-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathresize_longformer.py
67 lines (57 loc) · 2.95 KB
/
resize_longformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Tuple
from transformers import LongformerPreTrainedModel, LongformerTokenizerFast, LongformerConfig
def resize_longformer(
model: LongformerPreTrainedModel,
tokenizer: LongformerTokenizerFast,
longformer_max_length: int = 512) -> Tuple[LongformerPreTrainedModel, LongformerTokenizerFast]:
"""
Resize any longformer model (with task specific head)
"""
longformer_max_length += 2 # due to longformers origins in roberta...
###############################
# Create longformer tokenizer #
###############################
with TemporaryDirectory() as temp_dir:
tokenizer.save_pretrained(temp_dir)
longformer_tokenizer = LongformerTokenizerFast.from_pretrained(temp_dir)
longformer_tokenizer.model_max_length = longformer_max_length - 2
longformer_tokenizer.init_kwargs["model_max_length"] = longformer_max_length - 2
##################################
# Create new longformer instance #
##################################
longformer_config = LongformerConfig()
longformer_config.update(model.config.to_dict())
longformer_config.max_position_embeddings = longformer_max_length
longformer_model = model.__class__(longformer_config)
# We can easily copy all weights except the position embeddings and position ids
orig_weights = model.state_dict()
embedding_position_ids = orig_weights.pop("longformer.embeddings.position_ids")
embedding_position_weights = orig_weights.pop("longformer.embeddings.position_embeddings.weight")
longformer_model.load_state_dict(orig_weights, strict=False)
if longformer_max_length < tokenizer.model_max_length:
# New length is short than orig length
# Slice new position ids
new_embedding_position_ids = embedding_position_ids[:, :longformer_max_length]
# Slice the weights
new_embedding_position_weights = embedding_position_weights[:longformer_max_length, :]
else:
# New length is longer than orig length
# Create new position_ids
new_embedding_position_ids = torch.arange(longformer_max_length).unsqueeze(0)
# Copy weights
n_copies = longformer_max_length // embedding_position_weights.size(0)
n_pos_embs_left = longformer_max_length - (n_copies * embedding_position_weights.size(0))
new_embedding_position_weights = embedding_position_weights.repeat(n_copies, 1)
new_embedding_position_weights = torch.cat([
new_embedding_position_weights,
embedding_position_weights[:n_pos_embs_left]
], 0)
embedding_states = OrderedDict({
"longformer.embeddings.position_ids": new_embedding_position_ids,
"longformer.embeddings.position_embeddings.weight": new_embedding_position_weights
})
longformer_model.load_state_dict(embedding_states, strict=False)
return longformer_model, longformer_tokenizer