\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "id": "58a30939-f570-45cd-a736-d6f21aeb2a0c",
- "metadata": {},
- "source": [
- "***"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6c3441f3-b5a9-42c4-ba21-2bc682b0d8ac",
- "metadata": {},
- "source": [
- "## **Session 3 - Optimization and Simulation in TT format**"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6f6cc702",
- "metadata": {},
- "source": [
- "***"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "858a33fe-e16e-4dbb-b369-511d7c71fe71",
- "metadata": {},
- "source": [
- "## Exercise 3.1\n",
- "\n",
- "Now, let's look at a well-known example for supervised learning problem which involves the classification of handwritten digits. One of the frequently used datasets in this context is the so-called MNIST dataset. It actually contains 60,000 training images and 10,000 test images with their corresponding classes (i.e., 0,...,9). To reduce computation time, we will examine a reduced dataset extracted from the MNIST dataset. This consists of images with 7x7 pixels, representing only the digits 0 and 1. We have 500 training images and 100 test images at our disposal.\n",
- "\n",
- "**a)**$\\quad$Load the dataset and take a look at some of the training images:\n",
- "\n",
- "> data = np.load('MNIST_full.npz')\n",
- "\n",
- "$\\hspace{0.8cm}$You can access the arrays ```x_train, y_train, x_test, y_test``` by, e.g., ```data['x_train']```.\n",
- "\n",
- "$\\hspace{0.8cm}$The arrays ```x_train``` and ```x_test``` have shape $49 \\times 500$ and $49 \\times 100$, respectively and contain the (flattened) images.\n",
- "\n",
- "$\\hspace{0.8cm}$The arrays ```y_train``` and ```y_test``` have shape $2 \\times 500$ and $2 \\times 100$, respectively and contain the corresponding classes (in one-hot encoding).\n",
- "\n",
- "**b)**$\\quad$For the construction of the transformed data tensor $\\mathbf{\\Theta}$, we choose the two basis functions $\\sin(\\alpha x)$ and $\\cos(\\alpha x)$ with $\\alpha=0.5 \\pi$.\n",
- "\n",
- "$\\hspace{0.8cm}$For this purpose, we use the functions from scikit_tt.data_driven.transform, i.e.,\n",
- "\n",
- "> import scikit_tt.data_driven.transform as tdt\n",
- ">\n",
- "> basis_list = []\n",
- "> \n",
- "> for i in range(order):\n",
- "> \n",
- "> basis_list.append([tdt.Cos(i, alpha), tdt.Sin(i, alpha)])\n",
- "\n",
- "$\\hspace{0.8cm}$Note that ```order``` is simply the number of pixels.\n",
- "\n",
- "**c)**$\\quad$In the next step, we define the initial guess $\\mathbf{\\Xi}$ for the optimization problems\n",
- "\n",
- "$\\hspace{1.25cm}$$\\displaystyle \\min_{\\mathbf{\\Xi} \\in \\mathbb{T}} \\lVert \\mathbf{\\Xi}^\\top \\mathbf{\\Theta} - Y_i \\rVert_F$,\n",
- "\n",
- "$\\hspace{0.8cm}$where $Y_i$ denotes the $i$th row of ```y_train```. We specify that $\\mathbb{T}$ here consists only of tensor trains with a TT rank of 1, i.e.,\n",
- "\n",
- "> cores = [np.ones([1, 2, 1, 1]) for i in range(order)]\n",
- ">\n",
- "> initial_guess = TT(cores).ortho()\n",
- "\n",
- "**d)**$\\quad$Finally, we use the ARR routine from ```scikit_tt.data_driven.regression``` to optimize the tensors for the individual learning problems:\n",
- "\n",
- "> import scikit_tt.data_driven.regression as reg\n",
- ">\n",
- "> xi = reg.arr(x_train, y_train, basis_list, initial_guess, repeats=5, progress=False)\n",
- "\n",
- "**e)**$\\quad$To apply our coefficient tensors to the test data, we construct the corresponding transformed data tensor:\n",
- "\n",
- "> Theta = tdt.basis_decomposition(x_test, basis_list).transpose(cores=49)\n",
- "\n",
- "$\\hspace{0.8cm}$The corresponding (approximate) label vectors can then be computed by contracting the coefficient tensors with $\\mathbf{\\Theta}$. \n",
- "\n",
- "$\\hspace{0.8cm}$For example, the label vector for class 0 can be computed as follows:\n",
- "\n",
- "> xi_0 = TT(xi[0].cores + [np.ones([1,1,1,1])])\n",
- ">\n",
- "> y_0 = (xi_0.transpose()@Theta).matricize()\n",
- "\n",
- "$\\hspace{0.8cm}$Give these lines some thought!\n",
- "\n",
- "**e)**$\\quad$The row indices of the largest entries of $\\begin{pmatrix} - y_0 - \\\\ - y_1 - \\end{pmatrix}$ determine the detected labels. Compute the classification rate!"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "bc6abad4-3b47-447c-9e87-86b44381164b",
- "metadata": {},
- "source": [
- "***"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "d3a9188e-d384-4fc5-8022-34050c91807b",
- "metadata": {},
- "source": [
- "**a)**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "1a02a147-592d-4c28-981e-5859967afa0c",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGdCAYAAAAv9mXmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAW10lEQVR4nO3df4zV9b3n8ffAOAeqwwjIrykD6ooi4lBlhEvR+ovqZZVoN2uNi+lc2jTRjFUkJmZ2s8WmqUP/aKNtySi2Ff8oRdsEtW6BUiqQrlJhWLKoGxWlcRSB2tiZYZI94szZP7qde+cq1DNzPnw5p49H8k08J9/D93Ui+uScMzNUFQqFQgBAiY3IegAAlUlgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEii+mRfsL+/Pw4ePBi1tbVRVVV1si8PwDAUCoXo6emJ+vr6GDHixK9RTnpgDh48GA0NDSf7sgCUUGdnZ0ydOvWE55z0wNTW1kZExOXxH6M6TjvZlwdgGD6KY/H7+PXA/8tP5KQH5m9vi1XHaVFdJTAAZeX///TKT/MRhw/5AUhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhiSIFZvXp1nH322TFq1KiYP39+vPTSS6XeBUCZKzowTz75ZKxYsSJWrlwZe/bsiTlz5sT1118fR44cSbEPgDJVdGC+//3vx9e//vVYtmxZzJo1Kx555JH4zGc+Ez/96U9T7AOgTBUVmA8//DA6Ojpi0aJF//oLjBgRixYtihdffPETH5PP56O7u3vQAUDlKyow77//fvT19cWkSZMG3T9p0qQ4dOjQJz6mra0t6urqBo6GhoahrwWgbCT/KrLW1tbo6uoaODo7O1NfEoBTQHUxJ5911lkxcuTIOHz48KD7Dx8+HJMnT/7Ex+RyucjlckNfCEBZKuoVTE1NTcydOze2bt06cF9/f39s3bo1FixYUPJxAJSvol7BRESsWLEimpubo6mpKebNmxcPPfRQ9Pb2xrJly1LsA6BMFR2YW2+9Nf70pz/FN7/5zTh06FB87nOfi02bNn3sg38A/rFVFQqFwsm8YHd3d9TV1cVVcVNUV512Mi8NwDB9VDgW2+KZ6OrqijFjxpzwXD+LDIAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkqrMewKmresrkrCeU1KE1J/77w8vRf5/566wnlNx/++lXsp5QclPbXsh6Qia8ggEgCYEBIAmBASAJgQEgCYEBIAmBASAJgQEgCYEBIAmBASAJgQEgCYEBIAmBASAJgQEgCYEBIAmBASAJgQEgCYEBIAmBASAJgQEgCYEBIAmBASAJgQEgiaIDs2PHjliyZEnU19dHVVVVPP300wlmAVDuig5Mb29vzJkzJ1avXp1iDwAVorrYByxevDgWL16cYgsAFaTowBQrn89HPp8fuN3d3Z36kgCcApJ/yN/W1hZ1dXUDR0NDQ+pLAnAKSB6Y1tbW6OrqGjg6OztTXxKAU0Dyt8hyuVzkcrnUlwHgFOP7YABIouhXMEePHo39+/cP3D5w4EDs3bs3xo0bF9OmTSvpOADKV9GB2b17d1x99dUDt1esWBEREc3NzbF27dqSDQOgvBUdmKuuuioKhUKKLQBUEJ/BAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJBEddYDKkX15ElZTyi5/9GxKesJJfXdP8/IekLJ/c+eyntO/+Gf38p6Qsnl27JekA2vYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIoqjAtLW1xWWXXRa1tbUxceLEuPnmm+O1115LtQ2AMlZUYLZv3x4tLS2xc+fO2LJlSxw7diyuu+666O3tTbUPgDJVXczJmzZtGnR77dq1MXHixOjo6IgvfOELJR0GQHkrKjD/XldXV0REjBs37rjn5PP5yOfzA7e7u7uHc0kAysSQP+Tv7++P5cuXx8KFC2P27NnHPa+trS3q6uoGjoaGhqFeEoAyMuTAtLS0xMsvvxzr168/4Xmtra3R1dU1cHR2dg71kgCUkSG9RXbXXXfFc889Fzt27IipU6ee8NxcLhe5XG5I4wAoX0UFplAoxDe+8Y3YsGFDbNu2Lc4555xUuwAoc0UFpqWlJdatWxfPPPNM1NbWxqFDhyIioq6uLkaPHp1kIADlqajPYNrb26OrqyuuuuqqmDJlysDx5JNPptoHQJkq+i0yAPg0/CwyAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkijqr0zm+P7Pfz076wklN2Pbv2Q9oaTO/S97s55Qch/8yz9lPaHkJi87kPUESsQrGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSEBgAkhAYAJIQGACSKCow7e3t0djYGGPGjIkxY8bEggULYuPGjam2AVDGigrM1KlTY9WqVdHR0RG7d++Oa665Jm666aZ45ZVXUu0DoExVF3PykiVLBt3+zne+E+3t7bFz58646KKLSjoMgPJWVGD+rb6+vvjFL34Rvb29sWDBguOel8/nI5/PD9zu7u4e6iUBKCNFf8i/b9++OOOMMyKXy8Udd9wRGzZsiFmzZh33/La2tqirqxs4GhoahjUYgPJQdGAuuOCC2Lt3b/zhD3+IO++8M5qbm+PVV1897vmtra3R1dU1cHR2dg5rMADloei3yGpqauK8886LiIi5c+fGrl274uGHH45HH330E8/P5XKRy+WGtxKAsjPs74Pp7+8f9BkLAEQU+QqmtbU1Fi9eHNOmTYuenp5Yt25dbNu2LTZv3pxqHwBlqqjAHDlyJL7yla/Ee++9F3V1ddHY2BibN2+OL37xi6n2AVCmigrMT37yk1Q7AKgwfhYZAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAElUZz2gUnzxn/531hNK7sC9F2Q9gb8jP7Yq6wkl99qhiVlPKLmz41DWEzLhFQwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQgMAEkIDABJCAwASQwrMKtWrYqqqqpYvnx5ieYAUCmGHJhdu3bFo48+Go2NjaXcA0CFGFJgjh49GkuXLo3HHnssxo4dW+pNAFSAIQWmpaUlbrjhhli0aNHfPTefz0d3d/egA4DKV13sA9avXx979uyJXbt2farz29ra4lvf+lbRwwAob0W9guns7Ix77rknfvazn8WoUaM+1WNaW1ujq6tr4Ojs7BzSUADKS1GvYDo6OuLIkSNx6aWXDtzX19cXO3bsiB/96EeRz+dj5MiRgx6Ty+Uil8uVZi0AZaOowFx77bWxb9++QfctW7YsZs6cGffff//H4gLAP66iAlNbWxuzZ88edN/pp58e48eP/9j9APxj8538ACRR9FeR/Xvbtm0rwQwAKo1XMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkITAAJCEwACQhMAAkUZ31gErx0nvTs55Qclf+4OWsJ5TU7j+dm/WEkvtPU7ZlPaHk/vCfZ2Y9oeT6sh6QEa9gAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEiiqMA88MADUVVVNeiYOXNmqm0AlLHqYh9w0UUXxW9/+9t//QWqi/4lAPgHUHQdqqurY/LkySm2AFBBiv4M5o033oj6+vo499xzY+nSpfH222+f8Px8Ph/d3d2DDgAqX1GBmT9/fqxduzY2bdoU7e3tceDAgbjiiiuip6fnuI9pa2uLurq6gaOhoWHYowE49RUVmMWLF8ctt9wSjY2Ncf3118evf/3r+Mtf/hJPPfXUcR/T2toaXV1dA0dnZ+ewRwNw6hvWJ/RnnnlmnH/++bF///7jnpPL5SKXyw3nMgCUoWF9H8zRo0fjzTffjClTppRqDwAVoqjA3HfffbF9+/b44x//GC+88EJ86UtfipEjR8Ztt92Wah8AZaqot8jeeeeduO222+LPf/5zTJgwIS6//PLYuXNnTJgwIdU+AMpUUYFZv359qh0AVBg/iwyAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJKqzHlAp6lt6sp5QcjuvbMp6QkmN/LCQ9YSSe/Hpd7KeUHKFY29lPYES8QoGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCQEBoAkBAaAJAQGgCSKDsy7774bt99+e4wfPz5Gjx4dF198cezevTvFNgDKWHUxJ3/wwQexcOHCuPrqq2Pjxo0xYcKEeOONN2Ls2LGp9gFQpooKzHe/+91oaGiIxx9/fOC+c845p+SjACh/Rb1F9uyzz0ZTU1PccsstMXHixLjkkkviscceO+Fj8vl8dHd3DzoAqHxFBeatt96K9vb2mDFjRmzevDnuvPPOuPvuu+OJJ5447mPa2tqirq5u4GhoaBj2aABOfVWFQqHwaU+uqamJpqameOGFFwbuu/vuu2PXrl3x4osvfuJj8vl85PP5gdvd3d3R0NAQV8VNUV112jCmn1qqp3426wkl9+crK+sPAyM//NS/1ctG7dP/K+sJJVc49mHWEziBjwrHYls8E11dXTFmzJgTnlvUK5gpU6bErFmzBt134YUXxttvv33cx+RyuRgzZsygA4DKV1RgFi5cGK+99tqg+15//fWYPn16SUcBUP6KCsy9994bO3fujAcffDD2798f69atizVr1kRLS0uqfQCUqaICc9lll8WGDRvi5z//ecyePTu+/e1vx0MPPRRLly5NtQ+AMlXU98FERNx4441x4403ptgCQAXxs8gASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIoui/Mnm4CoVCRER8FMciCif76gn157NeUHJ9H/7frCeUVOFYJf2G+6uPCseynlByhQp8TpXko/jrv5+//b/8RKoKn+asEnrnnXeioaHhZF4SgBLr7OyMqVOnnvCckx6Y/v7+OHjwYNTW1kZVVVWy63R3d0dDQ0N0dnbGmDFjkl3nZPKcTn2V9nwiPKdycbKeU6FQiJ6enqivr48RI078KctJf4tsxIgRf7d6pTRmzJiK+Q30N57Tqa/Snk+E51QuTsZzqqur+1Tn+ZAfgCQEBoAkKjYwuVwuVq5cGblcLuspJeM5nfoq7flEeE7l4lR8Tif9Q34A/jFU7CsYALIlMAAkITAAJCEwACRRkYFZvXp1nH322TFq1KiYP39+vPTSS1lPGpYdO3bEkiVLor6+PqqqquLpp5/OetKwtLW1xWWXXRa1tbUxceLEuPnmm+O1117LetawtLe3R2Nj48A3uS1YsCA2btyY9aySWrVqVVRVVcXy5cuznjJkDzzwQFRVVQ06Zs6cmfWsYXn33Xfj9ttvj/Hjx8fo0aPj4osvjt27d2c9KyIqMDBPPvlkrFixIlauXBl79uyJOXPmxPXXXx9HjhzJetqQ9fb2xpw5c2L16tVZTymJ7du3R0tLS+zcuTO2bNkSx44di+uuuy56e3uznjZkU6dOjVWrVkVHR0fs3r07rrnmmrjpppvilVdeyXpaSezatSseffTRaGxszHrKsF100UXx3nvvDRy///3vs540ZB988EEsXLgwTjvttNi4cWO8+uqr8b3vfS/Gjh2b9bS/KlSYefPmFVpaWgZu9/X1Ferr6wttbW0ZriqdiChs2LAh6xkldeTIkUJEFLZv3571lJIaO3Zs4cc//nHWM4atp6enMGPGjMKWLVsKV155ZeGee+7JetKQrVy5sjBnzpysZ5TM/fffX7j88suznnFcFfUK5sMPP4yOjo5YtGjRwH0jRoyIRYsWxYsvvpjhMk6kq6srIiLGjRuX8ZLS6Ovri/Xr10dvb28sWLAg6znD1tLSEjfccMOg/67K2RtvvBH19fVx7rnnxtKlS+Ptt9/OetKQPfvss9HU1BS33HJLTJw4MS655JJ47LHHsp41oKIC8/7770dfX19MmjRp0P2TJk2KQ4cOZbSKE+nv74/ly5fHwoULY/bs2VnPGZZ9+/bFGWecEblcLu64447YsGFDzJo1K+tZw7J+/frYs2dPtLW1ZT2lJObPnx9r166NTZs2RXt7exw4cCCuuOKK6OnpyXrakLz11lvR3t4eM2bMiM2bN8edd94Zd999dzzxxBNZT4uIDH6aMvxbLS0t8fLLL5f1++B/c8EFF8TevXujq6srfvnLX0Zzc3Ns3769bCPT2dkZ99xzT2zZsiVGjRqV9ZySWLx48cA/NzY2xvz582P69Onx1FNPxde+9rUMlw1Nf39/NDU1xYMPPhgREZdcckm8/PLL8cgjj0Rzc3PG6yrsFcxZZ50VI0eOjMOHDw+6//DhwzF58uSMVnE8d911Vzz33HPx/PPPn9S/wiGVmpqaOO+882Lu3LnR1tYWc+bMiYcffjjrWUPW0dERR44ciUsvvTSqq6ujuro6tm/fHj/4wQ+iuro6+vr6sp44bGeeeWacf/75sX///qynDMmUKVM+9geYCy+88JR526+iAlNTUxNz586NrVu3DtzX398fW7durYj3witFoVCIu+66KzZs2BC/+93v4pxzzsl6UhL9/f2Rz5fvX6V97bXXxr59+2Lv3r0DR1NTUyxdujT27t0bI0eOzHrisB09ejTefPPNmDJlStZThmThwoUf+xL/119/PaZPn57RosEq7i2yFStWRHNzczQ1NcW8efPioYceit7e3li2bFnW04bs6NGjg/6EdeDAgdi7d2+MGzcupk2bluGyoWlpaYl169bFM888E7W1tQOfj9XV1cXo0aMzXjc0ra2tsXjx4pg2bVr09PTEunXrYtu2bbF58+aspw1ZbW3txz4XO/3002P8+PFl+3nZfffdF0uWLInp06fHwYMHY+XKlTFy5Mi47bbbsp42JPfee298/vOfjwcffDC+/OUvx0svvRRr1qyJNWvWZD3tr7L+MrYUfvjDHxamTZtWqKmpKcybN6+wc+fOrCcNy/PPP1+IiI8dzc3NWU8bkk96LhFRePzxx7OeNmRf/epXC9OnTy/U1NQUJkyYULj22msLv/nNb7KeVXLl/mXKt956a2HKlCmFmpqawmc/+9nCrbfeWti/f3/Ws4blV7/6VWH27NmFXC5XmDlzZmHNmjVZTxrgx/UDkERFfQYDwKlDYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASEJgAEhCYABIQmAASOL/AdOXMbYjLxvdAAAAAElFTkSuQmCC",
- "text/plain": [
- "