From 0909ded4c76379a60260689e34ef64835778fb60 Mon Sep 17 00:00:00 2001 From: Dushyant Rao Date: Mon, 8 Feb 2021 11:21:21 +0000 Subject: [PATCH] Fix bug for dynamic expansion in CURL. PiperOrigin-RevId: 356227182 --- curl/training.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/curl/training.py b/curl/training.py index 0cce039..8943bbf 100644 --- a/curl/training.py +++ b/curl/training.py @@ -591,7 +591,7 @@ def run_training( exp_wait_steps = 100 # Steps to wait after expansion before eligible again exp_burn_in = 100 # Steps to wait at start of learning before eligible exp_buffer_size = 100 # Size of the buffer of poorly explained data - num_buffer_train_steps = 20 # Num steps to train component on buffer + num_buffer_train_steps = 10 # Num steps to train component on buffer # Define a global tf variable for the number of active components. n_y_active_np = n_y_active @@ -749,6 +749,21 @@ def run_training( train_step = optimizer.minimize(train_ops.elbo) train_step_supervised = optimizer.minimize(train_ops.elbo_supervised) + # For dynamic expansion, we want to train only new-component-related params + cat_params = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, + 'cluster_encoder/mlp_cluster_encoder_final') + component_params = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, + 'latent_encoder/mlp_latent_encoder_*') + prior_params = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, + 'latent_decoder/latent_prior*') + + train_step_expansion = optimizer.minimize( + train_ops.elbo_supervised, + var_list=cat_params+component_params+prior_params) + # Set up ops for generative replay if gen_every_n > 0: # How many generative batches will we use each period? @@ -1078,11 +1093,10 @@ def run_training( for bs in range(n_poor_batches): x_batch = poor_data_buffer[bs * batch_size:(bs + 1) * batch_size] - label_batch = poor_data_labels[bs * batch_size:(bs + 1) * - batch_size] + label_batch = [new_cluster] * batch_size label_onehot_batch = np.eye(n_y)[label_batch] _ = sess.run( - train_step_supervised, + train_step_expansion, feed_dict={ x_train_raw: x_batch, model_train.y_label: label_onehot_batch