-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Pytorch MaxP Feature/ptmaxp #184
Support Pytorch MaxP Feature/ptmaxp #184
Conversation
…; remove unnecessary sort
MSMARCO reproductino logs - nima
repro commit for both tf and pt
56ccaeb
to
db5e1ee
Compare
# # REF-TODO: save scheduler state along with optimizer | ||
# self.lr_scheduler.step() | ||
# hacky: use step instead the internally calculated epoch to support step-wise lr update | ||
self.lr_scheduler.step(epoch=cur_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a bit hacky here, where by default lr_scheduler.step
takes in the epoch; changing here as when we passing epoch=0 into our lr_multiplier
and the warmupiter is also 1, the lr
would be almost 0 for the entire first epoch.
@@ -222,6 +207,30 @@ def parse_label_tensor(x): | |||
label = tf.map_fn(parse_label_tensor, parsed_example["label"], dtype=tf.float32) | |||
|
|||
return (pos_bert_input, pos_mask, pos_seg, neg_bert_input, neg_mask, neg_seg), label | |||
|
|||
def _filter_inputs(self, bert_inputs, bert_masks, bert_segs, n_valid_psg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly for training, this function randomly select one passage from the n-passages
, this is done in extractor now so that pytorch
and tensorflow
trainer can both use it.
…ain_feature into two MixIn (depends when they generate list of passage or single passage per query at training time), so that they can be shared by each extractor as needed
This pull request introduces 9 alerts when merging db0e405 into a568304 - view on LGTM.com new alerts:
|
|
||
|
||
@Extractor.register | ||
class BirchBertPassage(MultipleTrainingPassagesMixin, BertPassage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inherit the create_train_features
and parse_train_features
from MultipleTrainingPassagesMixin
, and the other functions from BertPassage
I got a reasonable dev MRR with pytorch: 0.3548 |
id2vec
function, so that it's compatible to both tf-maxp and pt-maxp