Skip to content
This repository has been archived by the owner on Jan 29, 2024. It is now read-only.

[BBS-126] Make bandit tests pass #177

Merged
merged 19 commits into from
Dec 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bandit
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
skips: [B322]
2 changes: 1 addition & 1 deletion src/bbsearch/database/cord_19.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def mark_bad_sentences(engine, sentences_table_name):

logger.info("Getting all sentences")
with engine.begin() as connection:
query = f"SELECT sentence_id, text FROM {sentences_table_name}"
query = f"SELECT sentence_id, text FROM {sentences_table_name}" # nosec
df_sentences = pd.read_sql(query, connection)

logger.info("Computing text lengths")
Expand Down
2 changes: 1 addition & 1 deletion src/bbsearch/database/mining_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _delete_rows(self):
DELETE
FROM {self.target_table}
WHERE mining_model = :mining_model
"""
""" # nosec
self.engine.execute(
sqlalchemy.sql.text(query),
mining_model=model_schema["model_path"],
Expand Down
5 changes: 3 additions & 2 deletions src/bbsearch/mining/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def get_core_nlp_analysis(self, text):
response = requests.post(
self.core_nlp_url + request_params, data=request_data
)
assert response.status_code == 200
response.raise_for_status()
pafonta marked this conversation as resolved.
Show resolved Hide resolved
response_json = json.loads(response.text)
except requests.exceptions.RequestException:
warnings.warn("There was a problem contacting the CoreNLP server.")
Expand Down Expand Up @@ -801,7 +801,8 @@ def __init__(self, texts, attribute_extractor, ee_model):
"""
super().__init__()

assert len(texts) > 0
if not texts:
raise TypeError("texts must be a non-empty list.")
self.texts = texts

self.idx_slider = widgets.IntSlider(
Expand Down
18 changes: 13 additions & 5 deletions src/bbsearch/mining/entity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes and functions for entity extraction (aka named entity recognition)."""

import ast
import copy

import numpy as np
Expand Down Expand Up @@ -198,6 +199,8 @@ def to_jsonl(self, path, sort_by=None):

Parameters
----------
path : pathlib.Path
File where to save it.
sort_by : None or list
If None, then no sorting taking place. If ``list``, then the
names of columns along which to sort.
Expand Down Expand Up @@ -326,11 +329,16 @@ def row2raw(row):
):
raise KeyError()

value = (
eval(f"{value_type}({value_str})")
if value_type != "str"
else value_str
)
if value_type != "str":
try:
value = ast.literal_eval(value_str)
FrancescoCasalegno marked this conversation as resolved.
Show resolved Hide resolved
except ValueError as ve:
if str(ve).startswith("malformed node or string"):
raise NameError(str(ve)) from ve
else:
raise
else:
value = value_str

