-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
145 lines (117 loc) · 5.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from settings.var import *
import os
import pickle
import argparse
from copy import deepcopy
import matplotlib.pyplot as plt
from training.utils import *
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='')
args = parser.parse_args()
task = args.task
# configs
task_cfg = TASK_CONFIG
train_cfg = TRAIN_CONFIG
policy_config = POLICY_CONFIG
checkpoint_dir = os.path.join(train_cfg['checkpoint_dir'], task)
# device
device = os.environ['DEVICE']
def forward_pass(data, policy):
image_data, qpos_data, action_data, is_pad = data
qpos_data = qpos_data.float()
action_data = action_data.float()
image_data, qpos_data, action_data, is_pad = image_data.to(device), qpos_data.to(device), action_data.to(device), is_pad.to(device)
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
# save training curves
for key in train_history[0]:
plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png')
plt.figure()
train_values = [summary[key].item() for summary in train_history]
val_values = [summary[key].item() for summary in validation_history]
plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train')
plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation')
# plt.ylim([-0.1, 1])
plt.tight_layout()
plt.legend()
plt.title(key)
plt.savefig(plot_path)
print(f'Saved plots to {ckpt_dir}')
def train_bc(train_dataloader, val_dataloader, policy_config):
# load policy
policy = make_policy(policy_config['policy_class'], policy_config)
policy.to(device)
# load optimizer
optimizer = make_optimizer(policy_config['policy_class'], policy)
# create checkpoint dir if not exists
os.makedirs(checkpoint_dir, exist_ok=True)
train_history = []
validation_history = []
min_val_loss = np.inf
best_ckpt_info = None
for epoch in range(train_cfg['num_epochs']):
print(f'\nEpoch {epoch}')
# validation
with torch.inference_mode():
policy.eval()
epoch_dicts = []
for batch_idx, data in enumerate(val_dataloader):
forward_dict = forward_pass(data, policy)
epoch_dicts.append(forward_dict)
epoch_summary = compute_dict_mean(epoch_dicts)
validation_history.append(epoch_summary)
epoch_val_loss = epoch_summary['loss']
if epoch_val_loss < min_val_loss:
min_val_loss = epoch_val_loss
best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))
print(f'Val loss: {epoch_val_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
# training
policy.train()
optimizer.zero_grad()
for batch_idx, data in enumerate(train_dataloader):
forward_dict = forward_pass(data, policy)
# backward
loss = forward_dict['loss']
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_history.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
epoch_train_loss = epoch_summary['loss']
print(f'Train loss: {epoch_train_loss:.5f}')
summary_string = ''
for k, v in epoch_summary.items():
summary_string += f'{k}: {v.item():.3f} '
print(summary_string)
if epoch % 200 == 0:
ckpt_path = os.path.join(checkpoint_dir, f"policy_epoch_{epoch}_seed_{train_cfg['seed']}.ckpt")
torch.save(policy.state_dict(), ckpt_path)
plot_history(train_history, validation_history, epoch, checkpoint_dir, train_cfg['seed'])
ckpt_path = os.path.join(checkpoint_dir, f'policy_last.ckpt')
torch.save(policy.state_dict(), ckpt_path)
if __name__ == '__main__':
# set seed
set_seed(train_cfg['seed'])
# create ckpt dir if not exists
os.makedirs(checkpoint_dir, exist_ok=True)
# number of training episodes
data_dir = os.path.join(task_cfg['dataset_dir'], task)
num_episodes = len(os.listdir(data_dir))
# load data
train_dataloader, val_dataloader, stats, _ = load_data(data_dir, num_episodes, task_cfg['camera_names'],
train_cfg['batch_size_train'], train_cfg['batch_size_val'])
# save stats
stats_path = os.path.join(checkpoint_dir, f'dataset_stats.pkl')
with open(stats_path, 'wb') as f:
pickle.dump(stats, f)
# train
train_bc(train_dataloader, val_dataloader, policy_config)
'''
The data loader is very important bcs we are pluging our input data here
The data lodaer like an iterator or a smampler that samples each part of the data
were we will have the training data validation data and statastics which is just to normalize our data
'''