Skip to content

Commit

Permalink
Merge pull request #2245 from Shaikh-Ubaid/wasm_string_cmp
Browse files Browse the repository at this point in the history
WASM: Support string comparison
  • Loading branch information
Shaikh-Ubaid authored Aug 2, 2023
2 parents 9aefff4 + 45668c2 commit 7f2048e
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 10 deletions.
2 changes: 1 addition & 1 deletion integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ RUN(NAME test_vars_01 LABELS cpython llvm)
RUN(NAME test_version LABELS cpython llvm)
RUN(NAME logical_binop1 LABELS cpython llvm)
RUN(NAME vec_01 LABELS cpython llvm c NOFAST)
RUN(NAME test_str_comparison LABELS cpython llvm c)
RUN(NAME test_str_comparison LABELS cpython llvm c wasm)
RUN(NAME test_bit_length LABELS cpython llvm c)
RUN(NAME str_to_list_cast LABELS cpython llvm c)
RUN(NAME cast_01 LABELS cpython llvm c)
Expand Down
14 changes: 8 additions & 6 deletions integration_tests/test_str_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ def f():
assert s1 <= s2
assert s1 >= s2
s1 = "abcde"
assert s1 >= s2
assert s1 >= s2
assert s1 > s2
s1 = "abc"
assert s1 < s2
assert s1 < s2
assert s1 <= s2
s1 = "Abcd"
s2 = "abcd"
assert s1 < s2
assert s1 < s2
s1 = "orange"
s2 = "apple"
assert s1 >= s2
assert s1 > s2
assert s1 >= s2
assert s1 > s2
s1 = "albatross"
s2 = "albany"
assert s1 >= s2
Expand All @@ -28,9 +28,11 @@ def f():
assert s1 < s2
assert s1 != s2
s1 = "Zebra"
s2 = "ant"
s2 = "ant"
assert s1 <= s2
assert s1 < s2
assert s1 != s2

print("Ok")

f()
136 changes: 133 additions & 3 deletions src/libasr/codegen/asr_to_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ enum RT_FUNCS {
abs_c64 = 10,
equal_c32 = 11,
equal_c64 = 12,
NO_OF_RT_FUNCS = 13,
string_cmp = 13,
NO_OF_RT_FUNCS = 14,
};

enum GLOBAL_VAR {
Expand Down Expand Up @@ -552,6 +553,93 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
});
}

void emit_string_cmp() {
using namespace wasm;
m_wa.define_func({i32, i32}, {i32}, {i32, i32, i32, i32, i32, i32}, "string_cmp", [&](){
/*
local 0 (param 0): string 1 (s1)
local 1 (param 1): string 2 (s2)
local 2: len(s1)
local 3: len(s2)
local 4: min(len(s1), len(s2))
local 5: loop variable
local 6: temp variable to store s1[i] - s2[i]
local 7: return variable
*/

m_wa.emit_local_get(0);
m_wa.emit_i32_load(mem_align::b8, 4);
m_wa.emit_local_set(2);

m_wa.emit_local_get(1);
m_wa.emit_i32_load(mem_align::b8, 4);
m_wa.emit_local_set(3);

m_wa.emit_if_else([&](){
m_wa.emit_local_get(2);
m_wa.emit_local_get(3);
m_wa.emit_i32_le_s();
}, [&](){
m_wa.emit_local_get(2);
m_wa.emit_local_set(4);
}, [&](){
m_wa.emit_local_get(3);
m_wa.emit_local_set(4);
});

m_wa.emit_i32_const(0);
m_wa.emit_local_set(5);

m_wa.emit_loop([&](){
m_wa.emit_local_get(5);
m_wa.emit_local_get(4);
m_wa.emit_i32_lt_s();
}, [&](){
m_wa.emit_local_get(0);
m_wa.emit_local_get(5);
m_wa.emit_i32_add();
m_wa.emit_i32_load8_u(mem_align::b8, 8);

m_wa.emit_local_get(1);
m_wa.emit_local_get(5);
m_wa.emit_i32_add();
m_wa.emit_i32_load8_u(mem_align::b8, 8);

m_wa.emit_i32_sub();
m_wa.emit_local_set(6);

m_wa.emit_local_get(6);
m_wa.emit_i32_const(0);
m_wa.emit_i32_ne();

// branch to end of if, if char diff not equal to 0
m_wa.emit_br_if(m_wa.nest_lvl - m_wa.cur_loop_nest_lvl - 2U);

m_wa.emit_local_get(5);
m_wa.emit_i32_const(1);
m_wa.emit_i32_add();
m_wa.emit_local_set(5);
});

m_wa.emit_if_else([&](){
m_wa.emit_local_get(5);
m_wa.emit_local_get(4);
m_wa.emit_i32_lt_s();
}, [&](){
m_wa.emit_local_get(6);
m_wa.emit_local_set(7);
}, [&](){
m_wa.emit_local_get(2);
m_wa.emit_local_get(3);
m_wa.emit_i32_sub();
m_wa.emit_local_set(7);
});

m_wa.emit_local_get(7);
m_wa.emit_return();
});
}

void declare_global_var(ASR::Variable_t* v) {
if (v->m_type->type == ASR::ttypeType::TypeParameter) {
// Ignore type variables
Expand Down Expand Up @@ -688,6 +776,7 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
m_rt_funcs_map[abs_c64] = &ASRToWASMVisitor::emit_complex_abs_64;
m_rt_funcs_map[equal_c32] = &ASRToWASMVisitor::emit_complex_equal_32;
m_rt_funcs_map[equal_c64] = &ASRToWASMVisitor::emit_complex_equal_64;
m_rt_funcs_map[string_cmp] = &ASRToWASMVisitor::emit_string_cmp;

{
// Pre-declare all functions first, then generate code
Expand Down Expand Up @@ -1915,6 +2004,47 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
}
}

void handle_string_compare(const ASR::StringCompare_t &x) {
if (x.m_value) {
visit_expr(*x.m_value);
return;
}
INCLUDE_RUNTIME_FUNC(string_cmp);
this->visit_expr(*x.m_left);
this->visit_expr(*x.m_right);
m_wa.emit_call(m_rt_func_used_idx[string_cmp]);
m_wa.emit_i32_const(0);
switch (x.m_op) {
case (ASR::cmpopType::Eq): {
m_wa.emit_i32_eq();
break;
}
case (ASR::cmpopType::Gt): {
m_wa.emit_i32_gt_s();
break;
}
case (ASR::cmpopType::GtE): {
m_wa.emit_i32_ge_s();
break;
}
case (ASR::cmpopType::Lt): {
m_wa.emit_i32_lt_s();
break;
}
case (ASR::cmpopType::LtE): {
m_wa.emit_i32_le_s();
break;
}
case (ASR::cmpopType::NotEq): {
m_wa.emit_i32_ne();
break;
}
default:
throw CodeGenError(
"handle_string_compare: ICE: Unknown string comparison operator");
}
}

void visit_IntegerCompare(const ASR::IntegerCompare_t &x) {
handle_integer_compare(x);
}
Expand All @@ -1931,8 +2061,8 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
handle_integer_compare(x);
}

void visit_StringCompare(const ASR::StringCompare_t & /*x*/) {
throw CodeGenError("String Types not yet supported");
void visit_StringCompare(const ASR::StringCompare_t &x) {
handle_string_compare(x);
}

void visit_StringLen(const ASR::StringLen_t & x) {
Expand Down

0 comments on commit 7f2048e

Please sign in to comment.