Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] support multiple simultaneous right-hand-sides #3

Open
teichert opened this issue May 27, 2021 · 10 comments
Open

[feature request] support multiple simultaneous right-hand-sides #3

teichert opened this issue May 27, 2021 · 10 comments
Labels
enhancement New feature or request

Comments

@teichert
Copy link

I have a situation where many of my formulas come in sets like the following:

'bijklmn,bij,bkl,bm,bn->'
'bijklmn,bij,bkl,bm,bn->bij'
'bijklmn,bij,bkl,bm,bn->bkl'
'bijklmn,bij,bkl,bm,bn->bm'
'bijklmn,bij,bkl,bm,bn->bn'

It seems like there should be a substantial amount of work between these that could be shared, but I might be misunderstanding the way that you are doing the reducing.

If there is a possibility to support more efficient bulk work (or maybe even if there isn't), I'd suggest allowing the above set of formulas to be replaced with the following (which should indicate that *einsum should return a tuple of tensors with the respective results):

'bijklmn,bij,bkl,bm,bn->,bij,bkl,bm,bn'

@bdusell would you please comment on the following when you get a chance:

  1. Are you open to the proposed notation?
  2. Are you open to pull requests?
  3. Do you see an opportunity for reused computation in the above scenario? (after a quick look, it doesn't look as easy as I had initially thought)

Thanks!

@bdusell
Copy link
Owner

bdusell commented May 27, 2021

This is a very interesting idea -- I'm afraid I won't have time to address it for a while (NeurIPS deadline!), but I wanted to leave some comments now.

  1. Yes, this looks like a reasonable, backwards-compatible extension of einsum syntax
  2. Absolutely yes
  3. At the very least, it's clear that you can use 'bijklmn,bij,bkl,bm,bn->bij' to compute 'bijklmn,bij,bkl,bm,bn->'. So in the general case I think you could build a tree of intermediate results with subsets of variables summed out. However, this might negate a lot of the memory savings. I bet we could do something fancy like yield results one at a time so some of the intermediate tensors can be freed (i.e. a DFS traversal of the tree). The derivative could be tricky because there are multiple outputs.

@teichert
Copy link
Author

Excellent; thanks, and best wishes on your NeurIPS push!

@bdusell
Copy link
Owner

bdusell commented May 29, 2021

Been thinking about this some more... Every operation in that tree I mentioned is another einsum operation, so I think all we need to do is define einsum to dispatch to e.g. einsum_single and einsum_multi depending on which syntax is used, where einsum_multi executes said tree using multiple calls to einsum_single. That's great because it involves no custom backward pass. I don't think modifying einsum directly instead would be more efficient, but I would need to think about it more.

Figuring out the optimal tree structure, on the other hand, definitely seems like a difficult problem. Numbering your equations:

No. Equation Summed Vars
1 bijklmn,bij,bkl,bm,bn-> bijklmn
2 bijklmn,bij,bkl,bm,bn->bij klmn
3 bijklmn,bij,bkl,bm,bn->bkl ijmn
4 bijklmn,bij,bkl,bm,bn->bm ijkln
5 bijklmn,bij,bkl,bm,bn->bn ijklm

We're looking for the largest subset of summed variables shared by the largest number of outputs. I don't know yet if there's an algorithm that does this in better than exponential time. Do we sum out klm first and use it for 1, 2, and 5, or do we sum out klmn first and use it for 1 and 2? Are going for best space or time efficiency? And the result depends on the actual dimensions of the inputs, which are not known by the equation at compile time.

@teichert
Copy link
Author

teichert commented May 29, 2021

Excellent; thanks for the reply! Also, sorry for the distraction from NeurIPS, and sorry for the long brain-dump here; I might not be thinking about this for a while, so I want to make sure I'll be able to jog my memory when needed. I don't expect you to even read this until after your paper is submitted. :)

My original idea

Originally I actually wasn't even thinking of being clever in how to reuse the summing work (even though that does seem like a good idea since multiple outputs will typically need to sum out shared sets of variables---that sounds like an extra bonus!). I was actually just expecting/hoping that at least the product computation could be reused (but in the code I don't actually see a full product followed by summing out---I expect this to be because of some combination of: (1) tricks to save memory, (2) tricks to avoid redundant multiplication (3), tricks to avoid redundant summing, and (4) me not looking close enough :). Is it true that we can typically reuse the multiplication work between outputs?

Clever summation reuse

