Skip to content

Commit

Permalink
Template changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu committed Jan 24, 2024
1 parent 72d1a84 commit b29587a
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) ->
if function_name == PredefinedFunctions.TIME_RESOLUTION:
# context dependent; we assume the template contains the necessary definitions
return 'h'
#return 'NESTGPUTimeResolution'

if function_name == PredefinedFunctions.TIME_STEPS:
return '(int)round({!s}/NESTGPUTimeResolution)'
Expand Down
1 change: 0 additions & 1 deletion pynestml/codegeneration/resources_nest_gpu/directives

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "{{ neuronName }}.h"
#include "spike_buffer.h"

{%- import 'directives/SetScalParamAndVar.jinja2' as set_scal_param_var with context %}

{%- if uses_analytic_solver %}
using namespace {{ neuronName }}_ns;

Expand Down Expand Up @@ -176,7 +178,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
// Parameters
{%- for variable_symbol in neuron.get_parameter_symbols() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
SetScalParam(0, n_node, "{{ printer_no_origin.print(variable) }}", {{printer_no_origin.print(variable_symbol.get_declaring_expression())}}); // as {{variable_symbol.get_type_symbol().print_symbol()}}
SetScalParam(0, n_node, "{{ printer_no_origin.print(variable) }}", {{set_scal_param_var.SetScalParamAndVar(variable_symbol.get_declaring_expression())}}); // as {{variable_symbol.get_type_symbol().print_symbol()}}
{%- endfor %}

// Internal variables
Expand All @@ -194,7 +196,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
// State variables
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
SetScalVar(0, n_node, "{{ printer_no_origin.print(variable) }}", {{printer_no_origin.print(variable_symbol.get_declaring_expression())}});
SetScalVar(0, n_node, "{{ printer_no_origin.print(variable) }}", {{set_scal_param_var.SetScalParamAndVar(variable_symbol.get_declaring_expression())}});
{%- endfor %}
{%- endif %}

Expand All @@ -212,7 +214,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
port_input_port_step_ = 1;

{# TODO #}
{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#}
{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay"); #}

return 0;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{#
Initialization of param or var if they have a declaring expression
@param expr ASTExpression declaring expression of the variable
#}
{%- macro SetScalParamAndVar(expr) -%}
{%- if utils.is_declaring_expression_parameter(expr) %}
*GetScalParam(0, n_node, "{{expr}}")
{%- elif utils.is_declaring_expression_state_varible(expr) %}
*GetScalVar(0, n_node, "{{expr}}")
{%- else %}
{{printer_no_origin.print(expr)}}
{%- endif %}
{%- endmacro -%}
18 changes: 18 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,4 +2234,22 @@ def adjusted_state_symbols(cls, neuron: ASTNeuron):
diff = list(set(neuron.get_state_symbols()) - set(extract_list))
return diff + extract_list
return neuron.get_state_symbols()


@classmethod
def is_declaring_expression_parameter(cls, expr: ASTExpression) -> bool:
if isinstance(expr, ASTSimpleExpression):
if expr.is_variable():
symbol = expr.get_scope().resolve_to_symbol(expr.get_variable().get_name(), SymbolKind.VARIABLE)
if symbol and symbol.is_parameters():
return True
return False

@classmethod
def is_declaring_expression_state_varible(cls, expr: ASTExpression) -> bool:
if isinstance(expr, ASTSimpleExpression):
if expr.is_variable():
symbol = expr.get_scope().resolve_to_symbol(expr.get_variable().get_name(), SymbolKind.VARIABLE)
if symbol and symbol.is_state():
return True
return False

0 comments on commit b29587a

Please sign in to comment.