Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(live): add support for using asyncio.Task as an alternative to threading.Thread to handle live updates #3457

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Adds support for using `asyncio.Task` as an alternative to `threading.Thread` to handle live updates
- Adds a `case_sensitive` parameter to `prompt.Prompt`. This determines if the
response is treated as case-sensitive. Defaults to `True`.

Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ The following people have contributed to the development of Rich:
- [Luca Salvarani](https://github.com/LukeSavefrogs)
- [Paul Sanders](https://github.com/sanders41)
- [Tim Savage](https://github.com/timsavage)
- [Dominik Schwabe](https://github.com/dominik-schwabe)
- [Anthony Shaw](https://github.com/tonybaloney)
- [Nicolas Simonds](https://github.com/0xDEC0DE)
- [Aaron Stephens](https://github.com/aaronst)
Expand Down
82 changes: 74 additions & 8 deletions rich/live.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import sys
from threading import Event, RLock, Thread
from types import TracebackType
from typing import IO, Any, Callable, List, Optional, TextIO, Type, cast
from typing import IO, Any, Callable, List, Literal, Optional, TextIO, Type, cast

from typing_extensions import Protocol

from . import get_console
from .console import Console, ConsoleRenderable, RenderableType, RenderHook
Expand All @@ -13,7 +16,20 @@
from .text import Text


class _RefreshThread(Thread):
class _Refresher(Protocol):
"Declares functionality that a refresher needs to implement"

def __init__(self, live: "Live", refresh_per_second: float) -> None:
...

def stop(self) -> None:
...

def start(self) -> None:
...


class _ThreadRefresher(Thread, _Refresher):
"""A thread that calls refresh() at regular intervals."""

def __init__(self, live: "Live", refresh_per_second: float) -> None:
Expand All @@ -32,6 +48,45 @@ def run(self) -> None:
self.live.refresh()


class _AsyncioTaskRefresher(_Refresher):
"""A wrapper around asyncio.Task that calls refresh() at regular intervals."""

def __init__(self, live: "Live", refresh_per_second: float) -> None:
self.live = live
self.refresh_per_second = refresh_per_second
self.task: Optional[asyncio.Task[Any]] = None

def stop(self) -> None:
if self.task is not None:
self.task.cancel()
self.task = None

def start(self) -> None:
if self.task is None:
self.task = asyncio.create_task(self.run())

async def run(self) -> None:
while True:
await asyncio.sleep(1 / self.refresh_per_second)
self.live.refresh()


def _resolve_refresher(
refresh_method: Literal["thread", "asyncio_task", "auto"],
) -> type[_Refresher]:
if refresh_method == "thread":
return _ThreadRefresher
elif refresh_method == "asyncio_task":
return _AsyncioTaskRefresher
else:
try:
asyncio.get_running_loop()
except RuntimeError:
return _ThreadRefresher
else:
return _AsyncioTaskRefresher


class Live(JupyterMixin, RenderHook):
"""Renders an auto-updating live display of any given renderable.

Expand All @@ -46,6 +101,10 @@ class Live(JupyterMixin, RenderHook):
redirect_stderr (bool, optional): Enable redirection of stderr. Defaults to True.
vertical_overflow (VerticalOverflowMethod, optional): How to handle renderable when it is too tall for the console. Defaults to "ellipsis".
get_renderable (Callable[[], RenderableType], optional): Optional callable to get renderable. Defaults to None.
refresh_method (str): The method that be used to handle the live updates, either ``"thread"``
to use a ``threading.Thread`` or ``"asyncio_task"`` to use an ``asnycio.Task``. Defaults to
``"auto"``, which uses an ``asnycio.Task`` if there exists a running event loop and a
``threading.Thread`` otherwise.
"""

def __init__(
Expand All @@ -61,6 +120,7 @@ def __init__(
redirect_stderr: bool = True,
vertical_overflow: VerticalOverflowMethod = "ellipsis",
get_renderable: Optional[Callable[[], RenderableType]] = None,
refresh_method: Literal["thread", "asyncio_task", "auto"] = "auto",
) -> None:
assert refresh_per_second > 0, "refresh_per_second must be > 0"
self._renderable = renderable
Expand All @@ -79,7 +139,8 @@ def __init__(
self._started: bool = False
self.transient = True if screen else transient

self._refresh_thread: Optional[_RefreshThread] = None
self._refresh_method = refresh_method
self._refresher: Optional[_Refresher] = None
self.refresh_per_second = refresh_per_second

self.vertical_overflow = vertical_overflow
Expand All @@ -106,6 +167,10 @@ def start(self, refresh: bool = False) -> None:

Args:
refresh (bool, optional): Also refresh. Defaults to False.
refresh_method (str): The method that be used to handle the live updates, either ``"thread"``
to use a ``threading.Thread`` or ``"asyncio_task"`` to use an ``asnycio.Task``. Defaults to
``"auto"``, which uses an ``asnycio.Task`` if there exists a running event loop and a
``threading.Thread`` otherwise.
"""
with self._lock:
if self._started:
Expand All @@ -128,8 +193,9 @@ def start(self, refresh: bool = False) -> None:
self.stop()
raise
if self.auto_refresh:
self._refresh_thread = _RefreshThread(self, self.refresh_per_second)
self._refresh_thread.start()
refresher_class = _resolve_refresher(self._refresh_method)
self._refresher = refresher_class(self, self.refresh_per_second)
self._refresher.start()

def stop(self) -> None:
"""Stop live rendering display."""
Expand All @@ -139,9 +205,9 @@ def stop(self) -> None:
self.console.clear_live()
self._started = False

if self.auto_refresh and self._refresh_thread is not None:
self._refresh_thread.stop()
self._refresh_thread = None
if self.auto_refresh and self._refresher is not None:
self._refresher.stop()
self._refresher = None
# allow it to fully render on the last even if overflow
self.vertical_overflow = "visible"
with self.console:
Expand Down
24 changes: 22 additions & 2 deletions tests/test_live.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
# encoding=utf-8
import asyncio
import time
from typing import Optional
from typing import Literal, Optional

# import pytest
from rich.console import Console
from rich.live import Live
from rich.live import Live, _AsyncioTaskRefresher, _ThreadRefresher
from rich.text import Text


def test_refresher() -> None:
def get_refresher(refresh_method: Literal["thread", "asyncio_task", "auto"]):
with Live("", refresh_method=refresh_method) as live:
return live._refresher

async def async_get_refresher(
refresh_method: Literal["thread", "asyncio_task", "auto"],
):
return get_refresher(refresh_method)

assert isinstance(get_refresher("thread"), _ThreadRefresher)
assert isinstance(asyncio.run(async_get_refresher("thread")), _ThreadRefresher)
assert isinstance(
asyncio.run(async_get_refresher("asyncio_task")), _AsyncioTaskRefresher
)
assert isinstance(get_refresher("auto"), _ThreadRefresher)
assert isinstance(asyncio.run(async_get_refresher("auto")), _AsyncioTaskRefresher)


def create_capture_console(
*, width: int = 60, height: int = 80, force_terminal: Optional[bool] = True
) -> Console:
Expand Down