Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add factory for dynamic TransformOp #427

Merged
merged 20 commits into from
Jan 27, 2021
Merged

Add factory for dynamic TransformOp #427

merged 20 commits into from
Jan 27, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jan 17, 2021

pair coded with @eb8680 @fehiepsi @ordabayevy

This introduces classes WrappedTransformOp and LogAbsDetJacobianOp for dynamically created ops from backend Transform objects and TransformedDistributions. Unlike statically created ops, instances of these ops will never have custom rules; hence we do not create unique subclasses for each new op.

I have also refactored Op and CachedOpMeta and WrappedOpMeta to more finely control whether and how we cache op instance creation.

Forwards compatibility

This PR xfails until the following PRs are implemented upstream. I have tested locally with these changes. It is safe to merge this PR before the upstream changes have merged, because the upstream interfaces are pretty stable:

Tested

  • cache test for WrappedTransformOp and LogAbsDetJacobianOp
  • gc test for WrappedTransformOp
  • distribution test with PowerTransform (passes locally)
  • distribution test with HaarTransorm (passes locally)

@fritzo fritzo added Blocked Blocked by other issues WIP labels Jan 17, 2021
@fritzo fritzo removed the WIP label Jan 17, 2021
@fritzo
Copy link
Member Author

fritzo commented Jan 17, 2021

Recording discussion from slack:

Du Phan: Do you expect that we will change anything in those to_funsor(ExpTransform) https://github.com/pyro-ppl/funsor/pull/427/files#diff-2aa6bbdd9f351b5925977d25eacad85b7fb658e38f07a4e1b33efa76d47d4b50R277 ? It seems to me that the default transform_to_funsor should take care of the job .

Fritz Obermeyer: If we wanted to replace the existing ExpOp with the new transform_to_funsor() logic, we would need to address caching and .__hash__(). That is a complex issue because PyTorch already uses the .__eq__() method for expensive tensor comparisons, and that usage is incompatible with lightweight caching. However I see no need to address this issue right away. @fehiepsi what is your motivation?

Du Phan: Ah, I just thought that if to_funsor(ExpTransfor) is not needed, then implementing #365 for jax backend is not needed.

Fritz Obermeyer: Yeah, let's just think about this more. A naive solution might simply define a custom hash method like

def hash_op(fn):
    if isinstance(fn, backend_dist.transforms.Transform):
        return id(fn)  # Hash Transform instances by id.
    return hash(fn)

that would at least allow us to delete existing custom to_funsor() implementations for singleton Transform s. But that would still leak memory when we start using flows.

@fritzo fritzo added the WIP label Jan 17, 2021
@eb8680 eb8680 mentioned this pull request Jan 19, 2021
34 tasks
@fritzo fritzo added awaiting review and removed Blocked Blocked by other issues WIP labels Jan 21, 2021
@fritzo fritzo requested review from eb8680 and fehiepsi January 21, 2021 15:10
@fritzo fritzo mentioned this pull request Jan 21, 2021
2 tasks
@fritzo fritzo added the Blocked Blocked by other issues label Jan 21, 2021
@fritzo
Copy link
Member Author

fritzo commented Jan 21, 2021

@fehiepsi

It seems to me that the default transform_to_funsor should take care of the job .

Yes, I think I better understand your point now. This PR now adds a default to_transform(numpyro.distributions.Transform) that should work on all transforms that are not batched. We would need to register additional patterns for batching to work. That should be easy in a subsequent PR: we can move the patterns from funsor/torch/distributions.py up to funsor/distribution.py and then register those patterns in funsor/jax/distributions.py

EDIT to get this free behavior, you will need to add a Transform.forward_shape() method in NumPyro, following pytorch/pytorch#50581

@fehiepsi
Copy link
Member

fehiepsi commented Jan 21, 2021

will need to add a Transform.forward_shape() method in NumPyro

@fritzo This will resolve current issues at TransformedDistribution and init_strategy so I want to make a PR for it. But I am a bit worried if it will be enough for funsor, in complicated situations like normalization flows that Eli pointed out in our last discussions. Do we need batch shape or something else?

Btw, can you change Travis setting to let me rerun it? Currently, I can rerun jobs in pyro, numpyro but not in funsor.

@fritzo
Copy link
Member Author

fritzo commented Jan 21, 2021

Do we need batch shape or something else?

I believe .batch_shape is unnecessary for TransformedDistribution. I believe we need only a small interface to fully support normalizing flows and batched transforms inside TransformedDistribution:

That would be sufficient metadata to implement to_funsor(TransformedDistribution), but we still haven't implemented the generic conversion rule in Funsor. If we additionally want to construct flows with lazy inputs, we would need at least:

@fehiepsi fehiepsi removed the Blocked Blocked by other issues label Jan 21, 2021
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the implementation fits my intuition (for now, I couldn't evaluate the caching and weakref details).

funsor/ops/op.py Show resolved Hide resolved
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code all looks good to me, but it would be good to add new TransformedDistribution test cases to our generic distribution tests to make sure sampling and other methods also behave correctly. If getting those working turns out to be a hassle we can defer to a followup PR.



@pytest.fixture
def dist():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a useful pattern, we may want to copy it in the distribution tests in a later PR.

test/test_distribution.py Show resolved Hide resolved
True,
xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"),
])
def test_haar_transform(shape, to_event):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto: try adding a HaarTransform test case to test/test_distribution_generic.py to test sampling, to_funsor/to_data conversion and other distribution methods?

Copy link
Member Author

@fritzo fritzo Jan 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you suggest a fix for this broken test_generic_distribution_to_funsor()?
(note to run you'll need the infer-shapes branch of Pyro)

$ FUNSOR_BACKEND=torch pytest -vx test/test_distribution_generic.py -k Haar --pdb
===================================================== test session starts ======================================================
platform darwin -- Python 3.7.0, pytest-6.1.2, py-1.9.0, pluggy-0.13.1 -- /Users/fobermey/opt/miniconda3/envs/pyro/bin/python
cachedir: .pytest_cache
benchmark: 3.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/fobermey/github/pyro-ppl/funsor, configfile: setup.cfg
plugins: forked-1.2.0, nbval-0.9.6, xdist-2.1.0, benchmark-3.2.3
collected 2244 items / 2211 deselected / 33 selected

test/test_distribution_generic.py::test_generic_distribution_to_funsor[dist.TransformedDistribution( dist.Normal(loc=case.loc, scale=1.).to_event(1), dist.transforms.HaarTransform(dim=-1)) (('loc', 'rand(() + (3,))'),)] FAILED [  3%]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> traceback >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

case = <test.test_distribution_generic.DistTestCase object at 0x7fb9b08c8cc0>

    @pytest.mark.parametrize("case", TEST_CASES, ids=str)
    def test_generic_distribution_to_funsor(case):

        HIGHER_ORDER_DISTS = [
            backend_dist.Independent,
            backend_dist.TransformedDistribution,
        ] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch"
             else [backend_dist.ExpandedDistribution])

        raw_dist = case.get_dist()
        expected_value_domain = case.expected_value_domain

        dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape)
        with interpretation(normalize_with_subs):
            funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name)
        assert funsor_dist.inputs["value"] == expected_value_domain

        while isinstance(funsor_dist, funsor.cnf.Contraction):
            funsor_dist = [term for term in funsor_dist.terms
>                          if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0]
E           IndexError: list index out of range

test/test_distribution_generic.py:590: IndexError
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> entering PDB >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB post_mortem (IO-capturing turned off) >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /Users/fobermey/github/pyro-ppl/funsor/test/test_distribution_generic.py(590)test_generic_distribution_to_funsor()
-> if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0]
(Pdb) print(funsor_dist.pretty())
Contraction(ops.nullop, ops.add,
 frozenset(),
 (Unary(ops.neg,
   Binary(ops.log_abs_det_jacobian,
   │Unary(ops._InverseTransform,
   │ Variable('value', Reals[3])),
   │Variable('value', Reals[3]))),
  Contraction(ops.add, ops.nullop,
   frozenset({Variable('_pyro_event_dim_-1_...
   (Normal(
   │ Tensor(
   │  torch.tensor([0.4366315007209778, 0.8093...
   │  (('_pyro_event_dim_-1__BOUND_1',
   │   │Bint[3, ],),),
   │  'real'),
   │ Tensor(
   │  torch.tensor([1.0, 1.0, 1.0], dtype=torc...
   │  (('_pyro_event_dim_-1__BOUND_1',
   │   │Bint[3, ],),),
   │  'real'),
   │ Binary(ops.GetitemOp(0),
   │  Unary(ops._InverseTransform,
   │   Variable('value', Reals[3])),
   │  Variable('_pyro_event_dim_-1__BOUND_1', ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the underlying reason is that the way we use a lazy funsor.terms.Independent to convert IndependentDistributions to funsors (via the pattern funsor.distribution.indepdist_to_funsor) is sort of fiddly and affects pattern matching, especially when combined with transforms. Now that #402 is done we should try to represent them directly by ops.expanding their parameters to the appropriate event_shape. I'll see if I can get this working in a separate PR.

funsor/ops/op.py Show resolved Hide resolved
@fritzo
Copy link
Member Author

fritzo commented Jan 27, 2021

@eb8680 thanks for #443! This PR now passes locally and should be safe to merge.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eb8680 eb8680 merged commit 90a488c into master Jan 27, 2021
@eb8680 eb8680 deleted the dynamic-transform-op branch January 27, 2021 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants