diff --git a/tangos/live_calculation/builtin_functions/arithmetic.py b/tangos/live_calculation/builtin_functions/arithmetic.py index 8561ec68..72253b3f 100644 --- a/tangos/live_calculation/builtin_functions/arithmetic.py +++ b/tangos/live_calculation/builtin_functions/arithmetic.py @@ -44,6 +44,18 @@ def greater(halos, vals1, vals2): def less(halos, vals1, vals2): return arithmetic_binary_op(vals1, vals2, np.less) +@BuiltinFunction.register +def equal(halos, vals1, vals2): + return arithmetic_binary_op(vals1, vals2, np.equal) + +@BuiltinFunction.register +def greater_equal(halos, vals1, vals2): + return arithmetic_binary_op(vals1, vals2, np.greater_equal) + +@BuiltinFunction.register +def less_equal(halos, vals1, vals2): + return arithmetic_binary_op(vals1, vals2, np.less_equal) + @BuiltinFunction.register def logical_and(halos, vals1, vals2): return arithmetic_binary_op(vals1, vals2, np.logical_and) diff --git a/tangos/live_calculation/parser.py b/tangos/live_calculation/parser.py index 1ef38fca..7a3d7459 100644 --- a/tangos/live_calculation/parser.py +++ b/tangos/live_calculation/parser.py @@ -25,7 +25,10 @@ def pack_args(for_function): (">", "greater"), ("<", "less"), ("|", "logical_or"), - ("&", "logical_and")] + ("&", "logical_and"), + ("==", "equal"), + (">=", "greater_equal"), + ("<=", "less_equal")] UNARY_OPS = [("!", "logical_not")] diff --git a/tests/test_live_calculation.py b/tests/test_live_calculation.py index 9d4ec278..4a40a04b 100644 --- a/tests/test_live_calculation.py +++ b/tests/test_live_calculation.py @@ -185,6 +185,17 @@ def test_arithmetic(): assert h.calculate("at(1.0,dummy_property_1)*at(5.0,dummy_property_1)") ==\ h.calculate("at(1.0,dummy_property_1)") * h.calculate("at(5.0,dummy_property_1)") +def test_comparison(): + h = tangos.get_halo("sim/ts1/1") + assert h.calculate("1.0<2.0") + assert not h.calculate("1.0>2.0") + assert h.calculate("1.0==1.0") + assert h.calculate("1.0>=1.0") + assert h.calculate("1.0<=1.0") + assert h.calculate("1.0>=0.5") + assert not h.calculate("1.0<=0.5") + + def test_calculate_array(): h = tangos.get_halo("sim/ts1/1")