-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
23,252 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,41 @@ | ||
# TextCNN | ||
A simple TextCNN pytorch implementation | ||
|
||
实现基于: | ||
[https://github.com/Shawn1993/cnn-text-classification-pytorch](https://github.com/Shawn1993/cnn-text-classification-pytorch) | ||
|
||
主要改动: | ||
* 简化了参数配置,希望呈现一个最简版本 | ||
* Fix一些由于pytorch版本升级接口变动所致语法错误 | ||
* Fix模型padding导致的runtime error | ||
* 解耦模型model.py与training/test/prediction逻辑 | ||
* 定制tokenizer,默认中文jieba分词 | ||
* 使用torchtext的TabularDataset读取数据集:text \t label | ||
|
||
使用的数据集是weibo_senti_100k中的部分数据,其中train/test分别有20000和3000条。 | ||
|
||
# Requirements | ||
pytorch==1.3.1 | ||
torchtext==0.4.0 | ||
|
||
# Train | ||
`python main.py -train` | ||
|
||
# Test | ||
`python main.py -test -snapshot snapshot/best_steps_400.pt` | ||
|
||
运行结果: | ||
``` | ||
Evaluation - loss: 0.061201 acc: 98.053% (2518/2568) | ||
``` | ||
|
||
# Predict | ||
`python main.py -predict -snapshot snapshot/best_steps_400.pt` | ||
|
||
运行结果: | ||
``` | ||
>>内牛满面~[泪] | ||
0 | 内牛满面~[泪] | ||
>>啧啧啧,好幸福好幸福 | ||
1 | 啧啧啧,好幸福好幸福 | ||
``` |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchtext.data as data | ||
import torchtext.datasets as datasets | ||
import jieba | ||
import argparse | ||
from torchtext import data | ||
from model import TextCnn | ||
from operation import * | ||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser(description='CNN text classificer') | ||
# learning | ||
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate [default: 0.001]') | ||
parser.add_argument('-epochs', type=int, default=10, help='number of epochs for train [default: 10]') | ||
parser.add_argument('-batch-size', type=int, default=128, help='batch size for training [default: 128]') | ||
parser.add_argument('-log-interval', type=int, default=100, help='how many steps to wait before logging training status [default: 100]') | ||
parser.add_argument('-test-interval', type=int, default=200, help='how many steps to wait before testing [default: 200]') | ||
parser.add_argument('-save-interval', type=int, default=1000, help='how many steps to wait before saving [default: 1000]') | ||
parser.add_argument('-save-dir', type=str, default='snapshot', help='directory to save the snapshot') | ||
# model | ||
parser.add_argument('-dropout', type=float, default=0.5, help='dropout probability [default: 0.5]') | ||
parser.add_argument('-embed-dim', type=int, default=128, help='number of embedding dimension [default: 128]') | ||
parser.add_argument('-kernel-num', type=int, default=10, help='number of kernels') | ||
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution') | ||
# option | ||
parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot [default: None]') | ||
parser.add_argument('-train', action='store_true', default=False, help='train a new model') | ||
parser.add_argument('-test', action='store_true', default=False, help='test on testset, combined with -snapshot to load model') | ||
parser.add_argument('-predict', action='store_true', default=False, help='predict label of console input') | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
def tokenize(text): | ||
return [word for word in jieba.cut(text) if word.strip()] | ||
|
||
args = parse_arguments() | ||
|
||
text_field = data.Field(lower=True, tokenize = tokenize) | ||
label_field = data.Field(sequential=False) | ||
fields = [('text', text_field), ('label', label_field)] | ||
train_dataset, test_dataset = data.TabularDataset.splits( | ||
path = './data/', format = 'tsv', skip_header = False, | ||
train = 'train.tsv', test = 'test.tsv', fields = fields | ||
) | ||
text_field.build_vocab(train_dataset, test_dataset, min_freq = 5, max_size = 50000) | ||
label_field.build_vocab(train_dataset, test_dataset) | ||
train_iter, test_iter = data.Iterator.splits((train_dataset, test_dataset), | ||
batch_sizes = (args.batch_size, args.batch_size), sort_key = lambda x: len(x.text)) | ||
|
||
embed_num = len(text_field.vocab) | ||
class_num = len(label_field.vocab) - 1 | ||
kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] | ||
|
||
args.cuda = torch.cuda.is_available() | ||
|
||
print("Parameters:") | ||
for attr, value in sorted(args.__dict__.items()): | ||
print("{}={}".format(attr.upper(), value)) | ||
|
||
cnn = TextCnn(embed_num, args.embed_dim, class_num, args.kernel_num, kernel_sizes, args.dropout) | ||
if args.snapshot is not None: | ||
print('Loading model from {}...'.format(args.snapshot)) | ||
cnn.load_state_dict(torch.load(args.snapshot)) | ||
pytorch_total_params = sum(p.numel() for p in cnn.parameters() if p.requires_grad) | ||
print ("Model parameters: " + str(pytorch_total_params)) | ||
if args.cuda: | ||
cnn = cnn.cuda() | ||
|
||
|
||
if args.train: | ||
train(train_iter, test_iter, cnn, args) | ||
|
||
if args.test: | ||
eval(test_iter, cnn, args) | ||
|
||
if args.predict: | ||
while(True): | ||
text = input(">>") | ||
label = predict(text, cnn, text_field, label_field, True) | ||
print (str(label) + " | " + text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class TextCnn(nn.Module): | ||
def __init__(self, embed_num, embed_dim, class_num, kernel_num, kernel_sizes, dropout = 0.5): | ||
super(TextCnn, self).__init__() | ||
|
||
Ci = 1 | ||
Co = kernel_num | ||
|
||
self.embed = nn.Embedding(embed_num, embed_dim) | ||
self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (f, embed_dim), padding = (2, 0)) for f in kernel_sizes]) | ||
|
||
self.dropout = nn.Dropout(dropout) | ||
self.fc = nn.Linear(Co * len(kernel_sizes), class_num) | ||
|
||
def forward(self, x): | ||
x = self.embed(x) # (N, token_num, embed_dim) | ||
x = x.unsqueeze(1) # (N, Ci, token_num, embed_dim) | ||
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, token_num) * len(kernel_sizes)] | ||
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co) * len(kernel_sizes)] | ||
x = torch.cat(x, 1) # (N, Co * len(kernel_sizes)) | ||
x = self.dropout(x) # (N, Co * len(kernel_sizes)) | ||
logit = self.fc(x) # (N, class_num) | ||
return logit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import os | ||
import sys | ||
import torch | ||
import torch.autograd as autograd | ||
import torch.nn.functional as F | ||
|
||
def train(train_iter, dev_iter, model, args): | ||
model.train() | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | ||
|
||
steps = 0 | ||
best_acc = 0 | ||
last_step = 0 | ||
|
||
for epoch in range(1, args.epochs+1): | ||
for batch in train_iter: | ||
feature, target = batch.text, batch.label | ||
feature.t_(), target.sub_(1) # batch first, index align | ||
if args.cuda: | ||
feature, target = feature.cuda(), target.cuda() | ||
|
||
optimizer.zero_grad() | ||
logit = model(feature) | ||
|
||
loss = F.cross_entropy(logit, target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
steps += 1 | ||
if steps % args.log_interval == 0: | ||
corrects = (torch.max(logit, 1)[1].view( | ||
target.size()).data == target.data).sum() | ||
accuracy = 100.0 * float(corrects) / batch.batch_size | ||
sys.stdout.write('\rBatch[{}] - loss: {:.6f} acc: {:.3f}%({}/{})'.format(steps, | ||
loss.data, | ||
accuracy, | ||
corrects, | ||
batch.batch_size)) | ||
if steps % args.test_interval == 0: | ||
dev_acc = test(dev_iter, model, args) | ||
if dev_acc > best_acc: | ||
best_acc = dev_acc | ||
last_step = steps | ||
save(model, args.save_dir, 'best', steps) | ||
if steps % args.save_interval == 0: | ||
save(model, args.save_dir, 'snapshot', steps) | ||
|
||
|
||
def test(data_iter, model, args): | ||
model.eval() | ||
corrects, avg_loss = 0, 0 | ||
for batch in data_iter: | ||
feature, target = batch.text, batch.label | ||
feature.t_(), target.sub_(1) # batch first, index align | ||
if args.cuda: | ||
feature, target = feature.cuda(), target.cuda() | ||
|
||
logit = model(feature) | ||
loss = F.cross_entropy(logit, target, size_average=False) | ||
|
||
avg_loss += loss.data | ||
corrects += (torch.max(logit, 1) | ||
[1].view(target.size()).data == target.data).sum() | ||
|
||
size = len(data_iter.dataset) | ||
avg_loss /= size | ||
accuracy = 100.0 * float(corrects) / size | ||
print('Evaluation - loss: {:.6f} acc: {:.3f}% ({}/{}) \n'.format(avg_loss, | ||
accuracy, | ||
corrects, | ||
size)) | ||
return accuracy | ||
|
||
|
||
def predict(text, model, text_field, label_feild, cuda_flag): | ||
assert isinstance(text, str) | ||
model.eval() | ||
|
||
text = text_field.preprocess(text) | ||
text = [[text_field.vocab.stoi[x] for x in text]] | ||
x = torch.tensor(text) | ||
x = autograd.Variable(x) | ||
if cuda_flag: | ||
x = x.cuda() | ||
|
||
output = model(x) | ||
_, predicted = torch.max(output, 1) | ||
return label_feild.vocab.itos[predicted.data + 1] | ||
|
||
|
||
def save(model, save_dir, save_prefix, steps): | ||
if not os.path.isdir(save_dir): | ||
os.makedirs(save_dir) | ||
save_path = os.path.join(save_dir, '{}_steps_{}.pt'.format(save_prefix, steps)) | ||
torch.save(model.state_dict(), save_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"folders": [ | ||
{ | ||
"path": "." | ||
} | ||
], | ||
"settings": {} | ||
} |