Skip to content

Commit

Permalink
Merge pull request #2802 from advikkabra/empty-membership
Browse files Browse the repository at this point in the history
Prevent the hashing function being called when capacity is zero
  • Loading branch information
czgdp1807 authored Aug 18, 2024
2 parents ba2dff6 + a010b04 commit a104c30
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
12 changes: 12 additions & 0 deletions integration_tests/test_membership_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ def test_int_dict():
i = 4
assert (i in a)

a = {}
assert (1 not in a)

def test_str_dict():
a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'}
i: str
Expand All @@ -14,6 +17,9 @@ def test_str_dict():
i = 'c'
assert (i in a)

a = {}
assert ('a' not in a)

def test_int_set():
a: set[i32] = {1, 2, 3, 4}
i: i32
Expand All @@ -22,6 +28,9 @@ def test_int_set():
i = 4
assert (i in a)

a = set()
assert (1 not in a)

def test_str_set():
a: set[str] = {'a', 'b', 'c', 'e', 'f'}
i: str
Expand All @@ -30,6 +39,9 @@ def test_str_set():
i = 'c'
assert (i in a)

a = set()
assert ('a' not in a)

test_int_dict()
test_str_dict()
test_int_set()
Expand Down
28 changes: 22 additions & 6 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1726,9 +1726,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ptr_loads = ptr_loads_copy;
llvm::Value *capacity = LLVM::CreateLoad(*builder,
llvm_utils->dict_api->get_pointer_to_capacity(right));
llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module);

tmp = llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true);
get_builder0();
llvm::AllocaInst *res = builder0.CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
llvm_utils->create_if_else(builder->CreateICmpEQ(
capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))),
[&]() {
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)), res);
}, [&]() {
llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module);
LLVM::CreateStore(*builder, llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true), res);
});
tmp = LLVM::CreateLoad(*builder, res);
}

void visit_SetContains(const ASR::SetContains_t &x) {
Expand All @@ -1748,9 +1756,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ptr_loads = ptr_loads_copy;
llvm::Value *capacity = LLVM::CreateLoad(*builder,
llvm_utils->set_api->get_pointer_to_capacity(right));
llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module);

tmp = llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true);
get_builder0();
llvm::AllocaInst *res = builder0.CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
llvm_utils->create_if_else(builder->CreateICmpEQ(
capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))),
[&]() {
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 0)), res);
}, [&]() {
llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module);
LLVM::CreateStore(*builder, llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true), res);
});
tmp = LLVM::CreateLoad(*builder, res);
}

void visit_DictLen(const ASR::DictLen_t& x) {
Expand Down

0 comments on commit a104c30

Please sign in to comment.