Skip to content

Commit

Permalink
py: add user auth using SA
Browse files Browse the repository at this point in the history
  • Loading branch information
isinyaaa committed Apr 15, 2024
1 parent 88730ff commit ed6dcbd
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 79 deletions.
7 changes: 5 additions & 2 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ def __init__(
server_address: str,
port: int = 443,
custom_ca: str | None = None,
user_token: str | None = None,
):
"""Constructor.
Args:
author: Name of the author.
server_address: Server address.
port: Server port. Defaults to 443.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar
KF_PIPELINES_SA_TOKEN_PATH.
"""
# TODO: get args from env
self._author = author
self._api = ModelRegistryAPIClient(server_address, port, custom_ca)
self._api = ModelRegistryAPIClient(server_address, port, custom_ca, user_token)

def _register_model(self, name: str) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
Expand Down
34 changes: 27 additions & 7 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Sequence
from pathlib import Path

from ml_metadata.proto import MetadataStoreClientConfig
import grpc

from .exceptions import StoreException
from .store import MLMDStore, ProtoType
Expand All @@ -23,15 +23,16 @@ def __init__(
server_address: str,
port: int = 443,
custom_ca: str | None = None,
user_token: str | None = None,
):
"""Constructor.
Args:
server_address: Server address.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH.
port: Server port. Defaults to 443.
"""
config = MetadataStoreClientConfig()
if port == 443:
if not custom_ca:
ca_cert = os.environ.get("CERT")
Expand All @@ -41,11 +42,30 @@ def __init__(
root_certs = Path(ca_cert).read_bytes()
else:
root_certs = custom_ca

config.ssl_config.custom_ca = root_certs
config.host = server_address
config.port = port
self._store = MLMDStore(config)
channel_credentials = grpc.ssl_channel_credentials(root_certs)
if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if not sa_token:
msg = "Access token must be provided"
raise StoreException(msg)
token = Path(sa_token).read_bytes()
else:
token = user_token
call_credentials = grpc.access_token_call_credentials(token)
composite_credentials = grpc.composite_channel_credentials(
channel_credentials,
call_credentials,
)
channel = grpc.secure_channel(
f"{server_address}:443",
composite_credentials,
)
self._store = MLMDStore.from_channel(channel)
else:
# TODO: add Auth header for insecure connection
# chan = grpc.insecure_channel(f"{server_address}:{port}")
self._store = MLMDStore.from_config(host=server_address, port=port)

def _map(self, py_obj: ProtoBase) -> ProtoType:
"""Map a Python object to a proto object.
Expand Down
61 changes: 44 additions & 17 deletions clients/python/src/model_registry/store/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import ClassVar

from grpc import Channel
from ml_metadata import errors
from ml_metadata.metadata_store import ListOptions, MetadataStore
from ml_metadata.proto import (
Expand All @@ -14,6 +16,7 @@
MetadataStoreClientConfig,
ParentContext,
)
from ml_metadata.proto.metadata_store_service_pb2_grpc import MetadataStoreServiceStub

from model_registry.exceptions import (
DuplicateException,
Expand All @@ -25,19 +28,43 @@
from .base import ProtoType


@dataclass
class MLMDStore:
"""MLMD storage backend."""

store: MetadataStore
# cache for MLMD type IDs
_type_ids: ClassVar[dict[str, int]] = {}

def __init__(self, config: MetadataStoreClientConfig):
@classmethod
def from_config(cls, host: str, port: int):
"""Constructor.
Args:
config: MLMD config.
host: MLMD store server host.
port: MLMD store server port.
"""
self._mlmd_store = MetadataStore(config)
return cls(
MetadataStore(
MetadataStoreClientConfig(
host=host,
port=port,
)
)
)

@classmethod
def from_channel(cls, chan: Channel):
"""Constructor.
Args:
chan: gRPC channel to the MLMD store.
"""
store = MetadataStore(
MetadataStoreClientConfig(host="localhost", port=8080),
)
store._metadata_store_stub = MetadataStoreServiceStub(chan)
return cls(store)

def get_type_id(self, pt: type[ProtoType], type_name: str) -> int:
"""Get backend ID for a type.
Expand All @@ -59,7 +86,7 @@ def get_type_id(self, pt: type[ProtoType], type_name: str) -> int:
pt_name = pt.__name__.lower()

try:
_type = getattr(self._mlmd_store, f"get_{pt_name}_type")(type_name)
_type = getattr(self.store, f"get_{pt_name}_type")(type_name)
except errors.NotFoundError as e:
msg = f"{pt_name} type {type_name} does not exist"
raise TypeNotFoundException(msg) from e
Expand All @@ -85,7 +112,7 @@ def put_artifact(self, artifact: Artifact) -> int:
StoreException: If the artifact isn't properly formed.
"""
try:
return self._mlmd_store.put_artifacts([artifact])[0]
return self.store.put_artifacts([artifact])[0]
except errors.AlreadyExistsError as e:
msg = f"Artifact {artifact.name} already exists"
raise DuplicateException(msg) from e
Expand All @@ -111,7 +138,7 @@ def put_context(self, context: Context) -> int:
StoreException: If the context isn't propertly formed.
"""
try:
return self._mlmd_store.put_contexts([context])[0]
return self.store.put_contexts([context])[0]
except errors.AlreadyExistsError as e:
msg = f"Context {context.name} already exists"
raise DuplicateException(msg) from e
Expand Down Expand Up @@ -152,12 +179,12 @@ def get_context(
StoreException: Invalid arguments.
"""
if name is not None:
return self._mlmd_store.get_context_by_type_and_name(ctx_type_name, name)
return self.store.get_context_by_type_and_name(ctx_type_name, name)

if id is not None:
contexts = self._mlmd_store.get_contexts_by_id([id])
contexts = self.store.get_contexts_by_id([id])
elif external_id is not None:
contexts = self._mlmd_store.get_contexts_by_external_ids([external_id])
contexts = self.store.get_contexts_by_external_ids([external_id])
else:
msg = "Either id, name or external_id must be provided"
raise StoreException(msg)
Expand All @@ -183,7 +210,7 @@ def get_contexts(
# TODO: should we make options optional?
# if options is not None:
try:
contexts = self._mlmd_store.get_contexts(options)
contexts = self.store.get_contexts(options)
except errors.InvalidArgumentError as e:
msg = f"Invalid arguments for get_contexts: {e}"
raise StoreException(msg) from e
Expand Down Expand Up @@ -213,7 +240,7 @@ def put_context_parent(self, parent_id: int, child_id: int):
ServerException: If there was an error putting the parent context.
"""
try:
self._mlmd_store.put_parent_contexts(
self.store.put_parent_contexts(
[ParentContext(parent_id=parent_id, child_id=child_id)]
)
except errors.AlreadyExistsError as e:
Expand All @@ -235,7 +262,7 @@ def put_attribution(self, context_id: int, artifact_id: int):
"""
attribution = Attribution(context_id=context_id, artifact_id=artifact_id)
try:
self._mlmd_store.put_attributions_and_associations([attribution], [])
self.store.put_attributions_and_associations([attribution], [])
except errors.InvalidArgumentError as e:
if "artifact" in str(e).lower():
msg = f"Artifact with ID {artifact_id} does not exist"
Expand Down Expand Up @@ -272,12 +299,12 @@ def get_artifact(
StoreException: Invalid arguments.
"""
if name is not None:
return self._mlmd_store.get_artifact_by_type_and_name(art_type_name, name)
return self.store.get_artifact_by_type_and_name(art_type_name, name)

if id is not None:
artifacts = self._mlmd_store.get_artifacts_by_id([id])
artifacts = self.store.get_artifacts_by_id([id])
elif external_id is not None:
artifacts = self._mlmd_store.get_artifacts_by_external_ids([external_id])
artifacts = self.store.get_artifacts_by_external_ids([external_id])
else:
msg = "Either id, name or external_id must be provided"
raise StoreException(msg)
Expand All @@ -299,7 +326,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact:
Artifact.
"""
try:
artifacts = self._mlmd_store.get_artifacts_by_context(ctx_id)
artifacts = self.store.get_artifacts_by_context(ctx_id)
except errors.InternalError as e:
msg = f"Couldn't get artifacts by context {ctx_id}"
raise ServerException(msg) from e
Expand All @@ -322,7 +349,7 @@ def get_artifacts(
Artifacts.
"""
try:
artifacts = self._mlmd_store.get_artifacts(options)
artifacts = self.store.get_artifacts(options)
except errors.InvalidArgumentError as e:
msg = f"Invalid arguments for get_artifacts: {e}"
raise StoreException(msg) from e
Expand Down
24 changes: 12 additions & 12 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# ruff: noqa: PT021 supported
@pytest.fixture(scope="session")
def mlmd_conn(request) -> MetadataStoreClientConfig:
def mlmd_port(request) -> int:
model_registry_root_dir = model_registry_root(request)
print(
"Assuming this is the Model Registry root directory:", model_registry_root_dir
Expand All @@ -46,10 +46,8 @@ def mlmd_conn(request) -> MetadataStoreClientConfig:
wait_for_logs(container, "Server listening on")
os.system('docker container ls --format "table {{.ID}}\t{{.Names}}\t{{.Ports}}" -a') # noqa governed test
print("waited for logs and port")
cfg = MetadataStoreClientConfig(
host="localhost", port=int(container.get_exposed_port(8080))
)
print(cfg)
port = int(container.get_exposed_port(8080))
print("port:", port)

# this callback is needed in order to perform the container.stop()
# removing this callback might result in mlmd container shutting down before the tests had chance to fully run,
Expand All @@ -63,18 +61,20 @@ def teardown():
time.sleep(
3
) # allowing some time for mlmd grpc to fully stabilize (is "spent" once per pytest session anyway)
_throwaway_store = metadata_store.MetadataStore(cfg)
_throwaway_store = metadata_store.MetadataStore(
MetadataStoreClientConfig(host="localhost", port=port)
)
wait_for_grpc(container, _throwaway_store)

return cfg
return port


def model_registry_root(request):
return (request.config.rootpath / "../..").resolve() # resolves to absolute path


@pytest.fixture()
def plain_wrapper(request, mlmd_conn: MetadataStoreClientConfig) -> MLMDStore:
def plain_wrapper(request, mlmd_port: int) -> MLMDStore:
sqlite_db_file = (
model_registry_root(request) / "test/config/ml-metadata/metadata.sqlite.db"
)
Expand All @@ -89,7 +89,7 @@ def teardown():

request.addfinalizer(teardown)

return MLMDStore(mlmd_conn)
return MLMDStore.from_config("localhost", mlmd_port)


def set_type_attrs(mlmd_obj: ProtoTypeType, name: str, props: list[str]):
Expand All @@ -114,7 +114,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore:
],
)

plain_wrapper._mlmd_store.put_artifact_type(ma_type)
plain_wrapper.store.put_artifact_type(ma_type)

mv_type = set_type_attrs(
ContextType(),
Expand All @@ -127,7 +127,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore:
],
)

plain_wrapper._mlmd_store.put_context_type(mv_type)
plain_wrapper.store.put_context_type(mv_type)

rm_type = set_type_attrs(
ContextType(),
Expand All @@ -138,7 +138,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore:
],
)

plain_wrapper._mlmd_store.put_context_type(rm_type)
plain_wrapper.store.put_context_type(rm_type)

return plain_wrapper

Expand Down
12 changes: 6 additions & 6 deletions clients/python/tests/store/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def artifact(plain_wrapper: MLMDStore) -> Artifact:

art = Artifact()
art.name = "test_artifact"
art.type_id = plain_wrapper._mlmd_store.put_artifact_type(art_type)
art.type_id = plain_wrapper.store.put_artifact_type(art_type)

return art

Expand All @@ -38,7 +38,7 @@ def context(plain_wrapper: MLMDStore) -> Context:

ctx = Context()
ctx.name = "test_context"
ctx.type_id = plain_wrapper._mlmd_store.put_context_type(ctx_type)
ctx.type_id = plain_wrapper.store.put_context_type(ctx_type)

return ctx

Expand All @@ -61,7 +61,7 @@ def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact):


def test_put_duplicate_artifact(plain_wrapper: MLMDStore, artifact: Artifact):
plain_wrapper._mlmd_store.put_artifacts([artifact])
plain_wrapper.store.put_artifacts([artifact])
with pytest.raises(DuplicateException):
plain_wrapper.put_artifact(artifact)

Expand All @@ -74,7 +74,7 @@ def test_put_invalid_context(plain_wrapper: MLMDStore, context: Context):


def test_put_duplicate_context(plain_wrapper: MLMDStore, context: Context):
plain_wrapper._mlmd_store.put_contexts([context])
plain_wrapper.store.put_contexts([context])

with pytest.raises(DuplicateException):
plain_wrapper.put_context(context)
Expand All @@ -83,7 +83,7 @@ def test_put_duplicate_context(plain_wrapper: MLMDStore, context: Context):
def test_put_attribution_with_invalid_context(
plain_wrapper: MLMDStore, artifact: Artifact
):
art_id = plain_wrapper._mlmd_store.put_artifacts([artifact])[0]
art_id = plain_wrapper.store.put_artifacts([artifact])[0]

with pytest.raises(StoreException) as store_error:
plain_wrapper.put_attribution(0, art_id)
Expand All @@ -94,7 +94,7 @@ def test_put_attribution_with_invalid_context(
def test_put_attribution_with_invalid_artifact(
plain_wrapper: MLMDStore, context: Context
):
ctx_id = plain_wrapper._mlmd_store.put_contexts([context])[0]
ctx_id = plain_wrapper.store.put_contexts([context])[0]

with pytest.raises(StoreException) as store_error:
plain_wrapper.put_attribution(ctx_id, 0)
Expand Down
Loading

0 comments on commit ed6dcbd

Please sign in to comment.