Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Regression of outdated vfolder GQL resolver #3047

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3047.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix regression of outdated `vfolder` GQL resolver.
21 changes: 21 additions & 0 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,27 @@ async def batch_result_in_session(
return [*objs_per_key.values()]


async def batch_result_in_scalar_stream(
graph_ctx: GraphQueryContext,
db_sess: SASession,
query: sa.sql.Select,
obj_type: type[T_SQLBasedGQLObject],
key_list: Iterable[T_Key],
key_getter: Callable[[Row], T_Key],
) -> Sequence[Optional[T_SQLBasedGQLObject]]:
"""
A batched query adaptor for (key -> item) resolving patterns.
stream the result scalar in async session.
"""
objs_per_key: dict[T_Key, Optional[T_SQLBasedGQLObject]]
objs_per_key = {}
for key in key_list:
objs_per_key[key] = None
async for row in await db_sess.stream_scalars(query):
objs_per_key[key_getter(row)] = obj_type.from_row(graph_ctx, row)
return [*objs_per_key.values()]


async def batch_multiresult_in_session(
graph_ctx: GraphQueryContext,
db_sess: SASession,
Expand Down
8 changes: 4 additions & 4 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,16 +1874,16 @@ async def resolve_vfolder(
user_id: Optional[uuid.UUID] = None,
) -> Optional[VirtualFolder]:
graph_ctx: GraphQueryContext = info.context
user_role = graph_ctx.user["role"]
vfolder_id = uuid.UUID(id)
loader = graph_ctx.dataloader_manager.get_loader(
graph_ctx,
"VirtualFolder.by_id",
Copy link
Member Author

@jopemachine jopemachine Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to use get_loader_by_func and VirtualFolder.batch_load_by_id here, but then I'm getting a type error since the return type of VirtualFolder.batch_load_by_id differs from what get_loader_by_func expects.

user_uuid=user_id,
user_role=user_role,
domain_name=domain_name,
group_id=group_id,
user_id=user_id,
filter=None,
)
return await loader.load(id)
return await loader.load(vfolder_id)

@staticmethod
@scoped_query(autofill_user=False, user_key="user_id")
Expand Down
116 changes: 72 additions & 44 deletions src/ai/backend/manager/models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
List,
NamedTuple,
Optional,
Self,
Sequence,
TypeAlias,
cast,
Expand All @@ -35,7 +36,7 @@
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, relationship, selectinload
from sqlalchemy.orm import joinedload, load_only, relationship, selectinload

from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.types import (
Expand Down Expand Up @@ -75,6 +76,7 @@
QuotaScopeIDType,
StrEnumType,
batch_multiresult,
batch_result_in_scalar_stream,
metadata,
)
from .group import GroupRow
Expand Down Expand Up @@ -1391,40 +1393,67 @@ class Meta:
status = graphene.String()

@classmethod
def from_row(cls, ctx: GraphQueryContext, row: Row | VFolderRow) -> Optional[VirtualFolder]:
if row is None:
return None

def _get_field(name: str) -> Any:
try:
return row[name]
except sa.exc.NoSuchColumnError:
def from_row(cls, ctx: GraphQueryContext, row: Row | VFolderRow | None) -> Optional[Self]:
match row:
case None:
return None

return cls(
id=row["id"],
host=row["host"],
quota_scope_id=row["quota_scope_id"],
name=row["name"],
user=row["user"],
user_email=_get_field("users_email"),
group=row["group"],
group_name=_get_field("groups_name"),
creator=row["creator"],
domain_name=row["domain_name"],
unmanaged_path=row["unmanaged_path"],
usage_mode=row["usage_mode"],
permission=row["permission"],
ownership_type=row["ownership_type"],
max_files=row["max_files"],
max_size=row["max_size"], # in MiB
created_at=row["created_at"],
last_used=row["last_used"],
# num_attached=row['num_attached'],
cloneable=row["cloneable"],
status=row["status"],
cur_size=row["cur_size"],
)
case VFolderRow():
return cls(
id=row.id,
host=row.host,
quota_scope_id=row.quota_scope_id,
name=row.name,
user=row.user,
user_email=row.user_row.email if row.user_row is not None else None,
group=row.group,
group_name=row.group_row.name if row.group_row is not None else None,
creator=row.creator,
domain_name=row.domain_name,
unmanaged_path=row.unmanaged_path,
usage_mode=row.usage_mode,
permission=row.permission,
ownership_type=row.ownership_type,
max_files=row.max_files,
max_size=row.max_size, # in MiB
created_at=row.created_at,
last_used=row.last_used,
cloneable=row.cloneable,
status=row.status,
cur_size=row.cur_size,
)
case Row():

def _get_field(name: str) -> Any:
try:
return row[name]
except (KeyError, sa.exc.NoSuchColumnError):
return None

return cls(
id=row["id"],
host=row["host"],
quota_scope_id=row["quota_scope_id"],
name=row["name"],
user=row["user"],
user_email=_get_field("users_email"),
group=row["group"],
group_name=_get_field("groups_name"),
creator=row["creator"],
domain_name=row["domain_name"],
unmanaged_path=row["unmanaged_path"],
usage_mode=row["usage_mode"],
permission=row["permission"],
ownership_type=row["ownership_type"],
max_files=row["max_files"],
max_size=row["max_size"], # in MiB
created_at=row["created_at"],
last_used=row["last_used"],
# num_attached=row['num_attached'],
cloneable=row["cloneable"],
status=row["status"],
cur_size=row["cur_size"],
)
raise ValueError(f"Type not allowed to parse (t:{type(row)})")

@classmethod
def from_orm_row(cls, row: VFolderRow) -> VirtualFolder:
Expand Down Expand Up @@ -1593,20 +1622,19 @@ async def load_slice(
async def batch_load_by_id(
cls,
graph_ctx: GraphQueryContext,
ids: list[str],
ids: list[uuid.UUID],
*,
domain_name: str | None = None,
group_id: uuid.UUID | None = None,
user_id: uuid.UUID | None = None,
filter: str | None = None,
) -> Sequence[Sequence[VirtualFolder]]:
domain_name: Optional[str] = None,
group_id: Optional[uuid.UUID] = None,
user_id: Optional[uuid.UUID] = None,
filter: Optional[str] = None,
) -> Sequence[Optional[VirtualFolder]]:
from .user import UserRow

j = sa.join(VFolderRow, UserRow, VFolderRow.user == UserRow.uuid)
query = (
sa.select(VFolderRow)
.select_from(j)
.where(VFolderRow.id.in_(ids))
.options(joinedload(VFolderRow.user_row), joinedload(VFolderRow.group_row))
.order_by(sa.desc(VFolderRow.created_at))
)
if user_id is not None:
Expand All @@ -1619,13 +1647,13 @@ async def batch_load_by_id(
qfparser = QueryFilterParser(cls._queryfilter_fieldspec)
query = qfparser.append_filter(query, filter)
async with graph_ctx.db.begin_readonly_session() as db_sess:
return await batch_multiresult(
return await batch_result_in_scalar_stream(
graph_ctx,
db_sess,
query,
cls,
ids,
lambda row: row["user"],
Copy link
Member Author

@jopemachine jopemachine Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since user_id is provided as a separate argument, the id here should be the vfolder's id.

lambda row: row.id,
)

@classmethod
Expand Down
Loading