diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 2b1c2b3c4f..cd6a9dd444 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -312,7 +312,7 @@ def forward(self, query_layer, key_layer, value_layer.size(3)) # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), + value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk]