From 26fa4f7b1f3e57999ecf8acb59dbcdc8f5e216a9 Mon Sep 17 00:00:00 2001 From: Dominik Schwabe Date: Sun, 18 Aug 2024 21:45:04 +0200 Subject: [PATCH] feat(live): add support for using asyncio.Task as an alternative to threading.Thread to handle live updates --- CHANGELOG.md | 1 + CONTRIBUTORS.md | 1 + rich/live.py | 82 +++++++++++++++++++++++++++++++++++++++++----- tests/test_live.py | 24 ++++++++++++-- 4 files changed, 98 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1a042249..727a2987f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Adds support for using `asyncio.Task` as an alternative to `threading.Thread` to handle live updates - Adds a `case_sensitive` parameter to `prompt.Prompt`. This determines if the response is treated as case-sensitive. Defaults to `True`. diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index edacc5885..72e1d4f19 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -64,6 +64,7 @@ The following people have contributed to the development of Rich: - [Luca Salvarani](https://github.com/LukeSavefrogs) - [Paul Sanders](https://github.com/sanders41) - [Tim Savage](https://github.com/timsavage) +- [Dominik Schwabe](https://github.com/dominik-schwabe) - [Anthony Shaw](https://github.com/tonybaloney) - [Nicolas Simonds](https://github.com/0xDEC0DE) - [Aaron Stephens](https://github.com/aaronst) diff --git a/rich/live.py b/rich/live.py index f0529a781..9a350f948 100644 --- a/rich/live.py +++ b/rich/live.py @@ -1,7 +1,10 @@ +import asyncio import sys from threading import Event, RLock, Thread from types import TracebackType -from typing import IO, Any, Callable, List, Optional, TextIO, Type, cast +from typing import IO, Any, Callable, List, Literal, Optional, TextIO, Type, cast + +from typing_extensions import Protocol from . import get_console from .console import Console, ConsoleRenderable, RenderableType, RenderHook @@ -13,7 +16,20 @@ from .text import Text -class _RefreshThread(Thread): +class _Refresher(Protocol): + "Declares functionality that a refresher needs to implement" + + def __init__(self, live: "Live", refresh_per_second: float) -> None: + ... + + def stop(self) -> None: + ... + + def start(self) -> None: + ... + + +class _ThreadRefresher(Thread, _Refresher): """A thread that calls refresh() at regular intervals.""" def __init__(self, live: "Live", refresh_per_second: float) -> None: @@ -32,6 +48,45 @@ def run(self) -> None: self.live.refresh() +class _AsyncioTaskRefresher(_Refresher): + """A wrapper around asyncio.Task that calls refresh() at regular intervals.""" + + def __init__(self, live: "Live", refresh_per_second: float) -> None: + self.live = live + self.refresh_per_second = refresh_per_second + self.task: Optional[asyncio.Task[Any]] = None + + def stop(self) -> None: + if self.task is not None: + self.task.cancel() + self.task = None + + def start(self) -> None: + if self.task is None: + self.task = asyncio.create_task(self.run()) + + async def run(self) -> None: + while True: + await asyncio.sleep(1 / self.refresh_per_second) + self.live.refresh() + + +def _resolve_refresher( + refresh_method: Literal["thread", "asyncio_task", "auto"], +) -> type[_Refresher]: + if refresh_method == "thread": + return _ThreadRefresher + elif refresh_method == "asyncio_task": + return _AsyncioTaskRefresher + else: + try: + asyncio.get_running_loop() + except RuntimeError: + return _ThreadRefresher + else: + return _AsyncioTaskRefresher + + class Live(JupyterMixin, RenderHook): """Renders an auto-updating live display of any given renderable. @@ -46,6 +101,10 @@ class Live(JupyterMixin, RenderHook): redirect_stderr (bool, optional): Enable redirection of stderr. Defaults to True. vertical_overflow (VerticalOverflowMethod, optional): How to handle renderable when it is too tall for the console. Defaults to "ellipsis". get_renderable (Callable[[], RenderableType], optional): Optional callable to get renderable. Defaults to None. + refresh_method (str): The method that be used to handle the live updates, either ``"thread"`` + to use a ``threading.Thread`` or ``"asyncio_task"`` to use an ``asnycio.Task``. Defaults to + ``"auto"``, which uses an ``asnycio.Task`` if there exists a running event loop and a + ``threading.Thread`` otherwise. """ def __init__( @@ -61,6 +120,7 @@ def __init__( redirect_stderr: bool = True, vertical_overflow: VerticalOverflowMethod = "ellipsis", get_renderable: Optional[Callable[[], RenderableType]] = None, + refresh_method: Literal["thread", "asyncio_task", "auto"] = "auto", ) -> None: assert refresh_per_second > 0, "refresh_per_second must be > 0" self._renderable = renderable @@ -79,7 +139,8 @@ def __init__( self._started: bool = False self.transient = True if screen else transient - self._refresh_thread: Optional[_RefreshThread] = None + self._refresh_method = refresh_method + self._refresher: Optional[_Refresher] = None self.refresh_per_second = refresh_per_second self.vertical_overflow = vertical_overflow @@ -106,6 +167,10 @@ def start(self, refresh: bool = False) -> None: Args: refresh (bool, optional): Also refresh. Defaults to False. + refresh_method (str): The method that be used to handle the live updates, either ``"thread"`` + to use a ``threading.Thread`` or ``"asyncio_task"`` to use an ``asnycio.Task``. Defaults to + ``"auto"``, which uses an ``asnycio.Task`` if there exists a running event loop and a + ``threading.Thread`` otherwise. """ with self._lock: if self._started: @@ -128,8 +193,9 @@ def start(self, refresh: bool = False) -> None: self.stop() raise if self.auto_refresh: - self._refresh_thread = _RefreshThread(self, self.refresh_per_second) - self._refresh_thread.start() + refresher_class = _resolve_refresher(self._refresh_method) + self._refresher = refresher_class(self, self.refresh_per_second) + self._refresher.start() def stop(self) -> None: """Stop live rendering display.""" @@ -139,9 +205,9 @@ def stop(self) -> None: self.console.clear_live() self._started = False - if self.auto_refresh and self._refresh_thread is not None: - self._refresh_thread.stop() - self._refresh_thread = None + if self.auto_refresh and self._refresher is not None: + self._refresher.stop() + self._refresher = None # allow it to fully render on the last even if overflow self.vertical_overflow = "visible" with self.console: diff --git a/tests/test_live.py b/tests/test_live.py index f037e4b8b..98935de38 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -1,13 +1,33 @@ # encoding=utf-8 +import asyncio import time -from typing import Optional +from typing import Literal, Optional # import pytest from rich.console import Console -from rich.live import Live +from rich.live import Live, _AsyncioTaskRefresher, _ThreadRefresher from rich.text import Text +def test_refresher() -> None: + def get_refresher(refresh_method: Literal["thread", "asyncio_task", "auto"]): + with Live("", refresh_method=refresh_method) as live: + return live._refresher + + async def async_get_refresher( + refresh_method: Literal["thread", "asyncio_task", "auto"], + ): + return get_refresher(refresh_method) + + assert isinstance(get_refresher("thread"), _ThreadRefresher) + assert isinstance(asyncio.run(async_get_refresher("thread")), _ThreadRefresher) + assert isinstance( + asyncio.run(async_get_refresher("asyncio_task")), _AsyncioTaskRefresher + ) + assert isinstance(get_refresher("auto"), _ThreadRefresher) + assert isinstance(asyncio.run(async_get_refresher("auto")), _AsyncioTaskRefresher) + + def create_capture_console( *, width: int = 60, height: int = 80, force_terminal: Optional[bool] = True ) -> Console: