From ed6dcbdbe59d1d41e6570d582f05fe4867176526 Mon Sep 17 00:00:00 2001 From: Isabella Basso Date: Mon, 1 Apr 2024 09:49:10 -0300 Subject: [PATCH] py: add user auth using SA --- clients/python/src/model_registry/_client.py | 7 +- clients/python/src/model_registry/core.py | 34 ++++++++-- .../src/model_registry/store/wrapper.py | 61 ++++++++++++----- clients/python/tests/conftest.py | 24 +++---- clients/python/tests/store/test_wrapper.py | 12 ++-- clients/python/tests/test_core.py | 68 +++++++++---------- 6 files changed, 127 insertions(+), 79 deletions(-) diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index ddd58f4ff..1d05d1a4b 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -20,6 +20,7 @@ def __init__( server_address: str, port: int = 443, custom_ca: str | None = None, + user_token: str | None = None, ): """Constructor. @@ -27,11 +28,13 @@ def __init__( author: Name of the author. server_address: Server address. port: Server port. Defaults to 443. - custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. + custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT. + user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar + KF_PIPELINES_SA_TOKEN_PATH. """ # TODO: get args from env self._author = author - self._api = ModelRegistryAPIClient(server_address, port, custom_ca) + self._api = ModelRegistryAPIClient(server_address, port, custom_ca, user_token) def _register_model(self, name: str) -> RegisteredModel: if rm := self._api.get_registered_model_by_params(name): diff --git a/clients/python/src/model_registry/core.py b/clients/python/src/model_registry/core.py index 07925d647..1723379f0 100644 --- a/clients/python/src/model_registry/core.py +++ b/clients/python/src/model_registry/core.py @@ -6,7 +6,7 @@ from collections.abc import Sequence from pathlib import Path -from ml_metadata.proto import MetadataStoreClientConfig +import grpc from .exceptions import StoreException from .store import MLMDStore, ProtoType @@ -23,15 +23,16 @@ def __init__( server_address: str, port: int = 443, custom_ca: str | None = None, + user_token: str | None = None, ): """Constructor. Args: server_address: Server address. custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT. + user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH. port: Server port. Defaults to 443. """ - config = MetadataStoreClientConfig() if port == 443: if not custom_ca: ca_cert = os.environ.get("CERT") @@ -41,11 +42,30 @@ def __init__( root_certs = Path(ca_cert).read_bytes() else: root_certs = custom_ca - - config.ssl_config.custom_ca = root_certs - config.host = server_address - config.port = port - self._store = MLMDStore(config) + channel_credentials = grpc.ssl_channel_credentials(root_certs) + if not user_token: + # /var/run/secrets/kubernetes.io/serviceaccount/token + sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH") + if not sa_token: + msg = "Access token must be provided" + raise StoreException(msg) + token = Path(sa_token).read_bytes() + else: + token = user_token + call_credentials = grpc.access_token_call_credentials(token) + composite_credentials = grpc.composite_channel_credentials( + channel_credentials, + call_credentials, + ) + channel = grpc.secure_channel( + f"{server_address}:443", + composite_credentials, + ) + self._store = MLMDStore.from_channel(channel) + else: + # TODO: add Auth header for insecure connection + # chan = grpc.insecure_channel(f"{server_address}:{port}") + self._store = MLMDStore.from_config(host=server_address, port=port) def _map(self, py_obj: ProtoBase) -> ProtoType: """Map a Python object to a proto object. diff --git a/clients/python/src/model_registry/store/wrapper.py b/clients/python/src/model_registry/store/wrapper.py index 57c56dd78..014739be3 100644 --- a/clients/python/src/model_registry/store/wrapper.py +++ b/clients/python/src/model_registry/store/wrapper.py @@ -3,8 +3,10 @@ from __future__ import annotations from collections.abc import Sequence +from dataclasses import dataclass from typing import ClassVar +from grpc import Channel from ml_metadata import errors from ml_metadata.metadata_store import ListOptions, MetadataStore from ml_metadata.proto import ( @@ -14,6 +16,7 @@ MetadataStoreClientConfig, ParentContext, ) +from ml_metadata.proto.metadata_store_service_pb2_grpc import MetadataStoreServiceStub from model_registry.exceptions import ( DuplicateException, @@ -25,19 +28,43 @@ from .base import ProtoType +@dataclass class MLMDStore: """MLMD storage backend.""" + store: MetadataStore # cache for MLMD type IDs _type_ids: ClassVar[dict[str, int]] = {} - def __init__(self, config: MetadataStoreClientConfig): + @classmethod + def from_config(cls, host: str, port: int): """Constructor. Args: - config: MLMD config. + host: MLMD store server host. + port: MLMD store server port. """ - self._mlmd_store = MetadataStore(config) + return cls( + MetadataStore( + MetadataStoreClientConfig( + host=host, + port=port, + ) + ) + ) + + @classmethod + def from_channel(cls, chan: Channel): + """Constructor. + + Args: + chan: gRPC channel to the MLMD store. + """ + store = MetadataStore( + MetadataStoreClientConfig(host="localhost", port=8080), + ) + store._metadata_store_stub = MetadataStoreServiceStub(chan) + return cls(store) def get_type_id(self, pt: type[ProtoType], type_name: str) -> int: """Get backend ID for a type. @@ -59,7 +86,7 @@ def get_type_id(self, pt: type[ProtoType], type_name: str) -> int: pt_name = pt.__name__.lower() try: - _type = getattr(self._mlmd_store, f"get_{pt_name}_type")(type_name) + _type = getattr(self.store, f"get_{pt_name}_type")(type_name) except errors.NotFoundError as e: msg = f"{pt_name} type {type_name} does not exist" raise TypeNotFoundException(msg) from e @@ -85,7 +112,7 @@ def put_artifact(self, artifact: Artifact) -> int: StoreException: If the artifact isn't properly formed. """ try: - return self._mlmd_store.put_artifacts([artifact])[0] + return self.store.put_artifacts([artifact])[0] except errors.AlreadyExistsError as e: msg = f"Artifact {artifact.name} already exists" raise DuplicateException(msg) from e @@ -111,7 +138,7 @@ def put_context(self, context: Context) -> int: StoreException: If the context isn't propertly formed. """ try: - return self._mlmd_store.put_contexts([context])[0] + return self.store.put_contexts([context])[0] except errors.AlreadyExistsError as e: msg = f"Context {context.name} already exists" raise DuplicateException(msg) from e @@ -152,12 +179,12 @@ def get_context( StoreException: Invalid arguments. """ if name is not None: - return self._mlmd_store.get_context_by_type_and_name(ctx_type_name, name) + return self.store.get_context_by_type_and_name(ctx_type_name, name) if id is not None: - contexts = self._mlmd_store.get_contexts_by_id([id]) + contexts = self.store.get_contexts_by_id([id]) elif external_id is not None: - contexts = self._mlmd_store.get_contexts_by_external_ids([external_id]) + contexts = self.store.get_contexts_by_external_ids([external_id]) else: msg = "Either id, name or external_id must be provided" raise StoreException(msg) @@ -183,7 +210,7 @@ def get_contexts( # TODO: should we make options optional? # if options is not None: try: - contexts = self._mlmd_store.get_contexts(options) + contexts = self.store.get_contexts(options) except errors.InvalidArgumentError as e: msg = f"Invalid arguments for get_contexts: {e}" raise StoreException(msg) from e @@ -213,7 +240,7 @@ def put_context_parent(self, parent_id: int, child_id: int): ServerException: If there was an error putting the parent context. """ try: - self._mlmd_store.put_parent_contexts( + self.store.put_parent_contexts( [ParentContext(parent_id=parent_id, child_id=child_id)] ) except errors.AlreadyExistsError as e: @@ -235,7 +262,7 @@ def put_attribution(self, context_id: int, artifact_id: int): """ attribution = Attribution(context_id=context_id, artifact_id=artifact_id) try: - self._mlmd_store.put_attributions_and_associations([attribution], []) + self.store.put_attributions_and_associations([attribution], []) except errors.InvalidArgumentError as e: if "artifact" in str(e).lower(): msg = f"Artifact with ID {artifact_id} does not exist" @@ -272,12 +299,12 @@ def get_artifact( StoreException: Invalid arguments. """ if name is not None: - return self._mlmd_store.get_artifact_by_type_and_name(art_type_name, name) + return self.store.get_artifact_by_type_and_name(art_type_name, name) if id is not None: - artifacts = self._mlmd_store.get_artifacts_by_id([id]) + artifacts = self.store.get_artifacts_by_id([id]) elif external_id is not None: - artifacts = self._mlmd_store.get_artifacts_by_external_ids([external_id]) + artifacts = self.store.get_artifacts_by_external_ids([external_id]) else: msg = "Either id, name or external_id must be provided" raise StoreException(msg) @@ -299,7 +326,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact: Artifact. """ try: - artifacts = self._mlmd_store.get_artifacts_by_context(ctx_id) + artifacts = self.store.get_artifacts_by_context(ctx_id) except errors.InternalError as e: msg = f"Couldn't get artifacts by context {ctx_id}" raise ServerException(msg) from e @@ -322,7 +349,7 @@ def get_artifacts( Artifacts. """ try: - artifacts = self._mlmd_store.get_artifacts(options) + artifacts = self.store.get_artifacts(options) except errors.InvalidArgumentError as e: msg = f"Invalid arguments for get_artifacts: {e}" raise StoreException(msg) from e diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 04e81fe48..3cc164459 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -21,7 +21,7 @@ # ruff: noqa: PT021 supported @pytest.fixture(scope="session") -def mlmd_conn(request) -> MetadataStoreClientConfig: +def mlmd_port(request) -> int: model_registry_root_dir = model_registry_root(request) print( "Assuming this is the Model Registry root directory:", model_registry_root_dir @@ -46,10 +46,8 @@ def mlmd_conn(request) -> MetadataStoreClientConfig: wait_for_logs(container, "Server listening on") os.system('docker container ls --format "table {{.ID}}\t{{.Names}}\t{{.Ports}}" -a') # noqa governed test print("waited for logs and port") - cfg = MetadataStoreClientConfig( - host="localhost", port=int(container.get_exposed_port(8080)) - ) - print(cfg) + port = int(container.get_exposed_port(8080)) + print("port:", port) # this callback is needed in order to perform the container.stop() # removing this callback might result in mlmd container shutting down before the tests had chance to fully run, @@ -63,10 +61,12 @@ def teardown(): time.sleep( 3 ) # allowing some time for mlmd grpc to fully stabilize (is "spent" once per pytest session anyway) - _throwaway_store = metadata_store.MetadataStore(cfg) + _throwaway_store = metadata_store.MetadataStore( + MetadataStoreClientConfig(host="localhost", port=port) + ) wait_for_grpc(container, _throwaway_store) - return cfg + return port def model_registry_root(request): @@ -74,7 +74,7 @@ def model_registry_root(request): @pytest.fixture() -def plain_wrapper(request, mlmd_conn: MetadataStoreClientConfig) -> MLMDStore: +def plain_wrapper(request, mlmd_port: int) -> MLMDStore: sqlite_db_file = ( model_registry_root(request) / "test/config/ml-metadata/metadata.sqlite.db" ) @@ -89,7 +89,7 @@ def teardown(): request.addfinalizer(teardown) - return MLMDStore(mlmd_conn) + return MLMDStore.from_config("localhost", mlmd_port) def set_type_attrs(mlmd_obj: ProtoTypeType, name: str, props: list[str]): @@ -114,7 +114,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: ], ) - plain_wrapper._mlmd_store.put_artifact_type(ma_type) + plain_wrapper.store.put_artifact_type(ma_type) mv_type = set_type_attrs( ContextType(), @@ -127,7 +127,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: ], ) - plain_wrapper._mlmd_store.put_context_type(mv_type) + plain_wrapper.store.put_context_type(mv_type) rm_type = set_type_attrs( ContextType(), @@ -138,7 +138,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: ], ) - plain_wrapper._mlmd_store.put_context_type(rm_type) + plain_wrapper.store.put_context_type(rm_type) return plain_wrapper diff --git a/clients/python/tests/store/test_wrapper.py b/clients/python/tests/store/test_wrapper.py index c2d379de1..b3f033a40 100644 --- a/clients/python/tests/store/test_wrapper.py +++ b/clients/python/tests/store/test_wrapper.py @@ -26,7 +26,7 @@ def artifact(plain_wrapper: MLMDStore) -> Artifact: art = Artifact() art.name = "test_artifact" - art.type_id = plain_wrapper._mlmd_store.put_artifact_type(art_type) + art.type_id = plain_wrapper.store.put_artifact_type(art_type) return art @@ -38,7 +38,7 @@ def context(plain_wrapper: MLMDStore) -> Context: ctx = Context() ctx.name = "test_context" - ctx.type_id = plain_wrapper._mlmd_store.put_context_type(ctx_type) + ctx.type_id = plain_wrapper.store.put_context_type(ctx_type) return ctx @@ -61,7 +61,7 @@ def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact): def test_put_duplicate_artifact(plain_wrapper: MLMDStore, artifact: Artifact): - plain_wrapper._mlmd_store.put_artifacts([artifact]) + plain_wrapper.store.put_artifacts([artifact]) with pytest.raises(DuplicateException): plain_wrapper.put_artifact(artifact) @@ -74,7 +74,7 @@ def test_put_invalid_context(plain_wrapper: MLMDStore, context: Context): def test_put_duplicate_context(plain_wrapper: MLMDStore, context: Context): - plain_wrapper._mlmd_store.put_contexts([context]) + plain_wrapper.store.put_contexts([context]) with pytest.raises(DuplicateException): plain_wrapper.put_context(context) @@ -83,7 +83,7 @@ def test_put_duplicate_context(plain_wrapper: MLMDStore, context: Context): def test_put_attribution_with_invalid_context( plain_wrapper: MLMDStore, artifact: Artifact ): - art_id = plain_wrapper._mlmd_store.put_artifacts([artifact])[0] + art_id = plain_wrapper.store.put_artifacts([artifact])[0] with pytest.raises(StoreException) as store_error: plain_wrapper.put_attribution(0, art_id) @@ -94,7 +94,7 @@ def test_put_attribution_with_invalid_context( def test_put_attribution_with_invalid_artifact( plain_wrapper: MLMDStore, context: Context ): - ctx_id = plain_wrapper._mlmd_store.put_contexts([context])[0] + ctx_id = plain_wrapper.store.put_contexts([context])[0] with pytest.raises(StoreException) as store_error: plain_wrapper.put_attribution(ctx_id, 0) diff --git a/clients/python/tests/test_core.py b/clients/python/tests/test_core.py index bdb3ba144..2e4469d81 100644 --- a/clients/python/tests/test_core.py +++ b/clients/python/tests/test_core.py @@ -61,7 +61,7 @@ def test_upsert_registered_model( ): mr_api.upsert_registered_model(registered_model.py) - rm_proto = mr_api._store._mlmd_store.get_context_by_type_and_name( + rm_proto = mr_api._store.store.get_context_by_type_and_name( RegisteredModel.get_proto_type_name(), registered_model.proto.name ) assert rm_proto is not None @@ -73,7 +73,7 @@ def test_get_registered_model_by_id( mr_api: ModelRegistryAPIClient, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] assert (mlmd_rm := mr_api.get_registered_model_by_id(str(rm_id))) assert mlmd_rm.id == str(rm_id) @@ -85,7 +85,7 @@ def test_get_registered_model_by_name( mr_api: ModelRegistryAPIClient, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] assert ( mlmd_rm := mr_api.get_registered_model_by_params(name=registered_model.py.name) @@ -101,7 +101,7 @@ def test_get_registered_model_by_external_id( ): registered_model.py.external_id = "external_id" registered_model.proto.external_id = "external_id" - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] assert ( mlmd_rm := mr_api.get_registered_model_by_params( @@ -116,9 +116,9 @@ def test_get_registered_model_by_external_id( def test_get_registered_models( mr_api: ModelRegistryAPIClient, registered_model: Mapped ): - rm1_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm1_id = mr_api._store.store.put_contexts([registered_model.proto])[0] registered_model.proto.name = "model2" - rm2_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm2_id = mr_api._store.store.put_contexts([registered_model.proto])[0] mlmd_rms = mr_api.get_registered_models() assert len(mlmd_rms) == 2 @@ -130,12 +130,12 @@ def test_upsert_model_version( model_version: Mapped, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] rm_id = str(rm_id) mr_api.upsert_model_version(model_version.py, rm_id) - mv_proto = mr_api._store._mlmd_store.get_context_by_type_and_name( + mv_proto = mr_api._store.store.get_context_by_type_and_name( ModelVersion.get_proto_type_name(), f"{rm_id}:{model_version.proto.name}" ) assert mv_proto is not None @@ -145,7 +145,7 @@ def test_upsert_model_version( def test_get_model_version_by_id(mr_api: ModelRegistryAPIClient, model_version: Mapped): model_version.proto.name = f"1:{model_version.proto.name}" - ctx_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + ctx_id = mr_api._store.store.put_contexts([model_version.proto])[0] id = str(ctx_id) assert (mlmd_mv := mr_api.get_model_version_by_id(id)) @@ -158,7 +158,7 @@ def test_get_model_version_by_name( mr_api: ModelRegistryAPIClient, model_version: Mapped ): model_version.proto.name = f"1:{model_version.proto.name}" - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] assert ( mlmd_mv := mr_api.get_model_version_by_params( @@ -176,7 +176,7 @@ def test_get_model_version_by_external_id( model_version.proto.name = f"1:{model_version.proto.name}" model_version.proto.external_id = "external_id" model_version.py.external_id = "external_id" - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] assert ( mlmd_mv := mr_api.get_model_version_by_params( @@ -193,14 +193,14 @@ def test_get_model_versions( model_version: Mapped, registered_model: Mapped, ): - rm_id = mr_api._store._mlmd_store.put_contexts([registered_model.proto])[0] + rm_id = mr_api._store.store.put_contexts([registered_model.proto])[0] model_version.proto.name = f"{rm_id}:version" - mv1_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] model_version.proto.name = f"{rm_id}:version2" - mv2_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] - mr_api._store._mlmd_store.put_parent_contexts( + mr_api._store.store.put_parent_contexts( [ ParentContext(parent_id=rm_id, child_id=mv1_id), ParentContext(parent_id=rm_id, child_id=mv2_id), @@ -220,12 +220,12 @@ def test_upsert_model_artifact( ): monkeypatch.setattr(ModelArtifact, "mlmd_name_prefix", "test_prefix") - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv_id = str(mv_id) mr_api.upsert_model_artifact(model.py, mv_id) - ma_proto = mr_api._store._mlmd_store.get_artifact_by_type_and_name( + ma_proto = mr_api._store.store.get_artifact_by_type_and_name( ModelArtifact.get_proto_type_name(), f"test_prefix:{model.proto.name}" ) assert ma_proto is not None @@ -236,11 +236,11 @@ def test_upsert_model_artifact( def test_upsert_duplicate_model_artifact_with_different_version( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv1_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv1_id = str(mv1_id) model_version.proto.name = "version2" - mv2_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv2_id = str(mv2_id) ma1 = evolve(model.py) @@ -248,9 +248,7 @@ def test_upsert_duplicate_model_artifact_with_different_version( ma2 = evolve(model.py) mr_api.upsert_model_artifact(ma2, mv2_id) - ma_protos = mr_api._store._mlmd_store.get_artifacts_by_id( - [int(ma1.id), int(ma2.id)] - ) + ma_protos = mr_api._store.store.get_artifacts_by_id([int(ma1.id), int(ma2.id)]) assert ma1.name == ma2.name assert ma1.name != str(ma_protos[0].name) assert ma2.name != str(ma_protos[1].name) @@ -259,7 +257,7 @@ def test_upsert_duplicate_model_artifact_with_different_version( def test_upsert_duplicate_model_artifact_with_same_version( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] mv_id = str(mv_id) ma1 = evolve(model.py) @@ -271,7 +269,7 @@ def test_upsert_duplicate_model_artifact_with_same_version( def test_get_model_artifact_by_id(mr_api: ModelRegistryAPIClient, model: Mapped): model.proto.name = f"test_prefix:{model.proto.name}" - id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + id = mr_api._store.store.put_artifacts([model.proto])[0] id = str(id) assert (mlmd_ma := mr_api.get_model_artifact_by_id(id)) @@ -283,12 +281,12 @@ def test_get_model_artifact_by_id(mr_api: ModelRegistryAPIClient, model: Mapped) def test_get_model_artifact_by_model_version_id( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv_id = mr_api._store.store.put_contexts([model_version.proto])[0] model.proto.name = f"test_prefix:{model.proto.name}" - ma_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma_id = mr_api._store.store.put_artifacts([model.proto])[0] - mr_api._store._mlmd_store.put_attributions_and_associations( + mr_api._store.store.put_attributions_and_associations( [Attribution(context_id=mv_id, artifact_id=ma_id)], [] ) @@ -305,7 +303,7 @@ def test_get_model_artifact_by_external_id( model.proto.external_id = "external_id" model.py.external_id = "external_id" - id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + id = mr_api._store.store.put_artifacts([model.proto])[0] id = str(id) assert ( @@ -318,9 +316,9 @@ def test_get_model_artifact_by_external_id( def test_get_all_model_artifacts(mr_api: ModelRegistryAPIClient, model: Mapped): model.proto.name = "test_prefix:model1" - ma1_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma1_id = mr_api._store.store.put_artifacts([model.proto])[0] model.proto.name = "test_prefix:model2" - ma2_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma2_id = mr_api._store.store.put_artifacts([model.proto])[0] mlmd_mas = mr_api.get_model_artifacts() assert len(mlmd_mas) == 2 @@ -330,17 +328,17 @@ def test_get_all_model_artifacts(mr_api: ModelRegistryAPIClient, model: Mapped): def test_get_model_artifacts_by_mv_id( mr_api: ModelRegistryAPIClient, model: Mapped, model_version: Mapped ): - mv1_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv1_id = mr_api._store.store.put_contexts([model_version.proto])[0] model_version.proto.name = "version2" - mv2_id = mr_api._store._mlmd_store.put_contexts([model_version.proto])[0] + mv2_id = mr_api._store.store.put_contexts([model_version.proto])[0] model.proto.name = "test_prefix:model1" - ma1_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma1_id = mr_api._store.store.put_artifacts([model.proto])[0] model.proto.name = "test_prefix:model2" - ma2_id = mr_api._store._mlmd_store.put_artifacts([model.proto])[0] + ma2_id = mr_api._store.store.put_artifacts([model.proto])[0] - mr_api._store._mlmd_store.put_attributions_and_associations( + mr_api._store.store.put_attributions_and_associations( [ Attribution(context_id=mv1_id, artifact_id=ma1_id), Attribution(context_id=mv2_id, artifact_id=ma2_id),