-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] ONNX conversion #6
base: master
Are you sure you want to change the base?
Conversation
DeBERTa/apps/train.py
Outdated
with torch.no_grad(): | ||
trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels']) | ||
# conversion fails now with: | ||
# site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast_tensor and mse_loss are ops that are not implemented in ONNX currently. To get unblocked need to modify functional.py as per below comment
DeBERTa/apps/train.py
Outdated
with torch.no_grad(): | ||
trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels']) | ||
# conversion fails now with: | ||
# site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mse_loss implementation in https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L2682 uses 2 ops that are not implemented: broadcast_tensors() and mse_loss(). Working around this to get unblocked, made a patch:
#expanded_input, expanded_target = torch.broadcast_tensors(input, target)
expanded_input = input + torch.zeros(target.size())
expanded_target = target + torch.zeros(input.size())
#ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
t = expanded_input - expanded_target
t = t * t
ret = torch.mean(t)
self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) | ||
self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) | ||
# Looks like params below are never updated and const, so removing them | ||
#self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q_bias and v_bias are always const, so commenting them out
Previous iterations i tried to redefine StableDropout to inherit from nn.Dropout, but it led to regression in model stats. Could not figure out why. If i do change this way there is no regression. Something was missing with just redefining StableDropout. |
5315a01
to
c81eb40
Compare
79dbe25
to
e59f09f
Compare
Changes needed to convert DeBerta to ONNX