Source code for sdv.tabular.ctgan

"""Wrapper around CTGAN model."""

import numpy as np
from ctgan import CTGANSynthesizer, TVAESynthesizer

from sdv.tabular.base import BaseTabularModel


class CTGANModel(BaseTabularModel):
    """Base class for all the CTGAN models.

    The ``CTGANModel`` class provides a wrapper for all the CTGAN models.
    """

    _MODEL_CLASS = None
    _model_kwargs = None

    _DTYPE_TRANSFORMERS = {
        'O': None
    }

    def _build_model(self):
        return self._MODEL_CLASS(**self._model_kwargs)

    def _fit(self, table_data):
        """Fit the model to the table.

        Args:
            table_data (pandas.DataFrame):
                Data to be learned.
        """
        self._model = self._build_model()

        categoricals = []
        fields_before_transform = self._metadata.get_fields()
        for field in table_data.columns:
            if field in fields_before_transform:
                meta = fields_before_transform[field]
                if meta['type'] == 'categorical':
                    categoricals.append(field)

            else:
                field_data = table_data[field].dropna()
                if set(field_data.unique()) == {0.0, 1.0}:
                    # booleans encoded as float values must be modeled as bool
                    field_data = field_data.astype(bool)

                dtype = field_data.infer_objects().dtype
                try:
                    kind = np.dtype(dtype).kind
                except TypeError:
                    # probably category
                    kind = 'O'
                if kind in ['O', 'b']:
                    categoricals.append(field)

        self._model.fit(
            table_data,
            discrete_columns=categoricals
        )

    def _sample(self, num_rows, conditions=None):
        """Sample the indicated number of rows from the model.

        Args:
            num_rows (int):
                Amount of rows to sample.
            conditions (dict):
                If specified, this dictionary maps column names to the column
                value. Then, this method generates `num_rows` samples, all of
                which are conditioned on the given variables.

        Returns:
            pandas.DataFrame:
                Sampled data.
        """
        if conditions is None:
            return self._model.sample(num_rows)

        raise NotImplementedError(f"{self._MODEL_CLASS} doesn't support conditional sampling.")


