Skip to content

Commit

Permalink
Fix translate_init_case
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Dec 19, 2024
1 parent 470e188 commit c9a40e8
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions tests/savepoint/translate/translate_init_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest

import ndsl.constants as constants
import ndsl.dsl.gt4py_utils as utils
import pyFV3.initialization.analytic_init as analytic_init
import pyFV3.initialization.init_utils as init_utils
import pyFV3.initialization.test_cases.initialize_baroclinic as baroclinic_init
Expand All @@ -20,7 +19,6 @@
)
from ndsl.grid import GridData, MetricTerms
from ndsl.stencils.testing import ParallelTranslateBaseSlicing
from ndsl.stencils.testing.grid import TRACER_DIM # type: ignore
from pyFV3.testing import TranslateDycoreFortranData2Py


Expand Down Expand Up @@ -112,7 +110,7 @@ class TranslateInitCase(ParallelTranslateBaseSlicing):
},
"q4d": {
"name": "tracers",
"dims": [X_DIM, Y_DIM, Z_DIM, TRACER_DIM],
"dims": [X_DIM, Y_DIM, Z_DIM, "tracers"],
"units": "kg/kg",
},
}
Expand Down Expand Up @@ -166,6 +164,10 @@ def __init__(
self.ignore_near_zero_errors[var] = {"near_zero": 2e-13}
self.namelist = namelist # type: ignore
self.stencil_factory = stencil_factory
self._quantity_factory = QuantityFactory.from_backend(
sizer=stencil_factory.grid_indexing._sizer,
backend=stencil_factory.backend,
)

def compute_sequential(self, *args, **kwargs):
pytest.skip(
Expand All @@ -177,10 +179,8 @@ def outputs_from_state(self, state: dict):
outputs = {}
arrays = {}
for name, properties in self.outputs.items():
if isinstance(state[name], dict):
for tracer, quantity in state[name].items():
state[name][tracer] = state[name][tracer].data
arrays[name] = state[name]
if name == "q4d":
arrays[name] = state["tracers"].as_4D_array()
elif len(self.outputs[name]["dims"]) > 0:
arrays[name] = state[name].data
else:
Expand Down Expand Up @@ -229,7 +229,6 @@ def compute_parallel(self, inputs, communicator):
)

grid_data = GridData.new_from_metric_terms(metric_terms)
quantity_factory = QuantityFactory()

state = analytic_init.init_analytic_state(
analytic_init_case="baroclinic",
Expand All @@ -241,9 +240,6 @@ def compute_parallel(self, inputs, communicator):
comm=communicator,
)

state.q4d = {}
for tracer in utils.tracer_variables:
state.q4d[tracer] = getattr(state, tracer)
return self.outputs_from_state(state.__dict__)


Expand Down

0 comments on commit c9a40e8

Please sign in to comment.