However, considering the optimal summing problem for a moment and assuming that there is a dense, full product tensor sitting around somewhere before the summing happens, we can imagine a graph where nodes are variable subsets and there is an edge to subset A from subset B to whenever A is proper subset of B---the weight on the edge being the amount of work needed to do the summing (ignore the fact that this depends on actual dimension sizes which isn't know at compile time; maybe the right answer would be to recompute the optimal tree before each einsum or something; just ignore it for now). The goal, then, is to find a min-cost tree that touches the variable subsets included in our output factors (and the root node that has all variables). This sounds like exactly the Steiner Tree Problem (on graphs) which is NP-Hard (even if our graph were polynomialy sized!). Stack overflow suggests some approximations---including, ironically, a modified belief propagation approach (I say ironically because my interest in einsum is that I'm using a special case of it as a subroutine in my own differentiable loopy belief propagation implementation [which may possibly be subsumed by einsum :)]; and more, see below on factor graphs). The thing that differentiates Steiner from regular minimum spanning tree (MST) is the freedom to choose an auxiliary subset of nodes to connect. One heuristic approach that comes to mind is to instead try running MST on a few carefully chosen subsets and then picking whichever is cheapest. Subsets could include e.g. (A) the root + the set of output subsets; (B) A + all pairwise unions. Or maybe the junction tree trick mentioned below might give some idea(?).

Factor Graphs

However, your discussion of variable elimination order encourages me and highlights that Einstein notation is essentially a concise way to specify a Factor Graph along with a query for marginal inference. (Which also means that einsum is NP-Hard in the general case.) Allowing multiple right-hand-sides lets us specify more marginal inference queries at once and also allows us to represent "batched" (or "conditional") factor graphs (the dimensions shared by all queries do not contribute to the tree-width of the factor graph). [Einsum with multiple rh-sides technically describes slightly more: it allows some factors to come "uncollapsed" (multiple dimensions for the same variable) and it allows describing the order of the output dimensions, but those both just amount to a small amount of pre- and post-processing; I'll ignore that below.]

In the example above, we were "lucky" to have a tree, but that is not the case in general. For the general factor graph, marginal inference is NP-Hard (see citations in this UAI 2008 paper), but it is polynomial in the treewidth of the factor graph. Again, the question boils down to what products to store all at once and what order to sum out the variables. Let's consider the following cases (things might get more complicated if there is special support for sparse tensors, but for this I'm assuming that all input tensors are dense; also, let's only consider output factors that are a subset of some input factor---this is without loss of generality as long as it is possible to always include a no-op factor touching the variables of some desired output factor, however, such a factor would necessarily depend on the semi-ring and it might impact computational complexity, so let's not consider it for now):

  • Case A) If one of the factors includes all of the variables: In this case, we are within 2OPT in terms of memory, so we might as well just go ahead and compute the full product tensor, and then separately sum out [or do something fancy as you suggested]---anyway, at least we didn't need to recompute the product multiple times.
  • Case B) If the factor graph is a tree: (Note, for the effective structure of the factor graph, there is a node for each variable that does not appear in at least one output factor and a node for each factor in which any of those variables appear; there is an edge between a factor node and a variable node if the variable appears in the factor; a factor graph with n nodes is a tree if it has n-1 edges.) We can reduce this type (B) einsum to a linear number of type (A) einsums [basically running the sum-product algorithm (bp) from leaves to root to leaves]. Pick any node to be the root; depth-first-search will give the unique directed tree from that root; repeatedly use einsum to propagate "forward" marginals to the root and then to propagate backward marginals back to the leaves. Now marginals are known for all input factors; if any output factor is only a strict subset of the closest input factor (the one that would require the least work to sum out the remaining variables), do that additional marginalization as well.
  • Case C) Otherwise (i.e if the factor graph has treewidth > 1): In this case, find a junction tree (triangulate, find maximal cliques, find maximal spanning tree of the clique graph where edge weights between cliques are the number of nodes shared by the two cliques). Given the junction tree, assign each original factor to a clique (that has sufficient scope), and for each clique, use type (A) einsum to compute a new "input" factor representing the product of all original input factors assigned to that clique (with the output being all dimensions from any factor in the clique [including batch dimensions]). Note that there might be some additional fanciness considered here to recursively get the complexity of this subproblem smaller e.g. if/since we didn't pick the optimal junction tree, the full marginal of the clique may not be needed (e.g. only some dimensions needed to serve bp in type (B) einsum) or it can be sparsely represented or only heuristically computed (some ideas). Call type (B) einsum on the resulting input factors and the original output factors.

