Skip to content

Commit

Permalink
Add XLnet
Browse files Browse the repository at this point in the history
  • Loading branch information
VietHoang1512 committed Jul 1, 2021
1 parent 8b799ed commit 9ae08dd
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Given a pair of key point and argument (along with their supported topic & stanc
| Model | BERT/ConvBERT | DistilBERT | ALBERT | XLNet | RoBERTa | ELECTRA | BART |
| ------------------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| Siamese Baseline | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| Siamese Question Answering-like | ✔️ | ✔️ | ✔️ | | ✔️ | ✔️ | ✔️ |
| Siamese Question Answering-like | ✔️ | ✔️ | ✔️ |✔️| ✔️ | ✔️ | ✔️ |
| Custom loss Baseline | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |

#### Loss
Expand Down
2 changes: 1 addition & 1 deletion bin/train_pseudo_label.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ echo "OUTPUT DIRECTORY $OUTPUT_DIR"

mkdir -p $OUTPUT_DIR

cp pseudo_label/models.py $OUTPUT_DIR
cp qs_kpa/pseudo_label/models.py $OUTPUT_DIR

for fold_id in 1 2 3 4
do
Expand Down
13 changes: 8 additions & 5 deletions qs_kpa/backbone/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ def _tokenize(self, text: str, max_len: int) -> Tuple[torch.Tensor, torch.Tensor
token_type_ids = torch.tensor(inputs["token_type_ids"], dtype=torch.long)

if self.tokenizer_type in ["xlnet"]:
input_ids = torch.squeeze(input_ids, 1)
attention_mask = torch.squeeze(attention_mask, 1)
token_type_ids = torch.squeeze(token_type_ids, 1)
if input_ids.size(0) > 1:
logger.warning(f"String `{text}` is truncated with maximum length {max_len}")
input_ids = input_ids[0]
attention_mask = attention_mask[0]
token_type_ids = token_type_ids[0]
else:
if inputs.get("num_truncated_tokens", 0) > 0:
logger.warning(f"String `{text}` is truncated with maximum length {max_len}")

if len(inputs.get("overflowing_tokens", [])) > 0:
logger.warning(f"String `{text}` is truncated with maximum length {max_len}")
return input_ids, attention_mask, token_type_ids

def _process_data(self, df: pd.DataFrame) -> List[Dict]:
Expand Down
24 changes: 12 additions & 12 deletions qs_kpa/backbone/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def _init_weights(self, module: nn.Module):
module.weight.data.fill_(1.0)

def _forward_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor):
if self.model_type in ["xlnet"]:
input_ids = torch.squeeze(input_ids, 1)
attention_mask = torch.squeeze(attention_mask, 1)
token_type_ids = torch.squeeze(token_type_ids, 1)
# if self.model_type in ["xlnet"]:
# input_ids = torch.squeeze(input_ids, 1)
# attention_mask = torch.squeeze(attention_mask, 1)
# token_type_ids = torch.squeeze(token_type_ids, 1)
if self.model_type in ["t5", "distilbert", "electra", "bart", "xlm", "xlnet", "camembert", "longformer"]:
output = self.bert_model(input_ids, attention_mask=attention_mask)
else:
Expand All @@ -76,14 +76,14 @@ def _forward_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, t
if self.model_type == "bart":
hidden_states_key = "decoder_hidden_states"

if self.model_type == "xlnet" and (not self.training):
# FIXME: XLnet behaves differenctly between train-eval
output = torch.cat(
[torch.transpose(output[hidden_states_key][-i], 0, 1)[:, 0, :] for i in range(self.n_hiddens)],
axis=-1,
)
else:
output = torch.cat([output[hidden_states_key][-i][:, 0, :] for i in range(self.n_hiddens)], axis=-1)
# if self.model_type == "xlnet" and (not self.training):
# # FIXME: XLnet behaves differenctly between train-eval
# output = torch.cat(
# [torch.transpose(output[hidden_states_key][-i], 0, 1)[:, 0, :] for i in range(self.n_hiddens)],
# axis=-1,
# )
# else:
output = torch.cat([output[hidden_states_key][-i][:, 0, :] for i in range(self.n_hiddens)], axis=-1)
else:
output = output["last_hidden_state"][:, 0, :]

Expand Down
46 changes: 32 additions & 14 deletions qs_kpa/question_answering/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,37 @@ def __getitem__(self, idx) -> Dict[str, List]:

stance = torch.tensor([self.stance[idx]], dtype=torch.float)

key_point_input_ids = torch.tensor(encoded_key_point["input_ids"], dtype=torch.long)
key_point_attention_mask = torch.tensor(encoded_key_point["attention_mask"], dtype=torch.long)
key_point_token_type_ids = torch.tensor(encoded_key_point["token_type_ids"], dtype=torch.long)
argument_input_ids = torch.tensor(encoded_argument["input_ids"], dtype=torch.long)
argument_attention_mask = torch.tensor(encoded_argument["attention_mask"], dtype=torch.long)
argument_token_type_ids = torch.tensor(encoded_argument["token_type_ids"], dtype=torch.long)

if self.tokenizer_type in ["xlnet"]:
if key_point_input_ids.size(0) > 1:
logger.warning(
f"Topic `{topic}` and {key_point} is truncated with maximum length {self.max_topic_length} and {self.max_statement_length}"
)
key_point_input_ids = key_point_input_ids[0]
key_point_attention_mask = key_point_attention_mask[0]
key_point_token_type_ids = key_point_token_type_ids[0]

if argument_input_ids.size(0) > 1:
logger.warning(
f"Topic `{topic}` and {argument} is truncated with maximum length {self.max_topic_length} and {self.max_statement_length}"
)
argument_input_ids = argument_input_ids[0]
argument_attention_mask = argument_attention_mask[0]
argument_token_type_ids = argument_token_type_ids[0]

sample = {
"key_point_input_ids": torch.tensor(encoded_key_point["input_ids"], dtype=torch.long),
"key_point_attention_mask": torch.tensor(encoded_key_point["attention_mask"], dtype=torch.long),
"key_point_token_type_ids": torch.tensor(encoded_key_point["token_type_ids"], dtype=torch.long),
"argument_input_ids": torch.tensor(encoded_argument["input_ids"], dtype=torch.long),
"argument_attention_mask": torch.tensor(encoded_argument["attention_mask"], dtype=torch.long),
"argument_token_type_ids": torch.tensor(encoded_argument["token_type_ids"], dtype=torch.long),
"key_point_input_ids": key_point_input_ids,
"key_point_attention_mask": key_point_attention_mask,
"key_point_token_type_ids": key_point_token_type_ids,
"argument_input_ids": argument_input_ids,
"argument_attention_mask": argument_attention_mask,
"argument_token_type_ids": argument_token_type_ids,
"stance": stance,
"label": torch.tensor(self.label[idx], dtype=torch.float),
}
Expand Down Expand Up @@ -135,10 +159,7 @@ def _process(
return_token_type_ids=True,
return_overflowing_tokens=True,
)
if (
len(encoded_key_point.get("overflowing_tokens", [])) > 0
or len(encoded_argument.get("overflowing_tokens", [])) > 0
):
if encoded_key_point.get("num_truncated_tokens", 0) > 0 or encoded_argument.get("num_truncated_tokens", 0) > 0:
logger.warning(f"String is truncated with maximum length {max_length}")

return encoded_key_point, encoded_argument
Expand Down Expand Up @@ -273,10 +294,7 @@ def _process(
return_token_type_ids=True,
return_overflowing_tokens=True,
)
if (
len(encoded_key_point.get("overflowing_tokens", [])) > 0
or len(encoded_argument.get("overflowing_tokens", [])) > 0
):
if encoded_key_point.get("num_truncated_tokens", 0) > 0 or encoded_argument.get("num_truncated_tokens", 0) > 0:
logger.warning(f"String is truncated with maximum length {max_length}")

return encoded_key_point, encoded_argument
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="keypoint-analysis",
version="1.0.1",
version="1.0.0",
description="Quantitative Summarization – Key Point Analysis",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 9ae08dd

Please sign in to comment.