token_pattern = {attribute: value}
if op:
Expand Down
6 changes: 5 additions & 1 deletion src/bbsearch/mining/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,11 @@ def ner_errors(
...
}
"""
assert len(iob_true) == len(iob_pred) == len(tokens)
if not (len(iob_true) == len(iob_pred) == len(tokens)):
raise ValueError(
f"Inputs iob_true (len={len(iob_true)}), iob_pred (len={len(iob_pred)}), "
f"tokens (len={len(tokens)}) should have equal length."
)
etypes = unique_etypes(iob_true)

etypes_map = etypes_map if etypes_map is not None else dict()
Expand Down
10 changes: 5 additions & 5 deletions src/bbsearch/mining/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,29 @@ def annotate(doc, sent, ent_1, ent_2, etype_symbols):
tokens = []
i = sent.start
while i < sent.end:
new_token = " " # hack to keep the punctuation nice
new_tkn = " " # hack to keep the punctuation nice
pafonta marked this conversation as resolved.
Show resolved Hide resolved

if ent_1.start == i:
start, end = ent_1.start, ent_1.end
new_token += (
new_tkn += (
etype_symbols[etype_1][0]
+ doc[start:end].text
+ etype_symbols[etype_1][1]
)

elif ent_2.start == i:
start, end = ent_2.start, ent_2.end
new_token += (
new_tkn += (
etype_symbols[etype_2][0]
+ doc[start:end].text
+ etype_symbols[etype_2][1]
)

else:
start, end = i, i + 1
new_token = doc[i].text if doc[i].is_punct else new_token + doc[i].text
new_tkn = doc[i].text if doc[i].is_punct else new_tkn + doc[i].text

tokens.append(new_token)
tokens.append(new_tkn)
i += end - start

return "".join(tokens).strip()
Expand Down
119 changes: 77 additions & 42 deletions src/bbsearch/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import sqlalchemy.sql as sql


def get_titles(article_ids, engine):
Expand All @@ -23,13 +24,16 @@ def get_titles(article_ids, engine):
if len(article_ids) == 0:
return {}

query = f"""\
SELECT article_id, title
FROM articles
WHERE article_id IN ({",".join(map(str, article_ids))})
"""
query = sql.text(
"""SELECT article_id, title
FROM articles
WHERE article_id IN :article_ids
"""
)
query = query.bindparams(sql.bindparam("article_ids", expanding=True))

with engine.begin() as connection:
response = connection.execute(query).fetchall()
response = connection.execute(query, {"article_ids": article_ids}).fetchall()
titles = {article_id: title for article_id, title in response}

return titles
Expand Down Expand Up @@ -72,16 +76,21 @@ def retrieve_sentences_from_sentence_ids(sentence_ids, engine, keep_order=False)
Pandas DataFrame containing all sentences and their corresponding metadata:
article_id, sentence_id, section_name, text, paragraph_pos_in_article.
"""
sentence_ids_s = ", ".join(str(id_) for id_ in sentence_ids)
sentence_ids_s = sentence_ids_s or "NULL"
sql_query = f"""
SELECT article_id, sentence_id, section_name, text, paragraph_pos_in_article
FROM sentences
WHERE sentence_id IN ({sentence_ids_s})
"""
sql_query = sql.text(
"""
SELECT article_id, sentence_id, section_name, text, paragraph_pos_in_article
FROM sentences
WHERE sentence_id IN :sentence_ids
"""
)
sql_query = sql_query.bindparams(sql.bindparam("sentence_ids", expanding=True))

with engine.begin() as connection:
df_sentences = pd.read_sql(sql_query, connection)
df_sentences = pd.read_sql(
sql_query,
params={"sentence_ids": [int(id_) for id_ in sentence_ids]},
pafonta marked this conversation as resolved.
Show resolved Hide resolved
con=connection,
)

if keep_order:
# Remove sentence IDs that were not found, otherwise df.loc will fail.
Expand Down Expand Up @@ -112,19 +121,23 @@ def retrieve_paragraph_from_sentence_id(sentence_id, engine):
sentence_id. If None then the `sentence_id` was not found in the
sentences table.
"""
sql_query = f"""SELECT text
sql_query = sql.text(
"""SELECT text
FROM sentences
WHERE article_id =
(SELECT article_id
FROM sentences
WHERE sentence_id = {sentence_id})
WHERE sentence_id = :sentence_id )
AND paragraph_pos_in_article =
(SELECT paragraph_pos_in_article
FROM sentences
WHERE sentence_id = {sentence_id})
WHERE sentence_id = :sentence_id )
ORDER BY sentence_pos_in_paragraph ASC"""
)