@bdusell
Copy link
Owner

bdusell commented May 29, 2021

No worries, we made the NeurIPS deadline without a hitch. That just leaves me here at home recovering from oral surgery, so I'm happy to take a look now.

The multiplication code is here: https://github.com/bdusell/semiring-einsum/blob/master/torch_semiring_einsum/extend.py#L150-L156

Your hunch is correct that the memory-saving technique complicates the re-use of the products. The memory-saving strategy I employ is based on the insight that you don't need to store the entire tensor of products (which can easily become extremely large) before doing the reduction. Instead, I break the input tensors into smaller slices, essentially doing a zip over little matching slices from each input tensor. At each iteration I reshape the slices so the dimensions line up as needed, compute the mini-product, sum out a subset of variables from the mini-product to get a mini-sum, then add the result to a running total (the full summation aka output returned by einsum). The only scratch memory I need is a tensor to hold the mini-product, a tensor for the mini-sum (always <= in size to the mini-product), and the output tensor itself. The first two are bounded in size by block_size. To save memory, the products are not preserved for the backward pass and are re-computed from scratch during the backward pass.

Now, can the mini-product tensor be re-used in some way? My hunch at first was no, but now I'm thinking this is the exact right direction to go in... we line up the dimensions and do the product, but then we perform a different mini-sum for each equation that gets rid of different variables, and then we accumulate each sum in a different tensor for each equation. This way you re-use the products for all outputs, but unlike the tree approach, you don't allocate any intermediate tensors whose sizes are proportional to the full inputs. I think this is exactly what you were looking for originally.

Moreover, any tree optimizations can now be applied to the mini-sum, but with the bonus that the memory usage is constant because we're working on slices limited by block_size. It also makes optimizing at compile time easier, since the slice sizes are more predictable.

@bdusell
Copy link
Owner

bdusell commented May 29, 2021

Foregoing the tree optimizations and computing the mini-summation for each output directly from the mini-product might be faster anyway due to the parallelism, and definitely more memory-efficient, if not as eco-friendly. Definitely the first thing we should try.

@teichert
Copy link
Author

That all sounds great!
Thanks for clarifying your method. I think I finally get it:
For 'ij,jk,kl->il' with, say, size 100 for each of those four dimension, rather compute the full 100x100x100x100 product tensor and then sum down to the desired 100x100 output tensor, your method can separately sweep over 10000 separate 10x10x10x10 slices of the full product one at a time and accumulate the respective marginals to the correct portion of the 100x100 output tensor which means much less memory usage. Very nice!

