diff --git a/src/databricks/labs/ucx/assessment/clusters.py b/src/databricks/labs/ucx/assessment/clusters.py index 0d605deeb7..5199b62cd5 100644 --- a/src/databricks/labs/ucx/assessment/clusters.py +++ b/src/databricks/labs/ucx/assessment/clusters.py @@ -49,6 +49,18 @@ class ClusterInfo: __id_attributes__: ClassVar[tuple[str, ...]] = ("cluster_id",) + @classmethod + def from_cluster_details(cls, details: ClusterDetails): + return ClusterInfo( + cluster_id=details.cluster_id if details.cluster_id else "", + cluster_name=details.cluster_name, + policy_id=details.policy_id, + spark_version=details.spark_version, + creator=details.creator_user_name or None, + success=1, + failures="[]", + ) + class CheckClusterMixin(CheckInitScriptMixin): _ws: WorkspaceClient @@ -155,7 +167,7 @@ def _crawl(self) -> Iterable[ClusterInfo]: all_clusters = list(self._ws.clusters.list()) return list(self._assess_clusters(all_clusters)) - def _assess_clusters(self, all_clusters): + def _assess_clusters(self, all_clusters: Iterable[ClusterDetails]): for cluster in all_clusters: if cluster.cluster_source == ClusterSource.JOB: continue @@ -165,15 +177,7 @@ def _assess_clusters(self, all_clusters): f"Cluster {cluster.cluster_id} have Unknown creator, it means that the original creator " f"has been deleted and should be re-created" ) - cluster_info = ClusterInfo( - cluster_id=cluster.cluster_id if cluster.cluster_id else "", - cluster_name=cluster.cluster_name, - policy_id=cluster.policy_id, - spark_version=cluster.spark_version, - creator=creator, - success=1, - failures="[]", - ) + cluster_info = ClusterInfo.from_cluster_details(cluster) failures = self._check_cluster_failures(cluster, "cluster") if len(failures) > 0: cluster_info.success = 0 diff --git a/src/databricks/labs/ucx/assessment/jobs.py b/src/databricks/labs/ucx/assessment/jobs.py index 1f87b26770..fe23e42fa0 100644 --- a/src/databricks/labs/ucx/assessment/jobs.py +++ b/src/databricks/labs/ucx/assessment/jobs.py @@ -21,6 +21,7 @@ RunType, SparkJarTask, SqlTask, + Job, ) from databricks.labs.ucx.assessment.clusters import CheckClusterMixin @@ -43,6 +44,17 @@ class JobInfo: __id_attributes__: ClassVar[tuple[str, ...]] = ("job_id",) + @classmethod + def from_job(cls, job: Job): + job_name = job.settings.name if job.settings and job.settings.name else "Unknown" + return JobInfo( + job_id=str(job.job_id), + success=1, + failures="[]", + job_name=job_name, + creator=job.creator_user_name or None, + ) + class JobsMixin: @classmethod @@ -127,17 +139,7 @@ def _prepare(all_jobs) -> tuple[dict[int, set[str]], dict[int, JobInfo]]: job_settings = job.settings if not job_settings: continue - job_name = job_settings.name - if not job_name: - job_name = "Unknown" - - job_details[job.job_id] = JobInfo( - job_id=str(job.job_id), - job_name=job_name, - creator=creator_user_name, - success=1, - failures="[]", - ) + job_details[job.job_id] = JobInfo.from_job(job) return job_assessment, job_details def _try_fetch(self) -> Iterable[JobInfo]: diff --git a/src/databricks/labs/ucx/assessment/sequencing.py b/src/databricks/labs/ucx/assessment/sequencing.py new file mode 100644 index 0000000000..137752d84a --- /dev/null +++ b/src/databricks/labs/ucx/assessment/sequencing.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import jobs + +from databricks.labs.blueprint.paths import WorkspacePath + +from databricks.labs.ucx.assessment.clusters import ClusterOwnership, ClusterInfo +from databricks.labs.ucx.assessment.jobs import JobOwnership, JobInfo +from databricks.labs.ucx.framework.owners import AdministratorLocator, WorkspacePathOwnership +from databricks.labs.ucx.source_code.graph import DependencyGraph +from databricks.labs.ucx.source_code.path_lookup import PathLookup +from databricks.labs.ucx.source_code.used_table import UsedTablesCrawler + + +@dataclass +class MigrationStep: + step_id: int + step_number: int + object_type: str + object_id: str + object_name: str + object_owner: str + required_step_ids: list[int] + + @property + def key(self) -> tuple[str, str]: + return self.object_type, self.object_id + + +@dataclass +class MigrationNode: + node_id: int + object_type: str + object_id: str + object_name: str + object_owner: str + + @property + def key(self) -> tuple[str, str]: + return self.object_type, self.object_id + + def as_step(self, step_number: int, required_step_ids: list[int]) -> MigrationStep: + return MigrationStep( + step_id=self.node_id, + step_number=step_number, + object_type=self.object_type, + object_id=self.object_id, + object_name=self.object_name, + object_owner=self.object_owner, + required_step_ids=required_step_ids, + ) + + +class MigrationSequencer: + + def __init__( + self, + ws: WorkspaceClient, + path_lookup: PathLookup, + admin_locator: AdministratorLocator, + used_tables_crawler: UsedTablesCrawler, + ): + self._ws = ws + self._path_lookup = path_lookup + self._admin_locator = admin_locator + self._used_tables_crawler = used_tables_crawler + self._last_node_id = 0 + self._nodes: dict[tuple[str, str], MigrationNode] = {} + self._outgoing: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set) + + def register_workflow_task(self, task: jobs.Task, job: jobs.Job, graph: DependencyGraph) -> MigrationNode: + task_id = f"{job.job_id}/{task.task_key}" + task_node = self._nodes.get(("TASK", task_id), None) + if task_node: + return task_node + job_node = self.register_workflow_job(job) + self._last_node_id += 1 + task_node = MigrationNode( + node_id=self._last_node_id, + object_type="TASK", + object_id=task_id, + object_name=task.task_key, + object_owner=job_node.object_owner, # no task owner so use job one + ) + self._nodes[task_node.key] = task_node + self._outgoing[task_node.key].add(job_node.key) + if task.existing_cluster_id: + cluster_node = self.register_cluster(task.existing_cluster_id) + if cluster_node: + self._outgoing[task_node.key].add(cluster_node.key) + # also make the cluster dependent on the job + self._outgoing[job_node.key].add(cluster_node.key) + graph.visit(self._visit_dependency, None) + return task_node + + def _visit_dependency(self, graph: DependencyGraph) -> bool | None: + lineage = graph.dependency.lineage[-1] + parent_node = self._nodes[(lineage.object_type, lineage.object_id)] + for dependency in graph.local_dependencies: + lineage = dependency.lineage[-1] + self.register_dependency(parent_node, lineage.object_type, lineage.object_id) + # TODO tables and dfsas + return False + + def register_dependency(self, parent_node: MigrationNode | None, object_type: str, object_id: str) -> MigrationNode: + dependency_node = self._nodes.get((object_type, object_id), None) + if not dependency_node: + dependency_node = self._create_dependency_node(object_type, object_id) + list(self._register_used_tables_for(dependency_node)) + if parent_node: + self._outgoing[dependency_node.key].add(parent_node.key) + return dependency_node + + def _create_dependency_node(self, object_type: str, object_id: str) -> MigrationNode: + object_name: str = "" + _object_owner: str = "" + if object_type in {"NOTEBOOK", "FILE"}: + path = Path(object_id) + for library_root in self._path_lookup.library_roots: + if not path.is_relative_to(library_root): + continue + object_name = path.relative_to(library_root).as_posix() + break + ws_path = WorkspacePath(self._ws, object_id) + object_owner = WorkspacePathOwnership(self._admin_locator, self._ws).owner_of(ws_path) + else: + raise ValueError(f"{object_type} not supported yet!") + self._last_node_id += 1 + dependency_node = MigrationNode( + node_id=self._last_node_id, + object_type=object_type, + object_id=object_id, + object_name=object_name, + object_owner=object_owner, + ) + self._nodes[dependency_node.key] = dependency_node + return dependency_node + + def _register_used_tables_for(self, parent_node: MigrationNode) -> Iterable[MigrationNode]: + if parent_node.object_type not in {"NOTEBOOK", "FILE"}: + return + used_tables = self._used_tables_crawler.for_lineage(parent_node.object_type, parent_node.object_id) + for used_table in used_tables: + self._last_node_id += 1 + table_node = MigrationNode( + node_id=self._last_node_id, + object_type="TABLE", + object_id=used_table.fullname, + object_name=used_table.fullname, + object_owner="", # TODO + ) + self._nodes[table_node.key] = table_node + self._outgoing[table_node.key].add(parent_node.key) + yield table_node + + def register_workflow_job(self, job: jobs.Job) -> MigrationNode: + job_node = self._nodes.get(("WORKFLOW", str(job.job_id)), None) + if job_node: + return job_node + self._last_node_id += 1 + job_name = job.settings.name if job.settings and job.settings.name else str(job.job_id) + job_node = MigrationNode( + node_id=self._last_node_id, + object_type="WORKFLOW", + object_id=str(job.job_id), + object_name=job_name, + object_owner=JobOwnership(self._admin_locator).owner_of(JobInfo.from_job(job)), + ) + self._nodes[job_node.key] = job_node + if job.settings and job.settings.job_clusters: + for job_cluster in job.settings.job_clusters: + cluster_node = self.register_job_cluster(job_cluster) + if cluster_node: + self._outgoing[job_node.key].add(cluster_node.key) + return job_node + + def register_job_cluster(self, cluster: jobs.JobCluster) -> MigrationNode | None: + if cluster.new_cluster: + return None + return self.register_cluster(cluster.job_cluster_key) + + def register_cluster(self, cluster_id: str) -> MigrationNode: + cluster_node = self._nodes.get(("CLUSTER", cluster_id), None) + if cluster_node: + return cluster_node + details = self._ws.clusters.get(cluster_id) + object_name = details.cluster_name if details and details.cluster_name else cluster_id + self._last_node_id += 1 + cluster_node = MigrationNode( + node_id=self._last_node_id, + object_type="CLUSTER", + object_id=cluster_id, + object_name=object_name, + object_owner=ClusterOwnership(self._admin_locator).owner_of(ClusterInfo.from_cluster_details(details)), + ) + self._nodes[cluster_node.key] = cluster_node + # TODO register warehouses and policies + return cluster_node + + def generate_steps(self) -> Iterable[MigrationStep]: + """The below algo is adapted from Kahn's topological sort. + The differences are as follows: + 1) we want the same step number for all nodes with same dependency depth + so instead of pushing 'leaf' nodes to a queue, we fetch them again once all current 'leaf' nodes are processed + (these are transient 'leaf' nodes i.e. they only become 'leaf' during processing) + 2) Kahn only supports DAGs but python code allows cyclic dependencies i.e. A -> B -> C -> A is not a DAG + so when fetching 'leaf' nodes, we relax the 0-incoming-vertex rule in order + to avoid an infinite loop. We also avoid side effects (such as negative counts). + This algo works correctly for simple cases, but is not tested on large trees. + """ + incoming_keys = self._collect_incoming_keys() + incoming_counts = self._compute_incoming_counts(incoming_keys) + step_number = 1 + sorted_steps: list[MigrationStep] = [] + while len(incoming_counts) > 0: + leaf_keys = self._get_leaf_keys(incoming_counts) + for leaf_key in leaf_keys: + del incoming_counts[leaf_key] + sorted_steps.append( + self._nodes[leaf_key].as_step(step_number, list(self._required_step_ids(incoming_keys[leaf_key]))) + ) + self._on_leaf_key_processed(leaf_key, incoming_counts) + step_number += 1 + return sorted_steps + + def _on_leaf_key_processed(self, leaf_key: tuple[str, str], incoming_counts: dict[tuple[str, str], int]): + for dependency_key in self._outgoing[leaf_key]: + # prevent re-instantiation of already deleted keys + if dependency_key not in incoming_counts: + continue + # prevent negative count with cyclic dependencies + if incoming_counts[dependency_key] > 0: + incoming_counts[dependency_key] -= 1 + + def _collect_incoming_keys(self) -> dict[tuple[str, str], set[tuple[str, str]]]: + result: dict[tuple[str, str], set[tuple[str, str]]] = defaultdict(set) + for source, outgoing in self._outgoing.items(): + for target in outgoing: + result[target].add(source) + return result + + def _required_step_ids(self, required_step_keys: set[tuple[str, str]]) -> Iterable[int]: + for source_key in required_step_keys: + yield self._nodes[source_key].node_id + + def _compute_incoming_counts( + self, incoming: dict[tuple[str, str], set[tuple[str, str]]] + ) -> dict[tuple[str, str], int]: + result = defaultdict(int) + for node_key in self._nodes: + result[node_key] = len(incoming[node_key]) + return result + + @classmethod + def _get_leaf_keys(cls, incoming_counts: dict[tuple[str, str], int]) -> Iterable[tuple[str, str]]: + max_count = 0 + leaf_keys = list(cls._yield_leaf_keys(incoming_counts, max_count)) + # if we're not finding nodes with 0 incoming counts, it's likely caused by cyclic dependencies + # in which case it's safe to process nodes with a higher incoming count + while not leaf_keys: + max_count += 1 + leaf_keys = list(cls._yield_leaf_keys(incoming_counts, max_count)) + return leaf_keys + + @classmethod + def _yield_leaf_keys(cls, incoming_counts: dict[tuple[str, str], int], max_count: int) -> Iterable[tuple[str, str]]: + for node_key, incoming_count in incoming_counts.items(): + if incoming_count > max_count: + continue + yield node_key diff --git a/src/databricks/labs/ucx/hive_metastore/tables.py b/src/databricks/labs/ucx/hive_metastore/tables.py index 08f1864586..1258706212 100644 --- a/src/databricks/labs/ucx/hive_metastore/tables.py +++ b/src/databricks/labs/ucx/hive_metastore/tables.py @@ -15,6 +15,7 @@ from databricks.labs.lsql.backends import SqlBackend from databricks.sdk.errors import NotFound +from databricks.labs.ucx.source_code.base import UsedTable from databricks.labs.ucx.framework.crawlers import CrawlerBase from databricks.labs.ucx.framework.utils import escape_sql_identifier @@ -86,6 +87,16 @@ def __post_init__(self) -> None: if isinstance(self.table_format, str): # Should not happen according to type hint, still safer self.table_format = self.table_format.upper() + @staticmethod + def from_used_table(used_table: UsedTable): + return Table( + catalog=used_table.catalog_name, + database=used_table.table_name, + name=used_table.table_name, + object_type="UNKNOWN", + table_format="UNKNOWN", + ) + @property def is_delta(self) -> bool: if self.table_format is None: diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index 659b38b2b7..cfc8cff773 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -266,6 +266,10 @@ def parse(cls, value: str, default_schema: str, is_read=True, is_write=False) -> is_read: bool = True is_write: bool = False + @property + def fullname(self) -> str: + return f"{self.catalog_name}.{self.schema_name}.{self.table_name}" + class TableCollector(ABC): diff --git a/src/databricks/labs/ucx/source_code/jobs.py b/src/databricks/labs/ucx/source_code/jobs.py index 60594588bb..e9791595b3 100644 --- a/src/databricks/labs/ucx/source_code/jobs.py +++ b/src/databricks/labs/ucx/source_code/jobs.py @@ -78,8 +78,8 @@ def as_message(self) -> str: class WorkflowTask(Dependency): - def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job): - loader = WrappingLoader(WorkflowTaskContainer(ws, task, job)) + def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job, cache: WorkspaceCache | None = None): + loader = WrappingLoader(WorkflowTaskContainer(ws, task, job, cache)) super().__init__(loader, Path(f'/jobs/{task.task_key}'), inherits_context=False) self._task = task self._job = job @@ -99,11 +99,11 @@ def lineage(self) -> list[LineageAtom]: class WorkflowTaskContainer(SourceContainer): - def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job): + def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job, cache: WorkspaceCache | None = None): self._task = task self._job = job self._ws = ws - self._cache = WorkspaceCache(ws) + self._cache = cache or WorkspaceCache(ws) self._named_parameters: dict[str, str] | None = {} self._parameters: list[str] | None = [] self._spark_conf: dict[str, str] | None = {} diff --git a/src/databricks/labs/ucx/source_code/used_table.py b/src/databricks/labs/ucx/source_code/used_table.py index b5cdb77c0b..5fd38f9fde 100644 --- a/src/databricks/labs/ucx/source_code/used_table.py +++ b/src/databricks/labs/ucx/source_code/used_table.py @@ -52,3 +52,10 @@ def _try_fetch(self) -> Iterable[UsedTable]: def _crawl(self) -> Iterable[UsedTable]: return [] # TODO raise NotImplementedError() once CrawlerBase supports empty snapshots + + def for_lineage(self, object_type: str, object_id: str): + sql = f"SELECT * FROM ( \ + SELECT *, explode(source_lineage) as lineage FROM {escape_sql_identifier(self.full_name)} \ + ) where lineage.object_type = '{object_type}' and lineage.object_id = '{object_id}'" + for row in self._sql_backend.fetch(sql): + yield self._klass.from_dict(row.as_dict()) diff --git a/tests/unit/assessment/test_sequencing.py b/tests/unit/assessment/test_sequencing.py new file mode 100644 index 0000000000..090e53719f --- /dev/null +++ b/tests/unit/assessment/test_sequencing.py @@ -0,0 +1,163 @@ +import dataclasses +from datetime import datetime +from unittest.mock import create_autospec + +from pathlib import Path + +from databricks.sdk.service import iam, jobs + +from databricks.sdk.service.compute import ClusterDetails +from databricks.sdk.service.jobs import NotebookTask + +from databricks.labs.ucx.assessment.sequencing import MigrationSequencer, MigrationStep +from databricks.labs.ucx.framework.owners import AdministratorLocator, AdministratorFinder +from databricks.labs.ucx.mixins.cached_workspace_path import WorkspaceCache +from databricks.labs.ucx.source_code.base import CurrentSessionState, UsedTable, LineageAtom +from databricks.labs.ucx.source_code.graph import DependencyGraph, Dependency +from databricks.labs.ucx.source_code.jobs import WorkflowTask +from databricks.labs.ucx.source_code.linters.files import FileLoader +from databricks.labs.ucx.source_code.used_table import UsedTablesCrawler + + +def admin_locator(ws, user_name: str): + admin_finder = create_autospec(AdministratorFinder) + admin_user = iam.User(user_name=user_name, active=True, roles=[iam.ComplexValue(value="account_admin")]) + admin_finder.find_admin_users.return_value = (admin_user,) + return AdministratorLocator(ws, finders=[lambda _ws: admin_finder]) + + +def test_sequencer_builds_cluster_and_children_from_task(ws, simple_dependency_resolver, mock_path_lookup): + ws.clusters.get.return_value = ClusterDetails(cluster_name="my-cluster", creator_user_name="John Doe") + task = jobs.Task(task_key="test-task", existing_cluster_id="cluster-123") + settings = jobs.JobSettings(name="test-job", tasks=[task]) + job = jobs.Job(job_id=1234, settings=settings) + ws.jobs.get.return_value = job + dependency = WorkflowTask(ws, task, job) + graph = DependencyGraph(dependency, None, simple_dependency_resolver, mock_path_lookup, CurrentSessionState()) + used_tables_crawler = create_autospec(UsedTablesCrawler) + used_tables_crawler.assert_not_called() + sequencer = MigrationSequencer(ws, mock_path_lookup, admin_locator(ws, "John Doe"), used_tables_crawler) + sequencer.register_workflow_task(task, job, graph) + steps = list(sequencer.generate_steps()) + step = steps[-1] + # we don't know the ids of the steps, se let's zero them + step = dataclasses.replace(step, step_id=0, required_step_ids=[0] * len(step.required_step_ids)) + assert step == MigrationStep( + step_id=0, + step_number=3, + object_type="CLUSTER", + object_id="cluster-123", + object_name="my-cluster", + object_owner="John Doe", + required_step_ids=[0, 0], + ) + + +def test_sequencer_builds_steps_from_dependency_graph(ws, simple_dependency_resolver, mock_path_lookup): + functional = mock_path_lookup.resolve(Path("functional")) + mock_path_lookup.append_path(functional) + mock_path_lookup = mock_path_lookup.change_directory(functional) + notebook_path = Path("grand_parent_that_imports_parent_that_magic_runs_child.py") + task = jobs.Task( + task_key="test-task", + existing_cluster_id="cluster-123", + notebook_task=NotebookTask(notebook_path=notebook_path.as_posix()), + ) + settings = jobs.JobSettings(name="test-job", tasks=[task]) + job = jobs.Job(job_id=1234, settings=settings) + ws.jobs.get.return_value = job + ws_cache = create_autospec(WorkspaceCache) + ws_cache.get_workspace_path.side_effect = Path + dependency = WorkflowTask(ws, task, job, ws_cache) + container = dependency.load(mock_path_lookup) + graph = DependencyGraph(dependency, None, simple_dependency_resolver, mock_path_lookup, CurrentSessionState()) + problems = container.build_dependency_graph(graph) + assert not problems + used_tables_crawler = create_autospec(UsedTablesCrawler) + used_tables_crawler.assert_not_called() + sequencer = MigrationSequencer(ws, mock_path_lookup, admin_locator(ws, "John Doe"), used_tables_crawler) + sequencer.register_workflow_task(task, job, graph) + all_steps = list(sequencer.generate_steps()) + # ensure steps have a consistent step_number: TASK > grand-parent > parent > child + parent_name = "parent_that_magic_runs_child_that_uses_value_from_parent.py" + steps = [ + next((step for step in all_steps if step.object_name == "_child_that_uses_value_from_parent.py"), None), + next((step for step in all_steps if step.object_name == parent_name), None), + next((step for step in all_steps if step.object_name == notebook_path.as_posix()), None), + next((step for step in all_steps if step.object_type == "TASK"), None), + ] + # ensure steps have a consistent step_number + for i in range(0, len(steps) - 1): + assert steps[i] + assert steps[i].step_number < steps[i + 1].step_number + + +class _DependencyGraph(DependencyGraph): + + def add_dependency(self, graph: DependencyGraph): + self._dependencies[graph.dependency] = graph + + +class _MigrationSequencer(MigrationSequencer): + + def visit_graph(self, graph: DependencyGraph): + graph.visit(self._visit_dependency, None) + + +def test_sequencer_supports_cyclic_dependencies(ws, simple_dependency_resolver, mock_path_lookup): + root = Dependency(FileLoader(), Path("root.py")) + root_graph = _DependencyGraph(root, None, simple_dependency_resolver, mock_path_lookup, CurrentSessionState()) + child_a = Dependency(FileLoader(), Path("a.py")) + child_graph_a = _DependencyGraph( + child_a, root_graph, simple_dependency_resolver, mock_path_lookup, CurrentSessionState() + ) + child_b = Dependency(FileLoader(), Path("b.py")) + child_graph_b = _DependencyGraph( + child_b, root_graph, simple_dependency_resolver, mock_path_lookup, CurrentSessionState() + ) + # root imports a and b + root_graph.add_dependency(child_graph_a) + root_graph.add_dependency(child_graph_b) + # a imports b + child_graph_a.add_dependency(child_graph_b) + # b imports a (using local import) + child_graph_b.add_dependency(child_graph_a) + used_tables_crawler = create_autospec(UsedTablesCrawler) + used_tables_crawler.assert_not_called() + sequencer = _MigrationSequencer(ws, mock_path_lookup, admin_locator(ws, "John Doe"), used_tables_crawler) + sequencer.register_dependency(None, root.lineage[-1].object_type, root.lineage[-1].object_id) + sequencer.visit_graph(root_graph) + steps = list(sequencer.generate_steps()) + assert len(steps) == 3 + assert steps[2].object_id == "root.py" + + +def test_sequencer_builds_steps_from_used_tables(ws, simple_dependency_resolver, mock_path_lookup): + used_tables_crawler = create_autospec(UsedTablesCrawler) + used_tables_crawler.for_lineage.side_effect = lambda object_type, object_id: ( + [] + if object_id != "/some-folder/some-notebook" + else [ + UsedTable( + source_id="/some-folder/some-notebook", + source_timestamp=datetime.now(), + source_lineage=[LineageAtom(object_type="NOTEBOOK", object_id="/some-folder/some-notebook")], + catalog_name="my-catalog", + schema_name="my-schema", + table_name="my-table", + is_read=False, + is_write=False, + ) + ] + ) + sequencer = _MigrationSequencer(ws, mock_path_lookup, admin_locator(ws, "John Doe"), used_tables_crawler) + sequencer.register_dependency(None, object_type="FILE", object_id="/some-folder/some-file") + all_steps = list(sequencer.generate_steps()) + assert len(all_steps) == 1 + sequencer.register_dependency(None, object_type="NOTEBOOK", object_id="/some-folder/some-notebook") + all_steps = list(sequencer.generate_steps()) + assert len(all_steps) == 3 + step = next((step for step in all_steps if step.object_type == "TABLE"), None) + assert step + assert step.step_number == 1 + assert step.object_id == "my-catalog.my-schema.my-table" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2c8cbfd3b2..675f2012e6 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -10,13 +10,17 @@ from databricks.labs.ucx.hive_metastore import TablesCrawler from databricks.labs.ucx.hive_metastore.tables import FasterTableScanCrawler -from databricks.labs.ucx.source_code.graph import BaseNotebookResolver +from databricks.labs.ucx.source_code.graph import BaseNotebookResolver, DependencyResolver +from databricks.labs.ucx.source_code.known import KnownList +from databricks.labs.ucx.source_code.linters.files import ImportFileResolver, FileLoader +from databricks.labs.ucx.source_code.notebooks.loaders import NotebookResolver, NotebookLoader from databricks.labs.ucx.source_code.path_lookup import PathLookup from databricks.sdk import AccountClient from databricks.sdk.config import Config from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.contexts.workflow_task import RuntimeContext +from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver from . import mock_workspace_client @@ -57,8 +61,10 @@ class CustomIterator: def __init__(self, values): self._values = iter(values) self._has_next = True + self._next_value = None - def hasNext(self): # pylint: disable=invalid-name + # pylint: disable=invalid-name + def hasNext(self): try: self._next_value = next(self._values) self._has_next = True @@ -150,9 +156,11 @@ def inner(cb, **replace) -> RuntimeContext: ctx.tables_crawler._spark._jsparkSession.sharedState().externalCatalog().listDatabases.return_value = ( mock_list_databases_iterator ) + # pylint: disable=protected-access ctx.tables_crawler._spark._jsparkSession.sharedState().externalCatalog().listTables.return_value = ( mock_list_tables_iterator ) + # pylint: disable=protected-access ctx.tables_crawler._spark._jsparkSession.sharedState().externalCatalog().getTable.return_value = ( get_table_mock ) @@ -165,8 +173,9 @@ def inner(cb, **replace) -> RuntimeContext: @pytest.fixture def acc_client(): - acc = create_autospec(AccountClient) # pylint: disable=mock-no-usage + acc = create_autospec(AccountClient) acc.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") + acc.assert_not_called() return acc @@ -201,3 +210,12 @@ def mock_backend() -> MockBackend: @pytest.fixture def ws(): return mock_workspace_client() + + +@pytest.fixture +def simple_dependency_resolver(mock_path_lookup: PathLookup) -> DependencyResolver: + allow_list = KnownList() + library_resolver = PythonLibraryResolver(allow_list) + notebook_resolver = NotebookResolver(NotebookLoader()) + import_resolver = ImportFileResolver(FileLoader(), allow_list) + return DependencyResolver(library_resolver, notebook_resolver, import_resolver, import_resolver, mock_path_lookup) diff --git a/tests/unit/source_code/conftest.py b/tests/unit/source_code/conftest.py index 6029ce4d82..9c999d92dc 100644 --- a/tests/unit/source_code/conftest.py +++ b/tests/unit/source_code/conftest.py @@ -1,12 +1,6 @@ import pytest from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex, TableMigrationStatus -from databricks.labs.ucx.source_code.graph import DependencyResolver -from databricks.labs.ucx.source_code.known import KnownList -from databricks.labs.ucx.source_code.linters.files import ImportFileResolver, FileLoader -from databricks.labs.ucx.source_code.notebooks.loaders import NotebookLoader, NotebookResolver -from databricks.labs.ucx.source_code.path_lookup import PathLookup -from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver @pytest.fixture @@ -51,12 +45,3 @@ def extended_test_index(): ), ] ) - - -@pytest.fixture -def simple_dependency_resolver(mock_path_lookup: PathLookup) -> DependencyResolver: - allow_list = KnownList() - library_resolver = PythonLibraryResolver(allow_list) - notebook_resolver = NotebookResolver(NotebookLoader()) - import_resolver = ImportFileResolver(FileLoader(), allow_list) - return DependencyResolver(library_resolver, notebook_resolver, import_resolver, import_resolver, mock_path_lookup)