From af04458ff9ac53d6a6dedab95b3d08528719b6e8 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Mon, 11 Dec 2023 07:03:30 -0500 Subject: [PATCH] Propapagate type aliases through sphinx (#453) * type annotations fix in dynamical systems * add type aliases for interventional * add type aliases for observational * add counterfactual doc type aliases * lint --- chirho/counterfactual/handlers/counterfactual.py | 2 ++ chirho/counterfactual/handlers/explanation.py | 2 ++ chirho/counterfactual/ops.py | 2 ++ chirho/dynamical/handlers/interruption.py | 2 ++ chirho/dynamical/ops.py | 2 ++ chirho/interventional/handlers.py | 6 ++++-- chirho/interventional/ops.py | 2 ++ chirho/observational/handlers/condition.py | 2 ++ .../observational/handlers/soft_conditioning.py | 2 ++ chirho/observational/internals.py | 2 ++ chirho/observational/ops.py | 2 ++ docs/source/conf.py | 16 ++++++++++++++++ 12 files changed, 40 insertions(+), 2 deletions(-) diff --git a/chirho/counterfactual/handlers/counterfactual.py b/chirho/counterfactual/handlers/counterfactual.py index 0fbec2c25..cbed31726 100644 --- a/chirho/counterfactual/handlers/counterfactual.py +++ b/chirho/counterfactual/handlers/counterfactual.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict, Generic, Mapping, TypeVar import pyro diff --git a/chirho/counterfactual/handlers/explanation.py b/chirho/counterfactual/handlers/explanation.py index 693c17a49..47afb350d 100644 --- a/chirho/counterfactual/handlers/explanation.py +++ b/chirho/counterfactual/handlers/explanation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections.abc import contextlib import functools diff --git a/chirho/counterfactual/ops.py b/chirho/counterfactual/ops.py index 151a33502..8817d6042 100644 --- a/chirho/counterfactual/ops.py +++ b/chirho/counterfactual/ops.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from typing import Optional, Tuple, TypeVar diff --git a/chirho/dynamical/handlers/interruption.py b/chirho/dynamical/handlers/interruption.py index b9987ea30..2b74dcf76 100644 --- a/chirho/dynamical/handlers/interruption.py +++ b/chirho/dynamical/handlers/interruption.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numbers import typing from typing import Callable, Generic, Tuple, TypeVar, Union diff --git a/chirho/dynamical/ops.py b/chirho/dynamical/ops.py index 26c0af016..09dd0c36c 100644 --- a/chirho/dynamical/ops.py +++ b/chirho/dynamical/ops.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numbers from typing import Callable, Mapping, Optional, Tuple, TypeVar, Union diff --git a/chirho/interventional/handlers.py b/chirho/interventional/handlers.py index c55ad0d8e..3bfbd694c 100644 --- a/chirho/interventional/handlers.py +++ b/chirho/interventional/handlers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import functools from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar @@ -41,7 +43,7 @@ def _intervene_atom( def _intervene_atom_distribution( obs: pyro.distributions.Distribution, act: Optional[AtomicIntervention[pyro.distributions.Distribution]] = None, - **kwargs + **kwargs, ) -> pyro.distributions.Distribution: """ Intervene on a distribution in a probabilistic program. @@ -70,7 +72,7 @@ def _dict_intervene( def _intervene_callable( obs: collections.abc.Callable, act: Optional[CompoundIntervention[T]] = None, - **call_kwargs + **call_kwargs, ) -> Callable[..., T]: if act is None: return obs diff --git a/chirho/interventional/ops.py b/chirho/interventional/ops.py index cea13923a..27e761afa 100644 --- a/chirho/interventional/ops.py +++ b/chirho/interventional/ops.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from typing import Callable, Hashable, Mapping, Optional, Tuple, TypeVar, Union diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index a097743c6..01a52fa3b 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union import pyro diff --git a/chirho/observational/handlers/soft_conditioning.py b/chirho/observational/handlers/soft_conditioning.py index b79f3978b..f048e4dd3 100644 --- a/chirho/observational/handlers/soft_conditioning.py +++ b/chirho/observational/handlers/soft_conditioning.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import operator from typing import Callable, Literal, Optional, Protocol, TypedDict, TypeVar, Union diff --git a/chirho/observational/internals.py b/chirho/observational/internals.py index 14b5a3698..d61a13b71 100644 --- a/chirho/observational/internals.py +++ b/chirho/observational/internals.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Mapping, Optional, TypeVar import pyro diff --git a/chirho/observational/ops.py b/chirho/observational/ops.py index a32b357fe..021bc8c0d 100644 --- a/chirho/observational/ops.py +++ b/chirho/observational/ops.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from typing import Callable, Hashable, Mapping, Optional, TypeVar, Union diff --git a/docs/source/conf.py b/docs/source/conf.py index 4308258b0..9f8a92ec3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,6 +23,22 @@ author = 'Basis' +# -- Type hints configuration ------------------------------------------------ + +autodoc_type_aliases = { + 'R': 'R', + 'State': 'State', + 'Dynamics': 'Dynamics', + 'AtomicIntervention': 'AtomicIntervention', + 'CompoundIntervention': 'CompoundIntervention', + 'Intervention': 'Intervention', + 'AtomicObservation': 'AtomicObservation', + 'CompoundObservation': 'CompoundObservation', + 'Observation': 'Observation', + 'Kernel': 'Kernel', + +} + # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be