Defining Custom Constraints

In some cases, the predefined constraints do not cover all your needs. In such scenarios, you can use CustomConstraint to define your own logic on how to constrain your data. There are three main functions that you can create:

  • transform which is responsible for the forward pass when using transform strategy. Its main function is to change your data in a way that enforces the constraint.

  • reverse_transform which defines how to reverse the transformation of the transform.

  • is_valid which indicates which rows satisfy the constraint and which ones do not.

Let’s look at a demo dataset:

In [1]: from sdv.demo import load_tabular_demo

In [2]: employees = load_tabular_demo()

In [3]: employees
Out[3]: 
     company     department  employee_id  age  age_when_joined  years_in_the_company     salary  annual_bonus  prior_years_experience  full_time  part_time  contractor
0       Pear          Sales            1   38               33                     5  109500.00      19500.00                       2        1.0        0.0         0.0
1       Pear         Design            5   31               30                     1  111023.72      20804.67                       5        0.0        0.0         1.0
2    Glasses             AI            1   32               29                     3  110500.00      13500.00                       3        1.0        0.0         0.0
3    Glasses  Search Engine            7   32               27                     5  138559.67      22440.78                       1        0.0        0.0         1.0
4   Cheerper        BigData            6   39               35                     4  141000.00      22500.00                       1        0.0        1.0         0.0
5   Cheerper        Support           11   42               41                     1  148500.00       5500.00                       5        0.0        1.0         0.0
6       Pear          Sales           28   49               44                     5  140000.00      15500.00                       5        1.0        0.0         0.0
7       Pear         Design           75   49               46                     3  102856.89      18708.69                       4        0.0        0.0         1.0
8    Glasses             AI           33   32               29                     3   71000.00       9000.00                       5        1.0        0.0         0.0
9    Glasses  Search Engine           56   30               25                     5   46508.76      19843.82                       5        0.0        0.0         1.0
10  Cheerper        BigData           42   44               39                     5  103000.00      21500.00                       3        0.0        1.0         0.0
11  Cheerper        Support           80   38               35                     3  146000.00       6500.00                       5        0.0        1.0         0.0

The dataset defined in _single_table_constraints contains basic details about employees. We will use this dataset to demonstrate how you can create your own constraint.

Using the CustomConstraint

We wish to generate synthetic data from the employees records. If you look at the data above, you will notice that the salary column is a multiple of a base value, in this case the base unit is 500. In other words, the salary increments by 500. We will define transform and reverse_transform methods to make sure our data satisfy our constraint.

We can achieve our goal by performing transformations in a 2 step process:

  • Divide salary by the base unit (500). This transformation makes it easier for the model to learn the data since it would now learn regular integer values without any explicit constraint on the data.

  • Reversing the effect by multiplying salary back with the base unit. Now that the model has learned regular integer values, we multiply it with the base (500) such that it now conforms to our original data range.

In [4]: def transform(table_data):
   ...:     base = 500.
   ...:     table_data['salary'] = table_data['salary'] / base
   ...:     return table_data
   ...: 

After defining transform we create reverse_transform that reverses the operations made.

In [5]: def reverse_transform(table_data):
   ...:     base = 500.
   ...:     table_data['salary'] = table_data['salary'].round() * base
   ...:     return table_data
   ...: 

Then, we pack every thing together in CustomConstraint.

In [6]: from sdv.constraints import CustomConstraint

In [7]: constraint = CustomConstraint(
   ...:     transform=transform,
   ...:     reverse_transform=reverse_transform
   ...: )
   ...: 

Can I apply the same function to multiple columns?

In the example above we fixed the salary format, but if we continue observing the data we will see that annual_bonus is also constrained by the same logic. Rather than defining two constraints, or editing the code of our functions for each new column that we want to constraint, we provide another style of writing functions such that the function should accept a column data as input.

The transform function takes column_data as input and returns the transformed column.

In [8]: def transform(column_data):
   ...:     base = 500.
   ...:     return column_data / base
   ...: 

Similarly we defined reverse_transform in a way that it operates on the data of a single column.

In [9]: def reverse_transform(column_data):
   ...:     base = 500.
   ...:     return column_data.round() * base
   ...: 

Now that we have our functions, we initialize CustomConstraint and we specify which column(s) are the desired ones.

In [10]: constraint = CustomConstraint(
   ....:     columns=['salary', 'annual_bonus'],
   ....:     transform=transform,
   ....:     reverse_transform=reverse_transform
   ....: )
   ....: 

Can I access the rest of the table from my column functions?

If we look closely at the data, we notice that salary and annual_bonus are only a multiple of 500 when the employee is not a “contractor”. To take this requirement into consideration, we refer to a “fixed” column contractor in order to know whether we should apply this constraint or not. The access to contractor column will allow us to properly transform and reverse transform the data.

We write our functions to take as input:

  • table_data which contains all the information.

  • column which is a an argument to represent the columns of interest.

Now we can construct our functions freely, we write our methods with said arguments and be able to access 'contractor'.

We first write our transform function as we have done previously:

In [11]: def transform(table_data, column):
   ....:     base = 500.
   ....:     table_data[column] = table_data[column] / base
   ....:     return table_data
   ....: 

When it comes to defining reverse_transform, we need to distinguish between contractors and non contractors, the operations are as follows:

  1. round values to four decimal points for contractors such that the end result will be two decimal points after multiplying the result with 500.

  2. round values to zero for employees that are not contractors such that the end result will be a multiple of 500.

In [12]: def reverse_transform(table_data, column):
   ....:     base = 500.
   ....:     is_not_contractor = table_data.contractor == 0.
   ....:     table_data[column] = table_data[column].round(4)
   ....:     table_data[column].loc[is_not_contractor] = table_data[column].loc[is_not_contractor].round()
   ....:     table_data[column] *= base
   ....:     return table_data
   ....: 

We now stich everything together and pass it to the model.

In [13]: from sdv.tabular import GaussianCopula

In [14]: constraint = CustomConstraint(
   ....:     columns=['salary', 'annual_bonus'],
   ....:     transform=transform,
   ....:     reverse_transform=reverse_transform
   ....: )
   ....: 

In [15]: gc = GaussianCopula(constraints=[constraint])

In [16]: gc.fit(employees)

In [17]: sampled = gc.sample(10)

When we view the sampled data, we should find that all the rows in the sampled data have a salary that is a multiple of the base value with the exception of “contractor” records.

In [18]: sampled
Out[18]: 
    company     department  employee_id  age  age_when_joined  years_in_the_company     salary  annual_bonus  prior_years_experience  full_time  part_time  contractor
0      Pear          Sales           79   46               39                     4   96000.00      10500.00                       5        1.0        0.0         0.0
1      Pear         Design           26   32               32                     2   95577.15       8680.10                       5        0.0        0.0         1.0
2      Pear         Design           16   42               36                     1  126480.15       9222.05                       5        0.0        0.0         1.0
3   Glasses             AI           16   31               27                     4   58500.00      22500.00                       4        0.0        0.0         0.0
4      Pear        BigData            1   47               41                     4  148500.00      22500.00                       1        0.0        0.0         0.0
5      Pear        Support           13   35               35                     2   86000.00      12500.00                       5        0.0        0.0         0.0
6      Pear         Design            1   34               29                     3  141000.00      22500.00                       2        1.0        0.0         0.0
7  Cheerper        Support           80   44               46                     4  117842.20       9358.25                       5        0.0        1.0         1.0
8  Cheerper        Support           58   31               30                     2  115000.00      17000.00                       5        0.0        1.0         0.0
9   Glasses  Search Engine           20   30               25                     5   60273.75      21419.15                       4        1.0        0.0         1.0

This style gives flexibility to access any column in the table while still operating on a column basis.

Can I write a CustomConstraint based on reject sampling?

In the previous section, we defined our CustomConstraint using transform and reverse_transform functions. Sometimes, our constraints are not possible to implement using these methods, that is when we rely on the reject_sampling strategy. In reject_sampling we need to implement an is_valid function that identifies which rows do not follow the said constraint, in our case, which rows are not a multiple of the base unit.

We can define is_valid according to the three styles mentioned in the previous section:

  1. function with table_data argument.

  2. function with column_data argument.

  3. function with table_data and column argument.

is_valid should return a pd.Series where every valid row corresponds to True, otherwise it should contain False. Here is an example of how you would define is_valid for each one of the mentioned styles:

def is_valid(table_data):
    base = 500.
    return table_data['salary'] % base == 0

def is_valid(column_data):
    base = 500.
    return column_data % base == 0

def is_valid(table_data, column):
    base = 500.
    is_contractor = table_data.contractor == 1
    valid = table_data[column] % base == 0
    contractor_salary = employees['salary'].loc[is_contractor]
    valid.loc[is_contractor] = contractor_salary == contractor_salary.round(2)
    return valid

Then we construct CustomConstraint to take is_valid on its own.

constraint = CustomConstraint(
    columns=['salary', 'annual_bonus'],
    is_valid=is_valid
)