Skip to content

Commit

Permalink
Merge pull request #87 from kleok/dev
Browse files Browse the repository at this point in the history
CPU support, visualization, vector results
  • Loading branch information
kleok authored Sep 16, 2024
2 parents e8de3ae + f5b5fc7 commit 8615f8f
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 258 deletions.
307 changes: 56 additions & 251 deletions Floodpyapp_Vit.ipynb

Large diffs are not rendered by default.

33 changes: 29 additions & 4 deletions floodpy/FLOODPYapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@


from floodpy.utils.read_AOI import Coords_to_geojson, Input_vector_to_geojson
from floodpy.utils.geo_utils import create_polygon
from floodpy.utils.geo_utils import create_polygon, convert_to_vector
from floodpy.Download.Query_Sentinel_1_products import query_Sentinel_1
from floodpy.Download.Sentinel_1_download import download_S1_data
from floodpy.Download.Sentinel_1_orbits_download import download_S1_POEORB_orbits
from floodpy.Preprocessing_S1_data.DEM_funcs import calc_slope_mask

# Visualization
from floodpy.Visualization.plot_ERA5_data import plot_ERA5
from floodpy.Visualization.interactive_plotting import plot_interactive_map

# Processing
from floodpy.Download.Download_ERA5_precipitation import Get_ERA5_data
from floodpy.Download.Download_LandCover import worldcover
from floodpy.Preprocessing_S1_data.Preprocessing_S1_data import Run_Preprocessing
Expand Down Expand Up @@ -188,7 +190,30 @@ def calc_floodmap_dataset(self):
self.Flood_map_dataset_filename = os.path.join(self.Results_dir, 'Flood_map_dataset_{}.nc'.format(self.flood_datetime_str))
Calc_flood_map(self)

def calc_flooded_regions_ViT(self, ViT_model_filename):

def calc_flooded_regions_ViT(self, ViT_model_filename, device = 'cuda', generate_vector = True, overwrite = True):
assert device in ['cuda', 'cpu'], 'device parameter must be cuda or cpu'

self.Flood_map_dataset_filename = os.path.join(self.Results_dir, 'Flood_map_ViT_{}.nc'.format(self.flood_datetime_str))
predict_flooded_regions(self, ViT_model_filename)
self.Flood_map_vector_dataset_filename = os.path.join(self.Results_dir, 'Flood_map_ViT_{}.geojson'.format(self.flood_datetime_str))

if os.path.exists(self.Flood_map_dataset_filename):
if overwrite:
os.remove(self.Flood_map_dataset_filename)
predict_flooded_regions(self, ViT_model_filename, device)
else:
predict_flooded_regions(self, ViT_model_filename, device)

if generate_vector:
if os.path.exists(self.Flood_map_vector_dataset_filename):
if overwrite:
os.remove(self.Flood_map_vector_dataset_filename)
convert_to_vector(self)
else:
convert_to_vector(self)

