Skip to content

Commit

Permalink
Merge branch 'master' into postgres-for-update-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
david-leifker authored Dec 19, 2024
2 parents 8fb6594 + 4392d72 commit 4792a15
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
PropertyValueClass,
StructuredPropertyDefinitionClass,
)
from datahub.metadata.urns import StructuredPropertyUrn, Urn
from datahub.metadata.urns import DataTypeUrn, StructuredPropertyUrn, Urn
from datahub.utilities.urns._urn_base import URN_TYPES

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -86,19 +86,31 @@ class StructuredProperties(ConfigModel):

@validator("type")
def validate_type(cls, v: str) -> str:
# Convert to lowercase if needed
if not v.islower():
# This logic is somewhat hacky, since we need to deal with
# 1. fully qualified urns
# 2. raw data types, that need to get the datahub namespace prefix
# While keeping the user-facing interface and error messages clean.

if not v.startswith("urn:li:") and not v.islower():
# Convert to lowercase if needed
v = v.lower()
logger.warning(
f"Structured property type should be lowercase. Updated to {v.lower()}"
f"Structured property type should be lowercase. Updated to {v}"
)
v = v.lower()

urn = Urn.make_data_type_urn(v)

# Check if type is allowed
if not AllowedTypes.check_allowed_type(v):
data_type_urn = DataTypeUrn.from_string(urn)
unqualified_data_type = data_type_urn.id
if unqualified_data_type.startswith("datahub."):
unqualified_data_type = unqualified_data_type[len("datahub.") :]
if not AllowedTypes.check_allowed_type(unqualified_data_type):
raise ValueError(
f"Type {v} is not allowed. Allowed types are {AllowedTypes.values()}"
f"Type {unqualified_data_type} is not allowed. Allowed types are {AllowedTypes.values()}"
)
return v

return urn

@property
def fqn(self) -> str:
Expand Down
35 changes: 30 additions & 5 deletions metadata-ingestion/src/datahub/ingestion/source/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,30 @@
class MLflowConfig(EnvConfigMixin):
tracking_uri: Optional[str] = Field(
default=None,
description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)",
description=(
"Tracking server URI. If not set, an MLflow default tracking_uri is used"
" (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)"
),
)
registry_uri: Optional[str] = Field(
default=None,
description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)",
description=(
"Registry server URI. If not set, an MLflow default registry_uri is used"
" (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)"
),
)
model_name_separator: str = Field(
default="_",
description="A string which separates model name from its version (e.g. model_1 or model-1)",
)
base_external_url: Optional[str] = Field(
default=None,
description=(
"Base URL to use when constructing external URLs to MLflow."
" If not set, tracking_uri is used if it's an HTTP URL."
" If neither is set, external URLs are not generated."
),
)


@dataclass
Expand Down Expand Up @@ -279,12 +293,23 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str:
)
return urn

def _make_external_url(self, model_version: ModelVersion) -> Union[None, str]:
def _get_base_external_url_from_tracking_uri(self) -> Optional[str]:
if isinstance(
self.client.tracking_uri, str
) and self.client.tracking_uri.startswith("http"):
return self.client.tracking_uri
else:
return None

def _make_external_url(self, model_version: ModelVersion) -> Optional[str]:
"""
Generate URL for a Model Version to MLflow UI.
"""
base_uri = self.client.tracking_uri
if base_uri.startswith("http"):
base_uri = (
self.config.base_external_url
or self._get_base_external_url_from_tracking_uri()
)
if base_uri:
return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}"
else:
return None
Expand Down
13 changes: 13 additions & 0 deletions metadata-ingestion/tests/unit/test_mlflow_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,16 @@ def test_make_external_link_remote(source, model_version):
url = source._make_external_url(model_version)

assert url == expected_url


def test_make_external_link_remote_via_config(source, model_version):
custom_base_url = "https://custom-server.org"
source.config.base_external_url = custom_base_url
source.client = MlflowClient(
tracking_uri="https://dummy-mlflow-tracking-server.org"
)
expected_url = f"{custom_base_url}/#/models/{model_version.name}/versions/{model_version.version}"

url = source._make_external_url(model_version)

assert url == expected_url

0 comments on commit 4792a15

Please sign in to comment.