mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 10:45:46 +08:00
Public release of Deep Compressed Sensing project.
PiperOrigin-RevId: 272403580
This commit is contained in:
committed by
Diego de Las Casas
parent
f119aac804
commit
fa33baacca
+106
@@ -0,0 +1,106 @@
|
||||
# Copyright 2019 DeepMind Technologies Limited and Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Network utilities."""
|
||||
|
||||
import functools
|
||||
import re
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
import tensorflow_gan as tfgan
|
||||
|
||||
|
||||
def _sn_custom_getter():
|
||||
def name_filter(name):
|
||||
match = re.match(r'.*w(_.*)?$', name)
|
||||
return match is not None
|
||||
|
||||
return tfgan.features.spectral_normalization_custom_getter(
|
||||
name_filter=name_filter)
|
||||
|
||||
|
||||
class SNGenNet(snt.AbstractModule):
|
||||
"""As in the SN paper."""
|
||||
|
||||
def __init__(self, name='conv_gen'):
|
||||
super(SNGenNet, self).__init__(name=name)
|
||||
|
||||
def _build(self, inputs, is_training):
|
||||
batch_size = inputs.get_shape().as_list()[0]
|
||||
first_shape = [4, 4, 512]
|
||||
norm_ctor = snt.BatchNormV2
|
||||
norm_ctor_config = {'scale': True}
|
||||
up_tensor = snt.Linear(np.prod(first_shape))(inputs)
|
||||
first_tensor = tf.reshape(up_tensor, shape=[batch_size] + first_shape)
|
||||
|
||||
net = snt.nets.ConvNet2DTranspose(
|
||||
output_channels=[256, 128, 64, 3],
|
||||
output_shapes=[(8, 8), (16, 16), (32, 32), (32, 32)],
|
||||
kernel_shapes=[(4, 4), (4, 4), (4, 4), (3, 3)],
|
||||
strides=[2, 2, 2, 1],
|
||||
normalization_ctor=norm_ctor,
|
||||
normalization_kwargs=norm_ctor_config,
|
||||
normalize_final=False,
|
||||
paddings=[snt.SAME], activate_final=False, activation=tf.nn.relu)
|
||||
output = net(first_tensor, is_training=is_training)
|
||||
return tf.nn.tanh(output)
|
||||
|
||||
|
||||
class SNMetricNet(snt.AbstractModule):
|
||||
"""Spectral normalization discriminator (metric) architecture."""
|
||||
|
||||
def __init__(self, num_outputs=2, name='sn_metric'):
|
||||
super(SNMetricNet, self).__init__(name=name)
|
||||
self._num_outputs = num_outputs
|
||||
|
||||
def _build(self, inputs):
|
||||
with tf.variable_scope('', custom_getter=_sn_custom_getter()):
|
||||
net = snt.nets.ConvNet2D(
|
||||
output_channels=[64, 64, 128, 128, 256, 256, 512],
|
||||
kernel_shapes=[
|
||||
(3, 3), (4, 4), (3, 3), (4, 4), (3, 3), (4, 4), (3, 3)],
|
||||
strides=[1, 2, 1, 2, 1, 2, 1],
|
||||
paddings=[snt.SAME], activate_final=True,
|
||||
activation=functools.partial(tf.nn.leaky_relu, alpha=0.1))
|
||||
linear = snt.Linear(self._num_outputs)
|
||||
output = linear(snt.BatchFlatten()(net(inputs)))
|
||||
return output
|
||||
|
||||
|
||||
class MLPGeneratorNet(snt.AbstractModule):
|
||||
"""MNIST generator net."""
|
||||
|
||||
def __init__(self, name='mlp_generator'):
|
||||
super(MLPGeneratorNet, self).__init__(name=name)
|
||||
|
||||
def _build(self, inputs, is_training=True):
|
||||
del is_training
|
||||
net = snt.nets.MLP([500, 500, 784], activation=tf.nn.leaky_relu)
|
||||
out = net(inputs)
|
||||
out = tf.nn.tanh(out)
|
||||
return snt.BatchReshape([28, 28, 1])(out)
|
||||
|
||||
|
||||
class MLPMetricNet(snt.AbstractModule):
|
||||
"""Same as in Grover and Ermon, ICLR workshop 2017."""
|
||||
|
||||
def __init__(self, num_outputs=2, name='mlp_metric'):
|
||||
super(MLPMetricNet, self).__init__(name=name)
|
||||
self._layer_size = [500, 500, num_outputs]
|
||||
|
||||
def _build(self, inputs):
|
||||
net = snt.nets.MLP(self._layer_size,
|
||||
activation=tf.nn.leaky_relu)
|
||||
output = net(snt.BatchFlatten()(inputs))
|
||||
return output
|
||||
Reference in New Issue
Block a user