Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF cross_entropy (matmul, gather) should maybe have allow_broadcast? #1636

Open
albertz opened this issue Oct 19, 2024 · 0 comments
Open

RF cross_entropy (matmul, gather) should maybe have allow_broadcast? #1636

albertz opened this issue Oct 19, 2024 · 0 comments

Comments

@albertz
Copy link
Member

albertz commented Oct 19, 2024

I had this bug:

log_prob = ...  # [B,T+1,D]
targets = ...  # [B,T] -> D
loss = rf.cross_entropy(target=targets, estimated=log_prob, ...)
loss.mark_as_loss(...)

What you get here is no error. It just works. I just wondered in my log that it doesn't seem to converge, but otherwise, all looked reasonable.

The bug is: loss here is [B,T+1,T], because T != T+1, so it gets broadcasted.

For other functions like rf.combine, rf.compare, rf.clip_by_value, rf.where, we have the arg allow_broadcast_all_sources. Or rf.concat has allow_broadcast. The default is False.

Note, the implementation of rf.cross_entropy in the general case just uses rf.gather or rf.matmul, which both don't have such an argument. For both matmul/gather, there are also many valid cases where broadcasting would be wanted (maybe broadcasting is also the wrong term, not sure). So if we would add such an argument, maybe the default would be True.

I'm not sure if there are many valid cases for rf.cross_entropy where broadcasting would make sense. So here the default should rather be False. I think it would break almost no setups, except my bugged code above, where it was unintentionally wrong anyway.

But I'm not sure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant