Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random sample option for queries #123

Merged
merged 14 commits into from
Sep 19, 2023
8 changes: 8 additions & 0 deletions materializationengine/blueprints/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def unhandled_exception(e):
location="args",
help="whether to only return the count of a query",
)
query_parser.add_argument(
"random_sample",
type=inputs.positive,
default=None,
required=False,
location="args",
help="How many samples to randomly get using tablesample on annotation tables, useful for visualization of large tables does not work as a random sample of query",
)


def check_aligned_volume(aligned_volume):
Expand Down
49 changes: 47 additions & 2 deletions materializationengine/blueprints/client/api2.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,39 @@ def unhandled_exception(e):
)


def _get_float(value):
try:
return float(value)
except (TypeError, ValueError):
raise ValueError(f"{value} is not a valid float")


class float_range(object):
"""Restrict input to an float in a range (inclusive)"""

def __init__(self, low, high, argument="argument"):
self.low = low
self.high = high
self.argument = argument

def __call__(self, value):
value = _get_float(value)
if value < self.low or value > self.high:
msg = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}"
raise ValueError(
msg.format(arg=self.argument, val=value, lo=self.low, hi=self.high)
)
return value

@property
def __schema__(self):
return {
"type": "integer",
"minimum": self.low,
"maximum": self.high,
}