all_sentences = pd.read_sql(sql_query, engine)["text"].to_list()
all_sentences = pd.read_sql(
sql_query, engine, params={"sentence_id": int(sentence_id)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docstring, sentence_id is already an int. Why a casting to int is added?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same answer as in this comment.

)["text"].to_list()
if not all_sentences:
paragraph = None
else:
Expand All @@ -151,13 +164,22 @@ def retrieve_paragraph(article_id, paragraph_pos_in_article, engine):
pd.DataFrame with the paragraph and its metadata:
article_id, text, section_name, paragraph_pos_in_article.
"""
sql_query = f"""SELECT section_name, text
sql_query = sql.text(
"""SELECT section_name, text
FROM sentences
WHERE article_id = {article_id}
AND paragraph_pos_in_article = {paragraph_pos_in_article}
WHERE article_id = :article_id
AND paragraph_pos_in_article = :paragraph_pos_in_article
ORDER BY sentence_pos_in_paragraph ASC"""

sentences = pd.read_sql(sql_query, engine)
)

sentences = pd.read_sql(
sql_query,
engine,
params={
"article_id": int(article_id),
"paragraph_pos_in_article": int(paragraph_pos_in_article),
Comment on lines +179 to +180
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docstring, article_id and paragraph_pos_in_article are already of type int. Why a casting to int is added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Python doesn't have static typing, if someone passes any other type than int we could have a mess.
In particular, if we pass a np.int64 the SQL query breaks, which was indeed what was happening in our case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Was it spotted by bandit?

Anyway, should we consider... type annotations for all BBS then? ;) Or data validation frameworks for Python arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For type annotations: I am not against as said in the past ;) I am still a bit annoyed that one would have to write them for each function by hand, it would be cool if we could use the numpy docstrings we already have to auto-generate type annotations (if PyCharm can do it, there must be a way I hope).

But afaik type annotations wouldn't have raised any exception here right? They are just useful to help your IDE or a developer to know which type is "expected".

For data validation frameworks: I do not know any, do you have one in mind in particular?
I feel Python's duck typing is a key feature, so unless it's really needed I would prefer avoiding type checking...

Copy link
Contributor

@pafonta pafonta Dec 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FrancescoCasalegno

auto-generate type annotations

According to what I have found, there is tooling to convert docstring types into type annotations.

type annotations wouldn't have raised any exception here right?

The IDE would have complained where the function with the body we discuss is used.
A static type checker like mypy, run by the CI, would have complained.
So, with this, the case where an exception would have been thrown would not be reached.

For data validation frameworks: I do not know any, do you have one in mind in particular?

For validating inputs, I would have pydantic in mind.

I feel Python's duck typing is a key feature, so unless it's really needed I would prefer avoiding type checking...

I agree. However, I would weight it against two other points:

  • How well are exceptions managed?
  • How critical is it if the process stops?

This being said, I think that checking docstrings types or Python type annotations would already let us be more comfortable with the runtime issues. I won't go for inputs validation.

},
)
if sentences.empty:
paragraph = pd.DataFrame(
columns=["article_id", "text", "section_name", "paragraph_pos_in_article"]
Expand Down Expand Up @@ -199,10 +221,12 @@ def retrieve_article_metadata_from_article_id(article_id, engine):
'authors', 'journal', 'mag_id', 'who_covidence_id', 'arxiv_id',
'pdf_json_files', 'pmc_json_files', 'url', 's2_id'.
"""
sql_query = f"""SELECT *
sql_query = sql.text(
"""SELECT *
FROM articles
WHERE article_id = {article_id}"""
article = pd.read_sql(sql_query, engine)
WHERE article_id = :article_id"""
)
article = pd.read_sql(sql_query, engine, params={"article_id": int(article_id)})
FrancescoCasalegno marked this conversation as resolved.
Show resolved Hide resolved
return article


Expand All @@ -222,14 +246,17 @@ def retrieve_articles(article_ids, engine):
DataFrame containing the articles divided into paragraphs. The columns are
'article_id', 'paragraph_pos_in_article', 'text', 'section_name'.
"""
articles_str = ", ".join(str(id_) for id_ in article_ids)
sql_query = f"""SELECT *
article_ids = [int(id_) for id_ in article_ids]
pafonta marked this conversation as resolved.
Show resolved Hide resolved
sql_query = sql.text(
"""SELECT *
FROM sentences
WHERE article_id IN ({articles_str})
WHERE article_id IN :articles_ids
ORDER BY article_id ASC,
paragraph_pos_in_article ASC,
sentence_pos_in_paragraph ASC"""
all_sentences = pd.read_sql(sql_query, engine)
)
sql_query = sql_query.bindparams(sql.bindparam("articles_ids", expanding=True))
all_sentences = pd.read_sql(sql_query, engine, params={"articles_ids": article_ids})

groupby_var = all_sentences.groupby(by=["article_id", "paragraph_pos_in_article"])
paragraphs = groupby_var["text"].apply(lambda x: " ".join(x))
Expand Down Expand Up @@ -260,21 +287,27 @@ def retrieve_mining_cache(identifiers, model_names, engine):
result : pd.DataFrame
Selected rows of the `mining_cache` table.
"""
model_names = tuple(set(model_names))
if len(model_names) == 1:
model_names = f"('{model_names[0]}')"
model_names = list(set(model_names))
pafonta marked this conversation as resolved.
Show resolved Hide resolved
identifiers_arts = [int(a) for a, p in identifiers if p == -1]
FrancescoCasalegno marked this conversation as resolved.
Show resolved Hide resolved

identifiers_arts = tuple(a for a, p in identifiers if p == -1)
if len(identifiers_arts) == 1:
identifiers_arts = f"({identifiers_arts[0]})"
if identifiers_arts:
query_arts = f"""
query_arts = sql.text(
"""
SELECT *
FROM mining_cache
WHERE article_id IN {identifiers_arts} AND mining_model IN {model_names}
WHERE article_id IN :identifiers_arts AND mining_model IN :model_names
ORDER BY article_id, paragraph_pos_in_article, start_char
"""
df_arts = pd.read_sql(query_arts, con=engine)
)
query_arts = query_arts.bindparams(
sql.bindparam("identifiers_arts", expanding=True),
sql.bindparam("model_names", expanding=True),
)
df_arts = pd.read_sql(
query_arts,
con=engine,
params={"identifiers_arts": identifiers_arts, "model_names": model_names},
)
else:
df_arts = pd.DataFrame()

Expand All @@ -287,6 +320,8 @@ def retrieve_mining_cache(identifiers, model_names, engine):
# 3. If `len(identifiers_pars)` is too large, we may have a too long
# SQL statement which overflows the max length. So we break it down.

if len(model_names) == 1:
model_names = f"('{model_names[0]}')"
batch_size = 1000
dfs_pars = []
d, r = divmod(len(identifiers_pars), batch_size)
Expand All @@ -296,14 +331,14 @@ def retrieve_mining_cache(identifiers, model_names, engine):
SELECT *
FROM mining_cache
WHERE (article_id = {a} AND paragraph_pos_in_article = {p})
"""
""" # nosec
for a, p in identifiers_pars[i * batch_size : (i + 1) * batch_size]
)
query_pars = f"""
SELECT *
FROM ({query_pars}) tt
WHERE tt.mining_model IN {model_names}
"""
""" # nosec
dfs_pars.append(pd.read_sql(query_pars, engine))
df_pars = pd.concat(dfs_pars)
df_pars = df_pars.sort_values(
Expand Down Expand Up @@ -553,7 +588,7 @@ def _build_query(self):
FROM articles
WHERE {" AND ".join(article_conditions)}
)
""".strip()
""".strip() # nosec
sentence_conditions.append(article_condition_query)

# Restricted sentence IDs
Expand Down
2 changes: 2 additions & 0 deletions tests/test_mining/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,8 @@ def test_ner_errors(ner_annotations, dataset, mode, errors_expected):
)
errors_expected = OrderedDict(errors_expected)
assert errors_out == errors_expected
with pytest.raises(ValueError):
ner_errors(iob_true, iob_pred[:-1], tokens)


def test_remove_punctuation(punctuation_annotations):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_retrieve_sentence_from_sentence_ids(
]
)

@pytest.mark.parametrize("sentence_id", [1, 2, 3, 0, -100])
@pytest.mark.parametrize("sentence_id", [1, 2, 3, 0, -100, -1, np.int64(2)])
def test_retrieve_paragraph_from_sentence_id(
self, sentence_id, fake_sqlalchemy_engine
):
Expand All @@ -85,7 +85,7 @@ def test_retrieve_paragraph_from_sentence_id(
sentence_text = retrieve_sentences_from_sentence_ids(
sentence_ids=(sentence_id,), engine=fake_sqlalchemy_engine
)
if sentence_id == 0 or sentence_id == -100: # invalid sentence_id
if sentence_id in [0, -100, -1]: # invalid sentence_id
assert paragraph is None
else:
assert isinstance(paragraph, str)
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ commands =
[testenv:lint]
skip_install = true
deps =
; bandit
bandit
black==20.8b1
flake8==3.8.4
isort==5.6.4
Expand All @@ -32,7 +32,7 @@ commands =
isort --profile black --check setup.py {[tox]source} tests
pydocstyle {[tox]source}
black -q --check setup.py {[tox]source} tests
; bandit -q -r {[tox]source}
bandit -c .bandit -q -r {[tox]source}

[testenv:format]
skip_install = true
Expand Down