Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functools wraps to inspect decorators #1171

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

rusheb
Copy link
Contributor

@rusheb rusheb commented Jan 22, 2025

This PR contains:

  • New features
  • Changes to dev-tools e.g. CI config / github tooling
  • Docs
  • Bug fixes
  • Code refactor

What is the current behavior? (You can also link to an open issue here)

It's not possible to get type hints for Inspect decorated tasks. Example:

from enum import StrEnum
from typing import get_type_hints

from inspect_ai import Task, task
from inspect_ai.dataset import Sample
from inspect_ai.solver._solver import Generate, solver
from inspect_ai.solver._task_state import TaskState


class EnumParam(StrEnum):
    A = "a"
    B = "b"

# ==================

def mytask(str_param: str, enum_param: EnumParam):
    return Task(dataset=[Sample(input="foo")])
decorated_task = task(mytask)

print(f"my_task:          {get_type_hints(mytask)}")
print(f"decorated_task:   {get_type_hints(decorated_task)}")

# ==================

def my_solver(str_param: str, enum_param: EnumParam):
    async def solve(state: TaskState, generate: Generate) -> TaskState:
        return state

    return solve
decorated_solver = solver(my_solver)

print()
print(f"my_solver:         {get_type_hints(my_solver)}")
print(f"decorated_solver: {get_type_hints(decorated_solver)}")

Output:

my_task:           {'str_param': <class 'str'>, 'enum_param': <enum 'EnumParam'>}
decorated_task:    {'w_args': typing.Any, 'w_kwargs': typing.Any, 'return': <class 'inspect_ai._eval.task.task.Task'>}

my_solver:         {'str_param': <class 'str'>, 'enum_param': <enum 'EnumParam'>}
decorated_solver: {'args': typing.Any, 'kwargs': dict[str, typing.Any], 'return': <class 'inspect_ai.solver._solver.Solver'>}

What is the new behavior?

Type hints are preserved. Behaviour of the above test script afterwards.

my_task:          {'str_param': <class 'str'>, 'enum_param': <enum 'EnumParam'>, 'return': <class 'inspect_ai._eval.task.task.Task'>}
decorated_task:   {'str_param': <class 'str'>, 'enum_param': <enum 'EnumParam'>, 'return': <class 'inspect_ai._eval.task.task.Task'>}

mys_olver:         {'str_param': <class 'str'>, 'enum_param': <enum 'EnumParam'>, 'return': <class 'inspect_ai.solver._solver.Solver'>}
decorated_solver: {'str_param': <class 'str'>, 'enum_param': <enum 'EnumParam'>, 'return': <class 'inspect_ai.solver._solver.Solver'>}

Does this PR introduce a breaking change? (What changes might users need to make in their application due to this PR?)

No

Other information:

@rusheb rusheb marked this pull request as draft January 22, 2025 12:35
@rusheb
Copy link
Contributor Author

rusheb commented Jan 22, 2025

Edit: The issue described in this comment is resolved as of 52183a8

Details

@jjallaire i'm not sure the best way to proceed here. My requirement is to have a way to access the task type hints so that I can dynamically instantiate tasks from config files.

The solution I've tried here seems to break some of your tests. For some reason the "return" property of the annotations dict gets deleted when I apply @wraps.

I also tried explicitly adding the return type to the __annotations__ property. This seemed to fix some of the tests but not all of them:

diff --git a/src/inspect_ai/_eval/registry.py b/src/inspect_ai/_eval/registry.py
index 4fb00d65..13d9a13a 100644
--- a/src/inspect_ai/_eval/registry.py
+++ b/src/inspect_ai/_eval/registry.py
@@ -156,6 +156,11 @@ def task(*args: Any, name: str | None = None, **attribs: Any) -> Any:
             # Return the task instance
             return task_instance
 
+        wrapper.__annotations__ = {
+            **task_type.__annotations__,
+            'return': Task
+        }
+
         # Register the task and return the wrapper
         return task_register(
             task=cast(TaskType, wrapper), name=task_name, attribs=attribs, params=params
k=cast(TaskType, wrapper), name=task_name, attribs=attribs, params=params
    )

I think there's 2 things we could do:

  1. if you know how to fix the failing tests and it's not too difficult then we would keep @wraps as it's a standard practice in python to use this with decorators
  2. alternatively we could just assign the wrapped function to a new variable e.g. __wrapped__ on the wrapper function, e.g.
diff --git a/src/inspect_ai/_eval/registry.py b/src/inspect_ai/_eval/registry.py
index 4fb00d65..cf1ded4e 100644
--- a/src/inspect_ai/_eval/registry.py
+++ b/src/inspect_ai/_eval/registry.py
@@ -156,6 +156,8 @@ def task(*args: Any, name: str | None = None, **attribs: Any) -> Any:
             # Return the task instance
             return task_instance
 
+        wrapper.__wrapped__ = task_type
+
         # Register the task and return the wrapper
         return task_register(
             task=cast(TaskType, wrapper), name=task_name, attribs=attribs, params=params

Then from my code I could call get_type_hints(task_fn.__wrapped__) instead of get_type_hints(task_fn) to get the annotations.

Let me know what you think is best.

@rusheb rusheb marked this pull request as ready for review January 22, 2025 13:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant