From 9b36dfbcb3e6ee32f74dc239932d97a25eb581dd Mon Sep 17 00:00:00 2001 From: Muhammad Ammar Nabil Date: Sat, 26 Aug 2023 11:05:55 +0700 Subject: [PATCH] Upload C2W2 Assignment --- .../Week 2/C2W2_Assignment.ipynb | 1594 +++++++++++++++++ 1 file changed, 1594 insertions(+) create mode 100644 Custom and Distributed Training with Tensorflow/Week 2/C2W2_Assignment.ipynb diff --git a/Custom and Distributed Training with Tensorflow/Week 2/C2W2_Assignment.ipynb b/Custom and Distributed Training with Tensorflow/Week 2/C2W2_Assignment.ipynb new file mode 100644 index 0000000..32cd231 --- /dev/null +++ b/Custom and Distributed Training with Tensorflow/Week 2/C2W2_Assignment.ipynb @@ -0,0 +1,1594 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "AWqcoPhU3RJN" + }, + "source": [ + "# Breast Cancer Prediction\n", + "\n", + "In this exercise, you will train a neural network on the [Breast Cancer Dataset](https://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+(original)) to predict if the tumor is malignant or benign.\n", + "\n", + "If you get stuck, we recommend that you review the ungraded labs for this week." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "st5AIBFZ5mEQ" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "JkMXve8XuN5X" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.keras.models import Model\n", + "from tensorflow.keras.layers import Dense, Input\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as mticker\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import confusion_matrix\n", + "import itertools\n", + "from tqdm import tqdm\n", + "from typing import Union, Optional\n", + "import tensorflow_datasets as tfds\n", + "\n", + "tf.get_logger().setLevel('ERROR')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "yUc3HpEQ5s6U" + }, + "source": [ + "## Load and Preprocess the Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7-TQFUXu5wS_" + }, + "source": [ + "We first load the dataset and create a data frame using pandas. We explicitly specify the column names because the CSV file does not have column headers." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "HVh-W73J5TjS" + }, + "outputs": [], + "source": [ + "data_file: str = './data/data.csv'\n", + "col_names = [\n", + " \"id\",\n", + " \"clump_thickness\",\n", + " \"un_cell_size\",\n", + " \"un_cell_shape\",\n", + " \"marginal_adheshion\",\n", + " \"single_eph_cell_size\",\n", + " \"bare_nuclei\",\n", + " \"bland_chromatin\",\n", + " \"normal_nucleoli\",\n", + " \"mitoses\",\n", + " \"class\"\n", + "]\n", + "df: pd.DataFrame = pd.read_csv(data_file, names=col_names, header=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "XEv8vS_P6HaV" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idclump_thicknessun_cell_sizeun_cell_shapemarginal_adheshionsingle_eph_cell_sizebare_nucleibland_chromatinnormal_nucleolimitosesclass
010000255111213112
1100294554457103212
210154253111223112
310162776881343712
410170234113213112
\n", + "
" + ], + "text/plain": [ + " id clump_thickness un_cell_size un_cell_shape marginal_adheshion \\\n", + "0 1000025 5 1 1 1 \n", + "1 1002945 5 4 4 5 \n", + "2 1015425 3 1 1 1 \n", + "3 1016277 6 8 8 1 \n", + "4 1017023 4 1 1 3 \n", + "\n", + " single_eph_cell_size bare_nuclei bland_chromatin normal_nucleoli \\\n", + "0 2 1 3 1 \n", + "1 7 10 3 2 \n", + "2 2 2 3 1 \n", + "3 3 4 3 7 \n", + "4 2 1 3 1 \n", + "\n", + " mitoses class \n", + "0 1 2 \n", + "1 1 2 \n", + "2 1 2 \n", + "3 1 2 \n", + "4 1 2 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NvvbnFL36L85" + }, + "source": [ + "We have to do some preprocessing on the data. We first pop the id column since it is of no use for our problem at hand." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "nDeXwHdA5uUN" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0 1000025\n", + "1 1002945\n", + "2 1015425\n", + "3 1016277\n", + "4 1017023\n", + " ... \n", + "694 776715\n", + "695 841769\n", + "696 888820\n", + "697 897471\n", + "698 897471\n", + "Name: id, Length: 699, dtype: int64" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.pop(\"id\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ubw5LueA6ZEY" + }, + "source": [ + "Upon inspection of data, you can see that some values of the **bare_nuclei** column are unknown. We drop the rows with these unknown values. We also convert the **bare_nuclei** column to numeric. This is required for training the model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MCcOrl1ITVhr" + }, + "outputs": [], + "source": [ + "df = df[df[\"bare_nuclei\"] != '?' ]\n", + "df.bare_nuclei = pd.to_numeric(df.bare_nuclei)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UQMhcTQG7LzY" + }, + "source": [ + "We check the class distribution of the data. You can see that there are two classes, 2.0 and 4.0\n", + "According to the dataset:\n", + "* **2.0 = benign**\n", + "* **4.0 = malignant**\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "SaAdQrBv8daS" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAQXElEQVR4nO3df6zddX3H8efLguioo2Vo17Rs7bL+MZD5gxvGZH/cyhKqMMuSkdQwUxeSZgtLXPaz+IeLfzTDP1jMULI0YqwBvWlQV4KySSp3blNkVtFakNFJgxXSRoHqZYQF9t4f90s8Xu7l/Lj33Fs/PB/Jzf2ez/f7Pd/X+fbT1z3ne+89N1WFJKktr1rpAJKkpWe5S1KDLHdJapDlLkkNstwlqUFnrHQAgPPOO682bdo08v7PPPMMZ5999tIFWiLmGo65hmOu4bSY69ChQz+sqtfPu7KqVvzj4osvrsW49957F7X/uJhrOOYajrmG02Iu4Ou1QK96WUaSGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhp0Wrz9wGId/sEp3rv78yPte+zGK5c4jSStPJ+5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaNHC5J1mV5JtJ7upun5vkniSPdJ/X9mx7Q5KjSR5OcsU4gkuSFjbMM/f3AQ/13N4NHKyqLcDB7jZJLgB2ABcC24BbkqxamriSpEEMVO5JNgJXAh/rGd4O7OuW9wFX94xPVdVzVfUocBS4ZGniSpIGkarqv1FyB/B3wOuAv6yqq5I8XVVrerZ5qqrWJvkIcF9V3daN3wrcXVV3zLnPXcAugHXr1l08NTU18oM4+eQpTjw72r4XbThn5OP2MzMzw+rVq8d2/6My13DMNRxzDWcxubZu3XqoqibmW9f3D2QnuQo4WVWHkkwOcLzMM/aSryBVtRfYCzAxMVGTk4Pc9fxuvv0ANx0e7W99H7t29OP2Mz09zWIe17iYazjmGo65hjOuXIM04mXAu5K8E3gN8ItJbgNOJFlfVU8kWQ+c7LY/Dpzfs/9G4PGlDC1Jenl9r7lX1Q1VtbGqNjH7jdIvVdUfAncCO7vNdgIHuuU7gR1JzkqyGdgC3L/kySVJCxrtWsasG4H9Sa4DHgOuAaiqI0n2Aw8CzwPXV9ULi04qSRrYUOVeVdPAdLf8I+DyBbbbA+xZZDZJ0oj8DVVJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1KC+5Z7kNUnuT/KtJEeSfLAbPzfJPUke6T6v7dnnhiRHkzyc5IpxPgBJ0ksN8sz9OeDtVfUm4M3AtiSXAruBg1W1BTjY3SbJBcAO4EJgG3BLklXjCC9Jml/fcq9ZM93NM7uPArYD+7rxfcDV3fJ2YKqqnquqR4GjwCVLmlqS9LJSVf03mn3mfQj4deCjVfU3SZ6uqjU92zxVVWuTfAS4r6pu68ZvBe6uqjvm3OcuYBfAunXrLp6amhr5QZx88hQnnh1t34s2nDPycfuZmZlh9erVY7v/UZlrOOYajrmGs5hcW7duPVRVE/OtO2OQO6iqF4A3J1kDfC7JG19m88x3F/Pc515gL8DExERNTk4OEmVeN99+gJsOD/RQXuLYtaMft5/p6WkW87jGxVzDMddwzDWcceUa6qdlquppYJrZa+knkqwH6D6f7DY7Dpzfs9tG4PFFJ5UkDWyQn5Z5ffeMnSSvBX4X+C5wJ7Cz22wncKBbvhPYkeSsJJuBLcD9Sx1ckrSwQa5lrAf2ddfdXwXsr6q7knwV2J/kOuAx4BqAqjqSZD/wIPA8cH13WUeStEz6lntVfRt4yzzjPwIuX2CfPcCeRaeTJI3E31CVpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNOmOlA0jSz7tNuz8/8r6f2Hb2Eib5KZ+5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QG9S33JOcnuTfJQ0mOJHlfN35uknuSPNJ9Xtuzzw1JjiZ5OMkV43wAkqSXGuSZ+/PAX1TVbwCXAtcnuQDYDRysqi3Awe423bodwIXANuCWJKvGEV6SNL++5V5VT1TVN7rlnwAPARuA7cC+brN9wNXd8nZgqqqeq6pHgaPAJUsdXJK0sKGuuSfZBLwF+BqwrqqegNkvAMAbus02AN/v2e14NyZJWiapqsE2TFYD/wrsqarPJnm6qtb0rH+qqtYm+Sjw1aq6rRu/FfhCVX1mzv3tAnYBrFu37uKpqamRH8TJJ09x4tnR9r1owzkjH7efmZkZVq9ePbb7H5W5hmOu4bwScx3+wamR9918zqqRc23duvVQVU3Mt26gv6Ga5EzgM8DtVfXZbvhEkvVV9USS9cDJbvw4cH7P7huBx+feZ1XtBfYCTExM1OTk5CBR5nXz7Qe46fBofw722LWjH7ef6elpFvO4xsVcwzHXcF6Jud67yL+hOo5cg/y0TIBbgYeq6u97Vt0J7OyWdwIHesZ3JDkryWZgC3D/0kWWJPUzyNPdy4D3AIeTPNCNvR+4Edif5DrgMeAagKo6kmQ/8CCzP2lzfVW9sOTJJUkL6lvuVfXvQBZYffkC++wB9iwilyRpEfwNVUlqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUoL7lnuTjSU4m+U7P2LlJ7knySPd5bc+6G5IcTfJwkivGFVyStLBBnrl/Atg2Z2w3cLCqtgAHu9skuQDYAVzY7XNLklVLllaSNJC+5V5VXwaenDO8HdjXLe8Dru4Zn6qq56rqUeAocMkSZZUkDShV1X+jZBNwV1W9sbv9dFWt6Vn/VFWtTfIR4L6quq0bvxW4u6rumOc+dwG7ANatW3fx1NTUyA/i5JOnOPHsaPtetOGckY/bz8zMDKtXrx7b/Y/KXMMx13BeibkO/+DUyPtuPmfVyLm2bt16qKom5lt3xsiJ5pd5xub96lFVe4G9ABMTEzU5OTnyQW++/QA3HR7toRy7dvTj9jM9Pc1iHte4mGs45hrOKzHXe3d/fuR9P7Ht7LHkGvWnZU4kWQ/QfT7ZjR8Hzu/ZbiPw+OjxJEmjGLXc7wR2dss7gQM94zuSnJVkM7AFuH9xESVJw+p7LSPJp4FJ4Lwkx4G/BW4E9ie5DngMuAagqo4k2Q88CDwPXF9VL4wpuyRpAX3LvarevcCqyxfYfg+wZzGhJEmL42+oSlKDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGja3ck2xL8nCSo0l2j+s4kqSXGku5J1kFfBR4B3AB8O4kF4zjWJKklxrXM/dLgKNV9b2q+l9gCtg+pmNJkuY4Y0z3uwH4fs/t48Bv9W6QZBewq7s5k+ThRRzvPOCHo+yYDy3iqP2NnGvMzDUccw3HXEPY+qFF5frVhVaMq9wzz1j9zI2qvcDeJTlY8vWqmliK+1pK5hqOuYZjruG80nKN67LMceD8ntsbgcfHdCxJ0hzjKvf/BLYk2Zzk1cAO4M4xHUuSNMdYLstU1fNJ/hT4F2AV8PGqOjKOY3WW5PLOGJhrOOYajrmG84rKlarqv5Uk6eeKv6EqSQ2y3CWpQadtuSc5P8m9SR5KciTJ++bZJkn+oXuLg28neWvPurG8/cGAua7t8nw7yVeSvKln3bEkh5M8kOTry5xrMsmp7tgPJPlAz7qVPF9/1ZPpO0leSHJut25c5+s1Se5P8q0u1wfn2WYl5tcguVZifg2SayXm1yC5ln1+9Rx7VZJvJrlrnnXjnV9VdVp+AOuBt3bLrwP+C7hgzjbvBO5m9ufqLwW+1o2vAv4b+DXg1cC35u475lxvA9Z2y+94MVd3+xhw3gqdr0ngrnn2XdHzNWf73wO+tAznK8DqbvlM4GvApafB/Bok10rMr0FyrcT86ptrJeZXz/3/OfCpBc7LWOfXafvMvaqeqKpvdMs/AR5i9jdfe20HPlmz7gPWJFnPGN/+YJBcVfWVqnqqu3kfsz/nP1YDnq+FrOj5muPdwKeX4th9clVVzXQ3z+w+5v50wUrMr765Vmh+DXK+FrKi52uOZZlfAEk2AlcCH1tgk7HOr9O23Hsl2QS8hdmvyr3me5uDDS8zvly5el3H7FfnFxXwxSSHMvsWDEuuT67f7l7C3p3kwm7stDhfSX4B2AZ8pmd4bOere8n8AHASuKeqTov5NUCuXss2vwbMtezza9DztdzzC/gw8NfA/y2wfqzza1xvP7Bkkqxm9h/jz6rqx3NXz7NLvcz4cuV6cZutzP7n+52e4cuq6vEkbwDuSfLdqvryMuX6BvCrVTWT5J3APwFbOE3OF7Mvmf+jqp7sGRvb+aqqF4A3J1kDfC7JG6vqO72x59vtZcaXxAC5ZsMt8/waINeKzK9BzxfLOL+SXAWcrKpDSSYX2myesSWbX6f1M/ckZzJbCLdX1Wfn2WShtzkY69sfDJCLJL/J7Mux7VX1oxfHq+rx7vNJ4HPMvgRbllxV9eMXX8JW1ReAM5Ocx2lwvjo7mPOSeZznq+cYTwPTzD6r67Ui82uAXCsyv/rlWqn51S9Xj+WcX5cB70pyjNnLKm9PctucbcY7v4a9SL9cH8x+9fok8OGX2eZKfvYbEvd342cA3wM289NvSFy4jLl+BTgKvG3O+NnA63qWvwJsW8Zcv8xPf3HtEuCxbr8VPV/dducATwJnL9P5ej2wplt+LfBvwFWnwfwaJNdKzK9Bcq3E/OqbayXm15xjTzL/N1THOr9O58sylwHvAQ5319MA3s/sxKaq/hH4ArPfcT4K/A/wR926cb79wSC5PgD8EnBLEoDna/Zd39Yx+7IRZv8BP1VV/7yMuf4A+JMkzwPPAjtqdjat9PkC+H3gi1X1TM++4zxf64F9mf3DMq8C9lfVXUn+uCfXSsyvQXKtxPwaJNdKzK9BcsHyz695Lef88u0HJKlBp/U1d0nSaCx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1KD/B7SOaRoKrKp/AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "df['class'].hist(bins=20) " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ENjMKvxQ6sWy" + }, + "source": [ + "We are going to model this problem as a binary classification problem which detects whether the tumor is malignant or not. Hence, we change the dataset so that:\n", + "* **benign(2.0) = 0**\n", + "* **malignant(4.0) = 1**" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1MVzeUwf_A3E", + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "df['class'] = np.where(df['class'] == 2, 0, 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EGbKO1bR8S9h" + }, + "source": [ + "We then split the dataset into training and testing sets. Since the number of samples is small, we will perform validation on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "aNUy7JcuAXjC" + }, + "outputs": [], + "source": [ + "train: pd.DataFrame\n", + "test: pd.DataFrame\n", + "train, test = train_test_split(df, test_size = 0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "H_ZKokUP8kP3" + }, + "source": [ + "We get the statistics for training. We can look at statistics to get an idea about the distribution of plots. If you need more visualization, you can create additional data plots. We will also be using the mean and standard deviation from statistics for normalizing the data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "k86tBT_QAm2P" + }, + "outputs": [], + "source": [ + "train_stats: pd.DataFrame = train.describe()\n", + "train_stats.pop('class')\n", + "train_stats = train_stats.transpose()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "l8AJ0Crc8u9t" + }, + "source": [ + "We pop the class column from the training and test sets to create train and test outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "V7EGUV-tA5LZ" + }, + "outputs": [], + "source": [ + "train_Y: pd.Series = train.pop(\"class\")\n", + "test_Y: pd.Series = test.pop(\"class\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "N9wVRO5E9AgA" + }, + "source": [ + "Here we normalize the data by using the formula: **X = (X - mean(X)) / StandardDeviation(X)**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NDo__q_AA3j0" + }, + "outputs": [], + "source": [ + "def norm(x: pd.DataFrame) -> pd.DataFrame:\n", + " return (x - train_stats['mean']) / train_stats['std']" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "pdARlWaDA_8G" + }, + "outputs": [], + "source": [ + "norm_train_X: pd.DataFrame = norm(train)\n", + "norm_test_X: pd.DataFrame = norm(test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P6LIVZbj9Usv" + }, + "source": [ + "We now create Tensorflow datasets for training and test sets to easily be able to build and manage an input pipeline for our model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1S0RtsP1Xsj8" + }, + "outputs": [], + "source": [ + "train_dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(\n", + " (norm_train_X.values, train_Y.values)\n", + ")\n", + "test_dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(\n", + " (norm_test_X.values, test_Y.values)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-Nb44PpV9hR4" + }, + "source": [ + "We shuffle and prepare a batched dataset to be used for training in our custom training loop." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "h9qdsNPen5-F" + }, + "outputs": [], + "source": [ + "batch_size: int = 32\n", + "train_dataset = train_dataset.shuffle(buffer_size=len(train)).batch(batch_size)\n", + "\n", + "test_dataset = test_dataset.batch(batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18\n" + ] + } + ], + "source": [ + "a: enumerate = enumerate(train_dataset)\n", + "\n", + "print(len(list(a)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GcbOJ6C79qT5" + }, + "source": [ + "## Define the Model\n", + "\n", + "Now we will define the model. Here, we use the Keras Functional API to create a simple network of two `Dense` layers. We have modelled the problem as a binary classification problem and hence we add a single layer with sigmoid activation as the final layer of the model." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "HU3qcM9WBcMh" + }, + "outputs": [], + "source": [ + "def base_model() -> Model:\n", + " inputs: Input = Input(shape=(len(train.columns)))\n", + "\n", + " x: Dense = Dense(128, activation='relu')(inputs)\n", + " x: Dense = Dense(64, activation='relu')(x)\n", + " outputs: Dense = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n", + " model: Model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", + " return model\n", + "\n", + "model: Model = base_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NBhKIcKQ-Bwe" + }, + "source": [ + "## Define Optimizer and Loss\n", + "\n", + "We use RMSprop optimizer and binary crossentropy as our loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "v5B3vh6fs84i" + }, + "outputs": [], + "source": [ + "optimizer: tf.keras.optimizers.Optimizer = tf.keras.optimizers.RMSprop(\n", + " learning_rate=0.001\n", + ")\n", + "loss_object: tf.keras.losses.Loss = tf.keras.losses.BinaryCrossentropy()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YSNDewgovSZ8" + }, + "source": [ + "## Evaluate Untrained Model\n", + "We calculate the loss on the model before training begins." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "TUScS3GbtPXt" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss before training 0.6629\n" + ] + } + ], + "source": [ + "outputs: tf.Tensor = model(norm_test_X.values)\n", + "loss_value: tf.Tensor = loss_object(\n", + " y_true=test_Y.values, y_pred=tf.reshape(outputs, outputs.shape[0])\n", + ")\n", + "print(f\"Loss before training {loss_value:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jPPb5ewkzMBY" + }, + "source": [ + "We also plot the confusion matrix to visualize the true outputs against the outputs predicted by the model." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ueenYwWZvQM_" + }, + "outputs": [], + "source": [ + "def plot_confusion_matrix(\n", + " y_true: np.ndarray,\n", + " y_pred: tf.Tensor,\n", + " title: str='',\n", + " labels: list=[0,1]\n", + ") -> None:\n", + " cm: np.ndarray = confusion_matrix(y_true, y_pred)\n", + " fig: plt.Figure = plt.figure()\n", + " ax: plt.Axes = fig.add_subplot(111)\n", + " cax: object = ax.matshow(cm)\n", + " plt.title(title)\n", + " fig.colorbar(cax)\n", + " ax.set_xticklabels([''] + labels)\n", + " ax.set_yticklabels([''] + labels)\n", + " plt.xlabel('Predicted')\n", + " plt.ylabel('True')\n", + " fmt: str = 'd'\n", + " thresh: np.float64 = cm.max() / 2.\n", + " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", + " plt.text(\n", + " j,\n", + " i,\n", + " format(cm[i, j], fmt),\n", + " horizontalalignment=\"center\",\n", + " color=\"black\" if cm[i, j] > thresh else \"white\"\n", + " )\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "FApnBUNWv-ZR" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATIAAAEQCAYAAAAzovj4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAc+UlEQVR4nO3deZgcZbn38e8vM9lZJKthCYkXAVlkkRzWA4KghkWDvnIEESOyuACvaDwcwB3BV87hoPGIS4SYCIhGBQmLJHkDCIisMazBBCGEkJhJhhDIApOZuc8fVRM6w2S6a9I93TX5fa6rrpla+qm7q7vvfp6nnqpWRGBmlme9qh2AmdmWciIzs9xzIjOz3HMiM7PccyIzs9xzIjOz3OvRiUxSf0m3Slot6XdbUM5pkmaVM7ZqkPQnSRO6+NjLJK2U9M9yx1Utki6RdE2Fyl4k6dhKlL2l+5Q0SlJIqu+OuLpDTSQySZ+U9KikNZKWpR+4fy1D0R8HhgODI+LkrhYSETdExAfLEM8mJB2VvqFuard8v3T5PSWW821J1xfbLiKOi4hpXYhzF2AisFdEvDPr4zsor8MPkqSpki4rsYx7JJ21JXFExPciYovK6Ir0eYakj7Rb/sN0+We6O6a8q3oik/QV4IfA90iSzkjgJ8D4MhS/K7AgIprLUFalrAAOkzS4YNkEYEG5dqDElrzWuwKNEdHQhX1X5Vs/B7WNBSSvM7Ax3pOBf1QtojyLiKpNwPbAGuDkTrbpS5LolqbTD4G+6bqjgCUktYUGYBlwRrruO0ATsCHdx5nAt4HrC8oeBQRQn85/BngeeB14ATitYPn9BY87DHgEWJ3+Paxg3T3Ad4G/pOXMAoZs5rm1xf8z4Nx0WV267JvAPQXbTgJeAl4DHgOOSJePa/c8Hy+I4/I0jvXAbumys9L1PwV+X1D+FcAcQO1iPDZ9fGta/tR0+UeAp4FX03L3LHjMIuA/gCeAN9uO7+aOe8HyqcBlhcccuBJYlb4ex6XrLgdagDfSmH6cLg/gXGAh8EJnxy1dt/H9UBDTBGAxsBL4WsG2vYCLSBJNIzAdGFSw/nTgxXTd19JjcOxmXvep6fP6J7BDuuxE4E/pc/5MwT6/npbbAPwK2L6UfXYW7+aOf56naieycUBzZwcUuBR4EBgGDAUeAL6brjsqffylQG/geGBdwZtj4xt1M/MbX1BgYPpm3yNdNwLYu/BDlf4/iOSDdXr6uFPT+cHp+nvSN8/uQP90/vubeW5HkSStw4CH0mXHAzOBs9g0kX0KGJzuc2L6IejX0fMqiGMxsHf6mN5smsgGkNQKPgMcQfLB3bmzOAvmdwfWAh9Iy70QeA7ok65fBMwDdgH6d1Behx8k3p7INgBnkyT3L5B8kang+Z3V7vEBzE5fo/5ZjltBTL9IX7f9SJLwnun6C0jehzuTfLn+HLgxXbcXSUI9Ml13Fcn7srNEdhkwGfhCumw6yXupMJF9Nj2u7wK2AW4Critln0Xi7fD453mqdtNyMLAyOm/6nQZcGhENEbGCpKZ1esH6Den6DRFxB8mLu0cX42kF9pHUPyKWRcTTHWxzArAwIq6LiOaIuBF4FvhwwTa/jIgFEbGe5A26f2c7jYgHgEGS9gA+TfLN236b6yOiMd3nf5O8OYs9z6kR8XT6mA3tyltH8iG/CrgeOD8ilhQpr80ngNsjYnZa7pUkH/7DCrb5UUS8lB6DrnoxIn4RES3ANJIvl+FFHvP/IuKVtv124bh9JyLWR8TjwOMkCQ3gcyQ1tCUR8SZJEvx42iT8OHBbRNybrvsGyXupmF8Bn5a0PfA+4I/t1p8GXBURz0fEGuBi4JQS99lZvD1OtRNZIzCkyMHdkaT63ObFdNnGMtolwnUk316ZRMRakg/o54Flkm6X9O4S4mmLaaeC+cIze6XGcx1wHnA0cHP7lZImSpqfnoF9laRZPqRImS91tjIiHiZpSosk4ZZqk2MQEa3pvgqPQWf7bnu9erdb3pvki6nNxuOYJl4ofiw32W8XjtvmXrtdgZslvZqWM5+keTuc5Hhs3G/6XmosEicRcT9JK+PrJEmpfdLv6L1fX+I+O4u3x6l2IvsrST/HSZ1ss5TkRWkzMl3WFWtJmlRtNjkDFxEzI+IDJN/8z5I0M4rF0xbTy12Mqc11wBeBOwo+tABIOoKkz+nfSJrN7yDpn1Nb6Jsps9Nbm0g6l6SGspSkeViqTY6BJJE0IwuPQWf7XkaSsEa1Wz6at39JbE7R51zCccviJZI+uncUTP0i4mWS57NLwX4HkLQ2SnE9SZP3bbVwOn7vNwPLS9hnZ/H2OFVNZBGxmqRT+2pJJ0kaIKm3pOMk/We62Y3A1yUNlTQk3b7oUIPNmAccKWlkWp2/uG2FpOGSPiJpIEnfyBqSb7D27gB2T4eM1Ev6BEl/xW1djAmAiHiBpHnxtQ5Wb0vyBl4B1Ev6JrBdwfrlwKgsZyYl7U7ST/Mpkqb6hZI6bQIXmA6cIOkYSb1JPohvkvRfFpU2Ff8AXC5pcPqan0pyHP9UYgzLSfqOOlPsuGXxszTeXQHS92PbmfXfAydK+ldJfUj6bEt9LX5E0td4bwfrbgS+LGm0pG1Izuz/Nm2BFNtnZ/H2ONWukRERVwFfIaleryD5JjmPt/oLLgMeJTkD9iQwN13WlX3NBn6blvUYmyafXiQfyKXAKyRJ5YsdlNFIcoZpIklV/kLgxIhY2ZWY2pV9f0R0VNucSfIBX0BSY3mDTZtQbYN9GyXNLbaftCl/PXBFRDweEQuBS4DrJPUtIc6/kyTA/yE5SfBh4MMR0VTssQW+SHKcnyA5I3cecEJELC/x8ZNI+nxWSfrRZrYpdtyymATMAGZJep2kI/1ggLQv9Vzg1yQ1pVUkJ3GKSvvz5kRERzXMKSQ19XtJztq+AZxf4j43G29PpI6Pn5WDpHEkb6g64JqI+H6VQ7IiJE0h+aJqiIh9qh2PlabqNbKeSlIdcDVwHEmT6VRJe1U3KivBVJJhQZYjTmSVcxDwXHrqvAn4DeW5WsEqKCLuJWnyWo44kVXOTmzaH7OETYcnmFmZOJFVTken+N0haVYBTmSVs4SCcT4kl4p0dfybmXXCiaxyHgHGpGOA+gCnkJwON7MycyKrkHTQ4nkkY5nmA9M3c+2m1RBJN5JccbKHpCWSzqx2TFacx5GZWe65RmZmuedEZma550RmZrnnRGZmuedE1g0knVPtGCwbv2b54kTWPfyhyB+/ZjniRGZmuVdT48iGDKqLUbu0v417/q1obGHo4Lpqh1ERf3+x2M8G5NOGprX07jOw2mGU3RvrVrGhaW1XbvW90YeOHhiNr3R08+S3e+yJN2dGRMVvi1RTv6gyapfePDxzl+IbWs143+fcAsuTeX+etMVlNL7SwsMzR5a0bd2Ihd3yTVdTiczMal8ArSX92l33cSIzs0yCYEOU1rTsLk5kZpaZa2RmlmtB0FJDJwnBiczMuqC1xm527ERmZpkE0OJEZmZ55xqZmeVaABvcR2ZmeRaEm5ZmlnMBLbWVx5zIzCybZGR/bXEiM7OMREuHvz9dPU5kZpZJ0tnvRGZmOZaMI3MiM7Oca3WNzMzyzDUyM8u9QLTU2F3yncjMLDM3Lc0s1wLRFLX1GxS1VT80s5qXDIjtVdJUjKR3SPq9pGclzZd0qKRBkmZLWpj+3aFYOU5kZpZZSzootthUgknAnRHxbmA/YD5wETAnIsYAc9L5TjmRmVkmEaIlepU0dUbSdsCRwLVJudEUEa8C44Fp6WbTgJOKxeREZmaZtaKSpiLeBawAfinpb5KukTQQGB4RywDSv8OKFeREZmaZJJ399SVNwBBJjxZMhT+EWg+8F/hpRBwArKWEZmRHfNbSzDJp6+wv0cqIGLuZdUuAJRHxUDr/e5JEtlzSiIhYJmkE0FBsJ66RmVlmLaGSps5ExD+BlyTtkS46BngGmAFMSJdNAG4pFo9rZGaWSZlH9p8P3CCpD/A8cAZJBWu6pDOBxcDJxQpxIjOzzFqLnJEsVUTMAzpqeh6TpRwnMjPLJLlovLZ6pZzIzCyTQGyosUuUnMjMLJMIig527W5OZGaWUUmDXbuVE5mZZRK4RmZmPYA7+80s1wL5xopmlm/Jz8HVVuqorWjMLAf8A71mlnNB+Ub2l4sTmZll5hqZmeVahFwjM7N8Szr7fYmSmeWaPCDWzPIt6ex3H5mZ5ZxH9ptZrnlk/1bg1dUtnD2xgaefbUKCa34wjEPH9ufH177K1b9cTX2dOP7YAVzxjSHVDtWAYUO25WsXHM+gdwwkIpgx83F+f9tczjjlMD78wX15dfV6ACZffy8PPvZClaOtHRl+fKRbOJGV2QXfWMmHjh7A764ZQVNTsG59K3f/ZR0zZq5l3pyR9O0rGlY2VztMS7W0tHL1lLtZ8HwD/fv35tr//jSPPv4iANNnPMZv/vhIlSOsPRGwodWJrMd67fVW7ntwPb+clPyeaJ8+ok+fOn427TUuPG8H+vZNquPDhviw14rGVWtpXLUWgPXrN7BoSSNDBm1T5ahqW9K0rK1EVlvR5NzzL25g6OA6PntBAwd+YDFnT2xg7bpWFj7fxP0PrefQ41/i6I8u4ZF5b1Q7VOvAO4dtx+7vGs4zC5YB8LHjD2DqpM9w0fnj2GZg3ypHV1ta0usti03dpaKJTNI4SX+X9JykLv2CcJ40Nwdzn3yTz0/Ynsdmj2Rgf3HF/6yiuRlWrW7lgdt35opvDuGUc/5JRFQ7XCvQv19vLvuP8fzomrtYt76JP/5pHqd8/hecccFUGlet4bzPHl3tEGtG2/CLUqbuUrFEJqkOuBo4DtgLOFXSXpXaXy3Yecd6dh5Rz8Hv7QfA/zlxG+Y++SY7jajno8cPRBIHHdCPXr1gZWNrlaO1NnV1vbjsovHM/vN87n1wIQCrVq+jtTWIgFtnPcGeY95Z5ShrSdK0LGXqLpXc00HAcxHxfEQ0Ab8Bxldwf1X3zmH17LJjPX9/rgmAu+5fx16792H8uIHcfX9y9mvBP5po2gBDBrtVXysuOn8ci15q5LczHt24bPAOAzf+f+QhY3hh8cpqhFazWtP79hebuksle513Al4qmF8CHFzB/dWESZcP5fRzl9O0IRg9sjdTfjiMgQN6ceaXl7PvUYvp01v8ctIwpNoah7O1es+eOzHu6L35x6IVTPnBBCAZanHsEXuy2+jkpM2yhtVc+ZNZ1QyzpiRnLbeeay07+qS+rWNI0jnAOQAjd8r/2bz99+nLwzN3edvy665206QWPTn/ZY4Y/19vW+4xY5tXiwNiK9m+WQIUfqJ3Bpa23ygiJkfE2IgYO3RwbWV5M+vY1tS0fAQYI2k08DJwCvDJCu7PzLrBVnXReEQ0SzoPmAnUAVMi4ulK7c/Muk+tDYitaKdURNwB3FHJfZhZ94oQzVtTIjOznmmraVqaWc9Uzj4ySYuA14EWoDkixkoaBPwWGAUsAv4tIlZ1Vk5t1Q/NLBfKfInS0RGxf0SMTecvAuZExBhgTjrfKScyM8ukbRxZBa+1HA9MS/+fBpxU7AFOZGaWWRnHkQUwS9Jj6eB4gOERsQwg/TusWCHuIzOzTCKgufQbKw6R9GjB/OSImFwwf3hELJU0DJgt6dmuxOREZmaZZWg2rizo+3qbiFia/m2QdDPJzSaWSxoREcskjQAaiu3ETUszy6RcfWSSBkratu1/4IPAU8AMYEK62QTglmIxuUZmZplFeYZfDAduTu8EUw/8OiLulPQIMF3SmcBi4ORiBTmRmVlm5bggPCKeB/brYHkjcEyWspzIzCyTCI/sN7PcEy3+OTgzy7sy9ZGVjROZmWWyVd2PzMx6qEj6yWqJE5mZZdadt7EuhROZmWUS7uw3s57ATUszyz2ftTSzXItwIjOzHsDDL8ws99xHZma5FohWn7U0s7yrsQqZE5mZZeTOfjPrEWqsSuZEZmaZuUZmZrkWQGurE5mZ5VkArpGZWd55HJmZ5Z8TmZnlm9zZb2Y9gGtkZpZrAeGzlmaWf05kZpZ3blqaWe45kZlZrnlArJn1BB4Qa2b5V2NnLYve5lGJT0n6Zjo/UtJBlQ/NzGqVorSpu5Ryv9qfAIcCp6bzrwNXVywiM6ttkWHqJqUksoMj4lzgDYCIWAX0qWhUZlbDlHT2lzKVUppUJ+lvkm5L5wdJmi1pYfp3h2JllJLINkiqI82vkoYCrSVFaGY9U3lrZF8C5hfMXwTMiYgxwJx0vlOlJLIfATcDwyRdDtwPfK/kEM2s52ktcSpC0s7ACcA1BYvHA9PS/6cBJxUrp+hZy4i4QdJjwDEk1yWcFBHzizzMzHqqbOPIhkh6tGB+ckRMLpj/IXAhsG3BsuERsQwgIpZJGlZsJ0UTmaSRwDrg1sJlEbG42GPNrGfKcEZyZUSM7bAM6USgISIek3TUlsRTyjiy20lysIB+wGjg78DeW7JjM8ux8pyRPBz4iKTjSXLLdpKuB5ZLGpHWxkYADcUKKtpHFhHviYh9079jgINI+snMzLosIi6OiJ0jYhRwCnBXRHwKmAFMSDebANxSrKzMI/sjYq6kf8n6uFIseGIAH9px/0oUbRWy8mJfHJInzQ+XZ0R+hQe7fh+YLulMYDFwcrEHlNJH9pWC2V7Ae4EVXY3QzHIuKPslShFxD3BP+n8jycnFkpXydVp4NqGZpM/sD1l2YmY9TJ4uGk8Hwm4TEf/eTfGYWQ5053WUpdhsIpNUHxHNkt7bnQGZWQ7kJZEBD5P0h82TNAP4HbC2bWVE3FTh2MysVuUokbUZBDQC7+et8WQBOJGZbYW6+xY9pegskQ1Lz1g+xVsJrE2NPQ0z61Y1dmPFzhJZHbANHf/ukxOZ2VYsTzWyZRFxabdFYmb5kaNEVlt1RzOrDTnrI8s0stbMtiJ5SWQR8Up3BmJm+aEau0d0KXeINTOrab51gZlll5empZlZh3LW2W9m1jEnMjPLPScyM8szUXtnLZ3IzCwb95GZWY/gRGZmuedEZmZ556almeWfE5mZ5Vr4rKWZ9QSukZlZ3rmPzMzyz4nMzHItcCIzs3wTblqaWQ/gRGZm+edEZma5V2OJzPfsN7Ns0rtflDJ1RlI/SQ9LelzS05K+ky4fJGm2pIXp3x2KheREZmbZRYlT594E3h8R+wH7A+MkHQJcBMyJiDHAnHS+U05kZpaZWkubOhOJNels73QKYDwwLV0+DTipWDxOZGaWWYam5RBJjxZM52xSjlQnaR7QAMyOiIeA4RGxDCD9O6xYPO7sN7Nssg2IXRkRYzdbVEQLsL+kdwA3S9qnKyG5RmZm2ZWnj+yt4iJeBe4BxgHLJY0ASP82FHu8E5mZZdI2sr8MZy2HpjUxJPUHjgWeBWYAE9LNJgC3FIvJTUszy0ytZRlINgKYJqmOpFI1PSJuk/RXYLqkM4HFwMnFCnIiM7NsynTReEQ8ARzQwfJG4JgsZTmRmVlmvtbSzPLPiczM8s41MjPLPycyM8s1/4pSzzfx2i9w8AkH8mrDas7ZdyIAZ//n6Rxy4oE0NzWz9B/LufKzV7N29boqR2oArc0bWHzdj4mWZqK1lW3fvR9Djxy3cX3jg3ez4q5b2e2CS6kfsE0VI60dtXiHWA+ILbNZU+/hkuMu32TZ3NmPc/Z7vsLn9v8qLy9cyqkXf7RK0Vl7qqtn5GlfZPRZ/87oM7/K2uefZf3LiwDY8Noq1r2wgPrtit5FZusTUdrUTZzIyuzJ++bz+itrNln22OwnaG1J6uLzH1zIkJ0GVyM064AkevXpC0C0thAtLSR1DmiYfQtD339i26wVKMfI/nJy07KbfeiMo/nz9AeqHYYViNZWFk25iqZVK9nhwMPpv9OuvL7gKeq33Z5+w3eqdni1pwZ/RaliNTJJUyQ1SHqqUvvIm09e8jFamluZc8N91Q7FCqhXL0af9VV2O/9bvLF0MW80LKXxgf/PkIK+MttUOe5HVk6VbFpOJbmS3YAPfPp9HHzCgXz/U5OqHYptRl2//gzYdTfWLHiKDa++wgvXXslzV3+X5tdWs2jKVTSvea3aIdaMWktkFWtaRsS9kkZVqvw8Gfuh/fnEhScx8ahv8eb6pmqHYwWa165BdXXU9etP64Ym1r6wgMGHvp8xF1y6cZvnrv4uo874ss9atgm6tSO/FFXvI0vvGHkOQD8GVDmaLXfJDV9i36P2Zvsh2/LrxT/jV9+ezikXfZTefeu5YtY3AJj/0AImfeEXVY7UAJrXvsayW2+E1lYigu323I9txuxd7bBqXq0Nv6h6IouIycBkgO00qMYOT3bfO+3tTcc7p9xVhUisFP2G7cjoMyd2us1u536jm6LJkRr7pFY9kZlZvtTigFgnMjPLJqJcN1Ysm0oOv7gR+Cuwh6Ql6d0ezawnKPM9+7dUJc9anlqpss2suty0NLN8C6DGmpZOZGaWXW3lMScyM8vOTUszy71aO2vpRGZm2dTg3S+cyMwsk2RAbG1lMicyM8vO9+w3s7xzjczM8s19ZGaWf7V3raUTmZll56almeWaf6DXzHqEGquR+XctzSy7MtzGR9Iuku6WNF/S05K+lC4fJGm2pIXp36K/kOxEZmaZqbW1pKmIZmBiROwJHAKcK2kv4CJgTkSMAeak851yIjOzbIJkQGwpU2fFRCyLiLnp/68D84GdgPHAtHSzacBJxUJyH5mZZSKi7ANi05+OPAB4CBgeEcsgSXaShhV7vBOZmWVXeiIbIunRgvnJ6S+nbSRpG+APwAUR8ZqkzOE4kZlZdqUnspURMXZzKyX1JkliN0TETeni5ZJGpLWxEUBDsZ24j8zMsilTH5mSqte1wPyIuKpg1QxgQvr/BOCWYiG5RmZmmZVwRrIUhwOnA09KmpcuuwT4PjA9/eW1xcDJxQpyIjOzjKIsA2Ij4n6S25t15JgsZTmRmVk2Qc2N7HciM7PsfK2lmeWdb6xoZvnnRGZmuRYBLbXVtnQiM7PsXCMzs9xzIjOzXAvA9+w3s3wLCPeRmVmeBe7sN7MewH1kZpZ7TmRmlm/luWi8nJzIzCybAMpzG5+ycSIzs+xcIzOzfPMlSmaWdwHhcWRmlnse2W9muec+MjPLtQiftTSzHsA1MjPLtyBaWqodxCacyMwsG9/Gx8x6BA+/MLM8CyBcIzOzXAvfWNHMeoBa6+xX1NBpVEkrgBerHUcFDAFWVjsIy6Snvma7RsTQLSlA0p0kx6cUKyNi3JbsrxQ1lch6KkmPRsTYasdhpfNrli+9qh2AmdmWciIzs9xzIusek6sdgGXm1yxHnMi6QURU9UMhqUXSPElPSfqdpAFbUNZUSR9P/79G0l6dbHuUpMO6sI9FkkrtTK6Iar9mlo0T2dZhfUTsHxH7AE3A5wtXSqrrSqERcVZEPNPJJkcBmROZWVZOZFuf+4Dd0trS3ZJ+DTwpqU7Sf0l6RNITkj4HoMSPJT0j6XZgWFtBku6RNDb9f5ykuZIelzRH0iiShPnltDZ4hKShkv6Q7uMRSYenjx0saZakv0n6OaDuPSSWdx4QuxWRVA8cB9yZLjoI2CciXpB0DrA6Iv5FUl/gL5JmAQcAewDvAYYDzwBT2pU7FPgFcGRa1qCIeEXSz4A1EXFlut2vgR9ExP2SRgIzgT2BbwH3R8Slkk4AzqnogbAex4ls69Bf0rz0//uAa0mafA9HxAvp8g8C+7b1fwHbA2OAI4EbI6IFWCrprg7KPwS4t62siHhlM3EcC+wlbaxwbSdp23QfH0sfe7ukVV18nraVciLbOqyPiP0LF6TJZG3hIuD8iJjZbrvjSa4T7oxK2AaSroxDI2J9B7F4ZLZ1mfvIrM1M4AuSegNI2l3SQOBe4JS0D20EcHQHj/0r8D5Jo9PHDkqXvw5sW7DdLOC8thlJbcn1XuC0dNlxwA5le1a2VXAiszbXkPR/zZX0FPBzkhr7zcBC4Engp8Cf2z8wIlaQ9GvdJOlx4LfpqluBj7Z19gP/Fxibnkx4hrfOnn4HOFLSXJIm7uIKPUfroXytpZnlnmtkZpZ7TmRmlntOZGaWe05kZpZ7TmRmlntOZGaWe05kZpZ7/wu6zcKr4il7iQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_confusion_matrix(\n", + " test_Y.values,\n", + " tf.round(outputs),\n", + " title='Confusion Matrix for Untrained Model'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7-HTkbQb-gYp" + }, + "source": [ + "## Define Metrics (Please complete this section)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "AYUyRka1-j87" + }, + "source": [ + "### Define Custom F1Score Metric\n", + "In this example, we will define a custom F1Score metric using the formula. \n", + "\n", + "**F1 Score = 2 * ((precision * recall) / (precision + recall))**\n", + "\n", + "**precision = true_positives / (true_positives + false_positives)**\n", + "\n", + "**recall = true_positives / (true_positives + false_negatives)**\n", + "\n", + "We use `confusion_matrix` defined in `tf.math` to calculate precision and recall.\n", + "\n", + "Here you can see that we have subclassed `tf.keras.Metric` and implemented the three required methods `update_state`, `result` and `reset_states`.\n", + "\n", + "### Please complete the result() method:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PdUe6cqvbzXy" + }, + "outputs": [], + "source": [ + "class F1Score(tf.keras.metrics.Metric):\n", + "\n", + " def __init__(self, name: str='f1_score', **kwargs) -> None:\n", + " '''initializes attributes of the class'''\n", + "\n", + " # call the parent class init\n", + " super(F1Score, self).__init__(name=name, **kwargs)\n", + "\n", + " # Initialize Required variables\n", + " # true positives\n", + " self.tp: tf.Variable = tf.Variable(0, dtype='int32')\n", + " # false positives\n", + " self.fp: tf.Variable = tf.Variable(0, dtype='int32')\n", + " # true negatives\n", + " self.tn: tf.Variable = tf.Variable(0, dtype='int32')\n", + " # false negatives\n", + " self.fn: tf.Variable = tf.Variable(0, dtype='int32')\n", + "\n", + " def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> None:\n", + " '''\n", + " Accumulates statistics for the metric\n", + "\n", + " Args:\n", + " y_true: target values from the test data\n", + " y_pred: predicted values by the model\n", + " '''\n", + "\n", + " # Calulcate confusion matrix.\n", + " conf_matrix: tf.Tensor = tf.math.confusion_matrix(\n", + " y_true, y_pred, num_classes=2\n", + " )\n", + "\n", + " # Update values of true positives, true negatives, false positives and\n", + " # false negatives from confusion matrix.\n", + " self.tn.assign_add(conf_matrix[0][0])\n", + " self.tp.assign_add(conf_matrix[1][1])\n", + " self.fp.assign_add(conf_matrix[0][1])\n", + " self.fn.assign_add(conf_matrix[1][0])\n", + "\n", + " def result(self) -> tf.Tensor:\n", + " '''Computes and returns the metric value tensor.'''\n", + "\n", + " # Calculate precision and recall\n", + " if (tf.add(self.tp, self.fp) == 0):\n", + " precision: float = 1.0\n", + " recall: float = 1.0\n", + " else:\n", + " precision = self.tp / (tf.add(self.tp, self.fp))\n", + " recall = self.tp / (tf.add(self.tp, self.fn))\n", + "\n", + " # Return F1 Score\n", + " ### START CODE HERE ###\n", + " f1_score: tf.Tensor = 2 * (\n", + " (precision * recall)/(tf.add(precision, recall))\n", + " )\n", + " ### END CODE HERE ###\n", + "\n", + " return f1_score\n", + "\n", + " def reset_states(self) -> None:\n", + " '''Resets all of the metric state variables.'''\n", + "\n", + " # The state of the metric will be reset at the start of each epoch.\n", + " self.tp.assign(0)\n", + " self.tn.assign(0)\n", + " self.fp.assign(0)\n", + " self.fn.assign(0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING: AutoGraph could not transform > and will run it as-is.\n", + "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", + "Cause: annotated name 'precision' can't be nonlocal (tmpr2a9i7d5.py, line 21)\n", + "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test Code:\n", + "\n", + "test_F1Score: F1Score = F1Score()\n", + "\n", + "test_F1Score.tp = tf.Variable(2, dtype='int32')\n", + "test_F1Score.fp = tf.Variable(5, dtype='int32')\n", + "test_F1Score.tn = tf.Variable(7, dtype='int32')\n", + "test_F1Score.fn = tf.Variable(9, dtype='int32')\n", + "test_F1Score.result()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Expected Output:**\n", + "\n", + "```txt\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xiTa2CePAOTa" + }, + "source": [ + "We initialize the seprate metrics required for training and validation. In addition to our custom F1Score metric, we are also using `BinaryAccuracy` defined in `tf.keras.metrics`" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7Pa_x-5-CH_V" + }, + "outputs": [], + "source": [ + "train_f1score_metric: F1Score = F1Score()\n", + "val_f1score_metric: F1Score = F1Score()\n", + "\n", + "train_acc_metric: tf.keras.metrics.Metric = tf.keras.metrics.BinaryAccuracy()\n", + "val_acc_metric: tf.keras.metrics.Metric = tf.keras.metrics.BinaryAccuracy()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1huOxRpEAxvf" + }, + "source": [ + "## Apply Gradients (Please complete this section)\n", + "\n", + "The core of training is using the model to calculate the logits on specific set of inputs and compute the loss(in this case **binary crossentropy**) by comparing the predicted outputs to the true outputs. We then update the trainable weights using the optimizer algorithm chosen. The optimizer algorithm requires our computed loss and partial derivatives of loss with respect to each of the trainable weights to make updates to the same.\n", + "\n", + "We use gradient tape to calculate the gradients and then update the model trainable weights using the optimizer.\n", + "\n", + "### Please complete the following function:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MMPe25Dstn0v" + }, + "outputs": [], + "source": [ + "def apply_gradient(\n", + " optimizer: tf.keras.optimizers.Optimizer,\n", + " loss_object: tf.keras.losses.Loss,\n", + " model: Model,\n", + " x: Union[tf.data.Dataset, pd.DataFrame, pd.Series, np.ndarray],\n", + " y: Optional[Union[tf.data.Dataset, pd.DataFrame, pd.Series, np.ndarray]]\n", + ") -> tuple:\n", + " '''\n", + " applies the gradients to the trainable model weights\n", + "\n", + " Args:\n", + " optimizer: optimizer to update model weights\n", + " loss_object: type of loss to measure during training\n", + " model: the model we are training\n", + " x: input data to the model\n", + " y: target values for each input\n", + " '''\n", + "\n", + " with tf.GradientTape() as tape:\n", + " ### START CODE HERE ###\n", + " logits: tf.Tensor = model(x)\n", + " reshape_logits: tf.Tensor = tf.reshape(logits, logits.shape[0])\n", + " loss_value: tf.Tensor = loss_object(y, reshape_logits)\n", + "\n", + " gradients: list = tape.gradient(loss_value, model.trainable_weights)\n", + " optimizer.apply_gradients(zip(gradients, model.trainable_weights))\n", + " ### END CODE HERE ###\n", + "\n", + " return logits, loss_value" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.5393396 ]\n", + " [0.5312672 ]\n", + " [0.5367919 ]\n", + " [0.54731965]\n", + " [0.5336758 ]\n", + " [0.5345792 ]\n", + " [0.48115993]\n", + " [0.5367919 ]]\n", + "0.75001687\n" + ] + } + ], + "source": [ + "# Test Code:\n", + "\n", + "test_model: Model = tf.keras.models.load_model('./test_model')\n", + "test_logits: tf.Tensor\n", + "test_loss: tf.Tensor\n", + "test_logits, test_loss = apply_gradient(\n", + " optimizer, loss_object, test_model, norm_test_X.values, test_Y.values\n", + ")\n", + "\n", + "print(test_logits.numpy()[:8])\n", + "print(test_loss.numpy())\n", + "\n", + "del test_model\n", + "del test_logits\n", + "del test_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Expected Output:**\n", + "\n", + "The output will be close to these values:\n", + "```txt\n", + "[[0.5516499 ]\n", + " [0.52124363]\n", + " [0.5412698 ]\n", + " [0.54203206]\n", + " [0.50022954]\n", + " [0.5459626 ]\n", + " [0.47841492]\n", + " [0.54381996]]\n", + "0.7030578\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JYM6GZPjB40r" + }, + "source": [ + "## Training Loop (Please complete this section)\n", + "\n", + "This function performs training during one epoch. We run through all batches of training data in each epoch to make updates to trainable weights using our previous function.\n", + "You can see that we also call `update_state` on our metrics to accumulate the value of our metrics. \n", + "\n", + "We are displaying a progress bar to indicate completion of training in each epoch. Here we use `tqdm` for displaying the progress bar. \n", + "\n", + "### Please complete the following function:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3fHoh_hgz2PC" + }, + "outputs": [], + "source": [ + "def train_data_for_one_epoch(\n", + " train_dataset: tf.data.Dataset,\n", + " optimizer: tf.keras.optimizers.Optimizer,\n", + " loss_object: tf.keras.losses.Loss,\n", + " model: Model,\n", + " train_acc_metric: tf.keras.metrics.Metric,\n", + " train_f1score_metric: F1Score,\n", + " verbose: bool=True\n", + ") -> list:\n", + " '''\n", + " Computes the loss then updates the weights and metrics for one epoch.\n", + "\n", + " Args:\n", + " train_dataset: the training dataset\n", + " optimizer: optimizer to update model weights\n", + " loss_object: type of loss to measure during training\n", + " model: the model we are training\n", + " train_acc_metric: calculates how often predictions match labels\n", + " train_f1score_metric: custom metric we defined earlier\n", + " '''\n", + " losses: list = []\n", + "\n", + " #Iterate through all batches of training data\n", + " for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n", + "\n", + " #Calculate loss and update trainable variables using optimizer\n", + " ### START CODE HERE ###\n", + " logits: tf.Tensor\n", + " loss_value: tf.Tensor\n", + " logits, loss_value = apply_gradient(\n", + " optimizer, loss_object, model, x_batch_train, y_batch_train\n", + " )\n", + " losses.append(loss_value)\n", + " ### END CODE HERE ###\n", + "\n", + " #Round off logits to nearest integer and cast to integer for calulating\n", + " # metrics\n", + " logits: tf.Tensor = tf.round(logits)\n", + " logits = tf.cast(logits, 'int64')\n", + "\n", + " #Update the training metrics\n", + " ### START CODE HERE ###\n", + " train_acc_metric.update_state(y_batch_train, logits)\n", + " train_f1score_metric.update_state(y_batch_train, logits)\n", + " ### END CODE HERE ###\n", + "\n", + " #Update progress\n", + " if verbose:\n", + " print(f\"Training loss for step {step}: {loss_value:.4f}\")\n", + "\n", + " return losses" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7590486\n", + "0.6305345\n", + "0.5635759\n", + "0.5032073\n", + "0.4497043\n", + "0.41916853\n", + "0.42087567\n", + "0.34722382\n", + "0.30006957\n", + "0.2471482\n", + "0.24749258\n", + "0.3069599\n", + "0.30252105\n", + "0.19269618\n", + "0.20125803\n", + "0.22097261\n", + "0.3462088\n", + "0.031088982\n" + ] + } + ], + "source": [ + "# TEST CODE\n", + "\n", + "test_model: Model = tf.keras.models.load_model('./test_model')\n", + "\n", + "test_losses: list = train_data_for_one_epoch(\n", + " train_dataset,\n", + " optimizer,\n", + " loss_object,\n", + " test_model,\n", + " train_acc_metric,\n", + " train_f1score_metric,\n", + " verbose=False\n", + ")\n", + "\n", + "for test_loss in test_losses:\n", + " print(test_loss.numpy())\n", + "\n", + "del test_model\n", + "del test_losses" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Expected Output:**\n", + "\n", + "The losses should generally be decreasing and will start from around 0.75. For example:\n", + "\n", + "```\n", + "0.7600615\n", + "0.6092045\n", + "0.5525634\n", + "0.4358902\n", + "0.4765755\n", + "0.43327087\n", + "0.40585428\n", + "0.32855004\n", + "0.35755336\n", + "0.3651728\n", + "0.33971977\n", + "0.27372319\n", + "0.25026917\n", + "0.29229593\n", + "0.242178\n", + "0.20602849\n", + "0.15887335\n", + "0.090397514\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "d9RJq8BLCsSF" + }, + "source": [ + "At the end of each epoch, we have to validate the model on the test dataset. The following function calculates the loss on test dataset and updates the states of the validation metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5gLJyAJE0YRc" + }, + "outputs": [], + "source": [ + "def perform_validation() -> list:\n", + " losses: list = []\n", + "\n", + " #Iterate through all batches of validation data.\n", + " for x_val, y_val in test_dataset:\n", + "\n", + " #Calculate validation loss for current batch.\n", + " val_logits: tf.Tensor = model(x_val)\n", + " val_loss: tf.Tensor = loss_object(y_true=y_val, y_pred=val_logits)\n", + " losses.append(val_loss)\n", + "\n", + " #Round off and cast outputs to either or 1\n", + " val_logits: tf.Tensor = tf.cast(tf.round(model(x_val)), 'int64')\n", + "\n", + " #Update validation metrics\n", + " val_acc_metric.update_state(y_val, val_logits)\n", + " val_f1score_metric.update_state(y_val, val_logits)\n", + "\n", + " return losses" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DLymSCkUC-CL" + }, + "source": [ + "Next we define the training loop that runs through the training samples repeatedly over a fixed number of epochs. Here we combine the functions we built earlier to establish the following flow:\n", + "1. Perform training over all batches of training data.\n", + "2. Get values of metrics.\n", + "3. Perform validation to calculate loss and update validation metrics on test data.\n", + "4. Reset the metrics at the end of epoch.\n", + "5. Display statistics at the end of each epoch.\n", + "\n", + "**Note** : We also calculate the training and validation losses for the whole epoch at the end of the epoch." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OOO1x3VyuPUV" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start of epoch 0\n", + "Training loss for step 0: 0.6414\n", + "Training loss for step 1: 0.5039\n", + "Training loss for step 2: 0.4533\n", + "Training loss for step 3: 0.4241\n", + "Training loss for step 4: 0.3874\n", + "Training loss for step 5: 0.3564\n", + "Training loss for step 6: 0.3103\n", + "Training loss for step 7: 0.2521\n", + "Training loss for step 8: 0.2243\n", + "Training loss for step 9: 0.2442\n", + "Training loss for step 10: 0.1588\n", + "Training loss for step 11: 0.1670\n", + "Training loss for step 12: 0.1737\n", + "Training loss for step 13: 0.1456\n", + "Training loss for step 14: 0.1666\n", + "Training loss for step 15: 0.1047\n", + "Training loss for step 16: 0.1636\n", + "Training loss for step 17: 0.1142\n", + "\n", + "Epoch 0: Train loss: 0.2773 Validation Loss: 0.1318, Train Accuracy: 0.9340, Validation Accuracy 0.9688, Train F1 Score: 0.9066, Validation F1 Score: 0.9451\n", + "\n", + "Start of epoch 1\n", + "Training loss for step 0: 0.1433\n", + "Training loss for step 1: 0.0735\n", + "Training loss for step 2: 0.1354\n", + "Training loss for step 3: 0.0889\n", + "Training loss for step 4: 0.0968\n", + "Training loss for step 5: 0.0775\n", + "Training loss for step 6: 0.0712\n", + "Training loss for step 7: 0.1093\n", + "Training loss for step 8: 0.2101\n", + "Training loss for step 9: 0.1881\n", + "Training loss for step 10: 0.1327\n", + "Training loss for step 11: 0.0880\n", + "Training loss for step 12: 0.0340\n", + "Training loss for step 13: 0.0994\n", + "Training loss for step 14: 0.1014\n", + "Training loss for step 15: 0.0458\n", + "Training loss for step 16: 0.1274\n", + "Training loss for step 17: 0.0341\n", + "\n", + "Epoch 1: Train loss: 0.1032 Validation Loss: 0.0905, Train Accuracy: 0.9722, Validation Accuracy 0.9688, Train F1 Score: 0.9592, Validation F1 Score: 0.9451\n", + "\n", + "Start of epoch 2\n", + "Training loss for step 0: 0.0601\n", + "Training loss for step 1: 0.1801\n", + "Training loss for step 2: 0.0864\n", + "Training loss for step 3: 0.0265\n", + "Training loss for step 4: 0.0535\n", + "Training loss for step 5: 0.0345\n", + "Training loss for step 6: 0.0427\n", + "Training loss for step 7: 0.0828\n", + "Training loss for step 8: 0.0267\n", + "Training loss for step 9: 0.2132\n", + "Training loss for step 10: 0.0169\n", + "Training loss for step 11: 0.0367\n", + "Training loss for step 12: 0.1306\n", + "Training loss for step 13: 0.1070\n", + "Training loss for step 14: 0.0723\n", + "Training loss for step 15: 0.0167\n", + "Training loss for step 16: 0.1844\n", + "Training loss for step 17: 0.0057\n", + "\n", + "Epoch 2: Train loss: 0.0765 Validation Loss: 0.0823, Train Accuracy: 0.9722, Validation Accuracy 0.9688, Train F1 Score: 0.9592, Validation F1 Score: 0.9451\n", + "\n", + "Start of epoch 3\n", + "Training loss for step 0: 0.0538\n", + "Training loss for step 1: 0.0361\n", + "Training loss for step 2: 0.0543\n", + "Training loss for step 3: 0.1211\n", + "Training loss for step 4: 0.1162\n", + "Training loss for step 5: 0.1754\n", + "Training loss for step 6: 0.0680\n", + "Training loss for step 7: 0.1673\n", + "Training loss for step 8: 0.0212\n", + "Training loss for step 9: 0.0246\n", + "Training loss for step 10: 0.0806\n", + "Training loss for step 11: 0.0286\n", + "Training loss for step 12: 0.0149\n", + "Training loss for step 13: 0.0249\n", + "Training loss for step 14: 0.1345\n", + "Training loss for step 15: 0.0230\n", + "Training loss for step 16: 0.0803\n", + "Training loss for step 17: 0.0189\n", + "\n", + "Epoch 3: Train loss: 0.0691 Validation Loss: 0.0807, Train Accuracy: 0.9757, Validation Accuracy 0.9688, Train F1 Score: 0.9641, Validation F1 Score: 0.9451\n", + "\n", + "Start of epoch 4\n", + "Training loss for step 0: 0.1294\n", + "Training loss for step 1: 0.0910\n", + "Training loss for step 2: 0.0628\n", + "Training loss for step 3: 0.1121\n", + "Training loss for step 4: 0.0556\n", + "Training loss for step 5: 0.0828\n", + "Training loss for step 6: 0.1098\n", + "Training loss for step 7: 0.1582\n", + "Training loss for step 8: 0.0331\n", + "Training loss for step 9: 0.0428\n", + "Training loss for step 10: 0.0114\n", + "Training loss for step 11: 0.0088\n", + "Training loss for step 12: 0.0115\n", + "Training loss for step 13: 0.0953\n", + "Training loss for step 14: 0.0174\n", + "Training loss for step 15: 0.1059\n", + "Training loss for step 16: 0.0140\n", + "Training loss for step 17: 0.1329\n", + "\n", + "Epoch 4: Train loss: 0.0708 Validation Loss: 0.0817, Train Accuracy: 0.9757, Validation Accuracy 0.9688, Train F1 Score: 0.9641, Validation F1 Score: 0.9451\n", + "\n" + ] + } + ], + "source": [ + "# Iterate over epochs.\n", + "epochs: int = 5\n", + "epochs_val_losses: list = []\n", + "epochs_train_losses: list = []\n", + "\n", + "for epoch in range(epochs):\n", + " print(f'Start of epoch {epoch}')\n", + " #Perform Training over all batches of train data\n", + " losses_train: list = train_data_for_one_epoch(\n", + " train_dataset,\n", + " optimizer,\n", + " loss_object,\n", + " model,\n", + " train_acc_metric,\n", + " train_f1score_metric\n", + " )\n", + "\n", + " # Get results from training metrics\n", + " train_acc: tf.Tensor = train_acc_metric.result()\n", + " train_f1score: tf.Tensor = train_f1score_metric.result()\n", + "\n", + " #Perform validation on all batches of test data\n", + " losses_val: list = perform_validation()\n", + "\n", + " # Get results from validation metrics\n", + " val_acc: tf.Tensor = val_acc_metric.result()\n", + " val_f1score: tf.Tensor = val_f1score_metric.result()\n", + "\n", + " #Calculate training and validation losses for current epoch\n", + " losses_train_mean: np.float32 = np.mean(losses_train)\n", + " losses_val_mean: np.float32 = np.mean(losses_val)\n", + " epochs_val_losses.append(losses_val_mean)\n", + " epochs_train_losses.append(losses_train_mean)\n", + "\n", + " print(f'\\nEpoch {epoch}: Train loss: {losses_train_mean:.4f} \\\n", + " Validation Loss: {losses_val_mean:.4f}, Train Accuracy: {train_acc:.4f}, \\\n", + " Validation Accuracy {val_acc:.4f}, Train F1 Score: {train_f1score:.4f}, \\\n", + " Validation F1 Score: {val_f1score:.4f}\\n')\n", + "\n", + " #Reset states of all metrics\n", + " train_acc_metric.reset_states()\n", + " val_acc_metric.reset_states()\n", + " val_f1score_metric.reset_states()\n", + " train_f1score_metric.reset_states()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JoLxueMdzm14" + }, + "source": [ + "## Evaluate the Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6EGW3HVUzqBX" + }, + "source": [ + "### Plots for Evaluation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "t8Wsr6wG0T4h" + }, + "source": [ + "We plot the progress of loss as training proceeds over number of epochs." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MsmF_2n307SP" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAWOUlEQVR4nO3dfZBkdX2o8ec7PdO7vOyy7osKu4AYEVyIII6Eqtybi8SXxTfKxJQggUiRUCRgYSQRvDfJVXMrFevWNcQSpDZIiAEhJhJFahUwyouCwIAIrkhqQ2B3s1AsLPvGwvbOzPf+cXrYmWF2pmfp2Z75zfOpOtV9XnrmRwvPHE+fczoyE0nSzNfV6QFIktrDoEtSIQy6JBXCoEtSIQy6JBXCoEtSIQy6JBXCoGtWiIgnIuJdnR6HNJUMuiQVwqBr1oqIORFxWURsaE6XRcSc5rrFEXFzRGyOiE0RcVdEdDXXXRIR/xUR2yLisYj4zc7+k0iV7k4PQOqg/wWcBBwPJPBt4M+APwcuBtYDS5rbngRkRBwFXAi8IzM3RMQbgNq+HbY0NvfQNZudCXw+M5/JzI3A54Czmut2AQcDh2fmrsy8K6sbHw0Ac4DlEdGTmU9k5n90ZPTSKAZds9khwJPD5p9sLgP4v8Aa4NaIeDwiLgXIzDXAJ4HPAs9ExA0RcQjSNGDQNZttAA4fNn9YcxmZuS0zL87MNwIfBD41dKw8M7+emf+t+doEvrBvhy2NzaBrNumJiLlDE3A98GcRsSQiFgN/AVwLEBEfiIg3RUQAW6kOtQxExFERcUrzw9OXgBeb66SOM+iaTVZRBXhomgv0AQ8DjwAPAv+nue2RwPeB7cA9wBWZeTvV8fO/Bp4FngZeC/zPffZPII0j/IILSSqDe+iSVIgJgx4RV0fEMxHx8z2sj4j4UkSsiYiHI+KE9g9TkjSRVvbQrwFWjLP+VKrjjUcC5wFfefXDkiRN1oRBz8w7gU3jbHIa8LWs/ARYEBEHt2uAkqTWtOPS/6XAumHz65vLnhq9YUScR7UXzwEHHPD2o48+ug2/XpJmjwceeODZzFwy1rp2BD3GWDbmqTOZuRJYCdDb25t9fX1t+PWSNHtExJN7WteOs1zWA4cOm19G82o7SdK+046g3wSc3Tzb5SRgS2a+4nCLJGlqTXjIJSKuB04GFkfEeuB/Az0AmXkl1dV376O6kdEO4JypGqwkac8mDHpmnjHB+gQuaNuIJEl7xStFJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQLQU9IlZExGMRsSYiLh1j/UER8Z2I+FlErI6Ic9o/VEnSeCYMekTUgMuBU4HlwBkRsXzUZhcAv8jM44CTgf8XEfU2j1WSNI5W9tBPBNZk5uOZ2QBuAE4btU0C8yIigAOBTUB/W0cqSRpXK0FfCqwbNr++uWy4LwNvATYAjwAXZebg6B8UEedFRF9E9G3cuHEvhyxJGksrQY8xluWo+fcCDwGHAMcDX46I+a94UebKzOzNzN4lS5ZMerCSpD1rJejrgUOHzS+j2hMf7hzgxqysAf4TOLo9Q5QktaKVoN8PHBkRRzQ/6DwduGnUNmuB3wSIiNcBRwGPt3OgkqTxdU+0QWb2R8SFwC1ADbg6M1dHxPnN9VcCfwlcExGPUB2iuSQzn53CcUuSRpkw6ACZuQpYNWrZlcOebwDe096hSZImwytFJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCmHQJakQBl2SCtFS0CNiRUQ8FhFrIuLSPWxzckQ8FBGrI+KO9g5TkjSR7ok2iIgacDnwbmA9cH9E3JSZvxi2zQLgCmBFZq6NiNdO1YAlSWNrZQ/9RGBNZj6emQ3gBuC0Udt8DLgxM9cCZOYz7R2mJGkirQR9KbBu2Pz65rLh3gy8JiJuj4gHIuLssX5QRJwXEX0R0bdx48a9G7EkaUytBD3GWJaj5ruBtwPvB94L/HlEvPkVL8pcmZm9mdm7ZMmSSQ9WkrRnEx5Dp9ojP3TY/DJgwxjbPJuZLwAvRMSdwHHAv7dllJKkCbWyh34/cGREHBERdeB04KZR23wb+O8R0R0R+wO/Bjza3qFKksYz4R56ZvZHxIXALUANuDozV0fE+c31V2bmoxHxPeBhYBC4KjN/PpUDlySNFJmjD4fvG729vdnX19eR3y1JM1VEPJCZvWOt80pRSSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSqEQZekQhh0SSpES0GPiBUR8VhErImIS8fZ7h0RMRARH2nfECVJrZgw6BFRAy4HTgWWA2dExPI9bPcF4JZ2D1KSNLFW9tBPBNZk5uOZ2QBuAE4bY7tPAN8Enmnj+CRJLWol6EuBdcPm1zeXvSwilgIfBq4c7wdFxHkR0RcRfRs3bpzsWCVJ42gl6DHGshw1fxlwSWYOjPeDMnNlZvZmZu+SJUtaHaMkqQXdLWyzHjh02PwyYMOobXqBGyICYDHwvojoz8xvtWWUkqQJtRL0+4EjI+II4L+A04GPDd8gM48Yeh4R1wA3G3NJ2rcmDHpm9kfEhVRnr9SAqzNzdUSc31w/7nFzSdK+0coeOpm5Clg1atmYIc/Mj7/6YUmSJssrRSWpEAZdkgph0CWpEAZdkgph0CWpEAZdkgph0CWpEAZdkgph0CWpEAZdkgph0CWpEAZdkgph0CWpEAZdkgph0CWpEAZdkgph0CWpEAZdkgph0CWpEDMu6C+9BKtWQWanRyJJ08uMC/p118H73w/veAd85zuGXZKGzLign302XH01PP88fOhD0NsLN91k2CVpxgW9pwfOOQd++Uv4+7+HLVvgtNPg7W+Hb33LsEuavWZc0If09MDHP16F/ZprYNs2+PCH4YQT4F//FQYHOz1CSdq3ZmzQh3R3w+/9Hjz6KPzDP8ALL8Bv/Ra87W1w442GXdLsMeODPqS7uzq+/otfwNe+Vp0N89u/DccfD9/8pmGXVL5igj6kuxvOOqsK+7XXws6d8JGPwHHHwb/8i2GXVK7igj6kVoMzz6zCft11sGsX/M7vVGH/xjcMu6TyFBv0IbUafOxjsHo1fP3rMDAAH/0ovPWt8E//VM1LUgmKD/qQWg3OOAMeeQSuv77aQz/9dPjVX4UbbjDskma+WRP0IbVaFfJHHqlCHlGF/thjq9Abdkkz1awL+pBarTr08sgj1TH1oUMzxx67+9CMJM0kszboQ7q6qg9LH34Y/vmfq7NkzjwTjjmmOkumv7/TI5Sk1sz6oA/p6qpOb/zZz6rTG+v16vTHY46Bf/xHwy5p+msp6BGxIiIei4g1EXHpGOvPjIiHm9PdEXFc+4e6b3R1VRckPfRQdUHS3LnVBUvLl1cXLBl2SdPVhEGPiBpwOXAqsBw4IyKWj9rsP4H/kZlvBf4SWNnuge5rXV3VLQR++tPqFgIHHFDdYuAtb6luMWDYJU03reyhnwisyczHM7MB3ACcNnyDzLw7M59vzv4EWNbeYXZOV1d1068HH6zu5njggdVNwY4+uropmGGXNF20EvSlwLph8+uby/bkXOC7Y62IiPMioi8i+jZu3Nj6KKeBiOo2vQ8+CN/+NsyfX93G96ijqvuz79rV6RFKmu1aCXqMsWzMu45HxDupgn7JWOszc2Vm9mZm75IlS1of5TQSUX2xxgMPVF+ssWABnHtuFfavftWwS+qcVoK+Hjh02PwyYMPojSLircBVwGmZ+Vx7hjd9RcAHPwh9fdVX4S1aBL//+/DmN8NVVxl2SfteK0G/HzgyIo6IiDpwOnDT8A0i4jDgRuCszPz39g9z+oqAD3wA7rsPbr4ZliyBP/gDOPJI+Lu/g0aj0yOUNFtMGPTM7AcuBG4BHgW+kZmrI+L8iDi/udlfAIuAKyLioYjom7IRT1MR1ZdX33svrFoFr3sdnHdeFfaVKw27pKkX2aEv4ezt7c2+vnK7nwm33AKf/WwV+cMOg898pvogdc6cTo9O0kwVEQ9kZu9Y67xSdIpEwIoVcM898L3vwSGHwB/+YbXH/pWvVF+8IUntZNCnWAS8971w993VHvuyZfBHfwRvehNccYVhl9Q+Bn0fiYD3vAd+/GO49dbqEMwFF8Cv/Apcfnn1HaiS9GoY9H0sAt79bvjRj+C22+CII+DCC6uwf/nLhl3S3jPoHRIB73oX3Hkn/Nu/VUH/xCeqxy99CV58sdMjlDTTGPQOi4BTToE77oAf/KA6tn7RRVXY//ZvDbuk1hn0aSIC3vnOKuw//GF1xeknPwlvfCNcdplhlzQxgz4NnXwy3H57Ffajj4Y//uMq7H/zN7BjR6dHJ2m6mnFBf+jphzj/5vP5q7v+imsfvpY7n7yTJzY/wa6B8m6ecvLJVdTvuKP6go1PfaoK+xe/aNglvVJ3pwcwWU9ufpIbH72RjTtG3n63K7o4ZN4hHHbQYdU0/zAOX3D47vmDDmPB3AUdGvWr8xu/UX1wetdd8LnPwcUXwxe+AJ/+NJx/fvXlG5I0Yy/937FrB+u2rGPtlrW7p61reXLzk6zdspZ1W9fRGBh5A5X5c+aPCP7Q86HwHzLvELq7pv/fuB/9qAr7979f3QzsT/+0uljJsEvlG+/S/xkb9IkM5iDPvPAMa7fsjvxQ9IeWPffiyLv8dkUXS+ct3R35YdEfWjZ/zvwpG/Nk/fjHVdhvu60K+5/8SRX2Aw/s9MgkTZVZGfRWvNB4gXVbd+/lP7n5yZeDv3bLWtZtWceuwZHH5g+ac9DIyB808rDOwfMO3ud7+XffXYX91lth8eIq7BdcYNilEhn0vTSYgzy9/emRh3WG4r+l2uvf9OKmEa+pRY1l85eNiPzo8M+bM29KxnvPPfD5z1c3A1u0aHfY503Nr5PUAQZ9Cm1vbH/5WP5Q5IdP67auo39w5DdJL5i7YI97+IcfdDivP/D11Lpqez2me++t9ti/+11YuLC6UOm446rIL15cPS5cCLW9/xWSOsSgd9DA4MAr9vJHh//5l54f8Zruru6Re/ljnLFzYH3i4yn33VeFfdWqV66LqL4PdSjwrTwuXAg9Pe16ZyTtDYM+zW3bue2Vh3WGnbGzfut6BnJgxGteM/c1uyM/xhk7rz/w9XRFdZnBhg3w9NPw7LPw3HMTP453jvtBB03uj8CiRVCvT+W7J80uBn2GGxgc4KntT+3xjJ21W9ay+aXNI17T09Xz8l7+svnL2L9nf+q1eksTA3V27qjz4vZqemFrnR3b6mzbXGf7ljpbNtXZ+nydzc9V06aN1TYM1GGwBsSIscybN7k/AIsXw9y5+/ANlmYQgz4LbN25dY8f3K7fup6X+l+iMdAYMU2FIOiOOt1RpyuricE69NfJ/jqDu+oMNOr076weGahD/5zqcdjU3VVnv3qd/efUOWBunQP3qzNv/zrzD6imBfOqaeFBdRbOr7PoNXXm79/aH6x6rU5Prefl/wcjzSTjBX36X0WjlsyfM59jX3ssx7722Ja2z0z6B/tfEfkpmQbHXr5z14vsaGzhxUaDl3Y12NnfoNFfbd+fDV7MBttpkF0jP1Tmheb09Kt7z7q7useOfVcV+4ggiBHPRz9O+3XN513RNeZr2rVuaMcwSTKTJF/+92xo2UTrp/Q1w+b3Zqzt/uf76DEf5dwTzn11/wKPwaDPUhFBT62HnloPBzC9LzEdzEF2DeyiMdBgx84GGzc1ePrZ6nHjpgbPPd/g2c0Nnt/SYNPWBpu3NdiyrcHWF6pp+0sN6GpAbeTUX2swWG8QBzToOqBBbb8G/fs1YE6Drtog0ZV0dSURSXRV89XzkY/E4LDnu9cRCQzu3qa5fmibZLC5Tb4iHoM5+Ipl03HdUKSGGx5+YMQfgYnWl/aa0euHlu0cmJrvnjTomva6oos53XOY0z2HeXPgdfPh2De0/vqBAdi8eeIPhJ99vHr+3HPVd73u2gWNBvT3T/w7Xq16vTqDaOhx+PN6Heo9469v57LJvKanh+oPWIz83CRz9zQ46Pzo+eVTdLaYQVfxarXqw9ZFi6r7zE9WZhX1RmN35HftGvl8Xy/bsaP11071x2S1Wrz8Pg1NGt8ll8Dxx7f/5xp0aQIRu/dGZ6KBgan/gwPV+9TVVT0OTc6PPb9w4dT8b23QpcLVarDfftWksnneliQVwqBLUiEMuiQVwqBLUiEMuiQVwqBLUiEMuiQVwqBLUiEMuiQVwqBLUiFaCnpErIiIxyJiTURcOsb6iIgvNdc/HBEntH+okqTxTBj0iKgBlwOnAsuBMyJi+ajNTgWObE7nAV9p8zglSRNoZQ/9RGBNZj6emQ3gBuC0UducBnwtKz8BFkTEwW0eqyRpHK3cbXEpsG7Y/Hrg11rYZinw1PCNIuI8qj14gO0R8dikRrvbYuDZvXztbOV7Njm+X5Pj+zU5r+b9OnxPK1oJeoyxbPQt7FvZhsxcCaxs4XeOP6CIvj19SarG5ns2Ob5fk+P7NTlT9X61cshlPXDosPllwIa92EaSNIVaCfr9wJERcURE1IHTgZtGbXMTcHbzbJeTgC2Z+dToHyRJmjoTHnLJzP6IuBC4BagBV2fm6og4v7n+SmAV8D5gDbADOGfqhgy04bDNLOR7Njm+X5Pj+zU5U/J+RfqNrpJUBK8UlaRCGHRJKsSMC/pEtyHQbhFxdUQ8ExE/7/RYZoKIODQifhgRj0bE6oi4qNNjms4iYm5E3BcRP2u+X5/r9JhmgoioRcRPI+Lmdv/sGRX0Fm9DoN2uAVZ0ehAzSD9wcWa+BTgJuMB/v8a1EzglM48DjgdWNM9y0/guAh6dih88o4JOa7chUFNm3gls6vQ4ZorMfCozH2w+30b1H93Szo5q+mre6mN7c7anOXmWxTgiYhnwfuCqqfj5My3oe7rFgNRWEfEG4G3AvZ0dyfTWPHzwEPAMcFtm+n6N7zLg08DgVPzwmRb0lm4xIL0aEXEg8E3gk5m5tdPjmc4ycyAzj6e6OvzEiDi202OariLiA8AzmfnAVP2OmRZ0bzGgKRURPVQxvy4zb+z0eGaKzNwM3I6f2Yzn14EPRcQTVIeLT4mIa9v5C2Za0Fu5DYG0VyIigK8Cj2bmFzs9nukuIpZExILm8/2AdwG/7Oyopq/M/ExmLsvMN1C16weZ+bvt/B0zKuiZ2Q8M3YbgUeAbmbm6s6OaviLieuAe4KiIWB8R53Z6TNPcrwNnUe05PdSc3tfpQU1jBwM/jIiHqXa2bsvMtp+Kp9Z56b8kFWJG7aFLkvbMoEtSIQy6JBXCoEtSIQy6JBXCoEtSIQy6JBXi/wNBbZue3L6YSAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "def plot_metrics(\n", + " train_metric: list,\n", + " val_metric: list,\n", + " metric_name: str,\n", + " title: str,\n", + " ylim: float=5\n", + ") -> None:\n", + " plt.title(title)\n", + " plt.ylim(0,ylim)\n", + " plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))\n", + " plt.plot(train_metric,color='blue',label=metric_name)\n", + " plt.plot(val_metric,color='green',label='val_' + metric_name)\n", + "\n", + "plot_metrics(epochs_train_losses, epochs_val_losses, \"Loss\", \"Loss\", ylim=1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "27fXX7Yqyu5S" + }, + "source": [ + "We plot the confusion matrix to visualize the true values against the values predicted by the model." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_9n2XJ9MwpDS" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATIAAAEQCAYAAAAzovj4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAdI0lEQVR4nO3deZwdVZ338c833QkkhC0LiFkIgiCLD4sRBQTjsIsizqgDA6I+IAqiozLiOiJug48MAjOgRERQBIEAirIkiGZClEmAENawbwlBAlmALGTp/j1/VHW43XT3rercpar7+3696pXcqntP/W7de399zqlTdRQRmJmV2aBmB2BmtqGcyMys9JzIzKz0nMjMrPScyMys9JzIzKz0+nUikzRU0h8kvSzpmg0o51hJ02oZWzNIulnSJ/r42u9LeknS32sdV7NI+oaki+tU9tOSDqpH2Ru6T0kTJIWk1kbE1QiFSGSS/kXSXZKWS3o+/cG9pwZFfwTYGhgZER/tayER8ZuIOKQG8XQiaVL6hbquy/rd0/XTM5bzHUmXV3teRBweEZf1Ic5xwGnALhHxpryv76a8bn9Iki6V9P2MZUyXdOKGxBERP4yIDSqjL9L3GZKO7LL+3HT9JxsdU9k1PZFJ+jJwLvBDkqQzHrgQ+FANit8WeDQi1tWgrHp5EdhX0siKdZ8AHq3VDpTYkM96W2BxRCzqw76b8le/BLWNR0k+Z2B9vB8FnmhaRGUWEU1bgM2B5cBHe3nORiSJbmG6nAtslG6bBCwgqS0sAp4HPpVuOxNYA6xN93EC8B3g8oqyJwABtKaPPwk8CbwKPAUcW7F+ZsXr9gXuBF5O/923Ytt04HvAX9NypgGjenhvHfH/DPhcuq4lXfdtYHrFc88D5gOvAHcD+6frD+vyPu+tiOMHaRyrgB3SdSem238KTKko/0fAbYC6xHhQ+vr2tPxL0/VHAg8Cy9Jyd654zdPAV4H7gNUdx7en416x/lLg+5XHHDgbWJp+Hoen234AtAGvpTH9d7o+gM8BjwFP9Xbc0m3rvw8VMX0CeBZ4CfhmxXMHAV8jSTSLgauBERXbPw48k277ZnoMDurhc780fV9/B7ZM130AuDl9z5+s2Oe30nIXAb8CNs+yz97i7en4l3lpdiI7DFjX2wEFvgv8L7AVMBr4G/C9dNuk9PXfBQYD7wdWVnw51n9Re3i8/gMFNkm/7Dul27YBdq38UaX/H0Hyw/p4+rpj0scj0+3T0y/PjsDQ9PFZPby3SSRJa19gVrru/cBU4EQ6J7LjgJHpPk9LfwQbd/e+KuJ4Ftg1fc1gOieyYSS1gk8C+5P8cMf2FmfF4x2BFcDBabmnA48DQ9LtTwNzgXHA0G7K6/aHxBsT2Vrg0yTJ/WSSP2SqeH8ndnl9ALemn9HQPMetIqafp5/b7iRJeOd0+xdJvodjSf64XgRcmW7bhSShHpBuO4fke9lbIvs+MBk4OV13Ncl3qTKR/d/0uL4FGA5cB/w6yz6rxNvt8S/z0uym5Ujgpei96Xcs8N2IWBQRL5LUtD5esX1tun1tRNxE8uHu1Md42oHdJA2NiOcj4sFunnME8FhE/Doi1kXElcDDwAcrnvPLiHg0IlaRfEH36G2nEfE3YISknYDjSf7ydn3O5RGxON3nf5J8Oau9z0sj4sH0NWu7lLeS5Ed+DnA58PmIWFClvA7/DNwYEbem5Z5N8uPft+I550fE/PQY9NUzEfHziGgDLiP547J1ldf8R0Qs6dhvH47bmRGxKiLuBe4lSWgAnyGpoS2IiNUkSfAjaZPwI8AfI2JGuu3fSb5L1fwKOF7S5sB7gd912X4scE5EPBkRy4GvA0dn3Gdv8fY7zU5ki4FRVQ7um0mqzx2eSdetL6NLIlxJ8tcrl4hYQfID/SzwvKQbJb0tQzwdMY2peFx5Zi9rPL8GTgXeB1zfdaOk0yTNS8/ALiNplo+qUub83jZGxGySprRIEm5WnY5BRLSn+6o8Br3tu+PzGtxl/WCSP0wd1h/HNPFC9WPZab99OG49fXbbAtdLWpaWM4+kebs1yfFYv9/0u7S4SpxExEySVsa3SJJS16Tf3Xe/NeM+e4u332l2IruDpJ/jqF6es5DkQ+kwPl3XFytImlQdOp2Bi4ipEXEwyV/+h0maGdXi6YjpuT7G1OHXwCnATRU/WgAk7U/S5/QxkmbzFiT9c+oIvYcye721iaTPkdRQFpI0D7PqdAwkiaQZWXkMetv38yQJa0KX9dvxxj8SPan6njMctzzmk/TRbVGxbBwRz5G8n3EV+x1G0trI4nKSJu8bauF0/91fB7yQYZ+9xdvvNDWRRcTLJJ3aF0g6StIwSYMlHS7p/6VPuxL4lqTRkkalz6861KAHc4EDJI1Pq/Nf79ggaWtJR0rahKRvZDnJX7CubgJ2TIeMtEr6Z5L+ij/2MSYAIuIpkubFN7vZvCnJF/hFoFXSt4HNKra/AEzIc2ZS0o4k/TTHkTTVT5fUaxO4wtXAEZIOlDSY5Ie4mqT/sqq0qXgt8ANJI9PP/BiS43hzxhheIOk76k2145bHz9J4twVIv48dZ9anAB+Q9B5JQ0j6bLN+FueT9DXO6GbblcCXJG0naTjJmf2r0hZItX32Fm+/0+waGRFxDvBlkur1iyR/SU7l9f6C7wN3kZwBux+Yk67ry75uBa5Ky7qbzslnEMkPciGwhCSpnNJNGYtJzjCdRlKVPx34QES81JeYupQ9MyK6q21OJfmBP0pSY3mNzk2ojsG+iyXNqbaftCl/OfCjiLg3Ih4DvgH8WtJGGeJ8hCQB/hfJSYIPAh+MiDXVXlvhFJLjfB/JGblTgSMi4oWMrz+PpM9nqaTze3hOteOWx3nADcA0Sa+SdKS/CyDtS/0ccAVJTWkpyUmcqtL+vNsiorsa5iUkNfUZJGdtXwM+n3GfPcbbH6n742e1IOkwki9UC3BxRJzV5JCsCkmXkPyhWhQRuzU7Hsum6TWy/kpSC3ABcDhJk+kYSbs0NyrL4FKSYUFWIk5k9bM38Hh66nwN8Ftqc7WC1VFEzCBp8lqJOJHVzxg698csoPPwBDOrESey+unuFL87JM3qwImsfhZQMc6H5FKRvo5/M7NeOJHVz53AW9MxQEOAo0lOh5tZjTmR1Uk6aPFUkrFM84Cre7h20wpE0pUkV5zsJGmBpBOaHZNV53FkZlZ6rpGZWek5kZlZ6TmRmVnpOZGZWek5kTWApJOaHYPl48+sXJzIGsM/ivLxZ1YiTmRmVnqFGkc2akRLTBjX9Tbu5ffi4jZGj2xpdhh18eh9w6o/qYTWsprBVL3HZOm8xgrWxOq+3Op7vUPft0ksXtLdzZPf6O77Vk+NiLrfFqlQM6pMGDeY2VPHVX+iFcahb856d2wrgllx2waXsXhJG7Onjs/03JZtHqs2QU5NFCqRmVnxBdCeaba7xnEiM7NcgmBtZGtaNooTmZnl5hqZmZVaELQV6CQhOJGZWR+0F+xmx05kZpZLAG1OZGZWdq6RmVmpBbDWfWRmVmZBuGlpZiUX0FasPOZEZmb5JCP7i8WJzMxyEm3dzj/dPE5kZpZL0tnvRGZmJZaMI3MiM7OSa3eNzMzKzDUyMyu9QLQV7C75TmRmlpublmZWaoFYE7WZg0LSl4ATSVqs9wOfAoYBVwETgKeBj0XE0t7KKVb90MwKLxkQOyjT0htJY4AvABMjYjegBTga+BpwW0S8FbgtfdwrJzIzy60tHRRbbcmgFRgqqZWkJrYQ+BBwWbr9MuCoLIWYmWUWIdoicx1olKS7Kh5PjojJSTnxnKSzgWeBVcC0iJgmaeuIeD59zvOStqq2EycyM8utPfvwi5ciYmJ3GyRtSVL72g5YBlwj6bi+xONEZma5JJ39NUkdBwFPRcSLAJKuA/YFXpC0TVob2wZYVK0g95GZWS616uwnaVK+W9IwSQIOBOYBNwCfSJ/zCeD31QpyjczMcmurwTiyiJglaQowB1gH3ANMBoYDV0s6gSTZfbRaWU5kZpZLLUf2R8QZwBldVq8mqZ1l5kRmZrm1Zz9r2RBOZGaWS3LRuBOZmZVYINbW6BKlWnEiM7NcIsgzILYhnMjMLCflGRDbEE5kZpZL4BqZmfUD7uw3s1IL5Bsrmlm5JdPBFSt1FCsaMysBT9BrZiUXeGS/mfUDrpGZWalFyDUyMyu3pLPflyiZWanlumd/QziRmVkuSWd/sfrIipVWzawU2hiUaemNpJ0kza1YXpH0RUkjJN0q6bH03y2rxeNEZma5dIzsz7L0Wk7EIxGxR0TsAbwDWAlcTx8m6HXTssbOvWgZv7jiFSTYbechXPKTrXjkibWc8tVFLF8RbDuulcsveBObbeq/IUUzeuxITr/sVEa8aQva24Obfv4nrj//pmaHVUgZJhbJ60DgiYh4RtKHgEnp+suA6cBXe3uxf0019Nzz6/ivXyxj9i1juW/6eNra4Le/X85Jpy3ih98Yxb1/Gc9Rhw/n7AuXNjtU60bbujYu+rdfccKuX+IL+3yDI085lPE7j212WIUTAWvbB2VaSCforVhO6qHYo4Er0/93mqAXqDpBrxNZja1rg1WvBevWBStXtfPmrVt55Ik1HLDPxgAcfMBQrrtxeZOjtO4s+fsyHr/nKQBWLX+NZ+c9x6gxI5ocVfEkTctBmRbSCXorlsldy5M0BDgSuKavMTmR1dCYbVo57bNbMGHi04zZ/Sk233QQh0waxm5v24gbpq4AYMofljN/4bomR2rVbL3taHbYczsenvVYs0MppLb0estqS0aHA3Mi4oX08QvpxLwUYoJeSYdJekTS45KqdtiV3dJlbdwwdQVPzJrAgrnbsWJlcPmUV7n4nK248Jcv885D5vPqimDIkGKdurbONt5kY7495d/46Zd+ycpXVzU7nMLpGH6xoZ39FY7h9WYlFGmCXkktwAXAwcAC4E5JN0TEQ/XaZ7P96fZVTBjfyuhRyajnD79/E+64axXHfWRTpl41BoBHn1jDTX9a0cwwrRctrS2cMeU0/nzF7cy8fnazwymo2l2iJGkYSY74TMXqsyjQBL17A49HxJMAkn4LfAjot4ls/JhWZt29mpUr2xk6VPx55iom7r4Ri15ax1ajWmlvD35w7lJOOn7zZodqPTjt4pN59uHnuPYnf2x2KIVWq3v2R8RKYGSXdYsp0AS9Y4D5FY8XAO+q4/6a7l17bcw/fWATJh4yn9ZWscduG/Hp4zbnol+9zIWXvgwktbRPHb1pkyO17uy639s4+Pj38uR9z/CzOT8G4JJvXsHsm+9pcmTFkpy1HDjXWnaXsuMNT0pOx54ESY2m7L7zlZF85yud/sDwhU9vwRc+vUWTIrKsHvzrwxw8qGorZsAr4q2u69nZvwAYV/F4LLCw65MiYnLHqdnRI4uV5c2se+3plHDVlkapZxXoTuCtkrYDniMZ8PYvddyfmTVAES8ar1sii4h1kk4FpgItwCUR8WC99mdmjTOgbqwYETcBvljNrB+JEOsGUiIzs/5pwDQtzax/GlB9ZGbWfzmRmVmpFXEcmROZmeXWyDFiWTiRmVkuEbCu3Wctzazk3LQ0s1JzH5mZ9QvhRGZmZVe0zv5i9diZWeFF1O5W15K2kDRF0sOS5knaxxP0mlkDiLb2QZmWDM4DbomItwG7A/PowwS9TmRmlluEMi29kbQZcADwi6TMWBMRy0huiX9Z+rTLgKOqxeNEZma55JxFqbcJet8CvAj8UtI9ki6WtAl9mKDXnf1mlk8k/WQZvRQRE3vY1grsBXw+ImZJOo8MzcjuuEZmZrnV6FbXC4AFETErfTyFJLEVa4JeM+t/okad/RHxd2C+pJ3SVQeSTBdZnAl6zaz/ytG0rObzwG8kDQGeBD5FUsEqzAS9ZtZP1Wpkf0TMBbrrQyvMBL1m1g9F+BIlM+sHfNG4mZVeDfvIasKJzMxyCUS7b6xoZmVXsAqZE5mZ5eTOfjPrFwpWJXMiM7PcXCMzs1ILoL3diczMyiwA18jMrOw8jszMys+JzMzKrfptrBvNiczM8nONzMxKLSB81tLMys+JzMzKrkZNS0lPA68CbcC6iJgoaQRwFTABeBr4WEQs7a2cYl3CbmblEBmXbN4XEXtUzLbkCXrNrM46BsRmWfrGE/SaWf1FZFvofYJeSNLiNEl3V2zzBL1m1gDZz1r2NkEvwH4RsVDSVsCtkh7uSzhVa2RKHCfp2+nj8ZL27svOzKx/UGRbqomIhem/i4Drgb2p0wS9FwL7AMekj18FLsjwOjPrj7J29FdJZJI2kbRpx/+BQ4AHqNMEve+KiL0k3QMQEUvTyTTNbEDaoI78SlsD10uCJBddERG3SLqTOkzQu1ZSC2l+lTQaaO9r5GbWD9RgHFlEPAns3s36xeScoDdL0/J8krbrVpJ+AMwEfphnJ2bWz7RnXBqkao0sIn4j6W6SDCngqIiYV/fIzKyYynhjRUnjgZXAHyrXRcSz9QzMzIoryxnJRsrSR3YjSQ4WsDGwHfAIsGsd4zKzIitbIouIt1c+lrQX8Jm6RWRmllPukf0RMUfSO+sRzKP3DePQse+oR9FWJ0+d5bHRZbLm/P+tSTmla1pK+nLFw0HAXsCLdYvIzIotyHOJUkNkqZFtWvH/dSR9ZtfWJxwzK4Uy1cjSgbDDI+IrDYrHzEqgNE1LSa0RsS7t3Dcze11ZEhkwm6Q/bK6kG4BrgBUdGyPiujrHZmZFVaJE1mEEsBj4B14fTxaAE5nZAJT1Fj2N1Fsi2yo9Y/kAryewDgV7G2bWUCU6a9kCDKf7eZ+cyMwGsDLVyJ6PiO82LBIzK48SJbJi1R3NrBgK2EfW2/3Ict3YzMwGkBrOaympRdI9kv6YPh4h6VZJj6X/blmtjB4TWUQsyRaGmQ00as+2ZPSvQOU9Dj1Br5mVh6SxwBHAxRWrc0/Q63ktzSy/7H1koyTdVfF4ckRMrnh8LnA6na/p7jRBbzrnZa+cyMwsn3yd/T1O0CvpA8CiiLhb0qQNCcmJzMzyq81Zy/2AIyW9n+Tu05tJupx0gt60NlazCXrNzDqrwVnLiPh6RIyNiAnA0cCfI+I46jRBr5nZeiLXGcm+OIs6TNBrZva6OgyIjYjpwPT0/7kn6HUiM7P8Cjay34nMzPJzIjOzsivatZZOZGaWnxOZmZVa1P2sZW5OZGaWn2tkZlZ27iMzs/JzIjOzUstx08RGcSIzs1yEm5Zm1g84kZlZ+TmRmVnpOZGZWakVcDo4JzIzy8+JzMzKrmiXKPlW12aWmyLb0msZ0saSZku6V9KDks5M19dugl4zs25lvV9/9ebnauAfImJ3YA/gMEnvxhP0mllD1GbykYiI5enDwekS9GGCXicyM8ulY2R/xqblKEl3VSwndSpLapE0l2TKt1sjYhZdJugFPEGvmdWe2jOftuxxgl6AiGgD9pC0BXC9pN36Eo9rZGaWT+36yF4vMmIZySxKh5FO0AvgCXrNrG5qdNZydFoTQ9JQ4CDgYTxBr5k1RG0GxG4DXCaphaRSdXVE/FHSHXiCXjOrt1pcohQR9wF7drPeE/SaWQP4EiUzKzXPojRwDN5oMOdMP4PBQwbT0jqI26+bxa/OnNLssKwHgyR+d/yx/H35ck669nd88T37ctAO29MewZKVKzn95qksWr6i2WEWgu8QO4CsXb2Wrxz0PV5bsZqW1hZ+MuNM7rxlLvNmPd7s0Kwbn3zHnjy+eAnDNxoCwMWz7+LcmX8D4Pi99uTUfd/Nt6fd1swQiyWKlck8/KKOXluxGoDWwS20trYU7bO31JuGD2fS9m/h6vvuX79u+Zo16/8/bHCrP7suajH8opZcI6ujQYPEhbP/gzfv8CZu+Ok0Hp7t2lgRfevASfxo+gyGDxnSaf2X99+PD++6C6+uXs1xv72mSdEVUAFnUapbjUzSJZIWSXqgXvsouvb24LMTv8Yx257CTu/cngm7jm12SNbF+7bfjsUrV/LgC28cPH7O7X9l/5/9nBsemsfH99qjCdEVl9qzLY1Sz6blpSSXGwx4K15eyb3/8xATD/WPoWjeMWYMB+6wPdM/cwLnfvAI9hk/jv884vBOz7lh3sMcuuNbmxRhMRUtkdWtaRkRMyRNqFf5Rbf5qE1Zt7aNFS+vZMjGg9nrwLdz1Y9vaHZY1sXZM2Zy9oyZALxr3FhO2Hsip914M9tuuQXPLF0GwIE7bM+TS5Y0M8xiCQrX2d/0PrL0th4nAWzMsCZHUzsjttmS0y85mUEtg9CgQcyYcgezbpzT7LAso68csD9vGbEl7REsfOUV/t1nLDvx8IsuImIyMBlgM40o2OHpu6fuf5aT3/n1ZodhOcyav4BZ8xcAcOrv/9DkaAquYL/UpicyMysXD4g1s/KLyHNjxYao5/CLK4E7gJ0kLUhvyWFm/UGNb6y4oep51vKYepVtZs3lpqWZlVsAA6VpaWb9WA2alpLGSfqLpHnpBL3/mq73BL1mVn81umh8HXBaROwMvBv4nKRd8AS9ZtYIao9MS28i4vmImJP+/1VgHjCGPkzQ6z4yM8sn3xnJUZLuqng8OR0E30l6OeOewBsm6JXkCXrNrLaSAbG1maAXQNJw4FrgixHxiqTcMblpaWb5tWdcqpA0mCSJ/SYirktXe4JeM6s/RWRaei0jqXr9ApgXEedUbPIEvWZWZ7Ubtb8f8HHgfklz03XfAM7CE/SaWX3V5lrLiJhJ0uXWHU/Qa2Z15hsrmlmpeYJeM+sXXCMzs9IrVh5zIjOz/NRerLalE5mZ5RNkGuzaSE5kZpaLqD7YtdGcyMwsPycyMys9JzIzKzX3kZlZf+CzlmZWcuGmpZmVXOBEZmb9QLFalk5kZpafx5GZWfkVLJH5Vtdmlk8EtLVnW6qQdImkRZIeqFjnCXrNrAEisi3VXQoc1mWdJ+g1swaoUSKLiBnAki6rPUGvmdVZANnv2Z9pgt4uPEGvmdVbQGQef1F1gt5acCIzs3yCTB35G+AFSduktTFP0GtmdVK7zv7u5J6g14nMzPKrUSKTdCVwB7CTpAXppLxnAQdLegw4OH3cKzctzSyn2l00HhHH9LDJE/SaWR0F4Nv4mFnpFewSJScyM8sp6n3WMjcnMjPLJyCyjyNrCCcyM8sv+8j+hnAiM7P83EdmZqUW4bOWZtYPuEZmZuUWRFtbs4PoxInMzPLJdxufhnAiM7P8PPzCzMosgHCNzMxKLXLdWLEhnMjMLLeidfYrCnQaVdKLwDPNjqMORgEvNTsIy6W/fmbbRsToDSlA0i0kxyeLlyKi6yxJNVeoRNZfSbqrEfctt9rxZ1YuvkOsmZWeE5mZlZ4TWWNUm8fPisefWYk4kTVAhglJ60pSm6S5kh6QdI2kYRtQ1qWSPpL+/2JJu/Ty3EmS9u3DPp6WlLUzuS6a/ZlZPk5kA8OqiNgjInYD1gCfrdwoqaUvhUbEiRHxUC9PmQTkTmRmeTmRDTy3AzuktaW/SLoCuF9Si6QfS7pT0n2SPgOgxH9LekjSjcD66eslTZc0Mf3/YZLmSLpX0m2SJpAkzC+ltcH9JY2WdG26jzsl7Ze+dqSkaZLukXQRoMYeEis7D4gdQCS1AocDt6Sr9gZ2i4inJJ0EvBwR75S0EfBXSdOAPYGdgLcDWwMPAZd0KXc08HPggLSsERGxRNLPgOURcXb6vCuAn0TETEnjganAzsAZwMyI+K6kI4CT6nogrN9xIhsYhkqam/7/duAXJE2+2RHxVLr+EOD/dPR/AZsDbwUOAK6MiDZgoaQ/d1P+u4EZHWVFxJIe4jgI2EVaX+HaTNKm6T7+MX3tjZKW9vF92gDlRDYwrIqIPSpXpMlkReUq4PMRMbXL895Pcp1wb5ThOZB0ZewTEau6icUjs63P3EdmHaYCJ0saDCBpR0mbADOAo9M+tG2A93Xz2juA90raLn3tiHT9q8CmFc+bBpza8UBSR3KdARybrjsc2LJm78oGBCcy63AxSf/XHEkPABeR1NivBx4D7gd+CvxP1xdGxIsk/VrXSboXuCrd9Afgwx2d/cAXgInpyYSHeP3s6ZnAAZLmkDRxn63Te7R+ytdamlnpuUZmZqXnRGZmpedEZmal50RmZqXnRGZmpedEZmal50RmZqX3/wFUR9SM+GDONwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "test_outputs: Model = model(norm_test_X.values)\n", + "plot_confusion_matrix(\n", + " test_Y.values,\n", + " tf.round(test_outputs),\n", + " title='Confusion Matrix for Untrained Model'\n", + ")" + ] + } + ], + "metadata": { + "coursera": { + "schema_names": [ + "TF3C2W2-1", + "TF3C2W2-2", + "TF3C2W2-3" + ] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}