Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor softmax templates to use outer dims
Summary: Previously, the softmax templates assumed that reduction would always be done over the last dim, so the only parameter passed to the templates was the rank of the tensor. To set the stage for generalizing softmax, we pass the reduction dim instead. The output is functionally identical, though the codegen changes slightly in the case where all the inner dimensions are 1: we now pass only the outer dimensions to the function call, dropping the redundant inner dimension parameters. For the `tail_shapes_all_1_bf16` softmax test case, we have Before: ``` softmax_0( X, Y, &input_batch, &X_dim_1, &X_dim_2, stream ); ``` After: ``` softmax_0( X, Y, &input_batch, stream ); ``` Differential Revision: D47732859 fbshipit-source-id: 512ae692034ce7208c4648955f6ccaf93c0a27aa
- Loading branch information