Skip to content

Commit

Permalink
ENH: improve how proposal kwargs are determined
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Aug 22, 2024
1 parent 20f6cb0 commit c2df454
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 50 deletions.
68 changes: 39 additions & 29 deletions nessai/proposal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Callable, Union
from warnings import warn

from ..utils.settings import _get_all_kwargs


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,15 +35,9 @@ def check_proposal_kwargs(ProposalClass, kwargs, strict=False):
dict
Dictionary of updated kwargs.
"""
from ..proposal import AugmentedFlowProposal, FlowProposal
from ..gw.proposal import AugmentedGWFlowProposal, GWFlowProposal

proposals = {
AugmentedFlowProposal,
AugmentedGWFlowProposal,
FlowProposal,
GWFlowProposal,
}
proposals = list(available_base_flow_proposal_classes().values()) + list(
available_external_flow_proposal_classes(load=True).values()
)

class_keys = set()
for cls in getmro(ProposalClass):
Expand All @@ -65,7 +61,7 @@ def check_proposal_kwargs(ProposalClass, kwargs, strict=False):
allowed_extra_keys = set()

for proposal in proposals:
allowed_extra_keys.update(set(signature(proposal).parameters.keys()))
allowed_extra_keys.update(set(_get_all_kwargs(proposal)))

invalid_keys = extra_keys - allowed_extra_keys

Expand All @@ -84,6 +80,37 @@ def check_proposal_kwargs(ProposalClass, kwargs, strict=False):
return kwargs_out


def available_base_flow_proposal_classes():
from .flowproposal import FlowProposal
from .augmented import AugmentedFlowProposal
from ..gw.proposal import GWFlowProposal, AugmentedGWFlowProposal
from ..experimental.proposal.clustering import ClusteringFlowProposal
from ..experimental.gw.proposal import ClusteringGWFlowProposal

base_proposals = {
"clusteringgwflowproposal": ClusteringGWFlowProposal,
"augmentedgwflowproposal": AugmentedGWFlowProposal,
"gwflowproposal": GWFlowProposal,
"clusteringflowproposal": ClusteringFlowProposal,
"augmentedflowproposal": AugmentedFlowProposal,
"flowproposal": FlowProposal,
}
return base_proposals


def available_external_flow_proposal_classes(load: bool = False):
from ..utils.entry_points import get_entry_points

external_proposals = get_entry_points("nessai.proposals")
logger.debug(
f"Found the following external proposals: {external_proposals.keys()}"
)
if load:
for key in external_proposals:
external_proposals[key] = external_proposals[key].load()
return external_proposals


def get_flow_proposal_class(
proposal_class: Union[str, None, Callable],
) -> Callable:
Expand All @@ -103,26 +130,9 @@ def get_flow_proposal_class(
Proposal class
"""
from .flowproposal import FlowProposal
from .augmented import AugmentedFlowProposal
from ..gw.proposal import GWFlowProposal, AugmentedGWFlowProposal
from ..experimental.proposal.clustering import ClusteringFlowProposal
from ..experimental.gw.proposal import ClusteringGWFlowProposal
from ..utils.entry_points import get_entry_points

base_proposals = {
"augmentedflowproposal": AugmentedFlowProposal,
"flowproposal": FlowProposal,
"gwflowproposal": GWFlowProposal,
"augmentedgwflowproposal": AugmentedGWFlowProposal,
"clusteringflowproposal": ClusteringFlowProposal,
"clusteringgwflowproposal": ClusteringGWFlowProposal,
}

external_proposals = get_entry_points("nessai.proposals")

logger.debug(
f"Found the following external proposals: {external_proposals.keys()}"
)
base_proposals = available_base_flow_proposal_classes()
external_proposals = available_external_flow_proposal_classes(load=False)

if proposal_class is None:
return FlowProposal
Expand Down
55 changes: 34 additions & 21 deletions nessai/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,47 @@
Used for bilby and pycbc-inference.
"""

from inspect import signature
from typing import List, Callable, Tuple
from inspect import getmro, signature
from typing import Any, List, Callable, Tuple


def _get_kwargs(func: Callable) -> dict[str, Any]:
return {
k: v.default
for k, v in signature(func).parameters.items()
if v.default is not v.empty
}


def _get_all_kwargs(callable: Callable) -> dict[str, Any]:
try:
parameters = {}
for cls in getmro(callable):
parameters.update(_get_kwargs(cls))
except AttributeError:
parameters = _get_kwargs(callable)
return parameters


def _get_standard_methods() -> Tuple[List[Callable], List[Callable]]:
"""Get a list of the methods used by the standard sampler and the run
method.
"""
from ..flowsampler import FlowSampler
from ..proposal import AugmentedFlowProposal, FlowProposal
from ..proposal.utils import (
available_base_flow_proposal_classes,
available_external_flow_proposal_classes,
)
from ..samplers import NestedSampler

methods = [
AugmentedFlowProposal,
FlowProposal,
NestedSampler,
FlowSampler,
]
methods = (
list(available_external_flow_proposal_classes(load=True).values())
+ list(available_base_flow_proposal_classes().values())
+ [
NestedSampler,
FlowSampler,
]
)
run_methods = [
FlowSampler.run_standard_sampler,
]
Expand Down Expand Up @@ -75,13 +98,7 @@ def get_all_kwargs(
run_kwargs = {}
for kwds, methods in zip([kwargs, run_kwargs], [methods, run_methods]):
for m in methods:
kwds.update(
{
k: v.default
for k, v in signature(m).parameters.items()
if v.default is not v.empty
}
)
kwds.update(_get_all_kwargs(m))

if split_kwargs:
return kwargs, run_kwargs
Expand Down Expand Up @@ -109,9 +126,5 @@ def get_run_kwargs_list(importance_nested_sampler: bool = False) -> List[str]:
method = FlowSampler.run_importance_nested_sampler
else:
method = FlowSampler.run_standard_sampler
run_kwargs_list = [
k
for k, v in signature(method).parameters.items()
if v.default is not v.empty
]
run_kwargs_list = list(_get_all_kwargs(method).keys())
return run_kwargs_list

0 comments on commit c2df454

Please sign in to comment.