Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move modeling.py and modeling_nv.py to transformers #9676

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from

Conversation

Li-Z-Q
Copy link
Contributor

@Li-Z-Q Li-Z-Q commented Dec 23, 2024

move modeling.py and modeling_nv.py to transformers

Copy link

paddle-bot bot commented Dec 23, 2024

Thanks for your contribution!

Copy link

codecov bot commented Dec 23, 2024

Codecov Report

Attention: Patch coverage is 17.38149% with 366 lines in your changes missing coverage. Please review.

Project coverage is 52.62%. Comparing base (97ae9ad) to head (86a05c3).
Report is 11 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/transformers/nv_embed/modeling.py 15.18% 229 Missing ⚠️
paddlenlp/transformers/llm_embed/modeling.py 18.93% 137 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9676      +/-   ##
===========================================
+ Coverage    52.00%   52.62%   +0.62%     
===========================================
  Files          721      722       +1     
  Lines       116703   112813    -3890     
===========================================
- Hits         60690    59373    -1317     
+ Misses       56013    53440    -2573     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@DrownFish19
Copy link
Collaborator

DrownFish19 commented Dec 24, 2024

Lint 问题需要安装pre-commit 后格式化代码,参考步骤如下:

# 安装
pip install pre-commit

# 在项目文件夹下注册pre-commit,每次commit提交时都会格式化代码
pre-commit install

# 单独处理之前的代码文件
pre-commit run --file XXXX.py

@@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的copyright是否正确?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处增加from .modeling import *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

PretrainedModel,
)
from paddlenlp.transformers.model_outputs import ModelOutput
from paddlenlp.utils.log import logger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的调用需要修改为相对引用方式
例如:

from paddlenlp.transformers.model_outputs import ModelOutput

修改为

from ..transformers.model_outputs import ModelOutput

Copy link
Contributor Author

@Li-Z-Q Li-Z-Q Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为 from ..transformers.model_outputs import ModelOutput 之后会报错,因此暂未修改
截屏2024-12-25 19 18 55

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处应该为 from ..model_outputs import ModelOutput ,抱歉之前写错了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,此处增加from .modeling import *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

@@ -0,0 +1,517 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此文件需要修改名称为modeling.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

from mteb import MTEB

from paddlenlp.peft import LoRAConfig, LoRAModel
from paddlenlp.transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers.llm_embed.modeling import BiEncoderModel
from paddlenlp.transformers.nv_embed.modeling_nv import NVEncodeModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处可以简化导入

from paddlenlp.transformers import BiEncoderModel, NVEncodeModel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


from paddlenlp.peft import LoRAConfig, LoRAModel
from paddlenlp.trainer import PdArgumentParser, Trainer, get_last_checkpoint, set_seed
from paddlenlp.transformers import AutoTokenizer
from paddlenlp.transformers.llm_embed.modeling import BiEncoderModel
from paddlenlp.transformers.nv_embed.modeling_nv import NVEncodeModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@Li-Z-Q
Copy link
Contributor Author

Li-Z-Q commented Dec 25, 2024

Lint 问题需要安装pre-commit 后格式化代码,参考步骤如下:

# 安装
pip install pre-commit

# 在项目文件夹下注册pre-commit,每次commit提交时都会格式化代码
pre-commit install

# 单独处理之前的代码文件
pre-commit run --file XXXX.py

已按照您所说步骤在commit之前进行了pre-commit

dtype=str(self.latents.weight.dtype).split(".")[-1],
)
self_latents_weight_T = self.latents(one).T
latents = repeat(self_latents_weight_T, "d h -> b d h", b=last_hidden_states.shape[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为

latents = paddle.tile(self_latents_weight_T, repeat_times=last_hidden_states.shape[0]).reshape( self_latents_weight_T.shape[0], last_hidden_states.shape[0], self_latents_weight_T.shape[1] ) latents = latents.transpose([1, 0, 2])

k = kv[:, :, : self.config.max_position_embeddings]
v = kv[:, :, self.config.max_position_embeddings :]

q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.config.num_key_value_heads), (q, k, v))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rearrange辛苦换为paddle算子

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# v.stop_gradient = False
# out = paddle.nn.functional.scaled_dot_product_attention(q, k, v) # if use this, must set k and v stop_gradient to False
out = scaled_dot_product_attention(q, k, v) # if use this, no need to manually set k and v
out = rearrange(out, "b n h d -> b n (h d)", h=self.config.num_key_value_heads)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants