diff --git a/charms/worker/k8s/src/snap.py b/charms/worker/k8s/src/snap.py index db2e5e02..cac24fae 100644 --- a/charms/worker/k8s/src/snap.py +++ b/charms/worker/k8s/src/snap.py @@ -253,16 +253,51 @@ def management(charm: ops.CharmBase) -> None: cache = snap_lib.SnapCache() for args in _parse_management_arguments(charm): which = cache[args.name] + if block_refresh(which, args): + continue + install_args = args.dict(exclude_none=True) if isinstance(args, SnapFileArgument) and which.revision != "x1": - snap_lib.install_local(**args.dict(exclude_none=True)) + snap_lib.install_local(**install_args) elif isinstance(args, SnapStoreArgument) and args.revision: if which.revision != args.revision: log.info("Ensuring %s snap revision=%s", args.name, args.revision) - which.ensure(**args.dict(exclude_none=True)) + which.ensure(**install_args) which.hold() elif isinstance(args, SnapStoreArgument): log.info("Ensuring %s snap channel=%s", args.name, args.channel) - which.ensure(**args.dict(exclude_none=True)) + which.ensure(**install_args) + + +def block_refresh(which: snap_lib.Snap, args: SnapArgument) -> bool: + """Block snap refreshes if the snap is in a specific state. + + Arguments: + which: The snap to check + args: The snap arguments + + Returns: + bool: True if the snap should be blocked from refreshing + """ + if snap_lib.SnapState(which.state) == snap_lib.SnapState.Available: + log.info("Allowing %s snap installation", args.name) + return False + if _overridden_snap_installation().exists(): + log.info("Allowing %s snap refresh due to snap installation override", args.name) + return False + if isinstance(args, SnapStoreArgument) and args.revision: + if block := which.revision != args.revision: + log.info("Blocking %s snap refresh to revision=%s", args.name, args.revision) + else: + log.info("Allowing %s snap refresh to same revision", args.name) + return block + if isinstance(args, SnapStoreArgument): + if block := which.channel != args.channel: + log.info("Blocking %s snap refresh to channel=%s", args.name, args.channel) + else: + log.info("Allowing %s snap refresh to same channel (%s)", args.name, args.channel) + return block + log.info("Blocking %s snap refresh", args.name) + return True def version(snap: str) -> Tuple[Optional[str], bool]: diff --git a/charms/worker/k8s/tests/unit/test_snap.py b/charms/worker/k8s/tests/unit/test_snap.py index 0ff07b62..e77df102 100644 --- a/charms/worker/k8s/tests/unit/test_snap.py +++ b/charms/worker/k8s/tests/unit/test_snap.py @@ -66,6 +66,82 @@ def resource_snap_installation(tmp_path): yield mock_path +@pytest.fixture() +def block_refresh(): + """Block snap refresh.""" + with mock.patch("snap.block_refresh") as mocked: + mocked.return_value = False + yield mocked + + +@mock.patch("snap.snap_lib.SnapCache") +@pytest.mark.parametrize( + "state, as_file", + [ + [("present", "1234", None), False], + [("present", None, "edge"), False], + [("present", None, None), True], + ], + ids=[ + "installed & store-by-channel", + "installed & store-by-revision", + "installed & file-without-override", + ], +) +def test_block_refresh(cache, state, as_file, caplog, resource_snap_installation): + """Test block refresh.""" + caplog.set_level(0) + k8s_snap = cache()["k8s"] + k8s_snap.state, k8s_snap.revision, k8s_snap.channel = state + if as_file: + args = snap.SnapFileArgument( + name="k8s", + filename=resource_snap_installation.parent / "k8s_1234.snap", + ) + else: + args = snap.SnapStoreArgument( + name="k8s", + channel="beta" if k8s_snap.channel else None, + revision="5678" if k8s_snap.revision else None, + ) + assert snap.block_refresh(k8s_snap, args) + assert "Blocking k8s snap refresh" in caplog.text + + +@mock.patch("snap.snap_lib.SnapCache") +@pytest.mark.parametrize( + "state, overridden", + [ + [("available", None, None), None], + [("present", "1234", None), None], + [("present", None, "edge"), None], + [("present", None, None), True], + ], + ids=[ + "not installed yet", + "installed & store-by-same-channel", + "installed & store-by-same-revision", + "installed & override", + ], +) +def test_not_block_refresh(cache, state, overridden, caplog, resource_snap_installation): + """Test block refresh.""" + caplog.set_level(0) + k8s_snap = cache()["k8s"] + k8s_snap.state, k8s_snap.revision, k8s_snap.channel = state + if overridden: + resource_snap_installation.write_text( + "amd64:\n- install-type: store\n name: k8s\n channel: edge" + ) + args = snap.SnapStoreArgument( + name="k8s", + channel=k8s_snap.channel, + revision=k8s_snap.revision, + ) + assert not snap.block_refresh(k8s_snap, args) + assert "Allowing k8s snap" in caplog.text + + @pytest.mark.usefixtures("missing_snap_installation") def test_parse_no_file(harness): """Test no file exists.""" @@ -222,6 +298,7 @@ def test_parse_valid_file(mock_checkoutput, snap_installation, harness): ] +@pytest.mark.usefixtures("block_refresh") @mock.patch("snap._parse_management_arguments") @mock.patch("snap.snap_lib.install_local") @mock.patch("snap.snap_lib.SnapCache") @@ -234,6 +311,7 @@ def test_management_installs_local(cache, install_local, args, harness): install_local.assert_called_once_with(filename=Path("path/to/thing")) +@pytest.mark.usefixtures("block_refresh") @mock.patch("snap._parse_management_arguments") @mock.patch("snap.snap_lib.install_local") @mock.patch("snap.snap_lib.SnapCache") @@ -248,6 +326,7 @@ def test_management_installs_store_from_channel(cache, install_local, args, revi k8s_snap.ensure.assert_called_once_with(state=snap.snap_lib.SnapState.Present, channel="edge") +@pytest.mark.usefixtures("block_refresh") @mock.patch("snap._parse_management_arguments") @mock.patch("snap.snap_lib.install_local") @mock.patch("snap.snap_lib.SnapCache")