Skip to content

Commit

Permalink
raise error with non-existence image prompts (#1322)
Browse files Browse the repository at this point in the history
* print non-existence image prompt

* reformat
  • Loading branch information
Gasoonjia authored Oct 23, 2024
1 parent 76c1cd2 commit 7fe2c86
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
Binary file added assets/view.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 24 additions & 13 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@

from PIL import Image

# torchtune model definition dependencies
from torchtune.data import Message, padded_collate_tiled_images_and_mask

from torchtune.generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer

from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
from torchtune.training import set_default_dtype

from torchchat.cli.builder import (
_initialize_model,
_initialize_tokenizer,
Expand All @@ -43,6 +34,15 @@
from torchchat.utils.build_utils import device_sync, set_precision
from torchchat.utils.device_info import get_device_info

# torchtune model definition dependencies
from torchtune.data import Message, padded_collate_tiled_images_and_mask

from torchtune.generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer

from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
from torchtune.training import set_default_dtype


class _ChatFormatter(ABC):
def __init__(self, tokenizer):
Expand Down Expand Up @@ -179,8 +179,15 @@ def from_args(cls, args):

# Validate that all image prompts exist before expensive model load
if image_prompts := getattr(args, "image_prompts", None):
if not all(os.path.exists(image_prompt) for image_prompt in image_prompts):
raise RuntimeError(f"Image prompt {image_prompt} does not exist")
non_existent_image_prompts = [
image_prompt
for image_prompt in image_prompts
if (not os.path.exists(image_prompt))
]
if len(non_existent_image_prompts):
raise RuntimeError(
f"Image prompt {non_existent_image_prompts} does not exist"
)

return cls(
prompt=getattr(args, "prompt", ""),
Expand Down Expand Up @@ -938,7 +945,8 @@ def chat(
TransformerCrossAttentionLayer,
TransformerSelfAttentionLayer,
)
decoder = self.model.model.decoder

decoder = self.model.model.decoder
for m in reversed(list(decoder.modules())):
if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
m, TransformerCrossAttentionLayer
Expand Down Expand Up @@ -984,7 +992,10 @@ def chat(
# `is_torchtune_model` is a misnomer since it doesn't capture all
# torchtune models (i.e. Flamingo)
# See Issue: https://github.com/pytorch/torchchat/issues/1273
elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo:
elif (
not generator_args.is_torchtune_model
and self.model.config.model_type != ModelType.Flamingo
):
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
(
Expand Down

0 comments on commit 7fe2c86

Please sign in to comment.