Skip to content

Commit

Permalink
Plotting in slingshot script is now disabled by default and can be en…
Browse files Browse the repository at this point in the history
…abled via a command line argument. Also fixes Murali-group#7
  • Loading branch information
matthieubulte committed Oct 9, 2019
1 parent c5ec584 commit 039e843
Showing 1 changed file with 86 additions and 83 deletions.
169 changes: 86 additions & 83 deletions scripts/runSlingshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from pathlib import Path
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from optparse import OptionParser
import seaborn as sns
from optparse import OptionParser

seed = 0
np.random.seed(seed)
Expand Down Expand Up @@ -46,19 +44,95 @@ def parseArgs(args):

parser.add_option('', '--noEnd', action='store_true',default= False,
help='Do not force SlingShot to have an end state.')



parser.add_option('-r', '--perplexity', type='int',default=500,
help='Perplexity for tSNE.')


parser.add_option('', '--enableDebugPlots', action='store_true', default=False,
help='Plot both the deterministic and slingshot pseudotime trajectories on top of '
'the original cell clustering.')

(opts, args) = parser.parse_args(args)

return opts, args



def computeSSPT(ExpDF, ptDF, nClust, outPaths, noEnd = False, perplexity = 500):
def plotResults(outPaths, tn, lneCnt, curveLst):
import matplotlib
matplotlib.use("TkAgg")

import matplotlib.pyplot as plt
import seaborn as sns

# Plot slingshot pseudotime
# and original clusters
f, axes = plt.subplots(2, 2, figsize=(7.5, 7.5))

detPT = pd.read_csv(outPaths + "/SlingshotPT.csv",
header=0, index_col=0)
print()
colNames = detPT.columns
for colName in colNames:
# Select cells belonging to each pseudotime trajectory
index = detPT[colName].index[detPT[colName].notnull()]
tn.loc[index, colName] = detPT.loc[index, colName]

sns.scatterplot(x='dim1', y='dim2',
data=tn.loc[index], hue=colName,
palette="viridis",
ax=axes[1][0])
plt.legend([])

for line in range(0, lneCnt, 2):
sns.lineplot(x=curveLst[line + 1], y=curveLst[line],
color="k", ax=axes[1][0])

sns.scatterplot(x='dim1', y='dim2',
data=tn, hue='kMeans',
palette="Set1",
ax=axes[1][1])

# Plot deterministic pseudotime
# and original clusters
detPT = pd.read_csv(outPaths + "/PseudoTime.csv",
header=0, index_col=0)
colNames = detPT.columns
for idx in range(len(colNames)):
# Select cells belonging to each pseudotime trajectory
colName = colNames[idx]
index = detPT[colName].index[detPT[colName].notnull()]
tn.loc[index, 'Original'] = int(idx)

tn['ExpTime'] = detPT.min(axis='columns')
tn.to_csv(outPaths + "/Updated_rd.tsv", sep='\t')

sns.scatterplot(x='dim1', y='dim2',
data=tn, hue='ExpTime',
palette="viridis",
ax=axes[0][0])

sns.scatterplot(x='dim1', y='dim2',
data=tn, hue='Original',
palette="Set1",
ax=axes[0][1])

axes[0][0].get_legend().remove()
axes[0][0].title.set_text('Experiment Time')
axes[0][1].get_legend().remove()
axes[0][1].title.set_text('Original Trajectories')

axes[1][0].get_legend().remove()
axes[1][0].title.set_text('Slingshot Pseudotime')

axes[1][1].get_legend().remove()
axes[1][1].title.set_text('kMeans Clustering')

f.tight_layout()

plt.savefig(outPaths + "/SlingshotOutput.png")


