Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingestion/glue): Implement unified AWS Glue table representation #12335

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@
"types-click==0.1.12",
# The boto3-stubs package seems to have regularly breaking minor releases,
# we pin to a specific version to avoid this.
"boto3-stubs[s3,glue,sagemaker,sts,dynamodb]==1.28.15",
"boto3-stubs[s3,glue,sagemaker,sts,dynamodb,athena]==1.28.15",
"mypy-boto3-sagemaker==1.28.15", # For some reason, above pin only restricts `mypy-boto3-sagemaker<1.29.0,>=1.28.0`
"types-tabulate",
# avrogen package requires this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mypy_boto3_athena import AthenaClient
from mypy_boto3_dynamodb import DynamoDBClient
from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client, S3ServiceResource
Expand Down Expand Up @@ -445,6 +446,9 @@
resource.meta.client.meta.events.unregister("before-sign.s3", fix_s3_host)
return resource

def get_athena_client(self) -> "AthenaClient":
return self.get_session().client("athena", config=self._aws_config())

Check warning on line 450 in metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py#L450

Added line #L450 was not covered by tests

def get_glue_client(self) -> "GlueClient":
return self.get_session().client("glue", config=self._aws_config())

Expand All @@ -454,6 +458,9 @@
def get_sagemaker_client(self) -> "SageMakerClient":
return self.get_session().client("sagemaker", config=self._aws_config())

def get_sts_client(self) -> "STSClient":
return self.get_session().client("sts", config=self._aws_config())


class AwsSourceConfig(EnvConfigMixin, AwsConnectionConfig):
"""
Expand Down
81 changes: 72 additions & 9 deletions metadata-ingestion/src/datahub/ingestion/source/aws/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import botocore.exceptions
import yaml
from mypy_boto3_glue.type_defs import DatabasePaginatorTypeDef, TablePaginatorTypeDef
from pydantic import validator
from pydantic.fields import Field

Expand Down Expand Up @@ -117,6 +118,7 @@
logger = logging.getLogger(__name__)

DEFAULT_PLATFORM = "glue"
DEFAULT_CATALOG_NAME = "awsdatacatalog"
VALID_PLATFORMS = [DEFAULT_PLATFORM, "athena"]


Expand Down Expand Up @@ -155,6 +157,9 @@
default=None,
description="The aws account id where the target glue catalog lives. If None, datahub will ingest glue in aws caller's account.",
)
athena_catalog_name: Optional[str] = Field(
default=None, description="The aws athena catalog name"
)
ignore_resource_links: Optional[bool] = Field(
default=False,
description="If set to True, ignore database resource links.",
Expand Down Expand Up @@ -198,6 +203,14 @@
def s3_client(self):
return self.get_s3_client()

@property
def athena_client(self):
return self.get_athena_client()

Check warning on line 208 in metadata-ingestion/src/datahub/ingestion/source/aws/glue.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/aws/glue.py#L208

Added line #L208 was not covered by tests

@property
def sts_client(self):
return self.get_sts_client()

@validator("glue_s3_lineage_direction")
def check_direction(cls, v: str) -> str:
if v.lower() not in ["upstream", "downstream"]:
Expand All @@ -215,6 +228,45 @@
f"'platform' can only take following values: {VALID_PLATFORMS}"
)

def __init__(self, **data: Any):
"""Post init configuration operations."""
super().__init__(**data)
self._set_athena_catalog_name()

def _set_athena_catalog_name(self) -> None:
"""Set the correct athena catalog name or raise an exception in case of misconfiguration."""
if self.platform == "athena":
if self.catalog_id:
current_account_id = self.sts_client.get_caller_identity().get(
"Account"
)
if self.catalog_id == current_account_id:
self.athena_catalog_name = DEFAULT_CATALOG_NAME
else:
self._validate_athena_catalog_name()

Check warning on line 246 in metadata-ingestion/src/datahub/ingestion/source/aws/glue.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/aws/glue.py#L246

Added line #L246 was not covered by tests
else:
self.athena_catalog_name = DEFAULT_CATALOG_NAME
else:
self.athena_catalog_name = None

def _validate_athena_catalog_name(self) -> None:
"""Validate if athena catalog name is set correctly.

