Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize user-defined names that are passed into the API endpoints #2

Merged
merged 12 commits into from
Jun 6, 2024
80 changes: 80 additions & 0 deletions backend/src/api/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import hashlib
import re
from typing import Annotated

Expand Down Expand Up @@ -112,6 +113,85 @@ def validate_blob_container_name(container_name: str):
)


def sanitize_name(name: str | None) -> str | None:
"""
Sanitize a human-readable name by converting it to a SHA256 hash, then truncate
to 128 bit length to ensure it is within the 63 character limit imposed by Azure Storage.

The sanitized name will be used to identify container names in both Azure Storage and CosmosDB.

Args:
-----
name (str)
The name to be sanitized.

Returns: str
The sanitized name.
"""
if not name:
return None
name = name.encode()
name_hash = hashlib.sha256(name)
truncated_hash = name_hash.digest()[:16] # get the first 16 bytes (128 bits)
return truncated_hash.hex()


def retrieve_original_blob_container_name(sanitized_name: str) -> str | None:
"""
Retrieve the original human-readable container name of a sanitized blob name.

Args:
-----
sanitized_name (str)
The sanitized name to be converted back to the original name.

Returns: str
The original human-readable name.
"""
try:
container_store_client = (
azure_storage_client_manager.get_cosmos_container_client(
database_name="graphrag", container_name="container-store"
)
)
for item in container_store_client.read_all_items():
if item["id"] == sanitized_name:
return item["human_readable_name"]
except Exception:
raise HTTPException(
status_code=500, detail="Error retrieving original blob name."
)
return None


def retrieve_original_entity_config_name(sanitized_name: str) -> str | None:
"""
Retrieve the original human-readable entity config name of a sanitized entity config name.

Args:
-----
sanitized_name (str)
The sanitized name to be converted back to the original name.

Returns: str
The original human-readable name.
"""
try:
container_store_client = (
azure_storage_client_manager.get_cosmos_container_client(
database_name="graphrag", container_name="entities"
)
)
for item in container_store_client.read_all_items():
if item["id"] == sanitized_name:
return item["human_readable_name"]
except Exception:
raise HTTPException(
status_code=500, detail="Error retrieving original entity config name."
)
return None


async def verify_subscription_key_exist(
Ocp_Apim_Subscription_Key: Annotated[str, Header()],
):
Expand Down
41 changes: 32 additions & 9 deletions backend/src/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from src.api.common import (
delete_blob_container,
sanitize_name,
validate_blob_container_name,
verify_subscription_key_exist,
)
Expand Down Expand Up @@ -60,10 +61,13 @@ async def get_all_data_storage_containers():
)
for item in container_store_client.read_all_items():
if item["type"] == "data":
items.append(item["id"])
items.append(item["human_readable_name"])
except Exception as e:
reporter = ReporterSingleton().get_instance()
reporter.on_error(f"Error getting all data containers: {str(e)}")
reporter.on_error("Error getting list of blob containers.\nDetails: " + str(e))
raise HTTPException(
status_code=500, detail="Error getting list of blob containers."
)
return StorageNameList(storage_name=items)


