Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
Signed-off-by: Jeffrey Kinard <[email protected]>
  • Loading branch information
Polber committed Dec 24, 2024
1 parent 063890c commit e8bc920
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
so that input data can be sent using an API request, and inferences can be
received as a response.
This Model Handler also required a `preprocess` function to be defined.
This Model Handler also requires a `preprocess` function to be defined.
Preprocessing and Postprocessing are described in more detail in the
RunInference docs:
https://beam.apache.org/releases/yamldoc/current/#runinference
Expand Down Expand Up @@ -249,6 +249,11 @@ def inference_output_type(self):
('model_id', Optional[str])])


def get_user_schema_fields(user_type):
return [(name, type(typ) if not isinstance(typ, type) else typ)
for (name, typ) in user_type._fields] if user_type else []


@beam.ptransform.ptransform_fn
def run_inference(
pcoll,
Expand Down Expand Up @@ -434,18 +439,17 @@ def fn(x: PredictionResult):
if missing_params:
raise ValueError(f'Missing parameters in model_handler: {missing_params}')
typ = model_handler['type']
model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None)
if model_handler_provider and issubclass(model_handler_provider,
type(ModelHandlerProvider)):
model_handler_provider.validate(model_handler['config'])
else:
model_handler_provider_type = ModelHandlerProvider.handler_types.get(
typ, None)
if not model_handler_provider_type:
raise NotImplementedError(f'Unknown model handler type: {typ}.')

model_handler_provider = ModelHandlerProvider.create_handler(model_handler)
model_handler_provider.validate(model_handler['config'])
user_type = RowTypeConstraint.from_user_type(pcoll.element_type.user_type)
schema = RowTypeConstraint.from_fields(
list(user_type._fields if user_type else []) +
[(inference_tag, model_handler_provider.inference_output_type())])
get_user_schema_fields(user_type) +
[(str(inference_tag), model_handler_provider.inference_output_type())])

return (
pcoll | RunInference(
Expand Down

0 comments on commit e8bc920

Please sign in to comment.