query_parser = reqparse.RequestParser()
query_parser.add_argument(
"return_pyarrow",
Expand All @@ -103,6 +136,14 @@ def unhandled_exception(e):
location="args",
help=("whether to convert dataframe to pyarrow ipc batch format"),
)
query_parser.add_argument(
"random_sample",
type=inputs.positive,
default=None,
required=False,
location="args",
help="How many samples to randomly get using tablesample on annotation tables, useful for visualization of large tables does not work as a random sample of query",
)
query_parser.add_argument(
"split_positions",
type=inputs.boolean,
Expand Down Expand Up @@ -211,7 +252,8 @@ def execute_materialized_query(
user_data: dict,
query_map: dict,
cg_client,
split_mode=False,
random_sample: int = None,
split_mode: bool = False,
) -> pd.DataFrame:
"""_summary_

Expand All @@ -233,13 +275,16 @@ def execute_materialized_query(
.filter(MaterializedMetadata.table_name == user_data["table"])
.scalar()
)
if random_sample is not None:
random_sample = (100.0*random_sample)/mat_row_count
if mat_row_count:
# setup a query manager
qm = QueryManager(
mat_db_name,
segmentation_source=pcg_table_name,
meta_db_name=aligned_volume,
split_mode=split_mode,
random_sample=random_sample,
)
qm.configure_query(user_data)
qm.apply_filter({user_data["table"]: {"valid": True}}, qm.apply_equal_filter)
Expand Down Expand Up @@ -962,7 +1007,6 @@ def post(self, datastack_name: str):
db = dynamic_annotation_cache.get_db(aligned_vol)
check_read_permission(db, user_data["table"])
allow_invalid_root_ids = args.get("allow_invalid_root_ids", False)

# TODO add table owner warnings
# if has_joins:
# abort(400, "we are not supporting joins yet")
Expand Down Expand Up @@ -1038,6 +1082,7 @@ def post(self, datastack_name: str):
modified_user_data,
query_map,
cg_client,
random_sample=args["random_sample"],
split_mode=not chosen_version.is_merged,
)

Expand Down
29 changes: 26 additions & 3 deletions materializationengine/blueprints/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
collect_crud_columns,
)
from materializationengine.database import dynamic_annotation_cache, sqlalchemy_cache
from materializationengine.models import MaterializedMetadata
from materializationengine.utils import check_read_permission
from materializationengine.info_client import (
get_relevant_datastack_info,
Expand Down Expand Up @@ -228,6 +229,16 @@ def handle_simple_query(
else:
data["desired_resolution"] = None

random_sample = args.get("random_sample", None)
if random_sample is not None:
session = sqlalchemy_cache.get(mat_db_name)
mat_row_count = (
session.query(MaterializedMetadata.row_count)
.filter(MaterializedMetadata.table_name == table_name)
.scalar()
)
random_sample = (100.0 * random_sample) / mat_row_count

qm = QueryManager(
mat_db_name,
segmentation_source=pcg_table_name,
Expand All @@ -236,8 +247,9 @@ def handle_simple_query(
limit=limit,
offset=data.get("offset", 0),
get_count=get_count,
random_sample=random_sample,
)
qm.add_table(table_name)
qm.add_table(table_name, random_sample=True)
qm.apply_filter(data.get("filter_in_dict", None), qm.apply_isin_filter)
qm.apply_filter(data.get("filter_out_dict", None), qm.apply_notequal_filter)
qm.apply_filter(data.get("filter_equal_dict", None), qm.apply_equal_filter)
Expand Down Expand Up @@ -267,7 +279,7 @@ def handle_simple_query(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args['arrow_format']
arrow_format=args["arrow_format"],
)


Expand Down Expand Up @@ -334,6 +346,16 @@ def handle_complex_query(
else:
suffixes = data.get("suffix_map")

random_sample = args.get("random_sample", None)
if random_sample is not None:
session = sqlalchemy_cache.get(db_name)
mat_row_count = (
session.query(MaterializedMetadata.row_count)
.filter(MaterializedMetadata.table_name == data["tables"][0][0])
.scalar()
)
random_sample = (100.0 * random_sample) / mat_row_count

qm = QueryManager(
db_name,
segmentation_source=pcg_table_name,
Expand All @@ -343,6 +365,7 @@ def handle_complex_query(
limit=limit,
offset=data.get("offset", 0),
get_count=False,
random_sample=random_sample,
)
if convert_desired_resolution:
if not data.get("desired_resolution", None):
Expand Down Expand Up @@ -421,5 +444,5 @@ def handle_complex_query(
column_names=column_names,
desired_resolution=data["desired_resolution"],
return_pyarrow=args["return_pyarrow"],
arrow_format=args['arrow_format']
arrow_format=args["arrow_format"],
)
43 changes: 31 additions & 12 deletions materializationengine/blueprints/client/query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy.orm import aliased
from sqlalchemy.sql.selectable import Alias
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.expression import tablesample
from sqlalchemy.ext.declarative.api import DeclarativeMeta
import datetime

Expand All @@ -31,7 +32,8 @@ def __init__(
offset: int = 0,
limit: int = DEFAULT_LIMIT,
get_count: bool = False,
split_mode_outer=False,
split_mode_outer: bool = False,
random_sample: float = None,
):
self._db = dynamic_annotation_cache.get_db(db_name)
if meta_db_name is None:
Expand All @@ -40,6 +42,8 @@ def __init__(
self._meta_db = dynamic_annotation_cache.get_db(meta_db_name)
self._segmentation_source = segmentation_source
self._split_mode = split_mode
self._random_sample = random_sample

self._split_mode_outer = split_mode_outer
self._split_models = {}
self._flat_models = {}
Expand Down Expand Up @@ -88,13 +92,15 @@ def _get_split_model(self, table_name):
if reference_table:
table_metadata = {"reference_table": reference_table}
ref_md = self._meta_db.database.get_table_metadata(reference_table)
_ = self._db.schema.get_split_models(reference_table,
ref_md["schema_type"],
self._segmentation_source,
table_metadata=None)
_ = self._db.schema.get_split_models(
reference_table,
ref_md["schema_type"],
self._segmentation_source,
table_metadata=None,
)
else:
table_metadata = None

annmodel, segmodel = self._db.schema.get_split_models(
table_name,
md["schema_type"],
Expand Down Expand Up @@ -151,7 +157,7 @@ def add_view(self, datastack_name, view_name):

self._voxel_resolutions[view_name] = vox_res

def add_table(self, table_name):
def add_table(self, table_name, random_sample=False):
if table_name not in self._tables:
self._tables.add(table_name)
if self._split_mode:
Expand All @@ -163,9 +169,15 @@ def add_table(self, table_name):
seg_columns = [
c for c in segmodel.__table__.columns if c.key != "id"
]
if random_sample and self._random_sample:
annmodel_alias1 = aliased(
annmodel, tablesample(annmodel, self._random_sample)
)
else:
annmodel_alias1 = annmodel
subquery = (
self._db.database.session.query(annmodel, *seg_columns)
.join(segmodel, annmodel.id == segmodel.id, isouter=True)
self._db.database.session.query(annmodel_alias1, *seg_columns)
.join(segmodel, annmodel_alias1.id == segmodel.id, isouter=True)
.subquery()
)
annmodel_alias = aliased(subquery, name=table_name, flat=True)
Expand All @@ -174,9 +186,17 @@ def add_table(self, table_name):
# self._models[segmodel.__tablename__] = segmodel_alias

else:
self._models[table_name] = annmodel
if random_sample and self._random_sample:
annmodel_alias1 = aliased(
annmodel, tablesample(annmodel, self._random_sample)
)
else:
annmodel_alias1 = annmodel
self._models[table_name] = annmodel_alias1
else:
model = self._get_flat_model(table_name)
if self._random_sample:
model = aliased(model, tablesample, self._random_sample)
self._models[table_name] = model

def _find_relevant_model(self, table_name, column_name):
Expand All @@ -189,7 +209,7 @@ def _find_relevant_model(self, table_name, column_name):

def join_tables(self, table1, column1, table2, column2, isouter=False):

self.add_table(table1)
self.add_table(table1, random_sample=True)
self.add_table(table2)

model1 = self._models[table1]
Expand Down Expand Up @@ -412,7 +432,6 @@ def _make_query(

if select_columns is not None:
query = query.with_entities(*select_columns)

if offset is not None:
query = query.offset(offset)
if limit is not None:
Expand Down
4 changes: 3 additions & 1 deletion materializationengine/blueprints/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def create_query_response(
)
context = pa.default_serialization_context()
serialized = context.serialize(df).to_buffer().to_pybytes()
response = Response(serialized, headers=headers, mimetype="x-application/pyarrow")
response = Response(
serialized, headers=headers, mimetype="x-application/pyarrow"
)
return after_request(response)
else:
dfjson = df.to_json(orient="records")
Expand Down
Loading