We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
执行infer_A_md_brief.sh时,在DepictQA/src/model/depictqa.py文件484行处报错,该处代码为:assert len(logits) == num_tokens and logits[0].shape[0] == bsz 打印几个长度分别为: len(logits):12 num_tokens:11 logits[0].shape[0]:1 bsz:1 这个正常么?还是我哪里操作不当?
The text was updated successfully, but these errors were encountered:
不正常,你有修改过什么代码嘛?
Sorry, something went wrong.
正常来说,output_ids会在开头多一个start token,所以在这一行计算num_tokens时需要减去1 ( https://github.com/XPixelGroup/DepictQA/blob/main/src/model/depictqa.py#L475 ),以及这里找对应token时候需要加上1( https://github.com/XPixelGroup/DepictQA/blob/main/src/model/depictqa.py#L487 )。
你这里output_ids的个数和logits长度是一致的。可能是修改代码造成的,也可能是版本问题。
No branches or pull requests
执行infer_A_md_brief.sh时,在DepictQA/src/model/depictqa.py文件484行处报错,该处代码为:assert len(logits) == num_tokens and logits[0].shape[0] == bsz
打印几个长度分别为:
len(logits):12
num_tokens:11
logits[0].shape[0]:1
bsz:1
这个正常么?还是我哪里操作不当?
The text was updated successfully, but these errors were encountered: