Skip to content

Commit

Permalink
add back-end support for interventions-from-document gollm task (#4853)
Browse files Browse the repository at this point in the history
Co-authored-by: dgauldie <[email protected]>
  • Loading branch information
dgauldie and dgauldie authored Sep 24, 2024
1 parent 6e9b086 commit d2e1ad4
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 18 deletions.
1 change: 1 addition & 0 deletions packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,7 @@ export enum ProvenanceType {
Document = "Document",
Workflow = "Workflow",
Equation = "Equation",
InterventionPolicy = "InterventionPolicy",
}

export enum SimulationType {
Expand Down
44 changes: 44 additions & 0 deletions packages/gollm/tasks/interventions_from_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json
import sys
from gollm.entities import InterventionsFromDocument
from gollm.openai.tool_utils import interventions_from_document

from taskrunner import TaskRunnerInterface


def cleanup():
pass


def main():
exitCode = 0
try:
taskrunner = TaskRunnerInterface(description="Extract interventions from paper CLI")
taskrunner.on_cancellation(cleanup)

input_dict = taskrunner.read_input_dict_with_timeout()

taskrunner.log("Creating InterventionsFromDocument model from input")
input_model = InterventionsFromDocument(**input_dict)
amr = json.dumps(input_model.amr, separators=(",", ":"))

taskrunner.log("Sending request to OpenAI API")
response = interventions_from_document(
research_paper=input_model.research_paper, amr=amr
)
taskrunner.log("Received response from OpenAI API")

taskrunner.write_output_dict_with_timeout({"response": response})

except Exception as e:
sys.stderr.write(f"Error: {str(e)}\n")
sys.stderr.flush()
exitCode = 1

taskrunner.log("Shutting down")
taskrunner.shutdown()
sys.exit(exitCode)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import jakarta.transaction.Transactional;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -370,7 +369,7 @@ public ResponseEntity<Dataset> createFromSimulationResult(
datasetService.updateAsset(dataset, projectId, permission);

// If this is a temporary asset, do not add to project.
if (addToProject == false) {
if (!addToProject) {
return ResponseEntity.status(HttpStatus.CREATED).body(dataset);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import software.uncharted.terarium.hmiserver.service.tasks.EquationsFromImageResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.GenerateResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.GenerateSummaryHandler;
import software.uncharted.terarium.hmiserver.service.tasks.InterventionsFromDocumentResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.ModelCardResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService.TaskMode;
Expand Down Expand Up @@ -448,6 +449,112 @@ public ResponseEntity<TaskResponse> createConfigureModelFromDatasetTask(
return ResponseEntity.ok().body(resp);
}

@PostMapping("/interventions-from-document")
@Secured(Roles.USER)
@Operation(summary = "Dispatch a `GoLLM interventions-from-document` task")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
)
),
@ApiResponse(
responseCode = "404",
description = "The provided model or document arguments are not found",
content = @Content
),
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
public ResponseEntity<TaskResponse> createInterventionsFromDocumentTask(
@RequestParam(name = "model-id", required = true) final UUID modelId,
@RequestParam(name = "document-id", required = true) final UUID documentId,
@RequestParam(name = "mode", required = false, defaultValue = "ASYNC") final TaskMode mode,
@RequestParam(name = "workflow-id", required = false) final UUID workflowId,
@RequestParam(name = "node-id", required = false) final UUID nodeId,
@RequestParam(name = "project-id", required = false) final UUID projectId
) {
final Schema.Permission permission = projectService.checkPermissionCanRead(
currentUserService.get().getId(),
projectId
);

// Grab the document
final Optional<DocumentAsset> document = documentAssetService.getAsset(documentId, permission);
if (document.isEmpty()) {
log.warn(String.format("Document %s not found", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found"));
}

// make sure there is text in the document
if (document.get().getText() == null || document.get().getText().isEmpty()) {
log.warn(String.format("Document %s has no extracted text", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done"));
}

// Grab the model
final Optional<Model> model = modelService.getAsset(modelId, permission);
if (model.isEmpty()) {
log.warn(String.format("Model %s not found", modelId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("model.not-found"));
}

final InterventionsFromDocumentResponseHandler.Input input = new InterventionsFromDocumentResponseHandler.Input();
input.setResearchPaper(document.get().getText());

// stripping the metadata from the model before its sent since it can cause
// gollm to fail with massive inputs
model.get().setMetadata(null);
input.setAmr(model.get().serializeWithoutTerariumFieldsKeepId());

// Create the task
final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
req.setScript(InterventionsFromDocumentResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());

try {
req.setInput(objectMapper.writeValueAsBytes(input));
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req.setProjectId(projectId);

final InterventionsFromDocumentResponseHandler.Properties props =
new InterventionsFromDocumentResponseHandler.Properties();
props.setProjectId(projectId);
props.setDocumentId(documentId);
props.setModelId(modelId);
props.setWorkflowId(workflowId);
props.setNodeId(nodeId);
req.setAdditionalProperties(props);

final TaskResponse resp;
try {
resp = taskService.runTask(mode, req);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.json-processing"));
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.gollm.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.gollm.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.gollm.execution-failure"));
}

return ResponseEntity.ok().body(resp);
}

@GetMapping("/compare-models")
@Secured(Roles.USER)
@Operation(summary = "Dispatch a `GoLLM Compare Models` task")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package software.uncharted.terarium.hmiserver.models.dataservice;

import java.io.Serial;
import java.io.Serializable;
import software.uncharted.terarium.hmiserver.annotations.TSModel;

@TSModel
public record Identifier(String curie, String name) implements Serializable {
@Serial
private static final long serialVersionUID = 302308407252037615L;
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ public enum ProvenanceType {
WORKFLOW("Workflow"),

@JsonAlias("Equation")
EQUATION("Equation");
EQUATION("Equation"),

@JsonAlias("InterventionPolicy")
INTERVENTION_POLICY("InterventionPolicy");

public final String type;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ public TaskResponse onSuccess(final TaskResponse resp) {
provenanceService.createProvenance(
new Provenance()
.setLeft(props.getDocumentId())
.setLeftType(ProvenanceType.DOCUMENT)
.setLeftType(ProvenanceType.EQUATION)
.setRight(props.getDocumentId())
.setRightType(ProvenanceType.EQUATION)
.setRightType(ProvenanceType.DOCUMENT)
.setRelationType(ProvenanceRelationType.EXTRACTED_FROM)
);
} catch (final Exception e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.UUID;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.Provenance;
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceRelationType;
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceType;
import software.uncharted.terarium.hmiserver.models.simulationservice.interventions.InterventionPolicy;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.service.data.DocumentAssetService;
import software.uncharted.terarium.hmiserver.service.data.InterventionService;
import software.uncharted.terarium.hmiserver.service.data.ProvenanceService;

@Component
@RequiredArgsConstructor
@Slf4j
public class InterventionsFromDocumentResponseHandler extends TaskResponseHandler {

public static final String NAME = "gollm_task:interventions_from_document";

private final ObjectMapper objectMapper;
private final InterventionService interventionService;
private final ProvenanceService provenanceService;
private final DocumentAssetService documentAssetService;

@Override
public String getName() {
return NAME;
}

@Data
public static class Input {

@JsonProperty("research_paper")
String researchPaper;

@JsonProperty("amr")
String amr;
}

@Data
public static class Response {

JsonNode response;
}

@Data
public static class Properties {

UUID projectId;
UUID documentId;
UUID modelId;
UUID workflowId;
UUID nodeId;
}

@Override
public TaskResponse onSuccess(final TaskResponse resp) {
try {
final Properties props = resp.getAdditionalProperties(Properties.class);
final Response interventionPolicies = objectMapper.readValue(resp.getOutput(), Response.class);

// For each configuration, create a new model configuration
for (final JsonNode policy : interventionPolicies.response.get("interventionPolicies")) {
final InterventionPolicy ip = objectMapper.treeToValue(policy, InterventionPolicy.class);

if (ip.getModelId() != props.modelId) {
ip.setModelId(props.modelId);
}

final InterventionPolicy newPolicy = interventionService.createAsset(
ip,
props.projectId,
ASSUME_WRITE_PERMISSION_ON_BEHALF_OF_USER
);

// add provenance
provenanceService.createProvenance(
new Provenance()
.setLeft(newPolicy.getId())
.setLeftType(ProvenanceType.INTERVENTION_POLICY)
.setRight(props.documentId)
.setRightType(ProvenanceType.DOCUMENT)
.setRelationType(ProvenanceRelationType.EXTRACTED_FROM)
);
}
} catch (final Exception e) {
log.error("Failed to extract intervention policy", e);
throw new RuntimeException(e);
}
log.info("Intervention policy extracted successfully");
return resp;
}
}
Loading

0 comments on commit d2e1ad4

Please sign in to comment.