mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 20:51:03 +08:00
Fix bug for dynamic expansion in CURL.
PiperOrigin-RevId: 356227182
This commit is contained in:
committed by
Diego de Las Casas
parent
91f4e2d2e6
commit
0909ded4c7
+18
-4
@@ -591,7 +591,7 @@ def run_training(
|
|||||||
exp_wait_steps = 100 # Steps to wait after expansion before eligible again
|
exp_wait_steps = 100 # Steps to wait after expansion before eligible again
|
||||||
exp_burn_in = 100 # Steps to wait at start of learning before eligible
|
exp_burn_in = 100 # Steps to wait at start of learning before eligible
|
||||||
exp_buffer_size = 100 # Size of the buffer of poorly explained data
|
exp_buffer_size = 100 # Size of the buffer of poorly explained data
|
||||||
num_buffer_train_steps = 20 # Num steps to train component on buffer
|
num_buffer_train_steps = 10 # Num steps to train component on buffer
|
||||||
|
|
||||||
# Define a global tf variable for the number of active components.
|
# Define a global tf variable for the number of active components.
|
||||||
n_y_active_np = n_y_active
|
n_y_active_np = n_y_active
|
||||||
@@ -749,6 +749,21 @@ def run_training(
|
|||||||
train_step = optimizer.minimize(train_ops.elbo)
|
train_step = optimizer.minimize(train_ops.elbo)
|
||||||
train_step_supervised = optimizer.minimize(train_ops.elbo_supervised)
|
train_step_supervised = optimizer.minimize(train_ops.elbo_supervised)
|
||||||
|
|
||||||
|
# For dynamic expansion, we want to train only new-component-related params
|
||||||
|
cat_params = tf.get_collection(
|
||||||
|
tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||||
|
'cluster_encoder/mlp_cluster_encoder_final')
|
||||||
|
component_params = tf.get_collection(
|
||||||
|
tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||||
|
'latent_encoder/mlp_latent_encoder_*')
|
||||||
|
prior_params = tf.get_collection(
|
||||||
|
tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||||
|
'latent_decoder/latent_prior*')
|
||||||
|
|
||||||
|
train_step_expansion = optimizer.minimize(
|
||||||
|
train_ops.elbo_supervised,
|
||||||
|
var_list=cat_params+component_params+prior_params)
|
||||||
|
|
||||||
# Set up ops for generative replay
|
# Set up ops for generative replay
|
||||||
if gen_every_n > 0:
|
if gen_every_n > 0:
|
||||||
# How many generative batches will we use each period?
|
# How many generative batches will we use each period?
|
||||||
@@ -1078,11 +1093,10 @@ def run_training(
|
|||||||
for bs in range(n_poor_batches):
|
for bs in range(n_poor_batches):
|
||||||
x_batch = poor_data_buffer[bs * batch_size:(bs + 1) *
|
x_batch = poor_data_buffer[bs * batch_size:(bs + 1) *
|
||||||
batch_size]
|
batch_size]
|
||||||
label_batch = poor_data_labels[bs * batch_size:(bs + 1) *
|
label_batch = [new_cluster] * batch_size
|
||||||
batch_size]
|
|
||||||
label_onehot_batch = np.eye(n_y)[label_batch]
|
label_onehot_batch = np.eye(n_y)[label_batch]
|
||||||
_ = sess.run(
|
_ = sess.run(
|
||||||
train_step_supervised,
|
train_step_expansion,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
x_train_raw: x_batch,
|
x_train_raw: x_batch,
|
||||||
model_train.y_label: label_onehot_batch
|
model_train.y_label: label_onehot_batch
|
||||||
|
|||||||
Reference in New Issue
Block a user