Skip to content

Commit

Permalink
Add city mode
Browse files Browse the repository at this point in the history
  • Loading branch information
JulesL2 committed Sep 13, 2024
1 parent c153362 commit d72253d
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 39 deletions.
87 changes: 87 additions & 0 deletions pretty_gpx/drawing/drawing_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,90 @@ def draw(self,

self.title.plot(ax, theme_colors.peak_color, self.ref_img_shape, img.shape)
self.stats.plot(ax, theme_colors.background_color, self.ref_img_shape, img.shape)

class CityDrawingFigure:
"""Base Drawing Figure displaying an image with plt.imshow.
paper_size: Expected intrinsic Figure width/height in mm (when saved as a vectorized image)
Drawing Figure displaying annotations on top of an image plotted with plt.imshow.
w_display_pix: Expected screen Figure width in pixels (when displayed on screen)
ref_img_shape: Shape of the reference background image. It defines the scale at which the X-Y coordinates of the
annotations have been saved.
track_data: Drawing Data to plot with the Track Color
peak_data: Drawing Data to plot with the Peak Color
title: Text Drawing Data for the title at the top of the image (Text String can be updated)
stats: Text Drawing Data for the statistics at the bottom of the image (Text String can be updated)
"""
paper_size: PaperSize
latlon_aspect_ratio: float

w_display_pix: int
ref_img_shape: tuple[int, ...]

gpx_segments: list[tuple[list[float],list[float]]]
roads_segments: list[tuple[list[float],list[float]]]

start: BaseDrawingData
end: BaseDrawingData

title: TextData
distance: TextData
duration: TextData

def draw(self,
fig: Figure,
ax: Axes,
theme_colors: ThemeColors,
title_txt: str,
stats_txt: str) -> None:
"""Plot the background image and the annotations on top of it."""
self.plot_segments(fig, ax, img)
self.adjust_display_width(fig, self.w_display_pix)

self.title.s = title_txt
self.stats.s = stats_txt

for data in self.track_data:
data.plot(ax, theme_colors.track_color, self.ref_img_shape, img.shape)

for data in self.peak_data:
data.plot(ax, theme_colors.peak_color, self.ref_img_shape, img.shape)

self.title.plot(ax, theme_colors.peak_color, self.ref_img_shape, img.shape)
self.stats.plot(ax, theme_colors.background_color, self.ref_img_shape, img.shape)

def plot_segments(self,
fig: Figure,
ax: Axes,
img: np.ndarray) -> None:
"""Display the image, with the appropriate size in inches and dpi."""
h, w = img.shape[:2]
ax.cla()
ax.axis('off')
fig.tight_layout(pad=0)
ax.imshow(img)
ax.set_aspect(self.latlon_aspect_ratio)

w_inches = mm_to_inch(self.paper_size.w_mm)
h_inches = mm_to_inch(self.paper_size.h_mm)

margin_inches = mm_to_inch(self.paper_size.margin_mm)
assert_close(self.latlon_aspect_ratio*h * (w_inches-2*margin_inches)/w,
h_inches-2*margin_inches,
eps=mm_to_inch(2.0),
msg="Wrong aspect ratio for image")

fig.set_size_inches(w_inches, h_inches)

fig.subplots_adjust(left=margin_inches/w_inches, right=1-margin_inches/w_inches,
bottom=margin_inches/h_inches, top=1-margin_inches/h_inches)

ax.autoscale(False)

def adjust_display_width(self, fig: Figure, w_display_pix: int) -> None:
"""Adjust height in pixels on the screen of the displayed figure."""
fig.set_dpi(w_display_pix / fig.get_size_inches()[0])
191 changes: 169 additions & 22 deletions pretty_gpx/drawing/poster_image_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
from pretty_gpx.drawing.drawing_data import PolyFillData
from pretty_gpx.drawing.drawing_data import ScatterData
from pretty_gpx.drawing.drawing_data import TextData
from pretty_gpx.drawing.drawing_figure import CityDrawingFigure
from pretty_gpx.drawing.drawing_figure import DrawingFigure
from pretty_gpx.drawing.drawing_params import DrawingSizeParams
from pretty_gpx.drawing.drawing_params import DrawingStyleParams
from pretty_gpx.drawing.hillshading import CachedHillShading
from pretty_gpx.drawing.text_allocation import allocate_text
from pretty_gpx.drawing.theme_colors import hex_to_rgb
from pretty_gpx.drawing.theme_colors import ThemeColors
from pretty_gpx.gpx.augmented_gpx_data import AugmentedGpxData
from pretty_gpx.gpx.augmented_gpx_data import MountainAugmentedGpxData
from pretty_gpx.gpx.augmented_gpx_data import CityAugmentedGpxData
from pretty_gpx.gpx.elevation_map import download_elevation_map
from pretty_gpx.gpx.elevation_map import rescale_elevation
from pretty_gpx.gpx.city_map import download_city_roads_map
from pretty_gpx.layout.paper_size import PaperSize
from pretty_gpx.layout.vertical_layout import get_bounds
from pretty_gpx.layout.vertical_layout import VerticalLayout
Expand All @@ -35,31 +38,59 @@


@dataclass
class PosterImageCaches:
"""Low and High resolution PosterImageCache."""
low_res: 'PosterImageCache'
high_res: 'PosterImageCache'
class MountainPosterImageCaches:
"""Low and High resolution MountainPosterImageCache."""
low_res: 'MountainPosterImageCache'
high_res: 'MountainPosterImageCache'

gpx_data: AugmentedGpxData
gpx_data: MountainAugmentedGpxData

def __post_init__(self) -> None:
assert self.low_res.dpi < self.high_res.dpi

@staticmethod
def from_gpx(list_gpx_path: str | bytes | list[str] | list[bytes],
paper_size: PaperSize) -> 'PosterImageCaches':
"""Create a PosterImageCaches from a GPX file."""
paper_size: PaperSize) -> 'MountainPosterImageCaches':
"""Create a MountainPosterImageCaches from a GPX file."""
# Extract GPX data and retrieve close mountain passes/huts
gpx_data = AugmentedGpxData.from_path(list_gpx_path)
return PosterImageCaches.from_gpx_data(gpx_data, paper_size)
gpx_data = MountainAugmentedGpxData.from_path(list_gpx_path)
return MountainPosterImageCaches.from_gpx_data(gpx_data, paper_size)

