Skip to content

Commit

Permalink
Cleaned up code.
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbaker committed Jun 21, 2024
1 parent 00d88e3 commit 058da6c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 71 deletions.
49 changes: 13 additions & 36 deletions backend/src/api/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@
TEXT_UNITS_TABLE = "output/create_base_text_units.parquet"
DOCUMENTS_TABLE = "output/create_base_documents.parquet"
storage_account_blob_url = os.environ["STORAGE_ACCOUNT_BLOB_URL"]
storage_account_name = "null"#storage_account_blob_url.split("//")[1].split(".")[0]
storage_account_name = storage_account_blob_url.split("//")[1].split(".")[0]
storage_account_host = storage_account_blob_url.split("//")[1]
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
}


@source_route.get(
Expand All @@ -55,11 +60,7 @@ async def get_report_info(index_name: str, report_id: str):
try:
report_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{COMMUNITY_REPORT_TABLE}",
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
row = report_table[report_table.community == report_id]
return ReportResponse(text=row["full_content"].values[0])
Expand All @@ -86,19 +87,11 @@ async def get_chunk_info(index_name: str, text_unit_id: str):
try:
text_unit_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{TEXT_UNITS_TABLE}",
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
docs = pd.read_parquet(
f"abfs://{sanitized_index_name}/{DOCUMENTS_TABLE}",
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
links = {
el["id"]: el["title"]
Expand Down Expand Up @@ -135,11 +128,7 @@ async def get_entity_info(index_name: str, entity_id: int):
try:
entity_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
row = entity_table[entity_table.human_readable_id == entity_id]
return EntityResponse(
Expand Down Expand Up @@ -176,11 +165,7 @@ async def get_claim_info(index_name: str, claim_id: int):
try:
claims_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{COVARIATES_TABLE}",
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
claims_table.human_readable_id = claims_table.human_readable_id.astype(
float
Expand Down Expand Up @@ -219,19 +204,11 @@ async def get_relationship_info(index_name: str, relationship_id: int):
try:
relationship_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{RELATIONSHIPS_TABLE}",
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
entity_table = pd.read_parquet(
f"abfs://{sanitized_index_name}/{ENTITY_EMBEDDING_TABLE}",
storage_options={
"account_name": storage_account_name,
"account_host": storage_account_host,
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
row = relationship_table[
relationship_table.human_readable_id == str(relationship_id)
Expand Down
47 changes: 12 additions & 35 deletions backend/src/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

from src.api.azure_clients import BlobServiceClientSingleton

storage_options={
"account_name": BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
}

def get_entities(
entity_table_path: str,
Expand All @@ -21,19 +26,11 @@ def get_entities(
) -> pd.DataFrame:
entity_df = pd.read_parquet(
entity_table_path,
storage_options={
"account_name": "null",#BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
entity_embedding_df = pd.read_parquet(
entity_embedding_table_path,
storage_options={
"account_name": "null",#BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
return pd.DataFrame(
read_indexer_entities(entity_df, entity_embedding_df, community_level)
Expand All @@ -45,54 +42,34 @@ def get_reports(
) -> pd.DataFrame:
entity_df = pd.read_parquet(
entity_table_path,
storage_options={
"account_name": "null",#BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
report_df = pd.read_parquet(
community_report_table_path,
storage_options={
"account_name": "null",#BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
return pd.DataFrame(read_indexer_reports(report_df, entity_df, community_level))


def get_relationships(relationships_table_path: str) -> pd.DataFrame:
relationship_df = pd.read_parquet(
relationships_table_path,
storage_options={
"account_name": "null",#BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
return pd.DataFrame(read_indexer_relationships(relationship_df))


def get_covariates(covariate_table_path: str) -> pd.DataFrame:
covariate_df = pd.read_parquet(
covariate_table_path,
storage_options={
"account_name": "null",#BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
return pd.DataFrame(read_indexer_covariates(covariate_df))


def get_text_units(text_unit_table_path: str) -> pd.DataFrame:
text_unit_df = pd.read_parquet(
text_unit_table_path,
storage_options={
"account_name": "null",#BlobServiceClientSingleton.get_storage_account_name(),
"account_host": BlobServiceClientSingleton.get_instance().url.split("//")[1],
"credential": DefaultAzureCredential(),
},
storage_options=storage_options,
)
return pd.DataFrame(read_indexer_text_units(text_unit_df))

0 comments on commit 058da6c

Please sign in to comment.