Skip to content

Commit

Permalink
Remove async Io from test
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 10, 2025
1 parent 9402b6a commit 4f99968
Show file tree
Hide file tree
Showing 14 changed files with 361 additions and 525 deletions.
2 changes: 2 additions & 0 deletions src/spdl/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ._composite import (
load_audio,
load_image,
load_image_batch,
load_image_batch_nvjpeg,
load_video,
sample_decode_video,
Expand Down Expand Up @@ -80,6 +81,7 @@
"load_audio",
"load_video",
"load_image",
"load_image_batch",
"load_image_batch_nvjpeg",
"sample_decode_video",
# DEMUXING
Expand Down
68 changes: 16 additions & 52 deletions src/spdl/io/_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
CUDAConfig,
DecodeConfig,
DemuxConfig,
ImageFrames,
ImagePackets,
VideoPackets,
)
Expand All @@ -32,6 +31,7 @@
"load_audio",
"load_video",
"load_image",
"load_image_batch",
"load_image_batch_nvjpeg",
"sample_decode_video",
]
Expand Down Expand Up @@ -404,7 +404,7 @@ def _decode(src, demux_config, decode_config, filter_desc):


@overload
async def async_load_image_batch(
def load_image_batch(
srcs: list[str | bytes],
*,
width: int | None,
Expand All @@ -421,7 +421,7 @@ async def async_load_image_batch(


@overload
async def async_load_image_batch(
def load_image_batch(
srcs: list[str | bytes],
*,
width: int | None,
Expand All @@ -437,7 +437,7 @@ async def async_load_image_batch(
) -> CUDABuffer: ...


async def async_load_image_batch(
def load_image_batch(
srcs: list[str | bytes],
*,
width,
Expand All @@ -458,32 +458,6 @@ async def async_load_image_batch(
and optionally, :py:func:`~spdl.io.transfer_buffer`, to produce
buffer object from source in one step.
It concurrently demuxes and decodes the input images, using
the :py:class:`~concurrent.futures.ThreadPoolExecutor` attached to
the running async event loop, fetched by :py:func:`~asyncio.get_running_loop`.
.. mermaid::
gantt
title Illustration of asynchronous batch image decoding timeline
dateFormat X
axisFormat %s
section Thread 1
Demux image 1 :demux1, 0, 3
Decode/resize image 1 :after demux1, 20
section Thread 2
Demux image 2 :demux2, 1, 5
Decode/resize image 2 :after demux2, 23
section Thread 3
Demux image 3 :demux3, 2, 5
Decode/resize image 3 :after demux3, 24
section Thread 4
Demux image 4 :demux4, 3, 8
Decode/resize image 4 :decode4, after demux4, 25
section Thread 5
Batch conversion :batch, after decode4, 30
Device Transfer :after batch, 33
Args:
srcs: List of source identifiers.
Expand All @@ -496,25 +470,25 @@ async def async_load_image_batch(
demux_config:
*Optional:* Demux configuration passed to
:py:func:`~spdl.io.async_demux_image`.
:py:func:`~spdl.io.demux_image`.
decode_config:
*Optional:* Decode configuration passed to
:py:func:`~spdl.io.async_decode_packets`.
:py:func:`~spdl.io.decode_packets`.
filter_desc:
*Optional:* Filter description passed to
:py:func:`~spdl.io.async_decode_packets`.
:py:func:`~spdl.io.decode_packets`.
device_config:
*Optional:* The CUDA device passed to
:py:func:`~spdl.io.async_transfer_buffer`.
:py:func:`~spdl.io.transfer_buffer`.
Providing this argument will move the resulting buffer to
the CUDA device.
storage:
*Optional:* The storage object passed to
:py:func:`~spdl.io.async_convert_frames`.
:py:func:`~spdl.io.convert_frames`.
strict:
*Optional:* If True, raise an error if any of the images failed to load.
Expand All @@ -528,13 +502,12 @@ async def async_load_image_batch(
... "sample1.jpg",
... "sample2.png",
... ]
>>> coro = async_load_image_batch(
>>> buffer = load_image_batch(
... srcs,
... scale_width=124,
... scale_height=96,
... pix_fmt="rgb24",
... )
>>> buffer = asyncio.run(coro)
>>> array = spdl.io.to_numpy(buffer)
>>> # An array with shape HWC==[2, 96, 124, 3]
>>>
Expand Down Expand Up @@ -563,34 +536,25 @@ async def async_load_image_batch(
pix_fmt=pix_fmt,
)

tasks = [
asyncio.create_task(
run_async(_decode, src, demux_config, decode_config, filter_desc)
)
for src in srcs
]

await asyncio.wait(tasks)

frames: list[ImageFrames] = []
for src, future in zip(srcs, tasks):
frames = []
for src in srcs:
try:
frms = future.result()
frame = _decode(src, demux_config, decode_config, filter_desc)
except Exception as err:
_LG.error(_get_err_msg(src, err))
else:
frames.append(frms)
frames.append(frame)

if strict and len(frames) != len(srcs):
raise RuntimeError("Failed to load some images.")

if not frames:
raise RuntimeError("Failed to load all the images.")

buffer = await _core.async_convert_frames(frames, storage=storage)
buffer = _core.convert_frames(frames, storage=storage)

if device_config is not None:
buffer = await _core.async_transfer_buffer(buffer, device_config=device_config)
buffer = _core.transfer_buffer(buffer, device_config=device_config)

return buffer

Expand Down
Loading

0 comments on commit 4f99968

Please sign in to comment.