@staticmethod
def from_gpx_data(gpx_data: AugmentedGpxData,
paper_size: PaperSize) -> 'PosterImageCaches':
"""Create a PosterImageCaches from a GPX file."""
high_res = PosterImageCache.from_gpx_data(gpx_data, dpi=HIGH_RES_DPI, paper=paper_size)
def from_gpx_data(gpx_data: MountainAugmentedGpxData,
paper_size: PaperSize) -> 'MountainPosterImageCaches':
"""Create a MountainPosterImageCaches from a GPX file."""
high_res = MountainPosterImageCache.from_gpx_data(gpx_data, dpi=HIGH_RES_DPI, paper=paper_size)
low_res = high_res.change_dpi(WORKING_DPI)
return PosterImageCaches(low_res=low_res, high_res=high_res, gpx_data=gpx_data)
return MountainPosterImageCaches(low_res=low_res, high_res=high_res, gpx_data=gpx_data)


@dataclass
class CityPosterImageCaches:
"""Low and High resolution MountainPosterImageCache."""
low_res: 'CityPosterImageCache'
high_res: 'CityPosterImageCache'

gpx_data: CityAugmentedGpxData

def __post_init__(self) -> None:
assert self.low_res.dpi < self.high_res.dpi

@staticmethod
def from_gpx(list_gpx_path: str | bytes | list[str] | list[bytes],
paper_size: PaperSize) -> 'CityPosterImageCaches':
"""Create a CityPosterImageCaches from a GPX file."""
# Extract GPX data and retrieve close mountain passes/huts
gpx_data = CityAugmentedGpxData.from_path(list_gpx_path)
return CityPosterImageCaches.from_gpx_data(gpx_data, paper_size)

@staticmethod
def from_gpx_data(gpx_data: CityAugmentedGpxData,
paper_size: PaperSize) -> 'CityPosterImageCaches':
"""Create a CityPosterImageCaches from a GPX file."""
high_res = CityPosterImageCache.from_gpx_data(gpx_data, dpi=HIGH_RES_DPI, paper=paper_size)
low_res = high_res.change_dpi(WORKING_DPI)
return CityPosterImageCaches(low_res=low_res, high_res=high_res, gpx_data=gpx_data)


@dataclass
Expand All @@ -72,7 +103,7 @@ class PosterDrawingData:


@dataclass
class PosterImageCache:
class MountainPosterImageCache:
"""Class leveraging cache to avoid reprocessing GPX when chaning color them, title, sun azimuth..."""

elevation_map: np.ndarray
Expand All @@ -86,11 +117,11 @@ class PosterImageCache:
dpi: float

@staticmethod
def from_gpx_data(gpx_data: AugmentedGpxData,
def from_gpx_data(gpx_data: MountainAugmentedGpxData,
paper: PaperSize,
layout: VerticalLayout = VerticalLayout(),
dpi: float = HIGH_RES_DPI) -> 'PosterImageCache':
"""Create a PosterImageCache from a GPX file."""
dpi: float = HIGH_RES_DPI) -> 'MountainPosterImageCache':
"""Create a MountainPosterImageCache from a GPX file."""
# Download the elevation map at the correct layout
bounds, latlon_aspect_ratio = get_bounds(gpx_data.track, layout, paper)
elevation = download_elevation_map(bounds)
Expand Down Expand Up @@ -245,17 +276,17 @@ def from_gpx_data(gpx_data: AugmentedGpxData,
stats=stats)

logger.info("Successful GPX Processing")
return PosterImageCache(elevation_map=elevation,
return MountainPosterImageCache(elevation_map=elevation,
elevation_shading=CachedHillShading(elevation),
stats_dist_km=gpx_data.dist_km,
stats_uphill_m=gpx_data.uphill_m,
plotter=plotter,
dpi=dpi)

def change_dpi(self, dpi: float) -> 'PosterImageCache':
def change_dpi(self, dpi: float) -> 'MountainPosterImageCache':
"""Scale the elevation map to a new DPI."""
new_elevation_map = rescale_elevation_to_dpi(self.elevation_map, self.plotter.paper_size, dpi)
return PosterImageCache(elevation_map=new_elevation_map,
return MountainPosterImageCache(elevation_map=new_elevation_map,
elevation_shading=CachedHillShading(new_elevation_map),
stats_dist_km=self.stats_dist_km,
stats_uphill_m=self.stats_uphill_m,
Expand Down Expand Up @@ -352,3 +383,119 @@ def rescale_elevation_to_dpi(elevation_map: np.ndarray, paper: PaperSize, target
"""Rescale the elevation map to a target DPI."""
current_dpi = elevation_map.shape[0]/mm_to_inch(paper.h_mm)
return rescale_elevation(elevation_map, target_dpi/current_dpi)

@dataclass
class CityPosterImageCache:
"""
Same Class as MountainPosterImageCache, but it is storing information for run in the cities
Class leveraging cache to avoid reprocessing GPX when changing color theme,
title, sun azimuth..."""

stats_dist_km: float
stats_time: float

plotter: DrawingFigure

dpi: float

@staticmethod
def from_gpx_data(gpx_data: CityAugmentedGpxData,
paper: PaperSize,
layout: VerticalLayout = VerticalLayout(),
dpi: float = HIGH_RES_DPI) -> 'CityPosterImageCache':
"""Create a CityPosterImageCache from a GPX file."""
# Download the elevation map at the correct layout
bounds, latlon_aspect_ratio = get_bounds(gpx_data.track, layout, paper)

## Get the map of the city
roads_segments = download_city_roads_map(bounds)

x_gpx, y_gpx = gpx_data.track.mercator_projection()
activity_segments = [(x_gpx, y_gpx)]

# Use default drawing params
drawing_size_params = DrawingSizeParams.default(paper)
drawing_style_params = DrawingStyleParams()

# Allocate non-overlapping text annotations on the map
list_x: list[float] = []
list_y: list[float] = []
list_text: list[str] = []

start_idx = len(list_x)
list_x.append(x_gpx[0])
list_y.append(y_gpx[0])

end_idx = len(list_x)
list_x.append(x_gpx[-1])
list_y.append(y_gpx[-1])


# i = safe(start_idx)
# start_point = ScatterData(x=[list_x[i]],
# y=[list_y[i]],
# marker=drawing_style_params.start_marker,
# markersize=drawing_size_params.start_markersize)

# i = safe(end_idx)
# end_point = ScatterData(x=[list_x[i]],
# y=[list_y[i]],
# marker=drawing_style_params.end_marker,
# markersize=drawing_size_params.end_markersize)

# title = TextData(x=0.5 * w, y=0.8 * h * layout.title_relative_h,
# s="", fontsize=drawing_size_params.title_fontsize,
# fontproperties=drawing_style_params.pretty_font, ha="center")

plotter = CityDrawingFigure(ref_img_shape=(h, w),
paper_size=paper,
w_display_pix=W_DISPLAY_PIX,
latlon_aspect_ratio=latlon_aspect_ratio,
track_data=track_data,
peak_data=peak_data,
title=title,
stats=stats)

logger.info("Successful GPX Processing")
return CityPosterImageCache(stats_time=gpx_data.track.activity_duration,
stats_dist_km=gpx_data.dist_km,
plotter=plotter,
dpi=dpi)

def change_dpi(self, dpi: float) -> 'CityPosterImageCache':
"""Scale the elevation map to a new DPI."""
return CityPosterImageCache(stats_time=gpx_data.track.activity_duration,
stats_dist_km=gpx_data.dist_km,
plotter=plotter,
dpi=dpi)

def update_drawing_data(self,
azimuth: int,
theme_colors: ThemeColors,
title_txt: str,
uphill_m: str,
dist_km: str) -> PosterDrawingData:
"""Update the drawing data (can run in a separate thread)."""
grey_hillshade = self.elevation_shading.render_grey(azimuth)[..., None]
background_color_rgb = hex_to_rgb(theme_colors.background_color)
color_0 = (0, 0, 0) if theme_colors.dark_mode else background_color_rgb
color_1 = background_color_rgb if theme_colors.dark_mode else (255, 255, 255)
colored_hillshade = grey_hillshade * (np.array(color_1) - np.array(color_0)) + np.array(color_0)

img = colored_hillshade.astype(np.uint8)

dist_km_int = int(dist_km if dist_km != '' else self.stats_dist_km)
uphill_m_int = int(uphill_m if uphill_m != '' else self.stats_uphill_m)
stats_text = f"{dist_km_int} km - {uphill_m_int} m D+"

return PosterDrawingData(img, theme_colors, title_txt=title_txt, stats_text=stats_text)

def draw(self, fig: Figure, ax: Axes, poster_drawing_data: PosterDrawingData) -> None:
"""Draw the updated drawing data (Must run in the main thread because of matplotlib backend)."""
self.plotter.draw(fig, ax,
poster_drawing_data.img,
poster_drawing_data.theme_colors,
poster_drawing_data.title_txt,
poster_drawing_data.stats_text)
logger.info("Drawing updated "
f"(Elevation Map {poster_drawing_data.img.shape[1]}x{poster_drawing_data.img.shape[0]})")
8 changes: 4 additions & 4 deletions pretty_gpx/explore_new_color_themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import matplotlib.pyplot as plt
from tqdm import tqdm

from pretty_gpx.drawing.poster_image_cache import PosterImageCache
from pretty_gpx.drawing.poster_image_cache import MountainPosterImageCache
from pretty_gpx.drawing.theme_colors import hex_to_rgb
from pretty_gpx.drawing.theme_colors import ThemeColors
from pretty_gpx.gpx.augmented_gpx_data import AugmentedGpxData
from pretty_gpx.gpx.augmented_gpx_data import MountainAugmentedGpxData
from pretty_gpx.layout.paper_size import PAPER_SIZES
from pretty_gpx.utils.logger import logger
from pretty_gpx.utils.paths import COLOR_EXPLORATION_DIR
Expand Down Expand Up @@ -62,8 +62,8 @@ def main(color_palettes: list[tuple[str, str, str]]) -> None:
Tune Color:
- https://mdigi.tools/darken-color/#f1effc
"""
gpx_data = AugmentedGpxData.from_path(os.path.join(CYCLING_DIR, "marmotte.gpx"))
cache = PosterImageCache.from_gpx_data(gpx_data, paper=PAPER_SIZES["A4"], dpi=60)
gpx_data = MountainAugmentedGpxData.from_path(os.path.join(CYCLING_DIR, "marmotte.gpx"))
cache = MountainPosterImageCache.from_gpx_data(gpx_data, paper=PAPER_SIZES["A4"], dpi=60)

shutil.rmtree(COLOR_EXPLORATION_DIR, ignore_errors=True)
os.makedirs(COLOR_EXPLORATION_DIR, exist_ok=True)
Expand Down
Loading

0 comments on commit d72253d

Please sign in to comment.