Skip to content

Commit

Permalink
Remove Buffer argument when possible in helper functions
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux#1802
  • Loading branch information
kanigsson committed Oct 21, 2024
1 parent bd21c3d commit abc1dc6
Show file tree
Hide file tree
Showing 34 changed files with 920 additions and 823 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Rejection of variable declarations with type `Opaque` (eng/recordflux/RecordFlux#633)
- Fatal error caused by variable in case expression (eng/recordflux/RecordFlux#1800)

### Changed

- Remove unused `Buffer` arguments in generated code (eng/recordflux/RecordFlux#1802)

## [0.24.0] - 2024-09-12

### Added
Expand Down
2 changes: 1 addition & 1 deletion examples/apps/spdm_responder/build_lib.gpr
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ project Build_Lib is
package Prove is
for Proof_Switches ("Ada") use Defaults.Proof_Switches;
for Proof_Switches ("responder.adb") use ("--prover=z3,cvc5", "--steps=64000", "--memlimit=6000", "--timeout=600");
for Proof_Switches ("rflx-spdm_responder-session-fsm.adb") use ("--timeout=120");
for Proof_Switches ("rflx-spdm_responder-session-fsm.adb") use ("--timeout=240");
end Prove;

end Build_Lib;
23 changes: 2 additions & 21 deletions rflx/generator/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def invalid_successors_invariant() -> Expr:
Variable("First"),
Variable("Verified_Last"),
Variable("Written_Last"),
Variable("Buffer"),
*([Variable("Buffer")] if message.has_aggregate_dependent_condition() else []),
*[Variable(p.identifier) for p in message.parameter_types],
],
),
Expand Down Expand Up @@ -779,7 +779,7 @@ def field_condition_call( # noqa: PLR0913
Variable(context),
Variable(package * field.affixed_name),
*([value] if has_scalar_value_dependent_condition(message) else []),
*([aggregate] if has_aggregate_dependent_condition(message) else []),
*([aggregate] if message.has_aggregate_dependent_condition() else []),
*([size] if has_size_dependent_condition(message, field) else []),
],
)
Expand Down Expand Up @@ -832,25 +832,6 @@ def has_scalar_value_dependent_condition(message: model.Message) -> bool:
)


def has_aggregate_dependent_condition(
message: model.Message,
field: model.Field | None = None,
) -> bool:
links = message.outgoing(field) if field else message.structure
fields = [field] if field else message.fields
return any(
r
for l in links
for r in l.condition.findall(lambda x: isinstance(x, (expr.Equal, expr.NotEqual)))
if isinstance(r, (expr.Equal, expr.NotEqual))
and r.findall(lambda x: isinstance(x, expr.Aggregate))
and any(
r.left == expr.Variable(f.identifier) or r.right == expr.Variable(f.identifier)
for f in fields
)
)


