Skip to content

Commit

Permalink
Trust remote code param (#385)
Browse files Browse the repository at this point in the history
* 所有涉及到hf_model的算子,都加了一个trust_remote_code的参数并且传递给prepare_model函数

* 进行了pre-commit检查

* trust_remote_code

---------

Co-authored-by: Zheng Chaoxu <[email protected]>
Co-authored-by: zhengchaoxu <[email protected]>
  • Loading branch information
3 people authored Aug 15, 2024
1 parent 986b43d commit 7a00933
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def __init__(self, config: Blip2Config) -> None:


def prepare_simple_aesthetics_model(pretrained_model_name_or_path,
return_model=True):
return_model=True,
trust_remote_code=False):
"""
Prepare and load a simple aesthetics model.
Expand All @@ -344,21 +345,25 @@ def prepare_simple_aesthetics_model(pretrained_model_name_or_path,
AestheticsPredictorV2ReLU)
from transformers import CLIPProcessor

processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path)
processor = CLIPProcessor.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if not return_model:
return processor
else:
if 'v1' in pretrained_model_name_or_path:
model = AestheticsPredictorV1.from_pretrained(
pretrained_model_name_or_path)
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code)
elif ('v2' in pretrained_model_name_or_path
and 'linear' in pretrained_model_name_or_path):
model = AestheticsPredictorV2Linear.from_pretrained(
pretrained_model_name_or_path)
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code)
elif ('v2' in pretrained_model_name_or_path
and 'relu' in pretrained_model_name_or_path):
model = AestheticsPredictorV2ReLU.from_pretrained(
pretrained_model_name_or_path)
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code)
else:
raise ValueError(
'Not support {}'.format(pretrained_model_name_or_path))
Expand Down Expand Up @@ -439,7 +444,8 @@ def decompress_model(compressed_model_path):
def prepare_diffusion_model(pretrained_model_name_or_path,
diffusion_type,
torch_dtype='fp32',
revision='main'):
revision='main',
trust_remote_code=False):
"""
Prepare and load an Diffusion model from HuggingFace.
Expand Down Expand Up @@ -493,7 +499,8 @@ def prepare_diffusion_model(pretrained_model_name_or_path,

model = pipeline.from_pretrained(pretrained_model_name_or_path,
revision=revision,
torch_dtype=torch_dtype)
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code)

return model

Expand Down

0 comments on commit 7a00933

Please sign in to comment.