Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix import issue when running huggingface_lowresource.sh #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions taa/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .augmentation import augment_list
try:
from .augmentation import augment_list
except:
from augmentation import augment_list


def remove_deplicates(policies):
Expand Down Expand Up @@ -172,6 +175,56 @@ def trec():
('tfidf_word_insert', 0.9572797299031379, 0.5632362843917429)]]


def sst2():
return [[('synonym_word_substitute', 0.5364303169444116, 0.19413963688692676),
('synonym_word_substitute', 0.5031232065494572, 0.1525844904189979)],
[('tfidf_word_substitute', 0.6418241968074692, 0.3370626777277068),
('tfidf_word_substitute', 0.9196749316584165, 0.38708167926026316)],
[('random_word_delete', 0.46983690625301894, 0.8914079607975975),
('random_word_delete', 0.1736560332764212, 0.04829071948477086)],
[('random_word_delete', 0.06996736110631846, 0.26554522796740154),
('random_word_swap', 0.3465013433062492, 0.012844513493466714)],
[('random_word_swap', 0.08932276061312096, 0.8109482577155689),
('random_word_swap', 0.18372409409505108, 0.05053298441875576)],
[('synonym_word_substitute', 0.3379990774213545, 0.5933999181049157),
('tfidf_word_substitute', 0.9983343578421985, 0.5781089454270496)],
[('tfidf_word_insert', 0.5113998603783605, 0.41212940260619846),
('random_word_delete', 0.5293255489346063, 0.14480959698732754)],
[('tfidf_word_insert', 0.6745183870263084, 0.95335175846966),
('random_word_swap', 0.4128642761364277, 0.00691588489202144)],
[('synonym_word_substitute', 0.6352531215702312, 0.18311186039533067),
('random_word_delete', 0.5926874262625883, 0.5364992599119727)],
[('tfidf_word_substitute', 0.0007766212863096478, 0.9951597620125248),
('tfidf_word_insert', 0.9284908255001041, 0.8151066122389705)],
[('random_word_delete', 0.5543678813405379, 0.5189422763173466),
('tfidf_word_substitute', 0.4301374956162636, 0.8941465004308808)],
[('random_word_delete', 0.5649901290207588, 0.4147469438768546),
('synonym_word_substitute', 0.7379541454486046, 0.39346847798276585)]]


def cola():
return [[('random_word_delete', 0.3561054823049778, 0.38653604140692843),
('tfidf_word_substitute', 0.6119904498638091, 0.7480956486081061)],
[('tfidf_word_substitute', 0.7333089032420526, 0.8074409858237845),
('random_word_swap', 0.5336904643341456, 0.3061304517754367)],
[('synonym_word_substitute', 0.8703484961446668, 0.8713429056373105),
('synonym_word_substitute', 0.3888761904351198, 0.44534142264101106)],
[('random_word_swap', 0.38489549841780535, 0.5162453703445143),
('random_word_delete', 0.6245419459874113, 0.20536909806810433)],
[('tfidf_word_substitute', 0.22666456210358554, 0.39372880970524793),
('random_word_swap', 0.811864987249894, 0.9555265437891366)],
[('tfidf_word_insert', 0.30050492672021645, 0.9916786328306386),
('tfidf_word_substitute', 0.06255950877127048, 0.6309235049048048)],
[('random_word_delete', 0.6719132372044447, 0.9836329858619379),
('tfidf_word_insert', 0.8420599342164212, 0.404672629267566)],
[('synonym_word_substitute', 0.5918495436499459, 0.8293310716446014),
('tfidf_word_insert', 0.9111906303284524, 0.5205727423907498)],
[('synonym_word_substitute', 0.8506006593326765, 0.6472527298713463),
('synonym_word_substitute', 0.43707935799244513, 0.5494767725442937)],
[('tfidf_word_insert', 0.6572641245063084, 0.32120987775289295),
('random_word_swap', 0.4009335761117499, 0.3015697007069029)]]


def default_policy():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have deleted the default policy since it is not been used anymore

