Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jdahm committed Nov 10, 2022
1 parent c8a3236 commit 0902b3a
Show file tree
Hide file tree
Showing 26 changed files with 170 additions and 304 deletions.
6 changes: 3 additions & 3 deletions driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ def _critical_path_step_all(
self.end_of_step_update(
dycore_state=self.state.dycore_state,
phy_state=self.state.physics_state,
u_dt=self.state.tendency_state.u_dt.storage,
v_dt=self.state.tendency_state.v_dt.storage,
pt_dt=self.state.tendency_state.pt_dt.storage,
u_dt=self.state.tendency_state.u_dt.data,
v_dt=self.state.tendency_state.v_dt.data,
pt_dt=self.state.tendency_state.pt_dt.data,
dt=float(dt),
)
self._end_of_step_actions(step)
Expand Down
16 changes: 10 additions & 6 deletions dsl/pace/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,22 @@
from pace.util.mpi import MPI


try:
import cupy as cp
except ImportError:
cp = None


def dace_inhibitor(func: Callable):
"""Triggers callback generation wrapping `func` while doing DaCe parsing."""
return func


def _upload_to_device(host_data: List[Any]):
"""Make sure any data that are still a gt4py.storage gets uploaded to device"""
for data in host_data:
if isinstance(data, gt4py.storage.Storage):
data.host_to_device()
for i, data in enumerate(host_data):
if isinstance(data, cp.ndarray):
host_data[i] = cp.asarray(data)


def _download_results_from_dace(
Expand All @@ -55,9 +61,7 @@ def _download_results_from_dace(
gt4py_results = None
if dace_result is not None:
for arg in args:
if isinstance(arg, gt4py.storage.Storage) and hasattr(
arg, "_set_device_modified"
):
if isinstance(arg, cp.ndarray) and hasattr(arg, "_set_device_modified"):
arg._set_device_modified()
if config.is_gpu_backend():
gt4py_results = [
Expand Down
75 changes: 42 additions & 33 deletions dsl/pace/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import gt4py.backend
import gt4py.storage as gt_storage
import gt4py
import numpy as np

from pace.dsl.typing import DTypes, Field, Float, FloatField
Expand Down Expand Up @@ -50,6 +49,30 @@ def wrapper(*args, **kwargs) -> Any:
return inner


def _mask_to_dimensions(
mask: Tuple[bool, ...], shape: Sequence[int]
) -> List[Union[str, int]]:
assert len(mask) == 3
dimensions: List[Union[str, int]] = []
for i, axis in enumerate(("I", "J", "K")):
if mask[i]:
dimensions.append(axis)
offset = int(sum(mask))
dimensions.extend(shape[offset:])
return dimensions


def _interpolate_origin(origin: Tuple[int, ...], mask: Tuple[bool, ...]) -> List[int]:
assert len(mask) == 3
final_origin: List[int] = []
for i, has_axis in enumerate(mask):
if has_axis:
final_origin.append(origin[i])

final_origin.extend(origin[len(mask) :])
return final_origin


def make_storage_data(
data: Field,
shape: Optional[Tuple[int, ...]] = None,
Expand Down Expand Up @@ -118,6 +141,11 @@ def make_storage_data(
default_mask = (n_dims * (True,)) + ((max_dim - n_dims) * (False,))
mask = default_mask

# Convert to `dimensions` which is the new parameter type that gt4py accepts.
zip(
shape,
)

if n_dims == 1:
data = _make_storage_data_1d(
data, shape, start, dummy, axis, read_only, backend=backend
Expand All @@ -129,14 +157,12 @@ def make_storage_data(
else:
data = _make_storage_data_3d(data, shape, start, backend=backend)

storage = gt_storage.from_array(
data=data,
storage = gt4py.storage.from_array(
data,
dtype,
backend=backend,
default_origin=origin,
shape=shape,
dtype=dtype,
mask=mask,
managed_memory=managed_memory,
aligned_index=_interpolate_origin(origin, mask),
dimensions=_mask_to_dimensions(mask, data.shape),
)
return storage

Expand Down Expand Up @@ -264,13 +290,12 @@ def make_storage_from_shape(
mask = (False, False, True) # Assume 1D is a k-field
else:
mask = (n_dims * (True,)) + ((3 - n_dims) * (False,))
storage = gt_storage.zeros(
storage = gt4py.storage.zeros(
shape,
dtype,
backend=backend,
default_origin=origin,
shape=shape,
dtype=dtype,
mask=mask,
managed_memory=managed_memory,
aligned_index=_interpolate_origin(origin, mask),
dimensions=_mask_to_dimensions(mask, shape),
)
return storage

Expand Down Expand Up @@ -340,8 +365,6 @@ def k_split_run(func, data, k_indices, splitvars_values):


def asarray(array, to_type=np.ndarray, dtype=None, order=None):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
if cp and (isinstance(array, list)):
if to_type is np.ndarray:
order = "F" if order is None else order
Expand Down Expand Up @@ -379,19 +402,15 @@ def is_gpu_backend(backend: str) -> bool:
def zeros(shape, dtype=Float, *, backend: str):
storage_type = cp.ndarray if is_gpu_backend(backend) else np.ndarray
xp = cp if cp and storage_type is cp.ndarray else np
return xp.zeros(shape)
return xp.zeros(shape, dtype=dtype)


def sum(array, axis=None, dtype=Float, out=None, keepdims=False):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.sum(array, axis, dtype, out, keepdims)


def repeat(array, repeats, axis=None):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.repeat(array, repeats, axis)

Expand All @@ -401,22 +420,16 @@ def index(array, key):


def moveaxis(array, source: int, destination: int):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.moveaxis(array, source, destination)


def tile(array, reps: Union[int, Tuple[int, ...]]):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.tile(array, reps)


def squeeze(array, axis: Union[int, Tuple[int]] = None):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.squeeze(array, axis)

Expand Down Expand Up @@ -444,17 +457,13 @@ def unique(
return_counts: bool = False,
axis: Union[int, Tuple[int]] = None,
):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.unique(array, return_index, return_inverse, return_counts, axis)


def stack(tup, axis: int = 0, out=None):
array_tup = []
for array in tup:
if isinstance(array, gt_storage.storage.Storage):
array = array.data
array_tup.append(array)
xp = cp if cp and type(array_tup[0]) is cp.ndarray else np
return xp.stack(array_tup, axis, out)
Expand Down
21 changes: 11 additions & 10 deletions dsl/pace/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import gt4py
import numpy as np
from gt4py import gtscript
from gt4py.storage.storage import Storage
from gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline

import pace.dsl.gt4py_utils as gt4py_utils
Expand All @@ -34,6 +33,12 @@
from pace.util.mpi import MPI


try:
import cupy as cp
except ImportError:
cp = np


def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id):
report_head = f"comparing against numpy for func {function_name}, gt_id {gt_id}:"
report_segments = []
Expand Down Expand Up @@ -431,7 +436,7 @@ def __call__(self, *args, **kwargs) -> None:
f"after calling {self._func_name}"
)

def _mark_cuda_fields_written(self, fields: Mapping[str, Storage]):
def _mark_cuda_fields_written(self, fields: Mapping[str, cp.ndarray]):
if self.stencil_config.is_gpu_backend:
for write_field in self._written_fields:
fields[write_field]._set_device_modified()
Expand Down Expand Up @@ -520,15 +525,11 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None):

def _convert_quantities_to_storage(args, kwargs):
for i, arg in enumerate(args):
try:
args[i] = arg.storage
except AttributeError:
pass
if isinstance(arg, pace.util.Quantity):
args[i] = arg.data
for name, arg in kwargs.items():
try:
kwargs[name] = arg.storage
except AttributeError:
pass
if isinstance(arg, pace.util.Quantity):
kwargs[name] = arg.data


class GridIndexing:
Expand Down
7 changes: 2 additions & 5 deletions fv3core/examples/standalone/runfile/acoustics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ def get_state_from_input(
) -> Dict[str, SimpleNamespace]:
"""
Transforms the input data from the dictionary of strings
to arrays into a state we can pass in
Input is a dict of arrays. These are transformed into Storage arrays
useable in GT4Py
to arrays into a state we can pass in.
This will also take care of reshaping the arrays into same sized
fields as required by the acoustics
fields as required by the acoustics.
"""
driver_object = TranslateDynCore([grid], namelist, stencil_config)
driver_object._base.make_storage_data_input_vars(input_data)
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/initialization/dycore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def init_zeros(cls, quantity_factory: pace.util.QuantityFactory):
if "dims" in _field.metadata.keys():
initial_storages[_field.name] = quantity_factory.zeros(
_field.metadata["dims"], _field.metadata["units"], dtype=float
).storage
).data
return cls.init_from_storages(
storages=initial_storages, sizer=quantity_factory.sizer
)
Expand Down
8 changes: 4 additions & 4 deletions physics/pace/physics/physics_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,14 @@ def __post_init__(

@classmethod
def init_zeros(cls, quantity_factory, active_packages: List[str]) -> "PhysicsState":
initial_storages = {}
initial_arrays = {}
for _field in fields(cls):
if "dims" in _field.metadata.keys():
initial_storages[_field.name] = quantity_factory.zeros(
initial_arrays[_field.name] = quantity_factory.zeros(
_field.metadata["dims"], _field.metadata["units"], dtype=float
).storage
).data
return cls(
**initial_storages,
**initial_arrays,
quantity_factory=quantity_factory,
active_packages=active_packages,
)
Expand Down
3 changes: 1 addition & 2 deletions stencils/pace/stencils/testing/temporaries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
from typing import List

import gt4py
import numpy as np

import pace.util
Expand All @@ -15,7 +14,7 @@ def copy_temporaries(obj, max_depth: int) -> dict:
attr = getattr(obj, attr_name)
except AttributeError:
attr = None
if isinstance(attr, (gt4py.storage.storage.Storage, pace.util.Quantity)):
if isinstance(attr, pace.util.Quantity):
temporaries[attr_name] = copy.deepcopy(np.asarray(attr.data))
elif attr.__class__.__module__.split(".")[0] in ( # type: ignore
"fv3core",
Expand Down
3 changes: 1 addition & 2 deletions tests/main/driver/test_restart_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import shutil
from datetime import datetime

import gt4py
import numpy as np
import xarray as xr
import yaml
Expand Down Expand Up @@ -114,7 +113,7 @@ def test_restart_save_to_disk():
for var in driver_state.physics_state.__dict__.keys():
if isinstance(
driver_state.physics_state.__dict__[var],
gt4py.storage.storage.CPUStorage,
np.ndarray,
):
np.testing.assert_allclose(
driver_state.physics_state.__dict__[var].data,
Expand Down
6 changes: 2 additions & 4 deletions tests/main/dsl/test_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ def _make_storage(
stencil_config: pace.dsl.StencilConfig,
*,
dtype=float,
mask=None,
default_origin=(0, 0, 0),
aligned_index=(0, 0, 0),
):
return func(
backend=stencil_config.compilation_config.backend,
shape=grid_indexing.domain,
dtype=dtype,
mask=mask,
default_origin=default_origin,
aligned_index=aligned_index,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/main/dsl/test_stencil_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_convert_quantities_to_storage_one_arg_quantity():
kwargs = {}
_convert_quantities_to_storage(args, kwargs)
assert len(args) == 1
assert args[0] == quantity.storage
assert args[0] == quantity.data
assert len(kwargs) == 0


Expand All @@ -326,7 +326,7 @@ def test_convert_quantities_to_storage_one_kwarg_quantity():
_convert_quantities_to_storage(args, kwargs)
assert len(args) == 0
assert len(kwargs) == 1
assert kwargs["val"] == quantity.storage
assert kwargs["val"] == quantity.data


def test_convert_quantities_to_storage_one_arg_nonquantity():
Expand Down
4 changes: 2 additions & 2 deletions tests/main/fv3core/test_dycore_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def test_call_does_not_allocate_storages():
def error_func(*args, **kwargs):
raise AssertionError("call not allowed")

with unittest.mock.patch("gt4py.storage.storage.zeros", new=error_func):
with unittest.mock.patch("gt4py.storage.storage.empty", new=error_func):
with unittest.mock.patch("gt4py.storage.zeros", new=error_func):
with unittest.mock.patch("gt4py.storage.empty", new=error_func):
dycore.step_dynamics(state, timer)


Expand Down
Loading

0 comments on commit 0902b3a

Please sign in to comment.