This method helps to avoid issue when the `athena_catalog_name` does not exist in a specified AWS account.
"""
effective_catalog_id = (

Check warning on line 257 in metadata-ingestion/src/datahub/ingestion/source/aws/glue.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/aws/glue.py#L257

Added line #L257 was not covered by tests
self.athena_client.get_data_catalog(Name=self.athena_catalog_name)[
"DataCatalog"
]
.get("Parameters", {})
.get("catalog-id", "")
)
if effective_catalog_id != self.catalog_id:
raise ValueError(

Check warning on line 265 in metadata-ingestion/src/datahub/ingestion/source/aws/glue.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/aws/glue.py#L264-L265

Added lines #L264 - L265 were not covered by tests
f"Catalog configuration mismatch for catalog name {self.athena_catalog_name}."
f"Effective catalog_id: {effective_catalog_id}, configured catalog_id: {self.catalog_id}."
)


@dataclass
class GlueSourceReport(StaleEntityRemovalSourceReport):
Expand Down Expand Up @@ -443,6 +495,14 @@

yield s3_uri, extension

def _gen_full_table_name(self, database_name: str, table_name: str) -> str:
return ".".join(
filter(
None,
[self.source_config.athena_catalog_name, database_name, table_name],
)
)

def process_dataflow_node(
self,
node: Dict[str, Any],
Expand All @@ -459,8 +519,9 @@

# if data object is Glue table
if "database" in node_args and "table_name" in node_args:
full_table_name = f"{node_args['database']}.{node_args['table_name']}"

full_table_name = self._gen_full_table_name(
node_args["database"], node_args["table_name"]
)
# we know that the table will already be covered when ingesting Glue tables
node_urn = make_dataset_urn_with_platform_instance(
platform=self.platform,
Expand Down Expand Up @@ -681,7 +742,7 @@

return MetadataWorkUnit(id=f'{job_name}-{node["Id"]}', mce=mce)

def get_all_databases(self) -> Iterable[Mapping[str, Any]]:
def get_all_databases(self) -> Iterable[DatabasePaginatorTypeDef]:
logger.debug("Getting all databases")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/paginator/GetDatabases.html
paginator = self.glue_client.get_paginator("get_databases")
Expand Down Expand Up @@ -709,7 +770,9 @@
self.report.databases.processed(database["Name"])
yield database

def get_tables_from_database(self, database: Mapping[str, Any]) -> Iterable[Dict]:
def get_tables_from_database(
self, database: DatabasePaginatorTypeDef
) -> Iterable[TablePaginatorTypeDef]:
logger.debug(f"Getting tables from database {database['Name']}")
# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/paginator/GetTables.html
paginator = self.glue_client.get_paginator("get_tables")
Expand All @@ -736,7 +799,7 @@

def get_all_databases_and_tables(
self,
) -> Tuple[List[Mapping[str, Any]], List[Dict]]:
) -> Tuple[List[DatabasePaginatorTypeDef], List[TablePaginatorTypeDef]]:
all_databases = [*self.get_all_databases()]
all_tables = [
tables
Expand Down Expand Up @@ -1004,7 +1067,7 @@
)

def gen_database_containers(
self, database: Mapping[str, Any]
self, database: DatabasePaginatorTypeDef
) -> Iterable[MetadataWorkUnit]:
domain_urn = self._gen_domain_urn(database["Name"])
database_container_key = self.gen_database_key(database["Name"])
Expand Down Expand Up @@ -1079,10 +1142,10 @@
if self.extract_transforms:
yield from self._transform_extraction()

def _gen_table_wu(self, table: Dict) -> Iterable[MetadataWorkUnit]:
def _gen_table_wu(self, table: TablePaginatorTypeDef) -> Iterable[MetadataWorkUnit]:
database_name = table["DatabaseName"]
table_name = table["Name"]
full_table_name = f"{database_name}.{table_name}"
full_table_name = self._gen_full_table_name(database_name, table_name)
self.report.report_table_scanned()
if not self.source_config.database_pattern.allowed(
database_name
Expand All @@ -1097,7 +1160,7 @@
platform_instance=self.source_config.platform_instance,
)

mce = self._extract_record(dataset_urn, table, full_table_name)
mce = self._extract_record(dataset_urn, dict(table), full_table_name)
yield MetadataWorkUnit(full_table_name, mce=mce)

# We also want to assign "table" subType to the dataset representing glue table - unfortunately it is not
Expand Down
106 changes: 71 additions & 35 deletions metadata-ingestion/tests/unit/glue/test_glue_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import pytest
from botocore.stub import Stubber
from freezegun import freeze_time
from moto import mock_athena, mock_sts
from moto.core import DEFAULT_ACCOUNT_ID

import datahub.metadata.schema_classes as models
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.sink.file import write_metadata_file
from datahub.ingestion.source.aws.glue import (
DEFAULT_CATALOG_NAME,
GlueProfilingConfig,
GlueSource,
GlueSourceConfig,
Expand All @@ -35,7 +38,6 @@
validate_all_providers_have_committed_successfully,
)
from tests.unit.glue.test_glue_source_stubs import (
empty_database,
flights_database,
get_bucket_tagging,
get_databases_delta_response,
Expand Down Expand Up @@ -311,40 +313,6 @@ def test_config_without_platform():
assert source.platform == "glue"


def test_get_databases_filters_by_catalog():
def format_databases(databases):
return set(d["Name"] for d in databases)

all_catalogs_source: GlueSource = GlueSource(
config=GlueSourceConfig(aws_region="us-west-2"),
ctx=PipelineContext(run_id="glue-source-test"),
)
with Stubber(all_catalogs_source.glue_client) as glue_stubber:
glue_stubber.add_response("get_databases", get_databases_response, {})

expected = [flights_database, test_database, empty_database]
actual = all_catalogs_source.get_all_databases()
assert format_databases(actual) == format_databases(expected)
assert all_catalogs_source.report.databases.dropped_entities.as_obj() == []

catalog_id = "123412341234"
single_catalog_source: GlueSource = GlueSource(
config=GlueSourceConfig(catalog_id=catalog_id, aws_region="us-west-2"),
ctx=PipelineContext(run_id="glue-source-test"),
)
with Stubber(single_catalog_source.glue_client) as glue_stubber:
glue_stubber.add_response(
"get_databases", get_databases_response, {"CatalogId": catalog_id}
)

expected = [flights_database, test_database]
actual = single_catalog_source.get_all_databases()
assert format_databases(actual) == format_databases(expected)
assert single_catalog_source.report.databases.dropped_entities.as_obj() == [
"empty-database"
]


@freeze_time(FROZEN_TIME)
def test_glue_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
deleted_actor_golden_mcs = "{}/glue_deleted_actor_mces_golden.json".format(
Expand Down Expand Up @@ -750,3 +718,71 @@ def test_glue_ingest_with_profiling(
output_path=tmp_path / mce_file,
golden_path=test_resources_dir / mce_golden_file,
)


@mock_athena
@mock_sts
@pytest.mark.parametrize(
("platform", "catalog_id", "expected"),
[
("athena", None, DEFAULT_CATALOG_NAME),
("athena", DEFAULT_ACCOUNT_ID, DEFAULT_CATALOG_NAME),
("glue", None, None),
("glue", DEFAULT_ACCOUNT_ID, None),
],
)
def test_athena_catalog_name(
platform: str, catalog_id: Optional[str], expected: Optional[str]
) -> None:
pipeline_context = PipelineContext(run_id="glue-source-test")
source = GlueSource(
ctx=pipeline_context,
config=GlueSourceConfig(
aws_region="us-west-2",
platform=platform,
catalog_id=catalog_id,
),
)
assert source.source_config.athena_catalog_name == expected


@mock_athena
@mock_sts
@pytest.mark.parametrize(
("platform", "catalog_id", "database", "table", "expected"),
[
(
"athena",
None,
"test_db",
"test_table",
f"{DEFAULT_CATALOG_NAME}.test_db.test_table",
),
(
"athena",
DEFAULT_ACCOUNT_ID,
"test_db",
"test_table",
f"{DEFAULT_CATALOG_NAME}.test_db.test_table",
),
("glue", None, "test_db", "test_table", "test_db.test_table"),
("glue", DEFAULT_ACCOUNT_ID, "test_db", "test_table", "test_db.test_table"),
],
)
def test_gen_full_table_name(
platform: str,
catalog_id: Optional[str],
database: str,
table: str,
expected: Optional[str],
) -> None:
pipeline_context = PipelineContext(run_id="glue-source-test")
source = GlueSource(
ctx=pipeline_context,
config=GlueSourceConfig(
aws_region="us-west-2",
platform=platform,
catalog_id=catalog_id,
),
)
assert source._gen_full_table_name(database, table) == expected
Loading