-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add back-end support for interventions-from-document gollm task (#4853)
Co-authored-by: dgauldie <[email protected]>
- Loading branch information
Showing
10 changed files
with
303 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import json | ||
import sys | ||
from gollm.entities import InterventionsFromDocument | ||
from gollm.openai.tool_utils import interventions_from_document | ||
|
||
from taskrunner import TaskRunnerInterface | ||
|
||
|
||
def cleanup(): | ||
pass | ||
|
||
|
||
def main(): | ||
exitCode = 0 | ||
try: | ||
taskrunner = TaskRunnerInterface(description="Extract interventions from paper CLI") | ||
taskrunner.on_cancellation(cleanup) | ||
|
||
input_dict = taskrunner.read_input_dict_with_timeout() | ||
|
||
taskrunner.log("Creating InterventionsFromDocument model from input") | ||
input_model = InterventionsFromDocument(**input_dict) | ||
amr = json.dumps(input_model.amr, separators=(",", ":")) | ||
|
||
taskrunner.log("Sending request to OpenAI API") | ||
response = interventions_from_document( | ||
research_paper=input_model.research_paper, amr=amr | ||
) | ||
taskrunner.log("Received response from OpenAI API") | ||
|
||
taskrunner.write_output_dict_with_timeout({"response": response}) | ||
|
||
except Exception as e: | ||
sys.stderr.write(f"Error: {str(e)}\n") | ||
sys.stderr.flush() | ||
exitCode = 1 | ||
|
||
taskrunner.log("Shutting down") | ||
taskrunner.shutdown() | ||
sys.exit(exitCode) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 2 additions & 0 deletions
2
...er/src/main/java/software/uncharted/terarium/hmiserver/models/dataservice/Identifier.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
package software.uncharted.terarium.hmiserver.models.dataservice; | ||
|
||
import java.io.Serial; | ||
import java.io.Serializable; | ||
import software.uncharted.terarium.hmiserver.annotations.TSModel; | ||
|
||
@TSModel | ||
public record Identifier(String curie, String name) implements Serializable { | ||
@Serial | ||
private static final long serialVersionUID = 302308407252037615L; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
.../uncharted/terarium/hmiserver/service/tasks/InterventionsFromDocumentResponseHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package software.uncharted.terarium.hmiserver.service.tasks; | ||
|
||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import com.fasterxml.jackson.databind.JsonNode; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import java.util.UUID; | ||
import lombok.Data; | ||
import lombok.RequiredArgsConstructor; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.springframework.stereotype.Component; | ||
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.Provenance; | ||
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceRelationType; | ||
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceType; | ||
import software.uncharted.terarium.hmiserver.models.simulationservice.interventions.InterventionPolicy; | ||
import software.uncharted.terarium.hmiserver.models.task.TaskResponse; | ||
import software.uncharted.terarium.hmiserver.service.data.DocumentAssetService; | ||
import software.uncharted.terarium.hmiserver.service.data.InterventionService; | ||
import software.uncharted.terarium.hmiserver.service.data.ProvenanceService; | ||
|
||
@Component | ||
@RequiredArgsConstructor | ||
@Slf4j | ||
public class InterventionsFromDocumentResponseHandler extends TaskResponseHandler { | ||
|
||
public static final String NAME = "gollm_task:interventions_from_document"; | ||
|
||
private final ObjectMapper objectMapper; | ||
private final InterventionService interventionService; | ||
private final ProvenanceService provenanceService; | ||
private final DocumentAssetService documentAssetService; | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
|
||
@Data | ||
public static class Input { | ||
|
||
@JsonProperty("research_paper") | ||
String researchPaper; | ||
|
||
@JsonProperty("amr") | ||
String amr; | ||
} | ||
|
||
@Data | ||
public static class Response { | ||
|
||
JsonNode response; | ||
} | ||
|
||
@Data | ||
public static class Properties { | ||
|
||
UUID projectId; | ||
UUID documentId; | ||
UUID modelId; | ||
UUID workflowId; | ||
UUID nodeId; | ||
} | ||
|
||
@Override | ||
public TaskResponse onSuccess(final TaskResponse resp) { | ||
try { | ||
final Properties props = resp.getAdditionalProperties(Properties.class); | ||
final Response interventionPolicies = objectMapper.readValue(resp.getOutput(), Response.class); | ||
|
||
// For each configuration, create a new model configuration | ||
for (final JsonNode policy : interventionPolicies.response.get("interventionPolicies")) { | ||
final InterventionPolicy ip = objectMapper.treeToValue(policy, InterventionPolicy.class); | ||
|
||
if (ip.getModelId() != props.modelId) { | ||
ip.setModelId(props.modelId); | ||
} | ||
|
||
final InterventionPolicy newPolicy = interventionService.createAsset( | ||
ip, | ||
props.projectId, | ||
ASSUME_WRITE_PERMISSION_ON_BEHALF_OF_USER | ||
); | ||
|
||
// add provenance | ||
provenanceService.createProvenance( | ||
new Provenance() | ||
.setLeft(newPolicy.getId()) | ||
.setLeftType(ProvenanceType.INTERVENTION_POLICY) | ||
.setRight(props.documentId) | ||
.setRightType(ProvenanceType.DOCUMENT) | ||
.setRelationType(ProvenanceRelationType.EXTRACTED_FROM) | ||
); | ||
} | ||
} catch (final Exception e) { | ||
log.error("Failed to extract intervention policy", e); | ||
throw new RuntimeException(e); | ||
} | ||
log.info("Intervention policy extracted successfully"); | ||
return resp; | ||
} | ||
} |
Oops, something went wrong.