Skip to content

Commit

Permalink
updated data processor and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Sep 24, 2019
1 parent 0b82e3d commit b5ec526
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 155 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,5 @@ runs
examples/runs

# data
data
/data
serialization_dir
7 changes: 5 additions & 2 deletions examples/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@

from pytorch_transformers import AdamW, WarmupLinearSchedule

from pytorch_transformers.preprocessing import (compute_metrics, output_modes, processors, convert_examples_to_glue_features)
from pytorch_transformers import glue_compute_metrics as compute_metrics
from pytorch_transformers import glue_output_modes as output_modes
from pytorch_transformers import glue_processors as processors
from pytorch_transformers import glue_convert_examples_to_features as convert_examples_to_features

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -275,7 +278,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_glue_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
Expand Down
7 changes: 7 additions & 0 deletions pytorch_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)

from .data import (is_sklearn_available,
InputExample, InputFeatures, DataProcessor,
glue_output_modes, glue_convert_examples_to_features, glue_processors)

if is_sklearn_available():
from .data import glue_compute_metrics
6 changes: 6 additions & 0 deletions pytorch_transformers/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .processors import (InputExample, InputFeatures, DataProcessor,
glue_output_modes, glue_convert_examples_to_features, glue_processors)
from .metrics import is_sklearn_available

if is_sklearn_available():
from .metrics import glue_compute_metrics
83 changes: 83 additions & 0 deletions pytorch_transformers/data/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import sys
import logging

logger = logging.getLogger(__name__)

try:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
_has_sklearn = True
except e:
logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
_has_sklearn = False

def is_sklearn_available():
return _has_sklearn

if _has_sklearn:

def simple_accuracy(preds, labels):
return (preds == labels).mean()


def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}


def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
"corr": (pearson_corr + spearman_corr) / 2,
}


def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
2 changes: 2 additions & 0 deletions pytorch_transformers/data/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .utils import InputExample, InputFeatures, DataProcessor
from .glue import output_modes, processors, convert_examples_to_glue_features
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,50 @@
# limitations under the License.
""" GLUE processors and helpers """

from .utils import DataProcessor
import logging
import os

from .utils import DataProcessor, InputExample, InputFeatures

logger = logging.getLogger(__name__)

GLUE_TASKS_NUM_LABELS = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
"sst-2": 2,
"sts-b": 1,
"qqp": 2,
"qnli": 2,
"rte": 2,
"wnli": 2,
}

processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
"sts-b": StsbProcessor,
"qqp": QqpProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
}

output_modes = {
"cola": "classification",
"mnli": "classification",
"mnli-mm": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
}

def convert_examples_to_glue_features(examples, label_list, max_seq_length,
tokenizer, output_mode,
Expand Down Expand Up @@ -91,37 +129,6 @@ def convert_examples_to_glue_features(examples, label_list, max_seq_length,
return features


class InputExample(object):
"""A single training/test example for simple sequence classification."""

def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label


class InputFeatures(object):
"""A single set of features of data."""

def __init__(self, input_ids, input_mask, segment_ids, label_id):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id


class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""

Expand Down Expand Up @@ -420,15 +427,3 @@ def _create_examples(self, lines, set_type):
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

GLUE_TASKS_NUM_LABELS = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
"sst-2": 2,
"sts-b": 1,
"qqp": 2,
"qnli": 2,
"rte": 2,
"wnli": 2,
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,34 @@
import csv
import sys

from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label


class InputFeatures(object):
"""A single set of features of data."""

def __init__(self, input_ids, input_mask, segment_ids, label_id):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id


class DataProcessor(object):
Expand Down Expand Up @@ -47,53 +73,3 @@ def _read_tsv(cls, input_file, quotechar=None):
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line)
return lines


def simple_accuracy(preds, labels):
return (preds == labels).mean()


def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}


def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
"corr": (pearson_corr + spearman_corr) / 2,
}


def compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
56 changes: 0 additions & 56 deletions pytorch_transformers/preprocessing/__init__.py

This file was deleted.

0 comments on commit b5ec526

Please sign in to comment.