Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W2D3 Post-Course Update (TA Feedback) #414

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 136 additions & 25 deletions tutorials/W2D3_Microlearning/W2D3_Tutorial1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,26 @@
"losses_data = pd.read_csv(\"losses.csv\")\n",
"snr_data = pd.read_csv(\"snr.csv\")\n",
"\n",
"losses_weight_perturbation_solution = losses_data[\"weight_perturbation\"]\n",
"losses_node_perturbation_solution = losses_data[\"node_perturbation\"]\n",
"losses_feedback_alignment_solution = losses_data[\"feedback_alignment\"]\n",
"losses_kolen_pollack_solution = losses_data[\"kolen_pollack\"]\n",
"losses_backpropagation_solution = losses_data[\"backpropagation\"]\n",
"\n",
"cosine_similarity_feedback_alignment_solution = cosine_similarity_data[\"feedback_alignment\"]\n",
"cosine_similarity_kolen_pollack_solution = cosine_similarity_data[\"kolen_pollack\"]\n",
"cosine_similarity_backpropagation_solution = cosine_similarity_data[\"backpropagation\"]\n",
"accuracy_weight_perturbation_solution = accuracy_data[\"weight_perturbation\"]\n",
"\n",
"accuracy_node_perturbation_solution = accuracy_data[\"node_perturbation\"]\n",
"accuracy_feedback_alignment_solution = accuracy_data[\"feedback_alignment\"]\n",
"accuracy_kolen_pollack_solution = accuracy_data[\"kolen_pollack\"]\n",
"accuracy_backpropagation_solution = accuracy_data[\"backpropagation\"]\n",
"\n",
"snr_weight_perturbation_solution = snr_data[\"weight_perturbation\"][0]\n",
"snr_node_perturbation_solution = snr_data[\"node_perturbation\"][0]\n",
"snr_backpropagation_solution = snr_data[\"backpropagation\"][0]\n",
"\n",
"with contextlib.redirect_stdout(io.StringIO()):\n",
" # Load the MNIST dataset, 50K training images, 10K validation, 10K testing\n",
" train_set = datasets.MNIST('./', transform=transforms.ToTensor(), train=True, download=True)\n",
Expand Down Expand Up @@ -896,7 +916,7 @@
},
"outputs": [],
"source": [
"# @title Train and observe the performance of WeightPerturbMLP\n",
"# @title Train WeightPerturbMLP\n",
"\n",
"rng_wp = np.random.default_rng(seed=seed)\n",
"losses_perturb = np.zeros((numupdates,))\n",
Expand All @@ -912,9 +932,26 @@
" learning_rate=learnrate, batch_size=batchsize, algorithm='perturb', noise=noise, \\\n",
" report=report, report_rate=rep_rate)\n",
"\n",
"# save metrics for plots\n",
"losses_weight_perturbation_solution = losses_perturb\n",
"accuracy_weight_perturbation_solution = accuracy_perturb\n",
"snr_weight_perturbation_solution = snr_perturb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {}
},
"outputs": [],
"source": [
"# @title Observe the performance of WeightPerturbMLP\n",
"\n",
"# plot performance over time\n",
"with plt.xkcd():\n",
" plt.plot(losses_data['weight_perturbation'], label=\"Weight Perturbation\", color='b') #pre-saved history of loss\n",
" plt.plot(losses_weight_perturbation_solution, label=\"Weight Perturbation\", color='b')\n",
" plt.xlabel(\"Updates\")\n",
" plt.ylabel(\"MSE\")\n",
" plt.legend()\n",
Expand Down Expand Up @@ -1025,7 +1062,7 @@
},
"outputs": [],
"source": [
"# @title Train and observe the performance of NodePerturbMLP\n",
"# @title Train NodePerturbMLP\n",
"\n",
"losses_node_perturb = np.zeros((numupdates,))\n",
"accuracy_node_perturb = np.zeros((numepochs,))\n",
Expand All @@ -1044,10 +1081,27 @@
" learning_rate=learnrate, batch_size=batchsize, algorithm='node_perturb', noise=noise, \\\n",
" report=report, report_rate=rep_rate)\n",
"\n",
"# save metrics for plots\n",
"losses_node_perturbation_solution = losses_node_perturb\n",
"accuracy_node_perturbation_solution = accuracy_node_perturb\n",
"snr_node_perturbation_solution = snr_node_perturb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {}
},
"outputs": [],
"source": [
"# @title Observe the performance of NodePerturbMLP\n",
"\n",
"# plot performance over time\n",
"with plt.xkcd():\n",
" plt.plot(losses_data['node_perturbation'], label=\"Node Perturbation\", color='c') #pre-saved history of loss\n",
" plt.plot(losses_data['weight_perturbation'], label=\"Weight Perturbation\", color='b') #pre-saved history of loss\n",
" plt.plot(losses_node_perturbation_solution, label=\"Node Perturbation\", color='c') #pre-saved history of loss\n",
" plt.plot(losses_weight_perturbation_solution, label=\"Weight Perturbation\", color='b') #pre-saved history of loss\n",
" plt.xlabel(\"Updates\")\n",
" plt.ylabel(\"MSE\")\n",
" plt.legend()\n",
Expand Down Expand Up @@ -1302,11 +1356,28 @@
" learning_rate=learnrate, batch_size=batchsize, algorithm='backprop', noise=noise, \\\n",
" report=report, report_rate=rep_rate)\n",
"\n",
"# save metrics for plots\n",
"losses_backpropagation_solution = losses_backprop\n",
"accuracy_backpropagation_solution = accuracy_backprop\n",
"snr_backpropagation_solution = snr_backprop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {}
},
"outputs": [],
"source": [
"# @title Compare the performance and SNRs for Weight Perturbation, Node Perturbation, and Backpropagation\n",
"\n",
"# plot performance over time\n",
"with plt.xkcd():\n",
" plt.plot(losses_data['node_perturbation'], label=\"Node Perturbation\", color='c') #pre-saved history of loss\n",
" plt.plot(losses_data['weight_perturbation'], label=\"Weight Perturbation\", color='b') #pre-saved history of loss\n",
" plt.plot(losses_data['backpropagation'], label=\"Backprop\", color='r') #pre-saved history of loss\n",
" plt.plot(losses_node_perturbation_solution, label=\"Node Perturbation\", color='c')\n",
" plt.plot(losses_weight_perturbation_solution, label=\"Weight Perturbation\", color='b')\n",
" plt.plot(losses_backpropagation_solution, label=\"Backprop\", color='r')\n",
" plt.xlabel(\"Updates\")\n",
" plt.ylabel(\"MSE\")\n",
" plt.legend()\n",
Expand All @@ -1317,7 +1388,7 @@
"with plt.xkcd():\n",
" plt.figure()\n",
" x = [0, 1, 2]\n",
" snr_vals = [snr_data['weight_perturbation'][0], snr_data['node_perturbation'][0], snr_data['backpropagation'][0]] #pre-saved snrs\n",
" snr_vals = [snr_weight_perturbation_solution, snr_node_perturbation_solution, snr_backpropagation_solution] #pre-saved snrs\n",
" colors = ['b', 'c', 'r']\n",
" labels = ['Weight Perturbation', 'Node Perturbation', 'Backprop']\n",
" plt.bar(x, snr_vals, color=colors, tick_label=labels)\n",
Expand Down Expand Up @@ -1626,7 +1697,7 @@
},
"outputs": [],
"source": [
"# @title Train and observe the performance of FeedbackAlignmentMLP\n",
"# @title Train FeedbackAlignmentMLP\n",
"\n",
"rng_fa = np.random.default_rng(seed=seed)\n",
"\n",
Expand All @@ -1644,6 +1715,11 @@
" learning_rate=learnrate, batch_size=batchsize, algorithm='feedback', noise=noise, \\\n",
" report=report, report_rate=rep_rate)\n",
"\n",
"# save metrics for plots\n",
"losses_feedback_alignment_solution = losses_feedback\n",
"accuracy_feedback_alignment_solution = accuracy_feedback\n",
"cosine_similarity_feedback_alignment_solution = cosine_sim_feedback\n",
"\n",
"# Train a network with Backpropagation for comparison\n",
"\n",
"# set the random seed to the current time\n",
Expand All @@ -1663,10 +1739,28 @@
" learning_rate=learnrate, batch_size=batchsize, algorithm='backprop', noise=noise, \\\n",
" report=report, report_rate=rep_rate)\n",
"\n",
"\n",
"# save metrics for plots\n",
"losses_backpropagation_solution = losses_backprop\n",
"accuracy_backpropagation_solution = accuracy_backprop\n",
"cosine_similarity_backpropagation_solution = cosine_sim_backprop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {}
},
"outputs": [],
"source": [
"# @title Observe the performance of FeedbackAlignmentMLP\n",
"\n",
"# plot performance over time\n",
"with plt.xkcd():\n",
" plt.plot(losses_data['feedback_alignment'], label=\"Feedback Alignment\", color='g') #pre-saved history of loss\n",
" plt.plot(losses_data['backpropagation'], label=\"Backprop\", color='r') #pre-saved history of loss\n",
" plt.plot(losses_feedback_alignment_solution, label=\"Feedback Alignment\", color='g')\n",
" plt.plot(losses_backpropagation_solution, label=\"Backprop\", color='r')\n",
" plt.xlabel(\"Updates\")\n",
" plt.ylabel(\"MSE\")\n",
" plt.legend()\n",
Expand Down Expand Up @@ -1932,7 +2026,7 @@
},
"outputs": [],
"source": [
"# @title Train and observe the performance of KolenPollackMLP\n",
"# @title Train KolenPollackMLP\n",
"rng_kp = np.random.default_rng(seed=seed)\n",
"\n",
"losses_kolepoll = np.zeros((numupdates,))\n",
Expand All @@ -1948,11 +2042,28 @@
" learning_rate=learnrate, batch_size=batchsize, algorithm='kolepoll', noise=noise, \\\n",
" report=report, report_rate=rep_rate)\n",
"\n",
"# save metrics for plots\n",
"losses_kolen_pollack_solution = losses_kolepoll\n",
"accuracy_kolen_pollack_solution = accuracy_kolepoll\n",
"cosine_similarity_kolen_pollack_solution = cosine_sim_kolepoll"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {}
},
"outputs": [],
"source": [
"# @title Observe the performance of KolenPollackMLP\n",
"\n",
"# plot performance over time\n",
"with plt.xkcd():\n",
" plt.plot(losses_data['feedback_alignment'], label=\"Feedback Alignment\", color='g') #pre-saved history of loss\n",
" plt.plot(losses_data['backpropagation'], label=\"Backprop\", color='r') #pre-saved history of loss\n",
" plt.plot(losses_data['kolen_pollack'], label=\"Kolen-Pollack\", color='k') #pre-saved history of loss\n",
" plt.plot(losses_feedback_alignment_solution, label=\"Feedback Alignment\", color='g')\n",
" plt.plot(losses_backpropagation_solution, label=\"Backprop\", color='r')\n",
" plt.plot(losses_kolen_pollack_solution, label=\"Kolen-Pollack\", color='k')\n",
" plt.xlabel(\"Updates\")\n",
" plt.ylabel(\"MSE\")\n",
" plt.legend()\n",
Expand Down Expand Up @@ -2013,11 +2124,11 @@
},
"outputs": [],
"source": [
"# @title Plot the gradient similarity to backprop over training with shaded error regions\n",
"# @title Plot the gradient similarity to backpropagation over training with shaded error regions\n",
"with plt.xkcd():\n",
" plt.plot(cosine_similarity_data[\"backpropagation\"], label=\"Backprop\", color='r') #pre-saved cosine similarities\n",
" plt.plot(cosine_similarity_data[\"feedback_alignment\"], label=\"Feedback Alignment\", color='g')\n",
" plt.plot(cosine_similarity_data[\"kolen_pollack\"], label=\"Kolen-Pollack\", color='k')\n",
" plt.plot(cosine_similarity_backpropagation_solution, label=\"Backprop\", color='r')\n",
" plt.plot(cosine_similarity_feedback_alignment_solution, label=\"Feedback Alignment\", color='g')\n",
" plt.plot(cosine_similarity_kolen_pollack_solution, label=\"Kolen-Pollack\", color='k')\n",
" plt.xlabel(\"Epochs\")\n",
" plt.ylabel(\"Cosine Sim\")\n",
" plt.legend()\n",
Expand Down Expand Up @@ -2045,11 +2156,11 @@
"source": [
"# @title Classification accuracy comparison\n",
"with plt.xkcd():\n",
" plt.plot(accuracy_data['weight_perturbation']) #pre-saved accuracies\n",
" plt.plot(accuracy_data['node_perturbation'])\n",
" plt.plot(accuracy_data['feedback_alignment'])\n",
" plt.plot(accuracy_data['kolen_pollack'])\n",
" plt.plot(accuracy_data['backpropagation'])\n",
" plt.plot(accuracy_weight_perturbation_solution)\n",
" plt.plot(accuracy_node_perturbation_solution)\n",
" plt.plot(accuracy_feedback_alignment_solution)\n",
" plt.plot(accuracy_kolen_pollack_solution)\n",
" plt.plot(accuracy_backpropagation_solution)\n",
" plt.legend(['Weight perturbation', 'Node perturbation', 'Feedback alignment', 'Kolen-Pollack', 'Backprop'])\n",
" plt.xlabel('Epochs')\n",
" plt.ylabel('Accuracy (%)')\n",
Expand Down
Loading