Skip to content

Commit

Permalink
update figures
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangir-azerbayev committed Oct 30, 2023
1 parent 72fdb87 commit f0c460e
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 20 deletions.
2 changes: 1 addition & 1 deletion analysis/hf_reanalysis/D_effective-raw.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion analysis/hf_reanalysis/data-constrained-scaling-raw.html

Large diffs are not rendered by default.

78 changes: 66 additions & 12 deletions analysis/hf_reanalysis/reanalysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 3,
"id": "f850551f",
"metadata": {},
"outputs": [],
Expand All @@ -534,7 +534,7 @@
" z_key: str = None,\n",
" z_type: Literal['log', 'linear'] = 'log',\n",
" color_key: str = None,\n",
" color_type: Literal['log', 'log2', 'linear'] = 'linear',\n",
" color_type: Literal['log', 'log2', 'log10', 'linear'] = 'linear',\n",
" fit_fn = None,\n",
" savepath: str = None,\n",
"):\n",
Expand Down Expand Up @@ -587,6 +587,8 @@
" color_variable = np.log(runs[color_key])\n",
" elif color_type==\"log2\":\n",
" color_variable = np.log2(runs[color_key])\n",
" elif color_type==\"log10\":\n",
" color_variable = np.log10(runs[color_key])\n",
" else:\n",
" color_variable = runs[color_key]\n",
" \n",
Expand Down Expand Up @@ -625,7 +627,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"id": "969c0e98",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -725,7 +727,7 @@
"metadata": {},
"source": [
"## Fit parametric scaling law\n",
"$$\\mathcal{L}(N, D) = E + \\frac{A}{N^\\alpha} + \\frac{B}{N^\\beta}$$"
"$$\\mathcal{L}(N, D) = E + \\frac{A}{N^\\alpha} + \\frac{B}{D^\\beta}$$"
]
},
{
Expand Down Expand Up @@ -845,7 +847,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4.02it/s]"
"100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3.85it/s]"
]
},
{
Expand Down Expand Up @@ -924,7 +926,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 12,
"id": "cf7a0b36",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -982,7 +984,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 14,
"id": "8d99e451-c54a-4134-877e-ad0d5cab1929",
"metadata": {
"tags": []
Expand All @@ -995,15 +997,15 @@
" y_key='D', \n",
" z_key='L', \n",
" color_key='R',\n",
" color_type='log',\n",
" color_type='log10',\n",
" fit_fn=single_epoch_fit,\n",
" savepath='single-epoch-runs-fitted-multiepoch-D.html'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 15,
"id": "64a2b62c-3716-4f37-bedb-8cbaac612bb6",
"metadata": {
"tags": []
Expand All @@ -1029,7 +1031,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 16,
"id": "1e63c213-4286-468c-a0ea-93d336c98293",
"metadata": {
"tags": []
Expand All @@ -1046,8 +1048,60 @@
},
{
"cell_type": "code",
"execution_count": 43,
"id": "831a3693-9b00-4bca-b36e-9df1bf170374",
"execution_count": 17,
"id": "0cdac3f8-231b-4a63-93f3-d2f739b24777",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def is_overfit(tpl):\n",
" \"\"\"\n",
" removes overfit runs that aren't filtered out by early stopping\n",
" \"\"\"\n",
" return tpl['N'] > 1.5e9 and tpl['D'] < 4.01e8 and tpl['R'] > 1\n",
"\n",
"runs_tuples = [{k: v[i] for k, v in runs.items()} for i in range(len(runs['N']))] # dict of lists into list of dicts\n",
"\n",
"runs_tuples.sort(key=lambda x: (x['N'], x['D'], x['R']))\n",
"\n",
"for i in range(1, len(runs_tuples)):\n",
" left = runs_tuples[i-1]\n",
" right = runs_tuples[i]\n",
" \n",
" if left['N']==right['N'] and left['D']==right['D'] and left['L'] < right['L']:\n",
" runs_tuples[i]['L'] = left['L']\n",
" \n",
"runs_tuples = [x for x in runs_tuples if not is_overfit(x)]\n",
" \n",
"runs_early_stopped = {k: [x[k] for x in runs_tuples] for k in runs_tuples[0]}"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "1077585e-8e3a-4cac-a8b0-ac5f63416c77",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"scaling_scatter(\n",
" runs_early_stopped, \n",
" x_key='N', \n",
" y_key='D', \n",
" z_key='L', \n",
" color_key='R',\n",
" color_type='log10',\n",
" fit_fn=single_epoch_fit,\n",
" savepath='single-epoch-runs-fitted-multiepoch-early-stopped-D.html'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a2a488a9-d86e-439c-ac1a-4e12fa96bdd8",
"metadata": {
"tags": []
},
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion analysis/hf_reanalysis/single-epoch-runs-fitted.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion analysis/hf_reanalysis/single-epoch-runs-raw.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion analysis/hf_reanalysis/single-epoch-runs-residuals.html

Large diffs are not rendered by default.

0 comments on commit f0c460e

Please sign in to comment.