tgan.trainer module¶
GAN Models.
-
class
tgan.trainer.GANTrainer(model, input_queue)[source]¶ Bases:
tensorpack.train.tower.TowerTrainerGanTrainer model.
We need to set
tower_func()because it’s aTowerTrainer, and onlyTowerTrainersupports automatic graph creation for inference during training.If we don’t care about inference during training, using
tower_func()is not needed. Just callingmodel.build_graph()directly is OK.- Parameters
input_queue (tensorpack.input_source.QueueInput) – Data input.
model (tgan.GAN.GANModelDesc) – Model to train.
-
class
tgan.trainer.MultiGPUGANTrainer(nr_gpu, input, model)[source]¶ Bases:
tensorpack.train.tower.TowerTrainerA replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
- Parameters
nr_gpu (int) –
input (tensorpack.input_source.QueueInput) – Data input.
model (tgan.GAN.GANModelDesc) – Model to train.
-
class
tgan.trainer.SeparateGANTrainer(input, model, d_period=1, g_period=1)[source]¶ Bases:
tensorpack.train.tower.TowerTrainerA GAN trainer which runs two optimization ops with a certain ratio.
- Parameters
input (tensorpack.input_source.QueueInput) – Data input.
model (tgan.GAN.GANModelDesc) – Model to train.
d_period (int) – period of each d_opt run
g_period (int) – period of each g_opt run