diff --git a/data/data b/data/data
deleted file mode 120000
index 945c9b4..0000000
--- a/data/data
+++ /dev/null
@@ -1 +0,0 @@
-.
\ No newline at end of file
diff --git a/model_files/classes_sherlock.npy b/model_files/classes_sherlock.npy
new file mode 100644
index 0000000..8483b8c
Binary files /dev/null and b/model_files/classes_sherlock.npy differ
diff --git a/model_files/sherlock_weights.h5 b/model_files/sherlock_weights.h5
index 2f9fba3..709d54f 100644
Binary files a/model_files/sherlock_weights.h5 and b/model_files/sherlock_weights.h5 differ
diff --git a/notebooks/00-use-sherlock-out-of-the-box.ipynb b/notebooks/00-use-sherlock-out-of-the-box.ipynb
index 178f6d7..5d137cc 100644
--- a/notebooks/00-use-sherlock-out-of-the-box.ipynb
+++ b/notebooks/00-use-sherlock-out-of-the-box.ipynb
@@ -8,7 +8,8 @@
"# Using Sherlock out-of-the-box\n",
"This notebook shows how to predict a semantic type for a given table column.\n",
"The steps are basically:\n",
- "- Extract features from a column.\n",
+ "- Download files for word embedding and paragraph vector feature extraction (downloads only once) and initialize feature extraction models.\n",
+ "- Extract features from table columns.\n",
"- Initialize Sherlock.\n",
"- Make a prediction for the feature representation of the column."
]
@@ -44,11 +45,14 @@
"metadata": {},
"outputs": [
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "UsageError: Environment does not have key: PYTHONHASHSEED\n"
- ]
+ "data": {
+ "text/plain": [
+ "'13'"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
@@ -57,20 +61,10 @@
},
{
"cell_type": "markdown",
- "id": "2b3b7967",
+ "id": "f1101303",
"metadata": {},
"source": [
- "## Extract features"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "164f74ff",
- "metadata": {},
- "outputs": [],
- "source": [
- "# helpers.download_data()"
+ "## Initialize feature extraction models"
]
},
{
@@ -93,9 +87,9 @@
" \n",
"All files for extracting word and paragraph embeddings are present.\n",
"Initialising word embeddings\n",
- "Initialise Word Embeddings process took 0:00:05.607905 seconds.\n",
- "Initialise Doc2Vec Model, 400 dim, process took 0:00:02.443327 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n",
- "Initialised NLTK, process took 0:00:00.181374 seconds.\n"
+ "Initialise Word Embeddings process took 0:00:05.513540 seconds.\n",
+ "Initialise Doc2Vec Model, 400 dim, process took 0:00:04.191875 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n",
+ "Initialised NLTK, process took 0:00:00.209930 seconds.\n"
]
},
{
@@ -117,9 +111,17 @@
"initialise_nltk()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "2b3b7967",
+ "metadata": {},
+ "source": [
+ "## Extract features"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 4,
"id": "db04ccf9",
"metadata": {},
"outputs": [],
@@ -128,6 +130,7 @@
" [\n",
" [\"Jane Smith\", \"Lute Ahorn\", \"Anna James\"],\n",
" [\"Amsterdam\", \"Haarlem\", \"Zwolle\"],\n",
+ " [\"Chabot Street 19\", \"1200 fifth Avenue\", \"Binnenkant 22, 1011BH\"]\n",
" ],\n",
" name=\"values\"\n",
")"
@@ -135,19 +138,20 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 5,
"id": "4875f6c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "0 [Jane Smith, Lute Ahorn, Anna James]\n",
- "1 [Amsterdam, Haarlem, Zwolle]\n",
+ "0 [Jane Smith, Lute Ahorn, Anna James]\n",
+ "1 [Amsterdam, Haarlem, Zwolle]\n",
+ "2 [Chabot Street 19, 1200 fifth Avenue, Binnenka...\n",
"Name: values, dtype: object"
]
},
- "execution_count": 36,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -158,7 +162,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 8,
"id": "f7f2c846",
"metadata": {},
"outputs": [
@@ -166,7 +170,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Extracting Features: 100%|██████████| 2/2 [00:00<00:00, 62.37it/s]\n"
+ "Extracting Features: 100%|██████████| 3/3 [00:00<00:00, 167.51it/s]"
]
},
{
@@ -175,6 +179,13 @@
"text": [
"Exporting 1588 column features\n"
]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
}
],
"source": [
@@ -182,12 +193,12 @@
" \"../temporary.csv\",\n",
" data\n",
")\n",
- "feature_vector = pd.read_csv(\"../temporary.csv\", dtype=np.float32)"
+ "feature_vectors = pd.read_csv(\"../temporary.csv\", dtype=np.float32)"
]
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 9,
"id": "0c42ce71",
"metadata": {},
"outputs": [
@@ -241,7 +252,7 @@
"
0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
- " 0.0 | \n",
+ " 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
@@ -249,23 +260,23 @@
" -3.0 | \n",
" 0.0 | \n",
" ... | \n",
- " -0.115819 | \n",
- " 0.023961 | \n",
- " -0.130739 | \n",
- " 0.006393 | \n",
- " -0.135118 | \n",
- " -0.071956 | \n",
- " -0.051051 | \n",
- " -0.068307 | \n",
- " 0.087342 | \n",
- " -0.145716 | \n",
+ " -0.116468 | \n",
+ " 0.023982 | \n",
+ " -0.130867 | \n",
+ " 0.006825 | \n",
+ " -0.135098 | \n",
+ " -0.070616 | \n",
+ " -0.052172 | \n",
+ " -0.067250 | \n",
+ " 0.086256 | \n",
+ " -0.144385 | \n",
" \n",
" \n",
" 1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
- " 0.0 | \n",
+ " 0.000000 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
@@ -273,63 +284,84 @@
" -3.0 | \n",
" 0.0 | \n",
" ... | \n",
- " -0.054351 | \n",
- " 0.023650 | \n",
- " -0.165681 | \n",
- " -0.016137 | \n",
- " -0.059402 | \n",
- " 0.008454 | \n",
- " -0.044624 | \n",
- " 0.025160 | \n",
- " 0.037831 | \n",
- " -0.086235 | \n",
+ " -0.054949 | \n",
+ " 0.024502 | \n",
+ " -0.166001 | \n",
+ " -0.014375 | \n",
+ " -0.058199 | \n",
+ " 0.009978 | \n",
+ " -0.046423 | \n",
+ " 0.025163 | \n",
+ " 0.036946 | \n",
+ " -0.086611 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.666667 | \n",
+ " 0.0 | \n",
+ " 2.0 | \n",
+ " 1.0 | \n",
+ " 3.0 | \n",
+ " -1.5 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " -0.022804 | \n",
+ " 0.001741 | \n",
+ " 0.047479 | \n",
+ " 0.118293 | \n",
+ " -0.093435 | \n",
+ " 0.036759 | \n",
+ " -0.004508 | \n",
+ " -0.087898 | \n",
+ " -0.117796 | \n",
+ " -0.191386 | \n",
"
\n",
" \n",
"\n",
- "2 rows × 1588 columns
\n",
+ "3 rows × 1588 columns
\n",
""
],
"text/plain": [
" n_[0]-agg-any n_[0]-agg-all n_[0]-agg-mean n_[0]-agg-var n_[0]-agg-min \\\n",
- "0 0.0 0.0 0.0 0.0 0.0 \n",
- "1 0.0 0.0 0.0 0.0 0.0 \n",
+ "0 0.0 0.0 0.0 0.000000 0.0 \n",
+ "1 0.0 0.0 0.0 0.000000 0.0 \n",
+ "2 1.0 0.0 1.0 0.666667 0.0 \n",
"\n",
" n_[0]-agg-max n_[0]-agg-median n_[0]-agg-sum n_[0]-agg-kurtosis \\\n",
"0 0.0 0.0 0.0 -3.0 \n",
"1 0.0 0.0 0.0 -3.0 \n",
+ "2 2.0 1.0 3.0 -1.5 \n",
"\n",
" n_[0]-agg-skewness ... par_vec_390 par_vec_391 par_vec_392 \\\n",
- "0 0.0 ... -0.115819 0.023961 -0.130739 \n",
- "1 0.0 ... -0.054351 0.023650 -0.165681 \n",
+ "0 0.0 ... -0.116468 0.023982 -0.130867 \n",
+ "1 0.0 ... -0.054949 0.024502 -0.166001 \n",
+ "2 0.0 ... -0.022804 0.001741 0.047479 \n",
"\n",
" par_vec_393 par_vec_394 par_vec_395 par_vec_396 par_vec_397 \\\n",
- "0 0.006393 -0.135118 -0.071956 -0.051051 -0.068307 \n",
- "1 -0.016137 -0.059402 0.008454 -0.044624 0.025160 \n",
+ "0 0.006825 -0.135098 -0.070616 -0.052172 -0.067250 \n",
+ "1 -0.014375 -0.058199 0.009978 -0.046423 0.025163 \n",
+ "2 0.118293 -0.093435 0.036759 -0.004508 -0.087898 \n",
"\n",
" par_vec_398 par_vec_399 \n",
- "0 0.087342 -0.145716 \n",
- "1 0.037831 -0.086235 \n",
+ "0 0.086256 -0.144385 \n",
+ "1 0.036946 -0.086611 \n",
+ "2 -0.117796 -0.191386 \n",
"\n",
- "[2 rows x 1588 columns]"
+ "[3 rows x 1588 columns]"
]
},
- "execution_count": 38,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "feature_vector"
+ "feature_vectors"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "52047a6b",
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "code",
"execution_count": null,
@@ -343,18 +375,18 @@
"id": "9027fa4a",
"metadata": {},
"source": [
- "## Initialize Sherlock."
+ "## Initialize Sherlock"
]
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 11,
"id": "b9ec13ec",
"metadata": {},
"outputs": [],
"source": [
"model = SherlockModel();\n",
- "model.initialize_model_from_json(with_weights=True);"
+ "model.initialize_model_from_json(with_weights=True, model_id=\"sherlock\");"
]
},
{
@@ -375,27 +407,27 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 12,
"id": "fc079fa9",
"metadata": {},
"outputs": [],
"source": [
- "predicted_labels = model.predict(feature_vector, \"sherlock\")"
+ "predicted_labels = model.predict(feature_vectors, \"sherlock\")"
]
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 13,
"id": "0feb9584",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "array(['creator', 'city'], dtype=object)"
+ "array(['person', 'city', 'address'], dtype=object)"
]
},
- "execution_count": 41,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/notebooks/01-data-preprocessing.ipynb b/notebooks/01-data-preprocessing.ipynb
index 34be765..33ddbc2 100644
--- a/notebooks/01-data-preprocessing.ipynb
+++ b/notebooks/01-data-preprocessing.ipynb
@@ -1,5 +1,12 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Preprocess data and extract features."
+ ]
+ },
{
"cell_type": "code",
"execution_count": 1,
@@ -12,7 +19,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 28,
"metadata": {},
"outputs": [
{
@@ -21,34 +28,21 @@
"'13'"
]
},
- "execution_count": 2,
+ "execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# If you need fully deterministic results between runs, set the following environment value prior to launching jupyter.\n",
+ "# Instructions can be found in HOW-TO-ENVIRONMENT.md.\n",
"# See comment in sherlock.features.paragraph_vectors.infer_paragraph_embeddings_features for more info.\n",
"%env PYTHONHASHSEED"
]
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Extract features, retrain Sherlock and generate predictions."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The script below first downloads the data (roughly 700K samples), then extract features from the raw data values.
"
- ]
- },
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -74,14 +68,14 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-10 18:08:09.074960.\n"
+ "Started at 2022-02-21 12:55:47.936774.\n"
]
}
],
@@ -101,19 +95,22 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Downloading the raw and preprocessed data into ../data/data.zip.\n",
+ "Downloading the raw data into ../data/.\n",
"Data was downloaded.\n",
- "Preparing feature extraction by downloading 2 files:\n",
+ "Preparing feature extraction by downloading 4 files:\n",
" \n",
- " ../sherlock/features/glove.6B.50d.txt and \n",
- " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy.\n",
+ " ../sherlock/features/glove.6B.50d.txt, \n",
+ " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy,\n",
+ " \n",
+ " ../sherlock/features/par_vec_trained_400.pkl.trainables.syn1neg.npy, and \n",
+ " ../sherlock/features/par_vec_trained_400.pkl.wv.vectors.npy.\n",
" \n",
"All files for extracting word and paragraph embeddings are present.\n"
]
@@ -126,7 +123,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -174,7 +171,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -195,23 +192,26 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Preparing feature extraction by downloading 2 files:\n",
+ "Preparing feature extraction by downloading 4 files:\n",
+ " \n",
+ " ../sherlock/features/glove.6B.50d.txt, \n",
+ " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy,\n",
" \n",
- " ../sherlock/features/glove.6B.50d.txt and \n",
- " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy.\n",
+ " ../sherlock/features/par_vec_trained_400.pkl.trainables.syn1neg.npy, and \n",
+ " ../sherlock/features/par_vec_trained_400.pkl.wv.vectors.npy.\n",
" \n",
"All files for extracting word and paragraph embeddings are present.\n",
"Initialising word embeddings\n",
- "Initialise Word Embeddings process took 0:00:05.470982 seconds.\n",
- "Initialise Doc2Vec Model, 400 dim, process took 0:00:02.812897 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n",
- "Initialised NLTK, process took 0:00:00.243027 seconds.\n"
+ "Initialise Word Embeddings process took 0:00:05.196631 seconds.\n",
+ "Initialise Doc2Vec Model, 400 dim, process took 0:00:03.018509 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n",
+ "Initialised NLTK, process took 0:00:00.184036 seconds.\n"
]
},
{
@@ -236,7 +236,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -245,7 +245,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -270,16 +270,16 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Starting ../data/data/processed/test_20220210-180929.csv at 2022-02-10 18:09:45.839366. Rows=137353, using 8 CPU cores\n",
+ "Starting ../data/data/processed/test_20220221-125552.csv at 2022-02-21 12:56:03.739007. Rows=137353, using 8 CPU cores\n",
"Exporting 1588 column features\n",
- "Finished. Processed 137353 rows in 0:04:43.132315, key_count=8\n"
+ "Finished. Processed 137353 rows in 0:04:41.140014, key_count=8\n"
]
}
],
@@ -293,14 +293,14 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Finished at 2022-02-10 18:14:29.113958\n"
+ "Finished at 2022-02-21 13:00:45.006991\n"
]
}
],
@@ -317,16 +317,16 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Starting ../data/data/processed/train_20220210-180929.csv at 2022-02-10 18:14:32.033627. Rows=412059, using 8 CPU cores\n",
+ "Starting ../data/data/processed/train_20220221-125552.csv at 2022-02-21 13:00:46.766509. Rows=412059, using 8 CPU cores\n",
"Exporting 1588 column features\n",
- "Finished. Processed 412059 rows in 0:13:59.233387, key_count=8\n"
+ "Finished. Processed 412059 rows in 0:13:46.382904, key_count=8\n"
]
}
],
@@ -340,14 +340,14 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Finished at 2022-02-10 18:28:31.533353\n"
+ "Finished at 2022-02-21 13:14:33.436746\n"
]
}
],
@@ -364,16 +364,16 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Starting ../data/data/processed/validation_20220210-180929.csv at 2022-02-10 18:28:32.398357. Rows=137353, using 8 CPU cores\n",
+ "Starting ../data/data/processed/validation_20220221-125552.csv at 2022-02-21 13:14:34.282570. Rows=137353, using 8 CPU cores\n",
"Exporting 1588 column features\n",
- "Finished. Processed 137353 rows in 0:04:29.888694, key_count=8\n"
+ "Finished. Processed 137353 rows in 0:04:20.682436, key_count=8\n"
]
}
],
@@ -387,14 +387,14 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Finished at 2022-02-10 18:33:02.388902\n"
+ "Finished at 2022-02-21 13:18:55.069282\n"
]
}
],
@@ -411,14 +411,14 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Load Features (test) process took 0:00:31.909566 seconds.\n"
+ "Load Features (test) process took 0:00:30.504465 seconds.\n"
]
}
],
@@ -432,7 +432,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -644,7 +644,7 @@
"[5 rows x 1588 columns]"
]
},
- "execution_count": 25,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -655,14 +655,14 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Load Features (train) process took 0:01:40.003595 seconds.\n"
+ "Load Features (train) process took 0:01:39.241069 seconds.\n"
]
}
],
@@ -676,7 +676,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -888,7 +888,7 @@
"[5 rows x 1588 columns]"
]
},
- "execution_count": 27,
+ "execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -899,14 +899,14 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Load Features (validation) process took 0:00:30.514030 seconds.\n"
+ "Load Features (validation) process took 0:00:30.066667 seconds.\n"
]
}
],
@@ -920,7 +920,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -1132,7 +1132,7 @@
"[5 rows x 1588 columns]"
]
},
- "execution_count": 29,
+ "execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
@@ -1150,14 +1150,14 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Transpose process took 0:00:11.975168 seconds.\n"
+ "Transpose process took 0:00:12.315641 seconds.\n"
]
}
],
@@ -1171,14 +1171,14 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "FillNA process took 0:00:04.992048 seconds.\n"
+ "FillNA process took 0:00:04.335100 seconds.\n"
]
}
],
@@ -1196,14 +1196,14 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Save parquet process took 0:01:03.475468 seconds.\n"
+ "Save parquet process took 0:00:58.970663 seconds.\n"
]
}
],
@@ -1219,14 +1219,14 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Completed at 2022-02-10 18:37:05.949284.\n"
+ "Completed at 2022-02-21 13:22:51.962526.\n"
]
}
],
diff --git a/notebooks/02-1-train-and-test-sherlock.ipynb b/notebooks/02-1-train-and-test-sherlock.ipynb
index add40fd..f0001f0 100644
--- a/notebooks/02-1-train-and-test-sherlock.ipynb
+++ b/notebooks/02-1-train-and-test-sherlock.ipynb
@@ -29,12 +29,14 @@
"metadata": {},
"outputs": [],
"source": [
+ "# This will be the ID for the retrained model,\n",
+ "#further down predictions can also be made with the original model: \"sherlock\"\n",
"model_id = 'retrained_sherlock'"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@@ -66,8 +68,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-10 18:47:37.581529\n",
- "Load data (train) process took 0:00:06.286789 seconds.\n"
+ "Started at 2022-02-21 13:49:29.901066\n",
+ "Load data (train) process took 0:00:07.129129 seconds.\n"
]
}
],
@@ -111,8 +113,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-10 18:47:44.189431\n",
- "Load data (validation) process took 0:00:01.713314 seconds.\n"
+ "Started at 2022-02-21 13:49:37.357866\n",
+ "Load data (validation) process took 0:00:01.622250 seconds.\n"
]
}
],
@@ -137,8 +139,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-10 18:47:45.963358\n",
- "Finished at 2022-02-10 18:47:48.373023, took 0:00:02.409678 seconds\n"
+ "Started at 2022-02-21 13:49:39.041588\n",
+ "Finished at 2022-02-21 13:49:41.373401, took 0:00:02.331826 seconds\n"
]
}
],
@@ -168,12 +170,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Load Sherlock with pretrained weights"
+ "### Option 1: load Sherlock with pretrained weights"
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 14,
"metadata": {
"scrolled": true
},
@@ -182,9 +184,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-10 18:55:58.217413\n",
+ "Started at 2022-02-21 14:05:25.056032\n",
"Initialized model.\n",
- "Finished at 2022-02-10 18:55:59.172945, took 0:00:00.955544 seconds\n"
+ "Finished at 2022-02-21 14:05:25.937041, took 0:00:00.881019 seconds\n"
]
}
],
@@ -192,8 +194,8 @@
"start = datetime.now()\n",
"print(f'Started at {start}')\n",
"\n",
- "model = SherlockModel()\n",
- "model.initialize_model_from_json(with_weights=True)\n",
+ "model = SherlockModel();\n",
+ "model.initialize_model_from_json(with_weights=True, model_id=\"sherlock\");\n",
"\n",
"print('Initialized model.')\n",
"print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')"
@@ -203,12 +205,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Fit Sherlock from scratch (and save for later use)"
+ "### Option 2: fit Sherlock from scratch (and save for later use)"
]
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_id = \"retrained_sherlock\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
"metadata": {
"scrolled": true
},
@@ -217,94 +228,128 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-10 20:48:03.023090\n",
+ "Started at 2022-02-21 11:26:21.704064\n",
"Train on 412059 samples, validate on 137353 samples\n",
"Epoch 1/100\n",
- "412059/412059 [==============================] - 73s 177us/sample - loss: 1.6144 - categorical_accuracy: 0.6971 - val_loss: 1.0390 - val_categorical_accuracy: 0.8300\n",
+ "412059/412059 [==============================] - 70s 170us/sample - loss: 1.6072 - categorical_accuracy: 0.6988 - val_loss: 1.0404 - val_categorical_accuracy: 0.8266\n",
"Epoch 2/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.9628 - categorical_accuracy: 0.8359 - val_loss: 0.9472 - val_categorical_accuracy: 0.8491\n",
+ "412059/412059 [==============================] - 67s 163us/sample - loss: 0.9624 - categorical_accuracy: 0.8364 - val_loss: 0.9520 - val_categorical_accuracy: 0.8498\n",
"Epoch 3/100\n",
- "412059/412059 [==============================] - 70s 171us/sample - loss: 0.8533 - categorical_accuracy: 0.8586 - val_loss: 0.8953 - val_categorical_accuracy: 0.8580\n",
+ "412059/412059 [==============================] - 65s 159us/sample - loss: 0.8501 - categorical_accuracy: 0.8590 - val_loss: 0.8929 - val_categorical_accuracy: 0.8579\n",
"Epoch 4/100\n",
- "412059/412059 [==============================] - 68s 166us/sample - loss: 0.7860 - categorical_accuracy: 0.8715 - val_loss: 0.8617 - val_categorical_accuracy: 0.8637\n",
+ "412059/412059 [==============================] - 66s 159us/sample - loss: 0.7858 - categorical_accuracy: 0.8718 - val_loss: 0.8561 - val_categorical_accuracy: 0.8645\n",
"Epoch 5/100\n",
- "412059/412059 [==============================] - 69s 168us/sample - loss: 0.7388 - categorical_accuracy: 0.8789 - val_loss: 0.8268 - val_categorical_accuracy: 0.8682\n",
+ "412059/412059 [==============================] - 66s 159us/sample - loss: 0.7399 - categorical_accuracy: 0.8788 - val_loss: 0.8267 - val_categorical_accuracy: 0.8682\n",
"Epoch 6/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.7010 - categorical_accuracy: 0.8852 - val_loss: 0.8113 - val_categorical_accuracy: 0.8699\n",
+ "412059/412059 [==============================] - 66s 160us/sample - loss: 0.6976 - categorical_accuracy: 0.8855 - val_loss: 0.8073 - val_categorical_accuracy: 0.8703\n",
"Epoch 7/100\n",
- "412059/412059 [==============================] - 71s 172us/sample - loss: 0.6674 - categorical_accuracy: 0.8896 - val_loss: 0.7839 - val_categorical_accuracy: 0.8720\n",
+ "412059/412059 [==============================] - 68s 165us/sample - loss: 0.6674 - categorical_accuracy: 0.8892 - val_loss: 0.7841 - val_categorical_accuracy: 0.8726\n",
"Epoch 8/100\n",
- "412059/412059 [==============================] - 70s 171us/sample - loss: 0.6406 - categorical_accuracy: 0.8939 - val_loss: 0.7652 - val_categorical_accuracy: 0.8754\n",
+ "412059/412059 [==============================] - 71s 173us/sample - loss: 0.6381 - categorical_accuracy: 0.8937 - val_loss: 0.7663 - val_categorical_accuracy: 0.8753\n",
"Epoch 9/100\n",
- "412059/412059 [==============================] - 70s 171us/sample - loss: 0.6147 - categorical_accuracy: 0.8970 - val_loss: 0.7455 - val_categorical_accuracy: 0.8763\n",
+ "412059/412059 [==============================] - 77s 187us/sample - loss: 0.6152 - categorical_accuracy: 0.8965 - val_loss: 0.7514 - val_categorical_accuracy: 0.8760\n",
"Epoch 10/100\n",
- "412059/412059 [==============================] - 69s 168us/sample - loss: 0.5918 - categorical_accuracy: 0.9005 - val_loss: 0.7366 - val_categorical_accuracy: 0.8781\n",
+ "412059/412059 [==============================] - 80s 195us/sample - loss: 0.5921 - categorical_accuracy: 0.9001 - val_loss: 0.7471 - val_categorical_accuracy: 0.8779\n",
"Epoch 11/100\n",
- "412059/412059 [==============================] - 71s 173us/sample - loss: 0.5734 - categorical_accuracy: 0.9022 - val_loss: 0.7283 - val_categorical_accuracy: 0.8807\n",
+ "412059/412059 [==============================] - 78s 190us/sample - loss: 0.5713 - categorical_accuracy: 0.9031 - val_loss: 0.7266 - val_categorical_accuracy: 0.8799\n",
"Epoch 12/100\n",
- "412059/412059 [==============================] - 70s 169us/sample - loss: 0.5543 - categorical_accuracy: 0.9051 - val_loss: 0.7156 - val_categorical_accuracy: 0.8799\n",
+ "412059/412059 [==============================] - 71s 171us/sample - loss: 0.5537 - categorical_accuracy: 0.9051 - val_loss: 0.7146 - val_categorical_accuracy: 0.8812\n",
"Epoch 13/100\n",
- "412059/412059 [==============================] - 70s 169us/sample - loss: 0.5373 - categorical_accuracy: 0.9069 - val_loss: 0.7066 - val_categorical_accuracy: 0.8812\n",
+ "412059/412059 [==============================] - 76s 184us/sample - loss: 0.5372 - categorical_accuracy: 0.9073 - val_loss: 0.7075 - val_categorical_accuracy: 0.8824\n",
"Epoch 14/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.5219 - categorical_accuracy: 0.9093 - val_loss: 0.6935 - val_categorical_accuracy: 0.8824\n",
+ "412059/412059 [==============================] - 72s 175us/sample - loss: 0.5213 - categorical_accuracy: 0.9092 - val_loss: 0.7004 - val_categorical_accuracy: 0.8830\n",
"Epoch 15/100\n",
- "412059/412059 [==============================] - 71s 172us/sample - loss: 0.5074 - categorical_accuracy: 0.9109 - val_loss: 0.6861 - val_categorical_accuracy: 0.8837\n",
+ "412059/412059 [==============================] - 71s 173us/sample - loss: 0.5080 - categorical_accuracy: 0.9107 - val_loss: 0.6898 - val_categorical_accuracy: 0.8830\n",
"Epoch 16/100\n",
- "412059/412059 [==============================] - 71s 172us/sample - loss: 0.4940 - categorical_accuracy: 0.9128 - val_loss: 0.6828 - val_categorical_accuracy: 0.8859\n",
+ "412059/412059 [==============================] - 74s 180us/sample - loss: 0.4934 - categorical_accuracy: 0.9131 - val_loss: 0.6923 - val_categorical_accuracy: 0.8844\n",
"Epoch 17/100\n",
- "412059/412059 [==============================] - 70s 169us/sample - loss: 0.4819 - categorical_accuracy: 0.9147 - val_loss: 0.6756 - val_categorical_accuracy: 0.8857\n",
+ "412059/412059 [==============================] - 73s 177us/sample - loss: 0.4804 - categorical_accuracy: 0.9148 - val_loss: 0.6735 - val_categorical_accuracy: 0.8835\n",
"Epoch 18/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4706 - categorical_accuracy: 0.9159 - val_loss: 0.6664 - val_categorical_accuracy: 0.8861\n",
+ "412059/412059 [==============================] - 76s 183us/sample - loss: 0.4709 - categorical_accuracy: 0.9156 - val_loss: 0.6644 - val_categorical_accuracy: 0.8862\n",
"Epoch 19/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4593 - categorical_accuracy: 0.9172 - val_loss: 0.6627 - val_categorical_accuracy: 0.8868\n",
+ "412059/412059 [==============================] - 73s 177us/sample - loss: 0.4605 - categorical_accuracy: 0.9172 - val_loss: 0.6749 - val_categorical_accuracy: 0.8868\n",
"Epoch 20/100\n",
- "412059/412059 [==============================] - 71s 172us/sample - loss: 0.4491 - categorical_accuracy: 0.9186 - val_loss: 0.6623 - val_categorical_accuracy: 0.8867\n",
+ "412059/412059 [==============================] - 72s 175us/sample - loss: 0.4495 - categorical_accuracy: 0.9187 - val_loss: 0.6570 - val_categorical_accuracy: 0.8871\n",
"Epoch 21/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4399 - categorical_accuracy: 0.9198 - val_loss: 0.6573 - val_categorical_accuracy: 0.8883\n",
+ "412059/412059 [==============================] - 73s 176us/sample - loss: 0.4405 - categorical_accuracy: 0.9198 - val_loss: 0.6547 - val_categorical_accuracy: 0.8881\n",
"Epoch 22/100\n",
- "412059/412059 [==============================] - 70s 171us/sample - loss: 0.4331 - categorical_accuracy: 0.9207 - val_loss: 0.6486 - val_categorical_accuracy: 0.8884\n",
+ "412059/412059 [==============================] - 74s 181us/sample - loss: 0.4319 - categorical_accuracy: 0.9208 - val_loss: 0.6499 - val_categorical_accuracy: 0.8885\n",
"Epoch 23/100\n",
- "412059/412059 [==============================] - 70s 169us/sample - loss: 0.4241 - categorical_accuracy: 0.9221 - val_loss: 0.6499 - val_categorical_accuracy: 0.8889\n",
+ "412059/412059 [==============================] - 74s 180us/sample - loss: 0.4245 - categorical_accuracy: 0.9219 - val_loss: 0.6415 - val_categorical_accuracy: 0.8895\n",
"Epoch 24/100\n",
- "412059/412059 [==============================] - 76s 183us/sample - loss: 0.4168 - categorical_accuracy: 0.9230 - val_loss: 0.6432 - val_categorical_accuracy: 0.8906\n",
+ "412059/412059 [==============================] - 71s 172us/sample - loss: 0.4176 - categorical_accuracy: 0.9225 - val_loss: 0.6470 - val_categorical_accuracy: 0.8888\n",
"Epoch 25/100\n",
- "412059/412059 [==============================] - 71s 171us/sample - loss: 0.4098 - categorical_accuracy: 0.9241 - val_loss: 0.6417 - val_categorical_accuracy: 0.8896\n",
+ "412059/412059 [==============================] - 73s 177us/sample - loss: 0.4091 - categorical_accuracy: 0.9243 - val_loss: 0.6343 - val_categorical_accuracy: 0.8888\n",
"Epoch 26/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4029 - categorical_accuracy: 0.9246 - val_loss: 0.6344 - val_categorical_accuracy: 0.8901\n",
+ "412059/412059 [==============================] - 76s 184us/sample - loss: 0.4009 - categorical_accuracy: 0.9254 - val_loss: 0.6438 - val_categorical_accuracy: 0.8894\n",
"Epoch 27/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.3965 - categorical_accuracy: 0.9261 - val_loss: 0.6409 - val_categorical_accuracy: 0.8911\n",
+ "412059/412059 [==============================] - 73s 177us/sample - loss: 0.3960 - categorical_accuracy: 0.9260 - val_loss: 0.6343 - val_categorical_accuracy: 0.8910\n",
"Epoch 28/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.3909 - categorical_accuracy: 0.9266 - val_loss: 0.6341 - val_categorical_accuracy: 0.8901\n",
+ "412059/412059 [==============================] - 72s 174us/sample - loss: 0.3895 - categorical_accuracy: 0.9272 - val_loss: 0.6549 - val_categorical_accuracy: 0.8897\n",
"Epoch 29/100\n",
- "412059/412059 [==============================] - 71s 173us/sample - loss: 0.3844 - categorical_accuracy: 0.9275 - val_loss: 0.6272 - val_categorical_accuracy: 0.8906\n",
+ "412059/412059 [==============================] - 71s 173us/sample - loss: 0.3846 - categorical_accuracy: 0.9280 - val_loss: 0.6252 - val_categorical_accuracy: 0.8912\n",
"Epoch 30/100\n",
- "412059/412059 [==============================] - 75s 182us/sample - loss: 0.3798 - categorical_accuracy: 0.9283 - val_loss: 0.6249 - val_categorical_accuracy: 0.8915\n",
+ "412059/412059 [==============================] - 74s 180us/sample - loss: 0.3791 - categorical_accuracy: 0.9288 - val_loss: 0.6246 - val_categorical_accuracy: 0.8913\n",
"Epoch 31/100\n",
- "412059/412059 [==============================] - 72s 174us/sample - loss: 0.3744 - categorical_accuracy: 0.9297 - val_loss: 0.6229 - val_categorical_accuracy: 0.8920\n",
+ "412059/412059 [==============================] - 74s 180us/sample - loss: 0.3743 - categorical_accuracy: 0.9292 - val_loss: 0.6318 - val_categorical_accuracy: 0.8917\n",
"Epoch 32/100\n",
- "412059/412059 [==============================] - 70s 170us/sample - loss: 0.3700 - categorical_accuracy: 0.9302 - val_loss: 0.6198 - val_categorical_accuracy: 0.8920\n",
+ "412059/412059 [==============================] - 77s 187us/sample - loss: 0.3697 - categorical_accuracy: 0.9299 - val_loss: 0.6297 - val_categorical_accuracy: 0.8909\n",
"Epoch 33/100\n",
- "412059/412059 [==============================] - 71s 172us/sample - loss: 0.3661 - categorical_accuracy: 0.9307 - val_loss: 0.6238 - val_categorical_accuracy: 0.8919\n",
+ "412059/412059 [==============================] - 79s 192us/sample - loss: 0.3649 - categorical_accuracy: 0.9310 - val_loss: 0.6250 - val_categorical_accuracy: 0.8920\n",
"Epoch 34/100\n",
- "412059/412059 [==============================] - 70s 171us/sample - loss: 0.3614 - categorical_accuracy: 0.9311 - val_loss: 0.6186 - val_categorical_accuracy: 0.8926\n",
+ "412059/412059 [==============================] - 77s 186us/sample - loss: 0.3609 - categorical_accuracy: 0.9312 - val_loss: 0.6217 - val_categorical_accuracy: 0.8921\n",
"Epoch 35/100\n",
- "412059/412059 [==============================] - 68s 166us/sample - loss: 0.3579 - categorical_accuracy: 0.9320 - val_loss: 0.6219 - val_categorical_accuracy: 0.8924\n",
+ "412059/412059 [==============================] - 73s 176us/sample - loss: 0.3570 - categorical_accuracy: 0.9318 - val_loss: 0.6203 - val_categorical_accuracy: 0.8923\n",
"Epoch 36/100\n",
- "412059/412059 [==============================] - 71s 171us/sample - loss: 0.3532 - categorical_accuracy: 0.9326 - val_loss: 0.6191 - val_categorical_accuracy: 0.8922\n",
+ "412059/412059 [==============================] - 77s 187us/sample - loss: 0.3525 - categorical_accuracy: 0.9328 - val_loss: 0.6304 - val_categorical_accuracy: 0.8923\n",
"Epoch 37/100\n",
- "412059/412059 [==============================] - 71s 172us/sample - loss: 0.3517 - categorical_accuracy: 0.9328 - val_loss: 0.6137 - val_categorical_accuracy: 0.8932\n",
+ "412059/412059 [==============================] - 74s 178us/sample - loss: 0.3496 - categorical_accuracy: 0.9332 - val_loss: 0.6173 - val_categorical_accuracy: 0.8923\n",
"Epoch 38/100\n",
- "412059/412059 [==============================] - 69s 168us/sample - loss: 0.3464 - categorical_accuracy: 0.9336 - val_loss: 0.6201 - val_categorical_accuracy: 0.8937\n",
+ "412059/412059 [==============================] - 73s 177us/sample - loss: 0.3464 - categorical_accuracy: 0.9338 - val_loss: 0.6186 - val_categorical_accuracy: 0.8924\n",
"Epoch 39/100\n",
- "412059/412059 [==============================] - 68s 165us/sample - loss: 0.3445 - categorical_accuracy: 0.9340 - val_loss: 0.6211 - val_categorical_accuracy: 0.8929\n",
+ "412059/412059 [==============================] - 73s 176us/sample - loss: 0.3435 - categorical_accuracy: 0.9342 - val_loss: 0.6116 - val_categorical_accuracy: 0.8938\n",
"Epoch 40/100\n",
- "412059/412059 [==============================] - 67s 162us/sample - loss: 0.3394 - categorical_accuracy: 0.9350 - val_loss: 0.6199 - val_categorical_accuracy: 0.8934\n",
+ "412059/412059 [==============================] - 73s 177us/sample - loss: 0.3414 - categorical_accuracy: 0.9345 - val_loss: 0.6105 - val_categorical_accuracy: 0.8930\n",
"Epoch 41/100\n",
- "412059/412059 [==============================] - 68s 165us/sample - loss: 0.3377 - categorical_accuracy: 0.9351 - val_loss: 0.6216 - val_categorical_accuracy: 0.8934\n",
+ "412059/412059 [==============================] - 82s 198us/sample - loss: 0.3386 - categorical_accuracy: 0.9351 - val_loss: 0.6201 - val_categorical_accuracy: 0.8936\n",
"Epoch 42/100\n",
- "412059/412059 [==============================] - 70s 169us/sample - loss: 0.3356 - categorical_accuracy: 0.9352 - val_loss: 0.6155 - val_categorical_accuracy: 0.8939\n",
+ "412059/412059 [==============================] - 77s 187us/sample - loss: 0.3342 - categorical_accuracy: 0.9356 - val_loss: 0.6132 - val_categorical_accuracy: 0.8913\n",
+ "Epoch 43/100\n",
+ "412059/412059 [==============================] - 80s 194us/sample - loss: 0.3318 - categorical_accuracy: 0.9362 - val_loss: 0.6141 - val_categorical_accuracy: 0.8932\n",
+ "Epoch 44/100\n",
+ "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3295 - categorical_accuracy: 0.9368 - val_loss: 0.6101 - val_categorical_accuracy: 0.8940\n",
+ "Epoch 45/100\n",
+ "412059/412059 [==============================] - 78s 188us/sample - loss: 0.3270 - categorical_accuracy: 0.9370 - val_loss: 0.6088 - val_categorical_accuracy: 0.8939\n",
+ "Epoch 46/100\n",
+ "412059/412059 [==============================] - 76s 185us/sample - loss: 0.3252 - categorical_accuracy: 0.9373 - val_loss: 0.6152 - val_categorical_accuracy: 0.8939\n",
+ "Epoch 47/100\n",
+ "412059/412059 [==============================] - 82s 200us/sample - loss: 0.3237 - categorical_accuracy: 0.9377 - val_loss: 0.6105 - val_categorical_accuracy: 0.8939\n",
+ "Epoch 48/100\n",
+ "412059/412059 [==============================] - 78s 188us/sample - loss: 0.3217 - categorical_accuracy: 0.9377 - val_loss: 0.6070 - val_categorical_accuracy: 0.8945\n",
+ "Epoch 49/100\n",
+ "412059/412059 [==============================] - 75s 181us/sample - loss: 0.3187 - categorical_accuracy: 0.9386 - val_loss: 0.6067 - val_categorical_accuracy: 0.8943\n",
+ "Epoch 50/100\n",
+ "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3184 - categorical_accuracy: 0.9382 - val_loss: 0.6094 - val_categorical_accuracy: 0.8947\n",
+ "Epoch 51/100\n",
+ "412059/412059 [==============================] - 76s 185us/sample - loss: 0.3155 - categorical_accuracy: 0.9391 - val_loss: 0.6054 - val_categorical_accuracy: 0.8948\n",
+ "Epoch 52/100\n",
+ "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3143 - categorical_accuracy: 0.9392 - val_loss: 0.6060 - val_categorical_accuracy: 0.8943\n",
+ "Epoch 53/100\n",
+ "412059/412059 [==============================] - 75s 181us/sample - loss: 0.3105 - categorical_accuracy: 0.9400 - val_loss: 0.6130 - val_categorical_accuracy: 0.8954\n",
+ "Epoch 54/100\n",
+ "412059/412059 [==============================] - 78s 188us/sample - loss: 0.3109 - categorical_accuracy: 0.9397 - val_loss: 0.6030 - val_categorical_accuracy: 0.8952\n",
+ "Epoch 55/100\n",
+ "412059/412059 [==============================] - 78s 189us/sample - loss: 0.3091 - categorical_accuracy: 0.9406 - val_loss: 0.6148 - val_categorical_accuracy: 0.8949\n",
+ "Epoch 56/100\n",
+ "412059/412059 [==============================] - 78s 189us/sample - loss: 0.3076 - categorical_accuracy: 0.9404 - val_loss: 0.6075 - val_categorical_accuracy: 0.8947\n",
+ "Epoch 57/100\n",
+ "412059/412059 [==============================] - 76s 185us/sample - loss: 0.3056 - categorical_accuracy: 0.9411 - val_loss: 0.6167 - val_categorical_accuracy: 0.8954\n",
+ "Epoch 58/100\n",
+ "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3033 - categorical_accuracy: 0.9411 - val_loss: 0.6130 - val_categorical_accuracy: 0.8955\n",
+ "Epoch 59/100\n",
+ "412059/412059 [==============================] - 68s 166us/sample - loss: 0.3041 - categorical_accuracy: 0.9408 - val_loss: 0.6100 - val_categorical_accuracy: 0.8950\n",
"Trained and saved new model.\n",
- "Finished at 2022-02-10 21:37:31.052443, took 0:49:28.030848 seconds\n"
+ "Finished at 2022-02-21 12:39:29.005837, took 1:13:07.304278 seconds\n"
]
}
],
@@ -312,7 +357,8 @@
"start = datetime.now()\n",
"print(f'Started at {start}')\n",
"\n",
- "sherlock_model = SherlockModel()\n",
+ "model = SherlockModel()\n",
+ "# Model will be stored with ID `model_id`\n",
"model.fit(X_train, y_train, X_validation, y_validation, model_id=model_id)\n",
"\n",
"print('Trained and saved new model.')\n",
@@ -321,7 +367,7 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
@@ -344,7 +390,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -354,7 +400,7 @@
},
{
"cell_type": "code",
- "execution_count": 51,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -367,10 +413,10 @@
{
"data": {
"text/plain": [
- "0.8937685721454983"
+ "0.8951410029373902"
]
},
- "execution_count": 51,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -386,11 +432,12 @@
},
{
"cell_type": "code",
- "execution_count": 52,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# If using the original model, model_id should be replaced with \"sherlock\"\n",
+ "#model_id = \"sherlock\"\n",
"classes = np.load(f\"../model_files/classes_{model_id}.npy\", allow_pickle=True)\n",
"\n",
"report = classification_report(y_test, predicted_labels, output_dict=True)\n",
@@ -409,7 +456,7 @@
},
{
"cell_type": "code",
- "execution_count": 53,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
@@ -417,11 +464,11 @@
"output_type": "stream",
"text": [
"\t\tf1-score\tprecision\trecall\t\tsupport\n",
- "grades\t\t0.993\t\t0.994\t\t0.993\t\t1765\n",
- "isbn\t\t0.990\t\t0.992\t\t0.987\t\t1430\n",
- "jockey\t\t0.986\t\t0.980\t\t0.992\t\t2819\n",
- "currency\t0.980\t\t0.987\t\t0.973\t\t405\n",
- "industry\t0.980\t\t0.975\t\t0.985\t\t2958\n"
+ "grades\t\t0.993\t\t0.993\t\t0.993\t\t1765\n",
+ "isbn\t\t0.991\t\t0.993\t\t0.988\t\t1430\n",
+ "jockey\t\t0.985\t\t0.982\t\t0.988\t\t2819\n",
+ "industry\t0.984\t\t0.983\t\t0.985\t\t2958\n",
+ "currency\t0.975\t\t0.982\t\t0.968\t\t405\n"
]
}
],
@@ -446,7 +493,7 @@
},
{
"cell_type": "code",
- "execution_count": 54,
+ "execution_count": 17,
"metadata": {
"scrolled": true
},
@@ -456,11 +503,11 @@
"output_type": "stream",
"text": [
"\t\tf1-score\tprecision\trecall\t\tsupport\n",
- "rank\t\t0.706\t\t0.631\t\t0.802\t\t2983\n",
- "person\t\t0.665\t\t0.721\t\t0.617\t\t579\n",
- "director\t0.581\t\t0.673\t\t0.511\t\t225\n",
- "sales\t\t0.555\t\t0.554\t\t0.556\t\t322\n",
- "ranking\t\t0.486\t\t0.722\t\t0.367\t\t439\n"
+ "rank\t\t0.693\t\t0.625\t\t0.778\t\t2983\n",
+ "person\t\t0.664\t\t0.717\t\t0.618\t\t579\n",
+ "director\t0.568\t\t0.591\t\t0.547\t\t225\n",
+ "sales\t\t0.556\t\t0.586\t\t0.528\t\t322\n",
+ "ranking\t\t0.441\t\t0.753\t\t0.312\t\t439\n"
]
}
],
@@ -485,7 +532,7 @@
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -494,88 +541,88 @@
"text": [
" precision recall f1-score support\n",
"\n",
- " address 0.944 0.936 0.940 3003\n",
- " affiliate 0.922 0.814 0.865 204\n",
- " affiliation 0.978 0.951 0.964 1768\n",
- " age 0.847 0.960 0.900 3033\n",
- " album 0.863 0.904 0.883 3035\n",
- " area 0.911 0.802 0.853 1987\n",
- " artist 0.810 0.870 0.839 3043\n",
- " birth date 0.979 0.977 0.978 479\n",
- " birth place 0.970 0.916 0.942 418\n",
- " brand 0.779 0.707 0.742 574\n",
- " capacity 0.773 0.715 0.743 362\n",
- " category 0.915 0.886 0.900 3087\n",
- " city 0.856 0.889 0.872 2966\n",
- " class 0.920 0.911 0.915 2971\n",
- "classification 0.924 0.848 0.885 587\n",
- " club 0.962 0.962 0.962 2977\n",
- " code 0.931 0.902 0.916 2956\n",
- " collection 0.982 0.933 0.957 476\n",
- " command 0.929 0.919 0.924 1045\n",
- " company 0.895 0.890 0.892 3041\n",
- " component 0.875 0.878 0.877 1226\n",
- " continent 0.905 0.877 0.890 227\n",
- " country 0.894 0.942 0.917 3038\n",
- " county 0.937 0.960 0.948 2959\n",
- " creator 0.829 0.827 0.828 347\n",
- " credit 0.860 0.840 0.849 941\n",
- " currency 0.987 0.973 0.980 405\n",
- " day 0.938 0.893 0.915 3038\n",
- " depth 0.922 0.930 0.926 947\n",
- " description 0.804 0.861 0.832 3042\n",
- " director 0.673 0.511 0.581 225\n",
- " duration 0.945 0.936 0.941 3000\n",
- " education 0.805 0.789 0.797 313\n",
- " elevation 0.961 0.944 0.952 1299\n",
- " family 0.961 0.901 0.930 746\n",
- " file size 0.913 0.867 0.889 361\n",
- " format 0.973 0.955 0.964 2956\n",
- " gender 0.806 0.787 0.797 1030\n",
- " genre 0.951 0.954 0.953 1163\n",
- " grades 0.994 0.993 0.993 1765\n",
- " industry 0.975 0.985 0.980 2958\n",
- " isbn 0.992 0.987 0.990 1430\n",
- " jockey 0.980 0.992 0.986 2819\n",
- " language 0.837 0.963 0.896 1474\n",
- " location 0.865 0.834 0.849 2949\n",
- " manufacturer 0.848 0.819 0.833 945\n",
- " name 0.732 0.748 0.740 3017\n",
- " nationality 0.823 0.748 0.784 424\n",
- " notes 0.769 0.842 0.804 2303\n",
- " operator 0.808 0.832 0.820 404\n",
- " order 0.949 0.826 0.883 1462\n",
- " organisation 0.844 0.828 0.836 262\n",
- " origin 0.957 0.887 0.921 1439\n",
- " owner 0.939 0.874 0.905 1673\n",
- " person 0.721 0.617 0.665 579\n",
- " plays 0.842 0.900 0.870 1513\n",
- " position 0.831 0.820 0.826 3057\n",
- " product 0.890 0.864 0.877 2647\n",
- " publisher 0.943 0.881 0.911 880\n",
- " range 0.839 0.766 0.801 577\n",
- " rank 0.631 0.802 0.706 2983\n",
- " ranking 0.722 0.367 0.486 439\n",
- " region 0.831 0.827 0.829 2740\n",
- " religion 0.963 0.909 0.935 340\n",
- " requirement 0.949 0.800 0.868 300\n",
- " result 0.969 0.934 0.951 2920\n",
- " sales 0.554 0.556 0.555 322\n",
- " service 0.972 0.921 0.946 2222\n",
- " sex 0.921 0.915 0.918 2997\n",
- " species 0.921 0.951 0.936 819\n",
- " state 0.922 0.961 0.941 3030\n",
- " status 0.951 0.930 0.940 3100\n",
- " symbol 0.952 0.952 0.952 1752\n",
- " team 0.867 0.858 0.863 3011\n",
- " team name 0.899 0.828 0.862 1639\n",
- " type 0.871 0.882 0.877 2909\n",
- " weight 0.954 0.930 0.942 2963\n",
- " year 0.964 0.936 0.950 3015\n",
+ " address 0.931 0.943 0.937 3003\n",
+ " affiliate 0.943 0.809 0.871 204\n",
+ " affiliation 0.973 0.957 0.965 1768\n",
+ " age 0.866 0.950 0.906 3033\n",
+ " album 0.892 0.889 0.890 3035\n",
+ " area 0.870 0.820 0.844 1987\n",
+ " artist 0.816 0.873 0.844 3043\n",
+ " birth date 0.985 0.969 0.977 479\n",
+ " birth place 0.934 0.921 0.928 418\n",
+ " brand 0.830 0.671 0.742 574\n",
+ " capacity 0.793 0.721 0.755 362\n",
+ " category 0.924 0.890 0.906 3087\n",
+ " city 0.864 0.904 0.883 2966\n",
+ " class 0.901 0.915 0.908 2971\n",
+ "classification 0.927 0.862 0.893 587\n",
+ " club 0.974 0.955 0.964 2977\n",
+ " code 0.916 0.907 0.912 2956\n",
+ " collection 0.984 0.931 0.957 476\n",
+ " command 0.938 0.904 0.921 1045\n",
+ " company 0.912 0.888 0.900 3041\n",
+ " component 0.888 0.880 0.884 1226\n",
+ " continent 0.875 0.894 0.885 227\n",
+ " country 0.892 0.950 0.920 3038\n",
+ " county 0.943 0.957 0.950 2959\n",
+ " creator 0.770 0.841 0.804 347\n",
+ " credit 0.868 0.813 0.840 941\n",
+ " currency 0.982 0.968 0.975 405\n",
+ " day 0.953 0.892 0.921 3038\n",
+ " depth 0.931 0.916 0.923 947\n",
+ " description 0.804 0.869 0.835 3042\n",
+ " director 0.591 0.547 0.568 225\n",
+ " duration 0.924 0.948 0.936 3000\n",
+ " education 0.856 0.818 0.837 313\n",
+ " elevation 0.956 0.946 0.951 1299\n",
+ " family 0.964 0.895 0.928 746\n",
+ " file size 0.941 0.845 0.891 361\n",
+ " format 0.966 0.959 0.963 2956\n",
+ " gender 0.868 0.721 0.788 1030\n",
+ " genre 0.965 0.952 0.958 1163\n",
+ " grades 0.993 0.993 0.993 1765\n",
+ " industry 0.983 0.985 0.984 2958\n",
+ " isbn 0.993 0.988 0.991 1430\n",
+ " jockey 0.982 0.988 0.985 2819\n",
+ " language 0.939 0.953 0.946 1474\n",
+ " location 0.896 0.827 0.861 2949\n",
+ " manufacturer 0.865 0.819 0.841 945\n",
+ " name 0.724 0.759 0.741 3017\n",
+ " nationality 0.907 0.691 0.784 424\n",
+ " notes 0.724 0.842 0.779 2303\n",
+ " operator 0.794 0.847 0.819 404\n",
+ " order 0.869 0.887 0.878 1462\n",
+ " organisation 0.832 0.832 0.832 262\n",
+ " origin 0.947 0.900 0.923 1439\n",
+ " owner 0.931 0.869 0.899 1673\n",
+ " person 0.717 0.618 0.664 579\n",
+ " plays 0.846 0.903 0.873 1513\n",
+ " position 0.806 0.839 0.822 3057\n",
+ " product 0.868 0.878 0.873 2647\n",
+ " publisher 0.888 0.899 0.893 880\n",
+ " range 0.855 0.759 0.804 577\n",
+ " rank 0.625 0.778 0.693 2983\n",
+ " ranking 0.753 0.312 0.441 439\n",
+ " region 0.882 0.810 0.845 2740\n",
+ " religion 0.972 0.921 0.946 340\n",
+ " requirement 0.927 0.807 0.863 300\n",
+ " result 0.962 0.940 0.951 2920\n",
+ " sales 0.586 0.528 0.556 322\n",
+ " service 0.964 0.925 0.944 2222\n",
+ " sex 0.903 0.945 0.924 2997\n",
+ " species 0.921 0.950 0.935 819\n",
+ " state 0.939 0.957 0.948 3030\n",
+ " status 0.943 0.936 0.940 3100\n",
+ " symbol 0.958 0.946 0.952 1752\n",
+ " team 0.850 0.870 0.860 3011\n",
+ " team name 0.893 0.827 0.859 1639\n",
+ " type 0.916 0.875 0.895 2909\n",
+ " weight 0.958 0.931 0.944 2963\n",
+ " year 0.965 0.937 0.951 3015\n",
"\n",
- " accuracy 0.894 137353\n",
- " macro avg 0.887 0.866 0.875 137353\n",
- " weighted avg 0.896 0.894 0.894 137353\n",
+ " accuracy 0.895 137353\n",
+ " macro avg 0.890 0.866 0.876 137353\n",
+ " weighted avg 0.898 0.895 0.895 137353\n",
"\n"
]
}
@@ -593,7 +640,7 @@
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -605,193 +652,171 @@
"[2420] expected \"address\" but predicted \"club\"\n",
"[2616] expected \"address\" but predicted \"city\"\n",
"[3398] expected \"address\" but predicted \"city\"\n",
- "[4380] expected \"address\" but predicted \"county\"\n",
- "[4422] expected \"address\" but predicted \"city\"\n",
"[5112] expected \"address\" but predicted \"location\"\n",
"[5546] expected \"address\" but predicted \"name\"\n",
- "[5647] expected \"address\" but predicted \"team name\"\n",
- "[7119] expected \"address\" but predicted \"day\"\n",
+ "[6526] expected \"address\" but predicted \"location\"\n",
"[8797] expected \"address\" but predicted \"location\"\n",
"[9354] expected \"address\" but predicted \"location\"\n",
"[9574] expected \"address\" but predicted \"location\"\n",
"[9806] expected \"address\" but predicted \"city\"\n",
"[10035] expected \"address\" but predicted \"creator\"\n",
- "[10067] expected \"address\" but predicted \"education\"\n",
- "[11055] expected \"address\" but predicted \"city\"\n",
- "[11902] expected \"address\" but predicted \"location\"\n",
+ "[10067] expected \"address\" but predicted \"order\"\n",
+ "[10665] expected \"address\" but predicted \"area\"\n",
+ "[11055] expected \"address\" but predicted \"county\"\n",
+ "[11902] expected \"address\" but predicted \"jockey\"\n",
+ "[11993] expected \"address\" but predicted \"location\"\n",
"[12072] expected \"address\" but predicted \"artist\"\n",
- "[12639] expected \"address\" but predicted \"location\"\n",
+ "[14200] expected \"address\" but predicted \"description\"\n",
"[14677] expected \"address\" but predicted \"location\"\n",
"[15232] expected \"address\" but predicted \"city\"\n",
"[15461] expected \"address\" but predicted \"location\"\n",
"[15496] expected \"address\" but predicted \"city\"\n",
- "[16212] expected \"address\" but predicted \"location\"\n",
+ "[15987] expected \"address\" but predicted \"artist\"\n",
"[19953] expected \"address\" but predicted \"county\"\n",
- "[20829] expected \"address\" but predicted \"language\"\n",
- "[21408] expected \"address\" but predicted \"description\"\n",
- "[22148] expected \"address\" but predicted \"location\"\n",
- "[22355] expected \"address\" but predicted \"location\"\n",
+ "[20425] expected \"address\" but predicted \"location\"\n",
+ "[20829] expected \"address\" but predicted \"notes\"\n",
+ "[21408] expected \"address\" but predicted \"position\"\n",
+ "[21666] expected \"address\" but predicted \"area\"\n",
+ "[22148] expected \"address\" but predicted \"product\"\n",
"[23915] expected \"address\" but predicted \"location\"\n",
"[24636] expected \"address\" but predicted \"team name\"\n",
- "[24803] expected \"address\" but predicted \"day\"\n",
+ "[24803] expected \"address\" but predicted \"position\"\n",
"[26171] expected \"address\" but predicted \"name\"\n",
"[26184] expected \"address\" but predicted \"name\"\n",
- "[26210] expected \"address\" but predicted \"state\"\n",
- "[26296] expected \"address\" but predicted \"file size\"\n",
+ "[26210] expected \"address\" but predicted \"name\"\n",
+ "[26393] expected \"address\" but predicted \"position\"\n",
"[26559] expected \"address\" but predicted \"location\"\n",
"[26872] expected \"address\" but predicted \"command\"\n",
- "[30403] expected \"address\" but predicted \"language\"\n",
+ "[30403] expected \"address\" but predicted \"notes\"\n",
"[31391] expected \"address\" but predicted \"location\"\n",
"[31515] expected \"address\" but predicted \"location\"\n",
- "[31830] expected \"address\" but predicted \"location\"\n",
- "[31836] expected \"address\" but predicted \"language\"\n",
- "[31976] expected \"address\" but predicted \"location\"\n",
- "[32429] expected \"address\" but predicted \"artist\"\n",
+ "[31830] expected \"address\" but predicted \"result\"\n",
+ "[31836] expected \"address\" but predicted \"notes\"\n",
"[32551] expected \"address\" but predicted \"city\"\n",
- "[32634] expected \"address\" but predicted \"location\"\n",
- "[32762] expected \"address\" but predicted \"name\"\n",
- "[33207] expected \"address\" but predicted \"region\"\n",
- "[33975] expected \"address\" but predicted \"location\"\n",
+ "[32762] expected \"address\" but predicted \"artist\"\n",
+ "[33207] expected \"address\" but predicted \"order\"\n",
"[34547] expected \"address\" but predicted \"location\"\n",
- "[34849] expected \"address\" but predicted \"location\"\n",
+ "[34711] expected \"address\" but predicted \"location\"\n",
"[35467] expected \"address\" but predicted \"location\"\n",
- "[35756] expected \"address\" but predicted \"city\"\n",
- "[35907] expected \"address\" but predicted \"county\"\n",
- "[35938] expected \"address\" but predicted \"language\"\n",
- "[36033] expected \"address\" but predicted \"name\"\n",
+ "[35938] expected \"address\" but predicted \"notes\"\n",
"[37084] expected \"address\" but predicted \"person\"\n",
- "[37318] expected \"address\" but predicted \"language\"\n",
- "[37536] expected \"address\" but predicted \"language\"\n",
- "[38118] expected \"address\" but predicted \"location\"\n",
+ "[37318] expected \"address\" but predicted \"notes\"\n",
+ "[37536] expected \"address\" but predicted \"notes\"\n",
"[40184] expected \"address\" but predicted \"notes\"\n",
"[40457] expected \"address\" but predicted \"country\"\n",
- "[41540] expected \"address\" but predicted \"location\"\n",
- "[42439] expected \"address\" but predicted \"description\"\n",
+ "[42439] expected \"address\" but predicted \"area\"\n",
+ "[44530] expected \"address\" but predicted \"artist\"\n",
"[44906] expected \"address\" but predicted \"location\"\n",
"[44918] expected \"address\" but predicted \"area\"\n",
- "[46430] expected \"address\" but predicted \"location\"\n",
+ "[44932] expected \"address\" but predicted \"name\"\n",
+ "[46430] expected \"address\" but predicted \"birth place\"\n",
"[46463] expected \"address\" but predicted \"notes\"\n",
- "[47140] expected \"address\" but predicted \"team name\"\n",
- "[47249] expected \"address\" but predicted \"language\"\n",
- "[47497] expected \"address\" but predicted \"location\"\n",
- "[47636] expected \"address\" but predicted \"location\"\n",
+ "[47140] expected \"address\" but predicted \"age\"\n",
+ "[47249] expected \"address\" but predicted \"notes\"\n",
"[47713] expected \"address\" but predicted \"city\"\n",
- "[47810] expected \"address\" but predicted \"category\"\n",
+ "[47810] expected \"address\" but predicted \"description\"\n",
"[48016] expected \"address\" but predicted \"location\"\n",
- "[48289] expected \"address\" but predicted \"location\"\n",
"[48631] expected \"address\" but predicted \"location\"\n",
- "[50317] expected \"address\" but predicted \"name\"\n",
- "[50329] expected \"address\" but predicted \"location\"\n",
- "[50783] expected \"address\" but predicted \"city\"\n",
+ "[50329] expected \"address\" but predicted \"area\"\n",
"[51643] expected \"address\" but predicted \"city\"\n",
- "[51897] expected \"address\" but predicted \"location\"\n",
- "[52120] expected \"address\" but predicted \"city\"\n",
+ "[51887] expected \"address\" but predicted \"location\"\n",
+ "[53960] expected \"address\" but predicted \"service\"\n",
"[54248] expected \"address\" but predicted \"area\"\n",
- "[55535] expected \"address\" but predicted \"language\"\n",
- "[56030] expected \"address\" but predicted \"description\"\n",
+ "[56030] expected \"address\" but predicted \"code\"\n",
+ "[56085] expected \"address\" but predicted \"name\"\n",
+ "[56176] expected \"address\" but predicted \"product\"\n",
"[57015] expected \"address\" but predicted \"location\"\n",
- "[57851] expected \"address\" but predicted \"location\"\n",
"[58524] expected \"address\" but predicted \"location\"\n",
+ "[58958] expected \"address\" but predicted \"team name\"\n",
"[59506] expected \"address\" but predicted \"name\"\n",
- "[59632] expected \"address\" but predicted \"location\"\n",
- "[59734] expected \"address\" but predicted \"location\"\n",
- "[60105] expected \"address\" but predicted \"location\"\n",
- "[60225] expected \"address\" but predicted \"language\"\n",
- "[60441] expected \"address\" but predicted \"location\"\n",
- "[60565] expected \"address\" but predicted \"region\"\n",
+ "[59734] expected \"address\" but predicted \"language\"\n",
+ "[60105] expected \"address\" but predicted \"area\"\n",
+ "[60225] expected \"address\" but predicted \"notes\"\n",
+ "[60441] expected \"address\" but predicted \"collection\"\n",
+ "[60565] expected \"address\" but predicted \"ranking\"\n",
"[61783] expected \"address\" but predicted \"location\"\n",
- "[61894] expected \"address\" but predicted \"language\"\n",
- "[61975] expected \"address\" but predicted \"capacity\"\n",
- "[64427] expected \"address\" but predicted \"region\"\n",
+ "[61894] expected \"address\" but predicted \"notes\"\n",
+ "[66355] expected \"address\" but predicted \"artist\"\n",
"[67142] expected \"address\" but predicted \"code\"\n",
- "[68698] expected \"address\" but predicted \"area\"\n",
"[69794] expected \"address\" but predicted \"city\"\n",
"[70071] expected \"address\" but predicted \"location\"\n",
"[71228] expected \"address\" but predicted \"location\"\n",
- "[71784] expected \"address\" but predicted \"symbol\"\n",
+ "[71784] expected \"address\" but predicted \"location\"\n",
+ "[72001] expected \"address\" but predicted \"product\"\n",
"[72226] expected \"address\" but predicted \"notes\"\n",
- "[73451] expected \"address\" but predicted \"name\"\n",
+ "[73360] expected \"address\" but predicted \"species\"\n",
"[73573] expected \"address\" but predicted \"name\"\n",
- "[74228] expected \"address\" but predicted \"elevation\"\n",
- "[75861] expected \"address\" but predicted \"location\"\n",
+ "[74228] expected \"address\" but predicted \"location\"\n",
"[75893] expected \"address\" but predicted \"notes\"\n",
"[76551] expected \"address\" but predicted \"language\"\n",
- "[77502] expected \"address\" but predicted \"location\"\n",
- "[77796] expected \"address\" but predicted \"city\"\n",
- "[78437] expected \"address\" but predicted \"location\"\n",
- "[80712] expected \"address\" but predicted \"name\"\n",
- "[81288] expected \"address\" but predicted \"city\"\n",
- "[82082] expected \"address\" but predicted \"artist\"\n",
+ "[77796] expected \"address\" but predicted \"location\"\n",
+ "[80669] expected \"address\" but predicted \"location\"\n",
+ "[80712] expected \"address\" but predicted \"category\"\n",
+ "[82082] expected \"address\" but predicted \"location\"\n",
"[82779] expected \"address\" but predicted \"range\"\n",
- "[83204] expected \"address\" but predicted \"name\"\n",
- "[84979] expected \"address\" but predicted \"country\"\n",
+ "[83478] expected \"address\" but predicted \"city\"\n",
+ "[84979] expected \"address\" but predicted \"weight\"\n",
"[85206] expected \"address\" but predicted \"location\"\n",
"[85353] expected \"address\" but predicted \"creator\"\n",
- "[85752] expected \"address\" but predicted \"language\"\n",
+ "[85752] expected \"address\" but predicted \"notes\"\n",
"[85930] expected \"address\" but predicted \"language\"\n",
- "[85971] expected \"address\" but predicted \"location\"\n",
- "[86084] expected \"address\" but predicted \"name\"\n",
"[86891] expected \"address\" but predicted \"rank\"\n",
- "[87332] expected \"address\" but predicted \"capacity\"\n",
+ "[87332] expected \"address\" but predicted \"code\"\n",
"[87413] expected \"address\" but predicted \"location\"\n",
"[87891] expected \"address\" but predicted \"sales\"\n",
- "[87958] expected \"address\" but predicted \"affiliation\"\n",
- "[89058] expected \"address\" but predicted \"city\"\n",
+ "[87958] expected \"address\" but predicted \"format\"\n",
+ "[88056] expected \"address\" but predicted \"city\"\n",
+ "[89504] expected \"address\" but predicted \"area\"\n",
"[89800] expected \"address\" but predicted \"location\"\n",
- "[90054] expected \"address\" but predicted \"name\"\n",
+ "[90054] expected \"address\" but predicted \"location\"\n",
"[90466] expected \"address\" but predicted \"location\"\n",
"[90582] expected \"address\" but predicted \"city\"\n",
- "[92393] expected \"address\" but predicted \"location\"\n",
- "[93361] expected \"address\" but predicted \"city\"\n",
+ "[90584] expected \"address\" but predicted \"location\"\n",
"[93557] expected \"address\" but predicted \"country\"\n",
+ "[95220] expected \"address\" but predicted \"notes\"\n",
"[95411] expected \"address\" but predicted \"location\"\n",
- "[95748] expected \"address\" but predicted \"location\"\n",
- "[95769] expected \"address\" but predicted \"city\"\n",
+ "[95769] expected \"address\" but predicted \"team name\"\n",
"[96379] expected \"address\" but predicted \"location\"\n",
- "[96640] expected \"address\" but predicted \"product\"\n",
- "[96728] expected \"address\" but predicted \"state\"\n",
- "[99038] expected \"address\" but predicted \"name\"\n",
+ "[96640] expected \"address\" but predicted \"format\"\n",
+ "[96728] expected \"address\" but predicted \"birth place\"\n",
+ "[97594] expected \"address\" but predicted \"artist\"\n",
+ "[98529] expected \"address\" but predicted \"artist\"\n",
"[99237] expected \"address\" but predicted \"location\"\n",
"[100797] expected \"address\" but predicted \"symbol\"\n",
"[101634] expected \"address\" but predicted \"area\"\n",
"[102060] expected \"address\" but predicted \"name\"\n",
- "[103165] expected \"address\" but predicted \"language\"\n",
- "[103681] expected \"address\" but predicted \"location\"\n",
- "[105085] expected \"address\" but predicted \"city\"\n",
+ "[103165] expected \"address\" but predicted \"notes\"\n",
"[107281] expected \"address\" but predicted \"location\"\n",
"[108550] expected \"address\" but predicted \"grades\"\n",
- "[109367] expected \"address\" but predicted \"region\"\n",
+ "[109367] expected \"address\" but predicted \"team name\"\n",
"[109427] expected \"address\" but predicted \"result\"\n",
- "[109740] expected \"address\" but predicted \"city\"\n",
"[111142] expected \"address\" but predicted \"location\"\n",
+ "[112597] expected \"address\" but predicted \"location\"\n",
"[113711] expected \"address\" but predicted \"location\"\n",
- "[114356] expected \"address\" but predicted \"day\"\n",
+ "[114356] expected \"address\" but predicted \"position\"\n",
"[114372] expected \"address\" but predicted \"area\"\n",
"[114485] expected \"address\" but predicted \"description\"\n",
- "[115607] expected \"address\" but predicted \"region\"\n",
- "[116630] expected \"address\" but predicted \"county\"\n",
- "[118623] expected \"address\" but predicted \"language\"\n",
+ "[118623] expected \"address\" but predicted \"notes\"\n",
"[120003] expected \"address\" but predicted \"owner\"\n",
- "[120180] expected \"address\" but predicted \"description\"\n",
- "[122561] expected \"address\" but predicted \"language\"\n",
+ "[122561] expected \"address\" but predicted \"notes\"\n",
"[123013] expected \"address\" but predicted \"rank\"\n",
"[123713] expected \"address\" but predicted \"location\"\n",
"[124794] expected \"address\" but predicted \"company\"\n",
- "[124809] expected \"address\" but predicted \"location\"\n",
- "[125936] expected \"address\" but predicted \"location\"\n",
- "[126147] expected \"address\" but predicted \"location\"\n",
+ "[126147] expected \"address\" but predicted \"region\"\n",
"[126442] expected \"address\" but predicted \"capacity\"\n",
"[126729] expected \"address\" but predicted \"location\"\n",
- "[127753] expected \"address\" but predicted \"education\"\n",
- "[131254] expected \"address\" but predicted \"company\"\n",
+ "[127753] expected \"address\" but predicted \"order\"\n",
+ "[129842] expected \"address\" but predicted \"birth place\"\n",
+ "[130902] expected \"address\" but predicted \"rank\"\n",
+ "[131254] expected \"address\" but predicted \"owner\"\n",
+ "[131471] expected \"address\" but predicted \"location\"\n",
"[131693] expected \"address\" but predicted \"location\"\n",
"[132446] expected \"address\" but predicted \"location\"\n",
- "[132881] expected \"address\" but predicted \"artist\"\n",
- "[132989] expected \"address\" but predicted \"location\"\n",
- "[133669] expected \"address\" but predicted \"album\"\n",
+ "[132881] expected \"address\" but predicted \"album\"\n",
+ "[133665] expected \"address\" but predicted \"region\"\n",
"[133805] expected \"address\" but predicted \"name\"\n",
- "[136727] expected \"address\" but predicted \"status\"\n",
- "[136771] expected \"address\" but predicted \"location\"\n",
+ "[135969] expected \"address\" but predicted \"area\"\n",
+ "[136727] expected \"address\" but predicted \"result\"\n",
"[137027] expected \"address\" but predicted \"location\"\n"
]
},
@@ -799,93 +824,93 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Total mismatches: 14623 (F1 score: 0.8937685721454983)\n"
+ "Total mismatches: 14419 (F1 score: 0.8951410029373902)\n"
]
},
{
"data": {
"text/plain": [
- "[('name', 761),\n",
- " ('rank', 592),\n",
- " ('position', 551),\n",
- " ('location', 489),\n",
- " ('region', 473),\n",
- " ('team', 428),\n",
- " ('description', 422),\n",
- " ('artist', 395),\n",
- " ('area', 394),\n",
+ "[('name', 727),\n",
+ " ('rank', 663),\n",
+ " ('region', 521),\n",
+ " ('location', 509),\n",
+ " ('position', 491),\n",
+ " ('description', 400),\n",
+ " ('team', 390),\n",
+ " ('artist', 385),\n",
" ('notes', 364),\n",
- " ('product', 361),\n",
- " ('category', 353),\n",
- " ('type', 342),\n",
- " ('company', 335),\n",
- " ('city', 330),\n",
- " ('day', 326),\n",
- " ('album', 292),\n",
- " ('code', 290),\n",
- " ('team name', 282),\n",
- " ('ranking', 278),\n",
- " ('class', 264),\n",
- " ('order', 254),\n",
- " ('sex', 254),\n",
- " ('person', 222),\n",
- " ('gender', 219),\n",
- " ('status', 217),\n",
- " ('owner', 211),\n",
- " ('weight', 206),\n",
- " ('result', 194),\n",
- " ('year', 193),\n",
- " ('address', 193),\n",
- " ('duration', 191),\n",
- " ('country', 177),\n",
- " ('service', 176),\n",
+ " ('type', 363),\n",
+ " ('area', 357),\n",
+ " ('category', 341),\n",
+ " ('company', 340),\n",
+ " ('album', 338),\n",
+ " ('day', 329),\n",
+ " ('product', 322),\n",
+ " ('ranking', 302),\n",
+ " ('gender', 287),\n",
+ " ('city', 286),\n",
+ " ('team name', 283),\n",
+ " ('code', 274),\n",
+ " ('class', 253),\n",
+ " ('person', 221),\n",
+ " ('owner', 219),\n",
+ " ('weight', 203),\n",
+ " ('status', 197),\n",
+ " ('brand', 189),\n",
+ " ('year', 189),\n",
+ " ('credit', 176),\n",
+ " ('result', 174),\n",
" ('manufacturer', 171),\n",
- " ('brand', 168),\n",
- " ('origin', 162),\n",
- " ('plays', 152),\n",
- " ('credit', 151),\n",
- " ('component', 149),\n",
- " ('sales', 143),\n",
- " ('range', 135),\n",
- " ('format', 133),\n",
- " ('age', 122),\n",
- " ('county', 118),\n",
- " ('state', 117),\n",
- " ('club', 113),\n",
- " ('director', 110),\n",
- " ('nationality', 107),\n",
- " ('publisher', 105),\n",
- " ('capacity', 103),\n",
- " ('classification', 89),\n",
- " ('affiliation', 87),\n",
- " ('command', 85),\n",
- " ('symbol', 84),\n",
- " ('family', 74),\n",
- " ('elevation', 73),\n",
- " ('operator', 68),\n",
- " ('education', 66),\n",
- " ('depth', 66),\n",
- " ('creator', 60),\n",
- " ('requirement', 60),\n",
- " ('language', 54),\n",
- " ('genre', 53),\n",
- " ('file size', 48),\n",
- " ('organisation', 45),\n",
- " ('industry', 43),\n",
- " ('species', 40),\n",
- " ('affiliate', 38),\n",
- " ('birth place', 35),\n",
- " ('collection', 32),\n",
- " ('religion', 31),\n",
- " ('continent', 28),\n",
- " ('jockey', 23),\n",
- " ('isbn', 18),\n",
- " ('grades', 13),\n",
- " ('currency', 11),\n",
- " ('birth date', 11)]"
+ " ('address', 171),\n",
+ " ('service', 167),\n",
+ " ('order', 165),\n",
+ " ('sex', 164),\n",
+ " ('duration', 155),\n",
+ " ('age', 153),\n",
+ " ('sales', 152),\n",
+ " ('country', 152),\n",
+ " ('plays', 147),\n",
+ " ('component', 147),\n",
+ " ('origin', 144),\n",
+ " ('range', 139),\n",
+ " ('club', 133),\n",
+ " ('nationality', 131),\n",
+ " ('state', 129),\n",
+ " ('county', 127),\n",
+ " ('format', 120),\n",
+ " ('director', 102),\n",
+ " ('capacity', 101),\n",
+ " ('command', 100),\n",
+ " ('symbol', 94),\n",
+ " ('publisher', 89),\n",
+ " ('classification', 81),\n",
+ " ('depth', 80),\n",
+ " ('family', 78),\n",
+ " ('affiliation', 76),\n",
+ " ('elevation', 70),\n",
+ " ('language', 69),\n",
+ " ('operator', 62),\n",
+ " ('requirement', 58),\n",
+ " ('education', 57),\n",
+ " ('file size', 56),\n",
+ " ('genre', 56),\n",
+ " ('creator', 55),\n",
+ " ('industry', 44),\n",
+ " ('organisation', 44),\n",
+ " ('species', 41),\n",
+ " ('affiliate', 39),\n",
+ " ('jockey', 33),\n",
+ " ('collection', 33),\n",
+ " ('birth place', 33),\n",
+ " ('religion', 27),\n",
+ " ('continent', 24),\n",
+ " ('isbn', 17),\n",
+ " ('birth date', 15),\n",
+ " ('currency', 13),\n",
+ " ('grades', 12)]"
]
},
- "execution_count": 56,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -913,7 +938,7 @@
},
{
"cell_type": "code",
- "execution_count": 57,
+ "execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -922,7 +947,7 @@
},
{
"cell_type": "code",
- "execution_count": 58,
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
diff --git a/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb b/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb
index 05e2280..80d1399 100644
--- a/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb
+++ b/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb
@@ -1,12 +1,22 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Train and test Sherlock when ensembled with a RF classifier\n",
+ "To boost the performance of Sherlock, it can be combined with a RF classifier.\n",
+ "\n",
+ "The scripts below show the procedure for doing so."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
- "nn_model_id = 'retrained_sherlock'"
+ "model_id = 'sherlock'"
]
},
{
@@ -27,6 +37,7 @@
],
"source": [
"# If you need fully deterministic results between runs, set the following environment value prior to launching jupyter.\n",
+ "\n",
"# See comment in sherlock.features.paragraph_vectors.infer_paragraph_embeddings_features for more info.\n",
"%env PYTHONHASHSEED"
]
@@ -43,7 +54,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@@ -60,8 +71,7 @@
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import classification_report, f1_score\n",
"\n",
- "from sherlock.deploy.predict_sherlock import predict_sherlock_proba, _transform_predictions_to_classes\n",
- "from sherlock.deploy.train_sherlock import train_sherlock"
+ "from sherlock.deploy.model import SherlockModel"
]
},
{
@@ -73,15 +83,15 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-07 15:46:02.658556\n",
- "Load data (train) process took 0:00:04.837128 seconds.\n"
+ "Started at 2022-02-21 14:44:58.387328\n",
+ "Load data (train) process took 0:00:07.072707 seconds.\n"
]
}
],
@@ -92,12 +102,14 @@
"X_train = pd.read_parquet('../data/data/processed/train.parquet')\n",
"y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()\n",
"\n",
+ "y_train = np.array([x.lower() for x in y_train])\n",
+ "\n",
"print(f'Load data (train) process took {datetime.now() - start} seconds.')"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -116,15 +128,15 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-07 15:46:07.682626\n",
- "Load data (validation) process took 0:00:01.803970 seconds.\n"
+ "Started at 2022-02-21 14:16:45.455219\n",
+ "Load data (validation) process took 0:00:02.024156 seconds.\n"
]
}
],
@@ -142,7 +154,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -151,7 +163,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
@@ -167,15 +179,15 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-07 15:46:11.682256\n",
- "Finished at 2022-02-07 16:01:59.050932, took 0:15:47.368743 seconds\n"
+ "Started at 2022-02-21 14:17:08.147857\n",
+ "Finished at 2022-02-21 14:38:09.947917, took 0:21:01.802720 seconds\n"
]
}
],
@@ -199,7 +211,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -217,16 +229,16 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Started at 2022-02-07 16:01:59.120891\n",
+ "Started at 2022-02-21 14:38:10.601540\n",
"Trained and saved new model.\n",
- "Finished at 2022-02-07 16:02:00.975332, took 0:00:01.854455 seconds\n"
+ "Finished at 2022-02-21 14:38:12.493349, took 0:00:01.891821 seconds\n"
]
}
],
@@ -243,6 +255,39 @@
"print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array(['address', 'affiliate', 'affiliation', 'age', 'album', 'area',\n",
+ " 'artist', 'birth date', 'birth place', 'brand', 'capacity',\n",
+ " 'category', 'city', 'class', 'classification', 'club', 'code',\n",
+ " 'collection', 'command', 'company', 'component', 'continent',\n",
+ " 'country', 'county', 'creator', 'credit', 'currency', 'day',\n",
+ " 'depth', 'description', 'director', 'duration', 'education',\n",
+ " 'elevation', 'family', 'file size', 'format', 'gender', 'genre',\n",
+ " 'grades', 'industry', 'isbn', 'jockey', 'language', 'location',\n",
+ " 'manufacturer', 'name', 'nationality', 'notes', 'operator',\n",
+ " 'order', 'organisation', 'origin', 'owner', 'person', 'plays',\n",
+ " 'position', 'product', 'publisher', 'range', 'rank', 'ranking',\n",
+ " 'region', 'religion', 'requirement', 'result', 'sales', 'service',\n",
+ " 'sex', 'species', 'state', 'status', 'symbol', 'team', 'team name',\n",
+ " 'type', 'weight', 'year'], dtype='\n",
- "f1 score 0.8909529156786774\n"
+ "f1 score 0.8912755744265719\n"
]
}
],
@@ -337,7 +383,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
@@ -346,7 +392,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 76,
"metadata": {},
"outputs": [
{
@@ -354,7 +400,7 @@
"output_type": "stream",
"text": [
"prediction count 137353, type = \n",
- "f1 score 0.8884613184751746\n"
+ "f1 score 0.8883526561931331\n"
]
}
],
@@ -371,7 +417,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
@@ -380,7 +426,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 78,
"metadata": {},
"outputs": [
{
@@ -388,7 +434,7 @@
"output_type": "stream",
"text": [
"prediction count 137353, type = \n",
- "f1 score 0.8933550546229518\n"
+ "f1 score 0.8940645473980389\n"
]
}
],
@@ -405,39 +451,18 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 79,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "WARNING: Logging before flag parsing goes to stderr.\n",
- "W0207 16:02:52.102345 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
- "Instructions for updating:\n",
- "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
- "W0207 16:02:52.104378 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
- "Instructions for updating:\n",
- "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
- "W0207 16:02:52.108886 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
- "Instructions for updating:\n",
- "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
- "W0207 16:02:52.138780 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
- "Instructions for updating:\n",
- "If using Keras pass *_constraint arguments to layers.\n",
- "2022-02-07 16:02:52.704754: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA\n",
- "2022-02-07 16:02:52.762490: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fc083ef6bc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n",
- "2022-02-07 16:02:52.762507: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "predicted_sherlock_proba = predict_sherlock_proba(X_test, nn_id=nn_model_id)"
+ "model = SherlockModel()\n",
+ "model.initialize_model_from_json(with_weights=True, model_id=\"sherlock\")\n",
+ "predicted_sherlock_proba = model.predict_proba(X_test)"
]
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 80,
"metadata": {},
"outputs": [
{
@@ -445,7 +470,7 @@
"output_type": "stream",
"text": [
"prediction count 137353, type = \n",
- "f1 score 0.8940572197723697\n"
+ "f1 score 0.8951410029373902\n"
]
}
],
@@ -462,7 +487,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
@@ -475,13 +500,12 @@
" x = nn_probs + voting_probs\n",
" x = x / 2\n",
"\n",
- " combined.append(x)\n",
- " "
+ " combined.append(x)"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 82,
"metadata": {},
"outputs": [
{
@@ -489,7 +513,7 @@
"output_type": "stream",
"text": [
"prediction count 137353, type = \n",
- "f1 score 0.9047220789717997\n"
+ "f1 score 0.905491661885665\n"
]
}
],
@@ -501,7 +525,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
@@ -521,7 +545,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
@@ -539,7 +563,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 85,
"metadata": {},
"outputs": [
{
@@ -547,11 +571,11 @@
"output_type": "stream",
"text": [
"\t\tf1-score\tprecision\trecall\t\tsupport\n",
- "grades\t\t0.994\t\t0.994\t\t0.994\t\t1765\n",
- "isbn\t\t0.991\t\t0.993\t\t0.989\t\t1430\n",
- "jockey\t\t0.986\t\t0.980\t\t0.991\t\t2819\n",
- "industry\t0.983\t\t0.979\t\t0.988\t\t2958\n",
- "birth date\t0.978\t\t0.981\t\t0.975\t\t479\n"
+ "grades\t\t0.995\t\t0.994\t\t0.995\t\t1765\n",
+ "isbn\t\t0.990\t\t0.992\t\t0.989\t\t1430\n",
+ "industry\t0.986\t\t0.985\t\t0.988\t\t2958\n",
+ "jockey\t\t0.985\t\t0.984\t\t0.987\t\t2819\n",
+ "currency\t0.979\t\t0.985\t\t0.973\t\t405\n"
]
}
],
@@ -568,7 +592,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 86,
"metadata": {},
"outputs": [
{
@@ -576,11 +600,11 @@
"output_type": "stream",
"text": [
"\t\tf1-score\tprecision\trecall\t\tsupport\n",
- "rank\t\t0.751\t\t0.710\t\t0.796\t\t2983\n",
- "person\t\t0.690\t\t0.702\t\t0.679\t\t579\n",
- "sales\t\t0.633\t\t0.747\t\t0.550\t\t322\n",
- "director\t0.598\t\t0.648\t\t0.556\t\t225\n",
- "ranking\t\t0.594\t\t0.855\t\t0.456\t\t439\n"
+ "rank\t\t0.738\t\t0.678\t\t0.810\t\t2983\n",
+ "person\t\t0.695\t\t0.767\t\t0.636\t\t579\n",
+ "sales\t\t0.615\t\t0.667\t\t0.571\t\t322\n",
+ "director\t0.604\t\t0.661\t\t0.556\t\t225\n",
+ "ranking\t\t0.569\t\t0.823\t\t0.435\t\t439\n"
]
}
],
@@ -597,7 +621,7 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 87,
"metadata": {},
"outputs": [
{
@@ -606,88 +630,88 @@
"text": [
" precision recall f1-score support\n",
"\n",
- " address 0.926 0.947 0.937 3003\n",
- " affiliate 0.976 0.814 0.888 204\n",
- " affiliation 0.978 0.958 0.968 1768\n",
- " age 0.882 0.963 0.921 3033\n",
- " album 0.883 0.901 0.892 3035\n",
- " area 0.888 0.840 0.863 1987\n",
- " artist 0.807 0.886 0.845 3043\n",
- " birth date 0.981 0.975 0.978 479\n",
- " birth place 0.974 0.904 0.938 418\n",
- " brand 0.795 0.723 0.757 574\n",
- " capacity 0.879 0.746 0.807 362\n",
- " category 0.913 0.901 0.907 3087\n",
- " city 0.857 0.912 0.883 2966\n",
- " class 0.906 0.926 0.916 2971\n",
- "classification 0.955 0.867 0.909 587\n",
- " club 0.972 0.961 0.967 2977\n",
+ " address 0.929 0.951 0.940 3003\n",
+ " affiliate 0.949 0.819 0.879 204\n",
+ " affiliation 0.975 0.958 0.966 1768\n",
+ " age 0.891 0.955 0.922 3033\n",
+ " album 0.894 0.895 0.894 3035\n",
+ " area 0.892 0.836 0.863 1987\n",
+ " artist 0.810 0.884 0.846 3043\n",
+ " birth date 0.983 0.969 0.976 479\n",
+ " birth place 0.939 0.919 0.929 418\n",
+ " brand 0.849 0.695 0.764 574\n",
+ " capacity 0.851 0.771 0.809 362\n",
+ " category 0.927 0.898 0.912 3087\n",
+ " city 0.870 0.910 0.890 2966\n",
+ " class 0.921 0.923 0.922 2971\n",
+ "classification 0.946 0.874 0.909 587\n",
+ " club 0.975 0.957 0.966 2977\n",
" code 0.921 0.925 0.923 2956\n",
- " collection 0.968 0.943 0.955 476\n",
- " command 0.933 0.930 0.931 1045\n",
- " company 0.890 0.904 0.897 3041\n",
- " component 0.893 0.895 0.894 1226\n",
- " continent 0.908 0.916 0.912 227\n",
- " country 0.896 0.952 0.923 3038\n",
- " county 0.938 0.963 0.950 2959\n",
- " creator 0.848 0.839 0.843 347\n",
- " credit 0.894 0.823 0.857 941\n",
- " currency 0.982 0.968 0.975 405\n",
- " day 0.944 0.915 0.929 3038\n",
- " depth 0.959 0.944 0.952 947\n",
- " description 0.795 0.886 0.838 3042\n",
- " director 0.648 0.556 0.598 225\n",
- " duration 0.948 0.953 0.950 3000\n",
- " education 0.902 0.853 0.877 313\n",
- " elevation 0.953 0.960 0.956 1299\n",
- " family 0.970 0.903 0.935 746\n",
- " file size 0.925 0.886 0.905 361\n",
- " format 0.977 0.957 0.967 2956\n",
- " gender 0.859 0.834 0.846 1030\n",
- " genre 0.968 0.948 0.958 1163\n",
- " grades 0.994 0.994 0.994 1765\n",
- " industry 0.979 0.988 0.983 2958\n",
- " isbn 0.993 0.989 0.991 1430\n",
- " jockey 0.980 0.991 0.986 2819\n",
- " language 0.927 0.945 0.936 1474\n",
- " location 0.877 0.844 0.860 2949\n",
- " manufacturer 0.904 0.798 0.848 945\n",
- " name 0.788 0.725 0.755 3017\n",
- " nationality 0.853 0.738 0.791 424\n",
- " notes 0.779 0.832 0.804 2303\n",
- " operator 0.850 0.854 0.852 404\n",
- " order 0.891 0.863 0.877 1462\n",
- " organisation 0.873 0.817 0.844 262\n",
- " origin 0.964 0.899 0.930 1439\n",
- " owner 0.935 0.877 0.905 1673\n",
- " person 0.702 0.679 0.690 579\n",
- " plays 0.826 0.919 0.870 1513\n",
- " position 0.842 0.856 0.849 3057\n",
- " product 0.865 0.893 0.879 2647\n",
- " publisher 0.897 0.918 0.907 880\n",
- " range 0.917 0.801 0.855 577\n",
- " rank 0.710 0.796 0.751 2983\n",
- " ranking 0.855 0.456 0.594 439\n",
- " region 0.894 0.850 0.871 2740\n",
- " religion 0.975 0.932 0.953 340\n",
- " requirement 0.938 0.807 0.867 300\n",
- " result 0.969 0.948 0.958 2920\n",
- " sales 0.747 0.550 0.633 322\n",
- " service 0.978 0.925 0.951 2222\n",
- " sex 0.939 0.940 0.939 2997\n",
- " species 0.919 0.954 0.936 819\n",
- " state 0.942 0.960 0.951 3030\n",
- " status 0.957 0.939 0.948 3100\n",
- " symbol 0.964 0.967 0.966 1752\n",
- " team 0.878 0.867 0.873 3011\n",
- " team name 0.912 0.840 0.875 1639\n",
- " type 0.888 0.894 0.891 2909\n",
- " weight 0.946 0.951 0.949 2963\n",
- " year 0.969 0.941 0.955 3015\n",
+ " collection 0.987 0.935 0.960 476\n",
+ " command 0.941 0.918 0.929 1045\n",
+ " company 0.913 0.898 0.905 3041\n",
+ " component 0.904 0.892 0.898 1226\n",
+ " continent 0.887 0.930 0.908 227\n",
+ " country 0.897 0.957 0.926 3038\n",
+ " county 0.944 0.964 0.954 2959\n",
+ " creator 0.807 0.841 0.824 347\n",
+ " credit 0.890 0.832 0.860 941\n",
+ " currency 0.985 0.973 0.979 405\n",
+ " day 0.948 0.914 0.931 3038\n",
+ " depth 0.945 0.945 0.945 947\n",
+ " description 0.809 0.884 0.845 3042\n",
+ " director 0.661 0.556 0.604 225\n",
+ " duration 0.935 0.955 0.945 3000\n",
+ " education 0.887 0.856 0.872 313\n",
+ " elevation 0.961 0.955 0.958 1299\n",
+ " family 0.967 0.905 0.935 746\n",
+ " file size 0.946 0.867 0.905 361\n",
+ " format 0.969 0.960 0.964 2956\n",
+ " gender 0.860 0.836 0.848 1030\n",
+ " genre 0.969 0.953 0.961 1163\n",
+ " grades 0.994 0.995 0.995 1765\n",
+ " industry 0.985 0.988 0.986 2958\n",
+ " isbn 0.992 0.989 0.990 1430\n",
+ " jockey 0.984 0.987 0.985 2819\n",
+ " language 0.923 0.947 0.935 1474\n",
+ " location 0.901 0.838 0.868 2949\n",
+ " manufacturer 0.876 0.828 0.851 945\n",
+ " name 0.733 0.769 0.751 3017\n",
+ " nationality 0.906 0.708 0.795 424\n",
+ " notes 0.750 0.847 0.796 2303\n",
+ " operator 0.819 0.854 0.836 404\n",
+ " order 0.860 0.877 0.869 1462\n",
+ " organisation 0.852 0.855 0.853 262\n",
+ " origin 0.955 0.905 0.930 1439\n",
+ " owner 0.941 0.874 0.906 1673\n",
+ " person 0.767 0.636 0.695 579\n",
+ " plays 0.856 0.915 0.885 1513\n",
+ " position 0.848 0.850 0.849 3057\n",
+ " product 0.878 0.886 0.882 2647\n",
+ " publisher 0.904 0.903 0.904 880\n",
+ " range 0.885 0.797 0.839 577\n",
+ " rank 0.678 0.810 0.738 2983\n",
+ " ranking 0.823 0.435 0.569 439\n",
+ " region 0.905 0.844 0.873 2740\n",
+ " religion 0.975 0.926 0.950 340\n",
+ " requirement 0.942 0.817 0.875 300\n",
+ " result 0.968 0.947 0.958 2920\n",
+ " sales 0.667 0.571 0.615 322\n",
+ " service 0.970 0.929 0.949 2222\n",
+ " sex 0.940 0.938 0.939 2997\n",
+ " species 0.930 0.952 0.941 819\n",
+ " state 0.940 0.962 0.951 3030\n",
+ " status 0.949 0.943 0.946 3100\n",
+ " symbol 0.963 0.971 0.967 1752\n",
+ " team 0.867 0.871 0.869 3011\n",
+ " team name 0.898 0.842 0.869 1639\n",
+ " type 0.923 0.882 0.902 2909\n",
+ " weight 0.963 0.950 0.956 2963\n",
+ " year 0.966 0.946 0.956 3015\n",
"\n",
" accuracy 0.905 137353\n",
- " macro avg 0.904 0.880 0.890 137353\n",
- " weighted avg 0.906 0.905 0.905 137353\n",
+ " macro avg 0.903 0.880 0.890 137353\n",
+ " weighted avg 0.907 0.905 0.905 137353\n",
"\n"
]
}
@@ -705,100 +729,100 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 88,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Total mismatches: 13058 (F1 score: 0.9047220789717997)\n"
+ "Total mismatches: 12994 (F1 score: 0.905491661885665)\n"
]
},
{
"data": {
"text/plain": [
- "[('name', 830),\n",
- " ('rank', 608),\n",
- " ('location', 460),\n",
- " ('position', 441),\n",
- " ('region', 412),\n",
- " ('team', 399),\n",
- " ('notes', 388),\n",
- " ('artist', 348),\n",
- " ('description', 347),\n",
- " ('area', 318),\n",
- " ('type', 307),\n",
- " ('category', 307),\n",
- " ('album', 299),\n",
- " ('company', 291),\n",
- " ('product', 283),\n",
- " ('team name', 262),\n",
- " ('city', 261),\n",
- " ('day', 258),\n",
- " ('ranking', 239),\n",
+ "[('name', 697),\n",
+ " ('rank', 566),\n",
+ " ('location', 479),\n",
+ " ('position', 460),\n",
+ " ('region', 427),\n",
+ " ('team', 388),\n",
+ " ('artist', 353),\n",
+ " ('notes', 352),\n",
+ " ('description', 352),\n",
+ " ('type', 342),\n",
+ " ('area', 326),\n",
+ " ('album', 320),\n",
+ " ('category', 316),\n",
+ " ('company', 310),\n",
+ " ('product', 301),\n",
+ " ('city', 266),\n",
+ " ('day', 261),\n",
+ " ('team name', 259),\n",
+ " ('ranking', 248),\n",
+ " ('class', 229),\n",
" ('code', 222),\n",
- " ('class', 220),\n",
- " ('owner', 205),\n",
- " ('order', 200),\n",
- " ('manufacturer', 191),\n",
- " ('status', 190),\n",
- " ('person', 186),\n",
- " ('sex', 181),\n",
- " ('year', 177),\n",
- " ('gender', 171),\n",
- " ('credit', 167),\n",
- " ('service', 167),\n",
- " ('brand', 159),\n",
- " ('address', 158),\n",
- " ('result', 151),\n",
- " ('country', 147),\n",
- " ('origin', 146),\n",
- " ('weight', 145),\n",
- " ('sales', 145),\n",
- " ('duration', 140),\n",
- " ('component', 129),\n",
- " ('format', 126),\n",
- " ('plays', 123),\n",
- " ('state', 121),\n",
- " ('club', 116),\n",
- " ('range', 115),\n",
- " ('nationality', 111),\n",
- " ('age', 111),\n",
- " ('county', 110),\n",
+ " ('person', 211),\n",
+ " ('owner', 210),\n",
+ " ('sex', 185),\n",
+ " ('order', 180),\n",
+ " ('status', 178),\n",
+ " ('brand', 175),\n",
+ " ('gender', 169),\n",
+ " ('manufacturer', 163),\n",
+ " ('year', 163),\n",
+ " ('credit', 158),\n",
+ " ('service', 158),\n",
+ " ('result', 154),\n",
+ " ('weight', 149),\n",
+ " ('address', 146),\n",
+ " ('sales', 138),\n",
+ " ('duration', 136),\n",
+ " ('age', 136),\n",
+ " ('origin', 136),\n",
+ " ('component', 133),\n",
+ " ('country', 130),\n",
+ " ('club', 129),\n",
+ " ('plays', 128),\n",
+ " ('nationality', 124),\n",
+ " ('format', 119),\n",
+ " ('range', 117),\n",
+ " ('state', 115),\n",
+ " ('county', 108),\n",
" ('director', 100),\n",
- " ('capacity', 92),\n",
- " ('language', 81),\n",
- " ('classification', 78),\n",
+ " ('command', 86),\n",
+ " ('publisher', 85),\n",
+ " ('capacity', 83),\n",
+ " ('language', 78),\n",
" ('affiliation', 75),\n",
- " ('command', 73),\n",
- " ('family', 72),\n",
- " ('publisher', 72),\n",
- " ('genre', 60),\n",
+ " ('classification', 74),\n",
+ " ('family', 71),\n",
" ('operator', 59),\n",
- " ('requirement', 58),\n",
- " ('symbol', 57),\n",
- " ('creator', 56),\n",
- " ('depth', 53),\n",
- " ('elevation', 52),\n",
- " ('organisation', 48),\n",
- " ('education', 46),\n",
- " ('file size', 41),\n",
- " ('birth place', 40),\n",
- " ('affiliate', 38),\n",
- " ('species', 38),\n",
+ " ('elevation', 58),\n",
+ " ('creator', 55),\n",
+ " ('requirement', 55),\n",
+ " ('genre', 55),\n",
+ " ('depth', 52),\n",
+ " ('symbol', 51),\n",
+ " ('file size', 48),\n",
+ " ('education', 45),\n",
+ " ('species', 39),\n",
+ " ('organisation', 38),\n",
+ " ('affiliate', 37),\n",
+ " ('jockey', 36),\n",
" ('industry', 36),\n",
- " ('collection', 27),\n",
- " ('jockey', 25),\n",
- " ('religion', 23),\n",
- " ('continent', 19),\n",
+ " ('birth place', 34),\n",
+ " ('collection', 31),\n",
+ " ('religion', 25),\n",
+ " ('continent', 16),\n",
" ('isbn', 16),\n",
- " ('currency', 13),\n",
- " ('birth date', 12),\n",
- " ('grades', 10)]"
+ " ('birth date', 15),\n",
+ " ('currency', 11),\n",
+ " ('grades', 8)]"
]
},
- "execution_count": 32,
+ "execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
@@ -824,7 +848,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
@@ -833,7 +857,7 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 90,
"metadata": {},
"outputs": [
{
@@ -855,14 +879,14 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Completed at 2022-02-07 16:03:14.044571\n"
+ "Completed at 2022-02-21 14:53:39.160195\n"
]
}
],
diff --git a/notebooks/03-train-paragraph-vector-features-optional.ipynb b/notebooks/03-retrain-paragraph-vector-features.ipynb
similarity index 95%
rename from notebooks/03-train-paragraph-vector-features-optional.ipynb
rename to notebooks/03-retrain-paragraph-vector-features.ipynb
index 90bc4f6..bc3aa4f 100644
--- a/notebooks/03-train-paragraph-vector-features-optional.ipynb
+++ b/notebooks/03-retrain-paragraph-vector-features.ipynb
@@ -1,5 +1,15 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Load training set and train paragraph vectors\n",
+ "Note: the paragraph vector model has been trained and is downloaded in the `prepare_feature_extraction()` function.\n",
+ "\n",
+ "Retraining is therefore not needed, but optional"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 1,
@@ -32,13 +42,6 @@
"%env PYTHONHASHSEED"
]
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Load training set and train paragraph vectors"
- ]
- },
{
"cell_type": "code",
"execution_count": 3,
@@ -87,9 +90,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Download and read in raw data\n",
- "\n",
- "You can skip this step if you want to use a preprocessed data file."
+ "## Download and read in raw data\n"
]
},
{
diff --git a/requirements.txt b/requirements.txt
index d1679bf..00d2898 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,14 +7,15 @@ certifi==2019.6.16
chardet==3.0.4
docutils==0.14
gast==0.2.2
+gdown==4.3.0
gensim==3.8.0
google-pasta==0.1.7
-googledrivedownloader==0.4
grpcio==1.22.0
h5py==2.9.0
idna==2.8
jmespath==0.9.4
joblib==0.13.2
+jupyter==1.0.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
line_profiler==3.3.1
diff --git a/requirements38.txt b/requirements38.txt
index 1e2b04b..39926b5 100644
--- a/requirements38.txt
+++ b/requirements38.txt
@@ -8,14 +8,15 @@ chardet==3.0.4
cloudpickle==1.6.0
docutils==0.14
gast==0.2.2
+gdown==4.3.0
gensim==3.8.0
google-pasta==0.1.7
-googledrivedownloader==0.4
grpcio==1.34.1
h5py==3.1.0
idna==2.8
jmespath==0.9.4
joblib==1.0.0
+jupyter==1.0.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
loky==2.9.0
diff --git a/sherlock/deploy/model.py b/sherlock/deploy/model.py
index aca7d93..d95b925 100644
--- a/sherlock/deploy/model.py
+++ b/sherlock/deploy/model.py
@@ -101,6 +101,8 @@ def fit(
self.model = model
+ _ = helpers._get_categorical_label_encodings(y_train, y_val, model_id)
+
def predict(self, X: pd.DataFrame, model_id: str = "sherlock") -> np.array:
"""Use sherlock model to generate predictions for X.
diff --git a/sherlock/features/preprocessing.py b/sherlock/features/preprocessing.py
index 874f87a..3ac3b5f 100644
--- a/sherlock/features/preprocessing.py
+++ b/sherlock/features/preprocessing.py
@@ -8,7 +8,6 @@
import pandas as pd
from functools import partial
-from google_drive_downloader import GoogleDriveDownloader as gd
from pyarrow.parquet import ParquetFile
from tqdm import tqdm
diff --git a/sherlock/helpers.py b/sherlock/helpers.py
index 7765a6e..13afc13 100644
--- a/sherlock/helpers.py
+++ b/sherlock/helpers.py
@@ -9,18 +9,17 @@ def download_data():
The data is downloaded from Google Drive and stored in the 'data/' directory.
"""
data_dir = "../data/data/"
- data_zip = "../data.zip"
+ zip_filepath = "../data/data.zip"
print(f"Downloading the raw data into {data_dir}.")
if not os.path.exists(data_dir):
print("Downloading data directory.")
- dir_name = data_zip
gdown.download(
url="https://drive.google.com/uc?id=1-g0zbKFAXz7zKZc0Dnh74uDBpZCv4YqU",
- output=dir_name,
+ output=zip_filepath,
)
- with zipfile.ZipFile(data_zip, "r") as zf:
+ with zipfile.ZipFile(zip_filepath, "r") as zf:
zf.extractall("../data/")
print("Data was downloaded.")
diff --git a/tests/test_helper.py b/tests/test_helper.py
index 65d6737..b0c0a74 100644
--- a/tests/test_helper.py
+++ b/tests/test_helper.py
@@ -1,4 +1,4 @@
-from sherlock.deploy.model_helpers import categorize_features
+from sherlock.deploy.helpers import categorize_features
from collections import OrderedDict
CATEGORY_FEATURE_KEYS: dict = categorize_features()