-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bidirectional attention or casual attention for embedding? #15
Comments
The last hidden state is produced via bidirectional attention in the model itself |
Hi, I'm currently trying to train gritlm using Gemma2b to generate embeddings. While reviewing the training script for Mistral7b, I noticed the use of bidirectional attention with attn='bbcc'. In the context of embeddings, would it be more advantageous to train with 'bbcc' or 'cccc'? However, when I tried to use attn='bbcc' with Gemma, I encountered an error: TypeError: GemmaModel.forward() received an unexpected keyword argument 'is_causal'. To fix this, I commented out the following line in gritlm.py:
is this correct ? |
|
Hi @Muennighoff, amazing work! I have a similar confusing as @yonxie. I can see here that you did a final pooling. I was also looking at the query-doc cacheing example at page 63. In order to reuse the key-value cache (if I understand correctly the key values are producing during forward pass using bidirectional attention), that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG? |
Sorry for the confusion. I mean that inside of the model bidirectional attention is applied in every transformer layer. The attention mask for that is created here gritlm/scripts/modeling_mistral_gritlm.py Line 1018 in 47b7fe6
The pooling that you point to is then applied to the final hidden state returned from the model to remove the sequence length dimension.
Yes
The two caches (or prefixes if you will) are concatenated and have not paid attention to one another (maybe this is what you mean by independent). You may find it helpful to look at this code example: https://github.com/ContextualAI/gritlm?tab=readme-ov-file#caching |
You mention that bidirectional attention is used for embedding task. But it appears that you only use the last hidden states from the pretrained LLM to generate embeddings. Is the final projection is the only bidirectional part?
The text was updated successfully, but these errors were encountered: