Skip to content

Commit

Permalink
Sanitize user-defined names that are passed into the API endpoints (#2)
Browse files Browse the repository at this point in the history
* add name sanitization
* update pytests version
* update error messages and fix local search bug
* update notebook
* improve error handling

---------

Co-authored-by: Christine Caggiano <[email protected]>
  • Loading branch information
jgbradley1 and Christine Caggiano authored Jun 6, 2024
1 parent 268c6ad commit 4ff1367
Show file tree
Hide file tree
Showing 13 changed files with 607 additions and 433 deletions.
80 changes: 80 additions & 0 deletions backend/src/api/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
***REMOVED***
***REMOVED***

import hashlib
import re
from typing import Annotated

Expand Down Expand Up @@ -112,6 +113,85 @@ def validate_blob_container_name(container_name: str):
***REMOVED***


def sanitize_name(name: str | None) -> str | None:
***REMOVED***
Sanitize a human-readable name by converting it to a SHA256 hash, then truncate
to 128 bit length to ensure it is within the 63 character limit imposed by Azure Storage.

The sanitized name will be used to identify container names in both Azure Storage and CosmosDB.

Args:
-----
name (str)
The name to be sanitized.

Returns: str
The sanitized name.
***REMOVED***
if not name:
***REMOVED***
name = name.encode()
name_hash = hashlib.sha256(name)
truncated_hash = name_hash.digest()[:16] # get the first 16 bytes (128 bits)
return truncated_hash.hex()


def retrieve_original_blob_container_name(sanitized_name: str) -> str | None:
***REMOVED***
Retrieve the original human-readable container name of a sanitized blob name.

Args:
-----
sanitized_name (str)
The sanitized name to be converted back to the original name.

Returns: str
The original human-readable name.
***REMOVED***
***REMOVED***
container_store_client = (
azure_storage_client_manager.get_cosmos_container_client(
database_name="graphrag", container_name="container-store"
***REMOVED***
***REMOVED***
for item in container_store_client.read_all_items():
if item["id"] == sanitized_name:
return item["human_readable_name"]
except Exception:
raise HTTPException(
status_code=500, detail="Error retrieving original blob name."
***REMOVED***
return None


def retrieve_original_entity_config_name(sanitized_name: str) -> str | None:
***REMOVED***
Retrieve the original human-readable entity config name of a sanitized entity config name.

Args:
-----
sanitized_name (str)
The sanitized name to be converted back to the original name.

Returns: str
The original human-readable name.
***REMOVED***
***REMOVED***
container_store_client = (
azure_storage_client_manager.get_cosmos_container_client(
database_name="graphrag", container_name="entities"
***REMOVED***
***REMOVED***
for item in container_store_client.read_all_items():
if item["id"] == sanitized_name:
return item["human_readable_name"]
except Exception:
raise HTTPException(
status_code=500, detail="Error retrieving original entity config name."
***REMOVED***
return None


async def verify_subscription_key_exist(
Ocp_Apim_Subscription_Key: Annotated[str, Header()],
):
Expand Down
41 changes: 32 additions & 9 deletions backend/src/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from src.api.common import (
delete_blob_container,
sanitize_name,
validate_blob_container_name,
verify_subscription_key_exist,
)
Expand Down Expand Up @@ -60,10 +61,13 @@ async def get_all_data_storage_containers():
***REMOVED***
for item in container_store_client.read_all_items():
if item["type"] == "data":
items.append(item["id"])
items.append(item["human_readable_name"])
***REMOVED***
reporter = ReporterSingleton().get_instance()
reporter.on_error(f"Error getting all data containers: {str(e)***REMOVED***")
reporter.on_error("Error getting list of blob containers.\nDetails: " + str(e))
raise HTTPException(
status_code=500, detail="Error getting list of blob containers."
***REMOVED***
return StorageNameList(storage_name=items)


Expand Down Expand Up @@ -136,11 +140,20 @@ async def upload_files(
HTTPException: If the container name is invalid or if any error occurs during the upload process.
***REMOVED***
reporter = ReporterSingleton().get_instance()
sanitized_storage_name = sanitize_name(storage_name)
# ensure container name follows Azure Blob Storage naming conventions
validate_blob_container_name(storage_name)
***REMOVED***
validate_blob_container_name(sanitized_storage_name)
except ValueError:
raise HTTPException(
status_code=500,
detail=f"Invalid blob container name: '{storage_name***REMOVED***'. Please try a different name.",
***REMOVED***
***REMOVED***
blob_service_client = BlobServiceClientSingletonAsync.get_instance()
container_client = blob_service_client.get_container_client(storage_name)
container_client = blob_service_client.get_container_client(
sanitized_storage_name
***REMOVED***
if not await container_client.exists():
await container_client.create_container()

Expand All @@ -166,7 +179,8 @@ async def upload_files(
***REMOVED***
container_store_client.upsert_item(
{
"id": storage_name,
"id": sanitized_storage_name,
"human_readable_name": storage_name,
"type": "data",
***REMOVED***
***REMOVED***
Expand All @@ -175,7 +189,10 @@ async def upload_files(
reporter.on_error(
"Error uploading files.", details={"ErrorDetails": str(e), "files": files***REMOVED***
***REMOVED***
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail=f"Error uploading files to container '{storage_name***REMOVED***'.",
***REMOVED***


@data_route.delete(
Expand All @@ -188,21 +205,27 @@ async def delete_files(storage_name: str):
***REMOVED***
Delete a specified data storage container.
***REMOVED***
sanitized_storage_name = sanitize_name(storage_name)
***REMOVED***
# delete container in Azure Storage
delete_blob_container(storage_name)
delete_blob_container(sanitized_storage_name)
# update container-store in cosmosDB
container_store_client = (
azure_storage_client_manager.get_cosmos_container_client(
database_name="graphrag", container_name="container-store"
***REMOVED***
***REMOVED***
container_store_client.delete_item(storage_name, storage_name)
container_store_client.delete_item(
item=sanitized_storage_name,
partition_key=sanitized_storage_name,
***REMOVED***
***REMOVED***
reporter = ReporterSingleton().get_instance()
reporter.on_error(
f"Error deleting container {storage_name***REMOVED***.",
details={"ErrorDetails": str(e), "Container": storage_name***REMOVED***,
***REMOVED***
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500, detail=f"Error deleting container '{storage_name***REMOVED***'."
***REMOVED***
return BaseResponse(status="Success")
12 changes: 7 additions & 5 deletions backend/src/api/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from src.api.common import (
sanitize_name,
validate_index_file_exist,
verify_subscription_key_exist,
)
Expand Down Expand Up @@ -66,8 +67,9 @@ async def global_search_streaming(request: GraphRequest):
index_names = [request.index_name]
else:
index_names = request.index_name
sanitized_index_names = [sanitize_name(name) for name in index_names]

for index_name in index_names:
for index_name in sanitized_index_names:
if not _is_index_complete(index_name):
raise HTTPException(
status_code=500,
Expand All @@ -77,15 +79,15 @@ async def global_search_streaming(request: GraphRequest):
ENTITY_TABLE = "output/create_final_nodes.parquet"
COMMUNITY_REPORT_TABLE = "output/create_final_community_reports.parquet"

for index_name in index_names:
for index_name in sanitized_index_names:
validate_index_file_exist(index_name, COMMUNITY_REPORT_TABLE)
validate_index_file_exist(index_name, ENTITY_TABLE)

# current investigations show that community level 1 is the most useful for global search
COMMUNITY_LEVEL = 1
***REMOVED***
report_dfs = []
for index_name in index_names:
for index_name in sanitized_index_names:
entity_table_path = f"abfs://{index_name***REMOVED***/{ENTITY_TABLE***REMOVED***"
community_report_table_path = (
f"abfs://{index_name***REMOVED***/{COMMUNITY_REPORT_TABLE***REMOVED***"
Expand All @@ -101,12 +103,12 @@ async def global_search_streaming(request: GraphRequest):
if len(report_df["community_id"]) > 0:
max_id = report_df["community_id"].astype(int).max()
report_df["title"] = [
index_names[0] + "<sep>" + i + "<sep>" + t
sanitized_index_names[0] + "<sep>" + i + "<sep>" + t
for i, t in zip(report_df["community_id"], report_df["title"])
***REMOVED***
for idx, df in enumerate(report_dfs[1:]):
df["title"] = [
index_names[idx + 1] + "<sep>" + i + "<sep>" + t
sanitized_index_names[idx + 1] + "<sep>" + i + "<sep>" + t
for i, t in zip(df["community_id"], df["title"])
***REMOVED***
df["community_id"] = [str(int(i) + max_id + 1) for i in df["community_id"]]
Expand Down
21 changes: 15 additions & 6 deletions backend/src/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from src.api.azure_clients import BlobServiceClientSingleton
from src.api.common import (
sanitize_name,
validate_index_file_exist,
verify_subscription_key_exist,
)
Expand All @@ -39,12 +40,13 @@
)
async def retrieve_graphml_file(index_name: str):
# validate index_name and graphml file existence
sanitized_index_name = sanitize_name(index_name)
graphml_filename = "summarized_graph.graphml"
blob_filepath = f"output/{graphml_filename***REMOVED***" # expected file location of the graph based on the workflow
validate_index_file_exist(index_name, blob_filepath)
validate_index_file_exist(sanitized_index_name, blob_filepath)
***REMOVED***
blob_client = blob_service_client.get_blob_client(
container=index_name, blob=blob_filepath
container=sanitized_index_name, blob=blob_filepath
***REMOVED***
blob_stream = blob_client.download_blob().chunks()
return StreamingResponse(
Expand All @@ -55,7 +57,10 @@ async def retrieve_graphml_file(index_name: str):
***REMOVED***
reporter = ReporterSingleton().get_instance()
reporter.on_error(f"Could not retrieve graphml file: {str(e)***REMOVED***")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graphml file for index '{index_name***REMOVED***'.",
***REMOVED***


@graph_route.get(
Expand All @@ -66,15 +71,19 @@ async def retrieve_graphml_file(index_name: str):
)
async def retrieve_graph_stats(index_name: str):
# validate index_name and knowledge graph file existence
sanitized_index_name = sanitize_name(index_name)
graph_file = "output/summarized_graph.graphml" # expected filename of the graph based on the indexing workflow
validate_index_file_exist(index_name, graph_file)
validate_index_file_exist(sanitized_index_name, graph_file)
***REMOVED***
storage_client = blob_service_client.get_container_client(index_name)
storage_client = blob_service_client.get_container_client(sanitized_index_name)
blob_data = storage_client.download_blob(graph_file).readall()
bytes_io = BytesIO(blob_data)
g = nx.read_graphml(bytes_io)
return GraphDataResponse(nodes=len(g.nodes), edges=len(g.edges))
***REMOVED***
reporter = ReporterSingleton().get_instance()
reporter.on_error(f"Could not retrieve graph data file: {str(e)***REMOVED***")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graph statistics for index '{index_name***REMOVED***'.",
***REMOVED***
Loading

0 comments on commit 4ff1367

Please sign in to comment.