Skip to content

Commit

Permalink
Apply some of mikea's review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hoodmane committed Dec 18, 2024
1 parent 28d804c commit cf7021d
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 87 deletions.
4 changes: 2 additions & 2 deletions deps/rust/cargo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ PACKAGES = {
"pico-args": crate.spec(version = "0"),
"proc-macro2": crate.spec(version = "1"),
"quote": crate.spec(version = "1"),
"ruff_python_ast": crate.spec(git = "https://github.com/astral-sh/ruff.git", tag = "0.7.0"),
"ruff_python_parser": crate.spec(git = "https://github.com/astral-sh/ruff.git", tag = "0.7.0"),
"ruff_python_ast": crate.spec(git = "https://github.com/astral-sh/ruff.git", tag = "0"),
"ruff_python_parser": crate.spec(git = "https://github.com/astral-sh/ruff.git", tag = "0"),
"serde_json": crate.spec(version = "1"),
"serde": crate.spec(version = "1", features = ["derive"]),
"syn": crate.spec(version = "2"),
Expand Down
8 changes: 8 additions & 0 deletions src/rust/cxx-integration/cxx-bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ inline kj::ArrayPtr<const char> fromRust(const ::rust::Str& str) {
return kj::ArrayPtr<const char>(str.data(), str.size());
}

inline kj::Array<kj::String> fromRust(::rust::Vec<::rust::String> vec) {
auto res = kj::heapArrayBuilder<kj::String>(vec.size());
for (auto& entry: vec) {
res.add(kj::str(entry.c_str()));
}
return res.finish();
}

struct Rust {
template <typename T>
static ::rust::Slice<const T> from(const kj::ArrayPtr<T>* arr) {
Expand Down
1 change: 1 addition & 0 deletions src/rust/python-parser/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ kj_test(
src = "import_parsing.c++",
deps = [
":python-parser",
"//src/rust/cxx-integration",
],
)
111 changes: 42 additions & 69 deletions src/rust/python-parser/import_parsing.c++
Original file line number Diff line number Diff line change
Expand Up @@ -3,125 +3,101 @@
// https://opensource.org/licenses/Apache-2.0

#include <workerd/rust/python-parser/lib.rs.h>
#include "workerd/rust/cxx-integration/lib.rs.h"

#include <kj/test.h>

using ::edgeworker::rust::python_parser::get_imports;

kj::Array<kj::String> parseImports(kj::Array<kj::String> cpp_modules) {
::rust::Vec<::rust::String> rust_modules;
rust_modules.reserve(cpp_modules.size());
kj::Array<kj::String> parseImports(kj::ArrayPtr<kj::StringPtr> cpp_modules) {
auto rust_modules = kj::heapArrayBuilder<::rust::Str const>(cpp_modules.size());
for (auto& entry: cpp_modules) {
rust_modules.push_back(entry.cStr());
rust_modules.add(entry.cStr());
}
auto rust_result = get_imports(rust_modules);
auto cpp_result = kj::heapArrayBuilder<kj::String>(rust_result.size());
for (auto& entry: rust_result) {
cpp_result.add(kj::str(entry.c_str()));
}
return cpp_result.finish();
::rust::Slice<::rust::Str const> rust_slice(rust_modules.begin(), rust_modules.size());
auto rust_result = get_imports(rust_slice);
return workerd::fromRust(rust_result);
}

namespace workerd::api {
namespace {

KJ_TEST("basic `import` tests") {
auto files = kj::heapArrayBuilder<kj::String>(2);
files.add(kj::str("import a\nimport z"));
files.add(kj::str("import b"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import a\nimport z"_kj, "import b"_kj));
KJ_REQUIRE(result.size() == 3);
KJ_REQUIRE(result[0] == "a");
KJ_REQUIRE(result[1] == "b");
KJ_REQUIRE(result[2] == "z");
}

KJ_TEST("supports whitespace") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("import a\nimport \\\n\tz"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import a\nimport \\\n\tz"_kj));
KJ_REQUIRE(result.size() == 2);
KJ_REQUIRE(result[0] == "a");
KJ_REQUIRE(result[1] == "z");
}

KJ_TEST("supports windows newlines") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("import a\r\nimport \\\r\n\tz"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import a\r\nimport \\\r\n\tz"_kj));
KJ_REQUIRE(result.size() == 2);
KJ_REQUIRE(result[0] == "a");
KJ_REQUIRE(result[1] == "z");
}

KJ_TEST("basic `from` test") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("from x import a,b\nfrom z import y"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("from x import a,b\nfrom z import y"_kj));
KJ_REQUIRE(result.size() == 2);
KJ_REQUIRE(result[0] == "x");
KJ_REQUIRE(result[1] == "z");
}

KJ_TEST("ignores indented blocks") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("import a\nif True:\n import x\nimport y"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import a\nif True:\n import x\nimport y"_kj));
KJ_REQUIRE(result.size() == 2);
KJ_REQUIRE(result[0] == "a");
KJ_REQUIRE(result[1] == "y");
}

KJ_TEST("supports nested imports") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("import a.b\nimport z.x.y.i"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import a.b\nimport z.x.y.i"_kj));
KJ_REQUIRE(result.size() == 2);
KJ_REQUIRE(result[0] == "a.b");
KJ_REQUIRE(result[1] == "z.x.y.i");
}

KJ_TEST("nested `from` test") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("from x.y.z import a,b\nfrom z import y"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("from x.y.z import a,b\nfrom z import y"_kj));
KJ_REQUIRE(result.size() == 2);
KJ_REQUIRE(result[0] == "x.y.z");
KJ_REQUIRE(result[1] == "z");
}

KJ_TEST("ignores trailing period") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("import a.b.\nimport z.x.y.i."));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import a.b.\nimport z.x.y.i."_kj));
KJ_REQUIRE(result.size() == 0);
}

KJ_TEST("ignores relative import") {
// This is where we diverge from the old AST-based approach. It would have returned `y` in the
// input below.
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("import .a.b\nimport ..z.x\nfrom .y import x"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import .a.b\nimport ..z.x\nfrom .y import x"_kj));
KJ_REQUIRE(result.size() == 0);
}

KJ_TEST("supports commas") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str("import a,b"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr("import a,b"_kj));
KJ_REQUIRE(result.size() == 2);
KJ_REQUIRE(result[0] == "a");
KJ_REQUIRE(result[1] == "b");
}

KJ_TEST("supports backslash") {
auto files = kj::heapArrayBuilder<kj::String>(4);
files.add(kj::str("import a\\\n,b"));
files.add(kj::str("import\\\n q,w"));
files.add(kj::str("from \\\nx import y"));
files.add(kj::str("from \\\n c import y"));
auto result = parseImports(files.finish());
auto result = parseImports(kj::arr(
"import a\\\n,b"_kj,
"import\\\n q,w"_kj,
"from \\\nx import y"_kj,
"from \\\n c import y"_kj
));
KJ_REQUIRE(result.size() == 6);
KJ_REQUIRE(result[0] == "a");
KJ_REQUIRE(result[1] == "b");
Expand All @@ -132,32 +108,31 @@ KJ_TEST("supports backslash") {
}

KJ_TEST("multiline-strings ignored") {
auto files = kj::heapArrayBuilder<kj::String>(4);
files.add(kj::str(R"SCRIPT(
auto files = kj::arr(R"SCRIPT(
FOO="""
import x
from y import z
"""
)SCRIPT"));
files.add(kj::str(R"SCRIPT(
)SCRIPT"_kj,
R"SCRIPT(
FOO='''
import f
from g import z
'''
)SCRIPT"));
files.add(kj::str(R"SCRIPT(FOO = "\
)SCRIPT"_kj,
R"SCRIPT(FOO = "\
import b \
")SCRIPT"));
files.add(kj::str("FOO=\"\"\" \n", R"SCRIPT(import x
")SCRIPT"_kj,
"FOO=\"\"\" \n"_kj,
R"SCRIPT(import x
from y import z
""")SCRIPT"));
auto result = parseImports(files.finish());
""")SCRIPT"_kj);
auto result = parseImports(files);
KJ_REQUIRE(result.size() == 0);
}

KJ_TEST("multiline-strings with imports in-between") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str(
auto files = kj::arr(
R"SCRIPT(FOO="""
import x
from y import z
Expand All @@ -167,29 +142,27 @@ import w
BAR="""
import e
"""
from t import u)SCRIPT"));
auto result = parseImports(files.finish());
from t import u)SCRIPT"_kj);
auto result = parseImports(files);
KJ_REQUIRE(result.size() == 3);
KJ_REQUIRE(result[0] == "q");
KJ_REQUIRE(result[1] == "t");
KJ_REQUIRE(result[2] == "w");
}

KJ_TEST("import after string literal") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str(R"SCRIPT(import a
"import b")SCRIPT"));
auto result = parseImports(files.finish());
auto files = kj::arr(R"SCRIPT(import a
"import b")SCRIPT"_kj);
auto result = parseImports(files);
KJ_REQUIRE(result.size() == 1);
KJ_REQUIRE(result[0] == "a");
}

KJ_TEST("langchain import") {
auto files = kj::heapArrayBuilder<kj::String>(1);
files.add(kj::str(R"SCRIPT(from js import Response, console, URL
auto files = kj::arr(R"SCRIPT(from js import Response, console, URL
from langchain.chat_models import ChatOpenAI
import openai)SCRIPT"));
auto result = parseImports(files.finish());
import openai)SCRIPT"_kj);
auto result = parseImports(files);
KJ_REQUIRE(result.size() == 3);
KJ_REQUIRE(result[0] == "js");
KJ_REQUIRE(result[1] == "langchain.chat_models");
Expand Down
33 changes: 17 additions & 16 deletions src/rust/python-parser/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,31 @@ use ruff_python_parser::parse_module;
mod ffi {

extern "Rust" {
fn get_imports(sources: &Vec<String>) -> Vec<String>;
fn get_imports(sources: &[&str]) -> Vec<String>;
}
}

#[must_use]
pub fn get_imports(sources: &Vec<String>) -> Vec<String> {
pub fn get_imports(sources: &[&str]) -> Vec<String> {
let mut names: HashSet<String> = HashSet::new();
for src in sources {
// Just skip it if it doesn't parse.
if let Ok(module) = parse_module(src) {
for stmt in &module.syntax().body {
match stmt {
Stmt::Import(s) => {
names.extend(s.names.iter().map(|x| x.name.id.as_str().into()));
}
Stmt::ImportFrom(StmtImportFrom {
module: Some(module),
level: 0,
..
}) => {
names.insert(module.id.as_str().into());
}
_ => {}
let Ok(module) = parse_module(src) else {
continue;
};
for stmt in &module.syntax().body {
match stmt {
Stmt::Import(s) => {
names.extend(s.names.iter().map(|x| x.name.id.as_str().to_owned()));
}
Stmt::ImportFrom(StmtImportFrom {
module: Some(module),
level: 0,
..
}) => {
names.insert(module.id.as_str().to_owned());
}
_ => {}
}
}
}
Expand Down

0 comments on commit cf7021d

Please sign in to comment.