Skip to content

Commit

Permalink
Manually import functions/classes (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jan 9, 2025
1 parent 167c4fc commit 9402b6a
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 29 deletions.
151 changes: 126 additions & 25 deletions src/spdl/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,130 @@

# This has to happen before other sub modules are imporeted.
# Otherwise circular import would occur.
#
# I know, I should not use `*`. I don't want to either, but
# for creating annotation for types from C++ code, which might not be
# available at the runtime, while simultaneously pleasing all the linters
# (black, flake8 and pyre) and documentation tools, this seems like
# the simplest solution.
# This import is just for annotation, so please overlook this one.
from ._type_stub import * # noqa: F403 # isort: skip

from . import _composite, _config, _convert, _core, _preprocessing, _type_stub, _zip

_mods = [
_composite,
_config,
_convert,
_core,
_preprocessing,
_zip,
from ._type_stub import ( # isort: skip
CPUBuffer,
CUDABuffer,
Packets,
AudioPackets,
VideoPackets,
ImagePackets,
Frames,
AudioFrames,
VideoFrames,
ImageFrames,
DemuxConfig,
DecodeConfig,
EncodeConfig,
CUDAConfig,
CPUStorage,
)

from ._composite import (
load_audio,
load_image,
load_image_batch_nvjpeg,
load_video,
sample_decode_video,
)
from ._config import (
cpu_storage,
cuda_config,
decode_config,
demux_config,
encode_config,
)
from ._convert import (
to_jax,
to_numba,
to_numpy,
to_torch,
)
from ._core import (
convert_array,
convert_frames,
decode_image_nvjpeg,
decode_packets,
decode_packets_nvdec,
demux_audio,
demux_image,
demux_video,
Demuxer,
encode_image,
streaming_decode_packets,
transfer_buffer,
transfer_buffer_cpu,
)
from ._preprocessing import (
get_audio_filter_desc,
get_filter_desc,
get_video_filter_desc,
)
from ._zip import (
load_npz,
NpzFile,
)

__all__ = [
# HIGH LEVEL API
"load_audio",
"load_video",
"load_image",
"load_image_batch_nvjpeg",
"sample_decode_video",
# DEMUXING
"Demuxer",
"demux_audio",
"demux_video",
"demux_image",
"Packets",
"AudioPackets",
"VideoPackets",
"ImagePackets",
# DECODING
"decode_packets",
"decode_packets_nvdec",
"streaming_decode_packets",
"decode_image_nvjpeg",
"Frames",
"AudioFrames",
"VideoFrames",
"ImageFrames",
# PREPROCESSING
"get_audio_filter_desc",
"get_video_filter_desc",
"get_filter_desc",
# FRAME CONVERSION
"convert_array",
"convert_frames",
"CPUBuffer",
"CUDABuffer",
# DATA TRANSFER
"transfer_buffer",
"transfer_buffer_cpu",
# CAST
"to_numba",
"to_numpy",
"to_torch",
"to_jax",
# ENCODING
"encode_image",
# CONFIG
"demux_config",
"DemuxConfig",
"decode_config",
"DecodeConfig",
"encode_config",
"EncodeConfig",
"cuda_config",
"CUDAConfig",
"cpu_storage",
"CPUStorage",
# NUMPY
"NpzFile",
"load_npz",
]


__all__ = sorted(item for mod in [*_mods, _type_stub] for item in mod.__all__)


def __dir__():
return __all__

Expand Down Expand Up @@ -77,11 +177,12 @@ def __getattr__(name: str) -> Any:
)

if name in _deprecated_core:
from . import _core

return getattr(_core, name)
return getattr(_composite, name)

for mod in _mods:
if name in mod.__all__:
return getattr(mod, name)
from . import _composite

return getattr(_composite, name)

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
16 changes: 12 additions & 4 deletions src/spdl/io/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,16 @@
UintArray = NDArray[np.uint8]
except ImportError:
UintArray = np.ndarray

try:
import torch

Tensor = torch.Tensor
except ImportError:
Tensor = object
else:
UintArray = object
Tensor = object


from spdl.io import (
Expand Down Expand Up @@ -123,7 +131,7 @@ class Demuxer:
demux_config (DemuxConfig): Custom I/O config.
"""

def __init__(self, src: str | Path | bytes | UintArray, **kwargs):
def __init__(self, src: str | Path | bytes | UintArray | Tensor, **kwargs):
if isinstance(src, Path):
src = str(src)
self._demuxer = _libspdl._demuxer(src, **kwargs)
Expand Down Expand Up @@ -185,7 +193,7 @@ async def __aexit__(self, exc_type, exc_value, exc_traceback) -> None:


def demux_audio(
src: str | bytes | UintArray,
src: str | bytes | UintArray | Tensor,
*,
timestamp: tuple[float, float] | None = None,
**kwargs,
Expand Down Expand Up @@ -215,7 +223,7 @@ async def async_demux_audio(


def demux_video(
src: str | bytes | UintArray,
src: str | bytes | UintArray | Tensor,
*,
timestamp: tuple[float, float] | None = None,
**kwargs,
Expand Down Expand Up @@ -245,7 +253,7 @@ async def async_demux_video(
return await run_async(demux_video, src, timestamp=timestamp, **kwargs)


def demux_image(src: str | bytes | UintArray, **kwargs) -> ImagePackets:
def demux_image(src: str | bytes | UintArray | Tensor, **kwargs) -> ImagePackets:
"""Demux image from the source.
Args:
Expand Down

0 comments on commit 9402b6a

Please sign in to comment.