Skip to content

Commit

Permalink
use a config.yaml instead of cli args
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrudenell committed Jul 8, 2024
1 parent 60da68d commit e75f2a1
Show file tree
Hide file tree
Showing 7 changed files with 582 additions and 80 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ repos:
- arrow>=1,<2
- backports-zoneinfo; python_version<'3.9'
- boto3-stubs[boto3,s3]>=1,<2
- cfgv>=3,<4
- types-pyyaml>=6,<7
- rich>=13,<14
- typing-extensions>=4.4,<5
# test dependencies (corresponds to test-requirements.txt)
Expand Down
15 changes: 15 additions & 0 deletions example-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
timezone: America/Los_Angeles
sources:
- path: /foo
snapshots: /snapshots
preserve: 1y 3m 30d 24h 60M 60s
backups:
s3:
bucket: btrfs2s3-test
endpoint: https://foo
region:
profile:
verify:
pipe_through:
- zstd -T0 -12
- gpg --encrypt -r 0xFE592029B2CB9D04 --trust-model always
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ dependencies = [
"arrow>=1,<2",
"backports-zoneinfo; python_version<'3.9'",
"boto3[crt]>=1,<2",
"cfgv>=3,<4",
"pyyaml>=6,<7",
"rich>=13,<14",
"typing-extensions>=4.4,<5",
"tzdata",
Expand Down Expand Up @@ -103,3 +105,7 @@ legacy_tox_ini = """
[tool.mypy]
mypy_path = "typeshed"
strict = true

[[tool.mypy.overrides]]
module = "cfgv"
ignore_missing_imports = true
88 changes: 67 additions & 21 deletions src/btrfs2s3/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections import defaultdict
from pathlib import Path
import shlex
from typing import cast
from typing import TYPE_CHECKING

import arrow
Expand All @@ -26,6 +26,8 @@
from btrfs2s3.assessor import assessment_to_actions
from btrfs2s3.assessor import BackupAssessment
from btrfs2s3.assessor import SourceAssessment
from btrfs2s3.config import Config
from btrfs2s3.config import load_from_path
from btrfs2s3.preservation import Params
from btrfs2s3.preservation import Policy
from btrfs2s3.preservation import TS
Expand Down Expand Up @@ -335,18 +337,9 @@ def print_actions(*, console: Console, actions: Actions) -> None:

def add_args(parser: argparse.ArgumentParser) -> None:
"""Add args for "btrfs2s3 run" to an ArgumentParser."""
parser.add_argument("config_file", type=load_from_path)
parser.add_argument("--force", action="store_true")
parser.add_argument("--pretend", action="store_true")
parser.add_argument("--region")
parser.add_argument("--profile")
parser.add_argument("--endpoint-url")
parser.add_argument("--no-verify", action="store_false", dest="verify")
parser.add_argument("--source", action="append", type=Path, required=True)
parser.add_argument("--snapshot-dir", type=Path, required=True)
parser.add_argument("--bucket", required=True)
parser.add_argument("--timezone", type=get_zoneinfo, required=True)
parser.add_argument("--preserve", type=Params.parse, required=True)
parser.add_argument("--pipe-through", action="append", type=shlex.split, default=[])


def command(*, console: Console, args: argparse.Namespace) -> int:
Expand All @@ -355,14 +348,63 @@ def command(*, console: Console, args: argparse.Namespace) -> int:
console.print("to run in unattended mode, use --force")
return 1

session = Session(region_name=args.region, profile_name=args.profile)
s3 = session.client("s3", verify=args.verify, endpoint_url=args.endpoint_url)
policy = Policy(tzinfo=args.timezone, params=args.preserve)
config = cast(Config, args.config_file)
tzinfo = get_zoneinfo(config["timezone"])
assert len(config["remotes"]) == 1 # noqa: S101
s3_remote = config["remotes"][0]["s3"]
s3_endpoint = s3_remote.get("endpoint", {})

sources = config["sources"]
assert len({source["snapshots"] for source in sources}) == 1 # noqa: S101
assert ( # noqa: S101
len(
{
upload["id"]
for source in sources
for upload in source["upload_to_remotes"]
}
)
== 1
)
assert ( # noqa: S101
len(
{
upload["preserve"]
for source in sources
for upload in source["upload_to_remotes"]
}
)
== 1
)
assert ( # noqa: S101
len(
{
tuple(tuple(cmd) for cmd in upload.get("pipe_through", []))
for source in sources
for upload in source["upload_to_remotes"]
}
)
== 1
)

session = Session(
region_name=s3_endpoint.get("region_name"),
profile_name=s3_endpoint.get("profile_name"),
)
s3 = session.client(
"s3",
verify=s3_endpoint.get("verify"),
endpoint_url=s3_endpoint.get("endpoint_url"),
)
policy = Policy(
tzinfo=tzinfo,
params=Params.parse(sources[0]["upload_to_remotes"][0]["preserve"]),
)
asmt = assess(
snapshot_dir=args.snapshot_dir,
sources=args.source,
snapshot_dir=Path(sources[0]["snapshots"]),
sources=[Path(source["path"]) for source in sources],
s3=s3,
bucket=args.bucket,
bucket=s3_remote["bucket"],
policy=policy,
)
actions = Actions()
Expand All @@ -372,9 +414,9 @@ def command(*, console: Console, args: argparse.Namespace) -> int:
print_assessment(
console=console,
asmt=asmt,
tzinfo=args.timezone,
snapshot_dir=args.snapshot_dir,
bucket=args.bucket,
tzinfo=tzinfo,
snapshot_dir=Path(sources[0]["snapshots"]),
bucket=s3_remote["bucket"],
)
print_actions(console=console, actions=actions)

Expand All @@ -386,6 +428,10 @@ def command(*, console: Console, args: argparse.Namespace) -> int:
return 0

if args.force or Confirm(console=console).ask("continue?"):
actions.execute(s3, args.bucket, pipe_through=args.pipe_through)
actions.execute(
s3,
s3_remote["bucket"],
pipe_through=sources[0]["upload_to_remotes"][0].get("pipe_through", []),
)

return 0
198 changes: 198 additions & 0 deletions src/btrfs2s3/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Code for manipulating config for btrfs2s3."""

from __future__ import annotations

from collections import namedtuple
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import TypedDict

from cfgv import Array
from cfgv import check_array
from cfgv import check_string
from cfgv import check_type
from cfgv import load_from_filename
from cfgv import Map
from cfgv import OptionalNoDefault
from cfgv import OptionalRecurse
from cfgv import Required
from cfgv import RequiredRecurse
from typing_extensions import NotRequired
from yaml import safe_load

from btrfs2s3.preservation import Params

if TYPE_CHECKING:
from os import PathLike


class Error(Exception):
"""The top-level class for exceptions generated by this module."""


class InvalidConfigError(Error):
"""An error for invalid config."""


def _check_preserve(v: Any) -> None: # noqa: ANN401
check_string(v)
try:
Params.parse(v)
except ValueError as ex:
msg = "Expected a valid schedule"
raise InvalidConfigError(msg) from ex


# this is the same style used in cfgv
_OptionalRecurseNoDefault = namedtuple( # noqa: PYI024
"_OptionalRecurseNoDefault", ("key", "schema")
)
_OptionalRecurseNoDefault.check = OptionalRecurse.check # type: ignore[attr-defined]
_OptionalRecurseNoDefault.check_fn = OptionalRecurse.check_fn # type: ignore[attr-defined]
_OptionalRecurseNoDefault.apply_default = OptionalNoDefault.apply_default # type: ignore[attr-defined]
_OptionalRecurseNoDefault.remove_default = OptionalNoDefault.remove_default # type: ignore[attr-defined]


class S3EndpointConfig(TypedDict):
"""A config dict for how to talk to an S3 endpoint."""

aws_access_key_id: NotRequired[str]
aws_secret_access_key: NotRequired[str]
region_name: NotRequired[str]
profile_name: NotRequired[str]
api_version: NotRequired[str]
verify: NotRequired[bool | str]
endpoint_url: NotRequired[str]


_S3_ENDPOINT_SCHEMA = Map(
"S3EndpointConfig",
None,
OptionalNoDefault("aws_access_key_id", check_string),
OptionalNoDefault("aws_secret_access_key", check_string),
OptionalNoDefault("region_name", check_string),
OptionalNoDefault("profile_name", check_string),
OptionalNoDefault("api_version", check_string),
OptionalNoDefault("verify", check_type((bool, str), typename="bool or path")),
OptionalNoDefault("endpoint_url", check_string),
)


class S3RemoteConfig(TypedDict):
"""A config dict for how to access an S3 remote."""

bucket: str
endpoint: NotRequired[S3EndpointConfig]


_S3_SCHEMA = Map(
"S3RemoteConfig",
None,
Required("bucket", check_string),
_OptionalRecurseNoDefault("endpoint", _S3_ENDPOINT_SCHEMA),
)


class RemoteConfig(TypedDict):
"""A config dict for how to access a remote."""

id: str
s3: NotRequired[S3RemoteConfig]


_REMOTE_SCHEMA = Map(
"RemoteConfig",
None,
Required("id", check_string),
RequiredRecurse("s3", _S3_SCHEMA),
)


class UploadToRemoteConfig(TypedDict):
"""A config dict for uploading a source to a remote."""

id: str
preserve: str
pipe_through: NotRequired[list[list[str]]]


_UPLOAD_TO_REMOTE_SCHEMA = Map(
"UploadToRemoteConfig",
"id",
Required("preserve", _check_preserve),
OptionalNoDefault("pipe_through", check_array(check_array(check_string))),
)


class SourceConfig(TypedDict):
"""A config dict for a source."""

path: str
snapshots: str
upload_to_remotes: list[UploadToRemoteConfig]


_SOURCE_SCHEMA = Map(
"SourceConfig",
"path",
Required("path", check_string),
Required("snapshots", check_string),
RequiredRecurse(
"upload_to_remotes", Array(_UPLOAD_TO_REMOTE_SCHEMA, allow_empty=False)
),
)


class Config(TypedDict):
"""The top-level config dict.
This just matches the data as it's stored in config.yaml. We don't do any
transformation up front (for example "preserve" values are just their
strings, not Policy objects).
"""

timezone: str
sources: list[SourceConfig]
remotes: list[RemoteConfig]


_SCHEMA = Map(
"Config",
None,
Required("timezone", check_string),
RequiredRecurse("sources", Array(_SOURCE_SCHEMA, allow_empty=False)),
RequiredRecurse("remotes", Array(_REMOTE_SCHEMA, allow_empty=False)),
)


def load_from_path(path: str | PathLike[str]) -> Config:
"""Load config from a file path.
This performs some basic syntactic validation on the config, to ensure it
really conforms to the return type.
Args:
path: The path to the config.
Returns:
A Config instance.
Raises:
InvalidConfigError: If the config does not pass validation.
"""
config = cast(
Config, load_from_filename(path, _SCHEMA, safe_load, exc_tp=InvalidConfigError)
)

remote_ids = {remote["id"] for remote in config["remotes"]}
for source in config["sources"]:
for upload_to_remote in source["upload_to_remotes"]:
if upload_to_remote["id"] not in remote_ids:
msg = (
f'remote id {upload_to_remote["id"]!r} for source '
f'{source["path"]!r} is not defined in the list of remotes'
)
raise InvalidConfigError(msg)

return config
Loading

0 comments on commit e75f2a1

Please sign in to comment.