Skip to content

Commit

Permalink
Merge pull request #21 from gabe-l-hart/LoadPrompt
Browse files Browse the repository at this point in the history
Load prompt
  • Loading branch information
gabe-l-hart authored Jul 27, 2023
2 parents f95f734 + 98cc844 commit 06af8f4
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 4 deletions.
12 changes: 12 additions & 0 deletions caikit_tgis_backend/managed_tgis_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
bootup_poll_delay: float = 1,
load_timeout: float = 30,
num_gpus: int = 1,
prompt_dir: Optional[str] = None,
):
"""Create a ManagedTGISSubprocess
Expand All @@ -77,6 +78,9 @@ def __init__(
checks during bootup. Defaults to 1.
load_timeout (float, optional): number of seconds to wait for TGIS to
boot before cancelling. Defaults to 30.
num_gpus (int): The number of GPUs to use for this instance
prompt_dir (Optional[str]): A directory with write access to use as
the prompt cache for this instance
"""
# parameters of the TGIS subprocess
self._model_id = None
Expand All @@ -86,6 +90,10 @@ def __init__(
self._health_poll_timeout = health_poll_timeout
self._load_timeout = load_timeout
self._num_gpus = num_gpus
self._prompt_dir = prompt_dir
error.value_check(
"<TGB54435438E>", prompt_dir is None or os.path.isdir(prompt_dir)
)
log.debug("Managing local TGIS with %d GPU(s)", self._num_gpus)

self._hostname = f"localhost:{self._grpc_port}"
Expand All @@ -108,6 +116,7 @@ def get_connection(self):
"""Get the TGISConnection object for this local connection"""
return TGISConnection(
hostname=self._hostname,
prompt_dir=self._prompt_dir,
_client=self.get_client(),
)

Expand Down Expand Up @@ -208,6 +217,9 @@ def _launch(self):
log.debug2("Launching TGIS with command: [%s]", launch_cmd)
env = os.environ.copy()
env["GRPC_PORT"] = str(self._grpc_port)
if self._prompt_dir is not None:
env["PREFIX_STORE_PATH"] = self._prompt_dir

# Long running process
# pylint: disable=consider-using-with
self._tgis_proc = subprocess.Popen(shlex.split(launch_cmd), env=env)
Expand Down
12 changes: 11 additions & 1 deletion caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, config: Optional[dict] = None):
health_poll_timeout=local_cfg.get("health_poll_timeout", 10),
load_timeout=local_cfg.get("load_timeout", 30),
num_gpus=local_cfg.get("num_gpus", 1),
prompt_dir=local_cfg.get("prompt_dir"),
)

def __del__(self):
Expand Down Expand Up @@ -189,7 +190,6 @@ def get_client(self, model_id: str) -> generation_pb2_grpc.GenerationServiceStub
# Return the client to the server
return model_conn.get_client()

# pylint: disable=unused-argument
def unload_model(self, model_id: str):
"""Unload the model from TGIS"""
# If running locally, shut down the managed instance
Expand All @@ -201,6 +201,16 @@ def unload_model(self, model_id: str):
# Remove the connection for this model
self._model_connections.pop(model_id, None)

def load_prompt_artifacts(self, model_id: str, prompt_id: str, *prompt_artifacts):
"""Load the given prompt artifacts for the given prompt against the base
model
"""
conn = self.get_connection(model_id)
error.value_check(
"<TGB00822514E>", conn is not None, "Unknown model {}", model_id
)
conn.load_prompt_artifacts(prompt_id, *prompt_artifacts)

@property
def local_tgis(self) -> bool:
return self._local_tgis
Expand Down
65 changes: 62 additions & 3 deletions caikit_tgis_backend/tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dataclasses import dataclass
from typing import Optional
import os
import shutil

# Third Party
import grpc
Expand All @@ -41,18 +42,31 @@ class TLSFilePair:
@dataclass
class TGISConnection:

# Class members
#################
# Class members #
#################

