diff --git a/data/data b/data/data deleted file mode 120000 index 945c9b4..0000000 --- a/data/data +++ /dev/null @@ -1 +0,0 @@ -. \ No newline at end of file diff --git a/model_files/classes_sherlock.npy b/model_files/classes_sherlock.npy new file mode 100644 index 0000000..8483b8c Binary files /dev/null and b/model_files/classes_sherlock.npy differ diff --git a/model_files/sherlock_weights.h5 b/model_files/sherlock_weights.h5 index 2f9fba3..709d54f 100644 Binary files a/model_files/sherlock_weights.h5 and b/model_files/sherlock_weights.h5 differ diff --git a/notebooks/00-use-sherlock-out-of-the-box.ipynb b/notebooks/00-use-sherlock-out-of-the-box.ipynb index 178f6d7..5d137cc 100644 --- a/notebooks/00-use-sherlock-out-of-the-box.ipynb +++ b/notebooks/00-use-sherlock-out-of-the-box.ipynb @@ -8,7 +8,8 @@ "# Using Sherlock out-of-the-box\n", "This notebook shows how to predict a semantic type for a given table column.\n", "The steps are basically:\n", - "- Extract features from a column.\n", + "- Download files for word embedding and paragraph vector feature extraction (downloads only once) and initialize feature extraction models.\n", + "- Extract features from table columns.\n", "- Initialize Sherlock.\n", "- Make a prediction for the feature representation of the column." ] @@ -44,11 +45,14 @@ "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "UsageError: Environment does not have key: PYTHONHASHSEED\n" - ] + "data": { + "text/plain": [ + "'13'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -57,20 +61,10 @@ }, { "cell_type": "markdown", - "id": "2b3b7967", + "id": "f1101303", "metadata": {}, "source": [ - "## Extract features" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "164f74ff", - "metadata": {}, - "outputs": [], - "source": [ - "# helpers.download_data()" + "## Initialize feature extraction models" ] }, { @@ -93,9 +87,9 @@ " \n", "All files for extracting word and paragraph embeddings are present.\n", "Initialising word embeddings\n", - "Initialise Word Embeddings process took 0:00:05.607905 seconds.\n", - "Initialise Doc2Vec Model, 400 dim, process took 0:00:02.443327 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n", - "Initialised NLTK, process took 0:00:00.181374 seconds.\n" + "Initialise Word Embeddings process took 0:00:05.513540 seconds.\n", + "Initialise Doc2Vec Model, 400 dim, process took 0:00:04.191875 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n", + "Initialised NLTK, process took 0:00:00.209930 seconds.\n" ] }, { @@ -117,9 +111,17 @@ "initialise_nltk()" ] }, + { + "cell_type": "markdown", + "id": "2b3b7967", + "metadata": {}, + "source": [ + "## Extract features" + ] + }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 4, "id": "db04ccf9", "metadata": {}, "outputs": [], @@ -128,6 +130,7 @@ " [\n", " [\"Jane Smith\", \"Lute Ahorn\", \"Anna James\"],\n", " [\"Amsterdam\", \"Haarlem\", \"Zwolle\"],\n", + " [\"Chabot Street 19\", \"1200 fifth Avenue\", \"Binnenkant 22, 1011BH\"]\n", " ],\n", " name=\"values\"\n", ")" @@ -135,19 +138,20 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 5, "id": "4875f6c7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0 [Jane Smith, Lute Ahorn, Anna James]\n", - "1 [Amsterdam, Haarlem, Zwolle]\n", + "0 [Jane Smith, Lute Ahorn, Anna James]\n", + "1 [Amsterdam, Haarlem, Zwolle]\n", + "2 [Chabot Street 19, 1200 fifth Avenue, Binnenka...\n", "Name: values, dtype: object" ] }, - "execution_count": 36, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -158,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 8, "id": "f7f2c846", "metadata": {}, "outputs": [ @@ -166,7 +170,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Extracting Features: 100%|██████████| 2/2 [00:00<00:00, 62.37it/s]\n" + "Extracting Features: 100%|██████████| 3/3 [00:00<00:00, 167.51it/s]" ] }, { @@ -175,6 +179,13 @@ "text": [ "Exporting 1588 column features\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] } ], "source": [ @@ -182,12 +193,12 @@ " \"../temporary.csv\",\n", " data\n", ")\n", - "feature_vector = pd.read_csv(\"../temporary.csv\", dtype=np.float32)" + "feature_vectors = pd.read_csv(\"../temporary.csv\", dtype=np.float32)" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 9, "id": "0c42ce71", "metadata": {}, "outputs": [ @@ -241,7 +252,7 @@ " 0.0\n", " 0.0\n", " 0.0\n", - " 0.0\n", + " 0.000000\n", " 0.0\n", " 0.0\n", " 0.0\n", @@ -249,23 +260,23 @@ " -3.0\n", " 0.0\n", " ...\n", - " -0.115819\n", - " 0.023961\n", - " -0.130739\n", - " 0.006393\n", - " -0.135118\n", - " -0.071956\n", - " -0.051051\n", - " -0.068307\n", - " 0.087342\n", - " -0.145716\n", + " -0.116468\n", + " 0.023982\n", + " -0.130867\n", + " 0.006825\n", + " -0.135098\n", + " -0.070616\n", + " -0.052172\n", + " -0.067250\n", + " 0.086256\n", + " -0.144385\n", " \n", " \n", " 1\n", " 0.0\n", " 0.0\n", " 0.0\n", - " 0.0\n", + " 0.000000\n", " 0.0\n", " 0.0\n", " 0.0\n", @@ -273,63 +284,84 @@ " -3.0\n", " 0.0\n", " ...\n", - " -0.054351\n", - " 0.023650\n", - " -0.165681\n", - " -0.016137\n", - " -0.059402\n", - " 0.008454\n", - " -0.044624\n", - " 0.025160\n", - " 0.037831\n", - " -0.086235\n", + " -0.054949\n", + " 0.024502\n", + " -0.166001\n", + " -0.014375\n", + " -0.058199\n", + " 0.009978\n", + " -0.046423\n", + " 0.025163\n", + " 0.036946\n", + " -0.086611\n", + " \n", + " \n", + " 2\n", + " 1.0\n", + " 0.0\n", + " 1.0\n", + " 0.666667\n", + " 0.0\n", + " 2.0\n", + " 1.0\n", + " 3.0\n", + " -1.5\n", + " 0.0\n", + " ...\n", + " -0.022804\n", + " 0.001741\n", + " 0.047479\n", + " 0.118293\n", + " -0.093435\n", + " 0.036759\n", + " -0.004508\n", + " -0.087898\n", + " -0.117796\n", + " -0.191386\n", " \n", " \n", "\n", - "

2 rows × 1588 columns

\n", + "

3 rows × 1588 columns

\n", "" ], "text/plain": [ " n_[0]-agg-any n_[0]-agg-all n_[0]-agg-mean n_[0]-agg-var n_[0]-agg-min \\\n", - "0 0.0 0.0 0.0 0.0 0.0 \n", - "1 0.0 0.0 0.0 0.0 0.0 \n", + "0 0.0 0.0 0.0 0.000000 0.0 \n", + "1 0.0 0.0 0.0 0.000000 0.0 \n", + "2 1.0 0.0 1.0 0.666667 0.0 \n", "\n", " n_[0]-agg-max n_[0]-agg-median n_[0]-agg-sum n_[0]-agg-kurtosis \\\n", "0 0.0 0.0 0.0 -3.0 \n", "1 0.0 0.0 0.0 -3.0 \n", + "2 2.0 1.0 3.0 -1.5 \n", "\n", " n_[0]-agg-skewness ... par_vec_390 par_vec_391 par_vec_392 \\\n", - "0 0.0 ... -0.115819 0.023961 -0.130739 \n", - "1 0.0 ... -0.054351 0.023650 -0.165681 \n", + "0 0.0 ... -0.116468 0.023982 -0.130867 \n", + "1 0.0 ... -0.054949 0.024502 -0.166001 \n", + "2 0.0 ... -0.022804 0.001741 0.047479 \n", "\n", " par_vec_393 par_vec_394 par_vec_395 par_vec_396 par_vec_397 \\\n", - "0 0.006393 -0.135118 -0.071956 -0.051051 -0.068307 \n", - "1 -0.016137 -0.059402 0.008454 -0.044624 0.025160 \n", + "0 0.006825 -0.135098 -0.070616 -0.052172 -0.067250 \n", + "1 -0.014375 -0.058199 0.009978 -0.046423 0.025163 \n", + "2 0.118293 -0.093435 0.036759 -0.004508 -0.087898 \n", "\n", " par_vec_398 par_vec_399 \n", - "0 0.087342 -0.145716 \n", - "1 0.037831 -0.086235 \n", + "0 0.086256 -0.144385 \n", + "1 0.036946 -0.086611 \n", + "2 -0.117796 -0.191386 \n", "\n", - "[2 rows x 1588 columns]" + "[3 rows x 1588 columns]" ] }, - "execution_count": 38, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "feature_vector" + "feature_vectors" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "52047a6b", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -343,18 +375,18 @@ "id": "9027fa4a", "metadata": {}, "source": [ - "## Initialize Sherlock." + "## Initialize Sherlock" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 11, "id": "b9ec13ec", "metadata": {}, "outputs": [], "source": [ "model = SherlockModel();\n", - "model.initialize_model_from_json(with_weights=True);" + "model.initialize_model_from_json(with_weights=True, model_id=\"sherlock\");" ] }, { @@ -375,27 +407,27 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 12, "id": "fc079fa9", "metadata": {}, "outputs": [], "source": [ - "predicted_labels = model.predict(feature_vector, \"sherlock\")" + "predicted_labels = model.predict(feature_vectors, \"sherlock\")" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 13, "id": "0feb9584", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array(['creator', 'city'], dtype=object)" + "array(['person', 'city', 'address'], dtype=object)" ] }, - "execution_count": 41, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } diff --git a/notebooks/01-data-preprocessing.ipynb b/notebooks/01-data-preprocessing.ipynb index 34be765..33ddbc2 100644 --- a/notebooks/01-data-preprocessing.ipynb +++ b/notebooks/01-data-preprocessing.ipynb @@ -1,5 +1,12 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Preprocess data and extract features." + ] + }, { "cell_type": "code", "execution_count": 1, @@ -12,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -21,34 +28,21 @@ "'13'" ] }, - "execution_count": 2, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# If you need fully deterministic results between runs, set the following environment value prior to launching jupyter.\n", + "# Instructions can be found in HOW-TO-ENVIRONMENT.md.\n", "# See comment in sherlock.features.paragraph_vectors.infer_paragraph_embeddings_features for more info.\n", "%env PYTHONHASHSEED" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Extract features, retrain Sherlock and generate predictions." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The script below first downloads the data (roughly 700K samples), then extract features from the raw data values.
" - ] - }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -74,14 +68,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-10 18:08:09.074960.\n" + "Started at 2022-02-21 12:55:47.936774.\n" ] } ], @@ -101,19 +95,22 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Downloading the raw and preprocessed data into ../data/data.zip.\n", + "Downloading the raw data into ../data/.\n", "Data was downloaded.\n", - "Preparing feature extraction by downloading 2 files:\n", + "Preparing feature extraction by downloading 4 files:\n", " \n", - " ../sherlock/features/glove.6B.50d.txt and \n", - " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy.\n", + " ../sherlock/features/glove.6B.50d.txt, \n", + " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy,\n", + " \n", + " ../sherlock/features/par_vec_trained_400.pkl.trainables.syn1neg.npy, and \n", + " ../sherlock/features/par_vec_trained_400.pkl.wv.vectors.npy.\n", " \n", "All files for extracting word and paragraph embeddings are present.\n" ] @@ -126,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -174,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -195,23 +192,26 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preparing feature extraction by downloading 2 files:\n", + "Preparing feature extraction by downloading 4 files:\n", + " \n", + " ../sherlock/features/glove.6B.50d.txt, \n", + " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy,\n", " \n", - " ../sherlock/features/glove.6B.50d.txt and \n", - " ../sherlock/features/par_vec_trained_400.pkl.docvecs.vectors_docs.npy.\n", + " ../sherlock/features/par_vec_trained_400.pkl.trainables.syn1neg.npy, and \n", + " ../sherlock/features/par_vec_trained_400.pkl.wv.vectors.npy.\n", " \n", "All files for extracting word and paragraph embeddings are present.\n", "Initialising word embeddings\n", - "Initialise Word Embeddings process took 0:00:05.470982 seconds.\n", - "Initialise Doc2Vec Model, 400 dim, process took 0:00:02.812897 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n", - "Initialised NLTK, process took 0:00:00.243027 seconds.\n" + "Initialise Word Embeddings process took 0:00:05.196631 seconds.\n", + "Initialise Doc2Vec Model, 400 dim, process took 0:00:03.018509 seconds. (filename = ../sherlock/features/par_vec_trained_400.pkl)\n", + "Initialised NLTK, process took 0:00:00.184036 seconds.\n" ] }, { @@ -236,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -245,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -270,16 +270,16 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Starting ../data/data/processed/test_20220210-180929.csv at 2022-02-10 18:09:45.839366. Rows=137353, using 8 CPU cores\n", + "Starting ../data/data/processed/test_20220221-125552.csv at 2022-02-21 12:56:03.739007. Rows=137353, using 8 CPU cores\n", "Exporting 1588 column features\n", - "Finished. Processed 137353 rows in 0:04:43.132315, key_count=8\n" + "Finished. Processed 137353 rows in 0:04:41.140014, key_count=8\n" ] } ], @@ -293,14 +293,14 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Finished at 2022-02-10 18:14:29.113958\n" + "Finished at 2022-02-21 13:00:45.006991\n" ] } ], @@ -317,16 +317,16 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Starting ../data/data/processed/train_20220210-180929.csv at 2022-02-10 18:14:32.033627. Rows=412059, using 8 CPU cores\n", + "Starting ../data/data/processed/train_20220221-125552.csv at 2022-02-21 13:00:46.766509. Rows=412059, using 8 CPU cores\n", "Exporting 1588 column features\n", - "Finished. Processed 412059 rows in 0:13:59.233387, key_count=8\n" + "Finished. Processed 412059 rows in 0:13:46.382904, key_count=8\n" ] } ], @@ -340,14 +340,14 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Finished at 2022-02-10 18:28:31.533353\n" + "Finished at 2022-02-21 13:14:33.436746\n" ] } ], @@ -364,16 +364,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Starting ../data/data/processed/validation_20220210-180929.csv at 2022-02-10 18:28:32.398357. Rows=137353, using 8 CPU cores\n", + "Starting ../data/data/processed/validation_20220221-125552.csv at 2022-02-21 13:14:34.282570. Rows=137353, using 8 CPU cores\n", "Exporting 1588 column features\n", - "Finished. Processed 137353 rows in 0:04:29.888694, key_count=8\n" + "Finished. Processed 137353 rows in 0:04:20.682436, key_count=8\n" ] } ], @@ -387,14 +387,14 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Finished at 2022-02-10 18:33:02.388902\n" + "Finished at 2022-02-21 13:18:55.069282\n" ] } ], @@ -411,14 +411,14 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Load Features (test) process took 0:00:31.909566 seconds.\n" + "Load Features (test) process took 0:00:30.504465 seconds.\n" ] } ], @@ -432,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -644,7 +644,7 @@ "[5 rows x 1588 columns]" ] }, - "execution_count": 25, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -655,14 +655,14 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Load Features (train) process took 0:01:40.003595 seconds.\n" + "Load Features (train) process took 0:01:39.241069 seconds.\n" ] } ], @@ -676,7 +676,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -888,7 +888,7 @@ "[5 rows x 1588 columns]" ] }, - "execution_count": 27, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -899,14 +899,14 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Load Features (validation) process took 0:00:30.514030 seconds.\n" + "Load Features (validation) process took 0:00:30.066667 seconds.\n" ] } ], @@ -920,7 +920,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1132,7 +1132,7 @@ "[5 rows x 1588 columns]" ] }, - "execution_count": 29, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1150,14 +1150,14 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Transpose process took 0:00:11.975168 seconds.\n" + "Transpose process took 0:00:12.315641 seconds.\n" ] } ], @@ -1171,14 +1171,14 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "FillNA process took 0:00:04.992048 seconds.\n" + "FillNA process took 0:00:04.335100 seconds.\n" ] } ], @@ -1196,14 +1196,14 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Save parquet process took 0:01:03.475468 seconds.\n" + "Save parquet process took 0:00:58.970663 seconds.\n" ] } ], @@ -1219,14 +1219,14 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Completed at 2022-02-10 18:37:05.949284.\n" + "Completed at 2022-02-21 13:22:51.962526.\n" ] } ], diff --git a/notebooks/02-1-train-and-test-sherlock.ipynb b/notebooks/02-1-train-and-test-sherlock.ipynb index add40fd..f0001f0 100644 --- a/notebooks/02-1-train-and-test-sherlock.ipynb +++ b/notebooks/02-1-train-and-test-sherlock.ipynb @@ -29,12 +29,14 @@ "metadata": {}, "outputs": [], "source": [ + "# This will be the ID for the retrained model,\n", + "#further down predictions can also be made with the original model: \"sherlock\"\n", "model_id = 'retrained_sherlock'" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -66,8 +68,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-10 18:47:37.581529\n", - "Load data (train) process took 0:00:06.286789 seconds.\n" + "Started at 2022-02-21 13:49:29.901066\n", + "Load data (train) process took 0:00:07.129129 seconds.\n" ] } ], @@ -111,8 +113,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-10 18:47:44.189431\n", - "Load data (validation) process took 0:00:01.713314 seconds.\n" + "Started at 2022-02-21 13:49:37.357866\n", + "Load data (validation) process took 0:00:01.622250 seconds.\n" ] } ], @@ -137,8 +139,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-10 18:47:45.963358\n", - "Finished at 2022-02-10 18:47:48.373023, took 0:00:02.409678 seconds\n" + "Started at 2022-02-21 13:49:39.041588\n", + "Finished at 2022-02-21 13:49:41.373401, took 0:00:02.331826 seconds\n" ] } ], @@ -168,12 +170,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Load Sherlock with pretrained weights" + "### Option 1: load Sherlock with pretrained weights" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "metadata": { "scrolled": true }, @@ -182,9 +184,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-10 18:55:58.217413\n", + "Started at 2022-02-21 14:05:25.056032\n", "Initialized model.\n", - "Finished at 2022-02-10 18:55:59.172945, took 0:00:00.955544 seconds\n" + "Finished at 2022-02-21 14:05:25.937041, took 0:00:00.881019 seconds\n" ] } ], @@ -192,8 +194,8 @@ "start = datetime.now()\n", "print(f'Started at {start}')\n", "\n", - "model = SherlockModel()\n", - "model.initialize_model_from_json(with_weights=True)\n", + "model = SherlockModel();\n", + "model.initialize_model_from_json(with_weights=True, model_id=\"sherlock\");\n", "\n", "print('Initialized model.')\n", "print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')" @@ -203,12 +205,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Fit Sherlock from scratch (and save for later use)" + "### Option 2: fit Sherlock from scratch (and save for later use)" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "model_id = \"retrained_sherlock\"" + ] + }, + { + "cell_type": "code", + "execution_count": 26, "metadata": { "scrolled": true }, @@ -217,94 +228,128 @@ "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-10 20:48:03.023090\n", + "Started at 2022-02-21 11:26:21.704064\n", "Train on 412059 samples, validate on 137353 samples\n", "Epoch 1/100\n", - "412059/412059 [==============================] - 73s 177us/sample - loss: 1.6144 - categorical_accuracy: 0.6971 - val_loss: 1.0390 - val_categorical_accuracy: 0.8300\n", + "412059/412059 [==============================] - 70s 170us/sample - loss: 1.6072 - categorical_accuracy: 0.6988 - val_loss: 1.0404 - val_categorical_accuracy: 0.8266\n", "Epoch 2/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.9628 - categorical_accuracy: 0.8359 - val_loss: 0.9472 - val_categorical_accuracy: 0.8491\n", + "412059/412059 [==============================] - 67s 163us/sample - loss: 0.9624 - categorical_accuracy: 0.8364 - val_loss: 0.9520 - val_categorical_accuracy: 0.8498\n", "Epoch 3/100\n", - "412059/412059 [==============================] - 70s 171us/sample - loss: 0.8533 - categorical_accuracy: 0.8586 - val_loss: 0.8953 - val_categorical_accuracy: 0.8580\n", + "412059/412059 [==============================] - 65s 159us/sample - loss: 0.8501 - categorical_accuracy: 0.8590 - val_loss: 0.8929 - val_categorical_accuracy: 0.8579\n", "Epoch 4/100\n", - "412059/412059 [==============================] - 68s 166us/sample - loss: 0.7860 - categorical_accuracy: 0.8715 - val_loss: 0.8617 - val_categorical_accuracy: 0.8637\n", + "412059/412059 [==============================] - 66s 159us/sample - loss: 0.7858 - categorical_accuracy: 0.8718 - val_loss: 0.8561 - val_categorical_accuracy: 0.8645\n", "Epoch 5/100\n", - "412059/412059 [==============================] - 69s 168us/sample - loss: 0.7388 - categorical_accuracy: 0.8789 - val_loss: 0.8268 - val_categorical_accuracy: 0.8682\n", + "412059/412059 [==============================] - 66s 159us/sample - loss: 0.7399 - categorical_accuracy: 0.8788 - val_loss: 0.8267 - val_categorical_accuracy: 0.8682\n", "Epoch 6/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.7010 - categorical_accuracy: 0.8852 - val_loss: 0.8113 - val_categorical_accuracy: 0.8699\n", + "412059/412059 [==============================] - 66s 160us/sample - loss: 0.6976 - categorical_accuracy: 0.8855 - val_loss: 0.8073 - val_categorical_accuracy: 0.8703\n", "Epoch 7/100\n", - "412059/412059 [==============================] - 71s 172us/sample - loss: 0.6674 - categorical_accuracy: 0.8896 - val_loss: 0.7839 - val_categorical_accuracy: 0.8720\n", + "412059/412059 [==============================] - 68s 165us/sample - loss: 0.6674 - categorical_accuracy: 0.8892 - val_loss: 0.7841 - val_categorical_accuracy: 0.8726\n", "Epoch 8/100\n", - "412059/412059 [==============================] - 70s 171us/sample - loss: 0.6406 - categorical_accuracy: 0.8939 - val_loss: 0.7652 - val_categorical_accuracy: 0.8754\n", + "412059/412059 [==============================] - 71s 173us/sample - loss: 0.6381 - categorical_accuracy: 0.8937 - val_loss: 0.7663 - val_categorical_accuracy: 0.8753\n", "Epoch 9/100\n", - "412059/412059 [==============================] - 70s 171us/sample - loss: 0.6147 - categorical_accuracy: 0.8970 - val_loss: 0.7455 - val_categorical_accuracy: 0.8763\n", + "412059/412059 [==============================] - 77s 187us/sample - loss: 0.6152 - categorical_accuracy: 0.8965 - val_loss: 0.7514 - val_categorical_accuracy: 0.8760\n", "Epoch 10/100\n", - "412059/412059 [==============================] - 69s 168us/sample - loss: 0.5918 - categorical_accuracy: 0.9005 - val_loss: 0.7366 - val_categorical_accuracy: 0.8781\n", + "412059/412059 [==============================] - 80s 195us/sample - loss: 0.5921 - categorical_accuracy: 0.9001 - val_loss: 0.7471 - val_categorical_accuracy: 0.8779\n", "Epoch 11/100\n", - "412059/412059 [==============================] - 71s 173us/sample - loss: 0.5734 - categorical_accuracy: 0.9022 - val_loss: 0.7283 - val_categorical_accuracy: 0.8807\n", + "412059/412059 [==============================] - 78s 190us/sample - loss: 0.5713 - categorical_accuracy: 0.9031 - val_loss: 0.7266 - val_categorical_accuracy: 0.8799\n", "Epoch 12/100\n", - "412059/412059 [==============================] - 70s 169us/sample - loss: 0.5543 - categorical_accuracy: 0.9051 - val_loss: 0.7156 - val_categorical_accuracy: 0.8799\n", + "412059/412059 [==============================] - 71s 171us/sample - loss: 0.5537 - categorical_accuracy: 0.9051 - val_loss: 0.7146 - val_categorical_accuracy: 0.8812\n", "Epoch 13/100\n", - "412059/412059 [==============================] - 70s 169us/sample - loss: 0.5373 - categorical_accuracy: 0.9069 - val_loss: 0.7066 - val_categorical_accuracy: 0.8812\n", + "412059/412059 [==============================] - 76s 184us/sample - loss: 0.5372 - categorical_accuracy: 0.9073 - val_loss: 0.7075 - val_categorical_accuracy: 0.8824\n", "Epoch 14/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.5219 - categorical_accuracy: 0.9093 - val_loss: 0.6935 - val_categorical_accuracy: 0.8824\n", + "412059/412059 [==============================] - 72s 175us/sample - loss: 0.5213 - categorical_accuracy: 0.9092 - val_loss: 0.7004 - val_categorical_accuracy: 0.8830\n", "Epoch 15/100\n", - "412059/412059 [==============================] - 71s 172us/sample - loss: 0.5074 - categorical_accuracy: 0.9109 - val_loss: 0.6861 - val_categorical_accuracy: 0.8837\n", + "412059/412059 [==============================] - 71s 173us/sample - loss: 0.5080 - categorical_accuracy: 0.9107 - val_loss: 0.6898 - val_categorical_accuracy: 0.8830\n", "Epoch 16/100\n", - "412059/412059 [==============================] - 71s 172us/sample - loss: 0.4940 - categorical_accuracy: 0.9128 - val_loss: 0.6828 - val_categorical_accuracy: 0.8859\n", + "412059/412059 [==============================] - 74s 180us/sample - loss: 0.4934 - categorical_accuracy: 0.9131 - val_loss: 0.6923 - val_categorical_accuracy: 0.8844\n", "Epoch 17/100\n", - "412059/412059 [==============================] - 70s 169us/sample - loss: 0.4819 - categorical_accuracy: 0.9147 - val_loss: 0.6756 - val_categorical_accuracy: 0.8857\n", + "412059/412059 [==============================] - 73s 177us/sample - loss: 0.4804 - categorical_accuracy: 0.9148 - val_loss: 0.6735 - val_categorical_accuracy: 0.8835\n", "Epoch 18/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4706 - categorical_accuracy: 0.9159 - val_loss: 0.6664 - val_categorical_accuracy: 0.8861\n", + "412059/412059 [==============================] - 76s 183us/sample - loss: 0.4709 - categorical_accuracy: 0.9156 - val_loss: 0.6644 - val_categorical_accuracy: 0.8862\n", "Epoch 19/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4593 - categorical_accuracy: 0.9172 - val_loss: 0.6627 - val_categorical_accuracy: 0.8868\n", + "412059/412059 [==============================] - 73s 177us/sample - loss: 0.4605 - categorical_accuracy: 0.9172 - val_loss: 0.6749 - val_categorical_accuracy: 0.8868\n", "Epoch 20/100\n", - "412059/412059 [==============================] - 71s 172us/sample - loss: 0.4491 - categorical_accuracy: 0.9186 - val_loss: 0.6623 - val_categorical_accuracy: 0.8867\n", + "412059/412059 [==============================] - 72s 175us/sample - loss: 0.4495 - categorical_accuracy: 0.9187 - val_loss: 0.6570 - val_categorical_accuracy: 0.8871\n", "Epoch 21/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4399 - categorical_accuracy: 0.9198 - val_loss: 0.6573 - val_categorical_accuracy: 0.8883\n", + "412059/412059 [==============================] - 73s 176us/sample - loss: 0.4405 - categorical_accuracy: 0.9198 - val_loss: 0.6547 - val_categorical_accuracy: 0.8881\n", "Epoch 22/100\n", - "412059/412059 [==============================] - 70s 171us/sample - loss: 0.4331 - categorical_accuracy: 0.9207 - val_loss: 0.6486 - val_categorical_accuracy: 0.8884\n", + "412059/412059 [==============================] - 74s 181us/sample - loss: 0.4319 - categorical_accuracy: 0.9208 - val_loss: 0.6499 - val_categorical_accuracy: 0.8885\n", "Epoch 23/100\n", - "412059/412059 [==============================] - 70s 169us/sample - loss: 0.4241 - categorical_accuracy: 0.9221 - val_loss: 0.6499 - val_categorical_accuracy: 0.8889\n", + "412059/412059 [==============================] - 74s 180us/sample - loss: 0.4245 - categorical_accuracy: 0.9219 - val_loss: 0.6415 - val_categorical_accuracy: 0.8895\n", "Epoch 24/100\n", - "412059/412059 [==============================] - 76s 183us/sample - loss: 0.4168 - categorical_accuracy: 0.9230 - val_loss: 0.6432 - val_categorical_accuracy: 0.8906\n", + "412059/412059 [==============================] - 71s 172us/sample - loss: 0.4176 - categorical_accuracy: 0.9225 - val_loss: 0.6470 - val_categorical_accuracy: 0.8888\n", "Epoch 25/100\n", - "412059/412059 [==============================] - 71s 171us/sample - loss: 0.4098 - categorical_accuracy: 0.9241 - val_loss: 0.6417 - val_categorical_accuracy: 0.8896\n", + "412059/412059 [==============================] - 73s 177us/sample - loss: 0.4091 - categorical_accuracy: 0.9243 - val_loss: 0.6343 - val_categorical_accuracy: 0.8888\n", "Epoch 26/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.4029 - categorical_accuracy: 0.9246 - val_loss: 0.6344 - val_categorical_accuracy: 0.8901\n", + "412059/412059 [==============================] - 76s 184us/sample - loss: 0.4009 - categorical_accuracy: 0.9254 - val_loss: 0.6438 - val_categorical_accuracy: 0.8894\n", "Epoch 27/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.3965 - categorical_accuracy: 0.9261 - val_loss: 0.6409 - val_categorical_accuracy: 0.8911\n", + "412059/412059 [==============================] - 73s 177us/sample - loss: 0.3960 - categorical_accuracy: 0.9260 - val_loss: 0.6343 - val_categorical_accuracy: 0.8910\n", "Epoch 28/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.3909 - categorical_accuracy: 0.9266 - val_loss: 0.6341 - val_categorical_accuracy: 0.8901\n", + "412059/412059 [==============================] - 72s 174us/sample - loss: 0.3895 - categorical_accuracy: 0.9272 - val_loss: 0.6549 - val_categorical_accuracy: 0.8897\n", "Epoch 29/100\n", - "412059/412059 [==============================] - 71s 173us/sample - loss: 0.3844 - categorical_accuracy: 0.9275 - val_loss: 0.6272 - val_categorical_accuracy: 0.8906\n", + "412059/412059 [==============================] - 71s 173us/sample - loss: 0.3846 - categorical_accuracy: 0.9280 - val_loss: 0.6252 - val_categorical_accuracy: 0.8912\n", "Epoch 30/100\n", - "412059/412059 [==============================] - 75s 182us/sample - loss: 0.3798 - categorical_accuracy: 0.9283 - val_loss: 0.6249 - val_categorical_accuracy: 0.8915\n", + "412059/412059 [==============================] - 74s 180us/sample - loss: 0.3791 - categorical_accuracy: 0.9288 - val_loss: 0.6246 - val_categorical_accuracy: 0.8913\n", "Epoch 31/100\n", - "412059/412059 [==============================] - 72s 174us/sample - loss: 0.3744 - categorical_accuracy: 0.9297 - val_loss: 0.6229 - val_categorical_accuracy: 0.8920\n", + "412059/412059 [==============================] - 74s 180us/sample - loss: 0.3743 - categorical_accuracy: 0.9292 - val_loss: 0.6318 - val_categorical_accuracy: 0.8917\n", "Epoch 32/100\n", - "412059/412059 [==============================] - 70s 170us/sample - loss: 0.3700 - categorical_accuracy: 0.9302 - val_loss: 0.6198 - val_categorical_accuracy: 0.8920\n", + "412059/412059 [==============================] - 77s 187us/sample - loss: 0.3697 - categorical_accuracy: 0.9299 - val_loss: 0.6297 - val_categorical_accuracy: 0.8909\n", "Epoch 33/100\n", - "412059/412059 [==============================] - 71s 172us/sample - loss: 0.3661 - categorical_accuracy: 0.9307 - val_loss: 0.6238 - val_categorical_accuracy: 0.8919\n", + "412059/412059 [==============================] - 79s 192us/sample - loss: 0.3649 - categorical_accuracy: 0.9310 - val_loss: 0.6250 - val_categorical_accuracy: 0.8920\n", "Epoch 34/100\n", - "412059/412059 [==============================] - 70s 171us/sample - loss: 0.3614 - categorical_accuracy: 0.9311 - val_loss: 0.6186 - val_categorical_accuracy: 0.8926\n", + "412059/412059 [==============================] - 77s 186us/sample - loss: 0.3609 - categorical_accuracy: 0.9312 - val_loss: 0.6217 - val_categorical_accuracy: 0.8921\n", "Epoch 35/100\n", - "412059/412059 [==============================] - 68s 166us/sample - loss: 0.3579 - categorical_accuracy: 0.9320 - val_loss: 0.6219 - val_categorical_accuracy: 0.8924\n", + "412059/412059 [==============================] - 73s 176us/sample - loss: 0.3570 - categorical_accuracy: 0.9318 - val_loss: 0.6203 - val_categorical_accuracy: 0.8923\n", "Epoch 36/100\n", - "412059/412059 [==============================] - 71s 171us/sample - loss: 0.3532 - categorical_accuracy: 0.9326 - val_loss: 0.6191 - val_categorical_accuracy: 0.8922\n", + "412059/412059 [==============================] - 77s 187us/sample - loss: 0.3525 - categorical_accuracy: 0.9328 - val_loss: 0.6304 - val_categorical_accuracy: 0.8923\n", "Epoch 37/100\n", - "412059/412059 [==============================] - 71s 172us/sample - loss: 0.3517 - categorical_accuracy: 0.9328 - val_loss: 0.6137 - val_categorical_accuracy: 0.8932\n", + "412059/412059 [==============================] - 74s 178us/sample - loss: 0.3496 - categorical_accuracy: 0.9332 - val_loss: 0.6173 - val_categorical_accuracy: 0.8923\n", "Epoch 38/100\n", - "412059/412059 [==============================] - 69s 168us/sample - loss: 0.3464 - categorical_accuracy: 0.9336 - val_loss: 0.6201 - val_categorical_accuracy: 0.8937\n", + "412059/412059 [==============================] - 73s 177us/sample - loss: 0.3464 - categorical_accuracy: 0.9338 - val_loss: 0.6186 - val_categorical_accuracy: 0.8924\n", "Epoch 39/100\n", - "412059/412059 [==============================] - 68s 165us/sample - loss: 0.3445 - categorical_accuracy: 0.9340 - val_loss: 0.6211 - val_categorical_accuracy: 0.8929\n", + "412059/412059 [==============================] - 73s 176us/sample - loss: 0.3435 - categorical_accuracy: 0.9342 - val_loss: 0.6116 - val_categorical_accuracy: 0.8938\n", "Epoch 40/100\n", - "412059/412059 [==============================] - 67s 162us/sample - loss: 0.3394 - categorical_accuracy: 0.9350 - val_loss: 0.6199 - val_categorical_accuracy: 0.8934\n", + "412059/412059 [==============================] - 73s 177us/sample - loss: 0.3414 - categorical_accuracy: 0.9345 - val_loss: 0.6105 - val_categorical_accuracy: 0.8930\n", "Epoch 41/100\n", - "412059/412059 [==============================] - 68s 165us/sample - loss: 0.3377 - categorical_accuracy: 0.9351 - val_loss: 0.6216 - val_categorical_accuracy: 0.8934\n", + "412059/412059 [==============================] - 82s 198us/sample - loss: 0.3386 - categorical_accuracy: 0.9351 - val_loss: 0.6201 - val_categorical_accuracy: 0.8936\n", "Epoch 42/100\n", - "412059/412059 [==============================] - 70s 169us/sample - loss: 0.3356 - categorical_accuracy: 0.9352 - val_loss: 0.6155 - val_categorical_accuracy: 0.8939\n", + "412059/412059 [==============================] - 77s 187us/sample - loss: 0.3342 - categorical_accuracy: 0.9356 - val_loss: 0.6132 - val_categorical_accuracy: 0.8913\n", + "Epoch 43/100\n", + "412059/412059 [==============================] - 80s 194us/sample - loss: 0.3318 - categorical_accuracy: 0.9362 - val_loss: 0.6141 - val_categorical_accuracy: 0.8932\n", + "Epoch 44/100\n", + "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3295 - categorical_accuracy: 0.9368 - val_loss: 0.6101 - val_categorical_accuracy: 0.8940\n", + "Epoch 45/100\n", + "412059/412059 [==============================] - 78s 188us/sample - loss: 0.3270 - categorical_accuracy: 0.9370 - val_loss: 0.6088 - val_categorical_accuracy: 0.8939\n", + "Epoch 46/100\n", + "412059/412059 [==============================] - 76s 185us/sample - loss: 0.3252 - categorical_accuracy: 0.9373 - val_loss: 0.6152 - val_categorical_accuracy: 0.8939\n", + "Epoch 47/100\n", + "412059/412059 [==============================] - 82s 200us/sample - loss: 0.3237 - categorical_accuracy: 0.9377 - val_loss: 0.6105 - val_categorical_accuracy: 0.8939\n", + "Epoch 48/100\n", + "412059/412059 [==============================] - 78s 188us/sample - loss: 0.3217 - categorical_accuracy: 0.9377 - val_loss: 0.6070 - val_categorical_accuracy: 0.8945\n", + "Epoch 49/100\n", + "412059/412059 [==============================] - 75s 181us/sample - loss: 0.3187 - categorical_accuracy: 0.9386 - val_loss: 0.6067 - val_categorical_accuracy: 0.8943\n", + "Epoch 50/100\n", + "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3184 - categorical_accuracy: 0.9382 - val_loss: 0.6094 - val_categorical_accuracy: 0.8947\n", + "Epoch 51/100\n", + "412059/412059 [==============================] - 76s 185us/sample - loss: 0.3155 - categorical_accuracy: 0.9391 - val_loss: 0.6054 - val_categorical_accuracy: 0.8948\n", + "Epoch 52/100\n", + "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3143 - categorical_accuracy: 0.9392 - val_loss: 0.6060 - val_categorical_accuracy: 0.8943\n", + "Epoch 53/100\n", + "412059/412059 [==============================] - 75s 181us/sample - loss: 0.3105 - categorical_accuracy: 0.9400 - val_loss: 0.6130 - val_categorical_accuracy: 0.8954\n", + "Epoch 54/100\n", + "412059/412059 [==============================] - 78s 188us/sample - loss: 0.3109 - categorical_accuracy: 0.9397 - val_loss: 0.6030 - val_categorical_accuracy: 0.8952\n", + "Epoch 55/100\n", + "412059/412059 [==============================] - 78s 189us/sample - loss: 0.3091 - categorical_accuracy: 0.9406 - val_loss: 0.6148 - val_categorical_accuracy: 0.8949\n", + "Epoch 56/100\n", + "412059/412059 [==============================] - 78s 189us/sample - loss: 0.3076 - categorical_accuracy: 0.9404 - val_loss: 0.6075 - val_categorical_accuracy: 0.8947\n", + "Epoch 57/100\n", + "412059/412059 [==============================] - 76s 185us/sample - loss: 0.3056 - categorical_accuracy: 0.9411 - val_loss: 0.6167 - val_categorical_accuracy: 0.8954\n", + "Epoch 58/100\n", + "412059/412059 [==============================] - 78s 190us/sample - loss: 0.3033 - categorical_accuracy: 0.9411 - val_loss: 0.6130 - val_categorical_accuracy: 0.8955\n", + "Epoch 59/100\n", + "412059/412059 [==============================] - 68s 166us/sample - loss: 0.3041 - categorical_accuracy: 0.9408 - val_loss: 0.6100 - val_categorical_accuracy: 0.8950\n", "Trained and saved new model.\n", - "Finished at 2022-02-10 21:37:31.052443, took 0:49:28.030848 seconds\n" + "Finished at 2022-02-21 12:39:29.005837, took 1:13:07.304278 seconds\n" ] } ], @@ -312,7 +357,8 @@ "start = datetime.now()\n", "print(f'Started at {start}')\n", "\n", - "sherlock_model = SherlockModel()\n", + "model = SherlockModel()\n", + "# Model will be stored with ID `model_id`\n", "model.fit(X_train, y_train, X_validation, y_validation, model_id=model_id)\n", "\n", "print('Trained and saved new model.')\n", @@ -321,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -344,7 +390,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -354,7 +400,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -367,10 +413,10 @@ { "data": { "text/plain": [ - "0.8937685721454983" + "0.8951410029373902" ] }, - "execution_count": 51, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -386,11 +432,12 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# If using the original model, model_id should be replaced with \"sherlock\"\n", + "#model_id = \"sherlock\"\n", "classes = np.load(f\"../model_files/classes_{model_id}.npy\", allow_pickle=True)\n", "\n", "report = classification_report(y_test, predicted_labels, output_dict=True)\n", @@ -409,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -417,11 +464,11 @@ "output_type": "stream", "text": [ "\t\tf1-score\tprecision\trecall\t\tsupport\n", - "grades\t\t0.993\t\t0.994\t\t0.993\t\t1765\n", - "isbn\t\t0.990\t\t0.992\t\t0.987\t\t1430\n", - "jockey\t\t0.986\t\t0.980\t\t0.992\t\t2819\n", - "currency\t0.980\t\t0.987\t\t0.973\t\t405\n", - "industry\t0.980\t\t0.975\t\t0.985\t\t2958\n" + "grades\t\t0.993\t\t0.993\t\t0.993\t\t1765\n", + "isbn\t\t0.991\t\t0.993\t\t0.988\t\t1430\n", + "jockey\t\t0.985\t\t0.982\t\t0.988\t\t2819\n", + "industry\t0.984\t\t0.983\t\t0.985\t\t2958\n", + "currency\t0.975\t\t0.982\t\t0.968\t\t405\n" ] } ], @@ -446,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 17, "metadata": { "scrolled": true }, @@ -456,11 +503,11 @@ "output_type": "stream", "text": [ "\t\tf1-score\tprecision\trecall\t\tsupport\n", - "rank\t\t0.706\t\t0.631\t\t0.802\t\t2983\n", - "person\t\t0.665\t\t0.721\t\t0.617\t\t579\n", - "director\t0.581\t\t0.673\t\t0.511\t\t225\n", - "sales\t\t0.555\t\t0.554\t\t0.556\t\t322\n", - "ranking\t\t0.486\t\t0.722\t\t0.367\t\t439\n" + "rank\t\t0.693\t\t0.625\t\t0.778\t\t2983\n", + "person\t\t0.664\t\t0.717\t\t0.618\t\t579\n", + "director\t0.568\t\t0.591\t\t0.547\t\t225\n", + "sales\t\t0.556\t\t0.586\t\t0.528\t\t322\n", + "ranking\t\t0.441\t\t0.753\t\t0.312\t\t439\n" ] } ], @@ -485,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -494,88 +541,88 @@ "text": [ " precision recall f1-score support\n", "\n", - " address 0.944 0.936 0.940 3003\n", - " affiliate 0.922 0.814 0.865 204\n", - " affiliation 0.978 0.951 0.964 1768\n", - " age 0.847 0.960 0.900 3033\n", - " album 0.863 0.904 0.883 3035\n", - " area 0.911 0.802 0.853 1987\n", - " artist 0.810 0.870 0.839 3043\n", - " birth date 0.979 0.977 0.978 479\n", - " birth place 0.970 0.916 0.942 418\n", - " brand 0.779 0.707 0.742 574\n", - " capacity 0.773 0.715 0.743 362\n", - " category 0.915 0.886 0.900 3087\n", - " city 0.856 0.889 0.872 2966\n", - " class 0.920 0.911 0.915 2971\n", - "classification 0.924 0.848 0.885 587\n", - " club 0.962 0.962 0.962 2977\n", - " code 0.931 0.902 0.916 2956\n", - " collection 0.982 0.933 0.957 476\n", - " command 0.929 0.919 0.924 1045\n", - " company 0.895 0.890 0.892 3041\n", - " component 0.875 0.878 0.877 1226\n", - " continent 0.905 0.877 0.890 227\n", - " country 0.894 0.942 0.917 3038\n", - " county 0.937 0.960 0.948 2959\n", - " creator 0.829 0.827 0.828 347\n", - " credit 0.860 0.840 0.849 941\n", - " currency 0.987 0.973 0.980 405\n", - " day 0.938 0.893 0.915 3038\n", - " depth 0.922 0.930 0.926 947\n", - " description 0.804 0.861 0.832 3042\n", - " director 0.673 0.511 0.581 225\n", - " duration 0.945 0.936 0.941 3000\n", - " education 0.805 0.789 0.797 313\n", - " elevation 0.961 0.944 0.952 1299\n", - " family 0.961 0.901 0.930 746\n", - " file size 0.913 0.867 0.889 361\n", - " format 0.973 0.955 0.964 2956\n", - " gender 0.806 0.787 0.797 1030\n", - " genre 0.951 0.954 0.953 1163\n", - " grades 0.994 0.993 0.993 1765\n", - " industry 0.975 0.985 0.980 2958\n", - " isbn 0.992 0.987 0.990 1430\n", - " jockey 0.980 0.992 0.986 2819\n", - " language 0.837 0.963 0.896 1474\n", - " location 0.865 0.834 0.849 2949\n", - " manufacturer 0.848 0.819 0.833 945\n", - " name 0.732 0.748 0.740 3017\n", - " nationality 0.823 0.748 0.784 424\n", - " notes 0.769 0.842 0.804 2303\n", - " operator 0.808 0.832 0.820 404\n", - " order 0.949 0.826 0.883 1462\n", - " organisation 0.844 0.828 0.836 262\n", - " origin 0.957 0.887 0.921 1439\n", - " owner 0.939 0.874 0.905 1673\n", - " person 0.721 0.617 0.665 579\n", - " plays 0.842 0.900 0.870 1513\n", - " position 0.831 0.820 0.826 3057\n", - " product 0.890 0.864 0.877 2647\n", - " publisher 0.943 0.881 0.911 880\n", - " range 0.839 0.766 0.801 577\n", - " rank 0.631 0.802 0.706 2983\n", - " ranking 0.722 0.367 0.486 439\n", - " region 0.831 0.827 0.829 2740\n", - " religion 0.963 0.909 0.935 340\n", - " requirement 0.949 0.800 0.868 300\n", - " result 0.969 0.934 0.951 2920\n", - " sales 0.554 0.556 0.555 322\n", - " service 0.972 0.921 0.946 2222\n", - " sex 0.921 0.915 0.918 2997\n", - " species 0.921 0.951 0.936 819\n", - " state 0.922 0.961 0.941 3030\n", - " status 0.951 0.930 0.940 3100\n", - " symbol 0.952 0.952 0.952 1752\n", - " team 0.867 0.858 0.863 3011\n", - " team name 0.899 0.828 0.862 1639\n", - " type 0.871 0.882 0.877 2909\n", - " weight 0.954 0.930 0.942 2963\n", - " year 0.964 0.936 0.950 3015\n", + " address 0.931 0.943 0.937 3003\n", + " affiliate 0.943 0.809 0.871 204\n", + " affiliation 0.973 0.957 0.965 1768\n", + " age 0.866 0.950 0.906 3033\n", + " album 0.892 0.889 0.890 3035\n", + " area 0.870 0.820 0.844 1987\n", + " artist 0.816 0.873 0.844 3043\n", + " birth date 0.985 0.969 0.977 479\n", + " birth place 0.934 0.921 0.928 418\n", + " brand 0.830 0.671 0.742 574\n", + " capacity 0.793 0.721 0.755 362\n", + " category 0.924 0.890 0.906 3087\n", + " city 0.864 0.904 0.883 2966\n", + " class 0.901 0.915 0.908 2971\n", + "classification 0.927 0.862 0.893 587\n", + " club 0.974 0.955 0.964 2977\n", + " code 0.916 0.907 0.912 2956\n", + " collection 0.984 0.931 0.957 476\n", + " command 0.938 0.904 0.921 1045\n", + " company 0.912 0.888 0.900 3041\n", + " component 0.888 0.880 0.884 1226\n", + " continent 0.875 0.894 0.885 227\n", + " country 0.892 0.950 0.920 3038\n", + " county 0.943 0.957 0.950 2959\n", + " creator 0.770 0.841 0.804 347\n", + " credit 0.868 0.813 0.840 941\n", + " currency 0.982 0.968 0.975 405\n", + " day 0.953 0.892 0.921 3038\n", + " depth 0.931 0.916 0.923 947\n", + " description 0.804 0.869 0.835 3042\n", + " director 0.591 0.547 0.568 225\n", + " duration 0.924 0.948 0.936 3000\n", + " education 0.856 0.818 0.837 313\n", + " elevation 0.956 0.946 0.951 1299\n", + " family 0.964 0.895 0.928 746\n", + " file size 0.941 0.845 0.891 361\n", + " format 0.966 0.959 0.963 2956\n", + " gender 0.868 0.721 0.788 1030\n", + " genre 0.965 0.952 0.958 1163\n", + " grades 0.993 0.993 0.993 1765\n", + " industry 0.983 0.985 0.984 2958\n", + " isbn 0.993 0.988 0.991 1430\n", + " jockey 0.982 0.988 0.985 2819\n", + " language 0.939 0.953 0.946 1474\n", + " location 0.896 0.827 0.861 2949\n", + " manufacturer 0.865 0.819 0.841 945\n", + " name 0.724 0.759 0.741 3017\n", + " nationality 0.907 0.691 0.784 424\n", + " notes 0.724 0.842 0.779 2303\n", + " operator 0.794 0.847 0.819 404\n", + " order 0.869 0.887 0.878 1462\n", + " organisation 0.832 0.832 0.832 262\n", + " origin 0.947 0.900 0.923 1439\n", + " owner 0.931 0.869 0.899 1673\n", + " person 0.717 0.618 0.664 579\n", + " plays 0.846 0.903 0.873 1513\n", + " position 0.806 0.839 0.822 3057\n", + " product 0.868 0.878 0.873 2647\n", + " publisher 0.888 0.899 0.893 880\n", + " range 0.855 0.759 0.804 577\n", + " rank 0.625 0.778 0.693 2983\n", + " ranking 0.753 0.312 0.441 439\n", + " region 0.882 0.810 0.845 2740\n", + " religion 0.972 0.921 0.946 340\n", + " requirement 0.927 0.807 0.863 300\n", + " result 0.962 0.940 0.951 2920\n", + " sales 0.586 0.528 0.556 322\n", + " service 0.964 0.925 0.944 2222\n", + " sex 0.903 0.945 0.924 2997\n", + " species 0.921 0.950 0.935 819\n", + " state 0.939 0.957 0.948 3030\n", + " status 0.943 0.936 0.940 3100\n", + " symbol 0.958 0.946 0.952 1752\n", + " team 0.850 0.870 0.860 3011\n", + " team name 0.893 0.827 0.859 1639\n", + " type 0.916 0.875 0.895 2909\n", + " weight 0.958 0.931 0.944 2963\n", + " year 0.965 0.937 0.951 3015\n", "\n", - " accuracy 0.894 137353\n", - " macro avg 0.887 0.866 0.875 137353\n", - " weighted avg 0.896 0.894 0.894 137353\n", + " accuracy 0.895 137353\n", + " macro avg 0.890 0.866 0.876 137353\n", + " weighted avg 0.898 0.895 0.895 137353\n", "\n" ] } @@ -593,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -605,193 +652,171 @@ "[2420] expected \"address\" but predicted \"club\"\n", "[2616] expected \"address\" but predicted \"city\"\n", "[3398] expected \"address\" but predicted \"city\"\n", - "[4380] expected \"address\" but predicted \"county\"\n", - "[4422] expected \"address\" but predicted \"city\"\n", "[5112] expected \"address\" but predicted \"location\"\n", "[5546] expected \"address\" but predicted \"name\"\n", - "[5647] expected \"address\" but predicted \"team name\"\n", - "[7119] expected \"address\" but predicted \"day\"\n", + "[6526] expected \"address\" but predicted \"location\"\n", "[8797] expected \"address\" but predicted \"location\"\n", "[9354] expected \"address\" but predicted \"location\"\n", "[9574] expected \"address\" but predicted \"location\"\n", "[9806] expected \"address\" but predicted \"city\"\n", "[10035] expected \"address\" but predicted \"creator\"\n", - "[10067] expected \"address\" but predicted \"education\"\n", - "[11055] expected \"address\" but predicted \"city\"\n", - "[11902] expected \"address\" but predicted \"location\"\n", + "[10067] expected \"address\" but predicted \"order\"\n", + "[10665] expected \"address\" but predicted \"area\"\n", + "[11055] expected \"address\" but predicted \"county\"\n", + "[11902] expected \"address\" but predicted \"jockey\"\n", + "[11993] expected \"address\" but predicted \"location\"\n", "[12072] expected \"address\" but predicted \"artist\"\n", - "[12639] expected \"address\" but predicted \"location\"\n", + "[14200] expected \"address\" but predicted \"description\"\n", "[14677] expected \"address\" but predicted \"location\"\n", "[15232] expected \"address\" but predicted \"city\"\n", "[15461] expected \"address\" but predicted \"location\"\n", "[15496] expected \"address\" but predicted \"city\"\n", - "[16212] expected \"address\" but predicted \"location\"\n", + "[15987] expected \"address\" but predicted \"artist\"\n", "[19953] expected \"address\" but predicted \"county\"\n", - "[20829] expected \"address\" but predicted \"language\"\n", - "[21408] expected \"address\" but predicted \"description\"\n", - "[22148] expected \"address\" but predicted \"location\"\n", - "[22355] expected \"address\" but predicted \"location\"\n", + "[20425] expected \"address\" but predicted \"location\"\n", + "[20829] expected \"address\" but predicted \"notes\"\n", + "[21408] expected \"address\" but predicted \"position\"\n", + "[21666] expected \"address\" but predicted \"area\"\n", + "[22148] expected \"address\" but predicted \"product\"\n", "[23915] expected \"address\" but predicted \"location\"\n", "[24636] expected \"address\" but predicted \"team name\"\n", - "[24803] expected \"address\" but predicted \"day\"\n", + "[24803] expected \"address\" but predicted \"position\"\n", "[26171] expected \"address\" but predicted \"name\"\n", "[26184] expected \"address\" but predicted \"name\"\n", - "[26210] expected \"address\" but predicted \"state\"\n", - "[26296] expected \"address\" but predicted \"file size\"\n", + "[26210] expected \"address\" but predicted \"name\"\n", + "[26393] expected \"address\" but predicted \"position\"\n", "[26559] expected \"address\" but predicted \"location\"\n", "[26872] expected \"address\" but predicted \"command\"\n", - "[30403] expected \"address\" but predicted \"language\"\n", + "[30403] expected \"address\" but predicted \"notes\"\n", "[31391] expected \"address\" but predicted \"location\"\n", "[31515] expected \"address\" but predicted \"location\"\n", - "[31830] expected \"address\" but predicted \"location\"\n", - "[31836] expected \"address\" but predicted \"language\"\n", - "[31976] expected \"address\" but predicted \"location\"\n", - "[32429] expected \"address\" but predicted \"artist\"\n", + "[31830] expected \"address\" but predicted \"result\"\n", + "[31836] expected \"address\" but predicted \"notes\"\n", "[32551] expected \"address\" but predicted \"city\"\n", - "[32634] expected \"address\" but predicted \"location\"\n", - "[32762] expected \"address\" but predicted \"name\"\n", - "[33207] expected \"address\" but predicted \"region\"\n", - "[33975] expected \"address\" but predicted \"location\"\n", + "[32762] expected \"address\" but predicted \"artist\"\n", + "[33207] expected \"address\" but predicted \"order\"\n", "[34547] expected \"address\" but predicted \"location\"\n", - "[34849] expected \"address\" but predicted \"location\"\n", + "[34711] expected \"address\" but predicted \"location\"\n", "[35467] expected \"address\" but predicted \"location\"\n", - "[35756] expected \"address\" but predicted \"city\"\n", - "[35907] expected \"address\" but predicted \"county\"\n", - "[35938] expected \"address\" but predicted \"language\"\n", - "[36033] expected \"address\" but predicted \"name\"\n", + "[35938] expected \"address\" but predicted \"notes\"\n", "[37084] expected \"address\" but predicted \"person\"\n", - "[37318] expected \"address\" but predicted \"language\"\n", - "[37536] expected \"address\" but predicted \"language\"\n", - "[38118] expected \"address\" but predicted \"location\"\n", + "[37318] expected \"address\" but predicted \"notes\"\n", + "[37536] expected \"address\" but predicted \"notes\"\n", "[40184] expected \"address\" but predicted \"notes\"\n", "[40457] expected \"address\" but predicted \"country\"\n", - "[41540] expected \"address\" but predicted \"location\"\n", - "[42439] expected \"address\" but predicted \"description\"\n", + "[42439] expected \"address\" but predicted \"area\"\n", + "[44530] expected \"address\" but predicted \"artist\"\n", "[44906] expected \"address\" but predicted \"location\"\n", "[44918] expected \"address\" but predicted \"area\"\n", - "[46430] expected \"address\" but predicted \"location\"\n", + "[44932] expected \"address\" but predicted \"name\"\n", + "[46430] expected \"address\" but predicted \"birth place\"\n", "[46463] expected \"address\" but predicted \"notes\"\n", - "[47140] expected \"address\" but predicted \"team name\"\n", - "[47249] expected \"address\" but predicted \"language\"\n", - "[47497] expected \"address\" but predicted \"location\"\n", - "[47636] expected \"address\" but predicted \"location\"\n", + "[47140] expected \"address\" but predicted \"age\"\n", + "[47249] expected \"address\" but predicted \"notes\"\n", "[47713] expected \"address\" but predicted \"city\"\n", - "[47810] expected \"address\" but predicted \"category\"\n", + "[47810] expected \"address\" but predicted \"description\"\n", "[48016] expected \"address\" but predicted \"location\"\n", - "[48289] expected \"address\" but predicted \"location\"\n", "[48631] expected \"address\" but predicted \"location\"\n", - "[50317] expected \"address\" but predicted \"name\"\n", - "[50329] expected \"address\" but predicted \"location\"\n", - "[50783] expected \"address\" but predicted \"city\"\n", + "[50329] expected \"address\" but predicted \"area\"\n", "[51643] expected \"address\" but predicted \"city\"\n", - "[51897] expected \"address\" but predicted \"location\"\n", - "[52120] expected \"address\" but predicted \"city\"\n", + "[51887] expected \"address\" but predicted \"location\"\n", + "[53960] expected \"address\" but predicted \"service\"\n", "[54248] expected \"address\" but predicted \"area\"\n", - "[55535] expected \"address\" but predicted \"language\"\n", - "[56030] expected \"address\" but predicted \"description\"\n", + "[56030] expected \"address\" but predicted \"code\"\n", + "[56085] expected \"address\" but predicted \"name\"\n", + "[56176] expected \"address\" but predicted \"product\"\n", "[57015] expected \"address\" but predicted \"location\"\n", - "[57851] expected \"address\" but predicted \"location\"\n", "[58524] expected \"address\" but predicted \"location\"\n", + "[58958] expected \"address\" but predicted \"team name\"\n", "[59506] expected \"address\" but predicted \"name\"\n", - "[59632] expected \"address\" but predicted \"location\"\n", - "[59734] expected \"address\" but predicted \"location\"\n", - "[60105] expected \"address\" but predicted \"location\"\n", - "[60225] expected \"address\" but predicted \"language\"\n", - "[60441] expected \"address\" but predicted \"location\"\n", - "[60565] expected \"address\" but predicted \"region\"\n", + "[59734] expected \"address\" but predicted \"language\"\n", + "[60105] expected \"address\" but predicted \"area\"\n", + "[60225] expected \"address\" but predicted \"notes\"\n", + "[60441] expected \"address\" but predicted \"collection\"\n", + "[60565] expected \"address\" but predicted \"ranking\"\n", "[61783] expected \"address\" but predicted \"location\"\n", - "[61894] expected \"address\" but predicted \"language\"\n", - "[61975] expected \"address\" but predicted \"capacity\"\n", - "[64427] expected \"address\" but predicted \"region\"\n", + "[61894] expected \"address\" but predicted \"notes\"\n", + "[66355] expected \"address\" but predicted \"artist\"\n", "[67142] expected \"address\" but predicted \"code\"\n", - "[68698] expected \"address\" but predicted \"area\"\n", "[69794] expected \"address\" but predicted \"city\"\n", "[70071] expected \"address\" but predicted \"location\"\n", "[71228] expected \"address\" but predicted \"location\"\n", - "[71784] expected \"address\" but predicted \"symbol\"\n", + "[71784] expected \"address\" but predicted \"location\"\n", + "[72001] expected \"address\" but predicted \"product\"\n", "[72226] expected \"address\" but predicted \"notes\"\n", - "[73451] expected \"address\" but predicted \"name\"\n", + "[73360] expected \"address\" but predicted \"species\"\n", "[73573] expected \"address\" but predicted \"name\"\n", - "[74228] expected \"address\" but predicted \"elevation\"\n", - "[75861] expected \"address\" but predicted \"location\"\n", + "[74228] expected \"address\" but predicted \"location\"\n", "[75893] expected \"address\" but predicted \"notes\"\n", "[76551] expected \"address\" but predicted \"language\"\n", - "[77502] expected \"address\" but predicted \"location\"\n", - "[77796] expected \"address\" but predicted \"city\"\n", - "[78437] expected \"address\" but predicted \"location\"\n", - "[80712] expected \"address\" but predicted \"name\"\n", - "[81288] expected \"address\" but predicted \"city\"\n", - "[82082] expected \"address\" but predicted \"artist\"\n", + "[77796] expected \"address\" but predicted \"location\"\n", + "[80669] expected \"address\" but predicted \"location\"\n", + "[80712] expected \"address\" but predicted \"category\"\n", + "[82082] expected \"address\" but predicted \"location\"\n", "[82779] expected \"address\" but predicted \"range\"\n", - "[83204] expected \"address\" but predicted \"name\"\n", - "[84979] expected \"address\" but predicted \"country\"\n", + "[83478] expected \"address\" but predicted \"city\"\n", + "[84979] expected \"address\" but predicted \"weight\"\n", "[85206] expected \"address\" but predicted \"location\"\n", "[85353] expected \"address\" but predicted \"creator\"\n", - "[85752] expected \"address\" but predicted \"language\"\n", + "[85752] expected \"address\" but predicted \"notes\"\n", "[85930] expected \"address\" but predicted \"language\"\n", - "[85971] expected \"address\" but predicted \"location\"\n", - "[86084] expected \"address\" but predicted \"name\"\n", "[86891] expected \"address\" but predicted \"rank\"\n", - "[87332] expected \"address\" but predicted \"capacity\"\n", + "[87332] expected \"address\" but predicted \"code\"\n", "[87413] expected \"address\" but predicted \"location\"\n", "[87891] expected \"address\" but predicted \"sales\"\n", - "[87958] expected \"address\" but predicted \"affiliation\"\n", - "[89058] expected \"address\" but predicted \"city\"\n", + "[87958] expected \"address\" but predicted \"format\"\n", + "[88056] expected \"address\" but predicted \"city\"\n", + "[89504] expected \"address\" but predicted \"area\"\n", "[89800] expected \"address\" but predicted \"location\"\n", - "[90054] expected \"address\" but predicted \"name\"\n", + "[90054] expected \"address\" but predicted \"location\"\n", "[90466] expected \"address\" but predicted \"location\"\n", "[90582] expected \"address\" but predicted \"city\"\n", - "[92393] expected \"address\" but predicted \"location\"\n", - "[93361] expected \"address\" but predicted \"city\"\n", + "[90584] expected \"address\" but predicted \"location\"\n", "[93557] expected \"address\" but predicted \"country\"\n", + "[95220] expected \"address\" but predicted \"notes\"\n", "[95411] expected \"address\" but predicted \"location\"\n", - "[95748] expected \"address\" but predicted \"location\"\n", - "[95769] expected \"address\" but predicted \"city\"\n", + "[95769] expected \"address\" but predicted \"team name\"\n", "[96379] expected \"address\" but predicted \"location\"\n", - "[96640] expected \"address\" but predicted \"product\"\n", - "[96728] expected \"address\" but predicted \"state\"\n", - "[99038] expected \"address\" but predicted \"name\"\n", + "[96640] expected \"address\" but predicted \"format\"\n", + "[96728] expected \"address\" but predicted \"birth place\"\n", + "[97594] expected \"address\" but predicted \"artist\"\n", + "[98529] expected \"address\" but predicted \"artist\"\n", "[99237] expected \"address\" but predicted \"location\"\n", "[100797] expected \"address\" but predicted \"symbol\"\n", "[101634] expected \"address\" but predicted \"area\"\n", "[102060] expected \"address\" but predicted \"name\"\n", - "[103165] expected \"address\" but predicted \"language\"\n", - "[103681] expected \"address\" but predicted \"location\"\n", - "[105085] expected \"address\" but predicted \"city\"\n", + "[103165] expected \"address\" but predicted \"notes\"\n", "[107281] expected \"address\" but predicted \"location\"\n", "[108550] expected \"address\" but predicted \"grades\"\n", - "[109367] expected \"address\" but predicted \"region\"\n", + "[109367] expected \"address\" but predicted \"team name\"\n", "[109427] expected \"address\" but predicted \"result\"\n", - "[109740] expected \"address\" but predicted \"city\"\n", "[111142] expected \"address\" but predicted \"location\"\n", + "[112597] expected \"address\" but predicted \"location\"\n", "[113711] expected \"address\" but predicted \"location\"\n", - "[114356] expected \"address\" but predicted \"day\"\n", + "[114356] expected \"address\" but predicted \"position\"\n", "[114372] expected \"address\" but predicted \"area\"\n", "[114485] expected \"address\" but predicted \"description\"\n", - "[115607] expected \"address\" but predicted \"region\"\n", - "[116630] expected \"address\" but predicted \"county\"\n", - "[118623] expected \"address\" but predicted \"language\"\n", + "[118623] expected \"address\" but predicted \"notes\"\n", "[120003] expected \"address\" but predicted \"owner\"\n", - "[120180] expected \"address\" but predicted \"description\"\n", - "[122561] expected \"address\" but predicted \"language\"\n", + "[122561] expected \"address\" but predicted \"notes\"\n", "[123013] expected \"address\" but predicted \"rank\"\n", "[123713] expected \"address\" but predicted \"location\"\n", "[124794] expected \"address\" but predicted \"company\"\n", - "[124809] expected \"address\" but predicted \"location\"\n", - "[125936] expected \"address\" but predicted \"location\"\n", - "[126147] expected \"address\" but predicted \"location\"\n", + "[126147] expected \"address\" but predicted \"region\"\n", "[126442] expected \"address\" but predicted \"capacity\"\n", "[126729] expected \"address\" but predicted \"location\"\n", - "[127753] expected \"address\" but predicted \"education\"\n", - "[131254] expected \"address\" but predicted \"company\"\n", + "[127753] expected \"address\" but predicted \"order\"\n", + "[129842] expected \"address\" but predicted \"birth place\"\n", + "[130902] expected \"address\" but predicted \"rank\"\n", + "[131254] expected \"address\" but predicted \"owner\"\n", + "[131471] expected \"address\" but predicted \"location\"\n", "[131693] expected \"address\" but predicted \"location\"\n", "[132446] expected \"address\" but predicted \"location\"\n", - "[132881] expected \"address\" but predicted \"artist\"\n", - "[132989] expected \"address\" but predicted \"location\"\n", - "[133669] expected \"address\" but predicted \"album\"\n", + "[132881] expected \"address\" but predicted \"album\"\n", + "[133665] expected \"address\" but predicted \"region\"\n", "[133805] expected \"address\" but predicted \"name\"\n", - "[136727] expected \"address\" but predicted \"status\"\n", - "[136771] expected \"address\" but predicted \"location\"\n", + "[135969] expected \"address\" but predicted \"area\"\n", + "[136727] expected \"address\" but predicted \"result\"\n", "[137027] expected \"address\" but predicted \"location\"\n" ] }, @@ -799,93 +824,93 @@ "name": "stdout", "output_type": "stream", "text": [ - "Total mismatches: 14623 (F1 score: 0.8937685721454983)\n" + "Total mismatches: 14419 (F1 score: 0.8951410029373902)\n" ] }, { "data": { "text/plain": [ - "[('name', 761),\n", - " ('rank', 592),\n", - " ('position', 551),\n", - " ('location', 489),\n", - " ('region', 473),\n", - " ('team', 428),\n", - " ('description', 422),\n", - " ('artist', 395),\n", - " ('area', 394),\n", + "[('name', 727),\n", + " ('rank', 663),\n", + " ('region', 521),\n", + " ('location', 509),\n", + " ('position', 491),\n", + " ('description', 400),\n", + " ('team', 390),\n", + " ('artist', 385),\n", " ('notes', 364),\n", - " ('product', 361),\n", - " ('category', 353),\n", - " ('type', 342),\n", - " ('company', 335),\n", - " ('city', 330),\n", - " ('day', 326),\n", - " ('album', 292),\n", - " ('code', 290),\n", - " ('team name', 282),\n", - " ('ranking', 278),\n", - " ('class', 264),\n", - " ('order', 254),\n", - " ('sex', 254),\n", - " ('person', 222),\n", - " ('gender', 219),\n", - " ('status', 217),\n", - " ('owner', 211),\n", - " ('weight', 206),\n", - " ('result', 194),\n", - " ('year', 193),\n", - " ('address', 193),\n", - " ('duration', 191),\n", - " ('country', 177),\n", - " ('service', 176),\n", + " ('type', 363),\n", + " ('area', 357),\n", + " ('category', 341),\n", + " ('company', 340),\n", + " ('album', 338),\n", + " ('day', 329),\n", + " ('product', 322),\n", + " ('ranking', 302),\n", + " ('gender', 287),\n", + " ('city', 286),\n", + " ('team name', 283),\n", + " ('code', 274),\n", + " ('class', 253),\n", + " ('person', 221),\n", + " ('owner', 219),\n", + " ('weight', 203),\n", + " ('status', 197),\n", + " ('brand', 189),\n", + " ('year', 189),\n", + " ('credit', 176),\n", + " ('result', 174),\n", " ('manufacturer', 171),\n", - " ('brand', 168),\n", - " ('origin', 162),\n", - " ('plays', 152),\n", - " ('credit', 151),\n", - " ('component', 149),\n", - " ('sales', 143),\n", - " ('range', 135),\n", - " ('format', 133),\n", - " ('age', 122),\n", - " ('county', 118),\n", - " ('state', 117),\n", - " ('club', 113),\n", - " ('director', 110),\n", - " ('nationality', 107),\n", - " ('publisher', 105),\n", - " ('capacity', 103),\n", - " ('classification', 89),\n", - " ('affiliation', 87),\n", - " ('command', 85),\n", - " ('symbol', 84),\n", - " ('family', 74),\n", - " ('elevation', 73),\n", - " ('operator', 68),\n", - " ('education', 66),\n", - " ('depth', 66),\n", - " ('creator', 60),\n", - " ('requirement', 60),\n", - " ('language', 54),\n", - " ('genre', 53),\n", - " ('file size', 48),\n", - " ('organisation', 45),\n", - " ('industry', 43),\n", - " ('species', 40),\n", - " ('affiliate', 38),\n", - " ('birth place', 35),\n", - " ('collection', 32),\n", - " ('religion', 31),\n", - " ('continent', 28),\n", - " ('jockey', 23),\n", - " ('isbn', 18),\n", - " ('grades', 13),\n", - " ('currency', 11),\n", - " ('birth date', 11)]" + " ('address', 171),\n", + " ('service', 167),\n", + " ('order', 165),\n", + " ('sex', 164),\n", + " ('duration', 155),\n", + " ('age', 153),\n", + " ('sales', 152),\n", + " ('country', 152),\n", + " ('plays', 147),\n", + " ('component', 147),\n", + " ('origin', 144),\n", + " ('range', 139),\n", + " ('club', 133),\n", + " ('nationality', 131),\n", + " ('state', 129),\n", + " ('county', 127),\n", + " ('format', 120),\n", + " ('director', 102),\n", + " ('capacity', 101),\n", + " ('command', 100),\n", + " ('symbol', 94),\n", + " ('publisher', 89),\n", + " ('classification', 81),\n", + " ('depth', 80),\n", + " ('family', 78),\n", + " ('affiliation', 76),\n", + " ('elevation', 70),\n", + " ('language', 69),\n", + " ('operator', 62),\n", + " ('requirement', 58),\n", + " ('education', 57),\n", + " ('file size', 56),\n", + " ('genre', 56),\n", + " ('creator', 55),\n", + " ('industry', 44),\n", + " ('organisation', 44),\n", + " ('species', 41),\n", + " ('affiliate', 39),\n", + " ('jockey', 33),\n", + " ('collection', 33),\n", + " ('birth place', 33),\n", + " ('religion', 27),\n", + " ('continent', 24),\n", + " ('isbn', 17),\n", + " ('birth date', 15),\n", + " ('currency', 13),\n", + " ('grades', 12)]" ] }, - "execution_count": 56, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -913,7 +938,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -922,7 +947,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 24, "metadata": {}, "outputs": [ { diff --git a/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb b/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb index 05e2280..80d1399 100644 --- a/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb +++ b/notebooks/02-2-train-and-test-sherlock-rf-ensemble.ipynb @@ -1,12 +1,22 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train and test Sherlock when ensembled with a RF classifier\n", + "To boost the performance of Sherlock, it can be combined with a RF classifier.\n", + "\n", + "The scripts below show the procedure for doing so." + ] + }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "nn_model_id = 'retrained_sherlock'" + "model_id = 'sherlock'" ] }, { @@ -27,6 +37,7 @@ ], "source": [ "# If you need fully deterministic results between runs, set the following environment value prior to launching jupyter.\n", + "\n", "# See comment in sherlock.features.paragraph_vectors.infer_paragraph_embeddings_features for more info.\n", "%env PYTHONHASHSEED" ] @@ -43,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -60,8 +71,7 @@ "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.metrics import classification_report, f1_score\n", "\n", - "from sherlock.deploy.predict_sherlock import predict_sherlock_proba, _transform_predictions_to_classes\n", - "from sherlock.deploy.train_sherlock import train_sherlock" + "from sherlock.deploy.model import SherlockModel" ] }, { @@ -73,15 +83,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-07 15:46:02.658556\n", - "Load data (train) process took 0:00:04.837128 seconds.\n" + "Started at 2022-02-21 14:44:58.387328\n", + "Load data (train) process took 0:00:07.072707 seconds.\n" ] } ], @@ -92,12 +102,14 @@ "X_train = pd.read_parquet('../data/data/processed/train.parquet')\n", "y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()\n", "\n", + "y_train = np.array([x.lower() for x in y_train])\n", + "\n", "print(f'Load data (train) process took {datetime.now() - start} seconds.')" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -116,15 +128,15 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-07 15:46:07.682626\n", - "Load data (validation) process took 0:00:01.803970 seconds.\n" + "Started at 2022-02-21 14:16:45.455219\n", + "Load data (validation) process took 0:00:02.024156 seconds.\n" ] } ], @@ -142,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -151,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -167,15 +179,15 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-07 15:46:11.682256\n", - "Finished at 2022-02-07 16:01:59.050932, took 0:15:47.368743 seconds\n" + "Started at 2022-02-21 14:17:08.147857\n", + "Finished at 2022-02-21 14:38:09.947917, took 0:21:01.802720 seconds\n" ] } ], @@ -199,7 +211,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -217,16 +229,16 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Started at 2022-02-07 16:01:59.120891\n", + "Started at 2022-02-21 14:38:10.601540\n", "Trained and saved new model.\n", - "Finished at 2022-02-07 16:02:00.975332, took 0:00:01.854455 seconds\n" + "Finished at 2022-02-21 14:38:12.493349, took 0:00:01.891821 seconds\n" ] } ], @@ -243,6 +255,39 @@ "print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')" ] }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['address', 'affiliate', 'affiliation', 'age', 'album', 'area',\n", + " 'artist', 'birth date', 'birth place', 'brand', 'capacity',\n", + " 'category', 'city', 'class', 'classification', 'club', 'code',\n", + " 'collection', 'command', 'company', 'component', 'continent',\n", + " 'country', 'county', 'creator', 'credit', 'currency', 'day',\n", + " 'depth', 'description', 'director', 'duration', 'education',\n", + " 'elevation', 'family', 'file size', 'format', 'gender', 'genre',\n", + " 'grades', 'industry', 'isbn', 'jockey', 'language', 'location',\n", + " 'manufacturer', 'name', 'nationality', 'notes', 'operator',\n", + " 'order', 'organisation', 'origin', 'owner', 'person', 'plays',\n", + " 'position', 'product', 'publisher', 'range', 'rank', 'ranking',\n", + " 'region', 'religion', 'requirement', 'result', 'sales', 'service',\n", + " 'sex', 'species', 'state', 'status', 'symbol', 'team', 'team name',\n", + " 'type', 'weight', 'year'], dtype='\n", - "f1 score 0.8909529156786774\n" + "f1 score 0.8912755744265719\n" ] } ], @@ -337,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -346,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 76, "metadata": {}, "outputs": [ { @@ -354,7 +400,7 @@ "output_type": "stream", "text": [ "prediction count 137353, type = \n", - "f1 score 0.8884613184751746\n" + "f1 score 0.8883526561931331\n" ] } ], @@ -371,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -380,7 +426,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 78, "metadata": {}, "outputs": [ { @@ -388,7 +434,7 @@ "output_type": "stream", "text": [ "prediction count 137353, type = \n", - "f1 score 0.8933550546229518\n" + "f1 score 0.8940645473980389\n" ] } ], @@ -405,39 +451,18 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 79, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING: Logging before flag parsing goes to stderr.\n", - "W0207 16:02:52.102345 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "Call initializer instance with the dtype argument instead of passing it to the constructor\n", - "W0207 16:02:52.104378 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "Call initializer instance with the dtype argument instead of passing it to the constructor\n", - "W0207 16:02:52.108886 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/init_ops.py:97: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "Call initializer instance with the dtype argument instead of passing it to the constructor\n", - "W0207 16:02:52.138780 4654583296 deprecation.py:506] From /Users/lowecg/source/private-github/sherlock-project-1/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "If using Keras pass *_constraint arguments to layers.\n", - "2022-02-07 16:02:52.704754: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA\n", - "2022-02-07 16:02:52.762490: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fc083ef6bc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n", - "2022-02-07 16:02:52.762507: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version\n" - ] - } - ], + "outputs": [], "source": [ - "predicted_sherlock_proba = predict_sherlock_proba(X_test, nn_id=nn_model_id)" + "model = SherlockModel()\n", + "model.initialize_model_from_json(with_weights=True, model_id=\"sherlock\")\n", + "predicted_sherlock_proba = model.predict_proba(X_test)" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 80, "metadata": {}, "outputs": [ { @@ -445,7 +470,7 @@ "output_type": "stream", "text": [ "prediction count 137353, type = \n", - "f1 score 0.8940572197723697\n" + "f1 score 0.8951410029373902\n" ] } ], @@ -462,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -475,13 +500,12 @@ " x = nn_probs + voting_probs\n", " x = x / 2\n", "\n", - " combined.append(x)\n", - " " + " combined.append(x)" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 82, "metadata": {}, "outputs": [ { @@ -489,7 +513,7 @@ "output_type": "stream", "text": [ "prediction count 137353, type = \n", - "f1 score 0.9047220789717997\n" + "f1 score 0.905491661885665\n" ] } ], @@ -501,7 +525,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ @@ -521,7 +545,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -539,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 85, "metadata": {}, "outputs": [ { @@ -547,11 +571,11 @@ "output_type": "stream", "text": [ "\t\tf1-score\tprecision\trecall\t\tsupport\n", - "grades\t\t0.994\t\t0.994\t\t0.994\t\t1765\n", - "isbn\t\t0.991\t\t0.993\t\t0.989\t\t1430\n", - "jockey\t\t0.986\t\t0.980\t\t0.991\t\t2819\n", - "industry\t0.983\t\t0.979\t\t0.988\t\t2958\n", - "birth date\t0.978\t\t0.981\t\t0.975\t\t479\n" + "grades\t\t0.995\t\t0.994\t\t0.995\t\t1765\n", + "isbn\t\t0.990\t\t0.992\t\t0.989\t\t1430\n", + "industry\t0.986\t\t0.985\t\t0.988\t\t2958\n", + "jockey\t\t0.985\t\t0.984\t\t0.987\t\t2819\n", + "currency\t0.979\t\t0.985\t\t0.973\t\t405\n" ] } ], @@ -568,7 +592,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 86, "metadata": {}, "outputs": [ { @@ -576,11 +600,11 @@ "output_type": "stream", "text": [ "\t\tf1-score\tprecision\trecall\t\tsupport\n", - "rank\t\t0.751\t\t0.710\t\t0.796\t\t2983\n", - "person\t\t0.690\t\t0.702\t\t0.679\t\t579\n", - "sales\t\t0.633\t\t0.747\t\t0.550\t\t322\n", - "director\t0.598\t\t0.648\t\t0.556\t\t225\n", - "ranking\t\t0.594\t\t0.855\t\t0.456\t\t439\n" + "rank\t\t0.738\t\t0.678\t\t0.810\t\t2983\n", + "person\t\t0.695\t\t0.767\t\t0.636\t\t579\n", + "sales\t\t0.615\t\t0.667\t\t0.571\t\t322\n", + "director\t0.604\t\t0.661\t\t0.556\t\t225\n", + "ranking\t\t0.569\t\t0.823\t\t0.435\t\t439\n" ] } ], @@ -597,7 +621,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 87, "metadata": {}, "outputs": [ { @@ -606,88 +630,88 @@ "text": [ " precision recall f1-score support\n", "\n", - " address 0.926 0.947 0.937 3003\n", - " affiliate 0.976 0.814 0.888 204\n", - " affiliation 0.978 0.958 0.968 1768\n", - " age 0.882 0.963 0.921 3033\n", - " album 0.883 0.901 0.892 3035\n", - " area 0.888 0.840 0.863 1987\n", - " artist 0.807 0.886 0.845 3043\n", - " birth date 0.981 0.975 0.978 479\n", - " birth place 0.974 0.904 0.938 418\n", - " brand 0.795 0.723 0.757 574\n", - " capacity 0.879 0.746 0.807 362\n", - " category 0.913 0.901 0.907 3087\n", - " city 0.857 0.912 0.883 2966\n", - " class 0.906 0.926 0.916 2971\n", - "classification 0.955 0.867 0.909 587\n", - " club 0.972 0.961 0.967 2977\n", + " address 0.929 0.951 0.940 3003\n", + " affiliate 0.949 0.819 0.879 204\n", + " affiliation 0.975 0.958 0.966 1768\n", + " age 0.891 0.955 0.922 3033\n", + " album 0.894 0.895 0.894 3035\n", + " area 0.892 0.836 0.863 1987\n", + " artist 0.810 0.884 0.846 3043\n", + " birth date 0.983 0.969 0.976 479\n", + " birth place 0.939 0.919 0.929 418\n", + " brand 0.849 0.695 0.764 574\n", + " capacity 0.851 0.771 0.809 362\n", + " category 0.927 0.898 0.912 3087\n", + " city 0.870 0.910 0.890 2966\n", + " class 0.921 0.923 0.922 2971\n", + "classification 0.946 0.874 0.909 587\n", + " club 0.975 0.957 0.966 2977\n", " code 0.921 0.925 0.923 2956\n", - " collection 0.968 0.943 0.955 476\n", - " command 0.933 0.930 0.931 1045\n", - " company 0.890 0.904 0.897 3041\n", - " component 0.893 0.895 0.894 1226\n", - " continent 0.908 0.916 0.912 227\n", - " country 0.896 0.952 0.923 3038\n", - " county 0.938 0.963 0.950 2959\n", - " creator 0.848 0.839 0.843 347\n", - " credit 0.894 0.823 0.857 941\n", - " currency 0.982 0.968 0.975 405\n", - " day 0.944 0.915 0.929 3038\n", - " depth 0.959 0.944 0.952 947\n", - " description 0.795 0.886 0.838 3042\n", - " director 0.648 0.556 0.598 225\n", - " duration 0.948 0.953 0.950 3000\n", - " education 0.902 0.853 0.877 313\n", - " elevation 0.953 0.960 0.956 1299\n", - " family 0.970 0.903 0.935 746\n", - " file size 0.925 0.886 0.905 361\n", - " format 0.977 0.957 0.967 2956\n", - " gender 0.859 0.834 0.846 1030\n", - " genre 0.968 0.948 0.958 1163\n", - " grades 0.994 0.994 0.994 1765\n", - " industry 0.979 0.988 0.983 2958\n", - " isbn 0.993 0.989 0.991 1430\n", - " jockey 0.980 0.991 0.986 2819\n", - " language 0.927 0.945 0.936 1474\n", - " location 0.877 0.844 0.860 2949\n", - " manufacturer 0.904 0.798 0.848 945\n", - " name 0.788 0.725 0.755 3017\n", - " nationality 0.853 0.738 0.791 424\n", - " notes 0.779 0.832 0.804 2303\n", - " operator 0.850 0.854 0.852 404\n", - " order 0.891 0.863 0.877 1462\n", - " organisation 0.873 0.817 0.844 262\n", - " origin 0.964 0.899 0.930 1439\n", - " owner 0.935 0.877 0.905 1673\n", - " person 0.702 0.679 0.690 579\n", - " plays 0.826 0.919 0.870 1513\n", - " position 0.842 0.856 0.849 3057\n", - " product 0.865 0.893 0.879 2647\n", - " publisher 0.897 0.918 0.907 880\n", - " range 0.917 0.801 0.855 577\n", - " rank 0.710 0.796 0.751 2983\n", - " ranking 0.855 0.456 0.594 439\n", - " region 0.894 0.850 0.871 2740\n", - " religion 0.975 0.932 0.953 340\n", - " requirement 0.938 0.807 0.867 300\n", - " result 0.969 0.948 0.958 2920\n", - " sales 0.747 0.550 0.633 322\n", - " service 0.978 0.925 0.951 2222\n", - " sex 0.939 0.940 0.939 2997\n", - " species 0.919 0.954 0.936 819\n", - " state 0.942 0.960 0.951 3030\n", - " status 0.957 0.939 0.948 3100\n", - " symbol 0.964 0.967 0.966 1752\n", - " team 0.878 0.867 0.873 3011\n", - " team name 0.912 0.840 0.875 1639\n", - " type 0.888 0.894 0.891 2909\n", - " weight 0.946 0.951 0.949 2963\n", - " year 0.969 0.941 0.955 3015\n", + " collection 0.987 0.935 0.960 476\n", + " command 0.941 0.918 0.929 1045\n", + " company 0.913 0.898 0.905 3041\n", + " component 0.904 0.892 0.898 1226\n", + " continent 0.887 0.930 0.908 227\n", + " country 0.897 0.957 0.926 3038\n", + " county 0.944 0.964 0.954 2959\n", + " creator 0.807 0.841 0.824 347\n", + " credit 0.890 0.832 0.860 941\n", + " currency 0.985 0.973 0.979 405\n", + " day 0.948 0.914 0.931 3038\n", + " depth 0.945 0.945 0.945 947\n", + " description 0.809 0.884 0.845 3042\n", + " director 0.661 0.556 0.604 225\n", + " duration 0.935 0.955 0.945 3000\n", + " education 0.887 0.856 0.872 313\n", + " elevation 0.961 0.955 0.958 1299\n", + " family 0.967 0.905 0.935 746\n", + " file size 0.946 0.867 0.905 361\n", + " format 0.969 0.960 0.964 2956\n", + " gender 0.860 0.836 0.848 1030\n", + " genre 0.969 0.953 0.961 1163\n", + " grades 0.994 0.995 0.995 1765\n", + " industry 0.985 0.988 0.986 2958\n", + " isbn 0.992 0.989 0.990 1430\n", + " jockey 0.984 0.987 0.985 2819\n", + " language 0.923 0.947 0.935 1474\n", + " location 0.901 0.838 0.868 2949\n", + " manufacturer 0.876 0.828 0.851 945\n", + " name 0.733 0.769 0.751 3017\n", + " nationality 0.906 0.708 0.795 424\n", + " notes 0.750 0.847 0.796 2303\n", + " operator 0.819 0.854 0.836 404\n", + " order 0.860 0.877 0.869 1462\n", + " organisation 0.852 0.855 0.853 262\n", + " origin 0.955 0.905 0.930 1439\n", + " owner 0.941 0.874 0.906 1673\n", + " person 0.767 0.636 0.695 579\n", + " plays 0.856 0.915 0.885 1513\n", + " position 0.848 0.850 0.849 3057\n", + " product 0.878 0.886 0.882 2647\n", + " publisher 0.904 0.903 0.904 880\n", + " range 0.885 0.797 0.839 577\n", + " rank 0.678 0.810 0.738 2983\n", + " ranking 0.823 0.435 0.569 439\n", + " region 0.905 0.844 0.873 2740\n", + " religion 0.975 0.926 0.950 340\n", + " requirement 0.942 0.817 0.875 300\n", + " result 0.968 0.947 0.958 2920\n", + " sales 0.667 0.571 0.615 322\n", + " service 0.970 0.929 0.949 2222\n", + " sex 0.940 0.938 0.939 2997\n", + " species 0.930 0.952 0.941 819\n", + " state 0.940 0.962 0.951 3030\n", + " status 0.949 0.943 0.946 3100\n", + " symbol 0.963 0.971 0.967 1752\n", + " team 0.867 0.871 0.869 3011\n", + " team name 0.898 0.842 0.869 1639\n", + " type 0.923 0.882 0.902 2909\n", + " weight 0.963 0.950 0.956 2963\n", + " year 0.966 0.946 0.956 3015\n", "\n", " accuracy 0.905 137353\n", - " macro avg 0.904 0.880 0.890 137353\n", - " weighted avg 0.906 0.905 0.905 137353\n", + " macro avg 0.903 0.880 0.890 137353\n", + " weighted avg 0.907 0.905 0.905 137353\n", "\n" ] } @@ -705,100 +729,100 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 88, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Total mismatches: 13058 (F1 score: 0.9047220789717997)\n" + "Total mismatches: 12994 (F1 score: 0.905491661885665)\n" ] }, { "data": { "text/plain": [ - "[('name', 830),\n", - " ('rank', 608),\n", - " ('location', 460),\n", - " ('position', 441),\n", - " ('region', 412),\n", - " ('team', 399),\n", - " ('notes', 388),\n", - " ('artist', 348),\n", - " ('description', 347),\n", - " ('area', 318),\n", - " ('type', 307),\n", - " ('category', 307),\n", - " ('album', 299),\n", - " ('company', 291),\n", - " ('product', 283),\n", - " ('team name', 262),\n", - " ('city', 261),\n", - " ('day', 258),\n", - " ('ranking', 239),\n", + "[('name', 697),\n", + " ('rank', 566),\n", + " ('location', 479),\n", + " ('position', 460),\n", + " ('region', 427),\n", + " ('team', 388),\n", + " ('artist', 353),\n", + " ('notes', 352),\n", + " ('description', 352),\n", + " ('type', 342),\n", + " ('area', 326),\n", + " ('album', 320),\n", + " ('category', 316),\n", + " ('company', 310),\n", + " ('product', 301),\n", + " ('city', 266),\n", + " ('day', 261),\n", + " ('team name', 259),\n", + " ('ranking', 248),\n", + " ('class', 229),\n", " ('code', 222),\n", - " ('class', 220),\n", - " ('owner', 205),\n", - " ('order', 200),\n", - " ('manufacturer', 191),\n", - " ('status', 190),\n", - " ('person', 186),\n", - " ('sex', 181),\n", - " ('year', 177),\n", - " ('gender', 171),\n", - " ('credit', 167),\n", - " ('service', 167),\n", - " ('brand', 159),\n", - " ('address', 158),\n", - " ('result', 151),\n", - " ('country', 147),\n", - " ('origin', 146),\n", - " ('weight', 145),\n", - " ('sales', 145),\n", - " ('duration', 140),\n", - " ('component', 129),\n", - " ('format', 126),\n", - " ('plays', 123),\n", - " ('state', 121),\n", - " ('club', 116),\n", - " ('range', 115),\n", - " ('nationality', 111),\n", - " ('age', 111),\n", - " ('county', 110),\n", + " ('person', 211),\n", + " ('owner', 210),\n", + " ('sex', 185),\n", + " ('order', 180),\n", + " ('status', 178),\n", + " ('brand', 175),\n", + " ('gender', 169),\n", + " ('manufacturer', 163),\n", + " ('year', 163),\n", + " ('credit', 158),\n", + " ('service', 158),\n", + " ('result', 154),\n", + " ('weight', 149),\n", + " ('address', 146),\n", + " ('sales', 138),\n", + " ('duration', 136),\n", + " ('age', 136),\n", + " ('origin', 136),\n", + " ('component', 133),\n", + " ('country', 130),\n", + " ('club', 129),\n", + " ('plays', 128),\n", + " ('nationality', 124),\n", + " ('format', 119),\n", + " ('range', 117),\n", + " ('state', 115),\n", + " ('county', 108),\n", " ('director', 100),\n", - " ('capacity', 92),\n", - " ('language', 81),\n", - " ('classification', 78),\n", + " ('command', 86),\n", + " ('publisher', 85),\n", + " ('capacity', 83),\n", + " ('language', 78),\n", " ('affiliation', 75),\n", - " ('command', 73),\n", - " ('family', 72),\n", - " ('publisher', 72),\n", - " ('genre', 60),\n", + " ('classification', 74),\n", + " ('family', 71),\n", " ('operator', 59),\n", - " ('requirement', 58),\n", - " ('symbol', 57),\n", - " ('creator', 56),\n", - " ('depth', 53),\n", - " ('elevation', 52),\n", - " ('organisation', 48),\n", - " ('education', 46),\n", - " ('file size', 41),\n", - " ('birth place', 40),\n", - " ('affiliate', 38),\n", - " ('species', 38),\n", + " ('elevation', 58),\n", + " ('creator', 55),\n", + " ('requirement', 55),\n", + " ('genre', 55),\n", + " ('depth', 52),\n", + " ('symbol', 51),\n", + " ('file size', 48),\n", + " ('education', 45),\n", + " ('species', 39),\n", + " ('organisation', 38),\n", + " ('affiliate', 37),\n", + " ('jockey', 36),\n", " ('industry', 36),\n", - " ('collection', 27),\n", - " ('jockey', 25),\n", - " ('religion', 23),\n", - " ('continent', 19),\n", + " ('birth place', 34),\n", + " ('collection', 31),\n", + " ('religion', 25),\n", + " ('continent', 16),\n", " ('isbn', 16),\n", - " ('currency', 13),\n", - " ('birth date', 12),\n", - " ('grades', 10)]" + " ('birth date', 15),\n", + " ('currency', 11),\n", + " ('grades', 8)]" ] }, - "execution_count": 32, + "execution_count": 88, "metadata": {}, "output_type": "execute_result" } @@ -824,7 +848,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ @@ -833,7 +857,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 90, "metadata": {}, "outputs": [ { @@ -855,14 +879,14 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 91, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Completed at 2022-02-07 16:03:14.044571\n" + "Completed at 2022-02-21 14:53:39.160195\n" ] } ], diff --git a/notebooks/03-train-paragraph-vector-features-optional.ipynb b/notebooks/03-retrain-paragraph-vector-features.ipynb similarity index 95% rename from notebooks/03-train-paragraph-vector-features-optional.ipynb rename to notebooks/03-retrain-paragraph-vector-features.ipynb index 90bc4f6..bc3aa4f 100644 --- a/notebooks/03-train-paragraph-vector-features-optional.ipynb +++ b/notebooks/03-retrain-paragraph-vector-features.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load training set and train paragraph vectors\n", + "Note: the paragraph vector model has been trained and is downloaded in the `prepare_feature_extraction()` function.\n", + "\n", + "Retraining is therefore not needed, but optional" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -32,13 +42,6 @@ "%env PYTHONHASHSEED" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load training set and train paragraph vectors" - ] - }, { "cell_type": "code", "execution_count": 3, @@ -87,9 +90,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Download and read in raw data\n", - "\n", - "You can skip this step if you want to use a preprocessed data file." + "## Download and read in raw data\n" ] }, { diff --git a/requirements.txt b/requirements.txt index d1679bf..00d2898 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,14 +7,15 @@ certifi==2019.6.16 chardet==3.0.4 docutils==0.14 gast==0.2.2 +gdown==4.3.0 gensim==3.8.0 google-pasta==0.1.7 -googledrivedownloader==0.4 grpcio==1.22.0 h5py==2.9.0 idna==2.8 jmespath==0.9.4 joblib==0.13.2 +jupyter==1.0.0 Keras-Applications==1.0.8 Keras-Preprocessing==1.1.0 line_profiler==3.3.1 diff --git a/requirements38.txt b/requirements38.txt index 1e2b04b..39926b5 100644 --- a/requirements38.txt +++ b/requirements38.txt @@ -8,14 +8,15 @@ chardet==3.0.4 cloudpickle==1.6.0 docutils==0.14 gast==0.2.2 +gdown==4.3.0 gensim==3.8.0 google-pasta==0.1.7 -googledrivedownloader==0.4 grpcio==1.34.1 h5py==3.1.0 idna==2.8 jmespath==0.9.4 joblib==1.0.0 +jupyter==1.0.0 Keras-Applications==1.0.8 Keras-Preprocessing==1.1.0 loky==2.9.0 diff --git a/sherlock/deploy/model.py b/sherlock/deploy/model.py index aca7d93..d95b925 100644 --- a/sherlock/deploy/model.py +++ b/sherlock/deploy/model.py @@ -101,6 +101,8 @@ def fit( self.model = model + _ = helpers._get_categorical_label_encodings(y_train, y_val, model_id) + def predict(self, X: pd.DataFrame, model_id: str = "sherlock") -> np.array: """Use sherlock model to generate predictions for X. diff --git a/sherlock/features/preprocessing.py b/sherlock/features/preprocessing.py index 874f87a..3ac3b5f 100644 --- a/sherlock/features/preprocessing.py +++ b/sherlock/features/preprocessing.py @@ -8,7 +8,6 @@ import pandas as pd from functools import partial -from google_drive_downloader import GoogleDriveDownloader as gd from pyarrow.parquet import ParquetFile from tqdm import tqdm diff --git a/sherlock/helpers.py b/sherlock/helpers.py index 7765a6e..13afc13 100644 --- a/sherlock/helpers.py +++ b/sherlock/helpers.py @@ -9,18 +9,17 @@ def download_data(): The data is downloaded from Google Drive and stored in the 'data/' directory. """ data_dir = "../data/data/" - data_zip = "../data.zip" + zip_filepath = "../data/data.zip" print(f"Downloading the raw data into {data_dir}.") if not os.path.exists(data_dir): print("Downloading data directory.") - dir_name = data_zip gdown.download( url="https://drive.google.com/uc?id=1-g0zbKFAXz7zKZc0Dnh74uDBpZCv4YqU", - output=dir_name, + output=zip_filepath, ) - with zipfile.ZipFile(data_zip, "r") as zf: + with zipfile.ZipFile(zip_filepath, "r") as zf: zf.extractall("../data/") print("Data was downloaded.") diff --git a/tests/test_helper.py b/tests/test_helper.py index 65d6737..b0c0a74 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -1,4 +1,4 @@ -from sherlock.deploy.model_helpers import categorize_features +from sherlock.deploy.helpers import categorize_features from collections import OrderedDict CATEGORY_FEATURE_KEYS: dict = categorize_features()