As you point out, the idea of separately accumulating to different output tensors seems perfectly consistent (and you are right that that would satisfy my initial desire). The only additional complication (and the thing that threw me off originally) is that you currently have deliberately structured the product slices so that all reduced dimensions come last (which, I think is required to do the summation), so that may be an extra permutation of dimensions for each extra output (but that's not a big deal I think).

This ticket was really just about (1) allowing the interface that allows multiple right-hand sides and (2) if possible, allowing the marginals for each of the three input factors while only multiplying the summation time by 3 rather than multiplying the product and summation time each by three, so I think your plan is a good one to handle both of these already.

(Tangentially, I think that all of that BP and junction tree stuff continues to be applicable as a way of avoiding the need to even slice over the entire full product. For example, the BP approach for the example above would perform the einsum 'ij,jk->ik' and then pass the result in to a second einsum of 'ik,kl->il' which would achieve the same answer with 2 full product tensors of 100x100x100 which amounts to only 2000 10x10x10 slices. Much better than 10000 10x10x10x10 slices! Do you follow me?)

Anyway, no pressure on any of this, and thanks very much for you library, attention, and comments.

@bdusell
Copy link
Owner

bdusell commented May 30, 2021

Ah, I think the comment above the line of code you linked to is actually out of date -- I changed it a while back so that lookup does a single [] followed by a .permute (https://github.com/bdusell/semiring-einsum/blob/master/torch_semiring_einsum/equation.py#L153-L159), and sum_block gets a tuple for dim that handles the rest (https://github.com/bdusell/semiring-einsum/blob/master/torch_semiring_einsum/extend.py#L162). But I think it's all manipulating tensor views anyway, so it doesn't matter that much.

I don't understand the BP approach yet, but are you saying it would be able to avoid creating an intermediate tensor of size I x K? That would be extremely useful if so. Do you think this is also related to #4?

This library was definitely born out of necessity, and without it much of my current research would be impossible. I hope by accommodating new, interesting use cases like this we can spread the joy to more users and open up more research possibilities.

@teichert
Copy link
Author

Ah, I think the comment above the line of code you linked to is actually out of date -- [... ]But I think it's all manipulating tensor views anyway, so it doesn't matter that much.

Great; that sounds right.

I don't understand the BP approach yet, but are you saying it would be able to avoid creating an intermediate tensor of size I x K?

We won't avoid IxK (which is already used in the input factors), nor IxL (which is used in the output), but we will avoid needing slices of (IxJxKxL). Even with your slicing trick, you have to represent the final sum, but you only have to represent a fixed-size slice of the intermediate (full) product at any given time. Even so, the size of that slice (in terms of number of elements) grows exponentially with the number of variables in the entire formula and the number of slices you need to iterate over grows with the product of the dimension sizes of all variables. With tree-structured factorizations, the BP version can avoid ever storing even a slice of the full product. Instead, the size and number of product slices that do need to be stored can be linear in the size and number of input factors. The junction-tree stuff is a way of representing any factorization as a tree of factor clusters (no magic here of course: if the factorization has high treewidth then the dense representation of some of the clusters will explode exponentially). To see a pretty trivial example of the difference:

# three 100x100 matrices
a, b, c = [torch.rand(100,100) for _ in range(3)]

# formula to compute all at once
eq_full = tse.compile_equation('ij,jk,kl->il')
# formula analogous to separating the two matrix multiplies
eq_part_out, eq_part_in = tse.compile_equation('ik,kl->il'), tse.compile_equation('ij,jk->ik')

# block_size 1
tse.einsum(eq_full, a,b,c, block_size=1)
# %timeit: 654 ms ± 6.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# block_size 10
tse.einsum(eq_full, a,b,c, block_size=10)
# %timeit: 395 ms ± 9.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# "BP" version with block_size 1
tse.einsum(eq_part_out, tse.einsum(eq_part_in, a,b, block_size=1), c, block_size=1)
# %timeit: 11.7 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

# "BP" version with block_size 10
tse.einsum(eq_part_out, tse.einsum(eq_part_in, a,b, block_size=10), c, block_size=10)
# %timeit: 6.17 ms ± 1.3 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

# same answers:
torch.isclose(tse.einsum(eq_full, a,b,c, block_size=10), tse.einsum(eq_part_out, tse.einsum(eq_part_in, a,b, block_size=10), c, block_size=10)).all()
# output: tensor(True)

# of course native pytorch matrix multiply and still wins handily (as does pytorch einsum, but no option for semiring)
a@b@c
# %timeit: 50.4 µs ± 2.82 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

torch.einsum('ik,kl->il', torch.einsum('ij,jk->ik', a,b), c)
# %timeit: 88.2 µs ± 4.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

torch.einsum('ij,jk,kl->il', a,b,c)
# %timeit: 86.8 µs ± 7.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

torch.isclose(tse.einsum(eq_full, a,b,c, block_size=10), a@b@c).all()
# outpu:t tensor(True)

Is that clear?

Do you think this is also related to #4?

Yes; I'm not well versed in that, but I do see connections made on page 55 and 57 of the Eisner and Blatz paper that Tim cited.

@bdusell
Copy link
Owner

bdusell commented Jun 7, 2021

Thanks for the great and thorough examples! To clarify, for Y = einsum('ij,jk,kl->il', A, B, C), the way it currently works in pseudocode is the following:

Y = torch.zeros(I, L)
for j in slices(J, block_size):
    for k in slices(K, block_size):
        Z = A[None, j, None, None] * B[None, j, k, None] * C[None, None, k, None] # size: I x block_size x block_size x L
        Y += torch.sum(Z, dim=(1, 2)) # size: I x L
return Y

Iterations: O(JK)
Multiplications: O(block_size^2 IJKL)
Memory: O(block_size^2 IL)

With the BP version this becomes:

Y1 = torch.zeros(I, K)
for j in slices(J, block_size):
    Z = A[None, j, None] * B[None, j, None] # size: I x block_size x K
    Y1 += torch.sum(Z, dim=1) # size: I x K
Y = torch.zeros(I, L)
for k in slices(K, block_size):
    Z = Y1[None, k, None] * C[None, k, None] # size: I x block_size x L
    Y += torch.sum(Z, dim=1) # size: I x L
return Y

Iterations: O(J + K)
Multiplications: O(block_size IJK + block_size IKL)
Memory: O(block_size IK + block_size IL)

@bdusell bdusell added the enhancement New feature or request label Jun 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants