-
Notifications
You must be signed in to change notification settings - Fork 0
/
Seq_deepCpf1_torch.py
168 lines (139 loc) · 5.59 KB
/
Seq_deepCpf1_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Reference: https://github.com/MyungjaeSong/Paired-Library
# PyTorch implementation
import argparse
import os
# Ignore warnings
import warnings
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
from deepcpf1_network import SequenceDataset
import deepcpf1_network
def main():
# TODO: Update Usage
print(
"Usage: python Seq_deepCpf1_torch.py --train ./data/train.csv --test ./data/test.csv -- output output.csv"
)
print("input.txt must include 3 columns with single header row")
print("\t1st column: sequence index")
print("\t2nd column: 34bp target sequence")
print("\t3rd column: binary chromain information of the target sequence\n")
print("DeepCpf1 currently requires python=3.9, PyTorch=1.12")
print(
"DeepCpf1 available on GitHub requires pre-obtained binary chromatin information (DNase-seq narraow peak data from ENCODE)"
)
print(
"DeepCpf1 web tool, available at http://data.snu.ac.kr/DeepCpf1, provides entire pipeline including binary chromatin accessibility for 125 cell lines\n"
)
# if len(sys.argv) < 3:
# print("ERROR: Not enough arguments for DeepCpf1.py; Check the usage.")
# sys.exit()
# Argument Parsing
parser = argparse.ArgumentParser("DeepCpf1 meets PyTorch")
parser.add_argument("--train", type=str, default="./data/train.csv")
parser.add_argument("--test", type=str, default="./data/test.csv")
parser.add_argument("--load_weights", action="store_true", default=True)
parser.add_argument("--model_path", type=str, default="weights")
parser.add_argument(
"--mps", action="store_true", default=False, help="Apple Silicon MPS"
)
parser.add_argument(
"--log-interval",
type=int,
default=1,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument("--training_mode", action="store_true", help="training")
parser.add_argument(
"--sequence_length", type=int, default=34, help="target sequence length"
)
parser.add_argument("--kernel_size", type=int, default=5, help="kernel size")
parser.add_argument("--pool_size", type=int, default=2, help="pooling filter size")
parser.add_argument(
"--max_epoch", type=int, default=500, help="maximum training epoch"
)
parser.add_argument(
"--save_model",
action="store_true",
default=True,
help="Saving network state_dict",
)
parser.add_argument(
"--save_freq", type=int, default=10, help="Model save frequency"
)
parser.add_argument("--alias", type=str, default="corr_testing")
args = parser.parse_args()
# Device Configuration
if args.mps and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
train_kwargs = {"batch_size": int(1e4), "shuffle": True}
test_kwargs = {"batch_size": 50, "shuffle": True}
seq_deep_cpf1 = deepcpf1_network.SeqDeepCpf1Net(args).to(device)
opt_seq_deep_cpf1 = optim.Adam(seq_deep_cpf1.parameters(), lr=0.001)
train_history = []
model_state_paths = []
if args.training_mode:
# -----------------------------------------------------
# Seq-deepCpf1 PreTrain
# -----------------------------------------------------
print("Training Seq-deepCpf1")
print("Loading train data")
# Load train data
training_data = SequenceDataset(csv_file=args.train, args=args)
train_dataloader = DataLoader(training_data, **train_kwargs)
# Assertive
epoch = 1
while epoch <= args.max_epoch:
loss = deepcpf1_network.train(
args, seq_deep_cpf1, device, train_dataloader, opt_seq_deep_cpf1, epoch
)
train_history.append(loss)
epoch += 1
epochs = [i for i in range(1, args.max_epoch + 1)]
plt.plot(epochs, train_history, "g", label="Training loss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()
if args.load_weights:
# Loading Model for Inference
print("Listing weights for the models")
for file in os.listdir(args.model_path):
if file.endswith(".pt"):
model_state_paths.append(os.path.join(args.model_path, file))
seq_deep_cpf1.eval()
print("Loading test data")
# TODO: TESTING
test_history = []
# Load test data
testing_data = SequenceDataset(csv_file=args.test, args=args)
test_dataloader = DataLoader(testing_data, **test_kwargs)
for idx, model_path in enumerate(model_state_paths):
seq_deep_cpf1.load_state_dict(torch.load(model_path))
print(f"Predicting on test data: {idx}/{len(model_state_paths)}")
Seq_deepCpf1_SCORE = deepcpf1_network.test(
seq_deep_cpf1, device, test_dataloader
) # returns average loss
test_history.append(Seq_deepCpf1_SCORE)
# Plotting test error over the generated models
loss_test = test_history
epochs = [i for i in range(1, args.max_epoch + 1, args.save_freq)]
plt.plot(epochs, loss_test, "b", label="Testing loss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()
if __name__ == "__main__":
main()