Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jan 26, 2022
1 parent 3f96e58 commit 106ddf9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 59 deletions.
108 changes: 54 additions & 54 deletions tests/assemblers/test_boosting_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ def test_regression():
expected = ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.725),
ast.FeatureRef(8),
ast.NumVal(1.0000000180025095e-35),
ast.CompOpType.GT),
ast.NumVal(22.030283219508686),
ast.NumVal(23.27840740210207)),
ast.NumVal(156.64462853604854),
ast.NumVal(148.40956590509697)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.8375),
ast.FeatureRef(2),
ast.NumVal(0.00780560282464346),
ast.CompOpType.GT),
ast.NumVal(1.2777791671888081),
ast.NumVal(-0.2686772850549309)),
ast.NumVal(4.996373375352607),
ast.NumVal(-3.1063596100284814)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)
Expand All @@ -93,18 +93,18 @@ def test_regression_random_forest():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.605),
ast.FeatureRef(2),
ast.NumVal(0.00780560282464346),
ast.CompOpType.GT),
ast.NumVal(17.398543657369768),
ast.NumVal(29.851408659650296)),
ast.NumVal(210.27118647591766),
ast.NumVal(120.45454548930705)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.888),
ast.CompOpType.GT),
ast.NumVal(37.2235298136268),
ast.NumVal(19.948122884684025)),
ast.FeatureRef(2),
ast.NumVal(-0.007822672246629598),
ast.CompOpType.LTE),
ast.NumVal(114.24161077349474),
ast.NumVal(194.84868424576604)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)
Expand Down Expand Up @@ -159,18 +159,18 @@ def test_simple_sigmoid_output_transform():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(4.002437528537838),
ast.NumVal(4.090096709787509)),
ast.FeatureRef(8),
ast.NumVal(-0.0028501970360456344),
ast.CompOpType.LTE),
ast.NumVal(5.8325360677435345),
ast.NumVal(5.891973988308211)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(14.895),
ast.CompOpType.GT),
ast.NumVal(-0.0417499606641773),
ast.NumVal(0.02069953712454655)),
ast.FeatureRef(8),
ast.NumVal(-0.005612778088288765),
ast.CompOpType.LTE),
ast.NumVal(-0.027170480653266372),
ast.NumVal(0.026423953384869338)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand All @@ -188,18 +188,18 @@ def test_log1p_exp_output_transform():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(0.6622623010380544),
ast.NumVal(0.6684065452877841)),
ast.FeatureRef(8),
ast.NumVal(-0.0028501970360456344),
ast.CompOpType.LTE),
ast.NumVal(0.693713164308067),
ast.NumVal(0.694435273176687)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(15.145),
ast.CompOpType.GT),
ast.NumVal(0.1404975120475147),
ast.NumVal(0.14535916856709272)),
ast.FeatureRef(8),
ast.NumVal(-0.005612778088288765),
ast.CompOpType.LTE),
ast.NumVal(0.14830023030115363),
ast.NumVal(0.14902176200722345)),
ast.BinNumOpType.ADD)))

assert utils.cmp_exprs(actual, expected)
Expand All @@ -216,18 +216,18 @@ def test_maybe_sqr_output_transform():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.725),
ast.FeatureRef(8),
ast.NumVal(1.0000000180025095e-35),
ast.CompOpType.GT),
ast.NumVal(4.569350528717041),
ast.NumVal(4.663526439666748)),
ast.NumVal(12.094032478332519),
ast.NumVal(11.671793556213379)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(11.655),
ast.CompOpType.GT),
ast.NumVal(-0.04462450027465819),
ast.NumVal(0.033305134773254384)),
ast.FeatureRef(8),
ast.NumVal(-0.00468258384360457),
ast.CompOpType.LTE),
ast.NumVal(-0.18738342285156248),
ast.NumVal(0.19059675216674812)),
ast.BinNumOpType.ADD),
to_reuse=True)

Expand All @@ -250,18 +250,18 @@ def test_exp_output_transform():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.725),
ast.FeatureRef(8),
ast.NumVal(1.0000000180025095e-35),
ast.CompOpType.GT),
ast.NumVal(3.1043985065105892),
ast.NumVal(3.1318783133960197)),
ast.NumVal(5.040167360736721),
ast.NumVal(5.013324518244505)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.8375),
ast.FeatureRef(2),
ast.NumVal(0.00780560282464346),
ast.CompOpType.GT),
ast.NumVal(0.028409619436010138),
ast.NumVal(-0.0060740730485278754)),
ast.NumVal(0.016475080997255653),
ast.NumVal(-0.010346335106608635)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_generate_code(pickled_model):

verify_python_model_is_expected(
generated_code,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
expected_output=-44.40540274041321)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
expected_output=11089.941259597403)


def test_function_name(pickled_model):
Expand Down Expand Up @@ -151,5 +151,5 @@ def test_unsupported_args_are_ignored(pickled_model):

verify_python_model_is_expected(
generated_code,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
expected_output=-44.40540274041321)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
expected_output=11089.941259597403)
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def verify_python_model_is_expected(model_code, input, expected_output):

context = {}
exec(code, context)
print(context["result"])

assert np.isclose(context["result"], expected_output)


Expand Down

0 comments on commit 106ddf9

Please sign in to comment.