Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of str.partition() with tuples #1619

Merged
merged 14 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions integration_tests/test_str_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,50 @@ def endswith():
suffix = "apple"
assert s.endswith(suffix) == False

def partition():

# Note: Both string or seperator cannot be empty
# Case 1: When string is constant and seperator is also constant
assert " ".partition(" ") == (""," "," ")
assert "apple mango".partition(" ") == ("apple"," ","mango")
assert "applemango".partition("afdnjkfsn") == ("applemango","","")
assert "applemango".partition("an") == ("applem", "an", "go")
assert "applemango".partition("mango") == ("apple", "mango", "")
assert "applemango".partition("applemango") == ("", "applemango", "")
assert "applemango".partition("ppleman") == ("a", "ppleman", "go")
assert "applemango".partition("pplt") == ("applemango", "", "")

# Case 2: When string is constant and seperator is variable
seperator: str
seperator = " "
assert " ".partition(seperator) == (""," "," ")
seperator = " "
assert "apple mango".partition(seperator) == ("apple"," ","mango")
seperator = "5:30 "
assert " rendezvous 5:30 ".partition(seperator) == (" rendezvous ", "5:30 ", "")
seperator = "^&"
assert "@#$%^&*()#!".partition(seperator) == ("@#$%", "^&", "*()#!")
seperator = "daddada "
assert " rendezvous 5:30 ".partition(seperator) == (" rendezvous 5:30 ", "", "")
seperator = "longer than string"
assert "two plus".partition(seperator) == ("two plus", "", "")

# Case 3: When string is variable and seperator is either constant or variable
s: str
s = "tomorrow"
assert s.partition("apple") == ("tomorrow", "", "")
assert s.partition("rr") == ("tomo", "rr", "ow")
assert s.partition(seperator) == ("tomorrow", "", "")

s = "rendezvous 5"
assert s.partition(" ") == ("rendezvous", " ", "5")
assert s.partition("5") == ("rendezvous ", "5", "")
assert s.partition(s) == ("", "rendezvous 5", "")
seperator = "vous "
assert s.partition(seperator) == ("rendez", "vous ", "5")
seperator = "apple"
assert s.partition(seperator) == ("rendezvous 5", "", "")

def check():
capitalize()
lower()
Expand All @@ -140,5 +184,6 @@ def check():
find()
startswith()
endswith()
partition()

check()
127 changes: 127 additions & 0 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5747,6 +5747,41 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
// Push string and substring argument on top of Vector (or Function Arguments Stack basically)
fn_args.push_back(al, str);
fn_args.push_back(al, suffix);
} else if (attr_name == "partition") {

/*
str.partition(seperator) ---->

Split the string at the first occurrence of sep, and return a 3-tuple containing the part
before the separator, the separator itself, and the part after the separator.
If the separator is not found, return a 3-tuple containing the string itself, followed
by two empty strings.
*/

if(args.size() != 1) {
throw SemanticError("str.partition() takes one argument",
loc);
}

ASR::expr_t *arg_seperator = args[0].m_value;
ASR::ttype_t *arg_seperator_type = ASRUtils::expr_type(arg_seperator);
if (!ASRUtils::is_character(*arg_seperator_type)) {
throw SemanticError("str.partition() takes one argument of type: str",
loc);
}

harshsingh-24 marked this conversation as resolved.
Show resolved Hide resolved
fn_call_name = "_lpython_str_partition";

ASR::call_arg_t str;
str.loc = loc;
str.m_value = s_var;
ASR::call_arg_t seperator;
seperator.loc = loc;
seperator.m_value = args[0].m_value;

fn_args.push_back(al, str);
fn_args.push_back(al, seperator);

} else {
throw SemanticError("String method not implemented: " + attr_name,
loc);
Expand Down Expand Up @@ -5795,6 +5830,75 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
return res;
}