return [[('synonym_word_substitute', 0.7492730962660217, 0.8816452863413866),
('synonym_word_substitute', 0.33184334794125936, 0.5208169910984721)],
Expand All @@ -197,4 +250,5 @@ def default_policy():
('tfidf_word_insert', 0.9572797299031379, 0.5632362843917429)]]


policy_map = {'imdb': imdb(), 'sst5': sst5(), 'trec': trec(), 'yelp2': yelp2(), 'yelp5': yelp5()}
policy_map = {'imdb': imdb(), 'sst5': sst5(), 'trec': trec(),
'yelp2': yelp2(), 'yelp5': yelp5(), 'sst2':sst2(), 'cola':cola()}
36 changes: 26 additions & 10 deletions taa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,38 @@
import random
import torch
from torch.utils.data import SubsetRandomSampler, Sampler, Subset, ConcatDataset
from .transforms import Compose
from sklearn.model_selection import StratifiedShuffleSplit, KFold
from theconf import Config as C
import numpy as np
from .custom_dataset import general_dataset
from datasets import load_dataset
from .augmentation import get_augment, apply_augment, random_augment
from .common import get_logger
import pandas as pd
from .utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split
from transformers import BertTokenizer, BertTokenizerFast
from .text_networks import num_class
import math
import copy
from .archive import policy_map, default_policy
import multiprocessing
from functools import partial
import time
from .utils.get_data import download_data
from .utils.metrics import n_dist

try:
from .utils.get_data import download_data
from .utils.metrics import n_dist
from .archive import policy_map, default_policy
from .text_networks import num_class
from .utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split
from .common import get_logger
from .augmentation import get_augment, apply_augment, random_augment
from .custom_dataset import general_dataset
from .transforms import Compose
except:
from utils.get_data import download_data
from utils.metrics import n_dist
from archive import policy_map, default_policy
from text_networks import num_class
from utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split
from common import get_logger
from augmentation import get_augment, apply_augment, random_augment
from custom_dataset import general_dataset
from transforms import Compose

logger = get_logger('Text AutoAugment')
logger.setLevel(logging.INFO)
Expand All @@ -45,7 +56,7 @@ def get_datasets(dataset, policy_opt):
elif aug in list(policy_map.keys()): # use pre-searched policy
transform_train.transforms.insert(0, Augmentation(policy_map[aug]))
else:
transform_train.transforms.insert(0, Augmentation(default_policy()))
pass
Comment on lines -48 to +59
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have fixed


# load dataset
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
Expand Down Expand Up @@ -95,6 +106,11 @@ def __init__(self, policy):
def __call__(self, texts, labels):
texts = np.array(texts)
labels = np.array(labels)
if C.get()['ir'] < 1 and C.get()['method'] != 'bt':
# rebalanced data
ir_index = np.where(labels == 0)
texts = np.append(texts, texts[ir_index].repeat(int(1 / C.get()['ir']) - 1))
labels = np.append(labels, labels[ir_index].repeat(int(1 / C.get()['ir']) - 1))
Comment on lines +109 to +113
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have fixed

# generate multiple augmented data if necessary
labels = labels.repeat(C.get()['n_aug'])
texts = texts.repeat(C.get()['n_aug'])
Expand Down
22 changes: 16 additions & 6 deletions taa/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,28 @@
from ray import tune
from tqdm import tqdm
from datetime import datetime
from .archive import policy_decoder, remove_deplicates
from .augmentation import augment_list
from .common import get_logger, add_filehandler
from .utils.train_tfidf import train_tfidf
from .train import train_and_eval
from theconf import Config as C, ConfigArgumentParser
import joblib
import random
import logging
from pystopwatch2 import PyStopwatch
import json

try:
from .archive import policy_decoder, remove_deplicates
from .augmentation import augment_list
from .common import get_logger, add_filehandler
from .utils.train_tfidf import train_tfidf
from .train import train_and_eval
except:
from archive import policy_decoder, remove_deplicates
from augmentation import augment_list
from common import get_logger, add_filehandler
from utils.train_tfidf import train_tfidf
from train import train_and_eval



