-
Notifications
You must be signed in to change notification settings - Fork 96
/
data_utils.py
217 lines (186 loc) · 6.61 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
Utilities for data processing.
"""
import tensorflow as tf
import numpy as np
import os
"""
File formatting note.
Data should be preprocessed as a sequence of comma-seperated ints with
sequences /n seperated
"""
# Lookup tables
aa_to_int = {
'M':1,
'R':2,
'H':3,
'K':4,
'D':5,
'E':6,
'S':7,
'T':8,
'N':9,
'Q':10,
'C':11,
'U':12,
'G':13,
'P':14,
'A':15,
'V':16,
'I':17,
'F':18,
'Y':19,
'W':20,
'L':21,
'O':22, #Pyrrolysine
'X':23, # Unknown
'Z':23, # Glutamic acid or GLutamine
'B':23, # Asparagine or aspartic acid
'J':23, # Leucine or isoleucine
'start':24,
'stop':25,
}
int_to_aa = {value:key for key, value in aa_to_int.items()}
def get_aa_to_int():
"""
Get the lookup table (for easy import)
"""
return aa_to_int
def get_int_to_aa():
"""
Get the lookup table (for easy import)
"""
return int_to_aa
# Helper functions
def aa_seq_to_int(s):
"""
Return the int sequence as a list for a given string of amino acids
"""
return [24] + [aa_to_int[a] for a in s] + [25]
def int_seq_to_aa(s):
"""
Return the int sequence as a list for a given string of amino acids
"""
return "".join([int_to_aa[i] for i in s])
def tf_str_len(s):
"""
Returns length of tf.string s
"""
return tf.size(tf.string_split([s],""))
def tf_rank1_tensor_len(t):
"""
Returns the length of a rank 1 tensor t as rank 0 int32
"""
l = tf.reduce_sum(tf.sign(tf.abs(t)), 0)
return tf.cast(l, tf.int32)
def tf_seq_to_tensor(s):
"""
Input a tf.string of comma seperated integers.
Returns Rank 1 tensor the length of the input sequence of type int32
"""
return tf.string_to_number(
tf.sparse_tensor_to_dense(tf.string_split([s],","), default_value='0'), out_type=tf.int32
)[0]
def smart_length(length, bucket_bounds=tf.constant([128, 256])):
"""
Hash the given length into the windows given by bucket bounds.
"""
# num_buckets = tf_len(bucket_bounds) + tf.constant(1)
# Subtract length so that smaller bins are negative, then take sign
# Eg: len is 129, sign = [-1,1]
signed = tf.sign(bucket_bounds - length)
# Now make 1 everywhere that length is greater than bound, else 0
greater = tf.sign(tf.abs(signed - tf.constant(1)))
# Now simply sum to count the number of bounds smaller than length
key = tf.cast(tf.reduce_sum(greater), tf.int64)
# This will be between 0 and len(bucket_bounds)
return key
def pad_batch(ds, batch_size, padding=None, padded_shapes=([None])):
"""
Helper for bucket batch pad- pads with zeros
"""
return ds.padded_batch(batch_size,
padded_shapes=padded_shapes,
padding_values=padding
)
def aas_to_int_seq(aa_seq):
int_seq = ""
for aa in aa_seq:
int_seq += str(aa_to_int[aa]) + ","
return str(aa_to_int['start']) + "," + int_seq + str(aa_to_int['stop'])
# Preprocessing in python
def fasta_to_input_format(source, destination):
# I don't know exactly how to do this in tf, so resorting to python.
# Should go line by line so everything is not loaded into memory
sourcefile = os.path.join(source)
destination = os.path.join(destiation)
with open(sourcefile, 'r') as f:
with open(destination, 'w') as dest:
seq = ""
for line in f:
if line[0] == '>' and not seq == "":
dest.write(aas_to_int_seq(seq) + '\n')
seq = ""
elif not line[0] == '>':
seq += line.replace("\n","")
# Real data pipelines
def bucketbatchpad(
batch_size=256,
path_to_data=os.path.join("./data/SwissProt/sprot_ints.fasta"), # Preprocessed- see note
compressed="", # See tf.contrib.data.TextLineDataset init args
bounds=[128,256], # Default buckets of < 128, 128><256, >256
# Unclear exactly what this does, should proly equal batchsize
window_size=256, # NOT a tensor
padding=None, # Use default padding of zero, otherwise see Dataset docs
shuffle_buffer=None, # None or the size of the buffer to shuffle with
pad_shape=([None]),
repeat=1,
filt=None
):
"""
Streams data from path_to_data that is correctly preprocessed.
Divides into buckets given by bounds and pads to full length.
Returns a dataset which will return a padded batch of batchsize
with iteration.
"""
batch_size=tf.constant(batch_size, tf.int64)
bounds=tf.constant(bounds)
window_size=tf.constant(window_size, tf.int64)
path_to_data = os.path.join(path_to_data)
# Parse strings to tensors
dataset = tf.contrib.data.TextLineDataset(path_to_data).map(tf_seq_to_tensor)
if filt is not None:
dataset = dataset.filter(filt)
if shuffle_buffer:
# Stream elements uniformly randomly from a buffer
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Apply a repeat. Because this is after the shuffle, all elements of the dataset should be seen before repeat.
# See https://stackoverflow.com/questions/44132307/tf-contrib-data-dataset-repeat-with-shuffle-notice-epoch-end-mixed-epochs
dataset = dataset.repeat(count=repeat)
# Apply grouping to bucket and pad
grouped_dataset = dataset.group_by_window(
key_func=lambda seq: smart_length(tf_rank1_tensor_len(seq), bucket_bounds=bounds), # choose a bucket
reduce_func=lambda key, ds: pad_batch(ds, batch_size, padding=padding, padded_shapes=pad_shape), # apply reduce funtion to pad
window_size=window_size)
return grouped_dataset
def shufflebatch(
batch_size=256,
shuffle_buffer=None,
repeat=1,
path_to_data="./data/SwissProt/sprot_ints.fasta"
):
"""
Draws from an (optionally shuffled) dataset, repeats dataset repeat times,
and serves batches of the specified size.
"""
path_to_data = os.path.join(path_to_data)
# Parse strings to tensors
dataset = tf.contrib.data.TextLineDataset(path_to_data).map(tf_seq_to_tensor)
if shuffle_buffer:
# Stream elements uniformly randomly from a buffer
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Apply a repeat. Because this is after the shuffle, all elements of the dataset should be seen before repeat.
# See https://stackoverflow.com/questions/44132307/tf-contrib-data-dataset-repeat-with-shuffle-notice-epoch-end-mixed-epochs
dataset = dataset.repeat(count=repeat)
dataset = dataset.batch(batch_size)
return dataset