Skip to content

Commit

Permalink
Merge branch 'master' into web-react-coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
chakru-r authored Dec 27, 2024
2 parents dc0acd2 + 3ca8d09 commit 03e21ba
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 205 deletions.
1 change: 1 addition & 0 deletions docs/how/updating-datahub.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ This file documents any backwards-incompatible changes in DataHub and assists pe
- #12077: `Kafka` source no longer ingests schemas from schema registry as separate entities by default, set `ingest_schemas_as_entities` to `true` to ingest them
- OpenAPI Update: PIT Keep Alive parameter added to scroll. NOTE: This parameter requires the `pointInTimeCreationEnabled` feature flag to be enabled and the `elasticSearch.implementation` configuration to be `elasticsearch`. This feature is not supported for OpenSearch at this time and the parameter will not be respected without both of these set.
- OpenAPI Update 2: Previously there was an incorrectly marked parameter named `sort` on the generic list entities endpoint for v3. This parameter is deprecated and only supports a single string value while the documentation indicates it supports a list of strings. This documentation error has been fixed and the correct field, `sortCriteria`, is now documented which supports a list of strings.
- #12223: For dbt Cloud ingestion, the "View in dbt" link will point at the "Explore" page in the dbt Cloud UI. You can revert to the old behavior of linking to the dbt Cloud IDE by setting `external_url_mode: ide".

### Breaking Changes

Expand Down
13 changes: 10 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from datetime import datetime
from json import JSONDecodeError
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Literal, Optional, Tuple
from urllib.parse import urlparse

import dateutil.parser
Expand Down Expand Up @@ -62,6 +62,11 @@ class DBTCloudConfig(DBTCommonConfig):
description="The ID of the run to ingest metadata from. If not specified, we'll default to the latest run.",
)

external_url_mode: Literal["explore", "ide"] = Field(
default="explore",
description='Where should the "View in dbt" link point to - either the "Explore" UI or the dbt Cloud IDE',
)

@root_validator(pre=True)
def set_metadata_endpoint(cls, values: dict) -> dict:
if values.get("access_url") and not values.get("metadata_endpoint"):
Expand Down Expand Up @@ -527,5 +532,7 @@ def _parse_into_dbt_column(
)