logging.basicConfig(level=logging.INFO)


Expand Down Expand Up @@ -127,7 +137,7 @@ def search_policy(dataset, abspath, configfile=None, num_search=200, num_policy=
logger.info('Initialize ray...')
# ray.init(num_cpus=num_cpus, local_mode=True) # used for debug
ray.init(num_gpus=num_gpus, num_cpus=num_cpus)

train_tfidf(dataset_type) # calculate tf-idf score for TS and TI operations

if method == 'taa':
Expand Down
15 changes: 11 additions & 4 deletions taa/search_and_augment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import pkg_resources
from .data import augment
from .search import search_policy
from datasets import load_dataset
from theconf import Config as C
from .archive import policy_map
from .utils.train_tfidf import train_tfidf

try:
from .archive import policy_map
from .utils.train_tfidf import train_tfidf
from .data import augment
from .search import search_policy
except:
from archive import policy_map
from utils.train_tfidf import train_tfidf
from data import augment
from search import search_policy


def search_and_augment(configfile=None):
Expand Down
121 changes: 79 additions & 42 deletions taa/search_augment_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,67 @@
import shutil
import logging
import transformers
from .data import augment
from .train import compute_metrics
from .search import search_policy
from datasets import load_dataset
import pandas as pd
import numpy as np
from theconf import Config as C
from .utils.raw_data_utils import get_examples
from .custom_dataset import general_dataset
from .text_networks import get_model, num_class, get_num_class
from transformers import BertForSequenceClassification, Trainer, TrainingArguments, BertTokenizerFast
from .common import get_logger, add_filehandler
from .utils.train_tfidf import train_tfidf
import joblib

try:
from .data import augment
from .train import compute_metrics
from .search import search_policy
from .utils.raw_data_utils import get_examples
from .custom_dataset import general_dataset
from .text_networks import get_model, num_class, get_num_class
from .common import get_logger, add_filehandler
from .utils.train_tfidf import train_tfidf
from .transforms import Compose
from .data import Augmentation
from .custom_dataset import general_dataset
from .utils.raw_data_utils import get_examples, general_split

except:
from data import augment
from train import compute_metrics
from search import search_policy
from utils.raw_data_utils import get_examples
from custom_dataset import general_dataset
from text_networks import get_model, num_class, get_num_class
from common import get_logger, add_filehandler
from utils.train_tfidf import train_tfidf
from transforms import Compose
from data import Augmentation
from custom_dataset import general_dataset
from utils.raw_data_utils import get_examples, general_split



def get_all_datasets(dataset, n_aug, policy, test_size):
""" get augmented train, valid and full test datasets """
C.get()['n_aug'] = n_aug
text_key = C.get()['dataset']['text_key']

transform_train = Compose([])
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
transform_train.transforms.insert(0, Augmentation(policy))

train_dataset = load_dataset(dataset, split='train')
test_dataset = load_dataset(dataset, split='test')

class_num = train_dataset.features['label'].num_classes
all_train_examples = get_examples(train_dataset, text_key)

train_examples, valid_examples = general_split(all_train_examples, test_size=test_size, train_size=1-test_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A judgment statement should be added: if the dataset originally has a validation set, there is no need to split the val set from the training set.


test_examples = get_examples(test_dataset, text_key)

train_dataset = general_dataset(train_examples, tokenizer, text_transform=transform_train)
val_dataset = general_dataset(valid_examples, tokenizer, text_transform=None)
test_dataset = general_dataset(test_examples, tokenizer, text_transform=None)

return train_dataset, val_dataset, test_dataset

def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy_opt, save_path=None, only_eval=False):
transformers.logging.set_verbosity_info()
Expand All @@ -32,13 +78,7 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy
text_key = C.get()['dataset']['text_key']

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

valid_examples = get_examples(valid_dataset, text_key)
test_examples = get_examples(test_dataset, text_key)

train_dataset = augmented_train_dataset
val_dataset = general_dataset(valid_examples, tokenizer, text_transform=None)
test_dataset = general_dataset(test_examples, tokenizer, text_transform=None)

do_train = True
logging_dir = os.path.join('logs/%s_%s/%s' % (dataset_type, model_type, tag))
Expand Down Expand Up @@ -86,7 +126,7 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # evaluation dataset
eval_dataset=valid_dataset, # evaluation dataset
compute_metrics=compute_metrics,
)