[docs]class CTGAN(CTGANModel): """Model wrapping ``CTGANSynthesizer`` model. Args: field_names (list[str]): List of names of the fields that need to be modeled and included in the generated output data. Any additional fields found in the data will be ignored and will not be included in the generated output. If ``None``, all the fields found in the data are used. field_types (dict[str, dict]): Dictinary specifying the data types and subtypes of the fields that will be modeled. Field types and subtypes combinations must be compatible with the SDV Metadata Schema. field_transformers (dict[str, str]): Dictinary specifying which transformers to use for each field. Available transformers are: * ``integer``: Uses a ``NumericalTransformer`` of dtype ``int``. * ``float``: Uses a ``NumericalTransformer`` of dtype ``float``. * ``categorical``: Uses a ``CategoricalTransformer`` without gaussian noise. * ``categorical_fuzzy``: Uses a ``CategoricalTransformer`` adding gaussian noise. * ``one_hot_encoding``: Uses a ``OneHotEncodingTransformer``. * ``label_encoding``: Uses a ``LabelEncodingTransformer``. * ``boolean``: Uses a ``BooleanTransformer``. * ``datetime``: Uses a ``DatetimeTransformer``. anonymize_fields (dict[str, str]): Dict specifying which fields to anonymize and what faker category they belong to. primary_key (str): Name of the field which is the primary key of the table. constraints (list[Constraint, dict]): List of Constraint objects or dicts. table_metadata (dict or metadata.Table): Table metadata instance or dict representation. If given alongside any other metadata-related arguments, an exception will be raised. If not given at all, it will be built using the other arguments or learned from the data. embedding_dim (int): Size of the random sample passed to the Generator. Defaults to 128. generator_dim (tuple or list of ints): Size of the output samples for each one of the Residuals. A Residual Layer will be created for each one of the values provided. Defaults to (256, 256). discriminator_dim (tuple or list of ints): Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256). generator_lr (float): Learning rate for the generator. Defaults to 2e-4. generator_decay (float): Generator weight decay for the Adam Optimizer. Defaults to 1e-6. discriminator_lr (float): Learning rate for the discriminator. Defaults to 2e-4. discriminator_decay (float): Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. batch_size (int): Number of data samples to process in each step. discriminator_steps (int): Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation. log_frequency (boolean): Whether to use log frequency of categorical levels in conditional sampling. Defaults to ``True``. verbose (boolean): Whether to have print statements for progress results. Defaults to ``False``. epochs (int): Number of training epochs. Defaults to 300. pac (int): Number of samples to group together when applying the discriminator. Defaults to 10. cuda (bool or str): If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. rounding (int, str or None): Define rounding scheme for ``NumericalTransformer``. If set to an int, values will be rounded to that number of decimal places. If ``None``, values will not be rounded. If set to ``'auto'``, the transformer will round to the maximum number of decimal places detected in the fitted data. Defaults to ``'auto'``. min_value (int, str or None): Specify the minimum value the ``NumericalTransformer`` should use. If an integer is given, sampled data will be greater than or equal to it. If the string ``'auto'`` is given, the minimum will be the minimum value seen in the fitted data. If ``None`` is given, there won't be a minimum. Defaults to ``'auto'``. max_value (int, str or None): Specify the maximum value the ``NumericalTransformer`` should use. If an integer is given, sampled data will be less than or equal to it. If the string ``'auto'`` is given, the maximum will be the maximum value seen in the fitted data. If ``None`` is given, there won't be a maximum. Defaults to ``'auto'``. """ _MODEL_CLASS = CTGANSynthesizer
[docs] def __init__(self, field_names=None, field_types=None, field_transformers=None, anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True, rounding='auto', min_value='auto', max_value='auto'): super().__init__( field_names=field_names, primary_key=primary_key, field_types=field_types, field_transformers=field_transformers, anonymize_fields=anonymize_fields, constraints=constraints, table_metadata=table_metadata, rounding=rounding, max_value=max_value, min_value=min_value ) self._model_kwargs = { 'embedding_dim': embedding_dim, 'generator_dim': generator_dim, 'discriminator_dim': discriminator_dim, 'generator_lr': generator_lr, 'generator_decay': generator_decay, 'discriminator_lr': discriminator_lr, 'discriminator_decay': discriminator_decay, 'batch_size': batch_size, 'discriminator_steps': discriminator_steps, 'log_frequency': log_frequency, 'verbose': verbose, 'epochs': epochs, 'pac': pac, 'cuda': cuda }
[docs]class TVAE(CTGANModel): """Model wrapping ``TVAESynthesizer`` model. Args: field_names (list[str]): List of names of the fields that need to be modeled and included in the generated output data. Any additional fields found in the data will be ignored and will not be included in the generated output. If ``None``, all the fields found in the data are used. field_types (dict[str, dict]): Dictinary specifying the data types and subtypes of the fields that will be modeled. Field types and subtypes combinations must be compatible with the SDV Metadata Schema. field_transformers (dict[str, str]): Dictinary specifying which transformers to use for each field. Available transformers are: * ``integer``: Uses a ``NumericalTransformer`` of dtype ``int``. * ``float``: Uses a ``NumericalTransformer`` of dtype ``float``. * ``categorical``: Uses a ``CategoricalTransformer`` without gaussian noise. * ``categorical_fuzzy``: Uses a ``CategoricalTransformer`` adding gaussian noise. * ``one_hot_encoding``: Uses a ``OneHotEncodingTransformer``. * ``label_encoding``: Uses a ``LabelEncodingTransformer``. * ``boolean``: Uses a ``BooleanTransformer``. * ``datetime``: Uses a ``DatetimeTransformer``. anonymize_fields (dict[str, str]): Dict specifying which fields to anonymize and what faker category they belong to. primary_key (str): Name of the field which is the primary key of the table. constraints (list[Constraint, dict]): List of Constraint objects or dicts. table_metadata (dict or metadata.Table): Table metadata instance or dict representation. If given alongside any other metadata-related arguments, an exception will be raised. If not given at all, it will be built using the other arguments or learned from the data. embedding_dim (int): Size of the random sample passed to the Generator. Defaults to 128. compress_dims (tuple or list of ints): Size of each hidden layer in the encoder. Defaults to (128, 128). decompress_dims (tuple or list of ints): Size of each hidden layer in the decoder. Defaults to (128, 128). l2scale (int): Regularization term. Defaults to 1e-5. batch_size (int): Number of data samples to process in each step. epochs (int): Number of training epochs. Defaults to 300. loss_factor (int): Multiplier for the reconstruction error. Defaults to 2. cuda (bool or str): If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. rounding (int, str or None): Define rounding scheme for ``NumericalTransformer``. If set to an int, values will be rounded to that number of decimal places. If ``None``, values will not be rounded. If set to ``'auto'``, the transformer will round to the maximum number of decimal places detected in the fitted data. Defaults to ``'auto'``. min_value (int, str or None): Specify the minimum value the ``NumericalTransformer`` should use. If an integer is given, sampled data will be greater than or equal to it. If the string ``'auto'`` is given, the minimum will be the minimum value seen in the fitted data. If ``None`` is given, there won't be a minimum. Defaults to ``'auto'``. max_value (int, str or None): Specify the maximum value the ``NumericalTransformer`` should use. If an integer is given, sampled data will be less than or equal to it. If the string ``'auto'`` is given, the maximum will be the maximum value seen in the fitted data. If ``None`` is given, there won't be a maximum. Defaults to ``'auto'``. """ _MODEL_CLASS = TVAESynthesizer
[docs] def __init__(self, field_names=None, field_types=None, field_transformers=None, anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True, rounding='auto', min_value='auto', max_value='auto'): super().__init__( field_names=field_names, primary_key=primary_key, field_types=field_types, field_transformers=field_transformers, anonymize_fields=anonymize_fields, constraints=constraints, table_metadata=table_metadata, rounding=rounding, max_value=max_value, min_value=min_value ) self._model_kwargs = { 'embedding_dim': embedding_dim, 'compress_dims': compress_dims, 'decompress_dims': decompress_dims, 'l2scale': l2scale, 'batch_size': batch_size, 'epochs': epochs, 'loss_factor': loss_factor, 'cuda': cuda }