diff --git a/python/lsst/daf/butler/queries/expression_factory.py b/python/lsst/daf/butler/queries/expression_factory.py index c880c18c2b..4f0fc75289 100644 --- a/python/lsst/daf/butler/queries/expression_factory.py +++ b/python/lsst/daf/butler/queries/expression_factory.py @@ -44,7 +44,7 @@ # This module uses ExpressionProxy and its subclasses to wrap ColumnExpression, # but it just returns OrderExpression and Predicate objects directly, because -# we don't need to overload an operators or define any methods on those. +# we don't need to overload any operators or define any methods on those. class ExpressionProxy: @@ -135,7 +135,7 @@ def in_iterable(self, others: Iterable) -> rt.Predicate: Parameters ---------- - others : `Iterable` + others : `collections.abc.Iterable` An iterable of `ExpressionProxy` or values to be interpreted as literals. @@ -268,7 +268,7 @@ class DimensionProxy(ScalarExpressionProxy, DimensionElementProxy): Parameters ---------- - element : `DimensionElement` + dimension : `DimensionElement` Element this object wraps. Notes @@ -355,22 +355,61 @@ def __getitem__(self, name: str | EllipsisType) -> DatasetTypeProxy: return DatasetTypeProxy(name) def not_(self, operand: rt.Predicate) -> rt.Predicate: - """Apply a logical NOT operation to a boolean expression.""" - return rt.LogicalNot.model_construct(operand=operand) + """Apply a logical NOT operation to a boolean expression. + + Parameters + ---------- + operand : `relation_tree.Predicate` + Expression to invert. + + Returns + ------- + logical_not : `relation_tree.Predicate` + A boolean expression that evaluates to the opposite of ``operand``. + """ + return operand.logical_not() - def all(self, *args: rt.Predicate) -> rt.Predicate: - """Combine a sequence of boolean expressions with logical AND.""" - operands: list[rt.Predicate] = [] + def all(self, first: rt.Predicate, /, *args: rt.Predicate) -> rt.Predicate: + """Combine a sequence of boolean expressions with logical AND. + + Parameters + ---------- + first : `relation_tree.Predicate` + First operand (required). + *args + Additional operands. + + Returns + ------- + logical_and : `relation_tree.Predicate` + A boolean expression that evaluates to `True` only if all operands + evaluate to `True. + """ + result = first for arg in args: - operands.extend(arg._flatten_and()) - return rt.LogicalAnd.model_construct(operands=tuple(operands)) + result = result.logical_and(arg) + return result + + def any(self, first: rt.Predicate, /, *args: rt.Predicate) -> rt.Predicate: + """Combine a sequence of boolean expressions with logical OR. + + Parameters + ---------- + first : `relation_tree.Predicate` + First operand (required). + *args + Additional operands. - def any(self, *args: rt.Predicate) -> rt.Predicate: - """Combine a sequence of boolean expressions with logical OR.""" - operands: list[rt.Predicate] = [] + Returns + ------- + logical_or : `relation_tree.Predicate` + A boolean expression that evaluates to `True` if any operand + evaluates to `True. + """ + result = first for arg in args: - operands.extend(arg._flatten_or()) - return rt.LogicalOr.model_construct(operands=tuple(operands)) + result = result.logical_or(arg) + return result @staticmethod def literal(value: object) -> ExpressionProxy: @@ -379,6 +418,16 @@ def literal(value: object) -> ExpressionProxy: Expression proxy objects obtained from this factory can generally be compared directly to literals, so calling this method directly in user code should rarely be necessary. + + Parameters + ---------- + value : `object` + Value to include as a literal in an expression tree. + + Returns + ------- + expression : `ExpressionProxy` + Expression wrapper for this literal. """ expression = rt.make_column_literal(value) match expression.expression_type: diff --git a/python/lsst/daf/butler/queries/relation_tree/_base.py b/python/lsst/daf/butler/queries/relation_tree/_base.py index cdd8339a86..5a9f18748a 100644 --- a/python/lsst/daf/butler/queries/relation_tree/_base.py +++ b/python/lsst/daf/butler/queries/relation_tree/_base.py @@ -39,7 +39,7 @@ from abc import ABC, abstractmethod from types import EllipsisType -from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast import pydantic @@ -279,6 +279,11 @@ class PredicateBase(RelationTreeBase, ABC): """Base class for objects that represent boolean column expressions in a relation tree. + A `Predicate` tree is always in conjunctive normal form (ANDs of ORs of + NOTs). This is enforced by type annotations (and hence Pydantic + validation) and the `logical_and`, `logical_or`, and `logical_not` factory + methods. + This is a closed hierarchy whose concrete, `~typing.final` derived classes are members of the `Predicate` union. That union should generally be used in type annotations rather than the formally-open base class. @@ -308,14 +313,51 @@ def gather_required_columns(self) -> set[ColumnReference]: """ return set() - def _flatten_and(self) -> tuple[Predicate, ...]: - """Convert this expression to a sequence of predicates that should be - combined with logical AND. + # The 'other' arguments of the methods below are annotated as Any because + # MyPy doesn't correctly recognize subclass implementations that use + # @overload, and the signature of this base class doesn't really matter, + # since it's the union of all concrete implementations that's public; + # the base class exists largely as a place to hang docstrings. + + @abstractmethod + def logical_and(self, other: Any) -> Predicate: + """Return the logical AND of this predicate and another. + + Parameters + ---------- + other : `Predicate` + Other operand. + + Returns + ------- + result : `Predicate` + A predicate presenting the logical AND. """ - return (self,) # type: ignore[return-value] + raise NotImplementedError() - def _flatten_or(self) -> tuple[Predicate, ...]: - """Convert this expression to a sequence of predicates that should be - combined with logical OR. + @abstractmethod + def logical_or(self, other: Any) -> Predicate: + """Return the logical OR of this predicate and another. + + Parameters + ---------- + other : `Predicate` + Other operand. + + Returns + ------- + result : `Predicate` + A predicate presenting the logical OR. + """ + raise NotImplementedError() + + @abstractmethod + def logical_not(self) -> Predicate: + """Return the logical NOTof this predicate. + + Returns + ------- + result : `Predicate` + A predicate presenting the logical OR. """ - return (self,) # type: ignore[return-value] + raise NotImplementedError() diff --git a/python/lsst/daf/butler/queries/relation_tree/_predicate.py b/python/lsst/daf/butler/queries/relation_tree/_predicate.py index 4aa8559f18..1a592aa0c7 100644 --- a/python/lsst/daf/butler/queries/relation_tree/_predicate.py +++ b/python/lsst/daf/butler/queries/relation_tree/_predicate.py @@ -37,12 +37,12 @@ "InContainer", "InRange", "InRelation", - "StringPredicate", "DataCoordinateConstraint", "ComparisonOperator", ) -from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, Union, final +import itertools +from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, Union, final, overload import pydantic @@ -66,7 +66,7 @@ class LogicalAnd(PredicateBase): predicate_type: Literal["and"] = "and" - operands: tuple[Predicate, ...] = pydantic.Field(min_length=2) + operands: tuple[LogicalAndOperand, ...] = pydantic.Field(min_length=2) """Upstream boolean expressions to combine.""" def gather_required_columns(self) -> set[ColumnReference]: @@ -81,16 +81,42 @@ def precedence(self) -> int: # Docstring inherited. return 6 + def logical_and(self, other: Predicate) -> LogicalAnd: + # Docstring inherited. + match other: + case LogicalAnd(): + return LogicalAnd.model_construct(operands=self.operands + other.operands) + case _: + return LogicalAnd.model_construct(operands=self.operands + (other,)) + + def logical_or(self, other: Predicate) -> LogicalAnd: + # Docstring inherited. + match other: + case LogicalAnd(): + return LogicalAnd.model_construct( + operands=tuple( + [a.logical_or(b) for a, b in itertools.product(self.operands, other.operands)] + ), + ) + case _: + return LogicalAnd.model_construct( + operands=tuple([a.logical_or(other) for a in self.operands]), + ) + + def logical_not(self) -> Predicate: + # Docstring inherited. + first, *rest = self.operands + result: Predicate = first.logical_not() + for operand in rest: + result = result.logical_or(operand.logical_not()) + return result + def __str__(self) -> str: return " AND ".join( str(operand) if operand.precedence <= self.precedence else f"({operand})" for operand in self.operands ) - def _flatten_and(self) -> tuple[Predicate, ...]: - # Docstring inherited. - return self.operands - @final class LogicalOr(PredicateBase): @@ -100,7 +126,7 @@ class LogicalOr(PredicateBase): predicate_type: Literal["or"] = "or" - operands: tuple[Predicate, ...] = pydantic.Field(min_length=2) + operands: tuple[LogicalOrOperand, ...] = pydantic.Field(min_length=2) """Upstream boolean expressions to combine.""" def gather_required_columns(self) -> set[ColumnReference]: @@ -115,16 +141,39 @@ def precedence(self) -> int: # Docstring inherited. return 7 + def logical_and(self, other: Predicate) -> LogicalAnd: + return _base_logical_and(self, other) + + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... + + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... + + def logical_or(self, other: Predicate) -> Predicate: + # Docstring inherited. + match other: + case LogicalAnd(): + return LogicalAnd.model_construct( + operands=tuple([self.logical_or(b) for b in other.operands]) + ) + case LogicalOr(): + return LogicalOr.model_construct(operands=self.operands + other.operands) + case _: + return LogicalOr.model_construct(operands=self.operands + (other,)) + + def logical_not(self) -> LogicalAnd: + # Docstring inherited. + return LogicalAnd.model_construct(operands=tuple([x.logical_not() for x in self.operands])) + def __str__(self) -> str: return " OR ".join( str(operand) if operand.precedence <= self.precedence else f"({operand})" for operand in self.operands ) - def _flatten_or(self) -> tuple[Predicate, ...]: - # Docstring inherited. - return self.operands - @final class LogicalNot(PredicateBase): @@ -132,7 +181,7 @@ class LogicalNot(PredicateBase): predicate_type: Literal["not"] = "not" - operand: Predicate + operand: LogicalNotOperand """Upstream boolean expression to invert.""" def gather_required_columns(self) -> set[ColumnReference]: @@ -144,6 +193,25 @@ def precedence(self) -> int: # Docstring inherited. return 4 + def logical_and(self, other: Predicate) -> LogicalAnd: + return _base_logical_and(self, other) + + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... + + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... + + def logical_or(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_or(self, other) + + def logical_not(self) -> LogicalNotOperand: + # Docstring inherited. + return self.operand + def __str__(self) -> str: if self.operand.precedence <= self.precedence: return f"NOT {self.operand}" @@ -169,6 +237,26 @@ def precedence(self) -> int: # Docstring inherited. return 5 + def logical_and(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_and(self, other) + + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... + + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... + + def logical_or(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_or(self, other) + + def logical_not(self) -> LogicalOrOperand: + # Docstring inherited. + return LogicalNot.model_construct(operand=self) + def __str__(self) -> str: if self.operand.precedence <= self.precedence: return f"{self.operand} IS NULL" @@ -204,6 +292,26 @@ def precedence(self) -> int: # Docstring inherited. return 5 + def logical_and(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_and(self, other) + + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... + + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... + + def logical_or(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_or(self, other) + + def logical_not(self) -> LogicalOrOperand: + # Docstring inherited. + return LogicalNot.model_construct(operand=self) + def __str__(self) -> str: a = str(self.a) if self.a.precedence <= self.precedence else f"({self.a})" b = str(self.b) if self.b.precedence <= self.precedence else f"({self.b})" @@ -236,6 +344,26 @@ def precedence(self) -> int: # Docstring inherited. return 5 + def logical_and(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_and(self, other) + + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... + + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... + + def logical_or(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_or(self, other) + + def logical_not(self) -> LogicalOrOperand: + # Docstring inherited. + return LogicalNot.model_construct(operand=self) + def __str__(self) -> str: m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" return f"{m} IN [{', '.join(str(item) for item in self.container)}]" @@ -270,6 +398,26 @@ def precedence(self) -> int: # Docstring inherited. return 5 + def logical_and(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_and(self, other) + + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... + + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... + + def logical_or(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_or(self, other) + + def logical_not(self) -> LogicalOrOperand: + # Docstring inherited. + return LogicalNot.model_construct(operand=self) + def __str__(self) -> str: s = f"{self.start if self.start else ''}..{self.stop if self.stop is not None else ''}" if self.step != 1: @@ -309,39 +457,30 @@ def precedence(self) -> int: # Docstring inherited. return 5 - def __str__(self) -> str: - m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" - c = str(self.column) if self.column.precedence <= self.precedence else f"({self.column})" - return f"{m} IN [{{{self.relation}}}.{c}]" - - -@final -class StringPredicate(PredicateBase): - """A wrapper for boolean column expressions created by parsing a string - expression. - - Remembering the original string is useful for error reporting. - """ - - predicate_type: Literal["string_predicate"] = "string_predicate" + def logical_and(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_and(self, other) - where: str - """The string expression.""" + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... - tree: Predicate - """Boolean expression tree created from the string expression.""" + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... - def gather_required_columns(self) -> set[ColumnReference]: + def logical_or(self, other: Predicate) -> Predicate: # Docstring inherited. - return self.tree.gather_required_columns() + return _base_logical_or(self, other) - @property - def precedence(self) -> int: + def logical_not(self) -> LogicalOrOperand: # Docstring inherited. - return 5 + return LogicalNot.model_construct(operand=self) def __str__(self) -> str: - return f'parsed("{self.where}")' + m = str(self.member) if self.member.precedence <= self.precedence else f"({self.member})" + c = str(self.column) if self.column.precedence <= self.precedence else f"({self.column})" + return f"{m} IN [{{{self.relation}}}.{c}]" @final @@ -363,22 +502,77 @@ def precedence(self) -> int: # Docstring inherited. return 5 + def logical_and(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_and(self, other) + + @overload + def logical_or(self, other: LogicalAnd) -> LogicalAnd: + ... + + @overload + def logical_or(self, other: LogicalAndOperand) -> LogicalOr: + ... + + def logical_or(self, other: Predicate) -> Predicate: + # Docstring inherited. + return _base_logical_or(self, other) + + def logical_not(self) -> LogicalOrOperand: + # Docstring inherited. + return LogicalNot.model_construct(operand=self) + def __str__(self) -> str: return str(DataCoordinate.from_required_values(self.dimensions, self.values)) +def _base_logical_and(a: LogicalAndOperand, b: Predicate) -> LogicalAnd: + match b: + case LogicalAnd(): + return LogicalAnd.model_construct(operands=(a,) + b.operands) + case _: + return LogicalAnd.model_construct(operands=(a, b)) + + +@overload +def _base_logical_or(a: LogicalOrOperand, b: LogicalAnd) -> LogicalAnd: + ... + + +@overload +def _base_logical_or(a: LogicalOrOperand, b: LogicalAndOperand) -> LogicalAndOperand: + ... + + +def _base_logical_or(a: LogicalOrOperand, b: Predicate) -> Predicate: + match b: + case LogicalAnd(): + return LogicalAnd.model_construct( + operands=tuple(_base_logical_or(a, b_operand) for b_operand in b.operands) + ) + case LogicalOr(): + return LogicalOr.model_construct(operands=(a,) + b.operands) + case _: + return LogicalOr.model_construct(operands=(a, b)) + + +_LogicalNotOperand = Union[ + IsNull, + Comparison, + InContainer, + InRange, + InRelation, + DataCoordinateConstraint, +] +_LogicalOrOperand = Union[_LogicalNotOperand, LogicalNot] +_LogicalAndOperand = Union[_LogicalOrOperand, LogicalOr] + + +LogicalNotOperand = Annotated[_LogicalNotOperand, pydantic.Field(discriminator="predicate_type")] +LogicalOrOperand = Annotated[_LogicalOrOperand, pydantic.Field(discriminator="predicate_type")] +LogicalAndOperand = Annotated[_LogicalAndOperand, pydantic.Field(discriminator="predicate_type")] + Predicate = Annotated[ - Union[ - LogicalAnd, - LogicalOr, - LogicalNot, - IsNull, - Comparison, - InContainer, - InRange, - InRelation, - StringPredicate, - DataCoordinateConstraint, - ], + Union[_LogicalAndOperand, LogicalAnd], pydantic.Field(discriminator="predicate_type"), ]