{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "e2h9qOh2KgZy"
},
"source": [
"# Design of Type Promotion Semantics for JAX\n",
"\n",
"[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
"\n",
"*Jake VanderPlas, December 2021*\n",
"\n",
"One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rod6OOyUVbQ8"
},
"source": [
"## Goals of JAX Type Promotion\n",
"\n",
"JAX's numerical computing API is modeled after that of NumPy, with a few enhancements including the ability to target accelerators like GPU and TPU.\n",
"This makes adoption of NumPy's type promotion system disadvantageous for JAX users: NumPy’s type promotion rules heavily favor 64-bit outputs, which is problematic for computation on accelerators. Devices such as GPUs and TPUs often pay a significant performance penalty to use 64-bit floating point types, and in some cases do not support native 64-bit floating point types at all.\n",
"\n",
"A simple example of this problematic type promotion semantics can be seen in binary operations between 32-bit integers and floats:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "PTu3TMUxX8Xq"
},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float64')"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.dtype(np.int32(1) + np.float32(1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0RkJcO-OY_pN"
},
"source": [
"NumPy's tendency to produce 64-bit values is a [long-standing issue](https://github.com/numpy/numpy/issues/6860) with using NumPy's API for accelerator computations, for which there isn't yet a good solution.\n",
"For this reason, JAX has sought to re-think NumPy-style type promotion with accelerators in mind."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rh_dYVHALFQO"
},
"source": [
"## Stepping Back: Tables and Lattices\n",
"\n",
"Before we dive into the details, let's take a moment to step back and think about *how* to think about the problem of type promotion. Consider arithmetic operations between built-in numerical types in Python, namely those of type `int`, `float`, and `complex`. With a few lines of code we can generate the type promotion table used by Python for addition between values of these types:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "J-bym22gLpfe"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"
\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
int
\n",
"
float
\n",
"
complex
\n",
"
\n",
" \n",
" \n",
"
\n",
"
int
\n",
"
int
\n",
"
float
\n",
"
complex
\n",
"
\n",
"
\n",
"
float
\n",
"
float
\n",
"
float
\n",
"
complex
\n",
"
\n",
"
\n",
"
complex
\n",
"
complex
\n",
"
complex
\n",
"
complex
\n",
"
\n",
" \n",
"
\n",
"
\n",
" \n",
" \n",
" \n",
"\n",
" \n",
"
\n",
"
\n",
" "
],
"text/plain": [
" int float complex\n",
"int int float complex\n",
"float float float complex\n",
"complex complex complex complex"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"types = [int, float, complex]\n",
"name = lambda t: t.__name__\n",
"pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],\n",
" index=[name(t) for t in types], columns=[name(t) for t in types])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z9-VjJKHQ45U"
},
"source": [
"This table enumerates Python's numerical type promotion behavior, but it turns out there is a complementary representation that is much more compact: a [Lattice](https://en.wikipedia.org/wiki/Lattice_(order)) representation, where the [supremum](https://en.wikipedia.org/wiki/Infimum_and_supremum) between any two nodes is the type that they promote to. The lattice representation of Python's promotion table is much simpler:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"cellView": "form",
"id": "SY8leGvMRnV5",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAAB7CAYAAAD5Y7D/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQu0lEQVR4nO3dbWwTdQAG8OfatXthL4S9hK3lyyAxapboEgZsOCCSbSDIQLIYEnwBFz4AOkIggjEwzAYqEl7MDMwJin6QIAwT4mbiCxNRRogxaEJUmGg35jaEbuO6dW3PD7PNxt5u7V2vvXt+CQmw9e7f69M+d73/tYIkSRKIiIhoXCatB0BERBQNWJhEREQysDCJiIhkYGESERHJwMIkIiKSgYVJREQkAwuTiIhIBhYmERGRDCxMIiIiGViYREREMrAwiYiIZGBhEhERycDCJCIikoGFSUREJAMLk4iISAYWJhERkQwsTCIiIhlYmERERDKwMImIiGRgYRIREcnAwiQiIpKBhUlERCQDC5OIiEgGFiYREZEMLEwiIiIZWJhEREQysDCJiIhkYGESERHJEKP1ACYiSRLcbjdcLhf6+vrg9XohSRJ8Ph9MJhMEQYDZbEZcXBzi4+NhtVohCILWwyaShfkmvdNTxiOuML1eL3p6eiCKIu7fvw+32x3YeD6fb8zbmUyDB8uSJMFqtWLKlClISEhAUlISzGZzWMZONBHmm/ROzxkXJEmStB4EALhcLnR1daG7uxuCIIy7YeUymUyQJAnJyclIS0tDfHy8AiMlmjzmm/TOCBnXtDB9Ph+cTie6urrgdruh5lAEQYDVakV6ejqSk5MDezNEamG+Se+MlnHNCtPpdKK1tRXA+IfpSvNvZJvNhpSUlLCtl4yF+Sa9M2LGw16YHo8Hra2t6O3tVXVvZCKCICAxMRE2mw0xMRF3KpeiFPNNemfkjIe1MP17JOHcG5mIyWTi3jgpgvkmvTN6xsNSmF6vFw6HQ/M9krH491TsdnvEzMai6MF8k94x4/+vR+3C9Hg8aGlpUf2EcKj8J5Szs7P5okKyMd+kd8z4kHWoWZgejwc3btzAwMCAWqtQnMViwcyZM3nehybEfJPeMePDqTYvNxo3NAAMDAzgxo0b8Hq9Wg+FIhjzTXrHjI+kSmF6vV60tLRE3Yb2GxgYwM2bN/miQqNivknvmPHRqVKYDocDbrdbjUWHjdvthsPh0HoYFIGYb9I7Znx0ihem0+mM2JlUkyFJEnp7e9Hd3a31UCiCMN+kd8z42BQtTP8FrdG+of0kSYLD4YDH49F6KBQBmG/SO2Z8fIoWZqRd0KoEn88X+PgnMjbmm/SOGR+fYoXpP4zXo97eXjidTq2HQRpivknvmPGJKVKY/gbXy2H8gyRJ0uWeF8nDfJPeMePyKFKYTqdTtxt6KE6QMCbmm/SOGZdHkcLs6urS/cb2+Xzo7OzUehikAeab9I4ZlyfkwnS5XFF/vY5cbrcbLpdL62FQGDHfpHfMuHwhF6YR9kz8JElCV1eX1sOgMGK+Se+YcflCKkyv12u48x7d3d38SDGDYL5J75jxyQmpMHt6eiAIQiiLiDqCIKCnp0frYVAYMN+kd8z45IRUmKIoGm4qus/ngyiKWg+DwoD5Jr1jxicnpMK8f/9+KDePWnxBMQbmm/SOGZ+coAtTkiTVZ1aVlpbiypUrqq4jGP39/YY5SW5USue7paUFq1evxpw5c1BQUIDDhw8rtmylMd/GEI7X8FC8+OKL+Oyzz1RZdrAZD7ow3W636u9919fXY/bs2RP+XnFxMX744QdVx/KgSA4ajU/OE0XpfB8/fhyzZ8/G5cuXsXDhQkWWWVNTg1dffVWRZT2I+Y5uWmQ82gST8aAL0+jXaxn9/kezxx57DAsXLhz33QulH9+2tjbMmjVL0WWqifmOXmfOnEFqaioOHDgw7uNo9Mc4mPsfdGH29fWpfrLYf+RYU1ODrVu3YufOnZgzZw5KS0vx66+/AgB27NiB27dvY/PmzcjLy8MHH3yg6piAwb23vr4+1ddD6rhz5w4uXLiABQsWjFmcSuZ7/fr1uHLlCqqrq5GXlzfiW+xPnz6NpUuXoqCgAJs3b0ZHR0fgZ/v27cPixYsxd+5clJWV4erVqwCAixcvora2Fo2NjcjLy8MzzzyjyFgB5jvadXd3QxRFvP7668jKyhqzOCfKeHt7OyoqKlBYWIj58+ejqqoKPp8PR48eRVFRERYsWICdO3cGZpy2trYiJycHZ8+exeLFi5Gfn49Tp07hl19+wapVq5Cfn4+qqqrA8uvr67F27VpUVVVh3rx5WL58OX788ccxx3P27Fk8/fTTyM/Px4YNG9DW1gYAqKurw5o1awJf4fXpp5+itLQU/f39Yy4r2IwHXZjhvlbr22+/RUlJCS5duoSFCxeiuroaALB3715kZmbiyJEjaG5uxrp168IyHqPNLNMjl8uFpqamUYtTyXzX1dUhNzcXO3fuRHNzMywWS+Bnly9fxqFDh7B//3588803yMzMxLZt2wI/f/TRR3H69GlcvHgRTz31FLZu3Yr+/n7Mnz8f5eXlKC4uRnNzs+Lnepjv6BYTEwNRFHHv3r0xi3O8jHu9XmzcuBGZmZloaGjAV199hSVLluDcuXM4d+4c6urq8MUXX0AUxcBrsd+1a9dw/vx57N+/H2+++SaOHTuG2tpanD17Fl9++eWw59m1a9cwY8YMNDU1YePGjdiyZcuo3yry9ddfo7a2FgcPHkRTUxNyc3Oxfft2AIPnOi0WC44dO4Zbt27h0KFD2Lt3L2JjY8fdRsFkPGbSt/hfuCcFPP744ygsLAQALF++HB9//HFY1/+gvr4+fiVSlBqaXUmShhVnbm4uqqurkZ2dHZaxnD9/HitXrsQjjzwCAKioqEBBQQFaW1ths9mwfPnywO8+//zzOHr0KP7880889NBDqo6L+Y5eoigOy7goioEjzj179mD79u14+eWXx30Nv3btGjo7O7F161bExAzWRG5uLmpqavDcc89hxowZAIBXXnkFq1atwhtvvBG47YYNGxAbG4v8/HzEx8dj6dKlSE1NDSzj+vXrgbkp06ZNw9q1ayEIAkpKSvDhhx+iqalpWO4B4NSpU3jppZcCz8vy8nK8//77aGtrQ1ZWFqqrq1FWVoaGhgasW7cODz/88ITbKayFGe490LS0tMDf4+Li0N/fD4/HE3gww+3zzz/Hrl27NFk3hWa0i5b9xfn9999jwYIFuH79eljG0tHRMezJnZCQgJSUFHR0dMBms+HEiRM4c+YMOjs7IQgCent7cffuXdXHxXxHL7fbPerbjf5LKV577TVcunQJNTU1Yy6jvb0dmZmZI15fOzo6kJWVFfh3VlYWPB4P7ty5E/g/fzkCQGxs7Ih/D72kIyMjY9jEo8zMzGGnJPxu376Nffv2Yf/+/YH/kyQpMB6bzYa8vDx89913ePbZZ8e8X0OFtTBNJsW+ezpkWsz0Wr16NSoqKsK+Xgqd3W4f8Q3sFosFZrMZa9asQWVlJQDg3r17qo8lIyMDt2/fDvxbFEU4nU5kZGTg6tWrOH78OGprazFr1iyYTCbk5+cHjgzUzD3zHb1OnDiBTZs2jbjGMjExEdOnT8fbb7+NFStWjHgODDV9+nS0t7ePOCjJyMgInDsEBossJiYGqamp+OeffyY91o6ODkiSFMhye3s7Fi1aNOp4ysvLsWzZslGX09TUhJ9//hlz5szBO++8I2tnL5gOC7r1Imk6cmpqKhwOR1jXGUk7DBQ8i8WCuLg4rF27Fr///jvq6upgt9vDlu8lS5agvr4e169fh9vtxuHDh5GTkwObzYb79+/DbDZj2rRp8Hq9eO+994a9CKampqKtrU2Vd3uYb/1ITEzErFmzcPLkSfz2228oLS2FIAjjZjwnJwdpaWk4ePAgRFFEf38/fvrpJyxZsgQnT56Ew+GAKIo4fPgwiouLg36n799//8Unn3yCgYEBNDY24ubNm3jiiSdG/F5ZWRnq6urwxx9/ABh8l6ixsREAcPfuXezatQu7d+9GVVUVLly4gKampgnXHdbCNJvNwd5UcevXr8exY8eQn5+PEydOhGWdfEGJbjExMaMWpV+48j1v3jxs2rQJW7ZswaJFi/D333/jrbfeAgAUFBSgoKAAy5YtQ1FREWJjYzF9+vTAbYuKigAA8+fPR1lZmaLjYr6j28DAwJhF6Tdexs1mM95991389ddfKCoqwuLFi9HQ0ICVK1di2bJleOGFF1BSUgKr1YodO3YEPc6cnBzcunULhYWFOHLkCA4cOICpU6eO+L0nn3wS69atw7Zt2zB37lysXLkSFy9eBABUVlZi0aJFKCwsxNSpU1FZWYndu3dP+A5RMBkXpCBn79y7d0+1vdtIJwgCbDbbqA8sRb6SkhLYbDZUVlYOK8mhmG/mO1o1NjZi27Zt2LNnD1asWDHmkaTWGa+vr8eZM2fw0UcfhX3dwWY86HOY8fHxwd5UF4x+/6NZQ0PDhL9j9MfX6Pc/mhUXF6O4uHjC3zP6YxzM/Q/6fRer1Wroz5u0Wq1aD4FUxHwz33rHjE8+4yFN+jHqkyo2NjaiJj2R8phv5lvvtM54aWmpJm/HAsFnPKQz+1OmTAnl5lErISFB6yFQGDDfpHfM+OSEVJgJCQmGm01nMpn4gmIQzDfpHTM+yduGsuKkpCTDvQcuSRKSkpK0HgaFAfNNeseMT05IhWk2m5GcnBzKIqJOcnJyRF2DSuphvknvmPHJCflYPC0tzTATBARBGPaZtqR/zDfpHTMuX8iFGR8fb5jZhFar1fDXLhkN8016x4zLp8jZXiPsoZhMJqSnp2s9DNIA8016x4zLXIYSA0lJSdH9xgZguPf6aRDzTXrHjMujSGGaTCbYbDbdbnD/5w4abfo1DWK+Se+YcXkUe4akpKQgMTFRqcVFlMTERKSkpGg9DNIQ8016x4xPTNFdSj3upfr3vIiYb9I7ZnyCZSmylP/FxMTo6rBeEATY7fagvxyV9IX5Jr1jxsen+K6E/7A+2je4IAhITEzkRAgahvkmvWPGx6bKsbfdbo/663qsVitmzJih9TAoAjHfpHfM+OhUKUyz2Yzs7GxYLBY1Fq86q9WK7Oxs3b2XT8pgvknvmPHRqfaMMZvNmDlzZtRtcIvFguzsbH6eJo2L+Sa9Y8ZHEiSVP6re4/GgpaUFbrc7oj8V3/9lqnwxoclgvknvmPEh61C7MAHA6/XC4XCgt7c3Ije4/+Sw3W7niwlNGvNNeseM/7+ecBSmn9PpRGtrK3w+X7hWOSGTyQS73c7ZghQy5pv0zugZD2thAoOH962trZrvqfj3SGw2G69DI8Uw36R3Rs542AvTz7+nAiCseyv+WVM2m40fB0aqYb5J74yYcc0KExjcyN3d3ejs7FT9hLL/hHB6ejqSk5M5pZ5Ux3yT3hkt45oW5lAulwtdXV3o7u6GIAiK7LGYTCZIkoTk5GSkpaXxy3FJM8w36Z0RMh4xhenn9XrR09MDURQhiiL6+/sDPxtvqEM/xik2NhYJCQlISEhAUlISZwZSxGC+Se/0nPGIK8wHSZIEt9sNl8uFvr4++Hy+wB+TyRT4ExcXh/j4eFit1qj/DEQyDuab9E5PGY/4wiQiIooEnBlAREQkAwuTiIhIBhYmERGRDCxMIiIiGViYREREMrAwiYiIZGBhEhERycDCJCIikoGFSUREJAMLk4iISAYWJhERkQwsTCIiIhlYmERERDKwMImIiGRgYRIREcnAwiQiIpKBhUlERCQDC5OIiEgGFiYREZEMLEwiIiIZWJhEREQysDCJiIhkYGESERHJwMIkIiKSgYVJREQkAwuTiIhIBhYmERGRDCxMIiIiGf4Dqv6edZEELNkAAAAASUVORK5CYII=\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {'int': ['float'], 'float': ['complex']}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {'int': [0, 0], 'float': [1, 0], 'complex': [2, 0]}\n",
"fig, ax = plt.subplots(figsize=(8, 2))\n",
"nx.draw(graph, with_labels=True, node_size=4000, node_color='lightgray', pos=pos, ax=ax, arrowsize=20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "80qo0-xqSbYH"
},
"source": [
"This lattice is a compact encoding of the information in the promotion table above. You can find the result of a type promotion for two inputs by tracing the graph to the first common child of the two nodes (including the nodes themselves); mathematically, this common child is known as the *supremum*, or *least upper bound*, or *join* of the pair on the lattice; here we will refer to this operation as the **join**.\n",
"\n",
"Conceptually, an arrow means that *implicit type promotion is allowed* between the source and the destination: for example, implicit promotion from integer to float is allowed, but implicit promotion from float to integer is not.\n",
"\n",
"Keep in mind that in general not every directed acyclic graph (DAG) will satisfy the properties of a lattice. A lattice requires the existence of a unique least upper bound for every pair of nodes; so, for example the following two DAGs are not lattices:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"cellView": "form",
"id": "qfKmOZG3xRzl",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAB7CAYAAABwzVpnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAefklEQVR4nO3df1BU1/k/8Pe9u+zCLltZQF2JfASiYtUYixYKTpT4M9HUpCbaWGusnSnaWnASm6bpTLUqOibtpKPIxImNNk7zw5Z2KkYn2gzYmhp/EvE3DikYqiKCoMuG3YW79/uHs3wBQRHO3b3n3Oc147SDm3NP9nm4ee557j1XUlVVBSGEEEKIwORwT4AQQgghRGtU8BBCCCFEeFTwEEIIIUR4VPAQQgghRHhU8BBCCCFEeFTwEEIIIUR4VPAQQgghRHhU8BBCCCFEeFTwEEIIIUR4VPAQQgghRHjmcE+AkHBTVRV+vx9+vx+BQACqqkKSJMiyDIvFAovFAkmSwj1NQu6L8ti4KPa9QwUPMRxVVeHxeOB2u+HxeODz+QCg2xNC8FVzVqsVdrsdDocDdrudTh4k7CiPjYti3zcSvTyUGIWiKLh16xYaGhoQCAQQCAQeegxZliHLMuLi4hAbGwuTyaTBTAnpGeWxcVHs+4cKHiK8QCCA2tpaNDY2Avj/Vzz9Ebw6cjqdcLlckGW6HY5oi/LYuCj2bFDBQ4Tm8XhQU1MDRVGYnCS6kiQJJpMJiYmJsNvtzMcnBKA8NjKKPTtU8BAhdbwiCkWKS5JkqCslEhqUx8ZFsWePCh4iHEVRUF1dDa/XG5ITRZAkSYiMjERSUpKh+uJEG5THxkWx1wYVPEQoiqKgqqoKPp8vpCeKIEmSYLVakZycLOQJg4QG5bFxUey1I+a6FTGkQCCA6urqsJ0ogLs3E/p8PlRXV/fpCQpCKI+Ni2KvLSp4iDBqa2tDvgTcHVVV4fV6UVtbG9Z5ED5RHhsXxV5bVPAQIXg8npDd3NcbqqqisbERHo8n3FMhHKE8Ni6Kvfao4CHcCwQCqKmp0c2JIkhVVdTU1Ai3LEy0QXlsXBT70KCCh3CvtrYWiqKEexrdUhRFuGVhog3KY+Oi2IcGFTyEa4qi6GoZuKvgsrBeT2ZEHyiPjYtiHzpU8BCu3bp1K9xT6JXglvCEdIfy2Lgo9qFDBQ/hlqqqaGho0O2VUZCqqqivr9f9PEl4UB4bF8U+tKjgIdzyeDzc3EynKIpQTzsQdiiPjYtiH1pU8BBuud1ubk4WqqrC7XaHexpEhyiPjYtiH1rmcE+AkL5iebUxa9YsNDQ0QJZlmM1mjB8/HqtXr4bL5WJ2DN6vjog2tMiLpUuXoqKiAocOHYLFYmE6NuXx/amqiurqaiQnJz/ws6y/y3379mHXrl2oqqqC3W5HamoqcnJykJaWxmR83mNPKzyES8Htz1kqKCjA8ePHUVpairi4OGzcuJHp+OHcLp7okxZ5fPXqVZSVlUGSJJSWljIdG6A8fpBz584hJSUFU6ZMwfHjx3v8HOvYv/fee3jzzTfxk5/8BIcOHcLBgwfx4osvMs0B3mNPBQ/hkt/v12xsq9WKGTNm4L///S/zsbWcN+GPFvlQXFyMcePG4dlnn0VxcTHz8QHK4/tpa2uDw+HAv//9bzz55JM9Fj4sv0O3243CwkL8+te/xvTp02Gz2RAREYHs7GysWrWK2XEAvmNPBQ/hkt/vhyRJmozd0tKCAwcOYNy4cUzHlSSJ65MFYU+LPN67dy/mzJmDOXPm4MiRI6ivr2c6PuXxgwVj+vXXX/dY+LCMfXl5Ofx+P6ZNm8ZkvJ7wHnu6h4dwSYsb/VauXAmz2YyWlhY4nU5s27aN+TF4uUGRhAbrfCgrK8P169cxa9YsOJ1ODB06FPv378dLL73E7BiKouD8+fNobW1lNqZILl++fM8mfcHCZ/LkyRg1ahQ2btyISZMmMTtmU1MTYmJiYDZr/590ns9hVPAQLmnRR968eTMyMzOhKApKS0uxdOlS7NmzB/Hx8cyOwXP/m7DHOh+Ki4uRmZkJp9MJAJg9ezaKi4uZFjw+nw9/+ctfcOLECWZjiqS5uRler7fbv/P5fCgvL8fq1avx6aefMjtmTEwMmpqa0NbWpnnRw/M5jAoewiWt2lkAYDKZMH36dKxbtw5lZWWYOXMms7G1nDfhD8t88Hq9OHDgABRFQXZ2NoC7bRO3242KigqkpqYyOY7NZsOaNWswYMAAJuOJ5osvvkB2djbu3LnT6ec2mw3jx4/H7373O2RlZeH27dvMjvn444/DYrGgpKSE6fmqOzyfw6jgIbrz8ccf49q1a5g2bRpSUlK6/QWTZe1uP1NVFaWlpbhz5w5SUlKYjq3lvAl/WOZDSUkJZFnG3//+d0RERLT/fNWqVSguLsarr77K7FiUx73XtdAJYvkdOhwOrFixAhs2bIDJZEJWVhbMZjOOHj2KEydO4JVXXmF2LJ5jTwUP0Z2ioiK8//77iIiIgM1mw7Rp0/DMM89g6tSpeOSRRwAAFouF+dJqbm4uZFmGJElISEhAfn4+hg8fzmx8VVWZ74lC+MYyj4uLi/Hcc89hyJAhnX6+cOFCbNq0CS+//DKTdgfl8YO1tLT0WOgEsT6HLVmyBHFxcXjnnXfw+uuvw2azYfTo0cjJyWF2DN5jL6k8N+QYU1UVfr8ffr8fgUAAqqpCkiTIsgyLxQKLxcL1ch4v/vSnP+HnP/95p02uoqOj0draitjYWGRlZWHy5MmYNm0aV/1kSZIwevRoyiHSTlVVXLhwgfJYIHV1dVixYgVefvnlbgudIIp96Bm64FFVFR6PB263Gx6Pp30TqO6CGfyarFYr7HY7HA4H7HY7t4HXs4sXLyI9PR3Nzc09fsblcuHw4cM93hyoR5GRkUxXjIgYKisrKY8NimIfWoZsaSmKglu3bqGhoQGBQOCex+zuVwN6vV54vV40NjZClmXExcUhNjYWJpNJ62kLr7W1FWfOnMHnn3+OlpaWbj8TFRWF+fPnY/v27WhoaODqZGG328M9BaJDdrud8tigKPahZaiCJxAIoLa2Fo2NjQD693hdsFCqq6tDXV0dnE4nXC4X1zd0hZKqqrhy5QqOHTvW/qe8vBxJSUnIyMjA0KFDceXKlU7/TFRUFN566y0sX74cwN0b9RobG7nYF0KSJDgcjnBPg+gQ5bFxUexDyzAFj8fjQU1NDRRFYdozDY7V2NiIO3fuIDExkfsqWAtNTU04ceJEe3Fz/PhxmEwmZGRkICMjA/n5+Zg4cWL7L9Rrr72G3//+9wgEAjCbzXA4HNi3bx8yMzPbx7Tb7ZBlmYuThclkorwg3eIpjyVJojzuBbfbjc2bNyMqKgoDBgzo9CcmJqb9/9tsNm5iL8I5TPh7eDqu6oTiX1WSJMOv9rS2tuLs2bOdVm9qamqQlpbWXuAEV3F6ugequLgYixYtgqIoGDFiBD755JN7nj4BgJs3b6Kurk7XN/5JkoTBgwcz3cCQiIWHPG5tbcW2bdvg9XqRm5uLiRMnhntKutXY2IhBgwYBuPs0lslkgizLUFUVgUCg/eGYtLQ0fPLJJ7qPvSjnMKELHkVRUF1dDa/XG9JkkiQJkZGRSEpKEv7enge1poJ/xowZ81CPxN64cQMulwuLFy/G9u3bYbVau/2coii4dOmS7k8Wo0aNEj4XSN/xkseDBg3Czp07UVhYiEceeQR5eXl4/vnnO+37Q+76wQ9+gN27d/e4emOz2fDFF1/g0Ucf5SL2IpzDhC14FEVBVVVV2F5nL0kSrFYrkpOTuU+Sjm7fvt2pNXXs2LH21lR6ejoyMjIwceJEfOMb3+j3sSorK3v1RMC1a9dCtoL3sIIrfgkJCeGeCtE5nvK4ra0Ne/fuxZYtW3D58mUsX74cOTk5GDx4cJhnqh+nT5/GpEmT8PXXX9/zd1FRUSgpKcF3vvMdAHzFnmdCFjyBQABVVVUhX9npKrjSk5yczGV7i0VrKhQCgQAuX76Mtra2sM2hJ2azGSNHjuQy/iS0eM3jM2fOYOvWrfjrX/+KuXPnUrsLdy+49+/fj0WLFsHtdnf6u6ioKLz77rtYuHBh+894jT1vhCx49FQt81Idd2xNHT9+HMeOHcPp06c7tabS09MxduzYkLyR92F5PB5UV1frIuZBkiQhKSmJ+xv9SOjwnMcNDQ149913Dd3uampqwo4dO7B161bEx8cjMzMTO3bsaN9TzGazYdWqVVi3bt09/yzPseeFcAUPJU3vdNeakmW508oNq9ZUqFChS0TAex4bsd114cIFFBQUYPfu3Xj66aeRl5eHjIwMtLW1weVyoaGhAVFRUXj66adRVFTU44o477HXO6EKHloW7B4vran+olYmEYFIedy13ZWXl4cJEyYwnml4BNtWW7Zswblz57Bs2TIsW7bsnqdJN2zYgDVr1mDs2LE4evQoIiMjexxTpNjrkVAFj56q465CVS3frzUVvKk4IyNDt62p/qKb1YkIRMvjju2uoUOHIjc3l9t2V9e21cqVK/HCCy/0+CRpfX09Fi9ejF27dmHgwIEPHF+02OuJMAUPL491sn60T8TWVH/RdgREBCLmcdd2109/+lPk5OS071mjZz21rbQgYuz1QJiCh4eNu/q7eZNRWlMs0IaTRAQi5zEP7a7etq20IHLsw0WIgkdVVVRUVOjy3p2uzGYzUlNTH1iQGL01xYpWrxQJkiQJJpOJXilCNCVyHuux3dW1bZWXl4f58+f32LbSksixDzUhCp7m5mZ89dVXXLyPRJIkDBs2DNHR0Z1+Tq0p7bB8aWxQsGAV/YqI6IfoeayHdteFCxewdetWfPTRR5q3rR6G6LEPFSEKnuvXr6OhoSHc0+g1p9OJGzduUGsqxBRFQWNjI+rr6/t8tRS8GoqPj4fT6RSyz030zQh5HMp2VzjbVg/LCLHXkhAFT2VlJbxeL9Mxly5dioqKChw6dAgWi4Xp2BcvXsTatWupNRUmqqrC4/HA7XbD4/HA5/MBQLfFZfDXw2q1wm63w+FwwG63UyFKws4Ieaxlu0tPbauHZYTYa4H7gkdVVVy4cIFpb/Pq1auYPXs2oqOjsXr1asyaNYvZ2MDdpBw9erQhE06PVFVtf3txIBCAqqqQJAmyLMNiscBisVCsiO6JnMdtbW0oLi5GQUFBv9tdem1b9YfIsWeJ+4LH5/OhsrKSacHz9ttv48iRI3jsscdw5coVFBYWMhsbuFvwDB8+nIsrCUII0ZMzZ86goKAARUVFvW538dS2Itrh/i4lv9/PvHLdu3cv5syZgzlz5uDIkSOor69nOr4kSfD7/UzHJIQQIxg3bhy2b9+OyspKjBkzBvPmzcOkSZOwe/dutLa2dvpsU1MT3nrrLYwcORLr16/HkiVLUF1djd/+9rdU7BgQ9wUP6yezysrKcP36dcyaNQtjxozB0KFDsX//fqbHANjPmxBCjCQuLg6//OUv8eWXX2LVqlXYtm0bkpKSkJ+fj8OHD+NnP/sZkpOTcerUKXzwwQc4fvw4fvjDH9LKuoFxX/Cw7sgVFxcjMzMTTqcTADB79mwUFxczPQbAft6EEGJEZrMZ8+bNw6efforXXnsNb7/9NqZMmYIjR47go48+wvvvv8/9PTqEDe4fC2LZzvJ6vThw4AAURUF2djaAuy0zt9uNiooKpKamMjsW3UBGCCH919TUhJ07d2Lr1q2Ii4vDG2+8galTp+LPf/4zcnJyMHToUOTl5WHevHlcvruLsMP9Cg/LzZJKSkogyzL27NmDoqIiFBUVobi4GGlpacxXeYywyRMhhGjl4sWL7W2rkydPdmpbJSQk3NPuSk5ORn5+Purq6sI9dRIm3P9X12KxMGsPFRcX47nnnsOQIUMQHx/f/mfhwoXYt28fs1dXqKrKfG8fQggRnaIo2Lt3L2bOnImpU6di4MCBOH/+fI9tq2C7q7S0FPv378eVK1eQmpqKH/3oRzh16lQY/g1IOHH/WLoW+/CEwqhRo2ijQUII6YWubav+bBLYdTNDancZB/cFD6DNTstaqqysxJIlSzBx4sROuy3TY5LhQZt2ERGImMcXL15EQUEBPvzwQ8yePZvpJoHBzQy3bNmCyspKLF++POTv7mJFxNhrQYiCh7d3acXFxcFisbS/BT34RnS73d5e/KSnp2PChAnCv702HGhbdiICUfM4uElgQUEBzp49i5ycHCxbtgwJCQmaHbPjZobPPvsscnNzNXt3Fwuixl5rQhQ8IrwtXVVVVFZWdiqCzp07hxEjRnR6oeioUaMM9bI3lhRFwa1bt9DQ0IBAINCnfJFlGbIsIy4uDrGxsRQLEnKi5nHHtlVsbCxWrlwZ8ndbNTQ04I9//CMKCwuRmJiou3aXqLEPFSEKHlVVUVFRweymYi2ZzWakpqb2qrr2+Xw4ffp0p7eq37x5k1phDykQCKC2thaNjY0A2OyBFIyf0+mEy+Wip+6I5kTN44sXL2Lr1q348MMP8dRTT7W3rcK5AqG3dpeosQ81IQoeALh58ybq6up0ffOyJEkYPHgw4uPj+zxGQ0MDtcIegsfjQU1NDRRF0SQ3JEmCyWRCYmIifedEM6LlcSAQaH+31ZkzZ9rfbaVl26qvysvLUVBQgL/97W9haXeJFvtwEqbgURQFly5d0n3Bw7olFWyFdSyAqBXW+YooFDkhSZKhrpRIaIiWx8G2VWFhIZxOZ1jaVn0V6naXaLHXA2EKHgC4du1ayJLjYQWTKRRXMEZvhSmKgurqani93pDmgiRJiIyMRFJSkqGKS6INkfJYj22rvnrYdtf//vc/zJgxAx9//DEeffTRXh1DpNjriVAFTyAQwOXLl3V5L4/ZbMbIkSPDVjnX19e3t8KC/xsdHS1cK0xRFFRVVcHn84Wl8JUkCVarFcnJyUKeMEhoiJDHPLWt+qo37a5XX30Vf/jDH5CQkIAzZ84gJibmvmOKEHu9EqrgAe72O6urq3W1yiNJEpKSknRVTHRthQWfChs5ciS3rbBAIICqqqqQXxV1FbxKSk5OFnZpmGiH9zzu2rbKy8vDggULuGhb9VVP7a62tjYMHjwYbrcbFosFEyZMwL/+9a8e22C8x17vhCt4AH21tkLZyuovr9fb3goLrgLx1AqjuBMR8JrHIrWt+qqtrQ179uzBli1b8OWXXyI9PR0HDx6Ex+MBAERFRWHBggXYuXNnt98Lr7HnhZAFD1XJ7HRshQULoY6tsIyMDEyYMAE2my2s86SVPSIC3vI42LYqKChAeXm5kG2rvjp9+jQmT54Mt9vd6ec2mw1r167FL37xi04/5y32PBKy4AGoD6qVnlphqampnVaBvvnNb/a7yDt06BCeeOKJB35/dO8WEQFPeXz79u32TQKN0rZ6WJ999hmeeuqp9tWdjqKiorB7925897vfBcBX7HkmbMED0J3uoRJshXVcCQq2wjquBLlcrl6PWVNTg//7v//DlClT8I9//OO+N/rpaRm4KxGXhYk2eMjj27dvG75t1Vtz5szB/v37e/x7s9mMf/7zn8jOzuYi9iKcw4QueADayyBcumuFORyOTqtA92uFFRUV4cc//jF8Ph/i4+Nx8OBBjBkz5p7PGXX/JSIWHvK4tbUV8+fPx6JFi6ht9QB1dXUYMmQIoqKiYDaboaoqAoEA2tra4Pf7Adx9xUNWVhZKSkp0H3tRzmHCFzxBtFtleHXXCjt//ny3T4XJsoyVK1eioKCgPVY2mw27du3C888/32lco+ywTcTGQx6rqoqBAwc+1EqtUXm9XhQVFSEqKgoDBgy4509kZGT7Z3mIvSjnMMMUPAC9j0Rv7tcKO3v2LG7evNnp81FRUVixYgU2bdoEk8kk7DvUiLFQHhsXxT60DFXwBCmKgsbGRtTX1/d5xSe4ohMfHw+n08n9Up9e1NfX4/PPP8f3vvc9KIpyz99brVZ8+9vfxt69e2E2m/HVV1/16Y3BoSZJEoYNG4bo6OhwT4XoTHNzM+WxQVHsQ8sc7gmEQ7BQiYuLg8fjgdvthsfjgc/nA4BuK9hgUWS1WmG32+FwOGC327mudvUoPj4eiYmJsNls9zzOCdx9bcZnn32GYcOG4dKlS8xPFPv27cOuXbtQVVUFu92O1NRU5OTkIC0trV/jqqoKt9vN9cmCaMPtdjPL41mzZqGhoQGyLMNsNmP8+PFYvXo1szYU5TFbrGIfjLvJZILJZEJKSgrmzp2LF154gVnXQYTYG7LgCZIkCdHR0e0BVFUVfr8ffr8fgUAAqqpCkiTIsgyLxQKLxUIFTggcO3as0+qO2WyGzWaD1+tFWloaZs+ejczMzG4f9+yP9957Dzt27MBvfvMbZGVlISIiAv/5z39QWlra74IHAPP5EjGwzouCggJkZmbC5/MhPz8fGzduxJYtW5iNT3l8f1evXsXcuXOxatUqfP/737/v6j/L7zIYd7fbjZMnT+KNN97AmTNnkJ+fz+wYvMfe0AVPV8G9c2gvifA6fPgwWltbERERgdGjR2Pu3LmYMWMGMjIyYLFYANwtTi9cuMDsmG63G4WFhVi/fj2mT5/e/vPs7GxkZ2czOUZwTygqmkmQqqrtK8usWa1WzJgxA2+++SbTcSmP76+urg7nz59HTk4OfvWrX2HTpk3dFj5axd7hcODJJ59EfHw8Fi1ahCVLlmDEiBFMxuY99nSHLdGdvLw87Nu3D01NTTh9+jTWrVuHJ554or3YAdD+aCcr5eXl8Pv9mDZtGtNxu2I9b8I3LfOhpaUFBw4cwLhx45iPTXl8f1artf3J4JycHCQnJ+ODDz7otHKt9Xf42GOPYfDgwSgrK2M6Ls+xpxUeojvp6ekP/Izf74ckScwe5WxqakJMTAzMZu1+JSRJgt/vpxVE0o51HgPAypUrYTab0dLSAqfTiW3btjEbG6A8flgejwcej+eeFR8tYt/VoEGDcPv2bWbj8R57KngIl1jfrBwTE4Ompia0tbVpWvTw8DQGCR0t8mHz5s3IzMyEoigoLS3F0qVLsWfPHmZ7qLjdbuTk5ODIkSNMxhON3+/v9oGLYOGzaNEirFmzBidPntR8Ljdu3MCAAQOYjsnzOYwKHsIl1ldFjz/+OCwWC0pKSjBz5kymY3dkwF0gyH1omQ8mkwnTp0/HunXrUFZWxiyvo6OjsX79ekRERDAZTzTnzp3Diy++iObm5k4/j4iIgCzLmDt3LtatW6f5ueDcuXOoq6vDt771Labj8nwOo4KHcIn1TXMOhwMrVqzAhg0bYDKZkJWVBbPZjKNHj+LEiRN45ZVXmByH15v9iDa0zAdVVVFaWoo7d+4gJSWF2biSJGHgwIHMVw5EcevWrU6PgkdERMBkMmHhwoVYu3YtEhMTAYBpq6mj5uZmnDp1Cps2bcIzzzyDkSNHMh2f53MYFTyES1rsaL1kyRLExcXhnXfeweuvvw6bzYbRo0cjJyeH2TFoJ27SkRb5kJubC1mWIUkSEhISkJ+fj+HDhzM9BuXx/SmK0mOhE8T6O8zNzYXJZIIsy0hJScFLL72EBQsWMD0GwHfsDbnTMuGfz+dDZWUlV8urkiRh+PDh3N7wR9ijPBZPRUUFxo4di8WLF3db6ARR7EOPVngIlzo+os4TXudNtMFrPvA671BITU3FnTt3EBUVdd/P8fod8jpvgPbhIZwKbhLJE6vVynX/m7BHeSymBxU7AMU+HKjgIdyy2+3hnsJD4W2+JDR4ywve5qtnvH2XvM23Kyp4CLccDgc3N9BJkgSHwxHuaRAdojw2Lop9aPHxTRPSDbvdzs3JwmQycX91RLRBeWxcFPvQ4uObJqQbkiQhLi5O9z1lSZIQHx+v+3mS8KA8Ni6KfWhRwUO4FhsbG+4p9IrT6Qz3FIiOUR4bF8U+dKjgIVwzmUxwOp26vfKQJAlOpxMmkyncUyE6RnlsXBT70KGCh3DP5XLp9pfRZDLB5XKFexqEA5THxkWxDw0qeAj3ZFlGYmKi7q6QJElCYmIiNzclkvCiPDYuin1oiPFvQQzPbrfralk4uAzM+1MNJLQoj42LYq89KniIMFwuFyIjI8N+wpAkCZGRkcIsA5PQojw2Loq9tqjgIcKQZRlJSUlh3f48uF18UlKSMMvAJLQoj42LYq8tels6EY6iKKiurobX6w3pm4iDV0VJSUm6vQGR8IPy2Lgo9tqggocIKRAIoLa2Fo2NjSE5YQT73S6XS7irIhI+lMfGRbFnjwoeIjSPx4OamhooiqLJSUOSJJhMJiQmJgp1cx/RF8pj46LYs0MFDxFexyslAExOGsH+uuhXREQ/KI+Ni2LPBhU8xDAURUFjYyPq6+v7fLUUvBqKj48XZvdRwhfKY+Oi2PcPFTzEcFRVhcfjgdvthsfjgc/nA4Bun4oI/npYrVbY7XY4HA7Y7fawPzZKCOWxcVHs+4YKHmJ4qqrC7/fD7/cjEAhAVVVIkgRZlmGxWGCxWAx5ciB8oTw2Lop971DBQwghhBDhiX+XEiGEEEIMjwoeQgghhAiPCh5CCCGECI8KHkIIIYQIjwoeQgghhAiPCh5CCCGECI8KHkIIIYQIjwoeQgghhAiPCh5CCCGECO//AcN1Ixa6KgO0AAAAAElFTkSuQmCC\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(1, 2, figsize=(10, 2))\n",
"\n",
"lattice = {'A': ['B', 'C']}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {'A': [0, 0], 'B': [1, 0.5], 'C': [1, -0.5]}\n",
"nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[0], arrowsize=20)\n",
"ax[0].set(xlim=[-0.5, 1.5], ylim=[-1, 1])\n",
"\n",
"lattice = {'A': ['C', 'D'], 'B': ['C', 'D']}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {'A': [0, 0.5], 'B': [0, -0.5], 'C': [1, 0.5], 'D': [1, -0.5]}\n",
"nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[1], arrowsize=20)\n",
"ax[1].set(xlim=[-0.5, 1.5], ylim=[-1, 1]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aDBWlvmezJju"
},
"source": [
"The left DAG is not a lattice because there exists no upper bound for nodes `B` and `C`; the right DAG fails on two counts: first, there exists no upper bound for nodes `C` and `D`, and for nodes `A` and `B` the least upper bound cannot be *uniquely* determined: both `C` and `D` are candidates, but they are unorderable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o50FMh8_VGHx"
},
"source": [
"### Properties of a Type Promotion Lattice\n",
"\n",
"Specifying type promotions in terms of a lattice ensures a number of useful properties. Denoting the join on the lattice with the $\\vee$ operator, we have:\n",
"\n",
"**Existence:** A lattice by definition requires that a unique lattice join exists for every pair of elements: $\\forall (a, b): \\exists !(a \\vee b)$\n",
"\n",
"**Commutativity:** The lattice join is commutative: $\\forall (a, b): a\\vee b = b \\vee a$.\n",
"\n",
"**Associativity:** The lattice join is associative: $\\forall (a, b, c): a \\vee (b \\vee c) = (a \\vee b) \\vee c$.\n",
"\n",
"On the other hand, these properties imply restrictions on the type promotion systems they can represent; in particular **not every type promotion table can be represented by a lattice**. A ready example of this is NumPy's full type promotion table; this can be shown quickly by counterexample: here are three scalar types whose promotion behavior in NumPy is non-associative:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "AbApKMiPXls8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32\n",
"float16\n"
]
}
],
"source": [
"import numpy as np\n",
"a, b, c = np.int8(1), np.uint8(1), np.float16(1)\n",
"print(np.dtype((a + b) + c))\n",
"print(np.dtype(a + (b + c)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_VGHxo50FMh8"
},
"source": [
"Such a result may come as a surprise to users: we generally expect mathematical expressions to map to mathematical concepts, so, for example, `a + b + c` should be equivalent to `c + b + a`; `x * (y + z)` should be equivalent to `x * y + x * z`. If type promotion is non-associative or non-commutative, these properties no longer apply.\n",
"\n",
"Further, a lattice-based type promotion system is simpler to conceptualize and understand when compared to a table-based system. For example, JAX recognizes 18 distinct types: a promotion lattice consisting of 18 nodes and sparse, well-motivated connections between them is far easier to hold in one's mind than a table of 324 entries.\n",
"\n",
"For this reason, we opt to use a lattice-based type promotion system for JAX."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cdkJV3qqUrO_"
},
"source": [
"## Type Promotion within Categories\n",
"\n",
"Numerical computing libraries generally provide more than just `int`, `float`, and `complex`; within each of these categories there are a variety of possible precisions, denoted by the number of bits used in the numerical representation. The categories we will consider here are:\n",
"\n",
"- *unsigned integers* which include `uint8`, `uint16`, `uint32` & `uint64` (we'll use `u8`, `u16`, `u32`, `u64` for short)\n",
"- *signed integers* which include `int8`, `int16`, `int32` & `int64` (we'll use `i8`, `i16`, `i32`, `i64` for short)\n",
"- *floating point*, which include `float16`, `float32` & `float64` (we'll use `f16`, `f32`, `f64` for short)\n",
"- *complex floating point*, which include `complex64` & `complex128` (we'll use `c64`, `c128` for short)\n",
"\n",
"Numpy's type promotion semantics **within** each of these four categories is relatively straightforward: the ordered hierarchy of types translates directly to four separate lattices representing in-category type promotion rules:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"cellView": "form",
"id": "hi6YuTfyW03b",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de1xT9f8H8NcZAwY6kLVyjmlKhqlQiYqJgaaVmJfUsr6VZpmm5q30J17Thw81tAv21dAyCzOtLDI1K7Xs8tBUstRvKpq3MmGOGiBMx2CHnd8fPLZHBOouZ5+dc3w//9Jxds77vHb23jmfnbPDCYIggBBCCBOqUBdACCHXE2q6hBDCEDVdQghhiJouIYQwRE2XEEIYUoe6AEKUiOd5lJeXw2azweFwQBAEcBzXYDr34xqNBlqtFnFxcVCr6W3pKznlzdEpY4SIx263w2q1wmazAah7k3vL3SS0Wi30ej2io6ODUqOSyDFvarqEiIDneZjNZthsNp/e+FfCcRy0Wi2MRiPt+TZCznlT0yUkQJWVlSgqKoIgCKI0ADeO48BxHEwmE2JiYkSbr9zJPW9quoT4SRAEWCwWlJWVifrm/zeO46DT6WAwGBodp7xeKCVvarqE+EEQBBQXF6OioiKoDcCN4zjExsYiPj7+umy8SsqbThkjxA8Wi4VZAwDqmk5FRQUsFguT5UmNkvKmpkuIjyorK4N+iNsYQRBQVlaGyspKpssNNaXlTU2XEB/wPO/5EicUBEFAUVEReJ4PyfJZU2Le1HQJ8YHZbA5ZA3ATBAFmszmkNbCixLyp6RLiJbvdLtp5oYEQBAE2mw12uz2kdQSbUvOmpkuIl6xWa8gbgJsgCLBaraEuI6iUmjc1XUK8wPO851JTqbDZbIod21Vy3tR0CfFCeXl5qEtolFTrCpRU10uMuqjpEuIFKYwt/pt7rFGJlJw3NV1CvOBwOAJ6fllZGbKystC9e3ekpaVhxowZDaapqKhARkYGnnzySWZ1SZUv6/X3339j0qRJ6N27N5KTk1FcXFzv76+++ir69++Pbt26YeDAgdi6dWu9vxcUFOCRRx7BXXfdhczMTHzyySei1HUl9PNFhFyD0+kMeK/rhRdeQMeOHbFz505oNBqcPn26wTTLli1DmzZtfFqWIAhwOp0IDw8PqL5QOX78OG677bZ6l9r6mjfHcejRoweeeeYZjBgxosHfo6KisGLFCrRu3RpHjx7FuHHj0KpVK9x5551wOp14/vnn8cILL2DYsGE4duwYRo0ahdtvvx3t2rVrMC8x8qY9XUKuweFweH39vcViwfPPP4+MjAzcfffdWLx4Mfbu3QuLxYJp06ZBq9UiPDwc7du3r/e8w4cP49SpUxg8eLBPtXEcJ9u93aqqKnTo0AGJiYnYvHmzp9FeLe/G8tXr9fjPf/6DpKSkRp8zYcIEJCQkQKVS4fbbb0fnzp3xv//9D0Dd0cWlS5cwcOBAcByHpKQkJCQk4MyZM43OS4y8aU+XkGtwuVxeTVdbW4sJEyYgNTUV27dvR1hYGI4dO4aCggK0bt0ac+bMwZ49e2AymTBt2jR07drV87yXXnoJ8+fPx6lTp3yqjed5HD58WJaN1+FwQKVS4fTp03jssceg0+kwa9asRvdWgSvn6+syjx49ikcffRQAoNfr0a9fP2zevBmPPPIIjhw5ggsXLiAlJeWK8/B2e7gSarqEXIO3h7pHjhzB33//jWnTpnl+CDslJQVbt27F3r17sWDBAixcuBDffPMNpkyZgi+++AJxcXHYsGEDkpOT0bFjR5+bbk1NDbZt24aDBw/6vF6hVltbW2/v9sKFC1iyZAmGDx/e6PRXytcXCxcuRLt27dCjRw/PYw888ADmz5+PpUuXAgDmzp0Lg8FwxXkEOtRETZeQa/BlaKFFixYN7jyg0WgQHx+PoUOHAgD69euH1atX49ChQ0hKSsIHH3yAjRs3+lVbdHQ0Zs+ejdjYWL+eH0p2ux1NmzZFkyZNkJCQgJycHPTp0+eKPzBzpXy99dprr+HUqVN49913Pa/p2bNnkZWVhWXLlqF79+44d+4cJk6ciJtuugkZGRmNzifQn3qkpkvINahU3n31YTAYYLFYwPN8vcaQmJiI77//vt607jeue+/twQcfBABUV1fD4XCgV69e2LVrF8LCwkSrT2oiIyMxfPhwPPnkk+jTp48nkyutz5Xy9UZubi727NmDvLw8NG3a1PP46dOncfPNN3v2fNu0aYOMjAzs3r37ik030Lzl+WoRwpBGo/HqkDI5ORl6vR6vv/467HY7qqurcejQIc/e25YtW1BbW4udO3eipKQEnTp1Qnp6Onbs2IH8/Hzk5+djwoQJaN++PfLz871quIIgQKPRiLGazIWFhWHdunW499576+09XinvK+UL1H1Y1dTUAKgbcqmurvY8b82aNfjyyy/x9ttvo1mzZvXm2b59e5w7dw4FBQUQBAHnz5/HDz/8gMTExEZrFiNvunMEIV4oLCz06guUCxcuIDs7GwcPHgTHcXjggQcwa9Ys/PLLL1i0aBGKi4vRpk0bZGVloXPnzg2ev3nzZmzatAnr1q3zqi6VSoUOHTr4vD5Sd6W8r5RvcnJyg2mPHDkCoK5Zh4eH19s7HjNmDMaMGQMA2L59O9566y2YzWY0bdoU/fv3x/PPP9/oHq0YeVPTJcQLZ8+eleSvekVHRyMhISHUZYhOyXnT8AIhXtBqtZK7N5n7tuFKpOS8qekS4oW4uLhQl9AoqdYVKKmulxh1UdMlxAtqtVpye5Vardbv06ekTsl5U9MlxEt6vV4yh7wcx0Gv14e6jKBSat7UdAnxUnR0tCTGGt1ji9HR0SGtI9iUmjc1XUJ8YDQaJdEE4uPjQ1oDK0rMm5ouIT5Qq9UwmUwhawQcx8FkMnl14YQSKDFvarqE+CgmJgY6nY55I+A4DjqdDjExMUyXG2pKy5uaLiF+MBgMiI2NZdYIOI5Ds2bNrvrrV0qmpLzpijRC/CQIAiwWC8rKyoJ6Py/3HpfBYAj5+GYoKSVvarqEBKiyshJFRUUQBEHUZsBxnGdM8XobUrgauedNTZcQEfA8D7PZLNpdbN2nKRmNRsVeABEIOedNTZcQEdntdlitVs+tun29wSJQd+WTXq9X/Hm4YpBj3tR0CQkCnudRXl4Om80Gh8MBQRAaHR90P67RaKDVahEXF0d7tn6QU97UdAlhwOl0wuFwwOVyed74KpUKGo1GtrdPlzIp501NlxBCGKLzdAkhhCFquoQQwhA1XUIIYYiaLiGEMERNlxBCGKKmSwghDFHTJYQQhqjpEkIIQ9R0CSGEIWq6hBDCEDVdQghhiJouIYQwRE2XEEIYoqZLCCEMUdMlhBCGqOkSQghD1HQJIYQharqEEMIQ3QHvOiTl+0cpFWXOlpTzpqZ7HZDTnVKVgjJnS055040pFcxut8NqtcJmswGo2+C85d5gtVot9Ho9oqOjg1Kj0lDmbMkxb2q6CsTzPMxmM2w2m08b4ZVwHAetVguj0Uh7YVdAmbMl57yp6SpMZWUlioqKIAiCKBujG8dx4DgOJpMJMTExos1XCShztuSeNzVdhRAEARaLBWVlZaJuiP/GcRx0Oh0MBkOjY2bXE8qcLaXkTU1XAQRBQHFxMSoqKoK6MbpxHIfY2FjEx8dft02AMmdLSXnTeboKYLFYmG2MQN0boKKiAhaLhcnypIgyZ0tJeVPTlbnKysqgH241RhAElJWVobKykulypYAyZ0tpeVPTlTGe5z1fKISCIAgoKioCz/MhWX4oUOZsKTFvaroyZjabQ7YxugmCALPZHNIaWKLM2VJi3tR0Zcput4t2jmIgBEGAzWaD3W4PaR0sUOZsKTVvaroyZbVaQ74xugmCAKvVGuoygo4yZ0upeVPTlSGe5z2XPUqFzWZT9DgjZc6WkvOmpitD5eXlXk/7+++/4+GHH0a3bt2wYcOGIFblW11y4+26scwbUG7mSs6bmq4M+TLOlZeXh65du6KgoAC33norRo0ahe7du6Nv376NTr9+/XpkZmYiNTUVgwYNwh9//OHVctzjXkrlbeb/zLu2thaZmZm466670Lt3byxdutSzp1RaWoqsrCz07t0b3bt3x4gRI/Drr7/6VJOSM/cn7yeeeAKFhYUYOXIkUlNT0bNnT6xfv77Bcw4cOIDk5GQsX77cp5rEypuargw5HA6vpzWbzWjbti0AICoqCkOGDMHUqVMbnfbTTz/Fpk2bkJubi4KCAuTm5iIuLi4odcmNt+v2z7zvuecefPzxx9i/fz8+++wznDx50rM3Zrfb0bFjR2zcuBF79uzBoEGDMGHCBJ+/rFFq5v7kXV5ejvHjx2PYsGHYs2cPvvzyS6SlpdWb3ul0YunSpbj99tuDWtfVUNOVGafT6fVe7jPPPIMDBw7gpZdeQmpqKrRaLQYOHAiTydRgWpfLhVWrViErKwu33HILOI5Dy5YtERsb63VtgiDA6XR6Pb3U2O127N69u0G+3mb+77xra2s9P5zi/h3X8+fPAwBatmyJkSNH4sYbb0RYWBiGDRsGp9OJ33//3aea5Z75999/j+rq6nqP+Zt3dnY20tLSMGDAAERERKBJkyZISEio95z33nsPaWlpaN26tV/1ipE3NV2ZcTgcXl8L/s477yAlJQWzZ8/GTz/9dNUNraSkBCUlJTh9+jTuvfdeZGZmIjc3Fy6Xy+vaOI6T9Z7Xjz/+iIyMDHTo0AFffPGF543vbeaN5f3FF1/grrvuQnp6Ok6ePIlhw4Y1+twTJ07A6XSiVatWPtUs58ztdjvuuecexMfHIzc319N8/c27tLQUsbGxGD58OHr27ImJEyfiwoULnunNZjM2b96McePG+V2zGHnTD3XKjC9N0BclJSUAgL1792LTpk2w2WwYO3Ysmjdvjocfftjr2r799ltcvHgxKDUG29GjRxEVFYUTJ07goYceQlxcHLKzszFkyBC/59m/f3/0798f586dw9atW3HDDTc0mObSpUuYNWsWxo8fD61W69P8nU6nbDOvrq6GSqVCaWkppk2bhpkzZ2LMmDGYP3++X/MrKSnB8ePHsXr1atx6663IyclBVlYW3n//fQBAdnY2Jk6cGPCPlQf6HqSmKzPBOm8xMjISAPD0008jJiYGMTExGDZsGHbv3u11062trcXx48dx/PjxoNQYbGaz2fNFl9PpxF9//YVvv/0WgwcPDnjeN998M9q2bYtFixbh9ddf9zzucDgwceJE3HHHHRg9erTP8+V5XraZ8zzv2Z6rq6vhdDqxc+dOzJs3z6/5RUZGonfv3khKSgIAjB8/Hunp6bDZbPjll19gt9uRmZkZcN2Bvgep6cpMsH7Wr3Xr1ggPDw9o/uHh4Rg/frxP48BSsnPnTnz//feIiorC008/jRdffBEGgwEVFRWizJ/nec+YLgDU1NRgypQpaN68ud+NJioqSraZ2+12fPjhh4iKikKvXr2wdOlSJCcn+513YmJive33n/8uKCjAsWPH0KtXLwB1RxcqlQqnTp3CihUrfFpOoO9BGtOVGZXK/5fM5XKhurras4fh3rsA6t68mZmZyMvLw+XLl2GxWJCfn4+ePXsyqy/UkpKSMHPmTJw9exa5ubkwGAwA/F+nTz/9FKWlpQCAM2fO4J133kG3bt0A1O1JT506FZGRkVi8eHFAuck1c41Gg1mzZqGgoABffvklkpOTAfi/PoMHD8auXbs84+NvvvkmUlJSoNVqMXHiRGzbtg35+fnIz89Hr1698NBDD2HRokU+LyfQvGlPV2Y0Go3fhze//PILRo0a5fl/ly5d0KVLF+Tl5QEAZs+ejQULFqB3797QarV46KGHfBrPFAQBGo3Gr9qkwGg0YuHChQ0e9zfzQ4cOYfny5aiqqkJcXBzuv/9+TJw4EQBw+PBh/PDDD9BoNPVOa1q1ahU6d+7s9TLknLlKpcLixYsbPO5v3t26dcOUKVMwYcIEVFVVISUlBUuXLgUANGnSBE2aNPFMGxkZiaioKJ+PEMTIm+4cIUOFhYVB+0ItECqVCh06dAh1GUFBmbOl5LzleVxynZPqno1U6xKDVNdNqnUFSqrrJUZd1HRlSKvVSu4+We5bWCsVZc6WkvOmpitDvlyay5JU6xKDVNdNqnUFSqrrJUZd1HRlSK1WS24PR6vVQq1W7veylDlbSs6bmq5M6fV6yRx+cRwHvV4f6jKCjjJnS6l5U9OVqejoaEmMe7nHuQK9tFIOKHO2lJo3NV0ZMxqNktgg4+PjQ1oDS5Q5W0rMm5qujKnVaphMppBtlBzHwWQyISwsLCTLDwXKnC0l5k1NV+ZiYmKg0+mYb5Qcx0Gn03l+L/Z6QpmzpbS8qekqgMFgQGxsLLONkuM4NGvWzPPbBNcjypwtJeVNlwErhCAIsFgsKCsrC+ptq92f/gaDIeRjbaFGmbOllLyp6SpMZWUlioqKIAiCqBsmx3Ge8a3r7fD2WihztuSeNzVdBeJ5Hmaz2ae7Bl+N+5QZo9Go2JPxA0WZsyXnvKnpKpjdbofVavXcNtqXl9p9WKXVaqHX6xV/TqhYKHO25Jg3Nd3rAM/zKC8vh81mg8Ph8NyZ9t/cj2s0Gmi1WsTFxdFelp8oc7bklDc13euQ0+mEw+GAy+XybIQqlQoajQbh4eGhLk+RKHO2pJw3NV1CCGGIztMlhBCGqOkSQghD1HQJIYQharqEEMIQNV1CCGGImi4hhDBETZcQQhiipksIIQxR0yWEEIao6RJCCEPUdAkhhCFquoQQwhA1XUIIYYiaLiGEMERNlxBCGKKmSwghDFHTJYQQhqjpEkIIQ5K4A56U72ekRJQ3e5Q5W1LOOyRNV0537lQCyps9ypwtOeXN9MaUcrxHvZxR3uxR5mzJMW8mTZfneZjNZthsNp9CuRKO46DVamE0GmmvoBGUN3uUOVtyzjvoTbeyshJFRUUQBEGUcNw4jgPHcTCZTIiJiRFtvnJHebNHmbMl97yD1nQFQYDFYkFZWZmowfwbx3HQ6XQwGAyNjuFcLyhv9ihztpSSd1CariAIKC4uRkVFRVDDceM4DrGxsYiPj78uN0rKmz3KnC0l5R2U83QtFguzcIC6F6SiogIWi4XJ8qSG8maPMmdLSXmL3nQrKyuDvvvfGEEQUFZWhsrKSqbLDTXKmz3KnC2l5S1q0+V53jPAHQqCIKCoqAg8z4dk+axR3uxR5mwpMW9Rm67ZbA5ZOG6CIMBsNoe0BlYob/Yoc7aUmLdoTddut4t2zlwgBEGAzWaD3W4PaR3BRnmzR5mzpdS8RWu6Vqs15OG4CYIAq9Ua6jKCivJmjzJnS6l5i9J0eZ73XIYnFTabTbHjXpQ3e5Q5W0rOW5SmW15e7vdzBw8ejAMHDohRRgOB1CVlvqxXMPP9N6XmDVDmrCk5b1Eujjh79mzA4x2CIGDFihXYsmUL7HY7brvtNsyZMwdt27b1e57R0dFISEgIqC4p8ifvU6dO4dVXX0VhYSEuXryII0eONJjmq6++wqpVq2CxWHDDDTdg0aJF6Ny5s9fLUGregH+Zf/XVV1i5ciWsVisiIiJw9913Y9asWWjatClqamqwaNEi7N+/HxUVFWjZsiWmTJmC9PR0n5ah1Mz97Snnz5/HkiVL8PPPPyMiIgJDhgzB1KlT601z7tw5DB06FPfddx+WLFni0/zFyFuUPV2HwxHwPHbs2IHNmzdj7dq12LNnD+644w7Mnj075HVJkT/rpVar0bdvXyxYsKDRv+/duxfLli3DwoULsX//fqxduxYmkynodcmFP+vWqVMnrFu3Dvv27cNXX30FnuexYsUKAHWHzwaDAXl5edi3bx8mTZqE//u//0NxcXHQ65IDf9bL6XTi2WefRWpqKr777jt888036N+/f4PpFi9ejKSkJGZ1/VvATdfpdAY02N23b1/s27cPxcXF6NSpE1q2bImwsDAMGDAAZ86cCag2QRDgdDoDmkco/fbbb/jmm2/q5etr3u5827Rpg6FDh17xyGHlypUYN24c7rjjDqhUKjRv3hzNmzf3qV65511RUYGNGzeipqam3uP+Zm4wGBAXF+d5PCwsDH/++SeAuj2m5557DvHx8VCpVOjZsyfi4+NRWFjoU81yzlwQBHzwwQcoKyur97i/eW/evBk33XQTRo4ciejoaERGRqJdu3b1pv3qq6+g1WrRrVs3v2sONO+Am67D4RDl2uR+/frh/Pnz+OOPP+B0OrF161b06NEjoHlyHCfrPYENGzbg/vvvR3JyMnbs2AFBEETL+59qa2tx7NgxlJWV4YEHHkCfPn2wePFin7OTe94FBQV4/PHHYTKZ8NZbb3mabyCZHzx4EN27d0e3bt3wzTffYMSIEY1OZ7Vace7cOdxyyy0+zV/OmVdVVeGJJ55Ay5YtMWPGDE/z9TfvX3/9FUajEePGjUN6ejqefvppnDx50vP3S5cuITc3F9OnT/e7ZjHyDviHI10uV6CzAADceOONSElJwcCBAxEWFgaDwYA1a9YENE+n04lNmzbhwoULotTI2o8//ghBEHDs2DEMGjQIcXFx+Oijj3DTTTeJupzS0lLwPI+vv/4a7733HtRqNSZPnozVq1dj8uTJXs+ntrZW1nmfPHkSERER+PvvvzF58mRMmzYNM2fOxKRJk/yeZ0pKCvbt24eSkhJ8+umnMBqNDaZxOp2YOXMmBg0a5PN4YXV1tWwzr6mpgUqlgt1uR05ODpYtW4aBAwfi3Xff9Wt+JSUlOHDgAJYvX4677roL69evx+TJk/H5558jPDwcb7zxBoYMGQKDwRBQ3YH2vICbrljn0a1atQpHjx7F119/Db1ej23btmH06NH47LPPEBUV5fd8q6qqcPHiRVFqZO2fn6iCIKCqqgrV1dWiLycyMhIA8Pjjj+PGG28EADz55JM+N113jXLN+/Lly57t2eVywel04uLFi6Js482bN0ePHj2QlZWFjz/+2PO4y+XC7NmzER4e7vd3GHLN/J+H6S6XCxzHoby83O+8IyMj0alTJ8+XkU899RRWr16Ns2fPQhAE7N+/H5988knAdQe6PQTcdMU61P3tt9+QmZnp+RQaPHgwXn75ZZw9exYdO3b0a57h4eF47LHHEBsbK0qNrM2bNw/79++HyWTCyy+/jGHDhsFms/n8Zcu1xMbGonnz5vVeS39eV7VaLeu8d+7ciU8//RQxMTGYNWsWJk2ahCZNmqCiokKU+dfW1uL8+fOe/wuCgHnz5qG0tBQrV67064aJkZGRss3cbrdj2bJliIqKwogRIzB//nwYjUa/805MTMThw4cb/duBAwdgNptx3333eZbtcrnwyCOP1PsQ9EagPS/gpqtSiXNRW1JSEnbu3InMzEzodDp88cUX4HkeLVu2lER9oTB06FDcfvvtGDp0qGc9/F0fQRBQU1Pj2buorq4Gx3GIiIgAUPch98EHH6BHjx5Qq9V4//33kZGR4fNy5Jx3SkoKcnNzMXz4cDRp0sTzuL/rtG3bNnTu3BktWrSA2WzG8uXL632Bs3DhQvz+++94++23odFo/K5brplHRUXhjTfewODBg+sNu/i7PgMGDPCcLZKamooNGzagWbNmSEhIQKtWrdCvXz/PtGvXroXZbMbcuXN9Xk6geQfcdDUajSiHX6NGjUJpaSmGDRuGqqoqtGrVCjk5OQHdNkMQhIA25lC78847ceedd9Z7zN+8zWYzMjMzPf/v0qULjEYjduzYAQAYO3YsLl68iIEDByIiIgJ9+/bFs88+69My5J63Xq/H2LFjGzzub+Znz57FsmXLYLPZoNVqkZ6ejueffx5A3evxySefICIiAr169fI8Z968eRgwYIDXy5Bz5hzH4bnnnmvwuL95t2nTBtnZ2Vi4cCHKysrQvn17rFixAuHh4QgPD683TBkdHY2IiAjodDqfliFG3qJcHFFYWCjaF2piUqlU6NChQ6jLEB3lzR5lzpaS8xbluESqn7RSrStQUl0vqdYlBqmum1TrCpRU10uMukRpulqtVnL3bXLfUlmJKG/2KHO2lJy3KE33n1fdSIlU6wqUVNdLqnWJQarrJtW6AiXV9RKjLlGarlqtltwnrlarhVod8PeEkkR5s0eZs6XkvEU710Sv10vmcIDjOOj1+lCXEVSUN3uUOVtKzVu0phsdHS2JcRj3uEt0dHRI6wg2yps9ypwtpeYt6lnVRqNREgHFx8eHtAZWKG/2KHO2lJi3qE1XrVbDZDKFLCSO42AymRAWFhaS5bNGebNHmbOlxLxFv34wJiYGOp2OeUgcx0Gn0wV0BZscUd7sUeZsKS3voFy0bTAYEBsbyywkjuPQrFmzgH+yTa4ob/Yoc7aUlLcolwE3RhAEWCwWlJWVBfU2yu5PI4PBEPKxn1CivNmjzNlSSt5Ba7pulZWVKCoqgiAIogbFcZxnvOV6O9y6GsqbPcqcLbnnHfSmC9TdhM9sNsNms4kSkvsUDqPRqNiTwwNBebNHmbMl57yZNF03u90Oq9UKm80GwLdfYHfv5mu1Wuj1esWfoygGyps9ypwtOebNtOm68TyP8vJy2Gw2OBwOCILQ6NiJ+3GNRgOtVou4uDj61PcD5c0eZc6WnPIOSdP9N6fTCYfDAZfL5QlFpVJBo9H4dQsTcnWUN3uUOVtSzlsSTZcQQq4X8ry5EiGEyBQ1XUIIYYiaLiGEMERNlxBCGKKmSwghDFHTJYQQhqjpEkIIQ9R0CSGEIWq6hBDCEDVdQghhiJouIYQwRE2XEEIYoqZLCCEMUdMlhBCGqOkSQghD1HQJIYQharqEEMIQNV1CCGFIEnfAk/L9jJSI8maPMmdLynmHpOnK6c6dSkB5s0eZsyWnvJnemFKO96iXM8qbPcqcLTnmzaTp8jwPs9kMm83mUyhXwnEctFotjEYj7RU0gvJmjzJnS855B73pVlZWoqioCIIgiBKOG8dx4DgOJpMJMTExos1X7ihv9ihztuSed9CariAIsFgsKCsrEzWYf+M4DjqdDgaDodExnOsF5c0eZc6WUvIOStMVBAHFxcWoqKgIajhuHMchNjYW8fHx1+VGSXmzR5mzpaS8g3KersViYRYOUPeCVFRUwGKxMNpkoEQAAAqKSURBVFme1FDe7FHmbCkpb9GbbmVlZdB3/xsjCALKyspQWVnJdLmhRnmzR5mzpbS8RW26PM97BrhDQRAEFBUVgef5kCyfNcqbPcqcLSXmLWrTNZvNIQvHTRAEmM3mkNbACuXNHmXOlhLzFq3p2u120c6ZC4QgCLDZbLDb7SGtI9gob/Yoc7aUmrdoTddqtYY8HDdBEGC1WkNdRlBR3uxR5mwpNW9Rmi7P857L8KTCZrMpdtyL8maPMmdLyXmL0nTLy8vFmI3opFpXoKS6XlKtSwxSXTep1hUoqa6XGHWJ0nSDMe6yfft2DBo0CN26dcODDz6IXbt2+fR89ziMEgWat9PpxNSpU9G3b18kJyfjwIEDDaYpLCzEyJEjkZqaip49e2L9+vVXnaeS8wYCy/zMmTN49NFHkZaWhrS0NIwePRpnzpzx/D0vLw9DhgxBt27dkJmZiby8PK/mq+TMxegpVVVVWLRoEdLT09G9e3eMHDmywTROpxODBg1Cnz59rjk/sfIW5ZcdHA6HGLPxKCkpwaxZs7B8+XLcfffd2L17N6ZNm4bt27fjhhtuCFldUiHGenXq1AnDhw/HtGnTGvytvLwc48ePx/Tp03H//ffD6XSipKSESV1SFci63XjjjcjJyYHRaITL5cJHH32E6dOnY9OmTQDq3syLFy9GYmIizp8/j7Fjx8JgMKBfv35BrUvKxFivBQsWoLa2Flu2bEFsbCxOnDjRYJq8vDzExcXh8uXLzOoKeE/X6XT69YmUnJyMP//80/P/OXPmYPny5QDqmm5MTAzS09PBcRwyMjIQFRWF8+fP+7QMQRDgdDp9rk0qDhw4gPfee6/eOJK3eV8t3/DwcIwYMQIpKSlQqRpuAuvWrUNaWhoGDBiAiIgINGnSBAkJCddcptzz/uuvv7Bs2TJcvHix3uPeZH61vGNiYjyXkwqCAJVKVW9bHjVqFDp06AC1Wo02bdrgnnvuwaFDh7yqWc6Zu1wuvPrqq/j999/rPS7GNn727Fl8//33mD9/PnQ6HcLCwtCxY8d6zy8qKsK2bdswevRor2sWI++Am67D4RD92uSOHTuiTZs2+O6771BbW4tdu3YhPDwciYmJPs2H4zhZ7wl8/vnneOaZZ9CyZUu8++674Hk+KHn/26+//orY2FgMHz4cPXv2xMSJE3HhwoVrPk/ueR8+fBhZWVkwmUx48cUXPc1XrMzT0tLQpUsXZGdnX/GNLggCDh48iLZt23o1Tzln7nA4MH36dHTo0AFPPPGEp/mKkffRo0fRokUL5ObmIj09HUOGDMHXX39db5rs7GxMmTIFGo3G6/mKkXfAwwsulyvQWTQQFhaGQYMGYcaMGaipqUF4eDhee+01n39kuKamBmvWrGnwSSoX+/fvR21tLSwWC8aOHYtJkybho48+wi233BLU5ZaUlOD48eNYvXo1br31VuTk5CArKwvvv//+VZ/H87ys8/7zzz+hUqlw+fJlZGdnY8mSJRgzZgyys7NFmf/evXtht9uxdetWGI3GRqdZuXIlXC4XBg8e7NU8HQ6HbDN3Op2eJvbhhx9i48aN6Nq1K7Zv3x7wvEtKSnD69Gncd999+Pbbb3H48GFMmDABt9xyCxISErBr1y7U1taiT58+jX6ncTWB9ryAm24wzqPbt28fcnJykJeXh/bt26OwsBCTJk3CqlWrcNttt3k9H5VKBaPRiLCwMNFrZOHkyZOef6tUKuj1euh0uqAvNzIyEr1790ZSUhIAYPz48UhPT4fNZoNWq73i8ziOk3XePM97hlvCwsIQFRWFW2+9VdRtPDo6Go888ggyMjKwZcuWet9RfPDBB/j888+xdu1aREREeDU/OW/jNTU1niEXtVoNlUqFdu3aiZJ3ZGQk1Go1nn32WajVanTt2hWpqanYu3cvDAYDcnJysHLlSr/mHWh9ATddfw8DoqKiUFVV5fl/aWkpmjdvDgD47bff0LlzZ88YTFJSEpKTk7F//36fmq5arUZmZiZiY2P9qjHUrFYrduzYgdTUVLzyyivIyMhARUUFiouLr/ncq+V7LYmJifVeV29f47CwMFnnvXPnTrz55puIj4/HkiVL8NhjjyEsLAwVFRXXfK4vebtcLjgcDvz111+epvvZZ5/hnXfewdq1a2EwGLyuOSIiQraZ2+12TJ8+HU2bNsWMGTMwZcoUaLVar/IGrp55Y0OR7u34zz//hNls9pzN4HQ6cenSJfTq1QsbNmxAfHz8VZcb6NBHwGO6jX0R44127drhyy+/RG1tLfbs2YOff/7Z87eOHTvi4MGDnm8bjx8/joMHD/o8phtIfVLw1FNP4ccff0RBQQEyMjIAeL8+V8sXqNvLqK6uBlC30VVXV3s+wQcPHoxdu3bhxIkTcDqdePPNN5GSknLVvVw3OefdvXt3bNmyBefOncPw4cM9e4/erNPV8t67dy+OHz+O2tpaXLp0Ca+88gpiYmI8X05u27YN//3vf/H222+jZcuWPtct18yjoqLw2WefwWw2Y+7cuZ7tS4xtvHPnzmjRogXWrFkDnudx6NAh/PTTT0hLS0Pbtm3x9ddfIz8/H/n5+ViwYAFuuOEG5Ofne/WBF2jeAf+IudPpxMmTJ33e5T527BjmzJmDCxcuoHfv3qitrYXJZMLkyZMB1B1qrV+/HqWlpYiLi8Njjz3W6Hl2V8NxHBITE0N+y2UxeZv3tfLt27dvgx/x2L59u+dTfuPGjVi9ejWqqqqQkpKCuXPnXnODVGLegHeZXy3vHTt24I033kBJSQk0Gg2SkpIwZcoUtGvXDgCQmZmJkpKSerkNGDAA8+bNu2ZtSsxcrG389OnTmD9/Pk6dOoUWLVpg8uTJjZ6Pe+DAAcycOdOrawHEyFuUO0cUFhYG5Qu1QKlUKnTo0CHUZYiO8maPMmdLyXmLclziyykXLEm1rkBJdb2kWpcYpLpuUq0rUFJdLzHqEqXparVayd23yX1LZSWivNmjzNlSct6iNN24uDgxZiM6qdYVKKmul1TrEoNU102qdQVKquslRl2iNF21Wi25T1ytVgu1WpSflpAcyps9ypwtJect2rkmer1eMocDHMdBr9eHuoygorzZo8zZUmreojXd6OhoSYzDuMddfL1kWG4ob/Yoc7aUmreoZ1UbjUZJBHStK0qUgvJmjzJnS4l5i9p01Wo1TCZTyELiOA4mk0mW16H7g/JmjzJnS4l5i379YExMDHQ6HfOQOI6DTqdDTEwM0+WGGuXNHmXOltLyDspF2waDAbGxscxC4jgOzZo18+mHQpSE8maPMmdLSXmLchlwYwRBgMViQVlZWVBvo+z+NDIYDCEf+wklyps9ypwtpeQdtKbrVllZiaKiIgiCIGpQHMd5xluut8Otq6G82aPM2ZJ73kFvukDdj0ObzWbR7hrsPoXDaDQq9uTwQFDe7FHmbMk5byZN181ut8NqtXpuY+zLot27+VqtFnq9XvHnKIqB8maPMmdLjnkzbbpuPM+jvLwcNpsNDocDgiA0Onbiflyj0UCr1SIuLo4+9f1AebNHmbMlp7xD0nQJIeR6Jc/7fBBCiExR0yWEEIao6RJCCEPUdAkhhCFquoQQwtD/A04qa8YM/BNxAAAAAElFTkSuQmCC\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],\n",
" 'f16': ['f32'], 'f32': ['f64'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'u8': [0, 0], 'u16': [1, 0], 'u32': [2, 0], 'u64': [3, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [1, 2], 'f32': [2, 2], 'f64': [3, 2],\n",
" 'c64': [2, 3], 'c128': [3, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 4))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3m8_BFDqdxvQ"
},
"source": [
"In terms of promotion of values to 64-bit that JAX seeks to avoid, these same-kind promotion semantics within each type category are unproblematic: the only way to produce a 64-bit output is to have a 64-bit input."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pspgwrv2gNJw"
},
"source": [
"## Enter Python Scalars\n",
"\n",
"Let's now think about where Python scalars fit into the mix.\n",
"\n",
"In NumPy, promotion behavior differs depending on whether the inputs are arrays or scalars. For example, when operating on two scalars, normal promotion rules apply:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "-5boZVhbhG-k"
},
"outputs": [
{
"data": {
"text/plain": [
"dtype('int64')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.int8(0) # int8 scalar\n",
"y = 1 # Python int = int64 scalar\n",
"(x + y).dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TXL8PYfPptN"
},
"source": [
"Here the Python value `1` is treated as an `int64`, and straightforward within-category rules lead to an `int64` result.\n",
"\n",
"In operations between Python scalars and NumPy arrays, however, scalars defer to the dtype of the array. For example:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "U2P8nbEskxC_"
},
"outputs": [
{
"data": {
"text/plain": [
"dtype('int8')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.zeros(1, dtype='int8') # int8 array\n",
"y = 1 # Python int = int64 scalar\n",
"(x + y).dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sRiTUEOWP_7O"
},
"source": [
"Here the bit width of the `int64` scalar is ignored, deferring to the bit width of the array.\n",
"\n",
"There is another detail here: when NumPy type promotion involves a scalar, the output dtype is value-dependent: if the Python scalar is too large for the given dtype, it is promoted to a compatible type:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "JNpNwSwjihCb"
},
"outputs": [
{
"data": {
"text/plain": [
"dtype('int16')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.zeros(1, dtype='int8') # int8 array\n",
"y = 1000 # int64 scalar\n",
"(x + y).dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LajRXAuvRLbW"
},
"source": [
"For the purposes of JAX, **value-dependent promotion is a non-starter** because of the nature of JIT compilation and other transformations, which act on abstract representations of data without reference to their value."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ep3RJciFk_aX"
},
"source": [
"Ignoring value-dependent effects, the signed integer branch of NumPy's type promotion can be represented in the following lattice, where we'll use `*` to mark scalar dtypes:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"cellView": "form",
"id": "wf0FonWhlWwU",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqsAAADnCAYAAAA5Hh/PAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dfZzM9f7/8efYXbs2s9jjYu0ui4OEykVrSb5kk+tQuIkj61qUnFtKJBG6wFGUklwUUpKoaDk4jnKLjpIucEpJLjay1sWKbXd2Pr8//HaOsS4+M7t23sPjfrt1u7Wf+czMa5+f2bfnzn5mxmFZliUAAADAQMUCPQAAAABwKZRVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBYlFUAAAAYi7IKAAAAY1FWAQAAYCzKKgAAAIxFWQUAAICxKKsAAAAwFmUVAAAAxqKsAgAAwFiUVQAAABiLsgoAAABjUVYBAABgLMoqAAAAjEVZBQAAgLEoqwAAADAWZRUAAADGoqwCAADAWJRVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBYlFUAAAAYi7IKAAAAY1FWAQAAYCzKKgAAAIxFWQUAAICxKKsAAAAwFmUVAAAAxqKsAgAAwFiUVQAAABiLsgoAAABjUVYBAABgLMoqAAAAjEVZBQAAgLEoqwAAADAWZRUAAADGoqwCAADAWJRVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBYlFUAAAAYi7IKAAAAY1FWAQAAYCzKKgAAAIxFWQUAAICxKKsAAAAwFmUVAAAAxqKsAgAAwFiUVQAAABiLsgoAAABjUVYBAABgLMoqAAAAjEVZBQAAgLEoqwAAADAWZRUBMX78eNWtWzfQYwAAAMNRVpHP0aNHNXToUFWpUkXh4eGqUKGCkpOTtW7dukCPdkUOh+Oy/6WkpAR6RADXuO3btyskJERNmzYN9Cj5tGjR4rJrZJUqVQI9IpBPaKAHgHnuu+8+nTlzRvPmzVP16tX1+++/a9OmTTp27FigR/OSnZ2t4sWLe2377bffPP+/atUqDRw40GtbiRIlimw+ANenuXPnaujQoVq4cKF2796tm2666bL75+TkKCwszGvbxda3wvDBBx8oOztbkpSRkaE6depo+fLluv322yVJISEhhX6fQEHxzCq8nDhxQp999pmef/55JScnKyEhQYmJiRo5cqR69Ojh2S87O1tjxoxRQkKCwsPDVa1aNc2cOVOSlJubq/79+6tq1aoqUaKEatSooSlTpsjtdl/yfrdt26a7775bZcuWVVRUlO644w5t2bLFax+Hw6FZs2bp3nvv1Q033KAxY8bku52YmBjPf6VLl/Zsq1Chgu644w698cYbXvvv2bNHDodD27dv99zHK6+8ovbt2ysyMlIJCQlavHix13UOHTqkHj16qEyZMipTpozat2+vPXv2eC4/cOCAOnXqpOjoaEVGRqpWrVp699137cQPIMidPXtWS5Ys0aBBg9S1a1fNmzfP6/J9+/bJ4XDonXfeUcuWLVWiRAm9/vrrSklJUYcOHfTCCy8oPj5e8fHxkqTFixcrMTFRTqdT5cuXV7du3XTo0CFJkmVZql69uqZNm+Z1Hxeua+eLjo72rJHly5f32jZmzBj17dvXa3+3263KlStr+vTpks49MztkyBA98sgjnjXwscce81rfs7OzNWrUKMXHxysyMlKJiYlau3at5/KcnBwNHz5csbGxCg8PV6VKlfTEE0/4GzmuA5RVeClZsqRKliypjz76SFlZWZfcr0+fPlq4cKGmT5+u3bt3a968eZ5y6Ha7FRcXp/fee0+7d+/W5MmT9eyzz2rBggWXvL3MzEz17t1bn332mf7zn/+oXr16ateuXb5ncydMmKB27drpu+++07Bhw2x/Xw6HQ/379883w/z581WvXj01aNDAs+3pp5/WPffcox07dmjQoEF64IEH9OWXX0qSzpw5ozvvvFMRERHatGmTtmzZoooVK+quu+7SmTNnJElDhw7VmTNntHHjRu3cuVMvvfSSJxsA17b3339fCQkJuvnmm9W7d28tXLhQOTk5+fYbPXq0hg4dql27dqlz586SpE2bNunbb7/VmjVrtGHDBknnit+ECRP0zTffaNWqVUpPT9f9998vybd1zY6BAwdqzZo1Xn+NWrdunQ4fPqzevXt7tr399ttyu93asmWLXn/9dc2ZM0cvvfSS5/K+fftq06ZNWrJkib7//nv16dNHHTt21DfffCNJmjlzplasWKF3331Xe/bs0dKlS3XjjTf6NCuuMxZwgffff98qU6aMFR4ebjVu3Nh69NFHra1bt3ou//HHHy1JVmpqqu3bHDVqlJWcnOz5+umnn7bq1Klzyf3dbrcVExNjLVq0yLNNkvXQQw/Zvs9ly5ZZ5z/Ef/vtNys0NNTasmWLZVmW5XK5rNjYWOvll1/2uo8BAwZ43U5ycrLVq1cvy7Isa968eVb16tUtt9vtudzlclnR0dHW0qVLLcuyrJtvvtkaP3687TkBXDuaN29uTZ061bKsc+tYQkKCtWzZMs/lv/zyiyXJmjZtmtf1+vTpY5UtW9bKysq67O3v3r3bkmQdOHDAsix769qlHD161JJkbdy40bOtTp061nPPPef5unv37tZ9993n9f3VqFHDaw2cOHGiFRcXZ1mWZf3000+Ww+Gwfv31V6/76tSpk/Xggw9almVZDz/8sNWyZUuv2wAuh2dWkc99992ntLQ0ffzxx2rbtq0+//xzNW7cWM8++6wk6euvv1axYsV05513XvI2Zs+erdtuu03lypVTyZIl9eKLL2r//v2X3P/333/X4MGDVbNmTZUqVUpOp1O///57vuvcdtttfn9fMTEx6tChg+bPny9JWrNmjTIyMtSrVy+v/Zo0aZLv6127dkmSvvrqK/3yyy9yOp2eZ6FLlSql48eP6+eff5YkPfLII5o0aZKaNGmisWPH6quvvvJ7ZgDB46efftLmzZvVs2dPSeee+ezVq1e+UwGki69ldevWVXh4uNe27du3q1OnTkpISJDT6fRcL29ttLuu2TVw4EDPM7UZGRn68MMP1b9/f699GjduLIfD4fm6SZMmOnTokE6dOqXt27fLsizVrl3bs0aWLFlSq1ev9qyRKSkp2rFjh2rWrKlhw4Zp9erVlz1NDOAFVrioiIgItWrVSq1atdK4ceM0YMAAjR8/XiNHjrzidZcuXaoRI0Zo2rRpuv322xUVFaVZs2ZpxYoVl7xOnz59dOTIEb344ouedyFITk72vBAgzw033FCg72vAgAHq2bOnXnrpJc2fP19dunRRmTJlbF/f7XarXr16Fz0HNTo6WpLUv39/tW7dWp988onWr1+v22+/XaNHj9b48eMLNDsAs82dO1e5ubmqXLmyZ5tlWZLOncteqVIlz/aLrWUXbvvjjz/UunVr3XXXXVq0aJHKly+v9PR0NWvWzGttLOi6dr7evXtr1KhR2rx5s77++muVK1dOrVu3tn19t9sth8Ohbdu25XvRWN4LXBs0aKB9+/Zp7dq12rBhg/r06aNbb71V69atU7FiPIeG/CirsKV27dpyuVzKyspSvXr15Ha7tXHjRrVp0ybfvps3b1ZSUpIeeughz7a836gvZfPmzZo5c6bat28vSTpy5IjXeVOFpU2bNoqKitLs2bP18ccf65NPPsm3z9atW9WvXz+vr/NezdugQQO98847Klu27GXPQ42Pj9egQYM0aNAgvfDCC5oxYwZlFbiGuVwuvfXWW3ruuefUoUMHr8t69+6tBQsWaNy4cT7d5n//+1+lp6fr2WefVdWqVSWdezX/heysa3ZFR0fr3nvv1fz58/X111+rT58++QrkF198IcuyPM+ubt26VbGxsYqKilL9+vVlWZYOHz582b++OZ1Ode3aVV27dlVKSooaN26sn376STVr1vR7dly7KKvwcuzYMXXr1k39+vXTLbfcIqfTqS+//FJTpkxRcnKyoqKiFBUVpe7du2vAgAGaMWOGGjRooIMHD2rfvn3q3bu3atasqTfffFOpqamqXr263n33XW3atOmyv+nXrFlTixcvVlJSkv744w89/vjjV+VtW0JCQtSvXz+NHj1acXFxSk5OzrfPBx98oMTERLVo0ULvv/++NmzYoC+++EKS1KtXL02bNk2dOnXSM888o8qVK+vAgQP68MMPNWTIENWoUUOPPPKI2rZtq5o1a+rUqVNas2aNateuXejfCwBzrF69Wunp6Ro4cKD+8pe/eF3Wo0cPzZ49W0899ZRPt1m5cmWFh4frlVde0bBhw7R79+6L3oaddc0XAwcOVJs2bZSTk6Ply5fnuzwtLU0jRozQ0KFD9d1332nq1KkaO3aspHNrea9evZSSkqJ//OMfatCggTIyMvTvf/9b1apV07333qvp06erYsWKqlevnsLCwrRkyRJFRUV53gEBuBDPt8NLyZIl1bhxY82YMUPNmzdXnTp1NGbMGPXs2VNLly717Ldw4UL17NlTw4cPV61atZSSkqKTJ09KkgYPHqzu3burZ8+eSkxM1L59+/Too49e9n7nz5+v06dPq2HDhurRo4f69et31d6cul+/fsrOzlbfvn29zrvKM378eC1fvly33HKLXnvtNS1YsECJiYmSpMjISH366aeqVq2aunXrplq1aqlPnz46fvy4p4y73W49/PDDql27tlq1aqUKFSrorbfeuirfCwAzzJs3T3feeWe+oipJ3bp10759+3z+YJVy5crprbfe0sqVK1W7dm1NmDDB8xZSF7rSuuaLFi1aKD4+Xi1atFC1atXyXd6rVy/l5uYqKSlJAwcOVP/+/fX3v//dc/mCBQvUt29fPf7446pVq5Y6dOigTz/9VAkJCZLOPas6depUNWrUSA0aNNCOHTuUmpqqyMjIAs2Na5fDyjuhBrhOfPHFF2ratKn27t3rdW6ZdO4FEcuWLVPXrl0DNB0A+O5y65qvzp49q7i4OL388sv5XqjVokUL1a1bV6+88kqB7gPwBacB4Lrx559/6ujRo3rqqafUpUuXAi/oABBohbmuud1upaena8aMGSpRooS6d+9eiJMC/uM0AFw33nnnHSUkJCg9Pf2Sf0oDgGBSmOva/v37VaFCBS1cuFALFizI92p+IFA4DQAAAADG4plVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBY1+3Hrebk5CgrK0tut1uWZcnhcKhYsWKKiIjgUzvOQ072kJM95BScOG72kJM95GQPOf3PdVNWXS6Xjh8/rszMTGVlZXkO/IXytkdERMjpdKpMmTIKDb1uYiInm8jJHnIKThw3e8jJHnKyh5wu7Zr/uNUzZ84oPT1dmZmZks4dZLvyHiROp1Nly5ZVZGTkVZnRBORkDznZQ07BieNmDznZQ072kNOVXbNl1eVyKS0tTZmZmT4d+EtxOBxyOp2KjY29pn6DISd7yMkecgpOHDd7yMkecrKHnOy7JsvqqVOndPDgQVmWVSgPgDwOh0MOh0Px8fGKiooqtNsNFHKyh5zsIafgxHGzh5zsISd7yMk311RZtSxLhw8fVkZGRqEe/As5HA5FR0crJibmoueTmI6c7CEne8gpOHHc7CEne8jJHnLyzzVTVi3L0qFDh3Ty5Mmr+gDI43A4VKpUKcXFxQXVA4Gc7CEne8gpOHHc7CEne8jJHnLy3zXzPquHDx8usgeAdO5Bd/LkSR0+fLhI7q+wkJM95GQPOQUnjps95GQPOdlDTv67JsrqqVOnrvpT6hdjWZYyMjJ06tSpIr1ff5GTPeRkDzkFJ46bPeRkDznZQ04FE/Rl1eVyeU5SDgTLsnTw4EG5XK6A3L9d5GQPOdlDTsGJ42YPOdlDTvaQU8EFfVlNS0sL2AMgj2VZSktLC+gMV0JO9pCTPeQUnDhu9pCTPeRkDzkVXFCX1TNnzhTa+5MVhGVZyszM1JkzZwI6x6WQkz3kZA85BSeOmz3kZA852UNOhSOoy2p6enrAHwB5LMtSenp6oMe4KHKyh5zsIafgxHGzh5zsISd7yKlwBG1Zdblcno8mM0VmZqZx54SQkz3kZA85BSeOmz3kZA852UNOhSdoy+rx48d92r9z587atm3bVZrmf3yd62q70jxFlcuFyMkecrLHtJxMU5B8ruYxNe24+TJPUT7WyckecrLHtJzsCNqy6us5ICtXrlRiYqIsy9LMmTOVnJysJk2aqG/fvvrpp5+89t22bZtWrlzp80x554SY5Eo55eWyZ88eDR48WM2aNdPNN9980X1TU1N1zz33qFGjRmrbtq2++uorr8uffPJJWzMFc06pqanq2LGjmjRpoubNm+vJJ5/U6dOnJUnZ2dkaN26c7r77biUlJalr16767LPP8t3W9ZCTJB04cEDDhg1TUlKSmjVrpunTp+fb/9dff1XDhg31xBNP5LssmHMyTUHOmfNl7fSFicfNl5wKc+28nGslp8JYOy/nWslJKvjaeTkm5mRH0JbVrKwsv663du1arVy5Um+++aY2b96sW2+9VWPGjJEkbdiwQcuWLfPsu379er333ntFMtfVYnee0NBQtW7dWhMmTLjo5Z9//rlefPFFTZw4UVu3btWbb76p+Ph4WZalZ555xvMqw+PHj2v8+PE6e/ZsocxVVOzOU79+fS1cuFBbtmxRamqqXC6XXn75ZUnn/uQTExOjBQsWaMuWLXr44Yc1cuRIHTp06LrLKScnR4MGDVKjRo20ceNGrV+/Xu3bt8+33+TJk1W3bl3P12fPntWECRN04sQJSedeRTthwoQrLvim5WSawsjncmtnIOcqTP7M4+/aebXnupr8mcfftfNqz3U1+TOPv2vn1Z4r0IKyrObk5Pj8LEHr1q21ZcsWHTp0SPXr11elSpUUEhKiDh066Oeff5YktWzZUqGhoZoxY4YWL16sAwcOqEuXLj7dj2VZysnJ8ek6hWHDhg06evSo1zY7OeXlUrVqVd17772qXr36Rfd79dVXNWTIEN16660qVqyYKlSooAoVKsjhcGjAgAGaNWuWtm/frkmTJun+++9XiRIlLnu/gcjJsix9+OGH+X5QfckpJiZGZcqU8WwPCQnR/v37JUmRkZEaOnSo4uLiVKxYMTVv3lxxcXHatWtXUOV04sQJrV27Vm6322u7LzmtXLlS5cuXV58+fRQZGanw8HDdeOONXvumpqbK6XQqKSnJs61EiRLq2bOnJk2apO3bt2vWrFnq37//FT8qMFA/d6b59ttvtWvXLq9t/qyX57OzdvorUMdt48aN+T7Vx9ecCrp2+iJQ6+VHH32U79Xj/ubk79rp68xFnVNmZqZSU1P9Wi/PV9C10xfBuF4GZVnNysry+3Nu27ZtqwMHDmjfvn3KycnRRx99pKZNm15032LFfI/H4XAE5LeW7t27q1KlShoxYoSntBYkp/Pl5uZq586dysjIULt27ZScnKzJkyd7fZ8Oh0OWZcnhcNjKLRA5paWlqUuXLqpYsaJeeuklz/37mtP27dvVpEkTJSUlaf369erdu/dF90tPT9evv/6qv/71r55twZDTv/71L7Vt21Y1atTQ8uXLPYuwLzl9++23io2N1ZAhQ9SsWTP17dtXP/74o+fy06dPa9asWXrssccueRum52Sip556SnXr1lXHjh21c+dOSYW3DviydtoVqOPWq1cvJSQk6KGHHvKU1sLK6Xx21k47ApHT8ePH1alTJ8XGxmrq1Kme0lqQnAqydtoRiJw2b96s9u3bq2rVqlq6dKlf6+X5CmPtvJJgXC9DAz2APy78DcYX5cqVU4MGDdSxY0eFhIQoJiZGc+fOlXTuH+mcnBwNHz5caWlpKlmypFasWKHu3bvbvn2Xy6VPPvkk37OcV9vZs2f1559/atasWXr11VfVsmVLLVmypFBu+9ixY3K5XFq3bp3eeusthYaGavjw4ZozZ44efvhhzZ07V0OHDtWsWbM0cuRIzZgxQ6NGjbrss4Y5OTlFntOJEycUGhqqEydOaNSoURozZowGDRp0yT/fXUqDBg20ZcsWHTlyRMuXL1dsbGy+fXJycvTEE0/onnvuUbVq1WRZll85ZWdnF3lOO3bsUHh4uPbu3auePXvK6XTqueee8+nn4MiRI9q2bZtmzpypxo0ba/HixRo+fLg+/vhjhYWF6ZVXXlGXLl0UExPjdb2zZ8/q7bff1tixYzV16lQNHTpUc+fO1VNPPXXZhT8QOZlo7969sixLq1ev1tq1a3XjjTd6ndpUEJdbO/0VqPXy9OnTys7O1uzZszVnzhw1a9bM51O+7Ljc2jl8+HDbtxOI9fKPP/5QaGioTp48qSeffFLjxo1TSkqKnn/+eb9v05+10xeBWAd27typ8PBw7d+/Xw888IAefPBBPfPMM5cs4lfi79rpq4L0qEAIyrJakD9pvfbaa/r++++1bt06lS1bVqtWrdKAAQO0YsUKJScnS5LnFXl33XWXz7fvdrt1+PBh7dmzx+8Z/ZH3wHO5XAoJCdEPP/xQaG9PER4eLknq2bOnypUrJ0l64IEHPAvuuHHjPPuWKVNG48ePtzVvUed0+vRpz2PH5XLJ5XLp+++/9/vxVKFCBTVt2lSPP/641z90brdbY8aMUVhYmOecPofDETQ5HTp0yPN4ys3N1cmTJ/XDDz/4lFN4eLjq16+vZs2aSZJSUlI0Z84cT5naunXrRUtUiRIlvHKJjY31yu1SAvVzZ5q8F6xYliXLsnTgwAGdPHlSTqezwLd9ubXzSqezXIplWQE5bnlrY25uriRpz549ys7OLvT7udLaaVcgHt9ZWVle62Vubm6B1svz+bJ2+iIQOR08eNBrvTx16pR2797td07+rp2+MuW9X+0KyrJakD/V/PDDD2rTpo3nt5LOnTtrypQp2rt3r+rUqSNJSkxM9Lwqz1fFixdXSkqKSpUq5feM/liyZIksy9Idd9yhKVOmqGHDhjp58mSh3HapUqU856fmudgxmDx5su3bDA8PL/KcDh06pEWLFikiIkL333+/JkyYoEqVKhUop9zcXB04cMDztWVZGjdunI4dO6ZXX31VYWFh+a7jS04RERFFntMHH3yg9evXq2TJknrsscc0YsQIRUVF+ZRTzZo1tWPHjotetm3bNqWlpalVq1aSzn3Ci9vtVvfu3b3+4TI9JxPt379fBw8eVEJCgqZMmaLOnTsrMzPT5xeqXIydtdNXYWFhATlu77//vnJzc5WUlKSpU6cqKSlJJ0+eLPS39LG7dl5JINbLjIwMzZ8/XxEREeratasmTpyoKlWqFNq/K/6snVcSiHUgNTVVa9as0Q033KARI0Zo5MiRKl26tN85FcbaaUdhn/JytQVlWfXnXNI8devW1T//+U+1adNG0dHRWr16tVwulypVqmTEfP6aPHmyGjVqpIYNG/o1h2VZys7O9px0/eeff8rhcKh48eKSzv3DtGTJEjVt2lShoaFatGiR/u///q9AMxd1TuXLl9ekSZP0t7/9zet4+zLHqlWr1LBhQ1WsWFFpaWmaOXOm10nuEydO1C+//KI33nhDERERhTJ3UeeUmJioqVOnql+/foqKivJrjg4dOnhe+duoUSO9/fbbKl26tKpVq6bKlSurbdu2nn3ffPNNpaWlaezYsQWaOxA/d6YZMmSIUlJS1LlzZ08ehZXL1Vo7A3HcJk6cqFtuucXrZ9ffOYpq7SzqnEqVKqVnn31WPXr0UJUqVQo8R1GtnUWdU7169fT888+rf//+Kl26dIHnKKq1M9jWy6AsqxEREX4/hd2vXz8dO3ZM3bp109mzZ1W5cmVNnz7d6x/lgrAsq9BKii8efPDBfNt8ySktLU1t2rTxfH3bbbcpNjZWa9eulSQNHjxYJ06cUMeOHVW8eHG1bt1agwYN8nveQOQUFham0aNH59vuS0579+7Viy++qMzMTDmdTjVr1kwjRoyQdC7DZcuWqXjx4mrRooXnOuPGjVOHDh38mjkQOeW9UO9CvuRUtWpVPffcc5o4caIyMjJ000036eWXX1ZYWJjCwsK8/mwcGRmp4sWLKzo62u+ZA/VzZ5p27drl21aQ9fJ8V2PtDNRxGzhwYL5t/uZUFGtnIHIKCQm56Ht4+ptTUaydgcipYsWKevTRR/Nt9zenolg7g3G9dFjBduLC/7dr1y4jTxAuVqyYateuHegxPMjJHnKyh5yCE8fNHnKyh5zsIafCE1zPA5/H1N8KTJvLtHnymDaXafPkMW0u0+bJY+pcpjA1H9PmMm2ePKbNZdo8eUyby7R58pg61+UEbVl1Op3GnSDscDgK5VW3hYmc7CEne8gpOHHc7CEne8jJHnIqPEFbVs//JAyTmDaXafPkMW0u0+bJY9pcps2Tx9S5TGFqPqbNZdo8eUyby7R58pg2l2nz5DF1rssJ2rIaGhpq3G8HTqdToaFmvWaNnOwhJ3vIKThx3OwhJ3vIyR5yKjxBW1YlqWzZssY8xe5wOFS2bNlAj3FR5GQPOdlDTsGJ42YPOdlDTvaQU+EI6rIaGRlpxDkheeeAREZGBnSOSyEne8jJHnIKThw3e8jJHnKyh5wKR1CXVencxzGa8CCIi4sL6AxXQk72kJM95BScOG72kJM95GQPORVc0JfV0NBQxcfHB+yB4HA4FB8fr5CQkIDcv13kZA852UNOwYnjZg852UNO9pBTwQV9WZWkqKgoRUdHF/kDweFwKDo6utA+/epqIyd7yMkecgpOHDd7yMkecrKHnArmmiirkhQTE6NSpUoV2QPB4XCodOnSiomJKZL7KyzkZA852UNOwYnjZg852UNO9pCT/4L241YvxrIsHT58WBkZGYXyWdiXkvebSkxMTMDPQ/EHOdlDTvaQU3DiuNlDTvaQkz3k5J9rqqzmOXXqlA4ePCjLsgr1weBwODznfgT7U+oSOdlFTvaQU3DiuNlDTvaQkz3k5JtrsqxKksvlUlpamjIzM84RVyQAAANKSURBVAvlgZD3tg+xsbFB+Ya6l0JO9pCTPeQUnDhu9pCTPeRkDznZd82W1TxnzpxRenq6MjMzJcmnB0TeU+dOp1Nly5YN2vcns4Oc7CEne8gpOHHc7CEne8jJHnK6smu+rOZxuVw6fvy4MjMzlZWVJcuyLnoeR972iIgIOZ1OlSlT5pr7DeVyyMkecrKHnIITx80ecrKHnOwhp0u7bsrqhXJycpSVlSW32+058MWKFVNERITCwsICPZ4xyMkecrKHnIITx80ecrKHnOwhp/+5bssqAAAAzHfNvM8qAAAArj2UVQAAABiLsgoAAABjUVYBAABgLMoqAAAAjEVZBQAAgLEoqwAAADAWZRUAAADGoqwCAADAWJRVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBYlFUAAAAYi7IKAAAAY1FWAQAAYCzKKgAAAIxFWQUAAICxKKsAAAAwFmUVAAAAxqKsAgAAwFiUVQAAABiLsgoAAABjUVYBAABgLMoqAAAAjEVZBQAAgLEoqwAAADAWZRUAAADGoqwCAADAWJRVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBYlFUAAAAYi7IKAAAAY1FWAQAAYCzKKgAAAIxFWQUAAICxKKsAAAAwFmUVAAAAxqKsAgAAwFiUVQAAABiLsgoAAABjUVYBAABgLMoqAAAAjEVZBQAAgLEoqwAAADAWZRUAAADGoqwCAADAWJRVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBYlFUAAAAYi7IKAAAAY1FWAQAAYCzKKgAAAIxFWQUAAICxKKsAAAAwFmUVAAAAxqKsAgAAwFiUVQAAABiLsgoAAABjUVYBAABgLMoqAAAAjEVZBQAAgLEoqwAAADAWZRUAAADGoqwCAADAWJRVAAAAGIuyCgAAAGNRVgEAAGAsyioAAACMRVkFAACAsSirAAAAMBZlFQAAAMairAIAAMBYlFUAAAAYi7IKAAAAY1FWAQAAYCzKKgAAAIxFWQUAAICx/h8OkxgJvFWlFQAAAABJRU5ErkJggg==\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i8*': ['i16*'], 'i16*': ['i32*'], 'i32*': ['i64*'], 'i64*': ['i8'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i8*': [0, 1], 'i16*': [2, 1], 'i32*': [4, 1], 'i64*': [6, 1],\n",
" 'i8': [9, 1], 'i16': [11, 1], 'i32': [13, 1], 'i64': [15, 1],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(12, 4))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)\n",
"ax.text(3, 1.6, \"Scalar Types\", ha='center', fontsize=14)\n",
"ax.text(12, 1.6, \"Array Types\", ha='center', fontsize=14)\n",
"ax.set_ylim(-1, 3);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SyVphPnfSwbt"
},
"source": [
"A similar pattern holds within the `uint`, `float`, and `complex` lattices.\n",
"\n",
"For the sake of simplicity, let's collapse each category of scalar types into a single node, denoted by `u*`, `i*`, `f*`, and `c*` respectively. Our set of in-category lattices can now be represented like this:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"cellView": "form",
"id": "y6eib8KQT1ge",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2de3wTVfr/P5OmbVpMS0uQkLasXCwCrUiBAoUCoiyVm4CCi8qiCMhNQJGCoCALFQEtK2xhF5GCXLwhAqKAgKsvuW4VUKTcUaRN06+90IZXmjaX+f3RV/KjtIUmczK5nOf9F0xnnnk+z5zzycnJzBxBFEURBEEQhCwovJ0AQRAET5DpEgRByAiZLkEQhIyQ6RIEQcgImS5BEISMKL1xUqvVitLSUhiNRpjNZoiiCEEQau3n2K5SqaBWqxEVFQWl0ispE4RfwGvf8ifdgpy3jJlMJhQVFcFoNAKoLkBDcRRQrVZDo9EgPDzcIzkShD/Ca9/yR92ymK7VaoVer4fRaHSpKPUhCALUajV0Op1ffzoThFR47Vv+rNvjplteXo68vDyIosikOA4EQYAgCIiNjUVERASzuAThL/Dat/xdt8dMVxRFGAwGlJSUMC3M7QiCgOjoaGi12jrncAgi0OC1bwWKbo+YriiKyM/PR1lZmUeL40AQBERGRiImJsYnGgdBeApe+1Yg6fbILWMGg0G24gDVF6SsrAwGg0GW8xGEt+C1bwWSbuamW15e7vHhf12IooiSkhKUl5fLel6CkAte+1ag6WZqular1TnB7Q1EUUReXh6sVqtXzk8QnoLXvhWIupmarl6v91pxHIiiCL1e79UcCII1vPatQNTNzHRNJhOze+akIIoijEYjTCaTV/MgCFbw2rcCVTcz0y0qKvJ6cRyIooiioiJvp0EQTOC1bwWqbiama7VanY/h+QpGo5Hmdgm/h9e+Fci6mZhuaWkpizDM8dW8CKKh+Gob9nRegaybien6wrzL7TjmYQjCn+G1bwWybiamazabWYSpxfz58yUd76m8CEIupLbhkpISpKeno0ePHkhJScGcOXNq7VNWVobevXvj73//u2x5sYz/559/4qWXXkK/fv2QmJiI/Pz8Gn9/5513MGjQIHTr1g1DhgzB7t27a/z9xIkTGDVqFLp37460tDR89tlnTPKqD8mma7FYmH4iVVRUYNGiRbhx4waA6ltGFi1a5NY5RFGExWJhlhtBeIqSkpJaTz+x6Fsvv/wyNBoNvvnmG3z//fd47rnnau2zcuVKtGzZ0qW4rPrW+fPna2l0VbcgCOjZsycyMzPr/HtYWBhWr16NY8eOISMjA2+//TZOnz7tPNfMmTPx5JNP4tixY3jnnXewYsUKXLhwoc5YLHRLNl2z2Szp2WSDwYCZM2eid+/e6NWrFzIzM/H0009jyZIlOHnyJLKysvDCCy+4dQ5BEGi0S/gFK1asQGxsLF588UXnSM2VvnV7P8rIyMDRo0dhMBgwa9YsqNVqBAcHo127djWOO336NC5duoRhw4a5lC+LvmU2m9GuXTvEx8fjiy++cBrtnXTXpVOj0eBvf/sbEhIS6jxm6tSpaNWqFRQKBR588EF07twZP//8M4DqUf7NmzcxZMgQCIKAhIQEtGrVCleuXPGYbskvjrTb7W4fa7PZMHXqVCQnJ2Pfvn0ICgrC2bNnnX93vOVdoXDvs8Fut+P06dNkvITPc/nyZdhsNmzYsAEbN27Eo48+ivfee69Bx9bXj06cOIH77rsP8+fPx+HDhxEbG4tZs2aha9euzuPeeustLFy4EJcuXXIpXxZ9q7KyEgqFApcvX8bTTz+N6OhozJ07t95pjrv5RUMwm8349ddf8dRTTwEANBoNHnvsMezcuROjRo3CmTNnUFBQgKSkpHpjSPE8gIHpSvn6c+bMGfz555+YNWuW88XB7dq1w7Jly/D6669jxYoVmDJlCtavX4833njD5dFuVVUV9uzZg5MnT7qdI0HIgcP0rFYrBEHAvn378OOPP+LBBx+867F19aOkpCTs3r0bR48exaJFi7B48WIcPHgQM2bMwFdffYWoqChs3boViYmJ6NChg8uma7VaJfctm81WY3RbUFCAZcuWYcyYMS7pdIXFixejbdu26Nmzp3PbwIEDsXDhQixbtgwA8Prrr0Or1dYbQ+qUj2TTlTq10Lx58xpvag8LC8Obb77p/L9Op8OCBQvciq9SqTBv3jxERka6nSNByMHs2bORmZkJlUqFyZMnY+7cuQgODq71o1Bd1NWPgOr2HxMTgxEjRgAAHnvsMaxbtw6nTp1CQkICtm3bhk8++cStfENCQiT3rYqKCjRq1AiNGjVCq1at8O677+LRRx+t9wUz9elsKO+++y4uXbqEDRs2OH3r6tWrSE9Px8qVK9GjRw9cu3YN06ZNw7333ovevXvXGUfqqx4lm667X/0BQKvVwmAwwGq11lnIjIwMKakBkJYfQchFamoqBEFAeno6NBoNADT49qT6+lF8fDy+++67Gvs6DMMxanz88ccBVH/VN5vN6Nu3Lw4dOoSgoKC7nldq3woJCcGYMWPw7LPP4tFHH3XmVl/cu/nFncjKysLhw4eRnZ2Ne+65x7n98uXL+Mtf/uIc+bZs2RK9e/fGDz/8UK/pStUt2ZFUKpXbw+3ExERoNBr885//hMlkQmVlJU6dOiU1JSeiKEKlUjGLRxCeYujQoVi+fLnTcIGG9636+tEjjzyC8vJy7Nq1CzabDd988w0KCwvRqVMnpKamYv/+/di+fTu2b9+OqVOnol27dti+fXuDDJdF3woKCsKmTZvQv3//GqPH+nTfyS8qKytRVVUFoHpasbKy0nnc+vXr8fXXX+P9999H48aNa8Rs164drl27hhMnTkAURVy/fh3ff/894uPjPaabycoRubm5bk8uFxQUYOnSpTh58iQEQcDAgQPx2muvSU0JQPUnUvv27ZnEIghv0NC+VV8/+umnn7BkyRLk5+ejZcuWSE9PR+fOnWsdv3PnTuzYsQMffvhhg/LydN+qT3d9OhMTE2vte+bMGQDVZh0cHFxjdDxhwgRMmDABALBv3z785z//gV6vxz333INBgwZh5syZdY5oWehmYrpXr171ybd6hYeHo1WrVt5OgyDchte+Fci6mUx4qtVqn1ubzLGkMkH4M7z2rUDWzcR0o6KiWIRhjq/mRRANxVfbsKfzCmTdTExXqVT63KhSrVa7fWsJQfgKvPatQNbN7H4qjUbjM18HBEGo8SswQfgzvPatQNXNzHTDw8N9Yh7GMe8SHh7u1TwIghW89q1A1c30yQGdTucTBYqJifFqDgTBGl77ViDqZmq6SqUSsbGxXiuSIAiIjY1t0M3dBOFP8Nq3AlE382dkIyIiEB0dLXuRBEFAdHQ0IiIiZD0vQcgFr30r0HR75MUEWq0WkZGRshVJEAQ0btz4jm8GIohAgNe+FUi6mTyRVheiKMJgMKCkpMSjax05Po20Wq3X534IQg547VuBottjpuugvLwceXl5EEWRaaEEQXDOt9CUAsEjvPYtf9ftcdMFql94rNfrma3w6biFQ6fT0QMQBNfw2rf8WbcspuvAZDKhqKjI+Z5QVxefA6qfCtFoNHQfLkHcAq99yx91y2q6DqxWK0pLS2E0GmE2m51rod2OY7tKpYJarUZUVJRPf/oShLfhtW/5k26vmO7tWCwWmM1m2O32GotRqlQqBAcHezs9gvBbeO1bvqzbJ0yXIAiCF2gBMYIgCBkh0yUIgpARMl2CIAgZIdMlCIKQETJdgiAIGSHTJQiCkBEyXYIgCBkh0yUIgpARMl2CIAgZIdMlCIKQETJdgiAIGSHTJQiCkBEyXYIgCBkh0yUIgpARMl2CIAgZIdMlCIKQETJdgiAIGSHTJQiCkBGfWInOl9czIgIHHtsZj5oB39btFdP1p5U7Cf+Fx3bGo2bAv3TLujClP65RT/gfPLYzHjUD/qlbFtO1Wq3Q6/UwGo0uFaU+BEGAWq2GTqfz609ngi08tjMeNQP+rdvjplteXo68vDyIosikOA4EQYAgCIiNjUVERASzuIR/wmM741Ez4P+6PWa6oijCYDCgpKSEaWFuRxAEREdHQ6vV1jmHQwQ2PLYzHjUDgaPbI6YriiLy8/NRVlbm0eI4EAQBkZGRiImJ8YnGQcgDj+2MR81AYOn2yH26BoNBtuIA1RekrKwMBoNBlvMRvgGP7YxHzUBg6WZuuuXl5R4f/teFKIooKSlBeXm5rOclvAOP7YxHzUDg6WZqular1TnB7Q1EUUReXh6sVqtXzk/IA4/tjEfNQGDqZmq6er3ea8VxIIoi9Hq9V3MgPAuP7YxHzUBg6mZmuiaTidk9c1IQRRFGoxEmk8mreRCegcd2xqNmIHB1MzPdoqIirxfHgSiKKCoq8nYahAfgsZ3xqBkIXN1MTNdqtTofw/MVjEYjze0GGDy2Mx41A4Gtm4nplpaWun3sb7/9hieffBLdunXD1q1bWaTjREpehO/R0OvpyTZVF55sZ67EllO3p/tWIF9rJqYrZd4lOzsbXbt2xYkTJ/DMM88gPz8fa9askZyTYx6GCBwa2s5ubVP3338/xo0bhx49emDAgAF17r9lyxakpaUhOTkZQ4cOxe+//97gnDzdzlzpW7fqttlsSEtLQ/fu3dGvXz8sW7bMOUorLi5Geno6+vXrhx49emDMmDH45ZdfGpyTHH3LnWv9zDPPIDc3F2PHjkVycjL69OmDLVu21DomJycHiYmJWLVqlUs5sdLNxHTNZrPbx+r1erRp0wY///wz1q1bB5vNBgD48ccfsW7dOq/lRfgeDb2ejjYFAGFhYRg+fDheeeWVOvf9/PPPsWPHDmRlZeHEiRPIyspCVFSUR/JyB1di36r74Ycfxqefforjx4/jiy++wMWLF50jQZPJhA4dOuCTTz7B4cOHMXToUEydOtWlH4o83bfcudalpaWYPHkyRo4cicOHD+Prr79GSkpKjf0tFguWLVuGBx980KN53QnJpmuxWNwe5b7wwgvIycnBW2+9hQkTJiA8PByLFy/G3r17cfjwYTz77LOSchNFERaLRVIMQn6+//77Wo27oe3s1jaVnJwMtVqNIUOGIDY2tta+drsda9euRXp6Olq3bg1BEBAXF4fIyEiX8mXRzq5fv47c3Nwa21zpW7frttlszpe2ON4he/36dQBAXFwcxo4di6ZNmyIoKAgjR46ExWLBb7/91uB8WfWt7777DpWVlTW2uXutly5dipSUFAwePBghISFo1KgRWrVqVeOYTZs2ISUlBffdd59b+bLQLdl0zWaz288mf/DBB0hKSsK8efPwv//9DzExMf8/MYX0QbggCDTa9TNsNhv69u0LnU6H9957z3n9GtrObm9Td+pchYWFKCwsxOXLl/Hoo48iLS0NWVlZsNvtLuXMop1lZWWhQ4cOGDhwIM6cOQPAtb5Vl+6vvvoK3bt3R2pqKi5evIiRI0fWeez58+dhsVjQokWLBufLQnNFRQUefvhhxMTEICsry2m+7l7r4uJiREZG4tlnn0WfPn0wbdo0FBQUOPfX6/XYuXMnJk2a5HbOLHRLfnGkqw20Pn7++WdcvHgRb7zxBvbs2YOuXbti69atmDBhgtsxrVYrvv32W9y4cYNJjoTncbSn0tJSzJkzB/Pnz8e4ceOwePFi5ucqLCwEABw9ehQ7duyA0WjEiy++iGbNmuHJJ59scByLxSK5nTmMdt++fTh48CAeeOABbNq0CaGhoW7HHDRoEAYNGoRr165h9+7daNKkSa19bt68iddeew2TJ0+GWq1ucGybzSZZc1VVFRQKBYqLizFr1izMmTMHEydOxMKFC92KV1hYiHPnzmHdunW4//77kZmZifT0dGzevBkAsHTpUkybNk3yy8qlep5k02V1H13Hjh3RsWNH5OfnAwC6du2Krl27Sopps9lw7tw5nDt3jkWKhAzc2qArKythsVjwzTff4B//+AfzczkM7fnnn0dERAQiIiIwcuRI/PDDDy6ZLot2du3aNQDV/clut+PChQu4du0a4uPj3Y7p4C9/+QvatGmDJUuW4J///Kdzu9lsxrRp09CxY0eMHz/epZh2u12yZqvV6vSPW6/1ggUL3IoXGhqKfv36ISEhAQAwefJkpKamwmg04qeffoLJZEJaWprb+TqQ6nmSTZf1a89iYmIwZcoUJrFCQ0MxefJkl+foCO9hs9mwZcsWhIWFoVevXli2bBk6deqEsrIy5ue67777EBwcLLkNq1Qqye1szpw5OH/+PNRqNRYsWIBJkyahqqrKOQiRitVqdc7pAtWjzBkzZqBZs2ZumVxwcLBkzRUVFdi2bRvCwsLQp08f5w9c7l7r+Pj4Gtfy1n+fOHECZ8+eRd++fQFUj/AVCgUuXbqE1atXu3Qeqe1FsumymHv1JL6eH1EThUKB+fPn44knnkCnTp1qbHcHu90Oi8XiHFVVVlZCoVAgODgYYWFhSEtLQ3Z2Ntq1awej0Yjt27fj+eefdytvKQwcOBA6nQ4TJ05EWFgYAEi6Ef/zzz9H37590aRJE1y5cgUffPCB85d8i8WCV155BaGhocjIyHA7d6maQ0NDMW/ePDz11FM17iZwN+6wYcPw8ssv45lnnkHr1q3x73//G0lJSVCr1Zg2bRpeeOEF575vv/02mjZt6tb8rlTdkk1XpVL5zKN6tyOKIlQqlbfTIFxAEAQsWbKk1nZ329lPP/2EcePGOf/fpUsXdOnSBdnZ2QCAefPmYdGiRejXrx/UajWeeOIJDB8+3KVzsGhnffr0QZ8+fWpsk9K3Tp06hVWrVqGiogJRUVH461//imnTpgEATp8+je+//x4qlarGLVVr165F586dGxSfhWaFQoGMjIxa293V3a1bN8yYMQNTp05FRUUFkpKSsGzZMgBAo0aN0KhRI+e+oaGhCAsLc+tOFam6mawckZuby+wHNZYoFAq0b9/e22kQjOCxnfGoGQhs3Uy+e/vqaNJX8yLcw1evpyfz4lGzHPHdhUVeTExXrVb73NpkjiWVicCBx3bGo2YgsHUzMV1XH5uUC1/Ni3APX72ensyLR81yxHcXFnkxMV2lUulzo0q1Wg2lUvLvhIQPwWM741EzENi6md1PpdFofObrgCAI0Gg03k6D8AA8tjMeNQOBq5uZ6YaHh/vEPIxj3kXqo36Eb8JjO+NRMxC4upk+OaDT6XyiQLe+OIcIPHhsZzxqBgJTN1PTVSqViI2N9VqRBEFAbGwsgoKCvHJ+Qh54bGc8agYCUzfzZ2QjIiIQHR0te5EEQUB0dLTzHaJEYMNjO+NRMxB4uj3yYgKtVovIyEjZiiQIAho3bgytVivL+QjfgMd2xqNmILB0M3kMuC5EUYTBYEBJSYlH383g+DTSarVen/sh5IfHdsajZiBwdHvMdB2Ul5cjLy8PoigyLZQgCM75FppSIHhsZzxqBvxft8dNF6h+RZ1er5e0avCtOG7h0Ol09AAE4YTHdsajZsC/dctiug5MJhOKioqcyxi7cmrHMF+tVkOj0dB9uES98NjOeNQM+KduWU3XgdVqRWlpKYxGI8xms3O10ttxbFepVFCr1YiKivLpT1/Ct+CxnfGoGfAv3V4x3duxWCwwm82w2+3OoigUCqhUKgQHB3s7PSJA4LGd8agZ8G3dPmG6BEEQvEALiBEEQcgImS5BEISMkOkSBEHICJkuQRCEjJDpEgRByAiZLkEQhIyQ6RIEQcgImS5BEISMkOkSBEHICJkuQRCEjJDpEgRByAiZLkEQhIyQ6RIEQcgImS5BEISMkOkSBEHICJkuQRCEjJDpEgRByAiZLkEQhIz4xEp0vryeUSDCa7151M2jZsC3dXvFdP1p5c5AgNd686ibR82Af+mWdWFKf1yj3p/htd486uZRM+CfumUxXavVCr1eD6PR6FJR6kMQBKjVauh0Or/+dPYUvNabR908agb8W7fHTbe8vBx5eXkQRZFJcRwIggBBEBAbG4uIiAhmcf0dXuvNo24eNQP+r9tjpiuKIgwGA0pKSpgW5nYEQUB0dDS0Wm2dczi8wGu9edTNo2YgcHR7xHRFUUR+fj7Kyso8WhwHgiAgMjISMTExPtE45IbXevOom0fNQGDp9sh9ugaDQbbiANUXpKysDAaDQZbz+Rq81ptH3TxqBgJLN3PTLS8v9/jwvy5EUURJSQnKy8tlPa+34bXePOrmUTMQeLqZmq7VanVOcHsDURSRl5cHq9XqlfPLDa/15lE3j5qBwNTN1HT1er3XiuNAFEXo9Xqv5iAXvNabR908agYCUzcz0zWZTMzumZOCKIowGo0wmUxezcPT8FpvHnXzqBkIXN3MTLeoqMjrxXEgiiKKioq8nYZH4bXePOrmUTMQuLqZmK7VanU+hucrGI3GgJ3b5bXePOrmUTMQ2LqZmG5paamk44cNG4acnBwWqdRAal6+ihRdnqo14Pl6uxLfkzpvx5O6edTsanx/083k4YirV68yneeZP38+MjIyJMcJDw9Hq1atGGTkW7CotyiKWL16NXbt2gWTyYQHHngA8+fPR5s2bdyO6el6u6P70qVLeOedd5Cbm4sbN27gzJkztfbZu3cv1q5dC4PBgCZNmmDJkiXo3Llzg8/hSd3uaN67dy/WrFmDoqIihISEoFevXnjttddwzz33oKqqCkuWLMHx48dRVlaGuLg4zJgxA6mpqS6dwxevNQBcv34db7/9Nn788UeEhIRg+PDheOWVV2rsc+3aNYwYMQL9+/fH22+/7VJ8FrqZjHTNZrPkGAUFBVi2bBkqKioAABcvXsQ777zj9bx8ERa69u/fj507d2Ljxo04fPgwOnbsiHnz5nk9L9bxlUolBgwYgEWLFtX596NHj2LlypVYvHgxjh8/jo0bNyI2NtbjeXkydqdOnfDhhx/i2LFj2Lt3L6xWK1avXg2g+mu7VqtFdnY2jh07hpdeegmvvvoq8vPzPZ6Xp+NbLBZMnDgRycnJ+O9//4uDBw9i0KBBtfbLyMhAQkKCbHndjmTTtVgskie7BwwYgN9//x39+/fHvHnzkJOTg08//RTjxo2TFFcURVgsFkkxvMmFCxdw8ODBGvWVWu8BAwbg2LFjyM/PR6dOnRAXF4egoCAMHjwYV65ckZQvq3p//PHHtX60cFW3Q2fLli0xYsSIekfwa9aswaRJk9CxY0coFAo0a9YMzZo1cylfFrrPnTuHQ4cOSbrWDs1arRZRUVHO7UFBQfjjjz8AVI/UpkyZgpiYGCgUCvTp0wcxMTHIzc11KV8WmkVRxEcffYSSkpIa293VvXPnTtx7770YO3YswsPDERoairZt29bYd+/evVCr1ejWrZvbOUvVLdl0zWazR57JVigUkuMKguDXo92tW7fir3/9KxITE7F//36Iosis3o899hiuX7+O33//HRaLBbt370bPnj0lxWRRb7vdjtGjR6NFixaYNWuW03w90c5sNhvOnj2LkpISDBw4EI888ggyMjJc1sBC98aNG9G/f3889NBDOHDggORrffLkSfTo0QPdunXDwYMHMWbMmDr3KyoqwrVr19C6dWuX4rPQbDab8fTTTyMuLg5z5sxxmq+7un/55RfodDpMmjQJqampeP7553Hx4kXn32/evImsrCzMnj3b7ZxZ6Jb84ki73S41BADgzz//xLlz5/DWW29hyZIleOKJJ7BhwwbMmjXL7ZgWiwU7duxAQUEBkxzl5siRIxBFEWfPnsXQoUMRFRWFjz/+GPfee6/k2E2bNkVSUhKGDBmCoKAgaLVarF+/XlJMm80mud6O9lRRUYFVq1Zh9erVSEtLw+bNmyXlVhfFxcWwWq04cOAANm3aBKVSienTp2PdunWYPn16g+NUVVVJ1n38+HGIoohffvkFgwcPRnR0ND799FM0adLErXhJSUk4duwYCgsL8fnnn0On09Xax2KxYO7cuRg6dKjL85QsrrXFYoFCoYDJZEJmZiZWrlyJIUOGYMOGDW7FKywsRE5ODlatWoXu3btjy5YtmD59Or788ksEBwfjX//6F4YPHw6tVut2zoB0z5Nsuqzuo2vatCmGDh3q/H/btm1rfTVwh4qKCty4cUNyHG9w6yeqKIqoqKhAZWUlk9hr167Fr7/+igMHDkCj0WDPnj0YP348vvjiC4SFhbkV05GjlHrf2p7sdjsEQUBpaalH7tcMDQ0FADz99NNo2rQpAODvf/+7y6bLQvet15XltW7WrBl69uyJ9PR0fPrpp87tdrsd8+bNQ3BwsFtz+Sw03zqN4LjWUt6xEBoaik6dOjl/FHzuueewbt06XL16FaIo4vjx4/jss8/czteB1LYo2XQ9MbXA4s4FAAgODsbo0aMRGRnJJJ7cLFiwAMePH0dsbCyWL1+OkSNHwmg0uvyjR11cuHABaWlpzk/9YcOGYfny5bh69So6dOjgVkylUim53na7HcuXL4dKpcLo0aOxaNEixMXFoayszO2Y9REZGYlmzZrVaMPutOfQ0FDJuufMmYOcnBy0aNECy5cvxxNPPMHsWttsNly/ft35f1EUsWDBAhQXF2PNmjVuLdTI4lpXVFQgMzMTKpUKY8aMwcKFCxETE+P2tY6Pj8fp06fr/FtOTg70ej369+8PoPppN7vdjlGjRtX4MGoIUj1PsukqFL69iruv53cnRowYgQcffBAjRoxw6mClJyEhAd988w3S0tIQHR2Nr776ClarFXFxcZLiSs1PEASsXbsWgwYNqpGLu3FFUURVVZXzx4/KykoIgoCQkBAA1R8227ZtQ8+ePaFUKrF582b07t3b5fNI1f3UU08hOTkZw4cPl3yt9+zZg86dO6N58+bQ6/VYtWpVjR+OFi9ejN9++w3vv/8+VCqV2zlL1axSqZCVlYXHH38cMTExkuMOHjzYeddGcnIytm7disaNG6NVq1Zo0aIFHnvsMee+GzduhF6vx+uvv+7yeaTqlmy6KpXKZx7Vux1RFCU1Km/z0EMP4aGHHqqxjVW9x40bh+LiYowcORIVFRVo0aIFMjMzJS1TwqLegiBg0qRJtba7q1uv1yMtLc35/y5dukCn02H//v0AgBdffBE3btzAkHx6doAAAA+uSURBVCFDEBISggEDBmDixIkunYOF7qSkJCQlJdXY5q7mq1evYuXKlTAajVCr1UhNTcXMmTMBVNfjs88+Q0hICPr27es8ZsGCBRg8eHCDz8HqWk+ZMqXWdnd1t2zZEkuXLsXixYtRUlKCdu3aYfXq1QgODkZwcHCNabPw8HCEhIQgOjrapXMw0c3i4Yjc3FxmP6ixRKFQoH379t5Ogzm81ptH3TxqBgJbN5Pvqr46mvTVvKTiq7o8nRePunnULEd8d2GRFxPTVavVPrc2mWNJ5UCE13rzqJtHzUBg62Ziurc+/eJL+GpeUvFVXZ7Oi0fdPGqWI767sMiLiekqlUqfG1Wq1WoolZJ/J/RJeK03j7p51AwEtm5m91NpNBqf+TogCAI0Go230/AovNabR908agYCVzcz0w0PD/eJeRjHvEt4eLhX8/A0vNabR908agYCVzfTJwd0Op1PFOjWG60DGV7rzaNuHjUDgambqekqlUrExsZ6rUiCICA2NhZBQUFeOb/c8FpvHnXzqBkITN3Mn5GNiIhAdHS07EUSBAHR0dGSnqjyR3itN4+6edQMBJ5uj7yYQKvVIjIyUrYiCYKAxo0bS35lm7/Ca7151M2jZiCwdDN5DLguRFGEwWCQ9Kq2huD4NNJqtV6f+/EmvNabR908agYCR7fHTNdBeXk58vLyIIoi00IJguCcb+FtSuFO8FpvHnXzqBnwf90eN12gejE8vV4Po9HIpEiOWzh0Ol3APgAhBV7rzaNuHjUD/q1bFtN1YDKZUFRUBKPRCMC1N7A7hvlqtRoajSbg78NlAa/15lE3j5oB/9Qtq+k6sFqtKC0thdFohNlshiiKdc6dOLarVCqo1WpERUX59Kevr8JrvXnUzaNmwL90e8V0b8discBsNsNutzuLolAooFKp3FpKhLgzvNabR908agZ8W7dPmC5BEAQv+O8CYgRBEH4ImS5BEISMkOkSBEHICJkuQRCEjJDpEgRByAiZLkEQhIyQ6RIEQcgImS5BEISMkOkSBEHICJkuQRCEjJDpEgRByAiZLkEQhIyQ6RIEQcgImS5BEISMkOkSBEHICJkuQRCEjJDpEgRByAiZLkEQhIz4xEp0vryeUSDCa7151M2jZsC3dXvFdP1p5c5AgNd686ibR82Af+mWdWFKf1yj3p/htd486uZRM+CfumUxXavVCr1eD6PR6FJR6kMQBKjVauh0Or/+dPYUvNabR908agb8W7fHTbe8vBx5eXkQRZFJcRwIggBBEBAbG4uIiAhmcf0dXuvNo24eNQP+r9tjpiuKIgwGA0pKSpgW5nYEQUB0dDS0Wm2dczi8wGu9edTNo2YgcHR7xHRFUUR+fj7Kyso8WhwHgiAgMjISMTExPtE45IbXevOom0fNQGDp9sh9ugaDQbbiANUXpKysDAaDQZbz+Rq81ptH3TxqBgJLN3PTLS8v9/jwvy5EUURJSQnKy8tlPa+34bXePOrmUTMQeLqZmq7VanVOcHsDURSRl5cHq9XqlfPLDa/15lE3j5qBwNTN1HT1er3XiuNAFEXo9Xqv5iAXvNabR908agYCUzcz0zWZTMzumZOCKIowGo0wmUxezcPT8FpvHnXzqBkIXN3MTLeoqMjrxXEgiiKKioq8nYZH4bXePOrmUTMQuLqZmK7VanU+hucrGI3GgJ3b5bXePOrmUTMQ2LqZmG5paSmLMMzx1byk4qu6PJ0Xj7p51CxHfHdhkRcT0/XkvMv8+fPdOs4xDxOIeKLe+/btw9ChQ9GtWzc8/vjjOHTokEvHy1FvqbotFgteeeUVDBgwAImJicjJyam1T25uLsaOHYvk5GT06dMHW7ZsuWNMT+uWqvnKlSt46qmnkJKSgpSUFIwfPx5Xrlxx/j07OxvDhw9Ht27dkJaWhuzs7LvG9IdrDQAVFRVYsmQJUlNT0aNHD4wdO7bWPhaLBUOHDsUjjzxy13isdDMxXbPZzCKME1EU8Y9//MP5i2FpaSnefPNNVFRUeDUvX4G1rsLCQrz22muYPXs2jh8/jlmzZmHu3LkoLi72al6eiN+pUycsXboUGo2m1t9KS0sxefJkjBw5EocPH8bXX3+NlJQUWfLyVOymTZsiMzMTR44cwQ8//ICHH34Ys2fPdv5dFEVkZGTgyJEjWLt2LT766CPs3bvX43nJEX/RokUoKyvDrl27cPjwYaSnp9faJzs7G1FRUbLmJdl0LRaL259IiYmJ+OOPP5z/nz9/PlatWgVBEDB+/HhkZWXh5MmTWLJkCUaPHo2wsDCX4ouiCIvF4lZuvkBOTg42bdpUYx7J3XrXV2ug2nQjIiKQmpoKQRDQu3dvhIWF4fr16y6dg0W9RVFEZmYmLl++XGN7Q3XfSWdwcDDGjBmDpKQkKBS1m/6HH36IlJQUDB48GCEhIWjUqBFatWrVoJyl6j5x4gQ2b97s1rW+k+aIiAjno6yiKEKhUNS4ruPGjUP79u2hVCrRsmVLPPzwwzh16tRdz8nqWr/77ru4evVqje0sdF+9ehXfffcdFi5ciOjoaAQFBaFDhw41js/Ly8OePXswfvx4l3KWqluy6ZrNZo89k+1oKI63vrtzvD+Pdr/88ku88MILiIuLw4YNG2C1Wj1S7w4dOqBly5b473//C5vNhkOHDiE4OBjx8fEuxWFRb1EUMWvWLCQmJmLUqFFO8/VkO3Pwyy+/IDIyEs8++yz69OmDadOmoaCg4K7HsdC9Y8cOPP/882jRooXzg5al5pSUFHTp0gVLly6t12REUcTJkyfRpk2bu8ZjodlsNuPVV19Fhw4d8MwzzzjNl4XuX3/9Fc2bN0dWVhZSU1MxfPhwHDhwoMY+S5cuxYwZM6BSqRocl4VuyS+OtNvtUkPUQhRFrF+/HlOmTEFWVhZeffVVvPfee5gzZ45Lo92qqiqsX78ev/32G/Mc5eD48eOw2WwwGAx48cUX8dJLL+Hjjz9G69atmZ4nKCgIQ4cOxZw5c1BVVYXg4GC8++67Lr/U2Wq1Sq63Y4RjNpuxfft27NixAx07dsS3337rdsyGUlhYiHPnzmHdunW4//77kZmZifT0dGzevPmOx5nNZsm6jxw5ApvNhoKCAowfPx5TpkzBZ599hvvuu8/tmLdy9OhRmEwm7N69Gzqdrs591qxZA7vdjmHDht01HotrbbVanSb20Ucf4ZNPPkGXLl2wf/9+t2M6KCwsxOXLl9G/f398++23OH36NKZOnYrWrVujVatWOHToEGw2Gx555JE65/bvhFTPk2y6nvgBTRAELFiwwPn/qKgovPnmmy7HUSgU0Ol0CAoKYpidfFy8eNH5b4VCAY1Gg+joaObnOXbsGDIzM5GdnY127dohNzcXL730EtauXYsHHnigwXEEQZBc71sbtFKphCAIaNu2rSz3a4aGhqJfv35ISEgAAEyePBmpqakwGo1Qq9X1HseineXm5taI17RpU+bXOjw8HKNGjULv3r2xa9cuNGnSxPm3bdu24csvv8TGjRsREhJy11gsrnVVVZXz26xSqYRCoWB2rUNDQ6FUKjFx4kQolUp07doVycnJOHr0KLRaLTIzM7FmzRq3YkvNT7LpSvkaEBYWVuPHseLiYjRr1qzGPhkZGW7HVyqVSEtLQ2RkpNsxvElRURH279+P5ORkrFixAr1790ZZWRny8/NdjnWnWl+4cAGdO3d2znklJCQgMTERx48fd8l0g4KCJNfbbrfj5ZdfRqNGjfDqq6/i5ZdfRmRkJMrKyhp0fEPaVH3Ex8fXaM8NbdshISGSdefn5+PQoUPo3r07VqxYgV69ejX4Wrui2W63w2w24//+7/+cpvvFF1/ggw8+wMaNG6HVahuUL4trXVFRgdmzZ+Oee+5Beno6Zs6cCbVazeRa1zU15rief/zxB/R6vfNuBovFgps3b6Jv377YunUrYmJi7nheqVMfkud03ZlrddC2bVt8/fXXsNlsOHz4MH788Uep6dRCSn7e5rnnnsORI0dw4sQJ9O7dG4D7eu5U6w4dOuDkyZM4f/48AODcuXM4efKky3O6UvK79fhdu3YhPz8fb775prNTNzTu3dpUVVUVKisrAVR3tsrKSufIZdiwYTh06BDOnz8Pi8WCf//730hKSrrjKPfWvKUwYcIEHD16FMeOHUOvXr1cinknzUePHsW5c+dgs9lw8+ZNrFixAhEREc4fCPfs2YP33nsP77//PuLi4lzKWarmsLAw7Ny5E3q9Hm+88Yazzix0d+7cGc2bN8f69ethtVpx6tQp/O9//0NKSgratGmDAwcOYPv27di+fTsWLVqEJk2aYPv27Q360JGqW/JLzC0WCy5evOjWkPvs2bOYP38+CgoK0K9fP9hsNsTGxmL69OlSUnIiCALi4+O9vuQyS9yt991qvW3bNmzZsgXFxcWIiorC6NGj67yv8U54st4N1X03nQMGDKj18pJ9+/Y5RzeffPIJ1q1bh4qKCiQlJeH111+/a0f0lG4Wmvfv349//etfKCwshEqlQkJCAmbMmIG2bdsCANLS0lBYWFgj98GDB9eY3qsLf7jWly9fxsKFC3Hp0iU0b94c06dPr/N+3JycHMydO7dB96az0M1k5Yjc3FyP/KAmFYVCgfbt23s7DebwWm8edfOoGQhs3Uy+e7tyy4Wc+GpeUvFVXZ7Oi0fdPGqWI767sMiLiemq1WqfW5vMsaRyIMJrvXnUzaNmILB1MzFdVx6jkxNfzUsqvqrL03nxqJtHzXLEdxcWeTExXaVS6XOjSrVaDaVS8h1xPgmv9eZRN4+agcDWzex+Ko1G4zNfBwRBqPOFJoEEr/XmUTePmoHA1c3MdMPDw31iHsYx7+LqI6z+Bq/15lE3j5qBwNXN9MkBnU7nEwW62xMlgQKv9eZRN4+agcDUzdR0lUolYmNjvVYkQRAQGxvrt+9acBVe682jbh41A4Gpm/kzshEREYiOjpa9SIIgIDo6GhEREbKe19vwWm8edfOoGQg83R55MYFWq0VkZKRsRRIEAY0bN27wyzoCDV7rzaNuHjUDgaWbyWPAdSGKIgwGA0pKSjz6Wj7Hp5FWq/X63I834bXePOrmUTMQOLo9ZroOysvLkZeXB1EUmRZKEATnfAtvUwp3gtd686ibR82A/+v2uOkC1W+I1+v1zFaxddzCodPpAvYBCCnwWm8edfOoGfBv3bKYrgOTyYSioiLnMsaunNoxzFer1dBoNAF/Hy4LeK03j7p51Az4p25ZTdeB1WpFaWkpjEYjzGazc/HJ23FsV6lUUKvViIqK8ulPX1+F13rzqJtHzYB/6faK6RIEQfCK/65lQxAE4YeQ6RIEQcgImS5BEISMkOkSBEHICJkuQRCEjPw/uRKVFUaXuc0AAAAASUVORK5CYII=\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'u*': ['u8'], 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],\n",
" 'i*': ['i8'], 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],\n",
" 'f*': ['f16'], 'f16': ['f32'], 'f32': ['f64'],\n",
" 'c*': ['c64'], 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'u*': [0, 0], 'u8': [3, 0], 'u16': [5, 0], 'u32': [7, 0], 'u64': [9, 0],\n",
" 'i*': [0, 1], 'i8': [3, 1], 'i16': [5, 1], 'i32': [7, 1], 'i64': [9, 1],\n",
" 'f*': [0, 2], 'f16': [5, 2], 'f32': [7, 2], 'f64': [9, 2],\n",
" 'c*': [0, 3], 'c64': [7, 3], 'c128': [9, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 4))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yft3cGgtUyYx"
},
"source": [
"In some senses, putting scalars at the left is a strange choice: the scalar types may contain values of any width, but when interacting with an array of a given type, the promotion result defers to the array type.\n",
"The benefit of this is that when you perform an operation like `x + 2` for an array `x`, the type of `x` will carry to the result no matter its width:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "_MzOX_HCVfOT"
},
"outputs": [],
"source": [
"for dtype in [np.int8, np.int16, np.int32, np.int64]:\n",
" x = np.arange(10, dtype=dtype)\n",
" assert (x + 2).dtype == dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wmbZEheuVuDH"
},
"source": [
"This behavior gives motivation to our `*` notation for scalar values: the `*` is reminiscent of a wildcard that can take on any desired value.\n",
"\n",
"The benefit of these semantics is that you can readily express sequences of operations with clean Python code, without having to explicitly cast scalars to the appropriate type. Imagine if rather than writing this:\n",
"```python\n",
"3 * (x + 1) ** 2\n",
"```\n",
"you had to write this:\n",
"```python\n",
"np.int32(3) * (x + np.int32(1)) ** np.int32(2)\n",
"```\n",
"Although it is explicit, numerical code would become tedious to read or write. With the scalar promotion semantics described above, given an array `x` of type `int32`, the types in the second statement are implicit within the first."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HX7yeIf6jbjz"
},
"source": [
"## Combining Lattices\n",
"\n",
"Recall that we began our discussion by introducing the lattice representing type promotion within Python: `int -> float -> complex`. Let's rewrite this as `i* -> f* -> c*`, and let's further allow `i*` to subsume `u*` (after all, there is no unsigned integer scalar type in Python).\n",
"\n",
"Putting these all together, we get the following partial lattice representing type promotion between Python scalars and numpy arrays:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"cellView": "form",
"id": "koA5VFHp7tjo",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxMZ98/8M/JOgmTSERNFlqKFEHFmmgEtSStveVRS5Tq3VLFw69aFHXbuqi2FC3SEMtT+3rXrnXXbUstte9EkjEhEsmQTMyZOb8/8sw8EkFm5sw5Z+b6vl+vvl41mXOd65Nz8p1rznZxgiAIIIQQIgkPuTtACCEsoaJLCCESoqJLCCESoqJLCCESoqJLCCESoqJLCCESoqJLCCESoqJLCCESoqJLCCESoqJLCCES8pK7A6RijEYjDAYDzGYzBEEAx3Hw8PCASqWCt7e33N0jdmBlm7KSs6Ko6CoUz/PIy8uDXq+HwWCw7qxlWV5XqVRQq9UICgqClxdtViViZZuyktNeHD3wRlkKCwuRk5MDvV4PoGTHrCjLjq1WqxESEgJ/f3+n9JHYhpVtykpOR1HRVQie56HVaqHX623aWZ+G4zio1WqEhYUxMXpQIla2KSs5xUJFVwEKCgqQmZkJQRBE2WktOI4Dx3GIiIhAQECAaO2S52Nlm7KSU0xUdGUkCAJ0Oh1yc3NF3WHL4jgOwcHB0Gg05R5bI+JhZZuyktMZqOjKRBAEZGVlIT8/36k7rQXHcQgMDER4eLjb7LxKw8o2ZSWns9B1ujLR6XSS7bRAyR9Kfn4+dDqdJOtjESvblJWczkJFVwYFBQVO/1pWHkEQkJubi4KCAknXywJWtikrOZ2Jiq7EeJ63nniQgyAIyMzMBM/zsqzfHbGyTVnJ6WxUdCWm1Wpl22ktBEGAVquVtQ/uhJVtykpOZ6OiK6HCwkLRrmV0hCAI0Ov1KCwslLUf7oCVbcpKTilQ0ZVQTk6O7DuthSAIyMnJkbsbLo+VbcpKTilQ0ZUIz/PW2yOVQq/Xu/zxMTmxsk1ZySkVKroSycvLc0q7kyZNcmh5Z/WLBY787nJzczF+/HjExMQgNjYWn3766RPvyc/PR9u2bZGUlCRZvxxt7+7du/j444/RoUMHNGrUCFlZWaV+PmfOHLz55pto1aoVunXrhq1bt5b6+dGjR9G3b1+0bt0aCQkJWLdunSj9UhIquhIR83hYUVERpk2bhvv37wMoOcExbdo0m9u3HB8j9nFkm/73f/83QkJCsHv3bhw4cADvvvvuE+/57rvvUKtWLZvadcY2tSUnx3Fo06YN5s6dW+7P/fz8MH/+fBw+fBgzZ87El19+iVOnTgEoeQTkmDFj8Pbbb+Pw4cOYM2cOvvnmG1y6dOmJdlx533W/p0kolMFgsHtZnU6HL7/8EidOnIDZbEZiYiL69++PGTNm4Ny5c1iwYAGGDx9u1906jvSLdRX53ZW37dq3bw+dTodffvkFnp6eAID69euXWu7UqVO4cuUK3n77bWzatEn0fonRXnnZJk2ahH79+j31q/9HH31k/f/GjRujWbNm+Pvvv/Hqq68iPz8fDx48QLdu3cBxHKKiolC7dm1cu3YNkZGRFe6X0tFIVwJGo9HuEZHJZMJHH32E0NBQ7Ny5E/v27UNiYqL1548/FNoegiDAaDTatSwLtm3bhlatWmHv3r2ltmFFtunTtt3ff/+Nl156CZMmTcJrr72Gfv36IS0trdRys2bNwsSJE+36ILVnm2ZnZ6NBgwb4+eef8ejRI+vrT8v5vP2yIgwGA86ePYuXX34ZABASEoLExERs3rwZJpMJp06dwu3btxEdHV3u8q6671LRlYDBYLD7nvEzZ87g7t27GDduHPz9/eHr64v69etj1apV+Pzzz9GsWTOMGDECS5cutauwcxznsiMGKWRkZODEiRPo2bMnXn31VWvxrcg2LW/bRUdHIzs7G4cOHULLli3x+++/Y/DgwRg9erT1GOWqVavQqFEjNGzY0K4+27NNc3Nzcf36dYwbNw4RERHW4vu0nE/LZovp06cjMjISbdq0sb72xhtv4KeffkKzZs3w7rvv4uOPP4ZGoxEtpxLQ4QUJmM1mu5fV6XQIDQ0t9VxRPz8/fPHFF9Z/h4WFYcqUKXa1/+DBAwwZMgT//ve/7e6jOysqKgLP8+B5HqdPn0anTp3Qr18//PTTT89dtrxtBwAqlQrh4eHo3bs3ACAxMRGLFy/GyZMnERUVhdWrV2PNmjV291mv19u8TXmex6NHj1BcXIyHDx/iww8/xIwZM3D27FmbslXUt99+iytXruCXX36xFvXr169j/Pjx+O677xATE4P09HSMHDkSL7zwAtq2bVtuO478bcmFiq4EHDmBptFooNPpwPN8uTv4zJkzHekaKlWqhLlz50KlUjnUjrtKSUnB5MmT4eXlBW9vb4wdOxZjx46FyWR67rJP23b16tXDH3/8Ueq9lsJjGUH26NEDAFBcXAyDwYB27dph37591mPAz1K5cmWbt+nly5fRuXNn8DwPT09P9O3bF1OmTHnqvvu8/fJZFixYgIMHDyIlJQWVK1e2vn716lW8+OKL1pFvrVq10LZtW/z5559PLbpKuXbYFlR0JeDI4+gaNWqEkJAQfP/99xgxYgQ8PT1x/vx5NG3aVLS+BQYGIjAwUJT23E1oaCgqVaqEyZMnY/jw4fDz8wNQcjnX8zxt273++uuYM2cOtmzZgq5du2Lfvn3Izs5G06ZNUalSJezatcvaxs6dO/Hbb79h3rx5FSq4gH3b1DJx5NChQzFlyhSEhYU9M+ez9svi4mLrh5Jl9Ozr6wsAWLp0KX777TcsX74cVapUKdVm/fr1kZ6ejqNHj6Jly5bIzMzEgQMHMGTIkGdmdTX0PF0J6PV6ZGRk2P1V6Pbt25g9ezZOnDgBjuPwxhtvYMKECaL0zcPDAzVq1IBarRalPXcjCALMZvMTBa+i2/Rp2+748eOYMWMGsrKyUKtWLYwfPx7NmjV7YvnNmzdj48aNSE1NrXCf7d2mJpPJppxPy9aoUaMn3nvmzBkAJcXa29u71Oj4/fffx/vvvw+g5EPm559/hlarReXKlfHmm29izJgx5Z4odtV9l4quBIxGIy5fvqzIr0Icx6FevXpMToXtCFa2KSs5pURXL0jA29tbsV+DOI5zuZ1WCVjZpqzklBIVXYko9USVUvvlCpT6uxO7X6zklAoVXYmo1WrFjRgsU10T+7CyTVnJKRUquhIJCgqSuwvlUmq/XIFSf3di94uVnFKhoisRLy8vxX0yq9Vquy9uJ+xsU1ZySoWKroRCQkIU8zWN4ziEhITI3Q2Xx8o2ZSWnFKjoSsjf318Rx8csx8P8/f1l7Yc7YGWbspJTClR0JRYWFqaIHTc8PFzWPrgTVrYpKzmdjYquxLy8vBARESHbzstxHCIiIip8Syl5Pla2KSs5nY2KrgwCAgIQHBws+c7LcRyCg4MREBAg6XpZwMo2ZSWnM1HRlYlGo0FgYKBkOy/HcahSpcpTn01KHMfKNmUlp7PQsxdkJAgCdDodcnNznXpvu2WUoNFoZD8m5+5Y2aas5HQGKroKUFBQgMzMTAiCIOoOzHGc9TiYO3wtcyWsbFNWcoqJiq5C8DwPrVYr2qzBlktrwsLCXPYiclfHyjZlJadYqOgqTHp6OrZv3474+HhwHGfTTmz5+qVWqxESEuLS1zK6k8LCQuTk5FinDHfXbcpKTkdR0VWQy5cvo3Xr1sjLy8OhQ4dQp04d6PV6GAwG66y/ZVleV6lUUKvVCAoKcsvRgTvgeR55eXluv01ZyWkvKroKkZaWhk6dOiE/Px++vr749ddf0bNnT+vPjUajdUqVx6ddV6lULvlMUcLONmUlZ0W5/8eKC8jPz0fbtm2t00lzHAedTlfqPd7e3kzuoO6MlW3KSs6Kout0FSAwMBBbtmxBVFQUPDw8YDQaodVq5e4WIcQJaKSrEJ07d4Zarcby5ctx+/Zt0Wb7JYQoCx3TVYhz586hU6dOuHXrFhMnEwhhFR1eUIjk5GS8++67VHAJcXM00lWA4uJi1KhRw3qZGCHEfdFIVwG2bt2KqKgoKriEMICKrgIkJyfjvffek7sbhBAJ0OEFmaWnpyM6OhqZmZnw8/OTuzuEECejka7Mli1bhnfeeYcKLiGMoJGujEwmE2rXro0tW7bg1Vdflbs7hBAJ0EhXRnv37kW1atWo4BLCECq6MqITaISwhw4vyOTu3buoW7cubt68iSpVqsjdHUKIRGikK5OVK1eiR48eVHAJYQwVXRkIgoClS5fSoQVCGERFVwZHjx6F0WhEXFyc3F0hhEiMiq4MLKNcd5lSmhBScXQiTWJ6vR41a9bEhQsXoNFo5O4OIURiNNKV2Nq1axEfH08FlxBGUdGV2NKlSzFs2DC5u0EIkQkVXQmdO3cOt27dQkJCgtxdIYTIhIquhGh2CEIInUiTiGV2iMOHD+Pll1+WuzuEEJnQSFciltkhqOASwjYquhJJTk6mE2iEEDq8IIX09HQ0a9YMGRkZ9LByQhhHI10JpKSk0OwQhBAANNJ1OsvsEFu3bkWTJk3k7g4hRGY00nUyy+wQVHAJIQAVXaejE2iEkMfR4QUnsswOkZ6ejsDAQLm7QwhRABrpOpFldggquIQQC7of1Q5GoxEGgwFmsxmCIIDjOHh4eEClUsHb2xvA/80OsWjRIpl7a7+K5HQHrOQE2MqqVFR0K4DneeTl5UGv18NgMFh31rIsr6tUKuTm5sLX19elZoewJ6darUZQUJBLPU+ClZwAW1ldBR3TfYbCwkLk5ORAr9cDKNkxK8pkMoHjOFSpUgUhISHw9/d3Vjcd5khOyx+wWq2mnArCUlZXQ0W3HDzPQ6vVQq/X27SzPg3HcVCr1QgLC1PU6IFy2kepOQG2sroqKrplFBQUIDMzE4IgiLLTWnAcB47jEBERgYCAANHatRfldIzScgJsZXVlVHT/lyAI0Ol0yM3NFXWHLYvjOAQHB0Oj0cgyMSXlFJfcOQG2sroDKroo2WmzsrKQn5/v1J3WguM4BAYGIjw8XNKdl3I6h1w5Abayugu6TheATqeTbKcFSv5Q8vPzodPpJFmfBeV0DrlyAmxldRfMF92CggKnfy0rjyAIyM3NRUFBgSTro5zOJXVOgK2s7oTposvzvPXEgxwEQUBmZiZ4nnfqeiinNKTKCbCV1d0wXXS1Wq1sO62FIAjQarVOXQfllI4UOQG2srobZotuYWGhaNcyOkIQBOj1ehQWFjqlfcopLWfnBNjK6o6YLbo5OTmy77QWgiAgJyfHKW1TTuk5MyfAVlZ3xGTR5XneenukUuj1etGPj1FO+TgjJ8BWVnfFZNHNy8uze9kbN27g7bffRqtWrbBq1SoRe+VYvxxtz5m5yqKczm9TypyAc7K6KyaLriPHw1JSUtCiRQscPXoUAwYMQFZWFhYuXOhwnyzHx8RkS87Hc9WtWxdDhw5FTEwMunTpUu77V65ciYSEBLRs2RLdu3fHzZs3K7QeJeU0mUxISEhA69at0aFDB3z11VfWEdu9e/cwfvx4dOjQATExMRg0aBBOnz5d4T45IydQ8axl99Pz589j8ODBaNmyJeLj47Fy5conlklLS0OjRo0wb948m/rkrKzuismiazAY7F5Wq9WiTp06+Pvvv7F48WKYTCYAwF9//YXFixfL1i9H27PkAgA/Pz/06tULY8eOLfe9GzZswMaNG7FgwQIcPXoUCxYsQFBQkFP6JXZ7j+ds37491q5diyNHjmDTpk24fPmydVRYWFiIhg0bYs2aNTh48CC6d++Ojz76yKaTRmLntKXNx3Pm5eVh+PDh6NOnDw4ePIjffvsNsbGxpd5vNBrx1VdfoXHjxk7tF2Gw6BqNRrtHue+99x7S0tIwa9YsvP/++/D398f06dOxY8cOHDx4EAMHDnSob4IgwGg0OtSGhS05H8/VsmVLqNVqdOvWDREREU+812w2Y9GiRRg/fjxefvllcByHGjVq2DQ7hlJymkwm6wNcLM+TzcjIAADUqFEDgwcPRrVq1eDp6Yk+ffrAaDTixo0bFe6bmDmBimctm3P27NmIjY1F165d4ePjg0qVKqF27dqlllm+fDliY2Px0ksv2dU3sbO6M+aKrsFgsPue8eTkZERHR2PixIk4duwYwsPDrT/z8HD8V8lxnF0jhvj4eLRr1w5paWnW12zJWTbXs/7wsrOzkZ2djatXr6Jjx45ISEjAggULYDabK9xfe3POnDkTtWvXxrp166zrczTnv/71L7Ru3RpxcXG4fPky+vTpU+6yFy9ehNFoRM2aNSvcX3tzHjp0CFWrVsWXX36Jhw8fWl+vaNayOe/du4fAwEAMHDgQ8fHxGDlyJG7fvm19v1arxebNm/Hhhx/a3FcLe7OyiLmia0txeJa///4bly9fxuTJk5GYmIiYmBiHT1gIgoCioiIUFhba9N/t27dx4MABxMfHIy4uDgcPHhQtZ1nZ2dkASgrDxo0bkZycjB07dmDjxo1Oz5mdnY0bN25gyJAhqF27NlauXOnwWfM333wTR44cwfbt29GnTx9UrVr1ifc8ePAAEyZMwPDhw6FWq21q396cRUVF+Oc//4nQ0FBMnz4dBQUFdm/T7OxsbN26FZ999hl2796N8PBwjB8/3vrz2bNnY+TIkQ4/rNxZ+5y7Ye6pxGJd39ikSRM0adIEWVlZAIAWLVqgRYsWDrX54MEDfP7559i1a5dNy1lGGEVFRTh48CDi4uJw9uxZpzwFytfXFwAwZMgQBAQEICAgAH369MGff/6Jt99+u0JtFBUV4YMPPrA556NHjwAADx8+xMOHDzFo0CAsWbLkieOT9njxxRdRp04dzJgxA99//731dYPBgJEjR6JJkyYYNmyYTW0WFxfbldNkMlmzAsCUKVOwZcsW7N2716Z2LHx9fdGhQwdERUUBAIYPH464uDjo9XocP34chYWFSEhIsKvtxynl2mGlY67oil2IwsPDMWLECFHaUqvVWLZsmc2zB0dGRuLq1avw8fFB165dMWvWLLzwwgvWDwQxvfTSS/D29nbo9+jv729XznHjxmHu3LmoVKkSIiMjMWfOHDRt2lS0W1F5nrce0wVKivzo0aNRvXp1TJkyxeb2fH197cq5Y8cO9O7dG56enqhcuTJmzpyJpKQku+/8qlevXqnt9fj/Hz16FOfOnUO7du0AlHzwe3h44MqVK5g/f75N66FHPVYMc4cXxDj26kz29C8yMhK9e/fG6dOnsW7dOtStW9ehnGazGcXFxeB5HoIgoLi42HqSxM/PDwkJCUhJScHDhw+h0+mwfv16xMfH27QOe/pXo0YNNGvWDNu2bcNff/2F9u3bw9PT0+Z2LDZs2IB79+4BAK5du4bk5GS0atUKQMlJq7Fjx8LX1xczZ860+/dpz3LVqlVDWFgYfvjhB2RkZOC9996Dt7e33X3o2bMn9u3bZz0u/dNPPyE6OhpqtRojR47E9u3bsX79eqxfvx7t2rXDW2+9hRkzZti8HqX/bSkFcyNdlUql2K9BgiBApVLZvNzWrVufeM2RnMePH8fQoUOt/27evDmaN2+OlJQUAMDEiRMxbdo0dOjQAWq1Gm+99RZ69epV4fbtzTlmzBiMGTOm1GuO5Dx58iTmzZuHoqIiBAUFoXPnzhg5ciQA4NSpUzhw4ABUKlWpwxeLFi1Cs2bNKtS+vTmbN2+Oa9euPfG6vVlbtWqF0aNH46OPPkJRURGio6Px1VdfAQAqVaqESpUqWd/r6+sLPz8/m0fn9mZlEZMzR5w/f16RB/09PDzQoEED0dqjnPISOyfAVlZ3xeT3AaV+IovdL8opL2f0i6Ws7orJoqtWqxV30N8y1bWYKKd8nJETYCuru2Ky6Npyy6qUxO4X5ZSXM/rFUlZ3xWTR9fLyUtwns1qthpeXuOc1Kad8nJETYCuru2Ky6AJASEiIYr6mcRyHkJAQp7RNOaXnzJwAW1ndEbNF19/fXxHHxyzHwxy9BfNpKKe0nJ0TYCurO2K26AJAWFiYInbcxx+c4wyUUzpS5ATYyupumC66Xl5eiIiIkG3n5TgOERERDt1VVRGUUxpS5QTYyupumC66ABAQEIDg4GDJd16O4xAcHGx9nquzUU7nkjonwFZWd8J80QUAjUaDwMBAyXZejuNQpUoVaDQaSdZnQTmdQ66cAFtZ3QWTtwGXRxAE6HQ65ObmOvXZDJZRgkajkeWrIeUUl9w5AbayugMqumUUFBQgMzMTgiCIugNzHGc9DqaEr2WU0zFKywmwldWVUdEtB8/z0Gq1Ds0a/DjLpTVhYWGKuoicctpHqTkBtrK6Kiq6z1BYWIicnBzr9NK2/KosX7/UajVCQkIUfS0j5Xw+V8oJsJXV1VDRfYadO3firbfewr1796DX66HX62EwGKwzx5ZleV2lUkGtViMoKMilRgc8zyMvL49y/i9XzwmwldVVUNF9ik2bNqFfv34wGo3IyMgodRG40WiEwWCA2Wy27qweHh5QqVTw9vaWsdfiopzulRNgK6tS0UdZORYtWoRx48bh0aNH8Pf3h06nK1V0vb29mdhBKaf7YSmrUlHRLePChQsYMWKE9SuYl5cXdDqdzL0ihLgLujmijPr16+PYsWN45ZVX4OXlhcLCQiq6hBDR0Ei3HI0bN8bdu3dx8uRJnDp1qtTEhIQQ4ggquuXYsmULmjRpgqioKERFRcndHUKIG6HDC+VYunQp3nvvPbm7QQhxQ3TJWBk3b95E8+bNkZmZSTOcEkJERyPdMlJSUtC/f38quIQQp6CR7mNMJhNq1aqF7du3o3HjxnJ3hxDihmik+5g9e/ZAo9FQwSWEOA0V3cfQCTRCiLPR4YX/defOHURGRiI9PZ2eGUoIcRoa6f6vFStWoEePHlRwCSFORUUXJY+1S05OxrBhw+TuCiHEzVHRBXD48GGYzWa0adNG7q4QQtwcFV383wk0mmyPEOJszJ9IKygowIsvvoiLFy+ievXqcneHEOLmmB/prlmzBu3bt6eCSwiRBPNFl06gEUKkxHTRPXv2LDIzM9GlSxe5u0IIYQTTRTc5ORlDhgyBp6en3F0hhDCC2RNpxcXFiIiIwNGjR1G7dm25u0MIYQSzI93NmzejSZMmVHAJIZJitugmJyfTw20IIZJj8vACzQ5BCJELkyPdlJQUDBgwgAouIURyzI10aXYIQoicmBvp0uwQhBA5MVd0aXYIQoicmDq8QLNDEELkxtRIl2aHIITIjZmiKwgCli5dSg+3IYTIipmie/jwYQiCQLNDEEJkxUzRpdkhCCFKwMSJNHeYHcJoNMJgMMBsNkMQBHAcBw8PD6hUKnh7e8vdPdGwkhNgJysrOSvKS+4OSMEVZ4fgeR55eXnQ6/UwGAzWnbUsy+sqlQpqtRpBQUHw8nKdzcpKToCdrKzktBcTI93WrVtj8uTJePPNN+XuynMVFhYiJycHer0eQMmOWVGWHVutViMkJAT+/v5O6aMYWMkJsJOVlZyOcvuie+bMGSQmJiI9PV3RDyvneR5arRZ6vd6mnfVpOI6DWq1GWFiYokYPrOQE2MnKSk6xuH3RHTNmDNRqNaZPny53V56qoKAAmZmZEARBlJ3WguM4cByHiIgIRVybzEpOgJ2srOQUk1sXXcvsEMeOHUOtWrXk7s4TBEGATqdDbm6uqDtsWRzHITg4GBqNRparN1jJCbCTlZWczuDWl4xZZodQasHNyspy+k5rWVdubi6ysrKcvq7y1s1CTsv6WcjKSk5nceuiq+Q70HQ6HfLz8yXbkQRBQH5+PnQ6nSTrs2AlJ8BOVlZyOovbFt0bN27g5MmT6Nmzp9xdeUJBQYEko4SyLKOGgoICSdbHSk6Anays5HQmty26Sp0dgud564kHOQiCgMzMTPA879T1sJITYCcrKzmdzS2LrslkQkpKiiKfm6vVamU/NiUIArRarVPXwUpOgJ2srOR0Nrcsurt370ZoaKjiZocoLCwU7VpGRwiCAL1ej8LCQqe0z0pOgJ2srOSUglsWXaWeQMvJyZF9p7UQBAE5OTlOaZuVnAA7WVnJKQW3K7p37tzBvn370K9fP7m7UgrP89bbI5VCr9eLfnyMlZwAO1lZySkVtyu6qamp6NWrl+LuYsnLy7N72Z49eyItLU3E3vwfR/rlaHvOzFWW2DltaVPKnABtU6VzqzvSBEFAgwYNsGTJErz22mtyd6eU69evO3wcShAEzJ8/H1u2bEFhYSFeeeUVTJo0CXXq1LG7TX9/f9SuXduhfj3OnpxXrlzBnDlzcP78edy/fx9nzpx54j07duzAokWLoNPpULVqVcyYMQPNmjWr8DrEzgnYnnXHjh1YuHAhcnJy4OPjg9deew0TJkxA5cqV8ejRI8yYMQNHjhxBfn4+atSogdGjRyMuLs7mfilhmwJARkYGvvzyS/z111/w8fFBr169MHbs2FLvSU9PR+/evdGpUyd8+eWXNrXvjG0qBbca6R46dEixs0MYDAaH29i1axc2b96MZcuW4eDBg2jSpAkmTpwoe78cbc/LywtdunTBtGnTyv35oUOH8N1332H69Ok4cuQIli1bhoiICKf3S+w2mzZtitTUVBw+fBg7duwAz/OYP38+gJKv8BqNBikpKTh8+DA+/vhj/L//9/+QlZXl9H45oz2j0Yh//OMfaNmyJX7//Xfs3bu33Kf8zZw5E1FRUZL1SwncquhaTqAp7R5to9Ho0EmILl264PDhw8jKykLTpk1Ro0YNeHp6omvXrrh27ZpDfRMEAUaj0eblHjx48MRytua05KpVqxZ69+791BH7woUL8eGHH6JJkybw8PBA9erVbX42sr05jUYjHjx4UO7rFc1qyanRaBAUFGR93dPTE7du3QJQMmobMWIEwsPD4eHhgfj4eISHh+P8+fM299nerPfv33/iNXu36ebNm/HCCy9g8ODB8Pf3h6+vLyIjI0u9d8eOHVCr1WjVqpXNfQXszyk3tym6BQUF2LRpE5KSkuTuyhMMBoMoHwSJiYnIyMjAzZs3YTQasXXrVodH9RzH2TViePfddxEeHo7Fixfj0aNHAMTL+TVAjNcAABtmSURBVDiTyYRz584hNzcXb7zxBl5//XXMnDnT5j7bm3PlypWoVq0axo8fj3v37llftzfriRMnEBMTg1atWmHv3r0YNGhQue/LyclBeno6Xn75ZZvXYU/W69evIzg4GL169cLFixetr9ub8/Tp0wgLC8OHH36IuLg4DBkyBJcvX7b+/MGDB1iwYAE++eQTm9u2sHebys1tHlb566+/4vXXX8cLL7wgd1eeYDabRWmnWrVqiI6ORrdu3eDp6QmNRoOlS5c63LeMjAyb/7Cys7Nx9+5djBkzBhMmTMDo0aMxcuRIh/pSnnv37oHneezZswfLly+Hl5cXRo0ahcWLF2PUqFEVbkcQBLty3rx5EyaTCfPmzcP8+fPRv39/TJ8+HZUqVbI1CgAgOjoahw8fRnZ2NjZs2ICwsLAn3mM0GvHZZ5+he/fudh2ztGebXrt2DSqVClu3bsWOHTvw2muvYebMmXjllVdsXj9Qsn+kpaVh3rx5aN26NVauXIlRo0Zh27Zt8Pb2xo8//ohevXpBo9HY1b6FWH9bUnKbopucnIypU6fK3Y1yiXWuctGiRTh79iz27NmDkJAQbN++HcOGDcOmTZvg5+dnV5tFRUVYvHgxDh48aNNyGRkZ1uWLioowdepUxMTEIDw83K5+PI2vry8AoH///qhWrRoAICkpyeaiW1xcjEWLFtmcMy8vDzzPW7fhL7/8Ap7n8cMPP9jUTlnVq1dHmzZtMH78eKxdu9b6utlsxsSJE+Ht7W338frCwkIsWbLEpqzFxcUoLi6G2WxGcXEx9u3bhxs3buD48eN29cHX1xdNmza1ngh89913sXjxYly/fh2CIODIkSNYt26dXW0/zhWvA3CLonv69GlotVp06dJF7q6US6yv3JcuXUJCQoJ1dNCzZ098/fXXuH79Oho2bGhXm5UqVcKcOXMQGBho03KJiYnYs2cPvL29kZSUhKlTp6JSpUp2nfh5lsDAQFSvXr3U79Ce36dKpbIr56JFizBq1Ch4e3sjJiYG33zzDaKjo5Gfn29zH8oymUzWDy+gpIBMmTIF9+7dw8KFC+2etLFy5co2Z71w4QJeffVVeHt7o3r16vjmm2/Qs2dPu6/PrVevHk6dOlXuz9LS0qDVatGpUycAJR8SZrMZffv2LfUBVBFKO39TEW5RdJOTkzFkyBDFTsfj4SHOofOoqCjs3r0bCQkJCA4Oxr/+9S/wPI8aNWpI3r/Y2FjUrFkTU6dOtX5FtvcPVBAEPHr0yHpSpLi4GBzHwcfHB0DJh8vq1avRpk0beHl5YcWKFWjbtq3N67EnZ926ddG5c2dMnz4d0dHRDrW1fft2NGvWDKGhodBqtZg3b16pk0jTp0/HjRs3sGTJEocf1GRr/6pVq4ZWrVphzJgx6Nmzp3V5e/fdrl27Wq/UaNmyJVatWoUqVaqgdu3aqFmzJhITE63vXbZsGbRaLT7//HOb1yPW35aUXL7oGgwGrFq1StKLz22lUqlE+Ro0dOhQ3Lt3D3369EFRURFq1qyJuXPnOnQjiCAIdv2BT548+YnX7M2p1WqRkJBg/Xfz5s0RFhaGXbt2AQA++OAD3L9/H926dYOPjw+6dOmCf/zjHzatw96cHTt2RMeOHZ943Z6s169fx3fffQe9Xg+1Wo24uDiMGTMGQMnvYN26dfDx8UG7du2sy0yZMgVdu3a1aT32ZA0JCcG///3vJ163d5vWqlULs2fPxvTp05Gbm4v69etj/vz58Pb2hre3d6nDYf7+/vDx8UFwcLBN67B3m8rN5W+O+PXXX5GcnIw9e/bI3ZVnOn/+vCIP+nt4eKBBgwaitcdKToCdrKzklIrrjc3LSE5OVuQjHMtS6iey2P1iJaez2hQDbVNlc+mie+PGDZw6dUqRs0OUpVarFXfQ3zLVtZhYyQmwk5WVnFJx6aL7yy+/oH///i7xiff4nUhKIna/WMnprDbFQNtU2Vy26Cp5dojyeHl5Ke6TWa1Ww8tL3HOprOQE2MnKSk6puGzR3bVrF8LDwxU3OwRQckVFu3bt0KhRIzRo0AAvvfQSAgMDcebMGcV8TeM4DiEhIU5pOyQkhImcADtZWckpBdf8qEDJw22UOsr19fWFTqfDpUuXrK8FBgYiJiYGDx48kH3aE8vxMH9/f6e07+/vD7Va7fY5AXayspJTCi450s3Ozsb+/fsVNzuExY0bN0o9etDPzw9//PEHgoODERYWJvuIgeM40W/XLYuVnAA7WVnJ6WwuWXSVOjtERkYGPvjgA7Ro0QJt2rRBw4YN4eXlhdmzZ+PVV18FUHJ8LCIiQradl+M4REREOP3uPVZyAuxkZSWns7lc0RUEAcnJyYqaePL27dsYNWoUmjRpgqCgIFy+fBnTpk3DkiVLMGTIkCcezBIQEIDg4GDJd16O4xAcHCzZhxUrOQF2srKS05lc7pjuf/7zH3Ach9jYWLm7gpycHHz99ddYunQpBg8ejAsXLpR6uHZMTAxiYmLKXVaj0cBkMiE/P1+SY2Qcx6FKlSoOP0rPVqzkBNjJykpOZ3G5ka7lBJqcx5bu37+PyZMnIzIyEg8ePMDp06fx3Xff2TSbgeXYlBSjBssoQY5jcqzktKyfhays5HQWl3r2Qn5+Pl588UVcvnxZloeV6/V6/PDDD/j+++/RvXt3TJ48GbVq1XK43YKCAmRmZkIQBFFHDhzHWY+DKeFrGSs5AXayspJTTC5VdH/++Wfs3r0bGzZskHS9hYWFWLBgAebMmYOOHTti6tSpqFevnqjr4HkeWq1WtEtyLJfWhIWFKeoiclZyAuxkZSWnWBRRdI1GIwwGA8xmMwRBAMdx8PDwgEqlKvUg55YtW2LatGmlnsXpTAaDAUuWLMHs2bMRGxuLadOm2f2w8IoqLCxETk6O9dm0tmwey9cvtVqNkJAQRV/LyEpOgJ2srOR0lCxFl+d55OXlQa/Xw2AwWAttWZbXVSoVeJ5Hr169cPr0aadfMmI0GpGSkoIZM2agcePGmD59Opo2berUdZZlz+9IrVYjKCjIpUYHrOQE2MnKSk57SVp0Hf0kFAQBAQEBTvsk5Hkeq1atwrRp01CnTh3885//ROvWrUVfjz0q+m3A1bGSE2AnKys5K0qSoqv0Yz5msxlr167FF198gRdeeAHTp09HfHy8w+0SQkhZTi+6Sj67KQgCtmzZgilTpsDPzw8zZsxAx44d3ebSFEKI8jit6AqCAJ1Oh9zcXKdeQG25jk+j0VS4WAqCgJ07d2Ly5MkwmUyYPn063nzzTSq2hBCnc0rRFQQBWVlZkt6xEhgYiPDw8OcWzv379+Pzzz9Hfn4+pk2bht69e7vkjKKEENfklFOFOp1OsoILlBT5/Px8eHp6IjQ0tNz3/Oc//8HkyZORkZGBL774Av369XP5B2cQQlyP6EO8goICpx9SKI8gCMjNzUVBQUGp1//66y8kJiZiwIABGDhwIC5cuIABAwZQwSWEyELUosvzvPWkmRwEQUBmZiZ4nsfp06fRq1cv9OzZE926dcOlS5cwdOhQJq4DJIQol6hFV6vVyvpUeeD/rkjo3Lkz2rZtiytXrmDEiBHw9fWVtV+EEAKIeCKtsLAQN27ckL3oAiXX3YaHh6Nq1apyd4UQQkoRbaSbk5OjiIILAB4eHnj48KHc3SCEkCeIUnR5nrfe2muPnj17Ii0tTYyuWOn1evA8L2qbhBDiKFGKbl5enkPLb968GS1atLD+e9KkSY52CYDj/SKEELGJUnTFeKbC7du38dVXX6GoqAgAcPnyZcyZM8fu9gRBcGj0TQghziBK0TUYDA4t36VLF9y8eROdOnXCxIkTkZaWhrVr12Lo0KGy9osQQsTmcNE1Go1OOYHm4eHh8LMQBEGA0WgUqUeEEOI4h4uuwWAQ5UExd+/exZ49ezBr1iy0aNECb731Fn755ReH2uQ4jka7hBBFcfj2LLPZLEY/UK1aNXTv3t3678jISERGRjrcrlj9I4QQMTg80nXGoYWZM2eK1pZSrh0mhBBAhKKr9GfQKr1/hBC2OFx0lf4sWqX3jxDCFoeP6apUKoe/wu/atcvRbpRLEASoVCqntE0IIfZweBjo7e2t2K/wHMcxOdsoIUS5RPnurdTRpFL7RQhhlyhFV61WK260azAYsH79emzYsAHFxcVyd4cQQgCIVHSDgoLEaEZUKpUKoaGh+PHHHxEWFobhw4fj8OHDdAkZIURWoj3E/NatW0/MTyangIAA1KxZEwCQnp6OlStXIjU1FWazGUlJSRg0aBBeeukleTtJCGGOW84cwXEcatWqBX9//1KvC4KAY8eOITU1FWvWrEHDhg2RlJSEt99+G4GBgTL1lhDCEtGKLlAy2hXjMY+O4DgOarXaOsp9muLiYvz2229ITU3F/v378cYbbyApKQmdOnWiySsJIU4jatHleR6XL1+W9XkHHh4eiIyMtGmK9Xv37mHNmjVITU1Feno6BgwYgKSkJDRu3NiJPSWEsEjUogsABQUFyMjIkGW0y3EcatSogYCAALvbuHTpElasWIEVK1YgKCgISUlJ6N+/PzQajYg9JYSwSvSiC5TMApGbmytp4eU4DsHBwQgNDRWlPbPZjAMHDiA1NRWbN29GTEwMkpKS0KNHD/j5+YmyDkIIe5xSdAVBQFZWFvLz8yUpvBzHoUqVKggLC3PK9cIPHz7E5s2bkZqairS0NPTu3RtJSUl47bXX6NkOhBCbOKXoAiWFV6fTOX3EaxnhajQaSW7QyMrKwurVq5GamooHDx5g0KBBGDRoEOrWrev0dRNCXJ/Tiq5FQUEBMjMzIQiCqMWX4zhwHIeIiAiHjuHaSxAE/P3330hNTcXq1atRu3ZtJCUloW/fvggODpa8P4QQ1+D0oguUXNWg1WpFu5zMcllYWFiYIi7vMhqN2LNnD1JTU7Fz50507NgRSUlJSExMdOkH7hiNRhgMBpjNZgiCAI7j4OHhAZVK5dK5ymIlJ8BOViXnlKToWhQWFiInJ8c6Nbotq7YcOlCr1QgJCXnixgeluH//PtavX4/ly5fj0qVL6NevH5KSktCsWTPFPZ+iLJ7nkZeXB71eD4PBYN1Zy7K8rlKpoFarERQUpIgPv4piJSfATlZXyilp0bVwpV+QI65du2a9/djX1xdJSUkYMGAAatSoIXfXSmHhwxBgJyfATlZXzClL0S1LyV8FxCAIAg4dOoTU1FSsX78eTZs2RVJSEnr37o3KlSvL1i93P+xjwUpOgJ2srpxTEUWXJQaDAdu2bUNqair+/PNPdO/eHUlJSWjfvr1Nd9E5yl1PcJbFSk6AnayunpOKrozu3LmD//mf/0Fqaiqys7MxcOBAJCUloUGDBk5bp7teylcWKzkBdrK6S04qugpx9uxZrFixAitXrkRoaCiSkpLwzjvvoFq1aqKtQ46bVgIDAxEeHi7pHykrOQF2srpTTiq6CmMymbB//36kpqZi27ZtaNu2LZKSktCtWzf4+vo61LY73J5dEazkBNjJ6k45qegqmF6vx8aNG5GamopTp06hb9++SEpKQuvWrW3+9HX1BxFVFCs5AXayultOKrou4tatW1i1ahWWL18Ok8mEpKQkDBw4ELVq1bK+Z/v27QgODkZsbGypZZXyyM169eo59cwwKzkBdrK6Y056WouLqFmzJiZMmIALFy5g9erVuHPnDlq2bIn4+HgkJyfj/v37eP/999GxY0ccO3as1LJarVb2GT0EQYBWq3XqOljJCbCT1R1z0kjXhT169Ag7duxAamoqdu3aheLiYvA8D7VajUOHDiEqKsolplESAys5AXayumtOGum6MB8fH/To0QMbNmxAt27dYDKZAJQcC27evDnS0tKQk5OjiJ0WKBkx5OTkOKVtVnIC7GR115zKucWEOGT//v0AAH9/f1SpUgWenp64ePGi4m7h1Ov14Hle1OOAPM9bbwNVCmfkBNjJ6s45qei6ibNnz8LPz6/UbcV3797FnTt3ZOxV+fLy8kS9/jgvL0+0tsQkdk5Lm0pE27Ti6PCCm6hWrdoTz3FwxszMO3fuRPfu3dGqVSv06NED+/bts2l5QRBEH8E4mtNoNGLs2LHo0qULGjVqhLS0tCfec/78eQwePNh68nLlypXPbNMZOQHHsl67dg3/9V//hdjYWMTGxmLYsGG4du2a9ecpKSno1asXWrVqhYSEBKSkpFSoXSVuUwAoKirCjBkzEBcXh5iYGAwePPiJ9xiNRnTv3h2vv/76c9sTKyeNdN2YwWAQtb3s7GxMmDAB8+bNw2uvvYY///wT48aNw86dO1G1alXZ+iVGe02bNsXAgQMxbty4J36Wl5eH4cOH45NPPkHnzp1hNBqRnZ0tSb/EbLNatWqYO3cuwsLCYDab8euvv+KTTz7Bxo0bAZQUlZkzZ6JevXrIyMjABx98AI1Gg8TERKf2y1ntTZs2DSaTCVu2bEFgYCAuXrz4xHtSUlIQFBSEhw8fStYvGum6KaPRaNdIoVGjRrh165b135MmTcK8efMAlBTdgIAAxMXFgeM4tG3bFn5+fsjIyLBpHYIgwGg02ty3nJycJy7dqWjOZ+Xy9vbGoEGDEB0dXe6cd6mpqYiNjUXXrl3h4+ODSpUqoXbt2s9dp705Hz16VG6BqEjWZ+UMCAiw3tYqCAI8PDxKbbuhQ4eiQYMG8PLyQq1atdC+fXucPHmyQn22N+vZs2efuAZXjG16/fp1/PHHH5g6dSqCg4Ph6emJhg0bllo+MzMT27dvx7BhwyrcX3tzPo6KrpsyGAyi3zPesGFD1KpVC7///jtMJhP27dsHb29v1KtXz6Z2OI6za8Tw9ddfo2bNmhg2bBgyMzMBOCdnWadPn0ZgYCAGDhyI+Ph4jBw5Erdv337ucvbm/OOPP1C/fn20a9eu1KEOsbLGxsaiefPmmD179lMLjiAIOHHiBOrUqVOhNu3JWlxcjEaNGqF27dpYu3attfiKkfPs2bMIDQ3FggULEBcXh169emHPnj2l3jN79myMHj0aKpWqwu3au00fR4cX3JQz7uDx9PRE9+7d8emnn+LRo0fw9vbGt99+a/MVEiaTCcePH0dRUZFNy129ehUmkwnLly/HihUr0K5dO/z44482tWGP7OxsXLhwAYsXL0bdunUxd+5cjB8/HitWrHjmcmaz2a6caWlp8PPzw4EDBxAXF4d69erh66+/RkxMjCMxrA4dOoTCwkJs3boVYWFh5b5n4cKFMJvN6NmzZ4Xa5Hne5qyPHj2Ch4cH0tPTMXjwYHz88cf49NNP8d5771W4jafJzs7G1atX0alTJ+zfvx+nTp3CRx99hJdffhm1a9fGvn37YDKZ8Prrr5d7DP9ZHP3boqLrppxxfePhw4cxd+5cpKSkoH79+jh//jw+/vhjLFq0CK+88kqF2zEajdi9e3eFv7paWL5y8zwPjuOwd+9enDp16omvjWLz9fVFhw4dEBUVBQAYPnw44uLioNfroVarn7ocz/N25bx79671K2xxcTHOnDmDn3/+Ga1bt7Y/RBn+/v7o27cv2rZtiy1btpQ6Jr969Wps27YNy5Ytg4+PT4Xae/Tokc1ZTSaTdT81GAwoLi7G999/j6FDh9oWphy+vr7w8vLCP/7xD3h5eaFFixZo2bIlDh06BI1Gg7lz52LhwoV2te3o3xYVXTdl79czPz+/UqOVe/fuoXr16gCAS5cuoVmzZtYiFxUVhUaNGuHIkSM2FV2VSoVPP/0UgYGBNvVt3Lhx+OGHH+Dr64vRo0fjk08+gYeHB7KyshzK9Tz16tUr9fus6O/Wx8fHrpw7duxAjx49oFKp0LVrV8yaNQt169ZFfn7+c5e1JafZbIbBYMCdO3esRXfTpk1ITk7GsmXLoNFoKtxnf39/m7MaDAb4+/ujUqVKqFevHr799lu0a9cOBQUFFVr+WVnLO+Rl2W63bt2CVqu1Xs1gNBrx4MEDtGvXDqtWrUJ4ePgz1+vooQ86puumyjshVBGRkZH47bffYDKZcPDgQfz111/WnzVs2BAnTpywjjgvXLiAEydO2HxM197+xcfH47PPPkNmZiZmzZqFoKCgCrfzrFxAyUituLgYQMkfYXFxsXVE07NnT+zbtw8XL16E0WjETz/9hOjo6GeOci3syRkZGYkhQ4bg9OnTWLduHerWrVvhtp6V89ChQ7hw4QJMJhMePHiAb775BgEBAdaTgtu3b8cPP/yAJUuW2DWPn61ZfXx8MHToUGzbtg3Hjx9H+/btrVN1VcSzsjZr1gyhoaFYunQpeJ7HyZMncezYMcTGxqJOnTrYs2cP1q9fj/Xr12PatGmoWrUq1q9fX6EPGnv/tizo2Qtuymg04vLlyzZ/FTp37hwmTZqE27dvo0OHDjCZTIiIiMCoUaMAlHz1XLlyJe7du4egoCC888475V7/+Cwcx6FevXqizH9X0ZzPy9WlS5cnrozYuXOnddSzZs0aLF68GEVFRYiOjsbnn3/+3D9QMXMCFcv6rJy7du3Cjz/+iOzsbKhUKkRFRWH06NGIjIwEACQkJCA7O7tUf7t27YopU6Y8t29K3KZXr17F1KlTceXKFYSGhmLUqFHlXo+blpaGzz77rELXnIuRk4quGzt//rysj8R7Gg8PD1GnJGIlJ8BOVnfOSYcX3Jgtl8JISex+sZLTWW2KgbZpxVHRdWNqtVq2yRKfxjLVtZhYyQmwk9Wdc1LRdWNBQUFyd6FcYveLlZzOalMMtE0rjoquG/Py8nLKaMsRarVa9McdspITYCerO+ekouvmQkJCFPM1jeM4hISEOKVtVnIC7GR115xUdN2cv7+/Io6PWY6HOeuh6qzkBNjJ6q45qegyICwsTBE77vPu9HEUKzkBdrK6Y04qugzw8vJCRESEbDsvx3GIiIiAp6enU9fDSk6AnazumJOKLiMCAgIQHBws+c7LcRyCg4MREBAgyfpYyQmwk9XdclLRZYhGo0FgYKBkOy/HcahSpYpND04RAys5AXayulNOug2YMYIgQKfTITc316nTW1tGCRqNRpavhqzkBNjJ6i45qegyqqCgAJmZmRAEQdQdmOM463EwKb9qPw0rOQF2srp6Tiq6DON5HlqtVrRZgy2X1oSFhTnlxgB7sZITYCerK+ekoktQWFiInJwc6/TStuwSlq9farUaISEhTr0+1VGs5ATYyeqKOanoEiue55GXlwe9Xg+DwQBBEMo9pmV5XaVSQa1WIygoSFGjoOdhJSfATlZXyklFlzyV0WiEwWCA2Wy27qweHh5QqVSiPZhbCVjJCbCTVck5qegSQoiE6DpdQgiREBVdQgiREBVdQgiREBVdQgiREBVdQgiREBVdQgiREBVdQgiREBVdQgiREBVdQgiR0P8HW76b6aUjlL8AAAAASUVORK5CYII=\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],\n",
" 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],\n",
" 'f16': ['f32'], 'f32': ['f64'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],\n",
" 'c64': [2, 3], 'c128': [3, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WVjEMlmm9RS5"
},
"source": [
"Notice that this is not (yet) a true lattice: there are many pairs of nodes for which a join does not exist. However, we can think of this as a *partial* lattice, in which some pairs of nodes do not have a defined promotion behavior, and the defined portion of this partial lattice does correctly describe NumPy's array promotion behavior (leaving aside value-dependent semantics mentioned above).\n",
"\n",
"This sets up a nice framework by which we can think about filling-out these undefined promotion rules, by adding connections on this graph. But which connections to add?\n",
"Broadly speaking, we want any additional connections to satisfy a few properties:\n",
"\n",
"1. Promotion should satisfy the commutative and associative properties: in other words, the graph should remain a (partial) lattice.\n",
"\n",
"2. Promotion should never allow for dropping entire components of data: for example, we should never promote `complex` to `float`, as it would discard any imaginary parts.\n",
"\n",
"3. Promotion should never lead to an unhandled overflow. For example, the maximum possible `uint32` is twice as large as the maximum possible `int32`, so we should not implicitly promote `uint32` to `int32`. \n",
"\n",
"4. Wherever possible, promotion should avoid loss of precision. For example, an `int64` value may have 64 bits of mantissa, so promoting `int64` to `float64` represents a possible loss of precision. However, the maximum representable float64 is larger than the maximum representable int64, so in this case criterion #3 is still satisfied.\n",
"\n",
"5. Wherever possible, binary promotion should avoid resulting in types that are wider than the inputs. This is to ensure that JAX's implicit promotions remain friendly to accelerator-based workflows, in which users often want to restrict types to 32-bit (or in some cases 16-bit) values.\n",
"\n",
"Each new connection on the lattice introduces some level of convenience to the user (a new set of types that can interact without explicit casting), but the convenience may become too costly if any of the above criteria are violated. Developing a full promotion lattice involves striking a balance between this convenience and this cost."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GSqwTTS8nYdn"
},
"source": [
"## Mixed Promotion: Float and Complex\n",
"\n",
"Let's begin with what is perhaps the easiest case, that of promotion between float and complex values.\n",
"\n",
"Complex numbers are made up of pairs of floating point numbers, and so we have a natural path of promotion between them: cast float to complex while maintaining the width of the real part. In terms of our partial lattice representation, it would look like this:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"cellView": "form",
"id": "5DJ59qZSoY6J",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2dd1gUV9vG76Eu4NIEXYoabERFjVjBIGpsJPaor7FgLCkao776xUSNGl8sKcYkGjVREUU0sdfEbmLia8MWe2/AuiiCsAgLO7vz/cG3+0lR2d1pu+f8rivXFYedM8+95+w9z5w5heE4jgOFQqFQRMFJ6gAoFAqFJKjpUigUiohQ06VQKBQRoaZLoVAoIkJNl0KhUESEmi6FQqGICDVdCoVCERFquhQKhSIi1HQpFApFRKjpUigUioi4SB0ApXLo9XrodDoYjUZwHAeGYeDk5ASFQgFXV1epw6NYASl1SorOykJNV6awLIucnBxotVrodDpzYy2L6bhCoYBSqYSfnx9cXGi1yhFS6pQUndbC0AVv5EVBQQGysrKg1WoBlDTMymJq2EqlEgEBAfD09BQkRoplkFKnpOi0FWq6MoFlWajVami1Wosa6/NgGAZKpRLBwcFEZA9yhJQ6JUUnX1DTlQF5eXlIT08Hx3G8NFoTDMOAYRiEhobC29ubt3IpL4eUOiVFJ59Q05UQjuOg0WiQnZ3Na4MtC8Mw8Pf3h0qlqrBvjcIfpNQpKTqFgJquRHAch4yMDOTm5graaE0wDAMfHx+EhIQ4TOOVG6TUKSk6hYKO05UIjUYjWqMFSn4oubm50Gg0olyPREipU1J0CgU1XQnIy8sT/LGsIjiOQ3Z2NvLy8kS9LgmQUqek6BQSaroiw7Ks+cWDFHAch/T0dLAsK8n1HRFS6pQUnUJDTVdk1Gq1ZI3WBMdxUKvVksbgSJBSp6ToFBpquiJSUFDA21hGW+A4DlqtFgUFBZLG4QiQUqek6BQDaroikpWVJXmjNcFxHLKysqQOw+4hpU5J0SkG1HRFgmVZ8/RIuaDVau2+f0xKSKlTUnSKBTVdkcjJyRGk3GnTptl0vlBxkYAt3112djYmT56MqKgoREdH49NPPy33mdzcXLRr1w7x8fGixWVreY8ePcLHH3+Mjh07onHjxsjIyCj19/nz5+Ott95C69at0aNHD+zYsaPU30+cOIEBAwagTZs26NatGzZu3MhLXHKCmq5I8NkfVlhYiFmzZuHJkycASl5wzJo1y+LyTf1jFOuwpU7//e9/IyAgAPv27cPhw4fx7rvvlvvMd999h7CwMIvKFaJOLdHJMAzatm2LBQsWVPh3Dw8PLFq0CMeOHcOcOXPw5Zdf4ty5cwBKloCcMGEC+vXrh2PHjmH+/Pn45ptvcO3atXLl2HPbdbzVJGSKTqez+lyNRoMvv/wSZ86cgdFoRFxcHAYNGoTZs2fj0qVLWLx4MUaPHm3VbB1b4iKdynx3FdVdhw4doNFosHLlSjg7OwMAGjRoUOq8c+fO4caNG+jXrx+2bt3Ke1x8lFeRtmnTpmHgwIHPffT/6KOPzP/fpEkTNG/eHP/88w9ee+015ObmIj8/Hz169ADDMIiIiEDt2rVx69YthIeHVzouuUMzXRHQ6/VWZ0QGgwEfffQRgoKCsGfPHhw8eBBxcXHmvz+7KLQ1cBwHvV5v1bkksHPnTrRu3RoHDhwoVYeVqdPn1d0///yDV155BdOmTcPrr7+OgQMHIjU1tdR5c+fOxdSpU626kVpTp5mZmWjYsCF+/vlnFBcXm48/T+fL2mVl0Ol0uHjxIurUqQMACAgIQFxcHLZt2waDwYBz587hwYMHiIyMrPB8e2271HRFQKfTWT1n/MKFC3j06BEmTZoET09PuLu7o0GDBli7di0+//xzNG/eHGPGjMGKFSusMnaGYew2YxCDtLQ0nDlzBr1798Zrr71mNt/K1GlFdRcZGYnMzEwcPXoUrVq1wh9//IFhw4Zh/Pjx5j7KtWvXonHjxmjUqJFVMVtTp9nZ2bh9+zYmTZqE0NBQs/k+T+fztFlCQkICwsPD0bZtW/OxN998Ez/99BOaN2+Od999Fx9//DFUKhVvOuUA7V4QAaPRaPW5Go0GQUFBpdYV9fDwwBdffGH+d3BwMGbMmGFV+fn5+Rg+fDj++usvq2N0ZAoLC8GyLFiWxfnz59G5c2cMHDgQP/3000vPrajuAEChUCAkJAR9+/YFAMTFxWHZsmU4e/YsIiIisG7dOqxfv97qmLVarcV1yrIsiouLUVRUhKdPn+LDDz/E7NmzcfHiRYu0VZZvv/0WN27cwMqVK82mfvv2bUyePBnfffcdoqKicO/ePYwdOxbVqlVDu3btKizHlt+WVFDTFQFbXqCpVCpoNBqwLFthA58zZ44tocHLywsLFiyAQqGwqRxHJSkpCdOnT4eLiwtcXV0xceJETJw4EQaD4aXnPq/u6tevjz///LPUZ03GY8oge/XqBQAoKiqCTqdD+/btcfDgQXMf8IuoUqWKxXV6/fp1dOnSBSzLwtnZGQMGDMCMGTOe23Zf1i5fxOLFi3HkyBEkJSWhSpUq5uM3b95ErVq1zJlvWFgY2rVrh7///vu5piuXscOWQE1XBGxZjq5x48YICAjA999/jzFjxsDZ2RmXL19Gs2bNeIvNx8cHPj4+vJTnaAQFBcHLywvTp0/H6NGj4eHhAaBkONfLeF7dvfHGG5g/fz62b9+O7t274+DBg8jMzESzZs3g5eWFvXv3msvYs2cPfv/9dyxcuLBShgtYV6emjSNHjBiBGTNmIDg4+IU6X9Qui4qKzDclU/bs7u4OAFixYgV+//13rF69Gr6+vqXKbNCgAe7du4cTJ06gVatWSE9Px+HDhzF8+PAXarU36Hq6IqDVapGWlmb1o9CDBw8wb948nDlzBgzD4M0338SUKVN4ic3JyQk1atSAUqnkpTxHg+M4GI3GcoZX2Tp9Xt2dPn0as2fPRkZGBsLCwjB58mQ0b9683Pnbtm3Dli1bkJycXOmYra1Tg8Fgkc7naWvcuHG5z164cAFAiVm7urqWyo7fe+89vPfeewBKbjI///wz1Go1qlSpgrfeegsTJkyo8EWxvbZdaroioNfrcf36dVk+CjEMg/r16xO5FbYtkFKnpOgUEzp6QQRcXV1l+xjEMIzdNVo5QEqdkqJTTKjpioRcX1TJNS57QK7fHd9xkaJTLKjpioRSqZRdxmDa6ppiHaTUKSk6xYKarkj4+flJHUKFyDUue0Cu3x3fcZGiUyyo6YqEi4uL7O7MSqXS6sHtFHLqlBSdYkFNV0QCAgJk85jGMAwCAgKkDsPuIaVOSdEpBtR0RcTT01MW/WOm/jBPT09J43AESKlTUnSKATVdkQkODpZFww0JCZE0BkeClDolRafQUNMVGRcXF4SGhkrWeBmGQWhoaKWnlFJeDil1SopOoaGmKwHe3t7w9/cXvfEyDAN/f394e3uLel0SIKVOSdEpJNR0JUKlUsHHx0e0xsswDHx9fZ+7NinFdkipU1J0CgVde0FCOI6DRqNBdna2oHPbTVmCSqWSvE/O0SGlTknRKQTUdGVAXl4e0tPTwXEcrw2YYRhzP5gjPJbZE6TUKSk6+YSarkxgWRZqtZq3XYNNQ2uCg4PtdhC5vUNKnZKiky9on65McHFxQc2aNVGlShUcPnwYgOULNJuyA29vb4SFhaFmzZoO2WjtBVOdbtmyBefPnzfXjyXYQ52adKanp+PIkSMOq5MvHFOVnXL9+nW0adMGOTk5OHr0KOrWrQutVgudTmfe9bcspuMKhQJKpRJ+fn4O21jtDY7j8Omnn+Kbb75BnTp1cPXqVeTk5Dhkne7YsQP9+vUDy7IoKirCkydPHFInHzi+QjshNTUVnTt3Rm5uLtzd3ZGZmYmoqCgEBgYCKFlM2rSlyrPbrisUCrtcU9TRYVkWw4cPx+bNm83/dnFxQWBgoMPV6fLlyzF+/Hjo9Xq4uroiPz/fIXXyBTVdGZCbm4t27dqZt5NmGAYajabUZ1xdXYlsoPbKF198gXXr1pm3uTFtr/4sjlCnhw4dwvvvv2/+t0KhgEajKbUCmCPo5BPapysDfHx8sH37dkRERMDJyQl6vR5qtVrqsCg2MG7cOMycORNOTk7w8PCAVqut1A7C9kbbtm2xcuVK+Pr6wtXVFYWFheUSBkppaKYrE7p06QKlUonVq1fjwYMHvO32S5GGatWqISoqCk2bNsW3336LvXv3Osw402dxd3dHnz598O9//xt//vknfvvtN9SuXVvqsGQNHTImEy5duoTOnTvj/v37RLxMIIF//etfiI2NxZgxY6QORVCWLl2KQ4cOYePGjVKHYhfQ7gWZkJiYiHfffZcaroOQlZWFvXv3YtCgQVKHIjgrVqzAqFGjpA7DbqC/cBlQVFSElJQUHD16VOpQKDyRkpKCHj16wNfXV+pQBOXs2bN49OgROnXqJHUodgPNdGXAjh07EBERgbp160odCoUHOI4jJvtLTEzEiBEj7H65RTGhma4MSExMxMiRI6UOg8ITJ0+eRFFREdq1ayd1KIJSWFiIX375BWfOnJE6FLuCmq7E3Lt3D6mpqdi6davUoVB4YsWKFRg5cqRDjlZ4li1btqBFixaoVauW1KHYFdR0JWbVqlV455134OHhIXUoFB7Iz8/H5s2bcenSJalDEZzExESMHj1a6jDsDmq6EmIwGLBy5Ups375d6lAoPLFhwwbExMQgKChI6lAE5ebNm7h48SJ69uwpdSh2B32RJiEHDhxAYGAgXnvtNalDofAEKS/QVq5ciSFDhsDd3V3qUOwOmulKCH2B5lhcvnwZd+/eRVxcnNShCArLsli1ahX2798vdSh2Cc10JeLRo0fYt28f3nnnHalDofAEKRNcdu/ejVq1aqFRo0ZSh2KXOHbrkDEpKSno1auXww+eJ4Xi4mKsWbOGiAkuiYmJRHShCAXNdCXANHiedi04Djt27ECjRo0cfoLLgwcPcPjwYQwYMEDqUOwWaroScOLECej1esTExEgdCoUnSOmfT05ORt++faFUKqUOxW6h3QsSQMrgeVK4f/8+Tp48iS1btkgdiqCYntCSk5OlDsWuoaYrMlqtFps3b8aVK1ekDoXCE0lJSURMcPnrr7/g5uaGNm3aSB2KXUNNV2Q2bNiA2NhYqFQqqUOh8IBpgsu2bdukDkVwTF0o9AnNNmifrsiQMnieFA4ePIiAgACH3+njyZMn2LFjB4YOHSp1KHYPNV0RuXTpEu7fv49u3bpJHQqFJ0gZPvXLL7+gS5cu5h1+KdZDTVdESBk8Twqm3SFImOBChzjyB/31i4Rpd4hjx45JHQqFJ1JSUtCzZ0+Hn+By9uxZZGVl0d0heIJmuiJh2h2iTp06UodC4QGSJrgkJiZi+PDhdHcInqCZrkiQ0vdHCnR3CIq10ExXBO7du4dTp06hT58+UodC4QlSJrjQ3SH4h2a6IkDK4HlSyM/Px6ZNm3D58mWpQxGcFStWYMyYMVKH4VBQ0xUYg8GApKQk7NixQ+pQKDxhmuBCwu4Qly5dortD8AztXhAY0+4QTZs2lToUCk+Q8gKN7g4hDDTTFRj6As2xuHz5Mu7du0d3h6BYDc10BYTuDuF4JCYmYtiwYQ4/wYXuDiEcjt1yJMa0O4SPj4/UoVB4gO4OQeEDarpWoNfrodPpYDQawXEcGIaBk5MTFAoFXF1dAfz/4PmlS5dKHK31VEanI1BZnY6wO0RltJp2h1izZo3E0Tom1HQrAcuyyMnJgVarhU6nMzfWspiOKxQKZGdnw93d3a52h7BGp1KphJ+fn109blurc926dXaX/VmjNTU1FYMHD6a7QwgEw3EcJ3UQcqWgoABZWVnQarUAShpmZTEYDGAYBr6+vggICICnp6dQYdqMLTpNP2ClUunwOouLi+Hr64vq1avLWidgm1a9Xg8XFxf4+PjIvk7tEWq6FcCyLNRqNbRarUWN9XkwDAOlUong4GBZZYRUp3XIVSdAllZ7hZpuGfLy8pCeng6O43hptCYYhgHDMAgNDYW3tzdv5VoL1WkbctMJkKXVnqGm+39wHAeNRoPs7GxeG2xZGIaBv78/VCqVJPP2qU5+kVonQJZWR4CaLkoabUZGBnJzcwVttCYYhoGPjw9CQkJEbbxUpzBIpRMgS6ujQCdHANBoNKI1WqDkh5KbmwuNRiPK9UxQncIglU6ALK2OAvGmm5eXJ/hjWUVwHIfs7Gzk5eWJcj2qU1jE1gmQpdWRINp0WZY1v3iQAo7jkJ6eDpZlBb0O1SkOYukEyNLqaBBtumq1WrJGa4LjOKjVakGvQXWKhxg6AbK0OhrEmm5BQQFvYxltgeM4aLVaFBQUCFI+1SkuQusEyNLqiBBrullZWZI3WhMcxyErK0uQsqlO8RFSJ0CWVkeESNNlWdY8PVIuaLVa3vvHqE7pEEInQJZWR4VI083JybH63Dt37qBfv35o3bo11q5dy2NUtsVla3lC6ioL1Sl8mWLqBITR6qgQabq29IclJSWhZcuWOHHiBAYPHoyMjAwsWbLE5phM/WN8YonOZ3XVq1cPI0aMQFRUFLp27Vrh51NSUtCtWze0atUKPXv2xN27dyt1HTnpNBgM6NatG9q0aYOOHTviq6++Mmdsjx8/xuTJk9GxY0dERUVh6NChOH/+fKVjEkInUHmtZdvp5cuXMWzYMLRq1QqxsbFISUkpd05qaioaN26MhQsXWhSTUFodFSJNV6fTWX2uWq1G3bp18c8//2DZsmUwGAwAgFOnTmHZsmWSxWVreSZdAODh4YE+ffpg4sSJFX528+bN2LJlCxYvXowTJ05g8eLF8PPzEyQuvst7VmeHDh2wYcMGHD9+HFu3bsX169fNWWFBQQEaNWqE9evX48iRI+jZsyc++ugji14a8a3TkjKf1ZmTk4PRo0ejf//+OHLkCH7//XdER0eX+rxer8dXX32FJk2aCBoXhUDT1ev1Vme5I0eORGpqKubOnYv33nsPnp6eSEhIwO7du3HkyBEMGTLEptg4joNer7epDBOW6HxWV6tWraBUKtGjRw+EhoaW+6zRaMTSpUsxefJk1KlTBwzDoEaNGhbtjiEXnQaDwbyAi2k92bS0NABAjRo1MGzYMAQGBsLZ2Rn9+/eHXq/HnTt3Kh0bnzqBymstq3PevHmIjo5G9+7d4ebmBi8vL9SuXbvUOatXr0Z0dDReeeUVq2LjW6sjQ5zp6nQ6q+eMJyYmIjIyElOnTsXJkycREhJi/puTk+1fJcMwVmUMsbGxaN++PVJTU83HLNFZVteLfniZmZnIzMzEzZs30alTJ3Tr1g2LFy+G0WisdLzW6pwzZw5q166NjRs3mq9nq87ffvsNbdq0QUxMDK5fv47+/ftXeO7Vq1eh1+tRs2bNSsdrrc6jR4+iatWq+PLLL/H06VPz8cpqLavz8ePH8PHxwZAhQxAbG4uxY8fiwYMH5s+r1Wps27YNH374ocWxmrBWK4kQZ7qWmMOL+Oeff3D9+nVMnz4dcXFxiIqKsvmFBcdxKCwsREFBgUX/mbZXiY2NRUxMDI4cOcKbzrJkZmYCKDGGLVu2IDExEbt378aWLVsE15mZmYk7d+5g+PDhqF27NlJSUmx+a/7WW2/h+PHj2LVrF/r374+qVauW+0x+fj6mTJmC0aNHW7ybgrU6CwsL8Z///AdBQUFISEhAXl6e1XWamZmJHTt24LPPPsO+ffsQEhKCyZMnm/8+b948jB071ubFyoVqc44GcasS8zW+sWnTpmjatCkyMjIAAC1btkTLli1tKjM/Px+ff/459u7da9F5pgyjsLAQR44cQUxMDC5evCjIKlDu7u4AgOHDh8Pb2xve3t7o378//v77b/Tr169SZRQWFuKDDz6wWGdxcTEA4OnTp3j69CmGDh2K5cuXl+uftIZatWqhbt26mD17Nr7//nvzcZ1Oh7Fjx6Jp06YWb9VTVFRklU6DwWDWCgAzZszA9u3bceDAAYvKMeHu7o6OHTsiIiICADB69GjExMRAq9Xi9OnTKCgoQLdu3awq+1nkMnZY7hBnunwbUUhICMaMGcNLWUqlEqtWrbJ49+Dw8HDcvHkTbm5u6N69O+bOnYtq1aqZbwh88sorr8DV1dWm79HT09MqnZMmTcKCBQvg5eWF8PBwzJ8/H82aNeNtKirLsuY+XaDE5MePH4/q1atjxowZFpfn7u5ulc7du3ejb9++cHZ2RpUqVTBnzhzEx8dbPfOrfv36perr2f8/ceIELl26hPbt2wMoufE7OTnhxo0bWLRokUXXoUs9Vg7iuhf46HsVEmviCw8PR9++fXH+/Hls3LgR9erVs0mn0WhEUVERWJYFx3EoKioyvyTx8PBAt27dkJSUhKdPn0Kj0WDTpk2IjY216BrWxFejRg00b94cO3fuxKlTp9ChQwc4OztbXI6JzZs34/HjxwCAW7duITExEa1btwZQ8tJq4sSJcHd3x5w5c6z+Pq05LzAwEMHBwfjhhx+QlpaGkSNHwtXV1eoYevfujYMHD5r7pX/66SdERkZCqVRi7Nix2LVrFzZt2oRNmzahffv2ePvttzF79myLryP335ZcIC7TVSgUsn0M4jgOCoXC4vN27NhR7pgtOk+fPo0RI0aY/92iRQu0aNECSUlJAICpU6di1qxZ6NixI5RKJd5++2306dOn0uVbq3PChAmYMGFCqWO26Dx79iwWLlyIwsJC+Pn5oUuXLhg7diwA4Ny5czh8+DAUCkWp7oulS5eiefPmlSrfWp0tWrTArVu3yh23Vmvr1q0xfvx4fPTRRygsLERkZCS++uorAICXlxe8vLzMn3V3d4eHh4fF2bm1WkmEyJ0jLl++LMtOfycnJzRs2JC38qhOaeFbJ0CWVkeFyOcBud6R+Y6L6pQWIeIiSaujQqTpKpVK2XX6m7a65hOqUzqE0AmQpdVRIdJ0LZmyKiZ8x0V1SosQcZGk1VEh0nRdXFxkd2dWKpVwceH3vSbVKR1C6ATI0uqoEGm6ABAQECCbxzSGYRAQECBI2VSn+AipEyBLqyNCrOl6enrKon/M1B9m6xTM50F1iovQOgGytDoixJouAAQHB8ui4T67cI4QUJ3iIYZOgCytjgbRpuvi4oLQ0FDJGi/DMAgNDbVpVlVloDrFQSydAFlaHQ2iTRcAvL294e/vL3rjZRgG/v7+5vVchYbqFBaxdQJkaXUkiDddAFCpVPDx8RGt8TIMA19fX6hUKlGuZ4LqFAapdAJkaXUUiJwGXBEcx0Gj0SA7O1vQtRlMWYJKpZLk0ZDq5BepdQJkaXUEqOmWIS8vD+np6eA4jtcGzDCMuR9MDo9lVKdtyE0nQJZWe4aabgWwLAu1Wm3TrsHPYhpaExwcLKtB5FSndchVJ0CWVnuFmu4LKCgoQFZWlnl7aUu+KtPjl1KpREBAgKzHMlKdL8eedAJkabU3qOm+gD179uDtt9/G48ePodVqodVqodPpzDvHlsV0XKFQQKlUws/Pz66yA5ZlkZOTQ3X+H/auEyBLq71ATfc5bN26FQMHDoRer0daWlqpQeB6vR46nQ5Go9HcWJ2cnKBQKODq6iph1PxCdTqWToAsrXKF3soqYOnSpZg0aRKKi4vh6ekJjUZTynRdXV2JaKBUp+NBkla5Qk23DFeuXMGYMWPMj2AuLi7QaDQSR0WhUBwFOjmiDA0aNMDJkyfx6quvwsXFBQUFBdR0KRQKb9BMtwKaNGmCR48e4ezZszh37lypjQkpFArFFqjpVsD27dvRtGlTREREICIiQupwKBSKA0G7FypgxYoVGDlypNRhUCgUB4QOGSvD3bt30aJFC6Snp9MdTikUCu/QTLcMSUlJGDRoEDVcCoUiCDTTfQaDwYCwsDDs2rULTZo0kTocCoXigNBM9xn2798PlUpFDZdCoQgGNd1noC/QKBSK0NDuhf/j4cOHCA8Px7179+iaoRQKRTBopvt/rFmzBr169aKGS6FQBIWaLkqWtUtMTMSoUaOkDoVCoTg41HQBHDt2DEajEW3btpU6FAqF4uBQ08X/v0Cjm+1RKBShIf5FWl5eHmrVqoWrV6+ievXqUodDoVAcHOIz3fXr16NDhw7UcCkUiigQb7r0BRqFQhETok334sWLSE9PR9euXaUOhUKhEALRppuYmIjhw4fD2dlZ6lAoFAohEPsiraioCKGhoThx4gRq164tdTgUCoUQiM10t23bhqZNm1LDpVAookKs6SYmJtLFbSgUiugQ2b1Ad4egUChSQWSmm5SUhMGDB1PDpVAookNcpkt3h6BQKFJCXKZLd4egUChSQpzp0t0hKBSKlBDVvUB3h6BQKFJDVKZLd4egUChSQ4zpchyHFStW0MVtKBSKpBBjuseOHQPHcXR3CAqFIinEmC7dHYJCocgBIl6kOcLuEHq9HjqdDkajERzHgWEYODk5QaFQwNXVVerweIMUnQA5WknRWVlcpA5ADOxxdwiWZZGTkwOtVgudTmdurGUxHVcoFFAqlfDz84OLi/1UKyk6AXK0kqLTWojIdNu0aYPp06fjrbfekjqUl1JQUICsrCxotVoAJQ2zspgatlKpREBAADw9PQWJkQ9I0QmQo5UUnbbi8KZ74cIFxMXF4d69e7JerJxlWajVami1Wosa6/NgGAZKpRLBwcGyyh5I0QmQo5UUnXzh8KY7YcIEKJVKJCQkSB3Kc8nLy0N6ejo4juOl0ZpgGAYMwyA0NFQWY5NJ0QmQo5UUnXzi0KZr2h3i5MmTCAsLkzqccnAcB41Gg+zsbF4bbFkYhoG/vz9UKpUkozdI0QmQo5UUnULg0EPGTLtDyNVwMzIyBG+0pmtlZ2cjIyND8GtVdG0SdJquT4JWUnQKhUObrpxnoGk0GuTm5orWkDiOQ25uLjQajSjXM0GKToAcraToFAqHNd07d+7g7Nmz6N27t9ShlCMvL0+ULKEspqwhLy9PlOuRohMgRyspOoXEYU1XrrtDsCxrfvEgBRzHIT09HSzLCnodUnQC5GglRafQOKTpGgwGJCUlyXLdXLVaLXnfFMdxUKvVgl6DFJ0AOVpJ0Sk0Dmm6+/btQ0vRr04AAB6JSURBVFBQkOx2hygoKOBtLKMtcBwHrVaLgoICQconRSdAjlZSdIqBQ5quXF+gZWVlSd5oTXAch6ysLEHKJkUnQI5WUnSKgcOZ7sOHD3Hw4EEMHDhQ6lBKwbKseXqkXNBqtbz3j5GiEyBHKyk6xcLhTDc5ORl9+vSR3SyWnJwcq8/t3bs3UlNTeYzm/7ElLlvLE1JXWfjWaUmZYuoEaJ3KHYeakcZxHBo2bIjly5fj9ddflzqcUty+fdvmfiiO47Bo0SJs374dBQUFePXVVzFt2jTUrVvX6jI9PT1Ru3Ztm+J6Fmt03rhxA/Pnz8fly5fx5MkTXLhwodxndu/ejaVLl0Kj0aBq1aqYPXs2mjdvXulr8K0TsFzr7t27sWTJEmRlZcHNzQ2vv/46pkyZgipVqqC4uBizZ8/G8ePHkZubixo1amD8+PGIiYmxOC451CkApKWl4csvv8SpU6fg5uaGPn36YOLEiaU+c+/ePfTt2xedO3fGl19+aVH5QtSpGDhUpnv06FHZ7g6h0+lsLmPv3r3Ytm0bVq1ahSNHjqBp06aYOnWq5HHZWp6Liwu6du2KWbNmVfj3o0eP4rvvvkNCQgKOHz+OVatWITQ0VPC4+C6zWbNmSE5OxrFjx7B7926wLItFixYBKHmEV6lUSEpKwrFjx/Dxxx/jf/7nf5CRkSF4XEKUp9fr8f7776NVq1b4448/cODAgQpX+ZszZw4iIiJEi0sOOJTpml6gyW2Otl6vt+klRNeuXXHs2DFkZGSgWbNmqFGjBpydndG9e3fcunXLptg4joNer7f4vPz8/HLnWarTpCssLAx9+/Z9bsa+ZMkSfPjhh2jatCmcnJxQvXp1i9dGtlanXq9Hfn5+hccrq9WkU6VSwc/Pz3zc2dkZ9+/fB1CStY0ZMwYhISFwcnJCbGwsQkJCcPnyZYtjtlbrkydPyh2ztk63bduGatWqYdiwYfD09IS7uzvCw8NLfXb37t1QKpVo3bq1xbEC1uuUGocx3by8PGzduhXx8fFSh1IOnU7Hy40gLi4OaWlpuHv3LvR6PXbs2GFzVs8wjFUZw7vvvouQkBAsW7YMxcXFAPjT+SwGgwGXLl1CdnY23nzzTbzxxhuYM2eOxTFbqzMlJQWBgYGYPHkyHj9+bD5urdYzZ84gKioKrVu3xoEDBzB06NAKP5eVlYV79+6hTp06Fl/DGq23b9+Gv78/+vTpg6tXr5qPW6vz/PnzCA4OxocffoiYmBgMHz4c169fN/89Pz8fixcvxieffGJx2SasrVOpcZjFKn/99Ve88cYbqFatmtShlMNoNPJSTmBgICIjI9GjRw84OztDpVJhxYoVNseWlpZm8Q8rMzMTjx49woQJEzBlyhSMHz8eY8eOtSmWinj8+DFYlsX+/fuxevVquLi4YNy4cVi2bBnGjRtX6XI4jrNK5927d2EwGLBw4UIsWrQIgwYNQkJCAry8vCyVAgCIjIzEsWPHkJmZic2bNyM4OLjcZ/R6PT777DP07NnTqj5La+r01q1bUCgU2LFjB3bv3o3XX38dc+bMwauvvmrx9YGS9pGamoqFCxeiTZs2SElJwbhx47Bz5064urrixx9/RJ8+faBSqawq3wRfvy0xcRjTTUxMxMyZM6UOo0L4ele5dOlSXLx4Efv370dAQAB27dqFUaNGYevWrfDw8LCqzMLCQixbtgxHjhyx6Ly0tDTz+YWFhZg5cyaioqIQEhJiVRzPw93dHQAwaNAgBAYGAgDi4+MtNt2ioiIsXbrUYp05OTlgWdZchytXrgTLsvjhhx8sKqcs1atXR9u2bTF58mRs2LDBfNxoNGLq1KlwdXW1ur++oKAAy5cvt0hrUVERioqKYDQaUVRUhIMHD+LOnTs4ffq0VTG4u7ujWbNm5heB7777LpYtW4bbt2+D4zgcP34cGzdutKrsZ7HHcQAOYbrnz5+HWq1G165dpQ6lQvh65L527Rq6detmzg569+6Nr7/+Grdv30ajRo2sKtPLywvz58+Hj4+PRefFxcVh//79cHV1RXx8PGbOnAkvLy+rXvy8CB8fH1SvXr3Ud2jN96lQKKzSuXTpUowbNw6urq6IiorCN998g8jISOTm5locQ1kMBoP55gWUGMiMGTPw+PFjLFmyxOpNG6tUqWKx1itXruC1116Dq6srqlevjm+++Qa9e/e2enxu/fr1ce7cuQr/lpqaCrVajc6dOwMouUkYjUYMGDCg1A2oMsjt/U1lcAjTTUxMxPDhw2W7HY+TEz9d5xEREdi3bx+6desGf39//Pbbb2BZFjVq1BA9vujoaNSsWRMzZ840PyJb+wPlOA7FxcXmlyJFRUVgGAZubm4ASm4u69atQ9u2beHi4oI1a9agXbt2Fl/HGp316tVDly5dkJCQgMjISJvK2rVrF5o3b46goCCo1WosXLiw1EukhIQE3LlzB8uXL7d5oSZL4wsMDETr1q0xYcIE9O7d23y+tW23e/fu5pEarVq1wtq1a+Hr64vatWujZs2aiIuLM3921apVUKvV+Pzzzy2+Dl+/LTGxe9PV6XRYu3atqIPPLUWhUPDyGDRixAg8fvwY/fv3R2FhIWrWrIkFCxbYNBGE4zirfuDTp08vd8xanWq1Gt26dTP/u0WLFggODsbevXsBAB988AGePHmCHj16wM3NDV27dsX7779v0TWs1dmpUyd06tSp3HFrtN6+fRvfffcdtFotlEolYmJiMGHCBAAl38HGjRvh5uaG9u3bm8+ZMWMGunfvbtF1rNEaEBCAv/76q9xxa+s0LCwM8+bNQ0JCArKzs9GgQQMsWrQIrq6ucHV1LdUd5unpCTc3N/j7+1t0DWvrVGrsfnLEr7/+isTEROzfv1/qUF7I5cuXZdnp7+TkhIYNG/JWHik6AXK0kqJTLOwvNy9DYmKiLJdwLItc78h8x0WKTqHK5ANap/LGrk33zp07OHfunCx3hyiLUqmUXae/aatrPiFFJ0COVlJ0ioVdm+7KlSsxaNAgu7jjPTsTSU7wHRcpOoUqkw9oncobuzVdOe8OUREuLi6yuzMrlUq4uPD7LpUUnQA5WknRKRZ2a7p79+5FSEiI7HaHAEpGVLRv3x6NGzdGw4YN8corr8DHxwcXLlyQzWMawzAICAgQpOyAgAAidALkaCVFpxjY560CJYvbyDXLdXd3h0ajwbVr18zHfHx8EBUVhfz8fMm3PTH1h3l6egpSvqenJ5RKpcPrBMjRSopOMbDLTDczMxOHDh2S3e4QJu7cuVNq6UEPDw/8+eef8Pf3R3BwsOQZA8MwvE/XLQspOgFytJKiU2js0nTlujtEWloaPvjgA7Rs2RJt27ZFo0aN4OLignnz5uG1114DUNI/FhoaKlnjZRgGoaGhgs/eI0UnQI5WUnQKjd2ZLsdxSExMlNXGkw8ePMC4cePQtGlT+Pn54fr165g1axaWL1+O4cOHl1uYxdvbG/7+/qI3XoZh4O/vL9rNihSdADlaSdEpJHbXp/vf//4XDMMgOjpa6lCQlZWFr7/+GitWrMCwYcNw5cqVUotrR0VFISoqqsJzVSoVDAYDcnNzRekjYxgGvr6+Ni+lZymk6ATI0UqKTqGwu0zX9AJNyr6lJ0+eYPr06QgPD0d+fj7Onz+P7777zqLdDEx9U2JkDaYsQYo+OVJ0mq5PglZSdAqFXa29kJubi1q1auH69euSLFau1Wrxww8/4Pvvv0fPnj0xffp0hIWF2VxuXl4e0tPTwXEcr5kDwzDmfjA5PJaRohMgRyspOvnErkz3559/xr59+7B582ZRr1tQUIDFixdj/vz56NSpE2bOnIn69evzeg2WZaFWq3kbkmMaWhMcHCyrQeSk6ATI0UqKTr6Qhenq9XrodDoYjUZwHAeGYeDk5ASFQlFqIedWrVph1qxZpdbiFBKdTofly5dj3rx5iI6OxqxZs6xeLLyyFBQUICsry7w2rSXVY3r8UiqVCAgIkPVYRlJ0AuRoJUWnrUhiuizLIicnB1qtFjqdzmy0ZTEdVygUYFkWffr0wfnz5wUfMqLX65GUlITZs2ejSZMmSEhIQLNmzQS9Zlms+Y6USiX8/PzsKjsgRSdAjlZSdFqLqKZr652Q4zh4e3sLdidkWRZr167FrFmzULduXfznP/9BmzZteL+ONVT2acDeIUUnQI5WUnRWFlFMV+59PkajERs2bMAXX3yBatWqISEhAbGxsTaXS6FQKGUR3HTl/HaT4zhs374dM2bMgIeHB2bPno1OnTo5zNAUCoUiPwQzXY7joNFokJ2dLegAatM4PpVKVWmz5DgOe/bswfTp02EwGJCQkIC33nqLmi2FQhEcQUyX4zhkZGSIOmPFx8cHISEhLzXOQ4cO4fPPP0dubi5mzZqFvn372uWOohQKxT4R5FWhRqMRzXCBEpPPzc2Fs7MzgoKCKvzMf//7X0yfPh1paWn44osvMHDgQLtfOINCodgfvKd4eXl5gncpVATHccjOzkZeXl6p46dOnUJcXBwGDx6MIUOG4MqVKxg8eDA1XAqFIgm8mi7LsuaXZlLAcRzS09PBsizOnz+PPn36oHfv3ujRoweuXbuGESNGEDEOkEKhyBdeTVetVku6qjzw/yMSunTpgnbt2uHGjRsYM2YM3N3dJY2LQqFQAB5fpBUUFODOnTuSmy5QMu42JCQEVatWlToUCoVCKQVvmW5WVpYsDBcAnJyc8PTpU6nDoFAolHLwYrosy5qn9lpD7969kZqaykcoZrRaLViW5bVMCoVCsRVeTDcnJ8em87dt24aWLVua/z1t2jRbQwJge1wUCoXCN7yYLh9rKjx48ABfffUVCgsLAQDXr1/H/PnzrS6P4zibsm8KhUIRAl5MV6fT2XR+165dcffuXXTu3BlTp05FamoqNmzYgBEjRkgaF4VCofCNzaar1+sFeYHm5ORk81oIHMdBr9fzFBGFQqHYjs2mq9PpeFko5tGjR9i/fz/mzp2Lli1b4u2338bKlSttKpNhGJrtUigUWWHz9Cyj0chHHAgMDETPnj3N/w4PD0d4eLjN5fIVH4VCofCBzZmuEF0Lc+bM4a0suYwdplAoFIAH05X7GrRyj49CoZCFzaYr97Vo5R4fhUIhC5v7dBUKhc2P8Hv37rU1jArhOA4KhUKQsikUCsUabE4DXV1dZfsIzzAMkbuNUigU+cLLs7dcs0m5xkWhUMiFF9NVKpWyy3Z1Oh02bdqEzZs3o6ioSOpwKBQKBQBPpuvn58dHMbyiUCgQFBSEH3/8EcHBwRg9ejSOHTtGh5BRKBRJ4W0R8/v375fbn0xKvL29UbNmTQDAvXv3kJKSguTkZBiNRsTHx2Po0KF45ZVXpA2SQqEQh0PuHMEwDMLCwuDp6VnqOMdxOHnyJJKTk7F+/Xo0atQI8fHx6NevH3x8fCSKlkKhkARvpguUZLt8LPNoCwzDQKlUmrPc51FUVITff/8dycnJOHToEN58803Ex8ejc+fOdPNKCoUiGLyaLsuyuH79uqTrHTg5OSE8PNyiLdYfP36M9evXIzk5Gffu3cPgwYMRHx+PJk2aCBgphUIhEV5NFwDy8vKQlpYmSbbLMAxq1KgBb29vq8u4du0a1qxZgzVr1sDPzw/x8fEYNGgQVCoVj5FSKBRS4d10gZJdILKzs0U1XoZh4O/vj6CgIF7KMxqNOHz4MJKTk7Ft2zZERUUhPj4evXr1goeHBy/XoFAo5CGI6XIch4yMDOTm5opivAzDwNfXF8HBwYKMF3769Cm2bduG5ORkpKamom/fvoiPj8frr79O13agUCgWIYjpAiXGq9FoBM94TRmuSqUSZYJGRkYG1q1bh+TkZOTn52Po0KEYOnQo6tWrJ/i1KRSK/SOY6ZrIy8tDeno6OI7j1XwZhgHDMAgNDbWpD9daOI7DP//8g+TkZKxbtw61a9dGfHw8BgwYAH9/f9HjoVAo9oHgpguUjGpQq9W8DSczDQsLDg6WxfAuvV6P/fv3Izk5GXv27EGnTp0QHx+PuLg4u15wR6/XQ6fTwWg0guM4MAwDJycnKBQKu9ZVFlJ0AuRolbNOUUzXREFBAbKyssxbo1tyaVPXgVKpREBAQLmJD3LhyZMn2LRpE1avXo1r165h4MCBiI+PR/PmzWW3PkVZWJZFTk4OtFotdDqdubGWxXRcoVBAqVTCz89PFje/ykKKToAcrfakU1TTNWFPX5At3Lp1yzz92N3dHfHx8Rg8eDBq1KghdWilIOFmCJCjEyBHqz3qlMR0yyLnRwE+4DgOR48eRXJyMjZt2oRmzZohPj4effv2RZUqVSSLy9G7fUyQohMgR6s965SF6ZKETqfDzp07kZycjL///hs9e/ZEfHw8OnToYNEsOltx1BecZSFFJ0COVnvXSU1XQh4+fIhffvkFycnJyMzMxJAhQxAfH4+GDRsKdk1HHcpXFlJ0AuRodRSd1HRlwsWLF7FmzRqkpKQgKCgI8fHxeOeddxAYGMjbNaSYtOLj44OQkBBRf6Sk6ATI0epIOqnpygyDwYBDhw4hOTkZO3fuRLt27RAfH48ePXrA3d3dprIdYXp2ZSBFJ0COVkfSSU1Xxmi1WmzZsgXJyck4d+4cBgwYgPj4eLRp08biu6+9L0RUWUjRCZCj1dF0UtO1E+7fv4+1a9di9erVMBgMiI+Px5AhQxAWFmb+zK5du+Dv74/o6OhS58plyc369esL+maYFJ0AOVodUSddrcVOqFmzJqZMmYIrV65g3bp1ePjwIVq1aoXY2FgkJibiyZMneO+999CpUyecPHmy1LlqtVryHT04joNarRb0GqToBMjR6og6aaZrxxQXF2P37t1ITk7G3r17UVRUBJZloVQqcfToUURERNjFNkp8QIpOgBytjqqTZrp2jJubG3r16oXNmzejR48eMBgMAEr6glu0aIHU1FRkZWXJotECJRlDVlaWIGWTohMgR6uj6pTPFBOKTRw6dAgA4OnpCV9fXzg7O+Pq1auym8Kp1WrBsiyv/YAsy5qngcoFIXQC5Gh1ZJ3UdB2EixcvwsPDo9S04kePHuHhw4cSRlUxOTk5vI4/zsnJ4a0sPuFbp6lMOULrtPLQ7gUHITAwsNw6DkLszLxnzx707NkTrVu3Rq9evXDw4EGLzuc4jvcMxlader0eEydORNeuXdG4cWOkpqaW+8zly5cxbNgw88vLlJSUF5YphE7ANq23bt3Cv/71L0RHRyM6OhqjRo3CrVu3zH9PSkpCnz590Lp1a3Tr1g1JSUmVKleOdQoAhYWFmD17NmJiYhAVFYVhw4aV+4xer0fPnj3xxhtvvLQ8vnTSTNeB0el0vJaXmZmJKVOmYOHChXj99dfx999/Y9KkSdizZw+qVq0qWVx8lNesWTMMGTIEkyZNKve3nJwcjB49Gp988gm6dOkCvV6PzMxMUeLis8zAwEAsWLAAwcHBMBqN+PXXX/HJJ59gy5YtAEpMZc6cOahfvz7S0tLwwQcfQKVSIS4uTtC4hCpv1qxZMBgM2L59O3x8fHD16tVyn0lKSoKfnx+ePn0qWlw003VQ9Hq9VZlC48aNcf/+ffO/p02bhoULFwIoMV1vb2/ExMSAYRi0a9cOHh4eSEtLs+gaHMdBr9dbHFtWVla5oTuV1fkiXa6urhg6dCgiIyMr3PMuOTkZ0dHR6N69O9zc3ODl5YXatWu/9JrW6iwuLq7QICqj9UU6vb29zdNaOY6Dk5NTqbobMWIEGjZsCBcXF4SFhaFDhw44e/ZspWK2VuvFixfLjcHlo05v376NP//8EzNnzoS/vz+cnZ3RqFGjUuenp6dj165dGDVqVKXjtVbns1DTdVB0Oh3vc8YbNWqEsLAw/PHHHzAYDDh48CBcXV1Rv359i8phGMaqjOHrr79GzZo1MWrUKKSnpwMQRmdZzp8/Dx8fHwwZMgSxsbEYO3YsHjx48NLzrNX5559/okGDBmjfvn2prg6+tEZHR6NFixaYN2/ecw2H4zicOXMGdevWrVSZ1mgtKipC48aNUbt2bWzYsMFsvnzovHjxIoKCgrB48WLExMSgT58+2L9/f6nPzJs3D+PHj4dCoah0udbW6bPQ7gUHRYgZPM7OzujZsyc+/fRTFBcXw9XVFd9++63FIyQMBgNOnz6NwsJCi867efMmDAYDVq9ejTVr1qB9+/b48ccfLSrDGjIzM3HlyhUsW7YM9erVw4IFCzB58mSsWbPmhecZjUardKampsLDwwOHDx9GTEwM6tevj6+//hpRUVG2yDBz9OhRFBQUYMeOHQgODq7wM0uWLIHRaETv3r0rVSbLshZrLS4uhpOTE+7du4dhw4bh448/xqeffoqRI0dWuoznkZmZiZs3b6Jz5844dOgQzp07h48++gh16tRB7dq1cfDgQRgMBrzxxhsV9uG/CFt/W9R0HRQhxjceO3YMCxYsQFJSEho0aIDLly/j448/xtKlS/Hqq69Wuhy9Xo99+/ZV+tHVhOmRm2VZMAyDAwcO4Ny5c+UeG/nG3d0dHTt2REREBABg9OjRiImJgVarhVKpfO55LMtapfPRo0fmR9iioiJcuHABP//8M9q0aWO9iDJ4enpiwIABaNeuHbZv316qT37dunXYuXMnVq1aBTc3t0qVV1xcbLFWg8Fgbqc6nQ5FRUX4/vvvMWLECMvEVIC7uztcXFzw/vvvw8XFBS1btkSrVq1w9OhRqFQqLFiwAEuWLLGqbFt/W9R0HRRrH888PDxKZSuPHz9G9erVAQDXrl1D8+bNzSYXERGBxo0b4/jx4xaZrkKhwKeffgofHx+LYps0aRJ++OEHuLu7Y/z48fjkk0/g5OSEjIwMm3S9jPr165f6Piv73bq5uVmlc/fu3ejVqxcUCgW6d++OuXPnol69esjNzX3puZboNBqN0Ol0ePjwodl0t27disTERKxatQoqlarSMXt6elqsVafTwdPTE15eXqhfvz6+/fZbtG/fHnl5eZU6/0VaK+ryMtXb/fv3oVarzaMZ9Ho98vPz0b59e6xduxYhISEvvK6tXR+0T9dBqeiFUGUIDw/H77//DoPBgCNHjuDUqVPmvzVq1AhnzpwxZ5xXrlzBmTNnLO7TtTa+2NhYfPbZZ0hPT8fcuXPh5+dX6XJepAsoydSKiooAlPwIi4qKzBlN7969cfDgQVy9ehV6vR4//fQTIiMjX5jlmrBGZ3h4OIYPH47z589j48aNqFevXqXLepHOo0eP4sqVKzAYDMjPz8c333wDb29v80vBXbt24YcffsDy5cut2sfPUq1ubm4YMWIEdu7cidOnT6NDhw7mrboqw4u0Nm/eHEFBQVixYgVYlsXZs2dx8uRJREdHo27duti/fz82bdqETZs2YdasWahatSo2bdpUqRuNtb8tE3TtBQdFr9fj+vXrFj8KXbp0CdOmTcODBw/QsWNHGAwGhIaGYty4cQBKHj1TUlLw+PFj+Pn54Z133qlw/OOLYBgG9evX52X/u8rqfJmurl27lhsZsWfPHnPWs379eixbtgyFhYWIjIzE559//tIfKJ86gcppfZHOvXv34scff0RmZiYUCgUiIiIwfvx4hIeHAwC6deuGzMzMUvF2794dM2bMeGlscqzTmzdvYubMmbhx4waCgoIwbty4Csfjpqam4rPPPqvUmHM+dFLTdWAuX74s6ZJ4z8PJyYnXLYlI0QmQo9WRddLuBQfGkqEwYsJ3XKToFKpMPqB1Wnmo6TowSqVSss0Sn4dpq2s+IUUnQI5WR9ZJTdeB8fPzkzqECuE7LlJ0ClUmH9A6rTzUdB0YFxcXQbItW1Aqlbwvd0iKToAcrY6sk5qugxMQECCbxzSGYRAQECBI2aToBMjR6qg6qek6OJ6enrLoHzP1hwm1qDopOgFytDqqTmq6BBAcHCyLhvuymT62QopOgBytjqiTmi4BuLi4IDQ0VLLGyzAMQkND4ezsLOh1SNEJkKPVEXVS0yUEb29v+Pv7i954GYaBv78/vL29RbkeKToBcrQ6mk5qugShUqng4+MjWuNlGAa+vr4WLZzCB6ToBMjR6kg66TRgwuA4DhqNBtnZ2YJub23KElQqlSSPhqToBMjR6ig6qekSSl5eHtLT08FxHK8NmGEYcz+YmI/az4MUnQA5Wu1dJzVdgmFZFmq1mrddg01Da4KDgwWZGGAtpOgEyNFqzzqp6VJQUFCArKws8/bSljQJ0+OXUqlEQECAoONTbYUUnQA5Wu1RJzVdihmWZZGTkwOtVgudTgeO4yrs0zIdVygUUCqV8PPzk1UW9DJI0QmQo9WedFLTpTwXvV4PnU4Ho9FobqxOTk5QKBS8LcwtB0jRCZCjVc46qelSKBSKiNBxuhQKhSIi1HQpFApFRKjpUigUiohQ06VQKBQRoaZLoVAoIkJNl0KhUESEmi6FQqGICDVdCoVCERFquhQKhSIi/wvXLtN3+LCyRgAAAABJRU5ErkJggg==\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],\n",
" 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],\n",
" 'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],\n",
" 'c64': [2, 3], 'c128': [3, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nfIoOEXizLPQ"
},
"source": [
"This turns out to represent exactly the semantics used by Numpy in mixed float/complex type promotion."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "obx6SlFAhTFA"
},
"source": [
"## Mixed Promotion: Signed & Unsigned Integers\n",
"\n",
"For the next case, let's consider something a bit more difficult: promotion between signed and unsigned integers. For example, when promoting `uint8` to a signed integer, how many bits do we need?\n",
"\n",
"At first glance, you might think it natural to promote `uint8` to `int8`; but the largest `uint8` numbers are not representable in `int8`. For this reason, it makes more sense to promote unsigned integers to integers with twice the number of bits; this promotion behavior can be represented by adding the following connections to the promotion lattice:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"cellView": "form",
"id": "Irp8qFnC_EB8",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2deVgT1/7/38MawLAJyqYWq1IV9xUUUeuCrQvY1mtdUOyqterVX+1i1Xrduljbaq2tiiiirfva4obVWysqrVrcqRsCISiCJBgCmWR+f/BNrihaksxkkjnn9Tw+jySZM+93zslnzpw553wYjuM4UCgUCsUmOIktgEKhUEiCBl0KhUKxITToUigUig2hQZdCoVBsCA26FAqFYkNo0KVQKBQbQoMuhUKh2BAadCkUCsWG0KBLoVAoNoQGXQqFQrEhLmILoNQNnU4HrVYLg8EAjuPAMAycnJwgk8ng6uoqtjyKBZBSp6T4rCs06NopLMuitLQUarUaWq3W1Fgfxfi6TCaDXC6Hn58fXFxotdojpNQpKT4thaEb3tgXGo0GxcXFUKvVAKobZl0xNmy5XI6AgAB4enoKopFiHqTUKSk+rYUGXTuBZVkoFAqo1WqzGuuTYBgGcrkcISEhRPQe7BFS6pQUn3xBg64doFKpkJ+fD47jeGm0RhiGAcMwCAsLg7e3N2/lUv4ZUuqUFJ98QoOuiHAcB6VSiZKSEl4b7KMwDAN/f38EBQXVOrZG4Q9S6pQUn0JAg65IcByHgoIClJWVCdpojTAMAx8fH4SGhkqm8dobpNQpKT6Fgs7TFQmlUmmzRgtU/1DKysqgVCptcj4SIaVOSfEpFDToioBKpRL8tqw2OI5DSUkJVCqVTc9LAqTUKSk+hYQGXRvDsqzpwYMYcByH/Px8sCwryvmlCCl1SopPoaFB18YoFArRGq0RjuOgUChE1SAlSKlTUnwKDQ26NkSj0fA2l9EaOI6DWq2GRqMRVYcUIKVOSfFpC2jQtSHFxcWiN1ojHMehuLhYbBkODyl1SopPW0CDro1gWda0PNJeUKvVDj8+Jiak1CkpPm0FDbo2orS0VJByZ82aZdXxQukiAWu+u5KSEsycORNRUVGIjo7G+++//9hnysrK0KtXLyQmJtpMl7Xl3b17F++++y769u2LNm3aoKCgoMb7S5YswYsvvohu3bphyJAh2LNnT433T506hREjRqB79+6Ii4vD1q1bedFlT9CgayP4HA+rqKjAvHnzcP/+fQDVDzjmzZtndvnG8TGKZVhTp//+978REBCAgwcP4tixYxg/fvxjn/nqq68QHh5uVrlC1Kk5PhmGQY8ePbB06dJa3/fw8MDy5cuRmZmJhQsX4tNPP8W5c+cAVG8BOW3aNLz88svIzMzEkiVL8MUXX+Dq1auPlePIbVd6u0nYKVqt1uJjlUolPv30U5w5cwYGgwGDBg3CqFGjsGDBAly8eBErVqzAxIkTLVqtY40u0qnLd1db3fXp0wdKpRJr166Fs7MzAKBly5Y1jjt37hz+/vtvvPzyy9i5cyfvuvgorzZvs2bNwsiRI5946//OO++Y/t+2bVt06tQJf/31F9q3b4+ysjKUl5djyJAhYBgGkZGRaNq0Ka5fv46IiIg667J3aE/XBuh0Oot7RHq9Hu+88w6Cg4Oxf/9+ZGRkYNCgQab3H94U2hI4joNOp7PoWBLYu3cvunXrhsOHD9eow7rU6ZPq7q+//sIzzzyDWbNmoWfPnhg5ciSysrJqHLdo0SJ89NFHFl1ILanToqIitGrVCj/88AOqqqpMrz/J5z+1y7qg1Wpx4cIFPPvsswCAgIAADBo0CLt27YJer8e5c+dQWFiIjh071nq8o7ZdGnRtgFartXjN+Pnz53H37l3MmDEDnp6ecHd3R8uWLbFx40Z8/PHH6NSpEyZNmoQ1a9ZYFNgZhnHYHoMtyMvLw5kzZxAfH4/27dubgm9d6rS2uuvYsSOKiopw4sQJdO3aFb/++ivGjRuHqVOnmsYoN27ciDZt2qB169YWabakTktKSnDjxg3MmDEDYWFhpuD7JJ9P8mYO8+fPR0REBHr06GF67YUXXsD333+PTp06Yfz48Xj33XcRFBTEm097gA4v2ACDwWDxsUqlEsHBwTX2FfXw8MAnn3xi+jskJARz5syxqPzy8nIkJSXhv//9r8UapUxFRQVYlgXLssjOzkb//v0xcuRIfP/99/94bG11BwAymQyhoaEYPnw4AGDQoEFYtWoVzp49i8jISGzatAmbN2+2WLNarTa7TlmWRVVVFSorK/HgwQO8/fbbWLBgAS5cuGCWt7ry5Zdf4u+//8batWtNQf3GjRuYOXMmvvrqK0RFRSE3NxeTJ09GgwYN0KtXr1rLsea3JRY06NoAax6gBQUFQalUgmXZWhv4woULrZEGLy8vLF26FDKZzKpypEpKSgpmz54NFxcXuLq6Yvr06Zg+fTr0ev0/HvukumvRogWOHj1a47PGwGPsQQ4bNgwAUFlZCa1Wi969eyMjI8M0Bvw06tWrZ3ad5uTkYMCAAWBZFs7OzhgxYgTmzJnzxLb7T+3yaaxYsQLHjx9HSkoK6tWrZ3r92rVraNKkiannGx4ejl69euG33357YtC1l7nD5kCDrg2wZju6Nm3aICAgAF9//TUmTZoEZ2dnXLp0CR06dOBNm4+PD3x8fHgpT2oEBwfDy8sLs2fPxsSJE+Hh4QGgejrXP/Gkunv++eexZMkS7N69G4MHD0ZGRgaKiorQoUMHeHl54cCBA6Yy9u/fj19++QXLli2rU8AFLKtTY+LICRMmYM6cOQgJCXmqz6e1y8rKStNFydh7dnd3BwCsWbMGv/zyC9avXw9fX98aZbZs2RK5ubk4deoUunbtivz8fBw7dgxJSUlP9epo0P10bYBarUZeXp7Ft0KFhYVYvHgxzpw5A4Zh8MILL+DDDz/kRZuTkxMaNWoEuVzOS3lSg+M4GAyGxwJeXev0SXX3559/YsGCBSgoKEB4eDhmzpyJTp06PXb8rl27sGPHDqSmptZZs6V1qtfrzfL5JG9t2rR57LPnz58HUB2sXV1da/SO33jjDbzxxhsAqi8yP/zwAxQKBerVq4cXX3wR06ZNq/VBsaO2XRp0bYBOp0NOTo5d3goxDIMWLVoQmQrbGkipU1J82hI6e8EGuLq62u1tEMMwDtdo7QFS6pQUn7aEBl0bYa8PquxVlyNgr98d37pI8WkraNC1EXK53O56DMZU1xTLIKVOSfFpK2jQtRF+fn5iS6gVe9XlCNjrd8e3LlJ82goadG2Ei4uL3V2Z5XK5xZPbKeTUKSk+bQUNujYkICDAbm7TGIZBQECA2DIcHlLqlBSftoAGXRvi6elpF+NjxvEwT09PUXVIAVLqlBSftoAGXRsTEhJiFw03NDRUVA1SgpQ6JcWn0NCga2NcXFwQFhYmWuNlGAZhYWF1XlJK+WdIqVNSfAoNDboi4O3tDX9/f5s3XoZh4O/vD29vb5uelwRIqVNSfAoJDboiERQUBB8fH5s1XoZh4Ovr+8S9SSnWQ0qdkuJTKOjeCyLCcRyUSiVKSkoEXdtu7CUEBQWJPiYndUipU1J8CgENunaASqVCfn4+OI7jtQEzDGMaB5PCbZkjQUqdkuKTT2jQtRNYloVCoeAta7Bxak1ISIjDTiJ3dEipU1J88gUd07UTXFxc0LhxY9SrVw/Hjh0DYP4Gzcbegbe3N8LDw9G4cWNJNlpHwVinO3bsQHZ2tql+zMER6tToMz8/H8ePH5esT76QpisHJScnB927d0dpaSlOnDiBZs2aQa1WQ6vVmrL+PorxdZlMBrlcDj8/P8k2VkeD4zi8//77+OKLL/Dss8/iypUrKC0tlWSd7tmzBy+//DJYlkVlZSXu378vSZ98IH2HDkJWVhb69++PsrIyuLu7o6ioCFFRUQgMDARQvZm0MaXKw2nXZTKZQ+4pKnVYlkVSUhK2b99u+tvFxQWBgYGSq9PVq1dj6tSp0Ol0cHV1RXl5uSR98gUNunZAWVkZevXqZUonzTAMlEpljc+4uroS2UAdlU8++QSbNm0ypbkxpld/GCnU6ZEjR/Dmm2+a/pbJZFAqlTV2AJOCTz6hY7p2gI+PD3bv3o3IyEg4OTlBp9NBoVCILYtiBVOmTMHcuXPh5OQEDw8PqNXqOmUQdjR69OiBtWvXwtfXF66urqioqHisw0CpCe3p2gkDBgyAXC7H+vXrUVhYyFu2X4o4NGjQAFFRUWjXrh2+/PJLHDhwQDLzTB/G3d0dCQkJ+Pe//42jR4/i559/RtOmTcWWZdfQKWN2wsWLF9G/f3/cvn2biIcJJPCvf/0LsbGxmDRpkthSBGXlypU4cuQItm7dKrYUh4AOL9gJycnJGD9+PA24EqG4uBgHDhzAqFGjxJYiOGvWrMHrr78utgyHgf7C7YDKykqkpaXhxIkTYkuh8ERaWhqGDBkCX19fsaUIytmzZ3H37l3069dPbCkOA+3p2gF79uxBZGQkmjVrJrYUCg9wHEdM7y85ORkTJkxw+O0WbQnt6doBycnJeO2118SWQeGJ06dPo7KyEr169RJbiqBUVFTgxx9/xJkzZ8SW4lDQoCsyubm5yMrKws6dO8WWQuGJNWvW4LXXXpPkbIWH2bFjBzp37owmTZqILcWhoEFXZNatW4dXX30VHh4eYkuh8EB5eTm2b9+Oixcvii1FcJKTkzFx4kSxZTgcNOiKiF6vx9q1a7F7926xpVB4YsuWLYiJiUFwcLDYUgTl2rVruHDhAoYOHSq2FIeDPkgTkcOHDyMwMBDt27cXWwqFJ0h5gLZ27VqMGTMG7u7uYktxOGhPV0ToAzRpcenSJdy6dQuDBg0SW4qgsCyLdevW4dChQ2JLcUhoT1ck7t69i4MHD+LVV18VWwqFJ0hZ4JKeno4mTZqgdevWYktxSKTdOuyYtLQ0DBs2TPKT50mhqqoKGzZsIGKBS3JyMhFDKEJBe7oiYJw8T4cWpMOePXvQunVryS9wKSwsxLFjxzBixAixpTgsNOiKwKlTp6DT6RATEyO2FApPkDI+n5qaiuHDh0Mul4stxWGhwwsiQMrkeVK4ffs2Tp8+jR07dogtRVCMd2ipqaliS3FoaNC1MWq1Gtu3b8fly5fFlkLhiZSUFCIWuPz3v/+Fm5sbunfvLrYUh4YGXRuzZcsWxMbGIigoSGwpFB4wLnDZtWuX2FIExziEQu/QrIOO6doYUibPk0JGRgYCAgIkn+nj/v372LNnD8aOHSu2FIeHBl0bcvHiRdy+fRtxcXFiS6HwBCnTp3788UcMGDDAlOGXYjk06NoQUibPk4IxOwQJC1zoFEf+oL9+G2HMDpGZmSm2FApPpKWlYejQoZJf4HL27FkUFxfT7BA8QXu6NsKYHeLZZ58VWwqFB0ha4JKcnIykpCSaHYInaE/XRpAy9kcKNDsExVJoT9cG5Obm4o8//kBCQoLYUig8QcoCF5odgn9oT9cGkDJ5nhTKy8uxbds2XLp0SWwpgrNmzRpMmjRJbBmSggZdgdHr9UhJScGePXvElkLhCeMCFxKyQ1y8eJFmh+AZOrwgMMbsEO3atRNbCoUnSHmARrNDCAPt6QoMfYAmLS5duoTc3FyaHYJiMbSnKyA0O4T0SE5Oxrhx4yS/wIVmhxAOabcckTFmh/Dx8RFbCoUHaHYICh/QoGsBOp0OWq0WBoMBHMeBYRg4OTlBJpPB1dUVwP8mz69cuVJktZZTF59SoK4+pZAdoi5ejdkhNmzYILJaaUKDbh1gWRalpaVQq9XQarWmxvooxtdlMhlKSkrg7u7uUNkhLPEpl8vh5+fnULfblvrctGmTw/X+LPGalZWF0aNH0+wQAsFwHMeJLcJe0Wg0KC4uhlqtBlDdMOuKXq8HwzDw9fVFQEAAPD09hZJpNdb4NP6A5XK55H1WVVXB19cXDRs2tGufgHVedTodXFxc4OPjY/d16ojQoFsLLMtCoVBArVab1VifBMMwkMvlCAkJsaseIfVpGfbqEyDLq6NCg+4jqFQq5Ofng+M4XhqtEYZhwDAMwsLC4O3tzVu5lkJ9Woe9+QTI8urI0KD7f3AcB6VSiZKSEl4b7KMwDAN/f38EBQWJsm6f+uQXsX0CZHmVAjToorrRFhQUoKysTNBGa4RhGPj4+CA0NNSmjZf6FAaxfAJkeZUKdHEEAKVSabNGC1T/UMrKyqBUKm1yPiPUpzCI5RMgy6tUID7oqlQqwW/LaoPjOJSUlEClUtnkfNSnsNjaJ0CWVylBdNBlWdb04EEMOI5Dfn4+WJYV9DzUp22wlU+ALK9Sg+igq1AoRGu0RjiOg0KhEPQc1KftsIVPgCyvUoPYoKvRaHiby2gNHMdBrVZDo9EIUj71aVuE9gmQ5VWKEBt0i4uLRW+0RjiOQ3FxsSBlU5+2R0ifAFlepQiRQZdlWdPySHtBrVbzPj5GfYqHED4BsrxKFSKDbmlpqcXH3rx5Ey+//DK6deuGjRs38qjKOl3Wliekr0ehPoUv05Y+AWG8ShUig64142EpKSno0qULTp06hdGjR6OgoADfffed1ZqM42N8Yo7Ph301b94cEyZMQFRUFAYOHFjr59PS0hAXF4euXbti6NChuHXrVp3OY08+9Xo94uLi0L17d/Tt2xefffaZqcd27949zJw5E3379kVUVBTGjh2L7OzsOmsSwidQd6+PttNLly5h3Lhx6Nq1K2JjY5GWlvbYMVlZWWjTpg2WLVtmliahvEoVIoOuVqu1+FiFQoFmzZrhr7/+wqpVq6DX6wEAf/zxB1atWiWaLmvLM/oCAA8PDyQkJGD69Om1fnb79u3YsWMHVqxYgVOnTmHFihXw8/MTRBff5T3ss0+fPtiyZQtOnjyJnTt3Iicnx9Qr1Gg0aN26NTZv3ozjx49j6NCheOedd8x6aMS3T3PKfNhnaWkpJk6ciFdeeQXHjx/HL7/8gujo6Bqf1+l0+Oyzz9C2bVtBdVEIDLo6nc7iXu5rr72GrKwsLFq0CG+88QY8PT0xf/58pKen4/jx4xgzZoxV2jiOg06ns6oMI+b4fNhX165dIZfLMWTIEISFhT32WYPBgJUrV2LmzJl49tlnwTAMGjVqZFZ2DHvxqdfrTRu4GPeTzcvLAwA0atQI48aNQ2BgIJydnfHKK69Ap9Ph5s2bddbGp0+g7l4f9bl48WJER0dj8ODBcHNzg5eXF5o2bVrjmPXr1yM6OhrPPPOMRdr49ipliAu6Wq3W4jXjycnJ6NixIz766COcPn0aoaGhpvecnKz/KhmGsajHEBsbi969eyMrK8v0mjk+H/X1tB9eUVERioqKcO3aNfTr1w9xcXFYsWIFDAZDnfVa6nPhwoVo2rQptm7dajqftT5//vlndO/eHTExMcjJycErr7xS67FXrlyBTqdD48aN66zXUp8nTpxA/fr18emnn+LBgwem1+vq9VGf9+7dg4+PD8aMGYPY2FhMnjwZhYWFps8rFArs2rULb7/9ttlajVjqlUSIC7rmBIen8ddffyEnJwezZ8/GoEGDEBUVZfUDC47jUFFRAY1GY9Y/Y3qV2NhYxMTE4Pjx47z5fJSioiIA1YFhx44dSE5ORnp6Onbs2CG4z6KiIty8eRNJSUlo2rQp0tLSrH5q/uKLL+LkyZPYt28fXnnlFdSvX/+xz5SXl+PDDz/ExIkTzc6mYKnPiooK/Oc//0FwcDDmz58PlUplcZ0WFRVhz549+OCDD3Dw4EGEhoZi5syZpvcXL16MyZMnW71ZuVBtTmoQtysxX/Mb27Vrh3bt2qGgoAAA0KVLF3Tp0sWqMsvLy/Hxxx/jwIEDZh1n7GFUVFTg+PHjiImJwYULFwTZBcrd3R0AkJSUBG9vb3h7e+OVV17Bb7/9hpdffrlOZVRUVOCtt94y22dVVRUA4MGDB3jw4AHGjh2L1atXPzY+aQlNmjRBs2bNsGDBAnz99dem17VaLSZPnox27dqZnaqnsrLSIp96vd7kFQDmzJmD3bt34/Dhw2aVY8Td3R19+/ZFZGQkAGDixImIiYmBWq3Gn3/+CY1Gg7i4OIvKfhh7mTts7xAXdPkORKGhoZg0aRIvZcnlcqxbt87s7MERERG4du0a3NzcMHjwYCxatAgNGjQwXRD45JlnnoGrq6tV36Onp6dFPmfMmIGlS5fCy8sLERERWLJkCTp06MDbUlSWZU1jukB1kJ86dSoaNmyIOXPmmF2eu7u7RT7T09MxfPhwODs7o169eli4cCESExMtXvnVokWLGvX18P9PnTqFixcvonfv3gCqL/xOTk74+++/sXz5crPOQ7d6rBvEDS/wMfYqJJboi4iIwPDhw5GdnY2tW7eiefPmVvk0GAyorKwEy7LgOA6VlZWmhyQeHh6Ii4tDSkoKHjx4AKVSiW3btiE2Ntasc1iir1GjRujUqRP27t2LP/74A3369IGzs7PZ5RjZvn077t27BwC4fv06kpOT0a1bNwDVD62mT58Od3d3LFy40OLv05LjAgMDERISgm+++QZ5eXl47bXX4OrqarGG+Ph4ZGRkmMalv//+e3Ts2BFyuRyTJ0/Gvn37sG3bNmzbtg29e/fGSy+9hAULFph9Hnv/bdkLxPV0ZTKZ3d4GcRwHmUxm9nF79ux57DVrfP7555+YMGGC6e/OnTujc+fOSElJAQB89NFHmDdvHvr27Qu5XI6XXnoJCQkJdS7fUp/Tpk3DtGnTarxmjc+zZ89i2bJlqKiogJ+fHwYMGIDJkycDAM6dO4djx45BJpPVGL5YuXIlOnXqVKfyLfXZuXNnXL9+/bHXLfXarVs3TJ06Fe+88w4qKirQsWNHfPbZZwAALy8veHl5mT7r7u4ODw8Ps3vnlnolESIzR1y6dMkuB/2dnJzQqlUr3sqjPsWFb58AWV6lCpH3A/Z6ReZbF/UpLkLoIsmrVCEy6Mrlcrsb9DemuuYT6lM8hPAJkOVVqhAZdM1ZsmpL+NZFfYqLELpI8ipViAy6Li4udndllsvlcHHh97km9SkeQvgEyPIqVYgMugAQEBBgN7dpDMMgICBAkLKpT9sjpE+ALK9ShNig6+npaRfjY8bxMGuXYD4J6tO2CO0TIMurFCE26AJASEiIXTTchzfOEQLq03bYwidAllepQXTQdXFxQVhYmGiNl2EYhIWFWbWqqi5Qn7bBVj4BsrxKDaKDLgB4e3vD39/f5o2XYRj4+/ub9nMVGupTWGztEyDLq5QgPugCQFBQEHx8fGzWeBmGga+vL4KCgmxyPiPUpzCI5RMgy6tUIHIZcG1wHAelUomSkhJB92Yw9hKCgoJEuTWkPvlFbJ8AWV6lAA26j6BSqZCfnw+O43htwAzDmMbB7OG2jPq0DnvzCZDl1ZGhQbcWWJaFQqGwKmvwwxin1oSEhNjVJHLq0zLs1SdAlldHhQbdp6DRaFBcXGxKL23OV2W8/ZLL5QgICLDruYzU5z/jSD4Bsrw6GjToPoX9+/fjpZdewr1796BWq6FWq6HVak2ZYx/F+LpMJoNcLoefn59D9Q5YlkVpaSn1+X84uk+ALK+OAg26T2Dnzp0YOXIkdDod8vLyakwC1+l00Gq1MBgMpsbq5OQEmUwGV1dXEVXzC/UpLZ8AWV7tFXopq4WVK1dixowZqKqqgqenJ5RKZY2g6+rqSkQDpT6lB0le7RUadB/h8uXLmDRpkukWzMXFBUqlUmRVFApFKtDFEY/QsmVLnD59Gs899xxcXFyg0Who0KVQKLxBe7q10LZtW9y9exdnz57FuXPnaiQmpFAoFGugQbcWdu/ejXbt2iEyMhKRkZFiy6FQKBKCDi/Uwpo1a/Daa6+JLYNCoUgQOmXsEW7duoXOnTsjPz+fZjilUCi8Q3u6j5CSkoJRo0bRgEuhUASB9nQfQq/XIzw8HPv27UPbtm3FlkOhUCQI7ek+xKFDhxAUFEQDLoVCEQwadB+CPkCjUChCQ4cX/o87d+4gIiICubm5dM9QCoUiGLSn+39s2LABw4YNowGXQqEICg26qN7WLjk5Ga+//rrYUigUisShQRdAZmYmDAYDevToIbYUCoUicWjQxf8eoNFkexQKRWiIf5CmUqnQpEkTXLlyBQ0bNhRbDoVCkTjE93Q3b96MPn360IBLoVBsAvFBlz5Ao1AotoTooHvhwgXk5+dj4MCBYkuhUCiEQHTQTU5ORlJSEpydncWWQqFQCIHYB2mVlZUICwvDqVOn0LRpU7HlUCgUQiC2p7tr1y60a9eOBlwKhWJTiA26ycnJdHMbCoVic4gcXqDZISgUilgQ2dNNSUnB6NGjacClUCg2h7ieLs0OQaFQxIS4ni7NDkGhUMSEuKBLs0NQKBQxIWp4gWaHoFAoYkNUT5dmh6BQKGJDTNDlOA5r1qyhm9tQKBRRISboZmZmguM4mh2CQqGICjFBl2aHoFAo9gARD9KkkB1Cp9NBq9XCYDCA4zgwDAMnJyfIZDK4urqKLY83SPEJkOOVFJ91xUVsAbbAEbNDsCyL0tJSqNVqaLVaU2N9FOPrMpkMcrkcfn5+cHFxnGolxSdAjldSfFoKET3d7t27Y/bs2XjxxRfFlvKPaDQaFBcXQ61WA6humHXF2LDlcjkCAgLg6ekpiEY+IMUnQI5XUnxai+SD7vnz5zFo0CDk5uba9WblLMtCoVBArVab1VifBMMwkMvlCAkJsaveAyk+AXK8kuKTLyQfdKdNmwa5XI758+eLLeWJqFQq5Ofng+M4XhqtEYZhwDAMwsLC7GJuMik+AXK8kuKTTyQddI3ZIU6fPo3w8HCx5TwGx3FQKpUoKSnhtcE+CsMw8Pf3R1BQkCizN0jxCZDjlRSfQiDpKWPG7BD2GnALCgoEb7TGc5WUlKCgoEDwc9V2bhJ8Gs9PgldSfAqFpIOuPa9AUyqVKCsrs1lD4jgOZWVlUCqVNjmfEVJ8AuR4JcWnUEg26N68eRNnz55FfHy82FIeQ6VS2aSX8CjGXoNKpbLJ+UjxCZDjlRSfQiLZoGuv2SFYljU9eBADjuOQn58PlmUFPQ8pPgFyvJLiU2gkGXT1ej1SUlLsct9chUIh+tgUx3FQKBSCnoMUnwA5XqyHSZIAACAASURBVEnxKTSSDLoHDx5EcHCw3WWH0Gg0vM1ltAaO46BWq6HRaAQpnxSfADleSfFpCyQZdO31AVpxcbHojdYIx3EoLi4WpGxSfALkeCXFpy2QXNC9c+cOMjIyMHLkSLGl1IBlWdPySHtBrVbzPj5Gik+AHK+k+LQVkgu6qampSEhIsLtVLKWlpRYfGx8fj6ysLB7V/A9rdFlbnpC+HoVvn+aUaUufAK1Te0dSK9I4jkOrVq2wevVq9OzZU2w5Nbhx44bV41Acx2H58uXYvXs3NBoNnnvuOcyaNQvNmjWzuExPT080bdrUKl0PY4nPv//+G0uWLMGlS5dw//59nD9//rHPpKenY+XKlVAqlahfvz4WLFiATp061fkcfPsEzPeanp6O7777DsXFxXBzc0PPnj3x4Ycfol69eqiqqsKCBQtw8uRJlJWVoVGjRpg6dSpiYmLM1mUPdQoAeXl5+PTTT/HHH3/Azc0NCQkJmD59eo3P5ObmYvjw4ejfvz8+/fRTs8oXok5tgaR6uidOnLDb7BBardbqMg4cOIBdu3Zh3bp1OH78ONq1a4ePPvpIdF3Wlufi4oKBAwdi3rx5tb5/4sQJfPXVV5g/fz5OnjyJdevWISwsTHBdfJfZoUMHpKamIjMzE+np6WBZFsuXLwdQfQsfFBSElJQUZGZm4t1338X/+3//DwUFBYLrEqI8nU6HN998E127dsWvv/6Kw4cP17rL38KFCxEZGWkzXfaApIKu8QGava3R1ul0Vj2EGDhwIDIzM1FQUIAOHTqgUaNGcHZ2xuDBg3H9+nWrtHEcB51OZ/Zx5eXljx1nrk+jr/DwcAwfPvyJPfbvvvsOb7/9Ntq1awcnJyc0bNjQ7L2RLfWp0+lQXl5e6+t19Wr0GRQUBD8/P9Przs7OuH37NoDqXtukSZMQGhoKJycnxMbGIjQ0FJcuXTJbs6Ve79+//9hrltbprl270KBBA4wbNw6enp5wd3dHREREjc+mp6dDLpejW7duZmsFLPcpNpIJuiqVCjt37kRiYqLYUh5Dq9XyciEYNGgQ8vLycOvWLeh0OuzZs8fqXj3DMBb1GMaPH4/Q0FCsWrUKVVVVAPjz+TB6vR4XL15ESUkJXnjhBTz//PNYuHCh2Zot9ZmWlobAwEDMnDkT9+7dM71uqdczZ84gKioK3bp1w+HDhzF27NhaP1dcXIzc3Fw8++yzZp/DEq83btyAv78/EhIScOXKFdPrlvrMzs5GSEgI3n77bcTExCApKQk5OTmm98vLy7FixQq89957ZpdtxNI6FRvJbFb5008/4fnnn0eDBg3ElvIYBoOBl3ICAwPRsWNHDBkyBM7OzggKCsKaNWus1paXl2f2D6uoqAh3797FtGnT8OGHH2Lq1KmYPHmyVVpq4969e2BZFocOHcL69evh4uKCKVOmYNWqVZgyZUqdy+E4ziKft27dgl6vx7Jly7B8+XKMGjUK8+fPh5eXl7lWAAAdO3ZEZmYmioqKsH37doSEhDz2GZ1Ohw8++ABDhw61aMzSkjq9fv06ZDIZ9uzZg/T0dPTs2RMLFy7Ec889Z/b5ger2kZWVhWXLlqF79+5IS0vDlClTsHfvXri6uuLbb79FQkICgoKCLCrfCF+/LVsimaCbnJyMuXPnii2jVvh6Vrly5UpcuHABhw4dQkBAAPbt24fXX38dO3fuhIeHh0VlVlRUYNWqVTh+/LhZx+Xl5ZmOr6iowNy5cxEVFYXQ0FCLdDwJd3d3AMCoUaMQGBgIAEhMTDQ76FZWVmLlypVm+ywtLQXLsqY6XLt2LViWxTfffGNWOY/SsGFD9OjRAzNnzsSWLVtMrxsMBnz00UdwdXW1eLxeo9Fg9erVZnmtrKxEZWUlDAYDKisrkZGRgZs3b+LPP/+0SIO7uzs6dOhgehA4fvx4rFq1Cjdu3ADHcTh58iS2bt1qUdkP44jzACQRdLOzs6FQKDBw4ECxpdQKX7fcV69eRVxcnKl3EB8fj88//xw3btxA69atLSrTy8sLS5YsgY+Pj1nHDRo0CIcOHYKrqysSExMxd+5ceHl5WfTg52n4+PigYcOGNb5DS75PmUxmkc+VK1diypQpcHV1RVRUFL744gt07NgRZWVlZmt4FL1eb7p4AdUBZM6cObh37x6+++47i5M21qtXz2yvly9fRvv27eHq6oqGDRviiy++QHx8vMXzc1u0aIFz587V+l5WVhYUCgX69+8PoPoiYTAYMGLEiBoXoLpgb89v6oIkgm5ycjKSkpLsNh2PkxM/Q+eRkZE4ePAg4uLi4O/vj59//hksy6JRo0Y21xcdHY3GjRtj7ty5pltkS3+gHMehqqrK9FCksrISDMPAzc0NQPXFZdOmTejRowdcXFywYcMG9OrVy+zzWOKzefPmGDBgAObPn4+OHTtaVda+ffvQqVMnBAcHQ6FQYNmyZTUeIs2fPx83b97E6tWrrd6oyVx9gYGB6NatG6ZNm4b4+HjT8Za23cGDB5tmanTt2hUbN26Er68vmjZtisaNG2PQoEGmz65btw4KhQIff/yx2efh67dlSxw+6Gq1WmzcuNGmk8/NRSaT8XIbNGHCBNy7dw+vvPIKKioq0LhxYyxdutSqhSAcx1n0A589e/Zjr1nqU6FQIC4uzvR3586dERISggMHDgAA3nrrLdy/fx9DhgyBm5sbBg4ciDfffNOsc1jqs1+/fujXr99jr1vi9caNG/jqq6+gVqshl8sRExODadOmAaj+DrZu3Qo3Nzf07t3bdMycOXMwePBgs85jideAgAD897//fex1S+s0PDwcixcvxvz581FSUoKWLVti+fLlcHV1haura43hME9PT7i5ucHf39+sc1hap2Lj8IsjfvrpJyQnJ+PQoUNiS3kqly5dsstBfycnJ7Rq1Yq38kjxCZDjlRSftsLx+uaPkJycbJdbOD6KvV6R+dZFik+hyuQDWqf2jUMH3Zs3b+LcuXN2mR3iUeRyud0N+htTXfMJKT4BcryS4tNWOHTQXbt2LUaNGuUQV7yHVyLZE3zrIsWnUGXyAa1T+8Zhg649Z4eoDRcXF7u7Msvlcri48PsslRSfADleSfFpKxw26B44cAChoaF2lx0CqJ5R0bt3b7Rp0watWrXCM888Ax8fH5w/f95ubtMYhkFAQIAgZQcEBBDhEyDHKyk+bYFjXipQvbmNvfZy3d3doVQqcfXqVdNrPj4+iIqKQnl5uehpT4zjYZ6enoKU7+npCblcLnmfADleSfFpCxyyp1tUVIQjR47YXXYIIzdv3qyx9aCHhweOHj0Kf39/hISEiN5jYBiG9+W6j0KKT4Acr6T4FBqHDLr2mh0iLy8Pb731Frp06YIePXqgdevWcHFxweLFi9G+fXsA1eNjYWFhojVehmEQFhYm+Oo9UnwC5HglxafQOFzQ5TgOycnJdpV4srCwEFOmTEG7du3g5+eHnJwczJs3D6tXr0ZSUtJjG7N4e3vD39/f5o2XYRj4+/vb7GJFik+AHK+k+BQShxvT/f3338EwDKKjo8WWguLiYnz++edYs2YNxo0bh8uXL9fYXDsqKgpRUVG1HhsUFAS9Xo+ysjKbjJExDANfX1+rt9IzF1J8AuR4JcWnUDhcT9f4AE3MsaX79+9j9uzZiIiIQHl5ObKzs/HVV1+Zlc3AODZli16DsZcgxpgcKT6N5yfBKyk+hcKh9l4oKytDkyZNkJOTI8pm5Wq1Gt988w2+/vprDB06FLNnz0Z4eLjV5apUKuTn54PjOF57DgzDmMbB7OG2jBSfADleSfHJJw4VdH/44QccPHgQ27dvt+l5NRoNVqxYgSVLlqBfv36YO3cuWrRowes5WJaFQqHgbUqOcWpNSEiIXU0iJ8UnQI5XUnzyhV0EXZ1OB61WC4PBAI7jwDAMnJycIJPJamzk3LVrV8ybN6/GXpxCotVqsXr1aixevBjR0dGYN2+exZuF1xWNRoPi4mLT3rTmVI/x9ksulyMgIMCu5zJqNBrcvXvXlPRRqj4BcryS4tNaRLmMsCyL0tJSqNVqaLVaU6B9FOPrMpkMLMviwYMHGDBggOD6dDodUlJSsGDBArRt2xY///wzOnToIPh5gepJ6I0bN7boO5LL5fDz83OI3kFJSQkiIyMxdepUTJ06VbI+KysrERUVBZ1Oh+zsbEnX6ezZs7FixQqUl5dL2qe12NTh03pxT7oqchwHjUYDhmGwefNmFBQUCHYlZFkWGzduxLx589CsWTNs2bIF3bt35/08dcHFxQWBgYGmvGB1vRtwBC5fvoyePXuivLwcer1esj5VKhUGDBiA7OxsNGvWTLJ1ynEc/v3vf2PFihXQ6/VgGEaSPvnCJkGXjzEf43Eqlcq08z5fYz4GgwFbtmzBJ598ggYNGiAlJQWxsbFWl8snxh33HZ3MzEzExcVBpVIBAJRKZY33peKzqKgIvXr1wq1btwBUJ/B8FCl41el0GDNmDPbt2weWZeHu7o7i4uIaM3mk4JNPBA+6Qjzd5DgOarUaOTk5Vj3d5DgOu3fvxpw5c+Dh4YHly5ejX79+kpmaYo9MnToVlZWVpr9v374tohrh+Oabb3D9+nXo9XoA1VmFpcixY8ewZcsW0yox474j5kyfJA3B5ulyHIfCwkLk5eWZbiv4Lt9gMCAvLw+FhYVmlc9xHNLT09GlSxfMmzcPixYtwsmTJ9G/f38acAXm+PHjmD17NmQyGZycnHD37l2xJQnCwoULcfDgQbi5ucHd3R0ajcaUeFNK9OvXD9evX0dYWBg8PT2hVqtx584dsWXZNYL0dDmOQ0FBgU1WrHAch5KSEuj1eoSGhv5j0Dxy5Ag+/vhjlJWVYd68eRg+fLhDZhR1VNzc3HDt2jX85z//wb/+9a/HhhekAsMwyM/PR79+/fDjjz/i5MmTkr3FdnJyQnl5OYqKinDq1Cl07dpVbEl2jSBTxgoLC1FSUmLTLeCMK1eCg4Nrff/333/H7NmzkZeXh08++QQjR450+I0zHBGVSoXGjRvj6tWrkr8FjY2NxbRp05CQkCC2FEGZO3cuSktLsWzZMrGlOAS893RVKpXNAy7wvx6vl5dXjTHeP/74A7Nnz8bly5cxZ84cJCYmEjEtxV756aef0LdvX8kH3JycHFy9etXs9OmOhl6vx9q1a7Fv3z6xpTgMvN5XsyxremgmBhzHIT8/HyzLIjs7GwkJCYiPj8eQIUNw9epVTJgwgQZckVmzZo1d7RAnFMnJyUhMTJTskIKRgwcPIigoCO3atRNbisPAa9BVKBSi7ioP/G9GwoABA9CrVy/8/fffmDRpEtzd3UXVRQGys7NRWFiIgQMHii1FUHQ6HVJTU+02swmf2Ns2q44Ab90+jUYjeioPoDroNm/eHBcvXkT9+vVF1UKpSXJyMsaPHy/5sfSff/4ZzZo1Q0REhNhSBOXOnTs4fPgwkpOTxZbiUPAWdIuLi0UPuEacnJzw4MEDGnTtCK1Wi40bNyIrK0tsKYJDSu9vw4YNiI+Ph4+Pj9hSHApehhdYljUt7bWE+Ph43n+MarUaLMvyWibFcnbt2oX27dvzshWmPVNQUIDff/8dL7/8sthSBIXjOLtODmvP8BJ0rV1ts2vXLnTp0sX096xZs6yVBEC6q4AcEVJ6f+vXr8eIESPg5eUlthRByczMhMFgQM+ePcWW4nDwEnT5GMstLCzEZ599ZlqjnpOTgyVLllhcnnGpMEV8bt68ibNnzyI+Pl5sKYJiMBiQnJxMRO/PHjK4OCq8jOlqtVqrjh84cCA++eQT9O/fHx999BEuXrwIDw8PTJo0SVRdFH5Yu3YtRo8eDZlMJrYUQTl69Cjq1auHzp07iy1FUFQqFXbs2IErV66ILcUhsTro6nQ6QR6gOTk5WX0V5TgOOp1O8nMl7Rm9Xo+UlBT88ssvYksRHOMQitR7f5s3b0bfvn0lkyjS1lg9vKDVanlpZHfv3sWhQ4ewaNEidOnSBS+99BLWrl1rVZkMw9DersgcOHAAISEhaNu2rdhSBKW0tBQ///wzRo8eLbYUwaEP0KzD6p6uwWDgQwcCAwMxdOhQ098RERG8zHPkSx/FMkh5gLZx40a88MIL8Pf3F1uKoFy4cAEFBQWSX+AiJFb3dIUYWli4cCFvZdnL3GESKSoqQkZGBkaOHCm2FEEhafrUmjVrMH78eLqc3gqs/ubsffzK3vVJmQ0bNiAhIUFyKbQf5cyZM1CpVOjTp4/YUgSlsrISGzduxKlTp8SW4tBY3dO1971o7V2fVCGt9zdhwgTJt7Vdu3ahbdu2aNq0qdhSHBqre7oymczqW/gDBw5YK6NWOI6T/DQle+X3338HAPTo0UNkJcKi0WiwefNmZGdniy1FcEjZIU5orL40u7q62u0tPMMwdLqYSJAyfWrbtm2Ijo5GWFiY2FIExbjAReobstsCXu6H7LU3aa+6pI5KpcLOnTuRmJgothTBIWUIJSUlBaNGjaK/KR7gJejK5XK769FotVps27YN27dvr5F9liI8P/30E55//nk0aNBAbCmCkpOTg5ycHCKyQ6SkpNChBZ7gJej6+fnxUQyvyGQyBAcH49tvv0VISAgmTpyIzMxMOoXMBpDS+yMtO4TUF7jYCt4SU96+fRsqlYqPonjB29sbjRs3BgDk5uYiLS0NqampMBgMSExMxNixY/HMM8+IK1KCZGdn48UXX8StW7ckvVm5TqdDo0aNcOzYMclvVv7SSy9hwIABeOutt8SWIgl4m+MSEBBgN0MMDMMgICDA9HeTJk0wa9YsXLlyBWlpaVAqlejcuTNiY2ORnJyMsrIyEdVKi+TkZCQlJUk64ALV2SFatGgh+YB7584dZGRk4NVXXxVbimTgLeh6enraxdguwzCQy+Xw9PSs9b1u3bphxYoVKCgowLRp07Bv3z40btwYr776KtLT0+nG51ZgzA6RlJQkthTBIWUIJTU1FfHx8ZJf4GJLeBteAKozSOTk5Ii634GTkxMiIiLM6mndu3cPmzdvRmpqKnJzczF69GgkJibSMSwz+emnn5CcnIxDhw6JLUVQCgoK0KZNG+Tl5Ul6s3KO49CqVSusXr2ablbOI7wuoXFxcUFYWJhovV2GYRAWFmb2rW39+vUxadIknDx5EkePHoVMJsOQIUPQvn17LF26FEqlUiDF0oKU3t+6deuIyA5x4sQJcBwn+QUutobXnq6RwsJClJSU2HSmAMMw8Pf3R3BwMC/lGQwGHDt2DKmpqdi1axeioqKQmJiIYcOGwcPDg5dzSIkbN26ga9euyM/Pl/RcToPBgGbNmmHLli2S36w8KSkJrVq1wnvvvSe2FEkhSNDlOA4FBQUoKyuzSeBlGAa+vr4ICQkRpJf94MED7Nq1C6mpqcjKysLw4cORmJiInj17Sn69fV2ZPXs21Go1vv76a7GlCMqRI0cwffp0nD17VvTnF0KiUqnQuHFjXL16FQ0bNhRbjqQQJOgC1YFXqVQK3uM19nCDgoJs8iMoKCjApk2bkJqaivLycowdOxZjx45F8+bNBT+3vaLX69GkSROkp6ejTZs2YssRlFGjRiEqKgrvvvuu2FIEZdWqVdi/fz927NghthTJIVg3jWEYBAcHo1GjRryk3qmtfCcnJzRq1AjBwcE263WEhobivffeQ3Z2Nnbu3Iny8nLExMQgOjoa33//PUpKSmyiw544cOAAQkNDJR9wS0pK8MsvvxCRHYKUzefFQLCe7sOwLAuFQsFL1mDgf9PCQkJC7GIzZZ1Oh0OHDiE1NRX79+9Hv379kJiYiEGDBjn0aiWdTgetVguDwQCO40wXOplMVsPX8OHDERcXhzfffFNEtZZTV5/Lly9HZmYmNm3aJKJa66iL1/Pnz2PQoEHIzc112PnWda1TMbBJ0DWi0WhQXFxsSo1uzqmNPVm5XI6AgIBa5+HaA/fv38e2bduwfv16XL16FSNHjkRiYiI6depk92OALMuitLQUarUaWq3W1Fgfxfi6TCYDwzCIjY3FhQsXHGYupyU+5XI5EhIS8J///Ad9+/YVQbVlWOL1119/hUKhwOzZs0VQbBmW1qmfn5/NO242DbpGHOkLsobr16+blh+7u7sjMTERo0ePRqNGjcSWVgNrLoZA9RN9X19fu74YAtZf9FmWha+vLwIDA+3aJ2CdV71eD2dnZ3h7e0u+TgHbd+RECbqPYs+3AnzAcRxOnDiB1NRUbNu2DR06dEBiYiKGDx+OevXqiaZL6sM+RkjxCZDj1ZF92kXQJQmtVou9e/ciNTUVv/32G4YOHYrExET06dPHpuNnKpUK+fn54DiO19klDMOYFqnYw3ADKT4Bcrw6uk8adEXkzp07+PHHH5GamoqioiKMGTMGiYmJaNWqlWDnlOpUvkchxSdAjlep+KRB1064cOECNmzYgLS0NAQHByMxMRGvvvoqAgMDeTuHGItWfHx8EBoaatMfKSk+AXK8SsknDbp2hl6vx5EjR5Camoq9e/eiV69eSExMxJAhQ+Du7m5V2VJYnl0XSPEJkONVSj5p0LVj1Go1duzYgdTUVJw7dw4jRoxAYmIiunfvbvbVV6VSIS8vT5TMGQzDoFGjRjYZDyTFJ0COV6n5pEHXQbh9+zY2btyI9evXQ6/XIzExEWPGjEF4eLjpM/v27YO/vz+io6NrHGsvW262aNFC0CfDpPgEyPEqRZ90txYHoXHjxvjwww9x+fJlbNq0CXfu3EHXrl1N2S/u37+PN954A/369cPp06drHKtQKETPDcdxHBQKhaDnIMUnQI5XKfqkPV0HpqqqCunp6UhNTcWBAwdQWVkJlmUhl8tx4sQJREZGQqPR4ObNm6I3XKD6Vi08PFyQSeik+ATI8SpVn7Sn68C4ublh2LBh2L59O4YMGQK9Xg+geiy4c+fOyMrKQnFxsV00WqC6x1BcXCxI2aT4BMjxKlWf9rPEhGIVR44cAVCdq87X1xfOzs64cuWK3S3hVKvVYFmW13FAlmVNy0DtBSF8AuR4lbJPGnQlwoULF+Dh4VFjWfHdu3dx584dEVXVTmlpKa/zj0tLS3kri0/49mks0x6hdVp36PCCRAgMDHxsHwe+1qU/zP79+zF06FB069YNw4YNQ0ZGhlnHcxzHew/GWp86nQ7Tp0/HwIED0aZNG2RlZT32mUuXLmHcuHGmh5dpaWlPLVMIn4B1Xq9fv45//etfiI6ORnR0NF5//XVcv37d9H5KSgoSEhLQrVs3xMXFISUlpU7l2mOdAkBFRQUWLFiAmJgYREVFYdy4cY99RqfTYejQoXj++ef/sTy+fNKeroTRarW8lldUVIQPP/wQy5YtQ8+ePfHbb79hxowZ2L9/P+rXry+aLj7K69ChA8aMGYMZM2Y89l5paSkmTpyI9957DwMGDIBOp0NRUZFNdPFZZmBgIJYuXYqQkBAYDAb89NNPeO+990zZITiOw8KFC9GiRQvk5eXhrbfeQlBQEAYNGiSoLqHKmzdvHvR6PXbv3g0fHx9cuXLlsc+kpKTAz88PDx48sJku2tOVKDqdzqKeQps2bXD79m3T37NmzcKyZcsAVAddb29vxMTEgGEY9OrVCx4eHsjLyzPrHBzHQafTma2tuLj4sak7dfX5NF+urq4YO3YsOnbsWGvOu9TUVERHR2Pw4MFwc3ODl5cXmjZt+o/ntNRnVVVVrQGiLl6f5tPb29u0rJXjODg5OdWouwkTJqBVq1ZwcXFBeHg4+vTpg7Nnz9ZJs6VeL1y48NgcXD7q9MaNGzh69Cjmzp0Lf39/ODs7o3Xr1jWOz8/Px759+8zKkGGpz4ehQVeiaLVa3teMt27dGuHh4fj111+h1+uRkZEBV1dXtGjRwqxyGIaxqMfw+eefo3Hjxnj99deRn58PQBifj5KdnQ0fHx+MGTMGsbGxmDx5MgoLC//xOEt9Hj16FC1btkTv3r1rDHXw5TU6OhqdO3fG4sWLnxhwOI7DmTNn0KxZszqVaYnXyspKtGnTBk2bNsWWLVtMwZcPnxcuXEBwcDBWrFiBmJgYJCQk4NChQzU+s3jxYkydOtWs7NWW1unD0OEFiSLECh5nZ2cMHToU77//PqqqquDq6oovv/zS7BkSer0ef/75JyoqKsw67tq1a9Dr9Vi/fj02bNiA3r1749tvvzWrDEsoKirC5cuXsWrVKjRv3hxLly7FzJkzsWHDhqceZzAYLPKZlZUFDw8PHDt2DDExMWjRogU+//xzREVFWWPDxIkTJ6DRaLBnzx6EhITU+pnvvvsOBoMB8fHxdSqTZVmzvVZVVcHJyQm5ubkYN24c3n33Xbz//vt47bXX6lzGkygqKsK1a9fQv39/HDlyBOfOncM777yDZ599Fk2bNkVGRgb0ej2ef/75Wsfwn4a1vy0adCWKEPMbMzMzsXTpUqSkpKBly5a4dOkS3n33XaxcuRLPPfdcncvR6XQ4ePBgnW9djRhvuVmWBcMwOHz4MM6dO/fYbSPfuLu7o2/fvoiMjAQATJw4ETExMVCr1ZDL5U88jmVZi3zevXvXdAtbWVmJ8+fP44cffkD37t0tN/EInp6eGDFiBHr16oXdu3fXGJPftGkT9u7di3Xr1sHNza1O5VVVVZntVa/Xm9qpVqtFZWUlvv76a0yYMME8M7Xg7u4OFxcXvPnmm3BxcUGXLl3QtWtXnDhxAkFBQVi6dCm+++47i8q29rdFg65EsfT2zMPDo0Zv5d69e2jYsCEA4OrVq+jUqZMpyEVGRqJNmzY4efKkWUFXJpPh/fffh4+Pj1naZsyYgW+++Qbu7u6YOnUq3nvvPTg5OaGgoMAqX/9EixYtanyfdf1u3dzcLPKZnp6OYcOGQSaTYfDgwVi0aBGaN2+OsrKyfzzWHJ8GgwFarRZ37twxBd2dO3ciOTkZ69atQ1BQUJ01e3p6mu1Vq9XC09MTXl5eaNGiBb788kv07t0bKpWqRxjT1QAABEtJREFUTsc/zWttQ17Gert9+zYUCoVpNoNOp0N5eTl69+6NjRs3IjQ09KnntXbog47pSpTaHgjVhYiICPzyyy/Q6/U4fvw4/vjjD9N7rVu3xpkzZ0w9zsuXL+PMmTNmj+laqi82NhYffPAB8vPzsWjRIvj5+dW5nKf5Aqp7apWVlQCqf4SVlZWmHk18fDwyMjJw5coV6HQ6fP/99+jYseNTe7lGLPEZERGBpKQkZGdnY+vWrWjevHmdy3qazxMnTuDy5cvQ6/UoLy/HF198AW9vb9NDwX379uGbb77B6tWrLcrjZ65XNzc3TJgwAXv37sWff/6JPn36mFJ11YWnee3UqROCg4OxZs0asCyLs2fP4vTp04iOjkazZs1w6NAhbNu2Ddu2bcO8efNQv359bNu2rU4XGkt/W0bo3gsSRafTIScnx+xboYsXL2LWrFkoLCxE3759odfrERYWhilTpgCovvVMS0vDvXv34Ofnh1dffbXW+Y9Pg2EYtGjRgpf8d3X1+U++Bg4c+NjMiP3795t6PZs3b8aqVatQUVGBjh074uOPP/7HHyifPoG6eX2azwMHDuDbb79FUVERZDIZIiMjMXXqVERERAAA4uLiUFRUVEPv4MGDMWfOnH/UZo91eu3aNcydOxd///03goODMWXKlFrn42ZlZeGDDz6o05xzPnzSoCthLl26JOqWeE/CycmJ15REpPgEyPEqZZ90eEHCmDMVxpbwrYsUn0KVyQe0TusODboSRi6Xi5Ys8UkYU13zCSk+AXK8StknDboSxs/PT2wJtcK3LlJ8ClUmH9A6rTs06EoYFxcXQXpb1iCXy3nf7pAUnwA5XqXskwZdiRMQEGA3t2kMwyAgIECQsknxCZDjVao+adCVOJ6ennYxPmYcDxNqU3VSfALkeJWqTxp0CSAkJMQuGu4/rfSxFlJ8AuR4laJPGnQJwMXFBWFhYaI1XoZhEBYWBmdnZ0HPQ4pPgByvUvRJgy4heHt7w9/f3+aNl2EY+Pv7w9vb2ybnI8UnQI5XqfmkQZcggoKC4OPjY7PGyzAMfH19zdo4hQ9I8QmQ41VKPukyYMLgOA5KpRIlJSWCprc29hKCgoJEuTUkxSdAjlep+KRBl1BUKhXy8/PBcRyvDZhhGNM4mC1vtZ8EKT4Bcrw6uk8adAmGZVkoFAresgYbp9aEhIQIsjDAUkjxCZDj1ZF90qBLgUajQXFxsSm9tDlNwnj7JZfLERAQIOj8VGshxSdAjldH9EmDLsUEy7IoLS2FWq2GVqsFx3G1jmkZX5fJZJDL5fDz87OrXtA/QYpPgByvjuSTBl3KE9HpdNBqtTAYDKbG6uTkBJlMxtvG3PYAKT4Bcrzas08adCkUCsWG0Hm6FAqFYkNo0KVQKBQbQoMuhUKh2BAadCkUCsWG0KBLoVAoNoQGXQqFQrEhNOhSKBSKDaFBl0KhUGwIDboUCoViQ/4/OWoKyLZR9gQAAAAASUVORK5CYII=\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],\n",
" 'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],\n",
" 'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],\n",
" 'c64': [2, 3], 'c128': [3, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ng3SCvpnA2-p"
},
"source": [
"Again, the connections added here are precisely the promotion semantics implemented by Numpy for mixed-integer promotion."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EuzHht0CjbWf"
},
"source": [
"### How to handle `uint64`?\n",
"\n",
"The approach to mixed signed/unsigned integer promotion leaves out one type: `uint64`. Following the pattern above, the output of a mixed-integer operation involving `uint64` should result in `int128`, but this is not a standard available dtype.\n",
"\n",
"Numpy's choice here is to promote to `float64`:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "ZKUCX9ryjzhN"
},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float64')"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(np.uint64(1) + np.int64(1)).dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yV-9Uka-j98S"
},
"source": [
"However, this may be a surprising convention: it's the only case in which promotion of integer types does not result in an integer.\n",
"For now, we will leave `uint64` promotion undefined, and return to it later."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BBNC59v_pTSY"
},
"source": [
"## Mixed Promotion: Integer and Floating\n",
"\n",
"When promoting integers to floating point, we might start with the same thought process as mixed promotion between signed and unsigned integers. A 16-bit signed or unsigned integer cannot be represented at full precision by a 16-bit float, which has only 10 bits of mantissa. Therefore, it might make sense to promote integers to floats represented by twice the number of bits:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"cellView": "form",
"id": "GT5uPYlMs3sw",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2deVgTV/v+7wkBAhgQhMqmrVaxVtwXBIv72qoVq9atWK1dtFZ927d2sWqtW1fb2lr7KgFBtO7ivlttLS60aHFBcReIoAhCMAQyyfz+4Jd8RVBJMpNJ5pzPdfW6ajJz5r45J8+cOXPOeRiO4zhQKBQKxS7IxBZAoVAoJEGDLoVCodgRGnQpFArFjtCgS6FQKHaEBl0KhUKxIzToUigUih2hQZdCoVDsCA26FAqFYkdo0KVQKBQ7QoMuhUKh2BG52AIotUOv10On08FoNILjODAMA5lMBoVCAVdXV7HlUayAlDolxWdtoUHXQWFZFkVFRdBoNNDpdObG+jCmzxUKBZRKJXx9fSGX02p1REipU1J8WgtDN7xxLLRaLQoKCqDRaABUNszaYmrYSqUS/v7+8PT0FEQjxTJIqVNSfNoKDboOAsuyUKvV0Gg0FjXWR8EwDJRKJYKDg4noPTgipNQpKT75ggZdB6CkpAQ5OTngOI6XRmuCYRgwDIPQ0FB4e3vzVi7lyZBSp6T45BMadEWE4zjk5eWhsLCQ1wb7MAzDwM/PD4GBgTWOrVH4g5Q6JcWnENCgKxIcxyE3NxfFxcWCNloTDMPAx8cHISEhkmm8jgYpdUqKT6Gg83RFIi8vz26NFqj8oRQXFyMvL88u1yMRUuqUFJ9CQYOuCJSUlAj+WFYTHMehsLAQJSUldr0uCZBSp6T4FBIadO0My7LmFw9iwHEccnJywLKsKNeXIqTUKSk+hYYGXTujVqtFa7QmOI6DWq0WVYOUIKVOSfEpNDTo2hGtVsvbXEZb4DgOGo0GWq1WVB1SgJQ6JcWnPaBB144UFBSI3mhNcByHgoICsWU4PaTUKSk+7QENunaCZVnz8khHQaPROP34mJiQUqek+LQXNOjaiaKiIkHKnTlzpk3nC6WLBGz52xUWFmLGjBmIjIxEVFQUPvroo2rHFBcXo2vXroiNjbWbLlvLu3PnDt577z307NkTLVu2RG5ubpXvv/32W7z00kuIiIjAoEGDsG3btirfnzhxAiNGjEDnzp3Rv39/bNiwgRddjgQNunaCz/GwsrIyzJ07F/fu3QNQ+YJj7ty5FpdvGh+jWIctdfqf//wH/v7+2LdvH44cOYLXX3+92jHff/89GjVqZFG5QtSpJT4ZhkGXLl2wePHiGr/38PDATz/9hGPHjmHBggX48ssvcfr0aQCVW0BOnz4dw4YNw7Fjx/Dtt9/im2++wcWLF6uV48xtV3q7STgoOp3O6nPz8vLw5ZdfIj09HUajEQMGDMDo0aMxf/58nDt3DkuXLsWkSZOsWq1jiy7Sqc3frqa669GjB/Ly8hAfHw8XFxcAQPPmzaucd/r0aVy6dAnDhg3Dli1beNfFR3k1eZs5cyZGjhz5yEf/d9991/z/rVq1Qvv27fHvv/+iTZs2KC4uRmlpKQYNGgSGYRAeHo7GjRvjypUraNasWa11OTq0p2sH9Hq91T0ig8GAd999F0FBQdizZw8OHjyIAQMGmL9/cFNoa+A4Dnq93qpzSWD79u2IiIjAgQMHqtRhber0UXX377//4plnnsHMmTPxwgsvYOTIkUhLS6ty3sKFC/Hpp59adSO1pk7z8/Px/PPP43//+x8qKirMnz/K55PaZW3Q6XQ4e/Ysnn32WQCAv78/BgwYgJSUFBgMBpw+fRq3bt1Cu3btajzfWdsuDbp2QKfTWb1m/MyZM7hz5w4++OADeHp6wt3dHc2bN8fq1avx2WefoX379pg8eTLi4uKsCuwMwzhtj8EeZGdnIz09HUOGDEGbNm3Mwbc2dVpT3bVr1w75+flITU1Fp06d8Pvvv2PcuHGYNm2aeYxy9erVaNmyJVq0aGGVZmvqtLCwEFevXsUHH3yA0NBQc/B9lM9HebOEefPmoVmzZujSpYv5sxdffBG//vor2rdvj9dffx3vvfceAgMDefPpCNDhBTtgNBqtPjcvLw9BQUFV9hX18PDA559/bv53cHAwZs+ebVX5paWlGD9+PP744w+rNUqZsrIysCwLlmWRkZGBPn36YOTIkfj111+feG5NdQcACoUCISEhGDp0KABgwIABWL58OU6dOoXw8HCsWbMG69ats1qzRqOxuE5ZlkVFRQXKy8tx//59vPPOO5g/fz7Onj1rkbfa8t133+HSpUuIj483B/WrV69ixowZ+P777xEZGYkbN25gypQpeOqpp9C1a9cay7HltyUWNOjaAVteoAUGBiIvLw8sy9bYwBcsWGCLNHh5eWHx4sVQKBQ2lSNVEhISMGvWLMjlcri6uuL999/H+++/D4PB8MRzH1V3YWFhOHz4cJVjTYHH1IN8+eWXAQDl5eXQ6XTo3r07Dh48aB4Dfhx16tSxuE6zsrLQt29fsCwLFxcXjBgxArNnz35k231Su3wcS5cuxdGjR5GQkIA6deqYP798+TKefvppc8+3UaNG6Nq1K/78889HBl1HmTtsCTTo2gFbtqNr2bIl/P398cMPP2Dy5MlwcXHB+fPn0bZtW960+fj4wMfHh5fypEZQUBC8vLwwa9YsTJo0CR4eHgAqp3M9iUfVXa9evfDtt99i69atGDhwIA4ePIj8/Hy0bdsWXl5e2Lt3r7mMPXv2YNeuXViyZEmtAi5gXZ2aEkdOmDABs2fPRnBw8GN9Pq5dlpeXm29Kpt6zu7s7ACAuLg67du1CYmIi6tatW6XM5s2b48aNGzhx4gQ6deqEnJwcHDlyBOPHj3+sV2eD7qdrBzQaDbKzs61+FLp16xYWLVqE9PR0MAyDF198EZ988gkv2mQyGRo0aAClUslLeVKD4zgYjcZqAa+2dfqouvvnn38wf/585ObmolGjRpgxYwbat29f7fyUlBRs3rwZSUlJtdZsbZ0aDAaLfD7KW8uWLasde+bMGQCVwdrV1bVK7/jNN9/Em2++CaDyJvO///0ParUaderUwUsvvYTp06fX+KLYWdsuDbp2QK/XIysryyEfhRiGQVhYGJGpsG2BlDolxac9obMX7ICrq6vDPgYxDON0jdYRIKVOSfFpT2jQtROO+qLKUXU5A476t+NbFyk+7QUNunZCqVQ6XI/BlOqaYh2k1CkpPu0FDbp2wtfXV2wJNeKoupwBR/3b8a2LFJ/2ggZdOyGXyx3uzqxUKq2e3E4hp05J8WkvaNC1I/7+/g7zmMYwDPz9/cWW4fSQUqek+LQHNOjaEU9PT4cYHzONh3l6eoqqQwqQUqek+LQHNOjameDgYIdouCEhIaJqkBKk1CkpPoWGBl07I5fLERoaKlrjZRgGoaGhtV5SSnkypNQpKT6FhgZdEfD29oafn5/dGy/DMPDz84O3t7ddr0sCpNQpKT6FhAZdkQgMDISPj4/dGi/DMKhbt+4j9yal2A4pdUqKT6Ggey+ICMdxyMvLQ2FhoaBr2029hMDAQNHH5KQOKXVKik8hoEHXASgpKUFOTg44juO1ATMMYx4Hk8JjmTNBSp2S4pNPaNB1EFiWhVqt5i1rsGlqTXBwsNNOInd2SKlTUnzyBR3TdRDkcjkaNmyIOnXq4MiRIwAs36DZ1Dvw9vZGo0aN0LBhQ0k2WmfBVKebN29GRkaGuX4swRnq1OQzJycHR48elaxPvpCmKyclKysLnTt3RlFREVJTU9GkSRNoNBrodDpz1t+HMX2uUCigVCrh6+sr2cbqbHAch48++gjffPMNnn32WVy4cAFFRUWSrNNt27Zh2LBhYFkW5eXluHfvniR98oH0HToJaWlp6NOnD4qLi+Hu7o78/HxERkYiICAAQOVm0qaUKg+mXVcoFE65p6jUYVkW48ePx6ZNm8z/lsvlCAgIkFydrlixAtOmTYNer4erqytKS0sl6ZMvaNB1AIqLi9G1a1dzOmmGYZCXl1flGFdXVyIbqLPy+eefY82aNeY0N6b06g8ihTo9dOgQ3nrrLfO/FQoF8vLyquwAJgWffELHdB0AHx8fbN26FeHh4ZDJZNDr9VCr1WLLotjA1KlTMWfOHMhkMnh4eECj0dQqg7Cz0aVLF8THx6Nu3bpwdXVFWVlZtQ4DpSq0p+sg9O3bF0qlEomJibh16xZv2X4p4vDUU08hMjISrVu3xnfffYe9e/dKZp7pg7i7uyMmJgb/+c9/cPjwYezcuRONGzcWW5ZDQ6eMOQjnzp1Dnz59cPPmTSJeJpDAq6++im7dumHy5MliSxGUZcuW4dChQ9iwYYPYUpwCOrzgIKhUKrz++us04EqEgoIC7N27F6NHjxZbiuDExcVh4sSJYstwGugv3AEoLy9HcnIyUlNTxZZC4Ynk5GQMGjQIdevWFVuKoJw6dQp37txB7969xZbiNNCergOwbds2hIeHo0mTJmJLofAAx3HE9P5UKhUmTJjg9Nst2hPa03UAVCoV3njjDbFlUHji5MmTKC8vR9euXcWWIihlZWX47bffkJ6eLrYUp4IGXZG5ceMG0tLSsGXLFrGlUHgiLi4Ob7zxhiRnKzzI5s2b0aFDBzz99NNiS3EqaNAVmZUrV2LUqFHw8PAQWwqFB0pLS7Fp0yacO3dObCmCo1KpMGnSJLFlOB006IqIwWBAfHw8tm7dKrYUCk+sX78e0dHRCAoKEluKoFy+fBlnz57F4MGDxZbidNAXaSJy4MABBAQEoE2bNmJLofAEKS/Q4uPjMXbsWLi7u4stxemgPV0RoS/QpMX58+dx/fp1DBgwQGwpgsKyLFauXIn9+/eLLcUpoT1dkbhz5w727duHUaNGiS2FwhOkLHDZvXs3nn76abRo0UJsKU6JtFuHA5OcnIyXX35Z8pPnSaGiogKrVq0iYoGLSqUiYghFKGhPVwRMk+fp0IJ02LZtG1q0aCH5BS63bt3CkSNHMGLECLGlOC006IrAiRMnoNfrER0dLbYUCk+QMj6flJSEoUOHQqlUii3FaaHDCyJAyuR5Urh58yZOnjyJzZs3iy1FUExPaElJSWJLcWpo0LUzGo0GmzZtQmZmpthSKDyRkJBAxAKXP/74A25ubujcubPYUpwaGnTtzPr169GtWzcEBgaKLYXCA6YFLikpKWJLERzTEAp9QrMNOqZrZ0iZPE8KBw8ehL+/v+Qzfdy7dw/btm3Da6+9JrYUp4cGXTty7tw53Lx5E/379xdbCoUnSJk+9dtvv6Fv377mDL8U66FB146QMnmeFEzZIUhY4EKnOPIH/fXbCVN2iGPHjokthcITycnJGDx4sOQXuJw6dQoFBQU0OwRP0J6unTBlh3j22WfFlkLhAZIWuKhUKowfP55mh+AJ2tO1E6SM/ZECzQ5BsRba07UDN27cwN9//42YmBixpVB4gpQFLjQ7BP/Qnq4dIGXyPCmUlpZi48aNOH/+vNhSBCcuLg6TJ08WW4akoEFXYAwGAxISErBt2zaxpVB4wrTAhYTsEOfOnaPZIXiGDi8IjCk7ROvWrcWWQuEJUl6g0ewQwkB7ugJDX6BJi/Pnz+PGjRs0OwTFamhPV0BodgjpoVKpMG7cOMkvcKHZIYRD2i1HZEzZIXx8fMSWQuEBmh2Cwgc06FqBXq+HTqeD0WgEx3FgGAYymQwKhQKurq4A/m/y/LJly0RWaz218SkFautTCtkhauPVlB1i1apVIquVJjTo1gKWZVFUVASNRgOdTmdurA9j+lyhUKCwsBDu7u5OlR3CGp9KpRK+vr5O9bhtrc81a9Y4Xe/PGq9paWkYM2YMzQ4hEAzHcZzYIhwVrVaLgoICaDQaAJUNs7YYDAYwDIO6devC398fnp6eQsm0GVt8mn7ASqVS8j4rKipQt25d1K9f36F9ArZ51ev1kMvl8PHxcfg6dUZo0K0BlmWhVquh0WgsaqyPgmEYKJVKBAcHO1SPkPq0Dkf1CZDl1VmhQfchSkpKkJOTA47jeGm0JhiGAcMwCA0Nhbe3N2/lWgv1aRuO5hMgy6szQ4Pu/4fjOOTl5aGwsJDXBvswDMPAz88PgYGBoqzbpz75RWyfAFlepQANuqhstLm5uSguLha00ZpgGAY+Pj4ICQmxa+OlPoVBLJ8AWV6lAl0cASAvL89ujRao/KEUFxcjLy/PLtczQX0Kg1g+AbK8SgXig25JSYngj2U1wXEcCgsLUVJSYpfrUZ/CYm+fAFlepQTRQZdlWfOLBzHgOA45OTlgWVbQ61Cf9sFePgGyvEoNooOuWq0WrdGa4DgOarVa0GtQn/bDHj4BsrxKDWKDrlar5W0uoy1wHAeNRgOtVitI+dSnfRHaJ0CWVylCbNAtKCgQvdGa4DgOBQUFgpRNfdofIX0CZHmVIkQGXZZlzcsjHQWNRsP7+Bj1KR5C+ATI8ipViAy6RUVFVp977do1DBs2DBEREVi9ejWPqmzTZWt5Qvp6GOpT+DLt6RMQxqtUITLo2jIelpCQgI4dO+LEiRMYM2YMcnNz8csvv9isyTQ+xieW+HzQV9OmTTFhwgRERkaiX79+NR6fnJyM/v37o1OnThg8eDCuX79eq+s4kk+DwYD+/fujc+fO6NmzJ7766itzj+3u3buYMWMGevbsicjISLz22mvIyMiotSYhfAK19/pwOz1//jzGjRuHTp06oVu3bkhOTq52TlpaGlq2bIklS5ZYpEkor1KFyKCr0+msPletVqNJkyb4999/sXz5chgMBgDA33//jeXLl4umy9byTL4AwMPDAzExMXj//fdrPHbTpk3YvHkzli5dihMnTmDp0qXw9fUVRBff5T3os0ePHli/fj2OHz+OLVu2ICsry9wr1Gq1aNGiBdatW4ejR49i8ODBePfddy16acS3T0vKfNBnUVERJk2ahOHDh+Po0aPYtWsXoqKiqhyv1+vx1VdfoVWrVoLqohAYdPV6vdW93DfeeANpaWlYuHAh3nzzTXh6emLevHnYvXs3jh49irFjx9qkjeM46PV6m8owYYnPB3116tQJSqUSgwYNQmhoaLVjjUYjli1bhhkzZuDZZ58FwzBo0KCBRdkxHMWnwWAwb+Bi2k82OzsbANCgQQOMGzcOAQEBcHFxwfDhw6HX63Ht2rVaa+PTJ1B7rw/7XLRoEaKiojBw4EC4ubnBy8sLjRs3rnJOYmIioqKi8Mwzz1iljW+vUoa4oKvT6axeM65SqdCuXTt8+umnOHnyJEJCQszfyWS2/ykZhrGqx9CtWzd0794daWlp5s8s8fmwr8f98PLz85Gfn4/Lly+jd+/e6N+/P5YuXQqj0Vhrvdb6XLBgARo3bowNGzaYr2erz507d6Jz586Ijo5GVlYWhg8fXuO5Fy5cgF6vR8OGDWut11qfqampqFevHr788kvcv3/f/HltvT7s8+7du/Dx8cHYsWPRrVs3TJkyBbdu3TIfr1arkZKSgnfeecdirSas9UoixAVdS4LD4/j333+RlZWFWbNmYcCAAYiMjLT5hQXHcSgrK4NWq7XoP1N6lW7duiE6OhpHjx7lzefD5OfnA6gMDJs3b4ZKpcLu3buxefNmwX3m5+fj2rVrGD9+PBo3bozk5GSb35q/9NJLOH78OHbs2IHhw4ejXr161Y4pLS3FJ598gkmTJlmcTcFan2VlZfjiiy8QFBSEefPmoaSkxOo6zc/Px7Zt2/Dxxx9j3759CAkJwYwZM8zfL1q0CFOmTLF5s3Kh2pzUIG5XYr7mN7Zu3RqtW7dGbm4uAKBjx47o2LGjTWWWlpbis88+w969ey06z9TDKCsrw9GjRxEdHY2zZ88KsguUu7s7AGD8+PHw9vaGt7c3hg8fjj///BPDhg2rVRllZWV4++23LfZZUVEBALh//z7u37+P1157DStWrKg2PmkNTz/9NJo0aYL58+fjhx9+MH+u0+kwZcoUtG7d2uJUPeXl5Vb5NBgMZq8AMHv2bGzduhUHDhywqBwT7u7u6NmzJ8LDwwEAkyZNQnR0NDQaDf755x9otVr079/fqrIfxFHmDjs6xAVdvgNRSEgIJk+ezEtZSqUSK1eutDh7cLNmzXD58mW4ublh4MCBWLhwIZ566inzDYFPnnnmGbi6utr0d/T09LTK5wcffIDFixfDy8sLzZo1w7fffou2bdvythSVZVnzmC5QGeSnTZuG+vXrY/bs2RaX5+7ubpXP3bt3Y+jQoXBxcUGdOnWwYMECxMbGWr3yKywsrEp9Pfj/J06cwLlz59C9e3cAlTd+mUyGS5cu4aeffrLoOnSrx9pB3PACH2OvQmKNvmbNmmHo0KHIyMjAhg0b0LRpU5t8Go1GlJeXg2VZcByH8vJy80sSDw8P9O/fHwkJCbh//z7y8vKwceNGdOvWzaJrWKOvQYMGaN++PbZv346///4bPXr0gIuLi8XlmNi0aRPu3r0LALhy5QpUKhUiIiIAVL60ev/99+Hu7o4FCxZY/fe05ryAgAAEBwfjxx9/RHZ2Nt544w24urparWHIkCE4ePCgeVz6119/Rbt27aBUKjFlyhTs2LEDGzduxMaNG9G9e3e88sormD9/vsXXcfTflqNAXE9XoVA47GMQx3FQKBQWn7dt27Zqn9ni859//sGECRPM/+7QoQM6dOiAhIQEAMCnn36KuXPnomfPnlAqlXjllVcQExNT6/Kt9Tl9+nRMnz69yme2+Dx16hSWLFmCsrIy+Pr6om/fvpgyZQoA4PTp0zhy5AgUCkWV4Ytly5ahffv2tSrfWp8dOnTAlStXqn1urdeIiAhMmzYN7777LsrKytCuXTt89dVXAAAvLy94eXmZj3V3d4eHh4fFvXNrvZIIkZkjzp8/75CD/jKZDM8//zxv5VGf4sK3T4Asr1KFyOcBR70j862L+hQXIXSR5FWqEBl0lUqlww36m1Jd8wn1KR5C+ATI8ipViAy6lixZtSd866I+xUUIXSR5lSpEBl25XO5wd2alUgm5nN/3mtSneAjhEyDLq1QhMugCgL+/v8M8pjEMA39/f0HKpj7tj5A+AbK8ShFig66np6dDjI+ZxsNsXYL5KKhP+yK0T4Asr1KE2KALAMHBwQ7RcB/cOEcIqE/7YQ+fAFlepQbRQVculyM0NFS0xsswDEJDQ21aVVUbqE/7YC+fAFlepQbRQRcAvL294efnZ/fGyzAM/Pz8zPu5Cg31KSz29gmQ5VVKEB90ASAwMBA+Pj52a7wMw6Bu3boIDAy0y/VMUJ/CIJZPgCyvUoHIZcA1wXEc8vLyUFhYKOjeDKZeQmBgoCiPhtQnv4jtEyDLqxSgQfchSkpKkJOTA47jeG3ADMOYx8Ec4bGM+rQNR/MJkOXVmaFBtwZYloVarbYpa/CDmKbWBAcHO9QkcurTOhzVJ0CWV2eFBt3HoNVqUVBQYE4vbcmfyvT4pVQq4e/v79BzGanPJ+NMPgGyvDobNOg+hj179uCVV17B3bt3odFooNFooNPpzJljH8b0uUKhgFKphK+vr1P0DtRqNcLDw5GQkICoqCjJ+qyoqEBERATat2+PRYsWSdYnAEydOhU7duxAVlYWioqKJO3V2aB/1UewZcsWjBw5Enq9Hnfv3kVISAgCAgIAVGYV0Ol0MBqN5sYqk8mgUCjg6uoqsnLLuHjxIqKiolBUVIR79+4hICCgms9Lly4hPT0dw4cPd1qfGo0Gffv2xenTpxEYGFijz/LycqhUKrz99ttwcXFxSp8cx+HDDz/E0qVL4enpCblcXqPXFStW4OWXX4a/v7/T1qnTwlGq8csvv3AeHh4cAM7T05P7+++/xZYkCMePH+e8vb05AJxMJuO+/fbbGo/r1KkTJ5PJuDt37thZIT/k5+dzzz33HOfq6soB4KKjo2s87ocffuAAcHv27LGzQn7Q6/XcqFGjOE9PTw4A5+bmVuNxmZmZHABu9OjRdlZI4bjKt5yUBzh//jwHgGMYhgPAeXt7czt27BBbFu8YjUauXr16nIuLCweAA8BNmzat2nGZmZmcm5sbJ5PJuPfff18EpbbzyiuvVPHZtGnTaseUl5dz9erV4wBwLVq04IxGowhKbWPZsmVmjwA4uVzOaTSaasfFxMRwADh3d3fu9u3bIiglG7o44iGaN2+OkydP4rnnnoNcLodWq0VeXp7YsniHYRikp6cjJiYGcrkccrkcN27cqHbcJ598ApZlYTQasWzZMhQUFIig1jbi4+Px5ZdfQiaTwc3NzZyM8kESEhLMqeyvX7+Offv22VumzUycOBHbtm1DnTp14OrqCqPRWK3tXrhwAXv27AFQORSxaNEiMaQSDQ26NdCqVSvcuXMHp06dQkJCAnr06CG2JEFo2LAhysvL8euvv2Lnzp2YOnVqle+vXLmClJQU8zzNiooKfPfddyKptR5vb2/IZDKMHTsWaWlpWLJkSZXvOY7D7NmzzUG3rKwMn3zyiRhSbUIulyMkJAT16tXDtWvXsHTpUoSGhlY55vPPP0dFRQUAwGAwYOnSpbh3754YcomFzl6ogfXr12P58uU4cOCA2FIEJTc3F+Hh4cjOzkadOnWqfV9UVITExERkZGTg5MmTmDBhAjp37lwlO64zwHEcwsPD8euvvyI6OrrGY+Li4nD79m188cUXWLhwIUJCQvDqq6/aWantvPvuu6hfvz5mz55d4/d79+7FuXPnsHjxYowYMQKNGzfGxIkTaY4zeyLq4IaD0qdPH27NmjViyxCcBQsWcG+++eYTj0tJSeEGDx5sB0XCkJqayoWFhT1xnLa4uJhTKpV2UsU/Wq2W8/Pz427evPnEY9u2bcv9888/dlBFeRg6Zewhrl+/jvT0dGzbtk1sKYJiNBoRHx+PNWvWiC1FcOLi4vDGG29Ifr+ATZs2oVOnTmjQoIHYUiiPgQbdh0hISMDo0aMl/7h15MgReHh4oGPHjmJLERSNRoPNmzcjMzNTbCmCExcXh/fee09sGZQnQIPuAxgMBiQkJGDHjh1iSxEclUqFiRMnSr73t27dOnTv3l3yWxFeunQJmZmZGDRokNhSKE+Azl54gP379yMwMBCtWtt2I4YAACAASURBVLUSW4qgFBUVYceOHRg7dqzYUgTHdHOROvHx8Xjttdfg5uYmthTKE6A93Qcwjf1JnTVr1qB///6oV6+e2FIE5ezZs7h58yb69esnthRBYVkWiYmJOHjwoNhSKLWA9nT/P7dv38bBgwcxatQosaUICsdxWLFiBRG9P5VKhfHjx0t+45Zdu3ahcePGaN68udhSKLVA2q3RAlatWoWXX35Z8ps0p6eno7i4GD179hRbiqCUl5cjOTkZJ06cEFuK4JDyhCYVaE8Xlb0/Usb+VCoVJkyYAJlM2lW/detWtGrVCo0bNxZbiqCo1WocPXoUw4cPF1sKpZbQni6AY8eOwWg0okuXLmJLERStVou1a9fi33//FVuK4MTFxRFxE01MTMSwYcNqXFFIcUxo0AVZk+c7d+4s+cnzpCxwMT2hkbDARUpI+xmzFpSUlGDLli2IjY0VW4rgkNL7I2mBi6enp+QXuEgN4nu669atQ48ePVC/fn2xpQhKVlYWLly4gIEDB4otRVBMC1y2b98uthTBUalURDyhSQ3ie7qkvECLj49HbGys5CfP79+/H/Xr10fr1q3FliIoRUVF2L59OxELXKQG0T3ds2fPIicnR/KT5/V6PRITE/H777+LLUVwSJk+RcoCFylCdE/XNHnexcVFbCmCsmvXLjz77LN47rnnxJYiKLdv38aBAwckv8AFIOcJTYoQ29Olk+elh2mBi4+Pj9hSBCU9PR1FRUWSX+AiVYjt6aakpKB169aSnzyfm5tLxOR5usCF4iwQ29M1vfmVOomJiRgxYoTkJ88fO3YMBoMBL7zwgthSBMW0wOX06dNiS6FYCZFBl5TJ80ajESqVCr/99pvYUgSHpAUuERERkl/gImWIDLoJCQkYM2YMEZPnvby8JD953rTAhYTsECqVqlrWZopzQVzQJSk7BCm9P9KyQ0h9gYvUIW4knqTsEDt37iRi8jwpL9BIWeAidYjr6ZIyfWr16tVETJ4/e/YssrOzJb/AhWVZrFy5kogFLlKHqJ4uSdkhSNnchqTsECQscCEBabfUh6DZIaQFaQtcSLiJkgAxPV2Sen9xcXFETJ5PSUmh2SEoTgcxPd1jx46B4zgiskOsW7eOiOwQpLxAS0xMxPDhw+Hl5SW2FAoPEBN0SZk+tXHjRpodQkKYFrjQ7BDSgYiga5o8v2jRIrGlWI1er4dOp4PRaATHcWAYBjKZDAqFAq6urubjVCoVpk2bJqJS26itTylkh6iN1z/++MPps0PUtk5JgYig64zZIViWRVFRETQaDXQ6nbmxPozpc4VCAZ1OB7Va7VST563xWadOHWzYsMHpljdb4zUjIwNvv/22Uz2hWeNTqVTC19dX8rNQAEKCrkqlwqxZs8SWUSu0Wi0KCgqg0WgAVDZMEw/+/4NwHAetVguO47Bx40bk5eXB398fnp6edtFsDbb4LCsrw5o1a+Dr6wutVuvQPgHbvEZHR8PV1RU3b96UfJ3evn0bSqXS4X3aiuSD7pkzZ5CTk4P+/fuLLeWxsCwLtVoNjUbzyAb6JBiGgYuLC0pKSqDRaKBUKhEcHOxQvQc+fHIcB7lc7tA+AX68mh6/HdkrX3UKOLZPvpD2nCI4R3aIkpISZGVl2dRoH4bjOGg0GmRlZaGkpISXMm2FFJ8AOV5J8ckn0ruNPEB5eTlWr16NkydPii2lRjiOQ15eHgoLC3lrsA+Xz3EcsrOz4efnh8DAQFHGBknxadJCgldSfAqBpHu6puwQjRo1EltKNTiOQ25urmCN9uFrFRYWIjc3V/Br1XRtEnyark+CV1J8CoWkg64jr0DLy8tDcXGx3RoSx3EoLi5GXl6eXa5nghSfADleSfEpFJINuteuXcOpU6cwZMgQsaVUo6SkxC69hIcx9RrsNU5Gik+AHK+k+BQSyQZdR80OwbIscnJyRHtU4jgOOTk5YFlW0OuQ4hMgxyspPoVGkkHXlB3CEffNVavVoo9NcRwHtVot6DVI8QmQ45UUn0IjyaC7b98+BAUFOVx2CK1Wy+vUGmsxTcnRarWClE+KT4Acr6T4tAeSDLqO+gKtoKBA9EZrguM4FBQUCFI2KT4BcryS4tMeSC7omrJDjBw5UmwpVWBZ1rw80lHQaDS8j4+R4hMgxyspPu2F5IJuUlISYmJiHC47RFFRkdXnDhkyBGlpaTyq+T9s0WVreUL6ehi+fVpSpj19ArROHR1JBV2O46BSqRzyBZot42EpKSno2LEjOI7DkiVL0KtXL0RGRmL8+PG4fPmy1ZpM42N8YolPk69Lly7h7bffRnR0NFq2bFnjsbt378bgwYPRqVMnDBgwAP/880+tNQnhE6i9V5PP3bt3Y9CgQYiMjES3bt0wc+ZMlJaWAgAqKiowe/Zs9O3bFxERERg2bBj+/PNPizU5Sp0CQHZ2Nt59911EREQgOjoaixcvrnb8jRs30L59e3z88ccWaRKqTu2BpIJuamoqOM4xs0PodDqby9i7dy9SUlKwcuVKHD16FK1bt8ann34qui5by5PL5ejXrx/mzp1b4/epqan4/vvvMW/ePBw/fhwrV65EaGio4Lr4LrNt27ZISkrCsWPHsHv3brAsi59++glA5SN8YGAgEhIScOzYMbz33nv473//i9zcXMF1CVGeXq/HW2+9hU6dOuH333/HgQMH8NJLL1U7bsGCBQgPD7ebLkdAUkHX9ALN0dZo6/V6m15C9OvXD8eOHUNubi7atm2LBg0awMXFBQMHDsSVK1ds0sZxHPR6vcXnlZaWVjvPUp8mX40aNcLQoUPRpEmTGo/75Zdf8M4776B169aQyWSoX7++xXsjW+tTr9ebe6MPf15bryafgYGB8PX1NX/u4uKCmzdvAgA8PT0xefJkhISEQCaToVu3bggJCcH58+ct1myt13v37lX7zNo6TUlJwVNPPYVx48bB09MT7u7uaNasWZVjd+/eDaVSiYiICIu1Atb7FBvJBF1TdojY2FixpVRDp9PxciMYMGAAsrOzcf36dej1emzbts3mXj3DMFb1GF5//XWEhIRg+fLlqKioAMCfzwcxGAw4d+4cCgsL8eKLL6JXr15YsGCBxZqt9ZmcnIyAgADMmDEDd+/eNX9urdf09HRERkYiIiICBw4cwGuvvVbjcQUFBbhx4waeffZZi69hjderV6/Cz88PMTExuHDhgvlza31mZGQgODgY77zzDqKjozF+/HhkZWWZvy8tLcXSpUvx4YcfWly2CWvrVGwks8vY2rVr0atXLzz11FNiS6mG0WjkpZyAgAC0a9cOgwYNgouLCwIDAxEXF2eztuzs7Mf+sLKzs6HRaJCZmWn+LD8/H3fu3MH06dPxySefYNq0aZgyZYpNWmri7t27YFkW+/fvR2JiIuRyOaZOnYrly5dj6tSptS7HtGPV43yWlpbCYDBU8Xn9+nUYDAYsWbIEP/30E0aPHo158+ZZnSSyXbt2OHbsGPLz87Fp0yYEBwdXO0av1+Pjjz/G4MGDrcp0XJs61el0uHbtGjw8PAAAV65cgUKhwLZt27B792688MILWLBgAZ577jmLrw9Uto+0tDQsWbIEnTt3RnJyMqZOnYrt27fD1dUVP//8M2JiYhAYGGhV+Sb4+m3ZE8kEXZVKhTlz5ogto0b4mt+4bNkynD17Fvv374e/vz927NiBiRMnYsuWLeYfj6WUlZVh+fLlOHr06COPKS0txb179zB06FDzZ9nZ2ebzy8rKMGfOHERGRiIkJMQqHY/C3d0dADB69GgEBAQAAGJjYy0OuuXl5Vi2bNljfRqNRuh0uio+i4qKwLKsuQ7j4+PBsix+/PFHa+yYqV+/Prp06YIZM2Zg/fr1VTR8+umncHV1tXq8XqvVYsWKFY/1eu3aNfz3v/81L5MvLy9HeXk5jEYjysvLcfDgQVy7ds2iF5YP4u7ujrZt2yI6OhpA5ZPR8uXLcfXqVXAch+PHj2PDhg1Wlf0gjjJ32BIkEXQzMjKgVqvRr18/saXUCF+P3BcvXkT//v3NvYMhQ4bg66+/xtWrV9GiRQuryvTy8sK3334LHx+fRx6zdetWxMfHY+vWrebPBgwYgP3798PV1RWxsbGYM2cOvLy8rHrx8zh8fHxQv379Kn9Da/6eCoXiiT5LSkoQGhpapae7bNkyTJ06Fa6uroiMjMQ333yDdu3aobi42GIND2MwGMw3L6AygMyePRt3797FL7/8YnXSxjp16jzRa7t27RAXF4d27doBADIzM9GmTRu4urqifv36+OabbzBkyBCrZwiEhYXh9OnTNX6XlpYGtVqNPn36AKi8SRiNRowYMaLKDag2ONr7m9ogiaDr6NkhZDJ+hs7Dw8Oxb98+9O/fH35+fti5cydYlrU53bo1+qKiotCwYUPMmTPH/Ihs7Q+U4zhUVFSYX4qUl5eDYRi4ubkBqLy5rFmzBl26dIFcLseqVavQtWtXi69jjc+mTZuib9++mDdvnjlAWVvWjh070L59ewQFBUGtVmPJkiVVXiLNmzcP165dw4oVK2zeqMlSfQEBAYiIiMD06dMxZMgQ8/nWtt2BAweaZ2p06tQJq1evRt26ddG4cWM0bNgQAwYMMB+7cuVKqNVqfPbZZxZfh6/flj1x+qCr0+mwevVqu04+txSFQsHLY9CECRNw9+5dDB8+HGVlZWjYsCEWL15s00IQjuOs+oHXlOjTWp9qtbpKDrsOHTogODgYe/fuBQC8/fbbuHfvHgYNGgQ3Nzf069cPb731lkXXsNZn79690bt372qfW+P16tWr+P777805wKKjozF9+nQAlX+DDRs2wM3NDd27dzefM3v2bIuzO1vj1d/fH3/88Ue1z62t00aNGmHRokWYN28eCgsL0bx5c/z0009wdXWFq6trleEwT09PuLm5wc/Pz6JrWFunYsNwzjgo8gBr166FSqXC/v37xZbyWM6fP++Qg/4ymQzPP//8Y4+paXjhUTizT9PwQm33bHVmrw8PLzwOZ/bpiDhf3/whHHUF2sM46h2Zb12k+BSqTD6gderYOHXQvXbtGk6fPu2Q2SEeRqlUOtygP8MwUCqVvJZJik+AHK+k+LQXTh104+PjMXr0aKe44z24EsmR4FsXKT6FKpMPaJ06Nk4bdB05O0RNyOVyh7szK5VKyOX8vkslxSdAjldSfNoLpw26e/fuRUhIiMNlhwAqZ1R0794dLVu2xPPPP49nnnkGPj4+OHPmjMM8pjEMA39/f0HK9vf3J8InQI5XUnzaA+e8VaBycxtH7eW6u7sjLy8PFy9eNH/m4+ODyMhIlJaWip72xDQe5unpKUj5np6eUCqVkvcJkOOVFJ/2wCl7uvn5+Th06JDDZYcwce3atSpbD3p4eODw4cPw8/NDcHCw6D0GhmF4X677MKT4BMjxSopPoXHKoOuo2SGys7Px9ttvo2PHjujSpQtatGgBuVyORYsWoU2bNgAqx8dCQ0NFa7wMwyA0NFTw1Xuk+ATI8UqKT6FxuqBryg7hSIknb926halTp6J169bw9fVFVlYW5s6dixUrVmD8+PHVNmbx9vaGn5+f3RsvwzDw8/Oz282KFJ8AOV5J8SkkTjem+9dff4FhGERFRYktBQUFBfj6668RFxeHcePGITMzs8rm2pGRkYiMjKzx3MDAQBgMBhQXF9tljIxhGNStW9fmrfQshRSfADleSfEpFE7X0zW9QBNzbOnevXuYNWsWmjVrhtLSUmRkZOD777+3KJuBaWzKHr0GUy9BjDE5Unyark+CV1J8CoVTBd3i4mKkpKSIlh1Co9Fg/vz5aNKkCXJzc/H333/jl19+sThflwmGYRAUFIQGDRpAJpPx3qgYhoFMJkODBg0QFBQk6lgcCT5NWkjwSopPIXCqoCtWdgitVotvvvkGTZo0QWZmJlJTUxEfH49GjRrxUr63tzfCwsJ4XW5pmloTFhbmMONgpPgEyPFKik8+cYgxXb1eD51OB6PRCI7jzHc5hUJRZSNnlUr1yIyxQqDT6bBixQosWrQIUVFROHTokNWbhT8JuVyOhg0bQqvVoqCgwLw3rSVjZqZGr1Qq4e/v75BzGR/0eefOHXPSR6n5BMjxSopPvhAl6LIsi6KiImg0Guh0OnOgfRjT5wqFAizL4v79++jbt6/g+vR6PRISEjB//ny0atUKO3fuRNu2bQW/LlA5Cb1hw4ZW/Y2USiV8fX2dYnlkYWEhwsPDMW3aNEybNk2yPsvLyxEZGQm9Xo+MjAxJ1+msWbOwdOlSlJaWStqnrdjV4eN6cY+6K3IcB61WC4ZhsG7dOuTm5gp2J2RZFqtXr8bcuXPRpEkTrF+/Hp07d+b9OrVBLpcjICDAnBestk8DzkBmZiZeeOEFcyJIqfosKSlB3759kZGRgSZNmki2TjmOw3/+8x8sXboUBoMBDMNI0idf2CXosiwLtVpt0xJC03klJSXmnfeDg4N5uTMajUasX78en3/+OZ566ikkJCSgW7duNpfLJ6Yd952dY8eOoX///uaNwvPy8qp8LxWf+fn56Nq1K65fvw6gMoHnw0jBq16vx9ixY7Fjxw6wLAt3d3cUFBRUmckjBZ98InjQLSkpQU5ODjiO421OH8dx0Gg0yMrKQmhoqNWD7RzHYevWrZg9ezY8PDzw008/oXfv3pJ6U+poTJs2DeXl5eZ/37x5U0Q1wvHjjz/iypUrMBgMACqzCkuRI0eOYP369eZVYqZ9RyyZPkkags1e4DgOt27dQnZ2tvmxgu/yjUYjsrOzcevWLYvK5zgOu3fvRseOHTF37lwsXLgQx48fR58+fWjAFZijR49i1qxZUCgUkMlkuHPnjtiSBGHBggXYt28f3Nzc4O7uDq1Wa068KSV69+6NK1euIDQ0FJ6entBoNLh9+7bYshwaQXq6HMchNzfXLitWOI5DYWEhDAYDQkJCnhg0Dx06hM8++wzFxcWYO3cuhg4d6pQZRZ0VNzc3XL58GV988QVeffXVasMLUoFhGOTk5KB379747bffcPz4cck+YstkMpSWliI/Px8nTpxAp06dxJbk0AgSdPPy8uy2RBCoDLzFxcVwcXFBUFBQjcf89ddfmDVrFrKzs/H5559j5MiRTr9xhjNSUlKCLVu24Msvv0T9+vXRsGFDsSUJhkqlwvTp0+Ht7W2XWTdikZCQgNGjR6NOnTro1auX2HIcHt6DbklJCQoLC+2+56apx+vl5VVljPfvv//GrFmzkJmZidmzZyM2NpaIaSmOytq1a9GzZ0/Jj/llZWXh4sWLFqdPdzYMBgPi4+OxY8cOsaU4Dbw+V7Msa35pJgYcxyEnJwcsyyIjIwMxMTEYMmQIBg0ahIsXL2LChAk04IpMXFycQ+0QJxQqlQqxsbGSHVIwsW/fPgQGBqJ169ZiS3EaeA26arVa1F3lgf+bkdC3b1907doVly5dwuTJk+Hu7i6qLgqQkZGBW7duoV+/fmJLERS9Xo+kpCSHzWzCJ462zaozwFu3T6vVip7KA6gMuk2bNsW5c+dQr149UbVQqqJSqfD6669Lfix9586daNKkCZo1aya2FEG5ffs2Dhw4AJVKJbYUp4K3oFtQUCB6wDUhk8lw//59GnQdCJ1Oh9WrVyMtLU1sKYJDSu9v1apVGDJkCHx8fMSW4lTwMrzAsqx5aa81DBkyhPcfo0ajAcuyvJZJsZ6UlBS0adOGt53ZHJXc3Fz89ddfGDZsmNhSBIXjOIdODuvI8BJ0bV1tk5KSgo4dO5r/PXPmTFslAZDuKiBnhJTeX2JiIkaMGAEvLy+xpQjKsWPHYDQa8cILL4gtxengJejyMZZ769YtfPXVV+Y16llZWfj222+tLs+0VJgiPteuXcOpU6cwZMgQsaUIitFohEqlIqL35wgZXJwVXsZ0dTqdTef369cPn3/+Ofr06YNPP/0U586dg4eHByZPniyqLgo/xMfHY8yYMVAoFGJLEZTDhw+jTp066NChg9hSBKWkpASbN2/GhQsXxJbilNgcdPV6vSAv0PhIAcJxHPR6veTnSjoyBoMBCQkJ2LVrl9hSBMc0hCL13t+6devQs2dPySSKtDc2Dy/odDpeGtmdO3ewf/9+LFy4EB07dsQrr7yC+Ph4m8pkGIb2dkVm7969CA4ORqtWrcSWIihFRUXYuXMnxowZI7YUwaEv0GzD5p6u0WjkQwcCAgIwePBg87+bNWvGyzxHvvRRrIOUF2irV6/Giy++CD8/P7GlCMrZs2eRm5sr+QUuQmJzT1eIoYUFCxbwVpajzB0mkfz8fBw8eBAjR44UW4qgkDR9Ki4uDq+//jpdTm8DNv/lHH38ytH1SZlVq1YhJiZGkhldHyQ9PR0lJSXo0aOH2FIEpby8HKtXr8aJEyfEluLU2NzTdfS9aB1dn1Qhrfc3YcIEybe1lJQUtGrVCo0bNxZbilNjc09XoVDY/Ai/d+9eW2XUCMdxkp+m5Kj89ddfAIAuXbqIrERYtFot1q1bh4yMDLGlCA4pO8QJjc23ZldXV4d9hGcYhk4XEwlSpk9t3LgRUVFRCA0NFVuKoJgWuMTExIgtxenh5XnIUXuTjqpL6piyQ8TGxootRXBIGUIxZYegvynb4SXoKpVKh+vR6HQ6bNy4EZs2baqSfZYiPGvXrkWvXr3w1FNPiS1FULKyspCVlUVEdoiEhAQ6tMATvARdX19fPorhFYVCgaCgIPz8888IDg7GpEmTcOzYMTqFzA6Q0vsjLTuE1Be42Ategq5cLodSqeSjKN7w9vbGa6+9ht9//x3p6ekIDQ3F66+/jrCwMMybNw/Xr18XW6IkISk7RGJiIhE3F/oCjV94m+Pi7+/vMEMMDMPA39/f/O+nn34aM2fOxIULF5CcnIy8vDx06NAB3bp1g0qlQnFxsYhqpYVKpcL48eOJyA4RFhZGRHaIgwcPYtSoUWJLkQy8BV1PT0+HGNtlGAZKpRKenp41fhcREYGlS5ciNzcX06dPx44dO9CwYUOMGjUKu3fvphuf24ApO8T48ePFliI4pAyhJCUlYciQIZJf4GJPeJ3NHRwc7BBBNyQk5InHubu7IyYmBlu2bMHVq1cRHR2NuXPnokGDBvjvf/9LxLxLvklJSUHbtm2JyA6RmppKRHYIUvbOsCe8Bl25XI7Q0FDRAi/DMAgNDbX40bZevXqYPHkyjh8/jsOHD0OhUGDQoEFo06YNFi9ejLy8PIEUSwtSen8rV64kIjtEamoqOI6T/AIXe8P7ukVvb2/4+fnZPfAyDAM/Pz+bH4OaNWuG+fPn49q1a/j+++9x5swZNG/eHC+++CLWrl1rzmxBqcrVq1dx+vRpYrJDkND7o9khhEGQxeKBgYHw8fGxW2UxDIO6devyuqmyTCZDjx49kJCQgJycHIwZMwYJCQkICQnBxIkT8ccff9BtIx8gISEBY8eOlfzk+cOHD8Pb2xvt27cXW4qgkLTAxd4IEnRN46r26PGaerhCjid7eXlhzJgx2Lt3L86cOYNmzZrh3XffxbPPPovZs2fj0qVLglzXWTBNnidhaIGU3t/atWvRs2dP1K9fX2wpkkOwbZEYhkFQUBAaNGjAS+qdmsqXyWRo0KABgoKC7PYjCAkJwYcffoiMjAxs2bIFpaWliI6ORlRUFH799VcUFhbaRYcjsXfvXoSEhKBly5ZiSxGUwsJC7Nq1i4jsEKQMoYiB4HvReXt7IywsjNfpZKZpYWFhYaJNZWEYxvyiLTs7G5999hkOHz6Mxo0bY9iwYdi2bRv0er0o2vhCr9dDo9GgTp06aNOmDYqLi6HRaKr5cvYXaCafWq0W/fr1e6RPKWSHMHnt1KkTZDJZjV7PnDnj9NkhTD6Li4tx7969R9apGDCcHdfFarVaFBQUmFOjW3JpU8BWKpXw9/evcR6uI3Dv3j1s3LgRiYmJuHjxIkaOHInY2Fi0b9/e4R9JWZZFUVERNBoNdDodOI4DwzDgOA4cx5n3izV9rlAowDAMunXrhrNnzzrNXM5H+QQqh0pMs18e9KlUKhETE4MvvvgCPXv2FFO+RTzKq8FgqPIE+qDX33//HWq1GrNmzRJZfe15XJ0+yMN16uvra/csGHYNuiac6Q9kC1euXEFycjKSkpLg7u6O2NhYjBkzBg0aNBBbWhVsuRkClW/069at69A3Q8D2mz7Lsqhbty4CAgIc2idgm1fTjcfb21vydQrYvyMnStB9GL1eD51OB6PRaA60MpkMCoVCEpuJcByH1NRUJCUlYePGjWjbti1iY2MxdOhQ1KlTRzRdLMtCrVZDo9HwshGQadgnODjYoW6OpPgEyPHqzD4dIuiShE6nw/bt25GUlIQ///wTgwcPRmxsLHr06GHX/QpKSkqQk5NjHjrgC4ZhzItUHGG4gRSfADlend0nDboicvv2bfz2229ISkpCfn4+xo4di9jYWDz//POCXZPjOOTl5aGwsFDQbS5NU/kCAwNFGcsmxSdAjlep+KRB10E4e/YsVq1aheTkZAQFBSE2NhajRo1CQEAAb9fgOA65ubkoLi62y77CDMPAx8cHISEhdv2RkuITIMerlHzSoOtgGAwGHDp0CElJSdi+fTu6du2K2NhYDBo0CO7u7jaVfevWLcF7CQ9j6jUEBQXZ7Zqk+ATI8SolnzToOjAajQabN29GUlISTp8+jREjRiA2NhadO3e2+O5bUlKC7OxsUTJnMAyDBg0a2GU8kBSfADlepeaTBl0n4ebNm1i9ejUSExNhMBgQGxuLsWPHVtlGcceOHfDz80NUVFSVc1mWRVZWlqh7RchkMoSFhQn6ZpgUnwA5XqXoU/AVaRR+aNiwIT755BNkZmZizZo1uH37Njp16mTOfnHv3j28+eab6N27N06ePFnlXLVaLXpuOI7joFarBb0GKT4BcrxK0Sft6ToxFRUV2L17N5KSkrB3716Ul5eDZVkolUqkpqYiPDwcWq0W165dE73hrCq4DwAACmFJREFUApWPao0aNRJkEjopPgFyvErVJ+3pOjFubm54+eWXsWnTJgwaNAgGgwFA5Vhwhw4dkJaWhoKCAodotEBlj6GgoECQsknxCZDjVao+HWeJCcUmDh06BKAyV13dunXh4uKCCxcuONwSTo1GA5ZleR0HZFnWvAzUURDCJ0COVyn7pEFXIpw9exYeHh5VlhXfuXMHt2/fFlFVzRQVFfE6/7ioqIi3sviEb5+mMh0RWqe1hw4vSISAgIBq+zjwtS79Qfbs2YPBgwcjIiICL7/8Mg4ePGjR+RzH8d6DsdWnXq/H+++/j379+qFly5ZIS0urdsz58+cxbtw488vL5OTkx5YphE/ANq9XrlzBq6++iqioKERFRWHixIm4cuWK+fuEhATExMQgIiIC/fv3R0JCQq3KdcQ6BYCysjLMnz8f0dHRiIyMxLhx46odo9frMXjwYPTq1euJ5fHlk/Z0JYxOp+O1vPz8fHzyySdYsmQJXnjhBfz555/44IMPsGfPHtSrV080XXyU17ZtW4wdOxYffPBBte+KioowadIkfPjhh+jbty/0ej3y8/PtoovPMgMCArB48WIEBwfDaDRi7dq1+PDDD7F582YAlUFlwYIFCAsLQ3Z2Nt5++20EBgZiwIABguoSqry5c+fCYDBg69at8PHxwYULF6odk5CQAF9fX9y/f99uumhPV6Lo9XqregotW7bEzZs3zf+eOXMmlixZAqAy6Hp7eyM6OhoMw6Br167w8PBAdna2RdfgOM6qzaQLCgqqTd2prc/H+XJ1dcVrr72Gdu3amfcMfpCkpCRERUVh4MCBcHNzg5eXFxo3bvzEa1rrs6KiosYAURuvj/Pp7e1tXtZq2h/5wbqbMGECnn/+ecjlcjRq1Ag9evTAqVOnaqXZWq9nz56tNgeXjzq9evUqDh8+jDlz5sDPzw8uLi5o0aJFlfNzcnKwY8cOizJkWOvzQWjQlSg6nY73NeMtWrRAo0aN8Pvvv8NgMODgwYNwdXVFWFiYReUwDGNVj+Hrr79Gw4YNMXHiROTk5AAQxufDZGRkwMfHB2PHjkW3bt0wZcoU3Lp164nnWevz8OHDaN68Obp3715lqIMvr1FRUejQoQMWLVr0yIDDcRzS09PRpEmTWpVpjdfy8nK0bNkSjRs3xvr1683Blw+fZ8+eRVBQEJYuXYro6GjExMRg//79VY5ZtGgRpk2bZlEyVWvr9EHo8IJEEWIFj4uLCwYPHoyPPvoIFRUVcHV1xXfffWfxDAmDwYB//vnH4nT2ly9fhsFgQGJiIlatWoXu3bvj559/tqgMa8jPz0dmZiaWL1+Opk2bYvHixZgxYwZWrVr12POMRqNVPtPS0uDh4YEjR44gOjoaYWFh+PrrrxEZGWmLDTOpqanQarXYtm0bgoODazzml19+gdFoxJAhQ2pVJsuyFnutqKiATCbDjRs3MG7cOLz33nv46KOPeEn9lJ+fj8uXL6NPnz44dOgQTp8+bU4m27hxYxw8eBAGgwG9evWqcQz/cdj626JBV6IIMb/x2LFjWLx4MRISEtC8eXOcP38e7733HpYtW4bnnnuu1uXo9Xrs27ev1o+uJkyP3CzLgmEYHDhwAKdPn6722Mg37u7u6NmzJ8LDwwEAkyZNQnR0NDQaDZRK5SPPY1nWKp937twxP8KWl5fjzJkz+N///ofOnTtbb+IhPD09MWLECHTt2hVbt26tMia/Zs0abN++HStXroSbm1utyquoqLDYq8FgMLdTnU6H8vJy/PDDD5gwYYJlZmrA3d0dcrkcb731FuRyOTp27IhOnTohNTUVgYGBWLx4MX755Reryrb1t0WDrkSx9vHMw8OjSm/l7t275jTcFy9eRPv27c1BLjw8HC1btsTx48ctCroKhQIfffQRfHx8LNL2wQcf4Mcff4S7uzumTZuGDz/8EDKZDLm5uTb5ehJhYWFV/p61/du6ublZ5XP37t14+eWXoVAoMHDgQCxcuBBNmzZFcXHxE8+1xKfRaIROp8Pt27fNQXfLli1QqVRYuXIlAgMDa63Z09PTYq86nQ6enp7w8vJCWFgYvvvuO3Tv3h0lJSW1Ov9xXmsa8jLV282bN6FWq82zGfR6PUpLS9G9e3esXr0aISEhj72urUMfdExXotT0Qqg2NGvWDLt27YLBYMDRo0fx999/m79r0aIF0tPTzT3OzMxMpKenWzyma62+bt264eOPP0ZOTg4WLlwIX1/fWpfzOF9AZU+tvLwcQOWPsLy83NyjGTJkCA4ePIgLFy5Ar9fj119/Rbt27R7byzVhjc9mzZph/PjxyMjIwIYNG9C0adNal/U4n6mpqcjMzITBYEBpaSm++eYbeHt7m18K7tixAz/++CNWrFhhVR4/S726ublhwoQJ2L59O/755x/06NHDnKqrNjzOa/v27REUFIS4uDiwLItTp07h5MmTiIqKQpMmTbB//35s3LgRGzduxNy5c1GvXj1s3LixVjcaa39bJujeCxJFr9cjKyvL4kehc+fOYebMmbh16xZ69uwJg8GA0NBQTJ06FUDlo2dycjLu3r0LX19fjBo1qsb5j4+DYRiEhYXxkv+utj6f5Ktfv37VZkbs2bPH3OtZt24dli9fjrKyMrRr1w6fffbZE3+gfPoEauf1cT737t2Ln3/+Gfn5+VAoFAgPD8e0adPQrFkzAED//v2Rn59fRe/AgQMxe/bsJ2pzxDq9fPky5syZg0uXLiEoKAhTp06tcT5uWloaPv7441rNOefDJw26Eub8+fOibon3KGQyGa8piUjxCZDjVco+6fCChLFkKow94VsXKT6FKpMPaJ3WHhp0JYxSqRQtWeKjMKW65hNSfALkeJWyTxp0JYyvr6/YEmqEb12k+BSqTD6gdVp7aNCVMHK5XJDeli0olUretzskxSdAjlcp+6RBV+L4+/s7zGMawzDw9/cXpGxSfALkeJWqTxp0JY6np6dDjI+ZxsOE2lSdFJ8AOV6l6pMGXQIIDg52iIb7pJU+tkKKT4Acr1L0SYMuAcjlcoSGhorWeBmGQWhoKFxcXAS9Dik+AXK8StEnDbqE4O3tDT8/P7s3XoZh4OfnB29vb7tcjxSfADlepeaTBl2CCAwMhI+Pj90aL8MwqFu3rkUbp/ABKT4BcrxKySddBkwYHMchLy8PhYWFgqa3NvUSAgMDRXk0JMUnQI5XqfikQZdQSkpKkJOTA47jeG3ADMOYx8Hs+aj9KEjxCZDj1dl90qBLMCzLQq1W85Y12DS1Jjg4WJCFAdZCik+AHK/O7JMGXQq0Wi0KCgrM6aUtaRKmxy+lUgl/f39B56faCik+AXK8OqNPGnQpZliWRVFRETQaDXQ6HTiOq3FMy/S5QqGAUqmEr6+vQ/WCngQpPgFyvDqTTxp0KY9Er9dDp9PBaDSaG6tMJoNCoeBtY25HgBSfADleHdknDboUCoViR+g8XQqFQrEjNOhSKBSKHaFBl0KhUOwIDboUCoViR2jQpVAoFDtCgy6FQqHYERp0KRQKxY7QoEuhUCh2hAZdCoVCsSP/D8kCIQZIH9jZAAAAAElFTkSuQmCC\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],\n",
" 'u8': ['u16', 'i16', 'f16'], 'u16': ['u32', 'i32', 'f32'], 'u32': ['u64', 'i64', 'f64'],\n",
" 'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],\n",
" 'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],\n",
" 'c64': [2, 3], 'c128': [3, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gYIJaqkCuh35"
},
"source": [
"This is effectively what Numpy type promotion does, but in doing so it breaks the lattice property of the graph: for example, the pair *{i8, u8}* no longer has a unique least upper bound: the possibilities are *i16* and *f16*, which are unorderable on the graph. This turns out to be the source of NumPy's non-associative type promotion highlighted above.\n",
"\n",
"Can we come up with a modification of NumPy's promotion rules, such that it will satisfy the lattice property, while also giving sensible results for mixed type promotion? There are a few approaches we could take here."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nLXKOk48lfY2"
},
"source": [
"### Option 0: Leave integer/floating mixed precision undefined\n",
"\n",
"To make behavior utterly predictable (at some cost to user convenience), a defensible choice would be to leave as undefined any mixed integer/float promotion beyond Python scalars, stopping with the partial lattice from the previous section. The downside would be the requirement for users to explicitly type-cast when operating between integer and floating-point quantities."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TETvnofnEiG0"
},
"source": [
"### Option 1: Avoiding All Precision Loss\n",
"\n",
"If our focus is on avoiding precision loss at all costs, we can restore the lattice property by promoting unsigned integers to float via their existing signed integer paths:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"cellView": "form",
"id": "zEfVDpewv6z3",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO29eVgTV//+f08IEMCAICibWvequK+giFoXbN2w6uOKS1et26Nf7WLV+rh1sbbVWq0SUUTrvu9btbWi0qrFHXeBCIogCUIgk8zvD37JRxSVJDOZZM55XZfXJUnmzH3nnLznzJlzzpvhOI4DhUKhUOyCTGwBFAqFQhI06FIoFIodoUGXQqFQ7AgNuhQKhWJHaNClUCgUO0KDLoVCodgRGnQpFArFjtCgS6FQKHaEBl0KhUKxIzToUigUih2Riy2AUj70ej10Oh2MRiM4jgPDMJDJZFAoFHB1dRVbHsUKSKlTUnyWFxp0HRSWZZGbmwutVgudTmdurM9jel2hUECpVMLX1xdyOa1WR4SUOiXFp7UwdMMbx6KgoADZ2dnQarUAShpmeTE1bKVSCX9/f3h6egqikWIZpNQpKT5thQZdB4FlWajVami1Wosa68tgGAZKpRLBwcFE9B4cEVLqlBSffEGDrgOg0WiQnp4OjuN4abQmGIYBwzAIDQ2Ft7c3b+VSXg8pdUqKTz6hQVdEOI5DZmYmcnJyeG2wz8MwDPz8/BAYGFjm2BqFP0ipU1J8CgENuiLBcRwyMjKQl5cnaKM1wTAMfHx8EBISIpnG62iQUqek+BQKOk9XJDIzM+3WaIGSH0peXh4yMzPtcj4SIaVOSfEpFDToioBGoxH8tqwsOI5DTk4ONBqNXc9LAqTUKSk+hYQGXTvDsqz5wYMYcByH9PR0sCwryvmlCCl1SopPoaFB186o1WrRGq0JjuOgVqtF1SAlSKlTUnwKDQ26dqSgoIC3uYy2wHEctFotCgoKRNUhBUipU1J82gMadO1Idna26I3WBMdxyM7OFluG00NKnZLi0x7QoGsnWJY1L490FLRardOPj4kJKXVKik97QYOuncjNzRWk3OnTp9t0vFC6SMCW7y4nJwfTpk1DeHg4IiIi8Omnn77wmby8PHTo0AGxsbF202VreY8ePcL48ePRuXNnNGrUCBkZGaXeX7hwId555x20adMGvXr1wq5du0q9f+bMGQwcOBBt27ZFdHQ0Nm/ezIsuR4IGXTvB53hYYWEhZs+ejSdPngAoecAxe/Zsi8s3jY9RrMOWOv3vf/8Lf39/HDp0CCdOnMDIkSNf+MwPP/yAGjVqWFSuEHVqiU+GYdCuXTssWrSozPc9PDywZMkSJCUlYd68efj6669x4cIFACVbQE6aNAn9+/dHUlISFi5ciO+++w7Xr19/oRxnbrvS203CQdHpdFYfm5mZia+//hrnzp2D0WhEjx49MGTIEMydOxeXL1/G0qVLMWbMGKtW69iii3TK892VVXedOnVCZmYmVq1aBRcXFwBA/fr1Sx134cIF3LhxA/3798f27dt518VHeWV5mz59OgYNGvTSW/9PPvnE/P/GjRujRYsW+Pfff9G0aVPk5eUhPz8fvXr1AsMwCAsLQ82aNXHr1i3Uq1ev3LocHdrTtQN6vd7qHpHBYMAnn3yCoKAgHDhwAEePHkWPHj3M7z+7KbQ1cBwHvV5v1bEksHv3brRp0wZHjhwpVYflqdOX1d2///6LN954A9OnT0f79u0xaNAgJCcnlzpu/vz5+OKLL6y6kFpTp1lZWWjQoAF+/fVXFBcXm19/mc/XtcvyoNPpcOnSJdSqVQsA4O/vjx49emDHjh0wGAy4cOECHjx4gObNm5d5vLO2XRp07YBOp7N6zfjFixfx6NEjTJkyBZ6ennB3d0f9+vWxbt06fPnll2jRogXGjh2LuLg4qwI7wzBO22OwB2lpaTh37hz69u2Lpk2bmoNveeq0rLpr3rw5srKycOrUKbRu3Rq///47RowYgYkTJ5rHKNetW4dGjRqhYcOGVmm2pk5zcnJw+/ZtTJkyBaGhoebg+zKfL/NmCXPmzEG9evXQrl0782tvv/02li9fjhYtWmDkyJEYP348AgMDefPpCNDhBTtgNBqtPjYzMxNBQUGl9hX18PDAV199Zf47ODgYM2fOtKr8/Px8jBo1Cn/88YfVGqVMYWEhWJYFy7JISUlB165dMWjQICxfvvy1x5ZVdwCgUCgQEhKCfv36AQB69OiBFStW4Pz58wgLC8P69euxceNGqzVrtVqL65RlWRQXF6OoqAhPnz7Fxx9/jLlz5+LSpUsWeSsv33//PW7cuIFVq1aZg/rt27cxbdo0/PDDDwgPD8e9e/cwbtw4VK5cGR06dCizHFt+W2JBg64dsOUBWmBgIDIzM8GybJkNfN68ebZIg5eXFxYtWgSFQmFTOVIlPj4eM2bMgFwuh6urKyZPnozJkyfDYDC89tiX1V3dunVx/PjxUp81BR5TD7JPnz4AgKKiIuh0OnTs2BFHjx41jwG/igoVKlhcp6mpqejWrRtYloWLiwsGDhyImTNnvrTtvq5dvoqlS5fi5MmTiI+PR4UKFcyv37x5E9WrVzf3fGvUqIEOHTrgzz//fGnQdZS5w5ZAg64dsGU7ukaNGsHf3x8//vgjxo4dCxcXF1y5cgXNmjXjTZuPjw98fHx4KU9qBAUFwcvLCzNmzMCYMWPg4eEBoGQ61+t4Wd299dZbWLhwIXbu3ImePXvi6NGjyMrKQrNmzeDl5YWDBw+ayzhw4AD27duHxYsXlyvgAtbVqSlx5OjRozFz5kwEBwe/0uer2mVRUZH5omTqPbu7uwMA4uLisG/fPqxZswYVK1YsVWb9+vVx7949nDlzBq1bt0Z6ejpOnDiBUaNGvdKrs0H307UDWq0WaWlpVt8KPXjwAAsWLMC5c+fAMAzefvttfP7557xok8lkqFq1KpRKJS/lSQ2O42A0Gl8IeOWt05fV3T///IO5c+ciIyMDNWrUwLRp09CiRYsXjt+xYwe2bduGhISEcmu2tk4NBoNFPl/mrVGjRi989uLFiwBKgrWrq2up3vEHH3yADz74AEDJRebXX3+FWq1GhQoV8M4772DSpEllPih21rZLg64d0Ov1SE1NdchbIYZhULduXSJTYdsCKXVKik97Qmcv2AFXV1eHvQ1iGMbpGq0jQEqdkuLTntCgaycc9UGVo+pyBhz1u+NbFyk+7QUNunZCqVQ6XI/BlOqaYh2k1CkpPu0FDbp2wtfXV2wJZeKoupwBR/3u+NZFik97QYOunZDL5Q53ZVYqlVZPbqeQU6ek+LQXNOjaEX9/f4e5TWMYBv7+/mLLcHpIqVNSfNoDGnTtiKenp0OMj5nGwzw9PUXVIQVIqVNSfNoDGnTtTHBwsEM03JCQEFE1SAlS6pQUn0JDg66dkcvlCA0NFa3xMgyD0NDQci8ppbweUuqUFJ9CQ4OuCHh7e8PPz8/ujZdhGPj5+cHb29uu5yUBUuqUFJ9CQoOuSAQGBsLHx8dujZdhGFSsWPGle5NSbIeUOiXFp1DQvRdEhOM4ZGZmIicnR9C17aZeQmBgoOhjclKHlDolxacQ0KDrAGg0GqSnp4PjOF4bMMMw5nEwKdyWOROk1CkpPvmEBl0HgWVZqNVq3rIGm6bWBAcHO+0kcmeHlDolxSdf0DFdB0Eul6NatWqoUKECTpw4AcDyDZpNvQNvb2/UqFED1apVk2SjdRZMdbpt2zakpKSY68cSnKFOTT7T09Nx8uRJyfrkC2m6clJSU1PRtm1b5Obm4tSpU6hduza0Wi10Op056+/zmF5XKBRQKpXw9fWVbGN1NjiOw6efforvvvsOtWrVwrVr15CbmyvJOt21axf69+8PlmVRVFSEJ0+eSNInH0jfoZOQnJyMrl27Ii8vD+7u7sjKykJ4eDgCAgIAlGwmbUqp8mzadYVC4ZR7ikodlmUxatQobN261fy3XC5HQECA5Op05cqVmDhxIvR6PVxdXZGfny9Jn3xBg64DkJeXhw4dOpjTSTMMg8zMzFKfcXV1JbKBOitfffUV1q9fb05zY0qv/ixSqNNjx47hww8/NP+tUCiQmZlZagcwKfjkEzqm6wD4+Phg586dCAsLg0wmg16vh1qtFlsWxQYmTJiAWbNmQSaTwcPDA1qttlwZhJ2Ndu3aYdWqVahYsSJcXV1RWFj4QoeBUhra03UQunXrBqVSiTVr1uDBgwe8ZfuliEPlypURHh6OJk2a4Pvvv8fBgwclM8/0Wdzd3RETE4P//ve/OH78OPbu3YuaNWuKLcuhoVPGHITLly+ja9euuH//PhEPE0jgP//5D6KiojB27FixpQjKsmXLcOzYMWzevFlsKU4BHV5wEFQqFUaOHEkDrkTIzs7GwYMHMWTIELGlCE5cXBzef/99sWU4DfQX7gAUFRUhMTERp06dElsKhScSExPRq1cvVKxYUWwpgnL+/Hk8evQIXbp0EVuK00B7ug7Arl27EBYWhtq1a4sthcIDHMcR0/tTqVQYPXq002+3aE9oT9cBUKlUeO+998SWQeGJs2fPoqioCB06dBBbiqAUFhbit99+w7lz58SW4lTQoCsy9+7dQ3JyMrZv3y62FApPxMXF4b333pPkbIVn2bZtG1q2bInq1auLLcWpoEFXZFavXo3BgwfDw8NDbCkUHsjPz8fWrVtx+fJlsaUIjkqlwpgxY8SW4XTQoCsiBoMBq1atws6dO8WWQuGJTZs2ITIyEkFBQWJLEZSbN2/i0qVL6N27t9hSnA76IE1Ejhw5goCAADRt2lRsKRSeIOUB2qpVqzBs2DC4u7uLLcXpoD1dEaEP0KTFlStXcPfuXfTo0UNsKYLCsixWr16Nw4cPiy3FKaE9XZF49OgRDh06hMGDB4sthcITpCxw2b9/P6pXr46GDRuKLcUpkXbrcGASExPRp08fyU+eJ4Xi4mKsXbuWiAUuKpWKiCEUoaA9XREwTZ6nQwvSYdeuXWjYsKHkF7g8ePAAJ06cwMCBA8WW4rTQoCsCZ86cgV6vR2RkpNhSKDxByvh8QkIC+vXrB6VSKbYUp4UOL4gAKZPnSeH+/fs4e/Ystm3bJrYUQTHdoSUkJIgtxamhQdfOaLVabN26FVevXhVbCoUn4uPjiVjg8scff8DNzQ1t27YVW4pTQ4Oundm0aROioqIQGBgothQKD5gWuOzYsUNsKYJjGkKhd2i2Qcd07Qwpk+dJ4ejRo/D395d8po8nT55g165dGD58uNhSnB4adO3I5cuXcf/+fURHR4sthcITpEyf+u2339CtWzdzhl+K9dCga0dImTxPCqbsECQscKFTHPmD/vrthCk7RFJSkthSKDyRmJiI3r17S36By/nz55GdnU2zQ/AE7enaCVN2iFq1aokthcIDJC1wUalUGDVqFM0OwRO0p2snSBn7IwWaHYJiLbSnawfu3buHv//+GzExMWJLofAEKQtcaHYI/qE9XTtAyuR5UsjPz8eWLVtw5coVsaUITlxcHMaOHSu2DElBg67AGAwGxMfHY9euXWJLofCEaYELCdkhLl++TLND8AwdXhAYU3aIJk2aiC2FwhOkPECj2SGEgfZ0BYY+QJMWV65cwb1792h2CIrV0J6ugNDsENJDpVJhxIgRkl/gQrNDCIe0W47ImLJD+Pj4iC2FwgM0OwSFD2jQtQK9Xg+dTgej0QiO48AwDGQyGRQKBVxdXQH83+T5ZcuWiazWesrjUwqU16cUskOUx6spO8TatWtFVitNaNAtByzLIjc3F1qtFjqdztxYn8f0ukKhQE5ODtzd3Z0qO4Q1PpVKJXx9fZ3qdttan+vXr3e63p81XpOTkzF06FCaHUIgGI7jOLFFOCoFBQXIzs6GVqsFUNIwy4vBYADDMKhYsSL8/f3h6ekplEybscWn6QesVCol77O4uBgVK1ZElSpVHNonYJtXvV4PuVwOHx8fh69TZ4QG3TJgWRZqtRpardaixvoyGIaBUqlEcHCwQ/UIqU/rcFSfAFlenRUadJ9Do9EgPT0dHMfx0mhNMAwDhmEQGhoKb29v3sq1FurTNhzNJ0CWV2eGBt3/H47jkJmZiZycHF4b7PMwDAM/Pz8EBgaKsm6f+uQXsX0CZHmVAjTooqTRZmRkIC8vT9BGa4JhGPj4+CAkJMSujZf6FAaxfAJkeZUKdHEEgMzMTLs1WqDkh5KXl4fMzEy7nM8E9SkMYvkEyPIqFYgPuhqNRvDbsrLgOA45OTnQaDR2OR/1KSz29gmQ5VVKEB10WZY1P3gQA47jkJ6eDpZlBT0P9Wkf7OUTIMur1CA66KrVatEarQmO46BWqwU9B/VpP+zhEyDLq9QgNugWFBTwNpfRFjiOg1arRUFBgSDlU5/2RWifAFlepQixQTc7O1v0RmuC4zhkZ2cLUjb1aX+E9AmQ5VWKEBl0WZY1L490FLRaLe/jY9SneAjhEyDLq1QhMujm5uZafeydO3fQv39/tGnTBuvWreNRlW26bC1PSF/PQ30KX6Y9fQLCeJUqRAZdW8bD4uPj0apVK5w5cwZDhw5FRkYGfvnlF5s1mcbH+MQSn8/6qlOnDkaPHo3w8HB07969zM8nJiYiOjoarVu3Ru/evXH37t1ynceRfBoMBkRHR6Nt27bo3LkzvvnmG3OP7fHjx5g2bRo6d+6M8PBwDB8+HCkpKeXWJIRPoPxen2+nV65cwYgRI9C6dWtERUUhMTHxhWOSk5PRqFEjLF682CJNQnmVKkQGXZ1OZ/WxarUatWvXxr///osVK1bAYDAAAP7++2+sWLFCNF22lmfyBQAeHh6IiYnB5MmTy/zs1q1bsW3bNixduhRnzpzB0qVL4evrK4guvst71menTp2wadMmnD59Gtu3b0dqaqq5V1hQUICGDRti48aNOHnyJHr37o1PPvnEoodGfPu0pMxnfebm5mLMmDEYMGAATp48iX379iEiIqLU5/V6Pb755hs0btxYUF0UAoOuXq+3upf73nvvITk5GfPnz8cHH3wAT09PzJkzB/v378fJkycxbNgwm7RxHAe9Xm9TGSYs8fmsr9atW0OpVKJXr14IDQ194bNGoxHLli3DtGnTUKtWLTAMg6pVq1qUHcNRfBoMBvMGLqb9ZNPS0gAAVatWxYgRIxAQEAAXFxcMGDAAer0ed+7cKbc2Pn0C5ff6vM8FCxYgIiICPXv2hJubG7y8vFCzZs1Sx6xZswYRERF44403rNLGt1cpQ1zQ1el0Vq8ZV6lUaN68Ob744gucPXsWISEh5vdkMtu/SoZhrOoxREVFoWPHjkhOTja/ZonP53296oeXlZWFrKws3Lx5E126dEF0dDSWLl0Ko9FYbr3W+pw3bx5q1qyJzZs3m89nq8+9e/eibdu2iIyMRGpqKgYMGFDmsdeuXYNer0e1atXKrddan6dOnUKlSpXw9ddf4+nTp+bXy+v1eZ+PHz+Gj48Phg0bhqioKIwbNw4PHjwwf16tVmPHjh34+OOPLdZqwlqvJEJc0LUkOLyKf//9F6mpqZgxYwZ69OiB8PBwmx9YcByHwsJCFBQUWPTPlF4lKioKkZGROHnyJG8+nycrKwtASWDYtm0bVCoV9u/fj23btgnuMysrC3fu3MGoUaNQs2ZNJCYm2vzU/J133sHp06exZ88eDBgwAJUqVXrhM/n5+fj8888xZswYi7MpWOuzsLAQ//vf/xAUFIQ5c+ZAo9FYXadZWVnYtWsXPvvsMxw6dAghISGYNm2a+f0FCxZg3LhxNm9WLlSbkxrE7UrM1/zGJk2aoEmTJsjIyAAAtGrVCq1atbKpzPz8fHz55Zc4ePCgRceZehiFhYU4efIkIiMjcenSJUF2gXJ3dwcAjBo1Ct7e3vD29saAAQPw559/on///uUqo7CwEB999JHFPouLiwEAT58+xdOnTzF8+HCsXLnyhfFJa6hevTpq166NuXPn4scffzS/rtPpMG7cODRp0sTiVD1FRUVW+TQYDGavADBz5kzs3LkTR44csagcE+7u7ujcuTPCwsIAAGPGjEFkZCS0Wi3++ecfFBQUIDo62qqyn8VR5g47OsQFXb4DUUhICMaOHctLWUqlEqtXr7Y4e3C9evVw8+ZNuLm5oWfPnpg/fz4qV65sviDwyRtvvAFXV1ebvkdPT0+rfE6ZMgWLFi2Cl5cX6tWrh4ULF6JZs2a8LUVlWdY8pguUBPmJEyeiSpUqmDlzpsXlubu7W+Vz//796NevH1xcXFChQgXMmzcPsbGxVq/8qlu3bqn6evb/Z86cweXLl9GxY0cAJRd+mUyGGzduYMmSJRadh271WD6IG17gY+xVSKzRV69ePfTr1w8pKSnYvHkz6tSpY5NPo9GIoqIisCwLjuNQVFRkfkji4eGB6OhoxMfH4+nTp8jMzMSWLVsQFRVl0Tms0Ve1alW0aNECu3fvxt9//41OnTrBxcXF4nJMbN26FY8fPwYA3Lp1CyqVCm3atAFQ8tBq8uTJcHd3x7x586z+Pq05LiAgAMHBwfjpp5+QlpaG9957D66urlZr6Nu3L44ePWoel16+fDmaN28OpVKJcePGYc+ePdiyZQu2bNmCjh074t1338XcuXMtPo+j/7YcBeJ6ugqFwmFvgziOg0KhsPi4Xbt2vfCaLT7/+ecfjB492vx3y5Yt0bJlS8THxwMAvvjiC8yePRudO3eGUqnEu+++i5iYmHKXb63PSZMmYdKkSaVes8Xn+fPnsXjxYhQWFsLX1xfdunXDuHHjAAAXLlzAiRMnoFAoSg1fLFu2DC1atChX+db6bNmyJW7duvXC69Z6bdOmDSZOnIhPPvkEhYWFaN68Ob755hsAgJeXF7y8vMyfdXd3h4eHh8W9c2u9kgiRmSOuXLnikIP+MpkMDRo04K086lNc+PYJkOVVqhB5P+CoV2S+dVGf4iKELpK8ShUig65SqXS4QX9Tqms+oT7FQwifAFlepQqRQdeSJav2hG9d1Ke4CKGLJK9ShcigK5fLHe7KrFQqIZfz+1yT+hQPIXwCZHmVKkQGXQDw9/d3mNs0hmHg7+8vSNnUp/0R0idAllcpQmzQ9fT0dIjxMdN4mK1LMF8G9WlfhPYJkOVVihAbdAEgODjYIRrusxvnCAH1aT/s4RMgy6vUIDroyuVyhIaGitZ4GYZBaGioTauqygP1aR/s5RMgy6vUIDroAoC3tzf8/Pzs3ngZhoGfn595P1ehoT6Fxd4+AbK8Sgnigy4ABAYGwsfHx26Nl2EYVKxYEYGBgXY5nwnqUxjE8gmQ5VUqELkMuCw4jkNmZiZycnIE3ZvB1EsIDAwU5daQ+uQXsX0CZHmVAjToPodGo0F6ejo4juO1ATMMYx4Hc4TbMurTNhzNJ0CWV2eGBt0yYFkWarXapqzBz2KaWhMcHOxQk8ipT+twVJ8AWV6dFRp0X0FBQQGys7PN6aUt+apMt19KpRL+/v4OPZeR+nw9zuQTIMurs0GD7is4cOAA3n33XTx+/BharRZarRY6nc6cOfZ5TK8rFAoolUr4+vo6Re9ArVYjLCwM8fHxiIiIkKzP4uJitGnTBi1atMCCBQsk6xMAJkyYgD179iA1NRW5ubmS9ups0G/1JWzfvh2DBg2CXq/H48ePERISgoCAAAAlWQV0Oh2MRqO5scpkMigUCri6uoqs3DKuX7+OiIgI5Obm4smTJwgICJCkT61Wi27duuHChQsIDAyUrE+O4zB16lQsXboUnp6ekMvlkvXqrNCgWwbLli3DlClTUFxcDE9PT2RmZpZaeePq6iqJBnrmzBl069YNGo0GMpkM2dnZpd6Xis+HDx8iKirKnI3h2bTmgHR8siyL2NhY7Ny5E0ajsVRySxNS8erM0KD7HFevXsXYsWPNt2ByuRyZmZkiq+IfjuPwzjvvmAOQ0WgslZRRSowdOxY3btyAwWAAAEnWJwDExcXht99+M/9tNBqRn5+PChUqiKiK8jx0ccRz1K9fH2fPnsWbb74JuVyOgoICSf5IGYbBuXPnEBMTA7lcDrlcjnv37oktSxBWrVqFr7/+GjKZDG5ubuZklFLj/fffx65du1ChQgW4urrCaDRKsu06O7SnWwaNGzfGo0ePcP78eVy4cKFUYkIpUa1aNRQVFWH58uWoWrWqZG87vb29IZPJMGzYMEyZMgUXL14UW5IgyOVyhISEoFKlSvjrr7+we/duhIaGii2L8hx09kIZbNq0CStWrMCRI0fEliIoGRkZCAsLQ1pamqRvQTmOQ1hYGJYvX47IyEix5QjKJ598gipVqmDmzJliS6G8BNrTLYO4uDi89957YssQnDVr1mDAgAGSDrgAcPr0abAsi/bt24stRVAKCwuxYcMGXLhwQWwplFdAg+5z3L17F+fOncOuXbvEliIoRqMRq1atwvr168WWIjimi6jU9wvYunUrWrdujapVq4othfIKaNB9jvj4eAwZMkTyKaVPnDgBDw8PtGrVSmwpgqLVarFt2zZcvXpVbCmCExcXh/Hjx4stg/IaaNB9BoPBgPj4eOzZs0dsKYKjUqnw/vvvS773t3HjRnTs2FHyWxHeuHEDV69eRa9evcSWQnkNdMrYMxw+fBiBgYFo3Lix2FIEJTc3F3v27MGwYcPEliI4pouL1Fm1ahWGDx8ONzc3saVQXgPt6T4DKQ/Q1q9fj+joaFSqVElsKYJy6dIl3L9/H927dxdbiqCwLIs1a9bg6NGjYkuhlAPa0/3/efjwIY4ePYrBgweLLUVQOI7DypUriej9qVQqjBo1SvIbt+zbtw81a9ZE/fr1xZZCKQfSbo0WsHbtWvTp00fymzSfO3cOeXl56Ny5s9hSBKWoqAiJiYk4c+aM2FIEh5Q7NKlAe7oo6f2RMvanUqkwevRoyGTSrvqdO3eicePGqFmzpthSBEWtVuPkyZMYMGCA2FIo5YT2dAEkJSXBaDSiXbt2YksRlIKCAmzYsAH//vuv2FIEJy4ujoiL6Jo1a9C/f3/JL3CREjTogqzJ823btpX85HlSFriY7tBIWOAiJaR9j1kONBoNtm/fjtjYWLGlCA4pvT+SFrh4enpKfoGL1CC+p7tx40Z06tQJVapUEVuKoGLf7ZIAACAASURBVKSmpuLatWvo2bOn2FIExbTAZffu3WJLERyVSkXEHZrUIL6nS8oDtFWrViE2Nlbyk+cPHz6MKlWqoEmTJmJLEZTc3Fzs3r2biAUuUoPonu6lS5eQnp4u+cnzer0ea9aswe+//y62FMEhZfoUKQtcpAjRPV3T5HkXFxexpQjKvn37UKtWLbz55ptiSxGUhw8f4siRI5Jf4AKQc4cmRYjt6dLJ89LDtMDFx8dHbCmCcu7cOeTm5kp+gYtUIbanu2PHDjRp0kTyk+czMjKImDxPF7hQnAVie7qmJ79SZ82aNRg4cKDkJ88nJSXBYDBIPjuEaYELzQ7hvBAZdEmZPG80GqFSqUql5ZYqJC1wadOmjeQXuEgZIoNufHw8hg4dSsTkeS8vL8lPnjctcCEhO4RKpcKECRPElkGxAeKCLknZIUjp/ZGWHULqC1ykDnEj8SRlh9i7dy8Rk+dJeYBGygIXqUNcT5eU6VPr1q0jYvL8pUuXkJaWJvkFLizLYvXq1UQscJE6RPV0ScoOQcrmNiRlhyBhgQsJSLulPgfNDiEtSFvgQsJFlASI6emS1PuLi4sjYvL8jh07aHYIitNBTE83KSkJHMcRkR1i48aNRGSHIOUB2po1azBgwAB4eXmJLYXCA8QEXVKmT23ZsoVmh5AQpgUuNDuEdCAi6Jomzy9YsEBsKVaj1+uh0+lgNBrBcRwYhoFMJoNCoYCrq6v5cyqVChMnThRRqW2U16cUskOUx+sff/zh9NkhylunpEBE0HXG7BAsyyI3NxdarRY6nc7cWJ/H9LpCoYBOp4NarXaqyfPW+KxQoQI2b97sdMubrfGakpKCjz76yKnu0KzxqVQq4evrK/lZKADAcBzHiS1CaNq2bYsZM2bgnXfeEVvKaykoKEB2dja0Wi2AkoZZXjiOg9FohK+vL/z9/eHp6SmUTJuxxSfDMNDr9fDz83N4n4BtXvV6PVxdXeHt7e3wXm2tUwBQKpUO79NWJB90L168iB49euDevXsOvVk5y7JQq9XQarUWNdaXwTAMlEolgoODHar3QIpPgByvpPjkC8kH3UmTJkGpVGLOnDliS3kpGo0G6enp4DiOl0ZrgmEYMAyD0NBQh5ibTIpPgByvpPjkE0kH3aKiIoSGhuLs2bOoUaOG2HJegOM4ZGZmIicnh9cG+zwMw8DPzw+BgYGijA2S4hMgxyspPoVA0rPnTdkhHDXgZmRkCN5oTefKyclBRkaG4Ocq69wk+DSdnwSvpPgUCkkHXUdegZaZmYm8vDy7NSSO45CXl4fMzEy7nM8EKT4BcryS4lMoJBt079y5g/Pnz6Nv375iS3kBjUZjl17C85h6DRqNxi7nI8UnQI5XUnwKiWSDrqNmh2BZ1vzgQQw4jkN6ejpYlhX0PKT4BMjxSopPoZFk0DVlh3DEfXPVarXoY1Mcx0GtVgt6DlJ8AuR4JcWn0Egy6B46dAhBQUEOlx2ioKCAt7mMtsBxHLRaLQoKCgQpnxSfADleSfFpDyQZdB31AVp2drbojdYEx3HIzs4WpGxSfALkeCXFpz2QXNA1ZYcYNGiQ2FJKwbKseXmko6DVankfHyPFJ0COV1J82gvJBd2EhATExMQ43CqW3Nxcq4/t27cvkpOTeVTzf9iiy9byhPT1PHz7tKRMe/oEaJ06OpJakcZxHBo0aICVK1eiffv2Ysspxe3bt20eh+I4DkuWLMHOnTtRUFCAN998E9OnT0ft2rWtLtPT05PXzAvW+Lxx4wYWLlyIK1eu4MmTJ7h48eILn9m/fz+WLVuGzMxMVKpUCXPnzkWLFi3KfQ6+fQKWe92/fz9++eUXZGdnw83NDe3bt8fnn3+OChUqoLi4GHPnzsXp06eRl5eHqlWrYuLEiYiMjLRYlyPUKQCkpaXh66+/xt9//w03NzfExMRg8uTJpT5z79499OvXD127dsXXX39tUflC1Kk9kFRP99SpUw6bHUKn09lcxsGDB7Fjxw6sXr0aJ0+eRJMmTfDFF1+IrsvW8uRyObp3747Zs2eX+f6pU6fwww8/YM6cOTh9+jRWr16N0NBQwXXxXWazZs2QkJCApKQk7N+/HyzLYsmSJQBKbuEDAwMRHx+PpKQkjB8/Hv/v//0/ZGRkCK5LiPL0ej0+/PBDtG7dGr///juOHDlS5i5/8+bNQ1hYmN10OQKSCrqmB2iOtkZbr9fb9BCie/fuSEpKQkZGBpo1a4aqVavCxcUFPXv2xK1bt2zSxnEc9Hq9xcfl5+e/cJylPk2+atSogX79+r20x/7LL7/g448/RpMmTSCTyVClShWL90a21qder0d+fn6Zr5fXq8lnYGAgfH19za+7uLjg/v37AEp6bWPHjkVISAhkMhmioqIQEhKCK1euWKzZWq9Pnjx54TVr63THjh2oXLkyRowYAU9PT7i7u6NevXqlPrt//34olUq0adPGYq2A9T7FRjJB15QdIjY2VmwpL6DT6Xi5EPTo0QNpaWm4e/cu9Ho9du3aZXOvnmEYq3oMI0eOREhICFasWIHi4mIA/Pl8FoPBgMuXLyMnJwdvv/023nrrLcybN89izdb6TExMREBAAKZNm4bHjx+bX7fW67lz5xAeHo42bdrgyJEjGD58eJmfy87Oxr1791CrVi2Lz2GN19u3b8PPzw8xMTG4du2a+XVrfaakpCA4OBgff/wxIiMjMWrUKKSmpprfz8/Px9KlSzF16lSLyzZhbZ2KjWQ2q9ywYQPeeustVK5cWWwpL2A0GnkpJyAgAM2bN0evXr3g4uKCwMBAxMXF2awtLS3N4h9WVlYWHj16hEmTJuHzzz/HxIkTMW7cOJu0lMXjx4/BsiwOHz6MNWvWQC6XY8KECVixYgUmTJhQ7nI4jrPK5927d2EwGLB48WIsWbIEQ4YMwZw5c6xOEtm8eXMkJSUhKysLW7duRXBw8Auf0ev1+Oyzz9C7d2+rxiytqdNbt25BoVBg165d2L9/P9q3b4958+bhzTfftPj8QEn7SE5OxuLFi9G2bVskJiZiwoQJ2L17N1xdXfHzzz8jJiYGgYGBVpVvgq/flj2RTNBVqVSYNWuW2DLKhK9nlcuWLcOlS5dw+PBh+Pv7Y8+ePXj//fexfft2eHh4WFVmYWEhVqxYgZMnT1p0XFpamvn4wsJCzJo1C+Hh4QgJCbFKx8twd3cHAAwZMgQBAQEAgNjYWIuDblFREZYtW2axz9zcXLAsa67DVatWgWVZ/PTTTxaV8zxVqlRBu3btMG3aNGzatMn8utFoxBdffAFXV1erx+sLCgqwcuVKi7wWFRWhqKgIRqMRRUVFOHr0KO7cuYN//vnHKg3u7u5o1qyZ+UHgyJEjsWLFCty+fRscx+H06dPYvHmzVWU/izPOA5BE0E1JSYFarUb37t3FllImfN1yX79+HdHR0ebeQd++ffHtt9/i9u3baNiwoVVlenl5YeHChfDx8bHouB49euDw4cNwdXVFbGwsZs2aBS8vL6se/LwKHx8fVKlSpdR3aM33qVAorPK5bNkyTJgwAa6urggPD8d3332H5s2bIy8vz2INz2MwGMwXL6AkgMycOROPHz/GL7/8YnXSxgoVKljs9erVq2jatClcXV1RpUoVfPfdd+jbt6/V83Pr1q2LCxculPlecnIy1Go1unbtCqDkImE0GjFw4MBSF6Dy4GjPb8qDJIKuSqXCqFGjHDYdj0zGz9B5WFgYDh06hOjoaPj5+WHv3r1gWdbmdOvW6IuIiEC1atUwa9Ys8y2ytT9QjuNQXFxsfihSVFQEhmHg5uYGoOTisn79erRr1w5yuRxr165Fhw4dLD6PNT7r1KmDbt26Yc6cOWjevLlNZe3ZswctWrRAUFAQ1Go1Fi9eXOoh0pw5c3Dnzh2sXLnS5o2aLNUXEBCANm3aYNKkSejbt6/5eGvbbs+ePc0zNVq3bo1169ahYsWKqFmzJqpVq4YePXqYP7t69Wqo1Wp8+eWXFp+Hr9+WPXH6oKvT6bBu3Tq7Tj63FIVCwctt0OjRo/H48WMMGDAAhYWFqFatGhYtWmTTQhCO46z6gc+YMeOF16z1qVarER0dbf67ZcuWCA4OxsGDBwEAH330EZ48eYJevXrBzc0N3bt3x4cffmjROaz12aVLF3Tp0uWF163xevv2bfzwww/QarVQKpWIjIzEpEmTAJR8B5s3b4abmxs6duxoPmbmzJkWZ3e2xqu/vz/++OOPF163tk5r1KiBBQsWYM6cOcjJyUH9+vWxZMkSuLq6wtXVtdRwmKenJ9zc3ODn52fROaytU7Fx+sURGzZsgEqlwuHDh8WW8kquXLnikIP+MpkMDRo04K08UnwC5Hglxae9cL6++XOoVCqH3MLxeRz1isy3LlJ8ClUmH9A6dWycOujeuXMHFy5ccMjsEM+jVCodbtDflOqaT0jxCZDjlRSf9sKpg+6qVaswZMgQp7jiPbsSyZHgWxcpPoUqkw9onTo2Tht0HTk7RFnI5XKHuzIrlUrI5fw+SyXFJ0COV1J82gunDboHDx5ESEiIw2WHAEpmVHTs2BGNGjVCgwYN8MYbb8DHxwcXL150mNs0hmHg7+8vSNn+/v5E+ATI8UqKT3vgnJcKlGxu46i9XHd3d2RmZuL69evm13x8fBAeHo78/HzR056YxsM8PT0FKd/T0xNKpVLyPgFyvJLi0x44ZU83KysLx44dc7jsECbu3LlTautBDw8PHD9+HH5+fggODha9x8AwDO/LdZ+HFJ8AOV5J8Sk0Thl0HTU7RFpaGj766CO0atUK7dq1Q8OGDSGXy7FgwQI0bdoUQMn4WGhoqGiNl2EYhIaGCr56jxSfADleSfEpNE4XdDmOg0qlcqjEkw8ePMCECRPQpEkT+Pr6IjU1FbNnz8bKlSsxatSoFzZm8fb2hp+fn90bL8Mw8PPzs9vFihSfADleSfEpJE43pvvXX3+BYRhERESILQXZ2dn49ttvERcXhxEjRuDq1aulNtcODw9HeHh4mccGBgbCYDAgLy/PLmNkDMOgYsWKNm+lZymk+ATI8UqKT6Fwup6u6QGamGNLT548wYwZM1CvXj3k5+cjJSUFP/zwg0XZDExjU/boNZh6CWKMyZHi03R+EryS4lMonGrvhby8PFSvXh2pqamibFau1Wrx008/4ccff0Tv3r0xY8YM1KhRw+ZyNRoN0tPTwXEcrz0HhmHM42COcFtGik+AHK+k+OQTpwq6v/76Kw4dOoStW7fa9bwFBQVYunQpFi5ciC5dumDWrFmoW7cur+dgWRZqtZq3KTmmqTXBwcEONYmcFJ8AOV5J8ckXDhF09Xo9dDodjEYjOI4DwzCQyWRQKBSlNnJu3bo1Zs+eXWovTiHR6XRYuXIlFixYgIiICMyePdvqzcLLS0FBAbKzs81701pSPabbL6VSCX9/f4eey1hQUIBHjx6Zkz5K1SdAjldSfNqKKJcRlmWRm5sLrVYLnU5nDrTPY3pdoVCAZVk8ffoU3bp1E1yfXq9HfHw85s6di8aNG2Pv3r1o1qyZ4OcFSiahV6tWzarvSKlUwtfX1yl6Bzk5OQgLC8PEiRMxceJEyfosKipCeHg49Ho9UlJSJF2nM2bMwNKlS5Gfny9pn7ZiV4ev6sW97KrIcRwKCgrAMAw2btyIjIwMwa6ELMti3bp1mD17NmrXro1Nmzahbdu2vJ+nPMjlcgQEBJjzgpX3bsAZuHr1Ktq3b4/8/HwYDAbJ+tRoNOjWrRtSUlJQu3ZtydYpx3H473//i6VLl8JgMIBhGEn65Au7BF0+xnxMx2k0GvPO+3yN+RiNRmzatAlfffUVKleujPj4eERFRdlcLp+Ydtx3dpKSkhAdHQ2NRgMAyMzMLPW+VHxmZWWhQ4cOuHv3LoCSBJ7PIwWver0ew4YNw549e8CyLNzd3ZGdnV1qJo8UfPKJ4EFXiKebHMdBq9UiNTXVpqebHMdh586dmDlzJjw8PLBkyRJ06dJFMlNTHJGJEyeiqKjI/Pf9+/dFVCMcP/30E27dugWDwQCgJKuwFDlx4gQ2bdpkXiVm2nfEkumTpCHYPF2O4/DgwQOkpaWZbyv4Lt9oNCItLQ0PHjywqHyO47B//360atUKs2fPxvz583H69Gl07dqVBlyBOXnyJGbMmAGFQgGZTIZHjx6JLUkQ5s2bh0OHDsHNzQ3u7u4oKCgwJ96UEl26dMGtW7cQGhoKT09PaLVaPHz4UGxZDo0gPV2O45CRkWGXFSscxyEnJwcGgwEhISGvDZrHjh3Dl19+iby8PMyePRv9+vVzyoyizoqbmxtu3ryJ//3vf/jPf/7zwvCCVGAYBunp6ejSpQt+++03nD59WrK32DKZDPn5+cjKysKZM2fQunVrsSU5NIJMGXvw4AFycnLsugWcaeVKUFBQme//9ddfmDFjBtLS0vDVV19h0KBBTr9xhjOi0WhQrVo1XL9+XfK3oFFRUZg0aRJiYmLEliIos2bNQm5uLhYvXiy2FKeA956uRqOxe8AF/q/H6+XlVWqM9++//8aMGTNw9epVzJw5E7GxsURMS3FUNmzYgM6dO0s+4KampuL69esWp093NgwGA1atWoU9e/aILcVp4PW+mmVZ80MzMeA4Dunp6WBZFikpKYiJiUHfvn3Rq1cvXL9+HaNHj6YBV2Ti4uIcaoc4oVCpVIiNjZXskIKJQ4cOITAwEE2aNBFbitPAa9BVq9Wi7ioP/N+MhG7duqFDhw64ceMGxo4dC3d3d1F1UYCUlBQ8ePAA3bt3F1uKoOj1eiQkJDhsZhM+cbRtVp0B3rp9BQUFoqfyAEqCbp06dXD58mVUqlRJVC2U0qhUKowcOVLyY+l79+5F7dq1Ua9ePbGlCMrDhw9x5MgRqFQqsaU4FbwF3ezsbNEDrgmZTIanT5/SoOtA6HQ6rFu3DsnJyWJLERxSen9r165F37594ePjI7YUp4KX4QWWZc1Le62hb9++vP8YtVotWJbltUyK9ezYsQNNmzblZStMRyYjIwN//fUX+vfvL7YUQeE4zqGTwzoyvARdW1fb7NixA61atTL/PX36dFslAZDuKiBnhJTe35o1azBw4EB4eXmJLUVQkpKSYDQa0b59e7GlOB28BF0+xnIfPHiAb775xrxGPTU1FQsXLrS6PNNSYYr43LlzB+fPn0ffvn3FliIoRqMRKpWKiN6fI2RwcVZ4GdPV6XQ2Hd+9e3d89dVX6Nq1K7744gtcvnwZHh4eGDt2rKi6KPywatUqDB06FAqFQmwpgnL8+HFUqFABLVu2FFuKoGg0Gmzbtg3Xrl0TW4pTYnPQ1ev1gjxAk8lkNl9FOY6DXq+X/FxJR8ZgMCA+Ph779u0TW4rgmIZQpN7727hxIzp37iyZRJH2xubhBZ1Ox0sje/ToEQ4fPoz58+ejVatWePfdd7Fq1SqbymQYhvZ2RebgwYMIDg5G48aNxZYiKLm5udi7dy+GDh0qthTBoQ/QbMPmnq7RaORDBwICAtC7d2/z3/Xq1eNlniNf+ijWQcoDtHXr1uHtt9+Gn5+f2FIE5dKlS8jIyJD8AhchsbmnK8TQwrx583gry1HmDpNIVlYWjh49ikGDBoktRVBImj4VFxeHkSNH0uX0NmDzN+fo41eOrk/KrF27FjExMZJLof08586dg0ajQadOncSWIihFRUVYt24dzpw5I7YUp8bmnq6j70Xr6PqkCmm9v9GjR0u+re3YsQONGzdGzZo1xZbi1Njc01UoFDbfwh88eNBWGWXCcZzkpyk5Kn/99RcAoF27diIrEZaCggJs3LgRKSkpYksRHFJ2iBMamy/Nrq6uDnsLzzAMnS4mEqRMn9qyZQsiIiIQGhoqthRBMS1wkfqG7PaAl/shR+1NOqouqaPRaLB9+3bExsaKLUVwSBlCiY+Px5AhQ+hvigd4CbpKpdLhejQ6nQ5btmzB1q1bS2WfpQjPhg0b8NZbb6Fy5cpiSxGU1NRUpKamEpEdIj4+ng4t8AQvQdfX15ePYnhFoVAgKCgIP//8M4KDgzFmzBgkJSXRKWR2gJTeH2nZIaS+wMVe8JaY8v79+9BoNHwUxQve3t6oVq0aAODevXtITExEQkICjEYjYmNjMXz4cLzxxhviipQgKSkpeOedd3D37l1Jb1au1+tRtWpVnDhxQvKblb/77rvo1q0bPvroI7GlSALe5rj4+/s7zBADwzDw9/c3/129enVMnz4d165dQ2JiIjIzM9GyZUtERUVBpVIhLy9PRLXSQqVSYdSoUZIOuEBJdoi6detKPuA+fPgQR48exeDBg8WWIhl4C7qenp4OMbbLMAyUSiU8PT3LfK9NmzZYunQpMjIyMGnSJOzZswfVqlXD4MGDsX//frrxuQ2YskOMGjVKbCmCQ8oQSkJCAvr27Sv5BS72hLfhBaAkg0Rqaqqo+x3IZDLUq1fPop7W48ePsXHjRiQkJODevXsYOnQoYmNj6RiWhWzYsAEqlQqHDx8WW4qgZGRkoFGjRkhLS5P0ZuUcx6FBgwZYuXIl3aycR3hdQiOXyxEaGipab5dhGISGhlp8a1upUiWMHTsWp0+fxvHjx6FQKNCrVy80bdoUixYtQmZmpkCKpQUpvb/Vq1cTkR3i1KlT4DhO8gtc7A2vPV0TDx48QE5Ojl1nCjAMAz8/PwQFBfFSntFoxIkTJ5CQkIAdO3YgPDwcsbGx6NOnDzw8PHg5h5S4ffs2WrdujfT0dEnP5TQajahduzY2bdok+c3KR40ahQYNGmDq1KliS5EUggRdjuOQkZGBvLw8uwRehmFQsWJFBAcHC9LLfvr0KXbs2IGEhAQkJyejX79+iI2NRfv27SW/3r68zJgxA1qtFj/++KPYUgTl2LFjmDx5Ms6fPy/68wsh0Wg0qFatGq5fv44qVaqILUdSCBJ0gZLAm5mZKXiP19TDDQwMtMuPICMjA+vXr0dCQgLy8/MxfPhwDB8+HHXq1BH83I6KwWBA9erVsX//fjRq1EhsOYIyZMgQhIeHY/z48WJLEZQVK1bgwIED2LZtm9hSJIdg3TSGYRAUFISqVavyknqnrPJlMhmqVq2KoKAgu/U6QkJCMHXqVKSkpGD79u3Iz89HZGQkIiIisHz5cuTk5NhFhyNx8OBBhISESD7g5uTkYN++fURkhyBl83kxEKyn+ywsy0KtVvOSNRj4v2lhwcHBDrGZsl6vx+HDh5GQkIADBw6gS5cuiI2NRY8ePZx6tZJer4dOp4PRaATHceYLnUKhKOWrX79+iI6OxocffiiiWuspr88lS5YgKSkJ69evF1GtbZTH68WLF9GjRw/cu3fPaedbl7dOxcAuQddEQUEBsrOzzanRLTm1qSerVCrh7+9f5jxcR+DJkyfYsmUL1qxZg+vXr2PQoEGIjY1FixYtHH4MkGVZ5ObmQqvVQqfTmRvr85heVygUYBgGUVFRuHTpktPM5bTGp1KpRExMDP73v/+hc+fOIqi2Dmu8/v7771Cr1ZgxY4YIiq3D2jr19fW1e8fNrkHXhDN9QbZw69Yt8/Jjd3d3xMbGYujQoahatarY0kphy8UQKHmiX7FiRYe+GAK2X/RZlkXFihUREBDg0D4B27waDAa4uLjA29tb8nUK2L8jJ0rQfR5HvhXgA47jcOrUKSQkJGDLli1o1qwZYmNj0a9fP1SoUEE0XVIf9jFBik+AHK/O7NMhgi5J6HQ67N69GwkJCfjzzz/Ru3dvxMbGolOnTnYdP9NoNEhPTwfHcbzOLmEYxrxIxRGGG0jxCZDj1dl90qArIg8fPsRvv/2GhIQEZGVlYdiwYYiNjUWDBg0EO6dUp/I9Dyk+AXK8SsUnDboOwqVLl7B27VokJiYiKCgIsbGxGDx4MAICAng7hxiLVnx8fBASEmLXHykpPgFyvErJJw26DobBYMCxY8eQkJCA3bt3o0OHDoiNjUWvXr3g7u5uU9lSWJ5dHkjxCZDjVUo+adB1YLRaLbZt24aEhARcuHABAwcORGxsLNq2bWvx1Vej0SAtLU2UzBkMw6Bq1ap2GQ8kxSdAjlep+aRB10m4f/8+1q1bhzVr1sBgMCA2NhbDhg1DjRo1zJ/Zs2cP/Pz8EBERUepYR9lys27duoI+GSbFJ0COVyn6pLu1OAnVqlXD559/jqtXr2L9+vV4+PAhWrdubc5+8eTJE3zwwQfo0qULzp49W+pYtVotem44juOgVqsFPQcpPgFyvErRJ+3pOjHFxcXYv38/EhIScPDgQRQVFYFlWSiVSpw6dQphYWEoKCjAnTt3RG+4QMmtWo0aNQSZhE6KT4Acr1L1SXu6Toybmxv69OmDrVu3olevXjAYDABKxoJbtmyJ5ORkZGdnO0SjBUp6DNnZ2YKUTYpPgByvUvXpOEtMKDZx7NgxACW56ipWrAgXFxdcu3bN4ZZwarVasCzL6zggy7LmZaCOghA+AXK8StknDboS4dKlS/Dw8Ci1rPjRo0d4+PChiKrKJjc3l9f5x7m5ubyVxSd8+zSV6YjQOi0/dHhBIgQEBLywjwNf69Kf5cCBA+jduzfatGmDPn364OjRoxYdz3Ec7z0YW33q9XpMnjwZ3bt3R6NGjZCcnPzCZ65cuYIRI0aYH14mJia+skwhfAK2eb116xb+85//ICIiAhEREXj//fdx69Yt8/vx8fGIiYlBmzZtEB0djfj4+HKV64h1CgCFhYWYO3cuIiMjER4ejhEjRrzwGb1ej969e+Ott956bXl8+aQ9XQmj0+l4LS8rKwuff/45Fi9ejPbt2+PPP//ElClTcODAAVSqVEk0XXyU16xZMwwbNgxTpkx54b3c3FyMGTMGU6dORbdu3aDX65GVlWUXXXyWGRAQgEWLFiE4OBhGoxEbNmzA1KlTzdkhOI7DvHnzULduXaSlpeGjjz5CYGAgevToIaguocqbPXs2DAYDdu7cCR8fH1y7du2Fz8THx8PX1xdPVs+IrgAACHhJREFUnz61my7a05Uoer3eqp5Co0aNcP/+ffPf06dPx+LFiwGUBF1vb29ERkaCYRh06NABHh4eSEtLs+gcHMdBr9dbrC07O/uFqTvl9fkqX66urhg+fDiaN29eZs67hIQEREREoGfPnnBzc4OXlxdq1qz52nNa67O4uLjMAFEer6/y6e3tbV7WynEcZDJZqbobPXo0GjRoALlcjho1aqBTp044f/58uTRb6/XSpUsvzMHlo05v376N48ePY9asWfDz84OLiwsaNmxY6vj09HTs2bPHogwZ1vp8Fhp0JYpOp+N9zXjDhg1Ro0YN/P777zAYDDh69ChcXV1Rt25di8phGMaqHsO3336LatWq4f3330d6ejoAYXw+T0pKCnx8fDBs2DBERUVh3LhxePDgwWuPs9bn8ePHUb9+fXTs2LHUUAdfXiMiItCyZUssWLDgpQGH4zicO3cOtWvXLleZ1ngtKipCo0aNULNmTWzatMkcfPnweenSJQQFBWHp0qWIjIxETEwMDh8+XOozCxYswMSJEy3KXm1tnT4LHV6QKEKs4HFxcUHv3r3x6aefori4GK6urvj+++8tniFhMBjwzz//oLCw0KLjbt68CYPBgDVr1mDt2rXo2LEjfv75Z4vKsIasrCxcvXoVK1asQJ06dbBo0SJMmzYNa9eufeVxRqPRKp/Jycnw8PDAiRMnEBkZibp16+Lbb79FeHi4LTbMnDp1CgUFBdi1axeCg4PL/Mwvv/wCo9GIvn37lqtMlmUt9lpcXAyZTIZ79+5hxIgRGD9+PD799FO899575S7jZWRlZeHmzZvo2rUrjh07hgsXLuCTTz5BrVq1ULNmTRw9ehQGgwFvvfVWmWP4r8LW3xYNuhJFiPmNSUlJWLRoEeLj41G/fn1cuXIF48ePx7Jly/Dmm2+Wuxy9Xo9Dhw6V+9bVhOmWm2VZMAyDI0eO4MKFCy/cNvKNu7s7OnfujLCwMADAmDFjEBkZCa1WC6VS+dLjWJa1yuejR4/Mt7BFRUW4ePEifv31V7Rt29Z6E8/h6emJgQMHokOHDti5c2epMfn169dj9+7dWL16Ndzc3MpVXnFxscVeDQaDuZ3qdDoUFRXhxx9/xOjRoy0zUwbu7u6Qy+X48MMPIZfL0apVK7Ru3RqnTp1CYGAgFi1ahF9++cWqsm39bdGgK1GsvT3z8PAo1Vt5/PgxqlSpAgC4fv06WrRoYQ5yYWFhaNSoEU6fPm1R0FUoFPj000/h4+NjkbYpU6bgp59+gru7OyZOnIipU6dCJpMhIyPDJl+vo27duqW+z/J+t25ublb53L9/P/r06QOFQoGePXti/vz5qFOnDvLy8l57rCU+jUYjdDodHj58aA6627dvh0qlwurVqxEYGFhuzZ6enhZ71el08PT0hJeXF+rWrYvvv/8eHTt2hEajKdfxr/Ja1pCXqd7u378PtVptns2g1+uRn5+Pjh07Yt26dQgJCXnleW0d+qBjuhKlrAdC5aFevXrYt28fDAYDTp48ib///tv8XsOGDXHu3Dlzj/Pq1as4d+6cxWO61uqLiorCZ599hvT0dMyfPx++vr7lLudVvoCSnlpRURGAkh9hUVGRuUfTt29fHD16FNeuXYNer8fy5cvRvHnzV/ZyTVjjs169ehg1ahRSUlKwefNm1KlTp9xlvcrnqVOncPXqVRgMBuTn5+O7776Dt7e3+aHgnj178NNPP2HlypVW5fGz1KubmxtGjx6N3bt3459//kGnTp3MqbrKw6u8tmjRAkFBQYiLiwPLsjh//jzOnj2LiIgI1K5dG4cPH8aWLVuwZcsWzJ49G5UqVcKWLVvKdaGx9rdlgu69IFH0ej1SU1MtvhW6fPkypk+fjgcPHqBz584wGAwIDQ3FhAkTAJTceiYmJuLx48fw9fXF4MGDy5z/+CoYhkHdunV5yX9XXp+v89W9e/cXZkYcOHDA3OvZuHEjVqxYgcLCQjRv3hxffvnla3+gfPoEyuf1VT4PHjyIn3/+GVlZWVAoFAgLC8PEiRNRr149AEB0dDSysrJK6e3Zsydmzpz5Wm2OWKc3b97ErFmzcOPGDQQFBWHChAllzsdNTk7GZ599Vq4553z4pEFXwly5ckXULfFehkwm4zUlESk+AXK8StknHV6QMJZMhbEnfOsixadQZfIBrdPyQ4OuhFEqlaIlS3wZplTXfEKKT4Acr1L2SYOuhPH19RVbQpnwrYsUn0KVyQe0TssPDboSRi6XC9LbsgWlUsn7doek+ATI8SplnzToShx/f3+HuU1jGAb+/v6ClE2KT4Acr1L1SYOuxPH09HSI8THTeJhQm6qT4hMgx6tUfdKgSwDBwcEO0XBft9LHVkjxCZDjVYo+adAlALlcjtDQUNEaL8MwCA0NhYuLi6DnIcUnQI5XKfqkQZcQvL294efnZ/fGyzAM/Pz84O3tbZfzkeITIMer1HzSoEsQgYGB8PHxsVvjZRgGFStWtGjjFD4gxSdAjlcp+aTLgAmD4zhkZmYiJydH0PTWpl5CYGCgKLeGpPgEyPEqFZ806BKKRqNBeno6OI7jtQEzDGMeB7PnrfbLIMUnQI5XZ/dJgy7BsCwLtVrNW9Zg09Sa4OBgQRYGWAspPgFyvDqzTxp0KSgoKEB2drY5vbQlTcJ0+6VUKuHv7y/o/FRbIcUnQI5XZ/RJgy7FDMuyyM3NhVarhU6nA8dxZY5pmV5XKBRQKpXw9fV1qF7Q6yDFJ0COV2fySYMu5aXo9XrodDoYjUZzY5XJZFAoFLxtzO0IkOITIMerI/ukQZdCoVDsCJ2nS6FQKHaEBl0KhUKxIzToUigUih2hQZdCoVDsCA26FAqFYkdo0KVQKBQ7QoMuhUKh2BEadCkUCsWO0KBLoVAoduT/A4zrKQxDm00EAAAAAElFTkSuQmCC\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],\n",
" 'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],\n",
" 'i8': ['i16', 'f16'], 'i16': ['i32', 'f32'], 'i32': ['i64', 'f64'],\n",
" 'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],\n",
" 'c64': [2, 3], 'c128': [3, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zRfInAL21i_m"
},
"source": [
"A disadvantage of this approach is that it still leaves `int64` and `uint64` promotion undefined, because there is no standard floating point type with enough bits of mantissa to represent their full range of values. We could relax the precision constraint and complete the lattice by drawing connections from `i64->f64` and `u64->f64`, but those links would run counter to the motivation for this promotion scheme.\n",
"\n",
"A second disadvantage is that this lattice makes it difficult to find a sensible place to insert `bfloat16` (see below) while maintaining the lattice property.\n",
"\n",
"A third disadvantage of this approach, more important for JAX's accelerator backends, is that some operations result in types that are much wider than necessary; for example mixed operations between `uint16` and `float16` would promote all the way to `float64`, which is not ideal."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ksu9PCrTFyJo"
},
"source": [
"### Option 2: Avoid most wider-than-necessary promotions\n",
"\n",
"To address the unnecessary promotions to wider types, we could accept the possibility of some precision loss in integer/float promotion, promoting signed integers to floats of the same width:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"cellView": "form",
"id": "8tLGLvGM2h6O",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2dd1gUV/v+71l2YUGXJioCdoVY0IiCFWtUUDT2GpNorNFEY5QkGs3PFmNiezWWqFGTCPYau/JaI5bYaywoAgvEFYTVZWHL/P7gu/uKUrbM7syePZ/rynXFZefMfe955plnzpw5w7Asy4JCoVAodkHEtwAKhUJxJmjSpVAoFDtCky6FQqHYEZp0KRQKxY7QpEuhUCh2hCZdCoVCsSM06VIoFIodoUmXQqFQ7AhNuhQKhWJHaNKlUCgUOyLmWwDlf2g0GqjVauj1erAsC4ZhIBKJIJVKIZFI+JZHIRQS407InmjS5RGtVovs7GwolUqo1WpjcLyJ4XOpVAqZTAYfHx+IxbTrKJZBYtw5kieGLnhjf1QqFRQKBZRKJYDCQDAVQyDJZDL4+fnBw8PDJhop5EFi3DmiJ5p07YhWq4VcLodSqTQrOEqCYRjIZDIEBAQItgKh8A+JcefInmjStRO5ublITU0Fy7KcBIkBhmHAMAyCgoLg6enJWbsUMiAx7hzdE026NoZlWWRkZCArK4vTAHkThmHg6+sLf3//YseyKM4FiXFHiieadG0Iy7JIS0tDTk6OTYPEAMMw8PLyQmBgIE28TgyJcUeSJzpP14ZkZGTYLUiAwsDMyclBRkaGXfZHESYkxh1JnmjStRG5ubk2vwwqDpZlkZWVhdzcXLvulyIMSIw70jzRpGsDtFqtcaCfD1iWRWpqKrRaLS/7p/ADiXFHoieadG2AXC7nLUgMsCwLuVzOqwaKfSEx7kj0RJMux6hUKs7mDloDy7JQKpVQqVS86qDYBxLjjkRPAE26nKNQKHgPEgMsy0KhUPAtg2IHSIw7Ej0BNOlyilarNT6OKBSUSiUd2yUcEuOORE8GaNLlkOzsbJu0O336dKu2t5UuijCwtn+zsrIQGxuLli1bolWrVvjqq6/e+k5OTg7atm2LDz/80C66zNn22bNn+Oyzz9CxY0eEhoYiLS2tyN8XLlyI7t27o3nz5ujRowf27dtX5O8XLlzAgAED0KJFC0RFRWH79u2c6CoJmnQ5hMvxp7y8PMyaNQsvXrwAUHhDYdasWWa3bxiPopCLtXH3xRdfwM/PD0ePHsWpU6fw8ccfv/WdJUuWoGbNmia3aW3cmeOJYRi0bt0aixcvLvbv7u7uWL58ORITEzFv3jz88MMPuHbtGoDCJSAnTZqEfv36ITExEQsXLsRPP/2Ef/75h3NPBugqKRyiVqst3jYjIwM//PADrly5Ar1ej+joaAwZMgRz587F7du3sWLFCowbN86ip2Os0UURPqb2b3Ex1qFDB2RkZGD9+vVwcXEBANSrV6/IdteuXcODBw/Qr18/7N69m3Nd5mxbnIfp06dj0KBBJV76jx8/3vj/jRo1QtOmTXH9+nW8++67yMnJwcuXL9GjRw8wDIOGDRuiVq1aePToEUJCQjj1ZIBWuhyh0WgsrjZ0Oh3Gjx+PKlWq4PDhw0hISEB0dLTx768vwmwJLMtCo9FYtC1FOHzyyScYNmwYnjx5YvzM1LgrKcauX7+OGjVqYPr06WjTpg0GDRqES5cuFdnu+++/x7Rp08w+4ZcVd1qtFqGhoZg/fz5evXpVpqeyjhNTUKvVuHXrFmrXrg0A8PPzQ3R0NPbs2QOdTodr164hPT0dYWFhFnkyBZp0OUKtVlv8jPbNmzfx7NkzfPnll/Dw8ICbmxvq1auHuLg4fPvtt2jatCk+/fRTrFu3zqLEzjAMrXYJ4Nq1a4iPj0e9evWMydfUuCsuxsLCwpCZmYlz584hIiICJ06cwEcffYSJEycaxy7j4uIQGhqKBg0amK23rLjTarW4ffs25s6di4CAAGPyLclTSR7MYc6cOQgJCUHr1q2Nn3Xr1g2rV69G06ZN8fHHH+Ozzz6Dv7+/RZ5MgQ4vcIRer7d424yMDFSpUqXIOp7u7u74f//v/xn/HRAQgJkzZ1rU/qtXrzB8+HCcPn3aYo0U/snKyoJer4darcamTZuwadMmXLx4EeXKlStz2+JiDACkUikCAwPRp08fAEB0dDTWrFmDq1evomHDhoiPj8fWrVst0qtSqUqNO8PSjIb5r9OmTcO8efPeuhFWlgdTWbRoER48eID169cbk3pSUhJiY2OxZMkStGzZEsnJyZgwYQIqVaqEtm3bFtuONcc6QJMuZ1hzI8Pf3x8ZGRnQarXFBtS8efOskQYPDw8sXrwYUqnUqnYo/NK5c2fcuHEDHh4eqFatGubMmYO6deua9LRUSTEWHByMkydPFvmuISEZKsv3338fAJCfnw+1Wo327dsjISHBOAZcEu7u7qXGnVqtRs2aNcEwDCQSCTp37ozZs2eXeCyVdZyUxooVK3D27Fls2LAB5cuXN37+8OFDVK9e3Vj51qxZE23btsWZM2dKTLrW3iynSZcjrFn+LTQ0FH5+fli6dCk+/fRTuLi44M6dO2jSpAln2ry8vODl5cVJexR+qFSpEho2bIhFixahc+fOYBgGOTk5Jm1bUox16tQJCxcuxN69exETE4OEhARkZmaiSZMmKFeuHI4cOWJs4/Dhwzh48CCWLVtWZsIFyo47jUYDNzc3xMTEYN68eahbty4AlOiptOMkPz8fOp0OAFBQUID8/Hy4ubkBANatW4eDBw/it99+g7e3d5E269Wrh+TkZFy4cAERERFITU3FqVOnMHz48FJ9WQNdT5cjlEolUlJSLL70SE9Px/z583HlyhUwDINu3brhm2++4USbSCRC1apVIZPJOGmPwg86nQ4ikajIQW9O3JUUY5cvX8bcuXORlpaGmjVrIjY2Fk2bNn1r+z179mDXrl34/fffTdJrStzpdLq3EnhpnkryEBoa+tZ3b968CaAwWUskkiLV8ahRozBq1CgAhSeTX375BXK5HOXLl0f37t0xadKkYm9cc3Es0aTLERqNBvfv3xfMY4uvwzAMgoODeX/1NIV7SIw7Ej29Dp29wBESiUSwb2swjJlRyIPEuCPR0+vQpMshQr1RJVRdFG4Qav9ao4tETwZo0uUQmUwmuDO04dXSFHIhMe5I9GSAJl0O8fHx4VtCsQhVF4UbhNq/1ugi0ZMBmnQ5RCwWC66qlMlkFk8mpzgGJMYdiZ4M0KTLMX5+foK5LGIYBn5+fnzLoNgBEuOORE8ATbqc4+HhIYjxKMP4k4eHB686KPaBxLgj0RNAk65NCAgIEESgBAYG8qqBYl9IjDsSPdGkawPEYjGCgoJ4CxaGYRAUFGTSo5oUciAx7kj0RJOujfD09ISvr6/dg4VhGPj6+sLT09Ou+6UIAxLjjjRPNOnaEH9/f3h5edktWBiGgbe3d4lrgVKcAxLjjiRPdO0FG8OyLDIyMpCVlWXTZ8kNZ2V/f3/ex8Ao/ENi3JHiiSZdO5Gbm4vU1FTjws1cwTCMcdyJDilQ3oTEuHN0TzTp2hGtVgu5XM7ZW4MNU1kCAgLoAxCUEiEx7hzZE026PKBSqZCUlAStVguJRGJW0Bgud2QyGfz8/Og8XIrJqFQqpKam4tWrV3B1dSUi7lQqFa5du4Zy5cpBLBY7hCdaHvFARkYGWrZsCTc3N9y9exdKpRJqtdr41t83MXwulUohk8ng4+NDK1uK2bx69Qrt27eHSqXCvXv3oFKpHD7u/v77b3To0AEhISE4deqUQxxLwvoFnYB//vkHrVq1wsuXL6FSqVCuXDlUrFgRQOHizWq1Gnq9vshr16VSKV0Pl2IV6enpaNmyJTIyMuDq6gqlUml8PY6jxt2xY8fQq1cv6PV6KBQKVKxY0SGOJZp07ciTJ0/QvHlz5ObmAih8zDE5ORn16tUDULh4M98BQSGP7OxshIeHIyMjAyzLws3NDU+ePDEmXUeMu5MnT6J79+7QaDQACj3q9XrjK3aE7InO07UjDMMgIiICIpEIYrEYGo0GT5484VsWhXBYlkV4eDhEIhFcXFyQl5fn8HHn6uqK0NBQiEQi432RjIwMvmWZBK107Uj16tVx9OhRNGzYEN26dUNiYiLKlSvHtywK4fj6+mL37t2IiopCnTp1cOfOHcGuV2sqrVq1wqlTpxAQEIDJkyfj0KFDgnynWnHQ2Qt25tatW4iKisLTp0+LfdsohWILFAoF6tSpg9TUVJQvX55vOZwQHx+PuLg4HDhwgG8pZkGPejuzefNmDB48mCZcil3ZsWMHoqOjiUm4QGHSHTJkCN8yzIZWunaEZVnUqlULu3fvxrvvvsu3HIoTERkZia+++goxMTF8S+EER67cabllRxITE+Hu7o7GjRvzLYXiRCQnJ+Pu3bvo0qUL31I4Y/v27ejWrZvDJVyAJl27YrgcogvSUOzJli1b0K9fP7i6uvIthTMcdWgBoMMLdkOj0SAwMBCJiYmoXbs233IoTkTjxo2xfPlytG3blm8pnJCcnIymTZtCLpc75ImEVrp2IiEhAbVq1aIJl2JXbt26haysLLRp04ZvKZzh6JU7Tbp2wpEvhyiOC4mzZRz9WKLDC3ZApVIhICAA9+7do291oNgNEmfL3Lp1C9HR0UhOTnbYE4ljqnYw9u/fj4iICJpwKXbl/PnzxM2WIaFyp48B2wFHvxyiOCakzZZhWRbx8fHYvXs331Ksgg4v2Jjs7GzUqFEDT58+hZeXF99yKE6CVqtFYGAgzp07R8zN28TERHzyySe4ffu2Q59IHLdGdxB27NiBLl260IRLsSvHjx8nbrZMXFwchg4d6tAJF6DDCzYnPj4en332Gd8yKE5GfHw8Bg8ezLcMztBoNNi+fTvOnTvHtxSroZWuDUlNTcX169fRrVs3vqVQnAiVSoU///wTAwcO5FsKZ5A0z50mXRuydetW9O7dG1KplG8pFCfCMFumcuXKfEvhDJJuRtOka0NIChSK40Ba3Bkq9wEDBvAthRNo0rUR9+7dg1wuR/v27fmWQnEisrOzceLECfTu3ZtvKZxBWuVOk66N2Lx5MwYNGgQXFxe+pVCciJ07d6JLly7w9PTkWwpnkFa506RrAwyTuEkKFIpjQFrckVi506RrA/7++28AQLNmzXhWQnEm0tLScO3aNURHR/MthTNIrNxp0rUB8fHxREzipjgWW7ZsIW62TFxcHFGVO0AfA+YcnU6HoKAgnDx5EiEhIXzLoTgRTZs2xY8//ohOnTrxLYUTUlNT0ahRI8jlcqJOJLTS5ZiTJ08iMDCQJlyKXSFxtgyp89xp0uUY0m5kUBwDEmfLkHos0eEFDlGr1QgICMDNmzcRGBjItxyKk8CyLIKDgxEfH4/w8HC+5XDCvXv30KFDB6SmphJ1IgFopcsphw4dwrvvvksTLsWukDhbhsTK3QBNuhxC6uUQRdiQulg5qccSHV7giJycHFSrVg1PnjyBj48P33IoTgKJs2UuXbqEoUOH4p9//iHmRPI6tNLliN27d6NDhw404VLsComzZUir3N+ELmJuAhqNBmq1Gnq9HizLgmEYiEQiSKVSSCQSAIWBMmrUKJ6VOi+m9JGjYYonR3t4oCxPOp0OW7ZswcmTJ/mWajNo0i0GrVaL7OxsKJVKqNVqY3C8yetBU7t2baIevxQ65vaRVCqFTCaDj48PxGJhhr25nlxdXeHm5ob+/fvzoNY0zPWkVqsxYsQIIhYrLwk6pvsaKpUKCoUCSqUSQGEgmIpOp4NYLIZMJoOfnx88PDxsJdOpsaaPDAe70PrIGk9arRYSiYQoT3q9Hi4uLoLzxBU06aIwcOVyOZRKpVnBURIMw0AmkyEgIECwVZWjQWIfUU9lIwRPXOP0STc3NxepqalgWZaTIDHAMAwYhkFQUBBRKyTxAYl9RD2ZDmnHktMmXZZlkZGRgaysLE4D5E0YhoGvry/8/f2JvRtrK0jsI+rJckg5lpwy6bIsi7S0NOTk5Ng0SAwwDAMvLy8EBgY6dLDYExL7iHqyHhKOJaecp5uRkWG3IAEKAzMnJwcZGRl22R8JkNhH1JP1kHAsOV3Szc3NtfllUHGwLIusrCzk5ubadb+OCIl9RD1xh6MfS06VdLVarXGgnw9YlkVqaiq0Wi0v+3cESOwj6ol7HPlYcqqkK5fLeQsSAyzLQi6X86pByJDYR9STbXDUY8lpkq5KpeJs7qA1sCwLpVIJlUrFqw4hQmIfUU+2w1GPJadJugqFgvcgMcCyLBQKBd8yBAeJfUQ92RZHPJacIulqtVrj44hCQalUOuR4lK0gsY+oJ/vgaMeSUyTd7Oxsi7d9/Pgx+vXrh+bNmyMuLo5DVdbpIg1zfgtb9smbWNNHpm5rTz8A9cQ3TpF0rRl/2rBhA8LDw3HhwgUMHToUaWlpWLlypdWaDONRlELM6aPX+6Ru3boYMWIEWrZsia5duxb7/U2bNiEqKgoRERHo2bMnnjx5YtJ+rO0jUz297ken0yEqKgotWrRAx44dsWDBAmMV9/z5c8TGxqJjx45o2bIlhg0bhhs3bpiliQ9PQ4cOxZ07d/DRRx8hIiIC7dq1w6ZNm97a5tKlSwgNDcWyZcvM0uRox5JTJF21Wm3xtnK5HHXq1MH169exZs0a6HQ6AIXvpVqzZg1vukjDnN/C0CcA4O7ujt69e2Py5MnFfnfnzp3YtWsXVqxYgQsXLmDFihVmLTRvTR+Zuu3rfjp06IBt27bh/Pnz2L17N+7fv2+sFFUqFRo0aICtW7fi7Nmz6NmzJ8aPH2/2jSR7e8rOzsa4cePQv39/nD17FgcPHkSrVq2KfF+j0WDBggVo1KiRTXUJAeKTrkajsbjK/eSTT3Dp0iV8//33GDVqFDw8PDBnzhwcOnQIZ8+exQcffGCVNpZlodForGqDBMzpo9f7JCIiAjKZDD169EBQUNBb39Xr9Vi1ahViY2NRu3ZtMAyDqlWrwsvLy2RtlvaRqZ7e9KPT6YyLuhjWmE1JSQEAVK1aFR999BEqVqwIFxcX9O/fHxqNBo8fPzZLm709zZ8/H61atUJMTAxcXV1Rrlw51KpVq8g2v/32G1q1aoUaNWqYrQtwrGOJ+KSrVqstfkb7119/RVhYGKZNm4aLFy8WecuvSGT9T2dYtNmZaNOmDTp06GB8gy1gXh+92SelHaSZmZnIzMzEw4cP8d577yEqKgorVqyAXq83WW9ZfZSbmwsvLy+MHz++yKOppnoqzs+BAwfQokULREZG4v79+yUuUn7v3j1oNBpUq1bNZD+meNq1axcqV66MlStXIj8/32pPz58/h5eXFz744AO0a9cOEyZMQHp6uvH7crkce/bswdixY83yYY4nIUHGApWlYM4BVhrXr1/H/fv3MWPGDOzfvx/h4eGIi4uz6hU9LMsiLy+PyNdMl0R6ejqSkpLQtm1bNGvWDPPnz0fDhg1tsq/MzEwAwLlz57Br1y4olUqMGTMGlStXRr9+/Uxqo6w+ysrKQl5eHtatW4f169dj2LBhmDFjhlVLEHbv3h3du3dHcnIy9u3bhwoVKrz1nZcvX+Kbb77BuHHjIJPJzGq/LE+ZmZl48eIFYmNjMWPGDMyYMQNjxoyx+FjKzMzE3bt3sWbNGtStWxeLFy9GbGws/vjjDwDA/PnzMWHCBKsXK+fqWLc1xCddruYTNm7cGI0bN0ZaWhoAIDw8HOHh4Va1+fLlS3z77bc4cuQIFxIdAkM1kpeXhzNnzqBNmza4c+eOTfbl5uYGABg+fDg8PT3h6emJ/v3748yZMyYnXZVKhTFjxpTYR29e1q5duxbbtm0z+WZdaVSvXh116tTB3LlzsXTpUuPnarUaEyZMQOPGjTFy5Eiz2y0r7rRaLTQaDQoKCvDq1St88cUXSEhIMCZJc3Fzc0PHjh2NJ9dx48YhMjISSqUSly9fhkqlQlRUlEVtv45Q5g6XBfFJl+vl3wIDA/Hpp59y0pZMJsPGjRvNGmN0dOrWrYukpCS4urqiZ8+emDdvHipWrGg8mXFJjRo1IJFIrIqBcuXKldpH//77L4KCgiAWi+Hi4oKpU6di0qRJnCUArVZrHNMFgIKCAkycOBGVK1fGzJkzLWqzrLhbv349xo0bB4lEgooVK+LHH39E3759LZ4hEBwcXKQPXv//Cxcu4Pbt22jfvj2AwhOCSCTCgwcPsHz5crP24yhLPRKfdLkYe7UlQtfHNSEhIQgLC8O8efOMd7etme6j1+uh0Wig1WrBsizy8/MhEokgkUjg7u6OqKgobNiwAfXq1YNSqcSOHTswfPhws/ZRWh95eHggICAAI0aMwKRJk4zDCpZ62rlzJ9q3b48KFSrg0aNH+PXXX413+jUaDSZPngw3NzfMmzfPqtgpbdvKlSujZs2amDNnDvr27Wv8rqX769WrF7744gsMHToUtWvXxurVqxEWFgaZTIYJEybgk08+MX73hx9+QMWKFS0a33WUY4n4pCuVSgV72cGyLKRSKd8y7Mr+/fvf+syaPrp8+TJGjBhh/HezZs3QrFkzbNiwAQAwbdo0zJo1Cx07doRMJkPfvn3Ru3dvk9svq4/Kly9f7FCCpZ6uXr2KZcuWIS8vDz4+PujSpQsmTJgAALh27RpOnToFqVRaZMrVqlWr0LRpU5P3UZYnw5jym1jqqXnz5pg4cSLGjx+PvLw8hIWFYcGCBQAKryTKlStn/K6bmxvc3d3NvvpzpGPJKd4ccefOHUEOsotEItSvX59vGYKAxD6inuyHIx1LjlGPW4lQz4BC1cUHQv0trNFFPdkPoeoqDqdIujKZTHCD7IZXS1MKIbGPqCf74GjHklMkXXMe+7QnQtXFB0L9LazRRT3ZD6HqKg6nSLpisVhwZ0KZTAaxmPj7mCZDYh9RT/bB0Y4lp0i6AODn5yeYyyKGYeDn58e3DMFBYh9RT7bFEY8lp0m6Hh4eghiPMow/WfvII4mQ2EfUk+1w1GPJaZIuAAQEBAgiUF5fOIdSFBL7iHqyDY56LDlV0hWLxQgKCuItWBiGQVBQkFMtcGMuJPYR9cQ9jnwsOVXSBQBPT0/4+vraPVgYhoGvr69Vq085CyT2EfXEHY5+LDld0gUAf39/eHl52S1YGIaBt7c3/P397bI/EiCxj6gn6yHhWHKKx4CLg2VZZGRkICsry6ZrMxjOyv7+/ryPgTkaJPYR9WQ5pBxLTpt0DeTm5iI1NRUsy3IaMAzDGMedHPUySCiQ2EfUk+mQdiw5fdIFCtcslcvlVr01+HUMU1kCAgIcatK2kCGxj6inshGCJ66hSfc1VCoVFAqFcS1Uc34aw+WOTCaDn5+fw80ddBRI7CPqqShC9cQVNOm+wZgxY9CgQQMMHjwYSqUSarXa+FbWNzF8LpVKIZPJ4OPjQ8zZWMhcunQJM2bMwB9//EFMH82aNQsqlQpTpkwhwpNOp0NUVBSmT5+OBg0aEOGJK2jSfY3vvvsOs2fPRrdu3XDgwAHj5xqNBmq1Gnq93hgcIpEIUqkUEomER8XOx61bt9C8eXPk5eVBq9Ua3xbgyH20du1ajB49GvXq1SvyvjhH9aTRaNCnTx/s378fixYtwuTJk4v8zRE9cQmZpxIzYVkWsbGx+PnnnwEU3hB4HYlE4jQBIWSuXr2K9u3bQ6VSwdXVFenp6cYnkhy1j5YvX46vvvoKQOEl+es4oqeCggL07NkTJ06cAAA8e/asyN8d0RPXOOU83TeZM2cOlixZYnxT7dOnT3lWRHmTpKQktG7d2nhClEqlnLxxl09+//13TJo0CXl5eQAKX3Lp6PTv3x8JCQkoKCgAAJu96dmRoUkXQMeOHY1vI5VKpcjMzBTse9WclQoVKmD06NHw8PCARCJBXl6ewyfdd999Fz179jReXufl5eHFixd8y7KKoUOHokGDBgAKq9pHjx7xrEh40OEFAG3atMH333+PwYMH4+uvv8apU6dKHPCn8IOXlxeWLl2K06dP48MPP8TNmzdRvXp1vmVZRaNGjbB8+XKcOnUK//nPf3DgwAGHXEvgdQYMGICcnBzs3bsX7dq1g0Kh4FuS4KA30v6PSZMmwcfHB9999x3fUiglcO/ePXTq1AlPnz51+ORkYNGiRbh79y7WrVvHtxTO6NChAyZOnIhevXrxLUWQ0EoXhRO6t2zZgtOnT/MthVIK8fHxGDRoEDEJFwDi4uKwaNEivmVwRmpqKm7cuIHo6Gi+pQgWmnQBnDhxAlWrVkVwcDDfUiglwLIs4uPjsXXrVr6lcMbdu3eRmZmJtm3b8i2FM7Zs2YI+ffrAzc2NbymChd5IQ2EFNWTIEL5lUErh0qVLcHFxQVhYGN9SOGPz5s3EVe70WCobp0+6eXl52LNnDwYOHMi3FEopGA5mUm5uGip3khIUiZW7LXD64YWDBw8iLCwMAQEBfEuhlIBOp8PWrVtx6tQpvqVwBq3cnRenT7qkVRskcuLECQQFBRE15k5q5U7SmLutcOqk++LFCxw/fpyo6TokQtqJkVbuzo1Tj+nu3r0bHTt2hI+PD99SKCWgVquJG3Onlbtz49SVbnx8PMaMGcO3DEopHDx4EE2aNCFqzJ1W7s6N01a66enp+Pvvv9G9e3e+pVBKgbQERSt3itMm3W3btuH999+Hu7s731IoJZCTk4Njx46hb9++fEvhDFq5U5w26dJAET67d+9Gp06d4O3tzbcUzoiPj8fQoUP5lsEZhsp90KBBfEtxGJwy6T58+BDJycno2LEj31IopRAXF4fBgwfzLYMzDJV7nz59+JbCGQcOHECTJk1QpUoVvqU4DE6ZdOPj4zFgwABi38FEAoYx95iYGL6lcMauXbto5U5xvqRL4uOXJELimDtpcWeY505S5W4PnC7pXr16FQUFBWjevDnfUiilQFqCInG2DIlj7vbA6ZIuncQtfEgcc6eVO8WAUyVdnU6HzZs300AROJs3byZuzJ20BEVi5W4vnCrpnjlzBpUqVUL9+vX5lkIpAZZlERcXR1SCopU75XWcKumSVm2QCIlj7iTOliHtxGhPyImCMsjPz8fOnTtx7do1vqVQSp7/d3QAACAASURBVIG0MXfDbJmNGzfyLYUzHjx4QFzlbk+cJukeOXIEDRs2RNWqVfmWQikBw5j7sWPH+JbCGSRW7ps3b8bAgQOJqtztidMML9ChBeFz5swZVKxYkagxd1Ird3osWY5TJF2lUolDhw6hX79+fEuhlAJpBzOJs2WuXr0KjUZDVOVub5zi+mDv3r1o27YtKlSowLcUSgmQOOZO4mwZ0ip3PnCKpBsfH49hw4bxLYNSCiSOuZNauR8/fpxvKQ4N8Un333//xblz57B9+3a+pVBKgbQERXLlXq9ePb6lODTEj+lu374dMTExKFeuHN9SKCVA4pj74cOHiavc6dxcbiC+0o2Pj8f06dP5lkEphT179hA35k5i5b5r1y6iKne+IDrpPn78GPfv30fnzp35lmISGo0GarUaer0eLMuCYRiIRCJIpVJIJBK+5VmEKZ7i4+Px4Ycf8qzUdMrypFQqcfjwYaxcuZJvqSZhSh+RWLnzBdFJd8uWLejfv79gE5ZWq0V2djaUSiXUarUx4N/E8LlUKoVMJoOPj49gJ6ab68nFxQXVq1cX9MIp5nrKzc3FtGnT4OXlxYPasrEk7pKSkhzqxChkGJZlWb5F2IrQ0FCsXLkSkZGRfEspgkqlgkKhgFKpBFAY3KZiODhkMhn8/Pzg4eFhE43mYo0nvV4PFxcXojyxLAuRSCQoT9b40el0cHFxgaenp2D8OCrEJt2bN28iJiYGjx8/hkgkjPuFWq0WcrkcSqXSrIAvCYZhIJPJEBAQwFvlSz2VDd+eSPPj6BCbdL/55huwLIsffviBbykAgNzcXKSmpoJlWU4C3wDDMGAYBkFBQfD09OSsXVOgnkyHL0+k+SEBIpOuXq9HrVq1sG/fPjRq1IhXLSzLIiMjA1lZWZwG/ZswDANfX1/4+/vb/Gkh6sly7OWJND8kIYzrbo5JTExE+fLlERoayqsOlmWRlpZm88A37CsrKwtpaWk23Rf1ZP2+bO2JND+kQWTSFcrz4RkZGcjJybFbMLIsi5ycHGRkZNhsH9ST9djaE2l+SIO4pKvRaLB9+3YMHjyYVx25ubl2qTTexFB55Obmct429cQdtvJEmh8SIS7pHjt2DHXr1kXNmjV506DVao03L/iAZVmkpqZCq9Vy1ib1xD1ceyLND6kQl3SF8PilXC7nfXyLZVnI5XLO2qOebAOXnkjzQypEJd1Xr15h//796N+/P28aVCoVZ/MhrYFlWSiVSqhUKqvbop5sB1eeSPNDMkQl3T///BMtW7ZEpUqVeNOgUCh4D3wDLMtCoVBY3Q71ZFu48ESaH5IhKunyPbSg1WqNj1gKBaVSadUYG/VkH6zxRJof0iEm6T5//hynTp1Cr169eNOQnZ1t8ba9evXCpUuXOFTzP6zRZc62tvTwJvbwZE8/gOWeSOwjkiEm6e7cuRNRUVGQyWS8abBmTG3Pnj0IDw8Hy7JYtmwZOnXqhJYtW2L48OF4+PChxZoMY2yWYo4ng4cHDx5gzJgxiIyMLPEBlUOHDqFnz56IiIhAdHQ0Ll++bLIme3ky+Dl06BB69OiBli1bol27dpg+fTpevnwJACgoKMDMmTPRpUsXNG/eHP369cOZM2fM1mSNJ0v6CABSUlIwfvx4NG/eHJGRkVi8ePFb309OTkbTpk3x9ddfm6XJ2j4iGWKSLt9DCwCgVqutbuPIkSPYs2cPNm7ciLNnz6Jx48aYNm0ab7os2VYsFqNr166YNWtWsX8/d+4clixZgjlz5uD8+fPYuHEjgoKCbK7L0m2bNGmC33//HYmJiTh06BC0Wi2WL18OoPDS3t/fHxs2bEBiYiI+++wzTJkyBWlpaTbXZc12Go0Go0ePRkREBE6cOIHjx48Xu7zmvHnz0LBhQ7vpcgaISLopKSm4efMmoqKieNOg0WisupHRtWtXJCYmIi0tDU2aNEHVqlXh4uKCmJgYPHr0yCptLMtCo9GU+p0XL1689Zm5ngweatasiT59+qBOnTrFfm/lypUYO3YsGjduDJFIhMqVK6Ny5com7wco21NBQQFevXr11ufmeDL48ff3h4+Pj/FzFxcXPH36FADg4eGBTz/9FIGBgRCJRGjXrh0CAwNx584ds/wAZXvKzc2FTqez2A/wP0979uxBpUqV8NFHH8HDwwNubm4ICQkp8t1Dhw5BJpNZ/Lp1U+LOGSEi6W7ZsgV9+/aFm5sbbxrUajUnjx1HR0cjJSUFT548gUajwb59+9C6dWur2mQYptSq4+nTp/D19cX7779fJFlw5el1dDodbt++jaysLHTr1g2dOnXCvHnzzK6KyvL066+/ws/PD9OmTSsytmippytXrqBly5Zo3rw5jh8/XuLbpRUKBZKTk1G7dm2z91GWp549e6JGjRrYtGmTMfla6ufGjRsICAjA2LFjERkZieHDh+P+/fvGv798+RIrVqzA1KlTzW7bQFl+nBUiFsOMj4/HkiVLeNWg1+s5aadixYoICwtDjx494OLiAn9/f6xbt85qbSkpKSUenE+ePIGbmxv279+PI0eOoFWrVpg7dy4aNGhg1X6L4/nz59BqtTh27Bh+++03iMVifP7551izZg0+//xzk9spy1NycjI0Gg0WL16MpUuX4oMPPsDs2bPh7u5uke6wsDAkJiYiMzMTO3fuREBAwFvf0Wg0+Prrr9GzZ0/UqlXL7H2U5UmhUCA1NRVjxozBlClTMGXKFIwcOdLs/QBAZmYmLl26hGXLlqFFixbYtGkTPv/8c/z555+QSCT4+eef0bt3b/j7+1vUvgGujguScPike+fOHTx79oz3t0NwNUdy1apVuHXrFo4dOwY/Pz/s378fI0eOxO7duy1OGGq1Gr/88gvOnj1b7N8LCgpQUFAAvV6P/Px8nDhxAkOHDsXVq1etsVIshquRIUOGoGLFigCADz/80Oykq1KpsHbt2hI9ZWVlQafTGSvCtWvXQiQSWb2+cuXKldG6dWvExsZi27Ztxs/1ej2mTZsGiURi8Rh8WZ4eP35s/J5KpUJsbCzatWtn0Zuu3dzc0KRJE+Nx8/HHH2PNmjVISkoCy7I4f/48tm/fbpGP1xHK3GEh4fBJd/PmzRg0aBBcXFx41cHVZfg///yDqKgoY4XRq1cv/Pjjj0hKSrK48vTw8MDChQtLfGfXgwcP0LBhQ0gkElSqVAk//vgj+vTpY5O7z15eXqhcuXKR38uS3658+fKlelqyZAmmTp0KV1dXtG3bFgsWLEDjxo2Rk5NjsXYDOp0OKSkpxn+zLIuZM2fi+fPnWLlypcXv5CvLU7NmzXD9+nVIJBKMGzcO33zzDSQSiUU37YKDg0t8s++lS5cgl8uNL3RVqVTQ6/UYMGBAkRONKfC90p8Qceiky7Is4uPjOTkjWwtXrwRq2LAhjh49iqioKPj6+uLAgQPQarVWv4W1NH0VKlRA8+bN8fnnn6NPnz7G71rqiWVZFBQUGG+i5Ofng2EYuLq6Aig8kcTHx6N169YQi8X4448/0LZtW7P3U5q+kJAQdOvWDXPmzEHjxo1N2qYk9u/fj6ZNm6JKlSqQy+VYtmxZkZtLc+bMwePHj7F27VpIpVKz23+d0vS1a9cO7dq1wzfffAM/Pz8AsPjEGBMTY5yRERERgbi4OHh7e6NWrVqoVq0aoqOjjd/duHEj5HI5vv32W7P3I5RXZQkJh066Fy9ehEQiQZMmTfiWAqlUysml1IgRI/D8+XP0798feXl5qFatGhYvXmzVK1FYli01Gfj6+uL06dNvfW6pJ7lcXmQmSbNmzRAQEIAjR44AAMaMGYMXL16gR48ecHV1RdeuXTF69Giz9lGWp27duqFbt25vfW6Jp6SkJCxZsgRKpRIymQyRkZGYNGkSgEKv27dvh6urK9q3b2/cZubMmYiJiTFrP2V5WrRo0VufWdpHNWvWxPz58zFnzhxkZWWhXr16WL58OSQSCSQSSZGhLA8PD7i6usLX19esfZTlx1lx6Nf1TJw4ERUqVMDMmTP5lgKgcHxZiDcORCIR6tevb9G21JP9sNQTaX5Ix2Frf61Wi61bt/K+WPnrCPWsbo0u6sl+WKqLND+k47BJ98SJE6hWrRrq1q3LtxQjMplMcDcODK/LthTqyT5Y44k0P6TjsElXCI/9vsnrTy0JCWt0UU/2w1JdpPkhHYdMunl5edi7dy8GDhzIt5QiiMViwZ3dZTIZxGLL75dST/bBGk+k+SEdh0y6Bw4cME7hERJ6vR6+vr6CudRjGMY4tcga/Pz8qCcbwoUn0vyQjEMmXaEMLSxfvhyBgYHw8fGBm5sbXFxcEBwcLIgxNsOYmoeHh9VteXh4UE82gitPpPkhGYdLui9evEBCQgL69OnDtxS88847UCgUePHiBQoKCiAWi7Fx40YEBAQIIvgDAwM5a496sg1ceiLND6k4XNLdtWsX3nvvvRIflbQXer0eqampAAoDTSqVYtSoUejSpQvEYjGCgoJ4OwAYhkFQUBCnj0ZTT9zDtSfS/JCKwyVdIQwtnDp1Cs2aNcPatWsRFxcHsVgMX19fLFy40PgdT09PXsZ3GYaBr6+vVU+wlQT1xB228kSaHxJxqNuL6enpuHz5crGPd9qDR48eITY2FpcvX8aCBQswYMAAMAwDhUKB8PDwt8ax/P39odPpkJOTY5fVlhiGgbe3t9XL8ZUG9WQ9tvZEmh/ScKhKd+vWrejVq5fFSxxaSk5ODmJjY9G8eXM0bdoUd+/excCBA43VxNixY9G0adO3tjOMb9mj8jBUGrYe16OerN+XrT2R5oc0HCrp2ntoQafT4ZdffsE777yD58+f4+bNm5g2bZpZSZ9hGFSpUgVVq1aFSCTiPDAZhoFIJELVqlVRpUoVuwQ+9WRZ+/b0RJofknCYBW8ePHiAtm3bIiUlxS6TrhMSEvDFF1/A19cXS5Ys4WQlM61WC7lcbtVbg1/HMD0nICCAt4no1FPZ8O2JND+OjiCSrkajgVqthl6vB8uyxrOoVCo1Lgg9a9YsZGdnY+nSpTbVcv/+fUyZMgW3b9/GTz/9hN69e3N+FlepVFAoFMa1UM3pAoMWmUwGPz8/wcyHpJ6KIkRPpPlxVHg5TWm1WmRnZ0OpVEKtVhsT7ZsYPpdKpShfvjx69OhhM03Z2dmYPXs2Nm3ahNjYWGzfvt1mL7r08PBAtWrVLPodZDIZfHx8BFdhuLu7Y+LEiahWrRq+/fZbIjxptVq0b98ev/32G9555x2H9+Th4YHHjx/j008/xcmTJx3ej6Ni11+wtDNtSWddlmWhUqmMi2I/ffqU0zOtRqPBL7/8gtmzZ6Nv3764ffs2KlWqxEnbZSEWi1GxYkXju8JMqfiFCMuymDBhAvbs2YMWLVoQ4Sk7OxstWrTA48ePkZSUhMjISIf3dOHCBXTt2hX5+fmQSqUO78dRsUvS5WJMybBdbm6ucQV/a8eUDh06hMmTJyMoKAgJCQkIDQ21uC0uMKza70jo9XqMHDkSW7ZsAVDYP6/jiJ6ePXuG1q1bIykpCQzDvPUOMkf0dObMGURHRyM/Px/u7u548uSJMd4d0Y8jY/Okm5ubi9TUVLAsy9mcQZZloVQqcf/+fQQFBZk9Ifv27dv48ssv8fjxYyxatAjdu3end18tZM6cOfjtt9+Mby6Qy+U8K7Kezp0749GjR0ZPd+/e5VmRdSQnJ+O9995DQUEBgMIk+3rSpdgXm00ZY1kW6enpSElJMV62cN2+Xq9HSkoK0tPTTWpfoVBg/Pjx6NChA6Kjo3Hr1i3ExMTQhGsFw4YNw4QJEyASieDu7o4XL14gPz+fb1lWsWTJEnTs2BEA4OrqigcPHvCsyDoCAwOxfPlyBAUFQSQSQaVS4cmTJ3zLclpsUumyLIu0tDS7PBHDsiyysrKg0+kQGBhYbAItKCjAzz//jPnz52Pw4MG4e/cuKlSoYFNdzkKtWrUwbNgw7N+/H+vXr8fx48cd/tn7Dh064PTp0wgODkabNm2gUqn4lmQVYrEYo0ePxtq1a/Hjjz8iNTW1yNuMKfbFJlPG0tPTkZWVZZdHEA0Ynox5fY1dlmWxb98+TJkyBcHBwVi4cCHq1atnN03OwuTJk1G+fHnMnj2bbymcwLIs3nnnHfzxxx+IiIjgWw4n3L9/H+3atUNqaqrDnxQdHc4r3dzcXLsnXOB/FW+5cuXg6emJ69ev44svvkBmZiZ+/vlndO3a1a56nAWdToctW7YgISGBbymcceXKFeh0OoSHh/MthTM2b96MgQMH0oQrADgd09VqtcabZnzAsixSUlIwYcIEdOnSBf3798f169dpwrUhp06dgr+/P1FXEIbHzUkZ62dZVhCr81EK4bTSlcvlvCVcA3q9Ht26dcPcuXPh7e3NqxZngLSD2VC5Hz9+nG8pnEFi5e7IcJZ0VSoVZ892WwPDMKhevTpcXV151eEM5OfnY9euXbhx4wbfUjjj9OnTqFy5Mq3cKTaDs+EFhULBe8I1wLIsFAoF3zKI59ChQ2jUqBGCgoL4lsIZpFbugwcP5lsK5f/gJOlqtVrjo72W0KtXL1y6dIkLKUaUSiW0Wi2nbVKKQlqCMlTugwYN4lsKZ5BYuTs6nCTd7Oxsq7bfs2dPkfGm6dOnWysJgPW6KCWTm5uLI0eOoG/fvnxL4YzDhw8jNDSUVu4Um8JJ0uViLDc9PR0LFixAXl4egMJ5ha+/c8xcDI8KU2zDnj170K5dO6IeMiEtQZFYuZMAJ0lXrVZbtX3Xrl3x5MkTdO7cGdOmTcOlS5ewbds2jBgxglddlJKJi4vD0KFD+ZbBGYbKvV+/fnxL4QwSx9xJwOqkq9FobHIDjYtXjLAsC41Gw5EiioHMzExcuHDBpusb2xtD5e7r68u3FM4grXInBauTrlqt5mQqyrNnz3Ds2DF8//33CA8PR9++fbF+/Xqr2mQYhla7NmDbtm3o0aMHUW8PIC1BkTjmTgpWz9M1LH9nLRUrVkTPnj2N/w4JCUFISIjV7XKlj/I/4uPjMXPmTL5lcMa///6L8+fPY9euXXxL4Yw9e/agffv2RFXupGB1pWuLoYV58+Zx1pZQ5g6TQlJSEh4+fIj33nuPbymcsW3bNsTExBBXudO5ucLE6qQr9KdchK7P0di8eTMGDBhA1JsG4uPjibopmJmZifPnzxe5cqQIB6uTrkhks3XQOUHo+hwJlmURFxdH1NgniZX79u3biRtzJwmrx3SlUqnVl/BHjhyxVkaxsCwLqVRqk7adkRs3bkClUqFly5Z8S+GMLVu2oH///sRV7jNmzOBbBqUErC4DJRKJYC/hGYYh6mDiG8M4ISlXD7Ryp/ABJ6uMSaVSQb7ShFa53KHX67F582YcOHCAbymccePGDbx69Yqoyn3z5s3EVe6kwUnJIpPJBFft6vV6uLu78y2DGP766y94e3sT9QZZWrlT+ICTaPPx8eGiGU7R6XSIjIzE1KlTHf5trkKAtIcHDJU7SZ5IHHMnEU6Srlgshkwm46IpzqhQoQKOHDkCkUiE1q1b47333sOOHTvoY8EWUFBQgB07dhC1cMpff/0FLy8vWrlT7A5nvePn5yeYIQaGYeDn54fatWtjwYIFSElJwSeffILly5ejWrVq+Pbbb5GcnMy3TIfh6NGjCAkJQY0aNfiWwhmkzc01VO4keSIVzpKuh4eHIMZ2GYaBTCYrMkfRzc0NgwcPxqlTp5CQkICXL18iLCwMMTEx2L9/P3Q6HY+KhQ9pQwukVu7e3t5o2LAh31IoZcDpdUhAQIAgkm5gYGCJf69fvz6WLl2KlJQU9OvXD3PnzkXNmjUxd+5cpKen21GpY/Dy5UscOHAA/fv351sKZxw7dozIyp2kEyPJcJp0xWIxgoKCeEu8DMMgKCgILi4uZX7Xw8MDH3/8Mc6fP4+9e/ciNTUV9evXR9++fXHs2DG6UM7/sW/fPrRu3RoVK1bkWwpnkJagSKzcSYbzEXdPT0/4+vraPfEyDANfX194enqavW2TJk2wevVqJCcno3Pnzpg6dSqCg4Px008/4dmzZzZQ6ziQNvZJYuVO4pg7ydjkNqe/vz+8vLzslngZhoG3tzf8/f2tasfT0xNjx47F1atXsWnTJty+fRt169bF0KFDcebMGadbsUyhUODs2bN4//33+ZbCGfv27UOrVq1o5U7hDZskXcO4qj0qXkOFy+V4MsMwaNGiBTZu3IikpCSEh4dj9OjRaNCgAZYtW4YXL15wsh+hs2PHDkRHR6N8+fJ8S+EM0hIUiZU76dhsQh/DMKhSpQqqVq3Kyat3imtfJBKhatWqqFKlis2Su6+vLyZNmoQ7d+5g1apVSExMRI0aNTBixAhcvHiR6OqXtASlUChw5swZ4ip30sbcSYdh7ZA1tFot5HI5J28NBv43LSwgIABiMSfLR5jFv//+iw0bNuCXX36Bt7c3xo4diyFDhjhMRajRaKBWq6HX68GyrPEEJpVKjc/sP336FGFhYZDL5XB1deVZcdmY4mn16tU4efIktmzZwrNa0zDFU0xMDAYNGoQPPviAZ7WmYYon0rFL0jWgUqmgUCiMr0Y3Z9eGSlYmk8HPz08Qa4Xq9XocO3YMq1atwunTpzFo0CCMGTMGjRs35ltaEbRaLbKzs6FUKqFWq43B/iaGz6VSKa5cuYIzZ87gP//5Dw+Ky8YST3FxcWjSpAliYmJ4UFw25noSi8VYvHgxfvjhB3h7e/OguGws6SeZTAYfHx9eCip7YNeka4DEjkhNTcWvv/6KdevWISgoCGPHjsWAAQN4XXTHmpOcAU9PT8Gc5ADrPOn1eri4uAjqxA1QT28ixAKLS3hJum9C0iWHVqvFwYMHsXr1aly8eBHDhg3DmDFj8M4779hVA0nDOQD1ZArUk2MgiKRLKo8fP8batWuxfv161KtXD2PHjkXv3r1tOkaam5uL1NRUsCzL6U0+hmGMD59YMhfaGqgn06GehA9NunagoKAAe/bswerVq3H79m2MGDECo0aNQq1atTjbB8uyyMjIQFZWlk1nVBim6Pn7+9t8OiD1ZDnUk3Cha8DZAVdXVwwYMAD//e9/cerUKeTn5yMiIgLR0dHYu3cvtFqtVe2zLIu0tDSbB71hX1lZWUhLS7Ppvqgn6/dFPQkTWunyRF5eHnbs2GF8/HjkyJEYOXIkgoKCzG4rPT3dLkH/Ooaqo0qVKjZpn3riBupJeNBKlyfc3d0xbNgw/PXXXzh48CCePXuGRo0aoVevXjh8+PBbC+78/fffxS7Ck5uba/egB/5XdeTm5nLeNvXEHdST8KBJVwA0atQIK1aswNOnT9G9e3dMnz4dderUwQ8//IDMzEw8evQI4eHhiI2NLbKdVqs13rjgA5ZlkZqaavXwyOtQT9xDPQkLmnQFRPny5TFq1ChcvnwZ27Ztw8OHD/HOO++gW7ducHFxwapVq7BixQrj9+VyOe9jWyzLQi6Xc9Ye9WQbqCfhQMd0BU5mZiaqV6+O/Px8AIBEIsGGDRvQu3dvPH78mPfABwrH2GrWrGn1JHaVSkU92RDqSRjQSlfg/PXXX9BqtfDy8oKHhwe0Wi2GDx8OhUIhiKAHCisOhUJhdTvUk22hnoQBGY94EEybNm2wdetW+Pn5Gf/z8vLCo0eP+JZWBKVSCa1Wa/FTQ1qt1vjIqFCgnt6GRE/2xjFUOjGVKlVC3759i3wm1LdZZGdnW7zEYHZ2NsdquIF6entbIWKNJ3tDhxccEK6ea3+dw4cPo2fPnmjevDnef/99JCQkmLU9y7JWVUDWetJoNJg8eTK6du2K0NBQXLp06a3v3LlzBx999BEiIiLQrl07bNq0qdQ2+fT06NEjDBw4EK1atUKrVq0wcuTIIlc3hnH95s2bIyoqChs2bDCpXb77KS8vD3PnzkVkZCRatmyJjz766K3vaDQa9OzZE506dTKpTWs92Rta6TogarWa0/YyMzPxzTffYNmyZWjTpg3OnDmDL7/8EocPH0aFChXsoosLT02aNMEHH3yAL7/88q2/ZWdnY9y4cZg6dSq6dOkCjUaDzMxMm+qyZtuKFSti8eLFCAgIgF6vx5YtWzB16lTs2rULQGGimTdvHoKDg5GSkoIxY8bA398f0dHRNtVlbT/NmjULOp0Oe/fuhZeXF+7du/fWdzZs2AAfHx+8evXKbrrsCa10HQyNRmNRpREaGoqnT58a/z19+nQsW7YMQGHS9fT0RGRkJBiGQdu2beHu7o6UlBSz9sGyLDQaTanfefTo0VsHk6meSvMgkUgwbNgwhIWFQSR6O6x///13tGrVCjExMXB1dUW5cuVMWvvCFE83b958S78pnkrz4+npicDAQDAMA5ZlIRKJivTHiBEjUL9+fYjFYtSsWRMdOnTA1atXy/RjiqeXL18iKSnprc+t9ZSUlISTJ0/iu+++g6+vL1xcXNCgQYMi26empmL//v0YOXKkSV4MmNJPQoEmXQdDrVZzvthHgwYNULNmTZw4cQI6nQ4JCQmQSCQIDg42qx2GYcqsON5//31UqVIFCxYsMCZfW3h6kxs3bsDLywsffPAB2rVrhwkTJiA9Pb3M7crypFAo0KhRI4SEhGDfvn3GpMSVp1atWqFZs2aYP39+iYmIZVlcuXIFderUManNsjzt3LkTtWvXRo8ePXD79m3j59Z6unXrFqpUqYIVK1YgMjISvXv3xrFjx4p8Z/78+Zg4cSKkUqlZbZsSe0KBDi84GMU9CmwtLi4u6NmzJ7766isUFBRAIpFg0aJFZs991Ov1uHz5MvLy8kr8jmHx+u+++w6zZ8/GkCFDsGDBAmstlElmZibu3r2LNWvWoG7duli8eDFiY2Pxxx9/lLqdTqcr1dOLFy/g6uqKBw8eYODAgfD19cWMGTMwePBgTnSfO3cOKpUK+/btQ0BAQLHfWblyJfR6PXr16mVSm1qttlRP165dg5ubGw4ePIijR4/i3XffPlS2+gAABfFJREFUxdKlS1G/fn2LfQCFffDw4UN07twZ//3vf3Ht2jWMHz8etWvXRq1atZCQkACdTodOnToVOyZfFrY4NmwBTboOhi3mRyYmJmLx4sXYsGED6tWrhzt37uCzzz7DqlWrzFp8XafT4ejRo6Ve5hrepJyfnw+RSIRNmzZhypQpVnsoCzc3N3Ts2BENGzYEAIwbNw6RkZFQKpWQyWQlbqfRaEr1lJ+fb3wUVa1WQy6XY8mSJRg0aBBn2j08PDBgwAC0bdsWe/fuLTLOHh8fjz///BMbN240eZ3mgoKCUj0ZHq/V6/XQaDS4ePEi/vjjD3z//fdW+XBzc4NYLMbo0aMhFosRHh6OiIgInDt3Dv7+/li8eDFWrlxpcftCmTtcFjTpOhiWXt65u7sXqWyeP3+OypUrAwD++ecfNG3a1Di+1rBhQ4SGhuL8+fNmJV2JRIKvvvoKXl5eJX4nODgYT58+hbe3N77//nsMGzYMKpUKaWlpVnkoi+Dg4CK/nam/o1QqLdXTv//+i4CAAHh4eKBRo0ZYuHAhWrdujZycnDLbNsePXq+HWq3Gv//+a0y6u3fvxq+//oqNGzfC39/fJD9AYRIvzdP69esxZswYSKVSDBkyBLNmzUJQUJDVnoobrjL0w9OnTyGXy42zGTQaDV6+fIn27dsjLi4OgYGBZe7bUdbYpWO6DkZxN4lMISQkBAcPHoROp8PZs2fx999/G//WoEEDXLlyxXgn+e7du7hy5YrZY7qm6Bs2bBhWrlyJlJQUjBgxAhKJxGRPpXkACis4w+PSGo0G+fn5xuqnV69eSEhIwL1796DRaLB69WqEhYWVWuWa4snLywvDhg3D0aNHkZiYiNatW5e5jSl+zp07h7t370Kn0+Hly5f46aef4Onpabz5t3//fvznP//B2rVrUbVq1TL3ZY6nxo0bY+zYsXjw4AF+/fVX43Kj1npq2rQpqlSpgnXr1kGr1eLq1au4ePEiWrVqhTp16uDYsWPYsWMHduzYgVmzZqFChQrYsWOHyScUS48Ne0PXXnAwNBoN7t+/b/al1O3btzF9+nSkp6ejY8eO0Ol0CAoKwueffw6g8DJ106ZNeP78OXx8fDB48OBi51CWBsMwCA4ONvu9dqZ6KstD165d31oA5fDhw8YqaevWrVizZg3y8vIQFhaGb7/9tswD2paeSvNz5MgR/Pzzz8jMzIRUKkXDhg0xceJEhISEAACioqKQmZlZRFdMTAxmzpxZpja+PAHAw4cP8d133+HBgweoUqUKPv/882Ln4166dAlff/21yfPFLfXEBzTpOiB37twR5E0DkUhk8c0W6sl+UE/84hj1OKUI5k6nsRfW6KKe7Af1xC806TogMplMcDcNDK/KthTqyT5QT/xDk64D4uPjw7eEYrFGF/VkP6gnfqFJ1wERi8WCO7PLZDKrltajnuwD9cQ/NOk6KH5+foK5zGMYBn5+fla3Qz3ZFupJGNCk66B4eHgIYnzNMJ7GxetSqCfbQT0JB5p0HZiAgABBBL4pTwuZCvVkG6gn4UCTrgMjFosRFBTEW/AzDIOgoCC4uLhw1ib1xD3Uk7CgSdfB8fT0hK+vr92Dn2EY+Pr6wtPTk/O2qSfuoJ6EB026BODv7w8vLy+7BT/DMPD29jZrkRVzoZ6sh3oSJvQxYEJgWRYZGRnIysqy6RJ3hirD39/f5gca9WQ51JNwoUmXMHJzc5GamgqWZTk9ABiGMY6j2fuyjnoyHepJ+NCkSyBarRZyuZyztwYbpuYEBATwNgmdeiob6skxoEmXYFQqFRQKhfH11OZ0teHyTSaTwc/PTzBzIamnolBPjgdNuk6AVqs1vptMrVaDZdlix8QMn0ulUshkMvj4+Ai2uqCeqCdHhSZdJ0Sj0UCtVkOv1xuDXSQSQSqVOsQi0MVBPTkGJHoyF5p0KRQKxY7QeboUCoViR2jSpVAoFDtCky6FQqHYEZp0KRQKxY7QpEuhUCh2hCZdCoVCsSM06VIoFIodoUmXQqFQ7AhNuhQKhWJH/j9EVgWqGyVozgAAAABJRU5ErkJggg==\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],\n",
" 'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],\n",
" 'i8': ['i16'], 'i16': ['f16', 'i32'], 'i32': ['f32', 'i64'], 'i64': ['f64'],\n",
" 'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],\n",
" 'c64': [3, 3], 'c128': [4, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BfHlmmF_GOo_"
},
"source": [
"While this does allow for precision-losing promotions between integers and floats, these promotions will not mis-represent the *magnitude* of the result: though the floating point mantissa is not wide enough to represent all values, the exponent is wide enough to approximate them.\n",
"\n",
"This approach also allows a natural promotion path from `int64` to `float64`, though `uint64` remains unpromotable in this scheme. That said, a connection from `u64` to `f64` could be justified more readily here than before.\n",
"\n",
"This promotion scheme still results in some wider than necessary promotion paths; for example operations between `float32` and `uint32` result in `float64`. Additionally, this lattice makes it difficult to find a sensible place to insert `bfloat16` (see below) while maintaining the lattice property."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7M6EiBDqHNm-"
},
"source": [
"### Option 3: Avoid all wider-than-necessary promotions\n",
"\n",
"We can avoid *all* non-ideal 64-bit promotions if we're willing to fundamentally change our thinking around integer and float promotions.\n",
"Just as scalars always defer to the widths of array types, we can make integers always defer to the width of float types:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"cellView": "form",
"id": "JJ__tn0VJJRD",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2deVgTV/v+7wkJBDQsEZRNKy7gAlpRcMWtVQFx3xdqRRSs1lpb7Vtt+/6sWrW12lerxl2r4K5oqTuvUq1L3a27FRcggI0gBEMgy/z+4Ju8oghZJplkOJ/r6nXVkDnz3DnP3PPMmTNnKJqmaRAIBALBKvDYDoBAIBBqEsR0CQQCwYoQ0yUQCAQrQkyXQCAQrAgxXQKBQLAixHQJBALBihDTJRAIBCtCTJdAIBCsCDFdAoFAsCLEdAkEAsGK8NkOgPA/VCoVlEoltFotaJoGRVHg8XgQCoUQCARsh0fgKFzMO1vWREyXRdRqNQoKCiCXy6FUKvXJ8Tq6z4VCIUQiETw8PMDnk64jmAYX886eNFFkwRvro1AoIJPJIJfLAZQngqHoEkkkEsHT0xMuLi4WiZHAPbiYd/aoiZiuFVGr1ZBKpZDL5UYlx9ugKAoikQi+vr42W4EQ2IeLeWfPmojpWomioiJkZWWBpmlGkkQHRVGgKAr+/v5wdXVlrF0CN+Bi3tm7JmK6FoamaeTm5iI/P5/RBHkdiqIgFovh7e1d6VgWoWbBxbzjiiZiuhaEpmlkZ2ejsLDQokmig6IouLm5wc/PjxhvDYaLecclTWSergXJzc21WpIA5YlZWFiI3Nxcq+yPYJtwMe+4pImYroUoKiqy+GVQZdA0jfz8fBQVFVl1vwTbgIt5xzVNxHQtgFqt1g/0swFN08jKyoJarWZl/wR24GLecVETMV0LIJVKWUsSHTRNQyqVshoDwbpwMe+4qImYLsMoFArG5g6aA03TkMvlUCgUrMZBsA5czDsuagKI6TKOTCZjPUl00DQNmUzGdhgEK8DFvOOiJoCYLqOo1Wr944i2glwuJ2O7HIeLecdFTTqI6TJIQUGBRdqdM2eOWdtbKi6CbWBu/+bn52PWrFno2LEjOnXqhC+++OKN7xQWFqJr16744IMPrBKXMdv+888/+Pjjj9GzZ0+EhIQgOzu7wt+XLFmCvn37on379ujXrx8OHjxY4e8XLlzA8OHD0aFDB0RGRmL37t2MxPU2iOkyCJPjTyUlJZg7dy5evHgBoPyGwty5c41uXzceReAu5ubdp59+Ck9PTxw7dgzp6en48MMP3/jOsmXLEBAQYHCb5uadMZooikLnzp2xdOnSSv/u7OyMFStW4Ny5c1iwYAEWLVqEa9euAShfAnL69OkYOnQozp07hyVLluCHH37AvXv3GNekg6ySwiBKpdLkbXNzc7Fo0SJcuXIFWq0WUVFRGD16NObPn49bt25h5cqVmDx5sklPx5gTF8H2MbR/K8uxHj16IDc3Fxs3boSDgwMAoHnz5hW2u3btGh48eIChQ4di//79jMdlzLaVaZgzZw5Gjhz51kv/KVOm6P+/VatWaNu2La5fv453330XhYWFKC4uRr9+/UBRFIKDg9GoUSM8fPgQQUFBjGrSQSpdhlCpVCZXGxqNBlOmTIGPjw+OHDmCtLQ0REVF6f/+6iLMpkDTNFQqlUnbEmyHCRMmIDY2Fo8fP9Z/ZmjevS3Hrl+/joYNG2LOnDno0qULRo4ciYsXL1bY7rvvvsPs2bONPuFXl3dqtRohISFYuHAhXr58Wa2m6o4TQ1Aqlbh58yYaN24MAPD09ERUVBRSUlKg0Whw7do15OTkIDQ01CRNhkBMlyGUSqXJz2j/9ddf+Oeff/DZZ5/BxcUFTk5OaN68OZKSkvDVV1+hbdu2+Oijj7B+/XqTjJ2iKFLtcoBr164hOTkZzZs315uvoXlXWY6FhoYiLy8PZ8+eRXh4OE6ePIlx48bhk08+0Y9dJiUlISQkBC1btjQ63uryTq1W49atW5g/fz58fX315vs2TW/TYAzz5s1DUFAQOnfurP8sOjoaEokEbdu2xYcffoiPP/4Y3t7eJmkyBDK8wBBardbkbXNzc+Hj41NhHU9nZ2f8v//3//T/9vX1xTfffGNS+y9fvsT48ePx+++/mxwjgX3y8/Oh1WqhVCqxbds2bNu2DX/++Sdq1apV7baV5RgACIVC+Pn5YfDgwQCAqKgorF27FlevXkVwcDCSk5Oxc+dOk+JVKBRV5p1uaUbd/NfZs2djwYIFb9wIq06Dofz444948OABNm7cqDf1jIwMzJo1C8uWLUPHjh3x5MkTTJ06FXXr1kXXrl0rbcecYx0gpssY5tzI8Pb2Rm5uLtRqdaUJtWDBAnNCg4uLC5YuXQqhUGhWOwR26dWrF27cuAEXFxc0aNAA8+bNQ9OmTQ16WuptORYYGIhTp05V+K7OkHSV5YABAwAApaWlUCqV6N69O9LS0vRjwG/D2dm5yrxTKpUICAgARVEQCATo1asXvv3227ceS9UdJ1WxcuVKnDlzBps2bULt2rX1n//9999455139JVvQEAAunbtitOnT7/VdM29WU5MlyHMWf4tJCQEnp6e+Omnn/DRRx/BwcEBt2/fRps2bRiLzc3NDW5uboy0R2CHunXrIjg4GD/++CN69eoFiqJQWFho0LZvy7H33nsPS5YswYEDBxATE4O0tDTk5eWhTZs2qFWrFo4ePapv48iRIzh06BCWL19ereEC1eedSqWCk5MTYmJisGDBAjRt2hQA3qqpquOktLQUGo0GAFBWVobS0lI4OTkBANavX49Dhw5hy5YtcHd3r9Bm8+bN8eTJE1y4cAHh4eHIyspCeno6xo8fX6UucyDr6TKEXC5HZmamyZceOTk5WLhwIa5cuQKKohAdHY0vv/ySkdh4PB7q168PkUjESHsEdtBoNODxeBUOemPy7m05dvnyZcyfPx/Z2dkICAjArFmz0LZt2ze2T0lJwb59+/DLL78YFK8heafRaN4w8Ko0vU1DSEjIG9/966+/AJSbtUAgqFAdT5w4ERMnTgRQfjJZs2YNpFIpateujb59+2L69OmV3rhm4lgipssQKpUK9+/ft5nHFl+FoigEBgay/uppAvNwMe+4qOlVyOwFhhAIBDb7tgbdmBmBe3Ax77io6VWI6TKIrd6ostW4CMxgq/1rTlxc1KSDmC6DiEQimztD614tTeAuXMw7LmrSQUyXQTw8PNgOoVJsNS4CM9hq/5oTFxc16SCmyyB8Pt/mqkqRSGTyZHKCfcDFvOOiJh3EdBnG09PTZi6LKIqCp6cn22EQrAAX846LmgBiuozj4uJiE+NRuvEnFxcXVuMgWAcu5h0XNQHEdC2Cr6+vTSSKn58fqzEQrAsX846LmojpWgA+nw9/f3/WkoWiKPj7+xv0qCaBO3Ax77ioiZiuhXB1dYVYLLZ6slAUBbFYDFdXV6vul2AbcDHvuKaJmK4F8fb2hpubm9WShaIouLu7v3UtUELNgIt5xyVNZO0FC0PTNHJzc5Gfn2/RZ8l1Z2Vvb2/Wx8AI7MPFvOOKJmK6VqKoqAhZWVn6hZuZgqIo/bgTGVIgvA4X887eNRHTtSJqtRpSqZSxtwbrprL4+vqSByAIb4WLeWfPmojpsoBCoUBGRgbUajUEAoFRSaO73BGJRPD09CTzcAkGo1AokJWVhZcvX8LR0ZETeadQKHDt2jXUqlULfD7fLjSR8ogFcnNz0bFjRzg5OeHOnTuQy+VQKpX6t/6+ju5zoVAIkUgEDw8PUtkSjObly5fo3r07FAoF7t69C4VCYfd5d+nSJfTo0QNBQUFIT0+3i2PJtn7BGsC9e/fQqVMnFBcXQ6FQoFatWvDy8gJQvnizUqmEVqut8Np1oVBI1sMlmEVOTg46duyI3NxcODo6Qi6X61+PY695d/z4cQwcOBBarRYymQxeXl52cSwR07Uijx8/Rvv27VFUVASg/DHHJ0+eoHnz5gDKF29mOyEI3KOgoABhYWHIzc0FTdNwcnLC48eP9aZrj3l36tQp9O3bFyqVCkC5Rq1Wq3/Fji1rIvN0rQhFUQgPDwePxwOfz4dKpcLjx4/ZDovAcWiaRlhYGHg8HhwcHFBSUmL3eefo6IiQkBDweDz9fZHc3Fy2wzIIUulakXfeeQfHjh1DcHAwoqOjce7cOdSqVYvtsAgcRywWY//+/YiMjESTJk1w+/Ztm12v1lA6deqE9PR0+Pr6YsaMGTh8+LBNvlOtMsjsBStz8+ZNREZG4unTp5W+bZRAsAQymQxNmjRBVlYWateuzXY4jJCcnIykpCT89ttvbIdiFOSotzLbt2/HqFGjiOESrMqePXsQFRXFGcMFyk139OjRbIdhNKTStSI0TaNRo0bYv38/3n33XbbDIdQgIiIi8MUXXyAmJobtUBjBnit3Um5ZkXPnzsHZ2RmtW7dmOxRCDeLJkye4c+cOevfuzXYojLF7925ER0fbneECxHStiu5yiCxIQ7AmO3bswNChQ+Ho6Mh2KIxhr0MLABlesBoqlQp+fn44d+4cGjduzHY4hBpE69atsWLFCnTt2pXtUBjhyZMnaNu2LaRSqV2eSEilayXS0tLQqFEjYrgEq3Lz5k3k5+ejS5cubIfCGPZeuRPTtRL2fDlEsF+4OFvG3o8lMrxgBRQKBXx9fXH37l3yVgeC1eDibJmbN28iKioKT548sdsTiX1GbWekpqYiPDycGC7Bqpw/f55zs2W4ULmTx4CtgL1fDhHsE67NlqFpGsnJydi/fz/boZgFGV6wMAUFBWjYsCGePn0KNzc3tsMh1BDUajX8/Pxw9uxZzty8PXfuHCZMmIBbt27Z9YnEfmt0O2HPnj3o3bs3MVyCVTlx4gTnZsskJSVhzJgxdm24ABlesDjJycn4+OOP2Q6DUMNITk7GqFGj2A6DMVQqFXbv3o2zZ8+yHYrZkErXgmRlZeH69euIjo5mOxRCDUKhUODXX3/FiBEj2A6FMbg0z52YrgXZuXMnBg0aBKFQyHYohBqEbrZMvXr12A6FMbh0M5qYrgXhUqIQ7Aeu5Z2uch8+fDjboTACMV0LcffuXUilUnTv3p3tUAg1iIKCApw8eRKDBg1iOxTG4FrlTkzXQmzfvh0jR46Eg4MD26EQahB79+5F79694erqynYojMG1yp2YrgXQTeLmUqIQ7AOu5R0XK3diuhbg0qVLAIB27dqxHAmhJpGdnY1r164hKiqK7VAYg4uVOzFdC5CcnMyJSdwE+2LHjh2cmy2TlJTEqcodII8BM45Go4G/vz9OnTqFoKAgtsMh1CDatm2L77//Hu+99x7boTBCVlYWWrVqBalUyqkTCal0GebUqVPw8/MjhkuwKlycLcPVee7EdBmGazcyCPYBF2fLcPVYIsMLDKJUKuHr64u//voLfn5+bIdDqCHQNI3AwEAkJycjLCyM7XAY4e7du+jRoweysrI4dSIBSKXLKIcPH8a7775LDJdgVbg4W4aLlbsOYroMwtXLIYJtw9XFyrl6LJHhBYYoLCxEgwYN8PjxY3h4eLAdDqGGwMXZMhcvXsSYMWNw7949zpxIXoVUugyxf/9+9OjRgxguwapwcbYM1yr31yGLmBuASqWCUqmEVqsFTdOgKAo8Hg9CoRACgQBAeaJMnDiR5UhrLob0kb1hiCZ7e3igOk0ajQY7duzAqVOn2A7VYhDTrQS1Wo2CggLI5XIolUp9crzOq0nTuHFjTj1+aesY20dCoRAikQgeHh7g820z7Y3V5OjoCCcnJwwbNoyFaA3DWE1KpRJxcXGcWKz8bZAx3VdQKBSQyWSQy+UAyhPBUDQaDfh8PkQiETw9PeHi4mKpMGs05vSR7mC3tT4yR5NarYZAIOCUJq1WCwcHB5vTxBTEdFGeuFKpFHK53KjkeBsURUEkEsHX19dmqyp7g4t9RDRVjy1oYpoab7pFRUXIysoCTdOMJIkOiqJAURT8/f05tUISG3Cxj4gmw+HasVRjTZemaeTm5iI/P5/RBHkdiqIgFovh7e3N2buxloKLfUQ0mQ5XjqUaabo0TSM7OxuFhYUWTRIdFEXBzc0Nfn5+dp0s1oSLfUQ0mQ8XjqUaOU83NzfXakkClCdmYWEhcnNzrbI/LsDFPiKazIcLx1KNM92ioiKLXwZVBk3TyM/PR1FRkVX3a49wsY+IJuaw92OpRpmuWq3WD/SzAU3TyMrKglqtZmX/9gAX+4hoYh57PpZqlOlKpVLWkkQHTdOQSqWsxmDLcLGPiCbLYK/HUo0xXYVCwdjcQXOgaRpyuRwKhYLVOGwRLvYR0WQ57PVYqjGmK5PJWE8SHTRNQyaTsR2GzcHFPiKaLIs9Hks1wnTVarX+cURbQS6X2+V4lKXgYh8RTdbB3o6lGmG6BQUFJm/76NEjDB06FO3bt0dSUhKDUZkXF9cw5rewZJ+8jjl9ZOi21tQDEE1sUyNM15zxp02bNiEsLAwXLlzAmDFjkJ2djVWrVpkdk248ilCOMX30ap80bdoUcXFx6NixI/r06VPp97dt24bIyEiEh4ejf//+ePz4sUH7MbePDNX0qh6NRoPIyEh06NABPXv2xOLFi/VV3PPnzzFr1iz07NkTHTt2RGxsLG7cuGFUTGxoGjNmDG7fvo1x48YhPDwc3bp1w7Zt297Y5uLFiwgJCcHy5cuNisnejqUaYbpKpdLkbaVSKZo0aYLr169j7dq10Gg0AMrfS7V27VrW4uIaxvwWuj4BAGdnZwwaNAgzZsyo9Lt79+7Fvn37sHLlSly4cAErV640aqF5c/rI0G1f1dOjRw/s2rUL58+fx/79+3H//n19pahQKNCyZUvs3LkTZ86cQf/+/TFlyhSjbyRZW1NBQQEmT56MYcOG4cyZMzh06BA6depU4fsqlQqLFy9Gq1atLBqXLcB501WpVCZXuRMmTMDFixfx3XffYeLEiXBxccG8efNw+PBhnDlzBmPHjjUrNpqmoVKpzGqDCxjTR6/2SXh4OEQiEfr16wd/f/83vqvVarF69WrMmjULjRs3BkVRqF+/Ptzc3AyOzdQ+MlTT63o0Go1+URfdGrOZmZkAgPr162PcuHHw8vKCg4MDhg0bBpVKhUePHhkVm7U1LVy4EJ06dUJMTAwcHR1Rq1YtNGrUqMI2W7ZsQadOndCwYUOj4wLs61jivOkqlUqTn9HesGEDQkNDMXv2bPz5558V3vLL45n/0+kWba5JdOnSBT169NC/wRYwro9e75OqDtK8vDzk5eXh77//xvvvv4/IyEisXLkSWq3W4Hir66OioiK4ublhypQpFR5NNVRTZXp+++03dOjQAREREbh///5bFym/e/cuVCoVGjRoYLAeQzTt27cP9erVw6pVq1BaWmq2pufPn8PNzQ1jx45Ft27dMHXqVOTk5Oi/L5VKkZKSgsTERKN0GKPJluDGApVVYMwBVhXXr1/H/fv38fXXXyM1NRVhYWFISkoy6xU9NE2jpKSEk6+Zfhs5OTnIyMhA165d0a5dOyxcuBDBwcEW2VdeXh4A4OzZs9i3bx/kcjkSEhJQr149DB061KA2quuj/Px8lJSUYP369di4cSNiY2Px9ddfm7UEYd++fdG3b188efIEBw8eRJ06dd74TnFxMb788ktMnjwZIpHIqPar05SXl4cXL15g1qxZ+Prrr/H1118jISHB5GMpLy8Pd+7cwdq1a9G0aVMsXboUs2bNwtatWwEACxcuxNSpU81erJypY93ScN50mZpP2Lp1a7Ru3RrZ2dkAgLCwMISFhZnVZnFxMb766iscPXqUiRDtAl01UlJSgtOnT6NLly64ffu2Rfbl5OQEABg/fjxcXV3h6uqKYcOG4fTp0wabrkKhQEJCwlv76PXL2nXr1mHXrl0G36yrinfeeQdNmjTB/Pnz8dNPP+k/VyqVmDp1Klq3bo34+Hij260u79RqNVQqFcrKyvDy5Ut8+umnSEtL05uksTg5OaFnz576k+vkyZMREREBuVyOy5cvQ6FQIDIy0qS2X8VW5g5XB+dNl+nl3/z8/PDRRx8x0pZIJMLmzZuNGmO0d5o2bYqMjAw4Ojqif//+WLBgAby8vPQnMyZp2LAhBAKBWTlQq1atKvvo2bNn8Pf3B5/Ph4ODA2bOnInp06czZgBqtVo/pgsAZWVl+OSTT1CvXj188803JrVZXd5t3LgRkydPhkAggJeXF77//nsMGTLE5BkCgYGBFfrg1f+/cOECbt26he7duwMoPyHweDw8ePAAK1asMGo/9rLUI+dNl4mxV0ti6/ExTVBQEEJDQ7FgwQL93W1zpvtotVqoVCqo1WrQNI3S0lLweDwIBAI4OzsjMjISmzZtQvPmzSGXy7Fnzx6MHz/eqH1U1UcuLi7w9fVFXFwcpk+frh9WMFXT3r170b17d9SpUwcPHz7Ehg0b9Hf6VSoVZsyYAScnJyxYsMCs3Klq23r16iEgIADz5s3DkCFD9N81dX8DBw7Ep59+ijFjxqBx48aQSCQIDQ2FSCTC1KlTMWHCBP13Fy1aBC8vL5PGd+3lWOK86QqFQpu97KBpGkKhkO0wrEpqauobn5nTR5cvX0ZcXJz+3+3atUO7du2wadMmAMDs2bMxd+5c9OzZEyKRCEOGDMGgQYMMbr+6Pqpdu3alQwmmarp69SqWL1+OkpISeHh4oHfv3pg6dSoA4Nq1a0hPT4dQKKww5Wr16tVo27atwfuoTpNuTPl1TNXUvn17fPLJJ5gyZQpKSkoQGhqKxYsXAyi/kqhVq5b+u05OTnB2djb66s+ejqUa8eaI27dv2+QgO4/HQ4sWLdgOwybgYh8RTdbDno4l+6jHzcRWz4C2Ghcb2OpvYU5cRJP1sNW4KqNGmK5IJLK5QXbdq6UJ5XCxj4gm62BvxxLnx3QBwMPDA8+ePWM7jApoNBokJyejuLgYOTk5kEqleP/99xmbGWFv2GIfATDqkeHKtiWarIM5mqxNjTBdPp8PkUhkU+9UunbtGqZPn67/N5/Px7vvvstiROxii30kEonA55t+iBBN1sFcTdamRgwvAICnp6fNXBZRFIWBAweiY8eO+mTRaDTQaDR4/vw5y9Gxh631kaenp9ntEE2WhSlN1qTGmK6Li4tNjEfpxp9cXV1x9OhRBAQEgMfjoUOHDnj8+DEaN26M2NhY/PHHHzY71c1S2FofmftYKkA0WRImNVmTGmO6AODr62sTiaJbOEckEiE9PR0NGjTAokWLsGXLFmRkZCA0NBRxcXEICQnBzz//jMLCQlZjtia21kdMQDRZBqY1WYsaMU/3VYqKipCZmclKFalbWvD1xVB0S/i9/tmpU6cgkUhw7NgxDBkyBImJiWjXrp01Q2YFW+wjcyGamMVSmqxBjap0AcDV1RVisdjqZ2mKoiAWiytNkspioSgKPXr0wM6dO3H37l00adIEw4YNQ7t27bB+/Xq8fPnSGmGzgi32kbkQTcxhSU3WoMZVukB5FZmdnY3CwkKrnKUpioK7u7vZl2RarRbHjh2DRCLB77//jtGjRyMhIQEhISEMRmsb2GsfVQXRZD7W0GRpaqTpAuXJkpubi/z8fIsmi+6s7O3tzWiSZGVlYf369Vi3bh0CAgKQmJiIoUOH2tWTOdVh731UGUST6VhTkyWpsaaro6ioCFlZWaBpmtGEoSgKFEXB39/fopdBarUaqampkEgkuHz5Mj744AMkJCQgMDDQYvu0NvbeR5VBNBkOm5osQY03XaDcuKRSqVlvDX4V3VQWX19fq07azsjIwLp167Bx40YEBwcjMTERAwYMgKOjo9VisBRc6aNXIZqqxxY0MQ0x3VdQKBSQyWT6tVCN+Wl0lzsikQienp6szh0sLS1FSkoKJBIJ7t69i7i4OEycONHkl/7ZElzpo1chmipiq5qYgphuJajVahQUFEAul0OpVFY6pQv431QvoVAIkUgEDw8Pmzsb695NtXXrVrRv3x6JiYmIjo62+/eycamPdBBN9qHJXIjpGoBKpYJSqYRWq9UnB4/Hg1AohEAgYDs8gygpKcGuXbsgkUiQnZ2N+Ph4TJgwwS4nl1cGF/rodYgmbkJMtwZy7do1rFmzBjt27ECPHj2QmJiI999/325ed0Ig2DPEdGswcrkcycnJkEgkKCoqQkJCAsaPHw8vLy+2QyMQOAsxXQJomsbFixchkUiwf/9+REVFITExEREREXY9H5JAsEWI6RIqUFBQgK1bt0IikQAAEhIS8MEHH9jVItEEgi1DTJdQKTRN4/Tp05BIJDh06BAGDRqExMREhIeHk+qXQDADYrqEann27Bk2b96MtWvXQiQSITExEaNHj7ar91IRCLYCMV2CwWi1WqSlpUEikeDkyZMYMWIEEhMT0bp1a7ZDIxDsBmK6BJOQSqXYsGED1q5dC39/fyQkJGD48OGce3qIQGAaYroEs1Cr1Th8+DAkEgkuXLiAsWPHIiEhAc2bN2c7NALBJiGz4Qlmwefz0a9fP/z222+4dOkSatWqhR49eqB79+7YsWMHSktL2Q6RQLApSKVLYJyysjIcOHAAEokEN2/exPjx4zFp0iQ0atSI7dAIBNYhlS6BcRwdHTFs2DCkpaXh9OnTUKvVaN++PSIjI5GSkgK1Ws12iAQCa5BKl2AVlEol9uzZA4lEgsePHyM+Ph7x8fHw9/dnOzQCwaqQSpdgFYRCIcaOHYszZ87g8OHDkMlkaNWqFQYMGIDDhw9Do9GwHSKBYBVIpUtgjeLiYuzYsQMSiQTPnz/HpEmTEBcXh3r16rEdGoFgMUilS2CN2rVrIz4+HpcuXcLu3bvx8OFDNGvWDCNGjMDJkyet8nZZAsHakEqXYFMUFhZi27ZtWL16NVQqFRITEzFu3DiIxWK2QyMQGIGYLsEmoWkaZ8+ehUQiwa+//or+/fsjMTERHTt2JAvuEOwaYroEm0cmk2HLli2QSCRwdnZGYmIixo4dy4nXcRNqHsR0CXaDVqvFyZMnIZFIcOLECQwbNgyJiYkIDQ1lOzQCwWCI6RLsktzcXGzcuBFr165F3bp1kZiYiBEjRqBWrVpsh0YgVAkxXYJdo9FocPToUUgkEvzxxx8YPXo0EhISEBwczOqTBhgAACAASURBVHZoBEKlkCljBLvGwcEB0dHROHjwIK5evQoPDw/07t0bERERSEpKglKpZDtEAqECpNIlcA6VSoVff/0VEokE165dw7hx4zBp0iQ0bdqU7dAIBFLpEriHQCDA4MGDcezYMZw9exYURaFz587o1asX9u7dC5VKxXaIhBoMqXQJNYLS0lLs3bsXa9aswYMHDzBhwgRMnDgRDRo0YDs0Qg2DVLqEGoGTkxNGjx6N9PR0nDhxAkVFRWjTpg369euH1NRUsuAOwWqQSpdQY1EoFNi5cyckEglyc3MxceJETJgwAT4+PmyHRuAwpNIl1FhcXFwwfvx4XLhwAfv370dmZiZatGiBoUOH4sSJE9BqtWyHSOAgpNIlEF6hqKgISUlJkEgkUCgUSEhIwIcffghPT0+2QyNwBGK6BEIl0DSNCxcuQCKRICUlBTExMUhISECXLl3IgjsEsyCmSyBUQ35+Pn755RdIJBI4ODggMTERsbGxcHd3Zzs0gh1CTJdAMBCappGeng6JRIIjR45gyJAhSExMRLt27Uj1SzAYYroEggnk5eVh06ZNWLNmDcRiMRITEzFq1CjUrl2b7dAINg4xXRtCpVJBqVRCq9WCpmlQFAUejwehUAiBQMB2eCbBdU1arRY3btxAamoqUlJS0Lt3byQkJKBVq1Zsh2kwXOwjW4aYLouo1WoUFBRALpdDqVTqE/51dJ8LhUKIRCJ4eHiAz+ezEHH11GRNWq0WZWVluHfvHu7evYvAwEAMHjwYzs7OLET9drjYR/YEMV0WUCgUkMlkkMvlAGDUCxh1B4dIJIKnpydcXFwsEqOxEE0V0Wq10Gg0OH36NIqLizFkyBAEBQVZKlSD4GIf2SPEdK2IWq2GVCqFXC5n5E23FEVBJBLB19eXtQqEaDKsvbNnz+LQoUP44IMPMHDgQDg6OjIQqeH751of2TPEdK1EUVERsrKyQNM0o68WpygKFEXB39/f6u8MI5qMQ6PRYP369di9ezfi4uIwceJEBAQEMLqP1+FiH9k7xHQtDE3TyM3NRX5+PuMH8atQFAWxWAxvb2+LT18imkxHp2P9+vXYunUrwsLCkJiYiL59+zJaNXKxj7gCMV0LQtM0srOzUVhYaNHE10FRFNzc3ODn52exA4BoMh+dJrFYjN27d2PNmjXIzMxEfHw84uPj4efnZ1b7XOwjLkEWvLEgubm5Vkt8oPxgKywsRG5ursX2QTSZj05TYWEhxo0bh7NnzyI1NRV5eXkICQnBoEGDcPToUZMX3OFiH3EJUulaiKKiImRmZlot8V+FoijUr1+f8bE2oolZKtMkl8uxfft2rF69GoWFhUhISMD48eNRt25dg9q0NT2ENyGVrgVQq9X6mxdsQNM0srKyoFarGWuTaGKeyjSJRCJMmjQJV65cwY4dO3Dv3j0EBgZi1KhRSE9PrzJWW9RDeBNiuhZAKpWylvg6aJqGVCplrD2iyTK8TRNFUQgPD8fGjRvx6NEjdOzYER999BFatGiBn376Cfn5+W9sY8t6CP+DmC7DKBQKxuZDmgNN05DL5VAoFGa3RTRZDkM0eXh4YNq0abh58ybWrl2LixcvolGjRvjwww9x/vx50DRtV3pqOsR0GUYmk7Ge+DpomoZMJjO7HaLJshiqiaIoREREICkpCQ8ePEDLli0xduxYtGnTBtevX7c7PTUVciONQdRqNe7du2czyQ+UH6hBQUEmzwElmqyDqZq0Wi3++9//om7duuDxbKeGMrePuIzt9BIHKCgoMHnbgQMH4uLFiwxG8z/MicuYbS2p4XWsocmaegDTNPF4PLRu3RoODg4Gfd9e+ojLkEqXQTIyMswey6JpGitWrMCBAwegUCjQrFkzzJkzB02aNDG5TRcXFzRq1MikbU3R9ODBAyxZsgS3b9/Gixcv8Ndff73xncOHD2P16tXIzc1FnTp1MH/+fLRt29bgfVhT0+HDh7Fq1SrIZDI4OjqiS5cu+PLLL1G7dm2UlZVh/vz5OH/+PAoLC1G/fn188skniIiIMDouUzWZmneZmZlYtGgRLl26BEdHRwwaNAgzZsyo8J0nT55g8ODB6NWrFxYtWmRU++b0EZchlS6DKJVKs9s4evQoUlJSsHnzZpw5cwatW7fG7NmzWYvLlG35fD769OmDuXPnVvr3s2fPYtmyZZg3bx7Onz+PzZs3w9/f3+JxmbptmzZt8Msvv+DcuXM4fPgw1Go1VqxYAaB8qMLb2xubNm3CuXPn8PHHH+Pzzz9Hdna2xeMyZzuVSoVJkyYhPDwcJ0+exIkTJ9C3b983vrdgwQIEBwdbLa6aADFdhlCpVGaNEfbp0wfnzp1DdnY22rRpg/r168PBwQExMTF4+PChWbHRNA2VSlXld168ePHGZ8Zq0mkICAjA4MGD31qdr1q1ComJiWjdujV4PB7q1auHevXqGbwfoHpNZWVlePny5RufG6NJp8fb2xseHh76zx0cHPD06VMA5dXcRx99BD8/P/B4PHTr1g1+fn64ffu2UXqA6jUVFRVBo9GYrAf4n6aUlBTUrVsX48aNg4uLC5ycnN5YevLw4cMQiURo3769cUL+D0PyriZCTJchlEolI8+dR0VFITMzE48fP4ZKpcLBgwfRuXNns9qkKKrKquPp06cQi8UYMGBABbNgStOraDQa3Lp1C/n5+YiOjsZ7772HBQsWGF0VVadpw4YN8PT0xOzZsyuMLZqq6cqVK+jYsSPat2+PEydOIDY2ttLvyWQyPHnyBI0bNzZ6H9Vp6t+/Pxo2bIht27bpzddUPTdu3ICvry8SExMRERGB8ePH4/79+/q/FxcXY+XKlZg5c6bRbeuoTk9NhdxaZAhTn5N/HS8vL4SGhqJfv35wcHCAt7c31q9fb3ZsmZmZbz04Hz9+DCcnJ6SmpuLo0aPo1KkT5s+fj5YtW5q138p4/vw51Go1jh8/ji1btoDP52PatGlYu3Ytpk2bZnA71Wl68uQJVCoVli5dip9++gljx47Ft99+a/JbHEJDQ3Hu3Dnk5eVh79698PX1feM7KpUK//rXv9C/f3+TxjKr0ySTyZCVlYWEhAR8/vnn+PzzzxEfH2/0foDyd7xdvHgRy5cvR4cOHbBt2zZMmzYNv/76KwQCAX7++WcMGjQI3t7eJrWvg6njgksQ02UIpu5Hrl69Gjdv3sTx48fh6emJ1NRUxMfHY//+/SYbhlKpxJo1a3DmzJlK/15WVoaysjJotVqUlpbi5MmTGDNmDK5evWqOlEpxcnICAIwePRpeXl4AgA8++MBo01UoFFi3bt1bNeXn50Oj0egrwnXr1oHH4xl9M+h16tWrh86dO2PWrFnYtWuX/nOtVovZs2dDIBCYPAZfnaZHjx7pv6dQKDBr1ix069YNtWrVMnpfTk5OaNOmjf6G34cffoi1a9ciIyMDNE3j/Pnz2L17t0k6XoXcp38TYroMwdRl+L179xAZGamvMAYOHIjvv/8eGRkZJleeLi4uWLJkCdzc3Cr9+4MHDxAcHAyBQIC6devi+++/x+DBg/WvdWESNzc31KtXr8LvZcpvV7t27So1LVu2DDNnzoSjoyO6du2KxYsXo3Xr1igsLDQ5dh0ajQaZmZn6f9M0jW+++QbPnz/HqlWrTH6ZY3Wa2rVrh+vXr0MgEGDy5Mn48ssvIRAITLppFxgYiGvXrlX6t4sXL0IqlaJXr14Ayk1eq9Vi+PDhFU40hkCWenwTYroMwdTE9ODgYBw7dgyRkZEQi8X47bffoFarUb9+fYvFV6dOHbRv3x7Tpk3D4MGD9d81VRNN0ygrK9PfRCktLQVFUfpX1AwcOBDJycno3Lkz+Hw+tm7diq5duxq9n6riCwoKQnR0NObNm4fWrVsbtM3bSE1NRdu2beHj4wOpVIrly5dXuLk0b948PHr0COvWrYNQKDS6/VepKr5u3bqhW7du+PLLL+Hp6QkAJp8YY2Ji9DMywsPDkZSUBHd3dzRq1AgNGjRAVFSU/rubN2+GVCrFV199ZfR+bOmBDVuBmC5DCIVCRi6l4uLi8Pz5cwwbNgwlJSVo0KABli5datZyeTRNV2kGYrEYv//++xufm6pJKpUiMjJS/+927drB19cXR48eBQAkJCTgxYsX6NevHxwdHdGnTx9MmjTJqH1Upyk6OhrR0dFvfG6KpoyMDCxbtgxyuRwikQgRERGYPn06gHKtu3fvhqOjI7p3767f5ptvvkFMTIxR+6lO048//vjGZ6b2UUBAABYuXIh58+YhPz8fzZs3x4oVKyAQCCAQCCoMZbm4uMDR0RFisdiofVSnp6ZCHo5gkNu3b9vkjQMej4cWLVqYtC3RZD1M1cQ1PVyH1P4MYqtndXPiIpqsh6lxcU0P1yGmyyAikcjmbhzoXpdtKkSTdTBHE9f0cB1iugzy6lNLtoQ5cRFN1sPUuLimh+sQ02UQPp9vc2d3kUhk1vJ6RJN1MEcT1/RwHWK6DKLVaiEWi23mUo+iKP3UInPw9PQkmiwIE5q4pofLENM1gxUrVsDPzw8eHh5wcnKCg4MDAgMDbWKMTTem5uLiYnZbLi4uRJOFYEoT1/RwGWK6ZtCsWTPIZDK8ePECZWVl4PP52Lx5M3x9fW0i+f38/Bhrj2iyDExq4poerkJM10S0Wi2ysrIAlCeaUCjExIkT0bt3b/D5fPj7+7N2AFAUBX9/f4PfJmAIRBPzMK2Ja3q4CjFdE0hPT0e7du2wbt06JCUlgc/nQywWY8mSJfrvuLq6sjK+S1EUxGKxWU+wvQ2iiTkspYlrergIub1oBA8fPsSsWbNw+fJlLF68GMOHDwdFUZDJZAgLC3tjHMvb2xsajQaFhYVWWW2Joii4u7ubvRxfVRBN5mNpTVzTwzXIY8AGUFhYiAULFmDjxo2YMWMGPv30U4OXWaRpGrm5ucjPz7foAaCrNLy9vS1e5RBNpmMtTVzTwyXI8EIVaDQarFmzBs2aNcPz58/x119/Yfbs2Uata0tRFHx8fFC/fn3weDzGE5OiKPB4PNSvXx8+Pj5WSXyiybT2ramJa3q4BKl030JaWho+/fRTiMViLFu2DG3atDG7TbVaDalUCrlczkj1oZue4+vry9pEdKKpetjWxDU99o5NmK5KpYJSqYRWqwVN0/qzqFAoNHlBaFO5f/8+Pv/8c9y6dQs//PADBg0axPhZXKFQQCaT6ddCNaYLdLGIRCJ4enrazHxIoqkitqiJa3rsFVZMV61Wo6CgAHK5HEqlUm+0r6P7XCgUQiQSwcPDw2Jn1oKCAnz77bfYtm0bZs2ahWnTpulfLWMpbPF3MBWapjF48GA0aNAAX331FSc0FRUV4d1338WWLVvQrFkzTmhKT0/HRx99hFOnTnFCjz1i1V+wqjPt27yfpmkoFAqUlJTg2bNnjJ9pVSoV1qxZg2+//RZDhgzBrVu3ULduXUbarg4+nw8vLy/9u8JsqeI3BpqmMXXqVKSkpKBDhw6c0FRQUIAOHTrg0aNHyMjIQEREhN1runDhAvr06YPS0lIIhUK712OvWMV0mRhT0m1XVFSkX8Hf3DGlw4cPY8aMGfD390daWhpCQkJMbosJdKv22xNarRbx8fHYsWMHgPL+eRV71PTPP/+gc+fOyMjIAEVRb7yDzB41nT59GlFRUSgtLYWzszMeP36sz3d71GPPWNx0i4qKkJWVBZqmGZu6QtM05HI57t+/D39/f6MnZN+6dQufffYZHj16hB9//BF9+/Yld19NZN68ediyZYv+zQVSqZTliMynV69eePjwoV7TnTt3WI7IPJ48eYL3338fZWVlAMpN9lXTJVgXi00Zo2kaOTk5yMzM1F+2MN2+VqtFZmYmcnJyDGpfJpNhypQp6NGjB6KionDz5k3ExMQQwzWD2NhYTJ06FTweD87Oznjx4gVKS0vZDsssli1bhp49ewIAHB0d8eDBA5YjMg8/Pz+sWLEC/v7+4PF4UCgUePz4Mdth1VgsUunSNI3s7GyrPBFD0zTy8/Oh0Wjg5+dXqYGWlZXh559/xsKFCzFq1CjcuXMHderUsWhcNYVGjRohNjYWqamp2LhxI06cOGH3z9736NEDv//+OwIDA9GlSxcoFAq2QzILPp+PSZMmYd26dfj++++RlZVV4W3GBOtikdkLOTk5Fn8S5nV0T8b4+PjoP6NpGgcPHsTnn3+OwMBALFmyBM2bN7daTDWFGTNmoHbt2vj222/ZDoURaJpGs2bNsHXrVoSHh7MdDiPcv38f3bp1Q1ZWlt2fFO0dxivdoqIiqxsu8L+Kt1atWnB1dcX169fx6aefIi8vDz///DP69Olj1XhqChqNBjt27EBaWhrboTDGlStXoNFoEBYWxnYojLF9+3aMGDGCGK4NwOiYrlqt1t80YwOappGZmYmpU6eid+/eGDZsGK5fv04M14Kkp6fD29ubU1cQycnJGD16NGfG+mma1msisA+jla5UKmXNcHVotVpER0dj/vz5cHd3ZzWWmgDXDmZd5X7ixAm2Q2EMLlbu9gxjpqtQKBh7ttscKIrCO++8A0dHR1bjqAmUlpZi3759uHHjBtuhMMbvv/+OevXqkcqdYDEYG16QyWSsG64OmqYhk8nYDoPzHD58GK1atYK/vz/boTAGVyv3UaNGsR0K4f9gxHTVarX+0V5TGDhwIC5evMhEKHrkcjnUajWjbRIqwjWD0lXuI0eOZDsUxuBi5W7vMGK6BQUFZm2fkpJSYbxpzpw55oYEwPy4CG+nqKgIR48exZAhQ9gOhTGOHDmCkJAQUrkTLAojpsvEWG5OTg4WL16MkpISAOXzCl9955ix6B4VJliGlJQUdOvWjVMPmXDNoLhYuXMBRkxXqVSatX2fPn3w+PFj9OrVC7Nnz8bFixexa9cuxMXFsRoX4e0kJSVhzJgxbIfBGLrKfejQoWyHwhhcHHPnAmabrkqlssgNNCZeMULTNFQqFUMREXTk5eXhwoUL6NevH9uhMIaucheLxWyHwhhcq9y5gtmmq1QqGZmK8s8//+D48eP47rvvEBYWhiFDhmDjxo1mtUlRFKl2LcCuXbvQr18/Tr09gGsGxcUxd65g9jxd3fJ35uLl5YX+/fvr/x0UFISgoCCz22UqPsL/SE5OxjfffMN2GIzx7NkznD9/Hvv27WM7FMZISUlB9+7dOVW5cwWzK11LDC0sWLCAsbZsZe4wV8jIyMDff/+N999/n+1QGGPXrl2IiYnhXOVO5ubaJmabrq0/5WLr8dkb27dvx/Dhwzn1poHk5GRO3RTMy8vD+fPnK1w5EmwHs02Xx7PYOuiMYOvx2RM0TSMpKYlTY59crNx3797NuTF3LmH2mK5QKDT7Ev7o0aPmhlEpNE1DKBRapO2ayI0bN6BQKNCxY0e2Q2GMHTt2YNiwYZyr3L/++mu2wyC8BbPLQIFAYLOX8BRFcepgYhvdOCFXrh5I5U5gA0ZWGRMKhTb5ShNS5TKHVqvF9u3b8dtvv7EdCmPcuHEDL1++5FTlvn37ds5V7lyDkZJFJBLZXLWr1Wrh7OzMdhic4Y8//oC7uzun3iBLKncCGzCSbR4eHkw0wygajQYRERGYOXOm3b/N1Rbg2sMDusqdS5q4OObORRgxXT6fD5FIxERTjFGnTh0cPXoUPB4PnTt3xvvvv489e/aQx4JNoKysDHv27OHUwil//PEH3NzcSOVOsDqM9Y6np6fNDDFQFAVPT080btwYixcvRmZmJiZMmIAVK1agQYMG+Oqrr/DkyRO2w7Qbjh07hqCgIDRs2JDtUBiDa3NzdZU7lzRxFcZM18XFxSbGdimKgkgkqjBH0cnJCaNGjUJ6ejrS0tJQXFyM0NBQxMTEIDU1FRqNhsWIbR+uDS1wtXJ3d3dHcHAw26EQqoHR6xBfX1+bMF0/P7+3/r1Fixb46aefkJmZiaFDh2L+/PkICAjA/PnzkZOTY8VI7YPi4mL89ttvGDZsGNuhMMbx48c5Wblz6cTIZRg1XT6fD39/f9aMl6Io+Pv7w8HBodrvuri44MMPP8T58+dx4MABZGVloUWLFhgyZAiOHz9OFsr5Pw4ePIjOnTvDy8uL7VAYg2sGxcXKncswPuLu6uoKsVhsdeOlKApisRiurq5Gb9umTRtIJBI8efIEvXr1wsyZMxEYGIgffvgB//zzjwWitR+4NvbJxcqdi2PuXMYitzm9vb3h5uZmNeOlKAru7u7w9vY2qx1XV1ckJibi6tWr2LZtG27duoWmTZtizJgxOH36dI1bsUwmk+HMmTMYMGAA26EwxsGDB9GpUydSuRNYwyKmqxtXtUbFq6twmRxPpigKHTp0wObNm5GRkYGwsDBMmjQJLVu2xPLly/HixQtG9mPr7NmzB1FRUahduzbboTAG1wyKi5U717HYhD6KouDj44P69esz8uqdytrn8XioX78+fHx8LGbuYrEY06dPx+3bt7F69WqcO3cODRs2RFxcHP78809OV79cMyiZTIbTp09zrnLn2pg716FoK7iGWq2GVCpl5K3BwP+mhfn6+oLPZ2T5CKN49uwZNm3ahDVr1sDd3R2JiYkYPXq03VSEKpUKSqUSWq0WNE3rT2BCoVD/zP7Tp08RGhoKqVQKR0dHliOuHkM0SSQSnDp1Cjt27GA5WsMwRFNMTAxGjhyJsWPHshytYRiiietYxXR1KBQKyGQy/avRjdm1rpIViUTw9PS0ibVCtVotjh8/jtWrV+P333/HyJEjkZCQgNatW7MdWgXUajUKCgogl8uhVCr1yf46us+FQiGuXLmC06dP4z//+Q8LEVePKZqSkpLQpk0bxMTEsBBx9Riric/nY+nSpVi0aBHc3d1ZiLh6TOknkUgEDw8PVgoqa2BV09XBxY7IysrChg0bsH79evj7+yMxMRHDhw9nddEdc05yOlxdXW3mJAeYp0mr1cLBwcGmTtwA0fQ6tlhgMQkrpvs6XLrkUKvVOHToECQSCf7880/ExsYiISEBzZo1s2oMXBrOAYgmQyCa7AObMF2u8ujRI6xbtw4bN25E8+bNkZiYiEGDBll0jLSoqAhZWVmgaZrRm3wURekfPjFlLrQ5EE2GQzTZPsR0rUBZWRlSUlIgkUhw69YtxMXFYeLEiWjUqBFj+6BpGrm5ucjPz7fojArdFD1vb2+LTwckmkyHaLJdyBpwVsDR0RHDhw/Hf//7X6Snp6O0tBTh4eGIiorCgQMHoFarzWqfpmlkZ2dbPOl1+8rPz0d2drZF90U0mb8vosk2IZUuS5SUlGDPnj36x4/j4+MRHx8Pf39/o9vKycmxStK/iq7q8PHxsUj7RBMzEE22B6l0WcLZ2RmxsbH4448/cOjQIfzzzz9o1aoVBg4ciCNHjryx4M6lS5cqXYSnqKjI6kkP/K/qKCoqYrxtook5iCbbg5iuDdCqVSusXLkST58+Rd++fTFnzhw0adIEixYtQl5eHh4+fIiwsDDMmjWrwnZqtVp/44INaJpGVlaW2cMjr0I0MQ/RZFsQ07UhateujYkTJ+Ly5cvYtWsX/v77bzRr1gzR0dFwcHDA6tWrsXLlSv33pVIp62NbNE1DKpUy1h7RZBmIJtuBjOnaOHl5eXjnnXdQWloKABAIBNi0aRMGDRqER48esZ74QPkYW0BAgNmT2BUKBdFkQYgm24BUujbOH3/8AbVaDTc3N7i4uECtVmP8+PGQyWQ2kfRAecUhk8nMbodosixEk23AjUc8OEyXLl2wc+dOeHp66v9zc3PDw4cP2Q6tAnK5HGq12uSnhtRqtf6RUVuBaHoTLmqyNvYRZQ2mbt26GDJkSIXPbPVtFgUFBSYvMVhQUMBwNMxANL25rS1ijiZrQ4YX7BCmnmt/lSNHjqB///5o3749BgwYgLS0NKO2p2narArIXE0qlQozZsxAnz59EBISgosXL77xndu3b2PcuHEIDw9Ht27dsG3btirbZFPTw4cPMWLECHTq1AmdOnVCfHx8hasb3bh++/btERkZiU2bNhnULtv9VFJSgvnz5yMiIgIdO3bEuHHj3viOSqVC//798d577xnUprmarA2pdO0QpVLJaHt5eXn48ssvsXz5cnTp0gWnT5/GZ599hiNHjqBOnTpWiYsJTW3atMHYsWPx2WefvfG3goICTJ48GTNnzkTv3r2hUqmQl5dn0bjM2dbLywtLly6Fr68vtFotduzYgZkzZ2Lfvn0Ayo1mwYIFCAwMRGZmJhISEuDt7Y2oqCiLxmVuP82dOxcajQYHDhyAm5sb7t69+8Z3Nm3aBA8PD7x8+dJqcVkTUunaGSqVyqRKIyQkBE+fPtX/e86cOVi+fDmActN1dXVFREQEKIpC165d4ezsjMzMTKP2QdM0VCpVld95+PDhGweToZqq0iAQCBAbG4vQ0FDweG+m9S+//IJOnTohJiYGjo6OqFWrlkFrXxii6a+//nojfkM0VaXH1dUVfn5+oCgKNE2Dx+NV6I+4uDi0aNECfD4fAQEB6NGjB65evVqtHkM0FRcXIyMj443PzdWUkZGBU6dO4d///jfEYjEcHBzQsmXLCttnZWUhNTUV8fHxBmnRYUg/2QrEdO0MpVLJ+GIfLVu2REBAAE6ePAmNRoO0tDQIBAIEBgYa1Q5FUdVWHAMGDICPjw8WL16sN19LaHqdGzduwM3NDWPHjkW3bt0wdepU5OTkVLtddZpkMhlatWqFoKAgHDx4UG9KTGnq1KkT2rVrh4ULF77ViGiaxpUrV9CkSROD2qxO0969e9G4cWP069cPt27d0n9urqabN2/Cx8cHK1euREREBAYNGoTjx49X+M7ChQvxySefQCgUGtW2IblnK5DhBTujskeBzcXBwQH9+/fHF198gbKyMggEAvz4449Gz33UarW4fPkySkpK3vod3eL1//73v/Htt99i9OjRWLx4sbkSqiUvLw937tzB2rVr0bRpUyxduhSzZs3C1q1bS8rOYwAABkRJREFUq9xOo9FUqenFixdwdHTEgwcPMGLECIjFYnz99dcYNWoUI3GfPXsWCoUCBw8ehK+vb6XfWbVqFbRaLQYOHGhQm2q1ukpN165dg5OTEw4dOoRjx47h3XffxU8//YQWLVqYrAMo74O///4bvXr1wn//+19cu3YNU6ZMQePGjdGoUSOkpaVBo9Hgvffeq3RMvjoscWxYAmK6doYl5keeO3cOS5cuxaZNm9C8eXPcvn0bH3/8MVavXm3U4usajQbHjh2r8jJX9ybl0tJS8Hg8bNu2DZ9//rnZGqrDyckJPXv2RHBwMABg8uTJiIiIgFwuh0gkeut2KpWqSk2lpaX6R1GVSiWkUimWLVuGkSNHMha7i4sLhg8fjq5du+LAgQMVxtmTk5Px66+/YvPmzQav01xWVlalJt3jtVqtFiqVCn/++Se2bt2K7777ziwdTk5O4PP5mDRpEvh8PsLCwhAeHo6zZ8/C29sbS5cuxapVq0xu31bmDlcHMV07w9TLO2dn5wqVzfPnz1GvXj0AwL1799C2bVv9+FpwcDBCQkJw/vx5o0xXIBDgiy++gJub21u/ExgYiKdPn8Ld3R3fffcdYmNjoVAokJ2dbZaG6ggMDKzw2xn6OwqFwio1PXv2DL6+vnBxcUGrVq2wZMkSdO7cGYWFhdW2bYwerVYLpVKJZ8+e6U13//792LBhAzZv3gxvb2+D9ADlJl6Vpo0bNyIhIQFCoRCjR4/G3Llz4e/vb7amyoardP3w9OlTSKVS/WwGlUqF4uJidO/eHUlJSfDz86t23/ayxi4Z07UzKrtJZAhBQUE4dOgQNBoNzpw5g0uXLun/1rJlS1y5ckV/J/nOnTu4cuWK0WO6hsQXGxuLVatWITMzE3FxcRAIBAZrqkoDUF7B6R6XVqlUKC0t1Vc/AwcORFpaGu7evQuVSgWJRILQ0NAqq1xDNLm5uSE2NhbHjh3DuXPn0Llz52q3MUTP2bNncefOHWg0GhQXF+OHH36Aq6ur/uZfamoq/vOf/2DdunWoX79+tfsyRlPr1q2RmJiIBw8eYMOGDfrlRs3V1LZtW/j4+GD9+vVQq9W4evUq/vzzT3Tq1AlNmjTB8ePHsWfPHuzZswdz585FnTp1sGfPHoNPKKYeG9aGrL1gZ6hUKty/f9/oS6lbt25hzpw5yMnJQc+ePaHRaODv749p06YBKL9M3bZtG54/fw4PDw+MGjWq0jmUVUFRFAIDA41+r52hmqrT0KdPnzcWQDly5Ii+Stq5cyfWrl2LkpIShIaG4quvvqr2gLakpqr0HD16FD///DPy8vIgFAoRHByMTz75BEFBQQCAyMhI5OXlVYgrJiYG33zzTbWxsaUJAP7++2/8+9//xoMHD+Dj44Np06ZVOh/34sWL+Ne//mXwfHFTNbEBMV075Pbt2zZ504DH45l8s4Vosh5EE7vYRz1OqICx02mshTlxEU3Wg2hiF2K6dohIJLK5mwa6V2WbCtFkHYgm9iGma4d4eHiwHUKlmBMX0WQ9iCZ2IaZrh/D5fJs7s4tEIrOW1iOarAPRxD7EdO0UT09Pm7nMoygKnp6eZrdDNFkWosk2IKZrp7i4uNjE+JpuPI2J16UQTZaDaLIdiOnaMb6+vjaR+IY8LWQoRJNlIJpsB2K6dgyfz4e/vz9ryU9RFPz9/eHg4MBYm0QT8xBNtgUxXTvH1dUVYrHY6slPURTEYjFcXV0Zb5toYg6iyfYgpssBvL294ebmZrXkpygK7u7uRi2yYixEk/kQTbYJeQyYI9A0jdzcXOTn51t0iTtdleHt7W3xA41oMh2iyXYhpssxioqKkJWVBZqmGT0AKIrSj6NZ+7KOaDIcosn2IabLQdRqNaRSKWNvDdZNzfH19WVtEjrRVD1Ek31ATJfDKBQKyGQy/eupjelq3eWbSCSCp6enzcyFJJoqQjTZH8R0awBqtVr/bjKlUgmapisdE9N9LhQKIRKJ4OHhYbPVBdFENNkrxHRrICqVCkqlElqtVp/sPB4PQqHQLhaBrgyiyT7goiZjIaZLIBAIVoTM0yUQCAQrQkyXQCAQrAgxXQKBQLAixHQJBALBihDTJRAIBCtCTJdAIBCsCDFdAoFAsCLEdAkEAsGKENMlEAgEK/L/ATJ9BP0q9puSAAAAAElFTkSuQmCC\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],\n",
" 'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],\n",
" 'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [1.5, 2], 'f32': [2.5, 2], 'f64': [3.5, 2],\n",
" 'c64': [3, 3], 'c128': [4, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xBDCy0AnGsbJ"
},
"source": [
"This involves a small sleight of hand: previously we had used `f*` to refer to a scalar type. In this lattice, `f*` might be applied to the array output of a mixed computation. Instead of thinking of `f*` as a scalar, we could think of it as a special kind of `float` value with distinct promotion rules: in JAX we refer to this as a *weak float*; see below.\n",
"\n",
"The advantage of this approach is that, outside unsigned ints, it avoids *all* wider-than-necessary promotions: you can never get an f64 output without a 64-bit input, and you can never get an f32 output without a 32-bit input: this results in convenient semantics for working on accelerators while avoiding inadvertent 64-bit values.\n",
"\n",
"This feature of giving primacy to floating point types resembles the type promotion behavior of PyTorch.\n",
"This lattice also happens to generate a promotion table that very closely resembles JAX's original *ad hoc* type promotion scheme, which was not based on a lattice but had the property of giving primacy to floating point types.\n",
"\n",
"This lattice additionally offers a natural location to insert `bfloat16`, without the need to impose an ordering between `bf16` and `f16`:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"cellView": "form",
"id": "inqdnEmioq7W",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAEeCAYAAAApRMZ1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2dd1gUV/v+76GzujRRl6qowYKoVMWCLbF3U0w1KtZYEoym6GuKGs0bg4nGaFTEJCYxRrHGGt+oMTZAwQaiIkpZUIrs6rKwZX5/8Nv9sgK6Ozu7M7Ocz3XliiwzZ+6zPHPPmWfOnIeiaZoGgUAgEKyCHdcCCAQCoTFBTJdAIBCsCDFdAoFAsCLEdAkEAsGKENMlEAgEK0JMl0AgEKwIMV0CgUCwIsR0CQQCwYoQ0yUQCAQrQkyXQCAQrIgD1wIaIyqVCkqlElqtFjRNg6Io2NnZwcXFBY6OjlzLIxAEH6N81k9M1wqo1WqUl5dDLpdDqVTqg+BJdJ+7uLhALBbD09MTDg7kT0SwPEKPUSHpp8iCN5ZDoVCgpKQEcrkcQM0f3Fh0ASMWi+Ht7Q2RSGQRjYTGjdBjVIj6ielaALVajcLCQsjlcpOCoCEoioJYLIavry8vRhUE4SP0GBWyfmK6LCOTyZCfnw+aplkJBh0URYGiKPj7+8PNzY21dgmND6HHqND1E9NlCZqmUVRUhLKyMlYD4UkoioKXlxckEkm9OSsCoSGEHqNC169vn5iu+dA0jYKCAlRUVFg0GHRQFAV3d3f4+fkR4yUYhdBjVOj6a0Pm6bJAUVGR1YIBqAnAiooKFBUVWeV4BOEj9BgVuv7aENM1E5lMZvHbnfqgaRplZWWQyWRWPS5BeAg9RoWu/0mI6ZqBWq3WJ/S5gKZp5OfnQ61Wc3J8Av8ReowKXX99ENM1g8LCQs6CQQdN0ygsLORUA4G/CD1Gha6/PojpMkShULA2R9AcaJqGXC6HQqHgVAeBfwg9RoWuvyGI6TKkpKSE82DQQdM0SkpKuJZB4BlCj1Gh628IYroMUKvV+tcO+YJcLie5XYIeoceo0PU/DWK6DCgvL7dIu4sWLTJrf0vpIggPc2OhrKwMCxcuRExMDHr27IkPPvigzjYVFRWIjY3FW2+9xbouU/Q/ePAAc+bMwYABAxAaGoqCggKD369atQrDhw9H9+7dMXLkSOzbt8/g9+fPn8fLL7+MHj16YMiQIfjjjz9Y0dUQxHQZwGaeqbKyEp999hkePnwIoObBwWeffWZy+7q8E4EAmB+j7733Hry9vXH06FGcPHkSb7/9dp1tVq9ejaCgIKPbNCVGTdFPURR69eqFhISEen/v6uqKtWvX4uzZs1i+fDlWrlyJ9PR0ADVLQL777rt48cUXcfbsWaxatQpfffUVbty4YZb+p0FWT2GAUqlkvG9RURFWrlyJixcvQqvVYujQoXjttdewbNkyXLt2DevWrcPMmTMZvQVjji6CbWFsLNQXj/3790dRURG2bNkCe3t7AEDHjh0N9ktPT8fNmzfx4osvYvfu3azrami7+vQuWrQIEyZMaPDW/5133tH/u0uXLoiIiEBGRga6deuGiooKPHr0CCNHjgRFUejcuTPatGmD27dvo3379oz1Pw0y0jURlUrFeASh0WjwzjvvwMfHB4cPH8bx48cxdOhQ/e9rL7bMBJqmoVKpGO1LEB7FxcUICQnBDz/8gOrqav3nxsZoQ/GYkZGB1q1bY9GiRejduzcmTJiAlJQUg/2++OILfPzxxyYPDmrHqEqlQmhoKFauXInHjx8/U/+zzh9jUCqVuHr1Ktq2bQsA8Pb2xtChQ7Fnzx5oNBqkp6dDKpUiPDz8mfqZQkzXRJRKJeN3sa9cuYIHDx5g/vz5EIlEcHZ2RseOHfHLL79g8eLFiIiIwKxZs7B582ZGxk5RFBntNiLKy8tx+/ZtzJ8/H/7+/nrzNTZG64vH8PBwFBcX48yZM4iOjsbff/+NiRMnYt68efp85i+//ILQ0FCEhISYrLl2jKpUKly/fh1Lly6Fr6+v3nwb0t+QXlNYunQp2rdvj169euk/GzZsGDZs2ICIiAi8/fbbmDNnDiQSyTP1M4WkF0xEq9Uy3reoqAg+Pj4G63W6urri008/1f/s6+uLJUuWMGpfqVRi0qRJOHXqFGONBOGgVqtRXV2NqqoqPH78GDNmzMCyZctw9epVo/avLx4BwMXFBX5+fhg3bhwAYOjQodi4cSMuXbqEzp0749dff8Xvv//OSLNcLtfHKE3T0Gq1+vmvH330EZYtW1bnQdiz9BrL119/jZs3b2LLli16U8/JycHChQuxevVqxMTE4O7du5g9ezZatGiB2NjYetsxxwMAYromY87DCYlEgqKiIqjV6noDZ/ny5eZIg7OzMxISEuDi4mJWOwRhcPPmTTz//PPQaDSwt7fHyy+/jCVLlhgdow3FY3BwME6cOGGwrc6kdKPN0aNHAwCqqqqgVCrRr18/HD9+XJ8DboimTZvqY1ShUKBt27agKAqOjo4YNGgQPv/88wb1P+v8eRrr1q3D6dOnkZSUhKZNm+o/v3XrFlq1aqUf+QYFBSE2Nhb//PNPg6Zr7kN0YromYs4yb6GhofD29sY333yDWbNmwd7eHtevX0dYWBhr2tzd3eHu7s5KewR+U1VVBZqmMXnyZCxZsgS+vr4AaqZyGUND8Thw4ECsWrUKe/fuxYgRI3D8+HEUFxcjLCwMTZo0wZEjR/RtHD58GAcPHsSaNWueabiAYYxWVVXByckJI0eOxBdffIF27do9Vf/Tzp+qqipoNBoA0I/+nZ2dAQCbN2/GwYMH8eOPP8LDw8OgzY4dO+Lu3bs4f/48oqOjkZ+fj5MnT2LSpElP7YM5kPV0TUQulyMvL4/xLYZUKsWKFStw8eJFUBSFYcOG4aOPPmJFm52dHQICAiAWi1lpj8B/dKPc2pgSow3FY1pamv5WPygoCAsXLkRERESd/ffs2YPk5GT89NNPRul9MkZN1d+Q3tDQ0DrbXrlyBUCNWTs6OhqMjqdOnYqpU6cCqLlw/PDDDygsLETTpk0xfPhwvPvuu/U+0GbjHCOmayIqlQrZ2dm8eT2xNhRFITg4mPMS0wRuEXqMCl3/syCzF0zE0dGRt9UadLkxQuNG6DEqdP3PgpguA/j6oIqvugjWh6+xYKwuoet/GsR0GSAWi3l3JdaVkCYQAOHHqND1Pw1iugzw9PTkWkK98FUXwfrwNRaM1SV0/U+DmC4DHBwceDeqFIvFjCeNE2wPoceo0PU/DWK6DPH29ubN7Q9FUfD29uZaBoFnCD1Gha6/IYjpMkQkEvEi76TLM4lEIk51EPiH0GNU6PobgpiuGfj6+vIiIPz8/DjVQOAvQo9RoeuvD2K6ZuDg4AB/f3/OgoKiKPj7+xv1+iWhcSL0GBW6/vogpmsmbm5u8PLysnpQUBQFLy8vuLm5WfW4BOEh9BgVuv4nIabLAhKJBO7u7lYLCoqi4OHh0eCanwTCkwg9RoWu36BtsvYCO9A0jaKiIpSVlVn0nXHd1VcikXCe6yIIC6HHqND169snpssuMpkM+fn5oGma1cCgKEqfXyIpBYI5CD1Gha6fmK4FUKvVKCwsZK1qsG7Kiq+vL3kBgsAKQo9RIesnpmtBFAoFSkpKIJPJoNVqTSo4qbutEYvF8Pb2JvNwCRZBF6Pl5eVwcHAwycD4EKNC1E8epFkQkUgEe3t7DBo0CImJiRCJRLCzs9NX/H3yP93nIpEILVq0QPv27REYGEgMl2AxXF1d8cknn6Bv376wt7cXXIyKRCJs374dffv2RUlJiSD0k5GuBZFKpYiKikJBQQG8vLxQWloKoGaRZqVSCa1Wa1B23cXFhayHS7AaNE1jwYIFWLt2LbRaLfbv348hQ4YAEE6Mrlu3DvHx8VCpVFi/fj2mT58OgN/6SYLQQty/fx89evRAUVERgJq6T7q6TY6Ojpz/4QmERYsWYf369aiuroajoyMyMzP1piuEGN24cSMWLFiA6upqAMDly5f1v+OzfpJesBBpaWm4f/++Pinv6OiIW7ducayKQKhBq9Viz549+jpkKpUKGRkZHKsyjZ07dxrUUUtPT+dQjfEQ07UQQ4cOxcOHD+Hn54eBAwfC1dXV6CqtBIKlsbOzw/Xr17Fo0SKEhIQgODgYCoWCa1kmcfToUSQmJiIgIADdunVjXCzW2pCcrgXJyclBTEwMCgsLyfoIBF4yfPhwvPnmm5gwYQLXUhgxa9YstGrVCh988AHXUoyGmK4F+frrr3Hjxg1s3LiRaykEQh0qKioQEBCA/Px8Qb5wo9Vq4efnh5MnTyI4OJhrOUZD0gsWZNeuXRg3bhzXMgiEevnzzz8RGxsrSMMFgLNnz8Lb21tQhgsQ07UYhYWFyMrKwoABA7iWQiDUS3JyMsaPH8+1DMYIVT8xXQuxZ88eDB8+HE5OTlxLIRDqoFAocOzYMYwcOZJrKYygaRrJycmCvJMkpmshhBoQhMbB0aNHERkZKdjaeunp6bC3t0doaCjXUkyGmK4FKC0tRUpKCgYPHsy1FAKhXoT+vEGnX4jLmxLTtQD79u3DCy+8QNZMIPCS6upq/Pnnnxg7dizXUhgj1HwuQF4DtgjJycl49dVXuZbRKOHzO/d84e+//0aHDh3g6+vLtRRGZGZmQiaTISoqimspjCCmyzJyuRwnT57Etm3buJbSKFCr1SgvL4dcLodSqdQb7ZPoPndxcYFYLIanp2ejXZtY6M8bdu/ejbFjx5q0VCqfaJxRZ0H+/PNP9O7dG+7u7lxLsWl066jK5XIAMFhHtaH3fWiahkKhQGVlJe7fv98o1yrWaDTYs2cPzp49y7UUxuzatQurVq3iWgZjiOmyjJBzTUKAjYoBuv1kMhnkcnmjqsrx77//wtfXF23atOFaCiNyc3Nx79499OnTh2spjBHm+JynVFZW4siRIxg1ahTXUmwSmUyG7Oxs1kq0ADUGLJfLkZ2dDZlMxkqbfMYWUgujR48W9AWSmC6LHD16FGFhYWjevDnXUmwKmqYhlUqRl5enf0DGdvtarRZ5eXmQSqUWrTTLJboXCoQ8a2HXrl2C1g8Q02UVklpgH5qmUVBQYPGy27pjlZWVoaCgwCaNNzU1Fa6urggJCeFaCiOkUimuXbuG559/nmspZkFMlyVUKhUOHDgg+Ksw3ygqKkJFRYXVTJCmaVRUVOgrftgSukGBEF8oAIC9e/di2LBhcHZ25lqKWRDTZYkTJ07gueeeg7+/P9dSbAaZTGaVEe6T6Ea8tpTjpWla8G+hCT0frYOYLksIPaD5hlqtRn5+Pme3+TRNIz8/H2q1mpPjs821a9egVCoRERHBtRRGlJWV4dy5c/oabkKGmC4L6OY+EtNlj8LCQs7zqjRNo7CwkFMNbKEbJQo1tbB//34MHDgQTZo04VqK2RDTZYGzZ8+iZcuWaNeuHddSbAKFQsHqtDCm6KaTCa12WH0I/SGv0PXXhpguC9hKrokvlJSUcG64OmiaRklJCdcyzOL27duQSqXo2bMn11IY8ejRI/z9998YMWIE11JYQbgzjHmCbu7j/v37uZZiE6jVav2rvXxBLpdDrVYLdkJ+cnIyxowZI9jiqAcPHkTPnj3h4eHBtRRWICNdM7l48SIcHR3RuXNnrqXYBOXl5Q3+bvDgwQ2uGfD777+jb9++iI6OxsOHD62qi+8I/U5M6PqfhJiumQh97iPfYJLLValU+Oqrr7Bx40ZcuHABHh4eWLt2LcaOHYtu3brh+++/r7NPWVkZFi5ciJiYGPTs2fOpJbx1uV0hUlBQgBs3bqB///5cS2GEUqnE4cOHMWbMGK6lsIYw75d4gm7u408//cS1FJtBqVSavE9paSmqqqrQtm1b/WeBgYGIj4/Hjh076t3nvffeQ0hICI4ePQoXFxfcunWLdV18YPfu3RgxYoRga/UdO3YMXbt2RYsWLbiWwhpkpGsGmZmZePz4MSIjI7mWYhOoVKpnjnKvXbuG0aNHo2fPnli8eDFu3LihX2CoZ8+emDJlCgBg9OjR6NOnT71TjM6cOYOioiLMnz8fYrEYjo6O6Nix41OPS9M0VCoVw55xh9BvzYWuvz6I6ZqBLiCEupgy31Aqlc9M0/z555/44YcfcOjQIdy9exdHjhzB7t27AdSYaWJi4jOPk5GRgdatW2PRokXo3bs3JkyYgJSUlKfuQ1GU4Ea7JSUlSEtLE2ytPpVKhf379xPTJfwftngV5hKtVvvMbV599VVIJBK4u7tj6tSpOHTokMnHKS4uxpkzZxAdHY2///4bEydOxLx58575sMwYfXxi3759GDRoEFxdXbmWwohTp06hTZs2CAgI4FoKqxDTZUhOTg7y8/PRu3dvrqXYDMY8QJNIJPp/+/r64sGDByYfx8XFBX5+fhg3bhwcHR0xdOhQtGzZEpcuXTJbH58Q+qvpQtffEMR0GaJbTFmocx/5iDEzQGqv/iWVShmtXRwcHMzo2EKaoVJRUYF//vkHw4cP51oKI7RaLXbv3k1Mt7GgUqkgl8tRUVGBhw8foqKiAnK53OBBii29lsgF3333HZKSkpCfn6//zJjc+G+//aZf7nHTpk0NLoCiUqlQVVUFrVYLtVqNqqoqaDQaAMDAgQMhk8mwd+9eaDQaHD16FMXFxQgLC3vqsfmUu39WjB48eBCxsbFwc3PjWGn9PEv/uXPn4O3tXe8FUuiQKWMwvaKsnZ0dunTpgr59+3Kg1jbYuHEjbty4ATs7O7i7uyMkJAQrV658ZpHI4cOHY/r06Xjw4AH69++PadOmoaysrM52n376Kfbt26f/edOmTVi6dCnGjBkDd3d3rF27FsuWLcPy5csRFBSENWvWwNPTs8Hj0jQNFxcX5h02E1Nj1M3NDbNmzeLNm3Sm6lcoFPjggw94o59NKFpoiSoWeVpF2Weh1Wphb2/fKCvKMoWmady5cwepqalYsWIF0tPT9b+zt7fH9u3b0alTJ14+sLKzs0OnTp2sflxzYlRnalzGqDn6gZo+2No51ihNl42KsrXRBUZjqShrDDRNIy8vD6mpqfr/0tLS4OrqisjISDg6OuLAgQMAal5kOHz4MIKCgpCTk8PLVb1EIpFVK+gKPUaFrt+SNDrTlclk+sWx2ew6RVGgKAr+/v68zaNZksLCQgODTU1NhZ2dHSIjI/X/RUREwMfHBwBw48YNdOjQAa+88gqSkpL005oePHiA+/fv82qmgFarRfPmzfXaLY3QY1To+i1NozFdmqZRVFRk8fIvFEXBy8sLEolEUE+7TeH+/ft1DLa6uhpRUVF6c42MjISfn1+D3wFN07h06RLCwsIMtlGr1bhx4wavTFelUmH8+PEYPnw44uLi0LVrV4scR+gxKnT91qJRmK6uoqy1ChxSFAV3d/enmo5QKC0tRVpamoHByuVygxFsZGQkAgMDWevrvXv3eFWfzM3NDTRNIykpCVu2bEHLli0xZcoUvPrqq3B3d2flGEKPUaHrtyaNwnSlUqnVCxzqrsbWuiVlg4cPH+LixYsGBltaWorw8HADg23Tpo1FA12hUODOnTu8GO1SFIWgoCD9QxyNRoO//voLmzdvxl9//YVRo0YhLi4OvXv3Nus7EXqMCl2/NbF505XJZMjLy+PkBKYoCgEBAbzMP8nlcly8eFE/ik1JSYFUKkVYWJiBwT733HOczE+9d+8e5yV7dA9vAgMD6/39gwcP8PPPPyMxMRFqtRpTpkzBW2+9ZfDWnDEIPUaFrt/a2LTpqtVqZGdnczoFyc7ODsHBwZw+cX38+DHS09MNRrD37t1DaGioPg8bGRmJDh068OYNO7787dq3b//M74SmaZw7dw6bN29GcnIy+vXrh7i4OAwePFj/d585cybi4uLqVOPlSz+ZxqjQ9XOBTZuuEEZLbKNUKpGRkWFgsLdv30ZISIjBCLZTp05wdHS0iiamCHEEJZfL8fvvv2Pz5s3Iy8vDpEmT0KtXL4wePRoikQgpKSl47rnn9NsLPUaFrp8LbNZ0+ZwXZIvq6mpcuXLFwGB1U7FqG2znzp0Fu4i1kHOFV69eRWJiItavX4+qqipQFIUWLVogPT0dEolE8DEqdP1cYbOmy8cn4OZciVUqFa5fv25gsNeuXUO7du0MDLZLly6cvq7KNlw8Fffw8ICvry8rDwurqqrQrFkzPH78WP+Zu7s78vLyUF5eLugYtbVzzFoIIwliIkKvKKvRaJCZmal/iys1NRWXL19Gq1at9Ob65ptvomvXrvVWRrAlKIqCn58f7O3tBTn/MycnB46OjvD390eLFi3g4eEBjUaDqqoqQceo0M8xLrHJka45bzXduXMHCxYsQF5eHubOnYvXX3+dFU26W8snlyLUarXIzs42GMGmp6fD19fXYAQbFhYGsVjMihahYktvOhkbo5aKx/poKEbrQ+j6uYTflwSGmJPYT0pKQlRUFHbu3Amgpprq3r17MWvWLLM06SrKVlRUGBjsxYsX4e3trTfXzz//HOHh4fDw8DDreLaIm5sbgoODbeKdfmP1147Hn376CUOGDMHDhw8hEokwePBgzJ8/Hw4ODigtLcWXX36J1NRUVFZWol27dliwYAG6dOlitCZdjBpjWkz0A8D169fx5ZdfIjMzE66urpg6dSreeOMNg31SUlIwefJkTJ06FXPnzrWIfi6xSdM1p5ZVYWEhhg4dioyMDJw/f16/XqvOIKdNm8a47eLiYkyYMEH/quzHH3+MiIgINGvWjHGbjQ0HBwcEBgYKfvUtY2NUF48A0L9/f4wZMwZubm6oqKhAfHw8fvnlF0ycOBEKhQIhISFYsGABvLy8kJycjHfeeQdHjhwxqX/G6mKiv7y8HDNnzsSCBQswaNAgqFQqFBcXG2yvUqnw5ZdfmnSxYKKLS2wuvaBSqZCdnc1oFDRlyhSkpqbCwcEB9vb2mDt3Lk6ePImioiIMHDgQ06ZNM/sEbd++Pe+nagkJU9dpdXFxgVgshqenJ2e5P2Nj9Ml43LFjB1q3bg2g5u3B999/H61bt8bixYvr3b9Hjx5ITExESEiI0dooikJwcPBTY5Sp/n79+sHe3h4rVqxocJ/NmzdDJpOhtLQULVu2NGmka6x+ruHPUvgsYUxF2YZITExEeHg4Pv74Y1y4cAF+fn7637HxVpadnZ0grsRCwsHBAc2bN0ebNm3QqVMnBAcHIyAgAH5+fvD19YWfnx8CAgIQHByMTp06oU2bNmjevDmnD1uMjdEn47F169b4888/0aNHD/Tp0wfZ2dl46aWX6t03KysLKpXK5Kf5taseP3z4EO7u7pgzZ47BiJSp/tLSUri7u+ONN95A3759MXv2bEilUv32hYWF2LNnD2bMmGGS5ob08xWbM1223ozJyMhAdnY2/vOf/2Do0KGIiYnBL7/8Yna7fFyg25ZwdHSEWCyGu7s7PDw84O7uDrFYzKuRjzkxMHz4cJw7dw4HDhzASy+9VG9q6tGjR/joo48wc+ZMkx++0jSNyspKKBQKlJaWQqlUYuPGjQgKCsKMGTOQn5/PWH9xcTH27duHDz/8EEePHoWfnx8WLlyo//2KFSswe/Zss+8m+X6O2ZzpspUt6dq1K6ZPn65/BTQqKgpTp041u10by+YQGMBGDLRq1Qrt2rXDsmXLDD5XKpWYPXs2unbtiri4OJPbffToEaZPnw5vb2907twZ1dXVqK6uRmVlJX744QeEhoYy1u/s7IwBAwagc+fOcHZ2xsyZM5Geng65XI4TJ05AoVA0WPPOFPh+jtncgzS2V7/y8/Mze+ZCbYS2DB2BfdiKAbVajby8PP3P1dXVmDdvHlq2bIklS5YwalMsFmPr1q1wd3eHVCpFq1at9DnZhQsX4t1332U8kgwODjboe+1/nz9/HteuXUO/fv0A1Ji/nZ0dbt68ibVr15p0HL6fYzZnunyq2FoffNdHsDxMY2DXrl3o168fmjVrhtu3byMxMRE9e/YEUPNwKz4+Hs7Ozli+fLlZcabbt0mTJvDx8UFcXBzeffddfaqC6UsRY8aMwXvvvYfXX38dbdu2xYYNGxAeHg6xWIzZs2djypQp+m1XrlyJ5s2bM8rv8v0csznTdXFx4e3thVar5VVu0RZRqVRQKpXQarUG1ZtdXFx4890zjdFLly5hzZo1qKyshKenJwYNGoTZs2cDANLT03Hy5Em4uLjojRgA1q9fX2dls6dRu+qxm5sb7t69y5r+7t27Y968eXjnnXdQWVmJ8PBwfPnllwBqDL7225XOzs5wdXU1eZF4rqs2G4PNTRkDaiZg8zGZ/ujRIwwaNAhhYWH6kjaRkZFo164d76/OfEWIU8YA/saosVWPha6fS2zSdPlcUdbDw6NOdYaHDx/Wqc4QFBTE+9wUlwj95Qg+x6gxVY+Frp9LbNJ0+VhR9mnvhT948KBOHTKFQlGnDllAQECjN2Ihlvaub/QttBh9EqHr5xKbNF0+VpSlKArt27c3+sSWSqV6I05LS0NKSgq0Wm0dI/b19bWwcv4gxAVvrl69im7dukEkEsHLywtNmzYFABw7dszq6wQ/C1Ni1BbOMa6wSdMF+LfWZ0lJCV555RU0adIEKpUKlZWV8PLyQnZ2tlH70zSNwsJCg9FwSkoKnJycDEw4IiICLVu2tHBvrIuQS3ur1Wo0a9bMIBZ9fHxw+/ZtPHjwgFcxStbTtQ78viSYgbe3N+dlRHTobnsqKipw//59/efR0dEmteHn5wc/Pz+MHj0aQI0Z3bt3T2/Cq1evRmpqKsRisd6Adf/39vZmvV/WwJqLmNM0jbKyMmg0GrNLe9M0jZMnTyIxMVH/2ixFUWjVqhXS0tLg6urKuxg1NUaErp8rbHakC/CvftOFCxfQr18/VFZW6m+FlixZghdffJG16Uw0TSMnJ6fO8pFeXl51RsRCWD5SaOV6CgsL8eOPP2LLli1wdnZGXFwcYmJi0Lt3b3h6euLSpUsGa3rwLUZNRej6ucCmTZcvlUprV5T9+eefMXnyZHTo0AFffPEFEhIScOvWLcydOxdTp061iBFqtVrcvHmzzkLpEomkzkLpfCplLZTClGq1GgcPHkRiYiJOnTqFF198Ec4KgIwAACAASURBVHFxcYiOjtaPlpcsWYLXXnsNHTp0qLMv32LUFISunwts2nQBfp64//3vfxEVFYX+/fsDAC5evIjVq1fjzz//xBtvvIF58+ahbdu2FtWm0WiQlZVlMGsiIyMDgYGBBkbcrVs3TkoC8eVkflpp75s3b2LLli348ccf0bp1a0yZMgWvvPKK/mGZsfAxRk1B6Pqtjc2bLiCcW9SCggKsW7cOmzZtQp8+ffDee++hd+/eVpsmplar6xS/vHr1Ktq2bWuQI+7atStcXV0tqoWvt62VlZXYtWsXNm/ejOvXr+PNN9/ElClTzJ6QL5QYbQih67cmjcJ0hVZR9vHjx/jxxx+xevVqeHh4ID4+ntW8rylUV1fj6tWrBkaclZWF4OBggxFxaGgonJ2dWTkmH0t7Z2VlITExEdu3b0d0dDSmTJmCUaNGsVbaXmgx+iRC129NGoXpAsKcdqTVanHgwAEkJCTg9u3bFs37moJSqcTly5f1JpyWloabN2+iU6dOBkYcEhLS4IWCpmkcO3YMzz//fJ1XoPk0FYmmaZw/fx5Lly7F5MmTMWnSJIs9sBFijNZG6PqtRaMxXR1CnGAPcJP3NQWFQoGMjAyDEXFubi5CQ0MNjLhDhw5wcHDAtWvX0LlzZwwYMAC7du3SX0j4OOmepmkEBwezNpJ/FkKNUR1C129pGp3pAsJ8lVRHQUEBvvvuO2zatAmxsbGIj49Hr169eHnFf/ToES5dumRgxAUFBfo3tE6dOgWaptGsWTMcPnwYXbp0IaW9/z9CjlFA+PotSaM0XR1CXjSldt7X09MT7733Hmd5X1OoqKjAxYsX8eGHH+LChQv6zymKwpYtWxAbG2vUQipLlixBkyZN8MEHH+DChQvYsGEDMjMz4ebmhiNHjtTZftu2bdi2bRvKysogkUiwZs0afZFHY+BqIRUhxyggfP2WoFGbrg6hLg8IGOZ9c3JyMGfOHF7kfZ9FSEgIcnJyoNVq8dxzz+kLGGq1WqOmicXFxWHo0KEYP348rly5gtzcXCiVSmzevLmO6e7atQu//PILvvrqK7Rp0wb5+flwc3Mzaa1WrpcMFHKMAsLXzybEdOtBCAth10ftvO+bb76JuXPn8irvW5vffvsNTZs2Rb9+/fQVCcwtTX727Fl8+umnBqar1WoxaNAgLFu2DD169GCsl2+lvYUaozqErt8cbOsSwhKOjo6C/MOHh4fj559/1ud9u3fvztu876uvvlrnM90aBc8y3cTEREyaNAkjRozA+PHjn7ptcXExiouLcevWLSxevBgODg4YOXIkZs6cadLC8brS3nyJC6HGqA6h6zcHUq7ABvHz88OKFSuQm5uLgQMHYtKkSejevTu2b98OlUrFtbwGscTbZ8XFxQCAM2fOIDk5GYmJiTh06BCSk5NNbouPlRIIwoOYrg3TtGlTvPPOO8jKysLixYuxYcMGtG3bFl999RUePnzItbw6WCLTpZvmNWnSJLi5ucHPzw8vvfQS/vnnH5PbIpk4AhsQ020E2NvbY9SoUThx4gR2796NjIwMtGnTBvPmzUNOTg7X8vRYIv3RunVrODo6stI2n9IzBOFCTLeRERERgW3btuHKlSsQiUSIjo7G+PHjcfr0ac5HckyLc2q1WlRVVUGtVoOmaVRVVenTKK6urhgyZAiSkpLw+PFjFBUVYefOnejbt6/V9BEItSFR1EipnfcdMGAAL/K+TEt7p6WlITIyErNmzYJUKkVkZCSmTZum//3HH38MkUiEAQMG4I033sCwYcMwduxYk44hhNLeBGFApowRANQs9aib73vnzh3O5vuS0t4EW4eMdAkAavK+o0ePxsmTJznN+/J1NMlXXQThQUyXUAdd3vfy5ctwdXXV533//fdfi+d9xWIx7x5Y6d77JxDYgJguoUH8/f2xcuVKfd737bffRo8ePSya9/X09LRIu+bCV10E4UFyugSjeTLvO3fuXMTFxbGe9+XTerqAcEp7E4QBGekSjObJvG96erpF8r7e3t68STEIqbQ3QRgQ0yUwwpJ5X5FIxIvcri6XaytLChL4AUkvEFjh0aNH2Lp1K7755hs0a9YM8fHxGD9+PONl+fhSDVhIpb0JwoCYLoFVaud9c3NzMWfOHMZ5X1Lam2CLkPQCgVVq53137dqFS5cuoU2bNnj33XdNzvu6ubnBy8vL6mkGXeFDYrgES0BMl2AxIiMj8csvv+Dy5ctwcXFhlPeVSCRwd3e3mvHqSntLJBKrHI/Q+CDpBYLVYJr3JaW9CbYEMV2C1dFoNNi/fz9Wr16N3Nxc/XzfZ9UsI6W9CbYAMV0Cp6SmpmL16tU4dOgQ3nrrLcybNw9BQUENbk9KexOEDsnpEjhFl/fNyMiAs7MzoqKi8OKLL+LMmTP1mqqDgwMCAwMRFBQENzc3/SjVFHT7uLm5ISgoCIGBgcRwCVaDjHQJvMLUvC8p7U0QGsR0CbxEl/dNSEjA3bt3jc77NubS3gRhQEyXwHtq530nTpyIuXPnPjXvSyDwGZLTJfCe2nlfJyenZ+Z9CQQ+Q0a6BMHx6NEjJCUl4dtvv4W3tzfi4+Mxbtw4kqMlCAJiugTBwjTvSyBwCUkvEASLvb09xowZg1OnTmHnzp1IS0tDUFAQ3nvvPdy5c4dreQRCvRDTJdgEUVFR+PXXXw3yvi+99BLOnj3LtTQCwQCSXiDYJLq87zfffIPmzZuTvC+BNxDTJdg0Go0G+/btw+rVq3H37l3MmzcPU6ZMIXlfAmeQ9ALBprG3t8fYsWP1ed/U1FSS9yVwCjFdQqOhdt7X0dERkZGRJO9LsDokvUBotMjlcv06Dy1atEB8fDzGjh1L8r4Ei0JMlwNsYX0Aofehtn6NRoO0tDTs3LkTFy5cwJtvvsn7vK/Qv//GDDFdK2ALK2EJvQ/G6tdoNKiqqkJWVhaqq6vRv39/XqzzIPTvn/B/ENO1IAqFAiUlJZDL5QBg0joBuhNKLBbD29sbIpHIIhqfhdD7YI5+lUoFrVaLmzdvIigoCN27d7eUzAYR+vdPqAsxXQtgC9UNhN4HNvVrtVpUV1fj2rVraNq0KUaPHm3xPgj9+yc0DDFdlrGFOl5C74Ol9Gu1WlRWVuLrr79GTEwMpkyZYpF+CP37Jzwd+08//fRTrkXYArqKtcXFxRZbbpCmachkMmg0GjRt2pT1irVC74Ol9VMUBScnJ8TGxuLBgwcYPnw4iouL0b59e3h4eJjdvtC/f4JxkHm6LEDTNAoKCixeIlx3rLKyMhQUFLB6LKH3wZr6KYpCTEwMzp49CwcHB0RERODll182a76v0L9/gvEQ02WBoqIiVFRUWC2AaZpGRUUFioqKWGtT6H3gQr9arcZ7772H3Nxc9O7dG6+//jpiYmLwxx9/QK1Wm9Se0L9/gvEQ0zUTmUxmldHJk+hGKzKZzOy2hN4HrvXTNI25c+fi5s2bWLBgAb799lu0a9cOq1evNqpvXOtnI4YIxkNM1wzUarX+gQcX0DSN/Px8k0dVtRF6H/ik397eHuPGjcPp06exY8cOXLhwAUFBQZg/fz5yc3Pr3Z9P+gnWgZiuGRQWFnKeE6NpGoWFhYz3F3of+Ko/Ojoav/32Gy5dugQ7Ozt93vfcuXMG2/FVP8FyENNliEKhYG0OpTnQNA25XA6FQmHyvkLvgxD0BwYG4quvvkJubi569eqF1157DT179sQff/wBuVzOe/0E9iGmy5CSkhLOTxYdNE2jpKTE5P2E3gch6ReLxZg3bx5u3ryJ999/H2vWrEFaWppg9BPYg5guA9Rqtf61TL4gl8tNyssJvQ9C1a/L+/79999o3ry5lZQZh6kxRGAGMV0GlJeXM953zJgxSElJYVHN/2GKLmO3taTe+jBWV2PRD1i3D+bENsE4yGvADMjJyTE7/0XTNNauXYu9e/dCoVCgQ4cOWLRoEdq1a8e4TZFIhDZt2hi1ral9uHnzJlatWoXr16/j4cOHuHLlSp1tDh06hPXr16OoqAjNmjXDsmXLEBERYfQxAOP7YKr+Q4cO4fvvv0dJSQmcnJzQu3dvfPTRR2jatCmqq6uxbNkynDt3DhUVFQgICMC8efPQp08fk7RbUr+OvLw8rFy5EqmpqXBycsLYsWMRHx9vsM3du3cxbtw4vPDCC1i5cqVJ7ZsSQwRmkJEuA5RKpdltHDlyBHv27MHWrVtx+vRpdO3aFR9//LHVdJnaBwcHBwwePBifffZZvb8/c+YMVq9ejaVLl+LcuXPYunUr/P39TTqGKbpM1R8WFoaffvoJZ8+exaFDh6BWq7F27VoANakKiUSCpKQknD17FnPmzMH777+PgoIC3ugHalY9mzZtGqKjo/H333/jr7/+wvDhw+tst3z5cnTu3Nnk9pnqIpgGMV0TUalUZj38GDx4MM6ePYuCggKEhYUhICAA9vb2GDFiBG7fvm2WNpqmoVKp9D9XV1fXO5oypQ86vUFBQRg3blyDI/Hvv/8eM2bMQNeuXWFnZ4eWLVuiZcuWZvVBo9HUO3GfiX6JRAJPT0/95/b29rh37x6AmtHdrFmz4OfnBzs7O/Tt2xd+fn64fv26WfppmsbDhw/N0l+7D3v27EGLFi0wceJEiEQiODs7o3379gbbHjp0CGKxmPEylE/GEIF9iOmaiFKpZGWRkKFDhyIvLw+5ublQqVTYt28fevXqZVabFEUZjFQ2bNgAb29v/Oc//0FFRYX+c7b6oEOj0eDatWsoKyvDsGHDMHDgQCxfvpzRqKl2Hw4cOIDmzZtj9uzZBq+rMtV/8eJFxMTEoHv37vjrr7/w5ptv1rtdSUkJ7t69i7Zt25qlPyMjA82aNcNrr72GnJwcs/VfvnwZvr6+mDFjBvr06YNJkyYhOztb//tHjx5h3bp1WLBggclt16efYBmI6ZqIVqtlpZ3mzZsjPDwcI0eORFRUFI4ePYqFCxea1SZN08jLy0NmZiYyMzNx7949VFVV4auvvoKPjw9mzpyJBw8esNYHHaWlpVCr1Th27Bh+/PFH7Ny5E1lZWdi4caPJbWm1Wn0fbt++DTs7O2zatAmtW7fG66+/jpycHMb6w8PDcfbsWfz11194++234evrW2cblUqFDz/8EKNGjWKU26ytPysrCy4uLtixYwc6deqE4cOHIz09nbH+4uJiHD58GK+//jr+97//ITY2FnPnztWPTL/77juMHTsWEomEUfu1+0CwHGQ1YxNh67nj+vXrcfXqVRw7dgze3t44cOAA4uLisHv3bri6ujJqs7q6Gt9//z1Onz4NoMYMtVotqqqqANSMfJ2cnBrMyzLF2dkZAPDaa6/pp0G99dZb2LhxI+bOnWtSWwqFAps2bcLp06chk8lQXV2tN4Fff/0VFRUV2LZtm1l6W7ZsiV69emHhwoXYsWOH/nOtVouPP/4Yjo6OjPPrtfVXVlYa1GE7ePAg8vPzcfLkSUZtOzs7IywsTP+A7+2338bGjRuRk5MDmqZx7tw5/PHHH4zarg15tm5ZyEjXRNi6Lb9x4waGDBkCiUQCBwcHjBkzBjKZzOA21FScnZ2xatUq/Uj3/fffh52dHVxdXTFs2DBcvnwZ3377LetrqLq7u6Nly5YG7TI9RtOmTfV9SEhIgIODA0QiEWJiYnD69GkcOHCAFf0ajQZ5eXn6n2maxpIlS1BaWorVq1czLu5YW//27dvh6OiIJk2aoGPHjjh48CDS09MZ6w8ODm5w35SUFBQWFuKFF15Av379sHXrVvz11194+eWXTT4OWWPXspCRronY2bFznercuTOOHj2KIUOGwMvLC3/++SfUajUCAgJY09ehQweMGDECy5YtQ2hoaL3bGAtN06iurtbfylZVVekX9QZq5pL++uuv6NWrFxwcHPDzzz8jNjbWrD4EBQWhX79+WLJkiUG+m4n+AwcOICIiAj4+PigsLMSaNWsMHjYtXboUd+7cwaZNm+Di4sJI95P6fH190aNHD3zwwQcYMmSI3syYxtCIESP0MzCio6Pxyy+/wMPDA23atEFgYCCGDh2q33br1q0oLCzE4sWLGesnWAZiuibi4uLCyu3X5MmTUVpaipdeegmVlZUIDAxEQkKCWWVUaJo2MIxRo0Zh1KhRdbZj0ofCwkIMGTJE/3NkZCR8fX1x5MgRAMD06dPx8OFDjBw5Ek5OThg8eDCmTZtmVh+io6P17ZurPycnB6tXr4ZcLodYLEafPn3w7rvv6vv2xx9/wMnJCf369dPvs2TJEowYMYKx/sDAQJw4cYIV/UDNRWjFihVYunQpysrK0LFjR6xduxaOjo5wdHQ0SEuJRCI4OTnBy8uLsX6CZSAvRzDg+vXrvHzYYGdnh06dOhm1rdD7QPRbBlNiiMAMch/BAL6OBEzRJfQ+EP2Wga+6bAliugwQi8W8e9igK7FtLELvA9HPPqbGEIEZxHQZUPvNJj5hii6h94Hotwx81WVLENNlgIODA+9GBGKxGA4Oxj8XFXofiH72MTWGCMwgpssQb29v3tweUhQFb29vk/cTeh+IfvZgGkME0yGmawLV1dXIz89Hamoqtm/fDq1Wy/lJo8vDiUQik/cViUS8yC0y7QPRzw7mxBDBdIjpGklUVBRcXV0RHByM3r17Y8qUKUhOTubFCePn58d4f19fX0H3geg3H3NjiGAaxHSNZNy4cXB0dERlZSWqqqrQoUMH/Oc//4G/vz9nJw1FUfD394e9vT3jNhwcHATdB6LfPNiIIYJpENM1gpKSEmRlZenrR7m6umL79u2ws7ODm5sbvLy8rH7SUBQFLy8vs95g0yH0PhD9zGAzhgjGQ0z3KdA0jaSkJISEhMDLywuHDx8GUPMKb9euXfXbSSQSuLu7W+2koSgKHh4eZi/hVxuh94HoNw1LxBDBOMj8kAbIzMzEjBkzoFAocOjQIYSHhwOoWTilb9++BtvqcmL29vYoKyuz6NJ4utGJRCJh9QQVeh+IftOOZYkYIhgHWXvhCSorK/HFF19g/fr1+OSTTzBr1iyT8l0ymQz5+fmgaZrVE4eiKH3+zdK3g0LvA9FfP9aMIULDENOtxbFjxzBz5kyEhYXhm2++YfxEV61Wo7CwEHK5nJWTRjelx9fX12qT14XeB6LfEC5iiFA/vDBdlUqlX2GfpmlQFAU7Ozu4uLgwXkzaFIqLixEfH49///0X69atq7fCKhMUCgVKSkogl8sBmLYiv+62TywWw9vbm7M5lELvA9HPfQwRDOHkkqdWq1FeXg65XA6lUqk32ifRfe7i4gKxWAxPT09Wr9JarRabN2/G4sWL8fbbb+PatWto0qQJa+2LRCIEBgbypr9MEIlEyMrKwo4dO7BixQrB9UEkEkEul+PTTz/F999/L0j9IpEIM2bMwI8//ig4/YS6WHWky6er9pUrVzBjxgxotVr88MMP6NKli1ntmQLXI3tTuHr1KiIiImBvb29Qzl0ofcjPz0fnzp0hk8mgUqn0+Xmh6K+oqEDXrl1x9+5d5Ofn61NeQtFPqItVLoFs5Kd0+8lkMv3q/0zyUwqFAp9//jkSExOxbNkyTJ061erlSXQr/fOdzMxM9O7dG9XV1bCzs4NSqdSvtyqEPkilUvTo0QMymQxOTk7Izc3Vl1UXgn65XI4+ffogPz8fzs7OuHHjht50haCfUD8WdxuZTIbs7GzWHggANQYsl8uRnZ0NmUxm9H4HDx5ESEgI7t27hytXrmD69OmkHlQD5OXlISYmBhUVFQBqbnNv3brFsSrjkclk6NGjB4qKikDTNJycnJCVlcW1LKPRaDTo27cvsrKyoNFoAEBQ+gkNYzHHoWkaUqkUeXl5+lsgttvXarXIy8uDVCp9avuFhYV4+eWXMXfuXPzwww/49ddfyaTwZ6DVahEVFQUHBwc4OjqiqqpKUCe9RqNBeHg47O3t9amRzMxMrmUZjUajQZcuXeDo6AiKolBVVYXLly9zLYvAAhZJL9A0jYKCAlRUVFh0krfuWGVlZdBoNPDz8zN4uKDRaLBhwwZ8+umnmD59On788UeD4n2EhmnVqhWOHTuGmJgYDBgwAFlZWWjevDnXsozG09MTu3fvxoQJE+Dt7Y3y8nKzKy1bEycnJ2zduhUSiQR5eXlo2rQp2rdvz7UsAgtY5EGaVCq1+Fs1T6J7y8bHxwcAkJ6ejmnTpsHFxQUbNmwgxfYYIJVK0alTJxQVFcHZ2ZlrOSZTVVUFiUSCzMxMQd7Z0DSNDh06YNu2bYiKiuJaDoElWE8vyGQyqxsu8H8j3vv372P+/PkYPHgwZsyYgRMnThDDZciePXswbNgwQRouABw/fhydO3cWpOECNRWDKysrERkZybUUAouwml5Qq9X61xe5gKZp3L17F48ePcLVq1cFdTvMR5KTkzFz5kyuZTAmOTkZ48aN41oGY3T6yfoItgWr6YV79+6xOkuBCbrXHQMDAznTYAuUlpYiKCgIUqmU1RdGrIVarYaPjw9SUlLQunVrruUwolu3blizZg1iY2O5lkJgEdbSCwqFgnPDBf5vOlntifwE09m/fz+ef/55QRouAPzzzz9o1aqVYA339u3bkEql6NWrF9dSCCzDmumWlJRwbrg6aJpGSUkJ1zIEja3cmguV3bt3Y8yYMaSigw3Ciumq1Wr9q71MGDNmDFJSUtiQokcul+srPRBMQy6X48SJExgxYgTXUhih1Wqxe/duQZuu0C8ahIZhxXTLy8vN2n/Pnj0GU2IWLVpkriQA5utqrBw8eBC9evWCh4cH11IYceHCBbi7u6NDhw5cS2FEQUEBsrKy0L9/f66lECwAK6bLRi5XKpXiyy+/RGVlJQAgOzsbq1atYtyeLrdLMB2hj7KErn/Pnj0YMWIEnJycuJZCsACsmK5SqTRr/8GDByM3NxcvvPACPv74Y6SkpGDHjh2YPHkyp7oaI0qlEkeOHMHo0aO5lsIImqYFb7pC1094OmabrkqlssgDNDs7O7PnJ9I0DZVKxZKixsGxY8fQrVs3tGjRgmspjLhy5Qo0Gg26devGtRRGlJSUIDU1FYMGDeJaCsFCmG26SqWSlcnbDx48wLFjx/DFF18gKioK48ePx5YtW8xqk6IoMto1kV27dgl6lLVr1y6MHz9esC8U7Nu3D4MGDSJVHmwYs99I02q1bOhA8+bNMWrUKP3P7du3Z2WBD7b0NQZUKhX279+PpUuXci2FMcnJyfjhhx+4lsGY5ORkvP7661zLIFgQs0e6lkgtLF++nLW2+DJ3WAicPHkSbdu2FdRqXLXJzs5GaWkpevTowbUURshkMpw6dYq1Gn0EfmK26fL9No7v+vhEcnIyxo8fz7UMxuzevRtjx44V7ML0Bw8eRGxsLCmPbuOYHZ18D3C+6+MLuhcKxo4dy7UUxthCPlrI+gnGYfaCNyqVCtnZ2by8jacoCsHBwaSWlBH8+++/mDFjBq5cucK1FEbcu3cP4eHhkEqlgvx7V1ZWQiKR4Pbt2/D29uZaDsGCmD0M1JUT4SMURQnyBOQCoacW9uzZg5EjRwr273306FFERkYSw20EsHLvrasQyzf4qotv2MoLBUK+aAj9+ycYDyumKxaLeTfa1a2rS3g2ly5dgr29PUJDQ7mWwoj79+8jPT0dzz//PNdSGFFdXY0DBw5gzJgxXEshWAFWTNfT05ONZliluroaOTk5XMsQBEKvULB3714MGTJEsHc2J06cQHBwMPz8/LiWQrACrJiug4MD70aVpaWlmDBhAnr16oWdO3eSZR6fArk15xahf/8E02CtXI9CocCdO3d4MYuBoigEBQXB2dkZe/bsQUJCAgoLCzFv3jxMnjyZzIOsRWZmJl544QXcu3dPkNPrHj58iMDAQBQUFPDuwm8MGo0Gfn5++Pfff9G2bVuu5RCsAGtnmUgk4kVuV5fLFYlEsLe3x/jx4/Hvv/9i+/btOHfuHIKCgjB//nzcvXuXU518ITk5WdAvFBw4cAD9+/cXpOECwJkzZyCRSIjhNiJYPdN8fX15Ybr15ca6d++O7du34+LFi6AoCuHh4XjllVdw/vx5DlTyB1u4NSf6CUKC1WrAQM3743l5eZykGSiKQkBAgFHpA7lcji1btuCbb76Br68v4uPjG11NqtzcXERFRUEqlcLBwey1j6zO48eP4evrizt37sDLy4trOSZD0zRat26NgwcPIiQkhGs5BCvB+j2lm5sbvLy8rD7ipSgKXl5eRudrxWIx5s2bh1u3biE+Ph4JCQl47rnn8M0330Amk1lYLT/YvXs3Ro8eLUjDBYAjR44gOjpakIYLABcvXoSLiws6derEtRSCFbFIIk8ikcDd3d1qxktRFDw8PCCRSEzet3be97ffftPnfd9//33cu3fPAmr5g9Df9detnStUdN8/1yk5gnWxiOnq8qrWGPHqRrhs5JNr530BICwsDBMmTMCFCxfYkMorpFIprl27hoEDB3IthRFVVVU4ePCgoMsKCf2iQWCGxR5ZUxQFHx8fBAQEsFJ6p7727ezsEBAQAB8fH1bbb9WqFVatWoU7d+6gR48eeOWVV9CrVy/s2rULGo2GteNwyd69ezFs2DA4OztzLYUR//vf/xASEgIfHx+upTAiMzMTlZWViIiI4FoKwcqw/iCtPtRqNQoLC1mpGgz837QwX19fq+Qj1Wo19u7di4SEBEilUv18X75OU1KpVFAqldBqtaBpWn+BcnFx0S8IM2jQIEyfPp23I61n9WHq1Kno2LEj4uPjuZZaL8/Sv2zZMpSUlOCbb77hWmq9GBNDBGZYxXR1KBQKlJSU6Eujm3Jo3UhWLBbD29ubsxpS58+fR0JCAo4fP45JkyZhzpw5CAwM5ESLDrVajfLycsjlciiVSv1J8iS6z52cnLBq1Sr897//5c2LIqb2ISMjA+Hh4Wjbti0vHgSaqj8rKwv+/v4IDw8XpH4XFxeIxWJ4enryQr+QsKrp6rCFP/Ddu3exAPI6SgAADB9JREFUdu1aJCUl4YUXXkB8fDyio6OtqsGcixjwf3cMXF7EhH4hJvq5HwgJDU5M90mEfCsjk8mwZcsWfPvtt/Dz80N8fDxGjx5t0fm+Qk/XAMLvA9FvCBcxJFR4Ybq2gLXyvjKZDPn5+aBpmtUXUCiKAkVR8Pf3t3jKQeh9IPrrx5oxJGSI6VqAc+fOYfXq1fjrr78wefJkVvK+NE2jqKgIZWVlFn3bTzcFTyKRsD7jROh9IPqNw5IxZAsIc5UTntOjRw/8/vvvSEtLg1arNXu+L03TKCgosPjJojtWWVkZCgoKWD2W0PtA9Jt2LEvEkK1ATNeCtG7dGl9//bVR8321Wm2D7RQVFaGiosJqAUzTNCoqKlBUVMRam0LvA9FvGpaIIVuBpBesiFqtxp49e7B69eo6ed/Ro0fD09MTSUlJBrdkQllA6GkIvQ9EP3PYiiFbgox0rYiDgwNefPFF/Pvvv/j1119x5swZtG7dGnFxcThy5Aj++OMPg8nyarVa/8CDC2iaRn5+vllVN4TeB6LfPNiIIVuDjHQ5Jjc3F6NHj8bly5cBAM7Ozti/f7++mgNbU3qYopsKxPRBoND7QPSbj7kxZGuQkS7HeHp6IisrS/9zVVUVhgwZgszMTM5PFqBmpCKXy6FQKEzeV6FQCLoPRD87mBNDtgiZxcwxNE0jLi4Ozs7OcHd3h0gkglQqhaOjI5RKJdfyANRoLCkpMXmkUlJSwvkJr4NJH4h+9mAaQ7YIMV2O8fDwwLp16ww+U6vVuHHjBkeK6kcul0OtVhv9tpFarda/WsoXTOkD0c8+psaQrULSCzykvLycawn1YoouofeB6LcMfNVlTYjp8hBL5OEOHz6MUaNGoXv37hg9ejSOHz9u0v66vJyxmNMHlUqF+Ph4DB48GKGhoUhJSamzzfXr1zFx4kRER0ejb9++2LZt2zPbNaUP5ui/ffs2XnnlFfTs2RM9e/ZEXFwcbt++rf99UlISxo4di+7du2PIkCFISkoyql1r6QeAyspKLFu2DH369EFMTAwmTpxYZxuVSoVRo0YZvRC+qTFkqzTucT5PYTuXW1xcjI8++ghr1qxB79698c8//2D+/Pk4fPgwmjVrZhFd5vYhLCwMb7zxBubPn1/nd+Xl5Zg5cyYWLFiAQYMGQaVSobi4mFVd5uhv3rw5EhIS4OvrC61Wi+3bt2PBggVITk4GUGM+y5cvR3BwMPLy8jB9+nRIJBIMHTqUF/oB4LPPPoNGo8HevXvh7u5u8LBXR1JSEjw9PfH48WOj2+XLcwouISNdnqFSqRiNUEJDQw1qui1atAhr1qwBUGO6bm5u6NOnDyiKQmxsLFxdXZGXl2fSMWiahkql0v8sl8uRk5PDqA9P0+vo6Ig333wT4eHhsLOrG6I//fQTevbsiREjRsDJyQlNmjRBmzZtTO5DdXV1vWZirn43Nzf4+fmBoijQNA07OzuD73ry5Mno1KkTHBwcEBQUhP79++PSpUsm66dpGleuXKmj1Vz9OTk5OHHiBD755BN4eXnB3t6+TrXi/Px8HDhwAHFxcUbprk9/Y4WYLs9QKpWsLxISEhKCoKAg/P3339BoNDh+/DgcHR0RHBxsUjsURRmMVHbs2IG2bdtizJgxyMzM1H9uiT7U5vLly3B3d8cbb7yBvn37Yvbs2ZBKpUbtW7sPp06dQseOHdG3b1+DdTHY0t+zZ09ERkZixYoVDZoTTdO4ePEi2rVrZ7L+mzdvokuXLujWrRuOHTumN1pz9V+9ehU+Pj5Yt24d+vTpg7Fjx+LYsWMG26xYsQLz5s2Di4uLSW0/GUONEZJe4BlPW4OBKfb29hg1ahQ++OADVFdXw9HREV9//bXJi05rtVqkpaWhsrISAJCRkQEnJyfs378fhw4dQnh4OL799lu0b9+e9T7Upri4GJmZmdi4cSOee+45JCQkYOHChfj555+fua9ardb3IS0tDa6urjh16hRiY2Px3HPP4b///S969uzJis4zZ85AoVBg37598PX1rXeb77//HlqtFmPGjDGqzdr67927B1dXV1y+fBmjRo2CRCLB559/jlGjRpmlu7i4GLdu3cILL7yA//3vf0hPT8c777yDtm3bok2bNjh+/Dg0Gg0GDhxYb779WVgixoUEMV2eYYl5lWfPnkVCQgKSkpLQsWNHXL9+HXPmzMH69evRoUMHo9vRaDQ4evSo/lY4Ly8ParUaWq0WKpUK58+fx7Zt2/D555+z3ofaODs7Y8CAAejcuTMAYObMmejTpw/kcvkz1y+urq7W9+HBgwf6W92qqipcvXoVGzZsQExMDGtaRSIRXn75ZcTGxmLv3r0GOfRff/0V+/fvx9atW+Hk5GRUe7X1P3r0CNXV1QBqRre5ubn49ttvMXLkSLM0Ozs7w8HBAdOmTYODgwOioqIQHR2NM2fOQCKRICEhAd9//z3j9vkyd5grSHqBZzC9LXR1ddWPQAGgtLRU/+8bN24gIiICISEhsLOzQ+fOnREaGopz586ZdAxHR0d88MEHOHToEA4dOoS5c+fqK3zExcUhLy8Pa9asMaoPT9P7LIKDgw2OYcp3JhKJ9H1Yvny5vhzU+PHjcePGDezdu5d1/VqtFkqlEvfv39d/tnv3biQmJmLTpk2QSCSM9G/cuBEURcHV1RUDBgxAamoqUlNTzdZfX9pJ1+a9e/dQWFiIiRMnol+/fnj33XdRUlKCfv36oaCgwKg+NPY1donp8oz6HhwZQ/v27XHw4EFoNBqcPn0aqamp+t+FhITg4sWL+odGmZmZuHjxosk53Sf1hYWFYdasWbh16xY2btwIPz8/o/vwNL1AzYiuqqoKQM2DoaqqKv0IacyYMTh+/DiysrKgUqmwYcMGhIeHG12lQ6cvODgYb7/9NjIyMrBz507992Gu/jNnziAzMxMajQaPHj3CV199BTc3N/3DvgMHDuDbb7/Fpk2bEBAQYJTm+vT7+vrijTfewD///IPjx4/ry7mbqz8iIgI+Pj7YvHkz1Go1Ll26hAsXLqBnz55o164djh07hp07d2Lnzp347LPP0KxZM+zcudPoiwfTGLcVyII3PEOlUiE7O9vkW7Br165h0aJFkEqlGDBgADQaDfz9/TF37lwANbey27ZtQ2lpKTw9PfHqq6/WO/fyaVAUheDg4GfWrTOmD8/SO3jwYBQWFhrsc/jwYb2x//7779i4cSMqKysRHh6OxYsXG3XSG9MHc/UfOXIE3333HYqLi+Hi4oLOnTtj3rx5+lz3kCFDUFxcbKBhxIgRWLJkCS/0A8CtW7fwySef4ObNm/Dx8cHcuXPrnY+bkpKCDz/80Oh538bGkC1DTJeHXL9+nZcPG+zs7NCpUyejthV6H4h+y2BKDNkqjXucz1NMnYZjLUzRJfQ+EP2Wga+6rAkxXR4iFot597BBtyaqsQi9D0Q/+5gaQ7YKMV0e4unpybWEejFFl9D7QPRbBr7qsibEdHmIg4MD70YEYrHYpCX5hN4Hop99TI0hW4WYLk/x9vbmze0hRVHw9vY2eT+h94HoZw+mMWSLENPlKSKRiBd5OV0eztRXhgHh94HoZwdzYsgWIabLY3x9fXlxwujmxjJB6H0g+s3H3BiyNYjp8hgHBwf4+/tzdtJQFAV/f3/Y29szbkPofSD6zYONGLI1iOnyHDc3N3h5eVn9pKEoCl5eXnBzczO7LaH3gehnBpsxZEsQ0xUAEokE7u7uVjtpKIqCh4eHSQuxPAuh94HoNw1LxJCtQF4DFgg0TaOoqAhlZWUWXRpPNzqRSCSsn6BC7wPRbxyWjCFbgJiuwJDJZMjPzwdN06yeOBRF6fNvlr4dFHofiP76sWYMCRliugJErVajsLCQtarBuik9vr6+Vpu8LvQ+EP2GcBFDQoWYroBRKBQoKSnRl7U25U+pu+0Ti8Xw9vbmbA6l0PtA9HMfQ0KDmK4NoFarUV5eDrlcDqVSCZqm682l6T53cXGBWCyGp6cnb0YlQu8D0U8wFmK6NohKpYJSqYRWq9WfJLqyOkJZPFrofSD6CQ1BTJdAIBCsCJmnSyAQCFaEmC6BQCBYEWK6BAKBYEWI6RIIBIIVIaZLIBAIVoSYLoFAIFgRYroEAoFgRYjpEggEghUhpksgEAhW5P8Bs4E5a3hLUVEAAAAASUVORK5CYII=\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],\n",
" 'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],\n",
" 'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [1.8, 1.7], 'bf16': [1.8, 2.3], 'f32': [3.0, 2], 'f64': [4.0, 2],\n",
" 'c64': [3.5, 3], 'c128': [4.5, 3],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(6, 5))\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6oNhanFa2UoQ"
},
"source": [
"This is important because `f16` and `bf16` are not comparable because they utilize their bits differently: `bf16` represents a larger range at lower precision, while `f16` represents a smaller range at higher precision."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FZ6ana2UooNh"
},
"source": [
"However, these advantages comes with a few tradeoffs:\n",
"\n",
"- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n",
"- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n",
"\n",
"Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hLAcEHg7Vm-B"
},
"source": [
"## Type Promotion in JAX\n",
"\n",
"In designing the type promotion semantics of JAX, we kept in mind many of these ideas, and leaned heavily on a few things:\n",
"\n",
"1. We chose to constrain JAX's type promotion semantics to graphs that satisfy the lattice property: this is to ensure associativity and commutativity, but also to allow the semantics to be compactly described in a DAG, rather than requiring a large table.\n",
"\n",
"2. We leaned toward semantics that avoid inadvertent promotion to wider types, particularly when it comes to float values, in order to benefit computation on accelerators.\n",
"\n",
"3. We were fine accepting potential loss of precision (but not loss of magnitude) in mixed type promotion if it were required to maintain (1) and (2)\n",
"\n",
"With this in mind, JAX has adopted Option 3. Or rather, a slightly modified version of Option 3 that draws the connection between `u64` and `f*`, in order to create a true lattice.\n",
"Rearranging the nodes for clarity, JAX's type promotion lattice then looks like this:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"cellView": "form",
"id": "I5_GcCGwXMDV",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAADnCAYAAAAaaYxfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeVgT1/4/8PeEBGJs2ARFNgUFq6JWVBAU9wWr4t7etlr3tm5ttdVatNV+rVu1rtdqtZZatXa1LrUqiysFrRVxQwUBZUcQSKJJIMv8/vCXXFG0WYbMTDiv57nPc4Vk8m7OMPOZM2fOoWiapkEQBEEQBGHHBGwHIAiCIAiCqG+k4CEIgiAIwu6RgocgCIIgCLtHCh6CIAiCIOweKXgIgiAIgrB7pOAhCIIgCMLukYKHIAiCIAi7RwoegiAIgiDsHil4CIIgCIKwe6TgIQiCIAjC7pGChyAIgiAIu0cKHoIgCIIg7B4peAiCIAiCsHuk4CEIgiAIwu6RgocgCIIgCLtHCh6CIAiCIOweKXgIgiAIgrB7pOAhCIIgCMLukYKHIAiCIAi7RwoegiAIgiDsHil4CIIgCIKwe6TgIQiCIAjC7pGChyAIgiAIu0cKHoIgCIIg7B4peAiCIAiCsHuk4CEIgiAIwu6RgocgCIIgCLtHCh6CIAiCIOweKXgIgiAIgrB7QrYDEARBNBQajQZqtRp6vR40TYOiKAgEAojFYohEIrbjERbie7vyPb+pSMFDEARRT7RaLSorK6FQKKBWq40nkycZfi4WiyGVSuHm5gahkByeuYrv7cr3/JaiaJqm2Q5BEARhT5RKJcrLy6FQKAA8OnGYynDikUql8PDwgEQiqZeMhPn43q58z28tUvAQBEEwRKvVoqioCAqFwqyTybNQFAWpVApvb29eX1nzHd/ble/5mUIKHoIgCAbI5XIUFBSApmlGTioGFEWBoij4+vrC2dmZse0SpuF7u/I9P5NIwUMQBGEFmqZRUlKCiooKRk8oT6IoCu7u7vDy8qpzvAXBLL63K9/z1wdS8BAEQViIpmkUFhZCJpPV60nFgKIouLi4wMfHh/MnFz7je7vyPX99IfPwEARBWKikpMRmJxXg0YlMJpOhpKTEJp/XUPG9Xfmev76QgocgCMICcrm83m8X1IWmaVRUVEAul9v0cxsKvrcr3/PXJ1LwEARBmEmr1RoHgrKBpmkUFBRAq9Wy8vn2iu/tyvf89Y0UPARBEGYqKipi7aRiQNM0ioqKWM1gb/jernzPX99IwUMQBGEGpVLJ2Hwm1qBpGgqFAkqlktUc9oLv7cr3/LZACh6CIAgzlJeXs35SMaBpGuXl5WzHsAt8b1e+57cFUvAQBEGYSKvVGqfl5wqFQsHZMRN8wfd25Xt+WyEFD0EQhIkqKyvZjlAnrubiC65+f6bm4nt+WyEFD0EQhInqa4zEokWLLH6vYcwEYTlr27WiogILFixAREQEIiMj8dFHHz31GplMhl69euHNN980aZvmtKs5+cvKyjBnzhz069cPHTp0QGFhYa3fr127FkOHDkV4eDiGDx+OQ4cO1fr9+fPn8corr6B79+6Ijo7GL7/8YnV+WyEFD0EQhInUajVj21KpVPjss89QVVUF4NETNp999plFJ14mczVE1n5/c+fOhYeHB+Lj43H69GlMmjTpqdesX78eAQEB9ZLLnPwURaFHjx5Yt25dnb9v1KgRNm/ejNTUVCxfvhyrVq1Ceno6AECj0eD999/H2LFjkZqairVr12LNmjW4deuW1blsgRQ8BEEQJtBoNFb1ApSUlOD9999Hr1690LNnT6xbtw6vv/46Pv/8c6SlpWHLli2YOnWqRVPz0zQNjUZjcbaG4syZM1i9ejUePHhg/Jk57fpkGy5fvhwpKSkoKSnBBx98AKlUCpFIhLZt29Z6X3p6OrKysjBy5Eiz8j7erhUVFfjwww+Rm5tb6zXPy19XXg8PD/znP/9BSEhIne+ZNWsWAgMDIRAI0LFjR3Tp0gWXL18G8KiX6sGDBxg+fDgoikJISAgCAwORnZ39r/m5gBQ8BEEQJlCr1RavE6TT6TBr1iw0b94cx44dQ1JSEoYMGWL8PU3ToCgKAoFlh2SKojh3Nc1FJ06cwKJFi+Dt7Y0VK1bgwYMHJrfrs9rw8uXLaNmyJRYtWoSePXviP//5Dy5cuFDrfStWrEBsbKzZ+8/j7ZqdnY2NGzeiXbt2eOONN4yFz7Py/9s+Zwq1Wo1r166hVatWAAAPDw8MGTIEBw4cgE6nQ3p6OoqLixEaGvqv+blAyHYAgiAIPtDr9Ra/9+rVqygrK8MHH3wAofDRYbdt27ZYvXo1Fi9ejDVr1mDmzJn45ptv8Mknn5h9Yqyursa6detw8+ZNizM2BFevXoVOp4NCocAnn3yCJUuWYM+ePejQoYNJ732yDUNDQ3Ho0CGkpKTgs88+w7Jly5CYmIj33nsPR44cgZubG/bu3YsOHTqgffv2yMrKMiuvRqMxtuv9+/eNBcS+ffuwb98+REdHY9++fWblNceyZcvQpk0b9OjRw/izl19+GUuWLMHq1asBAIsXL4aXl9czt2HN3w3TSMFDEARhAmtvZzVv3tx44gEejZVYunSp8d/e3t749NNPLdq+g4MDwsLCEBwcbHHGhqCmpgbXr18HRVFwcHBASEgI2rRpY9J762pDABCLxfDx8cHo0aMBAEOGDMH27dtx6dIlhISE4IcffsBPP/1kUd7H2zU7OxvJycnQaDRwdHSEm5sbRowY8dzbWXXlNdWXX36JrKwsfPvtt8YCPCcnBwsWLMD69esRERGBu3fvYvbs2WjatCl69epV53a4MjcQQAoegiAIk1h6OwsAvLy8UFJSAq1WW+cJaPny5dZEg1AoRGRkJFxcXKzajr3LzMzE/v37MXToUKxatQrt27eHTCZ76kmlujyrDYODg3Hq1KlarzXsK4ZelhEjRgB41BOnVqvRp08fJCUlwcHB4bmfKRAIjO36zz//4JNPPkFQUBDWrl1rHEcjk8nMymuKLVu2IDk5GXFxcXjhhReMP799+zZatGhh7PEJCAhAr169cPbs2WcWPNb83TCNjOEhCIIwgaXjawCgQ4cO8PDwwIYNG6BUKlFdXY1Lly4xmM66fA3FW2+9hRs3buDw4cNo3749ANO/t2e1Yf/+/SGXy3Hw4EHodDrEx8ejtLQUnTt3RlRUFI4fP45ff/0Vv/76K2bNmoW2bdvi119//ddix8CQ76WXXsLZs2dx69YtxMTEGAuJZ+V/3j5XXV2NmpoaAI96vaqrq43v++abb/Dnn39ix44dcHV1rbXNtm3b4u7duzh//jxomkZ+fj5Onz793J5FLu2XFM2l/iaCIAiO0mg0yMzMtLiLvri4GCtXrkRaWhooisLLL7+Mjz/+mJFsFEUhODgYIpGIke01JOa067Pa8OLFi/j8889RWFiIgIAALFiwAF26dHnq/QcOHMD+/fvx/fffm5TNlHZ9Xv5n5a1rzNLVq1cBPCqURCJRrV6h6dOnY/r06QCAY8eO4euvv0ZRURFeeOEFDB06FO+//36dhQ3X9ktS8BAEQZgoIyODU4MwDQQCAdq1a8d2DN7ie7vyPb+tcKeviSAIguPEYjHbEerE1Vx8wdXvz9RcfM9vK6TgIQiCMJFUKuXUIEzg0W0DqVTKdgxe43u78j2/rZCChyAIwkRubm5sR6gTV3PxBVe/P1Nz8T2/rZCChyAIwkRCoZBzV61SqdTiuVaIR/jernzPbyuk4CEIgjCDh4cHZ24fUBQFDw8PtmPYBb63K9/z2wIpeAiCIMwgkUg4MWbCMEZCIpGwmsNe8L1d+Z7fFkjBQxAEYSZvb29OnFh8fHxYzWBv+N6ufM9f30jBQxAEYSahUAhfX1/WTi4URcHX19fk2XoJ0/C9Xfmev76RgocgCMICzs7OcHd3t/nJhaIouLu7w9nZ2aaf21DwvV35nr8+kYKHIAjCQl5eXnBxcbHZyYWiKLi6usLLy8smn9dQ8b1d+Z6/vpClJQiCIKxA0zRKSkpQUVFh8TpbpjBcQXt5ebE+TqMh4Hu78j1/fSAFD0EQBAPkcjkKCgpA0zSjJxiKooxjI7h8u8Be8b1d+Z6fSaTgIQiCYIhWq0VRUREUCgUjJxfDI77e3t6cm8StIeF7u/I9P1NIwUMQBMEwpVKJ8vJyKBQKADDrJGO4LSCVSuHh4cHJ+UwaKr63K9/zW4sUPARBEPWkuroan3zyCYYPH44mTZqApuk6xzkYfi4WiyGVSuHm5sarK+eGRqvVYuvWrfD390dQUBDv2jU+Ph6HDh3CvHnzoFareZffUqTgIQiCqAc0TWP8+PH44YcfMGHCBHz//ffQaDRQq9XQ6/XGk4lAIIBYLIZIJGI7MmGi7777DlOmTIGnpydKS0t51a6ZmZl46aWXoFarUVRUBC8vL17ltwZ/SzWCIAgOW7RoEfbv3w8AuHr1KgBAJBLZ1QmkITp27BhmzpwJmqZRUVGBmpoaODo68qJdS0pK0KtXL6hUKojFYly8eBFDhw5tMPslmYeHIAiCYT/88ANWr14NtVoNALh58yb0ej3LqQhrZWZmYsSIEVCpVACARo0a4dq1ayynMg1N0+jbty/KysoAADU1Nfj7779ZTmVbDkuXLl3KdgjCchqNBkqlEmq1Gmq1GtXV1dBoNBAIBJyd3tuAZGcP3/NznVqtRnl5ObKysuDo6Ijq6mq8/vrraNKkCdvRCCtoNBrIZDJkZWVBq9VCq9UiNDQUXbp0YTvav6JpGg8fPkReXh7u378PkUgEvV6PSZMmsR3NZsgYHp7RarWorKyEQqHg3WAzkp09fM/PR2q1Gh4eHjh+/DguX76M8ePH82a+EuL5+vXrhwkTJgAAIiIi8OKLL7KcyHTr1q3DtWvXMGrUKAgEAgwdOpTtSDZDCh6e4PPjhCQ7eQy1IUpKSsLixYuRmprKdhSCQQ8fPoSXlxeKi4vxwgsvsB3HbNHR0XjrrbcwevRotqPYHLl04zgmJowyvE8ul0OhUNhswiiSnZ3sAP/z24OEhAQMGjSI7RgEw86cOYMuXbrwsthRq9VISUnBjz/+yHYUVpBByxwml8uRmZnJ2OyYwKOTmEKhQGZmJuRyOSPbrAvJXputsgP8z28v4uPjMXDgQLZjEAzjc7v+9ddfaN++PVxdXdmOwgpS8HAQTdMoLi5Gfn6+cV4Eprev1+uRn5+P4uJiRrdPsj9/+/WV3bB9Pue3J/fu3UNOTg7Cw8PZjkIwLD4+nrc9d3zOzgRS8HAMTdMoLCys9xVuDZ9VUVGBwsJCRj6LZDf9s5jMbtgmn/Pbm6SkJPTp06dBzG3SkBQWFqKkpAShoaFsR7FIQ7/NSgoejikpKYFMJrPZiYSmachkMpSUlFi9LZLddExmB/if397w+bYH8WwJCQno378/L6duMPQ6hoWFsR2FNaTg4RC5XG6TK/QnGa7YrRmbQbKbj4nsAP/z2xuaphv8lbS94nO7kl5HUvBwhlarRUFBAWu3CGiaRkFBAbRardnvJdktZ012gP/57dGNGzcgFArRunVrtqMQDNLr9UhISOBtzx3pdSQFD2cUFRWxPh6CpmkUFRWZ/T6S3TqWZgf4n98eGQaG1jWxI8Ffly9fhpubG1q0aMF2FLPRNN3gBywDpODhBKVSyegjxJYyPHqsVCpNfg/Jbj1LsgP8z2+v+Hzbg3g2PrfrjRs34Ojo2OB7HUnBwwHl5eWsn7QMaJpGeXm5ya8n2ZlhbnaA//ntUXV1Nc6ePYt+/fqxHYVgGJ9vCRmyN/ReR1LwsEyr1Rqn/ecKhUJh0pgMkp1ZpmYH+J/fXqWmpqJt27Zwd3dnOwrBIKVSifPnz6NPnz5sR7EIn3unmEQKHpZVVlayHaFOpuQi2Zlnai6+57dXfO4FIJ7t7NmzeOmll3i5+CvpdfwfUvCwzNIxGCNHjsSFCxfqIdH/xmT8G3Oy12fexzWE7AD/89srMjDUPvG5XVNSUkiv4/9HCh6WqdVqi9534MABdOvWDTRNY9OmTejfvz8iIiIwefJk3L592ya5zMluyJuVlYW3334bUVFR6NChQ52vPXr0KGJiYhAWFoYhQ4bg4sWLJn+OqbksyX706FEMHz4cERER6N27NxYtWoQHDx4AAGpqavDpp59i0KBBCA8Px9ixY3H27FmzcpuTy5L8AJCfn49Zs2YhPDwcUVFRWLdu3VOvv3v3Lrp06YKFCxea/BmW5LI35eXlyMrKQvfu3dmOQjCMz7eE+JydaaTgYZFGo7F60Onx48dx4MABfPfdd0hOTkanTp0QGxtrdTaapqHRaAAAVVVVT/3e0uxCoRCDBw/GZ599VufvU1JSsH79eixbtgznzp3Dd999B19fX4uzP3z4EDU1NYxk79y5M77//nukpqbi6NGj0Gq12Lx5M4BHY2q8vLwQFxeH1NRUzJkzBx9++CEKCwstzq7Vauuc1M/S/BqNBm+99RbCwsJw8uRJJCYmYujQoU+9bvny5QgJCTF7+0Dt/A1NUlISevXqBUdHR7ajEAwqLi5GQUEBunbtynYUi5DbrP9DCh4WqdVqi0fNDx48GKmpqSgsLETnzp3h5+cHBwcHDBs2DNnZ2VZnoyjKeLUeGBiILl264NSpUxZnN+QNCAjA6NGjn/l45FdffYV33nkHnTp1gkAgQLNmzdCsWTOLs0+fPh0+Pj7Ytm2bsfCxNLuXlxfc3NyMP3dwcEBeXh4AQCKRYObMmfDx8YFAIEDv3r3h4+ODjIwMi7MfPnwYnp6emD17dq0lHCzNf+DAATRt2hQTJ06ERCKBk5MT2rRpU+u1R48ehVQqtXjRy8fzNzR8vJLWaDRQKBSQyWSoqqqCTCaDQqFosEVrXRITE9GvXz9eLidBeh1rIwUPi/R6vdXbGDJkCPLz83Hnzh1oNBocOnQIPXr0YCRbXl4eMjIy8PDhQ6SlpeHll19G+/bt8csvvzCS/Uk6nQ7Xr19HRUUFXn75ZfTv3x/Lly83+wT6ePbi4mKUl5dj3rx5aNasGZYsWfJUj4850tLSEBERgfDwcCQmJmLChAl1vq68vBx3795Fq1atzNo+TdPG7NnZ2RAIBNixYwdatmyJ119/Hbm5uRZ/91euXIG3tzfeeecdREVFYfLkycjMzDT+/sGDB9iyZQvmz59v0fYN6mPf4DrDxG5cv5LWarUoKytDTk4OMjIykJmZifz8fBQWFqKoqAiFhYXIz89HZmYmMjIykJOTg7Kysgb99B0f2vVZSK9jbUK2AzRkTMyh4unpidDQUAwfPhwODg7w8vLCN998Y/V2lUolvvnmGyQnJxuv9lQqFTIyMjBhwgQUFxdb/RlPun//PrRaLRISErBr1y4IhUK8++672L59O959912Tt6NSqbBjxw4kJycjPz/f+DOVSoX/+7//Q0REBPz9/S3KGBoaitTUVJSWluK3336Dt7f3U6/RaDRYuHAhYmJiEBgYaNb2q6ursW3bNiQnJ0Mmk6GmpsZYQOzbtw8KhQK7d++2KHtpaSkuXLiATZs2oXv37tizZw/effddHD58GCKRCP/9738xatQoeHl5WbR9A67MDWRLt27dAoCnesy4QqlUory83Dio/PE2elZ70TQNpVIJlUqFe/fuQSqVwsPDAxKJxCaZucCwLtqzbsFzHZ+LtfpAenhYxMQkUFu3bsW1a9eQkJCAf/75B++88w6mTZsGlUpl1XZfeOEFrFmzBhkZGXBycoJEIoGbmxvWr1+PqqoqCATM7zpOTk4AgNdffx2enp5wc3PDm2++afbg38aNGxuz9+3bFw4ODhCLxZg2bRoKCgoQERFhddZmzZqhR48eWLBgQa2f6/V6xMbGQiQSWTSWSiwWG7N/+eWXcHBwgEQiQWRkJP766y8cPnzY4v3GyckJnTt3RlRUFEQiESZNmgSZTIacnBzcvHkT586dw5tvvmnRth/XECc3M9zO4tp/u1arRV5eHnJzcyGXy0HTtNkFqeE9crkcubm5yMvLazA9PlevXoVUKjX7woULDMXa4MGD2Y7CGaSHh0VMFA23bt1CdHS08ap85MiR+OKLL5CTk4P27dszkq9///4YMGAA3nnnHYjFYgCol3v8Li4uaNasWa2ThqUnEEP2iIgINGvWDEuXLoWPjw8AMPbotE6nM/YgAY8OMJ9++inu37+Pr776yuJViQ3ZAwIC0LdvXyxZsgSRkZFP/d5cwcHBSE9Pr/N3Fy5cQFFRkfFqUKlUQq/X45VXXsHPP/9sUf6GJD4+HuPHj2c7Ri1yudy4sCxTvW6GqQcyMzPh6+vLy3lpzMHnHhJDr2NwcDDLSbiDFDwsEovFVh+IQkJCEB8fj+joaLi7u+PIkSPQarXw8/Ozars0TRuLmz/++OOp31uanaZp1NTUGAum6upqUBRlvMc8cuRI/PDDD+jRoweEQiF2796NXr16WZz9448/Ziz7H3/8gS5duqB58+YoKirCpk2bag3uXbZsGXJzc7Fjxw7j55vr8ezh4eE4fvw4Y/mHDRtmfMosLCwMe/fuhaurKwIDA+Hv748hQ4YYX/vdd9+hqKgIixcvtjh/Q1FTU4MzZ84gLi6O7SgAHrVBSUkJKioq6uX2oqGAys/Ph7u7O7y8vDjXs8WUhIQEzJgxg+0YFuFqryObSMHDIpFIBIqirDooTZkyBffv38e4ceOgUqng7++PdevWWX3lRVHUc3soLM1eVFSE6Oho47+7du0Kb29v44n97bffRlVVFYYPHw5HR0cMHjwYb731Fiey5+TkYP369VAoFJBKpYiKisL7779v/O/65Zdf4OjoWGv6+U8//RTDhg1jLLs1+QMCArBy5UosW7YMFRUVaNu2LTZv3gyRSASRSIRGjRoZXyuRSODo6Gj2ZGWm5Lc3586dQ1BQEDw8PNiOApqmUVhYCJlMVu9jqWiaRkVFBXQ6HXx8fOzuxKpSqZCSkmJ2DydXcLHXkW0U3RBHGHJITk4OJ1eZlkgk/3rfmmRnninZAf7ntyeLFy+GXq/HihUr2I6C4uLieuvZeRaKouDu7o7mzZvb7DNtISEhAUuXLsVff/3FdhSz1dTUwNPTE9nZ2ZwoxLmi4d1s5xipVMq5KyOKoiCVSv/1dSQ7s0zNDvA/vz3hyvw7crnc5sUO8L+enromyeQzrrSrJbjU68glpOBh2eMT2XGJKblIduaZmovv+e1FRUUFbty4wciTf9bQarXGAcpsoGkaBQUFdvX0Fp8HLPN57a/6RAoelgmFQs5dFUulUgiF/z68i2RnlqnZAf7ntxdJSUmIiooyTqnAlqKiItbnP6JpGkVFRaxmYEppaSnu3r2LsLAwtqNYhM+9U/WJFDwc4OHhwZnbExRFmdUNSrIzw9zsAP/z2wMunFiUSiUUCgUnCh6FQsHJsWXmSkxMRN++fXlZwHOl15GLSMHDARKJhBNjMgxjMMyZSZVkt54l2QHu5K+pqYFCoWhQM/AC3FlOory8nPVix4CmaZSXl7Mdw2pcaFdLnThxghO9jlxECh6OaN68OesnLoqijJPzmcPb25tkt4Kl2QFu5BcKhZg5cyZiYmJw9+5dVrPY0u3bt6HVatG2bVvWMmi1WsYm0mSKQqHg9VgewwzFbPfcWYrPxVp9IwUPC/Lz8/HNN99gxowZCA0NhaOjIzp06ABfX1/WTl4URcHX19eiFYGFQiHJbiFrsgPcyN+yZUvjZIZdunTBqlWrrFqglS8MJxY2C87KykqTXpebm4uxY8ciPDwce/furedUpufiouvXr0MsFpu98C8XGHod+Vqs1TdS8LDg66+/xjvvvINt27bh0qVL0Ol0WLlyJZydneHu7m7zA6hhHg1rJisk2c3HRHaAG/mdnJywePFi/P333zhz5gw6d+6M06dP2zSPrXHhxGLq2J24uDh069YN58+fR1BQEKZMmYKIiIhnrrO0Z88eREdHIywsDDExMbhz547JmQxjefiKC+1qKS70OnIZKXhYEBsbW+vx3ZCQEIwcORIA4OXlBRcXF5udvCiKgqurq9UrZAMkuzmYzA5wJ39gYCCOHDmC//u//8P48eMxceJE3Lt3zyaZbEmj0eD06dMYMGAAqznUarVJrysqKkLr1q0BAI0aNcKoUaMwb968Ol/722+/Yf/+/diyZQvOnz+PLVu2mD3dgKm5uMgebmexfZubq0jBY2NVVVWYNGkS3N3dIRaLIRaLsWHDBuMOahjPYYsrdsMVOlPjQEh20z+LyeyGbXIlP0VRGDNmDDIyMuDh4YGQkBB8/fXX0Ov19ZrLls6fP4/AwEB4enqylkGj0ZjUuzN16lRcuHABK1asQFhYGKRSKYYPHw5fX9+nXqvX67F161YsWLAArVq1AkVR8PPzg4uLi1nZaJqulwWG65tarUZycjL69evHdhSL8LlYswVS8NjQ+fPnERoaiqZNm+Ly5cvYunUr+vbti759+9Z6HUVRaN68Ofz8/CAQCBg/gVEUBYFAAD8/P8YHS5Psz99+fWU3bJ9L+aVSKb788kskJiZi165diIyMxKVLlxjNxBYu3PZQq9UmtfHOnTsRGhqK2NhY/P3332jZsuUzX1taWorS0lLcvn0bAwYMQHR0NLZs2WJ2sUpRFC97ef766y+EhITA1dWV7Shm02g0OHXqFPr37892FM4iBY8N6PV6rFmzBsOHD8fatWvx3//+F2KxGJMmTcKff/75zPc5OzsjODiY0UePDY9ABwcHWz125HlI9tpslR3gXv6OHTsiOTkZ06dPR3R0NN5//33eL0PAhSvp+ugxKy0tBQCkpKRg//792LlzJ44ePYr9+/ebvS0+9uhxoV0tZeh1bNq0KdtROIsUPPXs3r17GDp0KH7//XdcuHABo0ePNuv9QqEQ/v7+CAgIgLOzMyiKMvskZniPs7MzAgIC4O/vb5MJtUh2drID3MsvEAgwdepUXL9+HQqFAu3atcPPP//MmfljzFFZWYlr166hR48erOaoj+/OMHfL5MmT4ezsDB8fH4wbNw5nz541e1t8bFs+P9LN52LNVkjBU4+SkpLQuXNn4xMrLVq0sHhbEokE/v7+aNOmDZo2bQaOZlwAACAASURBVAqJRGK8bSEQCJ76n+HnEokETZs2RZs2beDv78/K5HCPZ799+zbEYjEvs8vlcuh0Ot5kfzx/kyZNkJ+fz/p+4+HhgZ07d+LHH3/EsmXLEB0djdu3bzOybVs5efIkevbsyfrEbvUxVqtly5YQiUSMjYvjk7KyMuTk5CA8PJztKBbhc7FmK/ybN5sHtFotli5dim+//Ra7du1idCcUCoXw9PQ0DpbUaDRQq9XQ6/Wgadp4whKLxRCJRIx9LhN27dqFadOmYf/+/Rg1ahSvsufk5CAqKgp9+vRBYmIir7JXV1ejR48eKCoqgkqlAsD+ftOzZ0+kpaVh48aN6N69O+bMmYOPPvoIYrG43j/bWlwYvwM86jWzhF6vh0ajgVarBU3TqK6uhkAggEgkQqNGjRAdHY24uDi0bdsWCoUCv/76KyZPnmyzfGxJTExEnz59OPf3a4qqqipcv36d9V5HriMFD8Py8/Px2muvoXHjxrh06RKaNWtWr58nEol48Qd69OhRzJo1CwCQmZkJgD/ZS0pK0KtXL+h0OuPiiHzJrtfrMW7cOOMMyFVVVXB1deVEfpFIhA8//BCvvvoq3nvvPXTo0AFbtmzhRDHxPPHx8Zg9ezbbMSAWiy26bXTx4kVMmTLF+O+uXbuia9euiIuLA/Bo2ozPPvsM/fr1g1QqxZgxYzBq1CizPoOmaV4Ur4/jcw/JiRMnEBkZybvv3NYomo83WjnqwIEDePvttzF37lwsWLCAd1c49eXixYvo1auXcVHBIUOGPHewNpcoFAp07doV2dnZ0Ol0cHJyglKp5E3bzp49G3FxcVAqlZBKpThw4ABnH7k9cuQI5syZg27dumH9+vXw9vZmO9JTsrOzERUVhcLCQk7cssnIyODk4GCBQIB27dqxHcNkNE3Dz88PJ0+eRFBQENtxzDZjxgwEBQU9c24l4hF+HLU5Tq1W491338XcuXNx4MABLFy4kDcnRFtYsWIFtFqt8QSRlpbGciLTHTx4EFlZWbWWfjD0UHFdaWkpduzYYewFUKlU+Oeff1hO9WxDhw7FtWvX0Lp1a3Ts2BEbN27k3JpMXJvYjatX9FzN9Sw3btyAUCg0Ts7IN3zunbIluzwrazQaKBQKyGQyVFVVQSaTQaFQ1MtEWJmZmYiIiEBRUREuXbqEiIgIxj+D73799VecPHkSIpEI3bt3B0VRnLwqrcv48eNRWFiIwMBAREZGQiqV8mbm4GbNmqGiogKvv/462rdvD29vb85nl0gkWL58Oc6ePYuDBw8iLCwM58+ft2mG5x0/uDJ+x4DJqQeYYpjCgGtMaVeufZcGz8uenZ0NpVKJkJAQtmNynl3c0tJqtaisrIRCoYBarTYOwnyS4edisRhSqRRubm5WPSa8e/duzJs3D8uWLcPbb7/N2T8WLti/fz+2b9+OY8eOsR3FbA8ePICXlxdKS0vRuHFjtuOYbdCgQZg5c6Zx+RK+oGkaP/zwA+bPn4+YmBisXLnS7CUOTGHO8UOj0UAikcDNzc3q44e53njjDZSVlcHFxQXOzs4oLy/H7Nmz4e3tzalHwLVaLZo0aQI/Pz/Wc5jarjqdDgDQvHlzm7drXczJDjy69d6qVStOZOcyXhc8SqUS5eXlxoXqzPlPMew8UqkUHh4eZj12++DBA8yaNQt///03fvrpJ3Ts2NG84A0Qn+8xHzlyBF9++SVOnDjBdhSzqVQqNG3aFAUFBWYvD8AVVVVVWLRoEfbv34/Vq1djwoQJjFxcsHX8sFT37t2f6u365ZdfEBYWxqmJHG/fvo0pU6ZgyJAhmDt3Lrp27WrTz+dbuz6Oz9n5gJe3tLRaLfLy8pCbmwu5XA6aps2+wjG8Ry6XIzc3F3l5eSaNF0hPT0eXLl0gFArxzz//kGLHRHy+x8zn7MnJyejYsSNvix0AcHV1xZYtW3Do0CFs2rQJffv2RUZGhsXbY/P4YamsrCx4eHgY/y0Wi7Fjxw6MHTsWHh4enOldpigKgwYNQk5ODl566SWMHj0aUVFR2L9/v7EXpb7wsV0N+JydT3hX8MjlcmRmZkKhUDDWjUvTNBQKBTIzM595pUTTNP773/9i4MCBWLJkCXbu3MnL2xts4Ps9Zq6N2zAHn7M/qVu3bjh//jzGjh2L3r17Y+HChXj48KHx94YTxfOwdfywdLsnT55ETEwMevTogU6dOsHLywuOjo6YOnUqpk2bBuDRuCcujOUxjN2RSCRwdXXF/PnzkZOTgzlz5mDNmjUICgrChg0b6qU3ik/t+iQ+Z+cb3hQ8NE2juLgY+fn5xsnSmN6+Xq9Hfn4+iouLa22/oqICo0ePRlxcHFJTU/H6668z+tn2zjDlOdsHZEvk5+ejvLwcnTt3ZjuKRextunkHBwfMnj0bV65cQV5eHtq3b49Dhw4ZxzCsWrWqzvexefwwV3V1NXbt2oXQ0FDMnDkTQ4cOxZ07d7B8+XKsXr0aAwYMwMaNG2u951kr19sSRVHw8fGp9TOhUIhXXnkFqamp+OGHH5CamoqAgADMmzcPd+7csfoz+dSudW2br9n5ihdjeGiaRmFhIWQymU0ajaIouLi4wMfHBykpKXj99dcxZswYrFy5kvXp5Plo9OjRGD16NMaPH892FLN9++23iI+Px48//sh2FLOVlpbixRdfRFlZmd0OZExMTMSsWbNQU1ODwsJCiEQiXL9+vdaK4GweP8wpQsrKyrBt2zZ89dVX6NChA+bNm4dBgwaZPMWFXC5Hfn4+Kyc2iqLg5+dn0sKyeXl52Lx5M7799lv07dsX8+bNQ0REhNkFG1/atS58zs5nvOjhKSkpsdmOATzaGWUyGQ4ePIgxY8Zgy5YtWLduHSl2LKDVanHy5EkMGDCA7SgW4XMPSWJiIvr27Wu3xQ4ADBgwAHv37kVBQYFxuYyJEyfWOlawdfwoKSkx6fXXr1/H9OnTERwcjLt37yIhIQHx8fGIjo42az4vZ2dnuLu72/yERlEU3N3dTSp2AMDf3x9r1qzBnTt30Lt3b7z55pvo3r079u3bV2vqkIcPH+Lvv/9+5na43q7Pw+fsfMb5gkcul6OiosLmVy00TcPf3x+pqakYNmyYTT/bnly4cAH+/v7w8vJiO4rZ9Ho9EhMTeTtgmc+Drc0RGxtrXAdMr9fjzJkz+PLLLwGwe/yoqKh45vgJvV6PY8eOYfDgwRgwYAD8/Pxw69YtfPPNN1aNdfPy8oKLi4vNih6KouDq6mrR37dUKsWcOXNw69YtxMbG4uuvv0ZgYCBWr16NyspKrF27FpGRkThz5sxT7+Vqu5qCz9n5jtOXflqtFgUFBazde3R0dIRKpYJWq7Xrq+T6xOdBs5cuXYKHhwfr84lYgqZpJCQk4NNPP2U7Sr3bunUrrl27hnv37qG0tBRnzpzBgwcPWD9+0DSNgoICBAcHG48fSqUSu3fvxsaNGyESiTB37lwcOnSIsd5jwzgaBweHej+pGnp2vLy8rCqwHBwcMGLECIwYMQJpaWnYsGEDAgICoFKpoNPpMHz4cKSnpyMgIAAA++eFutrVVHzObg84PYYnLy+P0ZHrljA8eeDv789aBj7r2bMnlixZwsuehlWrVqG4uPipAaJ8cO3aNYwYMQLZ2dlsR2ENl44fIpEIW7Zswfbt2xEeHo65c+eib9++9doTI5fLjSdXJr8DiqJAURR8fX1Nvo1lro0bN2L+/PnGW1xeXl64desWnJ2dOdWu5p4X+JzdHnD2lpZSqWR9xwD+93ifYeFLwnQymQyXL19Gz5492Y5iET7fEuJzdiZw6fhx//59jB49GpWVlUhOTsbhw4fRr1+/er/t5OzsjODgYEYfWTecLIODg+ut2AGAPXv2QCAQQCqVwtHRESUlJZgwYQKn2tXc8wKfs9sLzvZplZeXs75jGNA0jfLy8gZZEVvj1KlTiIiIQKNGjdiOYraHDx/iwoUL6NOnD9tRLJKQkIDp06ezHYM1XDp+ODg4YPfu3awsTCkUCuHv78+7GXx37dqFqqoqvPDCC8b/NWnSBIWFhZxpV3PPC1zaJxvqOY2TBY9WqzX+YXKFQqEgY3nMxOdehjNnziA0NBQvvPAC21HMplarkZycjH379rEdhRVcO34IBAJUV1ezevyQSCTw9/dnbd1Bc7Vr1+6pn3GtXQHTzwt8zm5POHlLq7Ky0uL3jhw5EhcuXGAwzf9Yk6sh4vOAZT5n/+uvvxASEgJXV1e2o7DCmr/T3NxcjB07FuHh4di7dy+Dqbhx/BAKhfD09ERgYCDatWuH4OBg+Pn5wcfHB97e3vDx8YGfnx+Cg4PRrl07BAYGwtPTkxMnRVO/v/psw7qYksuctrdlfi7sk7bEyYLHmvucBw4cQLdu3Yz/XrRoESOZDPc9CdPk5uZCoVCgQ4cObEexCJ/n3+FzdiZYc/yIi4szLmHxxhtvoLCwEF999ZXVmbh6/BCJRJBKpXBxcYGrqytcXFyMg6y5xtR2fbwNdTodoqOj0b17d/Tr1w+rV682ri91//59LFiwAP369UNERAQmTJiAK1eumJXJ1HY1Z598ch/MyMjAxIkTERYWht69e2PPnj1PvefChQvo0KEDNm3axHh2e8LJgketVlv1/uLiYqxevRoqlQoAkJmZibVr17KeqyFJSEjAgAEDzJo4jSsKCwtRXFyM0NBQtqNYhM+3Eplgzd9pUVERWrdujcuXL2P79u3GBS//+ecfbN++nbVchOnfn6ENAaBv3774+eefce7cOfz+++/IzMw09poolUq0b98eP/30E5KTkxETE4NZs2aZPZjXlFzmtP3j+SsrKzFjxgyMGzcOycnJ+PPPPxEZGVnr9RqNBqtXr7ZoIeuGtk9y7myk0WisGtg1ePBg3LlzBwMHDkRsbCwuXLiAn3/+GVOmTLE6G03TtWYCJZ6Nz70MiYmJ6N+/PxwcHNiOYraysjLk5OQgPDyc7SissOb4MXXqVFy4cAErVqzA9OnTIZFIsGzZMhw9ehTJyclWL41Cjh+WM7VdH2/DsLAw6HQ649NkhnFJ+fn5AAA/Pz9MnDgRnp6ecHBwwLhx46DRaJCbm2tWtsfbddmyZTh48GCtrObsk0/mX7lyJSIjIzFs2DA4OjqicePGCAwMrPWeXbt2ITIystZyKpZkbwg4V/Co1WrGH9cUCASMbJOiqAZXEVtCp9MhKSmJt8tJ8LmHJDExEb179+bkLQlbsOb4sXPnToSGhiI2NhZ///13rYUwmeipJMcPy5nark+2YcuWLXHkyBF0794dUVFRyMzMxLhx4+p8782bN6HRaMx+cunxdt2yZQv+85//IDg42Fj4mLNPPpn//v37cHFxwfjx49G7d2/Mnj0bxcXFxtcXFRXhwIEDeOedd8zKXFf2hoBzBY9er7d6G2VlZUhISMCKFSvQrVs3jBkzBt9++y0D6ZjJZ+/++ecf4yBIvtHr9UhISOBtwcPnwdZMYOrv8/Lly8jMzMQnn3yCIUOGICIigpEBpOT4YRlrvrehQ4fi3Llz+OOPPzBu3Dg0adLkqdc8ePAAH3/8MWbMmAGpVGrW9pVKJb744guMGTMGVVVVUKvVuH37NkaPHg1XV1erlnIoLS3FoUOHsHDhQsTHx8PHxwcLFiww/n7lypWYPXu2VVMFNKR9kv2h909gYp4CT09PxMTEGP/dpk0btGnTxurtAszks3d8LhiuXLkCV1dXi7qH2WZYTuLjjz9mOwprmPr77NSpEzp16oTCwkIAQLdu3Wo9DGEpcvywDBPfW4sWLdC6dWt8/vnn2LBhg/HnarUas2fPRqdOnTBt2jSztysSiRAVFYWXXnoJSUlJqK6uhpOTEwQCAcaMGYNGjRpBJpNZlNnJyQn9+vUzrq82Y8YMREVFQaFQ4OLFi1AqlYiOjrZo2wYNaZ/kXMHD9O2s5cuXM7o9W69EzEfx8fGIjY1lO4ZF+NxDcuPGDQiFQgQFBbEdhTVM/336+Phg5syZjG2PHD8sw9T3ptVqjWN4AKCmpgbvvfcemjVrZvG6cyKRCOHh4XBxccH7778PvV6P2NhYzJkzB40bN7a42AGA4ODgWv/tj///8+fP4/r168bJUR88eACBQICsrCxs3rzZ5M9oSPsk5woerj/Vw/V8bFMoFLh06RJ69erFdhSLJCQkYM6cOWzHsIhhoHhDOoA9iet/n1zPx1WWfm+//fYb+vTpgyZNmiA7Oxs7d+40PuWk0Wgwb948ODk5Yfny5Va1jeG9f/75JwIDA9G4cWOrswOP5pWbO3cu3njjDbRq1Qrbtm1DaGgopFIpZs+ejalTpxpfu2rVKnh6epo9nqch7ZOcK3jEYjFnu9homoZYLGY7BqedOnUKYWFhNpl+nmlKpRLnzp3Db7/9xnYUi8THx2PSpElsx2AVOX7YJ0vb9dKlS9i0aRNUKhXc3NwwaNAgzJ49GwCQnp6O06dPQywW13rUe+vWrejSpYvJn/F4u9Y175g1+2R4eDjee+89zJo1CyqVCqGhoVi9ejUAoHHjxrUKKycnJzRq1AguLi4WZW8IOLlaekZGBicHUgkEgjqnPCf+Z86cOfD19cVHH33EdhSzHT9+HJ9//jnOnj3LdhSzVVdXw9PTE3fu3IG7uzvbcVhFjh/2ic/tyufs9oSTfVlcrTi5motL+Dz/Dp+zp6amom3btg2+2AG4+3fK1Vx8wdXvz5RcfM5uTzhZ8EilUs6NQ6AoyuzHFRuau3fvoqKiAp06dWI7ikX4PP8On7MzjRw/TKfRaKBQKCCTyVBVVQWZTAaFQsHJyej43K58zm5PODeGBwDc3Nxw7949tmM8xc3Nje0InMbn5SSKi4tRUFCArl27sh3FIgkJCfjyyy/ZjsEJXDx+6HQ6Toxr48tq6StXrkRWVhbc3Nzg4uICuVyOAQMGwM/Pz2YZTGXKeYGL+yTQ8M5pnCx4hEIhpFKpVRM2MUmv18PFxYUTKwZzWUJCAoYMGcJ2DIskJiaiX79+vGzj+/fvIzMzE927d2c7Cidw7fhB0zQuX76MQYMGYcaMGZgxYwaaNm1q0wxKpRLl5eXGxSIfH7r5rGGcNE1DqVRCpVLh3r17kEql8PDwsEnhdu7cORw6dKjWzwQCAWbPns2ZdgUe9dyYcszg2j4JmJ7dnnD2UtzDw4MzXYAajQbvv/8+/vjjD84+AcI2nU6HxMRE3t5W4fMtoaSkJPTq1QuOjo5sR+EMLh0/DBPQnThxAoWFhWjTpg2mTp2Kq1ev1vtna7Va5OXlITc3F3K5HDRNm30MM7xHLpcjNzcXeXl5xhXH64NCoUDr1q2N7efo6IjJkyfjiy++4FS7UhQFDw8Pk1/P5+z2grMFj0Qi4cR9T8OOMWnSJCxcuBA9e/bE6dOnWc3ERZcuXYKXl1et9Yf4wjBDMV8HLPO5WKsvXDp+SKVSSCQStGvXDtu3b0dmZiYCAgIwePBgDBw4EH/++We9PMEjl8uRmZkJhULB2IUaTdNQKBTIzMxkvLfi7t27+PDDD9GyZUvk5+cjMDAQDg4OCA0Nxddffw2Am+1qKj5ntxecLXgAwNvbmxM7h6+vL0aMGIHLly9jxowZmDx5MqKjo3Hx4kVWs3EJn2covnr1KqRSKQICAtiOYjaapnn93dcnrhw/nrwI8PT0xOLFi5Gbm4sJEyZg0aJFaNeuHbZt2walUmn1Z9I0jeLiYuTn50Ov1zPeK03TNPR6PfLz81FcXGz19lNTU/HKK68gNDQUAJCWloaff/4Z69evR4sWLXDkyJFai+FytV1Nwefs9oDTBY9QKISvry9rO4ih2HFwcAAAODg4YPz48bh58yZGjBiBmJgYjB07Fjdu3GAlH5fwuZeBz9kzMzMBgLG14uwJ144fT3JycsKbb76JtLQ0bNu2DUePHkWLFi0QGxtrXMPLXDRNo7CwEBUVFfV++52maVRUVKCwsNDsz9Jqtfjpp5/QvXt3jB8/Hj179sSdO3ewdu1atGjRAgAwfPhw3L59+6mpFrjers/D5+z2gNMFDwA4OzvD3d3d5jsIRVFwd3eHs7PzU79zdHTEjBkzkJWVhbCwMPTu3RuTJ0/G3bt3bZqRKx48eICLFy+id+/ebEexiD3czmL7qpFLampqcOLECUyYMAEvvvgi544fdb22T58+OHjwIFJSUqBQKNChQweMHz/+qV7kBw8e1FoL6kklJSWQyWQ2G2tI0zRkMhlKSkpMen1VVRXWrFmDwMBAfPXVV1i4cCEyMzPx7rvv1vmI9LPajYvnBVPxOTvfcb7gAQAvLy+4uLjYbAehKAqurq7w8vJ67uskEgkWLFiArKws+Pr6IjQ0FO+++y5KS0ttkpMrTp8+ja5du9aa5pwvVCoVUlJS0LdvX7ajWITPxRrTSktLMXDgQLi4uCAmJgZ79uyBWCzm7PGjLkFBQdi8eTOys7PRqVMnjBo1Cr169cLvv/8OnU6H+fPno0OHDnUWPXK53CY9O08y9PQ8b0zP7du3MWfOHAQGBuLKlSv4/fffcfr0aYwcOdLi3gY+teuT+Jydz3hR8BjuOdqiKjZUwebca3VxccGyZctw48YN41TdixYtQlVVVb1m5Qo+jyFJTk5Gx44dzVp/hitqampw+vRp9O/fn+0onODk5ISMjAxUV1fj4cOHEIvFWLZsGeePH3Vxc3PD/PnzkZ2djVmzZmH16tVo3bo1du7cCYVCgQEDBuDhw4fG12u1WhQUFLD2FClN0ygoKKj19BZN0zh16hRGjBiByMhIODs749q1a9i9e7dZa1U9Cx/b9fHt8TU7n/Gi4AEeNVrz5s3h5+cHgUDAeMNRFAWBQAA/Pz80b97cou03bdoUGzZswKVLl1BSUoKgoCCsWrWq1oHJHvG5l4HP2c+dO4egoKAG+XhpXVxdXbF48WLjvwUCAUaPHg2AH8ePuohEIrz66qs4d+4chg4dCr1eD71ej+zsbAwbNsz4dFdRURHrU2bQNI2ioiJUV1dj165dCA0NxYwZM/Dyyy/jzp07WL58Oby9vRn9TL62q2HbfM3OV7wpeAycnZ0RHBzM6ON9hsf0goODGbm/6e/vj507d+Ls2bNIS0tDUFAQtmzZgpqaGgbScktBQQHu3buHzp07sx3FInwesMznYq0+7NixA8uWLcNPP/0EiUSCsWPHolGjRrVew4fjx7McOHAAIpEIUqkUAoEAp06dwuLFi6FUKhl99NxShltbgwcPxt69e7FixQpcv34db7/9dr0/As3nduVzdr7h5Grppnre7KH/xrBj2WL20LS0NCxatAg3b97EZ599hjfeeMNuRsnHxcXh+PHj+PHHH9mOYrbS0lK8+OKLKCsr4+WMo+Hh4Vi1ahVvxx8xhaZprFq1Cjt27MDx48cRFBSErKwsuLq6wtPT85nv48vxw+DYsWNQqVSQSqV44YUXIBQK0b59e5SVlXFmBl+9Xg+BQICQkBDWMvCtXR/H5+x8wOuCx4Av68OcOXMGsbGxqKysxOeff46RI0fyvpvxtddew8CBAzFlyhS2o5ht7969+O2337B//362o5itoqICLVu2RFlZGZycnNiOwxq9Xo8PP/wQCQkJOH78uEW3TPhy/KiLVqvFrVu3WO/deRxFUWjTpg0nvhs+tytfs3OZXRQ8T9JoNFCr1cZJtwz3MsVica0JrNhA0zSOHj2K2NhYODo6YsWKFRgwYACrmSyl1+vRrFkzpKWlcXJRv38zceJEdO/eHTNmzGA7itl+/fVXfPvtt/jzzz/ZjsIajUaDadOm4fbt2/jjjz8YWwiRy8ePJ5WVleHevXtPFTyDBw/G0qVLERER8dR7fvrpJ3z11VdQqVSIj4+Hq6sro5koikLTpk2f27vGBj6165P4nJ1L7LIUFIlEnN0JKIrCyy+/jOjoaPzyyy+YOXMm/Pz8sHz5ct4t/pieng4PDw9eFjuG5SQ+/fRTtqNYpKGP31GpVHjllVeg0+mQkJDAaPc9l48fTzJ37I5Go8GaNWuwd+9e42SVmzdvxokTJ5Cbm4u33noLM2fOrPWeiooKrFq1CmfPngVFUYiKisLq1auf+RmG5Se4VvDwqV2fxOfsXMK7Qcv2QiAQ4NVXX8X169fx2muvYdy4cRgxYoRNFhRkCp8H/F6/fh1isRitWrViO4rZaJrG8ePHefvdW6uqqgqDBw+Gs7MzDh482KDHKqjVarNef//+fVRXV9fa7/39/TFv3jxERUXV+Z65c+fCw8MD8fHxOH36NCZNmsR4LoKwBVLwsEwkEmHatGnIyspCnz59MGDAAIwfPx7Z2dlsR4NGo4FCoYBMJkNVVRVkMhkUCgU0Gg0A7s+/87z8fM5++/ZtaDQatGvXju2Ydfq3/cYaJSUl6NOnDzp37ozdu3c36KtejUbz3N6d69evG+fAWbx4MW7duoWYmBgAQGRkJKZOnQoAGDFiBKKiouqcODQlJQUlJSX44IMPIJVKIRKJ0LZt23/NRtM0I+1NEEyyyzE8fCaXy7FhwwZs2rQJ48aNwyeffML43BXPYu5AufT0dPTq1QteXl6cGChnTn6VSgWaptGyZUtODPQzJ7tWq0VJSQlCQ0N5l92aAZY5OTkYOHAgJk+ejEWLFvF+wL+1FAqFcYHQJw0ePBgSiQRbt25Fo0aNMHv2bHTr1g1jxoxBdHQ0Ll269NR3v3DhQvj7+9e6pbV161akp6fDzc0NycnJ8PX1xQcffIBu3bo9N5th/pe6losgCLaQHh6OcXZ2xqeffoqbN2+icePGCAkJwUcffYT79+/X22cqlUrk5eXh1q1buHfvHpRKpXFwnGGis8f/Z/h5x44dIZPJcOvWLeTl5TGy0rOt8ovFYjRq1Aj37t1jNb8l2R0cHODj48PL7Hq9Hkql0uzsV65cQVRUFD788EMsXry4wRc7AOosdB732muvGZcwmD59Oo4ePWr2Z5SWliIlJQVhYWE4efIk7Va0oAAAD41JREFUJk6ciPfeew+VlZVW5yMIWyMFD0d5eHhg7dq1uHLlCmQyGdq0aYPPP//cOD8DE7RaLfLy8pCbmwu5XA6aps1+vNXwHrlcjtzcXOTl5dWaXr4+8Tk/yW569uTkZAwcOBDr16/n5RN19eXfvvPH103y9vZGWVmZ2Z8hFovh4+OD0aNHQyQSYciQIWjWrBkuXbpkdT6CsDVS8HCcr68vtm3bhnPnzuHGjRsICgrChg0brB4UKJfLkZmZyegMrYanMzIzM+t9IjQ+5yfZa3te9iNHjmD06NHYs2cPXnnlFUY+z178Wy/X4yuYFxcXW/TUVHBwsNmfa+7rCMJWSMHDE61bt8bevXsRHx+PpKQkBAcHY+fOnWZf0dM0jeLiYuO9f6avwgy3LfLz81FcXFwv2+drfpL9+dt/Mvvu3bsxdepU/PHHHw32ibTnEQief/jet28fSkpKIJPJsGPHDkRHR9f5Oo1Gg+rqauj1emi1WlRXV0On0wEA+vfvD7lcjoMHD0Kn0yE+Ph6lpaUmLSXzb/kIwtbIHskzHTt2xOHDh/Hjjz9i9+7dCAkJwS+//FLrfnl6enqdhRBN0ygsLERFRUW9dzcb1tUpLCxktCeAr/lJdtM/q6KiAvHx8Vi0aBFOnjyJsLCwev1MvhKLxc9tj6FDh+Ltt9/GkCFD4Ofnh7feeqvO1y1duhRdu3bF0aNHsWPHDnTt2hWHDx8GALi4uGDz5s347rvvEBERgZ07d2LTpk3/OsmjYZwcQXAJeUqLxwyT58XGxkKv12P58uVo27YtWrdujXnz5uGLL76o9fri4mKbnLQeR1EU3N3d0bx5c6u3xef8JLt5NBoNGjduXOctFeJ/MjIyODk4WCAQcHbaBKLhIgWPHaBpGvv378fixYtRXl6OiooKODk54ejRo+jduzeAR2Mv8vPzWRlISFEU/Pz8rFq1l8/5SXbLMLHf2LucnBzWno58HolEgsDAQLZjEEQt5JaWHaAoCmPGjMHPP/8MmUwGvV4PlUqFkSNHorKyElqtFgUFBaw9NUHTNAoKCix+gojP+Ul2y1m73zQEUqmUc4ODKYoi8+8QnEQKHjuyceNGAI/m8hGLxaiqqkL//v1RVFTE+iOiNE2jqKjIovfyOT/Jbh1r9puGgKkFU5nG1VxEw0ZuadmR8vJy5OTkQKPRQKPRQC6Xw83NDe7u7qyfuIBHV34BAQFmrX2kVCqRm5vLy/wkOzMs2W8akry8vHqfBsIczs7O8Pf3ZzsGQTyF/fUACMZ4eHjAw8Oj1s+4dDCkaRrl5eVmHQzLy8s5cdIFzM9PsjPDkv2mIfHw8GB0XiRrUBT11DGIILiC3NKyY1qtltGZmZmgUChMHpPB5/wkO7PM2W8aGolEwomxPIaxO6QnjuAqUvDYMVPWu2GDqbn4nJ9kZx5Xc3GBt7c3JwoeHx8fVjMQxPOQgseOcaWb+3GGZQRMwef8JDuzzNlvGiKhUAhfX1/Wih6KouDr6wsHBwdWPp8gTEEKHjtm7XpbdTl27BhiYmIQHh6OESNGICkpqd5yWZNfo9Fg3rx5GDx4MDp06IALFy489ZqMjAxMnDgRYWFh6N27N/bs2cNYLmu/++zsbLz66quIjIxEZGQkpk2bhuzsbOPv4+LiMGrUKISHhyM6OhpxcXEmbdcW2QFApVLh888/R1RUFCIiIjBx4sSnXqPRaBATE4P+/fubtM362J/tibOzM9zd3W1e9BgmuCTzJRFcRwYt2ymNRsP4VXppaSk+/vhjbNq0CT179sTZs2fxwQcf4NixY2jSpInJ26FpGhqNBiKRCDqdDlVVVU+9n4n8nTt3xvjx4/HBBx889bvKykrMmDED8+fPx6BBg6DRaFBaWmp2/nv37sHT07PWSYaJ7J6enli3bh28vb2h1+vx448/Yv78+di/f78xw/LlyxEcHIz8/Hy8/fbb8PLywpAhQ0zOXlVVBYlEAkdHR0azA8Bnn30GnU6HgwcPwsXFBTdv3nzqNXFxcXBzc8PDhw9N2ubj2Ym6eXl5QafTQSaT2aSXjqIouLq61lqZnSC4ivTw2Cm1Wm3RlV6HDh2Ql5dn/PeiRYuwadMmAI8KHmdnZ0RFRYGiKPTq1QuNGjVCfn6+WZ9BUZTxav3EiRNo1qwZJkyYgDt37piV/3lZRSIRJkyYgNDQ0DoXMfz+++8RGRmJYcOGwdHREY0bNzZ5ZtjH8/v4+KBr1644efKk8QRj6nf/vPzOzs7w8fEBRVGgaRoCgaDW9zxlyhS0a9cOQqEQAQEB6Nu3Ly5dumRW9tdeew0+Pj7YunUrampqGMuek5ODU6dOYcmSJXB3d4eDgwPat29f6/0FBQX4448/MG3atH/9rLqyE3UzjKOxRU+PoWeHC+OHCMIUpOCxU/Wxvk779u0REBCAkydPQqfTISkpCSKRyOz1jmiaRk5ODtLT03Hz5k04Ojpi3759aNOmDYYNG4arV6/W+/pAV65cgYuLC8aPH4/evXtj9uzZKC4uNum9Op3OmF+r1SItLQ1Dhw5Fu3bt8NNPPzGaPTIyEl27dsXKlSufWRzQNI20tDS0bt36X7en1+uN2e/du4fy8nJ88MEHaNasGRYtWgSNRmN15mvXrqF58+bYsmULoqKiMGrUKCQkJNR6zcqVK/Hee++ZvcAkF9eN4hqKotC8eXP4+flBIBAwXoxQFAWBQAA/Pz80b96cFDsEb5BbWnaqPrqzHRwcEBMTg48++gg1NTUQiUT48ssvzX4MtaamBjt27EBKSgrkcjmqq6uh1+uh0+lw5MgRqFQq/Pbbb4znf1xpaSlu3LiB7du3IygoCOvWrcOCBQuwe/fuf32vWq1GXFwcUlJSjD9TqVS4efMmpk+fjrt37zKWMyUlBUqlEocOHYK3t3edr/nqq6+g1+sxcuTIf92eSqUyZr99+7bxZyqVCitXrkR0dLRZtyfrUlpaitu3b2PgwIE4ceIE0tPTMWvWLLRq1QqBgYFISkqCTqdD//796xxb9TxcG0zNZc7OzggODkZRURFjA9ENj557e3tDKCSnD4JfSA+PnaqPq67U1FSsW7cOcXFxSEtLQ1xcHJYsWVLn+IzncXJywsqVK5Geno7NmzdDKBRCIpEgMjISf/31F5KSkur9qtHJyQn9+vVDSEgInJycMGPGDKSnp5v0JFDjxo2N+Q3/9vLywrfffouKioo6b6FZQyKR4JVXXkFsbCzu379f63c//PADDh8+jC1bttQai2NK9i5duhi/+w8//BBlZWXo2LGj1XmdnJwgFArx1ltvQSQSoVu3bggLCzMWb+vWrcPHH39s0bZJb4J5hEIh/P39ERAQAGdnZ/y/9u7ftYk3gOP452KUM5BoSoqtxKAobdG6tLg4qQgqOLgIirg5V7opFXVyFx1ErDjooIuLFv0HVFREEH+hOIhUK9aCkeYkMfcdyp2t5tvcXdMm9+T92lpI8iE87fNJ7rnnsSwr9HvoPSaTyWjDhg0qFAqUHcQSo9ZQUSfdlStXqlQq+T9PTk5qzZo1kqS3b99qcHDQX4/R39+vrVu36tGjR+rr64uUb/369dq9e7dGRka0ffv2UPnny1pPT0/PnH/8YScBL9++fft08OBBHT161J8Egr73YfJXq1U5jqOvX7/638Dcvn1bo6OjunbtWqhFo16+nTt3atu2bTp58qT/nEFv/Z4ve61LnN77+/HjR42Pj/t3bZXLZf38+VM7duzQjRs36u7j0ugy2S5SqZQKhYIqlYqmpqZULBblOI5c16059r3f27atdDqtbDZLyUHsMYINZdt2pK+we3t7NTY2pk2bNunhw4d6+vSpNm/eLGlmDc/o6KjevHmjvr4+vX79Ws+ePdOhQ4dCvYbruv7ajS1btuju3buR8s+XVZq5dOY9R7lc1q9fv7RixQpZlqUDBw5oeHhYR44c0caNG3Xp0iUNDAwEOuV5dv6xsbFI2evlf/DggbLZrHp6elQqlXThwgVlMhl/YfWdO3d0/vx5Xb16VevWrav7WrWynz17dlGyDw4Oqru7W1euXNGxY8f04sULPX78WMPDwyoUCnPW8zx//lznzp3TrVu36h44OTs7okkmk+rs7FRnZ6ekmb8Lx3FUrVb9kpNIJGTbNnfDwTgcHmqwV69ehV7k+fLlS42MjOjz58/atWuXfv/+rXw+r6GhIUkzl1CuX7+uyclJZbNZHT58uOYeK/NJJBJziknU/PWy7tmz55+Ttu/du+d/i3Dz5k1dvnxZpVJJAwMDOnXqVKBvSoLkD/Lez5f//v37unjxoiYmJmTbtvr7+3X8+HH19vZKkvbu3auJiYk5k9L+/ft1+vTppmeXpPfv3+vMmTN69+6duru7NTQ0VHO/nSdPnujEiROB9nMKOm4AoBYKj8E+fPig6enpZsf4RyqVCnQLeJzzk73xgo4bAKiFC+IGa4UDBf/m3eURRJzzk72xwowbAKiFwmOwemsimiVorjjnJ3vjtWouAPFA4TFYMplsuU/F6XQ68N0ecc5P9sYKM24AoBYKj+FyuVzLXJ6wLEu5XC7UY+Kcn+yNEWXcAMDfKDyGS6VSLbEmw1uDEXZX5jjnJ/vCRR03APA3Ck8baIXD/bxDDaOIc36yL8xCxg0AzEbhaQPJZFL5fL5pk5dlWcrn81q2bFmkx8c5P9mjW+i4AYDZKDxtIpPJqKOjY8knL8uy1NHRoUwms6DniXN+sofXqHEDAB4KTxvp6urSqlWrlmzysixLq1evDnXO03zinJ/swTV63ACAxE7Lbcd1XX358kXfv3+PdNZWUN4n9K6uroZOlHHOT/b6FmvcAACFp039+PFDnz59kuu6DZ3ALMvy114s5uWIOOcn+7+WatwAaF8UnjZWqVQ0Pj6uYrHYkMnLu4V47dq1S7JJXJzzk/2PpR43ANoThQeanp7Wt2/fVCwWJSnUJOZddkin08rlck3ZLyXO+cnevHEDoL1QeOCrVCqamppSsViU4zhyXbfmOgrv97ZtK51OK5vNtsQn8zjnJzsALC4KD/5XuVyW4ziqVqv+ZJVIJGTbtpYvX97seHXFOT/ZAaCxKDwAAMB47MMDAACMR+EBAADGo/AAAADjUXgAAIDxKDwAAMB4FB4AAGA8Cg8AADAehQcAABiPwgMAAIxH4QEAAMaj8AAAAONReAAAgPEoPAAAwHgUHgAAYDwKDwAAMB6FBwAAGI/CAwAAjEfhAQAAxqPwAAAA41F4AACA8Sg8AADAeBQeAABgPAoPAAAwHoUHAAAYj8IDAACMR+EBAADGo/AAAADjUXgAAIDxKDwAAMB4/wGAynchHE2UagAAAABJRU5ErkJggg==\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#@title\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"lattice = {\n",
" 'i*': ['u8', 'i8'], 'f*': ['c*', 'f16', 'bf16'], 'c*': ['c64'],\n",
" 'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'], 'u64': ['f*'],\n",
" 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'], 'i64': ['f*'],\n",
" 'f16': ['f32'], 'bf16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],\n",
" 'c64': ['c128']\n",
"}\n",
"graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)\n",
"pos = {\n",
" 'i*': [-1.25, 0.5], 'f*': [4.5, 0.5], 'c*': [5, 1.5],\n",
" 'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],\n",
" 'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],\n",
" 'f16': [5.75, 0.8], 'bf16': [5.75, 0.2], 'f32': [7, 0.5], 'f64': [8, 0.5],\n",
" 'c64': [7.5, 1.5], 'c128': [8.5, 1.5],\n",
"}\n",
"fig, ax = plt.subplots(figsize=(10, 4))\n",
"ax.set_ylim(-0.5, 2)\n",
"nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)\n",
"# ax.patches[12].set_linestyle((0, (2, 4)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0-E2KWjYEXO"
},
"source": [
"The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n",
"\n",
"For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gkbaKRmOtXJ4"
},
"source": [
"## Appendix: Example Type Promotion Tables\n",
"\n",
"The following are some examples of implicit type promotion tables implemented by various Python array computing libraries."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KFgIKR70s1gw"
},
"source": [
"### NumPy Type Promotion\n",
"\n",
"Note that NumPy does not include the `bfloat16` dtype, and that the table below ignores value-dependent effects."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"id": "aJELZ70OheaC",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
\n",
"
b
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i*
\n",
"
f*
\n",
"
c*
\n",
"
\n",
" \n",
" \n",
"
\n",
"
b
\n",
"
b
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
-
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
u8
\n",
"
u8
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
-
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
u8
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
u16
\n",
"
u16
\n",
"
u16
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i64
\n",
"
-
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
u16
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
u32
\n",
"
u32
\n",
"
u32
\n",
"
u32
\n",
"
u32
\n",
"
u64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
-
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
c128
\n",
"
u32
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
-
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
c128
\n",
"
u64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
i8
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
f64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
-
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i8
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
i16
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
f64
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
-
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i16
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i64
\n",
"
f64
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i64
\n",
"
-
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
c128
\n",
"
i32
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
f64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
-
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
c128
\n",
"
i64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
bf16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
f64
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
f64
\n",
"
-
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f16
\n",
"
f16
\n",
"
c64
\n",
"
\n",
"
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
f64
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
f64
\n",
"
-
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f32
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
-
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
c128
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c128
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c128
\n",
"
-
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
\n",
"
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
-
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
\n",
"
\n",
"
i*
\n",
"
i64
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
-
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
f*
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
-
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
c*
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
-
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# @title\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from IPython import display\n",
"\n",
"np_dtypes = {\n",
" 'b': np.bool_,\n",
" 'u8': np.uint8, 'u16': np.uint16, 'u32': np.uint32, 'u64': np.uint64,\n",
" 'i8': np.int8, 'i16': np.int16, 'i32': np.int32, 'i64': np.int64,\n",
" 'bf16': 'invalid', 'f16': np.float16, 'f32': np.float32, 'f64': np.float64,\n",
" 'c64': np.complex64, 'c128': np.complex128,\n",
" 'i*': int, 'f*': float, 'c*': complex}\n",
"\n",
"np_dtype_to_code = {val: key for key, val in np_dtypes.items()}\n",
"\n",
"def make_np_zero(dtype):\n",
" if dtype in {int, float, complex}:\n",
" return dtype(0)\n",
" else:\n",
" return np.zeros(1, dtype=dtype)\n",
"\n",
"def np_result_code(dtype1, dtype2):\n",
" try:\n",
" out = np.add(make_np_zero(dtype1), make_np_zero(dtype2))\n",
" except TypeError:\n",
" return '-'\n",
" else:\n",
" if type(out) in {int, float, complex}:\n",
" return np_dtype_to_code[type(out)]\n",
" else:\n",
" return np_dtype_to_code[out.dtype.type]\n",
"\n",
"\n",
"grid = [[np_result_code(dtype1, dtype2)\n",
" for dtype2 in np_dtypes.values()]\n",
" for dtype1 in np_dtypes.values()]\n",
"table = pd.DataFrame(grid, index=np_dtypes.keys(), columns=np_dtypes.keys())\n",
"display.HTML(table.to_html())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JTMHTqQBs8Jv"
},
"source": [
"### Tensorflow Type Promotion\n",
"\n",
"Tensorflow avoids defining implicit type promotion, except for Python scalars in limited cases. The table is asymmetric because in `tf.add(x, y)`, the type of `y` must be coercible to the type of `x`."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"cellView": "form",
"id": "RvfJd7X-YBvY",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
\n",
"
b
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i*
\n",
"
f*
\n",
"
c*
\n",
"
\n",
" \n",
" \n",
"
\n",
"
b
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
u8
\n",
"
-
\n",
"
u8
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
u8
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
u16
\n",
"
-
\n",
"
-
\n",
"
u16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
u16
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
u32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
u32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
u32
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
u64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
u64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
u64
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
i8
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i8
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i8
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
i16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i16
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
i32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i32
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
i64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i64
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
bf16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
bf16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
bf16
\n",
"
bf16
\n",
"
-
\n",
"
\n",
"
\n",
"
f16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f16
\n",
"
f16
\n",
"
-
\n",
"
\n",
"
\n",
"
f32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f32
\n",
"
f32
\n",
"
-
\n",
"
\n",
"
\n",
"
f64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f64
\n",
"
-
\n",
"
-
\n",
"
f64
\n",
"
f64
\n",
"
-
\n",
"
\n",
"
\n",
"
c64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
c64
\n",
"
-
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
\n",
"
\n",
"
c128
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
\n",
"
\n",
"
i*
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i32
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
f*
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f32
\n",
"
f32
\n",
"
-
\n",
"
\n",
"
\n",
"
c*
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# @title\n",
"\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"from IPython import display\n",
"\n",
"tf_dtypes = {\n",
" 'b': tf.bool,\n",
" 'u8': tf.uint8, 'u16': tf.uint16, 'u32': tf.uint32, 'u64': tf.uint64,\n",
" 'i8': tf.int8, 'i16': tf.int16, 'i32': tf.int32, 'i64': tf.int64,\n",
" 'bf16': tf.bfloat16, 'f16': tf.float16, 'f32': tf.float32, 'f64': tf.float64,\n",
" 'c64': tf.complex64, 'c128': tf.complex128,\n",
" 'i*': int, 'f*': float, 'c*': complex}\n",
"\n",
"tf_dtype_to_code = {val: key for key, val in tf_dtypes.items()}\n",
"\n",
"def make_tf_zero(dtype):\n",
" if dtype in {int, float, complex}:\n",
" return dtype(0)\n",
" else:\n",
" return tf.zeros(1, dtype=dtype)\n",
"\n",
"def result_code(dtype1, dtype2):\n",
" try:\n",
" out = tf.add(make_tf_zero(dtype1), make_tf_zero(dtype2))\n",
" except (TypeError, tf.errors.InvalidArgumentError):\n",
" return '-'\n",
" else:\n",
" if type(out) in {int, float, complex}:\n",
" return tf_dtype_to_code[type(out)]\n",
" else:\n",
" return tf_dtype_to_code[out.dtype]\n",
"\n",
"\n",
"grid = [[result_code(dtype1, dtype2)\n",
" for dtype2 in tf_dtypes.values()]\n",
" for dtype1 in tf_dtypes.values()]\n",
"table = pd.DataFrame(grid, index=tf_dtypes.keys(), columns=tf_dtypes.keys())\n",
"display.HTML(table.to_html())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mff8P-dptB1P"
},
"source": [
"### PyTorch Type Promotion\n",
"\n",
"Notice that torch does not include unsigned integer types larger than `uint8`.\n",
"Aside from this and some details about promotion with scalar/weak types, the table is close to that used by `jax.numpy`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"cellView": "form",
"id": "U2demrM6da9Y",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
\n",
"
b
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i*
\n",
"
f*
\n",
"
c*
\n",
"
\n",
" \n",
" \n",
"
\n",
"
b
\n",
"
b
\n",
"
u8
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i64
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
u8
\n",
"
u8
\n",
"
u8
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
u8
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
u16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
u32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
u64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
\n",
"
\n",
"
i8
\n",
"
i8
\n",
"
i16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i8
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
i16
\n",
"
i16
\n",
"
i16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i16
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i32
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i64
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
bf16
\n",
"
bf16
\n",
"
c64
\n",
"
\n",
"
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f32
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f16
\n",
"
f16
\n",
"
c64
\n",
"
\n",
"
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f32
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
c128
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
\n",
"
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
\n",
"
\n",
"
i*
\n",
"
i64
\n",
"
u8
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i64
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
f*
\n",
"
f32
\n",
"
f32
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
\n",
"
\n",
"
c*
\n",
"
c64
\n",
"
c64
\n",
"
-
\n",
"
-
\n",
"
-
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# @title\n",
"import torch\n",
"import pandas as pd\n",
"from IPython import display\n",
"\n",
"torch_dtypes = {\n",
" 'b': torch.bool,\n",
" 'u8': torch.uint8, 'u16': 'invalid', 'u32': 'invalid', 'u64': 'invalid',\n",
" 'i8': torch.int8, 'i16': torch.int16, 'i32': torch.int32, 'i64': torch.int64,\n",
" 'bf16': torch.bfloat16, 'f16': torch.float16, 'f32': torch.float32, 'f64': torch.float64,\n",
" 'c64': torch.complex64, 'c128': torch.complex128,\n",
" 'i*': int, 'f*': float, 'c*': complex}\n",
"\n",
"torch_dtype_to_code = {val: key for key, val in torch_dtypes.items()}\n",
"\n",
"def make_torch_zero(dtype):\n",
" if dtype in {int, float, complex}:\n",
" return dtype(0)\n",
" else:\n",
" return torch.zeros(1, dtype=dtype)\n",
"\n",
"def torch_result_code(dtype1, dtype2):\n",
" try:\n",
" out = torch.add(make_torch_zero(dtype1), make_torch_zero(dtype2))\n",
" except TypeError:\n",
" return '-'\n",
" else:\n",
" if type(out) in {int, float, complex}:\n",
" return torch_dtype_to_code[type(out)]\n",
" else:\n",
" return torch_dtype_to_code[out.dtype]\n",
"\n",
"\n",
"grid = [[torch_result_code(dtype1, dtype2)\n",
" for dtype2 in torch_dtypes.values()]\n",
" for dtype1 in torch_dtypes.values()]\n",
"table = pd.DataFrame(grid, index=torch_dtypes.keys(), columns=torch_dtypes.keys())\n",
"display.HTML(table.to_html())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-7FLQxLqtIwp"
},
"source": [
"### JAX Type Promotion: `jax.numpy`\n",
"\n",
"`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"cellView": "form",
"id": "-AGKe0f9iQ4Z",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
\n",
"
b
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i*
\n",
"
f*
\n",
"
c*
\n",
"
\n",
" \n",
" \n",
"
\n",
"
b
\n",
"
b
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i*
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
u8
\n",
"
u8
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
u8
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
u16
\n",
"
u16
\n",
"
u16
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
u16
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
u32
\n",
"
u32
\n",
"
u32
\n",
"
u32
\n",
"
u32
\n",
"
u64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
u32
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
u64
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
u64
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
i8
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
f*
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i8
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
i16
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
f*
\n",
"
i16
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i16
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i64
\n",
"
f*
\n",
"
i32
\n",
"
i32
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i32
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
f*
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i64
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
bf16
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
bf16
\n",
"
bf16
\n",
"
c64
\n",
"
\n",
"
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f16
\n",
"
f32
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f16
\n",
"
f16
\n",
"
c64
\n",
"
\n",
"
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f32
\n",
"
f32
\n",
"
c64
\n",
"
\n",
"
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
c128
\n",
"
f64
\n",
"
f64
\n",
"
c128
\n",
"
\n",
"
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
\n",
"
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
c128
\n",
"
\n",
"
\n",
"
i*
\n",
"
i*
\n",
"
u8
\n",
"
u16
\n",
"
u32
\n",
"
u64
\n",
"
i8
\n",
"
i16
\n",
"
i32
\n",
"
i64
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
i*
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
f*
\n",
"
bf16
\n",
"
f16
\n",
"
f32
\n",
"
f64
\n",
"
c64
\n",
"
c128
\n",
"
f*
\n",
"
f*
\n",
"
c*
\n",
"
\n",
"
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
c64
\n",
"
c64
\n",
"
c64
\n",
"
c128
\n",
"
c64
\n",
"
c128
\n",
"
c*
\n",
"
c*
\n",
"
c*
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# @title\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import pandas as pd\n",
"from IPython import display\n",
"jax.config.update('jax_enable_x64', True)\n",
"\n",
"jnp_dtypes = {\n",
" 'b': jnp.bool_.dtype,\n",
" 'u8': jnp.uint8.dtype, 'u16': jnp.uint16.dtype, 'u32': jnp.uint32.dtype, 'u64': jnp.uint64.dtype,\n",
" 'i8': jnp.int8.dtype, 'i16': jnp.int16.dtype, 'i32': jnp.int32.dtype, 'i64': jnp.int64.dtype,\n",
" 'bf16': jnp.bfloat16.dtype, 'f16': jnp.float16.dtype, 'f32': jnp.float32.dtype, 'f64': jnp.float64.dtype,\n",
" 'c64': jnp.complex64.dtype, 'c128': jnp.complex128.dtype,\n",
" 'i*': int, 'f*': float, 'c*': complex}\n",
"\n",
"\n",
"jnp_dtype_to_code = {val: key for key, val in jnp_dtypes.items()}\n",
"\n",
"def make_jnp_zero(dtype):\n",
" if dtype in {int, float, complex}:\n",
" return dtype(0)\n",
" else:\n",
" return jnp.zeros((), dtype=dtype)\n",
"\n",
"def jnp_result_code(dtype1, dtype2):\n",
" try:\n",
" out = jnp.add(make_jnp_zero(dtype1), make_jnp_zero(dtype2))\n",
" except TypeError:\n",
" return '-'\n",
" else:\n",
" if hasattr(out, 'aval') and out.aval.weak_type:\n",
" return out.dtype.kind + '*'\n",
" elif type(out) in {int, float, complex}:\n",
" return jnp_dtype_to_code[type(out)]\n",
" else:\n",
" return jnp_dtype_to_code[out.dtype]\n",
"\n",
"grid = [[jnp_result_code(dtype1, dtype2)\n",
" for dtype2 in jnp_dtypes.values()]\n",
" for dtype1 in jnp_dtypes.values()]\n",
"table = pd.DataFrame(grid, index=jnp_dtypes.keys(), columns=jnp_dtypes.keys())\n",
"display.HTML(table.to_html())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cCVohsfUtP0m"
},
"source": [
"### JAX Type Promotion: `jax.lax`\n",
"`jax.lax` is lower-level, and does not do any implicit promotion. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"cellView": "form",
"id": "ES97obW6iRjf",
"tags": [
"hide-input"
]
},
"outputs": [
{
"data": {
"text/html": [
"