forked from vamsikrishna1902/IntentPredictionEval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConcurrentSessions.py
111 lines (104 loc) · 4.96 KB
/
ConcurrentSessions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sys
import os
import time, argparse
import QueryParser as qp
import ParseConfigFile as parseConfig
from ParseConfigFile import getConfig
import random
import TupleIntent as ti
def countQueries(inputFile):
sessionQueryCountDict = {}
with open(inputFile,encoding='utf-8') as f:
for line in f:
sessTokens = line.strip().split(";")
# sessionIndices need to be noted that they are neither sequential nor complete. For instance session 15 or 16 does not exist.
sessIndex = int(sessTokens[0].split(" ")[1])
sessionQueryCountDict[sessIndex] = line.count(";")-1 #line ends with a semicolon but has the first token as session name which is ignored
return sessionQueryCountDict
def retrieveQueryFromFile(inputFile, coveredSessQueries, sessIndex):
with open(inputFile) as f:
for line in f:
sessTokens = line.strip().split(";")
curSessIndex = int(sessTokens[0].split(" ")[1])
if sessIndex == curSessIndex:
# here we assume queryIndex starts from 1, count of queries covered so far gives the index of the next uncovered query
# but sessionName is the 0th token, so we need to add a 1 to get the query index
if sessIndex not in coveredSessQueries:
queryIndex = 1
else:
queryIndex = coveredSessQueries[sessIndex] + 1
sessQuery = sessTokens[queryIndex].split("~")[0]
sessQuery = ' '.join(sessQuery.split()) # eliminate extra spaces within the SQL query
return (sessQuery,queryIndex)
def createConcurrentSessions(inputFile, outputFile):
sessionQueryCountDict = countQueries(inputFile)
try:
os.remove(outputFile)
except OSError:
pass
keyList = list(sessionQueryCountDict.keys()) # this actually clones the keys into a new python object keyList, not the same as pointing to the existing list
coveredSessQueries = {} # key is sessionID and value is the query count covered
while len(keyList)!=0:
sessIndex = random.choice(keyList)
if sessIndex not in coveredSessQueries or coveredSessQueries[sessIndex] < sessionQueryCountDict[sessIndex]:
(sessQuery,queryIndex) = retrieveQueryFromFile(inputFile, coveredSessQueries, sessIndex)
if sessQuery == "":
keyList.remove(sessIndex)
continue
if sessIndex not in coveredSessQueries:
coveredSessQueries[sessIndex] = 1
else:
coveredSessQueries[sessIndex] += 1
output_str="Session "+str(sessIndex)+", Query "+str(queryIndex)+";"+sessQuery
ti.appendToFile(outputFile, output_str)
print("appended Session "+str(sessIndex)+", Query "+str(queryIndex))
else:
keyList.remove(sessIndex)
def readTestSessIDs(inputSeqFile, configDict):
sessIDs = set()
try:
lineIndex = 0
with open(inputSeqFile) as f:
for line in f:
if lineIndex >= int(configDict['RNN_SUSTENANCE_TRAIN_LIMIT']):
sessID = line.strip().split(";")[0].split(",")[0]
sessIDs.add(sessID)
lineIndex+=1
f.close()
except:
print("error1")
return sessIDs
def convertSeqToConcFile(configDict):
inputConcFile = getConfig('data/MINC/InputOutput/ClusterRuns/NovelTables-203087-8936Sess-307KModified/MincBitFragmentIntentSessions_Singularity')
inputSeqFile = getConfig('data/MINC/InputOutput/ClusterRuns/NovelTables-203087-8936Sess-307KModified/MincBitFragmentIntentSessions_Sustenance_0.8')
concTrainFile = getConfig('data/MINC/InputOutput/ClusterRuns/NovelTables-203087-8936Sess-307KModified/MincBitFragmentIntentSessions_ConcTrain_Sustenance_0.8')
concTestFile = getConfig('data/MINC/InputOutput/ClusterRuns/NovelTables-203087-8936Sess-307KModified/MincBitFragmentIntentSessions_ConcTest_Sustenance_0.8')
testSessIDs = readTestSessIDs(inputSeqFile, configDict)
try:
os.remove(concTrainFile)
except:
pass
try:
os.remove(concTestFile)
except:
pass
try:
with open(inputConcFile) as f:
for line in f:
curSessID = line.strip().split(";")[0].split(",")[0]
if curSessID in testSessIDs:
ti.appendToFile(concTestFile, line.strip())
else:
ti.appendToFile(concTrainFile, line.strip())
f.close()
except:
print("error2")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-config", help="Config parameters file", type=str, required=True)
args = parser.parse_args()
configDict = parseConfig.parseConfigFile(args.config)
#convertSeqToConcFile(configDict)
createConcurrentSessions(getConfig(configDict['QUERYSESSIONS']), getConfig(configDict['CONCURRENT_QUERY_SESSIONS']))
print("Completed concurrent session order creation")