Skip to content

Commit

Permalink
fix: extract stage functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Sep 5, 2024
1 parent 680b38a commit 8911021
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 175 deletions.
168 changes: 168 additions & 0 deletions jaxadi/_stages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import List, Any
from ._ops import OP_JAX_DICT, OP_JAX_VALUE_DICT
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function


class Stage:
def __init__(self):
self.output_idx: List[int] = []
self.work_idx: List[int] = []
self.ops: List[Operation] = []

def codegen(self):
if self.ops[0].op == OP_OUTPUT:
return self.ops[0].codegen()

values = "["
outputs = "jnp.array(["
for op in self.ops:
if values[-1] != "[":
values += ", "
if outputs[-1] != "[":
outputs += ", "
values += f"{op.value}"
outputs += str(op.output_idx)
values += "]"
outputs += "])"

return f"\n work = work.at[{outputs}].set({values})"


class Operation:
def __init__(self):
self.op: int = None
self.value: str = ""
self.work_idx: List[int] = []
self.output_idx: Any = None

def codegen(self):
return f"\n work = work.at[{self.output_idx}].set({self.value})"


class OutputOperation(Operation):
def __init__(self):
self.exact_idx1: Any = None
self.exact_idx2: Any = None

super().__init__()

def codegen(self):
return f"\n outputs[{self.output_idx}] = outputs[{self.output_idx}].at[{self.exact_idx1}, {self.exact_idx2}].set({self.value})"


def stage_generator(func: Function) -> str:
# Get information about Casadi function
n_instr = func.n_instructions()
n_out = func.n_out() # number of outputs in the function
n_in = func.n_in() # number of outputs in the function

# Get the shapes of input and output
out_shapes = [func.size_out(i) for i in range(n_out)]
in_shapes = [func.size_in(i) for i in range(n_in)]

input_idx = [func.instruction_input(i) for i in range(n_instr)]
output_idx = [func.instruction_output(i) for i in range(n_instr)]
operations = [func.instruction_id(i) for i in range(n_instr)]
const_instr = [func.instruction_constant(i) for i in range(n_instr)]

stages = []
for k in range(n_instr):
op = operations[k]
o_idx = output_idx[k]
i_idx = input_idx[k]
operation = Operation()
operation.op = op
if op == OP_CONST:
operation.output_idx = o_idx[0]
operation.value = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])"
# codegen += OP_JAX_DICT[op].format(o_idx[0], const_instr[k])
elif op == OP_INPUT:
this_shape = in_shapes[i_idx[0]]
rows, cols = this_shape # Get the shape of the output
row_number = i_idx[1] % rows # Compute row index for JAX
column_number = i_idx[1] // rows # Compute column index for JAX
operation.output_idx = o_idx[0]
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number)
elif op == OP_OUTPUT:
operation = OutputOperation()
operation.op = op
rows, cols = out_shapes[o_idx[0]]
row_number = o_idx[1] % rows # Compute row index for JAX
column_number = o_idx[1] // rows # Compute column index for JAX
operation.exact_idx1 = row_number
operation.exact_idx2 = column_number
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
elif op == OP_SQ:
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
elif OP_JAX_DICT[op].count("}") == 3:
operation.output_idx = o_idx[0]
operation.work_idx.extend([i_idx[0], i_idx[1]])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], i_idx[1])
elif OP_JAX_DICT[op].count("}") == 2:
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
else:
raise Exception("Unknown CasADi operation: " + str(op))

stage = Stage()
stage.output_idx.append(operation.output_idx)
stage.work_idx.extend(operation.work_idx)
stage.ops.append(operation)
stages.append(stage)

return stages


def update_stage(stage: Stage):
stage.output_idx = []
stage.work_idx = []

for op in stage.ops:
stage.output_idx.append(op.output_idx)
stage.work_idx.extend(op.work_idx)


def squeze_stages(stages: List[Stage]) -> List[Stage]:
for i, stage in enumerate(stages):
stage_ops = []
for j, op in enumerate(stage.ops):
current_stage = stage
for k, new_stage in enumerate(reversed(stages[0:i])):
if op.op == OP_OUTPUT:
break
if len(new_stage.ops) == 0:
continue
if new_stage.ops[0].op == OP_OUTPUT:
if op.output_idx in new_stage.output_idx:
break
elif op.output_idx in new_stage.work_idx:
break
else:
continue
if op.output_idx in new_stage.output_idx:
break
if set(op.work_idx).intersection(set(new_stage.output_idx)):
break
if op.output_idx in new_stage.work_idx:
current_stage = new_stage
break
current_stage = new_stage
if current_stage == stage:
stage_ops.append(op)
else:
current_stage.ops.append(op)

update_stage(current_stage)
stage.ops = stage_ops
update_stage(stage)

new_stages = []
for stage in stages:
if len(stage.ops) != 0:
new_stages.append(stage)
return new_stages
179 changes: 4 additions & 175 deletions jaxadi/_translate.py
Original file line number Diff line number Diff line change
@@ -1,182 +1,10 @@
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function

from ._ops import OP_JAX_DICT, OP_JAX_VALUE_DICT
from typing import List, Any


class Stage:
def __init__(self):
self.output_idx: List[int] = []
self.work_idx: List[int] = []
self.ops: List[Operation] = []

def codegen(self):
if self.ops[0].op == OP_OUTPUT:
return self.ops[0].codegen()

