tgan.trainer module

GAN Models.

class tgan.trainer.GANTrainer(model, input_queue)[source]

Bases: tensorpack.train.tower.TowerTrainer

GanTrainer model.

We need to set tower_func() because it’s a TowerTrainer, and only TowerTrainer supports automatic graph creation for inference during training.

If we don’t care about inference during training, using tower_func() is not needed. Just calling model.build_graph() directly is OK.

Parameters
  • input_queue (tensorpack.input_source.QueueInput) – Data input.

  • model (tgan.GAN.GANModelDesc) – Model to train.

class tgan.trainer.MultiGPUGANTrainer(nr_gpu, input, model)[source]

Bases: tensorpack.train.tower.TowerTrainer

A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.

Parameters
  • nr_gpu (int) –

  • input (tensorpack.input_source.QueueInput) – Data input.

  • model (tgan.GAN.GANModelDesc) – Model to train.

class tgan.trainer.SeparateGANTrainer(input, model, d_period=1, g_period=1)[source]

Bases: tensorpack.train.tower.TowerTrainer

A GAN trainer which runs two optimization ops with a certain ratio.

Parameters
  • input (tensorpack.input_source.QueueInput) – Data input.

  • model (tgan.GAN.GANModelDesc) – Model to train.

  • d_period (int) – period of each d_opt run

  • g_period (int) – period of each g_opt run

run_step()[source]

Define the training iteration.