Expand All @@ -102,52 +142,49 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy
if file.startswith("checkpoint-"):
shutil.rmtree(os.path.join(save_path, file))

# logger.info("evaluating on test set")
logger.info("evaluating on test set")
# note that when policy_opt, the test_dataset is only a subset of true test_dataset, used for evaluating policy
# result = trainer.evaluate(eval_dataset=test_dataset)
result = trainer.evaluate(eval_dataset=test_dataset)

result['n_dist'] = train_dataset.aug_n_dist
result['opt_object'] = result['eval_accuracy']

# result['n_dist'] = train_dataset.aug_n_dist
# result['opt_object'] = result['eval_accuracy']
logger.info("Predicting on test set")
result = trainer.predict(test_dataset)
logits = result.predictions
predictions = np.argmax(logits, axis=1)
# logger.info("Predicting on test set")
# result = trainer.predict(test_dataset)
# logits = result.predictions
# predictions = np.argmax(logits, axis=1)

predict_df = pd.read_csv('%s.tsv' % dataset_type, sep='\t')
predict_df['prediction'] = predictions
predict_df.to_csv('%s.tsv' % dataset_type, sep='\t', index=False)
# predict_df = pd.read_csv('%s.tsv' % dataset_type, sep='\t')
# predict_df['prediction'] = predictions
# predict_df.to_csv('%s.tsv' % dataset_type, sep='\t', index=False)

if policy_opt:
shutil.rmtree(save_path)
return result


if __name__ == '__main__':
_ = C('confs/bert_sst2_example.yaml')
_ = C('confs/bert_imdb_example.yaml')
dataset_type = C.get()['dataset']['name']
model_type = C.get()['model']['type']

# search augmentation policy for specific dataset
# search_policy(dataset='sst2', configfile='bert_sst2_example.yaml', abspath='/home/renshuhuai/text-autoaugment' )

train_dataset = load_dataset('glue', 'sst2', split='train')
valid_dataset = load_dataset('glue', 'sst2', split='validation')
test_dataset = load_dataset('glue', 'sst2', split='test')
# search_policy(dataset='imdb', abspath='/home/renshuhuai/text-autoaugment' )

# generate augmented dataset
configfile = 'bert_sst2_example.yaml'
policy_path = '/home/renshuhuai/text-autoaugment/final_policy/sst2_Bert_seed59_train-npc50_n-aug8_ir1.00_taa.pkl'
# get augmented train dataset, valid and full test datasets
configfile = 'bert_imdb_example.yaml'
policy_path = '/home/wangyuxiang/Text_AutoAugment/text-autoaugment/final_policy/imdb_Bert_seed59_train-npc50_n-aug4_ir1.00_taa.pkl'
policy = joblib.load(policy_path)
augmented_train_dataset = augment(dataset=train_dataset, policy=policy, n_aug=8, configfile=configfile)
augmented_train_dataset, valid_dataset, test_dataset = get_all_datasets(dataset_type, n_aug=4, policy=policy, test_size=0.3)

# training
dataset_type = C.get()['dataset']['name']
model_type = C.get()['model']['type']

train_tfidf(dataset_type) # calculate tf-idf score for TS and TI operations

tag = '%s_%s_with_found_policy' % (dataset_type, model_type)
save_path = os.path.join('models', tag)

result = train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy_opt=False,
save_path=save_path, only_eval=True)
# for k,v in result.items():
# print('%s:%s' % (k,v))
save_path=save_path, only_eval=False)

for k,v in result.items():
print('%s:%s' % (k,v))
Loading