From b2ea737a76f469a42cb949b6dbb5c29555cf566e Mon Sep 17 00:00:00 2001 From: Nicholas Gao Date: Thu, 22 Aug 2024 22:12:18 +0200 Subject: [PATCH 1/2] switch from docopt to docopt-ng --- dev-requirements.txt | 2 +- requirements.txt | 2 +- sacred/arg_parser.py | 7 +++++++ sacred/experiment.py | 7 ++++--- tests/test_arg_parser.py | 4 ++-- tests/test_observers/test_s3_observer.py | 16 ++++++++-------- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index c3d9c470..430dec99 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ pytest==7.1.2 # tests/test_utils.py depends on that pytest version is exactly 7.1.2 colorama -docopt +docopt-ng gitdb2 GitPython hashfs diff --git a/requirements.txt b/requirements.txt index a20916af..f53b9b0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -docopt>=0.3, <1.0 +docopt-ng>=0.9, <1.0 jsonpickle>=2.2.0 munch>=2.5, <5.0 wrapt>=1.0, <2.0 diff --git a/sacred/arg_parser.py b/sacred/arg_parser.py index ba564fa1..2319bf44 100644 --- a/sacred/arg_parser.py +++ b/sacred/arg_parser.py @@ -9,6 +9,7 @@ import ast import textwrap import inspect +import re from shlex import quote from sacred.serializer import restore @@ -199,6 +200,12 @@ def format_usage(program_name, description, commands=None, options=()): return usage +def printable_usage(doc): + # in python < 2.7 you can't pass flags=re.IGNORECASE + usage_split = re.split(r"([Uu][Ss][Aa][Gg][Ee]:)", doc) + return re.split(r"\n\s*\n", "".join(usage_split[1:]))[0].strip() + + def _get_first_line_of_docstring(func): return textwrap.dedent(func.__doc__ or "").strip().split("\n")[0] diff --git a/sacred/experiment.py b/sacred/experiment.py index a0b7e9dc..8f59e8d3 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -1,4 +1,5 @@ """The Experiment class, which is central to sacred.""" + import inspect import os.path import sys @@ -6,10 +7,10 @@ from collections import OrderedDict from typing import Sequence, Optional, List -from docopt import docopt, printable_usage +from docopt import docopt from sacred import SETTINGS -from sacred.arg_parser import format_usage, get_config_updates +from sacred.arg_parser import get_config_updates, format_usage, printable_usage from sacred import commandline_options from sacred.commandline_options import CLIOption from sacred.commands import ( @@ -294,7 +295,7 @@ def run_commandline(self, argv=None) -> Optional[Run]: """ argv = ensure_wellformed_argv(argv) short_usage, usage, internal_usage = self.get_usage() - args = docopt(internal_usage, [str(a) for a in argv[1:]], help=False) + args = docopt(internal_usage, [str(a) for a in argv[1:]], default_help=False) cmd_name = args.get("COMMAND") or self.default_command config_updates, named_configs = get_config_updates(args["UPDATE"]) diff --git a/tests/test_arg_parser.py b/tests/test_arg_parser.py index 2508ab98..d0cf0649 100644 --- a/tests/test_arg_parser.py +++ b/tests/test_arg_parser.py @@ -51,8 +51,8 @@ def test_parse_individual_arguments(argv, expected): options = gather_command_line_options() usage = format_usage("test.py", "", {}, options) argv = shlex.split(argv) - plain = docopt(usage, [], help=False) - args = docopt(usage, argv, help=False) + plain = docopt(usage, [], default_help=False) + args = docopt(usage, argv, default_help=False) plain.update(expected) assert args == plain diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index c4e2111d..a9000cef 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -77,7 +77,7 @@ def _get_file_data(bucket_name, key): return s3.Object(bucket_name, key).get()["Body"].read() -@moto.mock_s3 +@moto.mock_aws def test_fs_observer_started_event_creates_bucket(observer, sample_run): _id = observer.started_event(**sample_run) run_dir = s3_join(BASEDIR, str(_id)) @@ -102,7 +102,7 @@ def test_fs_observer_started_event_creates_bucket(observer, sample_run): } -@moto.mock_s3 +@moto.mock_aws def test_fs_observer_started_event_increments_run_id(observer, sample_run): _id = observer.started_event(**sample_run) _id2 = observer.started_event(**sample_run) @@ -119,7 +119,7 @@ def test_s3_observer_equality(): assert obs_one != different_bucket -@moto.mock_s3 +@moto.mock_aws def test_raises_error_on_duplicate_id_directory(observer, sample_run): observer.started_event(**sample_run) sample_run["_id"] = 1 @@ -127,7 +127,7 @@ def test_raises_error_on_duplicate_id_directory(observer, sample_run): observer.started_event(**sample_run) -@moto.mock_s3 +@moto.mock_aws def test_completed_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) run = json.loads( @@ -145,7 +145,7 @@ def test_completed_event_updates_run_json(observer, sample_run): assert run["status"] == "COMPLETED" -@moto.mock_s3 +@moto.mock_aws def test_interrupted_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) run = json.loads( @@ -163,7 +163,7 @@ def test_interrupted_event_updates_run_json(observer, sample_run): assert run["status"] == "SERVER_EXPLODED" -@moto.mock_s3 +@moto.mock_aws def test_failed_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) run = json.loads( @@ -181,7 +181,7 @@ def test_failed_event_updates_run_json(observer, sample_run): assert run["status"] == "FAILED" -@moto.mock_s3 +@moto.mock_aws def test_queued_event_updates_run_json(observer, sample_run): del sample_run["start_time"] sample_run["queue_time"] = T2 @@ -194,7 +194,7 @@ def test_queued_event_updates_run_json(observer, sample_run): assert run["status"] == "QUEUED" -@moto.mock_s3 +@moto.mock_aws def test_artifact_event_works(observer, sample_run, tmpfile): observer.started_event(**sample_run) observer.artifact_event("test_artifact.py", tmpfile.name) From cba5224fad573b1594ea7926958aa230e6d2190b Mon Sep 17 00:00:00 2001 From: Nicholas Gao Date: Fri, 23 Aug 2024 09:24:45 +0200 Subject: [PATCH 2/2] remove unnecessary parentheses (trigger CI) --- sacred/config/custom_containers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sacred/config/custom_containers.py b/sacred/config/custom_containers.py index 746e8e96..b02bf0c6 100644 --- a/sacred/config/custom_containers.py +++ b/sacred/config/custom_containers.py @@ -100,7 +100,7 @@ def update(self, iterable=None, **kwargs): for key in iterable: self[key] = iterable[key] else: - for (key, value) in iterable: + for key, value in iterable: self[key] = value for key in kwargs: self[key] = kwargs[key]