diff --git a/dev/proto_to_graphql/code_generator.py b/dev/proto_to_graphql/code_generator.py index b7604aa905bbd..f70789f597459 100644 --- a/dev/proto_to_graphql/code_generator.py +++ b/dev/proto_to_graphql/code_generator.py @@ -16,6 +16,7 @@ def __init__(self): self.queries = set() # method_descriptor self.mutations = set() # method_descriptor self.inputs = [] # field_descriptor + self.outputs = set() # field_descriptor self.types = [] # field_descriptor self.enums = set() # enum_descriptor self.method_names = set() # package_name_method_name diff --git a/dev/proto_to_graphql/parsing_utils.py b/dev/proto_to_graphql/parsing_utils.py index 24c7a03db4b93..7c06e5b58050f 100644 --- a/dev/proto_to_graphql/parsing_utils.py +++ b/dev/proto_to_graphql/parsing_utils.py @@ -26,6 +26,7 @@ def process_method(method_descriptor, state): state.queries.add(method_descriptor) else: state.mutations.add(method_descriptor) + state.outputs.add(method_descriptor.output_type) populate_message_types(method_descriptor.input_type, state, True, set()) populate_message_types(method_descriptor.output_type, state, False, set()) diff --git a/dev/proto_to_graphql/schema_autogeneration.py b/dev/proto_to_graphql/schema_autogeneration.py index 7a0522ceb19f4..a7ea2db20c6f3 100644 --- a/dev/proto_to_graphql/schema_autogeneration.py +++ b/dev/proto_to_graphql/schema_autogeneration.py @@ -84,6 +84,7 @@ def generate_schema(state): schema_builder += "import graphene\n" schema_builder += "import mlflow\n" schema_builder += "from mlflow.server.graphql.graphql_custom_scalars import LongString\n" + schema_builder += "from mlflow.server.graphql.graphql_errors import ApiError\n" schema_builder += "from mlflow.utils.proto_json_utils import parse_dict\n" schema_builder += "\n" @@ -103,6 +104,9 @@ def generate_schema(state): graphene_type = get_graphene_type_for_field(field, False) schema_builder += f"\n{INDENT}{camel_to_snake(field.name)} = {graphene_type}" + if type in state.outputs: + schema_builder += f"\n{INDENT}apiError = graphene.Field(ApiError)" + if len(type.fields) == 0: schema_builder += f"\n{INDENT}{DUMMY_FIELD}" diff --git a/mlflow/server/graphql/autogenerated_graphql_schema.py b/mlflow/server/graphql/autogenerated_graphql_schema.py index 3d1c91e1a214e..1ddfbd531b442 100644 --- a/mlflow/server/graphql/autogenerated_graphql_schema.py +++ b/mlflow/server/graphql/autogenerated_graphql_schema.py @@ -3,6 +3,7 @@ import graphene import mlflow from mlflow.server.graphql.graphql_custom_scalars import LongString +from mlflow.server.graphql.graphql_errors import ApiError from mlflow.utils.proto_json_utils import parse_dict @@ -51,6 +52,7 @@ class MlflowModelVersion(graphene.ObjectType): class MlflowSearchModelVersionsResponse(graphene.ObjectType): model_versions = graphene.List(graphene.NonNull(MlflowModelVersion)) next_page_token = graphene.String() + apiError = graphene.Field(ApiError) class MlflowDatasetSummary(graphene.ObjectType): @@ -62,6 +64,7 @@ class MlflowDatasetSummary(graphene.ObjectType): class MlflowSearchDatasetsResponse(graphene.ObjectType): dataset_summaries = graphene.List(graphene.NonNull(MlflowDatasetSummary)) + apiError = graphene.Field(ApiError) class MlflowMetricWithRunId(graphene.ObjectType): @@ -74,6 +77,7 @@ class MlflowMetricWithRunId(graphene.ObjectType): class MlflowGetMetricHistoryBulkIntervalResponse(graphene.ObjectType): metrics = graphene.List(graphene.NonNull(MlflowMetricWithRunId)) + apiError = graphene.Field(ApiError) class MlflowFileInfo(graphene.ObjectType): @@ -86,6 +90,7 @@ class MlflowListArtifactsResponse(graphene.ObjectType): root_uri = graphene.String() files = graphene.List(graphene.NonNull(MlflowFileInfo)) next_page_token = graphene.String() + apiError = graphene.Field(ApiError) class MlflowDataset(graphene.ObjectType): @@ -156,10 +161,12 @@ class MlflowRun(graphene.ObjectType): class MlflowSearchRunsResponse(graphene.ObjectType): runs = graphene.List(graphene.NonNull('mlflow.server.graphql.graphql_schema_extensions.MlflowRunExtension')) next_page_token = graphene.String() + apiError = graphene.Field(ApiError) class MlflowGetRunResponse(graphene.ObjectType): run = graphene.Field('mlflow.server.graphql.graphql_schema_extensions.MlflowRunExtension') + apiError = graphene.Field(ApiError) class MlflowExperimentTag(graphene.ObjectType): @@ -179,6 +186,7 @@ class MlflowExperiment(graphene.ObjectType): class MlflowGetExperimentResponse(graphene.ObjectType): experiment = graphene.Field(MlflowExperiment) + apiError = graphene.Field(ApiError) class MlflowSearchModelVersionsInput(graphene.InputObjectType): diff --git a/mlflow/server/graphql/graphql_errors.py b/mlflow/server/graphql/graphql_errors.py new file mode 100644 index 0000000000000..38938808dc030 --- /dev/null +++ b/mlflow/server/graphql/graphql_errors.py @@ -0,0 +1,15 @@ +import graphene + + +class ErrorDetail(graphene.ObjectType): + # NOTE: This is not an exhaustive list, might need to add more things in the future if needed. + field = graphene.String() + message = graphene.String() + + +class ApiError(graphene.ObjectType): + code = graphene.String() + message = graphene.String() + help_url = graphene.String() + trace_id = graphene.String() + error_details = graphene.List(ErrorDetail)