Skip to content

Commit

Permalink
escape
Browse files Browse the repository at this point in the history
  • Loading branch information
Koncopd committed Dec 9, 2024
1 parent 2b9d4fc commit fc6417f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
33 changes: 26 additions & 7 deletions lamin_utils/_search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pandas import DataFrame, Series


# needed to filter out everything that doesn't contain `string`
def _contains(col: Series, string: str, case_sensitive: bool, fields_convert: dict):
if col.name not in fields_convert:
return [False] * len(col)
Expand All @@ -14,26 +16,41 @@ def _contains(col: Series, string: str, case_sensitive: bool, fields_convert: di
return col.str.contains(string, case=case_sensitive)


def _ranks(col: Series, string: str, case_sensitive: bool, fields_convert: dict):
# apply ranking based on rules
# `string` - search query
# `string_regex` - escaped search query, the argument is just for caching,
# to avoid recompute of escaping on every call
def _ranks(
col: Series,
string: str,
string_regex: str,
case_sensitive: bool,
fields_convert: dict,
):
if col.name not in fields_convert:
return [0] * len(col)
if fields_convert[col.name]:
col = col.astype(str)
exact_rank = col.str.fullmatch(string, case=case_sensitive) * 200
synonym_rank = (
col.str.match(rf"(?:^|.*\|){string}(?:\|.*|$)", case=case_sensitive) * 200
col.str.match(rf"(?:^|.*\|){string_regex}(?:\|.*|$)", case=case_sensitive) * 200
)
sub_rank = (
col.str.match(
rf"(?:^|.*[ \|\.,;:]){string}(?:[ \|\.,;:].*|$)", case=case_sensitive
rf"(?:^|.*[ \|\.,;:]){string_regex}(?:[ \|\.,;:].*|$)", case=case_sensitive
)
* 10
)
startswith_rank = (
col.str.match(rf"(?:^|.*\|){string}[^ ]*(?:\|.*|$)", case=case_sensitive) * 8
col.str.match(rf"(?:^|.*\|){string_regex}[^ ]*(?:\|.*|$)", case=case_sensitive)
* 8
)
right_rank = (
col.str.match(rf"(?:^|.*[ \|]){string_regex}.*", case=case_sensitive) * 2
)
left_rank = (
col.str.match(rf".*{string_regex}(?:$|[ \|\.,;:].*)", case=case_sensitive) * 2
)
right_rank = col.str.match(rf"(?:^|.*[ \|]){string}.*", case=case_sensitive) * 2
left_rank = col.str.match(rf".*{string}(?:$|[ \|\.,;:].*)", case=case_sensitive) * 2
contains_rank = col.str.contains(string, case=case_sensitive).astype("int32")
return (
exact_rank
Expand Down Expand Up @@ -97,7 +114,9 @@ def search(
if len(df_contains) == 0:
return df_contains

ranks = lambda col: _ranks(col, string, case_sensitive, fields_convert)
ranks = lambda col: _ranks(
col, string, re.escape(string), case_sensitive, fields_convert
)
rank = df_contains.apply(ranks).sum(axis=1)

if _show_rank:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def df():
"description": "A Specialized Cardiac Myocyte In The Sinoatrial And Atrioventricular Nodes. The Cell Is Slender And Fusiform Confined To The Nodal Center, Circumferentially Arranged Around The Nodal Artery.",
"children": ["CL:1000409", "CL:1000410"],
},
{
"ontology_id": "",
"name": "cat[*_*]",
"synonyms": "",
"description": "",
"children": [],
},
]
return pd.DataFrame.from_records(records)

Expand Down Expand Up @@ -101,3 +108,13 @@ def test_search_case_sensitive(df):
def test_search_empty_df():
res = search(pd.DataFrame(columns=["a", "b", "c"]), string="")
assert res.shape == (0, 3)


def test_escape_string():
res = search(df=df, string="cat[")
assert len(res) == 1
assert res.iloc[0]["name"] == "cat[*_*]"

res = search(df=df, string="*_*")
assert len(res) == 1
assert res.iloc[0]["name"] == "cat[*_*]"

0 comments on commit fc6417f

Please sign in to comment.