-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathconfig.py
49 lines (42 loc) · 1.38 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from transformers import BertConfig
class HyperParameters(object):
"""
Hyper parameter of the model
"""
def __init__(
self,
max_length: int = 128,
epochs=4,
batch_size=32,
learning_rate=2e-5,
fp16=True,
fp16_opt_level="O1",
max_grad_norm=1.0,
warmup_steps=0.1,
) -> None:
self.max_length = max_length
"""Max length of sentence"""
self.epochs = epochs
"""Num of epochs"""
self.batch_size = batch_size
"""Size of mini batch"""
self.learning_rate = learning_rate
"""Learning rate"""
self.fp16 = fp16
"""Enable FP16 mixed-precision training"""
self.fp16_opt_level = fp16_opt_level
"""NVIDIA APEX Level, ['O0', 'O1', 'O2', and 'O3'], see: https://nvidia.github.io/apex/amp.html"""
self.max_grad_norm = max_grad_norm
"""Max gradient normalization"""
self.warmup_steps = warmup_steps
"""Steps of warm-up for learning rate"""
def __repr__(self) -> str:
return self.__dict__.__repr__()
class Config(BertConfig):
"""
Config of MatchModel
"""
def __init__(self, max_len=512, algorithm="BertForSimMatchModel", **kwargs):
super(Config, self).__init__(**kwargs)
self.max_len = max_len
self.algorithm = algorithm