def has_size_dependent_condition(
message: model.Message,
field: model.Field | None = None,
Expand Down
86 changes: 69 additions & 17 deletions rflx/generator/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,11 @@ def create_valid_predecessors_invariant_function(
Parameter(["First"], const.TYPES_BIT_INDEX),
Parameter(["Verified_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Written_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Buffer"], const.TYPES_BYTES_PTR),
*(
[Parameter(["Buffer"], const.TYPES_BYTES_PTR)]
if message.has_aggregate_dependent_condition()
else []
),
*common.message_parameters(message),
],
)
Expand Down Expand Up @@ -484,7 +488,11 @@ def create_valid_next_internal_function(
Parameter(["First"], const.TYPES_BIT_INDEX),
Parameter(["Verified_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Written_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Buffer"], const.TYPES_BYTES_PTR),
*(
[Parameter(["Buffer"], const.TYPES_BYTES_PTR)]
if message.has_aggregate_dependent_condition()
else []
),
*common.message_parameters(message),
Parameter(["Fld"], "Field"),
],
Expand Down Expand Up @@ -541,7 +549,11 @@ def valid_next_expr(fld: Field) -> Expr:
Variable("First"),
Variable("Verified_Last"),
Variable("Written_Last"),
Variable("Buffer"),
*(
[Variable("Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*param_args,
],
),
Expand All @@ -564,7 +576,11 @@ def create_field_size_internal_function(message: Message, prefix: str) -> UnitPa
Parameter(["First"], const.TYPES_BIT_INDEX),
Parameter(["Verified_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Written_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Buffer"], const.TYPES_BYTES_PTR),
*(
[Parameter(["Buffer"], const.TYPES_BYTES_PTR)]
if message.has_aggregate_dependent_condition()
else []
),
*common.message_parameters(message),
Parameter(["Fld"], "Field"),
],
Expand Down Expand Up @@ -610,7 +626,11 @@ def create_field_size_internal_function(message: Message, prefix: str) -> UnitPa
Variable("First"),
Variable("Verified_Last"),
Variable("Written_Last"),
Variable("Buffer"),
*(
[Variable("Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*param_args,
],
),
Expand All @@ -621,7 +641,11 @@ def create_field_size_internal_function(message: Message, prefix: str) -> UnitPa
Variable("First"),
Variable("Verified_Last"),
Variable("Written_Last"),
Variable("Buffer"),
*(
[Variable("Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*param_args,
Variable("Fld"),
],
Expand All @@ -645,7 +669,7 @@ def recursive_call(fld: Field) -> expr.Expr:
expr.Variable("First"),
expr.Variable("Verified_Last"),
expr.Variable("Written_Last"),
expr.Variable("Buffer"),
*([expr.Variable("Buffer")] if message.has_aggregate_dependent_condition() else []),
*[expr.Variable(param.name) for param in message.parameter_types],
],
)
Expand All @@ -659,7 +683,7 @@ def field_size_internal_call(fld: expr.Variable) -> expr.Expr:
expr.Variable("First"),
expr.Variable("Verified_Last"),
expr.Variable("Written_Last"),
expr.Variable("Buffer"),
*([expr.Variable("Buffer")] if message.has_aggregate_dependent_condition() else []),
*[expr.Variable(param.name) for param in message.parameter_types],
fld,
],
Expand Down Expand Up @@ -736,7 +760,11 @@ def precond(fld: str) -> Precondition:
Variable("First"),
Variable("Verified_Last"),
Variable("Written_Last"),
Variable("Buffer"),
*(
[Variable("Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*param_args,
],
),
Expand All @@ -747,7 +775,11 @@ def precond(fld: str) -> Precondition:
Variable("First"),
Variable("Verified_Last"),
Variable("Written_Last"),
Variable("Buffer"),
*(
[Variable("Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*param_args,
Variable(fld),
],
Expand All @@ -765,7 +797,11 @@ def fld_first_func(fld: Field) -> ExpressionFunctionDeclaration:
Parameter(["First"], const.TYPES_BIT_INDEX),
Parameter(["Verified_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Written_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Buffer"], const.TYPES_BYTES_PTR),
*(
[Parameter(["Buffer"], const.TYPES_BYTES_PTR)]
if message.has_aggregate_dependent_condition()
else []
),
*common.message_parameters(message),
],
),
Expand All @@ -781,7 +817,11 @@ def fld_first_func(fld: Field) -> ExpressionFunctionDeclaration:
Parameter(["First"], const.TYPES_BIT_INDEX),
Parameter(["Verified_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Written_Last"], const.TYPES_BIT_LENGTH),
Parameter(["Buffer"], const.TYPES_BYTES_PTR),
*(
[Parameter(["Buffer"], const.TYPES_BYTES_PTR)]
if message.has_aggregate_dependent_condition()
else []
),
*common.message_parameters(message),
Parameter(["Fld"], "Field"),
],
Expand Down Expand Up @@ -2004,7 +2044,11 @@ def create_field_size_function(
Variable("Ctx.First"),
Variable("Ctx.Verified_Last"),
Variable("Ctx.Written_Last"),
Variable("Ctx.Buffer"),
*(
[Variable("Ctx.Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*[Selected(Variable("Ctx"), fld.name) for fld in message.parameter_types],
Variable("Fld"),
],
Expand Down Expand Up @@ -2057,7 +2101,11 @@ def create_field_first_function(prefix: str, message: Message) -> UnitPart:
Variable("Ctx.First"),
Variable("Ctx.Verified_Last"),
Variable("Ctx.Written_Last"),
Variable("Ctx.Buffer"),
*(
[Variable("Ctx.Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*[Selected(Variable("Ctx"), fld.name) for fld in message.parameter_types],
Variable("Fld"),
],
Expand Down Expand Up @@ -2192,7 +2240,7 @@ def condition(field: Field, message: Message) -> Expr:
elif isinstance(
message.field_types[field],
Composite,
) and common.has_aggregate_dependent_condition(message, field):
) and message.has_aggregate_dependent_condition(field):
c = c.substituted(
lambda x: expr.Variable("Agg") if x == expr.Variable(field.name) else x,
)
Expand All @@ -2208,7 +2256,7 @@ def condition(field: Field, message: Message) -> Expr:
),
*(
[Parameter(["Agg"], const.TYPES_BYTES)]
if common.has_aggregate_dependent_condition(message)
if message.has_aggregate_dependent_condition()
else []
),
*(
Expand Down Expand Up @@ -2528,7 +2576,11 @@ def create_valid_next_function(message: Message) -> UnitPart:
Variable("Ctx.First"),
Variable("Ctx.Verified_Last"),
Variable("Ctx.Written_Last"),
Variable("Ctx.Buffer"),
*(
[Variable("Ctx.Buffer")]
if message.has_aggregate_dependent_condition()
else []
),
*[Selected(Variable("Ctx"), fld.name) for fld in message.parameter_types],
Variable("Fld"),
],
Expand Down
4 changes: 2 additions & 2 deletions rflx/generator/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def create_get_function(
return UnitPart()

comparison_to_aggregate = any(
(isinstance(t, Composite) and common.has_aggregate_dependent_condition(message, f))
(isinstance(t, Composite) and message.has_aggregate_dependent_condition(f))
for f, t in message.field_types.items()
)

Expand Down Expand Up @@ -271,7 +271,7 @@ def create_verify_procedure(
),
),
]
if common.has_aggregate_dependent_condition(message)
if message.has_aggregate_dependent_condition()
else []
),
*(
Expand Down
18 changes: 18 additions & 0 deletions rflx/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,24 @@ def has_fixed_size(self) -> bool:
def has_implicit_size(self) -> bool:
return any(l.has_implicit_size for l in self.structure)

def has_aggregate_dependent_condition(
self,
field: Field | None = None,
) -> bool:
links = self.outgoing(field) if field else self.structure
fields = [field] if field else self.fields
return any(
r
for l in links
for r in l.condition.findall(lambda x: isinstance(x, (expr.Equal, expr.NotEqual)))
if isinstance(r, (expr.Equal, expr.NotEqual))
and r.findall(lambda x: isinstance(x, expr.Aggregate))
and any(
r.left == expr.Variable(f.identifier) or r.right == expr.Variable(f.identifier)
for f in fields
)
)

@property
def is_definite(self) -> bool:
"""
Expand Down
Loading

0 comments on commit abc1dc6

Please sign in to comment.