Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: add dilution to speed up. #128

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cnmaps/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def clip_clabels_by_map(
if ax is None:
ax = plt.gca()
map_polygon = _transform_polygon(map_polygon, ccrs.PlateCarree(), ax.projection)
if not map_polygon.is_valid:
map_polygon = map_polygon.buffer(0)

for cbt in clabel_text:
point = sgeom.Point(cbt.get_position())
Expand Down
31 changes: 25 additions & 6 deletions cnmaps/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,25 @@ def make_mask_array(self, lons: np.ndarray, lats: np.ndarray):
return ~contains(self, lons, lats)


def read_mapjson(fp, wgs84=True):
def read_mapjson(fp, wgs84=True, dilution_interval=1):
"""
读取geojson地图边界文件

参数:
fp (str, 可选): geojson文件名.
wgs84 (bool, 可选): 是否使用 WGS84 坐标
dilution_interval (int, 可选): 稀疏间隔, 默认为1, 即不稀疏.
该值越大, 稀释得越强烈. 默认为1.

返回值:
MapPolygon: 地图边界对象
"""
if not isinstance(dilution_interval, int):
raise ValueError("dilution_interval必须为整数")

if dilution_interval < 1:
raise ValueError("dilution_interval必须大于等于1")

with open(fp, encoding="utf-8") as f:
map_json = orjson.loads(f.read())

Expand All @@ -221,11 +229,14 @@ def read_mapjson(fp, wgs84=True):
if "Polygon" in geometry["type"]:
for _coords in geometry["coordinates"]:
for coords in _coords:
__coords = coords[::dilution_interval]
if len(__coords) < 3:
__coords = coords
if wgs84:
wgs84_coords = [gcj02_to_wgs84(*coord) for coord in coords]
wgs84_coords = [gcj02_to_wgs84(*coord) for coord in __coords]
polygon_list.append(sgeom.Polygon(wgs84_coords))
else:
polygon_list.append(sgeom.Polygon(coords))
polygon_list.append(sgeom.Polygon(__coords))

return MapPolygon(polygon_list)

Expand Down Expand Up @@ -295,6 +306,7 @@ def get_adm_maps(
only_polygon: bool = False,
wgs84=True,
simplify=False,
dilution_interval=1,
*args,
**kwargs,
):
Expand Down Expand Up @@ -332,6 +344,8 @@ def get_adm_maps(
Defaults to False.
wgs84 (bool, 可选): 是否使用 WGS84 坐标系, 若为 True 则转为 WGS84 坐标,
若为 False 则使用高德默认的 GCJ02 火星坐标。Defaults to True.
dilution_interval (int, 可选): 稀疏间隔, 默认为1, 即不稀疏.
该值越大, 稀释得越强烈. Defaults to 1.
simplify (bool, 可选): 是否对边界进行简化, 若为 True 则进行简化处理, 否则不做简化。Defaults to True.

异常:
Expand Down Expand Up @@ -412,7 +426,9 @@ def get_adm_maps(
elif level in ["区", "县", "区县", "区/县"]:
level_sql = "level='区县'"
else:
raise ValueError(f'无法识别level等级: {level}, level参数请从"国", "省", "市", "区县"中选择')
raise ValueError(
f'无法识别level等级: {level}, level参数请从"国", "省", "市", "区县"中选择'
)

meta_sql = (
"SELECT country, province, city, district, level, source, kind"
Expand All @@ -432,13 +448,16 @@ def get_adm_maps(
map_polygons = []
for path in gemo_rows:
mapjson = read_mapjson(
os.path.join(DATA_DIR, "geojson.min/", path[0]), wgs84=wgs84
os.path.join(DATA_DIR, "geojson.min/", path[0]),
wgs84=wgs84,
dilution_interval=dilution_interval,
)

map_polygons.append(mapjson)

gdf = gpd.GeoDataFrame(
data=meta_rows, columns=["国家", "省/直辖市", "市", "区/县", "级别", "来源", "类型"]
data=meta_rows,
columns=["国家", "省/直辖市", "市", "区/县", "级别", "来源", "类型"],
)
gdf.set_geometry(map_polygons, inplace=True)

Expand Down
30 changes: 22 additions & 8 deletions tests/test_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@
sample_districts = [random.choice(districts) for _ in range(100)]

map_args = [
{"only_polygon": True, "record": "first", "name": "中华人民共和国", "simplify": True}
{
"only_polygon": True,
"record": "first",
"name": "中华人民共和国",
"dilution_interval": 10,
}
] + [
{
"province": p,
"only_polygon": True,
"record": "first",
"name": p,
"simplify": True,
"dilution_interval": 10,
}
for p in ["黑龙江省", "内蒙古自治区"]
]
Expand All @@ -41,15 +46,22 @@
def test_draw_maps():
"""测试多地图绘制功能"""
map_args = (
[{"level": "国", "name": "中华人民共和国", "simplify": True}]
+ [{"level": "省", "name": "中华人民共和国-分省", "simplify": True}]
+ [{"level": "国", "engine": "geopandas", "name": "中华人民共和国", "simplify": True}]
[{"level": "国", "name": "中华人民共和国", "dilution_interval": 10}]
+ [{"level": "省", "name": "中华人民共和国-分省", "dilution_interval": 10}]
+ [
{
"level": "国",
"engine": "geopandas",
"name": "中华人民共和国",
"dilution_interval": 10,
}
]
+ [
{
"level": "省",
"engine": "geopandas",
"name": "中华人民共和国-分省",
"simplify": True,
"dilution_interval": 10,
}
]
)
Expand Down Expand Up @@ -221,7 +233,7 @@ def test_clip_clabel():

lons, lats, data = load_dem()

map_polygon = get_adm_maps(record="first", only_polygon=True)
map_polygon = get_adm_maps(record="first", only_polygon=True, dilution_interval=10)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree())
contours = ax.contour(
Expand Down Expand Up @@ -272,7 +284,9 @@ def test_projection():
lons, lats, data = load_dem()

for projection in PROJECTIONS:
map_polygon = get_adm_maps(province="河南省", record="first", only_polygon=True)
map_polygon = get_adm_maps(
record="first", only_polygon=True, dilution_interval=10
)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection=projection)
contours = ax.contourf(
Expand Down
32 changes: 25 additions & 7 deletions tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def test_get_map_by_fp():
read_mapjson(fp)


def test_read_mapjson_with_dilution():
pattern = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../cnmaps/data/geojson.min/*/*/*/*.geojson",
)
fps = sorted(glob(pattern))
for fp in fps:
read_mapjson(fp, dilution_interval=10)


def test_map_load():
"""测试各级地图数量是否完整,以及各种规则是否都能加载成功."""
assert len(get_adm_maps(level="国")) == 2
Expand All @@ -153,7 +163,11 @@ def test_map_load():
)

beijing = get_adm_maps(city="北京市")[0]
assert beijing["市"] == "北京市" and beijing["区/县"] is None and beijing["级别"] == "市"
assert (
beijing["市"] == "北京市"
and beijing["区/县"] is None
and beijing["级别"] == "市"
)

chaoyang = get_adm_maps(district="朝阳区")
assert len(chaoyang) == 2
Expand Down Expand Up @@ -261,7 +275,7 @@ def test_province_orthogonality():

couples = sorted([couple for couple in combinations(province_names, r=2)])

for (one, another) in couples:
for one, another in couples:
assert (
get_adm_maps(province=one)[0]["geometry"]
& get_adm_maps(province=another)[0]["geometry"]
Expand Down Expand Up @@ -295,7 +309,7 @@ def test_city_orthogonality():
sorted(("阿勒泰地区", "北屯市")),
]

for (one, another) in couples:
for one, another in couples:
if sorted([one, another]) in problem_set:
continue
area = (
Expand All @@ -312,7 +326,7 @@ def test_province_union():

couples = sorted([couple for couple in combinations(province_names, r=2)])

for (one, another) in couples:
for one, another in couples:
_ = (
get_adm_maps(province=one)[0]["geometry"]
+ get_adm_maps(province=another)[0]["geometry"]
Expand All @@ -326,7 +340,7 @@ def test_province_difference():

couples = sorted([couple for couple in product(province_names, repeat=2)])

for (one, another) in couples:
for one, another in couples:
_ = (
get_adm_maps(province=one)[0]["geometry"]
- get_adm_maps(province=another)[0]["geometry"]
Expand Down Expand Up @@ -360,13 +374,17 @@ def test_get_extent():

def test_only_polygon_and_record():
"""测试only_polygon参数和record参数功能."""
polygons = get_adm_maps(city="北京市", record="all", level="区县", only_polygon=True)
polygons = get_adm_maps(
city="北京市", record="all", level="区县", only_polygon=True
)
assert isinstance(polygons, list)
assert len(polygons) == 16
for p in polygons:
assert isinstance(p, MapPolygon)

polygon = get_adm_maps(city="北京市", record="first", level="区县", only_polygon=True)
polygon = get_adm_maps(
city="北京市", record="first", level="区县", only_polygon=True
)
assert isinstance(polygon, MapPolygon)

meta = get_adm_maps(city="北京市", record="first", level="市")
Expand Down
29 changes: 23 additions & 6 deletions tests/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"only_polygon": True,
"record": "first",
"name": "中华人民共和国",
"simplify": True,
"dilution_interval": 10,
}


Expand Down Expand Up @@ -233,7 +233,9 @@ def test_clip_clabel(benchmark):
def inner():
lons, lats, data = load_dem()

map_polygon = get_adm_maps(record="first", only_polygon=True, simplify=True)
map_polygon = get_adm_maps(
record="first", only_polygon=True, dilution_interval=10
)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree())
contours = ax.contour(
Expand Down Expand Up @@ -269,7 +271,9 @@ def test_projection(benchmark):
def inner():
lons, lats, data = load_dem()

map_polygon = get_adm_maps(record="first", only_polygon=True, simplify=True)
map_polygon = get_adm_maps(
record="first", only_polygon=True, dilution_interval=10
)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection=ccrs.Orthographic(central_longitude=100))
contours = ax.contourf(
Expand Down Expand Up @@ -304,7 +308,10 @@ def inner():
mask_array = np.load(casefp)

map_polygon = get_adm_maps(
province="宁夏回族自治区", only_polygon=True, record="first", wgs84=False
province="宁夏回族自治区",
only_polygon=True,
record="first",
wgs84=False,
)

lons, lats, data = load_dem()
Expand Down Expand Up @@ -350,7 +357,12 @@ def inner():
lat = np.linspace(0, 60, 1000)
lons, lats = np.meshgrid(lon, lat)

china = get_adm_maps(level="国", record="first", only_polygon=True, wgs84=False)
china = get_adm_maps(
level="国",
record="first",
only_polygon=True,
wgs84=False,
)
china_maskout_array = china.make_mask_array(lons, lats)

assert (china_maskout_array == mask_array).all()
Expand All @@ -362,7 +374,12 @@ def inner():
lat = np.linspace(0, 60, 1000)
lons, lats = np.meshgrid(lon, lat)

china = get_adm_maps(level="国", record="first", only_polygon=True, wgs84=True)
china = get_adm_maps(
level="国",
record="first",
only_polygon=True,
wgs84=True,
)
china_maskout_array = china.make_mask_array(lons, lats)

assert (china_maskout_array == mask_array).all()
Expand Down
Loading