Fix silently broken use of tf.config.experimental.set_visible_devices.

The function expects device strings to be upper-case ('GPU' instead of 'gpu', etc) and silently fails when used with an empty device list and a lower-case device identifier.

PiperOrigin-RevId: 382547669
This commit is contained in:
Georg Ostrovski
2021-07-01 17:04:08 +00:00
committed by Louise Deason
parent ce87dbef98
commit db1d396077
2 changed files with 4 additions and 4 deletions
+2 -2
View File
@@ -154,8 +154,8 @@ class Experiment(experiment.AbstractExperiment):
):
"""Initializes experiment."""
super(Experiment, self).__init__(mode=mode, init_rng=init_rng)
tf.config.experimental.set_visible_devices([], device_type='gpu')
tf.config.experimental.set_visible_devices([], device_type='tpu')
tf.config.experimental.set_visible_devices([], device_type='GPU')
tf.config.experimental.set_visible_devices([], device_type='TPU')
if mode not in ('train', 'eval', 'train_eval_multithreaded'):
raise ValueError(f'Invalid mode {mode}.')
+2 -2
View File
@@ -90,8 +90,8 @@ class Experiment(experiment.AbstractExperiment):
raise ValueError(f'Invalid mode {mode}.')
# Do not use accelerators in data pipeline.
tf.config.experimental.set_visible_devices([], device_type='gpu')
tf.config.experimental.set_visible_devices([], device_type='tpu')
tf.config.experimental.set_visible_devices([], device_type='GPU')
tf.config.experimental.set_visible_devices([], device_type='TPU')
self.mode = mode
self.init_rng = init_rng