diff --git a/taa/archive.py b/taa/archive.py index 27ffae7..1fb69d5 100644 --- a/taa/archive.py +++ b/taa/archive.py @@ -2,7 +2,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from .augmentation import augment_list +try: + from .augmentation import augment_list +except: + from augmentation import augment_list def remove_deplicates(policies): @@ -172,6 +175,56 @@ def trec(): ('tfidf_word_insert', 0.9572797299031379, 0.5632362843917429)]] +def sst2(): + return [[('synonym_word_substitute', 0.5364303169444116, 0.19413963688692676), + ('synonym_word_substitute', 0.5031232065494572, 0.1525844904189979)], + [('tfidf_word_substitute', 0.6418241968074692, 0.3370626777277068), + ('tfidf_word_substitute', 0.9196749316584165, 0.38708167926026316)], + [('random_word_delete', 0.46983690625301894, 0.8914079607975975), + ('random_word_delete', 0.1736560332764212, 0.04829071948477086)], + [('random_word_delete', 0.06996736110631846, 0.26554522796740154), + ('random_word_swap', 0.3465013433062492, 0.012844513493466714)], + [('random_word_swap', 0.08932276061312096, 0.8109482577155689), + ('random_word_swap', 0.18372409409505108, 0.05053298441875576)], + [('synonym_word_substitute', 0.3379990774213545, 0.5933999181049157), + ('tfidf_word_substitute', 0.9983343578421985, 0.5781089454270496)], + [('tfidf_word_insert', 0.5113998603783605, 0.41212940260619846), + ('random_word_delete', 0.5293255489346063, 0.14480959698732754)], + [('tfidf_word_insert', 0.6745183870263084, 0.95335175846966), + ('random_word_swap', 0.4128642761364277, 0.00691588489202144)], + [('synonym_word_substitute', 0.6352531215702312, 0.18311186039533067), + ('random_word_delete', 0.5926874262625883, 0.5364992599119727)], + [('tfidf_word_substitute', 0.0007766212863096478, 0.9951597620125248), + ('tfidf_word_insert', 0.9284908255001041, 0.8151066122389705)], + [('random_word_delete', 0.5543678813405379, 0.5189422763173466), + ('tfidf_word_substitute', 0.4301374956162636, 0.8941465004308808)], + [('random_word_delete', 0.5649901290207588, 0.4147469438768546), + ('synonym_word_substitute', 0.7379541454486046, 0.39346847798276585)]] + + +def cola(): + return [[('random_word_delete', 0.3561054823049778, 0.38653604140692843), + ('tfidf_word_substitute', 0.6119904498638091, 0.7480956486081061)], + [('tfidf_word_substitute', 0.7333089032420526, 0.8074409858237845), + ('random_word_swap', 0.5336904643341456, 0.3061304517754367)], + [('synonym_word_substitute', 0.8703484961446668, 0.8713429056373105), + ('synonym_word_substitute', 0.3888761904351198, 0.44534142264101106)], + [('random_word_swap', 0.38489549841780535, 0.5162453703445143), + ('random_word_delete', 0.6245419459874113, 0.20536909806810433)], + [('tfidf_word_substitute', 0.22666456210358554, 0.39372880970524793), + ('random_word_swap', 0.811864987249894, 0.9555265437891366)], + [('tfidf_word_insert', 0.30050492672021645, 0.9916786328306386), + ('tfidf_word_substitute', 0.06255950877127048, 0.6309235049048048)], + [('random_word_delete', 0.6719132372044447, 0.9836329858619379), + ('tfidf_word_insert', 0.8420599342164212, 0.404672629267566)], + [('synonym_word_substitute', 0.5918495436499459, 0.8293310716446014), + ('tfidf_word_insert', 0.9111906303284524, 0.5205727423907498)], + [('synonym_word_substitute', 0.8506006593326765, 0.6472527298713463), + ('synonym_word_substitute', 0.43707935799244513, 0.5494767725442937)], + [('tfidf_word_insert', 0.6572641245063084, 0.32120987775289295), + ('random_word_swap', 0.4009335761117499, 0.3015697007069029)]] + + def default_policy(): return [[('synonym_word_substitute', 0.7492730962660217, 0.8816452863413866), ('synonym_word_substitute', 0.33184334794125936, 0.5208169910984721)], @@ -197,4 +250,5 @@ def default_policy(): ('tfidf_word_insert', 0.9572797299031379, 0.5632362843917429)]] -policy_map = {'imdb': imdb(), 'sst5': sst5(), 'trec': trec(), 'yelp2': yelp2(), 'yelp5': yelp5()} +policy_map = {'imdb': imdb(), 'sst5': sst5(), 'trec': trec(), + 'yelp2': yelp2(), 'yelp5': yelp5(), 'sst2':sst2(), 'cola':cola()} diff --git a/taa/data.py b/taa/data.py index 9f604d9..a6f0501 100644 --- a/taa/data.py +++ b/taa/data.py @@ -4,27 +4,38 @@ import random import torch from torch.utils.data import SubsetRandomSampler, Sampler, Subset, ConcatDataset -from .transforms import Compose from sklearn.model_selection import StratifiedShuffleSplit, KFold from theconf import Config as C import numpy as np -from .custom_dataset import general_dataset from datasets import load_dataset -from .augmentation import get_augment, apply_augment, random_augment -from .common import get_logger import pandas as pd -from .utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split from transformers import BertTokenizer, BertTokenizerFast -from .text_networks import num_class import math import copy -from .archive import policy_map, default_policy import multiprocessing from functools import partial import time -from .utils.get_data import download_data -from .utils.metrics import n_dist +try: + from .utils.get_data import download_data + from .utils.metrics import n_dist + from .archive import policy_map, default_policy + from .text_networks import num_class + from .utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split + from .common import get_logger + from .augmentation import get_augment, apply_augment, random_augment + from .custom_dataset import general_dataset + from .transforms import Compose +except: + from utils.get_data import download_data + from utils.metrics import n_dist + from archive import policy_map, default_policy + from text_networks import num_class + from utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split + from common import get_logger + from augmentation import get_augment, apply_augment, random_augment + from custom_dataset import general_dataset + from transforms import Compose logger = get_logger('Text AutoAugment') logger.setLevel(logging.INFO) @@ -45,7 +56,7 @@ def get_datasets(dataset, policy_opt): elif aug in list(policy_map.keys()): # use pre-searched policy transform_train.transforms.insert(0, Augmentation(policy_map[aug])) else: - transform_train.transforms.insert(0, Augmentation(default_policy())) + pass # load dataset tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') @@ -95,6 +106,11 @@ def __init__(self, policy): def __call__(self, texts, labels): texts = np.array(texts) labels = np.array(labels) + if C.get()['ir'] < 1 and C.get()['method'] != 'bt': + # rebalanced data + ir_index = np.where(labels == 0) + texts = np.append(texts, texts[ir_index].repeat(int(1 / C.get()['ir']) - 1)) + labels = np.append(labels, labels[ir_index].repeat(int(1 / C.get()['ir']) - 1)) # generate multiple augmented data if necessary labels = labels.repeat(C.get()['n_aug']) texts = texts.repeat(C.get()['n_aug']) diff --git a/taa/search.py b/taa/search.py index e2c1a7a..51b0deb 100644 --- a/taa/search.py +++ b/taa/search.py @@ -18,11 +18,6 @@ from ray import tune from tqdm import tqdm from datetime import datetime -from .archive import policy_decoder, remove_deplicates -from .augmentation import augment_list -from .common import get_logger, add_filehandler -from .utils.train_tfidf import train_tfidf -from .train import train_and_eval from theconf import Config as C, ConfigArgumentParser import joblib import random @@ -30,6 +25,21 @@ from pystopwatch2 import PyStopwatch import json +try: + from .archive import policy_decoder, remove_deplicates + from .augmentation import augment_list + from .common import get_logger, add_filehandler + from .utils.train_tfidf import train_tfidf + from .train import train_and_eval +except: + from archive import policy_decoder, remove_deplicates + from augmentation import augment_list + from common import get_logger, add_filehandler + from utils.train_tfidf import train_tfidf + from train import train_and_eval + + + logging.basicConfig(level=logging.INFO) @@ -127,7 +137,7 @@ def search_policy(dataset, abspath, configfile=None, num_search=200, num_policy= logger.info('Initialize ray...') # ray.init(num_cpus=num_cpus, local_mode=True) # used for debug ray.init(num_gpus=num_gpus, num_cpus=num_cpus) - + train_tfidf(dataset_type) # calculate tf-idf score for TS and TI operations if method == 'taa': diff --git a/taa/search_and_augment.py b/taa/search_and_augment.py index 36df40d..8952eef 100644 --- a/taa/search_and_augment.py +++ b/taa/search_and_augment.py @@ -1,10 +1,17 @@ import pkg_resources -from .data import augment -from .search import search_policy from datasets import load_dataset from theconf import Config as C -from .archive import policy_map -from .utils.train_tfidf import train_tfidf + +try: + from .archive import policy_map + from .utils.train_tfidf import train_tfidf + from .data import augment + from .search import search_policy +except: + from archive import policy_map + from utils.train_tfidf import train_tfidf + from data import augment + from search import search_policy def search_and_augment(configfile=None): diff --git a/taa/search_augment_train.py b/taa/search_augment_train.py index 8ca4862..4d877f5 100644 --- a/taa/search_augment_train.py +++ b/taa/search_augment_train.py @@ -4,21 +4,67 @@ import shutil import logging import transformers -from .data import augment -from .train import compute_metrics -from .search import search_policy from datasets import load_dataset import pandas as pd import numpy as np from theconf import Config as C -from .utils.raw_data_utils import get_examples -from .custom_dataset import general_dataset -from .text_networks import get_model, num_class, get_num_class from transformers import BertForSequenceClassification, Trainer, TrainingArguments, BertTokenizerFast -from .common import get_logger, add_filehandler -from .utils.train_tfidf import train_tfidf import joblib +try: + from .data import augment + from .train import compute_metrics + from .search import search_policy + from .utils.raw_data_utils import get_examples + from .custom_dataset import general_dataset + from .text_networks import get_model, num_class, get_num_class + from .common import get_logger, add_filehandler + from .utils.train_tfidf import train_tfidf + from .transforms import Compose + from .data import Augmentation + from .custom_dataset import general_dataset + from .utils.raw_data_utils import get_examples, general_split + +except: + from data import augment + from train import compute_metrics + from search import search_policy + from utils.raw_data_utils import get_examples + from custom_dataset import general_dataset + from text_networks import get_model, num_class, get_num_class + from common import get_logger, add_filehandler + from utils.train_tfidf import train_tfidf + from transforms import Compose + from data import Augmentation + from custom_dataset import general_dataset + from utils.raw_data_utils import get_examples, general_split + + + +def get_all_datasets(dataset, n_aug, policy, test_size): + """ get augmented train, valid and full test datasets """ + C.get()['n_aug'] = n_aug + text_key = C.get()['dataset']['text_key'] + + transform_train = Compose([]) + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + transform_train.transforms.insert(0, Augmentation(policy)) + + train_dataset = load_dataset(dataset, split='train') + test_dataset = load_dataset(dataset, split='test') + + class_num = train_dataset.features['label'].num_classes + all_train_examples = get_examples(train_dataset, text_key) + + train_examples, valid_examples = general_split(all_train_examples, test_size=test_size, train_size=1-test_size) + + test_examples = get_examples(test_dataset, text_key) + + train_dataset = general_dataset(train_examples, tokenizer, text_transform=transform_train) + val_dataset = general_dataset(valid_examples, tokenizer, text_transform=None) + test_dataset = general_dataset(test_examples, tokenizer, text_transform=None) + + return train_dataset, val_dataset, test_dataset def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy_opt, save_path=None, only_eval=False): transformers.logging.set_verbosity_info() @@ -32,13 +78,7 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy text_key = C.get()['dataset']['text_key'] tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') - - valid_examples = get_examples(valid_dataset, text_key) - test_examples = get_examples(test_dataset, text_key) - train_dataset = augmented_train_dataset - val_dataset = general_dataset(valid_examples, tokenizer, text_transform=None) - test_dataset = general_dataset(test_examples, tokenizer, text_transform=None) do_train = True logging_dir = os.path.join('logs/%s_%s/%s' % (dataset_type, model_type, tag)) @@ -86,7 +126,7 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy model=model, # the instantiated 🤗 Transformers model to be trained args=training_args, # training arguments, defined above train_dataset=train_dataset, # training dataset - eval_dataset=val_dataset, # evaluation dataset + eval_dataset=valid_dataset, # evaluation dataset compute_metrics=compute_metrics, ) @@ -102,20 +142,21 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy if file.startswith("checkpoint-"): shutil.rmtree(os.path.join(save_path, file)) - # logger.info("evaluating on test set") + logger.info("evaluating on test set") # note that when policy_opt, the test_dataset is only a subset of true test_dataset, used for evaluating policy - # result = trainer.evaluate(eval_dataset=test_dataset) + result = trainer.evaluate(eval_dataset=test_dataset) + + result['n_dist'] = train_dataset.aug_n_dist + result['opt_object'] = result['eval_accuracy'] - # result['n_dist'] = train_dataset.aug_n_dist - # result['opt_object'] = result['eval_accuracy'] - logger.info("Predicting on test set") - result = trainer.predict(test_dataset) - logits = result.predictions - predictions = np.argmax(logits, axis=1) + # logger.info("Predicting on test set") + # result = trainer.predict(test_dataset) + # logits = result.predictions + # predictions = np.argmax(logits, axis=1) - predict_df = pd.read_csv('%s.tsv' % dataset_type, sep='\t') - predict_df['prediction'] = predictions - predict_df.to_csv('%s.tsv' % dataset_type, sep='\t', index=False) + # predict_df = pd.read_csv('%s.tsv' % dataset_type, sep='\t') + # predict_df['prediction'] = predictions + # predict_df.to_csv('%s.tsv' % dataset_type, sep='\t', index=False) if policy_opt: shutil.rmtree(save_path) @@ -123,31 +164,27 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy if __name__ == '__main__': - _ = C('confs/bert_sst2_example.yaml') + _ = C('confs/bert_imdb_example.yaml') + dataset_type = C.get()['dataset']['name'] + model_type = C.get()['model']['type'] # search augmentation policy for specific dataset - # search_policy(dataset='sst2', configfile='bert_sst2_example.yaml', abspath='/home/renshuhuai/text-autoaugment' ) - - train_dataset = load_dataset('glue', 'sst2', split='train') - valid_dataset = load_dataset('glue', 'sst2', split='validation') - test_dataset = load_dataset('glue', 'sst2', split='test') + # search_policy(dataset='imdb', abspath='/home/renshuhuai/text-autoaugment' ) - # generate augmented dataset - configfile = 'bert_sst2_example.yaml' - policy_path = '/home/renshuhuai/text-autoaugment/final_policy/sst2_Bert_seed59_train-npc50_n-aug8_ir1.00_taa.pkl' + # get augmented train dataset, valid and full test datasets + configfile = 'bert_imdb_example.yaml' + policy_path = '/home/wangyuxiang/Text_AutoAugment/text-autoaugment/final_policy/imdb_Bert_seed59_train-npc50_n-aug4_ir1.00_taa.pkl' policy = joblib.load(policy_path) - augmented_train_dataset = augment(dataset=train_dataset, policy=policy, n_aug=8, configfile=configfile) + augmented_train_dataset, valid_dataset, test_dataset = get_all_datasets(dataset_type, n_aug=4, policy=policy, test_size=0.3) # training - dataset_type = C.get()['dataset']['name'] - model_type = C.get()['model']['type'] - train_tfidf(dataset_type) # calculate tf-idf score for TS and TI operations tag = '%s_%s_with_found_policy' % (dataset_type, model_type) save_path = os.path.join('models', tag) result = train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy_opt=False, - save_path=save_path, only_eval=True) - # for k,v in result.items(): - # print('%s:%s' % (k,v)) + save_path=save_path, only_eval=False) + + for k,v in result.items(): + print('%s:%s' % (k,v)) diff --git a/taa/train.py b/taa/train.py index 97a8cd0..a4a0aaf 100644 --- a/taa/train.py +++ b/taa/train.py @@ -10,13 +10,22 @@ import numpy as np import torch from theconf import Config as C, ConfigArgumentParser -from .common import get_logger, add_filehandler -from .data import get_datasets -from .text_networks import get_model, num_class, get_num_class import transformers from transformers import BertForSequenceClassification, Trainer, TrainingArguments -from .utils.metrics import accuracy, f1, accuracy_score -from .utils.train_tfidf import train_tfidf + +try: + from .utils.metrics import accuracy, f1, accuracy_score + from .utils.train_tfidf import train_tfidf + from .common import get_logger, add_filehandler + from .data import get_datasets + from .text_networks import get_model, num_class, get_num_class +except: + from utils.metrics import accuracy, f1, accuracy_score + from utils.train_tfidf import train_tfidf + from common import get_logger, add_filehandler + from data import get_datasets + from text_networks import get_model, num_class, get_num_class + transformers.logging.set_verbosity_info() logger = get_logger('Text AutoAugment') diff --git a/taa/utils/metrics.py b/taa/utils/metrics.py index cbccbd1..0e71e9d 100644 --- a/taa/utils/metrics.py +++ b/taa/utils/metrics.py @@ -3,7 +3,10 @@ import torch import numpy as np from collections import defaultdict -from .distinct_n import distinct_n_corpus_level +try: + from .distinct_n import distinct_n_corpus_level +except: + from distinct_n import distinct_n_corpus_level from torch import nn diff --git a/taa/utils/raw_data_utils.py b/taa/utils/raw_data_utils.py index 408c0fe..77d89a3 100644 --- a/taa/utils/raw_data_utils.py +++ b/taa/utils/raw_data_utils.py @@ -5,14 +5,21 @@ import torchtext import csv import os +import sys import numpy as np import random import pandas as pd import re from theconf import Config as C -from ..common import get_logger import logging + +try: + from ..common import get_logger +except: + sys.path.append("..") + from common import get_logger + logger = get_logger('Text AutoAugment') logger.setLevel(logging.INFO) diff --git a/taa/utils/train_tfidf.py b/taa/utils/train_tfidf.py index 0517994..115c497 100644 --- a/taa/utils/train_tfidf.py +++ b/taa/utils/train_tfidf.py @@ -1,10 +1,13 @@ import re import argparse import nlpaug.model.word_stats as nmw -from .raw_data_utils import get_examples import os from theconf import Config as C from datasets import load_dataset +try: + from .raw_data_utils import get_examples +except: + from raw_data_utils import get_examples def _tokenizer(text, token_pattern=r"(?u)\b\w\w+\b"):