From a20c987b84df20ebe0a7f9e842c05abade5c8a0f Mon Sep 17 00:00:00 2001 From: rly Date: Wed, 10 Apr 2024 02:54:42 -0700 Subject: [PATCH 1/2] Add draft example scripts --- example.ipynb | 654 ++++++++++++++++++++++++++++++++++++++++++++++++++ example.py | 110 +++++++++ 2 files changed, 764 insertions(+) create mode 100644 example.ipynb create mode 100644 example.py diff --git a/example.ipynb b/example.ipynb new file mode 100644 index 0000000..babcc9a --- /dev/null +++ b/example.ipynb @@ -0,0 +1,654 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "6ff4d1a4-2aca-413e-9c68-5243cec63647", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rly/Documents/NWB/hdmf/src/hdmf/utils.py:668: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.8.0 is already loaded.\n", + " return func(args[0], **pargs)\n", + "/Users/rly/Documents/NWB/hdmf/src/hdmf/utils.py:668: UserWarning: Ignoring cached namespace 'core' version 2.5.0 because version 2.6.0-alpha is already loaded.\n", + " return func(args[0], **pargs)\n", + "/Users/rly/Documents/NWB/hdmf/src/hdmf/utils.py:668: UserWarning: Ignoring cached namespace 'hdmf-experimental' version 0.2.0 because version 0.5.0 is already loaded.\n", + " return func(args[0], **pargs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " spike_times unit_name \\\n", + "id \n", + "0 [0.019766666666666665, 0.028533333333333334, 0... 12 \n", + "1 [0.019766666666666665, 0.07186666666666666, 0.... 424 \n", + "2 [0.020266666666666665, 0.04423333333333333, 0.... 78 \n", + "3 [0.020766666666666666, 0.023266666666666668, 0... 105 \n", + "4 [0.022333333333333334, 0.1083, 0.1305, 0.13916... 110 \n", + ".. ... ... \n", + "553 [1422.7791666666667, 2570.4848666666667, 2575.... 543 \n", + "554 [1448.522, 1477.2797666666668, 1522.4593666666... 291 \n", + "555 [1481.3955333333333, 1573.5977333333333, 1575.... 515 \n", + "556 [1527.9567333333334, 1750.9875666666667, 2421.... 246 \n", + "557 [2161.375666666667, 2454.7655666666665, 2710.4... 422 \n", + "\n", + " presence_ratio_standard_deviation contamination noise_cutoff \\\n", + "id \n", + "0 5.258554 0.083082 136.186874 \n", + "1 287.292619 0.007035 38.420545 \n", + "2 42.179534 0.006504 15.857428 \n", + "3 75.238846 0.003428 56.012060 \n", + "4 52.149411 0.002114 0.515300 \n", + ".. ... ... ... \n", + "553 1.736014 0.000000 25.459525 \n", + "554 5.888234 0.781129 -0.719935 \n", + "555 1.159141 0.000000 10.750188 \n", + "556 1.320937 9.734769 6.934761 \n", + "557 1.328322 0.000000 -0.691095 \n", + "\n", + " mean_relative_depth sliding_refractory_period_violation cosmos_location \\\n", + "id \n", + "0 80.0 0.0 MB \n", + "1 100.0 1.0 MB \n", + "2 100.0 1.0 MB \n", + "3 120.0 1.0 MB \n", + "4 160.0 1.0 MB \n", + ".. ... ... ... \n", + "553 2600.0 0.0 HPF \n", + "554 2600.0 0.0 HPF \n", + "555 3160.0 0.0 Isocortex \n", + "556 3400.0 0.0 Isocortex \n", + "557 2520.0 0.0 HPF \n", + "\n", + " maximum_amplitude maximum_amplitude_channel ... median_amplitude \\\n", + "id ... \n", + "0 0.000175 6.0 ... 0.000087 \n", + "1 0.000159 8.0 ... 0.000077 \n", + "2 0.000159 9.0 ... 0.000090 \n", + "3 0.000201 10.0 ... 0.000083 \n", + "4 0.000211 14.0 ... 0.000115 \n", + ".. ... ... ... ... \n", + "553 0.000139 258.0 ... 0.000080 \n", + "554 0.000276 259.0 ... 0.000118 \n", + "555 0.000121 315.0 ... 0.000085 \n", + "556 0.000126 338.0 ... 0.000074 \n", + "557 0.000342 251.0 ... 0.000236 \n", + "\n", + " drift minimum_amplitude spike_count firing_rate \\\n", + "id \n", + "0 6.848070e+04 0.000082 4319.0 0.995299 \n", + "1 2.249384e+06 0.000054 155672.0 35.874101 \n", + "2 3.517302e+05 0.000066 37810.0 8.713190 \n", + "3 2.860285e+06 0.000060 265558.0 61.196968 \n", + "4 1.021966e+06 0.000065 157876.0 36.382005 \n", + ".. ... ... ... ... \n", + "553 9.665525e+03 0.000074 873.0 0.201180 \n", + "554 9.882481e+03 0.000069 1992.0 0.459050 \n", + "555 5.706065e+03 0.000076 534.0 0.123059 \n", + "556 2.543755e+03 0.000059 399.0 0.091948 \n", + "557 7.951559e+02 0.000143 318.0 0.073282 \n", + "\n", + " missed_spikes_estimate standard_deviation_amplitude allen_location \\\n", + "id \n", + "0 0.500000 0.725555 RN \n", + "1 0.148918 1.481336 RN \n", + "2 0.059499 1.184573 RN \n", + "3 0.183454 1.535417 RN \n", + "4 0.006965 1.608456 MRN \n", + ".. ... ... ... \n", + "553 0.500000 0.822730 CA1 \n", + "554 0.009546 1.394312 CA1 \n", + "555 0.500000 0.829071 VISpm5 \n", + "556 NaN 1.284709 VISam2/3 \n", + "557 NaN 1.050623 CA1 \n", + "\n", + " spike_amplitudes presence_ratio \n", + "id \n", + "0 [8.288184885804525e-05, 0.00011442759068298115... 0.997701 \n", + "1 [4.6580973513486687e-05, 5.296578081321374e-05... 0.816092 \n", + "2 [6.32776225415224e-05, 9.328839624452988e-05, ... 0.997701 \n", + "3 [0.00016554799180735244, 0.0001820722989910208... 0.997701 \n", + "4 [0.00015009149841748693, 0.0001974641043535484... 0.997701 \n", + ".. ... ... \n", + "553 [6.087924810443521e-05, 6.153988903027088e-05,... 0.779310 \n", + "554 [0.0001497319921455542, 0.00012844935458089068... 0.802299 \n", + "555 [9.970061409456434e-05, 7.378086503872224e-05,... 0.701149 \n", + "556 [9.996891360609634e-05, 9.484602794567695e-05,... 0.471264 \n", + "557 [5.5044274830256504e-05, 5.894106134079688e-05... 0.326437 \n", + "\n", + "[558 rows x 25 columns]\n", + "There are 558 units in this dataset.\n" + ] + } + ], + "source": [ + "from hdmf_ai import ResultsTable\n", + "from hdmf.common import get_hdf5io, HERD\n", + "import numpy as np\n", + "import pandas as pd\n", + "from pynwb import NWBHDF5IO\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "\n", + "filepath = \"/Users/rly/Documents/NWB_Data/dandisets/000409/sub-CSHL047/sub-CSHL047_ses-b52182e7-39f6-4914-9717-136db589706e_behavior+ecephys+image.nwb\"\n", + "io = NWBHDF5IO(filepath, \"r\")\n", + "nwbfile = io.read()\n", + "\n", + "# the NWB Units table stores information about the sorted single units (putative neurons) after preprocessing\n", + "# and spike sorting. each row represents a single unit. this dataset includes many metadata fields (table columns) for\n", + "# each unit.\n", + "units = nwbfile.units\n", + "\n", + "# the Units table can be most readiy viewed as a pandas DataFrame\n", + "units_df = units.to_dataframe()\n", + "print(units_df)\n", + "print(f\"There are {len(units_df)} units in this dataset.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "68611331-8162-46bc-b408-0ae1aa482395", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "There are 5 unique Cosmos locations in this dataset.\n", + "[[3.41640428e-01 3.72742163e-01 1.99089260e-01 4.73519749e-02\n", + " 3.91761738e-02]\n", + " [1.57087821e-02 5.40603789e-02 8.93193920e-01 3.70369084e-02\n", + " 1.06054673e-08]\n", + " [2.53179848e-01 3.42376210e-01 3.54873547e-01 4.80723731e-02\n", + " 1.49802201e-03]\n", + " ...\n", + " [3.46285279e-01 3.63496691e-01 1.87693063e-01 4.75699137e-02\n", + " 5.49550530e-02]\n", + " [3.72337933e-01 3.69315148e-01 1.78041745e-01 3.73962858e-02\n", + " 4.29088877e-02]\n", + " [3.48455886e-01 3.57365528e-01 1.91957666e-01 4.80535150e-02\n", + " 5.41674053e-02]]\n", + "The logistic regression achieved a score of 0.46706586826347307 on the test set!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rly/mambaforge/envs/hdmf-ml/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + } + ], + "source": [ + "# run a simple classifier on the units data to predict the location (brain area) of the unit based on\n", + "# the amplitude, firing rate, spike count, and presence ratio. there are several ways to label the brain\n", + "# area of a unit. here, we use the label using the coarsest atlas, the Cosmos atlas, which has a total\n", + "# of 12 annotation regions. in this dataset, there are 5 unique Cosmos locations.\n", + "enc = LabelEncoder()\n", + "y = np.uint(enc.fit_transform(units[\"cosmos_location\"].data[:]))\n", + "unique_labels = enc.classes_\n", + "print(f\"There are {len(unique_labels)} unique Cosmos locations in this dataset.\")\n", + "\n", + "# split the data into training and test sets\n", + "# TODO integrate with sklearn.model_selection.train_test_split\n", + "proportion_train = 0.7\n", + "n_train_samples = int(np.round(proportion_train * len(units_df)))\n", + "n_test_samples = len(units) - n_train_samples\n", + "# train = 0, validate = 1, test = 2\n", + "tvt = np.array([0] * n_train_samples + [2] * n_test_samples)\n", + "np.random.shuffle(tvt)\n", + "\n", + "feature_names = [\n", + " \"median_amplitude\",\n", + " \"standard_deviation_amplitude\",\n", + " \"firing_rate\",\n", + " \"spike_count\",\n", + " \"presence_ratio\",\n", + "]\n", + "X = units_df[feature_names]\n", + "X_train = X[tvt == 0]\n", + "y_train = y[tvt == 0]\n", + "X_test = X[tvt == 2]\n", + "y_test = y[tvt == 2]\n", + "\n", + "# run logistic regression\n", + "logreg = LogisticRegression()\n", + "logreg.fit(X_train, y_train)\n", + "predictions_all = logreg.predict(X)\n", + "prediction_proba_all = logreg.predict_proba(X)\n", + "print(prediction_proba_all)\n", + "score = logreg.score(X_test, y_test)\n", + "print(f\"The logistic regression achieved a score of {score} on the test set!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1f93fc0f-6098-4769-9ae1-d29b1d2eeb60", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " tvt_split true_label predicted_probability \\\n", + "id \n", + "0 train 2 [0.341640428457359, 0.37274216307915675, 0.199... \n", + "1 test 2 [0.01570878205315649, 0.054060378948318714, 0.... \n", + "2 test 2 [0.2531798480163758, 0.3423762096888793, 0.354... \n", + "3 test 2 [0.0005548355525437238, 0.004663567070195263, ... \n", + "4 train 2 [0.016385208959930442, 0.05624561954206042, 0.... \n", + ".. ... ... ... \n", + "553 train 0 [0.34832839660017373, 0.3663682449303219, 0.18... \n", + "554 train 0 [0.38878744558319256, 0.38241354770404473, 0.1... \n", + "555 train 1 [0.3462852791010175, 0.36349669071444046, 0.18... \n", + "556 train 1 [0.37233793317845687, 0.3693151482651456, 0.17... \n", + "557 train 0 [0.3484558855856061, 0.35736552827870616, 0.19... \n", + "\n", + " predicted_class samples \n", + "id \n", + "0 1 spik... \n", + "1 2 spik... \n", + "2 2 spik... \n", + "3 2 spik... \n", + "4 2 spik... \n", + ".. ... ... \n", + "553 1 spi... \n", + "554 0 spi... \n", + "555 1 spi... \n", + "556 0 spi... \n", + "557 1 spi... \n", + "\n", + "[558 rows x 5 columns]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rly/Documents/NWB/hdmf/src/hdmf/container.py:286: UserWarning: TrainValidationTestSplit is experimental -- it may be removed in the future and is not guaranteed to maintain backward compatibility\n", + " warn(_exp_warn_msg(cls))\n", + "/Users/rly/Documents/NWB/hdmf/src/hdmf/utils.py:668: UserWarning: Column 'samples' is predefined in ResultsTable with table=True which does not match the entered table argument. The predefined table spec will be ignored. Please ensure the new column complies with the spec. This will raise an error in a future version of HDMF.\n", + " return func(args[0], **pargs)\n", + "/Users/rly/Documents/NWB/hdmf/src/hdmf/utils.py:668: UserWarning: Column 'samples' is predefined in ResultsTable with class= which does not match the entered col_cls argument. The predefined class spec will be ignored. Please ensure the new column complies with the spec. This will raise an error in a future version of HDMF.\n", + " return func(args[0], **pargs)\n" + ] + } + ], + "source": [ + "results_table = ResultsTable(\n", + " name=\"logistic_regression_results\",\n", + " description=\"Results of a simplelogisitic regression on the units table\",\n", + " n_samples=len(units),\n", + ")\n", + "results_table.add_tvt_split(tvt)\n", + "results_table.add_true_label(y)\n", + "results_table.add_predicted_probability(prediction_proba_all)\n", + "results_table.add_predicted_class(predictions_all)\n", + "results_table.add_samples(data=np.arange(len(X)), description=\"all the samples\", table=units)\n", + "print(results_table.to_dataframe())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6282e998-39ec-4458-9b78-fecda2544ede", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
spike_timesunit_namepresence_ratio_standard_deviationcontaminationnoise_cutoffmean_relative_depthsliding_refractory_period_violationcosmos_locationmaximum_amplitudemaximum_amplitude_channel...median_amplitudedriftminimum_amplitudespike_countfiring_ratemissed_spikes_estimatestandard_deviation_amplitudeallen_locationspike_amplitudespresence_ratio
id
0[0.019766666666666665, 0.028533333333333334, 0...125.2585540.083082136.18687480.00.0MB0.0001756.0...0.00008768480.6974490.0000824319.00.9952990.50.725555RN[8.288184885804525e-05, 0.00011442759068298115...0.997701
\n", + "

1 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " spike_times unit_name \\\n", + "id \n", + "0 [0.019766666666666665, 0.028533333333333334, 0... 12 \n", + "\n", + " presence_ratio_standard_deviation contamination noise_cutoff \\\n", + "id \n", + "0 5.258554 0.083082 136.186874 \n", + "\n", + " mean_relative_depth sliding_refractory_period_violation cosmos_location \\\n", + "id \n", + "0 80.0 0.0 MB \n", + "\n", + " maximum_amplitude maximum_amplitude_channel ... median_amplitude \\\n", + "id ... \n", + "0 0.000175 6.0 ... 0.000087 \n", + "\n", + " drift minimum_amplitude spike_count firing_rate \\\n", + "id \n", + "0 68480.697449 0.000082 4319.0 0.995299 \n", + "\n", + " missed_spikes_estimate standard_deviation_amplitude allen_location \\\n", + "id \n", + "0 0.5 0.725555 RN \n", + "\n", + " spike_amplitudes presence_ratio \n", + "id \n", + "0 [8.288184885804525e-05, 0.00011442759068298115... 0.997701 \n", + "\n", + "[1 rows x 25 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_table[\"samples\"][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "66edd762-d727-425b-ac2c-95b67d786d3f", + "metadata": {}, + "outputs": [], + "source": [ + "# from hdmf.common import get_hdf5io, SimpleMultiContainer\n", + "\n", + "# units.reset_parent()\n", + "# cannot reset container source\n", + "# container = SimpleMultiContainer(name=\"root\", containers=[units, results_table])\n", + "\n", + "# # write to a new file\n", + "# with get_hdf5io(\"results.nwb\", \"w\") as write_io:\n", + "# # TODO allow storage of results in a different file from raw data but maintain DTR link\n", + "# write_io.write(container)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5653415e-fdd4-455d-a75a-7ce51f7b4aad", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + "

logistic_regression_results (ResultsTable)

description: Results of a simplelogisitic regression on the units table
table\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
tvt_splittrue_labelpredicted_probabilitypredicted_classsamples
id
002[0.341640428457359, 0.37274216307915675, 0.19908925980832884, 0.04735197485904632, 0.03917617379610905]10
122[0.01570878205315649, 0.054060378948318714, 0.8931939200349566, 0.03703690835810088, 1.0605467273059379e-08]21
222[0.2531798480163758, 0.3423762096888793, 0.35487354719878633, 0.04807237308472253, 0.0014980220112362013]22
322[0.0005548355525437238, 0.004663567070195263, 0.9809924809575309, 0.013789116419647118, 8.298585887602905e-14]23

... and 554 more rows.

" + ], + "text/plain": [ + "logistic_regression_results hdmf_ml.results_table.ResultsTable at 0x13262690576\n", + "Fields:\n", + " colnames: ['tvt_split' 'true_label' 'predicted_probability' 'predicted_class'\n", + " 'samples']\n", + " columns: (\n", + " tvt_split ,\n", + " tvt_split_elements ,\n", + " true_label ,\n", + " predicted_probability ,\n", + " predicted_class ,\n", + " samples \n", + " )\n", + " description: Results of a simplelogisitic regression on the units table\n", + " id: id " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nwbfile.add_analysis(results_table)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f49c5210-b071-4259-b741-2ddb2d2353db", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "results_table.model = \"Bloom v1.3\" # not the actual model, just a placeholder\n", + "# annotate the model with a DOI using HDMF HERD\n", + "herd = HERD()\n", + "herd.add_ref(\n", + " file=nwbfile,\n", + " container=results_table,\n", + " attribute=\"model\",\n", + " key='Bloom v1.3',\n", + " entity_id='doi:10.57967/hf/0003',\n", + " entity_uri='https://doi.org/10.57967/hf/0003'\n", + ")\n", + "herd.to_zip(path='./HERD.zip')\n", + "\n", + "for x in list(nwbfile.acquisition.keys()):\n", + " nwbfile.acquisition.pop(x)\n", + "\n", + "with NWBHDF5IO(\"results.nwb\", \"w\") as export_io:\n", + " # TODO allow storage of results in a different file from raw data but maintain DTR link\n", + " export_io.export(src_io=io, nwbfile=nwbfile)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22fc1189-0070-4040-9078-cc928bafc4ee", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/example.py b/example.py new file mode 100644 index 0000000..427d5fc --- /dev/null +++ b/example.py @@ -0,0 +1,110 @@ +from hdmf_ai import ResultsTable # NOTE: because hdmf_ai modifies the hdmf common namespace, it is important that hdmf_ai is imported before pynwb +from hdmf.common import HERD +import numpy as np +from pynwb import NWBHDF5IO +from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import LabelEncoder + +filepath = "/Users/rly/Documents/NWB_Data/dandisets/000409/sub-CSHL047/sub-CSHL047_ses-b52182e7-39f6-4914-9717-136db589706e_behavior+ecephys+image.nwb" +io = NWBHDF5IO(filepath, "r") +nwbfile = io.read() + +# the NWB Units table stores information about the sorted single units (putative neurons) after preprocessing +# and spike sorting. each row represents a single unit. this dataset includes many metadata fields (table columns) for +# each unit. +units = nwbfile.units + +# the Units table can be most readiy viewed as a pandas DataFrame +units_df = units.to_dataframe() +print(units_df) +print(f"There are {len(units_df)} units in this dataset.") + +# run a simple classifier on the units data to predict the location (brain area) of the unit based on +# the amplitude, firing rate, spike count, and presence ratio. there are several ways to label the brain +# area of a unit. here, we use the label using the coarsest atlas, the Cosmos atlas, which has a total +# of 12 annotation regions. in this dataset, there are 5 unique Cosmos locations. +cosmos_location = units_df["cosmos_location"].to_numpy() +enc = LabelEncoder() +y = np.uint(enc.fit_transform(cosmos_location)) +unique_labels = enc.classes_ +print(f"There are {len(unique_labels)} unique Cosmos locations in this dataset.") + +# split the data into training and test sets +# TODO integrate with sklearn.model_selection.train_test_split +proportion_train = 0.7 +n_train_samples = int(np.round(proportion_train * len(units_df))) +n_test_samples = len(units) - n_train_samples +# train = 0, validate = 1, test = 2 +tvt = np.array([0] * n_train_samples + [2] * n_test_samples) +np.random.shuffle(tvt) + +feature_names = [ + "median_amplitude", + "standard_deviation_amplitude", + "firing_rate", + "spike_count", + "presence_ratio", +] +X = units_df[feature_names] +X_train = X[tvt == 0] +y_train = y[tvt == 0] +X_test = X[tvt == 2] +y_test = y[tvt == 2] + +# run logistic regression +logreg = LogisticRegression() +logreg.fit(X_train, y_train) +predictions_all = logreg.predict(X) +prediction_proba_all = logreg.predict_proba(X) +print(prediction_proba_all) +score = logreg.score(X_test, y_test) +print(f"The logistic regression achieved a score of {score} on the test set!") + +results_table = ResultsTable( + name="logistic_regression_results", + description="Results of a simple logisitic regression on the units table", + n_samples=len(units), +) +results_table.add_tvt_split(tvt) +results_table.add_true_label(cosmos_location) # use the text labels which will become an EnumData with uint encoding +results_table.add_predicted_probability(prediction_proba_all) +results_table.add_predicted_class(predictions_all) +results_table.add_samples(data=np.arange(len(X)), description="all the samples", table=units) +# TODO address len(id) mismatch when adding as first column +# TODO address warnings about mismatch with predefined spec +# TODO demonstrate adding custom column +results_table.add_column(name="custom_metadata", data=np.random.rand(len(X)), description="random data") +print(results_table.to_dataframe()) + +# add the results table to the in-memory NWB file +nwbfile.add_analysis(results_table) + +# store metadata about the model +# NOTE: not the actual model, just a placeholder for demonstration purposes +results_table.pre_trained_model = "Bloom v1.3" +# annotate the model with a DOI using HDMF HERD +herd = HERD() +herd.add_ref( + file=nwbfile, + container=results_table, + attribute="pre_trained_model", + key=results_table.pre_trained_model, + entity_id='doi:10.57967/hf/0003', + entity_uri='https://doi.org/10.57967/hf/0003' +) +herd.to_zip(path='./HERD.zip') + +# remove all the voltage recording raw data which is not needed for this analysis and takes up a lot of space +for x in list(nwbfile.acquisition.keys()): + nwbfile.acquisition.pop(x) + +with NWBHDF5IO("results.nwb", "w") as export_io: + # TODO allow storage of results in a different file from input data but maintain DTR link + export_io.export(src_io=io, nwbfile=nwbfile) + +io.close() + +with NWBHDF5IO("results.nwb", "r") as read_io: + read_nwbfile = read_io.read() + print(read_nwbfile.analysis["logistic_regression_results"].to_dataframe()) + # TODO check values From e34d812bf9f14cd70f65ece3a2c4acb964ddacce Mon Sep 17 00:00:00 2001 From: rly Date: Wed, 10 Apr 2024 02:56:40 -0700 Subject: [PATCH 2/2] Update ignore --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 515ec96..1f9e4fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ exclude = [ [tool.ruff.lint.per-file-ignores] "src/hdmf_ai/__init__.py" = ["E402", "F401"] +"example.py" = ["E501", "T201"] [tool.ruff.lint.mccabe] max-complexity = 17 \ No newline at end of file