Skip to content

Commit

Permalink
feat: support 2 lines stats/layout update
Browse files Browse the repository at this point in the history
  • Loading branch information
JulesL2 committed Nov 3, 2024
1 parent 397e549 commit 9f6d5d8
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 42 deletions.
43 changes: 43 additions & 0 deletions pretty_gpx/common/drawing/elevation_stats_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pretty_gpx.common.gpx.gpx_track import GpxTrack
from pretty_gpx.common.layout.elevation_vertical_layout import ElevationVerticalLayout
from pretty_gpx.common.utils.asserts import assert_same_len
from pretty_gpx.common.utils.logger import logger


def downsample(x: np.ndarray, y: np.ndarray, n: int) -> tuple[np.ndarray, np.ndarray]:
Expand All @@ -35,6 +36,7 @@ def __init__(self,
paper_fig: BaseDrawingFigure,
track: GpxTrack,
pts_per_mm: float = 2.0) -> None:
self.layout = layout
b = paper_fig.gpx_bounds

y_lat_up = b.lat_min + b.dlat * (layout.stats_relative_h + layout.elevation_relative_h)
Expand All @@ -58,6 +60,47 @@ def __init__(self,
self.section_center_lat_y = 0.5 * (y_lat_bot + b.lat_min)
self.section_center_lon_x = b.lon_center


def update_section_with_new_layout(self,
paper_fig: BaseDrawingFigure,
new_layout: ElevationVerticalLayout) -> None:
"""Update in place the ElevationStatsSection to a new layout."""
old_layout = self.layout
b = paper_fig.gpx_bounds


# Compute old and new positions of the elevation stat panel
lat_up_old = b.lat_min + b.dlat * (old_layout.stats_relative_h + old_layout.elevation_relative_h)
lat_bot_old = b.lat_min + b.dlat * old_layout.stats_relative_h

lat_up_new = b.lat_min + b.dlat * (new_layout.stats_relative_h + new_layout.elevation_relative_h)
lat_bot_new = b.lat_min + b.dlat * new_layout.stats_relative_h

# Factor of the scaling to update the layout
translation = lat_bot_new - lat_bot_old
scaling = (lat_up_new-lat_bot_new)/(lat_up_old-lat_bot_old)

logger.info(f"Update elevation section: Translation in lat={translation:.2e} Zoom factor={scaling:.2f}")

# Old layout
elevation_y_old = self.fill_poly.y[2:-2]

# Old part where the scaling is applied
elevation_y_old_scale_part = np.array(elevation_y_old) - lat_bot_old

new_elevation_y = elevation_y_old_scale_part*scaling + lat_bot_new


#Add the new starting point and ending point
elevation_y_new = [b.lat_min, lat_bot_new] + new_elevation_y.tolist() + [lat_bot_new, b.lat_min]

self.fill_poly.y = elevation_y_new

self.section_center_lat_y = 0.5 * (lat_bot_new + b.lat_min)
self.section_center_lon_x = b.lon_center

self.layout = new_layout

def get_profile_lat_y(self, k: int) -> float:
"""Get the latitude of the elevation profile on the poster at index k in the original GPX track."""
return self.__elevation_poly_y_lat[k]
Expand Down
122 changes: 113 additions & 9 deletions pretty_gpx/common/layout/base_vertical_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Base Vertical Layout."""
from dataclasses import dataclass
from dataclasses import fields
from typing import ClassVar

import matplotlib.pyplot as plt

Expand Down Expand Up @@ -58,6 +59,55 @@ class BaseVerticalLayout:
└───► X
"""

# Class attribute to be defined by child classes
_LAYOUTS: ClassVar[set[str]] = set()

@classmethod
def get_max_download_bounds_across_layouts(cls,
gpx_track: GpxTrack,
paper: PaperSize) -> GpxBounds:
"""Get maximum possible GPX bounds by trying all available layouts, considering X and Y independently."""
max_dlon = 0.0
max_dlat = 0.0
max_x_layout_name = None
max_y_layout_name = None

bounds_by_layout = {}

# Get all layout methods
layout_methods = {name: getattr(cls, name) for name in cls._LAYOUTS}

for layout_name, layout_method in layout_methods.items():
layout : BaseVerticalLayout = layout_method()
bounds = layout.get_layout_download_bounds(gpx_track, paper)

bounds_by_layout[layout_name] = bounds

if bounds.dlon > max_dlon:
max_dlon = bounds.dlon
max_x_layout_name = layout_name

if bounds.dlat > max_dlat:
max_dlat = bounds.dlat
max_y_layout_name = layout_name

if max_x_layout_name is None or max_y_layout_name is None:
raise ValueError("No valid layouts found")

# Get the center points from the bounds that gave us max X and Y
x_center = bounds_by_layout[max_x_layout_name].lon_center
y_center = bounds_by_layout[max_y_layout_name].lat_center

# Create the final maximum bounds combining the largest X and Y extents
max_bounds = GpxBounds.from_center(
lon_center=x_center,
lat_center=y_center,
dlon=max_dlon,
dlat=max_dlat
)

return max_bounds

######### METHODS TO IMPLEMENT #########

def _get_download_y_bounds(self) -> 'RelativeYBounds':
Expand Down Expand Up @@ -85,6 +135,66 @@ def __post_init__(self) -> None:

assert_float_eq(sum_fields, 1.0, msg="Sum of fields must be 1.0")

# Verify layouts
layout_methods = {name for name, method in self.__class__.__dict__.items()
if isinstance(method, staticmethod)
and name != 'get_layout_with_largest_map_and_title'
and not name.startswith('_')}

missing_layouts = layout_methods - self.__class__._LAYOUTS
extra_layouts = self.__class__._LAYOUTS - layout_methods

if missing_layouts:
raise ValueError(f"Layout methods {missing_layouts} exist but are not included in _LAYOUTS")
if extra_layouts:
raise ValueError(f"Layouts {extra_layouts} are included in _LAYOUTS but don't exist as methods")


def get_layout_download_bounds(self,
gpx_track: GpxTrack,
paper: PaperSize) -> GpxBounds:
"""Get GPX bounds around the GPX track to match the input vertical layout and paper size."""
# Remove the margins
paper_w_mm = (paper.w_mm - 2*paper.margin_mm)
paper_h_mm = (paper.h_mm - 2*paper.margin_mm)

# Track area
track_w_mm = paper_w_mm
track_h_mm = paper_h_mm * self._get_track_y_bounds().height

# Tight Track area (after removing the margins)
tight_w_mm = track_w_mm * (1. - 2*self._get_track_margin())
tight_h_mm = track_h_mm * (1. - 2*self._get_track_margin())

# Analyze the GPX track
bounds = gpx_track.get_bounds()

# Aspect ratio of the lat/lon map
latlon_aspect_ratio = bounds.latlon_aspect_ratio

# Tight fit
lat_deg_per_mm = max(
bounds.dlon / (tight_w_mm * latlon_aspect_ratio), # Track touching left/right of tight area
bounds.dlat / tight_h_mm # Track touching bot/top of tight area
)

# Compute the Paper Bounds
paper_dlon = paper_w_mm * lat_deg_per_mm * latlon_aspect_ratio
paper_dlat = paper_h_mm * lat_deg_per_mm

# Compute the Download Bounds
download_dlon = paper_dlon
download_dlat = paper_dlat * self._get_download_y_bounds().height

download_lat_offset = paper_dlat*(self._get_download_y_bounds().center - self._get_track_y_bounds().center)
download_bounds = GpxBounds.from_center(lon_center=bounds.lon_center,
lat_center=bounds.lat_center + download_lat_offset,
dlon=download_dlon,
dlat=download_dlat)

return download_bounds


def get_download_bounds_and_paper_figure(self,
gpx_track: GpxTrack,
paper: PaperSize) -> tuple[GpxBounds, BaseDrawingFigure]:
Expand Down Expand Up @@ -124,15 +234,9 @@ def get_download_bounds_and_paper_figure(self,
dlat=paper_dlat)
drawing_fig = BaseDrawingFigure(paper, latlon_aspect_ratio, paper_bounds)

# Compute the Download Bounds
download_dlon = paper_dlon
download_dlat = paper_dlat * self._get_download_y_bounds().height

download_lat_offset = paper_dlat*(self._get_download_y_bounds().center - self._get_track_y_bounds().center)
download_bounds = GpxBounds.from_center(lon_center=bounds.lon_center,
lat_center=bounds.lat_center + download_lat_offset,
dlon=download_dlon,
dlat=download_dlat)
# Get the Download Bounds
download_bounds = self.get_max_download_bounds_across_layouts(gpx_track=gpx_track,
paper=paper)

if DEBUG:
_debug(self, paper_bounds, download_bounds, bounds, gpx_track, latlon_aspect_ratio)
Expand Down
4 changes: 3 additions & 1 deletion pretty_gpx/rendering_modes/city/city_vertical_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
class CityVerticalLayout(ElevationVerticalLayout):
"""City Vertical Layout."""

_LAYOUTS = {'single_line_stats', 'two_lines_stats'}

@staticmethod
def default() -> 'CityVerticalLayout':
def single_line_stats() -> 'CityVerticalLayout':
"""Return the default City Vertical Layout."""
return CityVerticalLayout(title_relative_h=0.18,
map_relative_h=0.65,
Expand Down
63 changes: 31 additions & 32 deletions pretty_gpx/rendering_modes/city/drawing/city_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pretty_gpx.common.drawing.drawing_data import BaseDrawingData
from pretty_gpx.common.drawing.drawing_data import LineCollectionData
from pretty_gpx.common.drawing.drawing_data import PlotData
from pretty_gpx.common.drawing.drawing_data import PolyFillData
from pretty_gpx.common.drawing.drawing_data import PolygonCollectionData
from pretty_gpx.common.drawing.drawing_data import ScatterData
from pretty_gpx.common.drawing.drawing_data import TextData
Expand Down Expand Up @@ -59,16 +58,16 @@ def get_stats_items(dist_km_int: int,
return stats_items

@staticmethod
def build_stats_text(stats_items: list[str]) -> str:
def build_stats_text(stats_items: list[str]) -> tuple[str, CityVerticalLayout]:
"""Transform the stats items list into a string to display."""
if len(stats_items) == 0:
return ""
return "", CityVerticalLayout.single_line_stats()
elif len(stats_items) == 1:
return stats_items[0]
return stats_items[0], CityVerticalLayout.single_line_stats()
elif len(stats_items) == 2:
return f"{stats_items[0]} - {stats_items[1]}"
return f"{stats_items[0]} - {stats_items[1]}", CityVerticalLayout.single_line_stats()
elif len(stats_items) == 3:
return f"{stats_items[0]} - {stats_items[1]}\n{stats_items[2]}"
return f"{stats_items[0]} - {stats_items[1]}\n{stats_items[2]}", CityVerticalLayout.two_lines_stats()
else:
raise ValueError("The stat items list length is not valid")

Expand All @@ -83,8 +82,6 @@ class CityDrawer(Drawer[CityAugmentedGpxData,
stats_uphill_m: float
stats_duration_s: float | None

# TODO (upgrade): enhance support for duration with a 2 stats line layout

@staticmethod
def get_gpx_data_cls() -> type[CityAugmentedGpxData]:
"""Return the template AugmentedGpxData class (Because Python doesn't allow to use T as a type)."""
Expand All @@ -95,17 +92,26 @@ def get_gpx_data_cls() -> type[CityAugmentedGpxData]:
def from_gpx_data(gpx_data: CityAugmentedGpxData,
paper_size: PaperSize) -> 'CityDrawer':
"""Create a CityDrawer from a GPX file."""
layout = CityVerticalLayout.default()
stats_items = CityDrawingInputs.get_stats_items(dist_km_int=int(gpx_data.dist_km),
uphill_m_int=int(gpx_data.uphill_m),
duration_s_float=gpx_data.duration_s)

stats_txt, layout = CityDrawingInputs.build_stats_text(stats_items=stats_items)

# Download the elevation map at the correct layout
img_bounds, paper_fig = layout.get_download_bounds_and_paper_figure(gpx_data.track, paper_size)
download_bounds, paper_fig = layout.get_download_bounds_and_paper_figure(gpx_data.track, paper_size)

# Use default drawing params
drawing_size_config = CityDrawingSizeConfig.default(paper_size, img_bounds.diagonal_m)
drawing_size_config = CityDrawingSizeConfig.default(paper_size, paper_fig.gpx_bounds.diagonal_m)
drawing_style_config = CityDrawingStyleConfig()

plotter = init_and_populate_drawing_figure(gpx_data, paper_fig, img_bounds, layout, drawing_size_config,
drawing_style_config)
plotter = init_and_populate_drawing_figure(gpx_data=gpx_data,
base_fig=paper_fig,
download_bounds=download_bounds,
layout=layout,
stats_text=stats_txt,
drawing_size_config=drawing_size_config,
drawing_style_config=drawing_style_config)

logger.info("Successful GPX Processing")

Expand All @@ -131,45 +137,47 @@ def get_params(self, inputs: CityDrawingInputs) -> CityDrawingParams:
uphill_m_int=int(uphill_m_int),
duration_s_float=stats_duration_s)

stats_text = CityDrawingInputs.build_stats_text(stats_items=new_stats_items)
stats_text, layout = CityDrawingInputs.build_stats_text(stats_items=new_stats_items)

return CityDrawingParams(theme_colors=inputs.theme_colors,
title_txt=title_txt,
stats_txt=stats_text)
stats_txt=stats_text,
layout=layout)


def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData,
base_fig: BaseDrawingFigure,
city_bounds: GpxBounds,
download_bounds: GpxBounds,
layout: CityVerticalLayout,
stats_text: str,
drawing_size_config: CityDrawingSizeConfig,
drawing_style_config: CityDrawingStyleConfig
) -> CityDrawingFigure:
"""Set up and populate the DrawingFigure for the poster."""
gpx_track = gpx_data.track

caracteristic_distance_m = city_bounds.diagonal_m
caracteristic_distance_m = base_fig.gpx_bounds.diagonal_m
logger.info(f"Domain diagonal is {caracteristic_distance_m/1000.:.1f}km")

total_query = OverpassQuery()
for prepare_func in [prepare_download_city_roads,
prepare_download_city_rivers,
prepare_download_city_forests]:

prepare_func(query=total_query, bounds=city_bounds)
prepare_func(query=total_query, bounds=download_bounds)

# Merge and run all queries
total_query.launch_queries()

# Retrieve the data
roads = process_city_roads(query=total_query,
bounds=city_bounds)
bounds=download_bounds)

rivers = process_city_rivers(query=total_query,
bounds=city_bounds)
bounds=download_bounds)

