Skip to content

Commit

Permalink
fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Jul 25, 2024
1 parent a3e5f05 commit 4f6cbb4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ def _output_cast(self, output):
return output.to_host().astype(np_dtypes[self.dest_dtype])
case _:
return output

def save_output(self, function_name, output):
if isinstance(output, tuple) or isinstance(output, list):
for i in output:
self.save_output(function_name, i)
else:
np.save(
f"{function_name}_output_{self.output_counter}.npy", output.to_host()
)
self.output_counter += 1

def _run(self, function_name, inputs: list):
return self.module[function_name](*inputs)
Expand All @@ -247,10 +257,7 @@ def __call__(self, function_name, inputs: list):
else:
output = self._run(function_name, inputs)
if self.save_outputs:
np.save(
f"{function_name}_output_{self.output_counter}.npy", output.to_host()
)
self.output_counter += 1
self.save_output(function_name, output)
output = self._output_cast(output)
return output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __init__(
batch_prompts: bool = False,
punet_quant_paths: dict[str] = None,
vae_weight_path: str = None,
vae_harness: bool = False,
vae_harness: bool = True,
add_tk_kernels: bool = False,
save_outputs: bool | dict[bool] = False,
):
Expand Down

0 comments on commit 4f6cbb4

Please sign in to comment.