values = "["
outputs = "jnp.array(["
for op in self.ops:
if values[-1] != "[":
values += ", "
if outputs[-1] != "[":
outputs += ", "
values += f"{op.value}"
outputs += str(op.output_idx)
values += "]"
outputs += "])"

return f"\n work = work.at[{outputs}].set({values})"


class Operation:
def __init__(self):
self.op: int = None
self.value: str = ""
self.work_idx: List[int] = []
self.output_idx: Any = None

def codegen(self):
return f"\n work = work.at[{self.output_idx}].set({self.value})"


class OutputOperation(Operation):
def __init__(self):
self.exact_idx1: Any = None
self.exact_idx2: Any = None

super().__init__()

def codegen(self):
return f"\n outputs[{self.output_idx}] = outputs[{self.output_idx}].at[{self.exact_idx1}, {self.exact_idx2}].set({self.value})"


def stage_generator(func: Function) -> str:
# Get information about Casadi function
n_instr = func.n_instructions()
n_out = func.n_out() # number of outputs in the function
n_in = func.n_in() # number of outputs in the function

# Get the shapes of input and output
out_shapes = [func.size_out(i) for i in range(n_out)]
in_shapes = [func.size_in(i) for i in range(n_in)]

input_idx = [func.instruction_input(i) for i in range(n_instr)]
output_idx = [func.instruction_output(i) for i in range(n_instr)]
operations = [func.instruction_id(i) for i in range(n_instr)]
const_instr = [func.instruction_constant(i) for i in range(n_instr)]

stages = []
for k in range(n_instr):
op = operations[k]
o_idx = output_idx[k]
i_idx = input_idx[k]
operation = Operation()
operation.op = op
if op == OP_CONST:
operation.output_idx = o_idx[0]
operation.value = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])"
# codegen += OP_JAX_DICT[op].format(o_idx[0], const_instr[k])
elif op == OP_INPUT:
this_shape = in_shapes[i_idx[0]]
rows, cols = this_shape # Get the shape of the output
row_number = i_idx[1] % rows # Compute row index for JAX
column_number = i_idx[1] // rows # Compute column index for JAX
operation.output_idx = o_idx[0]
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number)
elif op == OP_OUTPUT:
operation = OutputOperation()
operation.op = op
rows, cols = out_shapes[o_idx[0]]
row_number = o_idx[1] % rows # Compute row index for JAX
column_number = o_idx[1] // rows # Compute column index for JAX
operation.exact_idx1 = row_number
operation.exact_idx2 = column_number
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
elif op == OP_SQ:
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
elif OP_JAX_DICT[op].count("}") == 3:
operation.output_idx = o_idx[0]
operation.work_idx.extend([i_idx[0], i_idx[1]])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], i_idx[1])
elif OP_JAX_DICT[op].count("}") == 2:
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
else:
raise Exception("Unknown CasADi operation: " + str(op))

stage = Stage()
stage.output_idx.append(operation.output_idx)
stage.work_idx.extend(operation.work_idx)
stage.ops.append(operation)
stages.append(stage)

return stages


def update_stage(stage: Stage):
stage.output_idx = []
stage.work_idx = []

for op in stage.ops:
stage.output_idx.append(op.output_idx)
stage.work_idx.extend(op.work_idx)


def squeze_stages(stages: List[Stage]) -> List[Stage]:
for i, stage in enumerate(stages):
stage_ops = []
for j, op in enumerate(stage.ops):
current_stage = stage
for k, new_stage in enumerate(reversed(stages[0:i])):
if op.op == OP_OUTPUT:
break
if len(new_stage.ops) == 0:
continue
if new_stage.ops[0].op == OP_OUTPUT:
if op.output_idx in new_stage.output_idx:
break
elif op.output_idx in new_stage.work_idx:
break
else:
continue
if op.output_idx in new_stage.output_idx:
break
if set(op.work_idx).intersection(set(new_stage.output_idx)):
break
if op.output_idx in new_stage.work_idx:
current_stage = new_stage
break
current_stage = new_stage
if current_stage == stage:
stage_ops.append(op)
else:
current_stage.ops.append(op)

update_stage(current_stage)
stage.ops = stage_ops
update_stage(stage)

new_stages = []
for stage in stages:
if len(stage.ops) != 0:
new_stages.append(stage)
return new_stages


def compress_stages(stages: List[Stage]) -> List[Stage]:
squezed = squeze_stages(stages)
return squezed
from ._stages import stage_generator, squeze_stages


def translate(func: Function, add_jit=False, add_import=False) -> str:
stages = stage_generator(func)
stages = compress_stages(stages)
stages = squeze_stages(stages)
# Get information about Casadi function
n_out = func.n_out() # number of outputs in the function

Expand Down Expand Up @@ -248,7 +76,8 @@ def legacy_translate(func: Function, add_jit=False, add_import=False) -> str:
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0], row_number, column_number)
elif op == OP_OUTPUT:
# Fix for OP_OUTPUT
rows, cols = out_shapes[o_idx[0]] # Get the shape of the output matrix
# Get the shape of the output matrix
rows, cols = out_shapes[o_idx[0]]
# Adjust the calculation to switch from CasADi's column-major to JAX's row-major
row_number = o_idx[1] % rows # Compute row index for JAX
column_number = o_idx[1] // rows # Compute column index for JAX
Expand Down

0 comments on commit 8911021

Please sign in to comment.