Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flytekit] Add custom agent template in Pyflyte Init #51

Merged
merged 9 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions basic-custom-agent/cookiecutter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"project_name": "Basic custom agent"
}
10 changes: 10 additions & 0 deletions basic-custom-agent/{{cookiecutter.project_name}}/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.10-slim-bookworm

MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

# additional dependencies for running in k8s
RUN pip install prometheus-client grpcio-health-checking
# flytekit will autoload the agent if package is installed.
RUN pip install flytekitplugins-bigquery
CMD pyflyte serve agent --port 8000
39 changes: 39 additions & 0 deletions basic-custom-agent/{{cookiecutter.project_name}}/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flyte-custom-agent-template
How to write your custom agent and build it with a Dockerfile.

## Concepts
1. flytekit will load plugin [here](https://github.com/flyteorg/flytekit/blob/ff2d0da686c82266db4dbf764a009896cf062349/flytekit/__init__.py#L322-L323),
so you must add your plugin to `entry_points` in [setup.py](https://github.com/Future-Outlier/flyte-custom-agent-template/blob/main/flytekit-bigquery/setup.py#L39).
2. Agent registration is triggered by loading the plugin. For example,
BigQuery's agent registration is triggered [here](https://github.com/Future-Outlier/flyte-custom-agent/blob/main/flytekit-bigquery/flytekitplugins/bigquery/agent.py#L97)

## Build your custom agent
1. Following the folder structure in this repo, you can build your custom agent.
2. Build your own custom agent ([learn more](https://docs.flyte.org/en/latest/user_guide/flyte_agents/developing_agents.html))

> In the following command, `localhost:30000` is the Docker registry that ships with the Flyte demo cluster. Use it or replace it with a registry where you have push permissions.

```bash
docker buildx build --platform linux/amd64 -t localhost:30000/flyteagent:custom-bigquery -f Dockerfile .
```

3. Test the image:
```bash
docker run -it localhost:30000/flyteagent:custom-bigquery
```

4. Check the logs (sensor is created by flytekit, bigquery is created by the custom agent)
```
(dev) future@outlier ~ % docker run -it localhost:30000/flyteagent:custom-bigquery

WARNING: The requested image's platform (linux/amd64) does not match the detected host platform (linux/arm64/v8) and no specific platform was requested
🚀 Starting the agent service...
Starting up the server to expose the prometheus metrics...
Agent Metadata
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┓
┃ Agent Name ┃ Support Task Types ┃ Is Sync ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━┩
│ Sensor │ sensor (v0) │ False │
│ Bigquery Agent │ bigquery_query_job_task (v0) │ False │
└────────────────┴───────────────────────────────┴─────────┘
```
47 changes: 47 additions & 0 deletions basic-custom-agent/{{cookiecutter.project_name}}/docker_build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash

set -e

# SET the REGISTRY here, where the docker container should be pushed
REGISTRY=""

# SET the appname here
PROJECT_NAME="{{ cookiecutter.project_name }}"

while getopts a:r:v:h flag
do
case "${flag}" in
a) PROJECT_NAME=${OPTARG};;
r) REGISTRY=${OPTARG};;
v) VERSION=${OPTARG};;
h) echo "Usage: ${0} [-h|[-p <project_name>][-r <registry_name>][-v <version>]]"
echo " h: help (this message)"
echo " p: PROJECT_NAME for your workflows. Defaults to '{{ cookiecutter.project_name }}'."
echo " r: REGISTRY name where the docker container should be pushed. Defaults to none - localhost"
echo " v: VERSION of the build. Defaults to using the current git head SHA"
exit 1;;
*) echo "Usage: ${0} [-h|[-a <project_name>][-r <registry_name>][-v <version>]]"
exit 1;;
esac
done

# If you are using git, then this will automatically use the git head as the
# version
if [ -z "${VERSION}" ]; then
echo "No version set, using git commit head sha as the version"
VERSION=$(git rev-parse HEAD)
fi

TAG=${PROJECT_NAME}:${VERSION}
if [ -z "${REGISTRY}" ]; then
echo "No registry set, creating tag ${TAG}"
else
TAG="${REGISTRY}/${TAG}"
echo "Registry set: creating tag ${TAG}"
fi

# Should be run in the folder that has Dockerfile
docker buildx build --platform linux/amd64 -t "${TAG}" -f Dockerfile .
docker push "${TAG}"

echo "Docker image built with tag ${TAG}. You can use this image to run pyflyte package."
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
.. currentmodule:: flytekitplugins.bigquery

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

