From 328076b3d2fa526c3cceaa7c891cba65fd45d979 Mon Sep 17 00:00:00 2001 From: parul-l Date: Mon, 4 Mar 2024 16:02:12 -0500 Subject: [PATCH] Allow multiple folders in `class_package` node registration key (#123) * Allowing packages in metadata to overwrite env variable * Adding documentation for class_package key * Restoring plot builds * remove extra line --- README_METADATA.md | 74 ++++++++++++++++++++++ primrose/configuration/configuration.py | 26 +++++--- pyproject.toml | 2 +- test/hello_world_read_process_write.json | 53 ++++++++++++++++ test/sample_project/change_column_order.py | 49 ++++++++++++++ test/sample_project/data/tennis_output.csv | 15 +++++ test/test_configuration.py | 12 ++-- test/test_dag_runner.py | 9 +++ 8 files changed, 223 insertions(+), 17 deletions(-) create mode 100644 test/hello_world_read_process_write.json create mode 100644 test/sample_project/change_column_order.py create mode 100644 test/sample_project/data/tennis_output.csv diff --git a/README_METADATA.md b/README_METADATA.md index 47d6710a..70320096 100644 --- a/README_METADATA.md +++ b/README_METADATA.md @@ -7,6 +7,7 @@ Users are free to add whatever keys and configuration they wish in this section. In this document, we will cover the following metadata keys: - traverser + - class_package - data_object - section_registry - section_run @@ -19,6 +20,8 @@ For instance, you might have a configuration "traverser": "DepthFirstTraverser", + "class_package": ["./src", "./sample_project"], + "data_object": { "read_from_cache": false, "read_filename": "/tmp/data_object_20190618.dill", @@ -70,6 +73,77 @@ The way users defined what to use is with the `traverser` key in `metadata`. If } ``` +# class_package +When creating additional nodes for your projects, you can specify where to look for potential nodes to register, via + +(1) setting the `PRIMROSE_EXT_NODE_PACKAGE` environment variable, or +(2) specifying the folders in the `metadata` portion of your primrose config via the `class_package` key. + + +For example, if all custom nodes are in the `src` folder, +``` +├── config +│ ├── my_primrose_config.yml +├── src +│ ├── readers +│ │ ├── reader_node.py +│ ├── pipelines +│ │ ├── pipeline_node.py +``` +we can set `PRIMROSE_EXT_NODE_PACKAGE=./src`, or define the primrose config as +``` +metadata: + section_registry: + - reader_config + - writer_config + class_package: + - './src' + +implementation_config: + reader_config: + ... + writer_config: + ... +``` + +The latter is particularly useful when mulitple projects are being built off the same primrose package. For example, given the following folder structure +``` +├── config +│ ├── primrose_config1.yml +│ ├── primrose_config2.yml +├── src +│ ├── readers +│ │ ├── reader_node.py +│ ├── pipelines +│ │ ├── pipeline_node.py +├── projects +│ ├── sample_project1 +│ │ ├── sample_project1_node.py +│ ├── sample_project2 +│ │ ├── sample_project2_node.py +``` +we can define the primrose config for `sample_project1` as +``` +metadata: + class_package: + - './src' + - './projects/sample_project1' + +implementation_config: + ... +``` +and the primrose config for `sample_project2` as +``` +metadata: + class_package: + - './src' + - './projects/sample_project2' + +implementation_config: + ... +``` +This ensures all classes within the `src` and `projects/sample_project{i}` folders are considered when registering nodes specified in the primrose configs. + # section_registry key The default assumption of ``` diff --git a/primrose/configuration/configuration.py b/primrose/configuration/configuration.py index 74a35186..0114bf14 100644 --- a/primrose/configuration/configuration.py +++ b/primrose/configuration/configuration.py @@ -455,8 +455,9 @@ def _import_file(full_name, path): def _get_file_candidates(self): """Get file candidates to search through when specifying a class package. - Priority will first consider environment variable PRIMROSE_EXT_NODE_PACKAGE. If unset, will - search the configuration metadata for key `class_package`. If nothing is specified, in either + First consider value in PRIMROSE_EXT_NODE_PACKAGE but give priority to `class_package` + in configuration metadata. That is, if `class_package` is set in the configuration metadata, + it will override the PRIMROSE_EXT_NODE_PACKAGE variable. If nothing is specified, in either location, an empty list is returned. Returns: @@ -464,27 +465,32 @@ def _get_file_candidates(self): """ # for now assume packages/top level only if CLASS_ENV_PACKAGE_KEY in os.environ: + logging.info("Using package from environment variable") pkg_name = os.environ[CLASS_ENV_PACKAGE_KEY] - elif self.config_metadata: - if "class_package" in self.config_metadata: - pkg_name = self.config_metadata["class_package"] - else: - return [] + if self.config_metadata and "class_package" in self.config_metadata: + # overwrites package set in environment variable CLASS_ENV_PACKAGE_KEY + logging.info("Using package from configuration metadata") + pkg_name = self.config_metadata["class_package"] else: return [] # look for path to module to find potential file candidates try: # if we are passed something like __init__.py, grab the package - if os.path.isfile(pkg_name): + if isinstance(pkg_name, str) and os.path.isfile(pkg_name): pkg_name = os.path.dirname(pkg_name) # if we have an actual package from pip install - if not os.path.isdir(pkg_name): + if isinstance(pkg_name, str) and not os.path.isdir(pkg_name): pkg_name = os.path.dirname(importlib.import_module(pkg_name).__file__) except ModuleNotFoundError: logging.warning("Could not find module specified for external node configuration") return [] - candidates = glob.glob(os.path.join(pkg_name, "**", "*.py"), recursive=True) + if isinstance(pkg_name, str): + pkg_name = [pkg_name] + + candidates = [] + for pkg in pkg_name: + candidates += glob.glob(os.path.join(pkg, "**", "*.py"), recursive=True) return candidates diff --git a/pyproject.toml b/pyproject.toml index 64c86599..84ff1465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "mysql-connector-python>=8.0.32", "slackclient>=2.9.4", "testfixtures>=7.1.0", - "moto>=4.1.4", + "moto==4.1.4", "nltk>=3.8.1", "pydot>=1.4.2", ] diff --git a/test/hello_world_read_process_write.json b/test/hello_world_read_process_write.json new file mode 100644 index 00000000..5c6b7200 --- /dev/null +++ b/test/hello_world_read_process_write.json @@ -0,0 +1,53 @@ +{ + /* + * simple ETL: read from a CSV, reorder the columns and then write data to a CSV + */ + + "metadata": { + "section_registry": [ + "reader_config", + "pipeline_config", + "writer_config" + ], + "class_package": [ + "./src", + "./test/sample_project" + ], + }, + "implementation_config": { + "reader_config": { + "read_data": { + "class": "CsvReader", + "filename": "data/tennis.csv", + "destinations": [ + "reorder_cols" + ] + } + }, + "pipeline_config": { + "reorder_cols": { + "class": "ColumnReorder", + "cols_order": [ + "id", + "outlook", + "humidity", + "play", + "temp", + "windy" + ], + "destinations": [ + "write_output" + ] + } + }, + "writer_config": { + "write_output": { + "class": "CsvWriter", + "key": "data", + "dir": "test/sample_project/data", + "filename": "tennis_output.csv" + } + } + } +} + diff --git a/test/sample_project/change_column_order.py b/test/sample_project/change_column_order.py new file mode 100644 index 00000000..2c120dfe --- /dev/null +++ b/test/sample_project/change_column_order.py @@ -0,0 +1,49 @@ +"""Module to reorder columns in a dataframe + +Author(s): + Parul Laul (parul.laul@ww.com) + +""" +import logging +from primrose.base.pipeline import AbstractPipeline + + +class ColumnReorder(AbstractPipeline): + """Reorder columns in a dataframe""" + + @staticmethod + def necessary_config(node_config): + """Return the necessary configuration keys for the DataFrameJoiner object + + Args: + node_config (dict): set of parameters / attributes for the node + + Note: + cols_order: list of column names in desired order + + Returns: + set of keys + + """ + return set(["cols_order"]) + + def transform(self, data_object): + """Get DataFrames from the data object and reorder columns + + Args: + data_object (DataObject): instance of DataObject + + Returns: + data_object (DataObject): instance of DataObject + + """ + + upstream_data = data_object.get_upstream_data( + self.instance_name, pop_data=False + ) + logging.info("Reordering columns") + data = upstream_data['data'][self.node_config["cols_order"]] + + data_object.add(self, data, overwrite=False) + + return data_object \ No newline at end of file diff --git a/test/sample_project/data/tennis_output.csv b/test/sample_project/data/tennis_output.csv new file mode 100644 index 00000000..9d2ee778 --- /dev/null +++ b/test/sample_project/data/tennis_output.csv @@ -0,0 +1,15 @@ +id,outlook,humidity,play,temp,windy +1,sunny,high,no,hot,False +2,sunny,high,no,hot,True +3,overcast,high,yes,hot,False +4,rainy,high,yes,mild,False +5,rainy,normal,yes,cool,False +6,rainy,normal,no,cool,True +7,overcast,normal,yes,cool,True +8,sunny,high,no,mild,False +9,sunny,normal,yes,cool,False +10,rainy,normal,yes,mild,False +11,sunny,normal,yes,mild,True +12,overcast,high,yes,mild,True +13,overcast,normal,yes,hot,False +14,rainy,high,no,mild,True diff --git a/test/test_configuration.py b/test/test_configuration.py index 17439008..436ac9a1 100644 --- a/test/test_configuration.py +++ b/test/test_configuration.py @@ -603,19 +603,19 @@ def test_class_package(mock_env): NodeFactory().unregister("TestExtNode") -def test_env_override_class_package(mock_env): +def test_class_package_override_env(mock_env): config = { "metadata": {"class_package": "junk"}, "implementation_config": { "reader_config": {"read_data": {"class": "TestExtNode", "destinations": []}} }, } - config = Configuration( - config_location=None, is_dict_config=True, dict_config=config + with pytest.raises(Exception) as e: + Configuration(config_location=None, is_dict_config=True, dict_config=config) + assert ( + "Cannot register node class TestExtNode" + in str(e) ) - assert config.config_string - assert config.config_hash - NodeFactory().unregister("TestExtNode") def test_incorrect_class_package(): diff --git a/test/test_dag_runner.py b/test/test_dag_runner.py index e414ddbf..15a1439a 100644 --- a/test/test_dag_runner.py +++ b/test/test_dag_runner.py @@ -18,6 +18,7 @@ from abc import abstractmethod from primrose.base.writer import AbstractWriter +TEST_DIR = os.path.dirname(os.path.abspath(__file__)) def test_run(): config = { @@ -766,3 +767,11 @@ def test_run_pruned(): ("root", "INFO", "left node!"), ("root", "INFO", "All done. Bye bye!"), ) + + +def test_class_packages_run(): + config_loc = os.path.join(TEST_DIR, "hello_world_read_process_write.json") + configuration = Configuration(config_location=config_loc) + + runner = DagRunner(configuration) + runner.run() \ No newline at end of file