mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
8457046b2c
PiperOrigin-RevId: 328023346
458 lines
15 KiB
Python
458 lines
15 KiB
Python
# Copyright 2020 DeepMind Technologies Limited.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Data preprocessing and augmentation."""
|
|
|
|
import functools
|
|
from typing import Any, Mapping, Text
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
# typing
|
|
JaxBatch = Mapping[Text, jnp.ndarray]
|
|
ConfigDict = Mapping[Text, Any]
|
|
|
|
augment_config = dict(
|
|
view1=dict(
|
|
random_flip=True, # Random left/right flip
|
|
color_transform=dict(
|
|
apply_prob=1.0,
|
|
# Range of jittering
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.2,
|
|
hue=0.1,
|
|
# Probability of applying color jittering
|
|
color_jitter_prob=0.8,
|
|
# Probability of converting to grayscale
|
|
to_grayscale_prob=0.2,
|
|
# Shuffle the order of color transforms
|
|
shuffle=True),
|
|
gaussian_blur=dict(
|
|
apply_prob=1.0,
|
|
# Kernel size ~ image_size / blur_divider
|
|
blur_divider=10.,
|
|
# Kernel distribution
|
|
sigma_min=0.1,
|
|
sigma_max=2.0),
|
|
solarize=dict(apply_prob=0.0, threshold=0.5),
|
|
),
|
|
view2=dict(
|
|
random_flip=True,
|
|
color_transform=dict(
|
|
apply_prob=1.0,
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.2,
|
|
hue=0.1,
|
|
color_jitter_prob=0.8,
|
|
to_grayscale_prob=0.2,
|
|
shuffle=True),
|
|
gaussian_blur=dict(
|
|
apply_prob=0.1, blur_divider=10., sigma_min=0.1, sigma_max=2.0),
|
|
solarize=dict(apply_prob=0.2, threshold=0.5),
|
|
))
|
|
|
|
|
|
def postprocess(inputs: JaxBatch, rng: jnp.ndarray):
|
|
"""Apply the image augmentations to crops in inputs (view1 and view2)."""
|
|
|
|
def _postprocess_image(
|
|
images: jnp.ndarray,
|
|
rng: jnp.ndarray,
|
|
presets: ConfigDict,
|
|
) -> JaxBatch:
|
|
"""Applies augmentations in post-processing.
|
|
|
|
Args:
|
|
images: an NHWC tensor (with C=3), with float values in [0, 1].
|
|
rng: a single PRNGKey.
|
|
presets: a dict of presets for the augmentations.
|
|
|
|
Returns:
|
|
A batch of augmented images with shape NHWC, with keys view1, view2
|
|
and labels.
|
|
"""
|
|
flip_rng, color_rng, blur_rng, solarize_rng = jax.random.split(rng, 4)
|
|
out = images
|
|
if presets['random_flip']:
|
|
out = random_flip(out, flip_rng)
|
|
if presets['color_transform']['apply_prob'] > 0:
|
|
out = color_transform(out, color_rng, **presets['color_transform'])
|
|
if presets['gaussian_blur']['apply_prob'] > 0:
|
|
out = gaussian_blur(out, blur_rng, **presets['gaussian_blur'])
|
|
if presets['solarize']['apply_prob'] > 0:
|
|
out = solarize(out, solarize_rng, **presets['solarize'])
|
|
out = jnp.clip(out, 0., 1.)
|
|
return jax.lax.stop_gradient(out)
|
|
|
|
rng1, rng2 = jax.random.split(rng, num=2)
|
|
view1 = _postprocess_image(inputs['view1'], rng1, augment_config['view1'])
|
|
view2 = _postprocess_image(inputs['view2'], rng2, augment_config['view2'])
|
|
return dict(view1=view1, view2=view2, labels=inputs['labels'])
|
|
|
|
|
|
def _maybe_apply(apply_fn, inputs, rng, apply_prob):
|
|
should_apply = jax.random.uniform(rng, shape=()) <= apply_prob
|
|
return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x)
|
|
|
|
|
|
def _depthwise_conv2d(inputs, kernel, strides, padding):
|
|
"""Computes a depthwise conv2d in Jax.
|
|
|
|
Args:
|
|
inputs: an NHWC tensor with N=1.
|
|
kernel: a [H", W", 1, C] tensor.
|
|
strides: a 2d tensor.
|
|
padding: "SAME" or "VALID".
|
|
|
|
Returns:
|
|
The depthwise convolution of inputs with kernel, as [H, W, C].
|
|
"""
|
|
return jax.lax.conv_general_dilated(
|
|
inputs,
|
|
kernel,
|
|
strides,
|
|
padding,
|
|
feature_group_count=inputs.shape[-1],
|
|
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
|
|
|
|
|
|
def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
|
|
"""Applies gaussian blur to a single image, given as NHWC with N=1."""
|
|
radius = int(kernel_size / 2)
|
|
kernel_size_ = 2 * radius + 1
|
|
x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
|
|
blur_filter = jnp.exp(-x**2 / (2. * sigma**2))
|
|
blur_filter = blur_filter / jnp.sum(blur_filter)
|
|
blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
|
|
blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
|
|
num_channels = image.shape[-1]
|
|
blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels])
|
|
blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
|
|
expand_batch_dim = len(image.shape) == 3
|
|
if expand_batch_dim:
|
|
image = image[jnp.newaxis, ...]
|
|
blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
|
|
blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
|
|
blurred = jnp.squeeze(blurred, axis=0)
|
|
return blurred
|
|
|
|
|
|
def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min,
|
|
sigma_max, apply_prob):
|
|
"""Applies a random gaussian blur."""
|
|
apply_rng, transform_rng = jax.random.split(rng)
|
|
|
|
def _apply(image):
|
|
sigma_rng, = jax.random.split(transform_rng, 1)
|
|
sigma = jax.random.uniform(
|
|
sigma_rng,
|
|
shape=(),
|
|
minval=sigma_min,
|
|
maxval=sigma_max,
|
|
dtype=jnp.float32)
|
|
return _gaussian_blur_single_image(image, kernel_size, padding, sigma)
|
|
|
|
return _maybe_apply(_apply, image, apply_rng, apply_prob)
|
|
|
|
|
|
def rgb_to_hsv(r, g, b):
|
|
"""Converts R, G, B values to H, S, V values.
|
|
|
|
Reference TF implementation:
|
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
|
|
Only input values between 0 and 1 are guaranteed to work properly, but this
|
|
function complies with the TF implementation outside of this range.
|
|
|
|
Args:
|
|
r: A tensor representing the red color component as floats.
|
|
g: A tensor representing the green color component as floats.
|
|
b: A tensor representing the blue color component as floats.
|
|
|
|
Returns:
|
|
H, S, V values, each as tensors of shape [...] (same as the input without
|
|
the last dimension).
|
|
"""
|
|
vv = jnp.maximum(jnp.maximum(r, g), b)
|
|
range_ = vv - jnp.minimum(jnp.minimum(r, g), b)
|
|
sat = jnp.where(vv > 0, range_ / vv, 0.)
|
|
norm = jnp.where(range_ != 0, 1. / (6. * range_), 1e9)
|
|
|
|
hr = norm * (g - b)
|
|
hg = norm * (b - r) + 2. / 6.
|
|
hb = norm * (r - g) + 4. / 6.
|
|
|
|
hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb))
|
|
hue = hue * (range_ > 0)
|
|
hue = hue + (hue < 0)
|
|
|
|
return hue, sat, vv
|
|
|
|
|
|
def hsv_to_rgb(h, s, v):
|
|
"""Converts H, S, V values to an R, G, B tuple.
|
|
|
|
Reference TF implementation:
|
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
|
|
Only input values between 0 and 1 are guaranteed to work properly, but this
|
|
function complies with the TF implementation outside of this range.
|
|
|
|
Args:
|
|
h: A float tensor of arbitrary shape for the hue (0-1 values).
|
|
s: A float tensor of the same shape for the saturation (0-1 values).
|
|
v: A float tensor of the same shape for the value channel (0-1 values).
|
|
|
|
Returns:
|
|
An (r, g, b) tuple, each with the same dimension as the inputs.
|
|
"""
|
|
c = s * v
|
|
m = v - c
|
|
dh = (h % 1.) * 6.
|
|
fmodu = dh % 2.
|
|
x = c * (1 - jnp.abs(fmodu - 1))
|
|
hcat = jnp.floor(dh).astype(jnp.int32)
|
|
rr = jnp.where(
|
|
(hcat == 0) | (hcat == 5), c, jnp.where(
|
|
(hcat == 1) | (hcat == 4), x, 0)) + m
|
|
gg = jnp.where(
|
|
(hcat == 1) | (hcat == 2), c, jnp.where(
|
|
(hcat == 0) | (hcat == 3), x, 0)) + m
|
|
bb = jnp.where(
|
|
(hcat == 3) | (hcat == 4), c, jnp.where(
|
|
(hcat == 2) | (hcat == 5), x, 0)) + m
|
|
return rr, gg, bb
|
|
|
|
|
|
def adjust_brightness(rgb_tuple, delta):
|
|
return jax.tree_map(lambda x: x + delta, rgb_tuple)
|
|
|
|
|
|
def adjust_contrast(image, factor):
|
|
def _adjust_contrast_channel(channel):
|
|
mean = jnp.mean(channel, axis=(-2, -1), keepdims=True)
|
|
return factor * (channel - mean) + mean
|
|
return jax.tree_map(_adjust_contrast_channel, image)
|
|
|
|
|
|
def adjust_saturation(h, s, v, factor):
|
|
return h, jnp.clip(s * factor, 0., 1.), v
|
|
|
|
|
|
def adjust_hue(h, s, v, delta):
|
|
# Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb
|
|
# conversions) when running on GPU. When running on CPU, the results will be
|
|
# different if all RGB values for a pixel are outside of the [0, 1] range.
|
|
return (h + delta) % 1.0, s, v
|
|
|
|
|
|
def _random_brightness(rgb_tuple, rng, max_delta):
|
|
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
|
|
return adjust_brightness(rgb_tuple, delta)
|
|
|
|
|
|
def _random_contrast(rgb_tuple, rng, max_delta):
|
|
factor = jax.random.uniform(
|
|
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
|
|
return adjust_contrast(rgb_tuple, factor)
|
|
|
|
|
|
def _random_saturation(rgb_tuple, rng, max_delta):
|
|
h, s, v = rgb_to_hsv(*rgb_tuple)
|
|
factor = jax.random.uniform(
|
|
rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta)
|
|
return hsv_to_rgb(*adjust_saturation(h, s, v, factor))
|
|
|
|
|
|
def _random_hue(rgb_tuple, rng, max_delta):
|
|
h, s, v = rgb_to_hsv(*rgb_tuple)
|
|
delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
|
|
return hsv_to_rgb(*adjust_hue(h, s, v, delta))
|
|
|
|
|
|
def _to_grayscale(image):
|
|
rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
|
|
grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
|
|
return jnp.tile(grayscale, (1, 1, 3)) # Back to 3 channels.
|
|
|
|
|
|
def _color_transform_single_image(image, rng, brightness, contrast, saturation,
|
|
hue, to_grayscale_prob, color_jitter_prob,
|
|
apply_prob, shuffle):
|
|
"""Applies color jittering to a single image."""
|
|
apply_rng, transform_rng = jax.random.split(rng)
|
|
perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
|
|
transform_rng, 7)
|
|
|
|
# Whether the transform should be applied at all.
|
|
should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob
|
|
# Whether to apply grayscale transform.
|
|
should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob
|
|
# Whether to apply color jittering.
|
|
should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob
|
|
|
|
# Decorator to conditionally apply fn based on an index.
|
|
def _make_cond(fn, idx):
|
|
|
|
def identity_fn(x, unused_rng, unused_param):
|
|
return x
|
|
|
|
def cond_fn(args, i):
|
|
def clip(args):
|
|
return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
|
|
out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
|
|
lambda a: clip(fn(*a)), args,
|
|
lambda a: identity_fn(*a))
|
|
return jax.lax.stop_gradient(out)
|
|
|
|
return cond_fn
|
|
|
|
random_brightness_cond = _make_cond(_random_brightness, idx=0)
|
|
random_contrast_cond = _make_cond(_random_contrast, idx=1)
|
|
random_saturation_cond = _make_cond(_random_saturation, idx=2)
|
|
random_hue_cond = _make_cond(_random_hue, idx=3)
|
|
|
|
def _color_jitter(x):
|
|
rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1)))
|
|
if shuffle:
|
|
order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
|
|
else:
|
|
order = range(4)
|
|
for idx in order:
|
|
if brightness > 0:
|
|
rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx)
|
|
if contrast > 0:
|
|
rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx)
|
|
if saturation > 0:
|
|
rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx)
|
|
if hue > 0:
|
|
rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx)
|
|
return jnp.stack(rgb_tuple, axis=-1)
|
|
|
|
out_apply = _color_jitter(image)
|
|
out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
|
|
_to_grayscale, out_apply, lambda x: x)
|
|
return jnp.clip(out_apply, 0., 1.)
|
|
|
|
|
|
def _random_flip_single_image(image, rng):
|
|
_, flip_rng = jax.random.split(rng)
|
|
should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
|
|
image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
|
|
return image
|
|
|
|
|
|
def random_flip(images, rng):
|
|
rngs = jax.random.split(rng, images.shape[0])
|
|
return jax.vmap(_random_flip_single_image)(images, rngs)
|
|
|
|
|
|
def color_transform(images,
|
|
rng,
|
|
brightness=0.8,
|
|
contrast=0.8,
|
|
saturation=0.8,
|
|
hue=0.2,
|
|
color_jitter_prob=0.8,
|
|
to_grayscale_prob=0.2,
|
|
apply_prob=1.0,
|
|
shuffle=True):
|
|
"""Applies color jittering and/or grayscaling to a batch of images.
|
|
|
|
Args:
|
|
images: an NHWC tensor, with C=3.
|
|
rng: a single PRNGKey.
|
|
brightness: the range of jitter on brightness.
|
|
contrast: the range of jitter on contrast.
|
|
saturation: the range of jitter on saturation.
|
|
hue: the range of jitter on hue.
|
|
color_jitter_prob: the probability of applying color jittering.
|
|
to_grayscale_prob: the probability of converting the image to grayscale.
|
|
apply_prob: the probability of applying the transform to a batch element.
|
|
shuffle: whether to apply the transforms in a random order.
|
|
|
|
Returns:
|
|
A NHWC tensor of the transformed images.
|
|
"""
|
|
rngs = jax.random.split(rng, images.shape[0])
|
|
jitter_fn = functools.partial(
|
|
_color_transform_single_image,
|
|
brightness=brightness,
|
|
contrast=contrast,
|
|
saturation=saturation,
|
|
hue=hue,
|
|
color_jitter_prob=color_jitter_prob,
|
|
to_grayscale_prob=to_grayscale_prob,
|
|
apply_prob=apply_prob,
|
|
shuffle=shuffle)
|
|
return jax.vmap(jitter_fn)(images, rngs)
|
|
|
|
|
|
def gaussian_blur(images,
|
|
rng,
|
|
blur_divider=10.,
|
|
sigma_min=0.1,
|
|
sigma_max=2.0,
|
|
apply_prob=1.0):
|
|
"""Applies gaussian blur to a batch of images.
|
|
|
|
Args:
|
|
images: an NHWC tensor, with C=3.
|
|
rng: a single PRNGKey.
|
|
blur_divider: the blurring kernel will have size H / blur_divider.
|
|
sigma_min: the minimum value for sigma in the blurring kernel.
|
|
sigma_max: the maximum value for sigma in the blurring kernel.
|
|
apply_prob: the probability of applying the transform to a batch element.
|
|
|
|
Returns:
|
|
A NHWC tensor of the blurred images.
|
|
"""
|
|
rngs = jax.random.split(rng, images.shape[0])
|
|
kernel_size = images.shape[1] / blur_divider
|
|
blur_fn = functools.partial(
|
|
_random_gaussian_blur,
|
|
kernel_size=kernel_size,
|
|
padding='SAME',
|
|
sigma_min=sigma_min,
|
|
sigma_max=sigma_max,
|
|
apply_prob=apply_prob)
|
|
return jax.vmap(blur_fn)(images, rngs)
|
|
|
|
|
|
def _solarize_single_image(image, rng, threshold, apply_prob):
|
|
|
|
def _apply(image):
|
|
return jnp.where(image < threshold, image, 1. - image)
|
|
|
|
return _maybe_apply(_apply, image, rng, apply_prob)
|
|
|
|
|
|
def solarize(images, rng, threshold=0.5, apply_prob=1.0):
|
|
"""Applies solarization.
|
|
|
|
Args:
|
|
images: an NHWC tensor (with C=3).
|
|
rng: a single PRNGKey.
|
|
threshold: the solarization threshold.
|
|
apply_prob: the probability of applying the transform to a batch element.
|
|
|
|
Returns:
|
|
A NHWC tensor of the transformed images.
|
|
"""
|
|
rngs = jax.random.split(rng, images.shape[0])
|
|
solarize_fn = functools.partial(
|
|
_solarize_single_image, threshold=threshold, apply_prob=apply_prob)
|
|
return jax.vmap(solarize_fn)(images, rngs)
|