From cbde34b72f6b3bea01b42be795a8b9348dbdd442 Mon Sep 17 00:00:00 2001 From: Qihui Xie Date: Fri, 23 Feb 2024 12:15:50 +0800 Subject: [PATCH] fix matrix3d int type error for windows (#81) --- llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc | 4 ++-- llm/src/nn_modules/Int4GPTBigCodeDecoder.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc b/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc index 0db391cb..53c4a235 100644 --- a/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc +++ b/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc @@ -104,8 +104,8 @@ struct Fp32GPTBigCodeDecoder_output Fp32GPTBigCodeDecoder::forward(const struct // Position embeddings // Matrix3D pos_embeds = this->get_position_embed(sqlen, past_key_values_length); #ifdef _WIN32 - std::vector position_ids_buf_vec(sqlen); - float *position_ids_buf = &position_ids_buf_vec.front(); + std::vector position_ids_buf_vec(sqlen); + int *position_ids_buf = &position_ids_buf_vec.front(); std::vector pos_embeds_buf_vec(sqlen * this->embed_dim); float *pos_embeds_buf = &pos_embeds_buf_vec.front(); #else diff --git a/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc b/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc index 9684ca53..b7893f4f 100644 --- a/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc +++ b/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc @@ -110,8 +110,8 @@ struct Int4GPTBigCodeDecoder_output Int4GPTBigCodeDecoder::forward(const struct // printf(("Before get_position_embed\n"); // Matrix3D pos_embeds = this->get_position_embed(sqlen, past_key_values_length); #ifdef _WIN32 - std::vector position_ids_buf_vec(sqlen); - float *position_ids_buf = &position_ids_buf_vec.front(); + std::vector position_ids_buf_vec(sqlen); + int *position_ids_buf = &position_ids_buf_vec.front(); std::vector pos_embeds_buf_vec(sqlen * this->embed_dim); float *pos_embeds_buf = &pos_embeds_buf_vec.front(); #else