def computeSSPT(ExpDF, ptDF, nClust, outPaths, noEnd = False, perplexity = 500, plotting=False):
'''
Compute PseudoTime using 'slingshot'.
Needs the input GenexCells expression data frame.
Expand Down Expand Up @@ -120,11 +194,6 @@ def computeSSPT(ExpDF, ptDF, nClust, outPaths, noEnd = False, perplexity = 500):
print(cmdToRun)
os.system(cmdToRun)

# os.system("cp temp/PseudoTime.csv "+outPaths+"/SlingshotPT.csv")
# os.system("cp temp/curves.csv "+outPaths+"/curves.csv")
# os.system("cp temp/rd.tsv "+outPaths+"/rd.tsv")
# os.system("cp temp/cl.tsv "+outPaths+"/cl.tsv")

# Do this only for the first file
tn = pd.read_csv(outPaths+"/rd.tsv",
header = 0, index_col =None, sep='\t')
Expand All @@ -143,75 +212,9 @@ def computeSSPT(ExpDF, ptDF, nClust, outPaths, noEnd = False, perplexity = 500):
tn.columns = ['CellID','dim1','dim2','kMeans']
tn.index = tn.CellID

f, axes = plt.subplots(2, 2, figsize=(7.5, 7.5))

# Plot slingshot pseudotime
# and original clusters
detPT = pd.read_csv(outPaths+"/SlingshotPT.csv",
header = 0, index_col = 0)
print()
colNames = detPT.columns
for colName in colNames:
# Select cells belonging to each pseudotime trajectory
index = detPT[colName].index[detPT[colName].notnull()]
tn.loc[index,colName] = detPT.loc[index,colName]


sns.scatterplot(x='dim1',y='dim2',
data = tn.loc[index], hue = colName,
palette = "viridis",
ax = axes[1][0])
plt.legend([])

for line in range(0, lneCnt, 2):
sns.lineplot(x= curveLst[line+1],y=curveLst[line],
color = "k", ax = axes[1][0])

sns.scatterplot(x='dim1',y='dim2',
data = tn, hue = 'kMeans',
palette = "Set1",
ax = axes[1][1])

# Plot deterministic pseduotime
# and original clusters
detPT = pd.read_csv(outPaths+"/PseudoTime.csv",
header = 0, index_col = 0)
colNames = detPT.columns
for idx in range(len(colNames)):
# Select cells belonging to each pseudotime trajectory
colName = colNames[idx]
index = detPT[colName].index[detPT[colName].notnull()]
tn.loc[index,'Original'] = int(idx)

tn['ExpTime'] = detPT.min(axis='columns')

sns.scatterplot(x='dim1',y='dim2',
data = tn, hue = 'ExpTime',
palette = "viridis",
ax = axes[0][0])

#tn['Original'] = tn['Original'].astype('category')
sns.scatterplot(x='dim1',y='dim2',
data = tn, hue = 'Original',
palette = "Set1",
ax = axes[0][1])

axes[0][0].get_legend().remove()
axes[0][0].title.set_text('Experiment Time')
axes[0][1].get_legend().remove()
axes[0][1].title.set_text('Original Trajectories')

axes[1][0].get_legend().remove()
axes[1][0].title.set_text('Slingshot Pseudotime')

axes[1][1].get_legend().remove()
axes[1][1].title.set_text('kMeans Clustering')

f.tight_layout()

tn.to_csv(outPaths+"/Updated_rd.tsv",
sep='\t')
plt.savefig(outPaths+"/SlingshotOutput.png")
if plotting:
plotResults(outPaths, tn, lneCnt, curveLst)

os.system("rm -rf temp/")

def main(args):
Expand All @@ -222,7 +225,7 @@ def main(args):

# Compute PseudoTime using slingshot
# TODO: Add other methods
computeSSPT(ExprDF, ptDF, opts.nClusters, opts.outPrefix, opts.noEnd, opts.perplexity)
computeSSPT(ExprDF, ptDF, opts.nClusters, opts.outPrefix, opts.noEnd, opts.perplexity, opts.enableDebugPlots)

if __name__ == "__main__":
main(sys.argv)

0 comments on commit 039e843

Please sign in to comment.