diff --git a/modeling.py b/modeling.py index a7d719cfb..b9136c3f7 100644 --- a/modeling.py +++ b/modeling.py @@ -232,8 +232,6 @@ def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len) if bsz is not None: - # With bi_data, the batch size should be divisible by 2. - assert bsz%2 == 0 fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2) bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2) else: @@ -472,6 +470,10 @@ def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, mlen = tf.shape(mems[0])[0] if mems is not None else 0 klen = mlen + qlen + # With bi_data, the batch size should be divisible by 2. + if bi_data: + assert (inp_k.shape[1] % 2) == 0 + ##### Attention mask # causal attention mask if attn_type == 'uni':