Files
deepmind-research/byol/utils/augmentations.py
T
Florent Altché 8457046b2c Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
2020-08-26 16:54:56 +01:00

458 lines
15 KiB
Python

# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data preprocessing and augmentation."""
import functools
from typing import Any, Mapping, Text
import jax
import jax.numpy as jnp
# typing
JaxBatch = Mapping[Text, jnp.ndarray]
ConfigDict = Mapping[Text, Any]
augment_config = dict(
view1=dict(
random_flip=True, # Random left/right flip
color_transform=dict(
apply_prob=1.0,
# Range of jittering
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1,
# Probability of applying color jittering
color_jitter_prob=0.8,
# Probability of converting to grayscale
to_grayscale_prob=0.2,
# Shuffle the order of color transforms
shuffle=True),
gaussian_blur=dict(
apply_prob=1.0,
# Kernel size ~ image_size / blur_divider
blur_divider=10.,
# Kernel distribution
sigma_min=0.1,
sigma_max=2.0),
solarize=dict(apply_prob=0.0, threshold=0.5),
),
view2=dict(
random_flip=True,
color_transform=dict(
apply_prob=1.0,
brightness=0.4,
contrast=0.4,
saturation=0.2,
hue=0.1,
color_jitter_prob=0.8,
to_grayscale_prob=0.2,
shuffle=True),
gaussian_blur=dict(
apply_prob=0.1, blur_divider=10., sigma_min=0.1, sigma_max=2.0),
solarize=dict(apply_prob=0.2, threshold=0.5),
))
def postprocess(inputs: JaxBatch, rng: jnp.ndarray):
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
def _postprocess_image(
images: jnp.ndarray,
rng: jnp.ndarray,
presets: ConfigDict,
) -> JaxBatch:
"""Applies augmentations in post-processing.
Args:
images: an NHWC tensor (with C=3), with float values in [0, 1].
rng: a single PRNGKey.
presets: a dict of presets for the augmentations.
Returns:
A batch of augmented images with shape NHWC, with keys view1, view2
and labels.
"""
flip_rng, color_rng, blur_rng, solarize_rng = jax.random.split(rng, 4)
out = images
if presets['random_flip']:
out = random_flip(out, flip_rng)
if presets['color_transform']['apply_prob'] > 0:
out = color_transform(out, color_rng, **presets['color_transform'])
if presets['gaussian_blur']['apply_prob'] > 0:
out = gaussian_blur(out, blur_rng, **presets['gaussian_blur'])
if presets['solarize']['apply_prob'] > 0:
out = solarize(out, solarize_rng, **presets['solarize'])
out = jnp.clip(out, 0., 1.)
return jax.lax.stop_gradient(out)
rng1, rng2 = jax.random.split(rng, num=2)
view1 = _postprocess_image(inputs['view1'], rng1, augment_config['view1'])
view2 = _postprocess_image(inputs['view2'], rng2, augment_config['view2'])
return dict(view1=view1, view2=view2, labels=inputs['labels'])
def _maybe_apply(apply_fn, inputs, rng, apply_prob):
should_apply = jax.random.uniform(rng, shape=()) <= apply_prob
return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x)
def _depthwise_conv2d(inputs, kernel, strides, padding):
"""Computes a depthwise conv2d in Jax.
Args:
inputs: an NHWC tensor with N=1.
kernel: a [H", W", 1, C] tensor.
strides: a 2d tensor.
padding: "SAME" or "VALID".
Returns:
The depthwise convolution of inputs with kernel, as [H, W, C].
"""
return jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
feature_group_count=inputs.shape[-1],
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
"""Applies gaussian blur to a single image, given as NHWC with N=1."""
radius = int(kernel_size / 2)
kernel_size_ = 2 * radius + 1
x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
blur_filter = jnp.exp(-x**2 / (2. * sigma**2))
blur_filter = blur_filter / jnp.sum(blur_filter)
blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
num_channels = image.shape[-1]
blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels])
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
expand_batch_dim = len(image.shape) == 3
if expand_batch_dim:
image = image[jnp.newaxis, ...]
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
blurred = jnp.squeeze(blurred, axis=0)
return blurred
def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min,
sigma_max, apply_prob):
"""Applies a random gaussian blur."""
apply_rng, transform_rng = jax.random.split(rng)
def _apply(image):
sigma_rng, = jax.random.split(transform_rng, 1)
sigma = jax.random.uniform(
sigma_rng,
shape=(),
minval=sigma_min,
maxval=sigma_max,
dtype=jnp.float32)
return _gaussian_blur_single_image(image, kernel_size, padding, sigma)
return _maybe_apply(_apply, image, apply_rng, apply_prob)
def rgb_to_hsv(r, g, b):
"""Converts R, G, B values to H, S, V values.
Reference TF implementation:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
Only input values between 0 and 1 are guaranteed to work properly, but this
function complies with the TF implementation outside of this range.
Args:
r: A tensor representing the red color component as floats.
g: A tensor representing the green color component as floats.
b: A tensor representing the blue color component as floats.
Returns:
H, S, V values, each as tensors of shape [...] (same as the input without
the last dimension).
"""
vv = jnp.maximum(jnp.maximum(r, g), b)
range_ = vv - jnp.minimum(jnp.minimum(r, g), b)
sat = jnp.where(vv > 0, range_ / vv, 0.)
norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9)
hr = norm * (g - b)
hg = norm * (b - r) + 2. / 6.
hb = norm * (r - g) + 4. / 6.
hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb))
hue = hue * (range_ > 0)
hue = hue + (hue < 0)
return hue, sat, vv
def hsv_to_rgb(h, s, v):
"""Converts H, S, V values to an R, G, B tuple.
Reference TF implementation:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
Only input values between 0 and 1 are guaranteed to work properly, but this
function complies with the TF implementation outside of this range.
Args:
h: A float tensor of arbitrary shape for the hue (0-1 values).
s: A float tensor of the same shape for the saturation (0-1 values).
v: A float tensor of the same shape for the value channel (0-1 values).
Returns:
An (r, g, b) tuple, each with the same dimension as the inputs.
"""
c = s * v
m = v - c
dh = (h % 1.) * 6.
fmodu = dh % 2.
x = c * (1 - jnp.abs(fmodu - 1))
hcat = jnp.floor(dh).astype(jnp.int32)
rr = jnp.where(
(hcat == 0) | (hcat == 5), c, jnp.where(
(hcat == 1) | (hcat == 4), x, 0)) + m
gg = jnp.where(
(hcat == 1) | (hcat == 2), c, jnp.where(
(hcat == 0) | (hcat == 3), x, 0)) + m
bb = jnp.where(
(hcat == 3) | (hcat == 4), c, jnp.where(
(hcat == 2) | (hcat == 5), x, 0)) + m
return rr, gg, bb
def adjust_brightness(rgb_tuple, delta):
return jax.tree_map(lambda x: x + delta, rgb_tuple)
def adjust_contrast(image, factor):
def _adjust_contrast_channel(channel):
mean = jnp.mean(channel, axis=(-2, -1), keepdims=True)
return factor * (channel - mean) + mean
return jax.tree_map(_adjust_contrast_channel, image)
def adjust_saturation(h, s, v, factor):
return h, jnp.clip(s * factor, 0., 1.), v
def adjust_hue(h, s, v, delta):
# Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb
# conversions) when running on GPU. When running on CPU, the results will be
# different if all RGB values for a pixel are outside of the [0, 1] range.
return (h + delta) % 1.0, s, v
def _random_brightness(rgb_tuple, rng, max_delta):
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
return adjust_brightness(rgb_tuple, delta)
def _random_contrast(rgb_tuple, rng, max_delta):
factor = jax.random.uniform(
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
return adjust_contrast(rgb_tuple, factor)
def _random_saturation(rgb_tuple, rng, max_delta):
h, s, v = rgb_to_hsv(*rgb_tuple)
factor = jax.random.uniform(
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
return hsv_to_rgb(*adjust_saturation(h, s, v, factor))
def _random_hue(rgb_tuple, rng, max_delta):
h, s, v = rgb_to_hsv(*rgb_tuple)
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
return hsv_to_rgb(*adjust_hue(h, s, v, delta))
def _to_grayscale(image):
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
def _color_transform_single_image(image, rng, brightness, contrast, saturation,
hue, to_grayscale_prob, color_jitter_prob,
apply_prob, shuffle):
"""Applies color jittering to a single image."""
apply_rng, transform_rng = jax.random.split(rng)
perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
transform_rng, 7)
# Whether the transform should be applied at all.
should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob
# Whether to apply grayscale transform.
should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob
# Whether to apply color jittering.
should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob
# Decorator to conditionally apply fn based on an index.
def _make_cond(fn, idx):
def identity_fn(x, unused_rng, unused_param):
return x
def cond_fn(args, i):
def clip(args):
return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
lambda a: clip(fn(*a)), args,
lambda a: identity_fn(*a))
return jax.lax.stop_gradient(out)
return cond_fn
random_brightness_cond = _make_cond(_random_brightness, idx=0)
random_contrast_cond = _make_cond(_random_contrast, idx=1)
random_saturation_cond = _make_cond(_random_saturation, idx=2)
random_hue_cond = _make_cond(_random_hue, idx=3)
def _color_jitter(x):
rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1)))
if shuffle:
order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
else:
order = range(4)
for idx in order:
if brightness > 0:
rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx)
if contrast > 0:
rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx)
if saturation > 0:
rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx)
if hue > 0:
rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx)
return jnp.stack(rgb_tuple, axis=-1)
out_apply = _color_jitter(image)
out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
_to_grayscale, out_apply, lambda x: x)
return jnp.clip(out_apply, 0., 1.)
def _random_flip_single_image(image, rng):
_, flip_rng = jax.random.split(rng)
should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
return image
def random_flip(images, rng):
rngs = jax.random.split(rng, images.shape[0])
return jax.vmap(_random_flip_single_image)(images, rngs)
def color_transform(images,
rng,
brightness=0.8,
contrast=0.8,
saturation=0.8,
hue=0.2,
color_jitter_prob=0.8,
to_grayscale_prob=0.2,
apply_prob=1.0,
shuffle=True):
"""Applies color jittering and/or grayscaling to a batch of images.
Args:
images: an NHWC tensor, with C=3.
rng: a single PRNGKey.
brightness: the range of jitter on brightness.
contrast: the range of jitter on contrast.
saturation: the range of jitter on saturation.
hue: the range of jitter on hue.
color_jitter_prob: the probability of applying color jittering.
to_grayscale_prob: the probability of converting the image to grayscale.
apply_prob: the probability of applying the transform to a batch element.
shuffle: whether to apply the transforms in a random order.
Returns:
A NHWC tensor of the transformed images.
"""
rngs = jax.random.split(rng, images.shape[0])
jitter_fn = functools.partial(
_color_transform_single_image,
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue,
color_jitter_prob=color_jitter_prob,
to_grayscale_prob=to_grayscale_prob,
apply_prob=apply_prob,
shuffle=shuffle)
return jax.vmap(jitter_fn)(images, rngs)
def gaussian_blur(images,
rng,
blur_divider=10.,
sigma_min=0.1,
sigma_max=2.0,
apply_prob=1.0):
"""Applies gaussian blur to a batch of images.
Args:
images: an NHWC tensor, with C=3.
rng: a single PRNGKey.
blur_divider: the blurring kernel will have size H / blur_divider.
sigma_min: the minimum value for sigma in the blurring kernel.
sigma_max: the maximum value for sigma in the blurring kernel.
apply_prob: the probability of applying the transform to a batch element.
Returns:
A NHWC tensor of the blurred images.
"""
rngs = jax.random.split(rng, images.shape[0])
kernel_size = images.shape[1] / blur_divider
blur_fn = functools.partial(
_random_gaussian_blur,
kernel_size=kernel_size,
padding='SAME',
sigma_min=sigma_min,
sigma_max=sigma_max,
apply_prob=apply_prob)
return jax.vmap(blur_fn)(images, rngs)
def _solarize_single_image(image, rng, threshold, apply_prob):
def _apply(image):
return jnp.where(image < threshold, image, 1. - image)
return _maybe_apply(_apply, image, rng, apply_prob)
def solarize(images, rng, threshold=0.5, apply_prob=1.0):
"""Applies solarization.
Args:
images: an NHWC tensor (with C=3).
rng: a single PRNGKey.
threshold: the solarization threshold.
apply_prob: the probability of applying the transform to a batch element.
Returns:
A NHWC tensor of the transformed images.
"""
rngs = jax.random.split(rng, images.shape[0])
solarize_fn = functools.partial(
_solarize_single_image, threshold=threshold, apply_prob=apply_prob)
return jax.vmap(solarize_fn)(images, rngs)