Source code for sdmetrics.multi_table.detection.base

"""Base class for Machine Learning Detection metrics that work on multiple tables."""

from sdmetrics.multi_table.base import MultiTableMetric


[docs]class DetectionMetric(MultiTableMetric): """Base class for Machine Learning Detection based metrics on multiple tables. These metrics build a Machine Learning Classifier that learns to tell the synthetic data apart from the real data, which later on is evaluated using Cross Validation. The output of the metric is one minus the average ROC AUC score obtained. Attributes: name (str): Name to use when reports about this metric are printed. goal (sdmetrics.goal.Goal): The goal of this metric. min_value (Union[float, tuple[float]]): Minimum value or values that this metric can take. max_value (Union[float, tuple[float]]): Maximum value or values that this metric can take. """ name = None goal = None min_value = None max_value = None @classmethod def compute(cls, real_data, synthetic_data, metadata=None): """Compute this metric. Args: real_data (dict[str, pandas.DataFrame]): The tables from the real dataset. synthetic_data (dict[str, pandas.DataFrame]): The tables from the synthetic dataset. metadata (dict): Multi-table metadata dict. If not passed, it is build based on the real_data fields and dtypes. Returns: Union[float, tuple[float]]: Metric output. """ raise NotImplementedError() @classmethod def normalize(cls, raw_score): """Return the `raw_score` as is, since it is already normalized. Args: raw_score (float): The value of the metric from `compute`. Returns: float: The normalized value of the metric """ return super().normalize(raw_score)