Skip to content

Commit

Permalink
Add ApiError into GraphQL schema (mlflow#12702)
Browse files Browse the repository at this point in the history
Signed-off-by: edwardfeng-db <[email protected]>
  • Loading branch information
edwardfeng-db authored Jul 18, 2024
1 parent dd4b2fd commit afd33b3
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions dev/proto_to_graphql/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
self.queries = set() # method_descriptor
self.mutations = set() # method_descriptor
self.inputs = [] # field_descriptor
self.outputs = set() # field_descriptor
self.types = [] # field_descriptor
self.enums = set() # enum_descriptor
self.method_names = set() # package_name_method_name
Expand Down
1 change: 1 addition & 0 deletions dev/proto_to_graphql/parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def process_method(method_descriptor, state):
state.queries.add(method_descriptor)
else:
state.mutations.add(method_descriptor)
state.outputs.add(method_descriptor.output_type)
populate_message_types(method_descriptor.input_type, state, True, set())
populate_message_types(method_descriptor.output_type, state, False, set())

Expand Down
4 changes: 4 additions & 0 deletions dev/proto_to_graphql/schema_autogeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def generate_schema(state):
schema_builder += "import graphene\n"
schema_builder += "import mlflow\n"
schema_builder += "from mlflow.server.graphql.graphql_custom_scalars import LongString\n"
schema_builder += "from mlflow.server.graphql.graphql_errors import ApiError\n"
schema_builder += "from mlflow.utils.proto_json_utils import parse_dict\n"
schema_builder += "\n"

Expand All @@ -103,6 +104,9 @@ def generate_schema(state):
graphene_type = get_graphene_type_for_field(field, False)
schema_builder += f"\n{INDENT}{camel_to_snake(field.name)} = {graphene_type}"

if type in state.outputs:
schema_builder += f"\n{INDENT}apiError = graphene.Field(ApiError)"

if len(type.fields) == 0:
schema_builder += f"\n{INDENT}{DUMMY_FIELD}"

Expand Down
8 changes: 8 additions & 0 deletions mlflow/server/graphql/autogenerated_graphql_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import graphene
import mlflow
from mlflow.server.graphql.graphql_custom_scalars import LongString
from mlflow.server.graphql.graphql_errors import ApiError
from mlflow.utils.proto_json_utils import parse_dict


Expand Down Expand Up @@ -51,6 +52,7 @@ class MlflowModelVersion(graphene.ObjectType):
class MlflowSearchModelVersionsResponse(graphene.ObjectType):
model_versions = graphene.List(graphene.NonNull(MlflowModelVersion))
next_page_token = graphene.String()
apiError = graphene.Field(ApiError)


class MlflowDatasetSummary(graphene.ObjectType):
Expand All @@ -62,6 +64,7 @@ class MlflowDatasetSummary(graphene.ObjectType):

class MlflowSearchDatasetsResponse(graphene.ObjectType):
dataset_summaries = graphene.List(graphene.NonNull(MlflowDatasetSummary))
apiError = graphene.Field(ApiError)


class MlflowMetricWithRunId(graphene.ObjectType):
Expand All @@ -74,6 +77,7 @@ class MlflowMetricWithRunId(graphene.ObjectType):

class MlflowGetMetricHistoryBulkIntervalResponse(graphene.ObjectType):
metrics = graphene.List(graphene.NonNull(MlflowMetricWithRunId))
apiError = graphene.Field(ApiError)


class MlflowFileInfo(graphene.ObjectType):
Expand All @@ -86,6 +90,7 @@ class MlflowListArtifactsResponse(graphene.ObjectType):
root_uri = graphene.String()
files = graphene.List(graphene.NonNull(MlflowFileInfo))
next_page_token = graphene.String()
apiError = graphene.Field(ApiError)


class MlflowDataset(graphene.ObjectType):
Expand Down Expand Up @@ -156,10 +161,12 @@ class MlflowRun(graphene.ObjectType):
class MlflowSearchRunsResponse(graphene.ObjectType):
runs = graphene.List(graphene.NonNull('mlflow.server.graphql.graphql_schema_extensions.MlflowRunExtension'))
next_page_token = graphene.String()
apiError = graphene.Field(ApiError)


class MlflowGetRunResponse(graphene.ObjectType):
run = graphene.Field('mlflow.server.graphql.graphql_schema_extensions.MlflowRunExtension')
apiError = graphene.Field(ApiError)


class MlflowExperimentTag(graphene.ObjectType):
Expand All @@ -179,6 +186,7 @@ class MlflowExperiment(graphene.ObjectType):

class MlflowGetExperimentResponse(graphene.ObjectType):
experiment = graphene.Field(MlflowExperiment)
apiError = graphene.Field(ApiError)


class MlflowSearchModelVersionsInput(graphene.InputObjectType):
Expand Down
15 changes: 15 additions & 0 deletions mlflow/server/graphql/graphql_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import graphene


class ErrorDetail(graphene.ObjectType):
# NOTE: This is not an exhaustive list, might need to add more things in the future if needed.
field = graphene.String()
message = graphene.String()


class ApiError(graphene.ObjectType):
code = graphene.String()
message = graphene.String()
help_url = graphene.String()
trace_id = graphene.String()
error_details = graphene.List(ErrorDetail)

0 comments on commit afd33b3

Please sign in to comment.