diff --git a/README.md b/README.md index 95fa73ab..54bf2d1a 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,8 @@ this extension and your Python version needs to be 3.8 or greater. ## Using ZenML in VSCode - **Manage Server Connections**: Connect or disconnect from ZenML servers and refresh server status. -- **Stack Operations**: View stack details, rename, copy, or set active stacks directly from VSCode. +- **Stack Operations**: View stack details, register, update, delete, copy, or set active stacks directly from VSCode. +- **Stack Component Operations**: View stack component details, register, update, or delete stack components directly from VSCode. - **Pipeline Runs**: Monitor and manage pipeline runs, including deleting runs from the system and rendering DAGs. - **Environment Information**: Get detailed snapshots of the development environment, aiding troubleshooting. diff --git a/bundled/tool/lsp_zenml.py b/bundled/tool/lsp_zenml.py index f8f4e3fe..0677c2c5 100644 --- a/bundled/tool/lsp_zenml.py +++ b/bundled/tool/lsp_zenml.py @@ -32,7 +32,9 @@ from zen_watcher import ZenConfigWatcher from zenml_client import ZenMLClient -zenml_init_error = {"error": "ZenML is not initialized. Please check ZenML version requirements."} +zenml_init_error = { + "error": "ZenML is not initialized. Please check ZenML version requirements." +} class ZenLanguageServer(LanguageServer): @@ -58,7 +60,9 @@ async def is_zenml_installed(self) -> bool: if process.returncode == 0: self.show_message_log("✅ ZenML installation check: Successful.") return True - self.show_message_log("❌ ZenML installation check failed.", lsp.MessageType.Error) + self.show_message_log( + "❌ ZenML installation check failed.", lsp.MessageType.Error + ) return False except Exception as e: self.show_message_log( @@ -93,7 +97,9 @@ async def initialize_zenml_client(self): # initialize watcher self.initialize_global_config_watcher() except Exception as e: - self.notify_user(f"Failed to initialize ZenML client: {str(e)}", lsp.MessageType.Error) + self.notify_user( + f"Failed to initialize ZenML client: {str(e)}", lsp.MessageType.Error + ) def initialize_global_config_watcher(self): """Sets up and starts the Global Configuration Watcher.""" @@ -133,7 +139,9 @@ def wrapper(*args, **kwargs): with suppress_stdout_temporarily(): if wrapper_name: - wrapper_instance = getattr(self.zenml_client, wrapper_name, None) + wrapper_instance = getattr( + self.zenml_client, wrapper_name, None + ) if not wrapper_instance: return {"error": f"Wrapper '{wrapper_name}' not found."} return func(wrapper_instance, *args, **kwargs) @@ -177,25 +185,33 @@ def _construct_version_validation_response(self, meets_requirement, version_str) def send_custom_notification(self, method: str, args: dict): """Sends a custom notification to the LSP client.""" - self.show_message_log(f"Sending custom notification: {method} with args: {args}") + self.show_message_log( + f"Sending custom notification: {method} with args: {args}" + ) self.send_notification(method, args) def update_python_interpreter(self, interpreter_path): """Updates the Python interpreter path and handles errors.""" try: self.python_interpreter = interpreter_path - self.show_message_log(f"LSP_Python_Interpreter Updated: {self.python_interpreter}") + self.show_message_log( + f"LSP_Python_Interpreter Updated: {self.python_interpreter}" + ) # pylint: disable=broad-exception-caught except Exception as e: self.show_message_log( f"Failed to update Python interpreter: {str(e)}", lsp.MessageType.Error ) - def notify_user(self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Info): + def notify_user( + self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Info + ): """Logs a message and also notifies the user.""" self.show_message(message, msg_type) - def log_to_output(self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Log) -> None: + def log_to_output( + self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Log + ) -> None: """Log to output.""" self.show_message_log(message, msg_type) @@ -261,6 +277,60 @@ def rename_stack(wrapper_instance, args): def copy_stack(wrapper_instance, args): """Copies a specified ZenML stack to a new stack.""" return wrapper_instance.copy_stack(args) + + @self.command(f"{TOOL_MODULE_NAME}.registerStack") + @self.zenml_command(wrapper_name="stacks_wrapper") + def register_stack(wrapper_instance, args): + """Registers a new ZenML stack.""" + return wrapper_instance.register_stack(args) + + @self.command(f"{TOOL_MODULE_NAME}.updateStack") + @self.zenml_command(wrapper_name="stacks_wrapper") + def update_stack(wrapper_instance, args): + """Updates a specified ZenML stack .""" + return wrapper_instance.update_stack(args) + + @self.command(f"{TOOL_MODULE_NAME}.deleteStack") + @self.zenml_command(wrapper_name="stacks_wrapper") + def delete_stack(wrapper_instance, args): + """Deletes a specified ZenML stack .""" + return wrapper_instance.delete_stack(args) + + @self.command(f"{TOOL_MODULE_NAME}.registerComponent") + @self.zenml_command(wrapper_name="stacks_wrapper") + def register_component(wrapper_instance, args): + """Registers a Zenml stack component""" + return wrapper_instance.register_component(args) + + @self.command(f"{TOOL_MODULE_NAME}.updateComponent") + @self.zenml_command(wrapper_name="stacks_wrapper") + def update_component(wrapper_instance, args): + """Updates a ZenML stack component""" + return wrapper_instance.update_component(args) + + @self.command(f"{TOOL_MODULE_NAME}.deleteComponent") + @self.zenml_command(wrapper_name="stacks_wrapper") + def delete_component(wrapper_instance, args): + """Deletes a specified ZenML stack component""" + return wrapper_instance.delete_component(args) + + @self.command(f"{TOOL_MODULE_NAME}.listComponents") + @self.zenml_command(wrapper_name="stacks_wrapper") + def list_components(wrapper_instance, args): + """Get paginated list of stack components from ZenML""" + return wrapper_instance.list_components(args) + + @self.command(f"{TOOL_MODULE_NAME}.getComponentTypes") + @self.zenml_command(wrapper_name="stacks_wrapper") + def get_component_types(wrapper_instance, args): + """Get list of component types from ZenML""" + return wrapper_instance.get_component_types() + + @self.command(f"{TOOL_MODULE_NAME}.listFlavors") + @self.zenml_command(wrapper_name="stacks_wrapper") + def list_flavors(wrapper_instance, args): + """Get paginated list of component flavors from ZenML""" + return wrapper_instance.list_flavors(args) @self.command(f"{TOOL_MODULE_NAME}.getPipelineRuns") @self.zenml_command(wrapper_name="pipeline_runs_wrapper") @@ -273,13 +343,13 @@ def fetch_pipeline_runs(wrapper_instance, args): def delete_pipeline_run(wrapper_instance, args): """Deletes a specified ZenML pipeline run.""" return wrapper_instance.delete_pipeline_run(args) - + @self.command(f"{TOOL_MODULE_NAME}.getPipelineRun") @self.zenml_command(wrapper_name="pipeline_runs_wrapper") def get_pipeline_run(wrapper_instance, args): """Gets a specified ZenML pipeline run.""" return wrapper_instance.get_pipeline_run(args) - + @self.command(f"{TOOL_MODULE_NAME}.getPipelineRunStep") @self.zenml_command(wrapper_name="pipeline_runs_wrapper") def get_run_step(wrapper_instance, args): @@ -291,9 +361,9 @@ def get_run_step(wrapper_instance, args): def get_run_artifact(wrapper_instance, args): """Gets a specified ZenML pipeline artifact""" return wrapper_instance.get_run_artifact(args) - + @self.command(f"{TOOL_MODULE_NAME}.getPipelineRunDag") @self.zenml_command(wrapper_name="pipeline_runs_wrapper") def get_run_dag(wrapper_instance, args): """Gets graph data for a specified ZenML pipeline run""" - return wrapper_instance.get_pipeline_run_graph(args) \ No newline at end of file + return wrapper_instance.get_pipeline_run_graph(args) diff --git a/bundled/tool/type_hints.py b/bundled/tool/type_hints.py index f5441944..ce524886 100644 --- a/bundled/tool/type_hints.py +++ b/bundled/tool/type_hints.py @@ -1,4 +1,16 @@ -from typing import Any, TypedDict, Dict, List, Union +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +from typing import Any, TypedDict, Dict, List, Optional from uuid import UUID @@ -20,8 +32,6 @@ class GraphEdge(TypedDict): source: str target: str - - class GraphResponse(TypedDict): nodes: List[GraphNode] edges: List[GraphEdge] @@ -37,9 +47,9 @@ class RunStepResponse(TypedDict): id: str status: str author: Dict[str, str] - startTime: Union[str, None] - endTime: Union[str, None] - duration: Union[str, None] + startTime: Optional[str] + endTime: Optional[str] + duration: Optional[str] stackName: str orchestrator: Dict[str, str] pipeline: Dict[str, str] @@ -71,7 +81,7 @@ class ZenmlStoreInfo(TypedDict): class ZenmlStoreConfig(TypedDict): type: str url: str - api_token: Union[str, None] + api_token: Optional[str] class ZenmlServerInfoResp(TypedDict): store_info: ZenmlStoreInfo @@ -84,4 +94,37 @@ class ZenmlGlobalConfigResp(TypedDict): version: str active_stack_id: str active_workspace_name: str - store: ZenmlStoreConfig \ No newline at end of file + store: ZenmlStoreConfig + +class StackComponent(TypedDict): + id: str + name: str + flavor: str + type: str + config: Dict[str, Any] + +class ListComponentsResponse(TypedDict): + index: int + max_size: int + total_pages: int + total: int + items: List[StackComponent] + +class Flavor(TypedDict): + id: str + name: str + type: str + logo_url: str + config_schema: Dict[str, Any] + docs_url: Optional[str] + sdk_docs_url: Optional[str] + connector_type: Optional[str] + connector_resource_type: Optional[str] + connector_resource_id_attr: Optional[str] + +class ListFlavorsResponse(TypedDict): + index: int + max_size: int + total_pages: int + total: int + items: List[Flavor] \ No newline at end of file diff --git a/bundled/tool/zenml_wrappers.py b/bundled/tool/zenml_wrappers.py index b4953ff2..f33fd99c 100644 --- a/bundled/tool/zenml_wrappers.py +++ b/bundled/tool/zenml_wrappers.py @@ -13,9 +13,18 @@ """This module provides wrappers for ZenML configuration and operations.""" import pathlib -from typing import Any, Tuple, Union -from type_hints import GraphResponse, ErrorResponse, RunStepResponse, RunArtifactResponse, ZenmlServerInfoResp, ZenmlGlobalConfigResp +from typing import Any, Tuple, Union, List, Optional, Dict from zenml_grapher import Grapher +from type_hints import ( + GraphResponse, + ErrorResponse, + RunStepResponse, + RunArtifactResponse, + ZenmlServerInfoResp, + ZenmlGlobalConfigResp, + ListComponentsResponse, + ListFlavorsResponse +) class GlobalConfigWrapper: @@ -48,7 +57,9 @@ def get_global_config_directory(self): def RestZenStoreConfiguration(self): """Returns the RestZenStoreConfiguration class for store configuration.""" # pylint: disable=not-callable - return self.lazy_import("zenml.zen_stores.rest_zen_store", "RestZenStoreConfiguration") + return self.lazy_import( + "zenml.zen_stores.rest_zen_store", "RestZenStoreConfiguration" + ) def get_global_config_directory_path(self) -> str: """Get the global configuration directory path. @@ -189,7 +200,9 @@ def get_server_info(self) -> ZenmlServerInfoResp: # Handle both 'store' and 'store_configuration' depending on version store_attr_name = ( - "store_configuration" if hasattr(self.gc, "store_configuration") else "store" + "store_configuration" + if hasattr(self.gc, "store_configuration") + else "store" ) store_config = getattr(self.gc, store_attr_name) @@ -229,7 +242,9 @@ def connect(self, args, **kwargs) -> dict: try: # pylint: disable=not-callable access_token = self.web_login(url=url, verify_ssl=verify_ssl) - self._config_wrapper.set_store_configuration(remote_url=url, access_token=access_token) + self._config_wrapper.set_store_configuration( + remote_url=url, access_token=access_token + ) return {"message": "Connected successfully.", "access_token": access_token} except self.AuthorizationException as e: return {"error": f"Authorization failed: {str(e)}"} @@ -245,7 +260,9 @@ def disconnect(self, args) -> dict: try: # Adjust for changes from 'store' to 'store_configuration' store_attr_name = ( - "store_configuration" if hasattr(self.gc, "store_configuration") else "store" + "store_configuration" + if hasattr(self.gc, "store_configuration") + else "store" ) url = getattr(self.gc, store_attr_name).url store_type = self.BaseZenStore.get_store_type(url) @@ -314,15 +331,21 @@ def fetch_pipeline_runs(self, args): "version": run.body.pipeline.body.version, "stackName": run.body.stack.name, "startTime": ( - run.metadata.start_time.isoformat() if run.metadata.start_time else None + run.metadata.start_time.isoformat() + if run.metadata.start_time + else None ), "endTime": ( - run.metadata.end_time.isoformat() if run.metadata.end_time else None + run.metadata.end_time.isoformat() + if run.metadata.end_time + else None ), "os": run.metadata.client_environment.get("os", "Unknown OS"), "osVersion": run.metadata.client_environment.get( "os_version", - run.metadata.client_environment.get("mac_version", "Unknown Version"), + run.metadata.client_environment.get( + "mac_version", "Unknown Version" + ), ), "pythonVersion": run.metadata.client_environment.get( "python_version", "Unknown" @@ -357,10 +380,10 @@ def delete_pipeline_run(self, args) -> dict: return {"message": f"Pipeline run `{run_id}` deleted successfully."} except self.ZenMLBaseException as e: return {"error": f"Failed to delete pipeline run: {str(e)}"} - + def get_pipeline_run(self, args: Tuple[str]) -> dict: """Gets a ZenML pipeline run. - + Args: args (list): List of arguments. Returns: @@ -371,33 +394,39 @@ def get_pipeline_run(self, args: Tuple[str]) -> dict: run = self.client.get_pipeline_run(run_id, hydrate=True) run_data = { "id": str(run.id), - "name": run.body.pipeline.name, - "status": run.body.status, - "version": run.body.pipeline.body.version, - "stackName": run.body.stack.name, - "startTime": ( - run.metadata.start_time.isoformat() if run.metadata.start_time else None - ), - "endTime": ( - run.metadata.end_time.isoformat() if run.metadata.end_time else None - ), - "os": run.metadata.client_environment.get("os", "Unknown OS"), - "osVersion": run.metadata.client_environment.get( - "os_version", - run.metadata.client_environment.get("mac_version", "Unknown Version"), - ), - "pythonVersion": run.metadata.client_environment.get( - "python_version", "Unknown" + "name": run.body.pipeline.name, + "status": run.body.status, + "version": run.body.pipeline.body.version, + "stackName": run.body.stack.name, + "startTime": ( + run.metadata.start_time.isoformat() + if run.metadata.start_time + else None + ), + "endTime": ( + run.metadata.end_time.isoformat() if run.metadata.end_time else None + ), + "os": run.metadata.client_environment.get("os", "Unknown OS"), + "osVersion": run.metadata.client_environment.get( + "os_version", + run.metadata.client_environment.get( + "mac_version", "Unknown Version" ), + ), + "pythonVersion": run.metadata.client_environment.get( + "python_version", "Unknown" + ), } return run_data except self.ZenMLBaseException as e: return {"error": f"Failed to retrieve pipeline run: {str(e)}"} - - def get_pipeline_run_graph(self, args: Tuple[str]) -> Union[GraphResponse, ErrorResponse]: + + def get_pipeline_run_graph( + self, args: Tuple[str] + ) -> Union[GraphResponse, ErrorResponse]: """Gets a ZenML pipeline run step DAG. - + Args: args (list): List of arguments. Returns: @@ -415,7 +444,7 @@ def get_pipeline_run_graph(self, args: Tuple[str]) -> Union[GraphResponse, Error def get_run_step(self, args: Tuple[str]) -> Union[RunStepResponse, ErrorResponse]: """Gets a ZenML pipeline run step. - + Args: args (list): List of arguments. Returns: @@ -424,7 +453,9 @@ def get_run_step(self, args: Tuple[str]) -> Union[RunStepResponse, ErrorResponse try: step_run_id = args[0] step = self.client.get_run_step(step_run_id, hydrate=True) - run = self.client.get_pipeline_run(step.metadata.pipeline_run_id, hydrate=True) + run = self.client.get_pipeline_run( + step.metadata.pipeline_run_id, hydrate=True + ) step_data = { "name": step.name, @@ -435,18 +466,22 @@ def get_run_step(self, args: Tuple[str]) -> Union[RunStepResponse, ErrorResponse "email": step.body.user.name, }, "startTime": ( - step.metadata.start_time.isoformat() if step.metadata.start_time else None + step.metadata.start_time.isoformat() + if step.metadata.start_time + else None ), "endTime": ( - step.metadata.end_time.isoformat() if step.metadata.end_time else None + step.metadata.end_time.isoformat() + if step.metadata.end_time + else None ), "duration": ( - str(step.metadata.end_time - step.metadata.start_time) if step.metadata.end_time and step.metadata.start_time else None + str(step.metadata.end_time - step.metadata.start_time) + if step.metadata.end_time and step.metadata.start_time + else None ), "stackName": run.body.stack.name, - "orchestrator": { - "runId": str(run.metadata.orchestrator_run_id) - }, + "orchestrator": {"runId": str(run.metadata.orchestrator_run_id)}, "pipeline": { "name": run.body.pipeline.name, "status": run.body.status, @@ -454,15 +489,17 @@ def get_run_step(self, args: Tuple[str]) -> Union[RunStepResponse, ErrorResponse }, "cacheKey": step.metadata.cache_key, "sourceCode": step.metadata.source_code, - "logsUri": step.metadata.logs.body.uri + "logsUri": step.metadata.logs.body.uri, } return step_data except self.ZenMLBaseException as e: return {"error": f"Failed to retrieve pipeline run step: {str(e)}"} - - def get_run_artifact(self, args: Tuple[str]) -> Union[RunArtifactResponse, ErrorResponse]: + + def get_run_artifact( + self, args: Tuple[str] + ) -> Union[RunArtifactResponse, ErrorResponse]: """Gets a ZenML pipeline run artifact. - + Args: args (list): List of arguments. Returns: @@ -528,6 +565,11 @@ def IllegalOperationError(self) -> Any: def StackComponentValidationError(self): """Returns the ZenML StackComponentValidationError class.""" return self.lazy_import("zenml.exceptions", "StackComponentValidationError") + + @property + def StackComponentType(self): + """Returns the ZenML StackComponentType enum.""" + return self.lazy_import("zenml.enums", "StackComponentType") @property def ZenKeyError(self) -> Any: @@ -540,7 +582,9 @@ def fetch_stacks(self, args): return {"error": "Insufficient arguments provided."} page, max_size = args try: - stacks_page = self.client.list_stacks(page=page, size=max_size, hydrate=True) + stacks_page = self.client.list_stacks( + page=page, size=max_size, hydrate=True + ) stacks_data = self.process_stacks(stacks_page.items) return { @@ -653,17 +697,23 @@ def copy_stack(self, args) -> dict: target_stack_name = args[1] if not source_stack_name_or_id or not target_stack_name: - return {"error": "Both source stack name/id and target stack name are required"} + return { + "error": "Both source stack name/id and target stack name are required" + } try: - stack_to_copy = self.client.get_stack(name_id_or_prefix=source_stack_name_or_id) + stack_to_copy = self.client.get_stack( + name_id_or_prefix=source_stack_name_or_id + ) component_mapping = { c_type: [c.id for c in components][0] for c_type, components in stack_to_copy.components.items() if components } - self.client.create_stack(name=target_stack_name, components=component_mapping) + self.client.create_stack( + name=target_stack_name, components=component_mapping + ) return { "message": ( f"Stack `{source_stack_name_or_id}` successfully copied " @@ -675,3 +725,207 @@ def copy_stack(self, args) -> dict: self.StackComponentValidationError, ) as e: return {"error": str(e)} + + def register_stack(self, args: Tuple[str, Dict[str, str]]) -> Dict[str, str]: + """Registers a new ZenML Stack. + + Args: + args (list): List containing the name and chosen components for the stack. + Returns: + Dictionary containing a message relevant to whether the action succeeded or failed + """ + [name, components] = args + + try: + self.client.create_stack(name, components) + return {"message": f"Stack {name} successfully registered"} + except self.ZenMLBaseException as e: + return {"error": str(e)} + + def update_stack(self, args: Tuple[str, str, Dict[str, List[str]]]) -> Dict[str, str]: + """Updates a specified ZenML Stack. + + Args: + args (list): List containing the id of the stack being updated, the new name, and the chosen components. + Returns: + Dictionary containing a message relevant to whether the action succeeded or failed + """ + [id, name, components] = args + + try: + old = self.client.get_stack(id) + if old.name == name: + self.client.update_stack(name_id_or_prefix=id, component_updates=components) + else: + self.client.update_stack(name_id_or_prefix=id, name=name, component_updates=components) + + return {"message": f"Stack {name} successfully updated."} + except self.ZenMLBaseException as e: + return {"error": str(e)} + + def delete_stack(self, args: Tuple[str]) -> Dict[str, str]: + """Deletes a specified ZenML stack. + + Args: + args (list): List containing the id of the stack to delete. + Returns: + Dictionary containing a message relevant to whether the action succeeded or failed + """ + [id] = args + + try: + self.client.delete_stack(id) + + return {"message": f"Stack {id} successfully deleted."} + except self.ZenMLBaseException as e: + return {"error": str(e)} + + def register_component(self, args: Tuple[str, str, str, Dict[str, str]]) -> Dict[str, str]: + """Registers a new ZenML stack component. + + Args: + args (list): List containing the component type, flavor used, name, and configuration of the desired new component. + Returns: + Dictionary containing a message relevant to whether the action succeeded or failed + """ + [component_type, flavor, name, configuration] = args + + try: + self.client.create_stack_component(name, flavor, component_type, configuration) + + return {"message": f"Stack Component {name} successfully registered"} + except self.ZenMLBaseException as e: + return {"error": str(e)} + + def update_component(self, args: Tuple[str, str, str, Dict[str, str]]) -> Dict[str, str]: + """Updates a specified ZenML stack component. + + Args: + args (list): List containing the id, component type, new name, and desired configuration of the desired component. + Returns: + Dictionary containing a message relevant to whether the action succeeded or failed + """ + [id, component_type, name, configuration] = args + + try: + old = self.client.get_stack_component(component_type, id) + + new_name = None if old.name == name else name + + self.client.update_stack_component(id, component_type, name=new_name, configuration=configuration) + + return {"message": f"Stack Component {name} successfully updated"} + except self.ZenMLBaseException as e: + return {"error": str(e)} + + def delete_component(self, args: Tuple[str, str]) -> Dict[str, str]: + """Deletes a specified ZenML stack component. + + Args: + args (list): List containing the id and component type of the desired component. + Returns: + Dictionary containing a message relevant to whether the action succeeded or failed + """ + [id, component_type] = args + + try: + self.client.delete_stack_component(id, component_type) + + return {"mesage": f"Stack Component {id} successfully deleted"} + except self.ZenMLBaseException as e: + return {"error": str(e)} + + def list_components(self, args: Tuple[int, int, Union[str, None]]) -> Union[ListComponentsResponse,ErrorResponse]: + """Lists stack components in a paginated way. + + Args: + args (list): List containing the page, maximum items per page, and an optional type filter used to retrieve expected components. + Returns: + A Dictionary containing the paginated results or an error message specifying why the action failed. + """ + if len(args) < 2: + return {"error": "Insufficient arguments provided."} + + page = args[0] + max_size = args[1] + filter = None + + if len(args) >= 3: + filter = args[2] + + try: + components = self.client.list_stack_components(page=page, size=max_size, type=filter, hydrate=True) + + return { + "index": components.index, + "max_size": components.max_size, + "total_pages": components.total_pages, + "total": components.total, + "items": [ + { + "id": str(item.id), + "name": item.name, + "flavor": item.body.flavor, + "type": item.body.type, + "config": item.metadata.configuration, + } + for item in components.items + ], + } + except self.ZenMLBaseException as e: + return {"error": f"Failed to retrieve list of stack components: {str(e)}"} + + def get_component_types(self) -> Union[List[str], ErrorResponse]: + """Gets a list of all component types. + + Returns: + A list of component types or a dictionary containing an error message specifying why the action failed. + """ + try: + return self.StackComponentType.values() + except self.ZenMLBaseException as e: + return {"error": f"Failed to retrieve list of component types: {str(e)}"} + + def list_flavors(self, args: Tuple[int, int, Optional[str]]) -> Union[ListFlavorsResponse, ErrorResponse]: + """Lists stack component flavors in a paginated way. + + Args: + args (list): List containing page, max items per page, and an optional component type filter used to retrieve expected component flavors. + Returns: + A Dictionary containing the paginated results or an error message specifying why the action failed. + """ + if len(args) < 2: + return {"error": "Insufficient arguments provided."} + + page = args[0] + max_size = args[1] + filter = None + if len(args) >= 3: + filter = args[2] + + try: + flavors = self.client.list_flavors(page=page, size=max_size, type=filter, hydrate=True) + + return { + "index": flavors.index, + "max_size": flavors.max_size, + "total_pages": flavors.total_pages, + "total": flavors.total, + "items": [ + { + "id": str(flavor.id), + "name": flavor.name, + "type": flavor.body.type, + "logo_url": flavor.body.logo_url, + "config_schema": flavor.metadata.config_schema, + "docs_url": flavor.metadata.docs_url, + "sdk_docs_url": flavor.metadata.sdk_docs_url, + "connector_type": flavor.metadata.connector_type, + "connector_resource_type": flavor.metadata.connector_resource_type, + "connector_resource_id_attr": flavor.metadata.connector_resource_id_attr, + } for flavor in flavors.items + ] + } + + except self.ZenMLBaseException as e: + return {"error": f"Failed to retrieve list of flavors: {str(e)}"} \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index 8960e019..df46f7e2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,6 +15,7 @@ "axios": "^1.6.7", "dagre": "^0.8.5", "fs-extra": "^11.2.0", + "hbs": "^4.2.0", "svg-pan-zoom": "github:bumbu/svg-pan-zoom", "svgdom": "^0.1.19", "vscode-languageclient": "^9.0.1" @@ -22,6 +23,7 @@ "devDependencies": { "@types/dagre": "^0.7.52", "@types/fs-extra": "^11.0.4", + "@types/hbs": "^4.0.4", "@types/mocha": "^10.0.6", "@types/node": "^18.19.18", "@types/sinon": "^17.0.3", @@ -514,6 +516,16 @@ "@types/node": "*" } }, + "node_modules/@types/hbs": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/@types/hbs/-/hbs-4.0.4.tgz", + "integrity": "sha512-GH3SIb2tzDBnTByUSOIVcD6AcLufnydBllTuFAIAGMhqPNbz8GL4tLryVdNqhq0NQEb5mVpu2FJOrUeqwJrPtg==", + "dev": true, + "license": "MIT", + "dependencies": { + "handlebars": "^4.1.0" + } + }, "node_modules/@types/json-schema": { "version": "7.0.15", "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", @@ -2579,6 +2591,12 @@ "unicode-trie": "^2.0.0" } }, + "node_modules/foreachasync": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz", + "integrity": "sha512-J+ler7Ta54FwwNcx6wQRDhTIbNeyDcARMkOcguEqnEdtm0jKvN3Li3PDAb2Du3ubJYEWfYL83XMROXdsXAXycw==", + "license": "Apache2" + }, "node_modules/foreground-child": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.1.1.tgz", @@ -2798,6 +2816,36 @@ "lodash": "^4.17.15" } }, + "node_modules/handlebars": { + "version": "4.7.7", + "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz", + "integrity": "sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==", + "license": "MIT", + "dependencies": { + "minimist": "^1.2.5", + "neo-async": "^2.6.0", + "source-map": "^0.6.1", + "wordwrap": "^1.0.0" + }, + "bin": { + "handlebars": "bin/handlebars" + }, + "engines": { + "node": ">=0.4.7" + }, + "optionalDependencies": { + "uglify-js": "^3.1.4" + } + }, + "node_modules/handlebars/node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/has-flag": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", @@ -2855,6 +2903,20 @@ "node": ">= 0.4" } }, + "node_modules/hbs": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.2.0.tgz", + "integrity": "sha512-dQwHnrfWlTk5PvG9+a45GYpg0VpX47ryKF8dULVd6DtwOE6TEcYQXQ5QM6nyOx/h7v3bvEQbdn19EDAcfUAgZg==", + "license": "MIT", + "dependencies": { + "handlebars": "4.7.7", + "walk": "2.3.15" + }, + "engines": { + "node": ">= 0.8", + "npm": "1.2.8000 || >= 1.4.16" + } + }, "node_modules/he": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/he/-/he-1.2.0.tgz", @@ -3645,8 +3707,6 @@ "version": "1.2.8", "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", - "dev": true, - "optional": true, "funding": { "url": "https://github.com/sponsors/ljharb" } @@ -3929,8 +3989,7 @@ "node_modules/neo-async": { "version": "2.6.2", "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz", - "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", - "dev": true + "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==" }, "node_modules/nise": { "version": "5.1.9", @@ -5481,6 +5540,19 @@ "integrity": "sha512-8Y75pvTYkLJW2hWQHXxoqRgV7qb9B+9vFEtidML+7koHUFapnVJAZ6cKs+Qjz5Aw3aZWHMC6u0wJE3At+nSGwA==", "dev": true }, + "node_modules/uglify-js": { + "version": "3.19.0", + "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.19.0.tgz", + "integrity": "sha512-wNKHUY2hYYkf6oSFfhwwiHo4WCHzHmzcXsqXYTN9ja3iApYIFbb2U6ics9hBcYLHcYGQoAlwnZlTrf3oF+BL/Q==", + "license": "BSD-2-Clause", + "optional": true, + "bin": { + "uglifyjs": "bin/uglifyjs" + }, + "engines": { + "node": ">=0.8.0" + } + }, "node_modules/underscore": { "version": "1.13.6", "resolved": "https://registry.npmjs.org/underscore/-/underscore-1.13.6.tgz", @@ -5627,6 +5699,15 @@ "resolved": "https://registry.npmjs.org/vscode-languageserver-types/-/vscode-languageserver-types-3.17.5.tgz", "integrity": "sha512-Ld1VelNuX9pdF39h2Hgaeb5hEZM2Z3jUrrMgWQAu82jMtZp7p3vJT3BzToKtZI7NgQssZje5o0zryOrhQvzQAg==" }, + "node_modules/walk": { + "version": "2.3.15", + "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.15.tgz", + "integrity": "sha512-4eRTBZljBfIISK1Vnt69Gvr2w/wc3U6Vtrw7qiN5iqYJPH7LElcYh/iU4XWhdCy2dZqv1ToMyYlybDylfG/5Vg==", + "license": "(MIT OR Apache-2.0)", + "dependencies": { + "foreachasync": "^3.0.0" + } + }, "node_modules/watchpack": { "version": "2.4.1", "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.1.tgz", @@ -5807,6 +5888,12 @@ "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", "dev": true }, + "node_modules/wordwrap": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-1.0.0.tgz", + "integrity": "sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==", + "license": "MIT" + }, "node_modules/workerpool": { "version": "6.2.1", "resolved": "https://registry.npmjs.org/workerpool/-/workerpool-6.2.1.tgz", diff --git a/package.json b/package.json index 99bfd6b0..bfcdf6fb 100644 --- a/package.json +++ b/package.json @@ -185,8 +185,14 @@ "category": "ZenML Stacks" }, { - "command": "zenml.renameStack", - "title": "Rename Stack", + "command": "zenml.registerStack", + "title": "Register New Stack", + "icon": "$(add)", + "category": "ZenML Stacks" + }, + { + "command": "zenml.updateStack", + "title": "Update Stack", "icon": "$(edit)", "category": "ZenML Stacks" }, @@ -200,13 +206,49 @@ "command": "zenml.copyStack", "title": "Copy Stack", "icon": "$(copy)", - "category": "ZenML" + "category": "ZenML Stacks" }, { "command": "zenml.goToStackUrl", "title": "Go to URL", "icon": "$(globe)", - "category": "ZenML" + "category": "ZenML Stacks" + }, + { + "command": "zenml.deleteStack", + "title": "Delete Stack", + "icon": "$(trash)", + "category": "ZenML Stacks" + }, + { + "command": "zenml.setComponentItemsPerPage", + "title": "Set Components Per Page", + "icon": "$(layers)", + "category": "ZenML Components" + }, + { + "command": "zenml.refreshComponentView", + "title": "Refresh Component View", + "icon": "$(refresh)", + "category": "ZenML Components" + }, + { + "command": "zenml.registerComponent", + "title": "Register New Component", + "icon": "$(add)", + "category": "ZenML Components" + }, + { + "command": "zenml.updateComponent", + "title": "Update Component", + "icon": "$(edit)", + "category": "ZenML Components" + }, + { + "command": "zenml.deleteComponent", + "title": "Delete Component", + "icon": "$(trash)", + "category": "ZenML Components" }, { "command": "zenml.setPipelineRunsPerPage", @@ -285,6 +327,11 @@ "name": "Stacks", "icon": "$(layers)" }, + { + "id": "zenmlComponentView", + "name": "Stack Components", + "icon": "$(extensions)" + }, { "id": "zenmlPipelineView", "name": "Pipeline Runs", @@ -322,14 +369,34 @@ }, { "when": "stackCommandsRegistered && view == zenmlStackView", - "command": "zenml.setStackItemsPerPage", + "command": "zenml.registerStack", "group": "navigation@1" }, + { + "when": "stackCommandsRegistered && view == zenmlStackView", + "command": "zenml.setStackItemsPerPage", + "group": "navigation@2" + }, { "when": "stackCommandsRegistered && view == zenmlStackView", "command": "zenml.refreshStackView", + "group": "navigation@3" + }, + { + "when": "componentCommandsRegistered && view == zenmlComponentView", + "command": "zenml.registerComponent", + "group": "navigation@1" + }, + { + "when": "componentCommandsRegistered && view == zenmlComponentView", + "command": "zenml.setComponentItemsPerPage", "group": "navigation@2" }, + { + "when": "componentCommandsRegistered && view == zenmlComponentView", + "command": "zenml.refreshComponentView", + "group": "navigation@3" + }, { "when": "pipelineCommandsRegistered && view == zenmlPipelineView", "command": "zenml.setPipelineRunsPerPage", @@ -354,7 +421,7 @@ }, { "when": "stackCommandsRegistered && view == zenmlStackView && viewItem == stack", - "command": "zenml.renameStack", + "command": "zenml.updateStack", "group": "inline@2" }, { @@ -367,6 +434,21 @@ "command": "zenml.goToStackUrl", "group": "inline@4" }, + { + "when": "stackCommandsRegistered && view == zenmlStackView && viewItem == stack", + "command": "zenml.deleteStack", + "group": "inline@5" + }, + { + "when": "componentCommandsRegistered && view == zenmlComponentView && viewItem == stackComponent", + "command": "zenml.updateComponent", + "group": "inline@1" + }, + { + "when": "componentCommandsRegistered && view == zenmlComponentView && viewItem == stackComponent", + "command": "zenml.deleteComponent", + "group": "inline@2" + }, { "when": "pipelineCommandsRegistered && view == zenmlPipelineView && viewItem == pipelineRun", "command": "zenml.deletePipelineRun", @@ -393,6 +475,7 @@ "devDependencies": { "@types/dagre": "^0.7.52", "@types/fs-extra": "^11.0.4", + "@types/hbs": "^4.0.4", "@types/mocha": "^10.0.6", "@types/node": "^18.19.18", "@types/sinon": "^17.0.3", @@ -420,6 +503,7 @@ "axios": "^1.6.7", "dagre": "^0.8.5", "fs-extra": "^11.2.0", + "hbs": "^4.2.0", "svg-pan-zoom": "github:bumbu/svg-pan-zoom", "svgdom": "^0.1.19", "vscode-languageclient": "^9.0.1" diff --git a/resources/components-form/components.css b/resources/components-form/components.css new file mode 100644 index 00000000..d8b7766a --- /dev/null +++ b/resources/components-form/components.css @@ -0,0 +1,87 @@ +/* Copyright(c) ZenML GmbH 2024. All Rights Reserved. + Licensed under the Apache License, Version 2.0(the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + or implied.See the License for the specific language governing + permissions and limitations under the License. */ +.container { + max-width: 600px; + margin: auto; +} + +.block { + padding: 10px 10px 5px 10px; + border: 2px var(--vscode-editor-foreground) solid; + margin-bottom: 10px; +} + +.logo { + float: right; + max-width: 100px; +} + +.docs { + clear: both; + display: flex; + justify-content: space-around; + padding: 5px 0px; +} + +.button, +button { + padding: 2px 5px; + background-color: var(--vscode-editor-background); + color: var(--vscode-editor-foreground); + border: 2px var(--vscode-editor-foreground) solid; + border-radius: 10px; +} + +.field { + margin-bottom: 10px; +} + +.value { + width: 100%; + box-sizing: border-box; + padding-left: 20px; +} + +.input { + width: 100%; +} + +.center { + display: flex; + align-items: center; + justify-content: center; +} + +.loader { + width: 20px; + height: 20px; + border: 5px solid #fff; + border-bottom-color: #ff3d00; + border-radius: 50%; + display: inline-block; + box-sizing: border-box; + animation: rotation 1s linear infinite; +} + +@keyframes rotation { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} + +.hidden { + display: none; +} diff --git a/resources/components-form/components.js b/resources/components-form/components.js new file mode 100644 index 00000000..f759ea55 --- /dev/null +++ b/resources/components-form/components.js @@ -0,0 +1,160 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +const form = document.querySelector('form'); +const submit = document.querySelector('input[type="submit"]'); +const spinner = document.querySelector('.loader'); +const title = document.querySelector('h2'); + +let mode = 'register'; +let type = ''; +let flavor = ''; +let id = ''; + +const inputs = {}; + +document.querySelectorAll('.input').forEach(element => { + inputs[element.id] = element; + if (element instanceof HTMLTextAreaElement) { + element.addEventListener('input', evt => { + try { + const val = JSON.parse(evt.target.value); + if (evt.target.dataset.array && !Array.isArray(val)) { + element.setCustomValidity('Must be an array'); + element.reportValidity(); + return; + } + } catch { + element.setCustomValidity('Invalid JSON value'); + element.reportValidity(); + return; + } + element.setCustomValidity(''); + }); + } +}); + +const setValues = (name, config) => { + document.querySelector('[name="name"]').value = name; + + for (const key in config) { + if ( + config[key] === null || + !inputs[key] || + (inputs[key].classList.contains('hidden') && !config[key]) + ) { + continue; + } + + if (typeof config[key] === 'boolean' && config[key]) { + inputs[config].checked = 'on'; + } + + if (typeof config[key] === 'object') { + inputs[key].value = JSON.stringify(config[key]); + } else { + inputs[key].value = String(config[key]); + } + + if (inputs[key].classList.contains('hidden')) { + inputs[key].classList.toggle('hidden'); + button = document.querySelector(`[data-id="${inputs[key].id}"]`); + button.textContent = '-'; + } + } +}; + +form.addEventListener('click', evt => { + const target = evt.target; + if (!(target instanceof HTMLButtonElement)) { + return; + } + + evt.preventDefault(); + + const current = target.textContent; + target.textContent = current === '+' ? '-' : '+'; + const fieldName = target.dataset.id; + const field = document.getElementById(fieldName); + field.classList.toggle('hidden'); +}); + +(() => { + const vscode = acquireVsCodeApi(); + + form.addEventListener('submit', evt => { + evt.preventDefault(); + + const data = Object.fromEntries(new FormData(form)); + + for (const id in inputs) { + if (inputs[id].classList.contains('hidden')) { + data[id] = null; + continue; + } + + if (inputs[id] instanceof HTMLTextAreaElement) { + data[id] = JSON.parse(inputs[id].value); + } + + if (inputs[id].type === 'checkbox') { + data[id] = !!inputs[id].checked; + } + + if (inputs[id].type === 'number') { + data[id] = Number(inputs[id].value); + } + } + + data.flavor = flavor; + data.type = type; + + submit.disabled = true; + spinner.classList.remove('hidden'); + + if (mode === 'update') { + data.id = id; + } + + vscode.postMessage({ + command: mode, + data, + }); + }); +})(); + +window.addEventListener('message', evt => { + const message = evt.data; + + switch (message.command) { + case 'register': + mode = 'register'; + type = message.type; + flavor = message.flavor; + id = ''; + break; + + case 'update': + mode = 'update'; + type = message.type; + flavor = message.flavor; + id = message.id; + title.innerText = title.innerText.replace('Register', 'Update'); + setValues(message.name, message.config); + break; + + case 'fail': + spinner.classList.add('hidden'); + submit.disabled = false; + break; + } +}); diff --git a/resources/dag-view/dag.css b/resources/dag-view/dag.css index fe8764e5..25b6ce6d 100644 --- a/resources/dag-view/dag.css +++ b/resources/dag-view/dag.css @@ -1,3 +1,15 @@ +/* Copyright(c) ZenML GmbH 2024. All Rights Reserved. + Licensed under the Apache License, Version 2.0(the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + or implied.See the License for the specific language governing + permissions and limitations under the License. */ body { background-color: var(--vscode-editor-background); color: var(--vscode-editor-foreground); diff --git a/resources/dag-view/dag.js b/resources/dag-view/dag.js index b7b1885c..08253921 100644 --- a/resources/dag-view/dag.js +++ b/resources/dag-view/dag.js @@ -1,3 +1,15 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. import svgPanZoom from 'svg-pan-zoom'; (() => { diff --git a/resources/stacks-form/stacks.css b/resources/stacks-form/stacks.css new file mode 100644 index 00000000..765dcd89 --- /dev/null +++ b/resources/stacks-form/stacks.css @@ -0,0 +1,115 @@ +/* Copyright(c) ZenML GmbH 2024. All Rights Reserved. + Licensed under the Apache License, Version 2.0(the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + or implied.See the License for the specific language governing + permissions and limitations under the License. */ +h2 { + text-align: center; +} + +input { + background-color: var(--vscode-editor-background); + color: var(--vscode-editor-foreground); +} + +input[type='radio'] { + appearance: none; + width: 15px; + height: 15px; + border-radius: 50%; + background-clip: content-box; + border: 2px solid var(--vscode-editor-foreground); + background-color: var(--vscode-editor-background); +} + +input[type='radio']:checked { + background-color: var(--vscode-editor-foreground); + padding: 2px; +} + +p { + margin: 0; + padding: 0; +} + +.options { + display: flex; + flex-wrap: nowrap; + overflow-x: auto; + align-items: center; + width: 100%; + border: 2px var(--vscode-editor-foreground) solid; + border-radius: 5px; + scrollbar-color: var(--vscode-editor-foreground) var(--vscode-editor-background); +} + +.single-option { + display: flex; + flex-direction: row; + align-items: start; + justify-content: center; + padding: 5px; + margin: 10px; + flex-shrink: 0; +} + +.single-option input { + margin-right: 5px; +} + +.single-option label { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + background-color: #eee; + color: #111; + text-align: center; + padding: 5px; + border: 2px var(--vscode-editor-foreground) solid; + border-radius: 5px; +} + +.single-option img { + width: 50px; + height: 50px; + flex-shrink: 0; +} + +.center { + display: flex; + justify-content: center; + align-items: center; + margin: 10px; +} + +.loader { + width: 20px; + height: 20px; + border: 5px solid #fff; + border-bottom-color: #ff3d00; + border-radius: 50%; + display: inline-block; + box-sizing: border-box; + animation: rotation 1s linear infinite; +} + +@keyframes rotation { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} + +.hidden { + display: none; +} diff --git a/resources/stacks-form/stacks.js b/resources/stacks-form/stacks.js new file mode 100644 index 00000000..f3be1915 --- /dev/null +++ b/resources/stacks-form/stacks.js @@ -0,0 +1,99 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +document.querySelector('input[name="orchestrator"]').toggleAttribute('required'); +document.querySelector('input[name="artifact_store"]').toggleAttribute('required'); + +const form = document.querySelector('form'); +const submit = document.querySelector('input[type="submit"]'); +const spinner = document.querySelector('.loader'); +let previousValues = {}; +let id = undefined; +let mode = 'register'; + +form.addEventListener('click', evt => { + const target = evt.target; + let input = null; + + if (target instanceof HTMLLabelElement) { + input = document.getElementById(target.htmlFor); + } else if (target instanceof HTMLInputElement && target.type === 'radio') { + input = target; + } + + if (!input) { + return; + } + + const value = input.value; + const name = input.name; + if (previousValues[name] === value) { + delete previousValues[name]; + input.checked = false; + } else { + previousValues[name] = value; + } +}); + +(() => { + const vscode = acquireVsCodeApi(); + + form.addEventListener('submit', evt => { + evt.preventDefault(); + submit.disabled = true; + spinner.classList.remove('hidden'); + const data = Object.fromEntries(new FormData(evt.target)); + + if (id) { + data.id = id; + } + + vscode.postMessage({ + command: mode, + data, + }); + }); +})(); + +const title = document.querySelector('h2'); +const nameInput = document.querySelector('input[name="name"]'); + +window.addEventListener('message', evt => { + const message = evt.data; + + switch (message.command) { + case 'register': + mode = 'register'; + title.innerText = 'Register Stack'; + id = undefined; + previousValues = {}; + form.reset(); + break; + + case 'update': + mode = 'update'; + title.innerText = 'Update Stack'; + id = message.data.id; + nameInput.value = message.data.name; + previousValues = message.data.components; + Object.entries(message.data.components).forEach(([type, id]) => { + const input = document.querySelector(`[name="${type}"][value="${id}"]`); + input.checked = true; + }); + break; + + case 'fail': + spinner.classList.add('hidden'); + submit.disabled = false; + break; + } +}); diff --git a/src/commands/components/ComponentsForm.ts b/src/commands/components/ComponentsForm.ts new file mode 100644 index 00000000..3cc31c07 --- /dev/null +++ b/src/commands/components/ComponentsForm.ts @@ -0,0 +1,410 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. + +import * as vscode from 'vscode'; +import WebviewBase from '../../common/WebviewBase'; +import { handlebars } from 'hbs'; +import Panels from '../../common/panels'; +import { Flavor } from '../../types/StackTypes'; +import { LSClient } from '../../services/LSClient'; +import { traceError, traceInfo } from '../../common/log/logging'; +import { ComponentDataProvider } from '../../views/activityBar/componentView/ComponentDataProvider'; + +const ROOT_PATH = ['resources', 'components-form']; +const CSS_FILE = 'components.css'; +const JS_FILE = 'components.js'; + +interface ComponentField { + is_string?: boolean; + is_integer?: boolean; + is_boolean?: boolean; + is_string_object?: boolean; + is_json_object?: boolean; + is_array?: boolean; + is_optional?: boolean; + is_required?: boolean; + defaultValue: any; + title: string; + key: string; +} + +export default class ComponentForm extends WebviewBase { + private static instance: ComponentForm | null = null; + + private root: vscode.Uri; + private javaScript: vscode.Uri; + private css: vscode.Uri; + private template: HandlebarsTemplateDelegate; + + /** + * Retrieves a singleton instance of ComponentForm + * @returns {ComponentForm} The singleton instance + */ + public static getInstance(): ComponentForm { + if (!ComponentForm.instance) { + ComponentForm.instance = new ComponentForm(); + } + + return ComponentForm.instance; + } + + constructor() { + super(); + + if (WebviewBase.context === null) { + throw new Error('Extension Context Not Propagated'); + } + + this.root = vscode.Uri.joinPath(WebviewBase.context.extensionUri, ...ROOT_PATH); + this.javaScript = vscode.Uri.joinPath(this.root, JS_FILE); + this.css = vscode.Uri.joinPath(this.root, CSS_FILE); + + this.template = handlebars.compile(this.produceTemplate()); + } + + /** + * Opens a webview panel based on the flavor config schema to register a new + * component + * @param {Flavor} flavor Flavor of component to register + */ + public async registerForm(flavor: Flavor) { + const panel = await this.getPanel(); + const description = flavor.config_schema.description.replaceAll('\n', ''); + panel.webview.html = this.template({ + type: flavor.type, + flavor: flavor.name, + logo: flavor.logo_url, + description, + docs_url: flavor.docs_url, + sdk_docs_url: flavor.sdk_docs_url, + cspSource: panel.webview.cspSource, + js: panel.webview.asWebviewUri(this.javaScript), + css: panel.webview.asWebviewUri(this.css), + fields: this.toFormFields(flavor.config_schema), + }); + + panel.webview.postMessage({ command: 'register', type: flavor.type, flavor: flavor.name }); + } + + /** + * Opens a webview panel based on the flavor config schema to update a + * specified component + * @param {Flavor} flavor Flavor of the selected component + * @param {string} name Name of the selected component + * @param {string} id ID of the selected component + * @param {object} config Current configuration settings of the selected + * component + */ + public async updateForm( + flavor: Flavor, + name: string, + id: string, + config: { [key: string]: any } + ) { + const panel = await this.getPanel(); + const description = flavor.config_schema.description.replaceAll('\n', ''); + panel.webview.html = this.template({ + type: flavor.type, + flavor: flavor.name, + logo: flavor.logo_url, + description, + docs_url: flavor.docs_url, + sdk_docs_url: flavor.sdk_docs_url, + cspSource: panel.webview.cspSource, + js: panel.webview.asWebviewUri(this.javaScript), + css: panel.webview.asWebviewUri(this.css), + fields: this.toFormFields(flavor.config_schema), + }); + + panel.webview.postMessage({ + command: 'update', + type: flavor.type, + flavor: flavor.name, + name, + id, + config, + }); + } + + private async getPanel(): Promise { + const panels = Panels.getInstance(); + const existingPanel = panels.getPanel('component-form', true); + if (existingPanel) { + existingPanel.reveal(); + return existingPanel; + } + + const panel = panels.createPanel('component-form', 'Component Form', { + enableForms: true, + enableScripts: true, + retainContextWhenHidden: true, + }); + + this.attachListener(panel); + return panel; + } + + private attachListener(panel: vscode.WebviewPanel) { + panel.webview.onDidReceiveMessage( + async (message: { command: string; data: { [key: string]: string } }) => { + let success = false; + const data = message.data; + const { name, flavor, type, id } = data; + delete data.name; + delete data.type; + delete data.flavor; + delete data.id; + + switch (message.command) { + case 'register': + success = await this.registerComponent(name, type, flavor, data); + break; + case 'update': + success = await this.updateComponent(id, name, type, data); + break; + } + + if (!success) { + panel.webview.postMessage({ command: 'fail' }); + return; + } + + panel.dispose(); + ComponentDataProvider.getInstance().refresh(); + } + ); + } + + private async registerComponent( + name: string, + type: string, + flavor: string, + data: object + ): Promise { + const lsClient = LSClient.getInstance(); + try { + const resp = await lsClient.sendLsClientRequest('registerComponent', [ + type, + flavor, + name, + data, + ]); + + if ('error' in resp) { + vscode.window.showErrorMessage(`Unable to register component: "${resp.error}"`); + console.error(resp.error); + traceError(resp.error); + return false; + } + + traceInfo(resp.message); + } catch (e) { + vscode.window.showErrorMessage(`Unable to register component: "${e}"`); + console.error(e); + traceError(e); + return false; + } + + return true; + } + + private async updateComponent( + id: string, + name: string, + type: string, + data: object + ): Promise { + const lsClient = LSClient.getInstance(); + try { + const resp = await lsClient.sendLsClientRequest('updateComponent', [id, type, name, data]); + + if ('error' in resp) { + vscode.window.showErrorMessage(`Unable to update component: "${resp.error}"`); + console.error(resp.error); + traceError(resp.error); + return false; + } + + traceInfo(resp.message); + } catch (e) { + vscode.window.showErrorMessage(`Unable to update component: "${e}"`); + console.error(e); + traceError(e); + return false; + } + + return true; + } + + private toFormFields(configSchema: { [key: string]: any }) { + const properties = configSchema.properties; + const required = configSchema.required ?? []; + + const converted: Array = []; + for (const key in properties) { + const current: ComponentField = { + key, + title: properties[key].title, + defaultValue: properties[key].default, + }; + converted.push(current); + + if ('anyOf' in properties[key]) { + if (properties[key].anyOf.find((obj: { type: string }) => obj.type === 'null')) { + current.is_optional = true; + } + + if ( + properties[key].anyOf.find( + (obj: { type: string }) => obj.type === 'object' || obj.type === 'array' + ) + ) { + current.is_json_object = true; + } else if (properties[key].anyOf[0].type === 'string') { + current.is_string = true; + } else if (properties[key].anyOf[0].type === 'integer') { + current.is_integer = true; + } else if (properties[key].anyOf[0].type === 'boolean') { + current.is_boolean = true; + } + } + + if (required.includes(key)) { + current.is_required = true; + } + + if (!properties[key].type) { + continue; + } + + current.is_boolean = properties[key].type === 'boolean'; + current.is_string = properties[key].type === 'string'; + current.is_integer = properties[key].type === 'integer'; + if (properties[key].type === 'object' || properties[key].type === 'array') { + current.is_json_object = true; + current.defaultValue = JSON.stringify(properties[key].default); + } + + if (properties[key].type === 'array') { + current.is_array = true; + } + } + + return converted; + } + + private produceTemplate(): string { + return ` + + + + + + + + Stack Form + + + + Register {{type}} Stack Component ({{flavor}}) + + + {{{description}}} + + {{#if docs_url}} + Documentation + {{/if}} + + {{#if sdk_docs_url}} + SDK Documentation + {{/if}} + + + + + + Component Name* + + + + + + + {{#each fields}} + + + + {{title}} {{#if is_required}}*{{/if}} + + {{#if is_optional}} + + + {{/if}} + + + + {{#if is_string}} + + {{/if}} + + {{#if is_boolean}} + + {{/if}} + + {{#if is_integer}} + + {{/if}} + + {{#if is_json_object}} + {{defaultValue}} + {{/if}} + + + {{/each}} + + + + + + + + `; + } +} diff --git a/src/commands/components/cmds.ts b/src/commands/components/cmds.ts new file mode 100644 index 00000000..36c28fb0 --- /dev/null +++ b/src/commands/components/cmds.ts @@ -0,0 +1,159 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +import * as vscode from 'vscode'; + +import ZenMLStatusBar from '../../views/statusBar'; +import { LSClient } from '../../services/LSClient'; +import { showInformationMessage } from '../../utils/notifications'; +import Panels from '../../common/panels'; +import { ComponentDataProvider } from '../../views/activityBar/componentView/ComponentDataProvider'; +import { ComponentTypesResponse, Flavor, FlavorListResponse } from '../../types/StackTypes'; +import { getFlavor, getFlavorsOfType } from '../../common/api'; +import ComponentForm from './ComponentsForm'; +import { StackComponentTreeItem } from '../../views/activityBar'; +import { traceError, traceInfo } from '../../common/log/logging'; + +/** + * Refreshes the stack component view. + */ +const refreshComponentView = async () => { + try { + await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Notification, + title: 'Refreshing Component View...', + cancellable: false, + }, + async progress => { + await ComponentDataProvider.getInstance().refresh(); + } + ); + } catch (e) { + vscode.window.showErrorMessage(`Failed to refresh component view: ${e}`); + traceError(`Failed to refresh component view: ${e}`); + console.error(`Failed to refresh component view: ${e}`); + } +}; + +/** + * Allows one to choose a component type and flavor, then opens the component + * form webview panel to a form specific to register a new a component of that + * type and flavor. + */ +const registerComponent = async () => { + const lsClient = LSClient.getInstance(); + try { + const types = await lsClient.sendLsClientRequest('getComponentTypes'); + + if ('error' in types) { + throw new Error(String(types.error)); + } + + const type = await vscode.window.showQuickPick(types, { + title: 'What type of component to register?', + }); + if (!type) { + return; + } + + const flavors = await getFlavorsOfType(type); + if ('error' in flavors) { + throw flavors.error; + } + + const flavorNames = flavors.map(flavor => flavor.name); + const selectedFlavor = await vscode.window.showQuickPick(flavorNames, { + title: `What flavor of a ${type} component to register?`, + }); + if (!selectedFlavor) { + return; + } + + const flavor = flavors.find(flavor => selectedFlavor === flavor.name); + await ComponentForm.getInstance().registerForm(flavor as Flavor); + } catch (e) { + vscode.window.showErrorMessage(`Unable to open component form: ${e}`); + traceError(e); + console.error(e); + } +}; + +const updateComponent = async (node: StackComponentTreeItem) => { + try { + const flavor = await getFlavor(node.component.flavor); + + await ComponentForm.getInstance().updateForm( + flavor, + node.component.name, + node.component.id, + node.component.config + ); + } catch (e) { + vscode.window.showErrorMessage(`Unable to open component form: ${e}`); + traceError(e); + console.error(e); + } +}; + +/** + * Deletes a specified Stack Component + * @param {StackComponentTreeItem} node The specified stack component to delete + */ +const deleteComponent = async (node: StackComponentTreeItem) => { + const lsClient = LSClient.getInstance(); + + const answer = await vscode.window.showWarningMessage( + `Are you sure you want to delete ${node.component.name}? This cannot be undone.`, + { modal: true }, + 'Delete' + ); + + if (!answer) { + return; + } + + await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Window, + title: `Deleting stack component ${node.component.name}...`, + }, + async () => { + try { + const resp = await lsClient.sendLsClientRequest('deleteComponent', [ + node.component.id, + node.component.type, + ]); + + if ('error' in resp) { + throw resp.error; + } + + vscode.window.showInformationMessage(`${node.component.name} deleted`); + traceInfo(`${node.component.name} deleted`); + + ComponentDataProvider.getInstance().refresh(); + } catch (e) { + vscode.window.showErrorMessage(`Failed to delete component: ${e}`); + traceError(e); + console.error(e); + } + } + ); +}; + +export const componentCommands = { + refreshComponentView, + registerComponent, + updateComponent, + deleteComponent, +}; diff --git a/src/commands/components/registry.ts b/src/commands/components/registry.ts new file mode 100644 index 00000000..086f7c74 --- /dev/null +++ b/src/commands/components/registry.ts @@ -0,0 +1,69 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +import { componentCommands } from './cmds'; +import { registerCommand } from '../../common/vscodeapi'; +import { ZenExtension } from '../../services/ZenExtension'; +import { ExtensionContext, commands } from 'vscode'; +import { ComponentDataProvider } from '../../views/activityBar/componentView/ComponentDataProvider'; +import { StackComponentTreeItem } from '../../views/activityBar'; + +/** + * Registers stack component-related commands for the extension. + * + * @param {ExtensionContext} context - The context in which the extension operates, used for registering commands and managing their lifecycle. + */ +export const registerComponentCommands = (context: ExtensionContext) => { + const componentDataProvider = ComponentDataProvider.getInstance(); + try { + const registeredCommands = [ + registerCommand( + 'zenml.setComponentItemsPerPage', + async () => await componentDataProvider.updateItemsPerPage() + ), + registerCommand( + 'zenml.refreshComponentView', + async () => await componentCommands.refreshComponentView() + ), + registerCommand( + 'zenml.registerComponent', + async () => await componentCommands.registerComponent() + ), + registerCommand( + 'zenml.updateComponent', + async (node: StackComponentTreeItem) => await componentCommands.updateComponent(node) + ), + registerCommand( + 'zenml.deleteComponent', + async (node: StackComponentTreeItem) => await componentCommands.deleteComponent(node) + ), + registerCommand( + 'zenml.nextComponentPage', + async () => await componentDataProvider.goToNextPage() + ), + registerCommand( + 'zenml.previousComponentPage', + async () => await componentDataProvider.goToPreviousPage() + ), + ]; + + registeredCommands.forEach(cmd => { + context.subscriptions.push(cmd); + ZenExtension.commandDisposables.push(cmd); + }); + + commands.executeCommand('setContext', 'componentCommandsRegistered', true); + } catch (e) { + console.error('Error registering component commands:', e); + commands.executeCommand('setContext', 'componentCommandsRegistered', false); + } +}; diff --git a/src/commands/pipelines/DagRender.ts b/src/commands/pipelines/DagRender.ts index d081edca..e06d3eec 100644 --- a/src/commands/pipelines/DagRender.ts +++ b/src/commands/pipelines/DagRender.ts @@ -20,38 +20,47 @@ import { LSClient } from '../../services/LSClient'; import { ServerStatus } from '../../types/ServerInfoTypes'; import { JsonObject } from '../../views/panel/panelView/PanelTreeItem'; import { PanelDataProvider } from '../../views/panel/panelView/PanelDataProvider'; +import Panels from '../../common/panels'; +import WebviewBase from '../../common/WebviewBase'; const ROOT_PATH = ['resources', 'dag-view']; const CSS_FILE = 'dag.css'; const JS_FILE = 'dag-packed.js'; const ICONS_DIRECTORY = '/resources/dag-view/icons/'; -export default class DagRenderer { +export default class DagRenderer extends WebviewBase { private static instance: DagRenderer | undefined; - private openPanels: { [id: string]: vscode.WebviewPanel }; private createSVGWindow: Function = () => {}; private iconSvgs: { [name: string]: string } = {}; private root: vscode.Uri; private javaScript: vscode.Uri; private css: vscode.Uri; - constructor(context: vscode.ExtensionContext) { - DagRenderer.instance = this; - this.openPanels = {}; - this.root = vscode.Uri.joinPath(context.extensionUri, ...ROOT_PATH); + constructor() { + super(); + + if (WebviewBase.context === null) { + throw new Error('Extension Context Not Propagated'); + } + + this.root = vscode.Uri.joinPath(WebviewBase.context.extensionUri, ...ROOT_PATH); this.javaScript = vscode.Uri.joinPath(this.root, JS_FILE); this.css = vscode.Uri.joinPath(this.root, CSS_FILE); this.loadSvgWindowLib(); - this.loadIcons(context.extensionPath + ICONS_DIRECTORY); + this.loadIcons(WebviewBase.context.extensionPath + ICONS_DIRECTORY); } /** * Retrieves a singleton instance of DagRenderer * - * @returns {DagRenderer | undefined} The singleton instance if it exists + * @returns {DagRenderer} The singleton instance */ - public static getInstance(): DagRenderer | undefined { + public static getInstance(): DagRenderer { + if (!DagRenderer.instance) { + DagRenderer.instance = new DagRenderer(); + } + return DagRenderer.instance; } @@ -60,7 +69,6 @@ export default class DagRenderer { */ public deactivate(): void { DagRenderer.instance = undefined; - Object.values(this.openPanels).forEach(panel => panel.dispose()); } /** @@ -69,30 +77,21 @@ export default class DagRenderer { * @returns */ public async createView(node: PipelineTreeItem) { - const existingPanel = this.getDagPanel(node.id); + const p = Panels.getInstance(); + const existingPanel = p.getPanel(node.id); if (existingPanel) { existingPanel.reveal(); return; } - const panel = vscode.window.createWebviewPanel( - `DAG-${node.id}`, - node.label as string, - vscode.ViewColumn.One, - { - enableScripts: true, - localResourceRoots: [this.root], - } - ); - - panel.webview.html = this.getLoadingContent(); + const panel = p.createPanel(node.id, node.label as string, { + enableScripts: true, + localResourceRoots: [this.root], + }); panel.webview.onDidReceiveMessage(this.createMessageHandler(panel, node)); this.renderDag(panel, node); - - // To track which DAGs are currently open - this.registerDagPanel(node.id, panel); } private createMessageHandler( @@ -211,7 +210,14 @@ export default class DagRenderer { const title = `${dagData.name} - v${dagData.version}`; // And set its HTML content - panel.webview.html = this.getWebviewContent({ svg, cssUri, jsUri, updateButton, title }); + panel.webview.html = this.getWebviewContent({ + svg, + cssUri, + jsUri, + updateButton, + title, + cspSource: panel.webview.cspSource, + }); } private async loadSvgWindowLib() { @@ -240,22 +246,6 @@ export default class DagRenderer { }); } - private deregisterDagPanel(runId: string) { - delete this.openPanels[runId]; - } - - private getDagPanel(runId: string): vscode.WebviewPanel | undefined { - return this.openPanels[runId]; - } - - private registerDagPanel(runId: string, panel: vscode.WebviewPanel) { - this.openPanels[runId] = panel; - - panel.onDidDispose(() => { - this.deregisterDagPanel(runId); - }, null); - } - private layoutDag(dagData: PipelineRunDag): Dagre.graphlib.Graph { const { nodes, edges } = dagData; const graph = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({})); @@ -350,56 +340,27 @@ export default class DagRenderer { return canvas.svg(); } - private getLoadingContent(): string { - return ` - - - - - - - Loading - - - - - -`; - } - private getWebviewContent({ svg, cssUri, jsUri, updateButton, title, + cspSource, }: { svg: string; cssUri: vscode.Uri; jsUri: vscode.Uri; updateButton: boolean; title: string; + cspSource: string; }): string { return ` - + DAG diff --git a/src/commands/stack/StackForm.ts b/src/commands/stack/StackForm.ts new file mode 100644 index 00000000..65f6597f --- /dev/null +++ b/src/commands/stack/StackForm.ts @@ -0,0 +1,274 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. + +import * as vscode from 'vscode'; +import WebviewBase from '../../common/WebviewBase'; +import { handlebars } from 'hbs'; +import Panels from '../../common/panels'; +import { getAllFlavors, getAllStackComponents } from '../../common/api'; +import { Flavor, StackComponent } from '../../types/StackTypes'; +import { LSClient } from '../../services/LSClient'; +import { StackDataProvider } from '../../views/activityBar'; +import { traceError, traceInfo } from '../../common/log/logging'; + +type MixedComponent = { name: string; id: string; url: string }; + +const ROOT_PATH = ['resources', 'stacks-form']; +const CSS_FILE = 'stacks.css'; +const JS_FILE = 'stacks.js'; + +export default class StackForm extends WebviewBase { + private static instance: StackForm | null = null; + + private root: vscode.Uri; + private javaScript: vscode.Uri; + private css: vscode.Uri; + private template: HandlebarsTemplateDelegate; + + /** + * Retrieves a singleton instance of ComponentForm + * @returns {StackForm} The singleton instance + */ + public static getInstance(): StackForm { + if (!StackForm.instance) { + StackForm.instance = new StackForm(); + } + + return StackForm.instance; + } + + constructor() { + super(); + + if (WebviewBase.context === null) { + throw new Error('Extension Context Not Propagated'); + } + + this.root = vscode.Uri.joinPath(WebviewBase.context.extensionUri, ...ROOT_PATH); + this.javaScript = vscode.Uri.joinPath(this.root, JS_FILE); + this.css = vscode.Uri.joinPath(this.root, CSS_FILE); + + handlebars.registerHelper('capitalize', (str: string) => { + return str + .split('_') + .map(word => word[0].toUpperCase() + word.slice(1).toLowerCase()) + .join(' '); + }); + + this.template = handlebars.compile(this.produceTemplate()); + } + + /** + * Opens a webview panel with a form to register a new stack + */ + public async registerForm() { + const panel = await this.display(); + panel.webview.postMessage({ command: 'register' }); + } + + /** + * Opens a webview panel with a form to update a specified stack + * @param {string} id The id of the specified stack + * @param {string} name The current name of the specified stack + * @param {object} components The component settings of the sepcified stack + */ + public async updateForm(id: string, name: string, components: { [type: string]: string }) { + const panel = await this.display(); + panel.webview.postMessage({ command: 'update', data: { id, name, components } }); + } + + private async display(): Promise { + const panels = Panels.getInstance(); + const existingPanel = panels.getPanel('stack-form'); + if (existingPanel) { + existingPanel.reveal(); + return existingPanel; + } + + const panel = panels.createPanel('stack-form', 'Stack Form', { + enableForms: true, + enableScripts: true, + retainContextWhenHidden: true, + }); + + await this.renderForm(panel); + this.attachListener(panel); + return panel; + } + + private attachListener(panel: vscode.WebviewPanel) { + panel.webview.onDidReceiveMessage( + async (message: { command: string; data: { [key: string]: string } }) => { + let success = false; + const data = message.data; + const { name, id } = data; + delete data.name; + delete data.id; + + switch (message.command) { + case 'register': + success = await this.registerStack(name, data); + break; + case 'update': { + const updateData = Object.fromEntries( + Object.entries(data).map(([type, id]) => [type, [id]]) + ); + success = await this.updateStack(id, name, updateData); + break; + } + } + + if (!success) { + panel.webview.postMessage({ command: 'fail' }); + return; + } + + panel.dispose(); + StackDataProvider.getInstance().refresh(); + } + ); + } + + private async registerStack( + name: string, + components: { [type: string]: string } + ): Promise { + const lsClient = LSClient.getInstance(); + try { + const resp = await lsClient.sendLsClientRequest('registerStack', [name, components]); + + if ('error' in resp) { + vscode.window.showErrorMessage(`Unable to register stack: "${resp.error}"`); + console.error(resp.error); + traceError(resp.error); + return false; + } + + traceInfo(resp.message); + } catch (e) { + vscode.window.showErrorMessage(`Unable to register stack: "${e}"`); + console.error(e); + traceError(e); + return false; + } + + return true; + } + + private async updateStack( + id: string, + name: string, + components: { [key: string]: string[] } + ): Promise { + const lsClient = LSClient.getInstance(); + try { + const types = await lsClient.sendLsClientRequest('getComponentTypes'); + if (!Array.isArray(types)) { + throw new Error('Could not get Component Types from LS Server'); + } + + // adding missing types to components object, in case we removed that type. + types.forEach(type => { + if (!components[type]) { + components[type] = []; + } + }); + + const resp = await lsClient.sendLsClientRequest('updateStack', [id, name, components]); + + if ('error' in resp) { + vscode.window.showErrorMessage(`Unable to update stack: "${resp.error}"`); + console.error(resp.error); + traceError(resp.error); + return false; + } + + traceInfo(resp.message); + } catch (e) { + vscode.window.showErrorMessage(`Unable to update stack: "${e}"`); + console.error(e); + traceError(e); + return false; + } + + return true; + } + + private async renderForm(panel: vscode.WebviewPanel) { + const flavors = await getAllFlavors(); + const components = await getAllStackComponents(); + const options = this.convertComponents(flavors, components); + const js = panel.webview.asWebviewUri(this.javaScript); + const css = panel.webview.asWebviewUri(this.css); + const cspSource = panel.webview.cspSource; + + panel.webview.html = this.template({ options, js, css, cspSource }); + } + + private convertComponents( + flavors: Flavor[], + components: { [type: string]: StackComponent[] } + ): { [type: string]: MixedComponent[] } { + const out: { [type: string]: MixedComponent[] } = {}; + + Object.keys(components).forEach(key => { + out[key] = components[key].map(component => { + return { + name: component.name, + id: component.id, + url: + flavors.find( + flavor => flavor.type === component.type && flavor.name === component.flavor + )?.logo_url ?? '', + }; + }); + }); + + return out; + } + + private produceTemplate(): string { + return ` + + + + + + + + Stack Form + + + Register Stack + + Stack Name: + {{#each options}} + {{capitalize @key}} + + {{#each this}} + + + {{capitalize name}} + + {{/each}} + + {{/each}} + + + + + + + `; + } +} diff --git a/src/commands/stack/cmds.ts b/src/commands/stack/cmds.ts index 9525824a..026fb78b 100644 --- a/src/commands/stack/cmds.ts +++ b/src/commands/stack/cmds.ts @@ -11,11 +11,15 @@ // or implied.See the License for the specific language governing // permissions and limitations under the License. import * as vscode from 'vscode'; -import { StackDataProvider, StackTreeItem } from '../../views/activityBar'; +import { StackComponentTreeItem, StackDataProvider, StackTreeItem } from '../../views/activityBar'; import ZenMLStatusBar from '../../views/statusBar'; import { getStackDashboardUrl, switchActiveStack } from './utils'; import { LSClient } from '../../services/LSClient'; import { showInformationMessage } from '../../utils/notifications'; +import Panels from '../../common/panels'; +import { randomUUID } from 'crypto'; +import StackForm from './StackForm'; +import { traceError, traceInfo } from '../../common/log/logging'; /** * Refreshes the stack view. @@ -178,6 +182,78 @@ const goToStackUrl = (node: StackTreeItem) => { } }; +/** + * Opens the stack form webview panel to a form specific to registering a new + * stack. + */ +const registerStack = () => { + StackForm.getInstance().registerForm(); +}; + +/** + * Opens the stack form webview panel to a form specific to updating a specified stack. + * @param {StackTreeItem} node The specified stack to update. + */ +const updateStack = async (node: StackTreeItem) => { + const { id, label: name } = node; + const components: { [type: string]: string } = {}; + + node.children?.forEach(child => { + if (child instanceof StackComponentTreeItem) { + const { type, id } = (child as StackComponentTreeItem).component; + components[type] = id; + } + }); + + StackForm.getInstance().updateForm(id, name, components); +}; + +/** + * Deletes a specified stack. + * + * @param {StackTreeItem} node The Stack to delete + */ +const deleteStack = async (node: StackTreeItem) => { + const lsClient = LSClient.getInstance(); + + const answer = await vscode.window.showWarningMessage( + `Are you sure you want to delete ${node.label}? This cannot be undone.`, + { modal: true }, + 'Delete' + ); + + if (!answer) { + return; + } + + await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Window, + title: `Deleting stack ${node.label}...`, + }, + async () => { + const { id } = node; + + try { + const resp = await lsClient.sendLsClientRequest('deleteStack', [id]); + + if ('error' in resp) { + throw resp.error; + } + + vscode.window.showInformationMessage(`${node.label} deleted`); + traceInfo(`${node.label} deleted`); + + StackDataProvider.getInstance().refresh(); + } catch (e) { + vscode.window.showErrorMessage(`Failed to delete component: ${e}`); + traceError(e); + console.error(e); + } + } + ); +}; + export const stackCommands = { refreshStackView, refreshActiveStack, @@ -185,4 +261,7 @@ export const stackCommands = { copyStack, setActiveStack, goToStackUrl, + registerStack, + updateStack, + deleteStack, }; diff --git a/src/commands/stack/registry.ts b/src/commands/stack/registry.ts index e8973ccd..2a6cef3f 100644 --- a/src/commands/stack/registry.ts +++ b/src/commands/stack/registry.ts @@ -35,6 +35,14 @@ export const registerStackCommands = (context: ExtensionContext) => { 'zenml.refreshActiveStack', async () => await stackCommands.refreshActiveStack() ), + registerCommand('zenml.registerStack', async () => stackCommands.registerStack()), + registerCommand('zenml.updateStack', async (node: StackTreeItem) => + stackCommands.updateStack(node) + ), + registerCommand( + 'zenml.deleteStack', + async (node: StackTreeItem) => await stackCommands.deleteStack(node) + ), registerCommand( 'zenml.renameStack', async (node: StackTreeItem) => await stackCommands.renameStack(node) diff --git a/src/common/WebviewBase.ts b/src/common/WebviewBase.ts new file mode 100644 index 00000000..52530645 --- /dev/null +++ b/src/common/WebviewBase.ts @@ -0,0 +1,30 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +import * as vscode from 'vscode'; + +/** + * Provides functionality to share extension context among classes that inherit + * from it. + */ +export default class WebviewBase { + protected static context: vscode.ExtensionContext | null = null; + + /** + * Sets the extension context so that descendant classes can correctly + * path to their resources + * @param {vscode.ExtensionContext} context ExtensionContext + */ + public static setContext(context: vscode.ExtensionContext) { + WebviewBase.context = context; + } +} diff --git a/src/common/api.ts b/src/common/api.ts new file mode 100644 index 00000000..90f96e93 --- /dev/null +++ b/src/common/api.ts @@ -0,0 +1,114 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +import { LSClient } from '../services/LSClient'; +import { + ComponentsListResponse, + Flavor, + FlavorListResponse, + StackComponent, +} from '../types/StackTypes'; + +let flavors: Flavor[] = []; + +/** + * Gets all component flavors and caches them + * @returns {Flavor[]} List of flavors + */ +export const getAllFlavors = async (): Promise => { + if (flavors.length > 0) { + return flavors; + } + const lsClient = LSClient.getInstance(); + + let [page, maxPage] = [0, 1]; + do { + page++; + const resp = await lsClient.sendLsClientRequest('listFlavors', [ + page, + 10000, + ]); + + if ('error' in resp) { + console.error(`Error retrieving flavors: ${resp.error}`); + throw new Error(`Error retrieving flavors: ${resp.error}`); + } + + maxPage = resp.total_pages; + flavors = flavors.concat(resp.items); + } while (page < maxPage); + return flavors; +}; + +/** + * Gets all flavors of a specified component type + * @param {string} type Type of component to filter by + * @returns {Flavor[]} List of flavors that match the component type filter + */ +export const getFlavorsOfType = async (type: string): Promise => { + const flavors = await getAllFlavors(); + return flavors.filter(flavor => flavor.type === type); +}; + +/** + * Gets a specific flavor + * @param {string} name The name of the flavor to get + * @returns {Flavor} The specified flavor. + */ +export const getFlavor = async (name: string): Promise => { + const flavors = await getAllFlavors(); + const flavor = flavors.find(flavor => flavor.name === name); + + if (!flavor) { + throw Error(`getFlavor: Flavor ${name} not found`); + } + + return flavor; +}; + +/** + * Gets all stack components + * @returns {object} Object containing all components keyed by each type. + */ +export const getAllStackComponents = async (): Promise<{ + [type: string]: StackComponent[]; +}> => { + const lsClient = LSClient.getInstance(); + let components: StackComponent[] = []; + let [page, maxPage] = [0, 1]; + + do { + page++; + const resp = await lsClient.sendLsClientRequest('listComponents', [ + page, + 10000, + ]); + + if ('error' in resp) { + console.error(`Error retrieving components: ${resp.error}`); + throw new Error(`Error retrieving components: ${resp.error}`); + } + + maxPage = resp.total_pages; + components = components.concat(resp.items); + } while (page < maxPage); + + const out: { [type: string]: StackComponent[] } = {}; + components.forEach(component => { + if (!(component.type in out)) { + out[component.type] = []; + } + out[component.type].push(component); + }); + + return out; +}; diff --git a/src/common/panels.ts b/src/common/panels.ts new file mode 100644 index 00000000..76365b9b --- /dev/null +++ b/src/common/panels.ts @@ -0,0 +1,113 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +import * as vscode from 'vscode'; + +/** + * Handles creation and monitoring of webview panels. + */ +export default class Panels { + private static instance: Panels | undefined; + private openPanels: { [id: string]: vscode.WebviewPanel }; + + constructor() { + this.openPanels = {}; + } + + /** + * Retrieves a singleton instance of Panels + * @returns {Panels} The singleton instance + */ + public static getInstance(): Panels { + if (Panels.instance === undefined) { + Panels.instance = new Panels(); + } + return Panels.instance; + } + + /** + * Creates a webview panel + * @param {string} id ID of the webview panel to create + * @param {string} label Title of webview panel tab + * @param {vscode.WebviewPanelOptions & vscode.WebviewOptions} options + * Options applied to the webview panel + * @returns {vscode.WebviewPanel} The webview panel created + */ + public createPanel( + id: string, + label: string, + options?: vscode.WebviewPanelOptions & vscode.WebviewOptions + ) { + const panel = vscode.window.createWebviewPanel(id, label, vscode.ViewColumn.One, options); + panel.webview.html = this.getLoadingContent(); + + this.openPanels[id] = panel; + + panel.onDidDispose(() => { + this.deregisterPanel(id); + }, null); + + return panel; + } + + /** + * Gets existing webview panel + * @param {string} id ID of webview panel to retrieve. + * @param {boolean} forceSpinner Whether to change the html content or not + * @returns {vscode.WebviewPanel | undefined} The webview panel if it exists, + * else undefined + */ + public getPanel(id: string, forceSpinner: boolean = false): vscode.WebviewPanel | undefined { + const panel = this.openPanels[id]; + + if (panel && forceSpinner) { + panel.webview.html = this.getLoadingContent(); + } + + return panel; + } + + private deregisterPanel(id: string) { + delete this.openPanels[id]; + } + + private getLoadingContent(): string { + return ` + + + + + + + Loading + + + + + +`; + } +} diff --git a/src/extension.ts b/src/extension.ts index 45abea59..8bc37721 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -20,6 +20,7 @@ import { registerEnvironmentCommands } from './commands/environment/registry'; import { LSP_ZENML_CLIENT_INITIALIZED } from './utils/constants'; import { toggleCommands } from './utils/global'; import DagRenderer from './commands/pipelines/DagRender'; +import WebviewBase from './common/WebviewBase'; export async function activate(context: vscode.ExtensionContext) { const eventBus = EventBus.getInstance(); @@ -48,7 +49,7 @@ export async function activate(context: vscode.ExtensionContext) { }) ); - new DagRenderer(context); + WebviewBase.setContext(context); } /** diff --git a/src/services/ZenExtension.ts b/src/services/ZenExtension.ts index 3836cabf..43a0f8e7 100644 --- a/src/services/ZenExtension.ts +++ b/src/services/ZenExtension.ts @@ -41,6 +41,8 @@ import ZenMLStatusBar from '../views/statusBar'; import { LSClient } from './LSClient'; import { toggleCommands } from '../utils/global'; import { PanelDataProvider } from '../views/panel/panelView/PanelDataProvider'; +import { ComponentDataProvider } from '../views/activityBar/componentView/ComponentDataProvider'; +import { registerComponentCommands } from '../commands/components/registry'; export interface IServerInfo { name: string; @@ -61,6 +63,7 @@ export class ZenExtension { private static dataProviders = new Map>([ ['zenmlServerView', ServerDataProvider.getInstance()], ['zenmlStackView', StackDataProvider.getInstance()], + ['zenmlComponentView', ComponentDataProvider.getInstance()], ['zenmlPipelineView', PipelineDataProvider.getInstance()], ['zenmlPanelView', PanelDataProvider.getInstance()], ]); @@ -68,6 +71,7 @@ export class ZenExtension { private static registries = [ registerServerCommands, registerStackCommands, + registerComponentCommands, registerPipelineCommands, ]; diff --git a/src/types/StackTypes.ts b/src/types/StackTypes.ts index 5ada4087..00fbdcde 100644 --- a/src/types/StackTypes.ts +++ b/src/types/StackTypes.ts @@ -40,8 +40,57 @@ interface StackComponent { name: string; flavor: string; type: string; + config: { [key: string]: any }; } export type StacksResponse = StacksData | ErrorMessageResponse | VersionMismatchError; -export { Stack, Components, StackComponent, StacksData }; +interface ComponentsListData { + index: number; + max_size: number; + total_pages: number; + total: number; + items: Array; +} + +export type ComponentsListResponse = + | ComponentsListData + | ErrorMessageResponse + | VersionMismatchError; + +interface Flavor { + id: string; + name: string; + type: string; + logo_url: string; + config_schema: { [key: string]: any }; + docs_url: string | null; + sdk_docs_url: string | null; + connector_type: string | null; + connector_resource_type: string | null; + connector_resource_id_attr: string | null; +} + +interface FlavorListData { + index: number; + max_size: number; + total_pages: number; + total: number; + items: Flavor[]; +} + +export type FlavorListResponse = FlavorListData | ErrorMessageResponse | VersionMismatchError; + +type ComponentTypes = string[]; + +export type ComponentTypesResponse = ComponentTypes | VersionMismatchError | ErrorMessageResponse; + +export { + Stack, + Components, + StackComponent, + StacksData, + ComponentsListData, + Flavor, + ComponentTypes, +}; diff --git a/src/utils/global.ts b/src/utils/global.ts index 8645a153..a4f94a3b 100644 --- a/src/utils/global.ts +++ b/src/utils/global.ts @@ -123,6 +123,7 @@ export function getDefaultPythonInterpreterPath(): string { */ export async function toggleCommands(state: boolean): Promise { await vscode.commands.executeCommand('setContext', 'stackCommandsRegistered', state); + await vscode.commands.executeCommand('setContext', 'componentCommandsRegistered', state); await vscode.commands.executeCommand('setContext', 'serverCommandsRegistered', state); await vscode.commands.executeCommand('setContext', 'pipelineCommandsRegistered', state); await vscode.commands.executeCommand('setContext', 'environmentCommandsRegistered', state); diff --git a/src/views/activityBar/common/LoadingTreeItem.ts b/src/views/activityBar/common/LoadingTreeItem.ts index 9e0fe9b6..df52fdb5 100644 --- a/src/views/activityBar/common/LoadingTreeItem.ts +++ b/src/views/activityBar/common/LoadingTreeItem.ts @@ -23,6 +23,7 @@ export class LoadingTreeItem extends TreeItem { export const LOADING_TREE_ITEMS = new Map([ ['server', new LoadingTreeItem('Refreshing Server View...')], ['stacks', new LoadingTreeItem('Refreshing Stacks View...')], + ['components', new LoadingTreeItem('Refreshing Components View...')], ['pipelineRuns', new LoadingTreeItem('Refreshing Pipeline Runs...')], ['environment', new LoadingTreeItem('Refreshing Environments...')], ['lsClient', new LoadingTreeItem('Waiting for Language Server to start...', '')], diff --git a/src/views/activityBar/common/PaginatedDataProvider.ts b/src/views/activityBar/common/PaginatedDataProvider.ts new file mode 100644 index 00000000..00732048 --- /dev/null +++ b/src/views/activityBar/common/PaginatedDataProvider.ts @@ -0,0 +1,154 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. + +import { Event, EventEmitter, TreeDataProvider, TreeItem, window } from 'vscode'; +import { ITEMS_PER_PAGE_OPTIONS } from '../../../utils/constants'; +import { CommandTreeItem } from './PaginationTreeItems'; +import { LoadingTreeItem } from './LoadingTreeItem'; +import { ErrorTreeItem } from './ErrorTreeItem'; + +/** + * Provides a base class to other DataProviders that provides all functionality + * for pagination in a tree view. + */ +export class PaginatedDataProvider implements TreeDataProvider { + protected _onDidChangeTreeData = new EventEmitter(); + readonly onDidChangeTreeData: Event = + this._onDidChangeTreeData.event; + protected pagination: { + currentPage: number; + itemsPerPage: number; + totalItems: number; + totalPages: number; + } = { + currentPage: 1, + itemsPerPage: 10, + totalItems: 0, + totalPages: 0, + }; + public items: TreeItem[] = []; + protected viewName: string = ''; + + /** + * Loads the next page. + */ + public async goToNextPage(): Promise { + try { + if (this.pagination.currentPage < this.pagination.totalPages) { + this.pagination.currentPage++; + await this.refresh(); + } + } catch (e) { + console.error(`Failed to go the next page: ${e}`); + } + } + + /** + * Loads the previous page + */ + public async goToPreviousPage(): Promise { + try { + if (this.pagination.currentPage > 1) { + this.pagination.currentPage--; + await this.refresh(); + } + } catch (e) { + console.error(`Failed to go the previous page: ${e}`); + } + } + + /** + * Sets the item count per page + */ + public async updateItemsPerPage(): Promise { + try { + const selected = await window.showQuickPick(ITEMS_PER_PAGE_OPTIONS, { + placeHolder: 'Choose the max number of items to display per page', + }); + if (selected) { + this.pagination.itemsPerPage = parseInt(selected, 10); + this.pagination.currentPage = 1; + await this.refresh(); + } + } catch (e) { + console.error(`Failed to update items per page: ${e}`); + } + } + + /** + * Refreshes the view. + */ + public async refresh(): Promise { + this._onDidChangeTreeData.fire(undefined); + } + + /** + * Returns the provided tree item. + * + * @param element element The tree item to return. + * @returns The corresponding VS Code tree item + */ + public getTreeItem(element: TreeItem): TreeItem { + return element; + } + + /** + * Gets the children of the selected element. This will insert + * PaginationTreeItems for navigation if there are other pages. + * @param {TreeItem} element The selected element + * @returns Children of the selected element + */ + public async getChildren(element?: TreeItem): Promise { + if (!element) { + if (this.items[0] instanceof LoadingTreeItem || this.items[0] instanceof ErrorTreeItem) { + return this.items; + } + return this.addPaginationCommands(this.items.slice()); + } + + if ('children' in element && Array.isArray(element.children)) { + return element.children; + } + + return undefined; + } + + private addPaginationCommands(treeItems: TreeItem[]): TreeItem[] { + const NEXT_PAGE_LABEL = 'Next Page'; + const PREVIOUS_PAGE_LABEL = 'Previous Page'; + const NEXT_PAGE_COMMAND = `zenml.next${this.viewName}Page`; + const PREVIOUS_PAGE_COMMAND = `zenml.previous${this.viewName}Page`; + + if (treeItems.length === 0 && this.pagination.currentPage === 1) { + return treeItems; + } + + if (this.pagination.currentPage < this.pagination.totalPages) { + treeItems.push( + new CommandTreeItem(NEXT_PAGE_LABEL, NEXT_PAGE_COMMAND, undefined, 'arrow-circle-right') + ); + } + + if (this.pagination.currentPage > 1) { + treeItems.unshift( + new CommandTreeItem( + PREVIOUS_PAGE_LABEL, + PREVIOUS_PAGE_COMMAND, + undefined, + 'arrow-circle-left' + ) + ); + } + return treeItems; + } +} diff --git a/src/views/activityBar/componentView/ComponentDataProvider.ts b/src/views/activityBar/componentView/ComponentDataProvider.ts new file mode 100644 index 00000000..1dda42f9 --- /dev/null +++ b/src/views/activityBar/componentView/ComponentDataProvider.ts @@ -0,0 +1,152 @@ +// Copyright(c) ZenML GmbH 2024. All Rights Reserved. +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied.See the License for the specific language governing +// permissions and limitations under the License. +import { State } from 'vscode-languageclient'; +import { EventBus } from '../../../services/EventBus'; +import { LSClient } from '../../../services/LSClient'; +import { + LSCLIENT_STATE_CHANGED, + LSP_ZENML_CLIENT_INITIALIZED, + LSP_ZENML_STACK_CHANGED, +} from '../../../utils/constants'; +import { ErrorTreeItem, createErrorItem, createAuthErrorItem } from '../common/ErrorTreeItem'; +import { LOADING_TREE_ITEMS } from '../common/LoadingTreeItem'; +import { CommandTreeItem } from '../common/PaginationTreeItems'; +import { ComponentsListResponse, StackComponent } from '../../../types/StackTypes'; +import { StackComponentTreeItem } from '../stackView/StackTreeItems'; +import { PaginatedDataProvider } from '../common/PaginatedDataProvider'; + +export class ComponentDataProvider extends PaginatedDataProvider { + private static instance: ComponentDataProvider | null = null; + private eventBus = EventBus.getInstance(); + private zenmlClientReady = false; + + constructor() { + super(); + this.subscribeToEvents(); + this.items = [LOADING_TREE_ITEMS.get('components')!]; + this.viewName = 'Component'; + } + + /** + * Subscribes to relevant events to trigger a refresh of the tree view. + */ + public subscribeToEvents(): void { + this.eventBus.on(LSCLIENT_STATE_CHANGED, (newState: State) => { + if (newState === State.Running) { + this.refresh(); + } else { + this.items = [LOADING_TREE_ITEMS.get('lsClient')!]; + this._onDidChangeTreeData.fire(undefined); + } + }); + + this.eventBus.on(LSP_ZENML_CLIENT_INITIALIZED, (isInitialized: boolean) => { + this.zenmlClientReady = isInitialized; + + if (!isInitialized) { + this.items = [LOADING_TREE_ITEMS.get('components')!]; + this._onDidChangeTreeData.fire(undefined); + return; + } + this.refresh(); + this.eventBus.off(LSP_ZENML_STACK_CHANGED, () => this.refresh()); + this.eventBus.on(LSP_ZENML_STACK_CHANGED, () => this.refresh()); + }); + } + + /** + * Retrieves the singleton instance of ComponentDataProvider + * + * @returns {ComponentDataProvider} The signleton instance. + */ + public static getInstance(): ComponentDataProvider { + if (!ComponentDataProvider.instance) { + ComponentDataProvider.instance = new ComponentDataProvider(); + } + + return ComponentDataProvider.instance; + } + + /** + * Refreshes the view. + */ + public async refresh(): Promise { + this.items = [LOADING_TREE_ITEMS.get('components')!]; + this._onDidChangeTreeData.fire(undefined); + + const page = this.pagination.currentPage; + const itemsPerPage = this.pagination.itemsPerPage; + + try { + const newComponentsData = await this.fetchComponents(page, itemsPerPage); + this.items = newComponentsData; + } catch (e) { + this.items = createErrorItem(e); + } + + this._onDidChangeTreeData.fire(undefined); + } + + private async fetchComponents(page: number = 1, itemsPerPage: number = 10) { + if (!this.zenmlClientReady) { + return [LOADING_TREE_ITEMS.get('zenmlClient')!]; + } + + try { + const lsClient = LSClient.getInstance(); + const result = await lsClient.sendLsClientRequest('listComponents', [ + page, + itemsPerPage, + ]); + + if (Array.isArray(result) && result.length === 1 && 'error' in result[0]) { + const errorMessage = result[0].error; + if (errorMessage.includes('Authentication error')) { + return createAuthErrorItem(errorMessage); + } + } + + if (!result || 'error' in result) { + if ('clientVersion' in result && 'serverVersion' in result) { + return createErrorItem(result); + } else { + console.error(`Failed to fetch stack components: ${result.error}`); + return []; + } + } + + if ('items' in result) { + const { items, total, total_pages, index, max_size } = result; + this.pagination = { + currentPage: index, + itemsPerPage: max_size, + totalItems: total, + totalPages: total_pages, + }; + + const components = items.map( + (component: StackComponent) => new StackComponentTreeItem(component) + ); + return components; + } else { + console.error('Unexpected response format:', result); + return []; + } + } catch (e: any) { + console.error(`Failed to fetch components: ${e}`); + return [ + new ErrorTreeItem('Error', `Failed to fetch components: ${e.message || e.toString()}`), + ]; + } + } +} diff --git a/src/views/activityBar/pipelineView/PipelineDataProvider.ts b/src/views/activityBar/pipelineView/PipelineDataProvider.ts index f707cc91..fc83d062 100644 --- a/src/views/activityBar/pipelineView/PipelineDataProvider.ts +++ b/src/views/activityBar/pipelineView/PipelineDataProvider.ts @@ -25,27 +25,20 @@ import { ErrorTreeItem, createErrorItem, createAuthErrorItem } from '../common/E import { LOADING_TREE_ITEMS } from '../common/LoadingTreeItem'; import { PipelineRunTreeItem, PipelineTreeItem } from './PipelineTreeItems'; import { CommandTreeItem } from '../common/PaginationTreeItems'; +import { PaginatedDataProvider } from '../common/PaginatedDataProvider'; /** * Provides data for the pipeline run tree view, displaying detailed information about each pipeline run. */ -export class PipelineDataProvider implements TreeDataProvider { - private _onDidChangeTreeData = new EventEmitter(); - readonly onDidChangeTreeData = this._onDidChangeTreeData.event; - +export class PipelineDataProvider extends PaginatedDataProvider { private static instance: PipelineDataProvider | null = null; private eventBus = EventBus.getInstance(); private zenmlClientReady = false; - private pipelineRuns: PipelineTreeItem[] | TreeItem[] = [LOADING_TREE_ITEMS.get('pipelineRuns')!]; - - private pagination = { - currentPage: 1, - itemsPerPage: 20, - totalItems: 0, - totalPages: 0, - }; constructor() { + super(); + this.items = [LOADING_TREE_ITEMS.get('pipelineRuns')!]; + this.viewName = 'PipelineRuns'; this.subscribeToEvents(); } @@ -57,7 +50,7 @@ export class PipelineDataProvider implements TreeDataProvider { if (newState === State.Running) { this.refresh(); } else { - this.pipelineRuns = [LOADING_TREE_ITEMS.get('lsClient')!]; + this.items = [LOADING_TREE_ITEMS.get('lsClient')!]; this._onDidChangeTreeData.fire(undefined); } }); @@ -66,7 +59,7 @@ export class PipelineDataProvider implements TreeDataProvider { this.zenmlClientReady = isInitialized; if (!isInitialized) { - this.pipelineRuns = [LOADING_TREE_ITEMS.get('pipelineRuns')!]; + this.items = [LOADING_TREE_ITEMS.get('pipelineRuns')!]; this._onDidChangeTreeData.fire(undefined); return; } @@ -83,10 +76,10 @@ export class PipelineDataProvider implements TreeDataProvider { * @returns {PipelineDataProvider} The singleton instance. */ public static getInstance(): PipelineDataProvider { - if (!this.instance) { - this.instance = new PipelineDataProvider(); + if (!PipelineDataProvider.instance) { + PipelineDataProvider.instance = new PipelineDataProvider(); } - return this.instance; + return PipelineDataProvider.instance; } /** @@ -95,74 +88,21 @@ export class PipelineDataProvider implements TreeDataProvider { * @returns A promise resolving to void. */ public async refresh(): Promise { - this.pipelineRuns = [LOADING_TREE_ITEMS.get('pipelineRuns')!]; + this.items = [LOADING_TREE_ITEMS.get('pipelineRuns')!]; this._onDidChangeTreeData.fire(undefined); const page = this.pagination.currentPage; const itemsPerPage = this.pagination.itemsPerPage; try { const newPipelineData = await this.fetchPipelineRuns(page, itemsPerPage); - this.pipelineRuns = newPipelineData; + this.items = newPipelineData; } catch (error: any) { - this.pipelineRuns = createErrorItem(error); + this.items = createErrorItem(error); } this._onDidChangeTreeData.fire(undefined); } - /** - * Retrieves the tree item for a given pipeline run. - * - * @param element The pipeline run item. - * @returns The corresponding VS Code tree item. - */ - getTreeItem(element: TreeItem): TreeItem { - return element; - } - - /** - * Retrieves the children for a given tree item. - * - * @param element The parent tree item. If undefined, root pipeline runs are fetched. - * @returns A promise resolving to an array of child tree items or undefined if there are no children. - */ - async getChildren(element?: TreeItem): Promise { - if (!element) { - if (Array.isArray(this.pipelineRuns) && this.pipelineRuns.length > 0) { - return this.pipelineRuns; - } - - // Fetch pipeline runs for the current page and add pagination controls if necessary - const runs = await this.fetchPipelineRuns( - this.pagination.currentPage, - this.pagination.itemsPerPage - ); - if (this.pagination.currentPage < this.pagination.totalPages) { - runs.push( - new CommandTreeItem( - 'Next Page', - 'zenml.nextPipelineRunsPage', - undefined, - 'arrow-circle-right' - ) - ); - } - if (this.pagination.currentPage > 1) { - runs.unshift( - new CommandTreeItem( - 'Previous Page', - 'zenml.previousPipelineRunsPage', - undefined, - 'arrow-circle-left' - ) - ); - } - return runs; - } else if (element instanceof PipelineTreeItem) { - return element.children; - } - return undefined; - } /** * Fetches pipeline runs from the server and maps them to tree items for display. * @@ -234,29 +174,4 @@ export class PipelineDataProvider implements TreeDataProvider { ]; } } - - public async goToNextPage() { - if (this.pagination.currentPage < this.pagination.totalPages) { - this.pagination.currentPage++; - await this.refresh(); - } - } - - public async goToPreviousPage() { - if (this.pagination.currentPage > 1) { - this.pagination.currentPage--; - await this.refresh(); - } - } - - public async updateItemsPerPage() { - const selected = await window.showQuickPick(ITEMS_PER_PAGE_OPTIONS, { - placeHolder: 'Choose the max number of pipeline runs to display per page', - }); - if (selected) { - this.pagination.itemsPerPage = parseInt(selected, 10); - this.pagination.currentPage = 1; - await this.refresh(); - } - } } diff --git a/src/views/activityBar/stackView/StackDataProvider.ts b/src/views/activityBar/stackView/StackDataProvider.ts index 9dfe1b76..b611767d 100644 --- a/src/views/activityBar/stackView/StackDataProvider.ts +++ b/src/views/activityBar/stackView/StackDataProvider.ts @@ -25,26 +25,18 @@ import { ErrorTreeItem, createErrorItem, createAuthErrorItem } from '../common/E import { LOADING_TREE_ITEMS } from '../common/LoadingTreeItem'; import { StackComponentTreeItem, StackTreeItem } from './StackTreeItems'; import { CommandTreeItem } from '../common/PaginationTreeItems'; +import { PaginatedDataProvider } from '../common/PaginatedDataProvider'; -export class StackDataProvider implements TreeDataProvider { - private _onDidChangeTreeData = new EventEmitter(); - readonly onDidChangeTreeData: Event = - this._onDidChangeTreeData.event; - +export class StackDataProvider extends PaginatedDataProvider { private static instance: StackDataProvider | null = null; private eventBus = EventBus.getInstance(); private zenmlClientReady = false; - public stacks: StackTreeItem[] | TreeItem[] = [LOADING_TREE_ITEMS.get('stacks')!]; - - private pagination = { - currentPage: 1, - itemsPerPage: 20, - totalItems: 0, - totalPages: 0, - }; constructor() { + super(); this.subscribeToEvents(); + this.items = [LOADING_TREE_ITEMS.get('stacks')!]; + this.viewName = 'Stack'; } /** @@ -55,7 +47,7 @@ export class StackDataProvider implements TreeDataProvider { if (newState === State.Running) { this.refresh(); } else { - this.stacks = [LOADING_TREE_ITEMS.get('lsClient')!]; + this.items = [LOADING_TREE_ITEMS.get('lsClient')!]; this._onDidChangeTreeData.fire(undefined); } }); @@ -64,7 +56,7 @@ export class StackDataProvider implements TreeDataProvider { this.zenmlClientReady = isInitialized; if (!isInitialized) { - this.stacks = [LOADING_TREE_ITEMS.get('stacks')!]; + this.items = [LOADING_TREE_ITEMS.get('stacks')!]; this._onDidChangeTreeData.fire(undefined); return; } @@ -86,23 +78,13 @@ export class StackDataProvider implements TreeDataProvider { return this.instance; } - /** - * Returns the provided tree item. - * - * @param {TreeItem} element The tree item to return. - * @returns The corresponding VS Code tree item. - */ - getTreeItem(element: TreeItem): TreeItem { - return element; - } - /** * Refreshes the tree view data by refetching stacks and triggering the onDidChangeTreeData event. * * @returns {Promise} A promise that resolves when the tree view data has been refreshed. */ public async refresh(): Promise { - this.stacks = [LOADING_TREE_ITEMS.get('stacks')!]; + this.items = [LOADING_TREE_ITEMS.get('stacks')!]; this._onDidChangeTreeData.fire(undefined); const page = this.pagination.currentPage; @@ -110,9 +92,9 @@ export class StackDataProvider implements TreeDataProvider { try { const newStacksData = await this.fetchStacksWithComponents(page, itemsPerPage); - this.stacks = newStacksData; + this.items = newStacksData; } catch (error: any) { - this.stacks = createErrorItem(error); + this.items = createErrorItem(error); } this._onDidChangeTreeData.fire(undefined); @@ -179,69 +161,6 @@ export class StackDataProvider implements TreeDataProvider { } } - public async goToNextPage() { - if (this.pagination.currentPage < this.pagination.totalPages) { - this.pagination.currentPage++; - await this.refresh(); - } - } - - public async goToPreviousPage() { - if (this.pagination.currentPage > 1) { - this.pagination.currentPage--; - await this.refresh(); - } - } - - public async updateItemsPerPage() { - const selected = await window.showQuickPick(ITEMS_PER_PAGE_OPTIONS, { - placeHolder: 'Choose the max number of stacks to display per page', - }); - if (selected) { - this.pagination.itemsPerPage = parseInt(selected, 10); - this.pagination.currentPage = 1; - await this.refresh(); - } - } - - /** - * Retrieves the children of a given tree item. - * - * @param {TreeItem} element The tree item whose children to retrieve. - * @returns A promise resolving to an array of child tree items or undefined if there are no children. - */ - async getChildren(element?: TreeItem): Promise { - if (!element) { - if (Array.isArray(this.stacks) && this.stacks.length > 0) { - return this.stacks; - } - - const stacks = await this.fetchStacksWithComponents( - this.pagination.currentPage, - this.pagination.itemsPerPage - ); - if (this.pagination.currentPage < this.pagination.totalPages) { - stacks.push( - new CommandTreeItem('Next Page', 'zenml.nextStackPage', undefined, 'arrow-circle-right') - ); - } - if (this.pagination.currentPage > 1) { - stacks.unshift( - new CommandTreeItem( - 'Previous Page', - 'zenml.previousStackPage', - undefined, - 'arrow-circle-left' - ) - ); - } - return stacks; - } else if (element instanceof StackTreeItem) { - return element.children; - } - return undefined; - } - /** * Helper method to determine if a stack is the active stack. * diff --git a/src/views/activityBar/stackView/StackTreeItems.ts b/src/views/activityBar/stackView/StackTreeItems.ts index 524e8993..ea54c802 100644 --- a/src/views/activityBar/stackView/StackTreeItems.ts +++ b/src/views/activityBar/stackView/StackTreeItems.ts @@ -44,13 +44,15 @@ export class StackTreeItem extends vscode.TreeItem { export class StackComponentTreeItem extends vscode.TreeItem { constructor( public component: StackComponent, - public stackId: string + public stackId?: string ) { super(component.name, vscode.TreeItemCollapsibleState.None); - this.tooltip = `Type: ${component.type}, Flavor: ${component.flavor}, ID: ${stackId}`; + this.tooltip = stackId + ? `Type: ${component.type}, Flavor: ${component.flavor}, ID: ${stackId}` + : `Type: ${component.type}, Flavor: ${component.flavor}`; this.description = `${component.type} (${component.flavor})`; this.contextValue = 'stackComponent'; - this.id = `${stackId}-${component.id}`; + this.id = stackId ? `${stackId}-${component.id}` : `${component.id}`; } } diff --git a/src/views/statusBar/index.ts b/src/views/statusBar/index.ts index c6db6759..d931a52e 100644 --- a/src/views/statusBar/index.ts +++ b/src/views/statusBar/index.ts @@ -116,17 +116,17 @@ export default class ZenMLStatusBar { */ private async switchStack(): Promise { const stackDataProvider = StackDataProvider.getInstance(); - const { stacks } = stackDataProvider; + const { items } = stackDataProvider; - const containsErrors = stacks.some(stack => stack instanceof ErrorTreeItem); + const containsErrors = items.some(stack => stack instanceof ErrorTreeItem); - if (containsErrors || stacks.length === 0) { + if (containsErrors || items.length === 0) { window.showErrorMessage('No stacks available.'); return; } - const activeStack = stacks.find(stack => stack.id === this.activeStackId); - const otherStacks = stacks.filter(stack => stack.id !== this.activeStackId); + const activeStack = items.find(stack => stack.id === this.activeStackId); + const otherStacks = items.filter(stack => stack.id !== this.activeStackId); const quickPickItems = [ {
{{capitalize name}}