# The URL (with port) for the connection
hostname: str
# Path to CA cert when TGIS is running with TLS
ca_cert_file: Optional[str] = None
# Paths to client key/cert pair when TGIS requires mTLS
client_tls: Optional[TLSFilePair] = None
# Mounted directory where TGIS will look for prompt vector artifacts
prompt_dir: Optional[str] = None
# Private member to hold the client once created
_client: Optional[generation_pb2_grpc.GenerationServiceStub] = None

# Class constants
###################
# Class constants #
###################

HOSTNAME_KEY = "hostname"
HOSTNAME_TEMPLATE_MODEL_ID = "model_id"
CA_CERT_FILE_KEY = "ca_cert_file"
CLIENT_CERT_FILE_KEY = "client_cert_file"
CLIENT_KEY_FILE_KEY = "client_key_file"
PROMPT_DIR_KEY = "prompt_dir"

@classmethod
def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]:
Expand All @@ -66,6 +80,17 @@ def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]:
)
log.debug("Resolved hostname [%s] for model %s", hostname, model_id)

# Look for the prompt dir
prompt_dir = config.get(cls.PROMPT_DIR_KEY) or None
error.type_check(
"<TGB17909870E>",
str,
allow_none=True,
**{cls.PROMPT_DIR_KEY: prompt_dir},
)
if prompt_dir:
error.dir_check("<RGB69837665E>", prompt_dir)

# Pull out the TLS info
ca_cert = config.get(cls.CA_CERT_FILE_KEY) or None
client_cert = config.get(cls.CLIENT_CERT_FILE_KEY) or None
Expand Down Expand Up @@ -120,7 +145,12 @@ def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]:
if client_cert
else None
)
return cls(hostname=hostname, ca_cert_file=ca_cert, client_tls=client_tls)
return cls(
hostname=hostname,
ca_cert_file=ca_cert,
client_tls=client_tls,
prompt_dir=prompt_dir,
)

@property
def tls_enabled(self) -> bool:
Expand All @@ -130,6 +160,35 @@ def tls_enabled(self) -> bool:
def mtls_enabled(self) -> bool:
return None not in [self.ca_cert_file, self.client_tls]

def load_prompt_artifacts(self, prompt_id: str, *artifact_paths):
"""Load the given artifact paths to this TGIS connection
As implemented, this is a simple copy to the TGIS instance's prompt dir,
but it could extend to API interactions in the future.
TODO: If two copies of the runtime attempt to perform the same copy at
the same time, it could race and cause errors with the mounted
directory system.
"""
error.value_check(
"<TGB07970356E>",
self.prompt_dir is not None,
"No prompt_dir configured for {}",
self.hostname,
)
error.type_check_all(
"<TGB23973965E>",
str,
artifact_paths=artifact_paths,
)
target_dir = os.path.join(self.prompt_dir, prompt_id)
os.makedirs(target_dir, exist_ok=True)
for artifact_path in artifact_paths:
error.file_check("<TGB14818050E>", artifact_path)
target_file = os.path.join(target_dir, os.path.basename(artifact_path))
log.debug3("Copying %s -> %s", artifact_path, target_file)
shutil.copyfile(artifact_path, target_file)

def get_client(self) -> generation_pb2_grpc.GenerationServiceStub:
"""Get a grpc client for the connection"""
if self._client is None:
Expand Down
125 changes: 125 additions & 0 deletions tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Standard
from unittest import mock
import os
import tempfile
import time

# Third Party
Expand Down Expand Up @@ -369,6 +370,35 @@ def test_local_tgis_autorecovery(mock_tgis_fixture: MockTGISFixture):
)


def test_local_tgis_with_prompt_dir(mock_tgis_fixture: MockTGISFixture):
"""Test that a "local tgis" (mocked) can manage prompts"""
mock_tgis_server: TGISMock = mock_tgis_fixture.mock_tgis_server
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:
tgis_be = TGISBackend(
{
"local": {
"grpc_port": int(mock_tgis_server.hostname.split(":")[-1]),
"http_port": mock_tgis_server.http_port,
"health_poll_delay": 0.1,
"prompt_dir": prompt_dir,
},
}
)
assert tgis_be.local_tgis
assert not mock_tgis_fixture.server_launched()
local_model_id = "local_model"
tgis_be.get_client(local_model_id)

