Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the data processing for the benchmark dataset #1205

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/trans/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ optional arguments:
--n_layers int transformer model n_layers default 6
```

run the example
**run the example**

step 1: Download the dataset to the cmn-eng directory.

step 2: Run the following script.

```
python train.py --dataset cmn.txt --max-epoch 100 --batch-size 32 --lr 0.01
python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01
```
4 changes: 2 additions & 2 deletions examples/trans/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __len__(self):


class CmnDataset:
def __init__(self, path='cmn-eng/cmn.txt', shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0):
def __init__(self, path, shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0):
"""
cmn dataset, download from https://www.manythings.org/anki/, contains 29909 Chinese and English translation
pairs, the pair format: English + TAB + Chinese + TAB + Attribution
Args:
path: the path of the dataset, default 'cmn-eng/cnn.txt'
path: the path of the dataset
shuffle: shuffle the dataset, default False
batch_size: the size of every batch, default 32
train_ratio: the proportion of the training set to the total data set, default 0.8
Expand Down
2 changes: 1 addition & 1 deletion examples/trans/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
#

# run this example
python train.py --dataset cmn-2000.txt --max-epoch 300 --batch-size 32 --lr 0.01
python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01
5 changes: 2 additions & 3 deletions examples/trans/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run(args):
np.random.seed(args.seed)

batch_size = args.batch_size
cmn_dataset = CmnDataset(path="cmn-eng/"+args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8)
cmn_dataset = CmnDataset(path=args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8)

print("【step-0】 prepare dataset...")
src_vocab_size, tgt_vocab_size = cmn_dataset.en_vab_size, cmn_dataset.cn_vab_size
Expand Down Expand Up @@ -151,8 +151,7 @@ def run(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Training Transformer Model.")
parser.add_argument('--dataset', choices=['cmn.txt', 'cmn-15000.txt',
'cmn-2000.txt'], default='cmn-2000.txt')
parser.add_argument('--dataset', default='cmn-eng/cmn-2000.txt')
parser.add_argument('--max-epoch', default=100, type=int, help='maximum epochs.', dest='max_epoch')
parser.add_argument('--batch-size', default=64, type=int, help='batch size', dest='batch_size')
parser.add_argument('--shuffle', default=True, type=bool, help='shuffle the dataset', dest='shuffle')
Expand Down
Loading