From 04358d690f5483f5522255e28d77efa51bbd936e Mon Sep 17 00:00:00 2001 From: patel-zeel Date: Sun, 26 Nov 2023 20:40:16 +0530 Subject: [PATCH] add torch dataloaders notebook --- .gitignore | 8 +- posts/2023-11-26-Torch-DataLoaders.ipynb | 1196 ++++++++++++++++++++++ 2 files changed, 1203 insertions(+), 1 deletion(-) create mode 100644 posts/2023-11-26-Torch-DataLoaders.ipynb diff --git a/.gitignore b/.gitignore index 11a4213..7a86fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ /.quarto/ _site/ -posts/logs \ No newline at end of file +posts/logs + +data/ +*.pdf +*.xlsx +launch.json +settings.json \ No newline at end of file diff --git a/posts/2023-11-26-Torch-DataLoaders.ipynb b/posts/2023-11-26-Torch-DataLoaders.ipynb new file mode 100644 index 0000000..214e58a --- /dev/null +++ b/posts/2023-11-26-Torch-DataLoaders.ipynb @@ -0,0 +1,1196 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "author: Zeel B Patel\n", + "badges: true\n", + "categories: ML\n", + "description: An exploratory analysis of various dataset handling processes to optimize memory, diskspace and speed.\n", + "title: Data Handling for Large Scale ML\n", + "date: '2023-09-30'\n", + "toc: true\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from numcodecs import GZip, Zstd, Blosc\n", + "\n", + "from time import time, sleep\n", + "from tqdm import tqdm\n", + "from glob import glob\n", + "from os.path import join\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from joblib import Parallel, delayed\n", + "import xarray as xr\n", + "import numpy as np\n", + "\n", + "from torchvision.models import vit_b_16\n", + "from astra.torch.models import ViTClassifier\n", + "from astra.torch.utils import train_fn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating Custom Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:  (channel: 3, col: 224, lat_lag: 5, lon_lag: 5, row: 224)\n",
+       "Coordinates:\n",
+       "  * channel  (channel) uint8 0 1 2\n",
+       "  * col      (col) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223\n",
+       "    lat      float64 ...\n",
+       "  * lat_lag  (lat_lag) int8 -2 -1 0 1 2\n",
+       "    lon      float64 ...\n",
+       "  * lon_lag  (lon_lag) int8 -2 -1 0 1 2\n",
+       "  * row      (row) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223\n",
+       "Data variables:\n",
+       "    data     (lat_lag, lon_lag, row, col, channel) uint8 dask.array<chunksize=(3, 3, 112, 112, 3), meta=np.ndarray>\n",
+       "    label    (lat_lag, lon_lag) int8 dask.array<chunksize=(5, 5), meta=np.ndarray>
" + ], + "text/plain": [ + "\n", + "Dimensions: (channel: 3, col: 224, lat_lag: 5, lon_lag: 5, row: 224)\n", + "Coordinates:\n", + " * channel (channel) uint8 0 1 2\n", + " * col (col) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223\n", + " lat float64 ...\n", + " * lat_lag (lat_lag) int8 -2 -1 0 1 2\n", + " lon float64 ...\n", + " * lon_lag (lon_lag) int8 -2 -1 0 1 2\n", + " * row (row) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223\n", + "Data variables:\n", + " data (lat_lag, lon_lag, row, col, channel) uint8 dask.array\n", + " label (lat_lag, lon_lag) int8 dask.array" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "base_path = \"/home/patel_zeel/bkdb/bangladesh_pnas_pred/team1\"\n", + "xr.open_zarr(join(base_path, \"21.11,92.18.zarr\"), consolidated=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class XarrayDataset(Dataset):\n", + " def __init__(self, path, max_files):\n", + " self.base_path = path\n", + " self.all_files = glob(join(path, \"*.zarr\"))[:max_files]\n", + " self.all_files.sort()\n", + " self.lat_lags = [-2, -1, 0, 1, 2]\n", + " self.lon_lags = [-2, -1, 0, 1, 2]\n", + " \n", + " def __len__(self):\n", + " return len(self.all_files) * 25\n", + " \n", + " def __getitem__(self, idx):\n", + " file_idx = idx // 25\n", + " local_idx = idx % 25\n", + " lat_lag = self.lat_lags[local_idx // 5]\n", + " lon_lag = self.lon_lags[local_idx % 5]\n", + " \n", + " with xr.open_zarr(self.all_files[file_idx], consolidated=False) as ds:\n", + " img = ds.isel(lat_lag=lat_lag, lon_lag=lon_lag)['data']\n", + " # swap dims to make it [\"channel\", \"row\", \"col\"]\n", + " img = img.transpose(\"channel\", \"row\", \"col\").values\n", + " return img.astype(np.float32) / 255" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def process_it(dataset, batch_size, num_workers):\n", + " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, pin_memory_device='cuda', prefetch_factor=num_workers//2)\n", + "\n", + " model = ViTClassifier(vit_b_16, None, 2).to('cuda')\n", + " optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + " pbar = tqdm(dataloader)\n", + "\n", + " train_init = time()\n", + " iter_times = []\n", + " for batch in pbar:\n", + " init = time()\n", + " optimizer.zero_grad()\n", + " out = model(batch.to('cuda'))\n", + " loss = nn.CrossEntropyLoss()(out, torch.randint(0, 2, (batch.shape[0],)).to('cuda'))\n", + " loss.backward()\n", + " optimizer.step()\n", + " time_taken = time() - init\n", + " pbar.set_description(f\"Time: {time_taken:.4f}\")\n", + " iter_times.append(time_taken)\n", + " \n", + " total_time = time() - train_init\n", + " print(f\"Average Iteration Processing Time: {np.mean(iter_times):.4f} +- {np.std(iter_times):.4f}\")\n", + " print(f\"Total time for all iterations: {np.sum(iter_times):.4f}\")\n", + " print(f\"Total Wall Time per iteration: {total_time / len(dataloader):.4f}\")\n", + " print(f\"Total Wall Time: {total_time:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Global config" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "max_files = 500" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Time: 1.5727: 100%|██████████| 49/49 [01:27<00:00, 1.78s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Iteration Processing Time: 1.6474 +- 0.2618\n", + "Total time for all iterations: 80.7246\n", + "Total Wall Time per iteration: 1.7799\n", + "Total Wall Time: 87.2134\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "batch_size = 256\n", + "num_workers = 32\n", + "\n", + "dataset = XarrayDataset(base_path, max_files=max_files)\n", + "process_it(dataset, batch_size, num_workers)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Time: 2.6731: 100%|██████████| 25/25 [01:32<00:00, 3.69s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Iteration Processing Time: 3.1956 +- 0.3949\n", + "Total time for all iterations: 79.8897\n", + "Total Wall Time per iteration: 3.6910\n", + "Total Wall Time: 92.2762\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "batch_size = 512\n", + "num_workers = 16\n", + "\n", + "dataset = XarrayDataset(base_path, max_files=max_files)\n", + "process_it(dataset, batch_size, num_workers)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Time: 2.6726: 100%|██████████| 25/25 [01:32<00:00, 3.69s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Iteration Processing Time: 3.1938 +- 0.4043\n", + "Total time for all iterations: 79.8451\n", + "Total Wall Time per iteration: 3.6908\n", + "Total Wall Time: 92.2689\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "batch_size = 512\n", + "num_workers = 32\n", + "\n", + "dataset = XarrayDataset(base_path, max_files=max_files)\n", + "process_it(dataset, batch_size, num_workers)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Time: 0.8377: 9%|▉ | 9/98 [00:11<01:19, 1.12it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Time: 0.7455: 100%|██████████| 98/98 [01:25<00:00, 1.15it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Iteration Processing Time: 0.8269 +- 0.0551\n", + "Total time for all iterations: 81.0315\n", + "Total Wall Time per iteration: 0.8716\n", + "Total Wall Time: 85.4156\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "batch_size = 128\n", + "num_workers = 32\n", + "\n", + "dataset = XarrayDataset(base_path, max_files=max_files)\n", + "process_it(dataset, batch_size, num_workers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Is .nc better than zarr?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.8G\t/home/patel_zeel/bkdb/bangladesh_pnas_pred/team1\n" + ] + }, + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "os.system(f\"du -sh {base_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1501 [00:00chw\", img).astype(np.float32) / 255)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "nc_path = \"/tmp/nc_check_compressed\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 500/500 [00:02<00:00, 246.27it/s]\n", + "Time: 0.7414: 100%|██████████| 98/98 [01:25<00:00, 1.15it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average Iteration Processing Time: 0.8260 +- 0.0530\n", + "Total time for all iterations: 80.9527\n", + "Total Wall Time per iteration: 0.8725\n", + "Total Wall Time: 85.5034\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "batch_size = 128\n", + "num_workers = 32\n", + "\n", + "dataset = XarrayDatasetWithNC(nc_path, max_files=max_files)\n", + "process_it(dataset, batch_size, num_workers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Additional experiments" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time to process 60000 images: 6.793048000000001 minutes\n" + ] + } + ], + "source": [ + "n_images = 60000\n", + "t = 84.9131/500/25 * n_images\n", + "print(f\"Time to process {n_images} images: \", t/60, \"minutes\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1501/1501 [02:44<00:00, 9.13it/s]\n" + ] + } + ], + "source": [ + "files = glob(join(base_path, \"*.zarr\"))\n", + "data_tensors = []\n", + "for file in tqdm(files):\n", + " with xr.open_zarr(file, consolidated=False) as ds:\n", + " # print(ds['data'].values.reshape(-1, 224, 224, 3))\n", + " data_tensors.append(torch.tensor(np.einsum(\"nhwc->nchw\", ds['data'].values.reshape(-1, 224, 224, 3)).astype(np.float16) / 255))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([37525, 3, 224, 224])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_in_one = torch.concat(data_tensors, dim=0)\n", + "all_in_one.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "all_in_one = all_in_one.to('cuda')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Insights" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* GPU Memory consumption is `17776MiB / 81920MiB` for batch size 128 for ViT model\n", + "* Uploading torch.Size([37525, 3, 224, 224]) of float32 data to GPU takes `22054MiB / 81920MiB` of GPU Memory. Same data with float16 takes `11202MiB / 81920MiB` of GPU Memory.\n", + "* It seems `.nc` or `.zarr` are not making much difference in terms of time and/or memory." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}