diff --git a/docs/guides/cli.md b/docs/guides/cli.md index ce047d59e..acdba57ef 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -91,9 +91,17 @@ usage: sleap-export [-h] [-m MODELS] [-e [EXPORT_PATH]] optional arguments: -h, --help show this help message and exit -m MODELS, --model MODELS - Path to trained model directory (with training_config.json). Multiple models can be specified, each preceded by --model. + Path to trained model directory (with training_config.json). Multiple + models can be specified, each preceded by --model. -e [EXPORT_PATH], --export_path [EXPORT_PATH] - Path to output directory where the frozen model will be exported to. Defaults to a folder named 'exported_model'. + Path to output directory where the frozen model will be exported to. + Defaults to a folder named 'exported_model'. + -u, --unrag UNRAG + Convert ragged tensors into regular tensors with NaN padding. + Defaults to True. + -i, --max_instances MAX_INSTANCES + Limit maximum number of instances in multi-instance models. + Defaults to None. ``` ## Inference and Tracking diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index dc7962c63..bed6dd1c4 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4353,9 +4353,29 @@ def export_model( ) -def export_cli(): +def export_cli(args: Optional[list] = None): """CLI for sleap-export.""" + + parser = _make_export_cli_parser() + + args, _ = parser.parse_known_args(args=args) + print("Args:") + pprint(vars(args)) + print() + + export_model( + args.models, + args.export_path, + unrag_outputs=args.unrag, + max_instances=args.max_instances, + ) + + +def _make_export_cli_parser() -> argparse.ArgumentParser: + """Create argument parser for sleap-export CLI.""" + parser = argparse.ArgumentParser() + parser.add_argument( "-m", "--model", @@ -4388,7 +4408,7 @@ def export_cli(): ), ) parser.add_argument( - "-m", + "-i", "--max_instances", type=int, help=( @@ -4397,8 +4417,7 @@ def export_cli(): ), ) - args, _ = parser.parse_known_args() - export_model(args.models, args.export_path, unrag_outputs=args.unrag) + return parser def _make_cli_parser() -> argparse.ArgumentParser: diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index a1fbb7353..7e817bd67 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -41,6 +41,7 @@ _make_cli_parser, _make_tracker_from_cli, main as sleap_track, + export_cli as sleap_export, ) from sleap.gui.learning import runners @@ -1088,6 +1089,8 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm # high level export (with unragging) export_model(min_single_instance_robot_model_path, save_path=tmp_path.as_posix()) + cmd = f"-m {min_single_instance_robot_model_path} -e {tmp_path.as_posix()}" + sleap_export(cmd.split()) # high level export (without unragging) export_model(