diff --git a/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc b/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc index 0db391c..53c4a23 100644 --- a/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc +++ b/llm/src/nn_modules/Fp32GPTBigCodeDecoder.cc @@ -104,8 +104,8 @@ struct Fp32GPTBigCodeDecoder_output Fp32GPTBigCodeDecoder::forward(const struct // Position embeddings // Matrix3D pos_embeds = this->get_position_embed(sqlen, past_key_values_length); #ifdef _WIN32 - std::vector position_ids_buf_vec(sqlen); - float *position_ids_buf = &position_ids_buf_vec.front(); + std::vector position_ids_buf_vec(sqlen); + int *position_ids_buf = &position_ids_buf_vec.front(); std::vector pos_embeds_buf_vec(sqlen * this->embed_dim); float *pos_embeds_buf = &pos_embeds_buf_vec.front(); #else diff --git a/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc b/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc index 9684ca5..b7893f4 100644 --- a/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc +++ b/llm/src/nn_modules/Int4GPTBigCodeDecoder.cc @@ -110,8 +110,8 @@ struct Int4GPTBigCodeDecoder_output Int4GPTBigCodeDecoder::forward(const struct // printf(("Before get_position_embed\n"); // Matrix3D pos_embeds = this->get_position_embed(sqlen, past_key_values_length); #ifdef _WIN32 - std::vector position_ids_buf_vec(sqlen); - float *position_ids_buf = &position_ids_buf_vec.front(); + std::vector position_ids_buf_vec(sqlen); + int *position_ids_buf = &position_ids_buf_vec.front(); std::vector pos_embeds_buf_vec(sqlen * this->embed_dim); float *pos_embeds_buf = &pos_embeds_buf_vec.front(); #else