Skip to content

Commit

Permalink
add async generator return value
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Oct 13, 2024
1 parent dcd58c5 commit fcabdea
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 24 deletions.
8 changes: 3 additions & 5 deletions Doc/library/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -485,18 +485,16 @@ or :class:`Iterator[YieldType] <collections.abc.Iterator>`::
yield start
start += 1

Async generators are handled in a similar fashion, but don't
expect a ``ReturnType`` type argument
(:class:`AsyncGenerator[YieldType, SendType] <collections.abc.AsyncGenerator>`).
The ``SendType`` argument defaults to :const:`!None`, so the following definitions
Async generators are handled in a similar fashion.
The ``SendType`` and ```ReturnType``` parameters default to :const:`!None`, so the following definitions
are equivalent::

async def infinite_stream(start: int) -> AsyncGenerator[int]:
while True:
yield start
start = await increment(start)

async def infinite_stream(start: int) -> AsyncGenerator[int, None]:
async def infinite_stream(start: int) -> AsyncGenerator[int, None, None]:
while True:
yield start
start = await increment(start)
Expand Down
5 changes: 3 additions & 2 deletions Doc/reference/simple_stmts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,9 @@ becomes the :attr:`StopIteration.value` attribute.

In an asynchronous generator function, an empty :keyword:`return` statement

Check warning on line 515 in Doc/reference/simple_stmts.rst

View workflow job for this annotation

GitHub Actions / Docs / Docs

py:attr reference target not found: StopAsyncIteration.value [ref.attr]
indicates that the asynchronous generator is done and will cause
:exc:`StopAsyncIteration` to be raised. A non-empty :keyword:`!return`
statement is a syntax error in an asynchronous generator function.
:exc:`StopAsyncIteration` to be raised. The returned
value (if any) is used as an argument to construct :exc:`StopAsyncIteration` and
becomes the :attr:`StopAsyncIteration.value` attribute.

.. _yield:

Expand Down
6 changes: 6 additions & 0 deletions Include/cpython/pyerrors.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ typedef struct {
PyObject *value;
} PyStopIterationObject;

typedef struct {
PyException_HEAD
PyObject *value;
} PyStopAsyncIterationObject;


