Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnloadPrompt: Add unload_prompt_artifacts #30

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions caikit_tgis_backend/tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Standard
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional
import os
import shutil

Expand Down Expand Up @@ -163,7 +163,7 @@ def tls_enabled(self) -> bool:
def mtls_enabled(self) -> bool:
return None not in [self.ca_cert_file, self.client_tls]

def load_prompt_artifacts(self, prompt_id: str, *artifact_paths):
def load_prompt_artifacts(self, prompt_id: str, *artifact_paths: List[str]):
"""Load the given artifact paths to this TGIS connection

As implemented, this is a simple copy to the TGIS instance's prompt dir,
Expand All @@ -172,6 +172,10 @@ def load_prompt_artifacts(self, prompt_id: str, *artifact_paths):
TODO: If two copies of the runtime attempt to perform the same copy at
the same time, it could race and cause errors with the mounted
directory system.

Args:
prompt_id (str): The ID that this prompt should use
*artifact_paths (List[str]): The paths to the artifacts to laod
"""
error.value_check(
"<TGB07970356E>",
Expand All @@ -192,6 +196,35 @@ def load_prompt_artifacts(self, prompt_id: str, *artifact_paths):
log.debug3("Copying %s -> %s", artifact_path, target_file)
shutil.copyfile(artifact_path, target_file)

def unload_prompt_artifacts(self, *prompt_ids: List[str]):
"""Unload the given prompts from TGIS

As implemented, this simply removes the prompt artifacts for these IDs
and does not explicitly unload them from the TGIS in-memory cache.

NOTE: This intentionally ignores all errors. It's very likely that
multiple replicas of the runtime will attempt to unload the same
prompt, so we need to let the first one win and the rest quietly
accept that it's already deleted.

Args:
*prompt_ids (List[str]): The IDs to unload
"""
error.value_check(
"<TGB07970365E>",
self.prompt_dir is not None,
"No prompt_dir configured for {}",
self.hostname,
)
error.type_check_all(
"<TGB41380075E>",
str,
prompt_ids=prompt_ids,
)
for prompt_id in prompt_ids:
prompt_id_dir = os.path.join(self.prompt_dir, prompt_id)
shutil.rmtree(prompt_id_dir, ignore_errors=True)

def get_client(self) -> generation_pb2_grpc.GenerationServiceStub:
"""Get a grpc client for the connection"""
if self._client is None:
Expand Down
51 changes: 51 additions & 0 deletions tests/test_tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,57 @@ def test_load_prompt_artifacts_no_prompt_dir():
conn.load_prompt_artifacts(prompt_id, *source_files)


def test_unload_prompt_artifacts_ok():
"""Make sure that prompt artifacts can be unloaded cleanly"""
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:
# Make some source files
source_fnames = ["foo.pt", "bar.pt"]
source_files = [os.path.join(source_dir, fname) for fname in source_fnames]
for source_file in source_files:
with open(source_file, "w") as handle:
handle.write("stub")

# Make the connection with the prompt dir
conn = TGISConnection.from_config(
"",
{
TGISConnection.HOSTNAME_KEY: "foo.bar:1234",
TGISConnection.PROMPT_DIR_KEY: prompt_dir,
},
)

# Copy the artifacts over
prompt_id = "some-prompt-id"
conn.load_prompt_artifacts(prompt_id, *source_files)

# Make sure the artifacts are available
for fname in source_fnames:
assert os.path.exists(os.path.join(prompt_dir, prompt_id, fname))

# Unload all of the prompts and make sure they're gone
conn.unload_prompt_artifacts(prompt_id)
assert not os.path.exists(os.path.join(prompt_dir, prompt_id))


def test_unload_prompt_artifacts_bad_prompt_id():
"""Make sure that unloading a bad prompt ID is a no-op"""
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:
conn = TGISConnection.from_config(
"",
{
TGISConnection.HOSTNAME_KEY: "foo.bar:1234",
TGISConnection.PROMPT_DIR_KEY: prompt_dir,
},
)

# Unload all of the prompts and make sure they're gone
prompt_id = "some-prompt-id"
conn.unload_prompt_artifacts(prompt_id)
assert not os.path.exists(os.path.join(prompt_dir, prompt_id))


def test_connection_valid_endpoint(tgis_mock_insecure):
"""Make sure that a connection test works with a valid server"""
conn = TGISConnection(hostname=tgis_mock_insecure.hostname, model_id="asdf")
Expand Down