def plot_flood_map(self):
self.interactive_map = plot_interactive_map(self)
return self.interactive_map



Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ def save_to_netcdf(S1_dataset, prediction_data, flooded_region_filename, flooded
Flooded_xarray.rio.write_crs("epsg:4326", inplace=True)
Flooded_xarray.to_netcdf(flooded_region_filename, format='NETCDF4')

def predict_flooded_regions(Floodpy_app, ViT_model_filename ):
def predict_flooded_regions(Floodpy_app, ViT_model_filename, device):

# loading pretrained ViT model
sys.path.insert(0, Floodpy_app.src)
vit_model = torch.load(ViT_model_filename)
device = 'cuda'
vit_model.to(device)

batch_size = 224
Expand Down Expand Up @@ -121,3 +120,4 @@ def predict_flooded_regions(Floodpy_app, ViT_model_filename ):




115 changes: 115 additions & 0 deletions floodpy/Visualization/interactive_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import geopandas as gpd
import rasterio as rio
import rasterio.mask
import os
import pandas as pd
import xarray as xr
import numpy as np

# plotting functionalities
import matplotlib.pyplot as plt
import folium
import matplotlib
from branca.element import Template, MacroElement
import branca.colormap as cm
from folium.plugins import MeasureControl, Draw
from xyzservices.lib import TileProvider

# FLOODPY libraries
from floodpy.utils.folium_categorical_legend import get_folium_categorical_template

def plot_interactive_map(Floodpy_app):

# Read AOI
aoi = gpd.read_file(Floodpy_app.geojson_bbox)

# AOI bounds
left, bottom, right, top = aoi.total_bounds

# Define map bounds
map_bounds = [[bottom, left], [top, right]]

# Create a map located to the AOI
m = folium.Map(location=[aoi.centroid.y[0], aoi.centroid.x[0]], tiles="openstreetmap", zoom_start=13)

folium.TileLayer("openstreetmap").add_to(m)
folium.TileLayer('cartodbdark_matter').add_to(m)

# measuring funcs
MeasureControl('bottomleft').add_to(m)

# drawing funcs
draw = Draw(export = True,
filename=os.path.join(Floodpy_app.projectfolder,'myJson.json'),
position='topleft').add_to(m)

# add geojson AOI
folium.GeoJson(aoi["geometry"],
show = False,
name='Area of Interest').add_to(m)

#------------------------------------------------------------------------------
# ESA worldcover

with rio.open(Floodpy_app.lc_mosaic_filename) as src:
LC_cover, out_transform = rasterio.mask.mask(src, aoi.geometry, crop=True)
LC_cover = LC_cover[0,:,:]

LC_map = folium.raster_layers.ImageOverlay(image = LC_cover,
name = 'ESA Worldcover 2021',
opacity = 1,
bounds = map_bounds,
show = False,
colormap = lambda x: Floodpy_app.LC_COLORBAR[x])

m.add_child(LC_map)

legend_categories = {Floodpy_app.LC_CATEGORIES[x]: Floodpy_app.LC_COLORBAR[x] for x in np.unique(LC_cover)}

template = get_folium_categorical_template(legend_categories)
macro = MacroElement()
macro._template = Template(template)
m.get_root().add_child(macro)

#------------------------------------------------------------------------------
# S1 VV backscatter Flood image

S1_stack_dB = xr.open_dataset(Floodpy_app.S1_stack_filename)['VV_dB']
Flood_data = S1_stack_dB.sel(time = pd.to_datetime(Floodpy_app.flood_datetime_str)).values

vmin = np.nanquantile(Flood_data, 0.01)
vmax = np.nanquantile(Flood_data, 0.99)

S1_data = np.clip(Flood_data, vmin, vmax)

cmap = cm.LinearColormap(['black', 'white'],
index=[vmin, vmax],
vmin=vmin, vmax=vmax)

cmap.caption = "Backscatter coefficient VV (dB)"
cmap_func = lambda x: matplotlib.colors.to_rgba(cmap(x)) if ~np.isnan(x) else (0,0,0,0)

folium.raster_layers.ImageOverlay(image = S1_data,
name = "Sentinel-1 ({})".format(Floodpy_app.flood_datetime_str),
opacity = 1,
bounds = map_bounds,
colormap = cmap_func).add_to(m)
m.add_child(cmap)

#------------------------------------------------------------------------------
# Flood binary mask

Flood_map_dataset = xr.open_dataset(Floodpy_app.Flood_map_dataset_filename)
flooded_regions = Flood_map_dataset.flooded_regions.data.astype(np.int32)

raster_to_coloridx = {1: (0.0, 0.0, 1.0, 0.8),
0: (0.0, 0.0, 0.0, 0.0)}

m.add_child(folium.raster_layers.ImageOverlay(image = flooded_regions,
name = 'Flooded Regions {} (UTC)'.format(Floodpy_app.flood_datetime_str),
bounds = map_bounds,
colormap = lambda x: raster_to_coloridx[x]))

folium.LayerControl('bottomleft', collapsed=False).add_to(m)

return m
26 changes: 25 additions & 1 deletion floodpy/utils/geo_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
from shapely.geometry import Polygon
import geopandas as gpd
from rasterio.features import shapes
from shapely.geometry import shape
import rasterio
import numpy as np

def create_polygon(coordinates):
return Polygon(coordinates['coordinates'][0])
return Polygon(coordinates['coordinates'][0])


def convert_to_vector(Floodpy_app):
with rasterio.open(Floodpy_app.Flood_map_dataset_filename) as src:
data = src.read(1).astype(np.int16)

# Use a generator instead of a list
shape_gen = ((shape(s), v) for s, v in shapes(data, transform=src.transform))

# either build a pd.DataFrame
# df = DataFrame(shape_gen, columns=['geometry', 'class'])
# gdf = GeoDataFrame(df["class"], geometry=df.geometry, crs=src.crs)

# or build a dict from unpacked shapes
gdf = gpd.GeoDataFrame(dict(zip(["geometry", "flooded_regions"], zip(*shape_gen))), crs=src.crs)
gdf = gdf.loc[gdf.flooded_regions == 1,:]
gdf.datetime = Floodpy_app.flood_datetime_str

gdf.to_file(Floodpy_app.Flood_map_vector_dataset_filename, driver='GeoJSON')

0 comments on commit 8615f8f

Please sign in to comment.