diff --git a/pretty_gpx/common/drawing/elevation_stats_section.py b/pretty_gpx/common/drawing/elevation_stats_section.py index b9fb59e..d0782b7 100644 --- a/pretty_gpx/common/drawing/elevation_stats_section.py +++ b/pretty_gpx/common/drawing/elevation_stats_section.py @@ -11,6 +11,7 @@ from pretty_gpx.common.gpx.gpx_track import GpxTrack from pretty_gpx.common.layout.elevation_vertical_layout import ElevationVerticalLayout from pretty_gpx.common.utils.asserts import assert_same_len +from pretty_gpx.common.utils.logger import logger def downsample(x: np.ndarray, y: np.ndarray, n: int) -> tuple[np.ndarray, np.ndarray]: @@ -35,6 +36,7 @@ def __init__(self, paper_fig: BaseDrawingFigure, track: GpxTrack, pts_per_mm: float = 2.0) -> None: + self.layout = layout b = paper_fig.gpx_bounds y_lat_up = b.lat_min + b.dlat * (layout.stats_relative_h + layout.elevation_relative_h) @@ -58,6 +60,47 @@ def __init__(self, self.section_center_lat_y = 0.5 * (y_lat_bot + b.lat_min) self.section_center_lon_x = b.lon_center + + def update_section_with_new_layout(self, + paper_fig: BaseDrawingFigure, + new_layout: ElevationVerticalLayout) -> None: + """Update in place the ElevationStatsSection to a new layout.""" + old_layout = self.layout + b = paper_fig.gpx_bounds + + + # Compute old and new positions of the elevation stat panel + lat_up_old = b.lat_min + b.dlat * (old_layout.stats_relative_h + old_layout.elevation_relative_h) + lat_bot_old = b.lat_min + b.dlat * old_layout.stats_relative_h + + lat_up_new = b.lat_min + b.dlat * (new_layout.stats_relative_h + new_layout.elevation_relative_h) + lat_bot_new = b.lat_min + b.dlat * new_layout.stats_relative_h + + # Factor of the scaling to update the layout + translation = lat_bot_new - lat_bot_old + scaling = (lat_up_new-lat_bot_new)/(lat_up_old-lat_bot_old) + + logger.info(f"Update elevation section: Translation in lat={translation:.2e} Zoom factor={scaling:.2f}") + + # Old layout + elevation_y_old = self.fill_poly.y[2:-2] + + # Old part where the scaling is applied + elevation_y_old_scale_part = np.array(elevation_y_old) - lat_bot_old + + new_elevation_y = elevation_y_old_scale_part*scaling + lat_bot_new + + + #Add the new starting point and ending point + elevation_y_new = [b.lat_min, lat_bot_new] + new_elevation_y.tolist() + [lat_bot_new, b.lat_min] + + self.fill_poly.y = elevation_y_new + + self.section_center_lat_y = 0.5 * (lat_bot_new + b.lat_min) + self.section_center_lon_x = b.lon_center + + self.layout = new_layout + def get_profile_lat_y(self, k: int) -> float: """Get the latitude of the elevation profile on the poster at index k in the original GPX track.""" return self.__elevation_poly_y_lat[k] diff --git a/pretty_gpx/common/layout/base_vertical_layout.py b/pretty_gpx/common/layout/base_vertical_layout.py index 62e71bd..7b8b258 100644 --- a/pretty_gpx/common/layout/base_vertical_layout.py +++ b/pretty_gpx/common/layout/base_vertical_layout.py @@ -2,6 +2,7 @@ """Base Vertical Layout.""" from dataclasses import dataclass from dataclasses import fields +from typing import ClassVar import matplotlib.pyplot as plt @@ -58,6 +59,55 @@ class BaseVerticalLayout: └───► X """ + # Class attribute to be defined by child classes + _LAYOUTS: ClassVar[set[str]] = set() + + @classmethod + def get_max_download_bounds_across_layouts(cls, + gpx_track: GpxTrack, + paper: PaperSize) -> GpxBounds: + """Get maximum possible GPX bounds by trying all available layouts, considering X and Y independently.""" + max_dlon = 0.0 + max_dlat = 0.0 + max_x_layout_name = None + max_y_layout_name = None + + bounds_by_layout = {} + + # Get all layout methods + layout_methods = {name: getattr(cls, name) for name in cls._LAYOUTS} + + for layout_name, layout_method in layout_methods.items(): + layout : BaseVerticalLayout = layout_method() + bounds = layout.get_layout_download_bounds(gpx_track, paper) + + bounds_by_layout[layout_name] = bounds + + if bounds.dlon > max_dlon: + max_dlon = bounds.dlon + max_x_layout_name = layout_name + + if bounds.dlat > max_dlat: + max_dlat = bounds.dlat + max_y_layout_name = layout_name + + if max_x_layout_name is None or max_y_layout_name is None: + raise ValueError("No valid layouts found") + + # Get the center points from the bounds that gave us max X and Y + x_center = bounds_by_layout[max_x_layout_name].lon_center + y_center = bounds_by_layout[max_y_layout_name].lat_center + + # Create the final maximum bounds combining the largest X and Y extents + max_bounds = GpxBounds.from_center( + lon_center=x_center, + lat_center=y_center, + dlon=max_dlon, + dlat=max_dlat + ) + + return max_bounds + ######### METHODS TO IMPLEMENT ######### def _get_download_y_bounds(self) -> 'RelativeYBounds': @@ -85,6 +135,66 @@ def __post_init__(self) -> None: assert_float_eq(sum_fields, 1.0, msg="Sum of fields must be 1.0") + # Verify layouts + layout_methods = {name for name, method in self.__class__.__dict__.items() + if isinstance(method, staticmethod) + and name != 'get_layout_with_largest_map_and_title' + and not name.startswith('_')} + + missing_layouts = layout_methods - self.__class__._LAYOUTS + extra_layouts = self.__class__._LAYOUTS - layout_methods + + if missing_layouts: + raise ValueError(f"Layout methods {missing_layouts} exist but are not included in _LAYOUTS") + if extra_layouts: + raise ValueError(f"Layouts {extra_layouts} are included in _LAYOUTS but don't exist as methods") + + + def get_layout_download_bounds(self, + gpx_track: GpxTrack, + paper: PaperSize) -> GpxBounds: + """Get GPX bounds around the GPX track to match the input vertical layout and paper size.""" + # Remove the margins + paper_w_mm = (paper.w_mm - 2*paper.margin_mm) + paper_h_mm = (paper.h_mm - 2*paper.margin_mm) + + # Track area + track_w_mm = paper_w_mm + track_h_mm = paper_h_mm * self._get_track_y_bounds().height + + # Tight Track area (after removing the margins) + tight_w_mm = track_w_mm * (1. - 2*self._get_track_margin()) + tight_h_mm = track_h_mm * (1. - 2*self._get_track_margin()) + + # Analyze the GPX track + bounds = gpx_track.get_bounds() + + # Aspect ratio of the lat/lon map + latlon_aspect_ratio = bounds.latlon_aspect_ratio + + # Tight fit + lat_deg_per_mm = max( + bounds.dlon / (tight_w_mm * latlon_aspect_ratio), # Track touching left/right of tight area + bounds.dlat / tight_h_mm # Track touching bot/top of tight area + ) + + # Compute the Paper Bounds + paper_dlon = paper_w_mm * lat_deg_per_mm * latlon_aspect_ratio + paper_dlat = paper_h_mm * lat_deg_per_mm + + # Compute the Download Bounds + download_dlon = paper_dlon + download_dlat = paper_dlat * self._get_download_y_bounds().height + + download_lat_offset = paper_dlat*(self._get_download_y_bounds().center - self._get_track_y_bounds().center) + download_bounds = GpxBounds.from_center(lon_center=bounds.lon_center, + lat_center=bounds.lat_center + download_lat_offset, + dlon=download_dlon, + dlat=download_dlat) + + return download_bounds + + def get_download_bounds_and_paper_figure(self, gpx_track: GpxTrack, paper: PaperSize) -> tuple[GpxBounds, BaseDrawingFigure]: @@ -124,15 +234,9 @@ def get_download_bounds_and_paper_figure(self, dlat=paper_dlat) drawing_fig = BaseDrawingFigure(paper, latlon_aspect_ratio, paper_bounds) - # Compute the Download Bounds - download_dlon = paper_dlon - download_dlat = paper_dlat * self._get_download_y_bounds().height - - download_lat_offset = paper_dlat*(self._get_download_y_bounds().center - self._get_track_y_bounds().center) - download_bounds = GpxBounds.from_center(lon_center=bounds.lon_center, - lat_center=bounds.lat_center + download_lat_offset, - dlon=download_dlon, - dlat=download_dlat) + # Get the Download Bounds + download_bounds = self.get_max_download_bounds_across_layouts(gpx_track=gpx_track, + paper=paper) if DEBUG: _debug(self, paper_bounds, download_bounds, bounds, gpx_track, latlon_aspect_ratio) diff --git a/pretty_gpx/rendering_modes/city/city_vertical_layout.py b/pretty_gpx/rendering_modes/city/city_vertical_layout.py index 3809e1c..69228ec 100644 --- a/pretty_gpx/rendering_modes/city/city_vertical_layout.py +++ b/pretty_gpx/rendering_modes/city/city_vertical_layout.py @@ -9,8 +9,10 @@ class CityVerticalLayout(ElevationVerticalLayout): """City Vertical Layout.""" + _LAYOUTS = {'single_line_stats', 'two_lines_stats'} + @staticmethod - def default() -> 'CityVerticalLayout': + def single_line_stats() -> 'CityVerticalLayout': """Return the default City Vertical Layout.""" return CityVerticalLayout(title_relative_h=0.18, map_relative_h=0.65, diff --git a/pretty_gpx/rendering_modes/city/drawing/city_drawer.py b/pretty_gpx/rendering_modes/city/drawing/city_drawer.py index 9f37da0..028f089 100644 --- a/pretty_gpx/rendering_modes/city/drawing/city_drawer.py +++ b/pretty_gpx/rendering_modes/city/drawing/city_drawer.py @@ -7,7 +7,6 @@ from pretty_gpx.common.drawing.drawing_data import BaseDrawingData from pretty_gpx.common.drawing.drawing_data import LineCollectionData from pretty_gpx.common.drawing.drawing_data import PlotData -from pretty_gpx.common.drawing.drawing_data import PolyFillData from pretty_gpx.common.drawing.drawing_data import PolygonCollectionData from pretty_gpx.common.drawing.drawing_data import ScatterData from pretty_gpx.common.drawing.drawing_data import TextData @@ -59,16 +58,16 @@ def get_stats_items(dist_km_int: int, return stats_items @staticmethod - def build_stats_text(stats_items: list[str]) -> str: + def build_stats_text(stats_items: list[str]) -> tuple[str, CityVerticalLayout]: """Transform the stats items list into a string to display.""" if len(stats_items) == 0: - return "" + return "", CityVerticalLayout.single_line_stats() elif len(stats_items) == 1: - return stats_items[0] + return stats_items[0], CityVerticalLayout.single_line_stats() elif len(stats_items) == 2: - return f"{stats_items[0]} - {stats_items[1]}" + return f"{stats_items[0]} - {stats_items[1]}", CityVerticalLayout.single_line_stats() elif len(stats_items) == 3: - return f"{stats_items[0]} - {stats_items[1]}\n{stats_items[2]}" + return f"{stats_items[0]} - {stats_items[1]}\n{stats_items[2]}", CityVerticalLayout.two_lines_stats() else: raise ValueError("The stat items list length is not valid") @@ -83,8 +82,6 @@ class CityDrawer(Drawer[CityAugmentedGpxData, stats_uphill_m: float stats_duration_s: float | None - # TODO (upgrade): enhance support for duration with a 2 stats line layout - @staticmethod def get_gpx_data_cls() -> type[CityAugmentedGpxData]: """Return the template AugmentedGpxData class (Because Python doesn't allow to use T as a type).""" @@ -95,17 +92,26 @@ def get_gpx_data_cls() -> type[CityAugmentedGpxData]: def from_gpx_data(gpx_data: CityAugmentedGpxData, paper_size: PaperSize) -> 'CityDrawer': """Create a CityDrawer from a GPX file.""" - layout = CityVerticalLayout.default() + stats_items = CityDrawingInputs.get_stats_items(dist_km_int=int(gpx_data.dist_km), + uphill_m_int=int(gpx_data.uphill_m), + duration_s_float=gpx_data.duration_s) + + stats_txt, layout = CityDrawingInputs.build_stats_text(stats_items=stats_items) # Download the elevation map at the correct layout - img_bounds, paper_fig = layout.get_download_bounds_and_paper_figure(gpx_data.track, paper_size) + download_bounds, paper_fig = layout.get_download_bounds_and_paper_figure(gpx_data.track, paper_size) # Use default drawing params - drawing_size_config = CityDrawingSizeConfig.default(paper_size, img_bounds.diagonal_m) + drawing_size_config = CityDrawingSizeConfig.default(paper_size, paper_fig.gpx_bounds.diagonal_m) drawing_style_config = CityDrawingStyleConfig() - plotter = init_and_populate_drawing_figure(gpx_data, paper_fig, img_bounds, layout, drawing_size_config, - drawing_style_config) + plotter = init_and_populate_drawing_figure(gpx_data=gpx_data, + base_fig=paper_fig, + download_bounds=download_bounds, + layout=layout, + stats_text=stats_txt, + drawing_size_config=drawing_size_config, + drawing_style_config=drawing_style_config) logger.info("Successful GPX Processing") @@ -131,24 +137,26 @@ def get_params(self, inputs: CityDrawingInputs) -> CityDrawingParams: uphill_m_int=int(uphill_m_int), duration_s_float=stats_duration_s) - stats_text = CityDrawingInputs.build_stats_text(stats_items=new_stats_items) + stats_text, layout = CityDrawingInputs.build_stats_text(stats_items=new_stats_items) return CityDrawingParams(theme_colors=inputs.theme_colors, title_txt=title_txt, - stats_txt=stats_text) + stats_txt=stats_text, + layout=layout) def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData, base_fig: BaseDrawingFigure, - city_bounds: GpxBounds, + download_bounds: GpxBounds, layout: CityVerticalLayout, + stats_text: str, drawing_size_config: CityDrawingSizeConfig, drawing_style_config: CityDrawingStyleConfig ) -> CityDrawingFigure: """Set up and populate the DrawingFigure for the poster.""" gpx_track = gpx_data.track - caracteristic_distance_m = city_bounds.diagonal_m + caracteristic_distance_m = base_fig.gpx_bounds.diagonal_m logger.info(f"Domain diagonal is {caracteristic_distance_m/1000.:.1f}km") total_query = OverpassQuery() @@ -156,20 +164,20 @@ def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData, prepare_download_city_rivers, prepare_download_city_forests]: - prepare_func(query=total_query, bounds=city_bounds) + prepare_func(query=total_query, bounds=download_bounds) # Merge and run all queries total_query.launch_queries() # Retrieve the data roads = process_city_roads(query=total_query, - bounds=city_bounds) + bounds=download_bounds) rivers = process_city_rivers(query=total_query, - bounds=city_bounds) + bounds=download_bounds) forests, farmland = process_city_forests(query=total_query, - bounds=city_bounds) + bounds=download_bounds) forests.interior_polygons = [] track_data: list[BaseDrawingData] = [PlotData(x=gpx_track.list_lon, y=gpx_track.list_lat, linewidth=2.0)] @@ -189,13 +197,7 @@ def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData, ha="center", va="center") - stats_text = f"{gpx_track.list_cumul_dist_km[-1]:.2f} km - {int(gpx_track.uphill_m)} m D+" - - if gpx_track.duration_s is not None: - stats_text += f"\n{format_timedelta(gpx_track.duration_s)}" - - ele = ElevationStatsSection(layout, base_fig, gpx_track) - track_data.append(ele.fill_poly) + elevation_profile = ElevationStatsSection(layout, base_fig, gpx_track) stats = TextData(x=b.lon_center, y=b.lat_min + 0.5 * b.dlat * layout.stats_relative_h, s=stats_text, fontsize=mm_to_point(18.5), @@ -207,10 +209,6 @@ def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData, point_data.append(ScatterData(x=[gpx_track.list_lon[-1]], y=[gpx_track.list_lat[-1]], marker="s", markersize=mm_to_point(3.5))) - h_top_stats = b.lat_min + b.dlat * layout.stats_relative_h - track_data.append(PolyFillData(x=[b.lon_min, b.lon_max, b.lon_max, b.lon_min], - y=[h_top_stats, h_top_stats, b.lat_min, b.lat_min])) - plotter = CityDrawingFigure(paper_size=base_fig.paper_size, latlon_aspect_ratio=base_fig.latlon_aspect_ratio, gpx_bounds=base_fig.gpx_bounds, @@ -220,6 +218,7 @@ def init_and_populate_drawing_figure(gpx_data: CityAugmentedGpxData, rivers_data=rivers_data, forests_data=forests_data, farmland_data=farmland_data, + elevation_profile=elevation_profile, title=title, stats=stats) diff --git a/pretty_gpx/rendering_modes/city/drawing/city_drawing_figure.py b/pretty_gpx/rendering_modes/city/drawing/city_drawing_figure.py index 95dedb6..708adf2 100644 --- a/pretty_gpx/rendering_modes/city/drawing/city_drawing_figure.py +++ b/pretty_gpx/rendering_modes/city/drawing/city_drawing_figure.py @@ -8,9 +8,11 @@ from pretty_gpx.common.drawing.drawing_data import BaseDrawingData from pretty_gpx.common.drawing.drawing_data import PolygonCollectionData from pretty_gpx.common.drawing.drawing_data import TextData +from pretty_gpx.common.drawing.elevation_stats_section import ElevationStatsSection from pretty_gpx.common.structure import DrawingFigure from pretty_gpx.common.structure import DrawingParams from pretty_gpx.common.utils.profile import profile +from pretty_gpx.rendering_modes.city.city_vertical_layout import CityVerticalLayout from pretty_gpx.rendering_modes.city.drawing.city_colors import CityColors @@ -20,6 +22,7 @@ class CityDrawingParams(DrawingParams): theme_colors: CityColors title_txt: str stats_txt: str + layout: CityVerticalLayout @dataclass @@ -38,6 +41,7 @@ class CityDrawingFigure(DrawingFigure[CityDrawingParams]): rivers_data: list[PolygonCollectionData] forests_data: list[PolygonCollectionData] farmland_data: list[PolygonCollectionData] + elevation_profile: ElevationStatsSection title: TextData stats: TextData @@ -47,6 +51,14 @@ def draw(self, fig: Figure, ax: Axes, params: CityDrawingParams) -> None: """Plot the annotations.""" road_color = "black" if params.theme_colors.dark_mode else "white" + old_layout = self.elevation_profile.layout + new_layout = params.layout + + if old_layout != new_layout: + self.elevation_profile.update_section_with_new_layout(self, + new_layout=new_layout) + self.stats.y = self.gpx_bounds.lat_min + 0.5 * self.gpx_bounds.dlat * new_layout.stats_relative_h + self.setup(fig, ax) self.title.s = params.title_txt @@ -70,6 +82,8 @@ def draw(self, fig: Figure, ax: Axes, params: CityDrawingParams) -> None: for data6 in self.point_data: data6.plot(ax, params.theme_colors.point_color) + self.elevation_profile.fill_poly.plot(ax, params.theme_colors.track_color) + ax.set_facecolor(params.theme_colors.background_color) self.title.plot(ax, params.theme_colors.point_color)