Skip to content

Commit

Permalink
wip: testing output
Browse files Browse the repository at this point in the history
  • Loading branch information
philtweir committed Feb 22, 2024
1 parent 3a6a5a0 commit ffc6f32
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 18 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.18.3", features = ["extension-module"] }
serde = { version = "1.0.133", features = ["derive"] }
serde_json = "1.0.74"
berlin-core = "0.2.2"
berlin-core = { path = "../berlin-rs" }

# Logging
tracing = "0.1.29"
Expand Down
39 changes: 36 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::{Arc, Mutex};
use berlin_core::ustr::Ustr;
use pyo3::exceptions::{PyAttributeError, PyKeyError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::PyList;
use pyo3::types::{PyList, PyTuple};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
Expand All @@ -14,7 +14,10 @@ use berlin_core::location::{subdiv_key, CsvLocode, Location};
use berlin_core::locations_db::{
parse_data_blocks, parse_data_files, parse_data_list, LocationsDb,
};
use berlin_core::search::SearchTerm;
use berlin_core::search::{SearchTerm, Score};

// We will cap scores to this number
pub const MAXIMUM_SCORE: i32 = 1000;

#[pyclass]
struct LocationsDbProxy {
Expand All @@ -24,6 +27,7 @@ struct LocationsDbProxy {
#[pyclass(name = "Location")]
struct LocationProxy {
_loc: Location,
_score: Option<Score>,
_db: Arc<Mutex<LocationsDb>>,
}

Expand All @@ -34,6 +38,7 @@ impl LocationsDbProxy {
Some(loc) => Python::with_gil(|_py| {
Ok(LocationProxy {
_loc: loc,
_score: None,
_db: self._db.clone(),
})
}),
Expand Down Expand Up @@ -97,10 +102,11 @@ impl LocationsDbProxy {
let db = self._db.lock().unwrap();
db.search(&st)
.into_iter()
.map(|(key, _score)| {
.map(|(key, score)| {
let loc = db.all.get(&key).cloned().expect("loc should be in db");
LocationProxy {
_loc: loc,
_score: Some(score),
_db: self._db.clone(),
}
})
Expand Down Expand Up @@ -135,6 +141,30 @@ impl LocationProxy {
Ok(val.unwrap())
}

fn get_score(&self) -> Result<i32, PyErr> {
match self._score {
Some(score) => {
Ok(match i32::try_from(score.score) {
Ok(_score) => i32::max(MAXIMUM_SCORE, _score),
_ => MAXIMUM_SCORE
})
},
None => Err(PyAttributeError::new_err(format!["No string offset attached to this location object"]))
}
}

fn get_offset(&self) -> PyResult<Py<PyTuple>> {
match self._score {
Some(score) => {
let offset_tuple = Python::with_gil(|_py| {
PyTuple::new(_py, [score.offset.start, score.offset.end]).into()
});
Ok(offset_tuple)
},
None => Err(PyAttributeError::new_err(format!["No string offset attached to this location object"]))
}
}

fn get_names(&self) -> PyResult<Py<PyAny>> {
let val: Result<_, PyAttributeError> = Python::with_gil(|py| {
let names: &PyList =
Expand Down Expand Up @@ -176,6 +206,7 @@ impl LocationProxy {
let loc = db.retrieve(key).unwrap();
LocationProxy {
_loc: loc,
_score: None,
_db: self._db.clone(),
}
})
Expand All @@ -192,6 +223,7 @@ impl LocationProxy {
let loc = db.retrieve(key).unwrap();
Ok(LocationProxy {
_loc: loc,
_score: None,
_db: self._db.clone(),
})
}),
Expand All @@ -213,6 +245,7 @@ impl LocationProxy {
let loc = db.retrieve(&key).unwrap();
Ok(Some(LocationProxy {
_loc: loc,
_score: None,
_db: self._db.clone(),
}))
}
Expand Down
29 changes: 17 additions & 12 deletions tests/test_berlin.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from berlin import Location

def test_search_with_state(db):
query = "Abercorn"
state = "GB"
limit = 2
lev_distance = 2

result = db.query(query, limit, lev_distance, state=state)
assert len(result) == 1
loc = result[0]
assert loc.key == "UN-LOCODE-gb:abc"
assert loc.encoding == "UN-LOCODE"
assert loc.id == "gb:abc"
assert list(loc.words) == ["abercarn"]
for query in ("Dentists in Abercarn", "Dental Abercarn"):
query = "Abercorn"
state = "GB"
limit = 2
lev_distance = 2

result = db.query(query, limit, lev_distance, state=state)
assert len(result) == 1
loc = result[0]
assert loc.key == "UN-LOCODE-gb:abc"
assert loc.encoding == "UN-LOCODE"
assert loc.id == "gb:abc"
assert list(loc.words) == ["abercarn"]

assert isinstance(loc.get_score(), int)
print(loc.get_offset())
assert isinstance(loc.get_offset(), tuple)

def test_retrieve(db):
loc = db.retrieve("UN-LOCODE-gb:abc")
Expand Down

0 comments on commit ffc6f32

Please sign in to comment.