Skip to content

Commit

Permalink
Update flexi_tree_map doc
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDuffield authored Jan 31, 2024
1 parent c042518 commit 1141441
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions uqlib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,20 @@ def flexi_tree_map(
namespace: str = "",
) -> TensorTree:
"""Applies a pure function to each tensor in a PyTree, with inplace argument.
```
out_tensor = func(tensor, *rest_tensors)
```
If inplace = True, uses uqlib.tree_map_inplacify_ to modify the tree in-place
(and return modified tree).
If inplace = False, uses optree.tree_map to return a new tree.
where `out_tensor` is of the same shape as `tensor`.
Therefore
```
out_tree = func(tree, *rests, inplace=True)
```
will return `out_tree` a pointer to the original `tree` with leaves (tensors) modified in place.
If `inplace=False`, `flexi_tree_map` is equivalent to `optree.tree_map` and returns a new tree.
Args:
func: A pure function that takes a tensor as its first argument and a returns
Expand All @@ -340,9 +350,8 @@ def flexi_tree_map(
(default: :const:`''`, i.e., the global namespace)
Returns:
The original ``tree`` with the value at each leaf is given by the side-effect of function
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
Either the original tree modified in-place or a new tree depending on the `inplace`
argument.
"""
tm = tree_map_inplacify_ if inplace else tree_map
return tm(
Expand Down

0 comments on commit 1141441

Please sign in to comment.