From 4f1219d74483773e89dbd89729cd7fcd01313d8c Mon Sep 17 00:00:00 2001 From: Shaikh Ubaid Date: Wed, 2 Aug 2023 17:57:38 +0530 Subject: [PATCH 1/2] WASM: Support string comparison --- src/libasr/codegen/asr_to_wasm.cpp | 136 ++++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 3 deletions(-) diff --git a/src/libasr/codegen/asr_to_wasm.cpp b/src/libasr/codegen/asr_to_wasm.cpp index 93650918dc..9dc60dcbf7 100644 --- a/src/libasr/codegen/asr_to_wasm.cpp +++ b/src/libasr/codegen/asr_to_wasm.cpp @@ -77,7 +77,8 @@ enum RT_FUNCS { abs_c64 = 10, equal_c32 = 11, equal_c64 = 12, - NO_OF_RT_FUNCS = 13, + string_cmp = 13, + NO_OF_RT_FUNCS = 14, }; enum GLOBAL_VAR { @@ -552,6 +553,93 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { }); } + void emit_string_cmp() { + using namespace wasm; + m_wa.define_func({i32, i32}, {i32}, {i32, i32, i32, i32, i32, i32}, "string_cmp", [&](){ + /* + local 0 (param 0): string 1 (s1) + local 1 (param 1): string 2 (s2) + local 2: len(s1) + local 3: len(s2) + local 4: min(len(s1), len(s2)) + local 5: loop variable + local 6: temp variable to store s1[i] - s2[i] + local 7: return variable + */ + + m_wa.emit_local_get(0); + m_wa.emit_i32_load(mem_align::b8, 4); + m_wa.emit_local_set(2); + + m_wa.emit_local_get(1); + m_wa.emit_i32_load(mem_align::b8, 4); + m_wa.emit_local_set(3); + + m_wa.emit_if_else([&](){ + m_wa.emit_local_get(2); + m_wa.emit_local_get(3); + m_wa.emit_i32_le_s(); + }, [&](){ + m_wa.emit_local_get(2); + m_wa.emit_local_set(4); + }, [&](){ + m_wa.emit_local_get(3); + m_wa.emit_local_set(4); + }); + + m_wa.emit_i32_const(0); + m_wa.emit_local_set(5); + + m_wa.emit_loop([&](){ + m_wa.emit_local_get(5); + m_wa.emit_local_get(4); + m_wa.emit_i32_lt_s(); + }, [&](){ + m_wa.emit_local_get(0); + m_wa.emit_local_get(5); + m_wa.emit_i32_add(); + m_wa.emit_i32_load8_u(mem_align::b8, 8); + + m_wa.emit_local_get(1); + m_wa.emit_local_get(5); + m_wa.emit_i32_add(); + m_wa.emit_i32_load8_u(mem_align::b8, 8); + + m_wa.emit_i32_sub(); + m_wa.emit_local_set(6); + + m_wa.emit_local_get(6); + m_wa.emit_i32_const(0); + m_wa.emit_i32_ne(); + + // branch to end of if, if char diff not equal to 0 + m_wa.emit_br_if(m_wa.nest_lvl - m_wa.cur_loop_nest_lvl - 2U); + + m_wa.emit_local_get(5); + m_wa.emit_i32_const(1); + m_wa.emit_i32_add(); + m_wa.emit_local_set(5); + }); + + m_wa.emit_if_else([&](){ + m_wa.emit_local_get(5); + m_wa.emit_local_get(4); + m_wa.emit_i32_lt_s(); + }, [&](){ + m_wa.emit_local_get(6); + m_wa.emit_local_set(7); + }, [&](){ + m_wa.emit_local_get(2); + m_wa.emit_local_get(3); + m_wa.emit_i32_sub(); + m_wa.emit_local_set(7); + }); + + m_wa.emit_local_get(7); + m_wa.emit_return(); + }); + } + void declare_global_var(ASR::Variable_t* v) { if (v->m_type->type == ASR::ttypeType::TypeParameter) { // Ignore type variables @@ -688,6 +776,7 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { m_rt_funcs_map[abs_c64] = &ASRToWASMVisitor::emit_complex_abs_64; m_rt_funcs_map[equal_c32] = &ASRToWASMVisitor::emit_complex_equal_32; m_rt_funcs_map[equal_c64] = &ASRToWASMVisitor::emit_complex_equal_64; + m_rt_funcs_map[string_cmp] = &ASRToWASMVisitor::emit_string_cmp; { // Pre-declare all functions first, then generate code @@ -1915,6 +2004,47 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { } } + void handle_string_compare(const ASR::StringCompare_t &x) { + if (x.m_value) { + visit_expr(*x.m_value); + return; + } + INCLUDE_RUNTIME_FUNC(string_cmp); + this->visit_expr(*x.m_left); + this->visit_expr(*x.m_right); + m_wa.emit_call(m_rt_func_used_idx[string_cmp]); + m_wa.emit_i32_const(0); + switch (x.m_op) { + case (ASR::cmpopType::Eq): { + m_wa.emit_i32_eq(); + break; + } + case (ASR::cmpopType::Gt): { + m_wa.emit_i32_gt_s(); + break; + } + case (ASR::cmpopType::GtE): { + m_wa.emit_i32_ge_s(); + break; + } + case (ASR::cmpopType::Lt): { + m_wa.emit_i32_lt_s(); + break; + } + case (ASR::cmpopType::LtE): { + m_wa.emit_i32_le_s(); + break; + } + case (ASR::cmpopType::NotEq): { + m_wa.emit_i32_ne(); + break; + } + default: + throw CodeGenError( + "handle_string_compare: ICE: Unknown string comparison operator"); + } + } + void visit_IntegerCompare(const ASR::IntegerCompare_t &x) { handle_integer_compare(x); } @@ -1931,8 +2061,8 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { handle_integer_compare(x); } - void visit_StringCompare(const ASR::StringCompare_t & /*x*/) { - throw CodeGenError("String Types not yet supported"); + void visit_StringCompare(const ASR::StringCompare_t &x) { + handle_string_compare(x); } void visit_StringLen(const ASR::StringLen_t & x) { From 45668c2b34274acdb16e13b4ada312d0cc5b4486 Mon Sep 17 00:00:00 2001 From: Shaikh Ubaid Date: Wed, 2 Aug 2023 17:59:37 +0530 Subject: [PATCH 2/2] TEST: Enable string comparison test for wasm --- integration_tests/CMakeLists.txt | 2 +- integration_tests/test_str_comparison.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index eb0d9b457e..affa4bd57c 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -708,7 +708,7 @@ RUN(NAME test_vars_01 LABELS cpython llvm) RUN(NAME test_version LABELS cpython llvm) RUN(NAME logical_binop1 LABELS cpython llvm) RUN(NAME vec_01 LABELS cpython llvm c NOFAST) -RUN(NAME test_str_comparison LABELS cpython llvm c) +RUN(NAME test_str_comparison LABELS cpython llvm c wasm) RUN(NAME test_bit_length LABELS cpython llvm c) RUN(NAME str_to_list_cast LABELS cpython llvm c) RUN(NAME cast_01 LABELS cpython llvm c) diff --git a/integration_tests/test_str_comparison.py b/integration_tests/test_str_comparison.py index 4ac40aaa78..3108c224e0 100644 --- a/integration_tests/test_str_comparison.py +++ b/integration_tests/test_str_comparison.py @@ -5,18 +5,18 @@ def f(): assert s1 <= s2 assert s1 >= s2 s1 = "abcde" - assert s1 >= s2 + assert s1 >= s2 assert s1 > s2 s1 = "abc" - assert s1 < s2 + assert s1 < s2 assert s1 <= s2 s1 = "Abcd" s2 = "abcd" - assert s1 < s2 + assert s1 < s2 s1 = "orange" s2 = "apple" - assert s1 >= s2 - assert s1 > s2 + assert s1 >= s2 + assert s1 > s2 s1 = "albatross" s2 = "albany" assert s1 >= s2 @@ -28,9 +28,11 @@ def f(): assert s1 < s2 assert s1 != s2 s1 = "Zebra" - s2 = "ant" + s2 = "ant" assert s1 <= s2 assert s1 < s2 assert s1 != s2 + print("Ok") + f()