Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move modeling.py and modeling_nv.py to transformers #9676

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,5 @@
from .jamba.modeling import *
from .jamba.configuration import *
from .jamba.tokenizer import *
from .llm_embed import *
from .nv_embed import *
15 changes: 15 additions & 0 deletions paddlenlp/transformers/llm_embed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处增加from .modeling import *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加


from .modeling import *
294 changes: 294 additions & 0 deletions paddlenlp/transformers/llm_embed/modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import numpy as np
import paddle
import paddle.distributed as dist
import paddle.nn as nn
from tqdm import tqdm

from ...utils.log import logger
from .. import AutoConfig, AutoModel, PretrainedModel
from ..model_outputs import ModelOutput


@dataclass
class EncoderOutput(ModelOutput):
q_reps: Optional[paddle.Tensor] = None
p_reps: Optional[paddle.Tensor] = None
loss: Optional[paddle.Tensor] = None
scores: Optional[paddle.Tensor] = None


class BiEncoderModel(PretrainedModel):
def __init__(
self,
model_name_or_path: str = None,
normalized: bool = False,
sentence_pooling_method: str = "cls",
negatives_cross_device: bool = False,
temperature: float = 1.0,
use_inbatch_neg: bool = True,
margin: float = 0.3,
matryoshka_dims: Optional[List[int]] = None,
matryoshka_loss_weights: Optional[List[float]] = None,
query_instruction: Optional[str] = None,
document_instruction: Optional[str] = None,
eval_batch_size: int = 8,
tokenizer=None,
max_seq_length: int = 4096,
):
super().__init__()
self.model = AutoModel.from_pretrained(model_name_or_path, convert_from_torch=True)
self.model_config = AutoConfig.from_pretrained(model_name_or_path)
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")

Check warning on line 58 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L55-L58

Added lines #L55 - L58 were not covered by tests

self.normalized = normalized
self.sentence_pooling_method = sentence_pooling_method
self.temperature = temperature
self.use_inbatch_neg = use_inbatch_neg
self.config = self.model_config
self.margin = margin
self.matryoshka_dims = matryoshka_dims

Check warning on line 66 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L60-L66

Added lines #L60 - L66 were not covered by tests

self.query_instruction = query_instruction
self.document_instruction = document_instruction
self.eval_batch_size = eval_batch_size
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length

Check warning on line 72 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L68-L72

Added lines #L68 - L72 were not covered by tests

if self.matryoshka_dims:
self.matryoshka_loss_weights = (

Check warning on line 75 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L74-L75

Added lines #L74 - L75 were not covered by tests
matryoshka_loss_weights if matryoshka_loss_weights else [1] * len(self.matryoshka_dims)
)
else:
self.matryoshka_loss_weights = None

Check warning on line 79 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L79

Added line #L79 was not covered by tests

if not normalized:
self.temperature = 1.0
logger.info("reset temperature = 1.0 due to using inner product to compute similarity")

Check warning on line 83 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L81-L83

Added lines #L81 - L83 were not covered by tests

self.negatives_cross_device = negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError("Distributed training has not been initialized for representation all gather.")
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()

Check warning on line 90 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L85-L90

Added lines #L85 - L90 were not covered by tests

