diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index b5417cec..767610c3 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -40,12 +40,12 @@ cp = None -def dace_inhibitor(func: Callable): +def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" return func -def _upload_to_device(host_data: List[Any]): +def _upload_to_device(host_data: List[Any]) -> None: """Make sure any ndarrays gets uploaded to the device This will raise an assertion if cupy is not installed. @@ -60,22 +60,11 @@ def _download_results_from_dace( config: DaceConfig, dace_result: Optional[List[Any]], args: List[Any] ): """Move all data from DaCe memory space to GT4Py""" - gt4py_results = None - if dace_result is not None: - if config.is_gpu_backend(): - gt4py_results = [ - gt4py.storage.from_array( - r, - backend=config.get_backend(), - ) - for r in dace_result - ] - else: - gt4py_results = [ - gt4py.storage.from_array(r, backend=config.get_backend()) - for r in dace_result - ] - return gt4py_results + if dace_result is None: + return None + + backend = config.get_backend() + return [gt4py.storage.from_array(result, backend=backend) for result in dace_result] def _to_gpu(sdfg: dace.SDFG): @@ -99,8 +88,8 @@ def _to_gpu(sdfg: dace.SDFG): else: arr.storage = dace.StorageType.GPU_Global - # All maps will be scedule on GPU - for mapentry, state in topmaps: + # All maps will be schedule on GPU + for mapentry, _state in topmaps: mapentry.schedule = dace.ScheduleType.GPU_Device # Deactivate OpenMP sections @@ -108,25 +97,30 @@ def _to_gpu(sdfg: dace.SDFG): sd.openmp_sections = False -def _simplify(sdfg: dace.SDFG, validate=True, verbose=False): +def _simplify( + sdfg: dace.SDFG, + *, + validate: bool = True, + validate_all: bool = False, + verbose: bool = False, +): """Override of sdfg.simplify to skip failing transformation per https://github.com/spcl/dace/issues/1328 """ return SimplifyPass( validate=validate, + validate_all=validate_all, verbose=verbose, skip=["ConstantPropagation"], ).apply_pass(sdfg, {}) def _build_sdfg( - daceprog: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs + dace_program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs ): """Build the .so out of the SDFG on the top tile ranks only""" - if DEACTIVATE_DISTRIBUTED_DACE_COMPILE: - is_compiling = True - else: - is_compiling = config.do_compile + is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile + if is_compiling: # Make the transients array persistents if config.is_gpu_backend(): @@ -136,18 +130,18 @@ def _build_sdfg( # Upload args to device _upload_to_device(list(args) + list(kwargs.values())) else: - for sd, _aname, arr in sdfg.arrays_recursive(): + for _sd, _aname, arr in sdfg.arrays_recursive(): if arr.shape == (1,): arr.storage = DaceStorageType.Register make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.CPU) # Build non-constants & non-transients from the sdfg_kwargs - sdfg_kwargs = daceprog._create_sdfg_args(sdfg, args, kwargs) - for k in daceprog.constant_args: + sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs) + for k in dace_program.constant_args: if k in sdfg_kwargs: del sdfg_kwargs[k] sdfg_kwargs = {k: v for k, v in sdfg_kwargs.items() if v is not None} - for k, tup in daceprog.resolver.closure_arrays.items(): + for k, tup in dace_program.resolver.closure_arrays.items(): if k in sdfg_kwargs and tup[1].transient: del sdfg_kwargs[k] @@ -204,13 +198,16 @@ def _build_sdfg( # On BuildAndRun: all ranks sync, then load the SDFG from # the expected path (made available by build). # We use a "FrozenCompiledSDFG" to minimize re-entry cost at call time + + mode = config.get_orchestrate() # DEV NOTE: we explicitly use MPI.COMM_WORLD here because it is # a true multi-machine sync, outside of our own communicator class. - if config.get_orchestrate() == DaCeOrchestration.Build: + if mode == DaCeOrchestration.Build: MPI.COMM_WORLD.Barrier() # Protect against early exist which kill SLURM jobs ndsl_log.info(f"{DaCeProgress.default_prefix(config)} Build only, exiting.") exit(0) - elif config.get_orchestrate() == DaCeOrchestration.BuildAndRun: + + if mode == DaCeOrchestration.BuildAndRun: if not is_compiling: ndsl_log.info( f"{DaCeProgress.default_prefix(config)} Rank is not compiling." @@ -219,52 +216,50 @@ def _build_sdfg( MPI.COMM_WORLD.Barrier() with DaCeProgress(config, "Loading"): - sdfg_path = get_sdfg_path(daceprog.name, config, override_run_only=True) - csdfg, _ = daceprog.load_precompiled_sdfg(sdfg_path, *args, **kwargs) - config.loaded_precompiled_SDFG[daceprog] = FrozenCompiledSDFG( - daceprog, csdfg, args, kwargs + sdfg_path = get_sdfg_path(dace_program.name, config, override_run_only=True) + compiledSDFG, _ = dace_program.load_precompiled_sdfg( + sdfg_path, *args, **kwargs + ) + config.loaded_precompiled_SDFG[dace_program] = FrozenCompiledSDFG( + dace_program, compiledSDFG, args, kwargs ) - return _call_sdfg(daceprog, sdfg, config, args, kwargs) + return _call_sdfg(dace_program, sdfg, config, args, kwargs) def _call_sdfg( - daceprog: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs + dace_program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs ): """Dispatch the SDFG execution and/or build""" # Pre-compiled SDFG code path does away with any data checks and # cached the marshalling - leading to almost direct C call # DaceProgram performs argument transformation & checks for a cost ~200ms # of overhead - if daceprog in config.loaded_precompiled_SDFG: + if dace_program in config.loaded_precompiled_SDFG: with DaCeProgress(config, "Run"): if config.is_gpu_backend(): _upload_to_device(list(args) + list(kwargs.values())) - res = config.loaded_precompiled_SDFG[daceprog]() + res = config.loaded_precompiled_SDFG[dace_program]() res = _download_results_from_dace( config, res, list(args) + list(kwargs.values()) ) return res + + mode = config.get_orchestrate() + if mode in [DaCeOrchestration.Build, DaCeOrchestration.BuildAndRun]: + ndsl_log.info("Building DaCe orchestration") + return _build_sdfg(dace_program, sdfg, config, args, kwargs) + + if mode == DaCeOrchestration.Run: + # We should never hit this, it should be caught by the + # loaded_precompiled_SDFG check above + raise RuntimeError("Unexpected call - pre-compiled SDFG failed to load") else: - if ( - config.get_orchestrate() == DaCeOrchestration.Build - or config.get_orchestrate() == DaCeOrchestration.BuildAndRun - ): - ndsl_log.info("Building DaCe orchestration") - res = _build_sdfg(daceprog, sdfg, config, args, kwargs) - elif config.get_orchestrate() == DaCeOrchestration.Run: - # We should never hit this, it should be caught by the - # loaded_precompiled_SDFG check above - raise RuntimeError("Unexpected call - csdfg didn't get caught") - else: - raise NotImplementedError( - f"Mode {config.get_orchestrate()} unimplemented at call time" - ) - return res + raise NotImplementedError(f"Mode '{mode}' unimplemented at call time") def _parse_sdfg( - daceprog: DaceProgram, + dace_program: DaceProgram, config: DaceConfig, *args, **kwargs, @@ -273,45 +268,46 @@ def _parse_sdfg( Either parses, load a .sdfg or load .so (as a compiled sdfg) Attributes: - daceprog: the DaceProgram carrying reference to the original method/function + dace_program: the DaceProgram carrying reference to the original method/function config: the DaceConfig configuration for this execution """ # Check cache for already loaded SDFG - if daceprog in config.loaded_precompiled_SDFG: - return config.loaded_precompiled_SDFG[daceprog] + if dace_program in config.loaded_precompiled_SDFG: + return config.loaded_precompiled_SDFG[dace_program] # Build expected path - sdfg_path = get_sdfg_path(daceprog.name, config) + sdfg_path = get_sdfg_path(dace_program.name, config) if sdfg_path is None: - if DEACTIVATE_DISTRIBUTED_DACE_COMPILE: - is_compiling = True - else: - is_compiling = config.do_compile + is_compiling = ( + True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile + ) + if not is_compiling: # We can not parse the SDFG since we will load the proper # compiled SDFG from the compiling rank return None - with DaCeProgress(config, f"Parse code of {daceprog.name} to SDFG"): - sdfg = daceprog.to_sdfg( + + with DaCeProgress(config, f"Parse code of {dace_program.name} to SDFG"): + sdfg = dace_program.to_sdfg( *args, - **daceprog.__sdfg_closure__(), + **dace_program.__sdfg_closure__(), **kwargs, save=False, simplify=False, ) return sdfg - else: - if os.path.isfile(sdfg_path): - with DaCeProgress(config, "Load .sdfg"): - sdfg, _ = daceprog.load_sdfg(sdfg_path, *args, **kwargs) - return sdfg - else: - with DaCeProgress(config, "Load precompiled .sdfg (.so)"): - csdfg, _ = daceprog.load_precompiled_sdfg(sdfg_path, *args, **kwargs) - config.loaded_precompiled_SDFG[daceprog] = FrozenCompiledSDFG( - daceprog, csdfg, args, kwargs - ) - return csdfg + + if os.path.isfile(sdfg_path): + with DaCeProgress(config, "Load .sdfg"): + sdfg, _ = dace_program.load_sdfg(sdfg_path, *args, **kwargs) + return sdfg + + with DaCeProgress(config, "Load precompiled .sdfg (.so)"): + compiledSDFG, _ = dace_program.load_precompiled_sdfg(sdfg_path, *args, **kwargs) + config.loaded_precompiled_SDFG[dace_program] = FrozenCompiledSDFG( + dace_program, compiledSDFG, args, kwargs + ) + return compiledSDFG class _LazyComputepathFunction(SDFGConvertible): @@ -448,7 +444,7 @@ def orchestrate( ): """ Orchestrate a method of an object with DaCe. - The method object is patched in place, replacing the orignal Callable with + The method object is patched in place, replacing the original Callable with a wrapper that will trigger orchestration at call time. If the model configuration doesn't demand orchestration, this won't do anything. @@ -463,72 +459,71 @@ def orchestrate( dace_compiletime_args = [] if config.is_dace_orchestrated(): - if hasattr(obj, method_to_orchestrate): - func = type.__getattribute__(type(obj), method_to_orchestrate) - - # Flag argument as dace.constant - for argument in dace_compiletime_args: - func.__annotations__[argument] = DaceCompiletime - - # Build DaCe orchestrated wrapper - # This is a JIT object, e.g. DaCe compilation will happen on call - wrapped = _LazyComputepathMethod(func, config).__get__(obj) - - if method_to_orchestrate == "__call__": - # Grab the function from the type of the child class - # Dev note: we need to use type for dunder call because: - # a = A() - # a() - # resolved to: type(a).__call__(a) - # therefore patching the instance call (e.g a.__call__) is not enough. - # We could patch the type(self), ergo the class itself - # but that would patch _every_ instance of A. - # What we can do is patch the instance.__class__ with a local made class - # in order to keep each instance with it's own patch. - # - # Re: type:ignore - # Mypy is unhappy about dynamic class name and the devs (per github - # issues discussion) is to make a plugin. Too much work -> ignore mypy - - class _(type(obj)): # type: ignore - __qualname__ = f"{type(obj).__qualname__}_patched" - __name__ = f"{type(obj).__name__}_patched" - - def __call__(self, *arg, **kwarg): - return wrapped(*arg, **kwarg) - - def __sdfg__(self, *args, **kwargs): - return wrapped.__sdfg__(*args, **kwargs) - - def __sdfg_closure__(self, reevaluate=None): - return wrapped.__sdfg_closure__(reevaluate) - - def __sdfg_signature__(self): - return wrapped.__sdfg_signature__() - - def closure_resolver( - self, constant_args, given_args, parent_closure=None - ): - return wrapped.closure_resolver( - constant_args, given_args, parent_closure - ) - - # We keep the original class type name to not perturb - # the workflows that uses it to build relevant info (path, hash...) - previous_cls_name = type(obj).__name__ - obj.__class__ = _ - type(obj).__name__ = previous_cls_name - else: - # For regular attribute - we can just patch as usual - setattr(obj, method_to_orchestrate, wrapped) - - else: + if not hasattr(obj, method_to_orchestrate): raise RuntimeError( f"Could not orchestrate, " f"{type(obj).__name__}.{method_to_orchestrate} " "does not exists" ) + func = type.__getattribute__(type(obj), method_to_orchestrate) + + # Flag argument as dace.constant + for argument in dace_compiletime_args: + func.__annotations__[argument] = DaceCompiletime + + # Build DaCe orchestrated wrapper + # This is a JIT object, e.g. DaCe compilation will happen on call + wrapped = _LazyComputepathMethod(func, config).__get__(obj) + + if method_to_orchestrate == "__call__": + # Grab the function from the type of the child class + # Dev note: we need to use type for dunder call because: + # a = A() + # a() + # resolved to: type(a).__call__(a) + # therefore patching the instance call (e.g a.__call__) is not enough. + # We could patch the type(self), ergo the class itself + # but that would patch _every_ instance of A. + # What we can do is patch the instance.__class__ with a local made class + # in order to keep each instance with it's own patch. + # + # Re: type:ignore + # Mypy is unhappy about dynamic class name and the devs (per github + # issues discussion) is to make a plugin. Too much work -> ignore mypy + + class _(type(obj)): # type: ignore + __qualname__ = f"{type(obj).__qualname__}_patched" + __name__ = f"{type(obj).__name__}_patched" + + def __call__(self, *arg, **kwarg): + return wrapped(*arg, **kwarg) + + def __sdfg__(self, *args, **kwargs): + return wrapped.__sdfg__(*args, **kwargs) + + def __sdfg_closure__(self, reevaluate=None): + return wrapped.__sdfg_closure__(reevaluate) + + def __sdfg_signature__(self): + return wrapped.__sdfg_signature__() + + def closure_resolver( + self, constant_args, given_args, parent_closure=None + ): + return wrapped.closure_resolver( + constant_args, given_args, parent_closure + ) + + # We keep the original class type name to not perturb + # the workflows that uses it to build relevant info (path, hash...) + previous_cls_name = type(obj).__name__ + obj.__class__ = _ + type(obj).__name__ = previous_cls_name + else: + # For regular attribute - we can just patch as usual + setattr(obj, method_to_orchestrate, wrapped) + def orchestrate_function( config: DaceConfig = None, @@ -553,9 +548,6 @@ def _wrapper(*args, **kwargs): func.__annotations__[argument] = DaceCompiletime return _LazyComputepathFunction(func, config) - if config.is_dace_orchestrated(): - return _wrapper(func) - else: - return func + return _wrapper(func) if config.is_dace_orchestrated() else func return _decorator