-
Notifications
You must be signed in to change notification settings - Fork 241
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add fortran interpreter, missing tests
- Loading branch information
1 parent
08dda59
commit b95f10b
Showing
16 changed files
with
423 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from contextlib import contextmanager | ||
|
||
from m2cgen.interpreters.code_generator import CodeTemplate, ImperativeCodeGenerator | ||
|
||
|
||
class FortranCodeGenerator(ImperativeCodeGenerator): | ||
tpl_num_value = CodeTemplate("{value}") | ||
tpl_infix_expression = CodeTemplate("{left} {op} {right}") | ||
tpl_return_statement_vec = CodeTemplate("{func_name}(:) = {value}") | ||
tpl_return_statement_single = CodeTemplate("{func_name} = {value}") | ||
tpl_array_index_access = CodeTemplate("{array_name}({index})") | ||
tpl_if_statement = CodeTemplate("if ({if_def}) then") | ||
tpl_else_statement = CodeTemplate("else") | ||
tpl_var_assignment = CodeTemplate("{var_name} = {value}") | ||
tpl_scalar_var_declare = CodeTemplate("double precision :: {var_name}") | ||
tpl_vector_var_declare = CodeTemplate("double precision, dimension({size}) :: {var_name}") | ||
|
||
tpl_block_termination = CodeTemplate("end if") | ||
|
||
def add_return_statement(self, value, func_name, output_size): | ||
if output_size > 1: | ||
tpl = self.tpl_return_statement_vec | ||
else: | ||
tpl = self.tpl_return_statement_single | ||
|
||
self.add_code_line(tpl(value=value, func_name=func_name)) | ||
|
||
def _declaration(self, var_name, size): | ||
if size > 1: | ||
tpl = self.tpl_vector_var_declare | ||
else: | ||
tpl = self.tpl_scalar_var_declare | ||
|
||
return tpl(var_name=var_name, size=size) | ||
|
||
def add_function_def(self, name, args, output_size): | ||
function_def = f"function {name}({', '.join(args)})" | ||
self.add_code_line(function_def) | ||
self.increase_indent() | ||
self.add_code_line(self._declaration(var_name=name, size=output_size)) | ||
self.add_code_lines([self.tpl_vector_var_declare(var_name=arg, size=":") for arg in args]) | ||
|
||
def add_function_end(self, name): | ||
self.add_code_line("return") | ||
self.decrease_indent() | ||
self.add_code_line(f"end function {name}") | ||
|
||
def add_var_declaration(self, size): | ||
# We use implicit declerations for the local variables | ||
return self.get_var_name() | ||
|
||
@contextmanager | ||
def function_definition(self, name, args, output_size): | ||
self.add_function_def(name, args, output_size) | ||
yield | ||
self.add_function_end(name) | ||
|
||
def vector_init(self, values): | ||
return f"(/ {', '.join(values)} /)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from pathlib import Path | ||
|
||
from m2cgen.ast import BinNumOpType | ||
from m2cgen.interpreters.interpreter import ImperativeToCodeInterpreter | ||
from m2cgen.interpreters.mixins import BinExpressionDepthTrackingMixin, LinearAlgebraMixin, PowExprFunctionMixin | ||
from m2cgen.interpreters.fortran.code_generator import FortranCodeGenerator | ||
from m2cgen.interpreters.utils import get_file_content | ||
|
||
|
||
class FortranInterpreter(ImperativeToCodeInterpreter, | ||
PowExprFunctionMixin, | ||
BinExpressionDepthTrackingMixin, | ||
LinearAlgebraMixin): | ||
# needs to be tested. | ||
bin_depth_threshold = 55 | ||
|
||
supported_bin_vector_ops = { | ||
BinNumOpType.ADD: "add_vectors", | ||
} | ||
|
||
supported_bin_vector_num_ops = { | ||
BinNumOpType.MUL: "mul_vector_number", | ||
} | ||
|
||
abs_function_name = "ABS" | ||
atan_function_name = "ATAN" | ||
exponent_function_name = "EXP" | ||
logarithm_function_name = "LOG" | ||
log1p_function_name = "LOG1P" | ||
sigmoid_function_name = "SIGMOID" | ||
softmax_function_name = "SOFTMAX" | ||
sqrt_function_name = "SQRT" | ||
tanh_function_name = "TANH" | ||
|
||
pow_operator = "**" | ||
|
||
with_sigmoid_expr = False | ||
with_softmax_expr = False | ||
with_log1p_expr = False | ||
|
||
def __init__(self, indent=4, function_name="score", *args, **kwargs): | ||
self.function_name = function_name | ||
|
||
cg = FortranCodeGenerator(indent=indent) | ||
super().__init__(cg, *args, **kwargs) | ||
|
||
def interpret(self, expr): | ||
self._cg.reset_state() | ||
self._reset_reused_expr_cache() | ||
|
||
with self._cg.function_definition( | ||
name=self.function_name, | ||
args=[self._feature_array_name], | ||
output_size=expr.output_size, | ||
): | ||
last_result = self._do_interpret(expr) | ||
self._cg.add_return_statement(last_result, self.function_name, expr.output_size) | ||
|
||
current_dir = Path(__file__).absolute().parent | ||
|
||
if self.with_linear_algebra \ | ||
or self.with_softmax_expr \ | ||
or self.with_sigmoid_expr \ | ||
or self.with_log1p_expr: | ||
self._cg.add_code_line("contains") | ||
|
||
if self.with_linear_algebra: | ||
filename = current_dir / "linear_algebra.f90" | ||
self._add_contain_statement(filename) | ||
|
||
if self.with_softmax_expr: | ||
filename = current_dir / "softmax.f90" | ||
self._add_contain_statement(filename) | ||
|
||
if self.with_sigmoid_expr: | ||
filename = current_dir / "sigmoid.f90" | ||
self._add_contain_statement(filename) | ||
|
||
if self.with_log1p_expr: | ||
filename = current_dir / "log1p.f90" | ||
self._add_contain_statement(filename) | ||
|
||
return self._cg.finalize_and_get_generated_code() | ||
|
||
def _add_contain_statement(self, filename): | ||
self._cg.increase_indent() | ||
self._cg.add_code_lines(get_file_content(filename)) | ||
self._cg.decrease_indent() | ||
|
||
def interpret_abs_expr(self, expr, **kwargs): | ||
nested_result = self._do_interpret(expr.expr, **kwargs) | ||
return self._cg.function_invocation( | ||
self.abs_function_name, nested_result) | ||
|
||
def interpret_log1p_expr(self, expr, **kwargs): | ||
self.with_log1p_expr = True | ||
return super().interpret_softmax_expr(expr, **kwargs) | ||
|
||
def interpret_softmax_expr(self, expr, **kwargs): | ||
self.with_softmax_expr = True | ||
return super().interpret_softmax_expr(expr, **kwargs) | ||
|
||
def interpret_sigmoid_expr(self, expr, **kwargs): | ||
self.with_sigmoid_expr = True | ||
return super().interpret_sigmoid_expr(expr, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
function add_vectors(v1, v2) result(res) | ||
implicit none | ||
double precision, dimension(:), intent(in) :: v1, v2 | ||
double precision, dimension(size(v1)) :: res | ||
integer :: i | ||
|
||
do i = 1, size(v1) | ||
res(i) = v1(i) + v2(i) | ||
end do | ||
|
||
end function add_vectors | ||
|
||
function mul_vector_number(v1, num) result(res) | ||
implicit none | ||
double precision, dimension(:), intent(in) :: v1 | ||
double precision, intent(in) :: num | ||
double precision, dimension(size(v1)) :: res | ||
integer :: i | ||
|
||
do i = 1, size(v1) | ||
res(i) = v1(i) * num | ||
end do | ||
|
||
end function mul_vector_number |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
function ChebyshevBroucke(x, coeffs) result(result) | ||
implicit none | ||
double precision, intent(in) :: x | ||
double precision, intent(in) :: coeffs(:) | ||
double precision :: b0, b1, b2, x2, result | ||
integer :: i | ||
b2 = 0.0d0 | ||
b1 = 0.0d0 | ||
b0 = 0.0d0 | ||
x2 = x * 2.0d0 | ||
do i = size(coeffs, 1), 1, -1 | ||
b2 = b1 | ||
b1 = b0 | ||
b0 = x2 * b1 - b2 + coeffs(i) | ||
end do | ||
result = (b0 - b2) * 0.5d0 | ||
end function ChebyshevBroucke | ||
|
||
function Log1p(x) result(result) | ||
implicit none | ||
double precision, intent(in) :: x | ||
double precision :: res, xAbs | ||
double precision, parameter :: eps = 2.220446049250313d-16 | ||
double precision, parameter :: coeff(21) = (/ 0.10378693562743769800686267719098d1, & | ||
-0.13364301504908918098766041553133d0, & | ||
0.19408249135520563357926199374750d-1, & | ||
-0.30107551127535777690376537776592d-2, & | ||
0.48694614797154850090456366509137d-3, & | ||
-0.81054881893175356066809943008622d-4, & | ||
0.13778847799559524782938251496059d-4, & | ||
-0.23802210894358970251369992914935d-5, & | ||
0.41640416213865183476391859901989d-6, & | ||
-0.73595828378075994984266837031998d-7, & | ||
0.13117611876241674949152294345011d-7, & | ||
-0.23546709317742425136696092330175d-8, & | ||
0.42522773276034997775638052962567d-9, & | ||
-0.77190894134840796826108107493300d-10, & | ||
0.14075746481359069909215356472191d-10, & | ||
-0.25769072058024680627537078627584d-11, & | ||
0.47342406666294421849154395005938d-12, & | ||
-0.87249012674742641745301263292675d-13, & | ||
0.16124614902740551465739833119115d-13, & | ||
-0.29875652015665773006710792416815d-14, & | ||
0.55480701209082887983041321697279d-15, & | ||
-0.10324619158271569595141333961932d-15 /) | ||
|
||
if (x == 0.0d0) then | ||
result = 0.0d0 | ||
return | ||
end if | ||
if (x == -1.0d0) then | ||
result = -huge(1.0d0) | ||
return | ||
end if | ||
if (x < -1.0) then | ||
result = 0.0d0 / 0.0d0 | ||
return | ||
end if | ||
|
||
xAbs = abs(x) | ||
if (xAbs < 0.5 * eps) then | ||
result = x | ||
return | ||
end if | ||
|
||
if ((x > 0.0 .and. x < 1.0e-8) .or. (x > -1.0e-9 .and. x < 0.0)) then | ||
result = x * (1.0 - x * 0.5) | ||
return | ||
end if | ||
|
||
if (xAbs < 0.375) then | ||
result = x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeff)) | ||
return | ||
end if | ||
|
||
result = log(1.0 + x) | ||
|
||
end function Log1p | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function sigmoid(x) result(res) | ||
implicit none | ||
double precision, intent(in) :: x | ||
double precision :: z | ||
|
||
if (x < 0.0d0) then | ||
z = exp(x) | ||
res = z / (1.0d0 + z) | ||
else | ||
res = 1.0d0 / (1.0d0 + exp(-x)) | ||
end if | ||
|
||
end function sigmoid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
function softmax(x) result(res) | ||
implicit none | ||
double precision, dimension(:), intent(in) :: x | ||
double precision, dimension(size(x)) :: res | ||
double precision :: max_val, sum_val | ||
integer :: i | ||
|
||
! Find maximum value in x | ||
max_val = x(1) | ||
do i = 2, size(x) | ||
if (x(i) > max_val) then | ||
max_val = x(i) | ||
end if | ||
end do | ||
|
||
! Compute softmax values | ||
sum_val = 0.0d0 | ||
do i = 1, size(x) | ||
res(i) = exp(x(i) - max_val) | ||
sum_val = sum_val + res(i) | ||
end do | ||
res = res / sum_val | ||
|
||
end function softmax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.