forked from aceimnorstuvwxz/chatbot-zh-torch7
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.lua
executable file
·125 lines (103 loc) · 3.67 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
require 'neuralconvo'
require 'xlua'
cmd = torch.CmdLine()
cmd:text('Options:')
cmd:option('--dataset', 0, 'approximate size of dataset to use (0 = all)')
cmd:option('--minWordFreq', 1, 'minimum frequency of words kept in vocab')
cmd:option('--cuda', false, 'use CUDA')
cmd:option('--opencl', false, 'use opencl')
cmd:option('--hiddenSize', 300, 'number of hidden units in LSTM')
cmd:option('--learningRate', 0.05, 'learning rate at t=0')
cmd:option('--momentum', 0.9, 'momentum')
cmd:option('--minLR', 0.00001, 'minimum learning rate')
cmd:option('--saturateEpoch', 20, 'epoch at which linear decayed LR will reach minLR')
cmd:option('--maxEpoch', 30, 'maximum number of epochs to run')
cmd:option('--batchSize', 10, 'number of examples to load at once')
cmd:text()
options = cmd:parse(arg)
if options.dataset == 0 then
options.dataset = nil
end
-- Data
print("-- Loading dataset")
dataset = neuralconvo.DataSet(neuralconvo.CornellMovieDialogs("data/cornell_movie_dialogs"),
{
loadFirst = options.dataset,
minWordFreq = options.minWordFreq
})
print("\nDataset stats:")
print(" Vocabulary size: " .. dataset.wordsCount)
print(" Examples: " .. dataset.examplesCount)
-- Model
model = neuralconvo.Seq2Seq(dataset.wordsCount, options.hiddenSize)
model.goToken = dataset.goToken
model.eosToken = dataset.eosToken
-- Training parameters
if options.batchSize > 1 then
model.criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(),1))
else
model.criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
end
model.learningRate = options.learningRate
model.momentum = options.momentum
local decayFactor = (options.minLR - options.learningRate) / options.saturateEpoch
local minMeanError = nil
-- Enabled CUDA
if options.cuda then
require 'cutorch'
require 'cunn'
model:cuda()
elseif options.opencl then
require 'cltorch'
require 'clnn'
model:cl()
end
-- Run the experiment
print("dgk ending")
--exit()
for epoch = 1, options.maxEpoch do
print("\n-- Epoch " .. epoch .. " / " .. options.maxEpoch)
print("")
local errors = torch.Tensor(dataset.examplesCount):fill(0)
local timer = torch.Timer()
local i = 1
for encInputs, decInputs, decTargets in dataset:batches(options.batchSize) do
collectgarbage()
if options.cuda then
encInputs = encInputs:cuda()
decInputs = decInputs:cuda()
decTargets = decTargets:cuda()
elseif options.opencl then
encInputs = encInputs:cl()
decInputs = decInputs:cl()
decTargets = decTargets:cl()
end
local err = model:train(encInputs, decInputs, decTargets)
-- Check if error is NaN. If so, it's probably a bug.
if err ~= err then
error("Invalid error! Exiting.")
end
errors[i] = err
xlua.progress(i * options.batchSize, dataset.examplesCount)
i = i + 1
end
timer:stop()
print("\nFinished in " .. xlua.formatTime(timer:time().real) .. " " .. (dataset.examplesCount / timer:time().real) .. ' examples/sec.')
print("\nEpoch stats:")
print(" LR= " .. model.learningRate)
print(" Errors: min= " .. errors:min())
print(" max= " .. errors:max())
print(" median= " .. errors:median()[1])
print(" mean= " .. errors:mean())
print(" std= " .. errors:std())
-- Save the model if it improved.
if minMeanError == nil or errors:mean() < minMeanError then
print("\n(Saving model ...)")
torch.save("data/model.t7", model)
minMeanError = errors:mean()
end
model.learningRate = model.learningRate + decayFactor
model.learningRate = math.max(options.minLR, model.learningRate)
end
-- Load testing script
require "eval"