prompt_id = "some-prompt"
artifact_fname = "artifact.pt"
source_fname = os.path.join(source_dir, artifact_fname)
with open(source_fname, "w") as handle:
handle.write("stub")
tgis_be.load_prompt_artifacts(local_model_id, prompt_id, source_fname)
assert os.path.exists(os.path.join(prompt_dir, prompt_id, artifact_fname))


## Remote Models ###############################################################


Expand Down Expand Up @@ -462,6 +492,101 @@ def test_tgis_backend_unload_multi_connection():
assert not tgis_be.get_connection("bar", False)


def test_tgis_backend_config_load_prompt_artifacts():
"""Make sure that loading prompt artifacts behaves as expected"""
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:

# Make some source files
source_fnames = ["prompt1.pt", "prompt2.pt"]
source_files = [os.path.join(source_dir, fname) for fname in source_fnames]
for source_file in source_files:
with open(source_file, "w") as handle:
handle.write("stub")

# Set up a separate prompt dir for foo and bar
foo_prompt_dir = os.path.join(prompt_dir, "foo")
bar_prompt_dir = os.path.join(prompt_dir, "bar")
os.makedirs(foo_prompt_dir)
os.makedirs(bar_prompt_dir)

# Make the backend with two remotes that support prompts and one
# that does not
tgis_be = TGISBackend(
{
"remote_models": {
"foo": {"hostname": "foo:123", "prompt_dir": foo_prompt_dir},
"bar": {"hostname": "bar:123", "prompt_dir": bar_prompt_dir},
"baz": {"hostname": "bar:123"},
},
}
)

# Make sure loading prompts lands on the right model and prompt
prompt_id1 = "prompt-one"
prompt_id2 = "prompt-two"
tgis_be.load_prompt_artifacts("foo", prompt_id1, source_files[0])
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id1, source_fnames[0])
)
assert not os.path.exists(
os.path.join(foo_prompt_dir, prompt_id2, source_fnames[1])
)
assert not os.path.exists(
os.path.join(bar_prompt_dir, prompt_id1, source_fnames[0])
)
assert not os.path.exists(
os.path.join(bar_prompt_dir, prompt_id2, source_fnames[1])
)
tgis_be.load_prompt_artifacts("foo", prompt_id2, source_files[1])
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id1, source_fnames[0])
)
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id2, source_fnames[1])
)
assert not os.path.exists(
os.path.join(bar_prompt_dir, prompt_id1, source_fnames[0])
)
assert not os.path.exists(
os.path.join(bar_prompt_dir, prompt_id2, source_fnames[1])
)
tgis_be.load_prompt_artifacts("bar", prompt_id1, source_files[0])
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id1, source_fnames[0])
)
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id2, source_fnames[1])
)
assert os.path.exists(
os.path.join(bar_prompt_dir, prompt_id1, source_fnames[0])
)
assert not os.path.exists(
os.path.join(bar_prompt_dir, prompt_id2, source_fnames[1])
)
tgis_be.load_prompt_artifacts("bar", prompt_id2, source_files[1])
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id1, source_fnames[0])
)
assert os.path.exists(
os.path.join(foo_prompt_dir, prompt_id2, source_fnames[1])
)
assert os.path.exists(
os.path.join(bar_prompt_dir, prompt_id1, source_fnames[0])
)
assert os.path.exists(
os.path.join(bar_prompt_dir, prompt_id2, source_fnames[1])
)

# Make sure non-prompt models raise
with pytest.raises(ValueError):
tgis_be.load_prompt_artifacts("baz", prompt_id1, source_files[0])

# Make sure unknown model raises
with pytest.raises(ValueError):
tgis_be.load_prompt_artifacts("buz", prompt_id1, source_files[0])


## Failure Tests ###############################################################


Expand Down
Loading

0 comments on commit 06af8f4

Please sign in to comment.