From 5909da5388d818791a64dd38bb2cc1f0dabc9fcf Mon Sep 17 00:00:00 2001 From: Sven Gowal Date: Wed, 7 Jul 2021 13:36:13 +0000 Subject: [PATCH] Added jaxline pipeline to train adversarially robust models. PiperOrigin-RevId: 383399487 --- adversarial_robustness/README.md | 78 +- adversarial_robustness/jax/attacks.py | 310 +++++++ adversarial_robustness/jax/datasets.py | 180 ++++ adversarial_robustness/jax/eval.py | 46 +- adversarial_robustness/jax/experiment.py | 565 +++++++++++++ adversarial_robustness/jax/experiment_test.py | 46 ++ adversarial_robustness/jax/model_zoo.py | 38 +- adversarial_robustness/jax/train.py | 32 + adversarial_robustness/jax/utils.py | 197 +++++ adversarial_robustness/pytorch/README.md | 28 + adversarial_robustness/requirements.txt | 85 +- adversarial_robustness/run.sh | 8 + rapid_task_solving/README.md | 123 +++ rapid_task_solving/images/example_mpg.png | Bin 0 -> 142077 bytes rapid_task_solving/images/example_osl.png | Bin 0 -> 80059 bytes rapid_task_solving/memory_planning_game.py | 184 +++++ rapid_task_solving/one_shot_streetlearn.py | 265 ++++++ rapid_task_solving/requirements.txt | 6 + wikigraphs/README.md | 230 ++++++ wikigraphs/requirements.txt | 7 + wikigraphs/scripts/build_vocab.py | 166 ++++ wikigraphs/scripts/download.sh | 63 ++ wikigraphs/scripts/freebase_preprocess.py | 106 +++ wikigraphs/scripts/visualize_graph.py | 143 ++++ wikigraphs/setup.py | 43 + wikigraphs/wikigraphs/data/__init__.py | 36 + wikigraphs/wikigraphs/data/dataset.py | 59 ++ wikigraphs/wikigraphs/data/io_tools.py | 179 ++++ wikigraphs/wikigraphs/data/paired_dataset.py | 767 ++++++++++++++++++ .../wikigraphs/data/paired_dataset_test.py | 271 +++++++ wikigraphs/wikigraphs/data/tokenizers.py | 230 ++++++ wikigraphs/wikigraphs/data/tokenizers_test.py | 78 ++ wikigraphs/wikigraphs/data/tools.py | 242 ++++++ wikigraphs/wikigraphs/data/tools_test.py | 195 +++++ wikigraphs/wikigraphs/data/wikitext.py | 218 +++++ wikigraphs/wikigraphs/data/wikitext_test.py | 84 ++ 36 files changed, 5229 insertions(+), 79 deletions(-) create mode 100644 adversarial_robustness/jax/attacks.py create mode 100644 adversarial_robustness/jax/datasets.py create mode 100644 adversarial_robustness/jax/experiment.py create mode 100644 adversarial_robustness/jax/experiment_test.py create mode 100644 adversarial_robustness/jax/train.py create mode 100644 adversarial_robustness/jax/utils.py create mode 100644 adversarial_robustness/pytorch/README.md create mode 100644 rapid_task_solving/README.md create mode 100644 rapid_task_solving/images/example_mpg.png create mode 100644 rapid_task_solving/images/example_osl.png create mode 100644 rapid_task_solving/memory_planning_game.py create mode 100644 rapid_task_solving/one_shot_streetlearn.py create mode 100644 rapid_task_solving/requirements.txt create mode 100644 wikigraphs/README.md create mode 100644 wikigraphs/requirements.txt create mode 100644 wikigraphs/scripts/build_vocab.py create mode 100644 wikigraphs/scripts/download.sh create mode 100644 wikigraphs/scripts/freebase_preprocess.py create mode 100644 wikigraphs/scripts/visualize_graph.py create mode 100644 wikigraphs/setup.py create mode 100644 wikigraphs/wikigraphs/data/__init__.py create mode 100644 wikigraphs/wikigraphs/data/dataset.py create mode 100644 wikigraphs/wikigraphs/data/io_tools.py create mode 100644 wikigraphs/wikigraphs/data/paired_dataset.py create mode 100644 wikigraphs/wikigraphs/data/paired_dataset_test.py create mode 100644 wikigraphs/wikigraphs/data/tokenizers.py create mode 100644 wikigraphs/wikigraphs/data/tokenizers_test.py create mode 100644 wikigraphs/wikigraphs/data/tools.py create mode 100644 wikigraphs/wikigraphs/data/tools_test.py create mode 100644 wikigraphs/wikigraphs/data/wikitext.py create mode 100644 wikigraphs/wikigraphs/data/wikitext_test.py diff --git a/adversarial_robustness/README.md b/adversarial_robustness/README.md index 5c62cb5..0708ffd 100644 --- a/adversarial_robustness/README.md +++ b/adversarial_robustness/README.md @@ -13,7 +13,7 @@ We have released our top-performing models in two formats compatible with [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/). This repository also contains our model definitions. -## Running the example code +## Running the code ### Downloading a model @@ -47,10 +47,32 @@ The following table contains the models from **Rebuffi et al., 2021**. | CIFAR-100 | ℓ | 8 / 255 | WRN-70-16 | ✗ | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt) | CIFAR-100 | ℓ | 8 / 255 | WRN-28-10 | ✗ | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt) -### Using the model +### Installing -Once downloaded, a model can be evaluated (clean accuracy) by running the -`eval.py` script in either the `jax` or `pytorch` folders. E.g.: +The following has been tested using Python 3.9.2. +Using `run.sh` will create and activate a virtualenv, install all necessary +dependencies and run a test program to ensure that you can import all the +modules. + +``` +# Run from the parent directory. +sh adversarial_robustness/run.sh +``` + +To run the provided code, use this virtualenv: + +``` +source /tmp/adversarial_robustness_venv/bin/activate +``` + +You may want to edit `requirements.txt` before running `run.sh` if GPU support +is needed (e.g., use `jaxline==0.1.67+cuda111`). See JAX's installation +[instructions](https://github.com/google/jax#installation) for more details. + +### Using pre-trained models + +Once downloaded, a model can be evaluated by running the `eval.py` script in +either the `jax` or `pytorch` folders. E.g.: ``` cd jax @@ -58,7 +80,47 @@ python3 eval.py \ --ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10 ``` -## Generated datasets +These models are also directly available within +[RobustBench](https://github.com/RobustBench/robustbench#model-zoo-quick-tour)'s +model zoo. + +### Training your own model + +We also provide a training pipeline that reproduces results from both +publications. This pipeline uses [Jaxline](https://github.com/deepmind/jaxline) +and is written using [JAX](https://github.com/google/jax) and +[Haiku](https://github.com/deepmind/dm-haiku). To train a model, modify the +configuration in the `get_config()` function of `jax/experiment.py` and issue +the following command from within the virtualenv created above: + +``` +cd jax +python3 train.py --config=experiment.py +``` + +The training pipeline can run with multiple worker machines and multiple devices +(either GPU or TPU). See [Jaxline](https://github.com/deepmind/jaxline) for more +details. + +We do not provide a PyTorch implementation of our training pipeline. However, +you may find one on GitHub, e.g., +[adversarial_robustness_pytorch](https://github.com/imrahulr/adversarial_robustness_pytorch) +(by Rahul Rade). + +## Datasets + +### Extracted dataset + +Gowal et al. (2020) use samples extracted from +[TinyImages-80M](https://groups.csail.mit.edu/vision/TinyImages/). +Unfortunately, since then, the official TinyImages-80M dataset has been +withdrawn (due to the presence of offensive images). As such, we cannot provide +a download link to our extrated data until we have manually verified that all +extracted images are not offensive. If you want to reproduce our setup, consider +the generated datasets below. We are also happy to help, so feel free to reach +out to Sven Gowal directly. + +### Generated datasets Rebuffi et al. (2021) use samples generated by a Denoising Diffusion Probabilistic Model [(DDPM; Ho et al., 2020)](https://arxiv.org/abs/2006.11239) @@ -82,8 +144,8 @@ labels = npzfile['label'] ## Citing this work -If you use this code, data or these models in your work, please cite the -relevant accompanying paper: +If you use this code (or any derived code), data or these models in your work, +please cite the relevant accompanying paper: ``` @article{gowal2020uncovering, @@ -95,7 +157,7 @@ relevant accompanying paper: } ``` -or +and/or ``` @article{rebuffi2021fixing, diff --git a/adversarial_robustness/jax/attacks.py b/adversarial_robustness/jax/attacks.py new file mode 100644 index 0000000..0b1e8c7 --- /dev/null +++ b/adversarial_robustness/jax/attacks.py @@ -0,0 +1,310 @@ +# Copyright 2021 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adversarial attacks. + +This file contains all the code necessary to create untargeted adversarial +attacks in JAX (within an l-infinity ball). For example, to create an untargeted +FGSM attack (with a single step), one can do the following: + +``` +import attacks + +epsilon = 8/255 # Perturbation radius for inputs between 0 and 1. +fgsm_attack = attacks.UntargetedAttack( + attacks.PGD( + attacks.IteratedFGSM(epsilon), + num_steps=1, + initialize_fn=attacks.linf_initialize_fn(epsilon), + project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), + loss_fn=attacks.untargeted_cross_entropy) +``` + +Just as elegantly, one can specify an adversarial attack on KL-divergence +to a target distribution (using 10 steps with Adam and a piecewise constant step +schedule): + +``` +kl_attack_with_adam = attacks.UntargetedAttack( + attacks.PGD( + attacks.Adam(optax.piecewise_constant_schedule( + init_value=.1, + boundaries_and_scales={5: .1})), + num_steps=10, + initialize_fn=attacks.linf_initialize_fn(epsilon), + project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), + loss_fn=attacks.untargeted_kl_divergence) +``` + +The attack instances can be used later on to build adversarial examples: + +``` +my_model = ... # Model. We assume that 'my_model(.)' returns logits. +clean_images, image_labels = ... # Batch of images and associated labels. +rng = jax.random.PRNGKey(0) # A random generator state. + +adversarial_images = fgsm_attack(my_model, rng, clean_images, image_labels) +``` + +See `experiment.py` or `eval.py` for more examples. + +This file contains the following components: +* Losses: + * untargeted_cross_entropy: minimizes the likelihood of the label class. + * untargeted_kl_divergence: maximizes the KL-divergence of the predictions with + a target distribution. + * untargeted_margin: maximizes the margin loss (distance from the highest + non-true logits to the label class logit) +* Step optimizers: + * SGD: Stochastic Gradient Descent. + * IteratedFGSM: Also called BIM (see https://arxiv.org/pdf/1607.02533). + * Adam: See https://arxiv.org/pdf/1412.6980. +* Initialization and projection functions: + * linf_initialize_fn: Initialize function for l-infinity attacks. + * linf_project_fn: Projection function for l-infinity attacks. +* Projected Gradient Descent (PGD): + * PGD: Runs Projected Gradient Descent using the specified optimizer, + initialization and projection functions for a given number of steps. +* Untargeted attack: + * UntargetedAttack: Combines PGD and a specific loss function to find + adversarial examples. +""" + +import functools +import inspect +from typing import Callable, Optional, Tuple, Union + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import optax + + +ModelFn = Callable[[chex.Array], chex.Array] +LossFn = Callable[[chex.Array], chex.Array] +ClassificationLossFn = Callable[[chex.Array, chex.Array], chex.Array] +OptimizeFn = Callable[[LossFn, chex.PRNGKey, chex.Array], chex.Array] +NormalizeFn = Callable[[chex.Array], chex.Array] +InitializeFn = Callable[[chex.PRNGKey, chex.Array], chex.Array] +ProjectFn = Callable[[chex.Array, chex.Array], chex.Array] + + +def untargeted_cross_entropy(logits: chex.Array, + labels: chex.Array) -> chex.Array: + """Maximize the cross-entropy of the true class (make it less likely).""" + num_classes = logits.shape[-1] + log_probs = jax.nn.log_softmax(logits) + return jnp.sum( + hk.one_hot(labels, num_classes).astype(logits.dtype) * log_probs, axis=-1) + + +def untargeted_kl_divergence(logits: chex.Array, + label_probs: chex.Array) -> chex.Array: + """Maximize the KL divergence between logits and label distribution.""" + # We are explicitly maximizing the cross-entropy, as this is equivalent to + # maximizing the KL divergence (when `label_probs` does not depend + # on the values that produce `logits`). + log_probs = jax.nn.log_softmax(logits) + return jnp.sum(label_probs * log_probs, axis=-1) + + +def untargeted_margin(logits: chex.Array, + labels: chex.Array) -> chex.Array: + """Make the highest non-correct logits higher than the true class logits.""" + batch_size = logits.shape[0] + num_classes = logits.shape[-1] + label_logits = logits[jnp.arange(batch_size), labels] + logit_mask = hk.one_hot(labels, num_classes).astype(logits.dtype) + highest_logits = jnp.max(logits - 1e8 * logit_mask, axis=-1) + return label_logits - highest_logits + + +class UntargetedAttack: + """Performs an untargeted attack.""" + + def __init__(self, + optimize_fn: OptimizeFn, + loss_fn: ClassificationLossFn = untargeted_cross_entropy): + """Creates an untargeted attack. + + Args: + optimize_fn: An `Optimizer` instance or any callable that takes + a loss function and an initial input and outputs a new input that + minimizes the loss function. + loss_fn: `loss_fn` is a surrogate loss. Its goal should be make the true + class less likely than any other class. Typical options for `loss_fn` + are `untargeted_cross_entropy` or `untargeted_margin`. + """ + self._optimize_fn = optimize_fn + self._loss_fn = loss_fn + + def __call__(self, + logits_fn: ModelFn, + rng: chex.PRNGKey, + inputs: chex.Array, + labels: chex.Array) -> chex.Array: + """Returns adversarial inputs.""" + def _loss_fn(x): + return self._loss_fn(logits_fn(x), labels) + return self._optimize_fn(_loss_fn, rng, inputs) + + # Convenience functions to detect the type of inputs required by the loss. + def expects_labels(self): + return 'labels' in inspect.getfullargspec(self._loss_fn).args + + def expects_probabilities(self): + return 'label_probs' in inspect.getfullargspec(self._loss_fn).args + + +class StepOptimizer: + """Makes a single gradient step that minimizes a loss function.""" + + def __init__(self, + gradient_transformation: optax.GradientTransformation): + self._gradient_transformation = gradient_transformation + + def init(self, + loss_fn: LossFn, + x: chex.Array) -> optax.OptState: + self._loss_fn = loss_fn + return self._gradient_transformation.init(x) + + def minimize( + self, + x: chex.Array, + state: optax.OptState) -> Tuple[chex.Array, chex.Array, optax.OptState]: + """Performs a single minimization step.""" + g, loss = gradients_fn(self._loss_fn, x) + if g is None: + raise ValueError('loss_fn does not depend on input.') + updates, state = self._gradient_transformation.update(g, state, x) + return optax.apply_updates(x, updates), loss, state + + +class SGD(StepOptimizer): + """Vanilla gradient descent optimizer.""" + + def __init__(self, + learning_rate_fn: Union[float, int, optax.Schedule], + normalize_fn: Optional[NormalizeFn] = None): + # Accept schedules, as well as scalar values. + if isinstance(learning_rate_fn, (float, int)): + lr = float(learning_rate_fn) + learning_rate_fn = lambda _: lr + # Normalization. + def update_fn(updates, state, params=None): + del params + updates = jax.tree_map(normalize_fn or (lambda x: x), updates) + return updates, state + gradient_transformation = optax.chain( + optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn), + optax.scale_by_schedule(learning_rate_fn), + optax.scale(-1.)) + super(SGD, self).__init__(gradient_transformation) + + +class IteratedFGSM(SGD): + """L-infinity normalized steps.""" + + def __init__(self, + learning_rate_fn: Union[float, int, optax.Schedule]): + super(IteratedFGSM, self).__init__(learning_rate_fn, jnp.sign) + + +class Adam(StepOptimizer): + """The Adam optimizer defined in https://arxiv.org/abs/1412.6980.""" + + def __init__( + self, + learning_rate_fn: Union[float, int, optax.Schedule], + normalize_fn: Optional[NormalizeFn] = None, + beta1: float = .9, + beta2: float = .999, + epsilon: float = 1e-9): + # Accept schedules, as well as scalar values. + if isinstance(learning_rate_fn, (float, int)): + lr = float(learning_rate_fn) + learning_rate_fn = lambda _: lr + # Normalization. + def update_fn(updates, state, params=None): + del params + updates = jax.tree_map(normalize_fn or (lambda x: x), updates) + return updates, state + gradient_transformation = optax.chain( + optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn), + optax.scale_by_adam(b1=beta1, b2=beta2, eps=epsilon), + optax.scale_by_schedule(learning_rate_fn), + optax.scale(-1.)) + super(Adam, self).__init__(gradient_transformation) + + +class PGD: + """Runs Project Gradient Descent (see https://arxiv.org/pdf/1706.06083).""" + + def __init__(self, + optimizer: StepOptimizer, + num_steps: int, + initialize_fn: Optional[InitializeFn] = None, + project_fn: Optional[ProjectFn] = None): + self._optimizer = optimizer + if initialize_fn is None: + initialize_fn = lambda rng, x: x + self._initialize_fn = initialize_fn + if project_fn is None: + project_fn = lambda x, origin_x: x + self._project_fn = project_fn + self._num_steps = num_steps + + def __call__(self, + loss_fn: LossFn, + rng: chex.PRNGKey, + x: chex.Array) -> chex.Array: + def _optimize(rng, x): + """Optimizes loss_fn when keep_best is False.""" + def body_fn(_, inputs): + opt_state, current_x = inputs + current_x, _, opt_state = self._optimizer.minimize(current_x, opt_state) + current_x = self._project_fn(current_x, x) + return opt_state, current_x + opt_state = self._optimizer.init(loss_fn, x) + current_x = self._project_fn(self._initialize_fn(rng, x), x) + _, current_x = jax.lax.fori_loop(0, self._num_steps, body_fn, + (opt_state, current_x)) + return current_x + return jax.lax.stop_gradient(_optimize(rng, x)) + + +def linf_project_fn(epsilon: float, bounds: Tuple[float, float]) -> ProjectFn: + def project_fn(x, origin_x): + dx = jnp.clip(x - origin_x, -epsilon, epsilon) + return jnp.clip(origin_x + dx, bounds[0], bounds[1]) + return project_fn + + +def linf_initialize_fn(epsilon: float) -> InitializeFn: + def initialize_fn(rng, x): + return x + jax.random.uniform(rng, x.shape, minval=-epsilon, + maxval=epsilon).astype(x.dtype) + return initialize_fn + + +def gradients_fn(loss_fn: LossFn, + x: chex.Array) -> Tuple[chex.Array, chex.Array]: + """Returns the analytical gradient as computed by `jax.grad`.""" + @functools.partial(jax.grad, has_aux=True) + def grad_reduced_loss_fn(x): + loss = loss_fn(x) + return jnp.sum(loss), loss + return grad_reduced_loss_fn(x) diff --git a/adversarial_robustness/jax/datasets.py b/adversarial_robustness/jax/datasets.py new file mode 100644 index 0000000..f823ec3 --- /dev/null +++ b/adversarial_robustness/jax/datasets.py @@ -0,0 +1,180 @@ +# Copyright 2021 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Datasets.""" + +from typing import Sequence + +import chex +import jax +import jax.numpy as jnp +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds + + +_CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) +_CIFAR10_STD = (0.2471, 0.2435, 0.2616) +_CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) +_CIFAR100_STD = (0.2673, 0.2564, 0.2762) + +_DATA_URL = 'https://storage.googleapis.com/dm-adversarial-robustness/' +_ALLOWED_FILES = ('cifar10_ddpm.npz',) +_WEBPAGE = ('https://github.com/deepmind/deepmind-research/tree/master/' + 'adversarial_robustness') + + +def cifar10_preprocess(mode: str = 'train'): + """Preprocessing functions for CIFAR-10.""" + def _preprocess_fn_train(example): + """Preprocessing of CIFAR-10 images for training.""" + image = example['image'] + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + image = _random_jitter(image, pad=4, crop=32) + image = tf.image.random_flip_left_right(image) + label = tf.cast(example['label'], tf.int32) + return {'image': image, 'label': label} + + def _preprocess_fn_test(example): + """Preprocessing of CIFAR-10 images for testing.""" + image = example['image'] + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + label = tf.cast(example['label'], tf.int32) + return {'image': image, 'label': label} + + return _preprocess_fn_train if mode == 'train' else _preprocess_fn_test + + +def cifar10_normalize(image: chex.Array) -> chex.Array: + means = jnp.array(_CIFAR10_MEAN, dtype=image.dtype) + stds = jnp.array(_CIFAR10_STD, dtype=image.dtype) + return (image - means) / stds + + +def mnist_normalize(image: chex.Array) -> chex.Array: + image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant', + constant_values=0) + return (image - .5) * 2. + + +def cifar100_normalize(image: chex.Array) -> chex.Array: + means = jnp.array(_CIFAR100_MEAN, dtype=image.dtype) + stds = jnp.array(_CIFAR100_STD, dtype=image.dtype) + return (image - means) / stds + + +def load_cifar10(batch_sizes: Sequence[int], + subset: str = 'train', + is_training: bool = True, + drop_remainder: bool = True, + repeat: int = 1) -> tf.data.Dataset: + """Loads CIFAR-10.""" + if subset == 'train': + ds = tfds.load(name='cifar10', split=tfds.Split.TRAIN) + # In Gowal et al. (https://arxiv.org/abs/2010.03593) and Rebuffi et al. + # (https://arxiv.org/abs/2103.01946), we also keep a separate validation + # subset for early stopping and would run: ds = ds.skip(1_024). + elif subset == 'test': + ds = tfds.load(name='cifar10', split=tfds.Split.TEST) + else: + raise ValueError('Unknown subset: "{}"'.format(subset)) + + ds = ds.cache() + if is_training: + ds = ds.repeat() + ds = ds.shuffle(buffer_size=50_000, seed=0) + ds = _repeat_batch(batch_sizes, ds, repeat=repeat) + ds = ds.map(cifar10_preprocess('train' if is_training else 'test'), + num_parallel_calls=tf.data.AUTOTUNE) + for batch_size in reversed(batch_sizes): + ds = ds.batch(batch_size, drop_remainder=drop_remainder) + return ds.prefetch(tf.data.AUTOTUNE) + + +def load_extra(batch_sizes: Sequence[int], + path_npz: str, + is_training: bool = True, + drop_remainder: bool = True) -> tf.data.Dataset: + """Loads extra data from a given path.""" + if not tf.io.gfile.exists(path_npz): + if path_npz in _ALLOWED_FILES: + path_npz = tf.keras.utils.get_file(path_npz, _DATA_URL + path_npz) + else: + raise ValueError(f'Extra data not found ({path_npz}). See {_WEBPAGE} for ' + 'more details.') + with tf.io.gfile.GFile(path_npz, 'rb') as fp: + npzfile = np.load(fp) + data = {'image': npzfile['image'], 'label': npzfile['label']} + with tf.device('/device:cpu:0'): # Prevent allocation to happen on GPU. + ds = tf.data.Dataset.from_tensor_slices(data) + ds = ds.cache() + if is_training: + ds = ds.repeat() + ds = ds.shuffle(buffer_size=50_000, seed=jax.host_id()) + ds = ds.map(cifar10_preprocess('train' if is_training else 'test'), + num_parallel_calls=tf.data.AUTOTUNE) + for batch_size in reversed(batch_sizes): + ds = ds.batch(batch_size, drop_remainder=drop_remainder) + return ds.prefetch(tf.data.AUTOTUNE) + + +def load_dummy_data(batch_sizes: Sequence[int], + is_training: bool = True, + **unused_kwargs) -> tf.data.Dataset: + """Loads fictive data (use this function when testing).""" + ds = tf.data.Dataset.from_tensor_slices({ + 'image': np.zeros((1, 32, 32, 3), np.float32), + 'label': np.zeros((1,), np.int32), + }) + ds = ds.repeat() + if not is_training: + total_batch_size = np.prod(batch_sizes) + ds = ds.take(total_batch_size) + ds = ds.map(cifar10_preprocess('train' if is_training else 'test'), + num_parallel_calls=tf.data.AUTOTUNE) + for batch_size in reversed(batch_sizes): + ds = ds.batch(batch_size, drop_remainder=True) + return ds.prefetch(tf.data.AUTOTUNE) + + +def _random_jitter(image: tf.Tensor, pad: int, crop: int) -> tf.Tensor: + shape = image.shape.as_list() + image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]]) + image = tf.image.random_crop(image, size=[crop, crop, shape[2]]) + return image + + +def _repeat_batch(batch_sizes: Sequence[int], + ds: tf.data.Dataset, + repeat: int = 1) -> tf.data.Dataset: + """Tiles the inner most batch dimension.""" + if repeat <= 1: + return ds + if batch_sizes[-1] % repeat != 0: + raise ValueError(f'The last element of `batch_sizes` ({batch_sizes}) must ' + f'be divisible by `repeat` ({repeat}).') + # Perform regular batching with reduced number of elements. + for i, batch_size in enumerate(reversed(batch_sizes)): + ds = ds.batch(batch_size // repeat if i == 0 else batch_size, + drop_remainder=True) + # Repeat batch. + fn = lambda x: tf.repeat(x, repeats=repeat, axis=len(batch_sizes) - 1) + def repeat_inner_batch(example): + return jax.tree_map(fn, example) + ds = ds.map(repeat_inner_batch, + num_parallel_calls=tf.data.AUTOTUNE) + # Unbatch. + for _ in batch_sizes: + ds = ds.unbatch() + return ds diff --git a/adversarial_robustness/jax/eval.py b/adversarial_robustness/jax/eval.py index 3ef8c5c..f5f1453 100644 --- a/adversarial_robustness/jax/eval.py +++ b/adversarial_robustness/jax/eval.py @@ -14,14 +14,19 @@ """Evaluates a JAX checkpoint on CIFAR-10/100 or MNIST.""" +import functools + from absl import app from absl import flags import haiku as hk import numpy as np +import optax import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds import tqdm +from adversarial_robustness.jax import attacks +from adversarial_robustness.jax import datasets from adversarial_robustness.jax import model_zoo _CKPT = flags.DEFINE_string( @@ -48,14 +53,14 @@ def main(unused_argv): # Create dataset. if _DATASET.value == 'mnist': _, data_test = tf.keras.datasets.mnist.load_data() - normalize_fn = model_zoo.mnist_normalize + normalize_fn = datasets.mnist_normalize elif _DATASET.value == 'cifar10': _, data_test = tf.keras.datasets.cifar10.load_data() - normalize_fn = model_zoo.cifar10_normalize + normalize_fn = datasets.cifar10_normalize else: assert _DATASET.value == 'cifar100' _, data_test = tf.keras.datasets.cifar100.load_data() - normalize_fn = model_zoo.cifar100_normalize + normalize_fn = datasets.cifar100_normalize # Create model. @hk.transform_with_state @@ -83,22 +88,53 @@ def main(unused_argv): else: params, state = np.load(_CKPT.value, allow_pickle=True) + # Create adversarial attack. We run a PGD-40 attack with margin loss. + epsilon = 8 / 255 + eval_attack = attacks.UntargetedAttack( + attacks.PGD( + attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule( + init_value=.1, + boundaries_and_scales={20: .1, 30: .01})), + num_steps=40, + initialize_fn=attacks.linf_initialize_fn(epsilon), + project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), + loss_fn=attacks.untargeted_margin) + + def logits_fn(x, rng): + return model_fn.apply(params, state, rng, x)[0] + # Evaluation. correct = 0 + adv_correct = 0 total = 0 batch_count = 0 total_batches = min((10_000 - 1) // _BATCH_SIZE.value + 1, _NUM_BATCHES.value) for images, labels in tqdm.tqdm(test_loader, total=total_batches): - outputs = model_fn.apply(params, state, next(rng_seq), images)[0] + rng = next(rng_seq) + loop_logits_fn = functools.partial(logits_fn, rng=rng) + + # Clean examples. + outputs = loop_logits_fn(images) + correct += (np.argmax(outputs, 1) == labels).sum().item() + + # Adversarial examples. + adv_images = eval_attack(loop_logits_fn, next(rng_seq), images, labels) + outputs = loop_logits_fn(adv_images) predicted = np.argmax(outputs, 1) + adv_correct += (predicted == labels).sum().item() + total += labels.shape[0] - correct += (predicted == labels).sum().item() batch_count += 1 if _NUM_BATCHES.value > 0 and batch_count >= _NUM_BATCHES.value: break print(f'Accuracy on the {total} test images: {100 * correct / total:.2f}%') + print(f'Robust accuracy: {100 * adv_correct / total:.2f}%') if __name__ == '__main__': flags.mark_flag_as_required('ckpt') + try: + tf.config.set_visible_devices([], 'GPU') # Prevent TF from using the GPU. + except tf.errors.NotFoundError: + pass app.run(main) diff --git a/adversarial_robustness/jax/experiment.py b/adversarial_robustness/jax/experiment.py new file mode 100644 index 0000000..d604921 --- /dev/null +++ b/adversarial_robustness/jax/experiment.py @@ -0,0 +1,565 @@ +# Copyright 2021 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAXline experiment to perform robust adversarial training.""" + +import functools +import os +from typing import Callable, Optional, Tuple + +from absl import flags +from absl import logging +import chex +import haiku as hk +import jax +import jax.numpy as jnp +from jaxline import base_config +from jaxline import experiment +from jaxline import utils as jl_utils +from ml_collections import config_dict +import numpy as np +import optax +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds + +from adversarial_robustness.jax import attacks +from adversarial_robustness.jax import datasets +from adversarial_robustness.jax import model_zoo +from adversarial_robustness.jax import utils + +FLAGS = flags.FLAGS + + +def get_config(): + """Return config object for training.""" + config = base_config.get_base_config() + + # Batch size, training steps and data. + num_classes = 10 + num_epochs = 400 + # Gowal et al. (2020) and Rebuffi et al. (2021) use 1024 as batch size. + # Reducing this batch size may require further adjustments to the batch + # normalization decay or the learning rate. If you have to use a batch size + # of 256, reduce the number of emulated workers to 1 (it should match the + # results of using a batch size of 1024 with 4 workers). + train_batch_size = 1024 + def steps_from_epochs(n): + return max(int(n * 50_000 / train_batch_size), 1) + num_steps = steps_from_epochs(num_epochs) + test_batch_size = train_batch_size + # Specify the path to the downloaded data. You can download data from + # https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness. + # If the path is set to "cifar10_ddpm.npz" and is not found in the current + # directory, the corresponding data will be downloaded. + extra_npz = 'cifar10_ddpm.npz' + + # Learning rate. + learning_rate = .1 * max(train_batch_size / 256, 1.) + learning_rate_warmup = steps_from_epochs(10) + learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps, + learning_rate_warmup) + + # Model definition. + model_ctor = model_zoo.WideResNet + model_kwargs = dict( + num_classes=num_classes, + depth=28, + width=10, + activation='swish') + + # Attack used during training (can be None). + epsilon = 8 / 255 + train_attack = attacks.UntargetedAttack( + attacks.PGD( + attacks.Adam(optax.piecewise_constant_schedule( + init_value=.1, + boundaries_and_scales={5: .1})), + num_steps=10, + initialize_fn=attacks.linf_initialize_fn(epsilon), + project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), + loss_fn=attacks.untargeted_kl_divergence) + + # Attack used during evaluation (can be None). + eval_attack = attacks.UntargetedAttack( + attacks.PGD( + attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule( + init_value=.1, + boundaries_and_scales={20: .1, 30: .01})), + num_steps=40, + initialize_fn=attacks.linf_initialize_fn(epsilon), + project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), + loss_fn=attacks.untargeted_margin) + + config.experiment_kwargs = config_dict.ConfigDict(dict(config=dict( + epsilon=epsilon, + num_classes=num_classes, + # Results from various publications use 4 worker machines, which results + # in slight differences when using less worker machines. To compensate for + # such discrepancies, we emulate these additional workers. Set to zero, + # when using more than 4 workers. + emulated_workers=4, + dry_run=False, + save_final_checkpoint_as_npy=True, + model=dict( + constructor=model_ctor, + kwargs=model_kwargs), + training=dict( + batch_size=train_batch_size, + learning_rate=learning_rate_fn, + weight_decay=5e-4, + swa_decay=.995, + use_cutmix=False, + supervised_batch_ratio=.3, + extra_data_path=extra_npz, + extra_label_smoothing=.1, + attack=train_attack), + evaluation=dict( + # If `interval` is positive, synchronously evaluate at regular + # intervals. Setting it to zero will not evaluate while training, + # unless `--jaxline_mode` is set to `train_eval_multithreaded`, which + # asynchronously evaluates checkpoints. + interval=steps_from_epochs(40), + batch_size=test_batch_size, + attack=eval_attack), + ))) + + config.checkpoint_dir = '/tmp/jaxline/robust' + config.train_checkpoint_all_hosts = False + config.training_steps = num_steps + config.interval_type = 'steps' + config.log_train_data_interval = steps_from_epochs(.5) + config.log_tensors_interval = steps_from_epochs(.5) + config.save_checkpoint_interval = steps_from_epochs(40) + config.eval_specific_checkpoint_dir = '' + return config + + +class Experiment(experiment.AbstractExperiment): + """CIFAR-10 experiment.""" + + CHECKPOINT_ATTRS = { + '_params': 'params', + '_avg_params': 'avg_params', + '_opt_state': 'opt_state', + '_state': 'state', + } + + def __init__(self, mode, config, init_rng): + super().__init__(mode=mode) + self.config = config + + self._params = None # Network weights. + self._avg_params = None # Averaged network weights. + self._state = None # Network state (e.g., batch statistics). + self._opt_state = None # Optimizer state. + + # Build model. + self.model = hk.transform_with_state(self._get_model()) + + if mode == 'train': + self._initialize_training(init_rng) + if self.config.evaluation.interval > 0: + self._last_evaluation_scalars = {} + self._initialize_evaluation() + elif mode == 'eval': + self._initialize_evaluation() + elif mode == 'train_eval_multithreaded': + self._initialize_training(init_rng) + self._initialize_evaluation() + else: + raise ValueError(f'Unknown mode: "{mode}"') + + # _ _ + # | |_ _ __ __ _(_)_ __ + # | __| '__/ _` | | '_ \ + # | |_| | | (_| | | | | | + # \__|_| \__,_|_|_| |_| + # + + def step(self, global_step, rng, *unused_args, **unused_kwargs): + # Get next inputs. + supervised_inputs = next(self.supervised_train_input) + if self.extra_train_input is None: + extra_inputs = None + else: + extra_inputs = next(self.extra_train_input) + + # Perform step. + (self._params, self._avg_params, self._state, self._opt_state, + scalars) = self.train_fn( + params=self._params, + avg_params=self._avg_params, + state=self._state, + opt_state=self._opt_state, + global_step=global_step, + supervised_inputs=supervised_inputs, + extra_inputs=extra_inputs, + rng=rng) + scalars = jl_utils.get_first(scalars) + + # Save final checkpoint. + if self.config.save_final_checkpoint_as_npy and not self.config.dry_run: + global_step_value = jl_utils.get_first(global_step) + if global_step_value == FLAGS.config.get('training_steps', 1) - 1: + f_np = lambda x: np.array(jax.device_get(jl_utils.get_first(x))) + np_params = jax.tree_map(f_np, self._avg_params or self._params) + np_state = jax.tree_map(f_np, self._state) + path_npy = os.path.join(FLAGS.config.checkpoint_dir, 'checkpoint.npy') + with tf.io.gfile.GFile(path_npy, 'wb') as fp: + np.save(fp, (np_params, np_state)) + logging.info('Saved final checkpoint at %s', path_npy) + + # Run synchronous evaluation. + if self.config.evaluation.interval <= 0: + return scalars + + global_step_value = jl_utils.get_first(global_step) + if (global_step_value % self.config.evaluation.interval != 0 and + global_step_value != FLAGS.config.get('training_steps', 1) - 1): + return _merge_eval_scalars(scalars, self._last_evaluation_scalars) + logging.info('Running synchronous evaluation...') + eval_scalars = self.evaluate(global_step, rng) + f_list = lambda x: x.tolist() if isinstance(x, jnp.ndarray) else x + self._last_evaluation_scalars = jax.tree_map(f_list, eval_scalars) + logging.info('(eval) global_step: %d, %s', global_step_value, + self._last_evaluation_scalars) + return _merge_eval_scalars(scalars, self._last_evaluation_scalars) + + def _train_fn(self, params, avg_params, state, opt_state, global_step, + supervised_inputs, extra_inputs, rng): + scalars = {} + images, labels, target_probs = self.concatenate(supervised_inputs, + extra_inputs) + + # Apply CutMix. + if self.config.training.use_cutmix: + aug_rng, rng = jax.random.split(rng) + images, target_probs = utils.cutmix(aug_rng, images, target_probs, + split=self._repeat_batch) + + # Perform adversarial attack. + if self.config.training.attack is None: + adv_images = None + grad_fn = jax.grad(self._cross_entropy_loss_fn, has_aux=True) + else: + attack = self.config.training.attack + attack_rng, rng = jax.random.split(rng) + def logits_fn(x): + x = self.normalize_fn(x) + return self.model.apply(params, state, rng, x, is_training=False, + test_local_stats=True)[0] + if attack.expects_labels(): + if self.config.training.use_cutmix: + raise ValueError('Use `untargeted_kl_divergence` when using CutMix.') + target_labels = labels + else: + assert attack.expects_probabilities() + if self.config.training.use_cutmix: + # When using CutMix, regress the attack away from mixed labels. + target_labels = target_probs + else: + target_labels = jax.nn.softmax(logits_fn(images)) + adv_images = attack(logits_fn, attack_rng, images, target_labels) + grad_fn = jax.grad(self._trades_loss_fn, has_aux=True) + + # Compute loss and gradients. + scaled_grads, (state, loss_scalars) = grad_fn( + params, state, images, adv_images, labels, target_probs, rng) + grads = jax.lax.psum(scaled_grads, axis_name='i') + scalars.update(loss_scalars) + + updates, opt_state = self.optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + + # Stochastic weight averaging. + if self.config.training.swa_decay > 0: + avg_params = utils.ema_update(global_step, avg_params, params, + decay_rate=self.config.training.swa_decay) + + learning_rate = self.config.training.learning_rate(global_step) + scalars['learning_rate'] = learning_rate + scalars = jax.lax.pmean(scalars, axis_name='i') + return params, avg_params, state, opt_state, scalars + + def _cross_entropy_loss_fn(self, params, state, images, adv_images, labels, + target_probs, rng): + scalars = {} + images = self.normalize_fn(images) + logits, state = self.model.apply( + params, state, rng, images, is_training=True) + loss = jnp.mean(utils.cross_entropy(logits, target_probs)) + loss += self.config.training.weight_decay * utils.weight_decay(params) + if not self.config.training.use_cutmix: + scalars['top_1_acc'] = utils.accuracy(logits, labels) + scalars['train_loss'] = loss + scaled_loss = loss / jax.device_count() + return scaled_loss, (state, scalars) + + def _trades_loss_fn(self, params, state, images, adv_images, labels, + target_probs, rng, beta=6.): + """Calculates TRADES loss (https://arxiv.org/pdf/1901.08573).""" + scalars = {} + + def apply_fn(x, **norm_kwargs): + x = self.normalize_fn(x) + return self.model.apply(params, state, rng, x, **norm_kwargs) + + # Clean images. + clean_logits, _ = apply_fn(images, is_training=False, test_local_stats=True) + if not self.config.training.use_cutmix: + scalars['top_1_acc'] = utils.accuracy(clean_logits, labels) + + # Adversarial images. Update BN stats with adversarial images. + adv_logits, state = apply_fn(adv_images, is_training=True) + if not self.config.training.use_cutmix: + scalars['top_1_adv_acc'] = utils.accuracy(adv_logits, labels) + + # Compute loss. + clean_loss = jnp.mean(utils.cross_entropy(clean_logits, target_probs)) + adv_loss = jnp.mean(utils.kl_divergence(adv_logits, clean_logits)) + reg_loss = self.config.training.weight_decay * utils.weight_decay(params) + loss = clean_loss + beta * adv_loss + reg_loss + scalars['train_loss'] = loss + + scaled_loss = loss / jax.device_count() + return scaled_loss, (state, scalars) + + # _ + # _____ ____ _| | + # / _ \ \ / / _` | | + # | __/\ V / (_| | | + # \___| \_/ \__,_|_| + # + + def evaluate(self, global_step, rng, *unused_args, **unused_kwargs): + return self.eval_epoch(self._avg_params or self._params, self._state, rng) + + def eval_epoch(self, params, state, rng): + host_id = jax.host_id() + num_samples = 0 + batch_axis = 1 + summed_scalars = None + # Converting to numpy here allows us to reset the generator. + eval_input = tfds.as_numpy(self.eval_input) + for all_inputs in eval_input: + # The inputs are send to multiple workers. + inputs = jax.tree_map(lambda x: x[host_id], all_inputs) + num_samples += jax.device_count() * inputs['image'].shape[batch_axis] + scalars = jl_utils.get_first(self.eval_fn(params, state, inputs, rng)) + # Accumulate the sum of scalars for each step. + scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars) + if summed_scalars is None: + summed_scalars = scalars + else: + summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) + return mean_scalars + + def _eval_fn(self, params, state, inputs, rng): + images = inputs['image'] + labels = inputs['label'] + + attack_rng, rng = jax.random.split(rng) + def logits_fn(x): + x = self.normalize_fn(x) + return self.model.apply(params, state, rng, x, is_training=False, + test_local_stats=False)[0] + + # Clean accuracy. + logits = logits_fn(images) + predicted_label = jnp.argmax(logits, axis=-1) + correct = jnp.equal(predicted_label, labels).astype(jnp.float32) + scalars = {'top_1_acc': correct} + + # Adversarial accuracy. + if self.config.evaluation.attack is not None: + attack = self.config.evaluation.attack + assert attack.expects_labels() + adv_images = attack(logits_fn, attack_rng, images, labels) + adv_logits = logits_fn(adv_images) + predicted_label = jnp.argmax(adv_logits, axis=-1) + correct = jnp.equal(predicted_label, labels).astype(jnp.float32) + scalars['top_1_adv_acc'] = correct + + # Returned values will be summed and finally divided by num_samples. + return jax.lax.psum(scalars, axis_name='i') + + def _initialize_training(self, rng): + # Initialize inputs. + if self.config.emulated_workers > 0: + per_device_workers, ragged = divmod(self.config.emulated_workers, + jax.host_count()) + if ragged: + raise ValueError('Number of emulated workers must be divisible by the ' + 'number of physical workers `jax.host_count()`.') + self._repeat_batch = per_device_workers + else: + self._repeat_batch = 1 + self.supervised_train_input = jl_utils.py_prefetch( + self._supervised_train_dataset) + self.extra_train_input = jl_utils.py_prefetch( + self._extra_train_dataset) + self.normalize_fn = datasets.cifar10_normalize + + # Optimizer. + self.optimizer = utils.sgd_momentum(self.config.training.learning_rate, + momentum=.9, nesterov=True) + + # Initialize parameters. + if self._params is None: + logging.info('Initializing parameters randomly rather than restoring ' + 'from checkpoint.') + # Create inputs to initialize the network state. + images, _, _ = jax.pmap(self.concatenate)( + next(self.supervised_train_input), + next(self.extra_train_input)) + images = jax.pmap(self.normalize_fn)(images) + # Initialize weights and biases. + init_net = jax.pmap( + lambda *a: self.model.init(*a, is_training=True), axis_name='i') + init_rng = jl_utils.bcast_local_devices(rng) + self._params, self._state = init_net(init_rng, images) + # Setup weight averaging. + if self.config.training.swa_decay > 0: + self._avg_params = self._params + else: + self._avg_params = None + # Initialize optimizer state. + init_opt = jax.pmap(self.optimizer.init, axis_name='i') + self._opt_state = init_opt(self._params) + + # Initialize step function. + self.train_fn = jax.pmap(self._train_fn, axis_name='i', + donate_argnums=(0, 1, 2, 3)) + + def _initialize_evaluation(self): + load_fn = (datasets.load_dummy_data if self.config.dry_run else + datasets.load_cifar10) + self.eval_input = _dataset( + functools.partial(load_fn, subset='test'), + is_training=False, total_batch_size=self.config.evaluation.batch_size) + self.normalize_fn = datasets.cifar10_normalize + self.eval_fn = jax.pmap(self._eval_fn, axis_name='i') + + def _supervised_train_dataset(self) -> tfds.typing.Tree[np.ndarray]: + """Creates the training dataset.""" + load_fn = (datasets.load_dummy_data if self.config.dry_run else + datasets.load_cifar10) + load_fn = functools.partial(load_fn, subset='train', + repeat=self._repeat_batch) + ds = _dataset(load_fn, is_training=True, repeat=self._repeat_batch, + total_batch_size=self.config.training.batch_size, + ratio=self.config.training.supervised_batch_ratio) + return tfds.as_numpy(ds) + + def _extra_train_dataset(self) -> tfds.typing.Tree[np.ndarray]: + """Creates the training dataset.""" + load_fn = (datasets.load_dummy_data if self.config.dry_run else + datasets.load_extra) + load_fn = functools.partial( + load_fn, path_npz=self.config.training.extra_data_path) + ds = _dataset( + load_fn, is_training=True, repeat=self._repeat_batch, + total_batch_size=self.config.training.batch_size, + one_minus_ratio=self.config.training.supervised_batch_ratio) + return tfds.as_numpy(ds) + + def _get_model(self) -> Callable[..., chex.Array]: + config = self.config.model + def forward_fn(inputs, **norm_kwargs): + model_instance = config.constructor(**config.kwargs.to_dict()) + return model_instance(inputs, **norm_kwargs) + return forward_fn + + def concatenate( + self, + supervised_inputs: chex.ArrayTree, + extra_inputs: chex.ArrayTree + ) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Concatenate inputs.""" + num_classes = self.config.num_classes + supervised_images = supervised_inputs['image'] + supervised_labels = supervised_inputs['label'] + if extra_inputs is None: + images = supervised_images + labels = supervised_labels + target_probs = hk.one_hot(labels, num_classes) + else: + extra_images = extra_inputs['image'] + images = jnp.concatenate([supervised_images, extra_images], axis=0) + extra_labels = extra_inputs['label'] + labels = jnp.concatenate([supervised_labels, extra_labels], axis=0) + supervised_one_hot_labels = hk.one_hot(supervised_labels, num_classes) + extra_one_hot_labels = hk.one_hot(extra_labels, num_classes) + if self.config.training.extra_label_smoothing > 0: + pos = 1. - self.config.training.extra_label_smoothing + neg = self.config.training.extra_label_smoothing / num_classes + extra_one_hot_labels = pos * extra_one_hot_labels + neg + target_probs = jnp.concatenate( + [supervised_one_hot_labels, extra_one_hot_labels], axis=0) + return images, labels, target_probs + + +def _dataset(load_fn, + is_training: bool, + total_batch_size: int, + ratio: Optional[float] = None, + one_minus_ratio: Optional[float] = None, + repeat: int = 1) -> tf.data.Dataset: + """Creates a dataset.""" + num_devices = jax.device_count() + per_device_batch_size, ragged = divmod(total_batch_size, num_devices) + if ragged: + raise ValueError( + f'Global batch size {total_batch_size} must be divisible by the ' + f'total number of devices {num_devices}') + if repeat > 1: + if per_device_batch_size % repeat: + raise ValueError( + f'Per device batch size {per_device_batch_size} must be divisible ' + f'by the number of repeated batches {repeat}') + per_device_batch_size //= repeat + if ratio is None and one_minus_ratio is None: + pass # Use full batch size. + elif one_minus_ratio is None: + per_device_batch_size = max( + 1, min(round(per_device_batch_size * ratio), + per_device_batch_size - 1)) + elif ratio is None: + batch_size = max(1, min(round(per_device_batch_size * one_minus_ratio), + per_device_batch_size - 1)) + per_device_batch_size = per_device_batch_size - batch_size + else: + raise ValueError('Only one of `ratio` or `one_minus_ratio` must be ' + 'specified') + if repeat > 1: + per_device_batch_size *= repeat + # When testing, we need to batch data across all devices (not just local + # devices). + num_local_devices = jax.local_device_count() + if is_training: + batch_sizes = [num_local_devices, per_device_batch_size] + else: + num_hosts = jax.host_count() + assert num_hosts * num_local_devices == num_devices + batch_sizes = [num_hosts, num_local_devices, per_device_batch_size] + return load_fn(batch_sizes, is_training=is_training) + + +def _merge_eval_scalars(a, b): + if b is None: + return a + for k, v in b.items(): + a['eval_' + k] = v + return a diff --git a/adversarial_robustness/jax/experiment_test.py b/adversarial_robustness/jax/experiment_test.py new file mode 100644 index 0000000..c7a767b --- /dev/null +++ b/adversarial_robustness/jax/experiment_test.py @@ -0,0 +1,46 @@ +# Copyright 2021 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quick script to test that experiment can import and run.""" + +from absl import app +import jax +import jax.numpy as jnp +from jaxline import utils as jl_utils +from adversarial_robustness.jax import experiment + + +@jl_utils.disable_pmap_jit +def test_experiment(unused_argv): + """Tests the main experiment.""" + config = experiment.get_config() + exp_config = config.experiment_kwargs.config + exp_config.dry_run = True + exp_config.emulated_workers = 0 + exp_config.training.batch_size = 2 + exp_config.evaluation.batch_size = 2 + exp_config.model.kwargs.depth = 10 + exp_config.model.kwargs.width = 1 + + xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0)) + bcast = jax.pmap(lambda x: x) + global_step = bcast(jnp.zeros(jax.local_device_count())) + rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count())) + print('Taking a single experiment step for test purposes!') + result = xp.step(global_step, rng) + print(f'Step successfully taken, resulting metrics are {result}') + + +if __name__ == '__main__': + app.run(test_experiment) diff --git a/adversarial_robustness/jax/model_zoo.py b/adversarial_robustness/jax/model_zoo.py index af54646..0fdf9f0 100644 --- a/adversarial_robustness/jax/model_zoo.py +++ b/adversarial_robustness/jax/model_zoo.py @@ -14,19 +14,14 @@ """WideResNet implementation in JAX using Haiku.""" -from typing import Any, Mapping, Optional, Text +from typing import Any, Dict, Optional +import chex import haiku as hk import jax import jax.numpy as jnp -CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) -CIFAR10_STD = (0.2471, 0.2435, 0.2616) -CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) -CIFAR100_STD = (0.2673, 0.2564, 0.2762) - - class _WideResNetBlock(hk.Module): """Block of a WideResNet.""" @@ -89,16 +84,16 @@ class WideResNet(hk.Module): num_classes: int = 10, depth: int = 28, width: int = 10, - activation: Text = 'relu', - norm_args: Optional[Mapping[Text, Any]] = None, - name: Optional[Text] = None): + activation: str = 'relu', + norm_args: Optional[Dict[str, Any]] = None, + name: Optional[str] = None): super(WideResNet, self).__init__(name=name) if (depth - 4) % 6 != 0: raise ValueError('depth should be 6n+4.') self._activation = getattr(jax.nn, activation) if norm_args is None: norm_args = { - 'create_offset': False, + 'create_offset': True, 'create_scale': True, 'decay_rate': .99, } @@ -113,6 +108,7 @@ class WideResNet(hk.Module): **norm_args) self._linear = hk.Linear( num_classes, + w_init=jnp.zeros, name='logits') blocks_per_layer = (depth - 4) // 6 @@ -132,7 +128,7 @@ class WideResNet(hk.Module): name='resnet_lay_{}_block_{}'.format(layer_num, i))) self._blocks.append(blocks_of_layer) - def __call__(self, inputs, **norm_kwargs): + def __call__(self, inputs: chex.Array, **norm_kwargs) -> chex.Array: net = inputs net = self._conv(net) @@ -145,21 +141,3 @@ class WideResNet(hk.Module): net = jnp.mean(net, axis=[1, 2]) return self._linear(net) - - -def mnist_normalize(image: jnp.array) -> jnp.array: - image = jnp.pad(image, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant', - constant_values=0) - return (image - .5) * 2. - - -def cifar10_normalize(image: jnp.array) -> jnp.array: - means = jnp.array(CIFAR10_MEAN, dtype=image.dtype) - stds = jnp.array(CIFAR10_STD, dtype=image.dtype) - return (image - means) / stds - - -def cifar100_normalize(image: jnp.array) -> jnp.array: - means = jnp.array(CIFAR100_MEAN, dtype=image.dtype) - stds = jnp.array(CIFAR100_STD, dtype=image.dtype) - return (image - means) / stds diff --git a/adversarial_robustness/jax/train.py b/adversarial_robustness/jax/train.py new file mode 100644 index 0000000..26e45d0 --- /dev/null +++ b/adversarial_robustness/jax/train.py @@ -0,0 +1,32 @@ +# Copyright 2021 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs a JAXline experiment to perform robust adversarial training.""" + +import functools + +from absl import app +from absl import flags +from jaxline import platform +import tensorflow.compat.v2 as tf + +from adversarial_robustness.jax import experiment + +if __name__ == '__main__': + flags.mark_flag_as_required('config') + try: + tf.config.set_visible_devices([], 'GPU') # Prevent TF from using the GPU. + except tf.errors.NotFoundError: + pass + app.run(functools.partial(platform.main, experiment.Experiment)) diff --git a/adversarial_robustness/jax/utils.py b/adversarial_robustness/jax/utils.py new file mode 100644 index 0000000..8f0ba13 --- /dev/null +++ b/adversarial_robustness/jax/utils.py @@ -0,0 +1,197 @@ +# Copyright 2021 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions.""" + +import re +from typing import Optional, Sequence, Tuple + +import chex +import einops +import haiku as hk +import jax +import jax.numpy as jnp +import optax + + +def get_cosine_schedule( + max_learning_rate: float, + total_steps: int, + warmup_steps: int = 0) -> optax.Schedule: + """Builds a cosine decay schedule with initial warm-up.""" + if total_steps < warmup_steps: + return optax.linear_schedule(init_value=0., end_value=max_learning_rate, + transition_steps=warmup_steps) + return optax.join_schedules([ + optax.linear_schedule(init_value=0., end_value=max_learning_rate, + transition_steps=warmup_steps), + optax.cosine_decay_schedule(init_value=max_learning_rate, + decay_steps=total_steps - warmup_steps), + ], [warmup_steps]) + + +def sgd_momentum(learning_rate_fn: optax.Schedule, + momentum: float = 0., + nesterov: bool = False) -> optax.GradientTransformation: + return optax.chain( + optax.trace(decay=momentum, nesterov=nesterov), + optax.scale_by_schedule(learning_rate_fn), + optax.scale(-1.)) + + +def cross_entropy(logits: chex.Array, labels: chex.Array) -> chex.Array: + return -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) + + +def kl_divergence(q_logits: chex.Array, + p_logits: chex.Array) -> chex.Array: + """Compute the KL divergence.""" + p_probs = jax.nn.softmax(p_logits) + return cross_entropy(q_logits, p_probs) - cross_entropy(p_logits, p_probs) + + +def accuracy(logits: chex.Array, labels: chex.Array) -> chex.Array: + predicted_label = jnp.argmax(logits, axis=-1) + correct = jnp.equal(predicted_label, labels).astype(jnp.float32) + return jnp.sum(correct, axis=0) / logits.shape[0] + + +def weight_decay(params: hk.Params, + regex_match: Optional[Sequence[str]] = None, + regex_ignore: Optional[Sequence[str]] = None) -> chex.Array: + """Computes the L2 regularization loss.""" + if regex_match is None: + regex_match = ('.*w$', '.*b$') + if regex_ignore is None: + regex_ignore = ('.*batchnorm.*',) + l2_norm = 0. + for mod_name, mod_params in params.items(): + for param_name, param in mod_params.items(): + name = '/'.join([mod_name, param_name]) + if (regex_match and + all(not re.match(regex, name) for regex in regex_match)): + continue + if (regex_ignore and + any(re.match(regex, name) for regex in regex_ignore)): + continue + l2_norm += jnp.sum(jnp.square(param)) + return .5 * l2_norm + + +def ema_update(step: chex.Array, + avg_params: chex.ArrayTree, + new_params: chex.ArrayTree, + decay_rate: float = 0.99, + warmup_steps: int = 0, + dynamic_decay: bool = True) -> chex.ArrayTree: + """Applies an exponential moving average.""" + factor = (step >= warmup_steps).astype(jnp.float32) + if dynamic_decay: + # Uses TF-style EMA. + delta = step - warmup_steps + decay = jnp.minimum(decay_rate, (1. + delta) / (10. + delta)) + else: + decay = decay_rate + decay *= factor + def _weighted_average(p1, p2): + d = decay.astype(p1.dtype) + return (1 - d) * p1 + d * p2 + return jax.tree_multimap(_weighted_average, new_params, avg_params) + + +def cutmix(rng: chex.PRNGKey, + images: chex.Array, + labels: chex.Array, + alpha: float = 1., + beta: float = 1., + split: int = 1) -> Tuple[chex.Array, chex.Array]: + """Composing two images by inserting a patch into another image.""" + batch_size, height, width, _ = images.shape + if split > 1: + split_batch_size = batch_size // split + + # Masking bounding box. + box_rng, lam_rng, rng = jax.random.split(rng, num=3) + lam = jax.random.beta(lam_rng, a=alpha, b=beta, shape=()) + cut_rat = jnp.sqrt(1. - lam) + cut_w = jnp.array(width * cut_rat, dtype=jnp.int32) + cut_h = jnp.array(height * cut_rat, dtype=jnp.int32) + box_coords = _random_box(box_rng, height, width, cut_h, cut_w) + # Adjust lambda. + lam = 1. - (box_coords[2] * box_coords[3] / (height * width)) + idx = jax.random.permutation(rng, split_batch_size) + def _cutmix(x, y): + images_a = x + images_b = x[idx, :, :, :] + y = lam * y + (1. - lam) * y[idx, :] + x = _compose_two_images(images_a, images_b, box_coords) + return x, y + + if split <= 1: + return _cutmix(images, labels) + + # Apply CutMix separately on each sub-batch. This reverses the effect of + # `repeat` in datasets. + images = einops.rearrange(images, '(b1 b2) ... -> b1 b2 ...', b2=split) + labels = einops.rearrange(labels, '(b1 b2) ... -> b1 b2 ...', b2=split) + images, labels = jax.vmap(_cutmix, in_axes=1, out_axes=1)(images, labels) + images = einops.rearrange(images, 'b1 b2 ... -> (b1 b2) ...', b2=split) + labels = einops.rearrange(labels, 'b1 b2 ... -> (b1 b2) ...', b2=split) + return images, labels + + +def _random_box(rng: chex.PRNGKey, + height: chex.Numeric, + width: chex.Numeric, + cut_h: chex.Array, + cut_w: chex.Array) -> chex.Array: + """Sample a random box of shape [cut_h, cut_w].""" + height_rng, width_rng = jax.random.split(rng) + minval_h = 0 + minval_w = 0 + maxval_h = height + maxval_w = width + i = jax.random.randint( + height_rng, shape=(), minval=minval_h, maxval=maxval_h, dtype=jnp.int32) + j = jax.random.randint( + width_rng, shape=(), minval=minval_w, maxval=maxval_w, dtype=jnp.int32) + bby1 = jnp.clip(i - cut_h // 2, 0, height) + bbx1 = jnp.clip(j - cut_w // 2, 0, width) + h = jnp.clip(i + cut_h // 2, 0, height) - bby1 + w = jnp.clip(j + cut_w // 2, 0, width) - bbx1 + return jnp.array([bby1, bbx1, h, w]) + + +def _compose_two_images(images: chex.Array, + image_permutation: chex.Array, + bbox: chex.Array) -> chex.Array: + """Inserting the second minibatch into the first at the target locations.""" + def _single_compose_two_images(image1, image2): + height, width, _ = image1.shape + mask = _window_mask(bbox, (height, width)) + return image1 * (1. - mask) + image2 * mask + return jax.vmap(_single_compose_two_images)(images, image_permutation) + + +def _window_mask(destination_box: chex.Array, + size: Tuple[int, int]) -> jnp.ndarray: + """Mask a part of the image.""" + height_offset, width_offset, h, w = destination_box + h_range = jnp.reshape(jnp.arange(size[0]), [size[0], 1, 1]) + w_range = jnp.reshape(jnp.arange(size[1]), [1, size[1], 1]) + return jnp.logical_and( + jnp.logical_and(height_offset <= h_range, + h_range < height_offset + h), + jnp.logical_and(width_offset <= w_range, + w_range < width_offset + w)).astype(jnp.float32) diff --git a/adversarial_robustness/pytorch/README.md b/adversarial_robustness/pytorch/README.md new file mode 100644 index 0000000..4f0afee --- /dev/null +++ b/adversarial_robustness/pytorch/README.md @@ -0,0 +1,28 @@ +# PyTorch evaluation + +We provide PyTorch evaluation code for convenience. If you developed a version +of our training pipeline for PyTorch, please let us know as we will link it from +here. + +Here are known PyTorch implementations of our training pipeline: + +* https://github.com/imrahulr/adversarial_robustness_pytorch (by Rahul Rade) + +Here are few consideration when reproducing our training pipeline in PyTorch. +As opposed to the [RST](https://github.com/yaircarmon/semisup-adv) code +(provided by Carmon et al.): + +* We set the batch normalization decay to 0.99 (instead of 0.9). +* We do not apply weight decay (l2 regularization) to the batch normalization + scale and offset +* We use Haiku's default initialization for all layers (except the last, which + is initialized with zeros). +* The PGD attack used during training uniformly initializes the initial solution + over the l-p norm ball. +* We run the attack over the local batch statistics (rather than the evaluation + statistics). +* We update batch normalization statistics from adversarial examples only ( + rather than both clean and adversarial examples). +* We use 10 epochs warm-up to our learning schedule. + + diff --git a/adversarial_robustness/requirements.txt b/adversarial_robustness/requirements.txt index 160c6fc..8e3dcd5 100644 --- a/adversarial_robustness/requirements.txt +++ b/adversarial_robustness/requirements.txt @@ -1,51 +1,64 @@ -absl-py==0.10.0 +# Direct dependencies. +absl-py==0.12.0 +chex==0.0.7 +dm-haiku==0.0.4 +einops==0.3.0 +jax==0.2.16 +jaxlib==0.1.68 +jaxline==0.0.3 +ml-collections==0.1.0 +numpy==1.19.5 +optax==0.0.8 +tensorflow==2.5.0 +tensorflow-datasets==4.3.0 +torch==1.9.0 +torchvision==0.10.0 +tqdm==4.61.1 +# Transitive dependencies. astunparse==1.6.3 -attrs==20.3.0 -cachetools==4.1.1 -certifi==2020.11.8 -chardet==3.0.4 -dataclasses==0.6 -dill==0.3.3 -dm-haiku==0.0.3 +attrs==21.2.0 +cachetools==4.2.2 +certifi==2021.5.30 +chardet==4.0.0 +contextlib2==21.6.0 +dill==0.3.4 +dm-tree==0.1.6 flatbuffers==1.12 future==0.18.2 -gast==0.3.3 -google-auth==1.23.0 -google-auth-oauthlib==0.4.2 +gast==0.4.0 +google-auth==1.32.0 +google-auth-oauthlib==0.4.4 google-pasta==0.2.0 -googleapis-common-protos==1.52.0 -grpcio==1.33.2 -h5py==2.10.0 +googleapis-common-protos==1.53.0 +grpcio==1.34.1 +h5py==3.1.0 idna==2.10 -importlib-resources==3.3.0 -jax==0.2.6 -jaxlib==0.1.57 +keras-nightly==2.5.0.dev2021032900 Keras-Preprocessing==1.1.2 -Markdown==3.3.3 -numpy==1.18.5 -oauthlib==3.1.0 +Markdown==3.3.4 +oauthlib==3.1.1 opt-einsum==3.3.0 -Pillow==8.0.1 +Pillow==8.2.0 +pkg-resources==0.0.0 promise==2.3 -protobuf==3.14.0 +protobuf==3.17.3 pyasn1==0.4.8 pyasn1-modules==0.2.8 -requests==2.25.0 +PyYAML==5.4.1 +requests==2.25.1 requests-oauthlib==1.3.0 -rsa==4.6 -scipy==1.5.4 +rsa==4.7.2 +scipy==1.7.0 six==1.15.0 -tensorboard==2.4.0 -tensorboard-plugin-wit==1.7.0 -tensorflow==2.3.1 -tensorflow-datasets==4.1.0 -tensorflow-estimator==2.3.0 -tensorflow-metadata==0.25.0 +tabulate==0.8.9 +tensorboard==2.5.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorflow-estimator==2.5.0 +tensorflow-metadata==1.1.0 termcolor==1.1.0 -torch==1.7.0 -torchvision==0.8.1 -tqdm==4.53.0 +toolz==0.11.1 typing-extensions==3.7.4.3 -urllib3==1.26.2 -Werkzeug==1.0.1 +urllib3==1.26.6 +Werkzeug==2.0.1 wrapt==1.12.1 diff --git a/adversarial_robustness/run.sh b/adversarial_robustness/run.sh index 6f11347..7fd1741 100755 --- a/adversarial_robustness/run.sh +++ b/adversarial_robustness/run.sh @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +set -euf -o pipefail # Stop at failure. + python3 -m venv /tmp/adversarial_robustness_venv source /tmp/adversarial_robustness_venv/bin/activate pip install -U pip @@ -34,3 +36,9 @@ python3 -m adversarial_robustness.pytorch.eval \ --batch_size=1 \ --num_batches=1 \ --nouse_cuda + +# We disable pmap/jit to avoid compilation during testing. Since the +# test only runs a single step, it would not benefit from such a compilation +# anyways. +python3 -m adversarial_robustness.jax.experiment_test \ + --jaxline_disable_pmap_jit=True diff --git a/rapid_task_solving/README.md b/rapid_task_solving/README.md new file mode 100644 index 0000000..98019da --- /dev/null +++ b/rapid_task_solving/README.md @@ -0,0 +1,123 @@ +# One-shot StreetLearn and Memory & Planning Game + +This repository contains code for the environments used in the paper +["Rapid Task-Solving in Novel Environments"](https://arxiv.org/abs/2006.03662) +by Sam Ritter, Ryan Faulkner, Laurent Sartran, Adam Santoro, Matt Botvinick and +David Raposo. It was published as a conference paper at ICLR 2021. + +To cite this work: + +``` +@inproceedings{ + ritter2021rapid, + title={Rapid Task-Solving in Novel Environments}, + author={Samuel Ritter and Ryan Faulkner and Laurent Sartran and Adam Santoro + and Matthew Botvinick and David Raposo}, + booktitle={International Conference on Learning Representations}, + year={2021}, + url={https://openreview.net/forum?id=F-mvpFpn_0q} +} +``` + +### Memory&Planning Game + +The _Memory&Planning Game_ is a simple variation of the well-known _Memory +Game_, wherein players must remember the locations of cards in a grid. This +variation extends the challenge to require planning as well as remembering. + +In the _Memory&Planning Game_, the agent occupies an environment consisting +of a grid of symbols (e.g. 4x4). The observation consists of two symbols — one +which corresponds to the agent's current location, and another that corresponds +to the "goal" the agent is tasked with navigating to. The agent can not see its +relative location with respect to other symbols in the grid. At each step the +agent selects one of 5 possible actions: _move left_, _move right_, _move up_, +_move down_, and _collect_. If the agent chooses the "collect" action when its +current location symbol matches the goal symbol, a reward of +1 is received. +Otherwise, the agent receives a reward of 0. At the beginning of each episode, a +new set of symbols is sampled, effectively inducing a new transition function. +The agent is allowed a fixed number of steps (e.g. 100) per episode to "collect" +as many goals as possible. Each time the agent collects a goal — which +corresponds to completing a task —, a new goal is sampled in and the transition +function stays fixed. + +#### Example + +The following code snipped shows an example of how to load the environment, +start a new episode and take a few random steps. A plot representing the current +state of the environment is displayed for each step. + +``` +import memory_planning_game + +env = memory_planning_game.MemoryPlanningGame(4, seed=123) +_ = env.reset() +for _ in range(5): + timestep = env.take_random_action() + fig, ax = env.draw_maze() +``` + +![Memory & Planning Game environment](images/example_mpg.png) + +### One-Shot StreetLearn + +The _One-Shot StreetLearn_ is a domain wherein environments are sampled as +neighborhoods from the _StreetLearn_ dataset (Mirowski et al., 2019). Tasks are +then sampled by selecting a position and orientation that the agent must +navigate to from its current location. + +In _One-Shot StreetLearn_, the agent's observations consist of a representation +of the current state and a representation of the goal state. The agent receives +no other information from the environment. The available actions are _turn +right_, which orients the agent clockwise toward the next available direction of +motion from its current location; _turn left_, which does the same in the other +direction; and _move forward_, which moves the agent along the direction it is +facing to the next available location. In each episode, we sample a new +neighborhood with 5 intersections from one of 12 cities. To reduce the +exploration problem while keeping the planning difficulty constant, we removed +all the locations that were not intersections (i.e. that corresponded to +degree-2 nodes in the connectivity graph of the sampled neighbourhood). + +Note: In the paper we used images to represent each location, taken from the +_StreetLearn_ dataset. In this codebase, to simplify the public release process, +we replace the images with one-hot vectors, to represent each location. + +Every time the agent reaches a goal, a new starting- and goal-state pair is +sampled, initiating a new task, until the fixed episode step limit is reached +(e.g. 200 steps). A step that takes the agent to the goal state results in a +reward of +1. Any other step results in a reward of 0. + +#### City graph datasets + +In this [link] +(https://console.cloud.google.com/storage/browser/one_shot_streetlearn_graphs) +you can download the datasets containing the connectivity graphs of different +cities. You will need at least one of these datasets in order to build the +environment. + +#### Example + +The following code snipped shows an example of how to load the environment for +an example city (in this case, Madrid), start a new episode and take a few +random steps. A plot representing the current state of the environment is +displayed for each step. + +``` +import numpy as np +import one_shot_streetlearn + +region = 'madrid' +sl_config = { + 'dataset_path': f'./datasets/{region}_full_graph.gexf', + 'max_episode_steps': 200, + 'num_junctions': 5, +} +env = one_shot_streetlearn.OneShotStreetLearn(**sl_config) + +_ = env.reset() +for _ in range(4): + rnd_action = np.random.randint(env.NUM_ACTIONS) + timestep = env.step(rnd_action) + fig, ax = env.draw_subgraph() +``` + +![One-shot StreetLearn environment](images/example_osl.png) diff --git a/rapid_task_solving/images/example_mpg.png b/rapid_task_solving/images/example_mpg.png new file mode 100644 index 0000000000000000000000000000000000000000..2affd8eba4c137cb8d2a0a0919672ac4d587798d GIT binary patch literal 142077 zcmeFZcRZY3*FTC7Er}XLjUKm7qK_eZlu07d8NI|XdN+a~qLYLeOprv}L@&_>(TPrs z-X<{^Wt72aXWY-*&hI?$`RDw5#%E^lYhP>ax~{eNTHkeve{7_4lZKs!goNa#p04Io z5)yI)5)x85s_Vp((3g~r#Dv^M-B6u`q%wu}%#MP1eb+(vsUZnTh!6?M%U2{M1Y*%k z3<-(96bZ?OEeVNI76}QfS5D&-72=IoPG)+}hK3}65%W|e6r_wK*N8b%;vWeqJIPgT z#2m>ZQjY)1KPA2QA2ehnBynyeuVNg~L@>xzp+ zB}j>w^>FgDyBp*I@$^*+0`UDqp+wAI6-)Bn{fETQ9l&R1`1q~{)W_+roP?Bw6rU>1 z-Me>Hd>oyXo@#3U2Rrc!z~}1c=cObm85kHS5hyDG^>LAuR#a4!l#-E@kr5|Si2DY6 z`q>4Ed;0SKtCIh$N7KpI!N<+Z&kgE%_o`kyd#Jx3fRFF0q5t~)tDk;u&i~z#r|*9Z zi#R~Zt2>g?5>k@?j?F2^?SEmry7Mo#f5!E%=2WgKQv&(8IT1U2)h$(Nm46!ezm@%W z_x?ry)X5hL@xLN}?&;>ID)S#K|3~rv;JRvC>9Jdo6U1E8&BMvl_o}4_va%|Y|Mt$m z2VmJSr^?%&@KUp=P5U7u_mz{&t)yV!~`5#69L9HTrwHW_nasIXP|CAD! zUX?~g^1sALm4-2o){TTjjYLmV{aFwxcAoOPht6}#U8c4L+BeKC>W|+{ysRNfVR`X7 z^);y`BSUoLJ@R%o+Sd;?U#H!rp-!b&*Zmx;=9;2LMNhhJh$em7X;Zm+Vz?J!5P)wu znu9Hbx+9^B^Ul9YPs)|5kER1#%Z&y;J>q>##Y`5X_V1E5Px7*=)33K^A|HOFzeoOW z6Lzv%+ox~-U+N%0N(||8<~h;#qHR{cp3N#w(@vk;UNp%L%uCH%b5MBykdy{{Kn%|2hd) z-|qfD#SB0jSq?iaNm@R~&q56oi(0zfKUaQ9E_uRRLCHrSy6z!3b)%>iNqMW>%;tLJ z`Iji<&shoMVo&kwWSY;rzq3`{llVt+QKk%RUv9?h-!8;IPx|T&6(FlDTo^)`l|dcB zS|#)-R{DFh)2yGG8yre0G6uWqzq0-x0>|;Dt@UEF)#&0k_`=EOU5S;`-nT`> zZE}b7cc|}T#%mQE$FGuo;ZonJ^~Uv*P2`)eEr# zx)VtHES|{D!i72?tPRa+$mQuGP;vjKtT4~zagXpK-eXa&q5DCZv}xEwolshd@}MqJ zAy%89W%SO)_KpznI2x#M@&|hH$7#_qe#;#!$L%%&V_?uVum_ILs+wp;RD(Tl%tT!5 zBM??z?GzaqMoMlz)$n=9-ZkLvwT!C?$x~UqGn|#<|I0iBO85fJ4jVoS8%EEzp3lm6 z$*oZCD$)%PBo%FcZk9_^)Wz*vMcE(jt@~jU3Hk}_F023F% zNcE%-WOT8$c;Vr*4_^%1>j0j0Y}t^Bt@eAfziX$6pV>*E5^y^!X(g2ODX!&=WB;(! zf7c2f6WaZKu}I`IQaxLy16>e@XB?&g@ zZZpoPGv$+uto~2}KHtbVKU$2~D+Ce>#l5jT0B1=JWxd>skh1(qaA4jAmuUF>K$#Uj z5LMP{+7r{pcbDFJq1Js0eYQUg1yvfRGDA~?UUYJtL~ zHk6ye@sW*>sUZAhe>5U+tP#>Ds23$gGRMj4ONfZ>5R4m>YEHbL&o8C z6d6@Tcb}bxQ$$9b^+e?4glrUN)VH2*WpFI)#(H=^8T%6S`7*TiC{%c=W!qm^IRNWQ zo0|SK1K`9RdZ$D-;2Q&HeuLF~O>sDHMK4uW7?^kDg840<)WHqd53mFO#4S3}NVrl-ep8YC4ce+~S(3)Ck!8lSv0uMmMB5n zAyZ98@87@AMp~kPtBJBgnN~h}0~9<5A1Q0)RqSOUYL+qE)%EDmGObH%2YtD^^dvyO zU|FSn4Wv@*zDm9e5cX#GU3u8&7K*~~E@jL!b7pePap>adLZ2x3q`%K{VedUlNwUgb zZ{7EMDW`jxt#b>#M{=3 zw#IQj4-g!+MC{RW4ElE$7{M%(d6Dq7ARcGA|ou$2+73-Z4Q#hN=C1DB? zWc_sM-Qy6AF3Wp1xp?^AUomY9uSm|1G>4d!;C&h29d-%>8OKap36mKKlE@&RIOVzg zE_Tr1bb{7xCKPa zjh>l8;q9$mblx>jt-~+zE)Lxv< zt*H(Qp}z46AB=NCr3*JTs)X$FJ}s^r1-@W<1kG$1=AN$X;*f)Q$F(dhU6*&oPg4Ct zJhbJYPc4cE9ZdZw(8>rUY9}Uy^v6p`Z`8FGe85DKr(Ypn_UYn|UKk-R;yA93ZJD&T zzP|Go)UW^m(gVnKsRDN+=?qlQ*Fo}mVqv}1#Q9-ggFLmy4TIvsGIdoyC?pV{n2}wz zhX7R)4=AFc-hrT9RQYV$>Pidw;^yTx@|bm>eya#s3G8IhxxcFN`>ioMh)hW1cB**R zGZAvUKRF}Dd3i|{&d@86$vEkvy8j`gN9B(+6EJ2Ur_$B=a~!znOP^I+`J711LT*qa z5DaOzVSEwgLb?E;J%lNHGCT=On={u`+-cqq&{(Rq2TG>CwU9~^$bL0{G8E|9_X%xs zT8}*ZmXQHcZ7>{yp><%Nx}>QqH7CY3K+D!gQcE|XK^>i)Mpc_gE#}{SAehhUFd0ZZ z3(KO;C3crAUIW{2A2+xX(|23B5~FiD{UIhaKyc{03*#5A(EV>|tSuO=CY|H&J%F<5 z^?2btl+a5d`*=5kl-qr~P~xUMi?!$oom=Jhc=9aoE2!ZU1#370Mp#e6i$jA4e$~h5 zH^(?PnMcq>ifg{y9(4kO?XIzKi+`~+sQA8yGz_Z{H(Xp`>=2VeZh1i66K=DxcX84F zMakv&xLVwsAL`}WYL&KG?65ql;SMacOJGr9_LWTF(eKy_c~N@ z*(oSwEQ1}veaozrd9hP2;BPiW%!GxMn-^pQVIejBET}&7yTJSQle~CgpQvIArbv*C zXz`)I9Hk@0>R-f@e1!}d8Q*Mw1`VhNpeKLA6h@&(Iw&Ba=NPrsKT)F1Tj6pyHKq;g zom(`<_>d?4U`w2*FC#T%e03cZ5dyagzBU$x+(h1VtkklE0v^eU>sshCl$>$&7@(S^ z?b|!9kus_1ee2`1baW`wzZbryUnOb>BkG+g#-%{yer|^4{D#R^QYPTq@`?{wTWwoN z(t(x6+F4e}S4h_{Z|lQxv5t`|DI)i^e+ulNFtD|6 zg@>v_1R6K@N3PyYF4<*~V!XTOZHm&y*ZIqD{bx78*q@f~y1N%(#9g!*=L88VsX8Ir zTz!8WnmqtgHnE~F3F(cFqnl1-#qyxA-_!oF_jbRbZ`66*o%H6dcu8f($aY#Oj9yWdeT($$}h3FfJXSs^v4$(s# z5_#PpG#`Bw_cRe9XSqX4FxW9fs|4dG=G>2kELeoZ#Io;=1_!u6Nn@~}$ z>CigrTvJu?JY5OF_jz7Hd`rdkjHd=oGpuBxqGah0nM9}&jw4K-IBjeimBm3t>OF~g zppflM5c}6GF$-+{$jYQMaJ-UjmAc7sqM%h?V6@WKzN#qbg^RsQWEWg`$WVd{ewJ_k zncmEgyz@@^A0YAgmy6AUA=*|tJGk)`b-KC;@#}BoS+>@ZsDhibyU#51T%J0m>m22V zn#B&8+39SDijKW92>1RO_}xHd8fSJVA&AaWq=)=1?tJ2z8qdv6c#7gZ z9ZxCXsT^W@=r7QM^=lAld{7FEX_MpkA{tk%;vqo}3E^hX4xB5xXhU*i(kuGPP`UM0 z276*}QB7rKz8U3q42!6QQ-}`3tQcfQ%);abM}f&&!W)v=rJeN$5x}2ov0(|)EL&X+ za4nV?@e0L)E9KbA5CyB%2;`9^KJ=Hli{UM0Y{fi~$G)17aI+v^ zqKPZvxhPJG(;*jngXpBGl2-6kx=m{EP=vHLQOgmE=iCeWu<82eJ#z1%Uu=kNYdutx5Ee2c z$C<=Y{=f<@_OkD>b0M5GHt0N@bRDoN(%m^fEy0-8L4WBDUb0gPRni%?lgs%QIH9_GwLP3ytXSa^yr>Ojh>iC*_sHcMs;ezci~-WllC-* zQ^W65Ln>NdpHg80DR}wD`GN4Jv>qUKB9DALHGNP)%83P&eJTn6i0x(%ymGaPRU4>El zLb2wLJEu1YG9P}z1vDTTX1B_xf_W>7?$S3g>KGh{wYGr(;T-J|j@2I@4o@0mr|C~| znsa}WG7S1eLZDTn*Ti7fii*E2zvO2*F3QnBd;SNRl)Rn~VeE9F(v1sqUQ35FbCZ-3wU=p9a-bQujEK*!_B|-^*T+ZZ zs@I)&j%E`jW9+C4qDn6rZAUd_Hr~1L$nK>TeqQgv+rM@*J$+~kR&OI8+NlrcS-vue z>+EovAMHg|S^8jN5;Ub`rE2fn)(eC3!r&w!{xcakeF@?@#!b&-;gQ~XIK>qE+CuzG zGdl-L{DB3Tt9mD2kG{=bdSBW4PF1*0r#wSg);qvSW-F{-{Q608He%|vdT_MBeZiu# zlbZ^Su>OwNvomcmrcG29wIj9kyHU9SmFBfMw73dt)ex~D;DC6vU7tZPrh+h$ z9gIADMWM|jh4-!V9R>te`4aN&@r>b4Osc&bV4qOT?@)`C!7f&Cb7-2TitLn3&mPzn zrM`zm``!5hW1^d^)LrN0R6f`eqnp@8N4^XVrp*u!j{aF6^>YjL^;jtLZN!Td;2Pwa z<(UFVWR{i@-dZkiAjh-nf7JGv4X&3<{ee2CTx`~ozLKW@YZ!W9TtQnz{W)6$1v6;R zU`6G`0U{K%q3?X|S^kK(qMWbIChW#=!Rz>}6g`@+Q6FASfJjtCD%jSGDMFc_-dHBrTb}mDIF%Av%!s*CA1ZEGmNgypB9y4#W|& z^VY!0+!xtjMdr*3i)GA59&1MQ`nRce+aUz9#jDO@9#UU3c1mia>)l^S#7$Ug4Ex4?nt zV#)+NhA=CR2*9%ES=F}{M0h_Bt5tiG$0mGq5L-=2|Hq5vxfuR+>xTL6om15(L8a&2 zcZq&!j_B1dPZ>~bNpy>f_?nlMR(_*e-@vpYS5|qbmidx>4jDf`nlp#?MdAUbFvatT z`(if^bc<-1*ZI7dtfKn~SIXkEp%SAa!)m035tJI2I+0-{*iI%8gmmN+_G9#|7nk4K zB7}RVhH9%n@z#GRoM)vXLdss~|QIYjgnJw+k?0s;_o3?)pWJ%bu zPwUdb3p6-@(hWm-8kLFW`^-aNQ(wgC?jm)L@OAUHqqlQW@j=~4F{PQ%M$bwx*THmQ z6tZX&JNStKA(CuettT<$SUw<8=Ua6I{ zCx<+>&j=fesLLP6R55K$zqQ{H2C@+5nG!nhcyLxr*V!*?GYF5~9v+#nu~g(&Rw3GP z78h=}*7ECiax2`VQ?|CURb3UfYQ|?GM(uZ2wj6j)QQGDHk0a!668*}0{Zrgs5DR&& zpcX26rqwEp-RFsgH_nM>1-&SNn8dd@T4%|yW!l+Gwum9$yz#tW=6AI80itlBgm9k* z8`Tn`8%mhZwz89n66!^r7j;{i;Vy+Wso~P?xw#{9knN%6-j$zki<$y{ps6nNw`!J$Wr z^cX#`Bhh42dwA0etxDMFfhCxT&rEQtQ9y5;5$8e>D?JJ*v7P*PLB}CW$?}NJE%~)ER=Xg1C=V@Kwkn}ua z((`lU`FWmp1R#9vVt1i*^PjCQjowTT?AyB5Fcr-63_r+7dy?lG+SIQh%kxnMThTGY zv{|#|c6gs~+OZkVaTt3vUWdQ(7fVe{Y|%o+i0c#6I>E-@u}zIE*Y; zHYsXF$ki?eU3orMrRV+L?6b#(2zlG^>w9}F;N6V>*<@`W)bV-+&kA_-*WquAPuKU{ z2?vYQl(@;4Y?p2>JN-I0pBgW_YL2%e9VrnAw%TKg90vhg#NP(BES8+Z@PFV`$ahQ1 z_e2IppSoZ#P_-jhBJ4`_v#@-BLb|Ln#9=C`-RJ13@M-YIkg@Q=i`L33v$A`oL29D; z`HFl%)M_LpH!7Jmw!!`rJt~@EH`Kdz8abZ3J#k?ix0v~lvGWhG$U;hb8QBqM7e3I* zk-8v*0Hl#ED;z@SqK{@7Na~s{rL~v@TFh7YAHrI* zQ%sU4t>3!0dn-GBzsDxV&XcBhJacRMKn0?6c_$QqjGGv0<;15|gxSg*v3m2|U?pwU z>U3!en(=J-6tOBfZw6;!j&<+wsJ7ISnEgDpzNk>2O}SW6-FLp59|8SvjB2HO9y04# z5^RysT+h@|yW3P=8y{AupE(Act51jnSpSJxK4Y8{Wij}}<2fnz?~zg>2xaR;j(>wd z2W^e~S%gpTN#>exUl82Qa`)B6FC|93FV4paBG4N#rQ8yPjgpAg-(YP5t9e=>LiO3S zDtXwY(B(>^;dHOFsDB>^r(XvLmyHPC5Avix&)zam21eAd&~m4J5E=uXem-qOj;XDG z+uM-=MnX$y<|kVt7&bgKFIT2oCB+a;O4M!Am&uHY6)$PH_y7!4r69=D*lePLw@0})IW;$>Q*FZNP>`ra1ZR4G=p9Obv7A`N2 zK(MtySdrm~pd4djQkI%#eL>*R_?pVz6VrS|{<4|gDR#EY<2YKUxw3qr%!kv)bfDSfqJGrFllCDL zWbiOXfV{WD@+&5xie|!O`9gX^64e`w&XGMwEVDrs*fWbm-9eG)@EAtPAW_fo?Lm2! zw}Np{RU+ha`CTwv=-NuR1?I;3kguo84$)OER35(NE({Idw1kJBA|t8wy6@BFhSwTN z$?(Fa2W;dNiFNEkW)-#o(^cF;&7jRO=N#WcRpzbI$cQL-{QMARJU@g6IMa?p@>BOd;3ix zN=`2|eFc*7sSgiq{ljplL~v^dojuj;OoS0OhxCO#{uJ}%ug{vP2%Zsr)Ke@F!lf=# z_VloZvVRXAJ_-}Re~L@NHDtT5FTSWHUwAdw`olqBQ`}YGN&ydwAdDLxDO^jiK(g#K zNr>JKJHGWZbYVX)K4NmMwIJ)+@zjD2zkaGCP zgD>(eZl@@?aBl;T65e5|rlrr&Q>Fo5TRxvt{8E9lD4_pQosxFz1J0B2@q^etR^)CB zelh|CkNsRU%R{%Blv`eR!FF>>?u`+qTod~FbXnw3#p836P+(sFu6c2F`rMCk)WX9X zvxMUw;KD+ctTYagc?E%FY!ZEL4u)l3H+k~t!?6xL@Ma=z<=#9hnZbb%zQQdvSPaWY z)^mj0x36#$G||bd!BObmfpUt>YkecB+Z5awIeCS3$t}FgQ!9_%_6d(m_O1}V44FWW zX?CLhOFs>wT9UrUB4zYhT@HbpghwjhYUkOk6+#{!ji$hAev=ZVa!aF`j{{!cVe9ZC zsE}jc)hEGnt{^v0P`%v&j@<_CxkO8e&`rz5 zkU!t!JjP5^@4;~T-5j%n=fOu9(fyj8U}+)blZ(xu(oZnFga4lfxsvnJ&Ha(NA*JQn zyqi`&eS#UMioFOV@PGm+n!7Y|yD@!QPjUk=k1tP`> zx`L?csq3i=>rM6qp4%3NI)rOML)lpjd)Xq=jS_@UcUmvevGUD!p=2`diSY3G&cn=z zlT3It!$A^tUXYVAcx_@TS}B)Pq}LW}n!7lf&mPf)eAF zYbSs4)D53#kABEpP&qHo2sxN^uMV4Fi(L;~0JNmvrEy-%@g3UkD^L|M!alUEeq3)g z_H}_$bvLr^?($8(1cx^iO6_|@rR5GsIdSgq)u$aAgsF5X9Znc?fWzK;f5~dw6LF8X zyn=dXBg|>O{l;^|sAEQ`8`XqSEFAFXW6rJ~+J>45^EqMp*Urchw@LR%t2{zK?0$D%^R?hjgFqKb2|ocmD}uQJ;- z=;1l;xR~NDnj(Jxyq)JveV=sYZJZ!ps+K6u2e$&hcc8oFk=_oXS&}^t8x`YMxQlGb zmpqztoLTg{?IJ^Pii8+z{g|^T9susR=-y7U=BTqAI8pJ7^FJ1$#XulI&{~xW+KV|W zG`4!!lju5@5y3v>b{WnbN0`TWX5V;Y&T=}cG3zXIFL?`9S=jY(>es-LnChQ~p>3Z1 zte*Rq?Z{(#hiAJ&tg)4pFZ$Wa9i}F37kO-oS7TB}4XsZXx)be64-R{_heO5vVh>%c zd45^e?C*Ne>G=AbT7gAdeKhEPHv>-lO_5dNS!Py_AitD69PVJ~q|~1D=d2UA$q5BD zys&)bk1|5lyizpYl8}r(ARz^A z2{O{;&(y5%H0)TsN{)3WwXi&#-^ncxS(gm|WaIjo=j^M)c(?ySM~Qp*lSR4ij?{of zi5`M<->}(uE6nzDL=ONoQ{bVmJSCt@6SFy+UtYHPh~QP)V%VpEHSQ*OgTcp-81rzO z8t?-KPDL+&A8rC(!{9WuIyq^0+!(lT=S4X1p-yi|&r&kdo zbu`3?7a)?tm`SUACgkQRzP>fw^1+tylWL5|FaK5RY>Bs=ek1DQdIto_QXB}DY!v$Yl~nsw! z-^SwfV>b@{t*b?f;qL8JHfiQ13g?Bl?rehMlVh}-h!4zyrjC!=Qd`3T7``bHyb6ty4cAm{^oJ>mPnE|F(^_@s!z?u@AWE zU^Oq~_*VDOq0U9w&C>w!)f-0Z9X98a%wZoO@scp1!?8?c&2N1lA-UZ*rp*Q75=(2s z5qrYIXJ=;Aw-&_gbXab|h%qgdg5UB4EHoK-4Cw)a-9^ayM$EWCa?;`z^qRmQoIx&$ z)&TFl33B9yh5norq`U92k(#vHJ6&r|P@`M=8(Cnp@dfBV@`$H+6Nc@pR63BmZ>8@I z*>r1)le^f`koSlURbYd|n8Rl=y=FHaQx5JTR3G+yAEg1iik1~gz$Q&6`~eOX!))U^ zgaNY&z2-^o*0SOh=vL^-hR$W#2hJGeX8-P4RY~>b$~AGQRnY5hB&u@|3fq^9tX2zN z2wtnz^khH?hmc+44eDZv?79D3( zvJte0u)c{u8!ymYZ4$hbxNZjj(&oDV1B5vU?b&BJK>b}`A7JFG z0M)6w3ANtJM!E^Jn737@^1+;Vn~R{Fm|nrjp4s6G?64gBo&>G*%G8yYmgTZDwq|`D z@(V}Q9|wsX&5OF`3Jh#LG^FrP77d55I$yjM_)+qixxM6*4UYzY^83{)_O9UIm9^(C zqCI6F?)!Q%Gk?#snA`FUbO+5HLc2cL3NjK3y*7fkgv)<*d56xvbT||zQOS36k|;l1 z26R;l-y`%Op39#ZTh+AWzXYvSzv=AT^y}^_{(WfWaTQpCH!%t-O6^|H)9ing_evLG za%^+4mswv)+?VJex5nKeZ->8>nQj{V4HzugeM=wlcHXS?D@WYaog_c+tb158q-U7q|05N?5w=>LCmU=+R3P0?YoBLu^Fcm-SRuR~m!-=$u zEKGq2>0OSa=Fj#{X4WzXTigl?E(wg}w7R<{0>5`x1rGJ-4(n@bCJ1&3QvAMn&7FtF z0BFCf?0cyO&w?!%s~pkAC>LbpLiMtAaA(7d)h#Oj-i(Eg>)J=nYk*m_H4h6p$_~PL z%ZAcXmT|VXQm`)dnv))xl_y57v z6m>ydDtR;fH@CPtTCjBXL`Q9p^9R`Hh)3J`nS27v<`%=W83M zt9)ISdIzyO`Ao)HgB2X+T}`>i?^~xSFL;0!RZZitd`yW%QwtxDM>KM@Fo36}lm};h zw@1rI97X-|jM(7cBwb5DvFMJAMDQg{M&D+jPa}fW%1gaN9UDaoIoKbGQ>{5+ZFCD; zk>D5+hD}R^GcF#w+pnl{q+|EBy0=_xZH{VFyB0sLc?cf$QeLM4I3>U?+GB0?BzAe} z0^%jZYdm)Y=KTp@ zs1ZyN?~%Ftk}Uk|ym^6G)aU}Ufb2zb^4J}C7g9R>R?FbRDl7Qlz1Z z=x%h_H|brjh1gH(GPuIlY>pOwMw5z#Vz02L|h0gag3_jvMQdp(}6&9n;@2 z=+0zJrGh#PLVL4Lxkh{n7h7j&S(9Xqtht|*RE=gZAJ-4gm)?x#Jh^f4iw%YNE(pyp zni&PLRfsEc{hUAA0whdIMG1zc$%{a8Bc|SeYv&Q4I{r(Moc+K6vFpzj?!(N)teSM5 zP>;)pP?dGxUD$x1$FJK*%Y{WAUmvzJPY7>cU<)wZ=W?($mGk*T8^jFm3WUntOi;x^ zn&n`x%5A}ZKTzw%^`^bAjBtnU%l@dNkRhBVxbIc?l3Qt;g~1Qy4;i)Idw_$O?cL}z zsodwomU4cX7q7?By>#Qes;_NaB`Y?%$0RTulssf!=@FKMnkBbq6;@E2(ETI`w>CzX2 zpyis&gF#=1x3ra}TT2=&2XoB{Ls==f@Qo@!P)D|;23@$Eww%|;lHK<@sWD33De_xg z`h-m-#d|R4Zs;g9yxQr!{bf@7eOl*syKhSR7m!&6oQq}STv4U3%4;xpau@IuDD?ad zc<)xoGH=D)2p>W^zCmhlp4}fKGe7gUOlTdlos<6Wq=&}6=3Rhp-iFm6Dh^%dSSc#j zk1%N0*O6;;SCM^k#r{d^PpfjrpG*8mkRh>BW6EL<_s&Zf=oWl!;iYE0(uENa&uMOR zuGE|#(PV|*NGFq7VJok@VZ{;tIKXc{GqEuOaXME-Jy$DD;q-|ZG}}6puJxw9n5{{G z1farRNUrvBEo-E}wfl$O)B9Vyq{`~_h;&I%>CzwS%1;^g9eY<%g}#3s|F&W)++>V0 zFgoPdvt`40ckXtMSLYp;3Y8-+PY}>{1a^sBqsPWum}jkj}JR20VcomeU%0~ zO?T%kaD4pqB~Oi_O4TYehJ5-_h=O5BrWC*Ytw(qINwgg)H2i8Ofn2f-Bgxo16{m#2 z2;HK!!0O~f_vrj2i5m%7jna8&^m(-FS$#+hPrlNR0eIhF6qfq-&=eq8K1ijd8 zlrLI}#g_T|7_0wC5pcKXTAU$J@p08Coy&0A%!CF}Sf|26^4SXA9S3x%?wyfYcf~Df z<1sWAUThhZ$_3XPd%g9?Fz`Y0zHoKKO?|$Rk+@fI>mNqSK_Mkk4r^ePi1qIgfE(it zL6YhF%( zoJqO2R^QzeRBNYC@$$89G8bG;zL|RS0_Qx8z|8tLD}fk4rPz(iXP-&6QbFTi3Ct=$ zlEv@c;V3oD{z!Lv0yzz(xvbY(*@)W zf#BlcTD-~)n~v%7&$t#NW<24YtW&p741@CD^93IHCP{8FhyNh;sJh)WF-6H`4^B?H zEslIMxC>ah1M8H5X`QVEuG3n-ZQD~!*V7YKwl`(eVQcS+En^!3x2fOY6UMI@cIv3^ zr44!}Wg6f0EwiKZ`!z zR7IZLM~=s_wD`+tk3WzR(zuAFLuX_>`|-kUqa!I2Ip6KxYF!ngvyvG;pCGrpm_bAFH+Aaa7<6QR__iHL%9grLC1tu#wd9w(M81@ zMv9bT%44XMAkn1Q6y@#j??IM{A0x^NHIlJ8%}afC$I=eo$z_diTHJ%;rb_LbM{_?H zBgd&%^Y;6LNj!Q6-G|lL-4d3RVI6Jt1`=NuaP>wEFGuVO z3snN~U6dc)IL~OX{J^XanT`vBu=Ys06JqH97G4ND_Dnu0GiCL%i(fhe+;*+`a8m&% zx~h{P(UuI2j1TQd-|W`gNlKwxcCkvJdr7Ojgy_{ybn57ii~S}Ul+N)Td7mqejm|+! z#rGz@BJ#PtCCH&Vxg5(NS5hQmO+xyX{+{j2D$2iv^`xGt`3u}<(J{a({FUJWOiN|< zM!zAXi`>6QW_K)deYLWx-XZ&AL5H+4dH3MG-5*bA*;9xfJK>a&iNfq^-$8ecg}b_Z z9j=JWv$gs<#wSbyQ8fy7Wc;KD2g7FfIPEMF>Awova*DbPpsF6*-8e3oaOnCzGJahG zcs>rCXH-9=a4(4`yw3t+)ahQZ-Qh2i6ro4m5m3@M7W7St}4PmxDq&68>iUg%E2XXu{mqoKG2@SI;?XJ0{%l8^%^bY?vJeC$bBwL57tH?`}o zv#;whl#+#Ow=5!vR%x5df8vP;=?z3wju!n9PR&#n$KHQkTuw8kFLR3{PMSs&OK2Dl z_go!iQjkj#S`eh(*i*ztCpGl+6~?eO3hU}BShLn(>OOYcBw2#dXl`Z4MOO>-%Q!>{ z6&3muP%Gk%bQLd5D}n|Bke*1amIR2A5mU6#x&C2Y0C<{U}_N z>1qI0sg*3^BJ$_n`y}f0m+5!MWsQm=haZG*PHlcWQ#cPWoMY^HdaQXU=m*(kND0!ccJt|cAJE$1-=ZNUPzZpUEEP8iCl;F3Ax9T;o=D(qZ@o$3wB4Qs4B6V z)(VFD^)h()EESpjRe=+yYg8LPfQ~seV>Q$>#bilkFEG=o#8($+i5@9@5l@g_L6S!k zt#udW9K`Fn?V}WXu#}W@)w@LaSH1VYna}blUZ|#Q-P>?WNLIzvT17()lh-N0U95Yu zey?u)7GI&7!FH{{m4rC=n1k2Go|R){%ukREDksq3jK<^q0YZD=-c8la*4wP8_g!9% zPJA7C>--HFqowvLXC$+Svh5nT6h9q;fSv3LOFEn}3~+jm=I$XpT0E#opT%T4>9j(M zTM1!B-&(es_N~vmFdENI2zF1TW8D|$Y>&Rh&R<&C;dtGp>*z*C9L|Hq_e7G`=8U=J zH*J?%Fp2da$*teq++VUOe;OMB&kyiY&#hP0qLbUCC{nsh>6-VMyzVw3M}Pr^;>~fY zk2w@y!V+B=Xu{hAeKeYdVGS-fx;-31yUUX)b@nv&mM_IeI&uLc)TmVD^A!j5%P@H^ z#v>I-vLQ~#o(KMRCm!fyNXx%RgS06f>Df77*x9NhJvZ}P(qI0o+|Kx5QUmBKq7=*O z`pD;seox zJJ`Y7U1G?1PR?;}2Z!>r8&()a;AigB-?EruMP_Sbo;p<#TbQu?t9V!<;*2@qFZrH& zIoCIcJ+;b(u(uxRabYGZ=s>1cOuT#8-BZir*OOrh6bv#Wz#NpZJstL#%fa-ZeJy1= zD>UpY-V2l9UG;o#RUfr(@?O-X3O>;k>2sf4EuZSCcW#xK;WhftpU4@Inw^eAjuc+6 z+8TCpjK=P8YLIa9uaPjia$TEe;TL@NNQd;W=fe{(nx;!X>9#kr8h$BWcgiQ9p3`$i zXG&dM>|DZVA7U_;;r&nuGs_M$m|kAHk3QN+=96T9LfOrm?RF;8LeyMVY)Uza@l{^- zZ;euyQ3t|Q>lu6>UTzGda3#R-RUKuY;}3|F4=gI9_Ca1a$A=@`tYQLOZzsIbf6Ndj zT6Vv0z9QGY{+DfPkrE~~{*_rp##gUKvU|}h6-jaYkEA`bmT&LmNJDzwG-8&Lj1+$H zaE~~@72)Jn-BMIkNVRj!48ePu#AzffufZ^`aq>gX%Cr{(evX_lJ2fLAK`(TLt!syD z%e=I;!K)9A6-zI_8R);AkXLm`@R;@`r1L9Xrwq7#>XV!-{FGDpoXHF&rEK3ZGe%uk z#LYdOEnr0E{Ln=OJ+zI{i)vI|XJGH{0yXgbnxJiw<4D0=7g@~aq8Y`tBf<3KWf!q1PH=3QZSt4UZUL`5!gH8>^6T1+WasV) za?QO-IW=in`?SYnv62q;$1Xil3gDNluo2@Bk3KqX$1kAv@!>(|>G#s^IsswtkU8U| z9;h7mH!Z%aGA1D}5&2tRHyckk(jyju7bD9FBcqD}@0Oejx^7|a*dJ+SerAXxgYvEE)oxpL>5=ZZN(1_Iai%d|OScghT+Qlj0Lv^T+eacLN zW%8)>bqWuMC#H50D;eH+t5gPLqJ%2_6lUp$F*sO7U}y>B5Ew=akEBwgwPrR|A#? z0Ptrr2mE0t9B2Xv8A|ySOl;|f-CXSW8IE#1#l$0blQ1eIqc_23+IS_1{%t~-th!Mv3BN;i8NIfdkjH-s|kfgI^-54OXBLONUQp!20gk8VbEh|4(Z;<_o$(G*G7c*t$Id+t6 zDVOq8Xdm)7>xjorv(3iIFP5~;s2Hh(7B|>ZC|1qcKC8|$tm29RmMo_~szaJncR2s^ ze*ri$!@apr+r}#dV9Wft@^U5>j9K@{;^NO$aZ9PbyCgG|RaPmLe>mI7je@C+I37Av zF+WZhi1ap8&$ON7zZcv4@m}U6>`|X91;tBQNUUzW;4re#Fa0jXw4PbG+HkS(c&l$P|rAAfcg*Nngqx#U%VN=I(p7k~2nQ zmSdXuYvF;%d{HUqo4a@27#<5s>_$dbws(X(EK`yH&MCa6qi(WzC&Iy=KsmhoutP}U z2h)Py6* zmnY|touNzb+gsnulI$fT>8a+1jpaksXQrrDFi#7W3H@B)_vTJBn-yd92g=o+^$hxlq2S-=&JFq@Gq;{#x{sB!* zK(l+g-rK+|D%h6MRgKZ{6eb~=J#0%Q!c2{KUjgrr9hsf}-u`@{;{Pq+CT7iDFMyfR zf83>g@pfU1^8Ejz>pX*+`2KLORO!+~mo8F6Cp78O6_hG9C@3955t7h*7ZH#yDk{DA zPyzu2>6*~Hgx(>9cKQ8h?u+~4zT4T^Idk^AyJycl=XpM-uVg|C__&K#`B}m67}e?H zJ${STp$cr-)^727me_>(jx(bLHeq+lHHB$@D*64uk81`fYb6Npn1G_v5A0#We005x z!g)mGVLn6|16VUv=o=H`3=2nes=(JgE3O6xH?h_*$s|GPQFXh8ReSozMRy*DMAUE3 z1J8zfrWOYcPfu}{lIquOQ%LDvAcwO)B&GaEA59IFPmwhw-IIbUxQ2%&olDZ$8mcfi zna@#}r$xFmC44nuYqu(bx_#`NTo8>en&2MZ9F3HU*smy^v(@0^PxIAMp4>tZI&ti{ zcf%mfL;PQOS!_n&XU}-VXnjn?KryCRd3yWfe1fQtDE;#wqksec1xIUqLw}VkYtBYH zd|9K1b_FaI4}#LpNcf+e5bG#4Ok{mbR4B3i@L_yuOFL>sO#i{eP#V{e*QX{3@x44lQMFqo^Rgz9{&a$>o@jVoW2h0 z3uG^j$X8$J;cSsN3y)!s7hTM|juV8;Vd_do=GkvUwhe17*S*p$BOL%<1V+HR*Z5HG zB+SSvp`;fYwS@~{9a@Mu$V#$vxJ8Eoogw$Y;Gy~ z-246RtV`_T-OP^ANrPKgl=Hb1q?zj4`zg6TFt-Bl z&q+-40~yJMY2l>9cTbP0LN=pQCTn&2Na&4zhz~1Z6M;1wTW)JWcpJfnJFn3w1{>$y zHBCa%DSFQ=8p!0e z&%^!tDDbc94BcY8IvMYd!6Ei52=erPV~b~4MrEqkw|jL0P7qn*5PT#Vi&NsMgpOR` zH!6dq3C8fU59nd2Q)GnqkCc=mB@H$zzsdR&tq?bcN|zqF zGrxvS01c{FjuUuN%e36fjAmGgXkPy0+mhtzezDA&OCH@N2t^P0}8q zm^qBE7S!42kqTmHH&k^=tKZu0K@IPJ*+X;|>`wR*??UW&Txw*PTKUE|zqgtc@xnUN zTm`bBF$F2q30d9#iBVUv*30rU;YCE&N>V=tXSg7yA&zMWuC85^I$wH2D$n#SOefa` z&n2Nh3JB;inyCHal+9-J7I?K%g;n;%z9*)pI^Q4PZ)+nx+S}1`{8$lD-FP)RX%;fp zI$DFy&jrh^7I?41-NaOiJUjmM1@e1(ZuXD#GYXLB&dC|o_$IF}eleCWLq{i_w-wH$NG@3-8Umke>npEU(Y220w$3h8v=@{HUxbOvRK_Q>(5NhJeNsM8 z)3IBmrnjyd#BA3`B^0<7^{wND+M{TL_~F@}{m;_VCpsKdhi%MblBeDmRQ*7t^4sWL zhPQ){*eJvtQB_&s*rfv?@`IND*^DO|morQ(l&$KZPxZJ^aF&b`q zH?vFm>;11GN5c}r6=z}wL9eryhlgW&D;IfQ#uJ5Dp_q|CA zxTFOalK2WxU_2dLiTzz+Z4p{owTS@Rb%q|6FT_lG9+CjTuy5w)m)C}lVD+K9ojPn` zCU#){nlMJh{$r0Sy{d$lW4y)v)-Y&V_QIBBUG^UuR!Yi$@Y*K`+bK(2hRK{$JClYZP ztE`?_wiLvM(v-;h(!3Gk8#<0X|1*9pRc z6lrJqN>Qbw+#Pct16$TlVD9_d8AkN4ufC4_#U83Iu^8*oQ5nler}F_?m}bHmt&2O{ zvW)Mw4-V|j7ge@v?INia%N$ZvardzKv8azbM@*py06|3foX6)0!lnb5L?Oj0g&Z+Q zTh~(#=?{tZd2UNwmuymbpIvnk!A$l)0i-aQ|7IM=<1&(k!pdS}!f5*jY1rF;GuBK` zMB;-#JDWImu?lQf)8YiPCu2Sv9~gyf6LmnU8Z~NA;&2>tKbZ!+qFVpTI(aT&Bp)`! z)>!v!!FH_j;7ujY*b@>jrr8%!j}aL27xSnlRpqyLoleja@N=mZ#eb)=D>II=u`H7& z;VPu^XEHH3#^~YQsl;xi>WG8LXm*0^zKGhq;%ZOfl``Iefd?iYw@Ce^ncpC^KA?~2 zv7tAuslFo&G_w~^PHR?=dELu?w)lA0(HwI7M~Dy`rM>Egx9T5fd&Jt+l2hV<789Pm zS#D1DFOTKmd_n67@Y$;VF)6}neJQ2Ua5YNbv~|T3LkY!4+Ff885C?jKqtEt=>cjEez1$+92Y45b!F*s7{@+U$9Yc2)ZxN4e480{ zs2&XdOM2m`>R2PEhspOIS2%s8R52k$?7^gX@-wLi;?{hr{(Rn6r!20>hJxIspf7@ zQMSWg2>SRC)9h}Ph*z`83zGDVbSWuKtwdG!)^@+{N4F)|SVgXx&h-^5b57k;++LeG z8s#dbpR124I&~bcyunw>h0x)r6~0Y)+3P3wikyGMy`tQHVz4}T9w-UqT+x(M2@|ab zHawz<`Vn-Q5yb#E3PFY&TUXttSqqO2-WaVKD(W=J15c?=7DtL*@7GPF1Bv~Oa_JosG6b$k*L4VPEx ziw~Obe>|hy_NG!rP5AwNBgPv%P`}fwjc73G`7tp+teIQuh0(p)5BSp{d67G}WqTHo?1<^7?&M-Mn+L zMQo(dgI8lcS#x_7=WN0XQgF8X6{!59x`Dr0{l+&q`?@Yc>Ee{(IrEgGSB39if!1en z2kLMq34?KUPvJ8ZRqg(olH>rVY6X&z3B@hZY)p-6zLD0Ak!cH3l4NpUdi`}4S3gn7 z2x(F0B>Oo$x4yc#k?(soP70erzOd&Pu9;}`u6lt!j$sZSX84D{&o1TN&6Ovqvf~cE z=3W8JAi1Gw$`=`nHSTBo_c^5ebG)lPp+`Uz-;?Dr08zBfgn4HDI&-dYmDikWHfr%l zgcIQiL87PacPX>|G3JI!J2j?;gKV3)sx58uLC-+vZ`85}VE~?l2j^ncS}ICy3XS@H!0z#Bkz|4Q#G3Y{D~= za5%w@?#IO^F;i(lRe!V>Q2*Ae;QnvJ%do~{`P=nus=*=!mN#H*+A&w^n8|daFr6meMrQ+VbCg!}b zMv&LsZg>dD)7K4Ix|}uSV<1kiy1&a=W<#%vtvsNCWuIo%wqvO^1(xbJZM z>PadaGqu-sKGu&}(_2m_Fo!4Djwd+GlFtV|U_Lbz-zW>W0~KR=MCBiEF4-8voj1@7EK2wOcICg5RA7MpwhZmh@kO5};K*>@4`sh8b8KCa zOucnAmW8IjB&pbcoB|umTn(!VW1F24y~L`rIVs}Td^f~{^n7$BP2>ol~sp7<* zCpN`$MZ~0B^(S`zhdZ>FYy~lQr-w|BlNzG5v#QCm#z_xi674^DwxX=b@ENa_}>Wf$t7Qf*nS zG4Uld&FbYCO@shcEwv~|smPUwH_hoGL#X33N2NzDR=OuvpT6|L&K`;pE-}cv4u>9` z4VTbjubY_Nd|2QgqIgUm7{?}MaJR&mvr!U*q!%4i#N*uzcGJjS8jv680>`3V$w4@? z5tNXGGyO{L#u?YJ5Z*&F8e3&m2}&2dw3?eeAwNyIbYsYau4+ zrh(_3eTpscKAjMQ<>R%M$F3$n#PL$A!1dGBdxF#?Nj9ca6KU8&5F|8NL=c(3@EeqD zIHlvChO8RtACO<Dx20cFAM1Gq8p(Boce^K>73XNfGuhA`& zcfJbr?PtCo8zm(jh(7#6L8X@SyVTRqCae*x%VIl+B($b8=+A`zTn8z4o>>eu;e+?V z<~Z+y6=!n-rmN2s1lg;x@Nt;?F?gtINtLgZy(4rPWbHkixg^Tpz<3jBz9K8;uVx&4 z&5?3Y9-=}La?C(oF?B?98Rj3~evpifJHs{<^$%jbIbu%Uixb=VQZ$Ptz1RQ*wA^SLJWmajb23sRVlIk^t zmI~5CW8Q4DdtQESw8WO-$J7o|cZ+q*+Br#M;Ho|S3e(e@=G*pr*wIRZ=#45MTe`B{ z1EGi{;EzGM2?|hnMfH>9vFL1dCYLufszH>+IOV_+kCC%|t7V0<4Mjxew{-o(PSxgd zFmdwstW(xD3`!rO(Q0C@8Yl^u1m?4+^I9MN6>RP9xK5ft(!wIO6`L+;K3(SHtGZlm ziP6#S&EL3csPq0)sF5h9neLF3kPu2?`R>jvEZy}Qol~cCEPA!~dX!3e2d8dvXcK4c z_4!6ffut}{^_{hH#R@$ zcsVO6a3a~TUqMyiFpI4XsJ5}`1+3-;+G}jqiYw)ecBXn%d$w>``*9d=0}%kKX3u>j z#ZXeV+vslJBH}$vG5flUjSFcukm0e(Y(w#7x$?7(*9Y_f$Er}R-sX(Bp?d~mg3AZV z?1m*gZ*msFeN)BL6ik(HUB>%O?>l(ShkW9~yk|aCfDpb$o)*=<2)I}5ufG{$GV1tA7y6WK(i~H z5Rdi0+8%+!P@#^0v+^fhCBN((_L8n%a`~ri;Rbo6RR*YR8q@;p1w8i$Nsc5GBv`%> zE##*2!bf|>L8}#1*)9L_L7U6(p77Sm2Hn(K9%!3#j`BkF?n=jzJ2<5RU&m9qp@ovb zcjtfc5rWw%yMD@+=-$uvk}&HdC!VDF`tdO>&)S_j&m6By2Gg|IcUP8&DdJa8AH?qr zq}b~&J**8$c@k(g2aqdyd=5~$yWAyz&^v}3M_eG-KEpobA4=Gh+0vDxh8Ah80^^rf zW!oOcFaE3FF^H89X#!dRKqXFrrsJqaK>!g{feOviUhw>0Rj$!=p$;&K6mB~a9s^7- zfy530bHsCG~<+ zLM8)Q!h^RWkUReDF?@ZcD}j6g$zgEJXw@hYSDWY8rD)IKtM1%#+@~fOj$2h;%rKqg zl-MpOR5RwtbJ5^Au>@$aD2?&L!BnDJ=rp%p1;dnhJR4e9CL6xb`1S(N) zxy^!W?t3wi>1_-7paMo}gHR*d8+8qEHsL456ai4=UD058pTPIZZ&+q(yo1j6zzhvU z-Hxf7=#b-9Cp>jWQPHpPt2Pi5;<4-T`6}-QiIouvoSz;FxTFvGXV?#4w`V-4{%0D& zt4gME5Iet^A(qM}f{pjCWp+3CGIfu#cw$voV3{cwdk18REu3(q$Sd)5)1E!_l|NnL zgsmbmY^v)6=I<$kmRfCaw8mKWoDYjDTwJoS_%~kY%pOL=>!$-w`rUmM(ZpueGGH08 ziw}upd$z#nB=ar)FbOH@!;-=6$P|-BD!Uu)_r^;(*!wo�EHJm8jtu(yZ;QP44z( zK@4&8=!=!^O79t+VZ84w^NO=Dl-;IRO4^ICRsj3b>?kCH8QiKNGhMNJ1!TR z(rY^u6Skc1&0jJu$SMJjxppt>eFa)&wF@=(BubUJjXJz8T?p!T(3Wguw)wdiZS|zMpqfx#+z_l+mC20H=3v>UneM0PwVf$zdL`mg zqjh2HUuA!^%#_EhpVZ~Y|E>t+^@V14DGZ*VdVNI7Zf}s8nBw_O+)bixf4xM^>)@{} zn}Ghx?NG~BTRzXVNXFs)!6w*iz-WMuW}D$gLc`R$hPxO`#(KoGbIwU`E?eu1_kFH1 ziW!g%8r@>!T)nd&PG8Woj^n)7`hKwUg;=}mPdhBftrBeLoJh4WW0}Yj0#(Q!wanRa z%im>5*NVjelx3(@{y<~g&K1v@lqJlvipFLD_^xLj->Gx{<*Z7(8Q4KHuP+tfHn^&& zqZh@mmA4&eT#v*Q*jpW3netaU>c8Fug8KqHDPaC7N#mxnx}?2#gEY(!9?C8o&Q#Oc zm@?1x-?yF4rvD_>9qYcBo&3`^7kkKZNy$~16o@*j3O;+USSu8?$9eN=D2=e}QsiUE z7gyAv4yKfF29xC-IkN3b#RZLT8g3YhvJbawD^4-qag)f*nKHA+w#}+>I#%@e9gyKA zjHK6i;*wCE(fuR^73sz8^3G+rSHTtOLUwOmra;@m@76(1r!S zZ%;mvy>#qJb!J@Y=Hibe#lvWw9<~6N^?6D^Yy?-PMWII?KUyw4^+6awKwTGG_L zuuGqh@QA=$C!$Q-l7@<`pIe#~Y0+Z2I4VN!?i}Cb&2#ho{Yb&O$Ef>6ygSGUVU*{AD0%dYf}wfM1P*Pw7ULFlqv zE4vDm*r};S`(JJ@SzDyro43LcQxonYG-=DK1G_lJjqVr5(%5~@8Kt6K; zK?1J!s%i&aemeoLWjpT?`NYAVt%H;hXe+hOwNoFYMu#gxF`t*_a3dkutqN+Y)@x3UGcY}fBu^|BEoiztLEgHA?RU$T5GQ!Oz6Gn!o5+|b}1DLfGRv`f%I{~ zNtv@uPK2oP2aH>$p$w`d#9p_+(99jFUcd#s3W$+yi&w3xM*lUI#NIE%FxxPdFyK;I zW|m!hcpHetuXAs6Di(-n@d#iWu7dZQY!8XC%K7R*#haRBV~VHH^y|e0!MOLzd4ae2 zX0*U92*&5h(b4gZO8^!V4swWv^$d9A~tjp_PcwlfOzQ zbOHRx__%x5=ewgwYbLM7dg71u?<54ZW#fP!UxJ_ zr&o6AS$>)(hplNiraN{1#Y7G~XR;LK+fhAHbhMma7QO01HYF{^8ABaQdBqL|AqiZ_ z68xB}Pq=5L=V~NV8BW!c)B;d1|8BdwPgd2L$FOpuaoEh&4ln&0R|{MD9PppPk#(;s zQxUVz!fANA3g4$KU8R&TO({PdgVrmMw?-+meP!0e)}_1SG)OKdd~ev=j1< zMU*P5YRu>6l2uiP=z;v*bWAh{+Dl~!EW*bSV@C* zXHmYF@Y5x-WUszV3+|Dc`#)3@Opi6#>LmaEna)J)msfuMQP5QoL+6Cj?$?;_Ozt0& zXG|!Aaw~!}AzB9Cz)nWdmtj%#zE6F9J=;e$f=Y-g{c8q7PICxq`SlCUgjfm-l|_Kh z9{`^(dcV%+QNoX^2%YY}E8rSxc6RIOLT4ERurqPc@5NaYWJ| zt+pqdClk!*fvC3qvU1QVE)@c5-0Pd3$T|b;;U#5*1B>F|xEDGLPQn|{bruMC-LS$6 z-BcvQNd?vKDJUm3z^&x%Q3?uQ7yB^%d1XY}D<7dx#tnS<{RJ3MI%1GC_aDwLt3t_I zcE|Y7uu=v#DaxfIuQ8cH@n7OpQl&0W!nn>(G)f#|jL!@ZO9o650>WoIC-gbnZzgmv zS6+0HP8Av^*E&T6WL0~%QSRI~hUDch4_X}GHFr;xiVJr|=}=caGkz*#v{UD7jpFOC zrFWl#AMNa{FA0dgVK(0@ey{=ZEv+j%)Z!8DRSN z`op1##Ns#^rmVDLa8pyC3U9~gRI+k=T~=%#gVOR`z6Xlg!XQKQ9V;L35U2?bO%H;` zCtbTF2*64NCBlr4zo!75dUq$sm~#)b7V1jm#td=qtAoa9KwrdIFv7b(5yRc`|WapshA8*XfGa#*Ns!P>hii z$`inKP7`jPWdg&z3;GHEs7RlAGHoL_$wj%Zj0g?+vL3=o|XDz zIlt{g;;5A;uwv1eNzaX?V>PGSA8QLCg5!F5&sj35{yDiox9JOp8(SA`E_sZh>2acc zPFl=~Inbe2gP)?iImfIyjw1UdidByzF3&LFh5sIoUF}@$zEYR&RfN7H1(`oOeN2@q{-NGG)Aq;?^a-B zGCR=P6@-sTQ7}Rh-N!Yg*=xo9nTC$TNZsdBQexXzkiY^b5toB^SK%Ik(JUhUeYV!_ z`>buFVXB|DL{}rZE|FYfsG&xzAn^U{xg>s>^J-ib9>E6My!~eV`W5BNGSPqP7HWvo z-(?e2EWDPoG8Q$yrM@(LQt`a_X?_cmTDf=Y+6*J6H#RJ@4VJY1dW(zt8xwkG&oHq^ z)VTMugq-YY#x2Llf{WlhE>2ECY3}TSsO}m5q50PB+0anA2TEgSsFucHC4qtDqx~v7nWdSaHxqjo{>Y%(Ti47OYs1#YS>KyXyu#krs=S}i z{9dLj8l7_%Xl93mIF*`4iAp&g9%9GMjdEouTfJkz(XWc;G} zPhQ$pc-Vy3+I8C#HbJI&3K_{Ba@DPV5q&OdAN3lquwTk_CG z!RW|I?)tS?K~Qr(UYQFC+WclaowO6PgO^o?xySg$9OD~mH&6YdcVgV}Em!T=cz_-9 z;IZu$Iv#r`9aMES4wSePIzYS6hm$Sr`Oy|8= z+ZTGY(ifm+4^@9YytBC}Wpq=Jv%{acHkTdbc@ZT2i81irA{%ebX`9 z{tjE&U}P7f+;?KVh#L&nsUCJO`4g=%IdAF`fGW2{FUAr z+RD83&*<*Oz=6`v`j1*V@xd{RU0dAt+0kS{iIK7IZDXhPFWc94%VQdo-nO!bAcL87 zF=nsD3W83L zCiDJlu)MEy!@0R-aL!uTlWod_DFlq}EpKxyxHwhdJd|mhFF5N%`YRNa3|no!t!16w z;JtV^ed>o_8O5H}?vAwNHwe!ew{M&O>=?fqor;Amn1?8RQQm7`+oXEqvGxwfx{K7P za324Kx}1EtZdqn})h1NCww{6o)+OVjLXkEh^FV4CBS4nNM7`KP~SKMLb^qNL0tfyRKHY$v|+#o!a!GL59orMd^JDQG)8~ zUsUD^FaIns-?zXJr^D`EvV#AAdmM=f5X{_!R5H4Gra&2>T>BOG&%OZmiQ-s%=p3<^c7p{#iB>}}X6kTKcn7z8L@k5aj z0L|%Om+-BXu-!(Jy^?QBD4K#~W!}-OwXYhRC0B3bHP=oTO@S@e_7=IAb|=m8miWfZ zKFxDp2W+>f!v*oPGgE4`)Cw*(v49YH(lJ5RE^(O2Bb-?l?rIDaVbkMeo|tC&ag|) zI~v$0p6Q}Z7sI1PPaf`Z+uOg79ve~m?WRY}Vp?R{m(Ue8tg&!das8Fz?KHug@+%XB zKx@=octFnPtU<$v^9JYuGJUS+yWwM<=2(ZiS<-uU_K)nJ8%Mc?jUmjkrs455z3V^y zG6)=MSIMrr?Gw6QW)@AnHlQ-MSdf9To`fEXvpu{azA?IsxW~cIK?o+PG1F09BE(V{ zM9JnRAJ;;1904T+c`eauuy4lqjEQz$&8ruzUCz|ICY@+HesDf%C-0!RU`4p%qD&vJ zC+q1zJ^)HWFN+`7s!($j?tBg_HSGHl3e@;>(qW3!OFV0^uW@PHfIe)B=u2_#E>fZ4 zpD*xxq@D}p=mw@8KC>rsu~$k1KOcy#&f^O^cptvRVMvFn8C{-aB#SFtc>u}fMujJL z=RQ_vf*j8kW_>Z;l8+NqJz8)weO)sM5X?&UpM$&=9E0BZM4$S=)xSOd<#-9to8?o@ z!$`C+p6eB7*n90~^S7z`YLA-8j|xV_9YzoAh{BeQM@1#N8zTr8iGx%1BLJqA!3$|1WCIaZaxHsI>D?q+Qp4C*-64%*$aeNsU z#~C<3o;UNT@5t9=k(P(E8p|%zbEzH5nK~d@^z8@C6?)WO|G$dm>Kk zA=bR@3}Bn9(lVVKos1Vcl{p=BFMmOlgF){TADhO@rU4RMjWW+9?H-zrThDZ!b>$I_ zy&-{(7m|{HGL~+SaUinB&a(4bs9)Dt-;o89S2Cq2+YFD0}CG7{@MWO znp^}Dp#a$Q=z87vH*uelnzkj2b4s*D;+$mkYo#(E+H|DIRPMTXDOXbUDFQ6DNv}_*0s7;1R$&Dz_%)1yXX$L$c7c4Jq|evK6`|Q7a0B&Q8N+ z&kN+_{x}c!gBtR_V_90{z2!S~I<qsZ34H|IoO!~n9yr}_#Ul^ZrBM|bGE$pO$2sJ2fxOc?`(guAA*HGQ6fg= zCKnY~*0>DLbtsVq8D&a%#L8msFJF-*?e{!8YB(rBy!00$b2(U zx-_w>l9`N9F|=;0vuW;lVElbzF@i}&dT32cq`U^Pb6zfY(l1bKGD#aFFK%iW56k#FvCwe3uN_Z1wHBPxzq ze`i3iQg#pydLY5`7GHs)II1GWNiOkj1{e})_tOKD2(dBLuMSG#r%yn1CE9LTeKv>t znr6O2`sGNSh~zYEN_ZqKWJtIvdgCGKqi@f09rrUiTTmrHPHiTpk`0pQIlxEiD~~MPvA8xdJDN9 zk4iiLfw1R9>68o?{)o{Q>;q5(r9UJuwF%c1N&BqP0j5a6uGA1E0u_nqP{4yTc3xee z3o~xpl`wHk;wRC@O0@K`%ve`R9Vne`&TXn4Iy##|s2*PF_arI~%18p-HJ2R{`S|)kxzG34pn1$%o}>9)M=a!?GQg}c(t#QB@rNfiuJEwz|Qx7LQj4Y zbM$`K$lbaHQwVBfR>G6-0DlzcTX`w@suY1e2^BYD^WTCall#qeKzhC#_kM34t$*nv zejaeZ37H)DKK({!y&IK2-}TLWkMlDyCmRLa&GB35955>mr&RBq$OW zeXB$vp#ZAw=pD=)O@CUL38nVWlG7a>LHP`J4vFQKnieQ1kfETes`i(;c*L!1xrQ+! z716=3{fOLPQ%~dSjxF`Erb!Ej6eKfp;=pY>>WL9?`&cud0o`$mzsb;@(#O6hTeVX3EHXYf~lxZj8SBy5B1SogW}5sDTt^Auk1EAYQH&o*>ubL;&x27;n;Zi z9L(y1zIQJ7O|7Am!@me-Z1YccwM~moZk3CJ6;7^;bs{b)pGAkI^@vISk#J~!tQv*b z{d(#AP+;3L>99W6RbczmT=V^^S7JJDpMuXy{_Vo1`g0qL=|pY5mZNuG+P+4D>xO%J z#&j#gU7x#y-Bc;hwQe&@9!&yRb6Kl?r1U#^vA071Wq5UBT2=PvS!BoQBJ0P76^ zVC?U?U{l||{2_o~xF9{Pmai`n=6m$r1Oi&X*WzJsajWF&dS^S<&}3bQ)z-q5hiGBF zZD`7n|0u8z(r!oA87Z#^jTpZ=OJ@;gj6sa;xOw-@^Q@ zhRmx<;n3~wPBujVFReiP?9 z(Fwvz%{A};q2ORN-Rqp3^jFi#TY5-kp7OI2bho~=$ie0vI!zor^k#OWYO^(wfP{x+ zl+e>Hh3G|M>^SMzv~=cSzv-#Ffc5xH+RMut4KAu(?^9N6ryN7E8r=T5_-fk> z>c!%b%{g)ts^D-IJup$1jxm)-iQp$VpTnd5`0OQ9_C#Qk*=ONsEF4vzOjKO zLgfN>bqN*sDe`FzUwH@);Ip9Bg%-+AKyy3lth?KGCqv`_z;m^bTc{^kR_bpHHYVonRg{ z>qtQVQ|uT=>wcxKe9kPb_4Ti(_<-2Y@qy(Z&qK|;WXx8IQ=##1cRdDg0(adL`()us zvM8|NPSw+C0X+tc^1UdMl75eH^1sfVpGmfeCP=(6~bm*nhQEornoZ7UmRGo*4CRc0s=l8ie2LuS|h)Ih%VBeZC{^jFN&^~K&e)AA*n#O=xSPWr@WYIis|&QF7!zAGMXry_oR zQ1AbCypomGLRhKu4qol?WQ=wNmYFQzchqCiTIu^!7fFH^8MMT{sMu=nVK>{zN|8rN zt2@=aEq;3E(*P;qjK3D0%^Xu*x}bi!INH^r<8poh#t6!!vt5(ci*2=;y6sc8D_yKS zo-rtYFQX*dr(i4aRfq*e;^xbD%-5OgWST81<$QgZSp-|6JS4frTUnu!W`;i{zo+)s z{f?LkW|dI+-iI7yS!M6HU$3zKGJD-xSj_+XW#y=8HZ@1Lr9E}|+EsI9Wq9&pj#07S z9G~CC%hBI&-`}O1jCl7#7kBrZ`c*3p!1I!A0AS6pHen}<5*Rh4*&-|C+0}7a1rd?M<`~NSCZU+EP*@vv>gk6K_T}x$ zqT}M9&_ONLV=zx|H1k8x5?zgIEPG{EDX!G_`c{cAX&Vx2GY7jos z>ivGTDPm_i%ZOXGNso9Q9JO)i<=IhnlzDgJUKPLacG*eUR^WEK#E;b)- zB+tiYJ>|MHr8x0wYbcT3`@!e8gkAOBd>}#d9{<&AWXy!Gc*h5SE=&Q_P+pFh^7^c_ z^WW{EXTPF&KPaqM_ekK=9<=<8@!mE$x-``O(EH{`@OOqPp|TOrcl_>i^-}H@^&nBp zCW+tqalnmH^-+$zT|s^F6-z35UQUivjtTRJG5>6T+DXSdYC81eE~wd`q;%k*ASeSb z(Y}{Nm;0i0FsiK1ao|1VZ~TilB*{TTZ2W_{Z=G-o4YFXMURC;6F8c1L_RrM5grm;u z;tq(qoBT#RuuZ}{%7J&LdJfhK;$Kr64aaNgkGEt9TY3;Dg7)m){+zz%tt0nr`D$>Scx#0Ex}W#G8@H`~;@XcWfX;ff zKQQ}W5WfcC*^vFCPcb z)cMzy&_A!nX$i~J6uT}au;0oRS?Wi`8Tz{|_yAwQBVp%r7!1+OjsDs_FYxOw1JeHtU`t`zL)X+j9MPW^!J0%Vp;Ia@OW& zT{9@MN>oRsuPHc|-*-1>*6G$LSs9H2Z%Y@o#qB(3I3aUGvb%l6`~6GIru5&Ur0Ss2 zIqJbo8;*zC4@ACEigW6L_cj}!DI@Q3c>i-`b(4q7kVYnLw6(47dIK)Bv8bDJhG zix>!Jiu}gTqxuzcKW#nqZPf10AbT~qXrD zllO#MH9Eesil1k7S^zb4@OJn&T$n0#Xf;q*;bv~vxSTZ2H@sFb+7k#&Q6(nu7QRDTZg zxcblF#}b(~9*I|Wgc_w^>Tdk>j%Ol5ypQFt18*?sk6o=nlMh5y^RC*Mx%_8#^WOCB zDsMa-C}66YkGTN_oX*z0-59U?sKN|RApZo(Vm@=OS$JEKpZGUr#jIqgXcR-#@rVD} zoM%l>^lgziB8Gqc^@nO)?dn%sC+s_0?JKGF!FN5rs{7Kk6Yp6mvzaztIkY9I3~js0 z*7KLn{*w4bU%7+%Xbau>vpB+txpVIKWsdxY(I{SxV9hk=PJu=5ScaK>q7&a?6ykkA zM6^AWqvPp0$fJha(-DmlvuU~);_vr~!eghPW^J-2Rs7VO+|B$GKv|W{dB>Gczw`TE z$t9PvKf_GDoI7$v#fykVlj-Y7-=IzEZCM5-)32aqY(?3m3J=an|z%@GJ4FKXq?@x7@Rc1~RC#hgZ9D z{mNDNO)^g6x8g4G;)BJ;PzL+(a$SJZr9EEtW-!F4&&AY8Twg*{;HpSxZp+&GQldSchIm0Xw8qdTSo zK^kr-o7-jH9wgC9jQD?e0{U#<^AFFMF=6@C71uh@_YZ1yNAdr%090<*Qyo9XvL7gN zU0n|_+db(xR=hC}+W#RqA$~czKCytCESXSGEdHhJ+W(_iUtWflot}>BpR9j^IAk?i z;#Zrc)Z?JTUgild*wGeIK+O2M%5@y%w-)mXw8o$8(#8Y+Gdp$kfCsptcQd74wbqVm zD9<0?Xn1K6^iDLrIR+>jLfH^<9n%pq9w)ANb_rQoGyh_IGpf@)mGvt03$+&?w*)6E zd6p3fFWr&K)MkJp;z;{3WVyekMj~$XC37KA8MfDTF}_2i{NZ}F!Sg`FKk1LK*%vmk zW+fPwL+)EnB_cW${7&6Be5~F{l5nba<6X(S98>pzoumo!Kd!v5y6#k{#jriO87{c= z#$Snt0NXEqx{iC^%q|!*yD8;T0rsTI6yV zc{lp*y6WXzNP*ZC#0Vcp1ASP|hqKTf-{$K}CzMM3;A#5xb`~%4_Q=9}z8c|bM|$|{ z_aRIbcbvW<)p1E7n0NZobs?=nX)hcyrB ztS_FD0QvzFH1uKp-HJvEj^K4(nvy-$d%)LB?^XBpvW3)n*!hbM@LrH{OJrDn zl<)di;-PZmSZAUG+(X!LfSkm~2Zjm13l;D~X?x#)g*hjfVXuFWo@ZX~bm*lZj@3c; z?5>wQuemmma`T!mAwJ{cY!1Jf#iPty#_EWx^TUJa&PBb1R0YlvkPG+QR9Z>Fg`G-H z1MRqPz=PIqkU^{ezRk45aK`4gXwxxOEXf4E%?xN?4ROuOA}zT2dH_3P%UGN$;G zlwAASd;JMP54at^U(wSs`|GP*6wr~2?ANxChJs_0-T#Mbh<(li#kH7yz>D;r{r?#w7pl{rl2u+qb9to_suH>x<|sV=LRcue>_F z>x!$>M{ayqIuAQGe=Z49zK=mkWU4_zXRwwfpyGft@K1NN*96{p1kkxK#3L?8n3E$o zd)o+}&D75&P~La&V0xO2og+(T=yXbR{>)ILYeUdDXDD=ox^cD`L6F%Tu_bUCenGy1 zA;qBMB%ZfyJ{r}LNiyj%%G`xaVqlbx7{rW@LnzE5&c={@;JO>qbr)ZnuDsyFpx;h1 z$KSf=u5=d$NMZO8hH>NSHEAHfAB|Yeq(ncN#4Z@dP^um1_|ogIrSIWre)Z1VBU}8x zf8npDHI5jN?Ad?lVBq_~V-GRf_*(kxhkrO-ec{DfA3*C^V2RQ2+G2u_8AWhEbnkEL zE&4GW5VPJmbD_*f5(FAL%B<2fIGM#r^K09;rDtB=oGvBcmoEw@P=#)=Z;7OHt{q7j zciA8qW3fz(H<`0A8p9!ys9;J^%V|OA24(EcBuq!p(nd=oqryZ(G~Om1aoXfH@PTsNf0QCIs{r9DB5kS=$4ZG@si!y_r)6o$>7fWl5 z7o+FH_&JJT;9164IPMzmKfCXq^qu?fi9qxBe&ruVW!V8%a^KAe?tM=^p1yYX_tQ(; zUQ2)B){mnH)?&P~GVWMJlVnfTMo>ULh~(FhGgs`PJGNL$u_pn<8`Z}GFFe&{u=1m}4=If%Qg~M5#E1Vw_?c7Jpgn)A9J-tcG z0z?*z6eA-{CmF*xIkw|Zxjl^}VqhF{+!bZ3pq4IPn!fO{A4!*=e<3!1AtUw&({~A; z9)9NOq<>(w+0EBn6Uhi|9XEn=2p!?sSSzUGg-`FxkS_xZZEX2ah^X$omynA-lahx zFdFC^*eM(3CPw+!cataCJt%dkHD3G7U_Fh*q*36rp41x6tl=z{G8_RFfb zAOFCIN$Q=G=2HFk20Z=cJHD6heDooz&id1)wd>OLmtM}`G6QPrJgbm-$6<)`47-Wb z{Q!noyU@|^7>Up5qx>9TkmTMc9#3EY!5wMy8?TXgygB`yPyhLd1x!*p!1nF??n*E5 z+n>MnW08ms9j1RR-`LmPz<9d%bO!97EeemGl;*(?L!ggi`*-Zw6BQ;iNjz&Go_*z| zv~Gyu+H_~!^6Yc7Q z&pw?VdF~miWVWYIzVC-3!Q)_Q3>YF;CW@R67yjWmF^=c{#&hhMmoTy9(v`Xuwh?@x z3~$}JGX%B4XtAY13h%?uK10T9V_LCzNhV0;&hi{|q9d*G$YHYL8iw(-c+^&u;iFhO zp2{z?65`Ik55IZ@bvkvqbL>JyW5%d}aet95)+nkFMqu=#$p@zeoJThE1`OXSvhh1` z0u~#U2e|hWcX&H)g-LjPwksL`6p)5D7SxY(jyexsH49OZA9h7N| z`Yoj&k;_xz`4N|He+-NSH!&5j1I9RcTwAocjJ1#MyrC|M`M8Toa6}c1hkE?jS_Fy1 z=vbyDa>8j%;9i2kx$ct7BFS{z1PkBaPsPsTFFcB>|r(LbeP}3BKx<5G!AC zdt<}sb#x441Tm&GmheGSgYf<7U(+ISWQ0=gQb5g=twVVRbHFU0{OODQ!cvtFMeI(tLtruV2!n)SS(poBc zENOm(k+E~o!%lx1o}l1J`Z3mZGVx`3z9V^K<=q0KJV7G8Q?kF!cRBxA(eLbz9xz7W1E4*^RK*;zV!XuBJjNWs*BQluX{&K z!O`fso!&Vht6JfykA;2hwBIRy7xK@5Z6psnL5S~a=%cmi zT4Eaqq@6x!(&RDx!iVv*t=efPO!Dit_`QCRxKQ~=kBzBCzXAaJI%vW~pMsCdmWdN9 zj;(+@##Ah)x!-ujmFXSGxP=J%D@%A6{vUqxYuK$d;j=i9JtLB)Srum;*$rJfR*Yl1 zvF;so#tpG63{ffs;YT(*_Qlt?q=%k)DxJG-UApP=E18P3HnL$>M+|p1hmMW_$BaKt zPB)s2W&$S5es?+}LP!~t{CP@EhBZNCxtWgGDIeJB>>-hv;YM&r!)Q6R*-l3~oSSH( zqMx#=B?!v0MT?Plj5JC~Wmz<*U_?0_$8@6&lC=&xaxI(JbF|P5=RE|LmR7#&s;k5C z5eF0vN9acS{R8iKC-dnRQfc%8Bea`i?p+AK{6jFyjdcU!^g0LqCcO>pMz&y8XzVEX z6PsT^&R$LLxbV_+15?U0hW9=BL{xK~#M0p`;V3E-miOik&dBqh>gO0$7b6ZGs=V`I zeyz+XC!_2#>_9QX7x9AEOv;R_nRpUo=#$O>9GyRo;W=hihVlY!=Tp|Kv!Ru84(qTZ z10&eP$#3Io(#FJd2aL-59?C8?%tz_7A7*NRhFfE({)~C*1nwVr#|=?}d4ztrpURk* zFs3?HQNmqdl>SxAC;%;@%$49PypS?o_|3V!CZnXg<>bz@zmX)>1?x9vqc)6;TMaQx z_e);N1AT%`1Q4SHVdYD`={DyxIP?`|heN7;*kNr9(Y^6!*zYHQz)^)~BSAM#*@~(| z?VuRr{!hM|JZu%$F=Sc0)rZ|$#FX1rOO~kd?`}eYDz)P{V)X2ls$>AGX zCDg$tp0xj)eL>$x!c@B@#!b3MD|_=P5TwQ2f??FOVcH!`7~8drmE(s2l5M#!kn1N$53$*DHUx z9Tg-7>|S4N@4fcA^q#A(ODE8!CU#uYqKFFu;{oJ0Fpkne;{tzX9;`l$YE%=zPA`O( z*5_NHmf(BU3NJ#gd?92*#mWdCUK2iOMi4V5WnNWJegcYW;D_ypqjw{pwk zSJw2?=8OkpuJ254mLIO+wGJIQoUH8s+5`78@LbRQLZ%6sEb*EaCR}~tMd{WXZ;onX zV>AmDtejRSS$%9}O(fjVS7&NF;}Ry3_3t1M2F!NcmTlY8qc|8Bl2yL>>T6@Zm(>PN z)3I!5REB~PEf->karb4c!i`{b<6LbpLW9xjyaH$Yhtu51?P;6?r}E7(N>l!az$)xc zZ*oz+SDAvb1P3%0R4vm>YzhOk)QBS~rwMNO?r|>qT7@7M6@}M+ohp~bO#Qh0yz^u6 zSrtmW*vEH+Iv1wBoGBu%2jBu_?GilkY5>GIn%6QM#@ptT{1wg!O8TS*-5dos(0>*O z(-GTiDH(KXiIrQoKl)(M+DeI&8pp{AH5EO4G@cf+bqsu?8Jsbn9E9?$7=t&)e4V!^yqWX#z^fs>&}YxOD^1acDn7+M+gMAN8g-utl|_?Hi!4+Sm{@jAFy2~U1zK( zjN>w$>v`YI-kj1Ulugeg?{=W0=Wublj9+ookZGrBzKUO9qTrED$4wytA}&U?_U zCed_M2XdWkFkt27c#}HA08fVx(AH$~NWViqmKU^JFKpS8UZ$c%zv1R9uSxG9Y0`wA z26p;3BLk7>JNeB7f_7mX&1uidR=k35$H!x3DB;VHLk;Pv2Z7cSoX#;dW$k&6}ku!W+#;|pQqJ`QUGUXGlowxnE$yP}$NuTmV%jHiG493!&tV63Zo1vDHz60J z-TbLs4<02G60{76Z-P@>5IpNK@McVb$NoS6>Z>uRW_)u}QzpbY9*mw8U!sQH(q7rw zneJ?*H<*rN9JOfH91@QU(pC6{4xAUVCFYsPniX4Cp6Sb&j5O)vKEXS~5}haA!5Q(z znKLCvMv%!gHoitQGtADR)7e*fUv|#<>Dr4fjp;e=>WH}I&XzivAscxHGv<#!_)vP} zTVJD6!Ap$XW5fkJd=#F?o-==Ox?$~xaD+PSHm0x4f?I+6*kg~S2fzDG9Dvc4*PC5H zK^VX=Y?{9)UB6~Sn#F3?o$mAmO6+x3m3Cn|%N$)KPp85WnsagH-$OFWxohq0JQwx; zcVK|-e(Zs?$IK#_$A(hBXE1G?N4EN`vs1s*IA}UFdk*Z2Q93i(&gXN{E=M{-DBa0^ zqklD&j@CcB<)wIvj^}e9e~zXBH$bk+nsb8q20e80rOSvW?Ho;7e?zuN@IkbhnkFrx0!kP72fKT#%;~) zd1=$arD+W#1ky3?l|NR!>||ljBhX_V(;1vPp7924(q8_vcqoJ%Ff^0_SFY+rVxaqnASeFXP*2;Mc!j`uo%6~MUoTl?GA3Js6Fr#CZGcUuLGl_mZ zWyIPCY1rmz?tMEt>{;e!zjkOp@{gl81}Z9(i+PgV`HPpOvlcE+^VkmCgS<5P&FO58 zxGrOz6>XSPLkjy<;L$(PDRo^F_jRG&Ch~7@ggy-(id(Jz6co^A8IUL!zH-*Je|6fo z?b;pd!>FH30=sacvSa0(iE7tdFc9&{G=*%*{tf_ar$Y7V2kuF)9y%~4nUF1K(5Fjg z%uWku&Po?9U6B?Nc=-KzC9J@>Ys>ER6jLa-z`F%c7o<6rcd_)71zE<^q|RNm6y9AD zvNiHtfticC*3S1HxG!x%FU@9sqFp>~(E#b9jhnDv*JMwj;hqbgmdD?~zCFzvA}=50 zAqxUO&E~cHIBprVI~UuvftOEY^ij6@hRTXh$FV8*E|@$Lzn_&kr)it$cg>Tc$u=!` zDAsK#zI@xiOB896*)6+GIu%Xb=2JBdrtfE6v(J9`qoH5bSq`xHG)14Bm9UA8v}Vryboq)kX)bHR+0P^UzI^Rd-PUF|Ng_P=#FOd1 z?|vhl7zf4~Bxz<338oDW7lLw(P^*ezvi2~}DP7%U6F5I@&S&hmX& zY34%k4s7qjj|sogYunk{E#s~;xScr(6NxO*eT`^#n8BQn5>{~Su=Ce%OlPlM%Oi#l zvLg2NQFi&ny63Y;A(|wR%Sz2DGS5BpZ2H14{e0=$v2Fh5U;Ufux=)@(Ai{!IodA#J zwBqGHH^1;g`k9~qTVpxHuw#DtXZ~`!`ct2Y$5n+<8}_Slv(+5i)2jr?zxSPQr9~{* zc>%$(bLVbm8jS0jM43JUApL`$yzeJj^(TM)-_n2i+LuS1@$^6bgU@{-UH_T&X-1wf z@eC%m?711tCzKweH8I89Ipj0S@bVpp_xfl@SKp7Mfb@)n8H(GdP`M_M&XH9J8bgzl zmX4ou)>)Zh%)3VZ*XZn~BIO%j`+EA)H@=Zx*}Y@r=}-CdO(YY}{+oX%E#u>*E0=ju%^!YKvn}OKVFKD_4sPNfcKAxnrx>b z@?^W=#N9^l_y^zl4$j@ow1MF08W#S(U0Sg}>?4SCLdHDAtSAYDK zFQtdaA^#t~@m=ZjpZHW-e(@z~J^^wlm=L6_B-}oo=VL0Bqq0FN-GfP!891c2av8b2 zGW=%Scm1~C7`Yak{VE*~k&Ju>o--k7S^EW4#W}~`Y9rsRKDy`e$B?ZxX))_C4Lt@m zn}WK1`;PR@Z+t!d>p%K~G1GW4M}7jHxNzgfwC-K+C1FDU+xNp5cMPZJc&m?gqCfuZ z+uu&V{h$8h7*3KM;_$=gKlKyo&%NjUX$^x_CN+ZJMx4>_-oOX=GEXPm!XV``?A%2p zp!5yjbIr9`LXx#g447^1Z*|qqox9R^k;#Ad`@cKljHmbUXMPNfxA2rklDmENEQC>} zl#{~OQ-}Ip*@hNZ3A=?)S>C_+t&_cL|MoWB2)R->S-w!xSN3ZY-gC_@lVvNY+&LHj z&t#;5v+EzaE|=3xH>_C`>q(VVTqr*iT`z5ZDShF8J^9Hnr$g@Z)~`!H`NN+`=YH^` zX%2&G_G72MKb>>BueU3oY$56MGe7rNPWSL(`>(@K{;5xVGHoKMq+J*(TLozkDC`f z^B7dX%L^MN4Yae`B&Uy#*6d#gQ(=vmT2$QAX#1aO-=8@NlS0zy2^p>I>}V7%X@BV1 zr{b|k7i`$bx`;llBgGRauDR%i7GY&A9uB=^7`2?MJjHinb%)I-SThj9+c3RM5EEpx@AjV|Q;p1uU zfozIWLE7?_6$}^6x|zpXI@i}SMw>~9Ib`U~GL{8Cwjuc1wQJHp{-?i{e&b*McHkSc zz(B`2M~migyRIg(Vn88x_ded|Q|M2*QlE;WR<( zO`A5QU--peOn>`Ve>pYFol#umRNgsdKO;bk4kBNc4x_}o(R85ZkvRC+kNimbz)yTC zwP1~bvGNA@)<<>LqC;Rj-U*|+tHnI(YtyGc^but2{CFHQUq*mB8$-Gf-8IrENym9O zrni3d)^y7!Kaox_vNqI1r$+_-p4pW3&t_C&)(DJNo6I8_rNeQQ`M@VA^Nh-f(k1M3 zyU;S*{RAV*q>JgbolR+{d{u3T0+6I;wNVFWLp}Zq3sOGv+_Uf@>)6a?1e%c?iR`pT zABXFsF&`%(`9s&gGe%E6aiM?DKIfeDg-b6Ap7_<@{-@*Oi@ET`oTDdFBc=`w!5`0- zys~lA+39naT#!Ef*Zy*9j#ah*F$4LZiR{k;XCJT3j_2sOj%^X8!e7AYbJ0l`_EdKI z(OFiNxFD>_l7e>b@3x1YwPAhw=?g!Xe&VnE?AYa>^YoL{pLZ15q9V*iUW;pal{v__ zPw9Dyfb%#s&YtEa3Yr295bwzfFRM|+H5d76pF^zhkUF|*B?>@~P75CEH|@g@aGfP9 zO$^-TP!SP?<#5-Gd75O(U5`GT-p{&BPUR>%!hW#K{NhV5Nq_I!E7RZj^__NcCM!Tq_|k0M#(CB@i}+?{vf1?g{Ibt!|Ae=j5PTj1&R zv$ks{bewzSMCwPrqH%s7d8};DcEM(vl|hmCA8Ci!S3Zy@3Vjd<_=g+LbaZ9A+BA_Y z&=!Y;ts%PDr}ps)DpH?d5p5p=a0B{wA(e5`(ZOODu78vyzXNGk5rf^(x?)%oUY#}* z-o5EP?@rG>{WLMnXdG8Y(|KETv*6hU`;Vpu*hXdE&WDfstuaSGzRmSF-k6@Ff}@4s zH5LQwok_h}*oFD~S?E@bBU*9yi}6K2`{7&T>1gYCaEY<0ewlvO0)8vYmZy-8{WJU1 zwbx&tVH}5TXeeLU)dgT|ni*KR3a#$895Zezume(cERLm~E4wdL(uE$zTt`%aa` zk@7h51XKL;nUgTd2^k8_>H_DL=1gaYWRG9m9P3-?kbL+4d(&e$pxY=3a~gpqgia;d zd+U+O&a<4)ZThhvqn%124apPVCy7%{M-sW8Y+74bloq;A<QlrK*HZgO#AtEDkNO+ z(?uy?WoqQB(Ai{)(@ibYRTtU}Am-9Pi_W4<*IBuh{*=>%0Z17LI!=#fpwZY57ea~&9;lvn6}acH~B)j7qn{x%z# zTzD3?b>|~)|O8#@(qx0y(lgd-V?rz6^S+(cW=(3g(dw_t@e)iF) zo7Sw$`dfPiMmj$0t(yJTJ$HpOf1G6SQ=4B5 z*?T6R3U`=9l{&C!b@bRgswfw)T0VZ;<${?kAbISi*RrSDoQiCDKQiZNji2hs{M^Nh z(jqbX1jX?*g3f{0x5B&Ivp{FWwF5HOG#FaDYz5T>8=}gzuvvxu8VSJxfO$Mte6hA` z?r?u?$5RL0o!M_kinCH`yK*B76Umv+j)y5%cl zO^AYJObqx~*5IQ!OZ6uV0RB*IF!PzC+N5AdXle4v*y>s2N1giv9xH7%q;Zz($>^h7 zt@k0O#5rYZAy4>I4@QIVK2CVa3Z{Ub)rGigk8#(_JX~NXkv8rg1Y?6_#9}ZmTR(g% z%$Uy2BgK!s@N#Ot$-=_=&8`c$3Hx>S+I7%NzpAiZ`i3W8xsJ_SynJa|wxLV0A#`l| zfDta`^_}A8G|>a-(gO^Vb@U762V~RgGkrl9gl*_cdlKIrxYm#4d}nGKr*}08x9tcx zPG>wRA|D0Ud3Q>V9rc33%fepHxiZ_mkClUCp;AWh%qs6jk*-{_EcMZOLlq!d5Z$C7 zW7OIXl`A`R?8rc-2nd4``7zBwoH&H6MwP?RHBP3Lzj;2D_(&ESc3uvR5ha3CD@$DP zDpt(*dj8e{XZ(==@fwBG82clTPPrM>c}L-qiZAb890j&)Mds{j9@%1r1

_jAR!5 zW0_t&-cEdPfvv-Z+qAcgvnCzQX5vI@_*}O{;oZ-?bc40MIGiR-ihhwII-Hi_`n!*g z(@AYBX)@&oE@yhCvm@Qc`z7>^NB${e=F#%==?p0khp6K6F^mRxTlegWMbjKf*s*U9 z4hYk1a6lauRraFlfb#|aV`a@nI}g9^xo8>359dq2@5EPr@Ac&+laTVQPE9P(7C(;u zr;xEua#p@p@C_*Yz&NO038S9j8~yCrF9q-UDdbFyCgph5&`Uxe*`%rK1QkydDC>YS zuA$@970{3`vdS72K_-obd=Kp_xrSpeFa|!Z9f4zx(crDzV$#7vR~7imxZbwg&?R)4`ix{3Y!*#gT2e19A9?2K zD7Sqb`EpH@Md-PM${fi_olxi9Yrod6$R3MY(9$Qdgze&_h5hopo%m?v_rpGOOjIt? zcLooGwH#z|q0+m@WwMp=nxD&U>?TuYKq9WEnuV^NweJY*v#b3mixp2eT$;Y(;adpFn3{eI}1=LBF|xxrvh9ti`3@Fq3IHMGHR*J?Z*0F)G}FC){@= zFZkJ3@LN`;hr}rR`Thm}$Bv~x@w?K+IkXFL?$Ylbmq`v^x}hAAG>Nr~xL>bX#`9*| zo^}{UGO+ED{`?-?XISHM;U=Nop634ZLh_Um-%SU#WkBCOC%8Hpu+ zB#=$=`((t$$jeTAlCAXfQmW1L8$HfuRkG`ryhh^eQYs)Eu@`q-L;a*|sat}uamd^I z58YO?x`e^nsBF_N7~jz30juO(YP_eKqXXGUUg_Oyqk8*Oi_atNLYZgXx9GnzFZg}5yl?d{hq|p`j$)_sJ3N%jS8?9*ujOmF zZM+Z|@oyanDq|~em(MYBJKHbwfkVE$`OzS9KW+PkY^nQ;R>oHTF2ji5o?`?gwD2OQ z(Q-N}%i6-ACEs~Xll)56PeyM9A zk@xGq@h`tlWI0Rc?6g_qzU!>;uM&)F<@D&Z@0)&?9~b8%c#B7sGP3A>FJ8DH=0;wz z>Aduv`|qJlY-UUwa&=~ZKTHPH()HUOyr1<~cBG$VzM!i%`WqAuL@VSg6N&uqsqjg@kJEt|l1+moLgQ6{D8GAJ_V*CwUtdlB&8{}IVXyD6EVg$U+G>B*VWTWp@iXVBgE1bl4bD!@W7@df?zB_sy9^^VbOg*`tkh1rKEI9p zuKYQI0MDaL_{dB~>DP)@%aaJcT)X6AtF=(~oEw49Bd(|?g~ zb)K9GcK~PQ@fV*9`=#TFTZ1DDzpB&hPn{E)idOFLGU-QtKm0R%IGce4s=#8}Os71$ z=jJ6YlhE3xY`ZKg28B6Cv0v~-ff2mX_`YbvF`SHj*zW|dQC)-GG69pJDY!Q~qc+uJ zLgu}G=~~4#hrPRxx(^z9|B)>5m=4F4&sE-z^X2=;2sQeyWx^Et1UmjGZ;r5TTDva2 z|GMj`B6=Vk&iD7-5Tmo6&!_0z_1MF)$los3hj5C>W#^n5Q?u|sqpxe&ZYwIPHbjBY zkNSqQksA`gRvsCoY=?%=W3wLH{9^jlEgxW-z@~Cn(Z%&xf8l#nbh+^B1?xAmHX8{O z`%(df&W(Ajmj(aEcO7`b6OQDClRnFR3o~=pKsmc&xJJOW!c~r?lJ1 z&vEbgg$?y?$i#3wWh|p@$F&0^I-K*g90|rR>{P+!zGc~xhCXGgeC`t5*p&C{BnI_6 zW52P|(c|zV29eVy{Y^v;O$Q?|a~iq6{SUJSlM9W0{RgZ8(LXcRGW6j9R&zbX)T}!n zc_^ln`m`ivG5Qshq08_C)YBy=VC1d5{XT0`>|@r0s-f3>6tIc;;t}ai$GP_Wi&!i2 z9of`eX-Pl5>x!$XlH8o0dTDcdjrD_`dg+CDP{aZg5FGD0>+@Jx)T$HpnSGn{#YkQn z`FU6;fkp46}b3OL@C|o{`(d zequC+Yl)Az^rQwHak+o_b<%$2FTmKKUmEzFb?lfJ<5k}0RJLUz166Y00JQmJ`L+1ZK&d(j7H@jOZ$~K6>bC01_8H@ zT{l`(sgQsPIz|bbwz5B6a?UyFBN)Z6{NVd6j`R?ht5bTtIvi=U*tb1+Uo13r3&QQ= zuc9nlB?PEC0ZYTO@WOHL2GD784o4D-fVTxrb?yv4Oh`Hv$0xgN#KHLJyKasoRdL&9 zksYOT{^IjnUQQ1&QhpUdv{MHL`As7$&AXx~=|?fPbZRuRWVMIO7tE^j#964>P@;Zj& zhh0xwgELvf)+8zxT(82%XNt;X&$&#)`{+$Kr!U`ed-}m6kA!}5>W$a+5w~Bt^LueW zpM2j3qeAj*ersq`O*+}vU9zvkFi(dM@-6`r^RrR@KDBEwo{sqV3(u!>py91Iy*vCv zgEi$!O)M`zNDzGxAJNASKlkb@>GJbg6qxVBb`juYIuiPirxiS<9>>=lbfC&XiD8~s zb^!-{tSZw^Ev7m)EQY!OOjgrsS6$;sK35*1Dx$E5UBD>c8YB)fbtN4w+pZzB;U5OD zx4vLt1fP1ZE4)hpFGdGBTYC+Rku7LrpDtgrICSV&?z|)2%|gEV(N1fe$C`WFnV#=- ztY=B4yziRpBB0Zz5zom8868cSR`S)MQ=vJ@q{aeb1*U4~&sbTauXl(+6chbF{lh;x zr1R+u?Uqm3TurR?;%i&NC->n1F(pxXkaw*(QBRf?c!#T>%2)a{2IWNWsdX5P*w|qW z%*u?5KR?B~MLylf0X6l6mD0+T>oYnSVI0QS5zB;6>{N0wkF#lCW1t!x)dFXGV4*lK z*VLAtax3VlJOswmIChka6v>T2^=Kqa6puzc5mUQmGV%f9=sO;MFcKxJnD*wh$9s4R zo_^tl8#bn^Nu)R>wvYR@TBTv4$YhpD&Mwe#6zDzw`piq1pv+<guwY>rXwex_~ay~LEZiMqsZ7t z;tG+BfZlZ+u&(fh{78UbGa+H7-b$0QVdQtQ*uIH#aW=mRY71p|f-*q%C0I{EEfTWav$BhWx+z5vTU680oig224O(Wip_i z%F3G!8Vq>nQ?h*T6OX0;=PO@Id-m)^SI$hEShM6J9&c-K|DLO_izybdX$Q|eUvRbBIlEFmXREB9S-G7y=AiMGLlkOZ~gQC{$JC_ zZ@wi9qU=ACXgVTq#YU$?=%_wix+vDE8YJ)>3X?L?a>2Jh{Ki+~HS)ac*Q`xH^^uRI zPrUa7@k9_SjiTzX%NUh|1{h8BwHWa%V|%hE@d_9|PzVfPY_)RwWD9nO}>R{Ii z@X2KV?-&0^y5#k9L*7jMS#{W0vaq`WH1c4;WH2&_Pvb0tGVl7yYBFD)+UGv@iS%Rd z|DjkA(xA=+TDLKp=#Rh?I~75l_BCvkZ&`qO@FF$5|C?p7NSeKmeKxLL8^Ijj zMCMPh#!8KTuUZER(9zO-P(N-b`iLykZl>87B(G(v=p~yrh23&MQy)U!&~9Bo@NCtp zNeUBKKl$Mwj*}adYqzZU)<$%r73XZRK4UFrARy;pX*V!!+qo_M6ZGv#jVW!v`WL;;18Z?ZH(|CU-=YDTmEZoclx7EFzIBKP9*2q z?nX1-jN}G`>68<5O42%DbOCjr0C&lzb7I>{t>69f7t>DGsS)^1mtPsh2p_rOUFl-> zDeP7ZM4IR~&K6^F-`xN^!RMfQsoh%>eCe@Y&h;!>`wpam?YnUF52hDidMQ0m*2_il z9C`DxOT}%;bS_Y;aKFT8vx0l=g_p44;;Q)GYL&hfOEZPut$@i}Ls>KnjJx)M2g_*k zgN)`AFTTLEk!SOjhKB^&E$4oT%|~)Mt-;l)`&KT4fwLesn9uWfp&4klSXi|Q9bEu& z5a-h}P#4R(?}^8GG#F1qVZQuY7RS=L8?8k~a@ms9w|q(NyuX#Gk9MXG-_v(Z6yU%P z`kF;Ft$GlnfrbHu3y$uj>^4?yhQMZL9pM!Bxg=I*65u(O6q#M{HI_4{m8z99S!g>@ zse50c4opNgckfFB+jfTGd7UW(mf2Yq<>N5VUAG~sGukqS0+jwcs4{WX*=+L-B&?R7 zwJz0gI4u`6*zZchB)?M~&L&2uxpjLiS}O0_)WueADeJFM-rmGHj=(ejQOvF(;6kN7 z{%$_TzP=2efG=VWe4M2_e}ayI5gd%xSt{z2Wgfs8c2vep{Bjl)j3>;Dc1e^%gORib z^3?!m*AgcA^(F$v~68crA4!un5p_M>m1WkY?0L3Qr8y4J^X z_G1K1u!&C{&}l7!(QfZ;_q_o=3F-mpJqXRGVEnS+{;FP#Vlg>sVU%tC8 zP^`m10O%Br^9ZbKi?CfQ;a?JmD8=sSv;t$;1#FkW9T*y1!!@JsAtAa0og0sm=KE*} zp8NRZ3D;OT58b)`PWx&yazOhEMu>(QL|L#nUtU_`+h|`e!Sa`6$kWo6F&bM@@N5M6nzn13kZgbDR)WlyHK=(fV z!t<eGdTC5G>h`l4YPQwd^$gioC#lqz9#5p8*)~{Zh`bh^@_*6BO`)@9lUOfwRr^u^R+bDlGmQ9#Y0$dNMd z;>t}HfERDb1SC-2t%7{{jW_TCwlPinXy_`d2wcx@J+^-niCvRov96Sfev(+m0Wome zt-fv(W4dfWb98Q##kTduUwh@1^u*>Dq8j5YbiTfTzBwEoH!M%MUX1GkS-ItO-WxBw zGA(k-8}il9BE{;pZu(sj(_mNShvs&0a_y*apih1D@qI_K3J;sJq#47*1TaD*~RJ2_Vrs@~k?+&p0G*x3)pQ5U)8yo5q zHjE9?S6!3;ao*op4?iOlCSgbr#&`+^!ojLuQHCkwDF}tYmWd*RX->>|$4bv+JzbQs zu@onN)v|09p&O1;`J!MMK$#D0-5D5lpj0B6Z(BBQ!S%1@$ruPP91Fvh z_bp#+4BKE4PdFH0bTn}g$IX_FervN)lDs-Lu+$mlune|u%?jrEAtX|{ixA_w&{4$_ zK?b8Tj(X71VUP=)8NU45`F#a$jkX36ItE6QCEcW9N5Y)>Us`r)Fw)$6Fl5VMzexJY z6Qjvk^alYo**26UiO6M)0yBbO!m8UXFzR&b)Hzy*GiM7~&*eS;B9fKapPWZ@ir|aJ z3KJ0o`7FlCfw6lx{}?1iVi(6aFk*0GM2kzwpVD%)W?*DRR9|2;p`+v81B@id0wXfk zB5=?2R7MW_vu&`YPYJKcT_m(lQZneMQS1@M(P(r=ey={mZkeQ(cU`DW2gT-m-_c|% z_*%ZTQVw0)z`-|JqEpiKMn5}U)5NKclk=^`D29`CZd|;X1YpR436#zJ_j*=+mLIx&(dD?dBYC^;DaVPJhk&dYGOH z+4AZc=0e9h;_cLfK6GhRn02WNw%hkPmChH5YWZRtBcE2LqH}|u@&m6guGvI9E8%8B3ddgN{%xGG0Z0X|47QU*3dl?-UgOvcZbZ4c#$G)i^H83`Sj{!e0 zv(}wDH0+kXUz~k79mUNqA^#4TI{nB>f~czLG9AZ$b5rz42&Ybq0TvQzjFe?ne!0FK zi8|;nKlqRSh-q$AFx44I>nG6f`qX&(H7aHC1(hw6X<@sF4LjSe;rbvJJ5|0|ucgkD zry|+kZMy*fMXZX=8sKSEEgz2!-bhM7D|uIu*AjHnS0*k@M279BdOH$*J+KSXF$xpj za4-$OQ2#pDMWgwJ{u~QH7dpz=(7_0F)pTf$U!$FIH(62AWonBs-vH%lkh zOkl;Bl^h%7N7mc2wQdJX(u8Z8J&lqs9XU$HwI&hFiARF4Utsfhw!>h8CoA1vnEH5B zT+CCaAL4`_)8!^clSYw^wV$QL5{kn{8 zI-X8CS%Rr@4#jETL)VPz+%bG$%Mz4)R6)A0d`hM{l2ePx%wuf)zSutw>e8Pm9R#ED zuTzW=X_g})nmKbBZ`Z9SN_a?>_sSg>e zn{?_C9doGDsPVDelfJGG4_`yDhs*PeGt;;aO ze=_M=#z+Ds<6AI^8Aw4#p~26Q;3wf!*LGo-V=$7z2nVZrVi%+ZauT-DsWe>M06yp_ z&CpL_y9zpjGwi>~9OWPXBr1-vU9P+3Jyh*2C=fI^V25w9MOOJp=MW_i<@Z9bxjpoP z_S!D=nEEW@u3qR|FB$q1-o^J1o#_B8%$ng6YwphwkHvRl&ano+lRPuXcd)M~c7bb! ze4(G(3wTRFWxc=&DYaYTlZLTB{#nQw)&7m8#$8muSQXg|I!0G=UF9FyP)C{6i6*_o zSI`mpD(IMD&0>cJerx!zF72Rwo(ApAtS~7EFEi0v&R26}6yY@!7TGXHa_r<|yzi%s zcmYo492SI%wGRvgbBA?8oDdbZBj73*rxQ7%V96t&6%wa9?A^CNEnK=J%@v~=Z-pJ4 z8lH(@wir>!FwZ=vW!V4rCquE9A@3I;Rjs)d4&ZvSqdz-jL?| z_U}&(0;`4aM8nZ;oVmz;o1B^;2S4Oq*;4+MtEiMhxr-6m+sAZPRGC|MvP;Psc{;jUbpCL=}3R8uL_Ln(Ien<+K4)8 zrkJ3e0Y%o=^tnl=3BcHE8cgK*L+3KO6!fNCIp{6@b?~qgmNK7Ut(yg`Y0_`NOyCit zw(_lOF_|psb<=jiyMYm%8+J>*3&sHgC?DBsP&c0^{A3ul3-Cj(2POhz&@nI$@otl~ zjZC13(Z^cOf_8IRcggk2ih1cFV>tfdaQ2#x&?xArT&Yi$F`rN+9l;qGk*m4b*RcPg zo3em9f;pYys+~Fov|Tv(nXgcyH*{~}Cme-uk29@z&YXE^CP9a~ECzk_4e`CQ1a*)7 zg}_L#q3jpFmpp-6j{fq5zp+ zHSq!lrxCjuhv6?MeehH#uHat8?|Vyq7nT~Xe~^|f;gcJJ=qySvQ?5Q;mgFD%msT`~ z0eBn1LArsKHqP|OB)k+#^S89vEwu&SDxbNR_F%Y^N@)^He4FTfxD59r*P;3@-t3eu zG`wsr%#ic(pSRvl6fQ@TeU1oLnI?RPS^u&O3;4=W7EoB+{~jqk@uZXxYXD z90B(2r&C;6*9QxD*?5%jD$PUeSW?45++~B_BO20M6Dkraj5;%&f!ETmrE8b10_SD? zR(sTV>_9GPq?T;H}M2hw4sT8AK_PG}!@>7u{*d~ABpZ5qqWhsm=jaS~k> z3VQ6abF>BNW?1|%+nrqlmna>$<#JuKRmNU!*VhxWcJkgOC0oqbYvMo?<8$3w#Ijc) ziN~!jIqzCxiL`vYHxUrGd2IaJCzT0i==`3kKsJP0H@#ddel&)3UG0*#d1CO$lYs7- z-8vC>bqj1{;TM^dxC}cFx;W`UAVgVkqgOWKY7yOHuty;p6Pi$VXRa~N81K4&b)I|e ziTI69`M|D%9~>LVjzrZ@6IhSdwc)(~;=NgPs zFc*pKB|rAdtRnDjdy1}J&wIyHEm>B=;RH&7jCfj0>;zhOEffXOv+oFze+)RXq9Q#& z>V3a*3*uBR^Q=;-+dV^>$>wGf&)YehSH?1Ghcvz*Y{mkB;TyKZ%Qo+}p)48SE1Umr zg+MzU{IR8ar9q@@iisk=OGI#bVm~_vaEY2uZJ-II>uq#$R(FGfxW=P6-ul20`;H)9 z2+DD7ZiN~FB)}Mf7Tj#@!&Pyrrkjm2A!`?17LQg75zUo7e?BUF46XC_A}t?;Csf0Da)|2-?)CZ2u0_P^3%Hne zHEq;K{2v+K#D+iXpA#=UK?O}t^FvJ$k0~-JQVVVd`+78zak)GXv2OvJrGCS)r?+uj zlZwNuDpLZd0@mtQ*_|xs=_>MhA2Iq$S!fNr_Q{)tWE<`rnZW}_lI7pDdaC_5PW}Op zt8S%z+1}1Xo)>PBl?xoE%EXRMf|>?a$`RkH&gXzEMSB-3Z^t*?@8ic$zU>lpR=nOx z(lURFRI7rVPh>uZTrK0To0zMOkDWjEzs+F@Y>p&1A~sn4E(#FqirccS($0O3uij%y zezxzvE<0{6Bb5(st^&d>Qt{W{p zLjETH&(YoY(U#UX9uYj43ILcz%d06mC0@mm)RLan(4psr;~GtUxsNsf#|O&~ zrffQ*lJC&G_U4JdsO0dun4^#zv4a;{9Jj^dN z=oi8E?vp3-d2cED-{TChPEvz{Rm$u+vLa>S2oV73Y z**GRk(S*hj{~br8ZKb`>+iuAzewNFPVdK+cZF9A|LBEYCD|BlCIfPpF?rx9$!S3@l zUkCV=2p!v4#O_U+dm{94RBdM<;?7oa@!h8>3Mi!EW+mQftvzmK72k?kZ~?UZl( zpS2UKW}CC7Xi<-@m)njLuzw?dZetzDasdGFCH zJLO@ytip2kjb#y-eUox8#W%$_cchyS7Yb9S>Y>M*b9mll-*}GGDR6K$gGfzD=16_5 z6q~qrdjGrjL!CwT;MS`4-FXUJWzxtb{5LK zoIqrsMf{BuSwDJqN;KW%F@CIdhy;z;YE^2Se1OYib8Q>kH&|H*W23WZBWuU_VfXWi zVxK>)HoLUO+{l{3xS$A;INn!?&wp^J5_G8%Ye@$D7<@>AvLn#rJ__xVwuFoA1EP#= z1hdNdvRk^73hpf2gupj1lxOlR`N2!DXCM# z1=t*DArkVGH>)%8HSf4g2%bakCFdKXy|9_*k?0vVC{+imStT(cQd5;o5*Y`gIMToI zO7Y0dI7a0Y1gU%~*93V<$eCLJtt#QLMF&cms>gj)@C(aF#$l!)jnUf5v&xuEsDBstsH*h-=V=5Pp zKm|m$pU~@;Z%?i_hxyw2f0GKSgr@XaEFWosnRH52mfd>BUfuz$l;pcdekA(`V7m;Fdnf=~XB1L)W4c;$3I zRPqZ&r2?ZMabOq_)v*8k&kd!NFcl%AE@ulTx&Ozr|E0lEC0{Nb5bd+j`+w=>e^~ao z6pN;?WyE&=+ED#}4Eg^Vv#}tmAtX5_)bh9dKb`xJWqK_3A55ChQ$?kB#sBR1{}}TB zpW2XAAcG+X%e6Y3=Kngy|IKnf0aCrsg#X^kRr$Z3NekT%YK0-Gu9NNm>&pH2ZVmW< zey)tduwL~=`F}YRe7yfGb<2n*vF?AdA)i2M8AOGxkkM^}qj!Symnm;$5b*CWTk+a& zsnLq+Auu&7G16{2MJ^v&7p?qeE{M0E%;h!lny80aw2Wo_9@W|2df)tz1f=)mHZ`GL;Jrsvinx_Z#~`h*+N$UY@YMfiHUI5j2Ii|uVqOw&^X*!( zan4Cwvt%W`;&2CLS@!g*F)X7ezfjklWM1xQ zl!p5f?3H-(F_u`N7atoJtV$>+&8*~xlbR+Ypeo~x%S5MuiJn^t{gz+au&&IHvKX$_ z4WJixjECnOA0aQDPYfmsptr+7t*&)?^b_Z#8qpf_+K|@IaR*%&3 zF1O)}*i9e(5P3U!h?|lv$;!xXkWGgZ?oxjn6-SoWd>^9}78;wAd9f92N{qMjZ) zm0Z?ff0wfDt7@e`7fJth&Dfxy)ocBe&8pplSX0>)-`r5y{}3xrq4H=#B927vQ+QB+ zEeaKW*ah?UUtO0-k%$FwKN_nXM&Baixeh9==67Bf*jOYtmB%dP+{%!#*z*x7DSx^N ziu}W%3aZr&^kW(Nfa#o;Jd$+HsBuxmc#bWs6H;HkD>ODHSZ0pedeLx7saW`Rk&AGP zyUchwB=R2R?{C!jAepT?Gsx?aS(nihSIf6`nDW@T#5!XmWTA?f>eWZO{HZ~+(e}s( z4QI0Ip^x{zycf04>hZ1s)?d${A>HS_ldVWQYJd?=Z-ut=+>!t7aD+A36>OX!%vOl> zuQkaCFTd!W`}3}i<2g%XB0co{d6o3c$9hr&-G&1g2EAo4aF zCzHKVTI7}7l^g3v&M#H%6xIychDr-qOul+JS7r_}7cqvMXD6E6c)2f@ipDGR98W7< z?RzR0_P48;mgHrcPrCB5BrBdWo214nu*JF#*jSxb;(vzdH(y$pK3!`#=!_3R z-f+G3up*&K4`+o%>&Log;W6HNe|vlxp5yqept2DMS9e4UM-S{v^18uDy979uoD* z7OFQVxkjy&g#LL4hIEPAP~ExQ8GgdCI#==uQ@2IQuo}TTSXMmtkw<>_ zsVT)z|_KHmbezQTAK!SYW&<{Pe;Z06)Vg$c#A zQ@I9Wysb7Pfor=?{cI|QC+Db_t6K-)oy|&yc0)&Scpz4ypyLE*<-6doXb}p2vk4E> zgq@3jJ8mZtH7Dlv`0}KmX@-j*olfJT#_|E1&}lX1kI{X|FZb!xiN+lW;&q@z9=9xG z=+~`%QIFO@ZD+Az?S*&pP3|lL7BPB?2@NuQRbUPd0i7b}JtBaTiyM)1T*m;Jw&(~Sfk`1TKfnK~sdmyPIByeW!z`Wz zA0hq}*3BsU8|?y~giJHz@C`QKFN#8R*La^+Vp#-_<%gidn$De?*wxR#8+y>(M37|+ zWye}&DxoAXud68CweEYr=h%}tfDD?l2+l~YnJqp+FlPTV>(J{Gzq=IGJD19GTx1Jq z%9v+-{=JN=%GxF<7HkW4v}dn~<}kMX5r}0XQmK6WfNadOO3W{pQ-R98trmWQe9*91 zS(EE_BxQljPvKcZNzv(v#z$z@Zj~q^e<61I{T-8yTR{UWUvjQnb3mo52`OP@lnwUs z@Y-I&Y z7djInM{s8waTGxZTun&83a6@2Tt1%6?vkO@-6g@QL-jOpxE<%CVnEc>)rX|LBGAz< zl0>N?4JYgvq|#|kQBtPZE=n!?xC||8XgpYsimsUz?{J08mblM3Yu@03PrUZxN@@$D zs@Hx(CeA{Bj^y;|&yGF$jC==$jdj1-#M_|9`QszcN&q&xPh|xuIp!%;FdIyr2z~*J z`L1!Uf;_{vQQ!Y)dU7k|XECWvMFq5_;Rv71eNXObe+~z&1-aj$z~JMrOou3eLxmO8 zWXQ}7Y)0^CsJ;J*%N*A}JH<1cLAr#dX88Xeiuf%`8oDv(zu?BK7q5PPJLcxOH59f> zv9eCli|w`mT^sDHCbx>6RrjHBgl|u;^IUhL!WKr7L(?!gGx&9VVq(RUZv0c~MEeLc zM6zLEtuX0qq*ttx>PJSfS`Lws9O^_l+t-u7mAqcF1$u;GNMhWEg-Mc}@CfEp=R`%? z(pQT~-36~06d4O&?0*E7DgZT^$HdgV0x?!2bOiaanjnj|q#)%ojX})4*TMtoisUEk z(nJcCa=WW@+-STKMuVf{^pbuxVtXS&i{F3p4dd?dTF|?(rbG{9vM7=SbnQ%>i#1b> zGfnO&`!QNug-2W#wF&11#U18YNi`zgw78DCAcgS}9ZZ*TsfH0s12r+eE+~{V< zL}D@d1z=xL{M4S1k?TuwwXoXCQk%7Gu0;y(ce={Gb9KM;aGIvs5_f*#lb6nyDKJ$C z{dplt8?$kka;)p3b@Lq@l7fX14@$os3yM}rzqRMhK6&m*AJs6#g0u-#Sut*tNVN#| ze|EU;S#Au8#D^fO`9sEL!NB&c!}3w1S#Y+L{v3rA@W)W+^Owy{!8@OGixY59tyL_m z@-&A~)IhCQhJoFyhh%fdr-o(l;Vmb_t^_PS=Rm+1bzv;5`^0YOV4n_(p2joVWQ{VG zITE@oCDeZZlXX_6>@~lwYzraVH3I?K6qCOk6o=9_Gv7#z8$^{p*Qxe!pP4SstX|_y zA0_a=++c;2J%{}xVBA$SaQ3vFc{=lqoktRFfGnl^80wa;CK6w?TR5+kJQgfJqyCji zM7|K&c-iRK#_87+Z9UezYw%W?nDr~cIiMX)AZE(qVlyvOfF8&h5q5>FU|{b;I+lqF zz*|E}M5&*&=f>%ZA|5O^4qA+>GGBs`05F(C#vd-`kfYOBHb_F_O=ek(4G5}#ojP4z~GXq?E&<6aLV|D>G;tz7t<&PNjg`z*SGD z0A9i&%KMl`8u`A5@as_K806V~`9J(Zv)G3Z^Gl(gJ^osOOHNKu3bg1Ks^v;5TvzQ> zCY^=@%oV4YBoAYq=DmSU#a(RL*d*Q+o9nZ*$wSXBpSN)u&%_O?tx}Kk<@!#gb!Pe0 zD!t`_O-9i|bjh5I%MXEv&E7B2-yvI|0{;1q_BIJTndsl?_&`D}W>(YD#BHpL{DqW8DGQAEr%gdd7~<-8hd2Oh;&-#wP6n zxMh5vyE<|}3|F%YR1J&@JVQWqVd7W#OtLEilW$V|NL8!}8SjT;kh~Xo5!9XNsEzD_F8&%i;VJ{&FErqOt6OI-o}r_F)%A%>fo)ew zaj75I6}uddf}wQ!g70_{^u<*aDGiM4&XgGe;sLQ!Z-^R*udKHiy0k||F~ejxd+iATYNEyX%PMe?T{Ui)&_WxWgaR=;4iE=s*Y>5`I2o>qO$i*hnWXO zBPzYyafN(Pda!L1-B5X;MRZ}bM^%Jj$KC?J*T0WPng6=bd8`iwwZu)dFTcVMzYWK& z6eQ77eGGh2^64dF4^Lshv$E1PVO;pqAgq2Qc`AZEU)ls4ZI!efmL4-gJ{D$5rHV0@ zOBsIFu{aNbb!y2tI>#ZYMN*ygV{GjE?q61aat^|I@L3>$50A+Nz0R(Pe3t|pOMixi zuIw)?5q`L#1#vJNbXxSuiS20N;9>{`3JfWAtRs>rIKp|yk=3>Ws)^CW5uX5<+>ns}hl-a-0rol1EWGz!;+e}=K zI6^LNdBBZu)aO_h4ErbMXd4T?bviH55470DLB%ecv4f(*gDx7<@VpUUh5p7HXvz}o z-;R<;q4u|#qWk5}Id0p_D%RFk^%~Xc)dgzHfpcOU@gt(J;<}fE#-Xm2v#n+~Ol0fY zKYQqD@i%smr^7kn>|O6%qJdN8KjZ~a$mf5|JlE*f{agC?E80Xo+CiHc91nNFCKYAQ zQ-|}0+wx<5Bp}&jqIHvG%@q{jhD*@2D5cnJxnOk%Ms7&g?p>x&XGKNTiu1B#+o~MY zw=@-Opoh`vVfFxg9inikeM+-SrbjyIh2J6m3P^C-zolPe{S^wkU2WtWNd8mjx3TT7 z@%8%R1%{BL>_rrbV=mUv2Z|>`q@aAME7;e#ery<=D8p4+t?7#4zTIcoS&)&nS{!;f z)USmejNXx;=>T?OnAf0M^AAmrqO>_W)H06=Vy+#O&7icl1>iPdJ2>|kBS_DIQg2Js zT42TOuF~T7dKJykm>)mjpn?}Fj{WN}KP$h$$kF7~Ih!jnKlhl7JUE20LJl#78r3+U zML6X!4+qeLpw!4uaulnGd!ml9@Nb6hw+^#T`|P!vIzK?HT)5WwD{2uD|D(QBvg!Wz zp9iDvb3qSLArNbuO`GN9uWGCLSIauW13qG;b?_oT_q)0#C1Fa{_(<=~JI-H2bulJU zJ#$)&^PiugQ(7(YzKjC8+sH<9o9Tr9vAi)AX;z$@bkuSjAtnqV*0gH9k{gqHBq`wS zex$SZ*N)JMyMv6L@7%C;*u79WM0~Deu{nIy-hU_0tE`@-q$KJ*aReG5@jHRC3S9;J zD3Rf0==5JhKNx{ET_bypQ$^3{L?2KF$++rWoA}5*$e?eeb&-|r z670HrEw{#1AM7vF(VAB7)*8W?ahSRl*+{tfld0}>B2DQ$MvXArq^x^Bb%Xr89Lra! zc;4eqh08p1t=SAHqLLc$g{rHnvwh{541W}+Q&QB}3*Aw^9DLk~C`N_hoTR>C+55~0 zUxVA%SIQzA1r#fEM;@Q~`Czg5eCIfSKS>}DsT77s&ERla-^xFxo!fsLgvYIE`hkDH z9H9A(T(np=a!R-$cNe{46q?$W%3x5IMnEZjH7_y``2+}m3Ci%CyNMseiOb=$^`?gI ze!Go36POk;ttH~cCnAHneP-4V*l zmiIGL#*_NKe|^~}eedeO{_r_;zJG_?Vkn^t1*lLiI|>Foji%{0>A!5tRNXDo@KJ1E zWorsPu1W&-cpdN!_Me{^Uk%s27d5z$#qSc#B@05)-d?tv=YN~ic5dbQ+^n4?0K3Oi z*GspH-Vd}SSj?#?L$^EeD!4@)w{@XU5M^9;RGuOnDF590N{>UhkqTPmJu3SfbU&+# z*xGlT+Uayw^mM}YyHFd)Sl!mXJ#h-OcZ#z za8_PvYe4)Twm*UAbs~;t6k;i^W`6rOWOyj6#q$h%`uMB75-@g*eiG z-4kj04fn7%T@n3S`*tg1bDT&W6`Z^xMSRo8*q+m!yL#2I;eL94zXR-6g-B^4#NQvI z$8r^vi=^HAiHt5uc&;Xuyx$5(;R4Tf)sLrN7x&hCsrwuZPFmXX3V-{jvuSqexh;mK z(O5MuUswXQkzm&8ssL5AoWqye@_n;S9DX|HDh%U+_;{uJV`%9-P~f8JTleALW_2zd zN`f7|Ao8dsd*6&S3qX!D2oDqK;}nvMaKPbVW5y*XqIvq6B&wo|V z>mUW4rZKkU+-nykW+nSD^bDV8raQM93UARCls__o^Yq#h{DOreIRWsT8i9-4sP2a1R#Q(~3t;;{}?KxcdB#;Bm06}_~ zQZ{lFfGDz8`4}g_uO>CZqPB%XlmOVDVTYmg--&z2Tl0PFDIU@O>l6Am@P@V;zWH9q zO$UZUc}mpdL$o^sRvuSiN;)Fz&rx{Yh*G_lSW7`|lK%ZM+w{TKt4T znLzxxtJGZS)Um5?|7Nd>?$El?>6*26P-D!6R?f4-juoeR^wsMUsq_FDF*RbX39@b? zHjmG^`$7NRKK!yXF{S&%x!Zto@P;>=X;7@zknlplZNTp)Ax>w$P&&X|^`ZI!C zQla5~FIzXj*3p&_C7h_{bT5?8nQQyX0LHu@_kU~@`MdQB8a{IcuA333*zS1%@@A2e zP7Kg82cRk2{w1kFJA>QBT}M?lyzU0H&-#vv+jE=kPiE;q0J9y%gL>ZmX6@eNJl^9Q z56g*s*@16Hk`fpO!8i~3{k!P2} z4{-NwI`)iI|F%0Xqkvs+6RP*U;hTozn_y`maRXxJ2)Nm4hb<*BE~daUCf0hYqC}UW z@~4=74FwFJBVde=-0HCa*tubct>T$XNU;#8CdI}0tO%R^r{$FFD^K=3Wz?_nGgD72 zW&X(rXn7)O7`G^#-N*R|3uqDIwQ$L7`~u^l!$-;-b{8c8%_RqMu;>Fana97sl=!_E zfLZSXWk}QZ&N_1-10t%&;kNnayPnSCoYCFMqF}{#zqTn@ZxcAoC}h`uMN^qrFQX6T z#vr%EcuEe#y|%MLu~7DcnK?n9{-^d&whzso!x#Q>RJLhUCmJcgVUV?Uz9s75J z(-Um6C3NZ9f_8?l~bo^FK-jE*o6O)7HHBWEb#L z`xsqR%L4p#R2%&ddBcmj@N$3SbRnq0AllB2E1gPgJ$*4&YH%t<-j-wuE&;~(&p^aR z_>Z_h|J)9f=s01!hf$X?NgwJ0pk{75dq5n)RZvhvxuo>*!C#B81mEcAdaMWgka~cJ z&z3$WDfRezRD$K8i8Og!wq_~E67w2SR@70w5s$RIQl#skz0HNTnN86gac6HZf&w5& zjDDy*Hvnd6-8=Tfz`H|;^@e&)=wQ5dbjPY`;A5h*3XD9~vi9CAc}v$jZ?%fJl5-zL zcCc!NF{vjg$L;8F67dU=&sMUDSdHWxzki@L5GS8=bWm5iH_eRt%kI@NmAPqml%4O_ z*i$MxqgT;sRQciG8urJuN)L!b7|TB@mVzz!rnws3^;%&*8k5E903#x_{p+m=DCZf!R&7?Rl&6FDq#(swNU}Yc+s>+$*38Crmg#y=~(Xrr5WfhSF?C|84^S5NA zhX2DVJ)0*vTdQ+QdAkVrU#y(?6HkEMnR_O@cj*>yqnBtB!IcKT&T&<6M7HbZ?lSsELg!S+W9sp0 zs6~aRc6^gQ^uND&Y_pP{`&kosP-8446kWI}kal`Nz1iBegMC&^M!}ZJP zElEj0AG+!Y1>pkrXa*reHFNwSSQ@$5Sa7z2k zFfF{a)&<@;KZ;&41BJ2LP-=G$CJ@Iwe+E+MC}b zwdhNevM*3YRti^Xqfx);MpfdtXsPtxyM}%jm*yjZdV^u{I1_Lh=z3l*ih4j@d7IO= zZ@b}cuJA2qwWOGtK9XrN0m}9w5_JZwK;c_zA?edK2sjg+@T`x90HKQn^|x}$9xu{Q z)!OEo=?ePi!Y=74ftCkWtUy>28^}i_+GJ>Ht~|S*IW^5d+0-b>JiNS@g;#?9b;3?8 z5F>$-V;^P5-D>rmPNF%J?_w1mF%CHCM(vUQdg3Vnk-YgT@T$sP$FrDEvokW7cjJk#a9B4Wdg>5GmD(2dVBzmApL6yOi6`=gIXP zmf6*_-`0IOrgG`KyCo16CDqc6h_9G>e%QAB%&zwFAR0c7H`nRi3`YSj94x{Sy2;89 zyb*(@<()SUw!u!-e5(u+P;0L;SZS6nZm%=;I1hFCZ)hK7 zZoYt7@TF*BxPBSsFwOTgy)G)S z+XzSnDTUsEej(j?9ugSjUW@=csD8y5at@L?DPSi>Xkts{DzHn5xI-vNqEPCqfLBoQ zI+8;pzt;#tmzhhH3ePoCON!HrnN*gRq7OyTmAxOAY=v}RDbQxb8sUE`r%S|EG69^- z77IG}HkZ6dxjPZ@XGD+zkbRLsU=`u_9KaWYe{B?+-EktWy{lPlEVFiURQ%wQ=9u36&vv=%N9yvojvs6&C zN+z40@?58{>gg%6emycoaYa;tHJK8=Yy#%cfu|7j^ae~}LC0`<77%jjP7X@+gqrAH z%<}et<yQfa7E<(Us(z8g8sAUPzjWQd=z87w1QH($Z*Y)CitmDv6Db_h?y|CY?dZg**;rkwWni}EL9qceg zmVfs|=w**te!&g$7QAUWn~C&&!b1>2lSWs_djd;WHvp5p+zIv-df2GaWfR|?mby*y z)c#(1U-&24>U8&dr-)X9F~mLu>{kl|F3tS);nR6Zq^5gDVxt->gp&c7HjpEROsNpP zzbnP%dCquDI}!F@2fKY1HW~Qp6rZbpJ{o>ujdNT^xdI75sH+Z+HULd{VFy?-r33AT zVI9{MdtzoHH_(^EGOgce605OdnL12i&xY_*J!Ug5Ua`zVh zC2uW16pXbm!FF{|EJE*InqN0`V;#mk3Zh`B@ZK`}(om%}s{sXN$OIj--~#wh2hCC> zO3)G%W_TbiKb#+GEcMGUloZeS=qgji4|(T4#}CbQ_i`mnjp#aKJQa?033}VSmC7=S}f1 z;&I@ZBEBBi9{uNz$8g~fBBsRn`DBK&*TgS6Tt$O$`iIx z^gs*H##_2i>LC&ySi;Ob_(mKZ`wYdqi|g4h+mB%F*iK?>v>AY$g-`s}>5#FC)DNwm+@3lBf`fdu{Tn;zJ< zA9RX=8s+W~N^86uCDw_{W&B~|GMqd4`5!oE*vK@1vuhWZsmybu}w^>2~yIsa^e|5n>7ylGWHl$+u4{7RG0`Lm8G*hpsFfo2eYCo&{%! z_`brsU#(R*jY*XZvr2(G8o5_n!6(cGiJ9*lHF}t+&*Nw+ zX?%Bs$Ztf!mGQCXh%6mz7dCr_BjE5uMmg{Hotml;SiOa>CURm&RJQXwP*gt#y>rQ# zizs}zcBKy#2{rkkhV`5IZF25~o0xcU#e1`j^>)?x1Smzy8N|D1E);^Ai>2&!mzsE< zcC%6RqWx!jNw;Sp0mWI7-C(T$*YN9rJnO0@GqbkZ5_(;Hp`HTQIKnbts8gp=Z^nE` z*XwTd{oF$rwZ8t>i1>ejlP4_5TUaEE-9q2Lj@~S=WFMKGTnP$ zBQpBFr<%qXbQJk|&UjOWK&=192Wid}ugPolVjSq7v2ak(6m~=S%4>bZ{G>RYsrehF zlVX|g>|BywApf<|+^x&vNM@h|u}!z!%dBAyF7DM%-L`0>Q{h05!*0WI=4f_1{Z~pA zrS2aG&)oyOeL(K*m+9pWrLitdSj^WMbZ1?cDbRZDs`(Do0Vv29kt(1D@8M#WobQ{Q z!}oN;!(5KezYK#RxP=g`fk%xea$Uwt@{1+;z44lvk}|p*&Z1bJm74T-Io0E~k=aii z$(qPeUSEISR0&NBS`V|_ujUh#K}bE939uos(VzUAKDu8-wN5*Oke-U;C@51hSFUU9 zkvsRBt=gngv+kAmQO>`9>TI_WjXkKVQPy$*6g?i_@2J;o*l9@MZ zAU|Kieb-!SSkGL&^n3#!N8S#W9OXaly6iqz^0qp1(^>5pI5<@Qt^81wCHH6z+2F%y zJ)1h7gRmvG1xM8lvYe%a^NKUENg={?EdC>R^S4;4oxiHi6j4vLhE$d(AvK5Jaqj+d z)vohodX~F*PFFa4NVMV_333t?${jPgNU<{g(lV5xTJ}j}Q6#>+Yz(XlvfSaWy?@29 z=dS!FX8qTTn%)~3a27TwN>ZO(Epwzl*Dz4ow5%OL!U*1qg?h1>-s$%v@>Sp^Eg2xx zsJZnMVq!j1Gd<|)Y(LlXul}F^>fhJ;rFeXLOE9wpoEgswhlIU;R|uaF1Jk zhsR-Y-6j@G*N%$g=P-(P7kDw>m^@c~H957gxdSQ8G7uf%|C21CpNlzUdfB#u2@35v z4^_XII%GMNF(sf}lN46?fgC!H4&HHh#&GRseSQw*q`zbFKl%Y*|#oxbkNaIrNp$0|q|0ozek{Xe4pjn6KW%e~nev$PpPGiBoOZkDN zo|ZztdRpA{`Q6?#jCiWqhwU8qwfUg);~dsC#UOKzmD6rlDvQIv>nKMoUol`q6#1U* z@WwVHEU2HGA6w;x3_$eDtc11N0;;$+*0aji&Plg;@oOl|c0bnu@>`vLy2RNRQqQNb$wCQt5rQYH+VG++l$=Lp#Q-)Md;bb)j1^_(V;{lLd)egk&? z`{vkefeORZyg^V0K+g?+G?35Tl?=}LS);Mjgk}8x1Lr$!=y?BP4Q~Ox9n)6`SfUab zXUCsaqTz8??rE&T@97~;sVl|Y-<()@Di9XJ7rwTODIx3ZNNe3W8O|pvCx zJ`w5zIe%7h5(8~JX5WGOABs9md$g2Yv5)ShM|)-o+lGkmaY@wO50{G-AmYoSPyxHIaxbT=s@JpZ_;S=lb7*;Md{{`He z9GTgJV5V3vxg0b@1sH|zDj)t9nOJp@g|$t&dzP%Hk$Nk5;Mz+y$cnrGFFo^wUZ3H7 zylsMPV*O{C7ck8ia^ADHKAO*kd`8}ffnDs3L?4U8B9@;BXOzVzub5lSZMCVBkoUa~ zILfXYko)=N$J2Q8!@sMqdDlBH7r-T}m{-Ck>lq`1f2fMRwWHQ0|D6Ru-K6`YLu%Wk z{;@=h#Q=;M@htLj)y?=L7np_(Hm!rg7C=+E(ACp6`)yVF(0S=U{JvAEU;ld;^@W}D zNus+p&v)y^_G(e5+RibBaD7%y`pEb_!20bvg(LCw{IcuG_56pcnt&(sypRP#ksn(I zRag+@CZe9yk@tcKQXYn~LKM=~654{m38S)b0a*lt3%(Q2pHd%?Yc24h`^$9+Y)If)pHlZqq-*fK0?Dpg1ed}W=F5@^aZ*j^-IY&O6wT+4`D&HgOdg^MfmLpNc5!$DTqmMToWG+s0f9fL@-Szd{HPYvTEXkG zb=L4D^_l(c4*0%cE1rjMA?0#2BCuIfLWB0m;g6Q?0Gs11+0<_SVE3|Y|2kHD)Uk>t z@bBX^;o-fd+6%Oyu|buIt}>AdEVMRHd0pvt&-2}%??(6KDmofGP`|(OznO1--Er61 zNSrxLMDAInAXl72Zf6T3pCai+?&W^K#Bt& z1-`wUQ+B^znnCN_7SSB{$!tDgHSca*Hm*i3ic_22cG7+?$*U97@`Cg62W$GXWL%l) zL%JhyN{j4&4IJJhAX!rCp~9YlKiOq=pxq@21bSJ9HxKyKZggD1nzje~7G}^8dEH0Q zMupd7V$d7~fH!YbzpWvmbt`Hc9lAba5ssWEb62=j8(EmG=o@6%c;ITWZMc6Bo=KTz zLe{xIMLBsfdnh+LS^OYtwX)b@8AL0TUz z$3~N&Z;$(byLjA2LFb!D%5_d8?f4};huts%wXPtg1Aj;AzxNp@%(Ab?85^0!3m`+B zmdsf>FIo1Jop1T!TyOWr|1x1NV%-ha$>Sc(D2q{6=L-_W5*&&7{!ypC!WYm(t9s!=@YNj3iUVksNoBO0k@ z;JzN(7A&;UZR&?#3g{l?+Nv;MY+0vq~Yyq4B40FCotF-{#9z+!W#v zL+@g-cp*WrObOLIm!pFste|4a{+m_4Z0JhSw9siYyi;l`sfakO=mY1m>@ex;i5IPp zS8E5zpaFtmi#hyU*fzMe!8hYjueS%i`%L`55%!j0bp+9tDDE2U;1Jvo?(P!Y-Q696 zyB^#fg1fsM+$|vlcZcA?;c?%)GvA$W-ktf`)z!PKyLx}S)?RCYlD}Bg-=cVq2fm)+ zZsTx9?|j*D#kAmY-O5U){q(lXO~b0R|E|YcVxX{}8s3BLs{9oNz##cR$nDT%wAwQv z^qf`M?wNN=85I02-Y}@00LCc%F%U(dS4%_WSCF4}TC%!1yYs)_)pBRTfu~vrEeGe_ zb3zt@uiSyPhpA?1ssl1t7NK|JESoJ8La!%A|J;;*I12Az2gaDxV8+FVH0rzBvJu66 z@%!8GkYTs7cG!A$E{*gYc(g*3<_DOXRPwAWywoi2j&a#*l4FasDu~==i8WlVF8((p z{12^L{&1*`^{6_wGtVfNt@M5JWlX8*x>m(b!d(0gTuqx&Zw7iZp$fG78r-(SKdmS*P ztjTt!N(sfXJSz1gIXCE|SIVSMERb1VJgPXz1Rj^AgN4#Lig2^I1l5n_h1ODwE|-a zwdz|@SE(ZH)xs>XP41!s_S}moPWJ&EwR%)a{+{1X%!^jW!Di_=Z;JMEI0kOia~8#s zrBnPZe>*My^+HxG^0`k-b$1a<$)FLLFgDm}m4vWg%{Bfk-@rU{yOJx*3Su7nE95@! zs}71@z}!}zL;>)puiTW{wxPVuT3RzqF4Wz=U*#)H9wF7ie;-%TYRm5vNsSMR?(q?I%<6i$GG=O@C zDPT^2iOY5@1LSD|L+NJKeb7T-nvg4GS8~(^ySL9)2>4s=Z%}=+BRRFkRGQZ$7Sgv` z01^g~AlQ;1`jzAzh2Hq=pz#j3P1zgpR#@<=+ch#EK_bF!wIa9kIHExHm|uO>9pkC6 zK`_xLy2#&)RQyP2DqH!P$K@)8FUl!<7aLdbcaS$YXTx{XeQe`jZF&AY$Z3x5v0wI< zgy%;%^bp=3y1rO=(G+uR;I_A@7r8wRKF=*9X`iTXY^gR*pE|ws`q#}tW+LuFh6VXh zE+}iM!zPahd$?Y&mKKZ0@$15gN5K8eVGWvDuXUSEB@r&Hza^(SuYl`rOME`x**~^y zWWfQpTnnbb;L(ZMW=}D6&nsh$X$Nf$ngz9;sbZz-rZp_TccwvkZ}kz^(Ci#;mY=+q zKQoivs*I+NquIH3fr^>z{=`WTio{c!j{PRhd4hqPeYev;JzvUlpf)%s^s$`fe-ff> zq{U%CXA4gdsnrJ{DlChu6_v^LH0cK?=ct=>>ozUt*%6~i51Re;!r95Y{JFcLxuP$> zchx<6|7X5Wys0zwZwNw69VgDT)+-S}5}m1*>35%Q*2)UnN^R&12qs7ASf86Ktu&1*%c|f* ziHy1>)lu%vfDCv!{?^#vYz`{{vdOtVh3(m(F*y1N|EgP-Uvu@mwXs0?W$QOQ#Y=2x zOgUIYMVUK=O6E%zoO(V$Sz=?>@(-y=H?BDy6Hbkx-(4<6T8!i-Bc6k3C8;0w+tBbO z>+IeVfQV4H-IfKTBFVGkx`+#SLJ9PcKn^zHpF2y^hi{oD;J3g8~ zBpE7mMk)%bKQUY(uIKn*aE~eWR;y1nK~in|?I~C9$0IA-l9BpgP8a)S^h+z7+GXi4KwRcY6!iDIIVUw}ShRJ+oO1&bhF=zNWB0GlV?k~iog zQrf1eY|f9KCuvCL02>H`EiXf6xP^6sG?Rnd9R@#fI&^u?;{2OAp9WW=d715M)rTMluE{}mnG^xOzF|J-f@ksy6#82oc5tE^B zChZI|a3;dD2sv*olvz!EPyub|K-gnQr;C#&yU^sx{iOBrQ#gL@iyWSmH;faEi1pn# z!IS3wcOTCq#e1feOLu}3Chu4jlCD^!$?aPV&mFWpk)b0V(3NcpMD-4DB8{zj?%pe3z$p4yiF@lK+3v^0 z6@q6c9Efi8dZlST3+P4}$FN7j5;4S7SbVn0zJ zSc17gd=1!yPC=4`qD5ClYZ2;gwz4uG9uRvGf#MaAd2Y)77^enB@l0J#Rlr`R_xIaJ z)0H5Xsx>l}wt4lo#i^YbDOFpR)%TsDMaW~?{(NhpdA-w*a$^iWv*Lllws8g>Hah@p z`C7zr-P8o_2AuOe4Z@aPW~qiWsN(t={{fk%n4n9RAv0&tGhv}E(>qVKzyDCI(*80% zo&9nq3yykjumKKWDjxJO(GYnIz<`j7x^$^OvVJt@vZaH>^5y*+9pD^pg_5n1(xhG$ zXDq+OfExCE5W$k7Ofk{YE~C2BW}Ym#S8^5-i}}+wx2ro;*bYmLgDAJ$ZM*g=m-?6 zvA0W@WNseBfh?NG%+d;REV<7g_q=7JGiLTM?YKaO(A2 zq#@?m2Z`f=)>qsZGdbzIUdf1K}E-87h^g?EvH?gOJY8gwo zt{K)K#3KYt;Ya?R>Onnj1cFm+nG}1lcJ4*g6)h28T+u7}o}}~Tr|a$!c`Tn0Fg9^} zcuy$EzFFnbL92nn2Nl4SEUTX8O-&RI9mgY2yRL(>9}h1NLxVq;dblQvvVc8KL0n?z z?10T@v3&Wf>#q0oV$k&~+xFgmy5hEXs-G($cK?h~q_-~MR4?Vr4_rs{&67>2-r;Y9 zxaQy;rVtCf{a@@c9(ba`yPUJ{6#Ml=BP~!%j!`oa{da3|$MO~-b!MZ5Zpui3NbByL z%e%%4_|FlMgzMq#-#)@``}cc`pF&jba*_j3pJH3DKmCU`^qy-6M|V@``R%kA@hT6g z24souW4Ltbq>bCcu)m=qzH)HoPT%?D@wY*+OQMUR&a#D}qh>!<%7y*pC#m}aM<5>qG#SvV|Rnxl%+<@DCaDE;2zkNebFYMNBInP_D-83K5NLsIx!*(Zb z6a8_SMM^}iq{(~zqEz+)4+xhwp5CEU>||FB60S3UZB5;upsXI3Fe`oKW9%|sFB%WQn8c*LzQ>l}|*-U)I0ns`KxHY6j>xmk}Cwp&@zXuQkx7z@zneC6s5dO8;Di7#l92 zaT!bz?K5-mCp^Z;^}grfw;9gZhF*K6*p2K19@x$Vf^I~6z*z?9y|o_Fr^ zyuo?I*##}Na*ZvM!Q(OFt)8_O@%D25;~&Utgmwz%h$M2P{We7ui%!7@Ry3=+FTso_ z4$uJ1R9o{lt9&G2mMQDVXxx(s<{OoFp|goG>A1$(WhL{^Q%{|>V>7hl*VN9a;_PZ= zELS^-tEa7=FJJaR%EUry*8rq4S#?l|p|zJiW!E=!l3MoS_hOYtX) z&d1UrkfS{fr(57fGDt54p&o2zI2zT{1R>1x^>19<42|N14euQz z{)tAX&#yTs3j|`uw5uPCLV#z%p3=Am!e zLP@2DGViH8-dX%HXMo3?`B`V+7pQHOf(+#LsZtoNmi9H1ExEsO#*AuscW{>=5XBfX zUGE#7eJQ9~ZS02aT4UhD70z35Ey>i7QVD1{*r{luAB2obAQZ|@>MI&ZBiE#bbSNl| z_ELv}@2tT?%ZH0kwHbq)uwzpnOS2PG_o&!Ds1!3@(*$0a=dR))P(OL(T?;&~AjH|c zlyV>PvOkJc`xgGCJAzkFOmF0s&qb&>Md=9N0OlJo66R&X@*6+|)o0)(2_@fIWyuFn*gqUTW+&}N2k>UGujjDw;+7!*b$2^B%qt7sL(^c3;^OJ!{}Wa$Oxy!5tKm+i%zE^76wF zcoOp&TZPG2apoSGL+?HsD2F6dt`cY{bJ030qEjnXbmT7b2D4JXy8`Kc3-*Iq8+=E8 z^doZT_bfL9-+vRrU;?X?E#b z?+dKCwkF1!yZa8Opg{cjat|q(1~-k`@3~}<{dAFnM3j>de-r%sJ&=OT1nuhX!fBS$-bHnFiJCCzyA!C;-K& z<3iDhVfI-==Vr&X5)h}jku)^^dXhb2rta2rwk?`4c*E-^4;N-xU$3UtUhZ|#FxG-g zbi{EA3E}9GIa_%imM(ALD64hdYN%0P1i2F@gzU&8BLg49N@jTFr#zMSekWh>V9kaG zFPCf&?a#wUc_n<6PTCBA+4Wr^(`@*QM&0rkPQ+Cm^hc3F?$})^(RsV@=sa^}Ht+L0 zM59BtC(KvcsyJm2{0HKpAT7F!v19ikHw1x6f#uK(`IQyK&<@47fc%x-p@0fvC#jYt zdNdA2*Z9ZecMDydZRS zhEUI)s?N}|omFzZ;Cm*qa#^S-?hkKo+{R!=QvuhO$MxkIJN!JG`)ikVQM?`F*y9a0 z4%Qg${`In_NGURhkR%S4z!ZXvrfOtCmhpGx1xXEa0%$6@-!aCbaETgf=IbHvTp6zD z-IaOiKP4nl_)37@IxwHS?{1~1!ymh`D+${~H4H~Rh9Na@9w?8oL5n#{JQP zTeb0WYRh`b_y9i+w!X{Xdgq`VF1_w&pgDq}vSgu8bgdCtkEE2R@{u}qV}nMevq)CHFX36t%rhUaI;6Og#M!zPG{2AG2_u1&G6t;YGh9rJ@2S3q99U&2C zW`Ttg+jt&U@x4y`W$E$OhWorS#v({I*M=Z8;BM!kV&{gS+-i&)uJ0JP6L)RCbG+hc zd>w7@!_hxJ&kzEv6p0--z~v4J{CClzr3gRnG`UXmM`jN#_mHrvW`M$7E{(%)%I}rAkQvr6Iqw2N_8|nFJe@m(s`9_JMHEF<0cgvkJj)ET*{$Igf%9NtU z6;R7;O5}$IWC}qls%Ms(*xXA3DVnvuJJ97(DQO30ArYDtoT*w!8Je%ck}~e=+O7QB zEr}Cxr}@=->fdz2RHP`}6~14K%=5P5NzHQDH;)C2ZB#(sSlFEQl?+hYQ_@h7axCJC zPJmpRyIMM)1$pr^*>on&sKkP|8=a(~UidmQEtb3F^6eqn2dqEEoZ1ruJ<8f~_%F?G zU$(-+(IoTV5^ei*{zBg`tGQ0sxH!xZG?t#GwfyDKHP>Jxp6;&bjJ+sE&tkib=7a?~ zMAtH(MvGtGQTd-jc@{@f;s81mac2dt{{UDMPMjx-dge3wn?xnB^03?TMeNRUFu7Xw zU%fFJCGLzPMQ*b$uhlohh793EZQk1uduYx1T+aGcHUDHUR7S?fuToC8d0Vxm%P);# zPG{gxx3Ozt0|Vr7$-B7m4@;~#m#u;t=Z@}2dh5q4`Omv!Y7$dG&X#>X%C0k96||CH zSH0F-MwU0f3b5-gs{dZyza?gf!>m~zYIe^2sfuR}wOY8t44rN4@MbnA%>8?=%(-8b ztWY%K2zTN76RQ@r9tuG!U!|&Qn!4AyNcU>#$>OVs@!L4H7(9=R`Dwvw z=LWyePl)6o8;P}rB|1!)(YWmA&JxnlUC=BlNps#HMN74UHW>`|y4w7Y%UxI|RG!N1 zAy)papJ)ET#~?*cM5L%7#Q5sGzBEp&O1-S?XGo`9B$}}%@*NlTZ}gSEg%h{vtqvSG zKzEVyrm!0&Y_b*5n4#Znb-VIu4eC}^vVDjM0LMZ%0QelAY-B$un%t_v`8R0Yb11%!;y6@f8_y5Er<6A;+HNr6D*X9~P@r2u<+tCerB7{AOE}^}RZZlEZJ-yE>9xvfn<6~;PTkV6H?706#EL98cX$4&q zaP!HJJ`b80;z*QOmEwJ7I3uak>Ch7#r-pvCa5p}X2xCEzy?oQ4oJIJSk%jDz^G~y^ zb_LpI-jxv=&}bx{mDA5taRw`j<(<#8abk4!7(2J7ds6cs2j)1J61@kPji|(}399uGN>k~b%%zrr>X5&!vFTGcNT?re)>R2$eR zo&3N5T62M3B6|?Y#*Y2Ift8{TnyAyai}hDRX)Ie>%5ev^IhwWqss!$sdL;K;zu2gc zx1GFnJFePZ)YkS9Mz2L&EY>$^n1-@~Ui4*-l$Bw=o80Z=PkBC6iHLEou(Hxeg1=#n`NssWmnY4HU!-BhjR!^ZMOGoS zWo^0`7jRdM{4TOB@rGd6z|Hz?$zW%L6p}mwKYhokz|4>SLeLM#vtTfQ@V4t_Vr7P> zx}sze&EkPasse_rIbO4*`i3B~s7GC5Sq~ap;VK3FMHHy26#ggNW~V9N1|AdA5E5$t zp3%Q{xyO|@0T;Szt^Tu(YlR_R8dJalGE z)arc zRj>*pWc=qgL5HtqA~KGJXbd!oYoLIIBk_i91dzb4MtLKvy7VDE$+svB=ZKKP?bn|xy>K+!@^BGDv`4cq z%o98jLPfs@jr@WC&}O0Xt}=2>cUa@WbO4=llIl2D2r&v7@98Ek9q%8$q8a@#Rsr%m zuh68WS+o+7Nrfp;b70^7%v|KL`TK`4V(Js-eMiG3be_d|x(I0*qQSGrmk4?&pK>i% z#YH6T@*IL)PM%m6x{^-}fJr$G5j9vl=-XJQ5JpeRO}a6vv3i5V4Oj0Ex=nKl1S{;T zIr_b<%M4VqOd7u7;7x+=Mu^(0bP_!x!{Xb|Q{E&Sm zMaWmU+*7OvEAb5tMK-a3pj%CMiV@+y2t%s?r9FZWZU~%Ug)HOYH~MB-YJz5kVigNd3IG*9F8ExhRkfgkCMS;rzWljrWL2~1eo>iJPBnTD)A^a z_sEI;>CUKj`RaDTG2t~G_;@vUw`wECG5CHnFDm619$fHso(7E1KU>63Bz@gKRACQa z%S(+`Zr~5N){^*{dIxQ$3rhaX*hOzd`vb8mY4a<>8%FBQ45AO-5uA0aIbZ~mHYfCk zliHbgIe{duLU=v8euKn7k*3?k9{pkLKAV5em#i|~%{@hlP#HK5iGB!CZZ2w{_2==g zGdOP|&*K_b6qc=I-f3RbGOm8WZNTX&{W$mGHMDG44S&G_|Bj)v0)4Mv`gqPnJ*zox zF8WR3YcH$-rjuyyJ9_qL;_c+OWX#gwg=;LZ%mDTqJpmjkh{m}L{uyEeXPkEPwo_;* z>l0uqJ%Q>)jCTfsc1JnoHZsLiWR4e{^u$QwNEGR;);RJYbTmleDTj~K(6yXA?ZHa6 z47PK=;Gpq2ke6#Qr`RomG8E4)l>O?y;~mayPKqJGf_@RM%KI4r#u(&Qq!rZKCiafP zT!5#q&7|iY#j+as6a1~D#sQ}PgDqcpmb(cVTHrKypn|&L3I~GsdAnXnuGw2s|u@|ns06F z%drJ-zJ{lY|*Rf=d?y=6p zzTO|A7-wa;Jm9vfP|u?z900Alp!3eS<$|Qw+3xp)BV|IIv4-T zP^*DCkm=cvOd1qgf8pYwnJK|TQyiDXA}w5AQJyhCADd!kl?$Z61xT-^LFCcy5ZF8X z7$=IrO*-gckpKFp*lIhaNa6`7inR^(yQla;)Lqlr#8_@1>PxRW6O*Me14CHH5HdJD z`45I4Tg-&-e(E)4$+(@?K4{Omw4W%`hio`K(kwBKoE+$oiNm#s}W&e!rMS z6w~Oi+&|doXBtPm9_K!2H+RCwFC{WR-0{(HE$pG--F{X2We6;`CQF2nf`!`;+h*i- zivD2zE5{deyug+FS06orpm7Bs?A~)%x=oM(;>RgEIUX55N&>@01$gk)v zoBN75|GkX&hgP}C5j$9q(rt%{=1Xle+9Q1-^5sr&08VeJs|1jb1Z;&skroxK81zvp zObNf972!DYvDpsujB}ArB;}s5hp9&ul2k(sL=w0Q9IzI>#5&A;NB=`RUPYfH%|K*&e1}QyzSqn zsMu>GfBg(lx*mp!j#a)uOuJBuH#TBzX8$|i?k01BHlp5VJ%q@aigyu|>)^{Wkjl9bdXpjDqmumZ>p(Enlm8r>Q+B4e$E@8qo+nTJ)Cw8f1#;ew2y{Qu-I3&~cHXsvp(OQ=WfUZO%S{AqQv6x@f_mhTJn* z_rVla#*QY}XMgxlPNLu5Xa=15oW+rt;F<|8o{!Md|lnxUo@Obk82FI_>#?H6F;%z!t zoOWw2ZB&*|s1ZsNp{7(hNwChQ=5^*7KN-Qs3k*v{_G>+X9|6lodQP1yr>%-Q7e}FF zKn+(aF0L_?QIpi*I8o_sqq%ven}3ASui^MvAb0 zReG1dnzGlP6BVEX2K)O6k4;(v{DW6Xiv#^#l?vyh*=t%t2taA~^Rw>-6}UtFP2snk zQ?U3}M4mCTC<$yy3adCulp^9)GuqZm%ful;s+rgEw&+>vMyulmpI)17Y2{B>W%5XY z9J{r7Xd0MbKe{El&E2F+oxEh6Ft65n9C9oQVMEHQcHOf0kAM+vMSCiOPcVB?&Yfm2 zXhw=Hl0Z#mi#US{_@x5FaNv+l&cC_gk6pnaPasj3*W&qK%XdfRkMI0V&RrbCkn3`Y-f{O)TK24xSB{?R3Y6i|X zc(}k@sL7@Q8J(1mHs`RzVWn${K<|g1_~uOzL60PvX&oH!EA&@^$-i=R#;%X^Cj7qp z6~^nPu?w}BAM%O#sA>YTS?Noyfi&)P1P}aNdL!?MI4`Rtb zbp=JRnw=^a+(5SMIW*L2IhKNt1#}V`k%EtBt^UIYI>^(4xv6i+Lad|S0NBgDt(U%` z$MCs8Lw?HB7yI7{Iotpb{|h0EU6=9e`@x@cJaIlN@+E0?4zH<>d*)bN%;V_}l)M&bb5l zmzOudeLpQGhL+1Jg49B2SohC?N($-Q%kK~0Z~EV>3ucXuPp(sV0u4q^{;%(StAMxw2~0&Czkh5z2^_%`(Z3ldd+qL z_dCnP@bsQ6MVI6uT7pb|{Ju1^;q=D@>4(dEo1W7Q>7VB1VDa`D!dYZsjOPy1!qwEA zV8%z9l0SXmhyTODL+_282;OZS9@HDb@kg$)-{#4`{`Ui;Ljy+>QiOj`LnZTJaoG7^ zdAq#EKz?ct-`-i?rwYP-nL?aMEKUTP#LqK8)>?ieJtge{FYNtETOB6i?{-G-M;J%R zR?|~{MLiH^5YmBOUhi0>AKvdr!cP;gBb$biwZ?8P=YA*#+=~WRfe%^DPX!;tG`FNL zNbZpyq@EBcqo<2n3?a3!6^QVG!YFB! zz}o5DOum?LLz(6Mj2A2%|k`>JFUr#PjtLX-Dcd$;dlMcV_~KkdZ1a(f`ECPyNHX zF?SZ1P0yNrLYNCRnE!UD!U_E<8n1Ah{DWH3ez(#yh$rIX|K`z!p$Tr6GJO0?NYtRV zTB<<46CQjE6$g2$qv|JmaCxAf(s(}VF#L4~gNUm=aEDmXPp7Sp=;SwK_HQXJ(Qz*~ zWbK6|>TQ*msI`$D(Jg;UbP|&~*iq%V9f3W3Lx-`7suB>8HGF%7-u~Zr@ODdh6plI+bAY<^A}>ltRXDEaI?oAcRdhSJQ3{>E?8SE5E$}MgZG3$ zL|7l$2F>CW+^9N5Tvu9nA^dz zK60WVZjE7-bD7m+f6!3z9NyxT9b^JBq7i$@f$M8?;&~LtEUb-F@Z;333pqJkVRHEF z3w!aw)a&sxYNq~SSxWsE6iR>$FwvL{ED{n71APk$p2DBFMbn4+e=iZ?{_{SJ!WV-7 zQtf}1|BvVXqd;LmC~@E6KP>(KR{77i5Hl$`{%4c_O~uThU`k^)Nb-rb|FbRd zXQQ&#|JnWj#+D2a2nqd(qkIOhcVJr5LLjeL zE_aZ)1pu7M@P+Rd?q?Wsxmqx5+~PGRB>zr_&v(Kdz?_BQJ)ZM`(%<`k) z9h0d;x3VY+x;lucC#PX<>$hI4As;XfTQ%Fk@L8Xl>0XfC(0!n=f++*%Xrq=bO&! z@rjsYA>Cp1n}q8NJme+mDEAcMNCd{i_Ub&4 zDk&Zqe`0P>KY^V<9(82>)0aZZ+}~%D$2_DWx-HR}q77B13Jw^n^3c@M{7d(-!tkpC{Z;Ar@2fh4pqQIh38x(F@qdG?I@TeE-M-dekZ8A~V1bI91j1~Pm9 ze1Nv^h+KZwr^4R(&+m4YX3)q?I_@{^w9FY<(*SFza%)Oq#&-hTf_U1FjOLwHQ9ErN zyVm<9Pw;b%{s-GmzwFa@R7?9rd`V7EGAZ2eW=%p?|Me zO>~+`)+Fq+Y1TMy{VI@-Cs;mtI9DTL1=bT^5tYvJZzm7v7UuR;GuxHQ3?5KQyIlF8xnNG>(4tIo^X@0c=y2ngXS(z{l) zH2$O6-m*G`-m^e3(CjW@*x}ig$C4`zG@U#4s0tkT&1}=ocCBh_tlW~w8uH=z6I;vO z;BTMVJLtz=)H`Zyd(%no{U4ud15Jkmaho&wd);^qJ!&IM9J-a&+Y!BO%=hGYC=v(b z0Zpz}ZU3(63!G|d=dV|p4`IE#MHnjxyvLdWsPFP>K52sPTW>eLKuU9#({7F>{bak_ zs_v_97<8^F(an{op3?tGh)s7JRX6s0;l)_pAK70WMR3G{NV#f+50oW z^Jw^wsG9bzro1Dj@0u&nf*BJ@eP&Z!ddrRLpRDdsk!Yb3DHmVd>cXgpBPeLfX6_9! z;x!a*An0E*z}u|NmE73#6mU30x45pHCWYF*+eF;K$Py{ws+73xcv0ju$p9I z)pEx<(%hI;OI5&^rU$yHsjAHub!@&K`*jw zYB7TpL>@z~IM!9h;g-)?uJ@U~(kJr@_$Fe+^Jf+3Z`Ca_SBSo8q9K5D2D~(O)Hx9^ z#XFxleB{4pH>W4wYC}U4haBCSu1ZybG!>xgD;34ug7)3AQTlO0!!eSRgxdxM&93`a zVcv14bvX;lA4&W{?oGcOCY`>;2%KqA%~Zq;+Ha3`b;H)tj{fh8$QP#0lPa@v5f()o zb&gXfjWRd)ZKyc{f3(m-!0m+3=+aL)`d&F|70>nVJeg4y!B%!ee(t$fns8C6r@3HBj z?X&J#h!0b?gr6jvTx+YpJ*zcckqENC_d+;J0u5oL?ldzF=I(W+4N4qzI#;vJvA)O* zc~VyFiK;o;gN_HEATOAx5n+{uOpx6d2B)g!F#d?9# z+293Q84I%`W;8hOU;3eA|FU)0+I(wUjrHW)gWFrKjvsRPuPZ-ckOrv;r*k%W3IZ*f;5 zVrx+P72eI{KtddtLEPL|pCLcy22}Vayo>X38-B(O5BJBy`=#7=RIj?$Q;;j6jFEZ> z(^*eik^#nJJ(L$SV|{cF)pv}39D(0-DmSa2ypu?GAD>5A0)}G&86gct!+p3}`1AMU z^aKSdjahhRhpbX93S=?B1UVfvG<2ObqT`rV-DWv_xuvx9s@g{PZ)v5~>PwjT)Kr6_ zhSS-CTNI)~g;CYTxs#(W;`r$E!hC#R{Q(2oI`DC>*)UpNqMTUq9x)LaT*U}K(GKCx ze_grI4ZHm=xM=OuFJPMxH&tcoobOlBz(tvguBxqQj;NB<%&2VF71W(vt!Xu__3M<|hua5{CeVaL&mKQe6BCQ?f@ z!}!8UVBzy-YwR*TB4x{HFsSx>I~v0%9ZB)y(DFJ6uH5d{3w39nwL=PS*XAC3NpwhJ za;bC2v=h{D8Ccz0#Q!$@+-un5Npcwv&P%YqTRghUNBpb&VkP2gn+OkA(MLq0A{n=@ zk$2c8Hfos`7n|MGvZS86lH6q*$edGuZcm4)Ht0st?TYsdy(;Nd)IRPRO#&2m?t+7Pa=nw9s`M}6Fyb&ktGMHok zE80I!w6~KAckf%HNzS=a)Ac*>9#*7lCG$Y2D-d>co4Wiuj$TLn(A2D+eCi1mpk*jHNZWi>M>G!KQMSKhQ6A)8q*1E)P!t-&Xv}#ZRgYd zT1UpC!^T4;-6-26m~C~PzV@-GHO^FJxFMC)<0=p$zK_)dTQ4hQ0`I-sWezcW5naC# zgowmy4yDC&G#l%p;Y#<4=}%#Lw7?IDoV%NAE&|+3d_l~N@==Q=WGXpI=J8Bg{7e4q zJA1<>AgIP71=?@cvPPgx&^z%a65);m@)4dg{cQS!(;+SVgvkxqgAjy~QAGcw2upTb zQ8aQDrxpT0LWTBAlAeSjRku(b$U8_L=vr-_wy=?x0=88qS!m2sDlE~4)D%vS$-WXv zh9i%YOVg;|49qJJP%;Tvnj_4Fr>G!dPe1Xxq)OJWFkAF=eP4f+iFTa87OC!o!`4NW zR<^-N*VmLa=3;76P%B{|Zofso!+taUHwRn#?NAwVq9=~}ktW-}9GRIJ7`$>2OdHA= z$slPh_L@L6Pf)S~%`t(m?fFyCcN3)6$myp(1vM-Gdp;_?hMLWG(771Lrn+9XHg4_! z$If2oKu%A+%pNr>94cPvaf5LaWNXNR62H2F9O@~h0gY#_fI&%T&N@(}Wa^!SajZ`7 zq%o+M14S)=T7@l>HVxI@9CV;@t!Aw+;c)p65>D<7&j;s_?bbDR!k z_;P%C6H_}>#mH+}6|Q=alFkzLF>1gDr97wZ(wffH3W{0!t=Xs$Q4!azn##wpeE_Dy z&oR6|Y)UO9hcw@IBpM4VfTpLAT$FI-_~yKPtxdp-qkqBo6k9@h^r`=p%T@&QX0U0B z^+{&fkT5MqVwO?VjeYKwi^Y*dteK2Ch;aE`5+5YwhjW9wEKx2-3qFFx?xF4wZ`gXO zuHyij3NyeSfX7mX#(KUqx!tK3ey9~W7{mdNY9W5N&|@^hu1&ULw}~y4Tgcb&sS&JU z@Xx1wp2PAo^-q`)vL%rJNy>ZsA6x(=dHuO#5a~W>{k!ZTipPP1C}7lKr(|g z4<9Kir*?R9&f7r;;K#_DXqd(P-W*y=baD*Dz@&5W{!{Q5 zedQ^`Vm3syeBWSHEbS$_pSc;cf9O{0 zFA?0e79}8t9;Sw7LpOQ)Y4R2`dXJ}M{T@M5ju@4pRW=hXsLRrdFuNN*fsnfPl%;RCRVj%Wy+i!;kRJ4+E!GqsEV72&EH;VYQwFM!Mth=!p)^bpueZJYp^oz*2+LR zNP{)YvlU(Vx3QC73NZ>jhksb07moHbQFk|Y~9aXH35!|d;HfCOtY_Ur4ry7Suj_eMPAx!-UpCB0ASaN+m8@L@Nqvbq*s`I zOb${pa5I(xr8sK{=#wIhwvbpHAXU&>`{>Gh!D-})DL0FY%A(9^M;en=V!`BdE~}p) zSdn3F5Ns=Xqj}&hGaoZuYt!RIg9v%4zf z#6dF`0U3XR4E@MDkrz?S)dV%}Tgd&@emXhuAG-%k6PkHz+nC1cWK2OW0)Z+E+>GDT zw!2;oc1b}&W!h~A8uE)1XeCutyJ47w4Mu5HW0B}nTRKeBfF z>`CdZE>`7}EirF*zdYMtE>L)W3XHBbdFIvH8`1MOFWSP}t|wmKSF=iGmLSkLPww8| zG)z#t7Z~%oXr9&Rt=WhF{>k-=&Tbszp({zwTwSVlJ&8;euf%jZ#$_(voj`+%$BlzP zHG776(r5KLRO>t>pBQZihia(nv?^Q9PG`ZKm@q^48#LxF*_3prWowf=lSOTSy)zc; z>r$&5V<%LCYYb-=r0ZSAfP&EnEWhjN4`KnIMP@X#$3<#|PtgkxTWscK#xv)2E+kJD zItu=h8sdJ%QA$8i!VhOx-ST6L+P@-40^-hxue{%POYwS0=d(cvgY`{gUqbLK;z&M^ zV1*mzbXy%gnjh_4Lh0MFf5_d_n@LmQW%2#6QGkX-Gr%@xWZO-I<{`A|o?Dau{-CX9 zR#I!PXDgqb{VGHR^EZWxKO~9Gut}b_q{CR^NwsY{+X+P}+vh;lH5Esn=W39gzwJof zc>ReAEI&!FUbOIM(Rj*9*D_vSeUf&Ky+Acw?q6=xak>9VP`y!nk|$!~T_u~7NDpV6 zzT{ts_(NN5IE(Vf(_f84zc-dnE;pi6*rl@`{~rL1Ky$y6_GQ+lo+b^Ai@ZYRyp|YL zg~ok+Gwri=-Td*YC>dhDd@Nkr0><|rJ8I~HRnkkXqSGbfBOEw_C8s*i+KvPJeMhR> z)a$Cfx!E$JS7ai~-^h-Nuw$cLWUR{o%G?F3>ml(=@gV!%bFtqJOP`)T-(_u0&9*{3 zwz#1|hoOrFj+{Mk0S_lv)JRlP-Ek=|Qm=guDKvOQ1lw^7lwLoB=? zcB)kc)zI8vi`yHlb(MN92p%o9cOGRqPooUaDrLoEtKvewXzn~Pi;N+FOt3_s!&@XJ z&YbJALnlwzu`{QwUxp=kt5Ow7gO)e0S+?A1piX>NtIvF{B}W51UXx~9G5M6{H8@9x zLATgU!oNBv*!kWbd;EpxtwRlp*RNP18>>F3LelZ)`xKNua>NdtIA+_{Z?x4*C3~gQ zm^aSjOLVXcA)+{ju1S{GVHE%3n7K_;bbhA~IUzfwPrC)*O3B$pf*C!Aa~CvU7mtc& zsH)7PUkOO!ki;;^ep9*Co$azN@dmF6Fkb8)kX^Xjdb|1{RmTZnmrJ)dwA9-Q^j52a zj?w{CAHtI?FX+{BgZ^Bz)umK*rYgx;h7=V6BlyCX;tQw6YSHzVaC`FX8QJ{D?NpaC z_*yB7kqm#g2$TXIaQQaTfT;Tf1 z=~MQ??p=1{nzim*)k}sk`jvR?lb^Qh-~sEt(CfV#m$kK9oq`Zq*_hZ9u~r`lCfBXV zT=*k|v3M4M7n!@~@FD5a7U`{a#iHfVGep-OU?BlVup1oH2nSQ~$_vC8Tm{xJ-~A`!cOvv{>$4p`|%BS`ruA z&|JHHjsC`eMQ()y{1J@o>gl$_(qGK|1tnO)8-)sW>KTL`vr@8j5rJOW=`4ZwN^S(k zH9TUHUUQ5A2||_$c!`i+lhI-lF+#TQ-)CR{;Sa1@3Z$V{Yrd+gJbQQW_;FX7kG}A% z{q$QOa%Y;~)KI#qXuUgiF$PRfij;)rDPuy=OoAc^Y}dj4_V|m>+ud7kR-@8dcS6UG z)~|@(NkAc}WI?;3JjtnM^br}lQzOT{q?xEE*ToR-?Ydx3eeF3rcS03BT9b&P#L=s1 zthVOmO}627H6l_8etD%FCk0So-cu^uWPRP@jCLcDQi^SmedGScxNgXYj0^6F5)&-> z(TmjL>?mr{g@7uxC`pDoGt1h}YGLAUTu` zYrhHBiH@ONc;%!Vx4-|_uiHcS-seUqf)SoOEob`KT|4b-k3VYLHf;2uNxgWlf=q;d z#|c;unxb8BFrmgxCEn^md@LtIM-nAiSc1x*#`pl3-IyaKzP$$Ix_Uia^O$@aYb zqW%Buy$6t7_nqhW>o7Uzj2wu_0DuH@5GhfC(u%8HY0I~_!F9K{b>}Plyz9HNUF9oZ zxyonz%2jr`uDomS-g)o#RvV>ViDHx}CNYX6M34ZHGiETDoM*b*pYQMW8}tkY1ZJjZ zFay3n+323vuiyK<-~a#1;eYbTA?wAlTYw|7q!c=DYO>pQ-$^D6y-es2H8cTTL83Ll zU5Im1r8&`<0|PjyT*-$vJQ4<)WB^Hd``cH3Vh`SNm+jiL+2w#D;{9ZI96Wy9o`36S z?i{YiAyJL8UxGrX^`919!PD}YPVT*;P!t8H?|#L4$Q@cAGAa~=c%OA~5ShDU!zOQ7 z9ym;bvt_Uvd?co_0pzh8BKe?yF0Y-Yl*X)=dMt!Aeepz>{rFE_bpTv$fp5tvDV!}E zo2-3Xo0Z_`mqSw7Fj81A6Gg9?!k|-MF3RA<>lQu0k=OMLkIc32#F9m$vRns8hV0PE z6ZY=M2kre&54qAgMufH*N7*`zxQE$!U<=K(TM49;VSG40F2k6kWF%IsD1IJbT=YY( zzzfJpm+cEA*(Lte)7NKz_WZN<=$^gMP6aw)z*nU|NH2U1{p;nQzvISq9Zrb~0xfa% zWstEXXCK$dn-Q&bA){04!ZVI!=wF#jEfHx?%I1(vJ$ClAJ$UCHcjU;CCWncHf|nN% z*zhuMcSw@OnJ0&dXiH8G1;O;v-w6U~3ZT9D{dev1`5w? z7s#*^Pi?8MciNhq++!v5UMiI_j9w5c9YOZW`J6}nbp?{hLO~MWmy>QvSsTjY$QdVb zIR5E#&)9dKc+zbv1*(;SaG|@~&f%Q+@w;!iajhVx4 zUEN8Y17rA6{QcyGbM}MRerk7Z-R6OFIU}#|eV$W}7kGf$Jt%`yeNOT*xGMBC?qxWa zQWpqNsT3D^1}zn)VQqu$dU%_aR1{lvGyUw)ujJHJpA+qz z(>Z?5Wmn!DCYopDrg|6Z`Q1UAeR!<_IZK9cgek!DHn!LMhY#6t94E@+Ye2`A(^-vR z-FYIOB^y}tEFD0+E{EUc1EWfM^4h#>w(B8@}9%3-g__Kb%j{A{Si{p&N9CQxQ7uo;0!eFN=jN`2nXV&EM+ zcgC}fI`DWVrA%d`eQ_LVXv{)!G;%QD>rD!$dKM@m>KG4BP-Eg}5q$UT+|B$4j>_X_ z&$==9?9X1eJ8+iVvtySvqmUG!n8@&AA}p=C&|wUHdh#%NId@ywM=|g;bdzGo4rPhB zgt_Wcq%NWmA4TD;tFH1=B0oj9&!sN}gEtVCVW&s}JYewy_z;ABD2km^XT79`Bhaus z^YExk7`h0`fiiwXIwQ zO<@;1T=@YHgD5`T^exfOIFOeMG&S=)Lv-^1wDs<*Ke48kcH95hm#n6tiB!u1pZh){ zcLQYb9Ko4VPiBFd{tiQ5y(6SKqi~jCq(RZEy#_ z6pKzW5R}pN7caeF6)5&!xO=~?!wIY)fIBjoJM|Wkv&0p;#gmo;EM-{buHuz*Md1P7$Plo4W4#AI~ z)D=tklS)Ix2c4H@=jS3u#9JR7Kw*EyO9u9B-(kP}=;IzRQCj<%u8ST4y@>Pwk$rbt z6B$Kl)Xwh|>5LO057*7&puDwyEt*hsLK#agd$NpvoOSxj2Orq(jhlSUs)Y0s{7X^h zzxwfuc9IN)C-1q>Gb3fpi4nREH;BQ2ju2OX1sRr);qnzDB!yAuL9p2ubtbHSDl`i3 z@ArVvAhJ%fS0xS8LomWW)JsBojQRvnyEFauVm?ULXaj&=GE7Esc*yZV;ewZup8DvV zt-rm)w%onZYFjB0B{QS*#3i!2j$`Z{@N;%Obel)QMFYw*QI?7_OX}!lqGqsx0A3uk z@{%U_if1}VPl5i!p^qofOJ05dJ$vT$S3Ohf3(!$J@=ICea#o&$wmv3nPnkO{WXBmg zkLFqpmQ_L=jjJ&r+2ziD**2~l2Ipv=l_B@sn{TjH?|rVzNHz$+KYjmQcUI}|Cy-wo z$w*b!oM$j9psZ|ff|;HndG#&lb;N1b9n?6|fKd%-5eajS5#aTVPKmrXeF>#JdAz||IBFP4dG9L$yGY? z!(o^4;kAgBZb^fJD|S-OB36ljBHya_lrpFCoacV_hL_^(-M-zvbngQYZHF^XdCElW z$_8t<~VU4 zhrC&yr&|)~$QpCgKL?-1Y#m%)Q*P^acG#{*wpm?Ut&L(>sdtGCx{rSPiS2lBE8bB> z5>zIG%)wa_l;+UcKPJJrxFP&acRd##H5{KxvmJmzzwVhO8`|6L(Y^azAhwcrEARFo z28ibV5cdRM(K`lysv` zWduKd=Pi5ZiC~#bq0BNj>R&Q0#UeX2 zZ{yU?A0iX)z1LoL2g&xk_S*Kn_gEd#?l_JxrH`sq=fdfeHclCw(rU%zQe5*g25ULd zgvenMVl+go_z$!o{v$^^3gk3>af;+t3ESSQ3`b72^B8neN?v^XO*fPtCOZDu-n%I) zqa25GU2gF^DK;JncVz_H3TbxM?Qir?@km#U9Y#4!uRSEMBxhB(J|~U2+>>*YNc|A< zK>b!0sVs@M6kXSD0xEZtGci|g>T^STJAi>8}UBr08%*Nr9LN)3+4;_N>Nt;r<(})BLr}syzc>` zayWNziVf3$(WM8p z?ys^@Wom7|Z>zO#Y3A=jJN?Ny8}1#kk0{qsw#g10{f_@^@{Fz%1k(nsG*iC#9p6M> z(ouxR9H(>m*fP#P{PZxH=C8SfY~PL@p0TBpL9GFmVW|X5<%H4|gOnbenksL?F{Lc7 zMa^{x$gT7a&Lw9=Ux8vb^Wl`y+gwYa45yv44|-3Xuonopd_or82so9q$WS zFXhT)7bxIVMwoUsRWFmpI!xpw@qj7M91iIlx~@uZ4xc)SP5rtlSbFd7+wF@uo40hV zHD#5Lv$KzEF4_L&47XC#Wt=EwObn#D zu@dJ=jh*<#X^cEwOTP}@hY++~lNY@CD;r^b78p15UFhKzW6d3EAf+;pzkBT!dx#YF zZ$A13yMqjqauiNUbw!Y5eB4H~b~A)5g(JzFyOF{A3BAVotV_&=Z?%5*hF;KtvU@96^sxIi-F+hTCdI%O?a*)W7 z6K&}w+?G}pS>u{|JM#8XkJRZrKbe1(#@Un7HP@gE1U+RI&IIT?gi@*AB_ovN_*`d^ zpAxB2o<~N9 zl&ENg7+{83v(w64csS=>8BwCG?!F!r`ZHcC(Aj<2e&^|5v3)ysT76ATz;D9^t#_5p zNwMK~nG8-~IV2W}cdqA=KV6Uc4A))`h;w|@%GH0@I9=Dod7$SOANM{%Qlb_()C)*W z@8po2!g+EvjFgj2UXfc)Tfk6JTW$>U`48*~BC=c7taHUO#mPk&$ea0l1CdxchSW$m z7(gjS79J(JsD5IRQI|!H-dL}`&eU50ygQ*}zZUjoL`!z5gjdG$d!Kw_hjEzv##f%Q zCotsXEOEUsXI=Quxvm^vzU^q}yvkm8zHcrIp#jf|U@osWv@U&}3Z~g`<^EWXSR6ZE znzslTbuVRW1zVev1<8McA$LeGR2l9R@JlCDkA-5Kkqra@3X78*lSA{$FM{8rlR>L} zlRcCURJu`j<&1fL5-QvPW1h)o3<&JXHB?GEh*MD|I43B(_$SD(hbU=MhOy+4Xsn!U zNR<#On@`63bO|}%d5JDK)V8S6$YMYfDhG!g2g>$#*>E*>##r0h6yTGS$K4tE(%Wy* z7icpj+_!tn0rd-+)-|LW@bYy#7;N{vV7V27A}4h}dYp8_+(j6EpY!gIr02+4t2tlK z_!z*LEq$Hw&3yjQjmUsZje4fAWR3K~8T2JIV^e8E0MV9YmztC-SY1(DVg23My{ysM zL-t(I=dO!>Jd_hI!fC2&XSja~Ji3k@A_&@%U45DI#=|F1+G}LKe&fN1-O;?A%zg!W zL|fWeLH!Naw6p*fefc=9XH>v{ij2!R?+W_I2(+xmvFaQq^7wf}N|-kDarNNX#hxCo z|5OA=4D=uV(zooPyQl+2$ydj5H|BDG3H4}8GH-AvB3h`znZ2eI8sFYbFy?D;E%&OU z%sSDiW8DIVFJ7K+JT>M~HKj~yVP8FhU1qdfhte+9p?MMmOMOkc&-PdiX<D4vbU|PSIDKb;QJ6p^STHo11F$;|%hPkvQC3L87#4aB1v?S#$QbI*KA$>`-YNxLeb{!bS zU{q?e?zNC1!%aO)`s6T?;;!dshzm>*@#&(j|I)>CRz>sC&AaZf+Au1VJx~0xKt^wD z7H#o+U3;8oQPxT^%JiJ;K{xxn4e$~uoJ4KsR5)pzM+u>#H!3AmnSdI9{?gNr+1)#L zQ5(!|7}roT#M4q_gvem>#%c@2ElRQXFH&~NoG>o%!YnTtu_OayXaSd!oEoPA zWiN)86ht*C7tPHczN>)b7BrlI-%Me|^>^XSKScvejI0VGcTz-W4_#Z4@S-H{e#&#^ zy7uhfT=lnTRysuk(3(PYdVpE|ulXHI!L!_!LbQyE`c%>-5Iuk084q$o2=3Vn5eN zOTnk?C5>!bOi7uS&$%^yYq6yOosf{x@I4TrEfNOpev ziD^9#nsGfPM>q7Vp#NtGbhs1a83A3vPcvfafA4=x24^9w~o1i?W%tmQMbmZ;h-X(oI z@T22PQE(TE*Fr!BW_g?`AsPy!yqlgSD#IyJUHvsjuo7s;fka^8&6tpD{%FoGe5Vbq z3C&?X2QcB|3tSMQuLZyuu6Lz*j**6}^vf%POAj9`V$sIsb3HWqIs%RL(I~IbqvGd} zT%auHw5{L0)=P4f85PPs|0Tam#W9!?swrQ%z!6<#kX#ecm3GDvWIi}Pi6S~c?~E}V z$hGX+RYeI%fzs5$_q0&tWdDt5X^t_4!lIriMRJ(T@%`uZkVn8#>gH0oK)v%uWQWh93Wi(kQRYp#EIq# zpB{FhgJnN4HR|O8=~$MB&sb%o4R=$H`Qa%~WiCT*O^#35v7eu`&Xb*<^1Su#O-^&- zccR%~6%#?@Kk`?AF@P_m=lOyfgSKKWw{m1F&n#!6c5G22-x?fX_4H9wJ-@a-ok?3j z#Q*yB3{3Hd>pxmUZWPQHvRn9ka47n^*1A=8!N(_#dlXcWZe{VP97;j7prmjP+-PAy zi!vUh73R*-9A|&#A!UB)Qpzm!W&4?D&+|*^M?AaQnJ+|cg?CGmRiz#aDRyQH=t2kx zNAOwB63M9pFCVlsht4@)x`31a*oP-=&GvTd*hbTG_`Z*v#Veyu*;|rdi@~D|)`a6R z7q$>aWfD!=C7S=MbU=Kwfvh~0SzSeAbMQCmvcax zf{TP;5~QLZiey(9R9Ud>LO<6WGvoYwWu|q#QA)Ly{i^B>c|PiC@T!efL|eH-IKT(ANh`Fy-`s7rEi}!g)>G#NL(rgf5=lIL_LN;F0w|*n6+HX>{NGB*1Q8x* z^Cy)JghYBRLn6G}?BD9Rp5NuAQ*|y`hrjy#zPkg=k#lVvXe(!J&A5ihjRc|QL-aM* zTm(i5&N(tSE>qh$d6^y!3#Bh6Gl_t$pj4!Q`{i@y%o%E?r>0(IRqTu*oT`Zz59+KBTI*#C*OQk{ zPz%m39~otM&5sKRXQJ2vwxb*5m8Q#jE`iBdS3_S3UZq#+{P2|n=4XDCVW`*qAp6Wq zuKq7UhG^kh23`1oYVarnPFr4Q0iI!PluJVLyqW62DI}!e@IvHPzZ^X4|O5aC^JYeG{c1KD~#6?n$j1dXv;Yn z4&6m^OLJCDS-|xS$1A_gUV+~QYY4?tj~-q)diVA8S+YyG?uWM2v!aomi1e-{DA^^s zHRoJ_U|n;R=yDITYpjo;fDk|5!M){(Oi~Ia`BjQjdg1Qwt8XqjF@b}i=llg=?78}c z8#p4JXi8hyLt6s{G*d-(c`$9(B9mQu@YTci=J(#4an%jbsyd>W(3z<~PY*0j*Fx~fwV01WG9ydql~t587rL76rh5m@5p60t|7;syL!nA5;G+Qs zgx&0XIg+o9lmUX7+J&l_M%B_GoW9amdT!1Br7{@@$rv9zP7qG(WdX}V*Nnqgi`bN+ z9NmSx&A!dT|JGXOa?>zMFU(nY^SSUuj4Z1JyLc8t0Pd8}g%f$pQ3JwM-EHTdyZJ)X(boPGGW)c6WzzHAcIq7fjAJKDx&PO3_dl3M**Ae z8`xbLoc+PBLb*>6!mD>JlNrNL)IZny#f-C8>>sPgVG_;BEQ=`Ms}L#0RTI!O;Z{=q4G`A^?K`7Fh-EA zi2l59>8IU5<~xMPOEl*CjW0ZD|Hae4%FYFW;SkQ7Jl0CF)tEy13OCmp`Y!xka8AyH z=5n2DUEuR}+;Jit&5ctXeVhhfS}W42rXD`y#eE|5*_jJ#@ z7F7KirpBhMetoSy@;moh`!1Tym0qp1Nx0p%f0Nz&^xYnj(-z%Q)OyM2I(6WTefZKr z$600lJyzxC1&?ZY2_Yy+2u{9d`Z&UvJ3e$-qPg12pZ?y;TwAGGz` zc2fSswy#-Q5RdVU67Hu4iQ)WdzOKSG*Yc4=M>^-K^MbQJ-~Wb)xdw8C%GQQvZ|tgd zE!z6=yKmz>N!r&Qco@fjRn`!=askby=I6+fBlbhK9X{7hqu3M2pp&2n(R>jeN5TF2 z{`>85mDsIY&t?%gU&1+_@5Qp`>-TycBl7qo8rZ)6(T8?{3?^?}JD+Rv*i&OETXR>d zwcOTZ8}9C~w(abE#unVOQ!%QmD(&&#ewfC2>oMXgz){9`@;LMHIgjAI_Lpz5Lr%Fh zZ)otcxZoKxcIxCw`@wV1+3~kdS>HLDK+n^$BN`I!cdp%FU%&rByKl#CYpSVpqe0*4 zKb3yy+0Xa%`tfV=!2sW<+0ZrcGJ&4(#r^ksw!zPSalp$1 zlv@7uV~?A5C{*co$OA(p>2K|_c;Mrs_Jg0j#tw0Bq8qZ~!#rt1x?~LjhHY!s**bb@ zeB-`{Xi(q4ZCDWwba==4zw7rSI7|ry+J zFJ*P92SzTwh5r%g>s92!zy9u@kjONaYm9@L@ z=@PALXP7tGcKm<-^Z#I@(3-9r$c!7muCf*jN(*fB{p)T2x9_s5=GoiVD~Ox1Z4Yj- z{a@c}ZMU_!bANPj%+4M8s5<1h)4sYRmKH6}{Ot8kF0+2_W#HFgk1?bY{x&arY=E;rFkEr}u&PU!SyA%f<% zR{Pe|Pm|U)Uy=>3qNOm!Z#T7cc#gCDHX9$Lbm}vd!D~FKA68jJWpleV zG&cCWiJfYQpA=!Ns}!enxZq7un>PBzXvZn-IPmksY@1HQ@(S>O4}DCu!$2;sbsfDb*fx9H zt{v7qvd-GKZ=G?@d{8OFV=KG0RMpo8PC8&(?C^EMMJgzp+`4nS6*SeeGm~g*o~wh? z<=8K7EVQ!bVk@sMv1A-v0^>|5Gp?9ww}Ca=+f2LUXsF?9NqLcd^vZ;N@cco$>xtb2 z0V;5ukm=!uKn&hhZM(Maw)G{OZE6TYo&`o@tOI?7YtOuL4h+}$sQD{Gk7;Dy z>$rCbjnS{bd5s_4=wX0Sub%4qI@@{2?N;2@G=JLCy|_-iFlMDq&{hLYIF*_&JZhQc zJg^-1$|Pyq(#nfnawsPq9T;`F_3`VU+UbwZ*riijtfnR9n*-?gqJe@0&tKDIckU>) zO(nMl*S^tb%sDT~JU63vG&DDP217^%{tlelv8xr`um<_1ru3&To@e)tvLK@Vg8mx{hRt-y4~!T@kew^2Hx@E$Nz#EzS0oC!eH* zIVelgedjYoO6p5e+);oogN+b%-^;y*!-hPQe)4H~ zCB2eLTz*}mH9~D=^^7}&w55a7`U-RlQ(-O5$Xl85yVdZ=>T{S#;iR#Zcks#%9^GiJu6O$z?$Z@_O;*o z6|0O_yP&-Kx;J=4cH;SNiA(yh{xQL)sFpCD|6T zvcAHzN+wA!9qbvjrpbC2Typ%CSDz4BDdJc&24F(G7hH+Is@KT`%8&$N8MWdvEHw1JoF#>FnD99o zAqH3hD06Y-QYtna9-MyHoQQXc1{)|D*t>lfJ5H6@U%dE&eeWkPc@*_Qvbd!D%kU0| zc-=Z^ZvFjta{Nx z;*e<)22R(uv(5L$?N$?`Zx5xXfwSX!P|fgoAzm-aP%?>#Qr2n`P6CRA-14qBMgN5(R zutKu92^zv1z3e*NIr7pwZ@EGD{F^_U#&aW1?p(bPa&SlSw%xmI+aveW_%$0~S_rYki^5}sKt@%PLDYzP;6bB0^D`mHruv-gZ2E)U;0i|CvjS7 z2Nl_5>Ngn5FX6ic2f@A#_uAg4zLJ#-b6ofaCPK3n(+)YO&jq{W=)l4*HREoQ$3-*& zzH{x}cE?w7M#ym_#;}k9OqWs@?ZdIRZ77Y6=&6xcF61D~<|0@r*`)w=hJ~qsS~;2g ziITVt^$yuM(NOrAay)egX5Tbu&zpGPuf|T2| zlsYD6_RxfDC!8R?sMM8_-~HuZwLf|02lnh6ulecPhE_WeC`C-e86^f&vXaToQBM5J z1^koYT9egu-?)sxAr*14lq)7LrWGqAs#b}@%;H?^s7lpXW4hIwe2I8x43Ovx9Ome2 zTYvLGM%|?|m+joA7jUvrM@l_-8If{V2+=yr=zd;b&@~NdWc_4|){o5Z#oQL8|2Jp?NOGDWd47(q7el{_lO|YhIS~*FX8OXZWPj?CL?PY&b+) zY^lO_*u_?;{nrZ&ZvRPX4T{=Wu{O^NzDVk?LId9#kd#t*t+6_GIMj&BDM-eM!%eabyN-HnAA}q7N zp7&df8uet7?bi{?L)vCFdj%;7%$ zQyqr%m+rlv$Y`nk-zc7c`_rFzIq4Vgz0Vu1c?m%_3qdlmfH`AAudJtqq7wR^(t41? zSk z#6Q$)p|C7)ib%23C`pu9T~PycR_1>%_#Z@H8ooCvr|u9=C^s@hM(|JOYRj6n_7t*9 zC08mbm7`<}23-XX(c!TXw&Xl%$IhIwE+Vs!?%C(N>4dPzVIm$CBM17Xqc)bOfrJp8Eb+rGp8=x_hVp8we!ZWrj-dYZsqDaU4EOD3YrCA4o|l+pN9#Wojr`)I(s?T0yXRO7v6mcjbwv06gFFJQ9Z6mmCP(+6j6e}ZZ_Tv z4X0u*9%ZzZBFE)~JVKe$pHpTV8u%O9+FftCOmpo|`R>zGCw)ur`|r4u=h4Fv?x-kt z!)_7K-FiQjEN3!BUve7yPER@X2lY{odSAXh` zs;{vXZ7oidDR{8Xjmz%Bp38^cmBG<4fsBpGW6tNfw&0!b{#IFzgq-H3FAb2Ti#aJ17d(h6a6WD4g=xfM~uePtl; z_^&R}V2wpBR#t#LhW@#jA@n&Oaf*5)?d-+0nL$lurESL<_AtB7zVy!9UallvZ~OX9 zWJ;FP2W;GqojYTPP8V|(wW-^HZ8`|&|nNR{He zYttqiXBZy3CJMfUjxwZtXd)Id0#7wC%3$PPbC2wCC$G<-D^8OUB88I&@=Gmt-AZ67 z!>SAi2T$+rk^qsx+7%tC=KGyz-lreXNh?&J6L2y{zm9${!DT_!`s;!1=c$QL9Pd@|W*M;qtJb|_Er4%Tg+OjehZh@Lb zlzM0lQ6DYZ3d@>WNG(?>P9Oc8R95=T>#w+DSlbM%CS3CsgSJRP@Fw6Quf?ui8C;&h zAUL%JDm4FD!@+mD_SJLMsp8k#rc(~n>1OI$EDIc>EzyFUhvm?M8WhWU9}cbK&eL6X z;HRHhC))^@RjYX`K{Mo*b{rTZqiWw*Z})OI?L43_bmR@9N~L<&NN-U;r5X(M8f08S=wRY^?hpWv z?(EO58^37uk7K~MT1L*q2&84mlnRGL#6=qVYprNrmA)^7PCM*9_0S{s(%Wy@^KZWH z5%7oZ+Do>^8lNW*QlW3gFaTkai(3-_qq!G-$w*X&vnq7a!djqpn%x~B?3p^JMK~LN=n)!pT>reTb-1a)mPD* zYqQl;{#HskcPIpsvB23cd03EH_?D-_f)5#$@F(dMX|{a!T{Cm&K@sg;n>M?HSq(@J zpE^NV2tf#FstEZt!A?4gfY%_;e0M+L6=$_@7{ehD1=F&4&V3WkaYncvN3bfx7P*0r z7@Q>C@a(0F-j|6ltqfkaZM}L%PWt8mFTeM${fL0IoGJJ0+->a&urOz?%gG6I9XiMb zWzNsgRz|uEdYslnCJp_gXN7BUuHQI4>V!GyD))ScQSwW45T+yMo^XSob-zL@avNzR zjVwxyOoJz#N_yexgJhlm%Vct2HxdP-qIK>C4`9q) zC}I*)kR*r-h#1M^des1aDX=1130f&lJ{Gy?q;ys%YpoJO8-{qNGEzEa(4phOz}D?r zxn%gDVmZK)(u5YSnRq5LMx7!V1OaSlGW@;LpB`rB@U247&WY^v&J>`cK+}M zJO2Ku!0E$tl~xs7JL=JWIHwYC;ZPh2GU-O`7xSah81iXvmO0*@u7q7?3C1)hp zNz~~P=<6<}d zhiFTqBvI;$=OlDBnC^4h((!WS(;biRM9J*8&f~}}B8}oZYW^xiZo?hx*q(6{y-jeA z;M9-;ELqr?Y_Yl|eG#PKF&2x*7{;p`7oqV!H3p0Uebv!}VK3PYa`wOT@khKqws#MI z>`rIx>?eoH9b0d+%{Zf#9qF*h(5_^AX=M=zuhF;~1I%**91-@k6o!UxhYHHoDD;~T zKIF8d@+PfQmFOtN{NcOyxfAlwo`2T$LTyRADRh`f=aSrF4d;XBI>3@eWUK^Rj0Z5X z2xmP@@VDqnLnX!1U*?*38Tc-{SWEeWu>zUZRM5ud6bPHNG3lt#dLer9*8Ll4pxkAj zzI}`auR-9XlrrrGlh&umeXv#h~wWTsD)PtSm2;53gxyb^hRiG|jc}7>zkG z;5HXdr!=K6!f~i~X-?(f{r2Nex+A{`ylbtN;H=tB>4xY@E8)lQy=A8_T(H~Lt+x&` zx!jQsjpYmHG4(ry{xf7At2wdYQ|a?n1K3@5p&N=1ge5|=&vR#_>;=|PzF;hj6Zj(2 zsxQ^;8OV}fB7fM=vZycYNodtz3Z~gP9sG$?H5j@!* z3<1zYy%I#l!Cx`a0!}G_XqyXYB1t)pl+RG9T{N7UYbo5d2u31wD#Oafju_oq7;ht) zLXYjc2P1^_pst%J!HSwoiPsK0iYRGqNU@NCG2I+)(E#3p7Xarfd2=uQ`7Wn1Ejg`= zt!vlW|MZXlsjpQz29Vh=wy$8Q7M7a?-Mwu68lhFgj zz$wjR3*RD|he6sxR%F#xiO%~XtJJ25Iqc1S%-<%fOeVMyKRRn`6?X@c6tyACE#!M_MI<%#Z*cv zIZ;+zN~}w3mo!)LG-zzb#La!{8$K%IK{7GJcB(0Q=>+R^=Fh`*G6i2&X=^d**mC;k z4yP|C|Wym!J-lT_#t`cg1jO@tQkn=7^W<(qqQ}nrs!FwgA zW(?h*|2>nFMOuP24bi&am!5X_V7Q6S)Dy@ZP<*F(*9HL!^2n*8{w0$C!lR74Xv~cPw`mki6AxlN zr0FN8b)faT2idh{%{r&AFqo$0DyK=Tonvq>QMBb_+q$v-v2EM7ZQIt3ZRf_eZQHh! zd*e*zRn63Uqfhho^y=Dux=vMh?OOY{7(KqYj_u^@q*W6v;(Lz+B-oZcS0=XHwfhvG z;t1=qZCl`Du%CLj3eXFmi8L7@J{sUl8(++_Bd+3oZ2!1^GW@F?_EBq54K;UQMUuE| z?U^$5;Q##2&kFDL&D(xl$~~Wq9%3D0Cao3G20zmv>NOziv}>zLC!ZYe*gNw1pvEsD zGy-4GFv;kMrR}6LAJc5`{S<5Xw>h@p%|hL~m+_(2lKO!JH$D2@I#8~x0xm$Zu z&(#h}Jx-y50M>x8L_s>mr?BeG`t4#%X>P`nz;uA7iWEG$fqLu_85R+vIm79FPoOIp zTw$)@!vrppCLx=Iv;uXUJAoy(9xEKA7*hg-6~r9$-sy0RBXzy}=;dp7r>gfg9J+^m z86|=qWsZ2BMsnkDS93Y7+6P;lNX`r~`oZ~xc2N!FXfOv*E+LI`#8Y~W*`@hF+!u&N zVV-godqGQ@QDh)}fk`@6lQp;NXEx?2uRY~`vl_oQk9Xwk(SqcH+y^Bxr@*1Ikgtn8 zf|Lqs(mnR#0=l$n_7n2U{B+_%h9&y_3O=6`J^_d9TOwhYbHu;jpXHIj>$LnSTYM!;rT{71Y?(|vB1CxL74f;&TlyeN|p4n7z!UNXP|5(%Ug9)*@cl|#*vY;jFyk9V8JQK^EqwXsnSgX zq=`<-qT^p7N@9Rl-^7M z*a;I1fKvYWvQzVZ{U==KBd!1o>Gjd%9%-IXkELhbQ%LvYvlTfB1bqw%0T=f_7SX2< zN_h9jJO_^Rvuz;sgrhp~_tSFbRoe=;i6lxb%Qj!tG_r1dDG`3{B7wLl-_t*-r%S@o9>TzV^ zl%_$u;1`tubGMuyb4&z^FD08I`aJZufpFo*$(p9A%yBDWf4$<_F^~Tpx`{H>M|}l_ z*k{B#5G~OARf)zJS>u7|D)RNz4&|9#b6gj$u90(FxnOu8q^Ku{=Hu>GM=h(xy=Z~p zbMUeF6vKfUYB3BO>}*hK6sIIKrUKo7@K8C|uAt(wU9dMY5jVzGfzE>xcG|)9S>h#i z?ODxIw)bRUdA^QfQbB5^m^!S3)bxs{#HA_(YQ5ip&CTo=^dy}O_nA^MCsRHa$ngsm zflQ?Vse?f!RZ4qai=N-1@SghUoM4;!>FY8v4PdWevO@7V!4qzlJ2++;r1(33ajTu2 zhbjq*WDg^MlZVK?W-`D&8LEW1^ZmxCwInhG+7cyfzXQr2`|H=9_GAAo8B)Ql3Fu4d ze3G~=E3zlvBC@(Q%O#lDbGzZ#837)?(;Juc;zID%UJ<#mtBvmogV{b(jHBo&j+LcY zs5269#w^hrWeg@tvt8YE%V6B^Uw{1sx#+x5>2&)Kd0+I~l3>D_rYUkLASKB($*iN; z;(yMW{{rVDZh4a-K<8zcbEASG8Zlx)*u zA{{Jlj|n%)>PJbHpx9avq6<7`#l!X&IwSFL-4FK1dkN*{8zX?Ux3fdkDn^af{NMv0)$Lc_$j_$!099$d4 z_?IucgOX#b_F0Z^1G}a}WtYf4QV&Vj6}px|Et^doj&<455c-O4XajzZ^i0za(upus zADoi4lZ#aY1h|B{l1C6Df!#Y)f1eTFzx$_URWeAbBS6gYx68%e7biRj?6((=2V3|W z^JCGt3RNyoPOj7MQ5nl>xhkVXN|qVQGy@pWq8Jr_Ks?kwOHOw)*~b20`1P;MBOjlG z^c+tA9sG9V-}bILGNePlXnCU3m?G|$DIwEcd3J>-wE#w}9I|NBYNcbx<8&97bel-< zy`*@J(s|m#rBH<~*<5U084osp7&UwaNovV|KZOP{-8xYH5u-8?a)ydivU(yTFMs2L z9;8V&B8{5iI6P$k!|PI{0jL`wLntrlGw#XzSE`TLaql96CeQefWghfeqC_HYA!k^i zna<#JXH(4P=i}R(AUxrn8b5RT>$JHv&B%gQqJEeX6z!yaXv8to@2(&6@0U2i`tg^Z zTNa0D@86*`pVMHe>;xx;wj$9urF$T3W-F92*Hj}Y@e9*J7k|-B6{SUP*Y!*F(TE68 zR{8);;{J-SXhHo@%FHOQRWe-JfjR`7Np;_gt5dY!G$uDzL&gE3n#7<>p@?y^8*Oj= z7EY`~;svbLkb<BFCJG@H@;B6MVuR zX;&!x%D#=0;K{0Kx;=IY+DaU00(2ZeI;%O8HXwf>*72*u|Gsy6U-7=#`FOmi>JX8~ zaVl)E(cJEOEZ|T(iIW+x?Jx;^ihm$Q%#$QDNo3{?{gMAK}h1U$m}Xn34}KBO-D(n|MFKdh|cG_0Y> zSKB>JR~oU%LK<8BBB8FrBkNg)_3k39hfdanwg~|gM29Se8C%}G zGyRr*6b_R9{7r@Cx$u!a6@K3~h$`=fB(cd)lk9xl3(gSnaXOPA7y5Z4bgtr{t%mku z^$9@6DjO6hlshMBdl6^&?NDEleK>+Xap>l3a>qekq#Mh9rt^Kd!dJM@6NW?TS)h8^ zASrjn%~`xdh^Cj75kG5z%%8He`o5uJ$n!YAgCpomi`%-9QZM;snLocnW6aU!$&@*S zu|&I$0~s@nkOAG6PcDHuzX?$FO zHmLSivoNwEcCa|k=a6669^h+P7zSv~ACCUk#AUU%rM0eRQVV9VaXg@X7;6PPS%7CB z8}1vLGw3=G6pV0RBgr$O*~oQ-I{qY;sc$a{-R9ot=MIr372ga57~#Zv8VJZJAwY-3 zQQue<57Xh-15Xqti%pe~kH>Ji9W|i~OJB0v80a~qhRgs$W)i6hiDFnye&++p5^ud# zl<=|h*;~cx`UOhn_l)`6KH{@^pB14ICk|AGDbOeN(I1_a%Ura`CK`(GK8Fou5sg(pw zSeUnTj)igie+AhZ!+ zASgzKzuXKrnpgGq?y}kLn2;5-BxT(q2&o^Xj*YXUj#QyQ({k6Q+)Gk9|62;l!NzYw zJ`!S;U0FV=E(JfXE|^)kTNf#}q59J_B+R6z=rCS$ab!xJS2a}EX>tg_L-CM+41C0b zaYB_P%uQ?x zQ&M!CWuzn~Ks5L9cBc=jPD$-6bOfMvZ_Fet8w3Q+BX1B-Kj?KUt73OxvhlOZKp)Yv zc{?|by;M7}OSKL|X(EPvOf@-GP3#6~h};S3X3&X99R#C{K0 zy)|1@hZ@4-7frDbnG)QwYB`)vy_=gBcw9@J-;X*CGLHwor}^}V3bf5VeTk27zOC?N z+Tu2amxa20H(b%%s$*NV;j1gG?3&OBt;uk{oU)-dw?RDN4d#`;p?kdDO`J&z+naNT6EP}AxVbZ%*!nN0nUIacys#_Z}YWAN|i z_vLrf?)$dCop;MtG{busDVsfTsXacORvW@MT_4e z`R3gjC<}<0oJceZ%Vc@T%clZmbQbP)LRpWMJO5jqtil5Xc)jXv-RC~RM_@Dj)0#bR z+b)blzUcsKeKvy-1xm!QC>dq`uPZjbi|0zzbLjW8HF3AI*E~&I&uU-e&Dw(ygEBzN ziJSscSws6V&cuK+eWFlHNb;iz$e}MBjQVC8~f$;&Z?E-T3`4d#`aP`}Od=I6q0@F4IvmO1DfRHIS+lrzSWf zlVglo;YBz3MNdT1_xI=izw61n8dt_Lv;RM#43WG`x;)T@DV*px4p3};8U zvjXX~t7dh+diNdQRlm>MZ(nQCQz>b7)EUjo#fsO-5YT7h@uj&XWMe0O;Hd3?yY`^T z?VtAqp9wi5^91(R4MHojT1~eOLfVs*og$hZ?qrpRKZ^B~8n$f?@U@->Yr2Y_t1zqv^XW_&vYCv-(ag+SU3N9Bd*=FY^Tbg$HT|q$1Umy2ALMoo=82GF&S{$ zFcxleX6${A)eHaK?tRC7H}C!SR-=EdQEH zwGk_}4jDUl*Y`lE<(PgZwJ|8?T7@=xJzKdw1<4pz2@2BLT`{tbz+17`IL#^iz8Iy~ z`UcZ;6vffZ?zz5ZLaC0B7on^znSff7kq4wiVKk9%Rx_#E+J78#o7^EWwld8Co|7#; z%`_o7knN5F5{Yh0(5IZk(D;r1{@s|}t>0lJZjmxKtNC?088K=)JnXb3(}rp*skzuO zb-*){>IXQx>WeILtLY6t84iC>4uafDw;&nZBPC*6DrUy<@`L%*uCW+qBfvf}SF1~C(>5vr;UAOrm) zsrzuYRb}qZHsLdgT=;;KvPBtP9H$PK>Gf6HYkfcN2U0ocI*LCJ0ne&Mw*v=BlBbzBzDB_buRh_sn~F-nd}4ky940N`@5u*N^W^0WrxC7P!k(W zTWRl$Lwl;naS!y8M0-LdS5TnuIaGmNXC@r>S3#eK*ofL-2ZHr()BKx-lxIs?54*|4 z!k(R0Svz(E^&C7buXaNe#P_pJKYTHphu4~yJi;7a%Sakdx%Xla} znXY_?Vf$^qGNkXka0S2R>{mEul}Xo@mC;jMv1h}ZTV#s^C@o^ZzOIu;avBX|Ed-x; z{I~yI+HSLAAgJ84u7Nvr_+-cs(uk>t_+&eHICHLs-Y#J0X+bA_WRozyQg1 zP}~g?LZ%z81}BL{%xSJVl4d&eLXff)w)xzAKoPK=Jz1gC=gf|VQz$HFpd6P%2sf5&NFX&&$__p*wvalT92dH$bht~&0+wr`gX}y9 zE)JMl;jg4xxGsPJh7a53$f|t{$Ncgh3_Afw5cpIo9{(hN0`NzmtFZcOiZ+xfFB)U_pGaM zHbp(O4X(@ue5-^(w-Xm zr~aj+8WmqH`0XD~zUbYWBPXFW!#QoCgl(Fj+5<>-?0Z3r19DKuoXtuZcAUxIz}HI5 zyml>HnGmEegj_J0W)J_Q^zak47nk@iiPYxuDuGAPLuy`w8wSqIXCTMs>zye5XO!(% z`W5`R#>0hJVP%@a>{qIE+Ha+F+0_X@@O(Q`oe!Ktg3c7xu_4qxGDYL1U#e7cn9<^? zG5SJNx??tyBau;cbRv9Tgj zuc!OoE_%(67&Jefv3LbT0-->WA&B8bup$_c3<(DW{_{e-Ih+VW5)TPWht7Xq{@-+U zZ;VKsGHTy=V)Fmnwg0kkgbkW7NIDA@ONaTt*7`5HD!>0;c$q(Z9`*nDVEIU;bx zFwQJgG%5Oj+Q9$RRTBIZHhG>kd;#x&M794>5CRPXA9`-^J)p;s3v{(BYCx zhABgpox5JcX1N~^_~Ug%z^9~E^CjTSs$@Zpe6mc9$|vu3*%vjJPu$K@mh^yHTSVGu zSe-bA)1wQzk)u{yF1TaC^Smq9Dm0H~LyA#5*{WOuRy~qIJB+X+mJicOy z)lcH%67|~R08I0Y=Z8>BP`EPRqvX=penJF$2AAVn^JuqHpY4jf!B}wP+9;5D*;BBM zZdFB%Jkq|WcU7NqYm|43TB;apu$NH*h(^cI0`|8?BS(4cURcN>7=S|kO@_M2fI*Rg zz?csiz4OY9RTTOG(pULPZ!`_%hLdh~|;3giJEP$rai!Ib^Zel7nC$b&H}| zw(d&6#)1q_-#3ZeOc|Q)Qq^@oF zv!E0m%CKWIEy;n#8!3NJ6*pYnQpE#nEMQ&$fp$C+ou0SQ$mw3bC*U zF=)M^ZN@OL=)V|(R#a24_`zy64K=Y9BCmuT!KCLGoiM!YPV<>mamlx=;QrcjGEno zHS=Ens5Axm3Oe)8cw4H*+IAWvzNz=K%$S00EvXWr5CiQ$F*EKrrBF&}!C{+;-a+{0 zwN)TNyj2`gGEuWJnl6tmqqq#TSx6g`&Gj0LT~xC1?K6~C{^=)SfC=~7HerELZC!}( zp`A8G;pNfOu)&=hT&zMa&_ZY?wLUV7248I@)dwy{(>7#wq6V~+5$fh8*f~36hFt;$ zfwx5kI{K>Spu_ONhu@Q5>J-xIE$ACh)mrop!+>H5^Gt|!wOhd5HG_SVe~}qO*K~HJ zEsd3iqH;~c)kZFg2TcTXn|`Bg#N82 zz}g2DPU{G)pSIg<8cxJ2=--$YYL>BHMVH)GVrZ~i6edZ9@CO)VA8q$piyJrEZi2)( zWlxmD8p7_T;UsACKrlo>>g+29>!ww87p(u?@hszZZ=%_{gZmG`i1zi{{sz3I^d(3x z^&8IBRZCbDRj2i-9Q7bWM!$1t_;H`7#P`;Wt6-0)a=O?6AC|f4Cz8b)*epcwU0v8fyAj_>Fu| zVNn}mn_a$T%S+4LeusBkeZQ($eFj@Kw<6bFFL~6xR_Dr1I1Yi2{;q^Ogu3h0wv+1Y?~Wj8{1 zfGa~vTW-3xVSbtoLL-qNab=+C1K&>jL^`-#Wr6g(5jRmjzxwHHDAwFZTdvg@c7OF3;N;nwV~n$ zS7#QD-1eI{!2GAk?7`yDBc@o4_;ad-rfsW<+d`=6iG`5~$%2{#)0UQT!=6q3{es-y zUtP@x0;_DSnh?KzySkPRO2|y|(KKRu(|ic87D$*5kitX|7W>kPg5t6BSjd5}!mp;< z8&}|Bo~LDt2a!f;K}Q~#_DH>Y@j$VZNw5Y7i0WY2IOTY!0j@4COEY===G6=tD<@*q1NSPLDh4UT zp=x?E#pEVolfYNBPQL>oc${lcK9YSNHYth1w@ z;PJNNoZDlF>M^k}>Wlo#D-3hOH+;%$KC6MO$xv;*qEVQx1O3y)aKcIp{1Oxjv(V`S zMOvf0q{9NII2Iy-%A4fn70F6kX>q$o#~KC&J9N6E?5FGjm?wd`x#fmR-;rMxM(0+ z5gYAjRtuqFU?Mjj3lNTTj)$Kq3h+f6Jl}e4#2R*0`~6Y2FV$BDR%B6GoNm)zGFrqW zMOtZ6j>vRcH1$Y{Y>#4KBjN9CDB6RtTp(CPwSy=v!?;gUHL22ojC(Q-gKzc2vyCPXwMS#XKKoNe{k{TiE1KdsF} zc)90}O}(uE!Qb%w(f(++*4ola0t%wbO7|N2l^H-F;h)2nYj@-a%PY=6tgp;L$XvC| z@`eeMR?>Iq>0-?*#?3%pQYi(A6w}rD)q%$#q&-u#Cw)UXeo<=&33M6?iz(o|5o`hd zlMc^z4s=)Ew)ia@w?XfsXZLVBl25@Y389Hda@Pp`vI*wVv*f$}_WoYoUA;^-XE@Xu zwQpK7r_S1>iXcy|fv^tlS2ITEqS4wJ)a>2fx%u|~R-LkEPuN-q+jyzQB-kP=6IfY- zKlavi)N}O5M&n6+J)*D{|87lOpaGmGRZXo8A8R4SBjD-!xbQpqG8FhTnjf0a)aLJi zxc9tn7W7k^B)y7Q_M+d4-|$~846HP-+=f;CWdu~@U~^%Mg8gL`m6_k;n24yZa&fuH zFnsLkDcymEXiU`it{((PwdGnN$W<4*iYE3|?Z;6Y}iy1(PHU z=aGcnOrfi)cTjW{$wc_in#e^~V8*(Z1K4TG)66bVnR-*QY-%&J+pY>!9Thbjoz`M+&A2*k?7LKA8b> zQSpAZZ8xX~1+PkZKKZGJWqo8~XX%4Vc<`LR*q1jxe+3741ZyTSK zQ3SLq^0!$@Qd4c^(gy$o4_tm>TuymN`%~na6;Q z#Qq3;%9lCdVoa{=LQVnLE1*F1K1!P5loHTi?ZTt30W`lmL@lhJxnwj1DsuJYb6fJFKKx5)+~CPjVPEGyr&H z$~oIKTyS!p2-t_H$#IlH)W)Q+B~T*5b(qR;iF1X@rO*S6>l=F$^?6zH`)93SzYz{i zCkeYoBXHv8E#{zE_DB)Z&KTneF?GNxoe&6ykTGMhM#eD=jPVQH~(Febri z<(yU_{TY2{zY;xJD;-KmN8~3}741VFcJlBc6p?`~BtbjQSBybC(I@EVdlr8+5C+=F z<{MKRbKSzGc`A$Oy*-;?)?V?HWqhOD(CTuj%O_mnP&k|GcyV;tV8GYhosmC@%M&x9T*!M&nJE@US3z%u=;2YgAt@}V%s5u@Lf8DQu+h)QSNA+QtXl9||qH_}Al!;ex-(t7iK_c`J|F9cDmwX?Yqj1hJQ zcj;bquOyy+&u~G<6oHY9%)x@I0zc^{H`lxR-s}9zL-U*OE-6aRFG3^HMvss8yKXV$ z+h2^e9~i3*tCK<8)3X=tRyaAziDs@(lf?d(pZ%J{sVM6)Kc%y_wQvao#24G-<{usJ z-50q9s#MYM{>!%|e&!8L%v>N9YC_?uxjMbX9wNP);Etg}52zQYp!u8LeZT)7!`VeH za_t8IQRNWouwL6wulE&W{iZi7G@}~e)_SzCoURf_IMm-Pr_f?4ply(99z)EC41Y)7 z`CH9*!Ppq_dJroLkc8V(lPjNXuUmeicQ#?ZVES|R79^F6o5&5oq9~70-i-v&gHs!i zh+)buNK~mk_qJQFbMksbB*%)cy^k0>KX2Z3Bauju@1;ESIh-f-y#GSFRXlHVh@S7ut#p9~0Ljp_ro|MX%qQy9ggFpnZh=J);L^PAH1xHM=DT{W8SJ?Q+nz-Q!ug?f#x47_92`IA(Y}0C@<=gGu6cF*vd0BpCwDPeO)^z5P5Q z^Zs6x&(1Xo2&0pPzIIA(vFHCHjb&j6i(>X6A;a=KF;V`ND6r1(ihC2$2TWAN5tz^C zcai-08a=By3*-arVGy$6)Detf z;SwAlN%)@A^Q38T4Vr(u#jl9>^XyPfJyHNdZa(I$-riy^*Z&dz{k1?|htc;9R=>$M z!EASopH#K zhc)n;5sj(imqQ5tRp0Htc{6a$<_S#JTr$txk=^nzywt=N$mj}iNcY0PWRR%nT3R*V z_1)LxcHQrqgajwl@gssI?OkDOU90umm)@Z#l^jSLs?0v0;F_#tB<3OZE_FD;>iUz{ zwAbr(?s+O4a3R%!VMBzlLxt8Fdu?rFtR3FG#J-k!VphfvSV{D*fF;D{A$O1g^U&cT zCJRqSuW&*#O?4p|d~@LG{g~Qaj~s-o2K#zFQ{HX*`<%4}eC-5l`r@=I`hL#rZ=J?X z-*I?U1BqYH`)_Bo$ixO7;hT$z00ql?J{=bnrbdUH!?neJZPIP=3?yf6O5X2|y#=DD zBclhVxeAz54CD)!g0**64xiViEDoP0h0TvPdsKjNUe%$eWzp{3X zo526*<%`T^bo;3eFnpe`#{`dl zm6S_)8e+2AZ!CWOZGX$VcgR)Kcq<4ATo=^+k5~y!wOoi*nu%63k#GZ z09*ped~CSe+x=$CIWL3Rc_xF|+>iWqj=Y-my4B?sS$-RA#baMB@A8O^C1Cr|`7iUkJY zWFZkl2tHyt^F*|DK)rb$w2#54j2 ziXV3?L56uG9DGA2u9cpB%t2)TIwZ{b!bXl3e`9 zJNWzFgZ`tCgROeg9mjr8w!w9E<}R=ILo}V))G3|P(0O6a(FzyluDq+bJ9Ax^sAvLS zNh9pfbhh&ZifZpfLT5p<&oPNz1Uy#>(6)uj`o)x^rCr*7f*f|}b)b1J%DFSp?17A%H?&S}qIr#dYKbq*0%6odiPVFw~0*0>unPye!`Bwc!xb-p3VpS<^~Baf5{48|js*muNb$H%1X7rQ5nd{R!RdPCVY6B` z+=PV_?Jr>L($Mg6iUcv_#}*Sg?KQyitT-sfrZuM+vTG*0XQSE`%5C>ifcBCiCRVi3 zpPiJn7YUE^Zv8ALY)x953^l6Mb$PD@k&aqv(#fz}a13%MTSD%X`S%K8)%TNNpp*Z8 z)bJzO(lBHM9Aseh{xSgVN}bh@*JHEVumWzVjJ{K#enD%XaY(Imh zn`p0>mSoYrQhHwa`)x^@KJ@#Wf*<%;X{VqKs3T+o#U*jSOxw}e{d3H>f zOkzYk^T;}Sa*VYmRk~%f(O>^%B$l1BNA4%}qQY!RT-WGhgS5osxWm;jItCyc{Wo0k zpm#hG^-5Qk4K0=ppqY7(-k#@(jji)&dF~>q9CEV>+clkA9?2im_{A3+J!fZNYQpb=8-yj=JTFwF! z2vkqDH72|#q3G5mj{;$Co)$r`$oYnMbc=7Lu z{d~rGKX<=>@z+Cu)^ESzhFDZMk2R;$s)}6ZI#e7~e45vXGYK}a^jb@7Dn5XzcU?AW zVS~QG*ivkrA45JFDNk5F5URmWr%V-v*wc#bA{DbbVMAt1)Mvd+@nm+Ib5_KOTIDxbr3l z&-G}FM2{MaaVm7Ht46;O>TZwPD5}x)Pr`#pB3D#jeR$(49wv_@co6}5pi+&8AZ$`- zN(pJf&Xs|?)mAZnH_|Z6?8>U-syP~82O2r#FZGvXaEKgYvAKimhvI12M3bgE>WNED zn8gQ+-%>e)NNZSOyFQ5@EEG7yUF zg5t}$VVl+C$(baIsRM@{Kv`%8j2Nu#9WrQBJIf>xWg@L(j>oSdE5(sSfDClp_?rt) zdy|`}LTYF`(@po`!SVhM8m8Cuxr+Wx9?hZmIgf>=bK+iA2bD4HQ*hYAsVNXgU^Zz6 zKUJQxEHX4i?R>4_|G@%aUzNTon5nK&r7!hRYk;mV^8rhDwOx#X7mKtZQ6Z06KozS* z6uztP;ZK{yLVT%Kgb={RjpKqp0;!d=P@`_fuIJhcVdSH7gr`StnSvGESQa$=m;Q zr}a1Xcqf}2EL>-LI)6uq)vi3!_38xsyLpWMTxYRc=j!~jhpxW92H~p^8DLwF5n@bC z;NoS{O1J0PSae=@U^@;g7UqYFrV7$uCT{5-27OTP$`oaAWuAtyVCQ*Zn>O!~KGNWr zggAcIJ45c+hDmj=f?a4j^!$_CYwiA$#{vT83GNKiP2T)R-!;g6Ye}=wv~iMVG&aZo z#Owk$qcHA1I!BLbRhE;6-{4+Gs>LNhl7NS#nUvlz?4{|bU>PTvj^p|1Zl1k&*@W86 zQ5IrpcC2nJ>(_v?sOc~D2lxLZskxnUONFEd-M$P>a*Be<0@`{ z@F=}zoXX;gET?5ERFZET*eb>g0g#|uxoo#f8oba+-_1jl%RPH!+r_Fjsqv6CiVxk& z8)4H%?sFt-KVRI3J)q2m5%%kdy^G}4p*n5cVF(L=bH5eFM?QGxl0hL1z~ZfU3hj~@ zT^HFf^z^bDN|I=rS#JBAXd$v1cZdxqc8rSVJjK8awiQh8%`?f!p z$$^lx;4W(}4*>E?&>!+_$>L@mp%c@489@feKAK%j(`K`dKuCFvjoFjUv;6^L3pF(5 zgB$B913@B8v4czb%h-fA20F?XgGg_$yi0@ly1QT=j{F>sm!tI|UvUDD!_)$dGj}!- zHg(V7ylS;Y+XmZR0OQ~sl{{27YUkr3b=oMNm#s1aNZwsM2E+R&Prdt8`Qu?afW4^2cL0Eg2yviJKiloTL%pZS5 z2c87GKui{iIYS(OG98iW%y49DR7>h6MeYk8D$Kn}lo}0o2w2ZhU+gZR*@daAaJ`Q} zu_F1|LgunL6(L1IO=Ebn9kM$$mh|{oKEiOZZg|`99M-4AibOBR3FdsoVC=9Y8-4A{g+qVlE7!fObhWI?{dXp~5(Q#9>M0i{iXczlEok{x&mgkKq z3Br;#cf>16Y$0<&XBvyy9b+!I_zN`&Dbr$XewgEL$LKXxO6jIkOO#~IxJkHQ%(33b zssi}Q0%r|h&^Y^{N-?@rsKX$Ll5fQvJxnbq?y$#rgMLY=t;0?Fr|(1Bt^=%DSH2Y= zBI5=WFqgd#CQ&wDfZ%q&GI5r27bK1F3DSmJILpuZ)Ix z{;|c$)U4zQ=M0hHUYFcFn-8U|q0UxD65P^9SDfB_*;|`K8$UzdPdU%R`H@;4ockcN zE~A5l`jqYyuawS+trb|PpF*&b1EZJoTsMs`3X`7`$YFTLTy+_e3+0~}?J?H4p) z;IXI@-dioV4!Dy_NiSA*s4C8eI6|_Oq!2B=mA=0KffjSUV4G7YH8NIxIm!z!Ts5bT z(xZsK%uN(Y$S2tFQOG1Os_OkHU^Ftr;ub6;7@!8s-;xLvqPeRE75U6i@1X#qLUU|s zl^`5J#Gvv*Y3bs%3@X6&5ta_m!?lRD{Et-*rzD7>Y3R2Goe*S~*~yU6q|B&n45FqZ0b}*4FBwv`&tBw1~;$ji+udOGb88 zM^%nu@z2T8K+h69NE>2;!RiR!iX^|96Sb92aQQ0{pr<4@X+=93aO&r*ON%PY#CL+y zGFcf)UMC>}c#N>&rlsf${-8ncFQ>lkw}U6Jd6-D8xNlIh7J1-AhL41>G7I|Lux9RfjvyAvRX^R55tocF!?ug_Jl)xCP{dY-PX-h1~` zRS=>K?&6ssB~!VaU;TVF28FczwL}oZ-$m(9IyybLB@Tupg>MwuK7?cIrQSBDa}M9+ zDPc&r8V>T{x={Zrmos#^ap(V{@H^GGRap2=zq4eE3MSM=xXrSR1&I>CVuJ*}c zpjxepb}aM~V}aLr6bw9mOP`g^P4bahN6>|Rm(d_9p`$A*vS^<+I6zE@r@)QEZx>{GPszvJC%(Cu}y}FLMhX(GO~?&MY#QX zGHx?Wst+{mLU(hqEu;^(PxMyC%;6~j>`>)O-ZNQV8`)QFd9lgv-W z^qTTFhy;99#{Mp_jfDX!hShFdHcSyv)0!h>MGvTbS@+Y*GdqX`*~D&#YnnSFUqzph z+feuOF1_opveh2`w*^>mnDnNK6rIB>Ra|!T9_qvh4tlzrskUpxQsAzalzwfm2gzQ{ z$LdEwqa(Hj%z+8J%U<@Xk!d6n)LO4)9$(#S;4;zdv}kngubkB|n_AW1=H_TkehiP|1TNg7dol0frwNN{t?N#nw<7E7|}0^fVfh3Fg=C^*P8 zNG_B_!h2!r(N1RSw$!i4CMvy)EN*3AdZ5m+}XtgCH9&E~i zO0L;eDu-+hfivZ0VyY#OY9gc%n!PzQGKJH*M-cCAsqXS29h5>nsIyC_qSjj2{IcFU zHg6d@#J9>q{f{orXx>thh<+O=hZ2x=a<|VqZ|{n5`R#Hc7$yRPn{GDM+i>cOhLOSg z$cTZf3|1%Pi8TOz-;`ms*oVOpZ6lLcENCMpX}T?CJNUM7?f^K-Z*7&p3yB`)I^)SS zk)*O|fpXUKy|5Fn>k2b}o+385r);Nh8i^u$!J5uAr}uu<(}z=7t4yOB>|C#nWCd&Ijrm}#m)(6S;3ON}uwJ^26D={$Wm}*eP7LynH47%7G zqLt*CdP*@74js1O_gsbT%^_}ygpf;B>k0WsmZQ+m2dyuXpZ)OhtXe@FiZl0ZN-pf4 zrMzYl&e`;zoZs*?&a`g_uP77o&?@_Viu1keNi(wf{G_N~!rzE71hPjJYQ9_i33Zgk z)(@FAuS%m0{U`1FiaQ-G2o0a>W|Bd%)@|n>^@fd^pNSAdywN(5B9_^3T^;W)vg9Z z<~{0Nb+bwycS*1-7~_|_+~jeUiE+;++{wr5EmwCj)0&ERcSMD-db-cQQNXoD9#3cHAG=@rh&{*IGMj8cx+M(`WLO<6_BHF~NnSsks#VQqXzVttiCC9O8NW z#HVi*qilF#qT9KKj>FEpg!u{=IwjkVHkHY{5fU()

rVk0bVQPRD?$9NrkMc-8{ zwsXjs5C2V904RRaH9X=M(?$nS%FrVG(O-lKfH@Jx><~rleli>+AL3#V=9n_Y1${wQaEwa~*%h=#h|?Y9{Z|LmIg!*}fq zbGF8{>y#>zg|DB?OI8VRGFqu2I-rXq1V?h%)}qF#TRXjk9DU2UjsIY|=giyw!wr`# zQo|GnvWK2Be&@rO<`~{}zMtC9!06`ZbynxpA>*J9hb6DzYgwjCf)^RH#N=D*E01jo z7v^L0LX@LisGNR{`1fH z*1I_xb&&|##swT=TO+)sAbC4~I&eg%0MIT2dM&Ls95Dzy;z<;~7q zapUMrGWu7$)qR*?DYk1zi7ptPIe0n0D;sPQ+5!7Jbpm-?m#k9lh?qS}q{;rKe=BkB zx4^rQg%iV534X+1-$UtEQ`kxaM+qKSj)UlrZ6@*9?t|oNv4cK=WjKZ`Fhgk+(9y04 z1?F3nYQsOO2$!q`8{mcOJ>(lW&KEIVA+##=k(YQW6eubP^uf5x?(3Atf6J;0E#%;nmp3v>oxe9 zIPR*QG+D<^1=3q#)VNSB2StZS%^$pf)oF`I-1d)`zbR8=fngVH7s$!jHp5%{inMqC z$_~j3xX5n4#$d^@bTSb3pD5nf!y6SIPGCvDxcK=lOIJIHr=&;}VqNK0R04tdMJ4%BQw@0Zv* zn~U?)q4_`T=odXN@=hCqqhNYy-|b4jZ!*h6q@;x(yI>mAB-`>KEQbUq3KYQ)5UCt- zccRPfIy6ORwf%Q<(m>py8l;7v6Bz6W-BI=;mG_=mQ2o%1P#i{Auqm<3tiI=JlS9Mo z*to=;kPR%Rjxta-Dnox=p0Ujqx!5U(hvm+o@F~W<%$J;Dd<&pM^*Xi@{It$5kE(Id zn@k60l4%sU4d60t<{gczEp^h-(qfY;*LFpoXiHZK3t&e}_<5PN#NdqayZz~4l~jXH z8B>Q%kk9`(`+`#fCKS?qq=b-N-^a7kw$D6gnp`-TN%O@Q3MNE7@pr8cR6@e@ zA`~qH=r5agHks~8p`3wzJ_)dOcb4MA^%LZaWhe)F8qig^4J;|G9qoH>3pmZ|GT^KR zy)Ymm)4DqnRE!U(OJ=k{Rg9-F0LyE5K&0LPI1*5F1^)UvnadP=;vL zi7Dg>duu$PJjYkuI96$kX-Vpgpk(}xjSs7{aW==|s0u&1>?AI&g(y(;-bf}8PghmV}c^wf{GJecOGep>MZ}ol`nBI1TG2-Q7eZnWlG=~~u z<){{1wmtaZ>h--8eb!K&{LI)kZMrok{Jo_5eEkRyMWgVsoh$Z{yc?}E$U;GyE9WGk zb9Vmf`t;8Scry?vX>>3br;@W^(!Bj_UQ+BPKZku)*;dGQSKKOtj|xZTbALR-EDZ+_ z_ip5n#ppE#dU!@z60F7V`?F(sfz1;JDX%xZJGJuXvLsT&a;|wmNJLy}2fsx9eO_es zn)f(EwD%sPe$EEJxa&nzw@Y8q%kAutG8F;KE<%@Vu0^nMxJrrOq_4vB;sIFzL)Y!X z0M{Y2hB4Ls7Q^CjQsdO*Q+O8O?NZ&RTFF9rF3bHI|8KEjgL-inOT{c2LtAV^6wzIB z=LPD&x0_t+*qqXz4^D(e=EkI_kC0ETERkzz&sVRg)NbBlBPOmIXZi_xAgmO94WxB zM^nydxUyC(5H3|=s@M_)e~$^r?Vr!~-o$x3Dj_EtSbd%lCrWUMlbwE-mRu+9r$%D+ zT#cz3i4EQr_3i)Wl@7TWt_{rewFg%8&9>-Xi`S9EEplECuyE%0Zsqv?cD_I{d0kBq zuf&%jcJbCov0mqsvwQtw+jWcTum#58_Db=AN1_t;!JZj)f|!opLpaSlUrV@nqs(ew z|1@up>SzTzEj75a-UL;fTAM zPX-@>&S}*+wYWJRG||fv=Z&b#@?Uerp5o$J*@FdN`n>ujFO50ets7{+mN;41MYNLn zo)bAxAI-+JNUiI3+Xp<3K3t%xbYC~<4GW6UUgZ0lsO?}?uukR*eDOc7Xbo>yXW2@9 zX+i8^HRzHw$eF?_oifTRJKUx1(pqyoRh?Yg?oPba9?tf6R-fruzi4NExlgBevUctU zA#CuJB<2!<+LKGeRDM1=ZPQ>3&iL$sV-cF_Jek`TZAp07gQXEUW=_|=IexLy-a&qU z$;%rzP(&8O{Y~(V%~OnH0GX;^UpXy z&hl0Yqj6G!J|H2NFk1Ahr;W3E^GJ%ny>88;>V!xADb9o28n;uhk6J=cT|bnQC*m}q zDOD?jvvc?A_o!?et(qFzrty%kri zQ{F+0)j>q;6#qPwxdq$Gx>Y_x8sz9&DYji2Ejc;D*RI;maw?09b*JLR`Y za~Fa!dhj!SOj?j~goD)hcXt#9K;fmf+FI)(CF81OmliUudSk67WBt@U?Gl@`>ys94 z(5~dBdv7i`5)YB~ zqdxbH_J0cmdYp(R?;ID+T_*Z?a#hK7+Ha6R2(9 zWs#bW)+F_^KHb9rN+#83ah5W+{40rbL%mb7G6Z`3F*I%!pCO@fqR-3hWy}|8D%|Os+ z#5}@`Z^@5bPanBOf09eB0VtdkDU*?gSbin>-ZbglIcCc3(7(4O@tOEBpH3*Rc`2)R zZsNUtqFEmOAeQeh(}CoNs~6B;B$pw!WuQ|nVc<4)N7nbK=((eNii^X*DZe-2Z$*ae zuDatQ5aB+}9bVFjE~e%|6>0e;3+Ni-BesYY9d6)_Ik0+@QhT%&sNth`e2MAwJIc>C z@2w|tKG{q2jd|)DlWr)22NZ|PDDTu+h*@;)Sp*5Jey7+pfm@E&KJMNk(x>kb zeI(BiIOEK99=nAPJ-0kA*v7d!K_O*CL=LnKk2f8wSzn>HC{(0pX=|RO@*RDr1EZd!3`xu&?L{NDl zpwu_NA{8e|9w8{)DctpP+<`EH-nF1r?G=y)^pr3A^o40bh1A9Zez`LfKms7t4V`$c zclwhWsw(H7sbSS*igM%=*KQ3n>$wR}7Ue|1#PETbzDz_!DM!tHQRN{_cr3eUMGJ3% z?u6tRPpFehT}3IQpiTFpKVGAavvGVrFgro(-;kquQY;@!BWn=7Vh*WtO6e5KkTC;| zOY)5@3Vgsul=NK!u00ySJr6NF^~_*voGziyQNT%g>8Lv_xJW_KOSKmC5fmj}pL zD`k*Ea+(+vQ*)L$ztSS*0Bc=E_UnwUIc9y%8{8)xDj0r_`0>Bu2^@l<^>iwpQn?#+ zKd7hVf*_~YO%h}?C#{ZigJU!#J#*4HkDg`EaaIRe)j!|?Eb)(FLdQw9zxohjSXF*~Gmo(-SvhA*DJ zbEnn4{Z#xUc(PYbKQ%xgmU~9DYE+Q!|NKcT*P*0q7?S2qK6ck|GmSQ0<6Ft3)n5Sk zsakk4s9J9}r;z+u^ZILjX*8e`jd&EA7fAu)nfg`PmL?aT;}BC7Xu=es`+UzgFeFXg zia(aiIEHy;bUN~J11;|CgqN|m*qR#6tYA6smo9G;?)RkTlfg;~%3l_z(IK=$pyie_ za)TCNlY%G|QK5Jx=$bH@>Zg!*W=M5+Q}?JBaTEu=6~)9rf7J&a#g7>A+?JelGbd{) z>`c;W(7uaKpTOv)w&0nXxavw5YjnhvZz*^N z<;x7>Me+_eYD&VE7Q5Mb!G3C3EJ^bsdslI=A>YRm`it~>$4>OPs^Cq8IQ>AdN}+L< zmxHbiF;Ao0t&-#RI401*>cxN~AcB>=N~%sL&_zf0kdE5BxqP%whmO#_7`5yk`^R7lK1UgenPoyBL;WV;z>qJ7Asb|FX=r=O%zvd@1Ov-7(pAQff*s(qBxE*3mQ)D@n(ON^)) zaKE2&J+bB}Xr-7M`{+VM+c-T}shlKcb4{f1uCX)v5C@YbzxvVv@l!4dJHTX`86k)> zB6+Y!2@w)=R~Y$KDqN~9Jv?hcAJ!hL(z;(nZB_S4#Hs`Grwmhb`6=3sslB#A^m6)T zBJ@fMluv-nL~BoFJ85$W#g~hwJrY*n4_y?V3(m~ z5!tp$>FNn!9C7VhtF7;E-m}0FLCZ~Q3H|$CYKGiaN!~$bT#IvslckX@ zE=}(H0hUMmY)}$|z8076M1~0agn$Y#rf2RYmsOMA?zt`Jl}(@_0Q~&@-l+gb0PPax z!K|!L!Y$GK?91iEe2MDx;7}WvEKEs3h!9P($fj4+ZraKCzUTZ6E4S@?8|Kw33NFh;2HMk z)7z_e^SZ_yA6)7c^Re}pe`<^VRQ{T;*+5@N7$BB8S_d*vl>Mqvr{%%cuaM?#6#=n7 zGAPY*{l>pPHV%^BtwoQM8e zML*zdSm8!6OIpyLti-#Gx*MED||fWyiP;VsYQPD4yB}L{CQKXe<{9T z>nL)F9TJu7E2SE~PPwQ$XgxSUEFiyH>cK55nA8(Oo(sNl=Oo$&Rtmsx3cpdjd+Hd; zmrMaPq|Tk%2{6~rI&hS#+{`J9i4bNICfSf93!E~cRLa8WD@AjEmDXx!&}5i z=M3MJJA-_@rV=CYNbLRuxMAKB>Bp;nF5=Z3VVFIk39plK*csyg42;3o+aEG0AmVc} z{DKwR>NP%zH=Xpi8Gmc(!}+3m)|yL_HQ)*|5aZkrcaiVMZ0S}$#=#%@m=E%uRw@)8`@Bf5u`%8wdLeE#Y_UcopGW41kI5x*=uYaMlm+Q zOp4DTm}IfHpN`&3zO}|l$C-7@+h;l*E2MP~uFyNN8j!X{+mdv0fa`#EsgL&6zxthn z9x1Pu@#y?>T<7En&P1*_wm&BSMVGm9BH|JaewY$IWS~!3+;P6b+FcUvr!5+5_=XRq zd7ihso@;-W=&5yea8Bd3`N=vkuuik&N^J00D-U_g7Y-x}JQ7*tf}m@W47kiO9;t^v z&qi4Zg}4eG*t&^to`Lz)1I{m>*HlAsy7}Z9x6b}D518HUHqB>6Br*~Tq}LYle6;8Um}v|&<%Lh19n0Qor0OYj9kY>!>+qai!=NbFQ2&4X&Zy(l z`qtdn2%$W$6gOY@PTEWkjtwvda1UXOdif~Wx%+f$q8>-47As_019 zD8_rO^M`Lg+KFfn#f+m$JDJOqCttaNaQR#E%Eo1#k0P-k1gW_bKM}v`VEbrMj6OB~ zMKh2+^_M4~!sv;5Sx6=* zc&Zw$+g?v>J{w?6sbW>smr~9xlcz*1+g0gaP50nsSQ>1ZJZ3Kw zvY4lp4>enlshFM5n?tc)sSu0cF?-rCzS+;JY!pwI+WM2T;xd{nAu**in(7f)i13X= z2t}r4g3CDsJxv`u_}M4xlQie<20tx?8N%aT|^|yZwz-mm1G}vA1l^6 zJ*`SnF?Pmjdbat+n_bknRN;!(dQN(IIjmb9l%AEZ=$MZ@?$;cOEOR&%PD&jep!12_ z-6rku3gP<~mH`yNWqGo>kp8Ep3yv8SFBBf)N8sQ27K3#sdk9)3Hwt2|=Re7=*7UY6 zgT54ArT2a?nPs$F_2f z;zJ#}*4R~3Ppo$Q>4w=@R`vr{xV-Zrqz_3QMf6&V6PfLOhd^Sh1y&738*8X6g2OL5 z#B;T?@1i1>SDImjIR-w?j=x%z8h=v{WKMN*WHmnGlYq?yd22CHGOdsiexRfpq&v8-xF>!xll2w8p6olDpDZ1a~OToD9?IRmj#n3>Ik+2{kO+;}>Od zTe|cIa$o`egWF6CoJx^M|AaL=Ex_KsELrpeLZ;NNyPw^fE%DH?Gb=e0hj zQu#|n)LP0mLosv=kB;U%fAVLWbjiNN~eLpsFFL94X0y6IT~Hy^Bl zgh}4{tn^nZQ?`o@M3tcf4=J0MPp&^)OSCO`r8#UHe}@tD)gDpfo#Lab8Pj_1FrqIs z6AzaOe%Z9PYmX#a^ASN+$UdyZQK9SS`79I+>7rMuRZ)d#m3`3N&wq5&m2XLf-ic2YN z5x3_#TLv>jW6Sqs?8e0X>ymKVM99WCjqj}P4sp?&d6cLWhm+dB?JQS&agYQWgnjJLLV{-MLN40z~DC8Q_1L3cDL=M_8WmS z@egrUNwIifgb(!fAv+TBUZ9K8K>S_0bu@0yQ7fTlEd5WF!r>Oh*J*}ja$m@JG`pZT zy%x!Px^(as*@={#$MQjofY@K3Rg=Hf2B)k*MuZgYy=Y^8GDXFhWxq1&|NUZjs*5Ht ze2Vc=%BtV83f2ZnK=@^;leXvZJ4{JDm}?x{+N_JZCa`QdK{xho zqCc^EnN8pfU63o(ze&0r)S#{y3 z7)k{v1ewwH{KoKJ2{22@FaA||hGuCcvmPO25CIu(|5`uTSRrv{&tNua*9n=KIzltD zBOnF!F_klZv~4lZ^YndNRS-RSX8is;`k1F14mB)>=TsyB*MeRd^A5F5n;I5|8U=0< znUgi7s#FLe%%UC4z6*avgd~T%v^P({GQ(ofEj1t=tsOeL3F5fb0t z;9B=EOWSYo#@Nt(l+w`g-9}8&s^MK`BDSjMeyh;==a!EiS37q2__DT0Fpfx>D?05@ zFq!4lH!@}1#1GP`N@ZDvL#vHLHIFxa(vhD}6`uY9zVQnT6u{>0&GX^>uC|OG|BWIT zk^PHBd5e`Jeqb}y06BwKBh9zw3Nl79h1!3`;oG%x3dcXD>Uc(vyEDuJx8%i1Hr9!? z`&H0DfP~t1nR7Cbj#ErJGM~88u|;se+%o@6(N^;%W7R!*g|32m%tEo3-$qoL0@9S?t&_=+p%#PTj_BOuL?~vkoA* zi2Se$*XmkjSR>Vsy!iP_8QUhJvTJ)M$Z+qAr! z=_sBXj3?>-Cwa5sf^GdJLC+;pL9Srw0%x+4-yHs*CPhD}6q)(99O zl#=%o0LLSM0g&gr@UsAi$g)dF{)ftHUmVQB2;)c@CxIqQ=2j*^{m^|>Vv?g?gN?nw zO~^Ds+(M@*cVyWv9H1y2)8plEpEOE-3<22g(kt;gM+t0a;NdCV4JJJyEe`YV+b!Ws7juG?|KSaf0dTz+Dw`nVvD(LysuNV20ynN`b3`{ z_p2YAO1Psc1mKA1t*4l=DQjLiWI%3xkEf)2 zGh@ANU{yHJIFIOp97ZB+CQNKyEV?k z+!c>_vA|a7{2oo><<+1>Qj6;fMVHylEBaCwb`>{1#f_D=@*!%J28;TkELVI91oTpR z&r^*hg}40)&x}KuXI0ozW0lPl@{<_rTwxHV7e=hf;(jSUl(crFikplfa-EZ}%3`hzKJ5c{}moMvQpg=rX= zacM`Y4VOM@`>UwhxS4>I_{b1B=0kv5TGMb@Fm<6AJ$f)4#s=L&tlmT(^-!gG6`bTT zH74_mz?Zf)7BioXnvy!X^}Gf$!`G%R1ur}A;BKxgbgX9FwP+CyhS~DdA8um}+0rNq z1L50y+jnMT$c#R09t0s1Yoe7)AuV4w?9SkV?Q+3baLRou$Zi0tLrIJN=ZOhEV_5*5 z-SSc*vr0R9B!HR;YR@bGsr1{pk3nrbN}Qnh32(Czh^_#G)+5%sJxV7mmi`R5Zb#^& zFJ8Fx$KItZOF)eqMTRepfT84)_DcPK1`CINdgLIHHUQ2~dd1SX#`nS2T;g)0UVcP zaj#EK_P(+vG*{T6`mhSS=W~CkI3;{Gbm)w+n5n2njr+7vT2JFp`-uLQpD-kva$7rE z4KleT+>}LDD?yM|t;i6f)MPkUW<m(2jQ@Nni< z(RRSvE(un%Sg)3~{4+~$rgK!JvQO(x5|PK9#uhrw2-6;)WU`1_y^f4B77j3zsYEoz zsyH<#9UYN(Y0=R89N7ORU9pLTJi)ei)uwiRpD-OqLR_GXQ+iLmLgP7}S3TBMPdlHF zYIjj>*cKbLTs#hnbZ|rVALd<<*S1Bb238t3pb+3IjaIP6xGpxhR$AgshQywyUCx;@ z5((vqOhJXb!bCT8IPE3NNJZB0rG#GO6W4gO>mBP-T*Y#-&Z-Bu)wbj;9k)l6HE0|{ z3PP`9g<-y(EC#uP9^D4()Hub0^J#77$-z`+G-jvv#+5$Uf3a4+O8YJ3?!a2OHRIo% z0x}+Yjjq9)bnNmV4GaL=ogYMc6mAuzr|U1$UmYQG?+SKutc3= z$-)R#$`5uvxxFTiz6yRr9HW?85O_7UEqA66n(g6HRY0gZ_U3Gl{+2Mnov8W(ufsW_Cc15EU0_NV4C%IaXi1szvB8M_Q;T_ zh$LTP)H*FH7Z;?+iK1J!7^=EH99Du_eahe0WF>u)6QfC__&DGsFBEE^&{Cjzn(5X= zE4<#YlGIV?mN@~;e(jHDEU6wtX~-d$VNypJCabsI2HS%doUPpIa;R1lKyBLUN3e2Z z2*J0!wg@6$mZOwi;{C71>5|r=i>qK3lwn>yf&9e2#H|E$(8K~gY35%Vj>Vu^)|bz6 z^T-1+uIbq@yXv$#jXEI*)nAgHaLpY06sk?=wUPDd>nR#?Tdelm3{G%lAgF(cOvcC;xRxi*1M6v z(hld~qt-|9Rgg^EvO~p_$K#%K>B;HqvCL9}!XkzJ+WcNM`cbHAD`+j8XjWd1dV1)Q z#zI8=QvmE^aroWJN-x)6&TC1N2x+0d2L7>|Qq2H-juBB9=XV>mV2{&SW61{HBi?Lu=~uOQQy?cjCvELw8vV!ZdC+8mBy% z^Zt3SUGA=-SbY}``Ef7$j4p~H`R|0sjqYTIr-T+5ch!K3bnp-KOdr>1xVx!N ze8mx=5uRiB^r*5QzpNS@P@d5AX7+SuQGz&siW6E$P_j@3>HRC_Re5tEIZ; zgMt13j&t?%RIX7yZC9-J=GkT%{g;*tKFF55tMIt{@H=pcJ<*mnV0wI1kw%@AJ`6;B z`OwccWLxmFx^~WQR%idguD*JP0KlNz655wqe4iY#s3$~cUIb=|CbL}TJqoKf0J4?6 z&MjZNK#_!(q*K^)wF_Z$G)RW#k*W2IJqEbq&}exIqJz>|IA65_VDPl4=oCzZ37`gD z!J7lm2pi7hmzQZ##T9eowWEi9PZMo&7{r)V1KWrz|VfCZ+TMV0A zM2JRtfZCjwRFSElShAx_VVLbt`kL-q;+xre*i{`X?-hMVUZV zvZ*Q~_S?27#s?-L(=rVi@vWBn2HQrm)p;CjY@en3OZ4DW&1nWjA18qGM_WAm&(;i_ z?s*6m*Tg6kONryyl%Zk8D=i6^hvzTg83q+f6q%V33Oh0l0IdO@aKagd!Pf{RZb^Hj z+%&|cbcw<-2G0gDCGP$dr$R?UoB=f_LJ4`Z%>t1;xLd+Le0F#Q0hFv10M&Dt^deyZ z%zHmhZe9oWZQYb=6~ZV>7aCg@hq?S&u=yQq2IkQYU5_KOYAbOg6LIr=B7%RpJ1<*P zxf0D)kS+WZqpMRZr^F*-0T!C>iIyOBR2OH22Ka+cO6qlnsNJLEwh)2~V7r{dSd$uRWF|zWA06 zo+LmO0nlk*scG{q(63V{3mM4Vm$?F{ZyJ%V^?LPb1Bn&0B z1s^N!ZF@pM?J5%)mz4ZLQEZ8id_?_KiX0Q;b1^UE(v-wcTJ~qMm|AEjW6q!v$EM7` zCh&4pegc;ChvoqE5`YKCW2#31`;EsY$uy68Sp1lXk{UJ zPTrJ+BpsCN?<5KWqQ^T`($IKUiO94l#2$)R3$SFO_M~?QbEU4r0p)kM?lm^(2&evA z*j4LZ&gKGhK2lmNgW@+}*FP_Z(2?rUPLr7q8^`ep-&br6v;%q2wD9_rge-k%^@FFWEVen7z%FkBJ7%u}` zPCMjx>(Uy-WD3>LWu813%-3XbvMCN$3$Is1@GU&)>>}zgOT#QZ@r#vJr+K!e277xW zc^KxTlF^}QdL}*?F5uHb<#r|zBp+cqNdQr*kW@uRrWZ^J_7X4`5JehPTt7Qx=;MMV ztxs}B)-7R>TWyFZ-Be`SJH$G(#@zjwk{cICRh!I(PWQvInrg{Yyl^%(A7|6XnyAp) z1Uc(U3nV?j8Hx>*H*266g~w^l{?gUOoUGAKeD)^@#$_MeCpA`Nb)m zD5MI{V!}-+_sP9y+!#Hym#_9;bpgTXB7jW>O9nYyNk_xFmT@P!OeM_pd$HVcP+S}u z?jK8VlW^*_e6~7p?Xm2>mkZ#~Dd(V+V|ylEy2Q3141`crQ8|j25-9+4R#^k2sHcl^ zWBBF8rYInGwa;_{l#9aB)cvHp4oEV``RpzW-06#wmozoeh!-ETSw|SL3yZiF(rdkh zCAlQB@%gRIhizY|?g_@l=X|+ZpD3|lIr`VX18kc@t+yv5&CqA4H%Iy2+``W?n3JME zGKK++Vxo=uH@?U%EiBMYxRkO36jU?vHUh_Sqrn3r3^YCXID0xor9=a&*MFtY=Nc&x zjAyyevV>z2JmrbLfg^)Z<~pFdk@;bPG4%?-`_+ktF9y9)I_NU-x3ISw1vnjCT9p>4 zGJ6h0uGIewdKEX-E{=~zkO%MeoC1h9iIawS#H1AQ6a<=*8IzN?3S)j3y&zPPv3~Y% z!ngB!-pOdbOKT64N!H+q;bx1H-U(k_`A9Bn)9d@k_RQdHcn#5wFaijoRSRULbXqN# zF^SHGk2|gf{o|PuCv2n5va06)T-t*jce*$D7*oGsj=A>AIMaxA_!=E|0Kna*jjKd) zUTrOMInngxy<>mQ`$T+H*!x45!3UB>(Z!Qjx{gO_qF6cBZzb!W$il2v8c!c-lvnVi zQaBe03k~y23vPx0BSs|U@EDFvEcj7=>~n+=>vZPW5J>s|bEkoUp~sYy6w^eK1$VB@ zTG|2yi15`!C8@Q)RTTxKyV)Y}eM!)B5lFM$t3R-DwZul0Va(i$+QR%9kcQ2Bg()_p zShc!Ohfv?0E#`o(pk?o;Ed1>Q8Zk37Z}2rXzf-vy`J9jVN7k39i+Q%bVIR52R`Lb! zOuLeextu#FP)R_}2Q4_KrN+xFREhkC;A{awdRF!ayRE$lgEI|^B5xy(U6J_+S{(XZ zhTbtBaUYT~4jV#g7_^0H2yF~N>)4AVyFcEi=1*>S${mbnD2pR=JoQYoU8YCm{xQ<} zaD1VesMI9@N#-#pFb1G-z0~Mk2lqNXi(U7YwExj5h7Ai^5$ZL`L7PqFbD}>;8Y)o| z@2E&o*RtlKP8h()3y8z@vW;B&a9IXQwS@fRc1j(Mm8E8z0M)m|BJPU+PP7%O&SSU! z-qeD2sdc?)>Hl72Cs=oROCqcVT?8_U+~-FWAHaZbktc&8@njttO~6P9Qu{mf4`c;U zotV_8FxwjD4yli|DwJrmkxlW1xKv}AAvELB1=rqUqjkjo?|o!%U4NuaehND7--?4< zy2Fr7t7!TPblMTL1)Wiwo_XYX%psB%O(l=?$lAla4KC%h9hJaY8!_8a_KY8TyKOVx z(4@U~D6>_8+AiFydgpfs8O&aaK3#|1%HA+ud|@a;SIbf6AzvEk(7`I&W-dXz0iS1= zb~g0Ip17vhR!Oe=Wu>y z740b?>=;$$^UoWWJyBsVL1Hmajb|K+N{PyEM|FL2$$CMZ~*bO?HsuVqPtG%MZD0&1X}G z)I7h2p!$%+70^?2(>zy#XuH-?l8`LXla*Jf@y$54se^0cC@SeLD`zQZ%KMSGEY{m8 z4f{xsALA_kf*fwj!ca`Li9d;#&N2uTn#z0-AFjnDrAC3sG>%d3-RX?YL={-x*8BGg zx0%|w;B%+_u;Y&j?yD!aYf{z=rC3L5fQKcA>*iFQbEeo^u!k8GRcB8~V&ddne%5?E zs(t=UFdf`_II_r@flXChQ!QtpX2GBTWnK-{8E?`duc(T$t> z!J)eyn~HJIYdL0m3jYkh9A#02^QKcLup8uqMoff#1mBA@{TV64LagwvuwMn( z5-pvn_q``C9M-@iLzRWNy5?^(^sTJ&%^#^N9EJ*(JUD&@b|NCZKE^s&Jn!ek8bW_} zoKbpxzQ|!1($$Zv)G%|O3a8vP*726SOy%Wg7%+CTLA!y0V9DA2ZrhLdMNVYBD6HF! zU6UMavktmnbYj?h;H0X$Pw7q95do=zOjt(E4zj+iIlVitF1@L26}lIwmwB#>1kmcS zF1rJ8LH|8x&Vdkyw!E8q2u^fmmC8NB<`S9|i-GUI>Iu8k3eb zJ1xLcy*yxvXYgE-k;4+51$)cwB+m7`{ATCB-IXv8ikJ-x1St`5WC1M_@i!F>&2`JQ z{iFzLQ+|Cf>&0Y#UGi9DTe||;DnQxHO}IGCr1AIIjy|+7lOo|B0kLI;glis_+f3a& z96O2s86=v_;?Uxy^^p_XH7eRh|BG&3uI-Y_4Ee{CZY}uNZk3tSTv1!C+s0F3+s2KO zpE$1K^bHLAoPU-JKuX3}T{4@S;gF0Ghns>Fbi)=kvkT69A8}*HD7H7A^u$ZXO9>H9NR;t&!*An!-C#~3` zrLRZdWLnV05t#^(Phbw#5xgDVm_F2h(Jf7I-o;Ptw)DF zu6i7<*t~7gbn}w`1VbEfjYeJ*vYgJV$(#G`ir(@^D%!Ky~#qJM)LVnW)omr%(UVmdwb7NDKEU;|!uXxqQFh*(4Eh zh*BkxpX^`qvQdz%t)x_X<0Q<2egu41@{XCv$DxjSBXZ6kx7`^8y=*YnE%f2ZF1zYecay_~!{54)Mqg`}`Dy*Mri z|KY%Z3(J>=j{>%F|4&!nZwc-T{Z68Yq+;|pQ{X?36EB*G+}2kln1P@Dh=FM4n|yt z=|BG$8UNocp&#DiqUUtn^LhR^yi_m0$zP$g9GJw!KT+7Ra=AQEZh8I8jowFE( zvBuB;g|gl)v3Izkr8TcgpZ?3U$&-EOOr1tGXY#*MhJ^JFmq^I-%=!4g;gzC#=Zq@& z+VQ_p#^HnQ5<0q7?hvH$-|+rF7x_P2GWh?w$p7Jz|IbDK|GCH>fBDJ%{XGlq(aoup z$mebB%WuNl6$>UK{|CP=2+9=i(ZCIzxv{#MRkFCfs)|Jc<-qX3TZJA8F8r9==-{Zu z#k{M&C;UHzBzTfwE=IRcH5l-JjgZjq7>4)vKebd0L}EoQ{aG6Hp|X&(vXUyRDWCeO z;%DH;z1ZveV)XmDHfR4OYgtGh`2G!F&uzR_lj z{)f1sEc$NkTtZ5Qb_8dL(8Gkg**-ISq#$|M7gEwSd?p8G@$o(g$DXJJzxNfQ*=CmS zHduo&@;YX)t^c(~_Q4|j?pp2Wd9B+)$LIN|w4Zxg_+10k%nM%T4MQb@VxBpp?vmIK ztlH$@5U}WCmndg`Rhi_gxJ$zC+4#Shdw(){ng9C(mHGV+=%?=+cwTmtq}o8Fi6EkM5D*cgqO_1i z6oS$_Tm%A$C|Kw{5=balx`2e57`n8C1QOD};QPk=@xb#hW1KP0E^D4O*PPkw$SHM3 z2^~E+UIP?7Eep|%oz#~@P%R1k=DiLc304_{ws#Lqn%0L6?4P!P9=;z;WkbB_VDk`J zLD`<39z66BjX+%52jCa9^@&Q!$d*`y3GBaB0>oCof3JJwLa@^3yBpUXHikU$L&ctG zJ?t7q!@JwTF*%KL<(cFBMn=e>5f_&vVtf9|5H;1)c z`GFA5TXd_k*kzC(=)S*y&zsvQ(%Wd#b?FB}Qc226rFM0u#kS+uutA^C>uZAD(Qn4ixN-%ek$dtheu zNx-DffNWPc4Fg1|b4Gp8Q?}{r#Xn9?Jw{A{5d(*wkl$?!y6lmy6$dO)=>%86BVW3E z@Dc+nJJai)d-7tw*-S^$jU?TZ@vd)xmFmMZr4vVO1Lfya$tiISQo?rI~}~3a-XnCqckhmy$Z;bXG>n z4Z`RFaG`lua=r$T{W7st)wk`vT(SSrFE43tpDq`tQNPQo;%LJdUGBxmZXuYXAbBE8aNW<(f(s$FBlh7R~Qr-;Tvn#TyAwOJ> z@%{2QZeR_Mqx*uKC9!MaK``$DrJT9Dfm?LkRN;CbsNbsR=ef#!l1p&b|_B=l0KC;v7;7#zgfv z3=CCSOZnC(>6<0f>EbU?r*lj*;hBA$UuMYxzMydeDNzEwF!W( z;EkWR(X@FuFc|k7<7}#$OpS(T(ync!z_qbr3BN&}Sf-y_c?fx~ER04XP6Ie2k1Xfk zc+q}rhhTWqC2shFBR_XuZIRpCmUYVymGOv1CY}#T&$9@a6rVql>)%;LI;!*7*d(P3 zMYRu;)8}*_XheQEa4(E)Yi-~KO9q0M9P!qib!#+Q$mmFo|F{I#Q$Kv85V!F)DAGD? zN(}!N$1JkRD$+F^U+7X|+>6Gnz`EfdKLV72u|GBK+Fyoq_ZppmejVup(wW0EX~JgRfVUPk(abqW>A9K`!^aG0L^?HiXyWb3OYTs>J=-wo}NbP{IM;2hoM z(aH~QpsA2B(o)<_&pW>ruh2apkQmGWZPUOdA!HCU1sty~{QbieAQ%@H7pNDazYWbV z}n;f_V-)fJVRKsW{Yb_n;pyUe>=9fPC{NzG zW9UZ{nDO~DlbH%wlAd<^j2lrPW=>&)a}b(PghWD>h$u1u0I=SHiLOg>iR^E=!gnw0 zz+Dt#7(>jR{@GWeGQvy8z-%ID^A}#jy)&ZBYrTO;;#YUnT+^LYCFe)U9zfE6mZFSU zgGQ03zt*81b%;ffHuJ=%XNhYG&_GUp#UY?3mN54)5$ z4J;+QAVjxf+kV^5nN1Lq3Yu&4h|J}IPBGXN?VzYgpB%{t4<1x>m4}i--9}FEw*2O{ z^M;U{<^DYpIviN18%onBcWc!w`~DtykFV%bt9%NqpOs)3_)p+z69G$KYGM592n|Od zj!mgWb9?<7?bQ#@1+D5f8r8f}%ne2fJS@+>GrVEMAseMDvg=ut@fEBO9DLphv{F4R z@gJ0nKA&;kzNVF$aa8e#+`~oIaM>ner22^3O(A`vF(ss2-)M)5tp7Hj9pzTbt8Y%? z_r;N>0IG{G+F{0-cD5r=$l=W zx^}9{mj9XQm`#L@!+yjaS63kBd*!3ahR0*RANHhMAgJF6=YXjo!Z(cgIGn?r$Irb5 z9-MBEt9%LwnMry)e8*>0*F{CnG+w|D^HmVasD;Wvh_3)6$uWak;g+g{JS^KHV??3g znM6xHuYaZWIzczo*EWi6^-s0O0|2tXQoho!u96DY|7Lx8n^`*Ic@{Pi&}R_H-{PWa z$7JC+-$D+v*aQwFQ#YU}&3bPxbkI`UH)c90LPXAFq(WtJ#S_4Apd|g^yF6O;!K*8M z6bPjTL47tjuAB{53CVew%%ifWo0_4*ZO6WHk-`yqT};x3lC`~?(F+y#zKimoXDUKT zir(2FOOCh)CU4hJFNabnkF@E49cWHNIJeAe*kx~BasKF-WP8t%qhncznAMEHxNOr3 zSK`rUUm^d1JgyLKv2^4%U z?Irzcf0f{n_>;*7FZ1BQpJR`IwgC^Qs;G*lf3rNK(pUU3$+XYU;C$v4df@Vi2zw!c z{h#4GlejYjF4P9hXp(b4K}*C1TtQHa1Bs478ZsG3P9w|2Dg9otl zr7bj8!{~$XnlsalBa)oy?PaPl;KVQUr*vVb=&v8Y@}Xs-qH1D0%z%I>l*1jHLp6_Z z9+ewXHVtnbTwc{SGihXAbKh$=B3Fys&1WQ$n^!616s~(b5i*~Xp75d!x^WxoWv+x> zFVY?z5Sy}hJA2W0jyw-#&r8Tf$`~^7LQw$Lf~mL4aljMR*HIPM8N-J4j5y!fU;USud^eKiQFH#9svri(QS-p`tKWxo#*y=pD)@itMz&h1+M z>NQ~4v`l4tnS)wa*zkSxMIrFQWp9-hEm)}c=t0=o0U(s&sMijv>kVwzR5NpuFyR@? z)M!og$gNvfWUq2vw{5wSF1m}V@!R~_mxroPl<`^`iFwcL?d;})Up|``Q7Pqq>cDCf zF9ekeBw+o!?kS~=LIo6BQ~tAdYMeg?1;GH+4Onh;e{84T~U!mObw*^qk3-YXS*<(HJ^yBE0*f=tRa9;ajox zQWr};ed-}k)0awXXKjL069ZQ#d`)8?oD^_;GCGWklv)nRD9X zy0fJ2Olmd8rE+OaXeA*PsPWFAy+$fddrfyus*AXtDeCM7Z;2&zQRlRt8Hiti`PB0l zvxQO7z>{KR^ZMP*HMM?kthH4&m9mxI1-}yhe5g&@I$=FoD&4lMT7#zXI#0m$`OjPK zi$ldGm+bDofhJk4U1%4(If0+yW~=At>y$K=iN-CS;SLI#3-=qccY)Qqv)-oUIA5p9 zrkd$3T6&muBiT#HyCP4dtso&%)eiQN<#LBBbP|+@HT~ssJe-nik*t+F(~k6@*kdZ? zW*?r_|6Y>1F%>X?VNGn?y(8gllP#tb8{8G1PF(DgK~-@-whXF7=!Cj!wm96hadF=p zN3-n}m_0O`6R53Bc|LL}vsFz=%HoEdO|oY4+4JYUWAJ*)r`pz{6yae5=}+?}O!{~s zG3NdwSyHbA8mG57A^wvqt3-8IG!=+Nj9QTm1f~eQYdQ}{`ZauK`ukP-;!a7&lLYxE zp^DN95rp0znaln?%9BWg=cWP_E*jZGyD3GrO!;x`rZqtZmpQ@tTGQW{#B3~K#?Qzs z>bSgf8GMtIa19vJ<*3qbMz3X(9dNXf)`~&Q;Cw2YnyHR-;;im~>C!y%>2CTCu8Lsv`6Wen+f5_V6H!VH=bUri|iW0hn%^OGvgK|3LKa zqZ!796?+8a{&_uF3xq3wUd|*PZ!2;0yYn@rK-3~ORvoeyBf%7t&vH<1t<{-LwS5{Z z9}ExUFmSb?>A*G-^+ug$wy|D&5iB2#>NcB3hqgt%=}HSFpPD@re{Mh3=I5+???k=5 zUC`s`y?X<-=wp~B??!2NM3$x}?B_lglQr~n(Z-`mBgL@R^UYG`#3Kc{Yj*BFM{l0J z-+C12Fq{IkaA)|Ydl<0R2Lr9wos{Pc3M8+0n=@kjs!+^?ZDkNQxLH{EtSdMf!((6B-GIjhsIAa-Rw?HcbR4ci45gW7oRk0CKZ%4}NwYp0=huT2TL? zv|o)^VDGDQpIr89nF{f52wwhzoeGHN{^4KJCG_|E=BUtjRvItd>VBbhw|6s-?(Na`Gq#@}JTcmW#k_%FI?rn&{Yo82ylCD#?k zkftk251qD;n1@5L3(+gutT3aX?OJS}I#-u8bIr?A7LzY;C_3d7@EwFgQ@JO!9{5!` z#(N~JAG5LdclHlvCUv0F_)8DU$(TNK43AHUyFQ<)Z2M3#8wer4l^~{Uo{+jY!p>0j z(0p13I*@zvA8;JCyl1HtvT^!!ysGCLGn7P?Qg09rT&s*zT2ewl*Ce6t*U24*lyVI) z{h4D$_>;&LA2ZiYl;lIMF*uowm2W#EkZANH@yVQS*QKBFg6`ApnC@R9aT1eY!A?1) zrFYbcLtlp;T`@3oi@q2^m(L04)eW2pWnSSuG%nz|fD)TYcYIVH24^k%@ILX;$P_M2 zQ;i&rKFlkl>_4%&*^0r=azjPFFUY-`c9>nhp7HTG6%}QBY_) z3!SKYQgp0?Ni&^FESou@AfVyd{>FXM6_=Uz~}Z=h-Xj@i9?0WJZJ$ z*Iem4=0WRVBn3?$bxqW6<ETP82y}GVp8tT$VAr+_U3u4hee_ zVLjrj4C%OL@&dOh5ZHCRw@2D1+uenm6UKP^oJ72-VV zJC-yDAZYHfp;ebkRRnpGAQ8h9dhGTv%#}-*@#X+})MNUtQ3_tu3q5Z(C`yO2=>L;ZQbYVkt#L$ENKd0pAq3Amya$j} zpXhbmla8bAG}3I2ENfOBs{709QfyXct*rubyN?&&%wxJG{yBtq=V?`1>BFK{K}yW! z+uD)eY-?cj^2?W=U-tSaV=sAdRZsZfvZ46Fl_Amiw9BY*EV1`7HgXPcwB9 z=bhd=-=FJ+1F5n??XR~KvfC~pZDX?Q!j%!t&4tt9hH|Z(9URF1eSZ`VV}qMXzVa!v z(vIgJ3HBKgn2|MT#+1yr=Ei+P> zd&eWilZk3N!?Jrjq_V_FYmtk3!vXraiR!EBJJ_$yJeRrq#rqkR?X&~I|Kk)afd1)Z zutnJa{OyLv|Lm3E!c*E|V|HkVD*sQE7#nRlq7a@@0De0e`yWS`mw(Ie{7irPV22mL zA6qpKu$5?yUMwg4`MJFyIT9Wo$>lJ6?ZLX$JFL;4t?w7w+H8O^#dL>A_%qh#<}C$~ z_fu2$-M>wFVc(W%Aasl&_kYZk|20V6+>Uc69Dd8y>dQ8C|NjTC^Kw`1_r64y$lu1Q zQMToHHkYV>7~W3b|L0NkUEEb$=lze(_qVZ1IdCU@&+TW1?b6?tp7mmj`R%#~jriMG z1yr~bK2x^XRs2^o8e+VKe4Yv;)ZfO+4HJ&!KKoSBkKPLZ@Nh4)t5#ROUUGZ+{{Txo B8zKMz literal 0 HcmV?d00001 diff --git a/rapid_task_solving/images/example_osl.png b/rapid_task_solving/images/example_osl.png new file mode 100644 index 0000000000000000000000000000000000000000..4ac08afe79126ecdbab8fddf65c63d61f0332dd3 GIT binary patch literal 80059 zcmcG$bzGFs_Bc+t)KUV{E#08x!h$HEl!%n%N~}oj64JXM-CdHB3dka%bT>*#NC+$; zCEdB-<$dq{-1y!5=WqA5uVk-KDLC5{o>#<;@n`vmT;cnGW}oq z87|-7`{3c=gxcW{{JqZ$?DOW6g8kjV{OyUKi~CQExp@ERO`w^J|Bo`R(G8fmPtYdz zK@3wjcE!?n`{suW%^Ahc7|+h`xsjWZmZqGQv!e*a+WDo8h^Hg$1_}q{DTggO+PFd3 zJslmKT;)6!IREM)hb`Y!i*mC6)y2(TfzwFqDZ7fZiw(P!h?s~Nry>bEJ3Gk5+E(t_ zL)E{7W8V}wU%9!#ypWg9?rsX4oHqmg-{)_1y4l(Ob0jC%zfTK0LD8EQQE?G5 z(f6%Z5)gr+Bw=dx!#NUtfgO7IX(ND?0zB>F@nwm4G~UME?o*4oBi&$oZ4!SV*TDo&rx z|04aR*kh*u?}Yz`uJs0Y6i&jx|5}$$%LJ+av19}RR0J=@7;x{V{yRbcT-mU;yN>^S z`wtqph&nAL@YF>D{~J4Qs0(~N{lA(BR4K#StO?>w0sqRwe_bI!V)g%OB8tM|?jvHa z1LFM;!m*6T{~t}1|Nk<2li+PvXlo#a;!nEQpK2yv6t*jj!b+3WL4-H6`Avlj`cg70 z9Xs`S3`?KTGlD!x;n4`K&u?3~lyXmV->OpZ-@Vi0^{M=A3%K5#dBym!Fzr))w{Y6N z`1U2SWTd$-6N0fLu)b~ejkfzOI2UTINN07k zr$4uXnjdaA9t^&@I-2~d20kBgtoklx&*V_i_JkWT?{`wOS>8}C>s+^z;+P*T_2MIJ zkA~o_5~UFz(JbO(Yipwvh@e9Vj7=XG2kEgW#K zxSx=Iml4~QOV!`WTJu2th#f6byYTiHO@XX!+IC0T>-a8G&-lz8WWO2LHxy=U_UT*@ z$@8D;uzo9Wu$-vlWN|*4>fQF`cIO{{!QV~|TNByG)x|Yg9&>e$pS$jCb%`S2Gq4V_ z83Pl?@zlYxo(_IKqNQ#oX9rtNzWcLcn%^_kPr|rj)}U~Uvwp+lQr*)}tJghP`B$+^ zDRX_gG-KFuK8`pJY)mv>)sp^IeYLD}Z9U^Uq~N2?ki3FEilep_fSj?Mn@su8TAZOa zA*SxA%P~pl37^7hwvUsukXJ#!PGP(&R_Me6api&@XH*T!mB#uDCFpcdVcF4e)f;=M zKJ#YAaVDSjXjrx(!9L*07sI5$eej{YxAUR*y=HIolTi40)4tx2sbJzbH{emBWHKfb zD-MiKiQ=G@zQZOd}jZBvoadi)37#A)MIL-d>i<)=L|pfVn)FnVCEO@@64` z@XOYhnQ{xfdWWe>#QItNX8BA@$kMfB&e=CMA`6yl3K0FQrN1#W%l+JQHk&`n2Poac z>f>^+dTHX+#$kCdJ8_!2frp{^M%dWW;l8m`nPoodeR^&uV7hcnwesH zd$#*uD^FCHGWbo;K0@8QEzDJXYh9k7-b^)3p5Qf3yV%R4u299p_s3;Wh0@iuqe+Vv zX)#Z;avooUB$vL!h|966^vjqzrl#_W{8idBCdH2X?!cYc?$pg25_s^s4CgQAPNprc z#x0H#hq|BUo811!GjF55H+u@)Nv}{`%MJ8)s+h)zUN4GH511cUvih%alUA{mIh4N? zGd3=mr~u>jDth(IndZG`#cRZFAlo0PE4^0s`B?FBDY#J@3g}VfDa%PU zdt=;y4q-ZHFp7vw=0BF3g7>h@U{A@|&~Hk_o#9B)=32v7L||OHp#*LXZwZ;5#)I?I zvx|wZ&Eug*%zZy4Eu~%>ckFVXv5OrMJ>5Q;M`DhDOSjam5=&_O?yIG>lb-i=6`e63 za-DBVaCF)FWs17=GNeO^O-R%(Fv5*Se#enFiyuT-e3?gFgk<~?etBy~q~Xu#j`+(K zs_Ozl1$2R3Au|^_uW&@zDa$_l_a5-|UK1ONC4!C(-I44?2H#+MfJ20L5QWp+N(w&G z1CM_zXN9%m4vJCRW zW-P8(EWU7O1#}2X0@MNg$2@0$px_yi#@+TCX7G__4G z;OAWan!hs3p|At3@ax#MoR_vtXCGv)9tPfJ(rr2VASzZ26tjvbD9ft)jJ}?N!?dq*u{z`5*`5i+Pu;+(~gY zOi?r0`1}6VX4id=gOsn0g_980liHq|Smfo>7IuCvU_Qv7F6oqErf)yL67eo1u7WXV zSDlo{BSw%C=>|h$ zP61!V^Df&(FCSQt+eFta* z0$)`5-P`{0C~BZ_{~JT%)!fdNIgm$zG4X^turFww0k;4~>GbwhchKhoTPQ8`75B>O zn4Zp-VIY0jpQ<6ntA*}`U=@=4t-Cqx{sep?NqBcH`N1|z3`3r(7e`ngXFr_HFG$OR z)V_}c{@CIgy3vYBNdh>Kr#ins(a3Im$oDq(xC)@~&1IXV7h?819J#`*DxTCeSXRWm*PgP8Pz-5DZB`!Jj3qO@7b87wmHW| z$ntw8G#fn4j(lFzxI-IWjE7t5sCAct5DiR}m+31=Am%Eap;vY|Jl}FgHuNlGx1E*G zH!Id3u3a~041qUYAd%h5#8bD#Lyc9`c?#xARHXPu+9sdrle`_A6$)dZmTbDb?SxM{ zU^FDN6lLzuf18GiBnWDhJ8_Ay% zgV?gSFlqi{N;n9EP`>ns1QTYnZvvSw7nx4>h;Oy!2D33~O(s0|eek(eSf;Zqsx2xY zuJFjST~JP@zcnaevh|_qBIVbW)61t1zx|d-gPlgP;d9za+>!j5?e#moi4D8_Uu?4q zZCqz+v6sz+5T5n)LsAv|4jvW3Dc+f5B|+Br_C~kt`Q`CDJ(>@)u4(DfdnBx%MKi{^ z7ju6s5SOaK?gSF&mn45Vf|lkQ^*UFtk4;8pl2 zUArKW$>}8My7Fjlm&fX}J`8>%WfTcueUmNMN9Kg@p+L;BzZ=*7Vuh4AQ=xjQmHIN7>MXw6}`TbAxx|Z_>AyHHn4gO_YSO8&Ws+(lB ztu7xu4M89=Nq3Lp@n?j;)$Yrha6?C+=5CxDPzva~OJo(fa*-5+HsoEVM)noMj_-)o zz8r4z0oG>vX#rpP{;uk;huL$=3yui{6Z zX|8h8$`K_9V3*RRDL)M*w@Ud6H;FG3^RGQ81dhj#$9Ml#NhllBUeS)AZ}(o{e5^LmHVljmx1nJ`sdD8+p#(waelSO zN|Xc$`V+BeG=?PgMs*}5#of2JgrwKuhAER6f6ejdUR$2c&7;A~gh+LfNU}6tTrN=r z#z#MHd%HRkmB7YfX^Bl6!s$c%;&Djv1+m zNHD+LCmUVOv4dq@{}RF!JF7?v+)C${-FNS^bbE$zVSrV~$mJdxQt9Ll{TP>{mO*#$ zYTpBPO|bFNxW_PnAPEq78_(fRcco!gqp?wY*$@L* zy-5-R9o5m<#=d!9Y4#b4uzodJu~bzkr$_me81-|LXlKez*Ar1zebE2DIz&5dk!iN$ z^er!0g3Q@r4w??ri8W2crid2&(E;`_oWTl-i7$janY;8z864^1{MQcppj+W2L(c1x z*A!QpQZIj`UX+o&tB{+FPkp2};itM&(LuhMOt!N)5k>s|Ze|ppt+6lSr=~=kE$IDS zqx;XUL&Si$rZFS-_}TBsSp5{EcriDIh7r+I9Jj+H-^u-!Rp{+)^#qTSLU zpS{viYClgpD>v=^NBX9HXBS!wpC9$H^jXO&1`s@#i<_da9oHY&jx%DmG6Q% z#+&hES_KlG{9x}f=QBrtW=tKa;Ykb046~xQ46$Rp7j8JWn=fin3$>sgIAVZAWORo* zJtM(ROoU48V+YY|ypX~6bW7XARzhIo^6YEL*@tZHDsC)GoF}i#?h3`c2ag0sHNip+ ziTyq%HB(D_-FUe;+>4fYJ*C8d80$&MoDf%Ng$sp3Z*fOO(OnWPT|shoswcw>KMEDu zMxTlMbNmEw(~!f*b5&<)C{7C38O7<>=+iEzQCBZ?h*~fgE%QBgl$|;^f`*g=m!&0{ zS>-xTV=|>$4*&$roeg>%3u+3fbTKuPv?;xI#v#F5f%JkGTO|QvK!&7{Y0K&RDkN6n zl8R_Pa?iD<9Y1HxnRcetS1*fBu>%w4pp1gY0ed+e%C!SHwsz8qf2sW|YjMi|Ea1H|VLX%qj z9J+TgE2i*mPRO7dw$FGosx{|R?>;un9vempckX=R>#-G|AZ#A%&pBBVY78u8mLXT8 z(Frg0B-CN{UVFc#u;ZoEBDYn8@^hrN0P=uw4j#@LmFgw1o$vpCXze2>X(!duj7LKr z71;Db6QxF+e1Q zI)$b|;Y**Vs~#GGR%qD)Oq2I16MAkGvF<>evoGFJRt}ahmVQo!f@57PnZ2k%<*)EN za}u7j0ffB=!y|gT$@+F~9D=5MB-)UNU53Y6@glqm*IOR#6xqAL9CbEnhFx{NOgBca zkDw(`h^g;VGBX;^K)w%@4i`Nisz})f6A3vG^ope4PiZqOA2`$LX#)fEth37B`oPGV zTNS!B%0J00?UMiMX&OymyC*~T6u zZ!glqVaKya{#+-h+yIZ5`NmiHAD&@l%Xp4BLl>U12LqgP@3!9CxSd8FjeYCK`Z${B zIQyTyzBAPwfkIww&LlY57Mga6%V|jbxT-gVsCA58{pCQCA&Q%!?c3<+b63a_j!DlR zazjyKT`y*D#HfFAVyhF6w2c7sF41SlAvTU$PEy9DUwSpRq>U1&n zb}=@rn)=D`ML*TlIl>W;c=l>M+?9j#siPlrrG;G^gn3=~VWZl?^E853U2B)7G)j7? zJgRzuuI?NmNW!y}T=Y>YEsAcV;Ld+zA3KpI2_E&?lfmH<10_^RmoBs7XP@=L)Z?76 z1KB4H97CxGM_~tp@+b^mXSUDw6jIf>dGsCsTwSX}k*DD*>z^Un)y(VCJ44f z^MLaD^c{)kOfQn(=y|=!s-RUHBxMUISslqNf1R;+D-CN@d_jJnFK+c?v2A9~`rS?{ z>Og1araO;h^WAN`N-g;<*z`DT(y+nz8@@=h!0BNz*w10AT&lXnu}FgNl8ej)UKW1v zkY1C-GzWQs^^#2?BW@e%+g)FDCaSu&;u0RB-aV%uvUQd7i5`UJg}{?B$>I5zgKx@Z zd19&Ym4i({nyM3KM{;r;JcEuBgT4StqaJNaBLx#h-(?tLL^IuI}5VP^LLX z%@T^VQ#bZUL*^n!#BJ%L!v?@FlwV>V$kI1w?*{Q6=G@8P*WA|P!5dfgd$3=-6k}pO zug|#Ez3HG)PnLM{%VQp~Ue$=Ym|yq17@B$ckZH_5F=M7p%eEF_{&9kE{ft-8u^~^4XsS zhZv0H5U6UEKgZyhK@Sw{>RT!1v`_Wn%f~EwPgB36ad2+vD_+HqrN~3#0g|e`(L|)F z3OnwZw8`$S7cl2F7 zWZw7XulrmKG5ParDBDqtK#NTO@_U_}{Z+|fprc|kNU!!xNMkXZji98wCiX@aDS&~S z6U%K_hxcTx#+zAq53Oz9qD!|VPhfT387w*F)bGP$2irAFezz6>LM6v&Xq9U#m=T(6 z+yd`rfL0?A!*{On}U!H4fDYLHix$WJB%L!rzr>OOrM=A+b25k9pcOK0yx zS*Nr2qM(Ne=sx&I(+Y9SSoG$ten@0}-DqzG20A%$lL#u6f7FdHAY#k6v z2pAH@?2DfGcC5AQPYHa_1^Rw+dAxcKzkJ8s1(V* zdYE>juk7%Wvkh84t3qUd3*Oy@3f6=U)R86f~9*PBc88-OJ(03d?VrS@QLx z1Fz3fn0o2K7s6L>=^ArU$WnyukNX2g`+Kp-M!$6piw*a8t61%3t8U28>0*TKYGe*3 zVe<1#WU{0atP7yNWu#4c zbFM!stUuP-09Oe)1Mab1t*KoctH7Wx?cQ071&P`ledrq=UZh`8U;_?Yr50DOf3#Ip zz7B*(k_I##PnWV)G-92=+feTs2FGs;GyaC5`jvUh*<>?~nthkl3bor;o^XeL|HLswe*zqf!EdCK<}kcDRG zQin^5t37!}l+Q}>zAk;Z$7vLBfUq2ne5B1RxLFdFfBpSB%?rAKwI3|zt|u|)So)}a z_pVP`TCTdCFIDG}`iNi@Ch9X_2q4*VY~M0*jE!cP>$XY9`phn!2OM;eWvd+|ZU#r= z7<>QPG@o>hHkJ<2-g=(%a`*`f{Ujx1t7*CY0DT3S8T;~t^>ity6ft$MzhpXeh88G~ z=rH~{Y9CT+e2S_)6*N36tKGQXHJxooRg`O>7^3w+m$LE#&RzVMq(l1C?x=EwQdl-T zf3MyUz2TKyI?DLFL5v~PoC^$SV>@jn^v?F&@tL3bN(-QwRpCWkALI+*IfwR^(ult&Y*;+pk6Oz z0#>FgkpmgaDzh|PrL0C|R6jMfnNxP>urcqZmp0)^hFauuYq^Lwt#h~Lu@;(lFcWLP zkI{+9rz+K*Ea}j6L7QWXrQmpd=@g9lP4KB5rD@J~J9folzYZInH0N75#@^PSbIe|% zt4&D}DG37&HinD3cdG8ly*Ya(3S(jPuoYO{H78!!oh!Kvtt#A0QDE|aNxN5s@Tl%sy2Nql+HcjR?!~;i2vEh$w>AqAKl1QDtQkduS7njd=nQb-A zChY|)LW8T9Y3Y`{Ryx&WDA@Q#i%kVxN(EaML#fP@%g1AN{5JG1B-Z5{?{tPq%qU9j z+dqm3AZM;o5Cqbjz$)bQRE@fY$q5c0EkDW_#*A{QW*wVd$+ftV`tjpOO4&YEV3_EM z7e_rGj_(z2IrpsO6y8cuv=n8l1TO>DVYM3a9Tr_=(#z?-Vh8XlVW38oqeJ48kaA?^dM4!0k(U-SsjgO3L zVMN8VxI}66CSERxdXt^StL>NVNDYp3_sN~*QG1!kI7ASBMJt@AUc+Su1zS3*j)j3& zj}!}UQOl9AL~H9{a_9NKHjdlJyYbByS@4+bFTFz|)|~sd-Z?NhFWAZj7(UBPj?G+T zKT`TiLxBiUDDOz@zbi6ZXu9t=VGKQDCJ@|fg@yH1WF8@(A8%gwXTyVCIv$%x=fLi9 z5k@Iu&W5A~XLi>$?a;i5N`J+2W1Z0=^E`QwPtMpRAshQ8s?&^50Urq3&uJ<>PAD_L~UBUr0&$}Qs>!!>^PkOR(U{B(he zWW|CoAEGPJ`AM1jj&^o1zgoCBWqi@@;z?$Om@Q@VNLxe2C4E6J@5^D!__%-&m+1x^ z2AVp)(8{K@N1_``dVn1!$#0ZT9_%8ohcN5O0`MoPxeco|)Ndoo;^+cGitj;Mb5=5F z6pWs(F|fy#C<8B(GB|n{h7*MFLW3x|B@i-QZ0|XUSF#_|cf~$0C}xFj5KvNZKG0S# zG2qm#4GIK<_wRR_&i0qgVzm?Stf213<2kA#Uk+s!yu!OqaE{#TI=R>s?Td->tuB?h zYT_q#fu_dgVRXGAhV^c{;Il;VlptMgFZ<~SCn?HM^*~S(r?zhH?m_-RA@KRCHY+^g z&6@Fdl8w%4q!&DHCxL+_v9>du0#rQ97f5^0T8*_InXI4P$olf>LxW{ydP)7H*EYoj zy==xKA2+}4MeZ+>UDM;t_C)`cQZtZ_o9hNS^+29@T#dqPT6N|0vMPM#A`KPMj}c@~ z_#ChSU*!CJ&(*$vLvWi-f9T0dmP6z!npq_I+M$tvt9El?#q6C5ucPMz+%$rTU+o+a zL09;UaWH=%2cB84uZB6nT+b^CIsH(C(Pi>gCRNmSh5CP4N$W$?pn?)0m9cEq;*Ob` zqYOap-TuEKGYLKH{=kjWJ&N>?%HfDKT_q+Fr`$*y!nMf)yUH*N31>?sZ8(r&OsX&I z%>c0mI5%e)xr+NMsP6EPLm4>gi zjKtzDXNRs#Ewud_CB!*l1}5}wVXM-}1L>I^Yy|fRtLydTh#grl^s#Y~HAQRRf?k1s zb-A9$H72d)W_sKe}6LvlFTdcaIN!?^pl{KWSZ6zPvmqYL$0t@$5w@K9j$DQX5*fBy^6Go?ziM zOmVJrUC@E^x{-@Hn?rCKWTxD#npIJ|coPG~6uqrI>RDz3C%Bt-P_QUvNAtrO%yU)f znT+qhzHCK*n!oxrzjfh&>WDX~C!6CA72_&abFa^ahxkT>mZ1$%o*m&ii=H7CjX3f~ z`_2w(gtKjsmGcxaaDl02A48zjHQNQ*X@@i=N6Y;7KeMU ziK1Ae2o@74CR5C?*%TNzpF!8`AA%E@V`t37+)=s5QS_r1>fx0;wDFiTFa|1LrsvPM z5TqX}!b)b89>QX#tkS#EV+#(7wy;rUP0GH>u-Ks`yXF(!S6iRJeZ5TZ)YD7#cE#H) z%p=Pr1cG0qLSmSj@b)84%=Hdt4sm3NtP%Y&EZ2MLe`?7G;yf5+dVj>)9B>Dy_)}Os zlv;30vph&&;lSi*0U?W5bWr?UY?Qs4b7Ox=Q&uUD26tsSNPn9<3y1o(wSQPdCRQIq zf3p>AaC`D0*;v6YmWbPFk0ZC7KgODbZLQ$fk8ZbacdcyxZ^t0Wy%F5T9lsrWtn!W- z9eO#|@b%LEVo)YZ^Kw!2`sENe4S)pKi+3Pf_Yh8rk8ks+*lI~K@J<9SA};3m6F!vv zP6}SE1yFn0kZFQyiB7_v@-B?{h-_&kU~kI=mpHdGs) zv$$NV(i{eRN0y;9E+rEMn{_>Q6Akxb_sO*52tCmc@S{lt%`tF3BtZN$NDpbWz z0J$=zxA&AR!Lpsdv`vgP@3farTJvj5d`|d79Nt%7mmgl0c(_MV_u^a+0}kD%nYm3G zCcdoQ$2zo;V+Zk4A6epwuUUX!FJFPp_u6%=pPF!TSueTfT$qY}?cx~yYyVWYk!MkeFKO5J-Fkj-Fyo?bSIkWz}RTPg7A;Y=0D8 za2FErP$t&tu4qMkOp|mkrFlDpb;_dC(+8>iD)=nz3ZM@;p}5yYJ;ULeJn^!2;(y|h z_^fGgu`Uk)>Hec@2&rMWb?r@B-S*YuubU*AcjF(7B?J{ql~UeNJM^1Ln5}=rml=s& zrt*lA3&|L@x7646^?c6QZ>GBR#6(8M&b~kFBsuh;W%?F*B-T)P)pO26o#AF+n)hqu zRxe52Qt@T{{Uy@CS4&(=$@%rlEeJE2bCimwg!Xog_(ee)`w01wd+l2%f6+DIX+DqQ zn&IMos!vki=v^Xzp+TO%K#<9abkzpLm84um~ zb}7%EuveZln!K_U6ouzWq9#6@fZu3XpJ{Ic5qrxOn7_Y7rp>TITfVFTI(%>{gN@r_ zeA$W**FE|Vv9T!VSqVf!{deZ4dR>E~w-|?8?0#o?ydOI}S}uNF8T&=o3%D}mjb|Vr zNzfo_)**WE$H0+O7*)2eV&kG03#?$19|dOv2{_sjtFH-q+l8%PN-Q;Kj;k4zSW!(zo9LV&I05~j!r3K`PY-DYc+JxIry@CSrwBh{EbUzaW zB}cWFbI22YJSED|HtL`fwF(a)M(5tmKo5JGk?Nhr=|mF}ukA805b+0$ruXS3H}#is zJbP!fc2BO;x~S7%-F#U?ju}@iXy195n+$LD9qbuTkod*pKK8JFb|(mNgKw)!U6OKk zr=1cGb%D_n7Uj3`fz-v7LhC^4u#a-iX{~T6!v`9@z}}Fzij1me{ZKM()tmuo-|^zc zRO)i|hjk^wW3G1fLpiWy-l4)3)~c_4-U5e&`g>_=NNX>W`XkkF##EE0I;0Ko{tFTa*_PR zEVYe_Zjltki=^=Z3Hx~!#l_3R6GSB35uHI0g9GV zyv&5ol01Sg`|8(lazn_~r0p1PsPmY;Y;CE^R+sojgfj`@P!}&^>rq;S&pn57UOHUI zhZ!|Fox{3mO(Jk3>G$oiU;JH7^`6xwa95%kr2h!k$Zo~O2Q6-t3F|iZ(JBNnFVa_P z9v7&XYf>^S58DTECGsStxMgx+QwcQ|?oN7=?huPcxz!kNP2JB5%(iCdvdGtN?vuK1 zS6jS?4zy4Dl;nPL1-;wmINFtUI@wCLddDNdFH!iAF_1x(dp!}-SRi!T&!OW92t_k^ zGEl~WZ9=IK3#J;RfKh$LYuH_Ukw#mWF521=&HFK~6t?-=vU}egyvZ7unFjmyw0-xv zoz9ruM_nvDnR&8Y#QaG$q5XPi16T`TS1!tGZ4Ean*5IR^{d6om;7yr=>*WH$IlwC9 zc)(6&I{a8|qBm-fEGt9^&3_VIx+cV&EnfQ@hF^DEl$o+ zoa{%@qeAwE;`VzRuQSWfJbqvD$}n+I(TF5Eq642qfKJr+!HI>^XTYJ1#zbgwaZ*k_k`DtKcP_YkHNWsr$TI z4|zqj75POf0>eQEONx2oY4H;gnG)>Z$Ki|YYmba*Q+;cBuXQ;W)XaaWxx~Exac4F# zUM||im*|6G6?atP=_4O>wOPMXjgO0ZcO;!#=4Bjy#M>`lKG_rU4z>C|bK9x84SL1F*doOvI&X+0}Umqd}7IZ0;DeE1I$nG;7pZstem+LBMbp=YFy5EX} zBCoT^hU79!1~YZXV-iFC-$0hMRCQEwwv4Y5rCo}DG5W``^h?;ZFk>xu$0A#~&2*Vw zDZh4mUd#^Wszp)TueuGdFf&T7Q+Mxm$LNH`(k(mt0U1^EYCUa}BsM6d(@K?U(sL^~ zDpB9@9VpQ!^S;AOvLSSfSxu(1hwWFu`yZAA5uZ1ly9RPq=#tb<7@T**4;7^Omzpu2 z6pH&#SDy1QTLqKh6Pu~OYnSw25VuL8gTmn(q(d^WPnRi6gegE3K^N;f8J|bj28Wu> zyRC36$&@aVoQU;a%S09}2<7ai#Yk31BWK?EPt>{I!VLBNdl8cGFmAdE83$<2D z4%cvbUak`Yg%;uTIt(p)f~YFlEei%1joNXr&8xQi28~JiTzy>lz`$EEJo8apVr^|v zO8wl8zJfY?c)n{-x}VOy6!W@8kr}Aci;{3{CwFI@{k^azmTci(6jRW`Nr+nfbQ6n| zgtrcOC5d7s%D&!JyO@Z3-u2o$mF$nO_FYM0CkYo>rQ5xXb!{Or&SUG8XS(8p0=1c# z?T@)=zoB1~wft0!a%~)Oj8ga^p0^1|ig5Yt>Q*Qd87sc92_a1InI30)G zcaia@*U|wozH;|&-&ZdXF;B96*>N}Gvzzr`n2PS_|63uTk7g?sf`ig$C-?|&9EV%UH zfO&-S#gU2j+%M{sMP8I?AxBQwZ4OT10s)aT7h#=+J(?7kL(A^Ym^e5l`XWH2z5}?j54B9>TV->VruU+UMdF}&ps8DOfbTk>? zDZifx+%N2W9@k&$28*4N9b(Df6C_m(AZ8ZOX@6F>8+TYh_0^MdSal80I#f|mo3+vQgtDNYb`FKkp@0dye9>?9W7H5^=}x&+gj)1Z64 ztxHH})6VbtIrA_%e1y}16m^Cqdfp*DuAgJcAd=RsCZ+h@U ze(0+lm8|TX*DB-3f|(bW36k}}Qi0IJ@#TS9W%>LJjd_k3mwb%bK?gP~KsrhAWQ}bZ z|Ld!keTaUM{5!iRJmQ%eJp!EJ-EeWU#L$+sVOnxJ<;*>Nr zGBj{qH;byzoL6SDpHk*=2K-A!3gL-;6h|3IR%S)i!M1)|A^H+~7cypwHbGAJ19T-!66Y^Tay@zK zE)n$9s5_|sw_4z*#>h~mfz&2tBVJePD#wv<>~(HOsmkt%Q1$QPiVae= zx>JnNNbyCgh^;K-xNv|pKBe=m)LfFuiap=}>`+Xrk64o55vn*=qI4b)1<@8IOSiLU zKN~CPiam$K-JXnK#yyJpA{iAD5#VVg&99t7{fvAZLTuj^`KKBm{~n5>ExombE9Ir6 zBOvNg_TtI{uh638Kp}Y{H)1(VC~h^5?`E*&?clm*As+;dY+jf(6=JGf= zwE8<`hL;q+6MGk!nV-&jMB*6&vZ?^xxzw;)TXswXpDh51AlzDZXhsm|P(G^oI?)wn zs++a|F7vTldcW!>Nae^%ZV@Pi5HMdYah=W6ditRHmitS80(o~`h=60-M~@+Tr%qJ% z3jWEd=r%;Yq(W9l1TTNfvMAw|m36p7q6bOYn_Ptv@F8-$uKYA}o&I#SVo+u?+2|@| z<(HJ(Nq?-=@491se49)jvdRU0kY6#>Q=6%q_xvR&)T6iIGMtciZKkgrHwn~F*pz?y z94xOVE4Y(=rUSO5;lN)kGxQ6szn#vTDl>@CBi}axPG*0i=-WqC9f#%N z*XpP}FM=ovh`Q^f7Z4L7K9%Gzd6IMOrOc=Iud#x*j;tk!zzZK-JON+M0X}VL(yRVB z=xMMtOJYofG=$R`n?*GnE*T*ysIyC#(?QZ!q6aULKbs_cgbL_gbeZ;lvz+$bPOVFF zyl9Pey8g|OdTy4!4-GGJFsiv*vV3dxL(W*rVaK_YpP17}&l*=p?Z##7-^AyfZ*K^~ z)oL`*)f!vbgY0ZAyQD3{DQBK-OUSb%p*w`J@c<2S~I) zrlfB*oi;c)75$d?Q}>=}LmGiJ)-t>eOD6D4);qZ;L;K-DX{FmsAS+|qU{bPHk4l>; zSHu&6wnBBI$8E8m_wb@eYo2gxGDgN2@|KQJicDs6-O?~4^Zj;Pcl*9ASvZ?lQ6?oj z*~8WEjCScME=jO(y>~&+8uhbB0Qxql)!xAIpFU2brL>#?6Gq+#9IyJjja2$RyG7@9 zp)SbVqhWkZG#)E=tAn(7Dexm`mI6J`R+gcj&k0aGAmZHjkUW z{SGRo)UB{xWU^_`lIbdPvaJ~Kw?T$VbT7Pde87+e;ojM9Ij%u%5!QIB=j1uRO-!$1 z3>r*nnau$aR4e%Lv$P={BWD*GFnwM%I;t{0S#e+&b*jk|y@WSfFVn}@<%FW&8o zn6IP71M>uiG;6q`eE=A&h}x@YzT`T0lT`3)M`jJHEzKGkJRR0gW~5{v%=H41U3%r~cNV=dU@*?#{}%v<-t;FQZs$UcB|&O# z>uB)a#!Qq-Z=ds(Pj%_m;@!+g^CNGP1Ltz1cqCHYytUS8^6%dl z__5q26Z}Wgdl(wT8(6C`i#^z_j|UG$2I*X8^La;E*36}v3iaUa8Z^L`MC8$?P(Hsly=}52O5sUdl#2$~WS8dIXEFPx zl9h~!ZJ)ax3xff<9X0P`rru2-1hSogcLOL6<(X318PhZZxjJc>t*r_}%(C2nG6#2P z|6yGCu;3|zatAVbxVxx7Ob?~!#}_PXg!Lv!baDvWBtlvE)%d7*aEBN8wE(PdzC*{K z-{}vN*z^xQ(=F;xy%a1pd`m3)=LD=NJRB-Iv@<(ELfLE{6{RcSgO2T zn_s`I;9=X3a*M}p`x2hWAiU1yIt29prV4MC$pr@G(j@=di0{Qr^mo>5IUQNykv0wM}Z2Puj)5eOhPR1pwq3eu!Y7Xs32 zC?X)e_oh^(g&IoeNC{PX3B9+_2_4Rz=XuY!*0Pi)|$*cvuDrVGka!Vvq*Q- zj_%>9@)PKPdLhb~US?hKSko(+-bmOuUC0uStHjVnDM?P(T|E#$_s)7tv}RGz9(hlv zL*Mj=p8ai_La{OawHW0Ux@+LhRS^e)ohJ%)-jBE(e{i_C#IEOtPkG%&U>L8rM7tlH zDH#&#mmJUb_)J~%_N#QEPMBEus?MdkqQpe8rF$oMcJPp5Wp3WLRn7Q9lK*m^sB=@; z?0sU*UoD(U3bVYLUp)RF`kC59|gLiOnn7x*ISG zG-H|~164>5**A5|gNgW1mZ%3SoV2D|jQ!DrEHg}P1%zZmm$cs#v&m?jdPlrNb%~P- zh+7x9z2w-QDmFgXR+(LYwn&a|aDa@KcH^LuXs6FmUtGXb^?%G~VWJ=5hod^LaQO$R zW$)sXeZkSc0GJc+P4#xog1>BDv1s3d@fKI>+>s;^s$WYyfvbW z-6ThEX%v13q*odk1mY z)s)!eS`hQSsk`)@IS-i$ObFy3R0F3vmZBhPO>)9RC=e&jn zo;|v=Lt8J;A1v|2dI@mrr;9}}Z-1o!OCn$v&O9og6sQAD@WjWlo%m6pK2mx*$or1` zLrgbo?{itUANzX)w3dHg5jy3GN>K7w_!&KH3RM1{;Ph)ytB zC51GTUJjD7(vG`3Kq!GT6pYMHCb55ONc&C6*i469rYXddsS$R^dBXL)QA zec3Z$oBs241LsM|5dG?t8{oq8s5uhEQnBf&-MJh?9BW;{B5KY5gqX%OgSpN8VU;x` zlJALa`Pa~{M-k09kCXpCs1g}K;?MHkev@HOA}gCV#OibKnt?{QO7;3InrGngBJ)YB zVyq-T=pd(09rVS&a>ZC%GqP1@2{x5g@zgn+6TpKbkJqmazEL@WA-*a%04d@2f_0 z$3-0)p^#A!u}zMldz{Z{%SU$GKUYi((xdJ} zRjLr5wWf&AUS|K|ayN@q0v*H}qQ6~dL89!@e?mS&e(_2AN%m*(> zrao3_cz3bWbQmti3jyhGdjyiE5jN59tri4T7^*d34(Fj_>-dc zibjb;M!Zt9w82r&0k?_+OXMYzQ5-LGJ@c$BvzrwbMmV%~v~}(|uooT7EZll`t@#b@ z&>sb1(lNbJMko+}E_YjyZB&FkT(c?J^n|ZD`;Zsb?Gk&szw6v{;Zi=Q6IVH#X&ZbX zrme>bH7@@6CwypMuc6Iwq6R7dW2yjSEbA-rULP#R3>I!_9;?EE-PP*gf9wYiyVI5I z9rAdA_)6%^MKzU8-ZbEarQ_w@CH5@ERb2`OG6;XAeBr%OdcFYY+6 z|C2vPh+F9Ii&LyojCoP17Pr8KP4wCk=vBCFEStc0JAQEOg{R&OBYkg;WWCmPU=495 z$R~zdB725hlaOi?&ECx4-=C+TC{DBXX#uvXT@tkq5_fO_@O&PUz?Iu zV&y)Y?t@viajk%qMgH#C;};gkM#LPv)7_fbkSbmw#uunQRy5+Mm}w`M9!9VI5%I>8+ww^A z|IB|m(*0~;L-4q$e@6?zij;_RY7J~oxhtR10sVCcBEqqq2mE0cd7hjh!xgo}hE{!R zwcXlQ#D}{E;+bgg07r&nelPgnYt^(CQ^7CtoLw>8J?VkhK{tg@<#xjj5}Ov>%`W#O z$FBq3-D*y?SkaT0MVnW~jpp}sy%~1g&r$~tyO)gxB`an!2QFE61cI7vj`h9NMJlHM zHXInPdF<4ep9C6j-Y6Mw>7;9QkTh{_!4{!4AWphDn`_PEYmgylH_B!|A~17zE2yW_ zX21d=Id`Ufxg>97ww0p7zf2>N&au7TFt|1OwubLovtz*KEQM=k;fkWYDa(R{sxzdu zWL>8n@l=x2Lvvy6qMpw^sc*pMWDPtJ>~l%jt2xfU@iHE4Y)&KukW%VeBQwV1nv5kF zxJgJ(^z``8k~a);5m^?TOBQCWT4IOp-Zfsm7%%w6Ve)uwnCDvO>zC<{?pFD!bJWmq3Jm*}8zyZ&7vDq ^+A)S>M2C>Q0u5O|dC_Kt>D; zF6KALJ8O8tFbl(w-1G907HQIJJj(PI(?^_7+#g`e{4*u~uiPC99gx@SEWw`Q)9mAzGlX?ZRf_eU< zMxgp6H@t)B<6u}T`G0*9TfiEd=WYF(%;JnDD+4xdzmZ7Y7ZBqA)!~1C7?nH(u9=lQRQCE$seEnz!whgb{%9E!{pa5S zX$t^(L;v3*0D5@(TnXntzK-o^-?2`)FMpq__4It92kMLgHFhF@CD= ztPy|&g$!Jq=kntHU9zt@*DzEiLwi65)cX1(Y|wr35i&7;5Rvx9HJ8cK?*uRS0D z#Rqb++hm9`M?pNzOcsgaiGj58(jTPF{M9oqx`@5jd#Fi^^wAyQVl;fl1yyy00oL_3 z!x1#SFMdtEHOHo8Hfw*Hch~@EGm{pXqa^^NX_nm+OI8Ak7eoHiSkk=1Vr~F88G5Us z!}QR7En|E`kg8JGwl)_HWNf2<_>=4$)5S{jrWA2^$C63FmX%cWi`7(gWbNwrJ6tq{ z+(b6lv!^r?^#=t#MnuuG1iN-8#aF-JOa0(#>=r`YWuOy&e{9YuelcS0AMK5l67tm< z;~>Z8Sb z-nZ<}DUg|Qu7)@W;hciHPLGc5hLyR(j*U>Jg0{Us}{qI@d@_tbrPtg7QtoX6A zAA5?jGEe`oO;Jf+)WT2TV7N@ClM$*?7w(Zs!6r2Ow-LHZgYIx<9GyMoY?{t+YQ`z1 zLz;pr-k@HEPynilqqrCeTrkAI7*CB*i!Jh1Y?9tE%!Gbor97;{1yQ^7gCC)jyqo4@ z>)CynG2^|`XV))f-uV3KKE9u9L^Ha2W_I>u_5^N6zki0Qy#2t&6W!>Ea%u~onmL&r z?%#AGX~!b_ACCgy!n6Fqv06;AN|zBe4o#b`PkUKe7hYJCqWq;}Kc^TBeUT;;8-* zm1=Na?jj}huyRt4?!-)RqaGCju8O2TYsZW-H%&hPZwGnoltrpdFUm3Tcyec*=L&!6 zN#L@!NV7NBy=bYL!tvalahWwoc$3Lhw469~iq4qB5VlQF@H$e}RuRE+l&I%r?^)^@ zz=mCMGL=3j)Ps(WqR*{MAgYqf^O%WrmmRILg(OeJQVRit4cMki*|T-Wv$C&l{F-~6 zoYXV*OS?v4jOHbPPXy@`lB@nh>g%=Dbj@~Vn88dzPTt{#oA833Q4y+j5>Y%gv%P_3`U}kXR zWsj{s{P|#~9ky3SBDk(j_)}Ym`?=L<;)*R8VpirfZ``_0VT}viV2u|8^f2@Rec{4JwyXuwk!E z$aSjSx^R~{DmN%>wjacEw@j&1wuo68pT1n!t9+h#e(dL$fNI;TQ;j7KZkl(Xyf7TApn~B*x)eO8kqCPdLWjJZDzo~VSd1cBQUK*M05+g^N_XgT6(S!@WM7fYB1kBAJWvhgQkFanKj1npJ`-1 z!t*P8@=i=9f3jfMs$XJv+67~k@5=bd_R4(7b=utmah;HgQdmoyiI6-za+&NCf+see zYoL)c*EUz1@UoSccd@`lYSJG~Hgy|$=_Td%7gK}s)pTI-)mhts?S_+?Ioqbg2$vP- z@yaC?$*S}({W_2(G~5d2-Df@3kJz5x5^vWE#R-aAAb8lN<={;k8{8ygDYP*Q^BfE{ zn7;ps!N%us_M|F*8)1hs>C~UnpDBpF$K6z!?!4oUCb_;uLtqVRHdx>jhdqQ1*3ssT6*5H;GHBBR1-jBgR~uw*`d&-P}Bx5 z0xn)R;3;vaGw>PZYS;sy8z=fSX(!wO&g=XW!qbDkQ7uCN=6O8gqFG&WxsjhXw1eJ3 zdp6^sd1tr}J)L0$J180(D2D$ZtGf5*Va)%>83|69dv78DgW~26P3N5|CM(AQNI>|N{d1oyRCAia=P|;}+ zY3J!IH_5hL;p%z0#5gC2vZ*LCPmRfoi9UE*Y&q5+j$kp|H54S#R`S^ev%cFx0J?NdyjpxeSaq8Klnm9O z|3)TAh=)UH1LR|&CkMP;P@cODC1~lq6UQzR$d!l}$WZS#HQm?mIXAuWrtXl`DK0Pk1=97!L?bY+A5AB*aLm_V#}#Gf8F{xzdKUDdTk z+UzEks`5!2rI7y8(Kqa$1emIIebl>hw2Y)A~MPStp36dJ+v&xn8w2|-;RH;H~@A|dTwPtx7| zn~$FQYT4KhXn#U_{addDsZpA~&($&NzR>hvfZ^D9_U5U*$}Q~1NQ$M{mXb!P_i`2h z<~BH#)$Ox*{o{%R@%F-R&%Lfqx}Q5u0+?-QQX^0qNvcLg7%LWwhUa%C^@Eso-1Qs#>1`3rxL9j*q0;Ooa1vg-W&!$76uYc$5B&4!bL#ulj^fCO26G z(y^&Z#ct01eCmN)cSmB%3K{c&J2d!z87< zS-zXJ(fy}DT{<-jklqAYe9;@pv zQ!C}=Y!>tF*QzQmZ*5Vis~wSJvB11feXZH4+QgMU{h&9w^XxYhun7hwxv>QGFtDn> zTj2Oas_jpxEwW{l)>qhJUHWv)XoxqdNP4d``Q^&jmOmG+7|Jedk)_v?R(Je1q?_Lg z7AjFIm3ujRY>ESSw3lB0z<_?DmPecHpJrBmZ9aNp^we^<;bccE3D}WCk70iR^-yR~ ze~z^jvwWM@BG}*U`AO17y+AXxjX%Y1*hC4M@5SnoCOyLY91!fMqzRPB#g@4L#Mc5o zX`Y&4oi$$c^R0Fpvcsr6kkd#Hsm@)$tN8c6DM0V0I#6IERHDL7C@yKbKEKL?=@|3r zJ7*37a>I@y{>KZV*ZKcbiV@e%9Ka~=RKEn6Q5%ImED=vVcw?XO{wbm1MsH3lzlkL7G-Dk z4t6Wht|RE?)f*1{*noHUOuGC}phAhe!@EDK+<30N+m#!`8I1);3g9i-_J_g*KaHg)-VYKToOL2iugCXk^UMW z;7>|&f*Yz4vo5uG|IB^lk9<`a?cN=E0;N%^0RkU)mtWK28PjG{PdO8XA`anaL}jN#rDb@*ct1sj?yCs{-&xk47&_0ktI_)c$j`75})Rse0DKnwT}T zrbNAxsWgxD7X`%49EgwXy2e!({vYEAiYg1?f-UzF)M4wwAp`_YVyfAvV&@5jp5AVH zTY`AWV$BZJ)%LNiV#XDyux5-!SI;#HB~3j0C-9DC+7w}RxI1RMPmpn)wI66LlT^Mj z2x$?JUo8$yC>Cc!e6qFrEig-b#wV;i-*GDScGv(wZ9wbIFR_zI5U&Vi{L2i);g7B~ zjyk|@kbDT|DMZ7^Z?+TT#FbV_wgMhI0yQMZ#A_t5R)-Uw5k`L}$+ zKym>o}K*kAvB?l9R;Q-rC#(`22Q^5+x z+w=Ec1Op~sJD}Iqh1isxx-6<+Xed~cP>Ba?R)I0S!nc|{a`r$;zh%H%rNi+|(w6QK zrUQT{+3BbL*mcY0AMKci$vhZ6ASzjBTgnbj5xiNVSRhE`6Y@3_P(UUSa7kH75SW+- zzky-V`^A>d1cDL*D0uv<0-cIRz&IIR<^Pd3(jy!h!|zZXwC)HvZ&hFeCbc@12@JVWx(-@h?f7ebFCs$2|0KKnbdp!!wX=rbv3qgTp0K;b9_ zD6n!Tyv0<6e0aZ;DE85*&BvGY&KYWeiDGUxQp=pr*8al|-l2WtA#Ef|`eS%RF6_*s z-D$-AEij|-aAo|1kw0oSUST~^fnDp4^JC9@6YnLu+sQS&Lum0jzWeITVfXf#!N^rh z6FdopJANE8&6|&6rNUT`??~1CyV(WsZ_NU%U+b@dYn)zzs6zN zQ0{WR_G8o`WzT>)!#{BkbPf!?QEsnA0y_mUt2q>?PBn+jl{Co!Z=17pP>$}n<57q# z29!JtkBTAiJCg&}4H8_B|Igq*c?UZk{VceYzb-I5iXSHk$g0pWBw*Zuoe~&Sot69= zX(M7nP(%@6EgOu-v^d!*7gBg{Wj_1b9QV{{EYTL2_;G^1|0*Tm`f2fF>(7Poo2j%P z0Zgn?qU7D}q~jSZ{Mr@D|2^5RlFSF&w#}-?sfrff0;h=3^cNX(XnxU-`ZN6S&mqgHC; zl(Iyt#M&NfD==u+_kgL^9T0DGrlyi2i?yKvQU=Im9^m#i@IwagOZX@YFo55?oh~C= zOhq@{-8t?e09Vp@z6mU;BHmsDK^f3Ep?2dY*(z~`0W#RnrR1Rs_IyCfS7oU*hAXu( zPx;?6ao1lhs1*;6n`-6U7CNc^35tmH8ADmtVidW9TTE4=l}&;F5kF-25B3DCYBUl{ zcwBviLna*ogi+0ETXOJ)9g|fSv53ODhowTP{>Fl^8)YqBKvs2|2%0~q%WPD}ng1#m z^}}EHhn~^+BdPsV)ufHONkLoS&!A*JLbYk2)(+oh(XFCkSXxzbgO=mLei*NBqf`t~ zkKczs!WGi_M2sj_dV4HuM@ZDDnFqBR!@mhkJe3Ad2 zCK2FvC}N+0Xi9EqU}_UO^gCz`;|C^>4$olmXxoK=XDj?vOJWVShcHv+~9ISUwu+&we~XlH2@)N9$_}ZwG8Eg zuis}ezx?5LUgVrLa_Y>OpxG>TgMO$kUq&uE@>_W>Q%Bm>6xHjP)TmV+i)FCKPj zt%+A3W8wn`LfFO7<)xb`72$o|it|2ypfvD`gRHNCi7ed`VME7R~5}{(=?&`8gQ)*ffnqBG)mZ4I#v|BC_}zH$rd&5 z@~{OQHs}dqlnJNOxt~&}fLoBGk1}F0DIh9QAY68c1UELKxwv{}%GSU|-FmAMj?~<$ z&W=3u03Y`9UqJEWAj(@p*6Zk40s&ItnKtK}*4h6&git-;Ro}3L_8E&1 z-TY;)SCuy%Ws!yIg>pH~C{9;0y~8zK9k(b0xcucV?G#!^70u>f$15uq1KP(BzKzQ! z6k%p?tRg8f;JO78$C`!QX)L?@qNMDbw{k78J@P!eR*vrsX!fxtip-(IW2M3A2qIGv z{s~>7Gqeo4+{-|ZrS5#T(D%^0IOmmz1CX#hbj3-d^@qqB;9UpcS&_N19mGka!EX7D zH1zkE7<<)O|K^uZ-Yv1AsHdUhGw7IHs;UYe;f9&>c@4;b4BVEW-m=hk`#{R19&4x) zR7u0{F$dBvv7NtmO01vclvG){z9 zCEV5mWIP@GH!@og8?2(k+f(Kl$3`k7foDg0-^;t+4)pZA&6R9H$TCh$xwJ^9XvUQ( zzB}!R;+1EdX>^|6N2fXV+mg7*8o>BB!@!b`{!X9n;tq#5mzz~rMnbo<(cj5#8pY8p z7u8_fh6>0iw0*JP1&WwI>OnPLY>x}hd4)3M*xTQ~=J%~#-L#(iou4rj(Js`z5iU~vhY3RPF}4+v+ESE!t>?xUwG25bQwoW7)pSrbLQ z5cdw9eDe4QTHd5fssquQWZDs%lkpD$#c%5l`Z<90!%Cvzti;LWKotpeT?UrOS;L^`N zBaWxGQ2IhV)@z1t(%{V`6AzgXp7W%5nvSxp$^+m`!3ocNDSmQ?XXGI61<9ejN--sH z#CA9(FL@A+iB~D9*7os-D_Ahqo-T%>20IzXxtkUbGOUZFjDIC@WEiIv9{6;T1+X!P zS4v(kL{_T$Y^tLRY^!?wt7isn*NLm->0NEWyO;Osq{lR`!jv?tiixK75m66+VI)&J zCwC&xb-$0xlrU}0No=6d>heJ9d)by81ZH7_;+9&M+1>{OBt9M;H(9@VQpyDW2-ZE zIG2(XTvEUK7;lxx@AHWx#oVjl{F=a`IpAd(^p@ex$BXq&iD_=cSrMjJYj@InUBDC0 za$!3&Hn%>Az9fU$MmCzH!si+%JVROFGDDV_Oe07AhJcWsk4(M^G`HbGfA_PH zp0hJ-eA*3Ng-ep>T7-XmG}k<`!-wT zDI8$wk&LW?(Ptq(>_**~6&`49`rp1I+vEA{scThpsJwWQMb) zT&Xvh9sTgn;DdDp8#n&*=uWMwPo2K|&KU$Lh=Mn`&SgVJxL6Qy{^d02zAQS)6`r1{Zk`Iv0mEfF_6{ zq|K^qV#L?s)poGWQ*(svD@*4+?bzd_O8;AUm7cNBwDh4G2YGE5WH&(Z`MR{I*obO$ZWThUN&37-ocr-lNeAtYV7xmbAo-n;*3G-eEg(Mye?Hs2->PDSToKAR}Fq;0;QeZVvHE$G3wmXopOOCNhR$z_AI}wfu=M=;jO2HXUeQlQ|G?==+ zdk%45hy|Y%*zao8dP)w|uFX33i+g4|W6pO+F;g5I`g%2%do-uEO^|_Myhzk;$aXDo z#dLBg_n-r27{%ywbKxM8=mQZ2u5z9w1)Wka`X|H=cZot&JU;7{Ds>pf-)lS zh!e}V8Y_}I)2dxz!*hL|xxyxS>K8`$;d2_xI=YB+ue4`LAm+}o00(X^iiRd+~dUdTC z3pYF2;nLC3=@;&~!L~J=+?RyL@NUm+k%3(4YS+RitL*5MEBvFxjHynp6!J!5%4=({ zjuWG@r4|#7GlMbcuVGO89_uu3mM{iA$F@Z@rs&}{#|dvItbkhJcy5R$V3kL9m6 zj5EA4&GXQ;i#j@T+9*=V!ViCWRw+RVA%rs@A+F%!16FgBsFQj7#^KdsU3gb-^=x#L z$07$!N-7&YinZ{e<_c|}Nla!k%Wi)v7p(E1AdFgK8kr)Ie>d(hS5VFO+E!hEoPx4j zIv&EFUYFC=^`7}oGX51ZZKZ~`T#OKg|))v@NZ)a(50#o2m= ztLfar)3Xl4qgd}Hb^Q-_g&kJ9HEO*!%15-?kjUZAbcV@V^^l&-*E51>Zawg}c1}ap zTtCXr+Our&f{j++VPWsR`)Z4Pn%?j{r+?V_@flIVfj>-}YIkaAeKoD!G@)jnp9+@`lQ)kHFHaTD|E<*r zBYV2y9-7Ga$Zt|i=833$EO4fls}4VNyOnXZih!8WFQw^IT}pPkT5DextZaZF zfAxAiBE7`#-Z*A^O8KC|L!qTnzg(g1B01&u-mB~bOUTO;PgTE^*M5tOpL8x)5)Tdv zVO%8PuHTrbkh{GJ)N8@el4f_j6_Jz7hUZ zV=VlrK4RC1%(Gu6^{_3J;AWiA-An|oqce+Va5cMn_-y!5_GV}5#? zChc|Co5X0)&Fc>9s zwLAiI;N0mFO=@NrwJC@-fGQjJvn^2MHrD{v1NBz=(AW3cC{T{+(CUUOp;J_21%Y)l zc^Y^-8uDs|?u&UWZ7LOa22VaWLxM_e_;R2HzG8Kn;85v%Xsui0h_L6vw3&#G-cy!* zo}Mv??s9X9)=R1-F(I&@m9G-wUON%!&w^j~i$J>Qy^hjOjNZUY^QmfC<|9`$dkkHL zr()gWo0!#4ZQo^%KBiJ~(T^$Btx!(paXb(_>5B4df3}Iq{Y?HwyPIa1-10ii=Q`gf z=7_R*g`j$_m^hPLAw1;`GT=}xXEu~avm|QV;Ho(17eP$}>!a*g0veKX0-r;A7JnRi zbJKzOfet&A5_{3ZX1I=o+p`TcKhiflHLru%8wY6wE_ z4!b-ma#6A6xvk+o>50i%vMR6j=So{kK2^E6&pF{pase3^sXyJ6bP2QmI==rxmNq2? z#WA{Y_tl+fZar12qMBiYwNuQfl}zWzy@8j%y~MR>p7hUj`M42%;J>pze58ADd^s$6 z`QhAY{l&oSF_%S-O~{KMVtwS(?OXSj8E>u`munfXbJfkms4N?y7L1n)v9b`iE+M zcS{uf@#Uvh+O&EtA5@v((3!c?(>VZ(y`$bVVim~|^Yg*e5Mw;*SDwF65oIKc%Io&o z_d6jKCe-aD>mxg0{KzQNc7d?=#L5RDnhQ@o)j#BIOp0iPz#)lCA&JyV+mkZ^*KO)Q z0@Z!JNl1M1M7md%GoK91LYs~>Eo^#Vn$u#=m@qlpk~p7>e4m8QOb<0y)r`xzyg!3$ z935Z`{p@WLVqId+ubp41AN@(4uUgo{cW*AmckeJwW$KT5b3C!1UG9#Ot4*pji~+-A z1A|f&TBaZA?C(DifqvLWZ_5g#u=lgijtw(4m|P6J-M{89-#RiD$qvr^F?~k=f_R8PkwAO#OLIloU zZm=F4988PcycZyPR@8#$@UyhyYn#FGa|zJOdm16P@0B)V)4aFz9}HOfpWk1WoQ0|z z3@-;PQ~Qb-y)n_l*s#vG{khCj*Ucw!`Q)iQ-I+8>odAupdjFd6=PO~UH}j9@(6o)I z?^-0L?RTUTokGC^rqzYj0ihR-{&V1dz|CdzUP(OuogEfCW*$htN##GyX9r%oPeu; zw;*{W@orn7qRYvVb@T)Fy_uI^HiR9Iwz_lchr|?UdMND?MsZ4)&i|T z`f*QY56Z5ACDr+yuu0j4x4vwS6;*0+@4z^UGLU~;Id7q{<^J?73t{t`mf2;xHVW{7 z#1ZxQ&J04HlI!XfAAxgR*yV#ExT7Q6pO=3v8v;dFRzu`sqkD4RFX0d6T3ByLHX_ov zC1CYF8Xb|0yY-FptnC?{80i-a)VW#0f2z9TL!BsEV-$1onTq4uJV>gCBRxW75|_6F z9$Ec7py4DtU*vd)VwcJw&bui7yHWIdbxzFYTjkg<{H$N=ThCjCgSOa0DZ7ej7;Epa ze%B+AQ3~R7XFWxI=~$*+>STfvTPtPo&LA;?BClPUbbX(|@0igcxXJAVA(QO84L9)Z zX@o}kg>z#61|&y$Ew8)qI43}K1e|MKe*Ubzw4)$a;C*(eO6Mfv?Ax}UyIKi(bqpJW z&s&d^e)~aB;ppaQrp5_Z6pBCaT5x(nIAHE?z#7|KuHal(+sOhBVs3oo;pmB4y?lw) zbe0o`2C`Z)aH{tDQ-v{-#&zBVnR0ymxMtWEJe<%*}mQCckY_sbNl$&;h(b#%C zz$eDLo0KA+p18XWS#oB=`ypmj>TCS&jnN113VZG&>NS!6V;B8(Uhj&Qc?vG=4NllHs-P`;wFXpm7Dk{>hHJ3O|p`S_H9i(`#+_v zrs-d(KXnqPejRt{HDw^DZEGzj?C00}OE`ZOgn0|IhiJH46}u}m|0)d8c<%Wd*HPoX zj16_h;H<}y9Z!B_;mJvCL&|#X>vHw-8OJycRgD2&mWM}4`-nN~gY#^aV7x7nJmna5 z3)LP5{HFs3qk?BQ?nJLcaRd3OG;XBxL_8Pch?TZ}i^_&ti;|B2 zT?>A^n;6Sn=xUR=F`TdX(c`xg*+}}Uef&v$gIjdkkA%Kc87YwZE7|fA-~AN6p!}|p zx8B%LQ$caJ1X5El{5fW&83}RNuo18I+ApZwn$kODsfvx5`KS$uvjd9B z(_Gd)$>BHsLbjZ3^2CX=v?7c>4r=y|RkT06^WBwPHf_H1(@IaP9rqR*BDu9jIh>cj zZ90SDNG|AZREel!7K1TcErwMw)F?MHtX zi0Q=NpW_L&xw|HR2;1qrQ?ptc*j-dnL14{f>ltOvJ>#z*%2^Y(2l*8RH;6>;px~di zEjfjn5-i1E4Oj1?m;;_mQP@sA73cF5eOz^;`6K16ek-(P$wqm$Skp$QP;P;MQ8B#~Mzg2-Nrvk4eGAhM3 zomMN*>JKSt{bm;s%an9Qx_eG$j(M>v6fA!iwV)HSVox8QLXHJqtX!SU4Vp=U-;%_O zXN7~OBY&fw?Oi8AZnQLQr{}Zc5bI@u=iJ)yp}!NiT(+asKX>&Fr{Q+#4%*7=ygXVH zyOSmD2dKo$`y)i(xs6%6I*-m&3M5&opC_=T5o4a7nSag zSkFNicc|C4YcA9}?8`ipB_zMD|@a(A=n<%q60jQ&SO1UIb)?7fW* zEo?WeLfPr9ky@yn!$Kh;-N=eI^Q~J_w-jDWf3Rg#Y_u(*?jujdJRcx0c^;*Va4FEs zKDTO<+d*pMIE$ZWlt`U-4qOK~o|5_fYlHF0;=<+X=q22=xeKEF*y3@LRa6#)aIZz~ zXFeAG@*@nf!CI|Exn^vMN1Ox%nGe*CwI0p*(6*Jw%>x+)IX@4ESMr;#DahW?KIdLN?u(Ul-x*k`*kJ4q`C7&IjsRmSkXEX9 z+TB{EmX$Hrbk!soeH3<7BZm8G3V70-dDy^Pw#+v42%XNX?@V)eNV{;D+g~B^m<2#19$v4?jKcjC(8at=qcwRlJo-Tv~hyZLNsM zktLfOKmp^soP+!D>rglfzWY zr#~G5vSM(rr%8oek|le+O5c;I<=6WnIt?!A^(U894j0Raw3eJ#@mdE@*Jm`C@9d^H zG%2aqns3zMA~!Qu4+Iy|Qm)#)tR@B8SAKt*)UfP3JOkDXsqF7mFKT10MjqHx*2=Ds zXVDe$Fq43eF%!wdLvR_-4X|5O7(G~X%CZsU_NTN1)N#ZSq&2EU=vH1O3ddjXPDi6$ zqmX2~abKMa%~q$YSHf@Hw9bzhpH0dj-pJj!#7yVbs*L{TQBJY0`z=82#-mwGBJZ?9i zW+j=VvtBq$k+g`nZr}1f;p4$emP4%s^HkObS$)sCZAZCvR1E5v?rc8f9$LX4U=xS|&Y8@5 zxyY7RQ|(;be>V4NS9R{(>jR}!ZcIp*?Z)-$)W3=NU3U5nUMyoC7&i9zRM$Bm%PE|j zgGW#3ROQMQi1EC_+1gT)G4{ULeW?SzfjuhVXDM+a+R@NIKd1~{Ps1QP9Z>YHznCr4 zLLHbc#mTMb`%=CNh%pZG8*=MByVsb;EE}X%_q`$8^X>BdYDKwSEcZsON96WLA0-v~ zx@XGV@Wecp6yp8Ix*?UtnLnee+DxY>_CH54?BqS!-wCrFbc=N1S#C8?A%R)m7;uc> z)q6f*jCrJyr80_mntc2Gx1M9+H1NX1gV-OiD;`?N-pL;Mca}R}q4A9f9kUcN&yz&% zq@PzAg!U!U5|deX^9(ch{Y|0u<6eZ5cN*c2&4a4mkntx>+!Rb*8=(q3NG=%B zoqN4bvTC?FP70&Xu99P>nRGR|CZ2k_>}WD?>xg*I7p0aA*H|%&l*?Wl6=3N&?pPwnUiE==t_zUq=Hvk9&G z&Aq#2mH+0=uwbO}8pnupo8jn+7>V87@4_;tsLPB7y~S}d(E+DaRvx_yvx?<#h9?7; z`;3S^H)bse+O1dC{PQhjih_BV~N-! zeo}AWhM?rzNHypOR<@4FK}5cEM~BsO-n!n$hnb245&F|ZF%5OOg+HqrQ@0wW+bYL* zpT}Hex+ygRFM8KDjcKNpt+Cz^ak>BxW)bjtVr<57{ZjJX1wO+7G-Qe@X?dG^|&UtR0=6{Grs%)?}ytGr*byi6m ziRT=Hpq`^%ChDZ_{607(Q`HOQ{(WED|FYbE7=NiXFo|#=(Wd&&8SB zkTC`&!FLEwJv%~ij~q3J()e9V>aQ1jhIAZRjtApcMc2t6s}ax-{p!oz8b>N9$BD{j zqdGs|R7)7rm9CO(d`~GN$zAy+7hL?81=W-!Y1nw`75!Mm>9uI&U(>0xV$-<H?LM8yg%?ib}Gy6GI(=4X0OZM=6+33}8iA!1UX2y^^XV8STCnFo&eQ>7BN zLOQTG;EmyT-`8r23TkqXbb2}%8CbES36sj}FJpuicKkhK$218G3KVDMYE#WwK&q3No_nta1GIclSY(b5g0Tcklox76sCmXYDXsBPm(=B}ML=&H1-XTHijHL1Zs2r6W;cR-BUoeh_Z8pMm_9c&k>viZRh za=cvAK|ObN=Ph0vRFVHvqr`Sbudr;EtQ8Q+X^?icdv&7ok-oCG5D}8Tt<11_jEo}r zQcC|52Af<3E}_3Y(#RS0hWdY8h1=HPVcmO zm3eIftINK&cI?orrRQYKji^z`OxCDI?mAR`M2o@&c>7WB6!5_5(&Noej#Kk#X)6opnv-$`X*g#vxnFtzM2$-%jr!;C6z z;O``+#&Q;3-Ci4)u>~uBS++6FVF=mzkxYIKJ{RU;Ojwbc=-1Jow9B^7zmV{X4vG@t zjFX{**Io=!G0CV2ep2pcxJ}=n7UfQ|{%$dKH;G5C*TJHZu%q@IGM3G7C)T#XV+U8P zYDnLBU>*!QR%=cj1F*2kz|gw{EX1;Jkk9S&4QYiKX*c+XMut~B0Qv+L za!mgWkS#@jqLs5CHyv+o=76NIxwMa(*Z>uwSk|rsj8DA!17E3Jfsw-JJius6z4ETw z^4IO`C%&5G*ltXF<+1IK198N2RGx zi`3S7oqj6Ai!qti`yu}0Pk+uX=oXSymr_1P1+`odAfAU=GDS{F|GTrID5kI6LL_)z zE`{f#945aHGF`fi_YJC=fmt?Tm}umRp05L0>H$<$>!(tfmTIU-i`|h-i4vR{9YQ8U z^;GsyxVFJLpeS~bf;7f+cUl;)sB%v6Zp-ot@l zFc3%oi1M##=8dS&<{`InvuJj)!U)-PmCgI_lEoCz*b} z>dYTZe+R26<0DMkCp9P9#DBnb*_XIs$Cg+WjsT zt&&>xP6?0XR89LSwnaRcLPf|r$RCc*9}+el7ivt$A#uA4oc!m-?X|45q_2GzD>=8T zXE7?Wv$J{Vz*?V?&(IxE-ESBc8dhN;Fir`;av)4PE!Brv=|^ao(Mfz{t%gS=cG!+l z8g=-8kEfGf^YbU^-?E1$kVnZ-D>g!|4^TvbEZ8EmAZ$G)D$4=Of8(+0sAZuW=ge?^ z98FK5@Ky`8;wx`bum|2oEW>a~8!m$KEU~rITS)guoBiS8w-30GW~=m-I*ahQtbX5Q zWlTwC53<1T4xMw?k8J5|I&3LsHD7C5HVzYnF`>aQYSTJcXk5s4-ZYZAsQJ9dy^ip zW&#nPCnq-#?2~F|bY4=DLf~O==ckA3FwIa_0IGP38{jSFJmIE=j%BgT*6>RBt;{>D zHLQGkSIAf3_2YM;JTw^YJqBcym~34sZN8*JhZx6oFb@UMbEprDJNA*h8O-83jXwvp zvIzQ%(Yl;@`{MR}L_1sf-5<9qnv;?7qRP?nDk#OtaEW9qMIl>PDcs#RTm&3921`}W zqT|g$KKbnpc=-6kL{qw1hkt9LK25k|yP4Mye{eSu51YKNVYCo8{E1}OCFM;p_B zYx!vB#VF`xS+crv78XUS-7!;4M3{EIR6ECmDU5Trkj)Op?M)EFaz93#?=_ZrQu*s@ z_SZ3zDUBOIMVa10!v$lVlZi`)nPlIFHsp_Lq1r))i5X548il?@-@P~C#oSvzBSX9< zUP?jS_xw_vwRYs-|HU_YKhiZauV3Z=4dQrpJ1H^mX%%R zcVzO>XkgoO2(hWGAl3b%GtW=JaYV8iMJ7zWH=}+m`t{0A4 zwv7#p*DHD~I~kDC>*sW4o56m#vpl$9 zf|QUHy&uyvGt9`31<5motAGv4X0bT#X@ou98Qtkm_Vl?aGV}8C*MjUJb?RO?i7(57 zuTEBf;fzPgO;uNwt>Z<8va6hl5!pcR6AbwLOWH$(suxLeum#Lq&aeM>v2bPaAxlRJ zVQ0g|ett_mSGAwK0Fq)SNMeqO6_{|Mx%g+15+Cyba%WY#j}XX$#MHd~{stEp*LtHn zG;xXKTX?nwb~uhsYL-=;rGY7#_8Lyr_Qs1Iykq2$)U*&K9Qiyb}!L+g}#uC~cyT41QF7G5ElgT#`2AND&w-q;7PUTb$l=&xR-Bu1MV z%)6e1wq*8AW4{=^?6)UO2*^Y}L@6u;AxDzkFvHUCRek%U0Sbj-jiFePYm%RxhUZdN zY_}}3UMDN9OzTg64(ok+&vig!^^}c;=!XjFPBMUW+xbF zYMr0tcLlCbky1f5-(Ifq_wl5npQ}bT<0T_$BLjCzY&m~@=JW|lFbw zW|F|WnAlH|M#>Cyo z{Efmb=EujV5LK{ftQis4yUzPR5oBD1v)$$lgKTG#;gUUylBD8Q1#RbO7GPmHz5ygP zB}NEyZ}-ZNL2M#lsw6Q-@!gTl()yt{@k?H!p(sAWUgtLTXVt5|V3_CyXi>ewza7u} z;#E7GX)XQsns2RW*Kl2=m4)&`$clBy%w6(7o}K!;=-=ny0jKHD-*++Xrqtl?n|;aG zgl0mMonpqZE(2p}FdgIJ5#$q5<)xbCBhHh3HUo*Yo;~tis82QW5)Ke2 z7ZU1Mn;MzcYviKv(EHkp(rnQS5Scsry#83$%X6ZqcU0*p;W+)g;R@8Aq_sg~bdZe0 zA!ieKCH7X5z2JxU)NVIbFj1trR~qZ&$6ChH9FvIRy!)9Qz3< zKbhQqD3$r$M?S4aFdi|5T_L?-H(|f-mxE-d?kw20v&~vn*St<42PukOvV`(rn^}1$ z!-9z>x>BV&#n9=3Mw`#+2}0_wnBxv*4dUIYT)y4px^Cv-=}GbdS4y338w@E!ZnI&4 zw958{%B%14PZ|k)S&;6dwGQL89AULrf46mp_EH&+b6V0p$$1AaM?MV~w|9CVN#qhN z3fX)XMGU;-H3~2b@$qDOIVQEpboo;gN$JtVU2;b4_VVqs6V3dr!x0G~0xyNNuU!b8 z!b(!uq`@e_NofIY7l-M*D@uETw@5ypx^>FxU zg>lNM;8l>x4{N;>z07~%Zty_62ytK9lKyM|1x{JIUca-1HueqBaeaLe!|nM3avJ^# z4SSjmeEDZQjleFDjV7lhS^eAcV0~&(gXKPdcG<;vnuZiJqhex(M-3h+Q`wf}-=oH9 z)r5|<5w^rnzQN~BjK+wcnye|(@{jV2mV9f+x7Tw#xG*kv&V`E2E991w*63KM=JJ^< zVUDP&zHo9#oP-6M{MohDsG*^uk~CD2Mt*F%HkOCpqaQK9tnh8`I~he7Hi^BCDS!mn zI4HDmxE)8UcW{#D7y-v9%GgUXFbmSV`Nowj{W&nR^lNd-Y(;%Z;MBjM3wYT5Mex6c z=TCG`cS`(FPc+>qtybwgTwA5r_qV~%J}rHAt3lsh1n^^Ua`i`eq+n-i-useQm&m_B z6@Rtq)c&==@Or86=H9MgsP=h_+~ZfW1^*2KXIds93f6~n97pxuIc>k|DUV4=EfvYS z9AUqDq4yek4;Vq}{y{Bo`xoo2Xp-CsBnk|5pC24z@vU)W>%b*1Xls4RFr<2HiLTm- zNi>F8ZB5T4o&pbMzcx7R<7P`bHUan3Sh;6q+>J1gdIu7Rr+dEhUuQluyDT@SAVC1p z*0SOT&_WRt-+L0HFh({ZA)!t8sI@3A)TGm;#`3q+GGM6hEwuH*L(TFcSh9dJ@;-cj z#%6h>E8|yGLb6N=;?0}C_}X1&L1ON2y$f87chZtxeWFp2*jAE$(E5@mGpH^q_GdJE z3&;Z%xVh)=qsX5$EyVNy1)6L5@lig7^8x(QZZ$5JJ|@ykRY(w>YB2&?yadZO>cV=y zQL6A)KBG9yGpVdANgW4X>lUeevV8GAv-3U&Pe@g(C}dnq>y#-EY=z@_I%t|&;~7ZN zM{F7z+FF;KbMZP$5?4Vol#mhh5qR3Gl*Pn=(SMq~{|h?wtL{U1lfKo_qKw=ZbA4Wi z3i`_Qk<40UJE52F){M6Gr-nzy(vVI5if;WyUD&05)|#;Enxx;KY&(;z_D$Vl0^1GT zs+sOfY;$ajbB)s|r5ATQ^&<=<8Tqcx|Ju>6kU3Ur>&K?%Z>^tKu?g?x8#W_C755s>oLUQqZ)OpS`h=gfq@4&QB((AagvF+lh2D7&6%O1ga@vrVADhuCh z{@Uyg-f$bwy?N{P%}O@{0yN{)uXa-%O+fYq3Ie*@!Lc5oR!k(vg1S*4?`1xwqEa1( z8}ca0PuIw81KRC5ebx(kRr16wU7l2$>ZZjgowbX%2H`F@MGoi$xJ_@q+IfsE3oZ0; zjZlS%PAJlVZTava#{`&zK$cj!^D_;{%__v@jhQpuKl3TEH#A@kn~-EVe6CtOHP zsECUU0~obgCi@-G%o$kD$k*8UQ=TJwS9Tc$!`YW6CZ_aG~wkHm=A z#i~%g-N0gUYaB=x)6Y2z2&!w$zxuC;^X?z=!Mm}KkAIsA3Kp=J+N+((jd~0NWZVI* zQPXB1AE}?-nayfB`=oH1h1?fA7Lj>aBW`G+d{BYQ9Vgt^jMlh7@Tcxr&*6FUHKqJ$O`8oz zdhS!sP}GP|Fi=yxLA_7XId1j&DYIF+#)^M;gbu5QGF5-t{PJ$oOva7GO(K_pkacH% zlRnneKdFG75ROH#%tn%%2k5cK{ba;je#Avri9|H5Myv6a!^1)lXa}`KC6wo72tU!O zY@U2ep|TN%Yx}Q_Lse?r5Ud0*QI}eLy4=^K5Ejw#@6O96Z}$#Nwmq00aF(dJVwn~~ zL~;kPRp`ZchtR+3iXJ%SM7}3ItWbOg}(d)A5d!M|%9q^LfLX zcNn+p^BY3~klUK$on5?-&)g!XW@?wLDhM(os%J#7IN4>Zt8^ln1 zSZ-4oaT0@O9U2EdB3&Bl<)V@(*?WaXpDHde!g`S4-4W@FcFr^0F!Z6$n@T(1p=sui z0usofACZ{*cEyo)u0;kl!xE%=10-debP zGYHl^vh5;~XFwz(nxE4|b@#)WRy_z9=^>J*!{gK^rqqGXp;Z=2Q(P6P39 z9#ve%TGj>=k|P^v2jet)b!@o_`n*NyrWD~0vTNQjj{SE_$=5c$WXKx$s3tP>m}EOV z^~(jp zpLNJEUCBUvM!)Co(1^gRNj(n_W}|Ylv|j7eO}Pfy7r9NZ#4c9t^2vaV5hH2nIN^s# zq55et6XcADvdj3g#oJ|B7U0BxFrPziMJrO?xg)cQbegmFlFs5B#sW4{v zCFYCMM>-BZ-3H%%U3uC-=l8p?-t|)VY7gCo`+Kq8(fVutmLWgrRt|?A(8$On#*E!k zT0l0Zz$+{(lJLmyU4qBl6cuEn>fe418eR90IA-Q`M>jjl#6{;Amp1^gn~|mjz?t62JX_YUcBz7-v2f=3H$bV7W^1ItpWOhz&!I{BTz*vIzkyJZBOdwQ(Q8m8A^+*~0PTAZP*9pX z_=sLiiXSxo@{`Onl9BCqkncLQ)%L|9{UgZ`(Wg1FtF`u}ivr%x()%SvGgbbatl59P zE z!l|oD_#@*%0Xu7dJDJm1)VzkJ8nB5RphC6}VE!?rA0SVGLw;&&&Pv9GhV-3kYjBUt zv_2zw4cq*N7kIh6_Hm(UeT3DY#)CA4JN6x^{HK>~;n|?`oK&lh7$3CeBI$a+RszqV1YF7U#|MJJ+)Puh&|gt%2k zrFZ=OBUoN%?UepT$231nz0|?tO9{`nIo6Ko_;&YFXM+-4Esls{OzzN(eYcEOVc}X{ zaUX@cZrP78IE=0qP16b@Q+HNWbza<7i$(80x>%q3Rgq9uJx*{LwbH$?JscimHq zjg5zN%YD3;@eJ3Z>K#w*&96Up!Qt1OAZvxiK`HlS;J>-ivO>WXuLQ*?5> z=2Kv7N!>2yu~%0rtd_EQ@N#%;BqN_A+|W`}r*0E(o5FXe@@ACqi-|FW3YUH+3DDcNu3-5eSUA(!vr{r;2BqXH|>t34J}zRECne zP}(9M^WJ%dsgj3CH+erxkR94&bJ|<40zJ82~_~D zeg_~w+yN>YlC_G$ZE0GO#)lG=09qnl8qQqQE<}K#dnmw*eTx0*VjH}n!J|zsJBUye zX?lWkBq6_WM4UB7%*lT0a%Z^UoH8lyNTfCjuB?#}!ZozU=-*|*vk5;xf35)l71Nb) zm~^P^E&=AU@%tU&pI(g2NgYWI7j|qi;7LT91#^M{d<58tJ{x-yltexjB=!4&GE6WW zdPz36Q^)hwRKq=L9i~>=dIBSA??HTrQXI|mFCD8?+M+YBe?Am~UK@RzazAZbhv)9gIR-QjYg%eTzjg_4dR9AZozh9*PEHV7 zXMY6PHqC>ID2=c>S$E%x+Mybvv%CzidxWh0p1q;@QhxyoT4I1wa!7yIq28aaOc2O& zo}UZ=3dZlQNXH0mfEv_H2Plyqa1Lc#eMnH8mGBhoe@PpJ#OcH^bys1~$Xz(ao#am+ z*G%5gZAwq!XeT_;-e!d~^%jSvx8>xw$u|cu+kP64!fqY;Y(f2%gW|TiZ05LakQ(hz z#gYwi&SrXzB8hC1-nAiR|A~t?qOAzuAMMZc>6frdF`M6WEDTp;uft8Y>RE2mYxw{q z0lHaOWp;REiAt6X5(C=roRKWQ)=z_cz670|>k>s4WK=Ma7tBAx zNt|(a@fY9@1(LFCeLW9=Y#0&N4RR7rV3TC!JwT!|?6Ua%A3DlkH?g6d+TIL$8K31x zvBIt3r@|RWfkS={vu$3K>SvO}z(3l5o{Eo5+B9<(!RV;{@jrnLFi<`{DHSBj69$x| z=%XCZ7-0O4>g4+&Kf}%n$JY3wQ8MSU>yUkA-6Kmg*zZS>l=uC!KW+q4$hM%HNj|Q# zE9@B-hzPqK5&wGh7Mi~%~h@Z{P>4Z$m&_Rh-{NzVInuIwPm1I049sOoHEN2BlT;o%u+ z2-=kxwX`!FZ6az*M*?9O-+T2@{ScKZ4A(b#y;WMhXr|oz?dFN8x+_XRH=%Fh`3ox? zstZsS+PIrWFJsM?{|DnIcr&02>+p}S-vZV%Rqg4Z$0TMRiaubbgQar+VWUITe7to`+D)^(3U^Qp*JK1#S)zt8oB|e-9)c))U4-&Y)H~9qR^Z1 zHga5}D27mkVFHt2sBQ?x-{?hX69nC~<0JYRyXVhHl_MJ_P(-#i$!n!%D=^33z0t#e zBuYWm+i22ix;3vxv*=N-HBceL^Y|AM+RA+3{ycJlR{jRkfa24uWv^7<;v;GjAcTuidQu4T`K<-^_ z`z5cqBzk9RpmU4(tyKu7G1n6N`qJKWl{dR(S#J@LexaM7Y$Zg1WxbAISQrLcwB8LL zPXGSB>=>?UkBTTBl-CBAtiH3RbfdN5qxH~j`gMA0@ROYoM~8u?r985HM(B5^v0_{M z^B-e=E>?qvpP7Hy+}lgY^V!8%jeSnj=h^divax90W-{;T{7D{&kq3Qn63Yy$$}2Di z0eC({WKh^SiW!Dy*eetz07Krt8iKNyG_4=<>1~+9)mK9@Yf5`{smZy_MV&$f1+?;E z(8~mzwoh}d3+A1rW|q>{Mj|z3cuBjDs5aYrNVe?-^T^I0tj?W3_?HFad8)e`?w~a* z*{a38Wb0S9-OTk|qXZD=8KF@Tr#(s)TlUevjO4f}#1Xv>^m~S-xbYfNUd~zTH`H4` zey{cQa&-R$H$`HogJWP1zL1v)3j6rY*azfBRRJOmG5^U~yUo5He1#A+Dg!M3>+n}? z;P@R;moM(A?`~TnkkvG#?8#{p0@}c%zl-8+Zi#sj0ZPy52)KQdQe%iYpL~}By6w11q zrDijJL89-v#AM_M`45Pq`poN~`vHk2YZta@{H^wCu~$^;$HdKoT|nKEiirPN$pu?l zy?_vp@*u~R6#VzkhGXmt#^r>e0#QORI&f4wOccGQTL-Csaq#iJBtI_B1DC<|>xese z$uOyR{oo$C|Zs z%zH3y_Rjq_ZC$I@_Ss1bBwP_0AZ945rKrp`vZ4zx1ITxloLOlW%~IJ z5&6PTxyM$tm$nnb3LFn|?WQE9PV@vxLQS!Gu_>cOjy7FxOh~>)AVm}j!t*bcP3Ga9 zb#C8|B9~6#@Fh&Zf zQV2E-Aef_;0_~e?v#5%S9OvDX>DcjP;4pguZ$+w+9w|COd0{~X-W8#LS>>JKm zO68jmpuzvhrIT9=0vV#qI(Rx}_?eJcs+iYfMfKF+-|bOgtJMV!CC%~;B|^T3@{T}K zSwq9e5PhvX&iP$i(FtGW23t>stmTdO;rx)L8k-Zo9_3ayUTp*@;D$u&r8hTd9d%}O7$_WUtxe*W zDgEG3O{p4udKCJPpAd469`2iHTkw3o{_lBOyZ(!r=a`JgLwc^Op@!nh`^I<*0et|7 zcDtBT_Lpt@3rYNfMui^A_(mYY*|tz&@FfOW5qcn8v^$?oRB}LGH-vWs+ZP?N=-TNC z{D^kl7>5C6{(_BO+|CohTE-{tx5E`$BBH$xHYa-e zQt!usERrQ6Szx`#&N7|uJCshQ2It$F)r+|Mu4Gj;UBuMnoobt|d&hg1?LQWBHr4+2 zn>HVxCh~dHUBw7^>gb&!IAPcTBR7QbHq{w|6zvWd+G7RPrM_)E^Y;Z&mmE895KtNR zBD{-^fz6CfWAAIKpBIIW?xl~5)YHj@twDF|C#F*Yn5V`mr)w~6B940}M-I%l-{0+! zt!8~weR0Yu#;B%Fljh&{4&m3zK;v&W)HjqS=%`;l73Ka4?7Gwlw;H5!-o#x9r!Nx; z-62gxPCT+|TzIbs|H_6GfB5fpCBcO73%lrZU;`=n`}~|W@+7~E z++HVJRBm+t+`z|~mW?|%-E;p-gYQI$0S2GfVYo+~0`sGUut^ofzhJd5bp?0D`Sv== znX$AFL+Q3@$t+!`*s#fMbzvI>mI)S#-)pNozDwQ3Q*Vi6DSy5pM~a1vhCKFv>-l{+ zP<8Ac_wqVi{{C}hOmqH;GXQ@XrYRdLK+kJJ#Tj58jI3oP*-*w`niI;TeqK;&`1?mJ z8#Qof^oacl>nh3K`b&_^E8v3JR^*158;fa259v6T$Q@2UrW`j|F3f0~>MrM0z5|KA z>G|(pl!}%mJK0O~W8U?;TGJw{3(S1M!1|!3ZPbIkYHtz_jZ&N}EJS|;3E;^?l5?ub6B_xV?DD5+O*uAkQ&K(GyKVzv*Oa^% zNpRLd;k{SsFXy|JY%4IG7L#2d(_upL9qzztq;@MwwF#KiSR@xJB43nMgHaPAA~VW9 zqC)CyXSCkRo24=f6h#*<)=xeefS%6#iqI2q?;AA)L=7RXn~#` z`1kl-ejfY3qkY32K5wRrPuE?_yD;7EkobKsH}LRHAQmCvpRqXf5{o(%!n?%p#8D_H zu)uUm6rcqTeZ(n*1G&G@k&43fe=$grxV*lDyLsv!a!&i(jb+jQ)!UZ^AFh1ygo5eH z5Dh?XU!G;PSos0vRmL8luDD5+L>qmIQ3$Gk9Oa`{@_$NM=}Z!;7@G4ngP0{bBWM~m+4+JWl&M{mD7x7z<^+FxZ=@({Wuy~tXXetVj8i$Rn1?zKe9 z2wTb?_?!Pz5KXs~P|9BggaE9RTg26XBs@1;sq;y9CZfh#v?lpNhS z>Yfgj3~X0g!fGfpJv<(i2h?e3%oO|)%C|hHqSwWomZZNas^x#u{;XX;oL==#zy^*s z+D*KFGa5t0nA>emLxZhYtl2`r$1s)Sgo*|Z8H*smAozfe24NZo6^3A8@$!ds1LO#( zMJJ8yiD1)n0vInB0h%b-*x;t&7F(6bw?#w_l~);kJY)p$mpZlyLN z7z1kDCbY7xBv}!5=>m{QyB#o9n@{2+2WmeH!*fyCKnppiDU*M%Np78 z!bX^K;U!E#u!<^PDb{WP#~|}NVZFhN-%t8+U^lv=&8`tluIDnmX7UPA8y$#VdwhxLLiY+{PkKSIf|S zzRr~G1f=7T3TWGM$i`t+hGYzM$`yUu!B*RT?nE&eoJ=zN!<JV z!mLw!?Oh$0k-rt+BE5~38IJilx5d}W2y7c%?t}L$mV@F1%$XC*h>ib zFY0M#s2u*h)bY5cQ`pA#I=F9AS=wet!xwjT)8BW?KcAVJ4)BY(UBob4i%}cZ0H~U5 z)lJ2q0dy6O;GVTH9M>>B7t3^#H^mjwhYYs)#W0{g=&3`SQiEh!q$wus zAcgu%Q?7+vY$l9INAgaA$&0SZn`q>ua-#cu^7*pHEwSh@=zb>?-*vL2ThUt7i>)~C z-rVMpqlzZhzeXJltrmc267k9Gk$sYFClLCoTiP4Nk8!*!1`b)0#armTg75%cL!vE1 zZV@-&R2WhthmGj+xc@ksbzXFJn!}T(w900*+NP>wBZv`8{X!G=59AnV)Zv&BzrYNT zU{g*v==r{$|LX?j1vG&$G{a}ASOLEqc_?P;MJankGkgrOIBjWfd-Br=_0bAmbl61y&vECMS6-wj}5p>c)d1<6GXM%FpNZ5U{M>EnAB zADw@SPeU(NM$|r8MNbSzddmC7cpa|lKdgA^lku1^sAlq}(uYTj!kL;t@b#hz!a&)-85Pu)_FnrJaXNpJ+t1RBaBj0;SNj`OAo$(WdX6K4eI zFA&5S?;qF(oK(%fy^+}T6o?3^-H6+aKjvF2UV#CH+GREjD?E!0TpzTjqa9DaH0e(_ zn5L%Z+Ta#bLTXGhKDnfp2D~$-koNl#D0q6y?>}ouA#VL+-hH2P9+?s|Kl#sX;`Qfl z3Q2{Hm5)e2tUuM+NLov@xK76yQrba+z&^ut#*qFGGm0fe(Sc|)P$y9e9lVWDQFC@G z6^`7RQC?Vf{}-+uW`z|G8Wubg#f`TBUC5`fgkiWs;@@erPUIseRkdADF{R!r?kUVH z3ry@2x4P(Wg4JyOG!rNXUnX3E$Jy*%JV2sWzB>AU`Gy&O!P+gLcULnSHi_yu>(z;L zk`rwCB}z%}iJaMl2JHtm|I@b%;!ZyEWx+Ojjc}glnzhMl75&M4X>#FJ%F|zOplrNEltc`h>ca-W|wLtpG z_?}@}Uz@hQ*}N~G=SQR;x1`$~WZ;khKro3T&`pzeY!V(cI-;OW^RdaLP3=#R~c-u<6(odYQhVzXq%K#oioX?8-Z8_1a!v02*JV(O;4;WYS` zjQMO=79Zf7<+cB{%2H@vY-o&fC_~k8aSk-AqVB*W5-1@zp8tyV`6B6=rNc8|zi@rr z0w@~p6Vstk4cmLY7;=Km;cHCD(oNXTB@y>t+=j%&?`|{NG@tb+*6G>(Zh;SiXZbVs zX-(OAimS#FX`#k+0pBM`jb$k`oiD}Cj#G6*Nk6n;s7gN7Ng$ngsorDCa=&76bDNU+ zQm-S^k?$m5CE>HOV|1eSkHFZ+KO#Cd`5XfB4Po4a^n?6?Smbx0F?VH(dCl4(A<^Pw zmn^gh6Hl#-9GP$|r<0a})h0kX*-0eBu+Vl+qkUh)dZ+XuJE1x5#;E)&R)0LKG6gH9 zo<;2O^hlvSDN*+nJ7>xXq=I87ZppLsUv}ekn;+iSX!qZo(E_eDojiZn%DS7o+5g`Q zAhKrE=ez;B;2`Gt!tMu_mx2$P@nW)FH}*DZ9pEXp`L)kBCBw0s_-2Ic2@_g@dr{ij z*5CEN`JSn14WaFv;01HvO-vt+T)C7h_}xu)n~rfFBbK7{Gy@nb}IfndXCfMPhf0gHQ?DtAu?>lKuGuah{EAleR!bNqgq?Fg~32O z3I`Mjs~h-t^fmBc;H=`UIvfqpz#xKtsufSkU?}M`aHF(ezkbPRA~E zIYpd02_xC$+c=hI47Uml1aA-pa1q<0oAD?$Eq$J(zfmksOS;Oo+ zim7>-dZk(_Y>jom2nmw^^VM9>j!cWYz20q4C zroqWdU-4cEwE~PndxSRn@@W)2)HA4nB?em7SE7nWjtv$w#A(Zg8rS`yEaX3)duk0nyFgu?5>CUdLeno~6~l-OFI1Kt zDND@ySAx(&VMO5vU5LDq3d(s!(8aXu3;rn3&UN)h)d9v!kZo1)`|{yM)lScsxJ3Q& zRp9&&wNiN!=z?nwlT|b@wf)2WbfpcalGPc?J^B;2k~EdCtkfU}40Nr90WlTV%j4Vhzh%>SG;G_N;$T#cO_&osXhX! zKH{0JeG%U&SC(5~Lk9qV*g`|^QLRqM;}vCa(5MN8QKAb&2UeNREGCFVx4Zmt`*P8) zv}C)9%J_iA5jw=*r4@T9OR)a-ct>8q%V$cTrIw}{jHJGU^g^|wdH~m>^MF=8?06Uql+eBr>k8Fcb9mn`y^&V+7XBu@48n~4OO5k<7I6qHl-b3 z9O*SmD~kHEu;wfGQI}nsv(=)OiU+z6?0SO^1Sh%;hO}sGckF`#x@$~ z78I#|pD(v9S5TBGuEHY^U3#?>WL3sVWG4u3Pyh!Ciy>cexaB58&RfYp>7L-hU^L7? z&>`nDdN*IL8z?}>|41zC_-q}izEnV^!XV1WiM56`qH^hb&;R+`Q^3YMkVvPKuTL%3 z`1;*g+k`?Yp_;6ichf9-VBKD66I4Bwh0-OAc_)=pGj*2J^TM)TeaDFBcE!8Ef6?<9 z_5pIMSck$3@r7CUJM{+qUJa23>%`EaIxSMohiFAil}NgS5T<8@K(wuH3Jl_2$_2yc zO?E;hTtjJuatvQ`1$%_8d5=!_HmWb#+t9JvU~yxzYEwo(ifkT1S}W&z4hHiY&SXYP zF$t4|O5f=*7-cE%Re2p2YqCPJMvNsOz-^q5OtXOEyBmI}kTV-6*a1#T%Ct>^tN6Cq zFxg2IYd%Em@jt`0Hn#4pAF?FOTd!f_?FLP~T{yiIZm8iZ&C?vF?$+N9-mIqYN;JE$ zf=Eet%(S--iqI?NTxJFzfFdpcM!>TinMI?eA$m1b30+8he ztDL3TP{_YA*AF^nRUVgJ&pUSf9;4Vp2g{}4q8%h`A#H*QEeO{lE%~Qq@G!aBM`XXZaiJ)A7{?Pf%3G+{MxD z^KzcRhjK+ypxaAU)`3ka-%DiJn0OHyR4xX^PE)vC?em{c#GMKy+gX+2gz#M0+Ky9g z$DbN{exhnrWS=NpQ$l;zdmzkJAG9~4Bu?v}_k7H*Nrp+7@6J^S6L0}40N~`C%eVCn zYzA%9A4-3H=DJ4p86n9CHHmoVmD|$>LFS-iuX6&pVBXAN2dZz?@fk2d9 zzG|2kGylz2lJqBJDv)XvKFvYd7OWK0_0{vEXG<~B{qo?K=!@7eIWu9>Y6J72xdb{G zi0VhAJ7&0R1yqhI>=--fpkX}H><~E_OIJXB;Ei3PTUe10 zh;?U<+&nedci(US^Q!S%N5Lc1JoWa#YmOnBS5Vrgn-*rHa|vQ;D!Sh1K|!$QR7tW- z%**?Iq)|J>?ss3TdW{9tu)%uhyNL_a9`hXPT^Ju3P23>q2xjaF#`w!$8-2nw!^@<7NQV`pxp9|ROqBM|8UTyxZWJB#fuWWLUH*?mcvlZ)<7Rw1qSu-%`Wq8%vwkZJpsX9DO+kc46$m7y|#3@DXJNqY29eBM4$l zC)}Pf0&g&$7xEMk@@nsn_Xq|NnS!AQp#p<4iYoXL#GBysp0A(UfAbWG9tAL%5Iu!# zgFtLi5YellZr;go0`Kd$AAC@9bRK_^dU*@h7RYhs{lU7-=HE=Y-z8rKH;@&}Du!id z%s`$2S1dBbC6`vY<>s4jc1jCTDBuW~gYv>2d4^@7iDsKVeY*1mVgK#mH@PlzDsAtC#62b5=J`Df* z>#sXy=0`vJk@Ngu5rv5XZ$JSf3mf6FT#wQN0N@GcIc&ptf`Dlo-W%EwuF%fVVnSEw z@v?rKU=?~zq~oF2`QJ640`xnKdQ2_5{r1~kIJX&0YJThE*Fxi;hK(A^fc^vIFTejk zh>PxTl|N6EPoXB2(w@Y^m@XX2!Z(pQbQO_fhkAe!k4Uhp>G+1ousp920u zxJcKiHuxWp56gM7Y&olNq1?d#7(q}zF_@6^ur zGr$g_O;9>8njqIOoevzr6O1DI^2;wf8THv`pLLEM{sUG}Y^o19Ll}OMMyXH-7=mGz z>%bhahCOvsmmQB331CRL2DauK+}I|ggmFz8~8$G`$6@V>bh&tY;6dp2|_@CM97 z&I(s(YiOyfSW{~aCf1=4EC{=1pkehD8;C+*p#L420(5Q$@9F=LznfDnLj1SJQrJqhk$P=p5B z6;w)Le2gUsBFy76p#-ln3J%tC1QtL7hNnMpf53oa+?PEY$53#hRcI4!qfc^wyhyag zm)cJO`ZE}_Y}qm=ETfpiBfdnYK;uS@CBJV!w}oDXp_}PJMM5eg9im+kPV|3JW9T+5hF%8d;w>If+CoQAYxd9 zu?4)b<^8}R>gGRSjNuG>!yUsJew#t1yyvONdFmzkFdl2#$ZwP@JV!yStjN9mH8qn0 zI1vzN5mAot}2ClQsA?6RYu@DWD1--C?Vht(T5yI z?m?~rdx#eKj(UpuPW|9bPys=9uno+Cx4|j;0K87U{0}VQ8IWnf6flN+Q6Jwq&&GFa zYgL}p{sRtzeHiX&2l5PKpOB`})KO@`EQ~1N4a!of0cf$q8U!2!7!)Qp z>ISzEIt*vn(_Zj~@8uPq18l}H47L@jkh^>5E~!_)zGSF!me)5=QWV>0b4Bh~UVZ+2 z{3&4lkZ1$mHDeWHS4|NZ<4=+6uMxh5#~A0~9r!&O{K0*BH-zM)q`dj&o363kpOAU# zL^h>PoHHmM$Tplut_Wfy=fNXnSdJU76R84Tfi>IlQhEN|$18-%f_Bkr}+<*|}eN_qzT!=pTo-gqv$XNDIA5JajlyHb{T4t z^PPKfPcxcOx6#D*FgZ~Nul@ElIi~kbZKuW_6_ikB2>YLK%gt`87yttee5pMY;2pE^ z{ux;4KM1QNFs7gSQa%O3@hh;>ph11dTbxJ0X8`AWxa-TWrSuB>FNy-jSrimPSbz1a zUpeKI`Mlr&Mibs4SVZCd+J8&$RQeh27-zsp{)N=?gj1{)&$!@AEYgW-y+C-{@Lg9mWJLoQ!h|T>o#)v%2ipStC3yN=XhYs z1I{=EO(V5GcyIr`W>6rAH3*MIDmda@%TQyT^$F_-HKWC~rxox4Jj(_Tg5}XD_Fmw% zDO096MS`$-47JEEB$y^jg!c)Lqfk)7;f?WXQ0BqOVps!iFx`i#%{&8m!=CVZi$=j< z`^6VuINU+zL9q#94ednEK~|ykWH;{&a-^#@ANHbly^uUV_dU7P-`Bj1~k-o2|6Nwpp(L@ z(I|FrbR3dQQ^yN0yx?$#I&RXbKVS-2fbkT}poGI2j&nc5E`D3#{){V6joG1zCo^`zko9sRFG$ojoC~m4@P2#VhwfE7W(1aZ@uk|&h1qe?9{4_G**3S z=GsE}V8#r2aLR)+a@0sC(-}Ri-SOPlIt>nDu`yVnsFD|h0Rt<)b(HEPXgptDEb5|j zL%(PmOst)}AB+)zGC~VhQkaw^u!MtAFHC04Tv6b*>o7=gZ?@2a^%Q2uz`+K%@S8p0 zW=0hNfbc{URPfdUA}C)NP_U+&g_iaf78J^R%IFPFosk=|dP`RKG)dB_2yR*CEt_T4 z!bL8;4ItrOh$Zzx8|Q!5NDA02nh7-mWE$1x0i^ao$Aj{+D746~tgvBZMvR zhPL7<0&6hZQZW90rpA)Fvp$w|O^%zNl`W|qI?9!+x}u8KSFml1eE810sth$zBkebQ zR6^*BslmU&6}Ag1sBTejK6Bw2BU@`- zO>$QvgcvtC&zNEFu`*9!ZKwUY@}i88JpSWM=xWe}~ui-Y{lBOK=9ABANxc2EJv=A1MuR2H|w(ogphwA2?(@ zT-mz0KQa?rXu;?Uo*=`34d~q946+m$1m5r)46@vkv;$d&s3mZQ?`*Uc`Nm{3>cW9N z_v25bcl(Y`u1RU%K^oR?pi46!l*NT>RD7Ofk zD7A}K8Qs2dlbpSHUiR+UE1!P)sZ+kOlvXEE^6gfY4P_Fnt43FW2 z4RHKanSA|{gI&fzBq^z?av< z`@v>2E?}g!1dShj@PP}RN0|Tzh(yC+QPmO-XSkO=S2(!_`2}MNxWNYYpzJVx5V;3g zgmDZ!Gy^SIgTjQ7rD*Ls`IGynNPbT>27P6;y>jKU_lU;HS5EQ|icl&)=Cw#r7Ga#5t?PI8ydnyhrT^y}Zx8TS~&qQx8Wd%Ehc$&3W> zeJqT#-~jK21k)H{G0c+0p6FBL8Kxd{{(4F{tN|-bCNa4N9OYVMD)0rzB65w%I2c%P z3NdvM;|RC{)?hGUECXwhb->Qg^%=S(HIs*Kzf;mW>p1sSG~Ri|sh;|r|Mtp@vfzsa zF0n519Jo~-rB)eCFqV)Q^u1|s%N_;un{@tG{)o|X;_PYJsvz+H+B*+8tEy}NFBlkl z9eN$A^ezI5il7)Z#ui(ui76(g`TXa_^!Adx)b})b=_ba+SYwHbU93p&y~EHMhRy&3 z^IzYMd%Rv}V1{Y;&iwX#&YXMtKHqcpI%}`J_S)aN{*!jo5AU##eC%VP1sDssx7$<) z*R;JRFEk%^`maUBq9ShwAe*mV(X?* z*{0tw!VDo3^N@T^EmuDQS_bFLo?}nARr}?KAM>E4gKel+So*rGS+`}g{r#7B*&8EX zu`51!jZL0Bxm=zK)ql6umjlWhP)d|imq7dupw~R-9#BuXchHL0-m>u%&a*4WoMQ(M z=iA~pXV?RG{?0D`@b&hAtFH{6w2ALj>fT*y<&1KNf{qej*?J8w@_3Z~;0Hel0}>Et z0E1UFn4He@E6p)+pQzVhL}5x!7>zKbAXm_B&}A@?Aa9Vx>)mj-bI)$O?b;h{pzAeR z&aI%pV#tM8E#F{m=oM>7G#SJG3u~MHlmsd`&G^FM%j)FrxkS&wsXu z{_B72ysIv?0ofiS;`|=AYU2j`#edvkBi}dC-hb7Vp_@l}JTR;GZmaL=Fhh^QSi%%# zP2^wRoW=5mp_KX(x&+2qM&cqyBZFqJjTHAS>orSr9Q(w28gdQ!iXkI8A#0E?$eB1{ zM42*WN~l{kaxWC#Ne5;BdBaDouY2e-L^Hf-ifZCK2W0oOaU(}rjvIX6@;VoW8w@5D z#WXq3CtFSrgNegC=87R;>AQV&$1c15)1L{J@49tyn>WjK0Y7%gP2 z-+1l?>*Ow}8SXxI%(ulvVP@aPUa+&8Hn&3w!ZYrpd1?`e0+`DCi#{~mG?yix|Kula zq`M@yb9sUU-{(r=c_YuXM_+ou9`G17^eN>d8IMh^J4)@`@X^J92=9t^UroMq+(K`@ z`Q{*FcwX7r+11uck(@U;AZuvl#l0OlNF|HbKx7SUP&V$r?>^hJev`FqdZHUpp??#3 z??}ge&Usq9tZ7ttTj+zFd>t*PCV?HU*AP#uQ^O|qo&Wrro#BdS^Ttg=SF7!#ciEgX z&$8b?_)lxyu1$!GQJl+kYTjAu206J)lrfDvM*;p0M^l z6@eUYY1^_X$NIUP8Q!OFa1pGZfv9vv(G~AZZ)IDyZjHUN%-vai)B;-a^=wZ?^}X9E zA5>&8&Q`X;1S2rP%qC5m6v`>GCY7Gqsboz{H`rlFd1c*(pxmCINo&9fE++S5q_?XDd=?VxwPmI1>%wKd zE&*j$)7Q=;b@%Ml#g&Stp+%nx{V`V%>p8@>4xj!{>({P3;pS7;i)jB)*szfDg46O# zA*{?q;pTc&W-ywha-!hIiF|f-Yi^bYsa?8ewQbv#YxK9{J>(iY)}852E-TDi5md&f?eiH&UW&yR^428Y3YqXd>@c0=r$N< z+qP)J!AD%{(aDD~cYS=72MMMjgH%=JrL^~w>ibbfDD%-fw6gO{nZ@vs zIABI)U-NX&pD1g1|H$_+A5E-V&mOiW$GZf2P-zl%DC$!F^4SjWinH9`+l+<{tha}3 zs7?>dv@w`qt=sPv&}Lqkn>1`hcgpu93Vl=GmhT_M!wJ@^_w1+@CnbOZg9(9I*#?uC zk#rEs2&OoZZUT+w8@Qtf>o6;+JIl z=s8L$OyF&ylJ8zYnJ!COxKWmm-`kD&?_}KHgC-3cS#x(wKIARK>G1BpllkiZMHwP7 zr&3WBU2@`Gr^~a7a)vp1W{#VGYwudLVA&Fzv1p->jI%k&np}^|uqG$RW-MN4b39mR zCyy~uTui3OK=bM+5Z^;QWDgxYWGxys3Eeki&UoG=%w+@-Z|Slq!+U(<%B)m205p>piB4k1pu?;8t%9p1pL5 z&Gndb4I4Fd!^(gnFNFBco=zTgOyv{EcW(zg^qVzm9ArvySET;bB(q`Y`aqpasq6&M z8SyZ|=uJRWrFRHKITK|Ctf(8|H92$MT$}0jr48Qwm^u{lh6d;9uGcK^x>N`6cEr%d z(O?qK@y^*R|5ti+$#D)%z2&7_cDB8=bh+j1*k<|On%~To?fB{ABe>uD;pObyZXNr3 z-@L4>%BNQYufR)4OAHgPnH2aqnnY#ZZgsr;<-H%2FrrmAYS7T0dwPn^ShdW0`tW;U zSYcN%)@|No1BMT?k6e05ky26KYwGw)B;d-TYmS!Xo=yez>)Kj>e)R1{&hUOxg`om- z$jiiLmor=4bEs`Qk6(i&T&c_i;D&4=s(Nhst!z0^+2=1agi;(3^W4t8**<&AE%vg< z*LiU2RC{g7LhIbRP3R7|(ci~BZz#C&3%A&WNfWJ+zq3lIOCUZ|yw?bt`EuS)+nl>I z=rx(XE*NS(`$4bqt`Vy@=a@(NwN@$ag_)|v(Ps!uNFz{HRs`d7PE{_&<*TmCWripH zUFr~^drgNf zoo(ZW&9>iVN&|1e!x)QBgSR4j%|@?V<$4$LVP4mw%qW$e0D29qku@|1DV0nB`AGN0 zapT6>+SP09p691nH}?YW)G8}@wQcbC5k33n+rDh)o_}7@7pqg=q>aI(gDZ=pJa`tv zC{15D-$wfABLjP$=*vq*1~wwK*V&l zXqC*IcZ*hEgyX2uXZmoEOKrIip~Z)%PoLp-)!>g>&tARVJ-)5_`Id|^WwxL=AWI1F z+p1lAd+q6GY}9~3?iOFqS~qVI@<~jZr5?XzmRrIP9xt%I9)Cv$lhb>IeE6uu(A|=V z@=7Ns(7?U^7~O}!Zs(qVzAap^&~iLzS{>&&{>%|JXz-w-E&$c3RGj=Sw6%zU>nA9S z{Ra%RZSMIn$6N0EcJ1b)t97%+Z^vZY<-zRdtz2%eOncKVz47`GK(FX{{U|j7zSVfm zlBmS+%~lzCFTSWR(?Wjv6_?vkZ}3^-aWCi&zstF`?(M;BKXPMe5r$!LaZuiSR;{US z0z6l`GYlU-!tVU{6dURBN``oRwirO15xEw5*K+ilg62(a^jT*Wb?c1pM0MXox?CT* zLRS;?8XAH~Vp#msi(Kdic#{vXaNdOTZIOHHZ*q^>{oY-5(Ai_`<}qWen|t%qxK`b| zc3K!rFnkbG12^MKFT2cMdG2|;@0q8)GBYr^%~RMnxa-b*udI$ddu*ui;I@+v?vTa5 zM%A~L7Z`R-6nGg-PW~(qjwlyY{P1e(+qX~X^Go68eU+8|STG1UG!$HJBGm08@F$DW65Hx->QXOc+rLa zE$O62zSDfXRO+b@kkloobP&<4JGiH1%NETI=ZM{2_oPb`aV-gbkSd1Rv@n>&iVQE| z#EBDwweCqD_UypYMIHp!hoy9bNv^xj*KO6@F7PfAC?KSGT?j>%KMEcB$`hTWJ9Vb2Tz=Jkt^G4g`Zh;vd*JS5`eU@(pcJagu zgJMyg{VCIjXz@AaBZ&e|D@;|zN-91l>9rP_LS?UqS6uvGdSza@Eg}&0lV(2B+T@ET z2YqGAlqq&#-70J26_`Ulg3d0NH;uAd+NIvwKXT+qc|{b_EBg?kudn8I8JlhRW}rZn`1`kwCLotRjkD)J^R3?>mCmOBLZRaadV#tDDp_1A6h z?mgDQEBBKME5?AbyaN_b6$7@_t%Q z9RfTvyoJv5ZV&zX_4lFp7TIzSaz{LY(cWi!s7Lkg;(;}(=v23hVCqmSTTbE!5Q7t5 zY)q_EIyC{wlv6WGr#+A%bhEnd`s?iMbI!K;^X6I3`VH2waU<*Fu_Oi#9O$u<+7$)6 zOQ(BFrMzM=LEa?BF^O_0mGYTNc_~TOl(asT@~*tRmjH(L@#DvbE>(1WAa29~?`k#p zoUz`md~~Q^(IuZszf#!<#4-Up8uyvYU{ZR9kzeu=C9O~1uPT<)%3uOGj3KxqN8<=B z1t=Vh&)mYTdMfDUe{GsIZ5D>WUAAJG-TRMw?0^S>I(ytXDdQGZ{yWqeMBwBEkSQUS zPtV5IrfnM=;jOzU*C^Z!8J8+9f+xR=)ZYqElsJYJM!A5wk_gl$0#UEQ{Wx~zq?=ll zL39zMJC;)U2;e!0d?7ZQJle{~Qb(sJ0b~u8c|1GYdbdAfmEc8ABPsF+nWI!b0@Ohx zOT}g@AInrbI;{*Q5grv6l22 z)X^evY65)xFwin7RXhO&86EA78L!y9`Sa{>zTD<(>ND4Uot%K=%*pwj?p9K}xdQPmj zR@{|zrm7~8mIjjuDZYep9@lGr`UFrU)~;J?%T_P2E`74BRi}nFf7=q9Ip-}~yM9fP z066_C=tvPrKLSz4=%ZJ$&kmkcV(>26xN&1>{YA+T=ZfX5v!xFK$(o9NR#BG3`IaqP zio8rzQLI?jI$Qb>NR%_NY@|=_Q#Yp=OwehF1G#OQm5<}Bm8o=Us3 z_N(~>aE*WJ$tUgE*?+R3=k>FEk1f!ySqm$ubIA7Y$PYm<=?9FWAt77y-v`~H2vimU zrW3Nn|Y70E5Zet%^ zu~EZD*34y1J@-73tSJv+9V-GN5DC!e%n;mj=ghIGFTI4jY%rWOa05ycmpgSc>Z*7T4K7}wFZt>0@MdN#KPEsooQ7dBeEj_oX~bynyL zETc(fKMkE<1ZpM$M#Wjbe!V^Z_~Z8I(@$CBx*3*HaNO3pm&vl_%Y$OkqGgNF4;aIW zWKGR{4|G=|P?-cUlx*9!&0c)*MSIj`&C2B~tbq^4mA7}dtysCzR;^xbty{IWtTtJp z5+gZNnNLIK7J-^gfZ+A3U8X$#?|<79uGbviv)Agn0cH8BRkmQ^BHNj})7rOdA4XV| ztVzFjAuWTxrXS|i&oPA;?~>PFd(E~tc+JlL;8<(XF4K-2EU>IDjjhMvHg<6LCcE>u zzYIeRUvkN19-y(|yKang6qn?QQXCbn6M@=CAj*iPK1}ny_ugxVH*K@8Uwwmh$Z8XU zz3xAB&=#&(VSo6)|FcW4xWX>K968hYT`e;49L3R-tSOF))`>vvB7lrQ7SR3XAOHA= zEuA^ry0^@-n=hDbtvtTKK{wd0+mvIor@n4^xx4I(U;dJ{Z{Pl1at2~0$r{O-+Vwu^ zZbjf^1frZ+uwa4RckkbA|E4YWzJ5b&K+j(8$=1*fw0mvY+BNpnz4uvx8*xASv5$oj zUlY$6&yl!}-;y=SM5=FD29xSzpHf%GSTbks9Gf+Bru7@!)B2s&&i3!gx1&dnTfyPu zmhA?Uqxr|}ud|=9XJ33eSV)Hq9vX%tW5gmNLaSh>)FV}9L?9&z(64vy>^b(UU;WA+ z`RV`IMQ4q*Om~Yv?vad|UWXO;( z>@fNaat3*$;p0mpto0%w0+9g54m{bYKRx!xd+cZ5`;LtpGR)evXk~Tk)(Inx4)53B zn))b!Q=Xn;{rdN}E?v3=y(Y>V={1pN>dw_ACv~4qmE??qj&iXpZNjoV zrQNJ^zm|3+|Cr61vff&CY-~d>=xX&c>)NpcO>FMWxwdZQR%_n0h3((JKTM1_j?o2| zH0HS66XXb&*n)zB5S5<3eOxbIQwEkQdBnPi2&4=FWB{!&3*0!{zG-v2Y}|NjpOs|= zK6E*9!T$y0$tHJQU9@I(2n4(Tz<%4eZ(o=gNhivgC~J^2QPv=D7zRDcnRw%pHIano zB2YUBtXaF(X1w*59XW8&&h^U6@V@YUqZM43Cqv(mhSyoIBeH%*Y5jJ9g{{ z#ujuKWDQf)Z&-(niFysPhI7RvITJ}}E&{cO0CHgN+_|=J?mTPPrnOys*4ft0jj_jE zHWV^BistSqd&qT_8H?uITJL&{&VpWp+(Et|Q_yLmUW43;awfSBg`APBsrvUz9wt@K zJms!PJ!RkCeO}lXTZdtFtwZk?mQk;+)o6;dKaCR4H{Wajw==gM{TCJ z+|Xi-l0gLp`O&OdGi%+tb&xGl&M-&LL>a^~EyblRf63*xBnYKFUrEPT@tk!KC2Z#+{_G8jg<^^A-&KSV1tva?fs(4aQllCm2xZp3}tTK)bfB zZIHXl(;B?X6_VHj6I))GZ{EB)#43og200UD4Du$*5=QrmzmY4^?LJxFaGuhp(#V?P zWI<`qf2#dTk|`yvKh^oPT?Epd0I~%c5)85aui&_QfxP`)K7O37LE4RTxdVR@TlYR1F_1vV(9_ExC=Ji}&^zGBf z1`i$_>O{yJjI(&_M^6%#kwJ9xiE;+Lh6!1cSZ|6lCd!pW9fx&H99Nc<3bH0Szo{T+ zWH3p7Dz#`izJQAtFR}+7`G@tsxRLc9-rgEE&M-HRT1Nf4)~H#1cPaBS@2$G^>eurj zvFh8Xi@Uq`Nuh0hp}7qmF~mj=A7w3?w+PBbR5U2?aY|elD?GF^uUoe+xX*(eVG<21 zQMrhM0O5e55sfYs{#X%WpIAAH3Jb@VH6;p&;*^5Yo{wV@d`0!vvja_Y+W0Z|b1{IqqYmhHd-o)P+grk84riq3X$QzZE zC}*PVDeH5JvZg$8CSH?saDVY0sld@Hg@=uj2-HRbv0Qa@xpKz90rsa?U$cIlyM_Ra zc$(Dl5xFp?!WWqZs0~yFFVwq_PGIM z)yDO9ux_E<>_J}#4H^`559ACw0p7Dwu0(wXd4qh3vL?zK8l%yHqJafCl*y>GATy95 z(O|;wC|}rz}*4Vo=Rd&Aw{bR~QcE|r-V$HI=C3l-w^gUnu^A3BG~8dKJ8X1~MJUip94Gcc2%xA?@T0;Kg$c|2W*_G4Q{IBkJ)tOV_x#SB zKHZisUSfOw-s-w?)}eDpJ8RrH>)W?)$RpQR5`o%CfO3YxWb@FWmf5nIy}n>x@E!>j z8Q*r;m4?M@SKHE@4L0cm?+?Zj+~G-yA8{%Zl=G5Olr_jWWDOOX7@Hu<8s^BA4I4Iu zIUi|c4eQ&rYZv5Abd8O!uq;P;6X)b5QPv<+P)?XfSrg?AidA`K4fo7_p)is6X>Yt~ zYgetb!`|mP(-pMrUOjE>*=L7VbCe&hsU!lmi2&ulbLY-Mest*C#pbMB9!9Ge(z}lb zUTkbf{9CBcguG#nJYmkX zY15`4W6*EVbEp8J$IzgHj6sKC9_0=5L^(rwMz%z~Cd!wn?;vO5lz1N}%m2x?ac@-Q zxR+HvH0qLt3vJ!1HQr#BX&t+CwtfQ!1ic2<6>+_jAXN$QjXmRxGi<$k!#?9d zU8l~QZTk-$u-;v|g_vxN{oMS$$a631m6xy*Pd_a%e zfjx(9(frkRAjfR#)R%4i_;UlKSQ&`(laZsv0>wXhPQhV)oY;;Mf?!}mC@>))sPLeO zFh>bl=6wS(w2*LlBN8>0A3o>^3o2i6Vm(5F0!ZPHg`T{}0?F^hiSrhBZ{&5O2XcDm znP=?xcmCQ2xN)R=+x8*;0u{h!KHA4kH+{lJpEWwI4EoD$^2T>_8cz@`ZsHMI_1_)}xG|U?j?vD0etd;`+GE zJTb8kd0f7Hxjp>gL-yQLPumDrDB5`IdG7wbHgD+?J9p9~yZ(mj?F?`ALcS@9Ky4&| zo-k|HEPL>u_uDfMJZhI*a*=h;YHOQ!Znvdt*VuL+Qutf{^?kei10M+G1UWeT(f1>`f&b|KE zda&%6&i00v?Ck7NwvjdHGRPb1PW(pJU}QnZh=vo6O_nq9zLMwhI;=;(dhOL$?GMiG zpx5&T_UL7;+*rL~$99|T^6&jue88@~?%E>HoOoNu0<(%e3my3KAzO6Rc+*e|&qSYM+BHdcNR3|L%~)8YKX_TYmL+7k~wZ0GbFVEtXGXx_L<5Ng|Y<=X2W z8GGyDL-x5Z+-l?8+oUpCCFf1cB9LkXD2IHwXUv#kb3CBy@+C`b$JXuExqCMoIr>bS zbkRlOLruAiAK9wB2joSftU-1}d4rrG|IAs>zKOCX%AQ176U#fY12_99YmhIA@&>uX zguG#n;!#{eG5hWR{nj=uU2Yc)8{yrWI)`pe#D!V!#+;Yt&$hNbd)Vhb|G7$+&Bf(U z>qH8U9!}tCANI=Qd{G0^Lcys*ubGf?EDEA*qPpl(Zw57szc5o7myu^ z@&=g_p_GhTMe<7P`Dy zY}4KQbJvcYHo|30&)&VPk9(mE_b$ei!KzNk!l;`>c|(3#hx|acB%|XaM!gKLlytWX&b`q3_Zj2{kQ0W3 zM*eZ7aBXYp9vrKtc!8*EXOEl|JT@oWN`<3oS46ypq>Y#+8!kSYAM z4o8tjn!SGRS6+F=GV_nvg`>xK zVejJHP6!j+e;Hov^WXcfT`*~4=#qeywm6{?GNOMXP{Rq}HA1WC#EBEb{D6O)H+Ty} zcAwrh(Y+)3_wQe61WqMqxNxj!oGKHXD^aFId4q>a;(C4~gOD}IU+zE35Uv&F3Uihk zvh|t(05tqbL_t)e@_~F|KlYDumus$Cxyqh?_8Gf!)L5Hz)>zAG-qK~vF&`#1!}@gX zZkxQTiV#FPCFO$yhMZ4BprI#&T7P=&~azk(F zd~r^7x7scO=|_NagI9Y8*KK;cF0;+OBNlo*wp-2~XJf~X4W1d5rT3po-cV1WtTW-A zcCv)=l?Ea_-4f+Z>~4sD6J-!`2jdF&#r-oOOVC%4F&IzKYZB!Rzu7O!8}k0<8*kW} zc?<0$U;2uT9x%vS`f=z=M?Ejy-FR~3$Wgohh3DP#_)_cc_14Pbcrq?Bn4FAG4d0eo zKhs8!8e_wUjtDIVeCP;Ldg6oUE>hmdckP;Wq2*%2gbAS%Q=F?+G{BS>S4EtMg2#C% zkQ7*4BoGJYa=gHon^^xIJ*?+CZ~b*w z*3RB4Lmn%OXCe#z7J*bGK$)kKz;_n~pO#aC*vBp_m67L)mNT3WIfHIPS&malx()Xj z8;Rl^d4v9fjNvzi7Ub^i*|S5XhWJwGOc=Ad&feV~f5rEiFmQ|6oIrMP}hwIl%+v~275RCXG3E%-)@E)uB=p% z@`kzzvIgS{F)5=C6X%J#P4rwrrbIczoC#wK^(JHu+v&1^k{2d4LXesCYAN}ronkBV9_F2Tz1#x6T7Q)rndPE=S>Mi2s2em1!g<*jMctlrU< z*1mO?)%TVe#9|@B(bAcy7_yLAm1F9aAVbUVc(JMOjl)H@)P%;|fn@CsCh4uEb?z zO_Vp$HJ^qfxX`-1diCm{*R*Zc!a8KN@vhjJCuE(=7!(}xLt{}}*A3Tf+7OJsa8we3 z+D;&rRXjzRC{ssVuc3h`9a79cc`a={3L|T>e2KD#^-?b z2OBbu{ewhyIa9B0J!|O(oQ|!sY){^9EAU3{>PRw}R0s7GyJm48ISLGkK4@&GB_{3{ zx24!^oH9Ex|HbXdJ#mkvZE>P7Q?Wx3;DLfrnc@8v_Z-+Cy59sLxtMmAi$r_XaCJaTTNyr$K4CD=RhZdSf zUU^wKXRaOair-oG#CIz{FX|KSoRbOYSu95Wlc`@Ods zD~W&zyhDI*JO-249XDMJChy$AJHJYjGjR{LEN_U>+|#ZaLnU@vXUYt&_XQ z5 zR3VB4awaG%KF>dL*js!z+FCc{5@e2+9};;`5`o%GAXZ>d)-lXtfK2z4C|eSLC*Rkp zWDWacP(l{tGM}BD9RlLkb%W~??{2fnb(jJ--r_;a1iLNnX1`?R3hUavd)QwxCYidH zYcm0i2k13)1B`X3bmfsKTZ;RgcweWIH5`vDM$W(o<7}sHU2WF<`L=%R7T0MGh3&|h zx*kJwr-w*bm9x=SxW3cP6YfhcQQx=z!}WzAu4)0yt|r(^DA(xqL8V3^(Kbx($XU%FwP%c5Sc*YpbVJIeFY z;K0xl90f(H-=Fd)0#!jEwlolsur?V?%EC6vWsYg@L0~`j zv5(n9_dj6YN8Whg+&0ZyxXbEZo9)43`<*q)Zoct`5R^LcxUyVN`-(tKB!DLtdJSFO zX!*rZTiUL;HIYa1JtfKuI9+<_rB=rcu0Q|PT{dd?FdNplU+`*Mv2mTf@#0kb?6>c* zfv(r!bt}ClnWmO&Gl6I@A*g$8GMJS1nMYYu$7S-UQKRgmpZ=u%<@bNIXMM!cVFL$P z>t@Yutp|#o`r^wr>Dnvo6W3iIg5Q_+{1v^Q3?>y#v<97z4`3Yr7R5rTWCAE*XY@V8 zK7P|pHgwo9TeV_^?ccN48acyD&KzdL-G!dk8k93q6ci<&QCmcy_7FfZCrIHm51uT= z;?&O_0}~^D5WsTwdFR@Ob?fb@cg^Y0wS#@;;z>5rm5*-Sx+#9isZnZ=T1J3xHKQ}& zB|*$2f=$a{Qv6GRCl1CC7@hAUB`#mS+%~P>V9k8}`_3M1U-;%Xf+rgWT$1#f;yG$< z?IeKr8Utp~5X1<1BsDx2|1RKpf7R7j2i=7cw>GX{Z~5MB^2|Z~?1s;M+IqN0E&5K3 z15+KcMh27OpPkk^0SqR%bMw(8qOn>JPX1cJ3Aa?34RfPOzW2kc<8Ja8Nl?&~?R`~L zT-(xaAh=6_#x;;Yu;9?RlK=q{tg#LdJOl|e?ykWh1dBj|L-5AkEjTnT!L@O?i+#p_ z@7dYsKHf3z!(q$^##p_2N>|NU_06xUfb$0N>|TuwWp+NcG9Bxlc>)|`DzX>s%}KTLiT~o`vgZg_QSNdStqhIEz4XX8d5 zDgwOU)=~jfWlG9uatrlnIn8Dn9$r5fH|jk}CQN>Q%V_&GQAe|!C<_(?FS|j79uK*>$Wqguq8Xb@j_Y=RUSLxz zh?!v*BNT1K`ofF?{<)uIavW~SuiSZ=*8^5ejeBHM?QLM3=r&v-$3TlE7l3U<0(u^5 zp5yF@s^O!Yz<0%}Jf5EZsd8Kbwe-1MfU=sJS`_HIxGd8wfzNqS=U>hGEON@T5k5+5GKrUn~}425l1N zItH7S;UqN02;ky!6=3|3gQq|zLBUy3)I#jR$~AeLCd7AJpVUAsa%aW@sK~AnaMW$` z^$gf>G_O#;${=~th(J*-boI|Q8I5}afU)*7N=0xJj~wuX`GHIy=B~|t-IYU=%M!(2 zfVL2=Z3wiH$9CGjm%ulKpjiu|9H-quP5h`(ehn34)Tb$5_R90i(nk+_K}!G@%YFFt zL!5^B*?UIfJTkP6xPW{xV1_Cx7UPCW6*BoMLP7{={nc76MF2JkWp~=K2mXU&f$I#t#MP8>LmUAz@&^$V#-772 z)x%~$M#k~(kp6bZA1tWAdXPy2_33cw6K#sApJn*$)+}&kj@4G2Y|YQ>Um5_4BZb!U z4Cop0@C42jzy5gzzTbjd=`|y`HCE>unk|Q}_}6FH{hNT0XzVM?qxF8=C@Fq4e9EVC zC%NIikoNViv!Yz~hHqTS7-)}HbY%e!<6G{A<2X)Zcw02R#A}E_=_K5|oK_UpHEmXb zUC8hFQ9S_j$CGfWa1tr+q0KUv9eoeC&6gh~Y?5fF>pLB*?8^nV#%thHuY)>qIXO6b zd4BxG$Y}aa)D&SB(Ho!2_huTGyTQz9t*=s$Q-6Bg4#H`1#41*DD}^-@DvjP69)}g` zi%4mC_MK{VZ36wtCF@IqfNfFmH;x zfVot%DkmzO;H8Onk%3jr&6cs?X_VFYSG{M{YlZMAD*m5yn8>NJv3$EiAUU4T+8dAg zO9Taum-SqbBQq*fECE^$=FCU?1}k;&MT(t%?d-+-!Ac--A{CcuCarFCja$!Fky@y1 zaQ-H7h(fq0&$(K1RRfT#3v9Vm>ApCRn?oIluwCtYtFh^e>Vj^^fbNU{OxxkIua!GP zDAZ3~A`Ly}D#n4N)eF#kjOXvAY>dk@XOXY!%w}PC zI8I*EgIb2kr_Q^;DV=&90udocRhw0oGy0+ z-J2L}=HT+e4D;gZ9Rzf*ecV={A8*E;w?=c-LnRri&X?S*3$a3*Jg`lUo9fI~!xNGH z_CCGdw^4i3K=6`&uArT)`PZhZC~EMpj7;AChT{{lS#$3a^ExAD?)cNPUsQk!$6-I9 z;V6O6fFekD%ON|Ij!m`iw=G>guf@hIy~MTvk=@ToRMUDo z`f@6VOXT|@i{>NxXUSX@Q@dr;)ka<(^C{)BHr(O&gJ|17r`CPQHhD5#jYzAqB}S!M z)F+2HXl`gHkO3V(VD+8ZEJ)?ETZEq*xmlSe$4JgkR}HZD@T$n+5)M_b4@pDnogks+ zBZ&KX-)+>0nXZXYh##T*!QtSK$TjLD2cG#N;XPE~2E}8@V#q2*pWn_}5#XmFTF=L) z6}L8bI`n=Ee7%=%2F3Z+?n1r~X6S`-L73Su}ef9I6 zP=2!2$Jus)`4E&>Bq+eNMo2;r1GVRYq&zbYI#1o5x3@bcS&N?k;*1k|n^O^-itkVS z&}XlKXpn_Kgl14<+QwQEc!n`#1zC+*ow1W49qVhVX}anYuF)`q_(D1@Wl>ctES*~k zz{$JiQ>vk%IYRY;FX+BweCDkDwWQHb5IZO%ynmQkap!ENBO#%V*jx!sj$su$J(Nk^ zv{$T<$Lt~y)lp`NCnu=*k>(T4n{Nryz(^CAq)r=60ls;vcGZB`5BoBs(I$vj%;)>( zH~%8akfV0Rqtf%=CsiJTv^28?6BA5F+81vjSU1F8=H=XICD~ZcZ!+!==zMlBD$f7L z;P=Ee%KpRHbkcf372L}sV9D2kMGd6n;*lN?zX}+yu}LP{5ka??)dOa)q=7mtp2#Ki z$t(1*<>JKl+za0t3DoeQ4lL>Ler zkiV8x`)lNC^7rZkwT+iP459DuHqM!H)^Fh#HokltO-@r$k4Z2&^4?~<;@rvZ2cf`K zn7UUF1|AJjbi9!hdkcX50fajStvph@soHjg`wqU%SdQDqjqCOquh=%-`CY7v&e~vU z?zaZBc<#noR`m>9RRcjr3pMB2FpClJ7?x#1`b!U{$FcMsnDjjFvb=W3Av*_(bfqE@ z6w@Ni3VNi*jSsmeLe33dhir{i%JjkW^FVifmwTk;_iZh{g_2_S^>U!yqgXb#e>ZZ{VBOU{1P6cj9YwX*vkj~D4 zjpBwDE28QZ7_N=LFZu3~Li3)o+~d}JX6N%O8v9uwtst!}g`sCDa4Y--mMlm4>PTrL z8rZei1EdltVY!RaGideg;ueH7PJ1L1Hivwhy19t@?O9x|k|Va~Sj@M>nlv%_k>KCy zfNdfWFZIicO33cDV!gw;(kGdjE#~W0S_v~sqg+vs7B|7)Ik<&7#YWP(-%}r0BMiq5 z#%oD6POK62zoUDUeI@GgE^FSLSbNx9IEy)+l~fG$l=X##O{=@+7UoW_%h*ulvMa zCu2*X(5@?%alFvO_5C>qp}}8VPplUetlr(~ZRH6OSwmZ~KE{h2}%glsdWF;PAD!YpA z7BKs-?J}In*C?Q%`NgyLuVI8IVVM82V_rT$QD~MP2poQQ^#!)aZ1d|(OCvCU7YuBc zv^Iu+>C@ep{(ECk#mbxFQl)$O-h!`mJ?Y8kv1@*3FXFr@v>Z4+*`lK#1aYq0JuKT-FkG%#Cd6efZa?>LK7VJOTOA-|N7^(l*DKG$r559P;Vf@FC zzYphshL#RIb;tkPNdM=N977oJP#GxLcxSZ!vpr4OpaqW9|Iqwc0E?RiWjNpyoUeO0>Tjgu|r^&W?~3XCL!Vk9FTvmQ0UzxG5uK%xe@&&h*UWAogR!K?8qpJ6tzaC!{(hJ1{RNF`T>6 zOL$HS+D)5b63~=12Vf?jKHU?QpfcC|BxXy3AGBAC&H3lHB85)!BK|#kOF^DiJC0Mj zmULv|I)ENeZrw3vE=DuAmm;_;8D)b(TT0tLwU5vB4Y<_|B(?L@Pp`g;M1$Kr%#AWG zANQYVT4i=gen2@Fb36|1&lH-cgDs$3_HkyG8r8c`0URcoagI3krShV#BW-?D1n$GX zPacIh4+nKKPKzwznM0JWvN>H8^8pWg7c35ZvGLhu!q9bUC|lO<=IR`Jw9*}Kx6tIb z=z8dIaHiE?^QIsaAVZG=T$FG;mh^If`!!x^H}fv7%}f0~qRcjCeXhbL`RgDo!n1`TPU9>aE-dr8zHuYq6^~Q}Q;T{%ZE+WoOK-mD6ee>-QQ!=M+gis9Ary$ti<7P7>KMU)tg1XwZ1WeUM;U7(dHbw7 zDc^;)9Of$RAp0zCH|Og+K)r&}@&VD_qPksC2*42nGP&3wft=q13Q<&KNT8=?tF?ts zDndr=gc4_2*DGw2r2;nN@GVcEv={zF-Oa@oCo}VBP4>~!!@-I{$y-lIt+Q1We~O&? zC(GF-w@Zi#q|&yO?3vCapSpz2?akO|WgGeAxv|73m(PZ0-AWVevUiMdCU)?&%zyQy zz-FRo@tlQ94{_3@QX^&5)d>``1uA4h7xRF;z#MR!1Gj|3mx&ECrKX+R#Xi^h*T=b% z6(UY+AdL8zy{2DfhtU5xlzt=tN7bwF%RoP-VI>ld`Znff-Tlq{;%Uj-^Ffm*H5Z%t zJ3u~~QhTp0$atx#qRnh|X*~cBxg7NiVFt92s)M0BGgX9%h|%xh4%5TK!|id1eZ}Q& z-C#!yJw5wZ80?JSTy~rG5*NLP>ZvL9E-Tmmav)*`}nmY5O2<~3yb(VMv3E>ig zgx&-9KL>20md znqhB|9nA3tMX_GnjSSM8WR1qRA4*p~-=0W|U2YB|M`wJX^t8wt96ypw?BGy!@5A`! z^3e60c5ZK8_y}Mc@^eIt}XC(Oa= ztz-B;(F<9giQhrOp6jr50r<%{bdS>6g@lCCKkOI)goc%KAWf|hT$ZbD`{^Z6(%ts* zbdPXm*&7scgRE))Vv&SN@8mR?|PNpmg)!&x(d#;k8cRg~xta?x42+xt`&{ z1@7ebIq!uY(-Eu)*G7At*3DOVUtj0}bol@c41zB(JB7<^8+^lNa&9AA_yjE0z-={g z(XhFxtwzuuVY{C0qy{?;D;_P-E8mYU<;>tn*S-Bi!{xG2vAbHophA<7FI{IHLV)zi zqx&F{%kAhDdR=*AGhx9keCRnhAqb*dr&8mg#RNGu+lvhs?3=4}Nk*P_kD! zpEqO4ZDWHi^HNNupgszh&6X_(0*~bzV1{61n<;Yn_M&v|s+Vx4|Ma$lO3>QmgzC-Y zFf+O~49){mu6B*?woS0|aM`XHL^jTG)xVE3Je+bK*GRJ!GkkhCa!r205q!>9Z>Lw|`40v`hvQ<1=Grs5sL?N?*%|hO~P4l}7Cv zt$RorE69Xn%hF&p=Vp3~wmn*$YbsH2-~68IKOC6CX?!_7;ZF=W*n-IU+w1wot5pH2 zb}$f7^@oerP7EB~!t9AZqgA-bUg#h4(l`SqAKLzxUX%>1QpX+|S@bPH|LKECScf9S4Jn}5S z73F}N=VHjvA0ACBCg?oBKoz;)Dy|W>cJ&;sdJYI+71qH!rXsf&V>#E?xNeD1VEMy3MLSSoP7NPo2qt5ousGACmABbT@TSmdas1-Ib*_f>!0mM?707eQ6mie35P!ss(+^Z^Fk5GZ{L<2nrM+^O2zjkvi_ihnvr`YB;_HDH5M~c zVx>_&f;@YR2m!=_L)}KdsWM`qV}O>vy`YT#?R-cVqk7`TpT={$uPA9-IP;T0;hX)< zmLRpLRV;0{qB?5pudAqk z_xq*HNdZYL!=`Y$vEDzvDa>stO&qKT?k{agwPu5fAF@V1G^};T zj)3W#IA`bdpM=`^W6G$i@q`(0*Q~@X6w3A`ao_yb61I#e3F!<{%{$YWGOISyAI8K3nG`o%Sc6c^CWo?-hNoO%mk+gRej}x}vEf?6ic6Iv zMv`;^!~SxS_+S;RW6G$(iHUHgB;b6ya@cS8IlTU&T-; zVdzzeZ!vY!nr_njpX7!52_Ij0?t)~MoGV^qZ3#2s?@`Yg?-MXS;`R#)qn3#zRS?PH zAm3vjB#@|+8^5!(+fScl^G*F zvu6&7ww`%5iG#2q^B@)0L`{VU*{iexG6i5DU7q{x^(ka9L%fwNJl{5HX+)k(9Si{- zA`d$0tBh|Ble29I$fjWZu%h#$^$-v%tZ{)@Mu?hy%z3WKSGMI$mML%Nt|$wURD#nm zdPt!3BdTY*tS_E~c@*Gde|tW=H_2;SZ!@~(GSTl-HjlWVxXetwHK#3`x!7+GnNwPm zHW7y^_*xCDMY;6aW#TB{jtp3@i<7dnAVD+Md0?fwiJ;1bpfog(xrP0<|1aX)r z79PrE=|qa)i9Kui7})cawDSX~Fyw>dCZy#S0$(2{0fng9n{ zWIa&0sKveZ@&z&LtG4Sv(yz^pK%Tf@12tg_e^rkgPBHg91~i|jL9y25ye@RYa}+X=4&4s6m>bFTAt z^k+?&EiME7d~JEsQl`NX4S)+M8*75c0;dO^gMvj}w!#KQsis@P5I3EdV}QevbBeOe zxVHw>oR^tu%6~|g2}}3>6LysfBHj;=)(j`wS24-PYoa&(8_U|%~Zi&a*P0ovhTFN{M16N`5@@)CilCx6U z8%7jC60cuCwQO{R4-IGS! zdF{hV9x&8~z}m%~(eKQ9P^tOJ<``oU zMsJuK2k(6WlU5|{%+4n7#B}okWCtjT7-uf{-QKycTMxKyZrFx?Yf00;$f#ju0gIJ&e(^U+oPDF!q$gPbsu(k9q(vXe`|pI+GV1tXQsXk zFKy{cS6KsMz-7Q;nYUN#Nc!6Y`o+k|v{ofipq_;3h72*s=Uj)HQyQ6|pI?sJ;qsrA zz$qu)Ta!t!qBx|okRC-062$pB}v3^Ldi*VdHVJYhgE^ zu#Al9yrrvMc{x_kCK+*8ez)QHD>r4SZ7?3CW&(cg8qw(rfU4lM8XwIrUlOg$i#2)` z=jvyA?zAnR{NwWAN4W3nl$T_`H4Q?-Sx6Ht*fRn!@NpcZ-!@*mTmtonfU6<4Hk2^g zvl8PWf9Gaa zCUWRZH{P5Sn>>Mv8?k?6N0-Lg=&04w_PH~Pb97nK^(Fs=5*7^6R-lJm zK+a96)z}_>LuzoP(neZSNg>?~i?ivYBGfc}*R0uuASq7-8uEHmr@^{0bp^k47Uz0z&0T=si@tn zWC>7J-pUtitZn%Fm-u2OgaG){YPL>L=MPQ&pc1VG&F>fRZVIjEyl<2Z#tr^1#!_h3 zI4G3QM>do@!hgV@jcjBmGUtD{=6~GD`TND6Fr^ap9Sf`fv!=qIzrbJ00SHtuSrPk4 zZS6FIZG%MF0Q* literal 0 HcmV?d00001 diff --git a/rapid_task_solving/memory_planning_game.py b/rapid_task_solving/memory_planning_game.py new file mode 100644 index 0000000..eea85f7 --- /dev/null +++ b/rapid_task_solving/memory_planning_game.py @@ -0,0 +1,184 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Memory & Planning Game environment.""" +import string + +import dm_env +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np + + +class MemoryPlanningGame(dm_env.Environment): + """Memory & Planning Game environment.""" + + ACTION_NAMES = ['Up', 'Down', 'Left', 'Right', 'Collect'] + NUM_ACTIONS = len(ACTION_NAMES) + DIRECTIONS = [ + (0, 1), # Up + (0, -1), # Down + (-1, 0), # Left + (1, 0), # Right + (0, 0), # Collect + ] + + def __init__(self, + maze_size=4, + max_episode_steps=100, + target_reward=1., + per_step_reward=0., + random_respawn=False, + seed=None): + """The Memory & Planning Game environment. + + Args: + maze_size: (int) size of the maze dimension. + max_episode_steps: (int) number of steps per episode. + target_reward: (float) reward value of the target. + per_step_reward: (float) reward/cost of taking a step. + random_respawn: (bool) whether the agent respawns in a random location + upon collecting the goal. + seed: (int or None) seed for random number generator. + """ + self._maze_size = maze_size + self._num_labels = maze_size * maze_size + # The graph itself is the same across episodes, but the node labels will be + # randomly sampled in each episode. + self._graph = nx.grid_2d_graph( + self._maze_size, self._maze_size, periodic=True) + self._max_episode_steps = max_episode_steps + self._target_reward = target_reward + self._per_step_reward = per_step_reward + self._random_respawn = random_respawn + self._rng = np.random.RandomState(seed) + + def _one_hot(self, node): + one_hot_vector = np.zeros([self._num_labels], dtype=np.int32) + one_hot_vector[self._labels[node]] = 1 + return one_hot_vector + + def step(self, action): + # If previous step was the last step of an episode, reset. + if self._needs_reset: + return self.reset() + + # Increment step count and check if it's the last step of the episode. + self._episode_steps += 1 + if self._episode_steps >= self._max_episode_steps: + self._needs_reset = True + transition = dm_env.termination + else: + transition = dm_env.transition + + # Recompute agent's position given the selected action. + direction = self.DIRECTIONS[action] + self._position = tuple( + (np.array(self._position) + np.array(direction)) % self._maze_size) + self._previous_action = self.ACTION_NAMES[action] + + # Get reward if agent is over the goal location and the selected action is + # `collect`. + if self._position == self._goal and self.ACTION_NAMES[action] == 'Collect': + reward = self._target_reward + self._set_new_goal() + else: + reward = self._per_step_reward + self._episode_reward += reward + + return transition(reward, self._observation()) + + def _observation(self): + return { + 'position': np.array(self._one_hot(self.position), dtype=np.int32), + 'goal': np.array(self._one_hot(self.goal), dtype=np.int32), + } + + def observation_spec(self): + return { + 'position': dm_env.specs.Array( + shape=(self._num_labels,), dtype=np.int32, name='position'), + 'goal': dm_env.specs.Array( + shape=(self._num_labels,), dtype=np.int32, name='goal'), + } + + def action_spec(self): + return dm_env.specs.DiscreteArray(self.NUM_ACTIONS) + + def take_random_action(self): + return self.step(self._rng.randint(self.NUM_ACTIONS)) + + def reset(self): + self._previous_action = '' + self._episode_reward = 0. + self._episode_steps = 0 + self._needs_reset = False + random_labels = self._rng.permutation(self._num_labels) + self._labels = {n: random_labels[i] + for i, n in enumerate(self._graph.nodes())} + self._respawn() + self._set_new_goal() + return dm_env.restart(self._observation()) + + def _respawn(self): + random_idx = self._rng.randint(self._num_labels) + self._position = list(self._graph.nodes())[random_idx] + + def _set_new_goal(self): + if self._random_respawn: + self._respawn() + goal = self._position + while goal == self._position: + random_idx = self._rng.randint(self._num_labels) + goal = list(self._graph.nodes())[random_idx] + self._goal = goal + + @property + def position(self): + return self._position + + @property + def goal(self): + return self._goal + + @property + def previous_action(self): + return self._previous_action + + @property + def episode_reward(self): + return self._episode_reward + + def draw_maze(self, ax=None): + if ax is None: + plt.figure() + ax = plt.gca() + node_positions = {(x, y): (x, y) for x, y in self._graph.nodes()} + letters = string.ascii_uppercase + string.ascii_lowercase + labels = {n: letters[self._labels[n]] for n in self._graph.nodes()} + node_list = list(self._graph.nodes()) + colors = [] + for n in node_list: + if n == self.position: + colors.append('lightblue') + elif n == self.goal: + colors.append('lightgreen') + else: + colors.append('pink') + nx.draw(self._graph, pos=node_positions, nodelist=node_list, ax=ax, + node_color=colors, with_labels=True, node_size=200, labels=labels) + ax.set_title('{}\nEpisode reward={:.1f}'.format( + self.previous_action, self.episode_reward)) + ax.margins(.1) + return plt.gcf(), ax diff --git a/rapid_task_solving/one_shot_streetlearn.py b/rapid_task_solving/one_shot_streetlearn.py new file mode 100644 index 0000000..c69825a --- /dev/null +++ b/rapid_task_solving/one_shot_streetlearn.py @@ -0,0 +1,265 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""One-shot StreetLearn environment.""" + +import dm_env +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np + + +def deg_to_rad(x): + """Convert degrees to radians.""" + return x / 180. * np.pi + + +def rad_to_deg(x): + """Convert radians to degrees.""" + return x * 180. / np.pi + + +class OneShotStreetLearn(dm_env.Environment): + """One-shot Streetlearn environment.""" + + ACTION_NAMES = [ + 'Forward', + 'Left', + 'Right', + 'Collect', + ] + NUM_ACTIONS = len(ACTION_NAMES) + + def __init__(self, dataset_path, max_episode_steps, num_junctions=8, + target_reward=1., per_step_reward=0., observation_length=60, + seed=None): + self._graph = nx.read_gexf(dataset_path) + self._node_attrs = self._graph.nodes(data=True) + self._num_junctions = num_junctions + self._observation_length = observation_length + self._max_episode_steps = max_episode_steps + self._target_reward = target_reward + self._per_step_reward = per_step_reward + self._rng = np.random.RandomState(seed) + self.reset() + + def reset(self): + self._previous_action = '' + self._episode_reward = 0. + self._episode_steps = 0 + self._needs_reset = False + self._subgraph = self.get_random_subgraph() + self._observation_map = self.randomize_observations(self._subgraph) + self._position = self._rng.choice(list(self._subgraph.nodes())) + neighbours = self._neighbors_bearings(self._subgraph, self._position) + self._neighbour = neighbours[self._rng.randint(len(neighbours))] + self._set_new_goal() + return dm_env.restart(self._observation()) + + @property + def _current_edge(self): + return (self._position, self._neighbour['neighbour']) + + def _set_new_goal(self): + goal = None + edges = list(self._observation_map.keys()) + while goal is None or goal == self._current_edge: + goal = edges[self._rng.randint(len(edges))] + self._goal = goal + + def _one_hot(self, edge): + one_hot_vector = np.zeros([self._observation_length], dtype=np.int32) + one_hot_vector[self._observation_map[edge]] = 1 + return one_hot_vector + + def _observation(self): + return { + 'position': np.array(self._one_hot(self._current_edge), dtype=np.int32), + 'goal': np.array(self._one_hot(self._goal), dtype=np.int32), + } + + def observation_spec(self): + return { + 'position': dm_env.specs.Array( + shape=(self._observation_length,), dtype=np.int32, name='position'), + 'goal': dm_env.specs.Array( + shape=(self._observation_length,), dtype=np.int32, name='goal'), + } + + def action_spec(self): + return dm_env.specs.DiscreteArray(self.NUM_ACTIONS) + + def step(self, action): + # If previous step was the last step of an episode, reset. + if self._needs_reset: + return self.reset() + + # Increment step count and check if it's the last step of the episode. + self._episode_steps += 1 + if self._episode_steps >= self._max_episode_steps: + self._needs_reset = True + transition = dm_env.termination + else: + transition = dm_env.transition + + # Recompute agent's position + self._move(action) + self._previous_action = self.ACTION_NAMES[action] + + # Get reward if agent is at the goal location and the selected action is + # `collect`. + if (self._current_edge == self._goal and + self.ACTION_NAMES[action] == 'Collect'): + reward = self._target_reward + self._set_new_goal() + else: + reward = self._per_step_reward + self._episode_reward += reward + + return transition(reward, self._observation()) + + def randomize_observations(self, subgraph): + edges = list(subgraph.edges()) + edges.extend([(y, x) for (x, y) in edges]) + obs_permutation = self._rng.permutation(self._observation_length) + return {e: obs_permutation[i] for i, e in enumerate(edges)} + + def _calculate_bearing(self, node, neighbor): + lat1 = deg_to_rad(self._node_attrs[node]['lat']) + lng1 = deg_to_rad(self._node_attrs[node]['lng']) + lat2 = deg_to_rad(self._node_attrs[neighbor]['lat']) + lng2 = deg_to_rad(self._node_attrs[neighbor]['lng']) + delta_lng = lng2 - lng1 + theta = np.arctan2( + np.sin(delta_lng) * np.cos(lat2), + np.cos(lat1) * np.sin(lat2) - + np.sin(lat1) * np.cos(lat2) * np.cos(delta_lng)) + return theta + + def _neighbors_bearings(self, subgraph, node): + bearings = [] + for neighbor in list(subgraph[node]): + orientation = self._calculate_bearing(node, neighbor) + bearings.append({'neighbour': neighbor, 'orientation': orientation}) + bearings.sort(key=lambda x: x['orientation']) + return bearings + + def _sort_neighbors(self, node, neighbour): + bearings = self._neighbors_bearings(self._subgraph, node) + bs = [x['orientation'] for x in bearings] + idx = np.argmin(np.abs(bs - neighbour['orientation'])) + return { + 'forward': bearings[idx], + 'right': bearings[idx-1], + 'left': bearings[(idx+1) % len(bearings)], + } + + def _move(self, action): + neighbours = self._sort_neighbors(self._position, self._neighbour) + if action == 0: + new_node = self._neighbour['neighbour'] + neighbours = self._sort_neighbors(new_node, neighbours['forward']) + new_neighbour = neighbours['forward'] + else: + new_node = self._position + if action == 1: + new_neighbour = neighbours['left'] + elif action == 2: + new_neighbour = neighbours['right'] + else: + new_neighbour = self._neighbour + self._position = new_node + self._neighbour = new_neighbour + + def _all_next_junctions(self, subgraph, node): + neighbors = list(subgraph[node]) + edges = [self._get_next_junction(subgraph, node, nb) for nb in neighbors] + nodes = [y for (_, y) in edges] + return nodes, edges + + def _get_next_junction(self, subgraph, initial_node, next_node): + node = initial_node + while subgraph.degree(next_node) == 2: + neighbours = list(subgraph.neighbors(next_node)) + neighbours.remove(node) + node = next_node + next_node = neighbours.pop() + return (initial_node, next_node) + + def get_random_subgraph(self): + graph = self._graph + num_nodes = len(graph) + rnd_index = self._rng.randint(num_nodes) + center_node = list(graph.nodes())[rnd_index] + while graph.degree(center_node) <= 2: + rnd_index = self._rng.randint(num_nodes) + center_node = list(graph.nodes())[rnd_index] + to_visit = [center_node] + visited = [] + subgraph = nx.Graph() + while to_visit: + node = to_visit.pop(0) + visited.append(node) + new_nodes, new_edges = self._all_next_junctions(graph, node) + subgraph.add_edges_from(new_edges) + node_degrees = [subgraph.degree(n) for n in subgraph.nodes()] + count_junctions = len(list(filter(lambda x: x > 2, node_degrees))) + if count_junctions >= self._num_junctions: + break + new_nodes = filter(lambda x: x not in visited + to_visit, new_nodes) + to_visit.extend(new_nodes) + return subgraph + + def draw_subgraph(self, ax=None): + if ax is None: + _ = plt.figure(figsize=(3, 3)) + ax = plt.gca() + node_ids = list(self._subgraph.nodes()) + pos = { + x: (self._node_attrs[x]['lat'], self._node_attrs[x]['lng']) + for x in node_ids + } + labels = {} + nc = 'pink' + ec = 'black' + ns = 50 + nshape = 'o' + # Draw the current subgraph + nx.draw(self._subgraph, pos=pos, node_color=nc, with_labels=False, + node_size=ns, labels=labels, edgecolors=ec, node_shape=nshape, + ax=ax) + max_xy = np.array([np.array(x) for x in pos.values()]).max(0) + min_xy = np.array([np.array(x) for x in pos.values()]).min(0) + delta_xy = (max_xy - min_xy) / 6. + ax.set_xlim([min_xy[0] - delta_xy[0], max_xy[0] + delta_xy[0]]) + ax.set_ylim([min_xy[1] - delta_xy[1], max_xy[1] + delta_xy[1]]) + # Draw goal position and orientation + x = self._node_attrs[self._goal[0]]['lat'] + y = self._node_attrs[self._goal[0]]['lng'] + rotation = rad_to_deg(self._calculate_bearing(*self._goal)) + _ = ax.plot(x, y, marker=(3, 0, rotation - 90), color=(0, 0, 0), + markersize=14, markerfacecolor='white') + _ = ax.plot(x, y, marker=(2, 0, rotation - 90), color=(0, 0, 0), + markersize=12, markerfacecolor='None') + # Draw current position and orientation + x = self._node_attrs[self._position]['lat'] + y = self._node_attrs[self._position]['lng'] + rotation = rad_to_deg(self._neighbour['orientation']) + _ = ax.plot(x, y, marker=(3, 0, rotation - 90), color=(0, 0, 0), + markersize=14, markerfacecolor='lightgreen') + _ = ax.plot(x, y, marker=(2, 0, rotation - 90), color=(0, 0, 0), + markersize=12, markerfacecolor='None') + ax.set_title('{}\nEpisode reward = {}'.format( + self._previous_action, self._episode_reward)) + return plt.gcf(), ax diff --git a/rapid_task_solving/requirements.txt b/rapid_task_solving/requirements.txt new file mode 100644 index 0000000..2f8c891 --- /dev/null +++ b/rapid_task_solving/requirements.txt @@ -0,0 +1,6 @@ +dm-env>=1.2 +dm-haiku>=0.0.3 +jax>=0.2.8 +matplotlib>=3.1.2 +networkx>=2.3 +numpy>=1.18.0 diff --git a/wikigraphs/README.md b/wikigraphs/README.md new file mode 100644 index 0000000..71e378c --- /dev/null +++ b/wikigraphs/README.md @@ -0,0 +1,230 @@ +# WikiGraphs + +This package provides tools to download the [WikiGraphs dataset](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf) +[1], collected by pairing each Wikipedia article from [WikiText-103](https://arxiv.org/pdf/1609.07843.pdf) +[2] with a knowledge graph (a subgraph from [Freebase knowledge graph](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ) +[3]). The baseline code to reproduce results in [1] is included as well. We hope +this can spur more interest in developing models that can generate long text +conditioned on graph and generate graphs given text. + +## Setup Jax environment + +[Jax](https://github.com/google/jax#installation), +[Haiku](https://github.com/deepmind/dm-haiku#installation), and +[Optax](https://github.com/deepmind/dm-haiku#installation) are needed for this +package. It has been developed and tested on python 3 with the following +packages: + +* Jax==0.2.13 +* Haiku==0.0.5 +* Optax==0.0.6 + +Other packages required can be installed via: + +```bash +pip install -r requirements.txt +``` + +Note: you may need to use `pip3` to select pip for python 3 and `--user` option +to install the packages to avoid permission issues. + +## Installation + +```bash +pip install -e . +``` + +## Preparing the data + +### Download the data + +You can download and unzip the data by running the following command: + +```bash +bash scripts/download.sh +``` + +This will put the downloaded WikiText-103 data in a temporary directory +`/tmp/data` with the tokenized WikiText-103 data in `/tmp/data/wikitext-103` and +the raw data in `/tmp/data/wikitext-103-raw`. + +This script will also download our processed Freebase knowledge graph data in a +temporary directory `/tmp/data/freebase`. + +### Build vocabularies + +For WikiText-103, run the following command to generate a vocabulary file: + +```bash +python scripts/build_vocab.py \ + --vocab_file_path=/tmp/data/wikitext-vocab.csv \ + --data_dir=/tmp/data/wikitext-103 +``` + +You can change the default file paths but make sure you make them consistent. + +### Pair Freebase graphs with WikiText + +You can run the following command to pair the Freebase graphs with WikiText-103 +articles: + +```bash +python scripts/freebase_preprocess.py \ + --freebase_dir=/tmp/data/freebase/max256 \ + --output_dir=/tmp/data/wikigraphs/max256 +``` + +where the `freebase_dir` `/tmp/data/freebase/max256` is the directory that +contains the Fsreebase graphs, which should have files `train.gz`, `valid.gz` +and `test.gz` in it; and `output_dir` is the directory that will contain the +generated paired Freebase-WikiText data. + +Note: you may need to use `python3` to select python 3 if you have both python 2 +and 3 on your system. + +Given that there are the following number of articles in WikiText-103: + +Subset | #articles +------ | --------- +Train | 28472* +Valid | 60 +Test | 60 + +*Official number is 28475 but we were only able to find 28472 articles in +training set. + +Our dataset covers around 80% of the WikiText articles: + +Max graph size | 256 | 512 | 1024 +---------------------------- | ----- | ----- | ----- +\#articles in training set | 23431 | 23718 | 23760 +Trainining set coverage | 82.3% | 83.3% | 83.5% +\#articles in validation set | 48 | 48 | 48 +Validation set coverage | 80% | 80% | 80% +\#articles in test set | 43 | 43 | 43 +Test set coverage | 71.7% | 71.7% | 71.7% + +### Build vocabulary for WikiGraphs + +You can build the vocabulary for the graph data (the max256 version) by running +the following command: + +```bash +python scripts/build_vocab.py \ + --vocab_file_path=/tmp/data/graph-vocab.csv \ + --data_dir=/tmp/data/wikigraphs \ + --version=max256 \ + --data_type=graph \ + --threshold=15 +``` + +This gives us a vocabulary of size 31,087, with each token included in the +vocabulary appearing at least 15 times. + +You also need to build a separate text vocabulary for the WikiGraphs data, as +our training set does not cover 100% of WikiText-103. + +```bash +python scripts/build_vocab.py \ + --vocab_file_path=/tmp/data/text-vocab.csv \ + --data_dir=/tmp/data/wikigraphs \ + --version=max256 \ + --data_type=text \ + --threshold=3 +``` + +Here we choose threshold 3 which is also used by the original WikiText-103 data, +this gives us a vocabulary size of 238,068, only slightly smaller than the +original vocabulary size. + +Note that when loading these vocabularies to build tokenizers, our tokenizers +will add a few extra tokens, like ``, ``, so the final vocab size +might be slightly different from the numbers above, depending on which tokenizer +you choose to use. + +We only showcase how to build the vocabulary for the max256 version. The above +steps can be easily changed for the max512 and max1024 version. + +## Loading the dataset + +We provide JAX modules to load the WikiGraphs dataset. There are three classes +in `wikigraphs/data/paired_dataset.py`: + +* `TextOnlyDataset`: loads only the text part of the WikiGraphs data +* `Bow2TextDataset`: loads text and the paired graph representated as one big +bag-of-words (BoW) on all nodes and edges from the graph +* `Graph2TextDataset`: returns text and the paired graph in which each node or +edge is represented by a BoW + +Different versions of the dataset can be accessed by changing the `version` +argument in each class. For more detailed usage please refer to +`wikigraphs/data/paired_dataset_test.py`. Besides, the original WikiText dataset +can be loaded via the `Dataset` class in `wikigraphs/data/wikitext.py`. + +Note: you may want to change the default data directory if you prefer to place +it elsewhere. + +## Run baseline models + +Note: baseline models will be available soon. + +## Citing WikiGraphs + +To cite this work: + +``` +@inproceedings{wang2021wikigraphs, + title={WikiGraphs: A Wikipedia Text-Knowledge Graph Paired Dataset}, + author={Wang, Luyu and Li, Yujia and Aslan, Ozlem and Vinyals, Oriol}, + booktitle={Proceedings of the Graph-Based Methods for Natural Language Processing (TextGraphs)}, + pages={67--82}, + year={2021} +} +``` + +## License + +All code copyright 2021 DeepMind Technologies Limited + +Code is licensed under the Apache License, Version 2.0 (the "License"); you may +not use this file except in compliance with the License. You may obtain a copy +of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. + +[WikiGraphs](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf) +[1] is licensed under the terms of the Creative Commons +Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. + +[WikiText-103 data](https://arxiv.org/pdf/1609.07843.pdf) [2] (unchanged) is +licensed by Salesforce.com, Inc. under the terms of the Creative Commons +Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. You can find +details about CC BY-SA 4.0 at: + + https://creativecommons.org/licenses/by-sa/4.0/legalcode + +[Freebase data](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ) +[3] is licensed by Google LLC under the terms of the Creative +Commons CC BY 4.0 license. You may obtain a copy of the License at: + + https://creativecommons.org/licenses/by/4.0/legalcode + +## References + +1. L. Wang, Y. Li, O. Aslan, and O. Vinyals, "[WikiGraphs: a wikipedia - +knowledge graph paired dataset](https://www.aclweb.org/anthology/2021.textgraphs-1.7.pdf)", +in Proceedings of the Graph-based Methods for Natural Language Processing +(TextGraphs), pages 67-82, 2021. +2. S. Merity, C. Xiong, J. Bradbury, and R. Socher, "[Pointer sentinel mixture +models](https://arxiv.org/pdf/1609.07843.pdf)", +arXiv: 1609.07843, 2016. +3. K. Bollacker, C. Evans, P. Paritosh, T. Sturge, and J. Taylor, +"[Freebase: a collaboratively created graph database for structuring human +knowledge](https://dl.acm.org/doi/pdf/10.1145/1376616.1376746?casa_token=H2ggPTDMoZUAAAAA:7wBhO9hnOzNKoJyMH0PcpVQZ6Vg6Ud6hObiDJTzLCGRiBwmYFjOFSXrG5PcKLStu5-n4_OfkPJtbisQ)", +in Proceedings of ACM SIGMOD international conference on Managementof data, +pages 1247–1250, 2008. diff --git a/wikigraphs/requirements.txt b/wikigraphs/requirements.txt new file mode 100644 index 0000000..f4a6b0c --- /dev/null +++ b/wikigraphs/requirements.txt @@ -0,0 +1,7 @@ +absl-py==0.10.0 +dm-haiku +jax>=0.2.13 +nltk>=3.6.2 +numpy>=1.19.5 +optax>=0.0.6 +scikit-learn>=0.24.2 diff --git a/wikigraphs/scripts/build_vocab.py b/wikigraphs/scripts/build_vocab.py new file mode 100644 index 0000000..2d2d819 --- /dev/null +++ b/wikigraphs/scripts/build_vocab.py @@ -0,0 +1,166 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Script for building vocabulary files for datasets.""" + +import collections +import csv +import enum +import io +import os +from typing import List, Tuple + +from absl import app +from absl import flags +from absl import logging + +from wikigraphs.data import io_tools +from wikigraphs.data import paired_dataset +from wikigraphs.data import tokenizers +from wikigraphs.data import wikitext + + +class DatasetType(enum.Enum): + text = 1 + graph = 2 + wikitext = 3 + + +FLAGS = flags.FLAGS +flags.DEFINE_string('data_dir', '', 'Path to the directory that contains the' + ' unzipped wikitext-103 data.') +flags.DEFINE_string('vocab_file_path', '', 'Path to the output vocab file.') +flags.DEFINE_enum_class('data_type', DatasetType.wikitext, DatasetType, + 'One of {`wikitext`, `graph`, `text`}.') +flags.DEFINE_integer('threshold', 1, 'Frequency threshold for a word to be' + ' included in the vocabulary.') +flags.DEFINE_string('version', 'max256', 'Which version of paired data to use.') + + +def get_vocab(dataset: wikitext.RawDataset) -> List[Tuple[str, int]]: + """Build vocabulary, return (word, count) tuples sorted by count.""" + vocab = collections.defaultdict(int) + + for pair in dataset: + for t in pair.text.split(' '): + if t: + vocab[t] += 1 + + return sorted(vocab.items(), key=lambda t: -t[1]) + + +def write_vocab(vocab: List[Tuple[str, int]], output_path: str): + """Write a vocab list to a file.""" + output_dir = os.path.dirname(output_path) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(output_path, mode='wb') as f_: + with io.TextIOWrapper(f_, encoding='utf-8') as f: + w = csv.writer(f) + w.writerows(vocab) + + +def build_wikitext_vocab(): + logging.info('Loading the dataset.') + dataset = wikitext.RawDataset(subset='train', data_dir=FLAGS.data_dir) + logging.info('Building the vocab.') + vocab = get_vocab(dataset) + logging.info('Finished, vocab size %d, total number of tokens %d', + len(vocab), sum([c for _, c in vocab])) + logging.info('Writing the vocab to %s', FLAGS.vocab_file_path) + write_vocab(vocab, FLAGS.vocab_file_path) + + +def build_graph_vocab(): + """Build vocabulary for graph data.""" + logging.info('Loading the dataset.') + dataset = paired_dataset.ParsedDataset( + subset='train', data_dir=FLAGS.data_dir, version=FLAGS.version) + logging.info('Building graph vocab.') + + vocab = collections.defaultdict(int) + for pair in dataset: + graph = pair.graph + for n in graph.nodes(): + for t in tokenizers.GraphTokenizer.split_node(n): + if t: + vocab[t] += 1 + for _, _, e in graph.edges(): + for t in tokenizers.GraphTokenizer.split_edge(e): + if t: + vocab[t] += 1 + + vocab = sorted(vocab.items(), key=lambda t: -t[1]) + vocab = [k for k, v in vocab if v >= FLAGS.threshold] + + logging.info('Finished, vocab size %d.', len(vocab)) + logging.info('Writing the vocab to %s.', FLAGS.vocab_file_path) + + io_tools.write_txt_file(FLAGS.vocab_file_path, '\n'.join(vocab), + # Some unicode characters requires utf-16 to encode. + encoding='utf-16') + + +def build_text_vocab(): + """Build vocabulary for the text part of the graph-to-text data.""" + logging.info('Loading the dataset.') + dataset = paired_dataset.ParsedDataset( + subset='train', data_dir=FLAGS.data_dir, version=FLAGS.version) + logging.info('Building text vocab.') + + vocab = collections.defaultdict(int) + for pair in dataset: + for t in pair.text.split(' '): + if t: + vocab[t] += 1 + + vocab = sorted(vocab.items(), key=lambda t: -t[1]) + logging.info('Finished, vocab size %d, total number of tokens %d.', + len(vocab), sum([v for _, v in vocab])) + vocab = [(k, v) for k, v in vocab if v >= FLAGS.threshold] + logging.info('After filtering, vocab size %d.', len(vocab)) + logging.info('Writing the vocab to %s.', FLAGS.vocab_file_path) + + write_vocab(vocab, FLAGS.vocab_file_path) + + +def main(_): + if FLAGS.data_type == DatasetType.wikitext: + build_wikitext_vocab() + elif FLAGS.data_type == DatasetType.text: + build_text_vocab() + elif FLAGS.data_type == DatasetType.graph: + build_graph_vocab() + else: + raise ValueError(f'Unknown data type {FLAGS.data_type}.') + + +if __name__ == '__main__': + app.run(main) diff --git a/wikigraphs/scripts/download.sh b/wikigraphs/scripts/download.sh new file mode 100644 index 0000000..ac11ddd --- /dev/null +++ b/wikigraphs/scripts/download.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +BASE_DIR=/tmp/data + +# wikitext-103 +TARGET_DIR=${BASE_DIR}/wikitext-103 +mkdir -p ${TARGET_DIR} +wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -P ${TARGET_DIR} +unzip ${TARGET_DIR}/wikitext-103-v1.zip -d ${TARGET_DIR} +mv ${TARGET_DIR}/wikitext-103/* ${TARGET_DIR} +rm -rf ${TARGET_DIR}/wikitext-103 ${TARGET_DIR}/wikitext-103-v1.zip + +# wikitext-103-raw +TARGET_DIR=${BASE_DIR}/wikitext-103-raw +mkdir -p ${TARGET_DIR} +wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip -P ${TARGET_DIR} +unzip ${TARGET_DIR}/wikitext-103-raw-v1.zip -d ${TARGET_DIR} +mv ${TARGET_DIR}/wikitext-103-raw/* ${TARGET_DIR} +rm -rf ${TARGET_DIR}/wikitext-103-raw ${TARGET_DIR}/wikitext-103-raw-v1.zip + + +# processed freebase graphs +FREEBASE_TARGET_DIR=/tmp/data +mkdir -p ${FREEBASE_TARGET_DIR}/packaged/ +wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1uuSS2o72dUCJrcLff6NBiLJuTgSU-uRo' -O ${FREEBASE_TARGET_DIR}/packaged/max256.tar +wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1nOfUq3RUoPEWNZa2QHXl2q-1gA5F6kYh' -O ${FREEBASE_TARGET_DIR}/packaged/max512.tar +wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1uuJwkocJXG1UcQ-RCH3JU96VsDvi7UD2' -O ${FREEBASE_TARGET_DIR}/packaged/max1024.tar + +for version in max1024 max512 max256 +do + output_dir=${FREEBASE_TARGET_DIR}/freebase/${version}/ + mkdir -p ${output_dir} + tar -xvf ${FREEBASE_TARGET_DIR}/packaged/${version}.tar -C ${output_dir} +done +rm -rf ${FREEBASE_TARGET_DIR}/packaged diff --git a/wikigraphs/scripts/freebase_preprocess.py b/wikigraphs/scripts/freebase_preprocess.py new file mode 100644 index 0000000..cb29f8f --- /dev/null +++ b/wikigraphs/scripts/freebase_preprocess.py @@ -0,0 +1,106 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Preprocess freebase data and pair with wikitext.""" + +import os + +from absl import app +from absl import flags +from absl import logging + +from wikigraphs.data import io_tools +from wikigraphs.data import wikitext + + +FLAGS = flags.FLAGS +flags.DEFINE_string('freebase_dir', '', 'Directory that containns Freebase' + ' graphs.') +flags.DEFINE_string('output_dir', '', 'Path to output directory to store the' + ' paired dataset.') + + +def pair_graphs_with_wikitext(subset: str, graph_dir: str, output_dir: str): + """Pair graphs with wikitext articles, and write to output directory.""" + logging.info('Pairing graphs from the %s set from %s with wikitext.', + subset, graph_dir) + graphs = list(io_tools.graphs_from_file( + os.path.join(graph_dir, f'{subset}.gz'))) + title2graph = { + io_tools.normalize_freebase_string(g.title).replace(' ', ''): g + for g in graphs} + n_graphs = len(graphs) + + # Use raw version of the wikitext data as the tokenized version has in + # titles which is bad for matching. We will handle the s through the + # tokenizer to make sure our data are equivalent to that of the tokenized + # version of wikitext-103. + wikitext_articles = list(wikitext.RawDataset(subset=subset, version='raw')) + n_wiki = len(wikitext_articles) + logging.info('Loaded %d graphs and %d wikitext articles in total.', + n_graphs, n_wiki) + + # Keep track of the article titles in the dataset. Unfortunately wikitext-103 + # has about 1% of duplicated articles, we want to take care of that. + retrieved_titles = set() + pairs = [] + n_duplicates = 0 + for a in wikitext_articles: + title = wikitext.normalize_title(a.title).replace(' ', '') + g = title2graph.get(title, None) + if g is not None: + if title not in retrieved_titles: + retrieved_titles.add(title) + pairs.append(io_tools.GraphTextPair( + center_node=g.center, + title=g.title, + edges=g.edges, + text=a.text)) + else: + n_duplicates += 1 + + n_pairs = len(pairs) + logging.info('Matched %d/%d = %.1f%% of wikitext articles,' + ' and %d/%d = %.1f%% of graphs.', + n_pairs, n_wiki, float(n_pairs) / n_wiki * 100, + n_pairs, n_graphs, float(n_pairs) / n_graphs * 100) + logging.info('Detected %d/%d = %.1f%% of duplicated wikitext articles.', + n_duplicates, n_wiki, float(n_duplicates) / n_wiki * 100) + + io_tools.write_pairs_to_gzip_txt_file( + os.path.join(output_dir, f'{subset}.gz'), pairs) + + +def main(_): + for subset in ['train', 'valid', 'test']: + pair_graphs_with_wikitext(subset, FLAGS.freebase_dir, FLAGS.output_dir) + + +if __name__ == '__main__': + app.run(main) diff --git a/wikigraphs/scripts/visualize_graph.py b/wikigraphs/scripts/visualize_graph.py new file mode 100644 index 0000000..b982852 --- /dev/null +++ b/wikigraphs/scripts/visualize_graph.py @@ -0,0 +1,143 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +r"""Tool to visualize graphs. + +You need to have the command line tool `dot` installed locally, for example by +`sudo apt-get install graphviz`. + +Example usage: +python visualize_graph.py \ + --logtostderr --graph_ids=0:48 --truncate_limit=500 --layout=fdp +""" + +import html +import os +import textwrap + +from absl import app +from absl import flags +from absl import logging + +from wikigraphs.data import io_tools +from wikigraphs.data import paired_dataset as pd + + +FLAGS = flags.FLAGS +flags.DEFINE_string('subset', 'valid', 'Which subset to choose graphs from.') +flags.DEFINE_string('graph_ids', '', 'A comma-separated string of graph IDs' + ' (0-based), for example `1,2,3`. Or alternatively a' + ' range, e.g. `0:10` which is equivalent to' + ' `0,1,2,3,...,9`.') +flags.DEFINE_string('version', 'max256', 'Which version of data to load.') +flags.DEFINE_string('data_dir', '', 'Path to a directory that contains the raw' + ' paired data, if provided.') +flags.DEFINE_string('output_dir', '/tmp/graph_vis', 'Output directory to save' + ' the visualized graphs.') +flags.DEFINE_integer('truncate_limit', -1, 'Maximum length for graph nodes in' + ' visualization.') +flags.DEFINE_string('layout', 'fdp', 'Which one of the dot layout to use.') + + +def truncate(s: str) -> str: + if FLAGS.truncate_limit > 0 and len(s) > FLAGS.truncate_limit: + s = s[:FLAGS.truncate_limit] + '...' + return s + + +def format_label(s: str, width: int = 40) -> str: + """Format a node / edge label.""" + s = io_tools.normalize_freebase_string(s) + s = truncate(s) + lines = s.split('\\n') + output_lines = [] + for line in lines: + line = html.escape(line) + if width > 0: + output_lines += textwrap.wrap(line, width) + else: + output_lines.append(line) + return '<' + '
'.join(output_lines) + '>' + + +def graph_to_dot(graph_text_pair: io_tools.GraphTextPair) -> str: + """Convert a graph to a dot file.""" + dot = ['digraph {', 'node [shape=rect];'] + graph = pd.Graph.from_edges(graph_text_pair.edges) + center_node_id = graph.node2id(graph_text_pair.center_node) + + for i, n in enumerate(graph.nodes()): + color = '#f5dc98' if i == center_node_id else ( + '#b0ffad' if not(n[0] == '"' and n[-1] == '"') else '#ffffff') + label = format_label(n) + dot.append(f'{i} [ label = {label}, fillcolor="{color}", style="filled"];') + + for i, j, e in graph.edges(): + dot.append(f'{i} -> {j} [ label = {format_label(e, width=0)} ];') + dot.append('}') + return '\n'.join(dot) + + +def visualize_graph(graph_text_pair: io_tools.GraphTextPair, + graph_id: int, + output_dir: str): + """Visualize a graph and save the visualization to the specified directory.""" + dot = graph_to_dot(graph_text_pair) + output_file = os.path.join(output_dir, f'{graph_id}.dot') + logging.info('Writing output to %s', output_file) + with open(output_file, 'w') as f: + f.write(dot) + pdf_output = os.path.join(output_dir, f'{graph_id}.pdf') + os.system(f'dot -K{FLAGS.layout} -Tpdf -o {pdf_output} {output_file}') + + +def main(_): + logging.info('Loading the %s set of data.', FLAGS.subset) + pairs = list(pd.RawDataset(subset=FLAGS.subset, + data_dir=FLAGS.data_dir or None, + shuffle_data=False, + version=FLAGS.version)) + logging.info('Loaded %d graph-text pairs.') + + if ':' in FLAGS.graph_ids: + start, end = [int(i) for i in FLAGS.graph_ids.split(':')] + graph_ids = list(range(start, end)) + else: + graph_ids = [int(i) for i in FLAGS.graph_ids.split(',')] + logging.info('Visualizing graphs with ID %r', graph_ids) + + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + + for gid in graph_ids: + visualize_graph(pairs[gid], gid, FLAGS.output_dir) + + +if __name__ == '__main__': + app.run(main) diff --git a/wikigraphs/setup.py b/wikigraphs/setup.py new file mode 100644 index 0000000..680bdfc --- /dev/null +++ b/wikigraphs/setup.py @@ -0,0 +1,43 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Setup for pip package.""" +from setuptools import find_packages +from setuptools import setup + +setup( + name='wikigraphs', + version='0.0.1', + description='A Wikipedia - knowledge graph paired dataset.', + url='https://github.com/deepmind/deepmind-research/tree/master/wikigraphs', + author='DeepMind', + author_email='luyuwang@google.com', + packages=find_packages(), + license='Apache 2.0', +) diff --git a/wikigraphs/wikigraphs/data/__init__.py b/wikigraphs/wikigraphs/data/__init__.py new file mode 100644 index 0000000..8cbeba0 --- /dev/null +++ b/wikigraphs/wikigraphs/data/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""WikiGraphs data modules.""" +from . import dataset +from . import io_tools +from . import paired_dataset +from . import tokenizers +from . import tools +from . import wikitext diff --git a/wikigraphs/wikigraphs/data/dataset.py b/wikigraphs/wikigraphs/data/dataset.py new file mode 100644 index 0000000..e34f523 --- /dev/null +++ b/wikigraphs/wikigraphs/data/dataset.py @@ -0,0 +1,59 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Base class of the datasets.""" + +import abc +from typing import Any, Iterator + + +class Dataset(abc.ABC): + """Base class for all datasets. + + All sub-classes should define `_load_data()` where an iterator + `self._data_iter` should be instantiated that iterates over the dataset. + """ + + def __init__(self): + """Constructor.""" + self._data_iter = None # An iterator produced by `self._load_data`. + + @abc.abstractmethod + def _load_data(self) -> Iterator[Any]: + """Prepare data for another pass through the dataset. + + This method should return a generator in a child class. + """ + + def __next__(self): + return next(self._data_iter) + + def __iter__(self): + self._data_iter = self._load_data() + return self diff --git a/wikigraphs/wikigraphs/data/io_tools.py b/wikigraphs/wikigraphs/data/io_tools.py new file mode 100644 index 0000000..23abb02 --- /dev/null +++ b/wikigraphs/wikigraphs/data/io_tools.py @@ -0,0 +1,179 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Some tools for I/O.""" + +import gzip +import io +import os +import re +from typing import NamedTuple, List, Iterator + +from absl import logging + + +def read_txt_file(file_path: str, encoding: str = 'utf-8') -> str: + """Read a plain txt file.""" + with open(file_path, 'rb') as f: + content = f.read() + return content.decode(encoding) + + +def write_txt_file(file_path: str, txt: str, encoding: str = 'utf-8'): + """Write the given txt string to file.""" + make_dir_if_necessary(file_path) + with open(file_path, 'wb') as f: + f.write(txt.encode(encoding, 'surrogatepass')) + + +def read_gzip_txt_file(file_path: str, encoding: str = 'utf-8') -> str: + """Read gzipped txt file.""" + with open(file_path, 'rb') as f: + content = f.read() + + with gzip.GzipFile(fileobj=io.BytesIO(content), mode='rb') as f: + content = f.read() + return content.decode(encoding) + + +def make_dir_if_necessary(output_path): + output_dir = os.path.dirname(output_path) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + +def write_lines_to_gzipped_file(file_path, lines): + make_dir_if_necessary(file_path) + with open(file_path, 'wb') as f_zip: + with gzip.GzipFile(fileobj=f_zip, mode='wb') as f: + f.write('\n'.join(lines).encode('utf-8')) + + +class Graph(NamedTuple): + title: str + center: str + edges: List[str] + + +def graphs_from_file(file_path: str) -> Iterator[Graph]: + """Read freebase graphs from file. + + Args: + file_path: path to the input `.gz` file that contains a list of graphs. + + Yields: + graphs: a list of read from the file. + """ + content = read_gzip_txt_file(file_path) + + graph_header_sep_re = re.compile( + r'(\n)') + graph_header_re = re.compile( + r'\n') + parts = graph_header_sep_re.split(content) + + # Skip the first part which is empty + for i in range(1, len(parts), 2): + header, body = parts[i], parts[i + 1] + m = graph_header_re.match(header) + yield Graph(title=m.group(2), + center=m.group(1), + edges=body.strip().split('\n')) + + +_UNICODE_RE = re.compile(r'(\$[0-9A-Fa-f]{4})') + + +def normalize_freebase_string(s: str) -> str: + """Expand the `$xxxx` escaped unicode characters in the input string.""" + # '"' is escaped as '``', convert it back. + s.replace('``', '"') + parts = _UNICODE_RE.split(s) + parts = [p if not _UNICODE_RE.match(p) else chr(int(p[1:], base=16)) + for p in parts] + return ''.join(parts).replace('_', ' ') + + +class GraphTextPair(NamedTuple): + """Text paired with raw graph represented as in `edges`.""" + center_node: str + title: str + edges: List[str] + text: str + + +def pair2lines(pair): + lines = [f''] + lines.append('
') + lines.append(pair.text) + lines.append('
') + lines.extend(pair.edges) + return lines + + +def write_pairs_to_gzip_txt_file(file_path, pairs): + logging.info('Writing %d pairs to %s.', len(pairs), file_path) + lines = [] + for p in pairs: + lines.extend(pair2lines(p)) + write_lines_to_gzipped_file(file_path, lines) + + +def read_pairs_from_gzip_txt_file(file_path: str) -> Iterator[GraphTextPair]: + """Read graph-text pairs from gzip txt files. + + Args: + file_path: a `.gz` file of graph-text pairs written in the same format as + using the `write_pairs_to_gzip_txt_file` function. + + Yields: + Graph-text pairs from this file. + """ + content = read_gzip_txt_file(file_path) + + graph_header_sep_re = re.compile( + r'()') + graph_header_re = re.compile( + r'$') + section_sep_re = re.compile(r'\n(
\n)') + parts = graph_header_sep_re.split(content) + + # Skip the first part which is empty + for i in range(1, len(parts), 2): + header, body = parts[i], parts[i + 1] + m = graph_header_re.match(header) + + # 5 parts total, empty first part, "text", text section, "edges", edges + # section. + section_parts = section_sep_re.split(body) + + yield GraphTextPair(center_node=m.group(1), + title=m.group(2), + text=section_parts[2], + edges=section_parts[-1].strip().split('\n')) diff --git a/wikigraphs/wikigraphs/data/paired_dataset.py b/wikigraphs/wikigraphs/data/paired_dataset.py new file mode 100644 index 0000000..6987bfe --- /dev/null +++ b/wikigraphs/wikigraphs/data/paired_dataset.py @@ -0,0 +1,767 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tools for accessing the graph-text paired datasets.""" + +import abc +import collections +from typing import List, Tuple, NamedTuple, Any, Dict, Optional, Union + +from absl import logging +import jax.numpy as jnp +import numpy as np + +from wikigraphs.data import dataset +from wikigraphs.data import io_tools +from wikigraphs.data import tokenizers +from wikigraphs.data import tools + + +ArrayType = Union[np.ndarray, jnp.ndarray] +DATA_ROOT = '/tmp/data/wikigraphs' + + +class RawDataset(dataset.Dataset): + """The untokenized raw dataset.""" + + def __init__(self, + subset: str = 'train', + shuffle_data: bool = False, + data_dir: str = None, + version: str = 'max256'): + """Constructor. + + Args: + subset: which subset to load. + shuffle_data: set to True to randomly shuffle the data. + data_dir: if provided this will be used instead of the default location to + look for data, it must contain files like `train.gz`, `valid.gz` and + `test.gz`. + version: which version of the data to load, this must be the name of a + directory in `DATA_ROOT`. + """ + super().__init__() + self._subset = subset + self._shuffle_data = shuffle_data + self._data_dir = data_dir or DATA_ROOT + self._dataset = None + + allowed_versions = ('max256', 'max512', 'max1024') + if version not in allowed_versions: + raise ValueError(f'Version {version} not one of the allowed versions:' + f' {allowed_versions}.') + + self._version = version + + def _load_data(self): + """Load and prepare the data iterator.""" + if self._dataset is None: + self._dataset = list(io_tools.read_pairs_from_gzip_txt_file( + f'{self._data_dir}/{self._version}/{self._subset}.gz')) + + def source(): + n_pairs = len(self._dataset) + if self._shuffle_data: + idx = np.random.permutation(n_pairs) + else: + idx = np.arange(n_pairs) + for i in range(n_pairs): + yield self._dataset[idx[i]] + + return source() + + +class Graph: + """A convenience class for representing graphs.""" + + def __init__(self, nodes: List[str], edges: List[Tuple[int, int, str]]): + """Construct a graph from a list of nodes and edges. + + Args: + nodes: a list of node attributes, one for each node. + edges: a list of (source_node_id, target_node_id, edge_attribute) for each + edge. + """ + self._nodes = nodes + self._edges = edges + self._node2id = {n: i for i, n in enumerate(nodes)} + + def nodes(self) -> List[str]: + return self._nodes + + def edges(self) -> List[Tuple[int, int, str]]: + return self._edges + + def node2id(self, node: str) -> int: + return self._node2id[node] + + @classmethod + def from_edges(cls, edges: List[str]) -> 'Graph': + """Build a graph instance from a list of edges.""" + node2id = dict() + parsed_edges = [] + next_node_id = 0 + + for e in edges: + src, edge, tgt = e.split('\t')[:3] + src_id = node2id.get(src, next_node_id) + if src_id == next_node_id: + node2id[src] = src_id + next_node_id += 1 + tgt_id = node2id.get(tgt, next_node_id) + if tgt_id == next_node_id: + node2id[tgt] = tgt_id + next_node_id += 1 + parsed_edges.append((src_id, tgt_id, edge)) + + id2node = {i: n for n, i in node2id.items()} + return Graph(nodes=[id2node[i] for i in range(next_node_id)], + edges=parsed_edges) + + def to_edges(self) -> List[str]: + r"""Convert graph to a list of edges. + + The converted list of edges should be compatible with the format specified + in io_tools and compatible with the `from_edges` method above. + + Returns: + edges: one edge per line, with the (source, target, edge_type) separated + by `\t`. + """ + edges = [] + for s, t, e in self._edges: + edges.append(f'{self._nodes[s]}\t{e}\t{self._nodes[t]}') + return edges + + @classmethod + def subsample_nodes( + cls, graph: 'Graph', subsample_rate: float = 1.0, center_node: str = None + ) -> 'Graph': + """Subsample the nodes of a graph.""" + graph_size = len(graph.nodes()) + if subsample_rate == 1.0 or graph_size <= 1: + return graph + subsampled_nodes_id = np.arange(graph_size) + if subsample_rate < 1.0: + subsample_graph_size = int(subsample_rate * graph_size) + if center_node is not None: + # We need to keep the center node during subsampling + center_node_id = graph.node2id(center_node) + subsampled_nodes_id = subsampled_nodes_id[ + subsampled_nodes_id != center_node_id] + subsample_graph_size = max(1, subsample_graph_size - 1) + subsampled_nodes_id = np.random.choice( + subsampled_nodes_id, subsample_graph_size, replace=False) + subsampled_nodes_id = np.append(subsampled_nodes_id, center_node_id) + else: + subsampled_nodes_id = np.random.choice( + subsampled_nodes_id, subsample_graph_size, replace=False) + subsampled_nodes_id = np.sort(subsampled_nodes_id) + map_subsampled_nodes_id = { + old_id: new_id for new_id, old_id in enumerate(subsampled_nodes_id)} + nodes = [] + edges = [] + for node_id, n in enumerate(graph.nodes()): + if node_id in subsampled_nodes_id: + nodes.append(n) + for out_node, in_node, e in graph.edges(): + if out_node in subsampled_nodes_id and in_node in subsampled_nodes_id: + edges.append((map_subsampled_nodes_id[out_node], + map_subsampled_nodes_id[in_node], e)) + return Graph(nodes=nodes, edges=edges) + + +class ParsedGraphTextPair(NamedTuple): + """Graph-text pair with graph parsed into a `Graph` instance.""" + center_node: str + title: str + text: str + graph: Graph + + +class ParsedGraph(NamedTuple): + """Data structure for representing a batch of graphs.""" + # [n_nodes, node_feat_dim] tensor. + nodes: ArrayType + # [n_edges, edge_feat_dim] tensor. + edges: ArrayType + # [n_edges] int tensor, index of the sender nodes for each edge. + sender: ArrayType + # [n_edges] int tensor, index of the receiver nodes for each edge. + receiver: ArrayType + # [n_graphs] int tensor, graph size (num nodes) for each graph in the batch. + graph_sizes: ArrayType + + +def batch_graphs(graphs: List[ParsedGraph]) -> ParsedGraph: + """Batch a list of graphs into a single graph. + + This method also updates the sender and receiver node indices as well as the + graph_sizes. + + Args: + graphs: a list of graphs to be merged. + + Returns: + graph: a single merged graph. + """ + for g in graphs: + if g.graph_sizes.size != 1: + raise ValueError('Each individual element in the list must contain only a' + ' single graph.') + node_idx_start = 0 + nodes = [] + edges = [] + sender = [] + receiver = [] + graph_sizes = [] + for g in graphs: + nodes.append(g.nodes) + edges.append(g.edges) + sender.append(g.sender + node_idx_start) + receiver.append(g.receiver + node_idx_start) + graph_sizes.append(g.graph_sizes) + node_idx_start += g.graph_sizes[0] + + return ParsedGraph(nodes=np.concatenate(nodes, axis=0), + edges=np.concatenate(edges, axis=0), + sender=np.concatenate(sender, axis=0), + receiver=np.concatenate(receiver, axis=0), + graph_sizes=np.concatenate(graph_sizes, axis=0)) + + +class ParsedDataset(dataset.Dataset): + """Raw dataset + parsing graphs into Graph instances.""" + + def __init__(self, + subset: str = 'train', + shuffle_data: bool = False, + data_dir: str = None, + version: str = 'max256'): + """Constructor. + + Args: + subset: which subset to load. + shuffle_data: set to True to randomly shuffle the data. + data_dir: if provided this will be used instead of the default location to + look for data, it must contain files like `train.gz`, `valid.gz` and + `test.gz`. + version: which version of the data to load, this must be the name of a + directory in `DATA_ROOT`. + """ + super().__init__() + self._raw_data = RawDataset(subset=subset, shuffle_data=False, + data_dir=data_dir, version=version) + self._shuffle_data = shuffle_data + self._dataset = None + + def _load_data(self): + if self._dataset is None: + # pylint: disable=g-complex-comprehension + self._dataset = [ParsedGraphTextPair(center_node=pair.center_node, + title=pair.title, + text=pair.text, + graph=Graph.from_edges(pair.edges)) + for pair in self._raw_data] + + def source(): + n_pairs = len(self._dataset) + if self._shuffle_data: + idx = np.random.permutation(n_pairs) + else: + idx = np.arange(n_pairs) + for i in range(n_pairs): + yield self._dataset[idx[i]] + + return source() + + +class BaseGraph2TextDataset(dataset.Dataset): + """Base dataset class for graph-to-text tasks.""" + + def __init__(self, + tokenizer: tokenizers.Tokenizer, + graph_tokenizer: Optional[tokenizers.GraphTokenizer] = None, + batch_size: int = 1, + timesteps: int = 128, + subset: str = 'train', + shuffle_data: bool = False, + repeat: bool = False, + version: str = 'max256', + data_dir: str = None, + subsample_nodes: float = 1.0, + graph_retrieval_dataset: bool = False, + debug: bool = False): + """Constructor. + + Args: + tokenizer: the tokenizer for text data. + graph_tokenizer: the tokenizer for graph data. + batch_size: number of sequences to put in a batch. + timesteps: number of tokens to put in a sequence in a batch. + subset: which subset to load. + shuffle_data: whether to shuffle data. + repeat: set to True to repeat the dataset infinitely, otherwise do only + one pass through the dataset. + version: which version of the data to load. + data_dir: if set load data instead from this directory, and ignore + `version`. + subsample_nodes: the proportion of the nodes in a graph to keep. + graph_retrieval_dataset: whether to construct the dataset for graph + retrieval tasks. + debug: set to True to use debug mode and only load a small number of + examples. + """ + super().__init__() + self._parsed_data = ParsedDataset(subset=subset, + shuffle_data=False, + data_dir=data_dir, + version=version) + self._tokenizer = tokenizer + self._graph_tokenizer = graph_tokenizer + self._batch_size = batch_size + self._timesteps = timesteps + self._subset = subset + self._shuffle_data = shuffle_data + self._repeat = repeat + self._subsample_nodes = subsample_nodes + self._graph_retrieval_dataset = graph_retrieval_dataset + self._debug = debug + + self._dataset = None + + @property + def num_articles(self): + return self._num_articles + + @abc.abstractmethod + def _process_graph(self, center_node: str, graph: Graph): + """Process the graph part of a `ParsedGraphTextPair` instance.""" + + def _process_graph_text_pair( + self, pair: ParsedGraphTextPair) -> Tuple[Any, np.ndarray]: + """Process the given graph-text pair and prepare one example. + + Args: + pair: the input `ParsedGraphTextPair` instance. + + Returns: + graph: the processed graph content. + text: the tokenized text, a sequence of token IDs. + """ + return (self._process_graph(pair.center_node, pair.graph), + self._tokenizer.encode( + pair.text, prepend_bos=True, append_eos=True)) + + def _load_data(self): + """Prepare the data.""" + if self._dataset is None: + if self._debug: + data = [next(self._parsed_data) for _ in range(10)] + else: + data = list(self._parsed_data) + self._dataset = [self._process_graph_text_pair(p) for p in data] + self._num_articles = len(self._dataset) + logging.info('Loaded a total of %d examples from %s set.', + self._num_articles, self._subset) + if self._graph_retrieval_dataset: + # For graph retrieval tasks we pair all texts and graphs in the dataset, + # and indicate their (text_id, graph_id) + retrieval_data = [] + for i, (g1, _) in enumerate(self._dataset): + for j, (_, t2) in enumerate(self._dataset): + retrieval_data.append(((g1, t2), (i, j))) + self._dataset = retrieval_data + logging.info('Constructed %d pairs.', len(self._dataset)) + + def source(): + n_examples = len(self._dataset) + if self._shuffle_data: + idx = np.random.permutation(n_examples) + else: + idx = np.arange(n_examples) + for i in range(n_examples): + yield self._dataset[idx[i]] + + def maybe_repeated_source(): + if self._repeat: + while True: + yield from source() + else: + yield from source() + + data_iter = tools.batch_graph_text_pairs( + maybe_repeated_source(), + self._batch_size, + self._timesteps + 1, + pad_value=self._tokenizer.pad_token(), + seq_and_graph_id=self._graph_retrieval_dataset) + + if self._graph_retrieval_dataset: + data_iter = map(lambda x: dict( # pylint: disable=g-long-lambda + obs=x['obs'][:, :-1], + target=x['obs'][:, 1:], + should_reset=x['should_reset'][:, :-1], + # If target is a token then that target should not be predicted. + mask=(x['obs'][:, 1:] != self._tokenizer.pad_token()).astype( + np.float32), + seq_id=x['seq_id'], + graph_id=x['graph_id'], + graphs=self._process_graph_batch(x['graphs']), + ), data_iter) + else: + data_iter = map(lambda x: dict( # pylint: disable=g-long-lambda + obs=x['obs'][:, :-1], + target=x['obs'][:, 1:], + should_reset=x['should_reset'][:, :-1], + # If target is a token then that target should not be predicted. + mask=(x['obs'][:, 1:] != self._tokenizer.pad_token()).astype( + np.float32), + graphs=self._process_graph_batch(x['graphs']), + ), data_iter) + + # Filter out batches that does not have targets. + # This may happen when an observation contains a single last token of the + # sequence, which was predicted as target in the previous batch, and only + # used as observation in this batch, without a matching target. In this + # case all the masks are 0, therefore this batch provides no training signal + # and we can safely remove this batch. This also avoids some potential + # downstream issues. + data_iter = filter(lambda x: x['mask'].sum() > 0, data_iter) + return data_iter + + @abc.abstractmethod + def _process_graph_batch(self, graphs: List[Any]): + """Process a batch of graph data. + + Args: + graphs: a list of graph data, each as returned by `_process_graph`. + + Returns: + processed_graphs: processed tensor(s) that can be directly fed into a + model. + """ + + @abc.abstractmethod + def return_faux_batch(self) -> Dict[str, np.ndarray]: + """Return a fake batch with the right shapes and dtypes.""" + + +class TextOnlyDataset(BaseGraph2TextDataset): + """Text-only version of the paired dataset.""" + + def __init__(self, + tokenizer: tokenizers.Tokenizer, + graph_tokenizer: Optional[tokenizers.GraphTokenizer] = None, + batch_size: int = 1, + timesteps: int = 128, + subset: str = 'train', + shuffle_data: bool = False, + repeat: bool = False, + version: str = 'max256', + data_dir: str = None, + debug: bool = False, + **kwargs): + """Constructor. + + Args: + tokenizer: the tokenizer for text data. + graph_tokenizer: not used, keeping it here for compatibility with other + graph2text datasets. + batch_size: number of sequences to put in a batch. + timesteps: number of tokens to put in a sequence in a batch. + subset: which subset to load. + shuffle_data: whether to shuffle data. + repeat: set to True to repeat the dataset infinitely, otherwise do only + one pass through the dataset. + version: which version of the data to load. + data_dir: if set load data instead from this directory, and ignore + `version`. + debug: set to True to use debug mode and only load a small number of + examples. + **kwargs: other arguments (for interface compatibility). + """ + del graph_tokenizer + super().__init__(tokenizer=tokenizer, + graph_tokenizer=None, + batch_size=batch_size, + timesteps=timesteps, + subset=subset, + shuffle_data=shuffle_data, + repeat=repeat, + version=version, + data_dir=data_dir, + debug=debug) + + def _process_graph_batch(self, graphs: List[Any]): + del graphs + return None + + def _process_graph(self, center_node: str, graph: Graph): + del center_node + del graph + return None + + def __next__(self): + batch = super().__next__() + # Data should be text-only. + del batch['graphs'] + return batch + + def return_faux_batch(self): + """Return a fake batch with the right shapes and types.""" + obs = np.zeros((self._batch_size, self._timesteps), dtype=np.int32) + target = np.zeros_like(obs) + should_reset = np.zeros_like(obs, dtype=np.float32) + mask = np.zeros_like(obs, dtype=np.float32) + return dict(obs=obs, target=target, should_reset=should_reset, mask=mask) + + +class Bow2TextDataset(BaseGraph2TextDataset): + """Dataset for bag-of-words to text.""" + + def _process_graph(self, center_node: str, graph: Graph): + """Process the graph part of a `ParsedGraphTextPair` instance.""" + # We don't use center node in a bag-of-words representation + del center_node + if self._subsample_nodes < 1.0: + graph = Graph.subsample_nodes(graph, self._subsample_nodes) + + bow = np.zeros(self._graph_tokenizer.vocab_size, dtype=np.int32) + for n in graph.nodes(): + for t in self._graph_tokenizer.encode_node(n): + bow[t] += 1 + for _, _, e in graph.edges(): + for t in self._graph_tokenizer.encode_edge(e): + bow[t] += 1 + return bow + + def _process_graph_batch(self, graphs: List[Any]): + """Process a batch of graph data. + + Args: + graphs: a list of graph data, each as returned by `_process_graph`. + + Returns: + processed_graphs: processed tensor(s) that can be directly fed into a + model. + """ + empty_graph_bow = np.zeros(self._graph_tokenizer.vocab_size, dtype=np.int32) + graphs = [g if g is not None else empty_graph_bow for g in graphs] + # B x [V] -> [B, V] + return np.stack(graphs, axis=0) + + def return_faux_batch(self): + obs = np.zeros((self._batch_size, self._timesteps), dtype=np.int32) + target = np.zeros_like(obs) + should_reset = np.zeros_like(obs, dtype=np.float32) + mask = np.zeros_like(obs, dtype=np.float32) + graphs = np.zeros((self._batch_size, self._graph_tokenizer.vocab_size), + dtype=np.float32) + return dict(obs=obs, target=target, should_reset=should_reset, mask=mask, + graphs=graphs) + + +def pack_graphs(graphs: List[Tuple[List[np.ndarray], + List[Tuple[int, int, np.ndarray]]]], + truncate_node: int, + truncate_edge: int, + pad_value=0) -> ParsedGraph: + """Pack a list of graphs into a batched ParsedGraph instance with truncation.""" + converted_graphs = [] + for nodes, edges in graphs: + converted_nodes = [] + for n in nodes: + if n.size < truncate_node: + node = tools.pad_to(n, truncate_node, axis=0, pad_value=pad_value) + else: + node = n[:truncate_node] + converted_nodes.append(node) + + sender = [] + receiver = [] + converted_edges = [] + for s, t, e in edges: + if e.size < truncate_edge: + edge = tools.pad_to(e, truncate_edge, axis=0, pad_value=pad_value) + else: + edge = e[:truncate_edge] + converted_edges.append(edge) + sender.append(s) + receiver.append(t) + + converted_graphs.append(ParsedGraph( + nodes=np.array(converted_nodes, dtype=np.int32), + edges=(np.array(converted_edges, dtype=np.int32) + if converted_edges else + np.zeros((0, truncate_edge), dtype=np.int32)), + sender=np.array(sender, dtype=np.int32), + receiver=np.array(receiver, dtype=np.int32), + graph_sizes=np.array([len(converted_nodes)], dtype=np.int32))) + + return batch_graphs(converted_graphs) + + +class Graph2TextDataset(BaseGraph2TextDataset): + """Graph-to-text dataset. + + This dataset encodes the graph nodes and edges using a bag-of-words + representation. + """ + + def __init__(self, + tokenizer: tokenizers.Tokenizer, + graph_tokenizer: tokenizers.GraphTokenizer, + batch_size: int = 1, + timesteps: int = 128, + subset: str = 'train', + shuffle_data: bool = False, + repeat: bool = False, + version: str = 'max256', + data_dir: str = None, + subsample_nodes: float = 1.0, + graph_retrieval_dataset: bool = False, + debug: bool = False): + """Constructor. + + Args: + tokenizer: the tokenizer for text data. + graph_tokenizer: the tokenizer for graph data. + batch_size: number of sequences to put in a batch. + timesteps: number of tokens to put in a sequence in a batch. + subset: which subset to load. + shuffle_data: whether to shuffle data. + repeat: set to True to repeat the dataset infinitely, otherwise do only + one pass through the dataset. + version: which version of the data to load. + data_dir: if set load data instead from this directory, and ignore + `version`. + subsample_nodes: the proportion of the nodes in a graph to keep. + graph_retrieval_dataset: whether to construct the dataset for graph + retrieval tasks. + debug: set to True to use debug mode and only load a small number of + examples. + """ + self._graph_feature_dim = graph_tokenizer.vocab_size + super().__init__(tokenizer=tokenizer, + graph_tokenizer=graph_tokenizer, + batch_size=batch_size, + timesteps=timesteps, + subset=subset, + shuffle_data=shuffle_data, + repeat=repeat, + version=version, + data_dir=data_dir, + subsample_nodes=subsample_nodes, + graph_retrieval_dataset=graph_retrieval_dataset, + debug=debug) + self._placeholder_graph = self._process_graph( + center_node='', + graph=Graph(nodes=[''], edges=[])) + + def _process_graph(self, center_node: str, graph: Graph): + """Process the graph part of a `ParsedGraphTextPair` instance.""" + if self._subsample_nodes < 1.0: + graph = Graph.subsample_nodes(graph, self._subsample_nodes, center_node) + + nodes = graph.nodes() + edges = graph.edges() + n_edges = len(edges) + + sender = np.zeros(n_edges, dtype=np.int32) + receiver = np.zeros(n_edges, dtype=np.int32) + + nodes_bow = [] + edges_bow = [] + + for n in nodes: + bow = collections.defaultdict(int) + for t in self._graph_tokenizer.encode_node(n): + bow[t] += 1 + nodes_bow.append(bow) + for i, (s, r, e) in enumerate(edges): + bow = collections.defaultdict(int) + for t in self._graph_tokenizer.encode_edge(e): + bow[t] += 1 + edges_bow.append(bow) + sender[i] = s + receiver[i] = r + + return (nodes_bow, edges_bow, sender, receiver, graph.node2id(center_node)) + + def _to_graph_with_features( + self, nodes_bow, edges_bow, sender, receiver, center_node_id): + """Convert the input to a `ParsedGraph` instance.""" + n_nodes = len(nodes_bow) + n_edges = len(edges_bow) + + # +1 for the center node indicator + nodes = np.zeros((n_nodes, self._graph_feature_dim + 1), dtype=np.float32) + edges = np.zeros((n_edges, self._graph_feature_dim), dtype=np.float32) + + nodes[center_node_id][-1] = 1 + for i, bow in enumerate(nodes_bow): + for t, c in bow.items(): + nodes[i][t] = c + for i, bow in enumerate(edges_bow): + for t, c in bow.items(): + edges[i][t] = c + + return ParsedGraph(nodes=nodes, edges=edges, sender=sender, + receiver=receiver, + graph_sizes=np.array([n_nodes], dtype=np.int32)) + + def _process_graph_batch(self, graphs: List[Any]): + """Process a batch of graph data. + + Args: + graphs: a list of graph data, each as returned by `_process_graph`. + + Returns: + processed_graphs: a list of processed tensor(s). + """ + graphs = [g if g is not None else self._placeholder_graph for g in graphs] + return [self._to_graph_with_features(*g) for g in graphs] + + def return_faux_batch(self) -> Dict[str, np.ndarray]: + """Return a fake batch with the right shapes and dimensions.""" + obs = np.zeros([self._batch_size, self._timesteps], dtype=np.int32) + target = np.zeros([self._batch_size, self._timesteps], dtype=np.int32) + should_reset = np.zeros_like(obs, np.float32) + mask = np.zeros_like(obs, np.float32) + # A batch should contain `batch_size` graphs. Here we make sure each graph + # has one node and one edge. + graphs = self._batch_size * [ParsedGraph( + nodes=np.zeros([1, self._graph_feature_dim + 1], dtype=np.float32), + edges=np.zeros([1, self._graph_feature_dim], dtype=np.float32), + sender=np.zeros([1], dtype=np.int32), + receiver=np.zeros([1], dtype=np.int32), + graph_sizes=np.ones(1, dtype=np.int32))] + return dict(obs=obs, target=target, mask=mask, should_reset=should_reset, + graphs=graphs) diff --git a/wikigraphs/wikigraphs/data/paired_dataset_test.py b/wikigraphs/wikigraphs/data/paired_dataset_test.py new file mode 100644 index 0000000..57a3830 --- /dev/null +++ b/wikigraphs/wikigraphs/data/paired_dataset_test.py @@ -0,0 +1,271 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tests for wikigraphs.data.paired_dataset.""" + +from absl.testing import absltest +from wikigraphs.data import io_tools +from wikigraphs.data import paired_dataset +from wikigraphs.data import tokenizers +from wikigraphs.data import wikitext + + +WIKITEXT_ROOT = '/tmp/data/wikitext-103' +WIKIGRAPHS_ROOT = '/tmp/data/wikigraphs' +WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv' +GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv' + + +class PairedDatasetTest(absltest.TestCase): + + def test_raw_paired_dataset_size(self): + dataset = paired_dataset.RawDataset( + subset='valid', shuffle_data=False, data_dir=WIKIGRAPHS_ROOT) + pairs = list(dataset) + self.assertLen(pairs, 48) + + self.assertEqual(pairs[0].title, 'Homarus_gammarus') + self.assertEqual(pairs[-1].title, 'Rakie_Ayola') + + # Make sure the content of the articles match the original + wikitext_set = wikitext.RawDataset( + subset='valid', shuffle_data=False, version='raw', + data_dir=WIKITEXT_ROOT) + title2article = {wikitext.normalize_title(a.title).replace(' ', ''): a.text + for a in wikitext_set} + for p in pairs: + title = io_tools.normalize_freebase_string(p.title).replace(' ', '') + article = title2article.get(title, None) + self.assertIsNotNone(article) + self.assertEqual(article, p.text) + + def test_graph_from_edges(self): + edges = ['A\tE1\tB', + 'A\tE2\tC', + 'B\tE1\tC', + 'C\tE3\tD', + 'C\tE2\tE'] + graph = paired_dataset.Graph.from_edges(edges) + self.assertEqual(graph.nodes(), ['A', 'B', 'C', 'D', 'E']) + self.assertEqual(graph.edges(), [(0, 1, 'E1'), + (0, 2, 'E2'), + (1, 2, 'E1'), + (2, 3, 'E3'), + (2, 4, 'E2')]) + + def test_graph_to_edges(self): + edges = ['A\tE1\tB', + 'A\tE2\tC', + 'B\tE1\tC', + 'C\tE3\tD', + 'C\tE2\tE'] + graph = paired_dataset.Graph.from_edges(edges) + self.assertEqual(graph.to_edges(), edges) + + def test_bow2text_dataset(self): + tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) + graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) + + batch_size = 4 + seq_len = 256 + dataset = paired_dataset.Bow2TextDataset( + tokenizer, + graph_tokenizer, + batch_size=batch_size, + timesteps=seq_len, + subset='valid', + subsample_nodes=0.7, + repeat=False, + data_dir=WIKIGRAPHS_ROOT) + + num_tokens = 0 + for batch in dataset: + num_tokens += batch['mask'].sum() + self.assertEqual(batch['graphs'].shape, + (batch_size, graph_tokenizer.vocab_size)) + + raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False) + raw_num_tokens = 0 + n_pairs = 0 + for pair in raw_dataset: + raw_num_tokens += len(tokenizer.encode( + pair.text, prepend_bos=True, append_eos=True)) + n_pairs += 1 + + # The first token of each example is not counted by `mask` as it masks the + # targets, and the first token of each example never appears in the targets. + self.assertEqual(raw_num_tokens, num_tokens + n_pairs) + + def test_graph2text_dataset(self): + tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) + graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) + + batch_size = 4 + seq_len = 256 + dataset = paired_dataset.Graph2TextDataset( + tokenizer, + graph_tokenizer, + batch_size=batch_size, + timesteps=seq_len, + subsample_nodes=0.8, + subset='valid', + data_dir=WIKIGRAPHS_ROOT) + data_iter = iter(dataset) + batch = next(data_iter) + self.assertEqual(batch['obs'].shape, (batch_size, seq_len)) + self.assertEqual(batch['target'].shape, (batch_size, seq_len)) + self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len)) + self.assertEqual(batch['mask'].shape, (batch_size, seq_len)) + self.assertIsInstance(batch['graphs'], list) + self.assertLen(batch['graphs'], batch_size) + for i in range(batch_size): + self.assertIsInstance(batch['graphs'][i], paired_dataset.ParsedGraph) + + # +1 for the center_node mask + self.assertEqual( + batch['graphs'][i].nodes.shape[-1], graph_tokenizer.vocab_size + 1) + self.assertEqual( + batch['graphs'][i].edges.shape[-1], graph_tokenizer.vocab_size) + n_edges = batch['graphs'][i].edges.shape[0] + self.assertEqual(batch['graphs'][i].sender.shape, (n_edges,)) + self.assertEqual(batch['graphs'][i].receiver.shape, (n_edges,)) + + # Make sure the token count matches across the tokenized data and the raw + # data set. + num_tokens = 0 + for batch in dataset: + num_tokens += batch['mask'].sum() + + raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False) + raw_num_tokens = 0 + n_pairs = 0 + for pair in raw_dataset: + raw_num_tokens += len(tokenizer.encode( + pair.text, prepend_bos=True, append_eos=True)) + n_pairs += 1 + + # The first token of each example is not counted by `mask` as it masks the + # targets, and the first token of each example never appears in the targets. + self.assertEqual(raw_num_tokens, num_tokens + n_pairs) + + def test_text_only_dataset(self): + tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) + + batch_size = 4 + seq_len = 256 + dataset = paired_dataset.TextOnlyDataset( + tokenizer, + batch_size=batch_size, + timesteps=seq_len, + subset='valid', + data_dir=WIKIGRAPHS_ROOT) + data_iter = iter(dataset) + batch = next(data_iter) + faux_batch = dataset.return_faux_batch() + + self.assertCountEqual(list(batch.keys()), + ['obs', 'target', 'should_reset', 'mask']) + self.assertCountEqual(list(faux_batch.keys()), + ['obs', 'target', 'should_reset', 'mask']) + for k, v in batch.items(): + faux_v = faux_batch[k] + self.assertEqual(v.shape, faux_v.shape) + self.assertEqual(v.dtype, faux_v.dtype) + + self.assertEqual(batch['obs'].shape, (batch_size, seq_len)) + self.assertEqual(batch['target'].shape, (batch_size, seq_len)) + self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len)) + self.assertEqual(batch['mask'].shape, (batch_size, seq_len)) + + num_tokens = 0 + for batch in dataset: + num_tokens += batch['mask'].sum() + + raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False) + raw_num_tokens = 0 + n_pairs = 0 + for pair in raw_dataset: + raw_num_tokens += len(tokenizer.encode( + pair.text, prepend_bos=True, append_eos=True)) + n_pairs += 1 + self.assertEqual(num_tokens + n_pairs, raw_num_tokens) + + def test_bow_retrieval_dataset(self): + tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) + graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) + + batch_size = 4 + seq_len = 256 + dataset = paired_dataset.Bow2TextDataset( + tokenizer, + graph_tokenizer, + batch_size=batch_size, + timesteps=seq_len, + subsample_nodes=0.8, + graph_retrieval_dataset=True, + subset='valid', + data_dir=WIKIGRAPHS_ROOT) + data_iter = iter(dataset) + batch = next(data_iter) + + self.assertEqual(batch['obs'].shape, (batch_size, seq_len)) + self.assertEqual(batch['target'].shape, (batch_size, seq_len)) + self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len)) + self.assertEqual(batch['mask'].shape, (batch_size, seq_len)) + self.assertEqual(batch['graph_id'].shape, (batch_size,)) + self.assertEqual(batch['seq_id'].shape, (batch_size,)) + + def test_graph_retrieval_dataset(self): + tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) + graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) + + batch_size = 4 + seq_len = 256 + dataset = paired_dataset.Graph2TextDataset( + tokenizer, + graph_tokenizer, + batch_size=batch_size, + timesteps=seq_len, + subsample_nodes=0.8, + graph_retrieval_dataset=True, + subset='valid', + data_dir=WIKIGRAPHS_ROOT) + data_iter = iter(dataset) + batch = next(data_iter) + + self.assertEqual(batch['obs'].shape, (batch_size, seq_len)) + self.assertEqual(batch['target'].shape, (batch_size, seq_len)) + self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len)) + self.assertEqual(batch['mask'].shape, (batch_size, seq_len)) + self.assertEqual(batch['graph_id'].shape, (batch_size,)) + self.assertEqual(batch['seq_id'].shape, (batch_size,)) + + +if __name__ == '__main__': + absltest.main() diff --git a/wikigraphs/wikigraphs/data/tokenizers.py b/wikigraphs/wikigraphs/data/tokenizers.py new file mode 100644 index 0000000..0d28619 --- /dev/null +++ b/wikigraphs/wikigraphs/data/tokenizers.py @@ -0,0 +1,230 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tokenizers for text data.""" + +import abc +import csv +import io +import re +from typing import List + +import nltk +import numpy as np + +from wikigraphs.data import io_tools + + +class Tokenizer(abc.ABC): + """Base class for tokenizers.""" + + @abc.abstractmethod + def encode(self, + inputs: str, + prepend_bos: bool = False, + append_eos: bool = False) -> np.ndarray: + """Encode input string into an array of token IDs. + + Args: + inputs: a string. + prepend_bos: set to True to add token at the beginning of the token + sequence. + append_eos: set to True to add token at the end of the token + sequence. + + Returns: + tokens: [n_tokens] int array. + """ + + @abc.abstractmethod + def decode(self, inputs) -> str: + """Decode a sequence of tokens back into a string. + + Args: + inputs: array or list of ints. + + Returns: + s: the decoded string using this tokenizer. + """ + + @property + @abc.abstractmethod + def vocab_size(self) -> int: + """Size of the vocabulary.""" + + @abc.abstractmethod + def pad_token(self) -> int: + """ID of the token.""" + + @abc.abstractmethod + def bos_token(self) -> int: + """ID of the token.""" + + +class WordTokenizer(Tokenizer): + """Word-level tokenizer for white-space separated text data.""" + + def __init__(self, vocab_file: str): + """Constructor. + + Args: + vocab_file: a csv vocab file. + """ + content = io_tools.read_txt_file(vocab_file, encoding='utf-8') + + with io.StringIO(content) as f: + r = csv.reader(f) + vocab = [w for w, _ in r] + + # Add pad and bos tokens to the vocab + to_add = ['', ''] + if '' not in vocab: + to_add.append('') + vocab = to_add + vocab + + # token-index mappings + self._t2i = {t: i for i, t in enumerate(vocab)} + self._i2t = {i: t for t, i in self._t2i.items()} + + self._unk_token = self._t2i[''] + self._bos_token = self._t2i[''] + self._pad_token = self._t2i[''] + + @property + def vocab_size(self): + return len(self._t2i) + + def encode(self, inputs, prepend_bos=False, append_eos=False): + tokens = [self._t2i.get(t, self._unk_token) for t in inputs.split(' ') if t] + if prepend_bos: + tokens = [self._bos_token] + tokens + if append_eos: + # Reuse as . + tokens.append(self._bos_token) + return np.array(tokens, dtype=np.int32) + + def decode(self, inputs): + """Decode a sequence of token IDs back into a string.""" + # Remove the first token if there is any. + if inputs[0] == self._bos_token: + inputs = inputs[1:] + tokens = [] + for i in inputs: + # Use also as and stop there. + if i == self._bos_token: + break + tokens.append(self._i2t[i]) + return ' '.join(tokens) + + def pad_token(self): + return self._pad_token + + def bos_token(self): + return self._bos_token + + +class GraphTokenizer: + """Tokenizer for the content on the graphs.""" + + def __init__(self, vocab_file: str): + """Constructor. + + Args: + vocab_file: path to a vocab file. + """ + content = io_tools.read_txt_file(vocab_file, encoding='utf-16') + + vocab = content.split('\n') + vocab = ['', '', ''] + vocab + + # token-index mappings + self._t2i = {t: i for i, t in enumerate(vocab)} + self._i2t = {i: t for t, i in self._t2i.items()} + + self._unk_token = self._t2i[''] + self._bos_token = self._t2i[''] + self._pad_token = self._t2i[''] + + @property + def vocab_size(self): + return len(self._t2i) + + def encode_node(self, txt: str) -> np.ndarray: + return np.array([self._t2i.get(t, self._unk_token) + for t in self.split_node(txt)]) + + def encode_edge(self, txt: str) -> np.ndarray: + return np.array([self._t2i.get(t, self._unk_token) + for t in self.split_edge(txt)]) + + def encode(self, inputs, prepend_bos=False, append_eos=False): + tokens = [self._t2i.get(t, self._unk_token) for t in inputs.split(' ') if t] + if prepend_bos: + tokens = [self._bos_token] + tokens + if append_eos: + # Reuse as . + tokens.append(self._bos_token) + return np.array(tokens, dtype=np.int32) + + def decode(self, inputs): + """Decode a sequence of token IDs back into a string.""" + # Remove the first token if there is any. + if inputs[0] == self._bos_token: + inputs = inputs[1:] + tokens = [] + for i in inputs: + # Use also as and stop there. + if i == self._bos_token: + break + tokens.append(self._i2t[i]) + return ' '.join(tokens) + + @classmethod + def split_node(cls, txt: str) -> List[str]: + """Split a node string into a sequence of tokens.""" + if txt[0] == '"' and txt[-1] == '"': # Node is a string literal. + tokens = nltk.wordpunct_tokenize(io_tools.normalize_freebase_string( + txt[1:-1].lower())) + for i, t in enumerate(tokens): + if t.isnumeric(): + tokens[i] = '' + return tokens + else: # If node is not a string literal it is always an entity. + return [''] + + @classmethod + def split_edge(cls, txt: str) -> List[str]: + """Split an edge string into a sequence of tokens.""" + return re.split('[._ ]+', txt.lower().split('/')[1]) + + def pad_token(self): + return self._pad_token + + def bos_token(self): + return self._bos_token diff --git a/wikigraphs/wikigraphs/data/tokenizers_test.py b/wikigraphs/wikigraphs/data/tokenizers_test.py new file mode 100644 index 0000000..3bef1da --- /dev/null +++ b/wikigraphs/wikigraphs/data/tokenizers_test.py @@ -0,0 +1,78 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tests for wikigraphs.data.tokenizers.""" +from absl.testing import absltest +from wikigraphs.data import tokenizers + + +WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv' +GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv' + + +class TokenizerTest(absltest.TestCase): + + def test_tokenizer(self): + tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) + # Vocab size must match published number. + self.assertEqual(tokenizer.vocab_size, 267735 + 2) + + s = 'Hello world ! \n How are you ?' + encoded = tokenizer.encode(s, prepend_bos=True) + self.assertEqual(encoded.shape, (9,)) + decoded = tokenizer.decode(encoded) + self.assertEqual(s, decoded) + + def test_graph_tokenizer_tokenize_nodes_edges(self): + self.assertEqual( + tokenizers.GraphTokenizer.split_node( + '"Hello, how are you?"'), + ['hello', ',', 'how', 'are', 'you', '?']) + self.assertEqual( + tokenizers.GraphTokenizer.split_node( + '"This building was built in 1998."'), + ['this', 'building', 'was', 'built', 'in', '', '.']) + self.assertEqual( + tokenizers.GraphTokenizer.split_node('ns/m.030ssw'), + ['']) + + self.assertEqual( + tokenizers.GraphTokenizer.split_edge('ns/common.topic.description'), + ['common', 'topic', 'description']) + self.assertEqual( + tokenizers.GraphTokenizer.split_edge('ns/type.object.name'), + ['type', 'object', 'name']) + + def test_graph_tokenizer_vocab(self): + tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) + self.assertEqual(tokenizer.vocab_size, 31087 + 3) + + +if __name__ == '__main__': + absltest.main() diff --git a/wikigraphs/wikigraphs/data/tools.py b/wikigraphs/wikigraphs/data/tools.py new file mode 100644 index 0000000..bfb3683 --- /dev/null +++ b/wikigraphs/wikigraphs/data/tools.py @@ -0,0 +1,242 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Some tools for processing data.""" + +from typing import Any, Iterator + +from absl import logging +import numpy as np + + +def pad_to(x: np.array, size: int, axis: int = -1, pad_value: float = 0.): + """Pad an array to the specified size along a specified axis.""" + if x.shape[axis] > size: + raise ValueError(f'Data item has size {x.shape[axis]} larger than {size}' + f' in axis {axis} already.') + elif x.shape[axis] == size: + return x + else: + pad_amount = [(0, 0)] * x.ndim + pad_amount[axis] = (0, size - x.shape[axis]) + return np.pad(x, pad_amount, mode='constant', constant_values=pad_value) + + +def dynamic_batch( + iterable: Iterator[Any], + batch_size: int, + timesteps: int, + return_incomplete_batch: bool = False, + pad: bool = False, + pad_value: float = 0.) -> Iterator[Any]: + """Batches up values in iterable to [batch_size, timesteps]. + + This function takes items from the iterable and pack them into the batch. + Sequence #i in the batch is a continuation from the sequence #i in the + previous batch, i.e. it will start from where the previous sequence left over. + When an item is finished, a new item is taken from the iterable to append to + the sequence and fill the batch. + + This function is designed for language modeling, where the input and the + target sequences are offset by one. We take that into account by making sure + neighboring batches have one token overlap. + + Example: + If the iterable contains [[0, 1, 2], [10, 11, 12, 13, 14], [20, 21, 22]] and + batch size is 2, timesteps is 3, then the first batch would be: + [[0, 1, 2], + [10, 11, 12]] + then the second batch: + [[2, 20, 21], # seq 0 finished, continuing from seq 2 + [12, 13, 14]] + Note the overlap of 1 token between these two batches, and the continuation + of sequences across batches. + + Args: + iterable: the iterable that yields sequences of integer token IDs. + batch_size: number of examples in a batch. + timesteps: length of each sequence in a batch. + return_incomplete_batch: if True return the incomplete batches, which + typically appears at the end of the dataset. + pad: set to True to pad the incomplete batches. + pad_value: the value to use for padding. + + Yields: + batches: where batches['obs'] are the observations of size + [batch_size, timesteps], and batches['should_reset'] is a 0/1 mask of + the same size that marks sequence boundaries, e.g. the entries in this + mask are all 0 except at locations where a new sequence is starting. + """ + if return_incomplete_batch and not pad: + raise ValueError( + f'If return_incomplete_batch, then pad must be True, currently {pad}.') + + iterator = iter(iterable) + elems = [] + for _ in range(batch_size): + item = next(iterator) + elems.append(item) + start_batch = [True] * batch_size + + iter_finished = False + loaded_finished = False + while not (iter_finished and loaded_finished): + batch = [] + for i in range(batch_size): + # should_reset value is 1 when a new sequence begins. + # [old[-3], old[-2], old[-1], new[0], new[1], new[2]] + # [0, 0, 0, 1, 0, 0] + should_reset = np.zeros(timesteps, np.float32) + if start_batch[i]: + should_reset[0] = 1 + + # Pack new examples in the sequence until they go beyond the required + # timesteps. + while len(elems[i]) < timesteps: + should_reset[len(elems[i])] = 1 + try: + item = next(iterator) + except StopIteration: + iter_finished = True + break + elems[i] = np.concatenate([elems[i], item]) + + batch.append(dict(obs=elems[i][:timesteps], should_reset=should_reset)) + # Shift and make sure we have a 1 token overlap. + elems[i] = elems[i][timesteps - 1:] + # Since the last token is shifted to be the first token of the next batch, + # We need to make sure reset is handled properly as well. + start_batch[i] = (should_reset[-1] == 1) + + # If any loaded data is not yet consumed in the output we should keep + # generating. + loaded_finished = all(e.size == 0 for e in elems) + + if not return_incomplete_batch: + elem_len = len(batch[0]['obs']) + if (elem_len != timesteps or + not all(len(x['obs']) == elem_len for x in batch[1:])): + logging.info('Dropping the (last?) incomplete batch.') + break + + if pad: + for x in batch: + x['obs'] = pad_to(x['obs'], timesteps, axis=0, pad_value=pad_value) + + yield dict( + obs=np.stack([x['obs'] for x in batch], axis=0), + should_reset=np.stack([x['should_reset'] for x in batch], axis=0)) + + +def batch_graph_text_pairs( + iterable: Iterator[Any], + batch_size: int, + timesteps: int, + pad_value: float = 0., + seq_and_graph_id: bool = False) -> Iterator[Any]: + """Batch graph and text pairs. + + This method pairs text with graphs, each text sequence is split into chunks + (with an overlap of 1) of size `timesteps`, and the graph associated with the + text is used and associated with each chunk as well. The last incomplete + chunk of each text sequence is padded with the `pad_value`. + + Args: + iterable: Iterable that returns (graph, sequence) pairs, graph can be + anything, and sequence is a list of tokenized token IDs. + batch_size: Number of examples in a batch. + timesteps: Window size for the sequences. + pad_value: Value to use for padding. + seq_and_graph_id: whether the `iterable` contains `seq_id` and `graph_id`. + + Yields: + batch: a batch of text sequence paired with graphs. + """ + iterator = iter(iterable) + seqs = [None] * batch_size + graphs = [None] * batch_size + graph_ids = [None] * batch_size + seq_ids = [None] * batch_size + + iter_finished = False + loaded_finished = False + while not (iter_finished and loaded_finished): + batch = [] + for idx in range(batch_size): + should_reset = np.zeros(timesteps, np.float32) + # pylint: disable=g-explicit-length-test + if seqs[idx] is None or len(seqs[idx]) == 0: + should_reset[0] = 1 + # One sequence exhausted, get the next example. + try: + if seq_and_graph_id: + (graph, seq), (graph_id, seq_id) = next(iterator) + graph_ids[idx] = graph_id + seq_ids[idx] = seq_id + else: + graph, seq = next(iterator) + seqs[idx] = seq + graphs[idx] = graph + except StopIteration: + iter_finished = True + seqs[idx] = np.array([pad_value], dtype=np.int32) + graphs[idx] = None + + example = dict(obs=seqs[idx][:timesteps], graph=graphs[idx], + should_reset=should_reset) + if seq_and_graph_id: + example['seq_id'] = seq_ids[idx] + example['graph_id'] = graph_ids[idx] + + batch.append(example) + # Make sure that there is an overlap, as we generate targets by shifting + # the tensor by 1 timestep. So the next element should be shifted by + # `timesteps - 1' timesteps. + seqs[idx] = seqs[idx][timesteps - 1:] + + # Make sure all loaded data are consumed in the output + loaded_finished = all(s.size == 0 for s in seqs) + + # Also check for the last batch to avoid returning a fully empty batch + if iter_finished and all([np.all(b['obs'] == pad_value) for b in batch]): + break + + # pad sequences to specified length + for e in batch: + e['obs'] = pad_to(e['obs'], timesteps, axis=0, pad_value=pad_value) + stacked_batch = dict( + obs=np.stack([e['obs'] for e in batch], axis=0), + graphs=[e['graph'] for e in batch], + should_reset=np.stack([e['should_reset'] for e in batch], axis=0)) + if seq_and_graph_id: + stacked_batch['seq_id'] = np.stack( + [e['seq_id'] for e in batch], axis=0) + stacked_batch['graph_id'] = np.stack( + [e['graph_id'] for e in batch], axis=0) + yield stacked_batch diff --git a/wikigraphs/wikigraphs/data/tools_test.py b/wikigraphs/wikigraphs/data/tools_test.py new file mode 100644 index 0000000..78b27a4 --- /dev/null +++ b/wikigraphs/wikigraphs/data/tools_test.py @@ -0,0 +1,195 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tests for wikigraphs.data.tools.""" +from absl.testing import absltest +import numpy as np +from wikigraphs.data import tools + + +class ToolsTest(absltest.TestCase): + + def test_padding(self): + np.testing.assert_array_equal( + tools.pad_to(np.arange(3), 5), + [0, 1, 2, 0, 0]) + np.testing.assert_array_equal( + tools.pad_to(np.arange(3), 5, pad_value=-1), + [0, 1, 2, -1, -1]) + np.testing.assert_array_equal( + tools.pad_to(np.arange(6).reshape(2, 3), 4, axis=0, pad_value=-1), + [[0, 1, 2], + [3, 4, 5], + [-1, -1, -1], + [-1, -1, -1]]) + np.testing.assert_array_equal( + tools.pad_to(np.arange(6).reshape(2, 3), 4, axis=-1, pad_value=-1), + [[0, 1, 2, -1], + [3, 4, 5, -1]]) + + def test_dynamic_batch(self): + def dataset(): + data = [[1, 2, 2, 2], + [1, 3, 3], + [1, 4]] + for d in data: + yield np.array(d, dtype=np.int32) + batches = list(tools.dynamic_batch( + dataset(), batch_size=2, timesteps=3, return_incomplete_batch=False)) + self.assertLen(batches, 1) + np.testing.assert_array_equal( + batches[0]['obs'], + [[1, 2, 2], [1, 3, 3]]) + np.testing.assert_array_equal( + batches[0]['should_reset'], + [[1, 0, 0], [1, 0, 0]]) + + batches = list(tools.dynamic_batch( + dataset(), batch_size=2, timesteps=3, return_incomplete_batch=True, + pad=True, pad_value=0)) + # Note `return_incomplete_batch=False` drops all the incomplete batches, + # and this can be more than just the last batch. + self.assertLen(batches, 3) + np.testing.assert_array_equal( + batches[0]['obs'], + [[1, 2, 2], [1, 3, 3]]) + np.testing.assert_array_equal( + batches[0]['should_reset'], + [[1, 0, 0], [1, 0, 0]]) + + np.testing.assert_array_equal( + batches[1]['obs'], + [[2, 2, 1], [3, 0, 0]]) + np.testing.assert_array_equal( + batches[1]['should_reset'], + [[0, 0, 1], [0, 1, 0]]) + + np.testing.assert_array_equal( + batches[2]['obs'], + [[1, 4, 0], [0, 0, 0]]) + np.testing.assert_array_equal( + batches[2]['should_reset'], + [[1, 0, 1], [1, 0, 0]]) + + with self.assertRaises(ValueError): + batches = list(tools.dynamic_batch( + dataset(), batch_size=2, timesteps=3, return_incomplete_batch=True, + pad=False)) + + def test_batch_graph_text_pairs(self): + def source(): + yield (1, np.array([1, 1, 1, 1, 1], dtype=np.int32)) + yield (2, np.array([2, 2], dtype=np.int32)) + yield (3, np.array([3, 3, 3, 3, 3, 3], dtype=np.int32)) + + data_iter = tools.batch_graph_text_pairs( + source(), batch_size=2, timesteps=3, pad_value=0) + + batches = list(data_iter) + self.assertLen(batches, 4) + + batch = batches[0] + np.testing.assert_array_equal( + batch['obs'], + [[1, 1, 1], + [2, 2, 0]]) + self.assertEqual(batch['graphs'], [1, 2]) + np.testing.assert_array_equal( + batch['should_reset'], + [[1, 0, 0], + [1, 0, 0]]) + + batch = batches[1] + np.testing.assert_array_equal( + batch['obs'], + [[1, 1, 1], + [3, 3, 3]]) + self.assertEqual(batch['graphs'], [1, 3]) + np.testing.assert_array_equal( + batch['should_reset'], + [[0, 0, 0], + [1, 0, 0]]) + + batch = batches[2] + np.testing.assert_array_equal( + batch['obs'], + [[1, 0, 0], + [3, 3, 3]]) + self.assertEqual(batch['graphs'], [1, 3]) + np.testing.assert_array_equal( + batch['should_reset'], + [[0, 0, 0], + [0, 0, 0]]) + + batch = batches[3] + np.testing.assert_array_equal( + batch['obs'], + [[0, 0, 0], + [3, 3, 0]]) + self.assertEqual(batch['graphs'], [None, 3]) + np.testing.assert_array_equal( + batch['should_reset'], + [[1, 0, 0], + [0, 0, 0]]) + + def test_batch_graph_text_pairs_batch_size1(self): + def source(): + yield (0, np.array([1, 2], dtype=np.int32)) + yield (1, np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)) + + data_iter = tools.batch_graph_text_pairs( + source(), batch_size=1, timesteps=3, pad_value=0) + + batches = list(data_iter) + + batch = batches[0] + np.testing.assert_array_equal(batch['obs'], [[1, 2, 0]]) + self.assertEqual(batch['graphs'], [0]) + np.testing.assert_array_equal(batch['should_reset'], [[1, 0, 0]]) + + batch = batches[1] + np.testing.assert_array_equal(batch['obs'], [[1, 2, 3]]) + self.assertEqual(batch['graphs'], [1]) + np.testing.assert_array_equal(batch['should_reset'], [[1, 0, 0]]) + + batch = batches[2] + np.testing.assert_array_equal(batch['obs'], [[3, 4, 5]]) + self.assertEqual(batch['graphs'], [1]) + np.testing.assert_array_equal(batch['should_reset'], [[0, 0, 0]]) + + batch = batches[3] + np.testing.assert_array_equal(batch['obs'], [[5, 6, 0]]) + self.assertEqual(batch['graphs'], [1]) + np.testing.assert_array_equal(batch['should_reset'], [[0, 0, 0]]) + + self.assertLen(batches, 4) + + +if __name__ == '__main__': + absltest.main() diff --git a/wikigraphs/wikigraphs/data/wikitext.py b/wikigraphs/wikigraphs/data/wikitext.py new file mode 100644 index 0000000..b22b70b --- /dev/null +++ b/wikigraphs/wikigraphs/data/wikitext.py @@ -0,0 +1,218 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Wikitext-103 datasets.""" + +import re +from typing import NamedTuple, List + +from absl import logging +import numpy as np + +from wikigraphs.data import dataset +from wikigraphs.data import tokenizers +from wikigraphs.data import tools + + +# The data directory that contains subdirectories `wikitext-103` and +# `wikitext-103-raw`. +DATA_ROOT = '/tmp/data/wikitext-103' + + +class WikitextArticle(NamedTuple): + title: str + text: str + + +def articles_from_file(file_path: str) -> List[WikitextArticle]: + """Read wikitext articles from file. + + Args: + file_path: path to the input `.tokens` file. + + Returns: + A list of `WikitextArticle` tuples. + """ + with open(file_path, mode='rb') as f: + content = f.read() + content = content.decode('utf-8') + + title_re = re.compile(r'(\n = ([^=].*) = \n \n)') + parts = title_re.split(content) + + # Skip the first part which is empty + return [WikitextArticle(title=parts[i+1], text=parts[i] + parts[i+2]) + for i in range(1, len(parts), 3)] + + +class RawDataset(dataset.Dataset): + """Raw text dataset for wikitext-103.""" + + def __init__(self, + subset: str = 'train', + shuffle_data: bool = False, + data_dir: str = None, + version: str = 'tokens'): + """Constructor. + + Args: + subset: which subset to load, one of {"train", "valid", "test"}. + shuffle_data: if set to True the data will be randomly shuffled. + data_dir: if provided will be used instead of the default `DATA_ROOT` as + the directory that contains the data. + version: one of {'tokens', 'raw'} + """ + super().__init__() + self._subset = subset + self._shuffle_data = shuffle_data + self._data_dir = data_dir or DATA_ROOT + self._dataset = None + + allowed_versions = ('tokens', 'raw') + if version not in allowed_versions: + raise ValueError(f'Version must be one of {allowed_versions}.') + self._version = version + + def _load_data(self): + """Prepare data for another pass through the dataset.""" + if self._dataset is None: + data_root = self._data_dir + ('-raw' if self._version == 'raw' else '') + self._dataset = articles_from_file( + f'{data_root}/wiki.{self._subset}.{self._version}') + + def source(): + n_articles = len(self._dataset) + if self._shuffle_data: + idx = np.random.permutation(n_articles) + else: + idx = np.arange(n_articles) + for i in range(n_articles): + yield self._dataset[idx[i]] + + return source() + + +def normalize_title(title: str) -> str: + """Normalize the wikitext article title by handling special characters.""" + return title.replace( + '@-@', '-').replace('@,@', ',').replace('@.@', '.').replace(' ', '') + + +class WikitextDataset(dataset.Dataset): + """Tokenized dataset for wikitext-103.""" + + def __init__(self, + tokenizer: tokenizers.Tokenizer, + batch_size: int = 1, + timesteps: int = 128, + subset: str = 'train', + shuffle_data: bool = True, + data_dir: str = None, + repeat: bool = False, + debug: bool = False, + **kwargs): + """Constructor. + + Args: + tokenizer: a tokenizer for text data. + batch_size: number of sequences to put into a batch. + timesteps: length of the sequences. + subset: which subset to load, one of {"train", "valid", "test"}. + shuffle_data: if set to True the data will be randomly shuffled. + data_dir: if provided will be used instead of the default `DATA_ROOT` as + the directory that contains the data. + repeat: set to False to go through the data only once, otherwise go + through the data indefinitely. + debug: set to True to only load a small amount of data for fast debugging. + **kwargs: other arguments (for interface compatibility). + """ + super().__init__() + self._tokenizer = tokenizer + self._batch_size = batch_size + self._timesteps = timesteps + self._subset = subset + self._shuffle_data = shuffle_data + self._data_dir = data_dir + self._repeat = repeat + self._debug = debug + self._dataset = None + + def _load_data(self): + """Prepare data for one pass through the dataset.""" + # Pre-tokenize everything in our dataset so we don't have to when going + # through the data more than once. + if not self._dataset: + raw_dataset = RawDataset( + subset=self._subset, shuffle_data=False, data_dir=self._data_dir) + if self._debug: + # Load a small number of examples for debugging. + self._dataset = [ + self._tokenizer.encode(next(raw_dataset).text, prepend_bos=True) + for _ in range(5)] + else: + self._dataset = [self._tokenizer.encode(item.text, prepend_bos=True) + for item in raw_dataset] + logging.info('%s set loaded, total %d examples.', + self._subset, len(self._dataset)) + + def source(): + idx = np.random.permutation(len(self._dataset)) + for i in idx: + yield self._dataset[i] + + def repeated_source(): + if self._repeat: + while True: + yield from source() + else: + yield from source() + + data_iter = tools.dynamic_batch( + repeated_source(), + self._batch_size, + self._timesteps + 1, # Extra token to count for the overlap. + return_incomplete_batch=True, + pad=True, + pad_value=self._tokenizer.pad_token()) + data_iter = map(lambda x: dict( # pylint: disable=g-long-lambda + obs=x['obs'][:, :-1], + target=x['obs'][:, 1:], + should_reset=x['should_reset'][:, :-1], + mask=(x['obs'][:, 1:] != self._tokenizer.pad_token()).astype( + np.float32), + ), data_iter) + return data_iter + + def return_faux_batch(self): + """Return a fake batch with the right shapes and dtypes.""" + obs = np.zeros((self._batch_size, self._timesteps), dtype=np.int32) + target = np.zeros_like(obs, dtype=np.int32) + should_reset = np.zeros_like(obs, dtype=np.float32) + mask = np.zeros_like(obs, dtype=np.float32) + return dict(obs=obs, target=target, should_reset=should_reset, mask=mask) diff --git a/wikigraphs/wikigraphs/data/wikitext_test.py b/wikigraphs/wikigraphs/data/wikitext_test.py new file mode 100644 index 0000000..5658da2 --- /dev/null +++ b/wikigraphs/wikigraphs/data/wikitext_test.py @@ -0,0 +1,84 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WikiGraphs is licensed under the terms of the Creative Commons +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license. +# +# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the +# terms of the Creative Commons Attribution-ShareAlike 4.0 International +# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at: +# +# https://creativecommons.org/licenses/by-sa/4.0/legalcode +# +# Freebase data is licensed by Google LLC under the terms of the Creative +# Commons CC BY 4.0 license. You may obtain a copy of the License at: +# +# https://creativecommons.org/licenses/by/4.0/legalcode +# +# ============================================================================== +"""Tests for wikigraphs.data.wikitext.""" + +from absl.testing import absltest +from wikigraphs.data import tokenizers +from wikigraphs.data import wikitext + + +WIKITEXT_ROOT = '/tmp/data/wikitext-103' +WIKITEXT_VOCAB_FILE = '/tmp/data/wikitext-vocab.csv' + + +class WikitextTest(absltest.TestCase): + + def test_wikitext_size(self): + valid_set = wikitext.RawDataset( + subset='valid', shuffle_data=False, data_dir=WIKITEXT_ROOT) + n_tokens = 0 + n_articles = 0 + for article in valid_set: + n_tokens += len([t for t in article.text.split(' ') if t]) + n_articles += 1 + + # Dataset size must match published values. + self.assertEqual(n_tokens, 217646) + self.assertEqual(n_articles, 60) + + def test_wikitext_dataset_size(self): + tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) + batch_size = 4 + timesteps = 256 + valid_set = wikitext.WikitextDataset( + tokenizer=tokenizer, batch_size=batch_size, timesteps=timesteps, + subset='valid', shuffle_data=False, repeat=False, + data_dir=WIKITEXT_ROOT) + n_tokens = 0 + n_bos = 0 + for batch in valid_set: + n_tokens += (batch['obs'] != tokenizer.pad_token()).sum() + n_bos += (batch['obs'] == tokenizer.bos_token()).sum() + self.assertEqual( + batch['obs'].shape, (batch_size, timesteps)) + self.assertEqual( + batch['target'].shape, (batch_size, timesteps)) + self.assertEqual( + batch['should_reset'].shape, (batch_size, timesteps)) + self.assertEqual( + batch['mask'].shape, (batch_size, timesteps)) + + n_tokens -= n_bos + self.assertEqual(n_tokens, 217646) + self.assertEqual(n_bos, 60) + + +if __name__ == '__main__': + absltest.main()