In this guide we will go through a series of steps that will let you discover functionalities of the HMA1 class.
HMA1
The sdv.relational.HMA1 class implements what is called a Hierarchical Modeling Algorithm which is an algorithm that allows to recursively walk through a relational dataset and apply tabular models across all the tables in a way that lets the models learn how all the fields from all the tables are related.
sdv.relational.HMA1
Let’s now discover how to use the HMA1 class.
We will start by loading and exploring one of our demo datasets.
In [1]: from sdv import load_demo In [2]: metadata, tables = load_demo(metadata=True)
This will return two objects:
A Metadata object with all the information that SDV needs to know about the dataset.
Metadata
In [3]: metadata Out[3]: Metadata root_path: . tables: ['users', 'sessions', 'transactions'] relationships: sessions.user_id -> users.user_id transactions.session_id -> sessions.session_id In [4]: metadata.visualize();
For more details about how to build the Metadata for your own dataset, please refer to the Relational Metadata Guide.
A dictionary containing three pandas.DataFrames with the tables described in the metadata object.
pandas.DataFrames
In [5]: tables Out[5]: {'users': user_id country gender age 0 0 US M 34 1 1 UK F 23 2 2 ES None 44 3 3 UK M 22 4 4 US F 54 5 5 DE M 57 6 6 BG F 45 7 7 ES None 41 8 8 FR F 23 9 9 UK None 30, 'sessions': session_id user_id device os minutes 0 0 0 mobile android 23 1 1 1 tablet ios 12 2 2 1 tablet android 8 3 3 2 mobile android 13 4 4 4 mobile ios 9 5 5 5 mobile android 32 6 6 6 mobile ios 7 7 7 6 tablet ios 21 8 8 6 mobile ios 29 9 9 8 tablet ios 34, 'transactions': transaction_id session_id timestamp amount cancelled 0 0 0 2019-01-01 12:34:32 100.0 False 1 1 0 2019-01-01 12:42:21 55.3 False 2 2 1 2019-01-07 17:23:11 79.5 False 3 3 3 2019-01-10 11:08:57 112.1 True 4 4 5 2019-01-10 21:54:08 110.0 True 5 5 5 2019-01-11 11:21:20 76.3 False 6 6 7 2019-01-22 14:44:10 89.5 False 7 7 8 2019-01-23 10:14:09 132.1 True 8 8 9 2019-01-27 16:09:17 68.0 False 9 9 9 2019-01-29 12:10:48 99.9 False}
Let us now use the HMA1 class to learn this data to be ready to sample synthetic data about new users. In order to do this you will need to:
Import the sdv.relational.HMA1 class and create an instance of it passing the metadata that we just loaded.
metadata
Call its fit method passing the tables dict.
fit
tables
In [6]: from sdv.relational import HMA1 In [7]: model = HMA1(metadata) In [8]: model.fit(tables)
Note
During the previous steps SDV walked through all the tables in the dataset following the relationships specified by the metadata, learned each table using a GaussianCopula Model and then augmented the parent tables using the copula parameters before learning them. By doing this, each copula model was able to learn how the child table rows were related to their parent tables.
Once the training process has finished you are ready to generate new synthetic data by calling the sample_all method from your model.
sample_all
In [9]: new_data = model.sample()
This will return a dictionary of tables identical to the one which the model was fitted on, but filled with new data which resembles the original one.
In [10]: new_data Out[10]: {'users': user_id country gender age 0 0 US M 49 1 1 ES F 53 2 2 ES F 27 3 3 UK M 52 4 4 US M 42 5 5 FR NaN 47 6 6 US F 29 7 7 FR F 40 8 8 DE F 43 9 9 ES M 26, 'sessions': session_id user_id device os minutes 0 0 0 mobile android 14 1 1 1 mobile ios 22 2 2 3 mobile android 28 3 3 3 tablet android 21 4 4 4 mobile ios 11 5 5 6 tablet ios 25 6 6 7 tablet ios 2 7 7 7 mobile ios 24 8 8 9 tablet android 27, 'transactions': transaction_id session_id timestamp amount cancelled 0 0 0 2019-01-03 04:31:01 60.978416 False 1 1 1 2019-01-14 17:38:47 68.069996 False 2 2 1 2019-01-15 11:06:17 59.307461 False 3 3 2 2019-01-05 09:22:47 97.318599 False 4 4 2 2019-01-05 06:05:45 64.535362 False 5 5 3 2019-01-12 13:08:20 64.885430 True 6 6 3 2019-01-12 15:48:40 91.347691 False 7 7 5 2019-01-06 01:46:02 108.910062 False 8 8 5 2019-01-06 03:19:40 96.394395 False 9 9 7 2019-01-23 01:41:04 107.960627 False 10 10 8 2019-01-13 06:34:16 88.917662 False 11 11 8 2019-01-14 03:00:12 87.453243 False}
In many scenarios it will be convenient to generate synthetic versions of your data directly in systems that do not have access to the original data source. For example, if you may want to generate testing data on the fly inside a testing environment that does not have access to your production database. In these scenarios, fitting the model with real data every time that you need to generate new data is feasible, so you will need to fit a model in your production environment, save the fitted model into a file, send this file to the testing environment and then load it there to be able to sample from it.
sample
Let’s see how this process works.
Once you have fitted the model, all you need to do is call its save method passing the name of the file in which you want to save the model. Note that the extension of the filename is not relevant, but we will be using the .pkl extension to highlight that the serialization protocol used is pickle.
save
.pkl
In [11]: model.save('my_model.pkl')
This will have created a file called my_model.pkl in the same directory in which you are running SDV.
my_model.pkl
Important
If you inspect the generated file you will notice that its size is much smaller than the size of the data that you used to generate it. This is because the serialized model contains no information about the original data, other than the parameters it needs to generate synthetic versions of it. This means that you can safely share this my_model.pkl file without the risc of disclosing any of your real data!
The file you just generated can be send over to the system where the synthetic data will be generated. Once it is there, you can load it using the HMA1.load method, and then you are ready to sample new data from the loaded instance:
HMA1.load
In [12]: loaded = HMA1.load('my_model.pkl') In [13]: new_data = loaded.sample() In [14]: new_data.keys() Out[14]: dict_keys(['users', 'sessions', 'transactions'])
Warning
Notice that the system where the model is loaded needs to also have sdv installed, otherwise it will not be able to load the model and use it.
sdv
In the steps above we did not tell the model at any moment how many rows we wanted to sample, so it produced as many rows as there were in the original dataset.
If you want to produce a different number of rows you can pass it as the num_rows argument and it will produce the indicated number of rows:
num_rows
In [15]: model.sample(num_rows=5) Out[15]: {'users': user_id country gender age 0 10 US NaN 43 1 11 FR M 40 2 12 DE NaN 40 3 13 BG M 19 4 14 UK M 44, 'sessions': session_id user_id device os minutes 0 9 10 tablet ios 17 1 10 11 mobile ios 13 2 11 12 mobile ios 23 3 12 12 tablet ios 27 4 13 12 mobile ios 28 5 14 13 tablet android 15 6 15 14 tablet ios 30 7 16 14 mobile ios 20, 'transactions': transaction_id session_id timestamp amount cancelled 0 12 9 2019-01-12 08:42:21 95.111277 False 1 13 10 2019-01-18 12:24:51 87.564454 True 2 14 10 2019-01-18 16:16:14 80.868356 True 3 15 11 2019-01-14 20:36:05 82.260514 False 4 16 12 2019-01-20 11:45:10 93.706944 False 5 17 13 2019-01-15 00:37:29 121.200205 False 6 18 14 2019-01-15 03:10:57 97.777895 True 7 19 15 2019-01-19 02:32:10 85.587062 False 8 20 15 2019-01-17 08:20:01 84.328488 False 9 21 16 2019-01-16 13:52:25 93.445029 True 10 22 16 2019-01-17 05:45:29 81.199695 True}
Notice that the root table users has the indicated number of rows but some of the other tables do not. This is because the number of rows from the child tables is sampled based on the values form the parent table, which means that only the root table of the dataset is affected by the passed num_rows argument.
users
In some occasions you will not be interested in generating rows for the entire dataset and would rather generate data for only one table and its children.
To do this you can simply pass the name of table that you want to sample.
For example, pass the name sessions to the sample method, the model will only generate data for the sessions table and its child table, transactions.
sessions
transactions
In [16]: model.sample('sessions', num_rows=5) Out[16]: {'sessions': session_id user_id device os minutes 0 17 15 mobile ios 41 1 18 15 mobile ios 18 2 19 18 mobile ios 14 3 20 17 mobile ios 30 4 21 19 tablet android 34, 'transactions': transaction_id session_id timestamp amount cancelled 0 23 17 2019-01-23 19:39:40 122.995859 True 1 24 17 2019-01-22 23:26:16 134.690624 True 2 25 19 2019-01-02 07:06:29 68.864885 False 3 26 20 2018-12-27 13:24:15 123.332667 False 4 27 20 2018-12-27 13:24:15 83.848067 False 5 28 21 2019-01-08 13:44:33 77.846918 False 6 29 21 2019-01-08 06:06:39 112.753962 False}
If you want to further restrict the sampling process to only one table and also skip its child tables, you can add the argument sample_children=False.
sample_children=False
For example, you can sample data from the table users only without producing any rows for the tables sessions and transactions.
In [17]: model.sample('users', num_rows=5, sample_children=False) Out[17]: user_id country gender age 0 20 DE M 51 1 21 US M 12 2 22 US NaN 39 3 23 ES F 39 4 24 BG F 48
In this case, since we are only producing a single table, the output is given directly as a pandas.DataFrame instead of a dictionary.
pandas.DataFrame