def get_external_url(self, node: DBTNode) -> Optional[str]:
# TODO: Once dbt Cloud supports deep linking to specific files, we can use that.
return f"{self.config.access_url}/develop/{self.config.account_id}/projects/{self.config.project_id}"
if self.config.external_url_mode == "explore":
return f"{self.config.access_url}/explore/{self.config.account_id}/projects/{self.config.project_id}/environments/production/details/{node.dbt_name}"
else:
return f"{self.config.access_url}/develop/{self.config.account_id}/projects/{self.config.project_id}"
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,16 @@ def resolve_includes(
f"traversal_path={traversal_path}, included_files = {included_files}, seen_so_far: {seen_so_far}"
)
if "*" not in inc and not included_files:
reporter.report_failure(
reporter.warning(
title="Error Resolving Include",
message=f"Cannot resolve include {inc}",
context=f"Path: {path}",
message="Cannot resolve included file",
context=f"Include: {inc}, path: {path}, traversal_path: {traversal_path}",
)
elif not included_files:
reporter.report_failure(
reporter.warning(
title="Error Resolving Include",
message=f"Did not resolve anything for wildcard include {inc}",
context=f"Path: {path}",
message="Did not find anything matching the wildcard include",
context=f"Include: {inc}, path: {path}, traversal_path: {traversal_path}",
)
# only load files that we haven't seen so far
included_files = [x for x in included_files if x not in seen_so_far]
Expand Down Expand Up @@ -231,9 +231,7 @@ def resolve_includes(
source_config,
reporter,
seen_so_far,
traversal_path=traversal_path
+ "."
+ pathlib.Path(included_file).stem,
traversal_path=f"{traversal_path} -> {pathlib.Path(included_file).stem}",
)
)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ def __init__(
tenant_id: str,
metadata_api_timeout: int,
):
self.__access_token: Optional[str] = None
self.__access_token_expiry_time: Optional[datetime] = None
self.__tenant_id = tenant_id
self._access_token: Optional[str] = None
self._access_token_expiry_time: Optional[datetime] = None

self._tenant_id = tenant_id
# Test connection by generating access token
logger.info(f"Trying to connect to {self._get_authority_url()}")
# Power-Bi Auth (Service Principal Auth)
self.__msal_client = msal.ConfidentialClientApplication(
self._msal_client = msal.ConfidentialClientApplication(
client_id,
client_credential=client_secret,
authority=DataResolverBase.AUTHORITY + tenant_id,
Expand Down Expand Up @@ -168,18 +169,18 @@ def _get_app(
pass

def _get_authority_url(self):
return f"{DataResolverBase.AUTHORITY}{self.__tenant_id}"
return f"{DataResolverBase.AUTHORITY}{self._tenant_id}"

def get_authorization_header(self):
return {Constant.Authorization: self.get_access_token()}

def get_access_token(self):
if self.__access_token is not None and not self._is_access_token_expired():
return self.__access_token
def get_access_token(self) -> str:
if self._access_token is not None and not self._is_access_token_expired():
return self._access_token

logger.info("Generating PowerBi access token")

auth_response = self.__msal_client.acquire_token_for_client(
auth_response = self._msal_client.acquire_token_for_client(
scopes=[DataResolverBase.SCOPE]
)

Expand All @@ -193,24 +194,24 @@ def get_access_token(self):

logger.info("Generated PowerBi access token")

self.__access_token = "Bearer {}".format(
self._access_token = "Bearer {}".format(
auth_response.get(Constant.ACCESS_TOKEN)
)
safety_gap = 300
self.__access_token_expiry_time = datetime.now() + timedelta(
self._access_token_expiry_time = datetime.now() + timedelta(
seconds=(
max(auth_response.get(Constant.ACCESS_TOKEN_EXPIRY, 0) - safety_gap, 0)
)
)

logger.debug(f"{Constant.PBIAccessToken}={self.__access_token}")
logger.debug(f"{Constant.PBIAccessToken}={self._access_token}")

return self.__access_token
return self._access_token

def _is_access_token_expired(self) -> bool:
if not self.__access_token_expiry_time:
if not self._access_token_expiry_time:
return True
return self.__access_token_expiry_time < datetime.now()
return self._access_token_expiry_time < datetime.now()

def get_dashboards(self, workspace: Workspace) -> List[Dashboard]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,20 @@ class SnowflakeIdentifierConfig(
description="Whether to convert dataset urns to lowercase.",
)


class SnowflakeUsageConfig(BaseUsageConfig):
email_domain: Optional[str] = pydantic.Field(
default=None,
description="Email domain of your organization so users can be displayed on UI appropriately.",
)

email_as_user_identifier: bool = Field(
default=True,
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
"provided, generates email addresses for snowflake users with unset emails, based on their "
"username.",
)


class SnowflakeUsageConfig(BaseUsageConfig):
apply_view_usage_to_tables: bool = pydantic.Field(
default=False,
description="Whether to apply view's usage to its base tables. If set to True, usage is applied to base tables only.",
Expand Down Expand Up @@ -267,13 +275,6 @@ class SnowflakeV2Config(
" Map of share name -> details of share.",
)

email_as_user_identifier: bool = Field(
default=True,
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
"provided, generates email addresses for snowflake users with unset emails, based on their "
"username.",
)

include_assertion_results: bool = Field(
default=False,
description="Whether to ingest assertion run results for assertions created using Datahub"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@

logger = logging.getLogger(__name__)

# Define a type alias
UserName = str
UserEmail = str
UsersMapping = Dict[UserName, UserEmail]


class SnowflakeQueriesExtractorConfig(ConfigModel):
# TODO: Support stateful ingestion for the time windows.
Expand Down Expand Up @@ -114,11 +119,13 @@ class SnowflakeQueriesSourceConfig(
class SnowflakeQueriesExtractorReport(Report):
copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
users_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)

audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_aggregator: Optional[SqlAggregatorReport] = None

num_ddl_queries_dropped: int = 0
num_users: int = 0


@dataclass
Expand Down Expand Up @@ -225,6 +232,9 @@ def is_allowed_table(self, name: str) -> bool:
def get_workunits_internal(
self,
) -> Iterable[MetadataWorkUnit]:
with self.report.users_fetch_timer:
users = self.fetch_users()

# TODO: Add some logic to check if the cached audit log is stale or not.
audit_log_file = self.local_temp_path / "audit_log.sqlite"
use_cached_audit_log = audit_log_file.exists()
Expand All @@ -248,7 +258,7 @@ def get_workunits_internal(
queries.append(entry)

with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log():
for entry in self.fetch_query_log(users):
queries.append(entry)

with self.report.audit_log_load_timer:
Expand All @@ -263,6 +273,25 @@ def get_workunits_internal(
shared_connection.close()
audit_log_file.unlink(missing_ok=True)

def fetch_users(self) -> UsersMapping:
users: UsersMapping = dict()
with self.structured_reporter.report_exc("Error fetching users from Snowflake"):
logger.info("Fetching users from Snowflake")
query = SnowflakeQuery.get_all_users()
resp = self.connection.query(query)

for row in resp:
try:
users[row["NAME"]] = row["EMAIL"]
self.report.num_users += 1
except Exception as e:
self.structured_reporter.warning(
"Error parsing user row",
context=f"{row}",
exc=e,
)
return users

def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
# Derived from _populate_external_lineage_from_copy_history.

Expand Down Expand Up @@ -298,7 +327,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
yield result

def fetch_query_log(
self,
self, users: UsersMapping
) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]:
query_log_query = _build_enriched_query_log_query(
start_time=self.config.window.start_time,
Expand All @@ -319,7 +348,7 @@ def fetch_query_log(

assert isinstance(row, dict)
try:
entry = self._parse_audit_log_row(row)
entry = self._parse_audit_log_row(row, users)
except Exception as e:
self.structured_reporter.warning(
"Error parsing query log row",
Expand All @@ -331,7 +360,7 @@ def fetch_query_log(
yield entry

def _parse_audit_log_row(
self, row: Dict[str, Any]
self, row: Dict[str, Any], users: UsersMapping
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]:
json_fields = {
"DIRECT_OBJECTS_ACCESSED",
Expand Down Expand Up @@ -430,9 +459,11 @@ def _parse_audit_log_row(
)
)

# TODO: Fetch email addresses from Snowflake to map user -> email
# TODO: Support email_domain fallback for generating user urns.
user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"]))
user = CorpUserUrn(
self.identifiers.get_user_identifier(
res["user_name"], users.get(res["user_name"])
)
)

timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -947,4 +947,8 @@ def dmf_assertion_results(start_time_millis: int, end_time_millis: int) -> str:
AND METRIC_NAME ilike '{pattern}' escape '{escape_pattern}'
ORDER BY MEASUREMENT_TIME ASC;
"""
"""

@staticmethod
def get_all_users() -> str:
return """SELECT name as "NAME", email as "EMAIL" FROM SNOWFLAKE.ACCOUNT_USAGE.USERS"""
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,9 @@ def _map_user_counts(
filtered_user_counts.append(
DatasetUserUsageCounts(
user=make_user_urn(
self.get_user_identifier(
self.identifiers.get_user_identifier(
user_count["user_name"],
user_email,
self.config.email_as_user_identifier,
)
),
count=user_count["total"],
Expand Down Expand Up @@ -453,9 +452,7 @@ def _get_operation_aspect_work_unit(
reported_time: int = int(time.time() * 1000)
last_updated_timestamp: int = int(start_time.timestamp() * 1000)
user_urn = make_user_urn(
self.get_user_identifier(
user_name, user_email, self.config.email_as_user_identifier
)
self.identifiers.get_user_identifier(user_name, user_email)
)

# NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,28 @@ def get_quoted_identifier_for_schema(db_name, schema_name):
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
return f'"{db_name}"."{schema_name}"."{table_name}"'

# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self,
user_name: str,
user_email: Optional[str],
) -> str:
if user_email:
return self.snowflake_identifier(
user_email
if self.identifier_config.email_as_user_identifier is True
else user_email.split("@")[0]
)
return self.snowflake_identifier(
f"{user_name}@{self.identifier_config.email_domain}"
if self.identifier_config.email_as_user_identifier is True
and self.identifier_config.email_domain is not None
else user_name
)


class SnowflakeCommonMixin(SnowflakeStructuredReportMixin):
platform = "snowflake"
Expand All @@ -315,24 +337,6 @@ def structured_reporter(self) -> SourceReport:
def identifiers(self) -> SnowflakeIdentifierBuilder:
return SnowflakeIdentifierBuilder(self.config, self.report)

# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self,
user_name: str,
user_email: Optional[str],
email_as_user_identifier: bool,
) -> str:
if user_email:
return self.identifiers.snowflake_identifier(
user_email
if email_as_user_identifier is True
else user_email.split("@")[0]
)
return self.identifiers.snowflake_identifier(user_name)

# TODO: Revisit this after stateful ingestion can commit checkpoint
# for failures that do not affect the checkpoint
# TODO: Add additional parameters to match the signature of the .warning and .failure methods
Expand Down
Loading

0 comments on commit 03e21ba

Please sign in to comment.