Skip to content

Commit

Permalink
Fix GEOS_WRAPPER tracer initialization, ingest, outgest
Browse files Browse the repository at this point in the history
Remove unused "make_mapping"
  • Loading branch information
FlorianDeconinck committed Dec 19, 2024
1 parent 81bcaeb commit 1cea564
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 78 deletions.
6 changes: 0 additions & 6 deletions pyFV3/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,3 @@ def make_from_4D_array(
unit=cls.unit,
)
return tracers

@staticmethod
def make_mapping(tracer_data: np.ndarray):
if len(tracer_data.shape) != 4:
raise ValueError("Expected 4D field as input")
return tracer_data.shape[3] * [None]
114 changes: 42 additions & 72 deletions pyFV3/wrappers/geos_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mpi4py import MPI

import pyFV3
import pyFV3.tracers
from ndsl import (
CompilationConfig,
CubedSphereCommunicator,
Expand All @@ -36,6 +37,17 @@
from ndsl.utils import safe_assign_array


GEOS_TRACER_MAPPING = [
"vapor",
"liquid",
"ice",
"rain",
"snow",
"graupel",
"cloud",
]


class StencilBackendCompilerOverride:
"""Override the Pace global stencil JIT to allow for 9-rank build
on any setup.
Expand Down Expand Up @@ -104,8 +116,23 @@ def __init__(
bdt: int,
comm: Comm,
backend: str,
water_tracers_count: int,
all_tracers_count: int,
fortran_mem_space: MemorySpace = MemorySpace.HOST,
):
# Check for water species configuration not handled by the interface
if water_tracers_count != 6:
raise NotImplementedError(
"[pyFV3 Bridge] Bridge expect 6 water species,"
f" got {water_tracers_count}."
)

# Build the full tracer mapping by appending None to the expected tracer list
# based on parameter
self._tracers_mapping = GEOS_TRACER_MAPPING
for i in range(all_tracers_count, len(GEOS_TRACER_MAPPING)):
self._tracers_mapping.append(f"tracer_#{i}")

# Look for an override to run on a single node
gtfv3_single_rank_override = int(os.getenv("GTFV3_SINGLE_RANK_OVERRIDE", -1))
if gtfv3_single_rank_override >= 0:
Expand Down Expand Up @@ -137,7 +164,7 @@ def __init__(
metric_terms = MetricTerms(
quantity_factory=quantity_factory,
communicator=self.communicator,
eta_file=namelist["grid_config"]["config"]["eta_file"],
eta_file=namelist["grid_config"]["config"]["eta_file"], # type: ignore
)
grid_data = GridData.new_from_metric_terms(metric_terms)

Expand Down Expand Up @@ -173,7 +200,8 @@ def __init__(
)

self.dycore_state = pyFV3.DycoreState.init_zeros(
quantity_factory=quantity_factory
quantity_factory=quantity_factory,
tracer_list=self._tracers_mapping,
)
self.dycore_state.bdt = self.dycore_config.dt_atmos

Expand All @@ -190,6 +218,7 @@ def __init__(
timestep=timedelta(seconds=self.dycore_state.bdt),
phis=self.dycore_state.phis,
state=self.dycore_state,
exclude_tracers=[],
)

self._fortran_mem_space = fortran_mem_space
Expand All @@ -198,7 +227,6 @@ def __init__(
)

self.output_dict: Dict[str, np.ndarray] = {}
self._allocate_output_dir()

# Feedback information
device_ordinal_info = (
Expand Down Expand Up @@ -368,15 +396,11 @@ def _put_fortran_data_in_dycore(
safe_assign_array(state.omga.view[:], omga[isc:iec, jsc:jec, :])
safe_assign_array(state.diss_estd.view[:], diss_estd[isc:iec, jsc:jec, :])

# tracer quantities should be a 4d array in order:
# vapor, liquid, ice, rain, snow, graupel, cloud
safe_assign_array(state.qvapor.view[:], q[isc:iec, jsc:jec, :, 0])
safe_assign_array(state.qliquid.view[:], q[isc:iec, jsc:jec, :, 1])
safe_assign_array(state.qice.view[:], q[isc:iec, jsc:jec, :, 2])
safe_assign_array(state.qrain.view[:], q[isc:iec, jsc:jec, :, 3])
safe_assign_array(state.qsnow.view[:], q[isc:iec, jsc:jec, :, 4])
safe_assign_array(state.qgraupel.view[:], q[isc:iec, jsc:jec, :, 5])
safe_assign_array(state.qcld.view[:], q[isc:iec, jsc:jec, :, 6])
# Copy tracer data
for index, name in enumerate(self._tracers_mapping):
safe_assign_array(
state.tracers[name].view[:], q[isc:iec, jsc:jec, :, index]
)

return state

Expand All @@ -388,6 +412,7 @@ def _prep_outputs_for_geos(self) -> Dict[str, np.ndarray]:
jec = self._grid_indexing.jec + 1

if self._fortran_mem_space != self._pace_mem_space:
self._allocate_output_dir()
safe_assign_array(output_dict["u"], self.dycore_state.u.data[:-1, :, :-1])
safe_assign_array(output_dict["v"], self.dycore_state.v.data[:, :-1, :-1])
safe_assign_array(output_dict["w"], self.dycore_state.w.data[:-1, :-1, :-1])
Expand Down Expand Up @@ -453,27 +478,8 @@ def _prep_outputs_for_geos(self) -> Dict[str, np.ndarray]:
self.dycore_state.diss_estd.data[:-1, :-1, :-1],
)

safe_assign_array(
output_dict["qvapor"], self.dycore_state.qvapor.data[:-1, :-1, :-1]
)
safe_assign_array(
output_dict["qliquid"], self.dycore_state.qliquid.data[:-1, :-1, :-1]
)
safe_assign_array(
output_dict["qice"], self.dycore_state.qice.data[:-1, :-1, :-1]
)
safe_assign_array(
output_dict["qrain"], self.dycore_state.qrain.data[:-1, :-1, :-1]
)
safe_assign_array(
output_dict["qsnow"], self.dycore_state.qsnow.data[:-1, :-1, :-1]
)
safe_assign_array(
output_dict["qgraupel"], self.dycore_state.qgraupel.data[:-1, :-1, :-1]
)
safe_assign_array(
output_dict["qcld"], self.dycore_state.qcld.data[:-1, :-1, :-1]
)
# Copy tracer data
safe_assign_array(output_dict["q"], self.dycore_state.tracers.as_4D_array())
else:
output_dict["u"] = self.dycore_state.u.data[:-1, :, :-1]
output_dict["v"] = self.dycore_state.v.data[:, :-1, :-1]
Expand Down Expand Up @@ -504,23 +510,18 @@ def _prep_outputs_for_geos(self) -> Dict[str, np.ndarray]:
output_dict["q_con"] = self.dycore_state.q_con.data[:-1, :-1, :-1]
output_dict["omga"] = self.dycore_state.omga.data[:-1, :-1, :-1]
output_dict["diss_estd"] = self.dycore_state.diss_estd.data[:-1, :-1, :-1]
output_dict["qvapor"] = self.dycore_state.qvapor.data[:-1, :-1, :-1]
output_dict["qliquid"] = self.dycore_state.qliquid.data[:-1, :-1, :-1]
output_dict["qice"] = self.dycore_state.qice.data[:-1, :-1, :-1]
output_dict["qrain"] = self.dycore_state.qrain.data[:-1, :-1, :-1]
output_dict["qsnow"] = self.dycore_state.qsnow.data[:-1, :-1, :-1]
output_dict["qgraupel"] = self.dycore_state.qgraupel.data[:-1, :-1, :-1]
output_dict["qcld"] = self.dycore_state.qcld.data[:-1, :-1, :-1]
output_dict["q"] = self.dycore_state.tracers.as_4D_array()

return output_dict

def _allocate_output_dir(self):
if len(self.output_dict) != 0:
return
if self._fortran_mem_space != self._pace_mem_space:
nhalo = self._grid_indexing.n_halo
shape_centered = self._grid_indexing.domain_full(add=(0, 0, 0))
shape_x_interface = self._grid_indexing.domain_full(add=(1, 0, 0))
shape_y_interface = self._grid_indexing.domain_full(add=(0, 1, 0))
shape_z_interface = self._grid_indexing.domain_full(add=(0, 0, 1))
shape_2d = shape_centered[:-1]

self.output_dict["u"] = np.empty((shape_y_interface))
Expand Down Expand Up @@ -573,34 +574,3 @@ def _allocate_output_dir(self):
self.output_dict["qsnow"] = np.empty((shape_centered))
self.output_dict["qgraupel"] = np.empty((shape_centered))
self.output_dict["qcld"] = np.empty((shape_centered))
else:
self.output_dict["u"] = None
self.output_dict["v"] = None
self.output_dict["w"] = None
self.output_dict["ua"] = None
self.output_dict["va"] = None
self.output_dict["uc"] = None
self.output_dict["vc"] = None
self.output_dict["delz"] = None
self.output_dict["pt"] = None
self.output_dict["delp"] = None
self.output_dict["mfxd"] = None
self.output_dict["mfyd"] = None
self.output_dict["cxd"] = None
self.output_dict["cyd"] = None
self.output_dict["ps"] = None
self.output_dict["pe"] = None
self.output_dict["pk"] = None
self.output_dict["peln"] = None
self.output_dict["pkz"] = None
self.output_dict["phis"] = None
self.output_dict["q_con"] = None
self.output_dict["omga"] = None
self.output_dict["diss_estd"] = None
self.output_dict["qvapor"] = None
self.output_dict["qliquid"] = None
self.output_dict["qice"] = None
self.output_dict["qrain"] = None
self.output_dict["qsnow"] = None
self.output_dict["qgraupel"] = None
self.output_dict["qcld"] = None

0 comments on commit 1cea564

Please sign in to comment.