-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix import issue when running huggingface_lowresource.sh #4
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,27 +4,38 @@ | |
import random | ||
import torch | ||
from torch.utils.data import SubsetRandomSampler, Sampler, Subset, ConcatDataset | ||
from .transforms import Compose | ||
from sklearn.model_selection import StratifiedShuffleSplit, KFold | ||
from theconf import Config as C | ||
import numpy as np | ||
from .custom_dataset import general_dataset | ||
from datasets import load_dataset | ||
from .augmentation import get_augment, apply_augment, random_augment | ||
from .common import get_logger | ||
import pandas as pd | ||
from .utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split | ||
from transformers import BertTokenizer, BertTokenizerFast | ||
from .text_networks import num_class | ||
import math | ||
import copy | ||
from .archive import policy_map, default_policy | ||
import multiprocessing | ||
from functools import partial | ||
import time | ||
from .utils.get_data import download_data | ||
from .utils.metrics import n_dist | ||
|
||
try: | ||
from .utils.get_data import download_data | ||
from .utils.metrics import n_dist | ||
from .archive import policy_map, default_policy | ||
from .text_networks import num_class | ||
from .utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split | ||
from .common import get_logger | ||
from .augmentation import get_augment, apply_augment, random_augment | ||
from .custom_dataset import general_dataset | ||
from .transforms import Compose | ||
except: | ||
from utils.get_data import download_data | ||
from utils.metrics import n_dist | ||
from archive import policy_map, default_policy | ||
from text_networks import num_class | ||
from utils.raw_data_utils import get_processor, general_subsample_by_classes, get_examples, general_split | ||
from common import get_logger | ||
from augmentation import get_augment, apply_augment, random_augment | ||
from custom_dataset import general_dataset | ||
from transforms import Compose | ||
|
||
logger = get_logger('Text AutoAugment') | ||
logger.setLevel(logging.INFO) | ||
|
@@ -45,7 +56,7 @@ def get_datasets(dataset, policy_opt): | |
elif aug in list(policy_map.keys()): # use pre-searched policy | ||
transform_train.transforms.insert(0, Augmentation(policy_map[aug])) | ||
else: | ||
transform_train.transforms.insert(0, Augmentation(default_policy())) | ||
pass | ||
Comment on lines
-48
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have fixed |
||
|
||
# load dataset | ||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | ||
|
@@ -95,6 +106,11 @@ def __init__(self, policy): | |
def __call__(self, texts, labels): | ||
texts = np.array(texts) | ||
labels = np.array(labels) | ||
if C.get()['ir'] < 1 and C.get()['method'] != 'bt': | ||
# rebalanced data | ||
ir_index = np.where(labels == 0) | ||
texts = np.append(texts, texts[ir_index].repeat(int(1 / C.get()['ir']) - 1)) | ||
labels = np.append(labels, labels[ir_index].repeat(int(1 / C.get()['ir']) - 1)) | ||
Comment on lines
+109
to
+113
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have fixed |
||
# generate multiple augmented data if necessary | ||
labels = labels.repeat(C.get()['n_aug']) | ||
texts = texts.repeat(C.get()['n_aug']) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,21 +4,67 @@ | |
import shutil | ||
import logging | ||
import transformers | ||
from .data import augment | ||
from .train import compute_metrics | ||
from .search import search_policy | ||
from datasets import load_dataset | ||
import pandas as pd | ||
import numpy as np | ||
from theconf import Config as C | ||
from .utils.raw_data_utils import get_examples | ||
from .custom_dataset import general_dataset | ||
from .text_networks import get_model, num_class, get_num_class | ||
from transformers import BertForSequenceClassification, Trainer, TrainingArguments, BertTokenizerFast | ||
from .common import get_logger, add_filehandler | ||
from .utils.train_tfidf import train_tfidf | ||
import joblib | ||
|
||
try: | ||
from .data import augment | ||
from .train import compute_metrics | ||
from .search import search_policy | ||
from .utils.raw_data_utils import get_examples | ||
from .custom_dataset import general_dataset | ||
from .text_networks import get_model, num_class, get_num_class | ||
from .common import get_logger, add_filehandler | ||
from .utils.train_tfidf import train_tfidf | ||
from .transforms import Compose | ||
from .data import Augmentation | ||
from .custom_dataset import general_dataset | ||
from .utils.raw_data_utils import get_examples, general_split | ||
|
||
except: | ||
from data import augment | ||
from train import compute_metrics | ||
from search import search_policy | ||
from utils.raw_data_utils import get_examples | ||
from custom_dataset import general_dataset | ||
from text_networks import get_model, num_class, get_num_class | ||
from common import get_logger, add_filehandler | ||
from utils.train_tfidf import train_tfidf | ||
from transforms import Compose | ||
from data import Augmentation | ||
from custom_dataset import general_dataset | ||
from utils.raw_data_utils import get_examples, general_split | ||
|
||
|
||
|
||
def get_all_datasets(dataset, n_aug, policy, test_size): | ||
""" get augmented train, valid and full test datasets """ | ||
C.get()['n_aug'] = n_aug | ||
text_key = C.get()['dataset']['text_key'] | ||
|
||
transform_train = Compose([]) | ||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | ||
transform_train.transforms.insert(0, Augmentation(policy)) | ||
|
||
train_dataset = load_dataset(dataset, split='train') | ||
test_dataset = load_dataset(dataset, split='test') | ||
|
||
class_num = train_dataset.features['label'].num_classes | ||
all_train_examples = get_examples(train_dataset, text_key) | ||
|
||
train_examples, valid_examples = general_split(all_train_examples, test_size=test_size, train_size=1-test_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A judgment statement should be added: if the dataset originally has a validation set, there is no need to split the val set from the training set. |
||
|
||
test_examples = get_examples(test_dataset, text_key) | ||
|
||
train_dataset = general_dataset(train_examples, tokenizer, text_transform=transform_train) | ||
val_dataset = general_dataset(valid_examples, tokenizer, text_transform=None) | ||
test_dataset = general_dataset(test_examples, tokenizer, text_transform=None) | ||
|
||
return train_dataset, val_dataset, test_dataset | ||
|
||
def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy_opt, save_path=None, only_eval=False): | ||
transformers.logging.set_verbosity_info() | ||
|
@@ -32,13 +78,7 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy | |
text_key = C.get()['dataset']['text_key'] | ||
|
||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | ||
|
||
valid_examples = get_examples(valid_dataset, text_key) | ||
test_examples = get_examples(test_dataset, text_key) | ||
|
||
train_dataset = augmented_train_dataset | ||
val_dataset = general_dataset(valid_examples, tokenizer, text_transform=None) | ||
test_dataset = general_dataset(test_examples, tokenizer, text_transform=None) | ||
|
||
do_train = True | ||
logging_dir = os.path.join('logs/%s_%s/%s' % (dataset_type, model_type, tag)) | ||
|
@@ -86,7 +126,7 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy | |
model=model, # the instantiated 🤗 Transformers model to be trained | ||
args=training_args, # training arguments, defined above | ||
train_dataset=train_dataset, # training dataset | ||
eval_dataset=val_dataset, # evaluation dataset | ||
eval_dataset=valid_dataset, # evaluation dataset | ||
compute_metrics=compute_metrics, | ||
) | ||
|
||
|
@@ -102,52 +142,49 @@ def train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy | |
if file.startswith("checkpoint-"): | ||
shutil.rmtree(os.path.join(save_path, file)) | ||
|
||
# logger.info("evaluating on test set") | ||
logger.info("evaluating on test set") | ||
# note that when policy_opt, the test_dataset is only a subset of true test_dataset, used for evaluating policy | ||
# result = trainer.evaluate(eval_dataset=test_dataset) | ||
result = trainer.evaluate(eval_dataset=test_dataset) | ||
|
||
result['n_dist'] = train_dataset.aug_n_dist | ||
result['opt_object'] = result['eval_accuracy'] | ||
|
||
# result['n_dist'] = train_dataset.aug_n_dist | ||
# result['opt_object'] = result['eval_accuracy'] | ||
logger.info("Predicting on test set") | ||
result = trainer.predict(test_dataset) | ||
logits = result.predictions | ||
predictions = np.argmax(logits, axis=1) | ||
# logger.info("Predicting on test set") | ||
# result = trainer.predict(test_dataset) | ||
# logits = result.predictions | ||
# predictions = np.argmax(logits, axis=1) | ||
|
||
predict_df = pd.read_csv('%s.tsv' % dataset_type, sep='\t') | ||
predict_df['prediction'] = predictions | ||
predict_df.to_csv('%s.tsv' % dataset_type, sep='\t', index=False) | ||
# predict_df = pd.read_csv('%s.tsv' % dataset_type, sep='\t') | ||
# predict_df['prediction'] = predictions | ||
# predict_df.to_csv('%s.tsv' % dataset_type, sep='\t', index=False) | ||
|
||
if policy_opt: | ||
shutil.rmtree(save_path) | ||
return result | ||
|
||
|
||
if __name__ == '__main__': | ||
_ = C('confs/bert_sst2_example.yaml') | ||
_ = C('confs/bert_imdb_example.yaml') | ||
dataset_type = C.get()['dataset']['name'] | ||
model_type = C.get()['model']['type'] | ||
|
||
# search augmentation policy for specific dataset | ||
# search_policy(dataset='sst2', configfile='bert_sst2_example.yaml', abspath='/home/renshuhuai/text-autoaugment' ) | ||
|
||
train_dataset = load_dataset('glue', 'sst2', split='train') | ||
valid_dataset = load_dataset('glue', 'sst2', split='validation') | ||
test_dataset = load_dataset('glue', 'sst2', split='test') | ||
# search_policy(dataset='imdb', abspath='/home/renshuhuai/text-autoaugment' ) | ||
|
||
# generate augmented dataset | ||
configfile = 'bert_sst2_example.yaml' | ||
policy_path = '/home/renshuhuai/text-autoaugment/final_policy/sst2_Bert_seed59_train-npc50_n-aug8_ir1.00_taa.pkl' | ||
# get augmented train dataset, valid and full test datasets | ||
configfile = 'bert_imdb_example.yaml' | ||
policy_path = '/home/wangyuxiang/Text_AutoAugment/text-autoaugment/final_policy/imdb_Bert_seed59_train-npc50_n-aug4_ir1.00_taa.pkl' | ||
policy = joblib.load(policy_path) | ||
augmented_train_dataset = augment(dataset=train_dataset, policy=policy, n_aug=8, configfile=configfile) | ||
augmented_train_dataset, valid_dataset, test_dataset = get_all_datasets(dataset_type, n_aug=4, policy=policy, test_size=0.3) | ||
|
||
# training | ||
dataset_type = C.get()['dataset']['name'] | ||
model_type = C.get()['model']['type'] | ||
|
||
train_tfidf(dataset_type) # calculate tf-idf score for TS and TI operations | ||
|
||
tag = '%s_%s_with_found_policy' % (dataset_type, model_type) | ||
save_path = os.path.join('models', tag) | ||
|
||
result = train_bert(tag, augmented_train_dataset, valid_dataset, test_dataset, policy_opt=False, | ||
save_path=save_path, only_eval=True) | ||
# for k,v in result.items(): | ||
# print('%s:%s' % (k,v)) | ||
save_path=save_path, only_eval=False) | ||
|
||
for k,v in result.items(): | ||
print('%s:%s' % (k,v)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have deleted the default policy since it is not been used anymore