{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 6 (JAX): Transformers and Multi-Head Attention\n", "\n", "![Status](https://img.shields.io/static/v1.svg?label=Status&message=Finished&color=green)\n", "\n", "\n", "**Filled notebook:** \n", "[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/JAX/tutorial6/Transformers_and_MHAttention.ipynb)\n", "[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/JAX/tutorial6/Transformers_and_MHAttention.ipynb) \n", "**Pre-trained models:** \n", "[![View files on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/phlippe/saved_models/tree/main/JAX/tutorial6) \n", "**PyTorch version:**\n", "[![View on RTD](https://img.shields.io/static/v1.svg?logo=readthedocs&label=RTD&message=View%20On%20RTD&color=8CA1AF)](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html) \n", "**Author:** Phillip Lippe" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "**Note:** This notebook is written in JAX+Flax. It is a 1-to-1 translation of the original notebook written in PyTorch+PyTorch Lightning with almost identical results. For an introduction to JAX, check out our [Tutorial 2 (JAX): Introduction to JAX+Flax](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html). Further, throughout the notebook, we comment on major differences to the PyTorch version and provide explanations for the major parts of the JAX code.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "**Speed comparison**: We note the training times for all models in the PyTorch and the JAX implementation below (PyTorch v1.11, JAX v0.3.13). The models were trained on the same hardware (NVIDIA RTX3090, 24 core CPU) and we slightly adjusted the tutorials to use the exact same training settings (same data loading parameters, evaluation schedule, etc.). Overall, the JAX implementation is almost *4x faster* than PyTorch! However, this is mostly due to the small model and input sizes, and the code has not been explicitly designed for benchmarking. With larger models, larger batch sizes, or smaller GPUs, the speed up is expected to become considerably smaller (see e.g. [Tutorial 15](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial15/Vision_Transformer.html)).\n", " \n", "| Models | PyTorch | JAX |\n", "|-------------------|:-----------:|:----------:|\n", "| Reverse Sequence | 0min 26sec | 0min 7sec |\n", "| Anomaly Detection | 16min 34sec | 3min 45sec |\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we will discuss one of the most impactful architectures of the last 2 years: the Transformer model. Since the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et al. had been published in 2017, the Transformer architecture has continued to beat benchmarks in many domains, most importantly in Natural Language Processing. Transformers with an incredible amount of parameters can generate long, convincing [essays](https://www.theguardian.com/commentisfree/2020/sep/08/robot-wrote-this-article-gpt-3), and opened up new application fields of AI. As the hype of the Transformer architecture seems not to come to an end in the next years, it is important to understand how it works, and have implemented it yourself, which we will do in this notebook.\n", "\n", "Despite the huge success of Transformers in NLP, we will _not_ include the NLP domain in our notebook here. Why? Firstly, the Master AI at UvA offers many great NLP courses that will take a closer look at the application of the Transformer architecture in NLP ([NLP2](https://studiegids.uva.nl/xmlpages/page/2020-2021/zoek-vak/vak/79628), [Advanced Topics in Computational Semantics](https://studiegids.uva.nl/xmlpages/page/2020-2021/zoek-vak/vak/80162)). Secondly, assignment 2 takes already a closer look at language generation on character level, on which you could easily apply our transformer architecture. Finally, and most importantly, there is so much more to the Transformer architecture. NLP is the domain the Transformer architecture has been originally proposed for and had the greatest impact on, but it also accelerated research in other domains, recently even [Computer Vision](https://arxiv.org/abs/2010.11929). Thus, we focus here on what makes the Transformer and self-attention so powerful in general. In [Tutorial 15](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial15/Vision_Transformer.html), we will discuss the application of Transformers in Computer Vision.\n", "\n", "Below, we import our standard libraries. We use [JAX](https://jax.readthedocs.io/en/latest/) as acceleration backend, [Flax](https://flax.readthedocs.io/en/latest/index.html) for implementing neural networks, and [Optax](https://optax.readthedocs.io/en/latest/index.html) to optimize the models." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: gpu:0\n" ] } ], "source": [ "## Standard libraries\n", "import os\n", "import numpy as np\n", "import math\n", "import json\n", "from functools import partial\n", "\n", "## Imports for plotting\n", "import matplotlib.pyplot as plt\n", "plt.set_cmap('cividis')\n", "%matplotlib inline \n", "from IPython.display import set_matplotlib_formats\n", "set_matplotlib_formats('svg', 'pdf') # For export\n", "from matplotlib.colors import to_rgb\n", "import matplotlib\n", "matplotlib.rcParams['lines.linewidth'] = 2.0\n", "import seaborn as sns\n", "sns.reset_orig()\n", "\n", "## tqdm for loading bars\n", "from tqdm.auto import tqdm\n", "\n", "## To run JAX on TPU in Google Colab, uncomment the two lines below\n", "# import jax.tools.colab_tpu\n", "# jax.tools.colab_tpu.setup_tpu()\n", "\n", "## JAX\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import random\n", "# Seeding for random operations\n", "main_rng = random.PRNGKey(42)\n", "\n", "## Flax (NN in JAX)\n", "try:\n", " import flax\n", "except ModuleNotFoundError: # Install flax if missing\n", " !pip install --quiet flax\n", " import flax\n", "from flax import linen as nn\n", "from flax.training import train_state, checkpoints\n", "\n", "## Optax (Optimizers in JAX)\n", "try:\n", " import optax\n", "except ModuleNotFoundError: # Install optax if missing\n", " !pip install --quiet optax\n", " import optax\n", "\n", "## PyTorch\n", "import torch\n", "import torch.utils.data as data\n", "from torch.utils.tensorboard import SummaryWriter\n", "import torchvision\n", "from torchvision import transforms\n", "from torchvision.datasets import CIFAR100\n", "\n", "# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)\n", "DATASET_PATH = \"../../data\"\n", "# Path to the folder where the pretrained models are saved\n", "CHECKPOINT_PATH = \"../../saved_models/tutorial6_jax\"\n", "\n", "print(\"Device:\", jax.devices()[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Two pre-trained models are downloaded below. Make sure to have adjusted your `CHECKPOINT_PATH` before running this code if not already done." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import urllib.request\n", "from urllib.error import HTTPError\n", "# Github URL where saved models are stored for this tutorial\n", "base_url = \"https://raw.githubusercontent.com/phlippe/saved_models/main/JAX/tutorial6/\"\n", "# Files to download\n", "pretrained_files = [\"ReverseTask.ckpt\", \"SetAnomalyTask.ckpt\"]\n", "\n", "# Create checkpoint path if it doesn't exist yet\n", "os.makedirs(CHECKPOINT_PATH, exist_ok=True)\n", "\n", "# For each file, check whether it already exists. If not, try downloading it.\n", "for file_name in pretrained_files:\n", " file_path = os.path.join(CHECKPOINT_PATH, file_name)\n", " if \"/\" in file_name:\n", " os.makedirs(file_path.rsplit(\"/\",1)[0], exist_ok=True)\n", " if not os.path.isfile(file_path):\n", " file_url = base_url + file_name\n", " print(f\"Downloading {file_url}...\")\n", " try:\n", " urllib.request.urlretrieve(file_url, file_path)\n", " except HTTPError as e:\n", " print(\"Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\\n\", e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Transformer architecture\n", "\n", "In the first part of this notebook, we will implement the Transformer architecture by hand. As the architecture is so popular, its main components are already integrated into Flax ([SelfAttention](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.SelfAttention.html#flax.linen.SelfAttention), [MultiHeadAttention](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.MultiHeadDotProductAttention.html#flax.linen.MultiHeadDotProductAttention)) and there exist several implementations (e.g. in [Trax](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)) and pre-trained models (e.g. on [Hugging Face](https://huggingface.co/docs/transformers/index)). However, we will implement it here ourselves, to get through to the smallest details.\n", "\n", "There are of course many more tutorials out there about attention and Transformers. Below, we list a few that are worth exploring if you are interested in the topic and might want yet another perspective on the topic after this one:\n", "\n", "* [Transformer: A Novel Neural Network Architecture for Language Understanding (Jakob Uszkoreit, 2017)](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html) - The original Google blog post about the Transformer paper, focusing on the application in machine translation.\n", "* [The Illustrated Transformer (Jay Alammar, 2018)](http://jalammar.github.io/illustrated-transformer/) - A very popular and great blog post intuitively explaining the Transformer architecture with many nice visualizations. The focus is on NLP.\n", "* [Attention? Attention! (Lilian Weng, 2018)](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html) - A nice blog post summarizing attention mechanisms in many domains including vision.\n", "* [Illustrated: Self-Attention (Raimi Karim, 2019)](https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a) - A nice visualization of the steps of self-attention. Recommended going through if the explanation below is too abstract for you.\n", "* [The Transformer family (Lilian Weng, 2020)](https://lilianweng.github.io/lil-log/2020/04/07/the-transformer-family.html) - A very detailed blog post reviewing more variants of Transformers besides the original one." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What is Attention?\n", "\n", "The attention mechanism describes a recent new group of layers in neural networks that has attracted a lot of interest in the past few years, especially in sequence tasks. There are a lot of different possible definitions of \"attention\" in the literature, but the one we will use here is the following: _the attention mechanism describes a weighted average of (sequence) elements with the weights dynamically computed based on an input query and elements' keys_. So what does this exactly mean? The goal is to take an average over the features of multiple elements. However, instead of weighting each element equally, we want to weight them depending on their actual values. In other words, we want to dynamically decide on which inputs we want to \"attend\" more than others. In particular, an attention mechanism has usually four parts we need to specify:\n", "\n", "* **Query**: The query is a feature vector that describes what we are looking for in the sequence, i.e. what would we maybe want to pay attention to.\n", "* **Keys**: For each input element, we have a key which is again a feature vector. This feature vector roughly describes what the element is \"offering\", or when it might be important. The keys should be designed such that we can identify the elements we want to pay attention to based on the query.\n", "* **Values**: For each input element, we also have a value vector. This feature vector is the one we want to average over.\n", "* **Score function**: To rate which elements we want to pay attention to, we need to specify a score function $f_{attn}$. The score function takes the query and a key as input, and output the score/attention weight of the query-key pair. It is usually implemented by simple similarity metrics like a dot product, or a small MLP.\n", "\n", "\n", "The weights of the average are calculated by a softmax over all score function outputs. Hence, we assign those value vectors a higher weight whose corresponding key is most similar to the query. If we try to describe it with pseudo-math, we can write: \n", "\n", "$$\n", "\\alpha_i = \\frac{\\exp\\left(f_{attn}\\left(\\text{key}_i, \\text{query}\\right)\\right)}{\\sum_j \\exp\\left(f_{attn}\\left(\\text{key}_j, \\text{query}\\right)\\right)}, \\hspace{5mm} \\text{out} = \\sum_i \\alpha_i \\cdot \\text{value}_i\n", "$$\n", "\n", "Visually, we can show the attention over a sequence of words as follows:\n", "\n", "
\n", "\n", "For every word, we have one key and one value vector. The query is compared to all keys with a score function (in this case the dot product) to determine the weights. The softmax is not visualized for simplicity. Finally, the value vectors of all words are averaged using the attention weights.\n", "\n", "Most attention mechanisms differ in terms of what queries they use, how the key and value vectors are defined, and what score function is used. The attention applied inside the Transformer architecture is called **self-attention**. In self-attention, each sequence element provides a key, value, and query. For each element, we perform an attention layer where based on its query, we check the similarity of the all sequence elements' keys, and returned a different, averaged value vector for each element. We will now go into a bit more detail by first looking at the specific implementation of the attention mechanism which is in the Transformer case the scaled dot product attention." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scaled Dot Product Attention\n", "\n", "The core concept behind self-attention is the scaled dot product attention. Our goal is to have an attention mechanism with which any element in a sequence can attend to any other while still being efficient to compute. The dot product attention takes as input a set of queries $Q\\in\\mathbb{R}^{T\\times d_k}$, keys $K\\in\\mathbb{R}^{T\\times d_k}$ and values $V\\in\\mathbb{R}^{T\\times d_v}$ where $T$ is the sequence length, and $d_k$ and $d_v$ are the hidden dimensionality for queries/keys and values respectively. For simplicity, we neglect the batch dimension for now. The attention value from element $i$ to $j$ is based on its similarity of the query $Q_i$ and key $K_j$, using the dot product as the similarity metric. In math, we calculate the dot product attention as follows:\n", "\n", "$$\\text{Attention}(Q,K,V)=\\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$\n", "\n", "The matrix multiplication $QK^T$ performs the dot product for every possible pair of queries and keys, resulting in a matrix of the shape $T\\times T$. Each row represents the attention logits for a specific element $i$ to all other elements in the sequence. On these, we apply a softmax and multiply with the value vector to obtain a weighted mean (the weights being determined by the attention). Another perspective on this attention mechanism offers the computation graph which is visualized below (figure credit - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)).\n", "\n", "
\n", "\n", "One aspect we haven't discussed yet is the scaling factor of $1/\\sqrt{d_k}$. This scaling factor is crucial to maintain an appropriate variance of attention values after initialization. Remember that we intialize our layers with the intention of having equal variance throughout the model, and hence, $Q$ and $K$ might also have a variance close to $1$. However, performing a dot product over two vectors with a variance $\\sigma^2$ results in a scalar having $d_k$-times higher variance: \n", "\n", "$$q_i \\sim \\mathcal{N}(0,\\sigma^2), k_i \\sim \\mathcal{N}(0,\\sigma^2) \\to \\text{Var}\\left(\\sum_{i=1}^{d_k} q_i\\cdot k_i\\right) = \\sigma^4\\cdot d_k$$\n", "\n", "\n", "If we do not scale down the variance back to $\\sim\\sigma^2$, the softmax over the logits will already saturate to $1$ for one random element and $0$ for all others. The gradients through the softmax will be close to zero so that we can't learn the parameters appropriately. Note that the extra factor of $\\sigma^2$, i.e., having $\\sigma^4$ instead of $\\sigma^2$, is usually not an issue, since we keep the original variance $\\sigma^2$ close to $1$ anyways.\n", "\n", "The block `Mask (opt.)` in the diagram above represents the optional masking of specific entries in the attention matrix. This is for instance used if we stack multiple sequences with different lengths into a batch. To still benefit from parallelization in PyTorch, we pad the sentences to the same length and mask out the padding tokens during the calculation of the attention values. This is usually done by setting the respective attention logits to a very low value. \n", "\n", "After we have discussed the details of the scaled dot product attention block, we can write a function below which computes the output features given the triple of queries, keys, and values:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def scaled_dot_product(q, k, v, mask=None):\n", " d_k = q.shape[-1]\n", " attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))\n", " attn_logits = attn_logits / math.sqrt(d_k)\n", " if mask is not None:\n", " attn_logits = jnp.where(mask == 0, -9e15, attn_logits)\n", " attention = nn.softmax(attn_logits, axis=-1)\n", " values = jnp.matmul(attention, v)\n", " return values, attention" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that our code above supports any additional dimensionality in front of the sequence length so that we can also use it for batches. However, for a better understanding, let's generate a few random queries, keys, and value vectors, and calculate the attention outputs:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q\n", " [[-0.6613315 0.70056266]\n", " [ 0.08239268 -1.7793142 ]\n", " [-0.04378588 1.0965251 ]]\n", "K\n", " [[ 1.7257481 0.35568172]\n", " [ 1.3034704 1.2873708 ]\n", " [ 1.6871481 -0.5714404 ]]\n", "V\n", " [[ 1.5129997 1.1050899 ]\n", " [ 0.27949408 -0.46224892]\n", " [-1.1003422 -1.1437942 ]]\n", "Values\n", " [[ 0.376226 -0.14656176]\n", " [-0.42778552 -0.5989564 ]\n", " [ 0.4362476 -0.11678296]]\n", "Attention\n", " [[0.27963293 0.54049295 0.17987415]\n", " [0.22194655 0.06706189 0.71099156]\n", " [0.27977085 0.58373076 0.13649833]]\n" ] } ], "source": [ "seq_len, d_k = 3, 2\n", "main_rng, rand1 = random.split(main_rng)\n", "qkv = random.normal(rand1, (3, seq_len, d_k))\n", "q, k, v = qkv[0], qkv[1], qkv[2]\n", "values, attention = scaled_dot_product(q, k, v)\n", "print(\"Q\\n\", q)\n", "print(\"K\\n\", k)\n", "print(\"V\\n\", v)\n", "print(\"Values\\n\", values)\n", "print(\"Attention\\n\", attention)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before continuing, make sure you can follow the calculation of the specific values here, and also check it by hand. It is important to fully understand how the scaled dot product attention is calculated." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-Head Attention\n", "\n", "The scaled dot product attention allows a network to attend over a sequence. However, often there are multiple different aspects a sequence element wants to attend to, and a single weighted average is not a good option for it. This is why we extend the attention mechanisms to multiple heads, i.e. multiple different query-key-value triplets on the same features. Specifically, given a query, key, and value matrix, we transform those into $h$ sub-queries, sub-keys, and sub-values, which we pass through the scaled dot product attention independently. Afterward, we concatenate the heads and combine them with a final weight matrix. Mathematically, we can express this operation as:\n", "\n", "$$\n", "\\begin{split}\n", " \\text{Multihead}(Q,K,V) & = \\text{Concat}(\\text{head}_1,...,\\text{head}_h)W^{O}\\\\\n", " \\text{where } \\text{head}_i & = \\text{Attention}(QW_i^Q,KW_i^K, VW_i^V)\n", "\\end{split}\n", "$$\n", "\n", "We refer to this as Multi-Head Attention layer with the learnable parameters $W_{1...h}^{Q}\\in\\mathbb{R}^{D\\times d_k}$, $W_{1...h}^{K}\\in\\mathbb{R}^{D\\times d_k}$, $W_{1...h}^{V}\\in\\mathbb{R}^{D\\times d_v}$, and $W^{O}\\in\\mathbb{R}^{h\\cdot d_v\\times d_{out}}$ ($D$ being the input dimensionality). Expressed in a computational graph, we can visualize it as below (figure credit - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)).\n", "\n", "
\n", "\n", "How are we applying a Multi-Head Attention layer in a neural network, where we don't have an arbitrary query, key, and value vector as input? Looking at the computation graph above, a simple but effective implementation is to set the current feature map in a NN, $X\\in\\mathbb{R}^{B\\times T\\times d_{\\text{model}}}$, as $Q$, $K$ and $V$ ($B$ being the batch size, $T$ the sequence length, $d_{\\text{model}}$ the hidden dimensionality of $X$). The consecutive weight matrices $W^{Q}$, $W^{K}$, and $W^{V}$ can transform $X$ to the corresponding feature vectors that represent the queries, keys, and values of the input. Note that commonly, these weight matrices are initialized with the Xavier initialization. However, the layer is usually not too sensitive to the initialization, as long as the variance of $Q$ and $K$ do not become too large.\n", "With this in mind, we can implement the Multi-Head Attention module below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Helper function to support different mask shapes.\n", "# Output shape supports (batch_size, number of heads, seq length, seq length)\n", "# If 2D: broadcasted over batch size and number of heads\n", "# If 3D: broadcasted over number of heads\n", "# If 4D: leave as is\n", "def expand_mask(mask):\n", " assert mask.ndim >= 2, \"Mask must be at least 2-dimensional with seq_length x seq_length\"\n", " if mask.ndim == 3:\n", " mask = mask.unsqueeze(1)\n", " while mask.ndim < 4:\n", " mask = mask.unsqueeze(0)\n", " return mask" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class MultiheadAttention(nn.Module):\n", " embed_dim : int # Output dimension\n", " num_heads : int # Number of parallel heads (h)\n", " \n", " def setup(self):\n", " # Stack all weight matrices 1...h and W^Q, W^K, W^V together for efficiency\n", " # Note that in many implementations you see \"bias=False\" which is optional\n", " self.qkv_proj = nn.Dense(3*self.embed_dim,\n", " kernel_init=nn.initializers.xavier_uniform(), # Weights with Xavier uniform init\n", " bias_init=nn.initializers.zeros # Bias init with zeros\n", " )\n", " self.o_proj = nn.Dense(self.embed_dim,\n", " kernel_init=nn.initializers.xavier_uniform(),\n", " bias_init=nn.initializers.zeros)\n", "\n", " def __call__(self, x, mask=None):\n", " batch_size, seq_length, embed_dim = x.shape\n", " if mask is not None:\n", " mask = expand_mask(mask)\n", " qkv = self.qkv_proj(x)\n", " \n", " # Separate Q, K, V from linear output\n", " qkv = qkv.reshape(batch_size, seq_length, self.num_heads, -1)\n", " qkv = qkv.transpose(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]\n", " q, k, v = jnp.array_split(qkv, 3, axis=-1)\n", " \n", " # Determine value outputs\n", " values, attention = scaled_dot_product(q, k, v, mask=mask)\n", " values = values.transpose(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]\n", " values = values.reshape(batch_size, seq_length, embed_dim)\n", " o = self.o_proj(values)\n", " \n", " return o, attention" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Out (3, 16, 128) Attention (3, 4, 16, 16)\n" ] } ], "source": [ "## Test MultiheadAttention implementation\n", "# Example features as input\n", "main_rng, x_rng = random.split(main_rng)\n", "x = random.normal(x_rng, (3, 16, 128))\n", "# Create attention\n", "mh_attn = MultiheadAttention(embed_dim=128, num_heads=4)\n", "# Initialize parameters of attention with random key and inputs\n", "main_rng, init_rng = random.split(main_rng)\n", "params = mh_attn.init(init_rng, x)['params']\n", "# Apply attention with parameters on the inputs\n", "out, attn = mh_attn.apply({'params': params}, x)\n", "print('Out', out.shape, 'Attention', attn.shape)\n", "\n", "del mh_attn, params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One crucial characteristic of the multi-head attention is that it is permutation-equivariant with respect to its inputs. This means that if we switch two input elements in the sequence, e.g. $X_1\\leftrightarrow X_2$ (neglecting the batch dimension for now), the output is exactly the same besides the elements 1 and 2 switched. Hence, the multi-head attention is actually looking at the input not as a sequence, but as a set of elements. This property makes the multi-head attention block and the Transformer architecture so powerful and widely applicable! But what if the order of the input is actually important for solving the task, like language modeling? The answer is to encode the position in the input features, which we will take a closer look at later (topic _Positional encodings_ below).\n", "\n", "Before moving on to creating the Transformer architecture, we can compare the self-attention operation with our other common layer competitors for sequence data: convolutions and recurrent neural networks. Below you can find a table by [Vaswani et al. (2017)](https://arxiv.org/abs/1706.03762) on the complexity per layer, the number of sequential operations, and maximum path length. The complexity is measured by the upper bound of the number of operations to perform, while the maximum path length represents the maximum number of steps a forward or backward signal has to traverse to reach any other position. The lower this length, the better gradient signals can backpropagate for long-range dependencies. Let's take a look at the table below:\n", "\n", "\n", "
\n", "\n", "$n$ is the sequence length, $d$ is the representation dimension and $k$ is the kernel size of convolutions. In contrast to recurrent networks, the self-attention layer can parallelize all its operations making it much faster to execute for smaller sequence lengths. However, when the sequence length exceeds the hidden dimensionality, self-attention becomes more expensive than RNNs. One way of reducing the computational cost for long sequences is by restricting the self-attention to a neighborhood of inputs to attend over, denoted by $r$. Nevertheless, there has been recently a lot of work on more efficient Transformer architectures that still allow long dependencies, of which you can find an overview in the paper by [Tay et al. (2020)](https://arxiv.org/abs/2009.06732) if interested." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Transformer Encoder\n", "\n", "Next, we will look at how to apply the multi-head attention block inside the Transformer architecture. Originally, the Transformer model was designed for machine translation. Hence, it got an encoder-decoder structure where the encoder takes as input the sentence in the original language and generates an attention-based representation. On the other hand, the decoder attends over the encoded information and generates the translated sentence in an autoregressive manner, as in a standard RNN. While this structure is extremely useful for Sequence-to-Sequence tasks with the necessity of autoregressive decoding, we will focus here on the encoder part. Many advances in NLP have been made using pure encoder-based Transformer models (if interested, models include the [BERT](https://arxiv.org/abs/1810.04805)-family, the [Vision Transformer](https://arxiv.org/abs/2010.11929), and more), and in our tutorial, we will also mainly focus on the encoder part. If you have understood the encoder architecture, the decoder is a very small step to implement as well. The full Transformer architecture looks as follows (figure credit - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)).:\n", "\n", "
\n", "\n", "The encoder consists of $N$ identical blocks that are applied in sequence. Taking as input $x$, it is first passed through a Multi-Head Attention block as we have implemented above. The output is added to the original input using a residual connection, and we apply a consecutive Layer Normalization on the sum. Overall, it calculates $\\text{LayerNorm}(x+\\text{Multihead}(x,x,x))$ ($x$ being $Q$, $K$ and $V$ input to the attention layer). The residual connection is crucial in the Transformer architecture for two reasons: \n", "\n", "1. Similar to ResNets, Transformers are designed to be very deep. Some models contain more than 24 blocks in the encoder. Hence, the residual connections are crucial for enabling a smooth gradient flow through the model.\n", "2. Without the residual connection, the information about the original sequence is lost. Remember that the Multi-Head Attention layer ignores the position of elements in a sequence, and can only learn it based on the input features. Removing the residual connections would mean that this information is lost after the first attention layer (after initialization), and with a randomly initialized query and key vector, the output vectors for position $i$ has no relation to its original input. All outputs of the attention are likely to represent similar/same information, and there is no chance for the model to distinguish which information came from which input element. An alternative option to residual connection would be to fix at least one head to focus on its original input, but this is very inefficient and does not have the benefit of the improved gradient flow.\n", "\n", "The Layer Normalization also plays an important role in the Transformer architecture as it enables faster training and provides small regularization. Additionally, it ensures that the features are in a similar magnitude among the elements in the sequence. We are not using Batch Normalization because it depends on the batch size which is often small with Transformers (they require a lot of GPU memory), and BatchNorm has shown to perform particularly bad in language as the features of words tend to have a much higher variance (there are many, very rare words which need to be considered for a good distribution estimate).\n", "\n", "Additionally to the Multi-Head Attention, a small fully connected feed-forward network is added to the model, which is applied to each position separately and identically. Specifically, the model uses a Linear$\\to$ReLU$\\to$Linear MLP. The full transformation including the residual connection can be expressed as: \n", "\n", "$$\n", "\\begin{split}\n", " \\text{FFN}(x) & = \\max(0, xW_1+b_1)W_2 + b_2\\\\\n", " x & = \\text{LayerNorm}(x + \\text{FFN}(x))\n", "\\end{split}\n", "$$\n", "\n", "This MLP adds extra complexity to the model and allows transformations on each sequence element separately. You can imagine as this allows the model to \"post-process\" the new information added by the previous Multi-Head Attention, and prepare it for the next attention block. Usually, the inner dimensionality of the MLP is 2-8$\\times$ larger than $d_{\\text{model}}$, i.e. the dimensionality of the original input $x$. The general advantage of a wider layer instead of a narrow, multi-layer MLP is the faster, parallelizable execution.\n", "\n", "Finally, after looking at all parts of the encoder architecture, we can start implementing it below. We first start by implementing a single encoder block. Additionally to the layers described above, we will add dropout layers in the MLP and on the output of the MLP and Multi-Head Attention for regularization." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class EncoderBlock(nn.Module):\n", " input_dim : int # Input dimension is needed here since it is equal to the output dimension (residual connection)\n", " num_heads : int\n", " dim_feedforward : int\n", " dropout_prob : float\n", " \n", " def setup(self):\n", " # Attention layer\n", " self.self_attn = MultiheadAttention(embed_dim=self.input_dim, \n", " num_heads=self.num_heads)\n", " # Two-layer MLP\n", " self.linear = [\n", " nn.Dense(self.dim_feedforward),\n", " nn.Dropout(self.dropout_prob),\n", " nn.relu,\n", " nn.Dense(self.input_dim)\n", " ]\n", " # Layers to apply in between the main layers\n", " self.norm1 = nn.LayerNorm()\n", " self.norm2 = nn.LayerNorm()\n", " self.dropout = nn.Dropout(self.dropout_prob)\n", "\n", " def __call__(self, x, mask=None, train=True):\n", " # Attention part\n", " attn_out, _ = self.self_attn(x, mask=mask)\n", " x = x + self.dropout(attn_out, deterministic=not train)\n", " x = self.norm1(x)\n", " \n", " # MLP part\n", " linear_out = x\n", " for l in self.linear:\n", " linear_out = l(linear_out) if not isinstance(l, nn.Dropout) else l(linear_out, deterministic=not train)\n", " x = x + self.dropout(linear_out, deterministic=not train)\n", " x = self.norm2(x)\n", " \n", " return x" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Out (3, 16, 128)\n" ] } ], "source": [ "## Test EncoderBlock implementation\n", "# Example features as input\n", "main_rng, x_rng = random.split(main_rng)\n", "x = random.normal(x_rng, (3, 16, 128))\n", "# Create encoder block\n", "encblock = EncoderBlock(input_dim=128, num_heads=4, dim_feedforward=512, dropout_prob=0.1)\n", "# Initialize parameters of encoder block with random key and inputs\n", "main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)\n", "params = encblock.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']\n", "# Apply encoder block with parameters on the inputs\n", "# Since dropout is stochastic, we need to pass a rng to the forward\n", "main_rng, dropout_apply_rng = random.split(main_rng)\n", "out = encblock.apply({'params': params}, x, train=True, rngs={'dropout': dropout_apply_rng})\n", "print('Out', out.shape)\n", "\n", "del encblock, params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on this block, we can implement a module for the full Transformer encoder. Additionally to a forward function that iterates through the sequence of encoder blocks, we also provide a function called `get_attention_maps`. The idea of this function is to return the attention probabilities for all Multi-Head Attention blocks in the encoder. This helps us in understanding, and in a sense, explaining the model. However, the attention probabilities should be interpreted with a grain of salt as it does not necessarily reflect the true interpretation of the model (there is a series of papers about this, including [Attention is not Explanation](https://arxiv.org/abs/1902.10186) and [Attention is not not Explanation](https://arxiv.org/abs/1908.04626))." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class TransformerEncoder(nn.Module):\n", " num_layers : int\n", " input_dim : int\n", " num_heads : int\n", " dim_feedforward : int\n", " dropout_prob : float\n", " \n", " def setup(self):\n", " self.layers = [EncoderBlock(self.input_dim, self.num_heads, self.dim_feedforward, self.dropout_prob) for _ in range(self.num_layers)]\n", "\n", " def __call__(self, x, mask=None, train=True):\n", " for l in self.layers:\n", " x = l(x, mask=mask, train=train)\n", " return x\n", "\n", " def get_attention_maps(self, x, mask=None, train=True):\n", " # A function to return the attention maps within the model for a single application\n", " # Used for visualization purpose later\n", " attention_maps = []\n", " for l in self.layers:\n", " _, attn_map = l.self_attn(x, mask=mask)\n", " attention_maps.append(attn_map)\n", " x = l(x, mask=mask, train=train)\n", " return attention_maps" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Out (3, 16, 128)\n", "Attention maps 5 (3, 4, 16, 16)\n" ] } ], "source": [ "## Test TransformerEncoder implementation\n", "# Example features as input\n", "main_rng, x_rng = random.split(main_rng)\n", "x = random.normal(x_rng, (3, 16, 128))\n", "# Create Transformer encoder\n", "transenc = TransformerEncoder(num_layers=5, \n", " input_dim=128,\n", " num_heads=4,\n", " dim_feedforward=256,\n", " dropout_prob=0.15)\n", "# Initialize parameters of transformer with random key and inputs\n", "main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)\n", "params = transenc.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']\n", "# Apply transformer with parameters on the inputs\n", "# Since dropout is stochastic, we need to pass a rng to the forward\n", "main_rng, dropout_apply_rng = random.split(main_rng)\n", "# Instead of passing params and rngs every time to a function call, we can bind them to the module\n", "binded_mod = transenc.bind({'params': params}, rngs={'dropout': dropout_apply_rng})\n", "out = binded_mod(x, train=True)\n", "print('Out', out.shape)\n", "attn_maps = binded_mod.get_attention_maps(x, train=True)\n", "print('Attention maps', len(attn_maps), attn_maps[0].shape)\n", "\n", "del transenc, binded_mod, params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Positional encoding\n", "\n", "We have discussed before that the Multi-Head Attention block is permutation-equivariant, and cannot distinguish whether an input comes before another one in the sequence or not. In tasks like language understanding, however, the position is important for interpreting the input words. The position information can therefore be added via the input features. We could learn a embedding for every possible position, but this would not generalize to a dynamical input sequence length. Hence, the better option is to use feature patterns that the network can identify from the features and potentially generalize to larger sequences. The specific pattern chosen by Vaswani et al. are sine and cosine functions of different frequencies, as follows:\n", "\n", "$$\n", "PE_{(pos,i)} = \\begin{cases}\n", " \\sin\\left(\\frac{pos}{10000^{i/d_{\\text{model}}}}\\right) & \\text{if}\\hspace{3mm} i \\text{ mod } 2=0\\\\\n", " \\cos\\left(\\frac{pos}{10000^{(i-1)/d_{\\text{model}}}}\\right) & \\text{otherwise}\\\\\n", "\\end{cases}\n", "$$\n", "\n", "$PE_{(pos,i)}$ represents the position encoding at position $pos$ in the sequence, and hidden dimensionality $i$. These values, concatenated for all hidden dimensions, are added to the original input features (in the Transformer visualization above, see \"Positional encoding\"), and constitute the position information. We distinguish between even ($i \\text{ mod } 2=0$) and uneven ($i \\text{ mod } 2=1$) hidden dimensionalities where we apply a sine/cosine respectively. The intuition behind this encoding is that you can represent $PE_{(pos+k,:)}$ as a linear function of $PE_{(pos,:)}$, which might allow the model to easily attend to relative positions. The wavelengths in different dimensions range from $2\\pi$ to $10000\\cdot 2\\pi$.\n", "\n", "The positional encoding is implemented below." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class PositionalEncoding(nn.Module):\n", " d_model : int # Hidden dimensionality of the input.\n", " max_len : int = 5000 # Maximum length of a sequence to expect.\n", "\n", " def setup(self):\n", " # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs\n", " pe = np.zeros((self.max_len, self.d_model))\n", " position = np.arange(0, self.max_len, dtype=np.float32)[:,None]\n", " div_term = np.exp(np.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model))\n", " pe[:, 0::2] = np.sin(position * div_term)\n", " pe[:, 1::2] = np.cos(position * div_term)\n", " pe = pe[None]\n", " self.pe = jax.device_put(pe)\n", "\n", " def __call__(self, x):\n", " x = x + self.pe[:, :x.shape[1]]\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand the positional encoding, we can visualize it below. We will generate an image of the positional encoding over hidden dimensionality and position in a sequence. Each pixel, therefore, represents the change of the input feature we perform to encode the specific position. Let's do it below." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDQ0Mi4wNjUyNSAyMjIuOTQ4NzUgXSAvUGFyZW50IDIgMCBSIC9SZXNvdXJjZXMgOCAwIFIKL1R5cGUgL1BhZ2UgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMSAwIFIgPj4Kc3RyZWFtCnic1VZNc9s2EL3jV+DYHrzaxTeO8bh121NTa9Kzx2JkeUxqFNXJ3++CkkgApMV0JpcerJGegX37sJ8kX8TqA8ntUaJ84b9vkuS9XN01X3dPzV/3t/LpKJDxVhijAJ1Vln+9Zr+UUhBN8JZRLH49C9EJts0X7tnsVgiDYCM6vqU9WKP5WCu0c+Btib7mqEIEp07waCFHmemzOMgZ81o5ICfJacAgvzTyb9nJ1QeVFBMrJlaME8UHvudl0k3OzJl9auXqd5J3e/lRfJSHi0VksckqQjjbZWRedQYawLNoccvv9U3cruXqV5KEcv1Z8B1DGNOB9LZWrjfiJ/pZrl/kL2vRMwlPLDJQyZCB1xmcgYChtI8lAaEFMrpiyNHrFDGAD6UEVVPoCES1iBy9TkFaccKUMnTN4TkN4kRGhi5wpKTUpQ5TcSj0XBm1jhy9zpHy2lfhsDWHoWSj5sjQBQ7+v6vi4WoObwEn8cjRBQ4XwFXx8BWHxgBxEo8cXagNVOCqeISagztVnMQjRxc4tANXxSMOHGV6WAVaG5cOR1C+N5/O/7k/7v7Z7Tu56+SxObw13VNTejnX19pU+zjtdrOv7UH1L0bRQVC9CzONYqQhjyw7fctpRvS95ITgKRlPpaBM35bmGkZGxO+HnmqiAV0k4rpW3lyI6raREXEyYJwoGtAlIu5RKg6C6tYx8nBek6r1DOASCzdbrQY1Zj6ReO7IG0rjx4DvXeIWFZQ9edf2Ezld/2232TSd3Ozapjtyfl2MKfnHaZ73c6icbQvTqEi0h9nh3L43nPn8f5jwxeki69+1jr2ubTa5t8PDqfRwPPOso/6iO1/M6+/xVXLp7Te7biv3X5sv8rl6wOPlBdNQ1yFyjpZuc5itmV8nuHLotD8JMXe1LVA2FJA7RFp1RphihBicp1ji+TvoyLWuzBW4tDLiBWkGjzF5TtmbLTTXVqRwXpB0eqhk+7T39auROq1G2Vp0WYm2ObP1EFVAq/sU5LlXw0M1HUQqh5tUGAbNcN8AaSLvrdOKmcsqCqB9OKXXqVyYwKYACw7wQfTxShKGL2yAb9/cNS+Pn94eHrvjTbvr3o6sRfCad5kno/ueGZjfqtL9EV5w3xuw5JTnQHn6HvdPe8CPcp8bnovcVSr3R3jB/WggImnPmwTG73Ff/dDXJ+LyCFpbKv3P8PlmnEkgsqCCM94Z/ji7ifUGkXHyOsrz3Uw4R3yZU1vg2rTeYQy2eJp5ThshWG0nnCO+zGl53FrnvPU8fYpsmucM7JIOzlaUA7zM6B0Yk0qT27kqym+WMR8AGePS3pMzVptPHsd3ZuH/uU8/cE38C7SKRBYKZW5kc3RyZWFtCmVuZG9iagoxMSAwIG9iagoxMDAzCmVuZG9iagoxOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDc3ID4+CnN0cmVhbQp4nDM3NVIwULC0ABJmpiYK5kaWCimGXEA+iJXLZWhpDmblgFkmxgZAlqmpKRILIgvTC2HB5GC0sYk51AQECyQHtjYHZlsOVxoAnuAbmgplbmRzdHJlYW0KZW5kb2JqCjE5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTY1ID4+CnN0cmVhbQp4nEWPOxIDIQxDe06hI4B/wHk2k4q9fxvLO0kaLIwlP6IrOvbKw2NjysZrtLEnwhbuUjoNp6mMr4qnZ12gy2EyU29czVxgqrDIbk6x+hh8ofLs5oSvVZ4YwpdMCQ0wlTu5h/X6UZyWfCS7C4LqlI3KwjBH0vdATE2bp4WB/I8veWpBUJnmjWuWlUdrFVM0Z5gqWwuC9YGgOqX6A9P/TKe9P9z0PYAKZW5kc3RyZWFtCmVuZG9iagoyMCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMwNCA+PgpzdHJlYW0KeJw9kjuSwzAMQ3udghfIjPiT5PNkJ5X3/u0+MslWgEmJACgvdZmypjwgaSYJ/9Hh4WI75XfYns3MwLVELxPLKc+hK8TcRfmymY26sjrFqsMwnVv0qJyLhk2TmucqSxm3C57DtYnnln3EDzc0qAd1jUvCDd3VaFkKzXB1/zu9R9l3NTwXm1Tq1BePF1EV5vkhT6KH6UrifDwoIVx7MEYWEuRT0UCOs1yt8l5C9g63GrLCQWpJ57MnPNh1ek8ubhfNEA9kuVT4TlHs7dAzvuxKCT0StuFY7n07mrHpGps47H7vRtbKjK5oIX7IVyfrJWDcUyZFEmROtlhui9We7qEopnOGcxkg6tmKhlLmYlerfww7bywv2SzIlMwLMkanTZ44eMh+jZr0eZXneP0BbPNzOwplbmRzdHJlYW0KZW5kb2JqCjIxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjMwID4+CnN0cmVhbQp4nDVRSW7DMAy86xXzgQDiLr/HQU/t/68d0glgYGhLnM0RGxsReInBz0HkxlvWjJr4m8ld8bs8FR4Jt4InUQRehnvZCS5vGJf9OMx88F5aOZMaTzIgF9n08ETIYJdA6MDsGtRhm2kn+oaEz45INRtZTl9L0EurEChP2X6nC0q0rerP7bMutO1rTzjZ7aknlU8gnluyApeNV0wWYxn0ROUuxfRBqrOFnoTyonwOsvmoIRJdopyBJwYHo0A7sOe2n4lXhaB1dZ+2jaEaKR1P/zY0NUki5BMlnNnSuFv4/p57/fwDplRTnwplbmRzdHJlYW0KZW5kb2JqCjIyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjI3ID4+CnN0cmVhbQp4nDVPO7IDIQzrOYUukBmMbWDPs5lUL/dvn2SyDRL+SPL0REcmXubICKzZ8bYWGYgZ+BZT8a897cOE6j24hwjl4kKYYSScNeu4m6fjxb9d5TPWwbsNvmKWFwS2MJP1lcWZy3bBWBoncU6yG2PXRGxjXevpFNYRTCgDIZ3tMCXIHBUpfbKjjDk6TuSJ52KqxS6/72F9waYxosIcVwVP0GRQlj3vJqAdF/Tf1Y3fSTSLXgIykWBhnSTmzllO+NVrR8dRiyIxJ6QZ5DIR0pyuYgqhCcU6OwoqFQWX6nPK3T7/aF1bTQplbmRzdHJlYW0KZW5kb2JqCjIzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjQ1ID4+CnN0cmVhbQp4nEVQu41DMQzrPQUXCGD9LHued0iV2789SkZwhSFaP5JaEpiIwEsMsZRv4kdGQT0LvxeF4jPEzxeFQc6EpECc9RkQmXiG2kZu6HZwzrzDM4w5AhfFWnCm05n2XNjknAcnEM5tlPGMQrpJVBVxVJ9xTPGqss+N14GltWyz05HsIY2ES0klJpd+Uyr/tClbKujaRROwSOSBk0004Sw/Q5JizKCUUfcwtY70cbKRR3XQydmcOS2Z2e6n7Ux8D1gmmVHlKZ3nMj4nqfNcTn3usx3R5KKlVfuc/d6RlvIitduh1elXJVGZjdWnkLg8/4yf8f4DjqBZPgplbmRzdHJlYW0KZW5kb2JqCjI0IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzkyID4+CnN0cmVhbQp4nD1SS24FMQjbzym4QKXwTXKeqd7u3X9bm8xUqgovA7YxlJcMqSU/6pKIM0x+9XJd4lHyvWxqZ+Yh7i42pvhYcl+6hthy0ZpisU8cyS/ItFRYoVbdo0PxhSgTDwAt4IEF4b4c//EXqMHXsIVyw3tkAmBK1G5AxkPRGUhZQRFh+5EV6KRQr2zh7yggV9SshaF0YogNlgApvqsNiZio2aCHhJWSqh3S8Yyk8FvBXYlhUFtb2wR4ZtAQ2d6RjREz7dEZcVkRaz896aNRMrVRGQ9NZ3zx3TJS89EV6KTSyN3KQ2fPQidgJOZJmOdwI+Ge20ELMfRxr5ZPbPeYKVaR8AU7ygEDvf3eko3Pe+AsjFzb7Ewn8NFppxwTrb4eYv2DP2xLm1zHK4dFFKi8KAh+10ETcXxYxfdko0R3tAHWIxPVaCUQDBLCzu0w8njGedneFbTm9ERoo0Qe1I4RPSiyxeWcFbCn/KzNsRyeDyZ7b7SPlMzMqIQV1HZ6qLbPYx3Ud577+vwBLgChGQplbmRzdHJlYW0KZW5kb2JqCjI1IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjQ3ID4+CnN0cmVhbQp4nE1Ru21EMQzr3xRc4ADra3meC1Jd9m9DyQiQwiChLymnJRb2xksM4QdbD77kkVVDfx4/MewzLD3J5NQ/5rnJVBS+FaqbmFAXYuH9aAS8FnQvIivKB9+PZQxzzvfgoxCXYCY0YKxvSSYX1bwzZMKJoY7DQZtUGHdNFCyuFc0zyO1WN7I6syBseCUT4sYARATZF5DNYKOMsZWQxXIeqAqSBVpg1+kbUYuCK5TWCXSi1sS6zOCr5/Z2N0Mv8uCounh9DOtLsMLopXssfK5CH8z0TDt3SSO98KYTEWYPBVKZnZGVOj1ifbdA/59lK/j7yc/z/QsVKFwqCmVuZHN0cmVhbQplbmRvYmoKMjYgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA5MCA+PgpzdHJlYW0KeJxNjUESwCAIA++8Ik9QRND/dHrS/1+r1A69wE4CiRZFgvQ1aksw7rgyFWtQKZiUl8BVMFwL2u6iyv4ySUydhtN7twODsvFxg9JJ+/ZxegCr/XoG3Q/SHCJYCmVuZHN0cmVhbQplbmRvYmoKMjcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMzggPj4Kc3RyZWFtCnicRVJLcsUwCNvnFFwgM+Zn4/O8Tlfp/beVcDrdPPQMCAkyPWVIptw2lmSE5BzypVdkiNWQn0aORMQQ3ymhwK7yubyWxFzIbolK8aEdP5elNzLNrtCqt0enNotGNSsj5yBDhHpW6MzuUdtkw+t2Iek6UxaHcCz/QwWylHXKKZQEbUHf2CPobxY8EdwGs+Zys7lMbvW/7lsLntc6W7FtB0AJlnPeYAYAxMMJ2gDE3NreFikoH1W6iknCrfJcJztQttCqdLw3gBkHGDlgw5KtDtdobwDDPg/0okbF9hWgqCwg/s7ZZsHeMclIsCfmBk49cTrFkXBJOMYCQIqt4hS68R3Y4i8Xroia8Al1OmVNvMKe2uLHQpMI71JxAvAiG25dHUW1bE/nCbQ/KpIzYqQexNEJkdSSzhEUlwb10Br7uIkZr43E5p6+3T/COZ/r+xcWuIPgCmVuZHN0cmVhbQplbmRvYmoKMjggMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNjMgPj4Kc3RyZWFtCnicRZC5dQQxDENzVYESeIA66hk/R7P9pwtpvN5A+niEeIg9CcNyXcWF0Q0/3rbMNLyOMtyN9WXG+KixQE7QBxgiE1ejSfXtijNU6eHVYq6jolwvOiISzJLjq0AjfDqyx0Nb25l+Oq9/7CHvE/8qKuduYQEuqu5A+VIf8dSP2VHqmqGPKitrHmravwi7IpS2fVxOZZy6ewe0wmcrV/t9A6jnOoAKZW5kc3RyZWFtCmVuZG9iagoyOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDY4ID4+CnN0cmVhbQp4nDMyt1AwULA0ARKGFiYK5mYGCimGXEC+qYm5Qi4XSAzEygGzDIC0JZyCiFtCNEGUglgQpWYmZhBJOAMilwYAybQV5QplbmRzdHJlYW0KZW5kb2JqCjMwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNDUgPj4Kc3RyZWFtCnicMzK3UDBQsDQBEoYWJgrmZgYKKYZclhBWLhdMLAfMAtGWcAoingYAn30MtQplbmRzdHJlYW0KZW5kb2JqCjMxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjU1ID4+CnN0cmVhbQp4nEWRS5IDIAhE956CI4D85DyZmlVy/+00mEw2dpeo/YRKI6YSLOcUeTD9yPLNZLbptRyrnY0CiiIUzOQq9FiB1Z0p4sy1RLX1sTJy3Okdg+IN566cVLK4UcY6qjoVOKbnyvqq7vy4LMq+I4cyBWzWOQ42cOW2YYwTo81Wd4f7RJCnk6mj4naQbPiDk8a+ytUVuE42++olGAeCfqEJTPJNoHWGQOPmKXpyCfbxcbvzQLC3vAmkbAjkyBCMDkG7Tq5/cev83v86w53n2gxXjnfxO0xru+MvMcmKuYBF7hTU8z0XresMHe/JmWNy031D51ywy91Bps/8H+v3D1CKZogKZW5kc3RyZWFtCmVuZG9iagozMiAwIG9iago8PCAvQkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzNwovU3VidHlwZSAvRm9ybSAvVHlwZSAvWE9iamVjdCA+PgpzdHJlYW0KeJzjMjQwUzA2NVXI5TI3NgKzcsAsI3MjIAski2BBZNMAAV8KCgplbmRzdHJlYW0KZW5kb2JqCjMzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTYxID4+CnN0cmVhbQp4nEWQSxLDIAxD95xCR/BHBnyedLpK77+tIU2zgKexQAZ3JwSptQUT0QUvbUu6Cz5bCc7GeOg2bjUS5AR1gFak42iUUn25xWmVdPFoNnMrC60THWYOepSjGaAQOhXe7aLkcqbuzvlHcPVf9Uex7pzNxMBk5Q6EZvUp7nybHVFd3WR/0mNu1mt/FfaqsLSspeWE285dM6AE7qkc7f0FqXM6hAplbmRzdHJlYW0KZW5kb2JqCjM0IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzIwID4+CnN0cmVhbQp4nDVRu3HFMAzrNQUX8J34lTSPc6/K278NQDsVYRoEQKq8ZEq5XOqSVbLC5EeH6hRN+T5gpvwO9ZDj6B7ZIbpT1pZ7GAjLxDyljlhNlnu4BYEvDE2JuYXz9wjoKwajMBOBusXfP0CzJDBpcPBTkGutWmKJDjwsFlizK8ytGilUyFV8Oza5BwVycbPQpxyaFLfcgvBliGRHarGvy2Up8rv1CRiEFeaITxSJheeBDmYi8ScDYnv22WJXVy+qERnWSYcHUgTSbG4SMDRFsuqDG9hXxzU/T0fZwclBv4rB+DY4mS9JeV8FoRCPF/4Oz9nIsZJDJBTyfbXAiCNsgBGhT+0jEGUgNEX37plSPiZViu8ARiEcfapXMrwXkdlqhs3/GV3ZKgoGVVkfn0ZwJoNJOPNkowrTUrXTv/vc4/MHY2N6gAplbmRzdHJlYW0KZW5kb2JqCjM1IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjE0ID4+CnN0cmVhbQp4nD1QuxFDMQjrPQUL5M587TfPy6XL/m0knKRCNkISlJpMyZSHOsqSrClPHT5LYoe8h+VuZDYlKkUvk7Al99AK8X2J5hT33dWWs0M0l2g5fgszKqobHdNLNppwKhO6oNzDM/oNbXQDVocesVsg0KRg17YgcscPGAzBmROLIgxKTQb/rXL3UtzvPRxvooiUdPCu+eX0y88tvE49jkS6vfmKa3GmOgpEcEZq8op0YcWyyEOk1QQ1PQNrtQCu3nr5N2hHdBmA7BOJ4zSlHEP/1rjH6wOHilL0CmVuZHN0cmVhbQplbmRvYmoKMzYgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA4MCA+PgpzdHJlYW0KeJxFjLsNwDAIRHumYAR+JmafKJWzfxsgStxwT7p7uDoSMlPeYYaHBJ4MLIZT8QaZo2A1uEZSjZ3so7BuX3WB5npTq/X3BypPdnZxPc3LGfQKZW5kc3RyZWFtCmVuZG9iagozNyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDQ5ID4+CnN0cmVhbQp4nDM2tFAwUDA0MAeSRoZAlpGJQoohF0gAxMzlggnmgFkGQBqiOAeuJocrDQDG6A0mCmVuZHN0cmVhbQplbmRvYmoKMzggMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMzcgPj4Kc3RyZWFtCnicTVE5bgQxDOv9Cn1gAOu05z0bbDX5fxtS3gSpREMUScnlKVMy5bK5JCMka8qXDo0ttly+D0JTS0XB1L1FdclrmKasWyxd0POpLK/hGOB7dzfUP/SI2QKR0YJdYYEOkDu4YPg9eyZsUwsiUSXUDGCasMIcrkQMQQZjnRkGpQqDU/V3leOzDTsF1g5mU6RHUhOddIPmhbfeciGCrVO5qTfShNzZpxhiZeO+SpfjA+BgostEZMTmZTieDmFo8M40YIWzHsQEmdaR0ouZkTENN+nI1VeLis82GUue0f/2h/orn27/gxB8xvsHSVVcfgplbmRzdHJlYW0KZW5kb2JqCjM5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTU3ID4+CnN0cmVhbQp4nEWQuRFDMQhEc1VBCRKwCOqxx9F3/6kX+Uq0bwAth68lU6ofJyKm3Ndo9DB5Dp9NJVYs2Ca2kxpyGxZBSjGYeE4xq6O3oZmH1Ou4qKq4dWaV02nLysV/82hXM5M9wjXqJ/BN6PifPLSp6FugrwuUfUC1OJ1JUDF9r2KBo5x2fyKcGOA+GUeZKSNxYm4K7PcZAGa+V7jG4wXdATd5CmVuZHN0cmVhbQplbmRvYmoKNDAgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMzIgPj4Kc3RyZWFtCnicLVI5jiQxDMv9Cn5gAOvy8Z4eTNT7/3RJVQUFqmzLPORyw0QlfiyQ21Fr4tdGZqDC8K+rzIXvSNvIOohryEVcyZbCZ0Qs5DHEPMSC79v4GR75rMzJswfGL9n3GVbsqQnLQsaLM7TDKo7DKsixYOsiqnt4U6TDqSTY44v/PsVzF4IWviNowC/556sjeL6kRdo9Ztu0Ww+WaUeVFJaD7WnOy+RL6yxXx+P5INneFTtCaleAojB3xnkujjJtZURrYWeDpMbF9ubYj6UEXejGZaQ4AvmZKsIDSprMbKIg/sjpIacyEKau6Uont1EVd+rJXLO5vJ1JMlv3RYrNFM7rwpn1d5gyq807eZYTpU5F+Bl7tgQNnePq2WuZhUa3OcErJXw2dnpy8r2aWQ/JqUhIFdO6Ck6jyBRL2Jb4moqa0tTL8N+X9xl//wEz4nwBCmVuZHN0cmVhbQplbmRvYmoKNDEgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA2OCA+PgpzdHJlYW0KeJwzMzZTMFCwMAISpqaGCuZGlgophlxAPoiVywUTywGzzCzMgSwjC5CWHC5DC2MwbWJspGBmYgZkWSAxILrSAHL4EpEKZW5kc3RyZWFtCmVuZG9iago0MiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMxNyA+PgpzdHJlYW0KeJw1UktyQzEI279TcIHOmL99nnSyau6/rYQnK7AtQEIuL1nSS37UJdulw+RXH/clsUI+j+2azFLF9xazFM8tr0fPEbctCgRREz34MicVItTP1Og6eGGXPgOvEE4pFngHkwAGr+FfeJROg8A7GzLeEZORGhAkwZpLi01IlD1J/Cvl9aSVNHR+Jitz+XtyqRRqo8kIFSBYudgHpCspHiQTPYlIsnK9N1aI3pBXksdnJSYZEN0msU20wOPclbSEmZhCBeZYgNV0s7r6HExY47CE8SphFtWDTZ41qYRmtI5jZMN498JMiYWGwxJQm32VCaqXj9PcCSOmR0127cKyWzbvIUSj+TMslMHHKCQBh05jJArSsIARgTm9sIq95gs5FsCIZZ2aLAxtaCW7eo6FwNCcs6Vhxtee1/P+B0Vbe6MKZW5kc3RyZWFtCmVuZG9iago0MyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE3ID4+CnN0cmVhbQp4nDM2tFAwgMMUQy4AGpQC7AplbmRzdHJlYW0KZW5kb2JqCjQ0IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTMxID4+CnN0cmVhbQp4nEWPyw0EIQxD71ThEvIZPqmH1Z7Y/q/rMJpBQvhBIjvxMAis8/I20MXw0aLDN/421atjlSwfunpSVg/pkIe88hVQaTBRxIVZTB1DYc6YysiWMrcb4bZNg6xslVStg3Y8Bg+2p2WrCH6pbWHqLPEMwlVeuMcNP5BLrXe9Vb5/QlMwlwplbmRzdHJlYW0KZW5kb2JqCjQ1IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzM4ID4+CnN0cmVhbQp4nDVSOa7dQAzrfQpdIIB2zZznBal+7t+GlF8KQ7RWipqOFpVp+WUhVS2TLr/tSW2JG/L3yQqJE5JXJdqlDJFQ+TyFVL9ny7y+1pwRIEuVCpOTksclC/4Ml94uHOdjaz+PI3c9emBVjIQSAcsUE6NrWTq7w5qN/DymAT/iEXKuWLccYxVIDbpx2hXvQ/N5yBogZpiWigpdVokWfkHxoEetffdYVFgg0e0cSXCMjVCRgHaB2kgMObMWu6gv+lmUmAl07Ysi7qLAEknMnGJdOvoPPnQsqL8248uvjkr6SCtrTNp3o0lpzCKTrpdFbzdvfT24QPMuyn9ezSBBU9YoaXzQqp1jKJoZZYV3HJoMNMcch8wTPIczEpT0fSh+X0smuiiRPw4NoX9fHqOMnAZvAXPRn7aKAxfx2WGvHGCF0sWa5H1AKhN6YPr/1/h5/vwDHLaAVAplbmRzdHJlYW0KZW5kb2JqCjQ2IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjQ4ID4+CnN0cmVhbQp4nC1ROZIDQQjL5xV6QnPT77HLkff/6QrKAYOGQyA6LXFQxk8Qlive8shVtOHvmRjBd8Gh38p1GxY5EBVI0hhUTahdvB69B3YcZgLzpDUsgxnrAz9jCjd6cXhMxtntdRk1BHvXa09mUDIrF3HJxAVTddjImcNPpowL7VzPDci5EdZlGKSblcaMhCNNIVJIoeomqTNBkASjq1GjjRzFfunLI51hVSNqDPtcS9vXcxPOGjQ7Fqs8OaVHV5zLycULKwf9vM3ARVQaqzwQEnC/20P9nOzkN97SubPF9Phec7K8MBVY8ea1G5BNtfg3L+L4PePr+fwDqKVbFgplbmRzdHJlYW0KZW5kb2JqCjQ3IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTcxID4+CnN0cmVhbQp4nE2QTQ5CIRCD95yiFzCh8wOP82hc6f23dvD54oL0SyFDp8MDHUfiRkeGzuh4sMkxDrwLMiZejfOfjOskjgnqFW3BurQ77s0sMScsEyNga5Tcm0cU+OGYC0GC7PLDFxhEpGuYbzWfdZN+frvTXdSldffTIwqcyI5QDBtwBdjTPQ7cEs7vmia/VCkZmziUD1QXkbLZCYWopWKXU1VojOJWPe+LXu35AcH2O/sKZW5kc3RyZWFtCmVuZG9iago0OCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDcyID4+CnN0cmVhbQp4nDWMsRHAMAgDe6bQCDZYYO+TS0X2b0N8TgMvHQ+XosFaDbqCI3B1qfzRI125KUWXY86C4XGqX0gxRj2oI+Pex0+5X3AWEn0KZW5kc3RyZWFtCmVuZG9iago0OSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIxMCA+PgpzdHJlYW0KeJw1UMsNQzEIu2cKFqgUAoFknla9df9rbdA7YRH/QljIlAh5qcnOKelLPjpMD7Yuv7EiC611JezKmiCeK++hmbKx0djiYHAaJl6AFjdg6GmNGjV04YKmLpVCgcUl8Jl8dXvovk8ZeGoZcnYEEUPJYAlquhZNWLQ8n5BOAeL/fsPuLeShkvPKnhv5G5zt8DuzbuEnanYi0XIVMtSzNMcYCBNFHjx5RaZw4rPWd9U0EtRmC06WAa5OP4wOAGAiXlmA7K5EOUvSjqWfb7zH9w9AAFO0CmVuZHN0cmVhbQplbmRvYmoKMTYgMCBvYmoKPDwgL0Jhc2VGb250IC9EZWphVnVTYW5zIC9DaGFyUHJvY3MgMTcgMCBSCi9FbmNvZGluZyA8PAovRGlmZmVyZW5jZXMgWyAzMiAvc3BhY2UgNDYgL3BlcmlvZCA0OCAvemVybyAvb25lIC90d28gL3RocmVlIC9mb3VyIC9maXZlIC9zaXggL3NldmVuCi9laWdodCAvbmluZSA3MiAvSCA4MCAvUCA5NyAvYSA5OSAvYyAvZCAvZSAxMDMgL2cgL2ggL2kgMTA4IC9sIC9tIC9uIC9vIDExMwovcSAvciAvcyAvdCAvdSAvdiBdCi9UeXBlIC9FbmNvZGluZyA+PgovRmlyc3RDaGFyIDAgL0ZvbnRCQm94IFsgLTEwMjEgLTQ2MyAxNzk0IDEyMzMgXSAvRm9udERlc2NyaXB0b3IgMTUgMCBSCi9Gb250TWF0cml4IFsgMC4wMDEgMCAwIDAuMDAxIDAgMCBdIC9MYXN0Q2hhciAyNTUgL05hbWUgL0RlamFWdVNhbnMKL1N1YnR5cGUgL1R5cGUzIC9UeXBlIC9Gb250IC9XaWR0aHMgMTQgMCBSID4+CmVuZG9iagoxNSAwIG9iago8PCAvQXNjZW50IDkyOSAvQ2FwSGVpZ2h0IDAgL0Rlc2NlbnQgLTIzNiAvRmxhZ3MgMzIKL0ZvbnRCQm94IFsgLTEwMjEgLTQ2MyAxNzk0IDEyMzMgXSAvRm9udE5hbWUgL0RlamFWdVNhbnMgL0l0YWxpY0FuZ2xlIDAKL01heFdpZHRoIDEzNDIgL1N0ZW1WIDAgL1R5cGUgL0ZvbnREZXNjcmlwdG9yIC9YSGVpZ2h0IDAgPj4KZW5kb2JqCjE0IDAgb2JqClsgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAKNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCAzMTggNDAxIDQ2MCA4MzggNjM2Cjk1MCA3ODAgMjc1IDM5MCAzOTAgNTAwIDgzOCAzMTggMzYxIDMxOCAzMzcgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNgo2MzYgNjM2IDMzNyAzMzcgODM4IDgzOCA4MzggNTMxIDEwMDAgNjg0IDY4NiA2OTggNzcwIDYzMiA1NzUgNzc1IDc1MiAyOTUKMjk1IDY1NiA1NTcgODYzIDc0OCA3ODcgNjAzIDc4NyA2OTUgNjM1IDYxMSA3MzIgNjg0IDk4OSA2ODUgNjExIDY4NSAzOTAgMzM3CjM5MCA4MzggNTAwIDUwMCA2MTMgNjM1IDU1MCA2MzUgNjE1IDM1MiA2MzUgNjM0IDI3OCAyNzggNTc5IDI3OCA5NzQgNjM0IDYxMgo2MzUgNjM1IDQxMSA1MjEgMzkyIDYzNCA1OTIgODE4IDU5MiA1OTIgNTI1IDYzNiAzMzcgNjM2IDgzOCA2MDAgNjM2IDYwMCAzMTgKMzUyIDUxOCAxMDAwIDUwMCA1MDAgNTAwIDEzNDIgNjM1IDQwMCAxMDcwIDYwMCA2ODUgNjAwIDYwMCAzMTggMzE4IDUxOCA1MTgKNTkwIDUwMCAxMDAwIDUwMCAxMDAwIDUyMSA0MDAgMTAyMyA2MDAgNTI1IDYxMSAzMTggNDAxIDYzNiA2MzYgNjM2IDYzNiAzMzcKNTAwIDUwMCAxMDAwIDQ3MSA2MTIgODM4IDM2MSAxMDAwIDUwMCA1MDAgODM4IDQwMSA0MDEgNTAwIDYzNiA2MzYgMzE4IDUwMAo0MDEgNDcxIDYxMiA5NjkgOTY5IDk2OSA1MzEgNjg0IDY4NCA2ODQgNjg0IDY4NCA2ODQgOTc0IDY5OCA2MzIgNjMyIDYzMiA2MzIKMjk1IDI5NSAyOTUgMjk1IDc3NSA3NDggNzg3IDc4NyA3ODcgNzg3IDc4NyA4MzggNzg3IDczMiA3MzIgNzMyIDczMiA2MTEgNjA1CjYzMCA2MTMgNjEzIDYxMyA2MTMgNjEzIDYxMyA5ODIgNTUwIDYxNSA2MTUgNjE1IDYxNSAyNzggMjc4IDI3OCAyNzggNjEyIDYzNAo2MTIgNjEyIDYxMiA2MTIgNjEyIDgzOCA2MTIgNjM0IDYzNCA2MzQgNjM0IDU5MiA2MzUgNTkyIF0KZW5kb2JqCjE3IDAgb2JqCjw8IC9IIDE4IDAgUiAvUCAxOSAwIFIgL2EgMjAgMCBSIC9jIDIxIDAgUiAvZCAyMiAwIFIgL2UgMjMgMCBSCi9laWdodCAyNCAwIFIgL2ZpdmUgMjUgMCBSIC9mb3VyIDI2IDAgUiAvZyAyNyAwIFIgL2ggMjggMCBSIC9pIDI5IDAgUgovbCAzMCAwIFIgL20gMzEgMCBSIC9uIDMzIDAgUiAvbmluZSAzNCAwIFIgL28gMzUgMCBSIC9vbmUgMzYgMCBSCi9wZXJpb2QgMzcgMCBSIC9xIDM4IDAgUiAvciAzOSAwIFIgL3MgNDAgMCBSIC9zZXZlbiA0MSAwIFIgL3NpeCA0MiAwIFIKL3NwYWNlIDQzIDAgUiAvdCA0NCAwIFIgL3RocmVlIDQ1IDAgUiAvdHdvIDQ2IDAgUiAvdSA0NyAwIFIgL3YgNDggMCBSCi96ZXJvIDQ5IDAgUiA+PgplbmRvYmoKMyAwIG9iago8PCAvRjEgMTYgMCBSID4+CmVuZG9iago0IDAgb2JqCjw8IC9BMSA8PCAvQ0EgMCAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+Ci9BMiA8PCAvQ0EgMSAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+ID4+CmVuZG9iago1IDAgb2JqCjw8ID4+CmVuZG9iago2IDAgb2JqCjw8ID4+CmVuZG9iago3IDAgb2JqCjw8IC9GMS1EZWphVnVTYW5zLW1pbnVzIDMyIDAgUiAvSTEgMTIgMCBSIC9JMiAxMyAwIFIgPj4KZW5kb2JqCjEyIDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDMyNyAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTY0IC9MZW5ndGggNTAgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMzI3ID4+CnN0cmVhbQp4nO1daXhWRbLukAQSICELyBY2EcKqICAioigq4gYCigujqMNw0UFxFDdUcBvHZRaf0dEZd72AV0ZlZh4EBAURUEE2RSPIEk3YQgLZvywkuT+6ztslpz3fRwgBeur9k3qqvz5L9+mcqnqr60RVV1cppZRSw4ZdoDwcOHBACwsXLoRy47mXaGHJNmp9/NOn0Tr9gy+18M4770D58ssvayF98f9q4ak/rUDrfdPO0cJ3Q6+BctKkSVq4/vrrtfDoZaej9f6h92ph+CkpUJ66bL4WRowYAWVycrIWPl78kRbeaNELrbkVdNd3bf0QyuumPqSFtWvXauH9999Ha9m9U7Xw+oJtUD799k108PxULTz55JNonTFjhhaubvAjlA9MnKWFX4/sooWoGWYYx44dq4VBgwZB+daT99DB00droWPjWLSO2/ONFi688EIoCwoKtLB48WItrB4wDK1Lswq18MSqv0B53zufauHdd9+F8rXXXtNCp/+8qoVn//qFuZfp52thwxmjoJw8ebIWbrqJBueh87uaLhfRmFzSozmU3RbMI+Ul9Iy1bNkSrYvm/1sLr57UB8rCg9VamLqd5veq/7kLrZs2bdLCBx98YLrcfqsW3vo4UwvPvDMJrS/vjiflM89A+dhjj2lhVPlmuv5b/8/c6dhuWiif9iiU1157rRbOOeccKF+dMUULj3e/WgtdmzZE6+idG7QwbJiZo7KyMi189BHd4Od9h6L1s11FWnj0yxegvOv1BVpooAQCgVuQVS0QuIaYtLR2Wpo+fTq04+N3a+G+VgOhvOWyU7Rw3TtkjPW86iq0DhgwQAs/fvERlE93vUILOXExWnixegdazz+f7Lf8xebUsH7XD75YCw/+0RiETy4n6wgGv1LqFu/UMPiVUl0XvaWFyTF02TD4lVKbziFDqMOA86H81a9+pYU3bx2phQf6XoHWEelkYz+825igg4cP18JJJ52kheyfjLENm/9lz+BXSj2/m0xiY/BffTVa582bp4XQ3XdA+dv2dBkw+F/bn4zWDh06aAEGv1Lqqhoa5/ub99PCpCvT0TpuLs1gN3bqs88+Wws/rpoPJWz+HM/mf75yK1oxg8UfzoTy66+/1gJs/plPF6L1D58/p4V7Zi+FcsKZZ9J9+Qx+pdTkhmTowuBXSq3rT2PSrs9gLcDgV0rNvs2bwVMvhRI2//TslVo4yzP4FbP5szONhwWb/w3P4P9bzjK0wubf5D05Sqn582n0Dkw25v1vO9EwwuaHwa+U6tSpkxZg8CulRpZmaOH+lL5agMGvlLpyGj3qXUePg3Lo0KFakHe1QOAaZFULBK5BVrVA4BpisrOztDRjxkxoe79JTmnLAYZgGDWF+JXzKn/SwlOFJjr/0ZxVWli66QYo79k8TwtvLVmjhV69DL108OBBLdx5551QNo8q1UJcciMtFG6tRuvcK8gXffyfxhW/9FLynR544AEos7Lovq65m2izNo8/gdbQJCKQxuyKgTLxublaKGlHbvDk7xejddq0aVpY3a8flEOGDNHC008TO5XzZ8Ov7CqjGyyoNLew7TG68ldeec1/2aB24GsppZ5aNUcLMfnksjae8me0jswnj7f3nFlQJv3rn1roFqIZvPoFQ4HETZighSeeMGNy9RkUgFh8/gQos0J0C6cmxZFq82do/d3vfqeFRx55BMr+/ftrASTl7++8Ha1fXXm5FsoZQfhAs8ZaGFy4UQsZN05D67aM+7Qw6PVFUF6yN1oL33xDxF7hH6aidebFs7XQNMa8t7qNIQc1Opr69unTB61LlizRwqUjR0P5zIr3tNBhPY3nXztfhNY2xRVaeHiM8XjT04nGmzHkXCjnb39TC2ueJa6Uj3zmvBe18J/L7odyek6JFq7q3UILvf/8R7RiQTVu3BjKVatoDcq7WiBwDbKqBQLXYJitMz12QSm1bNkyLST8+3ko/z6KUnMWeTzNb27ui9aL/0LZY5whmzCI2KmOHTtq4aWXXkLrgHLiSN4dZ5Krbp38By2MH5ymhRfzTDraC7PnaaHXRGPyxcWRcXjvvfdCed0QMvWXDL9FC7c/azKcRrRsooVnvvoblIt312hhuEcR5T5tSLWbb75ZC+/PNdlF68eS8f9MR6LN0hMaoXXq63Tqbd0Nv3Lb3Xdr4dt0opouusgYdevWrdNC6DlzL/8YOl4LYJx+M81YdyPveEoL9913H5Rju/fRQs+ePbXw3nvvobVzBh3p9d8YMuyOonIt3HRFFyhfzCAa7/En6Sy3X3YzWpHAxymZ4W3Jvv1wBN3p5PvNpI/2CMLnM+dBOfszSgUbOJPOUlb2d3NhdxDJ99G7b0K58nLi5B5IpdTDgclxaJ3xIXlqX8WZe7nmfrJvM18hKhSZfEqpH374QQtZU8dD+Y/TLtNCSkO6qUlPjkRr0RW/1cLd3pwqpb6IbEG9cuVtaF1UTm7OxBtPg/Liv5IT8dBDdC9YTSrcgpJ3tUDgGmRVCwSuQVa1QOAaDLP1wQfzoL39dqIikP2nlGrYlniss88eqoV9N9yC1i6ludS3wT4ouxeT+7puBZ1l6UUmra/xBR21cO2SP0F5djPyNt944w0tTBl2OVpzcnK00L59eyjHjaOkOb5pqeIb2owV77lbybHmX9iivUQbhC4yRNTQuynDcfUK2r204ONlaH31Vcph7NjZuGoxMUSMDbzmLC30+PWvzQ32702nftYwFuPX0lCs2k+MVMs5y9G6NZM8vX5PTIFy0iQKIsyaRdzVGOYk73q9jxbatGkD5ZQp1H38ePIST8pchda1L9FWvD0e96aUqqSogtr0WTaULZ8g1/HBu2gv2llnnYVWpOjidBynDyaPd+LEiVBeOJQSPPc+Z/i8mGc/1sKQXOI1m7HJOm0WEXsVPRKh7OKNQFdvTP4xdy5aH7zxQS0gk1cpNWrUKC3ccAORr2n7v0PrhrH0mC1YanJ+S6poULo3ooku3mke7w4J5GxPnToVyldeeUUL4JkUy+U89dRTtXDLLLMz7N6RlP26/wUTTFl8Mjn/Ie9ZvTbGbNQb0ZxItf75G6EsHXydFuRdLRC4BlnVAoFriILZxu23MWPGaAH2m2Im3NqHKEtp4efGVNtbTnTXWSlmM8p5t5G5lXY32W+frDEmPWzazz//HMrqakrDQorSLbcYO3/EsKF0ur8YW2WZZ78t9+w3xUy4S/q11sKAx4xtXNCbUrhg0yql5nom3E8/UfIct98uv5wsNL43KC3/ey1smP6sFrj9htyy/kmGdLlg4hlaaH/P41pYmWG6YNPSihWGz6uoIHPL2G9sTEba7LcVj5MDAl+DZ1mN6EXpSmc+ajYbhTz7jde9mDOHrN/t27drISXF1KtAVt8EL1lNKdWlhrykTQ//XgsLPjRpZJmllVrolWgowOHjidHp/CAxZF9lFaAVY7J0qdnmVVpKc41sxRtvvBGto0dTfljZLEO5ffbwv7SwcCftIWvYIMpcg1eHY/BjZh9bzQjin1BJ4u2330brli1btJCUlAQleEr+nPRqSjOYMXOmFhZ9sMUcx8tR46UURng7tNK9pM+v880MYkxQFUMpVVhI9yXvaoHANciqFghcQxTqlq1duw7aTz75RAvc5tm6lTJXYPw0bdoUrV27UubWeeedByXkPr0pw6l88eto3TbnP1r4ZsF2KNflU7kmFKZq7qX1KKUGnURB9fQre0DZ/jqKgdf0NZvg16yh9PePPyb7/NNPP0Xrjz+S0YvqUIrZUd26kfHD60ghjNm9Q2soQwvJidg8iwyhDcuMOb2xgLK1yqproGzjFZAYmEYR3e7j+pjWcVStLdTxDChXr16tBZhbPL6KfSww1JVSzZtThYDevSkOf8EFpi4dSmp1TDSbW0rm071kzDb7N9Z+uVML3xXSvVSaWzHl0wZ2NWZ5t6vIdWo5jkziA0md0QpvC+W4FJusXbt2aQGOmFKqRQvyF/r2NbmMw716FYMHk5fXWpnaDPn/foMue9ZKcy9fk2uwucgMFHCKZ/0O8NwTpVSP6yngnzxqghZ2RyWhdeVKOjg3gzds2KCFPXv2+M+SlkYZk/3YHiEY7bxYXWoRPUg5cymjLuOd1Wj9cst+LcCjUUrFev6EvKsFAtcgq1ogcA2yqgUC12CYLe4kd+5MjhDKzSnmJJ/eh3iI8o/fQOuOOVS0edN8U61uzYEgJ3lgKnFg3ZiT3OFaSq6KGnglHcRzuhRzkpcvN/lYIF2sTjIcfu4k4156dEqDMrSI2IIfZtPu/I2fcCeZDo58I8Wc5AGtafR6MCe57TXEFZV1NvlYcJKxTZ+TWMFOMjZgWZ3kzqlmA33pQrqX796myMj6z3ei9RvvXqxO8oDOSVB2v5qc5FbXkJOcn2K2vn3xBe3o4o7ll19SocidO+mMKI+hWNG/004z+5OQFIgqFG1jDElZ4DnJGbPNQK1Zv1cLwU5yv+6pUPb0nOSU0XQvObGGucQsYF4Uq41pdZJbt6YIC1hYxZxknoTXvJQ44H3/9JzkOaac5prNeVrYUWKc5GjPScYuwH6nt0Jr9+updGTiFWYLXXYZzaC8qwUC1yCrWiBwDYbZ2rEjE9rvv6eUqfXr10OJAlGghbDXQjHrt0ED858iMZHIG9j5XbqYrRGoGgX2RSl1yilUPat5DFkjFRs/QWveCuISsj/bDOWP68k6Qo6OYvsWYGTGsUSi9p61mc6yvtLOpItMO5euJ3GwcUBielChgp379kO5eTNdxsaNGw8RFHMN+EAVFxernwOjpBiLwwcKxipyy9LTTXHvVol0C5XfLIOyYBV5KFnLqSDBztW70fq9R1PtDBnbGPRbrBkn1TaeBqp7Atm0bZgd2O5ccp1SBg+BMvZU8m5yQuR2IQdLsaeIP1ogTWHo8lECy8VrdGGgTj75ZC1gcBR7tEBSKqXataErr8og6q5olTG2s5bSxO1cvQvKLV624k8egcT9LzzoreIMQYj8sA6MIUsbQrPZ4hzi4Rr2Mf5gQXSCFlC5QbGPCoEqw6pUSu3eTbOZn58PJQZK3tUCgWuQVS0QuAZZ1QKBazDMFgoAKLYvBwluitEqcFp69DCMFMiw+AOZUJatoazAPcuJncpaab6ztX0rfTF3Gwvo7/dKHSJjMJFtNurUhNy8ru2NL5p2JtV/SzvPfBM3fiCVbqtqR04pvFylVEYGfcSIu8EoEYE9W/v3GxcagQM+UHCJMVDckYM/zKugI3CQWJmvhYr1xrvLWU6poNmfGRdrx7dUkQKBg33eDjnFBqpJtCVw0K0lUW5pg8yevLbn0IU1HWQYsuhu5Btn/pQFJQIH8O54IY0dO2g2c3NzoURCMSIsPHAAZosPFFxiCBglpVSLRnSLFRtN/vJ+j4jK+pSmMnvDXrRmeJUVrWUhEGFpG2+mskcy8axtBpiM4Hbn0sQ1O5siLDE9zdfaduVRjipGSbHx4Y/Wtm20a23vXrpIf3hF/ZxdxkBhZXEuEGuQx1/apNI4y7taIHANsqoFAtdgmK2cHFOQCVtnkB6klMrMzNQC7C5YqoqZFgcOHIASZgbPlAJgynLGIiGBovxIqGrbti1aUQaZf7AGNcz4L2HAxJfna6EqOwOtFVuIVjmwyZAueRmUAJTn5frszS5CK0og7Cs3Rh0+tQMzmP+bjPdM4haNjKWHdLRWLeiuU7uYz9Y270WWfGpPc4ON0r1PyXQgJ6gq2ZRtw8hj1pRS2dl0L5g17oBgWtFXKVVQQIUKYEIr28Q1bGh29mPimjVrBiU4p3btyDPCrHGZV56DG4hZi843D15VJtm0FVuN8Z/7NZFheRn0y9zvjbu0Zx/dwi5mgcO5Kz546KwpNnG8vESLRpQK2c5j+FqmJaA15RSauOa9zL0k96Lcu4anGKYtuj1Z8qFGSVrYt88sN0wHJkuxxQUlX25weYqKzCNqfB8lEAjcgqxqgcA1yKoWCFyDYbZ4pie+XAUvVymVmkrbX+D8WP0l7vHC0UWXJMYl1OwkRxderlKqKIPS4vK+9ViTzYY1yfPIsCyW6rjHfFDW8D18NxLdFMsYxWeTeK5fWjPaHJPcKUkLzXuYD/2m9qL7atLd5LfGdibfKaoNkXy5+01YAWl93OOFf4vwBPZpKeboclIN4Qm+KQ3AxPHwBMgkhCd4tUlMVocOHaCEzMMTrVpRlmWT6pAWqnea8ES5N3EF37LwxLeZWsj9nsIT+7JMlRJMHA9PwNH1z5pijF0K2/CH8EQbb6daCg9PdCd2CrOmlIpLJ2YoppM3gy0NLWQNT2Bq4NwiXVoxR5dnBCOFk3NX/vAE50cxcbyqoT+uhNxYxWYQwQv+S3lXCwSuQVa1QOAaDLNVUGDMJJAc3A6EDFuFGx6Qecge8Xcc0BqI54YlbBVejw6AtckNGHAt3ARFmg6cCF7IGuwLXAPFqn9DgNPBu3MzCXLDg3QvNfuNOV2VQ/LBXSajLuR9Aqnop72eYEasMJtmoXgvKxuwn6zfXI+bAUmjmPkKQf2CKQtgV1Z8tPm3DkaHfw0HVS5SvPLdTb2PBCulEj2aJyHNjG1COxrSJh3IIIxp3RGtMa1Ijko1aYtVTcjaREVrvhUJDx5/GrG7Cw8eZ+nw4PEuODiexlAohFY8eNxa9j+E3FHFQwiPVSkVH085ajxRDMwfniJeah4PIaxuxXwfdOEPML4xzPP2YkL5dJFKIBC4BVnVAoFriOKhUQCmhTWRCAKPkFvtBJgWsHi54QEl7wIZRksTlphVk0fma/U+k2dzcDfZtxU7TXyS2bdkocG4VUoV7SZHoDjHGLrYMnHAi6UjdUwx+7aCFff2OwnW3DKertQslmxafKOzOYvDJ3j1zxLaGPstIS2JBM+4TWhvhjG2DYV5Y9qasttRKW09gQzdohJm0nsmaF5eHpQwVvlWDdi3sHh5K7wtGLfKZt9aPSxezMwP7mHBvuWGbpMm5AjgOYFRqpgpyz0s/0PILV7rAwwPq5GiDUg1eeYrVFU5mXQvuzKhLMumJ7Moy3gEeAiLdlFgvHi3iZAX5h7qYSnmZIHZCbGCDbzCvB/yrhYIXIOsaoHANciqFghcg2G2SktNlL+khD6PykP/8J2QNMOzZ+BNcU7CT1TwVrBc3CsD3YVr4F5ZeTnthuf0g9VVAydhZcgATlT4aTMeVoDcqJH5RCs8Pbh5VnaNb2kCFQGfjZMTUHL+DMfx03X8jLiGn8ml+fpvTZHxh1UxudNVBww3CbnqgGHayrzSi6U5lDNXlmcmK5RHD0Ap+8ZwmVcEPgSBVcWAl1jIeDh4jNbgRaiKlMyvtEQ0rOCVFfFdWwhWYo8HQVCxI8GLfcQlm9mPT/Yc/mTj8DdOpZGPSzXTGn8S+fyNWyRpISbF1CqMTiY/PzrVVHqMSkg9RFBNTRSgvJIedSwTLsu7WiBwDbKqBQLXYGe2YItaM2lgi3IDFXaplX6AlcitTZiR3NqM0EDliTs4Dlf6T80NVFVChmVNsck98tulVbmmhnZVPtmloRyzfyOU57E4+/JJyDVeiTFB84wjU+Z9/rbYK0LGc8IKbYliMFBhi3IDFXZpsFHK/3+jwFlDtuMF5mh8tEWZaDNQg+3SeG/fRePmZuRhl8alGq/Eb5dyWxQGalSisVqjmno8lmeXhsqNUwbPkRuo8Phq4UIGt3IX0uq9wouEENaFhBzsQnJghcq7WiBwDbKqBQLXIKtaIHANhtmqqDD0AwgkK6sEpdVz4Ep0gQPDK91B5kq4JVb/BDJXWvd++a+2stLcIJwW7t5A6RfUL1BlEfJnVliDF9ZNaZD9ApdjY2Oh9Ic5eBAE+4og/JISwQjELHjcBDELHtGA0npAdLee2hqdgZLfQixIqTLylmtCZjsglEZQqqqIAiI1JaSsLjIhkqrifE9g1Q4K6dEqzy86RKOUqigiubKkHMpy7xtmFcWVFqVH8h1kZRLLvEoSiJso9kEvazAFoRauhCzvaoHANciqFghcg53ZAqzGYTDvZa1nEGwlciWMMWstBLRymsq6T92/w4wfx6r0W4w8jQxKq3EIJe8SG+sNhc04NBYjsxKrPTuwuijfKEsLD1FWFZsu5fle0fUiZhwWlhyirChkFQJKyO8oL6jwKytZKhgsRquVCDuwzJYKBoPQaiVayTkoD9uf+TnwXDKSTkVHHZpbxok9qxIkHxi+OFtrY7atMNb7aFQsK9HXKLGh1+plKCaa56Shp2yYaJ7G2ITGnpIe9UZJTVkryQ2aJkHZoAlRvPKuFghcg6xqgcA1mBh4ZaUJyiFczOPGkBEZ5q3Wsk/+ULM1pcaq9Kfj/FIXa74OYuCIlkPDZR7lhtIvWO9FhQubBysRNrdG2oPD79aQuzU4X1fg3pZfWQtPzaqM8ozkYAqAy3B5rK6f1eMLbrW6XdY4PGTulEWotB6HX48/g9Pahd81usu7WiBwDbKqBQLXIKtaIHANYZgtDr9nFexNWZVhXazo6OhDWiN3sbjS725xf8lKqvnrIgS7WCqcvxSsDHYIrU6UtRW3wO/Fr+StOCMnflSlR31VmEBGTUXI1xpiXeiX1SGzNaqmrOQQJTRKqZpyrwtTVnvZgZWl1FoVMsELKA+WsTBHCZT0y6oy1sXj4SpLLErkdVWx6n9QHgwZZZVX46HSaw2b4FVZE8Tn+Tk8pVSVrQt+iVZeNMKqBORdLRC4BlnVAoFrMMxWFcsZqqoiJedXoIyw1fpLToZZj4MfBJ/Fepzg/Rv8OBEqrYyUle7y037qFwjCCLtY79p6YdaBglzjmWrWLmjlP+AziB9EvrmlTni4sMxcnVN3wbASe9bWyCnA4NZg7xUyuEAlVRMEAochq1ogcA2yqgUC13AYzJYfwc6G9Qe18EBqwZ8pW+4hODNrq7KRalanhR/HT6pxAgnXY2WVoLQm/QV34exaMNtnTS20FjwPVuLU/PatyuAu1oHyj7y1S/AeQWsrP040Ci0c9OiuKhPvMMpKEzep8f/yYAXr4v2yysQsajxesIYThAe9z3SVhwJbOafoHcc7I2+trgRLx77LW0FKeVcLBK5BVrVA4BoMs1XDM1c8boPzB36ltZWTJRF2CT6OlSoLq8Qxg7tYrzb4gJFvwLKWdPazU9bWyLuADAtW1qILlyNstZ4x7DD6f2mdl2CllZALVvID+o+swhF7/uuP5Ae17hK2FUp5VwsErkFWtUDgGmRVCwSu4YiYrbAIS30F/CzyHD2rMphUq0UXa44eJ04CWoOVVsotWGk9YPB+OH5AKxdoKmnY7jrC1qPRxUrd+Qun16JLcKv14D+nyqL9SusM+n9p3acY3CXsWcy0KoFA4BZkVQsErsHObCFEzkP/fiUPr0eorEWX4GuIXGnlV6xdrMyHtYufnKufLrXYOFW7rwsFb8CqRe3ECJXWyYqcZ43wOGHP4ieQIu8S9pcB9xK2i/UseGbkXS0QuAZZ1QKBazi6MfAIEWGovHZd6uqXRyNQf7hHtip5PByIMDgf+XHqqjX41NYQMRC8pcd6ZGsX61msvID/mJF3qatT1+Is8q4WCFyDrGqBwDXIqhYIXINhtpQyzk8w5+QX1C9sf/HzNNYuYamdWnSx/tLfWuc3GDl15z/IYR3ncLsEX/9RPXXYLhGyOJEXMKyrLod7wCM8dfBxwm7tAssl72qBwDXIqhYIXMNxwWwdIWpBjNX5cY4qOVeLLrWg3Or8OEfSGiFdV29djoeDR35AeVcLBK5BVrVA4BpkVQsEriEMs8XhD6yHJX78SisjFXmXCM/CDx55lwg5pyO82jofxgivNuxoRzgdtThOhOOgImbFDut6/L88kuPU2zD6fxl50UJ5VwsErkFWtUDgGlxgtuocdUWVHQ+nPt4O6BJ9WIsuR+OXAKgveVcLBK5BVrVA4BrsMXAg8mhhcBer8ki6RB5CrPMLq5/bPx4G2XoZYbtEGLEP/kHw7Z9YXcJ2r/Mu8q4WCFyDrGqBwDXIqhYIXIMwW8cpjiG7FjmOh4us52s4qqers92HdXIUgUBw/EBWtUDgGsIwWxxh+QmNWlAyR/WXkR8nLGlRt6c7ri478jPW1YXV+ZNQVzcY4XiGPWD93yCuXN7VAoFrkFUtELgGWdUCgWsQZktwtHA88F71j+Phro/9FQgEgrqFrGqBwDUcBrPlR+SxeyBy2uBIznKE3eunC1A/Y3JCDGMthuJITlf/Y1I/cy3vaoHANciqFghcg6xqgcA1CLMlENQf6of3kne1QOAaZFULBK7hiJgtK46QLfDjSNiOX0KdX+Rxe0CZDrcPKF/kEQj+KyCrWiBwDRIDFwhcg7yrBQLXIKtaIHANsqoFAtdQ98xW5KhztiAYR4OSiRD1fKfH5IzH6tTH8E7r+YmK/E7lXS0QuAZZ1QKBaxBmSyBwDfKuFghcg6xqgcA1yKoWCFzDsWS2aoFjSGPUAseQS6sFTqyxlasNgLyrBQLXIKtaIHANwmwJBK5B3tUCgWuQVS0QuIYTLAZe5zixQql1jhMrSl/ncHX25V0tELgGWdUCgWuQVS0QuAZhtgQC1yDvaoHANciqFghcw387s3UM4Sqtctziv4fGk3e1QOAaZFULBK5BVrVA4BqE2RIIXIO8qwUC1yCrWiBwDcJsCSKC8HAnEORdLRC4BlnVAoFrkFUtELgGYbYEAtcg72qBwDXIqhYIXMP/A9aOwEMKZW5kc3RyZWFtCmVuZG9iago1MCAwIG9iago4MDAxCmVuZG9iagoxMyAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyA4IC9QcmVkaWN0b3IgMTAgPj4gL0ZpbHRlciAvRmxhdGVEZWNvZGUKL0hlaWdodCAxNjMgL0xlbmd0aCA1MSAwIFIgL1N1YnR5cGUgL0ltYWdlIC9UeXBlIC9YT2JqZWN0IC9XaWR0aCA4ID4+CnN0cmVhbQp4nO2U2w2EMAwEkxCaoP8q87CvgUykRdahe/CJNTveBJGv60qrpyzfppRqrVUclLJOe5b4+nXP83yO+K8b4Pjdb/cdxHEc4iDnLBK6XHc8KtcdSOjyQCLwrD6sYCEiucFgdpHwOYAwchg5AqOQQIfP9YALylGbgiQfWJCIDoSPRlF0tTrRkVAL3mkeVtCJsAY3aB0GPiFq4kURgQ5DInJdJOA+Ah3WyMFR+L8ioskEOSat69MpCgsCsTkS2aFHDWruQDSjqPX7DREYFbhuvxEVJX8BLW3aFwplbmRzdHJlYW0KZW5kb2JqCjUxIDAgb2JqCjIxNQplbmRvYmoKMiAwIG9iago8PCAvQ291bnQgMSAvS2lkcyBbIDEwIDAgUiBdIC9UeXBlIC9QYWdlcyA+PgplbmRvYmoKNTIgMCBvYmoKPDwgL0NyZWF0aW9uRGF0ZSAoRDoyMDIyMDUzMTE2NTkyNiswMicwMCcpCi9DcmVhdG9yIChNYXRwbG90bGliIHYzLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjMuMikgPj4KZW5kb2JqCnhyZWYKMCA1MwowMDAwMDAwMDAwIDY1NTM1IGYgCjAwMDAwMDAwMTYgMDAwMDAgbiAKMDAwMDAyMDkyNSAwMDAwMCBuIAowMDAwMDExOTcwIDAwMDAwIG4gCjAwMDAwMTIwMDIgMDAwMDAgbiAKMDAwMDAxMjEwMSAwMDAwMCBuIAowMDAwMDEyMTIyIDAwMDAwIG4gCjAwMDAwMTIxNDMgMDAwMDAgbiAKMDAwMDAwMDA2NSAwMDAwMCBuIAowMDAwMDAwMzk3IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwMTQ3NSAwMDAwMCBuIAowMDAwMDEyMjE0IDAwMDAwIG4gCjAwMDAwMjA0NjUgMDAwMDAgbiAKMDAwMDAxMDU0NiAwMDAwMCBuIAowMDAwMDEwMzQ2IDAwMDAwIG4gCjAwMDAwMDk4ODggMDAwMDAgbiAKMDAwMDAxMTU5OSAwMDAwMCBuIAowMDAwMDAxNDk2IDAwMDAwIG4gCjAwMDAwMDE2NDUgMDAwMDAgbiAKMDAwMDAwMTg4MyAwMDAwMCBuIAowMDAwMDAyMjYwIDAwMDAwIG4gCjAwMDAwMDI1NjMgMDAwMDAgbiAKMDAwMDAwMjg2MyAwMDAwMCBuIAowMDAwMDAzMTgxIDAwMDAwIG4gCjAwMDAwMDM2NDYgMDAwMDAgbiAKMDAwMDAwMzk2NiAwMDAwMCBuIAowMDAwMDA0MTI4IDAwMDAwIG4gCjAwMDAwMDQ1MzkgMDAwMDAgbiAKMDAwMDAwNDc3NSAwMDAwMCBuIAowMDAwMDA0OTE1IDAwMDAwIG4gCjAwMDAwMDUwMzIgMDAwMDAgbiAKMDAwMDAwNTM2MCAwMDAwMCBuIAowMDAwMDA1NTMwIDAwMDAwIG4gCjAwMDAwMDU3NjQgMDAwMDAgbiAKMDAwMDAwNjE1NyAwMDAwMCBuIAowMDAwMDA2NDQ0IDAwMDAwIG4gCjAwMDAwMDY1OTYgMDAwMDAgbiAKMDAwMDAwNjcxNyAwMDAwMCBuIAowMDAwMDA3MDI3IDAwMDAwIG4gCjAwMDAwMDcyNTcgMDAwMDAgbiAKMDAwMDAwNzY2MiAwMDAwMCBuIAowMDAwMDA3ODAyIDAwMDAwIG4gCjAwMDAwMDgxOTIgMDAwMDAgbiAKMDAwMDAwODI4MSAwMDAwMCBuIAowMDAwMDA4NDg1IDAwMDAwIG4gCjAwMDAwMDg4OTYgMDAwMDAgbiAKMDAwMDAwOTIxNyAwMDAwMCBuIAowMDAwMDA5NDYxIDAwMDAwIG4gCjAwMDAwMDk2MDUgMDAwMDAgbiAKMDAwMDAyMDQ0NCAwMDAwMCBuIAowMDAwMDIwOTA1IDAwMDAwIG4gCjAwMDAwMjA5ODUgMDAwMDAgbiAKdHJhaWxlcgo8PCAvSW5mbyA1MiAwIFIgL1Jvb3QgMSAwIFIgL1NpemUgNTMgPj4Kc3RhcnR4cmVmCjIxMTQyCiUlRU9GCg==\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T16:59:25.932478\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Create encoding block, bind to access positional encoding (module has no parameters)\n", "encod_block = PositionalEncoding(d_model=48, max_len=96).bind({})\n", "# Obtain positional encodings as numpy array\n", "pe = jax.device_get(encod_block.pe.squeeze().T)\n", "\n", "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8,3))\n", "pos = ax.imshow(pe, cmap=\"RdGy\", extent=(1,pe.shape[1]+1,pe.shape[0]+1,1))\n", "fig.colorbar(pos, ax=ax)\n", "ax.set_xlabel(\"Position in sequence\")\n", "ax.set_ylabel(\"Hidden dimension\")\n", "ax.set_title(\"Positional encoding over hidden dimensions\")\n", "ax.set_xticks([1]+[i*10 for i in range(1,1+pe.shape[1]//10)])\n", "ax.set_yticks([1]+[i*10 for i in range(1,1+pe.shape[0]//10)])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can clearly see the sine and cosine waves with different wavelengths that encode the position in the hidden dimensions. Specifically, we can look at the sine/cosine wave for each hidden dimension separately, to get a better intuition of the pattern. Below we visualize the positional encoding for the hidden dimensions $1$, $2$, $3$ and $4$." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDcyMS45MDYyNSAyNzkuODA4NzUgXSAvUGFyZW50IDIgMCBSIC9SZXNvdXJjZXMgOCAwIFIKL1R5cGUgL1BhZ2UgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMSAwIFIgPj4Kc3RyZWFtCnicxVxNbxw3Er3Pr+jj7kE0Wfwo8pggWQMBFtgkxuYc2EoiQ5LheLP79/dVj2ZYbPVM91icVozYaorN4St+VdXjGzd83L35xg2/fxnw12CHj/j/f/j5rTzvbiweH3ZMzhSbKOLpXj0RF5Nt5ohSVNWPf+x2v+2sKY5TYBtzHqYPoVhXkuU8/Ckf/PZZhePDblJ7twvRuLEHTj6SEn6BbvpQTEgl56TL73U5RTYh+7GLx0Z04djvz8PMB7hAZfA2GJ8OfwZmk6L8JuPv4c/b4ZfhcXjzDe3t+cPgdh/x796eb767/e/d+9uf3n47vP+yixmYbEmuhVCLdad2P+9+HD4fGrbGRYyVG/Ztj49vn0p3374b3vwDjdrh3W+7GA0nR+wp+Dy4RAY/5riH8u7D7m/u78O7j8P378bm+2DeHfr1HDPe9JkpUIu5FvfAzMGQ5RgD+RjnMNOmmEs02ftIvsVci3tgLoTu+Yz2MNZzmP2mmJ3DGHDJNrSgVXkP1M45ALEOeF2ZQx22Re3RAUo2T3YlVd4FNRUTQ/IF2zDan4Edt4UdMQpMPqUJ7FreBXZgUwolzmRpdl2nbWGnYpLLOKEmsGt5F9gpGezdJRZXZN98Dpu3hZ2zsSkUyhPYtbwL7BxMzpGCdXKYP0edN0VNNpnoHNnSolblPVCT9cZ7F8k55tmlXbaFTcGUyGGK+ljcBbTLJpaMsyGlEmadFLst6uBNsJ554pqp8i64PRsbgi8xwUWbxb2td0bJmRytjRP3TJV3wQ2nFM584pQizR7bblsPjbI1VJL3ExdNlXfBzd4Un7BwUgize5rb1kujkg0HjO7ES1PlXXDDOYVrSj6ju/Pre1s/zTs2OK5wpk6ix1reA7d3Fm2XiKiVspvFva2jhnDDJB8RVU9w1/IuuOGfEsdsbXJlPug8emoTDy/iWCHnbcoOLyJot8mWTMcX//Xpy91/7j49DnePw5fbz3/dPr6/3dSE9QNywt6QLVpr0w61/CILfsaH2uHG4leEacNOkiguk4kS3ggMP7x/aO2Vjed8yMocg/kdbPF5Z42VBvHn8APexqs3391+/PXff/386+OXm4e7x7++DN992v0ofzY1n5z5KWLKldZ8qvwlG46hkMSAmDk4ubK3Xvo72mhbV0JlmBwGgjy5CeBa3gUwgiRyEecLOpvaDE/ziky0G5kYsBAXWfCO0XSiUpwdbfWws0/T6rDqfr0fsOA+fbh7/P1aRjRxjxQYUk7YRfAAt3D/H+NhTAeSF8fhp6mxde5Mzy+VXkJf8TqP+5vKwATELswhSLHOUZA3hDnkZDbq6L4EmArhN0m5Cn8zMBI7xED3TRxo5Uf0bKyvAiVfTHHeZ6nfhBLWAG7J0h/lbHuHCNjCNFKsnFGLORRJ3rxvnLj9XgSIUq6cGdSn4vfd14c9EOKjsV4kZaoPQ7F7xqdJuTos0P1UUjrsdXb4Ye3Y/fn7fk7vJ/Txjd25N572SJm5h4G2LGnN42CPZww2uuHNP63sa4fq+D1WBCyAyY3RwBLBZxX5s6K2wWgntjKoS3VvSLabZDFLvKcV1XHAxYCBTx5DvVgd8yNYi9OZ7XLjkqLPGB0fVjTtZW4UEpN4v1gbfg3HEAu6UVZ0W0aHGZ4QJvdy4zfoC/Yj2BwDvKJxbF/Ry8DHFQOE2qXAICVaeGQrrIKF59klb1e0jafI2RVYcU3HZd/HKsKMtdqIP+ojgIRBwJpc8EW0MzfLaWBxzlMjD6eoEbxxGcWiX6gNnWv/abtYPuNodA5lk/N5dA3hHFLyT23JCfX907EkfuEfdx8+3D4OH+4ebh+/iK94PACfaK5LOKk9GTblpp7RWhNuCgvEz8RVLJN1atF7VdxwU7WRZ+zUeOoefu/3UdVXnb27A0t12oHBluZY+KPJNKvlPWIGbIJYmYy+5mzncyGNy7ER+EA4z32y0zVWy7uAxzkqeSAcYDzPY9BrgIcjYWc2DFXeBTzayLJ8sqQHlqmrjcCj3egSnLwJ+FreBTwn44FDXL04mxoJrwG+sCmJQpik+FV5F/DwouEWMFzqwLP5kfgK4DG68KsRLkxy3qq8CzHvJLQoNrgs7v8im7UReC8ObLTTOwnH4i7QEcYkOG1ECMxnVzy/BvTojYeXX6bYa3mfCxljvIMAIRPNMvX5NcCzQxDIkSeZf1XeBXzKJibOOWQ3u9uV18Be0NWCEGKS/VflXbBnb7IPgTHp43xW1L4Cemy+BpGW8xP/TpX3QI+Y2CD0QdRNhWfJe/caDl4iNjan4CYOnirvgh7eLRPbYCmX2UXvXsPDSwjYsA2nMvHwVHkX9HBvHXvJN52gQV7Dw0sIIAvnwhMPT5V3AY82ksSGRJgBK8ivjdBnMoEixYmLp8q7oId/axOCZMymML/sX8PHYztSLNFPfDxV3gU9HNwohz0smuaD2nkiLMq9NpySJXiX6auIsI0iBZXp0ISYynS8nA/zaMLZfOSPtybFNjdlS17U7FAPbsxjToZKxi8SZNuD10SZTo114Mmm4C8my3zGzEE0cDTH11Jm3c1q9jlia4IstgJYwp/Y+JSZzDNsWZPCU0bX2S2PpYf15CRPqRM/NmKDKnlkonSuBEs1YZ6OCUyVRcCP0SLklZmrA2m4Wdn7xNK8CjKx3zH5ONJxOv4K3thQvBXeSocmiKVLYPwj5cqdz6Pjsc+yake3OOkOkbSjXT0sMC6Ix8dy5QSN9xIAXOae9g+ENnQcMYL3zclJ6CfD6lKuzxpMP8se+3pNgJ8bMkWSna+mmLE6qF4SqXVgxyk+0gzuPJ2SOUQr97PycnVhSIIgzfFAvJ1tXcK/bBH/EJfF6nJ9Bos1F0fBL9ZGV+Rs9Vw8xcXa5NB2yRg075d7coNZzbAspuSBNjpXmzDtLFkc4JZWWAUjxjJLsc1haBd7LhdEsVFmcSkWK7MQvS4WPjJ152tLywUzPS4PPmEOMgxYUlnutYAUVg9+MKcV81CYUTQfsTkH3foF7NgJEuYEr4LlOEvPPJygZ4Tuuozmad/QFNnJT7iMIQsJrhYH8RnFa7yAIaPNGbIDUxiwk4VnROGxtFFuOccwk5UIWkm3mtLKjj37gEt1W2Lv3Xndluq8EnPp/vSRbRHh+Dmj1uqPsd6YURiVeKsHxlamVTHO3/3tj7Fe/1EYlVirB8ZWllUxzt/z7Y9R3WVSILU6qwfKVoZVUc7f6r0CynozS6NU97W6oGxlVxXm/CXeK8CsF800THX9rAvMVmZVYc6rq64As96b0zCV2qoLzFZWVWHOq6muALNeA9QwlbqqC8xGRlVRzqun+qNUlxoVSq2m6oFyIpuqMOfVUleAebyjqVFW8VQXkK1MSjkF81ear4CyXjnVMJVaqgvOVhalcG7l/agrtBqnuljbBWcrg1I4t/KA1JVgjVOpobrgbGVPCudWXpC64qxxKvVTF5ytzEnh3MoPUle2dVyl1E49cE5kTQrnVo6QuoKucSp1UxecrYxJ4VylXpKXs6N4mW6pv7Fq03LmPlfc1OKX65UCPG53JBO2kytd0WqcAe65bKcW91DtcDSIhDUPsdWRrtIoTu6joQuTO961vAdSRE5wuynmIn29SJ8kCd5snfdlNNLl8qQXWG93FCYxPSXe8ZBIUuVjHN385pwwSc+nmlApRYQ32GwaWZLDThuxnbiJLEmMmDimGFpZUnGmJBtjaVVJbI24/24iSorC8oQy7pE6GElyRQt9bSVJUfK3sFBpFElytGPte2oFSYBABHPwRI9kCyYHl1FHpN0MlEdphVs9UpF8pvC+rRwJqwW7LOq0aiTJmWVJs+jU6ZrharVIx/LduTdOapGO41v5FjqTQ7eyJwaMUuGwWFuWotwUxh570PScqX1DsHvAnMs203JtB/M5gA0OQcaKxrEiYubghEJZ0ToOZCELsP5oRWeScAvCsaUDT3SmNs4VrN/kg7XLPXF7yRV7rKvl2kApPCTs4ni5afFzo48xYRUs1saidF44JWF8l+3nhcixMv/jmqEXpjZg8LHw14w9tnu4dbLQizbKV2iRlEM2n6JvpUjz+f/pC+vpgka2VFs52/qLdEiO/L6x/emxSLX4zamWyjnpXP6RclLGU8RVQ7UonmuebGlvM3zlOXtORV1vKuj5paRJPRz+iQbpFOtydbD1+oUGq6RIXcC2mqNT9MvVwdY7JRqskh51AdtqjE7xMFcHWy/KaLDq+kwXsK2m6BQdc3Ww9fqPBqsuBXUB22qITpEy1war7jRpUlhJibqwwq1m6BQ1c3Wwx4taGmtVDnWB2mqETtEzV4daL59prEoq1Iftt1oTdIqluTrYeqNOg1XSoC5gGw3QKarm6ljrLUGNVd0d7IK11fycpGyujVbdfVRo9Y3IHmgnGp+TxM3V0dYbnRqtkvp0Qdtqek7SN1dHW++parRK2tMFbaPhOcnhXB1svXyrwSopTxewrWbnJJNzdbT1SrFGq6Q7XdC2Gp2TfM610aqL0jpQVdenu6BtNTmLrM5EinMZq3N1H1vF9ordUcH9i8mdifhma4ZnQws2mfmjBXsQPRPZyRLbsyHmhvXR6Z8OpM8E9eXMz0Rs89X8T0d7PrFAeZRhJAAbkcNS8k1OcXwYv8gsRHdecNMYXud20Clrkw2hFdxkHOdp/H43lS5IaA5t+NCqbUIRN8zn0IptAgAGvxer1GgtFkkY8viVbyquyfLdayFG1yptHHZjuFC0b0R7y1ZkZDH5idLGJiNf5TUR2uBFgkcx4lGeSpL32HJoZTYRZi9xz4LpIxFNE6Vxb1OnREyGLOaezu2uGq+GBqpv7M69cVp4Uwd4nO5jTt2fScHLd9JhllpAcsu1x8fowp75WqxOTvL7IaHreUV1EQGFYqNl1F9R3UjK1JXMvFzbEaaY3BMWgy7WpvGb5oQFCnG5467IRSN23rm0oieYhI7l+xEprDAK1iWmbCDGIK+wSRLmlrGc0+E7BhcGCCs6+2iLW2NxrCSMPlZFyCs6I5ougmXQuF+sXbBpO9kty0EY5b9aeKMTfydIhEZ3M0tQPCOD1hMaE5GO5oNOt/8i0c2ldFANKHb/B5p5Ts8KZW5kc3RyZWFtCmVuZG9iagoxMSAwIG9iagozODExCmVuZG9iagoyMCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDc5ID4+CnN0cmVhbQp4nE3Nuw3AIAwE0J4pPALg/z5RqrB/GxsiQmM/6U46wQ4V3OKwGGh3uFrxpVGYfeqZEpJQcz1EWDMlOoSkX/rLMMOY2Mi277dW7hfeGxwZCmVuZHN0cmVhbQplbmRvYmoKMjEgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNjUgPj4Kc3RyZWFtCnicRY87EgMhDEN7TqEjgH/AeTaTir1/G8s7SRosjCU/ois69srDY2PKxmu0sSfCFu5SOg2nqYyviqdnXaDLYTJTb1zNXGCqsMhuTrH6GHyh8uzmhK9VnhjCl0wJDTCVO7mH9fpRnJZ8JLsLguqUjcrCMEfS90BMTZunhYH8jy95akFQmeaNa5aVR2sVUzRnmCpbC4L1gaA6pfoD0/9Mp70/3PQ9gAplbmRzdHJlYW0KZW5kb2JqCjIyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzA0ID4+CnN0cmVhbQp4nD2SO5LDMAxDe52CF8iM+JPk82Qnlff+7T4yyVaASYkAKC91mbKmPCBpJgn/0eHhYjvld9iezczAtUQvE8spz6ErxNxF+bKZjbqyOsWqwzCdW/SonIuGTZOa5ypLGbcLnsO1ieeWfcQPNzSoB3WNS8IN3dVoWQrNcHX/O71H2Xc1PBebVOrUF48XURXm+SFPoofpSuJ8PCghXHswRhYS5FPRQI6zXK3yXkL2DrcassJBaknnsyc82HV6Ty5uF80QD2S5VPhOUezt0DO+7EoJPRK24VjufTuasekamzjsfu9G1sqMrmghfshXJ+slYNxTJkUSZE62WG6L1Z7uoSimc4ZzGSDq2YqGUuZiV6t/DDtvLC/ZLMiUzAsyRqdNnjh4yH6NmvR5led4/QFs83M7CmVuZHN0cmVhbQplbmRvYmoKMjMgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMzAgPj4Kc3RyZWFtCnicNVFJbsMwDLzrFfOBAOIuv8dBT+3/rx3SCWBgaEuczREbGxF4icHPQeTGW9aMmvibyV3xuzwVHgm3gidRBF6Ge9kJLm8Yl/04zHzwXlo5kxpPMiAX2fTwRMhgl0DowOwa1GGbaSf6hoTPjkg1G1lOX0vQS6sQKE/ZfqcLSrSt6s/tsy607WtPONntqSeVTyCeW7ICl41XTBZjGfRE5S7F9EGqs4WehPKifA6y+aghEl2inIEnBgejQDuw57afiVeFoHV1n7aNoRopHU//NjQ1SSLkEyWc2dK4W/j+nnv9/AOmVFOfCmVuZHN0cmVhbQplbmRvYmoKMjQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMjcgPj4Kc3RyZWFtCnicNU87sgMhDOs5hS6QGYxtYM+zmVQv92+fZLINEv5I8vRERyZe5sgIrNnxthYZiBn4FlPxrz3tw4TqPbiHCOXiQphhJJw167ibp+PFv13lM9bBuw2+YpYXBLYwk/WVxZnLdsFYGidxTrIbY9dEbGNd6+kU1hFMKAMhne0wJcgcFSl9sqOMOTpO5InnYqrFLr/vYX3BpjGiwhxXBU/QZFCWPe8moB0X9N/Vjd9JNIteAjKRYGGdJObOWU741WtHx1GLIjEnpBnkMhHSnK5iCqEJxTo7CioVBZfqc8rdPv9oXVtNCmVuZHN0cmVhbQplbmRvYmoKMjUgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDUgPj4Kc3RyZWFtCnicRVC7jUMxDOs9BRcIYP0se553SJXbvz1KRnCFIVo/kloSmIjASwyxlG/iR0ZBPQu/F4XiM8TPF4VBzoSkQJz1GRCZeIbaRm7odnDOvMMzjDkCF8VacKbTmfZc2OScBycQzm2U8YxCuklUFXFUn3FM8aqyz43XgaW1bLPTkewhjYRLSSUml35TKv+0KVsq6NpFE7BI5IGTTTThLD9DkmLMoJRR9zC1jvRxspFHddDJ2Zw5LZnZ7qftTHwPWCaZUeUpnecyPiep81xOfe6zHdHkoqVV+5z93pGW8iK126HV6VclUZmN1aeQuDz/jJ/x/gOOoFk+CmVuZHN0cmVhbQplbmRvYmoKMjYgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzOTIgPj4Kc3RyZWFtCnicPVJLbgUxCNvPKbhApfBNcp6p3u7df1ubzFSqCi8DtjGUlwypJT/qkogzTH71cl3iUfK9bGpn5iHuLjam+FhyX7qG2HLRmmKxTxzJL8i0VFihVt2jQ/GFKBMPAC3ggQXhvhz/8ReowdewhXLDe2QCYErUbkDGQ9EZSFlBEWH7kRXopFCvbOHvKCBX1KyFoXRiiA2WACm+qw2JmKjZoIeElZKqHdLxjKTwW8FdiWFQW1vbBHhm0BDZ3pGNETPt0RlxWRFrPz3po1EytVEZD01nfPHdMlLz0RXopNLI3cpDZ89CJ2Ak5kmY53Aj4Z7bQQsx9HGvlk9s95gpVpHwBTvKAQO9/d6Sjc974CyMXNvsTCfw0WmnHBOtvh5i/YM/bEubXMcrh0UUqLwoCH7XQRNxfFjF92SjRHe0AdYjE9VoJRAMEsLO7TDyeMZ52d4VtOb0RGijRB7UjhE9KLLF5ZwVsKf8rM2xHJ4PJntvtI+UzMyohBXUdnqots9jHdR3nvv6/AEuAKEZCmVuZHN0cmVhbQplbmRvYmoKMjcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDcgPj4Kc3RyZWFtCnicTVG7bUQxDOvfFFzgAOtreZ4LUl32b0PJCJDCIKEvKaclFvbGSwzhB1sPvuSRVUN/Hj8x7DMsPcnk1D/muclUFL4VqpuYUBdi4f1oBLwWdC8iK8oH349lDHPO9+CjEJdgJjRgrG9JJhfVvDNkwomhjsNBm1QYd00ULK4VzTPI7VY3sjqzIGx4JRPixgBEBNkXkM1go4yxlZDFch6oCpIFWmDX6RtRi4IrlNYJdKLWxLrM4Kvn9nY3Qy/y4Ki6eH0M60uwwuileyx8rkIfzPRMO3dJI73wphMRZg8FUpmdkZU6PWJ9t0D/n2Ur+PvJz/P9CxUoXCoKZW5kc3RyZWFtCmVuZG9iagoyOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDkwID4+CnN0cmVhbQp4nE2NQRLAIAgD77wiT1BE0P90etL/X6vUDr3ATgKJFkWC9DVqSzDuuDIVa1ApmJSXwFUwXAva7qLK/jJJTJ2G03u3A4Oy8XGD0kn79nF6AKv9egbdD9IcIlgKZW5kc3RyZWFtCmVuZG9iagoyOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMzOCA+PgpzdHJlYW0KeJxFUktyxTAI2+cUXCAz5mfj87xOV+n9t5VwOt089AwICTI9ZUim3DaWZITkHPKlV2SI1ZCfRo5ExBDfKaHArvK5vJbEXMhuiUrxoR0/l6U3Ms2u0Kq3R6c2i0Y1KyPnIEOEelbozO5R22TD63Yh6TpTFodwLP9DBbKUdcoplARtQd/YI+hvFjwR3Aaz5nKzuUxu9b/uWwue1zpbsW0HQAmWc95gBgDEwwnaAMTc2t4WKSgfVbqKScKt8lwnO1C20Kp0vDeAGQcYOWDDkq0O12hvAMM+D/SiRsX2FaCoLCD+ztlmwd4xyUiwJ+YGTj1xOsWRcEk4xgJAiq3iFLrxHdjiLxeuiJrwCXU6ZU28wp7a4sdCkwjvUnEC8CIbbl0dRbVsT+cJtD8qkjNipB7E0QmR1JLOERSXBvXQGvu4iRmvjcTmnr7dP8I5n+v7Fxa4g+AKZW5kc3RyZWFtCmVuZG9iagozMCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE2MyA+PgpzdHJlYW0KeJxFkLl1BDEMQ3NVgRJ4gDrqGT9Hs/2nC2m83kD6eIR4iD0Jw3JdxYXRDT/etsw0vI4y3I31Zcb4qLFATtAHGCITV6NJ9e2KM1Tp4dVirqOiXC86IhLMkuOrQCN8OrLHQ1vbmX46r3/sIe8T/yoq525hAS6q7kD5Uh/x1I/ZUeqaoY8qK2seatq/CLsilLZ9XE5lnLp7B7TCZytX+30DqOc6gAplbmRzdHJlYW0KZW5kb2JqCjMxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNjggPj4Kc3RyZWFtCnicMzK3UDBQsDQBEoYWJgrmZgYKKYZcQL6piblCLhdIDMTKAbMMgLQlnIKIW0I0QZSCWBClZiZmEEk4AyKXBgDJtBXlCmVuZHN0cmVhbQplbmRvYmoKMzIgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA0NSA+PgpzdHJlYW0KeJwzMrdQMFCwNAEShhYmCuZmBgophlyWEFYuF0wsB8wC0ZZwCiKeBgCffQy1CmVuZHN0cmVhbQplbmRvYmoKMzMgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNTUgPj4Kc3RyZWFtCnicRZFLkgMgCET3noIjgPzkPJmaVXL/7TSYTDZ2l6j9hEojphIs5xR5MP3I8s1ktum1HKudjQKKIhTM5Cr0WIHVnSnizLVEtfWxMnLc6R2D4g3nrpxUsrhRxjqqOhU4pufK+qru/Lgsyr4jhzIFbNY5DjZw5bZhjBOjzVZ3h/tEkKeTqaPidpBs+IOTxr7K1RW4Tjb76iUYB4J+oQlM8k2gdYZA4+YpenIJ9vFxu/NAsLe8CaRsCOTIEIwOQbtOrn9x6/ze/zrDnefaDFeOd/E7TGu74y8xyYq5gEXuFNTzPRet6wwd78mZY3LTfUPnXLDL3UGmz/wf6/cPUIpmiAplbmRzdHJlYW0KZW5kb2JqCjM0IDAgb2JqCjw8IC9CQm94IFsgLTEwMjEgLTQ2MyAxNzk0IDEyMzMgXSAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDM3Ci9TdWJ0eXBlIC9Gb3JtIC9UeXBlIC9YT2JqZWN0ID4+CnN0cmVhbQp4nOMyNDBTMDY1VcjlMjc2ArNywCwjcyMgCySLYEFk0wABXwoKCmVuZHN0cmVhbQplbmRvYmoKMzUgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNjEgPj4Kc3RyZWFtCnicRZBLEsMgDEP3nEJH8EcGfJ50ukrvv60hTbOAp7FABncnBKm1BRPRBS9tS7oLPlsJzsZ46DZuNRLkBHWAVqTjaJRSfbnFaZV08Wg2cysLrRMdZg56lKMZoBA6Fd7touRypu7O+Udw9V/1R7HunM3EwGTlDoRm9SnufJsdUV3dZH/SY27Wa38V9qqwtKyl5YTbzl0zoATuqRzt/QWpczqECmVuZHN0cmVhbQplbmRvYmoKMzYgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMjAgPj4Kc3RyZWFtCnicNVG7ccUwDOs1BRfwnfiVNI9zr8rbvw1AOxVhGgRAqrxkSrlc6pJVssLkR4fqFE35PmCm/A71kOPoHtkhulPWlnsYCMvEPKWOWE2We7gFgS8MTYm5hfP3COgrBqMwE4G6xd8/QLMkMGlw8FOQa61aYokOPCwWWLMrzK0aKVTIVXw7NrkHBXJxs9CnHJoUt9yC8GWIZEdqsa/LZSnyu/UJGIQV5ohPFImF54EOZiLxJwNie/bZYldXL6oRGdZJhwdSBNJsbhIwNEWy6oMb2FfHNT9PR9nByUG/isH4NjiZL0l5XwWhEI8X/g7P2cixkkMkFPJ9tcCII2yAEaFP7SMQZSA0RffumVI+JlWK7wBGIRx9qlcyvBeR2WqGzf8ZXdkqCgZVWR+fRnAmg0k482SjCtNStdO/+9zj8wdjY3qACmVuZHN0cmVhbQplbmRvYmoKMzcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMTQgPj4Kc3RyZWFtCnicPVC7EUMxCOs9BQvkznztN8/Lpcv+bSScpEI2QhKUmkzJlIc6ypKsKU8dPktih7yH5W5kNiUqRS+TsCX30ArxfYnmFPfd1ZazQzSXaDl+CzMqqhsd00s2mnAqE7qg3MMz+g1tdANWhx6xWyDQpGDXtiByxw8YDMGZE4siDEpNBv+tcvdS3O89HG+iiJR08K755fTLzy28Tj2ORLq9+YprcaY6CkRwRmryinRhxbLIQ6TVBDU9A2u1AK7eevk3aEd0GYDsE4njNKUcQ//WuMfrA4eKUvQKZW5kc3RyZWFtCmVuZG9iagozOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDgwID4+CnN0cmVhbQp4nEWMuw3AMAhEe6ZgBH4mZp8olbN/GyBK3HBPunu4OhIyU95hhocEngwshlPxBpmjYDW4RlKNneyjsG5fdYHmelOr9fcHKk92dnE9zcsZ9AplbmRzdHJlYW0KZW5kb2JqCjM5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjM3ID4+CnN0cmVhbQp4nE1ROW4EMQzr/Qp9YADrtOc9G2w1+X8bUt4EqURDFEnJ5SlTMuWyuSQjJGvKlw6NLbZcvg9CU0tFwdS9RXXJa5imrFssXdDzqSyv4Rjge3c31D/0iNkCkdGCXWGBDpA7uGD4PXsmbFMLIlEl1AxgmrDCHK5EDEEGY50ZBqUKg1P1d5Xjsw07BdYOZlOkR1ITnXSD5oW33nIhgq1Tuak30oTc2acYYmXjvkqX4wPgYKLLRGTE5mU4ng5haPDONGCFsx7EBJnWkdKLmZExDTfpyNVXi4rPNhlLntH/9of6K59u/4MQfMb7B0lVXH4KZW5kc3RyZWFtCmVuZG9iago0MCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMzMiA+PgpzdHJlYW0KeJwtUjmOJDEMy/0KfmAA6/Lxnh5M1Pv/dElVBQWqbMs85HLDRCV+LJDbUWvi10ZmoMLwr6vMhe9I28g6iGvIRVzJlsJnRCzkMcQ8xILv2/gZHvmszMmzB8Yv2fcZVuypCctCxosztMMqjsMqyLFg6yKqe3hTpMOpJNjji/8+xXMXgha+I2jAL/nnqyN4vqRF2j1m27RbD5ZpR5UUloPtac7L5EvrLFfH4/kg2d4VO0JqV4CiMHfGeS6OMm1lRGthZ4OkxsX25tiPpQRd6MZlpDgC+ZkqwgNKmsxsoiD+yOkhpzIQpq7pSie3URV36slcs7m8nUkyW/dFis0UzuvCmfV3mDKrzTt5lhOlTkX4GXu2BA2d4+rZa5mFRrc5wSslfDZ2enLyvZpZD8mpSEgV07oKTqPIFEvYlviaiprS1Mvw35f3GX//ATPifAEKZW5kc3RyZWFtCmVuZG9iago0MSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDY4ID4+CnN0cmVhbQp4nDMzNlMwULAwAhKmpoYK5kaWCimGXEA+iJXLBRPLAbPMLMyBLCMLkJYcLkMLYzBtYmykYGZiBmRZIDEgutIAcvgSkQplbmRzdHJlYW0KZW5kb2JqCjQyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzE3ID4+CnN0cmVhbQp4nDVSS3JDMQjbv1Nwgc6Yv32edLJq7r+thCcrsC1AQi4vWdJLftQl26XD5Fcf9yWxQj6P7ZrMUsX3FrMUzy2vR88Rty0KBFETPfgyJxUi1M/U6Dp4YZc+A68QTikWeAeTAAav4V94lE6DwDsbMt4Rk5EaECTBmkuLTUiUPUn8K+X1pJU0dH4mK3P5e3KpFGqjyQgVIFi52AekKykeJBM9iUiycr03VojekFeSx2clJhkQ3SaxTbTA49yVtISZmEIF5liA1XSzuvocTFjjsITxKmEW1YNNnjWphGa0jmNkw3j3wkyJhYbDElCbfZUJqpeP09wJI6ZHTXbtwrJbNu8hRKP5MyyUwccoJAGHTmMkCtKwgBGBOb2wir3mCzkWwIhlnZosDG1oJbt6joXA0JyzpWHG157X8/4HRVt7owplbmRzdHJlYW0KZW5kb2JqCjQzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTcgPj4Kc3RyZWFtCnicMza0UDCAwxRDLgAalALsCmVuZHN0cmVhbQplbmRvYmoKNDQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMzEgPj4Kc3RyZWFtCnicRY/LDQQhDEPvVOES8hk+qYfVntj+r+swmkFC+EEiO/EwCKzz8jbQxfDRosM3/jbVq2OVLB+6elJWD+mQh7zyFVBpMFHEhVlMHUNhzpjKyJYytxvhtk2DrGyVVK2DdjwGD7anZasIfqltYeos8QzCVV64xw0/kEutd71Vvn9CUzCXCmVuZHN0cmVhbQplbmRvYmoKNDUgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMzggPj4Kc3RyZWFtCnicNVI5rt1ADOt9Cl0ggHbNnOcFqX7u34aUXwpDtFaKmo4WlWn5ZSFVLZMuv+1JbYkb8vfJCokTklcl2qUMkVD5PIVUv2fLvL7WnBEgS5UKk5OSxyUL/gyX3i4c52NrP48jdz16YFWMhBIByxQTo2tZOrvDmo38PKYBP+IRcq5YtxxjFUgNunHaFe9D83nIGiBmmJaKCl1WiRZ+QfGgR61991hUWCDR7RxJcIyNUJGAdoHaSAw5sxa7qC/6WZSYCXTtiyLuosASScycYl06+g8+dCyovzbjy6+OSvpIK2tM2nejSWnMIpOul0VvN299PbhA8y7Kf17NIEFT1ihpfNCqnWMomhllhXccmgw0xxyHzBM8hzMSlPR9KH5fSya6KJE/Dg2hf18eo4ycBm8Bc9GftooDF/HZYa8cYIXSxZrkfUAqE3pg+v/X+Hn+/AMctoBUCmVuZHN0cmVhbQplbmRvYmoKNDYgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDggPj4Kc3RyZWFtCnicLVE5kgNBCMvnFXpCc9PvscuR9//pCsoBg4ZDIDotcVDGTxCWK97yyFW04e+ZGMF3waHfynUbFjkQFUjSGFRNqF28Hr0HdhxmAvOkNSyDGesDP2MKN3pxeEzG2e11GTUEe9drT2ZQMisXccnEBVN12MiZw0+mjAvtXM8NyLkR1mUYpJuVxoyEI00hUkih6iapM0GQBKOrUaONHMV+6csjnWFVI2oM+1xL29dzE84aNDsWqzw5pUdXnMvJxQsrB/28zcBFVBqrPBAScL/bQ/2c7OQ33tK5s8X0+F5zsrwwFVjx5rUbkE21+Dcv4vg94+v5/AOopVsWCmVuZHN0cmVhbQplbmRvYmoKNDcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNzEgPj4Kc3RyZWFtCnicTZBNDkIhEIP3nKIXMKHzA4/zaFzp/bd28PnigvRLIUOnwwMdR+JGR4bO6HiwyTEOvAsyJl6N85+M6ySOCeoVbcG6tDvuzSwxJywTI2BrlNybRxT44ZgLQYLs8sMXGESka5hvNZ91k35+u9Nd1KV199MjCpzIjlAMG3AF2NM9DtwSzu+aJr9UKRmbOJQPVBeRstkJhailYpdTVWiM4lY974te7fkBwfY7+wplbmRzdHJlYW0KZW5kb2JqCjQ4IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjEwID4+CnN0cmVhbQp4nDVQyw1DMQi7ZwoWqBQCgWSeVr11/2tt0DthEf9CWMiUCHmpyc4p6Us+OkwPti6/sSILrXUl7MqaIJ4r76GZsrHR2OJgcBomXoAWN2DoaY0aNXThgqYulUKBxSXwmXx1e+i+Txl4ahlydgQRQ8lgCWq6Fk1YtDyfkE4B4v9+w+4t5KGS88qeG/kbnO3wO7Nu4SdqdiLRchUy1LM0xxgIE0UePHlFpnDis9Z31TQS1GYLTpYBrk4/jA4AYCJeWYDsrkQ5S9KOpZ9vvMf3D0AAU7QKZW5kc3RyZWFtCmVuZG9iagoxOCAwIG9iago8PCAvQmFzZUZvbnQgL0RlamFWdVNhbnMgL0NoYXJQcm9jcyAxOSAwIFIKL0VuY29kaW5nIDw8Ci9EaWZmZXJlbmNlcyBbIDMyIC9zcGFjZSA0OCAvemVybyAvb25lIC90d28gL3RocmVlIC9mb3VyIC9maXZlIC9zaXggL3NldmVuIC9laWdodCAvbmluZQo2OSAvRSA4MCAvUCA5NyAvYSA5OSAvYyAvZCAvZSAxMDMgL2cgL2ggL2kgMTA4IC9sIC9tIC9uIC9vIDExMyAvcSAxMTUgL3MgL3QKL3UgXQovVHlwZSAvRW5jb2RpbmcgPj4KL0ZpcnN0Q2hhciAwIC9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnREZXNjcmlwdG9yIDE3IDAgUgovRm9udE1hdHJpeCBbIDAuMDAxIDAgMCAwLjAwMSAwIDAgXSAvTGFzdENoYXIgMjU1IC9OYW1lIC9EZWphVnVTYW5zCi9TdWJ0eXBlIC9UeXBlMyAvVHlwZSAvRm9udCAvV2lkdGhzIDE2IDAgUiA+PgplbmRvYmoKMTcgMCBvYmoKPDwgL0FzY2VudCA5MjkgL0NhcEhlaWdodCAwIC9EZXNjZW50IC0yMzYgL0ZsYWdzIDMyCi9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnROYW1lIC9EZWphVnVTYW5zIC9JdGFsaWNBbmdsZSAwCi9NYXhXaWR0aCAxMzQyIC9TdGVtViAwIC9UeXBlIC9Gb250RGVzY3JpcHRvciAvWEhlaWdodCAwID4+CmVuZG9iagoxNiAwIG9iagpbIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwCjYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgMzE4IDQwMSA0NjAgODM4IDYzNgo5NTAgNzgwIDI3NSAzOTAgMzkwIDUwMCA4MzggMzE4IDM2MSAzMTggMzM3IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYKNjM2IDYzNiAzMzcgMzM3IDgzOCA4MzggODM4IDUzMSAxMDAwIDY4NCA2ODYgNjk4IDc3MCA2MzIgNTc1IDc3NSA3NTIgMjk1CjI5NSA2NTYgNTU3IDg2MyA3NDggNzg3IDYwMyA3ODcgNjk1IDYzNSA2MTEgNzMyIDY4NCA5ODkgNjg1IDYxMSA2ODUgMzkwIDMzNwozOTAgODM4IDUwMCA1MDAgNjEzIDYzNSA1NTAgNjM1IDYxNSAzNTIgNjM1IDYzNCAyNzggMjc4IDU3OSAyNzggOTc0IDYzNCA2MTIKNjM1IDYzNSA0MTEgNTIxIDM5MiA2MzQgNTkyIDgxOCA1OTIgNTkyIDUyNSA2MzYgMzM3IDYzNiA4MzggNjAwIDYzNiA2MDAgMzE4CjM1MiA1MTggMTAwMCA1MDAgNTAwIDUwMCAxMzQyIDYzNSA0MDAgMTA3MCA2MDAgNjg1IDYwMCA2MDAgMzE4IDMxOCA1MTggNTE4CjU5MCA1MDAgMTAwMCA1MDAgMTAwMCA1MjEgNDAwIDEwMjMgNjAwIDUyNSA2MTEgMzE4IDQwMSA2MzYgNjM2IDYzNiA2MzYgMzM3CjUwMCA1MDAgMTAwMCA0NzEgNjEyIDgzOCAzNjEgMTAwMCA1MDAgNTAwIDgzOCA0MDEgNDAxIDUwMCA2MzYgNjM2IDMxOCA1MDAKNDAxIDQ3MSA2MTIgOTY5IDk2OSA5NjkgNTMxIDY4NCA2ODQgNjg0IDY4NCA2ODQgNjg0IDk3NCA2OTggNjMyIDYzMiA2MzIgNjMyCjI5NSAyOTUgMjk1IDI5NSA3NzUgNzQ4IDc4NyA3ODcgNzg3IDc4NyA3ODcgODM4IDc4NyA3MzIgNzMyIDczMiA3MzIgNjExIDYwNQo2MzAgNjEzIDYxMyA2MTMgNjEzIDYxMyA2MTMgOTgyIDU1MCA2MTUgNjE1IDYxNSA2MTUgMjc4IDI3OCAyNzggMjc4IDYxMiA2MzQKNjEyIDYxMiA2MTIgNjEyIDYxMiA4MzggNjEyIDYzNCA2MzQgNjM0IDYzNCA1OTIgNjM1IDU5MiBdCmVuZG9iagoxOSAwIG9iago8PCAvRSAyMCAwIFIgL1AgMjEgMCBSIC9hIDIyIDAgUiAvYyAyMyAwIFIgL2QgMjQgMCBSIC9lIDI1IDAgUgovZWlnaHQgMjYgMCBSIC9maXZlIDI3IDAgUiAvZm91ciAyOCAwIFIgL2cgMjkgMCBSIC9oIDMwIDAgUiAvaSAzMSAwIFIKL2wgMzIgMCBSIC9tIDMzIDAgUiAvbiAzNSAwIFIgL25pbmUgMzYgMCBSIC9vIDM3IDAgUiAvb25lIDM4IDAgUiAvcSAzOSAwIFIKL3MgNDAgMCBSIC9zZXZlbiA0MSAwIFIgL3NpeCA0MiAwIFIgL3NwYWNlIDQzIDAgUiAvdCA0NCAwIFIgL3RocmVlIDQ1IDAgUgovdHdvIDQ2IDAgUiAvdSA0NyAwIFIgL3plcm8gNDggMCBSID4+CmVuZG9iagozIDAgb2JqCjw8IC9GMSAxOCAwIFIgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9DQSAwIC9UeXBlIC9FeHRHU3RhdGUgL2NhIDEgPj4KL0EyIDw8IC9DQSAxIC9UeXBlIC9FeHRHU3RhdGUgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgL0YxLURlamFWdVNhbnMtbWludXMgMzQgMCBSIC9NMCAxMiAwIFIgL00xIDEzIDAgUiAvTTIgMTQgMCBSIC9NMyAxNSAwIFIKPj4KZW5kb2JqCjEyIDAgb2JqCjw8IC9CQm94IFsgLTggLTggOCA4IF0gL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMzEgL1N1YnR5cGUgL0Zvcm0KL1R5cGUgL1hPYmplY3QgPj4Kc3RyZWFtCnicbZBBDoQgDEX3PUUv8ElLRWXr0mu4mUzi/bcDcUBM3TTQvjx+Uf6S8E6lwPgkCUtOs+R605DSukyMGObVsijHoFEt1s51OKjP0HBjdIuxFKbU1uh4o5vpNt6TP/qwWSFGPxwOr4R7FkMmXCkxBoffCy/bw/8Rnl7UwB+ijX5jWkP9CmVuZHN0cmVhbQplbmRvYmoKMTMgMCBvYmoKPDwgL0JCb3ggWyAtOCAtOCA4IDggXSAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDEzMSAvU3VidHlwZSAvRm9ybQovVHlwZSAvWE9iamVjdCA+PgpzdHJlYW0KeJxtkEEOhCAMRfc9RS/wSUtFZevSa7iZTOL9twNxQEzdNNC+PH5R/pLwTqXA+CQJS06z5HrTkNK6TIwY5tWyKMegUS3WznU4qM/QcGN0i7EUptTW6Hijm+k23pM/+rBZIUY/HA6vhHsWQyZcKTEGh98LL9vD/xGeXtTAH6KNfmNaQ/0KZW5kc3RyZWFtCmVuZG9iagoxNCAwIG9iago8PCAvQkJveCBbIC04IC04IDggOCBdIC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTMxIC9TdWJ0eXBlIC9Gb3JtCi9UeXBlIC9YT2JqZWN0ID4+CnN0cmVhbQp4nG2QQQ6EIAxF9z1FL/BJS0Vl69JruJlM4v23A3FATN000L48flH+kvBOpcD4JAlLTrPketOQ0rpMjBjm1bIox6BRLdbOdTioz9BwY3SLsRSm1NboeKOb6Tbekz/6sFkhRj8cDq+EexZDJlwpMQaH3wsv28P/EZ5e1MAfoo1+Y1pD/QplbmRzdHJlYW0KZW5kb2JqCjE1IDAgb2JqCjw8IC9CQm94IFsgLTggLTggOCA4IF0gL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMzEgL1N1YnR5cGUgL0Zvcm0KL1R5cGUgL1hPYmplY3QgPj4Kc3RyZWFtCnicbZBBDoQgDEX3PUUv8ElLRWXr0mu4mUzi/bcDcUBM3TTQvjx+Uf6S8E6lwPgkCUtOs+R605DSukyMGObVsijHoFEt1s51OKjP0HBjdIuxFKbU1uh4o5vpNt6TP/qwWSFGPxwOr4R7FkMmXCkxBoffCy/bw/8Rnl7UwB+ijX5jWkP9CmVuZHN0cmVhbQplbmRvYmoKMiAwIG9iago8PCAvQ291bnQgMSAvS2lkcyBbIDEwIDAgUiBdIC9UeXBlIC9QYWdlcyA+PgplbmRvYmoKNDkgMCBvYmoKPDwgL0NyZWF0aW9uRGF0ZSAoRDoyMDIyMDUzMTE2NTkyNiswMicwMCcpCi9DcmVhdG9yIChNYXRwbG90bGliIHYzLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjMuMikgPj4KZW5kb2JqCnhyZWYKMCA1MAowMDAwMDAwMDAwIDY1NTM1IGYgCjAwMDAwMDAwMTYgMDAwMDAgbiAKMDAwMDAxNTUxOSAwMDAwMCBuIAowMDAwMDE0MjM3IDAwMDAwIG4gCjAwMDAwMTQyNjkgMDAwMDAgbiAKMDAwMDAxNDM2OCAwMDAwMCBuIAowMDAwMDE0Mzg5IDAwMDAwIG4gCjAwMDAwMTQ0MTAgMDAwMDAgbiAKMDAwMDAwMDA2NSAwMDAwMCBuIAowMDAwMDAwMzk3IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwNDI4MyAwMDAwMCBuIAowMDAwMDE0NTAzIDAwMDAwIG4gCjAwMDAwMTQ3NTcgMDAwMDAgbiAKMDAwMDAxNTAxMSAwMDAwMCBuIAowMDAwMDE1MjY1IDAwMDAwIG4gCjAwMDAwMTI4NDggMDAwMDAgbiAKMDAwMDAxMjY0OCAwMDAwMCBuIAowMDAwMDEyMjAzIDAwMDAwIG4gCjAwMDAwMTM5MDEgMDAwMDAgbiAKMDAwMDAwNDMwNCAwMDAwMCBuIAowMDAwMDA0NDU1IDAwMDAwIG4gCjAwMDAwMDQ2OTMgMDAwMDAgbiAKMDAwMDAwNTA3MCAwMDAwMCBuIAowMDAwMDA1MzczIDAwMDAwIG4gCjAwMDAwMDU2NzMgMDAwMDAgbiAKMDAwMDAwNTk5MSAwMDAwMCBuIAowMDAwMDA2NDU2IDAwMDAwIG4gCjAwMDAwMDY3NzYgMDAwMDAgbiAKMDAwMDAwNjkzOCAwMDAwMCBuIAowMDAwMDA3MzQ5IDAwMDAwIG4gCjAwMDAwMDc1ODUgMDAwMDAgbiAKMDAwMDAwNzcyNSAwMDAwMCBuIAowMDAwMDA3ODQyIDAwMDAwIG4gCjAwMDAwMDgxNzAgMDAwMDAgbiAKMDAwMDAwODM0MCAwMDAwMCBuIAowMDAwMDA4NTc0IDAwMDAwIG4gCjAwMDAwMDg5NjcgMDAwMDAgbiAKMDAwMDAwOTI1NCAwMDAwMCBuIAowMDAwMDA5NDA2IDAwMDAwIG4gCjAwMDAwMDk3MTYgMDAwMDAgbiAKMDAwMDAxMDEyMSAwMDAwMCBuIAowMDAwMDEwMjYxIDAwMDAwIG4gCjAwMDAwMTA2NTEgMDAwMDAgbiAKMDAwMDAxMDc0MCAwMDAwMCBuIAowMDAwMDEwOTQ0IDAwMDAwIG4gCjAwMDAwMTEzNTUgMDAwMDAgbiAKMDAwMDAxMTY3NiAwMDAwMCBuIAowMDAwMDExOTIwIDAwMDAwIG4gCjAwMDAwMTU1NzkgMDAwMDAgbiAKdHJhaWxlcgo8PCAvSW5mbyA0OSAwIFIgL1Jvb3QgMSAwIFIgL1NpemUgNTAgPj4Kc3RhcnR4cmVmCjE1NzM2CiUlRU9GCg==\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T16:59:26.427804\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sns.set_theme()\n", "fig, ax = plt.subplots(2, 2, figsize=(12,4))\n", "ax = [a for a_list in ax for a in a_list]\n", "for i in range(len(ax)):\n", " ax[i].plot(np.arange(1,17), pe[i,:16], color=f'C{i}', marker=\"o\", markersize=6, markeredgecolor=\"black\")\n", " ax[i].set_title(f\"Encoding in hidden dimension {i+1}\")\n", " ax[i].set_xlabel(\"Position in sequence\", fontsize=10)\n", " ax[i].set_ylabel(\"Positional encoding\", fontsize=10)\n", " ax[i].set_xticks(np.arange(1,17))\n", " ax[i].tick_params(axis='both', which='major', labelsize=10)\n", " ax[i].tick_params(axis='both', which='minor', labelsize=8)\n", " ax[i].set_ylim(-1.2, 1.2)\n", "fig.subplots_adjust(hspace=0.8)\n", "sns.reset_orig()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, the patterns between the hidden dimension $1$ and $2$ only differ in the starting angle. The wavelength is $2\\pi$, hence the repetition after position $6$. The hidden dimensions $2$ and $3$ have about twice the wavelength. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Learning rate warm-up\n", "\n", "One commonly used technique for training a Transformer is learning rate warm-up. This means that we gradually increase the learning rate from 0 on to our originally specified learning rate in the first few iterations. Thus, we slowly start learning instead of taking very large steps from the beginning. In fact, training a deep Transformer without learning rate warm-up can make the model diverge and achieve a much worse performance on training and testing. Take for instance the following plot by [Liu et al. (2019)](https://arxiv.org/pdf/1908.03265.pdf) comparing Adam-vanilla (i.e. Adam without warm-up) vs Adam with a warm-up:\n", "\n", "
\n", "\n", "Clearly, the warm-up is a crucial hyperparameter in the Transformer architecture. Why is it so important? There are currently two common explanations. Firstly, Adam uses the bias correction factors which however can lead to a higher variance in the adaptive learning rate during the first iterations. Improved optimizers like [RAdam](https://arxiv.org/abs/1908.03265) have been shown to overcome this issue, not requiring warm-up for training Transformers. Secondly, the iteratively applied Layer Normalization across layers can lead to very high gradients during the first iterations, which can be solved by using [Pre-Layer Normalization](https://proceedings.icml.cc/static/paper_files/icml/2020/328-Paper.pdf) (similar to Pre-Activation ResNet), or replacing Layer Normalization by other techniques ([Adaptive Normalization](https://proceedings.icml.cc/static/paper_files/icml/2020/328-Paper.pdf), [Power Normalization](https://arxiv.org/abs/2003.07845)). \n", "\n", "Nevertheless, many applications and papers still use the original Transformer architecture with Adam, because warm-up is a simple, yet effective way of solving the gradient problem in the first iterations. There are many different schedulers we could use. For instance, the original Transformer paper used an exponential decay scheduler with a warm-up. However, the currently most popular scheduler is the cosine warm-up scheduler, which combines warm-up with a cosine-shaped learning rate decay. In Optax, this learning rate scheduler is also implemented in `optax.warmup_cosine_decay_schedule`, but let's manually implement it below, and visualize the learning rate factor over epochs. " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def cosine_warmup_schedule(base_lr: float, warmup: int, max_iters: int):\n", " assert warmup > 0 and max_iters > 0\n", " # Create function to return lr based on iteration count\n", " def get_lr(train_iter):\n", " lr_factor = 0.5 * (1 + np.cos(np.pi * train_iter / max_iters))\n", " if train_iter <= warmup:\n", " lr_factor *= train_iter * 1.0 / warmup\n", " return lr_factor * base_lr\n", " return get_lr" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDUwMC44NzgxMjUgMjI3LjA0MjUgXSAvUGFyZW50IDIgMCBSIC9SZXNvdXJjZXMgOCAwIFIKL1R5cGUgL1BhZ2UgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMSAwIFIgPj4Kc3RyZWFtCnicvVhNcxs3DL3vr+AxOYQGQAIgj0k/PO300sQzveTiOkpij+00ttP8/T5QlqVVnU47Xfkgj/aZAvGABxBcThfT0UtOH24T/iRKF/h8xffjeJ4ovaB0NSlRbt5YNJ4vd59FPFPFl0ss3nn4OE3vJ8qd3aqTtpb2H2on7kbe0k1sfPy3BQ8P097qaaqeZb1/lUwyvl1NtZds+/DlLiykmWmDPxiZwcPzz+mxLWq1XBNbyQQ/Vum3dJ2OXso6ej/jc4HPiN509P3qz/Oz1evjV+nsdjLPap1Knfm7RWcOTG+mX9PnjWHKrMjMxvZ4PL5Hp1cn6ehHTszp5P1kNSuZlkqdepKa8T0snrybntHzdHKRfjgZlhelxtxyEbZWZtx24EXIMfXMpXIrXMX32IkekJ/1TF5LkTm/LbwMP+25NYRLuFfd44d6Oxg/7JebWG8847cDL8JPmLKVXmGmWtnj5wfMnzhl9W5zeW7RZdhpqB3mSEbEZuyYDpg+6CVDhOJ9xm8HXoRgQflRC3ul9v3mwoesv+Kw0WqXNie4hZchGPVXwl4l3u8vfMgCrCLY2dV9fpht4UUIVhSgtrBXtew3GD5kBVYvWTtBPHOCW3gZggbJV9gzJd3vMLJTg5tfybovSfbeKiKdevaHH/x0t7o5vTv/dH2b3j47v06/n96dfVzdvn1+wDA9GOu5VmoqezPNA/w/GpXkamphTTEtlKKIPzIwZoR8SA1sjHnHN8VBPie3hZcg54ru55ggcaJuyMkTkBtTCvWOqWfGbgdfgh5HvbhUNCyMyPf86lPwKw2jeNXa9/ht8UX4Fc1EThj6jeyenz0FP2vZoRveu07s4Ivww6jeKhvH84Zfewp+vWXjYnvFt4UXYdfRet24kejmfMmPd1+C1ReM2x9rxmQRTdg56/jZFf5L49e/rE5vrs+vPyR05FV6f3p29+lmiWBlTV9xt8TIRqULq4NYXV85BUckJLe5dNb0+jjNg7p7p9vp1w03YG2kYNg5JgrEFlEF3Jyow6XusKtlwIhUR9ezCJqhggQLLnHXidmGWgkYgXHiqgPvuTR46iPGjosQuhtwNph0hUlGUqV1TDOB4wwUapjXwhe12jzODi6cXag0Sdw6Ei+Fh5kCHqV6w7bwt7ipcOCIL86JIMIthOtleKmaq4tJmIHDiDW0DBwxx2RNCCf6L2j72goGNgykSliO9lwF4/cw0zgLGgfEx5F7VH0ZZjrmFWhCo1yQ4O51VB9O8qy44nrg0Fz3NoIgDJEBFYOYHAI1GuOEQJhmnQtgrKiI61iOFmOhUWxbS4hPJWImWARnDBTXXQhnZGQKoc5INkWqcDXCqUkS7he4A9FzBJkMwSll5KpA0YLkGycUl5tQhL6Ix2rDrl0yx4VoLEb7jO7SOTXI1kob5yDG0cwQZAEMrwSRHKtx2WiQQU0OzRTraxQtCi45W0Kku0EDY0esiXMVmceNNXpOJKM0eNfIkHezzE3bWAvxYPqlygk5RDlyHFdAckHitSZDlyIkoA8YN4IiDWmGP4otStjADhAmZJHUMlSOqSzQUKI2CCopFIQUDMslVOmIcNLYr6OGBhyq0Ui3IkgMK8NGDU9xqsM0A8aNeRjRoXhoJClBp8j0MAJfI7pIXo2bLAb7NmBHenu8IAGsUjAjjldB2B/VgeCtixmH2bfheYuUaCdgd//a7JtD3Wz0ffTNE+w++grr6puvsOIX/+VV2Gz9jqV/3IFA8F+dA7J+kQGJr+2gIu/tRBf/7tPt+TUa8enN1Ysvf6SHpv46mvobDNfvvlyutn19+gtbAVgmCmVuZHN0cmVhbQplbmRvYmoKMTEgMCBvYmoKMTM1OAplbmRvYmoKMTYgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyOTIgPj4Kc3RyZWFtCnicJVJLbgUxCNvnFL5ApfAn55nqrV7vv63JLEYQIMZ2plSxERs/Isg2lDZ+ZUU3NAR/a4qqhe/KNIgfJE99EBXg6VmhhgPPF8Q6b3yW7XMz9YRDTsOC5QuREFF+gi5IzXgq1GS46IkJxn1E33ArpLOkOIqIHhTdSCFskncklHn24CvBSwymftUYpUx8lse5WQgjyUU2smZJCjfMFkoOCsiiKdzA9VVU2ZQfFBhkE5acKdIJMhjmToykNjdy8LoWTKQFtm+mzY7RSOa5p6NCpBYIO+FBHeSzB03C4UFXOOvEn2iUP84xowkyeumilEP37Zyp8smoTKhb5z4nuec4ml4OtQOrnkDKdJqJV2II0C6RMS8GlnSi5sXGjPc/eNbnHx7UZ4EKZW5kc3RyZWFtCmVuZG9iagoxNyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDQ4ID4+CnN0cmVhbQp4nDMyt1AwULA0BhKGFhYK5oZmCimGXGB+LogCCeRwwaQgLAMgDVaRw5UGAIcJDFMKZW5kc3RyZWFtCmVuZG9iagoxOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDU3ID4+CnN0cmVhbQp4nDM1NVMwUDA3BhKmRoYK5oZmCimGXGB+LogCCeRwGZpZILEsTIAMkGo4wwBIg/XkcKUBAH5vD7EKZW5kc3RyZWFtCmVuZG9iagoxOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI3OCA+PgpzdHJlYW0KeJxFkbtxBDEMQ/OtgiVQ/Er1nMfRuv/UAOUZB3fECvw8UW0mKn34p0d6lXytZ75/GHjwPn7qT0XpqNQ9BRlLCu7nyQ6p7VKKE2SVIdqGU8hJ5FQi0mGECwcq6kh5SFhNrR9jt9ri2ZKrxT0l4ogb5wTm+EL+cXE1kJTYUdagi1XMDHOX3CnrJBySKugTeYwBmoWOvM9GDigixYwULRbNOaixVvC42C5xbMOmm1diMigS7sLPa4jgGBxsxxXUa6POJ36e1Ve9o8jQF4mfody2d48ClOmoSABhUPSWWM4r4sKBJ+P6AmN55chZfxTycJLod3riXfIodMpQrAS1i4smbhw+EerS8AR3OuPFpfrne5/vX2XBaFcKZW5kc3RyZWFtCmVuZG9iagoyMCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDQzOSA+PgpzdHJlYW0KeJwtk8txBDEIRO8TBQlslfhKimddPtn5X/1gfYKRUNPdMFVblkTKS1VKU7Yd+dKHE/Mlv496kB35ITtiy0TjiG5iBfmV96P7ilpx66KLKi+5wYXVkbPEF7guHtXh/fixTkJd9pEIldMwUVfuleBDNSQbzbs+tURTJc2n88Rp0Jlx2xXdtd9Y9pu4SwzABrVDMyfe7BtrRtCht9PH1cTRCl0INklbVJTSi9vdbzQ4uT68ApfuHtrvB3WBhZCKOyG9jzvhGJ0FltKt3BoHvlXgwLHORgUWrOZr3G+F466ZgVtMxK29J4s+0VZEXI2WK0BJehaoSzKPlO2+gVeCnRfXsj65jmO8zLWZZxPOhPih5hRLACdd41DCF4fwrdgEp3tHWO41mdmlImbipT47ANr9bEVLn+jNZTJtN5kCU1E4xZ71WKjJzcTaThwGMayVBfWButbq7EPcntcoo8apz4IXHBP9vhqt2DIDt7fYQJs4PnXWLIo6RUsyl8/6YFr71n7T7LDY1d1jywtjls5v4bkmQvh8fhRjyBRoT7BmPXq7mXO7IqezLk8W2XtgVJr8/1Lv5/sPgTakSQplbmRzdHJlYW0KZW5kb2JqCjIxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTY1ID4+CnN0cmVhbQp4nDVQO7IDMQzqfQqOYP2t8+TNqzb3b4OS3cZgNCDsdseGKI82Q0niT5buEd6L+ijXkt03UzlQr2FGYxdvRB+1IXvjtTQLUgktTnVDj0BNOXGpOydMbpbmsPBh0dB0ZI2/kedJzD7cEXT8dpYGNJKTY0/i0/9aZcn6zKNj0NsR5GRsEi3w4hZpoiK1pxn3Ric8CpGHKF8PX9P9Tfl9yrX+P/lXOmAKZW5kc3RyZWFtCmVuZG9iagoyMiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDQzNyA+PgpzdHJlYW0KeJw1kklyJTEIRPd1Ci7gCDFoOs/v8Mp9/22/pNyLKpAQkJkw57JhuezL3aaXzRz2x58aZavs75PbFc4a5hgfNu3zxBn2NS1qd2J4tv08Pt9S7mFhJ4xyn2dS+6jMTf09N5dyVljx+Ez6WozF9aJsbKNBVNm9FlOv3bfFuuQei307NY4SnFNcng8yb5GGTx4dAJJj05K25Ofli47Io/Nrz2tn/I8cbs4FGnk7reIoMoeV3qJDTaGItqgByb4ZsggF+MrGtvAChoV2dzbznPeVRNL+PJwKjCpGEB61JJmPY4V+nmlzSPzNfIQwBmrGy1PTilZPOeImL9FQLxK5NdPPIwyTkRac6/JN/K1JFnVLGDasqFiHqAt7Hd6IESq3CrLZ1fACPX/a85zEmFh16SWMBVfBGwxpNIbRKAJLFjwcekOi2O+qvdIH5Fm69e6WhhYIGdqO0BqobUjQq61DUGDHuC01NyPNNQCIe6lJ7ySgfR2AEoF42+wcearCUl2YsLynxd8NSfOcQlDWOxgU0fkeRROF9/1dDPYut4phj5r3PC4QICRizj41wXeXfqn+PN//ABlPplMKZW5kc3RyZWFtCmVuZG9iagoyMyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI1MSA+PgpzdHJlYW0KeJxNkUtuxEAIRPd9Ci4wEv9uzuPRrJz7b1PgiZSFRYmigYcjkpgy6CVCIUFbkt6yxDfyPwsO092hjXtJ2D/l2aYUU3CS2qYwJq2YeC3TJ2OBqEZ2EkMEjouT1yE3fIfJsYejCs6GY+h9ipwDVUW2FU4wlNWsaewTNftNK9noVEaCKveinDFoIJgbOoBaMvFaqjVKDj6nxgXj9eUGk4MnDRcYxagKlGFL2dB6uhabHQN29jA9/sw01RYWuNkGMlb6IvcowCl2qidq9pla9am6wuYNsIGBXgEoR1b7PAdDzkQAAHEy+BldIe3wrIWetpsp590fyrU+v/AiXAYKZW5kc3RyZWFtCmVuZG9iagoyNCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI0NiA+PgpzdHJlYW0KeJw1UUlyBDEIu/sV+kCqDMLGfk+nckr+f41gZg5dwsZoodecmODFlxniGhYnvm3EDNid+Bt1aXnxO+KosotYiXSEnRp8BtVdIK1JPFfjM3yyK4sNc5iO6+h+T9VRs5at7SIUzQWNCLFux06Uh2echSiCamXCVvksGjuSlR2X43JdwoVi4isH9X6Z5pu2NCLKvr63/zgutd3qCS4qJsVLvWZGT3IJac0rHjFwalJRPG+jojK6MjmL8A4WVl5MJ6Y6rjl/oe/uqKoV1wurZWx9s5PdhdwdbNdCo0DyaqogtX6BSK7X9WFvh9KuVX9+3TN+/gHOaljNCmVuZHN0cmVhbQplbmRvYmoKMjUgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNTcgPj4Kc3RyZWFtCnicNVAxkgMxCOv9Cj6QGYMA2+9J5qrN/9sTeFJJi1cgKSJlClxequLbZWnKR4dP4/zbmCHPQB5RF5j2rxar8T1Mo5muRYU6xMCxzi0eslU4TxPVzSlfNXbdsxkNcTgmgU5xE3Bv0tCpMZXu3DwhvlwU/D5Zy5dKcKFRFJjcgka6YYRiUJOgEVATWi9IBjxLsCtonUga7OtkFfsZvwIum4XdwzPUor1+m+lhIGymJWYyXF3Q4xXWjBHYEOdZBWF6EYBXUpCsYO4+y7pwxPmuezValKIYjGXwDzB4afxqEF0JaMtZpOVFBrfLlBpvWy5+bdEoFl9oHPT2i/Ief//jlV6CCmVuZHN0cmVhbQplbmRvYmoKMjYgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNTUgPj4Kc3RyZWFtCnicNVDJbQQxDPu7CjYQQKdl17NBXpv+v6E0G4wHImyRlJi5IfDClypSE+mCb11hCt2F35VqUHa9V9yCiiFKsBXhBp7X8uvgFzEidp76WiZnkAZ5FBFHHt7nJY421Rpvy2yZooaBr6EyHTHtGgcpGyY101ndqWT0C1FITkcEueS/OKpTxWYjjz3VdnMGZfAmYBxsKq3pYzXovZSaShclU51/JefZs1KgOEpMAr3q7k1dd4OOYF84czvd7ec+gUkHwNk+odKrs5PLeMMexHj1wNOn2w/nJrsxdTrtoL49mdiRTzbm97lhAkF3rcO9xyEZ7eUeTiXu++/4Wj9/SRdcugplbmRzdHJlYW0KZW5kb2JqCjI3IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNDE2ID4+CnN0cmVhbQp4nD1SS24FMQjbzym4QKXwT84zVXe9/7Y2M6/Sy4MJBGxDZsmSUPlSlVST1iPfemm3+N7ye2kgvlRUXcKWnJSIlPtCRtQR/OKMSQ9c09klu6XsiFpKFdO1XBp1DHXYxDrH3pd7j8d+vdDNCm9YK/BftiW2o2g81o0ReHEQ6RgUkf3CCj+DM4gX/fxlgojxC/kZ4ql4i8ggSHQ1IKYAFuexi9XoabAXmBtaMIm1lgsQR41w1o+9L76ip7ERV3xNetm85n3Q2GoWgZNghGaIooYbWUNNzxR1B9wS/SegSZGbQ6EHCNogCiGV1ZOPTdHASQM3BssxGSVzwKLnvifDa71vfNtU8QMwlOx8ZB4PQ/CN7TiIoG9B2Gdo5XizcAMZKAEtEKDz3AAJM4itlH2INvE16KvlwwcRfzT5HU/RTZ1xHIxDZ7G0uIQLnDfFtAXJODybFYOSS8CIfGFI99BxCNw+BHStmdzGVj6iaL0irdk2egDODNt4yrMovCdlw3wUJ9kkxqI5hYSn2EVmaOtIAriYtA0RUMPafCje188fiKGkNgplbmRzdHJlYW0KZW5kb2JqCjI4IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTYwID4+CnN0cmVhbQp4nE1PORIDMQjr9xV6AmAO857NpEr+30ZeZ7IpPBIWSGA1IWi+oQNlEw89ZrF+L/AQvI7+YWgTKdykfJOUiRbkcHQiQ3EeKo5kg7I7e0BdUVJLSWOaQtuuVBty4XlYxP6Za5/Ye3GeStPFB+NsKlnAkv5eMJ8Xssd0/4gRlz9rejOxOK0Tyn2ia2Pmpfj3Hqv4Y/vopd5M9rELnh/cbTvYCmVuZHN0cmVhbQplbmRvYmoKMjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDYgPj4Kc3RyZWFtCnicRVE7bsUwDNt9Cl6ggPW1fZ4UnV7vv5ZMAnRIxNgSSTFVjYl0fJmhrLFm49sGT2xv/A6LhJ3CZ1hOWOlpGDttG07iGs6RZfBo9IQTslwjLAQiD1Yj1oHNzfPkW1zpQQ6/q0fpRmgX1BGeiM3xCnGV84uPFeIsisy7UpxO7xM6ikN3J6ilG1NP071m89EMl4NaiNhayZ+FPyNJ/o/aXbekfVFtZEwin4bUltnIVXDKqcpi3Ujmk6az2GkKIplSdN/xxhuzp9YSssV+KhmVspjVnQSzM7okh36MMlV9shYyKnDGOCMirsp8UywL77+7xs8fHkpY9gplbmRzdHJlYW0KZW5kb2JqCjMwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggOTAgPj4Kc3RyZWFtCnicTY1BEsAgCAPvvoInGChS/tPpyf7/WpFx9EJ2EiCqjSpBxtB6k6HRgyIcxjcVBuoFB7DyABGf671cwEGZxrNNeRrppho/Zk9qbGejmg7PfRXxqnx/MdkhKQplbmRzdHJlYW0KZW5kb2JqCjMxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzUzID4+CnN0cmVhbQp4nD1Sy40lIRC7dxROoCXqC8TzVnPazf86NszsqSzA7U91VWMgHK+PjVwbFQN/7KmBNx3/HovCW4W/RBvvMlhy2hiw5pWZ4/PYmoS+4NYEMeGVF3we3z8wvO+ryPXLjEml3YjFuxkIPc7UzeYjMlJSdkYvnbfBHWFB634CyEBymm+eYA9MCRfNSs1h+6T0PpIi84OGqIna1Nw8JiV5ZiOQNCLDSWP89jSUKZudelyskGrwVChorEbR40KWOEJlm7WdUv8jpr2ADbJvZm8m7LyNkneaiUQy4ms9bjG2jpy2YjQbY96NOTdzAF3uuNAy9KqYRPtpNdFaT2jDLFtez3ZJ8mApW3sWGowfDVNxzQr8VMvuFtN7Yup1adDMOCBi6TYYw2yftZFIgaRHedX0vp3oF1DdpLHtaDV2OHG7D3Vf1Oo7+e9QVcg2F0bLxqrSji0ajckblwnDb5TP8/UNIeKGVgplbmRzdHJlYW0KZW5kb2JqCjMyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTczID4+CnN0cmVhbQp4nE1QuxFDMQjrPYVG4GvwPMmletm/DfByuRS2hHQGYfcNwu7LMhG88eQ19buhhWux2x8zP82OwWlDbMOVoHQGH0stbiUZLgJrh6Ic04CdUjxhwXVqrHk7WSrnhNA4N8oZJyvMtYzoh+18WSj0VBfy4tVRupu6TF+tytwhhwcfS/ZXsZ6cEK5EauX0PiYEjkpBAt53knIqrdY/9e4qNig5b4p1pvmva70+/I0+swplbmRzdHJlYW0KZW5kb2JqCjMzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNTAgPj4Kc3RyZWFtCnicMzY2VjBQMDZSMDI0VTA2MAJiY4UUQy6oSC6IARLK4YJJQlggyRyYqhyuNADiow2TCmVuZHN0cmVhbQplbmRvYmoKMzQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA2NyA+PgpzdHJlYW0KeJwzMjJSMFAwMwMShqYmCuaGZgophlxAvpmhqUIuiAESyuGCSUJYIMkcmCowwwCi2NTQEqoEwTKAqcjhSgMAlXoVTAplbmRzdHJlYW0KZW5kb2JqCjM1IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNDggPj4Kc3RyZWFtCnicMzIyUjBQMDMBEoamRgrmhmYKKYZcYH4uiAIJ5HDBpCAsAyANVpHDlQYAgA4MJQplbmRzdHJlYW0KZW5kb2JqCjM2IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjc0ID4+CnN0cmVhbQp4nE1SS3bEMAjb+xQcwfzxeaavq+n9txU409dFImKBkJKUKm2KwC3jkOumL17z/NPgfOi92PxfZRZdBZMlE5eQHSbZGN9JryWKORGSyBHULYOvpbbvCea6Qw86d4Ax2VDBpUWGOTOgnmbqgIG2XZXY9ahFXLVolp1SMFftIB0u/Uwkawao3nu62nAfxX+omHsqZIos0gogcsF57wmoFAUUrPcZkts4EJzYgSfscSOvi6/lLvcEKa37D/Jwe7M05FakRH50DG5uBlV7UnR8UDU/VQb8Yd92zEFVvN9ovy8Dyzb7pORxIJ73RMFYkjB2ajN8ehpfLnMSciBxtjf2Gm32VoxBiTPM9TR/xnt9/wJnsGqfCmVuZHN0cmVhbQplbmRvYmoKMzcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxODEgPj4Kc3RyZWFtCnicTVBBEgIxCLv3FTyhEOjCe3Q86f+vBhwdD9sEUtKwEUe2nD48Lwlsueua+tUQWvJc6vHHnB9ZQmKrGHLGoHvwtuD66VzsmAuqfUDFzThjdLB5zoNup1o5yUrFL3atqPLG9lYyBJlzH1Ef1Jkh20yCqh9C48vohuIsHZE1nNnal1k6m1s7QpwbUEFvluPg4WJlg7dlPKdjOsm1WGvP6KEDK6UKr0HL3rRZZ5o/+VyPN55TQ7sKZW5kc3RyZWFtCmVuZG9iagozOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIzMiA+PgpzdHJlYW0KeJwtUEGSxEAIuucVfmCrWkU7/Z7Zmtv+/7pgcoLERsCqtmWZ9uNu5ccql/36xT9Rx/5EssrIEW3uadhpn8tr871beIwmdg9+rsQehkXZakO5oTXB4Rc3yCdxBqM3J8PW4vtjTj1uIjk1fWxzQTIAYdFxTDqVO3yCy1z4uWI9VRwwJnPtvGVQ5FBR57a3HVsE3p5ifjjOm2Iic7nLyk/Z3hYZ1o9VyymZgyR5QE7zrvc5HLMAwQoHg9GhCVmGTsAgG6PBUjpdGKyXPAOYVyaY3HIVUwi9UKxHo6C56crgGQ8+pb7/VM5WwgplbmRzdHJlYW0KZW5kb2JqCjM5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTE0ID4+CnN0cmVhbQp4nDVOyw1DMQy6ZwpG8N/xPK/qKd3/WsdqLwZhQLgHCEzVV1ORXHjxupTwWbK98Qx6DAuFG0G0lTYLMawKz+JIWBZYAxY2peZ2P81cq9Psu3tkUl63ZSNE2yNpCHcoEWInlGPGPOs/6/xWnfX+Ai2WIl4KZW5kc3RyZWFtCmVuZG9iago0MCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI3NiA+PgpzdHJlYW0KeJxNkUtywzAMQ/c+BS+QGfEr6TzpdJXcf9tHuel0YZOGKACEM0uGVMlD95LUkvQhX3p9oHd3qVtel8b/LlK7q5CYU3SB7Cmmg5khz8s8JM3Fyg6n7Zv7eXmM0/nczC4Jde4WJxETNr6mSYSCMrU3JzmmeM7j0NVOtfI+6a5VR4miFQs31jpRS7AWyAUuR4hZywNDi4GHKrbuiuH6RTD+SDhVJrA234Z6CQeabBUN8z4Bvf6iunMxEn2fThfXkgcDnY+O1TJsOxljoBBb0QVXREXj3MazA+uJMVhWg0gMgh2nWrWD7nqLnugofeXp4UpCZWVnIo7IOhXxHDeinYsfi3FsafUPAcGXm8lnlef1/QNl6mXyCmVuZHN0cmVhbQplbmRvYmoKNDEgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNTAgPj4Kc3RyZWFtCnicRU+5DQMxDOs9BRc4wPrleRykuuzfRvIhSCWaMklRRDDhhIvnAq9AcOJFg0Uf7jMoDRdpgrS4CBARFHsUTG8xBfVgW8UWEHGEQGliJTS5aOKAudZfhqujAzprj1/qPdgC7rN27VeGtGAprV6FpE6oly46LqrH1xKSxVjl8jzzOaQRTUZrL5PT7lTa/3b3eH8BZ3IwUwplbmRzdHJlYW0KZW5kb2JqCjQyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTUwID4+CnN0cmVhbQp4nD1PQbJDIQzaewou8GYMiRrP8/501d5/W2Lbv4IBCeju6JiGi32De2Ex8WfNGB/t1X7us9lIXMPBvmCdBzk27lbM2aV0eBosJyJMjq1A5Ibp3fAJc2GyHE41d13SArnqqfL7dBdTYQZmTJAKzThlccKUEwrXZNfZwu8UMVsGZodWUPEo2YhrlSW0SPz/8G6PNzQIMRUKZW5kc3RyZWFtCmVuZG9iago0MyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDQyID4+CnN0cmVhbQp4nDMyt1AwULA0BBKGQNLQwEAhxZALzM/lggrkcBmisEA0lEoDAH7MDBIKZW5kc3RyZWFtCmVuZG9iago0NCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE0NCA+PgpzdHJlYW0KeJxNjzGyAyEMQ/s9hY9gyTaw58mfX23u30aGFClAAs2TxxFhbqN0RU6rcPvDtd/vlsJtz4XMH5fp7YbcHIZFK3ejToH2uggVMY0xdiHrqJJ1XEDK70hvpn+S3ctbvUHjDOktmltFx3FESWNPTDaNJQZuEBPiMabFzE5KLmkoJb62cmInvflzdn2u/w/DNDGPCmVuZHN0cmVhbQplbmRvYmoKNDUgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA0MDggPj4Kc3RyZWFtCnicLZM5cgNBCEXzOQUXUFWz9HYeuRzZ90/9PnKggaGn+QtojmHD0u3lbrXcZg778oeKz2m/j/sxX8d+Ho8wz+S3zAe/eewcez9+hu1h4WE0iNnh/cTtJLm2rmWFnUk5qfmgvonxHyt1omyl5QJoH65M8zt0IthLZaTFOGBMi9CJXxBFqBbV0R2D++/nUthle1vQYxUixWqCmGG0TpooHJVJapTRoWKKdpXIbpky7SyrI9pldUuwuDJ5kxcT3b4G8bYZdAmwx20vRU4RP/YnS76fAR9E666EC5mTu8GBCVXBQotOYlJ0KTALF/Nj41xYl8wlMTyvHBb50YZX9jfYVUgKTCimFZUd4TKiM9+qaNb0zx4mzwopoxsaFPN6n5Dt2zuQTNQLK1cPoRiNzJ+3VfnG1tztBt9mthspV8TV5aCeoZGKF57liZ4XmtmtYA2kPQrm6IrYPSTujN176ic+ccrBzqSbtztb/tI02jMoXCiA0asn8Lj2En703ovVq7dD01MmjXTSxlFln2AJDbn8+WO8n+8/heOYmQplbmRzdHJlYW0KZW5kb2JqCjQ2IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTM4ID4+CnN0cmVhbQp4nDWPSxLDIAxD95xCR/AX4/Ok0xW9/7YimWxAlphnkTkhiOKRqigpfHRwnmb4HXGsPd7wUdMXVcxErkZoIy3glYgIXMNd4DNgnbClsFJoFxNLh3rBwkDTCBLaejfYvBfYSLOhJOoSmByiCR8vEl1JfojheXaxT0rDSU663usuf72/2OP7B2dLKxYKZW5kc3RyZWFtCmVuZG9iago0NyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDM1MyA+PgpzdHJlYW0KeJw9UjtyRDEI698puEBmzB/Os5lU2fu3EThJJQYjQMLuQYe06IOZnA8lN33yY13kxvR+DElXo+/HjpBHkTZKW0kzKU7T61FXCkVGgBYk1YuvR4JvRgMVRcJOgarXwzVsJY4gT6DPHJ8XTLMOYnEy7DCoMXMYnewgk0ImRgK+2Zk5mG7QIgFO4KV7cXbLjewADTwbBdPNsKWCM7L1nEVRwctEs58jy4aOhZnggzN6igyLat9d1oBIOAj9vUZKxSL2YtmIfRRuk1USI0toHeEBXekILMfLawkbwhnLXuChMddeSNoWR969mXZSjh0wIpJ3VRxhlmxIg51/Jx2De4W+b4SzjkjeI9TGqElI54QNRSCPjpI1GgdMEkdz2FU+gDWEJ5iPkLCmQD7Txg7uCIoJMnlRZJ2cKOeeQcqXo3YvZvhbMEfGGcyqixhuv5lTW8H/HHbZLisoi/4kvp6vH1MwiTEKZW5kc3RyZWFtCmVuZG9iago0OCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE3ID4+CnN0cmVhbQp4nDMyt1AwgMMUQy4AGuMC8QplbmRzdHJlYW0KZW5kb2JqCjQ5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTcwID4+CnN0cmVhbQp4nEVQOQ7DMAzb/Qp+IIBFH7Lek6JT+/+1lFMki0mQskibvlBhC8cE3eC14mWFY8ED35Ka4VPYB44Gsu3J2hPOYs4k1h2HBlvFStWYK027miEaeqprYHYsIiJPG0yR6KMqQPM3GRYism4yFSBrxi54scvMpg/7r5D7MLvvGtXR9dw6hB2xy7ojpCtFDW2pnKUcE3JYBQNUguAs5CbshOsfrm86y/sHMoY9iQplbmRzdHJlYW0KZW5kb2JqCjUwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjY3ID4+CnN0cmVhbQp4nDVRSXLDMAy7+xV4griL70mnp+T/14LMdMYyaHMDoIjEgTZfcQwljR95JryOzwYH78fOAutUYAaXeVLwesLQbFSIOvpCOPH1zIfcgqRBlUd4MpjR5gS9MDdYEWtmTY+x22OGK/zexVBlZiPOtW7EJZZz+Zkeb6Q5TArpCa0vco/F988hUVKWSuS5wy0o9pKwFcLri2f3MOCq94iKakwLpQvpZa4skigOVJH1SqeIOERqI+egJE134hrkXJW0YFYEJy7qkJ/IaYd3wmmU03O3WCLMnFo7xiRXiva7JvWKtXBuD4yduiap0XzW6qH1rJXblDYZoV2jQZKiD/WEzvW+/u/5/fz+ASsdYNgKZW5kc3RyZWFtCmVuZG9iago1MSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE4OSA+PgpzdHJlYW0KeJxNUMFtADEI+2cKL1ApQCAwT6u+rvt/a5Kr1AcyMmAD7oGJWPgQwcoFl8KXjDWb/zm4A8+wcEjCZJ5WXXLwc+jLSJJhjzuCtGhBNmQWTFEBn2TTEIm9kIVggzjJVmYPlxCvA7Wbvss8Q1z/ZWryZpJtZ4yepJdlG4cdXaELdaQUPOvuuSfHj5NeJ9IUYWLck1Uzu93Gv3Dath4xS6JVF4qnhsJ4kjAa+xldiXfNvz/ebDaedz7j+xf2zUSMCmVuZHN0cmVhbQplbmRvYmoKNTIgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNzIgPj4Kc3RyZWFtCnicNVFLbgUxCNvPKXyBSvxJzjNVd73/tibpk2YECdgYJ7MgCMOXKlIWWje+9eGNp+N3kvBmiV+iFjIb77OYy4YSVcEYPPcUtDeanWZ+uKzzxPdxvTcezajwLtROVkKC6E0ZC0X6YEcxZ6UKuVlZVFeB2IY0YyWFwpYczcFZE0fxVBasiCHORNll1LcPW2KT3jeSKKp0GWGt4LrWx4QRPPF9TG6myd+5q1EV78mipmOa6Qz/n6v+8Wwy8zyuKPfRHvQ6lAIuas6F5Yyqo0BP4rGmOsbc9jFmCIKnIZx4h00W1D0dGReTazBDUlZw5YwoDrmRw93vDU0p46PxwfI8gNLwPFvS1BZ8Vnmfnz/0lmVLCmVuZHN0cmVhbQplbmRvYmoKMTQgMCBvYmoKPDwgL0Jhc2VGb250IC9BcmlhbE1UIC9DaGFyUHJvY3MgMTUgMCBSCi9FbmNvZGluZyA8PAovRGlmZmVyZW5jZXMgWyAzMiAvc3BhY2UgNDAgL3BhcmVubGVmdCAvcGFyZW5yaWdodCA0NSAvaHlwaGVuIC9wZXJpb2QgNDggL3plcm8gL29uZQovdHdvIDUyIC9mb3VyIC9maXZlIC9zaXggL3NldmVuIC9laWdodCA2NyAvQyA3MyAvSSA3NiAvTCA4MiAvUiAvUyA4NyAvVyA5NwovYSAvYiAvYyAvZCAvZSAvZiAvZyAvaCAvaSAxMDggL2wgL20gL24gL28gL3AgMTE0IC9yIC9zIC90IC91IF0KL1R5cGUgL0VuY29kaW5nID4+Ci9GaXJzdENoYXIgMCAvRm9udEJCb3ggWyAtNjY1IC0zMjUgMjAyOSAxMDM4IF0gL0ZvbnREZXNjcmlwdG9yIDEzIDAgUgovRm9udE1hdHJpeCBbIDAuMDAxIDAgMCAwLjAwMSAwIDAgXSAvTGFzdENoYXIgMjU1IC9OYW1lIC9BcmlhbE1UCi9TdWJ0eXBlIC9UeXBlMyAvVHlwZSAvRm9udCAvV2lkdGhzIDEyIDAgUiA+PgplbmRvYmoKMTMgMCBvYmoKPDwgL0FzY2VudCA5MDYgL0NhcEhlaWdodCA3MTYgL0Rlc2NlbnQgLTIxMiAvRmxhZ3MgMzIKL0ZvbnRCQm94IFsgLTY2NSAtMzI1IDIwMjkgMTAzOCBdIC9Gb250TmFtZSAvQXJpYWxNVCAvSXRhbGljQW5nbGUgMAovTWF4V2lkdGggMTAxNSAvU3RlbVYgMCAvVHlwZSAvRm9udERlc2NyaXB0b3IgL1hIZWlnaHQgNTE5ID4+CmVuZG9iagoxMiAwIG9iagpbIDc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwCjc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwIDc1MCA3NTAgNzUwIDc1MCA3NTAgMjc4IDI3OCAzNTUgNTU2IDU1Ngo4ODkgNjY3IDE5MSAzMzMgMzMzIDM4OSA1ODQgMjc4IDMzMyAyNzggMjc4IDU1NiA1NTYgNTU2IDU1NiA1NTYgNTU2IDU1NiA1NTYKNTU2IDU1NiAyNzggMjc4IDU4NCA1ODQgNTg0IDU1NiAxMDE1IDY2NyA2NjcgNzIyIDcyMiA2NjcgNjExIDc3OCA3MjIgMjc4CjUwMCA2NjcgNTU2IDgzMyA3MjIgNzc4IDY2NyA3NzggNzIyIDY2NyA2MTEgNzIyIDY2NyA5NDQgNjY3IDY2NyA2MTEgMjc4IDI3OAoyNzggNDY5IDU1NiAzMzMgNTU2IDU1NiA1MDAgNTU2IDU1NiAyNzggNTU2IDU1NiAyMjIgMjIyIDUwMCAyMjIgODMzIDU1NiA1NTYKNTU2IDU1NiAzMzMgNTAwIDI3OCA1NTYgNTAwIDcyMiA1MDAgNTAwIDUwMCAzMzQgMjYwIDMzNCA1ODQgNzUwIDU1NiA3NTAgMjIyCjU1NiAzMzMgMTAwMCA1NTYgNTU2IDMzMyAxMDAwIDY2NyAzMzMgMTAwMCA3NTAgNjExIDc1MCA3NTAgMjIyIDIyMiAzMzMgMzMzCjM1MCA1NTYgMTAwMCAzMzMgMTAwMCA1MDAgMzMzIDk0NCA3NTAgNTAwIDY2NyAyNzggMzMzIDU1NiA1NTYgNTU2IDU1NiAyNjAKNTU2IDMzMyA3MzcgMzcwIDU1NiA1ODQgMzMzIDczNyA1NTIgNDAwIDU0OSAzMzMgMzMzIDMzMyA1NzYgNTM3IDI3OCAzMzMgMzMzCjM2NSA1NTYgODM0IDgzNCA4MzQgNjExIDY2NyA2NjcgNjY3IDY2NyA2NjcgNjY3IDEwMDAgNzIyIDY2NyA2NjcgNjY3IDY2NwoyNzggMjc4IDI3OCAyNzggNzIyIDcyMiA3NzggNzc4IDc3OCA3NzggNzc4IDU4NCA3NzggNzIyIDcyMiA3MjIgNzIyIDY2NyA2NjcKNjExIDU1NiA1NTYgNTU2IDU1NiA1NTYgNTU2IDg4OSA1MDAgNTU2IDU1NiA1NTYgNTU2IDI3OCAyNzggMjc4IDI3OCA1NTYgNTU2CjU1NiA1NTYgNTU2IDU1NiA1NTYgNTQ5IDYxMSA1NTYgNTU2IDU1NiA1NTYgNTAwIDU1NiA1MDAgXQplbmRvYmoKMTUgMCBvYmoKPDwgL0MgMTYgMCBSIC9JIDE3IDAgUiAvTCAxOCAwIFIgL1IgMTkgMCBSIC9TIDIwIDAgUiAvVyAyMSAwIFIgL2EgMjIgMCBSCi9iIDIzIDAgUiAvYyAyNCAwIFIgL2QgMjUgMCBSIC9lIDI2IDAgUiAvZWlnaHQgMjcgMCBSIC9mIDI4IDAgUgovZml2ZSAyOSAwIFIgL2ZvdXIgMzAgMCBSIC9nIDMxIDAgUiAvaCAzMiAwIFIgL2h5cGhlbiAzMyAwIFIgL2kgMzQgMCBSCi9sIDM1IDAgUiAvbSAzNiAwIFIgL24gMzcgMCBSIC9vIDM4IDAgUiAvb25lIDM5IDAgUiAvcCA0MCAwIFIKL3BhcmVubGVmdCA0MSAwIFIgL3BhcmVucmlnaHQgNDIgMCBSIC9wZXJpb2QgNDMgMCBSIC9yIDQ0IDAgUiAvcyA0NSAwIFIKL3NldmVuIDQ2IDAgUiAvc2l4IDQ3IDAgUiAvc3BhY2UgNDggMCBSIC90IDQ5IDAgUiAvdHdvIDUwIDAgUiAvdSA1MSAwIFIKL3plcm8gNTIgMCBSID4+CmVuZG9iagozIDAgb2JqCjw8IC9GMSAxNCAwIFIgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9DQSAwIC9UeXBlIC9FeHRHU3RhdGUgL2NhIDEgPj4KL0EyIDw8IC9DQSAxIC9UeXBlIC9FeHRHU3RhdGUgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgPj4KZW5kb2JqCjIgMCBvYmoKPDwgL0NvdW50IDEgL0tpZHMgWyAxMCAwIFIgXSAvVHlwZSAvUGFnZXMgPj4KZW5kb2JqCjUzIDAgb2JqCjw8IC9DcmVhdGlvbkRhdGUgKEQ6MjAyMjA1MzExNjU5MjcrMDInMDAnKQovQ3JlYXRvciAoTWF0cGxvdGxpYiB2My4zLjIsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcpCi9Qcm9kdWNlciAoTWF0cGxvdGxpYiBwZGYgYmFja2VuZCB2My4zLjIpID4+CmVuZG9iagp4cmVmCjAgNTQKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDE2IDAwMDAwIG4gCjAwMDAwMTQ2MzIgMDAwMDAgbiAKMDAwMDAxNDQzOCAwMDAwMCBuIAowMDAwMDE0NDcwIDAwMDAwIG4gCjAwMDAwMTQ1NjkgMDAwMDAgbiAKMDAwMDAxNDU5MCAwMDAwMCBuIAowMDAwMDE0NjExIDAwMDAwIG4gCjAwMDAwMDAwNjUgMDAwMDAgbiAKMDAwMDAwMDM5NyAwMDAwMCBuIAowMDAwMDAwMjA4IDAwMDAwIG4gCjAwMDAwMDE4MzAgMDAwMDAgbiAKMDAwMDAxMjk0MSAwMDAwMCBuIAowMDAwMDEyNzQxIDAwMDAwIG4gCjAwMDAwMTIyNDkgMDAwMDAgbiAKMDAwMDAxMzk5MiAwMDAwMCBuIAowMDAwMDAxODUxIDAwMDAwIG4gCjAwMDAwMDIyMTYgMDAwMDAgbiAKMDAwMDAwMjMzNiAwMDAwMCBuIAowMDAwMDAyNDY1IDAwMDAwIG4gCjAwMDAwMDI4MTYgMDAwMDAgbiAKMDAwMDAwMzMyOCAwMDAwMCBuIAowMDAwMDAzNTY2IDAwMDAwIG4gCjAwMDAwMDQwNzYgMDAwMDAgbiAKMDAwMDAwNDQwMCAwMDAwMCBuIAowMDAwMDA0NzE5IDAwMDAwIG4gCjAwMDAwMDUwNDkgMDAwMDAgbiAKMDAwMDAwNTM3NyAwMDAwMCBuIAowMDAwMDA1ODY2IDAwMDAwIG4gCjAwMDAwMDYwOTkgMDAwMDAgbiAKMDAwMDAwNjQxOCAwMDAwMCBuIAowMDAwMDA2NTgwIDAwMDAwIG4gCjAwMDAwMDcwMDYgMDAwMDAgbiAKMDAwMDAwNzI1MiAwMDAwMCBuIAowMDAwMDA3Mzc0IDAwMDAwIG4gCjAwMDAwMDc1MTMgMDAwMDAgbiAKMDAwMDAwNzYzMyAwMDAwMCBuIAowMDAwMDA3OTgwIDAwMDAwIG4gCjAwMDAwMDgyMzQgMDAwMDAgbiAKMDAwMDAwODUzOSAwMDAwMCBuIAowMDAwMDA4NzI2IDAwMDAwIG4gCjAwMDAwMDkwNzUgMDAwMDAgbiAKMDAwMDAwOTI5OCAwMDAwMCBuIAowMDAwMDA5NTIxIDAwMDAwIG4gCjAwMDAwMDk2MzUgMDAwMDAgbiAKMDAwMDAwOTg1MiAwMDAwMCBuIAowMDAwMDEwMzMzIDAwMDAwIG4gCjAwMDAwMTA1NDQgMDAwMDAgbiAKMDAwMDAxMDk3MCAwMDAwMCBuIAowMDAwMDExMDU5IDAwMDAwIG4gCjAwMDAwMTEzMDIgMDAwMDAgbiAKMDAwMDAxMTY0MiAwMDAwMCBuIAowMDAwMDExOTA0IDAwMDAwIG4gCjAwMDAwMTQ2OTIgMDAwMDAgbiAKdHJhaWxlcgo8PCAvSW5mbyA1MyAwIFIgL1Jvb3QgMSAwIFIgL1NpemUgNTQgPj4Kc3RhcnR4cmVmCjE0ODQ5CiUlRU9GCg==\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T16:59:26.986974\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr_scheduler = cosine_warmup_schedule(base_lr=1.0, warmup=100, max_iters=2000)\n", "\n", "# Plotting\n", "epochs = list(range(2000))\n", "sns.set()\n", "plt.figure(figsize=(8,3))\n", "plt.plot(epochs, [lr_scheduler(e) for e in epochs])\n", "plt.ylabel(\"Learning rate factor\")\n", "plt.xlabel(\"Iterations (in batches)\")\n", "plt.title(\"Cosine Warm-up Learning Rate Scheduler\")\n", "plt.show()\n", "sns.reset_orig()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the first 100 iterations, we increase the learning rate factor from 0 to 1, whereas for all later iterations, we decay it using the cosine wave." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Full Transformer model\n", "\n", "Finally, we can embed the Transformer architecture into a full architecture. We will implement a template for a classifier based on the Transformer encoder. Thereby, we have a prediction output per sequence element. If we would need a classifier over the whole sequence, the common approach is to add an additional `[CLS]` token to the sequence, representing the classifier token. However, here we focus on tasks where we have one output per element. \n", "\n", "Additionally to the Transformer architecture, we add a small input network (maps input dimensions to model dimensions), the positional encoding, and an output network (transforms output encodings to predictions). We also add the learning rate scheduler, which takes a step each iteration instead of once per epoch. This is needed for the warmup and the smooth cosine decay. The training, validation, and test step is left empty for now and will be filled for our task-specific models." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class TransformerPredictor(nn.Module):\n", " model_dim : int # Hidden dimensionality to use inside the Transformer\n", " num_classes : int # Number of classes to predict per sequence element\n", " num_heads : int # Number of heads to use in the Multi-Head Attention blocks\n", " num_layers : int # Number of encoder blocks to use\n", " dropout_prob : float = 0.0 # Dropout to apply inside the model\n", " input_dropout_prob : float = 0.0 # Dropout to apply on the input features\n", "\n", " def setup(self):\n", " # Input dim -> Model dim\n", " self.input_dropout = nn.Dropout(self.input_dropout_prob)\n", " self.input_layer = nn.Dense(self.model_dim)\n", " # Positional encoding for sequences\n", " self.positional_encoding = PositionalEncoding(self.model_dim)\n", " # Transformer\n", " self.transformer = TransformerEncoder(num_layers=self.num_layers,\n", " input_dim=self.model_dim,\n", " dim_feedforward=2*self.model_dim,\n", " num_heads=self.num_heads,\n", " dropout_prob=self.dropout_prob)\n", " # Output classifier per sequence lement\n", " self.output_net = [\n", " nn.Dense(self.model_dim),\n", " nn.LayerNorm(),\n", " nn.relu,\n", " nn.Dropout(self.dropout_prob),\n", " nn.Dense(self.num_classes)\n", " ]\n", "\n", " def __call__(self, x, mask=None, add_positional_encoding=True, train=True):\n", " \"\"\"\n", " Inputs:\n", " x - Input features of shape [Batch, SeqLen, input_dim]\n", " mask - Mask to apply on the attention outputs (optional)\n", " add_positional_encoding - If True, we add the positional encoding to the input.\n", " Might not be desired for some tasks.\n", " train - If True, dropout is stochastic\n", " \"\"\"\n", " x = self.input_dropout(x, deterministic=not train)\n", " x = self.input_layer(x)\n", " if add_positional_encoding:\n", " x = self.positional_encoding(x)\n", " x = self.transformer(x, mask=mask, train=train)\n", " for l in self.output_net:\n", " x = l(x) if not isinstance(l, nn.Dropout) else l(x, deterministic=not train)\n", " return x\n", "\n", " def get_attention_maps(self, x, mask=None, add_positional_encoding=True, train=True):\n", " \"\"\"\n", " Function for extracting the attention matrices of the whole Transformer for a single batch.\n", " Input arguments same as the forward pass.\n", " \"\"\"\n", " x = self.input_dropout(x, deterministic=not train)\n", " x = self.input_layer(x)\n", " if add_positional_encoding:\n", " x = self.positional_encoding(x)\n", " attention_maps = self.transformer.get_attention_maps(x, mask=mask, train=train)\n", " return attention_maps" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Out (3, 16, 10)\n", "Attention maps 5 (3, 4, 16, 16)\n" ] } ], "source": [ "## Test TransformerPredictor implementation\n", "# Example features as input\n", "main_rng, x_rng = random.split(main_rng)\n", "x = random.normal(x_rng, (3, 16, 64))\n", "# Create Transformer encoder\n", "transpre = TransformerPredictor(num_layers=5, \n", " model_dim=128,\n", " num_classes=10,\n", " num_heads=4,\n", " dropout_prob=0.15,\n", " input_dropout_prob=0.05)\n", "# Initialize parameters of transformer predictor with random key and inputs\n", "main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)\n", "params = transpre.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']\n", "# Apply transformer predictor with parameters on the inputs\n", "# Since dropout is stochastic, we need to pass a rng to the forward\n", "main_rng, dropout_apply_rng = random.split(main_rng)\n", "# Instead of passing params and rngs every time to a function call, we can bind them to the module\n", "binded_mod = transpre.bind({'params': params}, rngs={'dropout': dropout_apply_rng})\n", "out = binded_mod(x, train=True)\n", "print('Out', out.shape)\n", "attn_maps = binded_mod.get_attention_maps(x, train=True)\n", "print('Attention maps', len(attn_maps), attn_maps[0].shape)\n", "\n", "del transpre, binded_mod, params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Trainer module\n", "\n", "Finally, we add the missing parts needed for training a model in JAX and Flax. Note that we leave the specific loss function unimplemented, since this function depends on different tasks we do below.\n", "\n", "In the optimization, we use the Adam optimizer with our previously discussed cosine scheduler. Additionally, we use the optax transformation `optax.clip_by_global_norm` ([documentation](https://optax.readthedocs.io/en/latest/api.html#optax.clip_by_global_norm)). This clips the norm of the gradients for all parameters before taking an optimizer step and prevents the model from diverging if we obtain very high gradients at, for instance, sharp loss surfaces (see many good blog posts on gradient clipping, like [DeepAI glossary](https://deepai.org/machine-learning-glossary-and-terms/gradient-clipping)). For Transformers, gradient clipping can help to further stabilize the training during the first few iterations, and also afterward. The clip value is usually between 0.5 and 10, depending on how harsh you want to clip large gradients." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "class TrainerModule:\n", " \n", " def __init__(self, model_name, exmp_batch, max_iters, lr=1e-3, warmup=100, seed=42, **model_kwargs):\n", " \"\"\"\n", " Inputs:\n", " model_name - Name of the model. Used for saving and checkpointing\n", " exmp_batch - Example batch to the model for initialization\n", " max_iters - Number of maximum iterations the model is trained for. This is needed for the CosineWarmup scheduler\n", " lr - Learning rate in the optimizer\n", " warmup - Number of warmup steps. Usually between 50 and 500\n", " seed - Seed to use for model init\n", " \"\"\"\n", " super().__init__()\n", " self.model_name = model_name\n", " self.max_iters = max_iters\n", " self.lr = lr\n", " self.warmup = warmup\n", " self.seed = seed\n", " # Create empty model. Note: no parameters yet\n", " self.model = TransformerPredictor(**model_kwargs)\n", " # Prepare logging\n", " self.log_dir = os.path.join(CHECKPOINT_PATH, self.model_name)\n", " self.logger = SummaryWriter(log_dir=self.log_dir)\n", " # Create jitted training and eval functions\n", " self.create_functions()\n", " # Initialize model\n", " self.init_model(exmp_batch)\n", " \n", " def batch_to_input(self, exmp_batch):\n", " # Map batch to input data to the model\n", " # To be implemented in a task specific sub-class\n", " raise NotImplementedError\n", " \n", " def get_loss_function(self):\n", " # Return a function that calculates the loss for a batch\n", " # To be implemented in a task specific sub-class\n", " raise NotImplementedError\n", " \n", " def create_functions(self):\n", " # Create jitted train and eval functions\n", " calculate_loss = self.get_loss_function()\n", " \n", " # Training function\n", " def train_step(state, rng, batch):\n", " loss_fn = lambda params: calculate_loss(params, rng, batch, train=True)\n", " ret, grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)\n", " loss, acc, rng = ret[0], *ret[1]\n", " state = state.apply_gradients(grads=grads)\n", " return state, rng, loss, acc\n", " self.train_step = jax.jit(train_step)\n", " \n", " # Evaluation function\n", " def eval_step(state, rng, batch):\n", " _, (acc, rng) = calculate_loss(state.params, rng, batch, train=False)\n", " return acc, rng\n", " self.eval_step = jax.jit(eval_step)\n", " \n", " def init_model(self, exmp_batch):\n", " # Initialize model\n", " self.rng = jax.random.PRNGKey(self.seed)\n", " self.rng, init_rng, dropout_init_rng = jax.random.split(self.rng, 3)\n", " exmp_input = self.batch_to_input(exmp_batch)\n", " params = self.model.init({'params': init_rng, 'dropout': dropout_init_rng}, exmp_input, train=True)['params']\n", " # Initialize learning rate schedule and optimizer\n", " lr_schedule = optax.warmup_cosine_decay_schedule(\n", " init_value=0.0,\n", " peak_value=self.lr,\n", " warmup_steps=self.warmup,\n", " decay_steps=self.max_iters,\n", " end_value=0.0\n", " )\n", " optimizer = optax.chain(\n", " optax.clip_by_global_norm(1.0), # Clip gradients at norm 1\n", " optax.adam(lr_schedule)\n", " )\n", " # Initialize training state\n", " self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=optimizer)\n", " \n", " def train_model(self, train_loader, val_loader, num_epochs=500):\n", " # Train model for defined number of epochs\n", " best_acc = 0.0\n", " for epoch_idx in tqdm(range(1, num_epochs+1)):\n", " self.train_epoch(train_loader, epoch=epoch_idx)\n", " if epoch_idx % 5 == 0:\n", " eval_acc = self.eval_model(val_loader)\n", " self.logger.add_scalar('val/accuracy', eval_acc, global_step=epoch_idx)\n", " if eval_acc >= best_acc:\n", " best_acc = eval_acc\n", " self.save_model(step=epoch_idx)\n", " self.logger.flush()\n", " \n", " def train_epoch(self, train_loader, epoch):\n", " # Train model for one epoch, and log avg loss and accuracy\n", " accs, losses = [], []\n", " for batch in tqdm(train_loader, desc='Training', leave=False):\n", " self.state, self.rng, loss, accuracy = self.train_step(self.state, self.rng, batch)\n", " losses.append(loss)\n", " accs.append(accuracy)\n", " avg_loss = np.stack(jax.device_get(losses)).mean()\n", " avg_acc = np.stack(jax.device_get(accs)).mean()\n", " self.logger.add_scalar('train/loss', avg_loss, global_step=epoch)\n", " self.logger.add_scalar('train/accuracy', avg_acc, global_step=epoch)\n", " \n", " def eval_model(self, data_loader):\n", " # Test model on all data points of a data loader and return avg accuracy\n", " correct_class, count = 0, 0\n", " for batch in data_loader:\n", " acc, self.rng = self.eval_step(self.state, self.rng, batch)\n", " correct_class += acc * batch[0].shape[0]\n", " count += batch[0].shape[0]\n", " eval_acc = (correct_class / count).item()\n", " return eval_acc\n", " \n", " def save_model(self, step=0):\n", " # Save current model at certain training iteration\n", " checkpoints.save_checkpoint(ckpt_dir=self.log_dir, target=self.state.params, step=step)\n", " \n", " def load_model(self, pretrained=False):\n", " # Load model. We use different checkpoint for the pretrained model\n", " if not pretrained:\n", " params = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=self.state.params)\n", " else:\n", " params = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt'), target=self.state.params)\n", " self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.state.tx)\n", " \n", " def checkpoint_exists(self):\n", " # Check whether a pretrained model exist for this Transformer\n", " return os.path.isfile(os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Experiments\n", "\n", "After having finished the implementation of the Transformer architecture, we can start experimenting and apply it to various tasks. In this notebook, we will focus on two tasks: parallel Sequence-to-Sequence, and set anomaly detection. The two tasks focus on different properties of the Transformer architecture, and we go through them below.\n", "\n", "### Sequence to Sequence\n", "\n", "A Sequence-to-Sequence task represents a task where the input _and_ the output is a sequence, not necessarily of the same length. Popular tasks in this domain include machine translation and summarization. For this, we usually have a Transformer encoder for interpreting the input sequence, and a decoder for generating the output in an autoregressive manner. Here, however, we will go back to a much simpler example task and use only the encoder. Given a sequence of $N$ numbers between $0$ and $M$, the task is to reverse the input sequence. In Numpy notation, if our input is $x$, the output should be $x$[::-1]. Although this task sounds very simple, RNNs can have issues with such because the task requires long-term dependencies. Transformers are built to support such, and hence, we expect it to perform very well. \n", "\n", "First, let's create a dataset class below." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class ReverseDataset(data.Dataset):\n", " \n", " def __init__(self, num_categories, seq_len, size, np_rng):\n", " super().__init__()\n", " self.num_categories = num_categories\n", " self.seq_len = seq_len\n", " self.size = size\n", " self.np_rng = np_rng\n", "\n", " self.data = self.np_rng.integers(self.num_categories, size=(self.size, self.seq_len))\n", "\n", " def __len__(self):\n", " return self.size\n", "\n", " def __getitem__(self, idx):\n", " inp_data = self.data[idx]\n", " labels = np.flip(inp_data, axis=0)\n", " return inp_data, labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create an arbitrary number of random sequences of numbers between 0 and `num_categories-1`. The label is simply the tensor flipped over the sequence dimension. We can create the corresponding data loaders using PyTorch below. " ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Combine batch elements (all numpy) by stacking\n", "def numpy_collate(batch):\n", " if isinstance(batch[0], np.ndarray):\n", " return np.stack(batch)\n", " elif isinstance(batch[0], (tuple,list)):\n", " transposed = zip(*batch)\n", " return [numpy_collate(samples) for samples in transposed]\n", " else:\n", " return np.array(batch)\n", "\n", "dataset = partial(ReverseDataset, 10, 16)\n", "rev_train_loader = data.DataLoader(dataset(50000, np_rng=np.random.default_rng(42)), \n", " batch_size=128, \n", " shuffle=True, \n", " drop_last=True,\n", " collate_fn=numpy_collate)\n", "rev_val_loader = data.DataLoader(dataset(1000, np_rng=np.random.default_rng(43)), \n", " batch_size=128,\n", " collate_fn=numpy_collate)\n", "rev_test_loader = data.DataLoader(dataset(10000, np_rng=np.random.default_rng(44)), \n", " batch_size=128,\n", " collate_fn=numpy_collate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Remember that these data loaders return numpy arrays instead of PyTorch tensors, as we define in the `numpy_collate` function which combines individual elements to a batch. As the data set is so simple and the `__getitem__` finishes a neglectable time, we don't need subprocesses, i.e. workers, to provide us the data (in fact, more workers can slow down the training as we have communication overhead among processes/threads). \n", "\n", "Now, let's look at an arbitrary sample of the dataset:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input data: [0 7 6 4 4 8 0 6 2 0 5 9 7 7 7 7]\n", "Labels: [7 7 7 7 9 5 0 2 6 0 8 4 4 6 7 0]\n" ] } ], "source": [ "inp_data, labels = rev_train_loader.dataset[0]\n", "print(\"Input data:\", inp_data)\n", "print(\"Labels: \", labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "During training, we pass the input sequence through the Transformer encoder and predict the output for each input token. We use the standard Cross-Entropy loss to perform this. Every number is represented as a one-hot vector. Remember that representing the categories as single scalars decreases the expressiveness of the model extremely as $0$ and $1$ are not closer related than $0$ and $9$ in our example. An alternative to a one-hot vector is using a learned embedding vector as it is provided by an `nn.Embed` module ([documentation](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Embed.html)). However, using a one-hot vector with an additional linear layer as in our case has the same effect as an embedding layer (`self.input_net` maps one-hot vector to a dense vector, where each row of the weight matrix represents the embedding for a specific category).\n", "\n", "To implement the training dynamic, we create a new class inheriting from `TrainerModule` and defining the loss function as follows:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "class ReverseTrainer(TrainerModule):\n", " \n", " def batch_to_input(self, batch):\n", " inp_data, _ = batch\n", " inp_data = jax.nn.one_hot(inp_data, num_classes=self.model.num_classes)\n", " return inp_data\n", " \n", " def get_loss_function(self):\n", " # Function for calculating loss and accuracy for a batch\n", " def calculate_loss(params, rng, batch, train):\n", " inp_data, labels = batch\n", " inp_data = jax.nn.one_hot(inp_data, num_classes=self.model.num_classes)\n", " rng, dropout_apply_rng = random.split(rng)\n", " logits = self.model.apply({'params': params}, inp_data, train=train, rngs={'dropout': dropout_apply_rng})\n", " loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()\n", " acc = (logits.argmax(axis=-1) == labels).mean()\n", " return loss, (acc, rng)\n", " return calculate_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can create a training function, similar to ones we have seen before. We create a `ReverseTrainer` object, run the training for $N$ epochs while logging in TensorBoard, and saving our best model based on the validation. Afterward, we test our models on the test set." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def train_reverse(max_epochs=10, **model_args):\n", " num_train_iters = len(rev_train_loader) * max_epochs\n", " # Create a trainer module with specified hyperparameters\n", " trainer = ReverseTrainer(model_name='ReverseTask', \n", " exmp_batch=next(iter(rev_train_loader)),\n", " max_iters=num_train_iters, \n", " **model_args)\n", " if not trainer.checkpoint_exists(): # Skip training if pretrained model exists\n", " trainer.train_model(rev_train_loader, rev_val_loader, num_epochs=max_epochs)\n", " trainer.load_model()\n", " else:\n", " trainer.load_model(pretrained=True)\n", " val_acc = trainer.eval_model(rev_val_loader)\n", " test_acc = trainer.eval_model(rev_test_loader)\n", " # Bind parameters to model for easier inference\n", " trainer.model_bd = trainer.model.bind({'params': trainer.state.params})\n", " return trainer, {'val_acc': val_acc, 'test_acc': test_acc}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can train the model. In this setup, we will use a single encoder block and a single head in the Multi-Head Attention. This is chosen because of the simplicity of the task, and in this case, the attention can actually be interpreted as an \"explanation\" of the predictions (compared to the other papers above dealing with deep Transformers). " ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "scrolled": false }, "outputs": [], "source": [ "reverse_trainer, reverse_result = train_reverse(model_dim=32,\n", " num_heads=1,\n", " num_classes=rev_train_loader.dataset.num_categories,\n", " num_layers=1,\n", " dropout_prob=0.0,\n", " lr=5e-4,\n", " warmup=50)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's print the results:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Val accuracy: 100.00%\n", "Test accuracy: 100.00%\n" ] } ], "source": [ "print(f\"Val accuracy: {(100.0 * reverse_result['val_acc']):4.2f}%\")\n", "print(f\"Test accuracy: {(100.0 * reverse_result['test_acc']):4.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we would have expected, the Transformer can correctly solve the task. However, how does the attention in the Multi-Head Attention block looks like for an arbitrary input? Let's try to visualize it below." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "data_input, labels = next(iter(rev_val_loader))\n", "inp_data = jax.nn.one_hot(data_input, num_classes=reverse_trainer.model.num_classes)\n", "attention_maps = reverse_trainer.model_bd.get_attention_maps(inp_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The object `attention_maps` is a list of length $N$ where $N$ is the number of layers. Each element is a tensor of shape [Batch, Heads, SeqLen, SeqLen], which we can verify below." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(128, 1, 16, 16)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "attention_maps[0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we will write a plotting function that takes as input the sequences, attention maps, and an index indicating for which batch element we want to visualize the attention map. We will create a plot where over rows, we have different layers, while over columns, we show the different heads. Remember that the softmax has been applied for each row separately." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "def plot_attention_maps(input_data, attn_maps, idx=0):\n", " if input_data is not None:\n", " input_data = jax.device_get(input_data[idx])\n", " else:\n", " input_data = np.arange(attn_maps[0][idx].shape[-1])\n", " attn_maps = [jax.device_get(m[idx]) for m in attn_maps]\n", " \n", " num_heads = attn_maps[0].shape[0]\n", " num_layers = len(attn_maps)\n", " seq_len = input_data.shape[0]\n", " fig_size = 4 if num_heads == 1 else 3\n", " fig, ax = plt.subplots(num_layers, num_heads, figsize=(num_heads*fig_size, num_layers*fig_size))\n", " if num_layers == 1:\n", " ax = [ax]\n", " if num_heads == 1:\n", " ax = [[a] for a in ax]\n", " for row in range(num_layers):\n", " for column in range(num_heads):\n", " ax[row][column].imshow(attn_maps[row][column], origin='lower', vmin=0)\n", " ax[row][column].set_xticks(list(range(seq_len)))\n", " ax[row][column].set_xticklabels(input_data.tolist())\n", " ax[row][column].set_yticks(list(range(seq_len)))\n", " ax[row][column].set_yticklabels(input_data.tolist())\n", " ax[row][column].set_title(f\"Layer {row+1}, Head {column+1}\")\n", " fig.subplots_adjust(hspace=0.5)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can plot the attention map of our trained Transformer on the reverse task:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDI0NS4xOTkzNzUgMjYzLjYzNjg3NSBdIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovVHlwZSAvUGFnZSA+PgplbmRvYmoKOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDExIDAgUiA+PgpzdHJlYW0KeJyVl01PGzEQhu/+FT62UjXxjL+PIFraqhfaSJwRhLQIqACpqP++40CwN7Nrq4co2Ve775MZj2fWqG/U6gj19kkbfcOfZ436VK9ONn9+XW6+nx7ryydlWL9T5DxgzjZ6vrxtLylYCDYk/nnLN08ufyp1r9ifnzll661SZMD7l+cspIjlPna3EXI+kG8nskOwuNerSSsz7Vo96DkEYQTn9l+PG32u7/XqiEroyKEjh24moSsOnb0w6ZKA8j3ne3mnV19Qn/zWZ+pMP+wtDYdcbA2kV2NWFEWw3onYq2rA7kNXx5y3Z3W81qtPqNHo9TUnHTC6XWozUExIXq+v1Dv/Xq9v9Me12uGUM5CdwDTqAMNJj2EGE6YY78BbgWnUAcZjiVti3BQTEiAJTKMOMMGVuCXGTDEJIaLANOoAE1OJe7g22YM1AtOoA0ymEvcwGjQJUhacVh6A0PgSuiTRAYkIXJKkRh6RMJfoJSkdkGwAEyWpkUckSyX8YcmhyxCCJDXyiORCCX9YDsjNkmRLaOURKZgSviTlA1IMkGRXaOURKdoS/rgisgEnG0Mrj0gplvCHFUHGgZG9oZVHHdWYEr4kxQMSz40g20Mrj0joSviSZA9IFoFkh2jlEYlSCX+4TnWcWcMLG/YFAeZQXeh4QDzs+S/OFXY1d47XUphXddmcJ5YLc5OnmvsIr7usNa/qsjnPKTO7+at5RF4xYV7VZfPAZeHm+nI1T54XSZhXddmcZxLZfs5z4rIU5lVdNs8OEvX/OSJxKQr3Rl62LwPIYb8auVwhZulf5Y4/0XyjeDO3PGySNK9yx9x68LlfM+gJcpT+Ve7483zh18zuymII4OVebeSOfyCIcW44NP7JAMrt2sgdf54lNgwWN1uIcsc2csc/ZcgLg+Dtzdzw27TctI3caWbGgnez7b/6888s920jd/z5frSzTb85WfD7s9y6jdzx544daZp/0l9fjna7g8j0YLdwFJs/W6kf84e0u8VDWnnifw570/urU5dgdvFtmzPc9i0xVBLD3ZLr9mVwurh/tuTo28XfzaPGD/rz5uJK4z5jZ+oftD4M6AplbmRzdHJlYW0KZW5kb2JqCjExIDAgb2JqCjgwOAplbmRvYmoKMTcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA3NyA+PgpzdHJlYW0KeJwzNzVSMFCwtAASZqYmCuZGlgophlxAPoiVy2VoaQ5m5YBZJsYGQJapqSkSCyIL0wthweRgtLGJOdQEBAskB7Y2B2ZbDlcaAJ7gG5oKZW5kc3RyZWFtCmVuZG9iagoxOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDU5ID4+CnN0cmVhbQp4nDM1NVcwULC0ABKmpkYK5kaWCimGXEA+iJXLZWhpDmblgFkWxkAGSBmcYQCkwZpzYHpyuNIAqeEQWgplbmRzdHJlYW0KZW5kb2JqCjE5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzA0ID4+CnN0cmVhbQp4nD2SO5LDMAxDe52CF8iM+JPk82Qnlff+7T4yyVaASYkAKC91mbKmPCBpJgn/0eHhYjvld9iezczAtUQvE8spz6ErxNxF+bKZjbqyOsWqwzCdW/SonIuGTZOa5ypLGbcLnsO1ieeWfcQPNzSoB3WNS8IN3dVoWQrNcHX/O71H2Xc1PBebVOrUF48XURXm+SFPoofpSuJ8PCghXHswRhYS5FPRQI6zXK3yXkL2DrcassJBaknnsyc82HV6Ty5uF80QD2S5VPhOUezt0DO+7EoJPRK24VjufTuasekamzjsfu9G1sqMrmghfshXJ+slYNxTJkUSZE62WG6L1Z7uoSimc4ZzGSDq2YqGUuZiV6t/DDtvLC/ZLMiUzAsyRqdNnjh4yH6NmvR5led4/QFs83M7CmVuZHN0cmVhbQplbmRvYmoKMjAgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA2NiA+PgpzdHJlYW0KeJwzNrRQMFAwN1fQNTQ0VTAyMlAwNDJRSDHkMjQ0BzNzuWCCOWCWiQGQYQgkwRpyuGBac8A6ILJQrTlcaQBNOBH1CmVuZHN0cmVhbQplbmRvYmoKMjEgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMjcgPj4Kc3RyZWFtCnicNU87sgMhDOs5hS6QGYxtYM+zmVQv92+fZLINEv5I8vRERyZe5sgIrNnxthYZiBn4FlPxrz3tw4TqPbiHCOXiQphhJJw167ibp+PFv13lM9bBuw2+YpYXBLYwk/WVxZnLdsFYGidxTrIbY9dEbGNd6+kU1hFMKAMhne0wJcgcFSl9sqOMOTpO5InnYqrFLr/vYX3BpjGiwhxXBU/QZFCWPe8moB0X9N/Vjd9JNIteAjKRYGGdJObOWU741WtHx1GLIjEnpBnkMhHSnK5iCqEJxTo7CioVBZfqc8rdPv9oXVtNCmVuZHN0cmVhbQplbmRvYmoKMjIgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDUgPj4Kc3RyZWFtCnicRVC7jUMxDOs9BRcIYP0se553SJXbvz1KRnCFIVo/kloSmIjASwyxlG/iR0ZBPQu/F4XiM8TPF4VBzoSkQJz1GRCZeIbaRm7odnDOvMMzjDkCF8VacKbTmfZc2OScBycQzm2U8YxCuklUFXFUn3FM8aqyz43XgaW1bLPTkewhjYRLSSUml35TKv+0KVsq6NpFE7BI5IGTTTThLD9DkmLMoJRR9zC1jvRxspFHddDJ2Zw5LZnZ7qftTHwPWCaZUeUpnecyPiep81xOfe6zHdHkoqVV+5z93pGW8iK126HV6VclUZmN1aeQuDz/jJ/x/gOOoFk+CmVuZHN0cmVhbQplbmRvYmoKMjMgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzOTIgPj4Kc3RyZWFtCnicPVJLbgUxCNvPKbhApfBNcp6p3u7df1ubzFSqCi8DtjGUlwypJT/qkogzTH71cl3iUfK9bGpn5iHuLjam+FhyX7qG2HLRmmKxTxzJL8i0VFihVt2jQ/GFKBMPAC3ggQXhvhz/8ReowdewhXLDe2QCYErUbkDGQ9EZSFlBEWH7kRXopFCvbOHvKCBX1KyFoXRiiA2WACm+qw2JmKjZoIeElZKqHdLxjKTwW8FdiWFQW1vbBHhm0BDZ3pGNETPt0RlxWRFrPz3po1EytVEZD01nfPHdMlLz0RXopNLI3cpDZ89CJ2Ak5kmY53Aj4Z7bQQsx9HGvlk9s95gpVpHwBTvKAQO9/d6Sjc974CyMXNvsTCfw0WmnHBOtvh5i/YM/bEubXMcrh0UUqLwoCH7XQRNxfFjF92SjRHe0AdYjE9VoJRAMEsLO7TDyeMZ52d4VtOb0RGijRB7UjhE9KLLF5ZwVsKf8rM2xHJ4PJntvtI+UzMyohBXUdnqots9jHdR3nvv6/AEuAKEZCmVuZHN0cmVhbQplbmRvYmoKMjQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDcgPj4Kc3RyZWFtCnicTVG7bUQxDOvfFFzgAOtreZ4LUl32b0PJCJDCIKEvKaclFvbGSwzhB1sPvuSRVUN/Hj8x7DMsPcnk1D/muclUFL4VqpuYUBdi4f1oBLwWdC8iK8oH349lDHPO9+CjEJdgJjRgrG9JJhfVvDNkwomhjsNBm1QYd00ULK4VzTPI7VY3sjqzIGx4JRPixgBEBNkXkM1go4yxlZDFch6oCpIFWmDX6RtRi4IrlNYJdKLWxLrM4Kvn9nY3Qy/y4Ki6eH0M60uwwuileyx8rkIfzPRMO3dJI73wphMRZg8FUpmdkZU6PWJ9t0D/n2Ur+PvJz/P9CxUoXCoKZW5kc3RyZWFtCmVuZG9iagoyNSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDkwID4+CnN0cmVhbQp4nE2NQRLAIAgD77wiT1BE0P90etL/X6vUDr3ATgKJFkWC9DVqSzDuuDIVa1ApmJSXwFUwXAva7qLK/jJJTJ2G03u3A4Oy8XGD0kn79nF6AKv9egbdD9IcIlgKZW5kc3RyZWFtCmVuZG9iagoyNiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMyMCA+PgpzdHJlYW0KeJw1UbtxxTAM6zUFF/Cd+JU0j3Ovytu/DUA7FWEaBECqvGRKuVzqklWywuRHh+oUTfk+YKb8DvWQ4+ge2SG6U9aWexgIy8Q8pY5YTZZ7uAWBLwxNibmF8/cI6CsGozATgbrF3z9AsyQwaXDwU5BrrVpiiQ48LBZYsyvMrRopVMhVfDs2uQcFcnGz0KccmhS33ILwZYhkR2qxr8tlKfK79QkYhBXmiE8UiYXngQ5mIvEnA2J79tliV1cvqhEZ1kmHB1IE0mxuEjA0RbLqgxvYV8c1P09H2cHJQb+Kwfg2OJkvSXlfBaEQjxf+Ds/ZyLGSQyQU8n21wIgjbIARoU/tIxBlIDRF9+6ZUj4mVYrvAEYhHH2qVzK8F5HZaobN/xld2SoKBlVZH59GcCaDSTjzZKMK01K107/73OPzB2NjeoAKZW5kc3RyZWFtCmVuZG9iagoyNyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDgwID4+CnN0cmVhbQp4nEWMuw3AMAhEe6ZgBH4mZp8olbN/GyBK3HBPunu4OhIyU95hhocEngwshlPxBpmjYDW4RlKNneyjsG5fdYHmelOr9fcHKk92dnE9zcsZ9AplbmRzdHJlYW0KZW5kb2JqCjI4IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTU3ID4+CnN0cmVhbQp4nEWQuRFDMQhEc1VBCRKwCOqxx9F3/6kX+Uq0bwAth68lU6ofJyKm3Ndo9DB5Dp9NJVYs2Ca2kxpyGxZBSjGYeE4xq6O3oZmH1Ou4qKq4dWaV02nLysV/82hXM5M9wjXqJ/BN6PifPLSp6FugrwuUfUC1OJ1JUDF9r2KBo5x2fyKcGOA+GUeZKSNxYm4K7PcZAGa+V7jG4wXdATd5CmVuZHN0cmVhbQplbmRvYmoKMjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA2OCA+PgpzdHJlYW0KeJwzMzZTMFCwMAISpqaGCuZGlgophlxAPoiVywUTywGzzCzMgSwjC5CWHC5DC2MwbWJspGBmYgZkWSAxILrSAHL4EpEKZW5kc3RyZWFtCmVuZG9iagozMCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMxNyA+PgpzdHJlYW0KeJw1UktyQzEI279TcIHOmL99nnSyau6/rYQnK7AtQEIuL1nSS37UJdulw+RXH/clsUI+j+2azFLF9xazFM8tr0fPEbctCgRREz34MicVItTP1Og6eGGXPgOvEE4pFngHkwAGr+FfeJROg8A7GzLeEZORGhAkwZpLi01IlD1J/Cvl9aSVNHR+Jitz+XtyqRRqo8kIFSBYudgHpCspHiQTPYlIsnK9N1aI3pBXksdnJSYZEN0msU20wOPclbSEmZhCBeZYgNV0s7r6HExY47CE8SphFtWDTZ41qYRmtI5jZMN498JMiYWGwxJQm32VCaqXj9PcCSOmR0127cKyWzbvIUSj+TMslMHHKCQBh05jJArSsIARgTm9sIq95gs5FsCIZZ2aLAxtaCW7eo6FwNCcs6Vhxtee1/P+B0Vbe6MKZW5kc3RyZWFtCmVuZG9iagozMSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE3ID4+CnN0cmVhbQp4nDM2tFAwgMMUQy4AGpQC7AplbmRzdHJlYW0KZW5kb2JqCjMyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzM4ID4+CnN0cmVhbQp4nDVSOa7dQAzrfQpdIIB2zZznBal+7t+GlF8KQ7RWipqOFpVp+WUhVS2TLr/tSW2JG/L3yQqJE5JXJdqlDJFQ+TyFVL9ny7y+1pwRIEuVCpOTksclC/4Ml94uHOdjaz+PI3c9emBVjIQSAcsUE6NrWTq7w5qN/DymAT/iEXKuWLccYxVIDbpx2hXvQ/N5yBogZpiWigpdVokWfkHxoEetffdYVFgg0e0cSXCMjVCRgHaB2kgMObMWu6gv+lmUmAl07Ysi7qLAEknMnGJdOvoPPnQsqL8248uvjkr6SCtrTNp3o0lpzCKTrpdFbzdvfT24QPMuyn9ezSBBU9YoaXzQqp1jKJoZZYV3HJoMNMcch8wTPIczEpT0fSh+X0smuiiRPw4NoX9fHqOMnAZvAXPRn7aKAxfx2WGvHGCF0sWa5H1AKhN6YPr/1/h5/vwDHLaAVAplbmRzdHJlYW0KZW5kb2JqCjMzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjQ4ID4+CnN0cmVhbQp4nC1ROZIDQQjL5xV6QnPT77HLkff/6QrKAYOGQyA6LXFQxk8Qlive8shVtOHvmRjBd8Gh38p1GxY5EBVI0hhUTahdvB69B3YcZgLzpDUsgxnrAz9jCjd6cXhMxtntdRk1BHvXa09mUDIrF3HJxAVTddjImcNPpowL7VzPDci5EdZlGKSblcaMhCNNIVJIoeomqTNBkASjq1GjjRzFfunLI51hVSNqDPtcS9vXcxPOGjQ7Fqs8OaVHV5zLycULKwf9vM3ARVQaqzwQEnC/20P9nOzkN97SubPF9Phec7K8MBVY8ea1G5BNtfg3L+L4PePr+fwDqKVbFgplbmRzdHJlYW0KZW5kb2JqCjM0IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTM4ID4+CnN0cmVhbQp4nD2PQQ4DMQgD73mFPxApdkJY3rNVT9v/X0ua3V7QCIwxFkJDb6hqDpuCDceLpUuo1vApiolKDsiZYA6lpNIdZ5F6YjgY3B60G87isen6EbuSVn3Q5ka6JWiCR+xTadyWcRPEAzUF6inqXKO8ELmfqVfYNJLdtLKSazim373nqev/01XeX1/fLowKZW5kc3RyZWFtCmVuZG9iagozNSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIxMCA+PgpzdHJlYW0KeJw1UMsNQzEIu2cKFqgUAoFknla9df9rbdA7YRH/QljIlAh5qcnOKelLPjpMD7Yuv7EiC611JezKmiCeK++hmbKx0djiYHAaJl6AFjdg6GmNGjV04YKmLpVCgcUl8Jl8dXvovk8ZeGoZcnYEEUPJYAlquhZNWLQ8n5BOAeL/fsPuLeShkvPKnhv5G5zt8DuzbuEnanYi0XIVMtSzNMcYCBNFHjx5RaZw4rPWd9U0EtRmC06WAa5OP4wOAGAiXlmA7K5EOUvSjqWfb7zH9w9AAFO0CmVuZHN0cmVhbQplbmRvYmoKMTUgMCBvYmoKPDwgL0Jhc2VGb250IC9EZWphVnVTYW5zIC9DaGFyUHJvY3MgMTYgMCBSCi9FbmNvZGluZyA8PAovRGlmZmVyZW5jZXMgWyAzMiAvc3BhY2UgNDQgL2NvbW1hIDQ4IC96ZXJvIC9vbmUgL3R3byAvdGhyZWUgL2ZvdXIgL2ZpdmUgL3NpeCAvc2V2ZW4KL2VpZ2h0IC9uaW5lIDcyIC9IIDc2IC9MIDk3IC9hIDEwMCAvZCAvZSAxMTQgL3IgMTIxIC95IF0KL1R5cGUgL0VuY29kaW5nID4+Ci9GaXJzdENoYXIgMCAvRm9udEJCb3ggWyAtMTAyMSAtNDYzIDE3OTQgMTIzMyBdIC9Gb250RGVzY3JpcHRvciAxNCAwIFIKL0ZvbnRNYXRyaXggWyAwLjAwMSAwIDAgMC4wMDEgMCAwIF0gL0xhc3RDaGFyIDI1NSAvTmFtZSAvRGVqYVZ1U2FucwovU3VidHlwZSAvVHlwZTMgL1R5cGUgL0ZvbnQgL1dpZHRocyAxMyAwIFIgPj4KZW5kb2JqCjE0IDAgb2JqCjw8IC9Bc2NlbnQgOTI5IC9DYXBIZWlnaHQgMCAvRGVzY2VudCAtMjM2IC9GbGFncyAzMgovRm9udEJCb3ggWyAtMTAyMSAtNDYzIDE3OTQgMTIzMyBdIC9Gb250TmFtZSAvRGVqYVZ1U2FucyAvSXRhbGljQW5nbGUgMAovTWF4V2lkdGggMTM0MiAvU3RlbVYgMCAvVHlwZSAvRm9udERlc2NyaXB0b3IgL1hIZWlnaHQgMCA+PgplbmRvYmoKMTMgMCBvYmoKWyA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMAo2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDMxOCA0MDEgNDYwIDgzOCA2MzYKOTUwIDc4MCAyNzUgMzkwIDM5MCA1MDAgODM4IDMxOCAzNjEgMzE4IDMzNyA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2CjYzNiA2MzYgMzM3IDMzNyA4MzggODM4IDgzOCA1MzEgMTAwMCA2ODQgNjg2IDY5OCA3NzAgNjMyIDU3NSA3NzUgNzUyIDI5NQoyOTUgNjU2IDU1NyA4NjMgNzQ4IDc4NyA2MDMgNzg3IDY5NSA2MzUgNjExIDczMiA2ODQgOTg5IDY4NSA2MTEgNjg1IDM5MCAzMzcKMzkwIDgzOCA1MDAgNTAwIDYxMyA2MzUgNTUwIDYzNSA2MTUgMzUyIDYzNSA2MzQgMjc4IDI3OCA1NzkgMjc4IDk3NCA2MzQgNjEyCjYzNSA2MzUgNDExIDUyMSAzOTIgNjM0IDU5MiA4MTggNTkyIDU5MiA1MjUgNjM2IDMzNyA2MzYgODM4IDYwMCA2MzYgNjAwIDMxOAozNTIgNTE4IDEwMDAgNTAwIDUwMCA1MDAgMTM0MiA2MzUgNDAwIDEwNzAgNjAwIDY4NSA2MDAgNjAwIDMxOCAzMTggNTE4IDUxOAo1OTAgNTAwIDEwMDAgNTAwIDEwMDAgNTIxIDQwMCAxMDIzIDYwMCA1MjUgNjExIDMxOCA0MDEgNjM2IDYzNiA2MzYgNjM2IDMzNwo1MDAgNTAwIDEwMDAgNDcxIDYxMiA4MzggMzYxIDEwMDAgNTAwIDUwMCA4MzggNDAxIDQwMSA1MDAgNjM2IDYzNiAzMTggNTAwCjQwMSA0NzEgNjEyIDk2OSA5NjkgOTY5IDUzMSA2ODQgNjg0IDY4NCA2ODQgNjg0IDY4NCA5NzQgNjk4IDYzMiA2MzIgNjMyIDYzMgoyOTUgMjk1IDI5NSAyOTUgNzc1IDc0OCA3ODcgNzg3IDc4NyA3ODcgNzg3IDgzOCA3ODcgNzMyIDczMiA3MzIgNzMyIDYxMSA2MDUKNjMwIDYxMyA2MTMgNjEzIDYxMyA2MTMgNjEzIDk4MiA1NTAgNjE1IDYxNSA2MTUgNjE1IDI3OCAyNzggMjc4IDI3OCA2MTIgNjM0CjYxMiA2MTIgNjEyIDYxMiA2MTIgODM4IDYxMiA2MzQgNjM0IDYzNCA2MzQgNTkyIDYzNSA1OTIgXQplbmRvYmoKMTYgMCBvYmoKPDwgL0ggMTcgMCBSIC9MIDE4IDAgUiAvYSAxOSAwIFIgL2NvbW1hIDIwIDAgUiAvZCAyMSAwIFIgL2UgMjIgMCBSCi9laWdodCAyMyAwIFIgL2ZpdmUgMjQgMCBSIC9mb3VyIDI1IDAgUiAvbmluZSAyNiAwIFIgL29uZSAyNyAwIFIgL3IgMjggMCBSCi9zZXZlbiAyOSAwIFIgL3NpeCAzMCAwIFIgL3NwYWNlIDMxIDAgUiAvdGhyZWUgMzIgMCBSIC90d28gMzMgMCBSIC95IDM0IDAgUgovemVybyAzNSAwIFIgPj4KZW5kb2JqCjMgMCBvYmoKPDwgL0YxIDE1IDAgUiA+PgplbmRvYmoKNCAwIG9iago8PCAvQTEgPDwgL0NBIDAgL1R5cGUgL0V4dEdTdGF0ZSAvY2EgMSA+PgovQTIgPDwgL0NBIDEgL1R5cGUgL0V4dEdTdGF0ZSAvY2EgMSA+PiA+PgplbmRvYmoKNSAwIG9iago8PCA+PgplbmRvYmoKNiAwIG9iago8PCA+PgplbmRvYmoKNyAwIG9iago8PCAvSTEgMTIgMCBSID4+CmVuZG9iagoxMiAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAyMTggL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDIxOCAvTGVuZ3RoIDM2IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDIxOCA+PgpzdHJlYW0KeJzt3T+I13Ucx/HvnXdaRmAZmdXUokSBU21SBFJDQVNDUS1JgxW1nEISgZJBGNFyNURLSGBEQiD2b4uGlBtM56K6LvqjpZx5ejZHFK/g/rwuH4/5zYcf3PM+05vPd+Ti9J1D5tFd0eQ7n0+HBw4XzqWTC25kNBq7NL/Iv4O/yP4qsCTkSBE5UkSOFJEjReRIETlSRI4UkSNF5EgROVJEjhSRI0XGHtl5Rzg68dhvydjJbzeEBx77+odo7uIf4YH/gVWdSm5HisiRInKkiBwpIkeKyJEicqSIHCkiR4rIkSJypIgcKTJ24IufwtGzs+uTsb3bR8IDJ9+7IRn74Fi2aTEszrIFS8jtSBE5UkSOFJEjReRIETlSRI4UkSNF5EgROVJEjhSRI0XkSJGxYe5MOHro6GwyNnPq5vDAd/fOJGOze9NXVo4cj7+wND+XTrKE3I4UkSNF5EgROVJEjhSRI0XkSBE5UkSOFJEjReRIETlSRI4UGRluuX+BjxwdDwfv2rwxGZvceSo88OlX1oWTR76Klok8+rPE3I4UkSNF5EgROVJEjhSRI0XkSBE5UkSOFJEjReRIETlSZBFWKHLZssXWTdGmxTAMrz4bPdsyDMOLb65Nxg5NZZsWwzBcOJdO8s/cjhSRI0XkSBE5UkSOFJEjReRIETlSRI4UkSNF5EgROVJEjhRZ1o2e0Ko14eC9t6UfRHrhiegzR3veWh0e+OHU9+GkLyz9C7cjReRIETlSRI4UkSNF5EgROVJEjhSRI0XkSBE5UkSOFBlb7h8QiD/ucvjEL+Hk6rfXJ2P7dvweHjiz58Zw8stvsodWLstXVtyOFJEjReRIETlSRI4UkSNF5EgROVJEjhSRI0XkSBE5UkSOFFkJj6LkRuL/rvGrk6kHt6wLz3toW/pyy8GPowWlg0d/Dg8c5s6kk/XcjhSRI0XkSBE5UkSOFJEjReRIETlSRI4UkSNF5EgROVJEjhRZCW/05C7Np5PnTydT7x9LP0n04+nrwsnJiWgB58xr14YHHj6e/cj4qaNl5HakiBwpIkeKyJEicqSIHCkiR4rIkSJypIgcKSJHisiRIv+vR1GW0dgV4eADWzYkY7sePx8e+PzkeDL2yYnp8MBhPl0cWXBuR4rIkSJypIgcKSJHisiRInKkiBwpIkeKyJEicqSIHCkiR4rY6Fly2e7PtluvD8/bt2M2GZt4/crwwI9OzISTC/7QituRInKkiBwpIkeKyJEicqSIHCkiR4rIkSJypIgcKSJHisiRIjZ6Wq1aEw7ed3v06M/LT50ND3xm/1Xh5Gcns3d/4kd/3I4UkSNF5EgROVJEjhSRI0XkSBE5UkSOFJEjReRIETlSxArFypctW9y9Odq0GIbhjV2/hpNPvnRNMvZpuGnhdqSKHCkiR4rIkSJypIgcKSJHisiRInKkiBwpIkeKyJEicqSIjZ7Lxuh4OLh108Zw8sCe75Kxh3ffFB7odqSIHCkiR4rIkSJypIgcKSJHisiRInKkiBwpIkeKyJEiVij4m7G14eD2e6Jli/3PTYUHuh0pIkeKyJEicqSIHCkiR4rIkSJypIgcKSJHisiRInKkiBwp8ifK6JXSCmVuZHN0cmVhbQplbmRvYmoKMzYgMCBvYmoKMTIwMwplbmRvYmoKMiAwIG9iago8PCAvQ291bnQgMSAvS2lkcyBbIDEwIDAgUiBdIC9UeXBlIC9QYWdlcyA+PgplbmRvYmoKMzcgMCBvYmoKPDwgL0NyZWF0aW9uRGF0ZSAoRDoyMDIyMDUzMTE2NTkzNCswMicwMCcpCi9DcmVhdG9yIChNYXRwbG90bGliIHYzLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjMuMikgPj4KZW5kb2JqCnhyZWYKMCAzOAowMDAwMDAwMDAwIDY1NTM1IGYgCjAwMDAwMDAwMTYgMDAwMDAgbiAKMDAwMDAwOTg2MSAwMDAwMCBuIAowMDAwMDA4MjAzIDAwMDAwIG4gCjAwMDAwMDgyMzUgMDAwMDAgbiAKMDAwMDAwODMzNCAwMDAwMCBuIAowMDAwMDA4MzU1IDAwMDAwIG4gCjAwMDAwMDgzNzYgMDAwMDAgbiAKMDAwMDAwMDA2NSAwMDAwMCBuIAowMDAwMDAwMzk5IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwMTI4MiAwMDAwMCBuIAowMDAwMDA4NDA4IDAwMDAwIG4gCjAwMDAwMDY5MDAgMDAwMDAgbiAKMDAwMDAwNjcwMCAwMDAwMCBuIAowMDAwMDA2MjgyIDAwMDAwIG4gCjAwMDAwMDc5NTMgMDAwMDAgbiAKMDAwMDAwMTMwMiAwMDAwMCBuIAowMDAwMDAxNDUxIDAwMDAwIG4gCjAwMDAwMDE1ODIgMDAwMDAgbiAKMDAwMDAwMTk1OSAwMDAwMCBuIAowMDAwMDAyMDk3IDAwMDAwIG4gCjAwMDAwMDIzOTcgMDAwMDAgbiAKMDAwMDAwMjcxNSAwMDAwMCBuIAowMDAwMDAzMTgwIDAwMDAwIG4gCjAwMDAwMDM1MDAgMDAwMDAgbiAKMDAwMDAwMzY2MiAwMDAwMCBuIAowMDAwMDA0MDU1IDAwMDAwIG4gCjAwMDAwMDQyMDcgMDAwMDAgbiAKMDAwMDAwNDQzNyAwMDAwMCBuIAowMDAwMDA0NTc3IDAwMDAwIG4gCjAwMDAwMDQ5NjcgMDAwMDAgbiAKMDAwMDAwNTA1NiAwMDAwMCBuIAowMDAwMDA1NDY3IDAwMDAwIG4gCjAwMDAwMDU3ODggMDAwMDAgbiAKMDAwMDAwNTk5OSAwMDAwMCBuIAowMDAwMDA5ODQwIDAwMDAwIG4gCjAwMDAwMDk5MjEgMDAwMDAgbiAKdHJhaWxlcgo8PCAvSW5mbyAzNyAwIFIgL1Jvb3QgMSAwIFIgL1NpemUgMzggPj4Kc3RhcnR4cmVmCjEwMDc4CiUlRU9GCg==\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T16:59:34.858868\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_attention_maps(data_input, attention_maps, idx=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model has learned to attend to the token that is on the flipped index of itself. Hence, it actually does what we intended it to do. We see that it however also pays some attention to values close to the flipped index. This is because the model doesn't need the perfect, hard attention to solve this problem, but is fine with this approximate, noisy attention map. The close-by indices are caused by the similarity of the positional encoding, which we also intended with the positional encoding." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set Anomaly Detection\n", "\n", "Besides sequences, sets are another data structure that is relevant for many applications. In contrast to sequences, elements are unordered in a set. RNNs can only be applied on sets by assuming an order in the data, which however biases the model towards a non-existing order in the data. [Vinyals et al. (2015)](https://arxiv.org/abs/1511.06391) and other papers have shown that the assumed order can have a significant impact on the model's performance, and hence, we should try to not use RNNs on sets. Ideally, our model should be permutation-equivariant/invariant such that the output is the same no matter how we sort the elements in a set. \n", "\n", "Transformers offer the perfect architecture for this as the Multi-Head Attention is permutation-equivariant, and thus, outputs the same values no matter in what order we enter the inputs (inputs and outputs are permuted equally). The task we are looking at for sets is _Set Anomaly Detection_ which means that we try to find the element(s) in a set that does not fit the others. In the research community, the common application of anomaly detection is performed on a set of images, where $N-1$ images belong to the same category/have the same high-level features while one belongs to another category. Note that category does not necessarily have to relate to a class in a standard classification problem, but could be the combination of multiple features. For instance, on a face dataset, this could be people with glasses, male, beard, etc. An example of distinguishing different animals can be seen below. The first four images show foxes, while the last represents a different animal. We want to recognize that the last image shows a different animal, but it is not relevant which class of animal it is.\n", "\n", "
\n", "\n", "In this tutorial, we will use the CIFAR100 dataset. CIFAR100 has 600 images for 100 classes each with a resolution of 32x32, similar to CIFAR10. The larger amount of classes requires the model to attend to specific features in the images instead of coarse features as in CIFAR10, therefore making the task harder. We will show the model a set of 9 images of one class, and 1 image from another class. The task is to find the image that is from a different class than the other images.\n", "Using the raw images directly as input to the Transformer is not a good idea, because it is not translation invariant as a CNN, and would need to learn to detect image features from high-dimensional input first of all. Instead, we will use a pre-trained ResNet34 model from the package `flaxmodels` ([link](https://github.com/matthias-wright/flaxmodels)) to obtain high-level, low-dimensional features of the images. The ResNet model has been pre-trained on the [ImageNet](http://image-net.org/) dataset which contains 1 million images of 1k classes and varying resolutions. However, during training and testing, the images are usually scaled to a resolution of 224x224, and hence we rescale our CIFAR images to this resolution as well. Below, we will load the dataset, and prepare the data for being processed by the ResNet model." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "def image_to_numpy(img):\n", " img = np.array(img, dtype=np.float32)\n", " img = img / 255. # Normalization is done in the ResNet\n", " return img\n", "\n", "# Resize to 224x224, and map to JAX\n", "transform = transforms.Compose([transforms.Resize((224,224)),\n", " image_to_numpy\n", " ])\n", "# Loading the training dataset. \n", "train_set = CIFAR100(root=DATASET_PATH, train=True, transform=transform, download=True)\n", "\n", "# Loading the test set\n", "test_set = CIFAR100(root=DATASET_PATH, train=False, transform=transform, download=True)\n", "\n", "# For later, keep a dictionary mapping class indices to class names\n", "class_idx_to_name = {val: key for key, val in train_set.class_to_idx.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we want to run the pre-trained ResNet model on the images, and extract the features before the classification layer. These are the most high-level features, and should sufficiently describe the images. CIFAR100 has some similarity to ImageNet, and thus we are not retraining the ResNet model in any form. However, if you would want to get the best performance and have a very large dataset, it would be better to add the ResNet to the computation graph during training and finetune its parameters as well. As we don't have a large enough dataset and want to train our model efficiently, we will extract the features beforehand. Let's load and prepare the model below." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# Import and install flaxmodels if needed\n", "try:\n", " import flaxmodels\n", "except ModuleNotFoundError:\n", " !pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git\n", " import flaxmodels\n", "\n", "# Pretrained ResNet34 on ImageNet\n", "resnet34 = flaxmodels.ResNet34(output='activations', pretrained='imagenet', normalize=True)\n", "main_rng, resnet_rng = random.split(main_rng, 2)\n", "resnet_params = resnet34.init(resnet_rng, jnp.zeros((1, 224, 224, 3)))\n", "# Jit its forward pass for efficiency\n", "apply_resnet = jax.jit(lambda imgs: resnet34.apply(resnet_params, imgs, train=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will now write a extraction function for the features below. This cell requires access to a GPU, as the model is rather deep and the images relatively large. The GPUs on GoogleColab are sufficient, but running this cell can take 2-3 minutes. Once it is run, the features are exported on disk so they don't have to be recalculated every time you run the notebook. However, this requires >150MB free disk space. So it is recommended to run this only on a local computer if you have enough free disk and a GPU (GoogleColab is fine for this). If you do not have a GPU, you can download the features from the [GoogleDrive folder](https://drive.google.com/drive/folders/1t8D-GTfmxJ42xLNTeMGMKSYuWBua5lXI?usp=sharing)." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "def extract_features(dataset, save_file):\n", " if not os.path.isfile(save_file):\n", " data_loader = data.DataLoader(dataset, batch_size=128, shuffle=False, drop_last=False, \n", " collate_fn=lambda batch: np.stack([b[0] for b in batch], axis=0))\n", " extracted_features = []\n", " for imgs in tqdm(data_loader):\n", " feats = apply_resnet(imgs)\n", " # Average pooling on the last conv features to obtain a image-level feature vector\n", " feats = feats['block4_2'].mean(axis=(1,2))\n", " extracted_features.append(feats)\n", " extracted_features = jnp.concatenate(extracted_features, axis=0)\n", " extracted_features = jax.device_get(extracted_features)\n", " np.savez_compressed(save_file, feats=extracted_features)\n", " else:\n", " extracted_features = np.load(save_file)['feats']\n", " return extracted_features\n", "\n", "train_feat_file = os.path.join(CHECKPOINT_PATH, \"train_set_features.npz\")\n", "train_set_feats = extract_features(train_set, train_feat_file)\n", "\n", "test_feat_file = os.path.join(CHECKPOINT_PATH, \"test_set_features.npz\")\n", "test_feats = extract_features(test_set, test_feat_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's verify the feature shapes below. The training should have 50k elements, and the test 10k images. The feature dimension is 512 for the ResNet34. If you experiment with other models, you likely see a different feature dimension." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: (50000, 512)\n", "Test: (10000, 512)\n" ] } ], "source": [ "print(\"Train:\", train_set_feats.shape)\n", "print(\"Test: \", test_feats.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As usual, we want to create a validation set to detect when we should stop training. In this case, we will split the training set into 90% training, 10% validation. However, the difficulty is here that we need to ensure that the validation set has the same number of images for all 100 labels. Otherwise, we have a class imbalance which is not good for creating the image sets. Hence, we take 10% of the images for each class, and move them into the validation set. The code below does exactly this." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "## Split train into train+val\n", "# Get labels from train set\n", "labels = np.array(train_set.targets, dtype=np.int32)\n", "\n", "# Get indices of images per class\n", "num_labels = labels.max()+1\n", "sorted_indices = np.argsort(labels).reshape(num_labels, -1) # [classes, num_imgs per class]\n", "\n", "# Determine number of validation images per class\n", "num_val_exmps = sorted_indices.shape[1] // 10\n", "\n", "# Get image indices for validation and training\n", "val_indices = sorted_indices[:,:num_val_exmps].reshape(-1)\n", "train_indices = sorted_indices[:,num_val_exmps:].reshape(-1)\n", "\n", "# Group corresponding image features and labels\n", "train_feats, train_labels = train_set_feats[train_indices], labels[train_indices]\n", "val_feats, val_labels = train_set_feats[val_indices], labels[val_indices]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can prepare a dataset class for the set anomaly task. We define an epoch to be the sequence in which each image has been exactly once as an \"anomaly\". Hence, the length of the dataset is the number of images in it. For the training set, each time we access an item with `__getitem__`, we sample a random, different class than the image at the corresponding index `idx` has. In a second step, we sample $N-1$ images of this sampled class. The set of 10 images is finally returned. The randomness in the `__getitem__` allows us to see a slightly different set during each iteration. However, we can't use the same strategy for the test set as we want the test dataset to be the same every time we iterate over it. Hence, we sample the sets in the `__init__` method, and return those in `__getitem__`. The code below implements exactly this dynamic." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "class SetAnomalyDataset(data.Dataset):\n", " \n", " def __init__(self, img_feats, labels, np_rng, set_size=10, train=True):\n", " \"\"\"\n", " Inputs:\n", " img_feats - Tensor of shape [num_imgs, img_dim]. Represents the high-level features.\n", " labels - Tensor of shape [num_imgs], containing the class labels for the images\n", " set_size - Number of elements in a set. N-1 are sampled from one class, and one from another one.\n", " train - If True, a new set will be sampled every time __getitem__ is called.\n", " \"\"\"\n", " super().__init__()\n", " self.img_feats = img_feats\n", " self.labels = labels\n", " self.np_rng = np_rng\n", " self.set_size = set_size-1 # The set size is here the number of images from the same class per set\n", " self.train = train\n", " \n", " # Tensors with indices of the images per class\n", " self.num_labels = labels.max()+1\n", " self.img_idx_by_label = np.argsort(self.labels).reshape(self.num_labels, -1)\n", " \n", " if not train:\n", " self.test_sets = self._create_test_sets()\n", " \n", " def _create_test_sets(self):\n", " # Pre-generates the sets for each image for the test set\n", " test_sets = []\n", " num_imgs = self.img_feats.shape[0]\n", " test_sets = [self.sample_img_set(self.labels[idx]) for idx in range(num_imgs)]\n", " test_sets = np.stack(test_sets, axis=0)\n", " return test_sets\n", " \n", " def sample_img_set(self, anomaly_label):\n", " \"\"\"\n", " Samples a new set of images, given the label of the anomaly. \n", " The sampled images come from a different class than anomaly_label\n", " \"\"\"\n", " # Sample class from 0,...,num_classes-1 while skipping anomaly_label as class\n", " set_label = self.np_rng.integers(self.num_labels-1)\n", " if set_label >= anomaly_label:\n", " set_label += 1\n", " \n", " # Sample images from the class determined above\n", " img_indices = self.np_rng.choice(self.img_idx_by_label.shape[1], size=self.set_size, replace=False)\n", " img_indices = self.img_idx_by_label[set_label, img_indices]\n", " return img_indices\n", " \n", " def __len__(self):\n", " return self.img_feats.shape[0]\n", " \n", " def __getitem__(self, idx):\n", " anomaly = self.img_feats[idx]\n", " if self.train: # If train => sample\n", " img_indices = self.sample_img_set(self.labels[idx])\n", " else: # If test => use pre-generated ones\n", " img_indices = self.test_sets[idx]\n", " \n", " # Concatenate images. The anomaly is always the last image for simplicity\n", " img_set = np.concatenate([self.img_feats[img_indices], anomaly[None]], axis=0)\n", " indices = np.concatenate([img_indices, np.array([idx], dtype=np.int32)], axis=0)\n", " label = img_set.shape[0]-1\n", " \n", " # We return the indices of the images for visualization purpose. \"Label\" is the index of the anomaly\n", " return img_set, indices, label" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can setup our datasets and data loaders below. Here, we will use a set size of 10, i.e. 9 images from one category + 1 anomaly. Feel free to change it if you want to experiment with the sizes. " ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "SET_SIZE = 10\n", "test_labels = np.array(test_set.targets, dtype=np.int32)\n", "\n", "anom_train_dataset = SetAnomalyDataset(train_feats, train_labels, np_rng=np.random.default_rng(42), set_size=SET_SIZE, train=True)\n", "anom_val_dataset = SetAnomalyDataset(val_feats, val_labels, np_rng=np.random.default_rng(43), set_size=SET_SIZE, train=False)\n", "anom_test_dataset = SetAnomalyDataset(test_feats, test_labels, np_rng=np.random.default_rng(123), set_size=SET_SIZE, train=False)\n", "\n", "anom_train_loader = data.DataLoader(anom_train_dataset, batch_size=64, shuffle=True, drop_last=True, collate_fn=numpy_collate)\n", "anom_val_loader = data.DataLoader(anom_val_dataset, batch_size=64, shuffle=False, drop_last=False, collate_fn=numpy_collate)\n", "anom_test_loader = data.DataLoader(anom_test_dataset, batch_size=64, shuffle=False, drop_last=False, collate_fn=numpy_collate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand the dataset a little better, we can plot below a few sets from the test dataset. Each row shows a different input set, where the first 9 are from the same class." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDY4NCAzMDAuMDI1NjYyMjUxNyBdIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovVHlwZSAvUGFnZSA+PgplbmRvYmoKOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDExIDAgUiA+PgpzdHJlYW0KeJxVjjtvwzAMhHf+ihvbRSJpS0pGp2mMjA4EdA5cJa3hR1MDffz7ygH6GojDHXj8KOjIVoLzDEaX5x2CGnab3p7bdKg3aGfinA/kV2XW/qoFs2F13msO+L99IhrpgmD0Ot6vjYcGNo45b6iTgNeEB4ywlS5kyWTJZEadiz4sPM4V+TnSDrB7wXZCQw0u30XG+W958bSJsDuBKOKJtCzN2hWiDroqTPj9ID7STTVOw7H/RPo4Di99mjGNuNvvqoMw3yJ2uI/U0BcZLEFQCmVuZHN0cmVhbQplbmRvYmoKMTEgMCBvYmoKMjA0CmVuZG9iagoxNyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDg4ID4+CnN0cmVhbQp4nDWMuw3AMAhEe6a4Efg4gPeJUpH92xBbLrh70hPnOcDIPg9H6MQtZEPhpnhJOaE+UTRabzq2SHO/vGQzFxX9M9x9he3mgGQ0SeQh0eVy5Vkpej6X2ht+CmVuZHN0cmVhbQplbmRvYmoKMTggMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMzIgPj4Kc3RyZWFtCnicNVE7cgUxCOt9Cl0gM+Zvn2czr0ru30awk2ZhAQkJ5z3YiMSXGNId5YpvWZ1mGX4ni7z4WSmcvBdRgVRFWCHt4FnOaobBcyNT4HImPsvMJ9NixwKqiTjOjpxmMAgxjetoOR1mmgc9IdcHI27sNMtVDGm9W6rX91r+U0X5yLqb5dYpm1qpW/SMPYnLzuupLe0Lo47ipiDS4WOH9yBfxJzFRSfSzX4z5bCSNASnBfAjMZTq2eE1wsTPjARP2dPpfZSG1z5our53L+jIzYRM5RbKSMWTlcaYMVS/Ec0k9f0/0LM+f5owVEcKZW5kc3RyZWFtCmVuZG9iagoxOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDc0ID4+CnN0cmVhbQp4nDM1N1UwULC0ABKmhuYK5kaWCimGXEA+iJXLBRPLAbPMTMyALENLZJaJsSGQZWJhhsQyNrGAyiJYBkAabE0OzPQcrjQAA3EYkwplbmRzdHJlYW0KZW5kb2JqCjIwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNDkgPj4Kc3RyZWFtCnicM7I0VTBQsLQAEoaW5grmRpYKKYZcQD6IlcsFE8sBswyANFhpDkxFDlcaAKVEDOQKZW5kc3RyZWFtCmVuZG9iagoyMSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIyNyA+PgpzdHJlYW0KeJxFkEuOAyEQQ/ecwkeg/nCejrLq3H87LjrRbLAlKNczuQMTe/HITJRuvGS4O8wVn+EZMHP4SphsxEzoTlwjlK4U4VSfCI7L3rzpoIl7RM6jngVZ1c4NagFnkuaC7YIu54wVN87JrUblzfSj1xC+aXcf13mH9kjj3sNUvs451c67ighpC1nVtL6QbBTJDms/Kk3bzssQseBsGlboHN4Iu1d3J0sYfr/yMCUTPw/d+lF8XTej6xRnJ1cma8956EnpX/XKow/FcSnoF7HtzCT3X6dTkqlTe2fvaf2nuMf7D5BuVjkKZW5kc3RyZWFtCmVuZG9iagoyMiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMwNCA+PgpzdHJlYW0KeJw9kjuSwzAMQ3udghfIjPiT5PNkJ5X3/u0+MslWgEmJACgvdZmypjwgaSYJ/9Hh4WI75XfYns3MwLVELxPLKc+hK8TcRfmymY26sjrFqsMwnVv0qJyLhk2TmucqSxm3C57DtYnnln3EDzc0qAd1jUvCDd3VaFkKzXB1/zu9R9l3NTwXm1Tq1BePF1EV5vkhT6KH6UrifDwoIVx7MEYWEuRT0UCOs1yt8l5C9g63GrLCQWpJ57MnPNh1ek8ubhfNEA9kuVT4TlHs7dAzvuxKCT0StuFY7n07mrHpGps47H7vRtbKjK5oIX7IVyfrJWDcUyZFEmROtlhui9We7qEopnOGcxkg6tmKhlLmYlerfww7bywv2SzIlMwLMkanTZ44eMh+jZr0eZXneP0BbPNzOwplbmRzdHJlYW0KZW5kb2JqCjIzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjQ1ID4+CnN0cmVhbQp4nEVQu41DMQzrPQUXCGD9LHued0iV2789SkZwhSFaP5JaEpiIwEsMsZRv4kdGQT0LvxeF4jPEzxeFQc6EpECc9RkQmXiG2kZu6HZwzrzDM4w5AhfFWnCm05n2XNjknAcnEM5tlPGMQrpJVBVxVJ9xTPGqss+N14GltWyz05HsIY2ES0klJpd+Uyr/tClbKujaRROwSOSBk0004Sw/Q5JizKCUUfcwtY70cbKRR3XQydmcOS2Z2e6n7Ux8D1gmmVHlKZ3nMj4nqfNcTn3usx3R5KKlVfuc/d6RlvIitduh1elXJVGZjdWnkLg8/4yf8f4DjqBZPgplbmRzdHJlYW0KZW5kb2JqCjI0IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNDUgPj4Kc3RyZWFtCnicMzK3UDBQsDQBEoYWJgrmZgYKKYZclhBWLhdMLAfMAtGWcAoingYAn30MtQplbmRzdHJlYW0KZW5kb2JqCjI1IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjU1ID4+CnN0cmVhbQp4nEWRS5IDIAhE956CI4D85DyZmlVy/+00mEw2dpeo/YRKI6YSLOcUeTD9yPLNZLbptRyrnY0CiiIUzOQq9FiB1Z0p4sy1RLX1sTJy3Okdg+IN566cVLK4UcY6qjoVOKbnyvqq7vy4LMq+I4cyBWzWOQ42cOW2YYwTo81Wd4f7RJCnk6mj4naQbPiDk8a+ytUVuE42++olGAeCfqEJTPJNoHWGQOPmKXpyCfbxcbvzQLC3vAmkbAjkyBCMDkG7Tq5/cev83v86w53n2gxXjnfxO0xru+MvMcmKuYBF7hTU8z0XresMHe/JmWNy031D51ywy91Bps/8H+v3D1CKZogKZW5kc3RyZWFtCmVuZG9iagoyNiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE2MSA+PgpzdHJlYW0KeJxFkEsSwyAMQ/ecQkfwRwZ8nnS6Su+/rSFNs4CnsUAGdycEqbUFE9EFL21Lugs+WwnOxnjoNm41EuQEdYBWpONolFJ9ucVplXTxaDZzKwutEx1mDnqUoxmgEDoV3u2i5HKm7s75R3D1X/VHse6czcTAZOUOhGb1Ke58mx1RXd1kf9JjbtZrfxX2qrC0rKXlhNvOXTOgBO6pHO39BalzOoQKZW5kc3RyZWFtCmVuZG9iagoyNyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIxNCA+PgpzdHJlYW0KeJw9ULsRQzEI6z0FC+TOfO03z8uly/5tJJykQjZCEpSaTMmUhzrKkqwpTx0+S2KHvIflbmQ2JSpFL5OwJffQCvF9ieYU993VlrNDNJdoOX4LMyqqGx3TSzaacCoTuqDcwzP6DW10A1aHHrFbINCkYNe2IHLHDxgMwZkTiyIMSk0G/61y91Lc7z0cb6KIlHTwrvnl9MvPLbxOPY5Eur35imtxpjoKRHBGavKKdGHFsshDpNUENT0Da7UArt56+TdoR3QZgOwTieM0pRxD/9a4x+sDh4pS9AplbmRzdHJlYW0KZW5kb2JqCjI4IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggODAgPj4Kc3RyZWFtCnicRYy7DcAwCER7pmAEfiZmnyiVs38bIErccE+6e7g6EjJT3mGGhwSeDCyGU/EGmaNgNbhGUo2d7KOwbl91geZ6U6v19wcqT3Z2cT3Nyxn0CmVuZHN0cmVhbQplbmRvYmoKMjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMzYgPj4Kc3RyZWFtCnicTVBLbkQhDNtzilzgSSQhAc5D1VXn/tuxw1TtKoYYf0gP6bJVHutTYnWJ7PKlTZfKMnkVqOVP2/9RDAJu/9DIQbS3jJ1i5hLWxcIkPOU0Ixsn1ywfjztPG2aFxsSN450uGWCfFgE1W5XNgTltOjdAupAat6qz3mRQDCLqQs0Hky6cp9GXiDmeqGBKdya1kBtcPtWhA3FavQq5Y4uTb8QcWaHAYdBMcdZfAdaoybJZyCBJhiHOfaN7lAqNqMp5KxXCD5OhEfWG1aAGlbmFoqnlkvwd2gIwBbaMdekMSoGqAMHfKqd9vwEkjV1TCmVuZHN0cmVhbQplbmRvYmoKMzAgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMzIgPj4Kc3RyZWFtCnicLVI5jiQxDMv9Cn5gAOvy8Z4eTNT7/3RJVQUFqmzLPORyw0QlfiyQ21Fr4tdGZqDC8K+rzIXvSNvIOohryEVcyZbCZ0Qs5DHEPMSC79v4GR75rMzJswfGL9n3GVbsqQnLQsaLM7TDKo7DKsixYOsiqnt4U6TDqSTY44v/PsVzF4IWviNowC/556sjeL6kRdo9Ztu0Ww+WaUeVFJaD7WnOy+RL6yxXx+P5INneFTtCaleAojB3xnkujjJtZURrYWeDpMbF9ubYj6UEXejGZaQ4AvmZKsIDSprMbKIg/sjpIacyEKau6Uont1EVd+rJXLO5vJ1JMlv3RYrNFM7rwpn1d5gyq807eZYTpU5F+Bl7tgQNnePq2WuZhUa3OcErJXw2dnpy8r2aWQ/JqUhIFdO6Ck6jyBRL2Jb4moqa0tTL8N+X9xl//wEz4nwBCmVuZHN0cmVhbQplbmRvYmoKMzEgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNyA+PgpzdHJlYW0KeJwzNrRQMIDDFEMuABqUAuwKZW5kc3RyZWFtCmVuZG9iagozMiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDg3ID4+CnN0cmVhbQp4nDVNuRHAMAjrmYIRzKPY7JNL5ezfBuy4QTp9IJQba+QBguGdbyH4pi8ZhHUITyq7JTpsoYazCpKJ4Vc2eFWuiva1konsbKYx2KBl+tHOt0nPB6XeG5gKZW5kc3RyZWFtCmVuZG9iagozMyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDEzOCA+PgpzdHJlYW0KeJw9j0EOAzEIA+95hT8QKXZCWN6zVU/b/19Lmt1e0AiMMRZCQ2+oag6bgg3Hi6VLqNbwKYqJSg7ImWAOpaTSHWeRemI4GNwetBvO4rHp+hG7klZ90OZGuiVogkfsU2nclnETxAM1Beop6lyjvBC5n6lX2DSS3bSykms4pt+956nr/9NV3l9f3y6MCmVuZHN0cmVhbQplbmRvYmoKMzQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyMTAgPj4Kc3RyZWFtCnicNVDLDUMxCLtnChaoFAKBZJ5WvXX/a23QO2ER/0JYyJQIeanJzinpSz46TA+2Lr+xIgutdSXsypognivvoZmysdHY4mBwGiZegBY3YOhpjRo1dOGCpi6VQoHFJfCZfHV76L5PGXhqGXJ2BBFDyWAJaroWTVi0PJ+QTgHi/37D7i3koZLzyp4b+Ruc7fA7s27hJ2p2ItFyFTLUszTHGAgTRR48eUWmcOKz1nfVNBLUZgtOlgGuTj+MDgBgIl5ZgOyuRDlL0o6ln2+8x/cPQABTtAplbmRzdHJlYW0KZW5kb2JqCjE1IDAgb2JqCjw8IC9CYXNlRm9udCAvRGVqYVZ1U2FucyAvQ2hhclByb2NzIDE2IDAgUgovRW5jb2RpbmcgPDwKL0RpZmZlcmVuY2VzIFsgMzIgL3NwYWNlIDQ4IC96ZXJvIC9vbmUgNjUgL0EgNjcgL0MgNzAgL0YgNzMgL0kgODIgL1IgOTcgL2EgMTAxIC9lIDEwOAovbCAvbSAvbiAvbyAvcCAxMTUgL3MgMTIwIC94IC95IF0KL1R5cGUgL0VuY29kaW5nID4+Ci9GaXJzdENoYXIgMCAvRm9udEJCb3ggWyAtMTAyMSAtNDYzIDE3OTQgMTIzMyBdIC9Gb250RGVzY3JpcHRvciAxNCAwIFIKL0ZvbnRNYXRyaXggWyAwLjAwMSAwIDAgMC4wMDEgMCAwIF0gL0xhc3RDaGFyIDI1NSAvTmFtZSAvRGVqYVZ1U2FucwovU3VidHlwZSAvVHlwZTMgL1R5cGUgL0ZvbnQgL1dpZHRocyAxMyAwIFIgPj4KZW5kb2JqCjE0IDAgb2JqCjw8IC9Bc2NlbnQgOTI5IC9DYXBIZWlnaHQgMCAvRGVzY2VudCAtMjM2IC9GbGFncyAzMgovRm9udEJCb3ggWyAtMTAyMSAtNDYzIDE3OTQgMTIzMyBdIC9Gb250TmFtZSAvRGVqYVZ1U2FucyAvSXRhbGljQW5nbGUgMAovTWF4V2lkdGggMTM0MiAvU3RlbVYgMCAvVHlwZSAvRm9udERlc2NyaXB0b3IgL1hIZWlnaHQgMCA+PgplbmRvYmoKMTMgMCBvYmoKWyA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMAo2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDMxOCA0MDEgNDYwIDgzOCA2MzYKOTUwIDc4MCAyNzUgMzkwIDM5MCA1MDAgODM4IDMxOCAzNjEgMzE4IDMzNyA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2CjYzNiA2MzYgMzM3IDMzNyA4MzggODM4IDgzOCA1MzEgMTAwMCA2ODQgNjg2IDY5OCA3NzAgNjMyIDU3NSA3NzUgNzUyIDI5NQoyOTUgNjU2IDU1NyA4NjMgNzQ4IDc4NyA2MDMgNzg3IDY5NSA2MzUgNjExIDczMiA2ODQgOTg5IDY4NSA2MTEgNjg1IDM5MCAzMzcKMzkwIDgzOCA1MDAgNTAwIDYxMyA2MzUgNTUwIDYzNSA2MTUgMzUyIDYzNSA2MzQgMjc4IDI3OCA1NzkgMjc4IDk3NCA2MzQgNjEyCjYzNSA2MzUgNDExIDUyMSAzOTIgNjM0IDU5MiA4MTggNTkyIDU5MiA1MjUgNjM2IDMzNyA2MzYgODM4IDYwMCA2MzYgNjAwIDMxOAozNTIgNTE4IDEwMDAgNTAwIDUwMCA1MDAgMTM0MiA2MzUgNDAwIDEwNzAgNjAwIDY4NSA2MDAgNjAwIDMxOCAzMTggNTE4IDUxOAo1OTAgNTAwIDEwMDAgNTAwIDEwMDAgNTIxIDQwMCAxMDIzIDYwMCA1MjUgNjExIDMxOCA0MDEgNjM2IDYzNiA2MzYgNjM2IDMzNwo1MDAgNTAwIDEwMDAgNDcxIDYxMiA4MzggMzYxIDEwMDAgNTAwIDUwMCA4MzggNDAxIDQwMSA1MDAgNjM2IDYzNiAzMTggNTAwCjQwMSA0NzEgNjEyIDk2OSA5NjkgOTY5IDUzMSA2ODQgNjg0IDY4NCA2ODQgNjg0IDY4NCA5NzQgNjk4IDYzMiA2MzIgNjMyIDYzMgoyOTUgMjk1IDI5NSAyOTUgNzc1IDc0OCA3ODcgNzg3IDc4NyA3ODcgNzg3IDgzOCA3ODcgNzMyIDczMiA3MzIgNzMyIDYxMSA2MDUKNjMwIDYxMyA2MTMgNjEzIDYxMyA2MTMgNjEzIDk4MiA1NTAgNjE1IDYxNSA2MTUgNjE1IDI3OCAyNzggMjc4IDI3OCA2MTIgNjM0CjYxMiA2MTIgNjEyIDYxMiA2MTIgODM4IDYxMiA2MzQgNjM0IDYzNCA2MzQgNTkyIDYzNSA1OTIgXQplbmRvYmoKMTYgMCBvYmoKPDwgL0EgMTcgMCBSIC9DIDE4IDAgUiAvRiAxOSAwIFIgL0kgMjAgMCBSIC9SIDIxIDAgUiAvYSAyMiAwIFIgL2UgMjMgMCBSCi9sIDI0IDAgUiAvbSAyNSAwIFIgL24gMjYgMCBSIC9vIDI3IDAgUiAvb25lIDI4IDAgUiAvcCAyOSAwIFIgL3MgMzAgMCBSCi9zcGFjZSAzMSAwIFIgL3ggMzIgMCBSIC95IDMzIDAgUiAvemVybyAzNCAwIFIgPj4KZW5kb2JqCjMgMCBvYmoKPDwgL0YxIDE1IDAgUiA+PgplbmRvYmoKNCAwIG9iago8PCAvQTEgPDwgL0NBIDAgL1R5cGUgL0V4dEdTdGF0ZSAvY2EgMSA+PgovQTIgPDwgL0NBIDEgL1R5cGUgL0V4dEdTdGF0ZSAvY2EgMSA+PiA+PgplbmRvYmoKNSAwIG9iago8PCA+PgplbmRvYmoKNiAwIG9iago8PCA+PgplbmRvYmoKNyAwIG9iago8PCAvSTEgMTIgMCBSID4+CmVuZG9iagoxMiAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyA2NzAgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDI3MSAvTGVuZ3RoIDM1IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDY3MCA+PgpzdHJlYW0KeJzs/eeSbEmSHggqMXKIsyCXJCnSII3uwczIjOw2Xg9vB8GurMhur0Cao7sqK+llEeHkEDNT1f1hxz0i7s2squ7FYEYgaXkl0sPDySFqSj79VBX/43/8j/Dz+nn9vH5eP6+f18/rf5RF/2cfwM/r5/Xz+nn9vH5eP6//lutn0/7z+nn9vH5eP6+f1/9Q62fT/vP6ef28fl4/r5/X/1DLPf3lb3b/5rfrL8AMAAAQEAEBEJ+9A5e/XR7XZ5++yJZPAHvy3LMnnvzh2apvtMtnmwF4gpagI1ixRkQHWNRmVTAjBI/gEfv7b1Zvf1M/Q1S/37+dxml6OJWUl09DPB+XgQEgAiIykePgQwihaWPbNqHxITgiJiJEAkAzNTMzBTNDNASzemiAl7NBUAMDMAMzVFHJpeRccs45p5zznHJKKmKiYFo/EwAQqWmaly+vr3abEJxjnmbT88d+/8N3X3/zW3QGztAROSRCBAQBE5ACZkCEyIgMiGBmJmAFQAAUsB4omTmAgOgBCMAMBFHqawwQDA0YgQEIDREMTcEKQFnOEREADciAAQiMwBAAEQkRANQsg2VAA0IwNvW666//5MW/XW6HyH/6T/9pGIaPhAgRiMk59s4xMxFjlbdH2VikwQABkAhi41+9evH69YsXL19sNyuTUvI0D+MwjIfj/HAY398fT8NU1EREpIiqmmoRKSK5SMkqehHOP7gQ8S/+4i+urq6W2/H993/5l3/5R773X7IQo49t0636ddd1h8P+w8OHnJOq/P6DrA9+z3ntdrv/8B/+Q32lmf3l//P/9XD/YE92cr0dTEiEuMgCRuea4KOj4MgROkIzVdEsmkQN0JCqknCMjskzOyJEE7MkMmUZ55yzqAETNY13kSmQkmZLBYuxASExoSEU1EySyZTQCNAAFFANRU0MhBiI4YtX/36zelFPan94+C9/8//5sbPGT555+jQ+UV6LPvhUH11e9OyPlw9+VFCAP/2FcP6Cpz8RrW26/+1//gsirq/5+tu/PhzfIxtSVQpGCERAbERG+OwQzEDFVEDNEIAIiJAIzFDUUrJ5tpJNCoACIoSAzpGaioAqGgAzOAfegSNgJBEYJ52KZgVdxAnBwASsGOhZy5mZgRpY/croKfh/9fmvvrj9rB4Ys653BwQDs4uORaT6BCIAIhEiIdYHiIgXs3HW0fW55dqDmamK5Fyy5Aw5S8pF1RCRGJ0jMFC1IlKkEDN7x6Sg+XBIp7FsNuvtbtUEJMr74f1x3GdR4rhd3a67bXTBMSOCCH79Fct5k7388k+vX/zyrPM/lh/8sf8tx36RhcdHj0IET//8yapb8yfF6Mnb8dkzy2+EQIjff/Pbr/7pby9/f2baf7v+8v/x8n9bdku9MYh22fqPuwPhrFMQ0ZYDPn+p2UXNPNl49lxn/9hp2MU8AppVGVOwhmDLcOXsxmmPEJGS2DEXMGOEhqAheGV6Me1q+v3du+PD/uH79/NpwkfT/mQRkSP23sfQdl3f91tdmTNszDkkh8hEiACoiqZqogYKRIpoAKZQjTkaGIAiKIAYmKEpFJNcckrzPE7TNI7DeDqextNQ5qQ5qYhpUVUAIHKbzTpG3G0bxxQczenxOn348O5v/u6/YFBsgBrihp0jBtJkMltJYALkmD2Sr04IaDabwRKgGCoQmDmzBrBF6gAIzAwyYkIoBmKABgTgwTwZozGaoAnaBJaMBNAQCYANWc0BeFAGIwBCrEpZFCbTwUiRCCxqacovbv/kYtpV9S//8i8/fPjw9E4jABE6RyGEponee+993fdPBcas6gkEpCbybtdv19jFq9cv21evrqxM84CHu/kDyDzOaT58+PD23f1hyiXlnNJcRIpqTjlPaR7GNIwlZ5XfZymfLkT8sz/7s4tpf/fu3X/+z//5j3zvP3shINK621xvb1+/+uz2+va7N9/+5ut/Op4Oc5rgJ8z2ogwBAKBK1I+uX/ziF3/xF39xMe1/91d//e3vvtYFsUNEIDBGcJ4cIRE4REfUh7Btm75xrefGU3RoqiWXMZcpSwFSZEAgguip8dwFH71DAjE95XyY0sNhHsYsat7xet006+h6Lq6MNsw4ixN0xN6REmQqI+eBrHg0B6iAxagYZLFkkNkje7zafH4x7afh+P/9q//3j2oSRPxYT+KTUOSJKrAndvMTjXuxMM+euzzzB037xaTbRZDBAGy3vfpf/6f/28W0v3n32x/e/gM5Q646BYjNOfPenFMiQDRVqGpWxEqGUkykbiJzDplJDXKBYdDhpPNoaTZQcAh9z01DRTVlE0E19AFigCZa45CBS4aHQznOOqoVAwACQFDUbDobFMBqqc3UQBSMyBzzqnOrdt2tLqadSNebE6KA6aLDkRAJDM2s2nViIiZmYiJiJCKk6swrAFSDX+MFNFDVUkpJOc3TNGYTnSWleSpiRBQCO2YEBAOVVCQFH31sPBeTMd8fT0Neb1/2G1j3yDTM+NuH/O0M6vyq2ej1lesba3xAxJLx269JZLmL1y9++as//b+bqmmN1x5v/Hm3LT/h8fFi2gkur1jsV91zy8vx8aPOYnh5bAD2TGovUnp2MhdPCB9tGZ7tLiMwIgL8pGn/xHF4cmJ2EWdc3DAzQDSrTh4aEgKgGSwHWKX0Uyt+jsSeLqynj2CK1bpXSTarQcBJwMymogGMVEQt16gdjMEcAs3y+eUL1Mb7w3Q4SS4XBMIACBGrZDl2IbgmNG3TdE3Xdn3Xdl3TdU1o2HviGrQvW1JVVZfzXNyc818M1NRMzPQc4pqhmgAbBwrk0QMHR46JaILjVApAuVyTeuNVSkrJOTYFNX5UFUswheCMGw69j5GZcD4mEYFCYEQMzICoBmYKoGd7qGAKhk/ExRAUQBEEVAwNFlSCKoZxvqd29loWt7PGzPXUzzfWLqCFgV2+UKG6Ovj74pizaICqiWjOhYgAiIiZl31R9eAi0GaG7JhfXG9+/cXtn3958+9e9y/bsra9WJpgVrYU/NS307Vl1aaLx3Ea5zTnUswEOecynabTw/5I99PplObJVPEPHOJ/74VAjK4N3W69++z2sy8+/6KJUVS+/f7rdx+S2o+7IxfTbmaI+McCEou7ZPZ45w3r+w1UQdDQYM7lCHMpZfYYHXmHYKaqc5FUTAAVxMAQbHaUPFlrZkAEiqpmROiDCwo5FyQwMiBDB+SIlBnYrAaFBmQYDFUhGyCCAoACCqAAKtZQlgHpj7ppeLHiBmfVjIs+qX9fnll++clr9kS3nq/a+feLtv1DR/LJHcEnP5fFDtkTgNQdAWjOwMgQDVEBqpWpl8FKgTJDylAyqBmAERkiqEIREEERqrexqmaH1DhWw0A6FyhmLqAL4BwgmYmAQmRQj6SWFURNC2gBLABVrwGaQY1l1EDVTA3VVD8FwAjMoOrC5fLXB4qAZrhs67NORiAEJAQzgiWERDNTkTTn03G8vzsNp7kUG4f8cDc97MfDaVIz71zT+K6LzjETEQE52F7hqvdoVgqZERI6j94BooimIrNoBmRmrtAUgtUD++h2mIpJuQCrF5v8aFbxHKNfhAgvf6nGsOLddjHni4F/DPHhgo+cA2fEp6KJz8Cli4uwPEDDx+8FBGBERnT87GZ8ZNo/krvzuVVM9NFlXex6Ve5ohqAoWo/fCPUSwD9urd8DN9hlj5zhC6s4cbWVAjYKzKJ7UFQzUdAzumWKpmD2Yn7UfWY63B+mYZRczGzxchCJPTvvQvRNCG0bV23Xt33fdW3TNCEG7z07NmYgQloUJpyd0Ho4ZzkwUzUVNTVVKyp1k8Fi8gwZmAid5+B8E4kYAa2UMs0qjxe4folIKTmn5MzQ7Mn9QUAGcIgBXUtxxW3nHKNZLtlEDACIgdiADEzPWQZEV9W2VS8cq7dQ4TYFVTA1AEOqZtie+pNVHz5CHbY4M+c7WH/g8rqL52BmBgZqaEb4x1A4zEDEAAQxI14MO12+E86+hEPogvv8av1nX9z8u5f9v1pbDw9uOmRRLppNk8PcNQbMjvq+eTgNpymNRTOQkJtTOR1GIJdSzjljmv9YRP6/30ICcuTb0G677e3u9vMXnwcfipR5Gh/2dykrfOIlV7tORGfn6o9ddtFJdfM+qrfqiptYhYCKiqWMI2NgdG7J+RSDoiaGumSnbCZIjggIkdgBkikaMcbobRGFimepEaBDFiZlNDE1U0UH6BBEIAAAgFa1KAZiJoCKYEBoPx4gP2qYTy7qRek8CtT55+IALEHKj3/QOTp6lPuPPv58QBd19wfuAv6oFiSH7FEVQaoyM0W4RPlmagqmpgZmJhlKwjxjyia67DozVUUxQEIkMlCodhSAAD0iEgojkmYw9hA8sAMC0KyoEBgRgBRThQQAioAUAwHVZYsbgtXYDdEQn1jwJ8sIoAoqVLsAWAMOs7OuM1UjrIcNAAg1sFBTVVBVyaXMUzoehvfvDt9+/X6/nwz8OMrD3XD/MNwfBjBsmtB1Td/HGFzwLgRuW+9ds14BIqfEAM4H9Z6dA4RSZM5lLpqBmNkx0QIWmJ6Dmye3A5QqAnuON5f/EOHsjHxkeatlpsX6LwrzEthfou2LtCxW/Jx+QFikuyYwniEDF8/grITxmbuw5MwdAgP432fa8ZxfXzIlZ5v7zNc8H3wNKbWEMjfjQ3d6jyrJhRS7FFeFnQIKOSNfQSWAs9UGe8ypVuNythz1IqNZRfWXtwAagAALEEA14QtIVH0AWKCki4RBGlOes8pifYiYQwixbZreN52Pje+a2DVNG2LjyZMhFTUoulw8MCDAGokj1sx7tXdmJiompstmQzWrSW184sTXJDYYIiI5Cm1jBpqyzhlMZ5UFPkVArFtTKur/LG9AYM4oALfkOworjB0EVlRicCNaGoRIARHIkJC9c61jYBAos2gWqbipt+pPY92g1RKQAgKQAQGQXfxGWu7SOYFyUVtoFxdsyaSdb5rZAuwbVQcSqmf0h1dNlRUlKswZALx3S753ERhEg01wr1b+X6/dv2nlZvjeffM9BgSPSIEhcGmDtCvXQuvBQmBYrZsh6yHrUHQsCKd5ykAhEntEegTG/q+yEBGZXfRN4MjgZNZ5SH2z+uXnv7q7//DD2+/NDimnT974iPfCH7YrT74PAACobq2aHAGgmgEF0Ef82AxM1JLAhMC8aCsFMEN9TJqZKoDpMAuxBCB2IGgAGEJgh8Q+SxaTqWQQIkMkYnBSRFSkqKIRqxKiQ1CAupnRwNRA6mZSg0p6+fQKfHxRzD7Ovn189j8SOv/YC59H7T8Sfj+L2n//9T9H8D/ypYiAhEyOyBGqqbIzApFSQO3s7NYgU01BBUzABAGACCpiBhV9qarVFrNVDKaxeJQQgJ15Z0zAHpiREJY8OgIjGKEHMF2CTau32M7RDQEyETtiX/UieQcVLnh6BRb1hedY7mKizkis2QLtG5iaYs3Dk6mUUoZhPh7Hu/vT/d3p4X58//bw/fcfjseEzhOzI0qlpFSICJDUbM4pl4JAhOAdFjERbhrPLhLBah3bNjpGMykli4oaMHpP0aFnJKwa7ZMbFzytIomZaNVr1SbScmJ0NsfP8uwAgIRA50B8iYrwbFEfI/FzuP7kY7A6OIuBxwvifsGfzh+FgPb0T+fEwMKV4ucB1SdR++IrPHr2i1d71rTwKJ5GYLFM6+Fud/fV1dt/BElj6E/91bi6mVyTiLPvJPbio/gAwIBkdes92gaoqLAtRk7RCkkmVTQzQyNSIEUyJFtMLgCqPTmmapoeJcys5CJFEBCJiJhDDE3bdOuuX4fYcWx9jL4JPjAxGWgWFdWiFqHqNzAzNKxoNCMCEQCqqUpNuyvoYgABCC9u6dktgopqV98U0QWPhtLPZZqlpJyTXXQQVl4dnHXsk0UG3iiSa9F3GDqInTYOAlFkR1rQREt1iQgdcXDOheA8GpRUSsolo6oAymKoFdDQEAEXu25kRufswFnR1L1pcInUAeDx12fPLFH7cketRt0OkX9/2P4oRaYgphkLUd0f4NAh8gWucmBX0f1y5X/d2i9wWD/s6f5gTiwQhB7d2vFV4685eg7BCgcOaxdOiiEJj0VP6TQVBRQDfZSZZ1r5/1xLjwiE6J1vQxNdJOAyl/E4bq7Wr29ffXv9YrPeppxyyXZJhJ3fukAn/8yo/RyNAKIRAFHV28zMCmYFappRwVSMsDqBhkUXFYeISFW2aZF+A4A5q8/KwTl2jIqM7FsFRg7jPE1lzCokxSkjIyERsJhqUaCaPCJkAgbQM51LxUyXhIEZFNNPTvPMC/royY9f9sRK48dPwCeJefhEJqrkLPHIAoeeAf+PbsqPHEa162cj9/GnIyEzITIaLWg6FLBckhZQAkACJDq7WgaKpvUiLfGX2NnxtgU2JENRULU5KaMhEBF6h+CB3HIgaIthqYk0FEBduDcMpGgEoDW+IwImagLFRur9cUsi8OkVsnpwWBWyVTV6DlHra0xEVCVnO987B4YiZZryw/3p/fvjDz/cv3t3PO7zw366vzuMU0Lmft3e3qybNnQ5E3LbhBAdOzNAFSxZUhJ/Pzke2q5rmtium1WPTROZQIoUyUVEFTx5z9GRo2pMz+Supys67BsuokVrxrXSRfGsHh/5LY9x+BJ/Az0JshdbXO3aY6h9AQGWILo+v5h2eHzm0T194hyc0fhzsH++snxxLJ6sT037OZX++MZnAv4YvKOx6Xq6/+LN3778/q+uvv8ryPMQ1qdme2y3p9AOvhnb3bS6nTav581L8Z1wtcE1El+2i0ElHxhrDnkI6RjmvcsTmClycU3hWDgU9sJOkc1QwYxMDUERCEz140NFQERCct5zaELbh66Pbedihy4COQUSgZINQIEU0BiBFYmQmQ0M0cgAzQiV8ByAw1P4DgGREBXMLilufHKpDFS1wmj17877tmvTNM3TpLnacwBAFTQhz6GNzZwfGfJABsGoAdeiayxEaRrsA7qGpWEWQdXhpCmDgUN2LqJvKTboloQna9FSSimp5CJZNZtlBVN1aqBKNfQyQD4rPNSK/eniY8MSxi2eI5014yPCdoFYqo1iJEakP7ao0hZkXueUAZAICZFpCSORMCC8aPyvO3olx9XDQ5A9yQGDQUD0kcO6a5NvMbmGCZJlAAgODWhUQ9B5mk6H4/7+4fiwn4Yhp2Q/TTf777Uew0CsCROk6ELXdE1sPDtTLfOMsmpd3K02N7vrYTwN02BwlvPlrWh2gYoe/bM/aOTxHCIQomMK3scQ2TtAyEVmmC2LnlWeVSYlGiAiKGjlxZ/F43w8hChAwL7t1ptt6zwCoSjO2Yy8MpcEGVMW1ay+gkVIoKgGIIaCCMsTSwiKYNUBRUQAEdVy3jHnE/7n+TPP7DpenrInm/qjZU/+V+PQi1NaT/u5i3j+GsRLlv3HuI2P316Xc977iEgA5AxVNGfLqeQMKlD9ZmIlAiKyglbQipUsgOgATU2yiJoo1MwzCpJpzSOa1keewHElrIsCauVUI6MoTmJp0mmWLKaApoiIjgk8oIKiCaMEwuixCwxUSW9GT0mvgGCIGdDO+DbZORC4sM9UVYpM4zQM0zjO45xzgpxhnss4pP1+2j+M+4dxnsX5xof25atQJKeU+1Xzi89vYnDjMOYiasjexcb7EL1rSil5Tt45ZHc8zfvj9MKtt9drJg8GRXLKcy5FDAi958iLd7NUCH0kRYGh81YIikJRELOajL2EyLgY14twYKWqEEJlzz35d8YfL/H6s7C9SssS5eJixO0Syj+VqDOqegFKnhh2q7ry+Xs+Me1Pg/IfW5ecO4LT0pbx6vD95z/89etv/3L1w99aTqNfncKqD90htIfQnbrrYfOK0mjo0gq1cReQHQDRDMEIjFW8TDEfV6cP3fghnt67PCiCuJjCOvs++TaFNvumcBQMAliLOxToR7c3IhKzI/ahCd069KvQ9T407D2gU0NUkyKIZkhABmCCwGq1jEfNEIzBGM+ZJrSzV2XVzuvijBGgMZBqzcnXq3TGnbTm5K0CUMzsY/QxsHNUslp1F0iNVMmRiy6kkh41FgMEowjcgI/qozURVy23TNASZNKlIAQMHBFSQNdC6CwEYyJCMoVSKCdLM5ZJymQyQWIDUlWxSr5b7js9yupjJszO4PySRLkISX3l2fc+HzABMCDTxw7kj4vS40NRsCKEuV5/dFy9CI/UO3zZ0C+j3cixOZ58OaCNmAwCgHMcpijozXvXg6fJsiEVK8kYtUhK83AaDofTw8N0OuZ5NClPfGm4ABOPAdajL/5/0KrOejUThoCOOLrQhWbV9G1oAhFKKtNRx4Bee0c36/WHu+bDsv2rTwkX/g08HjlcHvx+k1fDC0JwjDGErm03q40PQUynnIZxTKkUMRVTVUREQmZkR6ZZSj4DH3BGWXGpF0FE5qZrt7td1wVkGMZ0GJJQyYieTBRFJ81ak4oLdqBLnVWFECoPdUm0LVk5VEAtIFmfuL3P1tPY/VNGwhmJOpvx54H7xbo//8Rnj57GUZVkVCEmei49lfrgvUMkESkil9zG+cN+RLAc+xCaJbWLKCg5kwikGXNaaDHESEyOEI2gYMla5lrLiqZWMojUyjR0TGBGBowADN5hcBxD28YYAjsHABlAHBsjMlBJeFTlIpCzU6lkCAQyRRUUADHLDNmjtEGaIEAmsEDVTzx4A1NLVqxkRGJylXwBZzVBiKiiKeX9/nT3YX//cHrYj8Og4yjTKNMk01jSVHIuxG69bVd9u90GBDnujzH63abrumjbbprTMCYgCk1ou1XXraTIPM8iWgqcxnEYU7uK8wQVs1xMu4gaMnnnIuEjVfms6h4Xo0Wuyh8YrVT1dwmyl8f2jGmBj3vq0Zaf84r4GMcvP58J21nf4AXMfozHL1H05XufGvxLlLw4wB/J8I/Q6J74ps+efXRbANCsyePt/pvXb//u5ff/Zf3+n2A6FlFTRUkun+IcxUUa7tzxLZaswGYgLgK7R+wakVTafOqn+93hm+3+283D9+3xHc97tZJDyM0qt1ep2WTfp9inuJr9anJ9opDRZeNsLIBqFSS6nDMSOecoxCa2q9BvXNNyCMjOgBb4WxWRzRwCmdagGkw1M3DdoATBAWF17OSS+zUAQqIFaEAjrBAWslEN07V+voAqqpKqGsCFclKZ4K5mlOsvHs+FLwsT/1HKgCJSAA7mAkRvXeR1i70HbgGFofhxyOOoUCEKNueEnTqPMRAzYk1yqNPCMkMeNZ1kOuVx0HmGlE3ETA2Wy8IINRMPixJa3E4wAkQDumSHzmKyRO1AWEveEQmN/gWm0erpay7GrI6JkBF6xy8ivvL5JeaVnrgMDJlrFYIBSoF5RLnD7Bn62OGKPIA7TaNlzWOeD6f5eMing04nzJNHIQcKTipPAmroqfDfL46/6IclmmGkxjfbdrPrr3bdeh1CQ8r5oPvjWN7Tg/fTcOvth0DB8SyqZgSXSo2a1PnnRa8ItRIaGcCzW7Xt1dX1ixcv264X1THNh9MppSIK0zQPwyhSEKFtm9WqmafheHyQkheH9UzWWqSCgBi9574L1zdbH/3d/qh8mmAiE0eRs5SSJBc1RQZkqOwsKKq28DOqlJmd81mKlXqlxbT8s070ybInArk4AWfHCD4y9M9fX3GJSxR1Bt4RgAgcE5MT0SKVOmPOcYyx71fO+2kcx3GczErJnxzPs+3BzjsXRFWlZniLglYouAhKglrJDWgIimaoWMTmbEhYSgW9UAVMUdG0lh8QRs/euXUTt327Wa82q9Wq6RrvsFpqJ57Ns7NCw6DHUzqcpjllXcgNaIYqWAAK4oQ6ou4RDoiDgMDCQ3oGyKtN4zSN6bDPzsf1dt00znvEis+igUEpMk9pOKXDYd7v0/4h7ffldJKcTY1iaNuWiZQIfAirlbu9WXkHD57mVO4fhnGWVdciBec0lXI6jLV8HwBKLkTsPXvvaC77h/Hrr94Hv+7akPKc8ihaAInZe/YLpaRmDT4xdWiKUmhp5IFQGd1oVfnRox1dYmk4P3gatV/M8CXUvnzPBcw/y/LF4l8eP77smSnGj9Xq44f8GIHoU9P+XFXgY6R+KWBn02BlO96//vBPr9797fbDP8bjD3NRNPRFTTPIhNkTeofM7oNgmN06h3Xqri22wExgZOLL3KTT7vjD1cPXt+///uruN/39D/50J5oy09ytU7ct6VDmTQl9juu52YzN1oXtyB1iNPNibsHspDw9ZHYOCFzoXOx803GIxIzENWlXA2w1UFvYxWfaEOQM534RSNW+GMjZ1TdTNTOkpW4MoMoHIKHV268AxUwXqrgtFEy4RMeIROycZ3ZSpMLPgGSAaioqT+MFdIiRMAB7Cx5aD33AdYOraJ40ELG6wx0PR0sKCObJgpPgoAnYNOhdLR5FJKochjzqfCrD0YaTDicbBktJSwEVW8g4AKCIi6O6sKSo8uOqPccneqlGIwqLLkRARqvX5g+n2p+tystTVclSXHHikIkN145eNfbSpSssLU6OEqMyoTg0UjSxnGA6wuQQVl6567dKkGbDKdswwThinhhKYLPGRWwkB1ErspStmpqULCWpqKqcs6efbqL/lqtubUMj4uBj3653q6urfruNbeewsSnMM46jDioeOnYvzK5IOkZRzGaL179wFw3PGV/DhUK7XNGfXjVeZ8AYXN+1u83m5e3L9WZnhEnKMI25iCqcjsP9/f0wDCnP63V/dbUZTgcEGYdTmic9Vyc+nhUaohGr99j3oe0bRckqpyKDFiYj8FC4CJgUdMYeAQEUwdCKGgMyIiCSiRmomVrlm6qCCaLx8/ti54gQlgqOC1p5OftLoe6ze3q260jwBJSCT+77uU4HHilPBAjmHUaPjXfO8WnKpzHVBk9tjP1qtdnufIgHd1CAUnKRpWuW/YRQETlmr5YMTK2olapSKhemCEpeKsdU6/5DM6gFSUhmBqpLBQIgVuZEdK6JoWvitmt3fXe1Wu9Wm1237WPLpozFuxycBmZQnGYbxmraZ9FUnX1VFMUMkABG1MHKm5IhpTKXWRQuHv55qdk4zPv98PaHkw9NUdpu2/U6LLk+AwPNuUxTmuelD5MplQIpaSlITLGJ61XjgxFqSkKkwWMMTlbR9ro/jXMSJuc9IaKWMgxzSdlEiBDMnPfkAhN474ZhnqfpagdXV/1cplQmMUUkx86zI4TameTRvXu6TNG0Bmz2pLML0wX3BgRAgqUMn86U+ac0OnhqyBEuVelw2aAGl25Azw/iyas+PrZPBQh/+o8/ErVXM/5Uw9UgDdEqLOtFrvLhs8N3X77569u3/+DGe9PikDwaW1HVZDqanKAckMHJeLjzzfe8/ow2B2UP7J3lNh93h+9u7r968fbvrt/9Q3//dbN/C9NJSk7scuwKB3Uz8uAASYRUULNpKWrZFcDO1JtQLfiAPD85fHQ+ogK5AOx1YTfWwhDS6mCfq9ek6LL/AYFQxXIWdsRIRRGNFIyB1JCslnIiKIuRKEnVqVqtgtZSCgQgNEMBVMOl4Pxc04gGtc7Ls3NOCiJVwmhVYkXl2f4nxEDkjZ0Fhs7RytMq4DqW6G0TqWMePoR0og97yFI8UuOw99AFaIN5b0zATMQEQCpaPJTIcx+mkU4nPp3KOOg06jxZmqEklVJgQUZr5bs9oY2gQdW3la/zGFtVnhMQGBswKKnivyAONlUoprkoFWVW5/nK6RfRbn3uOTdBXSACNVx6Q1g2KKpzEjuZvTUhjy4E5Awu51jmDvO6xbLr2Ps8b6xkKVrESrEiKkWkyDwM8+mU5jHPyX6kVPe/+TrHucg+tKtus1tdX61vdyGs2BoZQx5COUY5dVxWTpFDY3wrw5YtCxTFmhVSemRxAcB5j8KzeOAnVs14BOY2hK5puq7r2n673TWr3oVg52q10/F09+H927dv371/23Xtzc3V0AUH5b3J/TQu2M7ilQGgqqlYEc1Fp1JOYLZqqVhzlHyyPI5iQkBsBVMpUISlOttgCiZG3oiIGIEQoYgUFZEsamiGDI6YET/1GSt6dP4Nq5V++jvZIyy6RFWGaETnYEfPjmwt8FrCcwRAtIXMe45RncMQaNeH6z70jWeib+/GWcBmcwSbVXd1vWtXV+iCGKZS5nmkPC9itfgJn9wdAzNQMRGRUooIGBIxsTIZoRVQqfkRswr0LoGGmZRaXm6AlWyHjrmLzaZvV3237pouunVwN8HdMl05vwotkkMC74ojYQAwyAwbL7sm5ZxEkhYx0SyUlEbVQWTFKs54HhMcxzIdca5+xVOGvKkdT/PDw/j2/QFxHGctZRfjNka/0PZMU8rDMIlAiM0anfeRaHI0TUkVsGkxthA8q8DpNKQ0v3PU95FA2QGhSJnGATU6R2Ba8jyWNEmZHTNVog+xj6u+afM8DafT/T2/fy/ohwzFEJE8s2NmQFDTc6H9xzfkCaC+QDZGCIR8tu6LJNUIkCrJ8RyL49Mc/LNPXZrYPHHwHrHRy7b9sbj8Y2H/ifWp1f80177EbB+99nJUCNaU+eb49rO7337+5u+3d7+FeU+SHLI3iCJmkI1H46Nxgx6djv7gmw9+/z4c7pzzjNrP99vTDy/f/t3LN39z/eZv1u9/Q8c7m05JrCDnuCougBTOCekIVqDMKpOXkawYYPY6EIB6K2zEwA6eEGwQwDkPhuQCkluyqoCXriiV21JD9TMDDw0RFEUUQGoPzaK1qhbQgNHIFgIWoSMK7AITA4BISfOkWtDUe2oaZ1Ykz/M8JZiyZTCtzVsRFr+amR07ZQeIaqaipUguOWcy9I+XnhAcojNmDUwtY+eo99QF6aJ4T31ww11XxoJQjtlaJ62DPmAfsIsYFtNu7GBxwz1qgzlTWrmux+5I46DjSYeTDkcbwWap/SvRFnodPAYs5yjJqiBUIdMLNLck2o0/qlf4o5eBgRVTEqWijbcIsGN4FWAbMbbk+8hNNM1qGVDIVE0wGWS1PKncmzmKvVPnsgul9JJ3rNb72LT9lpOYimSxUiAXKUXyXNKUhv2emOAIqiqlmAhcyCD//JP4wycJCAhM5Lzvm27bb3erzbbr1iidjkEPPu1jOTVy6qmsxAKHFfoXkG8cngqdQJ981BlGs+p0n9PCz2zbjyzPEB0Gpia4GHz0IcTQdP3u6rpbr2LrPROhzcNpf7/+duPbIM657abrnFEZ03A83N/XtkYIgIzM5L13wbEncoBYrJwwSxeidfgw8sPMmFARjcgARax2Z2FGx2hqWswhYABPjhwhotT2KQ5QwAw9eUeePzbt9oizAGLtcQWP8dCFdPzRv4pHnUs0Ec9aFRHPyCtUQlPtiEpI7Mh7ahvuW3e7iS82TetZ1A6TvaFkrJFpvdrsNjtq1oIuNCW2UxgPKXFOBX46alfVUpbOyEVUSkUVKw2JzNgEVFTkzG+myxtBzvzdSrLz7NoYV22/7Vbrrlk1sXW6JrvSclvmbUmtqrFHF9gpo9QyY4+opK0TKUnyLLlIsWQ8QdB5HoaTR+s9DQorzpHymfr10VnYNOVxzMOQc07jJMy02bRIHDxjDQRURQ0QnXfsOAQPhkx4OOWUlchMhRCRgQhyKofjSbWs+8AEDFJE82wMwTWhql0wQxNjoqpIAVfAsQ9kOc/Dwx287ea4HrApBsTsHTtHS2MsUzWoGchPbgnWCG2JxWssQYRMeHH9EAEJmAwv9exPo/Uf+ciP/ngO55dvrwnvH9u2j1H+81i+OqTLox/Z8j8atf/kqiBzOx9e3P3m9Zu/vX3/283+LedT1LIyiMV8FlVISqPRYPyGGw46huluHMLp3vY/RCqr2b14+M3L9393/f1frd/+V3d4q+N+yjIVy0DFOamFjCAkI6cZ1WEJpm2xLmJmxgI4OD+JWnEWGvAO89M9j+g8GZL35Ny5YzEgkfNOReHMXEUkouXuVT6FiJkJM4CRFrViKc0ieem2qErkYnB977erq67tHNM0j/f392meCWC7Wd3e7pggzePd/d2H9+8JpoKl9jkuonAuIGZmJjZAM8slp5Sm2SFDCO5Ch0BcGkQzgScMhIExOmo8tcFipM57+WX0DC6c3u4naiWydJ77QH2g6MExsgMkNQBREIcqWLzlYt5ZjJh6N69xONnpYPv7dHwo86w527ljHgHQExR+AZDwbFEuuHCVuJplP/cb+mevigoUVRJBtQZg42gXuV8Fv2O3idy5lAfNA2HBkgkSpEwoKKLjycBTfEAJTF2jujFxjjZNd/DdA3cnw0kkiWWFkiUnmYY0HkfwrGhiUko2M5XHThyGP7VP//9YiEQUQ+jb1VW/uV11u0ArGxsbgzyEcvRyajS1VjqzTrEjM5SXHl+17n2h91kVal2o1ATJM+NWEWiEn2CbLSsytg4DQ8NLf2Gs/BLnfQhtE5qAkQu04Wa1WrldS6d5LgDmjWDb7u9icE6yiiE59p5jDE0bVutuve36VRMDsY6c5pZbIrf20no0tAwqaLWEzgRF1WSZa6Bq6gzUmCjEyM6T8z5KESlZpZjH6DAyP9FXZzfz8Ql7gn+euy6dQ6qFEnNWkIi4VO7hpWgMLmDlQjIhInauibFr4rqPm1VYtX7V8tUqbFtXcnk4TISDKTuKTQxdt227qxlCKgYUQ+hi0+U0q4wighev67lQ5VTSlCrrTsW0WCmWZ8uTpFlzjaKlRuZ4TvdrxR2tapSl7zM1Ma5Xq3W/WrV9QKRcYkkrzFvG9TzEYlRAelJwaAzEtEAvqKomBQRA1BTFQLnV0J/S/vvTQ3R2RcEKOSUyQnjsivB0lWxm5EMYp+n4/hCC3+5WzOx3TW2g4bxvuygy55IRIQRabwIzItL+kNMgZVLaWdvybtuVLGnOJSVoHYExiGnBYujJQUQDzUJIkZgJ0DSLqoLMQ2GgcvQwnu7mH4RXt6W9BmxdE0P1DhkdIoOBqlHNcjzboQCwxOxUx2pQNe1LUUk17YCAFZfFc1MQ+ASOfyqsi1V+YuGX9nxnFK+i5TWY/7GPePyox284I/r4I/ymP2zan6P56iX10/2Lu9++/vBPLw4/bKYHb9KI9KI+GyYVhaTQGfXAwJoVB3cY/fvV3dcSsJtWG2cv7v7h+t3fdu/+0d19l3Iaik7GAzghMARCYyheRy6ZDcC8WitkSAiupTwQjmQT1jiRGUJjyT89TBcCGrkQnffsHNGFErngyxXUqnloWFo0LEWYYCCilHOe5lJknEYRibGNIXrfeB+Cj13Xvby52W230bv9/j6PExZl4qvN1S8//wUTHI8HzXi4H7mJzCyaU5nhuE+pVJKeYwfe136ciKCmRUrO5P0TlAsRCZiACRyBY/QMnsEzBkdNICZ2r9smOvK0veNkibz0PfY9rVqOHomsFiILmAjUChkRyGLeW4iaW8qJ2pbaxoJ3wetw0mG0aQRLoIKqWDs9X6oz8RIS2XM1dSbc/Rjm+MfKmAGomoo4055xG8PVqul3rbtpedtTGyifNJ8IEuXJ6ATliMMIUwIplmY7nQBaCugBupJjDNddGFbru3a7RzqWMhWdReepzFNGN4mZnxoeAgdPziFl/FEv/r/ZQkLy7NvQ7rr+um2uPG0p9Tq0cor6EGTwOncgPcGKcM20coqMnxl9Ke7bSb6brRg50o64ISBENcxqSTUVLbqkEH//tW49pUCeMLK52nnYBGqmBY1JPUP0JQblhqI2XjeHwzCMeWTyEO/62EUvRbOYD75pQ98361W72fZXu263bVYtRhxdyaEkwtAjtVhIkpVspohASAYsYrIkoAwBtKgUUTE0ZHaenIumZjmJJHDwiWn/fae4dFg8Nz2u1p3OXTHOlrtad1ioUIhM6Cqi5hz7wCH4GOOqi7tVc7UKuz50gZqAXXDB8/44708A4NHIO9+1fdvufLMbZyuakRrn2xi7FMZ5Tohyzq5+LF5SJKcatkvOUrKUrHnWedacpBRTeWTq4EJdXciujzQAxCXV55x3jglJCqbUwryFtGFdec/mgFqLOzU28AZ8MSRqIkUlowqqshoZNxjX6vIsDFZyBsmmZelW9WnrADNLSVSAyYFhmvL+YXz37ti0oV8HIjJAYnaekExqdE4Uo0PEnDQn05JVTIugYRO4IOZp1gJLFrty8i/9dKzKCXnnCFVKqWRfzXMas5YJNaUBHgoUxIy0wuC64MARcr3tNUeiavZx3QU+WlpausnUnDvhBVeHc2Ln2fv+YChwAeFNYZ7meU4lZzMNIcQYvHfk+HxbLxf26Tc88kIeIxAzfMIzuayPutEtP5Ybfo7yzyEbomksw3p8d3v/1euHr1/lh60mp8BZYVbNlgsUgwIApgFspfRCR7MPvarAieavgscIKZzeuuNbGQ9TKoPgqG4yKggARlDYBi/ZF2Jwhj5rP2I/QTNbO2oYC51QsyV1DXgHsbGmgzk8OQv0TcdAPrTsPHGlQpia5pyhRmYLokJEpAZaLw8QIBmAlKxzGk7H0+k0jiM7//lnX9xudze3N23b5FTW6/Wvv3xxe33VePfd9/bu++8yqmO3W3Vfvr411Teqb8mb0KrfvLh9oSCn6fj9D9+OY5ppROTofSAUU3auCb5iVh915Kj4Za2pIAJeflUiIkJCCoxxB13v+s324bQ6jXOW5DjFCKvOOwdWk3QqRa2oiVhRK2JF1DtxXnPG7MU7ioGahtfreDjo/qAPd2W/13k0UQVylc9QmSOAC3PO6gY7S2LtV3PWl//SaBcBQEklom4c7vru6uame7Hhl2tcrbFpvM2sE+qE6WDuHeh7HD9AUiyERJBmoyMoIBCXFDx1gder2KxXHdIhz6e5DHM+FBETkKwl1ahQTQGMCIm5TnxS+z+ENs/ETejWTb9rul2gjY2rnFY6tjY2OgbNAUrHsCLYeFwH6yOwgy8m3gP/ZtDfkRWyLtCvOnodKTgUwPsEHyZ5c0z3k0x6Cdl/0kPpopPGkVkgc1BIE0hSmTVPkp3krKQGCTlHSrtO3W08Btkf9BgokLta+d2qMbMpa9vF1brZbrqrbX9zs7m+6jadrcLc5Dno6OYEFprsmyw8H3GeCIQRgLmCsyJFTQiB0STrPGaiGYgpBPSOnUNi50ECMkSGSM7/1EkBwCdh07kUaUGxl6Ro1deLajarzyMF530MsW1C28V136zX3WbdblbttvW7llZeOxbSoiIGVJRUfdKg6ghdG9u+3/l2C35tuSjOSMacvO+8PzENiOlclf+JaZc6q01TLtOY0pRLNilWipVil2r+C+vlUulXsZb6iYQAZlq05JLmaTIhSa1Mayi3VDasUT34jSkgeuUGOCAz17ykmuUpzVpyQTAFBGRmH13oQrtpetQJBEq2lLWIPsGGH5eqDUOe56xqjrlpQiny7u1DvwrXNy0zk6GZqpZccsqJgB27CjbExq1WxoxS1LFJLqiiRaEYMTEEQDNlM2COiK6UYmrBe++dd06laDEkjIHNyjTmORURYyQtfHwoWcU53q0iWSBkMzQ9h9GoZ8bK04hqycQs1WyXMrYzFe5J/XplIy0NEJ9Z4fO9ftK56OzcKeQ5f3h/9/7d+8PhICVf7XbX17vNbtv1HTFXt/QRDj2vBZZaoBY4h/yXdkXP1sdeMF4axl+O9CxKhkBWmnTcDu9eHL99Oby5lrG3QgU1Q0omBZJAURBc2rgglqCwg0PE5OghFAdQRFNO45ynrDYaDkYTUkZSQmIApkqxNsJCXCgM3D3w9uiuRl7PuJ6tn8RPAMIE3oOP4ALQY/9cBPShUSDnPbOrGYzalllrpWnFUXgp215YYbDw3pe+8DnncZiPh3EYYmxb515cX/3ii89jE9+9exc9dZE3feibOB7bvvFpDN7Hvo2rphnHMU1TmmYt2jXdi9uXRcUdm2kqcxI2GoAhTyhJtBBT34TGO8dEzy1iZWowmSNzBMzAZMQ1Q0mM5BhClK6Trg/XyQ+jTykBTN5Z23hmVClFJYtklSKVQSZFNAnMWUOGlCxlzQ5iwCZi13FsKDTGRERyIJ0Gq8yDyn+2iirWPjz2PHB/FPZPN/4fvxDNyKQl2wbatnG16sNqg90WmpWFBlEYM2CGvEZxlACPiWZjIdKgYJjmOpnORLhrOk/WeOqDAwxTCSWxpDkNMJxsGHQaIM+k5oiC92Tg2ImKipZSfqoz6L/svACQkSPHTeyuYnsVeEuysmFlw1qHTnNnpWVtPPQNrlpcNbRqqA3IjDpAesBvCh6B1NOm5T9b0a9aarwJwPtE34/2m3v6ep+/H2SfNNXqwKVW+qN6BewCQ+MIwLOLASMpW7I8ynzKkyX2waRIspDRSetyXEELGBEaRg90vw13Vy2RnebS9X6zDruNu9rwi6272XLrc8TsZPZl5JJBS6NtaxLSyPOohJVICqqCRQ20CCCyQxSTWRJmg5kFyNATO0fkHKIjiAgBiX/kusI5ArkY9TM6ipdE6BIH0aOGppocYe9c8E3T9qu+3ay6zapZr8JmHberZrOKm86vvfWco04sqSRJSaZCR8EkNAmLsXOu7frV5iq0OwxrSKmOWSWemSNzRHJLnLTETM+0sEmFnHSeZTyVNJXz2FY49524vOVxq52N+0JSOSdkTUUkZwPxOm9k2pHsTFs0Z6QISg7JE0fjgFyjUFMRg1zEchFErYQxBMOSg+rKOSnOJKcis0g5V4p+hG6Z2TynnAsaeE9t60XheBz3++F4nB1zDG7B/NQkS1EQROeoOlqOMUYWrrNozIBQkcl738TQAZj3AztYbTbeO4Aswm1HTRNXm76kBOiIyUWX52maRledFkADK0n1CNMDTj3ktUqvDs3oIi2feMCIgPzsMj++8EmXpsfQd8n+XD6wjmqCJ29bIqGzAyFFxmF++8O7r3771d379ylNtzfXL1+9+Oyzz65vr2PTuFBnlNUuqBfC6lMJuHgjZ87LJ+fxzLR/nOl/tgzQ2Eo3P+yGdy+m97d5v7bszVKGnCErJIGslhWKoRgUw4FhQkUYO5v6jN2Ag9pedDY4AYyAE+BEmJC0zjlxbMFJiBZj9k3x3eS7vb96H148hKuT32TfGTUCrigJeKuF8kjPgWEMPkjtB1Mn8y3pkeqiam3EAQBKIoUrJ7wWoYmKWUFUhxCYo+OEFJi3q9Wrm5svXr0AxHdv3pyOx8Nhv+mb6Ngx931XCjgXnAvznO7u7r795psPH94VmX3gvm/HORG53e5FjOu7bre/+0FPH2w8qGRCaJsmxoCe0fEl0Q5LkZI5Bk/g2TyDc+BcLaitI7SBYGYSDiX40EdnGgGQ2YJnJjSLYppVikheyHqSSkmljInGuUyzTrMlBylZYgsOmdB7Ct61DcSQ9/s8J80iiOcyqyquFfnVx8DdEJXQCI0+Grjwxy88e8zQM11Htw3UMDgwU9UsBGKkwIS+xehgpTgL7mfKyOKscJkpF7NxEsOC5NSA2QVuPYEoaYF5yMe9u3vQD/tySHJKOKcAACG6HjSKFMk555xsmqSUP3jEf9RJAdR2rt7FPnS3TX8b4w51bdPKTisbVjCvSTYA6x5XG1pd8+qa255jw8EZgYUDhg86NXC9cdTEXef/bYtfRI1YFGCP/ofE//U+/PW79Jffjb95KHcZ5qVZo5EpPT+a1jM1PtS5Ck3TNOwhQTrmgROnGYJX8yFlKRKKYw0k3JSg1qAFsNNVGF/2TYD9OMfG9T2uW9mGeePHNUOk5GEiTAzFGRBwR7pCbUvyaRZEdI59IOeEsxUzVSYM5kgRVItJkQmS4JxDaUJrzpEjUmCAj5uWf6LPzkDqY8h+ocFjpc3VWehKpEzsgm+adb/abbcvbq5f3m5fXa+uN3HbURehcRpJAmZXTpQOUI6SBivI6s0oq58Ex4yC5IPvN5v19U1cbS30FBxlBFDFCTAg+qVGCmUpK3h+1GaoBdMs45DTpDkZaC1BQQIDWgbDAJytw2NycQlK6o8lawZGKkHkyspr1GuS3qlzZM5BaCw0ytHQIzt0NTpUMqhd2YsqkDA670kt5f0HHKZeZTKdi8yig0oys6U/NT617pUzpKpEHDwzQ0qSig6n+e7D4B0F3zKTc56QTSzPktS8YyJMSSSrSR3t6gnrABeMUdu27dY7BB3GxOxvX77ywec0hTAgj33f39xepzQeHvbkyAU/DSd3OpZSipQ5TSnNhIgKw4Pc8bjtxlWXmJ1bRqk8KXx8ejuq/rIlkQMIUNWcLdSM6hyaLv08ahXdI+tlKbm8hNWXSreKxUOa82F//OG7N7/9x6/evf1+Gk9vNqt3b14Mx+M0frm92q3WK99Ev2QJyc4NPy/mvCb8H52OH3MZP4naP1ZMz87aS9oN9y+O76/Hh1UaXSlWQMRKDdYNRCErJIVkkBAmhIkNWR0ZmVI2EyiCM/JIPAHMCBlBECvR3ZxDDuBadasSNnPYnNxq77d38eYQdmPcSlxhbGsbGWhWEDtw4dOjZnJmgERPPK3q49Y5RJes19m1N5CSpAiAOkdXu/W6vTl0/Z0PJSshMZIjCj6YmQnkVEouooqIbde9fPGy6yYi13fdPOf7u4d3794Nw8k58p6IrJSU5jnGZrXaeR+aEMo+2ilintAKeGcECgbPQeAKyDOqQ6u5dlenHfOZq3nujcOEnoBiJWlR7adR0Q8xLcqiIpXLp5pLSaVMqYxzGWcZJ5lmmxzMzrxTx+QdeofBVWYVH08wTlDEilyUSx0icZYku6jTOhjqn51rf+qQIoBD7APfNH7F4EqyadAjQSrgRmAAT9h14D2GLfVC7YFixoxmQGQmqmPKAhM7SzoZtoiOLKqo5ZSndjqEae/GA82FsrGpo6r4rKAAFFHDIoSk//K8wvOFSAYesXW8Dm4b3c7bltIGpg2Ma04bp7uIu5Y31279wnW3rrlxvmMXiKCQFL5X7uTPG7jZkfO08vSZ01sUhmJgE+MLsesIG88kwRH+/YN8mFWedEx8ejhNdL4JwXEIMUQXPDidbT7kk02QGBqvGEWylmzFOyW2BksIxat6g3wVrKz7Bj8ciVhj0M7nFWtnHFUbU89CjhmjR2/QKqx2RW42wyHZPudC5DyrI49WGNW7hnkVAxpkkUlkzJIti4oxgCPigAznMpePJetHweHHuRxwBpNgSUgTM3nnm6Zp275fbdar69325fXm9c325VV/sw3bjlZBA2bWQpqhzCCz2CyWxcTAK7pibhaesk4FgEPTrbvtdbe7xWZdsCGH5ESKU2NVUq1Dk+pUtOfN6QAAICWZpjxPJc1SipkRncl/lR8lBpXMDo/b5Jxyr7wrNURCXnJiHrRTu0J5gboh9WzkCLwzZiM2JAUkJGCGWtpldWAfq2JWZRTwoqp5zpayUwGVWcskZVIpj103P8pQIxFCjTrYRM3Uiug8lfsPQ9+FzToSUQixbWITo6ScshAQeQ6ucVzLhLE6wY4ZAEvWruvW2ytVCYfRh7i5fuF9mIZB7Jjk1K3X66uXKc1GDTH7GFw8kT+oipiO42kcB7OsmtM03r07du2D94Edh+CYCNAeTeZ5qaEInjslGC69+qutsLMlqdqupvfPCXg8dzg6C9ujkC7wsOWS5ql8eHv/3Tdv3vzw7uFun6akWaZh2N/dve87RzRP4zBsulXf9qtu1bkQ1EClUryX8IfOzL6zSCgaqv5e0/4oPGAfuZdo1uT06nj/+f7D9jT6KVdmTBaUOtAXTAHEIKmNiCNCcpAbBI/kDRHF7GQ2Ac6AySwbCpiQCZoBKrGiU/SKMXE/+Nsh3Jz8enDr0a/muNV+Z6sdbHaGBKWA8+AbIPepnNmllB1qXgCpNpuo7StrOomI2LHzAKAi6bgfj0fvfbfbffn5F7/68osP7999/bvfDeN8Op7maT4dT8NpJHIiiOCdi8HH4IPfxl/+ys9zUQUTmef54eH48LBXlX6z9o7meTgc7x4O76+vX63WPTEGxxIdtI3LR8vDoUyDZDVQRRf6p7kVhCUSdoTekXPkHJ/tupnU4UwAnJGASZm5jnLgZaKBVQO1fIonAxD1WWTOZU55nMsQZRjL4GWcxTvzrEzGpMxIZCFSs8fD0Y7HMk61p84j+HSe0Xe+/IRQJyr9i0xi9aARwRGtg3/RxR4MToNggTQ479E7cIgxWLnC1ZZci/GK/Tuho2lWUTVS0TKWOemRbR6lmRSyRVUycZqCTlGGVqcWUkM2OVIjNSxFDUS01Lo4qVOffox3+setR/yrujoMFklXztZR+1BWDjYwXtm4hbzzdrXi3Qu3/cyvP4v96xh2jtcMDo0MJNkM3M5dlC9X+PJEWMQVaVPhnMEEQAPM18z9ijbkevStp2MZx2IDoCoYfMT2gSYEbaJnDI6DJyYFmWQylWwyM/TRfDbIJtly8aLeAohH8UEjAF/HNtCmw7cPkMpomgJKixgEXFLPPjrnYu8QvWuQ+8Cr1Novii/h7s39fsqFQ0BC6BpUJZVVCJu2MbFxTh8Op7fH01HTlIvlbJqRCvs640R/bDTBE17QovkulCF8YtdpGa/mue3bm9vrm9ub25ubF9fbl7v1i028WfEqasTkcfYlYZkgzypFFUSyFiwWC/sMIVszmR8LD6lMCdh3q7BeXb1sty+SRcloKIjODItYKSAC56al8qM2cRzn42FMKZWam/DoKg3dtI56oQokynkIO9Rhe5X/ZaCGtXcHMSEyYYu2BrkCvUZtyYgBnDP2iqAqolKbAgMSceVkIzlPHBX8PE9WSsEZEUVFNIuWJGXQPFiZTMvjOPZnixBjE00VlKRIzuIcBuGS7OFu2mzSzY01DTcNrTd9SablkOeRGWNommYVQlc7cpYiAOacB4CcS9u0m6ublJKLD+xj6LYhRDXnJiCvLqxCu0WfoiA7F2JE34Hr6hiLYRybcdA8lHH/4d039x/2ZiaiXR9XqxY9UB1U8JxUI2qp1GbitfoR6i5eGtHRGaRf0u14rpWs3M2aeq/WB88/rcbZCno8Dh/eP/z2H7/+6p++uX9/Z4a77VXdg8FzGqd3P7w5no7d3Wq13e5url/wy5ZY1EqGnKtSQmJ0jOyIGQGhdkgDtSLy9EQ+Nu0VP7gY9cXII7JqU6YXx7tfffj2lx++25xOPEmqsqsgtoxQNwAByAAzwsSQHRSuo9OWkeKjwmQ4A2TAolAABUDqLEekQr5wO7vV5DZHvxvi9dhcp2Zb4qq0K23XttrAamtEVgQBAdmWxinPFCvVunaiGjAhVcoImmmhglSAkIlr6o4IwBSwgM4gxma79fbzz76IIczT9NXvfnc6DbnkOSUpQsETOefJucpFdezcBjgXFbHD/eH++GG/3w/DENq46jtiHIfjPA2lpGkaTseHaZpzSeSDW21jYks8DoqaH+Gby2kAVASRa8ju6kC1mlWocDiAIQgaqVFGV8jq7naXqTuVBWJItZSrVkZ5gkDQOGw8d0GGyMNUTmM+jXlwxmSOjdmI0Hlkx8TLsCmdl8YZT7jx+HiwlSS/gKL/smWB3aaNN6vudtX3gUHNUjYUwQEIKTJ2LcQeROs0SigAs8qUZdJSOM0yD3lIdnRoh6m832/adk3oSW2eNM8oM2vyUBrmNjokJgIwUgNRFSkkxI4BApjRHz3n5qMb9+QKLHubEAJBS9CxrpyuIe0oX3u52eD1i7D5Mqx/GZrXTXPbQMvm0WpTmgw2C2hCkL6l1YQ4Gg6GDwlyUTWACjyUFtk1Ztf+ZPzVhCfI3x7ymMTs43KrLKKloCJBHShSezSWkrNJZhCvMRh7VS/ionoxZmWUgOaDMXETXfTaRj0OOqfiDCLzpm3Wbd/2XWwiMTnnvG8prNSvKFFZvexv7394+/40jADgmLxnj+ZM++A2TVCxccrffXj46t3dd8f7d9PJRFAyY/FeAxsT8iep9icX+ekFX9yqM9mMkdh53zTNdre5vdl98dmL1y+vb7fr60171YdtQ73PHpKWAUtSLZpzmZMoGUaFVhAKaTZN6ibwg8ApyZRcUd90fbNar3af+f42T6ZpVkNVEzFRUMPaIojQCZQnHacfV8klp2RaCI3dGSlH1WKqpjUkW6K+J/8qAmwGYqa1Bp+QyBG1qBu0FVnL4L3DwBYac1HVRGawgrUlNBEjAJoqAjG5gBhKoZznZEJkqJZFisqs5SRplFxnxMPiD8BTtWsGUowQnXNcGxMoSAFVmyc5PEwf3h9Xq+A9STEi9j7EBvpuvVpt2mYTYofsFSDnLKpMThXGcSLvAb2o5EKKXMRh4SlhyiTixbxYEEO1htAjR2Jw3tgzewfcObeS+WGywuRzlvu7E7G7eXnVr/v1JsbozyH141JV0QJLFhcfq90WRwsALhkeQDrPhVvmhC7TQpeg/bFK2IxARPf743ff/vCb3379m998naeZDLCPMbrgMXpiAJnnfZpPp2GcZgXs1mtyUQ2KgAhoHdgpqA6dmVqdB6i1lZjIT0ftS0PUpz7ZkmJw3vRm2P/q7rt/9+Yf/+Tt79aHI822DAn7aAICgCAUssyQ2YRMzVCBFEBgMpwRMmIGKIbZUAAFsRgmoJmb2a+GsD2F3eDXU9zm7QtZ3Vq/s9iaCxYbC9GAgNQqVmACps9UF6IPrVgtaKelRYxjM1AVt4xtqF0hVKUQgmNogoM2oDIaBOe7tk/rab1ZhxgAUVRVhZh88LFpRNX74JxjJiZkAiU0w5zT/f3DYb9PaepW7Wq1IoTT6VhKcexOh4fDw8M0jbnkvu/7Jhp0ZGpu4pK0zrl6ts5TUh05X3202t8CMgA7JmUDrk13iERdsYAIjiFgdTANCckWNPNcM2MGYMTomRtPpXV9duNcDgNEL94JozIVQiWqjj0DOBEUxaLLLKulsc1TiPSMyf8LrPqFgIIIXXAvVv3Lzfp2u+4iAhUjRVXLSTSjBvAOzQDRSpF50mGQ0yDHlGdL6uZRxqkMyQ6Oxvvh/Tc/9FKu07xuXQO5pLlIARNGjc73LjjvfWEkD7jsEyRkx6CKAPRkOm11lP5gu7pzaRUsoAagLa0KkZA8Uoe2Zt1w2bVyu4EXr/nmF6H/RRN+EXkXqQsCmGW5mICKUKwlKEBeqFEICGAwFDABIdBlQCVaaQlfd/Cvsfn3utpzeij7Mc+GoE+IdAZwPI5pf4xMbXSqYqbeC5BTzVnKqEoSnQUSRAFSYwV0SixEFgh85DZ677oY7P5QjqdChtE3V5vrzea66Vau7dB58sHHjmMPcdWj3/xCvtifPrx9N+z3eZ4dWRd9IPNWWgetQ1OYs3799v563cdvML2ZJlMsxYFGZ23EwOQZP73Ql7ofxKdFQ1Wq6ixCx843Xb+7uvrVrz77k1+8/tVnu9e7rnfakgacvGSSlGVOaVYDZp+znyZSaihsjKIgZ5TZSgKYFQ45ndI4l2TYrLe3uxevV1evXbOFPIqKZpNSiqiqITpizxSIPGI2kOdeSJUSRRNHymwugPPI3gAgT5azglGdw1JLsKGWtMNZ5VY1BmDEQI7IeaSOdEPasrFDbKI1LXAH3KiqphE1EQgjMTKdR1YaALmALoL5PI95HpGKI1CjYjhrPpY0Si4mes6e6RO+NQCo6vE4x+BC33AgJlPJeS5igGaH/fj1V3m9ias+SLGchDisNt12c7NZ75yLxJ69B8KUSikGwGkWGyFlOJ1kGPIwCHs9noQoHe7H02GeJosTDiPkhMOAUZCZpJBK7TGEnhuKIetcnPMhOhfnOd3dnX747r7rW+93ITow+Kj2zUDEcu3ZXJXnY80vGuoZBUPAOq2HDB8T8Mvs3jNbEwDO/b0MpNjDw/Gbb9/87psfvv72Dah5x+M8T1O83XVd7NoQieD+eDwOYxbhGHfHWx87WwaVEACImQiokiq4pX5Pqpb5A4D8OWo/b436ItP1fPrX77/699/9zZ+8+acX9z/4cZbL1INnk1kQCbAyHAmEoBAA1edUFQRqpI6CVKpRR8rsk2vmsBribmxvTu3N2N1O/Yu8eSVXn9n6Bpq1+VCTz4B87qaidXzEmWPwuKPxUq2wACYIBmaqoqoKZpqL5FlKQi3b3WZ7tb3q+zJd3b+/z2ne7/f7/Z6I+vUqxuic88H54J1ndkR1wiURGKoYWC0ZBxUbhun9+/f7w16khBBWqxUhTuPAxJ7dh4cPdx/ez9MIYOn2Vq+uM4KjkLlVV1Rm02eIChjWIUuGIKACUtSyWMoMhccDaQbNggrBYxOt67RrAStAp8gOkA3ILW23cIHw6sBUA6wdiByAc+C91QbLwUFgC848q2Pgmmcyq6Y9C6pBySCVGWCwjI557FSzdHeAf9YyrHfUI15F/+W6fdXFdXBNZOccOgXImgZLkwUAk+odqxXNyXK2VHQuOpkaaAEDKmZjsbvjNP5wF3L+ME1X67hrmPMkScoyGxxAi2TIWaQkVUEE5xlBCUClLMXRT9aZGvN7rfuZalOZskjEiB6h97xytGW7Yb1t9MUaX1zz7QvYveLVa4i3xhWS1GJKpoSh5diCZYgjiJkk4gJkKmwNYSQMBoDVtCOqaTZHseUXm/Wf7V586PO387dDuR+H03NQy8ai0yziDQmcK44zgSGJGoFKAnMsk1fP7Jk8YZ2hRKZIimwOyTnu2xoVWdsGohjDarO+Wa1uXNtTbMEH9NGFlnyLvvHsGoP1arxqu+nhIZ+OrLlx6CG7MnrMnhUAi5K3lmw3TIf7w8MHTUmQi5Eaq3l6lufBs76Bx9QnXlJFUJs7IwE5Dm27Wr/67NUXX7z+N79+9evPdq9XtAvqSsaSVSWbCJRpKu/3o2CI/W7OtD+MRl1Y7TiuyLfZWbKcRRKVEw4zBeriKuj2xevdi89ds1YMptmKqRQrGbTUSdBEDLTUR1f78JG8eCdNI4RKDMxGTquGMce1y/oyyeQy0KjOksTzuIcl2w5m5s16sC3Zzlnn0UePXQ/dxiyqOisFZESZ2QoBEBADmKmJqpmRAw5EwYzmlA0mZlHAZO5Y0r5Mg+Si8kSWPj4RKVIQcy6V0uAqYgwACDnp8VAQHJgikGlwLjjXNs3auVYNpPZiIpZCpaiKzZOmuTaUSeOQpgmc2OlYCOV4nMYhpSRhTKfjmFMeTrOqsXOStSQEWHrOgIGqAFjXtVJ2+/1BRR8ehvfv9lfXbb9qCD/GtMAUrNSmnGpGiHXGx2UWYS3qX7x3MlheVJ89e25LUhyWkqIF1dBpTIfDcDgMh9OIBp45zXOaRrYS0BrvfOBSNJeSUy4pSy4qQs4TEQIpGcoyXQHEqnNhZhdeydPzeG7azyjPEsADAgKBBck3p/t//93f/O9f/eUv3v5mvb9LpRQlpLM+t4XYTLiUTtQLq4hCiIyMSIoEC+VAERVICAtSdi77OMfV2F6f+len7sXY38yr27J5KVevYfcZdFtz0RBrqL0UhaiBKqia2kf14AamWrSOQ6vNGrUIgIiUXEzNTNM0pOGYxhNo2a36292L3WaDZn/713/17ddf//DD91fXVy9f3fRd18Q2+Ni1Xdc3zjOg1lFPFRXPSREhJ01FS5bD4fD23dvD4QAIsWn61RpMVYqClGLD8fDdN7/L0+AYDQuyzc3KOY+uM7FSVHLxTwTNFLWQiGbVpDpJmYVTQVQ3ZX94r8e7Mp0KGmzW7mpL1zvUjWERLIaNQVAISmC1MP3sdVY2IRlY7fROqJ4raEjR+8ZD4y068lwcIaEBiBmpgYKlAkXRFExAVStrzhAVQQkMDQEZmOEnkdOfXsgILePLJvzJunsVXQsSgvddRFZQShNmzWAeQZEMyKyIalZTE4NskA1ASYmcA4G5wMNY7tKsw9Dc76933Yttu/YUTLISsNek8zieRjtMNmfJRcCAasUhSp0t/dSK43kGAfyYdf8ErTCiykMMwbmO7crTi0Cvgr125bOVvfjc3X4BV59Jv1PXCfhsE1oWQzL25BtuOte9QDbQwUQsDaCDZinsim+oJcyECUjBCAGKyWAOpQ/9even6z+9v5K/f8D7UXIan805B8yGyYiVslIxKKokmVQNCMwkQ86QMqeCcyEv4ItRbT1IKiJekQ3Z+VXfcgy9XLNf+7hru6vYbtE35hvzEVxADsi+OjcE5oziWmoCDtPgNGGeIR3IJsKMBJ7oNjq4bu7u128+bOR4vMsFEtisCiKSwesTyTonNSs9/EKcq/gn1BCLAYNvN9ubV7/+t3/6Z3/6q1+9bl+voJ/v3LQv4zxnmYwMXXDhYZB/epMThu3L3Vz47R0aNx10De2acGWOMlStO43hWJqmidD4sL1+sdreFqFpyipmpdr1xCBKJvQkVvqJ7oZNq+uVLKlP0zox3YCYgjFKFlSlJbVuS21YdVwML+RbNQUtUcvWyhXplYe+da4NtN5oe6XFy6xa9iAz6YxW0KzSsNRARUSqY+vYRyJvCkmS5TEDzsAPUh7yPFguIHr2Tn7kZMiKlnGa1VwTqYbO5xeTii/Z5zkwOyLvfRdji9ik2VJKuRRyEdGLWCmac0lzGUdVxZymeUo5gRlOQwHQcZinaS5Zx2E87B9KLsNpMGsds5RSkqnVEalFZE5pVE1938fgiWkYjuMw3304nk67bcre8cf1PGaVomKA1bjDmbd0prifqxIWOPTsXuG5gGEhH1F9sjYNM9VStGTNqZZgg0hthJ+no0KZUHOMYU2dAVSkmYkqA9N7JnZQWwuTiaoUra0zkS7DZAl/n2l/PL3L9iGn+er08IsPX//bH/7+X739r7vhvZM51/SPAixwA9ZAmhA8QkNQGAsDcG2EXL9z8TzrRNTaX6hWoYhzxTepWc+r22nzeVq/KOsXurmF1bW1a/PRkM4Xrc5S0Meqz0rafqpqDSQn0YrXkdHiU0mRkssStZcZrIAVXRo3Q9et1133/bffvvnh+8Pp+Obd++3VeumJ4JyPwfkgZmmajseHnOWwP576jUevonf7w/E0DMPw29/85tvvvpnnse1XPrSlmIrUQ2HHgCaaiiRTmIfTcNwrckB2rkFDSXOB/NSFNAFJkDPMCccZTxNHcpidzjjt5e236e5Nmk4FzbZrvrnhz17RZy9RbwmtbvpspuSF2NdykqoFEQiQFc54ZoU4DMFxncjqCT2Dd+icEgOgGmQDFLVpxpypJChpgQRr2FRDeKtS/oc7on0ibggI0Hi6jv6z1v8q8g2JywNms4qHlMnyBJI0s6YJ5xHmEQyxAkGKUJM6lbPGDA5zljHLKUlK5TDlcZqncdq2cRM9EguyYjaVknMac8ollyrNUHIuKec56SX98LjrfzJer6OTiNCx69qu7/tVv16v+s2qW3tcyWFbjjfl9Nrlz1fy4iVefcnr166/AscKQ7HBjMwYzJv1Da4dOkdND4xWGF0L4CwTJDDcQH9rYYc3PalDRQDMaRj3P8xlktjh9oubX/35r1f2J9+VN/fjw92H5/ldm5NOWT0zsW+atu+CI8M6xxPZiJwj9uxC9E30DXIwcgpUxFK2XPtAOOd8bADYmyO/c/HKxy3FNfjGXAQf0PmlS7EKqKAVSMmVQqJODXOhNMJ00PEBbUDO6IyYWmp3rrlp+LZt3x3nckqKSXXM0YLPbSjhWW8qArPzqMFHW2MAAAzASOxC3F1ff/blF7/81Zdffvlq12aPJzTNWfenvB/KoTB4t92s7hJ+9cES6svOKzYHlaJ8OOom8G61AnbZxHw26Mia6PrWsQuh6TYcujJmLROUDJoZCqMaCoGA1TmtZybSj/HovNMYlRkJUQ1EsQiJMZpnZAbT2kVJQcRy0ZRyLloUTCuBygytYVwTXJO+JLlmWzkLDiCQxQCxVWDNGcBQM2hGLbD4Q2YGCw2LCL33TRtSS3NQcVlx1HLU/CDlKGlClTrk+tJ4/WNuEKpayoIIRFxKnXGNYEjkgFrTaBoNvYGTwslM5hkU5nnORcgVRF85rTmnnDUXNqNSdE6lFAKEccwAlrLlAlIspzKNc8klpeycT7GoaCl1oI+qZS2TZEVw3qNjaoKfJxqnfP9wur8/bTbtat1+TOBQBTlD25V8rQvMvmTV8dHEgykgmshl4saC1AMBaHU261SjnMs4zMMwTeNcskJ1CwxULEM5jdP+6A7D4JvIPq48r1frVd83wQcmz4yMomZoiMaIyGRnw2cX1/H3F7891VWAZkhO5NXDm3/15h9//fa/vn74XZCTkpqe+y4vLf/AqM6VUUfQAxIBeyRGq5i8aR2cVoumltoArEE1KpM6L6HLq+t89bpcfa6bF9DtIHZGAcxAy3Id7dGiP3v8/NBLnovU6eEETDUjp1kkZRUFUCKN0RE2kkspcjgcTaHr15vNtl/1qcjdfj/m4qMnQmYkZgUaU57G6f2Hd2nKt7vbVbN2GKZx/Pqbr9++e3t39+Gbb3/3zTdfrbbbF9svXGj3h6lISml23rGn2Ll2HYiyFRXJ8zCQb9k3HHtgL9Op5PRMxsQkQZ5pmvA0uqP3XOJsvH8v77+fv//d9P6HOU9KoKsOb29of+/TGDzF4IgRwESkuFKcL+zYXJ2pVRvLAVV/D8Fw6blINckcKbAL3kXv2WWkXOfSAaAqTTPNEw5Hnc6UkurKLh5WXfJJ68Y/uBCBqI3+5Sp+2fovPW51xmkWnxQCWIE8yXCyNJsDGTwc9uwa8hFEUQFqVliUCAiQHIFgBskKojgXLSlPqQxjPqzb6/Wqa9kHrwhIBUy16oZUasePknJOOadiqnXiwCJVTxp8Pj/4JcNLRN67vu9fvfzss9eff/H5F1989uqz2/VtSKuHf+rvftO++6otp67T/jPXfdH5LRGhnUw/FMtiLNaI9goEuO3QowZvSKYFlaGgTWiDw3iDq39N619z+xlzi+C1qD7c7X/3d/v795l41f36iy///PUG/uR3p6++ffvVb/7pWTWiwZjyOJc+ehfCar3erRqm2oDPimEBik1om7br+37Tdw03XgNnh5nKbGVWNEFidhSjw4AQkdfkV+A6da25JWQHYlAxLZAmyLNJ1mmU8QSnIw4jDIONBxsfZLxHG8AXcmoeyPcRV2vWXQih0PSQ0wh8shhT8G63y/3q2YU/Zw6XwP0y0QiQAB2yj017+/Llr/7kl1988eJ61+J4GsejFsglvDnNb+7yh7Fwi5912wdx3xwoieBWY+ctbFIq03GGtqzRI0fRgsGzg8Z3zarUchUDp4IqxcoMmsiSQ/GshppNTItqMS2V02B2mbjwuIjMscaGvEdDEuF5plwcmEfnIRAqUUXGDHOWaZrnOc9zyaKiZqCAunJwHeB1sM+8XrN1ZEymWBWxUyJFAbCa169FtjX9o4ZqC80WnQ9dF8rKjT3IXCAPIvcpPWg6QU4EQrX/l4IZqXxMLgc0gyJGWQmhFMlZVQgMnQ8OOoKI5kFZlaZUQBOomWrOIgLEGckbkKilkkXALACigaRsaixK01wQoQirOjMoAmlWEdUlIJZacVeJjCqTysSAjhrEYpAqV2acsu7H9+8Pm00bGt9wfHoOpmbl3BUF6zmBLdkUWhCohW9kC3sOVKHCGTXxTnXGe6U2es/saJ7ScX86Hk7DMJZcEIjICMyA0VRVx5ROU+pF+r5frdeb3Wa33XaxCcyOUAGkTtdRA0DHrLUBm9YYt/YeenY7PjLtZw9gccoQzEKeX919/8sffnP78EM37c2y8sL0NDtPnqlGGs0AmMAjOAeBARkMYATLiEZYENVwuQhwGRha6S+KoMRETUvrna6vNXRGvubIn5nzR4sOFyTgue4CERGpg30IgRERVFXERFSymjRt6PqecY1KxOG43+8Px93uipxru5Whk1oLJVokT/P44cNd/8M75DDP0/v37+Zx/mF70zY9czjs7//hH//hu29/93D/4f7h7jjsV7tt07Qp5+9++E41G+h6vW7bJjZNv9oQOskFnRPRPCdyM3JkZiMCds+yu4qaUGbOkxsO/kNxJzUYy92b+e6HdPe+HPcISt4ROGwTHUfeD/hhLyGCKncZQrQQREPxgZwxMQHVahgzXFpv4pLmRl5sfCEUQuNaLsREMBNmIlOlaXLjSIc9TiNKrsxcO8+6hiURrfqpaf/U1F9UnAEgkwtu24bP++ZVoCtNTTabpeAIiU0y5BnyBJKKZlEhYJ6S71eWzcYJpgKnbGMBh1ApBkLFoACKkaoVEJm0lLkUSAnaJscYcpFpTmkuS/mQWsnVqOeSi9RKkueG/OwpPz0NBABi6tpuu9m8fvXq888+/8UvfvHlF19+9uqzlze7mxVvyjv/7TsP6Mfipplbc96cFhrRcrG9wAcFRXBgYuqRKFLcUGiRyFSsJCsFi1oiyBGaa2y+pN2/xd2vkHtEj6VQ/+BxHXYfQM2//DLsvmx07Pp1E+OnJP9cJIsYABExEbEjlFojTUBM7GNsuq7tVl23ahuKTiKXgJnEY/ZkwmSIzhSNGDAABaAA7I2cLbxLMC2YJ0uDjQebRyvZUrZ5wmmCeYB5smnSadZ5RsigGbSOEU7McwRZOfIGMukwpjJYCMl5/vWcPzqXJQ9yZgfVmNKAAQjJtf1qd3v76ovXn33xerVqVNLh7pAfDm3TF1l/n8r3p/L+IYWMPWynps1hNw354TivnTbdylHWw8G0NmekQozg2CE5Tyq1CamKiRSQBJYIksMiJIy1M//SClJkMe3wkaqqu1wXFHJRV0hcWJVBHbJjdr4mNogdMahJkZxLTlJEa7TEaD3bFdsvnb501nNBNDUwVVIBLWa8jDGtFcpmBqiAxUCspqGYGDkggvncxnYTTMZsc5J9mk4m2VUeEC2mvRTQAlKe3ogYGlU1I8dIZAhara8p1P7xkjlPVFkyUoqUbKqgpoJaCVkMgE4Ni4Aa1kgE0AESOgD2CowGgB6xDuNmVTIlMxLRec4AqFb5WCy5aBmNDF3l2NuUbcpwGmUs8w9vDl0fu1XDHJ+BWmZai4orNQ1MVLXmtYnr1GGozXgJmYAZiAxIa2bfDNXUBESNaptF8wgupzyP8zzOec5g5h0zoqvtBSi2kbab/ubFzctXr7a73Xq96rq27/vgPahJFjFLpaQipWglR6uZqKiKmQAAIc3zs8jwk6j9UlpRlbVqM08vP3z/xZuv1oc7znNxtvSuNFCtPMnqo0Dtu+IcBg+RoWU7lwzgiVCQMkFR0GXot537aQCpUcmuzF6zR5PYatMreTCA2nFxYTCcwSMDvKDxnzYdMJBKl6s0uqqOtU5bF9VSNJGLq/Vq3W8b3959uN8f9nd3d+vtThTbdl1HAKpBziXN8+m4/+6778ScKGkpH969n8fx+9VV0/Quth/evfn7f/y7b373T8PhXkw4BHbkfBiG4c27N2YSm8COm7YJoVutdkRxTtlUFDHPGWAk9i56AKgVeo/noWAZNfk8xZO6fISyz+P9vH+XDnelZDZzPmBoyffseoJIyez+VJwvIj5l6huTxrQYKBE49ARMAAwkBmzAdWpsLeCu2JyZAUpwyLVdV2AGISyIRUTmmcaB9h2NJ5ynWoq/1LvVQOn8GR/pL/zkwbMXEFEIYds2n/fxpcdeJpfFqOSihgqloBTUAqCW0MYJx4n297rdEQY9Hm3MdpztlC2SOQUmVRZDWVroEJqJaJaS8zSO2Tt2nsHQzETQdGlAqkXynHIuKrrU/j7dGfDMrlcebF0hhBe3t7/+1a/+1//5f/mf/vzPf/3LX37++tVmtekCBzvQXZZ7AxrZRrIEgJAR9kkPYGOBQXAwQIAA1jhFT2HH/S2FNQBhGWE+Who1CxY26QCvMbyG7he4+jVSC+hQC4ebVXPlplOS0vQ77m7t7k3lonx6M1QXjLg2IMulSK2jJAJGcj7Epu36ruvbtmsiBFcCu0DeqccSSRJqNiBJqgTqmB0jMlLtt4BmBlJAM0x7mx70+EH/f6T9V5MdSZYmCB6iqmZ2ibuDBktaVZ1FtmdnpluqV/a/j+zT7MvIdlXNTLHOYplRERkRAJxdYqbkkH1Quw4gIjKnZdcEAQEJANfcTPXo+c5H8gyiKIZiWFqfp1gpWpuKdR9vRGdCIEVvyX1DkAChQSl6BuVIIWCtH/NMf+RCB3JgA2KOm6ubV5999ukXX7z89FWIcj4+fvfN3fnuuH99bWn3RuS72m4PZYv4Ke1hQ+nqdWmPp/MSp3r16rWHej4vDBARiFAYCShQIDNUBhMTcRFoFSSjZvJK2BAaeHNtKnX9tnbtP1LXAUAFaoN+KGLGNcsZojsDMFNIIW5i2MQwpZgCByLoTC83dydwRtiA7UGft/yszklzb+dIjVvFtpgmNwXXy/yy76Bo7upgyEirYZ6DhmGcNleLO4I3KKf6sKBpWN1awA3VQKpLAfuwtNM0bXppZzImNXVAMYMmCEjM3CqRubmJSY/B6cp86FsQAZAjIwCvADcGgoAcEIEsEgfgAcDJjIKbIVICTIAIJKIARZACIANFBmqiWrJzcGV1L03Pi52znxf3LN99d5im8OLV1TBuPlwibqamvtqvWVPt+wExIREiAyAaIWAgigHHgWKkEKlbn5qZKoioiBFiiL1FQq3SOi3OlAmHFGPgIYZINkTY76eXL29+/vOfff7F59fX15vNhvtxkkhFTa2p5Sa1Sa3aDVYd3F3N1Pp4BTnn8uF79QMa3bo+VqoAqw11uTo+XD3exyV7cyNXuojRaf2SEIIzEIIQegAIgMGJMJgPBsmwAAhSZWzu2o02bPXeRXUSCS2HcorzfTrfy3ySqWBiQ/aLLPjy+d4fJC5L5YerxbuqtG9VSIyAhuZgZtLaUusiUwSz3Wb34uaVVj0+HpdlOZ7OBjRudkvO0mReipmW1krNjw93gBxDQvDT4bHV8nD/7t1uP242t7dvb+/ePjze1nwKMabNSEQiupRyPs/EGGJQNWmKQDGMMbga9aVujq01P8+hBbb2vaGPi0v2wgZuBRuK12Mth5pPWjMgeRg87WF7A/tnvH/O0zOkjWbUh6xNdcm03/J+g7tt54RoUOOIEA2CranqSIDeM/QQ+6y8J004ukMAGND3kcCYAF1VPC/2uOPzCeYzSgOA9137+ojMvtfsrtP4J2XS+qP3/w8BRPcN4U0MGwZCdRUtxUHMBdXIdAX+GwIrNIVSVZrR6FldV36KqBmDETn1YKsnb2d0B1OoKq0J4ftf7Bk3qiaitdTWJ4TuAPBxZf/owssyYeZnNzefvP7kL/7sz/7iz//iz/70T//oF798+eLZ9X6XYgzQKINbhZq95WBK7i7o2e1RHMAXAXF0RkZnhCHRfo/7Z7B9DnEEUytnPd/jfMK5WUvqGwtXHq8BJ7QAyIhsCDhsp5SSq5pTGBzTPOeHh8fj8aj6/XLYzEvTc66Pxzky5VIjUxrSuNmMw2bcbHZX+81mGBIxCrmTC7sF8EBIgcHQ1KyqutG4CSESRWK+qAfV1cEMJFs+wPJg5QBtRjM0AxGv1UvWPLcya8tqQuSJCPtSpYTEKcBmsDFy3zBFXd3USL9HfVjZD0+s+NVppB/2keL+5sXrL372/NPPN9cvyunt47vTb3979/Du9Bw/HV8OZXrheyzfnnypx7lsdtubZy+xwXyYW2uUEjthzymrhUKIxIjcBSaIaAYgqiXX5XQ+vDsebs/H2/l0l5dTXo55OS3LaZlPOc/S6kWq+yNXU6wFAViNY2AiVgmuwS0gBIdgxmbUQ2mZKEWOSLQCkk4ADDCojSKx58iZqyODkaOboVbrjNmuoVrP8aambq7m1aE5CLi6iZugG5EAZfFzk7O0ygodIHEAc1chrWQF/f2rhQgxUleMEhChExMhGaCoh0jcVX0IvlrxdC2+AyBiAAxIDMSA7MBIQMQhJOKIIYB5cyEKYZj6fIIIOaSYYhwnERIX6Pb3GJGSGIjVVrLkWTriYiXXcp4tF1QlB1tmOR7L6Zg3m/xhae8hnGogCqVprlKailpP43YAcEJDdGSEIfJ2GzabtMGe2eEiWqu0Jq1pYHKL1M3FzGIM283m5vrKzFNKY4pDCugtsT+72b16+ezF85ub6/2mL2V3dxA1ExUDMauitWqpotL7VidcS3XHS7+nr/ph1/7eQBQBgrQx5+35uD2fqIgqqoIgtG4SCetgiwABQakTYUEDGIMRGACpB/BgYEg1QAMTcEU0RzNcNZmoXHNcDsPxbXv4tl69peEGriJEXiWEH7TmeFnU+AQv/OByEbfeY62JpAQIYu5N6lKWYz6ntpSB07ObF8eHQ2CqpR5P58A0TNs552VZjse5tVhbU2nn9mjmkRMRLueTSj0d7u7vNsM4PtzfnU6HUhbVlsY0DAMx19ZKbU1t4MghgqM0dUGCSNgIFQgN3BRUpJ7PnHGMOISLbKK/ZOJtNhepC5iAFtVZNZsbIiInSDvbvqCrV37zAq5vaLODMLqynrTNuZ4XzjW1xuCEPe1oADfoxtQXZwUGMMA1sxvAvdMbFQCMEYYAtI3dtR6hidqy2OMejwd4DF7wQ/JGfzr9xPVxUw7QTxD++7p2t2g2AewCpwjGrl6hCVh1E1zHSeAASI4ELArSzMy5eB0MsCJVRysmrkrNAmLqebnwvrkGNPUei23rf/5+yGOmambqH9zPj75eePmeEMcYf/LpZ//xL/7jf/nLv/yf/u//4+efffbs5oYvgKRJ8Xz2+STzAqUFM3QAAciOrmDgVYGJxggJPBnuhvDsGq+f+ebamFGzLUc73YXjic4iddtw53EPaYfOWBsEAlZHQKY07CgEB2rqeal3d3ffffft/f1da98HscU8ix3OBd1rk915mMZxf0UvtylOu/311dXVZppiZEVrICvRmwAZncFNRUvVYmIhMcbtABSRGAgADLSBA7QKdfF8hOUI7YyaCR1BXBerWXJuJZdamhQxSQwYAsXAKWAYCGMabDv5dghj4pAFpGs0fsjgeP8Ll4fia7YBMIewvXr2/JOf7J5/ytP1w+3td7fLb7463L45HrfLiy2nqxejBY/f5FqOh8eU4vObG6pyfji2Vp0YAyCxqpQ8U4xx3PZ/AcG7x7NJa/k8H28fbn939+6b0/H2fLovyzmXuZal1qW1Iq2YFnddId4fXCJQKqlRayGlEEJwD+ABPTpEgyBKDbGQESoRMiEHioyRMTKFPoGoAmLV2C0G8IAUQKODGZCKeVRkxwBIkRIAm5lrM9OmXs2qWXUXbSqtaaugWduxLKeyLFIaCmLEfhRWAxGyxtDwg8MKYoeme8VeFx0QdSo4IoWUKITOOmdzhwB0OTEAAzISA5I5AiAxcUghDRwShgBirIUopHEkNGHjyG4QY0rDABWbCQAgB8QBMIkWaXMtsywzoRJCU81N5sVqJbeApNKgLHI+5fNm+XC7UtMm0hSqeLfizkWK6KXkuK8SL0TTMXFpg9pEDCmymdXalqXU0kQ0hrBS5VTMYUjp6nr38uVzIj5PyzSmMQUpZ/J2tZ+u95vtZhhSIATtO5GBmouZiIq5Gtamrao0FTEiDIGIkYihN2r40ejth7P2p2dFZLpb5ufH++vT42aeublJMFIHNwKPFAYmBkVv6lptjUbFbqSBjXAxKOQCaIiVw4lDIxcycVcHJVIkAwJDNkst2/ygx+/q4zdlusZhco6rF8D3LPz/8OXeqfBsRAxujAGAgTr2zIiA2lo+LSU3M3QkJGxS5+W83W4w4JLnw+Pj7vpmu9s2MUSy1sp8erh/S0h5PrvLfHp8vB9jGE7HY11yn+3HNF4/e7HZ7jqRh3kIaYjDBjj0WFUz75Zn6qZm6xjOPTDG3eAcAd4zgE1cFpGijq1bBLgAKCECR5x2vHvGu+ewfUZpBzSKRZegFNCd3UAF6qGVpqIsEtzQDAZzRkdyZnxP44Q++jZHBTToBnZ96IIAhFNE36J7EK2l2OHR7h98SDaTmz6FvHajATeA7xE6xiGOQ2hqqk+kR7+kUwICRMQt42aIab/hLUMwq4BLAQEQA4ULFRSMARkhIjEDBAiD7zYWo4Rd3RS5PS1LywqtowR4OQj2ecElfNGg13LvpXxFrS/jnfUz/b6rk2Ddiehmf/X5J5/+p//hf/x//OV/+dNf/eqnn32+2+0SByBwMO/uY64OgHHCYQspgTMGR+qcfrDgMBDsA4zoo+OziZ7vfTdpYDLBfMb5yKcjHjPM4LDz8RVuXtP0nHggBwRb1dtISEGN5rm8vb397W+//K//9b/++te/fvPmTa31B7dA5pCb0lLFrKpfUxohQphC2oQQEcxkbqCkLbB5AFBC7Z7bJsvS5kWFncIasmdmrQAYhuqOroatQJkxH7CeQTNiJTKTavWoJUuttbXcZ4fgQGEzDDRGDIFCcAyTwc0VvbgeX+zHd7lBNgO3npTy0UJfD8L9+ydhu6MTM3FUwyXL4bBwms8ZKkwatgXKu4eTvrv7bDuOm2mzn/LheL5/NzHe7K+GGF1bK1lKcUBiMrP5fOI4xGGDYCbqoi7S8pLPh4fbb+5vv368++bw8GZZDstyXM7nnGfRKtLMmq273Y/40F3ugtTRZV14DsTEhIExEEaiQESG3hxZgVonVIEFV0PVdV+nqlSNBdg4eIoYElgEYEtkEXCEkCwBEnvcGg9uYK1ph8XFmlnzLuuVojJLeSynd8f7x3ys3oyQu5pZ3btOwo3wo4OKu7dW3NyNER3BRFRE3bnXHnMwuMRROwSOPaYLEb2/RT0Tp5vwYHBKwBPEhEzkgizExIkJQBQxMCFTjJQCQSRJgMghuEc1rC2X8liWU1tyJGMCcWhiq4D/Ej1tYnlu51Md4nsv5tr0nEtTKM1Ls1w1V6lNzdzWYSPCOkSy0khAmkmz1qtya5KX0pqaWggq7rlgJy11l9hxHMYxibQhcgzUspaa55mPp/F8msdxDkmJAyIboJp1TYQaAHATLbm2JiJKTDGGEPoL4qsx6wfXDxnyl5YdIZo+Ox0+uXv77PiwyQs3cAsu4KgwAA3EV8kTNVcpIu4i/a5BEStCBpgJZsbiIIiZ4xHHGkjFUSpYUyQhcmAHIvOoDcrBzm+Xw+/i9qbsX8CwAYrvN9ofb6N+hLFlPVbZkATVIgIBAyUOFmMbtBZXX87zfJ7nXJu5E1Vpc56HKYVA5/l4e/dm//yFIaoBcUCvUvP5+AgIpSwEtpyPx5ACpWUp2ho6IvIwba9vXk7bKyBC4hCHmKYQJ8dYzZuZmIlKkyoqqtqKaBNwg0g29v73/f2aumRVNVNnTAET9ehgshBwsx32N2l37ePOMamSNrcKRpECkgOVRc5zyQWkBRPkToFxi4TMgJHQeeU0ugGgg4I3B0WA7jtPF3JoIuIxEEYHrLXeP8jbdzokCwRV8Gl2596tBz7qrxBgsx2WPOSlVZA+5f2IRAcQCbeRp02K11vcDc4OJ8V6dqtQES4bo3mXhyExMybiEcatD3ujjT4L9Srn9u25HGexetn04aImxtW+iIiof6HJ3QyhO0A86Sj/ry68YAAxhE9evfrzX/3pX/7P//n/+Zf/5fnz59NmQ0giimGdOq8LiRNP17B7Bue3RmcODYN5AAL0AL5Fuwq+Zdwgvtjgs8mnYGjQFlqOeD7yaYaz+BJ9uobNZ7R5HaZnTCM5rm6FRICkAudl+fa7t7/+53/6q7/6r3/zN3/z61//+t27dz/s2vtctZnNtUlH8PeklDCMyNEdVGrJGTgTlxjAEoMGVHYwU5FzrnNx2tKwUjS0VVUlJWQCd1CFWqiXdjkTFWQBVPRs7Sw1q5h0twYAARw48DDyELuzIyFtiJ4Bv7rZvLqevj7M5CaCP/KAnri0F+HbSoBg4hg5DqXI3e1jnN7mIqDqcZduXsaTneazvfnu5afPd9tht996Wc73twPCzXZHHExVSq7zieMQArvb+XROw2a7VwC0VqGJt1qWw+lwe/fu67ff/tt8fFeWh9aK1KXk47LMnd/kYAC6so5/rGXvz8ORFMDNSS/cRgpMkTEwRUJ0NDWvCg5mBqLexCJhIGRAcqdmVJXE2TkARYARIQGyR/KRaYtxchyQ2NOV8qgGnXwios28uYmrmZhp1nqs8/38+O54+5iPjQQoQHf9UrcmVlrAbvP8ATHIvdaiYm6EiExQa2tN3JiJEUjVNQBiP7gQdq8M6kjvKte37r7i6Bgck/FInHzl+xIRckAAckIHwhAwBg8EzjQkJOKYVNGqNMu5HnKZW2nGnpjskgeB5OzO5IwAanmu52NJz96X9tLkNJeiXsWreG3WmramIt4JXO4Ahu6OblWweStSl1rGIaTAblZrVXE34GBVug2GhcAcqNQGsC4aNWxN8jIv80m1ucMwbAFDTENMQxxG4tDFhLlUMyAO0qyWWmsTVWYCN3MmQ+ZA1JuI99cPunZ3RyD3oeYXp8c/+eZf/2///t8+f3yzlSWCE1Eg9Eg2ge4DXUdlsOIoiAxdN+8MjbwgzoAzwsLYEAvSEsdTeHYOu4ZxOt0O852CG66lHRHdkEVSOU+nt/X4XTl/psNOh71fXqzLYvbVVs0//va0+SKOm01t4mYAqOogRoFDmIar7Zg2ddjWOZeST/P5OJ+bCBG7mZQiTQijiIg0AOcQUkrDkCRnVQ3MSNSYTbW1WmsRcYAQ4mbcAAWctjccN+ZcqpbSWqtUaJ6zGoTAOddSW621ltLaotLIKTHHFIcUh0h9xPH+Yahb7V6+0KN7ERwJkJGRGSP7AIpavaEUFPaGpBA9ReIY3II1X1p7d2ymjmrS2BQmpMSGUSmKIwOIIwAYgDgIgnm3oUEGNPqgMI6M15uUb/DhBd6+839PFtDFQf0yOH9aoB92V4jXz7cGZT6XZamltNa61++FrYkQU9juN9PNlq8nnIKbQg0UohObIahD59n6KtsBAifw4BiAtgn3e+HtHPPDm9PhUE4gC5P0kI2LB3Hv2IkIvL+paA5EYOT4Me/v6Yd/YNaeUrq+uvqTP/6Tv/zP//lPf/WrVy9fpXHomQWACH0xgAGyUcRhj9dfQDmqHuxc3R8JKvabcfToMILvB3w24c0WRgaq2I4wH+DxHh+PcKhaWHDvm89o/zOML1xHM4LgPXAcQ2xid3d3v/nyq7/5m//9b//u7//pn/7bl19+eXd3V2u17yuUIAUcAiIgr2aNhEjmXmqd54UMrbUpLhBKiMUimgeToBzc1KW1XFtVHEcI2FCtLV6zuRM5EiAYmlKtUGbIJ9TZQqOoNKBpUcluDRFD5CnECKaom4ljcEIBWM11Ag3bMTy/ip88216/PSU4iYE6/hh4hxeWPAKQIQFHjCNv9nHaLaV9+/U359P56vrq6vn1NI2vfvaTuN19/dXXMj+Wh4cRr4dxU6bN+eHxFM5V0SkisUnNh/tpd7UZh9ZknudhmvclE2DLxWuRkk+Pd4eHt8v5wXUZokeMrRmT1EwlP9k72EXA8wRW/eAGAnDqv90p7ADIa9I8osMa6upo5qYGqqhKFkiYAhIBogOKkToZslMACEgNcMAQaQy0obAn3kA04gTDNfOojmLd1EOaeTNtrqJVW3k4n98e7t8d7h/Oh7nNkoyVQBwAQNSaaVMB6+qND1dNq+uZHhHcXMTNHBEiI5GZijQGAGMnZAB4MjnrwgA0QwdmJiADBo5AZNj3pnUIQSEAGMVk7sjBmRXcCSlyiDHEQRQNiebVZBsDO4AjD+MmMRmCuigouJtazuX+7iRC19f+pCM5z/Xd/VzVq4J0wqm4aqfIgcIF23NHcDIQp2bezIeiQ1wZu6ZmBohK2Fa7G3QzPZ0OD48P93d3x8OB0BG85LO2uixjrabOj8dlGMZps7m6up422xijQ5c1uJqaGrgTIQMRkyOIiIkhVkKq7Q/R6Hp6CLLbPp8+v/vmz7/69X/87T989vjdVksEI8Jum4lb1D35FYmDCBACMRKhEyp5I8uAC8CCmBEbU2Za4ngebu53n+W0vwkRvZhUNzNgR15dbNRCzcN8Nx7fLqc7mZ57mIzC08f7vq69s72+R9pC3Ox2obZWm5mZuIMz4JDGq92e9ib75f7Nd/fv3p7m0+PpWFojZFFrubTSCNnNED1EHoY0jsMwDAszOIzTxMxSzyU3VVV1wMiJ43RNaUojT9O1QSjNVdqy5Jxn0WZuY6sxpJKXpdRSSi25ldmt7ba73Xacpk1MAaEh6kcgRBcaO1DPeexEdHNyIiBQ1hLqmRHdavDWYAIDMHZFGjlSZBytmZdT1lqhuUoACM6OASgqByBgh4bUfbjFQeDJpRcAiRA6DILuHgl3Y2hX+OlzfPvM91NL7AXeUyC6+mPNC3r/OODZyx2FNhxjPOb5nJelQgERXS0PEWMK034arzbhekR2z607hDqSe5cR9db6qfPpGjshUnoRaLvRzc0Z8tvt9JhiU5gZWyRDdLGuqrpM3C9NOyJ3vQW6oRN5txf8A437kys0Ak7T9Orlq1/96lf/6X/+Tz//2c82m40hqDsS0LpIEN0AgtHowzU/+wlAVTzAsXhRaicS6yoPT6wpwG6Dz5/79QYGBq9UK5wf/HCPj2c4isiuxee4+4Kufu70TGtwBidnoE4JnvPy5Vdf//Xf/PX/8r/8v/7u7/7u9vbdPM/dMeSHN5IYhoAIxITMHGPosHPO+RQMGnhrkHIYqlpVJ/OozA0YVa21KtIMGR0jVG9QTiCG0tCFej6zG7fqNftyBMmYhEZgZ1MVK+oKFEIKgYMTAGkMxtBAzczQFQEwaQqbqym8vJquxpSQS4dYPr4bvNh1Xzhe6BQoTTjtx6vn4/65Otzf3j7evpvG4bNf/vzTn//0xRef3rx6fjrd3X33Lj/cjxximuK0K9+9O51zL+0Uo7eSH+6GwNubV25+nI/jcG7LQgAtZy25LvPpcHc43EqbYwCOiQBKIURb5pgDu6tozzd+ogT9+KtFASj1XgUvKmp0IPfVa84B3RXBGFXRBVGZ1ChxCAS9/qM6GpAjAzJiABJghTTgFGkLtAHeMIcQR4w7w9Sl4GLei3VVKdZqLaXMt8fH7+7fvXu8O82n4sURsbE3BUdv6lWtWnMD9A+0bwC+SuEIe7ykdSUUghM7uKgUADdj4kDISEwYVqc3JOjeYgDEEZDFyZGByaFjdoQcMTCEgGAcB3BHZgdXUEfgyCHFOAwgKI6BmQlDjO4jqSOFYXMVUhSXJqWYWmuiuiwGd8dS/Jd/9P7RnBcxyM36lnOheDm6o10CTvtsEAHIQBzEQQxr8BKoU6F9nfSZr27TUFvJeT4cHh4e7x/u746HR5XmpggemHLR0rxUv7s/DcOw3+9evCjPnj3b7nYxxn6SMBNTJ8TAXV2PQCCitVV3Q4Ba/0BpBwRHdtvU5adv//3Pf/v3/+Hrf/zp7W+v8yGCMimyGRsnw4F9NB+kKFRWZucAimiR1L0oZvXFsSAWxsZcI9cQahzn/afH/WdEGLylw5uwnBqhEfuqngI0DS3Hcgzne9o84rBHChcw61JDLg4pvahf3HPeL/jtbhtrq0uW1jrJUNUhbRIP22nk/Y0s+f7dm2WZHw8P3pqptZJrzpvtPsQESCGEEDhGSimmNKQQncL1fh9iqPmo2hADxzFOW6A0VVWtKRJiXM4LYBOFZT7m+YCMrZzzPIaQWq0lzzWfpC2BfH+1/8UvfvH69WtRm5f58PBuWebddv9+sRiarG2gE/oTCU1Biy2HYupx5jhhTD5sdLPTzd431aeNl0ETAQNBiD7AIv72jM2giL1UfIWEjDE4gCCAcwNCR3M0RAbgC5V9zdTDSxQHu08Mz3fh02fp0+f6cNfy4qKwJmT3+bwifRwrcP1yx6MNm2HcDOfjcD7l0ynnpdbazJyIgNg4YOQUKUoO5wPNZyzVVfGp53m6fQNo5t2pOzRWAKSC+Gj+ndijWHSoQEbklwE7fMgkRURAQuLu/uYAgJcquNrIXj74hYf3fvtyACCiF8+e//Ev/ujnP/n561evp2mzzunRO1rYMwfd2SE4TzBeA0VKFDbmh40/7jTfgwmqopsO3MbRh1c0fc5pYHbUs+dHP89+PtvSvLGl57D5GV//nHZfFN1I9TQRB0YOpcq7uzf/9pvf/tVf/fVf/c3f/PrX/+3Nmze5ZJUP992Pa4kbgzFTCJxSGocUmRBcpbbiDUFZjcVUTE3FBaCKO5g1kVqbg1II7mZmbbEmVBu1RlLQhMgILJpSKzafrWWPQg0jBgevWtTUBcEIOQABkoEqNu3TSwInRBZRNhcjMELgbgXi37PS/GDksf40EA+bq5e7V59fffL57vnLVrXMS3k8WMmPb+9iSNfPnm2ur/Y3V8vjQz2fzhTH6xsOQ2tacm61EWEcUtO8nB42m014/gkalHnJ4VgOx0AotdS85HwutSg4pSHR3jVro9Z1MMjEjA1XVchKAIBVMfyDy93ctb+WHfUSFfBmaIRKwNAVK935A4wAGpMoN/bY4aHuct5fOMCAZESEHQZHrNbaomhpmEaPmygDCPTHB67gRdrSytyW83w6HB/f3r797t13j4f71op7c0cFFENEcnOrYqKmpuDyIQKMwGvmuvnKKHJ3N21mZ1FnaSF0TjE7BOIx8BgSmK2yt17jjQiRHYMT+0WrZw4O7BgMiBCBIgE4UVd/4aqMAKSe3Q3Ye7lhE9MErbED8YCcQtilWLWZmACYiM5LcywfHuhrcydT7w06+Crn6Sco0IvkGsHfR7c6urq6NQW98Hs65q6mrdZcynk+neZjKTmX5VxaFlcxMyPAiJiQDamp5Vxyzsuy5FyOx+PV9dV2uxvHkTnAJS6eCQk6C9tMVUXNFNxV/kCoqyMaJKlXp4dffPOvf/Gbv/3ld//06vC7jTeC1fcd2DB4SAYJITRDCEEpOCdCIh9IxEu27J4FcsBK3JhrCI1D47RMz4/PfsEEEdpVq2OekdgDG/SDW3BEshbqHM73vHmg7XOLI1BcV8cP4PcfrhZEnIYhELJpcQXRUpuoaZrQYBo222lz2L5DwlLz6XgMSO7QgfJ9KZNBb2W6F0EIHGMMISDgfr9NKZ0OY60VMHCa4rilOI7i2mpEA9fz6dRBs5wPJT8AWCYKYQicRFRaFSkIenWz/+yTT/78z/70pz/72Zt397/73deP9++WJbv5k0u2O7j22RO59xyAjk+5VV0OOS81pMCJQoS0sc21bWbfFtjsfBxgGmBMHjBQDE3tdNbTrIezZgUKTAFShNB5c47OnRHVX2KClfrucPHZWP2gzSPBzcSf3KQvXtndO7i/L7m5A6/DRAdS+DDIAxGvXuzSFsbNMG3HaZvHxyXEcOQZTi7Sc5pIETtcFErm4yMeT75kqLIGX/TOe/XNxq47BTMUQEBHWpo8lvqmtqPotWFnbl421adxzvuPBABEzLwGYKsprBY166mEiD6MfYNLXUfAEMKr56/+5Jd//JPPvrjeXUeOoh0tXlWv61bn5BCARxgZxj3uNvF60uNVm0af36E3cgETRSuEwq8BfjLAOIGSvoF89oLeSG0wmmD7GT37Y775BW0/9bNrVSOCGIDjMi+//e2///Vf/fX/+r/+v//u7//+7bt38zz3e/xx7zyAXiECUQo8DWkaUgyMYCZNxU3R1FzN1U1dwasBoItRK5JLc2YcKBqoikpWzVwLl4ItozQiYzJAZ5F2nqVlYyHFFBIwFqlNmhVxhZVPRc6kBbULILo9bEhmAcqC0pqZ9Rr+PYtsAOiU4L7feTcR4WGzf/His58/+8lPr16/arXl4+n07Zvz29vl8Xzr3332y19sr6+2u91mO9V5BqNhf0MUTE1KbUtOY0pjkgz5dKjbK1R31ZZzoVAfH51ZtdW65DI3E+BAuMEQ6oJVW65YqosCdKOEC8OvE84udL/vPxEzU7N+CgXXXjQNkIkI1wrt3s/Zhu7k1phEQwuQAhJSr2sMEJC0h3X0KudUm4nmU5uL0bS53isJDltDJkIANasmSyvH5XRcjg/Hh9u727e3b9/evT3nk1kDUDfvosUeqWtVTNx7U/sxiEKMrt61J119rOYi1axxaDEU5oQeRViEibYWzByDEbA7cb9h7F8q7KM0h9UHExzIHc0QGIFCVzfaKqRzArsMbN3dwIzcQ5ooDFazS0OO7oF5k6JoaKCtayqtGZaPKqJYj5iCtcw8bWW06iL8MjYhhEAUiWI3Buk0o+4rj95b99rq8XR8eHw4HA+n5dyH1uLk3G2YtecA8TDFcRNSBISSl2WVaB1Pp+vr6+ubm5tpmogDcyAOveVwW2Xtbu4dIfgDbnRsMpb5xeO3P//u3/7oq3/8xbf//Oz0LtlC6EambM5dvgNMTms5xrBhBjQCcWpEjax2Vwu0RtQCtcg1cEVs7o1CHXbHF39EacDaoFW3ZsQtbjSMxKFLKFjKuDy05VbLKxu3NnQvl8sCx9+neuuXgyzeqstiLUutZc5LLuA6jOM4TuO0B07dcN/NwjhyGpY8y3JysEstZ5NS80m1dQIXOQTAgXlMw5BS1T4Mcw48biYp5K22Uks+1SatWa2HWu/dlRCZI1HsTOwYeL/b/vLnP//zP/uz//Anv7p+dvN4mGtt85xPx8U+sQ/vw+3i9AfvVbzuAAYm3cjJqVJhL9lKtSX7kmHcYBxwHGAcPQUKzJahnYmaJCgPlWfwDAlDBMINYZ/6IJI7rdvlikNZZ86v1R2RAAJAYr/ewi9+MhyO8Pa25qYL0oqXOq5atQ8qyWY/hdFDDGmM0yZMEw8DDgmZPM9N1QkR1Nucz3fHZT7IwzmcF6zFilgzFEddG6CVhk8AA8WbXXh1g9d7iXxalsPpWFWVSRwVEQzeC9dXdOfSj/cVxt5RemYTRSJocBG3IMYYYgz8YXV3AIcYw3bavH75+hdf/OLF9QvGYOZdUwhE0PNyOulezY2IB6RE5IgDTgOlkWj0co9YEYVQrZRlXu6W+Phlul6uPv9iv+GbmG5o84nVO6SCu0D7P6brv+Cbn8D4jKBRqjCwc3SiWtt337357W++/Orrr29vb2spT0OF3z9c6JRCYg4xxtjvEzujUM1QVUW0oRZXZgCy6mjic5Y5SxynTYhW0eYmWltTqplqhlpQCpEFcg0U3KVWKdKwgsEQHSPV2lqtsmSrumoheQ1xQgACDAiBmCIL0d0j3h3knJvYj9+Jr+etC4rNBBwhROAgteXjOYS42+w2n4ay2b/95pta692be6QAEMbN/uF4a8tsTQJgYkKT5XjvumEORGHJZT6dl9NRS2N3q3U+HFMMiCAqat5j5cxMFXODpdpcZC7d9EguCZsfiPJ+rK4DgAlpubxmrgSVCRm1R0ZzD95bz7fdUa5b11mPHeumbwQYaBWJE5IgE7A7llJPtd2ejqcm+93NzXJear7aXqeQmEjdSy3H+fh4erg73N4+3N093D8cH4/lLNYcDbo3ipqWntOJ3h0mun/LB6vD3Zcs4Ea8AvJNrHPQ3OFmN7x4mTZTSpFq8WWRmnOrYI2aEYYBA4CBMYILqANUAGQioDWT2gDNgxobUH8fVk6PuambG4mjQCutznMri2njNMSYDNy65aYqI6WYbJgASqvVTIm6UdiHi8MAdTVIWB8Xed+KwQicCAJTIIqEiXFMITJRL+a9xyc0N1EteX54vL+7v79/fKhN1DEOY4gJOHGaQmBmdofA9Gy/vdlN2xTQ9XQ8LMvsarWU8+mIYG6Se1DZMKZh5BCQuN8/OjASMjAR0Uc38nFp1zaW06vbr37+1T/84ptff3H75b48IDVFaOwtmAUDAiIYAjA7kDoF3DAgi0FtWAyLYyWv7NWxBZTILXANoTpVQEHWOM1XL3z3nOcD5FPM9+jaxpuWrpwjIgQvgBRlHspjrQdtVx6TU1x35XVU+lR+Lqv8w7dMFpdqkk2KSWnlPB8PahqHabu/2d+8cgocByR2t5CGYRj58U5N3b0LFRixlZzPR2nNHVQN3EEtAI4xjjG17jBlFUFjYrBQpbXW8nwuZWmt1XZs7QCuiIAUCLm3glO6ut5vfvGzn/35n/7ZF1980c/C87ycTsv5vHyYueu9Z6aPnHTXeTN220o3BxVz8Fb6WvLaII5IEVJPt08UhmiNZEbJBk0emp7RNVAcAzAw44DOACtVb/U1uKjQcR0zdfd5WJsD32/4p5/Hx6N9+RUfF23ZqpErwWoo9MFKARjGxMkC0zCwbMI08ThgZACzI+acJRCiW13K8Vbm+dgel1gKSuuiHDIAWwF/NVAET0gx8rMdv7qxaSzmp9PpfDo5OIVgDrquy4/RnffYGuAaWefkZGZIDoBm6EYAyMzjlMYxfc+lFRFjTLvt7tWLl198+vnV7godTd3RsBNxDEBdXXt8UafSEyNEAh4dthA3RFuTA1ADVmSz87m8vf/u68O/fDO/PCKEZ69vnl0Pr3l7JDrCVtCYN1/w7hcwfqa8Uz3XLiyqsImjqp1P54eHh8fHx3mewR0R/XvRzR9fiL1jJiYOIYYQiXjNUDJTdVGTptWN1ZDdCE20Fj1mORXfw0hTgAJmVURqqVAXaBlqAS0MFthtiJFIi7SqxcXNW0IcSKq03OqctVSw7mu5RmJ3m9FIGEmdOQN9d+9vHuS41GZmQD/KMHd4mmUjIkMIjthEzo+Hlstud7Xb7Xc3N76/yrncvXv38PZeBYaIw7Aze1dbllIIMRGp1vPDrWljZqJQa5vP5/l4UAFGAJV8PnlKHEM37wMkJBLzUnQpbV7qnMuSO4NZ3P2DdYBPa/eHKKMKSQ2rT5IrYmXyyNGMAgdnQyAz7CEc0C0iGIBs/eZg5ozkAESk2F1pqAGJ+WEpt6fjNw/vHpfz9XI611y15ZrHYRM4unuu+fH4cPd4+/buzd3j3cPxMJcs1lYKbf/s6ia67gYdIohMMWB4X0vcoRRB9Aio5k20ipYmph4Cbffxk882+31MEZZZzkd5vK+ngzcJ2rg3SbDm0Dlg6w2oUURi504hRldQESBq6uvhxhTX/B0DNghWcy7zoZVZVRIih0johgBibsqIMUQfBvBRdTEQYgwcP14dxqRrOmlvJ8zMQETAhAkC0sgpBYyEQ8AxYQqMCOrW1Lt2z0Sa5PP8cPvw5u7h4Xg8IQ9p3IY4xGHi5Ml9mqZhGNWMEW/2081u2KWA1rpAIi+zq7Za5zO4qbQqmzaZAnj0RCG6QycHIXQ3QeCPN6uPSju5RlmeHb/77PY3z8/fbvWRudWAmWAhWAg9wC7AdoK0Q99iCyCEQlwbLQbnYnPV2e1MlgM0Rk2oESVQJSo+5LCRMHocbbwq0/7w6Z+Lw+buN3G51zBq2rXhyuJAIEjuHAk0lYPlBw/Jw7TCHB+AJQCA3U7so4cDFIg9BI1BPSgQFwSUWk6nx3k+NRWOabO/ckOp1dSIIzEBeKt5mc+tNhXL55mBtCkBuVoTqTlLSonTEOPxPMtykNM9u4tTExWtTXLNp5rnJlW1+JPUn3rEOBHFIfFuO+62w2aM6FpKzcuyLDnnnEv5uNPqrjtgbgS0Nu6d0vEEW/gF5EPqukMTbBWses7Ks3AQDOoevAUXAo9ytPrGnIoa5Dq4DTfXvN0CcUcOfQ0bX4lnQMQO0EnqRMyEgX0a5ObKP3ttv/xpWirmb5uYmYdu2PohudwBeqYuM/IUYaCUMLKjm6kRYTjkaeAYwKWVU8lLzkWieefPGrk/KdR7EDohpcibkcZooPPDw8G5HArObcOBptG9dmckA3C6xMrDk5cGwXuvBAdwIlRdZQAcKIYwDGmzG6dpCOF7DoGQUtrt9vv91Xa7jzHZiqx0JwBYfb0vdb1H8TgyM3VHDseNpxfOG6MGATCQ+aKbZ7fnX//9P/7b8JvHN2+WP/0PP/nVH392s/3luOFg5goYd5CuMF6L+Hd377793e9qWTbT+Ce/+tU0TZ99/vlPfvKT/d//HROp6u9v1t8v9O6IjZ3xRwzY8ydAzJp6FauX8uHNi2vJdZ7LuXoxhuhTc0CRKlJrLRladSneCmhltMgA5pFZm0nTrGrmbVYSEzHJVotrXqVEhu5I4EzgASAwMmEBP4h8eStf35WHuVZzc7ugNj9cHwBIiIwx0pCa6+n0uByOJHgcp83+6vlnn26vroar3VZambPmdy9f3gRMgNxaOd3fByQwd9XH+3eisr9+yZgAuNZ6erwjSoyO6E2FPCCh95lGDxMqJS+n+Xw4nx/zcqx1Vm393I3YJ7X+RJL/0Vm7CrVCZtAlmAhKBEoWAoegTA2AzLp5RIdb2GxtI9WtU+8SMSMqeReUgEN3Pj/U/LCcj8t8zgtwAAqI3JptxhxjMtd5Ob+9fXP38O7hcH9azkstYurkAGBrwBm+/+x94RAyM6f0vYkVM6tqrdLEcmnzLDnrdju8eLF7/cnuxYtpGgMRBm5Mzc3ccV6klEUV3Aw5EQVXBYDui44MyG5GyIAc3UgFDSzXpl0T74rdnN8B2JFaXo7LfGsyrw0VDw4Ehm4VrBE6MVJIFCeqC5kRAn6gNwYARk9kgS0whMBEQcRr1bnVJi0EHkIcGSIjgZELGJhK1wCotlLbUsvxdH44HB4Oh8PxVEXjMMY4xbRlSm6ESByIw0RhdDVAMOiABAaiEGJKyVRcJTAFRkZw01YzgIvUmIYQB0S2zuzr3hYAqr8fkEfXoPVqvn91+N1NuR3xjAlKwFuiO/QTYGD4NGoa3bbgE2aAhmiRaqHF/Vz8PMvMvmygMDYmCaQRG1NFLpgqjUaMCB6SDOPp1R9XTi0O4/3XZAo8tPFGhq0xIHqEBsyh5ZhPErcGZNyfwdNSsRXg6jSrD+6DmNkhRIiGqhhTYg6qspwP8/lYckbkzWaX51yXubXyNBErZTkdH2utZp7ngoZgzhTAQaQtSx7T0FMCXIssmg93JlIoNQcVrXUu+VzLWVUMZJ1XO4KbERABujFBihi4dxcl55ZzLqXUVltrH04anoS874emqw9Wl1P4pd4BdLMwYAZEJxNv3VaNwNGcuq0MEhLS0LIs92KuIsWMGII6E9HAwKFbbfV/alUtrCxfFzRDImAI5uNgGOT1C/zlT4bDCd89trxAXQfi3yOfgZRi2Lp9UoxhSBTQwUybgTuBJ+IYAKzVpZVSq2oj4MAAYJ1ki5fhAAEy8TSEzQiRmtbTu+WQrWTEjAOwx7hw6/J8J0BHYuxBBr527E+OWStubRfxGyFyoGkzbLfjbj9N08jhR7r2zWa7mbbjMBIFc0C7xGOKO4j56rNLvRd2cgBk7nIDh+S8Axoc1QJCjD40mnbFvn33rqjMLYft5pPPvnixffbZ5uqKHLvGyBGcQ5Plzbvv/umff306HK/2+1effvr5F1988sknn3/xxdV+H0Kwp6SR3zeyeq/x7wpqXvVTPfnNoKk3hUpdtYJV3avMcz2fS1ZUgmFjS1UwD25ScivZpLo0lwraGC0ygmNkt+ZNbBEztRqEGpm5ZK8FpaIpqLt0eqgDIYQ+ryU8Nn+3tC/v6zeP5ZBFVizoe64vT3g8AhJyjONm2u9jiqaiZ7FzW+g0Py48bHizTbv9DuD2y+/K+Xy92QxjQGQ1O97fR0AQc9XT4x047LYvGSJT1NZOh/shTUSDA4o2tsCQrItXRFpZynJazo/z6WE+P7RylLaYtcsMCz4AdX/vecuEpLFq59NZR32VTIMKE1FPFkMHRCJCZuoRYKSm4kxICIQhBWdbKXluZoDQ1M61nkpepFURrjUsC9NJDYtIDFG0Hk+P37z95v7h9pzPTZqBA2EHq1YSmT9ZBqykMkIgJoqMH7SJXXyiqrVqLm1ZSqluRpvN8MmnVy9f7vb7FAJ3uygzr0VrA3UVq9ZQxcjUKTp2M1frwBqaG3UZe+wjf1HJOa+lHYDA+jARm4PXmo91uQPJkQgoAia82N3Ciksih0Q6Eg+oSuRE6cPtisETWeIO7XDgUKrNptVVtQXGiBAQQpclmomYKiBC05ZLPs3z4+l0d/94e/845yLqcZjGzSalbQgbczQFYkaMANE99NOqGjVxUUcCJO7V3ZUiY4ocAzO5mdTqIq3VyrESRVytCFfJrekfoNEBMuCkstcyUY1DgxHyEL7G4SuIFWxHbc/5KrUF3d2PgMLEEy4Nl2gLQjYsBIWwRmqMEllCKBgWj60bGpczn+90nCzFNl3Zi59ZiGX3ajje8TKDO0rxMGlMxjsIyXkEc6qFIHgC71lyfql7Dj8Uv60bWD/fcogJUhpjGqxkycvy8HB4e4tojMFam8/zNO3qtHMzQFzy2Q8kqhwGaV68bcZxGjchpn4471yPKqJmVZa7+7c8zxKSE6NDng+5nKVmB0O88Po7fdPN1BvgMi+Hw/HxcH88PqRxNEM1V704r390K++Zc5efXtZY9xDopBNYCaLMTmEVXa0UTmNAx972AjoTIihgUbo79VF9zdkOS1wkvjDeE4V0iYqgy2RwFdigE4A5EoSIyA7arnf8k9fx9o6/+o5mh0d39Qt/6IObyKdZIYcYcErjMIUhgarUIeemJsQOYkjq3bsNzQJYJIgECChKT4ncBojgiWg7+W5TGZY8H+7m40FKHaUFUZSmzVXIPSAxpkCoRkRN1NsH3vC4Hgd7ooOZmzkHTinsr6ar6+20ScMQ+eO+BAAIiSlQT6FAMsDOyi1VpEmnBcfIIbAiCCGHAERsiEruCg4EAQm6ZzpCiBS2I/7k9U//pz/7H0z02fPrL15/up22IRCQmXXrzf7M1cp8un/79puvjo/H9vzFcj67KjPHGAMHpu5u/mML4sNbACdYGUCBkOiCFzu4oTqKuTgWNVG3IpJrzdoqCaADtirLki1gANWatRQXcRGX5iaMpowI1NhNrDaZaxP06EiJgVybNwUxFFsNC8AN3LocSh2L09tZf3eQr4/toVhxRMKLFuZ7l68Z0cwUx+vnLz/52R9N+xuOYz2W/JjViOIY0w5hjCMy8Dk9LMfczgtq6MmUy+N9MUgxoLvWLCVD04A8pEHavJzuNeWUdjGOGKIYY6PmWmo5nR6Oh7vz8f58ul/mx5qPKotZg+4p22mlfX52UVD96OMwA1U3fS+SA1j9EUXtQpzttTMQmZGsFqTOCWLgGDiZczclJTBDRXBDbKpFm5giYggxhMQcAVDVaqut1dP58HC47/26aOuZYsjETO5uqk9QF3SWKToxcSAiuATZvd+bzIqZ9rJ9PrcQ482z7etPrl693u32EdBUVMXy0pZFanNzSkNCDstiuRSTqo27zYkDIKG7IAXnQDAAoQgYruipmSMzIRICMzOxipvMLZ9NZgYgGogSYOrbX89t8Q5fM8cYkPtEikNIH6KM6EAGATAiBQdSQxEUZcBAgYjBUcTWaRICuJlqa+V8Pj0+PjweD4fjec4tV3VkCiPxRLxBHGAVO4AZoIJKn38iEbp1o2vgrsFhijFgoBQoBY6ROTCF7u9C7lZLRmxIkTkQx6dZyYfv1cddOyA6RLfRJQbljdoedYKjxTsbO1O/cC0RiUABjkDGFAfOAywJlgCZqAQokVoiYZTAjUOxUIzVnCXH+U4Pv6vD6Gn0NMn+haaxbZ5p/CY9vgn1iKAOaBRbGCEMBNHUIS9oCIAQEmD/zL1175u1f3wXQETsbsTO4IFSSnEYRURqLfP5/HA3TENkBpNyPiznQ843ZoqEtRQxSzFxiFKbi27HKaYhpgRE55wNIcZYWxO12qodHzEvECfiQIg1n1vNZrLusfQ0Qbh8eQFb0/N5eXx8vLu/jcPGMJZSVbQfiP/AjtwL9NPf9X4X7z1px5r4UggML/9il9xc9m8iRxTHY7HatImci2TThlYxQgwboIGQAyARuPd5U+fyeY+oAGBGNAS03YSvn8NnL/H1DT4UP8+uDj+8ibpk8WxjDIHAMQSGIY6btN1Ht4EZJItXcTV3M3KJYJE8MFZ9Eg/1kxIiAKOP0aZU0edST4/n010rzZqmiljBG2iXTIVIiMhKzI4F3UXNLtB8F4u+r+sAEAJPm7S/mq6up5RCCJ1H/NGFiERMyIirdbMDmlkpNS/FVQkRPXUuCFIflTCSsIKDE2FkZkQkAyIADgSbxF988rn+uTj4dj+9/PTFbrOJgcClNavFEYkQwWU+Hh9ub2/ffDfPmYjevX272+/v7u5Ox6NIg/fF7w/VdgQg9IAYuyE5wsqjvHArzKGZqYJmaUttcwEFAnZCRGjSlrxYAAb1VrQUaOpiZs1VGV0Z3TCSqnppcsytuQdDHpxixwZQncRJXXvUA3k/dnI2Ogr8bpavjvJ2sbO6ACIB2I9z0NbmmJjisNldP3/+arO7Jk6Zy8JFPSCPw7BnHIchQojbaaN8aPNZCrg0MjufDigerq8R3Gq1UrTWflqSZsv5VPMyTjIMG0uDQiuSm8pSluPj3fHh3Tw/Lsuh1ZO0Bbz+nrr+0af93mV2qetPvrld3KAOuiJg7j3fW4E7KdSAFJzJBcEZySCqm2gfQBmCK3hTFW1mSoghxBgih4CI6lZblVbvH+/uH++P51NuBaE75awMOe9IUd84fAXuECFEDIk9oKF/XEvcvbmZ6ZrUvtmGV693r17vbp6Nw0Du1n04c27zWZYFWmMO4xg6mUirWzMwZfPuiI5uDTmAJwcHQkNDC02klcUciGPPRuuHEZWibZY6g1akSBTdgzm5gbqbqVnraY5EwB1DJaQQOKYPD42rEM8c1NVFXVsRrYIOTERA5t5EQDvUBeZaazmfT4/397e3bw+H03kp5kxxjGlgHoknpAEoAvJ6FgU0A1M38RCIERm7Pz24A3GIKTEDgw+RExP341TgLhYQ6ab0BmjuFgCoH8/99wPy4A5gCtAY2wQSKbwO0z69PFOdXYuOqszWAiCzBq7EEkOlcA54HmDZYFYtidqWJaAgNqSCnJ2KAmoZ/N4PXxJpzwDX3Usf98CDbl9kHNr2WTq/43J0adbUvLmREriAgXpSNMDBPRHC2pk+te8frh9EmMYoyoiCQG4WYxrG0cyIAwcwOweGlPAYXb2VuuR8VtMQYhORWiMHDNTqUqWVzRhi5CHxMFQVm+cQg2oTEVFVq+zIFL1bSjZxNyJOQ2RmIhIVESEEZhrTMAzjkAamcDyef/e7b5cKTunu7jbnxdUJ+WMsexW1I74XVnb0Edfh8+XeCZDAe/CSm/awVXdmx66dQ3B0xO7Nhg4oyAZ0W7U8SKN2bpJlAB8+eZ1iQOf+91Mn8SEReKDQyWaA4KDIxgPjbmjPd/bZC3p3tndZs/xYe4Lde81MVVWZABA50riJ5kqMJVZZkAposwpUgapBKhqy0tLJImACqAAMztC5wgVwMZyVlkpL1qKtEjZ2De5IoRvRBLLIashMjt4qSIc+12Z9bY+IgIiGMUybNG3iOIWnbJnvXYQUODAHJkZid6xNm9a8lFZqCiGGFDgQUE9bUDVZ8vG8EIdpux2HREzYmb8EiMQEKdLNzbX/4qfEtNlO26vNdrsJjC52PJ4fHk9MkTlKbe/e3P3u2/s3dydmOi3l//zbv/2Xf/u3b7/99p//5Z/fvHnbWrOn6cLvvwiNwQm0pz0wOqNH8oAQCQMHRK/iZW7Hh5MsCzQdOEwDowOYttZKQRAPoN6y1mrSY0vF1RhBEFWUiUS9VD3OrZhjBZo4TEQM7mhECtQzhAkJMDol53AsdlvaXZWDtAygvZL56uj1/b69/woiEGNIrdnx8XQ+ZlNzD0jDMG2mzVWKmwhpDBtOqVxf6fH+fHjMeTYULYu1BQ1MNwiOalZLmR9DSu6iWpfzwR3m8zkOQ0oDhQDEpdWcc493a3UWWVwzeR+xXwIVLqO09YN+v8Z/7z5WKQddDCRWkvA6kAN0dFuhJiTjaHHwGC0EZVIA7UHWsmrfAQEUvKmom5MjI1/2DSR0sFyWZTkfTofzfBJp4OC0mjohIXTgmpmwu8ev/HwHD5FjCsqsRPRhs4sQE4pYawpgm2188Wr67IvtzfMhBL7IF8kBRCxnOZ8lZ0xjGIY4DEMKKUcri+aitbbVqh0CXXjD5u5SnUjEWm0AZOzEXTBHAGhtkXKAWsiAMLgGaWo5u/VpUXOttPpqo0kFVyQgJmL+8LUiQmZsorXWNTBKfR1pA5qLS7fmNlVp0kotp/Pp8fHhcDicjsfa1CBwSBQmjCNwAo5AATkAh950gxMiASERxoBjwk2iKWJADETjNiJOJgVBU3cb7krGS+w2Bw4aVF3FTFtTIWZiNvv9gLwhCNI8jofd1bFt5+E4vgx0HW8O4EHaqWFrIVgbEYagKTbmFqJDnBGWwfMWC1EN2KYgBOpQHbNRMWwCIDX5Ec/G0ChEYm7aRFXHnYWxXY0y7W3ahNMdnR+hFQVyW7OG3AWgQqjIA/TA+PcV/QdrBnFMUcT6ON5UQgxpGNxhGNJ2N2wn244ahxCDmbdS53k+uCsCmIqqWopO2NoiJS95N+AGQ+AY6zLXVqIGAAck5mhdF6ddNWT9izsO49X1zThOIcRc85JncGGC3Wa73WwIOcaYl/Lm7bvTIgrh9u2b5Xw0FcIf69qfVjushM3eu65Uuh603Rsx6vluau5u2B0oCcCgy3+tCwz6yFkQgVmdSkN9bNXakHCbaEi83QTifh5Y//0LKE/9QLvuzRAi2zTIzd4/eU5fP8Bwr2cB9e9TmYmZkKBHINVG6IAeAg9TQoQuAinACCCu1Sk7lapDUZwN8qWbEUAFJAR2KyK1ZudFfBZe1GrzptIYBcAYu/w3EIeBMUTHkHID9BmKZVVfZaDdxYn7aZxgHMM0xXEMKXGvOj+2C3fGlTXRJkpsTVtptTYxMyKOKfbbRUBVyUueczmfc0zjJ599HuNAzgR4OWYBIjDjdrfpNknjNKUhhMCOVmu5u3387VdfmSNCKHN59+bdv/7m62/f3m02UzW7Pzy2Vu/ubt+8eXt7dyvy+2RiH119EgDvnTecCQJhIAxIBKSipcnjodzdnqWU6L4dwLCLbiyK1IpI6iDeitSiat1cE8wDkROBGruLWhErYrm5ioBYMKIBiQmBHAmoHz0DcDSMDcPB6m3Tg8gM2gg9IDitcy37saNWvyMOIY3mOJ/nlkuZ55C24/Ym8AYnAHVoTsYh4rgZhw3fvzueH+9CYDMFqQBoUntmp0mbzw/colotZcnzsYkAHkOMIcXON2ytlVJUqko3dWkArcdcr9yftS7DpTqvg7kfv3rEfM9I4TU4FBzMHS7xyN0t0s2AAck4eEhd/GgAnSOuCtaBXQJAcHMXVXNDBCJckZ/uBG+ay3w6H+f5XGpWs6dx+lr+EbsGmQmY+jjEe747MxETEPnHZ18knKaoaiHotAlXKbx6vX35arPbppWuA6jm0jQXOc/1PNdl8dEYIIwjxxTc0RVaE4DqauYIlKBTpN1UxYAUSQ1MFJDR3I19dcw0qbPmI7UajABZlawqenGvoIJiaI5k2AOsTImIIfSIxo+eBgIC1NZqLbIuaiLqkfYArqImJrXVJeclz/M8H0+nx8PDPC+1CVJMw8RxDMMUhimkMaQUUqQQidbxaF92fVcPDGOkzRg2idAtME7TEBOZVXSNCIEgrP790g3+3MEMm2jD1pqqqjih0fdcqj4q7UI8j5vvXn7+r+0/7E7L6G2aGg2Ge92wSSxeKwf1bbB90mkQCBVYnDN4S1F21AZrBMIoAE2hSK/raM1RNWAjzaE8xsO/j6DL/LjsH/Kzn+r+lccBYmrb5xonmq6hLiDN1dwu9O84YEywRpZc0KtOj/9e1w4wchBQUPNgSsiEMXDYjkMMr1/sPnu9DzGUqkQqVvJy5ONA4K6t1SwqIZK5tJal5Xk5i7uqAqKqSKvuFmPcbPeIKD3mxdRMu06CmW5unv38F//h+YtPtvvr83x6eLhd5gdp59007qZR1cyM0JfT6XTKS5V3t3fH40NrBb6f6LwSCp6uPpdENL908dgXPfaQBTBfg0p7SGufxfcDQJePcTdecAdCiEgxQByE6tnq20f/6tu23fBmw9foW6SeDofdoUHRtTNXDbGbYxEypEH3e3j5kl/cw/6NnsUz+PduJIwDiPQc5JKLu8XIxGGzoRRjSsJY0SM7imq2dlYcqg1nxcW9ogIqIRhgD5ovoMeljrAMwyy8VMoWBN1W645OVncSGSDst8Ow2VAcTnPte1RrIs1UBVaAkZkZ0Jhhs0njyDERM2rHr753ObRWz+fT3eP929t3FOJutzdTUQXHEBOnyJGBANDRIZfy5Vdfffvm7emUr6+fb7fX+92VBVDq7jbrwdQckDkOkThgYAcyR3eq1b/++ru//qv//fb+/nSay1KOh9OXX3759t3bcYgxcmul1pJzzktelkV/j0fN9y5ipsCOaF03tLr3hECMgCK2lHp/Ot8/nO8P1VuLAHP1uepmCpspxAitGZOh9wNODzACc0TiECOlECIxAak5q0DwrFmsVm0oKMSpu593Gp8DReeUlY5Vb6vcNz2DaQAAoE5H91XP8EMcpRNPYhq2u6thnBBQ61KOd0KzL4Wy+yxxczPuK0WctkFdncxssXpyi+iA3gBQpAYKiYObnQ7vjEBM6nIsy1lVjKg2pNKPyL62cm7ght4QxJ8yYPypGENvNf/wcAQAKDjHFZGg7qfeBdIGHYIxBViBKrIL28jd3UwdzYW8H3sNwZG6ezJ02xhHR0YyhNXIzc1UrM3Lec7nps3gQpO7WDwhETMxXUYBtmq1CRjczcCq+Aqevb8LJrq5uRpSTqES02Y/PnuxG4dIhIBmAK6es5yO9fi4HI/LPLdSXdRb01rbkCZTEjO1YlZF1RSJBjYld2AzJwWUbl/j3TTTHNicFMSUrJ2sLaCtUw4BAEkAClgDteCRYUvkROYmADwN2Kyatdbah+XDVKRmqUVaRcAUeX2bO89UVaTNy3w8He8e7o+n45JzqaW2Zg7IIaYxjVMaNmmc0rhJ42YYxzgMSAzeRQdP9gTgboQUI223424MrVRwDyEOQwphYnR0JVcCAzddjWnc1EV9XT0spdZaW87yh4xmlXiJ43fPPvtX+NX142kq7Wp4HPmMmwqhEgM08pB0k3y3lTRYA23QGjRFjUGQxEzcxL0pNIfqWIS0AYoHNyJ1K66g87uiormUZpj2MN0AR4/R08bigGlCKVhyJ0R20hiECHEADpel8zEgDx+WdhyYGcFVhakyxsjjkIYUnl1Nn73affpyu5T2Xc6iYuatlnw+MCKCSStiWsvibqJNTXNexKBDnWYdkorb/dXN9QsiPB6P83xalqWzJJEgxnh98/wnP/vlZ5//4vrZq/N8urt7ezi8mU93Q/AhYCullkJgJm1ezqeltDoz2ZDI/XuA/MfX0769OtesfLXeskM/q/uqG+ALurbG2uM6KvNO/PBu3gU4MA8JkYzw3OrtQd/dt6sr4hC499pIZD0boCevGZD1ND5HBEZKOG7p+gafXcH1Rg9ZW/tYo4QYYwAM1hTMWi7ghjgMY0oppeQhiFuQilRFaszCZ8BRcajggslBETsSgAARgMwll3K0udjZh7lQNhY05zVUvIMLrhrdN4F2U0zTMEZya24qouuhEJwJYooxRiIn9mkKw8CB+3D3R9zPHLzU8nh8+Pa7b3/z718CBoDQ9+IQiJm5M+HRzXTJy9u3b//x1//45ZdfNfEvPv3ZL3/2xy+fv+zNEwIaPpn9ASBg6DHvnReMZl6KfPft23/4h3/88t+/fHd7W3KrpZ5Op1xyCIToJS89s3XFrP97CjuAExmSOFS1IMpq0T04kLmYFZHH8/z24XR/XI6zuFhCyOK5WXNH5iH6EJzR0NXFVE3dFYhiDHEYN9M0piFiIHeRWIRCQ652yiJWq7mbAgbkwNSNP5xCczo1vy1yV+WoVhEgUQhMSoooa8X07y2OPpJCBEZOHBnQRbwWKLN70+qlEmRPO1VxjqySJC9qzbWCli6hYDQDNC3qFsBN23I6VhNxt7ZYzeZiiI4Xt8Wnr/H6aQRA8WLJ/P5NufTrf9hdCwCIIfYkZ3TmlU/jBt213N1RwLvOy5x8NUVzRXUHVHRGNyBDVkJDNwQE99XKBRwROnbe8XPR1mpb8ryURUwu/fr6tQToMdqdnutmbm6I6/S949IGBgzI7h+orYjo6mozpDAMLUTaXY3jZghMly8FmHspcj7V07nO57JkqWKQdZlrzm1IEyKreM61VlFzAqaLt7CrutGq9eo8fWZa9fzkigDgcvaW0cXcexgemoM2UEFVQgo8IAGxGRbsJxN08wrwUbNrKuJFpJq2GGJgioGIWc1UtLU6L/Pj4fH+4f7d7d3pdKrSzL3nzoU0pXE7TNth2MRhM0ybYdzEIXGMvkL6hk49TNvN3Y0IYuTNZthOwwKoIkSMxN0BFa2XdnUV1GZoBGDkRKZEzJ0M5Krauagf3shHpd0BG4XbzcsAch3yMPMn9NUVvgWeMSwhDMGNItMw4LQ1I18KzhVKM2flZMSqYgamYkpi3BpKRmwwmAe0gOBECrq0XHw2WCRlE32qzghdoZicgocRzLpmEQEu/Abox9UnQP5Hlg9CZyp58Ba9sA5DDJGfX00/eb1/eb3ZTfHfv72/vT2dTxUhgoGVup6WzRDcWpW+TVJoqlIWaUWkAOg4DS9ev/70s5988fnPRdpvf/Mv9lZFViiEKBIPu+tnz1599uzVZ1fXLzZXz7dXz06nl+fzbVke23Joaig1UWAAFXWLu90zA5/n2kTDhwZP8IO9eu3eL9aVeJGqkWF3aQRacQ6mrudcjz3U3VURATtdyrtxCWEIPkScAnMiYX3M7e2dx5jIAzqAYQAOwAmAwdUVo8ctQgAjN0aPiQYYt36195c7OsyWxfjjj93XqCG4qzUVcA4hDkhhYEAErRmIqzMJUUY+Y0xkzK7RRnbvSXKIQCDkGEFAc4WThllpaaE6CYEBQLdpIoBuBWlmpWBN0xhCAroeA1NMw+GYz3N2E0JLMcQUCA3RQqC4+rIBEvDHKr7+s5Lz3d3tv/7bv4xxix4349Vm2sTAQLhGzTg4QKn1d9/87h9+/Q//2//nf/v6d98+v3m1Ha8eHx5OxxPyFjH2+U9nXDKzr2PU9chEBK21ZTnf39999803v/v663e371RUVUXETKUhAHyoYv/vK+sAAM08q6NqBdFcLQRAFrHg0Go7LfnhtNyf51OuSwMwamu+n1PzVH2IPrAHdgLAtf9ApjBur7ZX11c317vtFEkZmrdacx4Pc+Cza4UFGFAczQAdmUKMMQ2hmOeqj6XdL+3YpDgYc4w8UiRPla2CNBSt8r2z1mq+4GBV6mkmYQsJigwYghOLcs5uR4GIFE4MZYnQDvWUoVmAHiYNkVndTLJDccAmtbS5qqiDWwNrT4FUT0g7QFeedlrne6L4+gTenwAuf+oPXiFCTJ0A4x2N739yzR7u7qYNtYHUnnPAKuti7vdP6BCsnwkBjdbSrj0pDBCI0b2XZ6tN8pJzzU2agb0/v3bU0R3VBADATQ1s3X4Rn+BRdzRAQ1Zt73EtREgphUDTlIgxRGYk69pQQndQMxEttdWirXmtXqq1ms1KPFUO0R1NvTUFgHEct5tht2shUGtWWwBn177/kTmjB3dDI0RycHcFzagVUZABCYAZumTWBbw5siMbOZAZuKjlVtVKCD4MH3IGwFTVGpheshl73p6JWC7l8XC4f3y8u384HA/neRY14hA4cIichpjGYdpO0z4Nm5g2MQ3EwQ2kyXpOW0cb4AbkPejTQ+BhSNNmdNVSwMxrFQBPyjEAMwdi7blEqt3QJqWkaj2vmYiIOcaoJzm10/v36nvvmQIf4x428BtfAvpS+YVyoEPiZdrIwBATU0oQkmdzRStq5+pkuKH1JC+uxmIojbWCV0eFABAQGFGBDLhBLJBKmCRtNaSeVwrm0LOOEIECUHhiinUmIii49eOJX75/OkR/BMhHCoZm0Qf1OnhMKabw6cvdzz+93g7JxKTdH465FGdK/ayw2o50nY2bqQIxAqs5qKgqAqRh2O13rz757POf/uInP/2jeT6/vX0THu9Dip1lwiGMm83u5vn++sVmd53GbRghjpths9ks+9Px3fk4cIi6DBMjgwJSiGXabyiG42HOudKHrocrsndB4y/cGljVvB13cu87TM9XMHBF8DXgzC4bDBgA4SW5AC6e3EQEIfiQfBwgRYDoWe0w++aRsVJd1KqNMW4i7yImBkfhATgQOjaB2jAXPhdwsHHAmy3sB7g/A328mzE5BXcgI6pV3VS6sMpX22hczcBBiSqFhUJk4wTGJm5oneSK7kABILqwF4Oz2lmsqAuAMQAhhh4RCavtq7nmpstiEWKMu4i+HZBTTEMcB1MBlBgwBkJX7BEd3cZ6rbI/mI8iqkrOy+F4uLu7Ox6OZalDGCIyMqGTm7fSSl1u79/+469//Tf/5//x9//4D3f3D250Op9Oh+P5eIqJAIy4467CzJAuZ7SVCAEA3lo9nU8PD3fv3r29v7s9PD48xbj9X8rb/vBVzYooOJKBYFFEMy8hsFmt9XRejks5l1aaac8ddAcHNBgEcvNcNbPHaIGBoKuUgCgMw7TZXW+vbjb7KaCyN5SShggO0uqyUFMABSIUREIKzDGElJKIabVqvqhWc8MuT4qJh0hD8bY0INGm8CNcFAcEdBGZlyYMQYJpghiA2BHUoFUvi8xxcc1nRDvofLYqaACo4EQEZq5tMXMEUGmtLWJqDgDdAOZietQX0tOXvhO53+88l+f3gwL/hx8HM4Tkl7a5///9gNhRendHDyAEZtBxdgC0NfLAiYBDR++M0JCs+wSYmXhXwDustdnFpJRlyUttVU1WARteJu2dLtcDELtYzN4D9j1V3fuHREdEk/dtIiKGQNDnwuvK6fJ6ugD93dyYQuAYQmCo4GbWqqg4UrMOETikFNKQ9lf47Bmm6MuiS7acMVeDpqLkHt1il7cCkruZCWpBq8iGyLRqYHphbmiVIDQVbYoq7s20ArQUbbcP20368Ajvbubqa3q0a5+tqy2lHM/z3cPD/ePh8XBcchbVPoIKnVs4jGncDOMujduUphhH4oBAPYt3nR7jRf5gAI6GSAQhUEwcUwgpNBF116Zm7skJOXJgJlNV9dYEFVNKKYZOkWZWVevh43ObP3yvfhDq6q5OM45f8asW5JDbJwWv9HCVlqtRdwOMI3BEQ7RWAV1VdalIEAPEBIhkGMSpNqwLWjEWIyckVOQKafHpDLtTfDZvPlle/FJe/sKvXnhMXfndV8IHcO7Tu9Tj9BxXnuYHdd0vxi4f3lUgAxT3lHCLMQ3p+mr76avdp6922vzudlka54aGHFLyNXevaz4QgXwliQRw6i6GxDHGNMbw7PmL15/+5Pmrz6bddWkKxMDEMZiZio6b7c3LV/ubl2HcGnIVdQdVYJ6mDYc47PbXWs5Q5+QV2sLxu+P5uN3vKEZyjpQ/Vlv5h3UdVhbcpUW4gOLYVRyEK2qpDoC9d19P2e/rOsJF0I1ABMQIgSwFHaMMg6ZBKUIDOhwp39P9d7Wc2rO9P7/iF3vcTcjBkkJKbo1y5vsjvXvAXB0AwHwz2Xb0yPChGhzBIwsE5zGaMsytGYh5bZqrMJI3KaW1UkE0IBlz4bAkZ+uncUdRF3N1d6BIGFEDFMKz8OJQCTQAMGEAitRTHNZDGlKrOh8XkhaHSClFHvZDcgo4pKZmoIQaUEAFVMzcAc0dTTtn/vuCK3cOcRynZzfPXn/yyXa3c3dVNTPCwMQqOp/nr7/96p//9Z/+6v/4r3//3/7umzdv3Dsrpy5Lnuc5jbH7cBAjMwIgqSHjh+cIM885Pz4+Ptw/PD4+5pz9gw36/5+6Dg5NtEhzQxStZk1ayzkxBbDuJqZqDB65n7T7vtyrHFa13HxhTAQpWCRkRDBgwkjESKJWqhhDJAocMQ6UMg2BE3LFzoqDxJw4BI5MiUgYhxhT1BA0qCl6DHGIaQhDwIjBNKAHxP5F+nBr8Isw0gxac6iuBICMAyMyBiAGJDCzWtWbL2rtaOXQSjFTAev4qKjUtojICgleDId7v27rLvS9Io2XT/H000uPgQ4XC7dLgf5DF7JT6Kfvvs1d/rres/diz/3ESa1AayYNLpA5coBoTuCBXNmor3bvZE9toq0JIhFCs6biy7LksnQVe/e2XedDeHnzHFcypl0maxdnBbhA9xcw4qN7814WO8RGnb3XXyFCQkaepuHmxggDhxiHGmINoZWifUK1ljviaZNevJyev5xePE8phZxxWWw5t+OpPBzyUgAkmaW1wGN0d1d1bWiiXSqpRlgBydxMq2sVKATmUAAKosbg2w3eXMdXr6fdbsMfvliXfXI93nQv93m+f3x4OBwO5/Occ48I6JZyzBE5AAaixHGMwxTTGOJAHLss1lZG7krA6SEsF0TWiZADYf9NAiBUMVNXNXcPjCl2ryJGIlWTKgg4jVMIMYQgIrVVUooxxPiRY+73S3vfQBrE+3DVomacH1Set+FZPb10e8Z+PfrQjaSQMYhxVV+0CRVhZw5oEnPDc8WlgFUDM0QGDEpD5WlO+/Nws+xel+vP9eUv9PkXPu2AaQUqnsrWirH70wkW3GEVJV94c++/ffSO9fOjOwj7iDwMPE3j9fXmerfdbDanU1s0Z0WDHpARHBwQOQTiYGAOxkxIBBQc0Jqie+QwDeN+v3vx8tXrT396df2SwmCA/W8IKRExY9hfP3vx+rPN1bUjihqJIpA7EkcOIQ0DwB6loGbWRfMhizhhGgcknIaEagXfH2x+BJBf37yn38c+cqeOt3vvPwwALifXdQO4dBLgl3UK6ITGDJExssVgKUFIrAiniuUEevA3v9XlUV4+89OLIC9JbnAzIaHbYoJ8PvG7d/Tbb1DMdztX8WnwzegDQ/i42WUyYogpqUJVbEVFzUtzygxoTZZ5qaUGMQBWjjXZrOAOCp3jpCBq6uCAkTGSIhaAk8JM2AIaAESkQBDRCcndOhLAKIBVfS5tAAiAFomYIlJkNGIACoQBAcSdANQd1tw7QqAfmJY7ABGlGLeb7fXV1TgOgK6mTapnWYot+XR3/+6f/vXX//jrf/jb//b3v/3qy3Oep3HTVkvWkpe8zINjP6oHouAOpoaA2CmilwPd8XR6+/bt7f3d8XSqrbl/bM3///N10SF1NaJJA1fUZowBnRCZYIiIjEGtqndI9wn+NXcDdAIISInJnRRBnMystZpnI69SU+Qh0BAJnLVHtqSAiYiMmWFgSkzcjXghEk0pbQbYjCAu0DwQJ+LgQGZsFsA8APmH7mcfrYee/UXu7MDUhdlIyEDsDAaiMouoWtF61HL0trg20I7uoJiKiVh9D4+te2HfDp8a8f+er+9Hf+K/64GQIa0Joisc04toL+/oCEjkXdDZc+1d4SKDxtDlE+xqzp12B449FUy1Nam1ETGiuYGIlpqrNHfrqpeO/12+kr7O+f3pdj9oZy8/RLycWT64Se/ebKadcYxAjIAGBoQX16wUA+4GJAyRhilMm5SXVmunBHifLhHTuEk3z9PNs7i/iinxtIFN9nkwpFYlIyEAqmITb+pNVdVNFbSCtj6Nc6+kiwOYm1kxq+RGoEiVuQ0j7Hf84sX08sX04uVmmsb5iE8brZo3M1knX9ZaO59Ph+Px/vHhcDoutYoqUg+I6WB5IIrEkUIKceQ4UEhAwXEdJbo/4aSr9vB9V7aqn7zP8de21VxkhXxqpCGqRgakECJzqLW2JqoaAofAiO4QmN0dwsfWmd8H5J/65BaG47Cv6fndedmc6k2rn2f9vNmn5HvvSaBE4+SbKuNZFrE5cw0pDq7hVIbHHE6VupmSUVCMQpsW93X3vF290Oef2c2nvn/pm2snBNPLwbC/2E/Vuh+X1w3vqcY/lXZ/Ku0fL6VIZA6ROQWexpQiRxOopea0FJmriCPFyIFJ0bqObRzTMPb2PYTAgdXd3LvR7H6zvd5fv3j28ub5i6tnL2McWxMpEigNaeNmaZuurp7vnz3fP3sxjJtlWVIYUojdAqLPiZkCYUIe0Sa00YmH7Ty0Jlq1FHIbAlV7/5LhZdVduKr+0R7jnXHlvWXvLCrroI0TBuZAhAFAFbQDl51kvNLIyThIZEqMkT0whZAowCz6mK09+PJO3n1l5aCPJzhlEAsK/Jp8GLzNUtVP98Ptt/TVl2iIr145BxuibQYYI8TvOb04IjBxciQgMKhNpEipTcHdqrTzUpeMBo7BQmwI5l7dqkHuZ35kCA6AEMiZFbC6z8EyegNUQk8EqwK0j9nBAJ3JmDWSJqoBm4PWKlYXgdK8AisTJwoJnNAAndgRmSEQECg9tTLryvDLnrvmpiGhkws0bW05ng7Hh2++/eqrr377L7/5l3//6rdv72/Pee4RmK39f9n7kyZJkiQ9FORFRFTVNneP8IjMyuqq6kJv6H4PwBEnzGEuINyAEy74L/gzIAIOOPd5iPBAM5c3ABqN7qrKzMrYw8PDN9t0ERFmnoOomlssmZVbNb15VFJRGRbuZmq6iAgzf/zxxykOMcYYh9gPA3lGCh6QiAEoZ0VDdlAcfDNNKd/e3r548eLq6qrruyy5mIvfHQN+ixECN5WTDCWgIFA0IKDAGDyHEBQgZulj7oeUdOzZAQZExg5CzfXCNXOqG+AMGEU1W4r9bh1ThFBhqKq6qut61tTOYVbOHMRXWiV1Ao65CegYENRUszj281CtZrzLpBbREihi0iwDZMh9RsmOjCvi43mFI2+0ZE/YYQgUnGNgAhrZOQRAqjbk1EcdUu4k7jTtQXqVaGCKiEwGo0rZaJ5UD6m+cWv5BnLrNDsOkeyUJIDROv+uR1akoMvLaa8bD2OFzVBSb+OGV9h1gAaqoGoA5txBp8FUjUALmJSliBDEwrUxNRFJkg2sMOuMDvj/2BQd4KAse2TV7wW94QC2f3hZqiKas0jZRLGo/mMpCTZCIoZQMbGralie+BituFSSU4qpH3JMYki+crMF+jCCHyGU6F+HqPXeQkXzuVd1bUvbnWy2SVUsZZUEmky9qUlCxGxmCqI2AEaHyA7rgMtleHjuzx9Vjx7NV6vKOwTAbncPv8SUYxphtrZt2/1+t9+3bdunIWUBJHYFJmVAMmBEJvbOV87XLtTEwYDUALRo8Be1SgNgIEZCIDAAUCuCRACmKjENw6ApZ1ERHQupQbVn8ASByTHUVa0iKUZV7brOTOs6IIL3rGIlu3n8ON437cWhAARFjr5K9el+ljZ93maOse/boXPxNA6ejAFYGajSMMMoNKhmowygHFOzh8W6Xgh6I1b0yl78PDdLXT3Qk4dwcg7LB1DNgQOYgAocTR44mHa7N/aTxRv96aPA/X7Pvb8qhyqAZp5x2YTgEDWzZklpiLlLkgyoJEhU2YeqaebLk9liweSIyDlHjkvhMwB455ez1cny5Oz0wXy+dKFR0dhHU3O+ms0WIYTlcvX40U/DbG5MKcW727eo6p0LoSFyRmVBjKrLBh5Qwc3qxSmCptRqGkgTaO7ebuN9Bmtkub/zgA6Gv7BZC7COgACqoNmKMDtnJWWCUcTw3ltAMAQjQzbnwHvwjpwjcihAXYbtLm421l1p+xY2t5p21op2hoKaDUsPSc0oit0e2p1tN6aoi5nN5hgcNUEbZ+HdpiqS1UApmSjkDDlZHERUE2VTtZRtSGBQhB8NOCtls2iWRaMAs5Irfm5JJ2I2TAqRNAOAc+zYPAFhLq7p2FsAjVgcZ8/JOyU0NBXNIiJCZoEMna89VJ5EKEtp01nUWktDy/dqEcdIN6W83W0v376pQp00M1PKcb2+ubq+fPXq+cvXzy8uXt/c3Qw5ZpUSBGXJMadUZI5K445SUIGkYjkncgjkXbHrWbuuu7y8fPLkydXVVYrR9Iju9IPHvAlpVuWkOauoIppHDA6b2i9m9XKxAKYhprbrt/tuSJoVREDNZrWbz6vlKqxO/XwGdVCMYpgko4hIHrKpxgF8yLGOMaWUfeUAJBqq91jXaAqOoaqMMBcQLplDcqGa1eFEfJY2R8silkWiSlJNYiLMSKMM/wfPBMwsq0axpJYQtWwPCGBGqhhNk0rMfc6dpM6kB01W5B8IERjGRjkMwFhKi1RgtKgHE2dWMIv3sbQpTT79OQS3YwR/pCP/0VEwfzwcqwTq0/amWmJgVEHR0lp0lPEpHbqLgZ/+6BglagkEJWWJMVnJE5iZmZhZSTZRib9LgD8x5A43Fd5/BTh5LONe/e5FmRXQUCQDkVMqmOFICgADMkIjQh+IHVfVeDmSpe+kbTWKgpoPGGp0wZBHoW5CRRS1ZCDOo/fuwakz49obaB66lCxZjiYZQIiFnTEaYkY0JCVO5CQw1ZVbzP3JaXj4KDw8r09Om7p2Kir5HYvYx6Hf98MwdF1fzPowDCklQ7CRY1v0/B2SZxdcVfvQhKoOzdz7mtkBUsG3cIJwyiwaSUWEZmA0dj1U0xRT3/VouRRBgAEoqFhWjYP1hJV3VDnvQ11rSjmnmFImQselzgSK/sB7qtjvMuRHu1k4n4rEtliZc8lX+/3p62G9k/X17e3JXVyyNpwdKwsBLbzHJjPnrBFBGa3O9aOu/qNULSzUQGzkrGqgWdh8CbOlhRm4GpAn8ilOfuGR13ow7e8iZBOVbArn75l00wxECA6TSZbkmCrP80CYlclAU8q5i5KUgCr2TUC3XJ2cPnz08PyT07OHpZ6zcLsK0RqRHHMdmqqq66omdqqWc4SCzVbVYnXmvXv44PFPPv35kOLl1cVut8lxSEPvnJvNTnyokQgIcllPVgpUlMA1s5PVLHhST9lBMomv1n8f83C4EJjM+3RvDlc6qluNqE5pyamgYsUpz0lJjItQRVmSiFPIjkBIDnww78F7Yh+A3D5ZauXqFm+ubbiReAtpUFWIPfR3NojuO+x7HHr87Nx571IGBSUSA1MBVPKeawdNUHc0rcxgGLLFmDKL0tCl2KXUR1VDJlCDLM7IV1UgR+wL2ae075aUUxJn7ACKUJeAZoOsJoWK5p2f1ehdAhtEU0yFG6RF74AMyCIBEJFziKSkSplZ5iCh8qH27AgJ+0H7KERszpWVeKC6vLcRZ8ltt3/56kUc5OLy4vTsgZr0Q3d3d313d70pcmhDoR/blBofQygxVQAkInbMhc2PKeeuG3xgDlwAvBjjZrN59erVF19+cXV9JR+pr/8hA5eLGvtm6FKMImZg4MCawKt5/eBseX7+kL3rY1rvdtXdtutjSpqzqtpyXj04a84fhIdnvnHJQVRiUW/ZikdloEhqlnMasmnMkXrnHAKqsaeGCEEdCzsxk5RQtFeomGsI3vNyVvW97WFQyTlpjpJjNlE0K8L9HwTPJf6RlIeu26pYDsmRo1GOVdU0m2aVpLnwCkBLU8BCjkMDQrTS65odl/46GKMIqR5LNoGNlDQ40vy6DzbejdrHj02Y+u+K2ov9HhnqxQgiGKqYahFXQGIUAUkqpY2njQ57mV367iBAMxVVVVEVURkfM452EpHQFbVMMLVsCaSIGNohQv/IjbYDXDiRgN59q5WKOxU009IaCUxRD26FApZtS4vvVZKwkvu2X9/tr9epz7g6bXxVEqNSCAUiMvRxs+vaToh8VYWmJjADkaFLbegHGnqJYEIMVaOzhXqnji0EDAGriqvKVZWrKz+bhdnMzxahqj0ipCSmoO868G3X3d62bdft27bvh5QSIhWIF5FkbJdMSI44uKqpZ/O6mVf1zPmAzuO9omi5QWUnK7VJxV5hQXWQAMkkS9/3exLL3rEHdGPNshWOoUXKQ5W8o+B8XdcI0HVt13U55cgUTAt1wTvv3TvW/L2o3QB0YoGaAaKvgJyii9U8d6uhX+776i6GxeBq3Xsa2IjVN+qX6BvrgiQ1Rg4wO5UHf5RmD6yejbXGVYCqglBbqADcKDVeHOwDEgVH6ProasI9MHawbYf/2LSijtYPIs5nddcPsMk5Y5Yswk6TZsuR2j7vuxQzIFf1jIjw9Oz84aNPHjx8tDp94JAIwCbnt+i+MLFzDomHmFVjjCkOfd/t265NKYGBcwGQhyHt97v1zfVmczsMnYowudls50ODzFBkH3BcHETgGRqvswDzynnnmSu0jPhOwPtBmfsRW6eA63TQuMBSylpooZJEkrDwfQyBRa9RAQFLcVHlqoqdYzHcDTgk3Xd6s4a7DeYtaAuWEJVypthTzDhEk2TDQEPE5RwETDQ7n1DBhE2QzVUOF7W8lxTNogrZIKqhZNGcJWZRYAdc6u8IAxMjG5KYZmT1DsBnwkTkwBxAqe3T0smZiJlD7V0d/KxC56II9inv2twny1imsCGUYkBBVCQYNfbQMXmipuZZQ8AopQ28oTIZc1aQrJoNsnwo7qaqMcXb9U0/xLvtzWw+F8lDHHb7TbvfxTRkyeODm56dqaWcY0pRcja1gw5QCTaT9EM2wtrQALPqbr+/vLp88fLFs+fP1gdi/I+AxI8T6PHjh6ugm7vtbtv2MeUkqEaE3rumrhbzJjTVTK2aN81i3g8pxZyGlGOaz8PpSX126pYn6ABRRCGQEaGXIJpVFRRQkI1ZCRMKqDkdm3gIOUHMxEKUskbVHFVjrnJaYIZSGovEOHbcMZDyGEsvTuc+qv1rACqSonVmljUxMeEErZuKSVECVxVTwVHiQQFAgcyUVEEYUVGLMHqxgmOJxEHc4KDz+r4J/4hp/whY/Q0GPidIAxwE/EfTrlgicskGAOzIRtMMRlDkhqZNUScF54LjagHBi387krNUS3k6llC9NK/kKSB4/35+wDG5v+DD3/ecgPfn171hK7betAAgNhbl6+QFmqrmnGJs22G7G/adCjoidp6RYBToMBPJQ0z7LsUE3tdV8I4JzByL5+w5MUbQAdl8RbMlnpzxrIY6UBW4qrCqsAoUgvPehcodekOojt1Q33sybd/dbTZd3/fDoKoG6Jwj58g5LHJBgEiM5Ii9rxpfz3zVuFCz8wB4qOKwMblWxBiIivSP0WjtoTyCsntTkZQHRibiwpMykJxVNKbc9yk4bqpCjS8F72gqhKii2Wz0097FtD4w7WNlbtlN0ArTpV5CaHR+EuND2Z/1u+v15o3rrimuWQZHusD6FOpT3J7AJpNj79x8wWfntHiszcKQzKyo8lthfJmOmsiGBnaYXvdB+yHrY4cEo01A/T3wdRzd308uxNXpgnaIVxpzv+s6y1TZAOLFaNvm3X6IydjXdbNsZrOT0wfLk4euWogxIqppTpJSSiNciYSlyZDFGFMaUoopxjS0+91mc7cpFOnYp+vLq67dbe6uu36fJeUkMWpVzdhXBYcpwFpBYonIOWoqntduOasWszCvnCPI+pHlMk2Rd1JgJW8yMuSnFW5qaGhqOQvGjJHYj/ekKMSKGSE576u6mTWzKhAzdLHv+37bp12X21aHhBYRIqBw4empUu4xRoiD7lrZtXB+Bos5JdEqRFEw9ZoYvAtEyxkpvOOgGBmYKiQAKiU6KmKKwJ6dD1w0GA3MsmoGy2hKaOzKJpSL5nlBHIiZua6ruqmrxSzMKvRsBDEp7bsEls2S6dHtKtSkQiAEBTQiJUM2CsBOldAMyJsHAEYlkqhZco5RYtIPHkehCw2pS5LaYcu3JdSWgrXbaAHGtxZPTNViTl0auhwHyTJVy2bNEmWIaUjCwRuwGlqWu836xcsXz18+v7h8vdvt4McdCD/7+c/kfHHx6uLy4m262w4xqogXHANCUGbmJtSr1QMmSZKH2O+7vm1DwPnczWZaVUImljxVjtB8DSaaUk4i2dCQyAVBylmyasaxACoDJsNsbEZRoU3QdrlvO8+yiBCqipzPOZfCODIiMVQFAgIjzy582K1nNKVqGRQtaZKhdEScsnQ66qyA4r1Vmkw0ggGO0a1gTpG59OYqkQYeZF5s7C8iRQ1s+t4JRH/XtNv9DPhWIw84tBMfp6BqNu7EpSkcIJoiMbIjADMxwQMsf39uY6ZnNCmlYbyxQ+/JUEGxAF9kk3I1lsjQ7q93NDrHAdT79/seSP1gYiFw0WiAgjWOiMKo/1SOO4KLhTOtElPq+7hvY9spIFdVmDVhVnvi0mhOzVByTjHHQdUoVC4ED8AqSURUM4KAiUFmh/WCFqf08DGv5tWiboJn55AJmcqWi4SgotmUBNGV9pb03oUOw7Db77KIAbhQMTtiBmIABGLPjtgTOWKH5J335IIhl0xlcZVwBGtGNIWQkEEVDUu7hNL8BgjJOQyVrytXBfKeg/fOeUSvAgNSREgxppS7HoKn5aJ2vgre13VVV1WMQ05DjP0wRCaqq/DehXzQHuY4Mh7/LlXmbD6YD8peuI5YI6+su6PcOZDOhkG7rNcAquAE2cxABBDNVUA8LiicQm2zqSdaWX5TBA5HdPcDIF/oMvc4PEwx+lFK6p0phiGE2Uwens5TlqZ2zhGqCYXBQlQUcOShpiqEumnm7JtkaEPsszAYmuSYcsop5dJUEUeOmgxDn9KQU8w55dwPXdv1nUmOsTeDnHKKQ+w7UQG0tu0Mb53bI3NxlqcSuzFD4hyH4JoqbJpq3lSzuvKOsnxjdDYG/VNdemnCOP0QtKSZR8VhySIxY6FqAZTJBIw+hLqpq9A4biTLLg79kPfdsE/aJs3RJBsmwAQoBIBlUkqBws1ELWZd7/TBiTkHSQ0ABk0MRuglAb/LPgMEct5MiAi0KEkUZigxe3YVO4cIpqKSFSSbKIiiAaHZpNyGaohEwM5RCGExb5aLZjEPdaUMYkpZlKjvY0piMiTTsYyokJBAysZW2Eijap9ZLpsjAjsiQiAWxJwtgWVRzfJ+QnGcmCYqopLy8OFvx2896miipuMSzDmbFlFvI+jj0Lbtbte2+2G5WpJzWWokfXPx9osvv3r16vV2u40pjlDqjxa148PHjynNRszAQAFzyj44X9fkfFFO8VUVZrOqrnNK3W7rySpnoeJm7kPITINmKclfY2DPhAApqWRAR0XRnVj7KCkVzkrJ5yYpk5cENCtE0S7lks6vc67qWkzZOw9UCPHsGVQIrPZUeSb+qNEcH6aAYhEZxQOqbIBaKJVW6kmARpYKYCnMKltO8SKKfMUBK0NEm5Izo+s8qs6Nph2/9sF8h6eVI+au8K0KrofjNC0lQWN9fenJBsAgbIRwnxUwK+w15pxzdmU1mYpmtUxsrkJgymIFBtdD9ne0r4dmCXZ/yK8/2w+226NfqYGVEMYIpsgNQQ+LEQwPujcmojnG2PVD20ufjAM3tW8qrjwJFBKzisEwpCHmrEBEIRA7EJWUc0oppRRTypoVzQcLc2kW1sxpPqd5zUV+akJKcbInCpoBjHSysu9ej6hmLdTqMe4Y4ycgYnbOswuIjMQGDEA2auSNXNOyzR/suplNfeyA0Arn2QxGoifgmCx0JWUQvA+ETsQAxgbtSYSSxeyTqAGQKyE6sqOhh5xTTv3A0fUupa/XkMdJkt2miGN6iWPeiQiauYUAzdJOH2u/y7FLeUj9LrZrUYcpMcggmLte726VFxZm5gMesryTp2sH+htODRDhwwj8gLcDgI0g1CS/OwVkE9JwPMmA5k39i588VFXwwYA0Wys+52AuugZmpGbI7Jl9Eon7rRZtCBVQxdEJFtXRipqZao5pyGkQFSl4dxqy5Bz7tBtyTlkyqCEgO0cuGNGQY5Q8PsLSDJSZnSd2zJyFc+YYYxfjrhtCCI7J8vvUrXceUDHoBACHwgkc7+1Yy16QAVUrlDohAkIa6yyYKYS6mTWzeRWCidu1u65t+9gOEjNiBjI1yoZZKSEoIVFZEIYqYL3BbW/9IOu13NzYYsZ1zeyVOedkOVsvkKKiO04Po/MzMyrwEVEmVmIGcK6qKFRAVAo5zUQtK+TpgY4sgyK4TgxcBMzqOiyW1XLpmxkHD6Zm4liCWDNr8pA1iuYkgACgVIQipJR1AZiogpmYDUiAQB6I0TkmREMSBWEQMkH7EI3/HqMEkDHFIUYpaiMjEETtvn3x6vn19d1+1z84e6iGJydLdvjs+cXf/+o3ry8u0o9Y83Z/QrBYnVTYpJyzWUbiqkpDCkzNYubrWokViZz3oaqamdkuSxSLzFpVoZk3REkVJEkUySImFtix5wgYmTHUvppVzRzJWdtrP8ScUmH+SFaVUey4ZCAJACFn6fsO0JARiEPt0Rl7C2Ilz4SgnsARMvMH1zNtHKiIBqVAGGDCRacwdIqnJ7uFI/I3PeHCXBs5zaUA9Z5RVnLFOm1YxyH70Xu+r/NlCWUghFFziZhxygMQHUGVOnJmmSDTWO9iI3tNc87FROSci+pclpwlIWmokByRQE5a4P3ploz2Z1zg793RD8/z6PZNd+34FpT+7lJqTQoKPULQAKpiJgpEpcf0KJaXhhjbLnZRslLt3Wzmq0BMpeMaqqmKdn3sewFj55z3SCQ5yxBjP+SuT22fhiSKQMF8I642doglqyo49q4cJwEiWenxgqBWGtPi+6uciJk9cYHcS5ODsnc6dn5EGIHUTDSrASAjMbNONwcL2p9zLthhcQyJ2AgIBYWAjBRNUMVERNQKEjlr6qoKZpSSZtE+DoqWTVghi6YkKWulykzk2Fsws74fRKHvo5kOw9dryBOAg8MUPoJdJhfY0IBIKSA5CLXWc8sRcrJ+r2HFPjhHLu5FtYsmu43NdpYTsIMSOZodPF00IDCEIvzzEdNewHqC6ZcFMzKA8sEiyjCR5N+fhUhEVnsCRPReyWf1OYcQw1zkHOZZCmOREUgKDmyqJqCKqmSAZllyVim8T1XNwsjADkVF1an6nDgz92hJogkgETly7EPdVE3jQiDnysfLyRMy8VQKWbpjMDvniifIPhCTfpzFMq6jkWkJU8CAUARrYHLzp8wGghooQDZwhmrTfSMGR8qQMLU5dbEf2rbrk4oYGBMi4gAwGEaFIi+NjAXKAzOEDGSCMVpv1rUyb2w+p6oxH7TO1mXMgJ2Yf//MHaAbIx9RM/DeGXl2joimjcUmb2/y3hAIjUCLeoxjdJ7YETOCQc7ASQ0kq4pm0ZRiBhtlF9kr2Eh2KWoPJRelJemqKmaISA5dKTi1MUQzVVRjgLKyf3fR09cOm/bPEcFEhHJq7BwS5Zxvbq8///Lzly9ebbftTz79I0M+Wa8M5IvPv/zyiy+vrq5zFhvbqP+Yg5mDD4vV4lzPuapOH+5jP7DaPPimClXwoWk41MBOTPvYb3dbHbqKAblynkVlSJYyZfWGDhxacOZZDHJOTN7YIwdHzjvzbEkAEdgTOQ++iIWCqjnvqiqYqkQBNWIqmUoHhWoKMBYrKII6Akbg96P2UTO5UNAMAUGm5aAH+33/SI7NF8KBwHx/uJK6GY0mTHYd3kkCHschcEjA358RwOjQ3YNX5fXoVOD7XoAhKiGVzsw00XipaLMV2kg5F80jJ4lK0twMZDTQIpJSGlOuhGaWJYom8kYOnEfiUQgNSlOO+ztiIww4klOmXxxfE77zuhgwPDgg5QOqXb9FzEWYEZkRGczByKUqiCLaWM2DKpKT9n1u25SyIVFVuVnD3pW2hGBmKirJuj7HJMzOO08EqhJj7LrYdqntctdnMeBAvlZXq/PApfcFAoBpaV4/xUIjtEmmCojTsvzAtDsXip4UUQFcCYnYOWLGwhoY6yNpchyhEGJUy/3B8kQAgIjGfAciUkGMFMzMUMwSQNdZxbCceUAsXLicS9e+nHKOOauYD6RAWS1LORCyYzBV9ewcIA1x6Ie+7frjC3nHtHuEiiAK6FSOMXatNoPR5E+0SETwHh2j1WBms4U0y3Y2s3pGu2vYbwal3Lc29JgzeEPCI/sMBMYAHgwRolmaehwfhegAYATm0QBAShsNG0UMGcAhKJroCA+9J2WhSNmypp7RQgjsPfG8kTr5YMHNlmQj34PASLQkPkVUAJTAHBIilGgjZ80iWTSlPKQ+pShaMBtJQ9/td6WJZ0rRDMix91WzmM+Xy1AF51gVpAgNTqgMFZ48ACB655z3VQhVqKq6YsLrzZM8Rbxm9yvsnpY1xnGGEzkekAyoLPECapgZAqIiGkIJ5Q3BEAUxg/S5z71CFh2SDkkSEAB5EkQF69RagMEsG3DxIBgM0QgRDEnUUqaUZeh120nd5WZusyU2SlUWQ85Ks/e0RQqhVySnlGI001B5IE9MBqqqqIKmCIBGhWkCcM82IySHHBx6h8iAKqkfTFmSkHOimiUniTnGPCRTQCLyDhDYc1V7IjYrs4mi5iyWs6gpMnklKZsjoIApFllhBQMmNKYfYlYnopEBIjM755333gfnPAB0fffm8uJXv/pfX3z527u77S9/cVtV8/l81vftr3/96ydPn97e3qoKvOdk/xhDcjLWqnIPHpycPjhLWYd9pyk7M0YoCSP03hC7IW52u7v1LWmkWTBskCAlbfusCcA8B+ccu4qBSlcsBUVNZqxGSslYDJOSYl3V7IgJVGWI0QAqFQKoQtAkmgURkRiIxuIv1CKxSCWaJXD0jqv10fC6wOw46q8fIYSTyZ3eicd2fUpoHQ6lR89uOur0VN+JVGEyi/dPaPqKI8rffXJmfMs7k4qxNAZC4lGg0DlGRFUUUUIbaZRmUvqYKyICu1LgXrqMmIgAWIxl0yZEy5JEs0NwRC4gOyoAcdkiimNTllhplvvBVP+oSzTm0safHn1CVXa7G2LxwQUfEANRVZqjl3MHyAhQujKjcRaLUbtW2lZFwDmqa25qLror93F90r5PKZmvvA+eACXlvkvtPra71HV5GASQ6tqFOjsvzpFjT0W5fNxHYEqxjBbIjEq6uCxzeFcXm9ixC6PAJzIxkyvhe4EfTEeWPCEBETFxCc1FxsAdEUeFyqPBzM674j1O8hViWVESa17Nfc4NIBChiAxDbEsDnyERQoMM5KToGSASE5gpM7EQOSQ3xH3b7rpvMO0PK/rjmdsk2SXIBggWGBwh2Kjgng16hVQsCOAUMgIgKbnk2HxAqkFYul5FQBKooOnEXbHS7NITNAwrR4FgL9rpmABFAzOTUadZA+LcMQNEU9XR1WawirFxGLO0MZsaAS7lHVtioFly30UmxMqcN0QhAu/CzNVVU42O8ejtm5QGiSYoA2oqDc8yzDNgseuilrLE2McUS38OA00x9m3b7Xbtbtt3XUrJhTCfL+bL+Xw5D2wMIkbJvCqqjn7JYR0hovfe+1B5H3wIwSPC3XM+TphMDJdpVR2YEFMAfyDWTpyad7yBMcEMBbA0y5K7mPtsNhiIQS7p0hEhMYVs1gq0ioOiGIxKq3b4MjBUQ1Er2q8RbCAbHA4VtQguKZZW6e9ROLKoiEiynAmtCsy+MvJZQURE0ti2aBRZppKmQQRyxOjLP7KaZSMFJBHtRQ0wO3WikFVzTjnl0hBDDRRQER2zK2rLZiqYE4pCShajiAgCO5LS3bQ0OSFPULbSUh91T5j6fuMAdpVq+NT3Qz8MWWSz3b5+8+w3n//6y99+8fz587btHfvlYul92Lfbr558uVnfxjjYgVX6Iw6DnAYhRdCq8q5qEHCoKxmi5YSGyB7Yo/MKkCWGqglVTQLsmRiRjJmDr6ji4EKovPNOIMc8eDVvxOyJXE455di1XTcMMQs4rqgKVWAmVSEaG3o5InGiLksSKHENe2aOMcV+UFFTIVBC88yO8X0WHRYL/S42bjC17ZsMLpYw/N6+3kdq045/+OGB+V02q8PHPo7fHLnZY8AzrrdjHwRpioI/+iSJyTnCSXiYHHIpgBIAAB29kEIFKDUDZoBFpceK/oKNJ1xi9xKPSYknEgAqMbEDdugMTVHyuI0UyIwN0FBsAiS+fr7hwUwaGL53S4ywG/tAIQI6AAErrVcVQAAyAJZIVxVSzH2f2j71MSO6qnJ17UJwRAhTYDzElAaNSQ0wVC54NtNhkN0m7rap3eehVzXggPUcqwaco2KKCRk/dr+L3Z02VSm0FwQ5fiMiEbsiSjPqzTmHRAZmRcndrGD1zJ7cCDLYyIRHKOUIE7IxAnXMzjl2rngAJqIiZAIIGSFliTH3fRz6iAB9n/t+6PrY9bEfMhHVAkksZcuiWcQZMRODiRRowYlg30tM71TJvrMH/6Rmd+LedPi2l0EUAZaBakYxyGrZsBO7jbpXS1aKGsrNAAM0H8Q5q2YADH2EfAcx4pQLKsWDOM2ICmHl8CcNrTzeZdxmBbPSGswUolpWM4UZw8PaecJBRBUAzAF4gkWgled9zLctqIFDepjeMyYpZ2kHZSInAJpZ9giOcOE5eNeU0o8C3FiRjDBFyzgkjMkAlQPUcwuNqpUSqCya4pBSYdcplGLlGPuu6/b73W672+2quj578GC5nM9nldM9xG2GOvFCzJvyKLA80ReIyYcQvPdcmoazqX5Fx9lEO5h2PDj+k+lGGOtYprxZqfeZUD+ciPhHPDuNOQ9JMmgGQmSHQIjEY8YniyWBaBCNk7HBdC7lBMaOcYWJowCKqIiCnAFiZoax1RyCcThaKwaaRCSpJAAJHn0VqlkjRt2Qh14sRS2sWS1uCpUcFRFy8M57FctZUk6QMpM5VuLeW3I+EwczPKQhRKFcRMqmCOwQgYl5vEXJstgQte9EckJDh6RsWCJ2tADeVQ5BzLJKUsnv4JLfc6CZpZS6rltv1uv1erff79vN//zb//k3f/s/nz5/ent7YwZX1xe/+jWo2b7dbdY3Q+zsHlL+kUeOXQI1MibvHRExqc+oeRAEclXNYUahNsAsUXKM3dbiLnD2wRNjVXkmX9fNfDav6sCedu1us9/WhkCBkFEx9UO/32/Wm24YkF01axxhxY4cGTARFvmZSJwpZSKhjEClx0wV6qHr9oBD16WUwBTJyKHH9+0rTqIy9+MQpx38qgmvL1j66Fi/P94zaEf44fjeDz5ynzec3oDHpv0+tEUzY7aiG2dW5vbx4bjU9TEUKLswskeXQwFpdL0BVZSKtjhAae0EbFwo/ocdQjUfeoEboGQzKCWfSITek2QQMYOJzcnFuEGp8P7QAzmsADzyWe5TFtMgtDpkZGDP7IxK1bIJgCLkwlMEQABnhqIWY2zboevikHLd+Fnj6uCCZ2ZUyCU2k2Sx1yTK7EJFIRCIxl53m7TbpH6QnI2IQo3NXKsanSvMdXcfs398+RQRIXk3FDr8kmCsQXPOOWaPRAagkrKoSEZA79n7QM4DcTkEohW4pDDnAA7bLx3sOjOLKpbAP0dA4AJLGaUkXRf3+05y7gftutj3qetT1ydm7gcZKo1JYpKYs1dy3o3dzIiJPaCTEj0ejXcsYs1w6tGMKrJBDBCWjmqmpJAMktoma1SLYqI2lQXhiBAjGLIGgmaJq3NEh31n1QyJ793laeEhoEPwZI1DR3Tip2YKWiRMLIr2OXvCZYCK0YzLwmQETzBzNGNsiDwwGDqi1f6dqL2uKgwBDYlpcbIKDkkjmxlhYlA2JGMEj+JGAGKkVxiLUhYjRGXHVCoKVKGoMdQsWlqWAZaWnCn3fb/f7eazZj6bz2bNw/NH83lTB4LU2OCFGnFLsGDGpqYmOpZ6GiKOj7zUZSCq5HdX/YEndz9TR7s+cWlxZPkdGI8HlhaUxBtSKS8cqQmW1aJaBAUCJENUBLFSjaYggFpSHXa0cd0H7hM2UJ4+KjIxGJEhoSrksV7nZH5MBjSEzJidV++oaqq6mVWLlSjtu9Tu9p1Jsl40jd9DZmDM6AJWcx+aJqfUd9rvJXaRDDNjCMjO1JIBIjp2bMZiKn0aorRtHoaUFWIy5/yMOAQGMoUsojmJpKw5qVNIggieS/kGeRdC4ISgWQlL9vb7DTu2ymaWJW22d0+ffclMN3eXu/3m7/7+b7766svN+i7GHpF2uzuVJKoxDjEOqvkDe/OjjWEYKrSxmbRvAEEoqWnKEZEZwVehnq/IOZFSIxF12DAMTe2aWQ3AYFxVdVPXzrORxuyqFIx8PSNmDwa7zSabURwIlJjRkaikFNlGux68A6s8sTifU9aUiZ339axZ1HXT7fYBsUXs0XKOYMIfpkZGO3qEldt7GDscZ4/hns5+eNtHPad3foL3PzwY+XtrPlp3HInRozQJTOAcwlSNYYAME0P5PfCBaExMI6PjwgwBACPVIl5TtkYDZCYDgENFJpZ30hRMTxvFvX1GNTCBnAwQfCgCyUU8wgCAWIkAmDKZqaaDbs3h0u0wh+8vGT+w6wCABE0FSICeoOTjTMEyQEYbCjKPyAAmqjlr38d9OwxJDClUrmlCFRwTFRKVlnUaNUYDIOddCC4Etgw+cAjOhexEFYEIfCMcgB1MIXtJ4xz2rtHO2yGDc7goVDU9bssLIw8+FAIds0NmACy9Qc2QyBERsSuZI5hgHZgCLDMrqgkj7WHC4/E+BiuYvpRmB4ioADlbjBKjMHHOJllTkpRKZTsMMfdD7D32NQ+D8x5LzZmaEXNV1U0zm82GwYZ93h4u5H0aXUA787BgSmaGNmP2xFkgGgxmHGWTZBttMjV0n6cq8wkRQ02rB+A89p3OV+bcBCkdLZyxtMPQ9NS7ishGMh2Zopn1Se4GU9CKofEQCBkJARnREXgwZ4YOsGZGcEgzfwyA4Xw2D8EvV6fsXNVULIPtN14IAXqUARMAIwBJdpJKahfZAWGyLJDVWNWocKlQiQwlA4KxM6wMyIwQUM1Ectd1lePguanr+Xz+6PxRXQci0ORl8MY1ugWBR2M1E1NREVXLGVQL5mMTPmjvykdMmflJqWNiOpQ0EOFRRD5adhtp80TMJT005hwAijaXQVJMhtEggypkMymOhmlh8RQ9erBCvTsUh4wL2mzSwJoKctCROspEoNkyWlbLOcYjaAiBWZjVFynT1axZrMLsRMzVbQrMlFMrQ5JkUEQ1ARF9oNDQ7MQ3i3oYjH2bU+z3fU5U2IgB0UAMhNgxOUCfTLJ2/ZD3u9ju4pB01qv3gb333kNh0pc2DBJBIyqgDh6p8mjggEJdcag8AmkCRzGjvPs4xrIW+Fbj+G0qEjfbm9/+9te3t1eff7Hq+vblq+fr9U3fl+jc+mEfYzdtDVPU+PsZQxwi43xV1c2KQ8giCm2S3A8dkXPNnBzP5nNfNWrmHXlWGRYk++AhBGZ0iOzYMTOgZM2OIHjvqxr9zFe1AlBVCeOAapUrFMUs2Xp1TM47751nxhACO/MqOWsW7+u6Wczny6ae7UPlTD2BJxsGTGlgJPjgzr//OD5qc975wLEjAPeb0fe+1WOkPtY2Tdv3ZNeLMzzRZCam0vtdbiaZnmLjaSxeN0UUPDLtAACOgJCUsk4F9qW9GtIUtd/T9Ub6v6GppaQGRkzOl0h9LEcjBucKScwkm8jUomsCPo45ymOCYvz/+4MQZhUWdSFFKiQAwIwwAPQlXp8ELjQl7bq0b2PKhsQhhKYO3vGkSGAqmmIeepBMyOxDCJ6Dd0bUNGGxqsTEUFhBnWElwILkmMOhS/SxxT381+A++w8jpFIgyKMLYee8FeUYQgbEUvKohoDEzjE7ZGeAKgoIU1x2MO16yLIzl9/R9FswtVJuWPbqca81yKpZLMshfQ2atQhYgukwpN5T56zveYg+RCIyNBBVZqqberFYiGonLRy1dX1Xmm5zVV98GQwUTEwBISAxkhhms2QQslCfH0TtxfK0D0+laVbysZgzyQA4oItm2frBkoN7EwQI0Dic93jSwcpjwxyIJrbQmDZOonVKClZ5rhg9jcXQVCBkMzRjVa9GCIzodjeHqxDVr55dOu+KKrnzDnOEvk3ietxFCAmLbik4E695rLNiNkSRNudOlA29a25c1TAqgY5KyMxGbizZmbr+xZj6vhuGIcYU2zr3G+8comkeNEcgh1wjMJZpbaNKlqlAaWY8MWbNTFWKhG0ZDx8+/Mu/+suDdZ/wcEMiQkYeCU80uZU5qYqNVYZUCBaIDoBLu0AY5eqSWQYTBEU1kMNqLtX3RdAVFAHII3lANhgbbmCpgZNkJqAI6JAqpEDoEKykgwwUHq4eH+1+6Pgh4tKxIXqVOvYz1UaBJXpU9I6aeu6oL/z+EnuwIxecwxlozTALPJs1PUq0jIgUKq4qYnaIDsGDOQIXSGd1bcuVxzjMcxINlV8tZrNQBfZGBnWi1SpQLymhpipAU2EV0HsyYEDnm9r5gKiEmXFomux8fbiQ09PTf/pP/+l3heinVY3O+Vkzr+umqqu6YXY/HYaHpnKILI+P/F18iN89Hj58eECDDGCfKx0w73ArGRhUJfY2dG6IDSJ3e2oxbvLW+cEM0tAPLWqqwMxFcz0i8tQCGRVQjfshxMTAFTnPAxvArq9bPcmOoe5HCgiYjDxOzuYA0EBLya+hKgm4gNyA1SmHATFVZLog17kYURIWf9bPDxfV1M2f/fIvzA4AyZFdv79zU7B5jCa/PwwAjlOzU+Z6evvRhybjOf3TYKK7HFJjU3A/GcfjkylHburZsXX/5JOfOueAio2nor1uNiUCR08PC2NYDMpvijbNO0niIyTCwA7pBwMrfapDRcQg2XIuFHFwDpiRkEQs9pJiIfzeJzTsKMi9T7RP4/z8/Og2VuT+sSEaBQSH5qbGkwkhGQAAAzozRwwQbLkcEOMqgwGvVs1yUdWVd0RICpDrMKDl2pkqIXFV++BqJgawuk4np30IcbGIamKs4ASd+srVITRuxjAf+78dT4Hp9thUoYUIpqhIeJBGBQCAWcMPTmHsuFVU8QzUSJXNaIRCD+gIKpIQGtIoGVksN6IyG5MQZaL7ZoTCoiAVJHXZsXgWInI0qAxd161x13VexIYhDakz65kyIprEoXdbc6brITZ17bxnBDAFyZqT9X3fdX2fuuNpjf/+3/97+MP4w/jD+MP4w/jD+MP4v8v4WAPkP4w/jD+MP4w/jD+MP4z/vx1/MO1/GH8Yfxh/GH8Yfxj/txrv5tqdI/q/rrE/lH19mDJTlTxJvZR68Y+86Vt/D0y9jc1MpJS6jUf+2IBDDu5jdTXf5YvNRm1RAAAoVRM/5IAfHt8mpkf5yUQ6wR9wuz4yRCTn+/r80qroRzz+jzm+MaOdj+rfiMh7/03v/ubvmVK4h4LG41+9d3/emVs/+NaVqvrDP73333WZH07+vp2I3acyPzi/j5zw/Sq5r+045L+PkrnfOFJKx1P3e8+rD+//gb14VOuFh6zqu0x8uM/KT399p7VvZjHeU2qccx8T0P3wfL/h+N+HlvHBRX3nUVoijUdDLKvj/7or/WvGe7vuWI5QyoqOGBhllGk7kX2nVuT3R5r+dXQIg3c4GIf/v7fr3tcm3P8KCw9ep4UHh0Locd++54UAlAbe9+v6Hcvxz//5P/8n/+SfjJ0Bp4eEo/DJgVLyjn0tv/jYw7SjP+98cKwZOBAxbRRXLVQbG1UTSJENkRDN1HLKKaeYFNDYOe9CqchFK3ye//7f/8d//T/+3+WLm6b5d//u3zVN8/VP8+vGgZyMWeT29ubFy2f/5//5//3N55/3Q8/MZ2dnD84ePjg9Ozs7Ozs9PTk5WS6XTTOr6yqEwM7R2L/ie07ulNJ//I//8fb2tvzzn/2zf/Yv/sW/+H6HOlzP8T/btl2v12/evLm4uBCREMInnzz+9NOfLJfLuq5/RK/u888//+u//uvy2rH7f/4//uVitrBjatLReJ/hZMc7/e/YKMqOd5hkh1qDiZB6P+/ufw5jyxibdG2nBVjeME1rQAP7P/4//6/rm6vyqz/90z/9V//qX33XW3EYOeeu67bb7e3t7X6/7/u+bItluTGz976QdMrr09PTBw8e1HX9w/xUAIBXr1795//8n8s1EtG//tf/+rPPPvu6Nx/PmYPPkXMehnh1dXVx8ebi4uLy8rIb+pRFppoiJCTksXADcSSa3nPKkIl8CLNZffbg9PEnj84fPjw9Pavr6iDoMu5X37h2/vqv//rzzz8vrz/99NN/82/+zbd3gN67rvK3iAxDv9/vb2/vNpv1drtt2y6l1DTNgwcPHz969OjRJ3XdEN9Lih7vbMXhH0lZRxf7zWdye3v7n/7Tfzo4W//yX/7LP/3TPy2+xQefPfhUYIfJO/0Kj982MubHia4GNrXvnMTCoLSVY0JG4lGrFo525vvxLefbf/kv/+W//bf/Vl6fnp7+23/7bw/O1u/VwL9HOP2BR+u67j/8h//QdSMHbcBtb5u2a9u27fshxphzVhPnrap4sZgtF4vVcrWYLSo/J6hypJxBTGLuhmGfcid5KF0T1uvd7c3W0Lx3o5reOE/Ie//w/MGDh6dVHZBgu9nutvuhGwBwPl+sVidnDx+EENp9f/P25uXTl69fvX57cdV3PRKenZ/95I8+W54tqiZkSV3Xt23XtcMnp5/94vGfHK7r/ai9qio9dLItg6aGpB+EqDTZ9cP9PRDhv860F7uuo0791LptbK+tRRfOABVI0CmSGaQ49Hno2q5tOzVgF+aLeVPVVTWadyQMPhw/7JOTk9ls9u2fLuJh+o/9QGOM+3arovt2v17ftX1HRCKSU47D/YgxLpdRdGFmNRGz884RMxPfM9s/9M6+ZqSUju1rXddnZ2c/hCl9iKxEJMa42+12u93t7e319XXOOYQQQlgsVovFomlmIfjfFT182zGf3zOZAWE+my8Wy3vzeT9Jprfg/c/fe9s3L107mPNpJh1m23GBzjt7QXF8DUaf+mNh9OF7+UhBKIRwenr63hu+9sTeNSQi0nXdfr+PMaaUYowxxlIkU357ONrBqy4vqqpaLBbvPZdvs50dn8B2u8Ujyv1yuTw7O/uGTx1/VlWTyH63H+Juv283m816vV6v113fR5FR32LaK0a7Xoo2xh+PJ8zsqirkPPfBL5fLYZlElNnNZjNmhlL5gvQNV2dmIdwvc+fc6enpfWD93U27mcYYEWEYBpo0xURy33eqEkJYLpcG5rwLPrBjYjySgDpEI6PF/fAEvu6UDsBDGfP5/PT09IPVcTjPA3wokvOhc+shkB9hODh0ZEcR0Zz18I77lUDErglNVdXeeyJ+z486/upvM8fq+r5+hIhOT08Pbujvz7T/zv3wu351COF4182aopTWM10fu7FfI6gDQ3ZJWCwYJiAhpwSKYsUOGqEYKJqBoohCSnnoY0tMHBDZkYMxClcBAUUlRxwYEYBALCeJqhYkiCUjBQZkM9RsKcah6/b7fYsE1cwPsa2zZ7WsKWkfc9enLsnXt4fZ7fZXV1fl1o1tCoiKVl6p0cN7LG2EJEYPfRpHeNZ7pn18LlAaC4qKjj0LD39MRTWV8gtFJxSSUZ/y+m795s2bt2/fXl/fmtq8bj779JNf/uJnp2erZlaRI0Do4jvyud933ONyOae729tXr1+9fv3q9cXrru9F5DW/rkJVV/V8Nl8sFicnJw/Ozh48fPDw4fnZ2dnJyclqtVoslk3TVFVVVXUIYUTbPpiOX2dLfqxxLzVhAABt293c3HzxxZd/93d/9/z5szdvLlLK3vuf/OQnr1+/+au/+ss//3M+OVkdQx0/5Hx+iDvyQ8Y7OOk0RaFIP76LnX3TQSbU6yMG/zudzHQcVe267urq6quvvrq8vCwhe0oRkYqtUtUQQlVVRFQaS5jZZrPZ7/c/+9nP6rr+0EH8fY9DBDnktNntX7169eSrJy+eP3/14uXN9fV6vY4pZZEkmkVsgjARoMTefIxcTuB5VdeLxfzs6uzu9vb6+vrx48c/+clnn3322Xwxd8zF3fpOV/jtQ3aYHujhNYCpWkpxvV5fXV3d3Nys1+v9fn93d3d9fS0idd3c3NymnB+dP1osy6IOVQg+jEJgcI9pjyDRDxn4teWORXZV0rDfb++Grk0xSs4FoAUroqU5ixTsh5i7tmvb1lQJ8CAFDoREXDXzR48/O3v4iHCGkxbIFK/9CLPrHyBkf+/gP3yree+AKaUudkNxwXNOOVuRbjcCY7BSeW0GapjVsOgDeO+deeeqGC1GRMCclYgJqa7q05PT2aKpqkpE45CHIakBU2BXe1cDAGFryjmDZB2G1A+p7yICiygiMrP3zgfnIptJafWnGs0QICPK+Ofd3e0d094P/XozytlwkePlsW9t0UwjOjbwxdGe/hQJpaOMxZF1f2eYmYppkdJVAZPSW800axpEs5pmoMH8ts9Xt5vXby6fPX/+5uLN9fUtip7WTfvLX8wRSD8FPCHnFCHGBD/aMFUdhv7m9uby8nKz2Qz9kGJMWTrpNrZBIMccfJjP56enJw/Pzx9/8sn5w4dnZw/Ozs5OT8/m83ldRlWFqqpCcM47x8yOmekoh/d7GuUhlIZ1wzC0bfvmzeXz58+/+OKLzz//4urq7W63SzEZQN/Htu1LvPXTP/rs/Py8IMA/9gm9d3r3k+QA98C3X6XHGbACBh1kOqYY0VRzSiXlT0QhhHcC3ylkP/7SD/Ov3/bqfsfJWoxxvV5fXl5eXFxcXV0d2kwzczm9w9eVID6lVKy7mS0Wi9PTU5wSmf+QI+ccY3x7e/vq8uL58xfPnj57/vTZq+cv1nd33X6fRUQs5Zwlq045eDAsfcfhHaPLzCGEqq5mzWy9Xu92u+12t9u1w5BylsePHp09OK1CwEnN7TuNDyDDj4z3niYiqlrOue/77XZXvKi+72OMfd9vt9v9fm8GQ4yAcHN7u1qt5vN509SzWTObNVWovJ9W9Kg1RkeezNeewDev+o/8tvxENeeh3VxfvX6yv7tKfXtv2gFENKcsqmLgHBNx13Vt25IhEzlGcqM8PbjQLE4Ws9lisayqGhFgbIZN753k99ud3rPr3+Mg32bRvYexfddv+Z3vT/mgUxJTjCklMGRG04NdB1XLklOKAJKyETqPyAyIBMBmnCJNSBY7H6pZPZvNq7pKMQPELABixMG52rkawJgDkQNAUR1Saru+btvSbQYZQx3qWd00TYpxiCopdW3nK1YQA40ppphySipfryGfs/RDLPn6ojY1toZVUDVmJgImIrKCvpUGQkAEk8s5wYmH1Pxo2oubfKTfYFgkz0VMczH1JlniIDmKaZ9tPeirq/XnXz3/6snzZy9eXN/cdm0X1M7Y1TF+tpjPa18vKqoqAczyTW3Ov9ODN8OU0m6/v7m+Lv57M2uquhLVGFOMMcU0xNh13b5r913bDX1Mab/f3dzenp6enZ2e1U3jS+MX75vZbD6bz+fzxXzWNLOmaaoiG38kffCjDwQwg2LXr66unz97/uWXX/7mN5+/fft2t9sRufOHj8oldF3//NkLEd1td/948xd//ud/9ujRo5OTkx+Yd//IdX1kzRbLCojvg/C/c9j950cpz5KLxBFIwpTzvm23m81ms2ma5tGjR03TMPN4/GlzOFj3jyViv/a8v+UocZiq9n3/9u3bN2/e7Ha7nDMiOufNihYV3nePmIb3Puesqtvtdr1ebzab8sPvfSbf7+T7vr+6ufniq9/+3ee/ubm9bff7q9vrN28vt+t1bHtRLdYxiRS6gB3Us+E+ni2Hcs7VVR2qsKv2+7btun7oY0rStf3V1fUvf/nLv/zH/5hPOYQCWX6Ycv7ak/zeK0hVY4xt2+73u67rCu/hgI7knNuuk5cyDMNy9WI2mzezumma5XKxWi1PTlbL1XI+mzdNU1d1FSrH3jlfMM73JvFHuQvfcpS3qkjs9uur1y9/8z/Wl0+l35omGj0JKrqtBqWTNRFxzpKzeHbkPAbvgkdfoa9UK4uUU5tzNFAkBChMMPzhfuwPjNePPezj3eC9A5aZdmiFfkyl/J3R/Lc8t5yKaY9xGNIwpJgRCBwrjx01ASCLxBhVDIxyMseeHThPgIKkBRASUVNALD1/CHiMgAvZzRCZAnPFHBCsWAv2bMliivt2z56z5rqu0GOzqBcni/1uF+MQYxr6uL5Z55zqrgYCUR2GYehTWn29adciO6qjQLghGqITE1KisTOoFt4MASmUxunvRGHv+FMj/+MdKGy6y4ao45yaYihVU8k5dcNwt+teXW9/+/LyV188ffri1Zurq3bfquS56MxoOD0Z3r5NPznX+Ai9V+Bvnpnf4qHen6CZdV13t767vb3dbDZm0NSNgWURAFTVlLKoxJSyKjvu+r7tWnIcU+r7frvbhVAd9um6rmez2WKxWBT7Pi9+f1PXdWkNc9jTS8T2nU7+vV1jjPyQzCxn2Wx2b9++ffr02Rdffvns6dOXL1/FGJ1zVVXXdQVIMeahj13sXr96nVMqeMowRDObz+dVVf3onsfHYpejHA7i8d5XXn+tIz/F7gdcqCj9GQgi7Pf7y8vLNxcXb968KZ7K+fn5bDYrQl9gRz7m0cB3k/0AH7Hs3+me2JRlv729vbu7SylNmxUV966qKu+dc6W77zgZCjGi67pCj1iv17PZbD6fH+7Pt7EQ3/48P4QrSjObi4uL33711ZfPnj579aLrehXZ7nbrzbpvW0uioqKacs5ZRn9p4s7DPcY7HrCwCgBBVEU0ZwGgum5ylpubO1WbNTORz87PH3jvy47y3Uzgt3YFDs6ciPR9v9/v9/u27/uClKSU0gT2xGFIKfUx1rc3dV1XdVXV1WI+Xyzmq5PlarWcl9HMZ818Vs/ruvEhTFQbKo8SJzjzYH7e8yO/7vyn9xRISiXFfr/Z3bzevH2mwwYtuVEPkPSojTeMExjBUNmJczkE8B59Bb5SV6vm/c2y9oEh6bAEKPbGkwvsA9HHKw6+/e39Tu//8JJzzsMwHOogCqX00DwNAFS1bdvNZhNjNDPnXF3XbhqHjfQ9ktN3Op+SEZMsOWfJoiIIoDjqv5YNQ0Ull/6QJIIIRUWUADVnSeVPEpHiNplITikiQE6SYpYsZlziWRUggrHloWd0KCr90MMWRcVwQYDk0Fe+qisfPCLmmNtdZ2BJlRwbWM6S72tHxvFubRUSkiNUA8BjaeOiPgo03Sw8hEfHpBK4n8FHgfv08/sHD4BqRlREYk2pGHYgQMQsstnuXl28/fWTl188u3j66vLt7brrewCrHS8RT7POY+TtFvd7zBkNiR3iD+d/jWtJVba73fXV9d3dXdd1hFg3tYhCikjTykREQnYcqipUlQseEYYYc5bdvitIXble59h5X1dVXdVN08yaZjFfLBeL1clqtVwtl8vZbFbXdVVVhcD4LXHg94x6cWZFMgB671W0bYeL1xd/+7/+7vPPP3/y5MlutzOD2Wy+WCwQIBUYVQyQELlt+4uLSwNru7brOpH805/+9OHDh8ex+w8x8xMb/Xdc2ocb+tdt8e9bX4SSRiktU8Hs5ub66dOnX/32t0+fPn38+PF8Pi/JbHSoqqAAU4uN42qu34crU8zkdrtt25YIvXcimQi99yFUzawJPjjnQvDeV6X6NMZIRMXG9H2/2WwKzerHyhF88wmb2X6/f/369a9//ev/8T//5812nRnFLKc8xNj3gxlUVSU5x5gUCRh88N65wzy0sc0BFUtWqODOMTEjYkp5s9lWVd11AyLlvC7+v0iezxseOT2/r3TVwbjmnPf7drvdFfQ1pdT3fdu2XdcNw1CC+JJt7YY+VCGEUFV+vb7zwVUhhCrUddU0zbxZLBer05Ozk+VJM7nsIYSmaeqmrkJVTM63p/sdP43xLwNTBcmWB5PBNDIIG5IRAKKNjczMxiopAEQgFUyZJDGxM3LG3rhy+x2pSL+Pu6vZfFXwYlcvwvykOXkYqtlYPWU/aL1/j1GeSN/3Nzc3u92uVKMVf7es3KZpiCjn/Pbt22fPnm02m5RSCGE+n89ms9ls1jRNcxQyFdf5cOePA4by4usuEBGn2gGzieuNClgk9MtdVjMxACQgQEIoDdqzgcYh9X3q+zT0KWU1AMky9AOYDtznpCmqJCPycYhdNxCRc6BmyESe2bNGSTnLbi+SkcE7lyVB4dw5JqKUJQ4RmZA9B0BCG7vTvnNF75h2MxjXJcLYgYiK+DiWHIMZTj0JDw0J8Di0OZj26d/3Yfr9fw0KII8j72aM4FWk67qb27sXL1//9vmrr568fHlxfXO76bpBVRxaRbhQOwNYpMjbLez3MEQUI2b8EbR3yprXkhm9unq72++zZGLy5syymalIFhEVVS0wY1XXoaqYuDD/AZGYDCCnlEUk51IcxI69c8H5KoSmaebz+cnJyUi7my/qpmmaZj6fMbG8my95bxxv7qVYuZ1G1/UxDog0n88BcL/rnjx5+rd/+7+ePXt2c3sLALOmAUQRFZHC7hcRACR2pjIMw9u3VylHRFRVEQXAxWJeuF3wLZbE72l8nT378Kcxxuub681m0/f91du3T7968uz589evXiFi8dJU9cCKL87ku0H7wbTfX+L3sKXveV1xTHy0MQ6hCg5ZNCEAMznPjtk59t6V/euALuacvfcppZxz27YxxsKi/33f/AJTv3nz5m/+5m9+9atf/fbJE2Vanj8w1b7vTXU2m61m87OTk9vrm8vLSwCrOJyfn5+dPShWs7AESoEIEsUYb25u9rudqDrHq9UJAO52+5Tyvt2XyODm5mYYeu/p4cMzRJjP58z8MYDnxxnlJIdh2G63m826RO0FnN/tdm3bHrgOWXLK4iQrKDFW6J1n772hxdinHNuu3bnd+m69vl0v5suqqkLwzC6EMJvNZgWim80O5ufrsirHc+bjcTMAmIImlESaEYWBCBQMyaAAqlA6q9D4gdLyc+Q5MwITenLOSPbSXQ8bgLRHIvaVT70R+tmSQ41Ty9nvSmn8gcPMhmEolJSbm5uDaS/rojCXSm7r4uLiyZMnm806peR9KEb9+L+HSGmkO1WV98F7PyaYJ2NvRSv/g7tNxM5558S7rE5BDowgM9OcUhx6AjNR78RzYKoAQbIgAVjp74KlQwiMhYcqOUc0RJKkOalmI9T9fhfWdzkPPvCQ+iyCSM55UzRNKUbJCUAdsyTp9l2MSUQBwFRTVEAEIi+BvUMaRe+PL+Qd0y6qOQtPnQWRCIDMUM1UbQKMizlHRLRCeJ8ajsER5IVoY3gPAPChpItZaWx6gFRNU0y3d+sXL1//+vPf/vb5q5dv7263+yGmUvjuzILJHOyMcJki73aw22s/WBYI9GOYdgAAEen77vb29urqqus6BCw7lJrmnGMqfnxWNUR03td1HUIFSAU9ms0WJycrUxv6vuu6XiSnLCIwQIdQ7n+hFM3n88VisZgvmtms2PuT1UmowrGWxdcPK8HQdrd79erV61ev37y5vL29bdvOOXd29sC50Lbdy5evfv3rX202W+e4qipil2Ia+iGLFK2J4kbQ6MBZ3w9v316patt1KYuo/tEf/fT84cP3hIx+f3vuR/Dwcb58zMIewKKyqQH0ff/q1avnz55dXV9fX11dX12v1+sCa7dtOwxDCXvHrPDYg0/KBlh88ZIgnmpP8XtZ9ndGoTuUc1AVwtKWSKGQzabc4eGWllEyiCXyKM5BKZkrIcgPPKVvHjnn3W735MmT//pf/+vTZ89EZHX+0DuXU273eyJ6/MnjX/7s5z//6R/96u9/dXd3BwDNrPnlL//4T//0T4vdEhHVe9O+3W6/+OLzFy9e3N2tveef/vQzZvfq9WtT2O93IfjTk9O23b9+/SoE/vTTRyG44k3+npyY4sMVKGW9vru9vdvvd2V67Pf73W7XdV1J5RYQQiQjAZivqnByenJ2erpcLVQlpTii90l2u12376/pZkw6iCBiycQtF8uzs7NPPvnk008/LYyK7+qfGcBEWAayTJbMhEEJgUb6KI3NephHrBQLb84jhdJLDn3FoXahdlXtfcWVonWSgYkVRbOT1IskUSEi/Ac36gAgIvv9/ubm5vXr15eXl0XwqoTsJQonor7v7+7urq6u3ry5aNtWRJg5jJRGV1VVCKGu67qpm7pummaxWK5Wy1K1NFr9qgpV5Z1j56iYqA8iB2YXfAUKqICKZDEnQSx0MomxN0sxxhCG2tdVNZvVDglLDSghEwGhMfmRXAkAJcgXRbQsmrNINjDdbNYC1nZN1XgDUUsA6HwgcKAQ+77vYxoGAJQksU/9rk8pl/B7Ik5qEAlN5YN3nt/rKPguIF82OLWSHpqC7gPVqLw+Knk/guLhCHU3Kz6fHQw7TkcrIbuW7uEwvk9N+76/vb19/uLll189e/ri4uLt7XbfpygmBgYE4BHmZqdmDxQWKfJ+p+1ehshZEfBD5+u7zi2YPMftbnd3e3tzezsMA5Zmc2CSc0yxEJtFFBDY+QIUeR/MAAGrqn50/uiPf/GL4H2Mw36/3213+/1+3+77ru+HvrAu1bTt2rbrtrttKZNzzldVtVou66r+FqbdynZ/dXX98uXLX//m8ydfPbl6e7Veb/ohMvNqdeJ9iCnf3d29vbrOOTezhthRjAY2GfV7SaKJT6aqKklv79aihsTlatXs7PS0iAR8mCH/MfZfnLavj1/q+KajNXhwE22cnVZS530crq6vnj1/9uLlq7vbu6EfUsqmGmPe7dv9vhuG6NgjUkyxbbvtdtvu97PZbLVcFfdl7D1MhaEz3pgfcm3FtA9Dn3MujbqJgJkA0DlHxIVte/CYEZGIHaufBiLK5IqV6qbDnfl9pA/6vr+8vHz27Fmp01ssl02MQ9t1bdvu93Wozj958Bd//ue//OM/3m42v/n1r1NOohJj7PpOVAhRVM0MZewl3vddCYJVpZk1v/jjn4dQxZzWd+uh77o2zGZNP/Tb7eb169dffPHFarU8Pz//4So9H706mIDftm3v7u7u7u7W63Xb7gulvIxhGA6Or4qYqqmCGqoxoGOufEAE8T5LzjmnmFOSUvJTBAv2bRuHgUoRzWz++PFjRFwsFsf99779GD9ASFP1FSAER03lHZe9nAzIkGCsYCLnPXtPfka+Ie+RnZEDdOg8++B87aqZb5YuzIgdciBXs3NmZqJWaEs/VKfud49jD77wUW5ubi4uLt5eXd3e3SEAM4uoiA5DTDm1+3azWd/e3m53265tRTIAOOeKy4vlyplK0t37UFVhNpst5ovVarVajqZ91jSzpiAps9l8XvjO750YITKzc6zeWfYghoBgVvIqACCiiEIkmdXZ2DcTkRw554NjIMs5Yu/z4FJ08ZASuN/lzFQlxgHbnUFO4tkhMhJS8BV6cuw0Zyhd4hUInWP0wZzP5HpAVFXLBkOpX2NiIn6/6/F7OqajbS6RBMLY9rZEOwAKQKN1p4n0DmBgB+1IOE63j9HUyFsulwTl81bkk8wAEFFEN7vt6zdvPv/tk998+fTier3eD1kQkRG1RFEVwInhA7OHqouYqG2t67TofiDZj7ERmFnfd+v1+vbutgA+SASqqhZTGoYhpiSS1ZSJQ/AF8HHOmxkiN/XsJ5/+5H//3/7305MTAGvb/Wa9ubm9ubm+ub6+vrm5WW/W2922j0POuY99ktQPg3MOCb336/VdKZo/Fgv8uvPsuu7582d/+7f/67//97/56qsn+30bh6RqSORDIOcQQESGGBExpmRmKefyKFVVVAs4ZWaqRQRDyRGzH1K+vr3tY7pbb5IoIBFxVdVF0+N4ovzwGz4d6R7+nvyMb8In7VCDW/KMWDYlLYDem7eXL1++3G733gVCJqKYdLPZr9fb3bb1rnLedzFeXt+8fPHi7eXbn3z66S9+8YvFYlGFACUrA0j47cvgv2kUfHsYBtECDxAzOucQKfjA5AEQjMzIjAAIgQm55N5CCAWWLABy4cz/sNP5plG+aLfbvXjx4vnz5zc3N23XNbNZ7PvN7d2+bft9++DT1S9+9rM/+ZM/+eOf/+LJb7+azWfb3Wa/3z1/+bwbuuB9mVEGo+igiPZ9d3Nzs9ttRfJiPv/lL3/RNLOb25th6Nv9frffhsrHGAFtvb77/PPffPLJ4z/7sz/7XlKS3/YaU0qbzebm5ub29na9Xsc4FD5dsevFESmUOhUBNVJDUU05dv3QdkMIoQrOcfAeCEoNb06SUu77Pu9y3Ayb3SaljIB1VaWcHj06H+JQsnjf5XyPIU9AAiZzZITQeF7O6ip4JjQgVUwK0ZDYOefrpq6buWuWXC841MR+SDnmnBQNHdbLsDibn5zXixNib8gqClQhkEmRDvsHTbqVNbLd7i7eXL589frm5rbruuADs2dmQBxiuru7e/Xq5dXV281mbaoh+ANnrsiWGZiopJy6vlNTBERC71zwYd7MZk1ThaquqlndzBeL1cnJg/PzTz/77Pz8fD6fv5fyQwQCY0LHpM6ZVzQ0sBC8DwRgREB86KhNhkXi1YVQ19XMjDxnSTh0aejjEAd2o+OBAMbADCZjJJ9zigMYiPPsgnOhCiF4H6SuwZQdS1YwJHSStK8GU+r23UADQipCMJJFsnJWoZH+fhjvmHYqgMJR8D0949GVO7yeiHKIUz1AMe330OKUoLhn1B2ioEILKTg+wJDSZrt9+uLFr7/88rfPXrx6e73rNQshEjGhKoGx0cL0XPE8w0nShhL2vQ6D5KgqP9Y0VNV9297e3m63u74fVBUBTTWlXLTncsolX1uAoBJwM3PZc4vvuV7fOaa6qrzzq5NVVVWnq5NHjx5tt9vNbrPdbfdt203MnVGYLMWcMxMhUR3qb3CYEUHVYozb7fb58xe/+c3nz549u7h4I1mnZDFC39OkNQQAhXsCBmp2zOWxgx44TJtHkULIuTShF5HZfFYQYFU9PVk1dXPIvPzwYQexj8kbHIv2xjj+eMHhkZ3Fo4+P8JGappzbrl9vtuvNdt/1KYtjZO+982J2d7d++eqVD+H84fl8Mb+5Xb94+erJkyeXF29E7ezBufeVdwEAoCQrR+fBfqBtL4akEOO993VdheC8d0RcV3PvApHz3ocqhFBVIRTIvTA8DsH68cP6QWfzjaPo4F5cXPz6179++vRp27ZaiPptB0Bd38WuN1Embprm5OTk4fnD88eP7tZ36+2mZK+qqvLeTyQ4FJU4xMJNU9Wmac4enJ2fnzMzE0qOKUVVIUIRSWnYt3R9fb1erwuLrczhH/caC/m5AIRXV2/v7u72+33OaRiGvu/LkszTGDMLJUgxSEPcbXZgNgxDqHypayfHQAQEZiCiSVOWnCWnnFJOCMhMZuqcD35UVviWy+fwpnHS4zQj1cwUgRyTc0yIMVsX86CYgEAyCQxIkdw8LGbEqpDSsN7udm2bDBQZ3V012zxIega8PH3ofZUsKYCIkIipwcRI/v0Z+AOCAgAistvtrq+vLi/fvH37tsyWEIiIzKDv+tu7u4uLixcvnm82G9Ucgg9hzGscaqvK4Ub00cxMTYrlTDmmvm0dMxN7dqGq5ovF+ePHNgH+71+mFZKtQuEw4Nj1oHBdAYEI2bkQqqquQghTdUsIoamrmSmipcpH77zzzntHHr137JiQmJUhoyVEISYuqo2qmlFQhVQdIrEnrpoZkhuL58jnKPvdPg7Z+VBU4spVmsKR7uvX59pLkm+06DYl2AuQgIV0UJw6O3Dgy3783t058OuKg1A4CIf/wf1JgZrtu/7N1fVvvvjqb3/1m6ev3txseqSasCJiBQFWAqtFThQfC52nfBKFMeWYLCXNWU0Q1X7YJCyXrKr73f7m5na/2+eUi/eas8YY+34YFQfNELGULYVQMbuSUwGAYRjevHnDhA9OT5er1WI+ny/mdV2vVsuf0GcGNsSh6/vdfrvZbkvJcgkdbu/uckolc/87l5OqDUNcr9cvXrx48uTJzc31MAxETMiII45Wolg0O+iYwrRQS0HOAZs5PHcAUNOcskguCeDtdvvkydMYk5mpyD/65R/7c+8cf48F/26m5t07DyMef+xzIBYaxrRgcUqql2gQSxuCYtcJCUVtGOJu1663+303GJAbebV1CMEMb+/WZs+22935+fnZgwc3d+vnL149f/bs+vp6uTz541/sFovlrJm+SAzQgN7NNn2vUaiOOWdCqqown8/rulLNzL6p51VVJD9DUdRhKslTizE6x8UOHRI0v1fTXlyQ29vb58+f//3f//2zZ89K/X0chn7fglqMUWIaur5t9zllH/zZgwefffbZ64vX6fWrYbtpu66ZNU3TFGYymOWU90X3YRiKN/Do0flyuRyGoR+6tmtTjilZSoMBmCkRHhPUfx91/OUy27a9ubl5+/Zqu910Xacqoz7JmDjQfKTnSsiIZAbDkES3Xd/x7Z1zzI6cLwTIUtPuiSin3A1dysnAiMiRa+pmPp8vV8v5Yv79myRNKXRVVFWTpIJjBs1w16a7XR+BxVcCIDC4rqvb/pF540rSrmu7y6vr283GGJWwS+LC7JP1+qfd8DN2qxPWHMUcQiYnNlYj/96HTSPGeHd39+bNm0KgAwDvQ+G7ZZHNdvv8+fOXL19eXr5RlcViHnwoARUAmqrkrEcRi3OeJyOPY8bOcs6SsqlKymZGzq8329lisTo5OT09PVYvhuIrgIxl4KCKZmhIxN75UBETMXvvQhXqKgRfefLBheCbys+qUJuAZitFEczsPLMn750P3jmXkybK5QaXajfHhIimJkkSCrF4b+zYhcb5KoTae+/Yxz4yuf2mZedLwdrIQR+D5wILfL1pP5KNL7WYAAetryO9GZtahxHeV7YfHLEDUA9T8UC5YdNni/BcKWHP+7Z7fXHx1dNnz169vri6WbddL+IsMzkyFABFqADPDB6LPRY7y1ZlMYdZxFTMBEwPJYffY3odXpeoZb3e3N7c9n1fPBNVSCn1XUmTJxFRs/LMCr1jLC5FLjmeYRguLy+3201TN1NFxliV0TS1D957f3p2dnJyWr7uzeWbly9fZsnbzXY2m61Wq9RH049cyzEhIOexnmq73fT9oCIIWPgRYFAosnj0pA6riJXN7MDGOmw0hTMkkkXHClkAUNXdbnd5ednUdcmtqMiDB2ez2ez7Wff30sPFLTQDNdU0pglURKe5hAAjkjWqIB5UziY8aELkc8o313eXl1fbbauCs2aBCE2opny2dH2vKm3X3tzdzefz7b59e31zd3sXY0o5lapFNUUg0JHbqQrjMv8Bo2R2U85E5LmazWbzeWNmzvm6aqq6rkLtvWfH08pEg5LV4/LISueC4qL9XqP2vu9L2qhgDERUSnvjMCCgqhJiHIbbm9vdbmdqDx4+/Ed/8o9evHr54tVLEamqUGQbqqryzplaitF7753v+361Wv3Jn/yjX/z85yG49eaubffD0BcPMqahqPd45+6T3D9q6uF47XRdt16vb29v1+u7rutzTjnn4n6V7y1gyZT+KBPQEzlEVsWUNEuKMUEpCmYiR+yZnSMiFc0pD8NgaghYZC0W88VysZzNZux+QIGujTVupuUvEZWuH4Yhb3vZDhBWD5YPH3V52LXbNvVt24JbR8Fh33f7rs+CLoRZDZ7ybp9NdvvN7fr6bPsohKowSEHF2e8x43N/KROya2Y5pXa/v7q6evPmzfpuPQxDCNWIXeW83W7fvHnz4sWLIs9cVcE558bGIVgCzSyComUzQ5w6mU0oIAKoGYiCFTJEkizIeT85kSVVf3x6zrsAFTIBYdkHtAjNMpN3xbSjY+TSLIG898FXwdc+BOecgjACghb2NyEUmjI7dI4R0MSIR9X50jZFRc0AkBGJmbN3AF7ViJicd77yzoOhr6riHzDxsVEnGM8Fv8G0l3GAwg5hd+HBH2tzqoiZFr+BCWHqgjph8pN1H+uZ7ylXpf7aVE1y1+5vrm+fPn/xxVdPX7+92bTDoCpoYFklImAGMLQa7JHap6KPs65Eg2hStbFST8EETRF+6KSMMe73+/X67u7uLsZU7pSqDkPsun4YYoqpyAshonPsfdGfQkR0RUezqgxgvd1uNhtEOBYXWy6XJycnZw/OTs9OT07HmjcmXq1WzHx9c933/XKxOD05uYk3WfM3n6qIpJRTSikVvUlTNQAZn7VOT/wI+CqxyHEkdJBmm1K5SVUMlEbhTASAnPN2s33+4nnOqRzOOS7C+N/JuuM0jozTuPxEJaVUZP6mpimiBjg1UR1V/Zx33jl2jI6nUotRQ0tt6PrLN28vXr/t9gNTWK1mwbFjLA67qiaNMQ6b7RaurhBpSKkfoqlVPiDStJUbYWlbVLxY/KFw/BQmShZiCtVY91jmTyi10qH490UjYrTuZdoc3LIYx4rEwxL7gWf13hnCRCC9vr7e7/fz+Xy1Wm02m5yyiqaYyJDZBefSEK+vrtZ3dznnBw8e/Omf/dlXT588ffZEzZqmWS6Xy8WyaWrvfAHz2/1+t9vv9/sHDx781V/91S9/+Usi3O92XbtPaXDOA5R+mkrERTz1uGTgxx1l+rVte3t7e3t7u9lsS4eeA/xuIyW+5EGyqjK5Enoxeyr8c0IkUFNVTVm0ILdTagh0RCRBgZCYuK6bxWK5mC+auvn+vZdK1bFOXbUQCh92iHm9bvfiEi+XJ48f/fwvtt06vX0RNzd928J237Zpd7cbumH18NHJo0erBydc+2az2XcDIPd9t99v6rr2PpDjgvX/QF/2O16XDcOw2W7fvn17cfFmt9sVL7ZwWkuY9OzZs5cvX242G+fcrGm8844dIgGUAjaYcpE2RaUIUyE70WhBSwRfHpVNPz9mpx7PNx9CE+acIjkGQFHLogaAjtBxMfmKVvRmAMYKveCD946ZQAVASgUOmCAYoiIqoRGB0hRym6oCKgAoIZoioGpGyZSyM7AsRoR6IOkR81Q2UyJ9g2k+lhrHD5jI79a1v3fnAcxKIumeulSGmpGCISiOJUOHiP/+oyPu/o6Co6mJSE6x3e9fX7x5+fri6vqmj5nDrF6c7EUH6cxUQNSEgBaAjxQ/E/tU9CSnJmdWzVAQg9KSDsj0eyglHud7VLWUXtzdbdq2FVEiNks55b4fuq4vNYUGyvd32BXCJEykoWJjU4wIwEQ554I0l2natu16s569na9Olierk+VyGXy4ubu5vb2NMQbvV6uTh+fn65v1cafzD0cBnZqmXq1WJyer3W7fdX2x7hMqNT6AieVok/865hqLutPhwmUcijjKPx1MvqqmnAp59csvv8g5ApiIPHz4sGmad031t8rM4YF0YZZS7oe42Wxv7+5KPXHxo8eIDZEQD2oV45jNZk1TN81Y7+KcqXZde3tze3d7F/thtVh59s5zFbiuXIpxvdnsdvt91xU3aOoSQ4Do2IUQRPJ2u+26LqXkmBGKiEMpgfs+yfbj1MNkqdQxB+9LbA6oSACUAEHN0ASs1Mlw8b+J0Tk+KM8WR+q+QOBHYoweXhfzVuTTiejhw4ebzeb169dq4+Jmx3Vdhaoqy2S/3w9xWMznP//5z//iz//8+vp6s92Y2nK5XK1WTdOUYrmccz45y1lSSo8eP/7Lv/zfPv30k3273+/3KWUmns9mIgoGhS14Dy5N48e9TCiu6nZ7c3NTnrhq6ZE2xujTKLBByfcQERMQGIICEDKxcw4ZASHlGPMgAqpiAFqEOg0IRryUiefNfLlYzueLuq6Puwh+t4H3eShDBDBVSTF27bDe7rU6qVYnYb6iqsbUsg+r0/PVKViy3MVB923Wk1DNVsswa6qmbpYPxHwWZg4ANvQ7hIWnGo+weHtns/8RxuFBHM9eVd3t91dXV1dv397e3gzDUIqNAaCAK69fv37z5s1+v1dV770PYSoqUSJidgAmood9zsBADHRMkKvoQdmfCpeouABMRFRobB9CROx94IaYEUkz5KiJREwMQA8JbTNRFJnieUI3JmZVLYtm0aSawRTv917LkrNYllH/X1UgGTimstYLIyM4z4hkIOPTKNgQOArBhcpXlffBEZOqjElzURUxIfsGGh1MYXZp6gqIYKqAdLTB3UfnU9GRihlZqb04HGR8q6pOkoHj4zQVyV3X3d7dvXjx4snzl+2QgPzy9EEEbofY94OiKYigNoYPlT5T/KnAoyzznINkMAMkJTZmGNsOK/6AAKugprvd7urqarNZD8OgZkSsajHlvu/7vs8pT/4TFs4EO4dI5V6VssWUkqmKimPnvNMi3wEmql3fDzFe394AWN3Us2bWzBrv/Sh1uduHEE5OTs7Pz59++eSjJ3lYEohYVdVqtXr06NHjx5/c3W22232J6AjLc4GDbUGbGqRPM+w48it7etlYAYyoyDm8E+6XmH6/2z3p9vv9rkiFFdLIvST7dxkjECeyb9ubm9vXFxfPX7w87LZFP/KQOfPeV1W9WMyL2VitVqenpycnJ4vFYj6b1VWtKje3t28v3243GzR79OBhoZ5WlVstm65vEW0Y+rge+iHq5AB7X1Wh8s5572NMd+v1breLMYL3xas2NKMfmmw/3D0w4xFIRAAxyGpY+K0gGdADOgQ3GgaESV6TiJCZVKnMtN9HLDs6cCkVdc8QwqNHj9brdeHcmBkx+SrUTVM3ddd1Q+y7vksp1s2Dn8yaP/+zP99ud8+fP7+7vSuZy6ZpnHM5RRFlcszeOf/pTz79sz//i7oO2y+/bPddzsLs5rNF2UHKc79XE/r9jHKZhRu/3W5Laauq5CzvmvYRLCkmgEuQVbIzYIy+ClWoPDvuh67tMeYogmrFshvAaNrNgMnNZ/PV8mQ+m4VQvZcF+24D7xNQYDCCIl2/3/chPKhWJ1xVMccsmYhOTs6Xq7N2199e3eG61yFzHXwdkICZzx7/bLZ8lDLHOAz7q2HYew7OK05s2n+AUZ5yFtlsNpeXl1fX10XS2/lARKratu319fWbN29ubm4K3FjXdRUCjpRSA0A3CiBmgLGMa+JxaUFQEFBKcFXw6gLOIxIzOYYSjE3J5cMoFDkiRiCJGl0iiipqZgoKZmhohgQkZFqakI+EOzUz0ZQlZkmi2cxGzQFAtbLbapEMMSstTo2QCpm/qqpQB18F9qyAlBRB0bSo3DIhF+te+xAcOxQBMLWiRZOIEL+pPcx0pTgyAxENjMDgA0EYm8rYSrdmnmhsRyvTjhtmj4UfiCnltusvLt++fPni1eXbu+1O0ZH3q2pOvt5vd0PXt6lXNU+2MvujjL/I8EmWZRYn2cwUQJg1BKsqDBU5R/iDqLSl+mK9Xr+9ervf71WNkEquMQ6DZAFA5z0SGaifZGoKvCaiANkMVDTnjAAAysyGWJSGEYkZ1UxyEpUy/0Rku9seOFYw9mw+OTs7o29E7cqOE4JfLBaffvrpH/3RTy8v397e3PVDLDD2aNmP1mdpxgXvctkKse4Afo5IFo6OXUHCEbGglDEmE2HGu7u7p8+eFQ5LgWTrpsHvvlcNfX99c/P69cXT5y9evX59efl2t9+nyagXuyZTb9PiyjR1PZvP5/PZYrFcLhcF9pjNZwhwfX29vruTNDSVL5JCKfYZKFWAoPNZPZ813nE/WBYt7FPnfVVVBQdvu/b6+vrB2dmDs7PlYlFXlR0hk98Xk793jMprJio6GT4wMjgH3jvvHbN37Jk9kSN0iEzIAKSMKedhKPVvxTPG39O+a1OvuaLIXXbSks4cC4yYYo7DZkgxEVOhzRb/9eTk5OHD86u312vcNHVzenI6X8yDDzmnnMUMmVxVNacnZ3XVECEgZpG+6/e7fV2Nmqw2icz/nnINozeZc1HtLUJGKaUS/uacDgn+43QAETnmsSszUslHkhmDeSLvWBI5wGwg43wxuKegISEGF+azxXKxrKpqQvg+wiT9FmNqrYlYQoWUxRBizEMUS2kYhs32rrecc68i7EIzW7JbADRt24tmUdnvN6I1s5esAOx9ZWYRwCSrFqLJPxB7rrxIKe33+6urq4uLi/V6PcTonEMcpSCK6kDpeXGA7kIVaOw1zoW/XMJfRHaOtBhpOwAPdqjrMgMp2pOlltf5AxY4sj6PBiEBOzREscjRk2NAUUMDMqCxXAdNQQVykiFGxz0AivdMOAxDV9oPZDVAZucdO+eK5VeF0rN8bG4OZkW/6tA4nRAKvKCiokPfmuWchiKDlmVgBh841F5VNKuZ5pwhoplJ/nrTfuAtYikMLIkKwvcqHafHg1o4WzhmJUe/oLyplIwgHNKExVHoh2G93b549frLJ0/Xm10fMwUXXGiWJ1Uz29zcdrt92gpYqtnOxX6h+LNkD5POsqBpBhPEzE6rGuoGqxqdL5mG7zrHcFJBEcl9393d3V1fXbVdB4hEBXBLQ4ylBLx2bGZiUsRlqyoUeK3cUFUrRUrMxESiRc5Vcs7sHCPbqGxQOHcool3fdW1bKEunp6er5fLk5OT05IS/hZfC7Gaz2cOHDz/55NPVauWDH4aoKuW2E+OEyoMdShPh3i5MYMoIqBTYcYxRx6oIKApQfd8PQ4xDFMK69sMwXFxclLtWAi927qAu8lHM7SP33Wy/37948eI3n3/x688/f3P5drdrVdU5bppZaUtTlCgKXhJj3O32xwJtddOsTlar1Wq+mDPRbrNRSaer1ax2+9gOcZ/6NkdAiC6Epq7mi1ndVO0waBJGZOd9CD6EkkTpuu76+vrB6dnDsweO2XuHhd7yg6P2crnjQyFiZu85VM558p58cM55x4HZMwVCh+gQCKBIQELKUtcxpaRqP7rZey9rcAjZ67ruuo6mMXYtAmu7tt23jt3p2elU5EYI6NhXUzxa1/XJyclyuayqkLOklGPMZlhXTQhBVA2gHLBs60RUuiYdFL9/xAss4zAnbeLGbzabIkVccvwiOcZ0qAo5dnaJyDEFx56YkNREzciUVEmFjLnIbqvZ1HYSp7JfAGDiEKrFYrFcLg8VVvbtqRKHdx3lVAFRAUzFLCtazDmJah+323UPwvtbdhy8NwMkN5vVTM12c7ff3/VDd3OdVicn3td914a644Cmgqolj2g6MpHx3S//scZ7W4RNsrKlI+J2u805MzsAKDji3d3dZrMxs2KAy57gmJG4BB48ZtwVAAqfqaRSStVMeX3QUjFTUbUiIcrsfajrusy9D5lDperaGIkhsPfk2JAMGMBNJD0zBEDNFkm6vjOznKW0+Y1DbIe+jymJqCGzd94F7wQgi4yhVpEpUDU1sJH7N7LWTU1ASlNUUdEUY8fOgVmKKcYe2ULt6lllprFPxWe1CYI9vufva8hPqMZUto5QOs8fSqCP3z1VrMP9Y7OxuA4ndGQ8R9Wyui7fvn356tXT568u3t6S89VsAeSQ2UyZ+OHDR5JEzOJu84D5ZwqfSXqU0jxlFhXAhJCAclXR6oQWK6xnQN4UviO18/hSrO/79fpuvbndbDcpxVJSLCp9HPqhB4SqrqqqIqYhRUSsyvbMhRXv6VB2rCo26vGOYiNqfgoaTLX0zNCC2o8dl805t1gszs7OVsvVYj7/NrW85XF2Xbvdbvb7Xd91KUURPVAiCACJxzzT9IcO9h5AVHHKkkwsiBG0n571WMBJSMnExMxQDYYhXl3fpPybru+B6I9Fzs8f1lUNcM+U/Lq4BBGLwOfl5eXTp09fvHhxe3vX94Oa+VAtFvPz8/PHjx/P54sQQoyx77t2PwqExakll6h1XZ9Vt21HnhyTQ5wFn1PnHFesrnGrZhFzamOvJs18OWvqxWLR9rGLmZgLy5SISpyVc+7a9vLyMoQQhyHGR3XlgyshNf9AJt2BzlP0rabceUmle+8Cc2DyRJ7QAfBo2g0VjHlkEeasY8liebo/aiIUpw4uqloUt/u+n4AcLNBo27ZFBuSTTz75i7/4i5///OeLxQIR2ra9vLx8+fJljHG5XJZeR2WjJCY2JlEVEJFhiLvdvmmqk5PTn/70pz//+c+HYShJkL7vy0fKzTlEAj/aFQIAgKr2fb/b7bbb7Xa7LS6jaj5oAdnEoTsy7chMjsgREpTUp6GIphhRcxqGOGiKBFbgbBsLkIpeHHnnF/NFWdQTW+J7X9S4Pg3KPiOGgkYKKAqxbdP1Wz/s/byumznOV3lIaRiMIPVpfbd5c3FrNjS1A6DF4gxUUGPuc06DpQ7NqFR4wXePj77XKKjher2+uLh48+bN9fV113WmpReqFiGE8oy8H8PrEj9M7qYrwGFZW67Qa53POReOFDPnlAWkRKdqpoW3rcLMSFzV1XyxWK1Wxet67/QcMbMHNFHKPgzOO+acyaHz5JAJAUXBFLKJ6SA5xSENVQzeO+dylIP0kZkxkyMXfDDCrIqQQRNktAwKBjiWyQGgisWYUbTQJEVMRDUroBESmEkW0eyCmy8XTNzW7X7XxiFLFlXVqN8YtR8lHkahTZgq245GecPB0yn+QCHdTLNjav869mSX4ou1+/2rV6+//OrJy4vLu213+nBWzZdmoAY5ZQQ8OTlFxX63TzH91PTnlj5N6SRFl5OoCWEiGoC1mbnTB7w8gapW4uJzf485BgCq1nXd3d3tZn3XtjtRQyITizn2fdf1nar64OeLufe+HwZRDSEwOyJy3jV1TUjHdAwkEpXYR7AiUnv/ZYxMAFK8ADMkKvXxJ6vVg7MHy+WyqZtvdupHL8H+f8T955IlR5ImCqqqESeHB0+CBIp1V/XOjOzIXZn3f4S5e5tUVxcKQPKghzo1M9X9oe4nAwkO5MiaAJDMRETGcXdzU/YR1W/ar9fr/X5XN7W2bQ0MiBLUuHyc0CmcdITjjI0ZHAUQB7ydthQ4JWHhpPgUJDIISesNFggpbff7+4d1Vdc+L6yxWZaZ0a7ro43xaFsN3PUY43a7vb6+fvP6ze3tXdf1gOR9VpTlfLG8uLz64vPPl6tVWZYxhqZp66pWZe9Ddaiqqq6b0PcxxijcVlXdt0RwvloUmQl964RyK1me+yLb1/Xh/fs+BZ8XxpCacpq6sdY47wYakgxNl8Th9vY2hhC6PoR+PpvNpuV0Msmy7Lc3KYeR7WhNYYYJ+hA1DFpCRzprB6OuS/rJjo2KGNmQGSULPn1cPzbk9Rjd7/fHvaEFfd/3WZatVqvPP//8f/7P//mHP/xhNpsx836/f/vu3cuXLxFRARBqBBxjOh4UInwsl61drpbLLz7//C9/+Uvbtn//+981uitxQAGDn1ymRpeWiRraVSi+7/uUwiOANB7xpDC0sow11lpjkVDG/i5ziiGlXlBijIkTOWe9YyREoyZ+KIhIzrrFYjGfzSflxFoLvwX/OPbcFEGWUkJSWD4CQNu28f7OdkURJsLibN63XVdVgbr60N7fPrx9ey+pmUxcWU5OT1uOvcSm7/vQNRhaSxY5oYznBj7+oZ9sHWOGnj5936/X6zdv3tzc3Gy32xijCtRoFrjf76uq6vt+0ITPc2stjPIX1lprBgIYIVmD1hgig8giQEjW2JQSDIQ1AyklgMScmIkMGZMXxWw2my8W0+nUOfeRI5dBdEQIwAZ64xSTr3I3zngy6rbHUVKKKUrsOu5sn4fee++MS4HbJnRtiH1CRCFAIGccWnIAJBGYIIhETkkQ2RlrjSXAlCRKgpiG5ok+a04srGe2mlX6LHPWTaaTYlK7bF/t6/rQhC6omefjC/moameRNG5CLThgAOs/ivDfmtqKICh0Hz9MeUVAuRoj6DSEsF5vbm5vv3758u37myg4W53aPE+A3loUbLsuRkZyxphlVlhXvGibp32c9a0PnaTEIAEwWBuy3K2WxcWFWyyZXEyMKfy4YdoPrKEBUtfVw8P9/rDv+w7JIELXd1VdV01dN7VuOwDxWQaI2oVWg21rrLOWAFFkUpaz2awoS+fcoTo8PKzbtokhCohC8ESYoon2A05NT88syxaLxclqVealNT9NKtMYPcghOGetM4baNqaYeIA+gW76D86a+j3jX32srOUDwRQAwJKx1iYaVDbVFw4R8zzXmqHr+hAiIjDLZrv76quvrbVkzLNnT0+Wyyz7oOv0PQWKAAN0IWy2u4fN9lA3kZPLvBCFGGMKdV3tdtvNdjOZTop8CZjneT6dTPqw7Nqu69q6btq2DX3oQ6i7dlcdru9u+tA65511IcaANC2LybQsJ5MAGBJv9tWu6vrIh7qNIVoifU0Jdaszc4oppD6GPmh7YL/brVbLs5PVkydPTlYnv67SOvYdYfScHkL7sGSAXqNFcqhxHQ0CgegsjwHRkFHm37Fq/4RF++PrUoE2nWse7bN0Co6Ik8kkz/Orq6svvvjiX/7lX/7bf/tvV1dXzrndbr/b7Q6Hfdu2akzsnJLZoio7Ke5IBISTGs8XRXayWjx7/vx//a//NZvNzs7OXr9+vV6vmfnomVSW5a9GaH53Hf8SZlYpiKqqNK6H0I/p+JCEMEuMGgLEWvLe50VRFlMCSCFKDEQoCEnrFU4MLIRZnmXlRACR7Gw6m02nmc+scSCwWCwWy8XgnfgbHxsiIAmgsuAA0SA5Y7yzVR+atoYYOo5AzueTut7nO4+ATd1y3xrBpk2HEB5u17Pp+yKfcgwpBRIurXHW4FHV7P9M5T7g1/XXIl3f7/b79++vX716fX//0LQtAlqHzCkErqrDfr/rupY50ZAWk/JeCQkU3m5IR+gICqRGFUAb1CFGNX2FSSROzEmENdg7Z6fT6XK5nE2neZ4rbOjxpw0hxLYnQYkSYxIWVPU5l+c+ByIWEY6se4GZJSbgaARB2KTUS9/F2IXUJ2CByL01vbe+8MYZZ03yNnobQyIERMyzPM9zIGKEqJZVGs71cDao4s2g1AuLpjCGyBC1TTeZTjfr7eZ+e9gd4iEBfuuw+lZoZ3nshaWN9aPEjW5oRtQ5NSAq34OFWSOHALDecRHhJDHqOCHG2LTt++vrr7959frt+/V2Nzs5n52cROEokhMZIYncNx2YJDEujF1Y/zzWF10o+5ZiFwUiYo/YZ1lazrPzs/zqys0XTCYmQQ78y0L70B7T1lBVHR7WD4fDIcZIFhBAe3dKx1K36RBCXhRFUdrSmbEVb406BiAALBaLzz///OTkpCjLh4eHV69ebbcbNYjUjkUIKcWYGAwRGQMjBDfPssV8sVqt8jynn0eP0XpOnWHn81lR5E1TdykMUB5BBFIk4NC2Oro5HRMvgGNoPxZY+tfqaRhCSEkxLE6HEW2rMGZBIh3Dv3n7BhCUR+edtXaBo9jRd5cAsEjfh+1uv9sf+hgByTqTQLrYt10bYm8tOWem0+Li/DTPnbMGimIEcUjoYx9CDLHtuvV2c313W9WH/SEVWe6zPNSHVmIqDZqMXA7kmi7er7ddH9VYObJYa5wdCG7MSVJKKcUQuq5LIYlIW7e7ze5ktaoPdVlMZ9P5b28Ma/03oNGGxyJEBnHAzSGasRVvEBQGSThOEJ1zMaahLfJ/pmGqoT2EoNPHruvm8/lisQghWGsXi8XFxcUf//jHf/mXf/n888+fPHmivMeu63a7XVM3SqfU6bs2VLuuFZ33Go9oWDiEUB2q6bREoovz8+lkcnZ2dn5+/re//e3LL7/cbrchhNlstlqt5vO5Hri//boe98D1g+33+7oeLFxHUZrHKY5iZPRdMN5neVGW0ykw9F0rwUgKAiyo+JokAGjI5Vk5mbCgMfb09PTs9HQ2nWfehxCLPF8sZj7zhggUzfoz8pXvtEUVGEtyNNdmQAOGKHMm987UMXV9n2InbF1WFGXlvXaBQhsspGmRpa4NoT1s9/fXN0WeS+qcocJ7M5s6AuVhw1CZ/Z/aZjrlZZGmadfrzbv371+/ebPebPo+WGuJBwmKw+FwOOz7vlNpCRpEXk2e58YanUqAGtg/KjJBhAiss8YMZrbGGmONNaYPoDkmIVhDufeL+Xy5XOqg/bvnVdf3KVUEBAnaro+RAdAZl7ksd4UgJmZgVFoUqHOrJGFMSVmJkEJKIUnPkjiE1BvqnEVET6ifQf9htIikfvSCkERSDCwQE8cUJSUy6I03hmRovqI1znufeZ95H0KYTEvrHQAkTl3XPW4Sw/dV7REUDIgibIAAQP3UkQBZ0vj/FGM0cOWUk5EAArCwQGKICVLUft92u7u7e/jm9Zs37677yNP5sihLawwCsah9kkACStAe9lRVy6a96MOqaSd1a2Ng4UDYWmozG1dT++zKPX9izs9oNkNjidCoK+nP32HjCqGvqsNmu10/rJumZk4cIEk4VNXhcGjbVjUNACDGmGL8IESjZbtTzx2yxiwXi+fPnl1eXU6ns8PlxdnpiZ4jSitSk5hDXbVdqzGVU0IAn+dlUS7m8/ls/j1qxj/w4bXWv7q6+sMf/qAdLY3OKSWRxClFCHqVWrTTB3j8o+H7I/Cw/p08Csg/HrXqqAFx8P+JKemdZlET2FthiaF3huQ5n5yc0Gj9+62PrQP+xH2ITdd1IQIZQWz7rmnrumk4JUIEScJxMZtcnJ3hcqlGtIRkxv6bMIQ+7A+H3W7X1U1XNanrM+fLvNzUzb6q+/b+YXuYzKYPu+3tw26zr0NU5GPhLJERaw0BgKSkbhIh9l3Xd52ykGOMTdMeXLUvD03bhfArh74jQhPHASEeh+5IhMiadyEZBIvah5cRAz98p/pGg/cuxqRKfJ+kkP3WcxlhdPrcnXNFURhjPv/88/V6fTgcptPp+fn51dXV06dPnzx5slgslNumCrh1XfehlwFRwyISY9psNtvtxhjrfTadzrzLNYFWd7Wu69Vn8/PPPy/L8vLy8osvvri5uVmv11mWnZ+fn5+fH0O7fCJ9Hv0AOsRVFdsjHh5AVCVaRmypiCCSNdY573zmshwVgp1CkgQoQNB1TdM1kSMiANFQ6RD5LCsnk7IsMp/FENVaNMs8En1gqvzos9B17Ini+Fu1Xgc0IsBJwIAlzL2bTaEL0Ac4xNTVqd5Wa7qXGLlrZlmeGXcy8fbqZD/Nu651Dg1Lv9+13mSLeW69QyFg9a2DUWTsgzPSp9puQ9UJWubplP3m5ubhYd00rbCo00nf9SGGpq7btpXBv4qIyFrjvSuK3Fg7fjKQxByTyOgeSmgUTYQIIMRkRJ0cGfHIIMWyzOfz2epktVqtdKt/d4O1Xd/wQVVpQhtC26eYrDHOuMzmZAyDOHKWbI8YDDKjAJMSjBSsx2wFBAkIDaFJEJuuI2ZMgjjQ6JSOxBIjhz4IYRJOKaSonttBlZ4FgNkalU1S0HLmXZZZr/A+mapvWdeHLvjsW7iBj0N7SpEIgQ2iCKIK8eo/LEiCzMAAisJWJ3cFWCJLFOklpcQQGFOEGEUkxXh7c/fy5evXb9/frtfT5XKxWBnvAMQSCWDo2tgGCQIh8mbn9rtl15+17bRufNsip0jSGmoz200yOlu6F0/d8ys6O4GyEEMqRvBTZmnHDTb8QkOaVh6bzXq9WTdNw8xJuOuDznfV3nHgAmpda4x24tU4UEeDSt1dLBZXl5dPnz6dzWZ96M/Pz5q6adtW3WBfvnz5lfuabu9EOKbEKUVmFPHWTSeT+Ww2nU6d/R5lwO9dGtovLi5ijDc3N8oSiTFWVdX3MWmPiIdC/EPtPuz7sXR4FN11KdnsWMETkXpxxhgUp2tGIxwYYdWb9Wa/2wOn2XTivMvzfPIDSEAWfRlZPx+iSQx129R11XaNJCbAFPrQtWcnq836aZ75zDtCUHMnbz2RAcEG2+pQNVW9fVgftvvEkYCc88b4Phy22wMhTubNoan3ddf1KTGTGetmAWMIhROLtlL6blAOVx1wROLEfR+6ro8hJhWA/A0Lj0nViGVEVJ0l1fz/AIkHQKWL4gBERS3TdfYjoJ2X/yP1lIyafc45bYnzKGS0Wq3Oz88vLy9ns1me50fcpY5Lq6rqu55HMI3+oqqqzWabZb4sJcsKa1gEkqSYuGmarms5JYU9L5fLq6urzz///Pr6+t27d4g4nU4vLi5+e9X+OEzKiI/ruk6bcI/B8DCEZnwMJCIiM9IxjHOEhMayeAEe/CJq7EVFI0VJTSBASKr3oH1gQszzLMsz5x0d1SZ+4mPD4y7CCIHRTWQADQCJoIp5EkDmzWxCMWJKgFWQLnITGto6iF76crb0k/np1C/Kol1Ou67vQwcQTQrSNx6nhSWHghyBo3D6+Md/onVsDWqd0DTt/f39mzdvbm9vd7udNolVAiQE7rq2bZvQ9/gBNKfUEuezzDqrgiuGiGMKIehpAgBoyHrvvAMAFu77vu+DJmvGkGIdjDGTyWSxXJysVovFOCj5zgoh1F0rSSRK7AOHxCk5Ywwab5x3HpB6co5MixAMRiGWOAgbqPsEiEWd9JMzJAKp64VSwgiGGCkmFpGUmBOYrgcQIGSQyDGllIYQkQAAez2rgcQAIhlC1a0lQgdWfF4Wk9m0a7rQxSzLHl/ItxvynDhFBAMGAQhwQLmPc3ZmQVQROgYEYOBBOC4mjhyEAzBHwcgp9KHvmrre7/evX7159fL1ZrfvYzJLyr1PzKntvM/J2E6oiaHf73F7mO2rk0N90bTzQ2XbNkmMFhpjDt4105JPFtlnl9PfPfVXpzLJgkWUCAkA5CfFWb+71XTKfnd3+3A/SFjElEJMTdMeW/GPO9Xe+8xn+uqqpaM1VkSMtfP5fLVazReLsiytNYYyQzSdTFLUC+2qqrq7u62rqu+7EGMIIcUoAmVRLBeL2XRWFoUqMPzMq0DEoihOTk5evHix2+20Frm/v9/tDn0fU0wJh9mSYkmGypEGY59jrvrRdPxYsh8ZaGHwqI9uzGn0Le37fiB0WbvebL788kutK6+urhaLxXfvuW4jY6goiizPE3PXtW3TdF2fBtQVgAghHqrDdrtdLpfL5VJzKG8dGdO13eFQ3d3ev3377h9fffX6zZvNZm0s7Xa72XRSTiaGaIMPKUXrfYFwenZmvW+a5vgBcOQ3931ou16FVAHAee+s99ab45FurTHa3PuZD+R7n9Eol0GEiIIkSEJGUATNcbj+wTpxqJpAUaj6sQeBiZHH8ElRdPrjBhFiPWT1uWuMF5E8z8uy/G7rUkTattntdse+lNqrKCCOmbPMq0hLiBFAvUyi+tuGGGE0MpjNZirDfHV1JSLOufl8aF99qhbFcN9EYoxHu0UN8Me89tslOxpjnHXOOzJGdelYWLQwJ0QDoNZKCDpBY2Zrxm1DJqUUIHiN884N8Hgt2n/i+SlVcsQHjDNRGa8DkESLK2aAZA3lGS3nzloqJ2neMhOQp4m3M4uFgYzEGAQ0s9wLFswJQKwzeZFNnLMgkHqOVijgQJ1F/D8wbtcCr+v7/f7w8LB+9erVq1evHh7WOuvUOY4w96HXA0F5mFo4HRmYens0XlprGYkTCygK3fgsW56sTk5PFci53+83o8al/g1qOLRYLFar1XK51I33vXsMgRANIAsMojCDoXlKFJN3ZMl5Zzwaa6hLNkgfU4gxcmJOgJQABYmNQQ/oLSbiDjkl7rvERAwYA8SeQxdSYpEUUyRLSMjAmtgbQrBOx52gSnYxKoQtsSRhFUiKMTGjsT4vJrM559+2Qv52aFfdlWFiOoZAHgxwRQ3bURhE9XQSJ04RUuKQOHBg6RE4sfSxbZu6Pjys17e3t+9ev7t+9z5Gts5jYsPAKUkMSB4FUpIu9O1unT1sVnV8WofTqp40NYa+R+48Vd5WRRZPpvbJWfbsPH96bk5mISOmhCyIgirs9fM22fG/KaX9fn9zc/PwcF8dqpiiIGhqr4qnj33eNG303rljiHOODHFM1lrV4ZpOp957QgRjCkMIOSCmGOumKYrcaYMvywBQ9ZIM4bScrJZLhRZrlfzjn/9x/9x7P5vNnj9/rp3Gvu9FIEVJsYrSc2IYe/L6khgzmAjA952bxx4gDyKORmkhzNyHPiUmUuVzr2PLEAIgeuectXVdv3r9ehy4Yp7nHxEWNGyDgDGmyPPMe06pD2GAdSrwVwDGxqkCpwnJGat+9inGzWZzfX3z5vXbV6/fvHz18vburm4qn7n1ZjOdTk5Wq+lsFkMfYywmRSY5GHLer9drZXMpViulFFNq265pBw8PZ5y33lvvrCcktWgaz+lfa9WltxSOmgE0+jBCYpGYEMhCQkoIBlE7TyoLJaj+ScAA+AGKhj9rTPvrlvar9bmrItNkMjk5OYExLh73xvHrmbnrOlWcZU593x8d29TdzloHgF2rhrYq2hg1A4hjxgwACtabzWbHonm8XQOO9+dfwo/8yVFG9zhlHwftcByWHBl3+r4rWBYRk1oowjjWRQVqcEwJENCQEtLIkZ4MSBhi5MR5lmVZ5rw3xgzP7kfj+jDUgJEIoZAmRMJjg5wASIASA7MIJwTjHBKZzJuigFmPUQRQfGYLR7kBh2JRjAXrjfFuoGcMuYshThwTgmEKqKNrgA8f8Tdst8c3X9t7VV3vtrvbu/v3799//fXXb9++PZbsevOZuWvbpm1iDDLSQ3T4OdxYABifAiEByTDeSkkn8WdnZy8+/3w6nVlrlMsKACEEGu1MsyxbLper1WqxWEzK0v6AtaAh64xnSIkxYcABJQEoTMw2SUZkAck4QQaDIAaSBQgRopa8gkxIQOwRcgMBJaEklhg4IUUhTpAihxBjCBpfDRtjDRCOHTtDBpFwlKYGFtYmxWjOKQiYEsckgMb5vJyK/5GqPaWoXhSKyhNmtdZUKTkQBqGEJAkiS4yxC10MvYQkkSVBQoqG+i50h2a73dyv72/vbu9ub+vtvqs7AhQw/aGudzube+d9SBzauqkOcXvIdvvVfndRp4s6TprWhK4j7qxpc9sVWZgUdLrMnpy7s5WUWTDQxw5QrDUyop1/zp47hsY02DZv3r17/7Bet23LwoDQdV3bNF3b9n1/9EKl0afEGGsGtTQa31bx3q2Wy8VikfkBLzP8gwAAIfR1fagOh7qqYgqEKMwxRBRwzs/ns9VyVRSlMZYQ089uh+mp55w7OztTsngYJschRE6J+37QuTueldZardwVKPPRuXn8rX6xZnh63KtUOzOHEI+9fU1ERESJZG3TXF+/178jz4vD4fD4L5fjy64qj4P5EFprbbI9J2YGFhxSVA4h8JhMWDJN2242u2++efXNy1e3t3f3Dw+HumYQNIYFHtZrMhRjnE1K611eFuWkTMBA1HadPuvjNYYQ2rbrQq/YNGedBnVrnCE7DLSRiNA5k+Xfn9f/nOczlBhIRJbQIaAwxz60+xBSbQlbX2audG5iXWF9rq8yiA6FEnNSp68YYmI2hN+fkf3m9a0utDFm9BT/bq/ycYCXUfJdlYD7vn94eFBc/XQ6nc2mLNJ3oeam73tDFhCPZgXfpa3rjzv+4S+N6x9dDjwa0muxvtvtbm9v7+7uVIfuGNpxdCM8RhcNKh8Y9iIxRSJDhowhYwkIEseQYh+CdarnAzGl4f0iAwIpRjDGea+4hKHnAj8V1xPrjChxREQypDKFiEdwG4I2fpB4aMsnQZ3Ii3eGjBEgQDKWrAEDCTiIsuGREaMRi9YQWWSAGASYCUAMU6BBSfPX769jcvb4wakuwna7ffPm7ctXr66vr29ubu/u7tSK+nHLJMXYh6A6049JJYr/yPNca0sVEQkxoIBzFgEIMc+L2Wx2enr25OnT2XQKiH0Im81GoZ0AoE+hLEut2qfTic+yozjYRzvNez+xZYwhYCcxcAyI5CxlRB7QhCixjcIBUg8hQIqEiYmFhCmx6gDhwFECBk5AMrrIoai9GyEaiZQQ41Aww6A8gggojAikwheZMwZ1ohRiSKEPnaTACiHgxKGPMSRAdN6Zb490PwrtHGMcdjYhoQLmRJTpgSgJWUCEJaQ+9G3fhr7nECGJMEWkDrGumv1md393d317c393u354oMgeyZGFmNqqOmz87HyZlXlT9dWhbrc7Wu/mu8P5oT6r06KNpu8jhOCxzU2Tu67M0qQ0i1l2duIWM/Emcop9RyIATgVS4ad57R+olQAYY6yqer1+UBha13cAMFTtbdv1naqA4Qe9EeusNSPwnMgMkHNETQbn87l3jggH8IEM+Xfbtdvtdr/fNU2dQkQA4ZRi1FHcfDZfLpeFTjFHfvkPvTnf/a0xZj6fA4AaTh8Oh6Zu+pBiSBqJFe40ZCfekeJHzfcAso59uGPqM5ZQ6MmLcEwhpRRDJENHlJPGQhHRuXvfdXmWrVar/X7/8b1/NAjR5ryehiYSBGFhYTE08Cy1ojJIzlhjTFs3d7e3r169+urrr/b7qm6bPkY0xnkvkvaHg+KAuuViNpkYZ0dRheFkH1mzQ63Ztk1MLEqKtSORkaxyx40xzts8z/I8y/PvAdD+jHV8Uqg/BMmCsERJXd+H2HZbBI6+7fM+yzgrJEMSYxkROEmKPMQ/YAF9iOjUE27859OtY2iHMauDR2jNH7lCBc4RGWutytrkA2osK4pcRc7btmuazlmHRMLhaIj+I0X2r47rH++vUUC6aZrb29u3b9/e3t4eQ/uxS3HM9Y/fpdtyZNhL4gSIBo2STlhSTDHEGFI0TnmvwjI0PFRNVv9m5c6NoX38lD9wWaJqol3TtU3XtzFFESByWZaXk6l3VpgBUNAwmCgUWWISG5MQ81CCWWeFyNGgbMbIgVM3suUigCFxKA4Ng7GMLEKACGyAAqmo8yflVSq+Yb1ev3379ssv//Fff/+7QueGnt+3qYkxJSUTyTis0VPXe6+bKsY4TkCEUzJkrHPaRSmKYjFfnJyszk5P86JQn6djrgYA+jiUUrRcLsuydM6OBcbHy1lLVIRgCIBjkBSRxTvvjfVDGycl5gApEkfiNLAmUIQU1SeIhIZMJElIbIbWnBpvG0BDhOAkhSQpgcLE8Dh5A8XaEKK15L1zzgzVDqfIse96iWDQIlJKHEOMIYIwWSLzwwh5nboNTSrQ4A4gKKJjGojMsQ8pRg4xpRg4cUoQE0eJDIc+PnTterN7uFtvHtbb7bre77uu96CoIUnMbdPg3thFAVIcumq/fbA3m9l6+6xqn7e86HsTYyBunalLV2fUEQWyYL31uctLcp6TSAiIiRJTSsZZZ4z5ufXu8HVd1+12m81mvVGKWozqnTt27XQKMnTjdYc559Copa7KFwsheOcVmjGdlCqv/dH9bJp6s1lrh3nA5SYGEe/8dDJdLJfzxcJnHsd568+6hkfnoDGmLMsnT56EEOu6TjElBk4CAHXdKNIShsIIAbC0xpAhQyMP7vEpe/wzGPH2bIzJ84yI+n6YUxIbnUzkeQ4AIQRJyVkKIod9urm++eqrr37IvI5Z1ItWRhyAgHYXZRw0AhKaEZzojEWktmnu7+7v7++3m20XI4/TtZQopphSONRVeh+22810ojiwMgnvD4ftbqtOkd77ATcXQogRx6BO48LRnDfLsuVycXZ+Ol/MijIn8ytC+/iYxkdEaDgxhAR9jxRM3wOyoOmBYsIoIEhEBliRK1o/DR5SqlhhjGKDP61izRDP9OD4bsn1498ZY+q7ruva4zwoyzLvfYxhs5nEGLuub9s+RVbPD4A0acsYA/9UFv7bexM8ipqpf6sajdzc3BxRKTw65H5oQjwKA25cxljmD12KmFKIfdM3fQhj/i4sYtTkO8uMGWr0IbRnmVb/Aj/22JjVMauqq33ou6o63NzePWw2TdMVRfni8y8uLy4W0xKRGExg6qM44RhSJECCxByZARiNoNXynpBFUs8AaBgNCxgBC+o06hiAkxiJlETAkbPRjInNJ9lfej8Ph8P9/f3Lly+/+uqr9++v1+t1XTfHUcjjQc9xHzIzjDMsa22WZaqXoG7USOi811YWIVpjhRki5nl+crKazxfeZ9qkubu705qtbVt9BMaYwXd4NvPe0w/jrgkJjUFxJIBJHFpgKazzPiOymEBGAXUFxQIkTbwQnSUjFiUJ+mQweeTMMMNQHcckCCSDAJ1ABoiSgGEU+UZR0zUlmwj2cERqoyMQlASxTyISQwCAof8UkwJ0P8qYvweVrYolWp0O7VsdliZOMXZNG3sd5bIgSGLuYt/Hpot3df3+cLhdb9f3m8N+31aH2PWQGIGssEFOnJquiwcw1TRmttpv+s1Dud6ePByuQrwInIXEwJ03dWHamWstdoEFjbWZzQqXF8Y6FoEYCXBoggiTd/hL1OhEuOsajeuHw77tWp34xpTarmvbVg1AAUTNSo7zdZ3ukDFIxMxoTV4Us+lsPp+VZaFKKGN0RxlI85XyiIZaYVDJwCzPlD18FDuU73unvlvffLwRiVQpjJmbpgkhdZ1KSydmbhpFL4euA0QgY3zmlZGvEwUeDiiAoWrXvTW8ZimJXrFzTvVnQ+gVKOCdM9Yyp8EozxgEUAO9m5ubj7v947WkFNu2a7tWlbwGzRAR0NfEkKqrZholrHPGikDow+Gwr6pD3TSqTIHaMTWICfsuxdjvDl1VHw6HfZ5neZ4n5qoZGMxaBKQY27ZNKQnAIGcxDvOIyBrrrCuLcj6bXV5eXl5dzuezATnxC9e3npg29wGVIYDJGBCnllMpRuhTJAaDZBGJY5QUkZM11hin4FQWUJWMX/oxfvanlWOcO56231346ODQVlDTNApTUmyaGiwxc1Ud8ryIg+uKIJo8L5y1gDyZFHXdqC/Dj2zsn+wZ/Pi1qPvATjP39Xqz2Wy328diNVoaPv4uHBHyMMoJjDAuA8JjAsos0vV927R9DEq6YhYF2FvrFGmorSwtBo7Yw5/sdTNzjKHvu9B3+/3+q6+/+urrb+4f1kVertebP/z+d8+fXO53u7rt28BdYMccAycS0NCeGJAxjSWfGERSgxHgBJJInIZzECZkEE5oIpjAguzLUjL9hI/u+W8B02m9fn9//+rVq2+++ebNmzeHqoJHLZnjOqZWSVEwKamupYZ2VT4+6hAr0UU1TgcgC5LWNqvVyWw6NcbU+/rh/v727u7u7q6qKs1ZFRlaluVsNptOJs75H5uQICKiMRaBCGzmpxap8H4+m5U+ky6GEJEjAiMmABROSnQjckBOiNggWDaQrAXviYUhxtinIKo9QgaIELwLAjFwSMBaWmhfHkljR8QAfbBkiMhrSWmMJSRmjqHXKYbOuEhn6N8Ogt8K7Qo+1DuuLzAAi2iJx5I4hhBDl1JU32BiSV2I++awO9xtD9dV9a6udk3XNF0XUhRUUEEE6TgBozBC5LbhcHuXVQfc7fL17ryun/bdqk9ZTCzQOVdPs2aW9bMsCHfbmtAWeZGVE5sX5JwAoohFMMLEURKGAD8TRgcACnqsm+ph/bDb7dquDSHElLQj1PZd1/cpHStdOpbsakigajAIyJzImPlspsJGWZYhHotgRJQYWZvkDw8PdV2pPoqisglxUk5Wq+V8Pi/L0ljz49XYR+XUR6eelgiLxfKLL74AwNDHo58VgOjJq9H9qHea57lzdhgxJrUHGMbwGtpTCuNRyTH2+iI46wzR8AXMjtBaZwitMUWeaWh338vO18Mxpa7vq+qw3+/rqmrqRlEziKjDjsz5LMunk+lytZpNZ5nPLJmYkk7UhpmoVkqIgECGkCxRnqKJIYhC82Ko6ioxdyNaQrMbTViG2cTYh9efq+nRbDK7vLx8cnV1dXl2fn46mUwG7N+vXcOElBAREqekVHVniGJKIQiEkCIEwQ6oQ0Dt/tFg02IAUMHDGng/2Px80iWDl6AHgGMb44fCqm6Itm232+3d/d3N7e1msx1F4GG/363XD7PZvCwLZg4hdX2wxp2dnZdFmbj33jw8POx2u4vzCynke8P3L+gcfPu74EjI3GweHh70ByHisey7vb095qwyakEeA8wxubEj7Y3IEBFYFDhqcXIIfdt1KUbA4W4oPEO/S3Ggxgwx6Tjr/TmPwHmf5XlMqW7aV69e/fu//tvN7Z0x5uH+7v3bb/7wuxfS7t7eXO+qOgTOAWJIgUbUJYMAK6EJQQiNIDECIoqNwtG6HFEEnbBwSJxiB1knLjBYYC+oselTlext267X69evX//9739Xg+CT1QkiGbpumuYx9uXIQhwOxxDBWhlhdHobZXSqBJEYwnG4l1LU5zubz5erVZ4XzGm/393c3q4fHuq6VsCQvvJFUczn8+l0WpSlteZHVAZYW9FgkDJfZN6W8+lsMZ2dTcuCqNnvD/tDatuu70Uic8cRANiRBXACJg4odgBBsJktC+Xse4wRE8YUmSWkFCMAGuOEACGRQbJaSyML6d1g5hhCjyQMiMQppSgskCKn2CnnPCk4B1F92x9fyLdCu07Oju9JjBGAUxrrjpRiiCEEYDaqq9xHPtT9/fZwv1k/bO7rehP6miWKEvPVBI8ScwAmUmQ9cIjdOmVbOqnrVdVcNP15iJPIyBCMaXMfFrO0msCsgK5PdULr3WTiJiXlGToHg7qHagBJTMCQYvpJ8tuQIzKnELrDYf+wftjtd8rGiSn2MXR9r0YkY8mOR0Uway0ZIjVVTAkBVLkwz/OiyK2xgCAsTMepNaYU2rbbHw7b7a5pmpQ4RT4Ok6bTyWp1MpvN8rxQy6MfCiJ6qBxFNh6zz+ERcz3L/Pn5mQi0TRtCqKr6aMSuCKa+71X4wjtfFIWGN0FOkMa6fXDHFWBCIMTEUTXARcRa751XMEEfexEmg5lzmHtvXT7C6dVVPYTweNyujbeUYt91VVVVVVU3Kl0TAMB7b8g4Y4osn8+my9Xq5ORkMp06Z5EQosZINT5GYG2MCIweTcYojpRS6GMIMUTmFFXgWYAIeXz8A4ZvILj5cdDuM5fNF/Pz0/PPPvvs+bOnJ6vFfFY65+Q3RXYAAEIwBEjCklhQyJAzRMyAsetD4oQMxNQFBOQUJUZIEYGcS2QdIcIwrQA49lV+0yf61jpuJBpLpceF++PqWbebVmO77e729lbxULvdngdXhWy32+13u9D3h4MXAcVyOpd552LoQ+yJ4Obm5vz8/OrySiFmavM6XNWvutXHF0FE1FDu4eHh3bt3erJrV0zfuGOQ4AEVDPAoudBbMbRwxkE7IhhjtPcOMhho9n2nJwAAiDDBIAnsndPDISt8URRZnh0Huj++cGjgZynFumm7vn+4v3//7vXt7R0AxND07Z5Dk1HcbLdN20lIETlFiCQJmFXRW1CAUxSOMoyXEIjIZejQMCVClhgQY+6QjA0EPUhgFDfI7w2P/jdMQ4biO6XtdvvmzRuViDfGrFYrY2yMcb1eH693lHUCNRAbnXkYQI4FVZ7nSsnRR8MiKUQcUd4xsXduOpks5vP5fO69CyHsttvbsRUPmq3iB+3O6XSaZ49UE77vYpk5QkKyhjy5uS9Wk8XZcrVazsucwPhNMtt2d0BoJPQpNjGBIfE2M2xSBMQEEBkRDZDLyU85RWJrODiOLIG5jzEmiWDRWMsoAkAGxjY5qs4Cj5P7lBgxAbIk0Z5QAokpymD6LokZBBE5/ohfu844QauV4TlFBCFUHHrilCQmBCCS1EfeVe3tpnl339xv2v0hxigG1ckNwSAyoRFkUX1mEQSJIknE1bGIfNL1n7X9ZdvNY0ShztjWuzCZ4dm5P1+YScZV3ex6652bTc205NyKIxFBGZQGkyRISQTDz+W1Ywixqurtdrt+eKgOh0EINoa+D13fhdDH9GEapL3oI8kSAFIMTV0TERFMioKZQ4xV0xRNkWWZs5bMwHSKMVVVfdhXh0PVdUFUcyklQPDez+eL09OTyWSq3nHHj/e9H/o4O1Rinvb2ZTRWV4NCPStPTla///3vUkp1XY9JAIuIcvn6vgcYUUKITkcMx+JUwyciIrA1zLbvuxj7FINwssZ6Z8uyFOHdfgfAlijPs0IlYcmqCcrl5aVS7W9ubj7cdCQACDF2Xdc0bds0fddxStYY57I89845Z+xsOj05Obm6vDo9PZtMJkhGQIOi6DmryviMKAq8Q0DU1hwZYwmQyCQTosolIaYUEzMqcEUZJWNc1zYMGVsUxXK+fPrk2bOnz64uLs9OT4vcjTJVn6ANrvISgJIAGZDRgvEppj70IQFaEiZOQEQENnJKQYJNMbK3QNaipuVay36iqP44NdQMXlUKdMN894sV2yMibdNuNtvr6/evXr5+/+79w8NabbY1+Z1NpwTQ9W19OIwO42QI6vrQ920IHafw+vXr5WJ5cX6ZZcV8PnPefiAX/dqgcmzq1nV9e3v7/v37+/t7Vb85OzubTqe73e5wOMjognNMWT6K68c/lFHlhsgYiygASQnHzBxi34kk1UckAGtM7n3us8x5zb0VsZVnmTkyp3/00rSE8FmemBE3zBGldybNSsMskLq+rZqmQau0G+20pRiBQKJwTBCjKM7ZWrE2aeFDlqz3U2vAFJ340GJV95zScmJmpbcZOSPAYCXHgWb3oY3xqx+EZlfv3r3729/+dnt7y8zn5+efffbZdrt99er1oTr0fQciR3kMbaOowi8CGDK6nXTKrsYwXdcd6Qwh9NZYXzhOHDlZXy5Xq+VqNZlOkKiuKvXmPuwPSb0/BDSrm8/mq+VqUk6ccz8yaAcAGYzNSMSj5D0XbcqblEczhcJZ9h4mDg+GajlUKRwikyEufOkBowTHfZeQDWsvR4ASmAiGkYF4nIImrZhRLGJC+EA1138BcRhDWEdKWgYQBEPknBVhlhQZEktiSXGo3kP4YVNXAFBbz8GuTfVygI3S2YWBGZlJn8qhiTcP4f19fH8vmwO2HRGZ3BtvkkFAg8SAFoFBlAQvwJyQgZPveNam8z5e9WnVpzylQNg705dlXC7M5QVeLNGgRevKnfPOz6a2LMQaodG/VYBFDQ4jA4Qfrdofp/Z93x/ncIeqGjXeYx/6vusVkMyPPJtH5QnndPaGJJwSJyHs+/5QHe7vH8riTdM00+mkyIs8z7zz1ti6bra73eFQtV2nwmYpqfsQ5Xk+n8+Xy5NR7HAYz3/3Y6tCp3b1tcd4OByOSGNtxU+n05OTk9PT09PT0yzLz87PQ4iHQ6VGHYqGO7Yi+76r63pk5jvvHJFBBNHSbWT2ATqRZC12HbRtyykJJ0TMvCNDKQVAmZblZFKWZWGNRaQ8y6fT6dOnT58/f359ff2tTYUgAjHEruv7gX3Ahsj5bDItF/NZUeTe+8V8cXZ6+uTJk9VylReFUnoZhAw577z3WeYwpqj4suOJDgCooFQcDHGiATTKPhSJCAOXyAzmyd45b63zzudZfnJyenlx+ezZs2dPni0Xi8mktDTYMn+ShYhkUDXPAktgNCYTA0CJBIi8JWvQEBqyRpKwYUArSAL0AVsIj+PQJ/hsetu0uaJbCwCur6+1BFeow3Er6O5MiZum3Ww2N9e37969u7+73x8OIgKAfR9Quy/z2aHCFlpNE40xhNS09YCQArh+f3Ny8v7587vFYpkXubUWzMf7/ueHFvk2yU1xW1VViYjK4Kg1g+Io1RL+2Jk4NuGPLfrH/XnV1wNtziGo0gBz4hhTDIhgrCFAI+iNzX2WeW+tDaEHAe99WZZuoML+9EVpP8kaUkCwNeSc8c44S8KSe18W5Wy2zI30myIak0Q4cowIIiFJCNwH9esAIlb7D0F0RZaTNVgwlrsmrXf9el2Fvr9a+quT/GwJ0zIhiIPSkgxQUqLfsrOUUXx7e/vu3bu3b9+mlObz+dnZ2Wq12u52m+2mOhxSjKoFxQLMAmMaxSkhgrXmcVwfBVt4zAMkJcUKDD8xz/PV6cl8uciyrOu6/X63Xm/W603T1JoNIIJzTsW8lyqT/MEf+fsXMyRBIJPASrKUqA3QBejYZOiTBfYkmZHOgTNiUJANSZ5NcuSYGpvAJOGkeEAchuFD5TQA7xC0VCFCIEgJGNQLlGXAOwHo5NQ4LTMIBig7ARgRw2yYOUYZIAoxacR+fCEfS9akEFJKMaaQhtIQBZz2olFA2HCCvo+HSu7W8c213Dy4h33e9lPG1vvW2SDQobZ0dW6nyj7CzBwCYPLCp0k+Y3iS5CRCwUhikrEhz3g1x8sz8/SCV7PQNlK3Ls/zPMumpS18Im2DMzAIMAvHlKJ2X8PPqdpFhJumfni4v7+/32w2dVWFoITwQVHrKKIOYzY9hEDvNYVUyoHOp7uue//+umma27vb5XK5WCyWi4Xuoel0ut1s7h/Wh7pKrHw4UNqEKnwtFgsVO/yRdqSStd6/f//NN998/fXXr1692u12VVUde/I4IuTPz89fvHjxz//8z8+ePZtOZxcXF3/+c0TEweN81MwOIabEXdfudkNPYjKZFEUxPKeBvKtaEYgIKfkucwTSdT2nFEOfUvQ+X60W3ttyUmTeD1vQ2OVydXl59fTpk8vLS9WAOy5E1GbmaCgH1ti8KCaT8uz05Orycj6fl2WxmM9Xq5VWWtY6QbU0Bp/72Xy2WM62h1nVtF3XCzMDj1muWuIgIgAaMmTRAJKAiALuARU3p1I01jktDKbT2clq+ezps+dPn52fna+Wy8z54Ukcuci/bemprQpazNz2yVguitx5V04dJ0E0xjrrvKKH1JvOeW+zjAFiSOMRMSjMf6q4rquu6zdv3vzjH//48ssvY4x1XevNPzk5ubi4GCSYyMA4Emradrvb3a8f7u7vd4d9CBEJY0r7qqobnM+mk+k0L4oUY5KBSdt2XVtVMQTrXB/jerO5vb29v3+4uNwvl4ssyyyNOt7Df37x5SFijFHdX7bbrTHmyZMnV1dXp6en2oQ3xqgOz2w2UzD/MVrAo8jxLUY7sxJKcbSNEOCUYgxBmMkYwoH/6p0r8twNohGouA1toY1KNfJzrkmhB5nPimJSljOfFX1cI9BiefrZiz/885//u+X2H9VDXL+NIpxS31MkCZH7nvteQpAYAIABkAHBmKnLjFvUdnGQ8qu79devN7e329D1z06yPz6Z/L8NL7LCG3AmOAKDg5zNr95exxH7y5cv371717btarX63e9+pz5D283m7u6uritU2SxEYIahp6h9KUYka1FPWvUSPHI3YExGtcGpZ4ixdjqdnp2fz2YzRKzr+v7ufr1+qA6H0IdBWwQpz/PFCFjOfobQoTAmJiEEApFEFEJoQ2jarjCGmjocWq57CExAzrrcueA95H6SQwpOsGeyKQELQOQU2kMSFAEEJjWDQiIkQ8aSJaIkCTkyswCr+isO7l5orNGUEQmFOSWMyMBoDbGzKSUAFomSoqQkieFHZu1931WHgyYC6vCu9rcGnA5HMCXsW9kf+O5B3t/C22t7v6WqnSZIxkWg3qVoU3SM6gJHKIRCiqcTSimDNBc+j/gkyjJKngQFo7HB+zgt4Wxlrs7M2SlMcxFB77Nykk/yfDa1ZQnOsY67Sf9lEVCFD/kZCHmFoe73++vrGyW5Ds5sIWrUiSGmlFSMAx5Zb2kdk2VZnmVkTIraQkpJuK7rru+3+8Pk7mE2my7m89VisVotF/PFoTrc3NxsNtuu70NQMfbEIt656XQ6m81ms6nCl36o8ds0zc3NzZdffvmv//qvX3/99bt375S6o/N1GHFDRKRjLWXuffH572az2dWTqxi1dg9H3p1Iy9zHmJgbrWC0q681DRFZa5yz3jnrjbUEADHmhrCuqhBYOCFI5t1kWuaFAoBRmJ3LJpPpxcXls2fPzs/P1Lzr0XUMddUxeRIA550DN5/NT09On1xdnZyczIY1nUwm3mdKERYRQCiKYnWyetJcAeB2t98dqu1+X9V1jEExZsPpidr517dCrAgCJmNAGU3WHpkOeVFMJ7Pz84ury8unT55cXV7NJtNCoWtj9xR/Gtf8sxaC+rNbkRQj94F9Rs5an1kQAUBrjHNeC6ZoyBhL1iCZlGIICQDIWCT6hDN2rXHrulZW2O3tbdM0IlJVlbV29DJPJycny+UyyzJDVtPormv3+/39/f31zc1+f4icCExkVvyNd06FVY11ISbBiCwDghIRyaTE+/3h/mF9d3+/2WzOz0+LsjBmtKj/5Vd47MbFGEf7me5YKWrOGkIwo2PsdDrVKazmxMfeu+KtYMQcyCiqr0N252yWeRGJqY8hqvHaIG2E5J0fDcVBx7oK/vp4oPuTFydCSHlezKbz2WyZZbM+gIi4bLJYnZ1fPsX+8DYvCY0IxCghMAJ0PXc99x3HADECJxBAIUOZyamQbNpgdmjgm7vmb6+3Dw/7FELdRSL87DS7XNgJgaNkhuE3/ZZ0lpm1A/Tu3bv9fp9l2dnZ2dOnT5n55ubm7u7u4f6+aRoiSglUFuGIH9JFhMbY42FrRrvVYysFAPQ+p5Q0XZvN58vlsiiKlJL2bDbrjQLjcQTMTicTdRQsy9Lan9aXTOrlilEwJNsAm9aZOrN1lQvDoer3h25fdU3b9aGHlLxxhbdlVnoJ0LQJkWUY5jEnCUHUX1vHJCCjYwQSoiGkpNS3gVyFMhYEpCeDU7yzcIoowEYsAzAgcEqhNxEjKmcgpR9DyNfVIYZOtJcvDACZz0yGhM4SUhLse1nv+Pae31zD9Z25X5tDQyE6MMYRYGIbEXtCqA31KB2m3nFC4cSeOUu8ErlI8CTIWZAisrAEMjEz7awIJzP35NQ9OaPFTLzFpjVZYeezYlrms7mdTCm3CQEjKxhGGBmBUAzCj49Pjpuvbbv1ev3mzeubm5v9ft+0rWpfxxhiCDFGHhu92orXzqQG+DzLyrL03mnHIMbUh9C0XUypbpqu73f7w93dfZ756aSclJMQQ103VXWoa21ISh8CiGR5tljMZ7OZbrUf+cDr9fo///M//+//+//+3//7f282m67ryrKcz+dqoY2Ifd/f3d0pJPj6+nq93jw8rGPkP/z+94vl4unTJyEEESXFqWCIIt6jDuOPdT8iOucUaZF5l2VeyWfWGQApy3y/2z88bJjFWlMU+Wq1KPJMYBAYm04nT548ubp6cnl5qfYwH41sj3OyvutC6EFEOS2z2UxFJE5PT4/X5ZxDIhH1ygJCnEwmV5eXeVGcnV3cr9e39/dv3ry9vrmpak4xjjbmOqRipfAQkncuEXJKgGhG3Uq1m5jPZk+fPn/x2WfPnz1fzOeTsrRkCARHe7VPF0aHSs4YK+O2CSGohq4CHaxzmc90fyIRYhSElDiEFGNUA2GdGn6qpQXW9fX1mzdv7u7uUkpnZ2eTyeTy8lL1ZzabzW632263L168WK1OJpMSBgxdX1XV/f39u7fvdrudMDBhZGEBYdnuD23XFUVuretCiDEBgAhkeeEFDJGINF233W7v7m7v7++fPLmcTife2yH5+wlrtB9bOuJVxIC+I1qv86gme6wFi6Lo+74sS/XS0BdBMxsYMdv6ghx1pgHEGMOc+tDHGAdW8/BkjbUuy3IiI+NY9xiZfuljIaKyKOeLk/niLC8XfaQQQhywIiRokmBkDBEhQtcLsnQddz2HXmKEmIATMIPJjbWZyQsqiirw3a65vdvd3e+btgeRh0N8t43XB3nSgrWcAQPCUQoNfu3mTyk9aDv0/l5Enjx58vTp09Vqpd4/79+/32w2MUbnjkAWOI5CjpyFI2nQGKPFACKqL5Gq0iq4NYSg+sR6aBhjumFfDYKDAKA/KMuyxWKhx8vPtB1KMYa2haF7l9ilvYhBzH0eA2939WZb7fZV23bMyRma5H7qbZllJkInIDENrR1CEjYIAMIgaURisqiFbkQBEkJJqJW66sgYRAJz1D801loHIgxAyMZYcEBIRElYUmRJLCExpCT8UYn4rdDS9V0fWj3fRERDOrDFFCgFrGrYbuX6Vq5v8d0tPWzdvnJdsAyO0EoCjEjBIWUIlaPaYA1cA0cSAZwIzoUuElwGOY88S+IA2drkKUyyeLqQqxN7derPVjSdElEsezNPRkwxKex0SllunCEAAk6YMCVhEkEQQ5gMtT/5zPq+3+/3d3d3b9++u729PRwOWqAoQj5G1eblcQxCR4im9y7LvPfeWvXNRkVnCEDXB2AGgZQ01w91Ve13O+89gKTEmjfoHlbWZlEUi8UH3fgf+cDqY/3NN9+8ffvWWrtcLp88eXJ5eTmfz7Uiadv2/fv3b968efny5X6//+abb0RgUk6tsX/K/lhOyufPn7Vtt9/v1TKeVZ4cQC9ah+7OWS1rMu8skbM2y3xZ5kWRZZk3lhbz2a6cpJgOVZ1iiKEnAGcNA2oVfHJyenl5pb1c9wPKzMISQlRHJha21mZ5pq6XRVGqrUhRFEe+0KDQIoKA2kZz3k8m08l0VpYTFQVcr9f7/SEmFQxgloH+reRZIiK0Kjmm1A/rrGqEXV5evvjs+fNnzy4vL3KfGSJgAWF8VLV/wiUyyJjjqGwqLOAzaw0iEnPiBALj5EQHPqzZmLUJALx3AD/U3PnFS1npyvYGgNlsNp1Ol8vl+fm5MWa73e52OzXYULd4gHNEbJp2t9vp8f3w8NCH4LPcem+cV+aMpNjHSCGyQAiJRQYWmTGEhABd31X7fd00Nzc3NzfX6/XT+XxeFOrDLSqfLfhrArwSQLShNYy6ndM4HUI4HA7b7bbve+fcycnJYrE4OTmZTqcwog10HVVOAUC9ZFJKAHic2jLriAdZILFgGlIHFZcNodewpPXAz8HGf7QQ0fmsKMqinPq8ZKAuJAX/6PxMwCSmkAAjxACQpO8ldBKCxAQpYUogIgSGfGaLnIpcOk5J+j71feTEgNBG2be87WTfyaLQUSGJ2hj8hr2fUlqv19vtVhsnT58+vbq6ms1mNzc3TdO0j6S7VYqOxxJTRm7bcAqNOI/jn6uKWkpJc1xNCIwxi8VCT0JtOGmFo/NK7b6oSK1WDpPJ5AiFhh+HPggjR5QehttiOaYQoQ+MFKu6Oez3++2ua1oCodxaN7HiMHUcuhi6ELoYehBBaw2StxhFgkjS5rlEkIiQUCIIS0LgXkQJF6hjEaMt+yMgYBgQEpIxJGAVZybW2Mx7SSxRgEWSfFTffiu0xxQFkrXWGoui4wYm7qlL0PXx5j69v5N3N3i3NtuDq5usjy6yFXIkDthg9BgmgAuRQzIHiwcDFUoUQKEluFMw5zGd9nEaJBck59iZvrBhOZGnZ/azS3dx6pdzKkqHaObCxpti5jKHxUSMBySDaCwQJoQIYAmdMYk5Wfut4e6jdQTQQdd16/X99c31+/fv7u/v1ba1j0FRCEMLjkVECMlZm3mvXibeucxnzjtOqUuJUMUFUIQFhAiNHXyeUox927KkrutAwZasALoUU9Km5WQyXS5XZTlRovxxR333ox8Oh81ms9lsmPni4uKf/umf/vznP3/xxRfqxIVIIfS3t7evXr3613/917///cv379+/fPlN5nMAnEynLz77bLFYfPHFC51XNU07KmyKkk20jakqMXmeFZlHQmdN7t2kKKbTsihznzlDZj6bqkFtU1drkum0cJacd9PpZHVycn5+qaPZH0qKFa0UY+iDmqjykVxuxln9x3Rq0Wa8UnXBkMmcJzTGaCslm01n796/v7u93R32TV0PTjMsMLhcCKh4iLFIRIbUm2exWJyenb148eLzF5+tlqvce0OIoMhVGC3TH8Onf1OcV1S/CLdtU1U1ETrnhvFPFry1xphgQ9/3oq0oFaJi3ZAhhIiITdMoe+e3fhoAANBzkIi6rjPGaMv92Dvx3l9dXW02m/fv3z88PHzzzTeHQwUAzvmmbe7u7hR/XlUHtG4+nReTmc/z0Hdd24a+SylYIiSyaIhMWU7yInfOG2MQoK6rFFMI8ebm5s2bt8+ff3Zyslos5s4LIjwWTvoVF3Us0O24EFE1o969e/ePf/zj/v4eEZ89e7ZarbSg1Ha9vmUKvtlsNqpso4V713Wa9QNAlg2wEgRU518RMWT6vu+6Vqt2TWWOwekXXsTg7EDWunFAJsx934cYVG8KyQmYlDBGSAkgYgwQAoSg+GbkJAKIxposM3lmc19YmrZQ5Lm3VvT1IIiAbcQmSkjDlFuIAH/T0Eerdq0WFovFs2fPzs/PVSN2EKCylkcjgcdavxqtFU+uALpjYiSjjsLxZDh279UkUNszdV3vdjtVHtQv1q/Xyl4tYX7Inf27y5LkRlQygFHQmiwrnZ+CyZNQCKnvu9A2qW0AGMBixtykdtdL6Lpq27eHEHpSjyHnnDFdihz6yFFSAA4IkTBajJiYgYfIPJhNkCXVqkORxCwpJq3mQUCjOwkkZE6MgM55YEQ2BIbAfFRWfbshjAAIZMl6iwAG0IJg28W64c0uvbuV9/fmdk27yvadDYlYaMDyqh1tshQLwjlBA1ILNUwNYRJDSWaJlwEWkWeMzjljCTLLuZeJ5+UEzla4WsikEGuRkND4PEfjbJ7QGcmz4c8BDSGKESKRhDbhkMR9b2f7WPsBC1f14ebm5v379ze3t5pd9kMTnpOwqCyaCAJaYzLn8yzLsyxTX0ZLBiGmJAxoCBCVycaabxI6a7y1CUGSBdFnI8yJiaw1nSpLW5tl2Ww2Wy5XeV7Q6DHzQ0mk6mop6scqZuTs7Pnz57PZLMsyREopnp6erVYneZ6XRQkAt7d3b968KYry4uKiyPPnz5+dnJz87nep7drdbt91/eFQtW0boxXhlFLfd3UNWeYz74s8m0lpDHnvijybTorJpMiL3HtX5F7Z+Q/3D23bVlU1nU5ms9lqdXJ59eT09Gw2nTqf4fdfjAwjnpQGJF1MSKRKfxrbAT6IvetThpEKgiIooOxhsAhIg+pOlud5PptMHjbr3XanJ7Jy/EDJ70ia61pnvfPT2XS5XJ6dnV9cXlxdXZ2cnBRZrmr6KIAjQgc/cKAGfONvXBqrmEVFf621CgwwSAp7McZYZwFA+esCEjmFGPs+hNAjUpZlQ4nzidoJmiuk0dQkz3OFOeiUR6vVY1Dc7fZ9iGVRsMj1zY3mxDGlIi8n0+l0vsjLMvR937V916YYde7BDEgmz4ssy5x3xhhCtM41ddPW9eFQ3d/f39zcXFyen1+cZ7kftK1/7dKIDmNn7v7+vus6a63++vXr10qw1qbX1dXVxcXFfD7XWlAZKBrXHx4edBihdfzhUDV1AwjOOQASBu2oDjzglJhlt90ZckXRTKcz7RZowgrfFgb4mWtAXY70VBbu+05PAAQUocQYGJAlMYJopa5WMTAImAMAEjlrvPdFbsR2wUzLLPMmJYwsIpC0fhRM47sFaPDnggK+f+mUR6uIY8DWPzkON4/IBo3vzKrSO+j2H1kMCl5WzZ/HHXsZSbyK2lHgpPKGrq+vtWQ/TtkVX6w5q2YAP/NBOGPQO4veUiZUoJ9ks1k+maFxkTkJiKq1ABuOGEJqQ0f1PjkJsa2rFAIIqAK+95n1HvouxIAiyAmFDQAQMIBA4hSYEyMDWiFCQQW3IUKKzElSTHpODmNHIIDBGB4QrXUIhtAaMsYY537Y+Y2sQWNclrnMOzRGwLY9HPbtm7f87gavH+x67+rO9dFAQhBBZAIW9Z1mL9ExlYycMEaMgIkxoRUWiOyj+CSexRmkMsNpkYpMygwmOSwmsJhLnkeAPkbog3VA1hjrXEFgKHqSwSBDXaIYCWjgNTMaHlyUv3/LKb437vf7t+/facmuvB09wvQLQAAFUQABnXX5OGb3zjlrSAkwzChIZBApxhBTTCkConASjhwJQbwl73xe5DGmvmv1bK9NQwBEZjqZzucDNn4sEX+wzao4Pi0+1ECpqqqu62az2aAtbM1qZfI8m04m89mCBf793//j5vrmm2++Pj09VQGm09OTqyeXXd8d9geVkjDGeG+ZNZiGtpXD/uCMnZYFL+aEmDlX5L4sskmZl2XuvXOWzs5O6rpp27br+q7rU5LZYnl+eXVxeTWdzqwZKLyP//vhGbAM3L8hY48GLIxuY0SUErdtowJzZAjNoO0JipJjGc4tAPVdnc1mPsvKsjxdnaw362OXeLPZNHXNgiJ8pMJ766bT6dXF5WeffXZ+cXF6cjKZTfMsI0IQPjouIQAKaBvse6/iVy8cWKrU933oo9FOGghzCiESkUtuIGVphBsEU/u27RSidXSO/yQa8oov0Q/W970WuDrfqetaT8ntdnt3d39//3B7e3d9czubzcrJ5Pr6ervf9yGQMc57n+VFUZaTKecphCKGHpiVNxGTJFb9YBrSNyIArGcVMIe+2+12NzfXD/dXbfN8UhbkLB47kL98aesVEauq0l7UYrHI81xxgu/evauq6tmzZy9evFDkfFmWfpRNPGacTdNUVaUw+81ms1lv1uv1erOtDlXTNm3bdm07evckEUaKIQROUlV1WU5PTk7Pzs4+0Et++Rpn3SoIYECB+n3fdW1MEUUVTyQmQQYGAEEWYhZmlSLWgwwEQYwxzuVFTiZn8bNpXmSmj5SCKKFUEIVQgFgMoEOyMHpC/rqHII+EzlS9X2dPat5qrS2KAgCGUiolACBivdYjwVj78ACgJZCefhrINZarGYeIbDabN2/epJSUNHRzc3N/f9+27Tg9QW10nZycaNP+x1FNj5czzlOR2dK7CbqZyRd+tTCTaWLoQiuAxlrvnUmOYiLuQ10dQgo1kIjEBGKdzVzmsjzP8sJlWRIhJBJABiMISEROUBKnkEQ4JUkAAIZAFZCQkDBJYOYUGRVRhgiIwpAU2iI4CnPbo/pW9iOmrmoux4DMwAQUuT80cr/htzfw5r3fHrDqKDKJCAmj6hUjiBAIISMkkkhRecySmAEZgJXZTihEiN6ig1TYmGOfm1B6Xk5lOUt5kQBj0/YszvfOefLOWhutRzEcWYbTlwAFHklQIyLI90tjjl+DzLHr2u128/79u+ubG4VaHDNrIrI0wBS1LaioV++9syq4bpAQUPT0NWREQNmEAoOMACHG0Ku7NhujHjA0Wk0wc+j7Is+Xi+VysTzqxuvH+6FNNplMVquV8kT3+/3Lly9nsxkRPXv2/Pz8vCwneZYZa4q8vLx0ArDfH0JIddXs99WXX/5jMpmcnZ0SUTkpnXV5Xkwm0/l83nVdXdNopZhijE1b2wPtD5Oqni3jHHHoQ2TOZN46ZwVgNpudX15an6WUprPJ5eXl1dXTk9OzyWTmffbj58HR1FUGiYoEgOqXBQAhhO122/ddUZTL5XKxXMxncwNDLgXDVgbttBMqB9g667x1ZZFPJpP5dDabzsqiJMQHkbZtQkgAjIacsUpxfv78+YsXL1bL1WRSWufU9wVHt6vBLmvcUvDp4jqMJr9jbD6ChnDo/YSkZ5+11hpnrFXcsBoUwSNDtk+1/EDCDjqTCiGcnJxYa7UM2mw22tWczWcicqirpml3h30Xwm6/b7uOQYy1LNy2jatrQYwxhL5LMVqi2WxaFGXb96nt2qYNMRlCbZz0XRdHBm5d1zc3N+rhMZmUxkx0DPrrortGDoVr9H1/f3+/3++JSG1C1GJYsQvaMe667ugBo+xW1X1arVaa3xwOh/1+v9vtt9vdbsAe7Pb7/aHaVdWhbbU/FGKKbdf1IaYkZTlR2PbPjyI/smTQUuWmbqqqbpsODAmggAmCJJAAEIH1AEYYTUsEAPW3hGiN8d6kwp0sy8uLRVkXXRQkc7oqTk5mk0luqAMhUE7Wb0itAMAYM51OdXix2+2+/vprTbYUuoGIy+VyOp2KSN/3IfTaGzsKHhxLbY3lej5rF/1YsivFse97/RFEpBwilSpZr9dKwNG/R7VDVqvVyOH86Sn7cCHWWsoLXxb51PipKaZuWmKR1W3fgqhuljEGDGECkMSpTxgigjcmy5y3hc+mLpu6fGaMF43KkkAYgMcPoEcOqBafRg1tGCsvDgkRSdVzOAmjEKnAOij7W0Rn79aQ+sdYY4zzP9yQFwBmlMDMMQGbLsaHHb+/t9cP/n5DXbRRWW0qigCIyKANioHjZiRCYoQkbCQaAANiFI/AuZHSYmkkowa4wb4Fkyj3ZW5n02hdDIm7CqraeWecs3lmfGa9N9YZY40qqhAZM8hoyHDmy4cg/vEapGBiVJuWh/fv39/f3ymYdhz0Dv2flBKiChSDy3xW5M5744yx1liLpMInhoAMUUycYkqcDKH3rihyYW6qKsYEIkd5L92dRVG0XWcMTSaTk9OT5XKpmI6ffFtWq9XV1dXDw4My2v/rv/6rqqr376//+Mc//e6L3w1d5aLI88x5f3Z69pe//CXGdHd7949//OP169fO2dPTE53Tb7e7uq6sNefnZyK8XpPyTwC6EIJ6J2932+l0slotUooySioYAkQxhEVRXFz4z158nud5WRSz+Wy1Wk0mU+v8T+b5w2NARMRB4QCiMCNhYq6qSnXEiqK4vLz87LPPJkVJRJxY+ZoyqGPj2GQnPcAynxlr8yyfTaeLxWIymaQYQx9C33fcMUdvfJZlpycnX3z++fPnn11eXuZZhqNz6QByBiAcp+zDEanrU0RTgaMqhwA7Z8dXW7RfgYFiCn0fHTMQesyc8zElbc4/DuqfTowOlJ0RQlDjItXg1FirgvAKPnriPRI9PKxfv317qKqu6/owQNWss13f9w8PTRfyfdG2TdvUnFKRZ0WR5XlWV9V2t9tsdk3TEKFyz4R5v9v1bYvAXdfd3d1d39zcPzwslvOiyAF+6XD6KIIr+qItl8vLy8v7+3stu+u6Vv6IikQpF+D6+lpx8mVZqkugJgSKmVccnDL7Qwh9H9qmq6p6v9/vdtvtdrPerDebB7WcORz2h+pQ13UI6cie/0nay08uGdVwNZhVVbXf7+rqQHkuSGJsAoyiJt8iJEIAZgjrAkSAIKAu3IbFoGQezs8mX3xx0fbIYqyzq5l/8aw4WYgLLGIQLKKlgaD1K8t2Y8zFxQUA3Nzc6FBD3/cjgPT8/FwTKY3TWjd1Xa+sxbZt47jUCraqKuXLaFRWpOTR+FUBkg8PDxr7te9yJMFba8uyXC6XJycnZVn+ouSYjHHWZHleFIXLCltktvDibQhRhxZ6i2S8VagIII9Fns+KsizmRbkwfoK27Htum1Y4SooiESABMYyqxXKkBbM+RnW8JGsMDHYvWgqozrHKygzCLsIAhgxZ5501dlDF+PZU+iMNeYfODWLpUULft4c2bQ95VZuu5ySKXhMYShxlHUVgA8QgBlgYSARZICUhI2SALLoMMyu5lcIkCxFinWJLJLm3q3l2furOTgMipRi6jlNgjhKZe0Fgkmiic2gsGWsMG8POHl/pQd0ZQH7AHkbrn+12+/bNm7dv3z48PLSqmvdI52joyzEnVrEH47zzWeYGbRNnnUVAzapw5MIq5g5GupTuV52T6YZWOITu1BgiIU0mk9PT0/l89l1uzPfmJvP5/IsvvtD35KuvvlJ6aNf1dVXf3d5fXV2dn5+rfdxqtSyKfDafP3/+/IsvvtAS/82bN//+7//e9/3z58+JyDn//Pnz+Xz+7t3bly9fvX//fhxVJBXo2O33+cP9fD6dz6dlmU8nReYtIqijRIxQlOXl5dVisVR2XJbl1lr8OaBTQCJ0zqrKo4BI4hCjTkb7rt1sNu/evfPeb3dbFs4yPysmjghUsQBV3BCRjIpqCCCR6hiTWKtcfBBp6kqEjaH9fhtiKMvi9PT0+bNnT58+PTs9KYvcGDOI0I9n2FCsPw6cw68+hTrMh8erJGlVLx9NLWlQ9lXGKyLCkV08/mQc168taD9eiKjulsvlUgfe6/WamYuiaNtW6cir1Wq1Wk0mE+vcZDpNwt+8fHVze3uoKkDIi4KM6UOKDN5ZBUw4a5nAGEq64WMQZkPonUFEqzbCSHmWOUOWyFoTQtis12/evJ7Ppov53Dv/6y4QRxPV5XIJAJPJZLPZXF9fd12nbTOVh1LuU9d1Xdft93sd7ipuqyzL6XSqahMa+PXPmTnNuOtD17VNXVdVdaj22rHfbjfb7Xa7227Wm7qup9OZtvo/Ion8onaLjB50Gqh4bG+s1+ub65t+Oum6buxk6bQVyIBxIAgsCKyVFookxWmnFBE4z82TJydZWQIVSB5BcpMuyj6nGhOx0GAe8S3hoF+8rLUXFxfOuRDC/f29Cl3jaMOR53lRFNouOkZffRw6BzkcDsq3XK/Xqv/f970iPXFkyulZyoPoLDC3cuwHav07ekNnWbZYzFVG7OdP2XUhAKIQiEExJJbAGwErhcPgTedNdJQsgkFgNGScddYZ48DlPp+Uk9l0MpkZPxHM4NDUVWAJiJEMG0Bg9WYVgKQGAEN7faC5kxkFAREJkYfTaXCIUvnVQZiXNCTJAPs1hB/ZVH4rtDuX2aLUu5nqrueub2OoexPZgwQDCCCCFsAIGNE8UQaBewEGIhHSqp4QiMWDOGumzs5K8Sah9H3Ttk0HwnlWzKbl1WXx9Ik9Pw0sfehCVcW25tgziJCwxJQkpsSMjEacE2cZWC9fw7ZCVdIPCM2mFJumub29+a+//9c3L7/Z73cAMJlMtA95xGpq6scKf3DWqvqGd6pv6qwTTsLDRP8otgAgiCQA+lsEtM6qfoImqgpKapqm63okms5mZ6en0+n0MQ3jR9ZkMnn+/Llir8qyNMas1+uHh/vDoXr9+u1yuVytTk5PTy8uzp89e3Z1dXV2drZYLH7/+99tt5vr6+v1ev3v//4fu93u7u7+2bNnV1eXz5499d6/fPlyOp1pW1LvgDayDoeDIVRjizx3RZEBcNM0TdvEJC6frorpdDpdrZbHgeLPfGeIEMF4n3nvjbUAmJj7vq+rKoaQUliv19fX14T08PAgIs65q/OL1XSuE3flrpExwyBqZOGSOucyWyHK/HI5/yw9y7wrcr/erLuum04nT58+ffbs2eXFRVmW2gDEgZMMjxDwH1+HDK/Uz9IR+7E1psADRyBFQhTxY16MiIOwgLEWjQHUWdMH/VT5IIB6xAD81qXI4adPnx4Oh//6r/96/fr17e3tbDY7OTl59uzZYrE4Pz9XrAYRAWLdtm/fvVOXTGOMin/FJJHFucxaF0IZQ88pEKIh6NsGmIvMe2dBhrLGWquyqcLshkqrV8bmcrl4/vxZWZS/+vK0LFNZkouLi7Zt9UMCQFEUz58/v7y8lFHhXLlYR+N2xKGpsFqtVHb+MVeKjClLU5b5cjFPwwAr1HVdVdVut9tuN/f399vtLsuy5XJ5dXX1uCH3S+M6jBMrjWrH0P5w//Dm9atuMa8PhxSjDHr22r1FAygGSRAEJQIwsIS2rdu26fuukKzIs2dPz54+98Vk5VzRNXWst1RdY70XgYTEpLZvv2l7GWO00pjP5+v1erfbHVnpZVlqwnQsqI7Mt+MQfb/fv3379s2bN/v9XpvqGsvlg3yQjM96QG7I2BCTYSw7jNJU20DT0/l8/mj0+bPWCOMPKfYm9sA9SbCUysxIdKl13BruiSIJGgvOO/SOjWWyjjJvi8xPCusyEdc2NXMPEojYWgRDqEEzCvckQZIiuEG0y69EXRlxRXpSAQCwiPbHEw+qvDERoGLFaAB78SgJPayPTF2td7lxhpASE+TBlmWcTEJZ1H0TmT2zZ/AsmWjtwTAU8cIoqF1cGKYm5J3JPZY5Fk4shNSHGLq26VKQSemXq/LsbHp+7uYzKHIQwGSdNVxkEgMLJ4CEwCrxHwX5OL5HQaMYqCE0sygs6PHSp77fH969e/PX//zrv//bv71//95Zp0SI/X6/2+26rjuiP4Zc0tk8yxXW67x3R/etyABgHBJRSqN/FFGWZWpjqqr7IBJC0I9lssx5ry0mxJGJsVrqiYmPDLB/aCkq/unTp1p8zOfzt2/f3tzcbre7qq7rur65udXy682bN1dXT54+fZplWYxpNpudnp7GGLfbrYikxN77y8vL5XJxcnKi5Qsi6NRAXzOttA51dXt3hwSJ42G/n01L720IwWfFsxe/y3yW55nS8X/hKYBkjApwKomoV1Pt/R5RmqbZH3ZVXRljyNL1zXVe5ClEvILceYNE1jCARUTDgAIK2hhGVYAAjChCJWR4unKOitwfqvO+74uyOD09Xa1WRak4mrFYH8Lq+OEA4PhaIGo9JL89rg9/31ifG3LOAoAO0R79ONbd65wT+aC9peIKRx4X85Ee9luXNjPVgg9g0DzXjnRKabPZIGLTNJPJJC+Kpm0BIMS43++7rtNS3jvHAixgjCNjM29TdMIRhI0hkGSJrHE4FBJDbiLMmGUgYohiDMyx7dq7u7v1et11XeL0K7jg8Ai5qa8JAGRZpoZJyiVRAJ1OeZtHSyt45VvHGDebjdLct9utFnyTyUSNGR8D+EVkOp12XXdysqrr88vLK814iqJQaVsNYI8mKT/3iR3juuLztXDv+3A4VPd3txaSSHLWOptZbhGAUJxDtEhCLACM4AAS9MQp9rvt4e5ma5zPtQ2Rl8Ukt9b3NnRMfZVCiAzIaIAs4G/Fc2gUVyW++Xx+9FQ9qsspbvHYJuQBUDJgGMuyVDScZjM4stL123XHaqt00DcU5jQYAQDAYMMhQITWutlspnXOYwDdz7zAmPrEnWYOBQhY48PESD7JMk+ZkdJjn1NoXIodI0NmkCgKQgRuYnB9Z/s2R0tEjMwDTpEMGQCDgCjAASNw7KkXifItrZnRPliUqYgjcBZHsJImBzhMXWJKqkLLAPyRT+VHod1o/9mQYTAUJZysuK5iX1cS265zIeaRiySYBoVvgQFWrjkTIhokYxxkmSlyW+a2yBJyH7u+bZq27WNia/x0XlxdTS6vypMTzrOESATOeHSWZEIMwhw4JhAmgJQkRO5DDJFBUJiAiAxj4sQh9L1azT5aI8siPDzc//Wvf/1//p//73/8x38AwOnZ2XK1Ug8D/Rat12GUUMi8hp/cK7fUWhKREDlFQrLeEFGKgzuLsbYsJ9bZFCMzx7EHoHMC3eU4gpDLopzPZovFIsvz42H0k2+LMUZzkfl8/vz581evXr98+fKrf3z98uWr+/v7h4f13d2d9/6bb77RwfzZ2dlsNjPGXl5exhjfvXuns8bT05OmqQGgLItnz57kecacmqbRo0QzQQCOKW532xD76nB4//7dpCgy74015+cXl08/y/LcWqdiczKKe/ycpV88SkMXztm6lrZrd3uMKVSHQS3EGYcGt7vtV1//wyCVRbmYzossNwkFAI0xonEdiETphQMmQFBEDFlny6Jwy8U8xJASG2vzPFfyFY5lCQ6BfbzJwxt1/CUM0in4Y+SFn7805GiQHmP26E0CAiDKn1aqrgqfwdibfUTRJp0+fRKIvAaPN2/evHv37kgJm0wmetS+e/fu5uZGD+jlakXWtl2nkm0hhjzPBQQIDZKBwZ2TgAwaYQRO2sgjo7MGGmdVAxZKj+BhniXSd/1utzvsD33oH0/Hfv69/d6vx7F60y13cXHx7NmzowrKsdmrAX5EzO1UqVSFHVUkVTvMuu+On02jjiIS5vPF6enZiPcm/fNfFyNxFMJTNUAFhekHbptms9lMC186U07KopgSdEgBTbSEBMSggoBAGuAZWon7zZ5f32dFMV/M8lwssqQmcgeho9hC10jfC6CQU0tR5bX/ls2lm9w5N5lM0iNTdnpkFnycZj5eCr47qs7p61AUhVLXFFqv8/iu60OMmhvIKAk+vMICoOYUeb5cLs/PL7Rk/6UCA30KKVUphZBClAgGfFcWqSiLzBZ5YXnqZeKl3svQXwbln0PLkZuKEQVNZMwzTMBiERyhN8RaUxAAJiIADi21AAEkjbxFdYlBHCWOZagvjveOxsqEDCKquizQcA5+XIp8e/BubOa8tjyNMa7IitMVSOqMdGWWdrtU1dIG6SMkjjGZFA0kQwJIgAQ6RHOZ9Rn7LGaZMUgpxr4JXdOHLqQEeUGLRXZ1WTx7ZldLyTIBlJRkMIG0BonAMDOEVjiAsCCARWACxhQ59n2MybrBhH58wt86hUWkruu7u9svv/zyP//zP29vb8vJ5Pz8/A9/+H1KfHd/z8y73e7oB6wt9Bijz/K8KJyxChslRBIkAGYBZI5RY6EKuCrEIaUUU0TEPM90asKjjpIi1Pq+n0wms/lsNptPJlPv3OMj6cdrdz3ftZmsIN7T09PT07PLy6t3795dX9+oEr6S4vb7/c3NzdnZeVkWWZadnJzoybXdbr/55uVkMtUGzmq1ms1mn3/+Rdf1iDhgSq3J8mwyLbPMO+eM8setz/Iyz/NyMmOGtm3DqPn1M7OT4zXiWB1Op9MiL/b7KqYYU+r7vm4aAMnzvJyU0+k0pfSw3twvHrbbnbeZd16xK2aIbICk9gkjIc6gqBIHiDForfHepyG/R2M/MJHwUQn5gx91+HcokT+J/tvjQ+0RviFxUs4SmoEdhiLwqKYZmnKfFiEvIrvdToVmD4fD2dnZycnJ06dPFVjXNM1+v1cc8m63O1RVEmn7/ubmpm7qrutCHzSA6X0kIC3HhZMeTQOq35hBM1OZDqJVAAx6gaKTUc4yDyhd32nT9Vdf1PEOP75MGbVLj77M+r8eF4taxyskXmO8duy13SUi2mM7grkebyHtTar06eOPcdxgv/SpMbOmF9fX1/cP98oIR8QQw3q3PTs7eXL1fD6dEHN/99LW74wwWRYC0cxPgABJyAvmYnoI/WG7uysfSicxcgzOO4MIbZsO29jsuG/A+qOWuaa8Gil/RQJ5vCcwlojf+2X6UMa69MPrkFJS+9e2bXVgmmXZxcWF6gaKgIJt62GY0itoLsaYBtiTNvlFxT/Ozy9UO/lXEBEVKi0ogfsm1NgYs8+ttbnP/NSZ0mVUeOgP1O+preu2CxxD7CUKQ0gJyDnf2yx6BCEES2AJjEFUwQ5No4A4gLVsSI1UB6CxlsdEBiANA2I1rVSYvaQISZJhcmCQwFgyhsgQGKCEH9Va363arUqxIILJXLFa2My60pvFtLq5DetNf2ik6STEPgQTeyPJkiAhkkXjyDqXFT4vgvNibOo6rg9c7bmuBAUyl80n/ull/tmz7NlTmE4ikTBLjGBIP4GAYbQJODKGKCwRICGKkDBhFO77gIA22KElO7Q7v3VVOjn+xz++/Otf//rlP/6BKF988cU///M///nP/7zfH/7+979vt9vr6+vJZHJ+fi4iXdfp1tHWsbeGQJAFWYAEhx4Qhz4wQNv1IUY9n4IqfadoiIqypCGVSilGEDmCPxXsNp/PiqKw9vs3/Y8sjeuaDp+fnz97+uxPf7x/8+bt69evv/nm5du3b9fr9X6/v729Vezu+fnF06dPzs7OROT6+vrm5ubly5dVVW2327pu/umf/vTixWcXFxea1jRNi4Q+86fnJ5dXF2VZWmPrwyH0YT6bzmez6WQ6nU4Rab/bVXU97TrNr3/+5xdmwCG0z2azyWRi15vEiXAIZtYan2WTyWQ6ne33+7ra7Xb77W43m8ymkymSICcZ6LiDLRIRjamt8sqG0w1JkMmAMHw4Xh+XXMdb+gOfdUx+UUPFbwruj9MIrU6YWUuTlFi5PwogH+IlQIwpxjQeeo8EdPCToeS32y0i3t3d9X3/pz/96Y9//OPV1VVZlqrfolAmFQO/u7vb7ve7w+Ht27fKKn7oH2QU9AUApUILM0uigTwqAkDq56Z0HS0+EGRsK2q6YoxZLOZ57lOMg07wJ5LSPUYOHgVPNB+F8WYqWEQ7vfP5/AjRqqpK5R0V14mIq9VKtVPwOxO0xxH98Xb6LVW7vsXX19cP9w8KsjGGYkrb3b5Nsnr6fObcpCg3X5v9yy13rbWIFkciFRhEgwbIJDD7PlWhqtd3t4Y5tLFf5Jn3Bqmrpd6Fepv6lsgO7xOZH0t4P936qKrRZYyZTCZPnjxRJbu7u7vD4UBEp6enyifS0F5VVVXXVVVru0UBE4qa0rkngJyenKrGraoX/IoRj/MuM4UC3ELq67bC3YNBMy+nVBSZdXnpLWfEnoMNPdR96FIfUmIGFCIXck4lgBhCa9AQEAEhgBavBhCNGHIWrRFrJKEcXWrHfBEIDDCy6nzg4K4sAChGUISQLBEYR9YYPTfYgJFvOU18m6fxSC1oyMKs8ZPSGHG580XeL+ayr6FqoO1j18a+IY4G2ZAx1hnnwWfROraOgZhBCACZLFKZW2/NpCguL4onT93JCvJMjGFtqOpc0ah+rqL0wFqr7HH1mWe1jkcUJGZRXxA9rgdrrHEx8/393e3t7b/927+9fPXSGHry5Op//I///vvf/+Hq6jKlN4kZAafT6YsXn//lL38xxnRtq3JUdV03bRtD4JT6vucUrUpgA6rFBwOmFADAe6+1gG6BxJyQkgCnhET2UZpPiGVRrpar6WT6QSP9l/Qej6XboOvks9lstjpZPX169dlnn11fX9/d3d3e3l7f3FRVZchkmT8/P1ssFi8+f/HmzZv//Otfb2/vbm9vRCSEoDSe8/PzyWTy7NmzPoST05P15uHi8uLpsyeKA2rrJvahLMqyLDKfCcB+v1PCEhFdjKi0x4fdj18OIg2sqsVyvlhkd3dJ2DrnUvRejTq4rpuUpKqqqm52+/1ms1nNV8t5VIU6kWNsO/aoNJfFYwjU4A7DuPzD5zl+ATwK9h99Pv26757gv3E9Psv018MgTVRAlxhJMSuPv1ZGRjsp23PA3elX/daPpLNMACiK4uTkRMX/VSxMR5uTyWQ+n0+nU+d927bbzeaw33OMKYS2rkMMPHLuEfR2C4KQCmJqj3T0oRivVdMvHCI7kjVGWaPp6LX4W/OoDwvHIYjIoHLIowHJR7nd8ehnZkXFm9Eq8O3btypmp/zV7/Z1cRyvwPdt/l8UKI97su/7pqnVwk5FrEUkptj1QQAn0/n5ybKwkMV1evhbBztrAllEQ4iAynVGQ9YAGdNDFsVILQfpTFv12+RdsGRToNCmrhEABCuUATlE8+nIlT99pR/9FhEVh+i9n8/nz549a9sGAIqi1JJGq6+27dqua9tjTB+gEqP3VQKA05PV5eXl06dPf8TP4seXN9570lYACBIaSSl2bWzr1NYuywyw5WA5YOpT6Lq2a/susSAaa6yx3mW5zwuf5zGxHfpxI/OFkJGiQaukBBKdaI3GVUgGyY60cgBBZpAEgwVu5Bg5CDBZMEPEJCQUACSAhPCIKPaRX3tKfQihTzGhIUOGCIw1fjYtimwyLcPJqt9Xoapj1cS6SW2FsWcUMQZ9ZnxGPmPEwJJCSiGZzLlJbmThCFyeudmkODvLz0/NbCZEop7zhGQMWkPWkrFEhhkRSJxFdBFTiIkH+04AMuQwxRRTEmEU0Nv57dCe3r17+/r1q//4638c9vsnT6/+6Z/+6b//9/9xeXlJRCnF3XYbY5jP5r//3e/+P//X/5Vledd19/f3d7d3d/d3Dw93u+12v993bVvVHQgbRO+9dz6BAGLihMZ477I8896FEHsBTjGopXNIWZZlzhtrRMQZy85NJ5MTZROpleQvMbl6HGy04LBTW07K1cnq+bOnv//DH3bb3f39/bt37/7+5Zfv3r2rDtXJycnTZ0+fP3s+mZSvX78hRID/+Oqrr968eb0/7B/W93d3d3/+8z//8Y9/XCzmf/nzn588uaqa+vLq4smTKxXL5JiExTtvjUXEw2H/5T++vH+4v7u7E5HZbKYDsJ+7BtNyo2Ow5XJZFmWMaSQg5TGG0IdtHwB2fd/HmKpDtVlvq9MqhGCtkTG84VjE6t8K3wkI4+wcP/ofP6tH+qnjOjwK5wAwcE9SGoiqiMaYsafIYwdq4LDo/x0M6+hTHrva6dWyVelexzRUid0q4zWbzaxz64eH0HV91ynfKMU+9iHFJGPvHQAJUatwMEPdIMI8ziBGlvZwUXpLxFljkDkOkf+TVezDOkrPHju37tuDsO9+vd5wzZ5jjHd3d03T3N/fKyhV44SMhAX40ez81xXA4+1iACFCa0gbOwCg7gNlWS5PTmaZ4e3r9bSU3hrDZIAcjdoMhtCQNWSM9zBlibFPfc/bqm2tWJus8WbIBNAV4EpwEzQZ0ugu/ynWz7z8420cU1hTFMVqtdI6jpkBcOzYs+prJeYQonpuhaMZ95i9IeJiPj85WSqATnO7H3no37s8udK6YbYkqANAiSl1bWoqgYQoEBoIbeqa0LZt07Z9J4DOkfHq7zMpijLLsxh6Z0iJciIDWYYQLIJBQZRBvB/1iY1TLINKY2dM2jqUFI9Vt4ITLRogQMQB8qYIoR+ZtaeUQt/HEGKKhg0aQUMASJbIe0I01rmijG3XN31om9Q1GAOBOGOcd9ZnznlUGEBiTGJZnACpZJ0l8TbMp1gWPvPeWbLWWJUVxYFWKcwpJAZOzBwBWG1rEb0hTowpQUqcbKIYVTRGBNT263gVMaa//dffbm9uDod9OSl///vf//53vz89ObPGHqrD/f39+/fv+75frVYXF5eXF1dFUcYYl8vVxfnldrvebDfbzXq73W53u8N+39R133Xa2OtHtQSjaVCMEVGYnbXkvHMuhNhKqwdlDDHEICJlUR7xIDp9/BUJ8nf7ydaCc85nflKWi8V8uVoulkMvazKZ/OH3fzi/OFdbNZZUlHmW+5vrm0NVvXnzuqmr/X57f393fn6+WC7zPFuuFucX52enp1mWO+WeiphxXEqGrp5cGWt0ELvb7bz3ij59XBD/+AUQYuaz1Wp1fna2Wq1CjDEFZhlNulSSP4mgIRNTquu6ruuu733mPAxiqPgIkoPDC/Ohv/eokMKP5oXfWy7Adwa0v/S5/NRFa9eGYBS+Y4YYUt+Ftut1iKBHkvJeeDi54mOn6lE64ZNVVaoVrw4FdV2HEBTveUxBYoxasN5cX7ddt1wsqsOhPuwHzSwZGu9aoSMKGasi4YoEPNKQjo1u1RI5DjjGm/9B0Oc4ffjVS0bymF5IlmVlWapfiG4k1Z7TL/5ozzz+e8xozKqCfQqkV2wsfHsX/br4/SMfXvvS5+fnL158fn19vdttu64jMicnJ3/4wx+/+Pzz5WpZlmWUmJdFluXRe4OJrJAzOMh4AQoTiIFkCR1iRGQhMmAACQ2a3OSlz0rjc1ss7PRycvY7X67I+l8YAT/Z+ijAH2/I43n8EX2iMT/GpJH+mDjqS2QMFUUxKYtfZao7LEkgQbutNELWyCJaESPRSABhSD3HLvVd7PsYIgdGIuPIW5c55601CJgixkCxN7Gn2LMwMIGxgiQpAgfghMAqNzP4NgszB4nMIH1oYopEIw+WJXFKMYaQENEk7QbQeCSCAJTo/KMn+J3QHkIMIaWEAgxIIIIEyjayloz1xSQldn2MfZ9Cj0kb8uSss85Z64gQ1XAA0AJZRE4ppFin0Enqve2dnTmbOWucM9YA0sC7Z2GJzJBYg3USYASwxhiDYpEZE0uIklI01oYQQJlmKbJ86ETEGP/2t7/WVQ0g5+dnf/rTn168+LwsJ13XrtcP19fX79+/y7L89PT07PRsNluUZSki8/kinse2reum2u932+1uvX5QN6HtZns4HOq66vo+xSiILAPdLsWokh3ayQwhCg8PQxE6x/bmUTzhV7xAj1t/3/5zUF2dclLMF7PT05Oqqtq2sdadnp6o8mVR5EWRz2bTsiz+8z//88svv1SB7IeH+7fv3rx48eKLL7548eLFYjH31h6no9r9P37YLPcXFxeq2dd13WazOQr5/bzPrxkrOu+Wy+XZ+fnp2VlV1+vNOiW21qnVDWLSZqp+l3JeQ+iZC5XVHrbyt4/XxxH98V36KH96/F3fjSLHP/k/EN0JkTSpJrRaSKTEoQ+jyMRQnR8byNpa1CLSmF+kIPCzVlmWigY9HA7r9Xqz2ehPR0Q9KxWJ8vLly5cvX9ZVdbJaVYfD7c01gqjePiHyGJ4R0RirqV6e598ajoioLCiM2g96S44ZGsCHuP6p7rzuAfW8UVEUBcDrx9OveZwiP/4FPxI9pdHd9RjXP/qWT76MMWqH+pe//Llp6jdvXu/3B+/906dP/9v/67/905/+dLI6yfO8C61zmfOZ885IT1bIGkRQPW9hIAQzEETJoxfyZDOymfGFzabFdJVNV65cuOlJNr0oFpfZ9MRYj0PnBo7//f/vGrLab2Pu1JtzdI0Z/lC/XoFi1hhrzWPM4y9dHDkKK79DTx1LxqNxKA6TgSTCwFFS4Bg4RkiMDAbRIWXWeWMsAKbAXSNdRX2NfU2hUViNJJvIpBhTDCCJAGis1xFZJIYECJiEx9A+JNyDfWhMMSZV7PqAkNCWHgr5fPpIRf7bs3YWbeprcU+I2iIn1uxaBk8hi+S8yzNJiTRJJDQqJWDGAxjUi45AsO/7qm03oT/0fW5wijgxhpxFhc5pz4EFhBX9z5KGUQfIKN2lDxsQQQ9HEbCGREwCYeDHT5E5vX//XtWRPv/886urq/libgwdqsPbt2/fv3+/3x9ms/mTqyer1cq54UgdOMTO5Hk2KafLxcn52f+PvT9pkizZ0gOxM6jeyQYfwiMyIiOH96pAAF1Ai0C4aAhILLjihj8K/4YbCpfNBRpbCkSazUULWah6Q44x+WzjnVT1nMOFXrMwj8x8VYX3CkBTnkqER5i7+bU7qOqZvvN9z/ftfr/bbXfb7Xa33Wy2u91+v+/7bhxDinEcRj2kVb33RTHm8n8Wes/7WlVVFxcXuUe2KIrsgf1nBGA/G2WeBK+Uy1R1XacUc9tP7sOpqurq6hkizmbN1dWzZ88uv/32mx9/fNN27Q/f/7Ddbm9vb9+9e/fFF1+8+vzzzz777PzsLOfbc5Xr2MyT5QeKomjb9ubmJl9abg36O7Eqx72biKuqOj87f/nyZT8MIQawQ5dBYsKEKGZGiHVVV9VEX5X1WItJqOdTkbCf2vUj3PR4q/7rxCOHj87Ff8dusVwy8ziMzHSiKW5HDnnvXQ7gM5Aj06J5fyL+fby2P2JkJE1Zlvv9/t27d2Z2JGkRkWEYNptN1tFi5qIosvLYhBs51LDhBMGQ0QCqemT6PI5DOPWk1fsTF+2TPfqPHNkkZNOel+Fms7m+vs6EaKctiD/93XxpOZ+RtUbyEjid4X8gD//HjHzYsixfvHj+r/7Vvzo7O/vw4X3bts75i4vLr7766vXr1/P53DvSsvJlyb4gciiZ6tQhMTADOiLvfFGUBfuSfUl+Rr7hcuaKGReNr+blbFnUc1/NXL1w5bJollWzmBpETyEd/1XHz7vvcGj0OgA4stHP72LOnOqY1Rn+sy8jpgAxHgBtLjOSkglqBIlo/gAMc458yb7xpUMm5sp5D4gxxXbfhyAI424TNvey36Sui6aRSMglctEspcQAlS8mvhpCNU0pEgoqKahpAhMzsilkN5HMRGeqZkkAT6uTZgA1RPhF0w5malkLizAD79khgaasO56J69k5yhn2yQPI2R4gBCJgIpe7JIEEMCl2ktaq92PYdd2ZcyWScebeQj2AcUzUENWyZPX055DcA5jgtTlhqJmvz8gcIxoq4GktUtUeH1fPnl2+evXqV1//6urZVV1VqrLZrN68eXN3dxdTnM/nn3/+6vz8nIgP9VpExJx5qKpmuTRVSSmOYey7PstmPD4+Pq4eHx8f1+vNbrtt2zYMQ94oc7Eo68rkDTHGeFQf+oQ84Y8ZP40+8wRHzPbgCXLEDujTuq5fvHiembYWi7kZvH//7vFxtd1tP3z48OOPP7548eLLL7/88ssvM39tJnJqmuYIhs/JyRBC27aZ4TnbgxwJ/eGlZIf0LRL6wme9lmEY+75ndpKe1Mwy1+9yeZ71JKqqKquyqibJ52M3C+ITF2facA/T6fTLT07mT4+V++WRFx6qmnP+8uKiaWYxhuxKZq8op6yPOt/OcQaXZWLzbPaICOBPFtfGGPOs6Loua2xnZvVcY869lF3X5db2bOfGcYSn8LTjPTwNcI+gdDjFBB5M5unLYy7iNGT/Yy4Qn4La8slnLNV2u3337l0Wrs2Xefpbp5+LB4n3/X7fdd0xFUEHjql/JLt+PIeiKC4vn/2Lf1F88cUXj48PXdcTUVmWi8Wymc2apkFTK6uimhX13JeNhYEduaJGVyKX7BsumqKeF/XM1zNXzV21cNXCl0tXzX0592VdVJUrSvYFuYJc6Zz3jo8++sec13/Z8Ut39ZO5ZNM3P0YLx4dyzDP+kdQPKUZJXWZ99c6jFkaImkAiSARTJCZyjovSl01ZJdWkQs55X3hmSCm0rcF+lJjabdo8pq6VYQgAgSiyS+QEURQcUVNUAqqolpEsCpaUgAAtpx0MTq90qqiraUpqplNm0qaZGf0TStanMLoM58lljylDJUFFYzRT751jco7ZoagYTkpcRMyYs0BAYERGqICgBgoUwXqRfQhdTEEMib0riOjQkj89OVVFhANB8oEX36bIy47P1QxVQcVMQBVNCDRrHX0yJaqqevHixdXz51VVx5i22827d++/++67tm2fXz1//fr1Zy9fLhaLDFA6dVcz/78BMOc4viiLqq5ni+XZs2dXu/1uu9luDmO323Vd2/fDkbl2GIZMqJTzeAfe+OUx9rI/ITd5PuE/eLj80wxRefHiBRFVVXl5efndd9+/efMmazns9/vc9fTu3buLi4vLy8sMdstbYS6oZzLIu7u7zOL36tWrL7744uzs7FSK8RdP5phuNQDAoiwuLy4kSVH43W4fYsjI1xDDJHJa+Ivl2fNnz16++Ozy4nK+mFVV5bw7YMV/5vBw6ESfvvOxie2/5jjOqmzwmDnTBiACURYecmaWqQVykY4oy2FZURRF4Y8N38eGcADAfwAK82fGMAwZG38k380kNnAIoIkoexXHeZ7ZjY450lM7h4du5qNdP37/OI5vOO7RR+sOJ5vXH3ezn4zMZXZ5efn8+fPb29vr6+vc6vby5cuLi4vsJh7nEh6y8Vkk9O3bt9fX18MwzOfzvAr+swu3f/9xdEoyB6VzPqvYmU1kOBnhCGYEtrh8cfXlPy0KH4cVO+ebuStnXDRczl05L6qZrxtf1q6oXFmzr9jXzlfsS+cK9m6ibWaXMdaO8eRZ5CzgP5YH8w8dTyKZvJoMALNnmZ3LkxbEw138Yz5RUrShz5wFZKZECAUCqmjW8c5paV8UdV2LqisLIyzKipzLkj2WRGPQsddxIEkOjIkYkJmMGZkYUBk9oyonk2zdM480MiITIHhySSWpJM3Za0MCAiJUMgSMergZZqZiuZPs9EI+bX4DM6IsSZ77LtRUU4wA5kpP3jnPziGmpCqQ0xYETDjZ9emrTSx1igIYVYOoGjL7whVVUXI27TA1uprZ4YkYYDbVipnnYgLcZnytoSmagAloMlXQSdXwE/pcPDC9FL6IMe337YcP7374/oc3b94CwJdffvn1V1+/eP58NmtOvdRpWhwTCVMcz459VTWLxVIuU66gt+1+uz3oOq9Xq9U6s1mN4xDCJPycTyBTKGTOdkT809r14854mpT+6RuO/8+B+NnZ2eefv/766199880333zzzbfffvvw8LDdbj98+HBzc5NbgLJUxnK5zHlaM8sXtd1uU0o52593/MVi8Xee58HXnrJp3vnlYumcXy6Xfd8P4zgciD9zQFmU5flyeXF2fnF2tlwsqrLwhZuMAoJlEcunhuDUMhzs/C+aisOkwmmfMDvOxE/e8ycYCHDQO1A1RPIeiTiL5XjvmT/2W+d5l2HSB1q6Y3s9Ttf2R8+fjJvLkzPn2+/u7vLNhwOrGjNnvGQWfs1+Kh7SlEdLfDRIx5n2Uwt9NO2fxPHHnLz9iRLypyF4bsd49uxZ27br9fr+/v7HH3/MnLKImLNBp0WBjPjr+/7u7u6HH364vb01s6ZpMmXK31ly+uPH6Z3xvmB2mff39M475xDMMS6vPnvx67+qFhdjtyHni2bh64Wr5r6cuXLuysqVlXOeveepw+L4h4inNlo8orAITx/Nf7PjAO4ANMw0lNmcf/quP7KNTxVVEJBIs3R5rrqbWkqaREwcIDrn6qo2BJcK8q5uZkAUYxqHcYx9iknHEUMgE0Jkdg5J2aFjnsR5wQxULWoSk8wld5SIAUQDi5KGOIIlNc1RvBHopJumaplw1yT34MuhNHEYT9nonCsLjx+9cgVAQCPviJEKx96xZ+eR2JsxwISJyedDCARKefsFUmAzrpRnTRPECu8lpYvFYl5VpeeM/IBPuLpRAXJTnCGdmPb8UEHNVJHAQJNk3fbMC/70aVqMcb1ef/vtt5Lk/fv3+/3+xx9/ePP2zTAMr1+//qu/+qtf/epX8/nceX9a/J78h4/h+6GyOQW+hOjzAquqaj5fXF5edl23zwX53S6rOu/3u/1+n+uUs9ksk1HX9ZFcFuAfAasFJ5vsH36Pc26xWOQW8+fPn3/xxRd/8Rd/8f79+5ubmywxdzThiJiDGzqIKuawJqcoM8vK35PsSQ1VM0QcAAABmbkqS0IsiyLEGOfzFGMS0awZ6nxdVbOmqbP6NdFkiw8WGKY5No1jmu74r/2ybf6pIf+FZPDJzPv016eb+XdeeN5n7KDzmztxidj7jO2Y5tdR7QkOCz4lyWA6ETFTouI0sfT3GX9gjuVMzPv379u2Pbb/ZmmAcRxzwil/J7e8z2azHMEfQ/BPguxjKH80lqeWPs+cY4X+NGo/5LEsW/e//9X94YEHdMhyuXz16lXbtiKy3+/fvn0bQlitVplg/KitnskoMxPc9fX1+/fvzez58+evXr3Kb/v7LK4/1WkfDXleX3rg380JZwBAotn5889+/S8Xz78c+g4QyRfOV86X7At2np0j5w4yKtmW02FgbrjAj0f89LL+2wnZf3E8Nd9/2nNdzGb1LJPqMxCx864o0HlzhZAPipRMFRKgMRtSFAEz5AiA4zgO/TB0XRo6GweO0YmKASARe8ceySkxAEwybqCIKEZw1IF3RM4hoZrGlBCQMZpNRs4AVDUQetSkZoBJLYGgguUc/sl4YtqrqlwuF4AIYCmJigBlhC+xc0XpqqqoqsJ7RFOAqTsXEAnJMVBWp7Tcr06A5ME5Y+DCl1UMwUTOZs3ZrM4QuqNpz/8eNOrsNO7K9uCwEWvu/AyBzRAxkYgBABLzcHohRNh17bfffrtarRaLxX6/f/v2XYyhrpvXr1//s3/2z19/8UVV10y5KwmeFGieOAkn/7WPoOWiKGezuR50gjOPVRZ1Xq9Xj4+Pq9UKAOq6vrq6Wi6XRfEPExb8B40nOaufgOxOTdGxnpe1al69evX69etf//rX79+/f//+/YcPH66vr29vbzOr3X6/P+pGZ58gwwPLsry4uMg8jlny4e84PwPN0A+ADJ/I09gxY1EwcVkUBypSy+lHOqFPB5p0ivF4OPhp/eXJlU8csb9glQ9z6Wgpf6bQe0xL/h2X9neNzBqZ4+9crOm6joi990dXyTlHhPlx5Rg9pRRiNJ3iEmb2HpjzPvxHnhEAwGKxIKJcYs+tksc+xox4yLD2LFsgIovFom3bt2/fwqFqDofpdFoKPVrrU4N9JL+CkxmIJ+NPFbIfB56AAGazGSJmFahvv/329vY2hJCTbbkekWXBMp5gvV7n+lTf91dXV1lNcbFYZLTj6fH/VKf6S+d/vIrThMfh54ZI9eLcV4t5iCGE3PjLNNGQZADzR9uNmUP22PLy8Wfwy1vHf8vjH9vpmDXNRTmzTJRpZkhFWZB34At1PhihmAEIkLFTpDGKalTDTMU99MPY92nsLQ5OQ2EKADY10HkEh8YwSakZkoJC5u5HREJmZEKHiGpqCAX6HMke3DAz1QLBoyU1AYpJgqGJKskn5BdPTPvV1bOvvvpa1dQ0s1Jkb+8Af0fvcoNBdiCOFhgQgWkKKz5WbIAUSIFC0pBURUC1LHzlHRESHTZoO37NPXCHdGn+5pPYKe8DMg5j3/dJkqnlyduPCtf3+U3M/OWXX4/jmAWAvXcimlK6unr+l3/5l//8n/13r15+vpgvmdyhz+rvFQ8dlsHH4DvvYkebt1jMLy8v+v6zHMRnrdhM3OYnjfl/dF/4p/vOJ985vsy7cC6Wn52dvX79OpcWMrHd3d1dBlhlan1ErKpqNpstFovnz5/nlrlXr15llua/86yyadcDL6zZUfkAiXJ/pR1cPDzuQwAAZmqGpqB22I5s8g7g8BjyODXMP2faj3boJHQ8ZuPtxN7DcZf748u/SFRW5Ww2OztbmmmuWB9jKcfsnM+v8tXlzxPRLIPmvZ81zWw+c56dd/QLUIN/6Mha7F3XVVXV932McbvdwiH4zgydy+Xy+fPnL1682O/3OTUNJ0XxT5whO3Qe5zlPRKctbb8Ujp8a9T95Qjh/dG4VyX2b2SXN+nVZbi6fqh14dbJbkxlPX758+dVXX2Wmmv+Sdv10HF2f4+fm+45gROT8FHGpqoHR1A8FlFHQmdAk31U8Lpdpmv3hq/hvPWT/xx5E6DkbWwJAIl849o68N+dTRnQjg0MCwBjVYBxDlqwJYxjHcRiGGEaJgTUVKN45di4BRdEoSUyy4gpytoIkaiKipgARYNrWdCqgi4IiGhGCo8kOe/LskmoUHSFBUiVSIn764J6Y9rOzs9evP5eJO0QmPN1Ek4NgQAjEeYeZtkYAyOT2RHCyOqeoKAvQfQRIT7C7PL2e3tDjLno48DFj+mTPNlOVYehd0UkSO2T2bh/WJ4+GX758uVqt1uvVZrNRlaIol8vl5cXlX/7FX3711dfn5xdlWeWt6afz+B+0ndOBxCoTpKuqSMq14xijiM5ms6IoidgOXJv/gKP/o7nSx8NmVPbZ2dnLly9zTPn4+Ji57TJiLlOaIGLe7s/Pz7OO1vPnz4+d+p/sPp+cdk4iyclef5rO/ZhXt6kW8jGTm+16Vlj/WFYzmFCVfweq9pNbd/xmfug2pa/s9EcfS0SHGfs0bfMzB/zDAydtyiqb0typT0wIoHZQRCI+gQced3Dz3ldV5b2b45yZTjnp/iCQ4BdvyPGUssjQOI5E9P79+4eHh67rcqNXBo5lqs6cjc9O3ml6Bn+uv+BownOseQqVh5MwHU6s+OlMOJqxP3xXf/oI7Bf6HY6Hys3imU6nrut3797d3NxkmMgxKQUHhEFWbf/iiy+y7Fv2XO2XsSx/kvHT6fqHnImDaAJNtORgRpZzp3mJ0Ondniz5YVP71IM6Xbmn3/zPuNLj0/xHukv/eOPTs813DDEjzoiZHRMTOgZmUQMDotxARugGBYgpxhRVJKYYQgwxhphiTGxJ0JShRFbDqBYlxRyO5siSSIEENIqmnCcXPex8ZmB0aFdjRjAgh0TMjsD7JEoxabJE5BD9gSXj43X8u3/3744vcrvRtComwdbjFDkENR99wI/W93hPfjJODfUfeNs/6DmYiooe+KEAAbFr2+1ud/oRMcau71JMlp1c5+ez2fn5edPM6roiohNE25/MS7WcdVYVmXT5crflcVv7w7+uqm/evDkK1J6fn19dXf2pzu0n5wnwtCZ6jF2OgnWZr+PYmpwz5DlFkdUSj/CiT3aH/X5/fX19/NHlxQtmd+qx2SFWtsOLPEOeYpanCANPpx3Ax8n0SzPp53YWe/LPz/3SyeM5uhvr9WOME+nKfD5/+fLlL/7+z5+ImUGKsev6EINq9kvw5PN+MYrCA1+Cz60pf4T+W9/37969O758/fq1936c6LiH3NlhJ2WXo8A2Mx/bHbPM6NEW/tL2ferh/YE35I0l+8Tn5+fPnz8/Np3/gQu5vr7e7/f5/3Vdf/7553/YvTv+P/fcZwxsZh3Pl3y6EOhAvpt1x/Pl49MZ+ScZIYQ3b94cP/rly5fz+Rx+blX+gYs6vnwSE00uL/zCVvNpqgsOhvzv8+k/Hff39+v1Ov/fe//ll1/+F8Ab/smHiPz4449HFdrCoafTPefogHukg/7KwRMVSeMwiCSAKcVoB9Y8NUUzwsyyx5ZZYzJ6bHIeCBH1ADU55jQ/8dxztjfnEPDQ4AcnIdNRKtK7siyaj794atr/PP48/jz+PP48/jz+PP63Pv6352T9efx5/Hn8efx5/Hn8efyB8WfT/ufx5/Hn8efx5/Hn8f9X4wmM7vlFc7aowMRUTMVM1MQm7hbMpAdIZJLS0OrQWmzRomPDTCNgKAqiEMXaAbYddglHYQAiIFAFVToUe9QgmSQTUQU1Z+SNCgNP4AiJwRAFIKmJASiQggMomSvnHTtmUtGM9QOzuGzCxTxfBRN9dvXsWPU5ltSPBQycgPiGuek//xAIDMzUANgXzhdH8CB8UrnKpVI7FLcMDExCjONgKhMmgfATCODJr/7cNwxMTVXvN9t0qPosl8vzi/OMJc/4wRjGFOOxSOm8n9omfuHA+cvxTA6YtZM3q5po6Ltx34b9LrYtxAApWdbMBcioEjgWvZmRyByj91RWftZU84WrKnL+I9kvYte2t7e3+RUhfvnFhSPSqBYFg4AaYb5xmX1Ip1aImCCMMI4QAoiATd+3k9qhnTa+IQCiIQExkMs4F0AyU1OFlEwSyMRYbFMtH4EQixKrmpoZNg0QAyCEEcJoKZqKeQ9lSYsF1vX7D9fDMPVVMkhpEREOPWh2rJPj4THkZ2F4QASeQplOaswH9eaJkcHg+AQxE+jix878AwzKdOoaObZaZCmFSVw1V+nsFMZ+7D2JBuv40Yk/n8+948OzOkVcHYr/03OHAzJ2evrHYuxhEuHpjPoFpA0czxZOmxBUVZLEkMKQJIpGwags3tVlMXdUMxX5t07ryavNpj88Du/9xcXFKfXNycn/ofGkOA0AiMce2+PrfN8+HusEiHFSBT05yN9vHKv+Dw8Px9r2q1ev5vP5z8FAjt85wicR8s5lAhZAB9POtAPrs5IY6JHkE9WcmVesgEp0nnKpGCzJJIpqpod+d0c4yX/nv0iE5AAdIsNByv2ndyN3yR7OEdWVx5/+Mg7w6WSZiERxWsxwELD7BJ8FT+dSPu7JPTrcTPuIyTllbDw52uGfp6iF0B+/0wYb4gnsYNrhnzxofLJYpu7CbAxOR37LJDRz3KVP3gAHytXDIT8qch0xSVMZ/nCFR6Pz8Uj5chArj82J9NsT0/6//6uX/8N//8pir6FPsU2pCzIkSMjE3pdF44vKlVXqtrsf7vv338SH3xXyuJilgpHEx4R9gHaETW+/WdP/+gO/XxcfugqtLNDbGGwYHKgDMMNotkndLvVDGCDqMhYX4q8An3k8b5hLTA47gG20IYAOUEabi71o5s/OL85n87os4xi6/V6joNrqX/36/v/4V/kqisL/n/7N/1AVRYbZWSYEyDh9NDAkc9NdwwAwqqkBEHg0DimI6fzy+eLiWdk0PtP3T/19R9fg41aFCGCmYu3qYf3+fRr2jApM5tz0aDJCBQFyg8o0yQynX80MfqhqMcRxjP/T//y/7Lo+f86v/+LX/+b/8G8EKCnGGIe+fbz9sFs/OMfNbPb8s8/Ozi+ZC0KHH1cFnG7QB0w5ApqREigaZJlgRVBQCFH37d2339787W9u379b/e439viA2430vcaQG9AJHTIjMzpHVYVNDfMZP7ssXr66+Oz5q7/6q/PXX9QXl1QUOimO429/+9v/x//4P+arYEf/l//zvzyry/A46mMHjy2P4sgRREs9xA7CAMkgAqx6XN3C3a3d3UHfWgiZdNiQDMkyY4LhREaHAEzq2KiwoqFmjrMFzhpzTlPUodf9TtpWh15jVJ1cQHBMZeGef+Zef+muXrhf/ROq5wgID3d2fy2blcgoyxm8eln8y/+ev/r6//p/+7+/PQDQagufycqzFWSERghZZpkRGMHx5PvmzmIgQmQkBnRIRM4TkWXArffMntgROSACyq4xE+QtwhCFUCZ4IbIBmQSLAVXAzNgbOzRANZAIElSyk5ws60eIpIyuEc1NhqtI//NjcQQA/fOvvzxfzAGMEPmjrNOB3uTITgYHwjJ2QISEimSTa0OTswcHF+bn0IAnczBjJs1g6pmROIaubVe32/t3u/3Dfmw7tx7r3cX5Vy8vXp3VXzbFZ2ioJpaVogzM7P/5v/y/fjw8juVy+a//9b/OJA1H5bpPIeCn2/e0Bx43zsOZESiakGpm6DYkxYOaNgDY1OSDAIia+5GnzRgwc6L97Of93AlkzoD1ev0f/sN/OOK2/u2//bd/9Vd/dTRSx7cf/gAAAlC+cwCGqpZ2Fu90/FHCtxY3Ft5DWNm4s5gsWBJMyiEtolxG97kWF2526WfLoqoMbbvfbrbr3W49joP3XPi6KOZl0VRFVZVFVRVlWbmq5mJB/ozcjKgh4BPLdHy68O///b//j//xP06nSyyzS5vMzumjwI/9UMfOlvwTAsB8Ax0CY/YhMK95QZtYriaW8cmPP0w1mxTqzdBMFcxAAXXyUS273zg1B9jpl+NMPPRtq8rdD3DQDn2/kTdrgYPuFxLm3t1MGGRgCJhhrVN7FDI79t4TTVJPMcaMb/uIhC28c46RyABEJYlIUhPRlCTq1InmnPPMmSIMIaPFk2SJjaNwItEUFompiIkkyah9wC8u+C9ffNRHeGLa07AKu5QheMSJkRw7NEQmdt55JiYV6fbj7W27e9/KKs0QIRVVgaAUknXRAlRazRcvz7+qz6qh/iwUMYiMw+b2cXMbhy7sx2gKAtBpjGoI6B3OC7j09qqmz2Z0cYZ1g+JxVNj2vNvY5l5kDxgplq6tCyt4ixDNBkMzQNPxSesIzZaLuixtEskmAzCaaO/QCA3VNGoYxqEf2yQCQPN6WZU1jpRiYiQQQAHyeLTgeDKrAQ5eoU2BJRBx4YnqwiM6NvbHaGByKRDIkA0OMEkDMFImI0RS0TAE4hFPIaY5JAVS0K7dPd5df/jhm9Xt+6Iszy4vi8LX9aysPDs6hITTwjl+PcXE5i2YDkFYnupx6IfHh+2H95s3P/Y3N7Le4L7FIWBKqEpAhIao2T8xAA2RmNE5azvZbMPjur9/qOfLcjZn57J5so9L+XAdMTGzH4Img3JGBRIhxg7bFsYRdjtoB+gGWK/h4cHWG+w6CxGTTO4Yqk76QGQft0AEg6yCBKE3BNOEoVMiFZEQdBhk6HUcNSVVyWsZxcgINjt0N8ReRfDyBc7OMEQAR+QMWcdo7d76zsYRTnqyiah0hSfzbISKCDRZRGSErHmYNZWAGIkAORt4IiZySFkym9kVxJ7ZIec3k2WGyUkhWs2CWAJTBARTEIG2s/0ex4hmOJvhrOGiQM/GoIooZEKmDCI5GqaUVEQPaqTuqfVh550rMlUnHYRjM6Nkbo7OwQMRIzHmPojj/FIDMbCEuSMo67NOMRf+JH90mHl26MnGKe4kQg/gw7Lo96XEBMl8pEaL0iMrMrDjvFRN9cB+94QlDRGzpk4m18OnA56OvKnrFA5ivh40zLKDqkkkmhlNvbQM7IxpCuVtcu4NDvfhY5em/f2B82aWlfF+2ghgH5M0J0HntEtMCbeMiQYZMe10fCfjNxa+B/nR5Br0zmSvqbdoltAUTRk0mUoYdQg9xnsel9WscUVhKoUfyyKaJjOJManENOxGxIGhLq2qXFmXvlq66tKVV1xeOb8EbhAdGk/PE+1w2icPGwkOWdKPD+JjyH9q6KegY3LUp91Js0FDTSpRNYGlnEYFPJj2KTtGCEzIgA7RTbxqeEiQAdFk+A93d7LwcNIVdZieSJ/s63mNHtn7ACFFMZPJec29nQaWG3oBwCIGzIYbAbOOpZkhGDsmIkVIqhoTIqLmVLiYKWVCK3MwSZiLARlSwczOAQA5I0rMVDiXPw4P7dYAkESSpBgppZQvwPGT8voT0x7bm/7xg6/mvpqh844ZzLPx1BfCbGBjN67u9z9+v3r8dgvbcF4ShKJpSEEH0U6MZmf+7ItnL75u/Ne/ollv1O0329Xd97/lJH0b+81+FBE1SGCGVjAtCnw+xy/O4NeX9PoZX11hs0DyEIX2W3f73t6wPjKNgwtNtVvWLfsULQBEIANAUwfmT57N4uK8qWubwmLOXjkasjIZglmQoYu7TRofw0MIkbAoFrPzZUU9hS6gkQa1wkBOEnVwkhg7vFA0BRVTI6TSc8WzmSdfGDvLLgUaQLZMxkbZtCukCMlMvTlvjpBVbWwH9v3T7hE0IENUtc368f0P33z3N//r7dvv6tns6uXr2XxxdvHM+brweGra4WjXD7PYcFpBWY/PEBXyaWm/26zfv73/9vf3v/3b8d07ubunYYAYADRrAaBpdn2zcE/efgjQkCO6vp7tlmfVbDE7P3dlmckN4GAJPq6sMaKa2/U2EtZn4AtAhTbBLkG7h7tHeFzhZmXrjW23OIwWEkjuakdFUkAFMyRDUKADkxMgAIiCBUjJUoBhD0hqMAkgSpKYJEZNuX0ZDIiccjJLa2s72O3w4R6/+DV9/jWQQ3TIBcIIw2Cbje122u1N5eNSYW7Kikk9WdYrz8mMHOvyRMjNyG4y8IdulfxzJIfImNWfnSP2lHMhk84iETGQGsqYYoiCqIRACTCKrtd2+2j7gdToxaUDcN5RUQiwmJPEqglUQQUzS22MB9OuAObNAPU4jdkVrijNDBBcNuwHOz2FvvkfzrqPmWRSTYRENQkkAREqPFYllQWym0zd1JfzaZNipvXTQ9MPgoEqIyJzmVIcBpFkFqlIfgZFWRxMu0MgzBIgh4TkqeOLiFmZbT6fZ2P5SeD+SauYgEm2kgYMRAogZjGlYZQYNQioMhk75pLBFVoX4lBzoA5gQAaMgATAYHRCiXh0WX52nJ7GOI4ppZ8oQGo2YAqGQFOO8FBRUTikG8HIBOPaxnfS/bV0/2+Q7wg+gO5Ae41BQrSRLTgFMiTTTrTth8023gg3WMxny7NmtvDVrPBclaNpGAYJwUx3EAXCUEBfu76pdTbjcnZW1M/9/Mti8WurP+fyM6IZWYFI2an+SQUBgXIiw04tEHz08/M2NeV7crRjiDbl5CVLiqNElaBpkDSqjqIJDnnOKUmSCc3RExbIJbuKvGd0lJlhkQgIjcjomKt/2gmZKy6ZhDRHsXQathESO3TOuVx8VFUxObj4R1JoOchCpJQpX4GIHLu8WSIhEwGwYxIASZMmm4maKJgRQOGdz3zxgGOMyZIhosvsto6YwMyIUJVh2oERYZJ3AUgphRhC5BiimQJi4Z48lCeTrF/fbt48uPq8mF82Z5fVfOkdKVKSFMfQ77p+F1d3+w/ffvj2//Nu/XZVDKktIa6hnqNVnme1O1suLl8vX//KnX1h1SthH9HGYd7tmvPzcnm5+P77D29/vF09rLrtvvE2r/DFuXt1gV9c4hcX8OpCrpY6b7AsCZCTuO6iKSoeQeCMt22FxRksz6Pg0A5DgjFI0qgqc9Pzk0lGvsCiQMs5NpocRTVWQVWxtNvfv7t79+OHb9/dfB8lVmWlCE0zJ2UFGkOK2ieiYGp4KKsffNUT0z4VwUU1jmOUhAWBc+gZ2GV/nkzpoGaCQAgQJYU07oddiOOsnDfF3HtGY4gE4YnblQl/VCHGuF0/3l+/bR9v0v5hTH1Xlbv14363q+plVR2WzqEiZEeFU8ScssoEwIfTRkDQlFLf7W5u7775Zv3dd/3bt/rwQPs9pgRq2QtAmA5wuGgFEQgRc3mbXVqtupub/dn57NklFp4XC+LCPnHmwXTXGRGsO0ge/Bw4gQ6238DdrV1f290DPj7iemX7VrshMyhmmSDFnH9ENTAwRTViJTKij2xDZmgJVSDmEqMlhSSas10S49T3CWhALMZqPqUijJYSjKORBy5oeYFlreQM2cbe9m3abnGzsfRRJxERvWNGZDLM/BDMOVpHYmRCyuROWbCY4WP4S4SO0CF6RAfskB06j86Rm/KQRIyOFJKo9jFs9huA5BBcm9xmcG8e/LtHtx+dIe4DjYmZqC6w8IAIjlGSqYAKigIxIYtEFVEVAKOoAOF4IeQc50oTYi4pTNl4+shDCgCYhLoBh/xnhHG0MVgINowQI1QVLBbw8jN48RxnDRUFTrlCxE93fTVTS2Zh1K63YcBxQFVgohS9qDdkAQ7iukA0mh8RlRwTOFQ0y0/VzJ4I3uFRUtu5o9LBJ1H7NBOPqbVsIJLCOEo7pM0ubvdxu9dhwCQE4NAlZmXGszk9P6ezhpoCmCDbKmZEAxMNQ4yDGSA6X1TOl5/ShZzO/hO78jEWfLrM4VDfyy43WgKNJqNqEI3ZMUVNqgPEawu/h/BbDL9DvQXdmQyWRhkl9SAjSWRgVkcRKaKNEAZDQHTMSl6R1YDQO5bSK6o60xiCyKhhh9SWRVu7oXKxgILlDPv3yT5Y+JXVv3blK3TPgWugwiaMyMeRI4djrftJCifn0W1KByPYRKl8SEaYBpGY0gCxhzRYGjT1KoPqoBbzs5seXlYPA1LwBiVQg27uypkv52VVFWVFPKmLZ9TKT6md8ESjKy9O0CexvAFkbkUByCHCGMYQYl7z3nt2LufikNlS0pSiJFVFQueEiZGQjQnJAJJ85HHKhP3AfMgBoxo4dt77oixFTczAUFNKkDcNQDNGLCZOKzvkJRABxHHhOUYXfTAzQCzLCDAeL+SJaR/WD5v+O/Nnbr599rkQMjelMo5D3+369m5Yf9hef3f3/tubN7973z20jfLoqbuT8ozdZXX25Yvni6+al7+6+Oqr6uySioWSCkQzTrF59vz8xRevnv3uzeL599//5tuHd+OyTK/O7Z9+5f/yNX1+ZS8u9GwmdZFYAZRNOCZfzZrgqgug4bzU/Uxg4YqZ9DEphaBDiKNwVHD05BEqsiFBrr5McDlBiGhRNYTU3z1+/9vvfvP777958+FHgTifz2azxfPLlzUtUYuQksY+MBWQlXQPW8SnUbuhKZiCqoxDipGcUzRE1OkH5lRpKnmaMStCCMO+39yvbrp+f372XM6hceS5EDSlJyix/EmiGmPYb9bru5vUb7wFTJj6bbtd7Tbr5fmVHVyOJ8CRTAs1Jfr0WGbMiQQD0BjDbr/78OH+97/f/fB9ur2mfUshTiiuKcKfPGw4rkJVlAQBAQmYZbPp7+72Z2fN1RU1dVOW6NwRh3S8Blu3BmCPHZiHqgF1OG7g4UY+vLcPH/Rxjes1bbc2BAliigqoxEo0kTzoJKVgSMpmzOryPphBFAqmmKFyBqIWFVLSmDSlJDGKqpjlHYFUnWrBJAktJQvBihqKmsmRL81QDS1Ek5A2W1utTT5G7YiQI1ueAAtoxBnBB44zGTOyw5x7J8rh+KSwhUzgETyiM3JAbM6Zd+g8IZMROibO1fNxPw73m7Xq6EDLx76+2S9+WJXvtr5PHsjaEWLCRUOXS6gqVxQgCSSpJJBEJIAEQEiUMIKSgSLJ6QNBZnSewOgg6snHNDwdUHsiuO/wcUWPj/j4SNs9dS30vfY9dL2NQavaLi/hv/vnQATssK7pmLi30wmQwUICMdhmIzf39viIux2ZUlVQ6V3pXD9wNzjtoNu75IharGNWrjDMOS/MpLSflLWPhvwnfD4HBOLBKbVDvRpEIYS02oTb++7tdbi5i+stjIGIHbIDVoNRhZ6dlb9+XX5+VV6dYVEoIvmKSwQys7FvV/1+JYLI1eLs2cwVx+3nNGnx0/Fzp5rPEA/7iSAE0NbSTuNG4l5SrylZQksB0g70LchvGN46u0UZNIJESiOnkdNAMRQxFlgyEEdyA7gRKTG50nHjqCTwajiqJULwDOzNq44xJOgFu9q3Z/N+sdjWzZY5GXhJb9PuRxs/QHyk2T/nGQJdGTjAn2utOhjzg2mfIoNjBQMhW3dD0Mm6K5ipxl7CXsa1hg2mFlIH2oP2ZgNYOhbGs3UXQzGM5oOWggujpasuq/pyphdEnoA5SxjkWsZJyP7JbcdPgQDHqWOqqiFERFNNIuMYYowKRsSGWDEjMzt2iEgYJIEmVQM1FAVEB5yFV9QgpmnpcVY+Z0Sc6kEqktQKx1VZeF8gQj+MYwgaJSbJ7rbL8QKiZybOysgGCATgHBXAyXOKzgAIsaLuF007CqRe7m/X2zDM34b55aaa1cTY77t+3fZ3Q3/XdrdbeWjPNrBIs8oKMkgS3Gwxu3h99uWvnv/FXyxfXbqmRlKwFiSijYhUsLu8vCjrZ7PZ8vmz5eszvP0+LmnzYj7++nXxxWfubIHzGZYFOEwwjhosGY4KjxHvxa3Yj01TFAuVWiOJWVQJphGVmqKanxfzBp5cCQECWAQLZsk0oSVIe4mP3dCu2u79hzcf3n+fQne5PE8WXMlMnD1JAiICZGJP7Nnwo+t3WL82TQ0zNEVVECViy5WZ7HqjQYoYgvVjGkZLYmpYFOpoP6zv9ne39+/27S6Mg4ICYVMtFFR/blPIaQFJMcWRTDyZgaQ4tLvNZv347EWbUjzQi5/ENEco5fGM4LjoAM1S1w13D+27D+2bt/H2DnY7HAOITuZqqoeh4aFH4LgWFSwhUoCRtW3DZt093O9uPrhZUzQNOY8VfbpWdr0lgYcVCBok9Gj9Vm8+yM2dPK51u8d2oC5YSCqmhoqYXSoBVMUMeNfsSeUUM6Id0n84QRfyyWXTblE0isYkkjK7symQoZGaE1NGZYTEaIC7HT4+QNkAsqWoXWu7nULS1UrOlnZgBgSADDpjNCKbti124Bw5JnbsKC9GyElscpgxiAiIB9NuHtEbszkyR8qoDERE4MkxMZlZNK8Awzi23UaHdnazP79uZxspFUrnGVDaQW5W+mqNz8+paqj2mYsSiIwIUBDSFJAAgJKa4lOa/8kxmfKjNCVKEQHyvVTrO3tc8btr+uEN3z24zYb6jsIIMUIMEJIlEXZyd68AxgRENmugKKZ08nHFIKqZSkrbXXr3Tt+8xR/f0sMD9y2BUVlAU0lTl2m0blvL2mBDy+CufPH1C6RXMFug98fo8Kdx8cfd+XTSH7Nsxz0dAHObSVLZd+F+1X3/ZvjhbXj3XldrL8LM3NRIDoLGbhzbTq+r8PAQX13pZ1dQVcqMTc2LBhu0Ij4+Xj88XI+R2S2+/BpmzRLR5U+yj37/ExP/BzL20xvMwJJpq2kl43sZPuh4LWFl0mpKFkmTaBoJt84/IA+JOEY39DGMLgSOg49jMY5ujI6EyVgdR+QeWNixq4ALRUpmJhEkpiHJKDokHaL0A0vf+HExH5fnoZlFLgakEQwojSg9aMShV4pCyKBEhWH5U+uOEyTpuMvAwbpPcTsaICiBAETTqBJiiikNEnYatpbWKFvWjqxnG9EGhACYMtIhByQGIIbJANUlK0LcDrqFsW27dhj7GMbl/GzeLJgcZeXQk3OzA+8ennJj/9S0H0S/Uko58YnEviwBgJiLwk+gorywmcuyZMfHAyIcNXlIFcxkwrIQwVSkNzBBNVJhgMhcqNXsCu8JyBPHlDSX1XE6Tt4/kPiwpEynKpWiAU+ii8jwy7V2Jk9WPNztv323Ad4X5e28qUuisO3ipkurALtYjKlUXLi68BVZ2UtahT37s4vXf/nin/zTz/7yV+UCFdqYWgiiOoiOQCW5RV2dLReXz87Ovnwx/2K+v3u+nkU99/rZVXF+XjjvyTnkwlRsWIsMIcJ20Pc7+XGrH3oU8POmwpHDEMdx6MdhCEPQWDXlbDkrF/On04zAFGw03ZvuQQJq1Hgn49vNZv3+bn99fb9ZPcxmV1+9+nrUEC0tZktHnozAzBGD58JzURxM+wHKcjKRswVXFEXUwEGnkGVyLnEcbLOVzTZu9xqiiVFdaeVXw8PN7vr2/n3X7lQSIRa+dOhQ8WeYwSe8m5plcC4yYTJNMey269XjXbvfxtBTURG76bQ+7iMIRyQvHjD52cKbStt2N7fd+w/9uw/yuOKuRxGbgHu5GJbhXXCI+A+5flWzpBFhJOg73W2Hx4fthw/FfDG7uHBV5Qr/NGwH24029vjwgGOwsDMC63by8JDuH2Wzl37AIdIoFlUVFEEJzCY/WC0jVHMGRie7niWPcyygBjrRNSpiMoiqQTWIJEmSRFImJVZDzqZdGJWRnBERdR2vH7Gs8qq2sdfNSjXJ6lEW808S8o6J0Cj7FUTAHpwnx+ymtDZmiScmJEfEOQhGJDRG82Qe0AOTMSQUQDNEysTQzpEjU3XqiZwZdm3XPtyn9+vqpkebV9WiYI8Att3pw1buVvbysry4dOiAyZBUjwgjzE0FBgBKoPYJ+6diLutOSG+buPrNANRURfRxZd99X/zmm+Jvfu8eHouhJ01ZUnLSdTRMqrJaJaLk2M7P7LMXRgTeT5bdDkeWFPs+3d2lv/0b+NvfFt/94Fcrp5EIyDusKq1rACgkYupZe5uN9gxdvMLiBXxmeH5mQAAHJPpPDCQefRSAyavItf1sVeyo6WikIENKt5v+u7e7v/5N9/vv8O7Wj0O5aKrlAiunprEfbb2T+3U0Gz9cpzcX9vIFNY06b2c1Xtb4zNnC7u7fvnn/ph+8K54tF89ev/oa6eBO/NwZ/uz/T5YHgCmaQOolPcjwNuz+Nra/0/5HDfegnSXRxBJRBJwnnHkr2XjWBtr0Ogw8jj6MdRjrIfCYiIGYCIXBMSCjY6BCwSVFFElqFlPohrgfpQvQjzx2tY+LSztf6HwhRaPKBiBgiTi5GEWixj0MoFgiNcTPJvDaJ/thFjA8TWDkJ4bThoI5zQkRbJTYxrgf+zaMOw0bS2u2ncM94+hxdBgJA2EE1EOJfcrCCEA0o8iibtB9GNpRO4V26Ls4dCixcFwWDbH/A/f8xB381LTn7sS+78dxJETnfT2bFVWVTazzU+kna1FmwEdJJR3UDlWyegLApEFuzMC5QG6QEZsqAiIM5hAJ0RE1dZN1rh1x9ioy6BMAIANwpqARzUBNVM1UIG+HRyfgaY3kiWkv69nMXdRlpNSubjZj185cOUdXjlKMVgbgRE6xcm5R1YWrRTgiJva0XJx9/vrsxctyNkceZAyaOpVBtBcdkBsWchzJQ+U9Lxr9bDGPS799aBIvKl8WtfkGXAPFHERspNStwxC2j+n97e7tOt0OC+YSZ6Kj7Tbdw2p9/7Dq205SvCwvZmWJ3p9eCBqABItribdJHlA71iDxYejePdyv3r7bd7t41tAXr85+9fWXifwofHH2elGfxV0cdzskYl8omEqyY0PuERV7OnNteoBj3w9tB+qk8ewQFeX+of/m+3B7nzZbUzEiq6tQuQ+y/hAe+3Yj47hV4KAllk5d5Wc/wyBkk/vm2LErxHAMKiAAvT7eUbO8+uzV4mx5fnbJ1QyRzQ7+6VS3g+m8D3BABUMzS3HcbPbv3/fXN2m1hrazmBQAOEdyTMRGbEyGCCd4gxwIqSlKhAAwMHQ+rFfdbbNfLuYXF64oXVnhqdiXmew66Xb0cEvtHtpaEaQf0m4fdm0agoaEUTIlghkqHgJ0MDVTQ7Ncd0chVAATMzIEm4p7k11XRRTEZBaThphCTDEkiSKqSc2QjIAAHEBeJ47ZKfgQU9fi6gGSgImNg+5XCmqrB1vMLT1NyDNhRhcyITM4j64gl0vVuYH1oLlFk67D5C0ZozGaA8Ck0o/DbmiNrCjKebXk2jE6ACRyhS/PFucvnr1M7T7Jve8Td8HNHZ8vzHtJYvsW2kHuNvLhsbh6iefmvAPHqiCIuSMIDSGX+lVRlZ7CtiaJ3UPjrSmCBTURER0G2e3cm7f13/7ef/+mvLnxw8BgSJg5Aw4JEmUBAIHNGt++lw/X+uVrYKe+mD7BzFRjinG/DR8+6O+/dX/9n/x3P5TrjQ+BHREipIT7Pe7b3MBOYExmKWpcA38PifSfdPqrr2G2MF9NjSZPd+djcnvyYc3gUDyyA6IabALuSTeMd6vNb7/d/83vwjc/4PVdFca64Kb0vnbgIJkoq2MrGKwfxlUbUx9S7+czLqvYcmjJYqHK2+3tZn83hlllTVI5EB2YTesXs990gL7Az8L6TkYC6yVtdLyX7ofUfpP2f5u672W8l3EvUSWhxkKV1ZCrSrkBLSLBfijWLXe9DaMPsUqxUmIs0SrEAtkRMk24NUgpmZmEQBAtDqnfdmE34pCKFOaUfBFKjgUbZ/A5zAARrAMEZCHoBdaqtxrfQ/gS/R6pQKSnF5QbPI4pFoRDO+TB+UIEA5MYuxg247AKw0rGjaYt6p6t9dQXOBQQHSSGRCSIgpTd6AlWbwhq4NQAJFnqoyKoJo0qrY6aWsdSeILFFddn0yefTJifvfufeGMiklICAO8cM7uicMyImEuCkJ4E/Zx71ogzkFNVo0VNAgaA5IiAGAxiiKIqImLZKIupEigDxBBSjIAYU8rqzgoGhCqQqehBJEjk6Ni56aNz52+u9+cdFkBUGZ5UrJ4s+qJqFs3F+WJ/Vqwft/32XRfMCfhnWJRU1s57xozOn7nK+WJAVJIBeb5omueXzcU5F15tUJEUB0n7qL3YSILe6hRGdcE5LRwsZoVbVtZ5HggSi5WANbgFFeemotyK9mMv+9V4/y7ermWtzpfBx5hGW612d4/r28eHYRgQrDxfSOYqOZ1lJiiDDPcS3qb4AWxnEMO42u9vHx7W7z/sEernZ89+9XL+T796hsWzEeaMZ5Cq9eo+tHtCQl+IaQrhZPZOHRhw+JsLeQqgauPQDV2P4NNYOkZGizd3+7/+3XD9IW435shKP5Z+X8B76m6xdSkVCn07Qhtn5bIu5nzG3tefzjozAnOERVH4qhZ0XdSkomlo4V5cefXy5dnFWVn4oijydneYxNPGMvW6H+6LIZiIhXFcr/bv3/e3d7Ld4jCSqh3w7ejYnAN2wASImLHxUwPTlPqeUDDjgJ1L203vfTef788vy/m8OjuDkxI1GMi+k+0WHu5su8KtF8MYJAwxhCAhaRRIiqKoucMJFNAoSyygQX5JQig44alAFFHRCAEObacmAAkwmiWRENIYY4xJkiTTZKBoBtm0W05iecMCMInI0NN2gyGAJIuDdTtjsPUKlgvwDRzS2QiYu0sQAImIHTmP3uceMcJjx7RNUXGOm5Ehw3qRyTipDWO/Htf3472yNvXczJpy5qFAUEL0rljOz+0qhtXjiO/rhGVQz45mjXoXx0hIMER52On1Sr7qYRR2HhzJZMsm045AWfSZzNjZwfzlnSAXrk0BVRVMUuxDGMYw6npLN3ezb3+c/+3vy7uHchgdEVWleWeOYUohRZSEIqxAfU+3d3B9Y7d3Nl/CYpnvgZmlFIduN97dhN/9jv7T3xS//X15/1AhOu+pqpDBNNE4umEEMyTCwlNZWALbdfLmbdr3CczqEl8AnDlAwp9YxiegOZum6MGo5pzExGdhMYX1tnv3fvO3v93/f3/jbh6qfmiaYlZV5aym2iuDqbEDX5AWLKOGsZV2GP1IMvezGagFjdIUcV6046aLW1FXZKwzTeFUPqeDqfglQ/7pMBhNtxLfpf472f1Gdr/T4RsYrjUMadRx8HEsUwAFNEY2jlxqrAbg/eg3HXW9DaNL4sV8UbuyYaqNSstdG2ACIKaWQopAqpQGG9vYroawDy7qnLSZAZMyjQQGUoJ4RAeAZglNCMhAgHq1jaQHCxsoenaRuDy9xjzFDQF0ennIF8LRrpslkaEfNv3+duhu43CHsmbbFTQ4DgXGkpIHdaiUPXwyYKSpMzRH/qBmzszQomnBlpGWkiTFLoWdd1B6711ZVXMmN53Ik3SofWrjPzHtqqLKzmW+BOc9MRuAqoiZiBxNOyEoO5o0+HJF1BAwk0jRgRUikxmEEGJKagemk6xpCRqJc/daSKkqCu8cTlkBTSopxdy2TsTsHGW9XiTHVDguHRPAQVFZmRROzOAnCXlX+vLZvPnyYjbOzPyoikiEdQG+AnQI6NB5x0A+GfWgG5QbDEb9nrvEA/rSK6urLY0JRgMzKIjnzs0cIdqAIKZD344Pd7H9UXRtF6Oc49icU+lAxgAx6bCSoYttinvWzvm4mPlz1DLuhq4dduvHod+ZBfbI7KlwSvhUEdpQe4gb6D9AfEdwZ7AXjW2/f1i39+txs0/nM72cydLvynTDVV348xhpTApmGSTB3hlamiR9jtiY08A9d1ZnHDsgICOBWhwDmZDG7uZu/+atrh5Jgs6KUOrDuL/thzsYdjYuAlRCUHgTH3a7oWsXzTm6J6YdzRCUEJzjZr5YXlyt768D3o+pj2HUkBLS+3ffz5azqq6LsqpKdK7A47I6OdIRTwQIIil2Xff4uLu+HlaPOgwoojmhxo59wUVJvmDnp25RVRCxJKaa6c8OOVyBlDAE6Hpzm/H+fn/9oTo/b55dyRhOPt1i24XdljdrWq+wcAl5jBYEgoImhSiYBMQOrUmoAKa5sg4KqEzKrJihcmpqgIaiE4ZfNHNWCOKEnBtjXkghSW5pj4AKqgpk6pXQIQNXzOAcIJmaxaBgkAKEQbtOGWG/h+0Wz0s4ELcZZN1SRTBSOtTWaKqZ5C0NDNTUCFSBBNWBAyTOltYQh76/ebh5t353vX9PNT9//pn3/mx24dmbGiMCmkfXVPNlsdj7WYFlJURd1FWnRDBG2o3URVx3+rBNq720o68qX3lhAEXFBEAEjHToGVNz/un6+Kg2JSpxGDfb3fvt9mG/6fz97tn1tv7wUD6uqhidYyoKrEooPGbtPlMQhGRAAGKoCuMgqxXe3NqLz/TqijIth8jYd7u72+HHH/j339Q/vi33XQXoioKrEqsKUCEIMYIjRETnsSqhLjONI4Lgbo1vfgRiXW3hs5c2m1lZndZHfjoOS/QjdCA3LWtKab/f//jD5j/9bffNN/LhfdnFEqlk55w3IBFTEFNlUVYBi1aqzVyooGvGuHCwKODM+/MqPcOhTmMvUpAzX1Wlz03qRNmtOIW4/D2HxY0MmNq/Tu3fYPjBwTtH9+I7RgHlceAQXde6qA48kiBbjES98iAwpiJC5ktyZI6LwpfOO2PUzHFjmrOKKmqiFoOMber3sV1HGbQmLgqKaAF4EKgCFbGkscxMFmodiKmAKIiBIIpDKsgkc199moXAA0bu42uYHoaZJonjuO27x273odt/0PCIaVXgnrAnjDwd9MATlFsUcwNxDllwSokRoAJ6NK/qnDKNjEhGokks9d39elPVzWK+OCffMHub4jE7OiJ/2LRnkhnH7Jzz3jvnckyMYDQ1sk8ESpkmYYSMtgIinLrP2eUIX5OapRBjDDH3ex/K+3aA2IJ3TM5Fsy6MQcTlTlRAVck+QYoxiahlhD4hs3OucL7yrvaTImRG9JNPs18y7SDgBC/KIp7V7TzFmbVIUJZ8NreyTr1hMAcFIalRFGtBV5BuMIDuH8aHl/HyHLxnLFyprk4SFStE9sWsLOfOMeFoGmLcr1f7d++6+2/T+GgvJL5yw2dez/1II0AYrd/KMKaeZShprAuZoZ/FRGPftrttt1vHMDCDLwtf1b4skpk8mWYG2po8Wv8e03tXbJINY4y7/XC7Cg9b3fdw3sjch1IfrCupmFNxlTRJBCIqqsL70hWlER1hQQcMGpyY9uMLAiRTb0VBpClEiwMOfX/30F3fun5flZxmNDI9xv37YbuVkFJqOqTI1FTOjdAPOo4gSvbJepmids88ny8url7c31xRfZ+i9uM+jkM0uf7w42w5Ozu/aGYLREZkYOAcaB7LSUcM3SEjNrb79vFxd3MzbtYaAqoKEjGj81yURVlTUZD3OZdnkiwlhQgpIRyJPxQMIAmGiH0PRHH1uL+5rq+ulq8+Twc2UAAwg9j3cd/Kbk/bHXqX0A2CI3BiZ0koJEgHmNyBYdWmzFMuvZM4nhBeBmCWE/40VVHVVBVB1JJqiimGFGMKKUXVZJYAEoCYCQIjqCkxOyRhZ95b7h+VZKAQRxsGHUdlwq7HtoXF5cclYqaaptKE6YH677BrTCgfnE4eFDK/hxKiUmYJABzSeP9w9+76xzfbH8plSRWdLc+G2FZcIBsAZxxO6ctZOTsrF+TqAhz0KcEeFGAItBtwSLjrbdWm9T7te3+2JGSHYIikCI6YHGqmZAFVdf6pRbTsf2hMYzfsNrt3Nw9/c3/3bnOzW1z31U2EbSxDKpmpLLGuoKrQe2Ce+HOSGQgAIggqYIy02eDdve32FpMV3sxiin2339/ejD/8OP/hjb++L1MqfMFlhWUJZQEaIREyQeGRGcsSmgpmNTACGA0B+xFubrQPstnCamNXV3a2hHGAXxh4LOziEcgKgGhkEsewWe++/379N38Tfvge7x+JvKvqnIKRpBbEWNGMREnVIFkJdl6GRtpKZKn+HMur0j+r+lkcuY/ezKMXaohdEh0GKEyJgAmQDumaT5Hyv5QQ1vggw1b63+rwn7zcOVqhb41SZFRBYhLlYaQhEQiZgUkYwXrxggzskYkLQmM1ds45LnLHKOTWHVURTIopWYgWBu12sd2FfpcgAtUckQa03lwZuOjRFwVqSR6BRjM2RUuQBJKRMpohJTNRVD0wHH8cU6n9CHvIy8MAzETCGLvt/mG7ed9t3437Dw62FbbgArtIpghqCooHBMohGaMADg0Fju09GeaCakRAlIgUMSIEBAWVYVzTzp+dPQ+x874kKi174ZNpP6DnYOohmjLtJ4OJC1+4iUWucMxqKirMpBNPzeQYg5qCiahh0omwJhmSYWFoKinFlFJMMSWZ2tkm8gvMMT0yZWI7VLM+RoyRiHIhwMxUJMQYQtbXFlEDJMy9ct6X3tXeeWYmBjMVKRqB6uOFPDHtstnH/aPbDY1QzVDOCM+WdHU2u7p04HbfP8Jde24MQiZpRGw9tCwt6n27/f7Nt+cXsKzH86ZhwNIT81KpAKqYS0ee2RBTGHab1d319f1379bXd7Fd43uWWxj+hTOHqcHk4qB9qyGZlpa7Lvt218koICF24xjTCATel9VsPluec1nux6EK4ezjzqUaHlVvrX9vcosWouq207uNfFjJQyuDaExj6DextVgrV+dUXoEy0rKsq6Ioirr2Zak5gPiIC7KDkTzmuzO2hxA4hmFsOcZOdEx9Z6vV8LhO+xbHQcENI26TbCHsXBpNEUCT2aAlynywubnGeT8hSD9d+AhIhPPl2fOXr7fbbRdGuv1gj3eyfwxhWK/ub2/enV8+K6paFc6WWFU1HlEVH7vW8jmjAcQw9ptNt3rsVo+hbU0SACgSsXe+LMu6bGZYlOjdhPhVtRiFgoaASUySambUM1CFmHCMyKO1bViv+8eH/ePD0HWn15CShCg0Co6CghFhUAoI6hBEOAqITu76cVPOIC+kXGIXRFWBJKAZ65zNvgAAqJqZAiS1FDXFlFQzcW226Lk/XhHlABkIgMQUPIv3BmoihmpKFlOuOJsBxgQhPCG7mKL2KdvLYJBZghGRaEoE4wE3d+h0RyZgBAZjA0L0SGRkCSWogNgYdYjSqzYOmTJMgoEcl3U9WyxlsdRmF9WNUbAP1A4yjKjJIsAw2HYnm61enaHWRohKqJzRj+howl6osB9PH0duJ0wxdt3ucf32bvXt9eabx4cP3c3gbtWvqA7kiZkdFAWUFTY1+MLYgYqlAEEBUoZDICuAUtvR48p2WxwHYFKwcez77Wa8vdOb+3K9r8eQA0ucquM2nQgxFgzeQ1lh08CiMceAhn4AGjgm//gAbQfvr9PFZVqeQYrwk3F4Rifu9xF5haaa4m7XXd90P7wdvn8D94+u64SKUYyYEwGjkZVUEBKRgiPy3hGLeZMZ2kWhV7P0fMZnNSxcgBDCKGGgfef7bbG+E//dfvT+/IwWc17OqanNMRDZSaX3E1vyiY2X8EEGwbTxAI4rohlgNEikQiRMwbmM6EBxFBX6TkfTERAd+QxbYMzJLlEZhpjrA2QGoKomCjFBTDYOOg6pb+PQpjAqAQXAAagVz4M3dRo0DaGptayNfQcugZkJimJSMI6mrfOPVN5o6VgQ7BRkCpkh1ciO7bgAYJpSCvtuu21X6/WHzeZd6u8gbhoe1ZsRq6KKJbQMnY0AdEyOkhGBZyM2IgMGQwNCYwyKQ7IoJKp6wAoDgmgcYzfGNsShNiEiA1KFyS/Aj7H74bQ/pSQ4PCwQEUwJAZjJFwU7JqQDkxCYmYmIWTSSnBbVZKCikFBjlDgOKURJKdeHHJP3OQvAmfjx2NChalGiaDIzJlLAEtl7n3lnC+9DEWJKKalmQhHkBKgxhRAYyR0QJ/Onybknpj2s9v1wb1GxEwLzM+c/X/KXL4qL89jr5t1jCCMKg2BMFgh75B7TYJI2229//91ZnZ7VgZ9dzIrG+abyNfqG3ByAwQwwiAztbnV//eHD9d27u91tD7tUXt8NjxpmNc6JsYwzHG0MIIpUkFPEmELat5s+mKkGBQE09ujYVVU5a4JI2+2bcXG62DU8mtzreAvp0ZBC5M0O7zd6u9FNbwEspDB0IbQpNSrzG0ifmdVEDVcVc1nOal9XOmHj1VRiDCpRwRAhc20jEqFzExUJxaF0ZF0rbdvHfZvuHsNqI33PYRTSMEA3Su9iZDMgVkRQUquCLYItjeaUmWk/SXFNOyARNbP5sxefdX0fVbkogDDJ2HXrtt0+3F9fX1+UdU3kJt4z5k9Bw9NXM4M0jN1m3a1Ww2Ydhz5HwIaIxOy8L8uyqqEozGXTDmhmLiGxEkGIFhEETQBVQMwsGQZwrF0ftpv+8bG9vxvSk0bqpBpyz0owVIuog0IiMABMYklALRPMHQOezCipSIacEBNkEj+Zig1maCaiOOWITQHFMgRGolo8BOua0884ofMAQAyEUByr91Z4S9EkWgIhAUmqmjNnLApZfOj4MHLnH0wARQU4tI/lvciyaUcizGj5qeOFgFAdIAMyMGMF0Bg0aiTCabTQSb8zqrhyhF4wd15y2TTz84v+8jJdtHGwYUROwmNASQQKKjSOtt3JeiP9c0xinPvygODAbccezFCFnMenW5qZphiGdrd+uH54eLNuP7TrO3kUt6amLRsofeXQOfMeqhLqBsoS2FmKEBBAshjJ5DGB0dDRaq3bHfYdFIUSjsMwbLfx/pEfVkXXV0mdy82+CABZKQotlw4cFAVUlTUNzOdQeEQwXyI52rRu28LjCpLCbGGzOX7xOSxPV/rJjpz/Hnjf8xwyMEtxXK+7d9f9j+/Ht7d+u7dREumQTAwK1RKgMPONB+8QiYmdL5jULCiTzQq9rORFHReFFRBGHYdRx4Hb3m+2RSzT+F27Dv75lX/xrHj5wl9dwqymsgDiAzDn4G78QtQu8U7C4DQ4anKVVbQX6A0MMTm2ooC6ZosUkPsgXWddoojqKgBGQ2CcuhskxhAFc++vKU4ZLYtJQ9BxSOMQw5BCFFEzsmg6KLoIZAyBbVDtx1SP9VyKuqMiIhoYTVE6j2Ar8e+ouOACzRPoad1togaziRx6wnSkFIZ+v90+3q9uHtcfttsblm0JQ1kqOgTzOUFnmsVtNOc8Jl4CNCZQ1owCshzdIyhaAByVhkRRSY0UCKbuMokyjnEIcci5SADKxHcG+vP9kyfO5jSTkDRvOCKI5L0rCl8VpfeZIRinOqVIUh0FomiK0ZKoWdAMSdWUUpIoIhO7tOei9GVRFN4TISGa2cRKK0lUY1IzVTZWUwAi551Dr1aWIjGlFJKKgQAltZQkxTDGQBlkTcxESU/rxU9Ne7fqH+7WcdSuly4gLhp/dY4Xi31Mm9X+bjekNoBoUJgbqKOeXIgxamolfjekshurbtd+ef7q+Xy+PK9mZ75aunJmwGYmKQ1Df/3D2++/eXNzs25Fq1cLBv94ff/hdvMbGnjP/+wVvTr3HgtXajn31cJVjXlONo7SGSgbkCIlDjFIVOjGMZmOMZxdnSx4Uwtbkw3FrWlrUoVgux43vW5HGxWRWUGHkEIMlgZLa4u3BkuzMwLHWBExMdmU0YpDah9Wt/v9WjQRY1XUVdGUvmnKxaxasuOpdZkIADWltG/Hh4e43VqIFkVHol6KPda1NSVLQc5RFWIh0hgsRBYhNOPIklSf1kSPhWcw530zW7z4/DUXfr6YzWcNodxci6ns9+v7u7e+IiAxTOTJeUdIyHxIb2PuoMhbXhrHYb0ZNpu4b2UMnFmh7Ag5/tjonHPIRETOAZEyG0ehEWPAqZInpiKUYIw2jNT1YbNpb28HeoJqFHIJWYBMEZIltBFMbOKaU1NAgPwrdqi4IymzIQmRAMrU4aaYSa8OwHgwMCab0LSWAKJZVI1mySxrwghg7oY/Jm09U1P4uiiLomJAS1FAMxwvd64YGBIcdBpOZlb+fSSbulBgsufZugNkqgrONHM59WYwzQ4AF6Ua46KNV3vRDg307L5tYEVjQ5eOL0uqyRyrQ0KsmxldPfdfDH1E2oe0G6MZhlGAvaghoKju9mm1kbajGCe22gykJLSJ6TazuP0kLjGQlGIYQzfoPhU7vtxVzSBfCV4BN4iU08vswBdWVVDX4AuMwRjAEkhATaAp78sYBtpvcbOx7Q7qRgsXxz7sd7bbu33nQ/JmlAtCaiCZOFgmlBsxOI9lCXUN8wWUFTCh7wwZksIYWBKIUN+6ceAXzwA+Ne04obEnnGtG0U1iAyoyDOPNfffDu/DhQVYdjYJqBiJpDMliUlMAEbQGmsLYKQBz4VLCbW+kUmpa1PFKUBUMksQYR5VAlrwGN7Z2fz1s+/Dhms7Pi69el1+9rr54WTx/hgXioTMhT7tfQsir9moBeA5YS3yQ0KVokgREFMwX1jSEHGmg3WjacrvBbS+jIFdWLbRoXFERUO5fVDNlJiaX70dGxaSkksQkMUjl0SNHNBVJoRsMnIIvUulAEFOSJCmlQHEkVCJCYwRyaGY7kzcwpIRboj1j0LR5+hwOqyAbUtUkseu7zXb1sL5/XN1vd+uh60qSynOG7zMBszk/ZTqIJtdgOiJNZPE0qfKYmoVkfZJBcBTcRx4TJSiUfCaiNkUziynFGFRlMsQENrVMwM/VSZ70vxEx+6lTCZmzBBRRzpCrc25qUgczRFZ1qJqzimSK0AmAEpMvXZOilxRztqYqXVX7LNeZ++ZFVFUm5wuAmBEnoSYASqok6pm8d5kWOqkm0ZA0xDiaBaVkBGqE4Bi9Y8dw6qI8Me27Tbh+14WgneC+muvZQmdzdeVmtb1/aDdtTEE0xmASkFipN0mYUGMMcdP3NzC85V3Vz4t+LpfnaXnpm4WrZgCkCnHUdjve/P76w++vN6u9ES1fLKlu2nW3ud6+H1K9kwutF+jms4JIyXtXkHdaYPIpFMnICs7hkaqAhTQO7TahJdMwvDzdgS12oC2kDnUAcaaUhMaofZAo6ggQMWOAQBOkvcV7wyvTDqjOM8DAooQgQ5Ru164+PHx///BhCAMhzqrFoj5bNs8uzz4rXOP9MR1EZmZJYteF1Trt9hqDJpFA1KfK4RzsnEidY+9mJTQlzCLMVKpx8H0PMVgpp1hOy4oUoJC5Ear68tlVVdezpioL3+03Xbtr+9XYt+v1tSuNnHHBs8WiqhpHnjOe207c4BzGDOOw3YbtLnW9xkiHHqg8LXJhCVTAyJSBCdlR7ohjFgqGqACohpJT1IIxgYs2jtb3cbvt7+/G6gkeUIsqFVVAJ4aokABCbukGI1MFQCTLjRyWKWjQgAQpc9KJgWrW55uyfVNmPQPqmCy3vYElg2g2mgbTBJaz5wogYGp4sMJUFsWibmb1rKpmbABhUJ3w/5PKGEA2OZ9AbAwAgLKgiyIdCrqUQzTItRMinKRfDlMJCQ1JBLrEq/38sXu+juUeNWjF3Xx8dDuG4M01xgyFy3Q3VTMrL6/caMylblp93IWUpO8VkgZxhJRUdp2sdtL1LiT0JSIzGQIqgWVvw55CPw/b2TQV1DBB0dP5tqh31dUor8WeAdTZFSEG53K2HOoGygriCGSgCVMAiaAJwEAAYqD9Hrcb2G7t/Ey4iuMQuxb2HXejS4ltYiIANROBrGMAMAnyemdFgXUDswU0DbAHV4ABDCN2HUoiEQ7Rh5FF4GcHTvBpOJagAMFMY0xt19/ctT+8C3dr2wcxAEABSylRVEkZOaLGWkBDRZH5CjgR90YajWOaleG8JTZSp0PSPmIShkw7FKHbyWaQ+xXcPKa2S8NAjnxVwnIJE9XElMD9+TMHMBtNR6MzgyIlSkFSTCoRQYDMl0acfDlSYWmf/J5RSXoZR4BBUiqK4MvGoUNANUwGwuyYXKZznRxRE0IpfSAWR4hqYZAQQkqRUD1hwVB49V7YR+QIOKolELCsgoSECESd6V7iXuBRHAk7i/3TpzCFs9OqN40xtN1+s1lvNqvtdjX0e0mBvDpi77DwWHgpCy28OWdERqQ4rdq8oQIS8pTOw2RgAqYakvYjdol74Sgsk2lHQzBjAIwp9n0XwiiasqDicdofiQ4O54ufMICSYy7QVA1sqqYRIaKqCSof8gkZpZklhDxDYagJAoJLqIreO0OfYowxqAioFoUrPGdC5Nw6PzW443QiTISUqedIwVISNGQkYi69y1F+FAkhjogOzBMkh6aKiJ5d4X3h5ZRP+olpf+xp2DihIvkylYvkl/sdjnHcrlK/A8EyVfUadIw6KlRmZgk5zjk1tV1c8Ofneu47P47hcb3t79u7GfoF+rkZqUBsU78Oj2/acN2XhJfzellWripSVc24KvbDmGTd6AOyXTpXQjtSu7a4TW5IF2znDVTOFGVQbUX3op1Bb9grDmpOT4pwBqAJRFASmJJB5dxyXjRlxNhhCM5hzW4+q6vaMTPKCOMaaIfYA0pu4o4prnf3m/ZhlN2+e7xdvbt/vF5tNmGIBZdns4sXF5/Hz3Uxv6zq5kBCl3eTJMMQ230cekkpiZIIBC57vQArFCQoudQEmKFbEJRg0Pdpt6MxaCOftMzk5XkE3zrn66YBfZbi+OLlq83mXu6GPmzC0O33j37lqqY+P79smlnpJ+wuGvKU47YJKxtC3Lep7zVE00muPv8kxRhDQDdklj10HgGVkZgJKcvQmQGpmagmMUxmYKqQBGKEcUxtO67XcSEwO5AIIcLyTPs+FGVCAgMBjETGxExslHXJbGKPssmxyhx8mUk+0ytO4C8zMJQDChop68eIWEwSNP/RoJox6zntIZB7KJCJi7KYLxbnz64Wi2VVVAhqY2tJTKccSW4WMF9AUcEJ24thXmuM5PMfoMwamyXjPsIxcgeuSQYVIRmYSGy7+GETv792bx8ubruzYOCVxt6vH30TrYVe2KlRWZAvkNj5mhfMWBSLs3G7H+9WQwj9bocpcIoOkBNIF2TXWRcgiWOmwmMysylBg590/ZzuX0RFUTX14rxcVDCrN375iFd7uwx2BlhMWjUMzJOhLSqoZ1CW4BhUIAwgMWMdEARTwqHHzRZXK31+JYVLIaRhxDFQiCQHIFNOdEz3WKe5wQTeY1lAXcNsbrMFuAKIIUXY7dA7dAzeo6GBIP3ESTkZefc+MpKiWRrCsNq17+/2bz7IpiUxzbhGNMy5xBh130ZLo8UyhKqZO+/JgAM0yYc2jjYKbjoR3J41V3MnqdhDHBnZ86yAsqKh8IMvoyOz8e4xqg1FyezpK0c+43UmQkg6ALk+sfKWeg0b0R0kkXgPcs+4ZxcnxUUwM/MiXAUuPJEnck2lj2sZZFBzFAsbCw2gqOSQHUrUMWkYJcVU1dY0Oqt1VkntY+1S6ZRAU0gpJBEhtLrAusSmsIrFszhOTIIIYKzCpiiACMgOyCtqxwoQf0xdo3F+akEI+WA0FcBSin3ft/t2v2v7tkvjQBa901kFiwYWtS1qnRWxKaL36p1O/a3TJjB1jGZZRUAypGQYBHCAYNArpGTJUMErFgbeMm0QEhj3fbda3zX1sqrmTbX0vlIlswNQ5tBtDwATAcXp0mDnC5aURCWzW6qpGmRiSTVIqphXmGpJOvOwdLzwIKNsbMgld4+s7IJjIsgCjIiQUoox5bphVoebfOysUEBkU1tqXiaEakE1U1O4Q5xWOPaOmqpIKrmBi4mYmAlr7n7RtPc0i8VzLWqtGm1qcWXb4tCm2LNZ4xtU4Z5s7PoUZA46I52V+moG1ZV7+Xn16jm+upRlPXqJuqcgfpRmTE2MnALaLslGukfAlpuLGZ7Xc+cc04uqqKpaOvP72F/jSk1adDW0QTcr053MxGYzLguaNx4Jxihdwl2CncBG8HHU1aglfbJcsnEwUkPVkvG89ucVzV0kjrWni6a8XDazhtkhaoBxi7xDN4ATA0yawhDuHz5cP/4YbDeE7Wp7v9rdX9/e7bYdKS2bi+6zwfv6xYvP66YpXZ074Cy7W8MQuzaNQ0oJ1DAlDMSIc8E6gBVG3krDGqghKBBhHKXtIASTTxLyAB+pL7KeEntf0HyR0tWz5589PNzs2ochbGIc+hZdSatZs1q/WC7O5/Wy9BWSP2m5nQxmCiH2XRpGFTG1jw36IilGCgMQsgqpkgGRg9ylmb1Xy72xoilBjJYm7RvVqRFOui7ston51LTzs0tKSebzUJQQUwIcHYP3ReHBjEOcGupMzQjgGEHbkYrfVDNcDifhbARAIM6peDWTJDGlMckgOqgmm6ysIQigICIiI/mymM9mZ+fn58+umnrmAbXfC5NKlonOnHcI7KAooW5OTTsAKjJnKlnn0XlgD+SQHLIHMIJj2p8NOFt6RwQKKcRhtdv+8A6+fT/78LjcDqUiO7QhmE/qBwsY0FtVFxeXWM7YO1eScxXXM3954fatzWbwuI7X92HfFe2ACpwAumj7QfsAUYnIeW8oqnjoxAPUY9vB6erIMmzkgebCyxEvd3C5tfMIjYEndJwbjQ7C88RQVjZfAgGEBkxh6EGiaULIVXPBMOJ2C48rbffaFCmMEoIbI4aEctBhyIIYYKACaMCISMAZRldCXdtsAbMl+BLAYOiwKLOUDjhBzXnNn6EuPxj2qQHkkGhBA4ttP9yuug93/fUD73tSE5ok8DLQLInGIQVNo6UyRhm1KEuHDGaVuZkVsZPxQRQ7ksKnqkacjyWNEEnqauaxcd475rJnN0DatnFIsZmPVeOqGbhCSweOmJiI7JSz+smsChBbjT2kFmyHuGMekGSSQEI1UNPohHzhvfdlUczqOKtpP/AoE2+SGERVRk9Y9FFTK30XYhoL0nKu57VensmySfMqVi45SCqiSUWAwLyjwoPzwKT5gZqCCmkiSDitPEUDdIRMiW0H6UbH0uRLgKuTx5AT8pY11EXiMPR93w19l8YBJZYsdaHLxs5msKxhUencp8Yn78WzEClmdoaDac9Cx0yUYZFJaRRTgEFgH80GU0DDCaGqxlmazwD7scPNXVMv6mru2BdFRYSZvBUmiMzHwP3TqJ2QmEUl7ztJLCZ2JJQTTSJZyxXNGLRkqAs883ZRaJAUbCyVSmMqnJX+YNnVzCQ3s8UUo0xAW0IAyMIIeYpMBJwT1MhMVQwiAAGIEiEwkfPM5LI2fNRESI4533UPT6APT0z77POv582LAdwA1EkcJPQpRtBqdjZbsJyN3XZrd8XIj327LTg+X+L5JVdX9eUX5cuvZ5fPYDkLBXeoKQ06tPFu1d4/xO0D92uqB24i+8ANuoRe2atqlOgbPn/WlOCqbSxX424bVtchlhqQJGIx2OcztzzjxUUxPy9dwaY4qvXJdkrriN/d9d/ctPPZEzY6yM9bWcVoTK5MS2cvF+4vXzRh9E3hXl/NPr+cLStwTtESDHvwHcKIhSniOPSb/vGHt79/c/1bLiO4OI5DiF3btdvtHpRiMOf8fHF2df+ZK92z5XM2l/dMTUlDlBBSiEnE1DJvBqbEI5WOfMmucJ7QM1XeleRMzWKyJFnv76eb18fFA0DA6Iqqni0vn11cXd3dz3c7ZymmcRj7fdc+rlfvl/Oz5fyiKmr0C0Q+iMeYmohFkTGGUSTBITE61eREJIY0EgKYJFZFJHQe7NBpigTMyJ58wuiReVJQNAXVrLYi44BdK3V9etrl6y8KX9CLl/K4lt1usNRWjup6Vi8IyEKEGCCOJmKiudcuh+c64YA0K5FDdnSJkF1WXTPEDHxLSWJMIcmoOgDEQ4ykmVQdsWQuvF/M588vL59dPV88u/KA2HXJsoTyodcbQYnQOasbnH/MqR7uABs5YA9coCuQPZAH9uR8zvQRgBFkdjojJKKCvcbU74btanv37j2//1Bth3LUBXuvbDFqlITDSJuR7/TsCl8NNDdyJRDnDjsEc2VVILvnz+h8aXdrBUQxBqMxyRB1jCpCSOTc1DuLCGA5ijH4VGpMVGIYut1mvL/1H67nNw8v2uFCrWZ2BERA7hBjmmGOLaoanj2DugRTcGRDhxJBEqgAC5CiCe52eHcPm63Oax2DhgRHnkGDTBoMmjOQAgyADpCAHTgPRWl1A7OlzZfgSggRuQBkhMzgSnBsi/zFlTGlzg7nnRtlNt2P1/F6ZZtWUwKChLmKAABGmcLIkFMKnYUksY+F98yuqKtq3iyrZVkuYo2KXPRN/Vg1XMysHFKMII6KEkpvYGIxio1GAgWaPeyGH64F2O86d7HwZwuaN1CWSpAFYj/peSfIyiE7wnvigSiQE6QpmswVdDAlFiZhioUf5zN3dcVj5KRkwAY0RuhHC7EeQ70ZCosOYYBieHWePv9Mnl/axVJmVax9chQJcrf7oe+HcjJ8KmiYoimqsCaWBBItRo7BDNDEIxuBmO1Ub8wuj6Y9x9iAuSlWs/pJjKPEEdJYYHKF1qXNGzhr5KyWRanzQmfOGmfeQYam5M3qUNtCdEgTFpUNXVREhNFL6WPhxLGymGiGvSvYMTeUxnGbwlAWVeGbqprNZ2cZhZJ3stMi2yeFdgBQVUmZoFrUlADI4KjblRmzshvpzdQbC2CIMbT9bteuNn2gSHNaYDmbe+cZVdTUMKYUKQEEhIiIzBnHY7khxzE5dviRNx6PwtiEqKhgYIeqk4DEpCpZwR0lpzlVvYtQfryQp6b9s1cXL6p1G/p9329W2z500YywbEpfN2WtXFYiShaT68syPnvlvnzFz1/y1ZfN1dfL2Rl436GYjiG2sffaDcmhpZb6W1dI46ku2Vfe9eQHNRmCAiFrPfPLgeuxwA2PO9zspHMihdQVXc381UVx9VmxuCzKs5K9s4lTFwZ0e3H12R48x7OThj5EQA/owNgSakhMoeH0vMZwVWvytaNnS/+sZu9y7VdABoIRXQIwMWuH/f3jzYebH95df1vNoGwYAJKEGMM4jppQBdi5+u7t8v2Z8847P/OLHJhkIyMhSEqSK8SoEA1TQubCsRfzaSphMzKpaVQLyZLAz0Ttp5vXBIhCdkVZzc/Oz55dzRZnxUMtkiTEFIax2203t5vl+e7iRV3NCUtyR0SbqUYZhzQOKYxmSszAfAyMTUUlSiRCM02mSszmC1CBqWs709rwVEtmBiaQCfNskixFGwfpOp1/7LZCovLFZ4Uv3esvcbVOTL1025q4abg591hoEh566PYQI4hYGBVHO3hFk6JbPkNCRAZ25j04ZxllqppEQkpDklZ0b9AhBgSd2mcy5zxWhW/q5vz8/Orq+fnFRdM0OAwSR5KgdnAhABTRmMEXWM9wscSU4IhtJAIm5IJcQb5gX+bYHblAdtNejGaEdHB6mLnwZcIgsu66bv34WKxWMiBjUfqqZDNDTYOkaNhHt8VVS21ERfSFOScT8T9yUXiD4vzcL5dYlGZAYmxKUSUmjSkX7ZAJNdPjgFluf8p98k82rzCOvcnu/lY+vDt/f724fzwbxwUiO6ZsTJkADUwg81SbQeHhbAnnZ+AdSIDNCsZhSstLQlYww3aPDw+2XuvZQkOwJJAEo4JatlCgE4whtxeAAzhG7UUJ9QxmC2gWgA7cDoDA8o46sXPDpyDTaV2cIs+zYdAcXPXDcPvYf/8+3a6gG1VEGQQsQcZXAoIRoANkNRolxhSG4JnZ+wVAvVw01axpCi1JESl4v/PmQKiKoGkiEScWIUkQVZOxEQrKrpebBxELu654fqHPn9HLKzwHrIp85Z+4J0Qz5jnKNVHnXGBnU0IIS0BSi2YBNYJGJnGcqjItFvGZODHKnGaqOIzWDdjuZbczPwqXLOXg6/71c/niM7041+VcSxe9SwRpsr+WG01Asyzi5PBABtmooApKguQMyNTYxJkwKKEK2AC2PtUZgym5DQg6HVZVJJqOBEPFgV2a13o200UdF1WcFdp4mzFWjJ7BUQbxZmw9IiLkbsmseIZOkVFQTD2DZ3GsLgM9DY3AyJLkblgF0CQhardr5832/OrqpZo4YkQyOxRFpkvNFcuneJrJek/TCRCyroJIpuAXUQXMCQwAMzaz2PX9arterR8224A9tzVzeXFeurJAn6XqQvKji0wYshDkgT41paRq3jl2nEWVM4yZTmhJjhT6ipZMTDTGpCImll0gUzWxqv5l014t53O/fBhuVpvN9d3tartGR64qlHgUXdZNWdXny0WDnSy2L87xL/53s69eu/MLm1819YtzVwJACVHQWo/CpT13rL7kAMUelza/LJ4Tu0SgOHS7NqKTirANGEUVyflmWVX1QuLOYWtVWD7j11/Orl6UiwWzpwjYRh1SIsaqdE2FTeEjzbl0125292S1lEilQQFCEAVwJGiXhHTuTciB1Bx4SODRHKIIip+0xkyTpM1+ffdwvd7dt/1aiASd48IssUP2lJIMMWzb3fu7D0qQREtffXb+eWUVgllKGqPGJAf+ISPKitVKlgBAJI7iwApCUjUmHEYMkZOC/aJpn9YOHPZqds1sfnZxuTi7aGaLdjtoTBY1xaHv1tvt7XrzoapmjitkYnIEgCpp6MJuE9u9xpEJy6YSFQgBJeV0FYKBJo1gKVkSZCbnXVmhVTmHC4CW81m5I+D/x9h/dkeS5Nii6AZgZi4igiqTKUq17p4554n1/v/veOved+6ZmZalUlCEcHczA/A+mAeTWVXdc2Ox2ZlMFhnubmYANjb2bhxWRRMz8Zx9WXyePX82fxwur1Pq+9/9fprnqZc83x27Sv2Yhm2STW8STiccEubZS8EcXcRzca0rO4DgzBAiFpcAiYjRmU3ValG1Uu2ktnf/CH8kHIkykRMC0Dl1RL3wbhhuLy9vX768fv1mHDeSFz880OEjLUdyha/+scrkFKTraLvhyws8PCLn86KS2A0SYkgpxE5SJyFwCBQCONBq3M7ERCIkDOYgklIPF5dYgeKVUFWCpeAXOxf3DJpUconF42xhcSlEHjxEixHe5JCIiCV13bgZNlsLkWydY279/caxVYe0ce413PnaUfyZiehh/5BdH//6F/mvP7/8/ofx8ZAc0kWKEUKtXl/1k2tBKVSrk3sX/HJLmy1OB7x/j9ORTgfU4qqeHNVpmen+Dnf3dn3lc3FzqJEp6foDSXHmWD1rQAsjhnXEbrujbvAG4BeFrjIGq3R46x/9i9dZGBKOepqXD/fzX79d/vdf8eE+VCtAYapu6l7JbSXUu4LESRzVTKoWkejeqRkRSYghQcRBVElmQGg14yE3kBsoMynDyaGNkA01Pi6mH+3xePr2x3x7Y9Ovhm/extsbl88ix7o7uldx6BwfQO8R4JEQdxQviW+IOrLF7YT6CD2AjyyzWwluhgKs3ER3jBXbTNPAFz1tQ7kZKKTc78rNrV+/8L73lCyIMRmf71VTf8JZ8e1MH28YCZuQahussQBxh2a2zG7sLoAzZcJzViMRERjs3ISDiQxehZYunkROncy7ft6leRvzJpRRMASkIFE4CISfPE6a5hwTEwJx20okBGGCK54cfwQcwE4cIvdBcrUlV9VqViWomximXI+1LuYVSMxraG9N97UO/9k8HBM3zafATT+KAksQAXFVL6VWVSJCEO5DCiRUdFkOD/fv3r//4d3946IqnQSKN5f97pq6rlIozkv1JYY5hBwLrfI1RIQG1zcTitW+Wg3uQiwiIQiIajOdcVe1XHOtteSqqm7rTVvnDONnJNPPQvtSlse6v3u4e//x3bv37+72D93QpdKfljIvueOYUohRwsA8yO3r9PY32zdfpXGjcdfTbgDDK8FmxAMHlkEvYkex972GBx/KsAtboVSBJdsyLUJVF+bCnUsfpB/ThkertuyJ1Kizywt++UV3fduFiLng8ej3k+1njeIXtV4JLgd5fSmx29nSv3tWKJJ0qJ17dBMulVCY5kEkdW2uubJWquYq6GLjE8O5lRRm+Xh6fNh/yOXUjK7P5AYXIQkMtlrraZnt/sNSc0r969svxrSNXXBzV/VarVZTMwcR1EHupGbMymbmBDMzNDuFJGGpoRqsmY7/q7Pr0/qT0I+bi8vr6+uXdxc3+fQ4L7Ora6nLfDgcP9zdf9v1mxhHCHdpZHNaluX+7vjj99O793o8CrwfB1U1Is9ALcwkqxOoriF1jhy7OGSvBQJ/cgdp2Oi5ddVuEdqF50zz8twLFURhHGPXd19+nY4H4eyHUOToIR1iTNINNFAKjGrCNi9OBGaEglJQK2p1YncjEQ+tJRjB4nBTq2q16qx6dH90f2DagxaCccPzKII3wpcxvrrYvX5xc3PzYnt5Fd3ouNfDvR8fkSc3hVtLrJzYg2AYeHchl9d0OH66DhZJvYQoKUnsJCYOgSSAGSRgoSYovRq38zoFF3sqjhAhYuRKpmwWBZvemRwFOROYjcSFEZgTSfQQPAR3rFraQhJi6vpuGEsMTxMNTnBu9tduq3Z1E9Uwbyq9n5UpaE9q/3gX5nn+9tvh2+/jx4d+WkKfKATE6M0M0wyluitqRcnIGTl7ySD2caSrK3rxku4/4u4DygItZE5QqoUOB9w/4P7RlyZ67mvGDLTmClr6CD9Xr4Q2idBI+P2IkEhnqKIUVF0zLvOmNfiv2lVPj8lBauXxMH37w/SXf+S/fuv3B1Z3hjKaO7ABurapXAGGC8AGdg/uxpxrXXJJpYZqTERKrOBiJG2+EQ6YkSmoOGW3gqpN/M9pyezOp7m6z2b53R0FoT7xxQZDh59X7fFSumh14xYhCgmIl4ivSd6Ct+QVuod8QP2A+p4MMF0lG+jJbJKSoe/QydKz95wvB+6GOl7U8QLDFiwtQ7Qng9yWL8HPcgzruyICgwJRgAmIHMZuwZvuGhuJG7tG8lXL4bM7z83dzRxQU9ViOhGOSQ4pHMd42nTLmJYhlF60C9QFTsFCRKuAmVq7ngmCZoPcLFBEfJ1GceMm4bpOpyaJISbpUhjCUvQ05XnWJRfAXQhU1Ga1bI3vSQzoSsLxM+5/7uY8vYQpiqC5sa92qe0mozlEa9VVO9pdVfMylcPjw93duw937+8elqKxSzjt4vRx6CXITiktFKJwIklEWZp2Vcu5ydSAFtrhZma15uLmUZqCWCSiolqq5lqzmapbbShCU/vmlWb985Tx+V/++pf/LO8fPnw8vf/wuH+4n07HnGc5TeAwb5eLboxjn0+PkveXQ9kMGLeUtoE2UjtxUjYEY8JA8sJ9Ni+h7y6ue32T4yHbg/CkwTl5vOFtB9TZvFCIXexTCiFpStTleQl5Lwsk1BQohuxEp4yPE/39kd6f8DCx2HIp81enk+jcb7Zvd7sfmJ+BQwROaLRJE6cKNZRMSkLuWlAytJorWQR6lwHcgzuwOMxtKeWY6yn18ZKvYhdCFK26ZGchiZDAalZNp2VR+MfHu7uHDy8vXu3C1qx6rd5ufJMuM3dVMWdzr42PbgRjt8rIgagkdl23yU/YTr8c6MkdRNR1w3Z39fL29cPHd/v7D/N8NCUtpjXP8/7+/rsY+xQGImBzI4vZ/X7/7T8+/vnP+//6S/lwx1WHoTczZa7kBhfmEAITOajWarXWZeE41WWW1CE6WKiFcGsTtL4OqbSF3zD5UjwvXj4jdBAzSwyX1/H1m4Qp7aWr72ctRytCS5JUFJLYFhgshBBDpE6pVl+WBioAQIyICSE5B9dqOatqyWWp9aj64P5ANLE4YSRKzB3zRvhCwnWKN0N/e/Py5es3F5dXkYWOR7v/YId7WyYrxU1XWRuGEVPX0XYrF5fh+op++P6zq4gdh0Cho9B67YFFQGxEBHFnpxbmA1jA5BTM2Ug4RImxDbY4Nc8WcWaX5BQNUiVq18sw0mZDfQcJTkxP6naAkIQQY4waxNmNXRkayAJ7YBdS8urWkiEzbdL/7QBX+0xo9nB3Fw97/nCXHo5pqbFNX9DZplYIWr0azFEKcsY84eGR3n3wy2tcXiJ1fnVNuwtKPZYTJCKCQJJdlsyPe9zdt0qQnmiPrcN8FgUDzgHbATAkeuopdggRDuSKnNfEzgxmUPPq9n8jtDucnKhaeffx+H/9ef7z3/XH95imVftrZY00+InMm0eOE6g27U+DAWw+z3l/v3dDzTn2IYSzSS8/uTO6V/JCVIAKUzO1UqvCuXSSUmChovRwLPf76WLkm13/5Sv08eejiMSRpCMmsDIxU++0BV263ECuCUxhQtigBs8FmmEzrLY2GIg+2Z04Qqixs8FCTNwN3m809s5in74DZ9zjk9o9npSyHOweiBKoIw7kTFyJC3GVYFXUqZoH1UCWiCKeu5E0S0TAiKrqsszzvM/5zvSO+T7KvotLF3MXagwWBTFwCMKRJLaJdvBzQ2JyEpAQCRGTEmBmtSpq9aruTiwS+zikYdttujTyUsrxyPePpo/V1R1M7CBtTemWvpibWTmD/i2X+GmfR4iThKbZu+rrwU291GorWtkcWj3nvF8mlA/68MP9uw8f7x8eT1MQvozdlnKa7+QAKhNJz5SidMydsMQY9CyJYW6t/x+IYhCGu3FeG+epi7HvEhHVWjLRbC7sEiiABFLlzIYFtXYVS8EzHOWz0P7u3fcf/vMvy8Kn2eb5VEtWU6oKClFSVa1Wp+UU8vGmzx1Johq4ErG7ep7c3ZszNno4uYlIF4dhey36ko+TLXcnVI6OroBrhBsHTnEbutEiWaVltn3OD3mZyzJqdQXlohMdKt5P8u2S3pWwV5al7nOO7tdJA/MYYvTPVbI5gTtHMhN3sNWGnzm1mYkCLTBjd2ZBt0UcwAkkbqo2T9P+NO0d1vVDPwzEfMyHklc5iCZzUqqqWS5lXqbDtJ/mQxmztYSqqqudFVEd5uRgJ9fWMza4GZkJqiqbpnVW8l+fWuv+OUd7DiFtxu3NzauHl2/effePw+HBXU3dTEs57Q8fQuy7NDKIioVjLd+9v/+vv9z9x/8+fvdd+fhAOUcRT0nNuJaqVWil8gKr9rTVWpalTpPEyGaQQN5gieLa2lr+vIyHrsJXP7HxIGIOgftedjsu16Gb4pKnaX+ci/kkIkPU0JsvZrN2LGNI0UnUKXW85IYieAwI0Tm4w2araqXUpdS56uxeiUhkYNoy75jHKGNM2y5dpO5q6K8348X1i+3NixQil2yngx0e9HTUvLTumcK1Oc6J0DDyxRVfXYWrawrPNggxSSQJxIHW4C2gsBp/gR2rsnXzoQQxSAziJCGlmFIIwgS4eVXPFRJcSZ2VgvYDri5xfUEXG+o6J+anH0ogX5FKYoDgAg3QQN4FGhIPibpoDF2rGoNaM29xGK3w+qfXvH/Q+4fNwyEd51iqtGW6spIZLDAABe6kirxgOuHhAT++oxcv/PYlQsDlJXYX6EfMR5SlDWhSLlIL7w98/0BDwvk8/ST3tQbWtYiHAtpohxGx99gRBWjFsmDJyBmlQhVVUdXU7Jci+89DvVf145S/fXf6v/6Sv/1BH/dtNTaMyc6gfVu2BqzgBxHBhdDU1+Ylh8eDq+oyd12IkSWyCLcGiTemeyUUQkVzsDD3Wou68ZJCSpGDZ/UPezt0y/c34cOdTYtcbn5BTJojcQUDbOtIvrchCyaIUyRSQueeEOQZB+/cGm8hmgEHBQ9RiZCMY4c0GEcD6S8p4T37En36eYA4JaIBiGAiLsRgcTdncWJzh1sgj8Dnob21TcxyKadpOjzeHw4fpuk96l2ifQynGEoMGoIGaeOixEEkEAcIuxCYm38ag4iliTUzsTgHc3Kv3mTiWyYOJgld1213Q7+JccCo0nfRjKbZPWu1dcj+DLkC7qq56tz0GkUCIT5ny69PgygwC1MgJiJvYtUrDAaR5qRKgmanp8dF81wf5nJYSq2lj2nbS0ReHt97njmNFnrlntLIaQxhIO4quHpjETXSBzNEgGatJCwOFg4soemOMQuLc0RiqcGz6hJrrlbMVm6xuTML/3NAfl5O+8ODeTIPRMZnlFFC6oc+dokiL161Zq9ZMvPpyAcTiqgN2TInK0TGzLVIKZSEolJidHFZ8uP3c5zmZEODwCSwDKGjFGI3lXJYTj/+ePfj+w8/Hj4Y7V9R3R6Cfqyl8qP5R+ePFPfdmGMkOJbDfikPe2Uc8pwnXD5ftODOuXPq1AVFiTJpEzqyNvLuau4aRFkdCM69S+ckZprLdDg83D18rDSFDiK9sMzzw+Ew5VJZeLMZukTH06RqEsNq21dy1exaTdXMTBuRsvEiuJl8wkG14ZMOhnMrXdzoSYv85wdVG9R5JmV13vzC0qXh8urFy9s3l1cvHx/v1PemBsCsLvN+v38XuacM3he+m6f/+nb/X3/d//nPy/2dNSUQa7GMPQrVsNLkWIglODmqumvOZT6xsFSlENpOcaumZSVYnbEgNycCVKlU/1xahIWJyE0N5oldE/PoPs/5tNSl2tIH6jeKWi2X0d1CHCn1HqkbpCqpwtSZjQhqnnMtS16mueRZ6wKrRL3wyNwxjRIuY9wO/bjdbjbbcbsdN5t+s0ndEELCNOnDne4f6jLXUmqpqlrdKrMyOzEk0mbH1zdydS2Xl5+HdiIRkGCFB9d+N7fZgXXun8zATeqazo1ykhj7rutTjCCWoqQzZO8heS1WUUOyiwv+4hW/eYnrHfoOTqwQQ4PkufXTa7GSFdXES6SQBJtedttwsQ2bkUJoA0hM3JR9jJyc/Tw/+7Sk6mniwzFMc5eLqILIGgyjRmJQhhORrBy6kjEd8XhHH3p8eInH19husdv6buebDU17LEeYQpmY2EyOR7l/YOyoKtkqAbbGs6cxoxZU1d0IykBEGMAdnFEU84JppmlBqaiGop6rmReinyiFOs6sJ/c2/uYOm+f84X75y7fz//pLeXfnWg1uTOQNcv80+u5r/wJKXonQBLHNiFSWzAZbSt2fUpAkJEEoMAmtrD44tdTk7IO4ui3AEYRjCByoou4XN7PDsZ4my1mqkjv/1BZOnJqYmjsyQOQn2KPb4ChE4phhH2B38CMoN+WFM1nhPCoFBjOLSzIWAC6xTTpYk3J5kmk5H5FPR8lKsVkTdDAQgR4U208lrszuTM18nBzEdO53fT6LyFZKOZ6O93cf7j58e3j8dj79kOi+i5NwCaGKGItzq8VXobcmjQ1hb6Ypbd6SmVmalUow6ggEiDvcqxvVRiCLEvs4bFLsiKjEvhv6fpnT4z4UnatVd/FGFbSGNtqSp6UcglAIAdSR0Fke99nDYBKGaZ1raZzAuprCgEAxBmEmInILLkEQIqoWf3zgeOzjtEm06YNpfvf+B+IPIfYsHYcudpvYb9HvPG2No3NAiMQRYGcx46ogIzevrbumZAVmNTACWWCJfWw5ctE6l3zKdcq6crrI3V0+nwD/LLSXWuZ8YiGQsCA4pz6lrmfp+y6JsALlPFePqr4smADJLq61LqonswVcXHqrW+Rxa92GrHJRTKdlf7ekQx20SBTuArrIFHSp1udDne/s8JGOj33OMYa4pY2iZ7XAhdCKFnNogXogG7rQJRLmWvykSw4Fz8bfSBJJB+mcoqnDlbSu1BG3s+k3WMlNgAhOoACilnpP83Q4HiotnYetgkhy1mnKuRgAThyJuxrNELqUUkdM5lo1u2bTtcvuOENlzWaYXNctuE5mK69sXqeVm7aq2v3Siz4pxp1dSMBR0nZzeXV9e/Xi1d3j++MhO6Z2hlZdlml/oPcpcwrKH07z3749ff9dufvopxPXuho3uQHGRCRs2tpuTCwSyZm9FDPVZa5MrkoS2sNvrFG4Ps0rr+yllbBefzKg3zQVSNgZGVpINZAGLuKl5lmXDjx0whuliuqANBYXRY1Sk1elRphX1ZrLMs/zNM/TVPNi5kAQHlLaxLiLadd1F8OwHTf9dttvNt0whL6XlAiwku100MNDPR1LKbXWqlabOxyjid956vjiUl7chuvrsNtBPtUlBOIGy65i6GvGT2vfjFf1rOYsjUZU57aXQkqxH6TrXQKsoGY7zRZdyVVEu4Fub+LXb+PrW9luERIZt7yg+dIbtD0InU6olYgRxLqBLi/55iZeXofNlkNqOT4JjAlgJnE4mdGzET44dFl8njkXqQozZ2qlDbdynwy2YpZuilp8mXE84OEedx9xf4euo2HwcUPj1ruOQoApiYGZFDzPsj9QCqh6xjpbOMSKRT/RkxtE7wQKFHpIhAFLwXHCacK0YC5UFKV6VSXKKdlzBSGso5t+FoJxwFXz/f3pr/+Y/vyP/Lfv6+O+5Q+tQiI7p6KEJmOg588F8FbWr06oxYtrrWWmtDK1z1GIz7/cVrDqjFs5edNDYwgHFlLSk3oM0JK8lY70c9d5oAXRdgTN8Bl2gN2TsvsRBPjsdg+7g+3hy5mKuAZpp+AUgUQcGQW0tKjPASRGBFDB6n3wtJKfPp+9FtHcG9qKFiDA1zDe5A2IhMVFWlFBzMKIhGePw73WPM/T4+Pd/f37+7sf59N7q/cSjxQLsTLbOmLXZJ+aedL6w51aerxumpZErLgXUUcIRG1sclZHtWZnFSSFkIKTVTWJIaWx77nrIDPMJtV1ZnZ9TK61zqWciISRzMho3SPPn4mfJ4Froxy5FTM1Z+bA0gC0lelnQpwAUL8L4647PcS8H4In9lqW/eFkDpEYJEVJKfWxG6nfUrdFN1I3cjdw7EHBPEDcPbbu0+pk7aZVq3pkJKEkkmQl2xETyA0wBxOMYc7uHuQzrPRz5zfX4poYLMRKQXh7MY6brZmkFNytFK8AsRhE3bS65oUmrW6nyd9P/v2Eh9nm2a5j/XJTb1/MlzUvD2m6D8f9cjjl4UhSSGLgrFLNCPXhofjxncwPodav4ja9uuxjF3CBecNzjVPo/KILtwseHybb7x+nepnoq8v45W54sQmmOi2l2HNoiMCRJHHsSZKBXNVVyZ1wPgyIQOIe4QkeGo2mCYOrolSdl7LolDVsN0vXUclai5XctnEWDjGxSOz6zWazTSkRo1rxmlXrWYf8ySxhPcFA0CaJwBzYXYi4uR2sM7Y/F9v6xVeDN9lBJF03bi+ub27fPOzfV3usujQYieCmeTntjydJWuXDVD++1+kYyCUJBFar1uqmrhVuxAwzV6g5hCQmDtEdtWQrucCtKjE3rSSW8+Cl+3pkncVIVsbT5/0rFuYYpE8eea7zKZ8WypXcQsh5yWURYIgxjUgxubO3LLP4kENPIQDkZFU1L8vxMB32h9PxNM9zUXcfmTZBbvr+arO92O42m92423XDmLouMAucavW8WMk2T/V0rMdDWeZSa1Ur6gpUIm2lkEQaRrm+ia9ehasr2W7omRh+889hFpHQ+K1MZ1pVm9xphUZTljobTAsBQqFLMgw0bDz1hllVzaoiVuE6dLZN4cvb7jdfx9evQr8BJzgTGETG5ARDKTWX46E+PshcxANzxHDBN6/Cq7fx5mXaXlKMjCZ9KWBtwbmBxyF+ZsTXaJ7QpgPTYrq7uqs517Xz+oTHNEbbPOGwx8MdPnzAxQW2W3SDjxvqeo+RrEKNgpA75UKHI/qOzLkow1ezT260fT+HlWf1PAtiAgmqYppxOGB/9OOEKWOuyOrVateV1Jl87k1AZAR1SJODMbNlPn33/f3/8X+e/vMv9d1HnbKtbPaGqJqRK/wpotfz53p2/qmAktVGC3RLRhFr27dxmhs6gEaw97Vp3bC4pn5iTOuIt6JUd4yx466P1kUP4vSpqbauK05NPQg8wI/wTLaHMlCJO8Dgs9kD7EB+JF9aG6NtLCMC9aAdeMPowUeEPZyAwM1awU9Au7bzbMLaYG+/u/VI1o6g01MJscrRuykaE7KRMYTYKQgzB+aOnwHy7panw+mwv//448P9j8fjXc2PhMlDdhiRE1EDthrTtGkiNX6cEa0WLtR6xwKSNoZAYEcP9C1tNno0mDqcmCWwiAGqvuSG0EcKlHrjw1zrBPUczbRho+pOZsU9EydmNcsVBJKfkJdVrdTaboi5VjVtRQ+RwUrJplVYAhEzOaS4IHT9uJXNWGdOUJTFqpdaVZ2hVYrKUuYT0yNC4jT0F1fDxXXwHfnIHE0iByXrirsaOQucFVaN4ShE1aQ4F3NmW+X5QSAOwkzw0AgBlP4FIL9O63Nxjw6TIN3QDWO3LNU9my5uKGqmOGY/Zpuy5QrSMmf7cG9/vcf/fqCHmazQbYKNRqWEOp3u7fQQD8fyuJS8QAtSjUGVaiHLnoqOnb3oN7cXV9d92vap6ztBV5e+HGS5h07JbWfzK8vBlx3Vy5jedPFChJSXrIfFM/z5SB84Yi3cI4htpcs4nNGMtpkh5IhAQnNOapPTRE5szqo+zWWp+Xg6OVBV3UgrmUPEo6DrI3Nwt7zMh+PxcDzM3YnqYmuut8Lbzdfc1+q81XGtk+QsTT2BnMn5bMry09ezPth6KBCe5jaIY+zGcXf94vZh/+p4+v40HVvAJSdXs7LUcsizycPi0wlaWABnp2a7bmbqWhsCSw0QMocahyhCQcSVvBY19VqJ2dQcREFEhERcDWu3Zz0xfEXsPwvtLcOOfR+7PoBZnUgB5xAQo9Zc3SosBe4CF2BxndWHatvoOVg4Gs2lzks+zqfj/ng6nOZlLoUMHSgSb0K46NLlMGyHcRiGLqbIHFSpZqvZS/a8aFlsWWrONS+llFyb96sroNR064hSCtutvHgZX78KFxfS98+TLSISlhUobNME6wNlafihyNmUWZ64tUJEzCFG6XrpR+v61rH01oWJAX3CduCri3B7I5cXlDpwxFOTmolcdcn6uNeP9/j4EKYSERA67TdydSMvX8nukvuRGkWPCKIm1IRmyYncST4TdFrbJ03E16Foqa9xVQKBz6wrd5ijGqhgnnE84u4eP77D5ZWPW5gjdZ46kkY7qMRCbFwrzQud5nZMkp8hjPVnPmsPt5v6NIWVCw5HPO7x8ID9wU8TzQty9WLqpF1fLi48pp9sjzPdGXC3ZckPj8e/f3v/v/5j+sf3uj+5qwq1KN6wxubh0Rr9BSiOimYSqE+hvVID2lHgySk0URhithWTOSfqa4+entT+3NFC++o6R4VY+nB5uZWrLXUJzD+TBgS4Bwt467R1f3Cf3CZSEAosgtQ9kx7hEyGj6Ug43NkpOKLTBnxBfEW8JaqMDAQgEDJwhL6DKmjGeXCwKcufD5FPMOH5XRmoYkUxDKxP466t78RoNXciGog+rSs3m46H0+HheLg7He+XZe86CS9m1d3ONPz13AKt8KWDDM2ICw0cb0fiKlXU+gA8MF0ECFdXSHUYhCWlrg8pgdkUiqgWq0UW6Qfrh6GfMqMLIYJcbSk1EHkpU62LGtQJ5uZEFMj5eWyvZlntbJTKRBDG2bjCtYE+DhdxJyPAiSVutlvUC+g26dKTGUzcTdWhRM5McDXPXk4276st1RbkI/cb46gheR6QOgWDY0g9caoGNXYnJV4sVEVmJ+KGVTHbWlOtmwfr6MmzAPJZaIc6Si06E8NYQoghCAuXejL1WnrmUIrOk32w8v6gjzm+MIksk+L7vf3HO///vke2ftMNNi+0P6Ri3eLLsR4fsJ/LndZofqo+miYVn90K2aamrnv9py9f/s+3m+suDaEx3gSO5WSP7+3uo72/S/t8y345sG1TCLED2UE/lnq0fPA88WfTVpAAiRQicThTj1YR0bYr16qdOqcOzuRKXuEVRMSBJYIk51qXfDgeibz5A5qJuTE4RtmMXa12d7+vi5GFztN1t+tLaW12w9NSbs1OrD1zkDOc3Jic3ZmaoEr7+IXIfr6kZ9d21jEBAGYOqRsur1/cHF7d319Wf3AqZsoU2FyK0pz9RFgKuZGQh9aGdWs+5aYrx2qNIlBTrSQhMAcRduGSSy2FuRJR09ajKh6EQ4DD7cn5tJ2w5mpPPNynd9/sbbb99kq2k/eTHaJTSCm6myPXvKhW8sKYYEI2Jtn2kosPkXya6nwq+2V5XB7n6bjkUpTVLsBbopFlK9KHEIUFoLx4WcytusHUrFotWpfGw2ruSUU1V626XvxqCsUi/UAX1+H2Nr5+FXY7jp/ZwxDRuQHYsENqkbsJUJxD/iqywfQU5gnE3EL7MGrXG7ORg0miUEpIyVOivqNxRNeZBBd+kswiclGt+4P98J6+fy8/3qdTiQg5pDoMcn0lN9e02VCI5A7X5i3RGkC0Iqp4jj2shwCRMStzdS7mRV2KERTeUKVVIek8kOmg7Dzh4z3GH7C9QOxxmsEBHEFh7TK1mtaMc6Z5BoCqviYo5yXbDqOnxlP7a1XMGfsDThM+fsD9HfaPmCZfMkpVtQou46bevLDus9B+lllc40E5zacfPzz85e/3//sv9v4j19qGAxWubu5NvsAbi1sJxW0hL/AKU6zSfw2JaW6B2RHdQ+u/kjExw1eZsNazb0jyp+25IvPWDI2YrYvj5di9eTHevuC+a4MCP1WbpR7cOV0Yduy94+CWiSqQQQJy8gpb4GWdIyeCkzsbJaWN8xa8Y74iuaFwSXEH6twFfg/93jPBj9SylzPx7tmZ8hkRw9csKIMEkNY2bKX2Cpi3dS8C7py2/qwJaqan/cPp8DjPh1JOZrPbAhRzPatGkHmbwm9/IEOjF69Vu60TaecPJ3cXlj5uOFwRdVJLdanO4BRkiOM29j1LYCCECB7UEwfrN9jlnTvBUh9HCV71NOfqZtO8L3qS6MQgN2Ymqk+j/u2l7tksoOXqIcjaoamqZrqevedstYGWQWSz2w7hpusmmfdWMiyfmJqgSR9oGCIBblqXomWxQ5nzQfotd2PlaBKp6zl1CF3sh2F3FbqxGhVjo1BNimlBBYXVsZnAbAKT5mJdVUHuzKmMz/bHZ6H9Zge79eMpn2bLllRlnmeFPewfTEtMEkM3LZOWcqf6415/uNOrC06Ja0HDq3UhCrHrttGiZcvzkh+1ZDc3ShR24kQ5qFVnpWJmiMNFP3zx8vLXb25/9abrnVlLqW4gFtp0vulrt1t0qA80LccyZRLPhIPZtNhhLoto7fTQfYZFUBvNpAiO52Wq52V8Zp04n7tKIKuwAlcikpBi7GPotWLKy36/JzJzDzGliKoGrBi6qs7TtEwauH+5ezEvS9Qzo6a1n2lVgF//9qkWa4dcQyMJq7Ygfvaipw0IanjZeWeem5ZEHFLa7S6vrl/sLq5Py/vq97lUNoIapowTdFI6KeXspXgpXpoUbrFGdPdVwxDM5AZ1s1prbqfQ6itcqzeTiEZYYjETMQPIG8GE12MW5lgh+s9ezBKHcdxeXI1X+XBflpO6Fa41khtVRy3H4roQCMqwHCVzmIgiPJfDPB2WU15OdSol1xrUR8cFIRACIO5Uqy2zmldiWFU3hjV+k1pVLavVsaO06F6rtfE9bnQI4ZhoexlevIyvbtPLlzIOLPLfZFveOsZNOGKV8fk0srq+Gv0uhJhC6mvsXMRQ3Qxqom7V0QxoAWVqbCK4c+P/5ILjUb/7ofz57/Td+3h3lFxdgm16u97Ji6twfYkuGRG58TlonFOtRu6inzwNEqYgGqSEkKlms6gubEQK9yDUpPSA5o3o7rpSCPd7fPjguwukHu4oBeY4x6u2tLkqW6F5AZHbarbUGpjkT43p58HQUQoORzQZ2g/vcPfRD3vMM3LxosWwxFB2O335wtNnaPaq1UkwOFVd7veHv35//Ov383fveX9K7k7NPtbrmQ9vICNXoMILkOEZXsmM1gFIArd+hoKiQ4BAEAIThOzJS4QArLqx9DTe7esMACm8QBG4f7G7+PLV5qu3m1cvQ59+WZGKE3h0unC6VIywCF7IKtzXdBsGL+3XYZ0oZCdxHihcg3fOAyg6RQq3lH7lPBAA/QHFUX8ApbW8+Xz5fgr0Twxdd7SUAnCE1uUgB8HA3hruxEJBIIPzDvQJKXX36XRa5qmWxayYr74q1ayq17bHff2sjuoQA/EZyHHw2YrXscIi3ggroeM0kndGfa6hVHFPLF2MXQiBmBghuLD0TpGCpqSbzUDEsC7JNiYyn6oWN3dUEJrSFUFCm1/5/MZUs9y8W7A6rTKtJl3UvFjk3JJrHkrEA2EruBr1cqs4xNPDvVddlpqIHdhs0nY3MJOp1iWXKedSSz5WLTYdi7OSUIiSUkyDjFsrs/UbMyaOIQ0iHVMsLsXYsQoCOxkxhGGmWbXxCntW/LPQ/s0X/E2Qv/2j/uP7ep9tyrh/JDvw4/7erChZ3/dTKW51X/393r/9Ua/6sI0BkIFwE+1t5zSEq5tuZ7yd6sCAzkSIPbY3sVpX96YHXQ5lnvRkoF2/+fXrq3/7ze7t664fdXlYlmMuxRwSu9QNw8VrCS9KvT78yH89vN+/e3TN6qTg41IOy8IbShfiu+cMAiISUHAKDZtq0b11xp68ufjMXgPMPbsXeGGmEGKfxiFtCLxMy+PD3uFd6rq+dwslF7OaZxPWZal5qe52PB1O81RVbRWJ9jbl9oQ7nrMJCNbkoo02NVr1WdkNqx3WL7zojGT5+c+faDQiYRi3FxfX2931/X4zHz5aXkJgy+SPaifVuWBSnxfMs0+z5eylWCleKtSaiDxAjeoNM1et2UnrSgZdFc+eDi8Cr+7mjSJL1MDlNdStKP1Pji+RNA7j7uJq9wKHPZeFlBZMVdgTzzW7ay6zNrlseK4yBwkVPOlhOuxPh2m2XIzUouEC2LS4ToBbLbWcjnlZhMWB7MpEHIMzO5G2eT14hRezolq1deC88XFFCBLRjXJ1HV+/Sbev0vW19P1PLqH1pdv1nnnZbg5uQfScwX2OirW+DFiChBhSJ6mDBEP1UjFnMiaCR7Ep65JR9VwGejCVWny/13fvyn/9pfyv/6Jv36fjbIY6iF5u6NVVuL2KV1skqa68MiIN7twYz9biGX02/EbgEChG7dKS4jKVXD2qB3KGsYOtdYta/mmAoUkvLYzjEfd39G7rMSFGWEXJjVrxdEdITbzykttdMloJ5GxtHRnhWa7R+obLgsdHnE4oM+4++MMdjoc1tFctRFOK+WJnr15iOiDPT8v/7KoLV/OlzO/uHv/j7/Pf3tn9gXJ1ZmM3h5IruTOMuM2eVjT5Vs+ODK9wazYooMZ7qaDqCGhsbpw10tbPdO4+tYYAPUXIlg6RK/zkNXT92y9uxz/86uLXX+1eveAutgGWn/TdiBLx4LwzunTfuvfkRlaYnKh5FtqZ4r4GPCNyYpKRwwvnwQCQOi1OifkWMhoZYYFtwR0Q19B+Hvt7Oj3O6KLTmpm5uxJlkAECYzc5A/dOHDgIsSBGk8H50p6FdjOblynnRVtHcp3xRVEU9azIimIohmAUnNhATcVo5fG5MNgdtg7ytZvLJC7BWbRSVlpKyDWppeCpTequWJr0zB1xJHcO1A9JODINkTcpsWM2BBCFIOydG9fizBBu2crnBA6zUqoLu3kUtJPa3WDORDHEEAKLNDWdJNxF2Ui/4XjF9Brkj3Tv2XIpSy2BWPjicry83klgd69LzXN5eDw+PJ7mknNd5qxZjYljjOhHHjen4yOnoSJwN4wX191mN3RDJpnMq7G5gIhgIkIc4G7mRa2ol3+hRvflW3l5Gy+26Lr65x/r9w/zPGkBOWo/xoub7upqF6JE0q6eXqQ8XpqJnxaQqRS6BH0RSKF9njZiFx1tQxwjXEJHqd+Ey2spR8uH8ng33T8uqoTr7dXvXr/43ZfD1dYJyzwvx4OZOZGakUjCFt3Al6i764fQPzix1ZjQjWEwJTVKHAYvHebPtksbPk7gCMiZpdbqh3Vx0/nohbVpyepuTIgxbsfNxfayi52Zn6YFTDYihtROqWUui3spXmstWYlJTbWNu62k37U6OQPxn7bSp8J9/UyfJ9Q/p88+9drPAG3bg40p39pBLF3qxmG3214P/e7ukfNSu1p1gj4WPVCeWRb1pWBefJlRitfSpp/P/Rleu25nlpCbVjdCeJL5WA9u4ClSWMtMVuHoc/p/noX76WUwcYppt9u+eUMwicFPY9GHYNMjshGVuji5laUVQhWwUvVYysNyfJhPp1qyu3k0F3cGAsBwwBVeTOfFmAqYmlESMZM7RIzZHGpWHQVWdAXW3Lzl4AAxCaeOt7vw4mV68ya9eBF3W44R/+Tla2935Zu1h30uTM936szKXTENIgkpdIN0vcWoNKtWzwucnGGJcVrqaeFSG8uaWsA7HMu33y1/+ev0H3+Z//ad7A8MsrGz64vw1ev4zdt4e83b3oTNmwPbOYa3t2FnjRj7fM9HkT7a2JWhzFOZqkX3oEoEgpERpN2WdXsAjlpBRNOEsPcPH8ABY4/AWGY3pbUF5TAjVXJnrSRiTCpSq6o51JvWNbCOHVDbkOZYFnp8BLnPkz/e4XGP04RcrGp1LCFMw+CXl3T7Ej+Wp9CONszWbAJKrQ/H6dv3x//4e/7xo8/F3LJQbXIyaIwKGNDayIVal52U3JzUYQZt5wMZta/AKrE0yVNixhnrO2fq/hTRP30+U5cJeejiy+uLP/z6xb/9bvfmNm03LqFqpme7/bw3ImTL8ZbTW9cfzT6aF3Jds4zGbFsFrdYQvJoakhF0Jdb5BJ5gN643TjfECShNNdbPViOf/96nSSlfDQtXcHBtWbgHd4ECxq1k5WbLFBJCh3jh4RYE4FNN1URYfVWKC27RPVQvWX1RXypihYizONfmuNKotxTczV0MvD5OwNfBUajNefF6OC20P56mhav2wOAmNRtx5diieyASAI3bvix5mVwoWVDYqeYsHJgDc2Oax8atT3FIseefQBoNNHjqDbSbSCyRiSmEQER67qKbMeCiGqjM0bwLqeu3m60WA0JRJcZ2N1xsuhADEWr1vGiQwBL2x9lPc61VS23Odmzk2R+XU3YpEO6G3elxe3m1u7rquo1wNAmQNnAU1LGYAivjw3/2hD8L7W9fx3+76i8vedxy/T/q47zMU3aEi4vx5evr3/7hyy+/vr25HrY9JZ9SOfTzY1/2cz34VH2xTcVr0Lzk+v5jN/Jmx7tBLjYpjAO6calxWcQWL6f6/v3h3d1xWGBXV2//8OrV1y9Tj1wO0zzlZWqjEapaCpYchUceAl30ZdfrRUzA9RVevebQJQU3E4ofA3/32boN4ETSEccGefuqV/GJbUDngxjuWNtwxoQYZLvZ3lxeb4aNcCi5HjG7UorqRmWpp8Ncq4awOLxWTx01TeCnvf5pztv8KSivW//TjNt5bGplpf4T5PdZcvCT8+ATJg8KIfbduN1ejcMlIdbsFVqPpg9W9oYZnNVzQS6UM9VKpmu7c13WZ7+lhgAytfk9tdY5cH8+WL9yluwJRKAnMs5TRPm51jcRC4ft2H/5lsc+bjZ8d8mHd4Puf+SFUgQrH4SmvbsTsWop83zcL8cfjvmj1hOkUiJ08MHROWI7BBzqyObcSMCEcO5vw1q3zgwNAPRqXtWqaouCQtTGz3GWqYm3t+nN63R9FYaegvzsEP6kPHAO358GtVtcb19tf1htMtbejLDE2I/SjxpjIapa1VlcwGxzsGmR4xSXLO5soGq0n/yHj8t//PXwf/7/Tn/+6/Luo+QShg43l/zF6/63Xw2/+Sq9vOIhVW6ziOffu/p+rN5e5vQ5q5FiijJ0vq150umUT1lTLcGUzWmFkwzO4LPh2rnZAppBgnjnRNiM1EfkE5qDbuugmsGU3RhOzBZCCaEuWdXOEsZOtM5aN7x3rdoPj9CK0wnHvR+PjUCnahk0pThtNun6Kr18wfcfnm0AnIFp0mnJH+7nv/84/ee39f0jmSsjM8xN4dVdyQt5bSy5xp6jtju5Xaw5r3iR4zwTB3NWCnRe4mc1gzVBX61qn41OtdWh7gghXF1sv/7y9f/8tzf/9ofhxY3E6M8q/s+2OCeErXRvYN9o/c71R9eTY3EmNxDDudFOn667cUOMfCL76CTs6qRwQq3Ge/KvIbfkC/nyDAj0T++yaUQ/mZmBaI3u/ikhdXLntn/c4A1dDMFCsjB6uPH4BnIA7p92hJqrw4jAgbgDd+apes1mc/VYrPlJUVsotB4TDhi7sHGjm1HTfAKxQMSK1uNj1fI46cfHw7SQWk80mMpyWtytY4L0zUyljcfmvDw+nh7ucmTro6dIKbJwiGEYx8tx3KSw7fsxhhRjl1K/9mHOL27itkxnbiwTXIhiisxcVYvWUqtqhTkTJ5Ej8sEfuZsvtuUK1I3bS5cQh1KqQ1MXuyAphRhDVcrJJXZp3KSHgzweWE7Cs0MlcDeIkt3fP96dltmYQtrcv39xffNFfX1183LcXko3SkBKIXXdYdF3h8UNIKFWyctnWOnn9jBDvL4eDJ1LfZiX2fTlnCyOt29v3n51+81vv3z7xcvLq83YS/CZ88FP9/p4Vz9+fPj+4/vj3fEhL0dnRuppI/3V2G0HjokgnImXIAvFJs/VF750wQLf0UWvfcjEblwQyGOg0LAWDjFIIGqTKTx3Y+UXfDP2r2/5zRcpdQbkkjUvpid892nAh0ACiQg9SSIwnbkrtG7EZ0+yQYL6JHsJYd6M49Xl1eXuctNvTssxzwWVFqkMLkVPp7lWizGICHiddeZ1PZwb6mciJeBPR8GnbblWAOv08zk1/O9fn76NGjLvjdZGkBi7zeZys7kKYYRFGGFxTMWPpU5KWVGVauVSuSk48CqR3Ea08VSGnlvEbqaq6/tuvYvnOGo7B9axdpwBaawB5pfzEeIY427HIiEk2m3lcN3nx06P6bgNMcSuD4eUm4LsAqUCdc/KxXqlRNwJJbfetVeP7szkRJW4AGzWCLXBvM1UoZWShNZtVffaTE8aA+6sVUESqBvk4jLe3navX/evb+NuyyGefaL+5eN43k4/s2ueYzAraQxELCF1/WaXtzvruxK4ZLPGQHRzNZ9LPcxhylKNtNhxKj98KP/199N//m3+2/fl8aAsfHNJ4xjfvkrffNH/9pvui1dysfEgoHNtDbezPgydEZQmiPj8bQ+bMTC0mM5lOQ1Tte7k0sY/YE6rW7T72mMmNGk7cyrEMw5HcCBTaAcrUIM6WuXbVIxYpO/DMNQGFc655ErmrVdD3CwyfZWOqZXmCaaoBdMJ04mmxXOxagtwjGEex3p11d1cp5cv+G/dTx8BQA49Tqfv30/fvcvvPtppYkCpVeRNHtPNXAkVa3RXYJ1hWcGzs37syjf2dcIKINe1bDsTu6mh0+37GhP+U2brjRWbxv7qmy/f/vsfb3/768svXqdxaE3zn3W82zVE8MjpFexrlL9a/Zb8HnpoDJjWnWm8r5XTsLIGjWgBDuQCd/ICrm7q9ei4I3sFwHWC3YGWlVf/CYtvQF17O/a0cx0rz9Aaf+EsFgCAyYmFqHe+Un5j8tblS6e/P4V2NABy3VhBQuc+mI0OrW7F6qIWqwv7ai0ANbDDHKTsQo1SvOZrYCcxqBMvvjws5fQ46cOhTDNyDcTBDK7ZCByCpApWouKt9VzrfFr2j3OSZH2nXTB1YbZoKQmhS2k7DrsQYwgpSLCf6GvBxdqwYFU4TGmdegAzF9WqtWjVdTiICups81ROMZ12tSjn3opWz84VbGZ1LrnWjQ0bHkESouxiGjab1HUpxSiSAqtVMDiKmqvXpSzH4op5XhbTwl6Pp1MYd9L1ImkYxt12NxkfZjtZmNFlcDX+V/YwQGAZr64ZEY/THJJnukzbmy+/efn265cvX99eXt6kfhQRokw2uR7n+7uHb9+9P8S/H08f3p3qMd9s5ettf7Htr692m6jsy36p99P8ADqBo5VUFvLScb0SYs7dcufHH4kTickQgmyEQ5AQhbsudkNfc635Efqw6abuFl+/2r5+HV+8ijFkKLTkWvzjj8Cz2V1ndg8IPaRbBwbWAETnWhdYx2bOg+dtutxBTEM/XF5cXl1eX+4uSy3LXKZlBjGDa/XTaXEHETcbEOEgHJu0kjB0je6N69JC3rrwnyAefvbxFBvOJf1PXv+sbD/nCuRwEEkIaRwvN9ub1O0CD1Kcc6FZfSp2XKhoc9U8Kzi0hdpMkvm8659COzczVD3zmUmEztUoPb2xMxBt+JRnrOfjL73ltaEfAm133A1yczUsb4bTY3+87x53sevSMMQu7g8Pp+PBawkskaVjJvZI1EsTcNfkdfCS1JjYWYy4+MowbNGo4TFuAJyaWFCTPmtvi5mNQEbU2mUdj9tw/aJ787Z/+7q/fRHGYc12Pg/tZ77DL77OuOmz6QVf78VKrY+xH7YX88VuHvoSpdRSfR0XJodntf2MUwlFVa3eP57+8cP0H3/Jf/s+f3w0EF1fyfVlfPVi/NVXw1dv0+tbubnULpk8MTDWMrhlNkxPb+Bs0HJ+jbvL2HdTrbbMZRkmtWjGE9hXBRtfIYBVd0eIABiMUEEZYaIgJC2BMldHMapO1ZpUEcUQd9t4eVFjKOZ5P5UpkznBqzkD4awTBlOUDBiWCSVjmXnJvhRTL45TCI8pzhc7f3EjL16km2tOz6atAGuUMreyPx6/++H0w7vysEfOcm4rnRsn5zYJrd4wdh6YaxuTiRmhUUBbj/gc4H2VD3H9RPNfW+veEvhnqBWZm8ERQ3918eqPv/vV/+t/vvzmq/HykmNsTXn+tDqe7XAScE/hJacJ6Usuf4H9g20dhGncdG9UebTZJz3j8wWYgCZSVuAzbIHde/1gfE0eyQ32gXACVTyR/NZ1QWdPLIZX98b1fsoGGcYrlbMxkhigaNgY3yr9yvkbki/A9892B4kk5uzORCGE3n1UnUBq7tVytjKr8eqio95oRsQAKXsrv9a708bd2YjdgGo2ZTlMdDj5aaGqJAHmTRqeJNZEJaTsZuZstVjVkusyFY8lisUooM4R3BPQMQ8pbvt+F2MUEWLWUj5jy5pBYVYdpKtADJhoWRYiWiGMM3PS3XNVlHJaFplOm2VaOPeWXbU2fROrVrJZubm6cJKYYghhGIfUDX2f+i4ERmAvtRSz5gsqIhKClZJLVXVXy0sO7z5qiE6BWHab7c3VDfWbHIYcNkU2hWJ1voyfbfPPQrubu1arYMf1Tr75qg/D6+Hy9up2e3kzbgfug0cOIgPxBnxJXIivat1231Ub3tdwcq7DkG5fbF++3O2ut9Gzz3ac6g/T9B70yBy9DrVuSx1NIzhqnn78/t3gcj3wtmfputjHOMbQBQkxxZSS5UOdZp8Pg89j0O0Qhj6kFEIkOEIKqZb48JPOqICjhx7StWfzdMD5TxmR54XeWDJEDMQQhm642FxcbC/3h+PR51KKanGn5oHCzISzdwFJ8/EOFIR95XbwGh1oFYJ6+p1rJDjX9+0f2H9hy//y66lCwFPMJyKHSOz6cRgv+uGyS9u4ZFGlbFiKLwVV8fQL2giLCIm0EfuGovuTUha1A6Wx31f/8fWEXFtx+HS+nZvKTzwf958Gxc/ePBELSIRTkK5D11E3SN+lceyHIQaB+Xw4eK1eK8NjpBhD18dtvxu7MVQLS06HQzdNwZ3V1o6teYXTioJ/6vq3/KON37Qqi+1sDiFMMfFmE25uurdvh6++6l+96i4vpEv0PAN8WiitzWJGz9oTTxe/fnr2nz1DhojAElK32abdxTT2pZO8lOpuAoiwRFf4MftxsTnnvBzu75fToZrybtuHL72P2I3dy+vu9qZ/c5te3tB29D4Zk62CDU1PAE0O/swvW2/E81FqAi6vX3ZeSetciiotwKlRf6eZtD49u08iunQGddQoF5K5eXDC29CEUa2o6tXc3CT4uO1evRpevsjjR3NfHo7TNJsWNquEQJRWQMOoEQ5adK+FcvaiVj0TH0X2Eg4x4fIyvXkTb27ibsfhWWhvoxTuXjTvD8cf3s33d3U5BVNhDg49N0M+QWUr3xsCehKk5zPWQiCyp6hPTwWyt7TG2doU+1PGgHMvzdtq8eqOIOPV5eVXb1//4Xevf/+b7c21pPQp9fsFPB4AOYRohFySXEJ2vs6LGxrAgQTegOCm1EI4GjYVQBGUyJtERgWy+wKvpCd4gAN+gi9wxadN+2mTtxW8ZjGrYm77SkuU14aiUwQlo0ultyq/sfAHjl9LesH8iWpKREM/5lpWPTlOCIPbrhEADaImRctyVmSm6sRK5O7N9g2NWrD+KxuJE5MBRZEzLwstCy1ZsjLyqdnx9J44JYma1grZa87LaVqmU8mz8AjGZndxc/0icJdCvxkvt5vLYdzErg8iq/Xs59KZeVkO+7zWYO17/LydW/bRjm9mJobDqkpdombzfKhz0Ql5rrUU1SYno2Wxmif1BXSxGXabvut7YYuBUmRmN2hVy9WyIVcXil3wFKgRF5aqPi2UVUXAgSXMRQ9z5m5L/ZY3ymOsjOqs+hkq9Hlo16J5no75dPQoF7cvxu2Lm/HyFsmJdFlO7jzUPqaR4kghsZD0F+P1sHl5GK7/urnbs9eX1+OXb65e3l70V6NnzlpOh/J+Wr6DfJQgrqPqbbHb6legUOvjd98u8x2/3Awvr69evL647HvpYtyKRJZIFLXOy36y4zHVuePCGr1aLa2jw5AEVshnfGYndkRID+mcpHHnzpDxWojiqVJuIBoEFNoOFJYudptxu9tc9ule6LDUmrPaShdiliASmAM5kQuTCAVhFnY5qx+v+gK+HhJt+fva0f6E8YHImyoT/Xzi9b9/PQH0zBLT0I+7Ybzsu106HEUXKu5ZUYqrcdNoZgY3RahAEoiFQKtPfBtmOwO6aFPsDgpN9HTVpzofiZ+663C3cx7zrKz/+Xv9BOkTQCwUuWO+iimN47i77PqO4NP+8YP93WrWmgENkVMvG4wXly8uLm6CQeaZ378XfEzzJLmsE1BwB7ShlmeXQ6yzW08udWdNAyYmZonc9XxxGV696r78cvjV1/2rl3GzkRjXJfJ5Eujuqioin+J6C/XcwGoyYpCtnhLrd5x7FSAJsdts427nm6F2oQiKuQqRiIToRjgtfpx1yvN0POwfzKtcbrqrbRcTbUa+2MSrXbza8W6DoTOCEszV3Mls7azDGzviGYnkuX74+hQuX7weA9W6WM1GlHlV3QmmPBu7rm4xDDHIunPWFgdq5Zy5ia3W6sIghhvUoOZE1ve4vOrefrn94vW06bPq8uHxeDzaqbK7CTtTJYTmZFOLu56bYpVKdUV1PkV+DPEQw5TScH2dvngTb65lGEmejSA3iU1zyzU/Ho7vP8yPD1pzcGOSAHZffW8YZGiufJBPCTEznNkNFAgFLka6murQOcAR0dnGo2m1rQncp7SW1oGJNqDmksL46sXtb3/95ve/uf3mq367PVNrgGe43ecb4yl89qCRqH9y6nSHe3QMxJdE7JhgB1gB4Aig3nlH6ECBfFr5X6jki0Mb3ABTeMXnHZnz9l33btOYsubcuk6dAWhj7u3U6pU2yq+Uf+Phj0h/lPRF7C5YPs1aMfNm3C45C7GCAgfGgLhKrztEXapRMbApqbE6VwfMDK2R2aQ9mIyFhVvDQL2Z5xpr9Vq5VJ+zZZ2rupkpEDsdRhPRRm6ry+m0Py7TodbiWCT61c31l1/+uk+bFIYgKcWUujY1t4690jP1YgDTNN8/HBrtKMQoEtakvtERpInjEgmHEJlIrEbLo9RUy7Isx9PpcDhOy7xoSYHHLnjNWpZjtZP67XV20tgFES6llpqXMp+W6TjpnFGMi5F7TIH7wKolqzm8NG1ZEY6RQ7co9g9HCiX0dfBuTJcmoVojy35i+38W2qmSnfz4Qe/urUbibSRhRFeiolg8Z5w8HHvvIpJwZxwgHaduHPvbi7676iJ1bzZhKx6saJnMqjKFKGOSzlR8Mtfiyon6MIwUO45Hsmkqej8dK477+nF4DOkipW1KuxB7Dny6f//+r39f3r/nusTorKtpTpvvJO4ZHfHu871CzgFxRBzPo+1ETX3oKbicC2UCAwEUiWNTDhWRIEE4iIQYUgxJuPAq7tuQ+MAcaGWZMqOVwMIMDoGDcBBrLl/nQvZcOj1t7/Mmb/OJIjgT6j7b8u3/z0g58M+jP5HEmNK4GS7m7iKwkh6sGEqlJilKbM1ciAUiCAESWzPAFNbC37nQcUIT7VrvFTNLcHet1tCotceIc1D1tTpqF/BPm9TPSQ6AEZFIIO5FKEY1rbXuHz5+eP+PZTmVZdLE6CXWLsVhvLzYXd1EEsmlCWLJw304HrgWUm33xbz10lu/eFUq8zPfjdbuojOCSJBhlIvL+Op1//XXw9dfjm/fxIsLiYH4Gez6k7f/s688Yb3mzn4mzp35c6A1vsDBErph0+92Ybe1cSiPx5ytwsQNBmS142JzNq0UOV6OnCS+uokhxdRJ10vX8dhz33kKFthhMIU52dpqWWEKBpokkp/LMfhPxAE3F9cXQ1drFuYppiKiHGZCcOU98bxAq7uzuTCtEmxEDeI3d6/Vl6aCa4gBIu1xmwQdBn31Gr/+Vferr+317Skx8lLvHk/z7OQyL05iRNGNAFdlN68Ed9IWYLyAZuY90R60dD1dXcbXt/0Xb+JuR78wnE3upksuhynfP+ppZj0PBRBAYCIhNge7i69OhUQUmDoilzYiTBXIZkU1Qyu0WmMNwKj112glIq/beV39n/Jgh8KNgC4N11dvfvvrX/37H19+8XbY7SSG88r5F5gcASBXtwU6eZ2g1c2cHM6OANpAbsEC/eheCAcggEbwDeQVaEMusL17E+CaQEwIhDa3Pzc04Xk+fl6754nFNS1c5fXaoG4b2nQSRzTcKH1j/AcLf+T0a+nehO4ixI6fZVrMcnl1XbR8/NhVzW7GHLq0MXNVB8ic1KkaiRZm5bpWAmqQVY3ThSBC4h4DB2MSnEO+ERmTMikReVNRNKpqtVqtqlWJoVpLWZZ5qnVxqHkFWUppu7vYDlddHFYVEeEWzs9oymd3hThI6NZuHTNA6mZmvDrRBSJpgb5WFdLgeZDlZfKYPU/lMC0fjnO2QsEQSMTdSZX3qnY8VbKllodT3vSHWu14yj9+2L97nA4nO2UvFeYE5uq+FK1m6s7Cse+Gse/HQbqBQrcUnKZqCBz7GJNwMHMt9V+FdhTykxw/0t07oovQxThltZJNghE7a+CF7EToWXqWCCInA6GP/Gobri5TR91Nojgvda/qDBYD91283WLJWsq81CrwbT9chn6Hnl1Ya3WbJuR5su8fVV2RQthuxpsUN0yUj3eHD39Oy7sbzgJpqXWbiHSHYCB+gZ+FdnCgtKG4IemI5YxInSvlM1W9qc7iSd+GQ6Nyi4TGshYOMaTAWbiZShARB4lNabSd30yyfgg4Bo6Bo7C0UO2fOGfr3jlvshZBWlwPAWdtw+fXcYbB17+clyN+Enda0sAcYurGYTf3O9BJjUuuXgqpNnsOIhgxiXiL6xLaG0Rrs9kT9QC+KmOs03zM1KY/rYm6nS9ifRf+1MjzM+r336APvrY825FCIrGXcLV96U4P9z/++P4vUz4s+Wh9psECYqxDv92M24supqAGFkgEM8F5OnFeWNcqx+DcgjxZw5HhRk1o50xlEmZJXdjuwouX3ZdfDL/+1fDVV/2rl3EzNuG2XzyGzxMNT/SIzx7Sp1H384s+H4VjljSO3cVFvLzI27FEzqVWU5i5qedi06zLYmZx02+GF4E5hkgSwDFAAhjExs1U3tcsE9ZCewNSm82bwddg3zSA3ezz0N5vdrurCxJKKd0HOQhn4YVd4CvIOLlrZTexVT69DTi03oapqZcACJxc0TT+OejQ69UL+/Wv6U9/SL/6hm4utuK2zPlhP02TqwY+NZFdqdWtuhnb+oxIYe4VmIWPInvQ0QnjGG5v+7dvhjevwmbEuePx9ECI4GaarU5TOZx8zmy+Blq4MxGYQeLUDK0EKyVNgkgQCYFYQKRuc9Wl1JlK1ppJi1uFK0GIKloRvS7z9tk+W/sobibcb8aLV6++/sMffvNv/3b96jaltOb1/2T85byuQDCyGbq3cmflAbpQ04UlcgrEG4SXYHIspI/eputpC76l8A3RBVygd26dY+OY2VuVcgLtne6B+tmW9E+fz1D8CkZi7eSspwUITtGwrfRW8SfwvyP+SbovYn8V4xACP7e8EJHLq+tc8jBsljxXnZlCSqlWtLvVgrEaqjFpbbQHa8YFBGZncmEPTiGsRtjiTAQhBPbAGsRjQFSqTrYmPa5quei8ZGauVUuupS5mhciBalZZOKVuGDZ92rite6EdV20j/8ROMKSuH6XJVTWgTlGIKHQcQmBO5GJqTfgKKIGmbay3O9ARP1idS9mXShHbbYqdkBAqIJ7NtZblUR9PUyeHyEEVudj98XR/nPeznhbT6m4IgYipOFVAQV0K/dBfXu0uLy/iMLp0S+Xj4kUZiNLvJEbNZrWaEp5p1nyuId9arhJCz2Ec+3HgGMyt1KIgEoh0TL3wGNCLB4M7qiGTl2Tm1XixstTDksMj05ZlSNx3owTquhD12uqiBLNb8LWBp3yc8H4u3zt8t6W+Ix3yXO8e5pr3Y+SIidTq8liW/UUomwuqPepc6kJajIIjOGhEq3t+smFYKAwUR4TeV1HupyOhPdVG4mCQPGnOE622dyJRzUup7t50h6gRshtS84mA5vJkT8jCBIqRuyRd0hicG//00xHgT5y6FteZIEIpcUok8i9Cop8nNPyMcOFZfdwQAQYiySDdQKlU0kWtVKvNvLUBtQxmrCbJgSSuwe/s7fL0sYpOEGGF4T2IMHNrqEIbznYO6c/Vi/+bc+zTrdB25LRuOBGIRVIfN9vu8nK42Y8f5vnRS0V1ouCVkKgptwcWGka6VHKjGPD4QKcjlky1Jb2t3PBVNqfFdawqqCzCIrLZxt1Fev2m//qr8de/Hn/1TX/7Mozjp1buL4IPrcV2vvErwvpkDnO2h6GmQURthbTMvy0qTl3XbbZhd5nH7RLijLqYuq/EIq3ZaoGrpGFIHYk0iMVJYKsOp5+JDm2UdQUjVmLxOb3y5k9kjCeo4jPIMcbYDSMRREKMsevHx67LIRTiqQl0MNk0o1ZqaD9RZAptpEPI2qvWQBA3NgMH6zt7+dJ++xv+0x/4D7/l2xc8dNvywpf8cDjlpSyKEh681lK1qi1WO1UxZXdSA0zdWxt5Mp2j6NB3b99s/vSH/ssv43ZHIZxbv58/Em9G6VpLrbWyG4FaDf203RpdQBobhEViSF2Xui6m1ORH1LG00F7KVPKpLFPJc82lkakcDGpW6tbG1s+AfGPcKVyZpe9ffPXFr/7Hn7764+9effPlsN2sBgQ/6+zgp9cA8gp98PKD5x+8fGCb1jYKgUjQ+h9u7gWoDAMJKIC24NeQF0By3qAyfAdU5w4cQffAe2hxTFh5ws+2IK1QloPIGNYAn6cPOg8DXBi9cf4t5N8p/U66NyFdxJgkyE8UNJl53Gwvlun65qZofnys7g5EkcGjmzXcpMHyC1kmVjp7zzApuTEZe5MYW5l8xGt8IkaM1Kn31ap5cVIPDqjpNC/+YLnMIuxOy5JXHlEAP2k9N5MH4fOEnxPOdJkzCfbpQkQkRG4Ek1KqO1iEQ+j6EGNgj+QChVuBI6LuhK83fLELs0q2YuK762HYyOVlYGjNuRaICIHI2ZyORvtZay6l2JL1MOfjXJZiVT0wh0Ct63RO0teJvlqq1iIle3WhtEkdScfSIw7OQmYqHj+32P3c+c08m1GUbtf3l9v+ckf9YCymqm4RMaCLvOvkMmDDLu7VLdc6a1lsLuVY80Opy5Kp9j31W07XQ7pGv+Vxk64kKIUCMa3dQnysD4f68UP59pD/jjjSbtNdjHFEqfvDj/uPc2c5lOpLcTtAKi5o7mJefJlKnlTnwKIMg/dOO9j2p7uFhENPcaDYQ6K3ufP1n1tc5nP5Jc4R3DVfRQK1lN7Ucy4NxV+57I7z0Nh6hDs5MwcR4cDMBKYYqEvcJ45S+ewj4Vixu3N7H4RmU4wglBKlBOafhJKfh/k1wJ9R/jNJyAE0dytx7xE6C3Uxn4qValVhRiIOWlvsLCSROBIHIsAaXWplDLX48UQmckDdyTxGYYZYMDP3uppN+NqzOxOV1r/Sz9/654/HHArUJ7sFtHSFGbGTcZcud/3lcXNfSzE1ItbsFsxcUZVBHIJsNixMXYfU4fHBjkcsC9UKtU+d5nUlELeBhhA4Jum6eHkZXtx2X301/O63m9/+ZvPNV93NlaRETP+ixlrpEcD6hte4/uT1Ji26k6yTkERhxf4acMkUuy6OG9ld+ma3xDhhXlTJjbgZ8xS37FajCPeDM9lKHCD1cxa10tabn5GvyRWtAl6t59v8ec6seCZyfxba1+iWOgkhdn0/bvrNloM8Ms+EmYAm7u2OCdSiu2r7MSwMIq+utVUziEBwoAve9f72Df70B/7T7/mbr6lLDGyvrtx8nuZSas1mEv1wpFwWx8kslRJU2ZRKRa3VtLhl1QL3lPjqov/m6+2//6n/4q30IzHh55H9GfxTXKs39bl1aMQbZf2MmhCcHVE4pTRuxn4cu76LKUmMTlKKLVXnko/zfJhO+9NxPx2nUkjb5AxZk2pclWsBrEax6q6Ax5B22ze/+fXv/p//44vf/eby1a2kgE8Y27kh/092A3k2vdP8nZXvvL5nn3lNSFcHT/jsqLAjfHnK65168DXkDWgEknuFzQBcRnBy+pEcoANwTyhPT/9MpvMVT3ByFxjO7pd8ToPbr78y/srD7xH/KN1XsbuJqZMg8jOIkYj6od/uLl68uJ3zcppOVYs5M3UhiKqoikPMRT2SL6vKgKubEVVCZTaGVXhAu88r5NFa2zFyMu+qZ7PsLXsjUztN85zn/QEhiEgwAzOHKFEtBJFmgtKquIa4rnIYT/SnnyZezCyBsJIQTE1DCrELXRdjDGKBXcjAzuLac95Fvhql34T5RJNlTri92l1exusLXubTw8NS2DU2e/tQTUqhQ12OZZ5PZZrLaam5KBOlEMahG1NkYTOfcs3qBpIQ3EmL1lzIYb5IGoY+dn3shtGlKy4BJiZD+Bx+eP6Xjwf961TvDjKbvLq96LevZHPpqU+u6haE+27s0igSzguu1jwdH+/v7u7e3x8ePpyWH6Y45dFs1/NuK7u57qwSigjHYJFqdDN1PdnpEe8+zH//YfnHY/6B0lXMJOh3fTAOvohF5GiZUCKLhC5JmrLN0zLNxyk/1tKrqMpQET+6VK8ReP3sUlpTPFLoEQYPydZD/ml9M7E03QaHOEeXCElNSlIkxJiI2eHm5nBiYuHWYD7/lJXyKswxxBgjSyCQS6AUw9CXFNtIx9poOweadcM31UphhMBd5C45s/3s6Hpi47aj4WkV+qcyuf27ay3ldMqHox5nO8y6n/S0WK5mRuZgnLPXwByYzsv9iQ58HsRrJfvTcYjmB0ikZiAmEUnBAa96lp5twZOAJx7Dvwrs/uw4bmYt1nhPQqpeq6JStDTwuEs73RSHK1WfULgUz+rkxlCFO4Uo4xYOxMjDBvNMOaOW5hTuKyjr3KQoQ+CUuO/DuEkvb/u3b8ff/Xb3h9+NX33ZXV+FvqdfnD386bunZxjMuWL+9NHacdJmANbyvc0gNEatSByHdHFBm21O3czHpXokSl2yIG6qy6Knk5ULoQ2In5IeXsvvNqB+Dt9uZ8OCNRX7RJcwJ19FPM98rGcvWhkjISbaXDTHuBiHfeiX0KnEpflCi9Bp9lJbFxzmgSHUjkeuIAMXiMSeX7zkr7+m3/+WfvsrfvmCh46Y3S123XCx2715DbPZOW82eveg86yOrDrXSqVQKVgWzIvnbFVdmFKXvnjT/ebXw59+3719I5vNOkX+i40eAjF5YGUqBMADEBulfW2G2GoW4I3h6NwUSIKkPsW+b6G92VTHWkPfhS6yCACmiXOr3b2B/IrG9z433gF1U+bN1eWrb77+5o+///oPv794ecMpknwK6v/NmrLieqr5Xc3/oPqO/ZFQ2rHhxECAG/QBWEj3ZA2rV7Ol9dEBIh7AAzCAEiG6XCAMqAybQD+SR/hyzrvPi7el1H5exgDWzhxALk4wAZFjNLp1eU3xVtJlSIOEwM1M/dN/eL5Olm4Yrm5ezHmZl2V/eCx5YSaRTgIRC1zcRRHJ45OGkEMJlVDJlc24WfWsVrJSjYMQiNSZyDmwBIiCtDkgUbOPKEWJK3Nt742IQ5AUuxT7GPsQOmne8A3Je8ajaVje8+to2rNwGLkwuTMRuZGZw4xcBc2aZRF7jNhHO5nqcZZjOVapacM3N/3lLmw6Y+MlkBWvatqOhxi6LiqbiSHAIyjx6Hx1sXt5dfn65Yur3YWEqIZpKXPWXNtIjqUkmyHCbZkXd2LxMfpulNB3kH6pNi99Sp/Z034W2u+ORJUeJzYOV9iG/kW/eSH92OY7WWoIIaVIBIe6mdZp2T8+vnv38cf3P354/PDhtP9xDkfdOK6S32xJzUJQCVmCSlBAycgLT0fc3du37+a//Ti/O9ijhLDRobdCzJ6CjtGJnQlMAanfDVdI/bHg45R9pnl5rCV6sgqt6LKnB9Sbz/e7EYlLQOgQe288eTx7rETEQmAYnFpoTyvhzqnhlBIC0WpYQEwiza28xea1bU+AiKQYY4whRHJ3EYoxDL10CcJgXYvatTt7hn9aaA+MKNylFtr1l/L6nxzL50q9oeDtyaip1nmaHh+nu/v8sC+Px/p40tNspcIMT1P4HFiCcGASIgGd9bTOS5yIz+D8+dcSmZmbFq3OgUQCrSZO5uqfROxWATj6vxXXyZuGjFqpqmYiAHFVrbVYcSmh92ErF7Wvyjb7pNCidbG5NzKNXIqrekukNltKiYeFloXygpxRipfaBFCJIK1eT4m7XsZRttv09ov+V9+Mv//d9ve/HV6+SNsNSzh3z//VUexnPuRzUOBTzFnnGVdU089gPZpBABMHiX3fXezCxe7U9Vlkqd4zdSmyiKvqPJfjIc4LDA0SpE8tHBDgDGjjxbeKtDGbW1xfC7EG3a8P6fxf/OQqVjYAUUg9S+TQdWkMIexjOAqr0NwmKRzuM0pBa9qYcyN+sihoASMk3uzi2y/TH/4Qf/9b+eYr2W5bTgxAQujGcffyBQtTCNht6/u7cpqqmaqiFsvZlxmnCYcT5sylpr7vLi/7P/x28+9/HL75JtxcUwjmxs/W/vPlBCYOhCgWJQsZLILQJAw+8ULNV9pts0gwgjNTiCH0ESEQCVzIwWocg4QAkKm2FIpqbSL06qs59Dm745akIoarV7df//533/zh929//c1wsbMnkPe/yxUBuBWrx5rf6fKt6HvyA1NlIm2AIkVygu7Jj9AD2QJzZ3WbXY9kB/jicEDgER7BG5KXCFt4Rb0jjHBZjx7/FNXO7BA+p6crANxyojYLYBaMe6MbyLXES45jiKm1JnFelZ9dCHFM3cXlda51KcUcH5cPps4SmUOQZCam7AjVIzzDC7w6KnkFKUMJSmbsVs2rkToV52iNJIM2v0tixH423WCiYKZV1asCGkKQEIk4hpjS0Hdjl4aUOpHAbRapOcyeQZQmjfj8Klp339wJa2h3g1WzAnWLXJlcvIieuH5kegTmwrp3OeYjog9Durrudj1HzyrUxbDMapqLO5j6FLuROcaQPCWkjmrxFLqv37799Vdf/+ZXv3p1+yrEQZ2nuU5LneY6L0teZkLtIuX5+Hh/fzoey5w3VG86H7cc+2RgNT+WeZ+npwv5LLSn7ctBaj7UrLFSmqv3xCklhzkKyIizYSqqZaF6mvL93eMP393/7a/7/3y/vM95j2WSw+IPbgfDBKdeh6F0yfvA6B2MnOk02ftH/f5j/dvH5d2RYvfi5ebFJr2gpf/w/SEv5ePH/XzMUkPiru/GtNle3O66eNTZZjsdKw2zbk4eo4dAEFBahd8+PR7oOnQuidJAsTsbop/JUOf2uYPBAbGHJOfQBKNEQkpd3w/d0EsUYjTtW9ZPPe/WeSWmFOMwDMMwxBhJ1Ygohjj20ifI08D6ao11hoFARBCmGKRL0nWckjGZ/Ty2+9On9We0/7nDFM2xdMl5Oh33D3fvv7//8fvTh4/1ca/Hk825ObiQMIXIkljC2gwmfvrR69j9GoGejUOvhenakatVQRRjYGGKzZK2uDlWn3I5NwmaS+o/PdIaKtA0ZNStmpZSgzkRV9OiZuZQDjV2tR9orKnQxhfONutpmboi0WAlS85szS0B4i4cpGOKHQ8KVZg2JxImEDEH4ZSo62kcabeTt2/S11+mL97ElzeyGSHy3Jfd/8mB3MJn08N5coLxFQhfBeHokxbcamplT4/fjeCcwnB1MVxf7S+25f39tMx91b6o5+KB9Xiqj3udZ2r3kM65Az5TLV/f5eoAtg4rrLEM5z8YoMZtaOAZjc4BX+eG1i8Qc+yG8eqaSFMMMaU5dRqiCmciYvJDm0ByUmu+og5SohqDXV7wF2/kD7+l//FH+uINbUaE8JQ8u0NE+mEAnIj77U5f3dYltxyu5lyXuU6zTZOdJi4anIbtbry+7r/8ovvybby6bASU1memX4jtvo6r9B02Y+nSzNSZw13MSe0TNxQrSM9mMCeDOElzmCdSbiU+EUCBJXCKYei6kkuptbXT4er2NLAKUBOooTB0w4vrr//0uz/9f/4fb3791XixlSgtAfhv6/XzVVT32fXe9EPwAyMTzMHO0cNA4RI8AKeGFxAU5E063vXBy3dEGyCgHtwWUEcUgRG4BO6BeA7mz3mvTeKX3dnBK4dgBX3M2Vfm76oDIsQdSy8SRQKt/XWiX0xbiCTErvfLy6vmB1nVj6fjUkqM0nU9OROLeTQ/Vg+wDC7w2uihDuO2amEGN3NTVCBac2d0NS7G6mbkYCGOggTugGreptMNRAALB5ZxHC92u+tx2KXYN6+2M6TldC6ZiFg+owMCbt4m61S1VlUFmJu5sTnEiYrVk+tjsDuXE4K6U7XA0S6uxu3Yb8auD+DiQxp9G1WnpZw0a62oSw08M6EXd3EOCF3ajduvbi++fn351dubV69fp24L6UqlUlEqci7LMsFKFD/t7999/+2P33777vg9LfuwdLGjLpg5VdX5MzG6z0P7cPHq+nJHD4fjVEz4NJ/GOvfeU5vOqUf34mVy5Xqq5cN+/vuP+799//j3H+Yf7vFBZY5iPrmcvObqumA42uV9GTrrIoeRLfBh8oeD/fBQ/vFQvj/oo3bb8eX2+gtOY5n9/d3Hx8eH0zRpsWBx7LYhJen7zYvbELbTx8M03SWlMfs0W9chdoKm4fi5ew9gjQxOEimOFAZQ8HNsx9mdECRGIqGjNCK2yp4JJBwaqXIcNzHGVtAyt7WBc3bQOtfSdWkcx2EYYkzk2QGEEMZB+p6CgBk/wdnXNhYRM8fIfcd9xykasflPIMd/1ns3N/WaLS91Oi3H4/T4+Hh39/H9d4/v3y0f78r+YKcTSiYHSWg9Zg6RJRDJU2rSIDk/08zQCMWrSO4a4p5kR1WN2EIkEiGwG6Gak3MIJMIs6/G51ke/eIatoe4c113NV39VB4momjYqs4JriKXreVCp3hsERj7bMlmKoKhZlrnp4QeSwBw5uEiIkahjPGM1nAFkT4n6HpsRFxf86ja+fRNvX4aLCw7nLfCpW/NP3ry7mhG4zc2f5XGakmuL+jA/31Ss6h9YgzrcDaYUeby8GF9cy+VlHT+c9ktf6zgvJMHhdjiWx71NM9lKkG4/gZ49/+d9QjivB/WKs/iK5KzysupuBNjn9jBmLbS3hagAibCMm5RC16WQ0iHGQ4O4nRxQMyyZ1Gg1vmNnVmZNwa4u/csv8bvf8B9+y9fXlNIKJ51vIxPFlFgkpK5eXWkpVqtWLaWUZc7zUua5zrMui4CixHF3OV5dx4udbEaSZvixghGf5A6fnoeDiCVFGQfebWvfHZmrKbklczEz8jMBEQQXUHAndzEER3BikOGsBr8SEQ3kzBRDSCFEDrVN5rXOawNlVkc8uNB4vbv91Ze/+Z9/+v3/+3+8/OJ1Gjow4NaM3f/lzNvTS+HF7QQ7kM/cDFdInDuXHcUbUAdz8j2tQ6nr8iN98PIPRwcwtMCOaFA59aANEN0Bt9Wxal1DdNahE0dwZ/dmOdCwNKzOt4B7m2lslipRwmqzdu5N/8J1NQ5SIGy3FyEmB5Wq/u7d3cMdO4ETUyAJXtmU1GcYwwRcxOGEdboYzfOgQVLa3BplVVXwYqG6ORlxEEnOHXkHsHpd9yExEUvoQtjutjeXly/HzUWKfQihZcnuq09AW0zM7PIZ/az1101Vq2qtpgo0P8UKU0R1zFruWR8Ej6BCYFBwiimFq+txOwzbTZ9afZMiwaZF0tFzyaa1LppduyRJ2BkhYNPHm13/cpuuN/Fq211dDP3mMqStUYIHUNBqOc9uRdj2H991ZHn/8OEfi57mehSTar6Yu5ZqMiB+mhT7LLR3Y9q9kLiNF8tSrR6nH/F+Op42gZms1tNBT7MubrPXqda7k37/kH/c2/tjf8CLZdzQcLOxfVcetZQ6qy4Ps32/rzXhyMDBC+sh6362B6NjCn4ZQ99TQJ6mZVqOy/L+8e4wHcwN5lLZYbH0FygeIzry/qLWcVrCSbGo5eqpimsg7+pzo1o4QRsFm4hJtpCtcw+aGqjT8lYFGUUOA/VXsr2lbqR1agrNKXe32V1dXL97/6NwJNI1JWxCxwxikhi6lMZh3G52Yz+GEC3nxlnj7Ya2g/edLxXZnvJgMDV9HCJyYXQdDSN3A4duJVs9j4pPp42vmPmKlGvVZV4OD/Pjw/zwMD8+LofDdDyUae/TyU6TTZMtC5mlFIU5hCghskSg+T3YmuSsvfZnH2cUHk/16MpIaoWgr9RVwIldhAAJkdpQfvsOwz9j0X3qsp9L9jbLrOaA1aqqrXncaiJhjaGmRF3l6uK1N4ctyKeKQSy6+zThdFqqMiAsIYQQg4QoQfgsNNAkesiZlYNZYqG+//+z919Nsi1ZeiC2hIstQqQ46oqqW7K7p9GNHjSHA/Id/wv/DG8wo/GFZqQRZgRgXeiuuuqoFBGxhbuvtfjge0dGnntLdA8xY6CVW548kZkhtnD3pb71fdj3vNtS2wCz0SUd8B8aZlZKxop4R1BCXXXaEdVITUWRKhS5zhEjlOVi21KqRQybvr29aV+9PH7/Yfx4PEpuT0enWkqmxwPeP5Zh1CIWKvHmkiF42kCXi4kLNrP2Iq/9bjUhvyjVq6Iqgslzvi1TUS1VCWYpxAAgGDnX9DsEqKIII3FiLB6FAe5PNkwilswYARCyY4mBrvf+zUu+vaa+B+9sJWiqfOew+IfViQ1MzmKoIIFSpORcvzRnkVLhqz42vmnZh5VMaU1U/TBkh7WS5Jzf9s3LG9tvTp6kFFfB2Vi7MJYuTq6aI0iuoqcVMSukgsBkBgiqZkUhJRknmUaZJyiF1NiAFdlQjRi01CWIZoHDdvPlX/7y13//dz//67+4+exV7BuoxKwrSc2fNJCQHHFnbgvSmQXDhETIPfA1upeIDvKANXZdYR4IgnqA8i0CAEwADGaAoRJSGSBYBhsAZsQMYGthB8EYwBk4ALaqkF6bYHSBkGJNxwMZeMCA2CCFBSb8h3JbSEhQsygeW6Kbm1sAZOeNIKU5Z2FHjiM7A0JTUkCxWjIHPuee6k5JRrXmBAvtTiULUoQKYHPASF7FqzCSEFE9F0RGdCH2fX/74vb1q5efbTd75yNXhOwKToEL007PKWtMVXKp1h0BHBOBMGYHI8msmsQmlAPh4GMK0VxDLhiyAkBgt+v7m/0Nm814OpRhmod5liJgSmjORDQbBR9doMiKufPOg+bh8Xj/bjhcz9dXsdsSVzpTJnRK6jCoIkAJwTcxRu8YLE3Dw0fN89j0XS2bpu7m95p28hh7bratiL+/fzgc7od3dwDo0GFWfTjZ3aQPSU9FktpQ8JDgocCDueI9xQ0562kwfZR8Px7uTw+HnOVUDh52hOAhs00qM4B0ETZNu3U6ezvJNNw/DOV+HO/G06QzO2ZEFIOMPh1HmQtBiK3f3kjZj2Mz5MOYtJ3BTYYjafBT+iRqL2AImgAMsDXqjVrAE8AMNbOqCIqAHl3H7ZXrb5E7XHHICOjYb/vd9f62a3pmj5CW+QtrvM7knIsxtF237Xdt07PzBmiixoRdi30HXWPDbClpzcaiAVaGKzBEY4amwa6npqW6lz3PqNjFv/ozGpiKTlM6Ph7fvT28e3v88H66fyzTpJINleas4yTTpPMMKt475zz7QORqrtjUFA1QgIxqPFQhUmtr1RNcyc7VwroSVlRSrQYAAjERovPIBEtR0wAQdI2A4dPTsSe7btWO1y9QzSKmupT5agODMmfvNMQg4GF2c44llYzJkAkJTIpNo46j5QyIRMzeuxBc8Ow9Bo/OgV9I9zxiDBGJfDXtfY8hLk0BF0f7B/ZjU5WckMmAsRKCEyCv7ROCqxEHtEWL6Sy6WOlMa43Pd21zfd29fjV+83b45h0dpmY4NiVDSfSwgbuHchwkF1StztNSIX2aBecJsSr0mJ0vpZ7JxRbsm+IqM38eKsWkLHzDREszFwIi+hCdu2GqJXVUtsQoiBOxGpRUZjVmJuegbWm38y9fhM9eu6s9xlCZiYEIAdfs+Xn6Ym0jAHS1yu/NREwXgV1Vs0pNQos2+rL14uXs/wEcsBosZArbfvPZK//yOm8aK+LK0igmaIUMwMgsItbDcEQMgKKWsjKZKjhCAlSFVGSa02lIw1jGWVNCEdJKao5kiEi1pUQJ/Kbffv7653/zV3/zb//+s1/9vL/eE1WsK9AfB2Q+DURGathfg38F5UrtHYIAEWKHfIXuBgFAfD3fqhRQQ1uCweQtQgI4AnZAHeI1UmVqVYNkMBDMhrKWyCr0ggwYgGvr8LKyV9w/LtRZAMAG0bBFbIgCLaD4319kWBtBVZGZmXm/4xgbJBSTu7uPx9PJDAEdMyCiIIquaBA1gGriYT0/IDOGipNQhmVCGSIwEJojJnEFzFQRF8XZ5c/oYrPZ729vb1+/fPF6s9k556m+lQEs7A8rI2kl6LxY9zUVX9Nd7JgJSRPZ4ODAegKdVWe20YUSgsUGXUT2yAgEyOY2TXe1uwZRyPBwmIcpTXMuRVRhjawM0bkQGdEcBUSCPA8Ph3t8fLjZH2+77U00wWW7FwJwjIqkCswUgg/eO+ZTSnfzOExDM3XMVb6mhQt6l2emfZymh0NqYvCO2tjILN/f3X1493g8zPmY2sHak7ZHDQmAGL3j0JdNGYZRUsYsgX0XY+djF1oPLArvy8M35RjnspmkY981IXZds4nxaoubdpvocJC73w7340OWUqSgUwZX+8NRtZgehtNxGIpBs93vdt0Jp4/vf/c4nnoRTiITTPflwMM7ShcNPgaaTUTyqeSpqCkExRawQVUDMQNUJTVCx6GzuIG4AfNQqs0yAGCibb+52V91TeeqXVz31DqHHXPwrvGhb7rdZrfpNp6dAJiJElr01rW46eU45jGJGSAyYQ31DFGZMUTsN7zbU9uh94iEq299cSZr8rVualJ0nse7j8d333/8+neP3387PTzk0wCqzOSbwKnYNJdxkpSwiCIoLcwptEb9YKqGlfZnMe2iJqJ1l10c27qTPFUIaoe7iCLKArDmT2U4YHEM/pBd1/OD6lMsORYDEVCDBVW+suECk3gahSdzCAqoKjMXCh4aCtuOzQzBBrU5lTSXGTIzOYfOITOGAE3j2jY0LVsDRNQ27mrndlvqWvLu7MFclrB/395lpiYZloSuCiiCEhqqgHpQMXUk3liIHUHNMlYmNEJjIDNgQCTHcbvZff75+N37w2+/Hab0OI85J5eTv+/g44McTjon6rTy58Bq161CmG213Lb2dqmaqKrUL6u3UVRVQRRM9YdReyl4riqtQSbW5j8zDrHd3ZgqIE3czL7Vpk0x6mkuSdg5bmK4uXGfvXG/+oX/6Ze039nSe10nVeU8WXJMz6/hGoHXxnxCqgIKVhV+VmhBXWZg6z5vFyQxT2MxNITd9f71r766+/77j2/fHn/z9cPbj0fJZAYODcEBRMRgC/8EAopImpMgoBTzbIyAJlJyKuM0j6dxmueUSg3eqkUARCMsILOJBPa79vVf/uKrv/tXv/43f/vmF1/1++25EIH2wyP9Q8PQA22IvzR/B+Wd2Z1aIlTAABARPKKYWS341x1oWWUmBCOoQUlGW7Nr1AzAAGgghhkgAwqsEQssa5OqXQerPNmK1Q8UBTRCZGYmEIsAPVBv1BmGFRL/e0eF0y7OogEi+OCR6fbFCyBsm+7jxw/jOJacDAgwMAERATCgMxAxWWsNgICgC1+WgRGCki31cERbML/V/0gK2VCQjIzM2MwBhhi3u/3t1dXtbnfVhJbg3AoEAJWL+Wlfe/p+Ma9CiM457xxhkelB02PAB7aTlYQqDsUReMfkCR0hg0MIxJFiFxtPvkgparmUOadccoW6QkU4eBOnwhlQgAXAgAW9A5qn6XA6PuzTZCZViUZVEdAxoGO1UNqm6/t+s+k2m4fHx8PjOGRtFZ1nRvKbcimM+My0D9N8/3jadm3fNYFDF7RMH999d/jtP70f3o+3Gl+WcDsSmIPW6Dro9WZq9SHbCCc5jJ3mFwrXGHehNaCTyfdpeiuPyLDx/KJzfN1tbndXt7v+dhe2/Zjk/m4s43R4mGHMmI2NHbraMI5qVmCaptMwzDmh85vdVXn4MKetPcZ+nvyMOMGj2YeSD1cCLy9Wi85aQNJQ8iRqYk6xNWgNMigYVEyQOWR00VwLrjHhGlTV205EXdvuNtuuaT07AlzU3pdCOXl2bWy2m83Vbnezv9r0vWMGMDVVBPMMbcRNp22cD4OiIaInVCIlFCb2AZqGtzu332PbgvPVm/lk9dj6DQFAVdOcj4+n77+/+90/3f32nx6//y6fTpazI/ZN9NZBmstpLMNUcsJSCA2QrFrhipNdasQAhrVQWz32xS7URb4Eg89mfIWfSBFEZGYkJOAL678CU85H/MMt7NLA21oVhopIMhNBs4VRx7kFpgCM6qkoizGZQ0sqGQUdQSTaNIxMldZK1bSYSJGMhYEIkCAGAjPHrmmBmYJ3m97fXLn9hmIE5iUPUru//2hO3sTKDERqhCC1IiigUNtx2Bl7YyFx4ByaNzMFZ0rATFyTjlClMHzfbd+8OXz2Dm9+M98/HE5DKdrkWR86vL8vh6NOE5YCzsFqLxf9F1UTXYz3Yt3FpKgWlSILN6WYyhKaq4CaPK+1mxSVXHdiWzga6map1bQDke+2LSKwY99QbFPbl66XwwhjMe+x7/HNZ/6nX/pf/cJ99ho3/WLaoYZDiguU41kSfQniV0fQlhQwwFoBOrdr1OQNwDlQtwXP8NxRON+vdrd58dUXn3381cP9/bcAH3OeT4OWbGhAENXAaAvOka+gkFwE5plUICXwldFBi5SU8jincUpzzkUkV2p8QwUTtIKQATIj77vN56+++tv/6V/9X/6Xn/zlr67fvMKlM/aPt7r92PCAPbrPyA+Wvzb53uTB4AQGqAqaADPYfKaCP7PbLqg6FbMZtCA70xE0m2WAvCgqrDdgiUpstevgwBwYghUQBAETA65NsrXhpQXbAe2AesS46t49m0efnAbi2v+Dyz4ZndvTVYxNDNE7//Hjx4eHe1E0qHUgtuVgMliuwsArFxTV7KLBwhNVG+mxBrOIBmJQFLKCGAoQEDgAp+IAm9hsNtvr7faq77aOwzMG0KezOF+ZZydSJxk71zRN9ERgqcxaDpEObGOhYmaeyTM7x+wYyIjQMTbO9b7pQmTgrFlUspYsRRY1aTMAYkMPylIoAxbEXMl6zKFCmabheHwcx1ObRiJEKKAVprzAn0MIbdd1m2233blwN5eDSSk4e+8cc5/l95r2VPQ0FYDZzPrgEcE5h+SPh/LwYdq6gOScBjQckxZRaqhsW7mJ07uHu9++bR+zlcGR71zjA2xduNbmJXbhlq7edC9vdy9urq+urrf7q7jpuPFtyb6JaZiLagmn9G5Ok0oRQseIRGSMIibz9PD+/Ydtp5IeH0+HxxIPcEy4IYeN2wTHgVNoPlzOuTIZmKbBSgJVAxZo1VqQZAoKmQCJQAVAKCipLpzJAAaogIIIjiEGF4OP3hE+qWxUbHsM4Xq3/8mXX/7sJz998+rVbrPhajVrDhRJg4e+1b7N/aBZGKAYFMRMyMzUtry/8jc34eaau9Yc1a6ay1n25FnaUmKfDo/Ht99//O0/ffjNPzx+++308YOlhGbgHaRIJac5TY+H+TRoylgymCiAq2EZESADVn0/1QWCbQv7iamKqOja1L6Cri8X7oIxUWImAFg6iXQp29vTgE9itedf5zz8slMTmpioVAsD3nHToA/AleueUV1tBFUDp1Aki8HESoFAXYBN8A5jA/OspYCZEVV9aQqBus51few3brul6z3dXPHNFW16dLzA1gHPWd8/bN1NRdNY0xVqDkEMi0FR9agF1IMUYDZyoB61oBZUr8yoNUIyZDNgRaIQ2uub/vWr+OZl/vBhuL/XeUYVGo7u/kEeDnIacZvQBYUqeL0kV1QUztZdpFbNF8CPZKk2XuvvpZK5mmpK6dnmJUVyqsVRZAar1YV6k8/4N3KxbffkXAjtJu1v0qsHmBJl9bGJ/bZ98ap59crfXmPXoXdr6ajyAD5dR1spGS5SOyvyZPlXV95qyJ9KDzW3DSuC42m+XMzIaknMedfvtl/86ufMfH374rvPP3949/5wf3c8PE6nI41ZkgE6Ig9EgjZryVlBUk3/G6KaZtFcZC5lziWJFYViUAAy2ow2mYxQUsO061//xc9/+rd//eu//7uf/OWv9rc3Zyp1XINXW3TW/7SBBBSQr8i9MfdTKN+hfjATKIPhd4gFSEzeg41nbku0BctgaoAKIEgJbDD5YPI70BaFwBCsAQiAK3PWwjTn0byZB2NUw8Wug1oFXzAiGzrFrdlL4JfE18QdItPa0Pf7RoWdVCzOgrVDjCEyu/qgiY1z7nB4HMeTmRoyUkPkEAtgASxLRmrJvNdcfRUvqBd0yeQoiNpcNC9ZRiRyRBgIYiFH2IawadtdCK13npDMcO2wqCHIuch4rvY8TaxcyjhPNWlEqhGHYAfisfHKSIlZxTyhj+Q8MVWKHwrOdU3cdl3TeABVKaBKoExKVOmV1dSYuPHsCcyKajYtiEBAnMWnHMaxOR4+3r0T8oSN4xhc8N7PmZ1jIpZSnPdNt2k3V932vj+N0zyVImAGbJ8k5z4hmoWimIvmJMIEZoBkwNOsw1G0QQ7eWQCzKY1DSmbZ+hb6zdTIw3A35hQe0rbMN5owUBfo2vkU2u5NuP1ic/vi5ubmpt/eNt2evAOCoIk9lHRlCAV9hsP8bsinwkgO2HsHjEUAJJ/u7x8/9M7heBpLkiAIyozkY+y7sPX+IVwov5lZngHMymySAQzAKbTZOtVZRdSUgJ0ymGPzWhE2S4tnNT2lkh0xQfQuhuDozEBRQy9qYnO1v/ry8y++/PyL2+ubxjVpnJbNp1aLY6BNi7sepglSMTUtIqoFsfgQNxt/dRWur/1uj00DjqHoDxeOLahOUxWd5+Hh/v67bz/89p8+/tM/ze/fl8OBTAnRvNNp1mme0zw/HPIwWs4opbbyAhITKyshgqGKiEilhKCl0Gwrs8cibgefOrO1lQnUFEWM2ZjrlRIRACA+x1Y/2IEvfJRzvXSB7a0ZiRXfApWp17Udx1hjdxRG59CMBLhiFpQLlsw6eQAgpIaddyFwylAK6GLaAZF8oKbhtuO24+2WrvZ8vXdXO2oboFXdawmD/ohdBwAwhTIDMwAjKVL1Am1RxFQ1FlUicmqC65epr2nlRfMATNEhu7jZdC9u+8/fzG/fTu/epnlKJfM48MPDfP8wPxxwv+fYKJrKYqF16Wg3EDUpKsUkmxSRLFJKySJZLtLyNXFvqvkT015ESlkRoUYL2RqsqU5aMvUuBPbOx9Bt09V1GkcrhQxC08V+G3dXcbOH4ICXuP9y/T27bBe2epkktla2Vo/gDA582nKXGgScOWcuYJ711Uv2Gw0cM7fNiy8+73a77fXN/rPX77/7/v333333u9+Wr7+1949FJ4UKnoMMVrSALncfERVqAGhZNKlWjppiWAAKWAKbSSeE2bG73e2+fPWTv/vrv/xf//4nf/Grl198xs7bYi4qHh7hj1jAH0wrREMH3ANfk3sD/JmV35o8mE1obw0OQGZ6qKa9ekDn+qCedSlQgEdz7yz8NyotAqPMZg6NF1GkJRvPAA7BA9SQXaqjbzXlD0ToEIJBo3Cl9NroJfOeqCNyy5vY701MLBNnmUXrA0DnvHe+a9pq44lQpKSMqmVRsiFFVCAxE0WlJe6gpSm34m/WFLxCUUtieYHsICIysWNuGVsqnqhtm33X7mPoiB0aGRjB2rG6TLoq3IRLu+9lrd20iKScHGm0hHwMOASfg7dahBQzQlvkHRAq3T0hO8chOCYoOZWcVAuCOgbHRqi1MsJInpDRFrpmsUxAZiTCWJq5jNP0cLgXdExN8E1FZwMCEzsfHDpCatpuu7ve7h8Px5OazdNopmZi8mzdPUfIu7Bpu76JXfSMOk/jMEzH41CSEFDNcppBLjrk+eFhGr8383Og3QhFoocYjUERCxXyQG3adgibprttrq/67aZrm8Z7j0ygxTSj5cjl5Yu+cSG4nv39mL+b5gOoMkPTeiavik3nUQvkHMxcQNjgNvLn2/Dyhd+/aih4RRfUPaORlwIIqEaGhA4oKvUFhySTZFE1D5FjT7z1ced8zxTF3Fr0y6Kk2VIaiyTnuG0b7z0xQU1lIzC7pml2u/3t7Yur6+vYNFXJoJI0MhM70rYpu00zp4ImU4ZUeJo1ZUXUNvJ+F66vebOh2JDzRIxrFuty1S/xjWmZ5/nweP/2+/df/+7+u2+O797qwwOMowEo4RLgkssppcOxTDOIkBkqqNZwXFSlXOR1cUGUAsIqJ7vsqHbx+NkGhAZVrVxU1kqqqkp91ZL0OAful69dv9vTb1bVrBWSC7DEe+Rj2Oz86cjHA5WMUipLARUjAiZQJEYSlIQioMUsGTTmApPTyOvOgkjkHIRgTZQu6razqx1uNxwbqnT9S7f0s9LbHxgEtXRbK+fGoA6UQcmUVAgB1cC4An8UlKhKfBMIGZAZqGrtiauMdc1ud/vVT+zh/uH+o6rk+yPMk9zfh/fv49u3uN+1/QZo5WxXUVETwQqWEDEpWpKWLCWXkotklbLY/+Wur6a9lMsTURErorW2DqCVyXNhAFhKE0+OHTtuOIboN3sEQyQXoguNiw16/2TX/2Ai+nJWXBp4W128xdurnpatnKyrPV9t+6cgusonQAZoCEi+ib1jcn7z4vbz4/Hx48f/+p/+0z/8v/7f9/qbfHqbrEJhjeHsQCwVAwXICtksiWXVDFAAK7o+oU0gEwj0TbzZfvZXv/rq7/76q7/5yy//4pfbm2tyXE1Q7RB56jj4YwHu5ajr3ojQe8id8U5ha9ogzKgjwACshAkgAQieifAWLV+QQloI0dDNwN9DKIBAMqEMJCewsoSpAGYEyAi0yh5LPWnVrCpogEZoDqxTvVJ6rfw58kt0W+SwntQfOI1FKtkWLdwVwAEIgOiImW5ubn3wIQbn3f39/fF0KFqKGRuxY0bPDHT2uIAqERatJt6W+pOJVn5fMyIAR8zORUcdYYe+iX67273cbV/E2CMwrNBWWHAcSwsn0ZJSInrGRheC7zplMLLEMHmaO9a4cpc5Vzs3hNhMixQgNAErXOaUT+OYCK1MaZ5TmhEtRu/nTJyZwBjQ1KRYIUSkmiUxFMOxGIMaeXYxZ5nGsQkkQLPBaDqnGQyCj12z6dtN2zSvXr06HY/3dx/SPIEJiP5wD3ve1x7ivnVNDIFJpmE+Tukw2jC1Bux853xkHzgoFk6k4zR+fJQgyZkqOMOGY+vIkzcyI0UvTUd81fT7brfpmsY7JkBVnU0zyAyWEazhgJs47PVqP3WN87VfiNC3oW1aIudCo5LyeNLxFMt80+hNz5/dtPtXob1tZ4Np0lTsybSbaSqioFlNwIwNo1JXoJ91KiImhL4x3CPv2e2IWgBniiYGKIa5FBOyVCaR7IPr2jY20XmfTUEEENhx07b9ZrPd7dq+d95hXnNeiIxMjqxpZLcRVfAs46xjwmHCaSYi7hp/vQ9Xe9f15MPabv6pP7yEt6omkobhdPfx/vvvPn7z9ePbt+PdRxpGTql+qiAqkBnmnPNp1JRRl/h7gUuLSClae5SrXV+n81IhW9PvZzO87n9191mRXFXBSZac21KcN1uY9hE/ScU/ncuPbWiAZKjrPoBWuVG9D5ttPO3C6VC0ZK1M/Fa77g2BkQDJMFd2ymJVJYmj42DoDQirXhmZcxA8NJ66qNsOdxvqOw4BieGCvk0rPfmzfuwf27sAuEJhoJp2IzBSJRQyXOJeVKgk1WgmaIRADEampLVSUO0TAwH5vr/64nM9PtrD3QQg+DbNpeTEdx/5m2/c7U1zfW0xVNowXUxcxRepVqa09daKLLX2c1PhgrUT0R9ByKsUAURQNDMkQsJKyFtPlBYnFwAqzq52NS/cDug8sTszeuLqFlzM2iVkgwsTfmmaP/nd+rXqEqwOwPKlF5b9+TRaAvflw4kdkfeuaTcvblVlOg0QnBTBw3x3P6ShnIoWA49PB1GxnKKQrVo5y1bJzUEQEmhGy4FLE7dvXrz42Ze//D/967/8X/7nlz/9cv/ytmp+LCdqAEgLwvbHnPQ/OBbQA7AD35vfG+8NetQHhKmiOpAUUFcxw7VyXGWFMueRFRBZCB+cH0EZOZEZyhF0XMRsa5Jm0YCppcOFEFKkyjnWlgk2aEWvC71Sek18i64j9n/0fLAWeFbTDtXWL0WX2iLKu51r22bVXCA1G+bKlQ6gQMbVyYGF7ofqxFqJNkytBilQ1MxAEY2YMLILzreOe0d9cLtNd3O1f73b3sTQIa7CHOdbZbASzlLdcun5qUXvuxagTGyZYfQ4Bi6BLRcVVXaETLAoYqqJGZoSliLjOEs5oMwloYqozkUKMzpP3pF4IwSm2soHYMuupkCmILkkAnLRx9YMpAg4Q7WSci5pGI+mVnxDio2Lwfub65uHm/vvt5ucZ+dISjYR5mfW/NkPXYhXLZqYDmn68Di9/RDuDy9TdsHLNrzy7U3sb5srKuacC/LI83z8OIvesVF7wtsS3jTdFTnHKJY1I2GIjQtd4xqPzsRGzQLCINlKhlK0aJrwcJB33433359szB68GgI4CjHutpu+F7XD4XT3Nodyui4PuzLtere/6durRnv/9v7wTx8fvqdbWFPyZjAOWR3mrGVGKZzNCzaFuoJTQQQqRh3wNfI1YC/i5lFSkjQrewMErGUQTYYaYug3fb/ZtMOQtYgqIrFzsWlCE9E5IHrilTdkI6hacE1L20JILoRq2m2eMSXvXNM0zYvbsN/5pmHniRwjCxpW2cmnYWYGaprz+Ph4//bt/fffP7x9O9zfp2HwKUGRKoyYVUXBFKRoSbNKwZVlzmo2X0RKAQCtdn31Y7WmNGkFz2OtFEKtmS4VPbNzHFdx8qoCYEi0sJmrgeliIdbc3w/H4jogGGC16SvUlRANyUxETZFd6Pr26rqUpGALHZSpgiEQAde+aiOoYpAKMpMU1lHUizkFZ+pqeOKMPLrG+b5pNi1vem4aYlcXvC3sck9N2H9kmFGlo6n94qKIiiiLgQEFYEAyIEBTBakNluuxID45U2qoCOS5ubna//xnBDDsb8br343vPo7H4XE86e/+W3x5tX114+iaYgRCMBMyqvaNEA3BEYED1Jr5ByMERkMjwcpTsqSsjflZa2g1+YYIhLbEvkRqi0Yx12iGLjBRi4j3UhlXg4qsIFit2Jp3gfOPy0pcJvFC5vO0Qs3OIj528X011yvz/RmTca7gPBtY2SMXsHWNq+oMNQTXhDdffgFz4WPioUz/+M3793c9clO5VwwX46Za1IpBMStghaBghXVpAssMzc3Viy8//+pv/vKX//PffPGrn7/68ot207kq+nye7PVaLRDxf6Zht7ryyDCCv8LwyvxLzVdQ3qMaWsEqo0yXmeP6MYyIIjRNlGYoAjFpC3Nsv6cwGSJCArgDyE+NNmtWfanBmapCETI1ZkAGI1SK2a6EbpVvibfIEckt1YY/mNo6R/bP0vK41CkQkIgp8vX1tXPsQ4hNc/dw93h6TDmpllKkUlwslSLAhefbFGp37uLlZbOCCIjMFAGic9G5zrtNcNur7asXN5/d3r7p+733sXqr+FRqX+CEF5ASoudsdN5jFwnIsGQHI+hRdRRMZlIV5WmhsgAyJQAmJCApcMrpMd/Po01DIYQQgdDUlAiaxhNxyUakzFodFhXJWcEIAZ0ScojNpm23SN5h8BQ9BVBD8hhbRAw+Nt6BFELXNWG/396+eEFM4zBISSqF++7ydjyH0Y35eCfzOKfjKd29z3f3/TB/wXjTBzHyhpHMO3LE29CVuZSs4ZDnNDNwW/wt8AsKDTkEnUuapmLZOWMzVJOSJ5EJMwMSlGK5yKxpKMNjebyfDx+m/DGFDB2aqFmRXEpRAYdQJKcTjI/HfNzwGHDuguu7xrXN4Oix2NfH+RjKhWm3h/vBOypllpwta0o0q5+hydgLeUQV6gvsZ+lodulQMJ9SspK03TTEAOKUVE2QoIlhs9nsd9vTMKaczcw717ZN3/dt27LjSk9Zs4qmKkmKKJCqIjrPsfGGzN5c1iaCimPHTbS+L87NYjYlQSYnKZd5zj9MZZtKSfn08PDw/dvHd+9OHz5Mj486jiiL1Leo5KKlwku1lmYrG6gZrC0UUmRla6wO9rIY6/xe8vJwTlo91aGXKN4WPwEAQEXBapF+2Xdrc4cS8Jq2+8FZPD1GAKv5tiV7SYaLF6+mhkzkmrjdV2roXLIAiCnlFeULaEuWgwBJoCiRsOQiqSiLeQNnyAQcyDWu6aLfdrTpXde5GBculHMscK7q/lHgkxmImCkYnzMcBpW2S8EYTY0qf6cCGJSammQERquRsdXkqKoJGRC6rutev2Yfmm43bK/d19/a998n1MePH3bv3o53H5vQuBDRnqzIEt/iElRXjVFZdeAWD/N8wX/MyKhozcfXsA9rXx5SpV0AALIawdESfa5Z8trKV6GiNcvyZDHWT1tM+9NatBUuAmZPYLnFumvtO1ntOizPWbIp516/My7zR00L1e785TWLPj0CMu1vbrxRfhxkyl+XcjeOmsqkyroA9CrUsCoQFoOCWuP1wihE2Pi46W5/9tOf/PVf/Prf/Otf/5u/uXpx23Yd1nVUESgXaQp4btT/dANf7xqgB95DeI3xJyjvAe6tnMzUdCXVoPPzK/sVVVndknE44jhaWxRRcXugbkAHSIY8AgmsbrmZgAooaTEtItlKAa2SaA7QOaWu0FWB1+LeoLsltyOq2fga9f7Bs8Cn038y64u6fM3sIBJv+k0MwTkXQohN9Pf+dDqO01hKUikGYqaG1TOt9K6iJufqJFSyKEIiTxSIOIYuhk0Mu665enH7+etXX15dvWybDVEV3zOz6vjVjAJczCS8POY6PGHrTEoGG1CPagehSTDX/BYRMGENYlCBEBwTk0PklHU45YeH+fH+5B1utyEG5xwRYQiO0AqpWYHaP6JmRSUrIjl2Ica233f9ru02oOQoRtdE1yAAoBhERGBiUMzzhMCI1Dbx1auXbd/O05RzkpxG3w8XJ/LMtP/Db7797X/6bj5NNs83LC+c3Drs981s6WTl7u54PEz6OLEGzUJoVxi3EmgyVCTBHqGhwqzCOsF4Pw/wgLGf0XsEzxGIYUnlCWjS8VDGhzTez/Mhhblca0E0RM2lHGe612EYHw+HXXCOS27A/JRCSG0jcTnLZIBQDDngRS5CRL/5/o6pqu0Km5lkSVbEF9qYEzScoLHUjgegMsBBlU+aDQ1vX175yBwAEKiqDTfNfrO9vb5OKZeSA3MTm9v99c3+atNvHHtcYR5gkOdyOAzzPNDAZiIlSy5S1BSASF0A0MJc2EmR8ThwAT6O6L0RiVgu8rwsimggRfI0DQ8PD+/eDR/u0uMxD6POCVV0Me1aiqoaGtoKhTtHErSKE4ogrlG7nfVIARHBdG1pXkqeoOcfzusZl2UCADVOp9UPtoXhjgAWlu0fzcpf2JvFshMRmaoQkiED6kofgeibFvF26chDFNMVpUQ19Vj3idoCLyhGUoF2plbMDFCJITau73i7Cdtd2Gx827L3SGRrz8E68E+J2lU1p4JE6NiU1xI4MxNTZb5jYwJiFTEWkUKizpi8Q0KomhZKqEqkwAZESkghNNc3jYu73fXu1evtd99//PD28fQ4nE53b9/t203XbQ2hZk1F5AkDL0VLkpzynEpJWsRMsIIpVJbnqppZKeXSKhbRUqQm4Y1qEzoBGoEBaE01IlSVcKxGHxArLGFpnRRZipRLQL0GrbXgupigKlqFS1bWnlyB6oZcmnNbzfk69z4N2eGMUPv94wzcIAAxsJWx/4u//KV3HtEK2sPX3959fKgtMdWjVTQhULUCkEGTaTIjH5rd5vaLzz7/5c9+8pe//ulf/frlT77Y3t64JioCqa1H85R7R/wRO/GnjMXNBAUgwC24L7hNhKhYbDIpX5vdkxUyXYvmS7Ji1ZhigMW6iyqaQFYqGRqDqAS6OO5gYMUUrKAmKQlltpxEVZGMHXIkiH3hN4V/pvxr9D9z8aULG8eBkP4ElOlT5uL51cAns06ABISOCK/2VyGGrm+3283Hu7uHh4dhOM3zkAuKzrK4+aCmUvmScRW7WDYiR+A9ueBj22z7/mrb3+62L17efvbi9vWm2zJ7QgIAtcqdAWZVrPiZwMUPb5lD8ZC1HPP4EeCO+OhJHUFVdXREXHMQQIjAzCF4xx7RqRq5LDofh9kxkDNDa9AzsWOEmrqvVSdVU9RsoBBi7Pvt1e7q9YtXV1c3XbsBQYbQ+LZtuiYGdghYiuQ85+EwPD4cSpaqDPL61es3/BkgpDSN0/jtYfrHh/npRC7P6rdff3j47W/SMAWTv3qzeflqs+/8LUECeMgyPk6nefw4T1ACo2fvCLgVbmyRQPJk5GdDEBKhUXAkBUhJTlyYxVcBawfIWihNdrxL032GobiiWwfc0zXqzlkgezvkh1zmx0nm46ZpX3TNzvMOZEO58xIISWbM4gw7wpu+U4oPF/crm1fFZIjo0RH6lkPHUoIWqLs+OaZgjoojNVZBLcoKtbkITQnIIQbmLsZd3+f9tWWFLGPbtU338sWr2/3NrttE9gxrPscWtS0BBgqARhzQK4mhrvsVLuKqwK6wU3RFQLRU7YGin3bugoHkMo/j6eHx8OHjeP+Qj0OZZk0ZVMgUlpl/UZtcGvDrBl0jHlVFFYFlVzUwICI202rdAVBXqPxFlrQmOM+rdtmHzh7AuQu/ZvNtVUz8MWnqpxh3fcmauayq99WhJzI2M0Ak7z1Tk/e5lCKl5AJLUiEDIqCAIGJ1SYhQlMRYQBZiPSMCYmpat9n4zTZut6HrfYzE/pw6PpsTWFCAa/nh4pifnYJqTgmZSFkdU4W2KdcsGzETEbAzEkUyImRWBUDP4Cp6CZAWn4qV1CrNLSFy0zgXXLeN2228vrLfdvLt14ZuOI1hnHxKRrgU1LU8WXcpkueS5pxSLbYtpr3a9Xp0pmaWc748kVqfQXti4UJERQUAAjBEWAVRzzQPZqaLYAiCqSpQzeYiAla2UsJPzdtSgn4K985wuaXLF+084dbpsUTqn2TpL6fRHxurga+YBnZdc/356xDjlMeMZm2Y/+l36TjIPMNiQEwrhBpQiNUHbnx/e3X7xWc//fUvf/nXf/XlL37++qsvm+2Go19s5LnuQGdQ/EUK+p89zjJMCNiio6V5Qee6jkxA8QgwI0j1fGEtPpuaCmghKaQZdUaZIJ9gRpNkLmJoyUWHLIuqgIqVItnyjGU2FUNUDsrRYQzmroR/Ivwr87/i8KULN851jI7O6YI/7bo/t+v1Ea35HANih8TsYhNjbLq275pN32wfHx+Pp8dpOszplMtcNAuoqCKALjVxQ0QmYkLH7J1rYtM2/XZzvdvd7Hcvr/Yvr/a3u+1V7Xmrt2IVaK9lvE+s+1IpeHYemqkMlg4yPRKd1E+qJMpARgvxzColWBeLOQBnxos7ChV4Y2JYxFIRT8BIFX5rCLg0NQAaMlAT2912f3v74sWLl7vtvm96zUbmvQvBx7btnCfRWUdL83A6DcfDIc0ZETe73dVu1203sW1yTuM0jN+8+8eH788n8lyv/X783e8eTfIukn4Z+5urfksdFWdSUt6c/GGcPozHYQSCxnF0JTTMyTAgeWQICi4jFWmSb9J1K24PYWMMM50MiAw5IWelYYLTUU8fk5z0JsabTXxzHXYd5SKPo3zzKL97KL/9OL4/5iHNEdJVR698fBHwJloXsoeMpbgSWh9ed47DllPzsDb4EPMXX/3KOZdzAiIfHCGilaVJCQBtUUpXAq1cYQAyzDanvmfPzEBmGJAb8uKiRcGNecWOXEq5adur65tX+5t9s2mdD0BchY4NfWw2V7cdYbPfklujHluy9QuSFxARqEatSAZQlj5zKyL8u+8A1o3YTE1LztNpGB8fx/uH+Xgq4yRz1lTUhExR9Yfgt1rSrnWClXCuSptALXueM6CMAET1mWuydAmTLpKka5UN664PNXJD1TPXCeKKeILKg/Mct3WZ4T8bT4SFm2iFqZoBM6yoKTBDCk3c7NuUSimGqACYGClLKUhFpBKfESEpESgbK5ihVUp5H7q+6Tex3/i2c7Ghyol7IcO6WpAzVb7ZWqRY/3pxFqolZxRCJ6RMLKqsykrMRFw7Y5wYsQIaMTGTolJCcMYAJLag0hyxQzZkwYXAlADRHNNu18XmxnnutwURXFOKzdOkCFLBb+fOdROVUtJc0iw5q+RFKktVFw9gJZs1SyldGsWabUCsSRasSRcisip3U1XL1jnwNB1AUQEJsDKCVs6c9VoR1Hr3cgXxotqKFVeFBLCY0qqcgxe34GnqrQx7F8H66mD+ILXyqQeJT0blqSRAyG3sX1x99bd/3V/t95+9/N1/+Ye3v/3649t3w+E4jVMpxRQQkLz3bdxe76/evPz851999etffvHzr15/+cXmeh/7jhwv/mitVCzd+LiytP1LTPr5sFdm3vo+DbpbQERTRa8AltDsawMhVAAlq7JoRYpJ1jxjGh0IRG9dq5tOyWh+dCMYsHZ76XbCUdAJIKpCESsqRc3A2Jvz5hpF79RtCr8p9Et1v0b/Mwqvne8dL2T+P8Q5/PA0VknJFRi7Yg8uHpzvEiAaEW0634S+b69v9sPDw/3D493j48fT8DDnIZW5aClSihbVypVrRMhEzBx8aGLcdNvtdr/f3VztX2y3N9vNVQitd2Gx6gjnjXDxMgFhQc0ux0IL0ezFpMqTTo82HSCfyBcmBCQ1lCSCykRkvK4sLIxqmcjMKCdLSZl4t90SQWw8oMxzKlBcxcNrzWnVch46VOd40/W73W6/3e02203bdSEKAigxMVS64CzjND7c3b9/9+50OJWcTY2Qcsp5nq1tGNHFGIJrPx4vT+SZaZe5lGFuPG6D3zVh07oQ1YFaY80Gd9d+yHosedA8TZMUwVE8uWjQe9pE1wd0G3O7Era53xpuHW/AtWBWpGQDMnSz2iBYTnK6y+PHQiPEG38V3It9e33lwcpYdH+C68e839B3d8PdYWxc+XInrzq9DdB7aB06Z6CKIl7LlXMx8AcguDDtN6/eeO/HeTYA9hXeaWtX6PIsNFIARa1KTcUPehqQU8nZyKCgzQWTskAA7l2gdhONRTU2Tb/dbkMbjDCpUMqZJck8zaLo2o6apr+9Ye/XZQs14rS1rwiektwIZqW22hOpypMQWZ1npnme5+NpPBzGw2MeBplnSVlEUOUszXh2PG2BiiyZ9MrsZICVp30NV2uuq0aqa2n8jJKzc5YULv/HNdg9H1m1jojn7rX6K/1R0Xa7fC1CTeESVuWWiy0aAQRNxACA0IUYAaUUXSRtMLuxpIlSksKIAihERYlICXQ5L6yyvDHGbhM3m9D3oe1cjOQ90hNHNlSnxda46yLl+OPbmKlIRiUwIhVSVmM1NuKV8o9RFdkZoKEYOzM0SgIOGIykVoCBlFjJKZYzAT0JMznvonexaZEx9HNKxQyQU8p1IxfJsmbjqyxWSbPkWUteet9rl+PS6S6rmrul+Xnzm6qI0JqQecpgEMLSAqWGiKq6YIERF1owWInEQCvlcF1Lq+OnelbbuQjdkBZ+uur01WLPwm6zPHkJ2ZdujmcOxcWcxh+KHv/QusOaFSarqg1ITOy6G/+m3W7Cttu8uN1//vrt19/ef/h4OBznlFSUkULT9Nvt9esXr3/y5Zc//+mXv/jZ7atX2/2OglsdU6seypOyKX5q0/9FUfvySlg46hGxBySKdSEmIW/SmX5n9qg2k6kpqJAUKhlzYTFy3pzTflvaXqQ089TMyash+9nHCTAzlKrriKCIQqyA4p2xJwrBeC/0udDPzP0S/VfkXzq/ZedXiJn9yKX+5PABlgrb6oHVi/SDTP45pY6I6JibyF1j203uu812u3vcbo+nhymdpjSmPOecixZVBTOkKt7Bnl2MsWvazWa32+5326sq3hpjh8jLwS5zGp8f4xqlrOHP2Qs5P0nSWIYHnU9QZnSKte1csbYOixOHWIVrihkUSSUhFFWUApINzNom1t4REZGshsBcy49cO2lqD6xnZufb2LQx+sqKJmqitYUfANQsi4iU43E8nsZxTGoQYrtQXKiVlOdhckzsHdKnudJnhmTr4IuWXl31P3m5/WrDWxlpnDPMpsk3dv0yYOupd/7t+P3bNBzyUEyMCG3n8FVD/pqv3/D+BXY31Oxd2HkKCKySqWRAJCA3Fvd4wpz0dKegFhT3ZBtnPjJ2HgE7Nde57V5f3bqHU7w7RlTcB96w9igeDbHxgRAJFDBbgEROIzylHBExdB2zA7GUiyapNnUJAKpKRy18IGgtMZpKypIS6+Bwdp6JKOd5TinNWVLGJEGRfQRA50NAj1nyaZzE1KeEY846nKbjME4Gnl12UZ0HwNp3q2sOZi3Mr2m9amlZGbFtIiFcyhCZmYmUeZqOj/PhMZ2OaRpzTqKiVXb6qXi9JMXhHBwvwRNBTRAZqNkTDXRVIlNQVoZnjb5r9LT8DLqG8xfXd+2TXc4BEIBwrc+va+hirHZ98Z3XLBgSK0FFb69F0ioHRbWvrjK/8QK/Zybn5tNhHk6Fp5ISUsalaCw1YAVTRCRyPoTYtE3Xh64PbeeaxoXAztW0Oa4W9SnpcNGvdHmmn5wFaGUdRTVeKPDNkLSWBUCZ6xUDAqy4hxq1kzkzEoVKD6GqhmqEgsuVQHDOAJxHz0xt11DjcxERcSillFJySSvfnCxssiJakpSsklWKiVaaOSmVvkZs5ZpPRS8XfZWsXNJJfJGoWbZdQ7Sah6oUZcAGyEsO0axu94iCAmCwqAytRhWeLmRdi7WBGNcE1hKl25mO0BYYWv3xLBZjtu5g+hTYf9LFd56mT7/CtdJTnTyj1Y4Ae9/utl/88he3r19/9etfPXy8u7+7ezwcTuMoRRxx27S7q/3V7c3ty5fb6123qRUcxiWT82TG159+JFb/w/YPLlbXxW8AYF1TsJoZ9OCuEH6GSOhfaPqppn8y+QbKR7STqWnxpUp644zt0HRz9KlpSoiWExfcMO9BWqMH1XvTETUjOkRiKi5kc7MBIClQFNiJ/VTsXwH+DflfUnjFsXOeqPIynTN6Pyi2Pz+RyglHqqs7tppQXJCGupzdeQ+o2TpkZvTee++6TXt1vZumYZyGYRqmcZhzEpGaUCdEYvKOvfdNjG3bdl3fd32MbQwNcyB6UgF+Io4+VyqXL6r70HpUtVX26TQkjel0r/NAUhDIgEtFHWYDMBUDAUI2AFVJRYoWqTS9AqDIyM45MUupqAgqeMfe+cCO0YlikkrrbcH5pmm9dwCW0zycjo8cMJvn6F0L5ECKzJByOp7GUmCzvWLmJsY8z6fDAQG0yOlwGE4HckyOTofD5d15ZtqvPe+78JN999Or7pWDJo8gY8EEqui59YStpy5QGzI+zjifHtNcDNi6DYc3Yfs5XH8OVzfY7TBszG8QyCmgZtCFghOTUNPylCSPbjgpZetRWigOCyKjZSZ1BK3DXeteXrWHEhRcw21Q5JJASC0ykjIUmlFnFHWk/DwD7ENg54OoUV4IfLGmvarVW8STqklCMDTBGDHnMo1pKjglMJNScim1J0OKgZIDQiICDwXSmLVAGjPzyOBL0WnOWYR8CG3f9D37YJVPY1GPX2JrshV0fOZyNWPGrmnOWc3z0lGRPE/z6TidjvNwSvOUF55ww6dtYFHGeR5tL++uCGLntKZdNHGaqZCQkpx7lS5jf1vfEZ+2m7oH2bpi8eIpF4nT87Gso3LFL+WQ806wpOSNFmJOMlI0IjCGpXULYbH8xIyV8MkFYp94QBqRZqSExFqKSam1fiJyzvsYQ9uFrvdt55vGhbjyAp1LgufaQFUhOW88Z78e5JNtWlWsQC0gmNYaHphV5KASM6kZsJohGRCKgoKRQ0MTM/KKBOSMFNiQjdZ+cKsFQMZCXNAFdrFhdoKlqGYpOeeccyoll7IG7qamxUpWKSpZSzFRKbkyYV2YdlPTpAjw1P9mZlVkTxdg3MKoh1jnC6oYgKIZLGQDjNXpXLNDhAi60OvRpbm1yggDzOc2zqfvtlaIV97DZ5n3eqgqFfUJZlpES6m/XrKpl1SaZiYiOed5nkUEcaUXtPX9dZU7qwZFAQBC18am6bbb/csXN8fjaRzGeZIijNzE2PebzXbTbzY+eKgYrjQDQOX2Wc/+nM2tS/pPDdPNbJ7nUmo758VCeB41rguUgVr0t0iO3N74DbivNH8D5T3IoxUF9sCArI4fsP3o+dD4k2NFInI3zr2y8kJh4+MHat6iuwM+ACmgImUybxAMwMgJbEU/E/wFuL+l8BccP+NmTz4wMy05vXq+f6zcjkBU+cF13SHOdffnLhAuZ12j/KX9DJEchei7ri1lO8/zNE/TPOacqz+HAEjETMzsHYcYmhBjjCEGx56Zz8b74lPO+QNTq0kiWhzU1QosrfMXQ0uSNIJmrDwVilYMyVDBACRpBnWOakOuFJtSKUVRjQAdEHnyjrToPBc18YRAxM559o78XAwWjCtgdD5EQEwpP8qxzOKUnbrtxoVAdccrUlIuIuZ8aJomxhh9mMZB1UDUO5qm8XQ6Fi2G9jA/m1fPTPtt9JtN90XXvIluo4nH2dxcqNQ9GJ22LYddF/pWAQwfih5pzhTxxRv+6i/CT7/k17fSNAVZ1UtSFW3VIlnjHLNP3udAEFs0jtHRh8nmeW6gcJ45DTQn0Mm0AHgg5z1xG327x3YfmitS1GEqyRVpTSCXBOVg+cFZciZmzzp3iSgEvyVsda3vLtlpeLqpFxEsmknb5aa9/6AP05TGIad0Qb0KppU0HUEQDVAE8kSUgGq8hmIgaqFpN/vd1c3V1c0Vu3D2eM+oY1iFL1bTvh4AgmcWKZcJPTMVKXmepmGYhmEex5TncrZhKzF0fS5+Yk4BzEBq8ANI62Gs832JlooUg9qYdUE3tYZWsO5dC0a6/nEJ0ZaLdw7UbfXIYUm3XqwWW9ijxEwBdbWqdSNeCbZsIahbyawWWC0SxiqQwN4H5kAckB6MGImRiDArkhBVWB+z8yH42ISm9bF1IbL37Byxq7xWBmZqRqBmqBWtgALI+LQZ1Dmjz1GNalpKXlPPROZITWtiutLLMbulzZ8QCYnNAJEri5+RGDlDURRgRWeITBW0VjdR55RFSczIBLASyVVGmpIlVwufRCoRqi4geclWipb6xKqAPuecRcrZcGYggPZiZiz3H3UV7DYAW4w1qi41pHpBiJCr+Eu1nrVOXkNZATMlsiILOI7XUtDSj3ZuRqjz5Kw3q2eyo8sfRVbgn5qallIhA5U0kADw0iiqakqp1g14ESFce1BXEMWiNYJP93RZMgoK5rtm2/hOt2pGFVhKXEyOwxHHCzt+USdef1WnPK4e759k3c0s5zyO4yeU/rDYuR/+jg1bRMe0JfcZyK9NHk0OJkcTVSFXaqfrB9Pv0D6QPVgBlWD+hvuX7F6S2wR+5+gb0t+CfaP2YHY0cGCm4BU6gyuF18Y/Q/45x19w88Z11+QbYodAT/BguNhpnh/i02MAIkJiWxh8lyuPCPUvnww6M9cQPrX2MxEF512IsZO+9ng85QIr3R0hETJV6Vhirv7EOdFk53W8eBh1Fj8L3BF0gZPgDy+9KZpUwKsBFFXIRqhVZCbNxQRiJODK+IRVyYHMmDl4aiLH4JOkOYmZYqDaz8/eMTlSVbBikA0aImAuoqdhSJMEDh7iptnt2bVt69vWkMd5RoQQg3N+t9uFEOpibVMmgOh9yvJwOB2OjymNJ99CszufxzPTHtqu3+/9pscYE5EYkBFCFfH0zJF9R7EDNH9rcYSQME6JGoovWv9iS9csfZ64iEoWyuYLNKqdQ/QkASdPiZCNvDbIW6Gdg+2YvB84sIVc2KSYgoEjdkwOoFW3o7jX5goEi0wFXS6dFlVLJmCmTmcnZdRwuYQeHu6buVU723W49D8vbOr63UxzyTkNczqOaTxNeZ6XGbWUnle3fQForI7CUnADM1SAHiB0bcppHkd2xc5p6HPJyZ52jItVYQiQEEotKq8jpXQ8HE7DMOUkRNi2lDbMDFlAtRa1aP1aoC4IT5vruj2v39d5XLdhqyUCquiU+mCx5EuFc8XaY5XQoKdDP58YrCbf1nbn+qcL8gQDeHx4MAQxWxVdF6Z+XSlv1EzEZMFRLRoRVnvysIKztKSkRVRtPfZ6TrRyd9iSVSdGZKiaEGpSJKfM0wzIplZSyi44ZscV91YbWtAhOoSzc3M27eWiF1EMx5WllIxIa9EZCZEUCKsCttFCuwGEgEpoGRSNzUiAnCIrMlBBV6ppX65hdiAFU2aXHDgnhKooUjRnSSWnnFPOOZdUkXRr8royyZeFS7jkyiVf1UgXeWoDqd1S65hLHlOqGyIVJmYkQqylCmKGtW8XrCYsiYkckVtwAWfV3bWqQcxIXJGEVJWqnKPa6rTi7crS6lCWnrwfs+siT3G8qhWRsuztVj29Z7dD5PHx0TnnvV9M+9OysvNSf1pj8PQbxEqPS4s0fKXIKWvB4Fn54oeR+RqT/nPsOqxphpTS8Xi8dBlPp9Pd3d2PvuJ8JmgObGPqTTuTramqkqmCFIMAQCDRykYTSo5Ke+NrphuyLolxKSCTlaSKZmyQDdCsUdso3Bq8BnxDfONKx4JOJnKCyBce0XoZfjCmaTo/VtX5dECixRhfvO6y7nVh2+uUYVw1A9fnXnxWZYg6v5MholZxOEFQgHKulde8lz0lDOBs2i8em14QJa2fmNN0eced97FtCT1SQVctpBGaJ/C175EYKBI7JnNQvCQicQDBUfQ+xhhCDOJ9ADP1kV30HDz5SOjJlLWwOmPl0KNvjVjN5iKl6DDLKZUhS8jiOSvIME1pziIlEM2qUoqKTrnMapUmYxQ9pnKY5mmccufgQkoF//2///dPP6ii6brDn6+pnWfx+aKvOtHLLaFVZbte5PVllwZ0XVxL6ndJzYGeDcRTzPjs4/BsmJ65YavFXciBtQr71Q9C9N7/cCL+8fEkR/YnAEJ/+Gp4Nnn/JQcAkPMTa011bM+RjqmsyKJPx7/ww/43v/b3DiK4wAN673EN+n90fPongx/sKedCw1MKd50P9slbnNOwa1S1rv7LGjA8O/PfdxEub8eZ1OfZm1yOH93rf5CLXK3Cj8RpPzic86k9nf7FH9aHBpfP/5GnPEu0wDP6rT+aUr5MbP7wbz/20h+d/D88qk8qzj/25E9HLUWdP6Wy7P2L19r/IaPO3sv0w+IG/clvsH4/kzrbxRdcYCZq1GFrh8onNOMr51FVYamVyeelgT86ROTpRBA/ITr9veNiNv3/7M79MGn5Jw8Dk4vuUK3RQ/3L+sbPEnoXS+ZyvT0t6kXqEKDi8xe05VMi4Sn4OueGDKBu+LjmTy+s3Nm2rEe85FQRFoqmiyzr0xV9jscmsj9aU6mjPoufzlcA5BNP+Y8OvKwA/mD8yLt9ukteeOMXr6vdPv/jjxrNLD9U7+n/0OP5F49P+qr/e4+zxf+XOGh/4G2f8ES/9xm/72j+4DP+O43fe6ifgNH+Bx22UPH8Dz/+O5yF/dhM+wMbe92//zd+pkn533WZ/3caRPTHgQXr+H1r7Pd5m3gOWn/s91ZtqP3zl+eP+WR/urf45/Hn8efx5/Hn8efx5/E/wPizaf/z+PP48/jz+PP48/j/q/EsIb9/fbO9uQqxCSEuAo7kEEmkiIhpATBmYqrgJqu4KDNTE7UCYMSVlqKWxC4EJOqzTREIiZk8s184wImrhFGFMRmYlHQ4fXz77r89PH6Y86Sqq9hOVUQ4I2BXPAYSpUhTrB8UY/x3/+7fxRj/O1yuPz2hahdpMbz4Wv980aVxzt6UUv7Df/gPx+OZVGgCTHHzqtm+bmITvWsDOcYh2ZQNiYnZMTfB7fumb7xzBJrm6WE4fH14+M+n47thmlQEKyTfVIqKSnAYIzM7pOBc5/0mNPsYd95vHXWYAcWIHYjOH9+luw9yOuZ5TgZZTcUAgT3HGPu+c9sr299auzEXDUnVxmE4HI6n0/H+/r6eg4j83/8f/7fTeKpdyk/XZC2PVZjQUvJ7KidBrSaigXOu6eKLn7y4/cltvGndJlRhEVyEnNcuvDopBEFcOubT++Hw9uHx+/vddvvzX375+qvrq8+b08Pw/W/ufvufv/vtf/kuzQlwrd0/4beWAludVv/r3/9fr/fX9ZC2nD7zQ9PEto3OEa+YxIU9Dc2gSvFoETFVdszsKoy3HqUBga19fYiiVmq5UrSi4bxzMXhecG1gYKWIlFLxH0xYkWqItSIIVPFgAKKWcslFq5S1KmSRlPKU0jznDG5oXq7TDbZ/8ej6YgYqOs1zyQm0OLTgWMWGUXQkmsOtv/5y8/o6bDsMwJYwPerdx/zh6w/vv37/8TFNoxZkYEfOU2zc9c3m+mq3ba8a35WsCuK6BH7KcBqm08PdMB60jKFrm5evm6YHowTqsfSN2/fxKrrWsVcrRec8TylNaR7nNJz0/qQPwzxPKfd3/zpMb+p5jHr6ev6vBgKmALqIgdZZs3S6ruVmqFhPXNOei/KtVZCeqFW4IDAg1xrfUsIEAzB2yJ58IO8Xwr0KWJzHkqaSk5RFPPoCnlv3sgWBhMjMjn10sfFNFyO3L/LPcQ2rfnUr142JVsQYIzKyI0ZmRVqYfEEZAGClUrAF511h3ljxTctJIyChIpzG6f394f/5//nH//KP3xG66+3uX/3yJ7/48lXXeO9pBY6sSkLLawmRaVG5WQgyoNIMqBkooCAwgEPwYO43d9PXj0vps++7f/tv/8+AKKVM8zyO4zxNc0oqYlD5nGrTueacU0rTPEspbds2beudQyIRyTlN05RzwkqtIDrP+TScpIhzXK2Dcy6GWIGTzocQmr7v+35DTGY6nE7H02Gappyyc957DwY55/uH+9PpxMQxxt1+17bdWvxWUT08PNpapTrO85BzBRB75wix5JTnOaep5Nm0mBVYmEqAAMFQiuY5qxoguuBDE9CRQmWaKGnOac6yXAcgxhB8/WLmnEtKyxOcd84xO4eIld8BwIpISimXvNAmmr/av/zszVeff/azzz/7aRN7rEKj7N6//8233/yns4l5Ztqbvt29uu76bdv2SwshB0QupfJbJQCrn17LvrYoTqlozpIQlRnAxEqp/Q92ZqRArLBnRCJyjmNlGCBiZs/oz5bbTEuZP965Q/n6ocz5dChSmInIMTlEQlp6V2Ehp2Aih8oAiy13zv31X/9113VnI3oJvXlmS//Z4/w+9Z3/wNPOuBUAwDPS8wIBVuGHFZa8vCzn/B//4388m3azjDiGNmxvbjfdpm/CpnWe8WGQw2xGjtkF77ddfH29ud42bXSm4/D47f2Hx/duuqP3jo6llMUAiQhKKaWJtOkcu4DUOL8LQbred33fNC76zhVmYedbKHqCaZzv8ySzTYPhbFakgAEjdxGufRe3jb26sd2thU6RVfXx4RDCRwA4m3Yz/cff/ubu4WO2ShhZTdsqLlO/V4wHrdpWy66CYIaATRN3Nxt8o20b4Rr1xhRNUYmr9tsi82FmpmCZNOkA08d3Hz8c3354/+4zfvWzzZvuS776dZBvhvzt8d3h23/4x/86nCakRU3FdGF4W0CUleqf+W//p787m/aI8sqNfQO7jQsBPWPtk5DzKkRTNSmaS1FV58B59I6osukuHAeLzgoQFZF5llKsiEguWnIM0LfOeWSmKtiTkqS0+NCOyTlzDEiogmbGvIhyFZFxsjlJEitiajZnmSANMp1wGiwMZ3QqQPNqDjezKeRU8ulUxgEkIZmPXovBYy6GmKNz7qZ/9UUbbrgHKico35U0jneC3z6mb9+Pp4Nk8uA8BaQOQtfe0O2L7go2LUhSBXHbBPE0yl0+3BW4H6UU6ZvNbvPF9f6GzM2oDaZm491V37Zh4zmIzbnYNOVpgGEsp2Gw8mFM35XTcRzG5vgLWEHZBdI9vq3KbRXqs6y4RbAYbQGH4WLwcEE1Lcj8IpKLVm/IAMEheiQHQKuektZWTefRG0fixldhQysmqeTTPA/HNI0lTWKlNm4A2MKG5JjYkfPMznHwAUIMvuNGQmtuC+UJB/Ki08+3WgRFCJGRiB2zIw5CXF1CQvNgYFbM0CqZmRkiLFK2a1crIiIZORK0jwdJw/H08Lt//M1/YYry8hX/ZPeqfbHfYROJUM90+IpotQMGL7TNKtDOBAxUrApQwKLagGAOLX4cy9m0ex9++atfIGJK6XQ6PTw8nE6nYRhExMycc845RFTVaZ7GcaQT5py7vt1sNiEEZhaROc3uhNNEAGBqOUsudXEkA0ZEUTEIIfja4eAcN22z3V/d3Nw6x2by8HiPDolpdrN3wftgasSEj1iKGEMA8DF2m75aAVWVUo4Pj+e9fCrlcZqLqBkE7x1SSXOexjQNKY2qs2lBEER1tT1OscxlHGYRRSLfhtg36EhMcynznKZhmoZZapsxg/PcNKFpY9s03rlpnsdhHsex5OKDD9H74IlZ1l6hlPMwDtM0lZxVEC28fild/+rNm9j1r7abG6bgfQwhzvPpW/g9pl1ES8rFl+wKihILMxA7MERkJAemS0O1nfUFiVDVFCsNpEkVrgDAKrG1vBaJCZnrbCQAVjVTUVRTNEIkroSWdat37EJoPUcVyFlM0TlAQCasS8gWChhbWx2fWeuLFhoVkdpUhojOOWY+kyScTezlg99j/n8fRPASnbx01C54ZDMDrX9fWpOK1N4fs0qbzk3ThBB/IBm8vqEamraOrvrY96GNPnpHiN6Dz2UqpWQzo+AtZSuKQOyInUPHSly8kzZicYCgOZe5FABlNiZDUpUsyabRANNwOnXdx65/2/e3XXPTNbfYdKSt2+/dOIoAuSawMwUdp5xSNgUIDQWkwBSIHFayVYMY43a/u+yKgTMA7RnLzsU4U+9e/HGBpBoyUbdrr17uty+3zU1LLRtVMlQjMiS8rCmZgSpo0nScjx8e5uMYQ7i63dx+vt/cdOZ4TjocU5qKmgEZ8KKOWaV0KilajfXWDNGzWQWmWkqeE4Ox5yqtgzVkM7QlZGKHpqSMjIolFYBS91BErPIWyExGmkuepiLFzEAFNJtYyUhojgIqAgCB1fUFZgiERpVLu+YVqmtoqmhGAAzmQQhNAXzA3sdt5LF1D5k/pIsOBbJV1NdKKSlly8kYoqvqYjKVMgwjJG65hSwQriwP9+O734zf/Ofpd799vL/L0wRiaKJgopaNUh7S4VRgQ9I0RwrOoWPnVElmTicajzYNgJk9bHf+8xftxkVzsPF6E922CQ1X8XFT76P32DQ+jN41fj6mo0wOjMzwgr7Ce3ezvRbNJWfVpJpVimpWNbGVkXaZVPhjjGqGYEiVixHXzmsyQ13XfuXEVQUpIgVLRmeItHhahEAMtQ1QCdYmaWJG5ynEENvQdG3TNqGNTRtD40MTQgjeWjg9HYeIiagKmkHtla4+JmgxsEVUz6SadqgnZ1pdFMMaQa1CcAaoSABGoFKvhIICIoIiVuYsQaiEigu/OliF24MhyZJIW92iRYLGRK2YiYEgAiEhCFaWiovdL+eMiKWUGpenlGqDCa46BczMjl1xjl0MseZA5nkuRdyS4nIxRDOoLZxVubDemiKlSMk5z2kSKc45JGpyVsAQm6ZpQvDVd2ubhpkqvZOIVE77vu/rx/kQvKtir+s+v6Z2lhMRKEmmeZ7nGQwYMbDzRN63zgXRZCZM5hkDMyqUJMllwklV0RMFRk+iWiSnpCkVKYpgjpkYyTExIbIWLMlAtEwqs2iSGhCgFBDxIfjgwbEsnNDAhMBeAUFYSjkdHj9+fPv2+2/G0+h913fbvt+W52jlZ6ZdpdQOWpcjEqKSKjoFYodIhA5AK9ek6bkBsfqKjECgJiAqRYoQMyOthIKM6KpBhboxK9SsV+WZwEpUDTW7VEWzfAyd940ZSJZKGMmEVjc1WBsnceEiXXUkz5uwVRhwbSStg4i8r0kH5pVwdNW0sDM3wpkg6dm2vnyc4dMPcMZiw5rYqS3Z1S4YPDFzlFxyzhemHYi49uM65xDBjH7gTAACMbro3Ca6PnKI7JjNMLB61jmLFBFnIpAVZNU8JDRExZqYZq0HIii1GZqWBAqqWs6aUsllmobDPHxI88ci94CTa9C5hsjzbuPTK0UH/cQusJkdT3oayjRm5xI3ngOSJ3KV6M0MQhPJ8fF4urxwl22u57aR1ZMCRFh48J+vMQQgwuDd5qq/+uxq82IT9g1EVLSlCb9G+XjuhwQw0gJlKtPjePjwkIapb7urF7vrN9v2qhWUaS7DcZ7nrIskSPXIlnn81OdybmC8PB5EIlTRNCc0RXXeEzJalR1SAuDqAiMTmCEhGBQpKssmWFOz5H2lZLeSShpzOTepC6oQ2CICzZVh3mrm31a9VKzmnhArtSeiIiIoITIhGnB1wImccyLcR4LpSWFhvfhqgKKScp6npHNSB9ExAORchjR9GMYZzEGkoq5NOhw+3H3z307f/Nfpuw8yH1XKoswKKlBN+2kej5PuTHo3BIoOG7TeMslMaXDjEacBG3UB+114ddu+iI0PtAt4zeiRVWESOJECGiExsi/QBCth7gNtIqWGjfDJtDvn9jdXRXJOqZQkJZUylzKXkrGIXFA2ftJU+HRL14lUU3/Vm1NFUUMkrEKJimCqxUoWZkRA4lU2eZ2BldyAAGraOUSOjW82Tdd33XbTbrq2bWMbffTsHCLT7GvCu47K/a91nQCDSdXY08W0E1YufDCwYqqr7JgCKpoaoAIBco240Crd5VnSDipb1rrLrqwWVYX6vIefU/MrFcjayYUAZiAGWayKq6IhIWQCftaBYpZSJsK8jhpQ1X1VRCoDwbkt2TkPNYifZoCZmbuuC8E77yMYAEgRW1uEDCyXktI8zxMxySqlLGrMYRzHYRhKCcQIYCGGELyqDqdxHIasCqBNE5xjVa3e8HrIC7PC5XmUnNM4jcNpnIaSCgFu2r5vWu89s0NkQ3UMwXF0DhUIkkmW4AyNI6EnIMgiBk4KFFRjYE9EyJ6RGSuhPLLD4JCVSAgKiKpAXhrvDIlDIKQqZbbYAgIyMiXN5XQ83H98/7b7ZhqmptlqEULK+dkif27aNZUylhJzbtg7RjQshosBRag3dYlPawJ0qZAjEbMaSl6IJ9AYKRDyKnhQEzwLT8SqAAlE5DgQOajWUXSJ4oFj6GLoiNgA1NAUl1iX0KzSYxkhMrFjD8SfdG/g0lxoqlpJoMzsidSQyDkXY4xxcR6rjwkAzBxjDCHU3X2N5i80WBYz8qy7qpQyTmMpRUXYsffeTEspKaV5nnMuJRdcihzOOe+cDyE45y5Z4D4ZPjRtbL2PpiKlCBFajdQwekpFATB6Dp4ZzbSUTGYpp1xSSUnmWea5iBQwzSuDmRHkBKpaIRNEgKAiklKm4QQIYKIyiI5d9xl2+/jqld/sNBdzLomGwzEcDtPxCCLUNMbRkM7EJo6YnPfBYmwuT+SymL3enYu/4hJyLBnr9e9k6L3vNt3uxe76s+t41UIAJVBTWnjOl2lbU/cmaEIyy/Q4DXfH0/3RCV7f7K9fXTVXASPlIuNYTsd5npOCrgnyNZyz9RsscfsnLSzOcd+3OZdhnHJ2UnzbxqYJVqmCkBBqjZzY4Tndp2aqSw24MsIQMSGdqzUppWEYVQoheO9zKQZA7CKRd44BAKkW8+ubI9OTI0uk1d9kx04N0dRVVrdzvysifOKmqmqRIsXmeR5P0+kwyZyCW6i0T6d0GqbjOM2KDN6Lcc46Dh+Pd98Ox/ucjlpmkKqtAQRooIZScBrteNDTqXSbufg58OBgUo1aSOcwH0MeShdd40Mf+o3feQ6eOwYPgGaqlrPMScZUTkM6neZhGKdhmsZ5ZOS+7Zwn/f6Jr4KI224jUkoolV2/5DGnKecp5znnDKVI1Sg8zzc7d8vaWpYGgsrIg1XfWBWxoKwNyKu+tmoBSQBWczKi5024chMxkeMQQ+xit4ndpmk3fdt3se1CbJwL7Bw7JmIAQnrW8ltziwB2Zp9aupS1AJgBAxBgtckCBmhUXTMDA1RbJHvO9pusamOYQ3OgVPNJtTx/hqTAUw4fltgVmAAIKpBq6Ytf8mlYHQQx0Eo2tzAamVychaU016i9tvPVzXaapmEYzKwa76ZpUspzmmuGQ6TM81zD/epMV6GsJjYIOE1JVXLOKeVcUi5ZTE2gss+SkgQxs1LyOE4ihZmcI3ZEzITkvUuOU5pTmmt5pUIUSimUEjGbWXVBLpOx83A83L3LJUlJmsUMkhacZ3Yema1eIzRC8I7QUIuWJDkVZmSP0YWm65BYFcdpGk5DSbPmbGBGQMzoHKNz7Ddt34QoKY3D8ePHd8fHO7MEWCo1tHgxxaIieVWcr0ZQoZQyT8Pp+Pj4cIdKUiz6pm37TwiMP0nITyk/uhycb5CACU1NARR4JUFYoShrthsJkbmGzwBYvbQanxBWN2cJi9ct80wfVndBcuwR2Uxl8dKkRnTeNyG0RFTdzAuuM1ogMwslKVFFm/zYwPOUgacguibniahpGhGp+I6aRKrmvxJm1eD+YupfOPwrmaWZqaioTtN4Op3qOzjnYgx1Us7zPM+pqnHEEL3z3ocQYnVga/3px9PUAN43bd8weym5ZL8IhyI5puhdUXQOQ6DABprTJEOZUE/T8TgepzyVMmuapYiAmda8FIAZ5gwpa/QYIy30T2YimuYEAKZZymhYzKRpYtheh82WjJEoF+HuwF3nm0ZScsQYIyx4sornohrYuOeUQWfLft5d7anyu5IvoJ3jBwColZbYxH7f717stq93YRfMLYlEqPi6eosrYFPBFExQJh3vh+H+mI5T7PrrF/urlzu/8cKQRp2GMp6mNOcq/mrnQL36FmdjXu/383vinGv7zk7jPBfIQkjeV3JbWkxo9Qeo8mrSMlmZGcBMsTLn1aABFxgBM4OtbqWZcznlogaADAvLG7m6ghbfY+FRq5kQNVipRxCZuQqBKgPKmd2lct5e3g0RgSwplWGYT8fp+DjLnAIvuoTHUz6d5mGaUYyMvSJK0Xm+Hx/epelYZAYtaMub1oywggjMg50Oejpqvy3WFUMTzGAd4IbAozTOtIv9pu362La+YQ6EjkwFRGzOMs75NMyPp+nx/nh3d3wY5znlDFiAShtd03Yn786BCRJ537BT9uIqEDEH54NL3iVHPGGaUykmF2jWi5wbAiABWt2nyDERkQGKkJ0ZQQwNUAW0cj4XAyQwVNO1rAbMBIGInHe+2bTdtt3sum7btV0X2s77htkDcMXUrfPsE9NuYAo1XW4EikAGCoBlOWyAhZ25IvVq2A2qC6NjzZsqAq9PryfnEB0A13t0UaGoF2TVTzGrunwLvyAQrzA6xOXInnbRhaW5ACZ4TkJlZpX3txIOnl8yTdPd3Z2ZVVwzIuac05ycc0hcg65hGOqebGbEwEzOOV5r86WUXIqoGCA5V4WgqxBkLUqJaEqzamGmEFzAVXITgdZ3WGsHyyK11RzU8sElbCqncT49ABqCohYtOudUYAB0SI68J1czZlYtGQDVXIgjDi70TbfbXvvQGPA4zqd4KimLlKqcgMzEDowYuW+6NjbUWemmgPHRt6UMucxzmRWMrMYxQGYODViW/DahY6+qKc3jeAyuYQo5JynFnvNVPDPtKR9pOrF3LjTsUAVM0FDAISgslL+uAnxUbVWkN7wAq1VZKCZaKLth4eI5fwgaVDpKV6tbC2x4rYuLFEGoFjd4z8gEaIYqJqwkRrSs1VoQUgVRWRCkz0fd7mKMzBxCqAb7sg4kIvM8X5bea/ponufj8dg0TWxi8MF7B+cZcc5agNWZMc/z6TSMwzBO58QADgPWs6k+RIxdCDHGJoamJnZWAN35k39ksPM+tGCY5+TYe+/MHCCE4Jwn57UIEBFDyqfj4ZBPUiQd0/B+Ot2loZgQGoGRigJQ8KiqIpaz5rLECWrOtCIVSI1zRoMiegR8a4Ky21m/bcKVdy0zci5FkkoDkrUEYs+xoxAWObXKVgtkl8RJy166JDmeZ/Ce6uuGhmsMXdOENSfbdu3+dr95sYlXDTWsaIaGaKvMDy4VmspSLmgZ8piPHw/D/QkVuq65eb3rbzsJICLzrNNQptOcUzaQhen6ybpXDPs5aVvDsafjJcdN6+rctiIIYAolKS1KKFqF0lRRtKzviuyIHZmuYgYAqlCDLWZqYiyd5FKGYZzmNEz5OMzjnKdUUi6i2sQQnKeF/b6i63Uly3qKyyteVdWKmBnU1gkDqBLvVJ7dDhXTXMZxOh6Gw+P0+DDJXAKjmiHBMJZhLGk2Uhlx/oiPVZT5JMMDpBlrzR/XnBaYmoBhxnmA4RGHxzLt2Ht0HjLMSOBi6PtwtdkK+89e3L64vmnbSEwAppAUZoGUbZzlOKTj/eHDh/t33398/+7+45wzoLY9dBtumtBEP12sclsYi2sSjhEU/WJiS4gunWgmGydN89OujguSjhCMABEZ0DF6rmhHNKAii16OFTA2rjklBTMTUcBa39alQ4Fd0zGiiyHGtum3Xbvtmrb1TeNcdM4TeUT3JAO7mPZnt2MR6KkTssL6FarRIkMEJqBqjKocsyksECsUg2JLo4gCybKYav2UickjsZ2Fc0iBFs3BmkQF1FWzgRAcATNw/eB6gdd+DgJ0AIBGWvuc0AyL6VkECNRsnmdErNimag8qq+6HDx/OZMDeuVJEzUoRAxmG6XA4Pj4+VhblcRybNsQYYoyVhBMW1weIvQvsvEMEVVvUXWNTE/s1+U8EqiJavHNElHOqR0JVpRiw6lC0bet8qEYdF/N8eTuUSYgJgGcpamVOk2aVAogutF1s2hA8MSkoELIj551jv+m6q91mv9/vdnt2QQw9zwRe20UpDszIMSKNwzSO4ziUPA3bruvbXfdFX159Mc6H03h8PB3GPAORIbCK9yW4pKor+y8ROmbvnIeaRQGDyqD7HAr2iWk/6DC70IS4cRWka2jICAWA2JgREGpmfk3vLDhLNTMEJPJAQOR4EdqqEM4zZoFqUr1KKGONv7Q6CFVqsuScAVQ1E6Jj59kxshhUolUlFaFVrAJrtr++wSd2sX5iranX3PvZR5vnmZmnaRKRnLKZglVlAqruRS4FwNq+25R+s9mw+xSnVz0aEUlzenw8fPjwYRgGKYWZfAg1W1VzEiGEJjZt2zVtG0LjXaiHBPDskH/UuldwAABIzrqofQDUXgdkRCy5SJ4kzfN0LNOoaZJ5yNOj5KOZmiAoqFgpgAjsqickpVhKBiY12QJQTSRWPfeKOwRTFRC4Mdqz6wL3BAqWqcxYZjIBRt9Eahr0HtfAfSnO/QBreE72XJyzPfvvbOBhcawJyDnXbbrd7a697tzGa0DFVcPvDP1ZUuhmilZAZk3HdLo7zKfJs9/uN/tX2/Y6FtZSbBrLPKR5mOtNX0z7s8PAlb3zR6J2M5AlUGMDQDUVyynXHCCQGVbo8kIzaQBELgRmcqpV70kvCSGZGD22TSylM0MBEptV85wKngaq4k8AjogIarCCC7ZLzxDRyrQPiKogqkUEAHBxHA2ADQyf+70iqkmGYT4exuPjfHxMmjR7JEZgHMcyjZpnI5MZ0oPZJLOaziXPVkpFrqIhLGIstaArYGmy8QinQz4duesdIgEKsoZWNjt6ebvDtPnsxavb69sQG6Da3lPU5qzDXE6n6fR4Or67e//d+7dvP374+PAAZCGSa1GR0SE5Brp0DRddoSrkcNYdM8c++Ap1F5UiWUDOeK9qXRWhYneYMDj0TMEjEpkhFhK1s9gYAEGpCCFRVSgV2IaI5Lzzgb0PPjRN2zRd0/Zd07UuROaA6AAcAsPiCME6n36QpcM1bFgT5nDWSTIkJARa8ulLhaBKVqmZVWwhYDXSC9yGyJCAmGsIUb1XQzUyIwO2qlkICIa1b7Cu3ZpsZgIkVEM1kGVfB653fHEvanvSovi7rqA1aj9zaNYdsm65IYQQQi10plx0ETQsh8Px4eFwf38/zXOa0zRP222/3famVt+q5oPZOXLM3ocQiMhWGHIMDbEzwyKCIoggKiJUvGPCNKecE4C5isFGQGDvfWyi93Ge53qEFR113ppUSsmzx8jkXNVpRDFTAkBQVmADj+zYYZWoqRgCF7q2a2LnXcClyaLWbQDRO9fUMhwzIYIknEBySQLaRgCKfROdoymN7TSE4TSmSUwWMnfVuvnXSo0a1J4BZm/rnPlR8/HMtM/zYcR7FzYxXDFFgsDskVDVsE52RVUlRFVRkzW9Xh0rI3IheCIidrV7orpci6dKjMgKBEDIXIUEVEStaqKYSkWR5yLZNBEURxCd9+wkqZkqqZAQAcCi+kPMiItU1eXZrZ7EOb2K62wg51wMoeu6kktOKU1pHIdpmuY5iYqZqcmck6jMeRYtPrjYBkSzi6J4nQmlyPE0fHz/8euvv5nGgZliDE3T1HC/bbu+3zRNE2J0zrmnSP28ws/71O+rJhATE8BSYzGrCnJqhlpkmuZhmI7H+XTK40nSbCWrZCsTgBCTKOVcC5iGaOLRTEvWXEzEBKGwMimiCqiaOVACNJRS9HSSlLHot2q9463HCDnJ8TC8e3t8fBhSoabj2DjHwIxE58QwrBEFXJxb7UlYvcGl1gmXVfYVw7P8yaBezG7XbV5swiaqB0W1WlRYC+SLL2GLaK5mK6cyPUynu5PMZbvdXr+83r7sXO9mk3my6ZTnU57HVHLWxSv91HxfPvqEpzbN84cPR8lFi0Tnow8iuWSNwSP6pe5cV8KTQIbWUi6Y1RsHi0BWdU8MgYJ3fdsQc2iaXFRUyoILy+Np8ISewdS5tYpYD7vkUnJe1HJqmgRJ1bQIICpaVTyxpcL1rNguYjnJcEqHw3w6pPFUQAGAixAalowlg2RQ1WRFVcYMhqAVUlbLcot/VXuxK1E2lmLTrIfH3D/w9U3wFMkBEwekCOw/3/vy4nb7Zre5ZmpFsGjOOmQ9jvkwTIf7w/Hu4fHt+/dvP34cp5kxbDbN7irGPoe2EEOxfAnJhrUqc67zIBKiq3lBD2qgWVIus80mKrBkkwzACAEJHYFjDJ6CR++QaAEDiaIuOiOGwEsypkrHAxoCB3bBhxBDbEJsYmx8jCFE9oG9Z/KEDtDVbiD8QRLrB6udANhMYTGiVPs/0BQBquLxeaipVt9xySrx+vQlr1MBUYtQDzNhvVuKqDX2RgbkJUdvCMveubTsLUBNW6+uAqhVqAeD0ZKlMQE8B3VPNyOnTLyUX5cUL1HXdS9evNhsNre3t7vdrolxnGZTOx2Hw/H4+PDw8PD48PgwTXOaU84ZQb1nx47ZmRkzxRjZefYeiQHBex9jJGawKiKAZqpSbY2KogqWkhEhp5Rzcs41oalXImchIu98CH49aqu353wm42m6/3hoO23b3nOMbYtRCcCRZ/KAntgxc4ihaRrn/TlYNsPjaRynmfjBAERtnPI4ZeYmhj746F1gYiRIWdD56AISZqSHMU1Zg/c+NN2mb7YvFE2t9sFjxV+oSsp5mudxmk6n4+l0JGJY791Z1/RyPDftacj2sW2v03wbfC/cE/rzNqRrsGCIC578KRlTe0CJKVQIumHtDlUwQSCqPT/kKugGkerMUNMiRbUQ1uJH7fLPoJnYmOrbEVgWNWIyJVVjIkJXgdkAddv89Mx+aNfrA2ZG72sfU5rTzDMBgOikWnIukkvJ4zTmknNOALrd9CIdQC1PVKqLiit0aZqH4+nh/uHjhw9pntumIYAmRMfcNu1ms91u97GJzi2A2LXqv5aUny34H7HuXCH9C+9J5XxAVU0lW8rT8TgdDuPhMJ1OZZ6lZDQBWwR+yAWF1iybUVXzKQqwtMIIktbyKyIy1XygYiXyQK1Y+lKU+Z3zfd+/blwPcymHx/Hd29PD/aDg9trevgAiZAbEOjFqBVkUnm/ByMRMLIprIvmiE+kyJ49rHI3ogm82TX/dtTcd9065csIsluXizW1JRwnqrNPjNN6d5seJFG9url+8vu6vW2ppUpvnMp/m6TRN07SA0qsS7mIYLr7boqH3CWO8iEzjWHKWXLBtg+e604qSSAUl4tPb1AcXWbJqDGuihAiXa4R18wrITE6qFN48jeNgKmUcB+/QO2qaSFSx8GSVZeWS/8fUFpg2MOGCZtA1ewufUlKLaCmaZ0mT5Fml1LjSEXkkrO0Vtc5niqWoqACCrekiZ+bMnAqrYIVzOVKkjGDF0lBOhzxPpuJidCGwQ2SE9iYE6bZN54NPJc+S53Ka8mHMD6fp8TAe7x8Od/eH+8fHw/HknN9t+uub7vo2UpzAjQpLwu+TVb5e6Donqq4GIpD3CighjyE4KSJleeLiDhIggCP0DoLH4NF7pKXSg1lABFQJuIIx1iS4AjE672ITY9u0Xde0XYytD9H5QOwAK/tW/aIL4eXzccLFj+fx/yXtv9YkSbI0MfAQEVFiZu4eJEl1VXUPMIsFLvZu3/8xgF3sLgbTpJIEcWJmqirkkL0QNQ9SNT3dDfmivsyozPQwNRWRw37SXQPtlue+zoH6BjIA2K/TffCC3Yu3j9h7XO7fAPfADYQYmJQxdNkMvDH9ejXEhPjF3Mt35ifcJmK9hMde4Lsb+TdtUQMgAIU+737d7O61teB8G7MCEcWUjqcTM0/jdJgPIQS/QQNqq8uyLOuWc25NRHRdNw781u85hO5W3ouxGGOMmMbJAUurgeM8HYmps41qa0wcmLtYAYo33JGJKg3Aj8djSjGECABmxcxFGwm5GxHFGL7so/1JiCBEGsd4GNMQOYIZASSOTKGJm+3shEAxUJS9IlW1Tn/S/moAqfdHmaUEjTFFTtQV2gJTDBwjM0uTJrVKS02OFOaUhiEik5r00B6ZUggOXltb1gXxpdW24XbTydi32GsP63V9E9pFctZzyZdaLypv0KXjFW4dm94Q8Ff0BnTj406gdMN9M2KXUri5fRJQoJB6d0jUYKdk9Iao9SYP7JH3Nrrf3TKZOBChWjNHswAOBIEoBh6I+BVw8t8qfL9beLtf+29CiDRhYBqHYV3XZb0u12srWUopJWurAaFtq5bZELTVdVlKyW4QQjwcTrVqvi55XdtW1BQRhnG8f7i/f3g4nU7jdBiGkTlQZ2h9AaP+mz4qAIQY52lm5hjjNE1DSoootZbLpVzO9Xqty9JKkVrNDNxvTTMESG5HQAg8AgkHN+9tJWU2DhK1MUkMGoNHBkIHUDNTd1VVsG7ip/Is5WMtjyU9YItSvSxrPp8zchpm1Q6Y7Eba1gNM1yn6HqtJMXBo/c/ZAbqwn6hvqnZ4zUPTnI7vj9P7Ob0ZaCSDve+Er7I2/Ri6u6ILeMO6yuXz9fLp0tZ6fzj+/Icffvi798MxUSAw1FK3y7Yta85ZVHoG4TfLyBvGb+97AuxDl6+fgonGFDfdnVNFJDJjiA7QVIADOsM+ubtlCQi3HjkAwK2TZZ133UflRBgjqzveXOZCoGlMeZVcc9goxhBiTIgOLiKioqqAkGLohpj2+iIQ4626M1NXv+mXfbPlbrQo7ID9GENgnsZhnmdEEIGSPAZAhSFGJ0UoioBACcNI4WjtIHUWmbQFNCJXCoXDmfgKqNW3q1wvsm06z8MQI7liUBpq0M1oWUWWa83tetmer9v5sp6XdV229bqs12Uz18D4cH/48f2buzdxvgMIbgTu0S1dY9y+HOa9WeFfuvTutyNAEAMPMcaYQq0NG8KuTwfY+bvgzB4CxIgpYQqICGZohoGxsTO7AxgigZEbIDBQHNI4DdM897iexolDYo6EnRaFgK/pBfb+wV+F9u/vKtzxXgHAHajDY5Cwd0oRDMAd3NAc3ZEceUcKAFDXm+kzfHcGYKDgASyyG2HE3gYAIPCAGIkS8cDkBgYqruadDRAMqHeREQJicCBw7lw772Omfi1/6Zfp11AUd+hzzMBdVBSZeTocQgjjMIK7iOqyAUATCTEyMyAQYYxxno8xSpMKRKf7+/c/vA8cVVREqVUmCiHdne7UzM7ChDFEc2itlZxryykOh2k2M5XmJmaa81ZyYYY0pA62J6JeSpZSRIRDCCFyB9KH8PWD3N89/PHnvz/d3R8OJ3TQpmXZainZmluptTURB+cYlmXBwK3TokwNzEA7wCsN4zDOHOLI5IqipbUM7l1pcDoexvGAsZtRq7lqUzMJmcitFXbwUouqAEGM4TDPMUYgNNFWaqtN1Xo9tV80fxXX4a8Q8rXaUuul1atJdmvQB7bwFRfE9JvNuveIdC/F9n/nVX6tM+eYKCCy+66ehK+wKrt1A+AmU4ZIyB0f0EctvcAWA9JkngCIel4MvPOO/qOLmPCGvSGmDu1RkXVZtLa6rprzaRhIFVxL3p6eHtflauppmB7evAeg5+fz5emlbDmkMA7j6Xi8f3i4v7+f50OMw86Cwi+Zx3eTgn997VX7K00OXEvO27Y8PW0vz7puWoreONNf/WAEiGAH5BD4yOwRQF3VBEzd1b2ZNcLG2EKQyEqoCNZETMRdX9tbbqvpuZWXOiwM94Jc1VuTBoa51OuSpit4gJj28s12jc/vQvteKhDiK4O0N/16YxehM9n6PiVEDmE6Tacf7+a3czhGiOQ7IeJL7Qs3fI0ZeAMrXq7t8vG8PC/keDwdfvjTuzc/P9AUDQmcpFi+btuy1n5mXnfwlxmBvzbhdxjJd6E98GEaO0cvxRACxRSZ2XRn76rtCQruTwfQMQt7vNmJAH1csZ/KvRFKwYypk5EhRB7i1LmTapprHVpLqrx/YEPsyp27+oKKuntHPJhBl2hy096/AUD+3ikCESmEOMQ0pOQJQwjjOEzTCAi5WooagyECB+69fXJgxzvA9wDv0d6ynKDOKBENGQrBivCC8IT0LNA2zOd6Odf7u6nrqyG54Cr6VJqLUmnrdXs5L0+X9XpZli2XXGqtTUSmcTgdDvd3hzdvTsNReNiMGqEzBYaxk1luOwiJ+MYT6N+5I7y+AnYMIYSO+QH0mwot9MqDCZgxMsSAKVAMCIimIIwdSyYE1AcZHAABKRDjOI3TPA9xSmEKmMgSY5+s7wILbrcM7vXg++vg5/a/b6+sPbT3C5HQ0Y0MELqEF4JBh/WB72J7XYIRdwQGke1w2f6dQP8do3f8XWfTgRmqogq1BuTuamYNsCIKARAqIKmTG4o7eYD+zt3A5cbGuHE09qf7Xkqk1tqqgwogInEchjSO0PUZa/3CQANExJ5TzvMcQ2qiOdfn87MZcIwxjW6gpn0GYuZuRBDdO1fBwF3F8pavy2Ury2E+pjS4mTRRrSYtr1vOJUZmZvPbsIYohMAirdXa6jiMNAycAvM3r2McxuPxNI1TCNxKK6VsW65bAXNXa7WqqYNTDOaATFVF3RQM0JydAyMiMsU0xDAQBFWX1vmZimTIBNSQGuwFsJCLqqhCA8itUW/NtiIqThBTcKnDOIYYWy1Sq7bmpu49F4TXMdl3iK3veO2iuklbWr2KrKbFbDQ3cCPT3qdUAACkG0bWoKts2usLNtVqdqNGUiBCZHfcX1Rv7hDt1YvudOQu/ISOjEQhmGqtUlvXGVTVJuqINYbBE3b0SBeIARRARQeEf7dHez8wToBAaRw5cIwhhdhKXl5eXp5enmuxNT/9+pupbNv18fPHy+UsauN4ePfDTyGOa67rlnMpD+Ob+/u7hzdv7u7up2nmGF+/o3/vp/ry8frUmZhDBLO6bevlcn05b+dzXRdUBeszT/qSt+2ZDoOPBENn9kCkPvhwk/2smoAJeCNqQBVpJVgYNgMl63IXqmjgAbxJ22rNcX6rwyTDpDygiG85f/pE6sN9ocMJhkmRVBWQiILbN6F9H+L5q0y8v8LjvY8IcX9ed+dAKcX54XD/8/34MNHAzr0fjMgA9BXM3t0NXMkE2yb5ZXv5+LKdt2mc3/z49u3fP8w/HjS6OLmTNcvXray5tmb77AB2zP6+D/bf7nPy3q78KrrHwKfjYR6HdndA3zVhmFlERNTMRXrjw29zRnRG845kANiHetjrJDcnwtBJV4hqGAMBqJoPKaUhBUImULWmspXCgYdhiDHEuAtAdVVm9y4/BrSnzWYiNWdz6xR6JgrCX0M7iCiGME9jO1ieBZWYeByHcZwAfEgSY0tJTRRQDdXc2WAy+wO1/ym0P8fyB84HlAQK6OKwmi9mlfSK/M/OHxrWl/z0md++Ge5PIYbgDms+L9dLLZ+2TZe8LNtyXZctb1st0kzVCSkFPszp4e50Os5pDOLXZfvsXCjgnEJK8/dDa0L016rdce8T2v6UsAuEEHUcmuOty0SETBgYAnfmG+73uzuTM1IXUrddMgADc0yUpjBNh2E4QIm2cV28cRvvI9+R8619oo69SCCCLkR6O8qwQ0QAvw/tvksYoQO5kSGCoxP2shb2JHHHNeydCScEAiJgJCbbf6aj9Wu+tz/7nnYEADHcCrxcrYkFNGuFPM+pTkliAO4kR7Vs1TAiDRyGGEd0BRAEAxDYpwCAu3Cyf52muNuWt7pt1+cndw/jeP/m7Zsf3rcmy7LUUvs97mYUOrGT5nk6zkd02kp9OV/Py7VUWbd2WYoUraXUWkqWsomQJGrmWlaJg9dScm2Xy8vL5WUrVwA4HO4QvJmaqpsgUxpSDCGERMgOSEghBDoeY4rX66W1huiEGEKgb5VF1LVKqZdiL7at27ZkLQoKCUMAAncCoBA4xBQiBqYQFMwJgA2ChxSHYZqm42G6C5QIuOtlujfzaiAOhkyEjczBmJoFUWwNmrtoA4a9UjBAp4DuXFBNSohDbdJJbjfjiv0yvKkrfBM+vvVrNzEtKll7XPcGrjdUJIG74V4x3Cgce4nTb8XbHLnjh/ZqrIta7N1H6wjnnuL2qy8ovGop9Mu6C72s23ZZ19W8i4RjB6SLtP7gu3qTm4MCNrLv1eO+Cy3wWifdzlMXYoAu+LMj7IgQQXVIAwHWLV+en63Ul0+p1Zzz9Xx5WrelNQ1xfHk5p2FSIHU0gOPdkYhiTGlIIcYbXG4fDf8Vfv/ftJgpxsiBCaCWXNb1+vyyni8tZ22VvqLUwJ77Yw+B4AQQAQghEEUMTOBIu4R7Fz5zUzdBEqSGtAAMBC+hz8HRRLUz5VRcaicVZJPugMBsirm2x+etim45PGR+eGcxiQEgEdlttnnbV68tePzSeQfsU8UvMjXg7o4cwnicDm8Oh/fHdBogwKuMJlIf+PT6qAN6EIysermU5WlZn65e2sO79z/89Ob0bkoHuoKImilLsXLNZc0qzc2QaMez7w2E15H7Lbr3AvvbXdTFDKPF3mjaK0VCYkJ07UqMIhw4Ru6Q7f7vmPW4ssskuJuKaNcFRHWHXOtaSm1NVR2RYwIOHAf1evupEmNACJ3+0iWtzFwVXuWjwLvQs6r2d00IASl8l14yM8c4DtAmm8Zs1Zk4pphidPAQOETiCGZae7AyS+pHgR/A/z7a/xDqH8c2siL4ueFThnXFpcHDQd8Fm9FO7n9ZsTyH7Vy2u3GYk4mfr+vTU15XW3Nb87blnEvJtVVpbkgYpjSkNBzG8XSY5znFCM1qLotgRgaAhJTU5evX0cHS+1XT4x/43gLCXfiXbv3q27xnx58xYkAMHQTUJ9LfGqQQQjd3oYAh0TDHYR6GMEeYyxXaRc0MgnHQEJUSAKFXcTWM0ftc5gavumkwfPm/vj3o7mAOvLc9EbALvMNNv2O/SnB/gA6O6RJ41HEV1lRFVAVNmbERptKq5AWkAXhTvGb9l4+X6b9+jIEJDFudQvvpHt4e4ZA8MDeQTfhSbRNsFmKaj/PpMIY5emDBTqzfP4Phjsn/0o8097xt15fnD7/+ambT8YjM8+kkTTrZuNORmrQ+FhEVDiFxQsN1KyWXbS1bqb/+8iJt0uYqIi1rE20QWSULgKylDjOlIedalmWttXQ1v9oqd6QnqINSF6wJgXk3BVBVwK411EtkcQt9cPUdWlZNq5SOtSq5ShOmEDkmTgkDdjR1oDCkYZ4wcHMzMCfw4BQ9pJjSNA7HaZgZAjr16wEwOARH7VQLU0AHckdAx6DkhhocCMxNyW/uGExAhO6g5mQAEDiEEJj4lo74fyu4fBvaXU13QWYwQdcdQnKLugDg3msF3mPxvt3QzAH2T9JRQq7uqg5GsQcdc7f+nEQ9b2aCfgU5Qsf6t1rX6/K8Lud1u5SyAmJKY4pJxFSlSRNpIUhXNjU3c3GrQUP4W4/39fou5b81yfCVc9J5Tf2g98jRaju/vJzdtvVS2wqkYq1K22pZ8sZpTMOBOALxYTkty1pr7bSN2/q6T/7vXiHGaZ4RwLVtLy/np8e8rC0XcCAicqeuRgEAu+Qf7DQa9Nd/2OtFQCTsWrBm2BNvAmKkAdEBJ4AZiRktRnSHZfNSvQnnSkMFyXWrj76qlcoKJODWpF1sy/V6SSUnCnA4CUdAQ/Um36gZm0tnH7zumBtgB75Gp3ZIbhzS8e3x+O44vZ1pitrD+Bc6e3+VAABuCEqoZE3W5/X6+VIu64Tx5/cPP/94P49I3rw0g2DOrUi+5LaVPnRAu6nc7DkpfOkcvCJ/vz007q577303MOhNJULiXsGT11JUBNGNsVuEUIfHu7j118DE5GKtVVEFoGZQmq2lbLmU1tT0kOWuOrir9L5Mj027N16XRO5NgZ6xmpmKGio4mvVWvLtqFTE1QnD9mpGIIQSK0RPUJCnGlpSQYwwc2NwoAAWjqFpbrhXVo8NY7b7ZG7C3o75hvU8a2IvB4wX/98/8/33il8r/zz/I/2PWB7afyELFv1xbecrn43gIszV8eS4fP5xzKaXVprVKa6Ki2tQIOXJIMc7DNI/TPI3DQCEKVlHV3Iq4qHBt2uRP355owFfhlF7Z3maHgIaggPrVhoNbFwZ3rRPshkR0K3l2gB0BMkKXA8WAcQxpCmke0zCxDLglvZby3ICNR2zXRp7jYeCBoTRURQBkcrpJ37x+2i+/+QbYaNDHZNiVwXbCAXjnloF9Kd0J1dDRtdsSIiIRulotbV2367KVLK14hJAoAHrZFq2bu1XFx2v+X//LL3/5vDgQmg8m72f+n/90/PP74X50Zsjt+pzl41I+X8vz0tIwv3/z9u9+ePPnn+7uj3EcAZCwqxCCAyqQdbL7fjrMSs7Xy+Xx8yczO4pMh+P2sPY+bt+kpeZlXbecm7ZxHKdpbtZabh9++/1ffvnt46entVjJv/zToUUe0UHaBu6Bp3nQ5SBIpWg+3GFIqUrNWyHmw3yfUiplC8w3NpFhIIKOWONOZQIonrdScyl527I7BA6qkvP2BV4DAAAiWmqvISUN8TDP4zCPcRp4CM4gDuZAGFIYDhMFri7ian1kFYBDYB5SSIEC+U0207wrWXEgpmRqTt5ZC4wBiASasfWEEs0RbOcqBHJmYwIOEBKbIwVzr61BT0f30v31avwqfHz72x7ktAPoCB1czbW7RnRaT88hbpfdLjLmnT/dWtmqdwKvuourqLpN83w8HmJKHHsNQTvQBPZeiLuq5SZbLdctX5b1ed0uOa9qQsQpTeN4EIVam1u3CRDCgIQ3iyYl+x4GfDv8fzuour+mwXvxBR3d7db9a5EwjUOaxlq2bb2ez0+1rcMcODEOBOJbK1QcQhxCiCl2geKuxGRmTOGrqv0/2JNn4hBCyzkv1+VyXi8Xqc3VesERQXlHr4IC286S6KESd0KimKHcbrUdEdNl22C3oiNncmekACAIEgIhQhNVEw6BMKCRZcn5Ua/FlpW6jrCCSJVWsWYfB5ZKZo6mtUreyuXy9YOIiZrYF0m6/mr2Hs8t+wTs6qGH6e793fzmwHPEiAq2g+fxy9CyAzbA0Y1cQDYpT4s+XcZc7yO+bfXN9Xz/Cw5LSO4rDise8ONLfryWpfirSaWhK+yWmN/Itb1iS76pr8ys1rYbNhARYam1tZZiopQ6n4oQAHqLC17R0bs1TpfbAbzpeLqI1tbOa328bOc1Lzk3VQc4zOV0rAEhoE+R5iEgECOhu+92ZBhuRrF0G7J1GRVTE+ka3tXUIJmn+N1wl4i6DWQMnFKQFJA4hC5ZDcwQEw0TVyFvSM2Gaodqp6anpAfQETygd33KrcHnFR8zL0qGNCe7T3DP/uzWil5f2uVQ7wd1heXSLue8rluzAmTqnUVoqo4cuo9ppJBCTIlDNIrCLjF6NRSBpgUrqH3TDcLXzQ5wU1dz2AfrDmYAu7POju3p4/iuZOZ9ELRbLuydRMOdrEsUImEIYYxxiGmKMQ6BByoMG0wFWCEwJ2KqwCasRKObgiL6cGtNfYMD6il4/xv6UgEDwG0c1F9On1P1PgQ6ATBab/Y4AX5VT5EjqXnZ2vXl8vzy8vR8XddWiwUIA3EMuElZ1lXMqlK2po8vny5ZncB8Mn+aI6hczsNxMELbqr6U9nkrj9fyfK0hprcPL88vLyW/++OPdz++PQ5pQCR0N9hlUr7OXNw957XVAmYMEHqjVaQXHObWpJVS8raVWtQtDYOZvVzPz48vf/nLL7/99uFlKVU458swpmkwQq5lRYdxCGVWFaEgVWtREWhquix5nNM0BxXU1hAbESA1JIvMRAGB3bpEv7ZWas3Lclm3VUQ7Fx2RAZubfx3amUIKQ6CA4CmlcZjGaR7SFDCykTdzNQdnppAiMiJSANObkD9SYAxk/KV4AUe37gLkRhAgfMnXkZAAyTmgO3XyOwLtkGGiGIxZkYwYKAAqme979atZCHxdNd3Wd6EdX2VHGJ0IwNVUDCPu0RB2PdsOCLIdQOkK2mxb1+VyWa7XZV2kqQuUVsXau3dvfv7ppzfv3t09PEBguJFRvQvRgKnlXJ5zPue85LzksjQtjooESDHG6TDfu4PbGbukSlNCQ0TruO4OuPt3L/+qT7dD+qq0teaqApEOD3dO8Pj4qW6XRUpum/kwhWkYp+jkawMPYUzT8XB393D3cB9jBHNpYqK9u/Z/Yc6+f0J33dbry+PjdrlKEwBk5uCaXCbTwZXcDCAjNWTBsPtd9JvCzBu4GNa2/zDfFSc6E9zRPbClDu5N7ncA4I6IGkKZ0GMKw5ASRShaPj23l4WbkSEhKYE1UTNAY7Ne+aCZXC/rh9/zF8t5cHAxEWs71vJ10O77G7ilIYDIMcbD3Xz/4930MGLq0GSDXavOb0MUvJVABAZSVZaqT9f4fP2xyVvz42+/HXh5e45v7gOkeE7HX/n+139+KR+e6pJ7G2NPb2z/MbfO4t4O2Gt3++b9mVkuVUVEpcsIdh9lAhxi7CNexh0h3H9Kx4ECAPiuCIK9WcLGHBClNnl8uf6fv3z++LK8bNnAmcM8LfN0GchHxp/f3g3vH4g4UAAHbWJuSGghcGdBEQViJBdVN1VptZacc6vVzHYXmW8RTz2/6EJuQ2IbAyCFgF0nnAOOU0AfzaEKcslzkWPVk9nsNqCje5fmdEc3RIe7AY6B7u94OOIYfED6M5CY//8ucv1Uz0MGxJKlbracm5imCQ1RBFXROs4WAwN3Jz9moNAo5kh1RISQQmMAAGu3Ofp+Zvec/JVgjbeMj9zBejX8Orj8sukc/OZF5YZuBNZJGuSKXdO6D2vDMA7znGKIzORMlXF1XORoPg5hHsM4RGkim7mIbZAj5SFUSELcN/qXG8D7MGA3Kv5mIQDthRfu9fjrfI3RGZ0RHZD2bBzcAYlYDaXJy/P24ddPnz4/Pj69XFctFQlC5DAkaC6Pl7WqNQdx81qLmjqBQVWwJib2y+8UWc21aMtqm3oWz00R6+eX7XK5LOfnvP4c4Y9v7tI0MJB305xvgSjg7nnL4H53PMUQjnd3x2nsVMwuBVZLKTU3qTGGOaU0jGb+4cPHf/6v//zhw6enl2tVNBwVxUUpGAGspYG5QqXQ0mDkWqRd6/nx0gAQIR4FOUqMiEBNRKQOI04z88jI0RVU3dRFJOflupyfnx/XdQHgcZxCiGaA8Oo8vq9pmB4ObwmRGWOIMQ1pHEJM6OgKFtQ7G8hMRcAQI/E+FwFQRO1wRjdUYuRuzNAJAloVwJhS4MiMvdoyMUMxUEBGcuLI5IjuhoQURgrRkbpdo2kRyU1URIi5b3k06wpD+K+Hdug9WxOzZloNC0IkHsIOdtrt+Xa5WjMwNRepuS1rfny8fPz4+Pnz49PjlltpkKVVsB9+/HFdtj81JQ7z4TCMXWu868M0kbW0y7I+rduLtFpblV7h7aoIQByG8aAmKkXE3ZtIcYAe2t3FQf3fHdu/yIfvTWtzUc05L8uaS1H3NA2zH56vT4qm7BbBB8Y50jSYojdDY4pxGKfj6W4Yhlrrui7rck0pxpToSxL1dYb+71it1vVyXS7X9XpttfY7gcGjyWBtAhvBsPuMuzuQcxKKQJ2E83qr30YD/ZO494GTeXckUzOD2AlGDDArNACNwWPYOIwRGIvYuul1sXUlDA5kAGoiUi1QGA88zSEN4N6Wa3t6rJ8+SG3A6fVBpN8ioF+uZv/mL33FFObjfHxzPL49xUNy6kWUE75mLLdXt2OKCMwht3hdf1iv2NYxyBHttL4MH9ejwdsXimOaaBY73P++TZ8fw7pB90dV6KaUXwTwbtnGK/bjG+2znmdJN6uoRMTMpooOptpqRbMuvR1CCDGGEIkDAGo3g+8a0mpkISABAsXIqlRUzK9b+fh8+fhy5RiPx6OzIWn1VkDfng4cYoyp76jOPgF1Ee1zrR3dQAS7pfYOZRWVHW0D3+8+76x4N0KIkTUFNXPQVouBgVuKcThFYlYy1na3+nvSP7L9EOxIzu4quxP63QB/fgNviGlMf/iJ5nsMbiD6Rn0R/+0qV8wvadFAdVOrUBdvPcFjUAMDJEKGwB7Zud+eUr0VpdQcIVAYEhKbmbrZX1sfI9xEDOGWnCEikiNRdxza1Z5ulIwbJqFLNSmBSDe5JAAUQ3cKzIlSiNOYxpHT6JwqeHHdxC6GV50bHIknoKhQq5cqVkyD0RDRmSYqA7fg1keptwtg31QIf9U73TOSvg1vOCToQHd3FAczq63mcs21NFHkMaUjYlDRzx8uv/zl5ePj8+PLy2XVtQB4CBhjBAP58JJLD0ZgYmJg5oSGVfGqJmKfO7wIVEwV0ICdCIiYMSu+rPrr43oYX+7SRBrSu0QRHXfm8jfiFYjjkNgUS6EOeAUIPUKC15qRoEPih3FMw5BzeXp++fDhwy+//Xa+LFsRowDBHLLb1cXRKesFzLUpVOdKTDW3a6nn2lbCmMKhClSzcRxiHGqpteaHN9PheEhpGtJQc22t5ZzN27Yty7Js21ZKCSGpSN42d2RmQv76HkphOIzH1ylbCIE5dstB6zK9cLO6dadAiRJG3vkLvquKUAdfArF3S/imItrdIQg9AqauhMEUA2CQLhaIhMiRAxPupBskczInMRetueYt51qbmRHTa7vnu+ykr+9CO7mTmork2nJtW49eMR4Qd2vjHsxJoDPFwMQ8t/xSnj7n338rf/ll+/W3l98/PC35pfriWEJ8ebpuSzPjkMZ3hkTBIxOhu4iW0i7r9nRdnrd8AUDfqSwICF2nHgBjGiafXKdcaqu1NRNtAN1b2RHBQODfvXZ0Yf9iTFVq29bterlu69Zai4FiCojuoJQoDEO8m/k0eQhSrHaJL+KQ0jhOAHi9XmPi+TSN8zAdZvYAe9r+Hyze87Y9ff68Xa+1FDcjJjIhlahl0DYSDoDuhiZRmzhqMApuOCD1M8UcEwemV4t6AHAzUZEqNbeaRaq2pq1Z6DNHcjtApGEIMa6gASvauvriVBq7A7qCmWhrtdSN0mF8eBjfvE3j3GprT5/b4yd7eXIgOH4J7eZiJruzSw81e28d4Suc7TCmu7en47vT9DDzGHVXjAciJ/Z9DLknwehG6EQqtJbjdfmhXu9pG2YF9OpbrHlcdXYYVuZGdeEfH/XdU/0lC3Zz812QGwBupHrrUdG9C8e8Rvqv9gv2fdIaExN01QV2tZKLceuy7SnFYZxiSuagal0FyUSbSG2KgUeHNIQYOfowKqZxwMBV9OW6zofjfRym6TDPo+UF2sYchnEapikOA7iqNDAzVxFBRA6hy8vf/Bp2c0HkPjK+Edu/PfUqvTfoCJBC0KBaVVpVK+6g7inGwzAOhwBTQdtOF/gZ7f8W9O9GPzFEAxcwBIr48x2c7lHnSIf57nSYh0i52LaNku+a3Lf2JHhG2iK3zVBIN6wN3AATGAMwMXGAyB7IAxlJgfVigCbmkMg5MSmyQVAHo7+i8cHeFvmSwSDtDH+G1Am3HXjWX7TvzSs0AEUUxIJoQJEIkRQQIUwhTTwewzRDnArOonOzetX1oi2DNZwxHJm4GlS11qyJAiDK0ACMeUQeeAWtwXox3rdur9k7r/67a8hh9/Yi8Bs1ntDJHdW8NVlLPV8unx8/Xs+XvFbmcZrvKYwK9Onx5Zffrp9f1qdle1nbdTPXgBaJDEivrVZBA0fqyJTesGIAamaqPU82R3BgxEAYI4dh4mEIYwohYDH//Fz+ET8mDHeHeeCAwRxcHb+OJsz09u2b7cy+LLWUvKzzqRDhMA5DSiIt5xxCIKJpnonDP/7jP//26+8fPnx6ejlXMSXyQBDdwibotV3BQTwjQrPVaoZcmbS2py1f8poJhzGGterLcj0cj/PhJLW1Wg6n4zicpvEwxihFtlrNhDM0ybVWAAghDGkIIZRaVX0YxhBCoC/sqsAhxYGcCJCdyRgamPpuOi8qrbkIODBxxDiFKaQBiMxBb4I2TBwCA5i7qkjL1czQobdhHNmJOaVhGIZ5jsPgyApYm5sjcyBEdatV1pxrUwOrra3bWkpustVW4UvS6F+vr/fVt7x289a8tVbatuVLWs8hYYQYYmEeKZADSKsKBhqdidBVi+aLfPpd/uWfwq+/PHz4PT5+nq9PT1meGn7C4dHn/Hz+FX8bx2kcRwJKaRidOHjTLZfLdfm0bk8iGQC6Sh2SIUlviZoJIXKgcZwivw3r9XK+1FZUam8i0G5drf+hAPrN5Y23mFNrWy5XtVbLuq5XAE1jDJFoDMKWW85ru6wraSQYAo+Bh3meUwyq9nrZ/F9fUtt2vbZS3TrVxYNrtJZMkxsbQUcpqFirfaKASGGcwziHNKRh7CGh06+p83scVERbLXkt65LXS96uqlWtqZojECTgyBCiz1rENrTNsDYGRA7qINLKlpuqBU6n4/z+h3Q4Wsnt5VyfPvn1zCrM3xARrYNI964pAtwIFv2+dUAEjjydpocf749vj/GQMFIv2ZG8a4f1NGxvsxi4AKpT0cNlfbtc/jNsP40tza7kK5ZThDuQJB4EhwJ3V3i/wd+J/0Xxnw2zgfnrOL2bq9xIyXbr0v8VLIWIUoymqizUYdVE2D1xRdCp0zbYOYFT944GBFUzq01KrVtt4gBrTinOYwDArehS5JrbVlUMAUMIcUhpHhIHCx4e7o7HwzQMkUMXBWMHR0VA3GdZt1PNISBTMO+6vxRCq5VjhO8EQcFrq5iLNTNVImBCFc2liAozj+M4j+n+fhxA2gSyeBhgVLgnP6BHQ1QwB0Vg8inhNBG8OdHDzzw9EEXTj+XSnq7+cZEVqCVcddvG4IpaABVBUIuDA0QiJ4KAwI4o4KuJa8nrmp41TCUMwoPyABSpa7FL++5kfd38+VIS4FeCV8xd2ky6NEuHAsttvIO7FJwbOSEQ0kDxaOOdDneVRtOkbVAfKtRVaNFaTc2HHawAxM5gkb05CwCbBNFUPGQUA0mgwb1T8W/ySP7XoqBfx/Tdfg3dXJu3KltryyYvS/78fP748Xx5PufrisjDcIY4CqWXZfv8XF9Wu2S4ZFuyuDoaICqiKiqg844j6Ds8kCP1mgidyYkVCV45tEweyVMIw3AYY4iMYvJyLs/n6/lyPoUpckBEcvr6QQjpeDjoupacry9ndU/z9INIijENA3EXhYUuqbmuebkuHz58fD6fc2uKBIFwjDhECOTkauIGENwdDKXSdlUkldaWKlUcCBxNpEp2a4jV2MRcW6tmAq6m1ErJ63LtVnK9EzFNB0JIaUDk1tqu0G7m5K/UsRDjOBFox7gxERETMCCCubFzoAhBCTD1dn0aiAMgmrmgOjoFjDEMQ0QCd62Va+rehN1DCgMzUSROxIk4ckocExDHZqqAxA6IZuLVShVrqlprbTu471YdfbuH/nuhvXkpmlPJ9bpsZ44vo7FjCmFlGiIxONa6uTQPg8fAEa2s8vSov/yL/Zf/ff746931OeYNuK4TPU3xHyn9Hzz9Ivr4+WkYx2FI4zCdTndMAcDW7emyfLheP+VyRUgpjuCk3Z8Mb+BfVwAKGNM4puOBQ8hlrZLVBbtKh/f2z3d27f+mhbfKHbv1O3OKaUiDq13Ol+fnz8vyIpIdbZoGH0EjZCnX63Z9yeu5YAtts7LJuuQf3r3/6ecfUxqm6ZDS7gH/f3H1dM/dmBAB2C2aDiYDeEQCB+lMwVZbq9q3aBjSkMb7++l0Nx3v5ru7cZ6HceAYePcShW6Ak9dlvZ6X58fr8+N2ft6Ws5dNrAEP5CPbQHKU5arXSorkyCEoktRaa7kuF2MaT+/Gt++O73/kGLenp+3Tx/b0GcqWAif+thuEXwL5Xqvja3T3Po2LYzw8HB5+fji8PdDEHsBg72110P8OeHMAhy4QAxXD1t5clj8vl/8Jy98dJI7oAQQtoB9Ag3u/5maE94x/DvSPhKPRYi5+I/z2fKGHet9nG7fG6Tfnh4mnYbxBqf3m8eOdIOLuoirSoFKIKcbEHANRIwZAEamlldquuZ7XjZhOx4lDaoYfntdPL9tSDHmIMSUOiWgMeDoc7qbw0w8P96c5pdBn4xRCZNoNlcC7hvwumZliJ4PFIcUhDTmv67qXf98YUkPOxXHDZiDQle9bk3UtOZdxHI7TcZ7H00OMhMtg6yfV2GfTgAoubvuMBAw9sAdgHN7g6e9p/hGc9aUs+cM/fpL/8iSfmC8zNkOdXTBIBURiJFWD1tPoABbcSMRXko38jAuiAjmFFkdJk44HHqaAEZ0J16+utB22sWOAAG8g0lsJ70DMMabIkYmxV5ru7r1GJUTrbEojBgIPBIH4iOGdDw8STs2CNtCGCmhEzVNVaLVKcXQhTyONTBhRiYtP2YjBo6tUo9WzQFGwEXe7wq/30rfDQ3d2B/DQSXndWFNVy9aWa31Z6+M1fz5vn58unz9tl/NWrqtbQ340TpXHorRWy4JZQjURVCTsyGzATtPfvyTrIEFjtICOEX0MPiYbBg+0D1rFzEgdAsOYOI3D4ZSGpEuT82XJj+dnHO00HBgSA9NX3gSIOKTkZtfzy+OnzyI2nU6t1hD47u40z9PxeOp9w+fn8/Pzy/PT86dPn9ZtMwRjghTiPNA0QYiO7EpgvcZ1hGjumy+uVWVzdB4SQTAyJwUQq9Q8oEEEqLXmbduia4N1uVyXMyHGFOZ5mqZxHIcYYwis6uu6ttYAwOwGxAAAgJASGoAQKhEEJuKb6pGD22DghuqBaIwDMalbZ4ORW3AjhBBwmsLhOIbIvfes5rVKLa2W1qrususcDLk5YJe4SIzEnUin7iDQKYbq1loVqQDezc7MGQRvkBnci/e/qkW+05D3ddEhtSqltLXUJcQ5WmltJR45DgjsIi1v4msMPEzBl3P+9Vf9y1/ot7/Mzx/etuXgEtBrCO8IkH1lvmI4N8prefr8/Pz4fHl3RtA44Mvy+2X5UOrVXFMcmJJpnzndVMfczdTRVAniEOMcY6FODnbthC5iZ3L6G4/2ry38qirrEzoi5BDmw+HhzZsffvqp1E29lrY6CqLzRD54pqoiYk2siQsCNBcMdDgd3/3wwx/+7k8//vR3d6c3Q5ppN4eA/3A3HuDG3DV3RHaLrsltcOgeXW6m5lW1NKl11zaOB5gPx+O794c37+bT3XiYh2GIKdDNCGCftrvFKaXDOMzTdDquz3fLy9P15XG9vqCC52bGBgFKIgGCrpXrAKbaxMUjh8Px+NPP85t36N4ul/z5Q315grqR2xc19du6QRX3oN6ViHe4pgOAcwjTYTq+OR7fn4a7AeJ+T3eOHtw8yTt6rmMBQZyzz0v9cVn+mJf32B6S8+AQwRHQnczg5urEBAeGH4K/ZZjRXwDkxqd/PRlwG6H9raT49Y1AYB5SMtMdRITIhBFCtzfKaqItrBsQx2iA1JpIF4IljgGY1cy3nLetiuEm/s8fXn5/ul62ag6EFBmHgFOk4xTvD2kaQtiVgKHzYdwQcNessBu419y1Pypid5rpemXuzsykX0nWOKxrFllQPTgOnNzN1FrRsklX2BtH51QYJCIhxOJpUXtRO6IN7BRAyUvDSpCaT2hzxdEDQjQnbbqt5dez/tcnuCYS4zhhiJBRq6kHcEMT8AZoCA2RsLgLaG86OahDM3Nz4SBxsHGOwxQhIpD/uPnhqzP8laxDt/CGG6Jkp4oDAu3qcrAb8jm4gbqzQiCbxR8ID46TeURiagkgeTTDbMgKKMagDAimpG3N+fOyIfoU8ScOpykoYNN4hWnDRF03qFZ39+TQesMUIexyTfjaRv1msQMiBMTuAIsimrf68rw+Py2fz/nTeft8yU/n7XwueTUtbiaOrVGpVJpFUTYzcgrOA3Gf56aAKXRVJVQDUVgbVkF0JggR4BDwhxO+PcLDyYfoBiiOzTlDLDAqz8AcycgKmpCDu4upmgH0dss3IinmtizXUrIDDON4Sunh4WGepxgCuDPzNI3Y3dNbO7+cz+fz9brUJp1zDIEhMiZ6zfjhptADrraD/qpKIetIdHcXt+YuBo4Yg5Mh5G25nB8jHXGKpg3cSmu1dVGKSBRSGrp1eL9Sujnv12pnNw50DJAYYyDmgERgaOq7OguqMhLFCG5Sm5ggegx8Og7TEKd5OBymw3GOKSIzIJlhqS3nsqx5W4s0U4PduTqyMynupB1TF/WqVprmUmrbpGWR7CqETgGRAmBs0sG6Nwz+32JhfRvas18vOk3StDXNopt5ca8iK7Vh0JlpALVWSl6vRHo4jvjyvP7Lv/Bf/nL4/HHanu+wHMiYwShOBC/U3qI/pPQSR3BcLuv5+fz89Gi+hkHPy+/X7RHJQkyEzBRM200XEPcgZAqOgqLq7gEggu8iJkTIjCEQUR8k/c3b+F9deEP7AyBhiDgfDxxDiDyfpjgFGuj5+dNWrjyiRQF1ByWGEIkTElMc4v3b+3/4H//Tf/4f/qe//9Pfv33zZp6nGCPRl0L1P7w4cBqTtKbV2DSaDG4DQiBGRBUR8WpatDVpYoaAA+F0Ot29f3969+N4OHSpZgYwU+kyRt0phzAk5ngcpvH4cL89vF2eX+Jv/4Lg9fqi17X1TLVRhIEpdK8/NW1SFSHcneb3Pz786e/H+SDLNT89bp8/6LZxV1r4a+3f12+ib0X6qiHvAI4xpcPd8fjmOL+dwyHu7UncsVoO2E0GALH7TpkBqaVNT9f887L8XVlPKCF4R+8g+s3qFbvkNgUcI76J8MBwQE8I+dY+ALjlvP1w3UY0f+u8mKkhYgxBFUQVwZGoS2OoaS3F1y2X6rg2hTgIMTdVVXPHECOGpEC5aq3X5+flecnPS/nl6frh6dpEhxgJPSAmxhRwCJQCEahKJWToPshEXRP+CwoUEBzMobaGoszEzEgUU0IiMwcE/haIsixbyQupJQo0o/v+rUp1mDxFCEmMikmmFrxNm8JTtd+qzG4H9mDQCD83+NxgmuzB5Me3Zcgr0uJGnpeat8fNf9viZmMY4g+OAezs2kgtATiaggqg9CMOtZnVpgBK7tQAg4iKiIMiWUopDrHbVx/e6WF63VOIyLjL0bxqB+69l1fBLkRgcqbdJ6crYLN6MDugvTF4p/BG7KSaDBCoYVuRr4QXnoOHJHVAH6gFLODlZdv+z8eMBA8zn+aB0YvAufpHoksYjgEmcM1rzi4x2MgG5MAwBgg9qXUC/C6yOzAAI0XEgBTNvdR2vpRPH88fPr58fFo/X/LTtV62VoqpODoDsmKrLlnUTEAjAwSHQJwgGhq6nQY6zTwNHoMV9bXA56ubkDsThIRwl+hPD+M//BD+8BYOEyhCIyxIi08vcjxnPm/mpUApDDoESF2lgQJBIOhe8l+Wqn7+/Gnd1jikH+bD27fvfvrzn968edOn2vuoBFFErsvy+PR8uSylNDVHIuzifQSd0w1gtwTIHAA649NEtVlrbt0DAJHNvalXAsAwIhA45O3l+bmNUYd4R4Qhhly2WgsixRgPh6MZbFvOpbjbLoH67SlXN4Gu8zBHSp0uj+i6e0GZioKJu6G5S7tuq1mLgedxfvv2+O7heH93OBzn6TDFNHCMgMEcS21bKZfLcrku61Lz1sTQAJ3RCQ2gm3Jq1dykVMmt5VLzttaaXRq4MRKHwCkChtK6NQC+XpN/HWu+Ce2t+LZ4raam5s28ujezUhs5hJQOMYCaici2vHhd/ULy+PT8l1/Sx0+0LqAtRE+MIVJ3Rwxt0RYU1QIQB1DYluXp80exOMy+lpfachoGwkQUqXuSdEpf1z1W79cigFZuJXSh6S4u7F0yqrNF/oY6/r++8LvUuV/zSIETDndvHjBA8woB8C8gj7LpUtpWsPYUAwgo0jTNP/zww5/+05//4X/8T3/3pz++ffduHufI/BXX6BXhs/9h/y6OexrG0/39dr2UVqNZUo1uDIC3Kq073Gs3DFFFJHDnEGJKIRJik7qqNTBRkdaVzvbQzt0+GJEBeZhHJGz5Wpezr0tt4gJGHm7OkeYurbZWjShO890PP83v3g/jaCVvnz/l5yfZMqg6kTu4qeL3hIWvv+y91tqltgCJ0jQc356mzmVP3FWR8SaAvWObe/Hex+EK2Hxey5vr9q5sD1qGoN0bzBD2Rwy3y8LMyZlgIpgIBsJAt8h4w8P713/z38zIEBkYiIBB0YUQgAhjioEDqjYD59S8lU1e8gK4GWB3aA+IzBxCcKDDdDAjN9qay2VrImrCgeYpjkNgAlUtteVIW+SX66ZqPWAP4zikRNTl1jp6ofd6sWtumyszBg4xJbw1fW5f3pex9LbVFTcUGzkMMQYI3AGBjEjmvCnUoqW2qs2q4gb0Qfn/yJQcj4yzgCD8ssD/ueFpgj+YT8/Xh7vfrRRzhO05or95uPuBDy84+Ezj1HzMhBuDD06U2MA8m4mrdTgNSEEzUAAndXJRbdK9nq0GoLALn7WjwvT1+6Au6a/NWlMPgGmXNu8tjZuv+TeYmujw4PhG4a34oQJkLwJHooFhDFjQK2glqS4AMkKN1IYoAcVMmAzRA8EQPAVjtCZ22fyTb0+0vEl4Ymu1barnyOsU3AJBYpqIIlGHAfw1Kba/K0RCBdtq/fjp8ePvnz/++vj50+XxXJ6XesmSm3blQQIFNEVQdAcN5EPwkcLAJE5FqVYz8x+P8cf74TBaZL1UeVx0zXIFd9Du0T4EvJ/iD8f08z2cDuiBCtIVcLRj1LdDS4dq1PLY1qFsg5S7aNFKsMboSO6uX/u1m9myXM397u3b43x49/bdw9u34zgBc9cmMhUAVNUtl1KrqHZp0l1TwtRqBeydm86Zd0QDR3R2AxcDaSACCrvTLJmDuKvD5ng1JENoZSgZRE4dtxJCMNOcMxEPw9CaqFqpLZdCSIhkr+qFt5dRal03szHgkAxQ3QgcwQ2guRXpfdvKbkNg17rmzV2REjMe5uHu7vBwf5znKQ0p9CE6kgGlxGngFGme4nZsuagoyP4A3lRqa1mlet3auuZtq6XWLFLUG1ADdwJ0D9ZCa820EKd+h9368vjdtfVtaG9YMrYGqqYmBtW8qlUzU8OUjmbYTJtK2S768hlqWT8//vbLL/PL01GrMCBziCEMARBMDepaL6VQrYOPcQhIZVsfP/8uHich9QIIhCnQRNR9cPC1WjLzbkbWO4+ElWjZ8tqju6k77+WVE35HQf53rR2Ds0NYEYjSNN6FN38gxQirLE/b8/r58bydIbqSiZoBYOD57vDzn//uz//57//4n/70/t37MU28e+je9DN2rbP/4GcbpvHuzRs08eWcXJNp6Hq/fbqjuw2YqZqqmaKqaRceU5OtmbVytrp5KyKtdYFFs36/EAfmyGkMaY7DHMc4HqZ5niQEVUdQIABgoA410ZJzlQYxTm/evvtP/+N4OMlyXZ8fl4+/1WVBR0RyR+v9ge9UDl9L9m8WAjgShsDjcTy9O433M47BGA13k7fXq+NmhN7ldhAEqflhLW+uy33JB6sBFQgUoTk2R8SQQmIApAbSHIXQE+FAmLrQ2K3B4F/Mj/ZfX0kNfXtaCDDsCmZOBKSAQEgYghObk6ErpeJ1yXUpJdfWmjggEw4hTEOap3EahiEOfIqBU1b4eNk4cIwUU7y/m+dpQIIqct2gS5/mKs/nlQAC0+nueDoexsSRicFv6uLUafrSTE0FQYMhEhHf2vXf2dxAzm216k00xtPcOBIHTimmgTiJ4lIMpbXSrDavZitCNsyFBqU3BG+CO8BfnvF/O+O7Axr4jw8vMhsNnxUJtucp8p9+elfev/8lx4V0PF5khsRlYCQKkgnRN5K8mLkRqxM4kCuBIaAbWc9a+8c3VWgGjg6o8jWvHZGIDFhcNmtbgwGJulgU+h7bb7T2G+mXHWbAnwF+drwrDqt/vkJTmmc+jTiyBwqLQ3EFrUx14u2YZBqMwJrAIdHDGIbo72c8pq4tq8umj+Xy0V0HbsFLbqvoS8A8h+Ap8TSNnAbuEoJ7S/LrjdW1aRiQvGo7b9d/+u23f/7HX58/XM4v+brpUmRtKmaIQNDlkMAQnYBIx+h3I5wGPCZuSkvFZfPW/IdD+oc3p9MoTPXTUkztIwOBiUsfVDDxFOgw8DHB3USQwgpUGgYdYzjdTcf7EA7c7nDj86M/fY7cBlmDHhkN0BTNMQDshj1uvrUypnT37u7udDqe7oZ5ImYDdIdayrZtAKhqpRSz7tDL+0VpDqKaszZyxE5Y2kmuQIQBncCARFE7Uk06PQrAGR2hGlwQSQllMGnJHYgjO3AIDiDaai3d803VWpNaJYRAhO7fDQ9hy/nlnM3YkSNWBmZzNDAEBc/SijSthcGmGN1kKxujp4GQIA1xHIc+0e9TSdy9u5zJh0iBxnlMcuei3gxEvYpX0a2UJW8Ksoq3XLJdtnotdQUQx4akHXfmilapFm8CEb8oG8E+4vzmSb61h1FoFVoFaS5SRbYmhaWoNXYo9QJA5uBWNZ/by0d6ucjjc1xeZqlHwjEMnpJOE08zow6a32n+h9YuVjKs7iW42HbNzwvH0eMQhpDiEOMxxWPgERGJAqIA6O7O3bWx1R1UqEljlWYm5l2BDsxADYi+a3H929ZXhfttktuVIcjcrJtLhOBMCrDVdl0zBFCw2loTA2MnpMQYSdHUxUC9qxPBNyf3FbjYbxz4N9fuMYTDNFpgB+sYOnbDnX9oTWprVaSZCpiBg+185iz5XGhzK2V90rK4VFNV9dZERMwUwImYQwzjIY2nNJ2Yx3Z9hpKD+cCRfIeRmZs1bVLXnJVpvH+Y372Pw2CtbI+f1sePbb2aCFIgQHcQ0VJqdf7+YfD2RaO7f1EH5RCm43h8mA9vj+k4OANAt4Yy5C7S+OVlubs7uaE3p03TZUuXBUpRU0BzgmZ4Vv4sIw3Ht3f3B9aUn71dDRXAEZERYu8n3rwz4MYK/VoSDP46DwEAALvlVd26QMVV1cDUYCvtuubPz9en87LkupaWS20i2N1gY5hTnMeaYmTiXNrzZfnwfHl82ZrgOB7HaZgOx5AG69Kz4ksRN08BIwGjB0I1ldYOUxpT6OLngZgodKhOl47vX5OGsHttq4pqLQD+hYvYqpWmWgUj1CoppBh4nINCHA9CQ3Om7r1kJBpEYtvQzg2HRoPDH9gmt6cLrisMAM/RPn/Mn4KN8xoSkSvSyDxhG7cmF5cAEOdwfz8dphgDa7Xrc3n5XJ4+1boZmBNDoCArthVq0+YinZx3g0Ds2jQ3Zuzroe3qfKZaa81LAaGAIQwcIjh0h8Put9nRlBSMRsO3Cn90/SNaItAAYXRCejel44ApqDkHgQSGqDPLHGqiwq4ATuDHAX48hch+HACBcoPW3JrqttRal0jGkAUWh2uEhpzyaJWiSuytJNxl6b89GTusX8GuZfv4/PyPv/7+X//l93yWsmpuXkSrqYHxzU7THRXB3ACcwAi9y5yYq6o2FXFT192kFsxNVGUXqwLtVjQKoZhsLawNuJiKPVX7p5f6ccsvrVE8HQ7T3/9wfPjp7jgAhqblrF6DNjLzDg/Ab96GNK0oW21+XXKTLFrMQ0wGkHPZtrKbqLfGTOM4zIdJlqwdJiTeK6v9pbs6OO5WxdChheTAjAQABojYSekhEAKCqYOYg0k01X3yCMjE0zS5aRqGaRpDCNDfAtFN0+37VnatZVmXEIaUBgwOyCpqTdRUTKq2qk2lAVgTQnSlFhgz4urbKmuRWXSISgS7nfputo0YiGIISGyABqSO6iAGTa2J5FofttNluTtfTs8vh4+ff3t5aeu6VckAiujIaIjmpIxGsOvUwK4/ht90pgC+D+0CUrtkhLVaW9tay8ylE31ruyASQgItni96/iwfP/P5+lDyO5B3IYxxqOGwDnd2uJ9Iki0/4UujS8vWJC9YxSplVZM6H/kAYRzScBjSMcUjMxsYUkBiF7xJAuzXrrmqkGozaw5mrmKChtJJu2bfnPn/3nptkDvcRnPugF17z7s9wFrW87ZctmWrpUgrIlsVqyYmuRYzCHFopkXqkpeny1MK0dXmNI9xYGJ0/Kb463/urYS/wcr+O4uJRuaGrq7ou6xsj0KqKrW2VqU1FekjRDNTrVIudfmk1bVd8/VRyoomvd5tTUopItW0EhGHFMdjnO6H8RRo1stq12s0pzjspCpwUWu1lpK3Vigdhrfvpoe3rlIu5+uHX/Pzk4l0iS0HM0WRlkup+K8q+t/IZoAYYzjezcc3x+nNHOZkuLMVkADJAG4lP3Z+R1cwBKxOq4TzRpel5ZpBR3RGqOZnjX+p9zj8SKc/hJADOm7VMXc8PCOGbsZEXY3sRn/rb+TLlP11d3z9qX2fgyiYoTrmoluWXHWrclnzedmeLtfzsuWqVbTVpqqIHommxDW0bavgkKs8X9YPjy/nrRYHCGmYTvNhTMOEzALUHINjrqK1DQFTwEAQGd1NW7U26TQEpsghxRQDIrm5m4qpgJuBmwQCR0RVKdtWMwF8pSBUoVaTrKRQisyjhxhmSsgxHjyOgGzejSxQLFQfZGW5qkNDFXoEf294zg4NWobrBT58bEeThzs6nTiNY0tpk3DO/vFpffaNf/B3Ez/8dBjewJCCNlvP5dNxc7xeXkorQgOEiSuRZrPeXwK12wQGEG/v59vk67bnTFtuLdeGij3DRSBDNRNTuekzIVmIHu8t/mT4R6x/pAYJlPDNhIHDaUyRwLRCA3JPoIlsJh1J2Iu2AoDuNMfw0ymqAxE0xaVAEwdVKgssviEUphXDyrw5aAxjA5RhUnN3ukWU7zqnvRcPSGJy2dYPz8//9Nvnf/ztEQqDsBqIm4IAKFDvKbEBiIG7EbopqkEVR9C1yCW3pYq6LbVdcyHQwOWay1JKVTYnd3cQA27errU8bTgtuBlUgN8u+f/96+VfnuFx/RDj8f3D3Wx//l/+cPdwOkx43M75sm7c/7zQJcy/ie2mXnJr9fLsL+Z+//DwQ6mH4zGmVGurtXYJZFUNMRwO8+l0LE1rlU7oull7795PgAb9niBzdOr8uBAogGvnznBMMQ2DipRSzLS3M75mgoUQTqfTPI8xpnGYUhq6BVyMEW/qft+t1lrO2zRurU2RydHFcmtFpIhUMREXNTGwakCMnNgDr6gXjS95Pm1pniIjYDByIw/7VRcYYwgxckzIDMSG5EAG5IjWday15ZKv2/Xx8cM//TP/5Vf9/cMqi4q2bl3kzIZoBi5dzaib9Hwd3b+s7/zaXarX7GXTWlpruUkOkh0JEEUWBCIbZLva9UKXc1rOU1kn13viAdOzTf9S5pAOp/H484H/MDwMPP2A/J9ZKNNj9DOreDOt3gYUZ4wxThwS9T52R81pl/I0AERiIu23OXQI8t5+djVAdRFD1N68/9tBBL5J87/9Rx1G23+sllpzzufL+eV6vubrJV8u68vj5fHXT//yuTwXVhyDm5gSQEAAjKFg+3j5SH/h83J+Mz/cpdOb+f7h8HB/ur87noZpikNi5u9aJf5VBQv/aoC3balPH+H6wjWjNOya2F0SWVqrpZUsrZlqVxdT8FrW9eUDxg2ptXLNy7NLY97BbbXUbd1aK6YNkZg5DodhfGrDmzHeYQWsyoAUYm+KNlUx3fKWW9WU4ukU5wMhluen8vRZtwVUaB+YQZ9t5lLXLRf62wa7XzgauhumDdN49+7++O4unQYa2HcIvb9KdrnfsGLeRUDRFHhTvDS4FLuWrekS4QDAhEw8jMfjmz/ju/9h+NOfWR+xPQE9oROBE9DN0esVpvNFutZ7twC+Ugv7PriD39RLW/Vc9bzU87Vcs1xLW4ssRZYsa/EmLgoOjIyMkCIf5zQnJoJlzR/Pz78/nj8+X3MzCHEYd8qLqqqRASIFDgwuqmJG6BSZx8QpEO/ubn1wBNrBug4A3sW/AIipS/eaO4BZx6J8vVRcqrdspFCLitg0hmHgNEqcaRrVzdbFLoteFpOG4zHGozfWTwW8wmL03kEFwECbXzf/y5Npwz9m+7FCvAuXFP7Lmv/3Z/mvH15Kqsc8PfB4uEvHdxyZXP0wJw7BAOIHfH7cvEBQ160rW7s1N3Sg/Y3sM5K/dU7c3BE8Es9pBDc1cNfa4RaqoJpdM0CNQWKE+UDzG05v3E94OeKakhGiQSBKY4yiel7qterT5kp0GAOBmsvadKlOiJGpOanD1qyIjQGHgNcsVQRVomo2yA5X5jWSILEmhnRLF6kT//CLZ9frYxA4uYOor1t9vuSna3leJCiTQ+/GOSqAurshEaKBNzB3J3Ns6BvlCkSWmy1Fi7iDP24yX8rSNHB93NpLltzQbO9wiupW6+fLOrCo8zxic/twzR+f1+ezXvMWaYnt/PgAj7/G0xGTb6CN3RmQOzYQvvE5R8QUh1rrZblertdlWU73d9c1v3379u7uDhD15q2EADHw8Ti/fXtfmzaxKiq7BrZTVyVAgt35Z5feJfQxhuM4MKI2BXBiTsMwTlNru/NFCHw8zofDnIbU++3DEO/wgASBA3EgZHcYxwGJWmutqYjteKivCjCzjk3KNaCFUH1rVDwpDEaIEQP7rroBBMDYwKuWz/mcnrBLqOlJHqYjAYTdMAVAARAdmyO4kiMCMRBzCIjcx8qR48g0xTAxMvgYOTF++MhPz09b2dSle8KrBPMA8Dpjh79Fu/i+Ie+tWtlsW6wcW621tRJDJgoGILKCAbRclxe9XGhZxry+kfKWYKRAmH7X4f/TBqPhIQ3/y+F0PIyHON0h/gNeT9g+DPAx6uqyoa3uzTBhijwRRUAw7y1FURFpIiLgQEhE7L5rRZuayq5yq4qIIM0RFQG8U1/+Deu1Pd5bjmraaq05n88vT0+Pv/z2yy8ffnlcnl/yedV10fWaL1vZNClxxIYkGAYCRI7cQv398uG8nf/5L/8043ig+f3h3R/e/vzHn/74xz/88c3bt8e7uzh0PTh6ZVl9fUv96x9Z1mv+/Re9vFDeQAR2CXhr2lotreZWq4mYSkcbqmPN1+vzr80+m2cpS96uBJ7SwEQOkHNZl0VaM5Ge9g5pbMMBhkqDRBwZIiECB0NTgKoqKlvJWWW4fxfv78MweGv186f6/BlqDYhOpI5deU1by7kuuUgc/5uP1gFr5hQodDOYHx4Ob09hHjBS57vsivEODgiGTv3AoRmqoovTKnQudC2YawXNEYQAGGMMx+Pppx/+AX78n48//zme/xk//L/Mebe5Brz5de/KjAjdmrK3EW5zfdzF8r4/K+5ghu6oJrVuS72ct6dzflrqeWtFsRpUczEWA3dg5sCQCI9TeHc/zyObSZZ8yddP15fHrYhhBERtgzZp3BppZAAk5hCY1NAgEKTIh2k4TokJmDEwI/WeBokBgCHAbkXa589MhDuTD9Fj5CjfxHYVk2qtGCmULCoemceJgY0nGoa6bXV5lJdnPS8KTtMhTAfHWC8IW8OzwQfje7M78KpwydDErxuYgCuw8edA/6+Pl//1o/x2Pod7/zGz4jzOw/EUdpLqIXEkYMdgTaUuzs11AA6IgNCrry5VCvAlDuL3b6M3fjCGdKIwhbbWtlUXF1UnUVfNbitjDqHFmU93dLyn4RQ0gSZqd2MbAopHwBQ4rNmb6jnXTwtRwCGxmTTQc/aPVwpMp4GdoJg9r3reZIw4RyhNlipuyqhZ7VH8wpYBKYQBcCIA7mBm7l6XN3Ogr9eez6rYmttlKZdNrwUSUETCbrzpCqCG2N1sFL25OAACimPpDlBgVa2qizsiPm5KWI7FIrdLbS9ZS2M36HxSM9uKfr4YWC0apwHV7SXX61qlCXkBWdv1cvkAH/4Rj2+HNBl6gW4khcEguhN2UGt/BqRxnER028rj4/PHT5/m53OpkkttouMwMHfnSQBEDnw4TG/e3K9bzbliqdCkm84HpkBIDIi3QsAJAQL5PKT704GRWq3mhojDOI7zLCIhcAg8TcM4DfM8jeMQI6cYQwjDGFMKHXqyrblUGacpxHi9LrW2zoL7BuQLjm6qNbcNIgTk4lvjFhKF2DVnX3t6pm6iVlspUutWpWWVBu4gGgzZIXSTTvSOfAZVa62zd5GZ9jo+9gQmADNQjCkd7xPS1H0ixNfLei3XrWUxB4wAQ2dW39Bze9//O0jNt6HdQBpsq10vcji1mksbVuGROAUA09zE26WVp4/15Xm4LlxrVE0ckHBF/qT8z5U2ortI82n4Mz6k+9N0//746Rf89Eskv4vbRrhwfByGcxgiDYjBzaX7YFrXGLfdcbqpmvaz7V1qpbvBWMdHYhewR7/ZLf/Nlvx3Vdc+W+31jJVWty2/PD0+ffr09Pjp8dPHXz7+8vvj78/1cvG1RW1BBUUHRUI0xuYozo5AyDE4Y4XaWj2Xcyg0tPQyPF8+n9eXZbuuP//8h/c//Xi6v5uPhxADBe5PclNM3efNrwymv16at/acPWcQQdvpPepdtLSoNFcxEW1NpKmrRtSi17MWYbemWlUqIkirZtDEWmtSi6r4rUMISolmDMhGzBSQdgEWRAPQJq1UMfUY4t19Ohwtb1ILbAuLIpEjGaAriGptLW/bmrcq1U3/Rkf+Vh67ORhgwGEc5vvj4d3dcD9j6kRZR3SkTnxB7P343TKOwNCa26b0Uuh5vW/5J273we4HjwExBB7H+e6e3/8AD3cRheoqOecsS4a10lVoc6oISt0QDAD29gDitzkX9Pzj+89PBm4ewAemFnka0jgYba17KImh9VPbybEIA9M8xPvD+HCap4FLzczc1It4c1IgdnJDUOtJAwEwM7qpGIMOke6O49u7w/1xOkxp/x52UDGag6iBA3dtUujIpI7V2kHhTBg4JnnV8wMAMAUV1+biJsWtGQMOgSEEZAHDuvnyIs8f2tNLi4nv7gIyxgnbhrX5GbExFMJscHSfDUsDMYxnuCiuq/yO6//2KL9cZGv1dIxdiZAQupMkICDxMIc37wfzg4NvL9YWihJw5ZZpXQTUDbuXob9Se25C/99sqE6TRUKKBFMANG3VpKjkJkVX8HXgnLgOgSbi5BgVqSpl8KOrqF2rG9iUYmm1tC3Xdi2BjXLjhCLgTxt+WDAFUFBCaGYvm35elAgjQzXPiqumC+IZ65XqSi7s48BxSsM0DeMUYuoG23+zUdcvZ3QDM2lamjb1drsc+v4AUASjbiPTCTLgpg6GglKRoCNjTK1LQAHUBpfNWtNAls2rALgHtG4HGwARsDa/ZiXwNQKgF5PoeIo4B42OB5JRl/L88expPOA4hjgNHAMHBqbdKPDLU+A0TG5e76WJqjsgNpHrdRnSoIfDOI6dsLFXbuCRaUxxGgc1R0diSinMYxyGQN0hxV3EahE3C0SHeXp4uEPEknNv7MfIKQZEVNWU4jAMsWtzgbtb/4HDkIYhIpKZuwOzOmKtlSh/AVh9tYgICUstulhDSRgbi5KWAEQUgMlwP+MOqipNSs7buoLICpiXrS55fblub9f3p/uHwynGSN0rnPdArGYKRoE5xjikECNyB82ELtHlrkHrIdC74+F8mH9F0q1cXp6rGIUxxmNKEbhratH+669Sxu/82lEVymbLVba1lVzruKWwhOhE7Cqmtl6e89Pven6K60pNGAEAFSkjXhwfFc7ZzsH+kOnF5zd3D6e7hGGEtgxtfUPXGqclTeM4hDA0iu68jy7J3WU30Lb+jTUz3dM38K74o94dMwmMHN3VCDnFEJz+jSry/io7ZlZKPV8uv/32+2///E9Pnz4+P3389PLp0/Xx7MvCRScwIhwQIzEjOGN1UgcnYKTIiKRqotJqgRVSZlnEs5Mi6BcNUA4MODLtQt//jSLkbywrtdWCZmRdJNO1O/GJiDQ3AVMTkdZqLc3FCblVXLYg+BpgEKHVVqpeNzEzwp14AAaMNDCyDwGHSEOgyEju2ntNZtZqbbUqIo1DOh3jOOrl4tcrlhIBnIM5qqOYqktpddnWXIuaIHxPftsTGHd3BHNXZ+LpOB0eDvO7YzqOwB1sAthBsoivMsk7bq23WRrIKvi84tPloW0/R7kfbBqBGYGZ48TjYTzOHkHXT3L+Xa6XZa2PGZaCa8OrYUGQ29h2L81fhwqvU3d376Zt38x2AdT7SUpM85Cq4iYQl3LTm8BOpwazbhPRQ+aYwpRSClgrmYEYiJFBcGTAiMC7uR10ZByBmVQZIoyJ70/TD29PD8d5Hgc1baqtad1zs/6HmhoQIKDTjeraH673KAJz+FZ63dWtc24ArbmLs/cGH5hBK54XW85y/tyePrbx4IFUzWJCjigE2aGRN8BskIVO5rNhcygr/nOFjy/td7Pf1nZuGsJuULlPv9zMBRyQICQ4PiRkCJEuj7p89iSRc9jO/vQ5d+ueTtR06pYpe3j/ekPdIBF92yAndCIoqtsmdW1lkxVtiZgnktGRG3kOtpBexI6oB64N5fO1itX7Oaq00rbaNDcnp60hgZDa0waPGYfgQMqoqvK82ufFBNAQG1IDFhoLhSvjiirBYMA4j+PhMM2HcZxiTJ1Z2OcL/p3wQ09n0cD7cVaxbl7rAMZgCHZz+UTvDFBEBCbcXW/UvsgN7hK7CK6Qq0lzQlcEBSD0gR3RCDAABURwr80Xk8rADEA+Ek3MRD4izEQnbraeVwwvxhBOKR1piBgZGb+2tQMAQhyGREwcA8cQhrTlrKql1Ot16ZQNZnaAbrAkIu7eBaBElJHGcTgcxoe7aZ4j4n4gc6mXl6s0TTEeDof7hzsEXFPIWy6lhMDM6I4hcgjdUQH73dXBwhwoBApMgEQE0zTG5KYOAMzc3Y2/ex2dc1ta2SRL0DGOFtCiC6gZsDEZuBqaMSCISW1ly9u6SS0Xle265MuyXZZ8Wc93b94f76dxjCmlECJ3wLepq6FT4BC7cGXkEIkDcoA+e0BXMGrlGOluTCMjtLo+X7YqIbZp4nCYPTkYYlcphr9uBX0f2sEUW4OSvRZtrTUpojXEmSmGkExVylqWM+XVWzV3JVZgdkpuJ2g/cGEKoIH15FoxTnT/I7Zr3D7Gp3/y5bOAhTgsYCvgWT03ZUcnADV3EW0qTURUxFTM5Vbm7XgnRAwxcKXqhg6BYYg8DUPU8O8wiOlINLOuSihNiPjNmzf394e3+e379emX64fft88LlhIEGJBw91ZJARUFzBEgkPXYwOCB0hxP0/Hnu5//4f3f/3j3w9vTwzAMy3IdL9N4mCgwhQB7qf7FdvxfX2YmrmS2m5W4V6mtlVazSgMAQFBttZRcq5ASMgYwQDV0M0LkwACuYluR52ttCkTUmrUqY+D7MRBMIx8HGgMy38Y2Dq4mtZWS19IyTWM4HNgdttWWs28buwMyvIIGTKXVVktt2cHGIUCK31e8e8jcBdsRII3x+OZ4eHuIx4QDGfVcxG8dJnBHB0B0dzQFNHSDurb8vKXP5/r0wrSNUQb2GAAQwMDFfV3947/486Muz/r5X/T8XLb2nP05w6XZY/NNUcx3XT68eWffEJUAcIPiA9j3qC3pdulmCgzMBtbUS5NcWhZvBmqu5qripgGhNWrWXZ+UCc/X618+nF8utTcUiaiPnDjEEENMIQSmrtiljVOcxzRFSuQBPSAQM3blWDIzMAM3QzdTb2qAQOTsXdwFOjOGkMygyVcwOoe9awJEENgYG3kFzaqtVSjZy/m5LBcpm0nxVQR8a9las/519U1SHR2hOVwdD4wjk5u3DBfUFdFDmEMwVyQnRiJHsH6/7JwEJCQcp/jmHY3R5wCxgb5YTBSQydnUgBwMuhTf3+xs4esL62I1qMDNebOwCFxr28oa6xm8ZhAB0wv4FeOKDnjhmIPmBPXxDE1JayTQWoqrkbq4n1fLbKj1JWuuVhXEwcGa6OfVPmUrTorkIUCMSMmIDSP7wFOMx3B6c396uJ+Px2GcYozMjHTL678v2x1IAR1AvPttujqIoSgyogVEohgYmaM7SrPENI6chhAi59qua26iZkZA1PXPELpm577TAYkoMQLtVV4AjIgDU2JP7IGBCYiRGEKAyDAyThyOQxxHjlMMxyHeH9ObEx0GC947zN89CAEMKaYxGthaNtHW5+tNRERVrUsMqVqrsq15uS6lZDAbQhxTOhzmu9N0fz/Pc3cXQGbatoJmtcphnA/H0+nuzhF5iBwDLmTmrVaDrgfbkcTot7wQwDomXxRiTCHEEDA5qoEDdK+aGCPA11MfMHIlE1FXk600BsRgDlVbNXUDUEe1ADByYEAUC4YDMHEwciffvHxcnrTV5+fH3+J8iMMU02EYpyHFwBw49DKqS2RQt5hjDoFCAO59UwACBTNpydv9mB7m6VMMpUif3enQ+d+vWPMOlf9vw+jA9558q96aiUqTKiZEHOMY49hqaaXWdUklu3YMG4oTu4+m77D+OfgMXtyPMof2gKCYJji95fsf8PortWsAgDYfTCaAq6mKABAQGKhZa9JabSJNtZPcpBd6PYPvKWkMibmAGwIwQWROKYT6V2yrf3X5zsOQVhshHg/H4/zucBhWXR/XJ/51rB/c6ovp2v3JuUu4ASGSdztK6hNadyKMnOJ0N775wx/+9H//T//Lj6f3c5y263q9npd1OaxbGoY0jvj9qP2/s8xdVMmh646YWym15FW1ukovWaS1UksRseQxkMf9KuwlKBOag5qtRZ6uLQshUa1Qi72d4ymNkeaJD4kGAkTz3bMXVKTWkre81lbDwylOI3bixHKl1oATEoNB91fofYvWslhjomlMlsL6N770/RcCUqBhHk7vTvPbA8+MEfbeOL0KFn0FPDR0QzBy9bK05WmVp3M9n30uISqRIyK4m7oV1fNZ2/9h5rY8+vXZrxcrLTc/V/xU9FFpM5Db5/i2ZL952ffjYu5uX0Oy1T3vSacCJaO0Nb3mtmx1zbWoC4Cam1lrzUwRnAnXWpdcLltx95eXy8en88u11GruxIhMxBy4Y3ZjCMyEaGLehDDNY0oB0ZprM2HgwEjOXUIA3MB2tlProR0Jg2M3j34lvQJpqeFr8pupuXrXYCVjFNSsBa162awskl+e6rpIy+YCpVgpxcWtgrkD78Yq5pAdN3BCXyIPgbZsWRwIQuLTcQ5M13Xl3uYPuIv3dDA0GgABcIghxThGGhhha/mxHE58OAzuXhTVVHtHgqGTxr4GoPVgj/ialKlDc8iGq8FVbW2llNXKFU3EDESKiUSJVwBOOQ45Sh68PL6YKGgNTF6alAamVsyfNid00LZW3ZqpwEWwmRW152wvxQuA9HsZQiSiSOAhUopHGu+nuzcPp4eH+XAcxjHEyN0H4stJ+GrtOCHrSmhuAiCGiiB9fzJz6nPjmFTMvIxDfPtwuL87zMfxfL1++PS4rLmWhkCBAkHnalsXcNrZXt0VEDlSDBQThkg8MCSGxB52K1cPQWOwFHwMNHI4DsNpGg534/QwD2/v48MdToOyu78Kxr0+kqu1IY3T6aBgL9eXXHITfWUE7FmYd0SX5C0vy9pKA/dxiMMwnE7z3d18d5rnOXKAEDjGuK5b2Uqren+6PxzvxsOsAGFIHBiItmWruQLh/g13ZiERMRJjz6VUVQRCCEgYKCKyO5p5CIG5AwC+YbYrWmNVVHDXKrogo7uSmBRtffJJbgNSSIkpsCMjUYiRQBGQXMhe6rIuy2ejg8cTp1MY7sb5OM5xiHEI42EaDiNgv9Q8IEQOMTLFAIF6nKZAQKgIWPPdEB/meR6G61KzaE+UdrXFXevDukDH19vq+6rdu8KVkRnoPtqjYZin6S6EOZvmVfLSuCqoO6O5mQkqRPC3ZBDsR9CK9e8az0uEx7kNiO0KQq6sAgDCrVBrqA1Ne2WkBmYqKqXWUkprTa1LJvU2fL9iqeuihJgC862n3oMQg8q/SrfqT/dlWbesNkOAwzQfh+F0muZ5eCkvKxaKYQdAYt8tX/odjkhAfkPhme3EO3OrKo6YpvH4cP8wnc7E27qoaim5P1FXmYR/c3g31dZaAOxzFDVvrZac0aR3LGttueQsVaPTRDQixQ5G6f+Fq0lrum71fK2fXtraEMgD+EA8x+lhvjvEmYFBzWrtGihq0mrblmVdlq217D46mIktZxXh2voAsPOSzK1/pLytKi0SxhhGDo3+OtPynVzmEJiGeTi8ORx/OI13I0Y07N3XDqADt12SEABsVzgAULRq+Zyvny52zaVKG0wMuiC5G6qZSpH2rFsxVcgL5ky1HkB/DpAZfwdaza4GBfy1+gT4qnDHPjPYxeS/e01V9GlZaym1VMWgED5d6ofn9WnNS5NedlnX4zU1M+wicU2r5iW31uRyWV6u61a6qLx3sSUi6OPkLtTkqq4CKoweAyK6SCs5g0FIiULcT6kB9LkgkH3lFeGI1iklbr3n7ki1fXPmOyms++WaQM26XErOvklepW5i54tvV6gV9mOi5grWpzwO3qlp7mCADB6wRXSGQmABj8fxzcP9D+/exhB/+e0DTHVIKYZAyN1G10FviPeuysMh0nxC/5nY0sDj6XD69Ovy6dP1cl3XuiF5SBQHCgOG8B3T5MY1QQWo5qvqReoiJcticmXNpOIeBNhQBKtVr4v4Y2ujaxCZrV03E/G1iIIvzZ4qf6xcGQMygplQEd8EBcAEq2FWXBusasaMHJzcqSl5V6GMMQ1jnMf5MB/mw2GcppSGXiACOfztbt3XBMyerBihIigRBaZhiKfjOAwJOW5bKSUPA79/e/eHP7x/9/7N08slpvHjx8fn5wsBpv8/Z3/aJDuSY4mC2FSVpJm5+11iyayuJ/PeiMzI/P8/M/Omp7vr1ZYZcRdfzIykqgKYDyDN3W9kVi+MRe7qbiRVFcDBwTkiiA6gDg4YnI/QnkQ0zFI+3n96OD0cxmFMksgSWyJjdkJAcibj7Rcpi4wpHUoeD3k85nIaaBo8sYY94vtbMbOX5ydJfHc6cualLUAkciZK0zCN41RKAYDWWms9uvEIFNObOedhyNM0jKWUnIcyDGNOSYiQMH36pGD08eFTGUdnaqZpLJKzpCL8/OLP5s5M4zhM01hKHoZyOp2mw3Q4HMZhAFB3a7071JIplORT2jDeJImZ395Kh1ZhRUYGBDev5rOhUWJxwArVwBkhMYbibgIEIHZqRs3RwNS6mau7KZlrYj4IdfC5Lk/n59oqB0xRZChpTGnKaczFUwLpThBygcEl66bXVrH1BMDmrtaakXTf24lhAhROh/9Rrz22iRuakimYuhsySSnHYbhT5dZ9vvbl2kp3M2joFb2rGjqhntAmsgZrd7hbbXg2/w0XW5HYa+2ra8VshqmCVrR+G8Vzd1VrTWtta61dm1kHMEDbJAfd3UJjBJkkfC/UtGtbK9EMAO0/Du3+pvzagrupqaHBMAxjTqfTmAe5Pi9G0Exr72rmsJnGhvIzGqA7YZC8tq9hZg7ezdZWa28GkHI+no5tXlnE1GqtrTezTaD8f7xsDysFJA7w28x6671WBkO31uu6rkuvFTsPREeSkTltPPxNCUmttb6u/bq089yuFZH8bkjHoTxMx4/T/SGP7Oi9q2pg3137sq7z+Xy5XK61VSZSTeva2yK9E5JTAlMDUMPW2jLPy/WyLrNZT0IDSyH296IcO31xSys5y3g3HT4ep4+HfCrAsCmmkiMFPrEpLTlEPopg6B36YuvTsjxepWojXpxmxaEBAzhjb7DQukLrcEYz6S2rJbUJ4CeBrwzN/eJwdagbLrovCYzpJNhl8zaQ6/0UNXTV5+u8LMt8XZpjNf56rl9e1udlXTTmQ3FL+DaGFzmiumkNu+r6cr5el3VtXTcw02lrJPvGknZXU++Kvbvb5gAG6gZdYUDMRFumDgCAoYuBtItiRMESPYPIMhwcXfXdaKh21+4RAbrasjYiB7TLusytreqXKyxX723jpQTz4FYYmIH3QHmBGTCjM3QAJcdE0zF/+DT9/MuHLGVe18aXcRiyZAR069tOxp1FGh+cIA/An9KU02mkj3f+L9MzmUBzXZq4F+Fh4nzknP8ohRR3bw7NdTWbvVevZjPZlb0LCcHYsagoWaWWwFefXR5NpNLUYF29NW3aZ/WXDs+Gz+4qmJAcwDo19WbeARtgM1jVV8PuyEwpkSUHbi4KWXLOOZdhKKVMpYylDCkVkcThaAVwo5L+ePnWqopbInSKLgxBFpzG9HB/LENpht06kEnGu7vp88f7n376KCm/XOq69LZqYhxLQuzmdWNTgBOQkAgydBjz+POnn/70658+3t+dDlmoJdJEyhQL2AFD1FEZUYhzkpKSFJJClBETGhKEccsPh5Xq08vj4XQYhpzHvLY1uPAIPJZxGMacU2t9lwcEQkqSEhOxDDmVknJJKaVQMx2GqZRk5gDy6aPkVD59/MySrm1de0vaUy4pZUJytdYaEk7T9OHDfdi7TdM0TdNxOpRSWo9hotaqIjBzJt7kRRCchZII2ms319CUekrCLNQJHWB1dM9FiJkRjJHJEmNiFkJxQAA2IkBQ6JvAMRiisasDDancHRImbHZ5evr67Wu1joKH4+F0OtxN4904nsZpzAWja0MOoGqta61aqymYxRC7m/XeeRNa3yd6tn3wY7f9h4DoAG7m2r017w3QU6KxyCScr/P1fD4v1+u61KX7i6E69uC+GhB69k6qAziA5bnb97a66jqvMl4Nvn+9zE/664E/jdbcbENmAUOP2KxrU+tbv9K6ewfQ2/kePp7BG+ndeu8ALeye1Ixzm/7D2H5TggMADDtktd5aq5VzEt6cNy/z9enl+Xw5z+vSQYGDk4Lo4GrQ1S3KSPMwTNmaBojm3dqyLJeX83y99sN9+NyYWa01+JzsLv8zorPm1k23gwtRVcGNAJkYwlR+Wbo4TSnfST5KDrGcoCLuBwURJMEp4/2IU8aS5NPx8Ke70693Hx4OpyIJwLQuIS+vbrXrUuv5fH6Z57NDQ4HrVawP1h3cMAETeHPstdt1Xp5eXi7X69pWJkzCgn9cY1v6BqEyR1imdPfT6fTzqdzvs+yBxuPWXo/Bdtzfe8QxXbW+ND13WDGP91KmF/v+1/YS9CMufEX89xUem67mA/hn9gd0BoDu2nFW/Gbw4lC3+YQ3R+zbH78p5n+YP3YzVxWnIWXs0NVdXbuCo7DE4IMjETl50JRjNoUBoHdT96YaJgSvRISbbTiiITdgNNTuXvVlbk/X1lUK45BhAIBuKD3mXZgJkcDB0InAAtASTikTghtbZ+3d3Z0o13cyA5sAK4GiVmhX817B3Oalrc3UoK7Wm4Y7FVIomoLhluyYgRsgAtE2r+PgZkAMhNChXtbL4/PjUKY88PF0Op6OQ8mMCrb7ZxohCqG4YXdF6ARABGXI9CEnTes3vf7bAmUZUisDTkeZPpbhIR3Ka1shoCliZCYAMkBw8i4gE0puqAw+HQWFYKiYG7mgZYLCxrmtNF9ensrj09Nlnq/rMtc+my8Ii/vqAG6CDYKL7mAACtYdOvZOXcVcHDL5QFiYB8lTLocyDFPJhYCJMnSyxUDiXwQOZAY2Q7p3C4s2BEyhq5s7Mw9JEnIRmZLcH6dfPj8g0+/fX3pvjhiTwN8fn18u8+PL9cuXZ1f7/HD4eD9+fBgde+3LZZnPyxJc8kwlYcLkJRUSIebxeLh7OJSkRWwQE3YIhShwdwXrYIYeuSKjYFjDmkaX54ZxvWlXmT0/P909nZaX63Cc7k93rZt2d4OcCjMH4E8EpaTDcWq9x7wWEaUkSUSEEMkN1RCAAcS8ieQPnw6H6Xh/9wCIdL3ysnBvDCiAiegwDPM8r3U9nY73d3fDMJSSx3GcxqmUIsy9t979cll6V+2g3VPOa60IzkxuXbvLG+WHnPgw5YQ5WZJGXMkVrCGAJ8EkCZNIAian0BAzhe4Sc0WG6mSARCQsGXnA9PHDhz/9/OshDdisgX05Pz0/ny/neWzzsV3HczmUfBoPp3GahmkqwzhmJmpNWw9VQQtdRmbi8HvbOtQ3E6zdI+F9n+fHeBi7tHdv1Vt1UGYoiQsBr+t6vb6s87LWflFvSo8Oi2ModwoAQhc3ARV0tNr6Cr23ZXmU42+e/+X7y/cX/X+R5TuLVrlvnCncwPXoAYK59xgXBHAkAkA3NPWAOT168r279zj5W7fB+x8BiB+uvSKLtpC2cPa9XMhGLXmetdr6+5ff//rbb4/Pz/O6qlg4S4Vrrqt5MzANxAzQXNWaghIbg7v1Pp/PX798+Tjen9K0zLOpmuu6rrVW3Y/0PX1/26b68a3cft3c1GBzwesaCiRJyB26WbWOB8onKSfJ4wYsbX5GwcNFJKSS+TTJT3cGLscy/nQ6/en+4/10fygHRsIQMltX1d60r61f1/Zyvjyv61lEAeg651YnAiVyRjIMUbaltpfz5fHpea2ro3FKDMA/tOBelxUAACJw4nIcTj+fjp+PcsiYyWKFvjpNvLIMYy4iYKe29PV59cUypOFuZIav3/t/O9dOvqpNqbwg/ddL/ddLn6s+sMMIJcFArg0uDb53+N3xCaBv1Ht8dzLhPgjz5ld+8FMiwISUk4x5mJsZ9kSNHAmRiWwTGwQAIETfAjsigIXXucMmkuxgUaJvU76OCEDoyApU1VXdu81rf7lWVc/Cq1kH42wpqXCct0RIamGFuqURQcgTQnAypk5oDoAk9q7YZXZOTuyYtEtdqTc3VZu79RZEWg2YKWrieA9Ebrs0Z5TaKECMSBAIHxIgQ7X6fD2D0WFch7HcPYzH41CGRFjd6jbCGYP4iG7eraM7A6AlBihF5C4/3A3X4ySHeldhOOBhkvE0lPvsml7TIgSkYNsgIKmzaWIsgAnJGRuRlikNdwLDAtIEitCUy5Ep67LU55c589npcebnKzz3eTHvAhbDdO7UO4Arbsi2RZ6JPQxsiIgy8IAySh7LdJrG02Eap5yKVUdDULTVVJoRGwPFZDui/63NAYbm0A3UwB1zSodhyJSmPJyG8vPHuz/99KGqfn16MldOzCIA8Px0frlcX651Wfrdsfzy8eEffr375aeDQV/q8ni+fn2+/P715elpJkxJShEec0GCpde51UULCWYmLlIKJtla9CH2YaredwF0QgO4zX3AhhG92x1uNl+u55eX5++PLDyOw/3pVNdmasISErhELoLDWAAhnKrMDACZg8SOAKTqrWlrhqRqmlM6nU53d/fHw0nVtqrf1Jkx5yxymqbz5fzy8jIOpeRcSi6llFJyzkRsBr1ZXfvlMq/ragpmnnNWM0IXQdWm2jiPt60/lvRwmhizQOKVaIF21VZNK4BB4iTMKSGyWa+xYchcHAVIgBTYwBglSS6cJikPd/efP30+DROaPV9f0tehvtjjcrliv2BPC2eRw+VyGqb76XR/ON3rlBO1unRdDUxdq9m8hmjfdkTtZs7b8tnazO+X1Y/xELcmtpt6AHfWzLpab73ObZ3DTWzu6B274xmwMBTy0Z1DshrN0FVNu7ufxfEJl3/R9H8+n3+f+zTKp2bdsCGpg5oxKDgQAiME7d81nJ8agJMwIIVcicVQu5qGTpqpu7tCa3YvCocfbuU1QOw/D+KBmVmr6/Vyfn78/vX339rx4Fqv6/Xr05f//G//9b/+5Z9+X76vtuJIzORgPXaeOipQGPBAgKRuzUiBQMjcen/8+vU/r/8fPS/15ZqI21qBqFk/3p16rV4yumFMKr37iH+nlEd0ZHWzrqAa0tqMyMROmEYpSfg+y1E4A3Fwzwy3USMHcyRkgpL5w9045kkgjzw9DHcfxvupHEoqiAhmwcSqa1+XeV7qeVkfL/P32pacwCy5NZU1ACgJuReqXS/n+fl8fjm/mNtQUgxt7gDuj4sKAMCdmMtUpofp+NNp+DhRIQ9B1Oiyh+yy460rHwcqOnqHeq7L45KAh/u74/Hgbv/t6+Pv5/QvqP9wlD+lTwvL/1uf/8vl/PJy/RP3hwf8MOJR8Fr9Xxf85wZ/cXgOIyF809t8+7i3Q2vrvv8wUFJy+unwEHIx12Zpbpfm3+c6d/MWjRnb8cZto8VPa+9ra0Eh6e4a2leb0i7grmZCFOIOnVyZkcGttcWtItaqvfWS8DgyITITM27MtPCGiJ5CLO5dA9Q8JO7xHR6PcHjAPgGyi5jk5oIK2qu11auDNWzNNIwGEEPKh8hNwkDCAWL4djNc3ISeDZAB0GtvevX1atex/vLnO5lSPmkaCXljXKp1ZiO0EFczs+DqWxNdBCt4tSnJzx8Op4UbH0mMk4sKrzwTtzcbHMEQN31x8oS9QDOrqtV6M1PN43D3acKhoFSCQWhM6YAo60LGhtiAHZiiQl+XRcWJUvKCTlHgdK8xaxZGwzEvI0kkSyo5D6VM43CYDsfjdDiM45BS0qaujk6q2i6VOrgpWElTRmRE8x9nKoNb5pFMiaTT4egmA6f7afr54e6Xzx9//uXj03X+l9++DsPAJPen03E6tFr72lD1WNJ/+vnD/+P/+PmXz4eHu+SordenuT48Xmr9t2/fFjMTxD/9/OHD3eFyXb58/bcv339PKR2n/OFu/Pnj6fOH6dP9dJhyklcfcCXfMhFHcArRp5jBu9ktv909xLzMy1///S+O8PHXn7Ok0/Fg5gRYa1tWZaacE4uIpNa0NW2tu7sIM1GMQbWuPi8AnoswEx0wtlM3dXciAve6rLXWTagy2liqy7wQkZshohATkPvae5/n+XI5n18ua12jjTBNEzOzsDAv89xaG9JwC+2n8UBiQILOXAmzz7AuVquadeDOopJAmNwSO7IpCXtGdJKG0Kz2+FwOhCgswizEiZmIprHcHafpUNIqOLBmcsaGNrfLY5t/v7yMj/lQckmMoIFJN9XrWr+9nL8+vcxrNX91egvOyq1k/aGB+C6030qlGM81hR7DaK1pW3tbWlua9tmsKi7Kq2ED/yjwQHZvwK6smtAVgdAZHHAG8KsvX5v8+6X/e+3/qdr37gzUmBVDHBLDMprQKSpqa9qragd0MEB0c3RHAzQDVVXtpta7moIScoM2/UFe5BbRd8aB2Z60qK7rcrm8vDw/Pj9917r0unx9/v5vX/7yz7/9y1+//75Q8wykiIqm3UFRHTchJ4Rw1Qwh3w7gEPINDFCX+dvy2+CUDE/TMUlCJulpvl7qsgzDANmCPAT7BJy/0oH+GOAJiUwd1Lx36A3ciJCYkLFMSaHkhywHCR9hCBkWRFXv6uCQ9snmMqYPY8kwjjBO6XjIhyQDS0ZAIGtqJh0Ae9NlWc6X+em6PDZVs+TeDKrawsjC2ZzdXHmt7XI5X87neZ6ZCYvcyjuzP2aQG3jHwsNxOHw8Tp+P5W6ARHYbB6L9ve2MO9/k9wAMrHp9afWlDpKmj8MwTVb1K41rHf6i/u+U/u96Z1z+fwb/1P1aVQiv1ar4avB9xX9a4Z8a/tXhDBsa768JH2xbZGu5wz5eD6/7JrYK8zQOSITImEwpTdM8lvSyVATdPcKDe7HT7N1a2PisrfV+Q6UcIWznghXBghIScta9rQntMPCUKTPE7LJp741MOwAQkwgTkQOibUftzWnQVA2AeMtko+9n73vtDz+hAGMMnLuau3ZVtx4tc8Nu5m63E8QhxEVxHywAo33Fbugg2JYOeVdvrS6qjvY5JTkN6VhkQiBV095UvZsjgZMhKJkCgxgWXItfBDuCain84dPhwUcvXVtT76rgV1gHentiBcfWnUBdG/Ur1RfSir1ab9jVKclwnGQUTIVwIBiIijt26GSZYCLs4sbdcO4A7twBhT2hg1s3611XR0NCFCZhSsJJ8pBzyWUoZSzDdBgOh8N0GKex5CLCampdtRr2fcplJU09lYSyrbX328Mc3FwNDNCTyN3hIJQL86fj8X//86+/fvp4uD844jCU42HkI/38cP/54931fD0/PSfCsZR/+Pn+//jffvl4X8aCgKreTrWnMvzzv34z815tEHy4v/vzL/d/+e338/Xpy5fLvPahlIfT4dtPD//wywfrH8COx6kwbSsnyBDRL8ewnIwAYlGHvQslhFiGoWv/7be/UpLx/pSHYZqmbXwBofUK4MxkhkQ8juO6duaqqkQUNcaWCldtvcrMw5hZuNZaW+Uq8S3dLIwzmAg2npbWUODQjgAsQkgWoGpr1+t8uVwul8u6VgyoXGQYShJJSdyt9/b2ZRzSMCR0YgekTM5GFWA1sNYNuBN3psabtIwbemd3QYqZD4cQ0w2aKDHH3LQjmCCMWe6Pw+lQxjlbEchkjOZe57WvFbuzQ2bOTDlxEibi1vV8Xc7z8jLPtZmB3OYN8G/FjNfz6u1Pbi2UOG4dvPe+1qWus0iOgcvmuhpcjRbD6ji7z9pn9Cs5u7NBAwjHckZgVemrq2IVaG6Kq/kFMDNrSk7ortoV3N0b+DZI1Vvr0aUBMHNEc9gKe9jGJ7B3aM0RLEmMnfyxvXuL7u4e3NzNZ6b1tizz9XJu6wLWzy9P5/PT83yZ53mcpj/lP7/0y8Xmhk171036ycnAnADADQ03sgBucCznnMeU7/n4odydhkm1Xq8vhMwiueTr+WW+nMex+JDfNHNh10env/mOCIDj24aRKG6UIURngsOUswBPDAVaN+0A4QTkqN3n6uDuiRIDmwgNkxwmmkYciwyZUpBmHcgBWYoUkNpoXqtdLrW+1HbpmhEYyYBW8ysjqDpY8u4sa23XeV7X1V0JkBzAXdWAlLCHjOCbKxBwlCLHD8fjp+NwN9AozgBhpLMprYDbTTfRg48Nhn3RdtZ+7r5YfkjT/QBgTTscp3768M/f7a8v+uX3lkb5XRONx5+S/Amv9+lMppeO/776/7fi/6XwBFCjVbzvibCMv8ll4l7QbzKOf2DIP52vwSWoTms3dC9JipBw5H7xt8h1m4g3Q1XvLdLQ8KDHzQsDQoVHkSAnESI0tVa9L+NAP92NP92PH+/HkjOJaHdwmIoIobCwpJgaQCQiIXTHyC2A0AxRd0XPbepJ+c2bgF/+IX8YsgPWbvPc57kvV/NqvvFaERCQCQ12ycToatPe6NlkgG/wUIjjxCp2dyBk8XTS8nEtH6/pvtAABlEkqJlBIgDzrraCzayQQA7WBloSAhAanqRI4cnp6P251kur2tvVMOHtxArKYQir26rtpZ6/zi/fr5yEUlLzrhsVW4glE0FGF1XvTbUbuBODJE455aGM46TqC8wIgGbkJC7mSa0qADJJSWUY8jjkoeRSgi83jOM4jWUchiGllCJEETMIQwEGKpQZ2dFR0ECj7PUfoOxYCdgdO5GVhKfDkDhB74cx/fnzx58/fezgjFSSPNxNH47TP/z04f/26+f5cp1El2Vhwp8exuNUUiqBehFZER8yEoecNhrAUNLnj6dpoLvTMP7l629fn69z+/b0si7zcjlTnX355B8fypDU1cmBUSSlNBAwGLqCb4NFDoiO7zJ4Zv748eN8uX5/fOSc7j59ePj0sYwjEZtZV621IuJu2aHMPI6jiNwUbGK1AoJqX+YV0VUHIiq5ELEbiKSu6u7MTIiSRNXWugSZ19177ymlUoZgI5l5a21d13mer9d5XVciFEnTNCEOKeVhsFKKv9esSc4ZCxIDETr01GsiSSiK3oEcrXq1hoshd7fV1llNw7DX0NyBnIhYOBExEJqbajNlJBzYj1nucrrLaRUyRiVoqma1tllbd1N0Z6TEkiWJZHestTd1QmZEu1G+bCeY36rX9+fVf9SgdoeYLmt9bbqqNjNV9+beHLujOnWzqjCDXcgRLBkwIjuEe2YxG3pjtVHtzvCMnBAbEQqrsCG496DyRu6lm8qBupmFd4drHCTuMY+I2r131+69OboLImX+wSfxfdV+g+FNTWut8zzPy7W1KsKHw3S9XK7zVbXnUj5Px08Jf3/5+uX89dzPszbHqIXDNlXB0QgNYTtXGcgJEYVlkunj6eOfH34ZacCOZIiAIpySuOk8X5frWHJOJceQ5QapbD3fv5l+IToRcJhYOwuGlR8hE455GHKC4soKpsiIwGrezLv7vCI4CFHCRJATDiNME04jD4lz6DQHicuQSIgBOVcpC7A0wAbQHZKat14NMahARkIOYNZ1XdZ1XVtvoXq6qQOTOnZ0/IGSHeEChfKUDx83mRrK1NFgk8zzW7IDAIC+jW4punqb+/JcdTYBKacyfC79surq48OdXvi358vL8/O3v56HsTvAQfhhnD6ij23uDb80+L9W/C8d/83wgtDx3afa1LxsmwMH2Dq47oHNv3strfWnft4+IyclEcJDSYNwQm/bn8YtJzEHAFXr3TR0EBwZiQm7b8yXKPLdLRhKVg21CfTjMHy6Hz/djx9OoduYYyC0JBZmERFJXR1iNB634XvtIWFhbtHeACQmJEBkftNrR/j8a/K7oatf5/79uzYFn3cN21jVBMi4jcb7xoRAQAQj3/oIYdm0jakLom0SmeCACDJAPnm+q3xcLF2bm1dtbV2XBhDTQwxV9Ez1hQUzDZN4gYZI3ZPLRHzIqaAMxoWICC4Ei5K/2+YWk7OGuvY29/Xa69XykWQUKSmNCQU1+lIk6OyGvWtrql0jcSREppRlKENrqubNTF0VDBk4QXIsSobCeRzGwzhOh2Gccsk5l3EYh3EoQylDYsEw/3YHBCICJszCWQohN9U3hjd/PGcD4jFAY7QscDeWzLrOlgWnMY8lPV+v1lsieDiU//TL/f/268OffzrWORXql5dzq/UwbGWPA0efhNGFwgR102fIie4O5eFU7u5GyamU/PX7+XKZrfdlvl4vl+t5XIZJFbp3JXOGlDxnYQRyQkd08rB5RKPtYNwuYj59+NBbfzlfSOTz98fD6XT/4QNLUg2gdAXw1jy8iRFxGIJeR7331tqWGCEqeNjEmRkRlTyIJEIpxeLAzDkDgIjUWtcVN30agNhxtVZmiZ3YWq9h1b6u67qGRk1rRzMLLYkI7W+TeDJkJ3QEAusGXdGVyIS3atm69aoGHbACLNhncwUiFkbhcKFnksSJid2891brrAlROBEcshxzOqSEZDV0tDFSp+5em3btigBCOSctikSsDoiUOFnv3XoMi2945mv380cs6A8M+dBh2X/kBMBB4QxzFlV1MBA3dyR3MXPTFezF3AAyoAAwQiZwguZezQvgL6Qt4QPgrwVTYUvU2Tv2gAy79dbWta211W6GxMgMahbaVTEw7BYjQLX11lrvZt0JjRBLCor737n20Kmmy7K8vLw8Pz2ty5py/vzTz/z5p6enp8enp+bqCdNxsOT4F5rhul7ntSMTIZGZa9fm3bp5WIglYWBUwg6qDkhlGD98/Okf/+F/H3nwVQkwseScc8nErK09Pz917YfjcTpMKSUW2Ro8fyey700VBkJydCZXAO3IiMyFM7IbtOa1Y3cyQkT0ZtbVryuR06mUBMOAw4jjACWjbBOu+/eNgpkISZKUksepTNO4LJPjQg0N1mZLW4SwZ0HMowmr1VaXeamtmlm0F6Jb7zEfyK79nTRgEMykyHAcDh8P48NEmbdTGgEZNp8n34U4A0xWB3XtsL6s87eVjIbTYfo05M8JoNNSTj+d7vzuy++///71y/W3L8SUS/75UPw+k3hv8K3Cv674f1b/Z4XvDkqvKPyWUMW/m3skbjp5r/z5d29FtV/qxVTNvIxTno6HkvQwPr+cxQ3VraOah5iEaQyq6CaoKZIBu1kzRQ0Ue9Nyr0u/nq/QmjEeCt2P6f44nA7DNJacBQFDc9vdCFG2zp1w4AOI4M6IRNSgqoYmb8ysI7NsnLX31m+nu0wf8trUyZ+fQdXq4m0G6MTAzOQJu29qEh4dhKhsHEBsU/EzcHdyt01naO9nGCB5Gl0mh6TV15fL2Xhxc+1aa2XiQ0sIB1lHe4L1samw3GfmjBuNVymTF3EGz44EhJCT8IVZXm/EfZs0MHRApCzj/ShDSWOWKcsg08M4HJkzAIB1cHProYrVIr7vTYhCZpyqjGvC1Ltb7aDOkBKXnDMWosJ5HMZpGoaxjONQShnyMAylDCmLCO2+lUGjICIsrImdQ00YZMsTEV49DV/vhNCAQBmQ0TNDmlKR9KLVXZ+vZ0B4eblcnp9Z1+Mh//owfL5LU7aBGH+6e2R4fHyqdf3+9OR4/3A6IAkaEAiBERYiIQyNQiP0IUtKI9NPn+5P58u8LEuvbRD66TjeTaMDzEs9r/PcllVXByIuiUuWMuQy5FyS5MzymgtvFxEdTtPz85Nav16vz09PP9dfp3Esw9h6N9N1XcNis9VWawWUWwCO1NPMAHxT2zFtrUXJNI9LKUPOa2DpgY6YGQBErV9Kub+/Z+acCxHGphMRM+/7tR9QRoSn02mapviOzJxSesep6W61hwFprJZ1qd565EldoTXT2mubu14Q1sKGDEaAhuRMzEycMAkKqDVdl/kyX0sVMBqYIScZSh6TVKtVO5MA8ygZU18cGKADxtBNkVxSZhIjU7QOauTNdBde/4E69yM8/3er9q2+3/Bm3OSLDcwd3TO4oKljBjOz1f0FXBEKYgIQcANkg/DCErAPhJTwI9lYMGdZEyujEwC5maq3tS3LujRtPYDLYBSZkoc6ijqoG6hC712176ritJd9P358DwFq9wDugk9xPp+fnp6en58Ty3Q4HMdpTGU8HPMwVm1Gnk+DJX9en79fv131svg1zL3dzQFabc26OzKmFE6D0Ss2EE6H6e7+/uPDh88HGX3tWWQsZRiGlNOyrvM8t9pe7CWm/oZhyADMQsRxaOIftnxQUwA8WCvOZOghNgdIhEmQ1FYHTG6IzIQdvJMl7ImYQQaeJhlGzAOF6zdHKH+Vcg3mACIzsmTJZRjHwzQdnK6wXpe61Na1CzkAlMQK6I7atzZKZNngEE59tumMhRnFu7VETGUapodp+nDIpwES+nYfQLSNwd8o61Gym7p376vXl1af6kGm091heCh8QH8EAByGwgdKWdT05XxWtXEcBxufefyW+r81bxX+S4X/rPibwSXqzBvk/iaux3kcuPIWwfBvECCizdd7066SBF2HlOAwHIY8CF2rgqormG3c967eVc0dEZkJEJpKUq0hVxtWP8362hZakvVxkONx/Olu/HSa7g5lLCmxmEHvpn2D+ENri5CdABCj2Qgi7m6stwHDWDsbIuSO7zGtPLCMbGiI3rvVxdaL1ytiJwbGGPhgsw238xsdAjlSTNyEfRzMQ4jlzXN1J0YeEBM09cu58e/Xy5nctqnazNQLHTANbewX6y+GRXxIMGTaJuSqArCgikNC7+DdkZALYn3vTuC7dAoSZSwsxZGLUGIpPBwzJ+XszIzGpti79dZar7211lqrva+mK7gRJU6UjDJ079xRkYCz5JQlHVI6lHIYx2kqZcgR10vOOaecKYYQNyaPuwORMGHhJuRO4OxEDEhwS5V+4J8ZoSEhESADZkIWKUzQCxE+XV9qb21tru3A+DDKx2O6GzizARGdDt611baoff32jCzDUMbEgoIA6IqQgrsXQr8ELgyDpJLyw3GqrfXWem8MMLEkSoBpqQ0AVHVZwhtqYUzCuUgqJd+fpvvThETM+DacENHp7u7l5eVwOpnay8vLMs9MPJSScm6tXa/XeZ43z25VZiamYMZFva6qqg7ordXWWm8dEWvt69LWpS6liqQQ/4k/HEkzAJQyMktKKaW0LMuyLOta44NFBsAb1iUxihwVfHzmCO2ub+KHATbURVutsUysNzBDdEAO2xo092Z1qeArZWdhF0Jk9iSYE0vw+EKB+ip0Lum+JC2CCJK4pDSVNC9tUXN2RhKghOJs5CBgACiUEuciiVHUtZM6OMMmskc3owjfwcY/XH9DjW4/FmAD51iSlCQDUwLYAJmMRuSAMKAZ2GLGgAbQERO4ACB4IeyKDUDIR9LEeJ/ZMmnKTRJK4pQcQU23wr3X2ptqB9BN1Mu6OwG6QYfAs4ITj53IiAzB1dtal/a+TNw1wsHcTG1d6/V6fX5+enx8XJbFzE7H48ePH6cyCokjAdHT89N5vngzRDzm6dPdh8vyclleTFVdN9kv9Va7uYtYyFBYc1JmT1M+fPrw03G6M0M1EEnjON2fjuM4pJxa7/OynC+X6zxfL1dVOxz6ZFbKkBMCRSazLcTXixBo41FhuCNzAkR37YDiZCaoIM4ZRLCjoQAwGQ+WH1wgfRiHMaWELIBou4L7rTgFJIh9jw5AiEQ85HI3HV6Mnzt9me1rq6ZeyBPbSUMxSFhUkqZkAA3cwa13NO9hB+BgVN+JAyKQZDk8HA6fTuV+4FECB0IK9zgHBPOQL7nFDDdzXa1dez83n238pZx+OfHBmrV1qXruT/rcz631FlKaYN67Pp2X/9bbIvY7te7w74Z/db9EZX4L4Htg2FMJ2Kl7ryzTqCDhzSEsQkcprWKtldnAKzMPIqdDOR3Hy3o5dwVHAjIMKxAjIg4bKgQHzIm7pa4K5tqVuiJ3r+SEZcCPw/Cn++kff3p4uBtOQypJiMgdAc08Dr7ArgzcI9dWMzeNkSRCYo6tS/t2/mOyGLdr3XRZ6+VSz8/98mjzE7QLptC0R9pC9E0bwd034DDAZgdAczSAV3w+ujtxaDCSoCleHh3Bzo9VckhPIQEL+oD9LrWPWdkQlCkTOjIIkxh0BQQztG6ERG7ZbDJnhyPYF4dXBeObiyUiAWWmSGsQAYETEScRTgKMCZ1rd7dmqr3VWtda67rUdW7LrK01TJ6yOBZS6skQKHFhlwQ8lvFwPE73x+nukEuRlHJKkiSeVsw2bm0cB0QSzsLC2AF0U7QK5cEQD3REAoO3HavwGyZyIRRCEaAsku5P3fTlel5rvT+cHviYUO+mPDIyGJkDYmY5TKf+wH/99vT124sBMfPnh+PdNMZj2EkQuvVkzdENAchRyDFhFkbgRDJISZwRpase6nBdxssyzq2vzZelz3P9fnnpXdePH5h/Yh5yTm8rKmK6u7831bqs3758nS+Xp8fH5TofjsdSyjAMpRQiDP2MTbPI/RahEUMCZF3XZa1BgMcUKXLXdW3LsuScxnGATdKutdZqbe4wjoPISUQA0MzneZnneVmWnEvYwUUIL6XUWpnRTJdlQcTQ9ieiuvTbXmEQBjQyvyVcbt609kWBUAahDDmDa1u5NW/WOUGmQXgYh1PKKRJgtb7WPs/zGewlyXwY9Tg5OBENWQ6lnPsqWlvt3bEvag3IJSNDChYvM4k4o4MZQDeL4WcmFiZmJHqjNev4h/D+96p234phJKaUpURLkZDNwcEYPaMTegZT98UBDQygI2SAgpARmkJFqAgEnhESwciwiFwlOScnQZJAYV+t2ltTbYim1rRXM0NioOCj2zYNBdt8JGzjcL3W+gcE2MKZQLvWVud5uV4u1+u11kpEwzAcj6fT6ZRTRsAJ3AFa7+tavTuAn/Lx14dfal3d7aWdZ5sVVGzt0GM5MkpWSZwBMZfhlO9++fDz5/tPh+HgCsCYcxmn6XA8DkNJSdQslwKIUY621uZ5dndV06IiiUkCZXp/HMeJYZuZODLhhlOYe3dEA8EkIEjJwCIaZfIhw5GRMY0pJ2YCdNMQ/AkwBm9WFcFfoE3glRAzy5DSIF2IV4On7mDQAQ8Gq4FuSy4lsSK2sWJM1d26qXUDMyOp78QBWTjncvx4PH4+plPBErPsYWYVITU4675N7pm74Sbee+mwQgIZjsP4eajDvNZ1Wef1surc2/e2rAsgSGImiuPm29qWpt/EleAJ4AK4AhhC3CPcAvgbAbAI5T/w5n74qQgfcumJWxFACoc6Jh+HfH+aztd2mTt0b+6ABE4sDIHquJk5kadEowtaLkSuloUPQ5qKHDPfj+nTsXw+jR9P41Qk+DLN3ELPZLMoCkMZREIKGZlYMNHwJqTN4wBgn7sDeCfWFFfT7rWdX+rz9/ryrZ+/2/oMtgIXIAGA3VTEN8Rsm1oIQCBM+WgT0Nuoh/sRjxAWI8hMYLhcwNT4DEiu5mgkTuKavdWhyV2dRqaE4BgBkFAQOMyNULsjI4EL2oCW1N3s6U3u64E6RUANYSkM/St3AzJiSEkyM0FyJYTm5qZq2lV7773V3lpX7U6WRmEGNBdPZk7MKRUy5kbDOI5lOoyH4/GUhyJJmJmJo2YyUPco6QABiDhJEhZ3Nu/uzTwcJ3a3AviRFhSNjPjwhExA4J6Ijofj2tu/f/laa/94fz8Oma0eClGAmB73LCXL6ZieLk3t8vR8BQRXJSDH5E6IvIvvR6a6S4ABIjgjEBETJUk5lWDhuFsZaBxpWnlpbe12vtRnWdx6azWE/YOosL2GuAvEXMrp/u7XP/1JW395enp5evr+9ds4TXc5E3HA77fpUDNz39CTINbVVud5vlwvra2EnFJmFmZxx9AqDXgcEZdl6drdHAlTSjnnoQxIZGbznBCpNQ3ee+jEi0gppbV2uVxaq7336/W6O7/xPqC/3Ujv1qu16q25du1N29rqui6tOXI5DEkSC7sZEbtDa10QAUgoJSnCCQBMe1eva5+v9QJ+KfOy1N7DShtzkmnIY5VLRVWH7tABOgVNOmyb3SHIYQCorWuLU9Zp822+HeC3VfSuDoG/KVmz8+6i2CAhzmkc0pg5I7EhRN8KwRM4uXfz2d0NDbADKAIi1Ajt5KuFXJ0zITg2lEqlQerGQcpRhd5dQ5FGu/bq3tRqa6s7xIBWDFCAd9ukszGl5Kpu3VUNeihn3W4ixiHWdV3m5Tpf13XtvQPA/f39UMo0TeM4ppyQyB0kp/EwHpdTq32e5z63u+nubjwd8vTp7vOXl6+P18e5zle4DLlc27VjJ8YBxyEPJY33pw8/Pfz8+e7zh+PHgQsaJJLj4XA8HIahpJSIKUQZkGgYh2Vdl3VtrT4/r/My51KGMg1lyLn8WLUHHhpkanSMWo3EDAy0uYHBZhjEeWt73swf02YqFpvaDA0N9pmmANM8lAvdnNAAUY0sHHMBTaE32xYUkGMFrIAtFMBZRKykTq5tA1LMDKxvUH1pbXpzH3lIlOn0093h00kmCXkjBMfd8BO2rlHEQYijkpxsdb20hJJPZXwY+F666dyWa7tc1/PLN1y/tJfrxd1LSSWlu7sjIrycz+e2zNANvSH0TTX2NvC2Byvf1FLxRsfYK3vYw9U7fg3RUAqWAghdvalXtdrWnOjjw3Gp2lQfL+t57YCCwSnnIAMaoDM5AkihUQrDkJgPQzodyph5TPhhyp/vpuOYGKF3rbVvE+ScmZMwc07DUHLOSRITQyTZ7mEiG3UP7dPGkdfeMsWub5zfAOraKq2P3+q33+rTb3r5avUF0VzQKBbDRvDbz4LdAS/K5I2sBw7om+fJLjwVe5N5CyZm0Bo0Q3Os3by6VJOuGRoelwMIURmKGFnzptYcsxuYERioKiATMVACUZOu2I3f4xB7nmYBvgIgaAgKIDqyMw6JMjhrhGHXbS3E+jIER07MSeROjJVWVFUgCqI1OUEHQQ4UZiumSyYMRlx0JyG+O6IjElOQWWLM1sOhPEYQcRtBQH+/y0NbFKM3huSIrbVEcjdNVfu///X35orkKZEKg3tdey1aCiNw4NlDkfv7u0+9f/n27b/+079fz9e69NPdBwMJK1WKqLGNJNNOygd0AnAwNIVOCtAdgNCZvQAS8VBQAe+O06ePcLnU63XNKU9jYRbfRBdvbGVo2gGxDEMZCjOfX87/8s//zDlxyWHQvsmtmLXWAR2A635tPPb5uiyzmY5jLsMwDCHWm5mFCHtv5/PF3WqtkmQcRiIyByYWlmCZhH7tSquHmXrvzJxznqaJCL99+/709Lgsc3DpmTmC5dugeJ7X5XHWXnutva6tLusyt7o6opRpkpKHyZxaVydW9xiWQgJAMNfe0c175ANLW5e+Iq2D1tVaM3RH9Cw05jSlNLKoBjTNBKjm3bRFx1mNkHJKTGRqcS9mEMUQ3M4oBLgNqrwP5X+rag8qJHoATkQozEmSSGIWi4YueHcXcHDvDt3QDBSguzs5k68Oi0M2SI5s2xyXE1fJNQ2NcgcGi3p9Zx5t7tOq1lWb9u5RqoIgknsItkfxQMxCRJuSu6u/D4ruEEONge0EHzLnMo7jNA7DMDIz0Jb1UJIMOB4OdW1taeu6DKkMZSwfxg93nz+fvz+eHy/L+Xq9XI/ndVm0d0DIJZdhmIbD3en+4/3HzCMqsZMQl5wP02EYh5QyMQUZPuz7Us6SF04yX3FZ1967OYQ8/kZNfncXbmYbh3tXKojpqhj5VjdzQKCt6AdF860IIAo/nc1VmQgANo/laLW7m3ttujZ1BEd07a6m5tVsbX2ttfXeTQ1QgRVQATqAOhAAIxZh9kRIoN27VjBt1rrWZtReR0URcTgOPtnh02F8GKlwEMnj/9vK3A/qm1a5G4KCr+4LDLkcp1N5KHhAqoRCzlphvc7r9WW5tsXYhTgVyZMYunVcMXreADe43X1nK+6/GpF9l+XY94pvRdTuX/rDtfXtzLFps26qQnQYy8Pd2LQBA15xNWqGqK6dzGJqERiQhYQoEZckJclhzKfjMCRMDMckhyEzU7g6qXZEIGIhJwAOcHljZuwcAN98rswwkr7XeLM/zu2ydzdS1z5be/7en77o+ZsvT6ArEHoXQDRmtBv1eStmbiomvmHzuFmPBBYdxMNg2mEI0AiJBL8Mu3rtXpvbDHw1rtqgFuP5kMYjFRZLoNSNFGL5OrrdxhYIkVCEiI0J35Nl/c0Ltl1U2wEQjTYhICYS66CmXXvXrm6AyCwiwKTCziOQcLrLTkaMqkpMuZRhHAnZ1UGBFEMhHB2YQx8zmEgBhRJux2bozVk0ToJ7GM8kAjxtJOsfztv9CyAgIyB07eY2DqV4zsLzuqzLnHDzqrjMq+Q8HkGA4qgXwuNx+ASny3L5/vT4+HQR/DqvQGlY1nYbgQQnV/QYcETfT5N955lqKHYgEBqxJwR2dKRS0sHzYfR16XFSMbE7ub9rK3Q13c4jkZSXZfm3f/3XYRpP93dI3LsCIHMCoN7VTM3xxnEzs9ZarWvvnQijBR6a/CmlYI8CYu/dTFVVRFgiogNFCgiOiDmXUsqyruu6ttYRF0QUkfit0+nkbqpBxO5R74WP4uvu6P2yrtZbb7Uuc11m00ZIw+EwHR8Opw+pHGrrSLOBNeveK5HPdZFanIVJ3Lz33rT1ZoGK9u6taW2dQcGdEBNTES4ia1cGF6RIW90tzF40WjZIBmhhcRPp4uZM5tt0K8J+mv24rP4Q2hEIgRlYUASIDUiRDNlZmIWRyBDXWKqA5OAGZNgcm0N3B3BBzw5Xh+SUlBhCLp49paWMdTi0NHQQ0M0TdNf5QEDeJOV1+xVTR3QS3iciwSzYwQzObmEgs6nlv97EfiVJNG2NllxySpmJo7i5HV6IgEDDMLTp8PL4oqtebDaDDz9/+unuT7/+9A9LW+Z1rnXprVrvZoYIRChJci5EDIbX8/z48oyC5f40jeM0jTmX0Bi7daSQiBEnmmJyZl3XZV1ra733dV1LLjdgc9srAfjFewubMIyhLHJgh65hAGcePFwmCnaBG5h3cML4oDskF7q5QBAGBl3tZVm+n+fuDkQjUyKYu780f1z741Ln1rSrE0H8t+udgTZyLUwZc8rAXb02XWtrVdVqb+nt8BvC8eHgRx8fxnRMyOhxrOxd9jCb2CbCDE3BDKC5reYLYOXhw3T69S49JC82pPxgx/X+5fnwfIXluZ+7GAzkjo3tuc8GuuCi0mN03EOV4O1Dha2LHInvDW/egkIU67ee7ZtLVed5zUnAHYmTiHRHVAAkgtMhAx3TkMbL+nxt57ldVzcwQgAidEuEY07TUI7DUBInxqHIdNha6gkJkZvaZanhDjIMZSgDkcR574BuqOrYzV2BIAJCQN/xcsDDYt7ACXyz6H1l0OxXq7o0vT7Z+Tssz1hn9A5M2JsThcEtEm2dkhs3gyhYtIoUomQ3fGOn7gXei06IKfM4JhmSIerS3FsMM3oHrwAObW2t9g5umb2gJ3dRYN06wYAEFPVFtI8kJU5E/Pa82vg0uDfdo55Bd0QQ4ZKScEbgrm2t69rW2puZI3EiQRdPxpmgkOScT9nQCElVOXEpZRxHZnFAq2aLMrIu2ktPoyLR3tgK7Il2CSJw964dA3zyvWlx+8Due8X1huofiByEZTRCWAq5ibAInQ5Da/O3b1/mUg65CNN6bpTzqRtJkBgV0ccJP5ep20dEv75cv357fn5pRvL9+5OZRTNHHZuiWRAinBDIkMAJnBAIFTbRq+j4qXk3NzcCREdJLDKKWWiN/EHNMebIe5+Xqu55HC8vL3/961+m0+HzT5/TMLamCFLKJLKaz2utvYfikAcwHr4eRCwSJuYUaXQMGeW8ecGZKSKxMG7kRQDeLkSMWfZlWZaFe48TaUOwAqz99Omzu18u5xi6i6qd8NVkwRGMAITQqZs1s1LK8Xj6+Zc/3334GelQO/SXJ3Vv2te29HptDd29dj10E8posWsMAHMqIpHNWF2bULfew3GUibOkxJhIlUEQJaWByaOD17urIYKarmvVwKJdmyq2pqa7sgTuWvI/4PF/y/kN9tSbONLTXnVZ+6LeY1LKidr+J2lbldgAO2xD6BIjcABu2AAvjgUgUZI89HHq09FztPU0UlZw2lJ1hxCmA2fYcl7aZ980nHd7hy3PU9coUsHs3eGFCEF9LESCAJJSKUVEbgJG+5m17y2AVHIZh5QzIC7zom4PHz8dyuG+PDhBC3U825zoYOsKOwDM8/z4+GxVXS0VOk7jdBhzKZwkJk72gI0QW4hFkouklDJLknWNBR1yDT/uFncE349Z3FHjjR+5OTybSey2CJiwU3VdYVNCwq0tCBieTY571DNT1UW1u69EhP6yrI9Le17bpfWqgfFDN5truyzrvA4nwmRKYEIEjITspB2odXOvUR69laxBxMPHgx89HzMXNgKAUMpxBAcPa3aPkt3UTdEUrJrN6g2EUjkO5VOBgxt1YsiTjJ+m45/v5kV9YOiEnbCBV23rWpdVUY10t2DfM9rXGe3Xx3uL6nB7Q1vYilr1HTWldX2q85BTV825kKQAkFzNzbPI/WHMuRynfjivz5fl5bLMSzU1MBO0InSa8nEsx2EQRrBOjLwZSzMiOXB3APWcKWUexnEaJ1VvLV4CmoOqOXRVi47bxnaNLjvAptoIW9s2SBQWknVvrnW2We3y5PMz1Bl1BXBABu2u7ERORIjMwojs7sHUg1BM21R+tpx4A+r3x0UEKXEZZBhTGTMm6QYOfU+8NyFcM7QYbUQzNmcHcWD3OKbQkZgwETBaUEkJOUFOSG8jImxa/HBLnsEBiFCIsuQhFcboGca8DBATUvLttXWlCgKcs4xZRjEPlwqTJOEexpKQ2JpZ7t4sjFI3i6pw+d4SwPjfDh3466J70wq9MTMd3u9zh3i/cQ4xAKmBmgF5Kfzhfmz9Ol/XFXwoxc2WZeGc7tdKIsIc2IkITpk+fZgQ/THL8+OlNg9zmEBxHHBe2stlTjlzCj0VRNtdmcgRDdHiIxtYTHhESyewPwJEYkO3MDh8v3d8mynvl+u8NiVJ3e355fnx+7fH798Op3vA5I5ECVDcodYWXLZbCA8iieSUUkqSNrQzIlHwO5lZmJ2ZCcDXZQ2iQs4eoZ0IRTjnlHIWkVBJiUgfHavD4TCOYyiaxKbY58fevJR4IrDpoATfK5UyHe8Op4eupc/VHLpp19a01VYbgDkbZkq9MGN42RMQMkpBCj/uXtfq1E1b672ruW3UB3NrqqsCCSWWYRiySHwKM2u9Lcu6LPN6nXmt1FRyJub9eAoexR+t+P4GIL8JL0AMYKKpt8v8bE5rvZg3JECmHhm7I/su5OXYIJjcmCOuO5wN0KAQDIh3lI/DgQ8nPByxFBFCRwNvnIiCYbHZyAIIAjAJIggTkrkFiaL3Bq1Zq9qWWlsL3+uIcG/vgYhEMlMCgFg9O+/gxs3cEFmL90oesnFlKmnKl8frfF7Ol5e79f5Y8lBKGQbwUB/3/R9rvc3zPF+fvvz25XqdiWg6THf3p8NhkiRRr++D0vCGbRSjnIiIxDIMk5khYMxmvN/zYA6bh2yopm2uc1shRgDdvKJSU3AUCYiA8AbGRhW3+e90IkJiIwhKvIOXxHfTgEt9WdfHeTnXOtd2XteX2hf17mhIAFi7fW0ra/tU0h15IhC68X4kIWdHpmYWnO33ZECE6ePBDsajOG2PPXoFcdpHK9fRY7jRzd2gr72dOyvnoaRDxiNrauodoHmy6efDz4Mc/vSwnrt39BXtAtffLl//y1/qb6uv28nj8YJt7wnijtm+cr/3cj0iRTy7wJwBb9lwXGvTr9frmNOx5WHwnEEV3Kl3ra2J4JjSYUwO9OnUXq7L4/Plcl16NwAogmOW05jHIonZTNdlqb2ta4XGJDkJgyQkYIY85NPdeJzGIefrdV3XGYE85IZVURXAiZiYhZmYMHCDnYcP20qjDegh/KHRM1/9Uu367MsL9BXCgxYp8ir3rXMozCXlguAWIh192Sj6sAnXw2twdwAnREk8HfJ0KMOhSErdsKv2LQVHAgwNLUQjcmYjUofqWIDAETcNDgIkJk6IjIYYHQli4Ffy//49d4Yg7WkZADNnkSGNJQ1uqGpIxCIZBhbZ5i9UscHihICSUyqJiMBNUgKAlFIZShmGlIuwoAGM1uZal4rM8V32rGrHL+L7U1Tu27G8WXY5EhLsCEfQF95vcw0OzF7SiBp1c/MuOX36dHCoX8DBSXJqrT9fK/D88HIVoWlKGDCLIyGeJh7y8eNpuF7b49Py29fzX58fd2Uffzpffvv2JOmQh+lQck5spqFyG3H9DenEzLt5N0e3eO60KxsoggIquCIwwGuPxNTXtb1crvO6agxM9T5fL8+Pj4hSxpOqa4/EDmtt1+sVAETkcDgE6C6SwtwlOG5mptpD1knNxGNaGMFkma/ny9nUiHkYD4CopkzctSNRSlJKDvtgROy9E1Gtq6qK8Ol0HIYShOh1XXvXt8NvUfhob72u1tXd1ax1va41zash1a6tt9arWt+9nNyr5obuGbFEHowOwMTMDrq2vqx1XVcnM23r2ue1r017t9b7WuvLdb40pb6OfizjMB6mqYzCZGqtt2Vdl2Wp13mZ13ltLoXLQCz7GPOO9rwP5D+EdtygBEUzjMTMXOf1Yk61LQDKDMTUAA0QIXbjdmSaExIk9AIADtW9GiwABeCI/HMe6XCXxwOXQkyMTkQGiViIGAN13IWvdz6/g4MruHu4S/XutVpde6stlOtgN6l8dxtITEF/Afyh+IL96It0f8+tiTnlNEzjdJjOl3O3HjxaAGcO8bibHuwr0xBg6a3P89J7P51Oh2k6HKYyZNqMjHBnYuHtyN3AQ0QizMQucEOlgv/5fttvR178ZJP23KvQKLs7YkNFRQjryK2ewg153griwOA1+NK+CVQ5ERbhnribLK269a69thBp1m6qZtEqXLWv5K0114wiyEHU5kCEgzgVFUkwON8sKRzvx35QTGSv80pBA90IknHcBWXQDayDrWazZSrDYeQDWzZnNVRAA4ZyynlK48dDr+CdbMb23QTk6V+/QAiah1B7gCwGt37Nxo2Lx/kGl7/9iiEgugEQ2n6WbddS+1++X48lr5NNDYYCarg2X2prqjQIpzQkYUljkmORU6J5ygEsJaGUuCRhRndrFVWlhcC4EgJacEkSIjFKLsOUcxFh5hZOMIG/+G4+w+wA0WrgrZMKiGB7YyPQsKjK3jH/3f3lmz/NfnmE9QK9bqwu3xWutsx+g+8IERjdNaguGnT8rd3/5mEiggjlLHmQXBgZFLQp1FfptyhQXNCz0GHKd3fTdCgstG+q0DLQbkbmZrh1zoBpO2zoLY4dhxW+ykFs3X4mzilnGRIPHbRrC2pORjTmQAC19c6KxEBGQkGIISdJiYhCHz7lnHMSFnIE2cHqWzEZeg47s32jM22p9La44k/sD/8tFPT+OPJYoBSCgODhhgVmxgSnQzE9mJo5TYfycvZm/eU6f316kcxlYGZGc3Anh8yQGEvK05BFRM0Pv+fgzDKzSE6p5FSy5JxSSqiqprbnhPjKy/Id6tpYlAag2ymGhtgBdMdcb6EdRcQdlmVtvUtOwzgO42hmz09PeZjycDDV2mpv4WEGABAN720CrbaYd4/QzswAmy44N0FmkWTugROr2rqu2hWJ41mJyK666KWU0+lORIIx5xv9E3pvtTYAFEmIG/rrDv2VXbKZKvWmrW2gRWv9cp2/Pz11yHmwrlCtdVdDJ5E0jG5IUJALUWEZhRHBHQ3FUYysqtV51afnWaBZX87Xl+fr5VrrpbXr2pa1rrUttWnXtZuQxAzqWAZElJRGlpyKlalOdVnbatg4MdGeJe6gz/vrb1XthqqoiqbxN2xpszp1XQGNGZnZiQxMgSJ1trA4IGTEFWHG0JfxJ4BH8FH4U8rlcLy7e4BxEqYMStZZCklCZECOYVgACui9a3UzAmweK1NiRL333qrW2npt1m/EeAzZutd72PbfjbXiP4TM94c2Yih2pTRO4/F0usxXIGThG1yzBUnci77t114vEZmmaTpMZSgsAu/SCXzz/7cfEt/+INCFd1seYGO+EW75PzjeIJV4NxTRXZHiYN/sPcMK2TZwJWgXTgaGwTpU3+AWQIQhMUFCS2TVe79EQtt6r733DkSAnsATQUJnAhQGkeDiOyE43Y5XZi6EKb1ZVwjpVGBszqBusLlGOYSIhqPHsFXEdUdX9+5eASunYx4/jHygzg2oI5oTIJIkJIBcSDu5igrUuS3E2NVqh+6gCPF1dava9/wqKr/9hd52xA1ZdjeyCIo/lO1z1ecv17tJl4rHFYfiwY5RM0AXHsdiYEqOo9DI6cCTHXJgtU7RPfGm2lp4EggKSWgsGjbVxVvPTFTUCSg5kKoRYskJ3IMW6buo5kax3ICojV4NhLrncrFUIPKa95Tsx9/964ufv0G9oPVN9cX3v7cf6Nq1YWNiwG38uKp1901ODXetxFvSmpKkIkH8W3tVs9awVu/azWwDS8yZcBrzh4fjzz99HB/KmjXYcbHvm2rrHVmlGzEgEqFsHRID9HebCkMKAXaTXgREF+aSSpLMmBQAoO+jIuTMFhJZyhztEA4fNgff5m44yVBKLllkx0QCAxLEwhAt4CiTAGIELQ5iBFDzbRJxC99EbxPyHer4w0mwIXEOCqEx4Vt6SoBZ5HSchJMBBMhM5Nd1+eu3x1TS/cMpCQPoVhWBmSsACfPpmLsej8eSE3ekMZdPHz786edffvo43B9FsiE3Aujou6GRODJudPE47GCj+3hXqIQCQIAapK5AWG73QITjOKX0HKF0GAaCewIj8Kfn58Pp/sNHd+vrsrS2mm0hPDTeI199eXlprR0OB2YehoGZg2G3rutWqyWRKsbk2lrrpq5qoK56nZcltOpihn4cx2maLpd8PvM8z621nHNKqdZ2uVxuOO5ORSN4I4Rk6q1b69q7OaADLWtf9VL992vV+4/GUppVAwfiVIZxmhDYVYZy5DTkfJikIIJCdzJn8z57XS6r/f7tjDq3enk+P79cXqpZA1jUZnU1MMN5mc/neb0uL88vl48fT6dTkRRUg6GUVCYbbVnbpfWXZoTBIleItQLq8G6f/6FqB3RHM9DuvaMZuaOqIjQAY8KURJIAkQJ2gLA43c2soCFUgAWgOijAE8AToeV0OIztdKenBy0jxeitO/i2OhyAaCMAB8U0eO+RyhoSIrQGrVqPfKm7RnEf/sYA/mPVjm9C+61ie1O7bLjT9v84oYg5VIVD3AAQAwva/sJezOAeC9y9t66qiLgx9XKODBQRf9y/r0/4Nc+I1P7vXe5um4ffjuft/faoyhzczBWsk6EZoG1F3HbqbSHUI4Bun1vd2E3dMSQ8au89nDTCK3eDUDW0BlrvxCQEAh46gwRARKGWHMM+Tl4cxq6n3nPNappLeXsjPIoOHkM3WyygG/6wQZS2nRXoCt6cTIgxTyXdCQ5g1BA7sUVwDmlscAKKPq3Z0tt5adelr827ge5o/A3+2Z7jDwXT7XS69Y3dHTYHmTe5PADUpo/nVY3c29qprDH22gFBhDiJJI5wWkSEsCRCTkyIAB2smS3dzb11XZutzWr3brCxJAAAUB1q03ltl+vqXTNDJK9q1lWRUG+LFwAJN1lLFmFmYnDX1jf/DdMbgai29PaWn77646MvL64tYIpY1jcCISFtvXXT6ubgvfW1a1fbNf8J9lVO6Ki2QWBE6AjdtLuqmRru9Z9HJeoOTDRmmYY8lJJT7tDBJdTSQsvAPIZ91BmYOXIN26ifP26mrS2A++5AZGLhxJts17azYxZuA3LIeFcIuJFOZGvkcsqplJCuDm9c98A9hBkQwqNj/4a4LSncJPZhQ8tuWz3WnL0G+9dP/WabB2cN9mN6S7ICDiWkLEwTOiARHqZ0dzo8ndfLtT69rOdLS5iGlGjjUKt7RyQESywlexJHUgQjtGmUh9N4OgzTQEjVoDsBA8BGZ5UQG4tKbWfXxQikoZuDQjyPfQG+7e8S0d396XI5D0MJu+2U8939/TJf53mer5e6LghQisSzTSmZWc41SMTrul6v1zirQ012p8WZqrZaa0rMLUjQpr3W1mOZ+37aE90Y7wDAzBbCZ6qB6d5k7FrrqpGd7Jju2/4hMVFC7ICCBCwE5k5YW7/M1zRf8mBm4YvIiAII2ywLhWIUsQgSAbCzAbuiW9fzcrnUi7Vrr9fL9TzPMxCCsBELUiLOhLO3Vutzf1nXWms7vpxLzuMwHqfDaZxOZWKgRJIwFM7DXHdrMSP9eLa9D+1Rg0N0Pb13cCMAjt4jEUqinHMqGYWMsBug7QAggqGrQ4Ww+/PmPhMpE44D353w/t7uHryMYc3tgOreTS3MM2JDRnQHhM3owtU9qF+19XXttZp2N0M32jhtURD8MBH+h5r47123ziTudfN+TjK4q+k2Z4BbIotv/mLIFvbWw3UghjRuPbgd3bp9h7/13f/Dz+buqobu6OEMeIP4411tB88rDXwPkkxA8cKJfMctwAlAHWAbtjMw97Xpea1rbWttT8v8NK8va629hUNPyCiLMzAl9owgGBJ521ImZiIhQZKMnCTldV1baz69GWtHwEQoGAS6QFV9x2DDBs0NXGGjE3T0ioySp5yPmU7kpQE2JCPynVvjZqDNtBG566LXb+eXL0/zy7UtIaW+1XmwycUj7oSHHWq8/YEtOtx2wN6A943xt19qfl0VqSP31SjVTftECHNmkOocQyKgGQpzNE4jt1LT1rWrtaZLs7n2ee21mzmEAFCOpo9r1365XL8h1CGPRZgQwGttrbfVNHfdqmpzB98KSyJhTiKMBG7Weg2pzt6DQnXVAvBwu5Hzo79897rcGFLBiY8B6/iXNoteqxqjtq12VQfb4jkhEwkzOoPjpmJiYJvd0+Z8Q4hRLUNI6KiDIxOVzDkIzh2AhJRZiTjiTHA8TU0BgVjAwMxCVvIHDsePNO3b4kTGTSgPbvjCTrZAdCIEQkLAYGWzARKnlHPA8Dkz0Q54BWMFkZgT78OkNyLh247f3v16/Xy32UAA8Fvr4IcEJUZ+HcC9O+zVsFt4N6Ft35PImfAwlp8+fXQ8f328ni/1++NSqIx3A2J379E43yAnU7PVfXGoqq5WhPsw+JCpJFLAENKx6C2AuCdwAQOArY+OwLeKADfJPNirBN6nHbeLmR8eHuZ5fnh46L0v65oTj4ejar+cz9fr5XJ+Hg53D/d3ret1WSLnM5vc/fHx8XK5rOsa4TyqowDYVbW21ru22ohqrD3T3mrVrq2rm7FIyjnie2vN3WtdEfF6vVwu51orvFGcJeJlOT8/v6zrAgCHw7GUXNLrdBJLymW07r0ZArt7EQambuZmtS6A6OZMwiiuuLaVgBKjSVdralU9EbAHzV4AnJXwuq7Xx2dtV+2r1qpqSVLilLlkTti7e63Nuvpaa63nZVm/pcckaRqG++Pp0+ne7z4MeXAgi5wPdgkrxkBb/sPQvi1CcLfeaRv1ivzGkYiS5GEch3GgxM5gcWwq3PpHjt4BVnQH6EQmnEsej9Ph4T4dj1AK8KYP4AhmvampVt0M2tuGn+6glrlbzOkb1K69afDsfGc/+W2E93/huhXt21Z1gFA/YGFh2mfwbkX/H76AmXdVB885D8MQzo60T7ttG3//q38v1bidVn5r399+y0xbj6M2pnQjHHnoFW0dMTADNcd44Vt/ADiEM+BNXzRI9f46FenuzXRuda51WetlrUtrbipuCSyBkXXUDmCMPDJPRAmRNkJGdCdC3jS07hEBk3Bda8+pvnvUGAzufRgab7TUqMiCRxf/ogEa55zHw5hOCQaHZEi6CX6FwmqgqE5g4B107vP3y+X7eZ2rqsIu5/T6sLfa6tagCcxgdznfaz8Hx80m5hUkfXMTSMyGVA2sKYW0u6oIZQCd6+o+NRtXnYqOSYQgISQOqNa62qq6duuGCqRAcd9IQIwpYRIBp8A1em2zu/YuQkjQNuOEtvb4fGBu4ECkW2COycpYk61HgqWqgQstrD7c1jq01dvqgCh5yz/NkBBo03rYwruHhafuG28Dt+K3mYkFGRzVwPWNN/xOnYgmORKIYu8eAvdCNGQ5jbkUVldUdSU0YhMCwr137RsxFonIiQxd3fqbzfi6hW+bBZwQtqkpwo1fss8LIG4ENyREIzd1dTQU4pRyKWUoQy4l55SSCEtkjxDzMDs1kQjeaH45vonn9poTBsTmG6sGf0Tm/Ady5qYbGMCH7ajg7QdRGW/juoQ45vzpITkkBwGHx+/PpO5Np5FzgX0xxN5UhI6ohNGUaoAVqSF1QIkWBKDQlpiQI3kg0Bs+H4I2cZO8fdWtX7MjE2/qKQTMKZ9Op19//dXMvnz5AogsSXLJJbdWnx6/Syqnh+M0jeMwBEXuVqCLSJyfb4HPiPQhIr6XLabd6rr0VrWHZi2lnMdxdPdwkFuWJSVhltbajYO1CaaYuUNvfV3r+XwFcJHEzOUNqqVda+3hA4hIxCQpAaG2ama9tjiLM6chD33N87KYNQPWvrY2r8JCSMLOaGbWrNe5LZfzcj0vV9OGm8sAYxJOOaeBJCEpAK9dF+1rba1rN19rY6S2VnY8cLHJKTNx6mhkPY60LfxtMiHv1tXfFJp1d9emrWqvFiUcJCSUlMt0OE6HWfJ3FERC6Fs+ihu90jXqFEIn5pLKNJzuT/cfHobDhCKBnsdH6dHO1dV0qfW61mvXxby56d7jhm6qXU2993g3EFKr79bUTTb1f/p6h55t5QhzKCS4u/budjv/93Y7vFtpoVwbQ5M55+i67e2A/9FP9ce4DgCm1lsLhjM6EzJuqMrOUo5eRowUGSLZNlJ56+nFzDEBwCZVEY3DGJJyBzWr2tfeFm3dmltPoMg+kg/o2ZW1iVMiOBQ5CueQblL1ZoEYogXUBcIC2dGdHFaRt6F9Rxj9NQX020MMtBVhY887OrFzHst4GvDonhRYkX0f/bP9lGEEAmdo3q/t+ni5Pl16ba9tdd+/2U5k3N+2v7bYbyf0nv/Eb21/4L0iKDGVkjmJEVYL809194TQEddrk9rL0oZcp6GMKWWhzJSYJRBu89qtm6szEJIQo7orETK5CJaMQokxsSOBh7Lf7szkDtC0Y9Mgzd04pubeu3a1TSCveW2trnWjFkd2UwD+/OZ1qLuDZCRGFnIAbQ4OQaffEexN/SDQcXDfoT8kZEZhEHYydW/mupnP25YNbXb3IgiAXa01bA7okJkPJd0fh3EQ9W6K5jmFgDzwvnEoKKIAQMRK3cHVrLn+sEcwoIV9iRGybNJr6O4BC7+2xjwiLgOCqlszcs5SxukwHKdSsqQkm5hc5KHkb7blBsG9QXf8TWpKe5sMMebqcYd+Xj8s7Cvf3x8LEXWc8CZiAZvWlCOFO6sTOaABeE78cJclDaWUx+/PT4+P6+Uyny+fP50+fZpSirWx/3UEohgscEB1rIqLgqgXdTMngIy4GfWCO4BCAO+uEdq3YVngHYHa+1zwhjJ4eyDuh+nwj//4j+Z+vV5bq46YSz7eHbW279++HI6nj59/GocyjmNrra5rAKXDMETnKJrlW6tnQ9mDk4tRoAJg73q9XLU3RAjQcBiGu9Op9R4ONJfLZRiGUJsPGDVeY/T1mVjNwCGm6ncBlddrrfV8ufS1au+JRZiBCBDMwdS1dkOVxJnlMEzWrr1euipaN621XWdGQGNlI2jeVq3aVl3n63KZeyOElJJwyUyZObEkycxC7EA8q157E2lNFRDdfRuVa90BRNJQRpasvXFwyzaO0t7VfHdc/dHUdW+Aqnpvpt2tR6h1kVJKOh79cJzzUFiuoIF1hoXJts1sywLRGdKQptN0vDtNp7s0DLFU9qN3G8mK4eqNHmvhbUaMrFunNLi7Zmo702qv8iIh99ux/Teuvx/y92gNW16y/3ki5hBnhK328/38f80D3n59ZgnixlAGufm0/vc/wH//owabDQ2BLLS5bAMht9uPCs4Q0EKj24F8q0zDWtl3OQ+LgsrQNwq5EyJzyjIOGdGFgVAFugs0smVN10U+JLZGwnxX5OOQ7osUptCtdVPrGGMwaOCIbk6Awuwi+t5jN1bUhiAEoLnn+xv9Np6rASgwMKeUB+Ej+mDGHUlpG73Z0iy3sERhctTZ2vM6P13n81VV38Awr43YW/C+cUpvCd2OpbwP/K/IyNsXRMQCDr3pTQU7eqKmCOrYYKk2J12qjblnpiycE2fmMHvu6goY6n7hd61m6N7NszkCClFmEkKGfZwFIQJbuCiBhVNWCFTFOJy1bmvra7Olau1em/autluRgTujTq9FOyChJBomzgNzQjdfZ9PmaIBEge4huLmq9pufR1DDiQiNTMkQFHGbl1MH3F3B9+0dyleImJgSQwNDwiHxYUiHKZWRLTmwEyExsQggqoVI2d5EQiYUgwav8kM/Vu3BsAtyHxPECBUR7yv99RVurz7cV7u7Qs4ZhabDoYyDJKEYRcC9Fof9v0jp9wdI24Qb7L8J+5/cW/CRUL+rzvevAvDHkyqMgCGq4ch+MVR7CIg2Crp7uLchIROeUBKPmZShrUt/Ob8gKmA/HmQcMSXYxjfRaQ+KuGF6qhACHWBAHg4FQYB3Be/u3Tcxvdg8DIAAoT0HmySaeTDk3V/Dh5m+vLzEJNuHh4eHh4frMjPTMKTDNDx9/f35+/fz+Xmer+62G7NG9kZRstOm77uF9ltQFxEiHoZBUkbEZVm/f39cl5kQJMkwDofj8XA8IuIwDACgaikxEW2EidcWezufzwjYuiFSkmTuP0jRAUBKaZzGyqK9ISASBQUkAlHXxp2YAMCYKKdhHA6NU2S83a1ZZ+9oamprX5e2uHUwxST5cBDm8DXIFM5C0Emc2AGIJbtPAIoseQCAUAgskvIwci6epKJbq3PrTY1D5Ah4H0HkH5bW3wjtcfKZgjbwBqChkgssqXA6HPh4XMowcBZrBghIQFvhvncz0YOEkaZ8vD8e7k/j8SS5bPrm29kRJJwNS6BwDUci5CQJAaE1A2PYJqZwL9YCKCOAOBW24Pu/VLX/8e9gSEIRRfb+BkZ/12XfowYhhqCmDKF1TAxvUvX/2Q/1R6zeLFAwNFSLYwUhwnUcQRRDrQ6+jQzHg0IMfRqK6B6itArmgTMaRg6DIxQgmIq0KqPAlb31vjL2rr21dV0zWBY5DfnnY34Yc5aQz3FXdUCF7o5A4dMRGR0JSzyH22Xu7tEz3KrprSwOjpICxMGlCB0FuOQsI+NolrtTR7LYgO6RgKMbmlLIGem1r0/L9fEyn2fTmBR493Zfu52vFwAEBRT2JwYxiuyvokRv2qavb55UTa3uX98BguwBoU6OaEvV1qzmnghz2jRlUwol4GjuozmoU3fs6qYGjonMkgMDIWaRIUlOIkJqHpOvtYYNuycmFgmBYgXfkQ9s3ZpaCyB7Hwve7DjerysSSgMf7vN0EhbQbsRtvZq13YRmowf01pqp4n7+Rg9bu/c1vmc4026jbYELBpwUYS5gMGFIhILEiFOSaZBhEhlozQgJMSFJQBXetDVtahpMWwQmTIh1G+R8j2QHe+1NVujMmEWSJEHpwe6FXQYHwAPwcQ/jdgAs05CGNB7GNBTabDL3l77/YENw3CK7vxHzbAd/4iDYq9tIDwO4pl1y/9YCek0j321zc1VHILet4t+dexlJAgzYShzsCI7Qh4RjTmM+Hif5/ffn3357fH5pZqvpQdKBmQLp2YHIrcMCQBYMJzdwMkd3jpcX423uzSK0b+xkCuHLQCUCHNxUh+DGk98uVfv+/ft0OByPx+Pp9Pnz55fLuauWRJm9zeff//Jv5+en56fvwMPWLtz8jjmlBADBcoiof7sQPaWUUp6mAzGH0M2XL1/PL8/glks+3R0/ffoUoGlrjZkRyUyjckSEcRyZ+fn56eXl5Xxu7p7zyCzDMJo7IZu5v3kp0/EIPq3rsi5rDz1a7b1v7ooK2qySBv0BUy4T3Kt20zipWYmUEBC66dp7rQ0JkqSS83ikxJIkMRECaIh77dgEgGfmY8pcptY3R03tPQkfpimNkwpfrM9zrU2r4+CIyISCKDFl9gPM+DdDuwOAK3on9MSQKIwLiAnzMPJ0OB0Ox/N4WXHR1nGfG+aUxrGUnHJmde3ej/d3H3+6//Dx4e7+LpVCiUsZh2EseZCURBqH3VmkDsSaimo3097qss6ttp6C7Gu9e+9Wm9WuvTfrHUOyzt3Nx2GEv3W9TfP/LpftzZ0HN3Sr07dTLjZ4BIgd4d33dewcB0AMeSD6W/Z6//0Ld3ry21+MbRS/RAZB84wRKN9ScnfaYGRzx0C2o3L3kMDcZON9f0eIDMRORCJEmBEQoSIwuJsiWO2NmYKfCK73Q8o5TSXdD+lQUkkSBwfSHi1jn+9sTUJGJqL36wrdadMBvf0XVVQMvIVpOhpiJ2ZOg9DgmrpxB9TI/nB/WVtNFz6Z3dbn9frtupzntq7uFmDGGxgV9lP1TQ3+5hW9/tBvoNXmQLoN6u+XO+hemEYSGl8QiUJNPPxawLx2BfdO0Dr1rrVr6spMgORI5qoAZqC9tdYIXHtIWzIaoDmFijRxOPq6swM7cDczs+5KGnHVTc3NyF0AMiM6ZQn0crutiFUqU32zsiRRAT7e5buHJBl7UwRA6PUCYOhb87eFIub+GPf6wMi69moO4GQ3gav9BUX/b//u8bg0GJDOgIIkBMDqYq8BnxEIumnta7fuAHsMJUQBpECW6Q99N9x60hZiq8KYU0qUiIRMX5dKKNoQbL43bhAuIiWXMUsS3jDrrTkOe2q3ax1FtR/reBdPjGWwpYa4qcdvJA3YEOud8RFJYZTs+IfOm8HuaGzESElkLGUohUk2BVi/Lc1ASS36BjQkJnTt4LU3pZ1/81r9bwuVCBmRwRmUTME0mifkW3S2ffH6dhhv5HHct4W774JCALbVWQRvDjp3i7myoCt8+vhxnKbaG5Mz6HQ4DsNYa/3+7dt4fJA05iRDyaUkM2PmdVkoeuxEAGBmt3n3nHPOQynFAcMp7nw+f3981N5Pp8M4DtE4ij95PB5713Vdal1rXXvXlBIimbl2W5dVzYRLzkkktd6u1yvOeHecYAcaf/7lH/J/ericLy/n5/PLy+V6rutSWzNT8HAAQhEmACml6KjaQ8VQARyQRCQlBxftPAylH0IjjynaX8wUyG70AjQSMCIEANZeej/2ZqZbuebGhDmlYRhgGBXQDFE8o0gZkRMg73uTfpi8+OEIhu34C7VXFfLEkAmjgUXIUko5HE6nu/vz+brpLamFyWMa8+nj/d3peDiMTdelzcfT8dPn+w8f7u7v7jCxMw3DOA7HkockotaSMGzQGzfJZhaCWLWt6XKuqYY/tVlM/kFtfal1XddeVyEqWcBBVafpAP9L1+uxvi3QLdZs+ofM4bqBe1zfQztGDRBF/qbr+gpI/k9Ed/yD5+btikeLe2loOxwfMpZhIuSbWuve3ICA4GPo7WYOE2ohgIhADMzAjMIsDJsuvJr2nLMDcGcmQgdGLwy1jRu7iFEIQxjs5nmz7f1t/hM2o1OHHwb0ASOo7US21xvcegsbD0CJlKkITwxFO1VARbI4bvdvh5s2rSM6erX1abl+PS/nudf2Gor3fG0jT70t+PD2052NANv5vf279Qjc3xOwAxjfkRpivimYxkl3axm4mTdQM4isVHqXxixCxA7Y93JStffWCKwTgQMBhWJL79abdoWSAQDVvCk147XHLVZEpOBwe7C3XdBHhsJMknaROt5BdFig/PNrzYsslImPx/TwUFKhVrt1swa2aq8eG7C3pl3hJsqCtNV5itqhNzNAY0QGpBhi2Xote5cMwV01mLDoHUiRnBgR0RWbUd7ccoTCKKp7X/sadgbbpDgxIDuREQHylkG9eSGR00ZXiCK0SxIWAkZ0AAI3M9wOGEAg7CG+LURIacxpSCSBJEUw3PfivhJsW0NbpITg8O5J33YmxIMG9i1x2Js/cZy8PpxbWH+337d0wRAAhXnMiXwcy8CYzDAs5dw9PGG3DQAADoLIRfjDNI14vazL0koWilYo4A6LYphIADAYeyfX4BBF6o6+pbC2xW8j95irwI2rC2aAIWFiCKEZE3f/ljTg5nVdF5FQfP/w4cNJdWnNtPb1ejieTqf7Za3fv34FSvcfxqGk43FiEQBg4jOhdgWA8BKMifbNISblsNBUMwCISffL5dJqLSVLEiLqvfXeiLiUcn9/dz6Tmc7zdVlmxMAgu5nfxt6ISIRbqy/nF+36D3/+5XYjv/75H3/95f/5+Pj92/cvX7789v3713m+1LrGcRXhMyAe2Aai94Y3gCOFmcPNOjFefIxR4pYi0+08uf0TO8fcfRva2V0F9jZPnPDqnrM7EoGwDJQykdBmVfIfh3bYWMMOYAraQRuabmkphDDqOJ7u4OOnn5e1xXifRnbvIENKQx4Ow+HuYJ6HLmXIKQFxzGOEK7iHMoODmTXViujC1MLEysHA9xSIEGTjVIMRIjG5sxqoqvUeUpuw2Vm+iyX/Ix3uG64eGGLEDhZOOY/T6O4553EYb+LGeKs4t/JkU0iIuZyUQpT5Nj+E/yOf4T/+tIF2Ipjrtk0JMRx5AG6arZuwW+AXaLvUSri3KgKB+2bAEzdigDfqRTRFiUlEHBCJaUUwt6QwlERkZkkSMfkG0jltx0XUrbQpmLhDDOfCBtS+uxE03xgN7lvXYBva20oEA1dkpyxZEmM2la7QCHfRnbBbj9C5w+bYrF3a5evLy5fnOq9m9ncSKn/l9Mbx/R6AeUVs9vzM//ZXismqmCUPQvprtUpBEY+nhA47fdjNtVszp27E7EDqsCk5urmjOpgpLFUN1qpL4iJcEpWl5ZQ8pOPNm9naaq219+5mKdoetKMlAIBOxCHYRgBb5eURdX5cb0SUhEuWsXBDmlOfqV/B3VTB3NysOxjS1gQzBd+UH03dgMHBjZDIY9QVt5M+hNkQDNTAmoGCrqgrYsfkWJiSEBAYQ0zscRYSdsT/f3vXthtHb4N5kDQze4gdO2hi5D8U6IPmBfMURW/a5E/yN2vvjCSSvaA0u06Knu4aDAEHcLCz6x1pRPL7yI9iWkQMgDEQcQittdzA/M2+40dgJTfQBEEZMfLAFLtuZM9coRUAIiBYgAhgDIohRWJuwWY/pNsv/f5hX+JGpGP3/CuT08B+cMEQxAYRmM/mQQes+jVgprpe0pfDI3ElgmkixFRrTDFR0FoX1apNUtHUFJsc2+q5LUXXJw5DEubIFAEZzJnchOhHPxA2cUatUlmqgvh8xTZ8Yn1y0dTpZ0MFADEzMVqxTG0yGQjwLGsHhMAhELlGcYpJTIkpZ9BaxulwvL3PHz+eTk+73aMdXxx2036/M8Raxd97Oc+uCIsAXhXvWbs3HzGTqgAoEcQhTrtpGNLd/d2bh4fj8ej+3oe0MtM4JtVpnp+enmxZsohWcTYGpFW/akxBLfEjidTrLZWG3eHwUhSrWqlmyOP8WEruRTqX6Kq39HT5YEcSvXdHrRexNkjnEvj2DGA9m653MzQsCHyNPYZAAAWt6hNSDYAIgj/B47jf7w7j0JuzruyZa7/6EFCBWqBkq8XM2ihoIg5xdzym+/s/LHOe51lUc17EXXEKFIlTiEMiDsmQCAyqSK51YTYEVitVcxVG1JLPOZ+17VNPO6VKWZaSc6nVz0nvsLcVZkQgn5TkojaNIPvusf/XnvU5rtfbngg5OCyxD8whxMEHrnenen0REYYQx3ECQJEaY+p40sWv/+fe/Z+aezEwUzBUBCGmVTXbuoN3BhFMgdCQrOvGIwq0IY3mTcXKRgTkUZWokrbJk4TEzIjMFEBVaxFmiGEIAQGIGMxKEXX1cnRN7FYxvValmXZS0+BblQG8BAD+0zS2XbFcERRBkIFSTCGRRVWqaqXlFC6th4hgJmjSxwYsWk7l8bfT6dPXMucVhP92fdu9hBbD9aq6DiCgD+kyuwCOq7t/9iVWkZgYPI/tY7ARWt92Y7fXaihDL5JSEENUJDVHIVrFCSKSmakUUVuqzaEOIaRAiTGGJTBbUy5BAcyl5FKWeZZaUnAKnyITty5zihGJFUCcXvfycGYuxPZsZFr7CyLxGGJQGogDEpqoiPnh3VQ+0BRVTQUaq+4UbnC/ae2Eb8vdUli/sVJNq1oGzWgZoWIgGkOIgYHQ1R5Dc+1kCGImJgAcoCmPebk7gIlWRenFmM/3lYceUBGUCQMnpqhGvRmIHJghwEbfewRLBoaEDJf295WJ6Sk7tDrVJv7U/Lw7dd8dhE3CA9Tn5vYbwH6ctTRb19I8h7e+ndeOPgJRKMC0o2FMYMwUmKXK2VTaZY1gIa9jaqkjKBFFjERxSNjZep+YZwiVICAwoiEQGKioVKsEVU0UTKlT+44IknVaCaDhg9aUmrHHb+siPMvaCXFMQ4rJY05mRkNVFiJEGsb94ebuy5evy/xpOT9ZWV4cj9PxmKuc50VFVAVVZwAicqlX/7bMPIzDMCQzKKBgykzTNMqLIxK+fvP67U9vb25uVHVZ5vX4jTHudtPplHzKe62C2Bp0zazWaiYpJSIYhvTNpmpSxXEch/1+n9VgHHe15ka12Koe3GgPP5d851iXLNJVEdqgUSJ+gnRdBL+//ZBp/I41uWjooGxbY09uvc5cVPuoo8CUhjSNw24cdzFGel68/My1//rzH8k7pRAD025Kr17dHqZDhARzkjpCCEpixQ7h8Pr2YcTxfPcktboU45Dibhp30zRNI5IZVABjoqAHnaNVwhmNpfKSgzBxrUutxZ20ZsY6kGrQZFaZJEUxbuNfxGfQI6tpHWuZSi2FEAKzb7bj7uX6LUop79+/99KM/9ZcyMtzI8/PUkrf172vL1585rqqF4P4cfQ/e3RVnef58us41ftXftKoAxuMPggEoHHt1PnHzk22WjV3Oi6oBaiGaqSIxBBYiSqxMYnX0YAqKURDMLaSYsGppuI6aN1/m0gT5sOrz2pOu3UEUpuFY1Cu1OjQ8PbzSz0JEuDKabQiCWjUoAIqBYyJBn5E/N2EWDEiuyo+AHhDfYMuzQiEsBB+sYfhYfd2PN/+VHJeUZhv/+0L0nBEj7rWF1gbEt/wALOWhiIeD8f1i6Td7u7tr8TERC1lW6M466hcH41ziS962d4auF9A7oYZqKkgOGZCTIQESlgQhdx/drkTURKJu8IqgQmYjVG8ghdRiCpTCYF6tGFtpVCuZlYCwi93fyrwcBuGfY4JOIi95CUe8w3lkhV6s1CD1p2h6FmJ9fvYQlfsJW5oIVEcKAxMjKJNct4ETHykCA7IL1LYHyJh1DkJDHaOEmLlgBByJcsMQJXi8vdkQ9ShnMLXudS5RAUGNHm8OrwKwt8SgIJFgGSo59/x88cT4uICmlVdA7wNEPB7rybaKXQEwc4erEvV9qz1dXPo9QrkgQ6lrpn8FSDUXtKDhYbor1ea0x3lWdj4eRkW5RYvmFmDbXtlgYa1sh4J2uQG9LMAoMl/tefIOlrnpQNLGWV48/LNsM8aQ5zT7V+ewlCB2fsNOgvRNmPfbNZhCPBABw1Q2zPe4mB/hk7lkibWKn/98CGEGFOMITKzgctVl1zy+ek8L8JpOtzc0zCds+Dpaa5aVXOppRQiGsbRI0w1CyEAWBoGZqolz2BmVkpBsONh/8vPP+VlQcK7u7vA7BMvsJu7URFBoHGYWmESUopKiFJlGNM4jjFEJr69uan7ep3v/vbpz2qa87Is5/N8zstcpaoKrHel7QCAnql7ou6k3Lr0KzvrIbBjQXYtlbx6+PaDl+DRP6KXS2M7pdSs/5+hKalxzifTpPZ5nsfT6cPVewO+e/cONttss80222yzH8Xo379ks80222yzzTb7/7HNtW+22WabbbbZD2Wba99ss80222yzH8r+AT3gtgsKZW5kc3RyZWFtCmVuZG9iagozNSAwIG9iagoyNDAzMzIKZW5kb2JqCjIgMCBvYmoKPDwgL0NvdW50IDEgL0tpZHMgWyAxMCAwIFIgXSAvVHlwZSAvUGFnZXMgPj4KZW5kb2JqCjM2IDAgb2JqCjw8IC9DcmVhdGlvbkRhdGUgKEQ6MjAyMjA1MzExNjU5NDMrMDInMDAnKQovQ3JlYXRvciAoTWF0cGxvdGxpYiB2My4zLjIsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcpCi9Qcm9kdWNlciAoTWF0cGxvdGxpYiBwZGYgYmFja2VuZCB2My4zLjIpID4+CmVuZG9iagp4cmVmCjAgMzcKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDE2IDAwMDAwIG4gCjAwMDAyNDc2NDUgMDAwMDAgbiAKMDAwMDAwNjg1NiAwMDAwMCBuIAowMDAwMDA2ODg4IDAwMDAwIG4gCjAwMDAwMDY5ODcgMDAwMDAgbiAKMDAwMDAwNzAwOCAwMDAwMCBuIAowMDAwMDA3MDI5IDAwMDAwIG4gCjAwMDAwMDAwNjUgMDAwMDAgbiAKMDAwMDAwMDM5NiAwMDAwMCBuIAowMDAwMDAwMjA4IDAwMDAwIG4gCjAwMDAwMDA2NzUgMDAwMDAgbiAKMDAwMDAwNzA2MSAwMDAwMCBuIAowMDAwMDA1NTkyIDAwMDAwIG4gCjAwMDAwMDUzOTIgMDAwMDAgbiAKMDAwMDAwNDk5NiAwMDAwMCBuIAowMDAwMDA2NjQ1IDAwMDAwIG4gCjAwMDAwMDA2OTUgMDAwMDAgbiAKMDAwMDAwMDg1NSAwMDAwMCBuIAowMDAwMDAxMTYwIDAwMDAwIG4gCjAwMDAwMDEzMDYgMDAwMDAgbiAKMDAwMDAwMTQyNyAwMDAwMCBuIAowMDAwMDAxNzI3IDAwMDAwIG4gCjAwMDAwMDIxMDQgMDAwMDAgbiAKMDAwMDAwMjQyMiAwMDAwMCBuIAowMDAwMDAyNTM5IDAwMDAwIG4gCjAwMDAwMDI4NjcgMDAwMDAgbiAKMDAwMDAwMzEwMSAwMDAwMCBuIAowMDAwMDAzMzg4IDAwMDAwIG4gCjAwMDAwMDM1NDAgMDAwMDAgbiAKMDAwMDAwMzg0OSAwMDAwMCBuIAowMDAwMDA0MjU0IDAwMDAwIG4gCjAwMDAwMDQzNDMgMDAwMDAgbiAKMDAwMDAwNDUwMiAwMDAwMCBuIAowMDAwMDA0NzEzIDAwMDAwIG4gCjAwMDAyNDc2MjIgMDAwMDAgbiAKMDAwMDI0NzcwNSAwMDAwMCBuIAp0cmFpbGVyCjw8IC9JbmZvIDM2IDAgUiAvUm9vdCAxIDAgUiAvU2l6ZSAzNyA+PgpzdGFydHhyZWYKMjQ3ODYyCiUlRU9GCg==\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T16:59:43.377558\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def visualize_exmp(indices, orig_dataset):\n", " images = [orig_dataset[idx][0] for idx in indices.reshape(-1)]\n", " images = jax.device_get(jnp.stack(images, axis=0)).astype(np.float32)\n", " images = torch.from_numpy(images)\n", " images = images.permute(0, 3, 1, 2)\n", " img_grid = torchvision.utils.make_grid(images, nrow=SET_SIZE, normalize=True, pad_value=0.5, padding=16)\n", " img_grid = img_grid.permute(1, 2, 0)\n", "\n", " plt.figure(figsize=(12,8))\n", " plt.title(\"Anomaly examples on CIFAR100\")\n", " plt.imshow(img_grid)\n", " plt.axis('off')\n", " plt.show()\n", " plt.close()\n", "\n", "_, indices, _ = next(iter(anom_test_loader))\n", "visualize_exmp(indices[:4], test_set)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can already see that for some sets the task might be easier than for others. Difficulties can especially arise if the anomaly is in a different, but yet visually similar class (e.g. insect vs mushroom, train vs bus, etc.).\n", "\n", "After having prepared the data, we can look closer at the model. Here, we have a classification of the whole set. For the prediction to be permutation-equivariant, we will output one logit for each image. Over these logits, we apply a softmax and train the anomaly image to have the highest score/probability. This is a bit different than a standard classification layer as the softmax is applied over images, not over output classes in the classical sense. However, if we swap two images in their position, we effectively swap their position in the output softmax. Hence, the prediction is equivariant with respect to the input. Furthermore, we need to remove the positional encoding since these features would break the permutation equivariance. We implement this setup below in the subclass of the Trainer module." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "class AnomalyTrainer(TrainerModule):\n", " \n", " def batch_to_input(self, batch):\n", " inp_data, _, _ = batch\n", " return inp_data\n", " \n", " def get_loss_function(self):\n", " # Function for calculating loss and accuracy for a batch\n", " def calculate_loss(params, rng, batch, train):\n", " inp_data, _, labels = batch\n", " rng, dropout_apply_rng = random.split(rng)\n", " logits = self.model.apply({'params': params}, inp_data, \n", " add_positional_encoding=False, # No positional encoding since this is a permutation equivariant task\n", " train=train, \n", " rngs={'dropout': dropout_apply_rng})\n", " logits = logits.squeeze(axis=-1)\n", " loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()\n", " acc = (logits.argmax(axis=-1) == labels).astype(jnp.float32).mean()\n", " return loss, (acc, rng)\n", " return calculate_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we write our train function below. It has the exact same structure as the reverse task one, hence not much of an explanation is needed here." ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "def train_anomaly(max_epochs=100, **model_args):\n", " num_train_iters = len(anom_train_loader) * max_epochs\n", " # Create a trainer module with specified hyperparameters\n", " trainer = AnomalyTrainer(model_name='SetAnomalyTask', \n", " exmp_batch=next(iter(anom_train_loader)),\n", " max_iters=num_train_iters, \n", " **model_args)\n", " if not trainer.checkpoint_exists(): # Skip training if pretrained model exists\n", " trainer.train_model(anom_train_loader, anom_val_loader, num_epochs=max_epochs)\n", " trainer.load_model()\n", " else:\n", " trainer.load_model(pretrained=True)\n", " train_acc = trainer.eval_model(anom_train_loader)\n", " val_acc = trainer.eval_model(anom_val_loader)\n", " test_acc = trainer.eval_model(anom_test_loader)\n", " # Bind parameters to model for easier inference\n", " trainer.model_bd = trainer.model.bind({'params': trainer.state.params})\n", " return trainer, {'train_acc': train_acc, 'val_acc': val_acc, 'test_acc': test_acc}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's finally train our model. We will use 4 layers with 4 attention heads each. The hidden dimensionality of the model is 256, and we use a dropout of 0.1 throughout the model for good regularization. Note that we also apply the dropout on the input features, as this makes the model more robust against image noise and generalizes better. Again, we use warmup to slowly start our model training. " ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "anomaly_trainer, anomaly_result = train_anomaly(model_dim=256,\n", " num_heads=4,\n", " num_classes=1,\n", " num_layers=4,\n", " dropout_prob=0.1,\n", " input_dropout_prob=0.1,\n", " lr=5e-4,\n", " warmup=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can print the achieved accuracy below." ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train accuracy: 98.37%\n", "Val accuracy: 94.50%\n", "Test accuracy: 94.66%\n" ] } ], "source": [ "print(f\"Train accuracy: {(100.0*anomaly_result['train_acc']):4.2f}%\")\n", "print(f\"Val accuracy: {(100.0*anomaly_result['val_acc']):4.2f}%\")\n", "print(f\"Test accuracy: {(100.0*anomaly_result['test_acc']):4.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With ~94% validation and test accuracy, the model generalizes quite well. It should be noted that you might see slightly different scores depending on what computer/device you are running this notebook. This is because despite setting the seed before generating the test dataset, it may not always be the same across platforms and numpy versions. Nevertheless, we can conclude that the model performs quite well and can solve the task for most sets. Before trying to interpret the model, let's verify that our model is permutation-equivariant, and assigns the same predictions for different permutations of the input set. For this, we sample a batch from the test set and run it through the model to obtain the probabilities. " ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Preds\n", " [1.3902171e-07 4.3522828e-08 5.2554757e-08 1.2441276e-08 2.7709259e-08\n", " 9.9999952e-01 5.5640967e-08 3.5155960e-08 1.4563368e-08 4.8264688e-08]\n", "Permuted preds\n", " [1.3902175e-07 4.3522839e-08 5.2554768e-08 1.2441279e-08 2.7709264e-08\n", " 9.9999976e-01 5.5640978e-08 3.5155971e-08 1.4563372e-08 4.8264699e-08]\n" ] } ], "source": [ "inp_data, indices, labels = next(iter(anom_test_loader))\n", "preds = anomaly_trainer.model_bd(inp_data, add_positional_encoding=False, train=False)\n", "preds = jax.nn.softmax(preds.squeeze(axis=-1))\n", "\n", "permut = np.random.permutation(inp_data.shape[1])\n", "permut_inp_data = inp_data[:,permut]\n", "perm_preds = anomaly_trainer.model_bd(permut_inp_data, add_positional_encoding=False, train=False)\n", "perm_preds = jax.nn.softmax(perm_preds.squeeze(axis=-1))\n", "\n", "preds = jax.device_get(preds)\n", "perm_preds = jax.device_get(perm_preds)\n", "\n", "assert np.abs(preds[:,permut] - perm_preds).max() < 1e-5, \"Predictions are not permutation equivariant\"\n", "\n", "print(\"Preds\\n\", preds[0,permut])\n", "print(\"Permuted preds\\n\", perm_preds[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that the predictions are almost exactly the same, and only differ because of slight numerical differences inside the network operation.\n", "\n", "To interpret the model a little more, we can plot the attention maps inside the model. This will give us an idea of what information the model is sharing/communicating between images, and what each head might represent. First, we need to extract the attention maps for the test batch above, and determine the discrete predictions for simplicity." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "attention_maps = anomaly_trainer.model_bd.get_attention_maps(inp_data, add_positional_encoding=False, train=False)\n", "predictions = preds.argmax(axis=-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we write a plot function which plots the images in the input set, the prediction of the model, and the attention maps of the different heads on layers of the transformer. Feel free to explore the attention maps for different input examples as well." ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDY4NCAxMDAuNDc1OTkzMzc3NSBdIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovVHlwZSAvUGFnZSA+PgplbmRvYmoKOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDExIDAgUiA+PgpzdHJlYW0KeJxVjztvwzAMhHf+ihubRSJlW4pHJ2mMjA4EdA5cJa3hR1MDffz70gH6GojDHXj8QEFHthJcZjA6nXcIathdentu07HeoJ2JNR/Ir3PV/qbCbPJQlGWmAf+3T0QjXRGMu433pfEIbMqCdSELocBrwgNG2MotYFGwKJhRa8+HBccI8nOiHWAPgt2Ehhpcv3uMy9/u4mkTYfcCcYhncnmu1ExcgXVu5JcfH+muGqfh1H8ifZyGlz7NmEZsD/vqqM+sEDvcR2roC5CcQPsKZW5kc3RyZWFtCmVuZG9iagoxMSAwIG9iagoyMDMKZW5kb2JqCjE3IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggODggPj4Kc3RyZWFtCnicNYy7DcAwCER7prgR+DiA94lSkf3bEFsuuHvSE+c5wMg+D0foxC1kQ+GmeEk5oT5RNFpvOrZIc7+8ZDMXFf0z3H2F7eaAZDRJ5CHR5XLlWSl6PpfaG34KZW5kc3RyZWFtCmVuZG9iagoxOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIzMiA+PgpzdHJlYW0KeJw1UTtyBTEI630KXSAz5m+fZzOvSu7fRrCTZmEBCQnnPdiIxJcY0h3lim9ZnWYZfieLvPhZKZy8F1GBVEVYIe3gWc5qhsFzI1PgciY+y8wn02LHAqqJOM6OnGYwCDGN62g5HWaaBz0h1wcjbuw0y1UMab1bqtf3Wv5TRfnIupvl1imbWqlb9Iw9icvO66kt7QujjuKmINLhY4f3IF/EnMVFJ9LNfjPlsJI0BKcF8CMxlOrZ4TXCxM+MBE/Z0+l9lIbXPmi6vncv6MjNhEzlFspIxZOVxpgxVL8RzST1/T/Qsz5/mjBURwplbmRzdHJlYW0KZW5kb2JqCjE5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNzQgPj4Kc3RyZWFtCnicMzU3VTBQsLQAEqaG5grmRpYKKYZcQD6IlcsFE8sBs8xMzIAsQ0tklomxIZBlYmGGxDI2sYDKIlgGQBpsTQ7M9ByuNAADcRiTCmVuZHN0cmVhbQplbmRvYmoKMjAgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA0OSA+PgpzdHJlYW0KeJwzsjRVMFCwtAAShpbmCuZGlgophlxAPoiVywUTywGzDIA0WGkOTEUOVxoApUQM5AplbmRzdHJlYW0KZW5kb2JqCjIxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjI3ID4+CnN0cmVhbQp4nEWQS44DIRBD95zCR6D+cJ6OsurcfzsuOtFssCUo1zO5AxN78chMlG68ZLg7zBWf4Rkwc/hKmGzETOhOXCOUrhThVJ8IjsvevOmgiXtEzqOeBVnVzg1qAWeS5oLtgi7njBU3zsmtRuXN9KPXEL5pdx/XeYf2SOPew1S+zjnVzruKCGkLWdW0vpBsFMkOaz8qTdvOyxCx4GwaVugc3gi7V3cnSxh+v/IwJRM/D936UXxdN6PrFGcnVyZrz3noSelf9cqjD8VxKegXse3MJPdfp1OSqVN7Z+9p/ae4x/sPkG5WOQplbmRzdHJlYW0KZW5kb2JqCjIyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzA0ID4+CnN0cmVhbQp4nD2SO5LDMAxDe52CF8iM+JPk82Qnlff+7T4yyVaASYkAKC91mbKmPCBpJgn/0eHhYjvld9iezczAtUQvE8spz6ErxNxF+bKZjbqyOsWqwzCdW/SonIuGTZOa5ypLGbcLnsO1ieeWfcQPNzSoB3WNS8IN3dVoWQrNcHX/O71H2Xc1PBebVOrUF48XURXm+SFPoofpSuJ8PCghXHswRhYS5FPRQI6zXK3yXkL2DrcassJBaknnsyc82HV6Ty5uF80QD2S5VPhOUezt0DO+7EoJPRK24VjufTuasekamzjsfu9G1sqMrmghfshXJ+slYNxTJkUSZE62WG6L1Z7uoSimc4ZzGSDq2YqGUuZiV6t/DDtvLC/ZLMiUzAsyRqdNnjh4yH6NmvR5led4/QFs83M7CmVuZHN0cmVhbQplbmRvYmoKMjMgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDUgPj4Kc3RyZWFtCnicRVC7jUMxDOs9BRcIYP0se553SJXbvz1KRnCFIVo/kloSmIjASwyxlG/iR0ZBPQu/F4XiM8TPF4VBzoSkQJz1GRCZeIbaRm7odnDOvMMzjDkCF8VacKbTmfZc2OScBycQzm2U8YxCuklUFXFUn3FM8aqyz43XgaW1bLPTkewhjYRLSSUml35TKv+0KVsq6NpFE7BI5IGTTTThLD9DkmLMoJRR9zC1jvRxspFHddDJ2Zw5LZnZ7qftTHwPWCaZUeUpnecyPiep81xOfe6zHdHkoqVV+5z93pGW8iK126HV6VclUZmN1aeQuDz/jJ/x/gOOoFk+CmVuZHN0cmVhbQplbmRvYmoKMjQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA0NSA+PgpzdHJlYW0KeJwzMrdQMFCwNAEShhYmCuZmBgophlyWEFYuF0wsB8wC0ZZwCiKeBgCffQy1CmVuZHN0cmVhbQplbmRvYmoKMjUgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNTUgPj4Kc3RyZWFtCnicRZFLkgMgCET3noIjgPzkPJmaVXL/7TSYTDZ2l6j9hEojphIs5xR5MP3I8s1ktum1HKudjQKKIhTM5Cr0WIHVnSnizLVEtfWxMnLc6R2D4g3nrpxUsrhRxjqqOhU4pufK+qru/Lgsyr4jhzIFbNY5DjZw5bZhjBOjzVZ3h/tEkKeTqaPidpBs+IOTxr7K1RW4Tjb76iUYB4J+oQlM8k2gdYZA4+YpenIJ9vFxu/NAsLe8CaRsCOTIEIwOQbtOrn9x6/ze/zrDnefaDFeOd/E7TGu74y8xyYq5gEXuFNTzPRet6wwd78mZY3LTfUPnXLDL3UGmz/wf6/cPUIpmiAplbmRzdHJlYW0KZW5kb2JqCjI2IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTYxID4+CnN0cmVhbQp4nEWQSxLDIAxD95xCR/BHBnyedLpK77+tIU2zgKexQAZ3JwSptQUT0QUvbUu6Cz5bCc7GeOg2bjUS5AR1gFak42iUUn25xWmVdPFoNnMrC60THWYOepSjGaAQOhXe7aLkcqbuzvlHcPVf9Uex7pzNxMBk5Q6EZvUp7nybHVFd3WR/0mNu1mt/FfaqsLSspeWE285dM6AE7qkc7f0FqXM6hAplbmRzdHJlYW0KZW5kb2JqCjI3IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjE0ID4+CnN0cmVhbQp4nD1QuxFDMQjrPQUL5M587TfPy6XL/m0knKRCNkISlJpMyZSHOsqSrClPHT5LYoe8h+VuZDYlKkUvk7Al99AK8X2J5hT33dWWs0M0l2g5fgszKqobHdNLNppwKhO6oNzDM/oNbXQDVocesVsg0KRg17YgcscPGAzBmROLIgxKTQb/rXL3UtzvPRxvooiUdPCu+eX0y88tvE49jkS6vfmKa3GmOgpEcEZq8op0YcWyyEOk1QQ1PQNrtQCu3nr5N2hHdBmA7BOJ4zSlHEP/1rjH6wOHilL0CmVuZHN0cmVhbQplbmRvYmoKMjggMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA4MCA+PgpzdHJlYW0KeJxFjLsNwDAIRHumYAR+JmafKJWzfxsgStxwT7p7uDoSMlPeYYaHBJ4MLIZT8QaZo2A1uEZSjZ3so7BuX3WB5npTq/X3BypPdnZxPc3LGfQKZW5kc3RyZWFtCmVuZG9iagoyOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIzNiA+PgpzdHJlYW0KeJxNUEtuRCEM23OKXOBJJCEBzkPVVef+27HDVO0qhhh/SA/pslUe61NidYns8qVNl8oyeRWo5U/b/1EMAm7/0MhBtLeMnWLmEtbFwiQ85TQjGyfXLB+PO08bZoXGxI3jnS4ZYJ8WATVblc2BOW06N0C6kBq3qrPeZFAMIupCzQeTLpyn0ZeIOZ6oYEp3JrWQG1w+1aEDcVq9Crlji5NvxBxZocBh0Exx1l8B1qjJslnIIEmGIc59o3uUCo2oynkrFcIPk6ER9YbVoAaVuYWiqeWS/B3aAjAFtox16QxKgaoAwd8qp32/ASSNXVMKZW5kc3RyZWFtCmVuZG9iagozMCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMzMiA+PgpzdHJlYW0KeJwtUjmOJDEMy/0KfmAA6/Lxnh5M1Pv/dElVBQWqbMs85HLDRCV+LJDbUWvi10ZmoMLwr6vMhe9I28g6iGvIRVzJlsJnRCzkMcQ8xILv2/gZHvmszMmzB8Yv2fcZVuypCctCxosztMMqjsMqyLFg6yKqe3hTpMOpJNjji/8+xXMXgha+I2jAL/nnqyN4vqRF2j1m27RbD5ZpR5UUloPtac7L5EvrLFfH4/kg2d4VO0JqV4CiMHfGeS6OMm1lRGthZ4OkxsX25tiPpQRd6MZlpDgC+ZkqwgNKmsxsoiD+yOkhpzIQpq7pSie3URV36slcs7m8nUkyW/dFis0UzuvCmfV3mDKrzTt5lhOlTkX4GXu2BA2d4+rZa5mFRrc5wSslfDZ2enLyvZpZD8mpSEgV07oKTqPIFEvYlviaiprS1Mvw35f3GX//ATPifAEKZW5kc3RyZWFtCmVuZG9iagozMSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE3ID4+CnN0cmVhbQp4nDM2tFAwgMMUQy4AGpQC7AplbmRzdHJlYW0KZW5kb2JqCjMyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggODcgPj4Kc3RyZWFtCnicNU25EcAwCOuZghHMo9jsk0vl7N8G7LhBOn0glBtr5AGC4Z1vIfimLxmEdQhPKrslOmyhhrMKkonhVzZ4Va6K9rWSiexspjHYoGX60c63Sc8Hpd4bmAplbmRzdHJlYW0KZW5kb2JqCjMzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTM4ID4+CnN0cmVhbQp4nD2PQQ4DMQgD73mFPxApdkJY3rNVT9v/X0ua3V7QCIwxFkJDb6hqDpuCDceLpUuo1vApiolKDsiZYA6lpNIdZ5F6YjgY3B60G87isen6EbuSVn3Q5ka6JWiCR+xTadyWcRPEAzUF6inqXKO8ELmfqVfYNJLdtLKSazim373nqev/01XeX1/fLowKZW5kc3RyZWFtCmVuZG9iagozNCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIxMCA+PgpzdHJlYW0KeJw1UMsNQzEIu2cKFqgUAoFknla9df9rbdA7YRH/QljIlAh5qcnOKelLPjpMD7Yuv7EiC611JezKmiCeK++hmbKx0djiYHAaJl6AFjdg6GmNGjV04YKmLpVCgcUl8Jl8dXvovk8ZeGoZcnYEEUPJYAlquhZNWLQ8n5BOAeL/fsPuLeShkvPKnhv5G5zt8DuzbuEnanYi0XIVMtSzNMcYCBNFHjx5RaZw4rPWd9U0EtRmC06WAa5OP4wOAGAiXlmA7K5EOUvSjqWfb7zH9w9AAFO0CmVuZHN0cmVhbQplbmRvYmoKMTUgMCBvYmoKPDwgL0Jhc2VGb250IC9EZWphVnVTYW5zIC9DaGFyUHJvY3MgMTYgMCBSCi9FbmNvZGluZyA8PAovRGlmZmVyZW5jZXMgWyAzMiAvc3BhY2UgNDggL3plcm8gL29uZSA2NSAvQSA2NyAvQyA3MCAvRiA3MyAvSSA4MiAvUiA5NyAvYSAxMDEgL2UgMTA4Ci9sIC9tIC9uIC9vIC9wIDExNSAvcyAxMjAgL3ggL3kgXQovVHlwZSAvRW5jb2RpbmcgPj4KL0ZpcnN0Q2hhciAwIC9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnREZXNjcmlwdG9yIDE0IDAgUgovRm9udE1hdHJpeCBbIDAuMDAxIDAgMCAwLjAwMSAwIDAgXSAvTGFzdENoYXIgMjU1IC9OYW1lIC9EZWphVnVTYW5zCi9TdWJ0eXBlIC9UeXBlMyAvVHlwZSAvRm9udCAvV2lkdGhzIDEzIDAgUiA+PgplbmRvYmoKMTQgMCBvYmoKPDwgL0FzY2VudCA5MjkgL0NhcEhlaWdodCAwIC9EZXNjZW50IC0yMzYgL0ZsYWdzIDMyCi9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnROYW1lIC9EZWphVnVTYW5zIC9JdGFsaWNBbmdsZSAwCi9NYXhXaWR0aCAxMzQyIC9TdGVtViAwIC9UeXBlIC9Gb250RGVzY3JpcHRvciAvWEhlaWdodCAwID4+CmVuZG9iagoxMyAwIG9iagpbIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwCjYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgMzE4IDQwMSA0NjAgODM4IDYzNgo5NTAgNzgwIDI3NSAzOTAgMzkwIDUwMCA4MzggMzE4IDM2MSAzMTggMzM3IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYKNjM2IDYzNiAzMzcgMzM3IDgzOCA4MzggODM4IDUzMSAxMDAwIDY4NCA2ODYgNjk4IDc3MCA2MzIgNTc1IDc3NSA3NTIgMjk1CjI5NSA2NTYgNTU3IDg2MyA3NDggNzg3IDYwMyA3ODcgNjk1IDYzNSA2MTEgNzMyIDY4NCA5ODkgNjg1IDYxMSA2ODUgMzkwIDMzNwozOTAgODM4IDUwMCA1MDAgNjEzIDYzNSA1NTAgNjM1IDYxNSAzNTIgNjM1IDYzNCAyNzggMjc4IDU3OSAyNzggOTc0IDYzNCA2MTIKNjM1IDYzNSA0MTEgNTIxIDM5MiA2MzQgNTkyIDgxOCA1OTIgNTkyIDUyNSA2MzYgMzM3IDYzNiA4MzggNjAwIDYzNiA2MDAgMzE4CjM1MiA1MTggMTAwMCA1MDAgNTAwIDUwMCAxMzQyIDYzNSA0MDAgMTA3MCA2MDAgNjg1IDYwMCA2MDAgMzE4IDMxOCA1MTggNTE4CjU5MCA1MDAgMTAwMCA1MDAgMTAwMCA1MjEgNDAwIDEwMjMgNjAwIDUyNSA2MTEgMzE4IDQwMSA2MzYgNjM2IDYzNiA2MzYgMzM3CjUwMCA1MDAgMTAwMCA0NzEgNjEyIDgzOCAzNjEgMTAwMCA1MDAgNTAwIDgzOCA0MDEgNDAxIDUwMCA2MzYgNjM2IDMxOCA1MDAKNDAxIDQ3MSA2MTIgOTY5IDk2OSA5NjkgNTMxIDY4NCA2ODQgNjg0IDY4NCA2ODQgNjg0IDk3NCA2OTggNjMyIDYzMiA2MzIgNjMyCjI5NSAyOTUgMjk1IDI5NSA3NzUgNzQ4IDc4NyA3ODcgNzg3IDc4NyA3ODcgODM4IDc4NyA3MzIgNzMyIDczMiA3MzIgNjExIDYwNQo2MzAgNjEzIDYxMyA2MTMgNjEzIDYxMyA2MTMgOTgyIDU1MCA2MTUgNjE1IDYxNSA2MTUgMjc4IDI3OCAyNzggMjc4IDYxMiA2MzQKNjEyIDYxMiA2MTIgNjEyIDYxMiA4MzggNjEyIDYzNCA2MzQgNjM0IDYzNCA1OTIgNjM1IDU5MiBdCmVuZG9iagoxNiAwIG9iago8PCAvQSAxNyAwIFIgL0MgMTggMCBSIC9GIDE5IDAgUiAvSSAyMCAwIFIgL1IgMjEgMCBSIC9hIDIyIDAgUiAvZSAyMyAwIFIKL2wgMjQgMCBSIC9tIDI1IDAgUiAvbiAyNiAwIFIgL28gMjcgMCBSIC9vbmUgMjggMCBSIC9wIDI5IDAgUiAvcyAzMCAwIFIKL3NwYWNlIDMxIDAgUiAveCAzMiAwIFIgL3kgMzMgMCBSIC96ZXJvIDM0IDAgUiA+PgplbmRvYmoKMyAwIG9iago8PCAvRjEgMTUgMCBSID4+CmVuZG9iago0IDAgb2JqCjw8IC9BMSA8PCAvQ0EgMCAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+Ci9BMiA8PCAvQ0EgMSAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+ID4+CmVuZG9iago1IDAgb2JqCjw8ID4+CmVuZG9iago2IDAgb2JqCjw8ID4+CmVuZG9iago3IDAgb2JqCjw8IC9JMSAxMiAwIFIgPj4KZW5kb2JqCjEyIDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDY3MCAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgNzEgL0xlbmd0aCAzNSAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCA2NzAgPj4Kc3RyZWFtCnic7P13syvJkh8IugiRCsAR91xR4r1+FM3u4awY221+PX47mnFtzXbauNaSfLrUVUcASBER7r5/RALnXFFPtHE4a7SKKsPFgUhkZkS4+PnP3fE//sf/CD+Nn8ZP46fx0/hp/DT+Zxn0f/YJ/DR+Gj+Nn8ZP46fx0/jvOX5S7T+Nn8ZP46fx0/hp/E81flLtP42fxk/jp/HT+Gn8TzXc0z/+4eLf/GbzpRkAAGD9DwHBTh9AAEA8/4GAj18+PzUAMLDHbz194/TsM2+f3jM8nUD9pEdoCXu2gSwiOMBitoiCGSF4RE/Q3347vPlVPYCofn//Zp7m6eFYUgYzfHpy9ZcRAImYiNkHH0Ns2ti2TWh8CI6YCQmRANBMzcxMwcwQDcHqHTHA0yUYghoYgBmYoYpKLiXnknPOOaWUUsopm4iJmimsxzQibpp4c3N9dbmNwTHzvJieDvv9D9/+/pvfAht6Q4foiBAJ0QRMQIuZIRISIzIAAtRjFwABUEMDNDMGcAAeMOB6xgIgiGKgYGiAAIxAYIyAYIagaBlMDLXOOCACkAGDEQCBIQACEoEBqEEBy4YKiGhOzenFcP2Lm3+7TofIf/pP/2kcx4+mGRGZ0DnnvGNmJgLEOlNPFoetKwKBCWP0L148e/nq+c3zm912MC2S5nmcxnHaH5b7/fTudn+clqIqokWKqqiaiEgpkovkYqpmP7byPh6I+Dd/8zeXl5f1z+++++5v//Zv/8Tv/lkDEQwAAUOIbew2w7Zru4fDw+3du1ySqPzYZkGsOxLrcvqx419eXv6H//Af6ofN7G//X//v+7s7A7TTZkYwQiRGQiAEQmTE4LgNPjr2TJ7REZiZimaRJKpAhnWAY3RMnskRAYKaJZE5l2kuKYuZMVNsvI+OAilJsiQoyoqERIyGJqiZNKEpIxCAAaqhAKhaMRBiIMYvX/777XBTL+phf/9f/v5//1SOIHw68MN/V8l1umFmn/sWPu7vH7n5f/Dd06jTsj7W9dy23f/2v/4NEddP/P7bv3/YvyU2IKi/SQRIxmRERnQ6RQBAMDURqKsYAZCMCJlQDUQhLZaSlgSlGBgQYvDoPIpWwYNmyA4cg3fmGBlJCoyLLNmymQKe5DuamBUDBTSsp28GZmCIxkTBUwz/6ouff/nsVb0KZt1e7NdrNAMDQESk+hciAGKVX4SIhIhQF895JtaXzlIAAMBUpJRckuQMOUvKRdUQkRnZERioWhEpkokde8ckoGW/X46TbLeb3cXQRCTMD+O7w3SfxYjDbni26S6iC44ZEUTw979lkfUnX3z1l5c3PzMzMHu6utaTwo+fw1lVnoUlwuP7p0WEn1kx+PjC+jH7/Junu4H4+Do+eRcRCeGHb3/zm1/+4/lbH6j232y++s/P/++rjKgTg1jleD3i6airQKmP9pHKX29KffpURj9R7Z8VQ3ZWj4BWLwMUrCHYMVw6u3I6IASkVOxYCpgxQkPQELxQPat2Nf3+7u3h/uH++3fLccbT1XxwU4nQEXvvY2i7buj7rQ7mDBtzDskhMhEiAKqiqZqogQKRIhqAKVRljgYGoAgKIAZmaArFJJec0rJM8zxP0zgdj8fpMJYlaU4qYlpUFQCI3Ha7iZEuLhpmCo6W9Hif3r9/9w//9P/FoNgANcQNO0cMpMlksZLABMgxeyRfjRDQbLaAJcBVMRs4swawReoACMwMMmJCKAZigAYE4ME8GaMxmqAJ2gyWjATQEAmADVnNAXjQVcEjMqEhiMJsOhopEoFFLU35+tkvzqpdVf/2b//2/fv3T2caAYjQOQohNE303nvv1x3/ZMFUfWWAgNREvrjod1vqmquXz9sXLy6tzMuI+9vlPcgyLWnZv7998/ZuP+eSck5pKSJFNaec57SMUxqnkrOeN/EfG4j4V3/1V2fV/u7du//8n//zn/jdP3sgINKm217tnr188erZ1bPv3nz3m9/9cn/cL2mGH1Hbq2oFAIC6oj47vv7667/5m785q/Z/+ru///Z3v9cVsUM82XXOkyMkgmpG9iHs2qZvXOu58RQdmmrJZcplyiJAigwIRBA9NZ674KN3SCCmx5z3c7rfL+OURc073myaZhNdz8WVycYFF3GCjtg7UoJMZeI8khWP5gAVsBgVgyyWDDJ7ZI+Xuy/Oqv04Hv7L3/9/PitJEPFjjY9PZPETUXDSuI9vfjgl9UMfvHZ+Bc8v/IiSP2t1Oy9kMAC72F3+3/6X/8dZtb9++5sf3vxXcoZcZQoQm3PmvTmnRIYIqtX4AxErGUoxkbqJzDlkJjXIGcZJx6Muk6XFQMEh9D03DRXVnK0IqqEPEAM00RqHDFwy3O/LYdFJrRgAEACCombTxaAArsraqvVgROaYh84N7aYbzqqdSIftEVHAdJXh1TsyNLOq14mJmJiJiYiRiJAAAQwUqkxGpKpWzFStlFJSTss8UzbRRVJa5iJGRCEwMwMgGKikIin46GPjuZhM+e5wHPPm4nm/hU2PTOOCv7nP3y6gzg/N1q4ufd9Y4wMilozf/p5E1lm8vPnZz//y/2mqpvpEta/K7qnqg8fngIgISHD6BNmjUwRPnz8ulw81qgEYfla1r8vQsB4BH3UZnvQuIzAhAvyoaoePN8STC7PzcsbVDDMDxGrxGqAhIQCanbz8uko/1eL1zQ9fRFg9QVOs2r2uZDNEE7OjgKrNRQMYqYhaFjUwBmMwh0aLfHH+AbXpbj/vj5LL2dw1AELEurIcuxBcE5q2abq2a9u+a7u+6domNOw9MRPRyRUyVVVdr3M1c07vGKipmZjp6sajGaoJsHGgQB49cHDkmJBmOMylAJTzPakTr1JySotjU1DjD5wMRGAEZ9xw6H2MzITLIYkIFAIjYmAGRDUwUwA96UMFUzijDAAAiuujgIqhwYpKEADi6sJV62qVMKv5Dni+9NPEPoIWBnb+QYVq6uDjFXx+1B9RE9GcCxFVCIUZT/va1v1QBQOyY7652v7Fl8/++qurf/eif96WjT2IpRkWZUvBz327XFlWbbp4mOZpSUsuxUCRUi7zcT7ePxzobj4e0zKbKv6RU/wfPRCI0bWhu9hcvHr26qsvvmpio1K++f73b98ntc+bI2fVbmbVd/+Tfmw1l8weZ95w9f1BFQQNDZZcDrCUUhaP0ZF3CGaimoqkYgKoIAaGYIuj5MlaMwMiUFQ1I0IfXFDIuSCBkQEZOiBHpMzAZtUpNCDDYKgK2QARFAAUUAAFULG6dAxIf9Kk4VmLP2KDuMqT+v76yvrHj96zJ7L1dNdOf5+l7R87k09mBJ88roMdsieAijApoDkDI0M0RAWoWqbeBisF8gI5Q8mgZgB21v1FQARFqE5jFc0OqXWshpl1yVDMXEAXwDlAMhMBhcigHkktK4iaFtACWACqXAM0g+rLqIGqmRqq6WcAMAIzqLJwvf31iSKgGa7b+iSTEQgBCcGMYHUh0cxUJC35eJjubo/jcSlFp2O5u5sfHqb9cVY1713T+LYN3jsmQgJ2sLvEofdoVgqZERI6h94BooimIotoBqz4YAUGrJ7YR9NhKialgrWn6T6p0BOkgOs/53k+v1OV4QnsPqnzVcGfF9TZJjj9/IpfPPr6j4vsqYmwHvkEa59+ERiRER1/MBkfqfaP1t0ZBrIn1qmdtXsV7lgxXNFV8hMqnnfDI7b+41LHznsE1+MbgqFVvWICNgksoHtQVDNRWG+7oSmYotnN8ij7zHS828/jJLmY2QkhQWLPzrsQfRND18S+7fq2H7quaZomxOC9Z8fGDERIq8CEkxFaT+e0DsxUTUVNTdWKSt1ksKo8QwYmQuc5ON9EIkZAK1LmxQTP7lX9EZFSck7JmaHZk/lBQAZwiAFdS3HgtnOO0SyXbCLVwAdiA7KqyVcs1VWxbdUKRzoJN0VUUAVTAzCkqobPuOyjPHyEOmw1Zk4zWB9w/dzZ0DQzA0M1NCP8UygcZiBiAIKYEYnZISIAnX8TTraEQ+iC++Jy81dfXv+75/2/2lgP927aZ1XOmk2Tw9w1hkyO+r65P47HOU1FM5CwW5Zy3E9ALqecc8a0/KmI/P+4gQTkyLeh23W7ZxfPXt288t4XyfM83T/cpqzwiZVc9ToR/WE0/tNxsvUA6+Z9FG/VFDdRNFDQomIp48QYGJ1DBDOzYlDUxFBtjd8shMkRASESO0AyRSPGGL2tS6HiWWoE6JCFSRnr5lFFB+AQRCAAAEBFm1AMxEwAFcGA0D7vID9KmE9u6lnoPC6o0+NqAKxOyucPdPKOHtf9R4c/ndB6qD86BfhZKUgO2aMqgmCN+2mFwCsKblojeGpgZpJBEuYFa3xv3YWmqigGSIhEa6TNAAAI0CEioWNE1AzGHoIHdkAAmhUVAiMCkGKqkABAEZBiICtaYAiGYLQqN0N8osGfDCOAulCh6gXA6nCcQh5mpmqEZ6gboToWaqoKKlpKkWVe9g/T+7cP33zz/uFhNvDzJHe34/39eLcfwaBpYtc1fRdjdN67GLhtvXfNpgckTskBOB/Ue3YOEEqRJZelaAZiZsdEK1hgenJunkwHKFUE9uRvnkLTCCdj5CPNWzUzwQmJx1VVnHT8+Y1H5Y0nfY6r1LP6AoF9gAycv1WP+fQgTywDh8AA/g+pdjydi61XYh/tiNM/J1sDWEsoSzPd98e3oJJdXGKX4iDsBFDIGXklWk9nvVuGZzva6iWtmqPeZFyjOucthQYgwALr0nkS0cEa6C5PF5lBmlJessqqfYiYQwixa5rON52Pje+a2DVNG2LjyZMhFTUraus9MyBAXe8DMRFh1XdmJiomputmQzWzOo9PjHir984QEclRaBsw0JR1SQCqKit8Wk0wszUqrI+WXV1l5owCcEu+pzBg7CCwohKDm9DSKERVoxsSsncOHAODQFlEs4gUIEVv1Z7GukGrJiAFBCBb/z/ZjbTO0imAcjofQzubYGsk7TRpZiuwb6sBCdUy+uOjhsqKEhXmDADeO1yN6tUvQoNtcC8G/6837t+0cj1+7775HgOAJyTPELm0QdqNa6HxqCEwbDbNMesh67HYVACOy5yBQkT2iPQIjP3/y0BEZHbRN4EDg5NFlzH1zfCzL35+d/f+h7ff23Gfcvrki494L/xJquX0ewAAQHVr1eAIANVQKIA+4semZqKWBGYE4rpVQQHMUB+DZqBqYDouQiwBiB0IGgCGGNghs0+SxWQuGYTIEIkYnBQRNSmqaMSqhOgQFKBuZjQwNZC6mdRWgsqnd+Djm2KGf2iG8SNx9uMf/NBr/4z7/Ue89g8PVj34z3wDEZCQyRE5QjVVdkYgItUzqcZudTLVFFTABEwQAIigImZQ0ZcqWm1VW8VgnopHCR7Ym3fGBOyBGQkBFGowlRGM0AOYrs6m1Smum7selInYEXsFVAAKDipc8GTYKr7w5MudVdQJiV0j2GoGpqZY4/BkKqWUcVwOh+n27nB3e7y/nd6+Pfzw3fv9cUEOzMxMqZSUChEBkpotOWUpCEQI3mEpJkJNE5gDUT9sYttGx2gmpWRRUQNG7yk69IyEVaJ9snGCp6GhlZ1QbyrSGT2vev1JHGy9xJOGPlt8+Oi+n9z+x7c+8PyrWQAnBY8nxsFp5T0CAPhEuz9R/xWQB+APHapPvPbVVni07Fer9iRp4XF5GoHFMm/G95e3v7t8899A0hiGsb8c++vZx0Su+K7EXnwUHwAYkKxuvUfdsKLCVs1BU7BMUkgFzczQiBVIkQxXaNiQANXO57QqqCdq1azkIkUQEImImEMMTdt0m67fhNhxbH2Mvgk+MDEZaBYVVVYDqPINzAxPP8mIQASAaqpSw+4KuipAAMJHJs7JZKmo9krZQxc8GpZ+kXmRknJKdpZBCKbVtMWPExbIwBtFci36FkMHsdPGQSRq2LGWo4mWahIROuLgnAvBe1QoqZRUSkY1OZ1L3cxoiICrXjeqNuuTu7labSvicl715z8/eGX12tcZNQQgRIfIf9htf1xFpiCmGQsR1tXg0CHyGa5yYJfR/Wzwf9Ha1zhu7h/obm9OLBD4Dv3W8WXjrzh6CgEKew7ZhVExJuGp6DEd56KGYqBPTx0/PJv/80aFWrzzbWiiiwRcljIdpu3l5uWzF99dPd8Ou5xSLtnOgbDTV09O25/ntZ+8EUA0AiBCZmJmZjYwK1DDjAqmamJIagiGRRHrdkVY6VFGWAFVA4Alq8/KwTl2TIqE7FsFRg6Q5jlPWYWkOGVkJCQCFlMtCgSmYEjIBGygukotFTNdAwZmUEw/ucwTL+ijFz/+2BMtjR+/AJ8E5uGTNVH3x+p127rTV8D/o0n5zGlUvX5Sch8fHQmZCZHRSKleYwHLJWmNohEBEtXjmRgomtabVPWDiZ0Mb1thQzIUBVVbkjJW74a8A/TVi68fQ0SooD9ajX6s3BsGUjQC0DWGjMBETaDYCKCqgVsDgU/vkNWTwyqQrYrRk4u63hARUZWcz8ESNiORsszp7m589+7ww/d3794dDg/5/mG+u91Pc0LmftM+u940XehzRuS2CSE6dmaAKliKpCT+fnZuajtsmthumqHHpmmYQIoUyUVEFTx576IjR1WZnshdT0d02EcuokVNa8gIzvHtJ+y/R4G/vlIDDPDoT+NJr5+QefzwP3y8N6tqh8dXHm2Hs32wOul2NgzOd7Y6Ux/5U5+q9lMo/fGLHyzwR+cdjU03892Xr//pxfd/d/n930FexrA5NLtjuzuEdvRxai/n/tm8e7lsn4vvhKsOrp74ul2sSgs01hzyGNIhLA8uz2CmyOKazLFwKOyFnSKboYIZmRmaVuNZPz5VBEQkJOc9hya0fez60PauaZEjkFMgESjZABRIAY0QnCIRVgGHaGSAZoRKeHLA4Sl8h4BIiApm5xA3PrlVBjVIv0a+zZz3Tdcu87zMs+aqzwEAVdCEPIcuNnN+ZMgDGQSjBlyLrrEQpWmwD+galoZJBVTHo6YMBg7ZuYi+pdigIwJDFVbxpZSSU8kiWTSbZQVTdWqgStX1MkA+CTxUW1c0nPSFrldmgEAnyfiIsJ0hllM4gBiR/tSkSluReV1SBkAiJESm1Y1EwoBw0/hfdPRCDsP9fSj3pAcMCoHQRw4PXZt8C9k1TJAsG0BwaECjGUy6zPNxf3i4uz/cP8zjMadsP043+x81Ht1ABCBARooudE3XxMazM9WyLChD6+Ju2FxfXI/z8TiPBqd1vn4Va+7A+WhwMi3/6M9XF4EQHVPwPobI3gFCLmKwWBY9ibwauSGsOJ6CWqU+wQmsqh8iRAEC9m2/2W475wEIRXEpZuSVuTBkTFlUs3ogq4EiRTUAMRSs5A9FUxCsABOulG0EEFEtpx1zuuA/z575QK/j+SV7sqk/Gvbkn+qHno3SOg32Oce9+nR183yO2/j463U4572PiGvyiYrmbDmVnEEEEAwRiBUJGMkEraAVK0mA0AGammQpaqrATESEgmRa44i2xuk9gat5MCAKqDUPAhlFcRZLs86LZDEFNEVEdEzgAQ0UTBglEEaPXWSoZAi0U9jydFWGmFcvbXW3VkfgzD5TVckyTfM0zuO0zHPOGVKCtORpyg/388P99PAwLos63/jQPn8ZpOQl5WFovvriOkY3j1POIobOu9B67xvvYimSU/LOIbvDcXk4zDdus7vaMHkwKJJTXnIpYkDoPUWmat2AVjDkw7kPDL23TFAUioJYTdk6XQMArsp1vepVia14oJ3d9LMRsCr1J2424qPmrjT0E6MJqSqWJ5r98QtVrz+u3JPSsyorP/zOJ6r9qVP+uXGOuSM4LW2ZLvfff/nD37/49m+HH/7R8jL5TR/6Q+jb0B5Ce+yuj5sXnCdDTgNq457ErRDNEIzM2MTLFNNhGG/78V0c33GeFEE45jgkPyTfptAW3xSOBYMAFgCtVubnNiXW4C2zD03oNqEfQtf70LD3gE4NUVWKIJoh1Yg+IqhaTeNRNURjMMbVEq6JAtXSq3pe63QQARoDqdaYfL1LJ9xJ1XQlZJgaM/sYfQzsHJWsFQVHUiNVcuSCC0tJjxKLAYJRBG7ARw3RmohDyy0TtASZdE0IAQNHhBzQtRA6C8FqopwZlEw5WV4wz1hmkxkyG5CqVT/2TI88bdMqMtdXbQXnV6TvcZHUT1Yg+FGbEAADMn1sQH5+KT0+FQUrQpjr/ccqhRA80uDwRUNfR7uWQ3M4+HJAmzApBATnOMxRwJvLbjCPsxVAylaSMUqRtCzjcXzYHx/u5+MhL7NJeWJLr5MFTx2sR1v8/6BRjfU1HIWAjji60IV2aPouNIEIJZVpr6NHLz3T9Xa4vW3e08mQhUcD/jRFH5zyH9XuuGa4gWOMIfRtu9lsfQhiOqc0TlNKRdRETEURsbqV7NA0i+R6kLqRV+S24kKIyNy07e5i13WBGI/Tsp+SYMmInkwURWfNWoOKK3agYAJWrEIIlYd6iu7W5YlqqAUk6xOz94Px1Hf/lJFwQqJOavxDx/2s3T884gfPnvpRlWSEUINX8HT1VOqD9x6JpBQROcc2Tgf7zMJy7ENo1tAuoqDkTCKQFsxppcUQIzExIRlBwZK1JAACMDS1kkEE1MAIHROokQEjAIN3GLyLoWljEyM7BoMMUBwbEzJwWeCoxqVgzkkVCMABAplijd8LWGbIDqWN0ngBMoEVqsZHC97A1JIVKxkQHflKvoCTmCBEVNG05P3D8f37h7u748N+GkcZR51nWWaZp5JmyTkTu82uHfp2twuEun84xOB2267vom27ZUnHOQGSb0LXDV03SJFlWUS0FDhO0zildojLDCVXMZ9TXrKIGjJ55yLhI1X5JOoeB6MFtpozwmilQrNnJ3tdQfYB0wIf99SjLid81PGPvvjj42mBrg9Yoc+VG/cIyOPp1x+X4OmLJy95NYA/WsOfodE9sU0/ePXRbAFAsyZPzx6+efnmH59//182734J86GImBqWxeUxLlE44njH+9dYkhiZgbgI7B6xa0RSafNxmO92+292D99s777vjm9peTArKcQch9xdprjJoU9hSM2w+GFx/UIhg8vA2VgA1QCfWseIRM45CrGJ7RD6rWtbDhHZGdAKf2tlwDsEqjEsAAOFTMB1gxIEB4TVsJNz7NcACKnmYKOhEdZJRTOqbrrW4wuoohqpag2G1RVERMzs1ogyETN5PCW+rEz8x1UGFJECcDAXIHjrIm9a7AOwAoqD4qdjniaFmt/O5pywU+cxBqqEc6tOemFNkEddjjKPeTrqssDKxFGD9bYwAuGJmrPal7hSLhAN6BwmOi2T1WsHQqgZVEho9C9QjVYvX3MxZnVMhIzQO34W8YVPzzEPeuQyMmSuZFszlALLhOUOc3DQNx1uyCO4wzxZ1jzlZX9cDvs87nU+YJ49CjlQcFJ5EgBgYCbwP86PP8sHgKoKkRrf7NrtRX9x0W02MTSknPey30/ynu6dX6Znzl4HCswLqJoRnDM11lDon+W9IgAREiEDBHZD215eXj2/edF2XTGdluVwHJecRWCe0ziOIgURuq7p+2aZx8PhXkpebdYTWWtdFQTE6D33Xbx6tvXR394f9P642EwmjiIXKTlJLmqKDMg1pkdQVG3lZ6AhIdo5nqVYqVdaTMuf66afhj1ZkKsRcDKM4CNF/+Hnzc4Xd4LsTlgLgWNiciJapFJnzDmOMfb94L2f5nkaRzMrJX9yPh9sD3beuSCqKmImakVB1UAMi6AkqJncgIagYIaKRSxlA8JSKuiFKmCKiqY1/YAoevLObdu47drdZrMdhr5pG+/ICkH2rgQGz6yFxlGPx7w/znNOBmJgUKWioAAWhBl1RH1A2CONYgIrdP8BIK82T/M8pf19dj5uLjZN47xHrN4XGhiUIsuSxuNy2C/7fXq4Tw8P5XCUnM2MYmjbjgmVGZ33w+CeXQ/eYeNxWcrd/Tgv0nctUHCsKZdxP2lZeXslFyLnPXvvaCkP99Pvf/s2uG3bhpSWlCfRAkjM3rNfKSU1avCJqkNTlFIJSAxV+hlUSV9F/QmYWZ1vOi1/WiGWk5t+MsDxlG70aA48tb8f33mqeD/wzk9PPxKrjwdB+Eivw+dU+4eiAh899XMCO5sGK7vp7uX7X71480/b97+Mh9epKBr6Mptm1AWTJ/JuZHbvBePiNzluU3dtsQVmAiMTX5YmHS72P1w9/P763X+9fP+r/u6HMN4VWTLz0g+puyh5X5ptCX2OmyVvprg7xgviDjGaeTFXMTuQ8vSU2TkgcKFzsfNNxz4SMxLXoF11sNVAbWUXV7NbzHIBIiCCaiYTMhrIydS3muq4sioq06ziTlTZJ4QKUMwUKlX8RMGEs3dc4QTnmZ0UqfAzIBmgmorKU38BHWIkDMDegofWQx9w0+AQLZAGRDa3v+PpYIsCgnmy4CQ4aAI2DXpXk0cRKz6fKU+6HMt4tLHV8WjjZHnRXEBljRoAAOiJPIkrS4oI1gQ5WL2p024GWL9VUSmoer0ii39ofLIKKy+v4nXFFScOmdhw6+hFYzcuX2LpcHaUGJUJxaGRoonlBPMBFgfQO3Vtv1OCZVGci40zTBPmmaEEBmtcxFZKEFlLB1VmlpQkOauIqpyip59uov+e46TYgcgFH/p2c7G5uuy2u9h2jI3NYV7IJh1VPHTsbswuSDpGUcwr7bR6ilgzN+wUAUZ7PPwfOAFGcIwMGILru/Ziu7t5drPZXhhjKmWc51yKKhyP4/3d/XE8prxshv7qajceHghkHI9pmfWUnXi6qgp9KbF6D33n275VkGxyFHFamIzAQ+YiYFLQGXsEtLrerKgxICMCIJFIpYNb5Zuqggliran0OOzkEcKawXFGK89Xf07U/WBOT3q9ut1rnk9940MpD3AiGuHZfEEw7zB6bIJzzMc5H6cEUACsjXEYhs3uwofI+33V60XyCnH8yKIicsxeNRmYWlEtBtXSITMqgpLXzDHVldNmBqIIAEJmBqprBgIgEiETRefaGNom7rr2ou8uh+3lsNl1uz62bMpYgsveaWAGxWWxccr7cZqXRTWrCQCKgihmgAQ4oxyhvMkZUipLWUQfHcvTULNxXPYP45sfDj60RWl30W42YY31GRhozmWe07KUnFQFTKkUyEmLIDHFJm6GJgRD1JSESL3HJjjpo4k+HKeUhMl5XwGAMo5LSdlEiBDM2Hn2gQm8d+O4LPN0eYEXl32SOZVZTBHJsfPsCKFWJnk0754OUzStnAN7UtmF6Yx7wymyXlP3T5R5fEKjgzPUU5+dVD6cN6jBkzo1n7jxT63Oj9/6kVc+fvMzXntV408lXHXSEK3Csl7kMu9fPXz79eu/v3n7z266M82M7NDYiqpm09HkCOWADCzT4dbffs/DK958oeSBvbPc5sPlw3dXd7+5efNP12//a3/3+7h/A9NRpAi7HLvigroF+egASYS0oGZTKQbZFcDO1JsQACEh5OXJ6aP3ERXIBWCvdQfXmglIWg3sU/aalCeqiVDFchZmYsSijEZaSSWGZDWVE01ZjUVR1vxsNRUzXfdW9XxxRRNObP8VUDFAJGb27JyTgkiVMKqmolqk2FOtSIiByBs7jUydo8HTEHATS+Nt21DneHwf8oHe7SFL8UiNw95DF6ANFrwRATORIzDSoOIhN7z0Yd7S8cjHQ5lGnSddZksLlKRSCqzIaM18t7N9WheimunK10G0FdGvPCdgMDZgUFLFf4EfbLpW+1Eqyqze86XTL6M987nnHIO4QARquNaGsGxQVOckcDR7a0Ie2QfkDC7nWJYO87bFctFx8HnZQs6lWBErokVUikqWZTwux2Na5rwsf1atun/RMHgMP7IPzdBtL4brq+2znfcbtihjyFMohyjHjvPgjDg0xs9k3LFmgaI1k9GUHllcp4Gn2MofuQQidIyBuY2ha5qua7uu3+0u2s3AwZ+y1eB4ON6+f/fmzes3b990XXt9fTm2nqG8M72bp5WWv1plAGhqKiaiuehcyhHUhpaKNYeSj5qnSUwImK1gKgWKsFRjG0zBxMgbETETEGQoIkVFJIsamiHDWiPyc7f0w3qZhh/o9pr+sWKiayzUEI3o5OzUvJhKKjt5ZVA9sBUeshNo7xzGwBe9vxxC33hH9O37aRG0BRzZduivri7azSVwEMNUyrxMVJbVdK4x209nx1ZxVMsmFilgtcyDMhkhFBARU7HqvtPZ0TCTcqLzIFItxMXcxWbbt5u+G7q2i7yJ7jrwDdOl931okRwReFccCQOYQWHYeLloci5JStJSVCwLJqVJbZQibJdOeZkSHOYyH3DRFWH5wGs/Hpb7++nN2wPSPM5aykWMuxj9StszTSkfx1kEQmw26LyPRLPjeV5UAZsWYwshsAocjmNKy1uHQ98QKDsgFCnzNKIG5xhMSl6mkmYpi2OmqmmJfRj6ps3LfDweb2/d7n1BP2bIhojkmR0zA4LWynz2mQl5AqivkE0lixIB4xm4BMTVCUSqHOvVuzlp2I9mGXG9W48G3tmifHTjP+eXf7zYf2R8qvU/jbWjnUXER988FS9pynJ9ePPq9jevXv/Txe1vbdmTZIcaDIIUMyxGo7mj8Xt04HQ67n3z3u/f+8Mte88oQ7rfHb5//uafbt78w/UP/zi8+xUd73Q+JrGCnJuhuABSOC9IRzCBsqjMXmayYgDZ60gA4k3YiIEdPCHYIAA7b4bkApI7RVUrZM2IWLkt1VfXlZlUg+0oogACaKhctGbVGhowGtlKwKpmMbvoiACwSE7LopoRzHtqIoOVUpa0zMs8ZytgYmSqhqAnv50dO2UHiGqmoqVIKtkXMvSPt54QHJIzZvNsLWPnqPfUBemiBE+dd+OrrkwCUA5ZOyetsz5QH7CLGLw5AnZGXE0QVI/SYG4pZdd3eOhoOup01PFo40EnsEW0uoW20uueOCwnL8nqQji57DVFcDV/GYwByD4XUvxjw8DAigGJYtHGWwC4YHgR4KLB2JLvW27BNKlmQCFTNcFkUNTyLHIL6igOTr3LHErpJV+w2hBC2w6FUzERLWJZrBTNWXLKac7j/T0xwwFURUoxETiTQf78i/iTBgITOe/7ptsNu4ths22bDWqnY5CDzw+xHBoZeyqDWODQg7+BfO3oWOwI53Vuj6Ezq0b3KSz8gW77zPAM0WF01HgXg298CCE0Xb+7uOw3Q2i9d8Rg83R8uBu+3bo2iGO33XWdM5I5jYf9HanpWj+U0TF773xwHIgdIBYrRyzahQgd3k98vzAmFMRaS0HEanUWZnQ1v72YA8AAjpgcAVIleKsBCpiRJ+/If6La7RFnAcRa4woe/aEz6fij/9eKnXTGPlepiogn5BVW/tqpNKpz5D03Dfct3+yam01sPKvaw2RMyVijw82w2e0uKG4FOTQltnOY9ilx1gI/7rWLaikqIlKkiEo5U+AIjMyoohciJ37z6R6ogpz4u0RARN75NsSh7XfdsOnaoQmd0wHtSvN1mXc5tV6BPbjAThm15tELYiRtnUhJkhfJRYol4xmCLss4Hh1a52lU6DkHyici2QdD1eYpT1Me55xSmiZhxu2uReQQGFdHQFUNEJ137DgGj0BMdDimJSuhmQoBIgET5FQOh9FENkNgAgYponlRhuibAColLVDvDhNVQQo4bDj6QJbLMj7cwZvvl7gdsREDYvaOnaO1MJapVir0Z4z5yvI/+eLVl2CqzuHJSFxfPGO4Z3nx+Vn+wGU/PXuCkNeA9+e27aOX/6EvXw3S9dlntvxnvfYfHWhGpu2yv7n99as3//Ts3W92D68pj42WXqER9VlFIBnOxkfjgaMLNoXpdhrD8c4efoiUN4u/uf/Vzdt/uvr+7zZv/ps7vNXxYc4yi2Wg4pywJ2YHQjJzSqgOSzBpivURs2MsQKPzs6gVZ6EB7zA/TddHdJ4MyXtyjmpYGwGJnPe1zmjV7ohEtM5e5VOImJkwAxhpViuW0iIlV82mqkwuBtf3fjdcdG3nHC/zdHt3m5cFAXe7/ub6ktjSMt/dvn/3/h3BUjCriUgWUTglEDMzExugmeWSl5SWhYkgRPeEPFnDosAEnjAQRq4SmdpgTUNt8PrzJjC4cHh7P2MnkaXz3AfqA8UAjpAdIKkBiII6FMHiLRfzzkLE1LtlxukAh70+3KXDfVkWyflcMe9UMv6c0A4AcKaQPOLCdcXZqbz8o2j9c0ZFBYoqFUG1FmDj6CJy1wd/wW7XcMspj5pHwoJSCBKkRChYxGQ0uKNmhxKY2kZtZ+I87WK79/09t0fDSTWJZoGSNacyj2k6TOBZ0QSklGxmKo+VOAx/bJ/+ywciIlEMcWj7i377rO8uAw02NzZFufey9zI2mjornVmn2JEZyQuPL1v3vtC7rAqKZlbzmg0/UG4AaxGZP3j7I2PrMDI0bJ6MKjZDRM67ENomNAEjF+jCs34Y3EWL47JkQPNKsG0fuhgcS1YxIMfeuxh928Zh0252XTc0TSSnE6el5ZbJbb10AQ2tgApaTaEzQVE1qVwwUzV1BmpMFGJkp+ScjxJFSlYp5rHxGJifyKuTmfn4gj3BP88YxmP47FSi/WQHENTQ6SlpDM5g5UoyqRmBTYx9Gzd93PZx6NzQuMvBbztXktzvZ8LRlB2FJoS+u2jbywVCKgYUQ+hi0+W0qEyipc7Xpyq+pJLmJFJEVMW0WCmWF0uzpEVy0iImUj1zPIX7tTr6ViXKWveZmhA3w7Dth6HtAyLlEkoaMG8ZtssYC1AB6cnAoTEQ0yklVlVNigmAqFX2HLca+mN6+P5wH7xdUbBCTomsciI+sztKMTPyPk7T9O7d3gd/cblhZh+aWkDDOd92UWTJJSOCDzRsPDMgou1zmqQsShfWtnyx60qWtOSSF1BHYAxiWrAYemaIaKBZCCkQOwIwzaKqIMtY2KgcPEzH2+V74eFZ6a4AW9fE4MgzMqNDZDBQNSLVjxyRqj4JAZAIDRFOrjmt3KpzNF1rEUd4Qp//EQVvJ638RMOv5flOKF5Fy8/A3o/Knqcx+hOij5/hN/1x1f4hmq9eUj/f3dz++uX7X97sf9jND86kEemL+mKYqmqHZNQDA2sROrrD7N9vbn8nHrpp2Hp7dvtfr97+U/fml+7uu5SXsdhsPIITAkMgVAfF6eRKIkMwr9YKGWQC12IeCSeyGSshmx2ExpJ7epouRDT0IbL37BzRIyWy4su0BvVqmtZa3HZNXjMQUco5z0vOZV5mEYmxbUIMvvE+eh/7rnt+fX2528Xg7u/v0jSNRZn4cnv59RdfM8HhsNcED3cTNw07Vs0pzw+HhyWVStJz7CB4VWOq21WLSC7FhycoF2JdUkzgqDbhAM/gGYOj6Ikju5d1xeLFLSdI5KXvse+obzkG5FWgoYIVAVETBREoYsFbiJoz5URzg01TmbR6PMI06jSDLaCCqlgrPZ+zM/Gst+1DMVV9+zVf9M9TiedPG4CqqYoz7Rh3MVxumv6iddct73rqIuWjpiNBojwbHqAccRyBEkixlOx4AGgogAdsSw4xYBfGYdi0uwekQymTaCq2LHmeMrpZzPwcuQnsHTmHlJE+n3Px32lUWNB3odl1/VXbXHraURpkbPQY9T7I6HXpQHqCDeGGaXCKjC+NvlT/7Vy+XUyMmLQnbrjOBSa1JJqKFl1DiH/4XreeUiBPGNlcrTxsAjXSgsakniH6EoNypKCNl81+P45Tnog8hLs+9tFL0Szmg2/aOPRxGNrdRX950V/sYt9gwMWVHEomCB1iA5kkWclmigiEZMCV9AAnjadFpYiKgiGz8+RcNDXLSSWbg8ZB/EC1/6FLrBUWT7jTWiisZtydgKhKaV7fq1KBa90Yx+w8e88x+hDjpmsu+uZiCBe97yI1HrvoouP7w/JwNACPxt75rh2aduebi2mxohmpcb6NsUthWpaEWU5e1sfLS4rkVESkiJYsJUvJmhdd9XoxlUemDq7U1ZUehGdUDSscSN7VugKIUjClzpYLTFvW3js2D9RavFBjAG/AZ3q4qkhRyaiCqqxGwA3Fjbq8KEMuOYNk03JKc/20eJBZSqICTAyGacn7h+nt20PThn4TiMgAidl5QjJRqSziGB0hpsVyMi1Z1bQIKDaeC2KeFy3VOltLhp0Lh9W8z3q9hCqlIBgBal7SlDXPqCmN8FBAAAtQfx1cFx04QkKodY8N4BRj/Xj5nDQtrdVk6hkgnnF1OAV2PvjeE0n2eeTsDMKb2LIsy5JyLqYaYwgh+ODY8Wlazzf26S888kIePZCaTPIJj+6janRPTuqJl39y2RBNYxm349ubu9++uP/9i/yw0+TFqBgsqtlygWJQAMDUgw2Kz4QU3g8iAgeafxsCRkj++MYd3sh0mFMZhUbF2aiatQSFbfKSfWEGZ+Az9RP2MzSztbOGqdARNVtS14B3EKM1PSzhyVWgb1oG8qF13hPVcsympjlnWD0zqLEpIlorwSPWXFsDkJJ1ScfDYTwepnlm57549eX1xcWzZ9dt06act8PmF1/dPLu6jMF9/x28++G7PKpjdzF0X798Zqo/qL4hb0Kbfvfs5pmajvMBX387TSnRjMjR+0AkJuxcE33wjIgfVeRAtGokEwITMBqREioRMSMhBYdxB13vht3u/jAc5yVLYk5NgL73nsFqLxrToqvtX9RELIs6L5w1Z8xevKMQqGl4s437ve0f5O6u7O91nkxUgVxt9IAEhARolZ5gdYOdVmKtV3OWl//CgQCmpNKgbh1eDN3l1XX3fMc3GxwGbBpviXVGm3HZm3sL+hYng2RYEIkgJaMDKCIQlxQ89YE3Q2w2Q0e0T+mw5HHJ+1KKCUjWGlyUqhONapqBKRio/R9Cm2fiNnRD01+23UWgrU2bvAw6NTo3NgXNAUrHNhBuA26D9QHYwZcL74F/c5TfkhWyLtBfdPSyoeBQAO8SvJvk9SHdzzLp2WX/UQulj04ax2aezUEmSSBJZdE8S3YlZSU1S+hSpHzZqX8W91Ee9noYKZJ7twm7TaNmc9a2i5tNu9u2F7v+2fX26rLf9jr4JeYl6OSWBOabFJosPB9wmQiUEYC5lmQUKWpCYEwgWZcpEzKQoxBwRU/ZeZCMDJEhkvN/8O5+5DadXPYVxaaVo0F40ufVb6qVBbzzoYmhaWLXhmFot5tuu2m3Q7tr/UVLg9eOC0oxEQPKSqo+aRB1hNzGtu93vr0Av7WcFRckY07edz4cmUbEdMrK/0S1i5Vcg0RlHtOyZMkqBUqxUuyczX9mvZwz/SrWUo9ICGCmRUsqyc0zCMnCsgyQn5FsWaN68NkUEL1yC+yRmaEGwdVsTouWXBBMAQGZ2aMPXei2TY+6gFjJlrIWsSfY8ONQtXFMy5JVzTE3TShF3ry+7wd/dd0yMQGaqWrJJaecCNixI0Qiig0Pg3cOS1HHpqWIihaFosSOISCaKZsBc0PoSimmFrz33nnvtBQthoQxsFmZprykImKEpIUPDyWpOM84RLJAyDXBZl0laGvFwyda+cR7f0xoO1PnTnr9nL9+1vePTvST1VjZnY+/tRp3CjmVd29v3719d9jvpeSLy4urq8vdxa7rO2Ku5IRHOPQ0VliqqmeEk94/lyv6YHxsBeO5YPz5TE9LyRDISpMO2+nts8N3z8c3VzIOVlBQE6RkUiAJFAXBtYwLogReLpaHiMnf34fiDYpoymla8pzVRsPRaEbKSEpIjCe7Co2wEBeOR+7veXd0lyNvFxyS9ZP4GaAwgffgIzgP9AjII6APrQI675ldjWDUssxaM01XHGVN214rwp6Kra914XMu87gcD9M4xti2Ljy/vPz6yy9ijG/evo2e2sibPvRNnPZtF8MSg3exb5quifM0p2le5lmLtE13c/28qLhDnJacFmFgh4R5xpJECzH1MVS2LX1Y6aUyNZjMkTEBMzABMTAjEzGSIwhROpKuC5fJT5NPOZkt3mnXeiaslPMsUlQrdyyLliJZYM4aMyzZctLkIQZsI84dx8ZCJGIiVCKZJ6tEUailo3H1PWrO2weO++Ni/3Tj/+kDEYxMGrKLQBdNHIY+9BvsdtAMFhpEYSyACZoNClMyPGRKxoVJg4JhSqAjIJkId03rqWk99cEBeiy+KEtKaYTxaOOo0wg5kZojCj4QoHNFRGox6x+rDPovuy4AZOTo4jZ2l7G98LwjGWzsbRx07DV3VhqnjYe+oaHFoaEhUhuQGXWC3OA3ggck87Rt+a8G+nlL0ZkAvs/43Wi/vqPfP+TvR7lPmgzVHstgf2hoYRvYGkcA3rkYKLKyJcujLIc8WWaXQINPphl9aVwOAzSADUDL6JGeb/3dZctkx7m0vd9u/cXGXW75+c5d7aj1EjE7WXyZuGTQ0Bh0JiFPvEy1Djkxg6pgUQMtArX6oJgssmA2WLgFMvDIjomcQ3JkESEg8WfuK5w8kLNSP6Gj+JiKBHh23+sqrQRcYs8+hNi2fd93u6HbDnEzxO0m7oZm24dt5wdvPeeosxMpqSxJ50JZMAnNhdXYOdf2w7C7DN0OwwApIRGYES/MDVNEcqugX32mD6SwCUixvOg8yzSWNBcVU1nbsZyCqR9I+Qr/ril551RVqHQ8kZwNJOiykfmCZGfaojkjRVR2SIE4GAc8AYYoYpCLWC6CqFj7iIFhTkFlcE5KMcmpyFyk1FDmKXjxeBVmy5JzLmjgPbWtF4XjcXp4mA6HxTHH4FbMb62xA4Lo3MpmdIwhcNX/tSQOGjKF4JsYOzDzfnLOhu0uBAYrqtwVijFutkPOC6Ijxy64lOZ5Gp2rBEE0sJJUjzDf49xb2qj06tCMzqvlEwsYEYAfb/O5Knq9ytMys0fXd43+nA9YWzWdb8z6uEpMBADJMo3z29fvfvvr37x//y4v8/X11YuXz1++enl1fR3bxgdPxLjW/joTVp+ugLM1YqvC/+Q6PlDt+PTLHw8DNLbcLXcX49vn89tn+X5j2ZulDLlAFkgCWS0rFK19g3BkmFEQ5t6WPmM7wij2ILYAHABmwBlwJkxIWvucOLbgJESLTfZN8e3s+3t/8S7ePPjLo99m3xs1BViUBL3VRHmkD4FhDN5LLR+3pqudpsnWRm7VyFIRyVzd45qEJipmBVEdQmROziWkwLwb+ufX1188f44Eb16/Ph4O+/3Dpm+iY+do6Lsi5jg459OSb9/ffvvtN7fv3xVJPlDfd9OyMLnL3fMYN0O/e3j/Wo/vbX6wkpGga5sQAnqH7mlX01pUxByDJ/BsnsE5cA4dExN5qvUSFybhUIIPQ3RqEQGZLXgmIjArqkUlixTRLFKKpFxSKVOiKZVp0cXZkiCxJbZ6Ct5R9K5tIMb8cJ+XpFkE8ZxmtRLoajeos+NuiEpohKfitf+CsRrIjNAzXUW3DdQyODBT1SwEYqTAiL7F6GBQTIr7hTKysBWWhXIxm2cxLEBODZid5zYQFCXNsIz58MC39/r+oTwscsy4pAAAIboetHZ+LjmnBPNcSvmjZ/wnXRRALecaXBxCd90MNzFcoG5s3thxsHGAZUuyBRg6HHa0ueL+itveNZG8NzKLewy3Oke42jqOcdeHf9vCV0EDFkV4wPDDQv/tPvzDm/S/fzf/6j7fZljUEBTByJQ+PJvWMzfe174KTdO07CFBOubRLZw9eC+QQsoxi4ojDSSuKUGtQQtg42WYjn0M8DAusXF9j9tOdmHZ+mnDFil7mAgTQ3EGBNyT9qhtST4tgsjOsQ/knHC2YqbqCIM5UgRVMRllxlRwyb6XoOpd54gUGeDjouWfyLMTkHqC3dcY5Ck0SpUCT6REwIQu+KbdDP3lbndzffX8+uLFdX+1jbuOuwiN00glQOZy5PRg5Shp1IKs3oyy+llwKihIPoR+s9lcXsd+Z6Gn4CkjgCrOgB7RrzlSuDq8H2HZZqgF0yLzlNOsKVltiIO12zmaEpy+eK4lVa/W1tiXncwWNAAjlSByYeUV6jVJ79Q5MucgRPONcjD0yA5d9Q6VDGtV9qIKJIzOe1JL6eEWj3OvspjORRaRySRZzRmorensyVVYLllViTgEdgxLklR0PC6374/eUfAtMznnCdnE8iJJrXZuy4tIVpOavecIa1E9bBpt2q4bLhBsnDI79+z5S+9dSYs/jEhjN/TXz67SMu/be2JyIczT0fl9kVKZzCkvBIRq473c8rTrpk2XavZxvYVPAe7H6agNFgwJYG0msjbeW6kZ1T40XbPdtHabObNe1pTLR2zglOlWsXhIKe/vjz989/o3v/zt2zc/zNPh9XZ4+/pm3O+Xr6fd5UW/2YQYXfDsHRLZqeDnSb0/Nm6vZsZnTcZPvPZPJNPTq/aSL8f754d3V9PDkCaXxQqIWJG1/XdRyApJIRkkhAlhcQAsng1NKZsJFMEFeSaeARaATCCIlehuziEHcK26IYfdEjaj2zyE3W242oeLKe4kDhg7QwQVaHpoevDh07NmcmaARHVD4DnCVvsVne/0+iKBgeRFRAzMO7q83G7bsL8fYngnWRCri0zBOQNQgZyk5GKqhNi13fPnN12/RXR91y9Lvru7f/Pm7TgenSPvCUlLWeZlDrHth13wofFBHqIdI+YZrEBgI9DaFvUDQWVExqgOzdFJqTtkrgZ3baYloMoEnoBiJQZSradBSAAkpkVZVERNVItIKiUXmVKeUxlnmbzMi80eFm/BqWNyDrxD72tpLT4cYZohFytyFi41jf+0kuwsTtFWefPnjacGKQIw4hD4uvEDA5dk06h7hKWAn4ABHGHXQfAYttQVbh8kZsxoBkpmojqlLDCR06SzYoPo0CKKWs55nuZ9nB/ctOdUSIxNHVfBZwUFoIgqIuNnqjL/C0el2njC1vEQ3C7yztuO0hbmLUxbTlunuwYvW7e94uHGdzeuvWLXsfNEWEgK3yn38tcRri/ZOeoDfsFyTcJWDGxmvFG+irBzTOo9wT/ey/tFRWsFO6QPN0gTnW9CcBxDDNF5B04XW/bpqAyLs8YbRZVsJVvxToktYvaheFNvkK+ClU3f4PsDEWkI2vk8sLbGjWo09azkmDF6DICNwXAp8mw37rM9pFKIvGdxFNCKIw2+YR5iALNSZBKZsmQrYmoM6Ig5IJ/TXD5eWZ8Fhx/7csAJTII1IE3MFHxoYtN2fT9sN8P1xe7mavvyevv8sr/ehl1HQ9CAmTSTJijJZFFLxTKYGjhFV8wtwlPWqQBwaPpNv7vuLm6w2RRsyCG5IsWpsiqpEgAj1tZWHxanAwCAlGSe8jyXtEgpBkZ4iiRUfpQYoDzutjVqaquOXdv1ICGvMbEA2qldotygbkg9GzkC74zZiA1JAQlrEXhAIzQhdoCsilmVUcCLqpYlW05OZVZZtEwis0p5rLr5EfyMRFhTFF0tZagmostcbt9PXRu3m0hEMca2adpmkZRTFgIiz8G33mEtbw3GhOwcI2DJ2nbtZndpKveH0fu4vbzx3i/TJLZPJXTDdnPxPKXZKJJjH4KLR3KdmqrJOI3zNKpm05Tm6fbNoWvuvPfMHII79e3+mIyphiJ46gVW9fpZm655lHgChrTWI9WVqWl1ncGZVXemJdSjmZS0TPndm7vvvnn9w/dv7+/2aUkqukzz/u7u3es3jnmeps3F2PZ9O/Td0LsQzEDkVH0EKqdvpWGt994Uam3/J+NH8trrzH0U+lVtcnp+uHv18H57HP1U6u7PglIb+oIZgBgktQlxREgOSkQImJ0hopgdDSaABTCZZcNTEyMzQCVWdIpeMSYeRv/sGK5Gvx3dMPrN0my1u7DNJWwuDAmkAHvwzQmK/xAdOqWyI4CCASKtRFmtXROhBtrZsfMIKFKWw34+7L0P/cXF16++/PnXX75/9/b3v/vdNC6Hw2Gel+PhOB5nYtaCCM656H3jfdjt4s/IL4uoghZZ5vn+fv/w8CAqw3bjHaVl3B/uHvbvL69eDJuemLxjbTwcoktHzcd9mUfJaqCEzvqnsRWE1RN2hJ7JeXLMXHMwzExqcyYAKsjApLXTPBMxrWgdKVRnxQgNScGJ+Vykz37OZYpliuU4lXGW2cnE5lgdIVNmRiKLgZo9PuztcCjTXGvqPIJPp6z90+1fIyn/ouS3E7iFiJ5oE/2zLvZocDwKFsij8w6dA0cYvckVDjtyLcZLcm+JDqZZRdVIRctY5qwHtjSWZlHLFlXIitPsdW5kbHRuYYkEwZEqqWHJaiCiJRcpRaV2ffoc7/RPG4/4VzV1GCyi9s42QftYBgdbmC5t2kG+9HYx8MVzd/HKb17F/mXjLxwPBB6NAMpiC3CztLF81dPzEbCIy9KkwjkDCIAGmK+I+4G3xB35xtO+TFOxEUD1VEX5yWhCsCZ6Ru84eHKkJrPMJpKhLAxDAJ8Nskm2XLyos4ASULzX2AFZbDxtOtzeQy6TaQooDWIQ4KSBffSOm8Ehem6QO+82S2tflVDC7evb+ykLB09M0DVoSqpD8Nu2NdFpzu8OxzeHw0HTnIvlbJqRhH3tcaKfa03whBdUQVI8U4bwiV6ntb2ad23XXt9cPXt2fXN9dXN1cXMx3Gyb64GGoIGWAMnlBWWBPKuIqkkpKlgsFvYZQrZmNj8VnlJZErDvhrjpL5+3u5sEQRIaCqIzw6KWC4gAGBOwmnxWJ07jcjhMKaWSBRCdx5rODqYGaEikVhBVTk3YAcFWOWZgawUKAiQmRCZs0DYgl6CXqC0ZMYBzxl4RREVU1jobRESVk43kPHNU8MsyW8kZCRFFRbSIliRl1DxpmU3LYzv2DwYhxqYxFVASkZzEOfTCJdn9+3m3TSlZ03CMtN12klXyPi8TM8bYNM0m+K5W5KzVeZ1zgJhTadt2d3mdUnLv7tnH2G19iGqeJyNvLvSh26JvoiB7F2JE34HriAkZ4zRP42j5mKeH92+/vXt/a6Yi2nbNsGnRQ60Z/lEtShFLxRD0FEdffZZzIboTDGSrsYh4ru51KgRStc+q2uuUI4KCHg7j+7d3v/7lN7/91e/v3t2a4cXuMgZiMu8ozdPbH344HA7d3e2w3V08u74hapFEoRTL2UxrDjc6RnbEXMECrfVGisjTC/lYta+hnXPk/3RerNqU+ebw/hfvv/n5u++2hyMvkkrlXYPY2kLdAAQgAywIC0NyIAxGgICkYAaT2mK4AGTAolAABUAQBECQCoXMbXLD5LZHfzHG66m9WuK2NBtpBm0HG3aw2RqylQKAiGz2GJI6C1ZyAQyIGCqpqxLPDNQUqSCXar/X0B2u9LQCmkCAzXab7auXXwYf5nluf/vbw/FYSl5SEhFynpxzxM4F5yod1W2Am1ZFbH+3vz0cHx7243iMbTMMHTGO436ex1LyMo+H/f08LyktxN73uxCcLTRNipo/S8OoCKIjdA6dq5WdalShwuEARiBopEYFWcjX3e3ITkRgBKr5xPVwCGrgEQJDdNg6miO3kce5jHM+Ttk7c2xc/yd0rgY8azEs0EXsVKwEH0/zqSSFEy35XzYsMG/beNV3zzZ9HxjULGVDETQgoOhQWmgGEK3dKKEALCJTlkVzobTIMuVp0YNH2y/53f22aTaIntSWWfMCsrAmD9Iwt8HVooMArAaiKiKkzOrqnabPKZM/OvD8sD4xACOEQNAS9KSD0y2kC87XTq52eH0Ttl+F4WehfdnEZw00bB7XojTZbBGwhKBdA/2MOCmOhg8ZclE1MGNTR9IgcRPsKhzV/XbCI+Rv9nlKtaf3B9ORi2jOpoRgzLSaiVokJ5DMoF5jVHaqToSjOm/MwqgBzXtzA7fBRW9do4dRl6WwQcO8aZttN7RdG9rIzOxc8C35QcLAiWR4Pjy7e/3m3XGcAICZgmeH5kF777ZNULVxyt+9v//d2/ff7u/fzAcTQcmMJXj1teb5J6H2D5ffR7d/xeRrFWbnQ9M0F5fbZ9cXX35x8+rm6no3XG26y95vG+x98bBonkCTatGcy1JEyTAqkCBk0myWlBfwR8Fjkjmlor7p+6bfDpevfP8sL6rLooaqteIhmCEAE3pEhyBPKk4/jlJKTsm0EBkzskPnEFG1mEolteFKCnr6f0WAzUDNtBaaYiR2RC3qFm0gaxl8cBid+cZcVDUtC1hBsJqkzQiApopAjC4ghlIo55xsJARUzaJFZdFylDRpzrXT6WoPfHAlZiBFETE4x4QIogpSQNXSIvv7+d3b/TCE4LkUQ2TvQ9Ni1w1Dv2vbTYgdkjeElLKqEnsVm2gmFxR8UcmFFDmLs8JzgpRJxBXzokEU1RoCjxSJwXljz+wdcudcL3OYrTC5XOzubiS+u35+OWy6za6J0Z9wlMehpqIF1igunkOkSHZOuUB8XFw1sf2k7FdDeqVBwJnpYYYgqg8Ph2+/e/3r3/z+N7/+XZoTgQ1dDIF9zXgClCU9pHQ8Hsd5UcRuGMgHURQBqRWxBUhQHbKZs9oPUKs+EPlxr93gI1C/otlg6Lzqs/HhL26/+8vXv/zF299uDkdabG0S9lEHBABBKGTZQWETMjVAhUXAFGbDGSEDZoRsmA0FUBCzYQZeOM5+GMNuDBej38zNLm1vdPPMup3Fzry30JiPBgTkYE3YEah9Mx43O/rQiOGa3EZEzM45ABMRJ1JEzOq2U5VCDhxDjA7agMpoENh3bZc2m812CDEAoqiqCjH54GNsRNV77xwzExMygRKaYc7p/v7+4eEhLUs3dMMwEMJ4PEgpnvmwv9/f3U3zVErphn5oo4GSiqWZOSnoJwpx7ZLKDl210VBVSymWAZkZqXoDCAhMoq6YIIJjCAAAjGBIaz9cgFPWTJ1fJgyBm0Cdur51UyqHIwQvnoVJmQqTEgGRQ2RAV4REsOjay2otbANPipueMPl/gVY/E1AQoQvu2dC/2G6ebbddBKBipKhqeRHNKBG8QzNAtFJknnUc5TDJMeXFsrplkmkuY7KDyHh/fPf7173IVVo2rWuhlLSUUsCEUaPzvQ9OgvdE5E5cFEUCdgTqqyP0ZFnh0/jZj18L4mlfV7VqYAqrwe2ROrIN65bLRSfXW7h56Z59Hbqvm/h15IuIXRDDLKt/AKgIxVqCDBSEWoWAgAZjARMQAqXawwBNWiovOvjXz+K/t/7BpbvyMOXFEPQJkc4A9scxPRyjwzZ6VTVVHxTIqeYsZVIljWyelEiMFViAvBILkQUCH7lpvPNtE+3uQfbHwkDRxYvt9W57FfvBtx04Rz762HHsIQw9+e3X8uX+ePv67fFhX5aZybroI5m30jjrHJnCnPX3b+6uNl34BtMP02wKpTiQ4KyNGJk84yc32s55P2ey3JNVVfW6Yx+avr+4vPiLn3/xi69f/vzVxcuLrmdtSTxOXjJJypKWNJshccg5zDMZNhR3RlGQE8piJYMtCvucjss8Fw/YDNtnl89fDZcvXbODPIqKZJNciqiqITriwOyJPGA+Wb0fiisTNHGk7Mx58B7ZmwHk2UpWMDLAWuy5Iq6129QqclfrAYwYiInYI3WkW9KWjR1ijNa2wB1Qq6qaJ9REIIzIyAQGthrr5AK6BsznZUrzhFQ8gxoWo0XToaRJcjHRU/TssRsWAACo6uGwRO/i0IRAzGaS81IEAM32D9Pvf5e3m9j3QcVyEuIwbPrd7mqzuXQcyAX2DhBTLiUbgEuL2AQpwfEo45jHUdjL/iDMaX83HffLNFuccRwhZxxHjILMtR84nVzbhkJIshTnfYjOhWVJ72/H77+77/rWBQrRgcFHPYcMRCyvDceR7MSOr7S8s/KGajQaAT2J/+CTpqm4Fv6AU7FiKXZ/v//m2x9+9833v/vuNYgF56Z5WZb4bNd3sWtDJIK7w+EwjlnUhebi2eSabm2TCFRnXwVUyZ30bq2QjWsnlMfxWa/9Q9AVwJluluO/fvubf//tP/zi9a+e3b720yLnrgcfdGbBNbJHoAhCUNZSsFjLdgusulyABLAgClJmn1yzhGFsLsf2amyvp/7Z3D/P2+dy9dKGa2g35kL1rwH5VE2llk42sA8sr6pf8JQe8pgMqqqiKgpqWrLkJHkBLReXu93V7qofysV8+/YuL8vDw8P9/T0RDsMQY+Oc88GH6J3ningbnAqzi4GZCIiCio7j9O7du/3+QbSEGDbDBhHmcXLEnt3D/fvbd++WeTSwi5sbw8uM6Chk12opKovpB4gKGNYmS4YgoAWkqGWxhGyFx0yWQJKgQfDYNNa12reAahgBA7ADZANya9mtUzMXQKjd7gCNyRyAc+A9OHLeucAQ2KIzz7XPhFQNVQRUsSiqQckgUBUWYOXT4FoLHE7VHeDPGlbFMHjEq+i/3nQvumYTuGnYOY8sAFnTaGmxgLWeOKKpFS3JcrZUdCk6mxpoAQMqZlOx28MywfuQ0/tputw0Fw25skiWYkgVAZEiGXLGnLOKIAI7BlMCUCnnffp0eQN8mtP74ThRbSpTlogIyaP1gTeOLtiuSW8avdnizZV7dgMXL7h/AfHauDUksVLMyJQwtBwbgIJlMjUriZYCaFrYImFkDApIIDVRR02TOYot32w2f3X1/H2fv52/HeV2Ph4/BLVsLjYnUSPC4l1hzoiGVNQYVBKAI5l9E5g9kyd0CAyKqsCKbIzkiPumekXatIEphjhsh+thc+WagWNrPqCPLrTkG/RNYNcADMN0Ebv54b4c9qSldbWm/OixBFYAyIpeG9KLcTrc7e/fS1oEuQCpsZqjD+I8J/rMCSR90iwbAKCCzkhAnkPbDpsXX7z46qtX/+bnz//i1e5lTzuvriSULFqyqYBOc3n3MAuEOPQp4/3DbNwF27lmg74tzhbLWUqmMtI485G70Ee7ePby4uYVx62iN01WzKSYZNCyWtZE5/zoU6r9B8N7bRup3D5iI6dVwphjU7Jy6mSy2uS13nUtBbCm79naTce8WQ+2I7tw1gfygbEfoN2YNarOSgaZUBa2QgAExKBmYKJqZuSAPVEw45SywZy5CFAydyjpocyjpKLyZC19fCFSpCDmXLDipCtiDICQkx73hcCBKSKZBueCc22MW8eNKkgqagTEJVMpKkXSUpa5ls1fpinNMziB8VgQ5bCfp3FJi05TOh7GlMt4mFWMmUvRkswAkFawXVUApOs6kcuHhwc1fbgf373dX153wxAR8ePwoSlYqUU51ZSQDE91zU5pZnZqlgZkoLZ2uLMnlputbr0hVulmBqXoPOXDfjwcpsNxRoPEJS1Lmic28WSNd95zKZpLyTmXnKs1RW7tBaJqWFv8mYrgCuiYnXklT6/jQ9V+QnngfKYICBYkX493//67f/jffvu3X7/59Wb/PhUp9aqrPLdTEiAhW621W1kGKITIxAgktTH72hFVgYRQkLJz2cclbqb26ti9OPY3U3e9DM/K7rlcvITLV9BtzUVDWms31KQQ1fpoavghIG9gqkVrOzQjUDJFMRPRUtaay3kel3Gfp6OpXG6GZxc3F9sdmf1j+btvfv+773/47uLq4sXL667rmqYJPnZd23at8wyotdWTqalASkoIOWnKmnN5eDi8efPmYf8AiE1s+n4wUy1FgXKxcb//9pvf5WV0hIACrHMzeBeQO3VWikou/snWN0HNJKJZdVGZBRfhVBDVzdk/vNPD+zIfC6ptt/5yR1eXqFtDKSiGahAUfECnSL4G5wFXmkHtplwrvROqZ2AGRxS9bxy03qIjR4UJqcZYrGb/21KgCE4KJqCquLarQF1L0RlC5ed8Bjn9YwMZoGO8acIvNu2LyC1I9N51nlhNKc2YNYM5BAUyILMiqllrPdJskAxQSYmcB7GlwMNU3i+LjlNzv7/atTcX3cZhBM1KwN6SLtN0nGw/2ZIli4AZ1YxDFFWVcw5vPT9EIjpXKf747D+2AmrZQXY+ROc6hkuPN5FeBHvpyqvBbr5w11/C1Svudup6AZ9sRktimMx58g03netu0BnIaCK2jKCT5lLYFx+xZSqICVBq6m0xGZVRhtBtLv9y85d3F/LPd3g7yZtl+qDPOWBWTEaslI2KgZiKZFI1EDCTmvCSORVcCnkBX9bS6jW1yik6j+x83zUcQyfAfhviRdNdxmaHvjEfzUfwESkgOyRGBA/mIjUbKWCChsvodKG82HIgm5gyEDjiZ5Hhqn1/P7x+vy2H423OsIDNqiAiGbzCB1muVcys1WRXAXfKakVDtVreYrO7fvGLf/vv/uovf/7zF+2LjfXzezc/lGmZs85Ghi46fz+WX/2QE4Yd7pbi3t7eqosddC1fNOFCHWUQ0VRkmfxemqZpoPVhe/W8314XoXnOKmalWEmgiUGUTE6VQE9M7M9AWm2jwyBrGptp7ZhuQEzBGDULqlL1X2oBs1qGtuoarSY2qCloiZp35i5JLz10Dbsu0Gaj7ZVmJ4tqSSAL6YKa0azSsNRARURUCMF59pHImUKSZHnKADO4u1LuyzJaLiB6sk4+czFkRcs4L2quiavrfPowqficfU6B2REF79sYO8SYFluWpRQhV5B8EStZS8nLIvOkqljSsixLTmCG07Eg6jQt85xK0mkc9/uHnPN0nMzEOZZScgK/3uoisqQ0qeS+75romfA4HaZpeX+7Px52u4vOO7aP/BAzVK05ZVVzIlU645nifopWr3Co1f6gK0hXP7C2C0DAmjmMpipFS9acVKTCzaYlL5Kno0JZQFMT/WbTGwASMbuaDM2IwTtitlpamGthcqulM5HOufeEf0i1P17eefsQa7463v3s/e/+7et//sXb/7Yb3ztJuRI6aoqjrfsJCMjAA7QEhVEYgMEIgE65gWZQ6Qi1fSigEimTOFdCTM1mHp7Nu1dpeF62N7p5Bpsra4aq109pV0/1up3CTR9yHA0kr0YgEhkRIoCZFCmp1LZvlhc0AROTIlm0WNcOm77/7uKbH15/dxiPr9+9211tg/fOOeedD8F5L2ZpnveH+5xkvz8c+o1HL6q39w+H4zgex9/8+pfffPf7Jc1dP/jYpmyqkooaonOMbApJNYnCMh2nw97QKTrnGgQqaSmQnxr1piAZcsIlwbTQOHEkB9nrjNODvPk23b5OVbXvtunZFb96ga9ekN0Q1j5QVsyMVMgpnnLmV1GIrHACjmtFaENwTGgM4Akdg3fknFJtSmfZAEVtXjBnyglKWiHBNcoEtgocWxmVf9aoZmfj+KpxX7T+5w1fkbg8QjYTVSlWZkszlKSZNM24TDBPAIiVdCAIgiiGqAjIzoFAzjJlGU1SLvs5zdMyT8uui9voiViQFbOZlJzTnFMuudRIg0kuJeW8pJrg/uGu/1F/vbZOIkLnfN92fT9shmEzDNuh2wQcysNFOVznw3OXv+jLzQu6/Io3L113AY4VDsUOamTGaN6sb3Dr0DE2HTJaIeQGwFlGWNhwC/0ziBd4M5A4NDKAvIzTww9LnkroaPvF9V/89c8H+8V35fX9dH/37qOo6JJ1zuqZyfmmbfs2eAJEUCNFNiK3AjjRt9E3xMHICXBRS1kzGKKBc843DRg78OR3Llz6ZkdxY64BF8EHdB5qEpEqqKAWSIly8SJOFHOhNMG81+kObATOyEAOG2x3rrmOfN22bw/Lm8MisKhOKVr0qfElfFCbimqa0ZNa8WfpxQCM5JyPu6vrL77+6mc//+rLr57vmuzxCGYp28OxPBzLvjB4t9sNd4l+e/s6gT7vvWGzt1IyHw66CXQ5DMAum5gvBomsafyATC7Gpt2w74pkLbVDeGIQh2ooBAJW+7TWkrDnWOeHwKnTGJUZK6NFBLOSKoN5R8xg0Rkoqpgq5CxLKqVWHlQzVDUDtIZxS3BN+oLkinVwFhyAJwsBQqPGkjOAoWbQjCYAuNb4PUtTJPTOt23IHS1RxWXFUctBy4OUvaQZVVYldirZ+BHPGlDVchJCIOJSVFRVEQyJPFAL2pgGQ2/gJHMyKfMCZsuSSinEBcmLooimnHLWktmASpElSSkECPOcASwlywVFICeZx6WUklJyzqelqGopNXtHTZPKLFkRnXfomGLw80zTnO7ux9u743bXDZvm4xKHqiAnaLt2+lMz0FPaAq5++FoVRQFx7YRb7+gKIBFUgh0CMyJBzmUa5/E4z1MqWc55RGpWShmn+eHA+3HyTXQhBuc2m83Q903wgckzIaOoGdrKvnJUuyrDGmGtTK0/zpA/7Q80Q3IiL+5f/6sffvkXb/7bq/vf+XJUUlNcjZja+qu2SSIDUM+AhsjADpHRKiZvegKSANeS+mvhCCNSJnVeQpuHy3z5qlx8odsb6C8gdEYezEDLegkndQ4f/fnhqee8iFhlwdfmAWCmuchSVNVAiTQ2jrGVXIrI/uGor6Drh+121/dDKnL7sJ9zDtETIRMiswJNKU/T9O79mzTnt5fXfbNxGOZ5/v3vf/fm7Zvb2/fffPO7b7793Wa3e7H7kn37cJhE8pIW7x15Cq3rhjBjMVGRvIwj+YZ9w7EH9jodJKenprCKSYK84DzTOLq991ziYu7hXXn3/fLd76b3P6Q8K4JuOnx2TQ/3Pi/Bc/SuFqkSkeJ8dr6QI3RUOx4gMiDUnvOAYFU+ghESsuNIwTnvXQyZfSbKZmomiChKy0LLguPRZjCDSvQ6rQI7BfPXHvB/zkBEoi66533zZeu/cnihC86LuKQQQLPlWaeD5cUWk9HD/oFdJN+AKCqAIBYgUSIkIGIExgxSSywsRUsqc5JxToehGzd917YheEVALFALZKWUUzE1VSlLzinnnE1Mn9Bnnxb4/PDk1wgvEYXg+75/8fzVq1dffPXFV1++evHq2eZZWDb3v+re/7p9+5u2HLrO2leu/7JzOyJEO5i+L1oQSaxR7QUY8KJDjxa8IpkWVIaCNqFNDuM1bv41D7+g7hVzi+C1qN7fPvzunx7u3mbiof/Fl1/+9ast/OJ3x99+++Y3v/rlR4bvlPK05D5678Nm2FxsWocKZmJQDAQoNKFpm64f+k3ftdQ4DVwcZiqzlUXRBInZUWgceoQGeUN+ANcpt+Ya8xFcAGJQMS2QZsiLlaTLLOMBxiOOE4yjTQcb72W6Q5jQF3SKHsinBmXDehmCLzQ/5DTB/ggxLtH7i8s0bD648TUb6KSontCckIAcsYtNe/Pi5ue/+NkXX95cXbQ4HsfxIAK5xNfH5fVtfj8VbvBVt7sX9+2ekgheaNN5C9uUy3xYoC1b9MhRtGAI7Ns29DBkZnaeDZwWUClWFtCElh2KsCpqNjEtqlm1goW6YrYfDiJzTmMk7xGQRHmeKRcH5tF86wlr5XZDUDz1Rc1LKrloUTVURBucXXl8GfCV12u2joypJj4SsFMiQ6kRPjSpzbQUWU11rdlKhoDeh64LeXZjhzIXyFOR+5TuNY2QFwIlXG01q/0dPzLhyUyLGWVFBClSsqoQGDrvHXQEETUAkirNqZgmUDPRkkUE0BUkZ8CqlnIuCmABEA00ZVNjUZqXgghFnKqYWSmQkkpRLVay5iw1487WYoeL6sSAjlrErJAqNWOaszxM794dttsuRN+0H0yJqVmRlb9B1RevPjyiERMjkikggK2pvmpoZgJ06vddAcyaV03oPbOjZU6Hh/GwP47HqWRBICIjAGCHYGo6p3ycUy86DMOwGba77eVu1zVNYHaECiA1oKyGgI5YUUXMtBpmiKj64XR8pNrXqzhh8ghmMS0v3n//sx9+fX33QzvvzbLyyvQ8+cu2KulatpGgIXAMgdfEtBEsIxphQVz7nxkgPDYMRQRURVBmotjQ5kI3l7rqdajtzz/W6HoygT/jtIMWkdrYxwiBFRFUVWoHpawmvg1dP/CWSAnJ7x8e9g/7i4tLItd2g1FNw7AiWkqe5/H9u9vvhjdAflnm92/fzdP8w/a6jQOz39/f/ddf/vN33/7+/u793f3tcdxvLnaxaZecD99/p1oAZLPdtG0bm7bfbIlcSYWcF9E8Z3IzUmQmo9rL5clsKGpCSZxnNx78bXFHNZjS7et0+0O6fSfHPYKS90Qeu0zHhR9GvH2QEEGVS4EQzQcJQXwgNCYmIDJSADXjmjGDK9sI0dhMAQqhIBhTxYWIcCHMTCaCy+ynkQ73towg2Soif+p1vVJBrSLknyysj8Y5RdkAkMkFv23jF0N8GehCU8xmixScLBFIhpIgzyCpaBZRMuZ5dt0Gstk4w1JgzDYVcAhkQGRCxaAAqpGqFZCyaJFUCizJ2ibHGHKReckpFVBDBVQrOeeUciolr718PrIaKyj6wcXU28fUd91uu3v18uUXX3zxs69+9tWXX7188fL51cX1QJvyxn/zxiv4MfOcuDHnzEvGI1ou9iBwqygIHkxNAxM13GwpdEhkKpYXKxmyQibIDTRXGL/Ci3+DFz9H1yN4LIW6e4+bcPEeTf3NV+Hiy6hT2w0xhE9J/rlIFq24Xx2rnDIgICP2MTZt1/VD1/VtS9FJpBIwk3jMgaAwGqIzBUMG8kABKAB5I2fEgAwApgXzbOmo0wGW0XKxlGBZYK6gy2zzpPOiSwLMoJmr3W8LsW9AekfeQCYdp1Qm9cF5xz9fPi4itMZBTqVcbA02sQERuthvLm9uXnz56uWXL4e+0Zz27x/S/b5r+yzDDyl/dyzv7lLcYI+7pelKuJzHfL9fjKUZNn5J48Pe1JiIiAoxATpEZE8aagBWi4kUkASaCJLDIiSMWukxtRqkSlGV1XH/ZCtoLU5R45+0NqhUdaAO0TE5jy4QBWRHDsyklMoUKqKiWjPTB7JL1q+d3bB2XBBBDUyFtIBmM6e1DflKkamQMdZiJIZr0JQB0cy3bWy3swlkWJI+pPlokt3a9nJV7aWYZpDydCJijCpqwI6A2bSoGZTK60ZidiVzJgQ01bK2/amRTUWrNawZAEENi4Aa1latgA6wMj68gkMzIIcrB4hV2UzNqIguS7bawgycAUkuWmYjQ3aiJRebk80ZjpNOZfnh9UPXh37TsIsfUv1NRSqkXv13qdwGPGUVIFUfnhGZgcmI1/oDBmoAaqoCqoYrEurROC95npZlXnJKYOYdM6IjImBHoW14t+2fPX/24uXL3cXFZjO0Xdv3XXAe1CSLmKVSUpFS1Co52lQqd0ylZuGlJT1dV5947efUiiqsVWOanr///svXv93s7zgtxVvF3s1AFXWNIlU1CormGIOH6KAlQwQxAMNjDasTVhyp6uaaJEAApMaSXVm85oCmsdGmV/JgUFf9icFgsKr5c0b1I+X7ySWsBk5FT6iK49pQ0EQ1F03kwrAZtsOu8e37d7cP+4fbu9vN7kIM226oLQBVIaWS0nI47r/77jsxViUp8v7t23mefxi+b5rBx+bdu9f//Mt//uZ3vxwP92LiQmRHzodxHF+/+cFMYxPYu6ZtQ2j74RKpWZZkqoqQUrbjhBx88ABAzE+9dhOwjJJ8nuJBXdpDecjT3fLwJh3uSs4M4HzA2JAf2PWMAZPZ7TGzK1J8TtS11kSzxkCIwKEjcATAQGxQSREO0dUE7oovmRmgBIfM5BwFT1TZkFikQEo8TnTf03jEZVYB1er5n6bSrB7jI3MeP3nygYAjIh/8ro1fdPHGYy+zy8VQclEDBSlYKoqotqBNE0wzP9yF3QVh0OPBpmT7xcZsgcwrEKmyGJa1hA6hmYjOpeS8jGP23jnHWKWqourq6GmRvOSci4rYJ+HzdWPYkys5GacxhJtnN7/4i7/4v/6v/5f/5a//+i9+/rMvXrzY9psusLc9vU9ya0AT24yWABAywn1SRJsKjIKjAQEIWOMUPYUddTcUBgDCMkE6WJqtFCxs0gFeYngJ3dc4/By5A3CoheP10F65+ZilxH7H3Y29/0EVajPvj5SJ6AkjNiuiuUgxXfEbR+R9aJq279u2a9uuiRBcCewieVaPJZCk/x9rf9ZkyZGdCYJnUVUzu4u7x4LAlsiFa5FdPV0l1dUj8+N73noeZoYsVnUxWcwskrkhAUSEL3cz0+Us86DmHguQyZKWMQmEhEOAcLdrpnr0fOdb0JoDaTUjsMAcGIkfMy/R3UEFtHk5wvJgl3vLFxBF6ShKg5KhLlqKtaZqiAZrMg0CGXqL7huGBAgNcpUZlEPjgLV+yDP9gQv7i21AyHFzdfPJ559/9uUXn3z2SYjtfHp4/c39+e60f3Xtw/47kddVbk9lQ/gp7WFD8eoVycPlsqRNu9psnROfFwaICESojAgciMkMlcHERFwEWgHJqAt6IWwIDay5NtUqUlSbarOOO/7AERdUoDagAMgQGNcsZwjuDMBMIYW4jWGKcTOEPqkGBxdwN/duiAMbsD3os7o8r0vSRd3NnNS5NWzZNLpZFxQ9tkKdkOPqYMA9uZYQHTSM47S5XtwRoHo+VV/QNPgj5dtQDaSCZLD3SztN48bM3InJmNTEAcUMmiAgMbNUqubmJioqTbV1Zf5avPvolhGAu1YfMBAG5EAI5IkoACcAJzNiN0OkBBgBAUhUoRSBruihAEatqdbFOTqzuJem52zn7JfFPct3r0/TFF++uk7j9P4ScTM1cwc1aGpNtYmqOxIR9xpOaEhAATFGGhKlRCEiIpi5qau5iElTIrJIfRbex3zaxMyYaBxi5DBEDmRDxKur6eXLm5/85MdffPnF9fX1ZrMhJiZCRBF1sWaWm9Qmtaqpd9qUu5mLWZfTUF7K++/V92h06/royJGz2lDy/vxwdbyPObu4sSuhACiC0fqREIIzEIEgegAPQMGJMZgPAsmwAAhSDd7cBfomvbINUJ1EuOVQTulyV893bT7h+BwHNuTVR/Fx0r5yGXyle//gKRjAu6rUiYARiRHQsKM0rbVc69zG6Kq7afvi2SspejqclmU5ni8GNE67JedWZV6KmZZWa82Hh1tETiEB+Pl4rDU/PLzd7fbjZrq7fXN39+ZwvK/LOaQ4biYiEpFcymVemDGmoGrSFJ0ij5FNA3ZAyR2lyXKZW2X29nEYtbosXsnALGNDsXpq5VjzWVsBJI+DD1ewvYGrZ3z1nKdnSBstqIdizWQpfLXh3RZMAQwQNCQjQwwG0bpmA5AAvWforU6y3S2jSw8C+Ii2jwROBAAm6stsxx1fTjBfoDWA97r29RGZffRMeidlT8qkzs987z8igOi+JbyJYcNAqK6itYexC6qRat/9gRCaYlOoVUWUB1/MxZtYK6ZqVsCYHKMj4ztv50ddvolUR6r0dFZ0AsBOsSylNhGzVUXyR2R8+LhMAvOzm2efffbpX/+bv/qf/uqv//Iv//JPfvrTF89vrna7FGOARtlci5fFa2ZTcndBz25HdQfPAuIIAQmcEcZEV3u6egbb5xBHMLVy1ss9Xs44i9WkvrFw5fEKcEQLgITd9TttphgHf6HuFAbHOC/5/v7hdDqqflwOxbw0vSz14TQHpqWUyJyGNG0347CZttv9frfZDEMiRiF3MmEyBg+EHNgNTc2KKhiNmxBSdz3u75J30yxTkALL0ZeDlSO0GfsZSsRrtbJYXlqZpWU1YfLU515MSAkppOCbZFPkyIQuIqZu7PQR5Ogr++GJFd8Hg51RDshxf/P8ky9//OyzL6frF+X4+uHt5Ve/vrt/e36Bn46fDHV6CVdUvjvbUk9z3u52z569IIHL4VJbw5QYCJFdTWpJa9Q30YpxoVnnaS11OZ8Pb0/Ht5fT7Xy+z8upLOdlOS/LOS/nvFxEqvUa+kNXEywFHViNY2AiVgmmASwgBIdgxmpkDr2+xMgRiQbqbGUGYICkOopEcxEFM3UkWD8x0trHVoBosLbG7gambi7m1aE5iJu4qqmAG6EALmJzk7m1GmRVHTmAmauQVrKK/u7VQoQYqVvQEhChExMRGaCoh0gcInJ0hJ7h8RTcAQCIATAgMRADci/vRBxC4hAhBFBgaEQhDhOiETsxcUwppjiOKiSugGDIiBEoiYFobTVLXhplRmpeSi2X2UpBVXbQZZbzqZxPy2bK75d2cxMRdWgKpUppWqo07dsJrnCzEQIywBBpu43bTdxgZCZTF9FapVVpTTlQstinwWYWY9huN8+ur9xhGNKQ4pgCWo3sz5/tPnlx8/zZ9fXVbrOZYorQDYnMesve1KpYbVqrqKxcXqT1IN33o4/GI9/v2t8ZiCJAlDaVZXc5bc5nKqKKoiAIgq68Spq7KQoAKCESeAAL4AzGYA4EHtyDgSEVhuYm4J3z1BnWAE6gTDkup3R6kx6+rVdvaLgBipB4lRD6mm37/qLGJ3jhe5eLPLIMEPuIGRDE3EXqUpZzGVJbSuL07PrF8f4QmGqp58scmIZpO+ecl+V0nmuLpYlKa+3o5jFEIpwvJ9N6Ot7d3W3SOD483J/Op1wWs5Y4DeNIzLVJqU3UOEQOERylqSsSBMJAqEBo7magIvVyIcIp4RDgfTaQibdFTaUuYOLd7NGyuSEi8gBxZ5sXtH/p1y/h+oamHYTBle2s7ZL9MnMuqUmnMjqCuUHoBFxy5D4jZgAD5G6j1BVbDmgKAMYIQwDaxEgUCBGaiC2LHfd4PMAhQMH3yRv9degnrg+bcgCCJzeb9T9+v7aTWzKbALaRU0RjVwdofRon2GGeTl4hRwIWBRFT8zB6G4ywERVAKyauSs0CYqInffnT5MfUeyx2B3F6DHZnCvQp3VoI8fEl+6G360kUyohjSj/6/PP/+X/6t//bf/yP/+5//l8+//zzZzc3zJ2baCbF8+zzRZYMtQUzdAAByI6uYO5VgYnGAAl8MNyO4dkNXj/zzZUxo2ZbTna6C6czXUTatuHW4xWkPTpjbdChTwRkSsOOQnCgJr7kcnt7+913397d3bXWPvr5xTyLHueC4LW13TRO03h1RXE3pGm/v7raX03TGCMrWgNRciNwAmQEBlcVL1WLiYXEGLcJKGIXPLqu9bY1rLMvR8gnqBfUTOgIzTVb33BLybW0VtQlMmAIFAPHiCERxjT4doLtGKZEISOAmtPHAuTHV+7Dh+JrtgEwc9heP3/+6Ze7Z5/yePXw9u3r2+XXvzu+fX06b/OLLQ7Xzwdnj78vNZ8OxyGl5zfPqMr5/tBahY5QE6lKzTPHGMctrC/x6vFs0upymY+3D7e/v7v95ny8vZzvS76UPJey1LpIKyLFtK5o/PeMcgFAFGolNZIWUgocAngAD+jRIRoEUa6ITM6kREiEHCgyRsbIFPoEogKIVWezGMEDcgAJDsGAVMyjIjsFAAKKAGxmrs1Mm3oxa2bVXVREW9VWQbO2U15OdZk1N1LEiB0qVQMRssbQ8L3DCiIwd/y0HwH6giN3UHVECilRCICA5mwOGIAUVmyMARmJAckcO3zPIYY0ckgYGMRYiSjEcSQ0YePA7hhDHIahVCQTAEAOiANgEs2tXWqeZZkZFcGbWWmyLFYruTESaIO8yOVczpv8/nalplWbCFbxpchSZKlSm/k6zOjvF6EBuo2JiwzmEzGmxGZWa1vmXKtI0xBDhzFVmgMOKV1f7V++fE7Ml3GZxjSm0MqFoO3309V+s9sMQwqEbiortGYu6k1V1NWwibairamoEa1mJ0QMq7j+g97w+7P2p2dFZLpb5mfHh+vzcbPM3NwkOKmBG4FH4oExoKI1cW2m1o/N3UgDG+KCkMkboCJWDhcOlVzJmrsBKJEiGRA4slpq2ZYHPX9XD9+UzTUOGw9xJU//gRL+w5e7ShNVNiJys4ABgIEiYsMewaJN8mUuuYqu4tMq7bJcdtsNBpzz5Xg47K6vt7ttE0MiK60sl4e7t0iYlwu4zufD4X6MYTifzy1nUzWHOIzXNy822z0QAhKHIaYxDhvg2Mz7MELNVEV73GpdFVYhUMLkHOE9cY+Jy6KSzbGBEii59AMUcMRpy/tnvH+O22eYdoCjWHAJihF7AKIJtFMrTZuwtACKZpDMGbphfz8zW4fZ+mjJQQENEIjQjda6TQgRfYPuQa3VoseD3d/7kGwhV+0LvKtrwbsC4cMD5DjGYYxNVPWdtvLpmSJARNwwTkNMuw1vAwS1irgUEIBmq4kxOAAYAQSEiMQMGCAOHreWooR925R2e1rmtii07peFjzUa1zmmIyFo/3dmbmamvqLW/q8J1p9+XEBwJ6Kbq+svPvvsP/zf/t3//T/+b3/x53/x5edf7LbbQIwIDuYgDkaujj3PZgvDAL5gcCRDUCSw4DAQ7AOMCKPD84me73w7aWCyhvmC84kvZzxnuIDDzsdXuHlFm2fMI3mfS+LKmaCgipclv3lz+6tf//pv/uZvfvGLX7x+/brW+r1bIHPMTXEpolYVrjlNECFMnCYOAcFU5gZC0kIwDwBKqAQI7qbL0i6LKDuFVQBjZq0AGDI7gKthq1BmXI5Yz2AZsRKZSbF61pKl1tpaaVJUFRwo4DDQGDEG4gAYNg7PGr24np5fjW9yg2wG7kAfu4t0NBfW3x89QcERKDBxUoV5lsNhpjhcMlScNO0q1dvDSd/efbGbxmna7Db5eLrcv50Ynl1dp5hcRUqWkh2ImM3scj5zHOIwYUeuRL1Jy3O+HO9vv7l/+7vD3TfHw5u8HJflPF/OOc8itePwZrLC3n+oD3FSR5fVgjwCMTFjIIzdXZqIHEEcqgK1TqgCYxdGYWIANqAmVJUE2Dh4ihgSWAJgG8gi4OgheQLA4HGnPIKBtaYqTaSpVbPmpiamWlVmKQ/5fHu+Pyyn5mKETNS7A29mtbEb4QcHFXdvrZi5GxECgkkTEXXnXnvMwKC7PCABgEY0e7SB6lg8OnQrHAJkpwQ8QUzIRN6QhZg4MQGIIgYmZIoJUyCIpAkQmaN7UMPaSi6Hks+y5EjGBGLQtBOwnGiNnna1PLf5XIb4zou5Np1zrQJVPFfNVUvV2tQeZ1juAEZdI1cbKYiYVG1DYkKUJjmX1tTUgqq4c6lEQMzd1GYYh2FIrbUUOQSsWUrJlzMdx+HqfEnjJkRhjkDsgKrWRGsTNQCgJlZyqVVEjZkshhCovyCA/NHq+D5D/nHDRYiqz86Hz+5fPzvdb/PCDdyCCwAqDEAD8dXgCRuoZJGzq7gjGIEiVoQFYCZYGCuAAi4cTzSVQKpOUkCbEgk+cl7cozQoJ7u8zcff592zsn8BaQMc3w1of7iN+gHGllnnrSAFMlMCBEYCDmNKbdA6uHm+zPNlmXvCMGGTtuR5mFIMNM/nu7s3V89fGpEZMAfxKjVfzgcAqCUTWL6cziEFGnLO2ho6IYVh3F09e7nZXXWr3xCGEMcQJ8fQ1Gtf6CqtVTER6ZM4ATCIZOMWHAHepd2YumRVNVNnTAETASEikIVAm+2wv0nbKxh3hlEFtboFMAqESI6UZ73MOWeQFlyQu4DQLXa6XiR0RjBAAzcAdBBwcdDHKTLDSvxBINqNkVDdqdRy/yBvdzokYwIVXGd3vUAamH9Ao0OAzWbIy7DkWqv0Ke8HJDqASLhNYbMd4vUWd4MHh5NiPbtVaAiPG6O5O4MbOjNjIp5g2PqwN97os1iv8lLhXE5zs/q46cOjmvjJvYiI+imB3M2w+zs9MTd+cPN9/3ryQYoxfvrJJ3/9F3/5H//9f/h//K//27Pnz6Zpg4gqikwdvsPew3Dk6QZ2z+HyxujCXDEYBCBAD+AbtKvgW/YN4YsJbza+CYYGLdN8xMuJzzNcxHP06QY2n9P2VRifMw3kazIpEgGSiF/m/M23r3/xy1/+7d/+zd/957/7xS9+8fbt2+937X2u2sygiqo5h3FPSgnD0I3QREpdMoaMXGMAiwyRUYKDmYlccp2L84aGTtFwbVVVSBgDgTuoQs1UZswnlDNRRW6Ahr5YPUvJKibqza0BKNDAgceRx0irHIUmomfAL2+mV9fT14eF3ETQvv+AVtjunbXx+uoycogcUy5yd/sQpze5CKh63A03L9PZzvPF3nz3yWfPd9txd7X1Ms8Pb88Iz3Z7ZjZVKbnOl26H726X83kYN9v9FYBZa96at1rm0/l4d/fm6zff/mo5vy3LobWidan5nOeLrUIRgyd/2R9q2fvzcKQezkAGZMTESIEpMAamSIiOpuZF3cG6o7gE7VgaAZABiVFVFifnCBQBB4QEyB7JR+Yt8WQ4IAVPV4FHNVBrIk1Exay5NTez5qqL1GOZ75fDm+PtsRwbCdBaJly7zqg5dpvn94hB7rUUVXMjRCSEjku7cR9kqLkYYCAmJiRkJ/PuxNIxWXO3fhw3dAyOyXggTs4A2kfdyAEBulyAkANG9kDgTJqQKYQkilalWS71tJRZSjP2xGtyYJf6MjqTM/bSXs+nkp69K+2ltdOcq0ARr+qtWa0qTbu9YOcGuHXI2KpCtZalzrVMQ0iBTK3WpmpuENiKGLi7W4iBA+XWVochF1Vojsuy5PmkWt1hGLYOIaUxphSHiTmoe2uaSzUD4iDNam61NVFlJnczYzIkCsz+rzHk3RyB3MeaX5we/vz3//xvf/2PXzy83kqO4EQUCD2SbcCuAl4HDWDZsXVFFSKhM1TyjLgALggLY0UsREsYT/HmEvYN43R+O8736mZraefuCsQiKV/Gy5ty/K5efaZpr+MeugmoPy3mjp9+yJl/b80j4rjZNBFTAyAVB3IKFOI0Dtspbbbjrs5zKeU0n0+XSxMlYjdrpWgVgiAiIg3AQwhpSCkNbcmqGgIjUmM21VZbLVXUAWKI22mDFHGzveEwqZEULaXVVrFgmKOahxCWXHNttdZaS6uLSGOgFDnFOAxxSMT4wVjU1a12T54nXwJHAmJkJMZIPriiVm+oFYS9EilET5E4BB/UxJdab0/N1UBNKt8YTEiJDYNSaI4MEBwRwADEofXhB/rq9EWP3RAiDoGvtzHfwOElvr317WAP6OKgTyv8UbzwQdeOePN861jTOS9LLaW1ptoDWB5H2imF7X6arjfhZsIxmCuUQDE6sRuCOuh6zOjJfUDgBM6GDLQdcL+XsJtjvv/udDyWi8vMJIGcCHzN0+thD90EAhAR1R2MwMjRPqCaPf3xeyY0775OKd1cXf/Zn/35f/wP/+tf/Pmfv3zxIo4D0GOgBDH0ySSyUcJhjzdfQj2rHe3S3B/Ia3cbAQeP7iP41YDPNnizhYkAK7YjzEc43uPDGQ5FCwvuffM5Xf0YwwvXwYwgOAJRiBRiafZwd/8vv/rNf/pP//m//v3f//KXv/jd7357f3/fWrOPFUqQAg4BCXoedKdPkbmXWi/zTIZWm8YMMYdYLaJZMA1KbGagreXamiKPELChWl0MFncn6mxKQ1OqFcoC+YQ6W2iUjBKYVNXiLogYE08xRndD2Ww4cg8D8K64DpQ2Q3y+T58+216/OScgNeiZHfDx9c6uG4AMCULEMIbNPm52Oddvv/7mcjrvr6+uX9xMm+mTH/8o7Xa//e3vZD6Ww8OIN8M4lWl7eXg4hUtVNIrIbFLz4W7aXU9TkqrzvIyXy74UAmi5WK2a8/lwe7h/necHsDxEiBRbs4WkZCoEoP6Y87BSgjpY9f0b6OSwdY7VITDsGdN9uLuS3x5hJlBFZVKjRMRdw2pAaqjOhuzEABGpAo3IkcZAG+IrDBuMRiHicG08qpOYqWhTEfNm2kxFq0h5mM9vjndvD/eHy3GWRZKxEnR/clFralUbGLGbfrBqWuvSLEAEAxN1c0cEDkhopqJCgG7UoxkQoE/lGYgQEM3RnQMTkDkDR2CyjoABIUcKRDEAGMXkDsjBmbrPFCUOIYY0oKAh0UwITowY2AEcwzBNA7EjKIhUhc6rXcr93ak1vLn2Jx3JZW5v75Zq3tTFQNW7ssoM1MAeheT9HzZo7uLe1HPRFAjBV966OYIStcedpKjL6XQ8HB7u7+5OxyOiE1jJi7U6T0Ntrs6H0zIMw7TZXl1db7bb0HNNDcBcTVUNwNePjAkQRMXUECoRtfZHGPLeB9jIbvvl9MXdN3/123/8t7/5h88PrzeaI+hK/k+IG9Q9wRU1AGlACBSQCJ1RyRt6BpgBFsSM2JgyU07jPNw87L7I6eomRLJqUtzMgB37VkioFloeLvfj+fVyvmub5x4nezwzfqxld4dHo9kP1jriZr+vtbVSzc3UvTkDDmnc7/e8c9kvd6+/vX/z+nw5H8+nUhshi5rk0mqjlXzpIXIa0jgOw5AyMzhM40TMUueSW49JRYycOG2uaBjTEMbNlUHI1VTrvOSSZ9Vm7rW2GGPJS86llFLL0soFTDbb3W43TdMmJgbvQV7v3Yu5t+7TTtTHzOucjwjIhbWEeiEEsMF8JBd3R2VToCFEioyDN7V8ztrMG6gGhODkGICCcgTC5sBI5ogOAqCPQzIEAKTuoAhg6O4RcTuE53v87Dm+eeb7SYZgFeBpgT/i0Ab6/uOAm5c7jJKmkE55vpRlrrmCN+0sKEQMKU67zXC9CVcjskNuSM7MhqSOoODSW+s1i7fnmZgJsdBLpu1GttcXH95sN4d0bAoLY4tkiC7dXx0e+3Yi9G5wQg5s3avQiboT8b/SuONjW7iZNp988slf/vlf/Pt/9+9//OMfT5uNIagbEj4qZBDdHILT4OM1P/sRQFM8wKl6VmqnvheDgye2IcBui8+f+9UEKaA3KhUu9364x8MZTioytfgCd1/S1U+cbqQGZnByBkYMBjQv869/89u//bv/9L//7//Pv//539/d3i7LYvax1LVfPRmIgIgwMMfYSd+W83JhgwY+NBhyHIpaU0fzqBIaMqpYa6Vpc2R0jFC9ejmjKoiQC7oQOrmxVC/ZlzPogklpdHY2VbGsbh0un5h7enoMxlBBzWyd9WHapLC52vDL/XQ1pIRUfE0x/2ilP+6b/SRFToHiRJur4erFtH+uTg9398fb23FMn//sJ5//5KsXX3x68+r56XR3++3r/PAwUIhxStP29ts3xEtRMIocorWyHO6GELbPPpm1HOdlHi91WRig5Sx5aXk+ne5PxztpSwwQYkLAUgjBliXmzO7q+lTUn1jAP/RSBaDUZ/H2qKLG1XUFbHX7dEMwMhV0QVAmM5YeW9nH/+b4uJMyYgASYIU04CbSDmgCmphDiCPEbcAoZmou5qLWTKtK0VZaKWW+PR6+u3/79nB3Wk7FqiNiZY8KAN7Um2lTdANwfV+K6KsUDpH80ffFzRGcyMFFJQOYWVBiJCYMSIF7yjwSIDk5AhBHwCCOjgGIHDpmR8gRAwEHBOM4mAMwObiCOgIHDinGYQBGcQzMgTCE6ONI4kRhmK7CkJpLlWqq1pqq5sXu7s45w5/86btHc55FIYt5e+zRu92qO1g3B3hkcCM4GT6mmWNrHpkYAQF6ae+irH7ary3nPB+PDw+H+4eH+9PhQbWBKQIEok3V2iBXu78/p2HY7XcvX5Rnz55td7uUkvdBp4urE2JgYqbOYBLR2mpvpWr5Iwx5QHBi102Zv3rzm7/+9c//7Pf/+KPbX1+VYwRjMmRzthCNRrbRfBA0r6zMzgEU0BKpQxHL6tkwExamxlQjV441jJfdp6erL5GQvaXD67CcGqER+6N6Ck2D5JjP4XJPlwMNe6fwCGb5Ssd+MuHxlbr1AaEGcbfd1lgrU6tN1VRE1SFNiYftOIX9jSzL/ZvXS54Px4O1Zmqt5JLzZrsPMQFSD3WLkVKKKQ0xxsDh6mofYmjlbNoAmdMUpx1QGpuqtCEiQpwvM0BVxXk+LvOBCGs55zTGkFqtJc+lXKTOiXF/dfWzn/301aefNdF5uRzu3yzz7Lv3ZRhousr+nXokCoADKGix5ZhNLV44jhQHGybb7HWzt031aetT0sgQACFEH3ARe3vBZlCavVT8hAgZY3AARSjOPWKk218wAD9S2VcL30c7fmD3McDzXfjs2fDZM324bXn2prAmJDgQgBvSh7ECN5/sw+TjZpg2+XLK53O+nPOylNqHIUQ9jxb7By45XI40X6BUEMX3tsd3JPxmruLSIAorAFIBfDD/TvQglhwqkBE59giHlSLvT4Btn5cQ+CPx57EKrtTTxx/8Kfbp8eXrSD7S82fP//Snf/KTr3786uWrzbhZR/X9+/RMIugTxOg8wXADlChx2Lgft37Yab4DU1RFV01chxHSJzR+wcPA7KBnWA5+mX2eLYsLW3oBmx/zzU9p92WRjVTDKXBg5FCqvLn97p//5Vd/8zd/+5/+7j//8pe/ePvmTc5F9WMJ+NNFbgzGTDFwSmkaUgyE4NJqqyAEympBTNXUVLCBcFcYNGmlCoBQCO5q6jUbCDWhWlErWUM0RoumVKvOZ5MMUalBwOjgVYuYeENwJmboyldRaIpdQuFOCKGpBvdmCEYIDD3j6CMrzVVu+YiuICATD5urT/avvrz+7Ivts5etSbks5XDUvBze3MWQrp8/31zt9zf7+fBQL+eZ4nh9Q2EQ0ZqLFCHGOKaqOZ8e2mYTXnyKBmWec0z1cAyEUkstS85zrUURKI2Jr1wWbUuz1hQdmCggtJVXuhIAoLu/ff9xuKu79teyt+aiAt4MjVbbTnRwdEMwRGPwRtRiSOyRnbo7vTsZsEPnmhsRISOwKlDVWmclTWkzTnEbZABZSdvgAl6kzTUvbTnP5+Pp8Ob2zXdvvzucHqQVd/GMCiiGSORmVsWbNTMFl/ejxhA40OrMY662aoBMm+qZyEhqCClwcGD3wDxymELyYL1KrYO/brnvGJzYH7V65uDAjmzQXakMAZzYXNwUV2UEIFFPMEFwAhyGTUobaI0cKIxIMYRditWaNZPufrIszaG8f6Bv4p6ta7j6ZoCPD9EQrCdnuCN4t57rWKqLqxuq26PTgvtKy62t5pwv8+k8n0vNueS5tKKg4maOiAkpIRuyqC+lLqUsy1JyOZ2OV1fX291uGkcOER7j4plw5TX1qEpRMwVz/WOhrg5oPrRyfb7/6Tf//Ne/+q8/++6Xn5y+mbwRGKBCj4IKztE4IoRqhiEoRWclJPKBRLwoZPesUBALkTCXEBpz45SnF+ebnzBBhHrdasgzEnsInbvlHB2RrIV6CZcHvjzI9jmEETi9g0c/btM/Xi2IOI6JCdiUTIuIVGniEkdUn4Zpt9kc7t4SYS35fD4GZHdstdZaaimTATPHEJiByAJzr+vIeLXbxRTPh7GW4hg4jnHaUBgnNW01gIHrfD6roiiWfMj5AcCIOIYUQhLR1qq2Qmib3dUXn332V3/1b3701Vev3959/fXvH+7eLEt28ycinTu4rii8e88BQAQAc6u6aC5L4xRCIk6QJttc22a2bYHN4uPg04Bj8oCBYmhip1mPix4ukhUwMDGkAIzKaODoqyUwrjvmSn3v8U4dI2R0QPOEcD3xp9fwo1fD3S3c35eluQOvw0QHUng/yAMRr15s4xbGzTBuh2k7DIcUIzPh2V2kU2lIER2A3bgsfDzg+eRLhiaPCP/aeYPjaroJim4osBILmhxKfVPbSfTaUADtiVX1bpzz7kcCgK5T7V+vWJfAkyKfiJj5+/o3BIwhfvLikz/72Z99+dmXV7urwEHU+u6CDO8qjZEDA08wMow73G7i9aSnfZtGn9+iN3LpsTuFUPgT8B8NNm5ASV9DWaCwN1YbjTaw+4Kf/Snf/IS2n/nZpGogghiBw3JZfv2r3/6nv/27/+P/+H/9/Oc/f3t7Oy9Lv8cf9s7rPxlYIEqBxyGNQ4rMCOYqKqCK1um5usZbNgMAb26ttJybB6aB1MBERBZV4FJDydAKaiW0QOYIrCLzLDVbEBRMIQFDldqkae29VncpMUatqJ1WTYBMFJIbe1mw1WZq61OE7z8MfNrv/F1pf/H8sx8//+qr/SefSGv5eD5993p+c5tP8y1898XPfrq9vtrsdpvNVOcZlIb9DVEwNSm15jyMMQ5JFsiXY7lco5qrtlwqn8vhaIFMpdYl57mpADPRhMrVsGjLFXN1VViFQ48Mv1WEufJWPn4inVrbT/AIiibNmiEyESERcP+fHBx6dXePRGIuAVJARAfvuWAYgCKSIa1VzrE0E8nndilG0/Z6ryiQtoZMPbHLmspc8zmfj8vp4fhwe3f75vbNm/u3cz6rNejMGXMRR2JAsCqm7ubyMVkWiPqk1HTNtO0K72rWOEjCCp4Ioig1CUrbIOaGZgzsQNyn6Yq9y3sk55o6dDfdTlxDYAAK7gBI7mbdpAbscWDbE6+V3EOaKA5WsktDCu4h8JTSXltzrW6irtYM+YOKKOam66zMoTfnfXoAq62KG66ZnBCIEmPo1Jo+VcfHxwVuZrXV4+n4cHg4no/n+QKEiCRAzgmQuqUGhMDDJk6bmCKB55zzsiw5n06n8/l8fX19c3MzTRNzJA4cQm85vLu0dO6Seo9wfP9GPijtrDKWy8uHb37y7b/86e/+20+//adn57fRFkIzcmV1NsUeQezEBuRIFLbMgEogTo2pmVXShlbRGpEEqpFb4IrY3BtxHXbnF39CacAq3ipYM6IWtxpG5IBEQMRax/wg862WT2zcWghP51/oy/mP4aYObfFWTWatRWotc56XDC7jOA7DNE57oIgcANHNwzTwMC55luXsYCFwjDEGNil1Oau2TuAidwYYmKeUckpFqTfRHHjcTK2w19JKzcultdbESjnWeg+uSMgUmWM/0cYY9tvdn/zkp3/1V//mz//0L65urh8Ol1LK5bKcz8sHCGqfs3Wnv7V37WD0yhg3VzOXRrh4yVarLdmWjOMGYqJxhHH0FCgwW4F2JmySoD4UuYBnSBjjTfcagV6YyL2nB8I7zwB8QhQdEamLaMmud/iTH42HE7y5rYvoAmSGKwPvUav2VEk2uymMEGNIQ9hs4jSFccAUkRmWuao4IYJ6ncv59rjMJ3m4hHnBUq2KVUN11JUrZQLeyfsDxZtdePUMr/cSw2mej6dTVVUmcVQkMFiNjXCtcI9CixXBJfZucEtsqojy5IHUWXIhpviBlZsDOMQYN5vNpy9f/fRHP3l+/ZwxWE8WJgSinn3kaLrORol4QIrEgDigDZRGosnLHWJFVEKxUvJluV3Cw6/jzbL/gq+3/CymZ7D5FNsdcsU90/5P+fqv6formG4IGteGIzkHJ6q1fff69a9//Zuvf//727u7WuvTUOEPDRfwMcOcOMSYYoyBmRB7VrUrqqo0bajFjIMjeXEwgXlpc5Y0bTYhDhV9bs2aNKFSsGasGbQSaiCfIgd3KbVVabWiQUqOEfuXbc7WtPdpwI9WloAEEBAjEUVuhLcPeHeUS249JPyH1jmsiGl/VZmAI3AA5lrqcjzFmLab7ebzL8r26s3vf19quXt9DxTBw7jZP5ze6nKxJgEwMaHJcrx3nZgDcsi5zOfzcj5raQxutS2no8aI6KKi7kAEFFRMBUvzpdic25Jbrq2JPCZsvifK+6G6DgAmpIVWOqArQmPCQBaYmJkxIJCvhBADd3Iw7ty5R/czdwYKRE9KEEEmYDMsrZxre3s+XVrb7W6eXV+Wmq+21ykkJja30sppPj2cHu6Ot7cPd7cP94fT4ZwvYm0N/nJwMTUBUsAepOYYugjvvbBg92VpDs5MDrBagzVtzdzhZje8/CRtNylFLtmXRcqytAoq5E4YBgwABOYILqAOXgCRkJEegyW9n4E6ZbCL0MzNQb1zEDAYCrTS6nJpJau2kIYYk4HbKs5WRkoh2TCij7VWMyUG5g+zrNCxm/KiraYJTt63YrDu8NF1C5EwBRxjiL0H8HeWguYmqmW53B/u7+7v7w8PVdQAIw8pJQwpJOEQArO7M9Pz/fZmN21TQNfz6bgss6m2WufLGcHdpGy2wzgOw5iGkUPEfoxyQ4Auf/N3ncp6fVjaTcZy+eT26598/Q8/+eaXX9z9ZlcPSE0QjF2CaTAgYIYhAgcHcifCDQOwGJSG2bA4FvLC3hxbQIkskWsI1akCNmKJ0zy98N0zno9Qzmm5A5M23rRh75QQIUBxpCjzUA+lnLRdexy8T9zf254/cHf6oI13l8WlmhTTYlJauSzno5nGYdrsn109e+UUOA1E7G4hpnGY+OFWTd0NEZmJEFtZlstJWnMHVevU3+AwxjTEVLusRSuChMCuXBu2VstyLnmprbV2bnIEV0AgZKIAgEQ8DVfX+82Pv/rRv/mLv/ziiy97aPIyL4+l/QNmilsPafvgRrHvZgarrYGYg0Pta8lbw8sEFDQlH0ZPA4UUrZHMKMWgyaHhTGAB08hAFBgTeOinUHq0peuH1X4CRYduv+QIgAwQ0Pcb+tEX6XCKv/kdnxZr2Wo3soSPcysRII2J3ZgxDazbME08JogM6BYQ8yKRkNzanM93Ms+nelxiqdiaiVlzNCBbAX8zUARPSCnyzZ5fXts4ZLPz+TKfzw5OMZiDrOvyQ3TnySqnS27Qe0qwGSL2qI2OoiMzT2MaxuH9vHYA6MT43Xb7yYuXX7z64mp7jY4m7tQJEB3Fc3VxA+zCNAzECJGAR4cthA3RxuQEWIEVg9v5kt/cf/e7w3///fnlCZyvPn12fT1+ErZn4jNsGxrz5kve/gSmz5R30s5FWm0tom/jqGbLvByPx9PxuCwLuCOifxTd/NHV+a5EzBxCCDHSWto7iAgiJqQVjNVQXNFMrFY5LXouvgfHKUBxs9qatFKhZajZW0aphBoZbIgRSau0KsWaGwwJKaE0rVnqUrTU9Qkxdqvuzm9MhJHIuGTA7+7h9aGdcm3WDQN/EM1en2xnxkMIjtikXQ6HNufd/mq322+vrvfbfV6W+ubt/ZsHERgTDuNO7Y223HJmxESkWi+HW9c9B2aMtcpymS/HgyoyAqjky8VSCjFoP7oiIZKq5ypLrksucy5LKaVW6clc79YBPq3d76OMKtRKWLkkrgiV2ZXUIgcLToZIZt0Gw8GBEA0AyYC0e/qaGSN7318QFJCAGlBzP+Zyez59c//2mC9Xy2VupUrLNU/DhkN091zy4fRwd7h9c/f67nD3cD4tZRETQKP1WIJgYKrw2I8iAgamFDG8qyXuUGqH3lDNu2SrVDXzEHh3FT/9fLPfxxgwL3I5tsN9Ox2hSZAWqNtAsIM5qAO1Dn0wJeyEDEZAckOVAORNuwQHu+mFiYEbNEfWmpdyObYymwogBo5rWJW4mxJiDNHTCDaJZRMhwsDxwzVuRNrTSXsf2Y1jRRqYMkFAGjmlgJFwCDgNGLnzcq2pdU2RaquSz/PD3cObu8P98XShMKRhF+IY0kjRwX3cbIZhNDUivN6N17txNzCZEBEi5mVxk1bLggCuIm0jGzN18OjGHMyhm0oi9OUM/Ed07eQSZbk5fff57a+eX77d6IFDqxEzwYwwMwLDLsBuhLRD32KL0BCFuFbKBpdsS9UL2EyWA7RAGkkjClMlyj4sYathhDjYdFVwf/zsrxRwuv2XNN9pGDXt2nhtYUAURIAQCCyVoy0PztHjBtZl8QSWOACg+8eJuwgUmD1ETWogChwKAkqrl/NhWc5NhGPa7q7NoNViqsiBmAGs1jzP51arii7nhYG0KgG5WhMpuUxpiCEOMZ4uF1mOcr5ntwbUxERLk1zypeZLk6aau2IawQ0BCHp495B4txt2u2EaA7iULEtelpxzzqWUDzstfzLJJaDHPhoc/Z1Zj68gHyKRERiYolSw6iXrPAsFwaDu7C26EniSs9XXZlTFIJfRbXh2zVsC4i5V9TVsHFaCC3kXKyugETETRvZx0GdX5fNX9idfDUvF/G1TdV159fY+udwBTMRBmZDHgAOlhJEdwW1lhZQxcQzg2vKpLDmXItEcmMDddDW3eReTSEgp8TTSGAx0vn84OtVjobluOPI4utUGtnLq6XE4BrAKXLDnF8Gja44Toeqqj2OmGMIwpu12mjZD+OhEDzCkYb+72u+udttdjNEMnDrPBx1AxVSbm/XzAyHSap9CSAQYHTceXzhvDJsHgEhm2TY3t5df/Py//dPwLw/ffTv/m7/86i//7Mub/WbahmDmChi3EPeYrqTZd3dvv/n661qWaRr/4i//crPZfP7FF1999dXf//zvmUhV/1Cz/v5CX/t2ZCJGYoCePwHNvKlXtSLeO0Nrns1qrpe5XKoXY0y+aY6oUrXV0krxVqAVawW0MWpidPPIbNVas6zm7jIL9oTibDW7Fnc1A1DqFZsJIHTDO8bifhD59a18fVfv51q7mhL14+q+FvZeZQPGRMPQXE6nAz+cSPE0TNN+/+Lzz3c312m/34mUOWt++/KTZwEjYmitnO/vA1I3ITvevRFpVzefECYArqWcD3fMQ0BH9CaNIhNFB+rKZjNpNZflPF+O8/mQl1Otsz4a1GDnJD+CXo+/Pr60kVTqrsDgDqiMIGwqHFiIGBHVwNfEEWTk7sRg7tqftUFkZ0SjVVACDmqmDqdaDst8yss5L8ARKSKE1mwz5RiSuc3L+c3td7cPbx+OD5flsrSqJk4ACLZmmL73q88ZCJmZU/zo4MvMqlqqtKalyLzIknW3G16+3L36dP/8xTSNkRFiaNxpq4CXuZWy9JgT1IgcgLgfMQEQGNHcGcECBXIjFVSwUqp2/q0rmqw2bOyAkpfTMt9quxACUkAeCAgM3CqYMroxUkgUR6ojoRIC4nv6aoCAnsgCW2QIkYmCCNQil9ZaayHyEOLIkBgJjFxcQYEI0dxUWm41l3I8n+8Px8PpdDydm+owjCFu0rAliq6ERBSYeSIeukunQVDDbpvPIaSUXNUUI1NgZAIwbbUAuLQa0hDigMgO9JTXgQh/bNaObkHr1XL/yenb63o30owBSsS3SLcIF1Qm+DzqMLpuwEdcABqiRyqFFve5+mWWmX3ZQAnYCCWSRmxMBbliajQYMAI4RxnG88s/qZxqSNP912TqlNp43dLWGZEsujiHIDmWs6StIRvFxwX9rsatZ/YPtjMkJvbACYJhFAyxcAiqssyn+XzMOSPStNnlea7L3GpdT9gIpSzn87HWZgZlLuQI5j31T6Qt8zKmhIBM5FrbosvxVqUVSuKgqrXMNZ9rntWaucJq3YqAbmCEHa+CEJDJwJtIzrmVnEuptbYm7f1Jg/vjiO1paLr6YPUWxlfCNgA4ECB10ZWTiTfrTsPopI4KGNYAcUqSZXkQA5VWTZkwqDESDQQhdAp+/1Z9Jo3IfRQgaNZnf24+JkNur17Qz36Ujmd4e2jL3HefXiI/2LykVINKjIE5ppASBez7YkfLPBGHCGDScqulVFUhCIEAeml3AoCeFE0AgXgawmaCSFXr6XY+LlYyQqYByGNcuHV5vhN0YYE5Iazw2ip2XzctAFgtawCACJlosx2222m7mzabgcPHXXuIsZurD8NEFKxXRepSVTPXleDZP25e03gw9Apq7tF5C5Qc1QNZCD402mybf/tw36Rlk3R99fmXX73YPf+Mrq67sYaDO6Iztzq/fvvdL//7L87H0/XV1Weff/HFj65effrpF198vt/tYghrENQPdYeP99AfUCcSEhLT6lfZU7igKTSBSl21glXdalvmer6UrKgEQ9GlKJizq5TSyuKtuTZrFbQFMmF0wMhuzVqzLGZilQmjm5kUbxWloiqou2L3r+8hrN2hFI/N3mb57X399lBOWQRWc+mPYZgVj0dAQo5p2kxXV2kYwE2WYueWaV5OJU67sNul3W4H8PY3386Xi2y2PAZEUrPz/X0ABDEwOx/vHHC/fckQmaKKXE73KW2YkgOIIZsyuIGru0hreSnzeTkflvPDPB9aOUtbzAQeDWpWH1CAP/g4AFSpVTaDnq4BYArOZBZUmJAQVu/mvn6ZqUeAkbqICyEhECKq85oLAN5hjqZ2ruVccpbWVHOtIS/MJzOsojGEpvV0Pnz75pu7h9s5z01b/z59DrW+yB/kTHcgCoiJAuN746o+2lKVVjXnNq99Cm2346efXb94ud3vUgjcXa10sHGyUkHM1GutqOpkihZ6Orq7IhI6ApkbIQNQ6D7FTSUvRc2QVqwcOwrVzL2U5VzyPWomIqAImJAIyQClr3oi5BBJR+KEKkRAmN6ncRB4IksMQ8RxCMyhVpsViqtaC44RISAwOriZmYiqIiA0qbnk83w5nM53D8fb+8NSqqjHYRo325S2IUzmaAaEjBABgnvoL4YaNYGmjgRIHGJ0FVcMjClyDMzkZlKrizRuLYRGHHC1IkR8pOW//159zJBnwEl1r2VDNQ3NR1hS+BqH32GqYDtsV2G5Si2gu9vJWRPxiLnhEm1BeGTFYw3YAkoIEkOBkD02RzelcubLrY6TxeibvdNXHmPdf5IOdyFfwJykSJg0bJQJ4gA8gjmVTBA8uSOvvIZHhxH8uK4/bmDYmzCOKaU0xDhYKbLM8+Hh+OYtkjMGa7KcL8u4mzY7NwPEvMyOd6rKcRDxsrTNOE7DJsZUluW8zMhETFVEzFTa3f0bni8SBiBGh7wcc75IKw6G+HhIJ3BAcDNHaTLPy/FwOh4fTuf7NI1mnanzLsru/c3riTn3+CWsB0z0Tqbot98JohycOn8AniicBEbYmUaAwAQICliV7s4G6mYlZzsucdH4wviKKaTe0DHSu8lg55k7AZgjAUdM7KByvaMvP41v7/i339FscHCXx2/1/k0sp4tBCSnAlMZxiinBRlttZUlmI7GDOKI6GLg6ugawSJAIEFCUALt7FBoYgSei3ej7qTIsy3y6u5yPWuogLYpiE22uQu4BiXEAQjYia6IOa/e3fqLu0ElM3YzCnJmHFPdXm6vrzTQNaYj8YV8CAITEFIhCV+xYz7VAKFVEmooCeAyBIyEAMnEIgGSGqGgO4MAQkJBQEQkhRArbAX/06Vf//q/+F1V59uz6y0+/2G42IRCQmtGKPTogiJX5fP/m7Te/Ox3PUvJ8ObsqE4XuiEHd3fyHFsT7twBOALE7vxN20iSsAyBUB3ESx9JzuYq0pbWitbICOKA0WfJiggxqpUjN0NRFTBto6xbU4CjsplabXGoTtAhIiQHdxKuCGKqh9OmxG5r2OAx1LEKvL/rNUb4+yUOx4oiE8O7V/3CBALojMlMarp+//OzHP5uunnEcy7Hkw6JGGMeY9ghjHIkxTOlhPuU2L6DMgAw+Hx7IfEgR3bVkzQuIBaQhjdIuy/leYh6GfYwDhiTK2KiZ5lbOp/vT8e5yur+c75f5UPLRJJs1cH3HU+k71BNz5Yeu1c3d3hfJWeejkeBTaUckokBshtLtXyKwQoicAidzVnNRI7BurWqAzbRKExVEDBxDSMwRgMSsSq2tnOfjw/H+/nh/zhdVse7MyURM7m5rcMP6Y/ZthwJxJKR1GPjekzCznuLgterl3GKKz57vXn26/+TVdreLgCai1izntixSqptTGhKGQIuVnFWqtjW0ra9SN0EOzoFAQUkEDa2J1LKoOTEj9vg1ZmIVt7ZIObtcCJBoJEpAqcMD5rY2O/BO8ElMHDjE4f0XCx3IvBMS2Z3UQASaMGCgQMTgKNKdVwERmoGZ1lYu59PD4f54PB3Ol2VpuRlQoDAQb4gnxNSdRt3BDNBAxQkdEInRDczMtJtdABGFGDHQECj20h4CBcb13GO15L65EEfm+DQref+9+rBrB0SH6Da4hKjEanu0Cc4WH2wwtQBQudSIRKCAZ0QligMvI85J5wiZqQQsEVsiCSgcGnOxkI3VnKWE+S4df1+HwYfB0yS755rGNt1I+DYdX8dyBFcHUgwWRuCBIJo65IwdJAoJMKyvmgP0AdVH/GcA7q8mszN6pJSGNIwi0pkJl8PdNA2JCUzyfMyXU1kuZkpEtWYxSzExJ6nFm2zGKaYhpAREc14cPKZYWxW11urhdKCcPY3EkRFrvrRWzIR4rbwAK66Fj8cNFZ3n5eHwcHd3G9LGMJZcRdQN1rSEP3B18dt6q+tfj09Ha0QnRuJHt31D8B7Z2iU3j/s3kSOK07lobdpUL0WzaQMtECHG7ZYGYg6AROCAPRoXCJ2A3AkRgBnREMh2I7565p+/xFc3+JD9PLs4fP8mai7q2SyGQODAgVMK45Q2+2iemKHl5lU6K1vJNYJFcmZExce+AR1WpB/Rh2jTUNHnWk/H+XzXcrVqWhEbeAM1QmQIkQiRAwc2rOgu2qMKV6wUujq+13UACJHHTdrtx/3VlFIIgYk+ribYSzsyYrduJgN0tVxqydnVCBFHQArujt3UAwlKY3Xo3nGBGRGJgRCAA8Fm4C8//UL/SgF8u59efvZ8v9mmQO7SmtVi2DU/LpfT6eH29vbN63kuRPzm9ZvtbvfmzZvj4fChYcUfH7UDozNCoNU3jx6JbB0XMgcxU3VdpC21LQUUCFYvxNZkyYsGCKBWi9XizVzUVdyE0YXIDRqZmpWq59waOCvx4JTAzcVQnBRMO/PI15QBQl6MT82/mfV3J3m72FldAJF6vtEP3lT/gJniOG2vbp693GyviFPGkmgyjxDGYdgzjsOQMMTLZqOHQ50vrbhLI9PlckSxeH2N4Naa1aK1MHOMQZot51MNizQZhknTKF5Jliotl+V0vDsdbufLIS/HVi8iC3j9A3X9jz2X/vr5o7ay40nm7oq6IuAA3iO4FACd3FGBDJ0JIoEzkkFUNzHHxzBGBa+qzcSg+43HGGLnV5tbrbVJuX+4fzjenedLaRWhexhht6nvfJsnYcK6zSJwwBDZAxp+/7DS3NUURVTEtvv4yavdJ692NzdTGlaYTprk3OZZcobaiMM4RuoynJqlVTBjM+oP1b2hB/DoYEBkYGDcmrSczQFDJCKnnkjLIsXqReoMWokGwugeTNHWqqmmAo9DRqYnd+bQdWVP97B63pq7mJiAS60qraFDICJAc69NHt3FXN1KzZfL+XB/d3f79ng8z0s1YIpjHEbmicKINAKlTowA7EcNMAUjCOGRYP8IkzAHGIbASGhj4BiYGZmJA3e3gyamqn3n4j6n4V4L/kjX7g5gnaPZRpBI4dMw7dOrC8nsknVU5WA1ADJr4EosMTSKF4bL6MsGs2qJ1DZBAgpCI8rA2akKoOYB7uDITELkhcj2L228Ak66eVE+GWT7LM1vaTmCNBc1aK6oDN7AUL0pquO49dQjfx8/iccO/t3OhTCOSUWJBEHBLcY4jJOaSQgc0G1mhmGgU3TzVtq85IuacghNRFqJISBSq0sV2WzHGBMPQxiHqqrLHCWKNhERVTVgIObYhwPaxN2YOKXEzEQoKk2ECALTkMZxGIc0MMXTcf7699/OFRzT7d3bssxuTsgfYtmrqB3XJYZP6OMKwTyC9UiABF0b0tk2nc7RU18cyREcbT1rIDhCQzag26LlQRvVi7SsAwJ89mmKEZ0fwf6uHCYCCOv5oPtqKLLxwLgb5PlOP39Oby/2JksWdP94xXc/EXczVVVjQkAMkcZNdFdmLJHaTFRAwRpQAarmKUtYlPKqRTEBVIAuRldXsAy4GM5Kc6WctUitjJVBgzsSoyM5BDIgiaH337WKqPbPqGvZOxmzyzKGIWw2cbOJ4xSIPnwUTysfiTkwR6KAxO4gTZq0JRcpLYUQYwocCMhIHVxVRcvpshCHzXY3jqnzXYgNqR8UIAW6ubn2nzozTdvNdr/ZbkZmMtHT8fJwOBHFwKnV+vb13dff3r++Owfmcy7/5b/+n//9n//p22+//e//9E+vX79prb2bLvzhi7DHvRi6MhijB/RAHhAiYeSA6FU8X9rp4dSWjE2GEKeByQFMq7RSEMTssbRrPwiIuvWAODJVQlK1XPW0tKKABWjiMBEFcEejzvB2X5O6o3MSCudit7Xdt3ZUWRCMu8IBV0eHj/r2/m8QgRg51mbHh/P5uJiYQ0Aap81m3OxT2iQYxjhxSlfXez3dnQ8PeZkdRWu2tqCBaQMAVNNa83wMKbqLSl3moxlwuMRhSGmg0FUJNZecl3NZzq3NIotrIe/JGI+BCr0Yvq80+YOPZR0relfu91W6utOt5KLVdalDdGQcLQ4ek4WgjAqgDqpmAgbgCobgBt4zwry7sgIQISEigYPluizz5XQ+XOazSAMHp9XUCQn7OIlT6NmZ/ccxMHcIkUMKxkGZ6L0VggghYmvemgDYdhtffjJ98aPtzfORAz1qAckBmtiS5XyRnHEY4jDGYRhiHErQnDVnrbWauXfBxMqIQXP3VoxY1Vpt0PO1iJ27sR1amyUfoRY2RGQ3bk0Ni1vtHhhulaAPScmkgisSEPeo1vdWByEzNtFauuv86kPXMWJzcVGHNZq2SSu1nC/nh8P96XA8n8+1mWPgkChOGAbotr/MyAwcEJz7Kw8EhEQYA44RNwNNESJiIIrbhGQqhVwTr+N27KNYBzdn5hhA1FTMtDYVYiZmsz88azcEIb6M08P+6th2N8N5fMl0Fa8H0CCNKrUagsmEOAQdYiNuIRmEGSEnL1ssSDVg2wQhUIdqVIyKYRMArYOfaHaCRjERh6ZVdqLj3uLU9oNMexsnvtzx5eCtKJA5gaKjuwtAg9AwKAR3/H5df2/ZIE4pCtOqgRYJMaQhObjZsNsPu63tJouJQjDzVuo8L0cwBXCTpqaaEhO0ukgtOW8NAZkpxNoutWpTAXAi5hB7fqnrOmjtH+44jlfXz6ZpE2LKJS/54iZMsNtstpstIYUQSi5v39xeFlHg2zev5/lkIoQ/1LU/rfZ1Z/PVuqbfvj3yw7oNGD56IPXCAb3hoXW7AQPE1f8HUZkVqArpsTbTMbVd4iHxZgpE/Tywfn8AAGBE6nDVujdDCGxTkpudf/qCvj7AcK8X8e+RnYACkxI4dINeQgdwjjxuEhGEEJg5AyGAuBag7JSrpqK4GORHoFIAtfOl3Kq0WrPz3HwRXsRK86baHBXAGNGd3CJxGBhjdAxzlv7pWe4nXvf+MSEw9wAlmKY4beIwhZS4S41+YBsGNzNRqa3VJsjcVGqrrYmZEXPsmv11mCDLPM9LOV9yGqbPPw8xDeRMgE+mqYjAjJvthphDjOM4xiGGQI4uNd/dPvzqt79VA3AuS7l9ffvP//K7b17fbbdTVrl7uGu13t3dvX775vbuTkT+gEzsg4tXlHddP13PEwgjY0BCIBEtVQ7HfHt30Vyiw3YAAw4BOVgUqRWQzEGsFm1F1XqcHpgHIgdHM0YX82rWBIq4qoAaO3LqHpm0FhxApgCUDGOFcLB61/QguoAKoUfsWb4AgD/gMvt4RxzjMLrjPM+SS7lcwrAdNzcxTDg6NPOqqBSGOG7GYeL71+f5cMuRvYfPApo0RCQHlzaf7zlFtVrKnOdTawJ04hhDjEhoDrW1WotINanu1b0BCK5e8U913d8V9xXN/gM/PQJQP78AERKtT7BP3p+Kez+0AwOScYCQnKMTGUBv11XBOrDbVbndptzA+l/rjzXbwdUk5+U8ny7LJdeiXV3dKySteCAzPgXtrobpjubOTMQE9NQoPK5xws0mmXqORhzCED95tX3xctpuI/WTPaCqt6YlyzzXy6Xm7KrswNNIMbIrmkKrAlC7jRtQWmPD3VWbAiuSqZsaIKOZE7txX9ZaZylnbjUYAbIqWVXx7NbABMXQADt5Sg1MmdgxINGHHGxAhJ6IWEsRETMjIuqye3RwFzVRKa0seVmWZV7m0/l0OB6XZalNiFIcIqcppk0YxpDGkIYQE4VI3BmCq0df39Ujw5hoO4YpMroGxnEaUiKzEUEDekAMBACmq8Wd94G9NK3YQwCqGKKyygdpER+UdqFwGaZvX375z+0v9uc8epvGyoPBlW5ZWyxQGwf1TbD9IJtBgKuzGGfwNkTZsQwm6BKwOYhCccyGraE1RzMGQV1CpXT87egyzw/L/qE8+7FevfI4QEht91zThNM11uxSQa1TTQEIYsI4ALM/lXNfXYE+6hMRYGAOAB7MBIWQCUPgwNOQ+NWL3RevrjiGXIVIRWtezny8J3TXVktWUw7kPrSWpZXLfGnd6AdRVLVVB4sxbjZ7RGyqqtpzQhEA3Zj5+ub5T3/2Fy9efrrb3Vzm8/3h7XJ5aPWym4bdNPZNkNHn8/l0XpYqb2/vTseHJhU+TnR+FFo/XittyFai1ApWd2IUrTG3PS+bVn9D7BujORgCIjLBqvMmhIgcI8RBOc5e3xztd9+27YY3G3b0TR/EdjKLASi6grsCGKI7giEhQxx0t4eXL/nFHey/04t4dv/oRuIwgIqrqGhdMlgKiZnDtKEUU4xCVNEKO4pptnYBHKsNF8XZY0UF1G5hYRAMqICeljr4MgyzhLlQ9ihoTmsCNCG6O4kMEPbbYdhuKQ7nuRIBuLXWpLmqdE5QCMzMSM7s0xSHkVMkZlT9oXhth9rqeT7dP9y/vn1LIe51r26qCo4xpRAjRwaC3meVnH/z299++/rN+bzcXL/Yba/3uysLoNTdbbzPVc0BmcIQKQQI7IDm5G61+tdff/e3f/t3t7d3p/NclnI6Xn7729/e3r4ZxhgDt1ZKKaXkfukf8Kj56CJmCtxBx77CiJhDYGQEErE5l/vz+f5hvj80b5IQ5uZz080Yt5sQIgzNmQxcvIk2EXPtTD/iECOnECIzAalCMAX1rEvTWqUCoBAnprBSwojBODjFRelU9LbKfZMZTONaMJ6Infhu7vvBYnfAmIbt7nocJwLUMpfTnVxmnwtls7Okzc14dUMRdRfExNnMFmtn9kTu5OKAIiVQTBzd7HJ8qwRiUpdTXS6i4kTYaC2N/oT3GLihNwRxUAB7X5D72HX/wRH7u8cRnGM/bAMzdGKBO6DBupMruPZ64ObIsP7Nrn2ioWTqpICK4EhusFYQdXV0ZCRDMCBCBDcTUZ+Xy2W5iLYnl0d8tHhCpLDSNqDP0LvYg4D7V1bFGZw/AICJ6OZmP6SUYqVA2/1w83w3DJEQAc0AXD1nOR3r8bicjvO8SK3e1GuTWtuQNm4kampZraqoGRENZMpuQNGADVC7vt8ciQHUnQ1IoJmytbO1BUXMu5U0IAlAcRU0Y499xsjckfkwjtS0mLXW2vuAiok0LK1maRWJUgy993AAMxNVkXZZLqfT6e7+/nQ+LWUppTZp5ogcQxrTuEnjlIYpTZth3KRxTMOw+u+6OxjaKupyN0SKgbebcTtGKdXdQohpSCFuGJ1M0I3AYHUA8m5mouKExMSVCWvNpdWSP8p4/KC0K1GO4+ubz/4F/uLmcBlLux4fRr7gVCEYBYTGHlg3yfdbickaavNWoSlqCLJxURU3cReF5lAdq5AKkDq7ISlYdUWZ3wYRya00h+EaNs+Ao8foNFkYME3YCtbsTUBXF1MIEdLgFD7Wknyva0fAgVkRTU1ZG2OMPI7DmPj59ebzl7tPX27m3OZ5Fmnm3mrJlxMToptKEZNWA4CJiprmksWhtWambgroMaX9/vrm2QtEPB2P8+W8LHOv/UQQY7y+ef6jr372+Zc/vb755DKf7+7fnA6v58tdYh8C1JxrKQTm2pblcl6KyhID4MDu4YdR4H497durZhahT4YQ1pDCPkC2bvhKqw9b/4QQ3B0RnHBlKCBgt+AZEhIZwtzq7Une3rerK+LAXedGSGRM3cWhb2dklB43WkZKNG7h+gafXcPNRk9FpX7YtiOGGJyCuYJZKxXckYYQUkoDRA9B3IIUpCpSSma+AI2KQwVQNEdFeEICIgCbSy75bHP12YalUjFSBOeuE+gkAXDV4D4F2o1hmNIY0W1yM1FDwAKAYESYVmsaJ/ZxDEMKzH24+7H7WQcmaymH48M3333zq9/8GjEgBGIC8BCIaG3/EV1NluXy3evX//CP//Cb3/5OBH70+U9+9pM/qy8+Ie67KBr642gWAaE3So7oiJ0nnUv77ts3//gP//jr3/7m7e1tya2WerlcSskhEKCXnFurK7vy0erhX72cyJCaOYqxKItF82COYG5aRA6X+c3D+eG0HBcBsYiwNM/NmgMGTtErO5Ohm4up9oAzpJhiGqbNNE3DECCQu0gtQqEhFz1naa7NFMwAAzIyQyCK5MTN6Vz9tsh9lZN6QYBIkZmVpRMsm7nAR4ujj6T6axgokKO35rVAXdzFqpfKkKHtVNU5skiSJas0twpSCQABGU0BTYuaB3AVWcqxuYi71mw1u6vhE0X/sY9+moKDACh+yCkDeFfX/7i7FgBwgJj6C+bEq3mu9fOd9aE7mJgLurk7dZGMK/agTQRDNycjNsRuwt8PAaZuXYbeZRBECOBNpLU6lznXRUwf+/X3ZG4IfSDc8V/1bp2MDmDdjM0NGJDd36NkE9F+vx1SHIYaIu/2w7gZAtHjRwFqXqpcLvVyLvNcc5bazF0Wqjm3Ic0AbAo511rFzIkCkQfqDGR1I/PVDcsdkAOBGRE6uZICuFy8ZXTRbu3ghOagAtpQlZCQB2QAMoSC4ARGruYV8ANLZjPRVlSamUSKgTEGIiY1U7XaymWeD8fD/cP929u78+XcRMwdiTimkMbUHTfHbRqmNE7DOMUhcYgOYH2JQXcW6nwKI4IYadqMu01aEKRJ16PGEEJANCVXBjUTkYZmToDmhM7EytaN6VS1NfjoEPlBaXfARvF285JRr0JOM31Kv7vG18AXDDmEMbhxDDQkmLam6EvBuUJu7qQ8KIMqqoqpmJIYt4YtIzZI5gEtIACRgiytFJsVZxmKiT5V5z5CB4qe2MMI3o/GK4LcjYZ9dWWyx572e8sHIXJgNIsmYpl1GGJM4fnV9NWnVy+vpu0YL5e729vz5dIQIhhorUDUIS8EsNakf04URFVzblKkFUAbp+Hlq1efff7Vlz/6qbT2q3/5pbqINmidrBaZeXf97Nknnz17+dnV9Yvt9bPt9bPz+eXlclfmhzo/gCpKHTgyuKmB237/3MCXpdam4X0XCPjeXm3dW62L3xCftKdkuJ6zqRPogAnWBv2xdnAv9d3Isa94JIYQbYw4BQ6RhPGQ25s7izGRR+yNMoYInAAI3FwxWkDCAEZujB4jjTBu/GrvL/d0nG1pxh/+2Ex91A8Gak0FnGOIA1IYGBFBUwIO1ZiEqSBfMCRSDiCuI3dl/4rdCDlGENBc4CJ8UVpaqE5CsDozBXBCQjcAULVccIzjwCEiXQ+RKabheMqXObsJoaXIMQVCQ7AQOYR1goh9yvi9g1Yuy93d7T/98z+lsCFP2/F6s5liCL3CuK0QUin169///uf/+PP/z//3//2733/74tmr3XR9eHi4HE9EW8TkigBmKkTAIfQZEz4emQihaVuW+eH+7ttvv/nm91+/vb0VEVPrntH9Be0+3evb8j9U1gEAqnlWR7AGorkaB0ASMXZotZ3n5eGy3F+Wc665gRuFnu8HztVT9TH4SNatRHA1+XGmOG2vttc3VzdXu90USdnFW6k5j8eZGUwrAFTA5mgG6MjEMcY0hmKwFDnWdp/rsWl2cA4xElNiiDVYARVo4vLR41jNFxysSjtdciMLCYqMGNgpqFHJ7meFWCmcGMocXA71nKFZQGAHJIiB0cwkG1Rz6FTbaqLubgImDuvYDVak3QAAvHsjdDTp6RH0/cfX8de7LgT+SO/OAeLQNS+OHaRwR++ukACApuCCUkFqzzlgFYT13A5r/kowAkNURDZAdDdXMTVwwBX86GhWrTnnXGpu0szt3fm1v3y+uguDu6l14n53aPA+iOtrkRSItL0b7iLiMMQYcZoiEYXIhGS2cl8dwHoebGm1aqteq+dirYlpjefCHN3BzFszBBincbsd97sWAtZmtVVQdkU3cyeDgKZuisSgHeQw0AWtIAr2ozl3mqq5C7k4siMbOpAZuKiVVtVqDD4M73MGwFTVenCL9wLkJgYmormUw+F4fzjc3T8cT8fLvKg6c4zMFGOIQxymYdyO0y4N25SmkAbi4AYi0onx/VTmDm5A68TPY+RhjOM0dvWamtcq7p4ipYCdHC8NWsuiwsxElFJQNWgtQSAiohBT0pOc6+npRj7Oa1egc9wBwK9sCWhzoZdCMZ5SWKYgY/CQAqfonLyY68mK2VwdHTeEjAbeZRjNqDXWClCd1ANgQGREARKIBVLGscadDHuNAxADwkr9Ws+ZYW1DH/HjblTk9q6ov/v14cpBgMhshOahRR8HT8MQh/jZi+1PPr/epKTNWrs/nksuzpzWEmoOa9AweIdYiRGDdr83VUQchmG3v3r56eeff/XTL7/62Xw5f/fmm3BIIXYTWQsxjpvN/vr57vrFtL2Kw4YTcJrSNE273eW0uZxSiIOO08TAboDHEMO031AIp9MlL5U/dDb9AI1/GuP5yqfzTqvETstwQEQD77Acka+f1BPatFLvHmm43WgaAnuKPg0QI2D0YnZcfHsIWK3MatXGaNvIu4QDg6HwABzBDZtCqZQzXxYwsCHh9Rb2I9xdgD/cxwIZoQOQNq5V3FVURd38Mc/BsYe8K3KlsFBIwTiBM+lqxOW9v8UAmEDYiuOl+SxWzAW69SZSRO9qW+yW1qClyZItYohhy2Sb5JRiSnEcVAVRQsDISKBg2sW5iLgyE+H781E01VLK5XI5HI7ny6XkOoQhrl4H5Oo111KXt7ev/+Ef//E//5//5R9+8Y939weEeLlczqfz5XyOAwM4MQKam3Lgbtr/XnUHAK+1Xs6n+4f727dvH+7vT8fDD8a4/V+4mlkWBUcyEMoKYGYlBDJrtZ4uyynXubQq3S6IpL9iBlkhN1+ajuwhWWDgjuA6EnEaxs12v9lfT7spkLI3lBJTdPNWy7JQEwAFJJQu5OKQQhxiVDEHreZZrVnXsoUYwxDGgKmYUCtFBRW+p1eADul5E1kyK2PQYJYgMvQIT4NWvWaZLwZaZkQ96nKx2tDAUR8xfhfJpkYAotLaItob0m4A81ihe11/OkI5+Qf+AY/P74MC/68D8hwgxDXs7XGShY9dCyC4B/Sea+3g6qLu3qlBYJ0nEWAl25gj9a4dzEx7wo8bAHQbKlHJJc95rq2oia9B8o+T9o6/ryyxXtrfAfaPk8/1R0U0k3cv5CNnJWD0x5XjZgpdRtuBfsIQKMaQkjWB2qCz5s2cUMxXVmuMcRjh6oqeP6cYYFl0yZ6z5GqlSlNySw7BNaAxEHd5OVpBa8iGyNQjTM0AzVXcKjojkIAgiXtzq4B1SLa/CtvNgO+9WO7dnWIda5hKUVHVuZTT5XL3cHg4HI6n85KLqPYRFMfEIYVhTOOUpu0wbVPahDASM3ZXGdXOZO/zY1uTftEQiYC7wi2FELkJqXlpauZuRBgiMwVCFTWXpqqWUhpSYHYiFFY248AxxUu7wPHde/W9UFd3dVpw+JpfNm4Hk08r3djhyper0TT5OGGIoIDeqqObqiwV0WP0gIjIhticWoO6oGXjZgSEBIJcMC0+XWB3js/n7av88k/ay5/61QtP0Tt3v7/Nj8X6yYymG2qAPUZrf6+6f7h6MARyJ3FPibaU0pCur7efvdx/+slOmt2+nZdGpaEjhzh0c1nkHsRJCOJMSNgbwD7jYE4pDmMKz168ePXZV88/+XzcXucqQAxEFAObqegwbW9evto/exnHjWGoTd1BDZjGzYZjHHb7GysXaEvy6vXC8bvT5bTd7yhGAo6UP1RbrTycp/t7tIlbP561DaduZoFr3VcHWMVvK65hAIRr6AKsDLz+/zFBJEtBhyjDIGlQSiBAhxPN93T/bS2n9mxvz6/Dyz3uN0DBBwUf3Cstme9O+PYeS3VAAPPNaJvRE8P7anAED0ERncekEnCpVUEMStNQGyNZk1JrqxVUO0eocpyTk6K7CThq11a5O1AkjKTBC+JFeHaoCBoAAiEDRiIG7307ISDVqst5IalxiJSGGIb9mIAnTEMzU1BGZRDU5iq9OzN3VFl5Mx/vy84cp3F89uzZZ59+uttt+8TT1AgDEavq5XL53e9/+8t//sXf/pe/+fkvfv7N6+/AsUmrteZlmS9zGqKaEmPXtAAiqXVM5embmXnO+XA4PNzfHw7HnPMT6v4/3p3/8OXQRKs0M0TRZlZbazknpgBmqqWJqXVvXOils+NDiOLY1ErzmTGyD8E76ZoNmDASEZKo5SaRIRExJ4xGQ6IhckKOGDp/KDHFVakbiVLAFGNKFoMGM0WIIYxpGEIKEDGYMUJAjO/vwOvS78ModIPW3JsLAyLhwEiM3JcnmFkreqk2q7ejloPksrpKExq6mtS6iLYVErRud7SiRavg/OMijY8/xdOXj9tWt15/HBb+a5UdkJyCPQ0X3/116xN36J7tiO7UCkgzcSB0cDDAwODJCZzJlY36anfoIeMi2pr27aFpU7ElzzkvomIdEgbHp70BoIsR7T0FH8CK0T9yAR///EQPfPdmrdRUh1V2Swhu6EBIyMibafBnThQ4lDi0EGsMoRTpljvgCgBING3Si0+mFy/H58+GlELJsCw2X+rpXB6OecnQZFCL7tEhuQU3N1PUhiY9QQ/IiBr0e9TqWsW9gDoUgIKoKfpuSzfX8ZNXm91+w/z+rosrkcm9H49KLpf5cnc4PBwOx8tlKaXL2okDc6AQkQJgIIocp5imkCaOA1H0nu66DmtWAo7hI57jDqv1O/XMUyQARO3zBzV3joEsOgIRMSKrmhQBwGmcUooxxiat1cpMMcZ4/sAx9+PS3l+nCvEhXLWomeaDtmc1PS/nl6bPya8HH8AM1JEgiHFVX0SFsgZ3ZjINudKl4pLBWsdzCDAIDY2mOV1dxpu8e1Wuv5AXP7GbL3za9syg92fJ7yFaa8YHODwCY48Z7R/8eu/JIERmcw8WpgTTyOM0XO83V9txnMaztixYFA2465gcCBA5RAqhnyCIiYiAgjuaKLrHEKdhvNrvXrx89eqzH11dvyBO5ogUQogxJebAGPY3z1+8+myzvzZHUSFSBAJH4sgYYkoIO9BKktkWyceliSGmISHimCKIFny3v/8AIL++ee+tJlyVb11gZubaDLqFnCMCrSxbf+/vXJnZTmSBIDImthgsJQopCPip4HIEPfrrX+tyaC+f+fmly0uSG9xOyOQ2myBfznz7ln7zDTbz/c5VfBpsO/gQIHzYXTEaMcYUNWJVgKKiZrkBZgK0Jsucay5BDICUQ00R1d1B3RsAigKzm7kDRsbASlAdzgYzYQ1oCB6QIkGkLr/vDh9AqEBFnZoqYkByoL7lBwJFNGAmCOQg3g15vAOXBIRA3zMtdwAiijFups1+vx+GAdDVtGq1pV2yzvPp9v7Nf/+nX/y3X/63v//lP/zm69/OeZnGTWut1ppzyUuel8EROHJKgYjcQdW6DfETA8vMTufT69dvbu/vzpdzfRyo///h6sF+jwypJg1cyVofnhNiIOwnpKDe1FZ+3KPlgkEXFQBGogHInRSxOZlpq3W5KFipQ4o8RB4iopFSgBg9RkyVyJgZxkADEyMhkkNE2qS0GWAaXUAAPFCIQGyAbmQWwIy7sdgfuicnMAZnAEZ+1GYTEDuBuWhbpIlY0XqyevKWXRpg92tANVEXNfFe7N6Nyn1dL+935v/Ktc4O/8efFpIhrzDtas3Ti2jXAWH3VgYK0KNitXsPQ5ejoUdHhMBdz7k60aN3HYe2JrU2RCY0c5CmpZamzd36zPepqq8V259+f2LNP97Y4x/7+fKjU0vv9Lu8tVPy1jA0IF8DKTCmsN0NiBAijVPcbGNepFbtU2cTNXdiHjfx5lm8uom7fRgSTxOMk4+DIbXSMhIBkipKg6peRdXMRU0rWgNncHQvpJ2OqmbFvJIbgSLVEGQaYL8PL14ML19Mz19M0zRejvi0vNSsmqmoqIhqq/VyuZxOp/vjw+l8XloVNSRmDt3yhnpSDEcKKaSB40AcAdkRH1nQjzjpKi10B0B6ypiAvoeINnV1WHP2ANzNU8AUSGMApBhi5VBrlSYiEgIzEwCDBzZ3h/BH4mEAng6O2MJwHnY1Pr+DZXOu161+meWLZp8T7EEBDZ1wHH27kfNFZ7E5U+UUR7dwKeNDDudGaggIRsEwCW1a2tf983b10p5/Zs8+891Ln66durUiPNLAntr19Uy8bmr+yIR/r7T7u7r+wVIKhA4o6gPzNKUUOFiDWmtuS2mXJuLIMXAIpM0AKXAaxjgOnTcWQuDAuppCWeSwm3Y3Vzcvnr+4efZyf/MixqGWKrUFikOa3DTF4fr6+f7mxe75izRuljzHMKRHazdABAKmQDhQEAgT2uhIw24epIkUbYXAhkhV371ka7v0+Pt7bNsVO8Oe+LqiyNjHY60KQJcEUFcXGWiH1Xqb74iAPa9WAlNkDMFDoBCZAiwiD4u1B1veyO3Xmg92OMs5g1gwoFcEw+BtkSJ+vh9uv6Xf/RoV8dUrD9GG5NPoY8DwIXKKXdfDQz9sGNTaxFupomButbVLrrmgglOwGBuCOTT3ap77mZ8MDAABAjuTAlbzxWwBF0BlgkQYuWdV9pfHEJ3YmDSyJmoBm6OV/19779Ylx3GkCdrNPSIys6pAgJTUOz2Xs0/7Mmd256V/Xv/H7d4dtVrslpqkeBFI1CUzI9zdLvvgEZlVAKS+cR9mDow8AKoqqyojwt3N7LPPPqvNy6JQGlQkI5YBJWMQOlAQBqJwJyoYXVKZdWdsR14nDBMhYZArNG11fjw+PL77w7dfff3177783ZdfffPV2/sfT/NsoR7etCMTrdRWlsoiAxMAEjEAmboHkAQH97xRVd+9e/fV11+9fft2ns9q2t3Fz+Lgh8TjIKbgHr42uAdBJKYhS87JAavaUttSrBnYNsmPKJghDzzt03SgaQLWAsXcWrRSjg+1VsgD5mEYh3Ecd9OYEqlz4+w5+1BdDIR5Sn2KgIe7qkjeD8Ody0kpoEIoGGAzdQ8FWxqYCoVkepFdrbVhiD7ivUvIp8TAtMLsfRC4WSxNo1mpOls9ejuBLbGmHV3EKYIApYuN907m2DKMSwHrz/v1rRjf/41bgym8dIAf/U70QFv7WtezDrYgD9YpRnApz2EE9spMOLgHYohcdBo6yc0hwsxVrda2LLX7Wu91dNMAACJCDI4te1qHogNclWWvl7uCVxeHf03in5t7/50GiEyM3u+a9aOJAIkgZyJOwwQ3d6m1UTVMvXePzotWtUBKWaYDprTe0pSJCDBiKTGOMIy43yf3NJ/x8ckfnhZ3C22uDVzDxDmsAWKHJXqm3oSABadMN7f58y/S51+MX3yxv70ZWAAAz09X+KU2i1Zaa6WU8/l8Op2Ox9N5PtdWmhkgiXSyLgNSdAYnJ0lDymNKI3EOoD6Qrx+41HuGgYF4ldWEPkfOOjPX3WotC7uqbm2kK9FhWUAIB2FhGIbRzVpr7j7PC0CMY0YESRwebsAvT933XXv054jgSDUNdXp1rvpU7OjSdFnm5fxQX9UlUwgAOSNmTztg4zL3yZgQUtt4xpvH6UYxA7OjGGXLOxtv/Pa13b3Buzdw8xryDjhBGKy99pdd8JwDv67qzeOt8fSzxP165l4sCZoFRgjjzZizEHhjV2+tVp2rtQCUzCmLu+Q87HaHm7vd4YaJiVhESKhT+DAgSTrsbu5u7l7dfbbb36Q8mnrxEg4pDbv9zTAON4e7X/ziP6Rp54S1LD/VH8IsCQ95Qk5bkZCBKIIRAjCH7Kb9KwLXNocWiham8/ePVS/klFVD4sUDujh+BFhVkYFo5Sq4gjUHCFMn586K9WsbF3YHHxTEIQlSgpRIhIhJg84NHo/x+BDzjz7/MR7fRTvF2WMBN/TmaH27KLrjcob5FMdjOPpyiN0eB6YpwZggvxyqouoYjuwWpAqtRStm7lgV3KNqlEYAxEKSAkSFFaBCqHr1YAvyIFgpQ4aoARpQ2RUARFgEEiphZ8/2AXQQGEQmrIlbEiMEBLO+c5zBBxJMNGYaMpl6M0TkoE5HDHR4T92pr0N3r609Pj5+98N3kvLSFiJSq/f3P7398Ydvv/3qD99989333727/6laU7ceLHd1i6baGyU3OXsGJLNQbSSYKXW/3pqfz+cffvjhd7/73du3b2ur4fFn/cq/znZTbvtBqzft7cORCYbEu0EO++nm5oBMpbbTvDwd59JMHc3CPfZTOhzy7d1w91k6TDBmg6IBbGpmZlajhLcKtWgdW9WmmrIAWnNwSThN6AbCMA5BqOYQAApMxHnYDcPtXprOrZxVLdStulXrPdvMyIyEH96FCHALVS/mg3kJ6JPEerEXHXuMqE2XpovrHLaAtwjz6CrMjB0/EEZgd1fVbbATvAjtoBe3/GWAdRG3j+eg2CV1/2drKB3z3/D9Tk+/VNP66BMMQjM0g9732s9E8wiHPlem66352g8e4NEL7WraWruSBSKs16u6wPBKLOoarBGXOVvx7N1dDLeIZSsUvvccAszD1BSJ0AnRe3m5z4GFLnZPkBCZJef1nZvacm5n8Lmag0uiNKKkQPb+2wjdQNVahIlQyvLZKwGno0SYlrkqtKotTAGMOKUcjI6oXQmauLF4FhozH/bp7lV+88Xw+s1wdzcOg7i76QssaC5lOc2llHmeT6fT6XQupahpIAD3nAX7GBekzN2jD1MaxjztJI/MCbHPgOxFmWvvZk/aoaft1OMrcPdW6zIjhfanCQHg4B7qXigSY8lMQ0opxTg1ba01Va0FmVGEiQCZgfA9VeyXDPmAtf4CAeBIDIdbF9E8nI6vvisPR71/e//T3X29ZZ9IRZyNkQ4p4aTEqlERgjBGHb+Yp/9Yh5tIA5AAcww7mPaxv8HpJvIEMgLy2soQuMWFl6j1mWt/iZCtb+9SYr8y6bYViJAFa7hqE8aceZ8Y1YQCvFbVpVl1AsqcpoHk5vbVZ59/8ebzX7767I2wrGpv1JVtERGFZcjjmMc8DExsHh4VOjY7DIfbVymlN69/+Re/+k+llu/efnt8emi1aF2SpN3uLuepi7dTrwD3dlg0Bhl3tzf7IbNlMoEa1r5+9/9eXPsqRLd592dZO8CmbtVd+ypg5eAWZgEB1pzVOfpxuBLqt5QdgZAEJEV37Sw5SM4NarW3P+GPP0X5yds7aMXDoBZY7mNRP81YFiwLti9STlwVHIzIAsAN0EkSjwl2yeXZsoqAZVHA2pQtaDnXNtc2V4ugXohRE6CUx4GEWDxCvXNZw6pWNem60oSIYOEaoP2EYsHMeRoxpxoRZq20VY57PVUiKIABmYgFiILdyVh0Fz4MkqfEwkiwFIdixhScOmFp5SjE+75EVc/z+Ztvvy6lffv9t3efvXbXUub7+x/fPfz09PRwOh9LLU1bdCnAiIhQM+3SVtiHFwmzIBEEquo8L5JZMkeER9RaH58ev/nmm99++du3P779udhzl5Vzexip7Mpca+1ssRCIKfPNfnzz6vbzL95Ikrm1x6fTcP84L7Wpa3N3v9kPrz+bPn89vPksjdIEiiObS9gl2HYkj1DTMi9WrFJiEQT04EQTEaIzmYh5mGsxWwwG5gkGSXy7G5YlTri4qTbXalo1zDGiC/d/4Nn7AeDalgWOYaHaVhngXtUIt/AW1lzNqlkDb9j7xCG8I/LoDEzELCx9vk5FNHN/Idm0CRjGs6dxTTZeZu3rt22Y+j+XtXdR8V7f7bS2IFx9NrgFABKhOZh6Z6xDwNqLsZIDfJ0F5+7uPRN071Nk3NzM3GEtqzgSIlHqmv0I7i0amF+Q44+iDBt60d9wP4rez6gioMcYEEHEhBjgvqJogX1OUo9X1lJHYHiYLvNyf3/68aEtCnevdnlMDr2hEgPAzJalPh7n86JEaczDNBBCgFud2ymXgmW2AuHEMIy2u/EsIQI5Yc44jjyMMuY0DGm3S7tdnvZ5GAUQmtrqTZ/ZeT6/e3c+n+fT+byUok2RiCSJCBBad49dap+z5N2434/TPo+TpAFZEC8NzOt9wl5mxzU7jegyuoEEyGFmy1xO5GEqnPq39xdHuNao1EqWJJwlDeMICPN8ns9nVW2VMByFJXGSnOSFN38va1+d+nrjAZEHmNhRWtrZcluWm/MyPdTxUNPoR8EiQexp8nSDaRfz4NWDkTPuXvnr/6C7NzHuoHfypgzDAMMYaQBg6BKmm+T55qYvu2V9L9v0l8tX45Lcr359i0evewpxvx9pqfB4bApNm5KLNVX3QvOsx3OtCiTDuGNmevXZm9ef//LVmy9u7z6TrjjkF1HnLgVKzOJAy1LNrLVWylLOp9PpWGt1ByJxj/P5fDw9Pbx9+/j4UymzqxHKbveU8oREKyEdcUMBQRim5PsM+5EoC0tCVvyoIN3VnrF1+nIh2IQgsbeyRk+sm2lj1kDevnFrQekt1JJlHGUYWIQ06FigVDvN9tNjPDyhHsHOAIoQpEqtUDMsNVSjFCyVbvZgGGaeUqWAMAlDDhmYDpPBy6KouTtotOJBbuZq1tQcQICxs0ApExFSIFm4IntIRLKu2B7ACLRObA8HQCIRliHLlNM0YOKkRqXq02xzA92gWgSgcOgFJALkIApBAUxo08TTSMBgAeogjszkTOqgGqEeah+Ku3l4a/X+8d1S6/3xfvph566lltPp6Xw+tlZ1k6q+pHzh0VRLq8Vac/POZkbwiGZemy5FByIPdEB1ezodv//h+6//8PVXX3/18HC/uvafAYlfF9Avf/XF3QgP94/Hx/NcWmsaHoCYRMYx73fjMA1jxLCfppv9UppWraVpbbtdevVq/OyOb25REFDNYqAgouy54yHhQIYczE7cwMJDnIghhA1FkZTIiJtbMdeqXnTQuoeKkqJTO3GduBNovWeGMIiR00d8ez+vzLW22cPVGvUm7u7pwj1MQ93Nw8L9otcZALF2PxOYAzgir0EdrCNZ+3PsOXlc8uj3XPhHXPtHMPg/4+C1QSu44WnYXXtf6G5gFgDIjJc4vksz+Xpqrg69wyara0eMrRDfXUinA0avJPU6H3V4CuL94HVDBT747Mu/r5yA54sLAeGaiUSn5DkEYngEBa0hoGOER7iptlpP53I8lvPijkLMKfU+oZUJbKa16HlutUFKkockzBDOZMItcSOs4BU58kj7W7x77bsRp8w505BpGHDInZYvOXPKa4Oru/di93tPZl6W+8fHpSxLqe4OhCRCklAYkXHFHxhJmFPKkww7yZOkkSUFXDDW2Mx7+ZN6114QRPS2SWZihj5+cR2MxSjMTMQYGGCqbl6rzUtNiach55xzkiTCRG5KiG7e3IWYM9KfBeT7Xb/IG2IgIiYYBNLkh7taX9vp9XJ8+/jwPc9vqTywFyE/4PgZTK/w8Q4elYSz8H5Pr97QzS99uokOSREBUVDPIW2dTRp97mnvNIDrruhlnS6Hf3XlzyvrF3bL+5uNEG9fHfB4wrfetJzmGZRzFPSkwU9nPZ5qbUEy7qe82+3vXr0+3L2W4WDBGIjhrVmrrZmZea+rMiNAtFZqrdpqq6XW+Xx8fHx47PWrstQ//vDH5Xx6fPhxWc5qTZu3annYcRo6+LUW3XsbCJMwTQPvx3Szy4fdsB+TEKh95Ajoa2RbNevz63WTNWvv1Cvz8MDA8FA1aI0a8po+9OkqYRGELCkNw7SbduNATLGUZV7mp0WfFp3PXhpGRWx9cjQ2RHdqM5YCtfrpHMezfv4Z3hyouufc3CEsrAkkyRQ3Exm+QOSDAsIdFACJgtDcLAyRE0saOAkhQ0CEujeIhhGEwdIjK8WOsK0K2yw8jMO4m8bDLk8jZnKE1pxOc41Q7wW2wA3jRIg+4gYCvLMHKZADE7C4E0YASeTAEHQkq26mrTSvrY+NeW+HWNhS52b1vDwSd0naro6vW5h+XcMQ4B611bnW2VrtUgYQHt6sqVuprVSjFB7sgU39/uHh62+++fqbr7/7/tvj6fjhevh3GeJf/qf/6J8fvv3m2x/wj+3dY6nmZs3QV8/ixDwNebq7e0Nsaq3U5XQup3POuNvLbufDqOgWKBhMNKYRwr01bWYaGEgk2ZC0qbobrWe+AdYgC4rg6nFucJ5tOS2JbV9hGEYWUdUuh8bhZH0EJhAEZ06J8f3ut/UWeyg4eDO10kG3DdVzAHf0AN/8Fa6TWxACKMDCPRzdmikRESBGn9C3xsx43V19g4U/O23wQ9f+JzPfP2GtYDl1nuIFWOtLBNesHdGFWICZMCL0yrIGuL43W0s93ll0nW3CgikToKtfouNNuRpxFaXe0ssLdHzNnd673etv/OjlYReeZ+5ZxwpYuSOtJTJcD20H9x56WWttXurp1E6LA/Iw5P0u76aE65TB3q2irWmp7kHDIDkLAHUVUPc+JcQClBMOBzq8ojdf8O0+H6YpCyWhLpeLK4UN3FzDyfrMOKSuEPTMSl2O56OZB4DkgSV1ldkABGLp8yM49ZmqLJkkB3JXo+nrDjds4xJp9TO4k3u7LF0PszhRzmkcJGfKiXJKIolI3ICRaoFaa1WlJYbMfhglDTmlcRzHYah1abXUurRahWkY83vo9QfjYS5fjssDXsu5gQkkBSfjodEI6RbOD6Qzoy2+VDu3GEI9gC04PEI1IoIFiGGDNjf595X3vnZNXW/uszV1AeR7DnbF4a95/LUk9XKNpZT3O/vis4OqTbucmNhBKdcYaqBjkowTQ07jNO1IcrXweZlr7TNWrDWtramp+QURD7dSl1qLaVVtqkuZ52VZ3LTWxSO0mdbSymJugDAvCzzes5yR+OLakXgdmsgszEOW45CezsNuGndjTkztY679xfbpf1+OgQvX61Jm7sezhTe3atipWuvOI2RKQx6ncRwmoak1bUuZl3Ze6knt3MJamAY2gApo/aRDQzAHt8BzmEdTvz/Gm9uQBC0AIBZTAiAQrUD4sjCMwJI9nAjBkcl6X1qHuTgNJGllAFpzMAsLsN4yg8CBGOjer5SAWTgPw81+d3OY9vs0DU7g4KzmRMtStVlYbWHPlEUCwHDr//N1eAaYR/OVXdiVNYHZAVWjQai7m72/tDbzMFdr+tEvAgBeIzHAzhQstbRexUUEoiAotRxPp+PxfD6Vm9sbFN7bRBTfffvHv//tl9/84Q/H07G1TQr058va33zxBbWdRVhEBXQirTpkSePAKfXoR/Iw7HZ5mrTW+fgkYJk8D7w7pJSUuHgzC+vwMA9MiNGqm0EHKocxkH0pXpuFe4SFq7sa9MqQAVpAM19US3N1mMyGcbBwyZIxgB2IJDO4E/iQeEj84YxdAOiJey/RhOsa8Mb6HDqpv+dmCNjbYWC9nX4BAS3Q3daU8xJEv0hA+1TQHi6s9/LjdLJ/pVnFNq8NaCvT43KwrcT51dOLADIwhz0Xv1gZc8rcidPSCbVu5q7EkTICEXpEj3FWAmLP42KdOvfiQv5cCeHD4/b6JQ/o4iAYBFvmhuABaxt+H3QR0HUqzLXWOs/lXKw04CzTlMdBspBBF01281hKK0XNgYhTJhGw9YTWLlKrpo6RcuSdjfsYd7Tb0W6kxBefDpuUIGA4uAYEE0HHxl9ej3uYr500kjJLQuyy3UTEkjJ32JwogHr00f8H37wSEmwlkUtKFgiI0ZuqO0Pq8qu79twwpHHMKSVE6QePmUUtZtY0qmoz9wBiYeauIkCIqqq6lNJElvZnNOSx8y/jUsfeIjRcUyAkjmmPKcfuAHe/8OWkdQYtbTmW0715ImocrTjpeYn7dyGHGPaRMl6qvAEb1/eC+j/LuV/CWtePAwACPa58uo6orf7e30NVAmg/jv/5V28sHCQHkGmcLWnLkWqaaMcWgEyJKDXzcj6uUYMbWJ+g3gUfvCOr0XGhVlSrWx8QULVVM9W61Lpoa2oKEQTEIpxSMBVV7IAarL5snbArwszGrMal1bm241xzziIE9mJ6z3uG648CAET0rRZP6/RFv0iZ+Vp3b0YU28xAFBbKeZx2u/0+p+zKSynn82mpc7GqhAYUHqROCqiITrCqtkagG8QSECXKYvf39m4fhx2PE6XkxKrNVWNxaNVQnl8FcppWQp86kRI7MSOK5JFyBiLryGO4e/MwALsQjAl7twsQAzNJSjKO6XDIh1vZ7TglAAs3Jsvm025qS7Nibn0mMgBFYCAYIQpJbHmKe9ROagkgQRbOiIGkAQODURjGv2TUyj9rHeOtrZZatC8GxJ5KnI7nr//w1Y9v352Oy2evP/fAu1e3LPhPX33761//5ttvv2+trdHvz2gIh5u7gaZmzSIaEg+5lTYQ7Q47GacgcWSSJMM4TFO4Na3mldmHMY+7iUjdQcOad1JgZEmSuCFWN8xjGqZhOhBJnJeYl2LaWqtaVbWPWiDcKpC9kdd0KQHYB9dwGhJykMQwrHUmDE8MjPAx174dFui4DVRYVRsA4JKGXk+T1ecDAFxT+ZW5tg0N7qvvWb6Bvh2KzwFCur7gfZT+X2HR0Ash9hmjhExbQ1oQXdrhYnMDnVuzkrP6ezIzVSJqzCzSa72u1tQakqeBMAUrtOagG3YBPSjo3QAvWy/+lOeO658Am0bG9ZvCO5ENvbfEAFD0tkRYxfAd+pgagAgPNWtLLee5LMXUccyym2RIKOym7o69V2Ge61IUgkUkJUS03m+ylNZzkqLmCJQj7TSNsfLYwt16frAqygOsw28AgsAjBHup4L3EkJg5Ea+NbT1fJyTqXIyUiDiiz18wCERkImN2AHBfM67e5NL9OiJGADMQAbghIhCTBxi6hambBQkP47DbjTmnCG7NVG2p6Bgaxg7NvDZr6h7eB7wmSAG+lMUcllI9fFnK8wt54doZQLZI8fljxLh41wBETwmJIw0x7kIraNPlFMPNY858YqlHMzu3sNNT7E/RKrAA8coowW27RTAAQBf++cC1b71aBA5x+W8biNIrYX3U88fgoUBCiiwAgJjZadBIpnko6cb9CzzoOgmZEcgiNLpIgIFbp15hhJqqW/RmcXdV5caq4q7m5pa11ZaknEGtoRujEJFIzuM4TDvJiUX6t/cIpPt2ZiEWunwgknJKKachC5Mi/akTYk0hulDsmrZDHzt8xS/WjKIPigHQAAH06OETBTIwGUWJaq3SsizzvJTmZoDBhEBYHBbA6tBTUmTsUF4AoBtQMWwFltmX2fZj7A40jp4GHzVmRQU6G+T33jkIhEa4G5haRKQkQYmFu2oebLSc9eltxy9hEDhhpzmhJGIhZggPbU5F3UHdzdVca2nusd1WRwYmorTdaSJCNIvOkNdwpESC0tsHHZDQOxfBnQGEyIk/Utv9l1ps5+cWmyMxM3c+P6Fq++ndj3//2998/fUfnh7Pf/EXfwnIdw+3Hvrbv//yyy//4ccff1S1WMeo/5xGRCnJ4Wb/xt7wMLz+/E2dC0Xsk0w55ZzyNHEegVjN57I8HR+8LgMD0iCJ1Wyp0Rqap0CBhJhTJDIIbcAogQkxMYmQC3s1QAwRYnbItlazIyTJOA4QYdXAQ6S7NxAg6EcCd4VsR3AhEOqnyHPDay8W9hhuc+q93vfS475wXwgbgfnZj+ulmxen3yr5CJciYB/EDptPwJdp35bvX5HZ/qNw+7n4wZzz6DIK0IeuESIRI+IagYL3ahxErKhBH89K4NFFKwIAOgeImbp1127eKAUJiCBxH6EbYPAsWtzeC17IKds7fn5N+N6/V1Dk+coM93l+AlRi5HVgnEAIwKoMB+4bQRwB0Myb+jLbedamgUTDwNMkSa4kNHe35svSavUuUMiE7lZKm+c2n9s867yYBUimNHganBMICa29FBHhF4r2WqZxCIreIAjRVXVeLCoiEcnIjNSH5HC/3dxr44Du0eUDImiLGaB3vcRGVu7sB1h7DDcGAq09CeDWZQ5rwJk8c9zsBABEOCVpGr1dtjYtTcMjZQqgdT5MB/WRGYLNiQWQlrLMyzLPy/MLeeHaE8JIUA1a77bZaMlrHn9hRvQnLAlZEEaIiN3BptvztItxR8cf4fRQDHU5R5nRDDx6o35XR+4sEQbIGIhQPWoPqV+k7N2vR8YABF3ldtd9zAiC4bjxOp/Fz90cSUO9LkyQh4GTsOzNxiY5BpluKaKfGgSBPVtUdwsDcIoQIgSo2pr1fiVX89q01qVuWXuE1bLMp9P5KbFQqzUAiCXnYdrv97eHnLMkce9jrKOrNXdApg9vRcRV7ybnIedxHJjpjw//eOl9ex5MX1vgNuZBRw67aw+gvhVgJSYGAqIj+pbKB0IgOqKGzTrXJUA1SrOqXoMASUgRHeKMcYYoARrAPXxgWCEXDiTzaEa1WVn8adZhWfseRqdRPRA1aP+etkifNmmmrWlrED4MCTijEIC7O7phOAJErJDdhj90106MnAV7TR7d2lLAWZuxiIU3U9XaWrXSwgGJKAsgSuI8JiaKAAgKJ3NVi6YW7kScEhn2aTmgsPK2XL3jisL873GrG9EoepuviIhIkiQpBcB5Pn/3/bf/4+/+n9/+9h8e7p/+y0/vxmG/2++X5fTrX//d73//+/v7d+52jWp/PjNrIZ6zvH5999mb16q+nGZvTSK6cg527VPEpdan4/H+4YG90j4F7pBBm89FvSFAliyShDMBR5hZc3B0DW8WSqROGlCdAtMwcGJGcLdSa0CoGSEOOXtTb95jH8Be1ITA/tz7rsFEyIzPO3ev6fV2v2GD2REottBqRQg3l7u9Ep/79SuufuVAbc/uxecBXviBLbfH508IX/7RvwdfVFReBoy9REcEJNg1CkUIidzcDRx76hMRoNV7FwICkCABxqrPGWYW61gUYCZAMGvm2nUaJRPj2gcPgOEr5t+3WDwDA5/ZR0OiNb3oAc3z63C34+knIkuD5JQRByIE5K2aagDaqYAIgsGmUYvPs53PZg4iPI48TSxCsJVA1FyrzYs2jTymnBICaLNlbudTPZ30fNZSDIiGSYZJJbvIKibzHKrZ3nZsP5ncYUsh8Dk+AwDELJIDO1eOiISkZ/AI0OV7PbxjLEBMzIxdsz9iDQ4RO6uRL6FW17hJ0gvRDg4ezdU10CqH3h5ysx0gEKGZllLnZTnPy1IbIozIQGKOHtDRX+gVBeZe8q/VTufjef7TWfvnA/3ve3lo/tSgd8UMDIlXmVcLaBGLQ/NNSRkBgAAB0uDIjSlyJh5AyZbFzcEU3NbBAmvWHhiQCHYMd4kGgqPFbFtjp0esRTOA8Ex4EGaAGqvuLgIKwMAwCVa1c1EPoMA7e8nbAlPVeWlMCEOk5ACNMIukSaaMAyERQtf3i4jOpQlXtEJWiQiQlXYKaOrqbhbNtNZS24orRri2upzn+XQ8Pz0ty1xqTTnvD4f94XC42WUOgmYhLZI7uveCT2zpNiBhSimlnFPKKeWUCeMn5ucF3I3hsi3PjWRwbZXsIH3vlVi7Sq6v74FiP3vQI5pp1IYKQQHmoI7u2FlFAe5gHmeDs+HiaGtM+/yXrfoa3mcLRgMs1WvCWikjpuaI6EHB8IvnV6HWhwaFKWEMg3AaA0Ud1CysQVh/6xgIwR2fQQRKXaQEHbC5ewtyIDL0xS0ANCKZh5mrqTVT60PDobPNgaiH+hHhhtrQHGrrBSwDIGFlBepcXIC1bcodvM8vtH8zyrouw+0hmnlrbZ6XeVmaanss33z7u1//5tf/8A9ffv3NV/NcUso3hxtJ6XR6+v0/ffn4+HBVoPtZ/ToEaFuMHMFzljRMAFiHZKVGawCInJAzSnIAtTpMu3Ga0DDlDjYFC+c88ihDymlIksSiFS3JIgczJSbRqk3rfJ7npVQzFMnTkHNmJg+jjereiE3MRF20z4RlFuZUSy1Q3AxcezetMCeS9+F4vFT4nl1ewDa2b6Pcrh9c/euV5nYpu2+f3Jz6elhdXvJx/GaDIbdI4eohsTd2YldruQYLHxozinQSTg+rkGXL9ABWSTcgd7cLAQ0QCZi71I731LRj0KrqjtDpIBGgAAREwQmIURKGg21V4T5oEAAg0B0cXxYcPjBcS4EA8eGE3SBciIIokAgwrSyH3sYGCqAAfZhQ1wVuy1LPcyvFkGQceRxkyMxMgA4Q5lZLq8WqWgDlLClzhNfFjo/1+FjPZ6vFHVASTnscRhBBYVqz9o/d7+53Izqx2x0MvaeQz6+RkGUlT3eEtY9kjPAwMwsHWj8vtJY/4jlGcymxr/X5biIk0o+CMHPTHmEZgmrUostcl7lgxLLYvNR5rvPS5qUxU1NoGs2iqatZCmJGgt4LwkRijqVYay+KuS9c+68mlrv0/Wx/XKCYA8BNoknQHHo30NnwXfVjRFshie2WIUbKLqmNOwyCuYG+g1rW1jVfmwd7dIuIA8Kt4K8muk340PxJATYyvDtUiz4IZsf4ZpSMWNzdAyIEIRHcZL5JfK760xncUYheK8PD82eoajaXIMJkQG5kZ4REiIlTlpGYBaErUnaSCoajK5aGVR0wGGE6QJ42GATMvNVSW2uqvR/JzLTUZZnn0+l4fHo6HsdxfP36zeFmv99lsRPUR8Wx0cEiR2xSkCsbFZEo57Sp1DIzhzvT8xjlQl5dGT1X3BoBOyEIt7bbgIvOZN9+a02XtoQ/wJtqaabhCoRITMiIRB6mZqEazaM6NuAW7LC9lxV/xPUt9cnpfc4UGaIBViOu0HXNICLlF9mUq7o1t4rgOVEe87CbLOi8NFgstHoYBIADAkYHGPowiZw4Z7fQZqo1XJlD2Jg9XCUbeY6LmhGgO6hFM6gaDkAJ+izy9Ra1UItafFlMW8PAhJQ4EA3CESNPWVgQLKKFq5u+wCX/jYYRodrO8/nh4f7+/v718c3p9Ph//+3f/M3f/s0/ffX7d+/eRcAff/wOfhPucTofn57ua53jCin/zKZlaWBBwJxEiEjITcEVDIFkmHiYOE+BqNpMa1uOXp4ya8pCjMOQhfI07vb7/TAmEjqenx5OT1MgciYUdKjzMp9Pjw+PSynIMkyTEA4snCiC10QGsCK31gzRkACQOQ3DNOZxmecTQJnn2hTCkIAFBd/3r4jvMTZXp3zNwteEvZevLl0TH3rp9xzalq+/CBs++EUv8vlrjWyD5Du1KBAhkOOZIb64DuppehcrY+q4Lz57F2thHdGcwtaGTCQiAA5C6JpUvfcp3Ltr7ycvqoZDIAUQEmFKZApm3tu0OperHyEalyEeL8sFlygIr3/GJXC6XAXFOCgRcGLmoBX+72PsW+cpAhAAR5B7lNrOc5mXWlSnXZrGNA2ShYnQIQLc1BavZXG1YMY8cE4E5rXY6bEdn7QUUw0mzCOOe88jMhOTMArBNjYPHD4yVKiLCNl28r+3qgiwPwbpfHUkCgDzZuZmhoDMnHIiSYC84fDR0Xe/AMmbMbNIkl6R6tw6N9dGAMjUEarWbJnL6SSmthTvfn1Z2lKUmMdqpXmtVpvW1lImSdKF94mFOAP2nP7Fhbxw7QPBnUAMOCBVh0C/ER6Y1KF51Iin5tWjdNd7jWjXNR5IQENMN3T7BlCwnGPcA/E1XL68ELCD6iNjIrpLBBtDxQPUopov2oTgJsHAGEAQQACMkAgnxh3jgMjAEJiIbo8vBpJM44jDwMjEfLg7JEbyKh7A0Aj6mFtCEAhx7x4TCZHD2ZwaBAUKM1HOCOusOXe0cTIfdRuk6e6muszL6XTcTdN+t5t2u88//2K/G4fM0MYoyWh0uYnIELJpQa7UlQ7IM8vWYopu+nLX48W2PXaNC9a63Mryg0t6T4i9W7yjQdjnSKw1fwj1qB4VPDqeh45g3nkGAQboFN6HE8EzAHNN3DdsYH36iBjcW9OI3EEdzMF12T+XgQhCEzLkSMLjNAy7/bC/seDxXM/H0xmshrm3iL49AiBIQDKP+5ynSVtbZl2OVueqCkaUBmQJjxaAiMIiEWThurSl2PnUlkWbR60hnBE5DQwUDqrmqmqthdZQA1UiEI4IRKKUMGdpiK7ex3H8W+H4eO6VI0K1Pj6++93vfwsYf/zx2+Pp8dd/97f/+Lt/eHx6aK0g0un00LlDtZbW+jS7P5tA/Tus1CVjjPv9OOxT3gGAUfWwpgWRGXcp5+lwQ5LNlDAo1MqDxDJOaZpGAILgYRjHYZREQV6bDGmAfZ52xJLD4fj0YAjSKlMQMSUy19qqBCEhAw4iMAxC5Clpat6MOOU8TtN+Gnfz8ZiRzkQLgbYaoYTvJ4mbH32GlV+QxOfI+/X1V4B+ewF+7Ca/+PDDBYDPvPm1Tt03KV0rZ1s8vfnAXnHurv2lryFCZuyuvSOsfbtTn31LV7aGMwX2LL2HMF1Un7ZkemODX/1zL4RBawEYknsuGkhr3EjsRIBCSr1V61Jw3S594/C9d+M/rBEhwm4AIEShTdbCIRSgYVTojDUU6Hp/zZe5nk61tq4sm3ZTzkmYus5NuHlratVaBUDqYzpyllDIOQ2jlGrmHgRMkCbjFCwovIkWdqhyvQq8PPW1gnMl9fRWxufdjEDELHml0TETcUAfJosRSJw6pQ76g7kCOmuC1e8j9okudLWtLyOiz9RzY+pTyNAB1KI2b82F3CxMXTsNUp0cSmnzUkfBaaRSU8rEHL3gz0TjOOx2+1JqiXLSPzHUlQASxqsEe6EWAUATUyJRh+pRIojsvpnU2HAsvNapttWGecTb1ygZlnPsb72Pso5nu229RDB3DLgTGdYYD1aGvsei9q64g2eBUWBgZEBEZEQhEABxRwFCJoBENKUXW3i32w853dy+4iTDOLCVOD0mJQSY0Qpqh+zIVaxROACiCBC00K4SER4EwIhCyBHoCgTBEiQRHEB9i5nZeZyzUBKexmG/33/x+efjmInCq+iQgEeUPUFC595Wbm5mFmZgjtS32iXFeCGM1NfHSsTYFk7fxN1pP0/JO1VvJfhij/37koIND4GwgObYAkuAgTuoh0ZEeEB0Hvqmq9Opd7TRitYFvBZGIgKgt/yjkDMDobmGYjQPfQkNIRApkWfBaUqH2910uMm7OwsZx5aYQRtardYCej8CEKFkyiPv72S6GUpxSt5qtdMSjQyZhAMxwAKMWJgYMLUw83mpejq206mU5rVGzllSSjl3lq51xNIaeEVH9CqIo1CABOVx4DwKIXkLoaov5YM2suu/0N0+f5mrtcend//4j795d//2t1/ezsvpD99+/fh4X8rS+6TLcq5l3kCa+OAn/JxWShmZDjKMuxtKg5o6YDOby0LEMikJT/tDHib3SEyJwsqO7JxT5MRETCDMIkyBpq5MMKSUhxHTLo2jA+CQlXAB9zN3cYxm6otLH2ydkjBjBmcGD1O1pjmP43Sz399M4+6UMrkLhlCUBbQFrTyvFxdycfcr7r6dMB/zxwBbOv9esRzgA9/1/vd+LMBbIf41Wb9IVlyr7Os/YnsVbgcc0EsmCtHaHkuEHf0FiIgOisUG1QEgCBJab98J83WqAlFnDm4s22cH4ebdQ5tHdEAeiII4uvg0MSTp8xDAWphuDUgb8PH8lm+MhbjwXF9cBcI0YCAGswN5VwZCRagAnd4lEBwQbt7U56Wez7VpIPEwpHFKKQmtOyzMXJuVBdQQqRONJWcGxrbTw+1gvZfdISRwUGBDSkSZae13e6Yec3nmGNA1+S8hC26u/dmFsEiCnm5jZzKt8sLQ2fPMjCwBaBaAzusxvLr23guwSgKtXyO8+PUtvetwDq0XDObeLNTAfFUkNPVeX/SAUtuylEWiFKo1t4qdi+juxDSO4/5wMPPZTu/O1wt5KU337rsJYoxwAAsHhITESBbQ1T1H83Gx/615seh9PHAtVfUnj2iKrSIU4BY+x/kRWGCrd/UXjYw7gZuENwlHorS59jWlj2ju+6YOMaQ+xhPpQjMEIAiMEI8UToCMKI8/XK5Czf/H33+dRAIQiSQJWotlVucCPzVMDaQ/d3GX0HUHdC0VXcwWCwpMMv6B88gQXSsIAYIYqOvodYcZ7tFaXUqptbbWTnk43X+fhAEirJlVREHOCNx1W2JtgvQwh/AOqsdW5TO31q5UiF/88pf/13//79QnkfS62+q/V7feSTfIABjhYWpu0YvUKwtEEAWAojcahUNYuEZogCM4eoBtR8emq4u9HoUY1HUPKXpjZ98M7uBtpZ6BIGWkTMjQ+1k64/GzmzfPdhUm+RVAZQYAaXWA01TrGEC1JWiUWWC8S7zAs9IDMUliginaQF4z3R12RaCGEQKlgfJAIkzIGAlCGHggP+xuOMou11pUNfIg+9tpN+VRknNwaIJyGKq1Cm5DinHAnFASAnCgpGGQlAiDSZNUvdGUp8uFvHnz5q/+6q8+ONP+hYbMaRqnYRjzMLR2uL3b1bpEx0Av9On/f+zVq1fPveCT7tqC8wMOpSK7u9USdUm1HBDpSPRg5e35HacjBLS6lNlDE8TEFMKI6wxh6geoBdc6VE1AiUT4hIFwPo9nfdU4xVChH3kABmiEClSdEbepYH0Zkockx0FtPJdUg8qQ/OYV5SVpJdUuek7DzeWi9rv9f/0//tvltq0Z1EXQHa6fx6unxeef77ZJXD9LPdasd0vuX3j29xlx6x+I8L5nfd+2ZDqmYXru3f/yL//Lfn8ABCRkpu1M3eajbsT1/k/v/18qe8+u9BpWbOHnhdyPGMiRMjGDatg6jhxYoCvzm0Erpq2T8Ps3Xxzgy1/w7I786ld/8ez6RpL/MwCDEoJQSNeChFDABtA/EABmQEL/7FXLqVYDADocxv1+GLMII2Ig234siUxzeGD3/YMMjAwU09jiszKO7dWrZmHADuwonhLnnPZpxzB1xb4L5na9go132R+Vr+J8/JyHfdgzQo/VVpfsgRHkLhHBTB0xXR8oOqEiOeKazPTmWkRkdiIlrGt3DhEgurijGbVwY1LhxkTCbHo+n57Qh6ck7lBqK3WJKMIGiGFzWeTBWTUfz8OQRRIjQJcfVfWy1GUpp+WFvBX+9V//9Z9cjJ/sk32yT/bJPtkn+5/N/rxi+Sf7ZJ/sk32yT/bJ/iezT679k32yT/bJPtkn+1/KPrn2T/bJPtkn+2Sf7H8p+/8Al7epOAplbmRzdHJlYW0KZW5kb2JqCjM1IDAgb2JqCjYxNzg1CmVuZG9iagoyIDAgb2JqCjw8IC9Db3VudCAxIC9LaWRzIFsgMTAgMCBSIF0gL1R5cGUgL1BhZ2VzID4+CmVuZG9iagozNiAwIG9iago8PCAvQ3JlYXRpb25EYXRlIChEOjIwMjIwNTMxMTY1OTU4KzAyJzAwJykKL0NyZWF0b3IgKE1hdHBsb3RsaWIgdjMuMy4yLCBodHRwczovL21hdHBsb3RsaWIub3JnKQovUHJvZHVjZXIgKE1hdHBsb3RsaWIgcGRmIGJhY2tlbmQgdjMuMy4yKSA+PgplbmRvYmoKeHJlZgowIDM3CjAwMDAwMDAwMDAgNjU1MzUgZiAKMDAwMDAwMDAxNiAwMDAwMCBuIAowMDAwMDY5MDk1IDAwMDAwIG4gCjAwMDAwMDY4NTUgMDAwMDAgbiAKMDAwMDAwNjg4NyAwMDAwMCBuIAowMDAwMDA2OTg2IDAwMDAwIG4gCjAwMDAwMDcwMDcgMDAwMDAgbiAKMDAwMDAwNzAyOCAwMDAwMCBuIAowMDAwMDAwMDY1IDAwMDAwIG4gCjAwMDAwMDAzOTYgMDAwMDAgbiAKMDAwMDAwMDIwOCAwMDAwMCBuIAowMDAwMDAwNjc0IDAwMDAwIG4gCjAwMDAwMDcwNjAgMDAwMDAgbiAKMDAwMDAwNTU5MSAwMDAwMCBuIAowMDAwMDA1MzkxIDAwMDAwIG4gCjAwMDAwMDQ5OTUgMDAwMDAgbiAKMDAwMDAwNjY0NCAwMDAwMCBuIAowMDAwMDAwNjk0IDAwMDAwIG4gCjAwMDAwMDA4NTQgMDAwMDAgbiAKMDAwMDAwMTE1OSAwMDAwMCBuIAowMDAwMDAxMzA1IDAwMDAwIG4gCjAwMDAwMDE0MjYgMDAwMDAgbiAKMDAwMDAwMTcyNiAwMDAwMCBuIAowMDAwMDAyMTAzIDAwMDAwIG4gCjAwMDAwMDI0MjEgMDAwMDAgbiAKMDAwMDAwMjUzOCAwMDAwMCBuIAowMDAwMDAyODY2IDAwMDAwIG4gCjAwMDAwMDMxMDAgMDAwMDAgbiAKMDAwMDAwMzM4NyAwMDAwMCBuIAowMDAwMDAzNTM5IDAwMDAwIG4gCjAwMDAwMDM4NDggMDAwMDAgbiAKMDAwMDAwNDI1MyAwMDAwMCBuIAowMDAwMDA0MzQyIDAwMDAwIG4gCjAwMDAwMDQ1MDEgMDAwMDAgbiAKMDAwMDAwNDcxMiAwMDAwMCBuIAowMDAwMDY5MDczIDAwMDAwIG4gCjAwMDAwNjkxNTUgMDAwMDAgbiAKdHJhaWxlcgo8PCAvSW5mbyAzNiAwIFIgL1Jvb3QgMSAwIFIgL1NpemUgMzcgPj4Kc3RhcnR4cmVmCjY5MzEyCiUlRU9GCg==\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T16:59:58.867390\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Main class: aquarium_fish, Anomaly class: mountain\n", "Prediction: image 9\n" ] }, { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDY3MC4zOTc3OTM5NzIzIDY5OC41MTY4NzUgXSAvUGFyZW50IDIgMCBSIC9SZXNvdXJjZXMgOCAwIFIKL1R5cGUgL1BhZ2UgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMSAwIFIgPj4Kc3RyZWFtCnicvZ1djyvHdUXf+Sv4mABBq+urPx4t2FES5EXJBfxsyNeKBV0HsoEE+fc5Tc6w9j5TVZtNQxcXRqCTIRfJ6a5aTa6ZCdefLt/8Jlx//Nt1vv5k//vfa7h+d/3mt5//588/fP6P7769/vC3y2zzL5dlnae0r+ue7T9/xv9c9m0qYdnWYvOZ//O/Lpe/XOz+7Tbf2V3/eLnEeSplT/b/LWWdyrbFtNvdh7RPYUlzCDj/GefLukxhD/c7rndDY+P96fLLtQHZlrTkawjbtMxpuf/z//nXz9ffX/9y/eY38XhFgr04wV6R+cMr8ovdcL8er8vxf7uwH75cv/nXcP3tf1+/v3x//eX9fmd7OcLxak/b273b5BKXKW+73QO/LjjO0/z+sly+PR7a5dtP12/+OVzDfP30p0tMU5r3bdlT2OK15Djte87b/bF8+uPlH+Z/vH766fq7T5cb+ZK2KeV9jwyEqeClMoV1jinuJZcWLzCvzFMMc9wyA3EsiHmd5jSXfSn7vLeIkYlLnOZ1XvLGRBwLYtmmzVjLHOOaWsTExDVN9vK7o/gLjQXRzqu1hHQ/hlrEzMStTOsePpxPOBbELU5LuJ+wDVxh3L5OSwmr3RnhcCxwe57yGucSwoFoEBcmhnm39eQ4RNyKgXPBDPMypRRzyOsa1xZ0ddAYprTGvCwOinMFtcUl7HFdtzzvoQXdHDSlyU6nLfmlsY4VMs3TXFLIcc6l+e3cGQlr15Ls5dxyPIh2nvlp52idjld/n9YYbBE4/rVWGqCsxV6/zV4SxtTxgLMGe1GCnY7Hv9YKA5zNbpH2YOc4cep4wNnSlLOd8Me/1FpXgLPv07zvdjAzp44HnN2e9BzfN5/GagKbXJinvezHOYocGPc5y7weW2hr8QBAjNMW57cjrgLqeAAI9mxjuj+P5qoBnJSnZZttQWVOHQ84dnRvW1rv35nWWgGcvEzFzjs7UYlTxwOOnTZrzvfjLLSWB+DYDpFnO0UZ85gOKHYyL3O+nzWltR6g49gmuQTb2hlTxwPOYgphm124/eNFIF7/7S5+NxVhyemIWke8Lv/ZUbgvXYWzm5yUQboF3NeQMd+e41257jr34+NFijePiVMocyxlOV6fxbbDt5sfL9W//+H/Pv/1Gv7p+i+f//DH62O9eZPmu9nexflhuMHO65jW2/ZNZhXStOVgxy4/K5jTY4f7+ei4h08+vqCY9C2vqu7lrrrXZ1R3TLz5bnzWd+dg37H1eMnd8QVztd2ZduzbZruevTTNbd1vRPZSm2raieCgOFeabd6x5S3vaTOheMJ7Y8y2bG03SSYozhXUvGO1ryhrXubmxu63qLTYGrbdPJmgOJcXFLb5L7uxbcHcnrDfmI9jbHdH+BeeK2guU7ZLii2vZmxPCHBcZlvf9g+nG80V1G4YN+Ns6XZcSw22p2JC6YDvMwVb9mPBSSHOto4/YcB2bWRn/3HuOiDOFXQL02wavBbb3ZpXMn6L2xfbZuabLhMU5wq6mz2sYU72rd+a56fb8ZKpyh6CP2hxrC5K5zytJsBz3MMzAoyLNRrwbuebH3cWIXtlzVztcHtbE5UNI5J0GJjKhz1UqjFCyY0BquTYQ6UnI5REGaDKlD1USjNum2TNFSq12UEHBo00UmigKYf2NKnTCCWfBqgSag+Vbo1QkmuAKrv2UCnaCEXTBqZQbY+U1k3ahdoNTOXdHvqigvdMsmeGZrttyfzSlczD20/6Kt+ENHxAkRoe58VOk62EpZSQxx4en/Fw28pt1yjb7nYw2w1sQd8L798/45wePNxP28MfX7CsaQ9fw8PHxJuHpyc93C6l7FtmK5zfc3GuNl0zHzvG7aHMpf026Ye3ntfpeAc2f3jzGeYKavZTogmNSV5qyoXbApNtNWU5FkUHxbmCmv3Yg7RvYmo+T7cB5hBMZldbKdwhiHP1pve8TTGbswdbV5tvmLoN0PzOZHYt3k9prqB2nR1m849iO0hT/d0+mI9PA/L68WTDuYKmMNmFhl1gxdR+w9Ttg/n4QGC2ZdYxYayQOd0WnfV4YZq+6HbBfHwisGxLdJ9i0FxBS7G9edvTUo6TWbt4Pj4UsOuTzV0r01xBF7uhXbTOswle87rK7YP2VUbaUwl++YS5/MBmn24+YDKbm9dVTsdxzQYdT8cG6se9T6XMQcpSYnhbGpWOIxJ1HJlKxz1U6jhCUccRqnTcQ6WOIxR1HKFKxz1U6jjunqjjAJU67qADHUca6jjSlI57mtRxhKKOI1TpuIdKHUco6jhClY57qNRxhIKOI1PouEdKHSf7Ah1HptJxD31Rx3tC2RNEs962a37puuZh8Ce1lW9COj6g6HfF7RKmzMu6ruYC+1jH0zM6XrL52n58y9yn8bagh3tJQs8L5vTg4X7aOv74gtUuv9avoeNj4k3H85M6XspsVpuT/4Cd5ipaMP9JOZtd7HFp+oXbAssSzWrNJtxRSnMFtY02zmUOdqK137f1OciazWuPQ8xBca6gttHOS0nrHvb2+7ZuCyybHTFzyf6JwlgWIcaI9iDj1u4z3P5Xjs8oluIO7y88V8w9T+tm2+SS1/abt24bXI7PKGxJ/HCu4VxAF7setx3Evvthbr9567bBJQbb445V1uc2MFdQOx3tW29XV2suzUsAtw0ux2cU2WzafbhCcwW1zTYtqy02tug1r7DcNrgcn1HMa8zuJKW5ghrFZK3YdzWszSsstxEux2cUtuIFd81McwW9fcG6R1sccvMkdT6Oizb4eMl23vlx59i1i7qY8v3NsPUJH0ck+jgylY97qPRxhKKPI1T5uIdKH0co+jhClY97qPRx3D7RxwEqfdxBBz6ONPRxpCkf9zTp4whFH0eo8nEPlT6OUPRxhCof91Dp4wgFH0em8HGPlD5O+gU+jkzl4x76oo/3jLJniKa9bdn80pVNu8lZb+WbkI8PKNLHTaumsJU5bXGfy9jH8zM+Xksa0yc7GeLbp5C1pIE5hdh536a4Lbclv94LTh8u3oLc1uBTIfblOQ3vwm4GXk6H2PS64Pj4TsT+7sMhdlqK7bn7nPJj8+mF2ASEqeBxiN3gdUNsAuJYEDnEbhC7ITYRcSyIHGI3iN0Qm4g4FkQOsRvEbohNRBwLIoTYDVw3xCYcjgWOQ+wGsR9i84qBc8F0IXYD2g+xGYpzBeUQuwHththuaaxjheQQu4HshtjJjhiTeV+uwrjfeh5v2mxhQWPoptjJDpWjseKktE4HlD1Ntn+hJHRDbHvC9g1PvlyFcZ+T5zKVtKAXdEPsbMdD2O0WK3PqeMCxr7fnfcfEcYid7RiYS3aZfJ0OKHGfol25wJVPt8bOJhD78bEZZ/IwHnBMxEJY8WKnW2Nnk4rVVjmXycN4wCl28WgKB9c33RrbLsSnJRWfycN4wFnytN0W4cclTbfGzts85b34TB7GA46d08veXA9QcuKUyuIzeRgPAMe7sfb1cOFyMsRmIWmaF3fYPbNztzipgnSLelcjwskIu8z5fuv5ca7cDTe+GGKzWdVKhZ4XzPHhw900HddF0S+K7t+RYbdsdzmfYbujC+Zqs+MMWwsv5NYMxbmSbM6wtfVCbs1QnCsoZ9hafCG3ZijO5eUEZdjafSG3ZijOFZQzbK2/kFu76yaYKyhn2FqC35JrBr7PFIwzbO2/kFszEOcKyhm29l/IrRmKcwXlDFv7b+2ted2EsbokpQxb6y8u1eS/NbqUAuyiS+3CCEUZBqawYY+UYox7CZlxZUo1dlBtyQglTQao8mQPlcqMUHRmYApp9kjpz4gkgQamMmgPlTKNULJpgCqd9lBp1ggltQaocmsPlZqNUPJsgCrR9tC+c5N1oXQDTVm3p70o4D2P7IghV9g98/TSftJV+SYo4X3G2QRbWvjZDJueGcQp9Mxgjo8f7qZp4S6J/goWPibeLHw9H2Hzdx/nasvlCPuZt50fsbV74xnmCsoRtrZwiK0ZinMFxQhbOzjE1nwA4ly94c0RtnZwiK0ZinMF5QhbOzjE1gzFuYJyhK0dvNbWzISxQnKErU0cYmtm4lxBOcLWJg6xNUNxrqAcYWsTh9jaLZ4wlx/WUIStZRxXbJRxSC6ljLvkUss4QkHGkSlk3COljOOWgjIOTCnjDqplHKEo4whVMu6hUsYRCjKOTCHjHillHJEo48hUMu6hUsYRijKOUCXjHiplHKEo4whVMu6hUsYRijKOUCXjHtqXcZIvkHGkKRn3tBdlvKeTHT/kBrsnoF7fTyor3wRlvM84G2BLGT8bYfMn8bVMoWcGc3z8cDdNGXdB9FeQ8THxJuPb+QSb+wicq1yBE2wt45BaMxTnCsoJ9hMhSE2tGYpzBeUEW+t4ba2ZCWPZgmCCrV0cUmtG4lwxOcHWLg6pNZ9pOBdQl2BrF4fU2oU2MFdQTrC1jUNqzVCcKygn2NrGIbVmKM4VlBNsbeOQWjMU5wrKCba2cVyy0cYhuJQ27oJLbeMIBRtHprBxj5Q2jnsK2jgwpY07qLZxhKKNI1TZuIdKG0co2DgyhY17pLRxRKKNI1PZuIdKG0co2jhClY17qLRxhKKNI1TZuIdKG0co2jhClY17aN/Gyb7AxpGmbNzTXrTxnk92BJEL7J6BulucdVa+Cdp4n3E2v5Y2fjLBPnILM/3754+1oqljCrCT3TBtbz9IAPEOjhsJ9vu93Trh8isn2A52M/D9dIKNr0udhuOb9/ay6AA7bKt965J53mPr6QXYiHsMFY3z6watm18jrk4Vj+PrBq8bXyOvThWP0+sGr5teI69OFY/D6wavG14jr04VD7LrBqybXSOsThWMo+sGrx9d0xpRx4rokusGsp9cE7KOJZKD6wayG1zzQvg+lUDOrRvAbm4d5+PTrtWnqTDuF51xtoPaXl8Qg25uHcPxAdfq01QYDzjBDuY1jEvrmI6PszafpsJ4AIh2AKeA23+3tI75+ARr87/BGcYDju1heQ/4Y3Td0jqW40OrzXfwMB5wih2rJeIFTre1tgu3aZ5318HX6YBiOhFDxGuabmltbjIdEYLr4GE84Kx2jbhGvIzpltZmJ9OaZt/Bw3jA2RZboxNeuXRL6zRHM6zZdfB1OqDYubzuCS9Wurm1ndFTLrPv4GE8+CmFMNuinPD65GRuTfLRtizOrTsW529xSvs4z673NCSczK1TXO43L4/T5W6z6cXcmjyq9ij4vDBswUePIUzTZ138/KLU/h25dcts7bA73Vvz4fUYy72Oa2ttt1BVE7KOpU9za60VF5pqQtaxRHJprS0XimpC1rG+aqDOWosu9NSErGOJ5Mpauy7U1Hxp9BhLJDfW2njfemrC3UcSxYW1ll0oqQlXxxLJfbWWXeioCVnHEsl1tZbdmlHTMvmYyutNaqu169K6jLJbS0ppu66k1OKLUDJfgCr19dC+BSONNBhoyoM9TSoxQsmJAaqk2EOlHyOUBBmgypA9VMoyQtGWgSl02SOlOSOS1BmYyp09VGo0QsmjAapE2kOlU6PwoFTDTyEIq/ZIKdiIJMMGplJs/6MPL9p2Rxp7FshtdUczPxj6KTF1MTYa94Bxtq2Wyn22rcZnBtUJPjPMV/DhY+7SVG5XOn8F5R4T78r97F9WhIiavv11LHdcTqufeUP5kVDzW8qPsURyWK2VGwJqQtaxRGJWrYUb8mk69upYvpHNUbUWboinCVnHEslJtRZuSKcJWccSyUG1Fu5aThPxMZVAzqm1dkM2TcQ6lkiOqbV2QzRNyDqWSE6ptXZDMs1r5WOsP3ihkFqbNy3PYN6QTUrzdtmkNm+EonkjVJm3h/bNG2lo3khT5u1p0rwRiuaNUGXeHirNG6Fo3ghV5u2h0rwRCuaNTGHeHinNG5Fo3shU5u2h0rwRiuaNUGXeHirNG70HzBt/5ECYt0dK80YkmjcylXn7n3N40bw77tiTQQ6pO7b5wdVP+akrr9G8B4yzIbU077MhNX2qXgsTfGaYquDDx7Slad4ua/4K5j0m3s372b/xCMU0pQ51LNsD7qi1eUMvTcg6lkiuqJ/oOWotTcg6lkhuqLV711iaiI+pTjqwoNbiDaU0AetYErmf1uINnTSdYHWskK6e1uINlTSXMo+xRHI7rdUbGmlC1rFEcjmt1RsKaULWsURyN63VG/poQtaxRHI1rdWb1mdQb2gkpXq7RlKrN0JRvRGq1NtD++qNNFRvpCn19jSp3ghF9UaoUm8PleqNUFRvhCr19lCp3ggF9UamUG+PlOqNSFRvZCr19lCp3ghF9UaoUm8PleqN4gPqjT9fINTbI6V6IxLVG5lKvf0PNbyo3h157NkgV9Md3fS3OCeoLrNG9R4wzlbTUr3PVtPp+OT1w1+Ar2OqpoNtW3l9y/7rndC4VU0/7u1UMH28JJdTwfSDc5fsZ/+AI8TS8HJgQj2l91dDx9LmvusWosikAQTt9JjDmTRx+oE0gDCbHpM4kCZSP40GEgbTYxKn0UTqR9FAwlR6TOIomkj9HBpIGEmPSZBDE6YfQgMG8+gxhkNoIg0SaDzhsYwes1wCTbBB/IwwbKIFjONngvWzZ1rJagwtUJw9E6ofPNsF1Bw/BqLv01HvarpTcFfv1s62py5L9L8l+X04iDZXs5uCu3i3dy52iRSjT0PrtA/Jx5+gKbhrd2PnpdiGlfzvR67TPmQ5/uTMgj+o1i2dV/v6nHyBXqd9yHr8iZkFr0i6mfNm+6Dtwy4/r9M+ZDv+pMw6zpvDPE/7kVxxfA7j/v3vx1+PWfGao1s3h3D8rZjs23MY9zHHejHHFS8zunWzWbpd8mffnsN4wLG72bYNryy6fbOtfXadX1x7XqcDSjr+NsyGFxNn62a0hbYNubq5bVv+Fqf0jL4e7mlIOFk323n6dvPHWXK3zvxq3Yz6AxEJPC8Y06OHe2l7p2+NX5PPJ39arwu7G+izf7MQm2Y6qCB1FhsZN80jCcWaGWEYOQvh5Zp5ZKLYMSMM82YB4455JKNYMCMMw2al8lQwj3wU22WEYdIsYNwuj5QUq2W6RoGYWcC4Wh6J6XuvjKD3hFlAuFceOSmWygjCgFnAuFQeOSk2ygjDdFnAuFEeOSnUybjMQbMsLvOoTh4pKdWW4KQYW46l9ENVKvwUf2MqCCr8RtmxofrmUckq8MhWK1Dpqv8NtspcgUjqWonKXR1RaiwQyWMrUYmsI0qnBSJJbSUqq3XEvuDidk2GW1lScR1M2i4ySXeBqXzXMbX6IpTcF6BKfj1UejBCUYSBKUzYI1+U4o7b9WTNRchtG/wg0qf8kW9AYjxgnI2QpRmfjpDhmWHkAc8MxvTw4V7aZuyT4F/VjLuwuxk/+7cEMT3GbzoWyWI75fR4/PZsjY7pDVpokQWMo+ORGWNujDCskAUMc+ORF2NojMcX9sfiDWEOjUdejIkxwrA8FjBOjEdejHExwrA5FjCOi0deDFkxsiA2FijOikd2jEExsrAzFjAOikd2jCkxwrAwFjBOiUd2jBExrXXQFqsPJygiHgkyRZFVkKmJHAvyh/hTCDL+FtMqyPhbXseC7NNEJcjAQ0EGoBJk/1tllSADEQUZiEqQHVEKMhBRkIGoBNkRpSADEQUZiEqQHbEvyLhroyADSwqyg0lBRiYKMjKVIDumFmSEoiAjVAmyh0pBRigIMjKFIHvki4LcUbyes7lWuC2FH5T6lEbyDUiQB4yzrbAU5NOtMH7MDCkGPDMY08OHe2kLsi93f1VB7sLugvzsnx/EQhg/68dwWHwEz4XwSJCxDUYYJsMCxm3wMGKAKhhhGAsLGFfBI0WGHhhZUAmrjgF74JEfYwmMKAyEBYtL4JEfYwOMJw6mwWOYa4BHfoz1L8UgEAULGNe/I0PG7hdhmAMLGHe/I0PG4hdhGAILGBe/I0PG1hdhmAALGLe+I0OmdrEaMqWLY0P+0GgKQ8bfLFoNGX/z6tiQfUGoDBl4aMgAVIbsf9OrMmQgoiEDURmyI0pDBiIaMhCVITuiNGQgoiEDURmyI/YNGbdtNGRgSUN2MGnIyERDRqYyZMfUhoxQNGSEKkP2UGnICAVDRqYwZI980ZA7jteTNpf0tq3Q3+KcR/INyJAHjLNJrzTkx5n3/eX/ATcNo6IKZW5kc3RyZWFtCmVuZG9iagoxMSAwIG9iago2MDczCmVuZG9iagozMiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDc3ID4+CnN0cmVhbQp4nDM3NVIwULC0ABJmpiYK5kaWCimGXEA+iJXLZWhpDmblgFkmxgZAlqmpKRILIgvTC2HB5GC0sYk51AQECyQHtjYHZlsOVxoAnuAbmgplbmRzdHJlYW0KZW5kb2JqCjMzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNTkgPj4Kc3RyZWFtCnicMzU1VzBQsLQAEqamRgrmRpYKKYZcQD6IlctlaGkOZuWAWRbGQAZIGZxhAKTBmnNgenK40gCp4RBaCmVuZHN0cmVhbQplbmRvYmoKMzQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMDQgPj4Kc3RyZWFtCnicPZI7ksMwDEN7nYIXyIz4k+TzZCeV9/7tPjLJVoBJiQAoL3WZsqY8IGkmCf/R4eFiO+V32J7NzMC1RC8TyynPoSvE3EX5spmNurI6xarDMJ1b9Kici4ZNk5rnKksZtwuew7WJ55Z9xA83NKgHdY1Lwg3d1WhZCs1wdf87vUfZdzU8F5tU6tQXjxdRFeb5IU+ih+lK4nw8KCFcezBGFhLkU9FAjrNcrfJeQvYOtxqywkFqSeezJzzYdXpPLm4XzRAPZLlU+E5R7O3QM77sSgk9ErbhWO59O5qx6RqbOOx+70bWyoyuaCF+yFcn6yVg3FMmRRJkTrZYbovVnu6hKKZzhnMZIOrZioZS5mJXq38MO28sL9ksyJTMCzJGp02eOHjIfo2a9HmV53j9AWzzczsKZW5kc3RyZWFtCmVuZG9iagozNSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDY2ID4+CnN0cmVhbQp4nDM2tFAwUDA3V9A1NDRVMDIyUDA0MlFIMeQyNDQHM3O5YII5YJaJAZBhCCTBGnK4YFpzwDogslCtOVxpAE04EfUKZW5kc3RyZWFtCmVuZG9iagozNiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIyNyA+PgpzdHJlYW0KeJw1TzuyAyEM6zmFLpAZjG1gz7OZVC/3b59ksg0S/kjy9ERHJl7myAis2fG2FhmIGfgWU/GvPe3DhOo9uIcI5eJCmGEknDXruJun48W/XeUz1sG7Db5ilhcEtjCT9ZXFmct2wVgaJ3FOshtj10RsY13r6RTWEUwoAyGd7TAlyBwVKX2yo4w5Ok7kiediqsUuv+9hfcGmMaLCHFcFT9BkUJY97yagHRf039WN30k0i14CMpFgYZ0k5s5ZTvjVa0fHUYsiMSekGeQyEdKcrmIKoQnFOjsKKhUFl+pzyt0+/2hdW00KZW5kc3RyZWFtCmVuZG9iagozNyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI0NSA+PgpzdHJlYW0KeJxFULuNQzEM6z0FFwhg/Sx7nndIldu/PUpGcIUhWj+SWhKYiMBLDLGUb+JHRkE9C78XheIzxM8XhUHOhKRAnPUZEJl4htpGbuh2cM68wzOMOQIXxVpwptOZ9lzY5JwHJxDObZTxjEK6SVQVcVSfcUzxqrLPjdeBpbVss9OR7CGNhEtJJSaXflMq/7QpWyro2kUTsEjkgZNNNOEsP0OSYsyglFH3MLWO9HGykUd10MnZnDktmdnup+1MfA9YJplR5Smd5zI+J6nzXE597rMd0eSipVX7nP3ekZbyIrXbodXpVyVRmY3Vp5C4PP+Mn/H+A46gWT4KZW5kc3RyZWFtCmVuZG9iagozOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDM5MiA+PgpzdHJlYW0KeJw9UktuBTEI288puECl8E1ynqne7t1/W5vMVKoKLwO2MZSXDKklP+qSiDNMfvVyXeJR8r1samfmIe4uNqb4WHJfuobYctGaYrFPHMkvyLRUWKFW3aND8YUoEw8ALeCBBeG+HP/xF6jB17CFcsN7ZAJgStRuQMZD0RlIWUERYfuRFeikUK9s4e8oIFfUrIWhdGKIDZYAKb6rDYmYqNmgh4SVkqod0vGMpPBbwV2JYVBbW9sEeGbQENnekY0RM+3RGXFZEWs/PemjUTK1URkPTWd88d0yUvPRFeik0sjdykNnz0InYCTmSZjncCPhnttBCzH0ca+WT2z3mClWkfAFO8oBA7393pKNz3vgLIxc2+xMJ/DRaaccE62+HmL9gz9sS5tcxyuHRRSovCgIftdBE3F8WMX3ZKNEd7QB1iMT1WglEAwSws7tMPJ4xnnZ3hW05vREaKNEHtSOET0ossXlnBWwp/yszbEcng8me2+0j5TMzKiEFdR2eqi2z2Md1Hee+/r8AS4AoRkKZW5kc3RyZWFtCmVuZG9iagozOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI0NyA+PgpzdHJlYW0KeJxNUbttRDEM698UXOAA62t5ngtSXfZvQ8kIkMIgoS8ppyUW9sZLDOEHWw++5JFVQ38ePzHsMyw9yeTUP+a5yVQUvhWqm5hQF2Lh/WgEvBZ0LyIrygffj2UMc8734KMQl2AmNGCsb0kmF9W8M2TCiaGOw0GbVBh3TRQsrhXNM8jtVjeyOrMgbHglE+LGAEQE2ReQzWCjjLGVkMVyHqgKkgVaYNfpG1GLgiuU1gl0otbEuszgq+f2djdDL/LgqLp4fQzrS7DC6KV7LHyuQh/M9Ew7d0kjvfCmExFmDwVSmZ2RlTo9Yn23QP+fZSv4+8nP8/0LFShcKgplbmRzdHJlYW0KZW5kb2JqCjQwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggOTAgPj4Kc3RyZWFtCnicTY1BEsAgCAPvvCJPUETQ/3R60v9fq9QOvcBOAokWRYL0NWpLMO64MhVrUCmYlJfAVTBcC9ruosr+MklMnYbTe7cDg7LxcYPSSfv2cXoAq/16Bt0P0hwiWAplbmRzdHJlYW0KZW5kb2JqCjQxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzIwID4+CnN0cmVhbQp4nDVRu3HFMAzrNQUX8J34lTSPc6/K278NQDsVYRoEQKq8ZEq5XOqSVbLC5EeH6hRN+T5gpvwO9ZDj6B7ZIbpT1pZ7GAjLxDyljlhNlnu4BYEvDE2JuYXz9wjoKwajMBOBusXfP0CzJDBpcPBTkGutWmKJDjwsFlizK8ytGilUyFV8Oza5BwVycbPQpxyaFLfcgvBliGRHarGvy2Up8rv1CRiEFeaITxSJheeBDmYi8ScDYnv22WJXVy+qERnWSYcHUgTSbG4SMDRFsuqDG9hXxzU/T0fZwclBv4rB+DY4mS9JeV8FoRCPF/4Oz9nIsZJDJBTyfbXAiCNsgBGhT+0jEGUgNEX37plSPiZViu8ARiEcfapXMrwXkdlqhs3/GV3ZKgoGVVkfn0ZwJoNJOPNkowrTUrXTv/vc4/MHY2N6gAplbmRzdHJlYW0KZW5kb2JqCjQyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggODAgPj4Kc3RyZWFtCnicRYy7DcAwCER7pmAEfiZmnyiVs38bIErccE+6e7g6EjJT3mGGhwSeDCyGU/EGmaNgNbhGUo2d7KOwbl91geZ6U6v19wcqT3Z2cT3Nyxn0CmVuZHN0cmVhbQplbmRvYmoKNDMgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNTcgPj4Kc3RyZWFtCnicRZC5EUMxCERzVUEJErAI6rHH0Xf/qRf5SrRvAC2HryVTqh8nIqbc12j0MHkOn00lVizYJraTGnIbFkFKMZh4TjGro7ehmYfU67ioqrh1ZpXTacvKxX/zaFczkz3CNeon8E3o+J88tKnoW6CvC5R9QLU4nUlQMX2vYoGjnHZ/IpwY4D4ZR5kpI3Fibgrs9xkAZr5XuMbjBd0BN3kKZW5kc3RyZWFtCmVuZG9iago0NCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDY4ID4+CnN0cmVhbQp4nDMzNlMwULAwAhKmpoYK5kaWCimGXEA+iJXLBRPLAbPMLMyBLCMLkJYcLkMLYzBtYmykYGZiBmRZIDEgutIAcvgSkQplbmRzdHJlYW0KZW5kb2JqCjQ1IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzE3ID4+CnN0cmVhbQp4nDVSS3JDMQjbv1Nwgc6Yv32edLJq7r+thCcrsC1AQi4vWdJLftQl26XD5Fcf9yWxQj6P7ZrMUsX3FrMUzy2vR88Rty0KBFETPfgyJxUi1M/U6Dp4YZc+A68QTikWeAeTAAav4V94lE6DwDsbMt4Rk5EaECTBmkuLTUiUPUn8K+X1pJU0dH4mK3P5e3KpFGqjyQgVIFi52AekKykeJBM9iUiycr03VojekFeSx2clJhkQ3SaxTbTA49yVtISZmEIF5liA1XSzuvocTFjjsITxKmEW1YNNnjWphGa0jmNkw3j3wkyJhYbDElCbfZUJqpeP09wJI6ZHTXbtwrJbNu8hRKP5MyyUwccoJAGHTmMkCtKwgBGBOb2wir3mCzkWwIhlnZosDG1oJbt6joXA0JyzpWHG157X8/4HRVt7owplbmRzdHJlYW0KZW5kb2JqCjQ2IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTcgPj4Kc3RyZWFtCnicMza0UDCAwxRDLgAalALsCmVuZHN0cmVhbQplbmRvYmoKNDcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMzggPj4Kc3RyZWFtCnicNVI5rt1ADOt9Cl0ggHbNnOcFqX7u34aUXwpDtFaKmo4WlWn5ZSFVLZMuv+1JbYkb8vfJCokTklcl2qUMkVD5PIVUv2fLvL7WnBEgS5UKk5OSxyUL/gyX3i4c52NrP48jdz16YFWMhBIByxQTo2tZOrvDmo38PKYBP+IRcq5YtxxjFUgNunHaFe9D83nIGiBmmJaKCl1WiRZ+QfGgR61991hUWCDR7RxJcIyNUJGAdoHaSAw5sxa7qC/6WZSYCXTtiyLuosASScycYl06+g8+dCyovzbjy6+OSvpIK2tM2nejSWnMIpOul0VvN299PbhA8y7Kf17NIEFT1ihpfNCqnWMomhllhXccmgw0xxyHzBM8hzMSlPR9KH5fSya6KJE/Dg2hf18eo4ycBm8Bc9GftooDF/HZYa8cYIXSxZrkfUAqE3pg+v/X+Hn+/AMctoBUCmVuZHN0cmVhbQplbmRvYmoKNDggMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDggPj4Kc3RyZWFtCnicLVE5kgNBCMvnFXpCc9PvscuR9//pCsoBg4ZDIDotcVDGTxCWK97yyFW04e+ZGMF3waHfynUbFjkQFUjSGFRNqF28Hr0HdhxmAvOkNSyDGesDP2MKN3pxeEzG2e11GTUEe9drT2ZQMisXccnEBVN12MiZw0+mjAvtXM8NyLkR1mUYpJuVxoyEI00hUkih6iapM0GQBKOrUaONHMV+6csjnWFVI2oM+1xL29dzE84aNDsWqzw5pUdXnMvJxQsrB/28zcBFVBqrPBAScL/bQ/2c7OQ33tK5s8X0+F5zsrwwFVjx5rUbkE21+Dcv4vg94+v5/AOopVsWCmVuZHN0cmVhbQplbmRvYmoKNDkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMzggPj4Kc3RyZWFtCnicPY9BDgMxCAPveYU/ECl2Qljes1VP2/9fS5rdXtAIjDEWQkNvqGoOm4INx4ulS6jW8CmKiUoOyJlgDqWk0h1nkXpiOBjcHrQbzuKx6foRu5JWfdDmRrolaIJH7FNp3JZxE8QDNQXqKepco7wQuZ+pV9g0kt20spJrOKbfveep6//TVd5fX98ujAplbmRzdHJlYW0KZW5kb2JqCjUwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjEwID4+CnN0cmVhbQp4nDVQyw1DMQi7ZwoWqBQCgWSeVr11/2tt0DthEf9CWMiUCHmpyc4p6Us+OkwPti6/sSILrXUl7MqaIJ4r76GZsrHR2OJgcBomXoAWN2DoaY0aNXThgqYulUKBxSXwmXx1e+i+Txl4ahlydgQRQ8lgCWq6Fk1YtDyfkE4B4v9+w+4t5KGS88qeG/kbnO3wO7Nu4SdqdiLRchUy1LM0xxgIE0UePHlFpnDis9Z31TQS1GYLTpYBrk4/jA4AYCJeWYDsrkQ5S9KOpZ9vvMf3D0AAU7QKZW5kc3RyZWFtCmVuZG9iagozMCAwIG9iago8PCAvQmFzZUZvbnQgL0RlamFWdVNhbnMgL0NoYXJQcm9jcyAzMSAwIFIKL0VuY29kaW5nIDw8Ci9EaWZmZXJlbmNlcyBbIDMyIC9zcGFjZSA0NCAvY29tbWEgNDggL3plcm8gL29uZSAvdHdvIC90aHJlZSAvZm91ciAvZml2ZSAvc2l4IC9zZXZlbgovZWlnaHQgL25pbmUgNzIgL0ggNzYgL0wgOTcgL2EgMTAwIC9kIC9lIDExNCAvciAxMjEgL3kgXQovVHlwZSAvRW5jb2RpbmcgPj4KL0ZpcnN0Q2hhciAwIC9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnREZXNjcmlwdG9yIDI5IDAgUgovRm9udE1hdHJpeCBbIDAuMDAxIDAgMCAwLjAwMSAwIDAgXSAvTGFzdENoYXIgMjU1IC9OYW1lIC9EZWphVnVTYW5zCi9TdWJ0eXBlIC9UeXBlMyAvVHlwZSAvRm9udCAvV2lkdGhzIDI4IDAgUiA+PgplbmRvYmoKMjkgMCBvYmoKPDwgL0FzY2VudCA5MjkgL0NhcEhlaWdodCAwIC9EZXNjZW50IC0yMzYgL0ZsYWdzIDMyCi9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnROYW1lIC9EZWphVnVTYW5zIC9JdGFsaWNBbmdsZSAwCi9NYXhXaWR0aCAxMzQyIC9TdGVtViAwIC9UeXBlIC9Gb250RGVzY3JpcHRvciAvWEhlaWdodCAwID4+CmVuZG9iagoyOCAwIG9iagpbIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwCjYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgMzE4IDQwMSA0NjAgODM4IDYzNgo5NTAgNzgwIDI3NSAzOTAgMzkwIDUwMCA4MzggMzE4IDM2MSAzMTggMzM3IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYKNjM2IDYzNiAzMzcgMzM3IDgzOCA4MzggODM4IDUzMSAxMDAwIDY4NCA2ODYgNjk4IDc3MCA2MzIgNTc1IDc3NSA3NTIgMjk1CjI5NSA2NTYgNTU3IDg2MyA3NDggNzg3IDYwMyA3ODcgNjk1IDYzNSA2MTEgNzMyIDY4NCA5ODkgNjg1IDYxMSA2ODUgMzkwIDMzNwozOTAgODM4IDUwMCA1MDAgNjEzIDYzNSA1NTAgNjM1IDYxNSAzNTIgNjM1IDYzNCAyNzggMjc4IDU3OSAyNzggOTc0IDYzNCA2MTIKNjM1IDYzNSA0MTEgNTIxIDM5MiA2MzQgNTkyIDgxOCA1OTIgNTkyIDUyNSA2MzYgMzM3IDYzNiA4MzggNjAwIDYzNiA2MDAgMzE4CjM1MiA1MTggMTAwMCA1MDAgNTAwIDUwMCAxMzQyIDYzNSA0MDAgMTA3MCA2MDAgNjg1IDYwMCA2MDAgMzE4IDMxOCA1MTggNTE4CjU5MCA1MDAgMTAwMCA1MDAgMTAwMCA1MjEgNDAwIDEwMjMgNjAwIDUyNSA2MTEgMzE4IDQwMSA2MzYgNjM2IDYzNiA2MzYgMzM3CjUwMCA1MDAgMTAwMCA0NzEgNjEyIDgzOCAzNjEgMTAwMCA1MDAgNTAwIDgzOCA0MDEgNDAxIDUwMCA2MzYgNjM2IDMxOCA1MDAKNDAxIDQ3MSA2MTIgOTY5IDk2OSA5NjkgNTMxIDY4NCA2ODQgNjg0IDY4NCA2ODQgNjg0IDk3NCA2OTggNjMyIDYzMiA2MzIgNjMyCjI5NSAyOTUgMjk1IDI5NSA3NzUgNzQ4IDc4NyA3ODcgNzg3IDc4NyA3ODcgODM4IDc4NyA3MzIgNzMyIDczMiA3MzIgNjExIDYwNQo2MzAgNjEzIDYxMyA2MTMgNjEzIDYxMyA2MTMgOTgyIDU1MCA2MTUgNjE1IDYxNSA2MTUgMjc4IDI3OCAyNzggMjc4IDYxMiA2MzQKNjEyIDYxMiA2MTIgNjEyIDYxMiA4MzggNjEyIDYzNCA2MzQgNjM0IDYzNCA1OTIgNjM1IDU5MiBdCmVuZG9iagozMSAwIG9iago8PCAvSCAzMiAwIFIgL0wgMzMgMCBSIC9hIDM0IDAgUiAvY29tbWEgMzUgMCBSIC9kIDM2IDAgUiAvZSAzNyAwIFIKL2VpZ2h0IDM4IDAgUiAvZml2ZSAzOSAwIFIgL2ZvdXIgNDAgMCBSIC9uaW5lIDQxIDAgUiAvb25lIDQyIDAgUiAvciA0MyAwIFIKL3NldmVuIDQ0IDAgUiAvc2l4IDQ1IDAgUiAvc3BhY2UgNDYgMCBSIC90aHJlZSA0NyAwIFIgL3R3byA0OCAwIFIgL3kgNDkgMCBSCi96ZXJvIDUwIDAgUiA+PgplbmRvYmoKMyAwIG9iago8PCAvRjEgMzAgMCBSID4+CmVuZG9iago0IDAgb2JqCjw8IC9BMSA8PCAvQ0EgMCAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+Ci9BMiA8PCAvQ0EgMSAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+ID4+CmVuZG9iago1IDAgb2JqCjw8ID4+CmVuZG9iago2IDAgb2JqCjw8ID4+CmVuZG9iago3IDAgb2JqCjw8IC9JMSAxMiAwIFIgL0kxMCAyMSAwIFIgL0kxMSAyMiAwIFIgL0kxMiAyMyAwIFIgL0kxMyAyNCAwIFIgL0kxNCAyNSAwIFIKL0kxNSAyNiAwIFIgL0kxNiAyNyAwIFIgL0kyIDEzIDAgUiAvSTMgMTQgMCBSIC9JNCAxNSAwIFIgL0k1IDE2IDAgUgovSTYgMTcgMCBSIC9JNyAxOCAwIFIgL0k4IDE5IDAgUiAvSTkgMjAgMCBSID4+CmVuZG9iagoxMiAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDUxIDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3N0JwjAYQFH/ZnAA13AGx3YZdxBcQDAi3CZ4znOh4ZKXfiTd7y633VKe98fHZ47Xc7CScYetF/AXVC6oXFC5oHJB5YLKBZULKhdULqhcOG29gK/NNqMYmavYywWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhfWu1cy2/8xRt5lLxdULqhcULmgckHlgsoFlQsqF1QuqFxYb47h/xi8p3JB5YLKBZULKhdULqhcULmgcmG9L2wnBXhP5YLKBZULKhdULqhcULmgckHlgsqFoTnGVKOD2U4KjLCXCyoXVC6oXFC5oHJB5YLKBZULKhdULgzNMVYcHWTceJiFygWVCyoXVC6oXFC5oHJB5YLKBZUL7pX8yr2SWahcULmgckHlgsoFlQsqF1QuqFxQufACAIgauwplbmRzdHJlYW0KZW5kb2JqCjUxIDAgb2JqCjM1NQplbmRvYmoKMTMgMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA1MiAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7dDLCYRQEEVBPzkYgGkYg2GbzOQwYAizeZQMnlo39OXM035OI3yvz8+b9diG/Po7y9MDXqHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLNyiowRdCmVuZHN0cmVhbQplbmRvYmoKNTIgMCBvYmoKMzA4CmVuZG9iagoxNCAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDUzIDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3csJwzAUAEHnU0MKSBuuwWWnmfQQSAMG62AmMd45CyEWnQTiXabnMh3K5/XeXHObH+Ak466/PsApVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWF+14bsQEX/zYlY0R3WaiyUGWhykKVhSoLVRaqLFRZqLJQZWG3dwz2vNBE0KyrslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFvpXInSXhSoLVRaqLFRZqLJQZaHKQpWFKgtVFr7UoA5fCmVuZHN0cmVhbQplbmRvYmoKNTMgMCBvYmoKMzMwCmVuZG9iagoxNSAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDU0IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3cENgkAUQEFQarAAm/BgDZZtLSb2QGIHQiJ5Spw572Hz8i/AbhiH823gA/P9ubjmEOwDlQsqF1QuqFxQuaByQeWCygWVCyoXpm9v4KeteUdxvJ4W15jlgsoFlQsqF1QuqFxQuaByQeWCygWVC+P8uCwuWvOozhtmuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuTLs7a7HVXY+SWS6oXFC5oHJB5YLKBZULKhdULqhcULngXknBLBdULqhcULmgckHlgsoFlQsqF1QupCcF9viRfxNmuaByQeWCygWVCyoXVC6oXFC5oHJB5cJmf9L423cUa5jlgsoFlQsqF1QuqFxQuaByQeWCygWVC248FMxyQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFx4AWh6GAMKZW5kc3RyZWFtCmVuZG9iago1NCAwIG9iagozNTAKZW5kb2JqCjE2IDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNTUgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3dvWpTcRzG8d9JT02atrTFQqQWhToUpQQsiigoXbq56aS7m+AteBPiLDoIzt6BINFBERVEKUUUazE1TWNeTnrqHeQ5kPJM38/8kJN+c5b+yUsSFx+E1PkhJ/XTNblZXUrkptnORw/yI/kY0e7qC7X6enNnY1puHj77KDclucD4qOxAZQcqO1DZgcoOVHagsgOVHajsQGWHJNbuydHdSxW5aXzT5wvdTB8d7GXihV8si4OOiNju6AtFojdXTw3l5v7tstxwLztQ2YHKDlR2oLIDlR2o7EBlByo7UNmByg5pZG05+vrzhNwMc/2CyTOKiOj0/40enF/QT2b7r3iQiIh8oCf5vNzcvPFObriXHajsQGUHKjtQ2YHKDlR2oLIDlR2o7EBlhzQSHXp1WW8OCrwfY25Sbzp5dfRgdiqTDxKlyWPZnJzVT/jxi8v6UvrZYGxUdqCyA5UdqOxAZQcqO1DZgcoOVHagskMaE1Ny9OT1gdwszszKzbn5Q7lp/UlHD5pFPjNypC9UqczITSnRnyvZvLKjH0cuMD4qO1DZgcoOVHagsgOVHajsQGUHKjtQ2SGNw64cXV+Zl5utpn7rwnt1RhERPfW5kp2OPngpotfdl5vdffHmkIioX/gkN9zLDlR2oLIDlR2o7EBlByo7UNmByg5UdtD/8kbE2dqE3Pxq6/+we70iVxNvBKimBX5Ko8gnHnL9yYnpsr7Wqzfr+unoZ4OxUdmByg5UdqCyA5UdqOxAZQcqO1DZgcoOSaFfBC2gUuD1urakjw721Y95lgpcqLGrz2dWqvoXOTbrx3MXci87UNmByg5UdqCyA5UdqOxAZQcqO1DZgcoOaWT6Ew8x1N/csLa8IDcrS/p44cOW+NKFWxv6Gxcaz/UftTfQ3wDx/bc+eHn5RZ+HcC87UNmByg5UdqCyA5UdqOxAZQcqO1DZgcoOaQyacrR+piY3rb4+Fnj6Vm96ffFDnXMN/ROmkXXkZG+gzx/6mf5azc+P9PPhXnagsgOVHajsQGUHKjtQ2YHKDlR2oLIDlR3+AxOijdMKZW5kc3RyZWFtCmVuZG9iago1NSAwIG9iago3ODkKZW5kb2JqCjE3IDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNTYgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3du2pUURxG8RNzzG0mZGIkxmhh8AJioWin2NnmIewVfAErwdpnsLC1UMQ2jXYGIaCSgCLRjIMimdxNJvEN5ttFWNX61X/OnLPYzWzOZeDmvQdV8mH1b5yZbZ2OMzfOHcSZ5c6J/gNftwbjQS6MHcaZ3lEcqcbqPHS2lX8rXJKOhZUJViZYmWBlgpUJViZYmWBlgpUJVibUQ3lXoKoOduLI+v5AnHnzJe9jzF892X9g+VM+mZXeSJx5PD8aZ568/B1npprDcca1TLAywcoEKxOsTLAywcoEKxOsTLAywcqEgcatR3Foa/NXnGk0z8SZ0cF8e0OdtkPa//KGSVVwr8WlZr6PYj+PVE/vh42XyrXMsDLBygQrE6xMsDLBygQrE6xMsDLByoR6brwXh5Y28n0UjYI9ik7BPRt3Zvb7D7TX8r5BdZQv6vaVfDLPF/MqfPsunHDlWmZYmWBlgpUJViZYmWBlgpUJViZYmWBlwkB1/WGe2l7LM/VYnjkxFEdGRpr9B3a7q/mHGufjyOxI3ngpeVim+/p9nHEtE6xMsDLBygQrE6xMsDLBygQrE6xMsDKhnh3NoX/u5L/z0+OTcaZT8EhIK71as93MexRVL98jcXkqH2bhR77348Wru3HGtUywMsHKBCsTrEywMsHKBCsTrEywMqEu+ddb1eNxpNsrOE5vr+CU0t0EBW9lKLGxV3DCBSPTk/miXMsEKxOsTLAywcoEKxOsTLAywcoEKxOsTKivtfJbGT6ubsSZ3b06/9pg/njF3GR46UL7W/486cxE/jzpzES+8JmNvGmysJgf9XAtE6xMsDLBygQrE6xMsDLBygQrE6xMsDKh3i54OUHJAwQlb26YHsrbAsPpiYfqMO8/tIYLvpKRD1NdPJWHVr5vxxnXMsHKBCsTrEywMsHKBCsTrEywMsHKBCsT6uXuMYU+KviSxl7eXujElyUMTcSDfP6zE2cODvPGy8pmjrP+bCnOuJYJViZYmWBlgpUJViZYmWBlgpUJViZYmfAfKQ2JjwplbmRzdHJlYW0KZW5kb2JqCjU2IDAgb2JqCjc3MwplbmRvYmoKMTggMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA1NyAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7d09T1NhAIbht+UUWijWRorY+BWI0WgIMTExRBMXBgcHB3cHV3dHF+PoonHxB8gfMEbj4CAaBxVNjMCgDUWCxWJrP04RKP4DniYmTxzua35S6M1Z+no8TTy4dysokxO/5abyKyM37U6yh9eJ9h5U6135ImfG9ebbqvhBIYRcdlduvq5sy41+2/h3VHagsgOVHajsQGUHKjtQ2YHKDlR2oLJD1J/SH9WbcUq/UF8vRwf9cjOkjkMSISFfZPmHPqNI698ljOb1GcXWto7DtexAZQcqO1DZgcoOVHagsgOVHajsQGUHKjtEKxUdutEelJuLU1W5eb+Ylpvvla29B8mkPsc4Oqzf1EJpR26uX1mSmzefJuSGa9mByg5UdqCyA5UdqOxAZQcqO1DZgcoOVHaIRnL6foxrMwty8/rjuNzEm/roID3Qt/egHet7JMoV8SIhhMGMvmfjyatTcnO82JIbrmUHKjtQ2YHKDlR2oLIDlR2o7EBlByo7UNkhasT69ob7s/rj/OGD+ocV8vpntWIxaKtBCOHQAb1p9vDG9w3pM5NL59/JDdeyA5UdqOxAZQcqO1DZgcoOVHagsgOVHajsEHX1LRJhbERvpifX5Wb2uX6hzT/iF9rV94+ERlufUdSb+oEeo/mO3CyXj8kN17IDlR2o7EBlByo7UNmByg5UdqCyA5UdovSAHv2s6Q+sa9Ws3OR6eKBCd0dskvo/M4QePoSHYkG/qfWaftREMx6TG65lByo7UNmByg5UdqCyA5UdqOxAZQcqO1DZITpxRP9j+NvP+rDjS0k/pbLZ0gcMxYIYfFjUv/CFKf01GWsb+kDk3Omy3MzNc6fA/4HKDlR2oLIDlR2o7EBlByo7UNmByg5Udojml/QtB3duvpCbu49m5KbV0Q9CKK2K2ySyGX3+UK3rzfqG+MqOEMLTOf1YzRtXX8oN17IDlR2o7EBlByo7UNmByg5UdqCyA5UdqOwQ5bL6HonbD/UZxeXpmtw8fpaRm4T6w3f1AxfCVg+Po9g/nJKbsyf115zGm8Nyw7XsQGUHKjtQ2YHKDlR2oLIDlR2o7EBlByo7/AWaapMBCmVuZHN0cmVhbQplbmRvYmoKNTcgMCBvYmoKODE1CmVuZG9iagoxOSAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDU4IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3T9uE1EURvE79kNJREQENKkQEgVULIKSlpIlsD8qxCJCg8QKEpGI8EeOk9ihouW7kocjivOrr2aso9f4+c14ev3mXSXvP53FmTo4jiMvH93FmZOL5d8H9scUL3K1+hZnXj27H2c+frmMM7Vdx5FFvop2ZmWClQlWJliZYGWClQlWJliZYGWClQlTvXg7z5VufrZul+W9jsY1tnlmcS/PbG/iyNWHi3yrfCftzMoEKxOsTLAywcoEKxOsTLAywcoEKxNG3a7y1OGTPNPax5hjj6JjCoc6qqruNo3r5FW4us7nOlzLBCsTrEywMsHKBCsTrEywMsHKBCsTrEwYna/qtTr995/kj3iUovOBW3sU8+x1PDjKz564lglWJliZYGWClQlWJliZYGWClQlWJliZMGrvcZ66bZy16DzK0dk6iNsUrRs1Vs9i5JlN3sf4fnmUb5XvpJ1ZmWBlgpUJViZYmWBlgpUJViZYmWBlwqj1eZ5a7uWZztZB57mSec5jNPY6Gu++6PA8xv/CygQrE6xMsDLBygQrE6xMsDLByoRRU+OtkOMwz8z1Bsr8BXqWi/Su0/DrR47jWiZYmWBlgpUJViZYmWBlgpUJViZYmWBlwqjlQWNsjl/4q/fnFdybG+a5zvp6P864lglWJliZYGWClQlWJliZYGWClQlWJliZMGrkP4Joab0sofHmhsZ2SEPjrEXnMY7G5szp+cM441omWJlgZYKVCVYmWJlgZYKVCVYmWJlgZcKo9dc8NTVe1tjZOth0XpYQz350ngdpHCDp/BVq4zzG86ef44xrmWBlgpUJViZYmWBlgpUJViZYmWBlgpUJvwG99l5tCmVuZHN0cmVhbQplbmRvYmoKNTggMCBvYmoKNjM2CmVuZG9iagoyMCAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDU5IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3b1OUwEchvF/7fmoLU2L1ComOBg3ExcTYnQiGAaNk6uDXoGjcfAOXJi8AjcvQWcNI8GJ6CIQIFoLlZZC6fEO+p7E5J2e3/zmtDw9C4fDaWX93etQXn0Yy8239Z7c9I7qclPLJ7MHb9/n8iBr95tyc3tpJDdPVr7KTfVhV24uyQX+H5UdqOxAZQcqO1DZgcoOVHagsgOVHajskGRJIUcv7+nN541FuVloiWsUEbHfS2cPOu2hPMjzx1ty8/HTHbk5OLguN9G4ISecyw5UdqCyA5UdqOxAZQcqO1DZgcoOVHagskNSqejR3m99/WGleSE3Waqvh7Qa09mDyUQfpHfUkpt6TbxQRLTm+nIT40xOOJcdqOxAZQcqO1DZgcoOVHagsgOVHajsQGWHpNBXBWKpm8jNn0FVbsq81vGJ+OAv5/rMqOVncnNyqo9z2LsmN522vhGFc9mByg5UdqCyA5UdqOxAZQcqO1DZgcoOVHZIEn35IXYO9f0Yq8t6k6X6Foiq+tyL0FdDfu4vyM18U7/hpKo3v/p7csO57EBlByo7UNmByg5UdqCyA5UdqOxAZQcqOyQX+tJClPnfk7NzPUr1bR0xHIsPfjTW1zGuzg/k5vuOfhZou9mXmyiuyAnnsgOVHajsQGUHKjtQ2YHKDlR2oLIDlR1KPbmhWdcfxkj9ZhzlntxwPhFvKEv0O85S/Rd++UIRMTxtyE1kbTnhXHagsgOVHajsQGUHKjtQ2YHKDlR2oLIDlR1K/O0+Ikv1r/yjsd605vRrTdSDLEs8/SHqNf1tn9MSByqKEld5KvpM5Vx2oLIDlR2o7EBlByo7UNmByg5UdqCyA5UdkmmJ/3go83QH+cSFiBgM9YFydc/GYieXB/myeUtuyvzgm9s35ebZXX0czmUHKjtQ2YHKDlR2oLIDlR2o7EBlByo7UNmh1P0Y27v62z4fLetNLdebv0Pxln7s6nstXjw9lpvBsCs3qw825Gbtjb7WwbnsQGUHKjtQ2YHKDlR2oLIDlR2o7EBlByo7/AMxxYJeCmVuZHN0cmVhbQplbmRvYmoKNTkgMCBvYmoKNzgwCmVuZG9iagoyMSAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDYwIDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt2LENg0AQAEGMXYMLoA1qcNk0Qw9IBBSAE48l2IlP+tPqo3sM02e4nG1ZT2ee8xtschjZS3dWZaHKQpWFKgtVFqosVFmoslBlocrC698L/IS8UXxzM+kvC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRZ2gWoHXwplbmRzdHJlYW0KZW5kb2JqCjYwIDAgb2JqCjMxOQplbmRvYmoKMjIgMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA2MSAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7d29ThtBGEbhtT02TtxAQaCOrCipcydcSS6M66EAiQbFIBo3UbKO/1JTvZ+U1UGWzlN/GtuHqUa7zGj+/UeX9Js/caZrH+LIxfSY10nW21Gc+bw4xJlVP44zl7O8zlNhnTyh/2dlgpUJViZYmWBlgpUJViZYmWBlgpUJo255k6eO+9JSgxilP/x4mhfZ94UPmuSZYz7H+Hmbl3EvE6xMsDLBygQrE6xMsDLBygQrE6xMsDKhdYfte3+Ht+KZSeVQpXD+UJopuP70HGfcywQrE6xMsDLBygQrE6xMsDLBygQrE6xMaF2b56l94awjPkdRNE6PSbRFXmSzzjOTWZ4p/PD7x69xxr1MsDLBygQrE6xMsDLBygQrE6xMsDLByoTW7QqvYFSU3j0piM+HDPWFB1rny/IhzriXCVYmWJlgZYKVCVYmWJlgZYKVCVYmWJnQ3vsLnL5DPsBxLxOsTLAywcoEKxOsTLAywcoEKxOsTGiDvanAKfyry8pjCwP98E2frxA5ucQnycoEKxOsTLAywcoEKxOsTLAywcoEKxNa7UKJgd5mKEnHFJPCTRq7oa7+yFeY9n8/xhn3MsHKBCsTrEywMsHKBCsTrEywMsHKBCsTWjc5y1O734WlCscClUcg4oWfs/O8yP4lz1R+eOFm0dXrVZxxLxOsTLAywcoEKxOsTLAywcoEKxOsTLAyoXXbX9ynVZ7ryP+BsnKoUjDQOt+Wd3HGvUywMsHKBCsTrEywMsHKBCsTrEywMsHKhH8rBlNFCmVuZHN0cmVhbQplbmRvYmoKNjEgMCBvYmoKNTgxCmVuZG9iagoyMyAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDYyIDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt2LENwkAUBUFjqMEFuA1qoGyaoQckd+C7aB14Jn7BaXXRfyz7Zxn5f3/DzfO9DTe3tV79gFtQuaByQeWCygWVCyoXVC6oXFC5oHLhNTNyozgxc+TxlwsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXDs6cB18KZW5kc3RyZWFtCmVuZG9iago2MiAwIG9iagozMTcKZW5kb2JqCjI0IDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNjMgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3QSQ2EAAAEQQ4PCMAGGlb2msEDCRLgVYTQ9Z7HpMdh/Q2vcvz3y828LeDJfdPTBz6hykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqiycZ5QEXQplbmRzdHJlYW0KZW5kb2JqCjYzIDAgb2JqCjMwNQplbmRvYmoKMjUgMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA2NCAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7dBJDYQAAARBDg8IwAYaVvaawQMJEuBVhND1nsekx2H9Da9y/PfLzbwt4Ml909MHPqHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLJxnlARdCmVuZHN0cmVhbQplbmRvYmoKNjQgMCBvYmoKMzA1CmVuZG9iagoyNiAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDY1IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt2z1KA1EUQOEkJsYoCoJETGe0cgtWLsD1uQSXYeUWrLSzMIKIwd/MJFrYey/MeBjj+erLm8fhVY957db4tLV05pf34czK8RDYybcO9qX/zMoEKxOsTLAywcoEKxOsTLAywcoEKxOsTLAywcoEKxOsTLAywcoEKxOsTLAywcoEKxO6maHD0X44c313Gy+0KDKf+9loOA5nzs6PEitdhRMb2wfhzMvjTTjjWSZYmWBlgpUJViZYmWBlgpUJViZYmWBlQtt3JQDPMsHKBCsTrEywMsHKBCsTrEywMsHKBCsTUv9jNErT7igyPMsEKxOsTLAywcoEKxOsTLAywcoEKxOsTMjdY6xuxTOzacWtJD1PN+OhTi+eyTxyaSdO4eci3k68iiqzMsHKBCsTrEywMsHKBCsTrEywMmE5XzwUF5NwpneyGy/UXYtnyvdwxLNMsDLBygQrE6xMsDLBygQrE6xMsDLByoS/9+Ihoyj79SzUyazjPUYzWJlgZYKVCVYmWJlgZYKVCVYmWJlgZULuHqO7Hs+UrxW3UqPJw15i6i0eGezEM7OncMSzTLAywcoEKxOsTLAywcoEKxOsTLAywcqE3D3G/OOXt1GzQT/+RyJlMa9lGc8ywcoEKxOsTLAywcoEKxOsTLAywcoEKxO+AIovMa4KZW5kc3RyZWFtCmVuZG9iago2NSAwIG9iago1MDAKZW5kb2JqCjI3IDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNjYgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3QSQ2EAAAEQQ4PCMAGGlb2msEDCRLgVYTQ9Z7HpMdh/Q2vcvz3y828LeDJfdPTBz6hykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqiycZ5QEXQplbmRzdHJlYW0KZW5kb2JqCjY2IDAgb2JqCjMwNQplbmRvYmoKMiAwIG9iago8PCAvQ291bnQgMSAvS2lkcyBbIDEwIDAgUiBdIC9UeXBlIC9QYWdlcyA+PgplbmRvYmoKNjcgMCBvYmoKPDwgL0NyZWF0aW9uRGF0ZSAoRDoyMDIyMDUzMTE3MDAwMCswMicwMCcpCi9DcmVhdG9yIChNYXRwbG90bGliIHYzLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjMuMikgPj4KZW5kb2JqCnhyZWYKMCA2OAowMDAwMDAwMDAwIDY1NTM1IGYgCjAwMDAwMDAwMTYgMDAwMDAgbiAKMDAwMDAyNTYwMiAwMDAwMCBuIAowMDAwMDEzNDczIDAwMDAwIG4gCjAwMDAwMTM1MDUgMDAwMDAgbiAKMDAwMDAxMzYwNCAwMDAwMCBuIAowMDAwMDEzNjI1IDAwMDAwIG4gCjAwMDAwMTM2NDYgMDAwMDAgbiAKMDAwMDAwMDA2NSAwMDAwMCBuIAowMDAwMDAwNDAzIDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwNjU1MSAwMDAwMCBuIAowMDAwMDEzODUwIDAwMDAwIG4gCjAwMDAwMTQ0NTQgMDAwMDAgbiAKMDAwMDAxNTAxMSAwMDAwMCBuIAowMDAwMDE1NTkwIDAwMDAwIG4gCjAwMDAwMTYxODkgMDAwMDAgbiAKMDAwMDAxNzIyNyAwMDAwMCBuIAowMDAwMDE4MjQ5IDAwMDAwIG4gCjAwMDAwMTkzMTMgMDAwMDAgbiAKMDAwMDAyMDE5OCAwMDAwMCBuIAowMDAwMDIxMjI3IDAwMDAwIG4gCjAwMDAwMjE3OTUgMDAwMDAgbiAKMDAwMDAyMjYyNSAwMDAwMCBuIAowMDAwMDIzMTkxIDAwMDAwIG4gCjAwMDAwMjM3NDUgMDAwMDAgbiAKMDAwMDAyNDI5OSAwMDAwMCBuIAowMDAwMDI1MDQ4IDAwMDAwIG4gCjAwMDAwMTIxNzAgMDAwMDAgbiAKMDAwMDAxMTk3MCAwMDAwMCBuIAowMDAwMDExNTUyIDAwMDAwIG4gCjAwMDAwMTMyMjMgMDAwMDAgbiAKMDAwMDAwNjU3MiAwMDAwMCBuIAowMDAwMDA2NzIxIDAwMDAwIG4gCjAwMDAwMDY4NTIgMDAwMDAgbiAKMDAwMDAwNzIyOSAwMDAwMCBuIAowMDAwMDA3MzY3IDAwMDAwIG4gCjAwMDAwMDc2NjcgMDAwMDAgbiAKMDAwMDAwNzk4NSAwMDAwMCBuIAowMDAwMDA4NDUwIDAwMDAwIG4gCjAwMDAwMDg3NzAgMDAwMDAgbiAKMDAwMDAwODkzMiAwMDAwMCBuIAowMDAwMDA5MzI1IDAwMDAwIG4gCjAwMDAwMDk0NzcgMDAwMDAgbiAKMDAwMDAwOTcwNyAwMDAwMCBuIAowMDAwMDA5ODQ3IDAwMDAwIG4gCjAwMDAwMTAyMzcgMDAwMDAgbiAKMDAwMDAxMDMyNiAwMDAwMCBuIAowMDAwMDEwNzM3IDAwMDAwIG4gCjAwMDAwMTEwNTggMDAwMDAgbiAKMDAwMDAxMTI2OSAwMDAwMCBuIAowMDAwMDE0NDM0IDAwMDAwIG4gCjAwMDAwMTQ5OTEgMDAwMDAgbiAKMDAwMDAxNTU3MCAwMDAwMCBuIAowMDAwMDE2MTY5IDAwMDAwIG4gCjAwMDAwMTcyMDcgMDAwMDAgbiAKMDAwMDAxODIyOSAwMDAwMCBuIAowMDAwMDE5MjkzIDAwMDAwIG4gCjAwMDAwMjAxNzggMDAwMDAgbiAKMDAwMDAyMTIwNyAwMDAwMCBuIAowMDAwMDIxNzc1IDAwMDAwIG4gCjAwMDAwMjI2MDUgMDAwMDAgbiAKMDAwMDAyMzE3MSAwMDAwMCBuIAowMDAwMDIzNzI1IDAwMDAwIG4gCjAwMDAwMjQyNzkgMDAwMDAgbiAKMDAwMDAyNTAyOCAwMDAwMCBuIAowMDAwMDI1NTgyIDAwMDAwIG4gCjAwMDAwMjU2NjIgMDAwMDAgbiAKdHJhaWxlcgo8PCAvSW5mbyA2NyAwIFIgL1Jvb3QgMSAwIFIgL1NpemUgNjggPj4Kc3RhcnR4cmVmCjI1ODE5CiUlRU9GCg==\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T16:59:59.848824\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def visualize_prediction(idx):\n", " visualize_exmp(indices[idx:idx+1], test_set)\n", " print(f'Main class: {class_idx_to_name[test_labels[indices[idx,0]]]}, Anomaly class: {class_idx_to_name[test_labels[indices[idx,-1]]]}')\n", " print(f'Prediction: image {predictions[idx].item()}')\n", " plot_attention_maps(input_data=None, attn_maps=attention_maps, idx=idx)\n", "\n", "visualize_prediction(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Depending on the random seed, you might see a slightly different input set. For the version on the website, we compare 9 aquarium fish images with a volcano/mountain. We see that multiple heads, for instance, Layer 2 Head 1, Layer 2 Head 4, and Layer 3 Head 2-4 focus on the last image. Additionally, the heads in Layer 4 all seem to ignore the last image and assign a very low attention probability to it. This shows that the model has indeed recognized that the image doesn't fit the setting, and hence predicted it to be the anomaly. Layer 2 Head 3 and Layer 3 Head 1 seems to take a slightly weighted average of all images. That might indicate that the model extracts the \"average\" information of all images, to compare it to the image features itself. \n", "\n", "Let's try to find where the model actually makes a mistake. We can do this by identifying the sets where the model predicts something else than 9, as in the dataset, we ensured that the anomaly is always at the last position in the set." ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Indices with mistake: [10 58]\n" ] } ], "source": [ "mistakes = np.where(predictions != 9)[0]\n", "print(\"Indices with mistake:\", mistakes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As our model achieves ~94% accuracy, we only have very little number of mistakes in a batch of 64 sets. Still, let's visualize one of them, for example the last one:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDY4NCAxMDAuNDc1OTkzMzc3NSBdIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovVHlwZSAvUGFnZSA+PgplbmRvYmoKOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDExIDAgUiA+PgpzdHJlYW0KeJxVjztvwzAMhHf+ihubRSJlW4pHJ2mMjA4EdA5cJa3hR1MDffz70gH6GojDHXj8QEFHthJcZjA6nXcIathdentu07HeoJ2JNR/Ir3PV/qbCbPJQlGWmAf+3T0QjXRGMu433pfEIbMqCdSELocBrwgNG2MotYFGwKJhRa8+HBccI8nOiHWAPgt2Ehhpcv3uMy9/u4mkTYfcCcYhncnmu1ExcgXVu5JcfH+muGqfh1H8ifZyGlz7NmEZsD/vqqM+sEDvcR2roC5CcQPsKZW5kc3RyZWFtCmVuZG9iagoxMSAwIG9iagoyMDMKZW5kb2JqCjE3IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggODggPj4Kc3RyZWFtCnicNYy7DcAwCER7prgR+DiA94lSkf3bEFsuuHvSE+c5wMg+D0foxC1kQ+GmeEk5oT5RNFpvOrZIc7+8ZDMXFf0z3H2F7eaAZDRJ5CHR5XLlWSl6PpfaG34KZW5kc3RyZWFtCmVuZG9iagoxOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIzMiA+PgpzdHJlYW0KeJw1UTtyBTEI630KXSAz5m+fZzOvSu7fRrCTZmEBCQnnPdiIxJcY0h3lim9ZnWYZfieLvPhZKZy8F1GBVEVYIe3gWc5qhsFzI1PgciY+y8wn02LHAqqJOM6OnGYwCDGN62g5HWaaBz0h1wcjbuw0y1UMab1bqtf3Wv5TRfnIupvl1imbWqlb9Iw9icvO66kt7QujjuKmINLhY4f3IF/EnMVFJ9LNfjPlsJI0BKcF8CMxlOrZ4TXCxM+MBE/Z0+l9lIbXPmi6vncv6MjNhEzlFspIxZOVxpgxVL8RzST1/T/Qsz5/mjBURwplbmRzdHJlYW0KZW5kb2JqCjE5IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNzQgPj4Kc3RyZWFtCnicMzU3VTBQsLQAEqaG5grmRpYKKYZcQD6IlcsFE8sBs8xMzIAsQ0tklomxIZBlYmGGxDI2sYDKIlgGQBpsTQ7M9ByuNAADcRiTCmVuZHN0cmVhbQplbmRvYmoKMjAgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA0OSA+PgpzdHJlYW0KeJwzsjRVMFCwtAAShpbmCuZGlgophlxAPoiVywUTywGzDIA0WGkOTEUOVxoApUQM5AplbmRzdHJlYW0KZW5kb2JqCjIxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjI3ID4+CnN0cmVhbQp4nEWQS44DIRBD95zCR6D+cJ6OsurcfzsuOtFssCUo1zO5AxN78chMlG68ZLg7zBWf4Rkwc/hKmGzETOhOXCOUrhThVJ8IjsvevOmgiXtEzqOeBVnVzg1qAWeS5oLtgi7njBU3zsmtRuXN9KPXEL5pdx/XeYf2SOPew1S+zjnVzruKCGkLWdW0vpBsFMkOaz8qTdvOyxCx4GwaVugc3gi7V3cnSxh+v/IwJRM/D936UXxdN6PrFGcnVyZrz3noSelf9cqjD8VxKegXse3MJPdfp1OSqVN7Z+9p/ae4x/sPkG5WOQplbmRzdHJlYW0KZW5kb2JqCjIyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzA0ID4+CnN0cmVhbQp4nD2SO5LDMAxDe52CF8iM+JPk82Qnlff+7T4yyVaASYkAKC91mbKmPCBpJgn/0eHhYjvld9iezczAtUQvE8spz6ErxNxF+bKZjbqyOsWqwzCdW/SonIuGTZOa5ypLGbcLnsO1ieeWfcQPNzSoB3WNS8IN3dVoWQrNcHX/O71H2Xc1PBebVOrUF48XURXm+SFPoofpSuJ8PCghXHswRhYS5FPRQI6zXK3yXkL2DrcassJBaknnsyc82HV6Ty5uF80QD2S5VPhOUezt0DO+7EoJPRK24VjufTuasekamzjsfu9G1sqMrmghfshXJ+slYNxTJkUSZE62WG6L1Z7uoSimc4ZzGSDq2YqGUuZiV6t/DDtvLC/ZLMiUzAsyRqdNnjh4yH6NmvR5led4/QFs83M7CmVuZHN0cmVhbQplbmRvYmoKMjMgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDUgPj4Kc3RyZWFtCnicRVC7jUMxDOs9BRcIYP0se553SJXbvz1KRnCFIVo/kloSmIjASwyxlG/iR0ZBPQu/F4XiM8TPF4VBzoSkQJz1GRCZeIbaRm7odnDOvMMzjDkCF8VacKbTmfZc2OScBycQzm2U8YxCuklUFXFUn3FM8aqyz43XgaW1bLPTkewhjYRLSSUml35TKv+0KVsq6NpFE7BI5IGTTTThLD9DkmLMoJRR9zC1jvRxspFHddDJ2Zw5LZnZ7qftTHwPWCaZUeUpnecyPiep81xOfe6zHdHkoqVV+5z93pGW8iK126HV6VclUZmN1aeQuDz/jJ/x/gOOoFk+CmVuZHN0cmVhbQplbmRvYmoKMjQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA0NSA+PgpzdHJlYW0KeJwzMrdQMFCwNAEShhYmCuZmBgophlyWEFYuF0wsB8wC0ZZwCiKeBgCffQy1CmVuZHN0cmVhbQplbmRvYmoKMjUgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNTUgPj4Kc3RyZWFtCnicRZFLkgMgCET3noIjgPzkPJmaVXL/7TSYTDZ2l6j9hEojphIs5xR5MP3I8s1ktum1HKudjQKKIhTM5Cr0WIHVnSnizLVEtfWxMnLc6R2D4g3nrpxUsrhRxjqqOhU4pufK+qru/Lgsyr4jhzIFbNY5DjZw5bZhjBOjzVZ3h/tEkKeTqaPidpBs+IOTxr7K1RW4Tjb76iUYB4J+oQlM8k2gdYZA4+YpenIJ9vFxu/NAsLe8CaRsCOTIEIwOQbtOrn9x6/ze/zrDnefaDFeOd/E7TGu74y8xyYq5gEXuFNTzPRet6wwd78mZY3LTfUPnXLDL3UGmz/wf6/cPUIpmiAplbmRzdHJlYW0KZW5kb2JqCjI2IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTYxID4+CnN0cmVhbQp4nEWQSxLDIAxD95xCR/BHBnyedLpK77+tIU2zgKexQAZ3JwSptQUT0QUvbUu6Cz5bCc7GeOg2bjUS5AR1gFak42iUUn25xWmVdPFoNnMrC60THWYOepSjGaAQOhXe7aLkcqbuzvlHcPVf9Uex7pzNxMBk5Q6EZvUp7nybHVFd3WR/0mNu1mt/FfaqsLSspeWE285dM6AE7qkc7f0FqXM6hAplbmRzdHJlYW0KZW5kb2JqCjI3IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjE0ID4+CnN0cmVhbQp4nD1QuxFDMQjrPQUL5M587TfPy6XL/m0knKRCNkISlJpMyZSHOsqSrClPHT5LYoe8h+VuZDYlKkUvk7Al99AK8X2J5hT33dWWs0M0l2g5fgszKqobHdNLNppwKhO6oNzDM/oNbXQDVocesVsg0KRg17YgcscPGAzBmROLIgxKTQb/rXL3UtzvPRxvooiUdPCu+eX0y88tvE49jkS6vfmKa3GmOgpEcEZq8op0YcWyyEOk1QQ1PQNrtQCu3nr5N2hHdBmA7BOJ4zSlHEP/1rjH6wOHilL0CmVuZHN0cmVhbQplbmRvYmoKMjggMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCA4MCA+PgpzdHJlYW0KeJxFjLsNwDAIRHumYAR+JmafKJWzfxsgStxwT7p7uDoSMlPeYYaHBJ4MLIZT8QaZo2A1uEZSjZ3so7BuX3WB5npTq/X3BypPdnZxPc3LGfQKZW5kc3RyZWFtCmVuZG9iagoyOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIzNiA+PgpzdHJlYW0KeJxNUEtuRCEM23OKXOBJJCEBzkPVVef+27HDVO0qhhh/SA/pslUe61NidYns8qVNl8oyeRWo5U/b/1EMAm7/0MhBtLeMnWLmEtbFwiQ85TQjGyfXLB+PO08bZoXGxI3jnS4ZYJ8WATVblc2BOW06N0C6kBq3qrPeZFAMIupCzQeTLpyn0ZeIOZ6oYEp3JrWQG1w+1aEDcVq9Crlji5NvxBxZocBh0Exx1l8B1qjJslnIIEmGIc59o3uUCo2oynkrFcIPk6ER9YbVoAaVuYWiqeWS/B3aAjAFtox16QxKgaoAwd8qp32/ASSNXVMKZW5kc3RyZWFtCmVuZG9iagozMCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDMzMiA+PgpzdHJlYW0KeJwtUjmOJDEMy/0KfmAA6/Lxnh5M1Pv/dElVBQWqbMs85HLDRCV+LJDbUWvi10ZmoMLwr6vMhe9I28g6iGvIRVzJlsJnRCzkMcQ8xILv2/gZHvmszMmzB8Yv2fcZVuypCctCxosztMMqjsMqyLFg6yKqe3hTpMOpJNjji/8+xXMXgha+I2jAL/nnqyN4vqRF2j1m27RbD5ZpR5UUloPtac7L5EvrLFfH4/kg2d4VO0JqV4CiMHfGeS6OMm1lRGthZ4OkxsX25tiPpQRd6MZlpDgC+ZkqwgNKmsxsoiD+yOkhpzIQpq7pSie3URV36slcs7m8nUkyW/dFis0UzuvCmfV3mDKrzTt5lhOlTkX4GXu2BA2d4+rZa5mFRrc5wSslfDZ2enLyvZpZD8mpSEgV07oKTqPIFEvYlviaiprS1Mvw35f3GX//ATPifAEKZW5kc3RyZWFtCmVuZG9iagozMSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDE3ID4+CnN0cmVhbQp4nDM2tFAwgMMUQy4AGpQC7AplbmRzdHJlYW0KZW5kb2JqCjMyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggODcgPj4Kc3RyZWFtCnicNU25EcAwCOuZghHMo9jsk0vl7N8G7LhBOn0glBtr5AGC4Z1vIfimLxmEdQhPKrslOmyhhrMKkonhVzZ4Va6K9rWSiexspjHYoGX60c63Sc8Hpd4bmAplbmRzdHJlYW0KZW5kb2JqCjMzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTM4ID4+CnN0cmVhbQp4nD2PQQ4DMQgD73mFPxApdkJY3rNVT9v/X0ua3V7QCIwxFkJDb6hqDpuCDceLpUuo1vApiolKDsiZYA6lpNIdZ5F6YjgY3B60G87isen6EbuSVn3Q5ka6JWiCR+xTadyWcRPEAzUF6inqXKO8ELmfqVfYNJLdtLKSazim373nqev/01XeX1/fLowKZW5kc3RyZWFtCmVuZG9iagozNCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIxMCA+PgpzdHJlYW0KeJw1UMsNQzEIu2cKFqgUAoFknla9df9rbdA7YRH/QljIlAh5qcnOKelLPjpMD7Yuv7EiC611JezKmiCeK++hmbKx0djiYHAaJl6AFjdg6GmNGjV04YKmLpVCgcUl8Jl8dXvovk8ZeGoZcnYEEUPJYAlquhZNWLQ8n5BOAeL/fsPuLeShkvPKnhv5G5zt8DuzbuEnanYi0XIVMtSzNMcYCBNFHjx5RaZw4rPWd9U0EtRmC06WAa5OP4wOAGAiXlmA7K5EOUvSjqWfb7zH9w9AAFO0CmVuZHN0cmVhbQplbmRvYmoKMTUgMCBvYmoKPDwgL0Jhc2VGb250IC9EZWphVnVTYW5zIC9DaGFyUHJvY3MgMTYgMCBSCi9FbmNvZGluZyA8PAovRGlmZmVyZW5jZXMgWyAzMiAvc3BhY2UgNDggL3plcm8gL29uZSA2NSAvQSA2NyAvQyA3MCAvRiA3MyAvSSA4MiAvUiA5NyAvYSAxMDEgL2UgMTA4Ci9sIC9tIC9uIC9vIC9wIDExNSAvcyAxMjAgL3ggL3kgXQovVHlwZSAvRW5jb2RpbmcgPj4KL0ZpcnN0Q2hhciAwIC9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnREZXNjcmlwdG9yIDE0IDAgUgovRm9udE1hdHJpeCBbIDAuMDAxIDAgMCAwLjAwMSAwIDAgXSAvTGFzdENoYXIgMjU1IC9OYW1lIC9EZWphVnVTYW5zCi9TdWJ0eXBlIC9UeXBlMyAvVHlwZSAvRm9udCAvV2lkdGhzIDEzIDAgUiA+PgplbmRvYmoKMTQgMCBvYmoKPDwgL0FzY2VudCA5MjkgL0NhcEhlaWdodCAwIC9EZXNjZW50IC0yMzYgL0ZsYWdzIDMyCi9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnROYW1lIC9EZWphVnVTYW5zIC9JdGFsaWNBbmdsZSAwCi9NYXhXaWR0aCAxMzQyIC9TdGVtViAwIC9UeXBlIC9Gb250RGVzY3JpcHRvciAvWEhlaWdodCAwID4+CmVuZG9iagoxMyAwIG9iagpbIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwCjYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgMzE4IDQwMSA0NjAgODM4IDYzNgo5NTAgNzgwIDI3NSAzOTAgMzkwIDUwMCA4MzggMzE4IDM2MSAzMTggMzM3IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYKNjM2IDYzNiAzMzcgMzM3IDgzOCA4MzggODM4IDUzMSAxMDAwIDY4NCA2ODYgNjk4IDc3MCA2MzIgNTc1IDc3NSA3NTIgMjk1CjI5NSA2NTYgNTU3IDg2MyA3NDggNzg3IDYwMyA3ODcgNjk1IDYzNSA2MTEgNzMyIDY4NCA5ODkgNjg1IDYxMSA2ODUgMzkwIDMzNwozOTAgODM4IDUwMCA1MDAgNjEzIDYzNSA1NTAgNjM1IDYxNSAzNTIgNjM1IDYzNCAyNzggMjc4IDU3OSAyNzggOTc0IDYzNCA2MTIKNjM1IDYzNSA0MTEgNTIxIDM5MiA2MzQgNTkyIDgxOCA1OTIgNTkyIDUyNSA2MzYgMzM3IDYzNiA4MzggNjAwIDYzNiA2MDAgMzE4CjM1MiA1MTggMTAwMCA1MDAgNTAwIDUwMCAxMzQyIDYzNSA0MDAgMTA3MCA2MDAgNjg1IDYwMCA2MDAgMzE4IDMxOCA1MTggNTE4CjU5MCA1MDAgMTAwMCA1MDAgMTAwMCA1MjEgNDAwIDEwMjMgNjAwIDUyNSA2MTEgMzE4IDQwMSA2MzYgNjM2IDYzNiA2MzYgMzM3CjUwMCA1MDAgMTAwMCA0NzEgNjEyIDgzOCAzNjEgMTAwMCA1MDAgNTAwIDgzOCA0MDEgNDAxIDUwMCA2MzYgNjM2IDMxOCA1MDAKNDAxIDQ3MSA2MTIgOTY5IDk2OSA5NjkgNTMxIDY4NCA2ODQgNjg0IDY4NCA2ODQgNjg0IDk3NCA2OTggNjMyIDYzMiA2MzIgNjMyCjI5NSAyOTUgMjk1IDI5NSA3NzUgNzQ4IDc4NyA3ODcgNzg3IDc4NyA3ODcgODM4IDc4NyA3MzIgNzMyIDczMiA3MzIgNjExIDYwNQo2MzAgNjEzIDYxMyA2MTMgNjEzIDYxMyA2MTMgOTgyIDU1MCA2MTUgNjE1IDYxNSA2MTUgMjc4IDI3OCAyNzggMjc4IDYxMiA2MzQKNjEyIDYxMiA2MTIgNjEyIDYxMiA4MzggNjEyIDYzNCA2MzQgNjM0IDYzNCA1OTIgNjM1IDU5MiBdCmVuZG9iagoxNiAwIG9iago8PCAvQSAxNyAwIFIgL0MgMTggMCBSIC9GIDE5IDAgUiAvSSAyMCAwIFIgL1IgMjEgMCBSIC9hIDIyIDAgUiAvZSAyMyAwIFIKL2wgMjQgMCBSIC9tIDI1IDAgUiAvbiAyNiAwIFIgL28gMjcgMCBSIC9vbmUgMjggMCBSIC9wIDI5IDAgUiAvcyAzMCAwIFIKL3NwYWNlIDMxIDAgUiAveCAzMiAwIFIgL3kgMzMgMCBSIC96ZXJvIDM0IDAgUiA+PgplbmRvYmoKMyAwIG9iago8PCAvRjEgMTUgMCBSID4+CmVuZG9iago0IDAgb2JqCjw8IC9BMSA8PCAvQ0EgMCAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+Ci9BMiA8PCAvQ0EgMSAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+ID4+CmVuZG9iago1IDAgb2JqCjw8ID4+CmVuZG9iago2IDAgb2JqCjw8ID4+CmVuZG9iago3IDAgb2JqCjw8IC9JMSAxMiAwIFIgPj4KZW5kb2JqCjEyIDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDY3MCAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgNzEgL0xlbmd0aCAzNSAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCA2NzAgPj4Kc3RyZWFtCnic7P1pjyRJkiUI0sHMIqJm5kdE5FVHV89ODbDzK/vfDRr7rYHG7s4A0zWzdeURER7uZqoiwkxEbz+QiB0emTWD3V4s0AhJSw9zVzVTOZjpePToEf+H//Af6Jfjl+OX45fjl+OX45fjv5VD/v99Ar8cvxy/HL8cvxy/HL8c/zWPX1z7L8cvxy/HL8cvxy/Hf1PHL679l+OX45fjl+OX45fjv6mjvP7Lpx/+8PjlxwggQEyqerksrbVwd3cPIAIgLXWal3m5TMsda0EEs4gqgtzN+25jQwzAfHQbg0AEJiIm4VrLtNy//3D37n1pRYuICAsRgwAgmCDMTERECHJ3c7PR99u6PV5tNxiICML7tq7b9e7h7sM374mbxXEtTDbpPzA7CTETMwmTMCmziDABgHtEBLMwC+fHMTGTKrMQiAIAKEAAI4jBRExgAOGOCGKIiBZlYQJHhHkwkaiKEDMDHhEszCzEBCKPCID5+QPzBIkAJggLsfx4/VuPmhfyq9/8+rd/9dutr+t+2/bbbmvAiFwLF+VSpKiIMIPczMzNI8DMEoC5IwCAmUWUmYkYiHy4APJW8XFPSJi1aC1aipZSioqIMDOBIoKCmFlFi1bVqlqYGCBiUVER1fNBiqhI+dMfv//P//n/mVcB0E8/dXMgwMKl1I8f3v/NX//2V7/6+OH9fZvKeSL8s8UJAuH8PsIiPDzcMHqMgTHQd6wrbmtsa2w9xggLApGF79aH7cO7x4gYEcMxEAY4EJR3ISK8u3f3gTCVwqzh5j7gA/D/6//wd/f3lzyDz59//Kd//F9r1dZEi6hQBCLgDndiYhGppU7T9HB3d5kXJuFcPMyl1FrbNE1alAg2xtZ392DhoqWUtizL5e7ucnc3LxcQudu2bX3fS9FairuNMbZtW9fterteb7en23Xd12E23Bzh7uZh5u4R4QEQgQXCJMr3l4e//7v/8fkOm3cQVIpqq22qbW7TrLUKs1BwdBUutQkHYycCwJAGbrZ9Htfvx/rJty9MXTjyY0q7K9ODTu9I77adAvX+46+Wh/eqlUWeHyET0/NDfvOXP38cj5+JiQEA+Kf/7X/+8un7fDXgu91ERERbm6ZpIoKF7du6712lTnW5v7y7v3tYLpd5nkSJGKMP96hapmm+v7ub51lVRfjYD28+///g9P4/Pq7X6//0H/8ndz8u5O5H1+u+2xiezygiABBLLfWyXB7uHt4/fLjMl1amCN/76tEBZwkVRHiE19JqnX1Q3/zpy/r0Zd17d/NWp9ZmkSIirMEKrXkfHMA01XmeLstcSomIfdufrtenp+v1+jSGUcDNbZhHgFCn0paJqTDKcnl3d3noK8Z6XEWP+N+uK4lKmVSVRYLcyEABhrtbXq+wFlWVILcY277ufTXf3QdggLMHR1CAEUTEfDwHRBABDAiIQAE4GMTIdZR2Pa04MzGBOYiD2JmDCMwkIsoiYCJmUi6lztNlbvPff/xtEc0L+Yd/+Id//ud/HqOPYUR02Orzz9NyEqfnOE8Rx2IB8GKziImZ0yyKpBkWJhapKkVK1TrN9x+n+/dtfqjToqWl/RVRTpfDp4MK976vjz9cf/zXvj1RmIhoKZoGN3dxLa3q8zJ749ofv3z6w7/87+nGmbnV+s03Hy93i5n13sdw96CgNs0PH77RotPdnRZ1Z2bRUuMwLmb75rYitn277etKDoYwCbHIfJke4t2331ze38/LVNO7a7r+QBgj0hERKBxjjN73fY+wuNlqtw07QASlp8efPn/+kfHNxw+Vi7xcC0fVfxTpLMxKwqRCRaiKVBUGATGGDTNhVa18HGAmLayFA+GABTwIwQhhUkZa6rAxwh0MLdKmqqoR5Ba9GzOXWlRFlCLM3VhYVIk5iIabRYiwsKRfV86oBgyIKlP5vP7u2bW///j+3//9/+XL9dOnxx8+PwHrZuHgUSvVJnPT1rQVFaLRfQzr3S2IWN3R+0jvLiKqhUWYxD3S9iOQO0ZENVedyFR5mnSedZ5qa7XWDFs4HBQkLFVra0urcy0Ts4aDWUpptU6tNtWqUlRrKe1/+V/+y7NrJ6LrDb1HRKjqspQ2vf/rv/73f//f/+3vfvvdcmm5AYjTfOfWfHZDTAARQGHWzfrYo+9+fbL16uvq16f49BOUPYbt3dyjG4Jot37tT2u/7v3W7Wa+mcPcIiLCEEbhiIgw67fRb2Pc3HrVplLH2G1sPlZg/Lu/+c2za396+vJf/tf/x7KUy12dmpQiEWGG3n0MEIlKWabp4f6+xHczvy9cmJSD09otypdWa6tM2BjU1wET0iY8t/L+/v6bb7/9+O237z5+BNG+749fPj89PrZW51bH6Ou6fnl8/Clsu/nYrk9ffvzp8cva12304dbNtn30bn0MM/NwYghDC5XCv/rmN3//d//j8+MIijROUkqbLvPdw+X+3TTPqqpkYrdaaJqaslFcKRwkIZeQZX+0jf6wjS8dvxfcVAaYSEqt39aZ6v091fb0RM6XX//2rz/85m9KW0QLBehV8Hb69DcWM83i8f/DSL6K7NJuBz59//u3rv2qoqW0SWudC5HD4H3scWs8N5mnZf7w4ZsPHz8+vLsrjYljW/ewmNv8cPfw3be/evf+XWtVi+QHvjjzw/z//8S7f//D9//x//Yfn1075s/ePm3YthgsQkQ2zAEmmcsyXT62b9rHX+vHd/eX5T58PF0xDBEhSqVQuJuNeaqXue4rXT8D7Ou2ot8sRpP70rQW0SJcIM3LBIjzOgj8cKfv3umHD/PUmlk8PYG/f9zjausPu6/wGLv1be9jWPj8MN+1O+FZaLm0aXr3McKfXfsI/L/WjXSq81y4KIszdbJgBMUeYx8dRKxStRZRQ9/j9th/erp93sfTsDVihw8ezsMoXCKYwEwsAgq4gwIakAAHWWAEO6ejJmYEEBASYWEIh4izGPMAGQsko2hRDSKoaCttuTzcfXhY7v+7D78udDjFf/mXf/lP/+k/reu671suvpdk7HTt6XBVziSUheRYQkAgQIchy3RRVaXWUoqyiLCUspR6KdMyXd49fKf3cX8nbWnvq1xanUubiioxCRMLq7Coxti7f/5p/fTHP/ywfv5j+CgqtbWqpRSd7x6Wuw/v7i9/0bXnaQKEM7lkERElcoDd4Q4VFSkqRVWPnFiIED52Mxt9c+sMCJEHKEARBHqOv5mefS2YX9JLPly7U0QwKJCBwhjDzdxGpF0+DH3m0whgmG3bVqaJZD63JQAPOIMlCEIBeJCQqKgyMxMLgRAAR4gIQEAEXJ3FyDJxc/cAQYS1aitamTVjPxwZH4Uf35i5mbMIC5hBzBEEIDwiHEzgDMbDg4OZiZWIVYXljAQDFK+N27Bx3Z6+XD//9OXHH7/84afrD92uQZsULwW1Umu6zLXVIswAhoUZIdjMt73bcABMIgcWIRERERkN41h+LMJFtRStRdqq89yWqV0uyzLPc2utttKKsgqJCpUiWkiUAPFwImVBAWeUEiDJuBr+el3d3733mSJsnuuHb97/7d/+5je//ebd+4sUAgVzLq3offQxSinzXEspqpo7C0QIjuAxeNvp9oRPP46fPu1Pj+Px0R6f6PEa15vftti7d3eHd99u/anbtY+b+Wq2mW8Re8QAjBK7CEd4+IgYPvaxry5DpUZYeI9wQrwxyedTd881mokAMTMI7jFsIIiJp1I5MNdpatPUltbKNFUR2ffNfLRaItzdATCIgDC30fu+933v+xbAtq3Xp6fHL5/z6dgY27Y9Xp++PH758uXz0+1p27dhw9z92CZu5sNsDDd3dycKFoiTGfVur69jmt+zatFSS6vzpdRJtaiWWquyiHTFin612GC3MIsI0gadbf1s4+bwYCFSYVMlKeQe42YcG9cnMy5NmIwBJvAb5IWIGMedAzEfFhCgP3OAcq8zEZhyvb59HBFGABHbGDYG2M1NhNpUlQsx9rF9uX7mipBNG4lCSLVU0swKPRDERMxA4CW04Nwhf/a0/qsfuQfbVFhYVCKwAdGdgmPQ6LFt/em6tnoTEaIImHnvYxOJBgIccDPbtn19wuOXsa0WzsyqAiIBSLTUOoV0D9tvV6NBwa02ViKOPnbAw8ndRbSUqqVorZAgCIPdsPW+X7uITlOZJiZixKv0lIiInSpLDZkgjYuqYtLJyQ02sGJQEJjEQsLIIrrFiDCCMwVziBICEixHnCVgOWBFTstPwqwMFhDIiQH2dOpBwXIEkCzBElJRFGLWvbvUUrSwaG44IZ6mZZruwDriLz0ZypXwvBASOspHRszyDDQKsxQWEQpEBDmBSXJRubuDghNhcCZWUhGdeXqQ5V2ImHcPD6YQMdbwGMBzvsMsIujrdv3x8/d/+MO//OM/Xj/9PmIIs2oRJhG5e//t/Ydf/81f/fr9w/x86uXra8F5BbnScXxFwN3DoVJEMtnTTHaFKeDubn0f+wYzSVAigxckGMKgIxBWosJQAZMDDneKECZO9xoBRLjZMDOzYREgINwDAUTe6QjKv5jZum+z3pX26iLgBD/Tk9w/AGsEiYqwEINO3D0CRGTm7kOUWGi4DRv9QC+kSKEp40cwSaaVCDpDyXTtYR4SLIKM9kGM4EAi0QBTEOLcDRlsKhH0OVdg4jd7xcP2flv3p8f18+enT58e/7SNx4EryRB1LVEr312mZZ6mqSqLGbkhDGP4vnVLII1YRJnkxZ2TMHOAzvSdWitTa7VILdKjDm9OF+KL6l1rS6lLq00AZdZ08sxmPGwQhAU1GKQBIIQQRAh/7Uu41TmUgXF/P//m19/87nfffvfd+/uHRQsFgpy2bVwf18en6/W2znN79/5umqZWq6qycG7sbfP1Fk9P/uXz+NMf9+//dPvyeXt8GreN1p32HvvwYT7CPHr3dbfHYTePzWwzS7/eAaNcRXF4aUSHm48++so0hPVcGEbkX1kvICLYzb1yCWYmYc7Y3T3MEBYIFFaO8OUihHmaSpNShSjWrRPzMk/hER6BgCAiwt3NbIzR99F3j9i29XZ7+vL4GR5M5Ga99+t6e7w+PV2vt3Xdex9mZodfH916t95z03iEgZB704XGW+vVpnutU1EppdRprm1SLcIsQkoQDvI1xqfo17DNRzcbJMqljr7buDksRBBCJCqhQiNiM3fbUNpU6zSH5iLLfZBmBQldUvr0vJsv3v21KX1+8TmRf/41bw1VhINBRMP6GJ04PIyZW6tCykK7bY/rF6pmfJMaWmlul7neiXP3bR97H3tpymA6rN3z7yZ6fVr/dY+vQxQwI1HYUkoAbuED7hwG23279cfpqZQiSqpsGBY2RhcNFmEKAsYw79vTYzx+7ts6IpD4MxFFBBOLaBCZ2a2vI7apTa0VUYBi37feBcE2nEVKrbW1NA6hUbj23cKw3zqIGNPUgMia5iuQhSlYmAu4QEqwspLWSmxwEw0WJwBMDnaP4dGHJ9QPZuSKYiU2EqLg8/ZzQu4H2CPEwhACcShlkkQAGQ4onkVI2EVcKpUC9TEwgotoqWAhShiWpzZN0+wQ+4sx3Eum/tq7n8+Q08kDEGKWBNI5y71ETCzuw20neAQ8gtyPwnSFkqBMVCaL6GMzeAY3zuwBijgRIxCDGett/enTj3/8wx/+9V/++enT7+FZLBBCEOPhm9v7jR4e7l6f/RvXfqYlEe4UGDy2dSWCu41h7o50ZBlA0VHWIM6k9AwBjtQwQGBiVSVmhgQkUzBOpwc/CxRBiHAnAhMDYWZuw80jPD8iAGSkyAFygII4CMTsgLl7+OsrAQXBQWCkictnkLCeQJQJImllQJQRhZs5BxFj2Bhuh9+jCNiwnRAhVbgwNEO3jDCYE4hhBIVkZEIQYmJhASQhmiCQ8JnxH4sFkbeDWPLWvFlZwiglakGr1JpMTZ00XMGRaHqA1t3N9r2HsoQDTuSZlCZGkoCJs5CKPh8i6uZ9jN7NhpfSLst0WaZ5qkW4KLdCQl4ETXmpOtUiECFVUSZ2DDff9y2CI2NuhXANJ5Fay9TH9vppfP70pXf32N0ffve7b4ryM6fBzXq33//++//9H/7pD3/406fPn+/ul2+/+3C5LNM0LcuyLHOtE3N9vNKXL/746F++9B9/XH/6vF5v+7aNMRKuCHMbZha7xW6xDV/d9/TohJ2icwyCE14tUXfEoBgEo1xHYBFmIcBzib3Z68wHfDXChTO4FSEVFmYC3KJjbGVfi061OhwcAd9HN4vr0xqBfZlEJNwBGmEELlqISVRI+ED0IszGvm02BuKIKSIgrCqlSGFWBI8e2257t33YPnKPHnXalzM/mS6vtjkRiEW1tjYtrTVm8rFF7wOrxpXtM40fya4cW9bicoF7hIWDjJWGy9Y5AFBwkaqlG9ug+XI3P3xT53spjVgoEfgX6P01Ak9fff+zW42XV5l/7tyzbufhw0bvmygn8MokqqqixNF9e9q840YyWFDLPJVlaXf3l7Wbdevf4Nv7u7tSisgLp5jBz4HIf/3j7RWPMUw6KFiltspSvAeMNnc4bMS27V++fCEK0JjnKgJi18KqIioUcCc3846nx/H02PfNbVjAieEx+pDSV1AMbD1Wiw52YrAAwBhj2zw8ACEIs9Ta5mnJnR4jehn7Pmq97X3fnvai27KMMXrvm/sL/EtAwARG7Ob7GE4DXJmUwGChNtUjb4O7GYaRmQbaWRMnFrAQC5FmtfysqTtRgIwkOMCSUFAWsVmCSZhKFteZcKLxIQgKDxgQuQUKS2EmLVZqEiwCQiR/Jm1/DcK/usSMRU8A+iiDM5MIq9a5zYuWKqIshVnGftvXR7ctvBNODgyTjb2vX0TB2NzuRDnCjsvhN2Esjj/Cx749/bRdP49t9TFOGkoggpgcRFJY9PUl/My1J2jrAQYTb9sKCiLKfybQ6foz7/AIgTs8CCAEPMKcwsMtwokoEy8KIWdACEyJhXraLCEKgieULKoA3N3Mww1I6BrIIEmIBJGuHZIkMUrH/3XglYlXUADMR1yQLpzBYCKcBf10+OQOj2CAGFmQJmE5NnhEDEOQhDBUCGdnQX4ywIgEE+gIII61oZKMu7PGISR5Ew+fGxn+gI/CzJurEEbRKEq10FR5aupUyKoTgVWEhOAjwsiGC4ODBGeNJlmDHBm/qEgtUmuptdZStEh4jKHrNvrud8vy7v7+4W65LBMTGCFEVUqT0rQ2LU2LBAtEiAMxhvWtb7cnD/IwomAGSw0jlebVx+iv98PTl6dtM0efJ7VuY9jttpaqzLTv/enx+l/+yz/+3//z//xP//z7Hz99urtfvvnuw7LM0zQ9PNw/PDzMy72W5elaPz/q0xVPV3982q9r33s3GweuHhE+hneP3WL16IEeGAEDBqEzDaaTQ3f69XBPp87kTEER4cQqTAzyl831YpI5g28zqEZlYSWRpDSSC4WFWex9rHtfxhg2RrL5zHv3x+stHH3sU2u1VGGO8MzXSq1tarVWURUCC3vEPnrfdniIqLAwycFx1Mok7tS775tv3fqwPszCQfHCkGTmZLh81QiT0BfAzKXWUkuM3frN+2f2p4Ib+xf2TxJXoc1s9G7uFjHAHKwECRYLdGNmFqXaSimNXB2q0327+1imi2g5S3xndf0v+PU/693z9J69OxJ0+uphMD2HQWPslWtROUkaIHGPGGG27euAoUeYsFad76b7h7vbtu/DTEtRlWW5lLdUTj5PnvB/gvL3f+54XZ14Ptx88CCBsrCgFK5Na9XBEQ4337c94MEDavcxL1OVwwkeDjCCRvd+8/XWt7Xb8IhgJhUmivC+7zezbtiN1lDTivxRN98D27aN4QRJOqdqmacpGc42nGiv01Rr6/sYfVi3g9O539hnovrqUgw0ED2Cx9gxgl20iBQpImVqQfCI3vcYQeZsXogg4sIeRJBgZSmkmamyhHM4zAkOCippgDKbYrCSkBxMYMWRG6pAGSLBDIYHjjqnEBWVSUpRMakIxPCNtMZfaBN7FYzySeakgPPL5mJiYUpsXlVbbZd2uavTRUtjkr49aW19e7Kx2r752AEiho2dCMyD0Fk44kNylEWE5QhpzpSRMjYP2/frT/v1s9l+lFSPSlaSmUtps5TXz+Ir1x44YG8AyVYMR7hIssyyem7uo/dtXW8QLqUSkAFMWLj7GH3sO6ITOjOUWVhFChOZgwjuPsYYfaCwSCH2cLPREaGqzJwplbkhnIUi4BYeyYUgKAUR46RCMich8/WFPO/vk5OV/HAKRmSx5aXWEFksAAUQLCzKhYscuDWfFcIEBSPgFCcmT+SRSAa5IxAifNA4CZIQJyeTGoeHZaXMqAECh5MBhVlEzg95awtA4cNtQ+zK1hTM6qCAg4iCWISJ4QBBKHkQykTuAXKSrAezMCmxBLEFkQnpVNu75RJ35I6pzcu8LNM81UkSS2IqWma9VFokJpiaB4UTuRn2Pa5P69PnJw/M3cmDAZUawaVchEv4m1q7qrRWRds83Znhx+8/u1ttxcMevzx+//2P//Ivf/znf/z95y9P27Z//vL4hz99n7TSeZ4vyzLNd3W6czwY7j0Wi7YP7MMthnvnMIIBgTCP3WOPWIEBcpADhjDC4CzqwQPBCI6IcAoDBoULRRFiQQQ4FwV9veVBcM/yElRQlFRJlYsm5hJC0dnCYe5bH9d1K7VIUQ8XqT4wRg8DRXBQWURKYZF5mh7e3b//8OH9x4+Xh/tpnjP9ZhF333t3s6m2WictUqiqFmb2QB+2d9v30fexjzHcHE5EnLgmcZZbi/DBWnixWUwI6/sQxuUOoWb72J/G+hONx4qb4kvBJ8YT0zrc+6Ax3N2DmFgJStBwEmFmBZS4sc6lLUQXbfekF5JKzw79+UNfefG/lKy/Pslnv85pxvjnP3QU1OLFhDMQw4f3PcAsRbQwhEbsfd37DidlvUx3D3e3bdsBTG1SLSQ650Pm5AOJUsb2r8/pWAf/3xx4BeqehxCxuwX5tq2lmHswQVXIOcw6U9CQElpBPIiWqkRhIiRGBAojG+xGBC7K4QC8tlJLST5QHzffgti5uEikuYRj30c4rtfVzFV1mrm1iYiFi6bVoHAEGKxSaiHiooUIe1/jkSb6WF9cOwoFwmEdxOxDmAUsATa0qbR5ZpUArtdH9B5EjiAFGBqkSnR0Ewmzsxi5qXcCAgNmB/Oc00ILEQU/38uTectcyzyVWVzIKNZhli0eFCTMtbSHOk1Ce/C695tb6DQxX17TOF7B78fvFFHOLQ7ACLCDMIfgpBlJvhpuHbFoqVobEZeY2/JAIryXcIzeD8sS7tbHkGJep7vLu19Nd+/PaPjcGkdIkShejL6t18/b7YtbB+L0d+l4RUtr01L+Ldd+ltnPK4wDyxCRoyIGooiwPrqsK4i0VCYuIkWru4W7jdH3LaILeVGWkxYYDOaMHtzGGL1zFVJOosEYBvcoka7d3d0twvlAcAJhwtDCVJK2l6gicbYW8FvXDiYWIBgkYAYnhs9Bnjl2JMMPAAk/cwZJlEsRSYjgJTPPO5w1h8jAjUgI5Ig4MIvsNzqAeSdkERbCJAxJ5ECIleikCQRREClpepXXIOrxOBggD3frwBDyVrhALdhDzMODOISIPMsZSiQH559oJFwgwombMEFA5JGZ1lzK3XIRKUxatbUy1dqqNmFJVmbRMtepahMuCPIkqEX0bren8fS03p5WIlZRq8VrhXqECtUo/tp4MdOyTPMktU2X5WK7f/+nTz/8+EPAzcePnz7967/8/scfPn/5cnMHEQ8bfezDhruXorXW1uY6Xcr8sUwfa/uo9V1QcRKHAUMwKIwREcOjR7p2MuLEoJIPbwzj8Gx44Lz7YdmiiTCmEAYYJCf74kjUXu35XIcgAinzkFBhFYiIChcVVELIQATQh63bXoqocoTV0ig4HOFkES6MWqWUWuuyLPcPD/fvHu4e7ufLpdQSALNEoI+x7buPwcxSChMFhYcP9z7G3nvvfe+jdxvDRlggRIiJJFO2AyvM9sZXu+MgB5j13cYQ4d733lfrG/pqvlXagnahHbybxzAaFmM4IAl5EdLkMYuyVpYiKhXJGi7EDaSvGeb/Rnb+8+OFyfbqZ1//y9v3BvLZHF/sbjbGNsYwL20qrXGIR1zXdVtXHy7E67Tve49wUZnarFqI2S6XExMVFSlcihRRSbYUTqD0ZV18dUL42Qk+Uwpenz2+TtzzKRE4PHofZh5GAVYVLrkczZz7oH3n1sRaYTDBOIIDDKFQoDBJUbQGwImpFC5N3CPCxui9D61cNbN9FnBYbPveu12vm0dM06TFPcAH2J1IJ3l4EFi51KpcS6nMPMbeR+d6qeXh5QaMQSHgIVpUJKmZ2fs1lXlui9ZKTDSsl5sNHcJgsIA1gwdViCqJB7yT7UxBYkwHY49JlEQp29dAcTz7884zs7Y6LfO9QjAwcMMYJByUhmqa5nfzcgdcdwvbvnRfi1otP8dSwEe5h5iYRUULazkyXiTZ+dkBCYsQU7jZvo7S6nyhA4OFlFYAIvYxrO8RRgQttbSpTktdHpaH7+4//Ga+vC+tvWBD/PzhhODwYfttu37e10f3ASQ3hPMdIlJKrdOi/4Zrf16Ah2lOKjkLZ83s6BAUgMLc+xis0BBm0sI1OURZTzG4B0WQBCEkmD09GSTgYcN679KUlEEHoygAmBOF2fC0gggmQiAsCGCRUlWIqbvtBk73CH6mXDwfSctLfiMxiONgBCAO601x9DYTazIcpUotRbTyQec7Vg4ffzInhB7hiMwP+HhfovHMTOzwkzwvmu6FKL/kCJaQOAQcwgxobUJUIoC3oCNAgYxtMoonZQaXEjHM4eQ2tr7ZMIeLkDUFN21LUQnk/UGWgThhFyYmLrXO81xLY5JCWrRNdZnb0upUtbEUYU2i7FRbrVWEARsWDifW4fbl6fb0uJp5m9rU2jy11pRF3LOCjreunX/1q/dMtZQmytu6Pz592frN4cz06acf//mf//Xxy7V3F6m1tnwm2XQwhq3brrpqeazzU5t/mi+P0/IN64WkBR98SY7sahvh3dGBAcp0xuEW7hTG4YI4sZogOGARBgyEnc/0KDg5IMr8Nm0HYObCgGXlJR+lqoDl6B8homySTse8blIKMweaK5dsJKBA2AgbwvNlme/vL3eXZZqaFMli0L5v1+vTSZhbPYyUIHDEvvefnr58fvzp6fa09a17d7fhY5i5ByhIIc9LndkZTPB4A6KIFtGSG3X0DbDeuxk8SmD2gLkNMmFlmSPM4UEdtAU0MIErcRV2FtNKrXGrQ8ptZENJfOdueEuO+7ePJOWc3//s1ZNQB6Kv090055IVUHdjwD3MklvoLkUROsz6sG3t++aIECbhrro+Pn0uqoXVbKzr9eHh3TIvtVY5qthlbtO8ZNX5ODVhPhqN+etA/GcxyZ978XjDV+9hhlSdsnQ63K1bGKtMZSpChYjBLketuChXlbwfkUxPJi5F2jIpqkBFoJ1IyCMLQsO9g5ylHKobRIgY1tfb2Nax74Ol1DJRSDgzUQRHkIiEk1sEBVfRqKxa6sRSh1sffebx7EDgcfty0xYLze0yL5e5zlOdp2wlZ1UtwlKIqdVlmi49tp1uA2QUzqAiRaYqrUHYva+Pbp0gRKpSURXKUkSqEEeEZQkOThQv7GCVUmpr86SsZBCKcJN9JQsutU6X+4dvlruHvRfrXWkS66CwsLfr7jm7pZMqR6JappmIEY4wnI0VWXQlYYL5MNs9+2a1zlwKS2FRZiltnu7esbDtK9wu77+5fPhuuf9w9/7bD7/5u/uPv27z5cxOX9NNiVkB99HHdu23L2O7nWbqZb0xp7VumfQ/Hz9jyJ+IRIIRaafTdUUQcHxoeLgNZXENiFDgyJFOf8ycVMQ0lg5G5ktHQdMOAjyZnF1smfOGw4eNCOcMlMEUB589SxtShC2Ogn1K3LzdK3zwzwU4GG0RIAYLMZNwqBwhl3teaYioiBCLZJqFZy8e51bko7mW4BSng+DsoYmgA0uPYEQECSGYgw+VGwc8KwEUYeEW7odrZ5aITPaP/puvDkkcX6WGgAQUAHEQeRwB1tiDgiWCmetoxCJV6lFvEAJLlvkTqZVSS21TKU1YmVW4FKlVp6ZzrU2kCJdUoVEtzAqEOfXhvTsz3bb90+fH29PaWltU6lRqUymUz484mRpvmCm1EjMVJTO73p6ut8fH6xM42lS/fLk+PT1dbzczFPV8flkS8oC7IWtoQmXc6niKMMBKeyf1EsnEI+cIhEeM8AF0oIMMhHAPs3DPfP18bMjyPMIiBsJAdiC6iFyuEUGs/DYPS4cNAsifUa1wqIoIn7+bmeVQaThD2KqsRFTArMRCFAgJN2aa5+myLPMyt6llNQoEN+/ZB7fvW9/dDYwBM7N13z8/PT6uT2vfug2PbBs8vpKroUJFWQsywREO5jePIxEqADbGeruqivkID6YKvkBKRPFQpotggDqoO548IqiCFnADK/EOCdbQEqXsIk9KJqiIHvC3Ge6fOc7k4XjX6wa5zEdwvi2daAKJbx8HcHbMUxoVh4D9jNKYiRFwG93Wta9r730wU1Eu6sP6dbuyEAFj7Ot6/fD+4/t375f5UrSmtMiyLA+IUkpCv8xUtGgpKiqafcwvedbxvyOmOflP58k+J/DhcH/zOOCAkLKyqIcdRCaQahTlppVJzEdVraXUozWNgaO3GE5AiFAVpSpRI6KCIqnIyfYIWNaus1mTghw+elwfb+utm6HWySf3HmM3YuojTtpPlrJVaxGwoNS2lNKsR7jj1TYH0Ne9uEx1yEJV69yWRIlFNSgTA3aQSK11brFMPJuv5CtIhWSul0u5TKxs/eZ9WKcUFJEKChaGJhGqh/fUDUMk4EHIpF9bba1Ok7BA3bpSYSiTitZa5st8eVju3pt0xqKxiPSINxAjne4PJ1U8C3DMSRqTkQTgw++y6NFPDXj2ubgNH7vWSWrTOmtbap20ttpmYbbSKPzh27/6+Nt/d3n/7eXdt3fvv13uP2htX5WqOIFDlTBy62Nb+3qzvsUZpuOMNUWk1FanWeu/kbXj5UgxsogYY7g7gdwc4BDEyTWPcJH0gubmHuE2QK4qQiXYM2v2pO04BViowA8mnplhPLNkKFXREMHHDT2K6SQAkYhEhEeQI9zcsjskvevX0f5hXlMkbnhQ4Gj/FxVW9rw6d4RnP1scJRkPy6AWHM4RDH5G5M8jU6KgUxtJiCQ71h0U5AqWAnEKyreSW7hFMDmHe7h58uLBHARzl5HUaHm9zJhJhVotyzxHzMDWR4wxtq3v++jD3KOoqs4BCxqG3Tw8CivXUoXl6D9MZJ5ZWauUUqoePOvKKAhxo8FeBKUchQnAzcJ9ZB45Rr9db3vfgfjy+fH3f/zjGPbNtx/vyx1XCvXuO5FGFKFKZP6q1g7EH/74T+Fc60Tg0W3r+75v2mRe6tT0skxjdPfdve8dyLTWxwE8pP0ijBGBUJ6Yy0RUKMAKOrpZgIymR369BFBmGZly7lXgkO/Kio8byEB+yEwcuE54BIEl3j6OXKJnUEjkAMxCNQny8szpSelDFZVsDDUzESE+y/nZ/x8ipCUjrdpabbVIKQBKKbXW7Er28H3sPTpvNNx6H7dt7dadguUl7Mx2DWZqVeZZp1lqEz2JYfPyM/jBfd93G33sm4iAQkVqraXNTBy2WV+AzmKIHbgNKsOJuIjesagwMwXRHhgRnqRWppKlgAMozP/826Vp0J8pQz3/DI6zPSzTz37Www8zyBTOBEl6mDC3UiuBiL2Pfl3X63Zbex9eqtBU0QrgfWzXm1OY2bb3677f3PaHuw/zdCHmPsa0TdvYiygFlFlF5nlZ5qXUqqF5WlkgfuYsCp9t+kcSlMFHnrt4+L73bX/dP0JmNmKgFg0lVeWGkh1rJErLpRWpexct2lpttbZWilJ4Ar0y3MzcYcoWEYDnQoUnHdiBYDmpQmGIEsTmvm9ju63bbSCYnPbahIsbSNnCVGSiJqLLvKSImTNL1DYt0zyD2cMzon+9QbLevG8rFw3R0DKLtqIZ1Qy3bjHCWMs8X7g5mWO47exBD/OH99O7WYBxa74OgsyFnSN6wEKG+b73m7tbjAh/VWEXpkzYl9pmrQ0Is7H5vo7V4JKuvTVtlUtxLkYFugg52c5c6FUIeji3DByImYPICSN8oxAkIkUQlqMNoyRVgkSFeQIQPoBIKBERFAYfUkqpc22L1unDb/7dx9/+d5f338x3D3WatVTWc+PQC4GTmUXYmMKHjd3GHn4CDOdmYBWtrc3LfPdQ2/x6u/08a88cGkTJ/iMEHI6AeTAxkH0GECbmYI4s5HqaMOuIEGGQZDCNQ1zmSPoRLOEp5GJmqR7EzHKm3tkDxs8nQ2CchgKAH03AMSzMj/Z2968i+hR/zRK7e1h4ACxcVEM4hFUhImHwgEioHj9+suEkt20qk5y0whTfOSLwOEi7R9ejO4aFBIE4mIuc7YDgBLV8uAiY2c3DIwUC0y57RIoafhWiCJEyVdWmdRPlYB/R17Fvfd+7OYi5TkULg3Q4YmxHY0JE5vpElF0CQiwstehUWi2FcXCFKMg9OgbhEEIEPDtqDiFWwBMQWre99/B4ul4fr08BvOd3VIESIR7kTEiCS6r3vl5Pf/jjv47h07S0Mqs0j9j7rkFafNgmilq5FDxrGVuqwmTyxzgzNKfBY2zar1IqMYMLkRCTJDIZB2kuYOEjzMINlrKyZ+IXzw0eHmGJ6CV/8VwvlIiQhPwZb4IzwQTMUuKDWUiZVFiZVImEVDmbLcMxzIVZaDAxF2aWQAjEwyz1gQ99wDh/+UF2ICZEmFu37iMCnhjN1scwi3AcqcPRMSRCIqTKpcjUtE1SKrNQAK3R18fxWLsfzwqlpFpjyfb2kAhU5gBVB1mMET3r6IVZxZVCOQhhqZ4sVcsiuGMuR6PIc07+F47XzKWfv3Bg8Hhdof8zWTsIzBzBTAaSg1qgoszJTbPs8ty2vu378PAijJhrQGNk//hw72493BmwbneXhyBe+15qva7XqkWIq5ZWSh/d3KZpLrUmzFaymzQ7SpUCkqKHLIfCKHOKWwCgvvfruj7drq+XlpsbG5EQVKUUVSZ2Ngowo03aSmWl5I+xhEioiKAkS9wJYcbkwU6ACERJQs6UCcw4JbeR+iBJvO/bGH2EGUGCfV97BG9rl6pSuU21llq06HLnjp72BqXWqehU1IpWeUteJmUw3G3vm9/YmE3YERaVJJxsJGHHQUxTm6sQencZIyjA7y8fv13ea+wh1C7vo16a3AG899tmT7s/+t6jm2MEJ0H4oEenSDkXLa3VaZnmOyKIlNuXRzvarIqUlNGuXApYg6vopTAiuVVvVumZtOfBSJWLsB3E4YOOxuKD5aWS6B5zmet0T8RhHUgN7qq1SamltLbctfmuTEtb7j/8+q/ffffb+e5dnRd5LiefXu95vR8xIwHu4SPc4kUY5dgPoqXNy3x5uDx8aFMhuj2/+jbsOvvUibPViVVFVQAYItyIuKKKUC1SK4uyFOIsRQUCPkYnRBERPoqOiFMKON1hOLvFSaYHR+qVx3F3A4jUNM+wIBkUQpkmJ5gabj4sG3IOtY/yusWSSZiICelrIHzUt8lTEAmnQA4OOhsCSSc4dNeJUm1emCyZAJHiYYn9pJE5MLZUG80vCUL2nSG76igcNmDdw0KVRJLYEtnncPAvwo0ynnijzMVM6Wwp4N32tW+3fV33vQ+zIGYtUkVbFVYZDkcHhZgeRG3JdJcYLMJVdJa2tEmIaVj2j4axjx7iKIGwMbZStJQDBD2mCRwRmtHR5hhahRlcCRou5izMUpSrlqlMTedpenEmAH749P2+j2W+f7j/+PHDPRN8fbo+ffnpcY0YDgfFNMGDI9DNEeMgPJ9kET7ITYUI4WOMlZhIq7AeDIJDg8Yifban2FwQ/GiKABFTPslwT+TzSGjOUQLHBnjGaAJfOaDwQyXqSCSJyCFMzlSVIJQhgQdrsCPcxV2yNodgNFKRQ3WA13q7zk+PbVnaNKtWkdpmsGhWazzG8GxqS+A9G/eftZwOPTs6XQipSEp5ABZQQIgkff9XTZUiOLinkiUAeJjZvu9uJkKEEejZ+QLzGB4uoAUgDqj0ynvVTdmIZPjk/Ctuvyv124ZvSKekW57Q9Iudoj9znMD886tvip4v//AzLTo6H8ShBEAgASuy+n5oVYcnsz7vQESMPgaLj9FqVYQHkwDCXFiLKAf3bb9ermDpZqL6pU1Ta0trU5vm1rqNbe/zcpmmOXUJay3TNE1Ta60xJDCu1+u276Voa21eZi0yrA/rZrbe9sen66fHn95EjWnWDMFUOeMENuMRHXCQSWmz1gAsxt5pmBSeilSCeDiHUQgRQUJUaqtgcnJ1VuVQRpwMHmEiMjMEfISZCXHRkmla3/d9HwGqS7t7v9RWiFlrK1rMsK59uFNIAuCZ3L1hDjJFZWYYG3wfW3TE5uO6ttaKVmIlFiFRBBjUWpF6cVpHbDExsX7z8M3H5WF/+sl4LA+/mrRd5vfh/uPn3//42G/Xvo3rbjdHl3KQnrNs6o5wg0Jqnea7h/uPU2sIi3W//vCpo0eAiTRjPi2qtWgTXQAmTNOzXuy5zY/MNquYhy30sD2Cwg04CqnZHA04i5RWlnffvP/V35Xpkmo/OOhXyiqqZX54f3n42JZLnS/z3cO83JfaVI6K/bmg87/Hsk9ehzDzgUs/bwLOyJeJtLT58rA8fLh7902bjfAXXPvRTnhsRTo0SlSTgB3hREwUIqRKKiyp2KtMGQsQ3J0RJJSWIyycj/17fnkKhUQ4u5Ajkn962F/PBjPAcQjOgY4JA4EAPNz8eRbKnzMWefZOCOJghuTJ/CwMo6O9+WBWpUhFGvSDLE2nFB8d5+95Jsd2PJK9AAfIg8whAAlI4xBqIAqnYXnBZzk+l04iFUJgCgoG42fMI84WdSIERvdt7eu6b1s/SnHlFAUU1aKl5G/A0i6VKxkiUpJBRKS0Ost0qcvddEGE9SFB7G4j9t2LOOIQsCzKoixKp6Y7CStATuGwbmPEzpVEBQXGY7cNEiLSSESz4UhZX8J5AI+3x3Xdd3Mu0wNR0aql+OqPT589uiYNWVWFyCHIQkjAA0wHl425lFJKU1WisLEGXEotmljWAa5ELouDLI2j4Z5z3MZZQzr0GI4q/stmet7VCaUi6XQvjyNA5tkzmxBS0iczVASDWYXgESzgCPZgl3CXI8EGMXFRJcoGj17X9Xq7zU9P07SUWlWru5da+772vtrYzPZ05MPN3M19mPU+xrAww6HiRZISQKxMIKYAeWJsnph9eLwp7trYwcKEZzXfcMuo2cKSLSsSkpZNlalSVPhgMqZBtBM2wk7sQHPch35L9TcqHwseLHiMHic1iflrn/wzLAT0qtZ+VueePfpfzvqZSCRVn44EPoKZBV+3KYtIrVqrDGcPz7s4ugCR0yU81Kxv2wrH6H3dVmYdALNoKfM0LfO8TPPUplqnWqbL5X5Z7rKm09q0LPNd3AWhFHOzH3768fHpsWiZ5unh4V6Lrv227eu+7+u23a7b45fH15RsZhZOFQop2loroGIqFJ7IqEiw1CRfAhbeqZaiMwUjmJE9MkHehSuLkkR+JZgcSClvcKR1x6HlknE/kYrkonf3Pjw4ysSt1ZhDtc7LxUbUcjPZPSc0edrpN3EYmFA5gAELJyGKwS40Yt+GaOFStdZSSj04+NDGddFLNKu6iJZ388OlXkjWUuJ+vtwt90u7bGO97j/Jjcz7sNWiQ7yoCstZPEPy3Mrcpsvlcrm/u7ybWg3rTZucQ2IkG0OIGFAprc7EzDwRxnQOCHu1IJ+X2EGfoEDAIoBDoPKlMC6ltnluy/27X/3Nt3/9309370GCwNEADEqK/Xz3bn5436ZFW8vwLSODl143frZFB1kjGwEllUTp7IZ7uwe01LZc5sv9fLmvZaX95bWvAHkcWd7BDEwGhWR+ExEvBSWmo7ErkV9RCQxVyllQTKpSVD3BBOBknh2a2YmLSgSCw/3sXfE4CGdOSY47I/NsmfORWtHmFoFMW0uptdb2VedupLgYjl4lVRYtZ8BzdL0f5ZSDwncEthEUCI7jM1PcXoghEjjmwR07AsQgJaIDBCRmDpB5sIGUgiDC4RTB8Wzb6byDwqxp50CcnXQg2NvHQXzKn/Xu22bbZnvPaTyHtlySyQuk1nl5f6m1ztOMwHpdu3U3Y9ba2izLpV3u5/uHy124d9kyqbXe9233OpMICbJDngPiJILUYyGRABvZHuO6rWtfQ1yKOo1tbLSh2iakrfhogmURNntLFOo+trEHlbZut22/u5T5cj9iva4/9bGZm6qWygR2gASlMjEZzhI5ETO31qZ5ybl23bYYa6kTWhNuIpo6S6dU4pmrCR96kKlefDDkjgIkXtbj8+GRg1X0mJL31nplfIbDG53SiwQSRhGFHoTNeFWq8gBOpocqUjs9KAJj733f93Vdb09PtVYVcVtKK7fb0+36ZdtuZrv7MLc+xj7yyLDWk5marS8qHMdQKaQrdY8xGEjFsdiXN+vq6csnnFp0dZpUNcJhw/sePhBOQiKqRcrBbG59xcpPYRvRiBg7hnmSM6YmDyTvqDyI3kk0W3ffNrNxdN9yet/DUL7266++//Me/OUNzG+ods+7Q5ggB7s1K2ZELCJ+kGkSgStF53lyBhfufQdg1vcdzMS1MDdhDoTZzkC4974Rl2AmFlFdt3q7tTxqmWqd77ftctkSL2ytXS6XEeZkqrLv2+//9C/f//ADk9TWHh4epPDT9nhbr9u+9T48Ymzj9X2opUCqSq2tLnOblsaS6IkDoYIj0iJ6boFkplIKgs0jJf+G9d16kVpKHTEsdscIjDFyuFbSlUhVa6nKAmYieIwcP6hFS20FxDrAsa2rFp3aLPeypGtvj7frtu1bELGIm4c56qsnwowmZuHWK9FSW6kqTUA0bDhOye1D45O9MHSaZC6zPgiJ6qUsGtx0lml6f//xMl2AiO1mNvrYu22GAYWwaClMDCcHjHBMyLt/9/7Dh4d3D5dlgtu2rn3dbB/wYCrJPzcbOoayLO1SyixiEaMSXvfC4NhY6QX5aGjOlCxew7UAglXb5d39x+8+fPu7D7/+2w+//fft8l6kHF1C+e6M3WrV2iT5N2eHxUtr1+nR6fxv7uXDwwq9dHwdlarj3apSW62t1lpFXoTC6GdqdPTMF0vfl6btKAUimIUPshCLMCsfw2JEESQiuUfleSRoBuHnVZ77ODuEXaAcDM5U/cilApYxQHKTmJ6pTTkKw3M+YI4I1FLaNM3LrLXGy1XA4BTmR0EytSeeHx5O6j9ROBEIjggWIQiCQfyCuadELWf+jwMvOdoE8EzuOdoD+QgXPCCeOw4efHS9g+j07ykXx2f/wamwFIC+tV8p4MuARLA5LMhBepDmRbKEVMs8tWWZlnme2qSqYx+xI3YCRLk0WZb6cD+/e7i7f7jcuY+Nb33f9r6NEdveJ9Jam0MoPLNRUUrd+QwsPWJ3u41+3W+3vjs5mHuMzXbfvKgK6VQYMav0ov6Va3f4CCPra1+v65OW0qqWOpe6yOjDd05aA2DuyTw5/WvebRaRWus0tQh0s7Bu4SJ0NKiThnt41l4PeBcH7SFSKtgPz5/sj8TEcTI4MmI7+jRSqTTLPW9cEVGAmFMdGKDj03PTeLp5AggCEWHNwX5BOWJgDLcaRbRUYVYQZftdH+O23qQoCFtfterT7fHHTz88Pn7Z9y2duY388izNP5eN+Azn9RxVIWmODkIoEQdRfEXJ3reVmKb5QqU8o2mAIzpsjxhaWKgUksIkFEyutFVeXXbQiBjuw+HGvOSFkpGvzJXDY+zgFj5y0+M0Kpman9//7Dj/8QVb+xmA/xXdHkR2yOqdCCoxBYm7nxO6IoKIVLXWMlEDgznMDeTmIcIeAFoiNDlRFeFmnTg7qJVN3It534bqpiJVpT3dbpflLsdq1FaXeV73p+t231oZ1n/46fvvf/zePIR1uVxY+dav277uY7ecteOva7lUSqFShbUUKcq1SK2KJkUpwmsrzGByUZ1aa1VbKaUULcUt0pYavPuw4Spe4cNH9314P8QQe3cQEWVpvCYFA36qMkYIlKGFhYWU+uhj75usa9v2+5Hayzm9Ot0zMRBhwzDF68eD7MJ36OF1uLUaTBaiIvUYKEbhZh4qrCzTXKdplqpStIHEMZVLkbJMH1pt+7ha2Nb3rW/dzIEzKdJ0CjmyupRlbg/L8m6eF2Ha92tft6fPP23rzc3pVAYlwhidZAeo6rRMpRS4jwJ/vdj48LyiWkSLiCbuDBtB/rxCs528XR4evv3dN7/524+//pt33/z27sOvynTJUZ+55pMIzq+S7lefdBbXXxb+i4fHCZwLgymyVH2Wr172iD6PgajlK7zqZ1n7ecTXx9EkICyH5z4GdWuORANBWPOaU6QQIBxjMA5ojk8rklSmiBCSRKpzErf72fbGRMLBfGQ8JJlgBwLneGpiqrXOy3y5vwO3bbxcxvAgwA/uXgI2KXaU5Is0BEf3yMHAOsAPFcJRZTjcKkBn51/yUZgPeOuQoM9OGxyo8DP0gQPJySeSlDoQ5JCXJ6JITNiRaf3p/l+eRn6QEldIAVeSQqJUmApzISnc5np3uTzc393fXS7LUrWE4YatlWGFya1oa/VumT8+3H/zcPdwd7fY6IjJ7TEcY1x7d60RzA4KP1CEehAJU144xrDb3q9bf+p9MxsRJbjbkC4OV1HlgphUrHWfig97K8vMQexBY4zt6fpZRO4ud0Bp7d48aBRCEKt70v5HkOuRyPPJfJBatZTSR0c4wVP9TSin7+VywlEvPzhzJxSUw9rTJR7e+pjOx8dEscPjMyXlQiERwemqX+9EkYT0gKw4njQXeh5iCBBRKSSiVbWVkl2vFOQWZo5Wp9ZE1Zy0CAtbxLrvIdS9l1tl5cenLz/88MOnT59ut633bsNO+JTlmCvEqQ9xtK8/Lz9GBvh0JB/Z85yDGF/v8oiAjZ0YHsZEbjusk3WKDuwCltKEQ3y4bXtfYSvFTTmYaZB5DHMQ6dRcaZfxPa0ryjvgDl2g72EDETnj8FzUoP9DwvzPDj6Auz+TsgfQx3gVC5CypEAvH7Pvzl8hLKoFpSKAYCMiz05NgMy7mShzyMGQIXKicpRnWIMkyMN9s3CDO7X6eWpzorUqWmv94fN8f7/cP1xE5Mvtp+t+3dbdLORzJSGHOSwQZmPvOzne4e45UilFpRbK4aVkmbDUOl/u5kB4iny4laoP98vcmrBWnUTZLEb07nu3HI0UFjFiZIf16DZS9dM8gkik8elqiN1h5hGRaiPBEeSi3GoB+b75vm5P+jTNl2m5eFiOF9cipUgtulN4himvn5WqJKMM6hHMMk8TFwVzrbXViuwGtbH3DQFyLlLnNhepwsJujGh1aqWJLk7k8O791m/Xfd9HeCgrk1KktgqIudY6LdO7u+XjMt0z6fXp8fO29XVdn67r7UqHW2cwnGIfu6GKaJvaVC9TK+FD4cz26jKYRZSKMpd2KbWBKNzGdgW2RJ+zAlDmu7sPv/7V3/wP3/31v3/38TfT5V5Ky4wuU/Kjg+1VmelY0kfgyi+u/Tkff0bmkWrhJAQKCx+JNR65fv4WYa1tmi91mlVV4k3s+3XW/rYSxkk6YLYDGmXKYZBaixaVonT0FuPosyJK2TtlimEHShbPBYW82qO7IBfV86TbbE0Nt+y9PIrwODIgeqYxHhlWsIoULa3WefKo9Mq1p6BAnK0n2aAiLwI8mQwyM6mIIKOQjD4yZGFVTqWboEBAhI5qfXBK9GUzFR+VwuCjX+AMq44X+VCoyVrWSWU4arUJimRqShzPEM6bJ3Lk62NEH9bHGG4QYhULL0SsrTRuk05zm+dZpXQa2cSO4ytDI2EpLJWlinKpRrI7xEHB2UGhxEerfCAMokHEBjKH92HrbVyv4+na9+EgqsxtuBQDQjVbg7LXlAJvFxGRKqcgK4uPse57K6rEkFJLnSzM0zAlLT4iwogo9R1ZCOEAeQzz7mbhxkSa01SCIhvRwgngzAsyuDpiwdQwGe5+5JB8wklMTIe2AQ43eFgCQFiIv4pPmIvq8cwPTTZ+hQeT+9FOTafiby0l3+3h3Vz30VplllJqSiUPG+t+sxh7rJtVLUqCxy9Pnz7/dL3exnAECYsKiugxUZqZOVIzhyOBw3OxZN05+BwinPvuLd2JSIsyHLFZ3ymMSMbYwrrABJ1jS72lEgS1GKtvjxRdxROhlOgSQ4JAQb6TXbFvwZ9RPgEX2u+oAOMR4wpaSEoaOToHrvCfSb/pK0D+zKIORP08/TerCkA3O+uUx95VokCY4eX3HKTHZ2Q1KyUQgTBEU9fcPUyDnUAcwsyCUwpLkkvjiGG2dxu7qaxHB+nJEllHvfXpul9K0S9PT7fttm29d49cUOUQbonAGAMWVF6GdAkriQaCGCJcirZSp2nSohF+W2/etwgHvBZptTCpskTE8GMScIoS2p4cYQrygFkf/RB1DEA5y36kERQefR99H36SSQLwMGGZ5kmUwoYZ+r6vt9u230rVtpRpLrWySAAjYriNeCWFxMy1NlGSKupSoCpFWWudtNRSa1Ed1od7EFmEZiyOUnguVI8HSKHSRFtAwsfu22a3zfrw1FRoxIeUgSiriHATnptOSuJ9bH5jDwxjktamUqskaejgTIGUpWQTSSs8ValBVcOZHt8G8el0tLZWpwuLuFsKiZxIHpc63b3/9sN3f/XNb//2w3d/tdw9SK1pC1JGgg+Lf6Br5zJ/3gNvs/azvE6vhpvIMR81SY8j3LLkROeiV5E2zcv9+3m51FLE/rJrf4XGEzOXUqZpam3uY9BBFpRSSoL7WiuxgMjM+OzeBkiKlqkp4B4CooM1caYSfO7Ds36RJMA4yCAeZqzCcpC4MjUmJcntkdS956RLhItyrWRv8hIWQcjzmzTLB3LUX4/fwsQFVAhDYhxzaElYVEVYlDgLq2Ccmtzuxy06p1AfMaoc3YCHejyTHsURvOjevGp+xRmenDkJCYgAwVv1naz6uEXfx7pu19vttt+GrzU4eTKiI0JBLbAHGngGiYWP8BFhYSMGGLvvu63ruNVepBIT5y0cEcEstdSp1qmah9vYRx/DpBMrcvSihw+LfY919afHMcy5yEIyZpRyqHazAJHwWG65t4tMpVWdWm1VRMx92/sxaCvnnoSFjR5wFmIcEu/uA4f8QADMm5hHagimPIGyItgjoZeMrviAG3IIMBILMvMRR6cG68k7pYSBnZ5JdhkI87kH+VWvKRGJcNX6HJKltQjGM0jjQA5NSN3ZoqLCyFGjDvMAUGt1x0RSlYG4rdd9bKVqszpFkcIgPN3W63odw5hysgayn0JMWDzVxuEQYgV7JHRwDLQjUByDaVmPwWtEb2X1pmXy2PvtEWYsTNIC7tkHCC/oGH148KQ8VwIUQRRMoDCCSYwGUyaQit9sD1cKJgwmNNq+4crY/hTbB6L3XC4gzUkmeUPpuR//jABfo6HPnp/P0S+5zb+OFrPuZkZMwsfAB5VjooyflZTsAgAIWRPp1scI9OWirTEzKZMcMhVmTozkmIgwZcqURgMcDjeYxRgY5mboBaWUIiTBwjZi691XZlrXvu02LGX8wKxNJi21FhWWvnfnN+KAIAGJRxCTljK3udWlaRNRIxPeKYdXmYcPoKooAe5j7Pu2rdu6b2vf12GbhwEBKcxKo+c83qPRnblQFArxQeGpzdeTTEXMgRjWy1SmuS7L1Eq5Xbd1HW49Rtd5blOztd4+M5DjhLdDSv35yTHP7QJiparB6lJKQ0BI5zqzaIAsyDwMBGGutbSplrnIJKSEBJ8DXD10RAS223hc7WnAQwrrwmDEICFRbaXU0pSLQGNQvz1ZbCtND8vDu8u7u8siROjx9OOXnbcAnAChaZkvyz27KFWBwiVbR1/jSUdHu2jyzEotpTYg2/JHjtRjljYvH7793Te//dsP3/z6cncvKnxklC+h67NfTzbwQZo7k9w3Of3x/Tn35XDtSbSL8OHW44RJzjKcqNZ5uX/48O3l/l1rVeMvA/I42fUHgU5EtZRSmF/yBdHUqVQtSiw4NT5gOKdEsxaVyKltL/eMX2XFZz+BszNUkEqozyVrel355GQYoiDHkTodtXJiliJSVIpwCL0ezv5c5MBhTo5ah5CIisqhkw5iwXA4AD8/7ThNhqbxPuqF5xeyGv/K1uTt5mNeKx3qsnAiOWgKSc08S+5IO8zBqi9j1OkM714bL7cY3fd9jG4+3B0BuMMsiFgY67oWFR9j3/q6DZW672Pb9s02407Vg6ODnnbQo/W4rXYnrOHx5frlcX3crRNzEDvIHN2iD+/diAH2Ybt5D4QHuUsfOULUtbKomIUHKQjgCErBXaJTVPmVuW6tDCulMLObrUykqqXU08JHhAUGIXvDhA8ZMYcGM7t7loPGsMyAWpLqj4F7gWP0HOE4g2dJpEOaJiGi55LbM8clkaJUmUVEMCiSQ4zskXvDZCYueV2ZqZ8jCJ3oufitz3tHJZ2TZ+PBsIgYxqrl7rqrlHmeiNndjRzMEqNQOeBS2VmM1ZkhyqlRx0osRhYkTE75mSJ4vtOn/ELy/PiAjo6Kw5t11e3JbTO7McCwIgXqBoN1wo5TKZqHs3SJnbEhuh+yAU4p73kwgiLcbNjgweKBgsHgO4wfo/9AQokKHAqt/ELPIdHnbuvTwyG5L3HietkmmnTR1+988YkHfzkoZwGFHCI2OIr3Hk4RATb3MUbvex87SSeaSqm1irJyMCEsuUSqRAoRCgYp0g5QOIkjgpwkRIkQR/2TwCwQcXaKPnaKQN+997BBbgQIJXUIIGoIOpWIX+0PVqKKQBCFS7azuRxrtEhtZYJBRA4+hzARJ/ExR/6N3cfm4zasRzhKLdrEBvlIGUYi4gB6mFsUDURYdx9gARQswUxAVBsBr63cPywiHA7mMNvDRIsCdhRMgwGTVFZ8tTuaVpbayqQo4iyiQiqkwoVz8pKT+QFiJuydg+Eo9d0oHYp4AKOP2NZ93c1Iq7aLGoGaiLfKy1JqUWXhIBgszMwZtdX27u7Dr7/7zf3ljsg//+lTq9PKt1SOAJEUrXPjoQJlVgo5+FJvSNh0bh6Em48ttxjiUE9Jt6Zal8v93f27eVlqqy8W5u2jxYsnPjxTpvDPmfrh6F9B9s8lLMkqW5j3zfr23HeXv0KES63zcrm7f7csS1GVt02uP5/XDmbWFM9M5vfpe3Jz0uHGREtGxDnwLSxGuB3Jk/CJbb5qGUyvqVpryZIYmYHBIoWJk+xXlKGcUGoETispohms6TNBj7KCrnKgYj8H9PisMCRt+GA1iIhKTbMkR0nc3C1363FvEqKgo8EegpQ0PFiLchIDjwp5Vu1ImEkoRWpTI4eZUt1csgkTiAimYIozzkmAIRkFZ0p/HhGw4b2PfR/mYC6qhdAyj6URBH/Etq8mzLV+Wi6X0kocD5FYWAszUUe37Xbrnx7Xy93lgUls+O26Xq83Ziq1dMe62bDYe/RB3Qg5wsm6ec+FRKwgtaBhbqDSijkBwqQECVASzIEc2PM6EaM6leJKgPnwMcKj1kmUmSX9OmBEnhoQKuzMKdGSYl+RxEkLllHKVOsErXwEUJS0uJR6PbvcsrZzNH+f7A0+1VkQcRbYIxLhjxQmJPdEewCz4fb1nBvls3SXvdSWoUXgJP6KSg4ZYuGgiAizyKmu4SkjsrX6pCLz1GpRZkCCJVghleokUtgDe6d9hxYHzlISUpWGyY7ZBRJ+uPEDcQxCMInIczssnW7u9ZbH05c/DLsJZK4LM1eh1mBsm18tRkBBlZjNV2xPBVtBD+8W2VgIELFIVVEVEnXIGK60i3iQedwontx+iv4DhMl3OlINYS4ihUsjnUgbZGIprHp2hRq8x0l3TG0qRNQ61TZJKSL6VZpYlJ8734BwZ8s5XHzW+DztCJl739d9vw3bpDiiqLSp1VqqDc6kKEiEW5Y6AecYEuBwUoEnNk+qzFJyFgELSFwqtAYrgeGWhbMYI8agcGEuhOi973uxMYtyDot+HcQzVaYZTubeN96qU/TwUC2qXLXKRMpKjHAyQys5cgqIswg7wncfW4zNfZC3KMaAIMp5byiDXaKYahMWN4JLOEU4cTAR2PvY922tVS/zxDy7O4uY7ftOTrpua874LlVZok7ZrvqyzRsV0TZPl8qTouQeEao5Q12IKHWXDlDlxHGEQXJQjRHCBIT5vo1124c7l3ZpU+xDRHxu5W6pd3eVyW3stg+PwSEKupvvv333m7/+7b/73W//phTdtutleahlEirwkc1QJMyqQlVCWZSII77iN9Hh5ZgAH2P16G4rEe/reoxuOw7RWkstIqwHZ+v1Qz2Kx2eqjhNGxytg6vl7vP7R528PfxJjbLexrc86FumqRbS2aVqW5XLXWitCX/WZfq1GB5ygWFIKzUBiZnE08AiB4iwKH2mQcDAAJ3JmEuXUZfJjjpLknJUDj5YEDyPcWYgkCa1JcaMcF5QaMqRH80k682RpaCm1QYvS6fUP4e23gfDzQzovKhOB09XTSUA82FYRGZHxUUiNILfMslNf5LzPfPKqOB354V8isYqEXCKp9+AgloPklOhjttVRjpAjFpIcOo8jyeBXwMOxMOIkbRRttc5ORoGIEeFgghO57xw2BgvNy1amIkVKkTbVVgtJkTT88IgQ69o3d+zruD6tt+taap1n0eJEZiN6t2FuJwUcRy0la4/MglJVhg4bo3s4MalKzQQ6JXVTfPDtGuMy1WIVQQ73GOIS0SMU4DH66N3MCEHMIKdDAjDcPSfoHg6bjHCg7pZxQWYDBz5+wKaUXYvulh24Ea9O5nj0OEv68fweitQhB5zkHHHzc+7DSaQkOgb+UCRm4Uj4LQVklFVZlLN4c4Y9HqCt8+P1OrfycDcVWSQR+wjrtm8OlsYqJeY73nZsPcwJCCYSoHCQBhewEyuBaSRR4IAuknB4yEzJQc8/ApnXF+HjGrbWel9qKzkDIDaNHvxItDkKiM1AsioeiUcgLGIfMI9D4FOZhItGERJlp9K9FQZYLTQ8xrDed8EjSUf2EuVsAi1iCq1cZpR71pnRmIngMW6+P6UYT9Y4cjQDw5hcUakUvJ1NkI11R2gFBLmT5IzNQ2Yqw82s9PkIH0RZkhMWKapFJRm9uXYckAObA6ccgDtpyg3yoeNBnORMEPMxu4zBERHdfIzITpZhcBfNON9c3FgoG4mVmOqLNQ8XJ7aBFBZIDQwPEExIp1agAkO3fr12N1aaqyqChErVufAQdAqGH+pbcIRlbYlrJVVyDxtHEzIdXVWimgEKqzJrOJGHbds6zVXu5uVuIuYxUrSQiasNM3NmoOREsK9n3mIM4sJgLVq5EQkIWlOnMQeEuNsIP1jromf/E8UpWZYYVVhulyjMSy2YJummyvLucrlb6jIRRe997bTv3luDqHy8/+7X3/zVtx9/8/7dx5TfEi4MpRNgoTibxkSYVPLjmb5u2eBDT5NF0g6ZAQEbe5g98+2P9VNUheSow77g60SHSEmC63z6mpfc/fV736R052tZsw73vu+3p77ewo1efhoi0qZ5ni/zsrRaE1l4fR0/a37DcwiBCO+9D4s+Rnik1Ur1n4Q4ko57DKuF41CzkVKUPXN+EZGjr/UohALh8AgLKXqw/2ApKCDMJBpHf1pqvqtmVZa5qKIVcipryc2pRSVnq77lAJ93kOlsuCHKLik4B7Ez5bMAByIsYjDpoc3FlKOpE0tPIcNDNhsEOABKNS8ijiNzPMFFJiJ3cg8WFeKk2zO/JIsECMVR+X+N9Gdk8frxMDGzaql1mto8TYvncISAO5gpmOCCiG0dgWgr6lzrVKa5eEhAwFQKn/0MKnUiqT5s2/16609Pe62EaEIGH+PAjR0ELcyqSoUVKZlnRqw8Ly0C43HYAIUK1aKtiCDiFIeKXByvr0JL0VaRUrAUgQGYeffAvq9738NNOCRV/dNSu4eZn+FjVjRAETHMqBMDKOWYS3NWcBGZKR+NlIcoFB3d3gmZcKL959sy080ZjgBFygfxAdq8IXyBEBxCEkl55PPr0E1MjSqVwlpYalZbcgwrqx1P332sG19v5XZrVWWaZ5DY8C1sN5tMlyhaeF54uae1O9KegAik6pzzO51EBSTbRmf3SlIyD+nLY1/wKYP8VtxJEUyYW5mnpU4XLUX8EdabfCF57EHm7C5aRimrMLnLCNmGmCHcSFyUVGhSkWa11ojW0UgKSxlU3GQfvO0h/Uq8pvaFlqZSJL0dg8uCFlyCC5iCYxvrT/36wxjdLEiLaCUuLDVoOA1CAwrirfADDneRRgvgQ9f6INMdGSIz41CnihyecDA0Ked8ONjTZARRgOXg0pyynpFcmqNpNVjMYphrESItAZAgzN2GWR8RLh5sjhwbQawJLA1ni3TtQq8GeZhFN+99EFu4c+qTn8ahqBB4BffNrtttm6zwvEwqVFTaMj3sNW7UhU3UNctjL5L2pKWwyN47K4oXAqoetGuR1IYWnZjVB4LIt32dtwZ6mJd5mufrbf385WpjqGgcQ0CIiQ8e+OvdEWHrDSCvS8hEjUupokWUVSgVosKH9RHuidGWojmlKFj96CLn1NcQApMWuVSlqtpquSxtru279x8vkyjtFLvVy8b7Gl2qzjJ/+/5Xv/rmN+8fvml13sfmWQxzkKe86gmyUuqRpkaW5HyN14TZpPNpEVHNoR44CXSHWePUjxctOZGEhE/HebpmSkj/hJcPwvybV08Y/hUs/8z/obMmjzDv23b9sq9P7i9NesApMbss8zy3WpTfitD+fKgrzqJWfnZarufUko4E4WiZyGaRI+1N9fKzA4+TMSciKgckhFQMAYWVqeYKVklxrpOKdFiibGsWOUTzcYRdh5kKTwibuZR8D54pCc9XngvvLNsdZW+moKSwxKH9KwALSssogllBHO7wgWHhDuIQJS3HsByAONJWqTAxKAIaOaYsDQG9VE0OdEeSxswvZ8cIcj6W1Jki4qD2v1wFB9iDzGAjrId1t+7m4e55HggW4VJnEFQLkyDEB3VxIgqgFtGitSgRm1vn4VnbdexrD6MiRVkRGGZjjGyxAomAPAwUWoRVVLgqolApqgeZURRaqVYRYjSpNQPd4Xg1HoZAvdvog4hAIQoiH745uTmG7eEWOVs+4YszX86NeFJNmJN7iQNafFaHPd3+AY3HMTnraCx8hnOefy2D07EDZ8WLc0gfMRMFgs7JAV/l7KAc7CI4OST8XEkDCbFwnco0t9q0VCklB5pDlVRgnd2DQUIxRn+63lSUWLTmyF2Bse2xy2izlMpao87wHCuYuI4EOAIQRwB9HMDTWVkmOiMY9yMPCvx8tBVd5jtCndo01aKqJGQjzHv3zfyGcAYkR2K6ObF5MdcD6idW5VJiqTG3sbRtauReA63UOXh2YLdy22haXZVT8agoCSmkBtfIMp4HQcmM+ArvGI+2f+nrZ3MPCGuRUpkrSwlvblVrk1rcX9S2AJilbIwxsWgWtsLZmfiU4vUIFC3ColJUGsRFWLWkzDwQ5sPcg4rk5gd55DrJ/SgHkJoJMbtTDHNL2jmTDmIB4O6WVfbwsMFjwE1KUdUQDRY4DQoyj8qKV8qTfR87OuBaIRLCEBGRKgxm6ns3G+vtdlu3desRfJs3hBQucCZS4apapmkuWsZkY1g+bjCYOQvZGhFAXWpVVRaARvO+e98MHKW0MrWpFHPr+xg2LFxUp3mB8NbH2G29rfs2EETC8MPzvM7Zgdj3GyNCSjCTKlVpqsQUEWaj9633bYw94FREVWotpYgIG8JBmm08KkoUXi0mIzTIROqoRHWp0938/lJFYg2vI4rUqV7QZLq0+/f3H+7vPrR2YSpu6LuN7tnXnjN74mS3Zwd36sGBIQC9ahvLMTxSCqvmGJEIRhDxODZ9MjezQX/s1reoJTWbD4fMOKvq+DN+nV9GGPP573xy0vjs8xSmyNz6dl0fv2y3K96OZtZSpuUyL5d5moumZ/+3XPuJxfOJw6rkzI/DwR9sZBymz7MnPc4hyZZFcjNTMAFH8VQAT9WZGGHE3Jb5KHur4tn9HuXPHDJEpWZz/OHssyPZ3fvoYwyPYKKS1fevumkOS8dnJpdVkAy8gZRgFcE5hEMLUp5AVQOwgHXbd9962AgSF0EtKbhwhCtFJEP/I5cEzALkHId75qOGkJh91j7pfIBZHaQMMkJwEPOIsonulSshBw/z3se67utt3fatj27RIwYYosTMpbblcikl5xel9HpYjwgf3rVKrdpqjRoICoNQqUWZyHonjy4iRBFj2DDrxCTKoKKFzQ0UBC4kKspCxqFEhbVyEahCC5VJijJXbU20BDAG7NXkN8Ltut5uqxYRggiYwnyFdw9K/ZDDoyZL6HDGEGa8iC8SPduTrK+ftXJmEeGIw6udrt0BHPJKp4BJAuxnF3sqlzEfrfOM894jTvbV2c/5/DhSn+7gkWTEJkcolxBlm+tyabVqKVyraJEGxqRjYtvV+pFhwuN63RKVWi5znYpKDckCorMEC4ittCgeDhwfctDtgxg8XuQNz3vDIHgEBhI2P7hAoLeSzHx//42QJXSf5Y4++tb73nuMrmQCVwIHjUEOdQcQqlo0mLlVnipfZlxmW6Z1qsN9djA37bQEfO96XVFXr1WqUtObxmACSXW+D2qU1E0w0RoO71fbPvlY3TtxIa0UhUKZlVm4s4hKbVJr2BvXPnJOvZmIVOZgEHnmAmaW+jAAkbJKKTrVAqPOhBTkYGYP76P34QQuUuig9yKS+yf6rLSf+tCGMLg5juE8nVlSbSncs8QeNmCDRidAVcHCXA4RRXcYwHgzfLP3fY+1tqhVVI8sRbiUIkTj6fpl3W63277uo5sz67puCCgXRSVXJi5a5SKinNpzo9swP4YiSYADbFLicj9flpkphavt9rT3PtzBUqe5TXe1j/6TfXa3MRxMbW5gvtvH5/749Pi0rRtly48h+RyvmyoBbPsV3ncKF3BTciYlJabA6Nu23ba+dttdULhl1l5LKUV7t/A4469SJRETDhIndamgRihzrZMsOQHenBEQnZe7dpnu7paHS7trbRGuEWwdfbfezT21xjl9R7aQEHOyqEWEoJzD188tUkqZ5ilfd0gEIQqRiPTcYvnzBIx9229P2+1JVWubVPWgt9EBsGc6x4eBeP3n6d1xOPPT7OfLR4M04GNft+vj7fHztl4RfgQPQJ7ncrnMl7uM6hjgt+Wqr4a6JjFPVElUDl6bqiaFJcACVVY98emcQh7HXcslSUwBZLcKJ4qOACE8krCOUxj3NEpZTz5KCJlbHVN1RJIbfQqJnfQ6BCWmUg4eHfNXzl2eazgewQeRI7/4lPg5URFJHNw9x2GZ9e57jzHCPGfYgsnj6DBWLcR60PRT1vRoho+zFekYy0GUAoYshz/Iwk7yuhMXAkk2DiVVkbm9IgIDACwr6pwaANmDE045ylhy1kGd2lRrjYC7DetuDg62UOcSfAgdgEkFygrAeL1tT483VT2d4uRhFoMZEkIcAY1IJb3h4FbVg/Y++uiHjg6DGarSWq2qVVoVFc4hMvZ6UW23bb2uddJapAqpEIsf0RwsNa6fk2nOJ8LPt42O8ScixBI4M3giOkTlDpwwY8BM4yLr00eHxBGR8Ws/fXhyPnlmWZU9kXpHeJyRxekSiTJQLVmCyizu5JAws6q0SaelHH0bjWthZaVgK9WUh6iNw7vbsH3ft72WJnXSosKlkjCEGDn6yJhMC9qBjmUQKeYxhu/d9t0jsgSW3uWQ183CxDPuASJ/qyB0ef9XqhQ5c9t9jHW9PW23re8gl0paGJWMCG6cd0KEi5IKRDE1npvOk0yT1iIH9A8hrkSNqVNE79ttvc24REH0q/PN1VhrsEEuoo0l0Ec4bIywNexGALFKaSITRChJkdEJxhzFZ/Ep/EW8IiK2dT+mQ5WS5paCokTuBTpXiqRmjVbVo35OkJQINLd127fdmGIqXOZJRHGwYdLqBvOhocXgU4EKSEE7T2k2TiHs3m10HyN8kBkLo9bEG0ccg4OzOvbG6rq5h7VWBbqtgVj7qvMS81xEfYwxEqpDgtYwH2OQ81Ayjcoa06KR1kVLcDJ1M/TUCM/OTyBISKu2WkTYzUV5XfdtzQGpmJZWp9K7Ecnex74bWLQULYqI7bZv162vXRguolW0aLxSowNg3oMR3nvsu2+x87qvhVSZw8cY+7BhPjxAQqmtyIQiosoKVpVapBatGTSwgktI9bFHKFeeSp310oTFhRnBkFJauyzz/WW+q9qYK0gRnEMzpjbV1kSVw1Pag4gQRJqlVEo/8pXnYBHWmrX29Ax+TCZ7hZWzBGJbnx4//8Ai23pr06VOU22tlKqlHqXoA1rM3/uC86VnfwbaiI7MnjOsO5TQJGxsT4/XLz+t18exbclZOU+SS23LcjfPSy1VmF+06s7jba2dQEA69WMWbSmqpRZWpaTPlKK1alHOURg5gyS50ZmIHwV1OjrPpEghBSNYJaDEpCLJ5zvRxExwmQ/B1WPqjKTIAyNg5M9soOfKKAvrEX3LM0/h+QEdNGKP8GAJYeLMMVPr4ABLEuIFkGp4OcAt+h5jwB3no6CMYChFuSUJY5RkLGKRUo6loCRxpHVETByUvf902oqj0Pt8kuRIZbXM3anq66XmhMHswqESKhAOpsgx7KpaS2m1tFprKSqaTTw+opsFOReUFNAjYjCHUBFSHQHb/PHx9vnzUy0lwyowg9xhzBAwvQylhYdJEHGJoHXd120HEUmG1JBCpclUM0MrGWy+rrWDsN/27bYGlObapqqFRRJDj2NQWdAp/vJctcKzBz/mZeZYLxBAcvZJZE2UjpQ2Zf/dcl5yzjqNxJFZiCBCAMXBbTnZiwdQdAaPkTMF4Qf78dWeT1ctepoGD0jGICwirEXapNOcQsBSm7TCVZnBLjSYFd5JjJBRhJmNsZsVoKpMtSqrQMjZ3SxsIEyZpWngGCUQBjPsm9+uY13Ng0VJCyMYQQHOKPekjGejUbr2s2LFPL//q1rr2J729bFff1qvn6/Xx33d3EViBgdLr0IMd0qitYhAxUtFKWhN2qS1lZIK06SgCVSJKlEVMZEw27btiYXDadhV4rFqZ6lQSIk63RF7eIzee98JJkKqk9Q7LjPVKQsjbltYwG+MnTCU/DUmGRHbtifs3hpESrJFc3dl6+6Bp+QCkKIaQoMAd7IBihg2bmtf905hMck8XSoXSOZmAeaUsZFMeY7ldyiJH+kME0ECZm77Nnq3/zdf/7YdyZFtiWJzXczcA8gkq/buI+lB//9XepLGUZ/eXVXMBBDhbrYueljmkUB2l2KQHGQyE0CEu9u6zNuc4YZwEsFeKt6YmTPSCSy6KX9BqYthSSnp7f3HeMPYtnh5Ga/fte8k6kEIzlwO34h0i2QQ4ICzxvbCFmkWSObUYtk02YTbcTzinGFIwD0dKV33rSOTmT/eDjf3mO5WCa42c5x2HvP+GCUQyYRbjOM87o/z46BMEd623rYWL19axsisn8DhI8b5sHl4I91aYyBh5tPcDBGU4zzmODNCmFTEQU2qkZAmwqTEWgjOTIlgDtq1bfrSly93gElb3/fv+/bS287EGRTBTKS6vdxeX15e99vtrjItaiKlUmsXjOzhWcPCb1xZAWmNF0XMKU5vsTX5GhUi/Li//fjnf43z7C/f+va63V5vr9+2/db3m2pj1TKq5bWyveryE1osshcT154m8fknoYDN8fH24/2vfz0+3uccKBRn3dHc+nZ7+bZvt9oWXPq+f1ParzX8+s4R7nMWcC0iYNKm2lS46G9zjhkeSMxhbpYE1aZNWYUigipnkZikKRUbRSve6Na0r0SxRUgqmLbaqOdUVtNUZvHrtbVsOWXKhcEn4B7TylbyC9a+5q+0hDFBlFWhDSrQWvUW+h9eYENZSLrBjdwRnmujW2znChoCBXK6wagxEy9EonJk1nUSRj2B8dRiFY5LC0oGgaQ+1UBmwiqbi0FfG8hr+xzpZbtSym9R3S5LUWEop8A5AZ8xTns8xrCRFNIQyR5h5tbCGlwpVGLm+THvH8c4nUkz2QPTPNI8jRgCllynYwLm4QmLdM/pbuUNB5sxZpwzTk8Fqyh11satUVP9JBRN2JxzDBIVxhRa3SwxC4mAmZICWbzoQPVlS7y19lrMDHASFyuKrueELjZFEdwrAn1t42kh9xFR6QaZT8/ESj6rAfeiyoddq/qSXSXl5zkfzNTaciCLIlxRYVNJSFXZNt022XchkBBtm/TGTYgiy9vUOZlTBCUUUQXg7nOOUWRRSUAJgUj4TJ8VDZ0ecA+zOIc9HvPxmOdpNh0pzCSMlKwQCy3T62rTLOqv36LQmBvrxjqJPmw+5vhApupNtSEn+UgcgUYcVEJvQNiFnBbrTGaKBItLZCfaEq/J35L+cLxwa9LcI4/jHRlTGG6ULNRYlJX6lqQg8rDTfZTbfCQFguGBqYGCCpiVW08oxRBm/W05V8vhSDMnYhEjAJHGLubcuVbu9RzVQyhMnBKIOeIga0oWmJPGzHRXcRCkMwG8pqzM8MzwIEpBcqX11PoRawgBS2lc/SmwgoBAIlnZa8gghCyWcTTOX50WkBE+7XzMecbH/R7hf/xJqvqSUhL4dc8SC3ET6SqttLsgRppfLp4ZCVRYF5UTqqq6ShMJ9cAc8fiYAuXU/db223573c9zjuFz+BxOpK31OfO8n/f3+/3nHYnzPsYx5jnnMccxmUGkScts6tdhRSR94977vou2KPaBO4gYLBc0Wo+he61C8zmnMeHizpSwCik0PMuZSlm575s04SZIUO9KTbrK1ttNZWNuazAFZUKYm7bqPSutg4VbYfsiGWTpQkFSsZH09ekQ5VZ2YqpKmocfxqP26k+3Mncf53k8PjJzjKH90R/38/HR95e+39q2tb61tvVqgetWvBb01/lFIFSIFP/6b6x+lOLxePz45z9+/Osf5+MeZpm/egNRadvWbi9ofXoRRoz8y5P+W6grnvSIiDSzcZ7ilpEiQiTaukgj4nQ3t3EcZk4gmzHdmFVUtakop0VSJOcykiPJunzC0rTfunQhySQLr1W04FKorreYT1lqEpOIqgoMrZ1yeaLUgz3OYdY/v5drc2+RBlgJPLSRapX2dS64lVKEFl3CUf9elKxyQVEl1RL3UzpFKVSALMPk9YR62RoTVlb0EjXHk7m8dhmZT/e91WNVorijtIVPxsV1QZb8wm2YDUNApdW+x8IyEy4+ySndKh1uPu5j+GAlSXiGGKbkVPee0SiE7IzH2/m4D/cECYsmsS3VgrMQmKOaFFUQOSyrrltEFM+0YujH8OP0Nlz3VOK9dd21N956/4ImFhfDJ5nQlFpwaIEaLKxKqMgMutI+aNEVaqNSGoPyDqgEouKsYNEn8+LWLSX7tZui5UNDxL+MXbIgd7pMAXMp4XxdrYvwUdHEn1dBzNS32hGxe6blkyPChNZ43/V20/0mZXKwb7o1ViHyHImYeUrBDVAhFW6N60k6j1EburaxJoLIg3ySTarogsJxx7TztMd9HIfN6eErG0o4U2pPRH1jlbLwy1kO40e9/U+Nb1h68W6mzfeYH8Ik24uIZpqfH/Du2ITBrMwh7JSD4swkT0FIWrFMmGkn/kbyJ7U/Eq+GjfqruA97t+OeNkUUEGAHkYr23CDSnIUDaQQnRlY0onnEWch5671tXURFbokBt6ZdtX8KeVpHXLEL3MM9mJ4TibXUZ2nPQPjy/+HgCJqnp0d0DbCbuEtY1KJYGqmWuIXC0mb4zIikEIK4mdkMKmIts3JrpA2RiblgXWZGcnix4bI6NeJkgUgKh/IXPlS425gPI3f8fP9I5LbdkGhd+ibTZiYAZoQyN+FNW29LNJTJM+bwMS0sEpCnpwERkZA01k0DgUlm+Xgf5EwpvTfVtt+2vp/neIxhj48JaIVtjHPe3+4///nGoMfH47yPea7Q+daFlLVL25Q/69qJ2v5Kfdtu31i3UhGW7CmTomA/ENHCO2gliCYQS2uCvDyMmD5htkQQES2JPykhmZfoX7gxN2ZFLl3y2pmvdIeL6kdQYS0DO2llKyhl8Nu+mHAAEJIm3RNEtOumpTw8jhIA0WI9LwqLm81xZnE75pjjaMe99b3vt77ftv3FtpfWN9VGF1VrFdn1b9VE0q+p/ur3MvP9/f5XlfbzHmEL7a6sYVHpm2wvKe3woJzm3ty3f1faf1X2zPCYCJkjUvPSEwOoeg9ChM05M5KIPQrxWkRE95GlSqZgXVI2enrBXc6+tPj4i9VUAt/6Xol0s8w0swRJ48T6UC/bjPoAwtznmO7yfC+JMt2s1W6BY1TDW8Tyg13JnGU7LYSU1TSJqhMg4JHkLKGdROgSL9NC3mvGvK5KZESkB7JMTcp5Li5nwHwWlyQCZ/lRUSKKtlM7OxH5zRAUiYyKrShfJ610swBDGKmZcQwDueoEsYdPtwwWakygTLeICPKMcKT7PAfczhgPc0/VRsTlGw+hOkBYCl/gIF5UikCUk7FHZJJw72172bQLJM3nsGOYWr8RQ5vubeu9f3laFnOitkEUAfPgev9IVUECMFhWsviKNvlya9YAnYVeiTztMvLihPunTur6H4vvgeurLRrdygtev6XcTMt47Oqbnxfg04sZva+NA2Z6lF1UiT9p3+T1Vb9902/fBAkGbZ23LsoMT6bMiDk8gskcRNqYmYJymJkfZjEtdmv7TSHKtJWs290s3T3cSkmwsoiYiAUVNSQqLcruF72x6NOuHOcBYd62z0uUnPf/mQN2foz7P+P8L/IfRSHP5IxMhGhv+rocI2JEHEsu1FmaEpeUa1n/MUF7CKbb3eCJF1GBlaycQJxooA5qRA26k25JGyiIg4U5bQXZlEuAnYiBaIwbb9+gN5IdxMnNUz43vmvAWzMP4yLo0PVX/aZ1IkXtWcoJKQjIKFNjiWhFZ8mEpyd5K/dxIp9BhwNemh1EZlq4OzID2kSZRLNKuzo8CAnVxqy1bynpl5Awg2QSBXLC+fO5O47xeD+YrNwFtReBiPbe901AZm7CTJnCUOYm0lSYMN3HHMd5nmNmsoqWsJHhyJx2Rho4+k2SNBkIMHEgx7TjHKpEwq0pEucxf/71Pmds++Yz0vN4P/4R/2Bin/Z4e5z3w4ZRpDbdX/b99ba/7tp+vQ1iub3+CW2qNxJNMG8qm4RlTGdOJm3pGjMM4FQpR1MRlmudHBfR9ddjneV/vFa5RZmFcFOGMFHFxSxpW2ZGjVhrHJpmc2YEU6X/NdUmrI5apfKvCeHzY9422V9RpJ6+g5nUSA8iraOEiLTv2+ufr3//f7z+8bfC10lUVFmb1F+ts/YET48YY1qwSiXC0ZP3ev29KF+fD/7McPv59vbXv/75/vOvOUZk8kpxSSbS1rXv1G8GeZxGnB4Oj39f2tdGvsjAgUwTBa2eq5xUxxjXqVcp5MtRzN0KgQizedb6xYgQzKslKwx6ieAj4fXwLdeJAAHuDoQQISssJGwaiQja2p26u/2CcjMR7nPOjK9j4iVoXgThIPdkA9OK+axhj0DEUvfHwiHQsqAwJtAkCWlJFXYKoGzimZNLBoXL3wyR4aUJkeXmXZOEl3kx4kkgwNMUA3lVm2W68olCV2+PkEy48gx4s7QwD698C46gxzHMjFjKpRucpEUuRFY0VFxExQzLmT5ipI8EqG8bi5iHeHDlkLBKU22NpWXSdK8EJwKbhZUyiXnft9vL3nclgeU87TxG2/uMDBZuvcmnZx4EEdbaZBMDHAm3iMpOS7CwZlbJJ4oag9d0fZFWCmTKtb+RWnHFMiWp28JrQK82ctXs6wswL9EmLsGb/Po9y5RImLNEWavf/w2BAzO1vq5RJMhWy8pMxLzt8vraqrojEsDWuZUFvCeBI3gO9iRMSkAaFV47ZqTNU2yfESX82pREBCTAiIdNW0n0SLruBm11+9b6slTaKZzaSBuEEUlzUFMmyi+lHbD7/3CMOT7O40fOvyQ/BIKg6WXJQNJvun1T4cxpM6cNIlVWVWm7RAy3MT1nAplM0XFqYOJ0GtS6to3Oi8sHJBhQok66a3+RvoE12RlgsOYJxMWvCI/Dxk9KUboJI7mDd+K9onl/6/cKUOendSYta0q+jGlrvFpbmcWqMPfakZGWXj0U2QAkaNr0MFbtmwizSYUKwq3SPLO26+7hZQLYazImBVoggpho23pvvUREEUGZXURKPJvhPn/r4McxHx8Hk4j07Xa7vdy23rTxtvV9byVKbTI9WfkZTyDENNwe5+Nxnuf0plvTnsElHcoMm2Mabq+9b1o0RxSIl2Th98dRUmhiycD5mOZv4/Tvf9TZjPM8j5/1U8nj7XF8nD4mAa3pdttv315ev90029Nki5hfXv/mi4MmRNLbtrV9HOORd2GI0oSrD08nCmFRUZUmIjXJJWuUSq9Mad2jlOlZdjOrOPFidedFU6aLkLtmpMhCuVbuXYG1zKwVhisKskri+SWu/nxT9a63bxSRSWi7E4UOyAex1FlPRG2/vfz5n3/8H//P73//b6338kImFhTjiqRYdCCKhJsjnV2kNVr04FVuaf2jzv0L/iv35vP8+fPtx49/fbz/MJsLkyGgDszWpW3UNkt5DCOKzNAvQXxfS3tpDWuXYZbL++lZiGIZc/M60so66Fp6ZLr5eQxmgDpT1c6kyKTKqYrMYObFVvIQFWbxuiBcJ3KZSH/Sga9W7blsW3vu0p7Q5bqHLw99ZhYjnIgZXj5lFMxOWvF+VSjWh5pIpJnPWVWWM0OFWJU0QFY+sFglpohTtWciPHc+a6eQiKggiooDyQpsjTAruwzJVXmKlLHAm5IYytdnnojr7m/a9+32cnvJRcQd5rPuiaIx2pyRCQYrS1QiOBM3BZW0jJ2qnwnHWrx3vu07l/+f8kriVW691c1KxIAhnUjC8zzH+TjHMUnSN8vLX90TY8ZHDqaH0J350Gb2VYYhqlS2X0YhqCpe0aK1mitPjxWCd/ENUdRmFlZlEUem12IziYv/5lHm35m4bOTrq1di75rrCERlBejXlYIXYlLb+6uHZr/SBjN/LyProaPnQPFkpNZNocJN6yctQ49aNDzD1YMltVMPhqKskGz6GD7PsBHKZjOIqTWFtibSpG24Wfi0mRzEEKLWWZqwApw0olacrZEoKYUqto7WSATuOM8kimmh+uW+yphMRtylfddtTpDZw+y48mNh44Rba0LkHuFGIHUqX4go60FhEVbQDuxJYgvDyKYCbnfiiIwwieo5gykbs8ouFIlZBneUQ9kYbsscisMcaeHDpnHbSb8zKXMHOdLzt4vBYCXJpZJBPU1MtE7VXIcB8nJNKLcno6Sg8ljFsyCMMT8+jkoLi8h9a7l2ZEkCUFIme9LMnHB3Y5jBPSMFFKK0bYQu+6attUIGbDrSu4hUEF8wJfg3y/IkBFF9oMq9S++y9dZ7b6oiLILeODiFWAQBD7Awe8xz3M2tXL+F2GuNaeY+y5Fi29rtdQck4iw6Bidz4rifGZ6O82PYcBs2p4dlHbyXBN0YLKz3j2McM9xVi2HA27a9fvse5+bHehNMsu+vTpy02qumvUlDS++9nu7IjLAMJypFB6H8abBCfcx9utUzNuu3XuZTwowrQ7IYkrTWe8VFqGuYEWyLgBVzmtc2LJ0CnCsBk5iSgwXM9Sh/ybnpqq+33UEBIe0RicdbXNgbABZp+8v27c/t+9/6tz95hZxz3XXX1p2SSiWWv9YQy3MmPk9wn+r7xZlnDrPz8Xh8fJz393EeGfHr/xNIVPsufWfpCXKP5wby8231tbSrtN7q20REBC72vhdVKTzc3HjZti+uU60JaoDGqOlRtLYdmfDVUkVpdZiFxawCQqQcpK+lRGaCysCztmyLHS1lL7AIJesdVtX/1Vt8flgKQZDVg0SQ+xIjFG1tyeHWV0BGjGnjtMxlS8dK2liUrJiU9aEVtb7aDEQiY3n61g8QkZHuZj7mvOb14jSFuSmUmMPqonsiKme2PmWpfLlPL2YWaaKttbZt+22/mY3zPHzOx/EojqtqYyIzM/MkkHKLhlRC5UjWDjMQGbG2zQkkpXS53bbqWVVUWKTyerUJNyLJXK7kkezDH2/j8fGYc4ribLTt7LN5V5t8uI3zsNnC3om+t34O+yVSqo0xiywGXiKD4uoVi6NeDwBWba4EPEbZDS7kkxGWnEV7vwzlLNyWrHDlmK9BMS9DzcLGgdLCl9aFc0knfI3Cvx4wWtTF6zn58qysaICMSJuXQW3N7ciL9pfptTZOcw+HUaVsRiBIUzo14WlhFtPtGHYeNo8QcY+QLtveuBdopRvTtDH4JAQ4RImFu3HrTJzEPg0JaOfWqAttPfedeiNmmIM5x0RrYP1K4AAx96Y79HtCPXmMcRxzHMc090iTM+ehrVzie9AOgJGe6R69RVfSpr3txLfAzQNXoAB3IVLmUgFwgEOkKBbWKRqJkGfAY5I/hKZyMBngSR28h5IJJ8xjmA3yU7iDgjLqFPryRqq0g5mZpFh0FTRZOFesg6me8qrh8EzHisJ1JGV4uLvbAN3vRyGv4RkvoKVMq3YBBLCTDCJkWoaGu1nAMkEpSiqs3LbeVGUp3TkyopUdpRHlCk358pgvSe2STbXG26773npvlfkhhKaUtLAyUCQFmAI+bWSmipZxV3qGebojQ4R6033rt32fI4eYuwEpjLQ8H+d4jCxrl9Mqyzm9bHMlLMY5x2MgQCTncc4aDUFFT1ZpL7fX09uTNUBEfXt1IDIrVkwqcZNFtVmax7Rwsxk2mRIRyHonZe6Y7m7s06xWOFd8ZbqvY7Y2ZFIYB4A1NV3Pdi2PCZRkxa2usGizCEsLClCljS1RSxJFpMVXi8O9yfebBqmzBrcxnZmyZo+anFnbtm+3W9s2aZ3wS9a17iuip7QKAGqYK++Op+b6inJ5bufXkomIi6N3PM77+3g8fI4CFp9PL4vqdtN+I2kgLpAuf3ex/30hXwoByswyeC1d1bMjyLphQ7SJqC7PluXpWavU9Ig5LZKvXbUviDyiVqbChbFIWLAULYr6trVWHFi5pE1VilG1jZJCbMkFF0ayPjZmjq9Gs3XWc2lWZLE0IsgXa90JUbxrqX0vyDzNc7mUX/GxRNfGYnEn4jr5gYRjXcvatfmKOHBboME1MF7IfgXioKIrkIkUIQGBk2r+/poQXuuSej8RYb44YkQkTOEZ7iOiwKLC3tLNzXSINSGpTqYkrclSeJCoqO5afNGVUBecxhbkFj5DmrMgQXOGTQ+f8xzvb4/Hx2OeJyvCJyMaS8wQEU5B8GMLe6ji26bv98f5+Y2oamttxdxFmtnKElg3xlKdlr0mSmnFFREmxJUMHMQkrCxCtfjxMmXHhcGtZynrRudcB3LFDNadvyTJsdbwuVbcvGr680pfn/3XTsvMPz7OGg3sUlJcWArGaR9vown5cOaLF0OQNekjPOfM6eERM9wiZvg6hikBH2GPOfpxpCCQrSkJc7ZGmxCTWtugPTNjmIuwsN0fPmaGhSEF6UHh5EwRNC2npdd64Ws2AW9/k6Yqu2SwMHL68ZcJhwDESh2kluwGCpTaTziJoQwVVmmijUiShIUaM9MLRT/P4ZFhH5JHzrewRzBDuDXdZDb+2OjHnj/INgt1nzHvSZ7KRIzCT1sHC0uvINJEn+cRZiHMFATP+NUy4ukGkZFrOQjkokSmmXsUNCUiSE5LcrCQJDNImVhWauUcbu7IFad2v5/heT6cgJpGLlSYWaJ1aobIUjoUrwcr3oJb41Y2WiCKSCLxGSj2fKQH3IIz8AkU1db6rROTdGo77y9ye+n7SxMB4Bkl3SkfQ27Str733lh4b7fX/dVGUGo6zvNEQEF768LCAu3MwDyGndNOTw8gLCOmz/s47sf5GHN6BMqohIjdIinTkR5u5cHgcxiS6lQ0s8f9McaskezzfRUOzzQ34tQEBGCUU4iFWU53y/CwYQib08zNfUwfY44x0FglpnsJWe1ax9f2E0LXbFx0sjVRESUlxWWZGpEOTI9ptqZ2K+JvofaBXPyYIsrQMyTket3Y/q4jyJzEcR457/7BdpAv22shKKb6o82fOrgaTNYGaDInuKAyPD+d50C1FpIrevh5yqzfsrpWyuTwOM/H+fiY4+E21qRa75dZW99evvXbq7S+hGW1UqB/v5Anvo5IkIgAWfCEu0dwXH5w4ut2J6EoMGRBgTW1hE2PcFoOsWtSq8oPouCiopFj3fDEpF1WGuh1xi+aAy/bBAQl/5IVVolFXhAMvryeGxEGaEU0Uqz+MCINSAKE4bSM48zhVRT42uAtGh6iIr0vB9k6zSOJQGU0ByqpGzwjExZpXldxmQsnIUFe+Iavur4GRSYCR16tZ/tUUDIXL2zBdrXDJ1Xp0Y3mRQiHqlZH5mYxsvZoXKhBBcgyWKn3hr1v+9b3rqpMXA1HBnlpAClZwZrEAaI5bY45xxzH+fg4j/s4H6doCmVT3fT0kcVISqfHnnG0zt/29uf9Pj5fDtFaUepTaUbktcYqKgQTEQvn4tXSctFeN/xSSYg844Aswm3FVa7t0UpeXEUs6wlb9MXnCP6rriMrzgWXxT/iImcUdzfz95Bds3h/O1dvEJRZu8R1l5wPe6MjLY4P07XTghCYSS7fCg8sy6GooNoSiAKKRBr8sMEHOcI99m3rvcNFsSULi/UefU+i2NyRlE5uZNMqt6PIRkiakwAyx3HGmOEZv8UOyfZ33XdmlZjI4WNvrbkydVII5NVTx4SFpwdDlEWo9MdoQipaqkh3Yglm56aEl2kwf4S9wyztHX6Ga6Yq5ybe+djSt1DklrHFcBsfjjDtorvoqyBFmuhN+4sH3BGz4szenYwpmDN8/3xk0dUeLh8ZLM2ruZETgI02WTs/eBgBwtxEqMJlhSyy+FbuLiwIhOc4zWceZGVn0nu73RSdQQJmadR6RoA0mGuWCdXeVPfWGndcrGcmSWE4VW5heOW9O6eh/7q32qbdOji0U995u8nttW27itQY7oU7UTKICaLUBUpOmm3n20nmhnn4eAytNDIuM2gwI6bZmOfHsMcs/nAQauMU08/jmNO1b6KKtWLMWOZsGZlmVvdrFiVQxT0ej+M8h5vlJ3C3DqmSIlOBlUnJmNPGNM/pWarUyukxN/Pwac7TzjHO8yTSpmW26FmuoEjPopPi6uGfpz/wy9Tq2njXaBUw9znNzMLWbq2G+vJOu+R2yTXzfzVf3yX+1GkZFoiExOj2znYiPMtRgyAxZX7I8UMaQTt0w/rwZNWkq7pXecr1CdHzw6KrfH36/Ip7REnpbuO4H4/3OY5nKsxVommV9v2FtYGfoXkLpH6+fnOj+7WvL5ygdjgRlwMTrvgNL8oPwqMOlwhn8BrrYTnXohXF5l2s1eqhpW+99aaqxDRtBrxGq7WTqHM2PQiiQsTuTkGL6Vrf3y+os0zlv27qKrGD1gh2nesXBn/h66hFe9VaDxTGX+i4OXIGeZqlO5XoqzD+yzocSIpLUxBB5uRBEZgGLxP9FeRajefSYl2nEpMUX6w2nRn2y/D8eQlK1E/MrbVt2zJfmKFNxjjnLJZbid7D3VWEsvbA7sXjXfYMJcsjH26nxUwE887SBJEUGT4z56qdCS7XLYAzBBEIpdxUqDV2F6HX/nrjF57Nk8w8LMIQ20PGx03eX9vbuH+NsFs9DBPg01ZmXGrvsmIGmISLhRBu6eYUmeBIL8k+Ls1rNWCLNFdX8sr3vi79qujlYbzgrizonmJZRQJ1UsqiyVZ+aCJJWJokAPbf5vaIGOdcDQAtHGapGoPG4WFhIx4frrpSxVRoCTUv9VlQBtlytyKIMHqyUEQG4vDTHnbOeTzG68vr60uCa2ABJ8EtpoE8wgghguIlzuFz2sEQgXAF11GuMLviqtHnK9Fvf/Rtn/Ow+RjHDzt/IocIqEtAUsSzEfP0sHDiBhIwsTTwDLhnwAWWmRk+4YfkEHkThMDDkOmKxyZGBKzb2yNmxISdIueGjSgsjtNgI7Tr7YU4k2KWgztBldWFEeQ2zQ+iYEJ4ex5ZRCSi7kkr3GUBKeE1PSYAEVHVKwgcmRCs6YBBqLLhBK+k5QXiqQiBw6kAF2LIdFBFQLunLQ+zTPcYYyYFEzZVlBbHowib1Wm7w62asJjmc079IuEDK7WdpUvbZH/VfuPWqTWwBJEzAUE+JSbZmXbY+fEQIkQej+P9/TgeY5zhMzLQu2zdCzglSRZiZQDjbvPw3nXb2n7rzDhezv3W267HMa24/bE2TXWmlvo1kec4w5NApdWuejjG+X5/Z+vA0sLQsu7KYguLChW5359nGBNWfBQWEkfVPZzneZyHaLeIVqxjlG0wl7iTsjbytQa95tfCbqu5v4paJhLpn9ymkQAoHeGRl74Z5QKRJbD2z8gbM5GwnXaMEZHnMcd52pwRjqL/EC/irntmloYOdFnMUa6OE9ckvs4mWr9Giz6cwNMKr06aWkJmmNk8H/fzcTeb5QV3bfGJmKW1tu3cWhHuAGSVJ/73U3uuMbEaocgMd6lD7fNOfpUczyCkZxkuZkRNcFF2nenrs2BS4QpyaaKiIqrbtm37pr2x0DnOabPmpLx4B0gs0jwAIhgjUJFiVdfrdZHR7belSkQwP+nPmUFF4mCsq1gMupKwl7HcmsMS6ctBILIsu1EFO5Z1+XU5lsTx+o5JHlz4qxl5FaTVRS5qdjls1FVmJgEFs/NC967P+FMjt8aRNZuqamsd8EqH7L1HRCbc/X5/jHMIczDLnGZp7pGBFUhbzUemZYygFIZKCO9KqFij6R7SlKmXhKCeFIoURBJYWbdWfmPM/NJeOvY8eESc57Tp6emded5f2+N+e7j9tqmL8BDmLK84t9WhtrzcGBf5BELhXh0zgtb2qmxZouhwUZuhjGB6giPPuXxxMJgJkItTvZr+S92eicyA8DKowmJfJxgMEshanF3RTddNVYZftRVDyd7LiiADMXEOj4nZoMq9RW/aVJgvQIYrCzVDgnhJeLWSdiI8ckZY9cnTBlt582iT8lKGUMyYkUlhYXOEW6nwMUc8Hr7EnhlAJoGFRKl1bp1Zft1UBLTtW9s3s4fNj/n4ax4/EEMZRJzEwelIFdbQEZRQImUJEio/AwoAa1nHmEqHyiHomuoQm5KRRLOpO0lmRFBEGkEcZkmwUmaG25h8WPaktisHpc9MRCRLCjMzgikSZg5KpssL4qolwioSIlrXfNkSZOLKHiwb+ahcyAQlKemK5U3U98oAltC6CLKsLABPX5OUR5hZ0VwjzWNYpMWyEXMzEkf4assD4WXIURaHHp7pnF7Bshd68KlplMZK2m+y3bRv2jprIxYQV9OBdBp3jIeHO9FkPphAgTnG8Tge93HcZwaYeb+1fAmHeUxiiFJryiw+HJ7CvG399futNd323rfWtvbxcX58jPMwm4GsVAUEwELSBEzm7tMrK7dve8AdPm3eHx87ff/MGhBRAJxaXuUrWHk15sWRqVJa52AxYWOanWMc52h9bu6uySsOrMzAuXbOYM7LtJ6Ku7pKKFVXsvDLoiiHW4VO1XqUEJFrgl9WOU8+huNr7EoSB9GMPKeHxznmnLNizXFV1yROUICCRFjB7YrEWzPf1Wisf2SueZIuEW99see9TL/izygyxzzP4+M87u6WmU/lWh1trCrbxqqR6eHrTk589cX+3+S1X2sNN1RhVq0mJYv331pvRfmsVFFSVuflYVH6V6LyF3qWMblmGF3TO4uotq1LEyjL1FwUtCBKqjyg9Vx5WpmH0fOHWx/dcqcwmyD6EjUWl5y8Cm0EmKprAXEkxSp3KG5EIDgDkVzPngIJ4hQitlhJa/5seuhJkq6bhAFaC/mgCHiwec3LF3KbDCLP9FgannKzN4c6tIOFPSpL70tRXFjgHPfH4+3tbYyH+wCyptIiH7qHiM4+fE53i97cywnKPZZAsW6NNcQ6ne8DRjFThep4MrewgrZSNAIo5QJAKk16Sw4nO6HhESfu46it1xgzIij59kKSt/ndctLnJKVEPj4e97ePcueuaJ8kAMYyVaR2FXYFe8Rin1+QC5NQAdaRaeaZ4Vibk1KY5JOUUJWcFrF+5S0X2MGEZXgTCKH0WiYt4kPRcxBUCg0QSMq25NcbqXMfdW8vfH/xPsBIl4JLFhhyWeEC100Q5eBRjPEU8HOfzABxEHGCAwSn9JgYD+NOqpDGkkaFKw3P0+I45+Pwx5HH4RWSWS115T8FkiMbccmcy+v41zPfWmvNBAqHnZiH+uR0IYAs8Zak2TaXP2f7D8s9gtKH+z2SJRtKgNRIObc29j57D+HDTjonn76ZSyYSbJ6R5h2uDGxnbm6NHbXoPVxnNqfd8ts5t2BuxX7MZJ8uVaaJeJcmCwajXxrXTLgnwKq9TseIAs1y3e3rHHOzCgjnpu2a2SkiznNGAWe1N02qpUo6wAmKImySULCRhmxIH+bHMf3xwG2/bftt23m/5cv+cts2pUYpxOTL/qJChJmFFASCCImSUKevl2OTrW+y9cpgyCo6Ng3pbjwe+Ou/Hj//+RiPkxjbTbe97Vun5ynmMYZFBPHW95t2broRpyjtfRNuQnaO7HsX7e6cSA+QSL/tIGXqnMfbcXdz6WCh4my2Lq1r65qRbkFEfWvJ4mEsFGmB+FXaierxqu0rgYhFFQk40mNcjlK131241nSjpAL9ptl0H+ZJjFIpFyouUg92hYrWgcZgJslAUizXjItOFhnmZjbDrSCvRQjzoMxVSFCnwrL3/ryQr94NxNq6s/NwD5h5Rq59Aav0m+zfaHultqe0a6OTtbap5+3zcvwaO35B7OtEef7zqvdlbjrHeR73ed6jIjSfqD0Ri7I2lpbE7gYuS3BK/P+d2n89NXlR3kRaa+aOafWle+vbtlcaX0X8lbD00qqBiUg4wRElUsbTVmI9P+tHLA90LSv4OYeNWb22VhwmSbW+gQB7Cbd/Hd+LwlLIGlj8eQgXfowEE0elJjuCU2pcl2JlcjnqV5ZCem0O2SyKkBlJokSc5jk9zNLLFo0pqXyIamvAxWqMIHd4ZbiAIxHpJX6vcaCw9ulRbWtRKcVSPTtElSK/WIY9XxFh047j+Pj4GOfDY2pbiUkXL0EIpCTG7CZZelBvV87RpbxmikzzGHOOMWdOcgplIaxHTjIsfSarR2YEIpOJZRNWQTIgjds0G/c5xjzOMca0aZkQkTTd27Sz4us+rRwTx/24fzx+IaMonyqjwakhUgvrgqfWVYnrMgOcJT+lpGqcwgnPNjYvmeTV9hESyQRmzVyLpaQFsAFEwuEIiqqy9aeIk4ViIXhEXHS9L86mzNRWDMniAZTEsFTVhdWU6VjfpHUuV9r6wdbwwrwIg/VnE8TIjGrGVaj26WEZM5J85onwjEYJSk0jDzkGP066H/k4/RxR+dcJXvPx1U7SL5r/71g7pXNMiVPiQX6wny2GpgkFwYFBTNK22G62txG3Y/I4eJgRiGirXa022jS7HKoPoYN8YloMnDOGb2BkYpgz5zRW5ZBm3s7cM5FxRrg7OW3cXqEvRjcKzUmZnulEzmzrwyflKt6oDv4XYmgWIGLW67n3WqrVH3zeb2WdQ8Rdu4qqCBFNt/Msjfu6X4kISeFkM8sDu2RvpCABN5IuMS15nHN83JO5vSZ3bS+7vOx7104hZU6dRExRcg7isuCHNnZnM+JsXzDRpoHey7iwJHtB4TROd7PHu3/8sL/+6/HX//w5jqM1+vbnLnzD1rQrQSIwTy/KmJtEGOvWX1QYzLS1TaUjXSRYG0HHEZHmUUg5ibRtk3kkxZHm0CStHWsC3Pe2jx4WNgaQ0piVI6m1Egv9VqYWJoLVX7EoByARFsPN3KabRToIsWR6BqJps3hD5j7D4UwiuQZ3SmZkxNM4ernNsJLE0ijFZRebeW1Z5rzYAAkklbYLKIr8Usp5pqfXIvz58gjzZNGNdc7JMuvYvPapRMzadt1epN9It8rfxIX3Jlaq1IWDPj+YBQr/+k+sOf5Z1FYElMccxzju83yEf9koVGknUbAk0sLSK+ns6i0+vX7TtUtrrQb0+n7lELII6UTC0lrb901bI6bSvKbHs68ggIW0aRIypBKTkelmlafBykoKrNAOEEQFSDeu71jzaGsNIHePOJ+Ms3qgl475Olufv/aZSpfBSZIhyyrSwYTUTCRnQgsvLW+qTKd0KcmCOcJqUxeYAGXFOVs4CL0TsfIl3TCPsOWMEp52lfakBJFHejhxsi5VlBl8FpNjxYOypHoG0Bpfo8bnC4ILn6jtEcy8yKQiXNWdrrshPBCZnnOYuyHBwnvrrXWtmGaq0u6P43jwIyJh1TnBzc1WiCKLEUvdlwUNnWxMFJGLb+0+h885RzFVzIm4qU6ziCeq9OU1xzyPE9fGipm5IZxs2kKsmYByKiqiW3hEAry8xIkZJf0SyVRU30HkC4Ja/6htVkYy/WoAkyhAXj0xXev4slJ7Vrwonw6qSCXVRlQBh5+sNFX420vZU+h6GjOAbMJNiylH4BDJbde+qzZhJg+UR69HRCGEHIRlQS9UZiKAACrcRZpGIqySbjLZk0HrewoHecCNJ/MAIcbFtyhDoaAiUv/yqisw8XNpz/n2/+VHjo//y+7/RfaueSodDaeACE5kKtFkELMFv/sfx9jTIWQVXtmVVYk4I21MnidlnabH8OE28bBwiCcnoqkcs5GQeG06jUWYdogA0vRb2/+m24v0LXzO827zCDtouZASMWnbWruJbiKN3z+AXwzNWDhuFnO6KKP1HBWfvdg8dbGEuIm+bLetd2Z6nOfPtw93A0Ioaw9PkDAcZuCkFqVUlEZtI2lJYpQhjQLxOAfh0aTf9lehxlAE2XSfhhSAKIt0J6hwWiKAzPg8Ir/CVaKirNXTEi0LgJj6cZ73j+Mf//3nP//72/uPOwJ//PH6/c/b3//bt+9/e335dmMhD3+8n33b+o/7/f3OGmYTaE2lis00M0MkkUik+JmPx3GOE7DFHA2iVIQ03RCETCHeX7pHEllmEuAzjscZ4RGu0prqtu9932TK8xHKhFtYpkUQEyvKamDRPiIqZtdtujsYVhmbZkSwRXXJolt7ZD0Wi4Od6yxahLd1XxRgE1ShtbmajDqlxpznOM1mXhw8n7FGR6rnNmbA0w87KT31M9DDzNxUa4+PvMwuMwEiSmE0pablwFVfbCEFCSRR/vpxCKC8yO95Eex+Eb+LYr1KOxX5wM3meYzHxzyPCnL9VQxoUcwDJdicle9Xx0nKvy/ttcishcqF+xJdAaW1kCzxc9+0Ot/KxlwHdBV/kTIgXMkrXjZg4QFwCoSk7PpQeHVxnuecNFfxZiL5bPyBpSP6hT6j9lz8b2TtKIoVIGbhJmZR+7pEcqFqdQ54mkU4Z9FIktwRUaQ4ymU+XfZvycKiVAAriACPSPPi8KU72YQHFoxG6ZkWzoyV5ZZkBre6lqWbSnZE1oiQJU/4fHGeN0R4lK5tjHkcR42tramKrp6wgiGIIsp214mogblX5Ou1Zcl0dwZTwOasrgsJTgjgnjHDEFlNTRXGRMRRCRR5MQPd1magjL6Ya1F2iTbrbvj0cnebNW8RE6UWMT1hXuARlgMjVajTE0Zb3RsBlMypCkaldPiSu5b17HIayNqE4rJSWKoT8syZS6ZQz38hsHV/0Ro26sF8mrtRSSN/vYum/Mf3rq1pGZJ4RBgimlacrTITKFjydmvbrUkTEjbHsBjDx3SLyMwi3sGTAQEBKaAUImXpTbomISJWagAyOZIdHAxh0q4Ujd3YJg1Ohq2zp7iaBd4zSSkqPeeZNvLiISATx4//t9Mxzp/2+Av2wRiMk3EykXCSmLIpO6er2znfaP7BrgwW3pWEUOiZBab79GnjiDlSK8rcYNMP7wYVSWbyFEercD9iKAvrztSTum7ft5f/1O1GIvN8j/hwC5t+lfYkBhdepY11Jz6elyOLMRqRmar6i5tGREQqqk2rBtSvqrStba+319u21RNDREsKu9xghEkyyCyTgyhIK5OZiGkNSQi6hG3HMd758cf37vNWwpbKXQRAtcUpGbUQaxZhU2a6Leuf54tXMK5FJAsTNEPmoMfd//rn8T/+z59v/3y3Yfut//0/v/3n//HH3//bH9/+uG0vPZFjTtVWgfBEOe1wNzeHA0xlq5uZIhuLRmoaxonz4YBXNJeQbH3besc3Gec5/RCF1hZRo3WlpMd2lNwpESK8bVsvErR/wd2KS+ARBKZYMusIVHdrBY24lfjcwqabzglKWxbKy3bg4kBxudDFdRhebrJrGKbrcMxrxxclFlz+u8ec43Kf4IIDwpeFRmRO94xxnw+E5ev2XHongFzC/OM4Fik5osZO4ue6eZkkPVGCus2etTAX0H7RAOpH5KuYMdYUV79tbQQ4Im2c47jP427nsThkn2R0tMQpZjZsjkrOSmQkRfz70l69R5Gu3WNRDRcCwOXIsjYZZVlolrmKFbOIctta21tv6u5umeGIqNDs0oAjy38io6oopSgzpFm3afOclz7qKsZXE1Cf4FNlVwcvK9eO6JMuEEiYIRNs6Ubm4k5ZQZpIqR25EK0OQLLavmV5gDIoqz5gzZAAMQt4WtHxIquWLtV4fazsQearzQSVbCWJk1fjxWFUYSUXYToik6KEc5QBpsSXq4NImPs5xnEc98dxvz8e93t1o+Xloyx1bXmVONQ8SpRmfh6jQt5bUxEtuJ1BW+tNNDPLFLFm8bWonzbdLoQ6I3JWCPw0X06u5TnAokJMLStHpm17b72xCnixRD7fV7T4bUCRznyV8PxUZmvdhE7M5B5eDzWVw1ZEllukCJE7hVGEI5PLcbsAtEu7UTDroimHmc9lEEREWSYExZIqRe8zTxb0NKCNOih+vZHW5Y+/37Rpa5plwWQS4Y1JhaW8fjNYcHvp+0svv1iLGDOOY8qgWmyQXNbntQOOLE+EIIAj0uokYEFrmpFFCzNzpRVXdFOhTSg569Ccj3IsJ6HWWNaRk3PGOCMzXvqXqf2vf/y/JN+ICWEeswZ7JlNOliBNp5gWHCfzX2lD48hQpIT3cWzl5f7a/bXNapOG+Tm5QSSNcmoabET2prfWtr5/b9v3JyU5q7TLxtlItyCZ7mQz59FykgrL99UekjGntl3aDuJakf16NCIrwwJIJl4s/gJSqJyxKs54OZL2tr3evr3eXvbWZ0yuzo2RVFEXIiqV+x5JQU5lUG00z+JlQSdACEfT/nJjO3E8zve3+/uLKJHQTmAWRlJETJuRppqCDA5hZmoskAYC0WffB87khCeQKk21R9DxiPcf8+c/zp//+JjH/OOPl7//x/f//L//8cffXl++7dq14IcAUlg27be2zw0Pm9PsmA8+Koo7khIsTXXbibe2MZi3rbud7sPdm/Y/vv+96+4zjsf97ePHcd6P45wW5oik5eEmix/VtPfWmdjNJb4wgyKzDubCniPCE7YMYy9XuVKgEcJj2hQbRMsvZ03tCM+gpJKprVqOvI6juB7vyLWrqmTKCggICxs+jvE4jvsYwws3JYqI8FIr1Zng4WPOx9vjHTHzpT/LZ5klbdlBPGbcj3EOW30zkbbet33f9966ECj9WZuLsbfOuSTG2ssvVXiVyVgL8FX4104iL45Quvl4fIz7+3jcbZ7IfE74uGDITJ/jmOdD245kUar0Ttcv2/v/lUZ35VeXfV1c2AYoQIgFb0cSJRar1x117DZte29dmancCYoYwESXCKhMKotEeh3pq06oqJaaLj3cgiif0MCFZFytRbXSQpevCSMpr/eVwLTihmc4eXA5EWYATsmQkk0QMimTI1FWEuaRxYazxb5aF4BqnqbSgXviyp4gFmQSFXkbiISv8IG1WarptxYf4YjA5dJHz4uWyREUnn7pJJ4vQqlFfc55nudxHMfjMJvhTkTK0lpTEVpMOTCx6kKI3SOKvTJna731VlvK4pbXskS1DD3SvUyXMir+AsBzM25mNqdNcyeU86s0bixc0sQyFtTeRJadxP+OM/CcxUsQAHgmR4qHw42i0h7Ke1mFhARcz0tmEJwgxFGETOZ0Tvco3IGVF+bh7IIKDFx8AYq8BpSodfiSsy5Up6b5fKpeLwKrW6zEquslTV7/2ERZlJEZQemUzmW/Vu+uUgW0MyuxEATM4IQ0KAhSZpmlDL/09LFIj+XVmmlVBQlcmNdi5sdMIWUWqGjBFBQRZj6GjVqKMESuLt5yTtiEWz5evtxV7x//k/zHtt9UhGCgSPaEgwrchCOGJZGpeEa2dE7N4OFtWB/RHNLDgQnKDAonc4AQICZrHErhIOWbqjZtvW0BiSR3B3NyI1ZmBXOk53T44DiVQ3RTfUkIAMpJMG27yFYD2teTKuccNcIpxwI2LzerdY/VnSekrFvbb9tt73tX8emglMayiQTDkwQkRELMolS2pJ6Rc4R7zBGtU+ssSiDuun173Q6Yj/Dhx32c2+gqhIrYcXM3G5FGpaBwzzrPk0iC4yt1i6rdLTQJ4TgPi2lv/3q8/fNxvJ8q+PPP1//8v/35x3+87i+bqJQzeRTSQiAhadI2dW/p00ccfvad2qaxaJ1N2ya6AUJEKmJDbLCZ977/8f377faageO46Q/+60f+669znB4QQgFQTaS11nvbVhp3wIdTfKLRJXxZqzBAkajlvHlYXCXZYxGJK8vKJtsAYk3FKFepoOXahmVEsmzKuZ7qEI+gOmWHz2GnI0Hs6dNt2Dzm4zjvx/ExxlGrfq4QSA9bERxpET7HeTzu9ztiZv7H830cx/jx9gC4b+3t43j/OMa0cjYkpqatt66VkR4Gn6XMy0XIW4r2+ldce/lamf3aP9eEc1X2vGTZlOJzjMfH8fFzHvdKVP9MHLkMPDJ82DxsHsQtSWt76l+NmP9XGt065nLxHLHYviI8HRVBaEYziVFOfjanm8vSXncRcpt1UmF5TYuyFlVlEaNKHuu1EneSBNEK0y2Bm1lWHlCU6uZpIJKrvatawQt1KUOv58uuwPWlOqsbjhaGDk9acgnKJHdYrUyHx9oGoVb+T1QkInLW1iDFUwQsYObeJZXD9QTmdCyTs2Iq1JciXJm8z1UTyC+CyHLnQcJsqWGf72Jtf0REpD55t9qFxyKdEalI01bjr7tFRtVHLMv/CushbWfvvffeWivcq0pyXpsQMxtjmBUcVZgOR8ZqFRchnxb7+IpkAJ4py7QOq9Vk/3ZTrYaT+ZfjEAFlX1E2UQRoRohWbyk1oAuTkLtlmjYSxkLfKMFBJe8jzwxelLpEmRynSET5YjBIIR41iNdOS0jpydXzSHNUwtpa8S4LWnwu7aQkL1IMIGQCKcTURIikrChXKgGmmx2Ok6qBWXsEkCiLat+6Ni0aUma4hZnltJgzLnwHSR6ZBrcMS4MxhTdsik3RRBsxqHnsHjE9AnTO6VkGauUcEmO4T3dLs2sdDwA40Yk2pdYYwsEyOZ0QjrJxD8lshmKBMahzenk8ly+OiEpvWh4Bg2NKkKCsRVKYe+MbqaZqo0Yu8VCo6M1JT6tABw9MLR8x44xMG8wmrcv+nbc/kzdAMiZiFqHhKhkfn86pmDYpwSQZsQKBPRMIDp+OSEaFSOutbXvbum4qSpQeFuRt19u3HRJnLQvhDiNhURaiSLIY45gJZ46MrtLAJEy3Lu0PtQ4fuO0kJDbm/QPEApS1l4OCJEvb4QGbPs2qtWX+6iEPAjgd0wI+7ERT8zP++sePj58fnPj2sv/nf/75x5/fiDDHXDt/rF4ExR0hB0ObUuzzsHEvZJFIhXsjbmsEAomwqIQLS1NI095a0yaZuXH/m/wZ6e/vx/lAOLXW9207Nm/tsW0vr7dvquzmQUFJRCGfgGArNglfppMR43IiWXyQ9V9JDPOYPlFTe81RqKpejsJlzFC+tk41LHlFYMAiKIWSTht3P+qJHuGHzft83M+Px/F+PN7HPC2qKvOy05l+DpvuNv0853GMMabUnux6/fXz/f/zf/6Pv15fetefP9//+vl+DgO4IODWpCkjpp2Pktgy7SyC4uRScVYXoaou7mKh1ROYz0U9rqH9Saiqe/o8P36eb3/ZeGTYUuVfsnYmZm0sTHDEDB8RsxScFvFVUPe/GM3+pjN+vsrzq05jMyNLFiwcwhyZhcG31hJeHoWxGJisLMqSS9mVlQ5SEIl7xHTy5WDAxVReaQFV5UPKz/aiR9Y5+fzh8JvPC2ohX6h5lCx6va9fSQSFt1CCIskszTBGnqO8yq6ivv5eKwOiYIc7pKE10uJ4MScEKURlVpdrH15sjeRatj9hTlRvSkllMLJ01RQBmxdX49Nr6ToKeLw+gQVq04XaMDFxICrRJ1L5ootf+6tqkirW2iNizvLw13AxpgifZjbnXPJNei6WmIu4y7I0kBFXCNN1b17fKy+K4++owqcbHJ9uKVp3YvEnIzyCeW1YymqFWKU1jUgPqFAhOhfiFp7FAly5MFSaniSmdpFR148qIiAOeFlMPrkvdaCve2CGz2URXDABEX0mzxIT95LNlu9ROZqxMLEgguJKcXAvW+BiN9YtRSyi1MpJGXJpCJKSKai2GNVk1E/GbuETNcsXwco9s4Ov3YaCttZe9n1YBEAHnWPUZzJnzOlL/uOfTcPqe7akHsnhJjmRlgjPSsYsKSfICcwZQiRFHHVK5mDxJgkhkMxg8nRzylTQTARIQCyysQophIAMP8O1NSJqBrZlEVHeSDPSMhkZJELaue3aX0h2kIbPsDPtyBisjbUz//jyRiKq11w3nme4e40kDMq1mu/S9rYXPZ6IImPYGDZIqO89KJP5eBxVOISjNQFxBsXgVZgoowMp5RqtIm3TFMmNmQPpxzHGPFZeFoGFWxdlziT3tMhIT3KiVBFG65+kUcoN3JPCI80zpyXzfMz7z/v5cSBCRXpTIR6jDEyCyk1dWZqkIR0EUW0CAW0f4z7t9AGTFDCrINbCtXhqmShfPkpS2QqqTwptqvr6eIxNf36k2Tk5E52Fag+/bX0LuM1RzE5cuqg6+mb6yk8gJGhWafdYOzT3dAubkU5BHjZspgwi97BMrnklaAGhKHA9ImJShBOcxFzHdAU7MRKHnR/zTsQi/Zh+n/N+3j+O9/vj/Tjepx2RHhlUzCePc9jjnBY+p51jjmF2mVo+X/f78Y9//bg/jq56Px4f94eZV7lh4W3TfROFc5w0GQLWZGpCUoNigPwSpcS1bF8MPKqKg9XILI4d1QIZABA+znF/Oz/ebJwZfo3rQJ3zqtqaqAglwXKZhVkUCROET/EEX0q7ivTWwmsyX26suHh9Ze3rHu4mDiSlL898YtZW7nIVhTIrU0WIpWAaLuX3Ep5WAYiAr/0MKJDuIK68W2Iu0PHXZF7LmaKCZ3B5KJAg4NOBX6hPAuZkhkpvFIEolQH0JWqlwj/cwzznhBmmkTlfUQRMxOkZK5YdRSRlgTbqBZIyl2lvRLjlnKsRWbY2fBExndy5kmSqLUE6KLQhG4toJrtHRo7hbr9v5PPqeHjFKbEIB7geUHebxkSkorie2ooo/lRT6wOkiBhjjDHqxyxf98lEgC0/yQX4cEkUo9JwSFWIBQR3fxynmeXKTacsKlNmJso5cnm//AYrLIMHrruargGelVuXtikroyjpTTIDzhExhjNDdXnJE18cUaLMMPPjGGaTdIXDERDmGcRlC1KpPpTPbUGyMpjBnj4zbMaY0yOj6JOeYZFW5naLAvW5C67rSWUce/mQK1eUC+UzRCET03JkemHkEZ5JYA734hMlt1JIFaEy5vQ5bMzps3oBSo85fJ5e83IsAxbPyLqaykxgpgqq2urih4eNsJlzlrPp89P+cjk6awbnOGceiAflWCoMkFJkkGCtDiuRjEmd2JAh5bWaoJzOY2wUtQAzZpecHunRg4W0CalBZuRhTmMQg0hyJiJBEYY5BdQTL6I37TtLs5ScCDLRwewRI+wR9iBM0f/oL99Z/+enuwosXNBpHZ/VszpTZnJnVm7a977t2771ra26nubz4/H4uN/PmCDa+p5J5zncK6w162FDFhGnVbTE0jhSfQ3PYreolG7Fzof5GRnEvO/77ba3pgSd0919hgVcJIkwKRT28uly7O0Wjdj5tBPBkqLUPBAj52kEOx/nz78+Lk8SJsrKNU5BjAoO4EZ7v93IKTXjVHu8pfM8IilI0kZMNRgi6TxsjkAoodZi3Z3NQBXClKi4OM73x/t9yJluYxzKrMwAfPp5jF22pk2In096JsYKlK39GVmsGwJeNmPmNsxGwkE5bQ6bISeRmntlYFaPW8Dv1a5Z2EBMkI3kA52yZSozUeZjPN7HB5KF2ml5H3Y/Pj7u7+/3t/v9bnPUHFXkrun+GKccRwDTbHrMSA8wfRkOI93mnJRpPI5zzhGrxLKI3Lb+7da+bfSqzjzYk8cp0bXvILFAEBQcRMnskbE2gdVm6qL4OiIvj2uWdbbUCTJPO+523MOsOFulDyuOuqq2rbfehCtONTIsfKwkNPqyDfo9+a1v3SPM5xgcizZ1EfUJBcD6kv8/idG5+E0VSYIaNglSRLc6lgmEvOK51pspi1hac1rtFGjNocU08mfbU3NdlmJ9zazEhdxfKbDP15yYBvek5QMgRMkKShJmsFQHOW2N7GY0J5vBPCNygaA1AWXplSFCkgCThHiwGQUhwyPgRmaXHS/yuYOhFVJGEUvwtn5LYTGcXt8iECtM7LfsYGC5qbAWaNmatZbCZZdY4HRhFJ8I9cX0XmEKdJGEA+lX2n0izW2JwWJFXxHh2UfVlwzPJBK5otvCp9Xip3gCV0Za/TS5nIX/17ewbW3b21X4CltJVm5dt1vbbx1MkVm5lgCxUJUmXCqSAsU9sqxtCj4YY8w5yWrByFwS9kAZtU8zUIgWX6paAy6CrFvY8qqiiDT3Ku2wzFiETPpqRVefLRBEXLS9RfMQEmVWKsFObUjK/IQM64GxjEhH2sQcoTNIEFTbD3ZPMzfzWQo5r40ezdPt9LLFtDLJLG4BMSU1ESVJkDBvrXtgjvAZc8ZIR+GVuUASwpc3onqLPP080wbDGPBUS/EMJp+cjUkZy7KcwEwBMiAkwUFkmdNdPIiDJKEMpiKhljkW1aaymCsWZAGPVHJKp/DI6eUlQp1kEgwtMilckyxpZEJUMkbEBJyZtPW+v14SdtSbEubISoR4Xp/rf5GoaG996/vWVjoqESzsMc63j48fb28TngLRdtkUYk1Ty6aQSBpxmMF88pKGJIUjmEISJOBSRRcx092IyV3d0w3IHMPPcZ52Rpo0qq/aWP5++0TPbK1CawUNzpiI0+YxbbhPB+I45l//ejOLCmAmzqbZOlWCcEIIm2pZ6sLYRAbraebTIkft6qd5cvMEjzPdUQZiRAqoG9vEAgYt3RCWPnwex0BmWnhUn+vT57A5bHvZm26cil8Mp5xhTkyc5cDtAVsJHFYguNnpfiYCghlz2OkMIg1P5RbhkWaZK6n68g9JO9IPx4EgyS2in7MB4TGOcb+fjwziVAseno/zfr+/Pe4f53E3GyUKA+AZ0+c5TzkfECkT2mWJ8/U5V5Ft066yYjmvlXmRX1R4a3Lr/NKIJUFWZlWaiGCfljW+lqlTRCXd5TSwQDQrMycFUFIhUsQ1mdci0qafh51nXB5x9RMwaInCexdmhMMNYQijNCwO07/H2lW19+7uczZmcl+5WGuK/bQMLQ+OxUCqty0iKtIUgs3diGO6JCkJJ1Fk+aYEMmsGvzqVwsC9PDZtEnFrvXy7CJQRYWk2KTmWZ+8zOqSqCy0S/XM1lBgz54iadURgjjGjREGqpELuMUZtieDB7pgTYxapumDbghuidqHMlGAuXJUVIHPPcJ+z7pGolHBKpiLcVXsgTMvDIsoGnK50nkQGuWEiUZw7p98YA4kEQoT71rat7VsfW0dsTNBF20JYuVdGlOGaFCyi2lrTputE48x09zHnHKMwmaUGqW1v3RORiU+2hVl3XZX5VGhkZDpyObzSRQ0RsJDI1ZGVmPXXEUz07c9vVgpbszFmRoKTlbZbv71u+62ByS8JCZfsHT0zKiaxlqJPUmNEutucNku/V7COXdYxARs2jnkcIzN0BXExsxCWc1NYpDkLb9uNbJbtS8ZF2nteg9/Qkcz0YCJhFA9AVgZPBU0sXDAQwQFOUrDVjiF9WVAHq/uwZEw3X9qKrD38IvcXspzsM2Oi5BtjunkFdQBBadi0dU1mKQSua3vZ9vA0L6PeDL5Mn7+8JQCU/e8JtuPBlhszc3cSj7zPETmFqEtuLbtkQypc0xwSjKhcK0yAmIUaqR8tH4qgzElwcCABK/MBrKe1Ce/CTRkEQ47jNJsOn8wmHSAHHaJD2rLvASG514MkbdOmbXvV/vKltBfOEohFSC0LTlXh3vrWt73ve9+3vmlhSkRJOWy+Pe7/evv5z7/+MgQ3eXl5Sfxa7RRix0yq0oh743PQORKIMY5MzzRhFe6ScKRW191fEs18ZobqRpBxRsQ4jvN+PD4eb+ZDWnlRtFvn3H9BVMLUmip/i05kdH6c//Of/9ePf/04zyMykXQe8dc/P8YZ235T7czRWvSNepfeKjyLNEVSps1xmgWSOcg9wocNx2M4d+m3ztrcCancOkunZCSHwyxL/++nPd6P95/vj4+7T/OwdK/wD3c/jvMcc5wBSGsbuXw2aZ1hce3W4jLoLK6ZzWOOD7NHxABlJJmfpz0YTqQ0AQ3zc84DTAlaDo0JxIzxSL9nvIcS5cs5WwSNcTzO99Mew0YawZm4E7c5x/Hxfjzex/kIm6U7AaWFTRvneLR5MHouGhlJU/1yXOH1Zf9vf/+zdyEg/6JzTpr2C3aMiEgR7b3pWgODOCnHmGYfj8c5prm01veNkOnTzvM8z0XG1g2ypb6QvrBsBC6AD0RccnGfMafPS5G/EHqASneuTUSQMYbzIe2gtjEckPy0t/7flHZmUtFKeyO6FMVPpDiWPiGzTHISwYtYLEv9SbTuVE4EOQU4lglszYQLX11YMeqBLGizLA0uV8GLvke1HqZfM+UC/6uSFIkuP88lmThPn3PJFSKJPC9yG1Sz6bJQKOOTSPKgOXNOLBpdIjPdEbU5yWSBryvAcn1m6TlH1DVILCR/3dtY1B8i4pqvc7H2kMu03K0IBKgkUnf8b5jlBBKosqq0rltXRFehVr7iCRvTpleOOKizytY37ZVCyctPjSi8AngiK7mTeYyxNsXXYiYXS3zRx4sxsOD8cjwq6Tw/f4Eu78Gm0pq2JqqkwvxVFIPbyzZsn3OOweAMDyLuW7u9bPut911BML9iWglES0Ja25O0heNf6he/shlQ3OKSw4IDzOlpw+dwG/UnKK5sZuRC2akM3RtvWyemyCyZTiH/Tw7rL6tYADUmnCNDL0iJPZhWt4Bac2Rc+QmfOgW3mKeXR7cot9AkTPNpPuwXHr88iIuNl5yGdMqgSEwzz5V1wWAERc9QqCaLgIiTlLiL3rR7i/V7w+rJ+l2LuOLRnD2CJagndcscoOkpBIenVBxgaIaQO8GqZ0MQBhBdSspwapwKIDFAAzLQHD3QKCW8eA1PpgmJaASTahq7Tw6HTDICZYZQipCQbsw3libUmJo2aX1r+x+iN7rGawBEqNJOyNKtCRdXjLtqb71r69q6XO0tExCnjY/j/v64v3/cHaG9jjtdxo7rii+BDFOysDb10CJ7urv5VAFaPQrBitZKttGiZFzBbnicxzjtPO04j/t5mE8SUm29J/YvtEazQe5Ct6bNp9vw++P+OO/UqHOvZ/RxTAtsA6qu2lqLNnLr2Dr13nvXDLj5eYz74zjnCqgK4AK5g8Mh3EgyhFDrmHrXyFyu0mE27sfxeJiZqLx+e5lzVrBrIuYYH++P45zHmK+nuX2xSEmkpzkWHfNXafcBf5h9zPlu9ggfYCDI/Jz2QBiB2ZnCxvF2dIosA7bC64jSYUfYR8ZHBhFHJh+nf9zf3j7+Ne2M9DTAuOmubQ+zcX+M+7sdj3QjAjElUYbbPOf4GOe78guJJjm45sLf1owE4khOXM/wxabOiPOcHx+Pv368m3ndbKrcVLRJeMAH+cg5IkbSjAyz6WP6sMwgEOvgZkwqba+478QijkVmwVQxTv8CqhJqk1T4CeGJ2tZqTNbE/bvf2W8M+SpFJWP8xEe/OpaMK/vFV0DeZRqximq4g7LsfJIkLXJGkayIyu+/dAGrPldMZ8k+nlK3+hlSIKkJJyZtjcGScaquXTxBVFpTaowrfOV5m53D51ywWdH3auE5p4tQa1KktFXIwZlsXvdU9YyZUWxMKiUVZhJjDptmkbltIgxE2sSls48SM9WUFh6ZlMFY2MHCAldcQv1CYOFLVILARan7fHhRueNQliKsNUW0VdqJKWGgkKVJk6Z92/b9Jip1dlfqa9Qo544IvjwHhS5N9fVk1o/nlz4CtXxgFmmiKiwR6douz1a6foOott62W9t33bq2xs0+1xKCdm6bQAABCcKDwftt21/27aZSAG+uvL9wsxmluOm9qbYib1I8RRJFVgWTRF7JLpnplLm8ctNJK6yiPtLyIqiA04p+lrJW1OpS3GOOCcHVMTIWHeDX5TDz9/d7U2l9lZHFgCkqJycxyhI3HVnMdq+wLzsrMDuzdVUVEiFPHzEOm7NsNfLSlmL1lI5ckgsExVVWUabOGRSdekIdRBwRaSFJu/bcFhE3LDw8rOgpzzsr8/E/8/hHjLdIn9Sh26TNyX1asDF5MAyBsui1pExQEhNnMBlRCHyT3Dk4jJYlOF64pWj4y8xXopbJ5q4Z4RZ+2ji5qbZX3V42UdDHwz1yBsiD4Epwwbu2rvJH611fvguLMGnf2raL7qyVl/rrxlJRBqfk1nvfNiHKDE4wobEyVQCQyBoBEBTnPD/O4xjjtCuN3EMbeu+ZWFKUSDMzCxGwUCJFpLyQi5FaTTwr96a9aV1SFmR29zgf9hjnjx8f94/DvR4xIrBPpAfcB76olO6PN/Lz1iCRH3/df/zzXx/Hh6vdvm/M4mfYMJ+Vuv6hOnvfVF2PqTqanvu+3W5dGzPneT7uj/v0w3IATCQFOUGKmkDSOE0uu69KSCKU9XvEnOPj/j7m2fb+N/07ER3H8f72dh7HnPN4nJl4jHlOu33fPz5eb+31V6ZgZuT0RMFeca1lkSf87vYx57vbPeIkUAbcjzk/QErJ7ETSHx8S+Zg+PJ1EVXVrqpzkZ8aZfiaEnafl28fj59tfP9/+MW0AQU7k3Nu+tT0t/THs/u7Hg9yZGExRSIM95uPn+dihKbRlWmJknr9Nux+P+V//+mCiyLg/Ho/TItaI4xHv9zv+kT/f3kUkklTby237/u32t++vL3tvrb8wi1BkIMzMzjk9mGWLsuRJoszG6I1Iy9+EwQRiD/eHxzxtnm4jVzrahbMTsSyWUaaKdmm9rTAXUGbACf8ea4+yq7zyLS9Gciwd8OfSHghKFPnoUqZlhNusLQWzQjPcLRG4SNW0ZlZc/OtFantiDc+hjMu8LIlFpMxumVPKJoWYSCqOTEg5PdzoaaaZiTFizkCmygotN+PzxHE4kNpchERoqdEuFcIiUBXklohFgqOoiFGkzXRHps2RykSIipdgBgvxk0YQ5fxKxeXPXFKxemt1pWppVbuP6qKWW+HXzus6jRcssnq2SC8TK5GNpcQTKvry+nJ7edn3GzPPMc4xznHOaU7uxE6sICfRpr33KSKlJKNyZywnsSr/mblUYlXaK7AvMuGsZGtkf9Z+ba31rfWS0ahowj+/CxYSZTBYoMrhgaS+tdZFVFgyg4jBQgL4pWPNKPCarvYSl8hiLQzAJLXz8AQqwK8wZgiLSqtdSUS4+bMHpSSWEk4nkKqUpOat4JtnX0ury/00tUeMMctTq6a7hWNXkLeClZuosqQjLNMRtRQRYiHSBcaYOyLNMp6GegkuYWEiPWzEGB4GLLfARDGcssIRqDR/1RemYrEtLCiyMae27NWLZ0SlIPLnfZDM/8r5PyiPpPKxZM/iIJuSb3BOD4/TMwp3JhVmSTCIkqWAMvEQuyydkAkHEZMkS3LkGtpibX7g4dOYJEm1t1sGzXbMeZ/mWRRKEmFqrfXt1veXtn8ruEFaJ+lg+V08wty3LSMoqLeuKkKL/MZAKyktX/Y1tExBh81jnp4J4lLh1sarinpVKDNzz8wUJdViqnAhrx5JFsA6bajMvMWTqIQw4RjD7o/z/e394+MgEtYFmwCIILOY9uXpmHNkPMRvZPj548fPtx/Bcftj+/bHTVTPexzv43jHPIaNITPMoS2UlWgQjn3v59z7Jk0x/ZxxGIZjEpTAYIgQNXCLTLMxEcSQMiwOIJh5mZ4XrBfa5Puf30tBcx7ny+v+468fP3/8nHOeP+bjnDP8z8f3OW1bz9H1yeW8gLzrUI1Mu8f8OeZPsw/3I2MCnMEZZ/gDEASl0aTz8RZz6PBHpElvrbfoqgoOYwQTU7RzHo/zfLv/fL//eDx+uk8gySEOspPthKXfpx/3HIM8nz4HCI/xmPcfvCkpFC8R5n66HaVBfXLLx7T74wQQGWMMs7X9JapIoRGRb3TPhAdE9LZv319vb9/vf35//fbt1pSXxGixrlhYiLQw+5qDtq31rmDKxS8GOGd4zocddzsPt/kLRKt99YpDFyLW1vvtW99fpXUioYzMytb69/Ew5jGGneccwyPWcjlwrSVq7vdMz5RyBs4obUGVtHS3JIiUfXaSrUWlJINIQEQIEC58spDyKmdUp+ecc4xR9WyNrXJ5UQixykrbEpImugkrwRmglVi/nhbMkYhMTQFAguAMn8YRPj1UWbWUcMtSNxdXqwKmiJLDEU5h5MER5TeR4eTmqiHLXS9U0Tvrr5hBSg+fdXYv9kWRDTPczYgX2AGU4fwKUvTEb9eGit99BRxSUljMYSNcmLDvsum2b0IyxlDV769//PH9+7Z1JE45u+jGYrIC3G05o4SKaG+TRZMs0xNWWZUxKx43Mvhi1jNJ46bcRJSIFT08VlN2vQtUElgraF9ExH9rUBgkUGFRQteICEtRWQ6FTyoHVZN9beKRNqf7nNVvrtKOq1ta7LBauWXW6gMZJBBSbtpBMPcxhtuskNt1CJU3T5g5s7I2vr10VcGS+zHKXgT47CGfWJqJeMJgSBCksXZJLNU/ILksMUGU2rTsa9usZKAYZmsJj9RG2lt9oQhM84iZ7nNYLJojARV3Aw83z/ByKa7Hh5EkCFxkDQYaU7Z2tSZMhN7b58mk5X8x/rsRJiQwpx/hDMtmJjleeGTEY9LhdISq9pftBgLSbHKMJpSNMjgnU4I9l5iNiMHJNDc6TmM3AYSEWRtJN7dwzumdrDWk9rH9OZ2P460DfRNt+7Zv2+3P7eXPfvve9luCqiELM84oD+znu2Dm220v5qaSrmcPzExK1KS1RetdwF2NPMPnaRPCrW9MLFIruhRlEF0ElEUHzRSAhK/+rHiFDA+fM1tzD3MkJ5shHek8R7y93d/fPj4+HscxWu9ddM06xBW5FF+tRRAZiOEjTv/586/74+P1b/31b99e/+jEfG8TIC/XgzF9zgQiIjQiMKc/Tj1tf3lt+4uyBPdcIUoZQKiqNpaeJGnzGIdRpjJzTA5AJVMaK0VRsbm11pr23mtRZ3P+8ee3vjW3+c9//vz5dn8cZxLmmBSf28XS/Y4qEU9H2Ez38TaPf875w/wj4kQakhGcMZEjy2pswoIeedAB83uytW0LbxmikhTetO/9z0CO8bg/3j/u/zjPD8Sgyl8vfSEsYDgjHjOPB6aVSHDhQW4xHuP9r+SghsAwn2aHjYdfiRO/rkd6dTr5ayNdz+FSqVyzX4DofDze3z/++a+fr68vf/7x7du327eXbd9678osTWt3SCpFgtyk7bq/SG+BFY0MEAWHTX+8j4+fdh5ulohFKKv1SxmDadfWt9vr6x9/k74nK5LKGh1ZDt6/Xl/+w93HmHPYnFZGs3mdYZcMqM4+ERUAYba0Xpm17aRqVHhFYODKvim3wup9icEspdROi0J5Pk+rde7WDp+JC/2lTLoC7KuLKccGEuInJ339cbiRTcoIRDoncWZQBJfjQQSVhRauJzgziYtmH8RgCChLk1rDR4ngkSvv3T0XfYqB5DLLxhVqVw3U4igziCsMJkDOkiKQ8qbGSiysJIJF8f+6HaKoTMRE1gK5975F5Z0Tg1haV9GKpozMOScikTHOc8Uk2vK4R+3o5wzmwlHKDX855mXlUjBjzeplpc4syiqiNQRzcHA8wZr10T3BFVmBdJ8halrRwL+MdCIoBMQcGeZJFdZU8dZl2hOIC1vPTKvc51wT+cWbSCrT9qVyz1oeLfVMPYSZ5j6HFR3h18XJiEhzYoOSMpiZWtOFQlyshF9+hOudrJS3FUJYdyJDG4vStZpZj2IKmJFCUGrK3tln2Ewzn9M5uD0LERFR+SblccyMNA1ldl4GLNUzhKd7koVbuj9tmAQBQak5A4jSiXYlVqHWuCU3vL72TzdWTh9mw0IcNM2YkC4ZEGBTflEGiYRy7JQvrL014RxZsuSAZRrlCaFabCEBAaQrdyXR0REJSSioCTemJKJkTWJPtiAFWFvrL81izFnBauANcgu0MSOPhwdXgmIdKKX79E/9OxNtrRezs7TQoGSWvgCkGnGo+rMays1jjGHTmKT3PhfOmpnB3GosKxYOinDMK68SCxRBOEVFg1LW1UPlxg73meFsM8eY5kYMbbLtrW+NBYk0o1ogmfiX4y5gFvP8GPd4PD4iZ+u3l9ft9tqYAEfMtEfY4V4EDSsfwBJ1uTvNOY7Tk2V/kb6zULMgOygm9c7b1rhFIsJ8zqBw0DwDae5NcutSGdoRlOitMcu2b/XpqYiA7reXbduFPsIiLFGG98Nz+zS1Z5rdS2/qERY1AprNn/P84fMj/ET60lBTEAwxMkvVEumIweSR8YAGb8bo7nAv49KX1r6RxznOx/39uL/ZecCCo9bKhrC0jMw8w8+ZNhFZIci1Lg6PwMhHpqbet2Sfdto8bDwg/HmmKliyjoCvJf8qTLkYznXuuPs0O89xjvk4z2/327fX28ttv+1bCTP2rret9a6tKWsn6aQKLth2NfkEyjH88T4fHzZHfO4namKkRU8r3cdtv5H2WUmHGbnmwH8/tbtHEY8rYQ8Ul9b5uTyHiNRKITN8Bsifq+ZIbpd9GlVTs+zlK9xw5VlyiamYy8j7ohavk/eTQ0tGBjxDEOEgwNKvdLHFT1vBHr/rrcLJrbD0mCNYuUbiTESUIJtijZa5RDMczKFtOeIiyVYMxkUAAC5m/lphZPnbCWdoJq9UeaLkECnaULKAlMwiKleR0DbprVWFYgIxqYKJ2AA8B+G6rlQ+4WlJyU36bX9REfcZZTVVrY0oi4f7+/vHcX9IcUbGLOc6IMsCyCLO8zweB5DM2ro2bZaw5WAORFJAqfxpuopWJyflKFv6y9rFgzzd5rTwjJq/mWhp5OhraQcgJaJvmogxznRSYSQi3M5Fs7AKj/BkYhW5yHJrk76G3KgLVloUEhFiCl8GhsgUYiGhFDcb83TzWCY9S+HNxMSITHOwgwyBZGciIVDJHZGITJuzOCW/aglzIeVt5ekufixLgmLOxXkHkrV2CQtngiOMbbqNSkBsdemuj4oyMWechwstnB6OOdzMS7tTHXR4eCRzhK8Vh9Aq7ZXpyJwsyQRVEoGKtFvvg7//uX86n+h99jE3C04SDxFShjIJS7but5s2lW+0n/L3h/xnABSPefwcH2e1UB4YATO2kBIcCKtyK5l3i8l0JpS4WyZRoxzIjVWIWmSzlJnMLL3vZaFBTMnd0Gd0Pxz2xu8HsRA3UiXtrK08pc3mr6eDSEkd4VkmGqbaVaS33ltbMRxMKK/UpjN9TpvntGFC1LSFWaRlbUqZiIVjKUGJEsLlH7CkcZnhXrWNmJpq06asVdrn8HF6ruEhiGl/2XbQfrtp00y3iomMNPMp9kUIY2Snf/x43N9Pt6Ebi4Drkijddo1b2hax75w46TzPEebJwazauzRhZnc/ztn28tzrRHm827jnbWv73iAREXayraBIP2e4zL41BrJpMmc6AU02Eebak2fAEg6GdO1b27puoQCBneyYcXP0Z9mLebwbEYm427QjcgLTzo853nIeZd2FVUeS0hGjTBsQzkhN4gjCIKKt79rD0qYPGzOD990BP4/zfBzzOGIcHMkVN7UOiPSInBnD4MHJRUEVLkzWMs0pcNA83lJszMPGw+bJ2vK3wdK9Vrj5e22p7W45ryN5lYKqOWMcZuN+v//rL9337eV2+/768se31+0//ri9vOxba01AlWESkRbpSE8fMIvMOE473ub54W618Vt69ud6PzwjGdSYW1l610+4ZtjL6O5/W9qrzjFT/cEngQ4XikwEFmlNt75F+ODJMDAt6Xr116LaWiBylJw7o7SXFYhRrO36UBb8yYmksjfCJZEvj971fbPYZ2lYvLCITC40oDZB+eUCUAQV7Fq9qaAO4FXLKwPsIqQnURmUpAi2TbZNQRxOHuYeFKXGf45YWGvk2k8Q66808Tpu63sUOErEK1iHKEWJwL2xtkvQQEkMVRAlycqM/lzaeRHpuGh3+aRg+/LvSaKmDQAlpru7l9DN1hK7KGeSa7LNNQ/DNZVEOJMzSBhKwuwihbEXBh5LqUDCKJ1nccBrYlhLSoVoCe1WSEsuS9fPb6RuC01KzwBbXf2K+TK7KO8e7s7Ezn4Z560bMJ1+RRUBlW/m6Qk8LfaqtDOcoW45xq+syETZEK0eXoSlUVEIEmnTmcBgtzDzWAa9XluDrw9HgWes9X5rE0HlyeD1xs2npIjQs7onJygFCEIrezcRCK+RMTKuFDhl3lSzJwUNdjM3T8vgSHaaAOrimduwk2dj5SShxQplTqVsIGLiJtyZobLry+v2+XLMfBnxzRe/hIVEoSoqytIADd420e/S/6Nv/2F22OPkCQhTMCFGpEeclsPIQSBp2ra2j2C2dLtzRrA4moVggu8j8txuN1EBbeDdqScE5EmSpMwssrH0pGYWZnciZVHWTXITMMAJK9Hj58tBsf5a8luWonrIkqnUaQZRFpXHMY7jGGOEGVaoX7ncDBKIiGiF8aS7M6VWfRUWaQVNZjqSI0CMAhyjFreW44zzdCIUKERCXbqI9K0zsdnighQm519DuM/7/Pg43n7ez2PcXtr+oiQog0hibV1eXxmnNupn2452PNrhFgkIQVVaa23TZAMbr9xNZSZvM9QIFJZp6ZFwkkpHgoWHBQunT0E4r8lEyxdcSJDhjvTg4E2277fv8Tc02sY0YvqPP/94aS9Cv8pHhh/3fwYza/eYY35kTmD6PMIeSCsjAiqusCDCcq5bkLJknUkRgiARshEjho9hw6dT381DCJTMQWyZMyicE5S8fCkRhsCROSI9UfpMElldu2dUF3bOx0dgmg+3M6b9ph8pt8cayr8ugxOVLl1hz9WUL5+w9YrwGDHnnHOOc7oZIb699MxvwqLaiKVWsbVkDBAHGae7U1j65R95WcevMlFgcRQfaqa7EuqEqb26I/2q0c/Xb0azYKUGAdSmTPPafj7JZQCVV3zvPdyFHmUUtuy+Fygl0jvSks2RM9eSkJ4CuatGrMVqSahYVBc7WVVEOYQyF9cyPBAeE8VcXXEDHuZWVMb/ZXNCTzgk3Atv+zTeFyiahQ4Ujq1KfZOX1+3ltUfkHFVXfPWaeBbvMjlAcZtW6A07no5xNUNxSEmzQJFr3BdlYa1Rs/4PM1jAknytRfC1tAuxEiszEm52nuM4Hu7TLxDofn+otpfby957Lj5egrMc6pFJZa0FwEMjWkZGErNuvfXOEUTOIlwE+GXRA1zE8oyoD6rsZjPJ/KJAogFgqQCQ3rRRrogB/1raiytGJCzQlkQUXg6P6b4SOK69OizXej48a81fC914ktOQFeNeKrM5raIykPB0AlOWj1A1MbEAJcFzJaRNt731W9NNzvM8z8GZBPYZBUV54VHA5/sqAz5SnqZsxMrMQkU5retfdPSUQAgJg0BB6Rkzw/z/196X7EiSJNnJExFVM/eIyKzunoXAYEiAIC/E8MQP7K+dAUgMu7syI8LdVGXhQdRiKWKGZzZCkYfKzEKkuy2qIk/ekhYANxFuSiCPsOk2zWb4TJ/EQV2EGnGiSZjFMJtmnumRanXexwobnnOqCqovoQxnjtCFYKumtKasTLJdPh3thAfCt8CsOs2TgwSim26kbbKAH7R9l/3bfnm0u99fZ5KjiVAOJg8j80I+ZnKCIVvDZYaWbzjFRH8gvbhxThlz7POW1PYrtHfIJbFHkntMzxkJldZ21Q0s5agPJS2TP2nl9jfnsDHyY/B8Usw1BOHy89AyxMTJmEkwFymHVabZ6+ttHCPMoVRBZV7shTSibK0zpILKyr2akSrcWhNuYBCF8AFwpFeshA3PSJ857jlniHLJs4io9aatCUvVuu7v6Gv+xtn05/3HX15ur/dE9v1yfdwhdMyJexLoul/aY9PkS/NjG/f9OI5xe72/3m61G/fWrvsl2QOzqSKFQomkxLg+0g+LZZhQGZmZ8BmewRRM4ZypgHBj5mLVKHO4zxiVsXzdHug7rtvT3/ximQnm7dL2vUfb3oqUTD9e/reDtW2e0+wlcoA9zCMdlMwCZlBW+E645d2oFJUxOaa7l/0hkP56j8F3uw/zTCg8LLSha99ku6ciiNxhSUa0zJwpMuiedBAZAcyqKsopFIlgCmREms/bq8Ugsgwnd7B+nI+o8tbUI5yWF8zajImIIMJb16ZafXNmlstZrI6mdO8Z7sc45JZN8va4z+Mptg20KQuElyqXkMKGNCw9w6kRrUJweWcTEXghD+Fz3m8+Dk5vTBXa6Vl+HcH/ztHOwq1JOZIcKhGFmtbRGW/QRDWAyxuuQhq1IFshKnuvN96Wh3vVN9WBshQh/B0PoEUZP6f5dAKnUmyFEmNFGMXMsFNcXyOYGQB/+GjnfTgV2UVrAxKcXFPt0qNhdeE1WwZna9R33jZRhXu4hDb04HPK/u6Sw6fzXXFqzshe8qgyrIgbqAbbncxpzoq8YgJHUGYQojKdq2woK7MzW/r9WxS3vOSMVaasnL03loBnpJvYpGIzAETpMSsBmRKR5QRrVq38Gsub+YGRmV4sQXieQrLVGq/AWiKAeZobc9nxsaoIN66nEWVoIa0SZs+C4MMOTMdtvN7v05wFRSQq2aTbibiv4QiEpUxbSioOnEGsBFYCI4KTAsKQio4Hey4M3cOTlu9PFrBCpTo42QCnKcLJA0x6q//WDaOkFQsb8TZQf3u5OZkC6eSWZb8ER6Z7JgUzSRJlBDEqKIFK0FZFdbLK6qpQJpSWNmLcfRzulmGlRIYQlIXKO5JJ2/ogY/px+DjMLMqHCBmEKN+4elrQM5W8RLkRp7/vp6+x9QfKST5nWLh7EkOpXfXhl3a5SlMSNXQ3xv1m99sx3TwjyIv+TFEyPIA4kVCVTfXCykTu1J2463fdHqEws2Pcpr8kxLLvdO0IrdTI6RZE0qVfZXsk5lkJg+5NGotK27RdZOFg+E1Xkqf/boEnUm7+9FaOLasxVREVYrodt1+ff329vR7jYJegNJvuTpkAbDrICt4ikCi31kTY3TMGyEVEFNULFtzoFshMS58xhh+H1QhrTs9MTVDCLabN++2ogFEgtl32/ROr8f467s9HJrWuCgVJOmLCmEOEmgp6V0H3HICzchMSgbgHM7q0xg2qySoJPxAzkOmTaErZHwmRMveaVYiAYDMooG3bL9dv35+u10tTEdHVuAMZaT6nuZmNYxzH3eb08mWuTASmH5Oe7e1QCJvPTqCykvU7qP4j4cTBEkxZZp+xXP+RJAyhyLRIuFGEItnJ/XChSKPMTDi73ydzbm3j/Rrb9TaOcTtyGA1Kz0VyBjBBVkRtIs3UpakOLK+N9PBjwA31/maCP4+ohbeu7j5B5u6Os6sHA6q69Xbdt72rKgOIkycWWcZTsRKrGZetPV7266acETZ8CDI4K0um8idq8hFzzjC3WW1RnkP29ZNVeGv9etlUJOaY91e7vbaiCIHrIJOq2z5+kY+/KUPT5EBma1p0rXCvAVLNGcoWd45RbNKMrJ9d+GQmmTnuhx9mY5pZuK3ZNSVz8slkLXi8dlyASy8a5S9dbxcvaUpkphfVKtPLog0USMuYwUuh9+lbLWwcLCAWsCRLalLv7JkErAOdWVSYk5laQ+tQpczyaHMR9K7VtZ5n8DJUqUZemARJVcAnmdHiaxMiUF6nY+ac6Z4RUAUTh2dSAAEOFkSAnQrbYXyilmN9eK6k9UzPPJXNkDrX66Pk8OkjmZ3BjAyfNn2F7lAuWeMKV6xL7p5j2Dp731GZtXlX6154zRlHBxbZKqxYtr3vW+soDnNmEilYRZTlc+1IlPT6fP/x/AwBhFWZFSoILzTew2u0VNp8EEksxne8ueeAiJsAkikF4RETM7MnJeBEBTme2cjnAV5cuOXCX39Ey7cg3D0t10FumX5SUhedcXHuPpyJECgnyOEISjMmKj/G5byrJ6AhDCFCZEmZkQlmtHJsVPag9PAR8x7jFkvtsiJVOZdIHsIsSpBSJ+A47N7szhiHU5IIpB7sTtLAIqXFglTQNc0ZcCPKvdvHG3K5PjRNTKcxB90zglrn6++27/9wefiuzOlzHK92f/GXlzl+jiN9ClmFXk8vVieSQSoKlq5967tKgNJdM7Vfv/f9d5Q0bi8/7i/3426BEerYkzpBKWlOi4ToLv2RtycvAreNzFAQa9O+a9+LU14iS3yqfTOz3NR10XNAJe6onZayYlpUlCPj5f78lx9/ebk938cBIIjGnJEhjYm4CoWqgZhRvpwAj1Fm59mabnuDCBEyC7YgSqSllW5lGI0goggH0IScliPyy+uL2aw647Jtj9f9Y5kyD5t32/Zt7ztny4OpMaUQtzxaqqZw7RacLFAoy0W6btOmm58nPYMpPezF3S08ueiDrCqt3NMerpfrw37dt940PJEs2rf98vD0dLle930TUTopUbUpmHtJluYYVvzqOafNwtnmn38+/zyz+DLS7xFpZJzJ5FmgXFA6lEgIYelWIcJUrFd0EBUZDmQgS8lImnZ3F0DPuTCm3e7o/nB53K+Jhxe83ub44a/Gk8gJeeZ0OpBIJGmGRqqVuCs41j6SmcPTHafrFvjziJp5a+xMIJqT54mMV0PQRPbeHy/bw2XbtjWFLGu1qOFgpEcwQ0S2LpfervvWhdLGuEVYyWMbq3JVizXaj5xzjjHnmFWwvvXwAER037anxwcCbvfjeH2+Pf8qaa01bk1LGtpE5d8+2hcD0P303azSv8wzs2qJXBoud7NzU+azfACVaZcPO+a8H3aYTyc6LdxAIFSW+SL9vTVGuS4fnxr3kzJem22lrZSel5bLapAPpzp5fzPcRZmHsCwLkhNPpnWH618XgUjt+8lcZOxAxDmnLmoAck3KVtTravkXabFOQc4sI3UqOK5O+jlzzJzlRJRwKhJjJAXK4YQW/7nKhk+GHOX6HjFtHmPc7vfb7X57vR3jrsJa1kQs4eGexxyUVNx0wnlunbhInoYBxWlg5jzvRHU2p4wMBKYaYpaj2rKcr5vPjIADgwp4zwhhqaM9kkhk80Xl5/ikGbvfxsvzPbJyJ0WatMbIXBHlNbYvVWPxDyuu7MOsfXH1m7IwmN6ShsKDnMgy5pug48PpvEoxLq/ZIqxRhhvNCRJiiBVyMCOXCpZOD6Yk+k29yHrO2NYDWYLFyCRaMUhQSJYLw0mwESpLS7AoAzDLOWMMn/ewg3wgJtLI30MZq0KAIKHMStLK6pgyKuYLSGLh3qVv0vYmnYkpOYgiagBXVZF5po8+Pn6P7fEPPR7FaDvux+1PYYe0/fHx++O3v7s+/IKkebzQmOY4DszZzB8jNUmmxZz3OdNmmFNl4ylSYUJDyEFzbw5u+6X1fQ8PskM5LEa62ryN47W1vfWNoSxCp4fAnMPmGDYpSbQRwcxkHssPgQgibRFc329Ha62O9nPHoEqAIF5QTWu67T0R93m83F9e7j+PeTe3kry6e5Zfy6mGPMvcYMT9TpSYh4/hc4Sq7Jdt2/dt3whQacLJSEt3q/G6EjEBwg0ApUTRfi1BYGFVXC796enh8fr4EYDo2i5937Z9112zwVSIJYUS4WnhYE4LHwFLISg0VTo31x4ezCIlSxHOMrPPSEpVVVl8kHMiv3feGu9dGhQM0dZb31vbVffWLq01OnEyAAUdT/fN3Kb5HLbO+VnvyL/enH6+x+xGeFhSzoyqlZKc4IQVnAYYYIV8BhBv+ZoEigAMPlHaQ+dwAXVJQWYGjG6jP+Yv+yVYvF3u6OISxrDkQCRlcuIkiTOREGmUR2jKKr5rxJ3uy3ejPpZ8Cg819/sxC+asvKs34nZETPMx55w6l+Sn6nUUSs+yeC9F6FGRrqIMZqbMqCzGySwq2tK99c4oB9gmopFkXl17HceLpQeQCu+tecaL+/319cef/mT32771bb/0Imiotn+na3cPKwjmmBV0W7j3G/mK3uxNCkhdjTcD5dYEAGFhZvM2xuvhs/AbJ7xp8Ne5uIRwhTmcPPzC1d7w+kpRSSrL27XNM4hqbJtwizLlLITwbXGxh1kYVLWuVAgdZAmqiIhSJEVy0YvTbDqZU6UTBGVK1XkocHdVBHQ++assSaIkjqQIeIA9z/onp1P5xazgPErL0r0EhJa+fVlgF/78abIQmdPtPsbL7fbyent+ff358nIc96ayte1yubAggsx8DouowQdlJJ23az0hZ19eoDQxGMKs0pSZuIJOa4utEQXBPKbZ/Th8rGTC8vSK9Nu4jzHu9z72qdpoZRcimfucHsEszJ9mosd93l4P8xkUrKwqvQujVBt8htKvbreOczcbc5wO1RARQKRJhZ17msWKGKyj3YcnLyEm1oxjPW98LnorXsvWREhR4QXh08OigHrKt0HJp8WgJqWUXtOe5eLqGZEQZoU0Ea1LWJrJVezL4tWRux/DjsPnCDvCRoaBQsr7pRCHd3yhNJmcYBCDhaVJs4VNs3Dbte3ar40bW0ahHQvrXGV2msWxzY9fpD38QYT2YB+v95/h4yfr5frw9PDwbb9+LyoZpEV2c7cQJw9sAb3HfD1ebXpameWhaWXSHkhQBCO25q23vrPsYgOuaEzOAbL0MY/X2fe4XGTFmSCZKOP++mw2zYZql7YneIx7MhKpbVPpUq35B5sBZmzbVk9PKdbelLq1zTBj623f++Hz9f7ycn9+Ha/DzcMzlnU/MU4jxBXDnGnu5jbchlvcbzYOmyNEZd/60/cngPvee1PmpIyZ7h4At8ooxHrSSudW7LHWW2dW5YfH/dsvj4+XB3w426+XS87H1nprvZGoq5JIMs2MI+ZhKWXYQOQkxKLytvuc+1HR08oesVgCpNpUlSgpqRpEgYaJDYCgqlBNapEyBoGDNYNylf1Yr00CpA0cygHZoJY8UiYBDLT+5w9v+QovC3M4MoKNEExGuWTGSQG4AFTaTZIgMBCJpEQafGKOYA8HhTACIRwUCadjds+ntkfyC3pP0ZBIfntVM/HuCysEJZLC/LGingJkQb4SMk7RLqF96tqPYa+3I5cpYby9mADcbQzcmATwyD68d927Zvatydbler22pnxWMuUzRqtfdnc3i8hksIj6HD73bd/Lxqb3jYAzOpNOjIEK8BamrlwhhPfb7VfKcXu9bO3x6UkoN9Vt2/TzjvVbhnyeGGi+nYG13xCdvjT5Po8tuGCFboKWg2CMMccxjmOm1x7jzM7MITXYQHJyBjETleSaEhXJoLKMjVe6cAFkpeBN9kqTBkOUtYmoQsjT/68G60z0WQ9VronAGsPW7l3TmaoVKiWrSM0cmcGUWZ6VIFoDFXxEbJfB8LoMeT7cp5tJBtVIhZhLzFkVQmYur4csczGvqII3xP/jzSAPMs9pYZH1THokWYCMZYauDBuLcAv+0K5+vCBJxZmGFuVNi2+kKsJSF6qcdaq2Uoi0CJkeKA5nEtBaYy6SETExWEMQBXRXZURxzHEsR/tPIIpbMf8iEERkNVMBEChuzVLSvycGLHpNES1779u2bfvWt65NWJmpcbI7jIBJqSEq6QnUyCbPgeuncIH1fGfSsmuvO8QiUv5JONGLVWj9hnIKWsUmEZV18NrKIrwcg3lZNIGIKDzDiYyRLJRMANLM78csI+SwzCLGerovGRYhImtHyMgkS9SmBXIPtyg+YhXIFhnDXSYZRjn7rOHxoqwupuH+CZBHOpJFmvRr7t+NBdyI+Hh9phRpO0G5PUI9EQQw2+JUtIAm4pZxpFuke3IkL9cK2cAMNpaNZWfu4AB6krhTxvS8BW3SbvvlLmBAwQ3a3H0ex5yH+8DyWl5cGFZtfe/9suxi5eN+hRO0zHP6+6a7TBG0rtveW28vz68vt+dj3p0sOUqKjcVBrjOyNmNmqdFVjGFzHD5zjjCjcAJQSKW5szlRCidluEVEMqucZV09FpQrwBqgvXftBSG0akjog9FAE93a1rSpNBVtXN02qrchQwaVLWKCVLRvGzMXSlGhI8JaZqXv7SeokKr6nYgwK4uSNINk8DTiSHbjmSyuh/X7rNHGwjSZoVIanTpqwqPwWjOrWnl8YDUiSS1zBlnCiDzJl9UCBRcXt+aV6U7pQCADFiSU6elOI3JkMcQjM4U4Ap27th4ih89fn3/8y/+0MX/+65/vP24xIhynnwOCcA5UVnRV3VI0hjCcKkE8YzFGVk4Xg+zTttt62y7XRUJSV1v0gszlzwHw9KD7OKZvUyO6CpsHEZpqb+0EnhcsncTlQ7larFzHRbjbHKpCRdDJ5Sp9bsWo/YuZe9PL3h6v2zR/fmnm0QR744dLf9zbw6YXxYbUz4yzzwz5AmUXAI7z4ALOv/jUCJ5CYWaRE5DPzOWYNuaYs+Kt4cmWohmROD1BKZnOMpkAVqhya02VF7L2xk8AqHIgZCZTIphZG7Rp6y0ru/wDkw4nbH7OVdceTGu+swYAxdRavHfQElxy+X8sY/hcXxtFKAOIqDBbWh1SZhUFWXs9VRjheVgylxkeJXFiKQFxkiSwwKHzOaSgTw9ZJnnCkzzKw2/JDCJpumNMj6RERHiZO5aLQHm6FWq7Wves5qb3tm17a62t5LIFg0d6hBdHSHsXbZnEWsNujiTmkvHIsBmewsJYUJ9qS0IJ0g+bxxzTy8/rfYXHsnGt6xnkVjcCDbkeIVk5SljyQjBXHhy11q7XS997a1osjcIYmAmZqWEKaRJOqDgVt/Nkp7OKyiIZnI9cFWbgk7RPvq7+22lextEfK4MkSgQR8owAKu6Cu9fkEhkZBkRVdkXnT2MEc0q95mZ+P+yYRVGu6RRlkFekC6gexMw3oWOEledAvu/mpddMqnx3uAfl/Zg2LSJAtKJmRQgUGXN+UlvFvHOAWgM3aU9JmkE26eXXP81jXL79gdC4PXLz5AMwYST1oNaCmjHRs+eL+2vQ3RNGHNCUDXoFb4kZUKILUSfKpB7RpsFjwmB5036bx01FVAIiEAmPMW0ed/dDtEUEKzGLaOt93y7XfXuolpT5o0t2vi2iJa4RqWjLEJGtrI9VzOfr7Xn4EYgi/iGXuinX2whazAZ2BIxs2DymG1EI5cqdAbhYlnOYO4QTlDV0f+PnM3jOaZlMy+yGmbatbZdWWURmdo+R13dykLA0liaqqk1UtfXWmJnefapLGIyg0KaXy87MxakNT9XW+oaPyZy1PwOVTg0GWCEKaSTqYE8OoyDPdACnn/RrXUNagcZc08o3iD6yjJwic2VG38d7yfjhaCeySFvWkElCjMLJVkSfR5qXszt7kGW4xzQanjPCc+2/HJmpQNO+UeMR9z//+F8vcxzHr7/++vLybIeXX5LT4nau9ggngkwMiGhjKFnknOm22Brl91FCqk97FW3b/vTtm9migS9bjHLaIqqsIY/0MJpm7szYe5tW7LeaTb+xv7HQiveTp5oZYeGkSimbGeHT4hyh1tXm0luCipP/cOlPD/uc/nPfp9t1a0+Pl999e3x8eLhct9ZEOBT/9tG+X79/r1PV4umY7l6T0Rp+lvXx1ltrKsIZeX0yMHrv2peaNDO1edvmvJpNo8UGEpEicwhkwdn8UesFglRGspTMnc7d962sZg7dLtenB9HvDPSt63aJ3DOSwpHv8h4W+cf/+F/DrY52yuRKQSDK0sGvc5PwlkKKta2/w+wLEKET7TnHuOsxXjjNqrLeDVAX+o2zL6/WvMqi9RzVz1ieGvT2TwAkzK31ty/y8+fPf/6Xf/7x45mgT0+/I+Lj93ebk5YNQJlIUebywaa3Ie15KJ2gfNb3XKF+p1HXugMrZ3xxN0SFWZNoi2iX7w9zZrW2KgDXFIbXGQyGCNc8YzHpieXH88v7NICImf/7P/3T6+1lgT/vwAdAJCKqTVTWrJ2LNeVupXePzKyWffE0zxtHy6nCbbgPt+kxT9pdeJ54U0EWq84rAt1SY3LrKl0y0s3d6gQ9b1wV3YTHx8f3d74//uEP/4Xe2Xi0jl6PiLJUZSCwAgcoT1eHkmjUD4mIaWGn1g9ngnABYWe/CuDN//j91+KjrItYPMdVzQVlOT4sD8d6sHkVM48PTx9fcw/NhFMQ4NninDdFZI7Ilxt4ukcmWtuTq+fIJLBcW/vmdoQNsyNslsJo23rvXXhjViYHc/AjRUuk9KeH3/2DXr5HBLFI27f9yvqQ2TwknDAsLRisujELQ908c7hFRtgY9+efqr2emOP2PtkV0afrL29XhMq/iVEMXVUB9H5wkNvoF/39f/jOD/1v3gZimchFIZEaEolCGJE+5zHn9FlhQnxKnSEq29Z733htYgki3z08hQtvFIDcVxJjRJhPQrZNWteaybg7Qz+WjPvDN229QoKXVQLLygyoKhTkLATKyKl6K/PxquY5J2FE4H3XOhGr1StUlZxwBwy8WtpFlT2nyG+r6FsnQMnni4q3bfDc3wDw7cPt2Fv/H//439xi+bSuOV6xjpnwNlePNCO39ccNJFipJP4RLU8q88em2ptWqiSrQNz8/v0+x7CSzK62Cu879KpMaOGQrAzOMt+1GbMcSWm1daDeetE1av3d3/89zsqpgOuooqYqSCxMr66ECu+9XfZ+3fv1uvfLzqoFcONtE0xiIo6QWFQ+xrnzEqCSYNF4pP6f/jOefv+3xxjuXhtWRPYmjw+X74/XX54ePXL79rce0ZQv+3a9bL33puUCLqzvZwcR4Y9//CN9ra/1tb7W1/paX+uvZfH/+3/5Wl/ra32tr/W1vtb/P+vraP9aX+trfa2v9bX+qtbX0f61vtbX+lpf62v9Va3/A4SYuYUKZW5kc3RyZWFtCmVuZG9iagozNSAwIG9iago1ODAzMAplbmRvYmoKMiAwIG9iago8PCAvQ291bnQgMSAvS2lkcyBbIDEwIDAgUiBdIC9UeXBlIC9QYWdlcyA+PgplbmRvYmoKMzYgMCBvYmoKPDwgL0NyZWF0aW9uRGF0ZSAoRDoyMDIyMDUzMTE3MDAwMSswMicwMCcpCi9DcmVhdG9yIChNYXRwbG90bGliIHYzLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjMuMikgPj4KZW5kb2JqCnhyZWYKMCAzNwowMDAwMDAwMDAwIDY1NTM1IGYgCjAwMDAwMDAwMTYgMDAwMDAgbiAKMDAwMDA2NTM0MCAwMDAwMCBuIAowMDAwMDA2ODU1IDAwMDAwIG4gCjAwMDAwMDY4ODcgMDAwMDAgbiAKMDAwMDAwNjk4NiAwMDAwMCBuIAowMDAwMDA3MDA3IDAwMDAwIG4gCjAwMDAwMDcwMjggMDAwMDAgbiAKMDAwMDAwMDA2NSAwMDAwMCBuIAowMDAwMDAwMzk2IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwMDY3NCAwMDAwMCBuIAowMDAwMDA3MDYwIDAwMDAwIG4gCjAwMDAwMDU1OTEgMDAwMDAgbiAKMDAwMDAwNTM5MSAwMDAwMCBuIAowMDAwMDA0OTk1IDAwMDAwIG4gCjAwMDAwMDY2NDQgMDAwMDAgbiAKMDAwMDAwMDY5NCAwMDAwMCBuIAowMDAwMDAwODU0IDAwMDAwIG4gCjAwMDAwMDExNTkgMDAwMDAgbiAKMDAwMDAwMTMwNSAwMDAwMCBuIAowMDAwMDAxNDI2IDAwMDAwIG4gCjAwMDAwMDE3MjYgMDAwMDAgbiAKMDAwMDAwMjEwMyAwMDAwMCBuIAowMDAwMDAyNDIxIDAwMDAwIG4gCjAwMDAwMDI1MzggMDAwMDAgbiAKMDAwMDAwMjg2NiAwMDAwMCBuIAowMDAwMDAzMTAwIDAwMDAwIG4gCjAwMDAwMDMzODcgMDAwMDAgbiAKMDAwMDAwMzUzOSAwMDAwMCBuIAowMDAwMDAzODQ4IDAwMDAwIG4gCjAwMDAwMDQyNTMgMDAwMDAgbiAKMDAwMDAwNDM0MiAwMDAwMCBuIAowMDAwMDA0NTAxIDAwMDAwIG4gCjAwMDAwMDQ3MTIgMDAwMDAgbiAKMDAwMDA2NTMxOCAwMDAwMCBuIAowMDAwMDY1NDAwIDAwMDAwIG4gCnRyYWlsZXIKPDwgL0luZm8gMzYgMCBSIC9Sb290IDEgMCBSIC9TaXplIDM3ID4+CnN0YXJ0eHJlZgo2NTU1NwolJUVPRgo=\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T17:00:00.931089\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Main class: forest, Anomaly class: sea\n", "Prediction: image 3\n" ] }, { "data": { "application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1BhZ2VzIDIgMCBSIC9UeXBlIC9DYXRhbG9nID4+CmVuZG9iago4IDAgb2JqCjw8IC9FeHRHU3RhdGUgNCAwIFIgL0ZvbnQgMyAwIFIgL1BhdHRlcm4gNSAwIFIKL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gL1NoYWRpbmcgNiAwIFIKL1hPYmplY3QgNyAwIFIgPj4KZW5kb2JqCjEwIDAgb2JqCjw8IC9Bbm5vdHMgWyBdIC9Db250ZW50cyA5IDAgUgovR3JvdXAgPDwgL0NTIC9EZXZpY2VSR0IgL1MgL1RyYW5zcGFyZW5jeSAvVHlwZSAvR3JvdXAgPj4KL01lZGlhQm94IFsgMCAwIDY3MC4zOTc3OTM5NzIzIDY5OC41MTY4NzUgXSAvUGFyZW50IDIgMCBSIC9SZXNvdXJjZXMgOCAwIFIKL1R5cGUgL1BhZ2UgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMSAwIFIgPj4Kc3RyZWFtCnicvZ1djyvHdUXf+Sv4mABBq+urPx4t2FES5EXJBfxsyNeKBV0HsoEE+fc5Tc6w9j5TVZtNQxcXRqCTIRfJ6a5aTa6ZCdefLt/8Jlx//Nt1vv5k//vfa7h+d/3mt5//588/fP6P7769/vC3y2zzL5dlnae0r+ue7T9/xv9c9m0qYdnWYvOZ//O/Lpe/XOz+7Tbf2V3/eLnEeSplT/b/LWWdyrbFtNvdh7RPYUlzCDj/GefLukxhD/c7rndDY+P96fLLtQHZlrTkawjbtMxpuf/z//nXz9ffX/9y/eY38XhFgr04wV6R+cMr8ovdcL8er8vxf7uwH75cv/nXcP3tf1+/v3x//eX9fmd7OcLxak/b273b5BKXKW+73QO/LjjO0/z+sly+PR7a5dtP12/+OVzDfP30p0tMU5r3bdlT2OK15Djte87b/bF8+uPlH+Z/vH766fq7T5cb+ZK2KeV9jwyEqeClMoV1jinuJZcWLzCvzFMMc9wyA3EsiHmd5jSXfSn7vLeIkYlLnOZ1XvLGRBwLYtmmzVjLHOOaWsTExDVN9vK7o/gLjQXRzqu1hHQ/hlrEzMStTOsePpxPOBbELU5LuJ+wDVxh3L5OSwmr3RnhcCxwe57yGucSwoFoEBcmhnm39eQ4RNyKgXPBDPMypRRzyOsa1xZ0ddAYprTGvCwOinMFtcUl7HFdtzzvoQXdHDSlyU6nLfmlsY4VMs3TXFLIcc6l+e3cGQlr15Ls5dxyPIh2nvlp52idjld/n9YYbBE4/rVWGqCsxV6/zV4SxtTxgLMGe1GCnY7Hv9YKA5zNbpH2YOc4cep4wNnSlLOd8Me/1FpXgLPv07zvdjAzp44HnN2e9BzfN5/GagKbXJinvezHOYocGPc5y7weW2hr8QBAjNMW57cjrgLqeAAI9mxjuj+P5qoBnJSnZZttQWVOHQ84dnRvW1rv35nWWgGcvEzFzjs7UYlTxwOOnTZrzvfjLLSWB+DYDpFnO0UZ85gOKHYyL3O+nzWltR6g49gmuQTb2hlTxwPOYgphm124/eNFIF7/7S5+NxVhyemIWke8Lv/ZUbgvXYWzm5yUQboF3NeQMd+e41257jr34+NFijePiVMocyxlOV6fxbbDt5sfL9W//+H/Pv/1Gv7p+i+f//DH62O9eZPmu9nexflhuMHO65jW2/ZNZhXStOVgxy4/K5jTY4f7+ei4h08+vqCY9C2vqu7lrrrXZ1R3TLz5bnzWd+dg37H1eMnd8QVztd2ZduzbZruevTTNbd1vRPZSm2raieCgOFeabd6x5S3vaTOheMJ7Y8y2bG03SSYozhXUvGO1ryhrXubmxu63qLTYGrbdPJmgOJcXFLb5L7uxbcHcnrDfmI9jbHdH+BeeK2guU7ZLii2vZmxPCHBcZlvf9g+nG80V1G4YN+Ns6XZcSw22p2JC6YDvMwVb9mPBSSHOto4/YcB2bWRn/3HuOiDOFXQL02wavBbb3ZpXMn6L2xfbZuabLhMU5wq6mz2sYU72rd+a56fb8ZKpyh6CP2hxrC5K5zytJsBz3MMzAoyLNRrwbuebH3cWIXtlzVztcHtbE5UNI5J0GJjKhz1UqjFCyY0BquTYQ6UnI5REGaDKlD1USjNum2TNFSq12UEHBo00UmigKYf2NKnTCCWfBqgSag+Vbo1QkmuAKrv2UCnaCEXTBqZQbY+U1k3ahdoNTOXdHvqigvdMsmeGZrttyfzSlczD20/6Kt+ENHxAkRoe58VOk62EpZSQxx4en/Fw28pt1yjb7nYw2w1sQd8L798/45wePNxP28MfX7CsaQ9fw8PHxJuHpyc93C6l7FtmK5zfc3GuNl0zHzvG7aHMpf026Ye3ntfpeAc2f3jzGeYKavZTogmNSV5qyoXbApNtNWU5FkUHxbmCmv3Yg7RvYmo+T7cB5hBMZldbKdwhiHP1pve8TTGbswdbV5tvmLoN0PzOZHYt3k9prqB2nR1m849iO0hT/d0+mI9PA/L68WTDuYKmMNmFhl1gxdR+w9Ttg/n4QGC2ZdYxYayQOd0WnfV4YZq+6HbBfHwisGxLdJ9i0FxBS7G9edvTUo6TWbt4Pj4UsOuTzV0r01xBF7uhXbTOswle87rK7YP2VUbaUwl++YS5/MBmn24+YDKbm9dVTsdxzQYdT8cG6se9T6XMQcpSYnhbGpWOIxJ1HJlKxz1U6jhCUccRqnTcQ6WOIxR1HKFKxz1U6jjunqjjAJU67qADHUca6jjSlI57mtRxhKKOI1TpuIdKHUco6jhClY57qNRxhIKOI1PouEdKHSf7Ah1HptJxD31Rx3tC2RNEs962a37puuZh8Ce1lW9COj6g6HfF7RKmzMu6ruYC+1jH0zM6XrL52n58y9yn8bagh3tJQs8L5vTg4X7aOv74gtUuv9avoeNj4k3H85M6XspsVpuT/4Cd5ipaMP9JOZtd7HFp+oXbAssSzWrNJtxRSnMFtY02zmUOdqK137f1OciazWuPQ8xBca6gttHOS0nrHvb2+7ZuCyybHTFzyf6JwlgWIcaI9iDj1u4z3P5Xjs8oluIO7y88V8w9T+tm2+SS1/abt24bXI7PKGxJ/HCu4VxAF7setx3Evvthbr9567bBJQbb445V1uc2MFdQOx3tW29XV2suzUsAtw0ux2cU2WzafbhCcwW1zTYtqy02tug1r7DcNrgcn1HMa8zuJKW5ghrFZK3YdzWszSsstxEux2cUtuIFd81McwW9fcG6R1sccvMkdT6Oizb4eMl23vlx59i1i7qY8v3NsPUJH0ck+jgylY97qPRxhKKPI1T5uIdKH0co+jhClY97qPRx3D7RxwEqfdxBBz6ONPRxpCkf9zTp4whFH0eo8nEPlT6OUPRxhCof91Dp4wgFH0em8HGPlD5O+gU+jkzl4x76oo/3jLJniKa9bdn80pVNu8lZb+WbkI8PKNLHTaumsJU5bXGfy9jH8zM+Xksa0yc7GeLbp5C1pIE5hdh536a4Lbclv94LTh8u3oLc1uBTIfblOQ3vwm4GXk6H2PS64Pj4TsT+7sMhdlqK7bn7nPJj8+mF2ASEqeBxiN3gdUNsAuJYEDnEbhC7ITYRcSyIHGI3iN0Qm4g4FkQOsRvEbohNRBwLIoTYDVw3xCYcjgWOQ+wGsR9i84qBc8F0IXYD2g+xGYpzBeUQuwHththuaaxjheQQu4HshtjJjhiTeV+uwrjfeh5v2mxhQWPoptjJDpWjseKktE4HlD1Ntn+hJHRDbHvC9g1PvlyFcZ+T5zKVtKAXdEPsbMdD2O0WK3PqeMCxr7fnfcfEcYid7RiYS3aZfJ0OKHGfol25wJVPt8bOJhD78bEZZ/IwHnBMxEJY8WKnW2Nnk4rVVjmXycN4wCl28WgKB9c33RrbLsSnJRWfycN4wFnytN0W4cclTbfGzts85b34TB7GA46d08veXA9QcuKUyuIzeRgPAMe7sfb1cOFyMsRmIWmaF3fYPbNztzipgnSLelcjwskIu8z5fuv5ca7cDTe+GGKzWdVKhZ4XzPHhw900HddF0S+K7t+RYbdsdzmfYbujC+Zqs+MMWwsv5NYMxbmSbM6wtfVCbs1QnCsoZ9hafCG3ZijO5eUEZdjafSG3ZijOFZQzbK2/kFu76yaYKyhn2FqC35JrBr7PFIwzbO2/kFszEOcKyhm29l/IrRmKcwXlDFv7b+2ted2EsbokpQxb6y8u1eS/NbqUAuyiS+3CCEUZBqawYY+UYox7CZlxZUo1dlBtyQglTQao8mQPlcqMUHRmYApp9kjpz4gkgQamMmgPlTKNULJpgCqd9lBp1ggltQaocmsPlZqNUPJsgCrR9tC+c5N1oXQDTVm3p70o4D2P7IghV9g98/TSftJV+SYo4X3G2QRbWvjZDJueGcQp9Mxgjo8f7qZp4S6J/goWPibeLHw9H2Hzdx/nasvlCPuZt50fsbV74xnmCsoRtrZwiK0ZinMFxQhbOzjE1nwA4ly94c0RtnZwiK0ZinMF5QhbOzjE1gzFuYJyhK0dvNbWzISxQnKErU0cYmtm4lxBOcLWJg6xNUNxrqAcYWsTh9jaLZ4wlx/WUIStZRxXbJRxSC6ljLvkUss4QkHGkSlk3COljOOWgjIOTCnjDqplHKEo4whVMu6hUsYRCjKOTCHjHillHJEo48hUMu6hUsYRijKOUCXjHiplHKEo4whVMu6hUsYRijKOUCXjHtqXcZIvkHGkKRn3tBdlvKeTHT/kBrsnoF7fTyor3wRlvM84G2BLGT8bYfMn8bVMoWcGc3z8cDdNGXdB9FeQ8THxJuPb+QSb+wicq1yBE2wt45BaMxTnCsoJ9hMhSE2tGYpzBeUEW+t4ba2ZCWPZgmCCrV0cUmtG4lwxOcHWLg6pNZ9pOBdQl2BrF4fU2oU2MFdQTrC1jUNqzVCcKygn2NrGIbVmKM4VlBNsbeOQWjMU5wrKCba2cVyy0cYhuJQ27oJLbeMIBRtHprBxj5Q2jnsK2jgwpY07qLZxhKKNI1TZuIdKG0co2DgyhY17pLRxRKKNI1PZuIdKG0co2jhClY17qLRxhKKNI1TZuIdKG0co2jhClY17aN/Gyb7AxpGmbNzTXrTxnk92BJEL7J6BulucdVa+Cdp4n3E2v5Y2fjLBPnILM/3754+1oqljCrCT3TBtbz9IAPEOjhsJ9vu93Trh8isn2A52M/D9dIKNr0udhuOb9/ay6AA7bKt965J53mPr6QXYiHsMFY3z6watm18jrk4Vj+PrBq8bXyOvThWP0+sGr5teI69OFY/D6wavG14jr04VD7LrBqybXSOsThWMo+sGrx9d0xpRx4rokusGsp9cE7KOJZKD6wayG1zzQvg+lUDOrRvAbm4d5+PTrtWnqTDuF51xtoPaXl8Qg25uHcPxAdfq01QYDzjBDuY1jEvrmI6PszafpsJ4AIh2AKeA23+3tI75+ARr87/BGcYDju1heQ/4Y3Td0jqW40OrzXfwMB5wih2rJeIFTre1tgu3aZ5318HX6YBiOhFDxGuabmltbjIdEYLr4GE84Kx2jbhGvIzpltZmJ9OaZt/Bw3jA2RZboxNeuXRL6zRHM6zZdfB1OqDYubzuCS9Wurm1ndFTLrPv4GE8+CmFMNuinPD65GRuTfLRtizOrTsW529xSvs4z673NCSczK1TXO43L4/T5W6z6cXcmjyq9ij4vDBswUePIUzTZ138/KLU/h25dcts7bA73Vvz4fUYy72Oa2ttt1BVE7KOpU9za60VF5pqQtaxRHJprS0XimpC1rG+aqDOWosu9NSErGOJ5Mpauy7U1Hxp9BhLJDfW2njfemrC3UcSxYW1ll0oqQlXxxLJfbWWXeioCVnHEsl1tZbdmlHTMvmYyutNaqu169K6jLJbS0ppu66k1OKLUDJfgCr19dC+BSONNBhoyoM9TSoxQsmJAaqk2EOlHyOUBBmgypA9VMoyQtGWgSl02SOlOSOS1BmYyp09VGo0QsmjAapE2kOlU6PwoFTDTyEIq/ZIKdiIJMMGplJs/6MPL9p2Rxp7FshtdUczPxj6KTF1MTYa94Bxtq2Wyn22rcZnBtUJPjPMV/DhY+7SVG5XOn8F5R4T78r97F9WhIiavv11LHdcTqufeUP5kVDzW8qPsURyWK2VGwJqQtaxRGJWrYUb8mk69upYvpHNUbUWboinCVnHEslJtRZuSKcJWccSyUG1Fu5aThPxMZVAzqm1dkM2TcQ6lkiOqbV2QzRNyDqWSE6ptXZDMs1r5WOsP3ihkFqbNy3PYN6QTUrzdtmkNm+EonkjVJm3h/bNG2lo3khT5u1p0rwRiuaNUGXeHirNG6Fo3ghV5u2h0rwRCuaNTGHeHinNG5Fo3shU5u2h0rwRiuaNUGXeHirNG70HzBt/5ECYt0dK80YkmjcylXn7n3N40bw77tiTQQ6pO7b5wdVP+akrr9G8B4yzIbU077MhNX2qXgsTfGaYquDDx7Slad4ua/4K5j0m3s372b/xCMU0pQ51LNsD7qi1eUMvTcg6lkiuqJ/oOWotTcg6lkhuqLV711iaiI+pTjqwoNbiDaU0AetYErmf1uINnTSdYHWskK6e1uINlTSXMo+xRHI7rdUbGmlC1rFEcjmt1RsKaULWsURyN63VG/poQtaxRHI1rdWb1mdQb2gkpXq7RlKrN0JRvRGq1NtD++qNNFRvpCn19jSp3ghF9UaoUm8PleqNUFRvhCr19lCp3ggF9UamUG+PlOqNSFRvZCr19lCp3ghF9UaoUm8PleqN4gPqjT9fINTbI6V6IxLVG5lKvf0PNbyo3h157NkgV9Md3fS3OCeoLrNG9R4wzlbTUr3PVtPp+OT1w1+Ar2OqpoNtW3l9y/7rndC4VU0/7u1UMH28JJdTwfSDc5fsZ/+AI8TS8HJgQj2l91dDx9LmvusWosikAQTt9JjDmTRx+oE0gDCbHpM4kCZSP40GEgbTYxKn0UTqR9FAwlR6TOIomkj9HBpIGEmPSZBDE6YfQgMG8+gxhkNoIg0SaDzhsYwes1wCTbBB/IwwbKIFjONngvWzZ1rJagwtUJw9E6ofPNsF1Bw/BqLv01HvarpTcFfv1s62py5L9L8l+X04iDZXs5uCu3i3dy52iRSjT0PrtA/Jx5+gKbhrd2PnpdiGlfzvR67TPmQ5/uTMgj+o1i2dV/v6nHyBXqd9yHr8iZkFr0i6mfNm+6Dtwy4/r9M+ZDv+pMw6zpvDPE/7kVxxfA7j/v3vx1+PWfGao1s3h3D8rZjs23MY9zHHejHHFS8zunWzWbpd8mffnsN4wLG72bYNryy6fbOtfXadX1x7XqcDSjr+NsyGFxNn62a0hbYNubq5bVv+Fqf0jL4e7mlIOFk323n6dvPHWXK3zvxq3Yz6AxEJPC8Y06OHe2l7p2+NX5PPJ39arwu7G+izf7MQm2Y6qCB1FhsZN80jCcWaGWEYOQvh5Zp5ZKLYMSMM82YB4455JKNYMCMMw2al8lQwj3wU22WEYdIsYNwuj5QUq2W6RoGYWcC4Wh6J6XuvjKD3hFlAuFceOSmWygjCgFnAuFQeOSk2ygjDdFnAuFEeOSnUybjMQbMsLvOoTh4pKdWW4KQYW46l9ENVKvwUf2MqCCr8RtmxofrmUckq8MhWK1Dpqv8NtspcgUjqWonKXR1RaiwQyWMrUYmsI0qnBSJJbSUqq3XEvuDidk2GW1lScR1M2i4ySXeBqXzXMbX6IpTcF6BKfj1UejBCUYSBKUzYI1+U4o7b9WTNRchtG/wg0qf8kW9AYjxgnI2QpRmfjpDhmWHkAc8MxvTw4V7aZuyT4F/VjLuwuxk/+7cEMT3GbzoWyWI75fR4/PZsjY7pDVpokQWMo+ORGWNujDCskAUMc+ORF2NojMcX9sfiDWEOjUdejIkxwrA8FjBOjEdejHExwrA5FjCOi0deDFkxsiA2FijOikd2jEExsrAzFjAOikd2jCkxwrAwFjBOiUd2jBExrXXQFqsPJygiHgkyRZFVkKmJHAvyh/hTCDL+FtMqyPhbXseC7NNEJcjAQ0EGoBJk/1tllSADEQUZiEqQHVEKMhBRkIGoBNkRpSADEQUZiEqQHbEvyLhroyADSwqyg0lBRiYKMjKVIDumFmSEoiAjVAmyh0pBRigIMjKFIHvki4LcUbyes7lWuC2FH5T6lEbyDUiQB4yzrbAU5NOtMH7MDCkGPDMY08OHe2kLsi93f1VB7sLugvzsnx/EQhg/68dwWHwEz4XwSJCxDUYYJsMCxm3wMGKAKhhhGAsLGFfBI0WGHhhZUAmrjgF74JEfYwmMKAyEBYtL4JEfYwOMJw6mwWOYa4BHfoz1L8UgEAULGNe/I0PG7hdhmAMLGHe/I0PG4hdhGAILGBe/I0PG1hdhmAALGLe+I0OmdrEaMqWLY0P+0GgKQ8bfLFoNGX/z6tiQfUGoDBl4aMgAVIbsf9OrMmQgoiEDURmyI0pDBiIaMhCVITuiNGQgoiEDURmyI/YNGbdtNGRgSUN2MGnIyERDRqYyZMfUhoxQNGSEKkP2UGnICAVDRqYwZI980ZA7jteTNpf0tq3Q3+KcR/INyJAHjLNJrzTkx5n3/eX/ATcNo6IKZW5kc3RyZWFtCmVuZG9iagoxMSAwIG9iago2MDczCmVuZG9iagozMiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDc3ID4+CnN0cmVhbQp4nDM3NVIwULC0ABJmpiYK5kaWCimGXEA+iJXLZWhpDmblgFkmxgZAlqmpKRILIgvTC2HB5GC0sYk51AQECyQHtjYHZlsOVxoAnuAbmgplbmRzdHJlYW0KZW5kb2JqCjMzIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggNTkgPj4Kc3RyZWFtCnicMzU1VzBQsLQAEqamRgrmRpYKKYZcQD6IlctlaGkOZuWAWRbGQAZIGZxhAKTBmnNgenK40gCp4RBaCmVuZHN0cmVhbQplbmRvYmoKMzQgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMDQgPj4Kc3RyZWFtCnicPZI7ksMwDEN7nYIXyIz4k+TzZCeV9/7tPjLJVoBJiQAoL3WZsqY8IGkmCf/R4eFiO+V32J7NzMC1RC8TyynPoSvE3EX5spmNurI6xarDMJ1b9Kici4ZNk5rnKksZtwuew7WJ55Z9xA83NKgHdY1Lwg3d1WhZCs1wdf87vUfZdzU8F5tU6tQXjxdRFeb5IU+ih+lK4nw8KCFcezBGFhLkU9FAjrNcrfJeQvYOtxqywkFqSeezJzzYdXpPLm4XzRAPZLlU+E5R7O3QM77sSgk9ErbhWO59O5qx6RqbOOx+70bWyoyuaCF+yFcn6yVg3FMmRRJkTrZYbovVnu6hKKZzhnMZIOrZioZS5mJXq38MO28sL9ksyJTMCzJGp02eOHjIfo2a9HmV53j9AWzzczsKZW5kc3RyZWFtCmVuZG9iagozNSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDY2ID4+CnN0cmVhbQp4nDM2tFAwUDA3V9A1NDRVMDIyUDA0MlFIMeQyNDQHM3O5YII5YJaJAZBhCCTBGnK4YFpzwDogslCtOVxpAE04EfUKZW5kc3RyZWFtCmVuZG9iagozNiAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDIyNyA+PgpzdHJlYW0KeJw1TzuyAyEM6zmFLpAZjG1gz7OZVC/3b59ksg0S/kjy9ERHJl7myAis2fG2FhmIGfgWU/GvPe3DhOo9uIcI5eJCmGEknDXruJun48W/XeUz1sG7Db5ilhcEtjCT9ZXFmct2wVgaJ3FOshtj10RsY13r6RTWEUwoAyGd7TAlyBwVKX2yo4w5Ok7kiediqsUuv+9hfcGmMaLCHFcFT9BkUJY97yagHRf039WN30k0i14CMpFgYZ0k5s5ZTvjVa0fHUYsiMSekGeQyEdKcrmIKoQnFOjsKKhUFl+pzyt0+/2hdW00KZW5kc3RyZWFtCmVuZG9iagozNyAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI0NSA+PgpzdHJlYW0KeJxFULuNQzEM6z0FFwhg/Sx7nndIldu/PUpGcIUhWj+SWhKYiMBLDLGUb+JHRkE9C78XheIzxM8XhUHOhKRAnPUZEJl4htpGbuh2cM68wzOMOQIXxVpwptOZ9lzY5JwHJxDObZTxjEK6SVQVcVSfcUzxqrLPjdeBpbVss9OR7CGNhEtJJSaXflMq/7QpWyro2kUTsEjkgZNNNOEsP0OSYsyglFH3MLWO9HGykUd10MnZnDktmdnup+1MfA9YJplR5Smd5zI+J6nzXE597rMd0eSipVX7nP3ekZbyIrXbodXpVyVRmY3Vp5C4PP+Mn/H+A46gWT4KZW5kc3RyZWFtCmVuZG9iagozOCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDM5MiA+PgpzdHJlYW0KeJw9UktuBTEI288puECl8E1ynqne7t1/W5vMVKoKLwO2MZSXDKklP+qSiDNMfvVyXeJR8r1samfmIe4uNqb4WHJfuobYctGaYrFPHMkvyLRUWKFW3aND8YUoEw8ALeCBBeG+HP/xF6jB17CFcsN7ZAJgStRuQMZD0RlIWUERYfuRFeikUK9s4e8oIFfUrIWhdGKIDZYAKb6rDYmYqNmgh4SVkqod0vGMpPBbwV2JYVBbW9sEeGbQENnekY0RM+3RGXFZEWs/PemjUTK1URkPTWd88d0yUvPRFeik0sjdykNnz0InYCTmSZjncCPhnttBCzH0ca+WT2z3mClWkfAFO8oBA7393pKNz3vgLIxc2+xMJ/DRaaccE62+HmL9gz9sS5tcxyuHRRSovCgIftdBE3F8WMX3ZKNEd7QB1iMT1WglEAwSws7tMPJ4xnnZ3hW05vREaKNEHtSOET0ossXlnBWwp/yszbEcng8me2+0j5TMzKiEFdR2eqi2z2Md1Hee+/r8AS4AoRkKZW5kc3RyZWFtCmVuZG9iagozOSAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDI0NyA+PgpzdHJlYW0KeJxNUbttRDEM698UXOAA62t5ngtSXfZvQ8kIkMIgoS8ppyUW9sZLDOEHWw++5JFVQ38ePzHsMyw9yeTUP+a5yVQUvhWqm5hQF2Lh/WgEvBZ0LyIrygffj2UMc8734KMQl2AmNGCsb0kmF9W8M2TCiaGOw0GbVBh3TRQsrhXNM8jtVjeyOrMgbHglE+LGAEQE2ReQzWCjjLGVkMVyHqgKkgVaYNfpG1GLgiuU1gl0otbEuszgq+f2djdDL/LgqLp4fQzrS7DC6KV7LHyuQh/M9Ew7d0kjvfCmExFmDwVSmZ2RlTo9Yn23QP+fZSv4+8nP8/0LFShcKgplbmRzdHJlYW0KZW5kb2JqCjQwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggOTAgPj4Kc3RyZWFtCnicTY1BEsAgCAPvvCJPUETQ/3R60v9fq9QOvcBOAokWRYL0NWpLMO64MhVrUCmYlJfAVTBcC9ruosr+MklMnYbTe7cDg7LxcYPSSfv2cXoAq/16Bt0P0hwiWAplbmRzdHJlYW0KZW5kb2JqCjQxIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzIwID4+CnN0cmVhbQp4nDVRu3HFMAzrNQUX8J34lTSPc6/K278NQDsVYRoEQKq8ZEq5XOqSVbLC5EeH6hRN+T5gpvwO9ZDj6B7ZIbpT1pZ7GAjLxDyljlhNlnu4BYEvDE2JuYXz9wjoKwajMBOBusXfP0CzJDBpcPBTkGutWmKJDjwsFlizK8ytGilUyFV8Oza5BwVycbPQpxyaFLfcgvBliGRHarGvy2Up8rv1CRiEFeaITxSJheeBDmYi8ScDYnv22WJXVy+qERnWSYcHUgTSbG4SMDRFsuqDG9hXxzU/T0fZwclBv4rB+DY4mS9JeV8FoRCPF/4Oz9nIsZJDJBTyfbXAiCNsgBGhT+0jEGUgNEX37plSPiZViu8ARiEcfapXMrwXkdlqhs3/GV3ZKgoGVVkfn0ZwJoNJOPNkowrTUrXTv/vc4/MHY2N6gAplbmRzdHJlYW0KZW5kb2JqCjQyIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggODAgPj4Kc3RyZWFtCnicRYy7DcAwCER7pmAEfiZmnyiVs38bIErccE+6e7g6EjJT3mGGhwSeDCyGU/EGmaNgNbhGUo2d7KOwbl91geZ6U6v19wcqT3Z2cT3Nyxn0CmVuZHN0cmVhbQplbmRvYmoKNDMgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxNTcgPj4Kc3RyZWFtCnicRZC5EUMxCERzVUEJErAI6rHH0Xf/qRf5SrRvAC2HryVTqh8nIqbc12j0MHkOn00lVizYJraTGnIbFkFKMZh4TjGro7ehmYfU67ioqrh1ZpXTacvKxX/zaFczkz3CNeon8E3o+J88tKnoW6CvC5R9QLU4nUlQMX2vYoGjnHZ/IpwY4D4ZR5kpI3Fibgrs9xkAZr5XuMbjBd0BN3kKZW5kc3RyZWFtCmVuZG9iago0NCAwIG9iago8PCAvRmlsdGVyIC9GbGF0ZURlY29kZSAvTGVuZ3RoIDY4ID4+CnN0cmVhbQp4nDMzNlMwULAwAhKmpoYK5kaWCimGXEA+iJXLBRPLAbPMLMyBLCMLkJYcLkMLYzBtYmykYGZiBmRZIDEgutIAcvgSkQplbmRzdHJlYW0KZW5kb2JqCjQ1IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMzE3ID4+CnN0cmVhbQp4nDVSS3JDMQjbv1Nwgc6Yv32edLJq7r+thCcrsC1AQi4vWdJLftQl26XD5Fcf9yWxQj6P7ZrMUsX3FrMUzy2vR88Rty0KBFETPfgyJxUi1M/U6Dp4YZc+A68QTikWeAeTAAav4V94lE6DwDsbMt4Rk5EaECTBmkuLTUiUPUn8K+X1pJU0dH4mK3P5e3KpFGqjyQgVIFi52AekKykeJBM9iUiycr03VojekFeSx2clJhkQ3SaxTbTA49yVtISZmEIF5liA1XSzuvocTFjjsITxKmEW1YNNnjWphGa0jmNkw3j3wkyJhYbDElCbfZUJqpeP09wJI6ZHTXbtwrJbNu8hRKP5MyyUwccoJAGHTmMkCtKwgBGBOb2wir3mCzkWwIhlnZosDG1oJbt6joXA0JyzpWHG157X8/4HRVt7owplbmRzdHJlYW0KZW5kb2JqCjQ2IDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMTcgPj4Kc3RyZWFtCnicMza0UDCAwxRDLgAalALsCmVuZHN0cmVhbQplbmRvYmoKNDcgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAzMzggPj4Kc3RyZWFtCnicNVI5rt1ADOt9Cl0ggHbNnOcFqX7u34aUXwpDtFaKmo4WlWn5ZSFVLZMuv+1JbYkb8vfJCokTklcl2qUMkVD5PIVUv2fLvL7WnBEgS5UKk5OSxyUL/gyX3i4c52NrP48jdz16YFWMhBIByxQTo2tZOrvDmo38PKYBP+IRcq5YtxxjFUgNunHaFe9D83nIGiBmmJaKCl1WiRZ+QfGgR61991hUWCDR7RxJcIyNUJGAdoHaSAw5sxa7qC/6WZSYCXTtiyLuosASScycYl06+g8+dCyovzbjy6+OSvpIK2tM2nejSWnMIpOul0VvN299PbhA8y7Kf17NIEFT1ihpfNCqnWMomhllhXccmgw0xxyHzBM8hzMSlPR9KH5fSya6KJE/Dg2hf18eo4ycBm8Bc9GftooDF/HZYa8cYIXSxZrkfUAqE3pg+v/X+Hn+/AMctoBUCmVuZHN0cmVhbQplbmRvYmoKNDggMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAyNDggPj4Kc3RyZWFtCnicLVE5kgNBCMvnFXpCc9PvscuR9//pCsoBg4ZDIDotcVDGTxCWK97yyFW04e+ZGMF3waHfynUbFjkQFUjSGFRNqF28Hr0HdhxmAvOkNSyDGesDP2MKN3pxeEzG2e11GTUEe9drT2ZQMisXccnEBVN12MiZw0+mjAvtXM8NyLkR1mUYpJuVxoyEI00hUkih6iapM0GQBKOrUaONHMV+6csjnWFVI2oM+1xL29dzE84aNDsWqzw5pUdXnMvJxQsrB/28zcBFVBqrPBAScL/bQ/2c7OQ33tK5s8X0+F5zsrwwFVjx5rUbkE21+Dcv4vg94+v5/AOopVsWCmVuZHN0cmVhbQplbmRvYmoKNDkgMCBvYmoKPDwgL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0xlbmd0aCAxMzggPj4Kc3RyZWFtCnicPY9BDgMxCAPveYU/ECl2Qljes1VP2/9fS5rdXtAIjDEWQkNvqGoOm4INx4ulS6jW8CmKiUoOyJlgDqWk0h1nkXpiOBjcHrQbzuKx6foRu5JWfdDmRrolaIJH7FNp3JZxE8QDNQXqKepco7wQuZ+pV9g0kt20spJrOKbfveep6//TVd5fX98ujAplbmRzdHJlYW0KZW5kb2JqCjUwIDAgb2JqCjw8IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9MZW5ndGggMjEwID4+CnN0cmVhbQp4nDVQyw1DMQi7ZwoWqBQCgWSeVr11/2tt0DthEf9CWMiUCHmpyc4p6Us+OkwPti6/sSILrXUl7MqaIJ4r76GZsrHR2OJgcBomXoAWN2DoaY0aNXThgqYulUKBxSXwmXx1e+i+Txl4ahlydgQRQ8lgCWq6Fk1YtDyfkE4B4v9+w+4t5KGS88qeG/kbnO3wO7Nu4SdqdiLRchUy1LM0xxgIE0UePHlFpnDis9Z31TQS1GYLTpYBrk4/jA4AYCJeWYDsrkQ5S9KOpZ9vvMf3D0AAU7QKZW5kc3RyZWFtCmVuZG9iagozMCAwIG9iago8PCAvQmFzZUZvbnQgL0RlamFWdVNhbnMgL0NoYXJQcm9jcyAzMSAwIFIKL0VuY29kaW5nIDw8Ci9EaWZmZXJlbmNlcyBbIDMyIC9zcGFjZSA0NCAvY29tbWEgNDggL3plcm8gL29uZSAvdHdvIC90aHJlZSAvZm91ciAvZml2ZSAvc2l4IC9zZXZlbgovZWlnaHQgL25pbmUgNzIgL0ggNzYgL0wgOTcgL2EgMTAwIC9kIC9lIDExNCAvciAxMjEgL3kgXQovVHlwZSAvRW5jb2RpbmcgPj4KL0ZpcnN0Q2hhciAwIC9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnREZXNjcmlwdG9yIDI5IDAgUgovRm9udE1hdHJpeCBbIDAuMDAxIDAgMCAwLjAwMSAwIDAgXSAvTGFzdENoYXIgMjU1IC9OYW1lIC9EZWphVnVTYW5zCi9TdWJ0eXBlIC9UeXBlMyAvVHlwZSAvRm9udCAvV2lkdGhzIDI4IDAgUiA+PgplbmRvYmoKMjkgMCBvYmoKPDwgL0FzY2VudCA5MjkgL0NhcEhlaWdodCAwIC9EZXNjZW50IC0yMzYgL0ZsYWdzIDMyCi9Gb250QkJveCBbIC0xMDIxIC00NjMgMTc5NCAxMjMzIF0gL0ZvbnROYW1lIC9EZWphVnVTYW5zIC9JdGFsaWNBbmdsZSAwCi9NYXhXaWR0aCAxMzQyIC9TdGVtViAwIC9UeXBlIC9Gb250RGVzY3JpcHRvciAvWEhlaWdodCAwID4+CmVuZG9iagoyOCAwIG9iagpbIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwCjYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgNjAwIDYwMCA2MDAgMzE4IDQwMSA0NjAgODM4IDYzNgo5NTAgNzgwIDI3NSAzOTAgMzkwIDUwMCA4MzggMzE4IDM2MSAzMTggMzM3IDYzNiA2MzYgNjM2IDYzNiA2MzYgNjM2IDYzNiA2MzYKNjM2IDYzNiAzMzcgMzM3IDgzOCA4MzggODM4IDUzMSAxMDAwIDY4NCA2ODYgNjk4IDc3MCA2MzIgNTc1IDc3NSA3NTIgMjk1CjI5NSA2NTYgNTU3IDg2MyA3NDggNzg3IDYwMyA3ODcgNjk1IDYzNSA2MTEgNzMyIDY4NCA5ODkgNjg1IDYxMSA2ODUgMzkwIDMzNwozOTAgODM4IDUwMCA1MDAgNjEzIDYzNSA1NTAgNjM1IDYxNSAzNTIgNjM1IDYzNCAyNzggMjc4IDU3OSAyNzggOTc0IDYzNCA2MTIKNjM1IDYzNSA0MTEgNTIxIDM5MiA2MzQgNTkyIDgxOCA1OTIgNTkyIDUyNSA2MzYgMzM3IDYzNiA4MzggNjAwIDYzNiA2MDAgMzE4CjM1MiA1MTggMTAwMCA1MDAgNTAwIDUwMCAxMzQyIDYzNSA0MDAgMTA3MCA2MDAgNjg1IDYwMCA2MDAgMzE4IDMxOCA1MTggNTE4CjU5MCA1MDAgMTAwMCA1MDAgMTAwMCA1MjEgNDAwIDEwMjMgNjAwIDUyNSA2MTEgMzE4IDQwMSA2MzYgNjM2IDYzNiA2MzYgMzM3CjUwMCA1MDAgMTAwMCA0NzEgNjEyIDgzOCAzNjEgMTAwMCA1MDAgNTAwIDgzOCA0MDEgNDAxIDUwMCA2MzYgNjM2IDMxOCA1MDAKNDAxIDQ3MSA2MTIgOTY5IDk2OSA5NjkgNTMxIDY4NCA2ODQgNjg0IDY4NCA2ODQgNjg0IDk3NCA2OTggNjMyIDYzMiA2MzIgNjMyCjI5NSAyOTUgMjk1IDI5NSA3NzUgNzQ4IDc4NyA3ODcgNzg3IDc4NyA3ODcgODM4IDc4NyA3MzIgNzMyIDczMiA3MzIgNjExIDYwNQo2MzAgNjEzIDYxMyA2MTMgNjEzIDYxMyA2MTMgOTgyIDU1MCA2MTUgNjE1IDYxNSA2MTUgMjc4IDI3OCAyNzggMjc4IDYxMiA2MzQKNjEyIDYxMiA2MTIgNjEyIDYxMiA4MzggNjEyIDYzNCA2MzQgNjM0IDYzNCA1OTIgNjM1IDU5MiBdCmVuZG9iagozMSAwIG9iago8PCAvSCAzMiAwIFIgL0wgMzMgMCBSIC9hIDM0IDAgUiAvY29tbWEgMzUgMCBSIC9kIDM2IDAgUiAvZSAzNyAwIFIKL2VpZ2h0IDM4IDAgUiAvZml2ZSAzOSAwIFIgL2ZvdXIgNDAgMCBSIC9uaW5lIDQxIDAgUiAvb25lIDQyIDAgUiAvciA0MyAwIFIKL3NldmVuIDQ0IDAgUiAvc2l4IDQ1IDAgUiAvc3BhY2UgNDYgMCBSIC90aHJlZSA0NyAwIFIgL3R3byA0OCAwIFIgL3kgNDkgMCBSCi96ZXJvIDUwIDAgUiA+PgplbmRvYmoKMyAwIG9iago8PCAvRjEgMzAgMCBSID4+CmVuZG9iago0IDAgb2JqCjw8IC9BMSA8PCAvQ0EgMCAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+Ci9BMiA8PCAvQ0EgMSAvVHlwZSAvRXh0R1N0YXRlIC9jYSAxID4+ID4+CmVuZG9iago1IDAgb2JqCjw8ID4+CmVuZG9iago2IDAgb2JqCjw8ID4+CmVuZG9iago3IDAgb2JqCjw8IC9JMSAxMiAwIFIgL0kxMCAyMSAwIFIgL0kxMSAyMiAwIFIgL0kxMiAyMyAwIFIgL0kxMyAyNCAwIFIgL0kxNCAyNSAwIFIKL0kxNSAyNiAwIFIgL0kxNiAyNyAwIFIgL0kyIDEzIDAgUiAvSTMgMTQgMCBSIC9JNCAxNSAwIFIgL0k1IDE2IDAgUgovSTYgMTcgMCBSIC9JNyAxOCAwIFIgL0k4IDE5IDAgUiAvSTkgMjAgMCBSID4+CmVuZG9iagoxMiAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDUxIDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3T1KHVEYgGH/CgVTiZImXhGygpRCdmDpdizcTtbhKgKiVmJIJxoLITtwDtzhgdH3qT/mDC+nuYcz3M2N0/ONT+nt+nFyZvvsaJa1tmZ5St5XZaHKQpWFKgtVFqosVFmoslBlocrCzlwP+nX5fXLm4ur3XMutb/vnt4Gp18mJ8x+ryZn2slBlocpClYUqC1UWqixUWaiyUGWhysJm9zHe0X2MJamyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWhr4rkVcXGPnC7WWhykKVhSoLVRaqLFRZqLJQZaHKQpWFoXOMxZ1RjOi7ko+mykKVhSoLVRaqLFRZqLJQZaHKwtAv7K+Hp5MzD39u1n4Z6u7+ZGDqeZa12stClYUqC1UWqixUWaiyUGWhykKVhSoLQ+cYizujGLE6vh2Y6qbAclRZqLJQZaHKQpWFKgtVFqosVFmosjD2j6C7B9Mz//6u+SrYy9MeW6u9LFRZqLJQZaHKQpWFKgtVFqosVFmosjB2jrG0M4oRe/svA1NfZlmrvSxUWaiyUGWhykKVhSoLVRaqLFRZqLLwH/ObJCYKZW5kc3RyZWFtCmVuZG9iago1MSAwIG9iago0NDkKZW5kb2JqCjEzIDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNTIgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3QOw2AUAAEQT4eEIANNCAbM3ggQQF5Ddlqpr7isvO0nxMfnusebtZjG26WP84woHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXXiOFBF0KZW5kc3RyZWFtCmVuZG9iago1MiAwIG9iagozMDMKZW5kb2JqCjE0IDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNTMgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3csQnCUBRAUaOZwVVEnMDKWR3CLeyF7CC4Qb6FXEHPqV/xubwyedPzcdiM7E774Qwrtt9+wF9QuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF+bd+TgcWq734cz+Mn3iPb/JLhdULqhcULmgckHlgsoFlQsqF1QuqFyY/FcSsMsFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULmgckHlgsoFlQsqF1QuqFxQuaByQeWCygWVCyoXVC6oXFC5oHJB5YLKBZULKhdULqhcULkwv3OV4XlbhjOuO6ywywWVCyoXVC6oXFC5oHJB5YLKBZULKhdcoCzY5YLKBZULKhdULqhcULmgckHlgsoFlQu+xyjY5YLKBZULKhdULqhcULmgckHlgsoFlQsvnOESSQplbmRzdHJlYW0KZW5kb2JqCjUzIDAgb2JqCjM0MQplbmRvYmoKMTUgMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA1NCAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7dzBSQRBEEDR0RHMQPBqGl5MwJxFMJI9CGbgYcAQumGHZyv/nQu2+fSp6Nmb7el1yxWOj6/hzC04R6osVFmoslBlocpClYUqC1UWqixUWbj77QMsbWZHsT8/DGe6y0KVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiy8D/fY5z1jmJmZkZ3WaiyUGWhykKVhSoLVRaqLFRZqLJQZYHuMY63z+HM/vJ4/Q+dtX/Y9vvxzPE9HOkuC1UWqixUWaiyUGWhykKVhSoLVRaqLNA9xik7Cul4vwxn+n+MVVRZqLJQZaHKQpWFKgtVFqosVFmwLwVO+hCB6YuHv6TKQpWFKgtVFqosVFmoslBlocpClQX7UkDtKFZbmHSXhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqb2GKutBYaWOszWXTaqLFRZqLJQZaHKQpWFKgtVFqosVFmY2mOsthYYWm3x0l0WqixUWaiyUGWhykKVhSoLVRaqLFRZ+AG+yCBjCmVuZHN0cmVhbQplbmRvYmoKNTQgMCBvYmoKMzk4CmVuZG9iagoxNiAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDU1IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3btOFAEARuFhGe4IbiQQ8BYjMcZEMSaaYGFs7Gx8AWsLay18Bh/BwtbGB6AxlkpjRAtjgtEYFYHlzgLLrm/AP4nmxOJ89Z/d5TjNjjOzXUvv7xbJ7Yfn4ub5k924uf7gV9zcuXzm6MHcwvf4IkV3T970j+dNVy1vtr7GSYVX0V+zMsHKBCsTrEywMsHKBCsTrEywMsHKhPLV/IU4Wts7jJsXc/X8brXVOJm9cuzowdziyfxG++txMjPWFTeLG91xMzp+Nm48lglWJliZYGWClQlWJliZYGWClQlWJliZUP5YyZcuDPUcxE2tlk8LFOVgnJweb+fX+RdWdvIRNtqTP8y3Zn4dj2WClQlWJliZYGWClQlWJliZYGWClQlWJpTTp/L9IJ1OPkdR6TzGwVac/FxJ//B7+aKOopX/qIGefAHJcoVzFFN9nbjxWCZYmWBlgpUJViZYmWBlgpUJViZYmWBlQjk5ls8tHLSH46Z+rMJ5jCJ/5V9eS9d+tPPFIUWtjJO+fMtI0WjlP6qxn2+68VgmWJlgZYKVCVYmWJlgZYKVCVYmWJlgZUL57tOJOLpxrhk3S438db6o5XtYllb3jx6MjJ6KL7LR3Imb4QrXUdTLvGnshg9ceCwzrEywMsHKBCsTrEywMsHKBCsTrEwo7997HUcPn83GzaPpCv/zfrgXJ9vN8KV2Y2cjv1FrM3+W9kTcVLlSoMqVCx7LBCsTrEywMsHKBCsTrEywMsHKBCsTrEwoHz+9FUfnR/KDEN58yHdOFMPh1z6Lojg9Ea4mmFnOj7H8vD4SNwO9+fzD1XorbgZ7B+LGY5lgZYKVCVYmWJlgZYKVCVYmWJlgZYKVCWWnky/q/7KZH4Rw81KFSxcWfsfJtYuTRw9evs2/brG9l+942Gz2xk2jwhMoW22fQPl/sDLBygQrE6xMsDLBygQrE6xMsDLByoTy4DB/DZ8cyE9lGOrvi5up4/kpEfMfwyUQQxV+pbPozh9mrZlPvNT783vlfB7LDCsTrEywMsHKBCsTrEywMsHKBCsTrEz4A2h5jRsKZW5kc3RyZWFtCmVuZG9iago1NSAwIG9iago3OTEKZW5kb2JqCjE3IDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNTYgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3dTU5TURyG8VN6KqU00BIwhhZNNDYkIgNjdGjiyJEuwbgA3IDr0MTEiTJgaqIuwDjQkdEwMYoNGrUGkGI/aSmtO+h7E8w7en5T/jlcnt5JD6e3qUs31oLSGaTkzMP7ciSsvxrJmY13v8cPrN1akotcWU7LmXuPGnLm7vVJOfP09Q85MyEncHJUdqCyA5UdqOxAZQcqO1DZgcoOVHagskMs5IZyaLmo9zGevdQzxZmoZ2YXxw9UzmXkIptfj+XM+Vm9R1Gt6XVC/qwc4V52oLIDlR2o7EBlByo7UNmByg5UdqCyA5UdYr2jQ38/0HsUt6/qIxBJ1Lvd8QMXyvpQx99WTs5U2wM5s1LSmzzhuCdHuJcdqOxAZQcqO1DZgcoOVHagsgOVHajsQGWHWC7qt+rpBK9Fs61nRnoHIoSM2ILY2dfnMWp7em9haUr/4cME2xghrc91cC87UNmByg5UdqCyA5UdqOxAZQcqO1DZgcoOcaGgz1HU/uiPV6xW9BGIb7/6+oqG4ne9eCMObIQQSgun5MxkgvMjE0luwt6+XifBMjgpKjtQ2YHKDlR2oLIDlR2o7EBlByo7UNkhvv2ih/Z6+ghE/n1LzmzvJrii3sH4nz9+oPcNnjxflTOdgd6cmZ7SH6i5tjQjZ7iXHajsQGUHKjtQ2YHKDlR2oLIDlR2o7BDnp/Wp/qmoP6pwNNAz2039Tj2kxAvf7uojCRn9pMuQTesLnp/VF7zxUQfkXnagsgOVHajsQGUHKjtQ2YHKDlR2oLIDlR1i41D/M/zMjH6r/rOu17lzWa+z/mFu/EA+tyMXqe3qUwCHx/qCN6v6CRCLWb1pwr3sQGUHKjtQ2YHKDlR2oLIDlR2o7EBlByo7xK2WfoLBp6aeuVk+kjOn5/QDFcJIfMFFoViXa6T0FkXY6euhil4mZCb0uQ7uZQcqO1DZgcoOVHagsgOVHajsQGUHKjtQ2SGuFPQXYybZFogJHujYSPAlnFG97p+3LspFen19HqOU/T9fIdIe6Drcyw5UdqCyA5UdqOxAZQcqO1DZgcoOVHagssM/EKaG6AplbmRzdHJlYW0KZW5kb2JqCjU2IDAgb2JqCjgwMAplbmRvYmoKMTggMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA1NyAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7d3NSlRhHMfxo3PUcTDf3xAsEmxhEeSiVURBt9Cuiwi6oFatWkVE0F6IIIgojFAQnUody9FJM8eX7sDfQPBdfT/rP8fTt2dzHp9z7Hrx7HGRLK9V48zM+EmcmR47ijPvPg9cPLCycRgv8uTRXpx5+nIizkyOlnGmvnkcZ7rjhP6flQlWJliZYGWClQlWJliZYGWClQlWJpT17b441Gyd5QtV8iP/wtyfOLO9G+5nYa4WL7Jaz6un2leJM7W+8zgzf6UnzriWCVYmWJlgZYKVCVYmWJlgZYKVCVYmWJlQNltdcajam2earfzDbt34FGfevL1/8cDvvBdS/NjpjTOV7vyPOsjnR4qevH/jWkZYmWBlgpUJViZYmWBlgpUJViZYmWBlQjk/m1+L+LmXn9WPjvO2wPPXd+LMwVE4AtHJ/sOv/bx6JkfyWYtaNR9EqW/ncx2uZYKVCVYmWJlgZYKVCVYmWJlgZYKVCVYmlKdneVugkz2K4YH8yH/0N/+nnp2GgenRvP9wqZZvZmMr7z+cnOYbbh2mO3YtM6xMsDLBygQrE6xMsDLBygQrE6xMsDKh/NbIn3cY6M/bAo1m3haYGs3fAu2vhj2T1Xo7XuT29ThSXJ7O+yE7HRxEmZ3KK9W1TLAywcoEKxOsTLAywcoEKxOsTLAyoTzJv+cu1rfySYHRwXyd98v5Sf38PLyBcXcx/6Cx4fzJhZ5Kvpn9g6E8c5jjuJYJViZYmWBlgpUJViZYmWBlgpUJViZYmVBWOujcbudH/vXNPPPwQf5jnksfhy8eaOzm3/Cf5Xsp1jbzVyp78umHYnwo/zDXMsHKBCsTrEywMsHKBCsTrEywMsHKBCsTyu+N/ALB7n6eqVXzM//Naytx5tVSOG9xb3E3XqR1mP/M6Yev+VWPmYk40tHXLl3LBCsTrEywMsHKBCsTrEywMsHKBCsTrEwor87kjxO02/kIxMlpPpawuTPZ0U1d6MvacJwZGcwbL0V+HaTo5KWbmfH8OQrXMsHKBCsTrEywMsHKBCsTrEywMsHKBCsT/gEe5pWJCmVuZHN0cmVhbQplbmRvYmoKNTcgMCBvYmoKNzkxCmVuZG9iagoxOSAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDU4IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3T9rE3Ecx/GL+aUxpqUpqbEqOqhYMNJFFMdOgghiB8XJxdXVB+Fj6OpUoYPgUgMOLSIWqRYRIzXWWhBiUzVNmn9N4zPI50D4TO/X/OVyffe33C93l8Tc/UeR8uBWkDMLL/ty5sqlrJxZXmsMH3j3/Yg8yHRBn0xhQh/n+ceknLmYO5Az+pPw/6jsQGUHKjtQ2YHKDlR2oLIDlR2o7EBlh7C6qUMPXuhL9ZWKvuS/M9uVM7t7g+ED+WOH8iDfavpkiudScqbZqsmZt4uf5Axr2YHKDlR2oLIDlR2o7EBlByo7UNmByg5UdgiJhB5qdcTeQhRFIcZxiud35EwqOTZ84MKU/qAPW3pm40dbD43k5Mj78oycYS07UNmByg5UdqCyA5UdqOxAZQcqO1DZgcoO4cy4fgTj4e20nHk8r+/ZWF7TexCZ9P7wgXpTb6pst/TquZrVx4kGOs7Sm0k5w1p2oLIDlR2o7EBlByo7UNmByg5UdqCyA5UdQlZvUUSdnv5n5I7qxz0GA33Txq+6GKj80S/rmMnrTZXcmH6uJCT1XkdbPyvDWragsgOVHajsQGUHKjtQ2YHKDlR2oLIDlR3C1x39KolnJf0Ixl5H/8Om8jGu+ZXLBb1Hsd/VGyb9Q71HcRDjlo252aqcYS07UNmByg5UdqCyA5UdqOxAZQcqO1DZIUyf0Df1n8zrL9XLP/VxnjztyZnRtLiqzaT11XNbf050rTgiZ0rremvhy9aEnGEtO1DZgcoOVHagsgOVHajsQGUHKjtQ2YHKDmF1Wz9AkAr6y/lyXd9xMH9PP16xUGoOH+j29Nf3tX29el6vd+RMpaH3TO7eXJEzrGUHKjtQ2YHKDlR2oLIDlR2o7EBlByo7UNkhTMZ440Klqi/nT2X0h8V4xiDqqi2T0Yw+mdPj+o/629BnU0jrzZl+R58Pa9mByg5UdqCyA5UdqOxAZQcqO1DZgcoOVHYIu+0YPzpxVj8z8rmiL+dvXN+UM4uvjg8f6MS4H2MjxusoBjF+JaPa0q+a+F0XJxyxlj2o7EBlByo7UNmByg5UdqCyA5UdqOxAZYd/896b3AplbmRzdHJlYW0KZW5kb2JqCjU4IDAgb2JqCjgxMQplbmRvYmoKMjAgMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA1OSAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7d0/axNxAMbx/Pkll5xt/vRKbRQRB0FRcRAU3BQHX4CuDu5ODr4BNwdfQUcHF1fBzUFXUaRLcSmaQJs2bZomzV1MfAd9bigPDt/P/HBNv73pyO9a/Pn5WUE5HNXlJmmO5Ob1RiI3ca10+uDty2/yIt3dC3Lz/tNluXn6aFtuXrxpy434lXAmqOxAZQcqO1DZgcoOVHagsgOVHajsQGWHcP3Kphz92LolN0nrQG62d1py8+SBeGby7uM9eZHzK6nc3L46kZtqZSY3qZ5wL1tQ2YHKDlR2oLIDlR2o7EBlByo7UNmByg6h178kR9M0yM3+YVNu1tv6j1pSk3o0lxcplxZyM830hxmf1OQmaRTlhnvZgcoOVHagsgOVHajsQGUHKjtQ2YHKDlR2CJ1OV44Gw4bcrCV9ufnd10cwjsbisUBzSV6jkM30s4WFftRRaJzTh2V2DpblhnvZgcoOVHagsgOVHajsQGUHKjtQ2YHKDlR2CNlEf9disdCPBbKsKjdRRX+gufq2RT36Ky8S1/RmcKQ/cDbTn7ik23AvW1DZgcoOVHagsgOVHajsQGUHKjtQ2YHKDqHXvyhHo3EkN3nOlXRW9B+1qB4LRBV9rqSU41xJnu9sTFP9HGO1ybmS/wOVHajsQGUHKjtQ2YHKDlR2oLIDlR1CXMvxIsaq/lJ/uzGUm95+S26Kv8anD149/yovsrl1Q26+fNdHJx7fH8jNn72W3HAvO1DZgcoOVHagsgOVHajsQGUHKjtQ2YHKDqG1vCdH3d3VM/lheV6WcPem+E8aGx8eyousJ1O5uXNNP8A5nsRyE1f1b8W97EBlByo7UNmByg5UdqCyA5UdqOxAZQcqO4TRRL9dMs302x2mqX4RQqwPTuiDCGtt/S8485yKGB7r0wzzub4L87yOgnvZgcoOVHagsgOVHajsQGUHKjtQ2YHKDlR2CFHlRI+CfqFjuaw3qX4CUaiprzeMxmV5kaV6jucPVf2B85jluAz3sgOVHajsQGUHKjtQ2YHKDlR2oLIDlR2o7PAP9HiFmAplbmRzdHJlYW0KZW5kb2JqCjU5IDAgb2JqCjc4MwplbmRvYmoKMjEgMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA2MCAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7dAxDYBAAATBBzwgABtoQDZm8ECCBLqhYKe+4rLT2I6h3Of1uln2FTzB5q8P/EKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWXgANqMEXQplbmRzdHJlYW0KZW5kb2JqCjYwIDAgb2JqCjMwMwplbmRvYmoKMjIgMCBvYmoKPDwgL0JpdHNQZXJDb21wb25lbnQgOCAvQ29sb3JTcGFjZSAvRGV2aWNlUkdCCi9EZWNvZGVQYXJtcyA8PCAvQ29sb3JzIDMgL0NvbHVtbnMgMTE5IC9QcmVkaWN0b3IgMTAgPj4KL0ZpbHRlciAvRmxhdGVEZWNvZGUgL0hlaWdodCAxMTkgL0xlbmd0aCA2MSAwIFIgL1N1YnR5cGUgL0ltYWdlCi9UeXBlIC9YT2JqZWN0IC9XaWR0aCAxMTkgPj4Kc3RyZWFtCnic7d2/alNhHIfxNz0nSf9ooFIQRHByKQguiuAoeAteg5tXIQ5ekVtxcFPpUHBRB6loq9HaP0lzkngH/b5geXB4PvOPN8nTs+THOWmv3H1Wks3+Ms6Mp12cefNiFmeevhxcPPB+3MZDymIeR0b9Js4cTc7izPMnm3FmJU7o31mZYGWClQlWJliZYGWClQlWJliZYGVCW5Z5R7GWv/GX8SLvMXbeXYszs/nJxQNbFUuVw+NwSCnltBnFmbLM+5DHDw7ijNcywcoEKxOsTLAywcoEKxOsTLAywcoEKxPastKLQ4uKXUdpVvM5i3zMdB7ez9WaPUa7Fmfa/LlL1wzjzM/f+SCvZYKVCVYmWJlgZYKVCVYmWJlgZYKVCVYmtGWen/VYLCse5ejyLRDffuRz1vth2bH7q+LNTPM9EpPmej5n9ieOPHr4Ic54LROsTLAywcoEKxOsTLAywcoEKxOsTLAyoS29HLriNopSevnhk3vb+daFnd2aF0uaS7ofo5d3Jq9e348zXssEKxOsTLAywcoEKxOsTLAywcoEKxOsTGhX27x/mNc8V9LPq4PPX/Na4PA0/OG3R/k3K/aO8m9fVO0xhvmcs4nPlfwfrEywMsHKBCsTrEywMsHKBCsTrExoR03+9lzx4w5lo+LvtfcxPxWxtR7ez/5xxe9hVvwnja1h/lRfpvlT3brxO854LROsTLAywcoEKxOsTLAywcoEKxOsTLAyof0+y1/nbw7zMw8nXZ65c3sjzrz9FHYdw4rFS5mfx5HDWb61oXTTfM44fyivZYKVCVYmWJlgZYKVCVYmWJlgZYKVCVYmtKOKuxvOF3nX0TYVvwBxKf9JY5D3GKvDvKMY1Dzx0M//SWP/oB9nvJYJViZYmWBlgpUJViZYmWBlgpUJViZYmdCeVuwWRm1eHXQVu46mYmcyWAmvNenyC026/IYH/XxON8/nXFnPz7B4LROsTLAywcoEKxOsTLAywcoEKxOsTLAy4S+CWoh5CmVuZHN0cmVhbQplbmRvYmoKNjEgMCBvYmoKNzUyCmVuZG9iagoyMyAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDYyIDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3K9KQ2EAhvEzN5nTJMxZBGHNBYNoWrOJClq8AK9CMAteiwgrGi2CTTDKgmJwgiaTgn+O3sHegePB8PzyyzkfD18ah1WKzn4RlR95U5vJm7enODnYWxo+OD65zS+amouT3eXpuOld5wMX5WecTOSn6M+sTLAywcoEKxOsTLAywcoEKxOsTLAyoVK0t7GXfV+9xE212wJOAvMuE6xMsDLBygQrE6xMsDLBygQrE6xMsDLBygQrE6xMsDLBygQrE6xMsDLBygQrE6xMsDJhpO8xzo7acbN1eB83rWZ+zmy9HD7oDx7iQ/4b7zLBygQrE6xMsDLBygQrE6xMsDLBygQrE2qjjE4vfsbysmYj/EZRFMXGan34oN8by1lQ3mWClQlWJliZYGWClQlWJliZYGWClQlWJvj/GATvMsHKBCsTrEywMsHKBCsTrEywMsHKhNrmymIcnd88xs3O2kLcVLsjnKgxHwbvz/khlXx71jv5wJeDybj5er2LG+8ywcoEKxOsTLAywcoEKxOsTLAywcoEKxP8UoDgXSZYmWBlgpUJViZYmWBlgpUJViZYmWBlgpUJViZYmWBlgpUJViZYmWBlgpUJViZYmWBlgpUJViZYmWBlgpUJViZYmWBlgpUJViZYmWBlwi804y+LCmVuZHN0cmVhbQplbmRvYmoKNjIgMCBvYmoKNDkyCmVuZG9iagoyNCAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDYzIDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3cuNwjAUQFEyUAPSbGkDUQJdM8VAD3xKcCSiw0Tcs7Zi68qrt3CmzeG8Ue5/t+Ga7XEPToL9fPoAX6HKQpWFKgtVFqosVFmoslBlocpClYWd3GyRGcUahyHdZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocoCfYFykccj/9vrknN0l4UqC1UWqixUWaiyUGWhykKVhSoLVRZW+CeNy3W80el3/KHn4/3DzNRdFqosVFmoslBlocpClYUqC1UWqixUWaBzjGVM03gNnFHM0V0WqixUWaiyUGWhykKVhSoLVRaqLFRZWOEc43Efr5m24zXPGd9ZSHdZqLJQZaHKQpWFKgtVFqosVFmoslBl4QU+jxZdCmVuZHN0cmVhbQplbmRvYmoKNjMgMCBvYmoKMzcyCmVuZG9iagoyNSAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDY0IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3b1KHFEYgOF112ARSJl0EWzFK0glBJPGO8hthNyaTYoUuQ9tLAS7BBL8SS7hHHF4dPF96sPOx8tXDTM7O6uD05Vy9/NqeGbz4S2YBFs/9QAvQpWFKgtVFqosVFmoslBlocpClYUqC7vyYovco/j7/Xp4Zu/z/viHbn49fphJ7bJQZaHKQpWFKgtVFqosVFmoslBlocoCvY+xiPX6dnzo38QZqF0WqixUWaiyUGWhykKVhSoLVRaqLFRZoPcxFnmv5NXxu4lL/ZmbCGmXhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGVh+/4f47mZecikXRaqLFRZqLJQZaHKQpWFKgtVFqosbN8bD8/NzMDtslBlocpClYUqC1UWqixUWaiyUGWhysL2PSnw8ej98MzXL+Pt+fTt/PHDTGqXhSoLVRaqLFRZqLJQZaHKQpWFKgtVFnZWB6dPPcPD3P24HJ7ZHE98EfT+ZoFp5rTLQpWFKgtVFqosVFmoslBlocpClYUqC9v3RdDNyeHwzO+zi+GZ1ydvlhhnSrssVFmoslBlocpClYUqC1UWqixUWaiy8B8GHiOGCmVuZHN0cmVhbQplbmRvYmoKNjQgMCBvYmoKNDU5CmVuZG9iagoyNiAwIG9iago8PCAvQml0c1BlckNvbXBvbmVudCA4IC9Db2xvclNwYWNlIC9EZXZpY2VSR0IKL0RlY29kZVBhcm1zIDw8IC9Db2xvcnMgMyAvQ29sdW1ucyAxMTkgL1ByZWRpY3RvciAxMCA+PgovRmlsdGVyIC9GbGF0ZURlY29kZSAvSGVpZ2h0IDExOSAvTGVuZ3RoIDY1IDAgUiAvU3VidHlwZSAvSW1hZ2UKL1R5cGUgL1hPYmplY3QgL1dpZHRoIDExOSA+PgpzdHJlYW0KeJzt3TFOFVEYgNEHDwkVhQXGlo41gJ2tiQ0lDRuwoXYFLsCKDdCwB3QFNJQWJkQKEjtDeOAOvJMwOWL8Tv3nzc2X28zNzLy1xe67xXOy+nrz54Hl/o5ZyYzW//YC/gtVFqosVFmoslBlocpClYUqC1UWqixsyIt9/rA3nFnug4XM6fjt7nCmvSxUWaiyUGWhykKVhSoLVRaqLFRZqLJAzzE+nf2SlzNOjm6HM+1locpClYUqC1UWqixUWaiyUGWhykKVhbXeKwHay0KVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyQN8rmWL4uMXl6cvhjxx+3B7OXH3/NnFJT9deFqosVFmoslBlocpClYUqC1UWqixMu8Pe2BrP3M/zVYYf54OBV+/HX0pYbK7GM+svxjOPDxNmxtdqLwtVFqosVFmoslBlocpClYUqC1UWqixMO8eY6Yxi+FWGxVwfZrj7OcOPzKe9LFRZqLJQZaHKQpWFKgtVFqosVFmoskDfeFgevJ4wNeFRiudkdXE9nGkvC1UWqixUWaiyUGWhykKVhSoLVRaqLNBzjNWX8S3/P/dHGcs348OZ9rJQZaHKQpWFKgtVFqosVFmoslBlocrCb+XYK3IKZW5kc3RyZWFtCmVuZG9iago2NSAwIG9iago0ODgKZW5kb2JqCjI3IDAgb2JqCjw8IC9CaXRzUGVyQ29tcG9uZW50IDggL0NvbG9yU3BhY2UgL0RldmljZVJHQgovRGVjb2RlUGFybXMgPDwgL0NvbG9ycyAzIC9Db2x1bW5zIDExOSAvUHJlZGljdG9yIDEwID4+Ci9GaWx0ZXIgL0ZsYXRlRGVjb2RlIC9IZWlnaHQgMTE5IC9MZW5ndGggNjYgMCBSIC9TdWJ0eXBlIC9JbWFnZQovVHlwZSAvWE9iamVjdCAvV2lkdGggMTE5ID4+CnN0cmVhbQp4nO3dwQnCQBRAQaOpwQJswoM1WLa1CPYQsAEhOYQx0TfnhSyPvWT5JMPhcj8o0+M1u+Z0O4OdYMdvb+AvVFmoslBlocpClYUqC1UWqixUWaiyMMqHrXJHscfLkM6yUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqizQeYxVRim2NmuxRGdZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRb2932MrVkyZNJZFqosVFmoslBlocpClYUqC1UWqiwM0/M6u+gn34ylzrJQZaHKQpWFKgtVFqosVFmoslBlocrCuLs7iv6kkc+qLFRZqLJQZaHKQpWFKgtVFqosVFmoslBlocpClYUqC1UWqixUWaiyUGWhykKVhSoLVRaqLFRZqLJQZaHKQpWFKgtVFqosVFl4A4yDE88KZW5kc3RyZWFtCmVuZG9iago2NiAwIG9iagozNTEKZW5kb2JqCjIgMCBvYmoKPDwgL0NvdW50IDEgL0tpZHMgWyAxMCAwIFIgXSAvVHlwZSAvUGFnZXMgPj4KZW5kb2JqCjY3IDAgb2JqCjw8IC9DcmVhdGlvbkRhdGUgKEQ6MjAyMjA1MzExNzAwMDIrMDInMDAnKQovQ3JlYXRvciAoTWF0cGxvdGxpYiB2My4zLjIsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcpCi9Qcm9kdWNlciAoTWF0cGxvdGxpYiBwZGYgYmFja2VuZCB2My4zLjIpID4+CmVuZG9iagp4cmVmCjAgNjgKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDE2IDAwMDAwIG4gCjAwMDAwMjY1MTggMDAwMDAgbiAKMDAwMDAxMzQ3MyAwMDAwMCBuIAowMDAwMDEzNTA1IDAwMDAwIG4gCjAwMDAwMTM2MDQgMDAwMDAgbiAKMDAwMDAxMzYyNSAwMDAwMCBuIAowMDAwMDEzNjQ2IDAwMDAwIG4gCjAwMDAwMDAwNjUgMDAwMDAgbiAKMDAwMDAwMDQwMyAwMDAwMCBuIAowMDAwMDAwMjA4IDAwMDAwIG4gCjAwMDAwMDY1NTEgMDAwMDAgbiAKMDAwMDAxMzg1MCAwMDAwMCBuIAowMDAwMDE0NTQ4IDAwMDAwIG4gCjAwMDAwMTUxMDAgMDAwMDAgbiAKMDAwMDAxNTY5MCAwMDAwMCBuIAowMDAwMDE2MzM3IDAwMDAwIG4gCjAwMDAwMTczNzcgMDAwMDAgbiAKMDAwMDAxODQyNiAwMDAwMCBuIAowMDAwMDE5NDY2IDAwMDAwIG4gCjAwMDAwMjA1MjYgMDAwMDAgbiAKMDAwMDAyMTU1OCAwMDAwMCBuIAowMDAwMDIyMTEwIDAwMDAwIG4gCjAwMDAwMjMxMTEgMDAwMDAgbiAKMDAwMDAyMzg1MiAwMDAwMCBuIAowMDAwMDI0NDczIDAwMDAwIG4gCjAwMDAwMjUxODEgMDAwMDAgbiAKMDAwMDAyNTkxOCAwMDAwMCBuIAowMDAwMDEyMTcwIDAwMDAwIG4gCjAwMDAwMTE5NzAgMDAwMDAgbiAKMDAwMDAxMTU1MiAwMDAwMCBuIAowMDAwMDEzMjIzIDAwMDAwIG4gCjAwMDAwMDY1NzIgMDAwMDAgbiAKMDAwMDAwNjcyMSAwMDAwMCBuIAowMDAwMDA2ODUyIDAwMDAwIG4gCjAwMDAwMDcyMjkgMDAwMDAgbiAKMDAwMDAwNzM2NyAwMDAwMCBuIAowMDAwMDA3NjY3IDAwMDAwIG4gCjAwMDAwMDc5ODUgMDAwMDAgbiAKMDAwMDAwODQ1MCAwMDAwMCBuIAowMDAwMDA4NzcwIDAwMDAwIG4gCjAwMDAwMDg5MzIgMDAwMDAgbiAKMDAwMDAwOTMyNSAwMDAwMCBuIAowMDAwMDA5NDc3IDAwMDAwIG4gCjAwMDAwMDk3MDcgMDAwMDAgbiAKMDAwMDAwOTg0NyAwMDAwMCBuIAowMDAwMDEwMjM3IDAwMDAwIG4gCjAwMDAwMTAzMjYgMDAwMDAgbiAKMDAwMDAxMDczNyAwMDAwMCBuIAowMDAwMDExMDU4IDAwMDAwIG4gCjAwMDAwMTEyNjkgMDAwMDAgbiAKMDAwMDAxNDUyOCAwMDAwMCBuIAowMDAwMDE1MDgwIDAwMDAwIG4gCjAwMDAwMTU2NzAgMDAwMDAgbiAKMDAwMDAxNjMxNyAwMDAwMCBuIAowMDAwMDE3MzU3IDAwMDAwIG4gCjAwMDAwMTg0MDYgMDAwMDAgbiAKMDAwMDAxOTQ0NiAwMDAwMCBuIAowMDAwMDIwNTA2IDAwMDAwIG4gCjAwMDAwMjE1MzggMDAwMDAgbiAKMDAwMDAyMjA5MCAwMDAwMCBuIAowMDAwMDIzMDkxIDAwMDAwIG4gCjAwMDAwMjM4MzIgMDAwMDAgbiAKMDAwMDAyNDQ1MyAwMDAwMCBuIAowMDAwMDI1MTYxIDAwMDAwIG4gCjAwMDAwMjU4OTggMDAwMDAgbiAKMDAwMDAyNjQ5OCAwMDAwMCBuIAowMDAwMDI2NTc4IDAwMDAwIG4gCnRyYWlsZXIKPDwgL0luZm8gNjcgMCBSIC9Sb290IDEgMCBSIC9TaXplIDY4ID4+CnN0YXJ0eHJlZgoyNjczNQolJUVPRgo=\n", "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-05-31T17:00:01.902281\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.3.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Probabilities:\n", "Image 0: 0.84%\n", "Image 1: 1.53%\n", "Image 2: 1.00%\n", "Image 3: 54.88%\n", "Image 4: 2.39%\n", "Image 5: 6.54%\n", "Image 6: 4.91%\n", "Image 7: 1.44%\n", "Image 8: 1.39%\n", "Image 9: 25.09%\n" ] } ], "source": [ "visualize_prediction(mistakes[0])\n", "print(\"Probabilities:\")\n", "for i, p in enumerate(preds[mistakes[0]]):\n", " print(f\"Image {i}: {100.0*p:4.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, the model confuses a picture of a sea with a forest, giving a probability of ~55% to image 3, and 25% to the actual anomaly. However, the difficulty here is that the picture of the sea actually contains trees in the foreground, which makes the image a bit of an ambiguous class. It is possible a picture of a sea taken out of a forest, which confuses the model. Nevertheless, in general, the model performs quite well." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "In this tutorial, we took a closer look at the Multi-Head Attention layer which uses a scaled dot product between queries and keys to find correlations and similarities between input elements. The Transformer architecture is based on the Multi-Head Attention layer and applies multiple of them in a ResNet-like block. The Transformer is a very important, recent architecture that can be applied to many tasks and datasets. Although it is best known for its success in NLP, there is so much more to it. We have seen its application on sequence-to-sequence tasks and set anomaly detection. Its property of being permutation-equivariant if we do not provide any positional encodings, allows it to generalize to many settings. Hence, it is important to know the architecture, but also its possible issues such as the gradient problem during the first iterations solved by learning rate warm-up. If you are interested in continuing with the study of the Transformer architecture, please have a look at the blog posts listed at the beginning of the tutorial notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "[![Star our repository](https://img.shields.io/static/v1.svg?logo=star&label=⭐&message=Star%20Our%20Repository&color=yellow)](https://github.com/phlippe/uvadlc_notebooks/) If you found this tutorial helpful, consider ⭐-ing our repository. \n", "[![Ask questions](https://img.shields.io/static/v1.svg?logo=star&label=❔&message=Ask%20Questions&color=9cf)](https://github.com/phlippe/uvadlc_notebooks/issues) For any questions, typos, or bugs that you found, please raise an issue on GitHub. \n", "\n", "---" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 4 }