Skip to content

Commit

Permalink
Rename program -> dace_program
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Cattaneo committed Dec 4, 2024
1 parent 9efb5f4 commit c7d6c4f
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _simplify(


def _build_sdfg(
program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs
dace_program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs
):
"""Build the .so out of the SDFG on the top tile ranks only"""
is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile
Expand All @@ -136,12 +136,12 @@ def _build_sdfg(
make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.CPU)

# Build non-constants & non-transients from the sdfg_kwargs
sdfg_kwargs = program._create_sdfg_args(sdfg, args, kwargs)
for k in program.constant_args:
sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs)
for k in dace_program.constant_args:
if k in sdfg_kwargs:
del sdfg_kwargs[k]
sdfg_kwargs = {k: v for k, v in sdfg_kwargs.items() if v is not None}
for k, tup in program.resolver.closure_arrays.items():
for k, tup in dace_program.resolver.closure_arrays.items():
if k in sdfg_kwargs and tup[1].transient:
del sdfg_kwargs[k]

Expand Down Expand Up @@ -216,26 +216,30 @@ def _build_sdfg(
MPI.COMM_WORLD.Barrier()

with DaCeProgress(config, "Loading"):
sdfg_path = get_sdfg_path(program.name, config, override_run_only=True)
compiledSDFG, _ = program.load_precompiled_sdfg(sdfg_path, *args, **kwargs)
config.loaded_precompiled_SDFG[program] = FrozenCompiledSDFG(
program, compiledSDFG, args, kwargs
sdfg_path = get_sdfg_path(dace_program.name, config, override_run_only=True)
compiledSDFG, _ = dace_program.load_precompiled_sdfg(
sdfg_path, *args, **kwargs
)
config.loaded_precompiled_SDFG[dace_program] = FrozenCompiledSDFG(
dace_program, compiledSDFG, args, kwargs
)

return _call_sdfg(program, sdfg, config, args, kwargs)
return _call_sdfg(dace_program, sdfg, config, args, kwargs)


def _call_sdfg(program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs):
def _call_sdfg(
dace_program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs
):
"""Dispatch the SDFG execution and/or build"""
# Pre-compiled SDFG code path does away with any data checks and
# cached the marshalling - leading to almost direct C call
# DaceProgram performs argument transformation & checks for a cost ~200ms
# of overhead
if program in config.loaded_precompiled_SDFG:
if dace_program in config.loaded_precompiled_SDFG:
with DaCeProgress(config, "Run"):
if config.is_gpu_backend():
_upload_to_device(list(args) + list(kwargs.values()))
res = config.loaded_precompiled_SDFG[program]()
res = config.loaded_precompiled_SDFG[dace_program]()
res = _download_results_from_dace(
config, res, list(args) + list(kwargs.values())
)
Expand All @@ -244,7 +248,7 @@ def _call_sdfg(program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args,
mode = config.get_orchestrate()
if mode in [DaCeOrchestration.Build, DaCeOrchestration.BuildAndRun]:
ndsl_log.info("Building DaCe orchestration")
return _build_sdfg(program, sdfg, config, args, kwargs)
return _build_sdfg(dace_program, sdfg, config, args, kwargs)

if mode == DaCeOrchestration.Run:
# We should never hit this, it should be caught by the
Expand All @@ -255,7 +259,7 @@ def _call_sdfg(program: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args,


def _parse_sdfg(
program: DaceProgram,
dace_program: DaceProgram,
config: DaceConfig,
*args,
**kwargs,
Expand All @@ -264,15 +268,15 @@ def _parse_sdfg(
Either parses, load a .sdfg or load .so (as a compiled sdfg)
Attributes:
program: the DaceProgram carrying reference to the original method/function
dace_program: the DaceProgram carrying reference to the original method/function
config: the DaceConfig configuration for this execution
"""
# Check cache for already loaded SDFG
if program in config.loaded_precompiled_SDFG:
return config.loaded_precompiled_SDFG[program]
if dace_program in config.loaded_precompiled_SDFG:
return config.loaded_precompiled_SDFG[dace_program]

# Build expected path
sdfg_path = get_sdfg_path(program.name, config)
sdfg_path = get_sdfg_path(dace_program.name, config)
if sdfg_path is None:
is_compiling = (
True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile
Expand All @@ -283,10 +287,10 @@ def _parse_sdfg(
# compiled SDFG from the compiling rank
return None

with DaCeProgress(config, f"Parse code of {program.name} to SDFG"):
sdfg = program.to_sdfg(
with DaCeProgress(config, f"Parse code of {dace_program.name} to SDFG"):
sdfg = dace_program.to_sdfg(
*args,
**program.__sdfg_closure__(),
**dace_program.__sdfg_closure__(),
**kwargs,
save=False,
simplify=False,
Expand All @@ -295,13 +299,13 @@ def _parse_sdfg(

if os.path.isfile(sdfg_path):
with DaCeProgress(config, "Load .sdfg"):
sdfg, _ = program.load_sdfg(sdfg_path, *args, **kwargs)
sdfg, _ = dace_program.load_sdfg(sdfg_path, *args, **kwargs)
return sdfg

with DaCeProgress(config, "Load precompiled .sdfg (.so)"):
compiledSDFG, _ = program.load_precompiled_sdfg(sdfg_path, *args, **kwargs)
config.loaded_precompiled_SDFG[program] = FrozenCompiledSDFG(
program, compiledSDFG, args, kwargs
compiledSDFG, _ = dace_program.load_precompiled_sdfg(sdfg_path, *args, **kwargs)
config.loaded_precompiled_SDFG[dace_program] = FrozenCompiledSDFG(
dace_program, compiledSDFG, args, kwargs
)
return compiledSDFG

Expand Down

0 comments on commit c7d6c4f

Please sign in to comment.