From ffb9d886b7b73a44aed0a90637ae96719e62346f Mon Sep 17 00:00:00 2001 From: lyhue1991 Date: Fri, 10 Feb 2023 01:06:00 +0800 Subject: [PATCH] 3.7.2 --- "1\357\274\214kerasmodel_example.ipynb" | 4472 +---------------------- push2github.md | 18 +- setup.py | 4 +- torchkeras/data.py | 14 + 4 files changed, 59 insertions(+), 4449 deletions(-) create mode 100644 torchkeras/data.py diff --git "a/1\357\274\214kerasmodel_example.ipynb" "b/1\357\274\214kerasmodel_example.ipynb" index 69693dd..659ea9b 100644 --- "a/1\357\274\214kerasmodel_example.ipynb" +++ "b/1\357\274\214kerasmodel_example.ipynb" @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "id": "3e915b9b", "metadata": {}, "outputs": [], @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "id": "a1920bc7", "metadata": {}, "outputs": [], @@ -47,14 +47,6 @@ "import torchkeras #Attention this line \n" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "101cd9c3-5061-450e-a022-df09afed935e", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "id": "e73ca619", @@ -65,113 +57,10 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "25ab4939", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "383bed3915db4058b97803e6ce7dfb56", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/9912422 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-02-09T19:34:21.192036\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.6.2, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", @@ -1047,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "6174ae8a", "metadata": {}, "outputs": [], @@ -1095,40 +168,10 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "ad17716f", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--------------------------------------------------------------------------\n", - "Layer (type) Output Shape Param #\n", - "==========================================================================\n", - "Conv2d-1 [-1, 32, 26, 26] 320\n", - "MaxPool2d-2 [-1, 32, 13, 13] 0\n", - "Conv2d-3 [-1, 64, 9, 9] 51,264\n", - "MaxPool2d-4 [-1, 64, 4, 4] 0\n", - "Dropout2d-5 [-1, 64, 4, 4] 0\n", - "AdaptiveMaxPool2d-6 [-1, 64, 1, 1] 0\n", - "Flatten-7 [-1, 64] 0\n", - "Linear-8 [-1, 32] 2,080\n", - "ReLU-9 [-1, 32] 0\n", - "Linear-10 [-1, 10] 330\n", - "==========================================================================\n", - "Total params: 53,994\n", - "Trainable params: 53,994\n", - "Non-trainable params: 0\n", - "--------------------------------------------------------------------------\n", - "Input size (MB): 0.000069\n", - "Forward/backward pass size (MB): 0.263016\n", - "Params size (MB): 0.205971\n", - "Estimated Total Size (MB): 0.469055\n", - "--------------------------------------------------------------------------\n" - ] - } - ], + "outputs": [], "source": [ "model = torchkeras.KerasModel(net,\n", " loss_fn = nn.CrossEntropyLoss(),\n", @@ -1158,135 +201,10 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "54bcc553", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[0;31m<<<<<< 🐌 cpu is used >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:34:30\n", - "Epoch 1 / 15\n", - "\n", - "100%|████████████████| 469/469 [00:18<00:00, 24.99it/s, lr=0.001, train_acc=0.845, train_loss=0.527]\n", - "100%|████████████████████████████████| 79/79 [00:01<00:00, 46.42it/s, val_acc=0.954, val_loss=0.149]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9544000029563904 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:34:50\n", - "Epoch 2 / 15\n", - "\n", - "100%|█████████████████| 469/469 [00:18<00:00, 25.71it/s, lr=0.001, train_acc=0.96, train_loss=0.134]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 48.45it/s, val_acc=0.973, val_loss=0.0844]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9728000164031982 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:35:10\n", - "Epoch 3 / 15\n", - "\n", - "100%|████████████████| 469/469 [00:18<00:00, 25.63it/s, lr=0.001, train_acc=0.97, train_loss=0.0977]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 47.24it/s, val_acc=0.974, val_loss=0.0777]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.974399983882904 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:35:30\n", - "Epoch 4 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.35it/s, lr=0.001, train_acc=0.975, train_loss=0.0808]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 46.35it/s, val_acc=0.977, val_loss=0.0656]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.977400004863739 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:35:51\n", - "Epoch 5 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.23it/s, lr=0.001, train_acc=0.978, train_loss=0.0702]\n", - "100%|████████████████████████████████| 79/79 [00:01<00:00, 46.67it/s, val_acc=0.98, val_loss=0.0643]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9797000288963318 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:36:11\n", - "Epoch 6 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.24it/s, lr=0.001, train_acc=0.981, train_loss=0.0611]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 47.87it/s, val_acc=0.982, val_loss=0.0554]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9815000295639038 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:36:31\n", - "Epoch 7 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.14it/s, lr=0.001, train_acc=0.983, train_loss=0.0528]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 48.14it/s, val_acc=0.982, val_loss=0.0566]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9815999865531921 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:36:51\n", - "Epoch 8 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.29it/s, lr=0.001, train_acc=0.984, train_loss=0.0482]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 46.61it/s, val_acc=0.983, val_loss=0.0568]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9825000166893005 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:37:12\n", - "Epoch 9 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 24.91it/s, lr=0.001, train_acc=0.986, train_loss=0.0432]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 46.04it/s, val_acc=0.985, val_loss=0.0444]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9850999712944031 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:37:32\n", - "Epoch 10 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.09it/s, lr=0.001, train_acc=0.987, train_loss=0.0409]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 47.21it/s, val_acc=0.984, val_loss=0.0473]\n", - "\n", - "================================================================================2023-02-09 19:37:52\n", - "Epoch 11 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.14it/s, lr=0.001, train_acc=0.988, train_loss=0.0372]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 47.82it/s, val_acc=0.987, val_loss=0.0445]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9868000149726868 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:38:13\n", - "Epoch 12 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 24.91it/s, lr=0.001, train_acc=0.988, train_loss=0.0343]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 46.46it/s, val_acc=0.985, val_loss=0.0479]\n", - "\n", - "================================================================================2023-02-09 19:38:33\n", - "Epoch 13 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.03it/s, lr=0.001, train_acc=0.989, train_loss=0.0335]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 47.64it/s, val_acc=0.988, val_loss=0.0418]\n", - "\u001b[0;31m<<<<<< reach best val_acc : 0.9876000285148621 >>>>>>\u001b[0m\n", - "\n", - "================================================================================2023-02-09 19:38:54\n", - "Epoch 14 / 15\n", - "\n", - "100%|████████████████| 469/469 [00:18<00:00, 24.98it/s, lr=0.001, train_acc=0.99, train_loss=0.0287]\n", - "100%|███████████████████████████████| 79/79 [00:01<00:00, 46.02it/s, val_acc=0.985, val_loss=0.0487]\n", - "\n", - "================================================================================2023-02-09 19:39:14\n", - "Epoch 15 / 15\n", - "\n", - "100%|███████████████| 469/469 [00:18<00:00, 25.03it/s, lr=0.001, train_acc=0.991, train_loss=0.0273]\n", - "100%|████████████████████████████████| 79/79 [00:01<00:00, 46.63it/s, val_acc=0.987, val_loss=0.041]\n", - " epoch train_loss train_acc lr val_loss val_acc\n", - "0 1 0.526868 0.845133 0.001 0.148538 0.9544\n", - "1 2 0.134293 0.959750 0.001 0.084383 0.9728\n", - "2 3 0.097699 0.970250 0.001 0.077715 0.9744\n", - "3 4 0.080804 0.974733 0.001 0.065645 0.9774\n", - "4 5 0.070243 0.977833 0.001 0.064307 0.9797\n", - "5 6 0.061092 0.981017 0.001 0.055448 0.9815\n", - "6 7 0.052803 0.982733 0.001 0.056596 0.9816\n", - "7 8 0.048157 0.984267 0.001 0.056754 0.9825\n", - "8 9 0.043176 0.986050 0.001 0.044400 0.9851\n", - "9 10 0.040888 0.986500 0.001 0.047307 0.9843\n", - "10 11 0.037161 0.987867 0.001 0.044459 0.9868\n", - "11 12 0.034350 0.988250 0.001 0.047916 0.9849\n", - "12 13 0.033464 0.989000 0.001 0.041772 0.9876\n", - "13 14 0.028671 0.990467 0.001 0.048730 0.9850\n", - "14 15 0.027258 0.991050 0.001 0.041001 0.9870\n" - ] - } - ], + "outputs": [], "source": [ "# if gpu/mps is available, will auto use it, otherwise cpu will be used.\n", "dfhistory=model.fit(train_data=dl_train, \n", @@ -1300,853 +218,10 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "1ddc3317", "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-02-09T19:39:40.860267\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.6.2, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# visual the prediction\n", "device = None\n", @@ -2178,7 +253,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "id": "520c6a1b", "metadata": {}, "outputs": [], @@ -2204,2146 +279,30 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "id": "a179179f-65ff-434e-b317-17ca56355e7e", "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-02-09T19:48:32.362008\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.6.2, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "plot_metric(dfhistory,'loss')" ] }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "id": "960119f4", "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-02-09T19:48:17.566941\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.6.2, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "plot_metric(dfhistory,\"acc\")" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "id": "62f8cf7b", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████| 79/79 [00:01<00:00, 46.50it/s, val_acc=0.988, val_loss=0.0418]\n" - ] - }, - { - "data": { - "text/plain": [ - "{'val_loss': 0.041772333724491295, 'val_acc': 0.9876000285148621}" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.evaluate(dl_val)" ] @@ -4366,371 +325,10 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "id": "91b21a4e-be1c-4fba-ae99-7fe4962c67b6", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y_pred = 0\n", - "y_prob = 0.9997945427894592\n" - ] - }, - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2023-02-09T19:48:42.484325\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.6.2, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "net = model.net\n", "net.eval();\n", @@ -4758,28 +356,10 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "id": "45089d19", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████| 79/79 [00:01<00:00, 48.51it/s, val_acc=0.988, val_loss=0.0418]\n" - ] - }, - { - "data": { - "text/plain": [ - "{'val_loss': 0.041772333724491295, 'val_acc': 0.9876000285148621}" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# used the saved model parameters \n", "net_clone = create_net() \n", diff --git a/push2github.md b/push2github.md index 64957d6..5bbaa8b 100644 --- a/push2github.md +++ b/push2github.md @@ -2,6 +2,7 @@ ```python %run setup.py sdist bdist_wheel + ``` ```python @@ -58,7 +59,7 @@ ``` ```python -!git push -f origin master +!git push origin master ``` ```python @@ -68,3 +69,18 @@ ```python ``` + +## gitignore + +```python +# %load .gitignore +.DS_store +.ipynb_checkpoints +.ipynb_checkpoints/* +torchkeras.egg-info/* +dist/* +build/* +torchkeras/__pycache__/ +.idea + +``` diff --git a/setup.py b/setup.py index 6f51607..95bcd08 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- import re from pathlib import Path -from setuptools import find_packages, setup +from setuptools import find_packages, setup, find_namespace_packages # Settings FILE = Path(__file__).resolve() @@ -32,7 +32,7 @@ def get_version(): ], long_description_content_type="text/markdown", url="https://github.com/lyhue1991/torchkeras", - packages=find_packages(), + packages=find_namespace_packages(exclude=['torchkeras.assets','data']), include_package_data=True, classifiers=[ "Programming Language :: Python :: 3", diff --git a/torchkeras/data.py b/torchkeras/data.py new file mode 100644 index 0000000..7c0628d --- /dev/null +++ b/torchkeras/data.py @@ -0,0 +1,14 @@ +from pathlib import Path +from PIL import Image +import os + +path = Path(__file__) + +def get_example_image(img_name='park.jpg'): + 'name can be bus.jpg / park.jpg / zidane.jpg' + img_path = str(path.parent/f"assets/{img_name}") + assert os.path.exists(img_path), 'img_name can only be bus.jpg / park.jpg / zidane.jpg' + return Image.open(img_path) + + + \ No newline at end of file