Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mesh_plotter improvements #2772

Merged
merged 11 commits into from
Sep 23, 2024
1 change: 1 addition & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ aa04d1f7d86cc2503b98b7e2b2d84dbfff6c316b
1a49e547ba3c48fa483f9ae81a8f05adcd6b888c
045d90f1d80f713eb3ae0ac58f6c2352937f1eb0
753fda3ff0147837231a73c9c728dd9ce47b5997
f112ba0bbf96a61d5a4d354dc0dcbd8b0c68145c
26 changes: 23 additions & 3 deletions python/ctsm/mesh_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,26 @@ def get_parser():

parser.add_argument(
"--overwrite",
help="If plots xists, overwrite them.",
help="If plots exist, overwrite them.",
action="store_true",
dest="overwrite",
required=False,
)

parser.add_argument(
"--no-center-coords",
help="Do not include red Xs at center of grid cells.",
action="store_true",
required=False,
)

default_dpi = 300
parser.add_argument(
"--dpi",
help=f"Dots per square inch in output; default {default_dpi}",
type=float,
)

add_logging_args(parser)
return parser

Expand Down Expand Up @@ -98,9 +112,10 @@ def process_and_check_args(args):

today = datetime.today()
today_string = today.strftime("%y%m%d")
input_filename = os.path.basename(args.input)
args.output = os.path.join(
args.out_dir,
os.path.splitext(args.input)[0] + "_c" + today_string,
os.path.splitext(input_filename)[0] + "_c" + today_string,
)

if not os.path.isfile(args.input):
Expand Down Expand Up @@ -148,10 +163,15 @@ def main():
this_mesh.read_file(ds)

plot_regional = os.path.splitext(mesh_out)[0] + "_regional" + ".png"
file_exists_msg = "File already exists but --overwrite not given: "
if os.path.exists(plot_regional) and not args.overwrite:
raise FileExistsError(file_exists_msg + plot_regional)

plot_global = os.path.splitext(mesh_out)[0] + "_global" + ".png"
if os.path.exists(plot_global) and not args.overwrite:
raise FileExistsError(file_exists_msg + plot_global)

this_mesh.make_mesh_plot(plot_regional, plot_global)
this_mesh.make_mesh_plot(plot_regional, plot_global, args)


if __name__ == "__main__":
Expand Down
19 changes: 11 additions & 8 deletions python/ctsm/site_and_regional/mesh_plot_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MeshPlotType(MeshType):
Extend mesh type with some advanced plotting capability
"""

def make_mesh_plot(self, plot_regional, plot_global):
def make_mesh_plot(self, plot_regional, plot_global, args):
"""
Create plots for the ESMF mesh file

Expand All @@ -36,10 +36,10 @@ def make_mesh_plot(self, plot_regional, plot_global):
The path to write the ESMF meshfile global plot
"""

self.mesh_plot(plot_regional, regional=True)
self.mesh_plot(plot_global, regional=False)
self.mesh_plot(plot_regional, args, regional=True)
self.mesh_plot(plot_global, args, regional=False)

def mesh_plot(self, plot_file, regional):
def mesh_plot(self, plot_file, args, regional):
"""Make a plot of a mesh file in either a regional or global grid"""
# -- regional settings
if regional:
Expand All @@ -49,7 +49,7 @@ def mesh_plot(self, plot_file, regional):
plot_type = "regional"
line_width = 1
marker = "x"
marker_size = 1
marker_size = 50
# global settings
else:
fig = plt.figure(num=None, figsize=(15, 10), facecolor="w", edgecolor="k")
Expand All @@ -58,7 +58,9 @@ def mesh_plot(self, plot_file, regional):
plot_type = "global"
line_width = 0.5
marker = "o"
marker_size = None
marker_size = 0.1
if args.no_center_coords:
marker_size = 0

ax.add_feature(cfeature.COASTLINE, edgecolor="black")
ax.add_feature(cfeature.BORDERS, edgecolor="black")
Expand Down Expand Up @@ -129,8 +131,9 @@ def mesh_plot(self, plot_file, regional):
*[(k, mpatches.Rectangle((0, 0), 1, 1, facecolor=v)) for k, v in lc_colors.items()]
)

ax.legend(handles, labels)
if not args.no_center_coords:
ax.legend(handles, labels)

plt.savefig(plot_file, bbox_inches="tight")
plt.savefig(plot_file, bbox_inches="tight", dpi=args.dpi)

logger.info("Successfully created %s plots for ESMF Mesh file : %s", plot_type, plot_file)
52 changes: 47 additions & 5 deletions python/ctsm/test/test_advanced_sys_mesh_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,27 @@ class SysTestMeshMaker(unittest.TestCase):

def setUp(self):
"""Setup for all tests"""
testinputs_path = os.path.join(path_to_ctsm_root(), "python/ctsm/test/testinputs")
testinputs_path = os.path.join(
path_to_ctsm_root(),
"python",
"ctsm",
"test",
"testinputs",
)
self._testinputs_path = testinputs_path
self._infile = os.path.join(
testinputs_path,
"ESMF_mesh_5x5pt_amazon_from_domain_c230308.nc",
)
self._tempdir = tempfile.mkdtemp()
self.mesh_out = self._tempdir + "/mesh_out"
self.mesh_out = os.path.join(self._tempdir, "mesh_out")
self.test_basic_argv = [
"mesh_plotter",
"--input",
self._infile,
"--output",
self.mesh_out,
]

def tearDown(self):
"""
Expand All @@ -43,15 +56,44 @@ def tearDown(self):

def test_basic(self):
"""Do a simple basic test"""
sys.argv = self.test_basic_argv
main()
plotfiles = glob.glob(os.path.join(self._tempdir, "*.png"))
if not plotfiles:
self.fail("plot files were NOT created as they should have")

def test_dpi(self):
"""Test setting dpi"""
sys.argv = self.test_basic_argv + [
"--dpi",
"198.7",
]
main()
plotfiles = glob.glob(os.path.join(self._tempdir, "*.png"))
if not plotfiles:
self.fail("plot files were NOT created as they should have")

def test_need_overwrite(self):
"""Ensure failure if output file exists but --overwrite not given"""
sys.argv = self.test_basic_argv
main()
with self.assertRaisesRegex(
FileExistsError, "File already exists but --overwrite not given"
):
main()

def test_outdir(self):
"""Test that --outdir option works"""
outdir = os.path.join(self._tempdir, "abc123")
sys.argv = [
"mesh_plotter",
"--input",
self._infile,
"--output",
self.mesh_out,
"--outdir",
outdir,
]
main()
plotfiles = glob.glob(self._tempdir + "/*.png")
plotfiles = glob.glob(os.path.join(outdir, "*.png"))
if not plotfiles:
self.fail("plot files were NOT created as they should have")

Expand Down
Loading