Files
deepmind-research/iodine/modules/decoder.py
Yilei Yang 152ac280f2 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 440896914
2022-05-26 17:43:58 +01:00

49 lines
1.6 KiB
Python

# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoders for rendering images."""
# pylint: disable=missing-docstring
from iodine.modules.distributions import MixtureParameters
import shapeguard
import sonnet as snt
class ComponentDecoder(snt.AbstractModule):
def __init__(self, pixel_decoder, name="component_decoder"):
super().__init__(name=name)
self._pixel_decoder = pixel_decoder
self._sg = shapeguard.ShapeGuard()
def set_output_shapes(self, pixel, mask):
self._sg.guard(pixel, "K, H, W, Cp")
self._sg.guard(mask, "K, H, W, 1")
self._pixel_decoder.set_output_shapes(self._sg["H, W, 1 + Cp"])
def _build(self, z):
self._sg.guard(z, "B, K, Z")
z_flat = self._sg.reshape(z, "B*K, Z")
pixel_params = self._pixel_decoder(z_flat).params
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
mask_params = pixel_params[..., 0:1]
pixel_params = pixel_params[..., 1:]
output = MixtureParameters(
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
mask=self._sg.reshape(mask_params, "B, K, H, W, 1"),
)
del self._sg.B
return output