Expand Down Expand Up @@ -136,11 +140,20 @@ async def upload_files(
HTTPException: If the container name is invalid or if any error occurs during the upload process.
"""
reporter = ReporterSingleton().get_instance()
sanitized_storage_name = sanitize_name(storage_name)
# ensure container name follows Azure Blob Storage naming conventions
validate_blob_container_name(storage_name)
try:
validate_blob_container_name(sanitized_storage_name)
except ValueError:
raise HTTPException(
status_code=500,
detail=f"Invalid blob container name: '{storage_name}'. Please try a different name.",
)
try:
blob_service_client = BlobServiceClientSingletonAsync.get_instance()
container_client = blob_service_client.get_container_client(storage_name)
container_client = blob_service_client.get_container_client(
sanitized_storage_name
)
if not await container_client.exists():
await container_client.create_container()

Expand All @@ -166,7 +179,8 @@ async def upload_files(
)
container_store_client.upsert_item(
{
"id": storage_name,
"id": sanitized_storage_name,
"human_readable_name": storage_name,
"type": "data",
}
)
Expand All @@ -175,7 +189,10 @@ async def upload_files(
reporter.on_error(
"Error uploading files.", details={"ErrorDetails": str(e), "files": files}
)
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail=f"Error uploading files to container '{storage_name}'.",
)


@data_route.delete(
Expand All @@ -188,21 +205,27 @@ async def delete_files(storage_name: str):
"""
Delete a specified data storage container.
"""
sanitized_storage_name = sanitize_name(storage_name)
try:
# delete container in Azure Storage
delete_blob_container(storage_name)
delete_blob_container(sanitized_storage_name)
# update container-store in cosmosDB
container_store_client = (
azure_storage_client_manager.get_cosmos_container_client(
database_name="graphrag", container_name="container-store"
)
)
container_store_client.delete_item(storage_name, storage_name)
container_store_client.delete_item(
item=sanitized_storage_name,
partition_key=sanitized_storage_name,
)
except Exception as e:
reporter = ReporterSingleton().get_instance()
reporter.on_error(
f"Error deleting container {storage_name}.",
details={"ErrorDetails": str(e), "Container": storage_name},
)
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500, detail=f"Error deleting container '{storage_name}'."
)
return BaseResponse(status="Success")
12 changes: 7 additions & 5 deletions backend/src/api/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from src.api.common import (
sanitize_name,
validate_index_file_exist,
verify_subscription_key_exist,
)
Expand Down Expand Up @@ -66,8 +67,9 @@ async def global_search_streaming(request: GraphRequest):
index_names = [request.index_name]
else:
index_names = request.index_name
sanitized_index_names = [sanitize_name(name) for name in index_names]

for index_name in index_names:
for index_name in sanitized_index_names:
if not _is_index_complete(index_name):
raise HTTPException(
status_code=500,
Expand All @@ -77,15 +79,15 @@ async def global_search_streaming(request: GraphRequest):
ENTITY_TABLE = "output/create_final_nodes.parquet"
COMMUNITY_REPORT_TABLE = "output/create_final_community_reports.parquet"

for index_name in index_names:
for index_name in sanitized_index_names:
validate_index_file_exist(index_name, COMMUNITY_REPORT_TABLE)
validate_index_file_exist(index_name, ENTITY_TABLE)

# current investigations show that community level 1 is the most useful for global search
COMMUNITY_LEVEL = 1
try:
report_dfs = []
for index_name in index_names:
for index_name in sanitized_index_names:
entity_table_path = f"abfs://{index_name}/{ENTITY_TABLE}"
community_report_table_path = (
f"abfs://{index_name}/{COMMUNITY_REPORT_TABLE}"
Expand All @@ -101,12 +103,12 @@ async def global_search_streaming(request: GraphRequest):
if len(report_df["community_id"]) > 0:
max_id = report_df["community_id"].astype(int).max()
report_df["title"] = [
index_names[0] + "<sep>" + i + "<sep>" + t
sanitized_index_names[0] + "<sep>" + i + "<sep>" + t
for i, t in zip(report_df["community_id"], report_df["title"])
]
for idx, df in enumerate(report_dfs[1:]):
df["title"] = [
index_names[idx + 1] + "<sep>" + i + "<sep>" + t
sanitized_index_names[idx + 1] + "<sep>" + i + "<sep>" + t
for i, t in zip(df["community_id"], df["title"])
]
df["community_id"] = [str(int(i) + max_id + 1) for i in df["community_id"]]
Expand Down
21 changes: 15 additions & 6 deletions backend/src/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from src.api.azure_clients import BlobServiceClientSingleton
from src.api.common import (
sanitize_name,
validate_index_file_exist,
verify_subscription_key_exist,
)
Expand All @@ -39,12 +40,13 @@
)
async def retrieve_graphml_file(index_name: str):
# validate index_name and graphml file existence
sanitized_index_name = sanitize_name(index_name)
graphml_filename = "summarized_graph.graphml"
blob_filepath = f"output/{graphml_filename}" # expected file location of the graph based on the workflow
validate_index_file_exist(index_name, blob_filepath)
validate_index_file_exist(sanitized_index_name, blob_filepath)
try:
blob_client = blob_service_client.get_blob_client(
container=index_name, blob=blob_filepath
container=sanitized_index_name, blob=blob_filepath
)
blob_stream = blob_client.download_blob().chunks()
return StreamingResponse(
Expand All @@ -55,7 +57,10 @@ async def retrieve_graphml_file(index_name: str):
except Exception as e:
reporter = ReporterSingleton().get_instance()
reporter.on_error(f"Could not retrieve graphml file: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graphml file for index '{index_name}'.",
)


@graph_route.get(
Expand All @@ -66,15 +71,19 @@ async def retrieve_graphml_file(index_name: str):
)
async def retrieve_graph_stats(index_name: str):
# validate index_name and knowledge graph file existence
sanitized_index_name = sanitize_name(index_name)
graph_file = "output/summarized_graph.graphml" # expected filename of the graph based on the indexing workflow
validate_index_file_exist(index_name, graph_file)
validate_index_file_exist(sanitized_index_name, graph_file)
try:
storage_client = blob_service_client.get_container_client(index_name)
storage_client = blob_service_client.get_container_client(sanitized_index_name)
blob_data = storage_client.download_blob(graph_file).readall()
bytes_io = BytesIO(blob_data)
g = nx.read_graphml(bytes_io)
return GraphDataResponse(nodes=len(g.nodes), edges=len(g.edges))
except Exception as e:
reporter = ReporterSingleton().get_instance()
reporter.on_error(f"Could not retrieve graph data file: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail=f"Could not retrieve graph statistics for index '{index_name}'.",
)
Loading