Skip to content

Commit

Permalink
Update README and weights saving imports
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Sep 18, 2024
1 parent 32d3fa0 commit 174e73d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/torchbench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ cd ..
### Export and compile

```shell
python ./export.py --model_id=All --target=gfx942 --device=hip --compile_to=vmfb --accuracy --inference
python ./export.py --target=gfx942 --device=rocm --compile_to=vmfb --performance --inference --precision=fp16 --float16 --external_weights=safetensors --external_weights_dir=./torchbench_weights/
```
3 changes: 2 additions & 1 deletion models/turbine_models/custom_models/torchbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import safetensors
import safetensors.numpy as safe_numpy
import safetensors.torch as safe_torch
import re
import glob

Expand Down Expand Up @@ -461,7 +462,7 @@ def save_external_weights(
mod_params = vae_params
if external_weight_file and not os.path.isfile(external_weight_file):
if not force_format:
safetensors.torch.save_file(mod_params, external_weight_file)
safe_torch.save_file(mod_params, external_weight_file)
else:
for x in mod_params.keys():
mod_params[x] = mod_params[x].numpy()
Expand Down

0 comments on commit 174e73d

Please sign in to comment.