Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

推理报错:assert len(logits) == num_tokens and logits[0].shape[0] == bsz #17

Open
YoungerGao opened this issue Nov 13, 2024 · 2 comments

Comments

@YoungerGao
Copy link

执行infer_A_md_brief.sh时,在DepictQA/src/model/depictqa.py文件484行处报错,该处代码为:assert len(logits) == num_tokens and logits[0].shape[0] == bsz
打印几个长度分别为:
len(logits):12
num_tokens:11
logits[0].shape[0]:1
bsz:1
这个正常么?还是我哪里操作不当?

@zhiyuanyou
Copy link
Collaborator

不正常,你有修改过什么代码嘛?

@zhiyuanyou
Copy link
Collaborator

zhiyuanyou commented Nov 14, 2024

正常来说,output_ids会在开头多一个start token,所以在这一行计算num_tokens时需要减去1 ( https://github.com/XPixelGroup/DepictQA/blob/main/src/model/depictqa.py#L475 ),以及这里找对应token时候需要加上1( https://github.com/XPixelGroup/DepictQA/blob/main/src/model/depictqa.py#L487 )。

你这里output_ids的个数和logits长度是一致的。可能是修改代码造成的,也可能是版本问题。

  1. 试试将config文件中的 output_confidence: true 修改为false应该可以正常推理。
  2. 可以打印output_ids,看看有没有多一个start token,如果没有,将上述两行的减去1、加上1都去掉,也可以解决该问题。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants