diff --git a/README.md b/README.md index 4c36c91..9f56f39 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,114 @@ loader_cls = get_loader( data = yaml.load(content, Loader=loader_cls) ``` +## Custom Tag Handling + +Yamling provides a `YAMLParser` class for handling custom YAML tags. This allows you to define how specific tagged values should be processed during YAML loading. + +### Basic Tag Registration + +You can register tag handlers using either a decorator or explicit registration: + +```python +from yamling import YAMLParser +from dataclasses import dataclass + +@dataclass +class Person: + name: str + age: int + +# Create parser instance +yaml_parser = YAMLParser() + +# Register handler using decorator +@yaml_parser.register("person") +def handle_person(data: dict) -> Person: + return Person(**data) + +# Or register handler explicitly +def handle_uppercase(data: str) -> str: + return data.upper() + +yaml_parser.register_handler("uppercase", handle_uppercase) +``` + +### Using Custom Tags + +Once registered, you can use the custom tags in your YAML: + +```yaml +# config.yaml +user: !person + name: John Doe + age: 30 +message: !uppercase "hello world" +``` + +Load the YAML using the parser: + +```python +# Load from string +data = yaml_parser.load_yaml(""" +user: !person + name: John Doe + age: 30 +message: !uppercase "hello world" +""") + +# Or load from file +data = yaml_parser.load_yaml_file("config.yaml") + +print(data["user"]) # Person(name='John Doe', age=30) +print(data["message"]) # "HELLO WORLD" +``` + +### Complex Structures + +Custom tags can be used in nested structures and lists: + +```yaml +team: + manager: !person + name: Alice Smith + age: 45 + members: + - !person + name: Bob Johnson + age: 30 + - !person + name: Carol White + age: 28 +messages: + - !uppercase "welcome" + - !uppercase "goodbye" +``` + +### Combining with Other Features + +The `YAMLParser` class supports all of Yamling's standard features: + +```python +data = yaml_parser.load_yaml_file( + "config.yaml", + mode="safe", # Safety mode + include_base_path="configs/", # For !include directives + resolve_strings=True, # Enable Jinja2 template resolution + resolve_inherit=True, # Enable inheritance + jinja_env=jinja_env # Custom Jinja2 environment +) +``` + +### Available Tags + +You can list all registered tags: + +```python +tags = yaml_parser.list_tags() +print(tags) # ['!person', '!uppercase'] +``` + + ## Universal load / dump interface Yamling provides a universal load function that can handle YAML, JSON, TOML, and INI files. diff --git a/src/yamling/__init__.py b/src/yamling/__init__.py index c43d8b3..b963ce6 100644 --- a/src/yamling/__init__.py +++ b/src/yamling/__init__.py @@ -5,6 +5,7 @@ from yamling.load_universal import load, load_file, ParsingError from yamling.yaml_dumpers import dump_yaml from yamling.dump_universal import DumpingError, dump, dump_file +from yamling.yamlparser import YAMLParser YAMLError = yaml.YAMLError # Reference for external libs that need to catch this error @@ -22,4 +23,5 @@ "ParsingError", "DumpingError", "YAMLInput", + "YAMLParser", ] diff --git a/src/yamling/yamlparser.py b/src/yamling/yamlparser.py new file mode 100644 index 0000000..c21c2b1 --- /dev/null +++ b/src/yamling/yamlparser.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, Any, TypeVar + +import yaml +from yaml import MappingNode, Node, SafeLoader, ScalarNode, SequenceNode + +from yamling import yaml_loaders + + +if TYPE_CHECKING: + import os + + import fsspec + import jinja2 + + from yamling import yamltypes + + +# Type for the handler function +T = TypeVar("T") +HandlerFunc = Callable[[Any], T] + + +class YAMLParser: + """Manages custom YAML tags and provides YAML loading capabilities.""" + + def __init__(self) -> None: + self._tag_handlers: dict[str, HandlerFunc] = {} + self._tag_prefix: str = "!" # Default prefix for tags + + def register(self, tag_name: str) -> Callable[[HandlerFunc[T]], HandlerFunc[T]]: + """Decorator to register a new tag handler. + + Args: + tag_name: Name of the tag without prefix + + Returns: + Decorator function that registers the handler + + Usage: + @yaml_parser.register("person") + def handle_person(data: dict) -> Person: + return Person(**data) + """ + + def decorator(func: HandlerFunc[T]) -> HandlerFunc[T]: + full_tag = f"{self._tag_prefix}{tag_name}" + self._tag_handlers[full_tag] = func + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + return func(*args, **kwargs) + + return wrapper + + return decorator + + def register_handler(self, tag_name: str, handler: HandlerFunc[T]) -> None: + """Explicitly register a tag handler function. + + Args: + tag_name: Name of the tag without prefix + handler: Function that processes the tagged data + """ + full_tag = f"{self._tag_prefix}{tag_name}" + self._tag_handlers[full_tag] = handler + + def process_tag(self, tag: str, data: Any) -> Any: + """Process data with the registered handler for the given tag. + + Args: + tag: Full tag name (including prefix) + data: Data to be processed + + Raises: + ValueError: If no handler is registered for the tag + """ + if tag not in self._tag_handlers: + msg = f"No handler registered for tag: {tag}" + raise ValueError(msg) + return self._tag_handlers[tag](data) + + def get_handler(self, tag: str) -> HandlerFunc | None: + """Get the handler function for a specific tag. + + Args: + tag: Full tag name (including prefix) + + Returns: + Handler function if found, None otherwise + """ + return self._tag_handlers.get(tag) + + def list_tags(self) -> list[str]: + """Return a list of registered tags. + + Returns: + List of registered tag names + """ + return list(self._tag_handlers.keys()) + + def create_constructor(self, tag_name: str) -> Callable[[yaml.Loader, Node], Any]: + """Create a YAML constructor function for a specific tag. + + Args: + tag_name: Name of the tag without prefix + + Returns: + Constructor function for the YAML loader + """ + full_tag = f"{self._tag_prefix}{tag_name}" + + def constructor(loader: yaml.Loader, node: Node) -> Any: + if isinstance(node, ScalarNode): + value = loader.construct_scalar(node) + elif isinstance(node, SequenceNode): + value = loader.construct_sequence(node) + elif isinstance(node, MappingNode): + value = loader.construct_mapping(node) + else: + msg = f"Unsupported node type for tag {full_tag}" + raise TypeError(msg) + + return self.process_tag(full_tag, value) + + return constructor + + def register_with_loader( + self, loader_class: yamltypes.LoaderType = SafeLoader + ) -> None: + """Register all tags with a YAML loader class. + + Args: + loader_class: The YAML loader class to register with + """ + for tag in self._tag_handlers: + loader_class.add_constructor(tag, self.create_constructor(tag[1:])) + + def load_yaml( + self, + text: yaml_loaders.YAMLInput, + *, + mode: yamltypes.LoaderStr | yamltypes.LoaderType = "unsafe", + include_base_path: str + | os.PathLike[str] + | fsspec.AbstractFileSystem + | None = None, + resolve_strings: bool = False, + resolve_dict_keys: bool = False, + resolve_inherit: bool = False, + jinja_env: jinja2.Environment | None = None, + ) -> Any: + """Load YAML content with custom tag handlers. + + Args: + text: The YAML content to load + mode: YAML loader safety mode ('unsafe', 'full', or 'safe') + Custom YAML loader classes are also accepted + include_base_path: Base path for resolving !include directives + resolve_strings: Whether to resolve Jinja2 template strings + resolve_dict_keys: Whether to resolve Jinja2 templates in dictionary keys + resolve_inherit: Whether to resolve INHERIT directives + jinja_env: Optional Jinja2 environment for template resolution + + Returns: + Parsed YAML data with custom tag handling + + Example: + ```python + yaml_parser = YAMLParser() + + @yaml_parser.register("person") + def handle_person(data: dict) -> Person: + return Person(**data) + + data = yaml_parser.load_yaml( + "person: !person {name: John, age: 30}", + mode="safe", + resolve_strings=True + ) + ``` + """ + loader = yaml_loaders.LOADERS[mode] if isinstance(mode, str) else mode + self.register_with_loader(loader) + try: + return yaml_loaders.load_yaml( + text, + mode=loader, + include_base_path=include_base_path, + resolve_strings=resolve_strings, + resolve_dict_keys=resolve_dict_keys, + resolve_inherit=resolve_inherit, + jinja_env=jinja_env, + ) + except yaml.constructor.ConstructorError as e: + # Convert YAML ConstructorError to ValueError + msg = f"No handler registered for tag: {e.problem.split()[-1]}" + raise ValueError(msg) from e + + def load_yaml_file( + self, + path: str | os.PathLike[str], + *, + mode: yamltypes.LoaderStr | yamltypes.LoaderType = "unsafe", + include_base_path: str + | os.PathLike[str] + | fsspec.AbstractFileSystem + | None = None, + resolve_inherit: bool = False, + resolve_strings: bool = False, + resolve_dict_keys: bool = False, + jinja_env: jinja2.Environment | None = None, + ) -> Any: + """Load YAML file with custom tag handlers. + + Args: + path: Path to the YAML file + mode: YAML loader safety mode ('unsafe', 'full', or 'safe') + Custom YAML loader classes are also accepted + include_base_path: Base path for resolving !include directives + resolve_inherit: Whether to resolve INHERIT directives + resolve_strings: Whether to resolve Jinja2 template strings + resolve_dict_keys: Whether to resolve Jinja2 templates in dictionary keys + jinja_env: Optional Jinja2 environment for template resolution + + Returns: + Parsed YAML data with custom tag handling + + Example: + ```python + yaml_parser = YAMLParser() + + @yaml_parser.register("config") + def handle_config(data: dict) -> Config: + return Config(**data) + + data = yaml_parser.load_yaml_file( + "config.yml", + resolve_inherit=True, + include_base_path="configs/" + ) + ``` + """ + loader = yaml_loaders.LOADERS[mode] if isinstance(mode, str) else mode + self.register_with_loader(loader) + try: + return yaml_loaders.load_yaml_file( + path, + mode=loader, + include_base_path=include_base_path, + resolve_inherit=resolve_inherit, + resolve_strings=resolve_strings, + resolve_dict_keys=resolve_dict_keys, + jinja_env=jinja_env, + ) + except yaml.constructor.ConstructorError as e: + msg = f"No handler registered for tag: {e.problem.split()[-1]}" + raise ValueError(msg) from e + + +# Usage example: +if __name__ == "__main__": + from dataclasses import dataclass + + @dataclass + class Person: + name: str + age: int + + yaml_parser = YAMLParser() + + @yaml_parser.register("person") + def handle_person(data: dict[str, Any]) -> Person: + return Person(**data) + + def handle_uppercase(data: str) -> str: + return data.upper() + + yaml_parser.register_handler("uppercase", handle_uppercase) + + yaml_content = """ + person: !person + name: John Doe + age: 30 + message: !uppercase "hello world" + """ + + data = yaml_parser.load_yaml(yaml_content) + print("Parsed data:", data) + print("Available tags:", yaml_parser.list_tags()) diff --git a/tests/test_yamlparser.py b/tests/test_yamlparser.py new file mode 100644 index 0000000..52d18f8 --- /dev/null +++ b/tests/test_yamlparser.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import pytest + +from yamling.yamlparser import YAMLParser + + +if TYPE_CHECKING: + import pathlib + + +@dataclass +class Person: + name: str + age: int + + +@pytest.fixture +def yaml_parser(): + """Create a fresh YAMLParser instance for each test.""" + return YAMLParser() + + +@pytest.fixture +def setup_basic_handlers(yaml_parser: YAMLParser): + """Setup basic handlers for common test cases.""" + + @yaml_parser.register("person") + def handle_person(data: dict[str, Any]) -> Person: + return Person(**data) + + @yaml_parser.register("uppercase") + def handle_uppercase(data: str) -> str: + return data.upper() + + return yaml_parser + + +def test_register_decorator(yaml_parser: YAMLParser): + """Test handler registration using decorator syntax.""" + + @yaml_parser.register("test") + def handle_test(data: str) -> str: + return f"Test: {data}" + + assert "!test" in yaml_parser.list_tags() + + yaml_content = "value: !test hello" + result = yaml_parser.load_yaml(yaml_content) + assert result["value"] == "Test: hello" + + +def test_register_handler_method(yaml_parser: YAMLParser): + """Test explicit handler registration using register_handler method.""" + + def handle_lowercase(data: str) -> str: + return data.lower() + + yaml_parser.register_handler("lowercase", handle_lowercase) + assert "!lowercase" in yaml_parser.list_tags() + + yaml_content = "value: !lowercase HELLO" + result = yaml_parser.load_yaml(yaml_content) + assert result["value"] == "hello" + + +def test_multiple_tags(setup_basic_handlers: YAMLParser): + """Test handling multiple custom tags in the same YAML document.""" + yaml_content = """ + person: !person + name: John Doe + age: 30 + greeting: !uppercase hello + """ + + result = setup_basic_handlers.load_yaml(yaml_content) + assert isinstance(result["person"], Person) + assert result["person"].name == "John Doe" + assert result["person"].age == 30 # noqa: PLR2004 + assert result["greeting"] == "HELLO" + + +def test_nested_tags(setup_basic_handlers: YAMLParser): + """Test handling nested tags in complex structures.""" + yaml_content = """ + people: + - !person + name: John Doe + age: 30 + - !person + name: Jane Doe + age: 25 + messages: + - !uppercase hello + - !uppercase world + """ + + result = setup_basic_handlers.load_yaml(yaml_content) + assert len(result["people"]) == 2 # noqa: PLR2004 + assert all(isinstance(p, Person) for p in result["people"]) + assert result["messages"] == ["HELLO", "WORLD"] + + +def test_load_yaml_file(setup_basic_handlers: YAMLParser, tmp_path: pathlib.Path): + """Test loading YAML from a file.""" + yaml_content = """ + person: !person + name: John Doe + age: 30 + """ + + # Create temporary YAML file + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text(yaml_content) + + result = setup_basic_handlers.load_yaml_file(yaml_file) + assert isinstance(result["person"], Person) + assert result["person"].name == "John Doe" + assert result["person"].age == 30 # noqa: PLR2004 + + +def test_invalid_tag(yaml_parser: YAMLParser): + """Test handling of unregistered tags.""" + yaml_content = "value: !invalid_tag data" + + with pytest.raises(ValueError, match="No handler registered for tag"): + yaml_parser.load_yaml(yaml_content) + + +def test_list_tags(setup_basic_handlers: YAMLParser): + """Test listing registered tags.""" + tags = setup_basic_handlers.list_tags() + assert "!person" in tags + assert "!uppercase" in tags + assert len(tags) == 2 # noqa: PLR2004 + + +def test_different_node_types(yaml_parser: YAMLParser): + """Test handling different YAML node types (scalar, sequence, mapping).""" + + @yaml_parser.register("process") + def handle_process(data: Any) -> str: + return str(data) + + yaml_content = """ + scalar: !process value + sequence: !process [1, 2, 3] + mapping: !process + key: value + """ + + result = yaml_parser.load_yaml(yaml_content) + assert result["scalar"] == "value" + assert result["sequence"] == "[1, 2, 3]" + assert result["mapping"] == "{'key': 'value'}"