forked from evogytis/baltic
-
Notifications
You must be signed in to change notification settings - Fork 6
/
samogitia.py
328 lines (298 loc) · 21.6 KB
/
samogitia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import argparse
import re
import datetime as dt
import baltic as bt
import sys
import collections
import numpy as np
def overlap(a,b):
"""
Return the elements shared by two lists in the following format:
[overlap],[list 1 remainder],[list 2 remainder]
"""
a_multiset = collections.Counter(a)
b_multiset = collections.Counter(b)
overlap = list((a_multiset & b_multiset).elements())
a_remainder = list((a_multiset - b_multiset).elements())
b_remainder = list((b_multiset - a_multiset).elements())
return overlap, a_remainder, b_remainder
############ arguments
samogitia = argparse.ArgumentParser(description="samogitia.py analyses trees drawn from the posterior distribution by BEAST.\n")
samogitia.add_argument('-b','--burnin', default=0, type=int, help="Number of states to remove as burnin (default 0).\n")
samogitia.add_argument('-nc','--nocalibration', default=True, action='store_false', help="Use flag to prevent calibration of trees into absolute time (default True). Should be used if tip names do not contain information about *when* each sequence was collected.\n")
samogitia.add_argument('-t','--treefile', type=open, required=True, help="File with trees sampled from the posterior distribution (usually with suffix .trees).\n")
samogitia.add_argument('-a','--analyses', type=str, required=True, nargs='+', help="Analysis to be performed, can be a list separated by spaces.\n")
samogitia.add_argument('-o','--output', type=argparse.FileType('w'), default='samogitia.out.txt', help="Output file name (default samogitia.out.txt).\n")
samogitia.add_argument('-s','--states', type=str, default='0-inf', help="Define range of states for analysis.\n")
samogitia.add_argument('-df','--date_format', type=str, default='%Y-%m-%d', help="Define date format encoded in tips (default \'%%Y-%%m-%%d\').\n")
samogitia.add_argument('-tf','--tip_format', type=str, default='\|([0-9]+)\-*([0-9]+)*\-*([0-9]+)*$', help="Define regex for capturing dates encoded in tips (default \'\|([0-9]+)\-*([0-9]+)*\-*([0-9]+)*$\'.\n")
args = vars(samogitia.parse_args())
burnin, treefile, analyses, outfile, calibration, states, dformat, tformat = args['burnin'], args['treefile'], args['analyses'], args['output'], args['nocalibration'], args['states'], args['date_format'], args['tip_format']
lower,upper=states.split('-')
lower=int(lower)
if upper=='inf':
upper=np.inf
else:
upper=int(upper)
try:
for line in open('banner_samogitia.txt','r'):
sys.stderr.write('%s'%(line))
except:
pass
## keeps track of things in the tree file at the beginning
plate=True
taxonlist=False
treecount=0 ## counts tree
##
tips={} ## remembers tip encodings
############################# progress bar stuff
Ntrees=10000 ## assume 10 000 trees in posterior sample
barLength=30
progress_update=Ntrees/barLength ## update progress bar every time a tick has to be added
threshold=progress_update ## threshold at which to update progress bar
processingRate=[] ## remember how quickly script processes trees
#############################
available_analyses=['treeLength','RC','Sharp','tmrcas','transitions','subtrees'] ## analysis names that are possible
assert analyses,'No analyses were selected.'
for queued_analysis in analyses: ## for each queued analysis check if samogitia can do anything about them (i.e. whether they're known analysis types)
assert queued_analysis in available_analyses,'%s is not a known analysis type\n\nAvailable analysis types are: \n* %s\n'%(queued_analysis,'\n* '.join(available_analyses))
begin=dt.datetime.now() ## start timer
for line in treefile: ## iterate through each line
###################################################################################
if plate==True and 'state' not in line.lower():
cerberus=re.search('Dimensions ntax\=([0-9]+)\;',line) ## Extract useful information from the bits preceding the actual trees.
if cerberus is not None:
tipNum=int(cerberus.group(1))
if 'Translate' in line:
taxonlist=True ## taxon list to follow
if taxonlist==True and ';' not in line and 'Translate' not in line: ## remember tip encodings
cerberus=re.search('([0-9]+) ([\'\"A-Za-z0-9\?\|\-\_\.\/]+)',line)
tips[cerberus.group(1)]=cerberus.group(2).strip("'")
if 'tree STATE_' in line and plate==True: ## starting actual analysis
plate=False
assert (tipNum == len(tips)),'Expected number of tips: %s\nNumber of tips found: %s'%(tipNum,len(tips)) ## check that correct numbers of tips have been parsed
################################################################################### start analysing trees
cerberus=re.match('tree\sSTATE\_([0-9]+).+\[\&R\]\s',line) ## search for crud at the beginning of the line that's not a tree string
if cerberus is not None: ## tree identified
################################################################# at state 0 - create the header for the output file and read the tree (in case the output log file requires information encoded in the tree)
if treecount==0: ## At tree state 0 insert header into output file
ll=bt.tree() ## empty tree object
start=len(cerberus.group()) ## index of where tree string starts in the line
treestring=str(line[start:]) ## grab tree string
bt.make_tree(treestring,ll) ## read tree string
if lower==0 and upper==np.inf: ## only add a header if not doing a chunk
outfile.write('state') ## begin the output log file
########################################### add header to output log file
if 'treeLength' in analyses:
outfile.write('\ttreeLength')
###########################################
if 'RC' in analyses:
outfile.write('\tN\tS\tuN\tuS\tdNdS')
###########################################
if 'tmrcas' in analyses:
tmrcas={'A':[],'B':[],'C':[]} ## dict of clade names
ll.renameTips(tips)
for k in ll.Objects: ## iterate over branches
if isinstance(k,bt.leaf): ## only interested in tips
if 'A' in k.name: ## if name of tip satisfies condition
tmrcas['A'].append(k.numName) ## add the tip's numName to tmrca list - tips will be used to ID the common ancestor
elif k.name in Conakry_tips:
tmrcas['B'].append(k.numName)
tmrcas['C'].append(k.numName)
outfile.write('\t%s'%('\t'.join(sorted(tmrcas.keys()))))
###########################################
if 'transitions' in analyses:
outfile.write('state\ttotalChangeCount\tcompleteHistory_1')
###########################################
if 'subtrees' in analyses:
import copy
###########################################
## your custom header making code goes here
## if 'custom' in analyses:
## trait='yourTrait'
## trait_vals=[]
## for k in ll.Objects:
## if k.traits.has_key(trait):
## trait_vals.append(k.traits[trait])
## available_trait_values=sorted(bt.unique(trait_vals))
## for tr in available_trait_values:
## outfile.write('\t%s.time'%(tr))
###########################################
treecount+1
outfile.write('\n') ## newline for first tree
#################################################################
if int(cerberus.group(1)) >= burnin and lower <= int(cerberus.group(1)) < upper: ## After burnin start processing
ll=bt.tree() ## ll is the tree object
start=len(cerberus.group()) ## find start of tree string in line
treestring=str(line[start:]) ## get tree string
bt.make_tree(treestring,ll) ## pass it to make_tree function
ll.traverse_tree() ## Traverse the tree - sets the height of each object in the tree
#### renaming tips
if len(tips)>0:
ll.renameTips(tips) ## Rename tips so their name refers to sequence name
else:
for k in ll.Objects:
if isinstance(k,leaf):
k.name=k.numName ## otherwise every tip gets a name that's the same as tree string names
#### calibration
dateCerberus=re.compile(tformat) ## search pattern + brackets on actual calendar date
if calibration==True: ## Calibrate tree so everything has a known position in actual time
tipDatesRaw=[dateCerberus.search(x).group(1) for x in tips.values()]
tipDates=[]
for tip_date in tipDatesRaw:
tipDates.append(bt.decimalDate(tip_date,fmt=dformat,variable=True))
maxDate=max(tipDates) ## identify most recent tip
ll.setAbsoluteTime(maxDate)
outfile.write('%s'%cerberus.group(1)) ## write MCMC state number to output log file
################################################################################
if 'treeLength' in analyses:
treeL=sum([k.length for k in ll.Objects]) ## do analysis
outfile.write('\t%s'%(treeL)) ## output to file
###################################################
if 'RC' in analyses: ## 'RC' was queued as an analysis
Ns=[] ## empty list
Ss=[]
uNs=[]
uSs=[]
for k in ll.Objects: ## iterate over branch objects in the tree
if k.traits.has_key('N'): ## if branch has a trait labelled "N"...
Ns.append(k.traits['N']) ## add it to empty list
Ss.append(k.traits['S']) ## likewise for every other trait
uNs.append(k.traits['b_u_N'])
uSs.append(k.traits['b_u_S'])
tNs=sum(Ns) ## sum of numbers in list
tSs=sum(Ss)
tuNs=sum(uNs)
tuSs=sum(uSs)
dNdS=(tNs/tSs)/(tuNs/tuSs) ## calculate dNdS
outfile.write('\t%s\t%s\t%s\t%s\t%s'%(tNs,tSs,tuNs,tuSs,dNdS)) ## output to file, separated by tabs
###################################################
if 'tmrcas' in analyses:
assert calibration==True,'This analysis type requires time-calibrated trees'
nodes={x:None for x in tmrcas.keys()} ## each TMRCA will correspond to a single object
score={x:len(ll.Objects)+1 for x in tmrcas.keys()} ## this will be used to score the common ancestor candidate
for required in tmrcas.keys(): ## iterate over TMRCAs
searchNodes=sorted([nd for nd in ll.Objects if nd.branchType=='node' and len(nd.leaves)>=len(tmrcas[required])],key=lambda n:len(n.leaves)) ## common ancestor candidates must have at least as many descendants as the list of search tips
for k in searchNodes: ## iterate over candidates
common,queryLeft,targetLeft=overlap(k.leaves,tmrcas[required]) ## find how many query tips exist as descendants of candidate nodes
if len(targetLeft)==0 and len(queryLeft)<=score[required]: ## all of query tips must be descended from common ancestor, every extra descendant of common ancestor not in query contributes to a score
nodes[required]=k ## if score improved - assign new common ancestor
score[required]=len(queryLeft) ## score is extra descendants not in the list of known tips
outTMRCA=['%.6f'%(nodes[n].absoluteTime) for n in sorted(nodes.keys())] ## fetch absoluteTime of each identified common ancestor
outfile.write('\t%s'%('\t'.join(outTMRCA)))
###################################################
if 'Sharp' in analyses:
assert calibration==True,'This analysis type requires time-calibrated trees'
assert len(analyses)==1,'More that one analysis queued in addition to Sharp, which is inadvisable'
outSharp=[]
for k in ll.Objects:
if k.traits.has_key('N'):
N=k.traits['N']
S=k.traits['S']
halfBranch=k.length*0.5
if isinstance(k,node):
all_leaves=[tips[lf] for lf in k.leaves]
t=min(map(bt.decimalDate,[dateCerberus.search(x).group(1) for x in all_leaves]))-k.absoluteTime+halfBranch
else:
t=halfBranch
outSharp.append('(%d,%d,%.4f)'%(N,S,t))
outfile.write('\t%s'%('\t'.join(outSharp)))
###################################################
if 'transitions' in analyses:
assert calibration==True,'This analysis type requires time-calibrated trees'
assert len(analyses)==1,'More that one analysis queued in addition to transitions, which is inadvisable'
outTransitions=[]
for k in ll.Objects:
if k.traits.has_key('location.states') and k.parent.traits.has_key('location.states'):
cur_value=k.traits['location.states']
par_value=k.parent.traits['location.states']
if cur_value!=par_value:
outTransitions.append('{1,%s,%s,%s}'%(ll.treeHeight-k.height-0.5*k.length,par_value,cur_value))
outfile.write('\t%d\t%s'%(len(outTransitions),'\t'.join(outTransitions)))
###################################################
if 'subtrees' in analyses:
traitName='location.states'
assert [k.traits.has_key(traitName) for k in ll.Objects].count(True)>0,'No branches have the trait "%s"'%(traitName)
for k in ll.Objects:
if k.traits.has_key(traitName) and k.parent.traits.has_key(traitName) and k.parent.index!='Root' and k.traits[traitName]!=k.parent.traits[traitName] and k.traits[traitName]=='human':
proceed=False ## assume we still can't proceed forward
kloc=k.traits[traitName]
if isinstance(k,bt.leaf): ## if dealing with a leaf - proceed
N_children=1
proceed=True
else:
N_children=len(k.leaves)
if [ch.traits[traitName] for ch in k.children].count(kloc)>0:
proceed=True
#print k.index,k.parent.index,k.traits,k.parent.traits,k.traits[traitName]
if proceed==True: ## if at least one valid tip and no hanging nodes
subtree=copy.deepcopy(ll.traverseWithinTrait(k,traitName))
subtree_leaves=[x.name for x in subtree if isinstance(x,bt.leaf)]
if len(subtree_leaves)>0:
mostRecentTip=max([bt.decimalDate(x.strip("'").split('|')[-1]) for x in subtree_leaves])
while sum([len(nd.children)-sum([1 if ch in subtree else 0 for ch in nd.children]) for nd in subtree if isinstance(nd,bt.node) and nd.index!='Root'])>0: ## keep removing nodes as long as there are nodes with children that are not entirely within subtree
for nd in sorted([q for q in subtree if isinstance(q,bt.node)],key=lambda x:(sum([1 if ch in subtree else 0 for ch in x.children]),x.height)): ## iterate over nodes in subtree, starting with ones that have fewest valid children and are more recent
child_status=[1 if ch in subtree else 0 for ch in nd.children] ## check how many children of current node are under the right trait value
if sum(child_status)<2 and nd.index!='Root': ## if less than 2 children in subtree (i.e. not all children are under the same trait state)
#print 'removing: %d, children in: %s'%(nd.index,[location_to_country[ch.traits[traitName]] for ch in nd.children])
grand_parent=nd.parent ## fetch grandparent of node to be removed
grand_parent.children.remove(nd) ## remove node from its parent's children
if sum(child_status)==0: ## node has no valid children - current grandparent will be removed on next iteration, since it will only have one child
pass
else: ## at least one child is still valid - reconnect the one valid child to grandparent
child=nd.children[child_status.index(1)] ## identify the valid child
child.parent=grand_parent ## child's parent is now its grandparent
grand_parent.children.append(child) ## child is now child of grandparent
child.length+=nd.length ## child's length now includes it's former parent's length
subtree.remove(nd) ## remove node from subtree
outfile.write('\t{%s,%s,%s,%s,%d}'%(k.absoluteTime,mostRecentTip,k.parent.traits[traitName],k.traits[traitName],len(subtree_leaves)))
sys.stderr.write('\t{%s,%s,%s,%s,%d}'%(k.absoluteTime,mostRecentTip,k.parent.traits[traitName],k.traits[traitName],len(subtree_leaves)))
##########
## Comment out to output stats rather than trees
##########
# if len(subtree)>0: ## only proceed if there's at least one tip in the subtree
# local_tree=bt.tree() ## create a new tree object where the subtree will be
# local_tree.Objects=subtree ## assign branches to new tree object
# local_tree.root.children.append(subtree[0]) ## connect tree object's root with subtree
# subtree[0].parent=local_tree.root ## subtree's root's parent is tree object's root
# #local_tree.root.absoluteTime=subtree[0].absoluteTime-subtree[0].length ## root's absolute time is subtree's root time
# local_tree.sortBranches() ## sort branches, draw small tree
# subtreeString=local_tree.toString()
# outfile.write('\t%s'%(subtreeString))
###################################################
## your analysis and output code goes here, e.g.
## if 'custom' in analyses:
## out={x:0.0 for x in available_trait_values}
## for k in ll.Objects:
## if k.traits.has_key(trait):
## out[k.traits[trait]]+=k.length
## for tr in available_trait_values:
## outfile.write('\t%s'%(out[tr]))
###################################################
outfile.write('\n') ## newline for post-burnin tree
treecount+=1 ## increment tree counter
################################################################################
if treecount==threshold: ## tree passed progress bar threshold
timeTakenSoFar=dt.datetime.now()-begin ## time elapsed
timeElapsed=float(divmod(timeTakenSoFar.total_seconds(),60)[0]+(divmod(timeTakenSoFar.total_seconds(),60)[1])/float(60))
timeRate=float(divmod(timeTakenSoFar.total_seconds(),60)[0]*60+divmod(timeTakenSoFar.total_seconds(),60)[1])/float(treecount+1) ## rate at which trees have been processed
processingRate.append(timeRate) ## remember rate
ETA=(sum(processingRate)/float(len(processingRate))*(Ntrees-treecount))/float(60)/float(60) ## estimate how long it'll take, given mean processing rate
excessiveTrees=treecount
if treecount>=10000:
excessiveTrees=10000
if timeElapsed>60.0: ## took over 60 minutes
reportElapsed=timeElapsed/60.0 ## switch to hours
reportUnit='h' ## unit is hours
else:
reportElapsed=timeElapsed ## keep minutes
reportUnit='m'
sys.stderr.write('\r') ## output progress bar
sys.stderr.write("[%-30s] %4d%% trees: %5d elapsed: %5.2f%1s ETA: %5.2fh (%6.1e s/tree)" % ('='*(excessiveTrees/progress_update),treecount/float(Ntrees)*100.0,treecount,reportElapsed,reportUnit,ETA,processingRate[-1]))
sys.stderr.flush()
threshold+=progress_update ## increment to next threshold
################################################################################
if 'End;' in line:
pass
outfile.close()
sys.stderr.write('\nDone!\n') ## done!