ASR::expr_t* eval_partition(std::string &s_var, ASR::expr_t* arg_seperator,
const Location &loc, ASR::ttype_t *arg_seperator_type) {
/*
Invoked when Seperator argument is provided as a constant string
*/
ASR::StringConstant_t* seperator_constant = ASR::down_cast<ASR::StringConstant_t>(arg_seperator);
std::string seperator = seperator_constant->m_s;
if(seperator.size() == 0) {
throw SemanticError("empty separator", arg_seperator->base.loc);
}
/*
using KMP algorithm to find seperator inside string
res_tuple: stores the resulting 3-tuple expression --->
(if seperator exist) tuple: (left of seperator, seperator, right of seperator)
(if seperator does not exist) tuple: (string, "", "")
res_tuple_type: stores the type of each expression present in resulting 3-tuple
*/
int seperator_pos = KMP_string_match(s_var, seperator);
Vec<ASR::expr_t *> res_tuple;
Vec<ASR::ttype_t *> res_tuple_type;
res_tuple.reserve(al, 3);
res_tuple_type.reserve(al, 3);
std :: string first_res, second_res, third_res;
if(seperator_pos == -1) {
/* seperator does not exist */
first_res = s_var;
second_res = "";
third_res = "";
} else {
first_res = s_var.substr(0, seperator_pos);
second_res = seperator;
third_res = s_var.substr(seperator_pos + seperator.size());
}

res_tuple.push_back(al, ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, first_res), arg_seperator_type)));
res_tuple.push_back(al, ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, second_res), arg_seperator_type)));
res_tuple.push_back(al, ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, third_res), arg_seperator_type)));
res_tuple_type.push_back(al, arg_seperator_type);
res_tuple_type.push_back(al,arg_seperator_type);
res_tuple_type.push_back(al,arg_seperator_type);
ASR::ttype_t *tuple_type = ASRUtils::TYPE(ASR::make_Tuple_t(al, loc, res_tuple_type.p, res_tuple_type.n));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_TupleConstant_t(al, loc, res_tuple.p, res_tuple.size(), tuple_type));
return value;
}

void create_partition(const Location &loc, std::string &s_var, ASR::expr_t *arg_seperator,
ASR::ttype_t *arg_seperator_type) {

ASR::expr_t *value = nullptr;
if(ASRUtils::expr_value(arg_seperator)) {
value = eval_partition(s_var, arg_seperator, loc, arg_seperator_type);
}
ASR::symbol_t *fn_div = resolve_intrinsic_function(loc, "_lpython_str_partition");
Vec<ASR::call_arg_t> args;
args.reserve(al, 1);
ASR::call_arg_t str_arg;
str_arg.loc = loc;
ASR::ttype_t *str_type = ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, s_var.size(), nullptr, nullptr, 0));
str_arg.m_value = ASRUtils::EXPR(ASR::make_StringConstant_t(al, loc, s2c(al, s_var), str_type));
ASR::call_arg_t sub_arg;
sub_arg.loc = loc;
sub_arg.m_value = arg_seperator;
args.push_back(al, str_arg);
args.push_back(al, sub_arg);
tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_partition", loc);
ASR::down_cast2<ASR::FunctionCall_t>(tmp)->m_value = value;
return;
}

void handle_constant_string_attributes(std::string &s_var,
Vec<ASR::call_arg_t> &args, std::string attr_name, const Location &loc) {
if (attr_name == "capitalize") {
Expand Down Expand Up @@ -5999,6 +6103,29 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_endswith", loc);
}
return;
} else if (attr_name == "partition") {
/*
str.partition(seperator) ---->
Split the string at the first occurrence of sep, and return a 3-tuple containing the part
before the separator, the separator itself, and the part after the separator.
If the separator is not found, return a 3-tuple containing the string itself, followed
by two empty strings.
*/
if (args.size() != 1) {
throw SemanticError("str.partition() takes one arguments",
loc);
}
ASR::expr_t *arg_seperator = args[0].m_value;
ASR::ttype_t *arg_seperator_type = ASRUtils::expr_type(arg_seperator);
if (!ASRUtils::is_character(*arg_seperator_type)) {
throw SemanticError("str.partition() takes one arguments of type: str",
arg_seperator->base.loc);
}
if(s_var.size() == 0) {
throw SemanticError("string to undergo partition cannot be empty",loc);
}
create_partition(loc, s_var, arg_seperator, arg_seperator_type);
return;
} else {
throw SemanticError("'str' object has no attribute '" + attr_name + "'",
loc);
Expand Down
3 changes: 2 additions & 1 deletion src/lpython/semantics/python_comptime_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ struct PythonIntrinsicProcedures {
{"_lpython_str_strip", {m_builtin, &not_implemented}},
{"_lpython_str_swapcase", {m_builtin, &not_implemented}},
{"_lpython_str_startswith", {m_builtin, &not_implemented}},
{"_lpython_str_endswith", {m_builtin, &not_implemented}}
{"_lpython_str_endswith", {m_builtin, &not_implemented}},
{"_lpython_str_partition", {m_builtin, &not_implemented}}
};
}

