diff --git a/materializationengine/blueprints/client/api2.py b/materializationengine/blueprints/client/api2.py index e19cadcd..af49876c 100644 --- a/materializationengine/blueprints/client/api2.py +++ b/materializationengine/blueprints/client/api2.py @@ -1,8 +1,9 @@ +import json import pytz from dynamicannotationdb.models import AnalysisTable, AnalysisVersion from cachetools import TTLCache, cached, LRUCache -from flask import abort, request, current_app, g +from flask import Response, abort, request, current_app, g from flask_accepts import accepts from flask_restx import Namespace, Resource, inputs, reqparse from middle_auth_client import ( @@ -16,6 +17,7 @@ from typing import List import werkzeug from sqlalchemy.sql.sqltypes import String, Integer, Float, DateTime, Boolean, Numeric + from geoalchemy2.types import Geometry import nglui from materializationengine.blueprints.client.datastack import validate_datastack @@ -57,7 +59,7 @@ from materializationengine.info_client import get_aligned_volumes, get_datastack_info from materializationengine.schemas import AnalysisTableSchema, AnalysisVersionSchema from materializationengine.blueprints.client.utils import update_notice_text_warnings - +from materializationengine.blueprints.client.utils import after_request __version__ = "4.28.0" @@ -264,10 +266,10 @@ def get_closest_versions(datastack_name: str, timestamp: datetime.datetime): def check_column_for_root_id(col): - if type(col) == "str": + if isinstance(col, str): if col.endswith("root_id"): abort(400, "we are not presently supporting joins on root_ids") - elif type(col) == list: + elif isinstance(col, list): for c in col: if c.endwith("root_id"): abort(400, "we are not presently supporting joins on root ids") @@ -932,9 +934,14 @@ def process_fields(df, fields, column_names, tags, bool_tags, numerical): continue if isinstance(field, mm_fields.String): + if df[col].isnull().all(): + continue + # check that this column is not all nulls tags.append(col) print(f"tag col: {col}") elif isinstance(field, mm_fields.Boolean): + if df[col].isnull().all(): + continue df[col] = df[col].astype(bool) bool_tags.append(col) print(f"bool tag col: {col}") @@ -953,6 +960,45 @@ def process_fields(df, fields, column_names, tags, bool_tags, numerical): print(f"numerical col: {col}") +def process_view_columns(df, model, column_names, tags, bool_tags, numerical): + for table_column_name, table_column in model.columns.items(): + col = column_names[model.name][table_column.key] + if ( + table_column_name.endswith("_supervoxel_id") + or table_column_name.endswith("_root_id") + or table_column_name == "id" + or table_column_name == "valid" + or table_column_name == "target_id" + ): + continue + + if isinstance(table_column.type, String): + if df[col].isnull().all(): + continue + # check that this column is not all nulls + tags.append(col) + print(f"tag col: {col}") + elif isinstance(table_column.type, Boolean): + if df[col].isnull().all(): + continue + df[col] = df[col].astype(bool) + bool_tags.append(col) + print(f"bool tag col: {col}") + elif isinstance(table_column.type, PostGISField): + # if all the values are NaNs skip this column + if df[col + "_x"].isnull().all(): + continue + numerical.append(col + "_x") + numerical.append(col + "_y") + numerical.append(col + "_z") + print(f"numerical cols: {col}_(x,y,z)") + elif isinstance(table_column.type, (Numeric, Integer, Float)): + if df[col].isnull().all(): + continue + numerical.append(col) + print(f"numerical col: {col}") + + def preprocess_dataframe(df, table_name, aligned_volume_name, column_names): db = dynamic_annotation_cache.get_db(aligned_volume_name) # check if this is a reference table @@ -1002,6 +1048,60 @@ def preprocess_dataframe(df, table_name, aligned_volume_name, column_names): unique_vals = {} for tag in tags: unique_vals[tag] = df[tag].unique() + unique_vals[tag] = unique_vals[tag][~pd.isnull(unique_vals[tag])] + + # find all the duplicate values across columns + vals, counts = np.unique( + np.concatenate([v for v in unique_vals.values()]), return_counts=True + ) + duplicates = vals[counts > 1] + + # iterate through the tags and replace any duplicate + # values in the dataframe with a unique value, + # based on preprending the column name + for tag in tags: + for dup in duplicates: + if dup in unique_vals[tag]: + df[tag] = df[tag].replace(dup, f"{tag}:{dup}") + + return df, tags, bool_tags, numerical, root_id_col + + +def preprocess_view_dataframe(df, view_name, db_name, column_names): + db = dynamic_annotation_cache.get_db(db_name) + # check if this is a reference table + view_table = db.database.get_view_table(view_name) + + # find the first column that ends with _root_id using next + try: + root_id_col = next( + (col for col in df.columns if col.endswith("_root_id")), None + ) + except StopIteration: + raise ValueError("No root_id column found in dataframe") + + # pick only the first row with each root_id + # df = df.drop_duplicates(subset=[root_id_col]) + # drop any row with root_id =0 + df = df[df[root_id_col] != 0] + + # iterate through the columns and put them into + # categories of 'tags' for strings, 'numerical' for numbers + + tags = [] + numerical = [] + bool_tags = [] + + process_view_columns(df, view_table, column_names, tags, bool_tags, numerical) + + # Look across the tag columns and make sure that there are no + # duplicate string values across distinct columns + unique_vals = {} + for tag in tags: + unique_vals[tag] = df[tag].unique() + # remove nan values from unique values + unique_vals[tag] = unique_vals[tag][~pd.isnull(unique_vals[tag])] + # find all the duplicate values across columns vals, counts = np.unique( np.concatenate([v for v in unique_vals.values()]), return_counts=True @@ -1117,7 +1217,9 @@ def get( number_cols=numerical, label_col="id", ) - return seg_prop.to_dict(), 200 + dfjson = json.dumps(seg_prop.to_dict(), cls=current_app.json_encoder) + response = Response(dfjson, status=200, mimetype="application/json") + return after_request(response) @client_bp.route("/datastack//table//info") @@ -1181,7 +1283,9 @@ def get( number_cols=numerical, label_col="id", ) - return seg_prop.to_dict(), 200 + dfjson = json.dumps(seg_prop.to_dict(), cls=current_app.json_encoder) + response = Response(dfjson, status=200, mimetype="application/json") + return after_request(response) @client_bp.expect(query_parser) @@ -1566,6 +1670,149 @@ def get( return md +def assemble_view_dataframe(datastack_name, version, view_name, data, args): + """ + Assemble a dataframe from a view + Args: + datastack_name (str): datastack name + version (int): version number + view_name (str): view name + data (dict): query data + args (dict): query arguments + Returns: + pd.DataFrame: dataframe + list: column names + list: warnings + """ + + aligned_volume_name, pcg_table_name = get_relevant_datastack_info(datastack_name) + + if version == 0: + mat_db_name = f"{aligned_volume_name}" + else: + mat_db_name = f"{datastack_name}__mat{version}" + + # check_read_permission(db, table_name) + + max_limit = current_app.config.get("QUERY_LIMIT_SIZE", 200000) + + limit = data.get("limit", max_limit) + if limit > max_limit: + limit = max_limit + + get_count = args.get("count", False) + if get_count: + limit = None + + mat_db = dynamic_annotation_cache.get_db(mat_db_name) + md = mat_db.database.get_view_metadata(datastack_name, view_name) + + if not data.get("desired_resolution", None): + des_res = [ + md["voxel_resolution_x"], + md["voxel_resolution_y"], + md["voxel_resolution_z"], + ] + data["desired_resolution"] = des_res + + qm = QueryManager( + mat_db_name, + segmentation_source=pcg_table_name, + split_mode=False, + limit=limit, + offset=data.get("offset", 0), + get_count=get_count, + ) + qm.add_view(datastack_name, view_name) + qm.apply_filter(data.get("filter_in_dict", None), qm.apply_isin_filter) + qm.apply_filter(data.get("filter_out_dict", None), qm.apply_notequal_filter) + qm.apply_filter(data.get("filter_equal_dict", None), qm.apply_equal_filter) + qm.apply_filter(data.get("filter_spatial_dict", None), qm.apply_spatial_filter) + qm.apply_filter(data.get("filter_regex_dict", None), qm.apply_regex_filter) + select_columns = data.get("select_columns", None) + if select_columns: + for column in select_columns: + qm.select_column(view_name, column) + else: + qm.select_all_columns(view_name) + + df, column_names = qm.execute_query(desired_resolution=data["desired_resolution"]) + df.drop(columns=["deleted", "superceded"], inplace=True, errors="ignore") + warnings = [] + current_app.logger.info("query: {}".format(data)) + current_app.logger.info("args: {}".format(args)) + user_id = str(g.auth_user["id"]) + current_app.logger.info(f"user_id: {user_id}") + + if len(df) == limit: + warnings.append(f'201 - "Limited query to {limit} rows') + warnings = update_notice_text_warnings(md, warnings, view_name) + + return df, column_names, warnings + + +@client_bp.route( + "/datastack//version//view//info" +) +class MatViewSegmentInfo(Resource): + method_decorators = [ + cached(TTLCache(maxsize=256, ttl=60 * 60 * 24)), + validate_datastack, + limit_by_category("query"), + auth_requires_permission("view", table_arg="datastack_name"), + reset_auth, + ] + + def get( + self, + datastack_name: str, + version: int, + view_name: str, + target_datastack: str = None, + target_version: int = None, + ): + """endpoint for getting a segment properties object for a view + + Args: + datastack_name (str): datastack name + version (int): version number + view_name (str): view name + + Returns: + json: a precomputed json object with the segment info for this view + """ + aligned_volume_name, pcg_table_name = get_relevant_datastack_info( + datastack_name + ) + + if version == 0: + mat_db_name = f"{aligned_volume_name}" + else: + mat_db_name = f"{datastack_name}__mat{version}" + + df, column_names, warnings = assemble_view_dataframe( + datastack_name, version, view_name, {}, {} + ) + + df, tags, bool_tags, numerical, root_id_col = preprocess_view_dataframe( + df, view_name, mat_db_name, column_names + ) + + seg_prop = nglui.segmentprops.SegmentProperties.from_dataframe( + df, + id_col=root_id_col, + tag_value_cols=tags, + tag_bool_cols=bool_tags, + number_cols=numerical, + label_col=df.columns[0], + ) + # use the current_app encoder to encode the seg_prop.to_dict() + # to ensure that the json is serialized correctly + dfjson = json.dumps(seg_prop.to_dict(), cls=current_app.json_encoder) + response = Response(dfjson, status=200, mimetype="application/json") + return after_request(response) + + @client_bp.expect(query_parser) @client_bp.route( "/datastack//version//views//query" @@ -1634,74 +1881,9 @@ def post( """ args = query_parser.parse_args() data = request.parsed_obj - # db = validate_table_args([table_name], target_datastack, target_version) - aligned_volume_name, pcg_table_name = get_relevant_datastack_info( - datastack_name - ) - session = sqlalchemy_cache.get(aligned_volume_name) - - # check_read_permission(db, table_name) - - max_limit = current_app.config.get("QUERY_LIMIT_SIZE", 200000) - - limit = data.get("limit", max_limit) - if limit > max_limit: - limit = max_limit - - get_count = args.get("count", False) - if get_count: - limit = None - - if version == 0: - mat_db_name = f"{aligned_volume_name}" - else: - mat_db_name = f"{datastack_name}__mat{version}" - - mat_db = dynamic_annotation_cache.get_db(mat_db_name) - md = mat_db.database.get_view_metadata(datastack_name, view_name) - - if not data.get("desired_resolution", None): - des_res = [ - md["voxel_resolution_x"], - md["voxel_resolution_y"], - md["voxel_resolution_z"], - ] - data["desired_resolution"] = des_res - - qm = QueryManager( - mat_db_name, - segmentation_source=pcg_table_name, - split_mode=False, - limit=limit, - offset=data.get("offset", 0), - get_count=get_count, - ) - qm.add_view(datastack_name, view_name) - qm.apply_filter(data.get("filter_in_dict", None), qm.apply_isin_filter) - qm.apply_filter(data.get("filter_out_dict", None), qm.apply_notequal_filter) - qm.apply_filter(data.get("filter_equal_dict", None), qm.apply_equal_filter) - qm.apply_filter(data.get("filter_spatial_dict", None), qm.apply_spatial_filter) - qm.apply_filter(data.get("filter_regex_dict", None), qm.apply_regex_filter) - select_columns = data.get("select_columns", None) - if select_columns: - for column in select_columns: - qm.select_column(view_name, column) - else: - qm.select_all_columns(view_name) - - df, column_names = qm.execute_query( - desired_resolution=data["desired_resolution"] + df, column_names, warnings = assemble_view_dataframe( + datastack_name, version, view_name, data, args ) - df.drop(columns=["deleted", "superceded"], inplace=True, errors="ignore") - warnings = [] - current_app.logger.info("query: {}".format(data)) - current_app.logger.info("args: {}".format(args)) - user_id = str(g.auth_user["id"]) - current_app.logger.info(f"user_id: {user_id}") - - if len(df) == limit: - warnings.append(f'201 - "Limited query to {limit} rows') - warnings = update_notice_text_warnings(md, warnings, view_name) return create_query_response( df, warnings=warnings, diff --git a/materializationengine/schemas.py b/materializationengine/schemas.py index c42e4342..55969790 100644 --- a/materializationengine/schemas.py +++ b/materializationengine/schemas.py @@ -1,4 +1,9 @@ -from dynamicannotationdb.models import AnalysisTable, AnalysisVersion, VersionErrorTable +from dynamicannotationdb.models import ( + AnalysisTable, + AnalysisVersion, + VersionErrorTable, + AnalysisView, +) from flask_marshmallow import Marshmallow from marshmallow import fields, ValidationError, Schema from marshmallow_sqlalchemy import SQLAlchemyAutoSchema @@ -18,6 +23,14 @@ class Meta: load_instance = True +class AnalysisViewSchema(SQLAlchemyAutoSchema): + class Meta: + model = AnalysisView + load_instance = True + fields = ("id", "table_name", "description") + ordered = True + + class VersionErrorTableSchema(SQLAlchemyAutoSchema): class Meta: model = VersionErrorTable diff --git a/materializationengine/views.py b/materializationengine/views.py index 6e1a5ab0..593c53fb 100644 --- a/materializationengine/views.py +++ b/materializationengine/views.py @@ -13,6 +13,7 @@ VersionErrorTable, AnnoMetadata, MaterializedMetadata, + AnalysisView, ) from dynamicannotationdb.schema import DynamicSchemaClient from flask import ( @@ -27,6 +28,7 @@ from middle_auth_client import auth_required, auth_requires_permission from sqlalchemy import and_, func, or_ from sqlalchemy.sql import text +from materializationengine.blueprints.client.schemas import AnalysisViewSchema from materializationengine.celery_init import celery from celery.result import AsyncResult from materializationengine.blueprints.reset_auth import reset_auth @@ -43,6 +45,7 @@ from materializationengine.schemas import ( AnalysisTableSchema, AnalysisVersionSchema, + AnalysisViewSchema, VersionErrorTableSchema, ) from materializationengine.utils import check_read_permission @@ -142,12 +145,13 @@ def make_df_with_links_to_id( df = pd.DataFrame(data=schema.dump(objects, many=True)) if urlkwargs is None: urlkwargs = {} - df[col] = df.apply( - lambda x: "{}".format( - url_for(url, id=getattr(x, col_value), **urlkwargs), x[col] - ), - axis=1, - ) + if url is not None: + df[col] = df.apply( + lambda x: "{}".format( + url_for(url, id=getattr(x, col_value), **urlkwargs), x[col] + ), + axis=1, + ) return df @@ -243,16 +247,25 @@ def version_error(datastack_name: str, id: int): ) -def make_seg_prop_ng_link(datastack_name, table_name, version, client): +def make_seg_prop_ng_link(datastack_name, table_name, version, client, is_view=False): seg_layer = client.info.segmentation_source(format_for="neuroglancer") seg_layer.replace("graphene://https://", "graphene://middleauth+https://") - seginfo_url = url_for( - "api.Materialization Client2_mat_table_segment_info", - datastack_name=datastack_name, - version=version, - table_name=table_name, - _external=True, - ) + if is_view: + seginfo_url = url_for( + "api.Materialization Client2_mat_view_segment_info", + datastack_name=datastack_name, + version=version, + view_name=table_name, + _external=True, + ) + else: + seginfo_url = url_for( + "api.Materialization Client2_mat_table_segment_info", + datastack_name=datastack_name, + version=version, + table_name=table_name, + _external=True, + ) seg_info_source = f"precomputed://middleauth+{seginfo_url}".format( seginfo_url=seginfo_url @@ -329,11 +342,34 @@ def version_view(datastack_name: str, id: int): escape=False, classes=classes, index=False, justify="left", border=0 ) + mat_session = sqlalchemy_cache.get(f"{datastack_name}__mat{version.version}") + + views = mat_session.query(AnalysisView).all() + + views_df = make_df_with_links_to_id( + objects=views, + schema=AnalysisViewSchema(many=True), + url=None, + col=None, + col_value=None, + datastack_name=datastack_name, + ) + views_df["ng_link"] = views_df.apply( + lambda x: f"seg prop link", + axis=1, + ) + classes = ["table table-borderless"] + with pd.option_context("display.max_colwidth", -1): + output_view_html = views_df.to_html( + escape=False, classes=classes, index=False, justify="left", border=0 + ) + return render_template( "version.html", datastack=datastack_name, analysisversion=version, table=output_html, + view_table=output_view_html, version=__version__, ) diff --git a/templates/version.html b/templates/version.html index f3e45cc6..f065ed28 100644 --- a/templates/version.html +++ b/templates/version.html @@ -1,11 +1,13 @@ {% extends 'base.html' %} {% block header %} -

{% block title %}{{datastack}}{% endblock %}

+

{% block title %}{{datastack}}{% endblock %}

{% endblock %} {% block content %}

tables

{{table|safe}} +

views

+{{view_table|safe}} {% endblock %} \ No newline at end of file