def sentence_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == "mean":
s = paddle.sum(hidden_state * mask.unsqueeze(-1).float(), axis=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == "cls":
return hidden_state[:, 0]
elif self.sentence_pooling_method == "last":

Check warning on line 99 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L93-L99

Added lines #L93 - L99 were not covered by tests
# return hidden_state[:, -1] # this is for padding side is left
sequence_lengths = mask.sum(axis=1)
last_token_indices = sequence_lengths - 1
embeddings = hidden_state[paddle.arange(hidden_state.shape[0]), last_token_indices]
return embeddings

Check warning on line 104 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L101-L104

Added lines #L101 - L104 were not covered by tests
else:
raise ValueError(f"Invalid sentence pooling method: {self.sentence_pooling_method}")

Check warning on line 106 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L106

Added line #L106 was not covered by tests

def get_model_config(
self,
):
return self.model_config.to_dict()

Check warning on line 111 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L111

Added line #L111 was not covered by tests

def encode(self, features):
psg_out = self.model(**features, return_dict=True, output_hidden_states=True)
p_reps = self.sentence_embedding(psg_out.hidden_states[-1], features["attention_mask"])
return p_reps

Check warning on line 116 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L114-L116

Added lines #L114 - L116 were not covered by tests

def compute_similarity(self, q_reps, p_reps):
# q_reps [batch_size, embedding_dim]
# p_reps [batch_size, embedding_dim]
return paddle.matmul(q_reps, p_reps.transpose([1, 0]))

Check warning on line 121 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L121

Added line #L121 was not covered by tests

def hard_negative_loss(self, q_reps, p_reps):
scores = self.compute_similarity(q_reps, p_reps)
scores = scores / self.temperature
scores = scores.reshape([q_reps.shape[0], -1])

Check warning on line 126 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L124-L126

Added lines #L124 - L126 were not covered by tests

target = paddle.arange(scores.shape[0], dtype="int64")
target = target * (p_reps.shape[0] // q_reps.shape[0])
loss = self.compute_loss(scores, target)
return scores, loss

Check warning on line 131 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L128-L131

Added lines #L128 - L131 were not covered by tests

def in_batch_negative_loss(self, q_reps, p_reps):
# In batch negatives
scores = self.compute_similarity(q_reps, p_reps)

Check warning on line 135 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L135

Added line #L135 was not covered by tests
# Substract margin from all positive samples cosine_sim()
margin_diag = paddle.full(shape=[q_reps.shape[0]], fill_value=self.margin, dtype=q_reps.dtype)
scores = scores - paddle.diag(margin_diag)

Check warning on line 138 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L137-L138

Added lines #L137 - L138 were not covered by tests
# Scale cosine to ease training converge
scores = scores / self.temperature
target = paddle.arange(0, q_reps.shape[0], dtype="int64")
loss = self.compute_loss(scores, target)
return scores, loss

Check warning on line 143 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L140-L143

Added lines #L140 - L143 were not covered by tests

def forward(
self,
query: Dict[str, paddle.Tensor] = None,
passage: Dict[str, paddle.Tensor] = None,
teacher_score: paddle.Tensor = None,
):
q_reps = self.encode(query)
p_reps = self.encode(passage)

Check warning on line 152 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L151-L152

Added lines #L151 - L152 were not covered by tests

# For non-matryoshka loss, we normalize the representations
if not self.matryoshka_dims:
if self.normalized:
q_reps = paddle.nn.functional.normalize(q_reps, axis=-1)
p_reps = paddle.nn.functional.normalize(p_reps, axis=-1)

Check warning on line 158 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L155-L158

Added lines #L155 - L158 were not covered by tests

if self.training:

Check warning on line 160 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L160

Added line #L160 was not covered by tests
# Cross device negatives
if self.negatives_cross_device:
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)

Check warning on line 164 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L162-L164

Added lines #L162 - L164 were not covered by tests

if self.matryoshka_dims:
loss = 0.0
scores = 0.0
for loss_weight, dim in zip(self.matryoshka_loss_weights, self.matryoshka_dims):
reduced_q = q_reps[:, :dim]
reduced_d = p_reps[:, :dim]
if self.normalized:
reduced_q = paddle.nn.functional.normalize(reduced_q, axis=-1)
reduced_d = paddle.nn.functional.normalize(reduced_d, axis=-1)

Check warning on line 174 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L166-L174

Added lines #L166 - L174 were not covered by tests

if self.use_inbatch_neg:
dim_score, dim_loss = self.in_batch_negative_loss(reduced_q, reduced_d)

Check warning on line 177 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L176-L177

Added lines #L176 - L177 were not covered by tests
else:
dim_score, dim_loss = self.hard_negative_loss(reduced_q, reduced_d)
scores += dim_score
loss += loss_weight * dim_loss

Check warning on line 181 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L179-L181

Added lines #L179 - L181 were not covered by tests

elif self.use_inbatch_neg:
scores, loss = self.in_batch_negative_loss(q_reps, p_reps)

Check warning on line 184 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L183-L184

Added lines #L183 - L184 were not covered by tests
else:
scores, loss = self.hard_negative_loss(q_reps, p_reps)

Check warning on line 186 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L186

Added line #L186 was not covered by tests

else:
scores = self.compute_similarity(q_reps, p_reps)
loss = None
return EncoderOutput(

Check warning on line 191 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L189-L191

Added lines #L189 - L191 were not covered by tests
loss=loss,
scores=scores,
q_reps=q_reps,
p_reps=p_reps,
)

def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)

Check warning on line 199 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L199

Added line #L199 was not covered by tests

def _dist_gather_tensor(self, t: Optional[paddle.Tensor]):
if t is None:
return None

Check warning on line 203 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L202-L203

Added lines #L202 - L203 were not covered by tests

all_tensors = [paddle.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)

Check warning on line 206 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L205-L206

Added lines #L205 - L206 were not covered by tests

all_tensors[self.process_rank] = t
all_tensors = paddle.concat(all_tensors, axis=0)

