Skip to content

Commit

Permalink
tasks: add registry of tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
kpsherva committed Aug 23, 2024
1 parent 31c48e4 commit b668c7c
Show file tree
Hide file tree
Showing 12 changed files with 334 additions and 53 deletions.
25 changes: 22 additions & 3 deletions invenio_jobs/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
# under the terms of the MIT License; see LICENSE file for more details.

"""Jobs extension."""

import importlib_metadata
from celery import current_app as current_celery_app
from flask import current_app
from invenio_i18n import gettext as _

from . import config
from .models import Task
from .registry import JobsRegistry
from .resources import (
JobsResource,
JobsResourceConfig,
Expand All @@ -40,11 +41,13 @@ def __init__(self, app=None):
if app:
self.init_app(app)

def init_app(self, app):
def init_app(self, app, entry_point_group="invenio_jobs.jobs"):
"""Flask application initialization."""
self.init_config(app)
self.init_services(app)
self.init_resource(app)
self.entry_point_group = entry_point_group
self.registry = JobsRegistry()
app.extensions["invenio-jobs"] = self

def init_config(self, app):
Expand All @@ -69,6 +72,13 @@ def init_resource(self, app):
TasksResourceConfig.build(app), self.tasks_service
)

def load_entry_point_group(self):
"""Load actions from an entry point group."""
entrypoints = set(importlib_metadata.entry_points(group=self.entry_point_group))
for ep in entrypoints:
entry_point = ep.load()
yield entry_point

@property
def queues(self):
"""Return the queues."""
Expand All @@ -85,13 +95,22 @@ def default_queue(self):
@property
def tasks(self):
"""Return the tasks."""
return Task.all()
# backwards compatibility
return self.jobs

@property
def jobs(self):
"""Return the tasks."""
return self.registry.all_registered_jobs()



def finalize_app(app):
"""Finalize app."""
rr_ext = app.extensions["invenio-records-resources"]
ext = app.extensions["invenio-jobs"]
for ep in ext.load_entry_point_group():
ext.registry.register(ep)

# services
rr_ext.registry.register(ext.service, service_id="jobs")
Expand Down
39 changes: 39 additions & 0 deletions invenio_jobs/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Invenio-Jobs is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
from functools import partial


class RegisteredTask:

arguments_schema = None
task = None
id = None
title = None
description = None
@classmethod
def factory(cls, job_cls_name, arguments_schema, id_, task, description, title, attrs=None):
"""Create a new instance of a job."""
if not attrs:
attrs = {}
return type(
job_cls_name,
(RegisteredTask,),
dict(
id=id_,
arguments_schema=arguments_schema,
task=task,
description=description,
title=title,
**attrs
),
)

@classmethod
def build_task_arguments(cls, job_obj, since=None, custom_args=None, **kwargs):
if custom_args:
return custom_args
return {}
80 changes: 53 additions & 27 deletions invenio_jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from sqlalchemy_utils import Timestamp
from sqlalchemy_utils.types import ChoiceType, JSONType, UUIDType
from werkzeug.utils import cached_property
from invenio_jobs.proxies import current_jobs


from .utils import eval_tpl_str, walk_values

Expand All @@ -34,6 +36,11 @@
)


class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self

def _dump_dict(model):
"""Dump a model to a dictionary."""
return {c.key: getattr(model, c.key) for c in sa.inspect(model).mapper.column_attrs}
Expand All @@ -49,13 +56,27 @@ class Job(db.Model, Timestamp):

task = db.Column(db.String(255))
default_queue = db.Column(db.String(64))
default_args = db.Column(JSON, default=lambda: dict(), nullable=True)
# default_args = db.Column(JSON, default=lambda: dict(), nullable=True)
schedule = db.Column(JSON, nullable=True)

@property
def last_run(self):
"""Last run of the job."""
return self.runs.order_by(Run.created.desc()).first()
_run = self.runs.order_by(Run.created.desc()).first()
return _run if _run else {}

@property
def last_runs(self):
"""Last run of the job."""
_runs = {}
for status in RunStatusEnum:
run = self.runs.filter_by(status=status).order_by(Run.created.desc()).first()
_runs[status.name.lower()] = run if run else {}
return _runs

@property
def default_args(self):
return Task.get(self.task).build_task_arguments(job_obj=self)

@property
def parsed_schedule(self):
Expand Down Expand Up @@ -136,24 +157,33 @@ def generate_args(cls, job):
We allow a templating mechanism to generate the args for the run. It's important
that the Jinja template context only includes "safe" values, i.e. no DB model
classes or Python objects or functions. Otherwise we risk that users could
execute arbitrary code, or perform harfmul DB operations (e.g. delete rows).
classes or Python objects or functions. Otherwise, we risk that users could
execute arbitrary code, or perform harmful DB operations (e.g. delete rows).
"""
args = deepcopy(job.default_args)
ctx = {"job": job.dump()}

# ctx = {"job": job.dump()}
# Add last runs
last_runs = {}
for status in RunStatusEnum:
run = job.runs.filter_by(status=status).order_by(cls.created.desc()).first()
last_runs[status.name.lower()] = run.dump() if run else None
ctx["last_runs"] = last_runs
ctx["last_run"] = job.last_run.dump() if job.last_run else None
walk_values(args, lambda val: eval_tpl_str(val, ctx))
# last_runs = {}
# for status in RunStatusEnum:
# run = job.runs.filter_by(status=status).order_by(cls.created.desc()).first()
# last_runs[status.name.lower()] = run.dump() if run else None
# ctx["last_runs"] = last_runs
# ctx["last_run"] = job.last_run.dump() if job.last_run else None
import json
args = json.dumps(args, indent=4, sort_keys=True, default=str)
args = json.loads(args)
# walk_values(args, lambda val: eval_tpl_str(val, ctx))
return args

def dump(self):
"""Dump the run as a dictionary."""
return _dump_dict(self)
dict_run = _dump_dict(self)
from invenio_jobs.services.schema import RegisteredTaskArgumentsSchema
serialized_args = RegisteredTaskArgumentsSchema().load({"args": dict_run["args"]})

dict_run["args"] = serialized_args
return dict_run


class Task:
Expand All @@ -177,21 +207,17 @@ def description(self):
return ""
return self._obj.__doc__.split("\n")[0]

@cached_property
def parameters(self):
"""Return the task's parameters."""
# TODO: Make this result more user friendly or enhance with type information
return signature(self._obj).parameters
# @cached_property
# def parameters(self):
# """Return the task's parameters."""
# TODO: Make this result more user friendly or enhance with type information
# return signature(self._obj).parameters

