diff --git a/.github/workflows/docker-unified.yml b/.github/workflows/docker-unified.yml index 03a9b3afc3bc5..47c26068347c0 100644 --- a/.github/workflows/docker-unified.yml +++ b/.github/workflows/docker-unified.yml @@ -1011,18 +1011,39 @@ jobs: needs: setup outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} + cypress_batch_count: ${{ steps.set-batch-count.outputs.cypress_batch_count }} + python_batch_count: ${{ steps.set-batch-count.outputs.python_batch_count }} steps: + - id: set-batch-count + # Tests are split simply to ensure the configured number of batches for parallelization. This may need some + # increase as a new tests added increase the duration where an additional parallel batch helps. + # python_batch_count is used to split pytests in the smoke-test (batches of actual test functions) + # cypress_batch_count is used to split the collection of cypress test specs into batches. + run: | + echo "cypress_batch_count=11" >> "$GITHUB_OUTPUT" + echo "python_batch_count=5" >> "$GITHUB_OUTPUT" + - id: set-matrix + # For m batches for python and n batches for cypress, we need a test matrix of python x m + cypress x n. + # while the github action matrix generation can handle these two parts individually, there isnt a way to use the + # two generated matrices for the same job. So, produce that matrix with scripting and use the include directive + # to add it to the test matrix. run: | - if [ '${{ needs.setup.outputs.frontend_only }}' == 'true' ]; then - echo 'matrix=["cypress_suite1","cypress_rest"]' >> "$GITHUB_OUTPUT" - elif [ '${{ needs.setup.outputs.ingestion_only }}' == 'true' ]; then - echo 'matrix=["no_cypress_suite0","no_cypress_suite1"]' >> "$GITHUB_OUTPUT" - elif [[ '${{ needs.setup.outputs.backend_change }}' == 'true' || '${{ needs.setup.outputs.smoke_test_change }}' == 'true' ]]; then - echo 'matrix=["no_cypress_suite0","no_cypress_suite1","cypress_suite1","cypress_rest"]' >> "$GITHUB_OUTPUT" - else - echo 'matrix=[]' >> "$GITHUB_OUTPUT" + python_batch_count=${{ steps.set-batch-count.outputs.python_batch_count }} + python_matrix=$(printf "{\"test_strategy\":\"pytests\",\"batch\":\"0\",\"batch_count\":\"$python_batch_count\"}"; for ((i=1;i> "$GITHUB_OUTPUT" smoke_test: name: Run Smoke Tests @@ -1043,8 +1064,7 @@ jobs: ] strategy: fail-fast: false - matrix: - test_strategy: ${{ fromJson(needs.smoke_test_matrix.outputs.matrix) }} + matrix: ${{ fromJson(needs.smoke_test_matrix.outputs.matrix) }} if: ${{ always() && !failure() && !cancelled() && needs.smoke_test_matrix.outputs.matrix != '[]' }} steps: - name: Free up disk space @@ -1220,6 +1240,8 @@ jobs: CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} CLEANUP_DATA: "false" TEST_STRATEGY: ${{ matrix.test_strategy }} + BATCH_COUNT: ${{ matrix.batch_count }} + BATCH_NUMBER: ${{ matrix.batch }} run: | echo "$DATAHUB_VERSION" ./gradlew --stop @@ -1230,25 +1252,25 @@ jobs: if: failure() run: | docker ps -a - TEST_STRATEGY="-${{ matrix.test_strategy }}" + TEST_STRATEGY="-${{ matrix.test_strategy }}-${{ matrix.batch }}" source .github/scripts/docker_logs.sh - name: Upload logs uses: actions/upload-artifact@v3 if: failure() with: - name: docker-logs-${{ matrix.test_strategy }} + name: docker-logs-${{ matrix.test_strategy }}-${{ matrix.batch }} path: "docker_logs/*.log" retention-days: 5 - name: Upload screenshots uses: actions/upload-artifact@v3 if: failure() with: - name: cypress-snapshots-${{ matrix.test_strategy }} + name: cypress-snapshots-${{ matrix.test_strategy }}-${{ matrix.batch }} path: smoke-test/tests/cypress/cypress/screenshots/ - uses: actions/upload-artifact@v3 if: always() with: - name: Test Results (smoke tests) ${{ matrix.test_strategy }} + name: Test Results (smoke tests) ${{ matrix.test_strategy }} ${{ matrix.batch }} path: | **/build/reports/tests/test/** **/build/test-results/test/** diff --git a/.github/workflows/qodana-scan.yml b/.github/workflows/qodana-scan.yml deleted file mode 100644 index 750cf24ad38e5..0000000000000 --- a/.github/workflows/qodana-scan.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Qodana -on: - workflow_dispatch: - pull_request: - push: - branches: - - master - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - qodana: - runs-on: ubuntu-latest - steps: - - uses: acryldata/sane-checkout-action@v3 - - name: "Qodana Scan" - uses: JetBrains/qodana-action@v2022.3.4 - - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: ${{ runner.temp }}/qodana/results/qodana.sarif.json - cache-default-branch-only: true diff --git a/build.gradle b/build.gradle index a3d807a733349..8929b4e644972 100644 --- a/build.gradle +++ b/build.gradle @@ -35,7 +35,7 @@ buildscript { ext.pegasusVersion = '29.57.0' ext.mavenVersion = '3.6.3' ext.versionGradle = '8.11.1' - ext.springVersion = '6.1.13' + ext.springVersion = '6.1.14' ext.springBootVersion = '3.2.9' ext.springKafkaVersion = '3.1.6' ext.openTelemetryVersion = '1.18.0' diff --git a/docs/how/updating-datahub.md b/docs/how/updating-datahub.md index a742ebe0cd896..d6620fde0bf79 100644 --- a/docs/how/updating-datahub.md +++ b/docs/how/updating-datahub.md @@ -42,6 +42,7 @@ This file documents any backwards-incompatible changes in DataHub and assists pe - #12077: `Kafka` source no longer ingests schemas from schema registry as separate entities by default, set `ingest_schemas_as_entities` to `true` to ingest them - OpenAPI Update: PIT Keep Alive parameter added to scroll. NOTE: This parameter requires the `pointInTimeCreationEnabled` feature flag to be enabled and the `elasticSearch.implementation` configuration to be `elasticsearch`. This feature is not supported for OpenSearch at this time and the parameter will not be respected without both of these set. - OpenAPI Update 2: Previously there was an incorrectly marked parameter named `sort` on the generic list entities endpoint for v3. This parameter is deprecated and only supports a single string value while the documentation indicates it supports a list of strings. This documentation error has been fixed and the correct field, `sortCriteria`, is now documented which supports a list of strings. +- #12223: For dbt Cloud ingestion, the "View in dbt" link will point at the "Explore" page in the dbt Cloud UI. You can revert to the old behavior of linking to the dbt Cloud IDE by setting `external_url_mode: ide". ### Breaking Changes diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py index 66c5ef7179af4..5042f6d69b261 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py @@ -1,7 +1,7 @@ import logging from datetime import datetime from json import JSONDecodeError -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Literal, Optional, Tuple from urllib.parse import urlparse import dateutil.parser @@ -62,6 +62,11 @@ class DBTCloudConfig(DBTCommonConfig): description="The ID of the run to ingest metadata from. If not specified, we'll default to the latest run.", ) + external_url_mode: Literal["explore", "ide"] = Field( + default="explore", + description='Where should the "View in dbt" link point to - either the "Explore" UI or the dbt Cloud IDE', + ) + @root_validator(pre=True) def set_metadata_endpoint(cls, values: dict) -> dict: if values.get("access_url") and not values.get("metadata_endpoint"): @@ -527,5 +532,7 @@ def _parse_into_dbt_column( ) def get_external_url(self, node: DBTNode) -> Optional[str]: - # TODO: Once dbt Cloud supports deep linking to specific files, we can use that. - return f"{self.config.access_url}/develop/{self.config.account_id}/projects/{self.config.project_id}" + if self.config.external_url_mode == "explore": + return f"{self.config.access_url}/explore/{self.config.account_id}/projects/{self.config.project_id}/environments/production/details/{node.dbt_name}" + else: + return f"{self.config.access_url}/develop/{self.config.account_id}/projects/{self.config.project_id}" diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py index 4eecbb4d9d717..168b787b85e8b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py @@ -34,6 +34,7 @@ SoftDeletedEntitiesCleanupConfig, SoftDeletedEntitiesReport, ) +from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport logger = logging.getLogger(__name__) @@ -86,6 +87,7 @@ class DataHubGcSourceReport( DataProcessCleanupReport, SoftDeletedEntitiesReport, DatahubExecutionRequestCleanupReport, + IngestionStageReport, ): expired_tokens_revoked: int = 0 @@ -139,31 +141,40 @@ def get_workunits_internal( ) -> Iterable[MetadataWorkUnit]: if self.config.cleanup_expired_tokens: try: + self.report.report_ingestion_stage_start("Expired Token Cleanup") self.revoke_expired_tokens() except Exception as e: self.report.failure("While trying to cleanup expired token ", exc=e) if self.config.truncate_indices: try: + self.report.report_ingestion_stage_start("Truncate Indices") self.truncate_indices() except Exception as e: self.report.failure("While trying to truncate indices ", exc=e) if self.config.soft_deleted_entities_cleanup.enabled: try: + self.report.report_ingestion_stage_start( + "Soft Deleted Entities Cleanup" + ) self.soft_deleted_entities_cleanup.cleanup_soft_deleted_entities() except Exception as e: self.report.failure( "While trying to cleanup soft deleted entities ", exc=e ) - if self.config.execution_request_cleanup.enabled: - try: - self.execution_request_cleanup.run() - except Exception as e: - self.report.failure("While trying to cleanup execution request ", exc=e) if self.config.dataprocess_cleanup.enabled: try: + self.report.report_ingestion_stage_start("Data Process Cleanup") yield from self.dataprocess_cleanup.get_workunits_internal() except Exception as e: self.report.failure("While trying to cleanup data process ", exc=e) + if self.config.execution_request_cleanup.enabled: + try: + self.report.report_ingestion_stage_start("Execution request Cleanup") + self.execution_request_cleanup.run() + except Exception as e: + self.report.failure("While trying to cleanup execution request ", exc=e) + # Otherwise last stage's duration does not get calculated. + self.report.report_ingestion_stage_start("End") yield from [] def truncate_indices(self) -> None: @@ -281,6 +292,8 @@ def revoke_expired_tokens(self) -> None: list_access_tokens = expired_tokens_res.get("listAccessTokens", {}) tokens = list_access_tokens.get("tokens", []) total = list_access_tokens.get("total", 0) + if tokens == []: + break for token in tokens: self.report.expired_tokens_revoked += 1 token_id = token["id"] diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py b/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py index 3baf858e44cdc..170a6ada3e336 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/execution_request_cleanup.py @@ -1,3 +1,4 @@ +import datetime import logging import time from typing import Any, Dict, Iterator, Optional @@ -42,16 +43,28 @@ class DatahubExecutionRequestCleanupConfig(ConfigModel): description="Global switch for this cleanup task", ) + runtime_limit_seconds: int = Field( + default=3600, + description="Maximum runtime in seconds for the cleanup task", + ) + + max_read_errors: int = Field( + default=10, + description="Maximum number of read errors before aborting", + ) + def keep_history_max_milliseconds(self): return self.keep_history_max_days * 24 * 3600 * 1000 class DatahubExecutionRequestCleanupReport(SourceReport): - execution_request_cleanup_records_read: int = 0 - execution_request_cleanup_records_preserved: int = 0 - execution_request_cleanup_records_deleted: int = 0 - execution_request_cleanup_read_errors: int = 0 - execution_request_cleanup_delete_errors: int = 0 + ergc_records_read: int = 0 + ergc_records_preserved: int = 0 + ergc_records_deleted: int = 0 + ergc_read_errors: int = 0 + ergc_delete_errors: int = 0 + ergc_start_time: Optional[datetime.datetime] = None + ergc_end_time: Optional[datetime.datetime] = None class CleanupRecord(BaseModel): @@ -124,6 +137,13 @@ def _scroll_execution_requests( params.update(overrides) while True: + if self._reached_runtime_limit(): + break + if self.report.ergc_read_errors >= self.config.max_read_errors: + self.report.failure( + f"ergc({self.instance_id}): too many read errors, aborting." + ) + break try: url = f"{self.graph.config.server}/openapi/v2/entity/{DATAHUB_EXECUTION_REQUEST_ENTITY_NAME}" response = self.graph._session.get(url, headers=headers, params=params) @@ -141,7 +161,7 @@ def _scroll_execution_requests( logger.error( f"ergc({self.instance_id}): failed to fetch next batch of execution requests: {e}" ) - self.report.execution_request_cleanup_read_errors += 1 + self.report.ergc_read_errors += 1 def _scroll_garbage_records(self): state: Dict[str, Dict] = {} @@ -150,7 +170,7 @@ def _scroll_garbage_records(self): running_guard_timeout = now_ms - 30 * 24 * 3600 * 1000 for entry in self._scroll_execution_requests(): - self.report.execution_request_cleanup_records_read += 1 + self.report.ergc_records_read += 1 key = entry.ingestion_source # Always delete corrupted records @@ -171,7 +191,7 @@ def _scroll_garbage_records(self): # Do not delete if number of requests is below minimum if state[key]["count"] < self.config.keep_history_min_count: - self.report.execution_request_cleanup_records_preserved += 1 + self.report.ergc_records_preserved += 1 continue # Do not delete if number of requests do not exceed allowed maximum, @@ -179,7 +199,7 @@ def _scroll_garbage_records(self): if (state[key]["count"] < self.config.keep_history_max_count) and ( entry.requested_at > state[key]["cutoffTimestamp"] ): - self.report.execution_request_cleanup_records_preserved += 1 + self.report.ergc_records_preserved += 1 continue # Do not delete if status is RUNNING or PENDING and created within last month. If the record is >month old and it did not @@ -188,7 +208,7 @@ def _scroll_garbage_records(self): "RUNNING", "PENDING", ]: - self.report.execution_request_cleanup_records_preserved += 1 + self.report.ergc_records_preserved += 1 continue # Otherwise delete current record @@ -200,7 +220,7 @@ def _scroll_garbage_records(self): f"record timestamp: {entry.requested_at}." ) ) - self.report.execution_request_cleanup_records_deleted += 1 + self.report.ergc_records_deleted += 1 yield entry def _delete_entry(self, entry: CleanupRecord) -> None: @@ -210,17 +230,31 @@ def _delete_entry(self, entry: CleanupRecord) -> None: ) self.graph.delete_entity(entry.urn, True) except Exception as e: - self.report.execution_request_cleanup_delete_errors += 1 + self.report.ergc_delete_errors += 1 logger.error( f"ergc({self.instance_id}): failed to delete ExecutionRequest {entry.request_id}: {e}" ) + def _reached_runtime_limit(self) -> bool: + if ( + self.config.runtime_limit_seconds + and self.report.ergc_start_time + and ( + datetime.datetime.now() - self.report.ergc_start_time + >= datetime.timedelta(seconds=self.config.runtime_limit_seconds) + ) + ): + logger.info(f"ergc({self.instance_id}): max runtime reached.") + return True + return False + def run(self) -> None: if not self.config.enabled: logger.info( f"ergc({self.instance_id}): ExecutionRequest cleaner is disabled." ) return + self.report.ergc_start_time = datetime.datetime.now() logger.info( ( @@ -232,8 +266,11 @@ def run(self) -> None: ) for entry in self._scroll_garbage_records(): + if self._reached_runtime_limit(): + break self._delete_entry(entry) + self.report.ergc_end_time = datetime.datetime.now() logger.info( f"ergc({self.instance_id}): Finished cleanup of ExecutionRequest records." ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py index 327c9ebf99bd2..d771821a14d88 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_dataclasses.py @@ -186,16 +186,16 @@ def resolve_includes( f"traversal_path={traversal_path}, included_files = {included_files}, seen_so_far: {seen_so_far}" ) if "*" not in inc and not included_files: - reporter.report_failure( + reporter.warning( title="Error Resolving Include", - message=f"Cannot resolve include {inc}", - context=f"Path: {path}", + message="Cannot resolve included file", + context=f"Include: {inc}, path: {path}, traversal_path: {traversal_path}", ) elif not included_files: - reporter.report_failure( + reporter.warning( title="Error Resolving Include", - message=f"Did not resolve anything for wildcard include {inc}", - context=f"Path: {path}", + message="Did not find anything matching the wildcard include", + context=f"Include: {inc}, path: {path}, traversal_path: {traversal_path}", ) # only load files that we haven't seen so far included_files = [x for x in included_files if x not in seen_so_far] @@ -231,9 +231,7 @@ def resolve_includes( source_config, reporter, seen_so_far, - traversal_path=traversal_path - + "." - + pathlib.Path(included_file).stem, + traversal_path=f"{traversal_path} -> {pathlib.Path(included_file).stem}", ) ) except Exception as e: diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py index e1301edef10b8..161975fa635fd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/rest_api_wrapper/data_resolver.py @@ -84,13 +84,14 @@ def __init__( tenant_id: str, metadata_api_timeout: int, ): - self.__access_token: Optional[str] = None - self.__access_token_expiry_time: Optional[datetime] = None - self.__tenant_id = tenant_id + self._access_token: Optional[str] = None + self._access_token_expiry_time: Optional[datetime] = None + + self._tenant_id = tenant_id # Test connection by generating access token logger.info(f"Trying to connect to {self._get_authority_url()}") # Power-Bi Auth (Service Principal Auth) - self.__msal_client = msal.ConfidentialClientApplication( + self._msal_client = msal.ConfidentialClientApplication( client_id, client_credential=client_secret, authority=DataResolverBase.AUTHORITY + tenant_id, @@ -168,18 +169,18 @@ def _get_app( pass def _get_authority_url(self): - return f"{DataResolverBase.AUTHORITY}{self.__tenant_id}" + return f"{DataResolverBase.AUTHORITY}{self._tenant_id}" def get_authorization_header(self): return {Constant.Authorization: self.get_access_token()} - def get_access_token(self): - if self.__access_token is not None and not self._is_access_token_expired(): - return self.__access_token + def get_access_token(self) -> str: + if self._access_token is not None and not self._is_access_token_expired(): + return self._access_token logger.info("Generating PowerBi access token") - auth_response = self.__msal_client.acquire_token_for_client( + auth_response = self._msal_client.acquire_token_for_client( scopes=[DataResolverBase.SCOPE] ) @@ -193,24 +194,24 @@ def get_access_token(self): logger.info("Generated PowerBi access token") - self.__access_token = "Bearer {}".format( + self._access_token = "Bearer {}".format( auth_response.get(Constant.ACCESS_TOKEN) ) safety_gap = 300 - self.__access_token_expiry_time = datetime.now() + timedelta( + self._access_token_expiry_time = datetime.now() + timedelta( seconds=( max(auth_response.get(Constant.ACCESS_TOKEN_EXPIRY, 0) - safety_gap, 0) ) ) - logger.debug(f"{Constant.PBIAccessToken}={self.__access_token}") + logger.debug(f"{Constant.PBIAccessToken}={self._access_token}") - return self.__access_token + return self._access_token def _is_access_token_expired(self) -> bool: - if not self.__access_token_expiry_time: + if not self._access_token_expiry_time: return True - return self.__access_token_expiry_time < datetime.now() + return self._access_token_expiry_time < datetime.now() def get_dashboards(self, workspace: Workspace) -> List[Dashboard]: """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 2b2dcf860cdb0..12e5fb72b00de 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -138,12 +138,20 @@ class SnowflakeIdentifierConfig( description="Whether to convert dataset urns to lowercase.", ) - -class SnowflakeUsageConfig(BaseUsageConfig): email_domain: Optional[str] = pydantic.Field( default=None, description="Email domain of your organization so users can be displayed on UI appropriately.", ) + + email_as_user_identifier: bool = Field( + default=True, + description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is " + "provided, generates email addresses for snowflake users with unset emails, based on their " + "username.", + ) + + +class SnowflakeUsageConfig(BaseUsageConfig): apply_view_usage_to_tables: bool = pydantic.Field( default=False, description="Whether to apply view's usage to its base tables. If set to True, usage is applied to base tables only.", @@ -267,13 +275,6 @@ class SnowflakeV2Config( " Map of share name -> details of share.", ) - email_as_user_identifier: bool = Field( - default=True, - description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is " - "provided, generates email addresses for snowflake users with unset emails, based on their " - "username.", - ) - include_assertion_results: bool = Field( default=False, description="Whether to ingest assertion run results for assertions created using Datahub" diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 174aad0bddd4a..36825dc33fe7d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -66,6 +66,11 @@ logger = logging.getLogger(__name__) +# Define a type alias +UserName = str +UserEmail = str +UsersMapping = Dict[UserName, UserEmail] + class SnowflakeQueriesExtractorConfig(ConfigModel): # TODO: Support stateful ingestion for the time windows. @@ -114,11 +119,13 @@ class SnowflakeQueriesSourceConfig( class SnowflakeQueriesExtractorReport(Report): copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) + users_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) sql_aggregator: Optional[SqlAggregatorReport] = None num_ddl_queries_dropped: int = 0 + num_users: int = 0 @dataclass @@ -225,6 +232,9 @@ def is_allowed_table(self, name: str) -> bool: def get_workunits_internal( self, ) -> Iterable[MetadataWorkUnit]: + with self.report.users_fetch_timer: + users = self.fetch_users() + # TODO: Add some logic to check if the cached audit log is stale or not. audit_log_file = self.local_temp_path / "audit_log.sqlite" use_cached_audit_log = audit_log_file.exists() @@ -248,7 +258,7 @@ def get_workunits_internal( queries.append(entry) with self.report.query_log_fetch_timer: - for entry in self.fetch_query_log(): + for entry in self.fetch_query_log(users): queries.append(entry) with self.report.audit_log_load_timer: @@ -263,6 +273,25 @@ def get_workunits_internal( shared_connection.close() audit_log_file.unlink(missing_ok=True) + def fetch_users(self) -> UsersMapping: + users: UsersMapping = dict() + with self.structured_reporter.report_exc("Error fetching users from Snowflake"): + logger.info("Fetching users from Snowflake") + query = SnowflakeQuery.get_all_users() + resp = self.connection.query(query) + + for row in resp: + try: + users[row["NAME"]] = row["EMAIL"] + self.report.num_users += 1 + except Exception as e: + self.structured_reporter.warning( + "Error parsing user row", + context=f"{row}", + exc=e, + ) + return users + def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: # Derived from _populate_external_lineage_from_copy_history. @@ -298,7 +327,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: yield result def fetch_query_log( - self, + self, users: UsersMapping ) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]: query_log_query = _build_enriched_query_log_query( start_time=self.config.window.start_time, @@ -319,7 +348,7 @@ def fetch_query_log( assert isinstance(row, dict) try: - entry = self._parse_audit_log_row(row) + entry = self._parse_audit_log_row(row, users) except Exception as e: self.structured_reporter.warning( "Error parsing query log row", @@ -331,7 +360,7 @@ def fetch_query_log( yield entry def _parse_audit_log_row( - self, row: Dict[str, Any] + self, row: Dict[str, Any], users: UsersMapping ) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]: json_fields = { "DIRECT_OBJECTS_ACCESSED", @@ -430,9 +459,11 @@ def _parse_audit_log_row( ) ) - # TODO: Fetch email addresses from Snowflake to map user -> email - # TODO: Support email_domain fallback for generating user urns. - user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"])) + user = CorpUserUrn( + self.identifiers.get_user_identifier( + res["user_name"], users.get(res["user_name"]) + ) + ) timestamp: datetime = res["query_start_time"] timestamp = timestamp.astimezone(timezone.utc) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py index a94b39476b2c2..40bcfb514efd2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py @@ -947,4 +947,8 @@ def dmf_assertion_results(start_time_millis: int, end_time_millis: int) -> str: AND METRIC_NAME ilike '{pattern}' escape '{escape_pattern}' ORDER BY MEASUREMENT_TIME ASC; -""" + """ + + @staticmethod + def get_all_users() -> str: + return """SELECT name as "NAME", email as "EMAIL" FROM SNOWFLAKE.ACCOUNT_USAGE.USERS""" diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py index aff15386c5083..4bdf559f293b5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -342,10 +342,9 @@ def _map_user_counts( filtered_user_counts.append( DatasetUserUsageCounts( user=make_user_urn( - self.get_user_identifier( + self.identifiers.get_user_identifier( user_count["user_name"], user_email, - self.config.email_as_user_identifier, ) ), count=user_count["total"], @@ -453,9 +452,7 @@ def _get_operation_aspect_work_unit( reported_time: int = int(time.time() * 1000) last_updated_timestamp: int = int(start_time.timestamp() * 1000) user_urn = make_user_urn( - self.get_user_identifier( - user_name, user_email, self.config.email_as_user_identifier - ) + self.identifiers.get_user_identifier(user_name, user_email) ) # NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index 8e0c97aa135e8..885bee1ccdb90 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -300,6 +300,28 @@ def get_quoted_identifier_for_schema(db_name, schema_name): def get_quoted_identifier_for_table(db_name, schema_name, table_name): return f'"{db_name}"."{schema_name}"."{table_name}"' + # Note - decide how to construct user urns. + # Historically urns were created using part before @ from user's email. + # Users without email were skipped from both user entries as well as aggregates. + # However email is not mandatory field in snowflake user, user_name is always present. + def get_user_identifier( + self, + user_name: str, + user_email: Optional[str], + ) -> str: + if user_email: + return self.snowflake_identifier( + user_email + if self.identifier_config.email_as_user_identifier is True + else user_email.split("@")[0] + ) + return self.snowflake_identifier( + f"{user_name}@{self.identifier_config.email_domain}" + if self.identifier_config.email_as_user_identifier is True + and self.identifier_config.email_domain is not None + else user_name + ) + class SnowflakeCommonMixin(SnowflakeStructuredReportMixin): platform = "snowflake" @@ -315,24 +337,6 @@ def structured_reporter(self) -> SourceReport: def identifiers(self) -> SnowflakeIdentifierBuilder: return SnowflakeIdentifierBuilder(self.config, self.report) - # Note - decide how to construct user urns. - # Historically urns were created using part before @ from user's email. - # Users without email were skipped from both user entries as well as aggregates. - # However email is not mandatory field in snowflake user, user_name is always present. - def get_user_identifier( - self, - user_name: str, - user_email: Optional[str], - email_as_user_identifier: bool, - ) -> str: - if user_email: - return self.identifiers.snowflake_identifier( - user_email - if email_as_user_identifier is True - else user_email.split("@")[0] - ) - return self.identifiers.snowflake_identifier(user_name) - # TODO: Revisit this after stateful ingestion can commit checkpoint # for failures that do not affect the checkpoint # TODO: Add additional parameters to match the signature of the .warning and .failure methods diff --git a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py index 508500ffe489b..df59cae3fad23 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py +++ b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py @@ -186,6 +186,15 @@ except ImportError: REAUTHENTICATE_ERRORS = (NonXMLResponseError,) +RETRIABLE_ERROR_CODES = [ + 408, # Request Timeout + 429, # Too Many Requests + 500, # Internal Server Error + 502, # Bad Gateway + 503, # Service Unavailable + 504, # Gateway Timeout +] + logger: logging.Logger = logging.getLogger(__name__) # Replace / with | @@ -287,7 +296,7 @@ def make_tableau_client(self, site: str) -> Server: max_retries=Retry( total=self.max_retries, backoff_factor=1, - status_forcelist=[429, 500, 502, 503, 504], + status_forcelist=RETRIABLE_ERROR_CODES, ) ) server._session.mount("http://", adapter) @@ -1212,9 +1221,11 @@ def get_connection_object_page( except InternalServerError as ise: # In some cases Tableau Server returns 504 error, which is a timeout error, so it worths to retry. - if ise.code == 504: + # Extended with other retryable errors. + if ise.code in RETRIABLE_ERROR_CODES: if retries_remaining <= 0: raise ise + logger.info(f"Retrying query due to error {ise.code}") return self.get_connection_object_page( query=query, connection_type=connection_type, diff --git a/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py b/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py index 42b3b648bd298..ce683e64b3f46 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py +++ b/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py @@ -42,4 +42,5 @@ def report_ingestion_stage_start(self, stage: str) -> None: self._timer = PerfTimer() self.ingestion_stage = f"{stage} at {datetime.now(timezone.utc)}" + logger.info(f"Stage started: {self.ingestion_stage}") self._timer.start() diff --git a/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py b/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py index b636c12cfda06..00dc79ed38cfb 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py +++ b/metadata-ingestion/tests/integration/powerbi/test_admin_only_api.py @@ -1,5 +1,3 @@ -import logging -import sys from typing import Any, Dict from unittest import mock @@ -483,12 +481,6 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None ) -def enable_logging(): - # set logging to console - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging.getLogger().setLevel(logging.DEBUG) - - def mock_msal_cca(*args, **kwargs): class MsalClient: def acquire_token_for_client(self, *args, **kwargs): @@ -527,8 +519,6 @@ def default_source_config(): @freeze_time(FROZEN_TIME) @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) def test_admin_only_apis(mock_msal, pytestconfig, tmp_path, mock_time, requests_mock): - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_admin_api(request_mock=requests_mock) @@ -567,8 +557,6 @@ def test_admin_only_apis(mock_msal, pytestconfig, tmp_path, mock_time, requests_ def test_most_config_and_modified_since( mock_msal, pytestconfig, tmp_path, mock_time, requests_mock ): - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_admin_api(request_mock=requests_mock) diff --git a/metadata-ingestion/tests/integration/powerbi/test_powerbi.py b/metadata-ingestion/tests/integration/powerbi/test_powerbi.py index edde11ff87d29..739be7cc8408d 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_powerbi.py +++ b/metadata-ingestion/tests/integration/powerbi/test_powerbi.py @@ -1,8 +1,6 @@ import datetime import json -import logging import re -import sys from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast from unittest import mock @@ -31,29 +29,21 @@ FROZEN_TIME = "2022-02-03 07:00:00" -def enable_logging(): - # set logging to console - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging.getLogger().setLevel(logging.DEBUG) - - -class MsalClient: - call_num = 0 - token: Dict[str, Any] = { - "access_token": "dummy", - } - - @staticmethod - def acquire_token_for_client(*args, **kwargs): - MsalClient.call_num += 1 - return MsalClient.token +def mock_msal_cca(*args, **kwargs): + class MsalClient: + def __init__(self): + self.call_num = 0 + self.token: Dict[str, Any] = { + "access_token": "dummy", + } - @staticmethod - def reset(): - MsalClient.call_num = 0 + def acquire_token_for_client(self, *args, **kwargs): + self.call_num += 1 + return self.token + def reset(self): + self.call_num = 0 -def mock_msal_cca(*args, **kwargs): return MsalClient() @@ -154,8 +144,6 @@ def test_powerbi_ingest( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) @@ -199,8 +187,6 @@ def test_powerbi_workspace_type_filter( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api( @@ -260,8 +246,6 @@ def test_powerbi_ingest_patch_disabled( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) @@ -327,8 +311,6 @@ def test_powerbi_platform_instance_ingest( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) @@ -515,8 +497,6 @@ def test_extract_reports( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) @@ -561,8 +541,6 @@ def test_extract_lineage( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) @@ -660,8 +638,6 @@ def test_admin_access_is_not_allowed( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api( @@ -723,8 +699,6 @@ def test_workspace_container( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) @@ -764,85 +738,84 @@ def test_workspace_container( ) -@mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) def test_access_token_expiry_with_long_expiry( - mock_msal: MagicMock, pytestconfig: pytest.Config, tmp_path: str, mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) - pipeline = Pipeline.create( - { - "run_id": "powerbi-test", - "source": { - "type": "powerbi", - "config": { - **default_source_config(), + mock_msal = mock_msal_cca() + + with mock.patch("msal.ConfidentialClientApplication", return_value=mock_msal): + pipeline = Pipeline.create( + { + "run_id": "powerbi-test", + "source": { + "type": "powerbi", + "config": { + **default_source_config(), + }, }, - }, - "sink": { - "type": "file", - "config": { - "filename": f"{tmp_path}/powerbi_access_token_mces.json", + "sink": { + "type": "file", + "config": { + "filename": f"{tmp_path}/powerbi_access_token_mces.json", + }, }, - }, - } - ) + } + ) # for long expiry, the token should only be requested once. - MsalClient.token = { + mock_msal.token = { "access_token": "dummy2", "expires_in": 3600, } + mock_msal.reset() - MsalClient.reset() pipeline.run() # We expect the token to be requested twice (once for AdminApiResolver and one for RegularApiResolver) - assert MsalClient.call_num == 2 + assert mock_msal.call_num == 2 -@mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) def test_access_token_expiry_with_short_expiry( - mock_msal: MagicMock, pytestconfig: pytest.Config, tmp_path: str, mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - register_mock_api(pytestconfig=pytestconfig, request_mock=requests_mock) - pipeline = Pipeline.create( - { - "run_id": "powerbi-test", - "source": { - "type": "powerbi", - "config": { - **default_source_config(), + mock_msal = mock_msal_cca() + with mock.patch("msal.ConfidentialClientApplication", return_value=mock_msal): + pipeline = Pipeline.create( + { + "run_id": "powerbi-test", + "source": { + "type": "powerbi", + "config": { + **default_source_config(), + }, }, - }, - "sink": { - "type": "file", - "config": { - "filename": f"{tmp_path}/powerbi_access_token_mces.json", + "sink": { + "type": "file", + "config": { + "filename": f"{tmp_path}/powerbi_access_token_mces.json", + }, }, - }, - } - ) + } + ) # for short expiry, the token should be requested when expires. - MsalClient.token = { + mock_msal.token = { "access_token": "dummy", "expires_in": 0, } + mock_msal.reset() + pipeline.run() - assert MsalClient.call_num > 2 + assert mock_msal.call_num > 2 def dataset_type_mapping_set_to_all_platform(pipeline: Pipeline) -> None: @@ -940,8 +913,6 @@ def test_dataset_type_mapping_error( def test_server_to_platform_map( mock_msal, pytestconfig, tmp_path, mock_time, requests_mock ): - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" new_config: dict = { **default_source_config(), @@ -1416,8 +1387,6 @@ def test_powerbi_cross_workspace_reference_info_message( mock_time: datetime.datetime, requests_mock: Any, ) -> None: - enable_logging() - register_mock_api( pytestconfig=pytestconfig, request_mock=requests_mock, @@ -1495,8 +1464,6 @@ def common_app_ingest( output_mcp_path: str, override_config: dict = {}, ) -> Pipeline: - enable_logging() - register_mock_api( pytestconfig=pytestconfig, request_mock=requests_mock, diff --git a/metadata-ingestion/tests/integration/powerbi/test_profiling.py b/metadata-ingestion/tests/integration/powerbi/test_profiling.py index 4b48bed003b1e..78d35cf31a26d 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_profiling.py +++ b/metadata-ingestion/tests/integration/powerbi/test_profiling.py @@ -1,5 +1,3 @@ -import logging -import sys from typing import Any, Dict from unittest import mock @@ -271,12 +269,6 @@ def register_mock_admin_api(request_mock: Any, override_data: dict = {}) -> None ) -def enable_logging(): - # set logging to console - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging.getLogger().setLevel(logging.DEBUG) - - def mock_msal_cca(*args, **kwargs): class MsalClient: def acquire_token_for_client(self, *args, **kwargs): @@ -311,8 +303,6 @@ def default_source_config(): @freeze_time(FROZEN_TIME) @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) def test_profiling(mock_msal, pytestconfig, tmp_path, mock_time, requests_mock): - enable_logging() - test_resources_dir = pytestconfig.rootpath / "tests/integration/powerbi" register_mock_admin_api(request_mock=requests_mock) diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py index 82f5691bcee3d..ae0f23d93215d 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py @@ -22,3 +22,58 @@ def test_source_close_cleans_tmp(snowflake_connect, tmp_path): # This closes QueriesExtractor which in turn closes SqlParsingAggregator source.close() assert len(os.listdir(tmp_path)) == 0 + + +@patch("snowflake.connector.connect") +def test_user_identifiers_email_as_identifier(snowflake_connect, tmp_path): + source = SnowflakeQueriesSource.create( + { + "connection": { + "account_id": "ABC12345.ap-south-1.aws", + "username": "TST_USR", + "password": "TST_PWD", + }, + "email_as_user_identifier": True, + "email_domain": "example.com", + }, + PipelineContext("run-id"), + ) + assert ( + source.identifiers.get_user_identifier("username", "username@example.com") + == "username@example.com" + ) + assert ( + source.identifiers.get_user_identifier("username", None) + == "username@example.com" + ) + + # We'd do best effort to use email as identifier, but would keep username as is, + # if email can't be formed. + source.identifiers.identifier_config.email_domain = None + + assert ( + source.identifiers.get_user_identifier("username", "username@example.com") + == "username@example.com" + ) + + assert source.identifiers.get_user_identifier("username", None) == "username" + + +@patch("snowflake.connector.connect") +def test_user_identifiers_username_as_identifier(snowflake_connect, tmp_path): + source = SnowflakeQueriesSource.create( + { + "connection": { + "account_id": "ABC12345.ap-south-1.aws", + "username": "TST_USR", + "password": "TST_PWD", + }, + "email_as_user_identifier": False, + }, + PipelineContext("run-id"), + ) + assert ( + source.identifiers.get_user_identifier("username", "username@example.com") + == "username" + ) + assert source.identifiers.get_user_identifier("username", None) == "username" diff --git a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py index 902ff243c802a..71e5ad10c2fc5 100644 --- a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py +++ b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py @@ -1,7 +1,5 @@ import json -import logging import pathlib -import sys from typing import Any, Dict, List, cast from unittest import mock @@ -88,12 +86,6 @@ } -def enable_logging(): - # set logging to console - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging.getLogger().setLevel(logging.DEBUG) - - def read_response(file_name): response_json_path = f"{test_resources_dir}/setup/{file_name}" with open(response_json_path) as file: @@ -376,7 +368,6 @@ def tableau_ingest_common( @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_mces.json" golden_file_name: str = "tableau_mces_golden.json" tableau_ingest_common( @@ -454,7 +445,6 @@ def mock_data() -> List[dict]: @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_tableau_cll_ingest(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_mces_cll.json" golden_file_name: str = "tableau_cll_mces_golden.json" @@ -481,7 +471,6 @@ def test_tableau_cll_ingest(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_project_pattern(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_project_pattern_mces.json" golden_file_name: str = "tableau_mces_golden.json" @@ -505,7 +494,6 @@ def test_project_pattern(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_project_path_pattern(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_project_path_mces.json" golden_file_name: str = "tableau_project_path_mces_golden.json" @@ -529,8 +517,6 @@ def test_project_path_pattern(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_project_hierarchy(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() - output_file_name: str = "tableau_nested_project_mces.json" golden_file_name: str = "tableau_nested_project_mces_golden.json" @@ -554,7 +540,6 @@ def test_project_hierarchy(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_extract_all_project(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_extract_all_project_mces.json" golden_file_name: str = "tableau_extract_all_project_mces_golden.json" @@ -644,7 +629,6 @@ def test_project_path_pattern_deny(pytestconfig, tmp_path, mock_datahub_graph): def test_tableau_ingest_with_platform_instance( pytestconfig, tmp_path, mock_datahub_graph ): - enable_logging() output_file_name: str = "tableau_with_platform_instance_mces.json" golden_file_name: str = "tableau_with_platform_instance_mces_golden.json" @@ -691,7 +675,6 @@ def test_tableau_ingest_with_platform_instance( def test_lineage_overrides(): - enable_logging() # Simple - specify platform instance to presto table assert ( TableauUpstreamReference( @@ -745,7 +728,6 @@ def test_lineage_overrides(): def test_database_hostname_to_platform_instance_map(): - enable_logging() # Simple - snowflake table assert ( TableauUpstreamReference( @@ -916,7 +898,6 @@ def test_tableau_stateful(pytestconfig, tmp_path, mock_time, mock_datahub_graph) def test_tableau_no_verify(): - enable_logging() # This test ensures that we can connect to a self-signed certificate # when ssl_verify is set to False. @@ -941,7 +922,6 @@ def test_tableau_no_verify(): @freeze_time(FROZEN_TIME) @pytest.mark.integration_batch_2 def test_tableau_signout_timeout(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_signout_timeout_mces.json" golden_file_name: str = "tableau_signout_timeout_mces_golden.json" tableau_ingest_common( @@ -1073,7 +1053,6 @@ def test_get_all_datasources_failure(pytestconfig, tmp_path, mock_datahub_graph) @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_tableau_ingest_multiple_sites(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_mces_multiple_sites.json" golden_file_name: str = "tableau_multiple_sites_mces_golden.json" @@ -1135,7 +1114,6 @@ def test_tableau_ingest_multiple_sites(pytestconfig, tmp_path, mock_datahub_grap @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_tableau_ingest_sites_as_container(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_mces_ingest_sites_as_container.json" golden_file_name: str = "tableau_sites_as_container_mces_golden.json" @@ -1159,7 +1137,6 @@ def test_tableau_ingest_sites_as_container(pytestconfig, tmp_path, mock_datahub_ @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_site_name_pattern(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_site_name_pattern_mces.json" golden_file_name: str = "tableau_site_name_pattern_mces_golden.json" @@ -1183,7 +1160,6 @@ def test_site_name_pattern(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_permission_ingestion(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_permission_ingestion_mces.json" golden_file_name: str = "tableau_permission_ingestion_mces_golden.json" @@ -1209,7 +1185,6 @@ def test_permission_ingestion(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_no_hidden_assets(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_no_hidden_assets_mces.json" golden_file_name: str = "tableau_no_hidden_assets_mces_golden.json" @@ -1232,7 +1207,6 @@ def test_no_hidden_assets(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_ingest_tags_disabled(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_ingest_tags_disabled_mces.json" golden_file_name: str = "tableau_ingest_tags_disabled_mces_golden.json" @@ -1254,7 +1228,6 @@ def test_ingest_tags_disabled(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_hidden_asset_tags(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() output_file_name: str = "tableau_hidden_asset_tags_mces.json" golden_file_name: str = "tableau_hidden_asset_tags_mces_golden.json" @@ -1277,8 +1250,6 @@ def test_hidden_asset_tags(pytestconfig, tmp_path, mock_datahub_graph): @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_hidden_assets_without_ingest_tags(pytestconfig, tmp_path, mock_datahub_graph): - enable_logging() - new_config = config_source_default.copy() new_config["tags_for_hidden_assets"] = ["hidden", "private"] new_config["ingest_tags"] = False diff --git a/smoke-test/.gitignore b/smoke-test/.gitignore index b8af2eef535a0..d8cfd65ff81b9 100644 --- a/smoke-test/.gitignore +++ b/smoke-test/.gitignore @@ -29,6 +29,8 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +**/cypress/node_modules + # PyInstaller # Usually these files are written by a python script from a template @@ -132,4 +134,4 @@ dmypy.json # Pyre type checker .pyre/ junit* -tests/cypress/onboarding.json \ No newline at end of file +tests/cypress/onboarding.json diff --git a/smoke-test/build.gradle b/smoke-test/build.gradle index 73ecdcb08ea14..60d08e0206cda 100644 --- a/smoke-test/build.gradle +++ b/smoke-test/build.gradle @@ -91,39 +91,31 @@ task pythonLintFix(type: Exec, dependsOn: installDev) { * The following tasks assume an already running quickstart. * ./gradlew quickstart (or another variation `quickstartDebug`) */ -task noCypressSuite0(type: Exec, dependsOn: [installDev, ':metadata-ingestion:installDev']) { - environment 'RUN_QUICKSTART', 'false' - environment 'TEST_STRATEGY', 'no_cypress_suite0' - - workingDir = project.projectDir - commandLine 'bash', '-c', - "source ${venv_name}/bin/activate && set -x && " + - "./smoke.sh" -} +// ./gradlew :smoke-test:pytest -PbatchNumber=2 (default 0) +task pytest(type: Exec, dependsOn: [installDev, ':metadata-ingestion:installDev']) { + // Get BATCH_NUMBER from command line argument with default value of 0 + def batchNumber = project.hasProperty('batchNumber') ? project.property('batchNumber') : '0' -task noCypressSuite1(type: Exec, dependsOn: [installDev, ':metadata-ingestion:installDev']) { environment 'RUN_QUICKSTART', 'false' - environment 'TEST_STRATEGY', 'no_cypress_suite1' + environment 'TEST_STRATEGY', 'pytests' + environment 'BATCH_COUNT', 5 + environment 'BATCH_NUMBER', batchNumber workingDir = project.projectDir commandLine 'bash', '-c', "source ${venv_name}/bin/activate && set -x && " + - "./smoke.sh" + "./smoke.sh" } -task cypressSuite1(type: Exec, dependsOn: [installDev, ':metadata-ingestion:installDev']) { - environment 'RUN_QUICKSTART', 'false' - environment 'TEST_STRATEGY', 'cypress_suite1' - - workingDir = project.projectDir - commandLine 'bash', '-c', - "source ${venv_name}/bin/activate && set -x && " + - "./smoke.sh" -} +// ./gradlew :smoke-test:cypressTest -PbatchNumber=2 (default 0) +task cypressTest(type: Exec, dependsOn: [installDev, ':metadata-ingestion:installDev']) { + // Get BATCH_NUMBER from command line argument with default value of 0 + def batchNumber = project.hasProperty('batchNumber') ? project.property('batchNumber') : '0' -task cypressRest(type: Exec, dependsOn: [installDev, ':metadata-ingestion:installDev']) { environment 'RUN_QUICKSTART', 'false' - environment 'TEST_STRATEGY', 'cypress_rest' + environment 'TEST_STRATEGY', 'cypress' + environment 'BATCH_COUNT', 11 + environment 'BATCH_NUMBER', batchNumber workingDir = project.projectDir commandLine 'bash', '-c', diff --git a/smoke-test/conftest.py b/smoke-test/conftest.py index 6d148db9886a4..d48a92b22ab48 100644 --- a/smoke-test/conftest.py +++ b/smoke-test/conftest.py @@ -1,6 +1,8 @@ import os import pytest +from typing import List, Tuple +from _pytest.nodes import Item import requests from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph @@ -45,3 +47,53 @@ def graph_client(auth_session) -> DataHubGraph: def pytest_sessionfinish(session, exitstatus): """whole test run finishes.""" send_message(exitstatus) + + +def get_batch_start_end(num_tests: int) -> Tuple[int, int]: + batch_count_env = os.getenv("BATCH_COUNT", 1) + batch_count = int(batch_count_env) + + batch_number_env = os.getenv("BATCH_NUMBER", 0) + batch_number = int(batch_number_env) + + if batch_count == 0 or batch_count > num_tests: + raise ValueError( + f"Invalid batch count {batch_count}: must be >0 and <= {num_tests} (num_tests)" + ) + if batch_number >= batch_count: + raise ValueError( + f"Invalid batch number: {batch_number}, must be less than {batch_count} (zer0 based index)" + ) + + batch_size = round(num_tests / batch_count) + + batch_start = batch_size * batch_number + batch_end = batch_start + batch_size + # We must have exactly as many batches as specified by BATCH_COUNT. + if ( + num_tests - batch_end < batch_size + ): # We must have exactly as many batches as specified by BATCH_COUNT, put the remaining in the last batch. + batch_end = num_tests + + if batch_count > 0: + print(f"Running tests for batch {batch_number} of {batch_count}") + + return batch_start, batch_end + + +def pytest_collection_modifyitems( + session: pytest.Session, config: pytest.Config, items: List[Item] +) -> None: + if os.getenv("TEST_STRATEGY") == "cypress": + return # We launch cypress via pytests, but needs a different batching mechanism at cypress level. + + # If BATCH_COUNT and BATCH_ENV vars are set, splits the pytests to batches and runs filters only the BATCH_NUMBER + # batch for execution. Enables multiple parallel launches. Current implementation assumes all test are of equal + # weight for batching. TODO. A weighted batching method can help make batches more equal sized by cost. + # this effectively is a no-op if BATCH_COUNT=1 + start_index, end_index = get_batch_start_end(num_tests=len(items)) + + items.sort(key=lambda x: x.nodeid) # we want the order to be stable across batches + # replace items with the filtered list + print(f"Running tests for batch {start_index}-{end_index}") + items[:] = items[start_index:end_index] diff --git a/smoke-test/smoke.sh b/smoke-test/smoke.sh index 888a60f488e1f..ec8188ebf5f4d 100755 --- a/smoke-test/smoke.sh +++ b/smoke-test/smoke.sh @@ -34,15 +34,20 @@ source ./set-cypress-creds.sh # set environment variables for the test source ./set-test-env-vars.sh -# no_cypress_suite0, no_cypress_suite1, cypress_suite1, cypress_rest -if [[ -z "${TEST_STRATEGY}" ]]; then - pytest -rP --durations=20 -vv --continue-on-collection-errors --junit-xml=junit.smoke.xml +# TEST_STRATEGY: +# if set to pytests, runs all pytests, skips cypress tests(though cypress test launch is via a pytest). +# if set tp cypress, runs all cypress tests +# if blank, runs all. +# When invoked via the github action, BATCH_COUNT and BATCH_NUM env vars are set to run a slice of those tests per +# worker for parallelism. docker-unified.yml generates a test matrix of pytests/cypress in batches. As number of tests +# increase, the batch_count config (in docker-unified.yml) may need adjustment. +if [[ "${TEST_STRATEGY}" == "pytests" ]]; then + #pytests only - github test matrix runs pytests in one of the runners when applicable. + pytest -rP --durations=20 -vv --continue-on-collection-errors --junit-xml=junit.smoke-pytests.xml -k 'not test_run_cypress' +elif [[ "${TEST_STRATEGY}" == "cypress" ]]; then + # run only cypress tests. The test inspects BATCH_COUNT and BATCH_NUMBER and runs only a subset of tests in that batch. + # github workflow test matrix will invoke this in multiple runners for each batch. + pytest -rP --durations=20 -vv --continue-on-collection-errors --junit-xml=junit.smoke-cypress${BATCH_NUMBER}.xml tests/cypress/integration_test.py else - if [ "$TEST_STRATEGY" == "no_cypress_suite0" ]; then - pytest -rP --durations=20 -vv --continue-on-collection-errors --junit-xml=junit.smoke_non_cypress.xml -k 'not test_run_cypress' -m 'not no_cypress_suite1' - elif [ "$TEST_STRATEGY" == "no_cypress_suite1" ]; then - pytest -rP --durations=20 -vv --continue-on-collection-errors --junit-xml=junit.smoke_non_cypress.xml -m 'no_cypress_suite1' - else - pytest -rP --durations=20 -vv --continue-on-collection-errors --junit-xml=junit.smoke_cypress_${TEST_STRATEGY}.xml tests/cypress/integration_test.py - fi + pytest -rP --durations=20 -vv --continue-on-collection-errors --junit-xml=junit.smoke-all.xml fi diff --git a/smoke-test/tests/cypress/integration_test.py b/smoke-test/tests/cypress/integration_test.py index 0d824a96810d0..33c67a923c278 100644 --- a/smoke-test/tests/cypress/integration_test.py +++ b/smoke-test/tests/cypress/integration_test.py @@ -1,10 +1,11 @@ import datetime import os import subprocess -from typing import List, Set +from typing import List import pytest +from conftest import get_batch_start_end from tests.setup.lineage.ingest_time_lineage import ( get_time_lineage_urns, ingest_time_lineage, @@ -169,10 +170,29 @@ def ingest_cleanup_data(auth_session, graph_client): print("deleted onboarding data") -def _get_spec_map(items: Set[str]) -> str: - if len(items) == 0: - return "" - return ",".join([f"**/{item}/*.js" for item in items]) +def _get_js_files(base_path: str): + file_paths = [] + for root, dirs, files in os.walk(base_path): + for file in files: + if file.endswith(".js"): + file_paths.append(os.path.relpath(os.path.join(root, file), base_path)) + return sorted(file_paths) # sort to make the order stable across batch runs + + +def _get_cypress_tests_batch(): + """ + Batching is configured via env vars BATCH_COUNT and BATCH_NUMBER. All cypress tests are split into exactly + BATCH_COUNT batches. When BATCH_NUMBER env var is set (zero based index), that batch alone is run. + Github workflow via test_matrix, runs all batches in parallel to speed up the test elapsed time. + If either of these vars are not set, all tests are run sequentially. + :return: + """ + all_tests = _get_js_files("tests/cypress/cypress/e2e") + + batch_start, batch_end = get_batch_start_end(num_tests=len(all_tests)) + + return all_tests[batch_start:batch_end] + # return test_batches[int(batch_number)] #if BATCH_NUMBER was set, we this test just runs that one batch. def test_run_cypress(auth_session): @@ -182,24 +202,23 @@ def test_run_cypress(auth_session): test_strategy = os.getenv("TEST_STRATEGY", None) if record_key: record_arg = " --record " - tag_arg = f" --tag {test_strategy} " + batch_number = os.getenv("BATCH_NUMBER") + batch_count = os.getenv("BATCH_COUNT") + if batch_number and batch_count: + batch_suffix = f"-{batch_number}{batch_count}" + else: + batch_suffix = "" + tag_arg = f" --tag {test_strategy}{batch_suffix}" else: record_arg = " " rest_specs = set(os.listdir("tests/cypress/cypress/e2e")) cypress_suite1_specs = {"mutations", "search", "views"} rest_specs.difference_update(set(cypress_suite1_specs)) - strategy_spec_map = { - "cypress_suite1": cypress_suite1_specs, - "cypress_rest": rest_specs, - } print(f"test strategy is {test_strategy}") test_spec_arg = "" - if test_strategy is not None: - specs = strategy_spec_map.get(test_strategy) - assert specs is not None - specs_str = _get_spec_map(specs) - test_spec_arg = f" --spec '{specs_str}' " + specs_str = ",".join([f"**/{f}" for f in _get_cypress_tests_batch()]) + test_spec_arg = f" --spec '{specs_str}' " print("Running Cypress tests with command") command = f"NO_COLOR=1 npx cypress run {record_arg} {test_spec_arg} {tag_arg}"