HMA1 Class

In this guide we will go through a series of steps that will let you discover functionalities of the HMA1 class.

What is HMA1?

The sdv.relational.HMA1 class implements what is called a Hierarchical Modeling Algorithm which is an algorithm that allows to recursively walk through a relational dataset and apply tabular models across all the tables in a way that lets the models learn how all the fields from all the tables are related.

Let’s now discover how to use the HMA1 class.

Quick Usage

We will start by loading and exploring one of our demo datasets.

In [1]: from sdv import load_demo

In [2]: metadata, tables = load_demo(metadata=True)

This will return two objects:

  1. A Metadata object with all the information that SDV needs to know about the dataset.

In [3]: metadata
Out[3]: 
Metadata
  root_path: .
  tables: ['users', 'sessions', 'transactions']
  relationships:
    sessions.user_id -> users.user_id
    transactions.session_id -> sessions.session_id


In [4]: metadata.visualize();
../../_images/hma1_1.png

For more details about how to build the Metadata for your own dataset, please refer to the Relational Metadata Guide.

  1. A dictionary containing three pandas.DataFrames with the tables described in the metadata object.

In [5]: tables
Out[5]: 
{'users':    user_id country gender  age
 0        0      US      M   34
 1        1      UK      F   23
 2        2      ES   None   44
 3        3      UK      M   22
 4        4      US      F   54
 5        5      DE      M   57
 6        6      BG      F   45
 7        7      ES   None   41
 8        8      FR      F   23
 9        9      UK   None   30,
 'sessions':    session_id  user_id  device       os  minutes
 0           0        0  mobile  android       23
 1           1        1  tablet      ios       12
 2           2        1  tablet  android        8
 3           3        2  mobile  android       13
 4           4        4  mobile      ios        9
 5           5        5  mobile  android       32
 6           6        6  mobile      ios        7
 7           7        6  tablet      ios       21
 8           8        6  mobile      ios       29
 9           9        8  tablet      ios       34,
 'transactions':    transaction_id  session_id           timestamp  amount  cancelled
 0               0           0 2019-01-01 12:34:32   100.0      False
 1               1           0 2019-01-01 12:42:21    55.3      False
 2               2           1 2019-01-07 17:23:11    79.5      False
 3               3           3 2019-01-10 11:08:57   112.1       True
 4               4           5 2019-01-10 21:54:08   110.0       True
 5               5           5 2019-01-11 11:21:20    76.3      False
 6               6           7 2019-01-22 14:44:10    89.5      False
 7               7           8 2019-01-23 10:14:09   132.1       True
 8               8           9 2019-01-27 16:09:17    68.0      False
 9               9           9 2019-01-29 12:10:48    99.9      False}

Let us now use the HMA1 class to learn this data to be ready to sample synthetic data about new users. In order to do this you will need to:

  • Import the sdv.relational.HMA1 class and create an instance of it passing the metadata that we just loaded.

  • Call its fit method passing the tables dict.

In [6]: from sdv.relational import HMA1

In [7]: model = HMA1(metadata)

In [8]: model.fit(tables)

Note

During the previous steps SDV walked through all the tables in the dataset following the relationships specified by the metadata, learned each table using a GaussianCopula Model and then augmented the parent tables using the copula parameters before learning them. By doing this, each copula model was able to learn how the child table rows were related to their parent tables.

Generate synthetic data from the model

Once the training process has finished you are ready to generate new synthetic data by calling the sample method from your model.

In [9]: new_data = model.sample()

This will return a dictionary of tables identical to the one which the model was fitted on, but filled with new data which resembles the original one.

In [10]: new_data
Out[10]: 
{'users':    user_id country gender  age
 0        0      ES      M   27
 1        1      DE      F   56
 2        2      ES      M   42
 3        3      US      F   24
 4        4      DE      F   32
 5        5      UK    NaN   47
 6        6      DE      F   47
 7        7      UK      F   25
 8        8      UK      F   39
 9        9      UK      F   57,
 'sessions':     session_id  user_id  device       os  minutes
 0            0        0  mobile  android       27
 1            1        1  mobile      ios       10
 2            2        1  mobile  android        8
 3            3        1  mobile  android       28
 4            4        4  tablet      ios       20
 5            5        4  tablet      ios       27
 6            6        5  tablet      ios       23
 7            7        5  tablet      ios       31
 8            8        5  mobile      ios       23
 9            9        6  mobile  android        9
 10          10        6  mobile      ios       12
 11          11        6  tablet      ios       14
 12          12        7  mobile      ios        7
 13          13        7  mobile  android        7
 14          14        8  tablet      ios       19
 15          15        8  tablet      ios       10
 16          16        8  tablet      ios       10
 17          17        9  mobile  android       12
 18          18        9  mobile      ios        7
 19          19        9  mobile  android        7,
 'transactions':     transaction_id  session_id           timestamp  amount  cancelled
 0                0           0 2019-01-10 03:58:42    80.3      False
 1                1           0 2019-01-11 07:57:57   122.1      False
 2                2           1 2019-01-01 12:55:17   112.9       True
 3                3           2 2019-01-08 14:53:33    85.6      False
 4                4           3 2019-01-08 20:35:11    94.5       True
 5                5           4 2019-01-23 05:43:09    91.1      False
 6                6           5 2019-01-28 20:04:04   118.8      False
 7                7           6 2019-01-27 15:28:46   109.3       True
 8                8           7 2019-01-28 07:41:40   115.6       True
 9                9           8 2019-01-28 23:38:40    66.2      False
 10              10           9 2019-01-14 18:25:21   124.1       True
 11              11          13 2019-01-01 16:49:37   103.4      False
 12              12          14 2019-01-19 22:13:22   109.5       True
 13              13          17 2019-01-08 22:34:41   130.5      False
 14              14          18 2019-01-11 10:15:59   131.0      False
 15              15          19 2019-01-13 03:47:21   108.7      False}

