diff --git a/examples/yelp/yelp_subspace_ekf_diag_hessian.ipynb b/examples/yelp/yelp_subspace_ekf_diag_hessian.ipynb new file mode 100644 index 00000000..a2a07a4b --- /dev/null +++ b/examples/yelp/yelp_subspace_ekf_diag_hessian.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from tqdm.auto import tqdm\n", + "from optree import tree_map, tree_map_\n", + "import pickle\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import uqlib\n", + "\n", + "from load import load_dataloaders, load_model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data size: 1000\n" + ] + } + ], + "source": [ + "# Load data\n", + "train_dataloader, eval_dataloader = load_dataloaders(small=True, batch_size=32)\n", + "num_data = len(train_dataloader.dataset)\n", + "print(\"Training data size: \", num_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "# Load model (with standard Gaussian prior)\n", + "model, param_to_log_lik = load_model(num_data=num_data, prior_sd=torch.inf)\n", + "\n", + "# Turn off Dropout\n", + "model.eval()\n", + "\n", + "# Load to GPU\n", + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "model.to(device);" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Only train the last layer\n", + "for name, param in model.named_parameters():\n", + " if 'bert' in name:\n", + " param.requires_grad = False\n", + "\n", + "# Extract only the parameters to be trained\n", + "sub_params, sub_param_to_log_lik = uqlib.extract_requires_grad_and_func(dict(model.named_parameters()), param_to_log_lik)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Store initial values of sub_params to check against later\n", + "init_sub_params = tree_map(lambda x: x.detach().clone(), sub_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Initiate Normal parameters\n", + "init_mean = sub_params\n", + "init_log_sds = tree_map(\n", + " lambda x: (torch.zeros_like(x) - 2.0).requires_grad_(True), init_mean\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Optimization setup\n", + "num_epochs = 30\n", + "num_training_steps = num_epochs * len(train_dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3e0f2d203bac4f51a26e5f9c9fef20a5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/960 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot moving average of log_likelhood\n", + "plot_moving_average(log_liks, 50)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAh30lEQVR4nO3dfXBU1f3H8c8mJAsBkpiQZJMxgSBP1sqDICFWK0iU0I6PTFsZp4J1fALrSErVtAKKvzEq1VAdhMGpoNNR1JmCrVbaGANUCbEwUK1YngyNaBJtaFhACQ85vz80W5ZsQrLZPXd3837N3DF79+Tud08S98M5957rMsYYAQAAWBLndAEAAKB3IXwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsKqP0wWcqbW1VZ9//rkGDhwol8vldDkAAKALjDE6fPiwcnJyFBfX+dhGxIWPzz//XLm5uU6XAQAAgvDpp5/q3HPP7bRNxIWPgQMHSvqm+OTkZIerAQAAXeH1epWbm+v7HO9MxIWPtqmW5ORkwgcAAFGmK6dMcMIpAACwivABAACsInwAAACrIu6cDwAAwsUYo5MnT+rUqVNOlxKVEhISFB8f3+PjED4AAL3C8ePHVV9fr6+++srpUqKWy+XSueeeqwEDBvToOIQPAEDMa21tVW1treLj45WTk6PExEQWsuwmY4y+/PJLHThwQMOHD+/RCAjhAwAQ844fP67W1lbl5uYqKSnJ6XKiVkZGhvbv368TJ070KHxwwikAoNc427Lf6FyoRov4KQAAAKsIHwAAwCrO+QAA9GrlFbutvt68K0eE9fgul0tr167VddddF9bX6QlGPgAAiGCzZ8/uVpCor6/X9OnTJUn79++Xy+XSjh07wlNckBj5AAAghng8HqdLOCtGPgAAiBKTJ0/WPffco/vuu09paWnyeDx66KGH/Nq4XC6tW7dOkpSfny9JGjdunFwulyZPnmy34A4QPgBEpPKK3dbn4oFo8MILL6h///6qqanRE088ocWLF6uioiJg2/fff1+S9Pbbb6u+vl5/+MMfbJbaIcIHAABRZPTo0Vq0aJGGDx+um2++WRMmTFBlZWXAthkZGZKk9PR0eTwepaWl2Sy1Q4QPAACiyOjRo/0eZ2dn64svvnComuAQPgAAiCIJCQl+j10ul1pbWx2qJjiEDwAAYlRiYqIk6dSpUw5X4o/wAQBAjMrMzFS/fv20fv16NTY26tChQ06XJIl1PgBEEK5ugRPCveKok/r06aOnn35aixcv1sKFC3XZZZdpw4YNTpcllzHGOF3E6bxer1JSUnTo0CElJyc7XQ4AiwKFj1j+YIA9x44dU21trfLz89W3b1+ny4lanfVjdz6/mXYBAABWMe0CIGqdPlLCCAkQPRj5AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAKLQ5MmTde+994b0mA899JDGjh0b0mMGwjofAIDerarM7utNKe1W89mzZ+uFF15ot7+mpkbnn39+qKqyivABAECEKy4u1qpVq/z2ZWRkKD4+3qGKeoZpFwAAIpzb7ZbH4/Hbpk6d6jftMmTIED366KP62c9+poEDByovL08rV670O87999+vESNGKCkpSUOHDtWCBQt04sQJy++G8AEAQMx48sknNWHCBG3fvl1z5szRXXfdpV27dvmeHzhwoFavXq2dO3fqt7/9rZ577jmVl5dbr5PwAQBAhHvjjTc0YMAA3/ajH/0oYLsf/OAHmjNnjoYNG6b7779fgwYNUlVVle/5Bx98UJdccomGDBmiq6++WvPnz9err75q6234cM4HAAARbsqUKVq+fLnvcf/+/TVz5sx27UaPHu372uVyyePx6IsvvvDte+WVV/T0009r3759OnLkiE6ePKnk5OTwFh8A4QMAgAjXv39/DRs27KztEhIS/B67XC61trZKkqqrq3XTTTfp4Ycf1rRp05SSkqI1a9boySefDEvNnSF8AADQC2zevFmDBw/Wr3/9a9++f//7347UQvgAAKAXGD58uOrq6rRmzRpdfPHFevPNN7V27VpHauGEUwAAeoFrrrlG8+bN0913362xY8dq8+bNWrBggSO1uIwxxpFX7oDX61VKSooOHTrkyEkwAJxTXrG73b55V47oUvvO2gHHjh1TbW2t8vPz1bdvX6fLiVqd9WN3Pr8Z+QAAAFYRPgAAgFWEDwAAYBXhAwAAWNWt8FFWVqaLL75YAwcOVGZmpq677jq/NeOlb05GmTt3rtLT0zVgwADNmDFDjY2NIS0aAABEr26Fj40bN2ru3LnasmWLKioqdOLECV111VU6evSor828efP0pz/9Sa+99po2btyozz//XDfccEPICwcAoLsi7ALPqBOq/uvWImPr16/3e7x69WplZmZq27Zt+v73v69Dhw7pd7/7nV566SVdccUVkqRVq1bp/PPP15YtWzRp0qSQFA0gunBJLJzWtuz4V199pX79+jlcTfQ6fvy4JCk+Pr5Hx+nRCqeHDh2SJKWlpUmStm3bphMnTqioqMjXZtSoUcrLy1N1dXXA8NHS0qKWlhbfY6/X25OSAABoJz4+Xqmpqb6brCUlJcnlcjlcVXRpbW3Vl19+qaSkJPXp07MF0oP+7tbWVt1777363ve+p+9+97uSpIaGBiUmJio1NdWvbVZWlhoaGgIep6ysTA8//HCwZQAA0CUej0eS/O7yiu6Ji4tTXl5ej4Nb0OFj7ty5+uc//6l33323RwWUlpaqpKTE99jr9So3N7dHxwQQO9qmbM42XdPd1VHR+7hcLmVnZyszM1MnTpxwupyolJiYqLi4nl8oG1T4uPvuu/XGG29o06ZNOvfcc337PR6Pjh8/rubmZr/Rj8bGRl/iPJPb7Zbb7Q6mDAAAui0+Pr7H5yygZ7oVX4wxuvvuu7V27Vq98847ys/P93t+/PjxSkhIUGVlpW/frl27VFdXp8LCwtBUDAAAolq3Rj7mzp2rl156Sa+//roGDhzoO48jJSVF/fr1U0pKim699VaVlJQoLS1NycnJ+vnPf67CwkKudAEgqevTKB19XzDfCyCydCt8LF++XJI0efJkv/2rVq3S7NmzJUnl5eWKi4vTjBkz1NLSomnTpunZZ58NSbEAACD6dSt8dGVxkb59+2rZsmVatmxZ0EUBAIDYxb1dAACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGBVz25LBwBBYtEwoPdi5AMAAFjFyAcAxwW6I62t12TUBbCPkQ8AAGAV4QMAAFhF+AAQ88ordjsytQMgMMIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq/o4XQAAhAv3cwEiEyMfAADAKkY+AEQdRjSA6MbIBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqFhkDEF5VZZpU1+R7uCXvdgeLARAJGPkAAABWET4AAIBVTLsACJvyit1+Uy4dmVS30u8xUzNAbGPkAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVrPMBwKoz1/SIJOUVuyVJ864c4XAlQGxj5AMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWMXVLgAiTqArYrjTLRA7GPkAAABWMfIBIGa1G0GpSpemlDpTDAAfRj4AAIBVhA8AAGAV0y4AokJXlmXnpFQgOjDyAQAArCJ8AAAAq5h2ARBybXeHjQXc6RYIPUY+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFZ1O3xs2rRJV199tXJycuRyubRu3Tq/52fPni2Xy+W3FRcXh6peAAha9SdNKq/YHVOLoAHRqNvh4+jRoxozZoyWLVvWYZvi4mLV19f7tpdffrlHRQIAgNjR7eXVp0+frunTp3faxu12y+PxBF0UAACIXWE552PDhg3KzMzUyJEjddddd6mpqanDti0tLfJ6vX4bAACIXSEPH8XFxXrxxRdVWVmpxx9/XBs3btT06dN16tSpgO3LysqUkpLi23Jzc0NdEgAAiCAhv6vtjTfe6Pv6wgsv1OjRo3Xeeedpw4YNmjp1arv2paWlKikp8T32er0EEAAAYljYL7UdOnSoBg0apL179wZ83u12Kzk52W8DAACxK+zh48CBA2pqalJ2dna4XwoAAESBbk+7HDlyxG8Uo7a2Vjt27FBaWprS0tL08MMPa8aMGfJ4PNq3b5/uu+8+DRs2TNOmTQtp4QAAIDp1O3xs3bpVU6ZM8T1uO19j1qxZWr58uT744AO98MILam5uVk5Ojq666io98sgjcrvdoasaAABErW6Hj8mTJ8sY0+Hzf/nLX3pUEAAAiG3c2wUAAFgV8kttAfRiVWWSpEl1HS8sCACMfAAAAKsY+QAQMybVrex+m6r0055rG7H5TQirAnAmRj4AAIBVhA8AAGAV0y4AcIbyit0d7pt35Qjb5QAxh5EPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4ANAj5RW7Ay5HDgAdIXwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArOrjdAEA4KTqT5qcLgHodQgfAIJXVaZJdd9+eFelO1tLCE2qW+n3eEve7Q5VAsQmpl0AAIBVjHwA6NTp922Zd+UIBysBECsY+QAAAFYRPgAAgFWEDwBdVl6x228aBgCCQfgAAABWccIpgJBgvQwAXcXIBwAAsIrwAQAArCJ8AAAAqwgfAADAKk44BYCz8LvXS9s9bKaUOlMMEAMY+QAAAFYRPgAAgFWEDwAAYBXhAwAAWMUJpwC6JOBJlwAQBEY+AACAVYQPAAgSd/kFgkP4AAAAVhE+AACAVZxwCqDbqj9pcroEAFGMkQ8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWMWltgAQjKoyTar79pLjqnRpSqmz9QBRhJEPAABgFeEDAABYRfgAAABWcc4HAD9td2mdd+UIhyuJTCwtD/QcIx8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArOp2+Ni0aZOuvvpq5eTkyOVyad26dX7PG2O0cOFCZWdnq1+/fioqKtKePXtCVS8ARIXyit2+DYC/boePo0ePasyYMVq2bFnA55944gk9/fTTWrFihWpqatS/f39NmzZNx44d63GxAAAg+nV7hdPp06dr+vTpAZ8zxmjp0qV68MEHde2110qSXnzxRWVlZWndunW68cYbe1YtAACIeiFdXr22tlYNDQ0qKiry7UtJSVFBQYGqq6sDho+Wlha1tLT4Hnu93lCWBABWMc0CnF1ITzhtaGiQJGVlZfntz8rK8j13prKyMqWkpPi23NzcUJYEAAAijOM3listLVVJSYnvsdfrJYAAtlWV+b6cVNd247TfOFMLgJgX0pEPj8cjSWpsbPTb39jY6HvuTG63W8nJyX4bAACIXSENH/n5+fJ4PKqsrPTt83q9qqmpUWFhYShfCgAARKluT7scOXJEe/fu9T2ura3Vjh07lJaWpry8PN177736v//7Pw0fPlz5+flasGCBcnJydN1114WybgDh9u1UzP+mYQAgNLodPrZu3aopU6b4HredrzFr1iytXr1a9913n44eParbb79dzc3NuvTSS7V+/Xr17ds3dFUDCKnqTwgYAOzpdviYPHmyjDEdPu9yubR48WItXry4R4UBAIDYxL1dAACAVYQPAABgFeEDAABYRfgAAABWOb7CKQBEu+pPmrTlJPd0AbqKkQ8AAGAVIx8AECaT6lZ+80VV+jf/nVLqXDFABGHkAwAAWEX4AAAAVhE+AMCS8ordKq/gxFSA8AEAAKwifAAAAKu42gUAQsB3ZQuAs2LkAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAVi4wBvczp9xaZd+UIBysB0Fsx8gEAAKwifAAAAKuYdgEQUPUnTU6XACBGMfIBAACsInwAMa68YrffSaYA4DTCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwinU+gF5oUt3Kb76oSne2EAC9EiMfAADAKsIHAACwimkXIMadOcUyqY5l050ScLprSqkzxQAOYuQDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFjF1S4AEEmqytrv44oYxBhGPgAAgFWMfAC9RPUnrO8BIDIw8gEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwikXGACDMWOAN8MfIBwAAsIrwAQAArCJ8AAAAqwgfAADAKk44BQAnVZU5XQFgHeEDiDVVZV2+uoKrMAA4gWkXAABgFeEDAABYRfgAAABWET4AAIBVhA8AiDDVnzRxMjBiGuEDAABYRfgAAABWsc4HAFh2+pRK4dB0BysBnMHIBwAAsCrk4eOhhx6Sy+Xy20aNGhXqlwEAAFEqLNMuF1xwgd5+++3/vUgfZneAkDjzPiBTSp2pAwB6ICypoE+fPvJ4POE4NAAAiHJhOedjz549ysnJ0dChQ3XTTTeprq6uw7YtLS3yer1+GwAAiF0hDx8FBQVavXq11q9fr+XLl6u2tlaXXXaZDh8+HLB9WVmZUlJSfFtubm6oSwIAABHEZYwx4XyB5uZmDR48WE899ZRuvfXWds+3tLSopaXF99jr9So3N1eHDh1ScnJyOEsDos+Z53x0gNUxo0egS23bfn6+5zi3B1HA6/UqJSWlS5/fYT8TNDU1VSNGjNDevXsDPu92u+V2u8NdBgAAiBBhX+fjyJEj2rdvn7Kzs8P9UgAAIAqEPHzMnz9fGzdu1P79+7V582Zdf/31io+P18yZM0P9UgAAIAqFfNrlwIEDmjlzppqampSRkaFLL71UW7ZsUUZGRqhfCsAZONcj+vAzQ28U8vCxZs2aUB8SAADEEJYeBYBIF+gqJ66AQRTjxnIAAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwiuXVASAWsAQ7oggjHwAAwCrCBxClqj9p4nbsAKIS4QMAAFhF+AAAAFZxwikARKjTp9UKh6Y7WAkQWoQPAIhGga5uAaIE0y4AAMAqwgcAALCK8AEAAKzinA8gyrHWB4Bow8gHAACwivABAACsYtoFiCJMsfRebT/7QOt9dPYcEIkY+QAAAFYRPgAAgFWEDwAAYBXnfACRiuWzAcQoRj4AAIBVjHwANnRlFGNKafjrQO9y5u8dv2OIEIx8AAAAqwgfAADAKqZdgEjBCaYAeglGPgAAgFWEDwAAYBXhAwAAWEX4AAAAVnHCKRBhTr9zLXcpxZm4gy1iAeEDAHqLQFdUsfAYHED4AEKNS2YBoFOED8BBXZ1iOb0d0BGm7BAtCB8AgP9hagYWED6AnmKaBQC6hfABOCDQNApTKwB6C9b5AAAAVjHyAXTmzCkV5r4BoMcIHwAAlVfsliTN41MBFvBrBgC92beje5Pqvj3niEt0YQHnfAAAAKsY+QCAGMbCY4hEjHwAAACrGPkAwoy7kCIczrYuTFfWjWFtGTiFkQ8AAGAVIx9Ad7CUOgD0GOEDiGAMiyMidCV0swAfuoFpFwAAYBXhAwAAWMW0C9AFwayVwJQJIg1XXiFSED4AAKEX6DwRzgvBtwgfAICe40owdAPhAwgBplgAoOs44RQAAFjFyAfQhmFjALCC8AEA8Ak0hRjo6hiunEFPED7QOzCqAQARg/ABAIhuXbmsl0t/IwrhA+gmrmxBtAvl7/DZjsW0DAIJW/hYtmyZlixZooaGBo0ZM0bPPPOMJk6cGK6XQ7DO/NdApP1LINgbWjHNAkQem3+X/D8gooUlfLzyyisqKSnRihUrVFBQoKVLl2ratGnatWuXMjMzw/GSXRfpH7aBdKXmcA4phqrPgn0f3wrFCW6d/Svt9OOGY3SDERNEq1D87nKCKk4XlnU+nnrqKd1222265ZZb9J3vfEcrVqxQUlKSnn/++XC8HAAAiCIhH/k4fvy4tm3bptLS//2rNi4uTkVFRaqurm7XvqWlRS0tLb7Hhw4dkiR5vd5Ql/aNo8f8H7+xqH2b7/8iPK8drK7UHEhX+vDMYwf6nq606Ypg30fbt3/9ze+JN5iazzhGIKcft7N2ALqu7e+qw79fJ4Xrc2bTk+33deVzJdD3BXOcUNXTTW2f28aYszc2IfbZZ58ZSWbz5s1++3/5y1+aiRMntmu/aNEiI4mNjY2NjY0tBrZPP/30rFnB8atdSktLVVJS4nvc2tqqgwcPKj09XS6Xq0vH8Hq9ys3N1aeffqrk5ORwldpr0J+hRX+GDn0ZWvRnaPX2/jTG6PDhw8rJyTlr25CHj0GDBik+Pl6NjY1++xsbG+XxeNq1d7vdcrvdfvtSU1ODeu3k5ORe+QMPF/oztOjP0KEvQ4v+DK3e3J8pKSldahfyE04TExM1fvx4VVZW+va1traqsrJShYWFoX45AAAQZcIy7VJSUqJZs2ZpwoQJmjhxopYuXaqjR4/qlltuCcfLAQCAKBKW8PGTn/xEX375pRYuXKiGhgaNHTtW69evV1ZWVjheTm63W4sWLWo3fYPg0J+hRX+GDn0ZWvRnaNGfXecypivXxAAAAIRGWBYZAwAA6AjhAwAAWEX4AAAAVhE+AACAVVERPg4ePKibbrpJycnJSk1N1a233qojR450+j0rV67U5MmTlZycLJfLpebm5nZthgwZIpfL5bc99thjYXoXkSNc/RnMcWNBMO/72LFjmjt3rtLT0zVgwADNmDGj3cJ8Z/5uulwurVmzJpxvxRHLli3TkCFD1LdvXxUUFOj999/vtP1rr72mUaNGqW/fvrrwwgv15z//2e95Y4wWLlyo7Oxs9evXT0VFRdqzZ08430JECXV/zp49u93vYXFxcTjfQsToTl9+9NFHmjFjhu9zZenSpT0+ZkwLyQ1dwqy4uNiMGTPGbNmyxfztb38zw4YNMzNnzuz0e8rLy01ZWZkpKyszksx///vfdm0GDx5sFi9ebOrr633bkSNHwvQuIke4+jOY48aCYN73nXfeaXJzc01lZaXZunWrmTRpkrnkkkv82kgyq1at8vv9/Prrr8P5Vqxbs2aNSUxMNM8//7z56KOPzG233WZSU1NNY2NjwPbvvfeeiY+PN0888YTZuXOnefDBB01CQoL58MMPfW0ee+wxk5KSYtatW2f+8Y9/mGuuucbk5+fHXN8FEo7+nDVrlikuLvb7PTx48KCtt+SY7vbl+++/b+bPn29efvll4/F4THl5eY+PGcsiPnzs3LnTSDJ///vfffveeust43K5zGeffXbW76+qquo0fAT6BYll4erPnh43WgXzvpubm01CQoJ57bXXfPs+/vhjI8lUV1f79kkya9euDVvtkWDixIlm7ty5vsenTp0yOTk5pqysLGD7H//4x+aHP/yh376CggJzxx13GGOMaW1tNR6PxyxZssT3fHNzs3G73ebll18OwzuILKHuT2O+CR/XXnttWOqNZN3ty9N19NnSk2PGmoifdqmurlZqaqomTJjg21dUVKS4uDjV1NT0+PiPPfaY0tPTNW7cOC1ZskQnT57s8TEjWbj6M9w/p0gVzPvetm2bTpw4oaKiIt++UaNGKS8vT9XV1X5t586dq0GDBmnixIl6/vnnu3ar6ihx/Phxbdu2za8f4uLiVFRU1K4f2lRXV/u1l6Rp06b52tfW1qqhocGvTUpKigoKCjo8ZqwIR3+22bBhgzIzMzVy5EjdddddampqCv0biCDB9KUTx4xmjt/V9mwaGhqUmZnpt69Pnz5KS0tTQ0NDj459zz336KKLLlJaWpo2b96s0tJS1dfX66mnnurRcSNZuPoznD+nSBbM+25oaFBiYmK7GyhmZWX5fc/ixYt1xRVXKCkpSX/96181Z84cHTlyRPfcc0/I34cT/vOf/+jUqVPtVj7OysrSv/71r4Df09DQELB9W7+1/bezNrEqHP0pScXFxbrhhhuUn5+vffv26Ve/+pWmT5+u6upqxcfHh/6NRIBg+tKJY0Yzx8LHAw88oMcff7zTNh9//HFYaygpKfF9PXr0aCUmJuqOO+5QWVlZ1C2PGwn9GUsioT8XLFjg+3rcuHE6evSolixZEjPhA9Hhxhtv9H194YUXavTo0TrvvPO0YcMGTZ061cHKEM0cCx+/+MUvNHv27E7bDB06VB6PR1988YXf/pMnT+rgwYPyeDwhramgoEAnT57U/v37NXLkyJAeO9yc7k+bPycbwtmfHo9Hx48fV3Nzs9/oR2NjY6d9VVBQoEceeUQtLS1RF44DGTRokOLj49td5dNZP3g8nk7bt/23sbFR2dnZfm3Gjh0bwuojTzj6M5ChQ4dq0KBB2rt3b8yGj2D60oljRjPHzvnIyMjQqFGjOt0SExNVWFio5uZmbdu2zfe977zzjlpbW1VQUBDSmnbs2KG4uLh2w+jRwOn+tPlzsiGc/Tl+/HglJCSosrLSt2/Xrl2qq6tTYWFhhzXt2LFD55xzTkwED0lKTEzU+PHj/fqhtbVVlZWVHfZDYWGhX3tJqqio8LXPz8+Xx+Pxa+P1elVTU9Np38aCcPRnIAcOHFBTU5NfuIs1wfSlE8eMak6f8doVxcXFZty4caampsa8++67Zvjw4X6XMh44cMCMHDnS1NTU+PbV19eb7du3m+eee85IMps2bTLbt283TU1NxhhjNm/ebMrLy82OHTvMvn37zO9//3uTkZFhbr75Zuvvz7Zw9GdXjhurgunPO++80+Tl5Zl33nnHbN261RQWFprCwkLf83/84x/Nc889Zz788EOzZ88e8+yzz5qkpCSzcOFCq+8t3NasWWPcbrdZvXq12blzp7n99ttNamqqaWhoMMYY89Of/tQ88MADvvbvvfee6dOnj/nNb35jPv74Y7No0aKAl9qmpqaa119/3XzwwQfm2muv7VWX2oayPw8fPmzmz59vqqurTW1trXn77bfNRRddZIYPH26OHTvmyHu0pbt92dLSYrZv3262b99usrOzzfz588327dvNnj17unzM3iQqwkdTU5OZOXOmGTBggElOTja33HKLOXz4sO/52tpaI8lUVVX59i1atMhIaretWrXKGGPMtm3bTEFBgUlJSTF9+/Y1559/vnn00Udj/g/KmPD0Z1eOG6uC6c+vv/7azJkzx5xzzjkmKSnJXH/99aa+vt73/FtvvWXGjh1rBgwYYPr372/GjBljVqxYYU6dOmXzrVnxzDPPmLy8PJOYmGgmTpxotmzZ4nvu8ssvN7NmzfJr/+qrr5oRI0aYxMREc8EFF5g333zT7/nW1lazYMECk5WVZdxut5k6darZtWuXjbcSEULZn1999ZW56qqrTEZGhklISDCDBw82t912W6/5sOxOX7b9nZ+5XX755V0+Zm/iMiaGrt0DAAARL+LX+QAAALGF8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMCq/wclEuyr0uXuGQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize trained sub_params vs their initial values\n", + "final_sub_params = tree_map(lambda p: p.detach().clone(), dict(model.named_parameters()))\n", + "\n", + "init_untrained_params = torch.cat([v.flatten() for k, v in init_sub_params.items() if 'bert' not in k])\n", + "final_untrained_params = torch.cat([v.flatten() for k, v in final_sub_params.items() if 'bert' not in k])\n", + "\n", + "plt.hist(init_untrained_params.cpu().numpy(), bins=100, alpha=0.5, label='Init', density=True)\n", + "plt.hist(final_untrained_params.cpu().numpy(), bins=100, alpha=0.5, label='Final', density=True)\n", + "plt.legend();" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcE0lEQVR4nO3df3SWdf348dfGZJCyERQbqxHTzB9pmKg08XxMna3iePTIST2RkalUjgopDUogSp16yjgYSnoM9BxRs6NWaqhnllRO1IGdfhjqAZOyjTy63YgxkF3fPzre38+EVOge9/ve5/E45zrHve/rvnjd18Q9vXbdW1mWZVkAACSkvNgDAAC8kUABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgORXFHmBP9PX1xQsvvBAjRoyIsrKyYo8DALwNWZbF5s2bo66uLsrL3/waSUkGygsvvBD19fXFHgMA2AMbN26M9773vW+6T0kGyogRIyLi3y+wqqqqyNMAAG9HLpeL+vr6/NfxN1OSgfL6t3WqqqoECgCUmLdze4abZAGA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDm7HSirVq2KU045Jerq6qKsrCzuvvvufo9nWRbz58+PsWPHxvDhw6OpqSmeeeaZfvu89NJLMW3atKiqqoqRI0fGueeeG6+88sp/9UIAgMFjtwNly5YtMWHChFiyZMkuH7/qqqti8eLFsXTp0li9enXsu+++0dzcHFu3bs3vM23atPjTn/4UDz74YNxzzz2xatWqmDFjxp6/CgBgUCnLsizb4yeXlcVdd90Vp512WkT8++pJXV1dfO1rX4uvf/3rERHR09MTNTU1sXz58jjrrLPiqaeeikMPPTQef/zxOOqooyIiYuXKlfHJT34y/va3v0VdXd1b/rm5XC6qq6ujp6fHLwsEgBKxO1+/C3oPyoYNG6KzszOamprya9XV1TFp0qRob2+PiIj29vYYOXJkPk4iIpqamqK8vDxWr169y+P29vZGLpfrtwEAg1dFIQ/W2dkZERE1NTX91mtqavKPdXZ2xpgxY/oPUVERo0aNyu/zRq2trbFw4cJCjgokbPyce3dae+6KKUWYBCiWkngXz9y5c6Onpye/bdy4sdgjAQADqKCBUltbGxERXV1d/da7urryj9XW1samTZv6Pf7aa6/FSy+9lN/njSorK6OqqqrfBgAMXgUNlIaGhqitrY22trb8Wi6Xi9WrV0djY2NERDQ2NkZ3d3d0dHTk93nooYeir68vJk2aVMhxAIAStdv3oLzyyivx7LPP5j/esGFDPPnkkzFq1KgYN25czJo1Ky699NI48MADo6GhIebNmxd1dXX5d/occsgh8fGPfzzOP//8WLp0aWzfvj1mzpwZZ5111tt6Bw8AMPjtdqA88cQTccIJJ+Q/nj17dkRETJ8+PZYvXx4XX3xxbNmyJWbMmBHd3d1x3HHHxcqVK2PYsGH559xyyy0xc+bMOOmkk6K8vDymTp0aixcvLsDLAQAGg//q56AUi5+DAoObd/HA4FS0n4MCAFAIAgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5BQ8UHbs2BHz5s2LhoaGGD58eBxwwAHx3e9+N7Isy++TZVnMnz8/xo4dG8OHD4+mpqZ45plnCj0KAFCiCh4oV155ZVx33XXxwx/+MJ566qm48sor46qrroprrrkmv89VV10VixcvjqVLl8bq1atj3333jebm5ti6dWuhxwEASlBFoQ/4yCOPxKmnnhpTpkyJiIjx48fHrbfeGo899lhE/PvqyaJFi+KSSy6JU089NSIibr755qipqYm77747zjrrrEKPBACUmIJfQTn22GOjra0tnn766YiI+P3vfx+//e1v4xOf+ERERGzYsCE6Ozujqakp/5zq6uqYNGlStLe37/KYvb29kcvl+m0AwOBV8Csoc+bMiVwuFwcffHAMGTIkduzYEZdddllMmzYtIiI6OzsjIqKmpqbf82pqavKPvVFra2ssXLiw0KMCAIkq+BWUn/zkJ3HLLbfEihUrYs2aNXHTTTfF9773vbjpppv2+Jhz586Nnp6e/LZx48YCTgwApKbgV1AuuuiimDNnTv5eksMPPzz++te/Rmtra0yfPj1qa2sjIqKrqyvGjh2bf15XV1ccccQRuzxmZWVlVFZWFnpUACBRBb+C8uqrr0Z5ef/DDhkyJPr6+iIioqGhIWpra6OtrS3/eC6Xi9WrV0djY2OhxwEASlDBr6Cccsopcdlll8W4cePigx/8YKxduzauvvrq+PznPx8REWVlZTFr1qy49NJL48ADD4yGhoaYN29e1NXVxWmnnVbocQCAElTwQLnmmmti3rx5ccEFF8SmTZuirq4uvvCFL8T8+fPz+1x88cWxZcuWmDFjRnR3d8dxxx0XK1eujGHDhhV6HACgBJVl//tHvJaIXC4X1dXV0dPTE1VVVcUeByiw8XPu3WntuSumFGESoJB25+u338UDACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIGJFD+/ve/x2c+85kYPXp0DB8+PA4//PB44okn8o9nWRbz58+PsWPHxvDhw6OpqSmeeeaZgRgFAChBBQ+Ul19+OSZPnhz77LNP/PKXv4w///nP8f3vfz/e+c535ve56qqrYvHixbF06dJYvXp17LvvvtHc3Bxbt24t9DgAQAmqKPQBr7zyyqivr49ly5bl1xoaGvL/nGVZLFq0KC655JI49dRTIyLi5ptvjpqamrj77rvjrLPOKvRIAECJKfgVlJ///Odx1FFHxac+9akYM2ZMfPjDH44bbrgh//iGDRuis7Mzmpqa8mvV1dUxadKkaG9v3+Uxe3t7I5fL9dsAgMGr4IGyfv36uO666+LAAw+M+++/P770pS/FV77ylbjpppsiIqKzszMiImpqavo9r6amJv/YG7W2tkZ1dXV+q6+vL/TYAEBCCh4ofX19ceSRR8bll18eH/7wh2PGjBlx/vnnx9KlS/f4mHPnzo2enp78tnHjxgJODACkpuCBMnbs2Dj00EP7rR1yyCHx/PPPR0REbW1tRER0dXX126erqyv/2BtVVlZGVVVVvw0AGLwKHiiTJ0+OdevW9Vt7+umn433ve19E/PuG2dra2mhra8s/nsvlYvXq1dHY2FjocQCAElTwd/FceOGFceyxx8bll18eZ5xxRjz22GNx/fXXx/XXXx8REWVlZTFr1qy49NJL48ADD4yGhoaYN29e1NXVxWmnnVbocQCAElTwQDn66KPjrrvuirlz58Z3vvOdaGhoiEWLFsW0adPy+1x88cWxZcuWmDFjRnR3d8dxxx0XK1eujGHDhhV6HACgBJVlWZYVe4jdlcvlorq6Onp6etyPAoPQ+Dn37rT23BVTijAJUEi78/Xb7+IBAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkDHihXXHFFlJWVxaxZs/JrW7dujZaWlhg9enTst99+MXXq1Ojq6hroUQCAEjGggfL444/Hj370o/jQhz7Ub/3CCy+MX/ziF3HHHXfEww8/HC+88EKcfvrpAzkKAFBCBixQXnnllZg2bVrccMMN8c53vjO/3tPTEzfeeGNcffXVceKJJ8bEiRNj2bJl8cgjj8Sjjz46UOMAACVkwAKlpaUlpkyZEk1NTf3WOzo6Yvv27f3WDz744Bg3bly0t7fv8li9vb2Ry+X6bQDA4FUxEAe97bbbYs2aNfH444/v9FhnZ2cMHTo0Ro4c2W+9pqYmOjs7d3m81tbWWLhw4UCMCgAkqOBXUDZu3Bhf/epX45Zbbolhw4YV5Jhz586Nnp6e/LZx48aCHBcASFPBA6WjoyM2bdoURx55ZFRUVERFRUU8/PDDsXjx4qioqIiamprYtm1bdHd393teV1dX1NbW7vKYlZWVUVVV1W8DAAavgn+L56STToo//OEP/dbOOeecOPjgg+Mb3/hG1NfXxz777BNtbW0xderUiIhYt25dPP/889HY2FjocQCAElTwQBkxYkQcdthh/db23XffGD16dH793HPPjdmzZ8eoUaOiqqoqvvzlL0djY2N85CMfKfQ4AEAJGpCbZN/KD37wgygvL4+pU6dGb29vNDc3x7XXXluMUQCABJVlWZYVe4jdlcvlorq6Onp6etyPAoPQ+Dn37rT23BVTijAJUEi78/Xb7+IBAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBITsEDpbW1NY4++ugYMWJEjBkzJk477bRYt25dv322bt0aLS0tMXr06Nhvv/1i6tSp0dXVVehRAIASVfBAefjhh6OlpSUeffTRePDBB2P79u3xsY99LLZs2ZLf58ILL4xf/OIXcccdd8TDDz8cL7zwQpx++umFHgUAKFEVhT7gypUr+328fPnyGDNmTHR0dMT//M//RE9PT9x4442xYsWKOPHEEyMiYtmyZXHIIYfEo48+Gh/5yEcKPRIAUGIG/B6Unp6eiIgYNWpURER0dHTE9u3bo6mpKb/PwQcfHOPGjYv29vZdHqO3tzdyuVy/DQAYvAY0UPr6+mLWrFkxefLkOOywwyIiorOzM4YOHRojR47st29NTU10dnbu8jitra1RXV2d3+rr6wdybACgyAY0UFpaWuKPf/xj3Hbbbf/VcebOnRs9PT35bePGjQWaEABIUcHvQXndzJkz45577olVq1bFe9/73vx6bW1tbNu2Lbq7u/tdRenq6ora2tpdHquysjIqKysHalQAIDEFv4KSZVnMnDkz7rrrrnjooYeioaGh3+MTJ06MffbZJ9ra2vJr69ati+effz4aGxsLPQ4AUIIKfgWlpaUlVqxYET/72c9ixIgR+ftKqqurY/jw4VFdXR3nnntuzJ49O0aNGhVVVVXx5S9/ORobG72DBwCIiAEIlOuuuy4iIj760Y/2W1+2bFl87nOfi4iIH/zgB1FeXh5Tp06N3t7eaG5ujmuvvbbQowAAJarggZJl2VvuM2zYsFiyZEksWbKk0H88ADAI+F08AEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACSnqIGyZMmSGD9+fAwbNiwmTZoUjz32WDHHAQASUbRAuf3222P27NmxYMGCWLNmTUyYMCGam5tj06ZNxRoJAEhE0QLl6quvjvPPPz/OOeecOPTQQ2Pp0qXxjne8I3784x8XayQAIBEVxfhDt23bFh0dHTF37tz8Wnl5eTQ1NUV7e/tO+/f29kZvb2/+456enoiIyOVyAz8ssNf19b6605q/71D6Xv97nGXZW+5blEB58cUXY8eOHVFTU9NvvaamJv7yl7/stH9ra2ssXLhwp/X6+voBmxFIS/WiYk8AFMrmzZujurr6TfcpSqDsrrlz58bs2bPzH/f19cVLL70Uo0ePjrKysiJOloZcLhf19fWxcePGqKqqKvY4g5bzvHc4z3uH87x3OM/9ZVkWmzdvjrq6urfctyiB8q53vSuGDBkSXV1d/da7urqitrZ2p/0rKyujsrKy39rIkSMHcsSSVFVV5S/AXuA87x3O897hPO8dzvP/91ZXTl5XlJtkhw4dGhMnToy2trb8Wl9fX7S1tUVjY2MxRgIAElK0b/HMnj07pk+fHkcddVQcc8wxsWjRotiyZUucc845xRoJAEhE0QLlzDPPjH/+858xf/786OzsjCOOOCJWrly5042zvLXKyspYsGDBTt8Go7Cc573Ded47nOe9w3nec2XZ23mvDwDAXuR38QAAyREoAEByBAoAkByBAgAkR6CUiCVLlsT48eNj2LBhMWnSpHjsscfedP/u7u5oaWmJsWPHRmVlZXzgAx+I++67by9NW7p29zwvWrQoDjrooBg+fHjU19fHhRdeGFu3bt1L05amVatWxSmnnBJ1dXVRVlYWd99991s+59e//nUceeSRUVlZGe9///tj+fLlAz5nqdvd83znnXfGySefHO9+97ujqqoqGhsb4/777987w5awPfn3+XW/+93voqKiIo444ogBm6+UCZQScPvtt8fs2bNjwYIFsWbNmpgwYUI0NzfHpk2bdrn/tm3b4uSTT47nnnsufvrTn8a6devihhtuiPe85z17efLSsrvnecWKFTFnzpxYsGBBPPXUU3HjjTfG7bffHt/85jf38uSlZcuWLTFhwoRYsmTJ29p/w4YNMWXKlDjhhBPiySefjFmzZsV5553ni+db2N3zvGrVqjj55JPjvvvui46OjjjhhBPilFNOibVr1w7wpKVtd8/z67q7u+Ozn/1snHTSSQM02SCQkbxjjjkma2lpyX+8Y8eOrK6uLmttbd3l/tddd122//77Z9u2bdtbIw4Ku3ueW1pashNPPLHf2uzZs7PJkycP6JyDSURkd91115vuc/HFF2cf/OAH+62deeaZWXNz8wBONri8nfO8K4ceemi2cOHCwg80SO3OeT7zzDOzSy65JFuwYEE2YcKEAZ2rVLmCkrht27ZFR0dHNDU15dfKy8ujqakp2tvbd/mcn//859HY2BgtLS1RU1MThx12WFx++eWxY8eOvTV2ydmT83zsscdGR0dH/ttA69evj/vuuy8++clP7pWZ/69ob2/v93mJiGhubv6PnxcKo6+vLzZv3hyjRo0q9iiDzrJly2L9+vWxYMGCYo+StJL4bcb/l7344ouxY8eOnX7Cbk1NTfzlL3/Z5XPWr18fDz30UEybNi3uu+++ePbZZ+OCCy6I7du3+wvxH+zJef70pz8dL774Yhx33HGRZVm89tpr8cUvftG3eAqss7Nzl5+XXC4X//rXv2L48OFFmmxw+973vhevvPJKnHHGGcUeZVB55plnYs6cOfGb3/wmKip8CX4zrqAMQn19fTFmzJi4/vrrY+LEiXHmmWfGt771rVi6dGmxRxtUfv3rX8fll18e1157baxZsybuvPPOuPfee+O73/1usUeD/8qKFSti4cKF8ZOf/CTGjBlT7HEGjR07dsSnP/3pWLhwYXzgAx8o9jjJk2+Je9e73hVDhgyJrq6ufutdXV1RW1u7y+eMHTs29tlnnxgyZEh+7ZBDDonOzs7Ytm1bDB06dEBnLkV7cp7nzZsXZ599dpx33nkREXH44YfHli1bYsaMGfGtb30rysv1fyHU1tbu8vNSVVXl6skAuO222+K8886LO+64Y6dvrfHf2bx5czzxxBOxdu3amDlzZkT8+38osyyLioqKeOCBB+LEE08s8pTp8F/QxA0dOjQmTpwYbW1t+bW+vr5oa2uLxsbGXT5n8uTJ8eyzz0ZfX19+7emnn46xY8eKk/9gT87zq6++ulOEvB6FmV9xVTCNjY39Pi8REQ8++OB//Lyw52699dY455xz4tZbb40pU6YUe5xBp6qqKv7whz/Ek08+md+++MUvxkEHHRRPPvlkTJo0qdgjpqXIN+nyNtx2221ZZWVltnz58uzPf/5zNmPGjGzkyJFZZ2dnlmVZdvbZZ2dz5szJ7//8889nI0aMyGbOnJmtW7cuu+eee7IxY8Zkl156abFeQknY3fO8YMGCbMSIEdmtt96arV+/PnvggQeyAw44IDvjjDOK9RJKwubNm7O1a9dma9euzSIiu/rqq7O1a9dmf/3rX7Msy7I5c+ZkZ599dn7/9evXZ+94xzuyiy66KHvqqaeyJUuWZEOGDMlWrlxZrJdQEnb3PN9yyy1ZRUVFtmTJkuwf//hHfuvu7i7WSygJu3ue38i7eP4zgVIirrnmmmzcuHHZ0KFDs2OOOSZ79NFH848df/zx2fTp0/vt/8gjj2STJk3KKisrs/333z+77LLLstdee20vT116duc8b9++Pfv2t7+dHXDAAdmwYcOy+vr67IILLshefvnlvT94CfnVr36VRcRO2+vndvr06dnxxx+/03OOOOKIbOjQodn++++fLVu2bK/PXWp29zwff/zxb7o/u7Yn/z7/bwLlPyvLMteiAYC0uAcFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOf8PN8IvwJdpcogAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize the standard deviations of the final Normal distribution\n", + "sd_diag = torch.cat([v.detach().cpu().flatten() for v in ekf_state.sd_diag.values()]).numpy()\n", + "\n", + "plt.hist(sd_diag, bins=100, density=True);" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# Save state\n", + "def detach(x):\n", + " if isinstance(x, torch.Tensor):\n", + " return x.detach().cpu()\n", + "\n", + "\n", + "ekf_state = tree_map_(detach, ekf_state)\n", + "pickle.dump(ekf_state, open(\"yelp_ekf_state.pkl\", \"wb\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Alternative implementation that updates mu and log_sigma directly without using the\n", + "# uqlib init+update API\n", + "\n", + "# from torch.optim import AdamW\n", + "# from transformers import get_scheduler\n", + "\n", + "\n", + "# mu = dict(model.named_parameters())\n", + "# log_sigma = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), mu)\n", + "\n", + "# vi_params_tensors = list(mu.values()) + list(log_sigma.values())\n", + "\n", + "# vi_optimizer = AdamW(vi_params_tensors, lr=5e-5)\n", + "# vi_lr_scheduler = get_scheduler(\n", + "# name=\"linear\",\n", + "# optimizer=vi_optimizer,\n", + "# num_warmup_steps=0,\n", + "# num_training_steps=num_training_steps,\n", + "# )\n", + "\n", + "# progress_bar = tqdm(range(num_training_steps))\n", + "\n", + "# nelbos = []\n", + "\n", + "# # model.train()\n", + "# for epoch in range(num_epochs):\n", + "# for batch in train_dataloader:\n", + "# batch = {k: v.to(device) for k, v in batch.items()}\n", + "# vi_optimizer.zero_grad()\n", + "\n", + "# sigma = tree_map(torch.exp, log_sigma)\n", + "\n", + "# nelbo = uqlib.vi.diag.nelbo(\n", + "# mu,\n", + "# sigma,\n", + "# batch,\n", + "# param_to_log_posterior,\n", + "# )\n", + "\n", + "# nelbo.backward()\n", + "# nelbos.append(nelbo.item())\n", + "\n", + "# vi_optimizer.step()\n", + "# vi_lr_scheduler.step()\n", + "# progress_bar.update(1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/yelp/yelp_subspace_vi_diag.ipynb b/examples/yelp/yelp_subspace_vi_diag.ipynb index 1f4906b0..85b5d4e7 100644 --- a/examples/yelp/yelp_subspace_vi_diag.ipynb +++ b/examples/yelp/yelp_subspace_vi_diag.ipynb @@ -242,7 +242,7 @@ } ], "source": [ - "# Visualize the standard deviations of the Laplace approximation\n", + "# Visualize the standard deviations of the final Normal distribution\n", "sd_diag = torch.cat([v.exp().detach().cpu().flatten() for v in vi_state.log_sd_diag.values()]).numpy()\n", "\n", "plt.hist(sd_diag, bins=100, density=True);" diff --git a/tests/ekf/__init__.py b/tests/ekf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ekf/test_diag_fisher.py b/tests/ekf/test_diag_fisher.py new file mode 100644 index 00000000..9596c08f --- /dev/null +++ b/tests/ekf/test_diag_fisher.py @@ -0,0 +1,71 @@ +from functools import partial +from typing import Any +import torch +from optree import tree_map + +from uqlib import ekf +from uqlib.utils import diag_normal_log_prob + + +def batch_normal_log_prob( + p: dict, batch: Any, mean: dict, sd_diag: dict +) -> torch.Tensor: + return diag_normal_log_prob(p, mean, sd_diag) + + +def test_ekf_diag(): + torch.manual_seed(42) + target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} + target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) + + batch_normal_log_prob_spec = partial( + batch_normal_log_prob, mean=target_mean, sd_diag=target_sds + ) + + init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) + + batch = torch.arange(3).reshape(-1, 1) + + n_steps = 1000 + transform = ekf.diag_fisher.build(batch_normal_log_prob_spec, lr=1e-3) + + state = transform.init(init_mean) + + log_liks = [] + + for _ in range(n_steps): + state = transform.update(state, batch) + log_liks.append(state.log_likelihood) + + for key in state.mean: + assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1) + + # Test inplace + state_ip = transform.init(init_mean) + state_ip2 = transform.update( + state_ip, + batch, + inplace=True, + ) + + for key in state_ip2.mean: + assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8) + assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8) + + # Test not inplace + state_ip_false = transform.init( + tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) + ) + state_ip_false2 = transform.update( + state_ip_false, + batch, + inplace=False, + ) + + for key in state_ip.mean: + assert not torch.allclose( + state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8 + ) + assert not torch.allclose( + state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8 + ) diff --git a/tests/ekf/test_diag_hessian.py b/tests/ekf/test_diag_hessian.py new file mode 100644 index 00000000..4b79fb2f --- /dev/null +++ b/tests/ekf/test_diag_hessian.py @@ -0,0 +1,71 @@ +from functools import partial +from typing import Any +import torch +from optree import tree_map + +from uqlib import ekf +from uqlib.utils import diag_normal_log_prob + + +def batch_normal_log_prob( + p: dict, batch: Any, mean: dict, sd_diag: dict +) -> torch.Tensor: + return diag_normal_log_prob(p, mean, sd_diag) + + +def test_ekf_diag(): + torch.manual_seed(42) + target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} + target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) + + batch_normal_log_prob_spec = partial( + batch_normal_log_prob, mean=target_mean, sd_diag=target_sds + ) + + init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) + + batch = torch.arange(3).reshape(-1, 1) + + n_steps = 1000 + transform = ekf.diag_hessian.build(batch_normal_log_prob_spec, lr=1e-3) + + state = transform.init(init_mean) + + log_liks = [] + + for _ in range(n_steps): + state = transform.update(state, batch) + log_liks.append(state.log_likelihood) + + for key in state.mean: + assert torch.allclose(state.mean[key], target_mean[key], atol=1e-1) + + # Test inplace + state_ip = transform.init(init_mean) + state_ip2 = transform.update( + state_ip, + batch, + inplace=True, + ) + + for key in state_ip2.mean: + assert torch.allclose(state_ip2.mean[key], state_ip.mean[key], atol=1e-8) + assert torch.allclose(state_ip2.sd_diag[key], state_ip.sd_diag[key], atol=1e-8) + + # Test not inplace + state_ip_false = transform.init( + tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) + ) + state_ip_false2 = transform.update( + state_ip_false, + batch, + inplace=False, + ) + + for key in state_ip.mean: + assert not torch.allclose( + state_ip_false2.mean[key], state_ip_false.mean[key], atol=1e-8 + ) + assert not torch.allclose( + state_ip_false2.sd_diag[key], state_ip_false.sd_diag[key], atol=1e-8 + ) diff --git a/tests/laplace/__init__.py b/tests/laplace/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/uqlib/__init__.py b/uqlib/__init__.py index 896b74f2..0130ca87 100644 --- a/uqlib/__init__.py +++ b/uqlib/__init__.py @@ -1,7 +1,8 @@ +from uqlib import ekf from uqlib import laplace -from uqlib import vi from uqlib import sgmcmc from uqlib import types +from uqlib import vi from uqlib.utils import model_to_function from uqlib.utils import hvp @@ -16,3 +17,5 @@ from uqlib.utils import insert_requires_grad_ from uqlib.utils import extract_requires_grad_and_func from uqlib.utils import inplacify +from uqlib.utils import tree_map_inplacify_ +from uqlib.utils import flexi_tree_map diff --git a/uqlib/ekf/__init__.py b/uqlib/ekf/__init__.py new file mode 100644 index 00000000..619940fa --- /dev/null +++ b/uqlib/ekf/__init__.py @@ -0,0 +1,2 @@ +from uqlib.ekf import diag_fisher +from uqlib.ekf import diag_hessian diff --git a/uqlib/ekf/diag_fisher.py b/uqlib/ekf/diag_fisher.py new file mode 100644 index 00000000..4d2abd04 --- /dev/null +++ b/uqlib/ekf/diag_fisher.py @@ -0,0 +1,170 @@ +from typing import Callable, Any, NamedTuple +from functools import partial +import torch +from torch.func import vmap, jacrev +from optree import tree_map + +from uqlib.types import TensorTree, Transform +from uqlib.utils import diag_normal_sample, flexi_tree_map + + +class EKFDiagState(NamedTuple): + """State encoding a diagonal Normal distribution over parameters. + + Args: + mean: Mean of the Normal distribution. + sd_diag: Square-root diagonal of the covariance matrix of the + Normal distribution. + log_likelihood: Log likelihood of the data given the parameters. + """ + + mean: TensorTree + sd_diag: TensorTree + log_likelihood: float = 0 + + +def init( + params: TensorTree, + init_sds: TensorTree | None = None, +) -> EKFDiagState: + """Initialise diagonal Normal distribution over parameters. + + Args: + params: Initial mean of the variational distribution. + init_sds: Initial square-root diagonal of the covariance matrix + of the variational distribution. Defaults to ones. + + Returns: + Initial EKFDiagState. + """ + if init_sds is None: + init_sds = tree_map( + lambda x: torch.ones_like(x, requires_grad=x.requires_grad), params + ) + + return EKFDiagState(params, init_sds) + + +def update( + state: EKFDiagState, + batch: Any, + log_likelihood: Callable[[TensorTree, Any], float], + lr: float, + transition_sd: float = 0.0, + per_sample: bool = False, + inplace: bool = True, +) -> EKFDiagState: + """Applies an extended Kalman Filter update to the diagonal Normal distribution. + The update is first order, i.e. the likelihood is approximated by a + + log p(y | x, p) ≈ log p(y | x, μ) + lr * g(μ)ᵀ(p - μ) + + lr * 1/2 (p - μ)ᵀ F_d(μ) (p - μ) T⁻¹ + + where μ is the mean of the variational distribution, lr is the learning rate + (likelihood inverse temperature), whilst g(μ) is the gradient and F_d(μ) the + negative diagonal empirical Fisher of the log-likelihood with respect to the + parameters. + + Args: + state: Current state. + batch: Input data to log_likelihood. + log_likelihood: Function that takes parameters and input batch and + returns the log-likelihood. + lr: Inverse temperature of the update, which behaves like a learning rate. + see https://arxiv.org/abs/1703.00209 for details. + transition_sd: Standard deviation of the transition noise, to additively + inflate the diagonal covariance before the update. Defaults to zero. + per_sample: If True, then log_likelihood is assumed to return a vector of + log likelihoods for each sample in the batch. If False, then log_likelihood + is assumed to return a scalar log likelihood for the whole batch, in this + case torch.func.vmap will be called, this is typically slower than + directly writing log_likelihood to be per sample. + inplace: Whether to update the state parameters in-place. + + Returns: + Updated EKFDiagState. + """ + + if per_sample: + log_likelihood_per_sample = log_likelihood + else: + # per-sample gradients following https://pytorch.org/tutorials/intermediate/per_sample_grads.html + @partial(vmap, in_dims=(None, 0)) + def log_likelihood_per_sample(params, batch): + batch = tree_map(lambda x: x.unsqueeze(0), batch) + return log_likelihood(params, batch) + + predict_sd_diag = flexi_tree_map( + lambda x: (x**2 + transition_sd**2) ** 0.5, state.sd_diag, inplace=inplace + ) + with torch.no_grad(): + log_lik = log_likelihood_per_sample(state.mean, batch).mean() + jac = jacrev(log_likelihood_per_sample)(state.mean, batch) + grad = tree_map(lambda x: x.mean(0), jac) + diag_lik_hessian_approx = tree_map(lambda x: -(x**2).mean(0), jac) + + update_sd_diag = flexi_tree_map( + lambda sig, h: (sig**-2 - lr * h) ** -0.5, + predict_sd_diag, + diag_lik_hessian_approx, + inplace=inplace, + ) + update_mean = flexi_tree_map( + lambda mu, sig, g: mu + sig**2 * lr * g, + state.mean, + update_sd_diag, + grad, + inplace=inplace, + ) + return EKFDiagState(update_mean, update_sd_diag, log_lik.item()) + + +def build( + log_likelihood: Callable[[TensorTree, Any], float], + lr: float, + transition_sd: float = 0.0, + per_sample: bool = False, + init_sds: TensorTree | None = None, +) -> Transform: + """Builds a transform for variational inference with a diagonal Normal + distribution over parameters. + + Args: + log_likelihood: Function that takes parameters and input batch and + returns the log-likelihood. + lr: Inverse temperature of the update, which behaves like a learning rate. + see https://arxiv.org/abs/1703.00209 for details. + transition_sd: Standard deviation of the transition noise, to additively + inflate the diagonal covariance before the update. Defaults to zero. + per_sample: If True, then log_likelihood is assumed to return a vector of + log likelihoods for each sample in the batch. If False, then log_likelihood + is assumed to return a scalar log likelihood for the whole batch, in this + case torch.func.vmap will be called, this is typically slower than + directly writing log_likelihood to be per sample. + init_sds: Initial square-root diagonal of the covariance matrix + of the variational distribution. Defaults to ones. + + Returns: + Diagonal EKF transform (uqlib.types.Transform instance). + """ + init_fn = partial(init, init_sds=init_sds) + update_fn = partial( + update, + log_likelihood=log_likelihood, + lr=lr, + transition_sd=transition_sd, + per_sample=per_sample, + ) + return Transform(init_fn, update_fn) + + +def sample(state: EKFDiagState, sample_shape: torch.Size = torch.Size([])): + """Single sample from diagonal Normal distribution over parameters. + + Args: + state: State encoding mean and standard deviations. + + Returns: + Sample from Normal distribution. + """ + return diag_normal_sample(state.mean, state.sd_diag, sample_shape=sample_shape) diff --git a/uqlib/ekf/diag_hessian.py b/uqlib/ekf/diag_hessian.py new file mode 100644 index 00000000..31b56b01 --- /dev/null +++ b/uqlib/ekf/diag_hessian.py @@ -0,0 +1,130 @@ +from typing import Callable, Any +from functools import partial +import torch +from torch.func import grad_and_value +from optree import tree_map + +from uqlib.types import TensorTree, Transform +from uqlib.utils import diag_normal_sample, hessian_diag, flexi_tree_map +from uqlib.ekf.diag_fisher import EKFDiagState + + +def init( + params: TensorTree, + init_sds: TensorTree | None = None, +) -> EKFDiagState: + """Initialise diagonal Normal distribution over parameters. + + Args: + params: Initial mean of the variational distribution. + init_sds: Initial square-root diagonal of the covariance matrix + of the variational distribution. Defaults to ones. + + Returns: + Initial EKFDiagState. + """ + if init_sds is None: + init_sds = tree_map( + lambda x: torch.ones_like(x, requires_grad=x.requires_grad), params + ) + + return EKFDiagState(params, init_sds) + + +def update( + state: EKFDiagState, + batch: Any, + log_likelihood: Callable[[TensorTree, Any], float], + lr: float, + transition_sd: float = 0.0, + inplace: bool = True, +) -> EKFDiagState: + """Applies an extended Kalman Filter update to the diagonal Normal distribution. + The update is first order, i.e. the likelihood is approximated by a + + log p(y | x, p) ≈ log p(y | x, μ) + lr * g(μ)ᵀ(p - μ) + + lr * 1/2 (p - μ)ᵀ H_d(μ) (p - μ) T⁻¹ + + where μ is the mean of the variational distribution, lr is the learning rate + (likelihood inverse temperature), whilst g(μ) is the gradient and H_d(μ) the + diagonal Hessian of the log-likelihood with respect to the parameters. + + Args: + state: Current state. + batch: Input data to log_likelihood. + log_likelihood: Function that takes parameters and input batch and + returns the log-likelihood. + lr: Inverse temperature of the update, which behaves like a learning rate. + see https://arxiv.org/abs/1703.00209 for details. + transition_sd: Standard deviation of the transition noise, to additively + inflate the diagonal covariance before the update. Defaults to zero. + inplace: Whether to update the state parameters in-place. + + Returns: + Updated EKFDiagState. + """ + predict_sd_diag = flexi_tree_map( + lambda x: (x**2 + transition_sd**2) ** 0.5, state.sd_diag, inplace=inplace + ) + with torch.no_grad(): + grad, log_lik = grad_and_value(log_likelihood)(state.mean, batch) + diag_hessian = hessian_diag(log_likelihood)(state.mean, batch) + + update_sd_diag = flexi_tree_map( + lambda sig, h: (sig**-2 - lr * h) ** -0.5, + predict_sd_diag, + diag_hessian, + inplace=inplace, + ) + update_mean = flexi_tree_map( + lambda mu, sig, g: mu + sig**2 * lr * g, + state.mean, + update_sd_diag, + grad, + inplace=inplace, + ) + return EKFDiagState(update_mean, update_sd_diag, log_lik.item()) + + +def build( + log_likelihood: Callable[[TensorTree, Any], float], + lr: float, + transition_sd: float = 0.0, + init_sds: TensorTree | None = None, +) -> Transform: + """Builds a transform for variational inference with a diagonal Normal + distribution over parameters. + + Args: + log_likelihood: Function that takes parameters and input batch and + returns the log-likelihood. + lr: Inverse temperature of the update, which behaves like a learning rate. + see https://arxiv.org/abs/1703.00209 for details. + transition_sd: Standard deviation of the transition noise, to additively + inflate the diagonal covariance before the update. Defaults to zero. + init_sds: Initial square-root diagonal of the covariance matrix + of the variational distribution. Defaults to ones. + + Returns: + Diagonal EKF transform (uqlib.types.Transform instance). + """ + init_fn = partial(init, init_sds=init_sds) + update_fn = partial( + update, + log_likelihood=log_likelihood, + lr=lr, + transition_sd=transition_sd, + ) + return Transform(init_fn, update_fn) + + +def sample(state: EKFDiagState, sample_shape: torch.Size = torch.Size([])): + """Single sample from diagonal Normal distribution over parameters. + + Args: + state: State encoding mean and standard deviations. + + Returns: + Sample from Normal distribution. + """ + return diag_normal_sample(state.mean, state.sd_diag, sample_shape=sample_shape) diff --git a/uqlib/laplace/diag_fisher.py b/uqlib/laplace/diag_fisher.py index 0de05b3d..d46223ad 100644 --- a/uqlib/laplace/diag_fisher.py +++ b/uqlib/laplace/diag_fisher.py @@ -2,10 +2,10 @@ from typing import Callable, Any, NamedTuple import torch from torch.func import jacrev, vmap -from optree import tree_map, tree_map_ +from optree import tree_map from uqlib.types import TensorTree, Transform -from uqlib.utils import diag_normal_sample, inplacify +from uqlib.utils import diag_normal_sample, flexi_tree_map class DiagLaplaceState(NamedTuple): @@ -34,7 +34,9 @@ def init( Initial DiagVIState. """ if init_prec_diag is None: - init_prec_diag = tree_map(lambda x: torch.zeros_like(x), params) + init_prec_diag = tree_map( + lambda x: torch.zeros_like(x, requires_grad=x.requires_grad), params + ) return DiagLaplaceState(params, init_prec_diag) @@ -84,12 +86,9 @@ def log_posterior_per_sample(params, batch): def update_func(x, y): return x + y - if inplace: - prec_diag = tree_map_( - inplacify(update_func), state.prec_diag, batch_diag_score_sq - ) - else: - prec_diag = tree_map(update_func, state.prec_diag, batch_diag_score_sq) + prec_diag = flexi_tree_map( + update_func, state.prec_diag, batch_diag_score_sq, inplace=inplace + ) return DiagLaplaceState(state.mean, prec_diag) diff --git a/uqlib/laplace/diag_hessian.py b/uqlib/laplace/diag_hessian.py index d34a5cca..77a69f3d 100644 --- a/uqlib/laplace/diag_hessian.py +++ b/uqlib/laplace/diag_hessian.py @@ -1,10 +1,10 @@ from functools import partial from typing import Callable, Any import torch -from optree import tree_map, tree_map_, tree_flatten +from optree import tree_map, tree_flatten from uqlib.types import TensorTree, Transform -from uqlib.utils import hessian_diag, diag_normal_sample, inplacify +from uqlib.utils import hessian_diag, diag_normal_sample, flexi_tree_map from uqlib.laplace.diag_fisher import DiagLaplaceState @@ -22,7 +22,9 @@ def init( Initial DiagVIState. """ if init_prec_diag is None: - init_prec_diag = tree_map(lambda x: torch.zeros_like(x), params) + init_prec_diag = tree_map( + lambda x: torch.zeros_like(x, requires_grad=x.requires_grad), params + ) return DiagLaplaceState(params, init_prec_diag) @@ -56,10 +58,9 @@ def update( def update_func(x, y): return x - y * batch_size - if inplace: - prec_diag = tree_map_(inplacify(update_func), state.prec_diag, batch_diag_hess) - else: - prec_diag = tree_map(update_func, state.prec_diag, batch_diag_hess) + prec_diag = flexi_tree_map( + update_func, state.prec_diag, batch_diag_hess, inplace=inplace + ) return DiagLaplaceState(state.mean, prec_diag) diff --git a/uqlib/sgmcmc/sghmc.py b/uqlib/sgmcmc/sghmc.py index 410f1082..75fc853b 100644 --- a/uqlib/sgmcmc/sghmc.py +++ b/uqlib/sgmcmc/sghmc.py @@ -2,10 +2,10 @@ from functools import partial import torch from torch.func import grad_and_value -from optree import tree_map, tree_map_ +from optree import tree_map from uqlib.types import TensorTree, Transform -from uqlib.utils import inplacify +from uqlib.utils import flexi_tree_map class SGHMCState(NamedTuple): @@ -80,13 +80,10 @@ def transform_momenta(m, g): * torch.randn_like(m) ) - if inplace: - params = tree_map_(inplacify(transform_params), state.params, state.momenta) - momenta = tree_map_(inplacify(transform_momenta), state.momenta, grads) - - else: - params = tree_map(transform_params, state.params, state.momenta) - momenta = tree_map(transform_momenta, state.momenta, grads) + params = flexi_tree_map( + transform_params, state.params, state.momenta, inplace=inplace + ) + momenta = flexi_tree_map(transform_momenta, state.momenta, grads, inplace=inplace) return SGHMCState(params, momenta, log_post.item()) diff --git a/uqlib/utils.py b/uqlib/utils.py index 12c54d30..bb6d01fa 100644 --- a/uqlib/utils.py +++ b/uqlib/utils.py @@ -256,3 +256,109 @@ def func_(tens, *args, **kwargs): return tens return func_ + + +def tree_map_inplacify_( + func: Callable, + tree: TensorTree, + *rests: TensorTree, + is_leaf: Callable[[TensorTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> TensorTree: + """Applies a pure function to each tensor in a PyTree in-place. + + Like optree.tree_map_ but takes a pure function as input + (and takes replaces its first argument with its output in-place) + rather than a side-effect function. + + Args: + func: A function that takes a tensor as its first argument and a returns + a modified version of said tensor. + tree (pytree): A pytree to be mapped over, with each leaf providing the first + positional argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same + structure as ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An optionally specified function that will be called at each + flattening step. It should return a boolean, with :data:`True` stopping the traversal + and the whole subtree being treated as a leaf, and :data:`False` indicating the + flattening should traverse the current object. + none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, + :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the + treespec rather than in the leaves list and :data:`None` will be remain in the result + pytree. (default: :data:`False`) + namespace (str, optional): The registry namespace used for custom pytree node types. + (default: :const:`''`, i.e., the global namespace) + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + return tree_map_( + inplacify(func), + tree, + *rests, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +def flexi_tree_map( + func: Callable, + tree: TensorTree, + *rests: TensorTree, + inplace: bool = False, + is_leaf: Callable[[TensorTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> TensorTree: + """Applies a pure function to each tensor in a PyTree, with inplace argument. + + ``` + out_tensor = func(tensor, *rest_tensors) + ``` + + where `out_tensor` is of the same shape as `tensor`. + Therefore + + ``` + out_tree = func(tree, *rests, inplace=True) + ``` + + will return `out_tree` a pointer to the original `tree` with leaves (tensors) modified in place. + If `inplace=False`, `flexi_tree_map` is equivalent to `optree.tree_map` and returns a new tree. + + Args: + func: A pure function that takes a tensor as its first argument and a returns + a modified version of said tensor. + tree (pytree): A pytree to be mapped over, with each leaf providing the first + positional argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same + structure as ``tree`` or has ``tree`` as a prefix. + inplace (bool, optional): Whether to modify the tree in-place or not. + is_leaf (callable, optional): An optionally specified function that will be called at each + flattening step. It should return a boolean, with :data:`True` stopping the traversal + and the whole subtree being treated as a leaf, and :data:`False` indicating the + flattening should traverse the current object. + none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, + :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the + treespec rather than in the leaves list and :data:`None` will be remain in the result + pytree. (default: :data:`False`) + namespace (str, optional): The registry namespace used for custom pytree node types. + (default: :const:`''`, i.e., the global namespace) + + Returns: + Either the original tree modified in-place or a new tree depending on the `inplace` + argument. + """ + tm = tree_map_inplacify_ if inplace else tree_map + return tm( + func, + tree, + *rests, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + )