You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to get a model summary of the donut model but am unable to define the input for the torch summary.
###########################################################
import argparse
import gradio as gr
import torch
from PIL import Image
from donut.donut.model import DonutModel
from torchvision import models
from torchsummary import summary
I am trying to get a model summary of the donut model but am unable to define the input for the torch summary.
###########################################################
import argparse
import gradio as gr
import torch
from PIL import Image
from donut.donut.model import DonutModel
from torchvision import models
from torchsummary import summary
def demo_process_vqa(input_img, question):
global pretrained_model, task_prompt, task_name
# pretrained_model = './donut/result/train_docvqa/20220912_103244'
# task_name = "docvqa"
# task_prompt = "<s_pdf-donut>"
input_img = Image.fromarray(input_img)
user_prompt = task_prompt.replace("{user_input}", question)
print(user_prompt)
output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
print('inf_out',output)
return output
def demo_process(input_img):
global pretrained_model, task_prompt, task_name
input_img = Image.fromarray(input_img)
output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
return output
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="docvqa")
parser.add_argument("--pretrained_path", type=str, default="train_docvqa_for_all_atts/donut/result/train_docvqa/20220915_125713")
args, left_argv = parser.parse_known_args()
task_name = args.task
if "docvqa" == task_name:
task_prompt = "<s_taco_eiko_pdf_donut>{user_input}</s_question><s_answer>"
else: # rvlcdip, cord, ...
task_prompt = f"<s_{task_name}>"
pretrained_model = DonutModel.from_pretrained(args.pretrained_path)
if torch.cuda.is_available():
# pretrained_model.half()
device = torch.device("cuda")
pretrained_model.to(device)
else:
pretrained_model.encoder.to(torch.bfloat16)
summary(pretrained_model, [(1, 3, 1280 , 960), (1, 21),(1, 21)])
The shape of the encoder and decoder is as follows.
Encoder : torch.Size([1, 3, 1280, 960])
Decode : torch.Size([1, 21])
##Model forward architecture looks like this
Can you please guide how to pass down the model input in summary?
The text was updated successfully, but these errors were encountered: