From 3c62e11ebb0e350e11436c501ac47f8ba83f1c57 Mon Sep 17 00:00:00 2001 From: Clarmy Lee Date: Sat, 20 Jul 2024 00:14:39 +0800 Subject: [PATCH 1/3] perf: add dilution_interval to speed up. --- cnmaps/drawing.py | 2 ++ cnmaps/maps.py | 31 +++++++++++++++++++++++++------ tests/test_drawing.py | 30 ++++++++++++++++++++++-------- tests/test_map.py | 32 +++++++++++++++++++++++++------- tests/test_perf.py | 38 +++++++++++++++++++++++++++++++------- 5 files changed, 105 insertions(+), 28 deletions(-) diff --git a/cnmaps/drawing.py b/cnmaps/drawing.py index d861987..b00844d 100644 --- a/cnmaps/drawing.py +++ b/cnmaps/drawing.py @@ -237,6 +237,8 @@ def clip_clabels_by_map( if ax is None: ax = plt.gca() map_polygon = _transform_polygon(map_polygon, ccrs.PlateCarree(), ax.projection) + if not map_polygon.is_valid: + map_polygon = map_polygon.buffer(0) for cbt in clabel_text: point = sgeom.Point(cbt.get_position()) diff --git a/cnmaps/maps.py b/cnmaps/maps.py index 4906fc8..be2c8e8 100644 --- a/cnmaps/maps.py +++ b/cnmaps/maps.py @@ -198,17 +198,25 @@ def make_mask_array(self, lons: np.ndarray, lats: np.ndarray): return ~contains(self, lons, lats) -def read_mapjson(fp, wgs84=True): +def read_mapjson(fp, wgs84=True, dilution_interval=1): """ 读取geojson地图边界文件 参数: fp (str, 可选): geojson文件名. wgs84 (bool, 可选): 是否使用 WGS84 坐标 + dilution_interval (int, 可选): 稀疏间隔, 默认为1, 即不稀疏. + 该值越大, 稀释得越强烈. 默认为1. 返回值: MapPolygon: 地图边界对象 """ + if not isinstance(dilution_interval, int): + raise ValueError("dilution_interval必须为整数") + + if dilution_interval < 1: + raise ValueError("dilution_interval必须大于等于1") + with open(fp, encoding="utf-8") as f: map_json = orjson.loads(f.read()) @@ -221,11 +229,14 @@ def read_mapjson(fp, wgs84=True): if "Polygon" in geometry["type"]: for _coords in geometry["coordinates"]: for coords in _coords: + __coords = coords[::dilution_interval] + if len(__coords) < 3: + __coords = coords if wgs84: - wgs84_coords = [gcj02_to_wgs84(*coord) for coord in coords] + wgs84_coords = [gcj02_to_wgs84(*coord) for coord in __coords] polygon_list.append(sgeom.Polygon(wgs84_coords)) else: - polygon_list.append(sgeom.Polygon(coords)) + polygon_list.append(sgeom.Polygon(__coords)) return MapPolygon(polygon_list) @@ -295,6 +306,7 @@ def get_adm_maps( only_polygon: bool = False, wgs84=True, simplify=False, + dilution_interval=1, *args, **kwargs, ): @@ -332,6 +344,8 @@ def get_adm_maps( Defaults to False. wgs84 (bool, 可选): 是否使用 WGS84 坐标系, 若为 True 则转为 WGS84 坐标, 若为 False 则使用高德默认的 GCJ02 火星坐标。Defaults to True. + dilution_interval (int, 可选): 稀疏间隔, 默认为1, 即不稀疏. + 该值越大, 稀释得越强烈. Defaults to 1. simplify (bool, 可选): 是否对边界进行简化, 若为 True 则进行简化处理, 否则不做简化。Defaults to True. 异常: @@ -412,7 +426,9 @@ def get_adm_maps( elif level in ["区", "县", "区县", "区/县"]: level_sql = "level='区县'" else: - raise ValueError(f'无法识别level等级: {level}, level参数请从"国", "省", "市", "区县"中选择') + raise ValueError( + f'无法识别level等级: {level}, level参数请从"国", "省", "市", "区县"中选择' + ) meta_sql = ( "SELECT country, province, city, district, level, source, kind" @@ -432,13 +448,16 @@ def get_adm_maps( map_polygons = [] for path in gemo_rows: mapjson = read_mapjson( - os.path.join(DATA_DIR, "geojson.min/", path[0]), wgs84=wgs84 + os.path.join(DATA_DIR, "geojson.min/", path[0]), + wgs84=wgs84, + dilution_interval=dilution_interval, ) map_polygons.append(mapjson) gdf = gpd.GeoDataFrame( - data=meta_rows, columns=["国家", "省/直辖市", "市", "区/县", "级别", "来源", "类型"] + data=meta_rows, + columns=["国家", "省/直辖市", "市", "区/县", "级别", "来源", "类型"], ) gdf.set_geometry(map_polygons, inplace=True) diff --git a/tests/test_drawing.py b/tests/test_drawing.py index 9b9c475..f50f63c 100644 --- a/tests/test_drawing.py +++ b/tests/test_drawing.py @@ -25,14 +25,19 @@ sample_districts = [random.choice(districts) for _ in range(100)] map_args = [ - {"only_polygon": True, "record": "first", "name": "中华人民共和国", "simplify": True} + { + "only_polygon": True, + "record": "first", + "name": "中华人民共和国", + "dilution_interval": 10, + } ] + [ { "province": p, "only_polygon": True, "record": "first", "name": p, - "simplify": True, + "dilution_interval": 10, } for p in ["黑龙江省", "内蒙古自治区"] ] @@ -41,15 +46,22 @@ def test_draw_maps(): """测试多地图绘制功能""" map_args = ( - [{"level": "国", "name": "中华人民共和国", "simplify": True}] - + [{"level": "省", "name": "中华人民共和国-分省", "simplify": True}] - + [{"level": "国", "engine": "geopandas", "name": "中华人民共和国", "simplify": True}] + [{"level": "国", "name": "中华人民共和国", "dilution_interval": 10}] + + [{"level": "省", "name": "中华人民共和国-分省", "dilution_interval": 10}] + + [ + { + "level": "国", + "engine": "geopandas", + "name": "中华人民共和国", + "dilution_interval": 10, + } + ] + [ { "level": "省", "engine": "geopandas", "name": "中华人民共和国-分省", - "simplify": True, + "dilution_interval": 10, } ] ) @@ -221,7 +233,7 @@ def test_clip_clabel(): lons, lats, data = load_dem() - map_polygon = get_adm_maps(record="first", only_polygon=True) + map_polygon = get_adm_maps(record="first", only_polygon=True, dilution_interval=10) fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) contours = ax.contour( @@ -272,7 +284,9 @@ def test_projection(): lons, lats, data = load_dem() for projection in PROJECTIONS: - map_polygon = get_adm_maps(province="河南省", record="first", only_polygon=True) + map_polygon = get_adm_maps( + record="first", only_polygon=True, dilution_interval=10 + ) fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection=projection) contours = ax.contourf( diff --git a/tests/test_map.py b/tests/test_map.py index 4cc58a3..9607298 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -127,6 +127,16 @@ def test_get_map_by_fp(): read_mapjson(fp) +def test_read_mapjson_with_dilution(): + pattern = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../cnmaps/data/geojson.min/*/*/*/*.geojson", + ) + fps = sorted(glob(pattern)) + for fp in fps: + read_mapjson(fp, dilution_interval=10) + + def test_map_load(): """测试各级地图数量是否完整,以及各种规则是否都能加载成功.""" assert len(get_adm_maps(level="国")) == 2 @@ -153,7 +163,11 @@ def test_map_load(): ) beijing = get_adm_maps(city="北京市")[0] - assert beijing["市"] == "北京市" and beijing["区/县"] is None and beijing["级别"] == "市" + assert ( + beijing["市"] == "北京市" + and beijing["区/县"] is None + and beijing["级别"] == "市" + ) chaoyang = get_adm_maps(district="朝阳区") assert len(chaoyang) == 2 @@ -261,7 +275,7 @@ def test_province_orthogonality(): couples = sorted([couple for couple in combinations(province_names, r=2)]) - for (one, another) in couples: + for one, another in couples: assert ( get_adm_maps(province=one)[0]["geometry"] & get_adm_maps(province=another)[0]["geometry"] @@ -295,7 +309,7 @@ def test_city_orthogonality(): sorted(("阿勒泰地区", "北屯市")), ] - for (one, another) in couples: + for one, another in couples: if sorted([one, another]) in problem_set: continue area = ( @@ -312,7 +326,7 @@ def test_province_union(): couples = sorted([couple for couple in combinations(province_names, r=2)]) - for (one, another) in couples: + for one, another in couples: _ = ( get_adm_maps(province=one)[0]["geometry"] + get_adm_maps(province=another)[0]["geometry"] @@ -326,7 +340,7 @@ def test_province_difference(): couples = sorted([couple for couple in product(province_names, repeat=2)]) - for (one, another) in couples: + for one, another in couples: _ = ( get_adm_maps(province=one)[0]["geometry"] - get_adm_maps(province=another)[0]["geometry"] @@ -360,13 +374,17 @@ def test_get_extent(): def test_only_polygon_and_record(): """测试only_polygon参数和record参数功能.""" - polygons = get_adm_maps(city="北京市", record="all", level="区县", only_polygon=True) + polygons = get_adm_maps( + city="北京市", record="all", level="区县", only_polygon=True + ) assert isinstance(polygons, list) assert len(polygons) == 16 for p in polygons: assert isinstance(p, MapPolygon) - polygon = get_adm_maps(city="北京市", record="first", level="区县", only_polygon=True) + polygon = get_adm_maps( + city="北京市", record="first", level="区县", only_polygon=True + ) assert isinstance(polygon, MapPolygon) meta = get_adm_maps(city="北京市", record="first", level="市") diff --git a/tests/test_perf.py b/tests/test_perf.py index 9553fa9..062f658 100644 --- a/tests/test_perf.py +++ b/tests/test_perf.py @@ -34,7 +34,7 @@ "only_polygon": True, "record": "first", "name": "中华人民共和国", - "simplify": True, + "dilution_interval": 10 } @@ -233,7 +233,9 @@ def test_clip_clabel(benchmark): def inner(): lons, lats, data = load_dem() - map_polygon = get_adm_maps(record="first", only_polygon=True, simplify=True) + map_polygon = get_adm_maps( + record="first", only_polygon=True, dilution_interval=10 + ) fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) contours = ax.contour( @@ -269,7 +271,9 @@ def test_projection(benchmark): def inner(): lons, lats, data = load_dem() - map_polygon = get_adm_maps(record="first", only_polygon=True, simplify=True) + map_polygon = get_adm_maps( + record="first", only_polygon=True, dilution_interval=10 + ) fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection=ccrs.Orthographic(central_longitude=100)) contours = ax.contourf( @@ -304,7 +308,11 @@ def inner(): mask_array = np.load(casefp) map_polygon = get_adm_maps( - province="宁夏回族自治区", only_polygon=True, record="first", wgs84=False + province="宁夏回族自治区", + only_polygon=True, + record="first", + wgs84=False, + dilution_interval=10, ) lons, lats, data = load_dem() @@ -322,7 +330,11 @@ def inner(): mask_array = np.load(casefp) map_polygon = get_adm_maps( - province="宁夏回族自治区", only_polygon=True, record="first", wgs84=True + province="宁夏回族自治区", + only_polygon=True, + record="first", + wgs84=True, + dilution_interval=10, ) lons, lats, data = load_dem() @@ -350,7 +362,13 @@ def inner(): lat = np.linspace(0, 60, 1000) lons, lats = np.meshgrid(lon, lat) - china = get_adm_maps(level="国", record="first", only_polygon=True, wgs84=False) + china = get_adm_maps( + level="国", + record="first", + only_polygon=True, + wgs84=False, + dilution_interval=10, + ) china_maskout_array = china.make_mask_array(lons, lats) assert (china_maskout_array == mask_array).all() @@ -362,7 +380,13 @@ def inner(): lat = np.linspace(0, 60, 1000) lons, lats = np.meshgrid(lon, lat) - china = get_adm_maps(level="国", record="first", only_polygon=True, wgs84=True) + china = get_adm_maps( + level="国", + record="first", + only_polygon=True, + wgs84=True, + dilution_interval=10, + ) china_maskout_array = china.make_mask_array(lons, lats) assert (china_maskout_array == mask_array).all() From 831be6ace8311944b60fc8fe33ee6ecb1f7a848e Mon Sep 17 00:00:00 2001 From: Clarmy Lee Date: Sat, 20 Jul 2024 00:27:58 +0800 Subject: [PATCH 2/3] tests: fix --- tests/test_perf.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_perf.py b/tests/test_perf.py index 062f658..2f2de51 100644 --- a/tests/test_perf.py +++ b/tests/test_perf.py @@ -34,7 +34,7 @@ "only_polygon": True, "record": "first", "name": "中华人民共和国", - "dilution_interval": 10 + "dilution_interval": 10, } @@ -330,11 +330,7 @@ def inner(): mask_array = np.load(casefp) map_polygon = get_adm_maps( - province="宁夏回族自治区", - only_polygon=True, - record="first", - wgs84=True, - dilution_interval=10, + province="宁夏回族自治区", only_polygon=True, record="first", wgs84=True ) lons, lats, data = load_dem() @@ -385,7 +381,6 @@ def inner(): record="first", only_polygon=True, wgs84=True, - dilution_interval=10, ) china_maskout_array = china.make_mask_array(lons, lats) From 3a4a5c1e8a63b9c06278c9ada655087712016a43 Mon Sep 17 00:00:00 2001 From: Clarmy Lee Date: Sat, 20 Jul 2024 00:46:35 +0800 Subject: [PATCH 3/3] tests: fix --- tests/test_perf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_perf.py b/tests/test_perf.py index 2f2de51..b670bb3 100644 --- a/tests/test_perf.py +++ b/tests/test_perf.py @@ -312,7 +312,6 @@ def inner(): only_polygon=True, record="first", wgs84=False, - dilution_interval=10, ) lons, lats, data = load_dem() @@ -363,7 +362,6 @@ def inner(): record="first", only_polygon=True, wgs84=False, - dilution_interval=10, ) china_maskout_array = china.make_mask_array(lons, lats)