Skip to content

Commit

Permalink
Add model check
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardobl committed Apr 20, 2024
1 parent eb37bbb commit 371c283
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/OpenAgentsNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def postRun(self):
self._disksByUrl = {}
self._diskByName = {}

def canRun(self,job):
return True

def preRun(self):
pass

Expand Down Expand Up @@ -242,6 +245,8 @@ def executePendingJob(self ):
t=time.time()
try:
client = self.getClient() # Reconnect client for each job
if not runner.canRun(job):
continue
client.acceptJob(rpc_pb2.RpcAcceptJob(jobId=job.id))
wasAccepted = True
self.log("Job started on node "+self.nodeName, job.id)
Expand Down
7 changes: 7 additions & 0 deletions src/events/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
"description": "The number of tokens to overlap between each chunk",
"name": "Overlap"
},
"model":{
"type": "string",
"value": "text",
"description": "Specify which model to use. Empty for any",
"name": "Model"
},
"documents": {
"type": "array",
"description": "The documents to generate embeddings from",
Expand Down Expand Up @@ -73,6 +79,7 @@
["param", "max-tokens", "{{in.max_tokens}}"],
["param", "overlap", "{{in.overlap}}"],
["param", "quantize", "{{in.quantize}}"],
["param", "model", "{{in.model}}"],
{{#in.documents}}
["i", "{{data}}", "{{data_type}}", "", "{{marker}}"],
{{/in.documents}}
Expand Down
7 changes: 7 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def quantize(self, embeddings):
binary_embeddings = quantize_embeddings(embeddings, precision="binary")
return binary_embeddings

def canRun(self,job):
def getParamValue(key,default=None):
param = [x for x in job.param if x.key == key]
return param[0].value[0] if len(param) > 0 else default
model = getParamValue("model", self.modelName)
return model == self.modelName

def run(self,job):
def getParamValue(key,default=None):
param = [x for x in job.param if x.key == key]
Expand Down

0 comments on commit 371c283

Please sign in to comment.