From 8518c86e88147280887b9f68cfab88bca928eaea Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Thu, 12 Sep 2024 10:21:01 -0400 Subject: [PATCH] Optimize utilize intermediate results (#116) * Add optimize hook, rename calibrate's * Correction to total_possible_iterations * removing hook's print * Add types to rabbitmq messages. * surround optimize hook in try * defining progress_hook in exception case. --- service/models/operations/calibrate.py | 4 +- .../models/operations/ensemble_calibrate.py | 4 +- service/models/operations/optimize.py | 20 ++++++++++ service/utils/rabbitmq.py | 38 ++++++++++++++++++- 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/service/models/operations/calibrate.py b/service/models/operations/calibrate.py index c07f834..bcdb7b9 100644 --- a/service/models/operations/calibrate.py +++ b/service/models/operations/calibrate.py @@ -13,7 +13,7 @@ fetch_and_convert_static_interventions, fetch_and_convert_dynamic_interventions, ) -from utils.rabbitmq import gen_rabbitmq_hook +from utils.rabbitmq import gen_calibrate_rabbitmq_hook from utils.tds import fetch_dataset, fetch_model @@ -77,7 +77,7 @@ def gen_pyciemss_args(self, job_id): # TODO: Test RabbitMQ try: - hook = gen_rabbitmq_hook(job_id) + hook = gen_calibrate_rabbitmq_hook(job_id) except (socket.gaierror, AMQPConnectionError): logging.warning( "%s: Failed to connect to RabbitMQ. Unable to log progress", job_id diff --git a/service/models/operations/ensemble_calibrate.py b/service/models/operations/ensemble_calibrate.py index 3ecf549..4e18363 100644 --- a/service/models/operations/ensemble_calibrate.py +++ b/service/models/operations/ensemble_calibrate.py @@ -11,7 +11,7 @@ from models.base import Dataset, OperationRequest, Timespan, ModelConfig from models.converters import convert_to_solution_mapping -from utils.rabbitmq import gen_rabbitmq_hook +from utils.rabbitmq import gen_calibrate_rabbitmq_hook from utils.tds import fetch_dataset, fetch_model @@ -52,7 +52,7 @@ def gen_pyciemss_args(self, job_id): dataset_path = fetch_dataset(self.dataset.dict(), job_id) try: - hook = gen_rabbitmq_hook(job_id) + hook = gen_calibrate_rabbitmq_hook(job_id) except (socket.gaierror, AMQPConnectionError): logging.warning( "%s: Failed to connect to RabbitMQ. Unable to log progress", job_id diff --git a/service/models/operations/optimize.py b/service/models/operations/optimize.py index 7edc414..b63695b 100644 --- a/service/models/operations/optimize.py +++ b/service/models/operations/optimize.py @@ -2,6 +2,11 @@ from typing import ClassVar, List, Optional from enum import Enum +from utils.rabbitmq import OptimizeHook +from pika.exceptions import AMQPConnectionError +import socket +import logging + import numpy as np import torch @@ -151,6 +156,20 @@ def gen_pyciemss_args(self, job_id): if step_size is not None and solver_method == "euler": solver_options["step_size"] = step_size + total_possible_iterations = ( + extra_options.get("maxiter") + 1 + ) * extra_options.get("maxfeval") + try: + progress_hook = OptimizeHook(job_id, total_possible_iterations) + except (socket.gaierror, AMQPConnectionError): + logging.warning( + "%s: Failed to connect to RabbitMQ. Unable to log progress", job_id + ) + + # Log progress hook when unable to connect - for testing purposes. + def progress_hook(current_results): + logging.info(f"Optimize current results: {current_results.tolist()}") + return { "model_path_or_json": amr_path, "logging_step_size": self.logging_step_size, @@ -173,6 +192,7 @@ def gen_pyciemss_args(self, job_id): "n_samples_ouu": n_samples_ouu, "solver_method": solver_method, "solver_options": solver_options, + "progress_hook": progress_hook, **extra_options, } diff --git a/service/utils/rabbitmq.py b/service/utils/rabbitmq.py index b8b0120..9a1667e 100644 --- a/service/utils/rabbitmq.py +++ b/service/utils/rabbitmq.py @@ -52,7 +52,7 @@ def callback(ch, method, properties, body): channel.start_consuming() -def gen_rabbitmq_hook(job_id): +def gen_calibrate_rabbitmq_hook(job_id): connection = pika.BlockingConnection( conn_config, ) @@ -63,8 +63,42 @@ def hook(progress, loss): exchange="", routing_key="simulation-status", body=json.dumps( - {"job_id": job_id, "progress": progress, "loss": str(loss)} + { + "job_id": job_id, + "type": "calibrate", + "progress": progress, + "loss": str(loss), + } ), ) return hook + + +class OptimizeHook: + def __init__(self, job_id, total_possible_iterations): + connection = pika.BlockingConnection( + conn_config, + ) + self.channel = connection.channel() + self.job_id = job_id + self.type = "optimize" + self.result = [] + self.step = 0 + self.total_possible_iterations = total_possible_iterations + + def __call__(self, current_results): + self.step += 1 + self.channel.basic_publish( + exchange="", + routing_key="simulation-status", + body=json.dumps( + { + "job_id": self.job_id, + "progress": self.step, + "type": self.type, + "current_results": current_results.tolist(), + "total_possible_iterations": self.total_possible_iterations, + } + ), + )