Source code for skll.config.utils

# License: BSD 3 clause
"""
Utility classes and functions to parse SKLL configuration files.

:author: Nitin Madnani (nmadnani@ets.org)
:author: Dan Blanchard (dblanchard@ets.org)
:author: Michael Heilman (mheilman@ets.org)
"""

import csv
import errno
import logging
from pathlib import Path
from typing import Iterable, List, Union

from ruamel.yaml import YAML

from skll.types import FoldMapping, PathOrStr


[docs] def fix_json(json_string: str) -> str: """ Fix incorrectly formatted quotes and capitalized booleans in JSON string. Parameters ---------- json_string : str A JSON-style string. Returns ------- str The normalized JSON string. """ json_string = json_string.replace("True", "true") json_string = json_string.replace("False", "false") json_string = json_string.replace("'", '"') return json_string
[docs] def load_cv_folds(folds_file: PathOrStr, ids_to_floats=False) -> FoldMapping: """ Load cross-validation folds from a CSV file. The CSV file must contain two columns: example ID and fold ID (and a header). Parameters ---------- folds_file : :class:`skll.types.PathOrStr` The path to a folds file to read. ids_to_floats : bool, default=False Whether to convert IDs to floats. Returns ------- :class:`skll.types.FoldMapping` Dictionary with example IDs as the keys and fold IDs as the values. If `ids_to_floats` is set to `True`, the example IDs are floats but otherwise they are strings. Raises ------ ValueError If example IDs cannot be converted to floats and `ids_to_floats` is `True`. """ with open(folds_file) as f: reader = csv.reader(f) next(reader) # discard the header res = {} for row in reader: example_id: Union[float, str] = row[0] fold_id: str = row[1] if ids_to_floats: try: example_id = float(example_id) except ValueError: raise ValueError( "You set ids_to_floats to true, but ID " f"{row[0]} could not be converted to " "float" ) res[example_id] = fold_id return res
[docs] def locate_file(file_path: PathOrStr, config_dir: PathOrStr) -> str: """ Locate a file, given a file path and configuration directory. Parameters ---------- file_path : :class:`skll.types.PathOrStr` The file to locate. Path may be absolute or relative. config_dir : :class:`skll.types.PathOrStr` The path to the configuration file directory. Returns ------- path_to_check : str The normalized absolute path, if it exists. Raises ------ FileNotFoundError If the file does not exist. """ if not file_path: return "" # convert to Path objects file_path = Path(file_path) config_dir = Path(config_dir) config_relative_path = config_dir / file_path path_to_check = file_path if file_path.is_absolute() else config_relative_path.resolve() if not path_to_check.exists(): raise FileNotFoundError( errno.ENOENT, f"File {path_to_check} does not exist", str(path_to_check) ) else: return str(path_to_check)
def _munge_featureset_name(name_or_list: Union[Iterable, str]) -> str: """ Create a munged name for the featureset. Join features in ``featureset`` with a '+' if ``featureset`` is not a string. Otherwise, returns ``featureset``. Parameters ---------- name_or_list : Union[Iterable, str] A featureset name or name components in a list. Returns ------- res : str name components joined with '+' if input is a list or the name itself. """ if isinstance(name_or_list, str): return name_or_list res = "+".join(sorted(name_or_list)) return res def _parse_and_validate_metrics(metrics: str, option_name: str, logger=None) -> List[str]: """ Parse and validate string containing list of metrics. Given a string containing a list of metrics, this function parses that string into a list and validates some specific metric names. Parameters ---------- metrics : str A string containing a list of metrics. option_name : str The name of the option with which the metrics are associated. logger : Optional[logging.Logger], default=None A logging object. Returns ------- metrics : List[str] A list of metrics for the given option. Raises ------ TypeError If the given string cannot be converted to a list. ValueError If "mean_squared_error" is specified as a metric. """ # create a logger if one was not passed in if not logger: logger = logging.getLogger(__name__) # make sure the given metrics data type is a list # and parse it correctly yaml = YAML(typ="safe", pure=True) metrics = yaml.load(fix_json(metrics)) if not isinstance(metrics, list): raise TypeError(f"{option_name} should be a list, not a " f"{type(metrics)}.") # `mean_squared_error` is no longer supported. # It has been replaced by `neg_mean_squared_error` if "mean_squared_error" in metrics: raise ValueError( 'The metric "mean_squared_error" is no longer ' "supported. please use the metric " '"neg_mean_squared_error" instead.' ) return metrics