Expand Down
18 changes: 18 additions & 0 deletions src/runtime/lpython_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,24 @@ def _lpython_str_endswith(s: str, suffix: str) -> bool:

return True

@overload
def _lpython_str_partition(s:str, sep: str) -> tuple[str, str, str]:
"""
Returns a 3-tuple splitted around seperator
"""
if(len(s) == 0):
harshsingh-24 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('empty string cannot be partitioned')
harshsingh-24 marked this conversation as resolved.
Show resolved Hide resolved
if(len(sep) == 0):
harshsingh-24 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('empty seperator')
harshsingh-24 marked this conversation as resolved.
Show resolved Hide resolved
res : tuple[str, str, str]
ind : i32
ind = _lpython_str_find(s, sep)
if(ind == -1):
harshsingh-24 marked this conversation as resolved.
Show resolved Hide resolved
res = (s, "", "")
else:
res = (s[0:ind], sep, s[ind+len(sep): len(s)])
return res


def list(s: str) -> list[str]:
l: list[str] = []
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-array_01_decl-39cf894.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-array_01_decl-39cf894.stdout",
"stdout_hash": "8a16e5822c8482c4852f99853f41a540874a26b728cfa7e9b673c5a4",
"stdout_hash": "3ab4cfa056d997fdbdcc663d3bb8007db20e5cd96d65b964b3585e5b",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-array_01_decl-39cf894.stdout

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/reference/asr-array_02_decl-e8f6874.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-array_02_decl-e8f6874.stdout",
"stdout_hash": "f12c18aa34d16d8cf0410f0b814b31772763b329843bb762d8ef9c7c",
"stdout_hash": "b5e35cbd9bebc76478eb79a404c87abf7bb7a44b0aee62c525386542",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-array_02_decl-e8f6874.stdout

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/reference/asr-bindc_02-bc1a7ea.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-bindc_02-bc1a7ea.stdout",
"stdout_hash": "6e529f9be21fccf158f130d6e19de9f1f082c54ec129a34a6474d027",
"stdout_hash": "a6d61c15dcefd11bf7f11a0eea824e9584afc245e97407914db7527c",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-bindc_02-bc1a7ea.stdout

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/reference/asr-cast-435c233.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-cast-435c233.stdout",
"stdout_hash": "70edccbfb356b0b020c44c491d975bd26b6629a825b207805c7bad99",
"stdout_hash": "b70656e1110deb358c9eea9b88c04b0755b95fc3ba7ce19ff9e2093d",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-cast-435c233.stdout
Original file line number Diff line number Diff line change
@@ -1 +1 @@
(TranslationUnit (SymbolTable 1 {_global_symbols: (Module (SymbolTable 109 {_lpython_main_program: (Function (SymbolTable 108 {}) _lpython_main_program (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [f] [] [(SubroutineCall 109 f () [] ())] () Public .false. .false.), f: (Function (SymbolTable 2 {list: (ExternalSymbol 2 list 4 list lpython_builtin [] list Private), s: (Variable 2 s [] Local () () Default (Character 1 -2 () []) Source Public Required .false.), x: (Variable 2 x [] Local () () Default (List (Character 1 -2 () [])) Source Public Required .false.), y: (Variable 2 y [] Local () () Default (List (Character 1 -2 () [])) Source Public Required .false.)}) f (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [list list list] [] [(= (Var 2 s) (StringConstant "lpython" (Character 1 7 () [])) ()) (= (Var 2 x) (FunctionCall 2 list () [((Var 2 s))] (List (Character 1 -2 () [])) () ()) ()) (= (Var 2 y) (ListConstant [(StringConstant "a" (Character 1 1 () [])) (StringConstant "b" (Character 1 1 () [])) (StringConstant "c" (Character 1 1 () []))] (List (Character 1 1 () []))) ()) (= (Var 2 x) (FunctionCall 2 list () [((Var 2 y))] (List (Character 1 -2 () [])) () ()) ()) (= (Var 2 x) (FunctionCall 2 list () [((StringConstant "lpython" (Character 1 7 () [])))] (List (Character 1 -2 () [])) (ListConstant [(StringConstant "l" (Character 1 1 () [])) (StringConstant "p" (Character 1 1 () [])) (StringConstant "y" (Character 1 1 () [])) (StringConstant "t" (Character 1 1 () [])) (StringConstant "h" (Character 1 1 () [])) (StringConstant "o" (Character 1 1 () [])) (StringConstant "n" (Character 1 1 () []))] (List (Character 1 1 () []))) ()) ())] () Public .false. .false.)}) _global_symbols [lpython_builtin] .false. .false.), lpython_builtin: (IntrinsicModule lpython_builtin), main_program: (Program (SymbolTable 107 {_lpython_main_program: (ExternalSymbol 107 _lpython_main_program 109 _lpython_main_program _global_symbols [] _lpython_main_program Public)}) main_program [_global_symbols] [(SubroutineCall 107 _lpython_main_program () [] ())])}) [])
(TranslationUnit (SymbolTable 1 {_global_symbols: (Module (SymbolTable 110 {_lpython_main_program: (Function (SymbolTable 109 {}) _lpython_main_program (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [f] [] [(SubroutineCall 110 f () [] ())] () Public .false. .false.), f: (Function (SymbolTable 2 {list: (ExternalSymbol 2 list 4 list lpython_builtin [] list Private), s: (Variable 2 s [] Local () () Default (Character 1 -2 () []) Source Public Required .false.), x: (Variable 2 x [] Local () () Default (List (Character 1 -2 () [])) Source Public Required .false.), y: (Variable 2 y [] Local () () Default (List (Character 1 -2 () [])) Source Public Required .false.)}) f (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [list list list] [] [(= (Var 2 s) (StringConstant "lpython" (Character 1 7 () [])) ()) (= (Var 2 x) (FunctionCall 2 list () [((Var 2 s))] (List (Character 1 -2 () [])) () ()) ()) (= (Var 2 y) (ListConstant [(StringConstant "a" (Character 1 1 () [])) (StringConstant "b" (Character 1 1 () [])) (StringConstant "c" (Character 1 1 () []))] (List (Character 1 1 () []))) ()) (= (Var 2 x) (FunctionCall 2 list () [((Var 2 y))] (List (Character 1 -2 () [])) () ()) ()) (= (Var 2 x) (FunctionCall 2 list () [((StringConstant "lpython" (Character 1 7 () [])))] (List (Character 1 -2 () [])) (ListConstant [(StringConstant "l" (Character 1 1 () [])) (StringConstant "p" (Character 1 1 () [])) (StringConstant "y" (Character 1 1 () [])) (StringConstant "t" (Character 1 1 () [])) (StringConstant "h" (Character 1 1 () [])) (StringConstant "o" (Character 1 1 () [])) (StringConstant "n" (Character 1 1 () []))] (List (Character 1 1 () []))) ()) ())] () Public .false. .false.)}) _global_symbols [lpython_builtin] .false. .false.), lpython_builtin: (IntrinsicModule lpython_builtin), main_program: (Program (SymbolTable 108 {_lpython_main_program: (ExternalSymbol 108 _lpython_main_program 110 _lpython_main_program _global_symbols [] _lpython_main_program Public)}) main_program [_global_symbols] [(SubroutineCall 108 _lpython_main_program () [] ())])}) [])
2 changes: 1 addition & 1 deletion tests/reference/asr-complex1-f26c460.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-complex1-f26c460.stdout",
"stdout_hash": "f062aa258b34f58a45dd65e8278772c13afcdb1e5be88bdf9a798d42",
"stdout_hash": "bb1714d801538a68149bf48a623a8991ce7d8fa755f1cf6194e8f743",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
Loading