diff --git a/docs/_config.yml b/docs/_config.yml index 1cb5dc24..c0711748 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -12,6 +12,7 @@ execute: execute_notebooks: cache exclude_patterns: - clay-v0-*.ipynb + - partial-inputs.ipynb # Define the name of the latex output file for PDF builds latex: diff --git a/docs/_toc.yml b/docs/_toc.yml index 299721a6..e05334c7 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -40,6 +40,8 @@ parts: file: clay-v0-location-embeddings - title: Interpolating images in embedding space file: clay-v0-interpolation + - title: Generate embeddings from partial inputs + file: partial-inputs - caption: About Clay chapters: - title: GitHub diff --git a/docs/partial-inputs.ipynb b/docs/partial-inputs.ipynb new file mode 100644 index 00000000..54cf3bc3 --- /dev/null +++ b/docs/partial-inputs.ipynb @@ -0,0 +1,451 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8dd35554-a9dd-49cf-b9fa-24fa8ae6cecf", + "metadata": {}, + "source": [ + "# Burn scar analysis using embeddings from partial inputs\n", + "This notebook contains a complete example for how to run Clay. It\n", + "combines the following three different aspects\n", + "\n", + "1. Create single-chip datacubes with time series data for a location and a date range\n", + "2. Run the model with partial inputs, in this case RGB + NIR\n", + "3. Study burn scares through the embeddings generated for that datacube\n", + "\n", + "## Let's start with importing and creating constants" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "b7bcff1e-bdb5-47f8-aa0e-d68d6fdd3476", + "metadata": {}, + "outputs": [], + "source": [ + "# Ensure working directory is the repo home\n", + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "15d65ec9-86aa-4275-89ba-ec79fdbad361", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "from pathlib import Path\n", + "\n", + "import geopandas as gpd\n", + "import matplotlib.pyplot as plt\n", + "import numpy\n", + "import pandas as pd\n", + "import pystac_client\n", + "import rasterio\n", + "import rioxarray # noqa: F401\n", + "import stackstac\n", + "import torch\n", + "from rasterio.enums import Resampling\n", + "from shapely import Point\n", + "from sklearn import decomposition\n", + "\n", + "from src.datamodule import ClayDataModule\n", + "from src.model_clay import CLAYModule\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "BAND_GROUPS = {\n", + " \"rgb\": [\"red\", \"green\", \"blue\"],\n", + " \"rededge\": [\"rededge1\", \"rededge2\", \"rededge3\", \"nir08\"],\n", + " \"nir\": [\n", + " \"nir\",\n", + " ],\n", + " \"swir\": [\"swir16\", \"swir22\"],\n", + " \"sar\": [\"vv\", \"vh\"],\n", + "}\n", + "\n", + "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", + "COLLECTION = \"sentinel-2-l2a\"" + ] + }, + { + "cell_type": "markdown", + "id": "a6341305-9c44-4a1e-847c-80d77b01c0bf", + "metadata": {}, + "source": [ + "## Search for imagery over an area of interest\n", + "In this example we use a location and date range to visualize a forest fire that happened in [Monchique in 2018](https://pt.wikipedia.org/wiki/Inc%C3%AAndio_de_Monchique_de_2018)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a1886f5a-8669-40e7-8fae-e45619570e3c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 12 items\n" + ] + } + ], + "source": [ + "# Point over Monchique Portugal\n", + "poi = 37.30939, -8.57207\n", + "\n", + "# Dates of a large forest fire\n", + "start = \"2018-07-01\"\n", + "end = \"2018-09-01\"\n", + "\n", + "catalog = pystac_client.Client.open(STAC_API)\n", + "\n", + "search = catalog.search(\n", + " collections=[COLLECTION],\n", + " datetime=f\"{start}/{end}\",\n", + " bbox=(poi[1] - 1e-5, poi[0] - 1e-5, poi[1] + 1e-5, poi[0] + 1e-5),\n", + " max_items=100,\n", + " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", + ")\n", + "\n", + "items = search.get_all_items()\n", + "\n", + "print(f\"Found {len(items)} items\")" + ] + }, + { + "cell_type": "markdown", + "id": "c4ba5c36-90a6-427c-80c5-2a83ad11a1b0", + "metadata": {}, + "source": [ + "## Download the data\n", + "Get the data into a numpy array and visualize the imagery. The burn scar is visible in the last five images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c371501c-3ef0-4507-9073-0521a1c733be", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract coordinate system from first item\n", + "epsg = items[0].properties[\"proj:epsg\"]\n", + "\n", + "# Convert point into the image projection\n", + "poidf = gpd.GeoDataFrame(\n", + " pd.DataFrame(),\n", + " crs=\"EPSG:4326\",\n", + " geometry=[Point(poi[1], poi[0])],\n", + ").to_crs(epsg)\n", + "\n", + "coords = poidf.iloc[0].geometry.coords[0]\n", + "\n", + "# Create bounds of the correct size, the model\n", + "# requires 512x512 pixels at 10m resolution.\n", + "bounds = (\n", + " coords[0] - 2560,\n", + " coords[1] - 2560,\n", + " coords[0] + 2560,\n", + " coords[1] + 2560,\n", + ")\n", + "\n", + "# Retrieve the pixel values, for the bounding box in\n", + "# the target projection. In this example we use only\n", + "# the RGB and NIR band groups.\n", + "stack = stackstac.stack(\n", + " items,\n", + " bounds=bounds,\n", + " snap_bounds=False,\n", + " epsg=epsg,\n", + " resolution=10,\n", + " dtype=\"float32\",\n", + " rescale=False,\n", + " fill_value=0,\n", + " assets=BAND_GROUPS[\"rgb\"] + BAND_GROUPS[\"nir\"],\n", + " resampling=Resampling.nearest,\n", + ")\n", + "\n", + "stack = stack.compute()\n", + "\n", + "stack.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(\n", + " row=\"time\", rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ce633fb1-fc82-4c88-8204-cda47aa9c874", + "metadata": {}, + "source": [ + "![Minicube visualization](https://github.com/Clay-foundation/model/assets/901647/c6e924e5-6ba1-4924-b99a-df8b90731a5f)" + ] + }, + { + "cell_type": "markdown", + "id": "77e7c22c-1bfd-4281-bb12-8330c3eedc25", + "metadata": {}, + "source": [ + "## Write data to tif files\n", + "To use the mini datacube in the Clay dataloader, we need to write the\n", + "images to tif files on disk. These tif files are then used by the Clay\n", + "data loader for creating embeddings below." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6509c3b2-a67c-447d-a7a1-e5fbcc1e35b5", + "metadata": {}, + "outputs": [], + "source": [ + "outdir = Path(\"data/minicubes\")\n", + "assert outdir.exists()\n", + "\n", + "# Write tile to output dir\n", + "for tile in stack:\n", + " # Grid code like MGRS-29SNB\n", + " mgrs = str(tile.coords[\"grid:code\"].values).split(\"-\")[1]\n", + " date = str(tile.time.values)[:10]\n", + "\n", + " name = \"{dir}/claytile_{mgrs}_{date}.tif\".format(\n", + " dir=outdir,\n", + " mgrs=mgrs,\n", + " date=date.replace(\"-\", \"\"),\n", + " )\n", + " tile.rio.to_raster(name, compress=\"deflate\")\n", + "\n", + " with rasterio.open(name, \"r+\") as rst:\n", + " rst.update_tags(date=date)" + ] + }, + { + "cell_type": "markdown", + "id": "ebc4b6ee-db58-4005-9689-a7d0acdc6a79", + "metadata": { + "scrolled": true + }, + "source": [ + "## Create embeddings\n", + "Now switch gears and load the tiles to create embeddings and analyze them. \n", + "\n", + "The model checkpoint can be loaded directly from huggingface, and the data\n", + "directory points to the directory we created in the steps above.\n", + "\n", + "Note that the normalization parameters for the data module need to be \n", + "adapted based on the band groups that were selected as partial input. The\n", + "full set of normalization parameters can be found [here](https://github.com/Clay-foundation/model/blob/main/src/datamodule.py#L108)." + ] + }, + { + "cell_type": "markdown", + "id": "d89e0135-9473-4f76-9f09-e4e295dd51c9", + "metadata": {}, + "source": [ + "### Load the model and set up the data module" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "301ee2db-c5fc-4628-b837-12e6ea477415", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of chips: 12\n" + ] + } + ], + "source": [ + "DATA_DIR = \"data/minicubes\"\n", + "CKPT_PATH = \"https://huggingface.co/made-with-clay/Clay1/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"\n", + "\n", + "# Load model\n", + "rgb_model = CLAYModule.load_from_checkpoint(\n", + " CKPT_PATH,\n", + " mask_ratio=0.0,\n", + " band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,)},\n", + " bands=4,\n", + " strict=False, # ignore the extra parameters in the checkpoint\n", + ")\n", + "# Set the model to evaluation mode\n", + "rgb_model.eval()\n", + "\n", + "\n", + "# Load the datamodule, with the reduced set of\n", + "class ClayDataModuleRGB(ClayDataModule):\n", + " MEAN = [\n", + " 1369.03, # red\n", + " 1597.68, # green\n", + " 1741.10, # blue\n", + " 2858.43, # nir\n", + " ]\n", + " STD = [\n", + " 2026.96, # red\n", + " 2011.88, # green\n", + " 2146.35, # blue\n", + " 2016.38, # nir\n", + " ]\n", + "\n", + "\n", + "data_dir = Path(DATA_DIR)\n", + "\n", + "dm = ClayDataModuleRGB(data_dir=str(data_dir.absolute()), batch_size=20)\n", + "dm.setup(stage=\"predict\")\n", + "trn_dl = iter(dm.predict_dataloader())" + ] + }, + { + "cell_type": "markdown", + "id": "db3f3e5e-8668-4830-9c77-cc1d8cb35234", + "metadata": {}, + "source": [ + "### Create the embeddings for the images over the forest fire\n", + "This will loop through the images returned by the data loader\n", + "and evaluate the model for each one of the images. The raw\n", + "embeddings are reduced to mean values to simplify the data." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c5762240-9d22-4ebd-8e39-83fc6594a459", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average embeddings have shape (12, 768)\n" + ] + } + ], + "source": [ + "embeddings = []\n", + "\n", + "for batch in trn_dl:\n", + " with torch.inference_mode():\n", + " # Move data from to the device of model\n", + " batch[\"pixels\"] = batch[\"pixels\"].to(rgb_model.device)\n", + " # Pass just the specific band through the model\n", + " batch[\"timestep\"] = batch[\"timestep\"].to(rgb_model.device)\n", + " batch[\"latlon\"] = batch[\"latlon\"].to(rgb_model.device)\n", + "\n", + " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", + " (\n", + " unmasked_patches,\n", + " unmasked_indices,\n", + " masked_indices,\n", + " masked_matrix,\n", + " ) = rgb_model.model.encoder(batch)\n", + "\n", + " embeddings.append(unmasked_patches.detach().cpu().numpy())\n", + "\n", + "embeddings = numpy.vstack(embeddings)\n", + "\n", + "embeddings_mean = embeddings[:, :-2, :].mean(axis=1)\n", + "\n", + "print(f\"Average embeddings have shape {embeddings_mean.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "72db5745-21c6-4b8e-b8f7-cb48c0f9c9ef", + "metadata": {}, + "source": [ + "## Analyze embeddings\n", + "Now we can make a simple analysis of the embeddings. We reduce all the\n", + "embeddings to a single number using Principle Component Analysis. Then\n", + "we can plot the principal components. The effect of the fire on the\n", + "embeddings is clearly visible. We use the following color code in the graph:\n", + "\n", + "| Color | Interpretation |\n", + "|---|---|\n", + "| Green | Cloudy Images |\n", + "| Blue | Before the fire |\n", + "| Red | After the fire |" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "88f3b2dc-8f2a-447b-a6af-b04e0d1ff61c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pca = decomposition.PCA(n_components=1)\n", + "pca_result = pca.fit_transform(embeddings_mean)\n", + "\n", + "plt.xticks(rotation=-30)\n", + "# All points\n", + "plt.scatter(stack.time, pca_result, color=\"blue\")\n", + "\n", + "# Cloudy images\n", + "plt.scatter(stack.time[0], pca_result[0], color=\"green\")\n", + "plt.scatter(stack.time[2], pca_result[2], color=\"green\")\n", + "\n", + "# After fire\n", + "plt.scatter(stack.time[-5:], pca_result[-5:], color=\"red\")" + ] + }, + { + "cell_type": "markdown", + "id": "a16fbdb8-1c2d-4c84-8526-283fa14faa53", + "metadata": {}, + "source": [ + "In the plot above, each image embedding is one point. One can clearly \n", + "distinguish the two cloudy images and the values after the fire are\n", + "consistently low." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/datamodule.py b/src/datamodule.py index 8a7d6d5c..e653bb49 100644 --- a/src/datamodule.py +++ b/src/datamodule.py @@ -157,7 +157,7 @@ def setup(self, stage: Literal["fit", "predict"] | None = None) -> None: dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.data_dir]) chips_path = list(dp.list_files_by_s3(masks="*.tif")) else: # if self.data_dir is a local data path - chips_path = list(Path(self.data_dir).glob("**/*.tif")) + chips_path = sorted(list(Path(self.data_dir).glob("**/*.tif"))) print(f"Total number of chips: {len(chips_path)}") if stage == "fit": diff --git a/src/model_clay.py b/src/model_clay.py index cffea57e..7b9dc134 100644 --- a/src/model_clay.py +++ b/src/model_clay.py @@ -791,6 +791,7 @@ def __init__( # noqa: PLR0913 b1=0.9, b2=0.95, embeddings_level: Literal["mean", "patch", "group"] = "mean", + band_groups=None, ): super().__init__() self.save_hyperparameters(logger=True) @@ -801,11 +802,14 @@ def __init__( # noqa: PLR0913 "large": clay_large, } if model_size in model_map: - self.model = model_map[model_size]( - mask_ratio=mask_ratio, - image_size=image_size, - patch_size=patch_size, - ) + model_args = { + "mask_ratio": mask_ratio, + "image_size": image_size, + "patch_size": patch_size, + } + if band_groups: + model_args["band_groups"] = band_groups + self.model = model_map[model_size](**model_args) else: raise ValueError( f"Invalid model size {model_size}. Expected one of {model_map.keys()}"