forests, farmland = process_city_forests(query=total_query,
bounds=city_bounds)
bounds=download_bounds)
forests.interior_polygons = []

track_data: list[BaseDrawingData] = [PlotData(x=gpx_track.list_lon, y=gpx_track.list_lat, linewidth=2.0)]
Expand All @@ -189,13 +197,7 @@ def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData,
ha="center",
va="center")

stats_text = f"{gpx_track.list_cumul_dist_km[-1]:.2f} km - {int(gpx_track.uphill_m)} m D+"

if gpx_track.duration_s is not None:
stats_text += f"\n{format_timedelta(gpx_track.duration_s)}"

ele = ElevationStatsSection(layout, base_fig, gpx_track)
track_data.append(ele.fill_poly)
elevation_profile = ElevationStatsSection(layout, base_fig, gpx_track)

stats = TextData(x=b.lon_center, y=b.lat_min + 0.5 * b.dlat * layout.stats_relative_h,
s=stats_text, fontsize=mm_to_point(18.5),
Expand All @@ -207,10 +209,6 @@ def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData,
point_data.append(ScatterData(x=[gpx_track.list_lon[-1]], y=[gpx_track.list_lat[-1]],
marker="s", markersize=mm_to_point(3.5)))

h_top_stats = b.lat_min + b.dlat * layout.stats_relative_h
track_data.append(PolyFillData(x=[b.lon_min, b.lon_max, b.lon_max, b.lon_min],
y=[h_top_stats, h_top_stats, b.lat_min, b.lat_min]))

plotter = CityDrawingFigure(paper_size=base_fig.paper_size,
latlon_aspect_ratio=base_fig.latlon_aspect_ratio,
gpx_bounds=base_fig.gpx_bounds,
Expand All @@ -220,6 +218,7 @@ def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData,
rivers_data=rivers_data,
forests_data=forests_data,
farmland_data=farmland_data,
elevation_profile=elevation_profile,
title=title,
stats=stats)

Expand Down
Loading

0 comments on commit 9f6d5d8

Please sign in to comment.