diff --git a/jaxadi/_expand.py b/jaxadi/_expand.py index 3c7b887..12df95d 100644 --- a/jaxadi/_expand.py +++ b/jaxadi/_expand.py @@ -1,5 +1,5 @@ from typing import List, Any, Dict -from ._ops import OP_JAX_VALUE_DICT +from ._ops import OP_JAX_EXPAND_VALUE_DICT as OP_JAX_VALUE_DICT from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function import re from multiprocessing import Pool, cpu_count @@ -141,7 +141,7 @@ def combine_outputs(stages: List[Stage]) -> str: rows = "[" + ", ".join(row_indices) + "]" columns = "[" + ", ".join(column_indices) + "]" values_str = ", ".join(values) - command = f" o[{output_idx}] = o[{output_idx}].at[({rows}, {columns})].set([{values_str}])" + command = f" outputs[{output_idx}] = outputs[{output_idx}].at[({rows}, {columns})].set([{values_str}])" commands.append(command) # Combine all the commands into a single string diff --git a/jaxadi/_graph.py b/jaxadi/_graph.py index 84584e2..dfc0146 100644 --- a/jaxadi/_graph.py +++ b/jaxadi/_graph.py @@ -4,13 +4,32 @@ compression/fusion if necessary/possible """ -from casadi import Function -from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function +import random +import re from collections import deque +from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function + from ._ops import OP_JAX_VALUE_DICT +def test_and_compress(s): + # Step 1: Check if the string has the desired form using a regex + pattern = re.compile(r"\[\s*work\[(\d+)\]\s*\*\s*work\[(\d+)\](?:\s*,\s*work\[(\d+)\]\s*\*\s*work\[(\d+)\])*\s*\]") + + if not pattern.fullmatch(s.strip()): + return s + + # Step 2.1: Extract the indices from the matches + matches = re.findall(r"work\[(\d+)\]\s*\*\s*work\[(\d+)\]", s) + left_indices = [int(m[0]) for m in matches] + right_indices = [int(m[1]) for m in matches] + + # Construct the compressed string + compressed_string = f"jnp.multiply(work[jnp.array({left_indices})], work[jnp.array({right_indices})])" + return compressed_string + + def sort_by_height(graph, antigraph, heights): nodes = [[] for i in range(max(heights) + 1)] for i, h in enumerate(heights): @@ -27,6 +46,8 @@ def codegen(graph, antigraph, heights, output_map, values): indices = [] assignment = "[" for node in layer: + if len(graph[node]) == 0 and not node in output_map: + continue if node in output_map: oo = output_map[node] if outputs.get(oo[0], None) is None: @@ -39,7 +60,11 @@ def codegen(graph, antigraph, heights, output_map, values): assignment += ", " assignment += values[node] indices += [node] - code += f" work = work.at[jnp.array({indices})].set({assignment}])\n" + if len(indices) == 0: + continue + assignment += "]" + # assignment = test_and_compress(assignment) + code += f" work = work.at[jnp.array({indices})].set({assignment})\n" for k, v in outputs.items(): code += f" outputs[{k}] = outputs[{k}].at[({v['rows']}, {v['cols']})].set([{', '.join(v['values'])}])\n" @@ -97,10 +122,20 @@ def create_graph(func: Function): workers[o_idx[0]] = i elif op == OP_OUTPUT: rows, cols = func.size_out(o_idx[0]) - row_number = o_idx[1] % rows # Compute row index for JAX - column_number = o_idx[1] // rows # Compute column index for JAX + row_number = o_idx[1] % rows + column_number = o_idx[1] // rows output_map[i] = (o_idx[0], row_number, column_number) values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) + + # Update the graph: add this output node as a child of its input (work node) + parent = workers[i_idx[0]] + graph[parent].append(i) + antigraph[i].append(parent) + # rows, cols = func.size_out(o_idx[0]) + # row_number = o_idx[1] % rows # Compute row index for JAX + # column_number = o_idx[1] // rows # Compute column index for JAX + # output_map[i] = (o_idx[0], row_number, column_number) + # values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) elif op == OP_SQ: values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) graph[workers[i_idx[0]]].append(i) @@ -125,8 +160,60 @@ def create_graph(func: Function): return graph, antigraph, output_map, values +def expand_graph(func, graph, antigraph, output_map, values): + heights = compute_heights(func, graph, antigraph) + sorted_nodes = sort_by_height(graph, antigraph, heights) + + # Calculate the average number of vertices per layer + total_vertices = sum(len(layer) for layer in sorted_nodes) + avg_vertices = total_vertices / len(sorted_nodes) + + new_graph = [[] for _ in range(len(graph))] + new_antigraph = [[] for _ in range(len(antigraph))] + + # Iterate over layers + for layer in sorted_nodes: + expand_layer = len(layer) < avg_vertices and not any(node in output_map for node in layer) + + if expand_layer: + # Expand nodes and update their values + for node in layer: + value_expr = values[node] + expanded_expr = re.sub(r"work\[(\d+)\]", lambda m: f"({values[int(m.group(1))]})", value_expr) + values[node] = expanded_expr + + # Recalculate dependencies for expanded nodes + for node in layer: + new_parents = set() + for parent in antigraph[node]: + new_parents.update(new_antigraph[parent]) # Use updated parents + + # Update new_antigraph and new_graph accordingly + new_antigraph[node] = list(new_parents) + + for new_parent in new_parents: + new_graph[new_parent].append(node) + + # Retain the original child relationships + for child in graph[node]: + new_graph[node].append(child) + new_antigraph[child].append(node) + else: + # Maintain existing connections for nodes without expansion + for node in layer: + for parent in antigraph[node]: + new_graph[parent].append(node) + new_antigraph[node].append(parent) + for child in graph[node]: + new_graph[node].append(child) + new_antigraph[child].append(node) + + return new_graph, new_antigraph, output_map, values + + def translate(func: Function, add_jit=False, add_import=False): graph, antigraph, output_map, values = create_graph(func) + graph, antigraph, output_map, values = expand_graph(func, graph, antigraph, output_map, values) heights = compute_heights(func, graph, antigraph) code = "" diff --git a/jaxadi/_ops.py b/jaxadi/_ops.py index 86cc970..26633ea 100644 --- a/jaxadi/_ops.py +++ b/jaxadi/_ops.py @@ -97,3 +97,52 @@ OP_INPUT: "inputs[{0}][{1}, {2}]", OP_OUTPUT: "work[{0}][0]", } +OP_JAX_EXPAND_VALUE_DICT = { + OP_ASSIGN: "{0}", + OP_ADD: "{0} + {1}", + OP_SUB: "{0} - {1}", + OP_MUL: "{0} * {1}", + OP_DIV: "{0} / {1}", + OP_NEG: "-{0}", + OP_EXP: "jnp.exp({0})", + OP_LOG: "jnp.log({0})", + OP_POW: "jnp.power({0}, {1})", + OP_CONSTPOW: "jnp.power({0}, {1})", + OP_SQRT: "jnp.sqrt({0})", + OP_SQ: "{0} * {0}", + OP_TWICE: "2 * {0}", + OP_SIN: "jnp.sin({0})", + OP_COS: "jnp.cos({0})", + OP_TAN: "jnp.tan({0})", + OP_ASIN: "jnp.arcsin({0})", + OP_ACOS: "jnp.arccos({0})", + OP_ATAN: "jnp.arctan({0})", + OP_LT: "{0} < {1}", + OP_LE: "{0} <= {1}", + OP_EQ: "{0} == {1}", + OP_NE: "{0} != {1}", + OP_NOT: "jnp.logical_not({0})", + OP_AND: "jnp.logical_and({0}, {1})", + OP_OR: "jnp.logical_or({0}, {1})", + OP_FLOOR: "jnp.floor({0})", + OP_CEIL: "jnp.ceil({0})", + OP_FMOD: "jnp.fmod({0}, {1})", + OP_FABS: "jnp.abs({0})", + OP_SIGN: "jnp.sign({0})", + OP_COPYSIGN: "jnp.copysign({0}, {1})", + OP_IF_ELSE_ZERO: "jnp.where({0} == 0, 0, {1})", + OP_ERF: "jax.scipy.special.erf({0})", + OP_FMIN: "jnp.minimum({0}, {1})", + OP_FMAX: "jnp.maximum({0}, {1})", + OP_INV: "1.0 / {0}", + OP_SINH: "jnp.sinh({0})", + OP_COSH: "jnp.cosh({0})", + OP_TANH: "jnp.tanh({0})", + OP_ASINH: "jnp.arcsinh({0})", + OP_ACOSH: "jnp.arccosh({0})", + OP_ATANH: "jnp.arctanh({0})", + OP_ATAN2: "jnp.arctan2({0}, {1})", + OP_CONST: "{0:.16f}", + OP_INPUT: "inputs[{0}][{1}, {2}]", + OP_OUTPUT: "{0}[0]", +}