Source code for tgan.trainer

"""GAN Models."""

import tensorflow as tf
from tensorpack import StagingInput, TowerTrainer
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper


[docs]class GANTrainer(TowerTrainer): """GanTrainer model. We need to set :meth:`tower_func` because it's a :class:`TowerTrainer`, and only :class:`TowerTrainer` supports automatic graph creation for inference during training. If we don't care about inference during training, using :meth:`tower_func` is not needed. Just calling :meth:`model.build_graph` directly is OK. Args: input_queue(tensorpack.input_source.QueueInput): Data input. model(tgan.GAN.GANModelDesc): Model to train. """ def __init__(self, model, input_queue): """Initialize object.""" super().__init__() inputs_desc = model.get_inputs_desc() # Setup input cbs = input_queue.setup(inputs_desc) self.register_callback(cbs) # Build the graph self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc) with TowerContext('', is_training=True): self.tower_func(*input_queue.get_input_tensors()) opt = model.get_optimizer() # Define the training iteration by default, run one d_min after one g_min with tf.name_scope('optimize'): g_min_grad = opt.compute_gradients(model.g_loss, var_list=model.g_vars) g_min_grad_clip = [ (tf.clip_by_value(grad, -5.0, 5.0), var) for grad, var in g_min_grad ] g_min_train_op = opt.apply_gradients(g_min_grad_clip, name='g_op') with tf.control_dependencies([g_min_train_op]): d_min_grad = opt.compute_gradients(model.d_loss, var_list=model.d_vars) d_min_grad_clip = [ (tf.clip_by_value(grad, -5.0, 5.0), var) for grad, var in d_min_grad ] d_min_train_op = opt.apply_gradients(d_min_grad_clip, name='d_op') self.train_op = d_min_train_op
[docs]class SeparateGANTrainer(TowerTrainer): """A GAN trainer which runs two optimization ops with a certain ratio. Args: input(tensorpack.input_source.QueueInput): Data input. model(tgan.GAN.GANModelDesc): Model to train. d_period(int): period of each d_opt run g_period(int): period of each g_opt run """ def __init__(self, input, model, d_period=1, g_period=1): """Initialize object.""" super(SeparateGANTrainer, self).__init__() self._d_period = int(d_period) self._g_period = int(g_period) if not min(d_period, g_period) == 1: raise ValueError('The minimum between d_period and g_period must be 1.') # Setup input cbs = input.setup(model.get_inputs_desc()) self.register_callback(cbs) # Build the graph self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) with TowerContext('', is_training=True): self.tower_func(*input.get_input_tensors()) opt = model.get_optimizer() with tf.name_scope('optimize'): self.d_min = opt.minimize( model.d_loss, var_list=model.d_vars, name='d_min') self.g_min = opt.minimize( model.g_loss, var_list=model.g_vars, name='g_min')
[docs] def run_step(self): """Define the training iteration.""" if self.global_step % (self._d_period) == 0: self.hooked_sess.run(self.d_min) if self.global_step % (self._g_period) == 0: self.hooked_sess.run(self.g_min)
[docs]class MultiGPUGANTrainer(TowerTrainer): """A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support. Args: nr_gpu(int): input(tensorpack.input_source.QueueInput): Data input. model(tgan.GAN.GANModelDesc): Model to train. """ def __init__(self, nr_gpu, input, model): """Initialize object.""" super(MultiGPUGANTrainer, self).__init__() if nr_gpu <= 1: raise ValueError('nr_gpu must be strictly greater than 1.') raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)] # Setup input input = StagingInput(input) cbs = input.setup(model.get_inputs_desc()) self.register_callback(cbs) # Build the graph with multi-gpu replication def get_cost(*inputs): model.build_graph(*inputs) return [model.d_loss, model.g_loss] self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] cost_list = DataParallelBuilder.build_on_towers( list(range(nr_gpu)), lambda: self.tower_func(*input.get_input_tensors()), devices) # Simply average the cost here. It might be faster to average the gradients with tf.name_scope('optimize'): d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu) opt = model.get_optimizer() # run one d_min after one g_min g_min = opt.minimize(g_loss, var_list=model.g_vars, colocate_gradients_with_ops=True, name='g_op') with tf.control_dependencies([g_min]): d_min = opt.minimize(d_loss, var_list=model.d_vars, colocate_gradients_with_ops=True, name='d_op') # Define the training iteration self.train_op = d_min