Skip to content

Commit

Permalink
Refactor softmax templates to use outer dims
Browse files Browse the repository at this point in the history
Summary:
Previously, the softmax templates assumed that reduction would always be done over the last dim, so the only parameter passed to the templates was the rank of the tensor. To set the stage for generalizing softmax, we pass the reduction dim instead.

The output is functionally identical, though the codegen changes slightly in the case where all the inner dimensions are 1: we now pass only the outer dimensions to the function call, dropping the redundant inner dimension parameters.

For the `tail_shapes_all_1_bf16` softmax test case, we have

Before:
```
      softmax_0(
         X,
         Y,
         &input_batch,
         &X_dim_1,
         &X_dim_2,
         stream
      );
```

After:
```
      softmax_0(
         X,
         Y,
         &input_batch,
         stream
      );
```

Differential Revision: D47732859

fbshipit-source-id: 512ae692034ce7208c4648955f6ccaf93c0a27aa
  • Loading branch information
int3 authored and facebook-github-bot committed Jul 25, 2023
1 parent 2efce5e commit 190212f
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/aitemplate/backend/cuda/softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
SHAPE_FUNCTIONS = jinja2.Template(
"""
int64_t M = 1;
{% for idx in range(input_ndim - 1) %}
{% for idx in range(dim) %}
M *= *in_{{idx}};
{% endfor %}
"""
Expand All @@ -124,7 +124,7 @@
"""
void {{func_name}}(void* input,
void* output,
{% for idx in range(input_ndim - 1) %}
{% for idx in range(dim) %}
int64_t* in_{{idx}},
{% endfor %}
cudaStream_t stream)
Expand All @@ -143,7 +143,7 @@
{{indent}}{{func_name}}(
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names[:-1] %}
{% for name in outer_dim_names %}
{{indent}} &{{name}},
{% endfor %}
{{indent}} stream
Expand All @@ -154,10 +154,9 @@


def get_func_signature(func_attrs: Dict[str, Any]) -> str:
input_ndim = func_attrs["inputs"][0]._rank()
return FUNC_SIGNATURE.render(
func_name=func_attrs["name"],
input_ndim=input_ndim,
dim=func_attrs["dim"],
).strip()


Expand Down Expand Up @@ -197,7 +196,7 @@ def softmax_gen_function(func_attrs: Dict[str, Any]) -> str:
os.path.dirname(__file__), "softmax.cuh"
),
func_signature=get_func_signature(func_attrs),
shape_functions=SHAPE_FUNCTIONS.render(input_ndim=rank),
shape_functions=SHAPE_FUNCTIONS.render(dim=dim),
dtype=elem_input_type,
K=k,
m=find_tile_size(k),
Expand All @@ -217,17 +216,18 @@ def softmax_gen_function_call(func_attrs, indent=" "):
input_name = func_attrs["inputs"][0]._attrs["name"]
output_name = func_attrs["outputs"][0]._attrs["name"]

shapes = func_attrs["inputs"][0]._attrs["shape"]
shape = func_attrs["inputs"][0]._attrs["shape"]
assert (
len(shapes) >= 2
), f"Softmax only supports input with rank >= 2, current rank: {len(shapes)}"
len(shape) >= 2
), f"Softmax only supports input with rank >= 2, current rank: {len(shape)}"

input_dim_names = [shape._attrs["name"] for shape in shapes]
reduction_dim = func_attrs["dim"]
outer_dim_names = [dim._attrs["name"] for dim in shape[:reduction_dim]]

return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
input=input_name,
output=output_name,
input_dim_names=input_dim_names,
outer_dim_names=outer_dim_names,
indent=indent,
)

0 comments on commit 190212f

Please sign in to comment.