Skip to content

Commit

Permalink
Fix parser for sleap-export (#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys authored Dec 16, 2022
1 parent 1595bec commit c4861e3
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
12 changes: 10 additions & 2 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,17 @@ usage: sleap-export [-h] [-m MODELS] [-e [EXPORT_PATH]]
optional arguments:
-h, --help show this help message and exit
-m MODELS, --model MODELS
Path to trained model directory (with training_config.json). Multiple models can be specified, each preceded by --model.
Path to trained model directory (with training_config.json). Multiple
models can be specified, each preceded by --model.
-e [EXPORT_PATH], --export_path [EXPORT_PATH]
Path to output directory where the frozen model will be exported to. Defaults to a folder named 'exported_model'.
Path to output directory where the frozen model will be exported to.
Defaults to a folder named 'exported_model'.
-u, --unrag UNRAG
Convert ragged tensors into regular tensors with NaN padding.
Defaults to True.
-i, --max_instances MAX_INSTANCES
Limit maximum number of instances in multi-instance models.
Defaults to None.
```

## Inference and Tracking
Expand Down
27 changes: 23 additions & 4 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4353,9 +4353,29 @@ def export_model(
)


def export_cli():
def export_cli(args: Optional[list] = None):
"""CLI for sleap-export."""

parser = _make_export_cli_parser()

args, _ = parser.parse_known_args(args=args)
print("Args:")
pprint(vars(args))
print()

export_model(
args.models,
args.export_path,
unrag_outputs=args.unrag,
max_instances=args.max_instances,
)


def _make_export_cli_parser() -> argparse.ArgumentParser:
"""Create argument parser for sleap-export CLI."""

parser = argparse.ArgumentParser()

parser.add_argument(
"-m",
"--model",
Expand Down Expand Up @@ -4388,7 +4408,7 @@ def export_cli():
),
)
parser.add_argument(
"-m",
"-i",
"--max_instances",
type=int,
help=(
Expand All @@ -4397,8 +4417,7 @@ def export_cli():
),
)

args, _ = parser.parse_known_args()
export_model(args.models, args.export_path, unrag_outputs=args.unrag)
return parser


def _make_cli_parser() -> argparse.ArgumentParser:
Expand Down
3 changes: 3 additions & 0 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_make_cli_parser,
_make_tracker_from_cli,
main as sleap_track,
export_cli as sleap_export,
)

from sleap.gui.learning import runners
Expand Down Expand Up @@ -1088,6 +1089,8 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm

# high level export (with unragging)
export_model(min_single_instance_robot_model_path, save_path=tmp_path.as_posix())
cmd = f"-m {min_single_instance_robot_model_path} -e {tmp_path.as_posix()}"
sleap_export(cmd.split())

# high level export (without unragging)
export_model(
Expand Down

0 comments on commit c4861e3

Please sign in to comment.