@classmethod
def all(cls):
"""Return all tasks."""
if getattr(cls, "_all_tasks", None) is None:
# Cache results
cls._all_tasks = {
k: cls(task)
for k, task in current_celery_app.tasks.items()
# Filter outer Celery internal tasks
if not k.startswith("celery.")
}
return cls._all_tasks
return current_jobs.jobs

@classmethod
def get(cls, id_):
return cls(current_jobs.registry.get(id_))
49 changes: 49 additions & 0 deletions invenio_jobs/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Invenio-Jobs is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

class JobsRegistry:
"""A simple class to register jobs."""

def __init__(self):
"""Initialize the registry."""
self._jobs = {}

def register(self, job_instance, job_id=None):
"""Register a new job instance."""
if job_id is None:
job_id = job_instance.id
if job_id in self._jobs:
raise RuntimeError(
f"Job with job id '{job_id}' is already registered."
)
self._jobs[job_id] = job_instance

def get(self, job_id):
"""Get a job for a given job_id."""
return self._jobs[job_id]

def get_job_id(self, instance):
"""Get the service id for a specific instance."""
for job_id, job_instance in self._jobs.items():
if instance == job_instance:
return job_id
raise KeyError("Job not found in registry.")

def all_registered_jobs(self):
"""Return a list of available tasks."""
return self._jobs

def all_arguments(self):
return [task.arguments_schema for task_id, task in self._jobs.items()]

def registered_schemas(self):
schemas = {}
for id_, registered_task in self._jobs.items():
schema = registered_task.arguments_schema
if schema:
schemas[f"{schema.__name__}API"] = schema
return schemas
4 changes: 3 additions & 1 deletion invenio_jobs/resources/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ class TasksResourceConfig(ResourceConfig, ConfiguratorMixin):
# Blueprint configuration
blueprint_name = "tasks"
url_prefix = "/tasks"
routes = {"list": ""}
routes = {"list": "", "arguments": "/<registered_task_id>/args"}


# Request handling
request_search_args = SearchRequestArgsSchema
request_body_parsers = request_body_parsers
request_view_args = {"registered_task_id": ma.fields.String()}

# Response handling
response_handlers = response_handlers
Expand Down
9 changes: 9 additions & 0 deletions invenio_jobs/resources/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from flask import g
from flask_resources import Resource, resource_requestctx, response_handler, route
from invenio_administration.marshmallow_utils import jsonify_schema
from invenio_records_resources.resources.errors import ErrorHandlersMixin
from invenio_records_resources.resources.records.resource import (
request_data,
Expand All @@ -32,6 +33,7 @@ def create_url_rules(self):
routes = self.config.routes
url_rules = [
route("GET", routes["list"], self.search),
route("GET", routes["arguments"], self.read_arguments)
]

return url_rules
Expand All @@ -50,6 +52,13 @@ def search(self):
)
return hits.to_dict(), 200

@request_view_args
def read_arguments(self):
identity = g.identity
registered_task_id = resource_requestctx.view_args["registered_task_id"]
arguments_schema = self.service.read_registered_task_arguments(identity, registered_task_id)
return jsonify_schema(arguments_schema) if arguments_schema else {}


class JobsResource(ErrorHandlersMixin, Resource):
"""Jobs resource."""
Expand Down
6 changes: 4 additions & 2 deletions invenio_jobs/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TasksServiceConfig(ServiceConfig, ConfiguratorMixin):
record_cls = Task
search = TasksSearchOptions
schema = TaskSchema
argument_item_cls = results.Item

permission_policy_cls = FromConfig(
"JOBS_TASKS_PERMISSION_POLICY",
Expand Down Expand Up @@ -94,12 +95,13 @@ class JobsServiceConfig(ServiceConfig, ConfiguratorMixin):
default=JobPermissionPolicy,
)

result_item_cls = results.Item
result_list_cls = results.List
result_item_cls = results.JobItem
result_list_cls = results.JobList

links_item = {
"self": JobLink("{+api}/jobs/{id}"),
"runs": JobLink("{+api}/jobs/{id}/runs"),
"self_admin_html": JobLink("{+ui}/administration/jobs/{id}"),
}

links_search = pagination_links("{+api}/jobs{?args*}")
Expand Down
Loading

0 comments on commit b668c7c

Please sign in to comment.