Skip to content

Commit

Permalink
add fortran interpreter, missing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AaronDavidSchneider committed Feb 28, 2023
1 parent 08dda59 commit b95f10b
Show file tree
Hide file tree
Showing 16 changed files with 423 additions and 0 deletions.
2 changes: 2 additions & 0 deletions m2cgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
export_to_dart,
export_to_elixir,
export_to_f_sharp,
export_to_fortran,
export_to_go,
export_to_haskell,
export_to_java,
Expand Down Expand Up @@ -34,6 +35,7 @@
export_to_haskell,
export_to_ruby,
export_to_f_sharp,
export_to_fortran,
export_to_rust,
export_to_elixir,
]
Expand Down
1 change: 1 addition & 0 deletions m2cgen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"ruby": (m2cgen.export_to_ruby, ["indent", "function_name"]),
"f_sharp": (m2cgen.export_to_f_sharp, ["indent", "function_name"]),
"rust": (m2cgen.export_to_rust, ["indent", "function_name"]),
"fortran": (m2cgen.export_to_fortran, ["indent", "function_name"]),
"elixir": (m2cgen.export_to_elixir, ["module_name", "indent", "function_name"]),
}

Expand Down
24 changes: 24 additions & 0 deletions m2cgen/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,30 @@ def export_to_ruby(model, indent=4, function_name="score"):
return _export(model, interpreter)


def export_to_fortran(model, indent=4, function_name="score"):
"""
Generates a Fortran code representation of the given model.
Parameters
----------
model : object
The model object that should be transpiled into code.
indent : int, optional
The size of indents in the generated code.
function_name : string, optional
Name of the function in the generated code.
Returns
-------
code : string
"""
interpreter = interpreters.FortranInterpreter(
indent=indent,
function_name=function_name
)
return _export(model, interpreter)


def export_to_f_sharp(model, indent=4, function_name="score"):
"""
Generates a F# code representation of the given model.
Expand Down
2 changes: 2 additions & 0 deletions m2cgen/interpreters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from m2cgen.interpreters.dart.interpreter import DartInterpreter
from m2cgen.interpreters.elixir.interpreter import ElixirInterpreter
from m2cgen.interpreters.f_sharp.interpreter import FSharpInterpreter
from m2cgen.interpreters.fortran.interpreter import FortranInterpreter
from m2cgen.interpreters.go.interpreter import GoInterpreter
from m2cgen.interpreters.haskell.interpreter import HaskellInterpreter
from m2cgen.interpreters.java.interpreter import JavaInterpreter
Expand Down Expand Up @@ -30,6 +31,7 @@
HaskellInterpreter,
RubyInterpreter,
FSharpInterpreter,
FortranInterpreter,
RustInterpreter,
ElixirInterpreter,
]
Empty file.
59 changes: 59 additions & 0 deletions m2cgen/interpreters/fortran/code_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from contextlib import contextmanager

from m2cgen.interpreters.code_generator import CodeTemplate, ImperativeCodeGenerator


class FortranCodeGenerator(ImperativeCodeGenerator):
tpl_num_value = CodeTemplate("{value}")
tpl_infix_expression = CodeTemplate("{left} {op} {right}")
tpl_return_statement_vec = CodeTemplate("{func_name}(:) = {value}")
tpl_return_statement_single = CodeTemplate("{func_name} = {value}")
tpl_array_index_access = CodeTemplate("{array_name}({index})")
tpl_if_statement = CodeTemplate("if ({if_def}) then")
tpl_else_statement = CodeTemplate("else")
tpl_var_assignment = CodeTemplate("{var_name} = {value}")
tpl_scalar_var_declare = CodeTemplate("double precision :: {var_name}")
tpl_vector_var_declare = CodeTemplate("double precision, dimension({size}) :: {var_name}")

tpl_block_termination = CodeTemplate("end if")

def add_return_statement(self, value, func_name, output_size):
if output_size > 1:
tpl = self.tpl_return_statement_vec
else:
tpl = self.tpl_return_statement_single

self.add_code_line(tpl(value=value, func_name=func_name))

def _declaration(self, var_name, size):
if size > 1:
tpl = self.tpl_vector_var_declare
else:
tpl = self.tpl_scalar_var_declare

return tpl(var_name=var_name, size=size)

def add_function_def(self, name, args, output_size):
function_def = f"function {name}({', '.join(args)})"
self.add_code_line(function_def)
self.increase_indent()
self.add_code_line(self._declaration(var_name=name, size=output_size))
self.add_code_lines([self.tpl_vector_var_declare(var_name=arg, size=":") for arg in args])

def add_function_end(self, name):
self.add_code_line("return")
self.decrease_indent()
self.add_code_line(f"end function {name}")

def add_var_declaration(self, size):
# We use implicit declerations for the local variables
return self.get_var_name()

@contextmanager
def function_definition(self, name, args, output_size):
self.add_function_def(name, args, output_size)
yield
self.add_function_end(name)

def vector_init(self, values):
return f"(/ {', '.join(values)} /)"
105 changes: 105 additions & 0 deletions m2cgen/interpreters/fortran/interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from pathlib import Path

from m2cgen.ast import BinNumOpType
from m2cgen.interpreters.interpreter import ImperativeToCodeInterpreter
from m2cgen.interpreters.mixins import BinExpressionDepthTrackingMixin, LinearAlgebraMixin, PowExprFunctionMixin
from m2cgen.interpreters.fortran.code_generator import FortranCodeGenerator
from m2cgen.interpreters.utils import get_file_content


class FortranInterpreter(ImperativeToCodeInterpreter,
PowExprFunctionMixin,
BinExpressionDepthTrackingMixin,
LinearAlgebraMixin):
# needs to be tested.
bin_depth_threshold = 55

supported_bin_vector_ops = {
BinNumOpType.ADD: "add_vectors",
}

supported_bin_vector_num_ops = {
BinNumOpType.MUL: "mul_vector_number",
}

abs_function_name = "ABS"
atan_function_name = "ATAN"
exponent_function_name = "EXP"
logarithm_function_name = "LOG"
log1p_function_name = "LOG1P"
sigmoid_function_name = "SIGMOID"
softmax_function_name = "SOFTMAX"
sqrt_function_name = "SQRT"
tanh_function_name = "TANH"

pow_operator = "**"

with_sigmoid_expr = False
with_softmax_expr = False
with_log1p_expr = False

def __init__(self, indent=4, function_name="score", *args, **kwargs):
self.function_name = function_name

cg = FortranCodeGenerator(indent=indent)
super().__init__(cg, *args, **kwargs)

def interpret(self, expr):
self._cg.reset_state()
self._reset_reused_expr_cache()

with self._cg.function_definition(
name=self.function_name,
args=[self._feature_array_name],
output_size=expr.output_size,
):
last_result = self._do_interpret(expr)
self._cg.add_return_statement(last_result, self.function_name, expr.output_size)

current_dir = Path(__file__).absolute().parent

if self.with_linear_algebra \
or self.with_softmax_expr \
or self.with_sigmoid_expr \
or self.with_log1p_expr:
self._cg.add_code_line("contains")

if self.with_linear_algebra:
filename = current_dir / "linear_algebra.f90"
self._add_contain_statement(filename)

if self.with_softmax_expr:
filename = current_dir / "softmax.f90"
self._add_contain_statement(filename)

if self.with_sigmoid_expr:
filename = current_dir / "sigmoid.f90"
self._add_contain_statement(filename)

if self.with_log1p_expr:
filename = current_dir / "log1p.f90"
self._add_contain_statement(filename)

return self._cg.finalize_and_get_generated_code()

def _add_contain_statement(self, filename):
self._cg.increase_indent()
self._cg.add_code_lines(get_file_content(filename))
self._cg.decrease_indent()

def interpret_abs_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.abs_function_name, nested_result)

def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_softmax_expr(expr, **kwargs)

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)

def interpret_sigmoid_expr(self, expr, **kwargs):
self.with_sigmoid_expr = True
return super().interpret_sigmoid_expr(expr, **kwargs)
24 changes: 24 additions & 0 deletions m2cgen/interpreters/fortran/linear_algebra.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function add_vectors(v1, v2) result(res)
implicit none
double precision, dimension(:), intent(in) :: v1, v2
double precision, dimension(size(v1)) :: res
integer :: i

do i = 1, size(v1)
res(i) = v1(i) + v2(i)
end do

end function add_vectors

function mul_vector_number(v1, num) result(res)
implicit none
double precision, dimension(:), intent(in) :: v1
double precision, intent(in) :: num
double precision, dimension(size(v1)) :: res
integer :: i

do i = 1, size(v1)
res(i) = v1(i) * num
end do

end function mul_vector_number
79 changes: 79 additions & 0 deletions m2cgen/interpreters/fortran/log1p.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
function ChebyshevBroucke(x, coeffs) result(result)
implicit none
double precision, intent(in) :: x
double precision, intent(in) :: coeffs(:)
double precision :: b0, b1, b2, x2, result
integer :: i
b2 = 0.0d0
b1 = 0.0d0
b0 = 0.0d0
x2 = x * 2.0d0
do i = size(coeffs, 1), 1, -1
b2 = b1
b1 = b0
b0 = x2 * b1 - b2 + coeffs(i)
end do
result = (b0 - b2) * 0.5d0
end function ChebyshevBroucke

function Log1p(x) result(result)
implicit none
double precision, intent(in) :: x
double precision :: res, xAbs
double precision, parameter :: eps = 2.220446049250313d-16
double precision, parameter :: coeff(21) = (/ 0.10378693562743769800686267719098d1, &
-0.13364301504908918098766041553133d0, &
0.19408249135520563357926199374750d-1, &
-0.30107551127535777690376537776592d-2, &
0.48694614797154850090456366509137d-3, &
-0.81054881893175356066809943008622d-4, &
0.13778847799559524782938251496059d-4, &
-0.23802210894358970251369992914935d-5, &
0.41640416213865183476391859901989d-6, &
-0.73595828378075994984266837031998d-7, &
0.13117611876241674949152294345011d-7, &
-0.23546709317742425136696092330175d-8, &
0.42522773276034997775638052962567d-9, &
-0.77190894134840796826108107493300d-10, &
0.14075746481359069909215356472191d-10, &
-0.25769072058024680627537078627584d-11, &
0.47342406666294421849154395005938d-12, &
-0.87249012674742641745301263292675d-13, &
0.16124614902740551465739833119115d-13, &
-0.29875652015665773006710792416815d-14, &
0.55480701209082887983041321697279d-15, &
-0.10324619158271569595141333961932d-15 /)

if (x == 0.0d0) then
result = 0.0d0
return
end if
if (x == -1.0d0) then
result = -huge(1.0d0)
return
end if
if (x < -1.0) then
result = 0.0d0 / 0.0d0
return
end if

xAbs = abs(x)
if (xAbs < 0.5 * eps) then
result = x
return
end if

if ((x > 0.0 .and. x < 1.0e-8) .or. (x > -1.0e-9 .and. x < 0.0)) then
result = x * (1.0 - x * 0.5)
return
end if

if (xAbs < 0.375) then
result = x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeff))
return
end if

result = log(1.0 + x)

end function Log1p

13 changes: 13 additions & 0 deletions m2cgen/interpreters/fortran/sigmoid.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function sigmoid(x) result(res)
implicit none
double precision, intent(in) :: x
double precision :: z

if (x < 0.0d0) then
z = exp(x)
res = z / (1.0d0 + z)
else
res = 1.0d0 / (1.0d0 + exp(-x))
end if

end function sigmoid
24 changes: 24 additions & 0 deletions m2cgen/interpreters/fortran/softmax.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function softmax(x) result(res)
implicit none
double precision, dimension(:), intent(in) :: x
double precision, dimension(size(x)) :: res
double precision :: max_val, sum_val
integer :: i

! Find maximum value in x
max_val = x(1)
do i = 2, size(x)
if (x(i) > max_val) then
max_val = x(i)
end if
end do

! Compute softmax values
sum_val = 0.0d0
do i = 1, size(x)
res(i) = exp(x(i) - max_val)
sum_val = sum_val + res(i)
end do
res = res / sum_val

end function softmax
2 changes: 2 additions & 0 deletions tests/e2e/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests.e2e.executors.dart import DartExecutor
from tests.e2e.executors.elixir import ElixirExecutor
from tests.e2e.executors.f_sharp import FSharpExecutor
from tests.e2e.executors.fortran import FortranExecutor
from tests.e2e.executors.go import GoExecutor
from tests.e2e.executors.haskell import HaskellExecutor
from tests.e2e.executors.java import JavaExecutor
Expand All @@ -21,6 +22,7 @@
CExecutor,
GoExecutor,
JavascriptExecutor,
FortranExecutor,
VisualBasicExecutor,
CSharpExecutor,
PowershellExecutor,
Expand Down
Loading

0 comments on commit b95f10b

Please sign in to comment.