From db3bcd655dce96ff1327c70b15b8caa43cb84dd2 Mon Sep 17 00:00:00 2001 From: Amit Portnoy <1131991+amitport@users.noreply.github.com> Date: Sat, 23 Nov 2024 16:25:39 +0200 Subject: [PATCH] align model_card_templates.py with code See: https://github.com/UKPLab/sentence-transformers/blob/348190d46b0c010c7a4693f198f0ddf70c6ceb35/sentence_transformers/models/Pooling.py#L163 https://github.com/UKPLab/sentence-transformers/blob/348190d46b0c010c7a4693f198f0ddf70c6ceb35/sentence_transformers/models/Pooling.py#L170 --- sentence_transformers/model_card_templates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/model_card_templates.py b/sentence_transformers/model_card_templates.py index 04a471371..9758b1cc2 100644 --- a/sentence_transformers/model_card_templates.py +++ b/sentence_transformers/model_card_templates.py @@ -130,7 +130,7 @@ def model_card_get_pooling_function(pooling_mode): # Max Pooling - Take the max value over time for every dimension. def max_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value return torch.max(token_embeddings, 1)[0] """, @@ -142,7 +142,7 @@ def max_pooling(model_output, attention_mask): #Mean Pooling - Take attention mask into account for correct averaging def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) """, )