diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 541123f..9b89dc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,8 @@ repos: - arrow>=1,<2 - backports-zoneinfo; python_version<'3.9' - boto3-stubs[boto3,s3]>=1,<2 + - cfgv>=3,<4 + - types-pyyaml>=6,<7 - rich>=13,<14 - typing-extensions>=4.4,<5 # test dependencies (corresponds to test-requirements.txt) diff --git a/example-config.yaml b/example-config.yaml new file mode 100644 index 0000000..41cf96c --- /dev/null +++ b/example-config.yaml @@ -0,0 +1,15 @@ +timezone: America/Los_Angeles +sources: +- path: /foo + snapshots: /snapshots + preserve: 1y 3m 30d 24h 60M 60s + backups: + s3: + bucket: btrfs2s3-test + endpoint: https://foo + region: + profile: + verify: + pipe_through: + - zstd -T0 -12 + - gpg --encrypt -r 0xFE592029B2CB9D04 --trust-model always diff --git a/pyproject.toml b/pyproject.toml index 65055a8..16b652c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "arrow>=1,<2", "backports-zoneinfo; python_version<'3.9'", "boto3[crt]>=1,<2", + "cfgv>=3,<4", + "pyyaml>=6,<7", "rich>=13,<14", "typing-extensions>=4.4,<5", "tzdata", @@ -103,3 +105,7 @@ legacy_tox_ini = """ [tool.mypy] mypy_path = "typeshed" strict = true + +[[tool.mypy.overrides]] +module = "cfgv" +ignore_missing_imports = true diff --git a/src/btrfs2s3/commands/run.py b/src/btrfs2s3/commands/run.py index 7b5f942..f4c4022 100644 --- a/src/btrfs2s3/commands/run.py +++ b/src/btrfs2s3/commands/run.py @@ -4,7 +4,7 @@ from collections import defaultdict from pathlib import Path -import shlex +from typing import cast from typing import TYPE_CHECKING import arrow @@ -26,6 +26,8 @@ from btrfs2s3.assessor import assessment_to_actions from btrfs2s3.assessor import BackupAssessment from btrfs2s3.assessor import SourceAssessment +from btrfs2s3.config import Config +from btrfs2s3.config import load_from_path from btrfs2s3.preservation import Params from btrfs2s3.preservation import Policy from btrfs2s3.preservation import TS @@ -335,18 +337,9 @@ def print_actions(*, console: Console, actions: Actions) -> None: def add_args(parser: argparse.ArgumentParser) -> None: """Add args for "btrfs2s3 run" to an ArgumentParser.""" + parser.add_argument("config_file", type=load_from_path) parser.add_argument("--force", action="store_true") parser.add_argument("--pretend", action="store_true") - parser.add_argument("--region") - parser.add_argument("--profile") - parser.add_argument("--endpoint-url") - parser.add_argument("--no-verify", action="store_false", dest="verify") - parser.add_argument("--source", action="append", type=Path, required=True) - parser.add_argument("--snapshot-dir", type=Path, required=True) - parser.add_argument("--bucket", required=True) - parser.add_argument("--timezone", type=get_zoneinfo, required=True) - parser.add_argument("--preserve", type=Params.parse, required=True) - parser.add_argument("--pipe-through", action="append", type=shlex.split, default=[]) def command(*, console: Console, args: argparse.Namespace) -> int: @@ -355,14 +348,63 @@ def command(*, console: Console, args: argparse.Namespace) -> int: console.print("to run in unattended mode, use --force") return 1 - session = Session(region_name=args.region, profile_name=args.profile) - s3 = session.client("s3", verify=args.verify, endpoint_url=args.endpoint_url) - policy = Policy(tzinfo=args.timezone, params=args.preserve) + config = cast(Config, args.config_file) + tzinfo = get_zoneinfo(config["timezone"]) + assert len(config["remotes"]) == 1 # noqa: S101 + s3_remote = config["remotes"][0]["s3"] + s3_endpoint = s3_remote.get("endpoint", {}) + + sources = config["sources"] + assert len({source["snapshots"] for source in sources}) == 1 # noqa: S101 + assert ( # noqa: S101 + len( + { + upload["id"] + for source in sources + for upload in source["upload_to_remotes"] + } + ) + == 1 + ) + assert ( # noqa: S101 + len( + { + upload["preserve"] + for source in sources + for upload in source["upload_to_remotes"] + } + ) + == 1 + ) + assert ( # noqa: S101 + len( + { + tuple(tuple(cmd) for cmd in upload.get("pipe_through", [])) + for source in sources + for upload in source["upload_to_remotes"] + } + ) + == 1 + ) + + session = Session( + region_name=s3_endpoint.get("region_name"), + profile_name=s3_endpoint.get("profile_name"), + ) + s3 = session.client( + "s3", + verify=s3_endpoint.get("verify"), + endpoint_url=s3_endpoint.get("endpoint_url"), + ) + policy = Policy( + tzinfo=tzinfo, + params=Params.parse(sources[0]["upload_to_remotes"][0]["preserve"]), + ) asmt = assess( - snapshot_dir=args.snapshot_dir, - sources=args.source, + snapshot_dir=Path(sources[0]["snapshots"]), + sources=[Path(source["path"]) for source in sources], s3=s3, - bucket=args.bucket, + bucket=s3_remote["bucket"], policy=policy, ) actions = Actions() @@ -372,9 +414,9 @@ def command(*, console: Console, args: argparse.Namespace) -> int: print_assessment( console=console, asmt=asmt, - tzinfo=args.timezone, - snapshot_dir=args.snapshot_dir, - bucket=args.bucket, + tzinfo=tzinfo, + snapshot_dir=Path(sources[0]["snapshots"]), + bucket=s3_remote["bucket"], ) print_actions(console=console, actions=actions) @@ -386,6 +428,10 @@ def command(*, console: Console, args: argparse.Namespace) -> int: return 0 if args.force or Confirm(console=console).ask("continue?"): - actions.execute(s3, args.bucket, pipe_through=args.pipe_through) + actions.execute( + s3, + s3_remote["bucket"], + pipe_through=sources[0]["upload_to_remotes"][0].get("pipe_through", []), + ) return 0 diff --git a/src/btrfs2s3/config.py b/src/btrfs2s3/config.py new file mode 100644 index 0000000..f4bd79f --- /dev/null +++ b/src/btrfs2s3/config.py @@ -0,0 +1,198 @@ +"""Code for manipulating config for btrfs2s3.""" + +from __future__ import annotations + +from collections import namedtuple +from typing import Any +from typing import cast +from typing import TYPE_CHECKING +from typing import TypedDict + +from cfgv import Array +from cfgv import check_array +from cfgv import check_string +from cfgv import check_type +from cfgv import load_from_filename +from cfgv import Map +from cfgv import OptionalNoDefault +from cfgv import OptionalRecurse +from cfgv import Required +from cfgv import RequiredRecurse +from typing_extensions import NotRequired +from yaml import safe_load + +from btrfs2s3.preservation import Params + +if TYPE_CHECKING: + from os import PathLike + + +class Error(Exception): + """The top-level class for exceptions generated by this module.""" + + +class InvalidConfigError(Error): + """An error for invalid config.""" + + +def _check_preserve(v: Any) -> None: # noqa: ANN401 + check_string(v) + try: + Params.parse(v) + except ValueError as ex: + msg = "Expected a valid schedule" + raise InvalidConfigError(msg) from ex + + +# this is the same style used in cfgv +_OptionalRecurseNoDefault = namedtuple( # noqa: PYI024 + "_OptionalRecurseNoDefault", ("key", "schema") +) +_OptionalRecurseNoDefault.check = OptionalRecurse.check # type: ignore[attr-defined] +_OptionalRecurseNoDefault.check_fn = OptionalRecurse.check_fn # type: ignore[attr-defined] +_OptionalRecurseNoDefault.apply_default = OptionalNoDefault.apply_default # type: ignore[attr-defined] +_OptionalRecurseNoDefault.remove_default = OptionalNoDefault.remove_default # type: ignore[attr-defined] + + +class S3EndpointConfig(TypedDict): + """A config dict for how to talk to an S3 endpoint.""" + + aws_access_key_id: NotRequired[str] + aws_secret_access_key: NotRequired[str] + region_name: NotRequired[str] + profile_name: NotRequired[str] + api_version: NotRequired[str] + verify: NotRequired[bool | str] + endpoint_url: NotRequired[str] + + +_S3_ENDPOINT_SCHEMA = Map( + "S3EndpointConfig", + None, + OptionalNoDefault("aws_access_key_id", check_string), + OptionalNoDefault("aws_secret_access_key", check_string), + OptionalNoDefault("region_name", check_string), + OptionalNoDefault("profile_name", check_string), + OptionalNoDefault("api_version", check_string), + OptionalNoDefault("verify", check_type((bool, str), typename="bool or path")), + OptionalNoDefault("endpoint_url", check_string), +) + + +class S3RemoteConfig(TypedDict): + """A config dict for how to access an S3 remote.""" + + bucket: str + endpoint: NotRequired[S3EndpointConfig] + + +_S3_SCHEMA = Map( + "S3RemoteConfig", + None, + Required("bucket", check_string), + _OptionalRecurseNoDefault("endpoint", _S3_ENDPOINT_SCHEMA), +) + + +class RemoteConfig(TypedDict): + """A config dict for how to access a remote.""" + + id: str + s3: NotRequired[S3RemoteConfig] + + +_REMOTE_SCHEMA = Map( + "RemoteConfig", + None, + Required("id", check_string), + RequiredRecurse("s3", _S3_SCHEMA), +) + + +class UploadToRemoteConfig(TypedDict): + """A config dict for uploading a source to a remote.""" + + id: str + preserve: str + pipe_through: NotRequired[list[list[str]]] + + +_UPLOAD_TO_REMOTE_SCHEMA = Map( + "UploadToRemoteConfig", + "id", + Required("preserve", _check_preserve), + OptionalNoDefault("pipe_through", check_array(check_array(check_string))), +) + + +class SourceConfig(TypedDict): + """A config dict for a source.""" + + path: str + snapshots: str + upload_to_remotes: list[UploadToRemoteConfig] + + +_SOURCE_SCHEMA = Map( + "SourceConfig", + "path", + Required("path", check_string), + Required("snapshots", check_string), + RequiredRecurse( + "upload_to_remotes", Array(_UPLOAD_TO_REMOTE_SCHEMA, allow_empty=False) + ), +) + + +class Config(TypedDict): + """The top-level config dict. + + This just matches the data as it's stored in config.yaml. We don't do any + transformation up front (for example "preserve" values are just their + strings, not Policy objects). + """ + + timezone: str + sources: list[SourceConfig] + remotes: list[RemoteConfig] + + +_SCHEMA = Map( + "Config", + None, + Required("timezone", check_string), + RequiredRecurse("sources", Array(_SOURCE_SCHEMA, allow_empty=False)), + RequiredRecurse("remotes", Array(_REMOTE_SCHEMA, allow_empty=False)), +) + + +def load_from_path(path: str | PathLike[str]) -> Config: + """Load config from a file path. + + This performs some basic syntactic validation on the config, to ensure it + really conforms to the return type. + + Args: + path: The path to the config. + + Returns: + A Config instance. + + Raises: + InvalidConfigError: If the config does not pass validation. + """ + config = cast( + Config, load_from_filename(path, _SCHEMA, safe_load, exc_tp=InvalidConfigError) + ) + + remote_ids = {remote["id"] for remote in config["remotes"]} + for source in config["sources"]: + for upload_to_remote in source["upload_to_remotes"]: + if upload_to_remote["id"] not in remote_ids: + msg = ( + f'remote id {upload_to_remote["id"]!r} for source ' + f'{source["path"]!r} is not defined in the list of remotes' + ) + raise InvalidConfigError(msg) + + return config diff --git a/tests/commands/run/invocation_test.py b/tests/commands/run/invocation_test.py index 69043ac..88279e8 100644 --- a/tests/commands/run/invocation_test.py +++ b/tests/commands/run/invocation_test.py @@ -19,7 +19,10 @@ def test_pretend( - btrfs_mountpoint: Path, bucket: str, capsys: pytest.CaptureFixture[str] + tmp_path: Path, + btrfs_mountpoint: Path, + bucket: str, + capsys: pytest.CaptureFixture[str], ) -> None: # Create a subvolume source = btrfs_mountpoint / "source" @@ -32,21 +35,21 @@ def test_pretend( btrfsutil.sync(source) console = Console(force_terminal=True, theme=THEME, width=88, height=30) - argv = [ - "run", - "--pretend", - "--source", - str(source), - "--snapshot-dir", - str(snapshot_dir), - "--bucket", - bucket, - "--timezone", - "UTC", - "--preserve", - "1y", - ] - assert main(console=console, argv=argv) == 0 + config_path = tmp_path / "config.yaml" + config_path.write_text(f""" + timezone: UTC + sources: + - path: {source} + snapshots: {snapshot_dir} + upload_to_remotes: + - id: aws + preserve: 1y + remotes: + - id: aws + s3: + bucket: {bucket} + """) + assert main(console=console, argv=["run", "--pretend", str(config_path)]) == 0 (out, err) = capsys.readouterr() # No idea how to stabilize this for golden testing @@ -55,6 +58,7 @@ def test_pretend( def test_force( + tmp_path: Path, btrfs_mountpoint: Path, s3: S3Client, bucket: str, @@ -72,20 +76,21 @@ def test_force( btrfsutil.sync(source) console = Console(force_terminal=True, theme=THEME, width=88, height=30) - argv = [ - "run", - "--force", - "--source", - str(source), - "--snapshot-dir", - str(snapshot_dir), - "--bucket", - bucket, - "--timezone", - "UTC", - "--preserve", - "1y", - ] + config_path = tmp_path / "config.yaml" + config_path.write_text(f""" + timezone: UTC + sources: + - path: {source} + snapshots: {snapshot_dir} + upload_to_remotes: + - id: aws + preserve: 1y + remotes: + - id: aws + s3: + bucket: {bucket} + """) + argv = ["run", "--force", str(config_path)] assert main(console=console, argv=argv) == 0 (out, err) = capsys.readouterr() @@ -109,26 +114,28 @@ def test_force( def test_refuse_to_run_unattended_without_pretend_or_force( - goldifyconsole: Console, + tmp_path: Path, goldifyconsole: Console ) -> None: # This shouldn't get to the point of verifying arguments - argv = [ - "run", - "--source", - "dummy_source", - "--snapshot-dir", - "dummy_snapshot_dir", - "--bucket", - "dummy_bucket", - "--timezone", - "UTC", - "--preserve", - "1y", - ] - assert main(argv=argv, console=goldifyconsole) == 1 + config_path = tmp_path / "config.yaml" + config_path.write_text(""" + timezone: UTC + sources: + - path: dummy_source + snapshots: dummy_snapshot_dir + upload_to_remotes: + - id: aws + preserve: 1y + remotes: + - id: aws + s3: + bucket: dummy_bucket + """) + assert main(argv=["run", str(config_path)], console=goldifyconsole) == 1 def test_reject_continue_prompt( + tmp_path: Path, btrfs_mountpoint: Path, bucket: str, capsys: pytest.CaptureFixture[str], @@ -145,21 +152,22 @@ def test_reject_continue_prompt( btrfsutil.sync(source) console = Console(force_terminal=True, theme=THEME, width=88, height=30) - argv = [ - "run", - "--source", - str(source), - "--snapshot-dir", - str(snapshot_dir), - "--bucket", - bucket, - "--timezone", - "UTC", - "--preserve", - "1y", - ] + config_path = tmp_path / "config.yaml" + config_path.write_text(f""" + timezone: UTC + sources: + - path: {source} + snapshots: {snapshot_dir} + upload_to_remotes: + - id: aws + preserve: 1y + remotes: + - id: aws + s3: + bucket: {bucket} + """) with patch("rich.console.input", return_value="n"): - assert main(console=console, argv=argv) == 0 + assert main(console=console, argv=["run", str(config_path)]) == 0 (out, err) = capsys.readouterr() # No idea how to stabilize this for golden testing diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 0000000..68901cc --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from btrfs2s3.config import Config +from btrfs2s3.config import InvalidConfigError +from btrfs2s3.config import load_from_path +from btrfs2s3.config import RemoteConfig +from btrfs2s3.config import S3EndpointConfig +from btrfs2s3.config import S3RemoteConfig +from btrfs2s3.config import SourceConfig +from btrfs2s3.config import UploadToRemoteConfig +import pytest + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.fixture() +def path(tmp_path: Path) -> Path: + return tmp_path / "config.yaml" + + +def test_malformed(path: Path) -> None: + path.write_text("malformed, bad text") + with pytest.raises(InvalidConfigError): + load_from_path(path) + + +def test_basic(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: c + upload_to_remotes: + - id: aws + preserve: 1y 1m + remotes: + - id: aws + s3: + bucket: d + """) + config = load_from_path(path) + assert config == Config( + { + "timezone": "a", + "sources": [ + SourceConfig( + { + "path": "b", + "snapshots": "c", + "upload_to_remotes": [ + UploadToRemoteConfig({"id": "aws", "preserve": "1y 1m"}) + ], + } + ) + ], + "remotes": [ + RemoteConfig({"id": "aws", "s3": S3RemoteConfig({"bucket": "d"})}) + ], + } + ) + + +def test_multiple_sources_with_anchors_and_refs(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: &snapshots c + upload_to_remotes: &my_remotes + - id: aws + preserve: 1y 1m + - path: otherpath + snapshots: *snapshots + upload_to_remotes: *my_remotes + remotes: + - id: aws + s3: + bucket: d + """) + config = load_from_path(path) + source0 = config["sources"][0] + source1 = config["sources"][1] + assert source0["snapshots"] == source1["snapshots"] + assert source1["upload_to_remotes"] == source1["upload_to_remotes"] + + +def test_s3_endpoint_config(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: c + upload_to_remotes: + - id: aws + preserve: 1y 1m + remotes: + - id: aws + s3: + bucket: d + endpoint: + aws_access_key_id: key + aws_secret_access_key: secret + region_name: region + profile_name: profile + api_version: version + verify: true + endpoint_url: https://example.com + """) + endpoint = load_from_path(path)["remotes"][0]["s3"]["endpoint"] + assert endpoint == S3EndpointConfig( + { + "aws_access_key_id": "key", + "aws_secret_access_key": "secret", + "region_name": "region", + "profile_name": "profile", + "api_version": "version", + "verify": True, + "endpoint_url": "https://example.com", + } + ) + + +def test_pipe_through(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: c + upload_to_remotes: + - id: aws + preserve: 1y 1m + pipe_through: + - [gzip] + - [gpg, encrypt, -r, me@example.com] + remotes: + - id: aws + s3: + bucket: d + """) + config = load_from_path(path) + assert config["sources"][0]["upload_to_remotes"][0]["pipe_through"] == [ + ["gzip"], + ["gpg", "encrypt", "-r", "me@example.com"], + ] + + +def test_no_sources(path: Path) -> None: + path.write_text(""" + timezone: a + sources: [] + upload_to_remotes: + - id: aws + s3: + bucket: d + """) + with pytest.raises(InvalidConfigError): + load_from_path(path) + + +def test_no_remotes(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: c + remotes: + - id: aws + preserve: 1y 1m + remotes: [] + """) + with pytest.raises(InvalidConfigError): + load_from_path(path) + + +def test_source_with_no_upload_to_remotes(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: c + upload_to_remotes: [] + remotes: + - id: aws + s3: + bucket: d + """) + with pytest.raises(InvalidConfigError): + load_from_path(path) + + +def test_invalid_preserve(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: c + upload_to_remotes: + - id: aws + preserve: invalid + remotes: + - id: aws + s3: + bucket: d + """) + with pytest.raises(InvalidConfigError): + load_from_path(path) + + +def test_invalid_upload_to_remote_id(path: Path) -> None: + path.write_text(""" + timezone: a + sources: + - path: b + snapshots: c + upload_to_remotes: + - id: does_not_Exist + preserve: 1y + remotes: + - id: aws + s3: + bucket: d + """) + with pytest.raises(InvalidConfigError): + load_from_path(path)