Skip to content

Commit

Permalink
Merge pull request #17 from VietHoang1512/demo
Browse files Browse the repository at this point in the history
Python package release
  • Loading branch information
duclong1009 authored Jul 1, 2021
2 parents d39fc57 + 80e81b3 commit 7cdb840
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 58 deletions.
93 changes: 76 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,57 @@

</div>

## Installation


## Keypoint Analysis

This library is based on the Transformers library by HuggingFace. **Keypoint Analysis** quickly embedds the statements with provided supported topic and the stances toward that topic.

### What's New

#### July 1, 2021

- First release of [keypoint-analysis](https://pypi.org/project/keypoint-analysis/) python package

### Installation

```bash
pip install keypoint-analysis
```

## TODO:
### Quick example

```python
# Import needed libraries
from qs_kpa import KeyPointAnalysis

# Create a KeyPointAnalysis model
encoder = KeyPointAnalysis()

# Model configuration
print(encoder)

# Preparing data (a tuplet of (topic, statement, stance) or a list of tuple)
inputs = [
(
"Assisted suicide should be a criminal offence",
"a cure or treatment may be discovered shortly after having ended someone's life unnecessarily.",
1,
),
(
"Assisted suicide should be a criminal offence",
"Assisted suicide should not be allowed because many times people can still get better",
1,
),
("Assisted suicide should be a criminal offence", "Assisted suicide is akin to killing someone", 1),
]

# Go and embedd everything
output = encoder.encode(inputs, convert_to_numpy=True)
```

### Detailed training

### Pair-wise keypoint-argument
Given a pair of key point and argument (along with their supported topic & stance) and the matching score. Similar pairs with label 1 are pulled together, or pushed away otherwise.

#### Model
Expand All @@ -32,29 +74,46 @@ Given a pair of key point and argument (along with their supported topic & stanc

#### Loss

- [x] Constrastive
- [x] Online Constrastive
- [x] Triplet
- [x] Online Triplet (Hard negative/positive mining)
- Constrastive
- Online Constrastive
- Triplet
- Online Triplet (Hard negative/positive mining)

#### Distance

- [x] Euclidean
- [x] Cosine
- [x] Manhattan
- Euclidean
- Cosine
- Manhattan

#### Utils

- K-folds
- Full-flow

### Pseudo-label

Group the arguments by their key point and consider the order of that key point within the topic as their labels (see [pseudo_label](src/pseudo_label)). We can now utilize available pytorch metrics learning distance, losses, miners or reducers from this great [open-source](https://github.com/KevinMusgrave/pytorch-metric-learning) in the main training workflow. This is also our best approach (single-model) so far.
Group the arguments by their key point and consider the order of that key point within the topic as their labels (see [pseudo_label](qs_kpa/pseudo_label)). We can now utilize available pytorch metrics learning distance, losses, miners or reducers from this great [open-source](https://github.com/KevinMusgrave/pytorch-metric-learning) in the main training workflow. This is also our best approach (single-model) so far.

![Model architecture](assets/model.png "Model architecture")
![Model architecture](https://user-images.githubusercontent.com/52401767/124059293-0ec81100-da55-11eb-94a4-cf9914479a78.png)

### Utils
### Training data

- [x] K-folds
- [x] Full-flow
**ArgKP** dataset ([Bar-Haim et al., ACL-2020](https://www.aclweb.org/anthology/2020.acl-main.371.pdf))

## Contributors
### Contributors

- Phan Viet Hoang
- Nguyen Duc Long
- Nguyen Duc Long

### BibTeX

```bibtex
@misc{hoang2021qskpa,
author = {Phan, V.H. & Nguyen, D.L.},
title = {Keypoint Analysis},
year = {2021},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/VietHoang1512/KPA}}
}
```
2 changes: 1 addition & 1 deletion bin/train_baselines.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export PYTHONPATH=$PWD
for fold_id in 1 2 3 4
do
python src/scripts/main.py \
python scripts/main.py \
--experiment "baseline" \
--output_dir "outputs/baselines/fold_$fold_id" \
--model_name_or_path "roberta-base" \
Expand Down
2 changes: 1 addition & 1 deletion bin/train_mixed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export PYTHONPATH=$PWD
for fold_id in 1 2 3 4
do
echo "TRAINING ON FOLD $fold_id"
python src/scripts/main.py \
python scripts/main.py \
--experiment "mixed" \
--output_dir "outputs/mixed/fold_$fold_id" \
--model_name_or_path "roberta-base" \
Expand Down
6 changes: 3 additions & 3 deletions bin/train_pseudo_label.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ echo "OUTPUT DIRECTORY $OUTPUT_DIR"

mkdir -p $OUTPUT_DIR

cp src/pseudo_label/models.py $OUTPUT_DIR
cp pseudo_label/models.py $OUTPUT_DIR

for fold_id in 1 2 3 4
do
echo "TRAINING ON FOLD $fold_id"
python src/scripts/main.py \
python scripts/main.py \
--experiment "pseudolabel" \
--output_dir "$OUTPUT_DIR/fold_$fold_id" \
--model_name_or_path roberta-base \
Expand Down Expand Up @@ -46,7 +46,7 @@ do
done

echo "INFERENCE"
python src/scripts/main.py \
python scripts/main.py \
--experiment "pseudolabel" \
--output_dir "$OUTPUT_DIR" \
--model_name_or_path roberta-base \
Expand Down
2 changes: 1 addition & 1 deletion bin/train_qa.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export PYTHONPATH=$PWD
for fold_id in 1 2 3 4
do
python src/scripts/main.py \
python scripts/main.py \
--experiment "qa" \
--output_dir "outputs/qa/fold_$fold_id" \
--model_name_or_path "roberta-base" \
Expand Down
2 changes: 1 addition & 1 deletion bin/train_triplet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export PYTHONPATH=$PWD
for fold_id in 1 2 3 4
do
echo "TRAINING ON FOLD $fold_id"
python src/scripts/main.py \
python scripts/main.py \
--experiment "triplet" \
--output_dir "outputs/triplet/fold_$fold_id" \
--model_name_or_path "roberta-base" \
Expand Down
14 changes: 14 additions & 0 deletions qs_kpa/KeyPointAnalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def __init__(
self.model.to(self.device)
self.model.eval()

def __str__(self):
return self.__repr__()

def __repr__(self):
s = ""
s += f"Device: {self.device} \n"
s += f"Backbone configuration:\n{self.model}"
return s

@classmethod
def _download_and_cache(self, model_path: str) -> None:
gdown.download(URL, model_path, quiet=False)
Expand All @@ -89,6 +98,11 @@ def _load_model(self, model_path: str, model: PseudoLabelModel) -> None:
model.load_state_dict(torch.load(model_path))
logger.info(f"Loaded model from {model_path}")

def to(self, device: str):
self.device = torch.device(device)
self.model.to(self.device)
self.model.eval()

def encode(
self,
examples: Union[List[Tuple[str, str, int]], Tuple[str, str, int]],
Expand Down
2 changes: 1 addition & 1 deletion qs_kpa/backbone/base_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class BaseModelArguments:
text_dim: int = field(default=256, metadata={"help": "Hidden representation dimension used for encoding text."})

distance: str = field(
default="cosine", metadata={"help": "Function that returns a distance between two emeddings."}
default="cosine", metadata={"help": "Function that returns a distance between two embeddings."}
)

normalize: bool = field(default=False, metadata={"help": "Whether to normalize the embedding vectors or not."})
2 changes: 1 addition & 1 deletion qs_kpa/pseudo_label/model_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ class PseudoLabelModelArguments(BaseModelArguments):

"""Pseudo Label Arguments pertaining to which model/config/tokenizer we are going to fine- tune, or train from scratch."""

margin: float = field(default=0.5, metadata={"help": "Margin distance value for cicle loss."})
margin: float = field(default=0.3, metadata={"help": "Margin distance value for cicle loss."})
16 changes: 16 additions & 0 deletions qs_kpa/pseudo_label/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,19 @@ def get_embeddings(
statement_rep = F.normalize(statement_rep, p=2, dim=1)

return statement_rep

def __str__(self):
return self.__repr__()

def __repr__(self):
s = ""
s += f"\tBackbone: {self.args.model_name_or_path if self.args.model_name_or_path else self.args.model_name}\n"
s += f"\tNumber of hidden state: {self.args.n_hiddens}\n"
s += f"\tDropout rate: {self.args.drop_rate}\n"
s += f"\tUse batch normalization: {self.args.batch_norm}\n"
s += f"\tHidden representation dimension used for encoding stance: {self.args.stance_dim}\n"
s += f"\tHidden representation dimension used for encoding text: {self.args.text_dim}\n"
s += f"\tDistance metric: {self.args.distance}\n"
s += f"\tNormalize embedding: {self.args.normalize}\n"
s += f"\tMargin: {self.args.margin}"
return s
35 changes: 5 additions & 30 deletions qs_kpa/utils/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,38 +76,13 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True
) -> Tuple[DataClass, ...]:
"""
Parse command-line args into instances of the specified dataclass types.
This relies on argparse's `ArgumentParser.parse_known_args`.
See the doc at:
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
Args:
args:
List of strings to parse. The default is taken from sys.argv.
(same as argparse.ArgumentParser)
return_remaining_strings:
If true, also return a list of remaining argument strings.
look_for_args_file:
If true, will look for a ".args" file with the same base name
as the entry point script for this process, and will append its
potential content to the command line args.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they
were passed to the initializer.abspath
- if applicable, an additional namespace for more
(non-dataclass backed) arguments added to the parser
after initialization.
- The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args)
"""

if look_for_args_file and len(sys.argv):
args_file = Path(sys.argv[0]).with_suffix(".args")
if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.

namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
Expand Down Expand Up @@ -142,9 +117,9 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:


if __name__ == "__main__":
from src.baselines.data_argument import DataArguments
from src.baselines.model_argument import ModelArguments
from src.train_utils.training_argument import TrainingArguments
from qs_kpa.baselines.data_argument import DataArguments
from qs_kpa.baselines.model_argument import ModelArguments
from qs_kpa.train_utils.training_argument import TrainingArguments

parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="keypoint-analysis",
version="0.0.1",
version="1.0.1",
description="Quantitative Summarization – Key Point Analysis",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
3 changes: 2 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
if __name__ == "__main__":

encoder = KeyPointAnalysis(from_pretrained=False)
print(encoder)
inputs = [
(
"Assisted suicide should be a criminal offence",
Expand All @@ -19,7 +20,7 @@
("Assisted suicide should be a criminal offence", "Assisted suicide is akin to killing someone", 1),
]

output = encoder.encode(inputs[0], convert_to_numpy=True)
output = encoder.encode(inputs[0], show_progress_bar=False, convert_to_numpy=True)
print("Embedding shape", output.shape)

output = encoder.encode(inputs, convert_to_numpy=True)
Expand Down

0 comments on commit 7cdb840

Please sign in to comment.