-
Notifications
You must be signed in to change notification settings - Fork 1
/
decode.h
executable file
·110 lines (102 loc) · 4.2 KB
/
decode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#ifndef _DECODE_
#define _DECODE_
#include <vector>
#include <string>
using namespace std;
namespace NMT
{
class Decoder
{
private:
size_t head_num;
size_t hidden_num;
size_t layer_num;
size_t decode_length;
size_t vocabe_size;
size_t filter_size;
vector<vector<vector<float>>> weight;
vector<float> weight_embedding;
vector<float> weight_language;
vector<float> weight_scale;
vector<float> weight_bias;
vector<float> logit_weight;
public:
void Translate(vector<float>& encode_out, const size_t& batch_size, const size_t& length, vector<int>& language_id, vector<int>& mask);
void GetPosWeight(vector<float>& position_weight, size_t decode_length, size_t hidden_num);
void BuildBias(const size_t& batch_size, const size_t& length, int* mask, float* bias);
void AddBias(float* input, const float* bias, const size_t& batch_size, const size_t& length);
void EmbeddingLookup(vector<size_t>& input, vector<float>& embedding_word);
void EmbeddingInit(vector<int>& language_id, vector<float>& embedding_word);
void PositionEncode(const size_t& time, vector<float>& embedding_word, vector<float>& weight_position);
void LayerPostprocess(vector<float>& layer_input, const vector<float>& temp);
void BatchSoftmax(float* input_qk, int k, int head_num, const size_t& batch_size, const size_t& length);
void GenSoftmax(float* input, int num);
void ToLogits(float* input, const size_t& batch_size, float* weight, float* output);
vector<size_t> GetMax(vector<float>& logit, const size_t& batch_size);
void GetPositionX(const float* position_embedding, const size_t max_length, const size_t& length, vector<float>& position_x);
void MulPositionKey(const size_t& batch_size, const size_t& length, float* input, float* position_key, float* out);
void MulPositionValue(const size_t& batch_size, const size_t& length, float* input, float* position_val, float* out);
Decoder(size_t& head_num,
size_t& hidden_num,
size_t& layer_num,
size_t& decode_length,
size_t& vocabe_size,
size_t& filter_size,
vector<vector<vector<float>>>& weight,
vector<float>& weight_embedding,
vector<float>& weight_language,
vector<float>& weight_scale,
vector<float>& weight_bias,
vector<float>& logit_weight);
void SetCache(const float* encode_out,
const size_t& batch_size,
const size_t& length,
vector<vector<float>>& cache_out_k,
vector<vector<float>>& cache_out_v);
void LayerPreprocess(vector<float>& layer_input,
const size_t& batch_size,
const size_t& length,
const float* scale,
const float* bias);
void FeedForward(const vector<float>& input,
vector<float>& output,
const size_t& batch_size,
const size_t& length,
int filter,
const float* weight,
float* bias,
string activation);
vector<size_t> Decode(vector<float>& embedding_word,
const size_t& batch_size,
const size_t& length,
vector<vector<float>>& encode_out_k,
vector<vector<float>>& encode_out_v,
vector<vector<float>>& cache_k,
vector<vector<float>>& cache_v,
vector<float>& self_bias,
vector<float>& encdec_bias);
void SelfAttention(float* input,
const size_t& batch_size,
const size_t& length,
const float* q_weight,
const float* k_weight,
const float* v_weight,
const float* key_weight,
const float* value_weight,
const float* weight,
float* output,
vector<float>& k_value,
vector<float>& v_value,
const float* bias);
void EncdecAttention(float* input,
const size_t& batch_size,
const size_t& length,
const float* q_weight,
const float* weight,
float* output,
vector<float>& k_value,
vector<float>& v_value,
const float* bias);
};
}
#endif // !_DECODE