forked from mhauskn/dqn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn.hpp
133 lines (107 loc) · 4.41 KB
/
dqn.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#ifndef DQN_HPP_
#define DQN_HPP_
#include <memory>
#include <random>
#include <tuple>
#include <unordered_map>
#include <vector>
#include <ale_interface.hpp>
#include <caffe/caffe.hpp>
#include <boost/functional/hash.hpp>
#include <boost/optional.hpp>
namespace dqn {
constexpr auto kRawFrameHeight = 210;
constexpr auto kRawFrameWidth = 160;
constexpr auto kCroppedFrameSize = 84;
constexpr auto kCroppedFrameDataSize = kCroppedFrameSize * kCroppedFrameSize;
constexpr auto kInputFrameCount = 4;
constexpr auto kInputDataSize = kCroppedFrameDataSize * kInputFrameCount;
constexpr auto kMinibatchSize = 32;
constexpr auto kMinibatchDataSize = kInputDataSize * kMinibatchSize;
constexpr auto kOutputCount = 18;
using FrameData = std::array<uint8_t, kCroppedFrameDataSize>;
using FrameDataSp = std::shared_ptr<FrameData>;
using InputFrames = std::array<FrameDataSp, 4>;
using Transition = std::tuple<InputFrames, Action,
float, boost::optional<FrameDataSp>>;
using FramesLayerInputData = std::array<float, kMinibatchDataSize>;
using TargetLayerInputData = std::array<float, kMinibatchSize * kOutputCount>;
using FilterLayerInputData = std::array<float, kMinibatchSize * kOutputCount>;
using ActionValue = std::pair<Action, float>;
using SolverSp = std::shared_ptr<caffe::Solver<float>>;
using NetSp = boost::shared_ptr<caffe::Net<float>>;
/**
* Deep Q-Network
*/
class DQN {
public:
DQN(const ActionVect& legal_actions,
const caffe::SolverParameter& solver_param,
const int replay_memory_capacity,
const double gamma,
const int clone_frequency) :
legal_actions_(legal_actions),
solver_param_(solver_param),
replay_memory_capacity_(replay_memory_capacity),
gamma_(gamma),
clone_frequency_(clone_frequency),
random_engine(0) {}
// Initialize DQN. Must be called before calling any other method.
void Initialize();
// Load a trained model from a file.
void LoadTrainedModel(const std::string& model_file);
// Restore solving from a solver file.
void RestoreSolver(const std::string& solver_file);
// Snapshot the current model
void Snapshot() { solver_->Snapshot(); }
// Select an action by epsilon-greedy.
Action SelectAction(const InputFrames& input_frames, double epsilon);
// Select a batch of actions by epsilon-greedy.
ActionVect SelectActions(const std::vector<InputFrames>& frames_batch,
double epsilon);
// Add a transition to replay memory
void AddTransition(const Transition& transition);
// Update DQN using one minibatch
void Update();
// Clear the replay memory
void ClearReplayMemory() { replay_memory_.clear(); }
// Get the current size of the replay memory
int memory_size() const { return replay_memory_.size(); }
// Return the current iteration of the solver
int current_iteration() const { return solver_->iter(); }
protected:
// Clone the Primary network and store the result in clone_net_
void ClonePrimaryNet();
// Given a set of input frames and a network, select an
// action. Returns the action and the estimated Q-Value.
ActionValue SelectActionGreedily(caffe::Net<float>& net,
const InputFrames& last_frames);
// Given a batch of input frames, return a batch of selected actions + values.
std::vector<ActionValue> SelectActionGreedily(
caffe::Net<float>& net,
const std::vector<InputFrames>& last_frames);
// Input data into the Frames/Target/Filter layers of the given
// net. This must be done before forward is called.
void InputDataIntoLayers(caffe::Net<float>& net,
const FramesLayerInputData& frames_data,
const TargetLayerInputData& target_data,
const FilterLayerInputData& filter_data);
protected:
const ActionVect legal_actions_;
const caffe::SolverParameter solver_param_;
const int replay_memory_capacity_;
const double gamma_;
const int clone_frequency_; // How often (steps) the clone_net is updated
std::deque<Transition> replay_memory_;
SolverSp solver_;
NetSp net_; // The primary network used for action selection.
NetSp clone_net_; // Clone of primary net. Used to generate targets.
TargetLayerInputData dummy_input_data_;
std::mt19937 random_engine;
};
/**
* Preprocess an ALE screen (downsampling & grayscaling)
*/
FrameDataSp PreprocessScreen(const ALEScreen& raw_screen);
}
#endif /* DQN_HPP_ */