typedef struct {
PyException_HEAD
PyObject *name;
Expand Down
40 changes: 34 additions & 6 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,15 @@ def test_async_gen_syntax_03(self):
return 123
'''

with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
exec(code, {}, {})
exec(code, {}, {})

def test_async_gen_syntax_04(self):
code = '''async def foo():
yield
return 123
'''

with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
exec(code, {}, {})
exec(code, {}, {})

def test_async_gen_syntax_05(self):
code = '''async def foo():
Expand All @@ -127,8 +125,7 @@ def test_async_gen_syntax_05(self):
return 12
'''

with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
exec(code, {}, {})
exec(code, {}, {})


class AsyncGenTest(unittest.TestCase):
Expand Down Expand Up @@ -2056,5 +2053,36 @@ async def agenfn():
del gen2
gc_collect() # does not warn unawaited

class TestAsyncGenReturn(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = None
asyncio.set_event_loop_policy(None)

def test_async_gen_return_value(self):
async def gen():
yield 1
yield 2
return 3

async def run():
g = gen()
res = []
while True:
try:
res.append(await g.__anext__())
except StopAsyncIteration as e:
res.append(e.value)
break

return res

res = self.loop.run_until_complete(run())
self.assertEqual(res, [1, 2, 3])

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10220,7 +10220,7 @@ def test_special_attrs(self):
# Subscribed ABC classes
typing.AbstractSet[Any]: 'AbstractSet',
typing.AsyncContextManager[Any, Any]: 'AsyncContextManager',
typing.AsyncGenerator[Any, Any]: 'AsyncGenerator',
typing.AsyncGenerator[Any, Any, Any]: 'AsyncGenerator',
typing.AsyncIterable[Any]: 'AsyncIterable',
typing.AsyncIterator[Any]: 'AsyncIterator',
typing.Awaitable[Any]: 'Awaitable',
Expand Down
2 changes: 1 addition & 1 deletion Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2819,7 +2819,7 @@ class Other(Leaf): # Error reported by type checker
Counter = _alias(collections.Counter, 1)
ChainMap = _alias(collections.ChainMap, 2)
Generator = _alias(collections.abc.Generator, 3, defaults=(types.NoneType, types.NoneType))
AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2, defaults=(types.NoneType,))
AsyncGenerator = _alias(collections.abc.AsyncGenerator, 3, defaults=(types.NoneType, types.NoneType))
Type = _alias(type, 1, inst=False, name='Type')
Type.__doc__ = \
"""Deprecated alias to builtins.type.
Expand Down
2 changes: 2 additions & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ Daniel Diniz
Humberto Diogenes
Yves Dionne
Daniel Dittmar
Alex Dixon
Josip Djolonga
Walter Dörwald
Jaromir Dolecek
Expand Down Expand Up @@ -689,6 +690,7 @@ Yuyang Guo
Anuj Gupta
Om Gupta
Michael Guravage
William Guss
Lars Gustäbel
Thomas Güttler
Jonas H.
Expand Down
50 changes: 48 additions & 2 deletions Objects/exceptions.c
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,54 @@ SimpleExtendsException(PyExc_Exception, TypeError,
/*
* StopAsyncIteration extends Exception
*/
SimpleExtendsException(PyExc_Exception, StopAsyncIteration,
"Signal the end from iterator.__anext__().");
static PyMemberDef StopAsyncIteration_members[] = {
{"value", _Py_T_OBJECT, offsetof(PyStopAsyncIterationObject, value), 0,
PyDoc_STR("async generator return value")},
{NULL} /* Sentinel */
};

static int
StopAsyncIteration_init(PyStopAsyncIterationObject *self, PyObject *args, PyObject *kwds)
{
Py_ssize_t size = PyTuple_GET_SIZE(args);
PyObject *value;

if (BaseException_init((PyBaseExceptionObject *)self, args, kwds) == -1)
return -1;
Py_CLEAR(self->value);
if (size > 0)
value = PyTuple_GET_ITEM(args, 0);
else
value = Py_None;
self->value = Py_NewRef(value);
return 0;
}

static int
StopAsyncIteration_clear(PyStopAsyncIterationObject *self)
{
Py_CLEAR(self->value);
return BaseException_clear((PyBaseExceptionObject *)self);
}

static void
StopAsyncIteration_dealloc(PyStopAsyncIterationObject *self)
{
PyObject_GC_UnTrack(self);
StopAsyncIteration_clear(self);
Py_TYPE(self)->tp_free((PyObject *)self);
}

static int
StopAsyncIteration_traverse(PyStopAsyncIterationObject *self, visitproc visit, void *arg)
{
Py_VISIT(self->value);
return BaseException_traverse((PyBaseExceptionObject *)self, visit, arg);
}

ComplexExtendsException(PyExc_Exception, StopAsyncIteration, StopAsyncIteration,
0, 0, StopAsyncIteration_members, 0, 0,
"Signal the end from iterator.__anext__().");


/*
Expand Down
75 changes: 71 additions & 4 deletions Objects/genobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,70 @@ gen_dealloc(PyObject *self)
PyObject_GC_Del(gen);
}


/*
* Set StopAsyncIteration with specified value. Value can be arbitrary object
* or NULL.
*
* Returns 0 if StopAsyncIteration is set and -1 if any other exception is set.
*/
int
_PyGen_SetStopAsyncIterationValue(PyObject *value)
{
PyObject *e;

if (value == NULL ||
(!PyTuple_Check(value) && !PyExceptionInstance_Check(value)))
{
/* Delay exception instantiation if we can */
PyErr_SetObject(PyExc_StopAsyncIteration, value);
return 0;
}
/* Construct an exception instance manually with
* PyObject_CallOneArg and pass it to PyErr_SetObject.
*
* We do this to handle a situation when "value" is a tuple, in which
* case PyErr_SetObject would set the value of StopIteration to
* the first element of the tuple.
*
* (See PyErr_SetObject/_PyErr_CreateException code for details.)
*/
e = PyObject_CallOneArg(PyExc_StopAsyncIteration, value);
if (e == NULL) {
return -1;
}
PyErr_SetObject(PyExc_StopAsyncIteration, e);
Py_DECREF(e);
return 0;
}

/*
* If StopAsyncIteration exception is set, fetches its 'value'
* attribute if any, otherwise sets pvalue to None.
*
* Returns 0 if no exception or StopAsyncIteration is set.
* If any other exception is set, returns -1 and leaves
* pvalue unchanged.
*/

int
_PyGen_FetchStopAsyncIterationValue(PyObject **pvalue)
{
PyObject *value = NULL;
if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
PyObject *exc = PyErr_GetRaisedException();
value = Py_NewRef(((PyStopAsyncIterationObject *)exc)->value);
Py_DECREF(exc);
} else if (PyErr_Occurred()) {
return -1;
}
if (value == NULL) {
value = Py_NewRef(Py_None);
}
*pvalue = value;
return 0;
}

static PySendResult
gen_send_ex2(PyGenObject *gen, PyObject *arg, PyObject **presult,
int exc, int closing)
Expand Down Expand Up @@ -255,8 +319,7 @@ gen_send_ex2(PyGenObject *gen, PyObject *arg, PyObject **presult,
*presult = result;
return PYGEN_NEXT;
}
assert(result == Py_None || !PyAsyncGen_CheckExact(gen));
if (result == Py_None && !PyAsyncGen_CheckExact(gen) && !arg) {
if (result == Py_None && !arg) {
/* Return NULL if called by gen_iternext() */
Py_CLEAR(result);
}
Expand Down Expand Up @@ -286,8 +349,12 @@ gen_send_ex(PyGenObject *gen, PyObject *arg, int exc, int closing)
PyObject *result;
if (gen_send_ex2(gen, arg, &result, exc, closing) == PYGEN_RETURN) {
if (PyAsyncGen_CheckExact(gen)) {
assert(result == Py_None);
PyErr_SetNone(PyExc_StopAsyncIteration);
if (result == Py_None) {
PyErr_SetNone(PyExc_StopAsyncIteration);
}
else {
_PyGen_SetStopAsyncIterationValue(result);
}
}
else if (result == Py_None) {
PyErr_SetNone(PyExc_StopIteration);
Expand Down
3 changes: 0 additions & 3 deletions Python/codegen.c
Original file line number Diff line number Diff line change
Expand Up @@ -2060,9 +2060,6 @@ codegen_return(compiler *c, stmt_ty s)
if (!_PyST_IsFunctionLike(ste)) {
return _PyCompile_Error(c, loc, "'return' outside function");
}
if (s->v.Return.value != NULL && ste->ste_coroutine && ste->ste_generator) {
return _PyCompile_Error(c, loc, "'return' with value in async generator");
}

if (preserve_tos) {
VISIT(c, expr, s->v.Return.value);
Expand Down

0 comments on commit fcabdea

Please sign in to comment.