Skip to content

Commit

Permalink
Fix/constrain parameters (#60)
Browse files Browse the repository at this point in the history
* constrain initial stability and difficulty

* constrain range of parameters

* update default parameters
  • Loading branch information
L-M-Sherlock authored Oct 6, 2022
1 parent cbdba6e commit 805bba9
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 62 deletions.
114 changes: 57 additions & 57 deletions fsrs4anki_optimizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# FSRS4Anki v3.0.0 Optimizer"
"# FSRS4Anki v3.0.1 Optimizer"
]
},
{
Expand All @@ -13,7 +13,7 @@
"id": "lurCmW0Jqz3s"
},
"source": [
"[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.0.0/fsrs4anki_optimizer.ipynb)\n",
"[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.0.1/fsrs4anki_optimizer.ipynb)\n",
"\n",
"↑ Click the above button to open the optimizer on Google Colab.\n",
"\n",
Expand Down Expand Up @@ -145,7 +145,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5166/5166 [00:16<00:00, 322.61it/s]\n"
"100%|██████████| 5166/5166 [00:28<00:00, 182.68it/s]\n"
]
},
{
Expand Down Expand Up @@ -264,11 +264,11 @@
"from torch import nn\n",
"from sklearn.utils import shuffle\n",
"\n",
"w = [1, 1, 5, -1, -1, 0.1, 1.5, -0.2, 0.8, 2, -0.2, 0.2, 1]\n",
"init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1]\n",
"\n",
"\n",
"class FSRS(nn.Module):\n",
" def __init__(self):\n",
" def __init__(self, w):\n",
" super(FSRS, self).__init__()\n",
" self.w = nn.Parameter(torch.FloatTensor(w))\n",
" self.zero = torch.FloatTensor([0.0])\n",
Expand Down Expand Up @@ -319,15 +319,15 @@
" w[0] = w[0].clamp(0.1, 10) # initStability\n",
" w[1] = w[1].clamp(0.1, 5) # initStabilityRatingFactor\n",
" w[2] = w[2].clamp(1, 10) # initDifficulty\n",
" w[3] = w[3].clamp(-5, -0.5) # initDifficultyRatingFactor\n",
" w[4] = w[4].clamp(-5, -0.5) # updateDifficultyRatingFactor\n",
" w[3] = w[3].clamp(-5, -0.1) # initDifficultyRatingFactor\n",
" w[4] = w[4].clamp(-5, -0.1) # updateDifficultyRatingFactor\n",
" w[5] = w[5].clamp(0, 0.5) # difficultyMeanReversionFactor\n",
" w[6] = w[6].clamp(0, 5) # recallFactor\n",
" w[7] = w[7].clamp(-0.2, -0.01) # recallStabilityDecay\n",
" w[8] = w[8].clamp(0.01, 2) # recallRetrievabilityFactor\n",
" w[9] = w[9].clamp(0.5, 5) # forgetFactor\n",
" w[10] = w[10].clamp(-2, -0.01) # forgetDifficultyDecay\n",
" w[11] = w[11].clamp(0.01, 1) # forgetStabilityDecay\n",
" w[11] = w[11].clamp(0.01, 0.5) # forgetStabilityDecay\n",
" w[12] = w[12].clamp(0.01, 2) # forgetRetrievabilityFactor\n",
" module.w.data = w\n",
"\n",
Expand Down Expand Up @@ -367,7 +367,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56910/56910 [00:03<00:00, 15627.86it/s]\n"
"100%|██████████| 56910/56910 [00:05<00:00, 10099.40it/s]\n"
]
},
{
Expand All @@ -381,179 +381,179 @@
"name": "stderr",
"output_type": "stream",
"text": [
"pre—train: 100%|\u001b[31m██████████\u001b[0m| 5166/5166 [00:01<00:00, 4229.38it/s]\n"
"pre—train: 100%|\u001b[31m██████████\u001b[0m| 5166/5166 [00:02<00:00, 2537.41it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"w: [1.1471, 1.4016, 5.0, -1.0, -1.0, 0.1, 1.5, -0.2, 0.8, 2.0, -0.2, 0.2, 1.0]\n"
"w: [1.1471, 1.4016, 5.0, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2.0, -0.2, 0.2, 1.0]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 0%|\u001b[31m \u001b[0m| 35/51744 [00:00<02:33, 337.74it/s]"
"train: 0%|\u001b[31m \u001b[0m| 23/51744 [00:00<03:49, 224.92it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 1\n",
"w: [1.1471, 1.4016, 4.9984, -1.0, -0.9984, 0.1016, 1.5016, -0.1984, 0.8016, 2.0, -0.2, 0.2, 1.0]\n"
"w: [1.1471, 1.4016, 4.9984, -0.5, -0.4984, 0.2016, 1.4016, -0.1184, 0.8016, 2.0, -0.2, 0.2, 1.0]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 10%|\u001b[31m█ \u001b[0m| 5251/51744 [00:11<01:35, 486.52it/s]"
"train: 10%|\u001b[31m█ \u001b[0m| 5220/51744 [00:20<02:36, 297.50it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 5175\n",
"w: [1.1475, 1.401, 5.0617, -1.1341, -1.0653, 0.0161, 1.4589, -0.1759, 0.755, 2.0445, -0.1573, 0.2593, 1.056]\n"
"w: [1.1475, 1.401, 5.1333, -0.7564, -0.672, 0.0173, 1.3335, -0.1686, 0.7295, 2.0383, -0.1622, 0.2498, 1.0586]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 20%|\u001b[31m██ \u001b[0m| 10423/51744 [00:23<01:36, 430.17it/s]"
"train: 20%|\u001b[31m██ \u001b[0m| 10385/51744 [00:39<02:34, 268.22it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 10349\n",
"w: [1.1475, 1.401, 5.0993, -1.1936, -1.1382, 0.0367, 1.4792, -0.1577, 0.7771, 2.0359, -0.1656, 0.2812, 1.056]\n"
"w: [1.1475, 1.401, 5.1568, -0.836, -0.7068, 0.049, 1.3445, -0.1561, 0.7417, 2.0282, -0.1683, 0.2665, 1.0575]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 30%|\u001b[31m███ \u001b[0m| 15591/51744 [00:36<01:50, 326.22it/s]"
"train: 30%|\u001b[31m███ \u001b[0m| 15556/51744 [00:57<01:51, 323.60it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 15523\n",
"w: [1.1475, 1.401, 5.1569, -1.2927, -1.2003, 0.0128, 1.4649, -0.1561, 0.7619, 2.0086, -0.1942, 0.2576, 1.0371]\n"
"w: [1.1475, 1.401, 5.194, -0.9428, -0.7409, 0.0241, 1.3317, -0.1497, 0.7283, 2.005, -0.1913, 0.2483, 1.0384]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 40%|\u001b[31m████ \u001b[0m| 20733/51744 [00:48<01:43, 298.54it/s]"
"train: 40%|\u001b[31m████ \u001b[0m| 20740/51744 [01:17<01:51, 277.96it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 20697\n",
"w: [1.1475, 1.401, 5.1201, -1.2971, -1.154, 0.0268, 1.4802, -0.1373, 0.7772, 2.0576, -0.1472, 0.349, 1.1001]\n"
"w: [1.1475, 1.401, 5.2174, -1.0217, -0.7856, 0.0172, 1.3633, -0.1154, 0.7596, 2.0549, -0.1446, 0.3413, 1.0974]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 50%|\u001b[31m█████ \u001b[0m| 25900/51744 [01:01<01:19, 323.16it/s]"
"train: 50%|\u001b[31m█████ \u001b[0m| 25903/51744 [01:35<02:16, 189.91it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 25871\n",
"w: [1.1475, 1.401, 5.1519, -1.3537, -1.1958, 0.0282, 1.4693, -0.1262, 0.7643, 2.0058, -0.1967, 0.3025, 1.0299]\n"
"w: [1.1475, 1.401, 5.2727, -1.0847, -0.8636, 0.0202, 1.3582, -0.1003, 0.7526, 2.0054, -0.1897, 0.295, 1.036]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 60%|\u001b[31m██████ \u001b[0m| 31126/51744 [01:14<00:48, 420.89it/s]"
"train: 60%|\u001b[31m██████ \u001b[0m| 31100/51744 [01:53<01:07, 305.09it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 31045\n",
"w: [1.1475, 1.401, 5.1399, -1.3126, -1.2124, 0.0358, 1.4722, -0.1264, 0.766, 2.0067, -0.1936, 0.3119, 1.0347]\n"
"w: [1.1475, 1.401, 5.3162, -1.1032, -0.9451, 0.0233, 1.3689, -0.095, 0.7622, 2.0071, -0.1852, 0.3045, 1.0371]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 70%|\u001b[31m███████ \u001b[0m| 36256/51744 [01:26<00:41, 373.30it/s]"
"train: 70%|\u001b[31m███████ \u001b[0m| 36252/51744 [02:11<00:54, 283.86it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 36219\n",
"w: [1.1475, 1.401, 5.1487, -1.3606, -1.1984, 0.0471, 1.4642, -0.1297, 0.7553, 2.0158, -0.1834, 0.3526, 1.0356]\n"
"w: [1.1475, 1.401, 5.3463, -1.1919, -0.9538, 0.0337, 1.3707, -0.0909, 0.7611, 2.0197, -0.1707, 0.3475, 1.0439]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 80%|\u001b[31m████████ \u001b[0m| 41442/51744 [01:38<00:24, 419.56it/s]"
"train: 80%|\u001b[31m████████ \u001b[0m| 41450/51744 [02:30<00:35, 286.34it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 41393\n",
"w: [1.1475, 1.401, 5.1195, -1.3066, -1.2153, 0.0485, 1.4812, -0.1294, 0.7698, 2.0114, -0.1864, 0.339, 1.0391]\n"
"w: [1.1475, 1.401, 5.3314, -1.1615, -0.9797, 0.0396, 1.3902, -0.092, 0.778, 2.0144, -0.1743, 0.333, 1.0483]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 90%|\u001b[31m█████████ \u001b[0m| 46619/51744 [01:51<00:11, 450.40it/s]"
"train: 90%|\u001b[31m█████████ \u001b[0m| 46591/51744 [02:52<00:24, 207.36it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 46567\n",
"w: [1.1475, 1.401, 5.1217, -1.3611, -1.2017, 0.0174, 1.4664, -0.1351, 0.7537, 1.9796, -0.218, 0.338, 0.9664]\n"
"w: [1.1475, 1.401, 5.3568, -1.2385, -0.9972, 0.01, 1.3793, -0.0969, 0.7657, 1.9809, -0.2075, 0.3306, 0.9755]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 100%|\u001b[31m██████████\u001b[0m| 51744/51744 [02:02<00:00, 422.20it/s]"
"train: 100%|\u001b[31m██████████\u001b[0m| 51744/51744 [03:13<00:00, 267.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 51741\n",
"w: [1.1475, 1.401, 5.1481, -1.4223, -1.228, 0.0348, 1.4669, -0.1285, 0.754, 1.9673, -0.2305, 0.3203, 0.9453]\n",
"w: [1.1475, 1.401, 5.3721, -1.2625, -1.0177, 0.0318, 1.3768, -0.0973, 0.7628, 1.9672, -0.221, 0.3114, 0.9563]\n",
"\n",
"Training finished!\n"
]
Expand All @@ -567,7 +567,7 @@
}
],
"source": [
"model = FSRS()\n",
"model = FSRS(init_w)\n",
"clipper = WeightClipper()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)\n",
"\n",
Expand Down Expand Up @@ -667,7 +667,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"var w = [1.1475, 1.401, 5.1483, -1.4221, -1.2282, 0.035, 1.4668, -0.1286, 0.7539, 1.9671, -0.2307, 0.32, 0.9451];\n"
"var w = [1.1475, 1.401, 5.3723, -1.2624, -1.0178, 0.032, 1.3767, -0.0975, 0.7627, 1.9671, -0.2212, 0.3111, 0.9561];\n"
]
}
],
Expand Down Expand Up @@ -703,23 +703,23 @@
"\n",
"first rating: 1\n",
"rating history: 1,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,1,2,4,8,15,27,49,85,146,245\n",
"difficulty history: 0,8.0,7.9,7.8,7.7,7.6,7.5,7.4,7.4,7.3,7.2\n",
"interval history: 0,1,2,4,8,15,27,49,87,152,262\n",
"difficulty history: 0,7.9,7.8,7.7,7.7,7.6,7.5,7.4,7.4,7.3,7.3\n",
"\n",
"first rating: 2\n",
"rating history: 2,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,3,7,15,31,63,121,225,406,710,1210\n",
"difficulty history: 0,6.6,6.5,6.5,6.4,6.4,6.3,6.3,6.3,6.2,6.2\n",
"interval history: 0,3,6,13,28,56,110,211,392,713,1266\n",
"difficulty history: 0,6.6,6.6,6.6,6.5,6.5,6.4,6.4,6.4,6.3,6.3\n",
"\n",
"first rating: 3\n",
"rating history: 3,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,4,11,27,62,134,276,542,1021,1853,3252\n",
"difficulty history: 0,5.1,5.1,5.1,5.1,5.1,5.1,5.1,5.1,5.1,5.1\n",
"interval history: 0,4,10,24,55,121,255,518,1015,1929,3560\n",
"difficulty history: 0,5.4,5.4,5.4,5.4,5.4,5.4,5.4,5.4,5.4,5.4\n",
"\n",
"first rating: 4\n",
"rating history: 4,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,5,15,41,103,239,522,1076,2112,3966,7160\n",
"difficulty history: 0,3.7,3.8,3.8,3.9,3.9,4.0,4.0,4.0,4.1,4.1\n",
"interval history: 0,5,14,38,94,222,498,1067,2193,4341,8300\n",
"difficulty history: 0,4.1,4.2,4.2,4.2,4.3,4.3,4.3,4.4,4.4,4.4\n",
"\n"
]
}
Expand Down Expand Up @@ -780,21 +780,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(tensor(3.9495), tensor(5.1483))\n",
"(tensor(10.6076), tensor(5.1483))\n",
"(tensor(26.7267), tensor(5.1483))\n",
"(tensor(61.8732), tensor(5.1483))\n",
"(tensor(134.3306), tensor(5.1483))\n",
"(tensor(6.5113), tensor(7.5187))\n",
"(tensor(2.3411), tensor(9.8062))\n",
"(tensor(3.1687), tensor(9.6432))\n",
"(tensor(4.4993), tensor(9.4858))\n",
"(tensor(6.3669), tensor(9.3340))\n",
"(tensor(9.2792), tensor(9.1875))\n",
"(tensor(13.7638), tensor(9.0462))\n",
"(tensor(3.9494), tensor(5.3723))\n",
"(tensor(10.1311), tensor(5.3723))\n",
"(tensor(24.2353), tensor(5.3723))\n",
"(tensor(55.3255), tensor(5.3723))\n",
"(tensor(121.0634), tensor(5.3723))\n",
"(tensor(6.1930), tensor(7.3427))\n",
"(tensor(2.3271), tensor(9.2501))\n",
"(tensor(3.4134), tensor(9.1259))\n",
"(tensor(5.0832), tensor(9.0057))\n",
"(tensor(7.9122), tensor(8.8893))\n",
"(tensor(12.4769), tensor(8.7767))\n",
"(tensor(19.3531), tensor(8.6676))\n",
"rating history: 3,3,3,3,3,1,1,3,3,3,3,3\n",
"interval history: 0,4,11,27,62,134,7,2,3,4,6,9,14\n",
"difficulty history: 0,5.1,5.1,5.1,5.1,5.1,7.5,9.8,9.6,9.5,9.3,9.2,9.0\n"
"interval history: 0,4,10,24,55,121,6,2,3,5,8,12,19\n",
"difficulty history: 0,5.4,5.4,5.4,5.4,5.4,7.3,9.3,9.1,9.0,8.9,8.8,8.7\n"
]
}
],
Expand Down
10 changes: 5 additions & 5 deletions fsrs4anki_scheduler.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// FSRS4Anki v3.0.0 Scheduler
// FSRS4Anki v3.0.1 Scheduler
set_version();
// The latest version will be released on https://github.com/open-spaced-repetition/fsrs4anki

// Default parameters of FSRS4Anki for global
var w = [1, 1, 5, -1, -1, 0.1, 1.5, -0.2, 0.8, 2, -0.2, 0.2, 1];
var w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1];
// The above parameters can be optimized via FSRS4Anki optimizer.

// User's custom parameters for global
Expand Down Expand Up @@ -147,11 +147,11 @@ function init_states() {
}

function init_difficulty(rating) {
return +(w[2] + w[3] * (ratings[rating] - 3)).toFixed(2);
return +constrain_difficulty(w[2] + w[3] * (ratings[rating] - 3)).toFixed(2);
}

function init_stability(rating) {
return +(w[0] + w[1] * (ratings[rating] - 1)).toFixed(2);
return +Math.max(w[0] + w[1] * (ratings[rating] - 1), 0.1).toFixed(2);
}

function convert_states() {
Expand Down Expand Up @@ -226,7 +226,7 @@ function is_empty() {
}

function set_version() {
const version = "3.0.0";
const version = "3.0.1";
customData.again.v = version;
customData.hard.v = version;
customData.good.v = version;
Expand Down

0 comments on commit 805bba9

Please sign in to comment.