Skip to content

Commit

Permalink
fixing random sample implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollman committed Aug 14, 2023
1 parent a873852 commit 837b42a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
8 changes: 8 additions & 0 deletions materializationengine/blueprints/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def unhandled_exception(e):
location="args",
help="whether to only return the count of a query",
)
query_parser.add_argument(
"random_sample",
type=inputs.positive,
default=None,
required=False,
location="args",
help="How many samples to randomly get using tablesample on annotation tables, useful for visualization of large tables does not work as a random sample of query",
)


def check_aligned_volume(aligned_volume):
Expand Down
27 changes: 24 additions & 3 deletions materializationengine/blueprints/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
collect_crud_columns,
)
from materializationengine.database import dynamic_annotation_cache, sqlalchemy_cache
from materializationengine.models import MaterializedMetadata
from materializationengine.utils import check_read_permission
from materializationengine.info_client import (
get_relevant_datastack_info,
Expand Down Expand Up @@ -228,6 +229,16 @@ def handle_simple_query(
else:
data["desired_resolution"] = None

random_sample = args.get("random_sample", None)
if random_sample is not None:
session = sqlalchemy_cache.get(mat_db_name)
mat_row_count = (
session.query(MaterializedMetadata.row_count)
.filter(MaterializedMetadata.table_name == table_name)
.scalar()
)
random_sample = (100.0 * random_sample) / mat_row_count

qm = QueryManager(
mat_db_name,
segmentation_source=pcg_table_name,
Expand All @@ -236,9 +247,9 @@ def handle_simple_query(
limit=limit,
offset=data.get("offset", 0),
get_count=get_count,
random_sample=args.get("random_sample", None),
random_sample=random_sample,
)
qm.add_table(table_name)
qm.add_table(table_name, random_sample=True)
qm.apply_filter(data.get("filter_in_dict", None), qm.apply_isin_filter)
qm.apply_filter(data.get("filter_out_dict", None), qm.apply_notequal_filter)
qm.apply_filter(data.get("filter_equal_dict", None), qm.apply_equal_filter)
Expand Down Expand Up @@ -335,6 +346,16 @@ def handle_complex_query(
else:
suffixes = data.get("suffix_map")

random_sample = args.get("random_sample", None)
if random_sample is not None:
session = sqlalchemy_cache.get(db_name)
mat_row_count = (
session.query(MaterializedMetadata.row_count)
.filter(MaterializedMetadata.table_name == data["tables"][0][0])
.scalar()
)
random_sample = (100.0 * random_sample) / mat_row_count

qm = QueryManager(
db_name,
segmentation_source=pcg_table_name,
Expand All @@ -344,7 +365,7 @@ def handle_complex_query(
limit=limit,
offset=data.get("offset", 0),
get_count=False,
random_sample=args.get("random_sample",None)
random_sample=random_sample,
)
if convert_desired_resolution:
if not data.get("desired_resolution", None):
Expand Down
14 changes: 10 additions & 4 deletions materializationengine/blueprints/client/query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def add_view(self, datastack_name, view_name):

self._voxel_resolutions[view_name] = vox_res

def add_table(self, table_name):
def add_table(self, table_name, random_sample=False):
if table_name not in self._tables:
self._tables.add(table_name)
if self._split_mode:
Expand All @@ -169,7 +169,7 @@ def add_table(self, table_name):
seg_columns = [
c for c in segmodel.__table__.columns if c.key != "id"
]
if self._random_sample:
if random_sample and self._random_sample:
annmodel_alias1 = aliased(
annmodel, tablesample(annmodel, self._random_sample)
)
Expand All @@ -186,7 +186,13 @@ def add_table(self, table_name):
# self._models[segmodel.__tablename__] = segmodel_alias

else:
self._models[table_name] = annmodel
if random_sample and self._random_sample:
annmodel_alias1 = aliased(
annmodel, tablesample(annmodel, self._random_sample)
)
else:
annmodel_alias1 = annmodel
self._models[table_name] = annmodel_alias1
else:
model = self._get_flat_model(table_name)
if self._random_sample:
Expand All @@ -203,7 +209,7 @@ def _find_relevant_model(self, table_name, column_name):

def join_tables(self, table1, column1, table2, column2, isouter=False):

self.add_table(table1)
self.add_table(table1, random_sample=True)
self.add_table(table2)

model1 = self._models[table1]
Expand Down

0 comments on commit 837b42a

Please sign in to comment.