mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-29 19:55:25 +08:00
Enable Travis tests for geomancer
* Added a flag to disable plotting when running headless on Travis PiperOrigin-RevId: 368229292
This commit is contained in:
committed by
Diego de Las Casas
parent
d4a9a684cb
commit
20e06f144a
@@ -12,6 +12,7 @@ env:
|
|||||||
- PROJECT="avae"
|
- PROJECT="avae"
|
||||||
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
||||||
- PROJECT="gated_linear_networks"
|
- PROJECT="gated_linear_networks"
|
||||||
|
- PROJECT="geomancer"
|
||||||
- PROJECT="iodine"
|
- PROJECT="iodine"
|
||||||
- PROJECT="kfac_ferminet_alpha"
|
- PROJECT="kfac_ferminet_alpha"
|
||||||
- PROJECT="learning_to_simulate"
|
- PROJECT="learning_to_simulate"
|
||||||
|
|||||||
Regular → Executable
+6
-6
@@ -13,9 +13,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
python3 -m venv geomancer-venv
|
python3 -m venv /tmp/geomancer-venv
|
||||||
source geomancer-venv/bin/activate
|
source /tmp/geomancer-venv/bin/activate
|
||||||
pip3 install .
|
pip3 install -U pip
|
||||||
python3 geomancer_test.py
|
pip3 install geomancer/
|
||||||
python3 train.py
|
python3 -m geomancer.geomancer_test
|
||||||
deactivate
|
python3 geomancer/train.py --plot=False
|
||||||
|
|||||||
+17
-11
@@ -28,11 +28,14 @@ import numpy as np
|
|||||||
from scipy.stats import special_ortho_group
|
from scipy.stats import special_ortho_group
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
flags.DEFINE_list('specification', ['S^2', 'S^2'], 'List of submanifolds')
|
SPECIFICATION = flags.DEFINE_list(
|
||||||
flags.DEFINE_integer('npts', 1000, 'Number of data points')
|
name='specification', default=['S^2', 'S^2'], help='List of submanifolds')
|
||||||
flags.DEFINE_boolean('rotate', False, 'Apply random rotation to the data')
|
NPTS = flags.DEFINE_integer(
|
||||||
|
name='npts', default=1000, help='Number of data points')
|
||||||
FLAGS = flags.FLAGS
|
ROTATE = flags.DEFINE_boolean(
|
||||||
|
name='rotate', default=False, help='Apply random rotation to the data')
|
||||||
|
PLOT = flags.DEFINE_boolean(
|
||||||
|
name='plot', default=True, help='Whether to enable plotting')
|
||||||
|
|
||||||
|
|
||||||
def make_so_tangent(q):
|
def make_so_tangent(q):
|
||||||
@@ -139,8 +142,8 @@ def make_product_manifold(specification, npts):
|
|||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
# Generate data and run GEOMANCER
|
# Generate data and run GEOMANCER
|
||||||
data, dim, tangents = make_product_manifold(FLAGS.specification, FLAGS.npts)
|
data, dim, tangents = make_product_manifold(SPECIFICATION.value, NPTS.value)
|
||||||
if FLAGS.rotate:
|
if ROTATE.value:
|
||||||
rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1]))
|
rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1]))
|
||||||
data_rot = data @ rot.T
|
data_rot = data @ rot.T
|
||||||
components, spectrum = geomancer.fit(data_rot, dim)
|
components, spectrum = geomancer.fit(data_rot, dim)
|
||||||
@@ -149,6 +152,12 @@ def main(_):
|
|||||||
components, spectrum = geomancer.fit(data, dim)
|
components, spectrum = geomancer.fit(data, dim)
|
||||||
errors = geomancer.eval_aligned(components, tangents)
|
errors = geomancer.eval_aligned(components, tangents)
|
||||||
|
|
||||||
|
logging.info('Error between subspaces: %.2f +/- %.2f radians',
|
||||||
|
np.mean(errors),
|
||||||
|
np.std(errors))
|
||||||
|
|
||||||
|
if PLOT.value:
|
||||||
|
|
||||||
# Plot spectrum
|
# Plot spectrum
|
||||||
plt.figure(figsize=(8, 6))
|
plt.figure(figsize=(8, 6))
|
||||||
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
|
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
|
||||||
@@ -177,7 +186,7 @@ def main(_):
|
|||||||
fig = plt.figure(figsize=(8, 6))
|
fig = plt.figure(figsize=(8, 6))
|
||||||
gs = gridspec.GridSpec(1, len(tangents),
|
gs = gridspec.GridSpec(1, len(tangents),
|
||||||
width_ratios=[b.shape[2] for b in tangents])
|
width_ratios=[b.shape[2] for b in tangents])
|
||||||
for i, spec in enumerate(FLAGS.specification):
|
for i, spec in enumerate(SPECIFICATION.value):
|
||||||
ax = plt.subplot(gs[i])
|
ax = plt.subplot(gs[i])
|
||||||
ax.imshow(tangents[i][0])
|
ax.imshow(tangents[i][0])
|
||||||
ax.set_xticks([])
|
ax.set_xticks([])
|
||||||
@@ -185,9 +194,6 @@ def main(_):
|
|||||||
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
|
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
|
||||||
fig.canvas.set_window_title('Ground Truth')
|
fig.canvas.set_window_title('Ground Truth')
|
||||||
|
|
||||||
logging.info('Error between subspaces: %.2f +/- %.2f radians',
|
|
||||||
np.mean(errors),
|
|
||||||
np.std(errors))
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user