BigQueryConfig
BigQueryTask
BigQueryAgent
"""

from .agent import BigQueryAgent
from .task import BigQueryConfig, BigQueryTask
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import datetime
from dataclasses import dataclass
from typing import Dict, Optional

from flyteidl.core.execution_pb2 import TaskExecution, TaskLog
from google.cloud import bigquery

from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

pythonTypeToBigQueryType: Dict[type, str] = {
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes
list: "ARRAY",
bool: "BOOL",
bytes: "BYTES",
datetime.datetime: "DATETIME",
float: "FLOAT64",
int: "INT64",
str: "STRING",
}


@dataclass
class BigQueryMetadata(ResourceMeta):
job_id: str
project: str
location: str


class BigQueryAgent(AsyncAgentBase):
name = "Bigquery Agent"

def __init__(self):
super().__init__(task_type_name="bigquery_query_job_task", metadata_type=BigQueryMetadata)

def create(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
**kwargs,
) -> BigQueryMetadata:
job_config = None
if inputs:
ctx = FlyteContextManager.current_context()
python_interface_inputs = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items()
}
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs)
logger.info(f"Create BigQuery job config with inputs: {native_inputs}")
job_config = bigquery.QueryJobConfig(
query_parameters=[
bigquery.ScalarQueryParameter(name, pythonTypeToBigQueryType[python_interface_inputs[name]], val)
for name, val in native_inputs.items()
]
)

custom = task_template.custom
project = custom["ProjectID"]
location = custom["Location"]
client = bigquery.Client(project=project, location=location)
query_job = client.query(task_template.sql.statement, job_config=job_config)

return BigQueryMetadata(job_id=str(query_job.job_id), location=location, project=project)

def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource:
client = bigquery.Client()
log_link = TaskLog(
uri=f"https://console.cloud.google.com/bigquery?project={resource_meta.project}&j=bq:{resource_meta.location}:{resource_meta.job_id}&page=queryresults",
name="BigQuery Console",
)

job = client.get_job(resource_meta.job_id, resource_meta.project, resource_meta.location)
if job.errors:
logger.error("failed to run BigQuery job with error:", job.errors.__str__())
return Resource(phase=TaskExecution.FAILED, message=job.errors.__str__(), log_links=[log_link])

cur_phase = convert_to_flyte_phase(str(job.state))
res = None

if cur_phase == TaskExecution.SUCCEEDED:
dst = job.destination
if dst:
output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}"
res = {"results": StructuredDataset(uri=output_location)}

return Resource(phase=cur_phase, message=str(job.state), log_links=[log_link], outputs=res)

def delete(self, resource_meta: BigQueryMetadata, **kwargs):
client = bigquery.Client()
client.cancel_job(resource_meta.job_id, resource_meta.project, resource_meta.location)


AgentRegistry.register(BigQueryAgent())
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.extend import SQLTask
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.models import task as _task_model
from flytekit.types.structured import StructuredDataset

bigquery = lazy_module("google.cloud.bigquery")


@dataclass
class BigQueryConfig(object):
"""
BigQueryConfig should be used to configure a BigQuery Task.
"""

ProjectID: str
Location: Optional[str] = None
QueryJobConfig: Optional[bigquery.QueryJobConfig] = None


class BigQueryTask(AsyncAgentExecutorMixin, SQLTask[BigQueryConfig]):
"""
This is the simplest form of a BigQuery Task, that can be used even for tasks that do not produce any output.
"""

# This task is executed using the BigQuery handler in the backend.
# https://github.com/flyteorg/flyteplugins/blob/43623826fb189fa64dc4cb53e7025b517d911f22/go/tasks/plugins/webapi/bigquery/plugin.go#L34
_TASK_TYPE = "bigquery_query_job_task"

def __init__(
self,
name: str,
query_template: str,
task_config: BigQueryConfig,
inputs: Optional[Dict[str, Type]] = None,
output_structured_dataset_type: Optional[Type[StructuredDataset]] = None,
**kwargs,
):
"""
To be used to query BigQuery Tables.

:param name: Name of this task, should be unique in the project
:param query_template: The actual query to run. We use Flyte's Golang templating format for Query templating. Refer to the templating documentation
:param task_config: BigQueryConfig object
:param inputs: Name and type of inputs specified as an ordered dictionary
:param output_structured_dataset_type: If some data is produced by this query, then you can specify the output StructuredDataset type
:param kwargs: All other args required by Parent type - SQLTask
"""
outputs = None
if output_structured_dataset_type is not None:
outputs = {
"results": output_structured_dataset_type,
}
super().__init__(
name=name,
task_config=task_config,
query_template=query_template,
inputs=inputs,
outputs=outputs,
task_type=self._TASK_TYPE,
**kwargs,
)
self._output_structured_dataset_type = output_structured_dataset_type

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
config = {
"Location": self.task_config.Location,
"ProjectID": self.task_config.ProjectID,
}
if self.task_config.QueryJobConfig is not None:
config.update(self.task_config.QueryJobConfig.to_api_repr()["query"])
s = Struct()
s.update(config)
return json_format.MessageToDict(s)

def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]:
sql = _task_model.Sql(statement=self.query_template, dialect=_task_model.Sql.Dialect.ANSI)
return sql
40 changes: 40 additions & 0 deletions basic-custom-agent/{{cookiecutter.project_name}}/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from setuptools import setup

PLUGIN_NAME = "bigquery"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = [
"flytekit>1.10.7",
"google-cloud-bigquery>=3.21.0",
"google-cloud-bigquery-storage>=2.25.0",
"flyteidl>1.10.7",
]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="This package holds the Bigquery plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.9",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ PROJECT_NAME="{{ cookiecutter.project_name }}"
while getopts a:r:v:h flag
do
case "${flag}" in
p) PROJECT_NAME=${OPTARG};;
a) PROJECT_NAME=${OPTARG};;
r) REGISTRY=${OPTARG};;
v) VERSION=${OPTARG};;
h) echo "Usage: ${0} [-h|[-p <project_name>][-r <registry_name>][-v <version>]]"
Expand Down
Loading