Check warning on line 209 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L208-L209

Added lines #L208 - L209 were not covered by tests

return all_tensors

Check warning on line 211 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L211

Added line #L211 was not covered by tests

def save_pretrained(self, output_dir: str, **kwargs):
state_dict = self.model.state_dict()
state_dict = type(state_dict)({k: v.clone().cpu() for k, v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)

Check warning on line 216 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L214-L216

Added lines #L214 - L216 were not covered by tests

@paddle.no_grad()
def encode_sentences(self, sentences: List[str], **kwargs) -> np.ndarray:
self.model.eval()
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), self.eval_batch_size), desc="Batches"):
sentences_batch = sentences[start_index : start_index + self.eval_batch_size]

Check warning on line 223 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L220-L223

Added lines #L220 - L223 were not covered by tests

inputs = self.tokenizer(

Check warning on line 225 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L225

Added line #L225 was not covered by tests
sentences_batch,
padding=True,
truncation=True,
return_tensors="pd",
max_length=self.max_seq_length,
return_attention_mask=True,
)
outputs = self.model(

Check warning on line 233 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L233

Added line #L233 was not covered by tests
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = outputs.hidden_states[-1]

Check warning on line 239 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L239

Added line #L239 was not covered by tests

if self.sentence_pooling_method == "last":
if self.tokenizer.padding_side == "right":
sequence_lengths = inputs.attention_mask.sum(axis=1)
last_token_indices = sequence_lengths - 1
embeddings = last_hidden_state[paddle.arange(last_hidden_state.shape[0]), last_token_indices]
elif self.tokenizer.padding_side == "left":
embeddings = last_hidden_state[:, -1]

Check warning on line 247 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L241-L247

Added lines #L241 - L247 were not covered by tests
else:
raise NotImplementedError(f"Padding side {self.tokenizer.padding_side} not supported.")
elif self.sentence_pooling_method == "cls":
embeddings = last_hidden_state[:, 1]
elif self.sentence_pooling_method == "mean":
s = paddle.sum(last_hidden_state * inputs.attention_mask.unsqueeze(-1), axis=1)
d = inputs.attention_mask.sum(axis=1, keepdim=True)
embeddings = s / d

Check warning on line 255 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L249-L255

Added lines #L249 - L255 were not covered by tests
else:
raise NotImplementedError(f"Pooling method {self.pooling_method} not supported.")

Check warning on line 257 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L257

Added line #L257 was not covered by tests

embeddings = paddle.nn.functional.normalize(embeddings, p=2, axis=-1)

Check warning on line 259 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L259

Added line #L259 was not covered by tests

all_embeddings.append(embeddings.cpu().numpy().astype("float32"))

Check warning on line 261 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L261

Added line #L261 was not covered by tests

return np.concatenate(all_embeddings, axis=0)

Check warning on line 263 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L263

Added line #L263 was not covered by tests

def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
"""
This function will be used to encode queries for retrieval task
if there is a instruction for queries, we will add it to the query text
"""
if self.query_instruction is not None:
input_texts = [f"{self.query_instruction}{query}" for query in queries]

Check warning on line 271 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L270-L271

Added lines #L270 - L271 were not covered by tests
else:
input_texts = queries
return self.encode_sentences(input_texts)

Check warning on line 274 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L273-L274

Added lines #L273 - L274 were not covered by tests

def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
"""
This function will be used to encode corpus for retrieval task
if there is a instruction for docs, we will add it to the doc text
"""
if isinstance(corpus[0], dict):
if self.document_instruction is not None:
input_texts = [

Check warning on line 283 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L281-L283

Added lines #L281 - L283 were not covered by tests
"{}{} {}".format(self.document_instruction, doc.get("title", ""), doc["text"]).strip()
for doc in corpus
]
else:
input_texts = ["{} {}".format(doc.get("title", ""), doc["text"]).strip() for doc in corpus]

Check warning on line 288 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L288

Added line #L288 was not covered by tests
else:
if self.document_instruction is not None:
input_texts = [f"{self.document_instruction}{doc}" for doc in corpus]

Check warning on line 291 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L290-L291

Added lines #L290 - L291 were not covered by tests
else:
input_texts = corpus
return self.encode_sentences(input_texts)

Check warning on line 294 in paddlenlp/transformers/llm_embed/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llm_embed/modeling.py#L293-L294

Added lines #L293 - L294 were not covered by tests
15 changes: 15 additions & 0 deletions paddlenlp/transformers/nv_embed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,此处增加from .modeling import *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加


from .modeling import *
Loading