Save and Load the model

In many scenarios it will be convenient to generate synthetic versions of your data directly in systems that do not have access to the original data source. For example, if you may want to generate testing data on the fly inside a testing environment that does not have access to your production database. In these scenarios, fitting the model with real data every time that you need to generate new data is feasible, so you will need to fit a model in your production environment, save the fitted model into a file, send this file to the testing environment and then load it there to be able to sample from it.

Let’s see how this process works.

Save and share the model

Once you have fitted the model, all you need to do is call its save method passing the name of the file in which you want to save the model. Note that the extension of the filename is not relevant, but we will be using the .pkl extension to highlight that the serialization protocol used is pickle.

In [11]: model.save('my_model.pkl')

This will have created a file called my_model.pkl in the same directory in which you are running SDV.

Important

If you inspect the generated file you will notice that its size is much smaller than the size of the data that you used to generate it. This is because the serialized model contains no information about the original data, other than the parameters it needs to generate synthetic versions of it. This means that you can safely share this my_model.pkl file without the risk of disclosing any of your real data!

Load the model and generate new data

The file you just generated can be sent over to the system where the synthetic data will be generated. Once it is there, you can load it using the HMA1.load method, and then you are ready to sample new data from the loaded instance:

In [12]: loaded = HMA1.load('my_model.pkl')

In [13]: new_data = loaded.sample()

In [14]: new_data.keys()
Out[14]: dict_keys(['users', 'sessions', 'transactions'])

Warning

Notice that the system where the model is loaded needs to also have sdv installed, otherwise it will not be able to load the model and use it.

How to control the number of rows?

In the steps above we did not tell the model at any moment how many rows we wanted to sample, so it produced as many rows as there were in the original dataset.

If you want to produce a different number of rows you can pass it as the num_rows argument and it will produce the indicated number of rows:

In [15]: model.sample(num_rows=5)
Out[15]: 
{'users':    user_id country gender  age
 0       10      ES      F   39
 1       11      US      M   33
 2       12      ES      F   57
 3       13      US      M   27
 4       14      FR    NaN   57,
 'sessions':    session_id  user_id  device   os  minutes
 0          20       10  tablet  ios       21
 1          21       10  tablet  ios       22
 2          22       11  tablet  ios       13
 3          23       12  tablet  ios       22
 4          24       12  mobile  ios       18
 5          25       14  tablet  ios       19,
 'transactions':    transaction_id  session_id           timestamp  amount  cancelled
 0              16          20 2019-01-14 00:37:00    84.1       True
 1              17          22 2019-01-01 07:24:23   123.5      False
 2              18          22 2019-01-01 15:54:14   106.0      False
 3              19          25 2019-01-11 04:22:24   103.4       True}

Note

Notice that the root table users has the indicated number of rows but some of the other tables do not. This is because the number of rows from the child tables is sampled based on the values form the parent table, which means that only the root table of the dataset is affected by the passed num_rows argument.

Can I sample a subset of the tables?

In some occasions you will not be interested in generating rows for the entire dataset and would rather generate data for only one table and its children.

To do this you can simply pass the name of the table that you want to sample.

For example, pass the name sessions to the sample method, the model will only generate data for the sessions table and its child table, transactions.

In [16]: model.sample('sessions', num_rows=5)
Out[16]: 
{'sessions':    session_id  user_id  device       os  minutes
 0          26       19  tablet  android       20
 1          27       19  tablet      ios       22
 2          28       19  tablet  android       16
 3          29       19  tablet  android       34
 4          30       19  mobile  android       28,
 'transactions':    transaction_id  session_id           timestamp  amount  cancelled
 0              20          26 2019-01-13 07:04:12    91.6      False
 1              21          27 2019-01-23 16:31:30    97.1      False
 2              22          28 2019-01-09 22:54:32    60.9      False
 3              23          29 2019-01-14 10:42:05    75.7      False
 4              24          29 2019-01-14 02:26:57    84.3      False
 5              25          30 2019-01-18 05:44:10   109.8       True}

If you want to further restrict the sampling process to only one table and also skip its child tables, you can add the argument sample_children=False.

For example, you can sample data from the table users only without producing any rows for the tables sessions and transactions.

In [17]: model.sample('users', num_rows=5, sample_children=False)
Out[17]: 
   user_id country gender  age
0       20      FR    NaN   35
1       21      US      M   30
2       22      BG      F   22
3       23      BG      F   57
4       24      DE      M   52

Note

In this case, since we are only producing a single table, the output is given directly as a pandas.DataFrame instead of a dictionary.