From 892f38e6e76e4465e54add49f80ed3e11076bd8f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 19 Dec 2024 11:47:07 -0500 Subject: [PATCH] Fix for floating point precision kept through np.power Pass use_logp and check for it --- pyFV3/stencils/nh_p_grad.py | 14 ++++++++++++-- tests/savepoint/translate/translate_nh_p_grad.py | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pyFV3/stencils/nh_p_grad.py b/pyFV3/stencils/nh_p_grad.py index 6a951c4..03988d0 100644 --- a/pyFV3/stencils/nh_p_grad.py +++ b/pyFV3/stencils/nh_p_grad.py @@ -1,3 +1,4 @@ +import numpy as np from gt4py.cartesian.gtscript import PARALLEL, computation, interval from ndsl import QuantityFactory, StencilFactory, orchestrate @@ -127,7 +128,8 @@ def __init__( stencil_factory: StencilFactory, quantity_factory: QuantityFactory, grid_data: GridData, - grid_type, + grid_type: int, + use_logp: bool, ): orchestrate( obj=self, @@ -142,6 +144,14 @@ def __init__( self.nk = grid_indexing.domain[2] self._rdx = grid_data.rdx self._rdy = grid_data.rdy + self._use_logp = use_logp + if self._use_logp: + # Requires computing and carrying `peln1` see below on + # top_level calculation + raise NotImplementedError( + "Non Hydrostatic Pressure Gradient (nh_p_grad) with" + " `use_logp` is not implemented." + ) self._tmp_wk = quantity_factory.zeros( [X_DIM, Y_DIM, Z_INTERFACE_DIM], @@ -226,7 +236,7 @@ def __call__( # Fortran names: # u=u v=v pp=pkc gz=gz pk3=pk3 delp=delp dt=dt - ptk = ptop ** akap + ptk = np.power(ptop, akap, dtype=Float) top_value = ptk # = peln1 if spec.namelist.use_logp else ptk # TODO: make it clearer that each of these a2b outputs is updated diff --git a/tests/savepoint/translate/translate_nh_p_grad.py b/tests/savepoint/translate/translate_nh_p_grad.py index 524dd5b..8220f83 100644 --- a/tests/savepoint/translate/translate_nh_p_grad.py +++ b/tests/savepoint/translate/translate_nh_p_grad.py @@ -39,6 +39,7 @@ def compute(self, inputs): self.grid.quantity_factory, grid_data=self.grid.grid_data, grid_type=self.namelist.grid_type, + use_logp=self.namelist.use_logp, ) self.make_storage_data_input_vars(inputs) self.compute_func(**inputs)