From e4a23ef01d7720d177553e4e13d462feeba7454f Mon Sep 17 00:00:00 2001 From: Pariterre Date: Wed, 24 Jul 2024 15:58:56 -0400 Subject: [PATCH 01/17] Started to add the capability to connect to a plot server --- bioptim/gui/online_callback.py | 161 +++++++++++++++++++++++++++++---- bioptim/gui/plot.py | 4 +- 2 files changed, 147 insertions(+), 18 deletions(-) diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index d09b33494..7f9827b80 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -1,4 +1,7 @@ +from abc import ABC, abstractmethod +from enum import Enum import multiprocessing as mp +import socket from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity from matplotlib import pyplot as plt @@ -6,7 +9,23 @@ from .plot import PlotOcp -class OnlineCallback(Callback): +class OnlineCallbackType(Enum): + """ + The type of callback + + Attributes + ---------- + MULTIPROCESS: int + Using multiprocessing + SERVER: int + Using a server to communicate with the client + """ + + MULTIPROCESS = 0 + SERVER = 1 + + +class OnlineCallbackAbstract(Callback, ABC): """ CasADi interface of Ipopt callbacks @@ -18,12 +37,6 @@ class OnlineCallback(Callback): The number of optimization variables ng: int The number of constraints - queue: mp.Queue - The multiprocessing queue - plotter: ProcessPlotter - The callback for plotting for the multiprocessing - plot_process: mp.Process - The multiprocessing placeholder Methods ------- @@ -63,20 +76,16 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None): from ..interfaces.ipopt_interface import IpoptInterface interface = IpoptInterface(ocp) - all_g, all_g_bounds = interface.dispatch_bounds() + all_g, _ = interface.dispatch_bounds() self.ng = all_g.shape[0] - v = interface.ocp.variables_vector - self.construct("AnimateCallback", opts) - self.queue = mp.Queue() - self.plotter = self.ProcessPlotter(self.ocp) - self.plot_process = mp.Process(target=self.plotter, args=(self.queue, show_options), daemon=True) - self.plot_process.start() - + @abstractmethod def close(self): - self.plot_process.kill() + """ + Close the callback + """ @staticmethod def get_n_in() -> int: @@ -155,6 +164,7 @@ def get_sparsity_in(self, i: int) -> tuple: else: return Sparsity(0, 0) + @abstractmethod def eval(self, arg: list | tuple) -> list: """ Send the current data to the plotter @@ -168,6 +178,34 @@ def eval(self, arg: list | tuple) -> list: ------- A list of error index """ + + +class OnlineCallback(OnlineCallbackAbstract): + """ + Multiprocessing implementation of the online callback + + Attributes + ---------- + queue: mp.Queue + The multiprocessing queue + plotter: ProcessPlotter + The callback for plotting for the multiprocessing + plot_process: mp.Process + The multiprocessing placeholder + """ + + def __init__(self, ocp, opts: dict = None, show_options: dict = None): + super(OnlineCallback, self).__init__(ocp, opts, show_options) + + self.queue = mp.Queue() + self.plotter = self.ProcessPlotter(self.ocp) + self.plot_process = mp.Process(target=self.plotter, args=(self.queue, show_options), daemon=True) + self.plot_process.start() + + def close(self): + self.plot_process.kill() + + def eval(self, arg: list | tuple) -> list: send = self.queue.put args_dict = {} for i, s in enumerate(nlpsol_out()): @@ -204,7 +242,7 @@ def __init__(self, ocp): self.ocp = ocp - def __call__(self, pipe: mp.Queue, show_options: dict): + def __call__(self, pipe: mp.Queue, show_options: dict | None): """ Parameters ---------- @@ -239,3 +277,92 @@ def callback(self) -> bool: for i, fig in enumerate(self.plot.all_figures): fig.canvas.draw() return True + + +class OnlineCallbackServer: + class _ServerMessages(Enum): + INITIATE_CONNEXION = 0 + NEW_DATA = 1 + CLOSE_CONNEXION = 2 + + def __init__(self): + # Define the host and port + self._host = "localhost" + self._port = 3050 + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._initialize_connexion() + + self._plotter: PlotOcp = None + + def _initialize_connexion(self): + # Start listening to the server + self._socket.bind((self._host, self._port)) + self._socket.listen(5) + print(f"Server started on {self._host}:{self._port}") + + while True: + client_socket, addr = self._socket.accept() + print(f"Connection from {addr}") + + # Receive the actual data + data = b"" + while True: + chunk = client_socket.recv(1024) + if not chunk: + break + data += chunk + data_as_list = data.decode().split("\n") + + if data_as_list[0] == OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION: + print(f"Received from client: {data_as_list[1]}") + + response = "Hello from server!" + client_socket.send(response.encode()) + # TODO Get the OCP and show_options from the client + # ocp = data_as_list[1] + # show_options = data_as_list[2] + # self._initialize_plotter(ocp, show_options=show_options) + elif data_as_list[0] == OnlineCallbackServer._ServerMessages.NEW_DATA: + print(f"Received from client: {data_as_list[1]}") + + response = "Hello from server!" + client_socket.send(response.encode()) + # self._plotter.update_data(data_as_list[1]) + elif data_as_list[0] == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: + print("Closing the server") + client_socket.close() + continue + else: + print("Unknown message received") + continue + + def _initialize_plotter(self, ocp, **show_options): + self._plotter = PlotOcp(ocp, **show_options) + + +class OnlineCallbackTcp(OnlineCallbackAbstract, OnlineCallbackServer): + def __init__(self, ocp, opts: dict = None, show_options: dict = None): + super(OnlineCallbackAbstract, self).__init__(ocp, opts, show_options) + super(OnlineCallbackServer, self).__init__() + + def _initialize_connexion(self): + # Start the client + try: + self._socket.connect((self._host, self._port)) + except ConnectionError: + print("Could not connect to the server, make sure it is running") + print(f"Connected to {self._host}:{self._port}") + + message = f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION}\nHello from client!" + self._socket.send(message.encode()) + data = self._socket.recv(1024).decode() + print(f"Received from server: {data}") + + self.close() + + def close(self): + self._socket.send(f"{OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION}\nGoodbye from client!".encode()) + self._socket.close() + + def eval(self, arg: list | tuple) -> list: + self._socket.send(f"{OnlineCallbackServer._ServerMessages.NEW_DATA}\n{arg}".encode()) diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index d831bd2c5..fbb93b2d6 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -782,7 +782,7 @@ def update_data( def _compute_y_from_plot_func( self, custom_plot: CustomPlot, phase_idx, time_stepwise, dt, x_decision, x_stepwise, u, p, a, d - ) -> list[np.ndarray | list, ...]: + ) -> list[np.ndarray | list]: """ Compute the y data from the plot function @@ -1031,3 +1031,5 @@ def _generate_windows_size(nb: int) -> tuple: n_rows = int(round(np.sqrt(nb))) return n_rows + 1 if n_rows * n_rows < nb else n_rows, n_rows + + From a82292c097e20b887dec5fbd8f612a800f291d5d Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 25 Jul 2024 16:30:23 -0400 Subject: [PATCH 02/17] Started to add the capabilty to run matplotlib from a server --- bioptim/__init__.py | 3 + bioptim/gui/online_callback.py | 211 +++++--- bioptim/gui/plot.py | 504 +++++++++++++++++- bioptim/interfaces/interface_utils.py | 34 +- bioptim/misc/enums.py | 16 + .../optimization/optimal_control_program.py | 1 + bioptim/optimization/optimization_vector.py | 5 +- resources/bioptim_plotting_server.py | 9 + 8 files changed, 701 insertions(+), 82 deletions(-) create mode 100644 resources/bioptim_plotting_server.py diff --git a/bioptim/__init__.py b/bioptim/__init__.py index b4f2b2f4c..61c30a665 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -208,6 +208,7 @@ MagnitudeType, MultiCyclicCycleSolutions, PhaseDynamics, + ShowOnlineType, ) from .misc.mapping import BiMappingList, BiMapping, Mapping, NodeMapping, NodeMappingList, SelectionMapping, Dependency from .optimization.multi_start import MultiStart @@ -229,3 +230,5 @@ from .optimization.stochastic_optimal_control_program import StochasticOptimalControlProgram from .optimization.problem_type import SocpType from .misc.casadi_expand import lt, le, gt, ge, if_else, if_else_zero + +from .gui.online_callback import OnlineCallbackServer diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index 7f9827b80..45cb669c0 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -1,28 +1,18 @@ from abc import ABC, abstractmethod from enum import Enum +import json +import logging import multiprocessing as mp import socket +import struct -from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity -from matplotlib import pyplot as plt - -from .plot import PlotOcp - - -class OnlineCallbackType(Enum): - """ - The type of callback - Attributes - ---------- - MULTIPROCESS: int - Using multiprocessing - SERVER: int - Using a server to communicate with the client - """ +from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity, DM +from matplotlib import pyplot as plt +import numpy as np - MULTIPROCESS = 0 - SERVER = 1 +from .plot import PlotOcp, OcpSerializable +from ..optimization.optimization_vector import OptimizationVectorHelper class OnlineCallbackAbstract(Callback, ABC): @@ -180,7 +170,7 @@ def eval(self, arg: list | tuple) -> list: """ -class OnlineCallback(OnlineCallbackAbstract): +class OnlineCallbackMultiprocess(OnlineCallbackAbstract): """ Multiprocessing implementation of the online callback @@ -195,7 +185,7 @@ class OnlineCallback(OnlineCallbackAbstract): """ def __init__(self, ocp, opts: dict = None, show_options: dict = None): - super(OnlineCallback, self).__init__(ocp, opts, show_options) + super(OnlineCallbackMultiprocess, self).__init__(ocp, opts, show_options) self.queue = mp.Queue() self.plotter = self.ProcessPlotter(self.ocp) @@ -255,7 +245,10 @@ def __call__(self, pipe: mp.Queue, show_options: dict | None): if show_options is None: show_options = {} self.pipe = pipe - self.plot = PlotOcp(self.ocp, **show_options) + + dummy_phase_times = OptimizationVectorHelper.extract_step_times(self.ocp) + self.plot = PlotOcp(self.ocp, dummy_phase_times=dummy_phase_times, **show_options) + timer = self.plot.all_figures[0].canvas.new_timer(interval=10) timer.add_callback(self.callback) timer.start() @@ -285,84 +278,164 @@ class _ServerMessages(Enum): NEW_DATA = 1 CLOSE_CONNEXION = 2 - def __init__(self): + def _prepare_logger(self): + name = "OnlineCallbackServer" + console_handler = logging.StreamHandler() + formatter = logging.Formatter( + "{asctime} - {name}:{levelname} - {message}", + style="{", + datefmt="%Y-%m-%d %H:%M", + ) + console_handler.setFormatter(formatter) + + self._logger = logging.getLogger(name) + self._logger.addHandler(console_handler) + self._logger.setLevel(logging.INFO) + + def __init__(self, host: str = "localhost", port: int = 3050): + self._prepare_logger() + # Define the host and port - self._host = "localhost" - self._port = 3050 + self._host = host + self._port = port self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._initialize_connexion() - self._plotter: PlotOcp = None - def _initialize_connexion(self): + def run(self): # Start listening to the server self._socket.bind((self._host, self._port)) - self._socket.listen(5) - print(f"Server started on {self._host}:{self._port}") + self._socket.listen(1) + self._logger.debug(f"Server started on {self._host}:{self._port}") while True: + self._logger.info("Waiting for a new connexion") client_socket, addr = self._socket.accept() - print(f"Connection from {addr}") + self._handle_client(client_socket, addr) + def _handle_client(self, client_socket: socket.socket, addr: tuple): + self._logger.info(f"Connection from {addr}") + while True: # Receive the actual data - data = b"" - while True: - chunk = client_socket.recv(1024) - if not chunk: - break - data += chunk + try: + data = client_socket.recv(1024) + except: + self._logger.warning("Error while receiving data from client, closing connexion") + return + data_as_list = data.decode().split("\n") + self._logger.debug(f"Received from client: {data_as_list}") - if data_as_list[0] == OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION: - print(f"Received from client: {data_as_list[1]}") + if not data: + self._logger.info("The client closed the connexion") + return - response = "Hello from server!" - client_socket.send(response.encode()) - # TODO Get the OCP and show_options from the client - # ocp = data_as_list[1] - # show_options = data_as_list[2] - # self._initialize_plotter(ocp, show_options=show_options) - elif data_as_list[0] == OnlineCallbackServer._ServerMessages.NEW_DATA: - print(f"Received from client: {data_as_list[1]}") + try: + message_type = OnlineCallbackServer._ServerMessages(int(data_as_list[0])) + except ValueError: + self._logger.warning("Unknown message type received") + continue - response = "Hello from server!" - client_socket.send(response.encode()) + if message_type == OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION: + self._logger.debug(f"Received hand shake from client, len of OCP: {data_as_list[1]}") + ocp_len = data_as_list[1] + try: + ocp_data = client_socket.recv(int(ocp_len)) + except: + self._logger.warning("Error while receiving OCP data from client, closing connexion") + return + + data_json = json.loads(ocp_data) + + try: + dummy_time_vector = [] + for phase_times in data_json["dummy_phase_times"]: + dummy_time_vector.append([DM(v) for v in phase_times]) + del data_json["dummy_phase_times"] + except: + self._logger.warning("Error while extracting dummy time vector from OCP data, closing connexion") + return + + try: + ocp = OcpSerializable.deserialize(data_json) + except: + self._logger.warning("Error while deserializing OCP data from client, closing connexion") + return + + show_options = {} + self._plotter = PlotOcp(ocp, dummy_phase_times=dummy_time_vector, **show_options) + self._plotter.show() + continue + + elif message_type == OnlineCallbackServer._ServerMessages.NEW_DATA: + n_bytes = [int(d) for d in data_as_list[1][1:-1].split(",")] + n_points = [int(d / 8) for d in n_bytes] + all_data = [] + for n_byte, n_point in zip(n_bytes, n_points): + data = client_socket.recv(n_byte) + data_tp = struct.unpack("d" * n_point, data) + all_data.append(DM(data_tp)) + + self._logger.debug(f"Received new data from client") # self._plotter.update_data(data_as_list[1]) - elif data_as_list[0] == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: - print("Closing the server") - client_socket.close() continue + + elif message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: + self._logger.info("Received close connexion from client") + client_socket.close() + return else: - print("Unknown message received") + self._logger.warning("Unknown message received") continue - def _initialize_plotter(self, ocp, **show_options): - self._plotter = PlotOcp(ocp, **show_options) +class OnlineCallbackTcp(OnlineCallbackAbstract): + def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str = "localhost", port: int = 3050): + super().__init__(ocp, opts, show_options) -class OnlineCallbackTcp(OnlineCallbackAbstract, OnlineCallbackServer): - def __init__(self, ocp, opts: dict = None, show_options: dict = None): - super(OnlineCallbackAbstract, self).__init__(ocp, opts, show_options) - super(OnlineCallbackServer, self).__init__() + self._host = host + self._port = port + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._initialize_connexion() def _initialize_connexion(self): # Start the client try: self._socket.connect((self._host, self._port)) except ConnectionError: - print("Could not connect to the server, make sure it is running") - print(f"Connected to {self._host}:{self._port}") - - message = f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION}\nHello from client!" - self._socket.send(message.encode()) - data = self._socket.recv(1024).decode() - print(f"Received from server: {data}") - - self.close() + raise RuntimeError( + "Could not connect to the plotter server, make sure it is running " + "by calling 'OnlineCallbackServer().start()' on another python instance)" + ) + + ocp_plot = OcpSerializable.from_ocp(self.ocp).serialize() + ocp_plot["dummy_phase_times"] = [] + for phase_times in OptimizationVectorHelper.extract_step_times(self.ocp): + ocp_plot["dummy_phase_times"].append([np.array(v)[:, 0].tolist() for v in phase_times]) + serialized_ocp = json.dumps(ocp_plot).encode() + self._socket.send( + f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION.value}\n{len(serialized_ocp)}".encode() + ) + + # TODO ADD SHOW OPTIONS to the send + self._socket.send(serialized_ocp) def close(self): - self._socket.send(f"{OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION}\nGoodbye from client!".encode()) + self._socket.send( + f"{OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION.value}\nGoodbye from client!".encode() + ) self._socket.close() def eval(self, arg: list | tuple) -> list: - self._socket.send(f"{OnlineCallbackServer._ServerMessages.NEW_DATA}\n{arg}".encode()) + arg_as_bytes = [] + for a in arg: + to_pack = np.array(a).T.tolist() + if len(to_pack) == 1: + to_pack = to_pack[0] + arg_as_bytes.append(struct.pack("d" * len(to_pack), *to_pack)) + + self._socket.send( + f"{OnlineCallbackServer._ServerMessages.NEW_DATA.value}\n{[len(a) for a in arg_as_bytes]}".encode() + ) + for a in arg_as_bytes: + self._socket.sendall(a) + return [0] diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index fbb93b2d6..f3c7f4feb 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -130,6 +130,501 @@ def __init__( self.all_variables_in_one_subplot = all_variables_in_one_subplot +class MappingSerializable: + map_idx: list[int] + oppose: list[int] + + def __init__(self, map_idx: list, oppose: list): + self.map_idx = map_idx + self.oppose = oppose + + def map(self, obj): + from ..misc.mapping import Mapping + + return Mapping.map(self, obj) + + @classmethod + def from_mapping(cls, mapping): + return cls( + map_idx=mapping.map_idx, + oppose=mapping.oppose, + ) + + def serialize(self): + return { + "map_idx": self.map_idx, + "oppose": self.oppose, + } + + @classmethod + def deserialize(cls, data): + return cls( + map_idx=data["map_idx"], + oppose=data["oppose"], + ) + + +class BiMappingSerializable: + to_first: MappingSerializable + to_second: MappingSerializable + + def __init__(self, to_first: MappingSerializable, to_second: MappingSerializable): + self.to_first = to_first + self.to_second = to_second + + @classmethod + def from_bimapping(cls, bimapping): + return cls( + to_first=MappingSerializable.from_mapping(bimapping.to_first), + to_second=MappingSerializable.from_mapping(bimapping.to_second), + ) + + def serialize(self): + return { + "to_first": self.to_first.serialize(), + "to_second": self.to_second.serialize(), + } + + @classmethod + def deserialize(cls, data): + return cls( + to_first=MappingSerializable.deserialize(data["to_first"]), + to_second=MappingSerializable.deserialize(data["to_second"]), + ) + + +class BoundsSerializable: + min: np.ndarray | DM + max: np.ndarray | DM + + def __init__(self, min: np.ndarray | DM, max: np.ndarray | DM): + self.min = min + self.max = max + + @classmethod + def from_bounds(cls, bounds): + return cls( + min=np.array(bounds.min), + max=np.array(bounds.max), + ) + + def serialize(self): + return { + "min": self.min.tolist(), + "max": self.max.tolist(), + } + + @classmethod + def deserialize(cls, data): + return cls( + min=np.array(data["min"]), + max=np.array(data["max"]), + ) + + +class CustomPlotSerializable: + function: Callable + type: PlotType + phase_mappings: BiMappingSerializable + legend: tuple | list + combine_to: str + color: str + linestyle: str + ylim: tuple | list + bounds: BoundsSerializable + node_idx: list | slice | range + label: list + compute_derivative: bool + integration_rule: QuadratureRule + parameters: dict[str, Any] + all_variables_in_one_subplot: bool + + def __init__( + self, + function: Callable, + plot_type: PlotType, + phase_mappings: BiMapping, + legend: tuple | list, + combine_to: str, + color: str, + linestyle: str, + ylim: tuple | list, + bounds: BoundsSerializable, + node_idx: list | slice | range, + label: list, + compute_derivative: bool, + integration_rule: QuadratureRule, + parameters: dict[str, Any], + all_variables_in_one_subplot: bool, + ): + self.function = None # TODO function + self.type = plot_type + self.phase_mappings = phase_mappings + self.legend = legend + self.combine_to = combine_to + self.color = color + self.linestyle = linestyle + self.ylim = ylim + self.bounds = bounds + self.node_idx = node_idx + self.label = label + self.compute_derivative = compute_derivative + self.integration_rule = integration_rule + self.parameters = None # TODO {key: value for key, value in parameters.items()} + self.all_variables_in_one_subplot = all_variables_in_one_subplot + + @classmethod + def from_custom_plot(cls, custom_plot: CustomPlot): + return cls( + function=custom_plot.function, + plot_type=custom_plot.type, + phase_mappings=BiMappingSerializable.from_bimapping(custom_plot.phase_mappings), + legend=custom_plot.legend, + combine_to=custom_plot.combine_to, + color=custom_plot.color, + linestyle=custom_plot.linestyle, + ylim=custom_plot.ylim, + bounds=BoundsSerializable.from_bounds(custom_plot.bounds), + node_idx=custom_plot.node_idx, + label=custom_plot.label, + compute_derivative=custom_plot.compute_derivative, + integration_rule=custom_plot.integration_rule, + parameters=custom_plot.parameters, + all_variables_in_one_subplot=custom_plot.all_variables_in_one_subplot, + ) + + def serialize(self): + return { + "function": self.function, + "type": self.type.value, + "phase_mappings": self.phase_mappings.serialize(), + "legend": self.legend, + "combine_to": self.combine_to, + "color": self.color, + "linestyle": self.linestyle, + "ylim": self.ylim, + "bounds": self.bounds.serialize(), + "node_idx": self.node_idx, + "label": self.label, + "compute_derivative": self.compute_derivative, + "integration_rule": self.integration_rule.value, + "parameters": self.parameters, + "all_variables_in_one_subplot": self.all_variables_in_one_subplot, + } + + @classmethod + def deserialize(cls, data): + return cls( + function=data["function"], + plot_type=PlotType(data["type"]), + phase_mappings=BiMappingSerializable.deserialize(data["phase_mappings"]), + legend=data["legend"], + combine_to=data["combine_to"], + color=data["color"], + linestyle=data["linestyle"], + ylim=data["ylim"], + bounds=data["bounds"], + node_idx=data["node_idx"], + label=data["label"], + compute_derivative=data["compute_derivative"], + integration_rule=QuadratureRule(data["integration_rule"]), + parameters=data["parameters"], + all_variables_in_one_subplot=data["all_variables_in_one_subplot"], + ) + + +class OptimizationVariableContainerSerializable: + node_index: int + shape: tuple[int, int] + + def __init__(self, node_index: int, shape: tuple[int, int], len: int): + self.node_index = node_index + self.shape = shape + self._len = len + + def __len__(self): + return self._len + + @classmethod + def from_container(cls, ovc): + from ..optimization.optimization_variable import OptimizationVariableContainer + + ovc: OptimizationVariableContainer = ovc + + return cls( + node_index=ovc.node_index, + shape=ovc.shape, + len=len(ovc), + ) + + def serialize(self): + return { + "node_index": self.node_index, + "shape": self.shape, + "len": self._len, + } + + @classmethod + def deserialize(cls, data): + return cls( + node_index=data["node_index"], + shape=data["shape"], + len=data["len"], + ) + + +class OdeSolverSerializable: + polynomial_degree: int + type: OdeSolver + + def __init__(self, polynomial_degree: int, type: OdeSolver): + self.polynomial_degree = polynomial_degree + self.type = type + + @classmethod + def from_ode_solver(cls, ode_solver): + from ..dynamics.ode_solver import OdeSolver + + ode_solver: OdeSolver = ode_solver + + return cls( + polynomial_degree=5, + type="ode", + ) + + def serialize(self): + return { + "polynomial_degree": self.polynomial_degree, + "type": self.type, + } + + @classmethod + def deserialize(cls, data): + return cls( + polynomial_degree=data["polynomial_degree"], + type=data["type"], + ) + + +class NlpSerializable: + ns: int + phase_idx: int + + n_states_nodes: int + states: OptimizationVariableContainerSerializable + states_dot: OptimizationVariableContainerSerializable + controls: OptimizationVariableContainerSerializable + algebraic_states: OptimizationVariableContainerSerializable + parameters: OptimizationVariableContainerSerializable + numerical_timeseries: OptimizationVariableContainerSerializable + + ode_solver: OdeSolverSerializable + plot: dict[str, CustomPlotSerializable] + + def __init__( + self, + ns: int, + phase_idx: int, + n_states_nodes: int, + states: OptimizationVariableContainerSerializable, + states_dot: OptimizationVariableContainerSerializable, + controls: OptimizationVariableContainerSerializable, + algebraic_states: OptimizationVariableContainerSerializable, + parameters: OptimizationVariableContainerSerializable, + numerical_timeseries: OptimizationVariableContainerSerializable, + ode_solver: OdeSolverSerializable, + plot: dict[str, CustomPlotSerializable], + ): + self.ns = ns + self.phase_idx = phase_idx + self.n_states_nodes = n_states_nodes + self.states = states + self.states_dot = states_dot + self.controls = controls + self.algebraic_states = algebraic_states + self.parameters = parameters + self.numerical_timeseries = numerical_timeseries + self.ode_solver = ode_solver + self.plot = plot + + @classmethod + def from_nlp(cls, nlp): + from ..optimization.non_linear_program import NonLinearProgram + + nlp: NonLinearProgram = nlp + + return cls( + ns=nlp.ns, + phase_idx=nlp.phase_idx, + n_states_nodes=nlp.n_states_nodes, + states=OptimizationVariableContainerSerializable.from_container(nlp.states), + states_dot=OptimizationVariableContainerSerializable.from_container(nlp.states_dot), + controls=OptimizationVariableContainerSerializable.from_container(nlp.controls), + algebraic_states=OptimizationVariableContainerSerializable.from_container(nlp.algebraic_states), + parameters=OptimizationVariableContainerSerializable.from_container(nlp.parameters), + numerical_timeseries=OptimizationVariableContainerSerializable.from_container(nlp.numerical_timeseries), + ode_solver=OdeSolverSerializable.from_ode_solver(nlp.ode_solver), + plot={key: CustomPlotSerializable.from_custom_plot(nlp.plot[key]) for key in nlp.plot}, + ) + + def serialize(self): + return { + "ns": self.ns, + "phase_idx": self.phase_idx, + "n_states_nodes": self.n_states_nodes, + "states": self.states.serialize(), + "states_dot": self.states_dot.serialize(), + "controls": self.controls.serialize(), + "algebraic_states": self.algebraic_states.serialize(), + "parameters": self.parameters.serialize(), + "numerical_timeseries": self.numerical_timeseries.serialize(), + "ode_solver": self.ode_solver.serialize(), + "plot": {key: plot.serialize() for key, plot in self.plot.items()}, + } + + @classmethod + def deserialize(cls, data): + return cls( + ns=data["ns"], + phase_idx=data["phase_idx"], + n_states_nodes=data["n_states_nodes"], + states=OptimizationVariableContainerSerializable.deserialize(data["states"]), + states_dot=OptimizationVariableContainerSerializable.deserialize(data["states_dot"]), + controls=OptimizationVariableContainerSerializable.deserialize(data["controls"]), + algebraic_states=OptimizationVariableContainerSerializable.deserialize(data["algebraic_states"]), + parameters=OptimizationVariableContainerSerializable.deserialize(data["parameters"]), + numerical_timeseries=OptimizationVariableContainerSerializable.deserialize(data["numerical_timeseries"]), + ode_solver=OdeSolverSerializable.deserialize(data["ode_solver"]), + plot={key: CustomPlotSerializable.deserialize(plot) for key, plot in data["plot"].items()}, + ) + + +class SaveIterationsInfoSerializable: + path_to_results: str + result_file_name: str | list[str] + nb_iter_save: int + current_iter: int + f_list: list[int] + + def __init__( + self, path_to_results: str, result_file_name: str, nb_iter_save: int, current_iter: int, f_list: list[int] + ): + self.path_to_results = path_to_results + self.result_file_name = result_file_name + self.nb_iter_save = nb_iter_save + self.current_iter = current_iter + self.f_list = f_list + + @classmethod + def from_save_iterations_info(cls, save_iterations_info): + from .ipopt_output_plot import SaveIterationsInfo + + save_iterations_info: SaveIterationsInfo = save_iterations_info + + if save_iterations_info is None: + return None + + return cls( + path_to_results=save_iterations_info.path_to_results, + result_file_name=save_iterations_info.result_file_name, + nb_iter_save=save_iterations_info.nb_iter_save, + current_iter=save_iterations_info.current_iter, + f_list=save_iterations_info.f_list, + ) + + def serialize(self): + return { + "path_to_results": self.path_to_results, + "result_file_name": self.result_file_name, + "nb_iter_save": self.nb_iter_save, + "current_iter": self.current_iter, + "f_list": self.f_list, + } + + @classmethod + def deserialize(cls, data): + return cls( + path_to_results=data["path_to_results"], + result_file_name=data["result_file_name"], + nb_iter_save=data["nb_iter_save"], + current_iter=data["current_iter"], + f_list=data["f_list"], + ) + + +class OcpSerializable: + n_phases: int + nlp: list[NlpSerializable] + + time_phase_mapping: BiMappingSerializable + + plot_ipopt_outputs: bool + plot_check_conditioning: bool + save_ipopt_iterations_info: SaveIterationsInfoSerializable + + def __init__( + self, + n_phases: int, + nlp: list[NlpSerializable], + time_phase_mapping: BiMappingSerializable, + plot_ipopt_outputs: bool, + plot_check_conditioning: bool, + save_ipopt_iterations_info: SaveIterationsInfoSerializable, + ): + self.n_phases = n_phases + self.nlp = nlp + + self.time_phase_mapping = time_phase_mapping + + self.plot_ipopt_outputs = plot_ipopt_outputs + self.plot_check_conditioning = plot_check_conditioning + self.save_ipopt_iterations_info = save_ipopt_iterations_info + + @classmethod + def from_ocp(cls, ocp): + from ..optimization.optimal_control_program import OptimalControlProgram + + ocp: OptimalControlProgram = ocp + + return cls( + n_phases=ocp.n_phases, + nlp=[NlpSerializable.from_nlp(nlp) for nlp in ocp.nlp], + time_phase_mapping=BiMappingSerializable.from_bimapping(ocp.time_phase_mapping), + plot_ipopt_outputs=ocp.plot_ipopt_outputs, + plot_check_conditioning=ocp.plot_check_conditioning, + save_ipopt_iterations_info=SaveIterationsInfoSerializable.from_save_iterations_info( + ocp.save_ipopt_iterations_info + ), + ) + + def serialize(self): + return { + "n_phases": self.n_phases, + "nlp": [nlp.serialize() for nlp in self.nlp], + "time_phase_mapping": self.time_phase_mapping.serialize(), + "plot_ipopt_outputs": self.plot_ipopt_outputs, + "plot_check_conditioning": self.plot_check_conditioning, + "save_ipopt_iterations_info": ( + None if self.save_ipopt_iterations_info is None else self.save_ipopt_iterations_info.serialize() + ), + } + + @classmethod + def deserialize(cls, data): + return cls( + n_phases=data["n_phases"], + nlp=[NlpSerializable.deserialize(nlp) for nlp in data["nlp"]], + time_phase_mapping=BiMappingSerializable.deserialize(data["time_phase_mapping"]), + plot_ipopt_outputs=data["plot_ipopt_outputs"], + plot_check_conditioning=data["plot_check_conditioning"], + save_ipopt_iterations_info=( + None + if data["save_ipopt_iterations_info"] is None + else SaveIterationsInfoSerializable.deserialize(data["save_ipopt_iterations_info"]) + ), + ) + + class PlotOcp: """ Attributes @@ -205,11 +700,12 @@ class PlotOcp: def __init__( self, - ocp, + ocp: OcpSerializable, automatically_organize: bool = True, show_bounds: bool = False, shooting_type: Shooting = Shooting.MULTIPLE, integrator: SolutionIntegrator = SolutionIntegrator.OCP, + dummy_phase_times: list[list[float]] = None, ): """ Prepares the figures during the simulation @@ -226,7 +722,8 @@ def __init__( The type of integration method integrator: SolutionIntegrator Use the ode defined by OCP or use a separate integrator provided by scipy - + dummy_phase_times: list[list[float]] + The time of each phase """ self.ocp = ocp self.plot_options = { @@ -247,7 +744,6 @@ def __init__( self.integrator = integrator # Emulate the time from Solution.time, this is just to give the size anyway - dummy_phase_times = OptimizationVectorHelper.extract_step_times(ocp, DM(np.ones(ocp.n_phases))) self._update_time_vector(dummy_phase_times) self.axes = {} @@ -1031,5 +1527,3 @@ def _generate_windows_size(nb: int) -> tuple: n_rows = int(round(np.sqrt(nb))) return n_rows + 1 if n_rows * n_rows < nb else n_rows, n_rows - - diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 824cf7587..32f7574b4 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -6,10 +6,10 @@ from casadi import horzcat, vertcat, sum1, sum2, nlpsol, SX, MX, reshape from bioptim.optimization.solution.solution import Solution -from ..gui.online_callback import OnlineCallback +from ..gui.online_callback import OnlineCallbackMultiprocess, OnlineCallbackTcp from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers -from ..misc.enums import InterpolationType +from ..misc.enums import InterpolationType, ShowOnlineType from ..optimization.non_linear_program import NonLinearProgram @@ -24,10 +24,32 @@ def generic_online_optim(interface, ocp, show_options: dict = None): show_options: dict The options to pass to PlotOcp """ - - if platform != "linux": - raise RuntimeError("Online graphics are only available on Linux") - interface.options_common["iteration_callback"] = OnlineCallback(ocp, show_options=show_options) + show_type = ShowOnlineType.MULTIPROCESS + if "type" in show_options: + show_type = show_options["type"] + del show_options["type"] + + if show_type == ShowOnlineType.MULTIPROCESS: + if platform == "win32": + raise RuntimeError( + "Online ShowOnlineType.MULTIPROCESS is not supported on Windows. " + "You can add show_options={'type': ShowOnlineType.TCP} to the Solver declaration" + ) + interface.options_common["iteration_callback"] = OnlineCallbackMultiprocess(ocp, show_options=show_options) + elif show_type == ShowOnlineType.TCP: + host = None + if "host" in show_options: + host = show_options["host"] + del show_options["host"] + port = None + if "port" in show_options: + port = show_options["port"] + del show_options["port"] + interface.options_common["iteration_callback"] = OnlineCallbackTcp( + ocp, show_options=show_options, host=host, port=port + ) + else: + raise NotImplementedError(f"show_options['type']={show_type} is not implemented yet") def generic_solve(interface, expand_during_shake_tree=False) -> dict: diff --git a/bioptim/misc/enums.py b/bioptim/misc/enums.py index 785abb3fc..302f75a49 100644 --- a/bioptim/misc/enums.py +++ b/bioptim/misc/enums.py @@ -93,6 +93,22 @@ class PlotType(Enum): POINT = 3 # Point plot +class ShowOnlineType(Enum): + """ + The type of callback + + Attributes + ---------- + MULTIPROCESS: int + Using multiprocessing + SERVER: int + Using a server to communicate with the client + """ + + MULTIPROCESS = 0 + TCP = 1 + + class ControlType(Enum): """ Selection of valid controls diff --git a/bioptim/optimization/optimal_control_program.py b/bioptim/optimization/optimal_control_program.py index 81e6b540f..42595cf33 100644 --- a/bioptim/optimization/optimal_control_program.py +++ b/bioptim/optimization/optimal_control_program.py @@ -1363,6 +1363,7 @@ def prepare_plots( show_bounds=show_bounds, shooting_type=shooting_type, integrator=integrator, + dummy_phase_times=OptimizationVectorHelper.extract_step_times(self.ocp), ) def check_conditioning(self): diff --git a/bioptim/optimization/optimization_vector.py b/bioptim/optimization/optimization_vector.py index cd2a1a7e7..86c122a19 100644 --- a/bioptim/optimization/optimization_vector.py +++ b/bioptim/optimization/optimization_vector.py @@ -347,7 +347,7 @@ def extract_phase_dt(ocp, data: np.ndarray | DM) -> list: return list(out[:, 0]) @staticmethod - def extract_step_times(ocp, data: np.ndarray | DM) -> list: + def extract_step_times(ocp, data: np.ndarray | DM | None = None) -> list: """ Get the phase time. If time is optimized, the MX/SX values are replaced by their actual optimized time @@ -356,13 +356,14 @@ def extract_step_times(ocp, data: np.ndarray | DM) -> list: ocp: OptimalControlProgram A reference to the ocp data: np.ndarray | DM - The solution in a vector + The solution in a vector, if no data is provided, dummy data is used (it can be useful getting the dimensions) Returns ------- The phase time """ + data = DM(np.ones(ocp.n_phases)) if data is None else data phase_dt = OptimizationVectorHelper.extract_phase_dt(ocp, data) # Starts at zero diff --git a/resources/bioptim_plotting_server.py b/resources/bioptim_plotting_server.py new file mode 100644 index 000000000..a9108b923 --- /dev/null +++ b/resources/bioptim_plotting_server.py @@ -0,0 +1,9 @@ +from bioptim import OnlineCallbackServer + + +def main(): + OnlineCallbackServer().run() + + +if __name__ == "__main__": + main() From 4b6dbe13012cc6ffbf467daa5f18d83207c0f07a Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 29 Jul 2024 08:08:09 -0400 Subject: [PATCH 03/17] Reverteed some changes, working version for launching the plots --- bioptim/gui/online_callback.py | 146 +++++++++++------- bioptim/interfaces/interface_utils.py | 5 +- .../optimization/optimal_control_program.py | 2 +- bioptim/optimization/optimization_vector.py | 3 +- 4 files changed, 92 insertions(+), 64 deletions(-) diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index 45cb669c0..cb8b3f59a 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -5,6 +5,7 @@ import multiprocessing as mp import socket import struct +from typing import Callable from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity, DM @@ -230,7 +231,8 @@ def __init__(self, ocp): A reference to the ocp to show """ - self.ocp = ocp + self.ocp: OcpSerializable = ocp + self._plotter: PlotOcp = None def __call__(self, pipe: mp.Queue, show_options: dict | None): """ @@ -246,15 +248,14 @@ def __call__(self, pipe: mp.Queue, show_options: dict | None): show_options = {} self.pipe = pipe - dummy_phase_times = OptimizationVectorHelper.extract_step_times(self.ocp) - self.plot = PlotOcp(self.ocp, dummy_phase_times=dummy_phase_times, **show_options) - - timer = self.plot.all_figures[0].canvas.new_timer(interval=10) - timer.add_callback(self.callback) + dummy_phase_times = OptimizationVectorHelper.extract_step_times(self.ocp, DM(np.ones(self.ocp.n_phases))) + self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_phase_times, **show_options) + timer = self._plotter.all_figures[0].canvas.new_timer(interval=10) + timer.add_callback(self.plot_update) timer.start() plt.show() - def callback(self) -> bool: + def plot_update(self) -> bool: """ The callback to update the graphs @@ -265,13 +266,17 @@ def callback(self) -> bool: while not self.pipe.empty(): args = self.pipe.get() - self.plot.update_data(args) + self._plotter.update_data(args) - for i, fig in enumerate(self.plot.all_figures): + for i, fig in enumerate(self._plotter.all_figures): fig.canvas.draw() return True +_default_host = "localhost" +_default_port = 3050 + + class OnlineCallbackServer: class _ServerMessages(Enum): INITIATE_CONNEXION = 0 @@ -290,14 +295,14 @@ def _prepare_logger(self): self._logger = logging.getLogger(name) self._logger.addHandler(console_handler) - self._logger.setLevel(logging.INFO) + self._logger.setLevel(logging.DEBUG) - def __init__(self, host: str = "localhost", port: int = 3050): + def __init__(self, host: str = None, port: int = None): self._prepare_logger() # Define the host and port - self._host = host - self._port = port + self._host = host if host else _default_host + self._port = port if port else _default_port self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._plotter: PlotOcp = None @@ -305,18 +310,24 @@ def run(self): # Start listening to the server self._socket.bind((self._host, self._port)) self._socket.listen(1) - self._logger.debug(f"Server started on {self._host}:{self._port}") + self._logger.info(f"Server started on {self._host}:{self._port}") - while True: - self._logger.info("Waiting for a new connexion") - client_socket, addr = self._socket.accept() - self._handle_client(client_socket, addr) + try: + while True: + self._logger.info("Waiting for a new connexion") + client_socket, addr = self._socket.accept() + self._logger.info(f"Connection from {addr}") + self._handle_client(client_socket, addr) + except Exception as e: + self._logger.error(f"Error while running the server: {e}") + finally: + self._socket.close() def _handle_client(self, client_socket: socket.socket, addr: tuple): - self._logger.info(f"Connection from {addr}") while True: # Receive the actual data try: + self._logger.debug("Waiting for data from client") data = client_socket.recv(1024) except: self._logger.warning("Error while receiving data from client, closing connexion") @@ -327,6 +338,7 @@ def _handle_client(self, client_socket: socket.socket, addr: tuple): if not data: self._logger.info("The client closed the connexion") + plt.close() return try: @@ -336,64 +348,78 @@ def _handle_client(self, client_socket: socket.socket, addr: tuple): continue if message_type == OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION: - self._logger.debug(f"Received hand shake from client, len of OCP: {data_as_list[1]}") - ocp_len = data_as_list[1] - try: - ocp_data = client_socket.recv(int(ocp_len)) - except: - self._logger.warning("Error while receiving OCP data from client, closing connexion") - return - - data_json = json.loads(ocp_data) - - try: - dummy_time_vector = [] - for phase_times in data_json["dummy_phase_times"]: - dummy_time_vector.append([DM(v) for v in phase_times]) - del data_json["dummy_phase_times"] - except: - self._logger.warning("Error while extracting dummy time vector from OCP data, closing connexion") - return + self._initiate_connexion(client_socket, data_as_list) + continue + elif message_type == OnlineCallbackServer._ServerMessages.NEW_DATA: try: - ocp = OcpSerializable.deserialize(data_json) + self._update_data(client_socket, data_as_list) except: - self._logger.warning("Error while deserializing OCP data from client, closing connexion") + self._logger.warning("Error while updating data from client, closing connexion") + plt.close() + client_socket.close() return - - show_options = {} - self._plotter = PlotOcp(ocp, dummy_phase_times=dummy_time_vector, **show_options) - self._plotter.show() - continue - - elif message_type == OnlineCallbackServer._ServerMessages.NEW_DATA: - n_bytes = [int(d) for d in data_as_list[1][1:-1].split(",")] - n_points = [int(d / 8) for d in n_bytes] - all_data = [] - for n_byte, n_point in zip(n_bytes, n_points): - data = client_socket.recv(n_byte) - data_tp = struct.unpack("d" * n_point, data) - all_data.append(DM(data_tp)) - - self._logger.debug(f"Received new data from client") - # self._plotter.update_data(data_as_list[1]) continue elif message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: self._logger.info("Received close connexion from client") client_socket.close() + plt.close() return else: self._logger.warning("Unknown message received") continue + def _initiate_connexion(self, client_socket: socket.socket, data_as_list: list): + self._logger.debug(f"Received hand shake from client, len of OCP: {data_as_list[1]}") + ocp_len = data_as_list[1] + try: + ocp_data = client_socket.recv(int(ocp_len)) + except: + self._logger.warning("Error while receiving OCP data from client, closing connexion") + return + + data_json = json.loads(ocp_data) + + try: + dummy_time_vector = [] + for phase_times in data_json["dummy_phase_times"]: + dummy_time_vector.append([DM(v) for v in phase_times]) + del data_json["dummy_phase_times"] + except: + self._logger.warning("Error while extracting dummy time vector from OCP data, closing connexion") + return + + try: + self.ocp = OcpSerializable.deserialize(data_json) + except: + self._logger.warning("Error while deserializing OCP data from client, closing connexion") + return + + show_options = {} + self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) + plt.ion() + plt.draw() # TODO HERE! + + def _update_data(self, client_socket: socket.socket, data_as_list: list): + n_bytes = [int(d) for d in data_as_list[1][1:-1].split(",")] + n_points = [int(d / 8) for d in n_bytes] + all_data = [] + for n_byte, n_point in zip(n_bytes, n_points): + data = client_socket.recv(n_byte) + data_tp = struct.unpack("d" * n_point, data) + all_data.append(DM(data_tp)) + + self._logger.debug(f"Received new data from client") + # self._plotter.update_data(data_as_list[1]) + class OnlineCallbackTcp(OnlineCallbackAbstract): - def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str = "localhost", port: int = 3050): + def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str = None, port: int = None): super().__init__(ocp, opts, show_options) - self._host = host - self._port = port + self._host = host if host else _default_host + self._port = port if port else _default_port self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._initialize_connexion() @@ -409,7 +435,7 @@ def _initialize_connexion(self): ocp_plot = OcpSerializable.from_ocp(self.ocp).serialize() ocp_plot["dummy_phase_times"] = [] - for phase_times in OptimizationVectorHelper.extract_step_times(self.ocp): + for phase_times in OptimizationVectorHelper.extract_step_times(self.ocp, DM(np.ones(self.ocp.n_phases))): ocp_plot["dummy_phase_times"].append([np.array(v)[:, 0].tolist() for v in phase_times]) serialized_ocp = json.dumps(ocp_plot).encode() self._socket.send( diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 32f7574b4..46bdd33c0 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -13,7 +13,7 @@ from ..optimization.non_linear_program import NonLinearProgram -def generic_online_optim(interface, ocp, show_options: dict = None): +def generic_online_optim(interface, ocp, show_options: dict | None = None): """ Declare the online callback to update the graphs while optimizing @@ -24,6 +24,9 @@ def generic_online_optim(interface, ocp, show_options: dict = None): show_options: dict The options to pass to PlotOcp """ + if show_options is None: + show_options = {} + show_type = ShowOnlineType.MULTIPROCESS if "type" in show_options: show_type = show_options["type"] diff --git a/bioptim/optimization/optimal_control_program.py b/bioptim/optimization/optimal_control_program.py index 42595cf33..19a44477a 100644 --- a/bioptim/optimization/optimal_control_program.py +++ b/bioptim/optimization/optimal_control_program.py @@ -1363,7 +1363,7 @@ def prepare_plots( show_bounds=show_bounds, shooting_type=shooting_type, integrator=integrator, - dummy_phase_times=OptimizationVectorHelper.extract_step_times(self.ocp), + dummy_phase_times=OptimizationVectorHelper.extract_step_times(self.ocp, casadi.DM(np.ones(self.ocp.n_phases))), ) def check_conditioning(self): diff --git a/bioptim/optimization/optimization_vector.py b/bioptim/optimization/optimization_vector.py index 86c122a19..d109be9a7 100644 --- a/bioptim/optimization/optimization_vector.py +++ b/bioptim/optimization/optimization_vector.py @@ -347,7 +347,7 @@ def extract_phase_dt(ocp, data: np.ndarray | DM) -> list: return list(out[:, 0]) @staticmethod - def extract_step_times(ocp, data: np.ndarray | DM | None = None) -> list: + def extract_step_times(ocp, data: np.ndarray | DM) -> list: """ Get the phase time. If time is optimized, the MX/SX values are replaced by their actual optimized time @@ -363,7 +363,6 @@ def extract_step_times(ocp, data: np.ndarray | DM | None = None) -> list: The phase time """ - data = DM(np.ones(ocp.n_phases)) if data is None else data phase_dt = OptimizationVectorHelper.extract_phase_dt(ocp, data) # Starts at zero From c52b3772a795f2cddac84e331bdd43c691e291f3 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 29 Jul 2024 17:52:03 -0400 Subject: [PATCH 04/17] Made penalties work Server GUI --- bioptim/gui/online_callback.py | 310 +++++--- bioptim/gui/plot.py | 686 +++--------------- bioptim/gui/serializable_class.py | 623 ++++++++++++++++ bioptim/interfaces/interface_utils.py | 11 + .../optimization/optimal_control_program.py | 2 +- bioptim/optimization/solution/solution.py | 2 +- 6 files changed, 972 insertions(+), 662 deletions(-) create mode 100644 bioptim/gui/serializable_class.py diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index cb8b3f59a..5b2c077ab 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -5,8 +5,7 @@ import multiprocessing as mp import socket import struct -from typing import Callable - +import threading from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity, DM from matplotlib import pyplot as plt @@ -41,7 +40,7 @@ class OnlineCallbackAbstract(Callback, ABC): Get the name of the output variable get_sparsity_in(self, i: int) -> tuple[int] Get the sparsity of a specific variable - eval(self, arg: list | tuple) -> list[int] + eval(self, arg: list | tuple, force: bool = False) -> list[int] Send the current data to the plotter """ @@ -156,7 +155,7 @@ def get_sparsity_in(self, i: int) -> tuple: return Sparsity(0, 0) @abstractmethod - def eval(self, arg: list | tuple) -> list: + def eval(self, arg: list | tuple, force: bool = False) -> list: """ Send the current data to the plotter @@ -196,7 +195,7 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None): def close(self): self.plot_process.kill() - def eval(self, arg: list | tuple) -> list: + def eval(self, arg: list | tuple, force: bool = False) -> list: send = self.queue.put args_dict = {} for i, s in enumerate(nlpsol_out()): @@ -266,7 +265,8 @@ def plot_update(self) -> bool: while not self.pipe.empty(): args = self.pipe.get() - self._plotter.update_data(args) + data = self._plotter.parse_data(**args) + self._plotter.update_data(**data, **args) for i, fig in enumerate(self._plotter.all_figures): fig.canvas.draw() @@ -282,6 +282,9 @@ class _ServerMessages(Enum): INITIATE_CONNEXION = 0 NEW_DATA = 1 CLOSE_CONNEXION = 2 + EMPTY = 3 + TOO_SOON = 4 + UNKNOWN = 5 def _prepare_logger(self): name = "OnlineCallbackServer" @@ -299,6 +302,8 @@ def _prepare_logger(self): def __init__(self, host: str = None, port: int = None): self._prepare_logger() + self._get_data_interval = 1.0 + self._update_plot_interval = 0.01 # Define the host and port self._host = host if host else _default_host @@ -317,71 +322,59 @@ def run(self): self._logger.info("Waiting for a new connexion") client_socket, addr = self._socket.accept() self._logger.info(f"Connection from {addr}") - self._handle_client(client_socket, addr) + self._wait_for_new_connexion(client_socket) except Exception as e: self._logger.error(f"Error while running the server: {e}") finally: self._socket.close() - def _handle_client(self, client_socket: socket.socket, addr: tuple): - while True: - # Receive the actual data - try: - self._logger.debug("Waiting for data from client") - data = client_socket.recv(1024) - except: - self._logger.warning("Error while receiving data from client, closing connexion") - return - - data_as_list = data.decode().split("\n") - self._logger.debug(f"Received from client: {data_as_list}") - - if not data: - self._logger.info("The client closed the connexion") - plt.close() - return - - try: - message_type = OnlineCallbackServer._ServerMessages(int(data_as_list[0])) - except ValueError: - self._logger.warning("Unknown message type received") - continue - - if message_type == OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION: - self._initiate_connexion(client_socket, data_as_list) - continue - - elif message_type == OnlineCallbackServer._ServerMessages.NEW_DATA: - try: - self._update_data(client_socket, data_as_list) - except: - self._logger.warning("Error while updating data from client, closing connexion") - plt.close() - client_socket.close() - return - continue - - elif message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: - self._logger.info("Received close connexion from client") - client_socket.close() - plt.close() - return - else: - self._logger.warning("Unknown message received") - continue - - def _initiate_connexion(self, client_socket: socket.socket, data_as_list: list): - self._logger.debug(f"Received hand shake from client, len of OCP: {data_as_list[1]}") - ocp_len = data_as_list[1] + def _wait_for_data(self, client_socket: socket.socket): + # Receive the actual data try: - ocp_data = client_socket.recv(int(ocp_len)) + self._logger.debug("Waiting for data from client") + data = client_socket.recv(1024) + if not data: + return OnlineCallbackServer._ServerMessages.EMPTY, None except: - self._logger.warning("Error while receiving OCP data from client, closing connexion") - return - - data_json = json.loads(ocp_data) + self._logger.warning("Client closed connexion") + client_socket.close() + return OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION, None + data_as_list = data.decode().split("\n") + try: + message_type = OnlineCallbackServer._ServerMessages(int(data_as_list[0])) + len_all_data = [int(len_data) for len_data in data_as_list[1][1:-1].split(",")] + # Sends confirmation and waits for the next message + client_socket.send("OK".encode()) + self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") + data_out = [] + for len_data in len_all_data: + data_out.append(client_socket.recv(len_data)) + client_socket.send("OK".encode()) + except ValueError: + self._logger.warning("Unknown message type received") + message_type = OnlineCallbackServer._ServerMessages.UNKNOWN + # Sends failure + client_socket.send("NOK".encode()) + data_out = [] + + if message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: + self._logger.info("Received close connexion from client") + client_socket.close() + plt.close() + return OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION, None + + return message_type, data_out + + def _wait_for_new_connexion(self, client_socket: socket.socket): + message_type, data = self._wait_for_data(client_socket=client_socket) + if message_type == OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION: + self._logger.debug(f"Received hand shake from client") + self._initialize_plotter(client_socket, data) + + def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list): try: + data_json = json.loads(ocp_raw[0]) dummy_time_vector = [] for phase_times in data_json["dummy_phase_times"]: dummy_time_vector.append([DM(v) for v in phase_times]) @@ -393,25 +386,105 @@ def _initiate_connexion(self, client_socket: socket.socket, data_as_list: list): try: self.ocp = OcpSerializable.deserialize(data_json) except: + client_socket.send("FAILED".encode()) self._logger.warning("Error while deserializing OCP data from client, closing connexion") return show_options = {} self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) - plt.ion() - plt.draw() # TODO HERE! - - def _update_data(self, client_socket: socket.socket, data_as_list: list): - n_bytes = [int(d) for d in data_as_list[1][1:-1].split(",")] - n_points = [int(d / 8) for d in n_bytes] - all_data = [] - for n_byte, n_point in zip(n_bytes, n_points): - data = client_socket.recv(n_byte) - data_tp = struct.unpack("d" * n_point, data) - all_data.append(DM(data_tp)) + + # Send the confirmation to the client + client_socket.send("PLOT_READY".encode()) + + # Start the callbacks + threading.Timer(self._get_data_interval, self._wait_for_new_data, (client_socket,)).start() + threading.Timer(self._update_plot_interval, self._redraw).start() + plt.show() + + def _redraw(self): + self._logger.debug("Updating plot") + for _, fig in enumerate(self._plotter.all_figures): + fig.canvas.draw() + + if [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0: + threading.Timer(self._update_plot_interval, self._redraw).start() + else: + self._logger.info("All figures have been closed, stop updating the plots") + + def _wait_for_new_data(self, client_socket: socket.socket) -> bool: + """ + The callback to update the graphs + + Returns + ------- + True if everything went well + """ + self._logger.debug(f"Waiting for new data from client") + client_socket.send("READY_FOR_NEXT_DATA".encode()) + + should_continue = False + message_type, data = self._wait_for_data(client_socket=client_socket) + if message_type == OnlineCallbackServer._ServerMessages.NEW_DATA: + try: + self._update_data(data) + should_continue = True + except: + self._logger.warning("Error while updating data from client, closing connexion") + plt.close() + client_socket.close() + elif ( + message_type == OnlineCallbackServer._ServerMessages.EMPTY + or message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION + ): + self._logger.debug("Received empty data from client (end of stream), closing connexion") + + if should_continue: + timer_get_data = threading.Timer(self._get_data_interval, self._wait_for_new_data, (client_socket,)) + timer_get_data.start() + + def _update_data(self, data_raw: list): + header = [int(v) for v in data_raw[0].decode().split(",")] + + data = data_raw[1] + all_data = np.array(struct.unpack("d" * (len(data) // 8), data)) + + header_cmp = 0 + all_data_cmp = 0 + xdata = [] + n_phases = header[header_cmp] + header_cmp += 1 + for _ in range(n_phases): + n_nodes = header[header_cmp] + header_cmp += 1 + x_phases = [] + for _ in range(n_nodes): + n_steps = header[header_cmp] + header_cmp += 1 + + x_phases.append(all_data[all_data_cmp : all_data_cmp + n_steps]) + all_data_cmp += n_steps + xdata.append(x_phases) + + ydata = [] + n_variables = header[header_cmp] + header_cmp += 1 + for _ in range(n_variables): + n_nodes = header[header_cmp] + header_cmp += 1 + if n_nodes == 0: + n_nodes = 1 + + y_variables = [] + for _ in range(n_nodes): + n_steps = header[header_cmp] + header_cmp += 1 + + y_variables.append(all_data[all_data_cmp : all_data_cmp + n_steps]) + all_data_cmp += n_steps + ydata.append(y_variables) self._logger.debug(f"Received new data from client") - # self._plotter.update_data(data_as_list[1]) + self._plotter.update_data(xdata, ydata) class OnlineCallbackTcp(OnlineCallbackAbstract): @@ -421,9 +494,21 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str self._host = host if host else _default_host self._port = port if port else _default_port self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._initialize_connexion() - def _initialize_connexion(self): + if self.ocp.plot_ipopt_outputs: + raise NotImplementedError("The online callback with TCP does not support the plot_ipopt_outputs option") + if self.ocp.save_ipopt_iterations_info: + raise NotImplementedError( + "The online callback with TCP does not support the save_ipopt_iterations_info option" + ) + if self.ocp.plot_check_conditioning: + raise NotImplementedError( + "The online callback with TCP does not support the plot_check_conditioning option" + ) + + self._initialize_connexion(**show_options) + + def _initialize_connexion(self, **show_options): # Start the client try: self._socket.connect((self._host, self._port)) @@ -434,16 +519,32 @@ def _initialize_connexion(self): ) ocp_plot = OcpSerializable.from_ocp(self.ocp).serialize() + dummy_phase_times = OptimizationVectorHelper.extract_step_times(self.ocp, DM(np.ones(self.ocp.n_phases))) ocp_plot["dummy_phase_times"] = [] - for phase_times in OptimizationVectorHelper.extract_step_times(self.ocp, DM(np.ones(self.ocp.n_phases))): + for phase_times in dummy_phase_times: ocp_plot["dummy_phase_times"].append([np.array(v)[:, 0].tolist() for v in phase_times]) serialized_ocp = json.dumps(ocp_plot).encode() + + # Sends message type and dimensions self._socket.send( - f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION.value}\n{len(serialized_ocp)}".encode() + f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp)]}".encode() ) + if self._socket.recv(1024).decode() != "OK": + raise RuntimeError("The server did not acknowledge the connexion") # TODO ADD SHOW OPTIONS to the send self._socket.send(serialized_ocp) + if self._socket.recv(1024).decode() != "OK": + raise RuntimeError("The server did not acknowledge the connexion") + + # Wait for the server to be ready + data = self._socket.recv(1024).decode().split("\n") + if data[0] != "PLOT_READY": + raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report") + + self._plotter = PlotOcp( + self.ocp, only_initialize_variables=True, dummy_phase_times=dummy_phase_times, **show_options + ) def close(self): self._socket.send( @@ -451,7 +552,7 @@ def close(self): ) self._socket.close() - def eval(self, arg: list | tuple) -> list: + def eval(self, arg: list | tuple, force: bool = False) -> list: arg_as_bytes = [] for a in arg: to_pack = np.array(a).T.tolist() @@ -459,9 +560,54 @@ def eval(self, arg: list | tuple) -> list: to_pack = to_pack[0] arg_as_bytes.append(struct.pack("d" * len(to_pack), *to_pack)) + if not force: + self._socket.setblocking(False) + + try: + data = self._socket.recv(1024).decode() + if data != "READY_FOR_NEXT_DATA": + return [0] + except BlockingIOError: + # This is to prevent the solving to be blocked by the server if it is not ready to update the plots + return [0] + finally: + self._socket.setblocking(True) + + args_dict = {} + for i, s in enumerate(nlpsol_out()): + args_dict[s] = arg[i] + xdata_raw, ydata_raw = self._plotter.parse_data(**args_dict) + + header = f"{len(xdata_raw)}" + data_serialized = b"" + for x_nodes in xdata_raw: + header += f",{len(x_nodes)}" + for x_steps in x_nodes: + header += f",{x_steps.shape[0]}" + x_steps_tp = np.array(x_steps)[:, 0].tolist() + data_serialized += struct.pack("d" * len(x_steps_tp), *x_steps_tp) + + header += f",{len(ydata_raw)}" + for y_nodes_variable in ydata_raw: + if isinstance(y_nodes_variable, np.ndarray): + header += f",0" + y_nodes_variable = [y_nodes_variable] + else: + header += f",{len(y_nodes_variable)}" + + for y_steps in y_nodes_variable: + header += f",{y_steps.shape[0]}" + y_steps_tp = y_steps.tolist() + data_serialized += struct.pack("d" * len(y_steps_tp), *y_steps_tp) + self._socket.send( - f"{OnlineCallbackServer._ServerMessages.NEW_DATA.value}\n{[len(a) for a in arg_as_bytes]}".encode() + f"{OnlineCallbackServer._ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode() ) - for a in arg_as_bytes: - self._socket.sendall(a) + if self._socket.recv(1024).decode() != "OK": + raise RuntimeError("The server did not acknowledge the data") + + for to_send in [header.encode(), data_serialized]: + self._socket.send(to_send) + if self._socket.recv(1024).decode() != "OK": + raise RuntimeError("The server did not acknowledge the data") return [0] diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index f3c7f4feb..f04e07377 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -6,15 +6,16 @@ from matplotlib import pyplot as plt, lines from matplotlib.ticker import StrMethodFormatter +from .serializable_class import OcpSerializable from ..dynamics.ode_solver import OdeSolver from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers from ..misc.enums import PlotType, Shooting, SolutionIntegrator, QuadratureRule, InterpolationType from ..misc.mapping import Mapping, BiMapping -from ..optimization.optimization_vector import OptimizationVectorHelper from ..optimization.solution.solution import Solution from ..optimization.solution.solution_data import SolutionMerge + DEFAULT_COLORS = { PlotType.PLOT: "tab:green", PlotType.INTEGRATED: "tab:brown", @@ -130,501 +131,6 @@ def __init__( self.all_variables_in_one_subplot = all_variables_in_one_subplot -class MappingSerializable: - map_idx: list[int] - oppose: list[int] - - def __init__(self, map_idx: list, oppose: list): - self.map_idx = map_idx - self.oppose = oppose - - def map(self, obj): - from ..misc.mapping import Mapping - - return Mapping.map(self, obj) - - @classmethod - def from_mapping(cls, mapping): - return cls( - map_idx=mapping.map_idx, - oppose=mapping.oppose, - ) - - def serialize(self): - return { - "map_idx": self.map_idx, - "oppose": self.oppose, - } - - @classmethod - def deserialize(cls, data): - return cls( - map_idx=data["map_idx"], - oppose=data["oppose"], - ) - - -class BiMappingSerializable: - to_first: MappingSerializable - to_second: MappingSerializable - - def __init__(self, to_first: MappingSerializable, to_second: MappingSerializable): - self.to_first = to_first - self.to_second = to_second - - @classmethod - def from_bimapping(cls, bimapping): - return cls( - to_first=MappingSerializable.from_mapping(bimapping.to_first), - to_second=MappingSerializable.from_mapping(bimapping.to_second), - ) - - def serialize(self): - return { - "to_first": self.to_first.serialize(), - "to_second": self.to_second.serialize(), - } - - @classmethod - def deserialize(cls, data): - return cls( - to_first=MappingSerializable.deserialize(data["to_first"]), - to_second=MappingSerializable.deserialize(data["to_second"]), - ) - - -class BoundsSerializable: - min: np.ndarray | DM - max: np.ndarray | DM - - def __init__(self, min: np.ndarray | DM, max: np.ndarray | DM): - self.min = min - self.max = max - - @classmethod - def from_bounds(cls, bounds): - return cls( - min=np.array(bounds.min), - max=np.array(bounds.max), - ) - - def serialize(self): - return { - "min": self.min.tolist(), - "max": self.max.tolist(), - } - - @classmethod - def deserialize(cls, data): - return cls( - min=np.array(data["min"]), - max=np.array(data["max"]), - ) - - -class CustomPlotSerializable: - function: Callable - type: PlotType - phase_mappings: BiMappingSerializable - legend: tuple | list - combine_to: str - color: str - linestyle: str - ylim: tuple | list - bounds: BoundsSerializable - node_idx: list | slice | range - label: list - compute_derivative: bool - integration_rule: QuadratureRule - parameters: dict[str, Any] - all_variables_in_one_subplot: bool - - def __init__( - self, - function: Callable, - plot_type: PlotType, - phase_mappings: BiMapping, - legend: tuple | list, - combine_to: str, - color: str, - linestyle: str, - ylim: tuple | list, - bounds: BoundsSerializable, - node_idx: list | slice | range, - label: list, - compute_derivative: bool, - integration_rule: QuadratureRule, - parameters: dict[str, Any], - all_variables_in_one_subplot: bool, - ): - self.function = None # TODO function - self.type = plot_type - self.phase_mappings = phase_mappings - self.legend = legend - self.combine_to = combine_to - self.color = color - self.linestyle = linestyle - self.ylim = ylim - self.bounds = bounds - self.node_idx = node_idx - self.label = label - self.compute_derivative = compute_derivative - self.integration_rule = integration_rule - self.parameters = None # TODO {key: value for key, value in parameters.items()} - self.all_variables_in_one_subplot = all_variables_in_one_subplot - - @classmethod - def from_custom_plot(cls, custom_plot: CustomPlot): - return cls( - function=custom_plot.function, - plot_type=custom_plot.type, - phase_mappings=BiMappingSerializable.from_bimapping(custom_plot.phase_mappings), - legend=custom_plot.legend, - combine_to=custom_plot.combine_to, - color=custom_plot.color, - linestyle=custom_plot.linestyle, - ylim=custom_plot.ylim, - bounds=BoundsSerializable.from_bounds(custom_plot.bounds), - node_idx=custom_plot.node_idx, - label=custom_plot.label, - compute_derivative=custom_plot.compute_derivative, - integration_rule=custom_plot.integration_rule, - parameters=custom_plot.parameters, - all_variables_in_one_subplot=custom_plot.all_variables_in_one_subplot, - ) - - def serialize(self): - return { - "function": self.function, - "type": self.type.value, - "phase_mappings": self.phase_mappings.serialize(), - "legend": self.legend, - "combine_to": self.combine_to, - "color": self.color, - "linestyle": self.linestyle, - "ylim": self.ylim, - "bounds": self.bounds.serialize(), - "node_idx": self.node_idx, - "label": self.label, - "compute_derivative": self.compute_derivative, - "integration_rule": self.integration_rule.value, - "parameters": self.parameters, - "all_variables_in_one_subplot": self.all_variables_in_one_subplot, - } - - @classmethod - def deserialize(cls, data): - return cls( - function=data["function"], - plot_type=PlotType(data["type"]), - phase_mappings=BiMappingSerializable.deserialize(data["phase_mappings"]), - legend=data["legend"], - combine_to=data["combine_to"], - color=data["color"], - linestyle=data["linestyle"], - ylim=data["ylim"], - bounds=data["bounds"], - node_idx=data["node_idx"], - label=data["label"], - compute_derivative=data["compute_derivative"], - integration_rule=QuadratureRule(data["integration_rule"]), - parameters=data["parameters"], - all_variables_in_one_subplot=data["all_variables_in_one_subplot"], - ) - - -class OptimizationVariableContainerSerializable: - node_index: int - shape: tuple[int, int] - - def __init__(self, node_index: int, shape: tuple[int, int], len: int): - self.node_index = node_index - self.shape = shape - self._len = len - - def __len__(self): - return self._len - - @classmethod - def from_container(cls, ovc): - from ..optimization.optimization_variable import OptimizationVariableContainer - - ovc: OptimizationVariableContainer = ovc - - return cls( - node_index=ovc.node_index, - shape=ovc.shape, - len=len(ovc), - ) - - def serialize(self): - return { - "node_index": self.node_index, - "shape": self.shape, - "len": self._len, - } - - @classmethod - def deserialize(cls, data): - return cls( - node_index=data["node_index"], - shape=data["shape"], - len=data["len"], - ) - - -class OdeSolverSerializable: - polynomial_degree: int - type: OdeSolver - - def __init__(self, polynomial_degree: int, type: OdeSolver): - self.polynomial_degree = polynomial_degree - self.type = type - - @classmethod - def from_ode_solver(cls, ode_solver): - from ..dynamics.ode_solver import OdeSolver - - ode_solver: OdeSolver = ode_solver - - return cls( - polynomial_degree=5, - type="ode", - ) - - def serialize(self): - return { - "polynomial_degree": self.polynomial_degree, - "type": self.type, - } - - @classmethod - def deserialize(cls, data): - return cls( - polynomial_degree=data["polynomial_degree"], - type=data["type"], - ) - - -class NlpSerializable: - ns: int - phase_idx: int - - n_states_nodes: int - states: OptimizationVariableContainerSerializable - states_dot: OptimizationVariableContainerSerializable - controls: OptimizationVariableContainerSerializable - algebraic_states: OptimizationVariableContainerSerializable - parameters: OptimizationVariableContainerSerializable - numerical_timeseries: OptimizationVariableContainerSerializable - - ode_solver: OdeSolverSerializable - plot: dict[str, CustomPlotSerializable] - - def __init__( - self, - ns: int, - phase_idx: int, - n_states_nodes: int, - states: OptimizationVariableContainerSerializable, - states_dot: OptimizationVariableContainerSerializable, - controls: OptimizationVariableContainerSerializable, - algebraic_states: OptimizationVariableContainerSerializable, - parameters: OptimizationVariableContainerSerializable, - numerical_timeseries: OptimizationVariableContainerSerializable, - ode_solver: OdeSolverSerializable, - plot: dict[str, CustomPlotSerializable], - ): - self.ns = ns - self.phase_idx = phase_idx - self.n_states_nodes = n_states_nodes - self.states = states - self.states_dot = states_dot - self.controls = controls - self.algebraic_states = algebraic_states - self.parameters = parameters - self.numerical_timeseries = numerical_timeseries - self.ode_solver = ode_solver - self.plot = plot - - @classmethod - def from_nlp(cls, nlp): - from ..optimization.non_linear_program import NonLinearProgram - - nlp: NonLinearProgram = nlp - - return cls( - ns=nlp.ns, - phase_idx=nlp.phase_idx, - n_states_nodes=nlp.n_states_nodes, - states=OptimizationVariableContainerSerializable.from_container(nlp.states), - states_dot=OptimizationVariableContainerSerializable.from_container(nlp.states_dot), - controls=OptimizationVariableContainerSerializable.from_container(nlp.controls), - algebraic_states=OptimizationVariableContainerSerializable.from_container(nlp.algebraic_states), - parameters=OptimizationVariableContainerSerializable.from_container(nlp.parameters), - numerical_timeseries=OptimizationVariableContainerSerializable.from_container(nlp.numerical_timeseries), - ode_solver=OdeSolverSerializable.from_ode_solver(nlp.ode_solver), - plot={key: CustomPlotSerializable.from_custom_plot(nlp.plot[key]) for key in nlp.plot}, - ) - - def serialize(self): - return { - "ns": self.ns, - "phase_idx": self.phase_idx, - "n_states_nodes": self.n_states_nodes, - "states": self.states.serialize(), - "states_dot": self.states_dot.serialize(), - "controls": self.controls.serialize(), - "algebraic_states": self.algebraic_states.serialize(), - "parameters": self.parameters.serialize(), - "numerical_timeseries": self.numerical_timeseries.serialize(), - "ode_solver": self.ode_solver.serialize(), - "plot": {key: plot.serialize() for key, plot in self.plot.items()}, - } - - @classmethod - def deserialize(cls, data): - return cls( - ns=data["ns"], - phase_idx=data["phase_idx"], - n_states_nodes=data["n_states_nodes"], - states=OptimizationVariableContainerSerializable.deserialize(data["states"]), - states_dot=OptimizationVariableContainerSerializable.deserialize(data["states_dot"]), - controls=OptimizationVariableContainerSerializable.deserialize(data["controls"]), - algebraic_states=OptimizationVariableContainerSerializable.deserialize(data["algebraic_states"]), - parameters=OptimizationVariableContainerSerializable.deserialize(data["parameters"]), - numerical_timeseries=OptimizationVariableContainerSerializable.deserialize(data["numerical_timeseries"]), - ode_solver=OdeSolverSerializable.deserialize(data["ode_solver"]), - plot={key: CustomPlotSerializable.deserialize(plot) for key, plot in data["plot"].items()}, - ) - - -class SaveIterationsInfoSerializable: - path_to_results: str - result_file_name: str | list[str] - nb_iter_save: int - current_iter: int - f_list: list[int] - - def __init__( - self, path_to_results: str, result_file_name: str, nb_iter_save: int, current_iter: int, f_list: list[int] - ): - self.path_to_results = path_to_results - self.result_file_name = result_file_name - self.nb_iter_save = nb_iter_save - self.current_iter = current_iter - self.f_list = f_list - - @classmethod - def from_save_iterations_info(cls, save_iterations_info): - from .ipopt_output_plot import SaveIterationsInfo - - save_iterations_info: SaveIterationsInfo = save_iterations_info - - if save_iterations_info is None: - return None - - return cls( - path_to_results=save_iterations_info.path_to_results, - result_file_name=save_iterations_info.result_file_name, - nb_iter_save=save_iterations_info.nb_iter_save, - current_iter=save_iterations_info.current_iter, - f_list=save_iterations_info.f_list, - ) - - def serialize(self): - return { - "path_to_results": self.path_to_results, - "result_file_name": self.result_file_name, - "nb_iter_save": self.nb_iter_save, - "current_iter": self.current_iter, - "f_list": self.f_list, - } - - @classmethod - def deserialize(cls, data): - return cls( - path_to_results=data["path_to_results"], - result_file_name=data["result_file_name"], - nb_iter_save=data["nb_iter_save"], - current_iter=data["current_iter"], - f_list=data["f_list"], - ) - - -class OcpSerializable: - n_phases: int - nlp: list[NlpSerializable] - - time_phase_mapping: BiMappingSerializable - - plot_ipopt_outputs: bool - plot_check_conditioning: bool - save_ipopt_iterations_info: SaveIterationsInfoSerializable - - def __init__( - self, - n_phases: int, - nlp: list[NlpSerializable], - time_phase_mapping: BiMappingSerializable, - plot_ipopt_outputs: bool, - plot_check_conditioning: bool, - save_ipopt_iterations_info: SaveIterationsInfoSerializable, - ): - self.n_phases = n_phases - self.nlp = nlp - - self.time_phase_mapping = time_phase_mapping - - self.plot_ipopt_outputs = plot_ipopt_outputs - self.plot_check_conditioning = plot_check_conditioning - self.save_ipopt_iterations_info = save_ipopt_iterations_info - - @classmethod - def from_ocp(cls, ocp): - from ..optimization.optimal_control_program import OptimalControlProgram - - ocp: OptimalControlProgram = ocp - - return cls( - n_phases=ocp.n_phases, - nlp=[NlpSerializable.from_nlp(nlp) for nlp in ocp.nlp], - time_phase_mapping=BiMappingSerializable.from_bimapping(ocp.time_phase_mapping), - plot_ipopt_outputs=ocp.plot_ipopt_outputs, - plot_check_conditioning=ocp.plot_check_conditioning, - save_ipopt_iterations_info=SaveIterationsInfoSerializable.from_save_iterations_info( - ocp.save_ipopt_iterations_info - ), - ) - - def serialize(self): - return { - "n_phases": self.n_phases, - "nlp": [nlp.serialize() for nlp in self.nlp], - "time_phase_mapping": self.time_phase_mapping.serialize(), - "plot_ipopt_outputs": self.plot_ipopt_outputs, - "plot_check_conditioning": self.plot_check_conditioning, - "save_ipopt_iterations_info": ( - None if self.save_ipopt_iterations_info is None else self.save_ipopt_iterations_info.serialize() - ), - } - - @classmethod - def deserialize(cls, data): - return cls( - n_phases=data["n_phases"], - nlp=[NlpSerializable.deserialize(nlp) for nlp in data["nlp"]], - time_phase_mapping=BiMappingSerializable.deserialize(data["time_phase_mapping"]), - plot_ipopt_outputs=data["plot_ipopt_outputs"], - plot_check_conditioning=data["plot_check_conditioning"], - save_ipopt_iterations_info=( - None - if data["save_ipopt_iterations_info"] is None - else SaveIterationsInfoSerializable.deserialize(data["save_ipopt_iterations_info"]) - ), - ) - - class PlotOcp: """ Attributes @@ -688,8 +194,6 @@ class PlotOcp: Update ydata from the variable a solution structure __update_xdata(self) Update of the time axes in plots - _append_to_ydata(self, data: list) - Parse the data list to create a single list of all ydata that will fit the plots vector __update_axes(self) Update the plotted data from ydata __compute_ylim(min_val: np.ndarray | DM, max_val: np.ndarray | DM, factor: float) -> tuple: @@ -706,6 +210,7 @@ def __init__( shooting_type: Shooting = Shooting.MULTIPLE, integrator: SolutionIntegrator = SolutionIntegrator.OCP, dummy_phase_times: list[list[float]] = None, + only_initialize_variables: bool = False, ): """ Prepares the figures during the simulation @@ -724,6 +229,9 @@ def __init__( Use the ode defined by OCP or use a separate integrator provided by scipy dummy_phase_times: list[list[float]] The time of each phase + only_initialize_variables: bool + If the plots should be initialized but not shown (this is useful for the online plot which must be declared + on the server side and on the client side) """ self.ocp = ocp self.plot_options = { @@ -736,7 +244,6 @@ def __init__( "vertical_lines": {"color": "k", "linestyle": "--", "linewidth": 1.2}, } - self.ydata = [] self.n_nodes = 0 self.t = [] @@ -758,31 +265,33 @@ def __init__( self.top_margin: int | None = None self.height_step: int | None = None self.width_step: int | None = None - self._organize_windows(len(self.ocp.nlp[0].states) + len(self.ocp.nlp[0].controls)) + if not only_initialize_variables: + self._organize_windows(len(self.ocp.nlp[0].states) + len(self.ocp.nlp[0].controls)) self.custom_plots = {} self.variable_sizes = [] self.show_bounds = show_bounds - self.__create_plots() + self._create_plots(only_initialize_variables) self.shooting_type = shooting_type - horz = 0 - vert = 1 if len(self.all_figures) < self.n_vertical_windows * self.n_horizontal_windows else 0 - for i, fig in enumerate(self.all_figures): - if self.automatically_organize: - try: - fig.canvas.manager.window.move( - int(vert * self.width_step), int(self.top_margin + horz * self.height_step) - ) - vert += 1 - if vert >= self.n_vertical_windows: - horz += 1 - vert = 0 - except AttributeError: - pass - fig.canvas.draw() - if self.plot_options["general_options"]["use_tight_layout"]: - fig.tight_layout() + if not only_initialize_variables: + horz = 0 + vert = 1 if len(self.all_figures) < self.n_vertical_windows * self.n_horizontal_windows else 0 + for i, fig in enumerate(self.all_figures): + if self.automatically_organize: + try: + fig.canvas.manager.window.move( + int(vert * self.width_step), int(self.top_margin + horz * self.height_step) + ) + vert += 1 + if vert >= self.n_vertical_windows: + horz += 1 + vert = 0 + except AttributeError: + pass + fig.canvas.draw() + if self.plot_options["general_options"]["use_tight_layout"]: + fig.tight_layout() if self.ocp.plot_ipopt_outputs: from ..gui.ipopt_output_plot import create_ipopt_output_plot @@ -809,9 +318,15 @@ def _update_time_vector(self, phase_times): self.t_integrated.append(time) self.t.append(np.linspace(float(time[0][0]), float(time[-1][-1]), nlp.n_states_nodes)) - def __create_plots(self): + def _create_plots(self, only_initialize_variables: bool): """ Setup the plots + + Parameters + ---------- + only_initialize_variables: bool + If the plots should be initialized but not shown (this is useful for the online plot which must be declared + on the server side and on the client side) """ def legend_without_duplicate_labels(ax): @@ -905,43 +420,45 @@ def legend_without_duplicate_labels(ax): for var_idx, variable in enumerate(self.variable_sizes[i]): y_range_var_idx = all_keys_across_phases.index(variable) - if nlp.plot[variable].combine_to: - self.axes[variable] = self.axes[nlp.plot[variable].combine_to] - axes = self.axes[variable][1] - elif i > 0 and variable in self.axes: - axes = self.axes[variable][1] - else: - nb_subplots = max( - [ - ( - max( - len(nlp.plot[variable].phase_mappings.to_first.map_idx), - max(nlp.plot[variable].phase_mappings.to_first.map_idx) + 1, - ) - if variable in nlp.plot - else 0 - ) - for nlp in self.ocp.nlp - ] - ) - # TODO: get rid of all_variables_in_one_subplot by fixing the mapping appropriately - if not nlp.plot[variable].all_variables_in_one_subplot: - n_cols, n_rows = PlotOcp._generate_windows_size(nb_subplots) + if not only_initialize_variables: + if nlp.plot[variable].combine_to: + self.axes[variable] = self.axes[nlp.plot[variable].combine_to] + axes = self.axes[variable][1] + elif i > 0 and variable in self.axes: + axes = self.axes[variable][1] else: - n_cols = 1 - n_rows = 1 - axes = self.__add_new_axis(variable, nb_subplots, n_rows, n_cols) - self.axes[variable] = [nlp.plot[variable], axes] + nb_subplots = max( + [ + ( + max( + len(nlp.plot[variable].phase_mappings.to_first.map_idx), + max(nlp.plot[variable].phase_mappings.to_first.map_idx) + 1, + ) + if variable in nlp.plot + else 0 + ) + for nlp in self.ocp.nlp + ] + ) + + # TODO: get rid of all_variables_in_one_subplot by fixing the mapping appropriately + if not nlp.plot[variable].all_variables_in_one_subplot: + n_cols, n_rows = PlotOcp._generate_windows_size(nb_subplots) + else: + n_cols = 1 + n_rows = 1 + axes = self._add_new_axis(variable, nb_subplots, n_rows, n_cols) + self.axes[variable] = [nlp.plot[variable], axes] - if not y_min_all[y_range_var_idx]: - y_min_all[y_range_var_idx] = [np.inf] * nb_subplots - y_max_all[y_range_var_idx] = [-np.inf] * nb_subplots + if not y_min_all[y_range_var_idx]: + y_min_all[y_range_var_idx] = [np.inf] * nb_subplots + y_max_all[y_range_var_idx] = [-np.inf] * nb_subplots if variable not in self.custom_plots: self.custom_plots[variable] = [ nlp_tp.plot[variable] if variable in nlp_tp.plot else None for nlp_tp in self.ocp.nlp ] - if not self.custom_plots[variable][i]: + if not self.custom_plots[variable][i] or only_initialize_variables: continue mapping_to_first_index = nlp.plot[variable].phase_mappings.to_first.map_idx @@ -1113,7 +630,7 @@ def legend_without_duplicate_labels(ax): [ax.step(self.t[i], bounds_max, where="post", **self.plot_options["bounds"]), i] ) - def __add_new_axis(self, variable: str, nb: int, n_rows: int, n_cols: int): + def _add_new_axis(self, variable: str, nb: int, n_rows: int, n_cols: int): """ Add a new axis to the axes pool @@ -1186,22 +703,24 @@ def show(): plt.show() - def update_data( - self, - args: dict, - ): + def parse_data(self, **args) -> tuple[list, list]: """ - Update ydata from the variable a solution structure + Parse the data to be plotted, the return of this method can be passed to update_data to update the plots Parameters ---------- - v: np.ndarray - The data to parse + ocp: OptimalControlProgram + A reference to the full ocp + variable_sizes: list[int] + The size of all variables. This is the reference to the PlotOcp.variable_sizes (which can't be accessed + from this static method) + custom_plots: dict + The dictionary of all the CustomPlot. This is the reference to the PlotOcp.custom_plots (which can't be + accessed from this static method) """ - from ..interfaces.interface_utils import get_numerical_timeseries - self.ydata = [] + ydata = [] sol = Solution.from_vector(self.ocp, args["x"]) data_states_decision = sol.decision_states(scaled=True, to_merge=SolutionMerge.KEYS) @@ -1222,7 +741,7 @@ def update_data( if self.ocp.n_phases == 1: time_stepwise = [time_stepwise] phases_dt = sol.phases_dt - self._update_xdata(time_stepwise) + xdata = time_stepwise for nlp in self.ocp.nlp: @@ -1257,9 +776,33 @@ def update_data( mapped_y_data = [] for i in nlp.plot[key].phase_mappings.to_first.map_idx: mapped_y_data.append(y_data[i]) - self._append_to_ydata(mapped_y_data) + for y in mapped_y_data: + ydata.append(y) - self.__update_axes() + return xdata, ydata + + def update_data( + self, + xdata: dict, + ydata: list, + **args: dict, + ): + """ + Update ydata from the variable a solution structure + + Parameters + ---------- + xdata: dict + The time vector + ydata: list + The actual current data to be plotted + args: dict + The same args as the parse_data method (that is so ipopt outputs can be plotted, this should be done properly + in the future, when ready, remove this parameter) + """ + + self._update_xdata(xdata) + self._update_ydata(ydata) if self.ocp.plot_ipopt_outputs: from ..gui.ipopt_output_plot import update_ipopt_output_plot @@ -1419,27 +962,14 @@ def _update_xdata(self, phase_times): for i, time in enumerate(intersections_time): self.plots_vertical_lines[p * n + i].set_xdata([float(time), float(time)]) - def _append_to_ydata(self, data: list | np.ndarray): - """ - Parse the data list to create a single list of all ydata that will fit the plots vector - - Parameters - ---------- - data: list - The data list to copy - """ - - for y in data: - self.ydata.append(y) - - def __update_axes(self): + def _update_ydata(self, ydata): """ Update the plotted data from ydata """ - assert len(self.plots) == len(self.ydata) + assert len(self.plots) == len(ydata) for i, plot in enumerate(self.plots): - y = self.ydata[i] + y = ydata[i] if y is None: # Jump the plots which are empty y = (np.nan,) * len(plot[2]) diff --git a/bioptim/gui/serializable_class.py b/bioptim/gui/serializable_class.py new file mode 100644 index 000000000..202b37b54 --- /dev/null +++ b/bioptim/gui/serializable_class.py @@ -0,0 +1,623 @@ +from typing import Any, Callable + +from casadi import DM, Function +import numpy as np + +from ..dynamics.ode_solver import OdeSolver +from ..limits.penalty_option import PenaltyOption +from ..misc.mapping import BiMapping +from ..misc.enums import PlotType, QuadratureRule + + +class CasadiFunctionSerializable: + _size_in: dict[str, int] + + def __init__(self, size_in: dict[str, int]): + self._size_in = size_in + + @classmethod + def from_casadi_function(cls, casadi_function): + casadi_function: Function = casadi_function + + return cls( + size_in={ + "x": casadi_function.size_in("x"), + "u": casadi_function.size_in("u"), + "p": casadi_function.size_in("p"), + "a": casadi_function.size_in("a"), + "d": casadi_function.size_in("d"), + } + ) + + def serialize(self): + return { + "size_in": self._size_in, + } + + @classmethod + def deserialize(cls, data): + return cls( + size_in=data["size_in"], + ) + + def size_in(self, key: str) -> int: + return self._size_in[key] + + +class PenaltySerializable: + function: list[CasadiFunctionSerializable | None] + + def __init__(self, function: list[CasadiFunctionSerializable | None]): + self.function = function + + @classmethod + def from_penalty(cls, penalty): + penalty: PenaltyOption = penalty + + function = [] + for f in penalty.function: + function.append(None if f is None else CasadiFunctionSerializable.from_casadi_function(f)) + return cls( + function=function, + ) + + def serialize(self): + return { + "function": [None if f is None else f.serialize() for f in self.function], + } + + @classmethod + def deserialize(cls, data): + return cls( + function=[None if f is None else CasadiFunctionSerializable.deserialize(f) for f in data["function"]], + ) + + +class MappingSerializable: + map_idx: list[int] + oppose: list[int] + + def __init__(self, map_idx: list, oppose: list): + self.map_idx = map_idx + self.oppose = oppose + + def map(self, obj): + from ..misc.mapping import Mapping + + return Mapping.map(self, obj) + + @classmethod + def from_mapping(cls, mapping): + return cls( + map_idx=mapping.map_idx, + oppose=mapping.oppose, + ) + + def serialize(self): + return { + "map_idx": self.map_idx, + "oppose": self.oppose, + } + + @classmethod + def deserialize(cls, data): + return cls( + map_idx=data["map_idx"], + oppose=data["oppose"], + ) + + +class BiMappingSerializable: + to_first: MappingSerializable + to_second: MappingSerializable + + def __init__(self, to_first: MappingSerializable, to_second: MappingSerializable): + self.to_first = to_first + self.to_second = to_second + + @classmethod + def from_bimapping(cls, bimapping): + return cls( + to_first=MappingSerializable.from_mapping(bimapping.to_first), + to_second=MappingSerializable.from_mapping(bimapping.to_second), + ) + + def serialize(self): + return { + "to_first": self.to_first.serialize(), + "to_second": self.to_second.serialize(), + } + + @classmethod + def deserialize(cls, data): + return cls( + to_first=MappingSerializable.deserialize(data["to_first"]), + to_second=MappingSerializable.deserialize(data["to_second"]), + ) + + +class BoundsSerializable: + min: np.ndarray | DM + max: np.ndarray | DM + + def __init__(self, min: np.ndarray | DM, max: np.ndarray | DM): + self.min = min + self.max = max + + @classmethod + def from_bounds(cls, bounds): + return cls( + min=np.array(bounds.min), + max=np.array(bounds.max), + ) + + def serialize(self): + return { + "min": self.min.tolist(), + "max": self.max.tolist(), + } + + @classmethod + def deserialize(cls, data): + return cls( + min=np.array(data["min"]), + max=np.array(data["max"]), + ) + + +class CustomPlotSerializable: + _function: Callable + type: PlotType + phase_mappings: BiMappingSerializable + legend: tuple | list + combine_to: str + color: str + linestyle: str + ylim: tuple | list + bounds: BoundsSerializable + node_idx: list | slice | range + label: list + compute_derivative: bool + integration_rule: QuadratureRule + parameters: dict[str, Any] + all_variables_in_one_subplot: bool + + def __init__( + self, + function: Callable, + plot_type: PlotType, + phase_mappings: BiMapping, + legend: tuple | list, + combine_to: str, + color: str, + linestyle: str, + ylim: tuple | list, + bounds: BoundsSerializable, + node_idx: list | slice | range, + label: list, + compute_derivative: bool, + integration_rule: QuadratureRule, + parameters: dict[str, Any], + all_variables_in_one_subplot: bool, + ): + self._function = function + self.type = plot_type + self.phase_mappings = phase_mappings + self.legend = legend + self.combine_to = combine_to + self.color = color + self.linestyle = linestyle + self.ylim = ylim + self.bounds = bounds + self.node_idx = node_idx + self.label = label + self.compute_derivative = compute_derivative + self.integration_rule = integration_rule + self.parameters = parameters + self.all_variables_in_one_subplot = all_variables_in_one_subplot + + @classmethod + def from_custom_plot(cls, custom_plot): + from .plot import CustomPlot + + custom_plot: CustomPlot = custom_plot + + _function = None + parameters = {} + for key in custom_plot.parameters.keys(): + if key == "penalty": + # This is a hack to emulate what PlotOcp._create_plots needs while not being able to actually serialize + # the function + parameters[key] = PenaltySerializable.from_penalty(custom_plot.parameters[key]) + + penalty = custom_plot.parameters[key] + + casadi_function = penalty.function[0] if penalty.function[0] is not None else penalty.function[-1] + size_x = casadi_function.size_in("x")[0] + size_dt = casadi_function.size_in("dt")[0] + size_u = casadi_function.size_in("u")[0] + size_p = casadi_function.size_in("p")[0] + size_a = casadi_function.size_in("a")[0] + size_d = casadi_function.size_in("d")[0] + _function = custom_plot.function( + 0, # t0 + np.zeros(size_dt), # phases_dt + custom_plot.node_idx[0], # node_idx + np.zeros((size_x, 1)), # states + np.zeros((size_u, 1)), # controls + np.zeros((size_p, 1)), # parameters + np.zeros((size_a, 1)), # algebraic_states + np.zeros((size_d, 1)), # numerical_timeseries + **custom_plot.parameters, # parameters + ) + + else: + raise NotImplementedError(f"Parameter {key} is not implemented in the serialization") + + return cls( + function=_function, + plot_type=custom_plot.type, + phase_mappings=( + None + if custom_plot.phase_mappings is None + else BiMappingSerializable.from_bimapping(custom_plot.phase_mappings) + ), + legend=custom_plot.legend, + combine_to=custom_plot.combine_to, + color=custom_plot.color, + linestyle=custom_plot.linestyle, + ylim=custom_plot.ylim, + bounds=None if custom_plot.bounds is None else BoundsSerializable.from_bounds(custom_plot.bounds), + node_idx=custom_plot.node_idx, + label=custom_plot.label, + compute_derivative=custom_plot.compute_derivative, + integration_rule=custom_plot.integration_rule, + parameters=parameters, + all_variables_in_one_subplot=custom_plot.all_variables_in_one_subplot, + ) + + def serialize(self): + return { + "function": None if self._function is None else np.array(self._function)[:, 0].tolist(), + "type": self.type.value, + "phase_mappings": None if self.phase_mappings is None else self.phase_mappings.serialize(), + "legend": self.legend, + "combine_to": self.combine_to, + "color": self.color, + "linestyle": self.linestyle, + "ylim": self.ylim, + "bounds": None if self.bounds is None else self.bounds.serialize(), + "node_idx": self.node_idx, + "label": self.label, + "compute_derivative": self.compute_derivative, + "integration_rule": self.integration_rule.value, + "parameters": {key: param.serialize() for key, param in self.parameters.items()}, + "all_variables_in_one_subplot": self.all_variables_in_one_subplot, + } + + @classmethod + def deserialize(cls, data): + + parameters = {} + for key in data["parameters"].keys(): + if key == "penalty": + parameters[key] = PenaltySerializable.deserialize(data["parameters"][key]) + else: + raise NotImplementedError(f"Parameter {key} is not implemented in the serialization") + + return cls( + function=None if data["function"] is None else DM(data["function"]), + plot_type=PlotType(data["type"]), + phase_mappings=( + None if data["phase_mappings"] is None else BiMappingSerializable.deserialize(data["phase_mappings"]) + ), + legend=data["legend"], + combine_to=data["combine_to"], + color=data["color"], + linestyle=data["linestyle"], + ylim=data["ylim"], + bounds=None if data["bounds"] is None else BoundsSerializable.deserialize(data["bounds"]), + node_idx=data["node_idx"], + label=data["label"], + compute_derivative=data["compute_derivative"], + integration_rule=QuadratureRule(data["integration_rule"]), + parameters=parameters, + all_variables_in_one_subplot=data["all_variables_in_one_subplot"], + ) + + def function(self, *args, **kwargs): + # This should not be called to get actual values, as it is evaluated at 0. This is solely to get the size of + # the function + return self._function + + +class OptimizationVariableContainerSerializable: + node_index: int + shape: tuple[int, int] + + def __init__(self, node_index: int, shape: tuple[int, int], len: int): + self.node_index = node_index + self.shape = shape + self._len = len + + def __len__(self): + return self._len + + @classmethod + def from_container(cls, ovc): + from ..optimization.optimization_variable import OptimizationVariableContainer + + ovc: OptimizationVariableContainer = ovc + + return cls( + node_index=ovc.node_index, + shape=ovc.shape, + len=len(ovc), + ) + + def serialize(self): + return { + "node_index": self.node_index, + "shape": self.shape, + "len": self._len, + } + + @classmethod + def deserialize(cls, data): + return cls( + node_index=data["node_index"], + shape=data["shape"], + len=data["len"], + ) + + +class OdeSolverSerializable: + polynomial_degree: int + type: OdeSolver + + def __init__(self, polynomial_degree: int, type: OdeSolver): + self.polynomial_degree = polynomial_degree + self.type = type + + @classmethod + def from_ode_solver(cls, ode_solver): + from ..dynamics.ode_solver import OdeSolver + + ode_solver: OdeSolver = ode_solver + + return cls( + polynomial_degree=5, + type="ode", + ) + + def serialize(self): + return { + "polynomial_degree": self.polynomial_degree, + "type": self.type, + } + + @classmethod + def deserialize(cls, data): + return cls( + polynomial_degree=data["polynomial_degree"], + type=data["type"], + ) + + +class NlpSerializable: + ns: int + phase_idx: int + + n_states_nodes: int + states: OptimizationVariableContainerSerializable + states_dot: OptimizationVariableContainerSerializable + controls: OptimizationVariableContainerSerializable + algebraic_states: OptimizationVariableContainerSerializable + parameters: OptimizationVariableContainerSerializable + numerical_timeseries: OptimizationVariableContainerSerializable + + ode_solver: OdeSolverSerializable + plot: dict[str, CustomPlotSerializable] + + def __init__( + self, + ns: int, + phase_idx: int, + n_states_nodes: int, + states: OptimizationVariableContainerSerializable, + states_dot: OptimizationVariableContainerSerializable, + controls: OptimizationVariableContainerSerializable, + algebraic_states: OptimizationVariableContainerSerializable, + parameters: OptimizationVariableContainerSerializable, + numerical_timeseries: OptimizationVariableContainerSerializable, + ode_solver: OdeSolverSerializable, + plot: dict[str, CustomPlotSerializable], + ): + self.ns = ns + self.phase_idx = phase_idx + self.n_states_nodes = n_states_nodes + self.states = states + self.states_dot = states_dot + self.controls = controls + self.algebraic_states = algebraic_states + self.parameters = parameters + self.numerical_timeseries = numerical_timeseries + self.ode_solver = ode_solver + self.plot = plot + + @classmethod + def from_nlp(cls, nlp): + from ..optimization.non_linear_program import NonLinearProgram + + nlp: NonLinearProgram = nlp + + return cls( + ns=nlp.ns, + phase_idx=nlp.phase_idx, + n_states_nodes=nlp.n_states_nodes, + states=OptimizationVariableContainerSerializable.from_container(nlp.states), + states_dot=OptimizationVariableContainerSerializable.from_container(nlp.states_dot), + controls=OptimizationVariableContainerSerializable.from_container(nlp.controls), + algebraic_states=OptimizationVariableContainerSerializable.from_container(nlp.algebraic_states), + parameters=OptimizationVariableContainerSerializable.from_container(nlp.parameters), + numerical_timeseries=OptimizationVariableContainerSerializable.from_container(nlp.numerical_timeseries), + ode_solver=OdeSolverSerializable.from_ode_solver(nlp.ode_solver), + plot={key: CustomPlotSerializable.from_custom_plot(nlp.plot[key]) for key in nlp.plot}, + ) + + def serialize(self): + return { + "ns": self.ns, + "phase_idx": self.phase_idx, + "n_states_nodes": self.n_states_nodes, + "states": self.states.serialize(), + "states_dot": self.states_dot.serialize(), + "controls": self.controls.serialize(), + "algebraic_states": self.algebraic_states.serialize(), + "parameters": self.parameters.serialize(), + "numerical_timeseries": self.numerical_timeseries.serialize(), + "ode_solver": self.ode_solver.serialize(), + "plot": {key: plot.serialize() for key, plot in self.plot.items()}, + } + + @classmethod + def deserialize(cls, data): + return cls( + ns=data["ns"], + phase_idx=data["phase_idx"], + n_states_nodes=data["n_states_nodes"], + states=OptimizationVariableContainerSerializable.deserialize(data["states"]), + states_dot=OptimizationVariableContainerSerializable.deserialize(data["states_dot"]), + controls=OptimizationVariableContainerSerializable.deserialize(data["controls"]), + algebraic_states=OptimizationVariableContainerSerializable.deserialize(data["algebraic_states"]), + parameters=OptimizationVariableContainerSerializable.deserialize(data["parameters"]), + numerical_timeseries=OptimizationVariableContainerSerializable.deserialize(data["numerical_timeseries"]), + ode_solver=OdeSolverSerializable.deserialize(data["ode_solver"]), + plot={key: CustomPlotSerializable.deserialize(plot) for key, plot in data["plot"].items()}, + ) + + +class SaveIterationsInfoSerializable: + path_to_results: str + result_file_name: str | list[str] + nb_iter_save: int + current_iter: int + f_list: list[int] + + def __init__( + self, path_to_results: str, result_file_name: str, nb_iter_save: int, current_iter: int, f_list: list[int] + ): + self.path_to_results = path_to_results + self.result_file_name = result_file_name + self.nb_iter_save = nb_iter_save + self.current_iter = current_iter + self.f_list = f_list + + @classmethod + def from_save_iterations_info(cls, save_iterations_info): + from .ipopt_output_plot import SaveIterationsInfo + + save_iterations_info: SaveIterationsInfo = save_iterations_info + + if save_iterations_info is None: + return None + + return cls( + path_to_results=save_iterations_info.path_to_results, + result_file_name=save_iterations_info.result_file_name, + nb_iter_save=save_iterations_info.nb_iter_save, + current_iter=save_iterations_info.current_iter, + f_list=save_iterations_info.f_list, + ) + + def serialize(self): + return { + "path_to_results": self.path_to_results, + "result_file_name": self.result_file_name, + "nb_iter_save": self.nb_iter_save, + "current_iter": self.current_iter, + "f_list": self.f_list, + } + + @classmethod + def deserialize(cls, data): + return cls( + path_to_results=data["path_to_results"], + result_file_name=data["result_file_name"], + nb_iter_save=data["nb_iter_save"], + current_iter=data["current_iter"], + f_list=data["f_list"], + ) + + +class OcpSerializable: + n_phases: int + nlp: list[NlpSerializable] + + time_phase_mapping: BiMappingSerializable + + plot_ipopt_outputs: bool + plot_check_conditioning: bool + save_ipopt_iterations_info: SaveIterationsInfoSerializable + + def __init__( + self, + n_phases: int, + nlp: list[NlpSerializable], + time_phase_mapping: BiMappingSerializable, + plot_ipopt_outputs: bool, + plot_check_conditioning: bool, + save_ipopt_iterations_info: SaveIterationsInfoSerializable, + ): + self.n_phases = n_phases + self.nlp = nlp + + self.time_phase_mapping = time_phase_mapping + + self.plot_ipopt_outputs = plot_ipopt_outputs + self.plot_check_conditioning = plot_check_conditioning + self.save_ipopt_iterations_info = save_ipopt_iterations_info + + @classmethod + def from_ocp(cls, ocp): + from ..optimization.optimal_control_program import OptimalControlProgram + + ocp: OptimalControlProgram = ocp + + return cls( + n_phases=ocp.n_phases, + nlp=[NlpSerializable.from_nlp(nlp) for nlp in ocp.nlp], + time_phase_mapping=BiMappingSerializable.from_bimapping(ocp.time_phase_mapping), + plot_ipopt_outputs=ocp.plot_ipopt_outputs, + plot_check_conditioning=ocp.plot_check_conditioning, + save_ipopt_iterations_info=SaveIterationsInfoSerializable.from_save_iterations_info( + ocp.save_ipopt_iterations_info + ), + ) + + def serialize(self): + return { + "n_phases": self.n_phases, + "nlp": [nlp.serialize() for nlp in self.nlp], + "time_phase_mapping": self.time_phase_mapping.serialize(), + "plot_ipopt_outputs": self.plot_ipopt_outputs, + "plot_check_conditioning": self.plot_check_conditioning, + "save_ipopt_iterations_info": ( + None if self.save_ipopt_iterations_info is None else self.save_ipopt_iterations_info.serialize() + ), + } + + @classmethod + def deserialize(cls, data): + return cls( + n_phases=data["n_phases"], + nlp=[NlpSerializable.deserialize(nlp) for nlp in data["nlp"]], + time_phase_mapping=BiMappingSerializable.deserialize(data["time_phase_mapping"]), + plot_ipopt_outputs=data["plot_ipopt_outputs"], + plot_check_conditioning=data["plot_check_conditioning"], + save_ipopt_iterations_info=( + None + if data["save_ipopt_iterations_info"] is None + else SaveIterationsInfoSerializable.deserialize(data["save_ipopt_iterations_info"]) + ), + ) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 46bdd33c0..f8ff19844 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -126,6 +126,17 @@ def generic_solve(interface, expand_during_shake_tree=False) -> dict: interface.out["sol"]["status"] = int(not interface.ocp_solver.stats()["success"]) interface.out["sol"]["solver"] = interface.solver_name + # Make sure the graphs are showing the last iteration + if interface.opts.show_online_optim: + to_eval = [ + interface.out["sol"]["x"], + interface.out["sol"]["f"], + interface.out["sol"]["g"], + interface.out["sol"]["lam_x"], + interface.out["sol"]["lam_g"], + interface.out["sol"]["lam_p"], + ] + interface.options_common["iteration_callback"].eval(to_eval, force=True) return interface.out diff --git a/bioptim/optimization/optimal_control_program.py b/bioptim/optimization/optimal_control_program.py index 19a44477a..cc5468e2d 100644 --- a/bioptim/optimization/optimal_control_program.py +++ b/bioptim/optimization/optimal_control_program.py @@ -1363,7 +1363,7 @@ def prepare_plots( show_bounds=show_bounds, shooting_type=shooting_type, integrator=integrator, - dummy_phase_times=OptimizationVectorHelper.extract_step_times(self.ocp, casadi.DM(np.ones(self.ocp.n_phases))), + dummy_phase_times=OptimizationVectorHelper.extract_step_times(self, casadi.DM(np.ones(self.n_phases))), ) def check_conditioning(self): diff --git a/bioptim/optimization/solution/solution.py b/bioptim/optimization/solution/solution.py index 9a65dbdd6..45485efdd 100644 --- a/bioptim/optimization/solution/solution.py +++ b/bioptim/optimization/solution/solution.py @@ -1206,7 +1206,7 @@ def graphs( """ plot_ocp = self.ocp.prepare_plots(automatically_organize, show_bounds, shooting_type, integrator) - plot_ocp.update_data({"x": self.vector}) + plot_ocp.update_data(*plot_ocp.parse_data(**{"x": self.vector})) if save_name: if save_name.endswith(".png"): save_name = save_name[:-4] From 720187f16a8ddc080aa5023ddaec85f47221d579 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Tue, 30 Jul 2024 09:54:11 -0400 Subject: [PATCH 05/17] More reactive online graphs with multiprocess --- bioptim/gui/online_callback.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index 5b2c077ab..c8a880de0 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -196,6 +196,10 @@ def close(self): self.plot_process.kill() def eval(self, arg: list | tuple, force: bool = False) -> list: + # Dequeuing the data by removing previous not useful data + while not self.queue.empty(): + self.queue.get() + send = self.queue.put args_dict = {} for i, s in enumerate(nlpsol_out()): @@ -230,8 +234,9 @@ def __init__(self, ocp): A reference to the ocp to show """ - self.ocp: OcpSerializable = ocp + self._ocp: OcpSerializable = ocp self._plotter: PlotOcp = None + self._update_time = 0.001 def __call__(self, pipe: mp.Queue, show_options: dict | None): """ @@ -245,13 +250,11 @@ def __call__(self, pipe: mp.Queue, show_options: dict | None): if show_options is None: show_options = {} - self.pipe = pipe + self._pipe = pipe - dummy_phase_times = OptimizationVectorHelper.extract_step_times(self.ocp, DM(np.ones(self.ocp.n_phases))) - self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_phase_times, **show_options) - timer = self._plotter.all_figures[0].canvas.new_timer(interval=10) - timer.add_callback(self.plot_update) - timer.start() + dummy_phase_times = OptimizationVectorHelper.extract_step_times(self._ocp, DM(np.ones(self._ocp.n_phases))) + self._plotter = PlotOcp(self._ocp, dummy_phase_times=dummy_phase_times, **show_options) + threading.Timer(self._update_time, self.plot_update).start() plt.show() def plot_update(self) -> bool: @@ -263,13 +266,20 @@ def plot_update(self) -> bool: True if everything went well """ - while not self.pipe.empty(): - args = self.pipe.get() - data = self._plotter.parse_data(**args) - self._plotter.update_data(**data, **args) + args = {} + while not self._pipe.empty(): + args = self._pipe.get() + + if args: + self._plotter.update_data(*self._plotter.parse_data(**args), **args) - for i, fig in enumerate(self._plotter.all_figures): + # We want to redraw here to actually consume a bit of time, otherwise it goes to fast and pipe remains empty + for fig in self._plotter.all_figures: fig.canvas.draw() + if [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0: + # If there are still figures, we keep updating + threading.Timer(self._update_time, self.plot_update).start() + return True From 2b5dae7250ed2e73f3c8bdde92a5cad321408890 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Tue, 30 Jul 2024 10:51:00 -0400 Subject: [PATCH 06/17] Slightly faster online graph using Server --- bioptim/gui/online_callback.py | 52 +++++++++++++++------------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index c8a880de0..20ffe4b34 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -198,13 +198,12 @@ def close(self): def eval(self, arg: list | tuple, force: bool = False) -> list: # Dequeuing the data by removing previous not useful data while not self.queue.empty(): - self.queue.get() + self.queue.get_nowait() - send = self.queue.put args_dict = {} for i, s in enumerate(nlpsol_out()): args_dict[s] = arg[i] - send(args_dict) + self.queue.put_nowait(args_dict) return [0] class ProcessPlotter(object): @@ -338,7 +337,7 @@ def run(self): finally: self._socket.close() - def _wait_for_data(self, client_socket: socket.socket): + def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool = True): # Receive the actual data try: self._logger.debug("Waiting for data from client") @@ -355,17 +354,22 @@ def _wait_for_data(self, client_socket: socket.socket): message_type = OnlineCallbackServer._ServerMessages(int(data_as_list[0])) len_all_data = [int(len_data) for len_data in data_as_list[1][1:-1].split(",")] # Sends confirmation and waits for the next message - client_socket.send("OK".encode()) + if send_confirmation: + client_socket.sendall("OK".encode()) self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") data_out = [] for len_data in len_all_data: data_out.append(client_socket.recv(len_data)) - client_socket.send("OK".encode()) + if len(data_out[-1]) != len_data: + data_out[-1] += client_socket.recv(len_data - len(data_out[-1])) + if send_confirmation: + client_socket.sendall("OK".encode()) except ValueError: self._logger.warning("Unknown message type received") message_type = OnlineCallbackServer._ServerMessages.UNKNOWN # Sends failure - client_socket.send("NOK".encode()) + if send_confirmation: + client_socket.sendall("NOK".encode()) data_out = [] if message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: @@ -396,7 +400,7 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list): try: self.ocp = OcpSerializable.deserialize(data_json) except: - client_socket.send("FAILED".encode()) + client_socket.sendall("FAILED".encode()) self._logger.warning("Error while deserializing OCP data from client, closing connexion") return @@ -404,7 +408,7 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list): self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) # Send the confirmation to the client - client_socket.send("PLOT_READY".encode()) + client_socket.sendall("PLOT_READY".encode()) # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data, (client_socket,)).start() @@ -430,10 +434,10 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> bool: True if everything went well """ self._logger.debug(f"Waiting for new data from client") - client_socket.send("READY_FOR_NEXT_DATA".encode()) + client_socket.sendall("READY_FOR_NEXT_DATA".encode()) should_continue = False - message_type, data = self._wait_for_data(client_socket=client_socket) + message_type, data = self._wait_for_data(client_socket=client_socket, send_confirmation=False) if message_type == OnlineCallbackServer._ServerMessages.NEW_DATA: try: self._update_data(data) @@ -536,14 +540,14 @@ def _initialize_connexion(self, **show_options): serialized_ocp = json.dumps(ocp_plot).encode() # Sends message type and dimensions - self._socket.send( + self._socket.sendall( f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp)]}".encode() ) if self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") # TODO ADD SHOW OPTIONS to the send - self._socket.send(serialized_ocp) + self._socket.sendall(serialized_ocp) if self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") @@ -557,19 +561,12 @@ def _initialize_connexion(self, **show_options): ) def close(self): - self._socket.send( + self._socket.sendall( f"{OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION.value}\nGoodbye from client!".encode() ) self._socket.close() def eval(self, arg: list | tuple, force: bool = False) -> list: - arg_as_bytes = [] - for a in arg: - to_pack = np.array(a).T.tolist() - if len(to_pack) == 1: - to_pack = to_pack[0] - arg_as_bytes.append(struct.pack("d" * len(to_pack), *to_pack)) - if not force: self._socket.setblocking(False) @@ -610,14 +607,11 @@ def eval(self, arg: list | tuple, force: bool = False) -> list: y_steps_tp = y_steps.tolist() data_serialized += struct.pack("d" * len(y_steps_tp), *y_steps_tp) - self._socket.send( + self._socket.sendall( f"{OnlineCallbackServer._ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode() ) - if self._socket.recv(1024).decode() != "OK": - raise RuntimeError("The server did not acknowledge the data") - - for to_send in [header.encode(), data_serialized]: - self._socket.send(to_send) - if self._socket.recv(1024).decode() != "OK": - raise RuntimeError("The server did not acknowledge the data") + # If send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) + self._socket.sendall(header.encode()) + self._socket.sendall(data_serialized) + # Again, if send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) return [0] From e9f085fb333e349dfd7c99e57cba8aea93e00bd6 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Tue, 30 Jul 2024 10:51:33 -0400 Subject: [PATCH 07/17] Renamed TCP for SERVER --- bioptim/interfaces/interface_utils.py | 2 +- bioptim/misc/enums.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index f8ff19844..4db48e342 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -39,7 +39,7 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): "You can add show_options={'type': ShowOnlineType.TCP} to the Solver declaration" ) interface.options_common["iteration_callback"] = OnlineCallbackMultiprocess(ocp, show_options=show_options) - elif show_type == ShowOnlineType.TCP: + elif show_type == ShowOnlineType.SERVER: host = None if "host" in show_options: host = show_options["host"] diff --git a/bioptim/misc/enums.py b/bioptim/misc/enums.py index 302f75a49..e1fde3cb0 100644 --- a/bioptim/misc/enums.py +++ b/bioptim/misc/enums.py @@ -106,7 +106,7 @@ class ShowOnlineType(Enum): """ MULTIPROCESS = 0 - TCP = 1 + SERVER = 1 class ControlType(Enum): From a7a91c36299afa9cfc954c91699c9ae6235ac697 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Wed, 31 Jul 2024 17:45:22 -0400 Subject: [PATCH 08/17] Adding show_options --- bioptim/gui/online_callback.py | 20 ++++++++++++++++++-- bioptim/gui/serializable_class.py | 28 ++++++++++++++-------------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index 20ffe4b34..c800e02f8 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -286,6 +286,14 @@ def plot_update(self) -> bool: _default_port = 3050 +def _serialize_show_options(show_options: dict) -> bytes: + return json.dumps(show_options).encode() + + +def _deserialize_show_options(show_options: bytes) -> dict: + return json.loads(show_options.decode()) + + class OnlineCallbackServer: class _ServerMessages(Enum): INITIATE_CONNEXION = 0 @@ -404,7 +412,12 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list): self._logger.warning("Error while deserializing OCP data from client, closing connexion") return - show_options = {} + try: + show_options = _deserialize_show_options(ocp_raw[1]) + except: + self._logger.warning("Error while extracting show options, closing connexion") + return + self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) # Send the confirmation to the client @@ -539,15 +552,18 @@ def _initialize_connexion(self, **show_options): ocp_plot["dummy_phase_times"].append([np.array(v)[:, 0].tolist() for v in phase_times]) serialized_ocp = json.dumps(ocp_plot).encode() + serialized_show_options = _serialize_show_options(show_options) + # Sends message type and dimensions self._socket.sendall( - f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp)]}".encode() + f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".encode() ) if self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") # TODO ADD SHOW OPTIONS to the send self._socket.sendall(serialized_ocp) + self._socket.sendall(serialized_show_options) if self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") diff --git a/bioptim/gui/serializable_class.py b/bioptim/gui/serializable_class.py index 202b37b54..a57470dab 100644 --- a/bioptim/gui/serializable_class.py +++ b/bioptim/gui/serializable_class.py @@ -5,8 +5,9 @@ from ..dynamics.ode_solver import OdeSolver from ..limits.penalty_option import PenaltyOption +from ..limits.path_conditions import Bounds from ..misc.mapping import BiMapping -from ..misc.enums import PlotType, QuadratureRule +from ..misc.enums import PlotType, QuadratureRule, InterpolationType class CasadiFunctionSerializable: @@ -137,33 +138,32 @@ def deserialize(cls, data): class BoundsSerializable: - min: np.ndarray | DM - max: np.ndarray | DM + bounds: Bounds - def __init__(self, min: np.ndarray | DM, max: np.ndarray | DM): - self.min = min - self.max = max + def __init__(self, bounds: Bounds): + self.bounds = bounds @classmethod def from_bounds(cls, bounds): - return cls( - min=np.array(bounds.min), - max=np.array(bounds.max), - ) + return cls(bounds=bounds) def serialize(self): return { - "min": self.min.tolist(), - "max": self.max.tolist(), + "min": self.bounds.min(), + "max": self.bounds.max(), + "type": self.bounds.type, + "slice_list": self.bounds.slice_list, } @classmethod def deserialize(cls, data): return cls( - min=np.array(data["min"]), - max=np.array(data["max"]), + type=Bounds(min_bound=data["min"], max_bound=data["max"], type=data["type"], slice_list=data["slice_list"]), ) + def check_and_adjust_dimensions(self, n_elements: int, n_nodes: int): + pass + class CustomPlotSerializable: _function: Callable From 014ce953c930f56cb470cd0cc1214b8bd810592a Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 1 Aug 2024 08:41:14 -0400 Subject: [PATCH 09/17] Typo --- bioptim/gui/online_callback.py | 2 +- bioptim/interfaces/interface_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback.py index c800e02f8..32b59cd46 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback.py @@ -514,7 +514,7 @@ def _update_data(self, data_raw: list): self._plotter.update_data(xdata, ydata) -class OnlineCallbackTcp(OnlineCallbackAbstract): +class OnlineCallbackServer(OnlineCallbackAbstract): def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str = None, port: int = None): super().__init__(ocp, opts, show_options) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 4db48e342..37456127a 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -6,7 +6,7 @@ from casadi import horzcat, vertcat, sum1, sum2, nlpsol, SX, MX, reshape from bioptim.optimization.solution.solution import Solution -from ..gui.online_callback import OnlineCallbackMultiprocess, OnlineCallbackTcp +from ..gui.online_callback import OnlineCallbackMultiprocess, OnlineCallbackServer from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers from ..misc.enums import InterpolationType, ShowOnlineType @@ -48,7 +48,7 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): if "port" in show_options: port = show_options["port"] del show_options["port"] - interface.options_common["iteration_callback"] = OnlineCallbackTcp( + interface.options_common["iteration_callback"] = OnlineCallbackServer( ocp, show_options=show_options, host=host, port=port ) else: From 3a6cd7bc32b27c9166f799d61c9449038fa0a80e Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 1 Aug 2024 09:25:30 -0400 Subject: [PATCH 10/17] Finalized show_bounds and separated online_callbacks into smaller files --- bioptim/__init__.py | 2 +- bioptim/gui/online_callback_abstract.py | 170 ++++++++++ bioptim/gui/online_callback_multiprocess.py | 122 +++++++ ..._callback.py => online_callback_server.py} | 320 ++---------------- bioptim/gui/serializable_class.py | 46 ++- bioptim/interfaces/interface_utils.py | 3 +- resources/bioptim_plotting_server.py | 4 +- 7 files changed, 357 insertions(+), 310 deletions(-) create mode 100644 bioptim/gui/online_callback_abstract.py create mode 100644 bioptim/gui/online_callback_multiprocess.py rename bioptim/gui/{online_callback.py => online_callback_server.py} (59%) diff --git a/bioptim/__init__.py b/bioptim/__init__.py index 61c30a665..b29898c5d 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -231,4 +231,4 @@ from .optimization.problem_type import SocpType from .misc.casadi_expand import lt, le, gt, ge, if_else, if_else_zero -from .gui.online_callback import OnlineCallbackServer +from .gui.online_callback_server import OnlineCallbackServerBackend diff --git a/bioptim/gui/online_callback_abstract.py b/bioptim/gui/online_callback_abstract.py new file mode 100644 index 000000000..a8a01e06f --- /dev/null +++ b/bioptim/gui/online_callback_abstract.py @@ -0,0 +1,170 @@ +from abc import ABC, abstractmethod +from enum import Enum +import json +import logging +import multiprocessing as mp +import socket +import struct +import threading + +from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity, DM +from matplotlib import pyplot as plt +import numpy as np + +from .plot import PlotOcp, OcpSerializable +from ..optimization.optimization_vector import OptimizationVectorHelper + + +class OnlineCallbackAbstract(Callback, ABC): + """ + CasADi interface of Ipopt callbacks + + Attributes + ---------- + ocp: OptimalControlProgram + A reference to the ocp to show + nx: int + The number of optimization variables + ng: int + The number of constraints + + Methods + ------- + get_n_in() -> int + Get the number of variables in + get_n_out() -> int + Get the number of variables out + get_name_in(i: int) -> int + Get the name of a variable + get_name_out(_) -> str + Get the name of the output variable + get_sparsity_in(self, i: int) -> tuple[int] + Get the sparsity of a specific variable + eval(self, arg: list | tuple, force: bool = False) -> list[int] + Send the current data to the plotter + """ + + def __init__(self, ocp, opts: dict = None, show_options: dict = None): + """ + Parameters + ---------- + ocp: OptimalControlProgram + A reference to the ocp to show + opts: dict + Option to AnimateCallback method of CasADi + show_options: dict + The options to pass to PlotOcp + """ + if opts is None: + opts = {} + + Callback.__init__(self) + self.ocp = ocp + self.nx = self.ocp.variables_vector.shape[0] + + # There must be an option to add an if here + from ..interfaces.ipopt_interface import IpoptInterface + + interface = IpoptInterface(ocp) + all_g, _ = interface.dispatch_bounds() + self.ng = all_g.shape[0] + + self.construct("AnimateCallback", opts) + + @abstractmethod + def close(self): + """ + Close the callback + """ + + @staticmethod + def get_n_in() -> int: + """ + Get the number of variables in + + Returns + ------- + The number of variables in + """ + + return nlpsol_n_out() + + @staticmethod + def get_n_out() -> int: + """ + Get the number of variables out + + Returns + ------- + The number of variables out + """ + + return 1 + + @staticmethod + def get_name_in(i: int) -> int: + """ + Get the name of a variable + + Parameters + ---------- + i: int + The index of the variable + + Returns + ------- + The name of the variable + """ + + return nlpsol_out(i) + + @staticmethod + def get_name_out(_) -> str: + """ + Get the name of the output variable + + Returns + ------- + The name of the output variable + """ + + return "ret" + + def get_sparsity_in(self, i: int) -> tuple: + """ + Get the sparsity of a specific variable + + Parameters + ---------- + i: int + The index of the variable + + Returns + ------- + The sparsity of the variable + """ + + n = nlpsol_out(i) + if n == "f": + return Sparsity.scalar() + elif n in ("x", "lam_x"): + return Sparsity.dense(self.nx) + elif n in ("g", "lam_g"): + return Sparsity.dense(self.ng) + else: + return Sparsity(0, 0) + + @abstractmethod + def eval(self, arg: list | tuple, force: bool = False) -> list: + """ + Send the current data to the plotter + + Parameters + ---------- + arg: list | tuple + The data to send + + Returns + ------- + A list of error index + """ diff --git a/bioptim/gui/online_callback_multiprocess.py b/bioptim/gui/online_callback_multiprocess.py new file mode 100644 index 000000000..d71210f39 --- /dev/null +++ b/bioptim/gui/online_callback_multiprocess.py @@ -0,0 +1,122 @@ +import multiprocessing as mp +import threading + +from casadi import nlpsol_out, DM +from matplotlib import pyplot as plt +import numpy as np + +from .plot import PlotOcp, OcpSerializable +from ..optimization.optimization_vector import OptimizationVectorHelper +from .online_callback_abstract import OnlineCallbackAbstract + + +class OnlineCallbackMultiprocess(OnlineCallbackAbstract): + """ + Multiprocessing implementation of the online callback + + Attributes + ---------- + queue: mp.Queue + The multiprocessing queue + plotter: ProcessPlotter + The callback for plotting for the multiprocessing + plot_process: mp.Process + The multiprocessing placeholder + """ + + def __init__(self, ocp, opts: dict = None, show_options: dict = None): + super(OnlineCallbackMultiprocess, self).__init__(ocp, opts, show_options) + + self.queue = mp.Queue() + self.plotter = self.ProcessPlotter(self.ocp) + self.plot_process = mp.Process(target=self.plotter, args=(self.queue, show_options), daemon=True) + self.plot_process.start() + + def close(self): + self.plot_process.kill() + + def eval(self, arg: list | tuple, force: bool = False) -> list: + # Dequeuing the data by removing previous not useful data + while not self.queue.empty(): + self.queue.get_nowait() + + args_dict = {} + for i, s in enumerate(nlpsol_out()): + args_dict[s] = arg[i] + self.queue.put_nowait(args_dict) + return [0] + + class ProcessPlotter(object): + """ + The plotter that interface PlotOcp and the multiprocessing + + Attributes + ---------- + ocp: OptimalControlProgram + A reference to the ocp to show + pipe: mp.Queue + The multiprocessing queue to evaluate + plot: PlotOcp + The handler on all the figures + + Methods + ------- + callback(self) -> bool + The callback to update the graphs + """ + + def __init__(self, ocp): + """ + Parameters + ---------- + ocp: OptimalControlProgram + A reference to the ocp to show + """ + + self._ocp: OcpSerializable = ocp + self._plotter: PlotOcp = None + self._update_time = 0.001 + + def __call__(self, pipe: mp.Queue, show_options: dict | None): + """ + Parameters + ---------- + pipe: mp.Queue + The multiprocessing queue to evaluate + show_options: dict + The option to pass to PlotOcp + """ + + if show_options is None: + show_options = {} + self._pipe = pipe + + dummy_phase_times = OptimizationVectorHelper.extract_step_times(self._ocp, DM(np.ones(self._ocp.n_phases))) + self._plotter = PlotOcp(self._ocp, dummy_phase_times=dummy_phase_times, **show_options) + threading.Timer(self._update_time, self.plot_update).start() + plt.show() + + def plot_update(self) -> bool: + """ + The callback to update the graphs + + Returns + ------- + True if everything went well + """ + + args = {} + while not self._pipe.empty(): + args = self._pipe.get() + + if args: + self._plotter.update_data(*self._plotter.parse_data(**args), **args) + + # We want to redraw here to actually consume a bit of time, otherwise it goes to fast and pipe remains empty + for fig in self._plotter.all_figures: + fig.canvas.draw() + if [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0: + # If there are still figures, we keep updating + threading.Timer(self._update_time, self.plot_update).start() + + return True diff --git a/bioptim/gui/online_callback.py b/bioptim/gui/online_callback_server.py similarity index 59% rename from bioptim/gui/online_callback.py rename to bioptim/gui/online_callback_server.py index 32b59cd46..7e8e9b276 100644 --- a/bioptim/gui/online_callback.py +++ b/bioptim/gui/online_callback_server.py @@ -1,287 +1,19 @@ -from abc import ABC, abstractmethod from enum import Enum import json import logging -import multiprocessing as mp import socket import struct import threading -from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity, DM +from casadi import nlpsol_out, DM from matplotlib import pyplot as plt import numpy as np +from .online_callback_abstract import OnlineCallbackAbstract from .plot import PlotOcp, OcpSerializable from ..optimization.optimization_vector import OptimizationVectorHelper -class OnlineCallbackAbstract(Callback, ABC): - """ - CasADi interface of Ipopt callbacks - - Attributes - ---------- - ocp: OptimalControlProgram - A reference to the ocp to show - nx: int - The number of optimization variables - ng: int - The number of constraints - - Methods - ------- - get_n_in() -> int - Get the number of variables in - get_n_out() -> int - Get the number of variables out - get_name_in(i: int) -> int - Get the name of a variable - get_name_out(_) -> str - Get the name of the output variable - get_sparsity_in(self, i: int) -> tuple[int] - Get the sparsity of a specific variable - eval(self, arg: list | tuple, force: bool = False) -> list[int] - Send the current data to the plotter - """ - - def __init__(self, ocp, opts: dict = None, show_options: dict = None): - """ - Parameters - ---------- - ocp: OptimalControlProgram - A reference to the ocp to show - opts: dict - Option to AnimateCallback method of CasADi - show_options: dict - The options to pass to PlotOcp - """ - if opts is None: - opts = {} - - Callback.__init__(self) - self.ocp = ocp - self.nx = self.ocp.variables_vector.shape[0] - - # There must be an option to add an if here - from ..interfaces.ipopt_interface import IpoptInterface - - interface = IpoptInterface(ocp) - all_g, _ = interface.dispatch_bounds() - self.ng = all_g.shape[0] - - self.construct("AnimateCallback", opts) - - @abstractmethod - def close(self): - """ - Close the callback - """ - - @staticmethod - def get_n_in() -> int: - """ - Get the number of variables in - - Returns - ------- - The number of variables in - """ - - return nlpsol_n_out() - - @staticmethod - def get_n_out() -> int: - """ - Get the number of variables out - - Returns - ------- - The number of variables out - """ - - return 1 - - @staticmethod - def get_name_in(i: int) -> int: - """ - Get the name of a variable - - Parameters - ---------- - i: int - The index of the variable - - Returns - ------- - The name of the variable - """ - - return nlpsol_out(i) - - @staticmethod - def get_name_out(_) -> str: - """ - Get the name of the output variable - - Returns - ------- - The name of the output variable - """ - - return "ret" - - def get_sparsity_in(self, i: int) -> tuple: - """ - Get the sparsity of a specific variable - - Parameters - ---------- - i: int - The index of the variable - - Returns - ------- - The sparsity of the variable - """ - - n = nlpsol_out(i) - if n == "f": - return Sparsity.scalar() - elif n in ("x", "lam_x"): - return Sparsity.dense(self.nx) - elif n in ("g", "lam_g"): - return Sparsity.dense(self.ng) - else: - return Sparsity(0, 0) - - @abstractmethod - def eval(self, arg: list | tuple, force: bool = False) -> list: - """ - Send the current data to the plotter - - Parameters - ---------- - arg: list | tuple - The data to send - - Returns - ------- - A list of error index - """ - - -class OnlineCallbackMultiprocess(OnlineCallbackAbstract): - """ - Multiprocessing implementation of the online callback - - Attributes - ---------- - queue: mp.Queue - The multiprocessing queue - plotter: ProcessPlotter - The callback for plotting for the multiprocessing - plot_process: mp.Process - The multiprocessing placeholder - """ - - def __init__(self, ocp, opts: dict = None, show_options: dict = None): - super(OnlineCallbackMultiprocess, self).__init__(ocp, opts, show_options) - - self.queue = mp.Queue() - self.plotter = self.ProcessPlotter(self.ocp) - self.plot_process = mp.Process(target=self.plotter, args=(self.queue, show_options), daemon=True) - self.plot_process.start() - - def close(self): - self.plot_process.kill() - - def eval(self, arg: list | tuple, force: bool = False) -> list: - # Dequeuing the data by removing previous not useful data - while not self.queue.empty(): - self.queue.get_nowait() - - args_dict = {} - for i, s in enumerate(nlpsol_out()): - args_dict[s] = arg[i] - self.queue.put_nowait(args_dict) - return [0] - - class ProcessPlotter(object): - """ - The plotter that interface PlotOcp and the multiprocessing - - Attributes - ---------- - ocp: OptimalControlProgram - A reference to the ocp to show - pipe: mp.Queue - The multiprocessing queue to evaluate - plot: PlotOcp - The handler on all the figures - - Methods - ------- - callback(self) -> bool - The callback to update the graphs - """ - - def __init__(self, ocp): - """ - Parameters - ---------- - ocp: OptimalControlProgram - A reference to the ocp to show - """ - - self._ocp: OcpSerializable = ocp - self._plotter: PlotOcp = None - self._update_time = 0.001 - - def __call__(self, pipe: mp.Queue, show_options: dict | None): - """ - Parameters - ---------- - pipe: mp.Queue - The multiprocessing queue to evaluate - show_options: dict - The option to pass to PlotOcp - """ - - if show_options is None: - show_options = {} - self._pipe = pipe - - dummy_phase_times = OptimizationVectorHelper.extract_step_times(self._ocp, DM(np.ones(self._ocp.n_phases))) - self._plotter = PlotOcp(self._ocp, dummy_phase_times=dummy_phase_times, **show_options) - threading.Timer(self._update_time, self.plot_update).start() - plt.show() - - def plot_update(self) -> bool: - """ - The callback to update the graphs - - Returns - ------- - True if everything went well - """ - - args = {} - while not self._pipe.empty(): - args = self._pipe.get() - - if args: - self._plotter.update_data(*self._plotter.parse_data(**args), **args) - - # We want to redraw here to actually consume a bit of time, otherwise it goes to fast and pipe remains empty - for fig in self._plotter.all_figures: - fig.canvas.draw() - if [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0: - # If there are still figures, we keep updating - threading.Timer(self._update_time, self.plot_update).start() - - return True - - _default_host = "localhost" _default_port = 3050 @@ -294,15 +26,16 @@ def _deserialize_show_options(show_options: bytes) -> dict: return json.loads(show_options.decode()) -class OnlineCallbackServer: - class _ServerMessages(Enum): - INITIATE_CONNEXION = 0 - NEW_DATA = 1 - CLOSE_CONNEXION = 2 - EMPTY = 3 - TOO_SOON = 4 - UNKNOWN = 5 +class _ServerMessages(Enum): + INITIATE_CONNEXION = 0 + NEW_DATA = 1 + CLOSE_CONNEXION = 2 + EMPTY = 3 + TOO_SOON = 4 + UNKNOWN = 5 + +class OnlineCallbackServerBackend: def _prepare_logger(self): name = "OnlineCallbackServer" console_handler = logging.StreamHandler() @@ -351,15 +84,15 @@ def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool = self._logger.debug("Waiting for data from client") data = client_socket.recv(1024) if not data: - return OnlineCallbackServer._ServerMessages.EMPTY, None + return _ServerMessages.EMPTY, None except: self._logger.warning("Client closed connexion") client_socket.close() - return OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION, None + return _ServerMessages.CLOSE_CONNEXION, None data_as_list = data.decode().split("\n") try: - message_type = OnlineCallbackServer._ServerMessages(int(data_as_list[0])) + message_type = _ServerMessages(int(data_as_list[0])) len_all_data = [int(len_data) for len_data in data_as_list[1][1:-1].split(",")] # Sends confirmation and waits for the next message if send_confirmation: @@ -374,23 +107,23 @@ def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool = client_socket.sendall("OK".encode()) except ValueError: self._logger.warning("Unknown message type received") - message_type = OnlineCallbackServer._ServerMessages.UNKNOWN + message_type = _ServerMessages.UNKNOWN # Sends failure if send_confirmation: client_socket.sendall("NOK".encode()) data_out = [] - if message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION: + if message_type == _ServerMessages.CLOSE_CONNEXION: self._logger.info("Received close connexion from client") client_socket.close() plt.close() - return OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION, None + return _ServerMessages.CLOSE_CONNEXION, None return message_type, data_out def _wait_for_new_connexion(self, client_socket: socket.socket): message_type, data = self._wait_for_data(client_socket=client_socket) - if message_type == OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION: + if message_type == _ServerMessages.INITIATE_CONNEXION: self._logger.debug(f"Received hand shake from client") self._initialize_plotter(client_socket, data) @@ -451,7 +184,7 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> bool: should_continue = False message_type, data = self._wait_for_data(client_socket=client_socket, send_confirmation=False) - if message_type == OnlineCallbackServer._ServerMessages.NEW_DATA: + if message_type == _ServerMessages.NEW_DATA: try: self._update_data(data) should_continue = True @@ -459,10 +192,7 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> bool: self._logger.warning("Error while updating data from client, closing connexion") plt.close() client_socket.close() - elif ( - message_type == OnlineCallbackServer._ServerMessages.EMPTY - or message_type == OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION - ): + elif message_type == _ServerMessages.EMPTY or message_type == _ServerMessages.CLOSE_CONNEXION: self._logger.debug("Received empty data from client (end of stream), closing connexion") if should_continue: @@ -556,7 +286,7 @@ def _initialize_connexion(self, **show_options): # Sends message type and dimensions self._socket.sendall( - f"{OnlineCallbackServer._ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".encode() + f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".encode() ) if self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") @@ -577,9 +307,7 @@ def _initialize_connexion(self, **show_options): ) def close(self): - self._socket.sendall( - f"{OnlineCallbackServer._ServerMessages.CLOSE_CONNEXION.value}\nGoodbye from client!".encode() - ) + self._socket.sendall(f"{_ServerMessages.CLOSE_CONNEXION.value}\nGoodbye from client!".encode()) self._socket.close() def eval(self, arg: list | tuple, force: bool = False) -> list: @@ -623,9 +351,7 @@ def eval(self, arg: list | tuple, force: bool = False) -> list: y_steps_tp = y_steps.tolist() data_serialized += struct.pack("d" * len(y_steps_tp), *y_steps_tp) - self._socket.sendall( - f"{OnlineCallbackServer._ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode() - ) + self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) # If send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) self._socket.sendall(header.encode()) self._socket.sendall(data_serialized) diff --git a/bioptim/gui/serializable_class.py b/bioptim/gui/serializable_class.py index a57470dab..408b134f8 100644 --- a/bioptim/gui/serializable_class.py +++ b/bioptim/gui/serializable_class.py @@ -138,31 +138,59 @@ def deserialize(cls, data): class BoundsSerializable: - bounds: Bounds + _bounds: Bounds def __init__(self, bounds: Bounds): - self.bounds = bounds + self._bounds = bounds @classmethod def from_bounds(cls, bounds): return cls(bounds=bounds) def serialize(self): + slice_list = self._bounds.min.slice_list # min and max have the same slice_list + slice_list_type = type(slice_list).__name__ + if isinstance(self._bounds.min.slice_list, slice): + slice_list = [slice_list.start, slice_list.stop, slice_list.step] + return { - "min": self.bounds.min(), - "max": self.bounds.max(), - "type": self.bounds.type, - "slice_list": self.bounds.slice_list, + "key": self._bounds.key, + "min": np.array(self._bounds.min).tolist(), + "max": np.array(self._bounds.max).tolist(), + "type": self._bounds.type.value, + "slice_list_type": slice_list_type, + "slice_list": slice_list, } @classmethod def deserialize(cls, data): return cls( - type=Bounds(min_bound=data["min"], max_bound=data["max"], type=data["type"], slice_list=data["slice_list"]), + bounds=Bounds( + key=data["key"], + min_bound=data["min"], + max_bound=data["max"], + interpolation=InterpolationType(data["type"]), + slice_list=( + slice(data["slice_list"][0], data["slice_list"][1], data["slice_list"][2]) + if data["slice_list_type"] == "slice" + else data["slice_list"] + ), + ), ) - def check_and_adjust_dimensions(self, n_elements: int, n_nodes: int): - pass + def check_and_adjust_dimensions(self, n_elements: int, n_shooting: int): + self._bounds.check_and_adjust_dimensions(n_elements, n_shooting) + + def type(self): + return self._bounds.type + + @property + def min(self): + return self._bounds.min + + @property + def max(self): + return self._bounds.max class CustomPlotSerializable: diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 37456127a..ee82bc50a 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -6,7 +6,8 @@ from casadi import horzcat, vertcat, sum1, sum2, nlpsol, SX, MX, reshape from bioptim.optimization.solution.solution import Solution -from ..gui.online_callback import OnlineCallbackMultiprocess, OnlineCallbackServer +from ..gui.online_callback_multiprocess import OnlineCallbackMultiprocess +from ..gui.online_callback_server import OnlineCallbackServer from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers from ..misc.enums import InterpolationType, ShowOnlineType diff --git a/resources/bioptim_plotting_server.py b/resources/bioptim_plotting_server.py index a9108b923..1d752495d 100644 --- a/resources/bioptim_plotting_server.py +++ b/resources/bioptim_plotting_server.py @@ -1,8 +1,8 @@ -from bioptim import OnlineCallbackServer +from bioptim import OnlineCallbackServerBackend def main(): - OnlineCallbackServer().run() + OnlineCallbackServerBackend().run() if __name__ == "__main__": From 3c0974c0257280b36215c97e252ec7a36329fe7b Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 1 Aug 2024 09:35:59 -0400 Subject: [PATCH 11/17] Robustified if client quits during optimization --- bioptim/gui/online_callback_server.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 7e8e9b276..f0c163d23 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -48,7 +48,7 @@ def _prepare_logger(self): self._logger = logging.getLogger(name) self._logger.addHandler(console_handler) - self._logger.setLevel(logging.DEBUG) + self._logger.setLevel(logging.INFO) def __init__(self, host: str = None, port: int = None): self._prepare_logger() @@ -180,7 +180,11 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> bool: True if everything went well """ self._logger.debug(f"Waiting for new data from client") - client_socket.sendall("READY_FOR_NEXT_DATA".encode()) + try: + client_socket.sendall("READY_FOR_NEXT_DATA".encode()) + except: + self._logger.warning("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") + return should_continue = False message_type, data = self._wait_for_data(client_socket=client_socket, send_confirmation=False) @@ -190,8 +194,9 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> bool: should_continue = True except: self._logger.warning("Error while updating data from client, closing connexion") - plt.close() client_socket.close() + return + elif message_type == _ServerMessages.EMPTY or message_type == _ServerMessages.CLOSE_CONNEXION: self._logger.debug("Received empty data from client (end of stream), closing connexion") From 4f4173c58456f5fb7ea4928846fd58b03aa12371 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 1 Aug 2024 11:15:40 -0400 Subject: [PATCH 12/17] Added automatic multiprocess for the plotter and docstrings --- README.md | 26 ++- bioptim/__init__.py | 2 +- bioptim/examples/getting_started/pendulum.py | 12 +- bioptim/gui/online_callback_server.py | 218 ++++++++++++++++--- bioptim/interfaces/interface_utils.py | 18 +- resources/bioptim_plotting_server.py | 9 - resources/plotting_server.py | 19 ++ 7 files changed, 258 insertions(+), 46 deletions(-) delete mode 100644 resources/bioptim_plotting_server.py create mode 100644 resources/plotting_server.py diff --git a/README.md b/README.md index 3bb69d4a1..857fcdbe0 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,8 @@ As a tour guide that uses this binder, you can watch the `bioptim` workshop that - [OptimalControlProgram](#class-optimalcontrolprogram) - [NonLinearProgram](#class-nonlinearprogram) + - [VariationalOptimalControlProgram](#class-variationaloptimalcontrolprogram) + - [PlottingServer](#class-plottingserver) @@ -167,6 +169,7 @@ As a tour guide that uses this binder, you can watch the `bioptim` workshop that - [Solver](#enum-solver) - [ControlType](#enum-controltype) - [PlotType](#enum-plottype) + - [ShowOnlineType](#enum-showonlinetype) - [InterpolationType](#enum-interpolationtype) - [Shooting](#enum-shooting) - [CostType](#enum-costtype) @@ -788,7 +791,15 @@ The `Solver` class can be used to select the nonlinear solver to solve the ocp: Note that options can be passed to the solver parameter. One can refer to their respective solver's documentation to know which options exist. The `show_online_optim` parameter can be set to `True` so the graphs nicely update during the optimization. -It is expected to slow down the optimization a bit. +Please note that `ShowOnlineType.MULTIPROCESS` is not available on Windows. To see how to run the server on Windows, please refer to the `getting_started/pendulum.py` example. +It is expected to slow down the optimization a bit. +`show_options` can be also passed as a dict to the plotter to customize the plotter's behavior. +The following keys are special options: + - `type`: the type of plotter to use (default is `ShowOnlineType.MULTIPROCESS`) + - If `type` is `ShowOnlineType.SERVER`, then these additional options are available: + - `as_multiprocess`: if the server should be run as a multiprocess (default is `True`). If `True`, a server is automatically started in a new process. If `False`, a server must be started manually by instantiating an `PlottingServer` class. + - `host`: the host to use (default is `localhost`), it must match the host used in the `PlottingServer` class if `as_multiprocess` is `False` + - `port`: the port to use (default is `5030`), it must match the port used in the `PlottingServer` class if `as_multiprocess` is `False` Finally, one can save and load previously optimized values by using ```python @@ -831,6 +842,12 @@ instead of `x_init` and `x_bounds`. the OCP, you can access them with `sol.parameters["qdot_start"]` and `sol.parameters["qdot_end"]` at the end of the optimization. +### Class: PlottingServer +If one wants to use the `ShowOnlineType.SERVER` plotter, one can instantiate this class to start a server. +This is not mandatory as if `as_multiprocess` is set to `True` in the `show_options` dict [default behavior], this server is started automatically. +The advantage of starting the server manually is that one can plot online graphs on a remote machine. +An example of such a server is provided in `resources/plotting_server.py`. + ## The model Bioptim is designed to work with any model, as long as it inherits from the class `bioptim.Model`. Models built with `biorbd` are already compatible with `bioptim`. @@ -1668,6 +1685,13 @@ INTEGRATED: Plot that links the points within an interval but is discrete betwee STEP: Step plot, constant over an interval. POINT: Point plot. +### Enum: ShowOnlineType +The type of online plotter to use. + +The accepted values are: +MULTIPROCESS: The online plotter is in a separate process. +SERVER: The online plotter is in a separate server. + ### Enum: InterpolationType Defines wow a time-dependent variable is interpolated. It is mainly used for phases time span. diff --git a/bioptim/__init__.py b/bioptim/__init__.py index b29898c5d..76092bd87 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -231,4 +231,4 @@ from .optimization.problem_type import SocpType from .misc.casadi_expand import lt, le, gt, ge, if_else, if_else_zero -from .gui.online_callback_server import OnlineCallbackServerBackend +from .gui.online_callback_server import PlottingServer diff --git a/bioptim/examples/getting_started/pendulum.py b/bioptim/examples/getting_started/pendulum.py index 3c6aa34a9..8325b9da7 100644 --- a/bioptim/examples/getting_started/pendulum.py +++ b/bioptim/examples/getting_started/pendulum.py @@ -26,6 +26,7 @@ BiorbdModel, ControlType, PhaseDynamics, + ShowOnlineType, ) @@ -149,8 +150,15 @@ def main(): # --- Print ocp structure --- # ocp.print(to_console=False, to_graph=False) - # --- Solve the ocp. Please note that online graphics only works with the Linux operating system --- # - sol = ocp.solve(Solver.IPOPT(show_online_optim=platform.system() == "Linux")) + # --- Solve the ocp --- # + if platform.system() == "Windows": + # The default online type (ShowOnlineType.MULTIPROCESS) is not available on Windows + # Since `as_multiprocess` default value is True, it will automatically starts the server in the background + solver = Solver.IPOPT(show_online_optim=True, show_options={"type": ShowOnlineType.SERVER}) + else: + solver = Solver.IPOPT(show_online_optim=True) + + sol = ocp.solve(solver) # --- Show the results graph --- # sol.print_cost() diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index f0c163d23..ee8214f7e 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -3,6 +3,7 @@ import logging import socket import struct +import time import threading from casadi import nlpsol_out, DM @@ -26,6 +27,18 @@ def _deserialize_show_options(show_options: bytes) -> dict: return json.loads(show_options.decode()) +def _start_as_multiprocess_internal(*args, **kwargs): + """ + Starts the server (necessary for multiprocessing), this method should not be called directly, apart from + run_as_multiprocess + + Parameters + ---------- + same as PlottingServer + """ + PlottingServer(*args, **kwargs) + + class _ServerMessages(Enum): INITIATE_CONNEXION = 0 NEW_DATA = 1 @@ -35,9 +48,37 @@ class _ServerMessages(Enum): UNKNOWN = 5 -class OnlineCallbackServerBackend: - def _prepare_logger(self): - name = "OnlineCallbackServer" +class PlottingServer: + def __init__(self, host: str = None, port: int = None): + """ + Initializes the server + + Parameters + ---------- + host: str + The host to listen to, by default "localhost" + port: int + The port to listen to, by default 3050 + """ + + self._prepare_logger() + self._get_data_interval = 1.0 + self._update_plot_interval = 0.01 + + # Define the host and port + self._host = host if host else _default_host + self._port = port if port else _default_port + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._plotter: PlotOcp = None + + self._run() + + def _prepare_logger(self) -> None: + """ + Prepares the logger + """ + + name = "PlottingServer" console_handler = logging.StreamHandler() formatter = logging.Formatter( "{asctime} - {name}:{levelname} - {message}", @@ -50,18 +91,24 @@ def _prepare_logger(self): self._logger.addHandler(console_handler) self._logger.setLevel(logging.INFO) - def __init__(self, host: str = None, port: int = None): - self._prepare_logger() - self._get_data_interval = 1.0 - self._update_plot_interval = 0.01 + @staticmethod + def as_multiprocess(*args, **kwargs) -> None: + """ + Starts the server in a new process, this method can be called directly by the user - # Define the host and port - self._host = host if host else _default_host - self._port = port if port else _default_port - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._plotter: PlotOcp = None + Parameters + ---------- + same as PlottingServer + """ + from multiprocessing import Process + + thread = Process(target=_start_as_multiprocess_internal, args=args, kwargs=kwargs) + thread.start() - def run(self): + def _run(self) -> None: + """ + Starts the server, this method can be called directly by the user to start a plot server + """ # Start listening to the server self._socket.bind((self._host, self._port)) self._socket.listen(1) @@ -78,7 +125,23 @@ def run(self): finally: self._socket.close() - def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool = True): + def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool) -> tuple[_ServerMessages, list]: + """ + Waits for data from the client + + Parameters + ---------- + client_socket: socket.socket + The client socket + send_confirmation: bool + If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will + not send anything. This is part of the communication protocol + + Returns + ------- + The message type and the data + """ + # Receive the actual data try: self._logger.debug("Waiting for data from client") @@ -121,13 +184,33 @@ def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool = return message_type, data_out - def _wait_for_new_connexion(self, client_socket: socket.socket): - message_type, data = self._wait_for_data(client_socket=client_socket) + def _wait_for_new_connexion(self, client_socket: socket.socket) -> None: + """ + Waits for a new connexion + + Parameters + ---------- + client_socket: socket.socket + The client socket + """ + + message_type, data = self._wait_for_data(client_socket=client_socket, send_confirmation=True) if message_type == _ServerMessages.INITIATE_CONNEXION: self._logger.debug(f"Received hand shake from client") self._initialize_plotter(client_socket, data) - def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list): + def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> None: + """ + Initializes the plotter + + Parameters + ---------- + client_socket: socket.socket + The client socket + ocp_raw: list + The serialized raw data from the client + """ + try: data_json = json.loads(ocp_raw[0]) dummy_time_vector = [] @@ -161,7 +244,11 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list): threading.Timer(self._update_plot_interval, self._redraw).start() plt.show() - def _redraw(self): + def _redraw(self) -> None: + """ + Redraws the plot, this method is called periodically as long as at least one figure is open + """ + self._logger.debug("Updating plot") for _, fig in enumerate(self._plotter.all_figures): fig.canvas.draw() @@ -171,14 +258,18 @@ def _redraw(self): else: self._logger.info("All figures have been closed, stop updating the plots") - def _wait_for_new_data(self, client_socket: socket.socket) -> bool: + def _wait_for_new_data(self, client_socket: socket.socket) -> None: """ - The callback to update the graphs - - Returns - ------- - True if everything went well + Waits for new data from the client, sends a "READY_FOR_NEXT_DATA" message to the client to signal that the server + is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects + the connexion will be closed + + Parameters + ---------- + client_socket: socket.socket + The client socket """ + self._logger.debug(f"Waiting for new data from client") try: client_socket.sendall("READY_FOR_NEXT_DATA".encode()) @@ -204,7 +295,16 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> bool: timer_get_data = threading.Timer(self._get_data_interval, self._wait_for_new_data, (client_socket,)) timer_get_data.start() - def _update_data(self, data_raw: list): + def _update_data(self, data_raw: list) -> None: + """ + Updates the data to plot based on the client data + + Parameters + ---------- + data_raw: list + The raw data from the client + """ + header = [int(v) for v in data_raw[0].decode().split(",")] data = data_raw[1] @@ -251,6 +351,25 @@ def _update_data(self, data_raw: list): class OnlineCallbackServer(OnlineCallbackAbstract): def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str = None, port: int = None): + """ + Initializes the client. This is not supposed to be called directly by the user, but by the solver. During the + initialization, we need to perform some tasks that are not possible to do in server side. Then the results of + these initialization are passed to the server + + Parameters + ---------- + ocp: OptimalControlProgram + The ocp + opts: dict + The options for the solver + show_options: dict + The options for the plot + host: str + The host to connect to, by default "localhost" + port: int + The port to connect to, by default 3050 + """ + super().__init__(ocp, opts, show_options) self._host = host if host else _default_host @@ -270,15 +389,32 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str self._initialize_connexion(**show_options) - def _initialize_connexion(self, **show_options): + def _initialize_connexion(self, retries: int = 0, **show_options) -> None: + """ + Initializes the connexion to the server + + Parameters + ---------- + retries: int + The number of retries to connect to the server (retry 5 times with 1s sleep between each retry, then raises + an error if it still cannot connect) + show_options: dict + The options to pass to PlotOcp + """ + # Start the client try: self._socket.connect((self._host, self._port)) except ConnectionError: - raise RuntimeError( - "Could not connect to the plotter server, make sure it is running " - "by calling 'OnlineCallbackServer().start()' on another python instance)" - ) + if retries > 5: + raise RuntimeError( + "Could not connect to the plotter server, make sure it is running by calling 'PlottingServer()' on " + "another python instance or allowing for automatic start of the server by calling " + "'PlottingServer.as_multiprocess()' in the main script" + ) + else: + time.sleep(1) + return self._initialize_connexion(retries + 1, **show_options) ocp_plot = OcpSerializable.from_ocp(self.ocp).serialize() dummy_phase_times = OptimizationVectorHelper.extract_step_times(self.ocp, DM(np.ones(self.ocp.n_phases))) @@ -311,11 +447,31 @@ def _initialize_connexion(self, **show_options): self.ocp, only_initialize_variables=True, dummy_phase_times=dummy_phase_times, **show_options ) - def close(self): + def close(self) -> None: + """ + Closes the connexion + """ + self._socket.sendall(f"{_ServerMessages.CLOSE_CONNEXION.value}\nGoodbye from client!".encode()) self._socket.close() def eval(self, arg: list | tuple, force: bool = False) -> list: + """ + Sends the current data to the plotter, this method is automatically called by the solver + + Parameters + ---------- + arg: list | tuple + The current data + force: bool + If True, the client will block until the server is ready to receive new data. This is useful at the end of + the optimization to make sure the data are plot (and not discarded) + + Returns + ------- + A mandatory [0] to respect the CasADi callback signature + """ + if not force: self._socket.setblocking(False) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index ee82bc50a..5ec68e39b 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -7,7 +7,7 @@ from bioptim.optimization.solution.solution import Solution from ..gui.online_callback_multiprocess import OnlineCallbackMultiprocess -from ..gui.online_callback_server import OnlineCallbackServer +from ..gui.online_callback_server import PlottingServer, OnlineCallbackServer from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers from ..misc.enums import InterpolationType, ShowOnlineType @@ -23,7 +23,12 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): ocp: OptimalControlProgram A reference to the current OptimalControlProgram show_options: dict - The options to pass to PlotOcp + The options to pass to PlotOcp, special options are: + - type: ShowOnlineType.MULTIPROCESS or ShowOnlineType.SERVER + - host: The host to connect to (only for ShowOnlineType.SERVER) + - port: The port to connect to (only for ShowOnlineType.SERVER) + - as_multiprocess: If the server should run as a multiprocess (only for ShowOnlineType.SERVER), if True, + a server is automatically started, if False, the user must start the server manually """ if show_options is None: show_options = {} @@ -45,10 +50,19 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): if "host" in show_options: host = show_options["host"] del show_options["host"] + port = None if "port" in show_options: port = show_options["port"] del show_options["port"] + + as_multiprocess = True + if "as_multiprocess" in show_options: + as_multiprocess = show_options["as_multiprocess"] + del show_options["as_multiprocess"] + if as_multiprocess: + PlottingServer.as_multiprocess(host=host, port=port) + interface.options_common["iteration_callback"] = OnlineCallbackServer( ocp, show_options=show_options, host=host, port=port ) diff --git a/resources/bioptim_plotting_server.py b/resources/bioptim_plotting_server.py deleted file mode 100644 index 1d752495d..000000000 --- a/resources/bioptim_plotting_server.py +++ /dev/null @@ -1,9 +0,0 @@ -from bioptim import OnlineCallbackServerBackend - - -def main(): - OnlineCallbackServerBackend().run() - - -if __name__ == "__main__": - main() diff --git a/resources/plotting_server.py b/resources/plotting_server.py new file mode 100644 index 000000000..6981b0b5d --- /dev/null +++ b/resources/plotting_server.py @@ -0,0 +1,19 @@ +""" +This file is an example of how to run a bioptim Online plotting server. That said, this is usually not the way to run a +bioptim server as it is easier to run it as an automatic multiprocess (default). This is achieved by setting +`show_options={"type": ShowOnlineType.SERVER, "as_multiprocess": True}` in the solver options. +If set to False, then the plotting server is mandatory. + +Since the server runs usings sockets, it is possible to run the server on a different machine than the one running the +optimization. This is useful when the optimization is run on a cluster and the plotting server is run on a local machine. +""" + +from bioptim import PlottingServer + + +def main(): + PlottingServer() + + +if __name__ == "__main__": + main() From 95588a97e26a599a8734f62c4bcbe914de9004ec Mon Sep 17 00:00:00 2001 From: Pariterre Date: Thu, 1 Aug 2024 17:26:02 -0400 Subject: [PATCH 13/17] Answered iPuch comments --- README.md | 26 ++++--- bioptim/__init__.py | 2 +- bioptim/examples/getting_started/pendulum.py | 14 +--- bioptim/gui/online_callback_abstract.py | 20 ++---- bioptim/gui/online_callback_multiprocess.py | 26 +++++-- .../online_callback_multiprocess_server.py | 29 ++++++++ bioptim/gui/online_callback_server.py | 71 ++++++++++++------- bioptim/gui/plot.py | 43 +++++------ bioptim/gui/serializable_class.py | 3 +- bioptim/interfaces/interface_utils.py | 29 ++++---- bioptim/interfaces/ipopt_options.py | 26 +++++-- bioptim/interfaces/sqp_options.py | 22 +++++- bioptim/misc/enums.py | 24 ++++--- .../optimization/optimal_control_program.py | 2 +- bioptim/optimization/optimization_vector.py | 2 +- .../receding_horizon_optimization.py | 6 +- resources/plotting_server.py | 2 +- tests/shard1/test_biorbd_model_holonomic.py | 2 +- tests/shard1/test_controltype_none.py | 4 +- tests/shard1/test_custom_model.py | 2 +- tests/shard2/test_global_sqp.py | 2 +- tests/shard3/test_get_time_solution.py | 2 +- ...t_global_torque_driven_with_contact_ocp.py | 2 +- tests/shard4/test_solution.py | 10 +-- tests/shard4/test_solver_options.py | 21 +++++- .../test_variational_integrator_examples.py | 4 +- .../test_global_stochastic_collocation.py | 6 +- ...st_global_stochastic_except_collocation.py | 6 +- 28 files changed, 256 insertions(+), 152 deletions(-) create mode 100644 bioptim/gui/online_callback_multiprocess_server.py diff --git a/README.md b/README.md index 857fcdbe0..50dc4c0b9 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ As a tour guide that uses this binder, you can watch the `bioptim` workshop that - [Solver](#enum-solver) - [ControlType](#enum-controltype) - [PlotType](#enum-plottype) - - [ShowOnlineType](#enum-showonlinetype) + - [OnlineOptim](#enum-onlineoptim) - [InterpolationType](#enum-interpolationtype) - [Shooting](#enum-shooting) - [CostType](#enum-costtype) @@ -790,16 +790,17 @@ The `Solver` class can be used to select the nonlinear solver to solve the ocp: Note that options can be passed to the solver parameter. One can refer to their respective solver's documentation to know which options exist. -The `show_online_optim` parameter can be set to `True` so the graphs nicely update during the optimization. -Please note that `ShowOnlineType.MULTIPROCESS` is not available on Windows. To see how to run the server on Windows, please refer to the `getting_started/pendulum.py` example. +The `show_online_optim` parameter can be set to `True` so the graphs nicely update during the optimization with the default values. +One can also directly declare `online_optim` as an `OnlineOptim` parameter to customize the behavior of the plotter. +Note that `show_online_optim` and `online_optim` are mutually exclusive. +Please also note that `OnlineOptim.MULTIPROCESS` is not available on Windows and only none of them are available on Macos. +To see how to run the server on Windows, please refer to the `getting_started/pendulum.py` example. It is expected to slow down the optimization a bit. `show_options` can be also passed as a dict to the plotter to customize the plotter's behavior. -The following keys are special options: - - `type`: the type of plotter to use (default is `ShowOnlineType.MULTIPROCESS`) - - If `type` is `ShowOnlineType.SERVER`, then these additional options are available: - - `as_multiprocess`: if the server should be run as a multiprocess (default is `True`). If `True`, a server is automatically started in a new process. If `False`, a server must be started manually by instantiating an `PlottingServer` class. - - `host`: the host to use (default is `localhost`), it must match the host used in the `PlottingServer` class if `as_multiprocess` is `False` - - `port`: the port to use (default is `5030`), it must match the port used in the `PlottingServer` class if `as_multiprocess` is `False` +If `online_optim` is set to `SERVER`, then a server must be started manually by instantiating an `PlottingServer` class (see `ressources/plotting_server.py`). +The following keys are additional options when using `OnlineOptim.SERVER` and `OnlineOptim.MULTIPROCESS_SERVER`: + - `host`: the host to use (default is `localhost`) + - `port`: the port to use (default is `5030`) Finally, one can save and load previously optimized values by using ```python @@ -843,7 +844,7 @@ the OCP, you can access them with `sol.parameters["qdot_start"]` and `sol.parame optimization. ### Class: PlottingServer -If one wants to use the `ShowOnlineType.SERVER` plotter, one can instantiate this class to start a server. +If one wants to use the `OnlineOptim.SERVER` plotter, one can instantiate this class to start a server. This is not mandatory as if `as_multiprocess` is set to `True` in the `show_options` dict [default behavior], this server is started automatically. The advantage of starting the server manually is that one can plot online graphs on a remote machine. An example of such a server is provided in `resources/plotting_server.py`. @@ -1685,12 +1686,15 @@ INTEGRATED: Plot that links the points within an interval but is discrete betwee STEP: Step plot, constant over an interval. POINT: Point plot. -### Enum: ShowOnlineType +### Enum: OnlineOptim The type of online plotter to use. The accepted values are: +NONE: No online plotter. +DEFAULT: Use the default online plotter depending on the OS (MULTIPROCESS on Linux, MULTIPROCESS_SERVER on Windows and NONE on MacOS). MULTIPROCESS: The online plotter is in a separate process. SERVER: The online plotter is in a separate server. +MULTIPROCESS_SERVER: The online plotter using the server automatically setup on a separate process. ### Enum: InterpolationType Defines wow a time-dependent variable is interpolated. diff --git a/bioptim/__init__.py b/bioptim/__init__.py index 76092bd87..d07d8da3c 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -208,7 +208,7 @@ MagnitudeType, MultiCyclicCycleSolutions, PhaseDynamics, - ShowOnlineType, + OnlineOptim, ) from .misc.mapping import BiMappingList, BiMapping, Mapping, NodeMapping, NodeMappingList, SelectionMapping, Dependency from .optimization.multi_start import MultiStart diff --git a/bioptim/examples/getting_started/pendulum.py b/bioptim/examples/getting_started/pendulum.py index 8325b9da7..7ce9ed46f 100644 --- a/bioptim/examples/getting_started/pendulum.py +++ b/bioptim/examples/getting_started/pendulum.py @@ -9,8 +9,6 @@ appreciate it). Finally, once it finished optimizing, it animates the model using the optimal solution """ -import platform - from bioptim import ( OptimalControlProgram, DynamicsFcn, @@ -26,7 +24,7 @@ BiorbdModel, ControlType, PhaseDynamics, - ShowOnlineType, + OnlineOptim, ) @@ -151,14 +149,8 @@ def main(): ocp.print(to_console=False, to_graph=False) # --- Solve the ocp --- # - if platform.system() == "Windows": - # The default online type (ShowOnlineType.MULTIPROCESS) is not available on Windows - # Since `as_multiprocess` default value is True, it will automatically starts the server in the background - solver = Solver.IPOPT(show_online_optim=True, show_options={"type": ShowOnlineType.SERVER}) - else: - solver = Solver.IPOPT(show_online_optim=True) - - sol = ocp.solve(solver) + # Default is OnlineOptim.MULTIPROCESS on Linux, OnlineOptim.MULTIPROCESS_SERVER on Windows and OnlineOptim.NONE on MacOS + sol = ocp.solve(Solver.IPOPT(show_online_optim=OnlineOptim.DEFAULT)) # --- Show the results graph --- # sol.print_cost() diff --git a/bioptim/gui/online_callback_abstract.py b/bioptim/gui/online_callback_abstract.py index a8a01e06f..678da5567 100644 --- a/bioptim/gui/online_callback_abstract.py +++ b/bioptim/gui/online_callback_abstract.py @@ -1,18 +1,6 @@ from abc import ABC, abstractmethod -from enum import Enum -import json -import logging -import multiprocessing as mp -import socket -import struct -import threading -from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity, DM -from matplotlib import pyplot as plt -import numpy as np - -from .plot import PlotOcp, OcpSerializable -from ..optimization.optimization_vector import OptimizationVectorHelper +from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity class OnlineCallbackAbstract(Callback, ABC): @@ -155,7 +143,7 @@ def get_sparsity_in(self, i: int) -> tuple: return Sparsity(0, 0) @abstractmethod - def eval(self, arg: list | tuple, force: bool = False) -> list: + def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: """ Send the current data to the plotter @@ -164,6 +152,10 @@ def eval(self, arg: list | tuple, force: bool = False) -> list: arg: list | tuple The data to send + enforce: bool + If True, the client will block until the server is ready to receive new data. This is useful at the end of + the optimization to make sure the data are plot (and not discarded) + Returns ------- A list of error index diff --git a/bioptim/gui/online_callback_multiprocess.py b/bioptim/gui/online_callback_multiprocess.py index d71210f39..935538725 100644 --- a/bioptim/gui/online_callback_multiprocess.py +++ b/bioptim/gui/online_callback_multiprocess.py @@ -35,7 +35,7 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None): def close(self): self.plot_process.kill() - def eval(self, arg: list | tuple, force: bool = False) -> list: + def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: # Dequeuing the data by removing previous not useful data while not self.queue.empty(): self.queue.get_nowait() @@ -52,12 +52,12 @@ class ProcessPlotter(object): Attributes ---------- - ocp: OptimalControlProgram + _ocp: OptimalControlProgram A reference to the ocp to show - pipe: mp.Queue - The multiprocessing queue to evaluate - plot: PlotOcp - The handler on all the figures + _plotter: PlotOcp + The plotter + _update_time: float + The time between each update Methods ------- @@ -96,6 +96,18 @@ def __call__(self, pipe: mp.Queue, show_options: dict | None): threading.Timer(self._update_time, self.plot_update).start() plt.show() + @property + def has_at_least_one_active_figure(self) -> bool: + """ + If at least one figure is active + + Returns + ------- + If at least one figure is active + """ + + return [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0 + def plot_update(self) -> bool: """ The callback to update the graphs @@ -115,7 +127,7 @@ def plot_update(self) -> bool: # We want to redraw here to actually consume a bit of time, otherwise it goes to fast and pipe remains empty for fig in self._plotter.all_figures: fig.canvas.draw() - if [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0: + if self.has_at_least_one_active_figure: # If there are still figures, we keep updating threading.Timer(self._update_time, self.plot_update).start() diff --git a/bioptim/gui/online_callback_multiprocess_server.py b/bioptim/gui/online_callback_multiprocess_server.py new file mode 100644 index 000000000..d352652e4 --- /dev/null +++ b/bioptim/gui/online_callback_multiprocess_server.py @@ -0,0 +1,29 @@ +from multiprocessing import Process + +from .online_callback_server import PlottingServer + + +def _start_as_multiprocess_internal(*args, **kwargs): + """ + Starts the server (necessary for multiprocessing), this method should not be called directly, apart from + run_as_multiprocess + + Parameters + ---------- + same as PlottingServer + """ + PlottingServer(*args, **kwargs) + + +class PlottingMultiprocessServer(PlottingServer): + def __init__(self, *args, **kwargs): + """ + Starts the server in a new process + + Parameters + ---------- + Same as PlottingServer + """ + + process = Process(target=_start_as_multiprocess_internal, args=args, kwargs=kwargs) + process.start() diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index ee8214f7e..35edbb02a 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import IntEnum, auto import json import logging import socket @@ -15,8 +15,8 @@ from ..optimization.optimization_vector import OptimizationVectorHelper -_default_host = "localhost" -_default_port = 3050 +_DEFAULT_HOST = "localhost" +_DEFAULT_PORT = 3050 def _serialize_show_options(show_options: dict) -> bytes: @@ -39,13 +39,13 @@ def _start_as_multiprocess_internal(*args, **kwargs): PlottingServer(*args, **kwargs) -class _ServerMessages(Enum): - INITIATE_CONNEXION = 0 - NEW_DATA = 1 - CLOSE_CONNEXION = 2 - EMPTY = 3 - TOO_SOON = 4 - UNKNOWN = 5 +class _ServerMessages(IntEnum): + INITIATE_CONNEXION = auto() + NEW_DATA = auto() + CLOSE_CONNEXION = auto() + EMPTY = auto() + TOO_SOON = auto() + UNKNOWN = auto() class PlottingServer: @@ -64,10 +64,11 @@ def __init__(self, host: str = None, port: int = None): self._prepare_logger() self._get_data_interval = 1.0 self._update_plot_interval = 0.01 + self._is_drawing = False # Define the host and port - self._host = host if host else _default_host - self._port = port if port else _default_port + self._host = host if host else _DEFAULT_HOST + self._port = port if port else _DEFAULT_PORT self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._plotter: PlotOcp = None @@ -107,7 +108,7 @@ def as_multiprocess(*args, **kwargs) -> None: def _run(self) -> None: """ - Starts the server, this method can be called directly by the user to start a plot server + Starts the server, this method is blocking """ # Start listening to the server self._socket.bind((self._host, self._port)) @@ -244,16 +245,31 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No threading.Timer(self._update_plot_interval, self._redraw).start() plt.show() + @property + def has_at_least_one_active_figure(self) -> bool: + """ + If at least one figure is active + + Returns + ------- + If at least one figure is active + """ + + return [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0 + def _redraw(self) -> None: """ Redraws the plot, this method is called periodically as long as at least one figure is open """ self._logger.debug("Updating plot") + self._is_drawing = True for _, fig in enumerate(self._plotter.all_figures): fig.canvas.draw() + fig.canvas.flush_events() + self._is_drawing = False - if [plt.fignum_exists(fig.number) for fig in self._plotter.all_figures].count(True) > 0: + if self.has_at_least_one_active_figure: threading.Timer(self._update_plot_interval, self._redraw).start() else: self._logger.info("All figures have been closed, stop updating the plots") @@ -272,6 +288,9 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> None: self._logger.debug(f"Waiting for new data from client") try: + if self._is_drawing: + # Give it some time + time.sleep(self._update_plot_interval) client_socket.sendall("READY_FOR_NEXT_DATA".encode()) except: self._logger.warning("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") @@ -288,26 +307,26 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> None: client_socket.close() return - elif message_type == _ServerMessages.EMPTY or message_type == _ServerMessages.CLOSE_CONNEXION: + elif message_type in (_ServerMessages.EMPTY, _ServerMessages.CLOSE_CONNEXION): self._logger.debug("Received empty data from client (end of stream), closing connexion") if should_continue: timer_get_data = threading.Timer(self._get_data_interval, self._wait_for_new_data, (client_socket,)) timer_get_data.start() - def _update_data(self, data_raw: list) -> None: + def _update_data(self, serialized_raw_data: list) -> None: """ - Updates the data to plot based on the client data + This method parses the data from the client Parameters ---------- - data_raw: list - The raw data from the client + serialized_raw_data: list + The serialized raw data from the client, see `xydata_encoding` below """ - header = [int(v) for v in data_raw[0].decode().split(",")] + header = [int(v) for v in serialized_raw_data[0].decode().split(",")] - data = data_raw[1] + data = serialized_raw_data[1] all_data = np.array(struct.unpack("d" * (len(data) // 8), data)) header_cmp = 0 @@ -372,8 +391,8 @@ def __init__(self, ocp, opts: dict = None, show_options: dict = None, host: str super().__init__(ocp, opts, show_options) - self._host = host if host else _default_host - self._port = port if port else _default_port + self._host = host if host else _DEFAULT_HOST + self._port = port if port else _DEFAULT_PORT self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if self.ocp.plot_ipopt_outputs: @@ -455,7 +474,7 @@ def close(self) -> None: self._socket.sendall(f"{_ServerMessages.CLOSE_CONNEXION.value}\nGoodbye from client!".encode()) self._socket.close() - def eval(self, arg: list | tuple, force: bool = False) -> list: + def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: """ Sends the current data to the plotter, this method is automatically called by the solver @@ -463,7 +482,7 @@ def eval(self, arg: list | tuple, force: bool = False) -> list: ---------- arg: list | tuple The current data - force: bool + enforce: bool If True, the client will block until the server is ready to receive new data. This is useful at the end of the optimization to make sure the data are plot (and not discarded) @@ -472,7 +491,7 @@ def eval(self, arg: list | tuple, force: bool = False) -> list: A mandatory [0] to respect the CasADi callback signature """ - if not force: + if not enforce: self._socket.setblocking(False) try: diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index f04e07377..f8736e899 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -275,23 +275,7 @@ def __init__( self.shooting_type = shooting_type if not only_initialize_variables: - horz = 0 - vert = 1 if len(self.all_figures) < self.n_vertical_windows * self.n_horizontal_windows else 0 - for i, fig in enumerate(self.all_figures): - if self.automatically_organize: - try: - fig.canvas.manager.window.move( - int(vert * self.width_step), int(self.top_margin + horz * self.height_step) - ) - vert += 1 - if vert >= self.n_vertical_windows: - horz += 1 - vert = 0 - except AttributeError: - pass - fig.canvas.draw() - if self.plot_options["general_options"]["use_tight_layout"]: - fig.tight_layout() + self._spread_figures_on_screen() if self.ocp.plot_ipopt_outputs: from ..gui.ipopt_output_plot import create_ipopt_output_plot @@ -688,6 +672,25 @@ def _organize_windows(self, n_windows: int): self.height_step = (height - self.top_margin) / self.n_horizontal_windows self.width_step = width / self.n_vertical_windows + def _spread_figures_on_screen(self): + horz = 0 + vert = 1 if len(self.all_figures) < self.n_vertical_windows * self.n_horizontal_windows else 0 + for i, fig in enumerate(self.all_figures): + if self.automatically_organize: + try: + fig.canvas.manager.window.move( + int(vert * self.width_step), int(self.top_margin + horz * self.height_step) + ) + vert += 1 + if vert >= self.n_vertical_windows: + horz += 1 + vert = 0 + except AttributeError: + pass + fig.canvas.draw() + if self.plot_options["general_options"]["use_tight_layout"]: + fig.tight_layout() + def find_phases_intersections(self): """ Finds the intersection between the phases @@ -783,16 +786,16 @@ def parse_data(self, **args) -> tuple[list, list]: def update_data( self, - xdata: dict, + xdata: list, ydata: list, **args: dict, ): """ - Update ydata from the variable a solution structure + Update xdata and ydata. The input are the output of the parse_data method Parameters ---------- - xdata: dict + xdata: list The time vector ydata: list The actual current data to be plotted diff --git a/bioptim/gui/serializable_class.py b/bioptim/gui/serializable_class.py index 408b134f8..c148ab1d8 100644 --- a/bioptim/gui/serializable_class.py +++ b/bioptim/gui/serializable_class.py @@ -400,6 +400,7 @@ def deserialize(cls, data): class OdeSolverSerializable: + # TODO There are probably more parameters to serialize here, if the GUI fails, this is probably the reason polynomial_degree: int type: OdeSolver @@ -414,7 +415,7 @@ def from_ode_solver(cls, ode_solver): ode_solver: OdeSolver = ode_solver return cls( - polynomial_degree=5, + polynomial_degree=ode_solver.polynomial_degree, type="ode", ) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 5ec68e39b..fcb42111c 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -10,7 +10,7 @@ from ..gui.online_callback_server import PlottingServer, OnlineCallbackServer from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers -from ..misc.enums import InterpolationType, ShowOnlineType +from ..misc.enums import InterpolationType, OnlineOptim from ..optimization.non_linear_program import NonLinearProgram @@ -23,29 +23,27 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): ocp: OptimalControlProgram A reference to the current OptimalControlProgram show_options: dict - The options to pass to PlotOcp, special options are: - - type: ShowOnlineType.MULTIPROCESS or ShowOnlineType.SERVER - - host: The host to connect to (only for ShowOnlineType.SERVER) - - port: The port to connect to (only for ShowOnlineType.SERVER) - - as_multiprocess: If the server should run as a multiprocess (only for ShowOnlineType.SERVER), if True, - a server is automatically started, if False, the user must start the server manually + The options to pass to PlotOcp, if online_optim is OnlineOptim.SERVER or OnlineOptim.MULTIPROCESS_SERVER there are + additional options: + - host: The host to connect to (only for OnlineOptim.SERVER) + - port: The port to connect to (only for OnlineOptim.SERVER) """ if show_options is None: show_options = {} - show_type = ShowOnlineType.MULTIPROCESS + show_type = OnlineOptim.MULTIPROCESS if "type" in show_options: show_type = show_options["type"] del show_options["type"] - if show_type == ShowOnlineType.MULTIPROCESS: - if platform == "win32": + if show_type == OnlineOptim.MULTIPROCESS: + if platform != "linux": raise RuntimeError( - "Online ShowOnlineType.MULTIPROCESS is not supported on Windows. " - "You can add show_options={'type': ShowOnlineType.TCP} to the Solver declaration" + "Online OnlineOptim.MULTIPROCESS is not supported on Windows or MacOS. " + "You can use online_optim=OnlineOptim.MULTIPROCESS_SERVER to the Solver declaration on Windows though" ) interface.options_common["iteration_callback"] = OnlineCallbackMultiprocess(ocp, show_options=show_options) - elif show_type == ShowOnlineType.SERVER: + elif show_type == OnlineOptim.SERVER: host = None if "host" in show_options: host = show_options["host"] @@ -56,6 +54,7 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): port = show_options["port"] del show_options["port"] + # TODO HERE! as_multiprocess = True if "as_multiprocess" in show_options: as_multiprocess = show_options["as_multiprocess"] @@ -96,7 +95,7 @@ def generic_solve(interface, expand_during_shake_tree=False) -> dict: all_g, all_g_bounds = interface.dispatch_bounds() all_g = _shake_tree_for_penalties(interface.ocp, all_g, v, v_bounds, expand_during_shake_tree) - if interface.opts.show_online_optim: + if interface.opts.online_optim is not OnlineOptim.NONE: interface.online_optim(interface.ocp, interface.opts.show_options) # Thread here on (f and all_g) instead of individually for each function? @@ -151,7 +150,7 @@ def generic_solve(interface, expand_during_shake_tree=False) -> dict: interface.out["sol"]["lam_g"], interface.out["sol"]["lam_p"], ] - interface.options_common["iteration_callback"].eval(to_eval, force=True) + interface.options_common["iteration_callback"].eval(to_eval, enforce=True) return interface.out diff --git a/bioptim/interfaces/ipopt_options.py b/bioptim/interfaces/ipopt_options.py index 4a4fcbef3..de4fc5f7a 100644 --- a/bioptim/interfaces/ipopt_options.py +++ b/bioptim/interfaces/ipopt_options.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from ..misc.enums import SolverType +from ..misc.enums import SolverType, OnlineOptim from .abstract_options import GenericSolver @@ -11,8 +11,13 @@ class IPOPT(GenericSolver): Attributes ---------- - show_online_optim: bool - If the plot should be shown while optimizing. It will slow down the optimization a bit + show_online_optim: bool | None + If the plot should be shown while optimizing. If set to True, it will the default online_optim. online_optim + and show_online_optim cannot be simultaneous set + online_optim: OnlineOptim + The type of online plot to show. If set to None (default), then no plot will be shown. If set to DEFAULT, it + will use the fastest method for your OS (multiprocessing on Linux and multiprocessing_server on Windows). + In all cases, it will slow down the optimization a bit. show_options: dict The graphs option to pass to PlotOcp _tol: float @@ -68,7 +73,8 @@ class IPOPT(GenericSolver): """ type: SolverType = SolverType.IPOPT - show_online_optim: bool = False + show_online_optim: bool | None = None + online_optim: OnlineOptim = OnlineOptim.NONE show_options: dict = None _tol: float = 1e-6 # default in ipopt 1e-8 _dual_inf_tol: float = 1.0 @@ -96,6 +102,16 @@ class IPOPT(GenericSolver): _c_compile: bool = False _check_derivatives_for_naninf: str = "no" # "yes" + def __attrs_post_init__(self): + if self.show_online_optim and self.online_optim != OnlineOptim.NONE: + raise ValueError("show_online_optim and online_optim cannot be simultaneous set") + + if self.show_online_optim is not None: + if self.show_online_optim: + self.online_optim = OnlineOptim.DEFAULT + else: + self.online_optim = OnlineOptim.NONE + @property def tol(self): return self._tol @@ -323,7 +339,7 @@ def set_option_unsafe(self, val, name): def as_dict(self, solver): solver_options = self.__dict__ options = {} - non_python_options = ["_c_compile", "type", "show_online_optim", "show_options"] + non_python_options = ["_c_compile", "type", "show_online_optim", "online_optim", "show_options"] for key in solver_options: if key not in non_python_options: ipopt_key = "ipopt." + key[1:] diff --git a/bioptim/interfaces/sqp_options.py b/bioptim/interfaces/sqp_options.py index ffc5ab8b5..c5f8ac5f4 100644 --- a/bioptim/interfaces/sqp_options.py +++ b/bioptim/interfaces/sqp_options.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from ..misc.enums import SolverType +from ..misc.enums import SolverType, OnlineOptim from .abstract_options import GenericSolver @@ -11,6 +11,13 @@ class SQP_METHOD(GenericSolver): Methods ------- + show_online_optim: bool | None + If the plot should be shown while optimizing. If set to True, it will the default online_optim. online_optim + and show_online_optim cannot be simultaneous set + online_optim: OnlineOptim + The type of online plot to show. If set to None (default), then no plot will be shown. If set to DEFAULT, it + will use the fastest method for your OS (multiprocessing on Linux and multiprocessing_server on Windows). + In all cases, it will slow down the optimization a bit. set_beta(beta: float): Line-search parameter, restoration factor of stepsize set_c1(c1: float): @@ -66,7 +73,8 @@ class SQP_METHOD(GenericSolver): """ type: SolverType = SolverType.SQP - show_online_optim: bool = False + show_online_optim: bool | None = None + online_optim: OnlineOptim = OnlineOptim.NONE show_options: dict = None c_compile = False _beta: float = 0.8 @@ -82,6 +90,16 @@ class SQP_METHOD(GenericSolver): _tol_du: float = 1e-6 _tol_pr: float = 1e-6 + def __attrs_post_init__(self): + if self.show_online_optim and self.online_optim != OnlineOptim.NONE: + raise ValueError("show_online_optim and online_optim cannot be simultaneous set") + + if self.show_online_optim is not None: + if self.show_online_optim: + self.online_optim = OnlineOptim.DEFAULT + else: + self.online_optim = OnlineOptim.NONE + @property def beta(self): return self._beta diff --git a/bioptim/misc/enums.py b/bioptim/misc/enums.py index e1fde3cb0..90c72da22 100644 --- a/bioptim/misc/enums.py +++ b/bioptim/misc/enums.py @@ -1,4 +1,4 @@ -from enum import Enum, IntEnum +from enum import Enum, IntEnum, auto class PhaseDynamics(Enum): @@ -93,20 +93,24 @@ class PlotType(Enum): POINT = 3 # Point plot -class ShowOnlineType(Enum): +class OnlineOptim(Enum): """ The type of callback Attributes ---------- - MULTIPROCESS: int - Using multiprocessing - SERVER: int - Using a server to communicate with the client - """ - - MULTIPROCESS = 0 - SERVER = 1 + NONE: No online plotting + DEFAULT: Default online plotting (MULTIPROCESS on Linux, MULTIPROCESS_SERVER on Windows and NONE on MacOS) + MULTIPROCESS: Multiprocess online plotting + SERVER: Server online plotting + MULTIPROCESS_SERVER: Multiprocess server online plotting + """ + + NONE = auto() + DEFAULT = auto() + MULTIPROCESS = auto() + SERVER = auto() + MULTIPROCESS_SERVER = auto() class ControlType(Enum): diff --git a/bioptim/optimization/optimal_control_program.py b/bioptim/optimization/optimal_control_program.py index cc5468e2d..b72acb03d 100644 --- a/bioptim/optimization/optimal_control_program.py +++ b/bioptim/optimization/optimal_control_program.py @@ -121,7 +121,7 @@ class OptimalControlProgram: prepare_plots(self, automatically_organize: bool, show_bounds: bool, shooting_type: Shooting) -> PlotOCP Create all the plots associated with the OCP - solve(self, solver: Solver, show_online_optim: bool, solver_options: dict) -> Solution + solve(self, solver: Solver) -> Solution Call the solver to actually solve the ocp _define_time(self, phase_time: float | tuple, objective_functions: ObjectiveList, constraints: ConstraintList) Declare the phase_time vector in v. If objective_functions or constraints defined a time optimization, diff --git a/bioptim/optimization/optimization_vector.py b/bioptim/optimization/optimization_vector.py index d109be9a7..cd2a1a7e7 100644 --- a/bioptim/optimization/optimization_vector.py +++ b/bioptim/optimization/optimization_vector.py @@ -356,7 +356,7 @@ def extract_step_times(ocp, data: np.ndarray | DM) -> list: ocp: OptimalControlProgram A reference to the ocp data: np.ndarray | DM - The solution in a vector, if no data is provided, dummy data is used (it can be useful getting the dimensions) + The solution in a vector Returns ------- diff --git a/bioptim/optimization/receding_horizon_optimization.py b/bioptim/optimization/receding_horizon_optimization.py index a404d7bc0..2087f903a 100644 --- a/bioptim/optimization/receding_horizon_optimization.py +++ b/bioptim/optimization/receding_horizon_optimization.py @@ -12,7 +12,7 @@ from ..limits.constraints import ConstraintFcn, ConstraintList from ..limits.objective_functions import ObjectiveFcn, ObjectiveList from ..limits.path_conditions import InitialGuessList -from ..misc.enums import SolverType, InterpolationType, MultiCyclicCycleSolutions, ControlType +from ..misc.enums import SolverType, InterpolationType, MultiCyclicCycleSolutions, ControlType, OnlineOptim from ..interfaces import Solver from ..interfaces.abstract_options import GenericSolver from ..models.protocols.biomodel import BioModel @@ -26,7 +26,7 @@ class RecedingHorizonOptimization(OptimalControlProgram): Methods ------- - solve(self, solver: Solver, show_online_optim: bool, solver_options: dict) -> Solution + solve(self, solver: Solver) -> Solution Call the solver to actually solve the ocp """ @@ -174,7 +174,7 @@ def solve( f"Only {solver_current.get_tolerance_keys()} can be modified." ) if solver_current.type == SolverType.IPOPT: - solver_current.show_online_optim = False + solver_current.online_optim = OnlineOptim.NONE warm_start = None total_time += sol.real_time_to_optimize diff --git a/resources/plotting_server.py b/resources/plotting_server.py index 6981b0b5d..3ee6eb2da 100644 --- a/resources/plotting_server.py +++ b/resources/plotting_server.py @@ -1,7 +1,7 @@ """ This file is an example of how to run a bioptim Online plotting server. That said, this is usually not the way to run a bioptim server as it is easier to run it as an automatic multiprocess (default). This is achieved by setting -`show_options={"type": ShowOnlineType.SERVER, "as_multiprocess": True}` in the solver options. +`show_options={"type": OnlineOptim.SERVER, "as_multiprocess": True}` in the solver options. If set to False, then the plotting server is mandatory. Since the server runs usings sockets, it is possible to run the server on a different machine than the one running the diff --git a/tests/shard1/test_biorbd_model_holonomic.py b/tests/shard1/test_biorbd_model_holonomic.py index f0a93b766..a6d672faf 100644 --- a/tests/shard1/test_biorbd_model_holonomic.py +++ b/tests/shard1/test_biorbd_model_holonomic.py @@ -183,7 +183,7 @@ def test_example_two_pendulums(): ) # --- Solve the ocp --- # - sol = ocp.solve(Solver.IPOPT(show_online_optim=False)) + sol = ocp.solve(Solver.IPOPT()) states = sol.decision_states(to_merge=SolutionMerge.NODES) npt.assert_almost_equal( diff --git a/tests/shard1/test_controltype_none.py b/tests/shard1/test_controltype_none.py index 4fa8bedd3..0afc6a40d 100644 --- a/tests/shard1/test_controltype_none.py +++ b/tests/shard1/test_controltype_none.py @@ -253,9 +253,7 @@ def test_main_control_type_none(use_sx, phase_dynamics): ) # --- Solve the program --- # - sol = ocp.solve( - Solver.IPOPT(show_online_optim=False), - ) + sol = ocp.solve(Solver.IPOPT()) # Check objective function value f = np.array(sol.cost) diff --git a/tests/shard1/test_custom_model.py b/tests/shard1/test_custom_model.py index bf2590f0c..c0fa9f672 100644 --- a/tests/shard1/test_custom_model.py +++ b/tests/shard1/test_custom_model.py @@ -28,7 +28,7 @@ def test_custom_model(phase_dynamics): npt.assert_almost_equal(ocp.nlp[0].model.mass, 1) assert ocp.nlp[0].model.name_dof == ["rotx"] - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(2) sol = ocp.solve(solver=solver) diff --git a/tests/shard2/test_global_sqp.py b/tests/shard2/test_global_sqp.py index cc1956fbf..b1bba2cc6 100644 --- a/tests/shard2/test_global_sqp.py +++ b/tests/shard2/test_global_sqp.py @@ -23,7 +23,7 @@ def test_pendulum(phase_dynamics): expand_dynamics=True, ) - solver = Solver.SQP_METHOD(show_online_optim=False) + solver = Solver.SQP_METHOD() solver.set_tol_du(1e-1) solver.set_tol_pr(1e-1) solver.set_max_iter_ls(1) diff --git a/tests/shard3/test_get_time_solution.py b/tests/shard3/test_get_time_solution.py index 163f6fc58..abb46c06c 100644 --- a/tests/shard3/test_get_time_solution.py +++ b/tests/shard3/test_get_time_solution.py @@ -62,7 +62,7 @@ def _get_solution( return None ocp = ocp_module.prepare_ocp(**prepare_args) - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(0) solver.set_print_level(0) diff --git a/tests/shard3/test_global_torque_driven_with_contact_ocp.py b/tests/shard3/test_global_torque_driven_with_contact_ocp.py index 8118fa52e..00cae2b89 100644 --- a/tests/shard3/test_global_torque_driven_with_contact_ocp.py +++ b/tests/shard3/test_global_torque_driven_with_contact_ocp.py @@ -161,7 +161,7 @@ def test_maximize_predicted_height_CoM_rigidbody_dynamics(rigidbody_dynamics, ph phase_dynamics=phase_dynamics, expand_dynamics=True, ) - sol_opt = Solver.IPOPT(show_online_optim=False) + sol_opt = Solver.IPOPT() sol_opt.set_maximum_iterations(1) sol = ocp.solve(sol_opt) diff --git a/tests/shard4/test_solution.py b/tests/shard4/test_solution.py index 6b5c8965a..b618861d2 100644 --- a/tests/shard4/test_solution.py +++ b/tests/shard4/test_solution.py @@ -22,7 +22,7 @@ def test_time(ode_solver, phase_dynamics): phase_dynamics=phase_dynamics, expand_dynamics=True, ) - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(0) solver.set_print_level(0) @@ -58,7 +58,7 @@ def test_time_multiphase(ode_solver, phase_dynamics, continuous): expand_dynamics=True, ) - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(0) solver.set_print_level(0) @@ -136,7 +136,7 @@ def test_generate_stepwise_time(ode_solver, merge_phase, phase_dynamics, continu expand_dynamics=True, ) - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(0) solver.set_print_level(0) @@ -232,7 +232,7 @@ def test_generate_decision_time(ode_solver, merge_phase, phase_dynamics, continu expand_dynamics=True, ) - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(0) solver.set_print_level(0) @@ -346,7 +346,7 @@ def test_generate_integrate(ode_solver, merge_phase, shooting_type, integrator, expand_dynamics=True, ) - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(100) solver.set_print_level(0) sol = ocp.solve(solver=solver) diff --git a/tests/shard4/test_solver_options.py b/tests/shard4/test_solver_options.py index b477faffc..04eb5364a 100644 --- a/tests/shard4/test_solver_options.py +++ b/tests/shard4/test_solver_options.py @@ -1,5 +1,7 @@ +import pytest + from bioptim import Solver -from bioptim.misc.enums import SolverType +from bioptim.misc.enums import SolverType, OnlineOptim class FakeSolver: @@ -13,7 +15,8 @@ def __init__( def test_ipopt_solver_options(): solver = Solver.IPOPT() assert solver.type == SolverType.IPOPT - assert solver.show_online_optim is False + assert solver.show_online_optim is None + assert solver.online_optim is OnlineOptim.NONE assert solver.show_options is None assert solver.tol == 1e-6 assert solver.dual_inf_tol == 1.0 @@ -125,7 +128,21 @@ def test_ipopt_solver_options(): assert not "_c_compile" in solver_dict assert not "type" in solver_dict assert not "show_online_optim" in solver_dict + assert not "online_optim" in solver_dict assert not "show_options" in solver_dict solver.set_nlp_scaling_method("gradient-fiesta") assert solver.nlp_scaling_method == "gradient-fiesta" + + +def test_ipopt_solver_options_wrong(): + + solver = Solver.IPOPT() + solver.show_online_optim = True + with pytest.raises(ValueError, match="show_online_optim and online_optim cannot be simultaneous set"): + solver.online_optim = OnlineOptim.SERVER + + solver.show_online_optim = None + solver.online_optim = OnlineOptim.SERVER + with pytest.raises(ValueError, match="show_online_optim and online_optim cannot be simultaneous set"): + solver.show_online_optim = True diff --git a/tests/shard4/test_variational_integrator_examples.py b/tests/shard4/test_variational_integrator_examples.py index c2e20c5cd..593c81db4 100644 --- a/tests/shard4/test_variational_integrator_examples.py +++ b/tests/shard4/test_variational_integrator_examples.py @@ -25,7 +25,7 @@ def test_variational_pendulum(use_sx): ) # --- Solve the ocp --- # - sol = ocp.solve(Solver.IPOPT(show_online_optim=False)) + sol = ocp.solve(Solver.IPOPT()) states = sol.decision_states(to_merge=SolutionMerge.NODES) controls = sol.decision_controls(to_merge=SolutionMerge.NODES) @@ -59,7 +59,7 @@ def test_variational_pendulum_with_holonomic_constraints(use_sx): ) # --- Solve the ocp --- # - sol = ocp.solve(Solver.IPOPT(show_online_optim=False)) + sol = ocp.solve(Solver.IPOPT()) states = sol.decision_states(to_merge=SolutionMerge.NODES) controls = sol.decision_controls(to_merge=SolutionMerge.NODES) diff --git a/tests/shard5/test_global_stochastic_collocation.py b/tests/shard5/test_global_stochastic_collocation.py index 762339ee2..2f2f890d2 100644 --- a/tests/shard5/test_global_stochastic_collocation.py +++ b/tests/shard5/test_global_stochastic_collocation.py @@ -38,7 +38,7 @@ def test_arm_reaching_torque_driven_collocations(use_sx: bool): ) # Solver parameters - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_nlp_scaling_method("none") sol = ocp.solve(solver) @@ -107,7 +107,7 @@ def test_arm_reaching_torque_driven_collocations(use_sx: bool): ) # Solver parameters - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_nlp_scaling_method("none") solver.set_maximum_iterations(0) solver.set_bound_frac(1e-8) @@ -423,7 +423,7 @@ def test_obstacle_avoidance_direct_collocation(use_sx: bool): ) # Solver parameters - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(4) sol = ocp.solve(solver) diff --git a/tests/shard6/test_global_stochastic_except_collocation.py b/tests/shard6/test_global_stochastic_except_collocation.py index 634a6c6fa..98c3fae01 100644 --- a/tests/shard6/test_global_stochastic_except_collocation.py +++ b/tests/shard6/test_global_stochastic_except_collocation.py @@ -56,7 +56,7 @@ def test_arm_reaching_muscle_driven(use_sx): # ocp.print(to_console=True, to_graph=False) #TODO: check to adjust the print method # Solver parameters - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(4) solver.set_nlp_scaling_method("none") @@ -291,7 +291,7 @@ def test_arm_reaching_torque_driven_explicit(use_sx): ) # Solver parameters - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(4) solver.set_nlp_scaling_method("none") @@ -444,7 +444,7 @@ def test_arm_reaching_torque_driven_implicit(with_cholesky, with_scaling, use_sx ) # Solver parameters - solver = Solver.IPOPT(show_online_optim=False) + solver = Solver.IPOPT() solver.set_maximum_iterations(4) solver.set_nlp_scaling_method("none") From 6aef860ba590baf1f24ac9b96a502d3e4f8e9334 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Fri, 2 Aug 2024 10:03:14 -0400 Subject: [PATCH 14/17] Finished porting OnlineOptim.MULTIPROCESS_SERVER to a dedicated class --- .../online_callback_multiprocess_server.py | 15 +- bioptim/gui/online_callback_server.py | 305 +++++++++++------- bioptim/gui/serializable_class.py | 11 +- bioptim/interfaces/interface_utils.py | 29 +- bioptim/interfaces/sqp_options.py | 2 +- 5 files changed, 228 insertions(+), 134 deletions(-) diff --git a/bioptim/gui/online_callback_multiprocess_server.py b/bioptim/gui/online_callback_multiprocess_server.py index d352652e4..9793fef94 100644 --- a/bioptim/gui/online_callback_multiprocess_server.py +++ b/bioptim/gui/online_callback_multiprocess_server.py @@ -1,9 +1,9 @@ from multiprocessing import Process -from .online_callback_server import PlottingServer +from .online_callback_server import PlottingServer, OnlineCallbackServer -def _start_as_multiprocess_internal(*args, **kwargs): +def _start_as_multiprocess_internal(**kwargs): """ Starts the server (necessary for multiprocessing), this method should not be called directly, apart from run_as_multiprocess @@ -12,10 +12,10 @@ def _start_as_multiprocess_internal(*args, **kwargs): ---------- same as PlottingServer """ - PlottingServer(*args, **kwargs) + PlottingServer(**kwargs) -class PlottingMultiprocessServer(PlottingServer): +class PlottingMultiprocessServer(OnlineCallbackServer): def __init__(self, *args, **kwargs): """ Starts the server in a new process @@ -24,6 +24,9 @@ def __init__(self, *args, **kwargs): ---------- Same as PlottingServer """ - - process = Process(target=_start_as_multiprocess_internal, args=args, kwargs=kwargs) + host = kwargs["host"] if "host" in kwargs else None + port = kwargs["port"] if "port" in kwargs else None + process = Process(target=_start_as_multiprocess_internal, kwargs={"host": host, "port": port}) process.start() + + super(PlottingMultiprocessServer, self).__init__(*args, **kwargs) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 35edbb02a..375bd7730 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -92,20 +92,6 @@ def _prepare_logger(self) -> None: self._logger.addHandler(console_handler) self._logger.setLevel(logging.INFO) - @staticmethod - def as_multiprocess(*args, **kwargs) -> None: - """ - Starts the server in a new process, this method can be called directly by the user - - Parameters - ---------- - same as PlottingServer - """ - from multiprocessing import Process - - thread = Process(target=_start_as_multiprocess_internal, args=args, kwargs=kwargs) - thread.start() - def _run(self) -> None: """ Starts the server, this method is blocking @@ -126,10 +112,51 @@ def _run(self) -> None: finally: self._socket.close() - def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool) -> tuple[_ServerMessages, list]: + def _wait_for_new_connexion(self, client_socket: socket.socket) -> None: + """ + Waits for a new connexion + + Parameters + ---------- + client_socket: socket.socket + The client socket + """ + + message_type, data = self._recv_data(client_socket=client_socket, send_confirmation=True) + if message_type == _ServerMessages.INITIATE_CONNEXION: + self._logger.debug(f"Received hand shake from client") + self._initialize_plotter(client_socket, data) + + def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> tuple[_ServerMessages, list]: """ Waits for data from the client + Parameters + ---------- + client_socket: socket.socket + The client socket + send_confirmation: bool + If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will + not send anything. This is part of the communication protocol + + Returns + ------- + The message type and the data + """ + self._logger.debug("Waiting for data from client") + message_type, data_len = self._recv_message_type_and_data_len(client_socket, send_confirmation) + if data_len is None: + return message_type, None + + data = self._recv_serialize_data(client_socket, send_confirmation, data_len) + return message_type, data + + def _recv_message_type_and_data_len( + self, client_socket: socket.socket, send_confirmation: bool + ) -> tuple[_ServerMessages, list]: + """ + Waits for data len from the client (first part of the protocol) + Parameters ---------- client_socket: socket.socket @@ -145,7 +172,6 @@ def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool) # Receive the actual data try: - self._logger.debug("Waiting for data from client") data = client_socket.recv(1024) if not data: return _ServerMessages.EMPTY, None @@ -157,48 +183,71 @@ def _wait_for_data(self, client_socket: socket.socket, send_confirmation: bool) data_as_list = data.decode().split("\n") try: message_type = _ServerMessages(int(data_as_list[0])) - len_all_data = [int(len_data) for len_data in data_as_list[1][1:-1].split(",")] - # Sends confirmation and waits for the next message - if send_confirmation: - client_socket.sendall("OK".encode()) - self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") - data_out = [] - for len_data in len_all_data: - data_out.append(client_socket.recv(len_data)) - if len(data_out[-1]) != len_data: - data_out[-1] += client_socket.recv(len_data - len(data_out[-1])) - if send_confirmation: - client_socket.sendall("OK".encode()) except ValueError: self._logger.warning("Unknown message type received") - message_type = _ServerMessages.UNKNOWN # Sends failure if send_confirmation: client_socket.sendall("NOK".encode()) - data_out = [] + return _ServerMessages.UNKNOWN, None if message_type == _ServerMessages.CLOSE_CONNEXION: self._logger.info("Received close connexion from client") client_socket.close() - plt.close() return _ServerMessages.CLOSE_CONNEXION, None - return message_type, data_out + try: + len_all_data = [int(len_data) for len_data in data_as_list[1][1:-1].split(",")] + except ValueError: + self._logger.warning("Length of data could not be extracted") + # Sends failure + if send_confirmation: + client_socket.sendall("NOK".encode()) + return _ServerMessages.UNKNOWN, None - def _wait_for_new_connexion(self, client_socket: socket.socket) -> None: + # If we are here, everything went well, so send confirmation + self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") + if send_confirmation: + client_socket.sendall("OK".encode()) + + return message_type, len_all_data + + def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: bool, len_all_data: list) -> tuple: """ - Waits for a new connexion + Receives the data from the client (second part of the protocol) Parameters ---------- client_socket: socket.socket The client socket + send_confirmation: bool + If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will + not send anything. This is part of the communication protocol + len_all_data: list + The length of the data to receive + + Returns + ------- + The unparsed serialized data """ - message_type, data = self._wait_for_data(client_socket=client_socket, send_confirmation=True) - if message_type == _ServerMessages.INITIATE_CONNEXION: - self._logger.debug(f"Received hand shake from client") - self._initialize_plotter(client_socket, data) + data_out = [] + try: + for len_data in len_all_data: + data_out.append(client_socket.recv(len_data)) + if len(data_out[-1]) != len_data: + data_out[-1] += client_socket.recv(len_data - len(data_out[-1])) + except: + self._logger.warning("Unknown message type received") + # Sends failure + if send_confirmation: + client_socket.sendall("NOK".encode()) + return None + + # If we are here, everything went well, so send confirmation + if send_confirmation: + client_socket.sendall("OK".encode()) + + return data_out def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> None: """ @@ -241,7 +290,7 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No client_socket.sendall("PLOT_READY".encode()) # Start the callbacks - threading.Timer(self._get_data_interval, self._wait_for_new_data, (client_socket,)).start() + threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() threading.Timer(self._update_plot_interval, self._redraw).start() plt.show() @@ -274,7 +323,7 @@ def _redraw(self) -> None: else: self._logger.info("All figures have been closed, stop updating the plots") - def _wait_for_new_data(self, client_socket: socket.socket) -> None: + def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ Waits for new data from the client, sends a "READY_FOR_NEXT_DATA" message to the client to signal that the server is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects @@ -297,10 +346,10 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> None: return should_continue = False - message_type, data = self._wait_for_data(client_socket=client_socket, send_confirmation=False) + message_type, data = self._recv_data(client_socket=client_socket, send_confirmation=False) if message_type == _ServerMessages.NEW_DATA: try: - self._update_data(data) + self._update_plot(data) should_continue = True except: self._logger.warning("Error while updating data from client, closing connexion") @@ -311,10 +360,10 @@ def _wait_for_new_data(self, client_socket: socket.socket) -> None: self._logger.debug("Received empty data from client (end of stream), closing connexion") if should_continue: - timer_get_data = threading.Timer(self._get_data_interval, self._wait_for_new_data, (client_socket,)) + timer_get_data = threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)) timer_get_data.start() - def _update_data(self, serialized_raw_data: list) -> None: + def _update_plot(self, serialized_raw_data: list) -> None: """ This method parses the data from the client @@ -323,48 +372,8 @@ def _update_data(self, serialized_raw_data: list) -> None: serialized_raw_data: list The serialized raw data from the client, see `xydata_encoding` below """ - - header = [int(v) for v in serialized_raw_data[0].decode().split(",")] - - data = serialized_raw_data[1] - all_data = np.array(struct.unpack("d" * (len(data) // 8), data)) - - header_cmp = 0 - all_data_cmp = 0 - xdata = [] - n_phases = header[header_cmp] - header_cmp += 1 - for _ in range(n_phases): - n_nodes = header[header_cmp] - header_cmp += 1 - x_phases = [] - for _ in range(n_nodes): - n_steps = header[header_cmp] - header_cmp += 1 - - x_phases.append(all_data[all_data_cmp : all_data_cmp + n_steps]) - all_data_cmp += n_steps - xdata.append(x_phases) - - ydata = [] - n_variables = header[header_cmp] - header_cmp += 1 - for _ in range(n_variables): - n_nodes = header[header_cmp] - header_cmp += 1 - if n_nodes == 0: - n_nodes = 1 - - y_variables = [] - for _ in range(n_nodes): - n_steps = header[header_cmp] - header_cmp += 1 - - y_variables.append(all_data[all_data_cmp : all_data_cmp + n_steps]) - all_data_cmp += n_steps - ydata.append(y_variables) - self._logger.debug(f"Received new data from client") + xdata, ydata = _deserialize_xydata(serialized_raw_data) self._plotter.update_data(xdata, ydata) @@ -507,29 +516,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: args_dict = {} for i, s in enumerate(nlpsol_out()): args_dict[s] = arg[i] - xdata_raw, ydata_raw = self._plotter.parse_data(**args_dict) - - header = f"{len(xdata_raw)}" - data_serialized = b"" - for x_nodes in xdata_raw: - header += f",{len(x_nodes)}" - for x_steps in x_nodes: - header += f",{x_steps.shape[0]}" - x_steps_tp = np.array(x_steps)[:, 0].tolist() - data_serialized += struct.pack("d" * len(x_steps_tp), *x_steps_tp) - - header += f",{len(ydata_raw)}" - for y_nodes_variable in ydata_raw: - if isinstance(y_nodes_variable, np.ndarray): - header += f",0" - y_nodes_variable = [y_nodes_variable] - else: - header += f",{len(y_nodes_variable)}" - - for y_steps in y_nodes_variable: - header += f",{y_steps.shape[0]}" - y_steps_tp = y_steps.tolist() - data_serialized += struct.pack("d" * len(y_steps_tp), *y_steps_tp) + xdata, ydata = self._plotter.parse_data(**args_dict) + header, data_serialized = _serialize_xydata(xdata, ydata) self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) # If send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) @@ -537,3 +525,104 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.sendall(data_serialized) # Again, if send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) return [0] + + +def _serialize_xydata(xdata: list, ydata: list) -> tuple: + """ + Serialize the data to send to the server, it will be deserialized by `_deserialize_xydata` + + Parameters + ---------- + xdata: list + The X data to serialize from PlotOcp.parse_data + ydata: list + The Y data to serialize from PlotOcp.parse_data + + Returns + ------- + The serialized data as expected by the server (header, serialized_data) + """ + + header = f"{len(xdata)}" + data_serialized = b"" + for x_nodes in xdata: + header += f",{len(x_nodes)}" + for x_steps in x_nodes: + header += f",{x_steps.shape[0]}" + x_steps_tp = np.array(x_steps)[:, 0].tolist() + data_serialized += struct.pack("d" * len(x_steps_tp), *x_steps_tp) + + header += f",{len(ydata)}" + for y_nodes_variable in ydata: + if isinstance(y_nodes_variable, np.ndarray): + header += f",0" + y_nodes_variable = [y_nodes_variable] + else: + header += f",{len(y_nodes_variable)}" + + for y_steps in y_nodes_variable: + header += f",{y_steps.shape[0]}" + y_steps_tp = y_steps.tolist() + data_serialized += struct.pack("d" * len(y_steps_tp), *y_steps_tp) + + return header, data_serialized + + +def _deserialize_xydata(serialized_raw_data: list) -> tuple: + """ + Deserialize the data from the client, based on the serialization used in _serialize_xydata` + + Parameters + ---------- + serialized_raw_data: list + The serialized raw data from the client + + Returns + ------- + The deserialized data as expected by PlotOcp.update_data + """ + + # Header is made of ints comma separated from the first line + header = [int(v) for v in serialized_raw_data[0].decode().split(",")] + + # Data is made of doubles (d) from the second line, the length of which is 8 bytes each + data = serialized_raw_data[1] + all_data = np.array(struct.unpack("d" * (len(data) // 8), data)) + + # Based on the header, we can now parse the data, assuming the number of phases, nodes and steps from the header + header_cmp = 0 + all_data_cmp = 0 + xdata = [] + n_phases = header[header_cmp] # Number of phases + header_cmp += 1 + for _ in range(n_phases): + n_nodes = header[header_cmp] # Number of nodes in the phase + header_cmp += 1 + x_phases = [] + for _ in range(n_nodes): + n_steps = header[header_cmp] # Number of steps in the node + header_cmp += 1 + + x_phases.append(all_data[all_data_cmp : all_data_cmp + n_steps]) # The X data of the node + all_data_cmp += n_steps + xdata.append(x_phases) + + ydata = [] + n_variables = header[header_cmp] # Number of variables (states, controls, etc.) + header_cmp += 1 + for _ in range(n_variables): + n_nodes = header[header_cmp] # Number of nodes for the variable + header_cmp += 1 + if n_nodes == 0: + n_nodes = 1 + + y_variables = [] + for _ in range(n_nodes): + n_steps = header[header_cmp] # Number of steps in the node for the variable + header_cmp += 1 + + y_variables.append(all_data[all_data_cmp : all_data_cmp + n_steps]) # The Y data of the node + all_data_cmp += n_steps + ydata.append(y_variables) + + return xdata, ydata diff --git a/bioptim/gui/serializable_class.py b/bioptim/gui/serializable_class.py index c148ab1d8..f53391500 100644 --- a/bioptim/gui/serializable_class.py +++ b/bioptim/gui/serializable_class.py @@ -401,11 +401,13 @@ def deserialize(cls, data): class OdeSolverSerializable: # TODO There are probably more parameters to serialize here, if the GUI fails, this is probably the reason - polynomial_degree: int + polynomial_degree: int | None + n_integration_steps: int | None type: OdeSolver - def __init__(self, polynomial_degree: int, type: OdeSolver): + def __init__(self, polynomial_degree: int | None, n_integration_steps: int | None, type: OdeSolver): self.polynomial_degree = polynomial_degree + self.n_integration_steps = n_integration_steps self.type = type @classmethod @@ -415,13 +417,15 @@ def from_ode_solver(cls, ode_solver): ode_solver: OdeSolver = ode_solver return cls( - polynomial_degree=ode_solver.polynomial_degree, + polynomial_degree=ode_solver.polynomial_degree if hasattr(ode_solver, "polynomial_degree") else None, + n_integration_steps=ode_solver.n_integration_steps if hasattr(ode_solver, "n_integration_steps") else None, type="ode", ) def serialize(self): return { "polynomial_degree": self.polynomial_degree, + "n_integration_steps": self.n_integration_steps, "type": self.type, } @@ -429,6 +433,7 @@ def serialize(self): def deserialize(cls, data): return cls( polynomial_degree=data["polynomial_degree"], + n_integration_steps=data["n_integration_steps"], type=data["type"], ) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index fcb42111c..62b492abf 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -7,7 +7,8 @@ from bioptim.optimization.solution.solution import Solution from ..gui.online_callback_multiprocess import OnlineCallbackMultiprocess -from ..gui.online_callback_server import PlottingServer, OnlineCallbackServer +from ..gui.online_callback_server import PlottingServer +from ..gui.online_callback_multiprocess_server import PlottingMultiprocessServer from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers from ..misc.enums import InterpolationType, OnlineOptim @@ -31,19 +32,16 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): if show_options is None: show_options = {} - show_type = OnlineOptim.MULTIPROCESS - if "type" in show_options: - show_type = show_options["type"] - del show_options["type"] + online_optim: OnlineOptim = interface.opts.online_optim - if show_type == OnlineOptim.MULTIPROCESS: + if interface.opts.online_optim == OnlineOptim.MULTIPROCESS: if platform != "linux": raise RuntimeError( "Online OnlineOptim.MULTIPROCESS is not supported on Windows or MacOS. " "You can use online_optim=OnlineOptim.MULTIPROCESS_SERVER to the Solver declaration on Windows though" ) interface.options_common["iteration_callback"] = OnlineCallbackMultiprocess(ocp, show_options=show_options) - elif show_type == OnlineOptim.SERVER: + elif online_optim in (OnlineOptim.SERVER, OnlineOptim.MULTIPROCESS_SERVER): host = None if "host" in show_options: host = show_options["host"] @@ -54,19 +52,18 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): port = show_options["port"] del show_options["port"] - # TODO HERE! - as_multiprocess = True - if "as_multiprocess" in show_options: - as_multiprocess = show_options["as_multiprocess"] - del show_options["as_multiprocess"] - if as_multiprocess: - PlottingServer.as_multiprocess(host=host, port=port) + if online_optim == OnlineOptim.SERVER: + class_to_instantiate = PlottingServer + elif online_optim == OnlineOptim.MULTIPROCESS_SERVER: + class_to_instantiate = PlottingMultiprocessServer + else: + raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") - interface.options_common["iteration_callback"] = OnlineCallbackServer( + interface.options_common["iteration_callback"] = class_to_instantiate( ocp, show_options=show_options, host=host, port=port ) else: - raise NotImplementedError(f"show_options['type']={show_type} is not implemented yet") + raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") def generic_solve(interface, expand_during_shake_tree=False) -> dict: diff --git a/bioptim/interfaces/sqp_options.py b/bioptim/interfaces/sqp_options.py index c5f8ac5f4..ea551f304 100644 --- a/bioptim/interfaces/sqp_options.py +++ b/bioptim/interfaces/sqp_options.py @@ -245,7 +245,7 @@ def set_constraint_tolerance(self, tol: float): def as_dict(self, solver): solver_options = self.__dict__ options = {} - non_python_options = ["type", "show_online_optim", "show_options"] + non_python_options = ["type", "show_online_optim", "online_optim", "show_options"] for key in solver_options: if key not in non_python_options: sqp_key = key[1:] From 23d20a89fb86c451e20e890d85c57793e21c3170 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Fri, 2 Aug 2024 10:32:31 -0400 Subject: [PATCH 15/17] Fixed last issues when combining show_optim and online_optim --- bioptim/examples/getting_started/pendulum.py | 2 +- .../online_callback_multiprocess_server.py | 4 +-- bioptim/interfaces/interface_utils.py | 26 ++++++++++++++----- bioptim/interfaces/ipopt_options.py | 12 +-------- bioptim/interfaces/sqp_options.py | 12 +-------- bioptim/misc/enums.py | 1 - .../receding_horizon_optimization.py | 2 +- tests/shard4/test_solver_options.py | 15 +---------- 8 files changed, 26 insertions(+), 48 deletions(-) diff --git a/bioptim/examples/getting_started/pendulum.py b/bioptim/examples/getting_started/pendulum.py index 7ce9ed46f..55fa6a782 100644 --- a/bioptim/examples/getting_started/pendulum.py +++ b/bioptim/examples/getting_started/pendulum.py @@ -149,7 +149,7 @@ def main(): ocp.print(to_console=False, to_graph=False) # --- Solve the ocp --- # - # Default is OnlineOptim.MULTIPROCESS on Linux, OnlineOptim.MULTIPROCESS_SERVER on Windows and OnlineOptim.NONE on MacOS + # Default is OnlineOptim.MULTIPROCESS on Linux, OnlineOptim.MULTIPROCESS_SERVER on Windows and None on MacOS sol = ocp.solve(Solver.IPOPT(show_online_optim=OnlineOptim.DEFAULT)) # --- Show the results graph --- # diff --git a/bioptim/gui/online_callback_multiprocess_server.py b/bioptim/gui/online_callback_multiprocess_server.py index 9793fef94..72ff50209 100644 --- a/bioptim/gui/online_callback_multiprocess_server.py +++ b/bioptim/gui/online_callback_multiprocess_server.py @@ -15,7 +15,7 @@ def _start_as_multiprocess_internal(**kwargs): PlottingServer(**kwargs) -class PlottingMultiprocessServer(OnlineCallbackServer): +class OnlineCallbackMultiprocessServer(OnlineCallbackServer): def __init__(self, *args, **kwargs): """ Starts the server in a new process @@ -29,4 +29,4 @@ def __init__(self, *args, **kwargs): process = Process(target=_start_as_multiprocess_internal, kwargs={"host": host, "port": port}) process.start() - super(PlottingMultiprocessServer, self).__init__(*args, **kwargs) + super(OnlineCallbackMultiprocessServer, self).__init__(*args, **kwargs) diff --git a/bioptim/interfaces/interface_utils.py b/bioptim/interfaces/interface_utils.py index 62b492abf..a0aa07afa 100644 --- a/bioptim/interfaces/interface_utils.py +++ b/bioptim/interfaces/interface_utils.py @@ -7,8 +7,8 @@ from bioptim.optimization.solution.solution import Solution from ..gui.online_callback_multiprocess import OnlineCallbackMultiprocess -from ..gui.online_callback_server import PlottingServer -from ..gui.online_callback_multiprocess_server import PlottingMultiprocessServer +from ..gui.online_callback_server import OnlineCallbackServer +from ..gui.online_callback_multiprocess_server import OnlineCallbackMultiprocessServer from ..limits.path_conditions import Bounds from ..limits.penalty_helpers import PenaltyHelpers from ..misc.enums import InterpolationType, OnlineOptim @@ -33,8 +33,15 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): show_options = {} online_optim: OnlineOptim = interface.opts.online_optim + if online_optim == OnlineOptim.DEFAULT: + if platform == "linux": + online_optim = OnlineOptim.MULTIPROCESS + elif platform == "win32": + online_optim = OnlineOptim.MULTIPROCESS_SERVER + else: + online_optim = None - if interface.opts.online_optim == OnlineOptim.MULTIPROCESS: + if online_optim == OnlineOptim.MULTIPROCESS: if platform != "linux": raise RuntimeError( "Online OnlineOptim.MULTIPROCESS is not supported on Windows or MacOS. " @@ -53,9 +60,9 @@ def generic_online_optim(interface, ocp, show_options: dict | None = None): del show_options["port"] if online_optim == OnlineOptim.SERVER: - class_to_instantiate = PlottingServer + class_to_instantiate = OnlineCallbackServer elif online_optim == OnlineOptim.MULTIPROCESS_SERVER: - class_to_instantiate = PlottingMultiprocessServer + class_to_instantiate = OnlineCallbackMultiprocessServer else: raise NotImplementedError(f"show_options['type']={online_optim} is not implemented yet") @@ -92,7 +99,12 @@ def generic_solve(interface, expand_during_shake_tree=False) -> dict: all_g, all_g_bounds = interface.dispatch_bounds() all_g = _shake_tree_for_penalties(interface.ocp, all_g, v, v_bounds, expand_during_shake_tree) - if interface.opts.online_optim is not OnlineOptim.NONE: + if interface.opts.show_online_optim is not None: + if interface.opts.online_optim is not None: + raise ValueError("show_online_optim and online_optim cannot be simultaneous set") + interface.opts.online_optim = OnlineOptim.DEFAULT if interface.opts.show_online_optim else None + + if interface.opts.online_optim is not None: interface.online_optim(interface.ocp, interface.opts.show_options) # Thread here on (f and all_g) instead of individually for each function? @@ -138,7 +150,7 @@ def generic_solve(interface, expand_during_shake_tree=False) -> dict: interface.out["sol"]["solver"] = interface.solver_name # Make sure the graphs are showing the last iteration - if interface.opts.show_online_optim: + if "iteration_callback" in interface.options_common: to_eval = [ interface.out["sol"]["x"], interface.out["sol"]["f"], diff --git a/bioptim/interfaces/ipopt_options.py b/bioptim/interfaces/ipopt_options.py index de4fc5f7a..caef56f81 100644 --- a/bioptim/interfaces/ipopt_options.py +++ b/bioptim/interfaces/ipopt_options.py @@ -74,7 +74,7 @@ class IPOPT(GenericSolver): type: SolverType = SolverType.IPOPT show_online_optim: bool | None = None - online_optim: OnlineOptim = OnlineOptim.NONE + online_optim: OnlineOptim | None = None show_options: dict = None _tol: float = 1e-6 # default in ipopt 1e-8 _dual_inf_tol: float = 1.0 @@ -102,16 +102,6 @@ class IPOPT(GenericSolver): _c_compile: bool = False _check_derivatives_for_naninf: str = "no" # "yes" - def __attrs_post_init__(self): - if self.show_online_optim and self.online_optim != OnlineOptim.NONE: - raise ValueError("show_online_optim and online_optim cannot be simultaneous set") - - if self.show_online_optim is not None: - if self.show_online_optim: - self.online_optim = OnlineOptim.DEFAULT - else: - self.online_optim = OnlineOptim.NONE - @property def tol(self): return self._tol diff --git a/bioptim/interfaces/sqp_options.py b/bioptim/interfaces/sqp_options.py index ea551f304..3238c59d4 100644 --- a/bioptim/interfaces/sqp_options.py +++ b/bioptim/interfaces/sqp_options.py @@ -74,7 +74,7 @@ class SQP_METHOD(GenericSolver): type: SolverType = SolverType.SQP show_online_optim: bool | None = None - online_optim: OnlineOptim = OnlineOptim.NONE + online_optim: OnlineOptim | None = None show_options: dict = None c_compile = False _beta: float = 0.8 @@ -90,16 +90,6 @@ class SQP_METHOD(GenericSolver): _tol_du: float = 1e-6 _tol_pr: float = 1e-6 - def __attrs_post_init__(self): - if self.show_online_optim and self.online_optim != OnlineOptim.NONE: - raise ValueError("show_online_optim and online_optim cannot be simultaneous set") - - if self.show_online_optim is not None: - if self.show_online_optim: - self.online_optim = OnlineOptim.DEFAULT - else: - self.online_optim = OnlineOptim.NONE - @property def beta(self): return self._beta diff --git a/bioptim/misc/enums.py b/bioptim/misc/enums.py index 90c72da22..eaa396af7 100644 --- a/bioptim/misc/enums.py +++ b/bioptim/misc/enums.py @@ -106,7 +106,6 @@ class OnlineOptim(Enum): MULTIPROCESS_SERVER: Multiprocess server online plotting """ - NONE = auto() DEFAULT = auto() MULTIPROCESS = auto() SERVER = auto() diff --git a/bioptim/optimization/receding_horizon_optimization.py b/bioptim/optimization/receding_horizon_optimization.py index 2087f903a..3c578170c 100644 --- a/bioptim/optimization/receding_horizon_optimization.py +++ b/bioptim/optimization/receding_horizon_optimization.py @@ -174,7 +174,7 @@ def solve( f"Only {solver_current.get_tolerance_keys()} can be modified." ) if solver_current.type == SolverType.IPOPT: - solver_current.online_optim = OnlineOptim.NONE + solver_current.online_optim = None warm_start = None total_time += sol.real_time_to_optimize diff --git a/tests/shard4/test_solver_options.py b/tests/shard4/test_solver_options.py index 04eb5364a..8eec96736 100644 --- a/tests/shard4/test_solver_options.py +++ b/tests/shard4/test_solver_options.py @@ -16,7 +16,7 @@ def test_ipopt_solver_options(): solver = Solver.IPOPT() assert solver.type == SolverType.IPOPT assert solver.show_online_optim is None - assert solver.online_optim is OnlineOptim.NONE + assert solver.online_optim is None assert solver.show_options is None assert solver.tol == 1e-6 assert solver.dual_inf_tol == 1.0 @@ -133,16 +133,3 @@ def test_ipopt_solver_options(): solver.set_nlp_scaling_method("gradient-fiesta") assert solver.nlp_scaling_method == "gradient-fiesta" - - -def test_ipopt_solver_options_wrong(): - - solver = Solver.IPOPT() - solver.show_online_optim = True - with pytest.raises(ValueError, match="show_online_optim and online_optim cannot be simultaneous set"): - solver.online_optim = OnlineOptim.SERVER - - solver.show_online_optim = None - solver.online_optim = OnlineOptim.SERVER - with pytest.raises(ValueError, match="show_online_optim and online_optim cannot be simultaneous set"): - solver.show_online_optim = True From e4788f920f43e6de695c4b8581ddfc0a1d03de91 Mon Sep 17 00:00:00 2001 From: Pariterre Date: Fri, 2 Aug 2024 11:12:29 -0400 Subject: [PATCH 16/17] Added test for serialization and deserialization of online server --- bioptim/gui/online_callback_server.py | 3 +- tests/shard4/test_solver_options.py | 4 +-- tests/shard5/test_plot_server.py | 42 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 tests/shard5/test_plot_server.py diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 375bd7730..322bd0341 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -521,7 +521,7 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) # If send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) - self._socket.sendall(header.encode()) + self._socket.sendall(header) self._socket.sendall(data_serialized) # Again, if send_confirmation is True, we should wait for the server to acknowledge the data here (sends OK) return [0] @@ -565,6 +565,7 @@ def _serialize_xydata(xdata: list, ydata: list) -> tuple: y_steps_tp = y_steps.tolist() data_serialized += struct.pack("d" * len(y_steps_tp), *y_steps_tp) + header = header.encode() return header, data_serialized diff --git a/tests/shard4/test_solver_options.py b/tests/shard4/test_solver_options.py index 8eec96736..71d499036 100644 --- a/tests/shard4/test_solver_options.py +++ b/tests/shard4/test_solver_options.py @@ -1,7 +1,5 @@ -import pytest - from bioptim import Solver -from bioptim.misc.enums import SolverType, OnlineOptim +from bioptim.misc.enums import SolverType class FakeSolver: diff --git a/tests/shard5/test_plot_server.py b/tests/shard5/test_plot_server.py new file mode 100644 index 000000000..af774b260 --- /dev/null +++ b/tests/shard5/test_plot_server.py @@ -0,0 +1,42 @@ +import os + +from bioptim.gui.online_callback_server import _serialize_xydata, _deserialize_xydata +from bioptim.gui.plot import PlotOcp +from bioptim.optimization.optimization_vector import OptimizationVectorHelper +from casadi import DM, Function +import numpy as np + + +def test_serialize_deserialize(): + # Prepare a set of data to serialize and deserialize + from bioptim.examples.getting_started import pendulum as ocp_module + + bioptim_folder = os.path.dirname(ocp_module.__file__) + + ocp = ocp_module.prepare_ocp( + biorbd_model_path=bioptim_folder + "/models/pendulum.bioMod", + final_time=1, + n_shooting=40, + ) + + dummy_phase_times = OptimizationVectorHelper.extract_step_times(ocp, DM(np.ones(ocp.n_phases))) + plotter = PlotOcp(ocp, dummy_phase_times=dummy_phase_times, show_bounds=True, only_initialize_variables=True) + + np.random.seed(42) + xdata, ydata = plotter.parse_data(**{"x": np.random.rand(ocp.variables_vector.shape[0])[:, None]}) + + # Serialize and deserialize the data + serialized_data = _serialize_xydata(xdata, ydata) + deserialized_xdata, deserialized_ydata = _deserialize_xydata(serialized_data) + + # Compare the outputs + for x_phase, deserialized_x_phase in zip(xdata, deserialized_xdata): + for x_node, deserialized_x_node in zip(x_phase, deserialized_x_phase): + assert np.allclose(x_node, DM(deserialized_x_node)) + + for y_variable, deserialized_y_variable in zip(ydata, deserialized_ydata): + if isinstance(y_variable, np.ndarray): + assert np.allclose(y_variable, deserialized_y_variable[0], equal_nan=True) + else: + for y_phase, deserialized_y_phase in zip(y_variable, deserialized_y_variable): + assert np.allclose(y_phase, deserialized_y_phase) From b1fb3f57b82a197e9f307ccc902eba4b71e4f73e Mon Sep 17 00:00:00 2001 From: Pariterre Date: Mon, 5 Aug 2024 09:35:02 -0400 Subject: [PATCH 17/17] Removed a useless comment --- bioptim/gui/online_callback_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 322bd0341..62804bb7d 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -460,7 +460,6 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: if self._socket.recv(1024).decode() != "OK": raise RuntimeError("The server did not acknowledge the connexion") - # TODO ADD SHOW OPTIONS to the send self._socket.sendall(serialized_ocp) self._socket.sendall(serialized_show_options) if self._socket.recv(1024).decode() != "OK":