You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, thanks for your great work. I intend to compute the attention scores between tokens and here is my code:
import torch
from transformers import BertModel, BertConfig, DNATokenizer
dir_to_pretrained_model = "./6-new-12w-0/"
config = BertConfig.from_pretrained('../src/transformers/dnabert-config/bert-config-6/config.json')
tokenizer = DNATokenizer.from_pretrained('dna6')
print(config)
model = BertModel.from_pretrained(dir_to_pretrained_model, config=config).cuda()
sequence = "AATCTAATCTAGTCTAGCCTAGCA"
model_input = tokenizer.encode_plus(sequence, add_special_tokens=True, max_length=512)["input_ids"]
inputs = tokenizer.encode_plus(sequence, add_special_tokens=True, max_length=512)
model_input = torch.tensor(model_input, dtype=torch.long).cuda()
model_input = model_input.unsqueeze(0) # to generate a fake batch with batch size one
output = model(model_input)
print(output[-1][-1])
print(output[-1][-1].shape)
I think the output[-1] will contain attention matrices and I took out the last item, whose shape is [1,12,3,3]. Does this 12 means the 12 heads? And 3,3 represents two sets of tokens? May I know how to comptue the correct attention for these tokens? Just averaging all the attention in each layer? Thanks.
The text was updated successfully, but these errors were encountered:
Hi, thanks for your great work. I intend to compute the attention scores between tokens and here is my code:
I think the output[-1] will contain attention matrices and I took out the last item, whose shape is [1,12,3,3]. Does this 12 means the 12 heads? And 3,3 represents two sets of tokens? May I know how to comptue the correct attention for these tokens? Just averaging all the attention in each layer? Thanks.
The text was updated successfully, but these errors were encountered: