From cc16746e5681585e8bbcdb92bf9a8319867c33d0 Mon Sep 17 00:00:00 2001 From: Pavel Tomskikh Date: Wed, 28 Feb 2024 15:07:34 +0600 Subject: [PATCH] Update savant-rs to 0.2.13 (#673) * Switch to Savant-Rs 0.2 (#622) * #612 update savant-rs to 0.2.1 * #612 move message deserialization to zeromq_src * #612 fix always-on-sink * #612 move savant_rs serialization to zeromq_sink * #612 use savant_rs.zmq * #612 update savant-rs to 0.2.5 * #612 fix adapters * #612 fix receiving EOS from ZeroMQ * #612 fix adapters * #612 fix video files sink adapter * #612 update savant-rs to 0.2.9 * #612 don't save non-persistent attributes to VideoFrame * #612 fix zeromq_sink gst plugin * #612 fix adapters * #612 fix module * #612 fix configuring zmq * #612 measure performance on A4000 * #612 update savant-rs to 0.2.12 * #612 fix copying bboxes * #612 update savant-rs to 0.2.13 * #612 fix benchmark for panoptic_driving_perception * #612 measure performance * #612 fix typo * Update version to 0.2.10 --- adapters/ds/sinks/always_on_rtsp/__main__.py | 10 +- adapters/ds/sinks/always_on_rtsp/config.py | 20 +- .../ds/sinks/always_on_rtsp/input_pipeline.py | 41 +- .../ds/sinks/always_on_rtsp/zeromq_proxy.py | 26 +- .../python/savant_rs_serializer.py | 505 --------------- .../gst_plugins/python/video_files_sink.py | 355 ----------- .../gst/gst_plugins/python/zeromq_sink.py | 589 ++++++++++++++---- adapters/gst/sinks/video_files.py | 412 ++++++++++++ adapters/gst/sinks/video_files.sh | 41 +- adapters/gst/sources/ffmpeg.sh | 12 +- adapters/gst/sources/gige_cam.sh | 2 +- adapters/gst/sources/media_files.sh | 18 +- adapters/gst/sources/multi_stream.sh | 34 +- adapters/gst/sources/rtsp.sh | 12 +- adapters/gst/sources/usb_cam.sh | 2 +- adapters/gst/sources/video_loop.sh | 5 +- adapters/python/bridge/buffer.py | 92 ++- adapters/python/sinks/chunk_writer.py | 26 +- adapters/python/sinks/image_files.py | 56 +- adapters/python/sinks/kafka_redis.py | 6 +- adapters/python/sinks/metadata_json.py | 43 +- docker/Dockerfile.adapters-gstreamer | 1 - docs/performance.md | 19 + docs/source/savant_101/10_adapters.rst | 2 +- .../python/savant_rs_video_decode_bin.py | 16 +- gst_plugins/python/savant_rs_video_demux.py | 532 ++++------------ gst_plugins/python/zeromq_source_bin.py | 93 ++- gst_plugins/python/zeromq_src.py | 465 +++++++++++++- libs/gstsavantframemeta/CMakeLists.txt | 17 +- .../gstsavantframemeta/CMakeLists.txt | 19 +- .../gstsavantframemeta/src/savantrsprobes.cpp | 4 + requirements/savant-rs.txt | 2 +- samples/animegan/docker-compose.x86.yml | 2 +- .../panoptic_driving_perception/run_perf.sh | 4 +- savant/VERSION | 2 +- savant/api/builder.py | 31 +- savant/api/parser.py | 6 +- savant/client/image_source/image_source.py | 11 +- savant/client/runner/sink.py | 26 +- savant/client/runner/source.py | 154 ++--- savant/deepstream/buffer_processor.py | 6 +- savant/deepstream/meta/frame.py | 12 +- savant/deepstream/metadata.py | 19 +- savant/deepstream/pipeline.py | 1 + savant/deepstream/utils/pipeline.py | 1 + savant/gstreamer/event.py | 35 ++ savant/parameter_storage/__init__.py | 2 +- savant/utils/artist/artist_gpumat.py | 6 +- savant/utils/re_patterns.py | 1 - savant/utils/sink_factories.py | 92 ++- savant/utils/zeromq.py | 494 ++++----------- scripts/run_sink.py | 2 +- 52 files changed, 2085 insertions(+), 2299 deletions(-) delete mode 100644 adapters/gst/gst_plugins/python/savant_rs_serializer.py delete mode 100644 adapters/gst/gst_plugins/python/video_files_sink.py create mode 100755 adapters/gst/sinks/video_files.py create mode 100644 savant/gstreamer/event.py diff --git a/adapters/ds/sinks/always_on_rtsp/__main__.py b/adapters/ds/sinks/always_on_rtsp/__main__.py index 7213bd833..1eb24dc16 100644 --- a/adapters/ds/sinks/always_on_rtsp/__main__.py +++ b/adapters/ds/sinks/always_on_rtsp/__main__.py @@ -82,18 +82,18 @@ def main(): if not config.source_id: internal_socket = 'ipc:///tmp/ao-sink-internal-socket.ipc' - internal_zmq_endpoint = f'sub+connect:{internal_socket}' + zmq_reader_endpoint = f'sub+connect:{internal_socket}' zmq_proxy = ZeroMqProxy( input_socket=config.zmq_endpoint, input_socket_type=config.zmq_socket_type, input_bind=config.zmq_socket_bind, - output_socket=internal_socket, + output_socket=f'pub+bind:{internal_socket}', ) zmq_proxy.start() zmq_proxy_thread = Thread(target=zmq_proxy.run, daemon=True) zmq_proxy_thread.start() else: - internal_zmq_endpoint = config.zmq_endpoint + zmq_reader_endpoint = config.zmq_endpoint if config.dev_mode: mediamtx_process = Popen( @@ -114,7 +114,7 @@ def main(): config.source_id: run_ao_sink_process( config.source_id, config.rtsp_uri, - internal_zmq_endpoint, + zmq_reader_endpoint, ) } else: @@ -122,7 +122,7 @@ def main(): source_id: run_ao_sink_process( source_id, f'{config.rtsp_uri.rstrip("/")}/{source_id}', - internal_zmq_endpoint, + zmq_reader_endpoint, ) for source_id in config.source_ids } diff --git a/adapters/ds/sinks/always_on_rtsp/config.py b/adapters/ds/sinks/always_on_rtsp/config.py index e09736a39..005a5f45e 100644 --- a/adapters/ds/sinks/always_on_rtsp/config.py +++ b/adapters/ds/sinks/always_on_rtsp/config.py @@ -55,16 +55,16 @@ def __init__(self): self.fps_output = opt_config('FPS_OUTPUT', 'stdout') self.metadata_output = opt_config('METADATA_OUTPUT') - if self.metadata_output: - self.pipeline_stage_name = 'source' - self.video_pipeline: Optional[VideoPipeline] = VideoPipeline( - 'always-on-sink', - [(self.pipeline_stage_name, VideoPipelineStagePayloadType.Frame)], - VideoPipelineConfiguration(), - ) - else: - self.pipeline_stage_name = None - self.video_pipeline: Optional[VideoPipeline] = None + self.pipeline_source_stage_name = 'source' + self.pipeline_demux_stage_name = 'source-demux' + self.video_pipeline: Optional[VideoPipeline] = VideoPipeline( + 'always-on-sink', + [ + (self.pipeline_source_stage_name, VideoPipelineStagePayloadType.Frame), + (self.pipeline_demux_stage_name, VideoPipelineStagePayloadType.Frame), + ], + VideoPipelineConfiguration(), + ) self.framerate = opt_config('FRAMERATE', '30/1') self.sync = opt_config('SYNC_OUTPUT', False, strtobool) diff --git a/adapters/ds/sinks/always_on_rtsp/input_pipeline.py b/adapters/ds/sinks/always_on_rtsp/input_pipeline.py index 61df87dbb..6196754eb 100644 --- a/adapters/ds/sinks/always_on_rtsp/input_pipeline.py +++ b/adapters/ds/sinks/always_on_rtsp/input_pipeline.py @@ -34,6 +34,21 @@ def log_frame_metadata(pad: Gst.Pad, info: Gst.PadProbeInfo, config: Config): return Gst.PadProbeReturn.OK +def delete_frame_from_pipeline(pad: Gst.Pad, info: Gst.PadProbeInfo, config: Config): + buffer: Gst.Buffer = info.get_buffer() + savant_frame_meta = gst_buffer_get_savant_frame_meta(buffer) + if savant_frame_meta is None: + logger.warning( + 'Source %s. No Savant Frame Metadata found on buffer with PTS %s.', + config.source_id, + buffer.pts, + ) + return Gst.PadProbeReturn.PASS + + config.video_pipeline.delete(savant_frame_meta.idx) + return Gst.PadProbeReturn.OK + + def link_added_pad( element: Gst.Element, src_pad: Gst.Pad, @@ -61,6 +76,8 @@ def on_demuxer_pad_added( codec = CODEC_BY_CAPS_NAME[caps[0].get_name()] if config.metadata_output: src_pad.add_probe(Gst.PadProbeType.BUFFER, log_frame_metadata, config) + else: + src_pad.add_probe(Gst.PadProbeType.BUFFER, delete_frame_from_pipeline, config) if codec == Codec.RAW_RGBA: capsfilter = factory.create( @@ -93,25 +110,25 @@ def build_input_pipeline( factory: GstElementFactory, ): pipeline: Gst.Pipeline = Gst.Pipeline.new('input-pipeline') - savant_rs_video_demux_properties = { + zeromq_src_properties = { + 'source-id': config.source_id, + 'socket': config.zmq_endpoint, + 'socket-type': config.zmq_socket_type.name, + 'bind': config.zmq_socket_bind, 'max-width': config.max_allowed_resolution[0], 'max-height': config.max_allowed_resolution[1], + 'pipeline': config.video_pipeline, + 'pipeline-stage-name': config.pipeline_source_stage_name, + } + savant_rs_video_demux_properties = { + 'pipeline': config.video_pipeline, + 'pipeline-stage-name': config.pipeline_demux_stage_name, } - if config.pipeline_stage_name is not None: - savant_rs_video_demux_properties[ - 'pipeline-stage-name' - ] = config.pipeline_stage_name - savant_rs_video_demux_properties['pipeline'] = config.video_pipeline source_elements = [ PipelineElement( 'zeromq_src', - properties={ - 'source-id': config.source_id, - 'socket': config.zmq_endpoint, - 'socket-type': config.zmq_socket_type.name, - 'bind': config.zmq_socket_bind, - }, + properties=zeromq_src_properties, ), PipelineElement( 'savant_rs_video_demux', diff --git a/adapters/ds/sinks/always_on_rtsp/zeromq_proxy.py b/adapters/ds/sinks/always_on_rtsp/zeromq_proxy.py index d37738b33..2b57371a8 100644 --- a/adapters/ds/sinks/always_on_rtsp/zeromq_proxy.py +++ b/adapters/ds/sinks/always_on_rtsp/zeromq_proxy.py @@ -1,13 +1,13 @@ from typing import Optional -import zmq +from savant_rs.zmq import BlockingWriter, WriterConfigBuilder -from savant.utils.zeromq import Defaults, SenderSocketTypes, ZeroMQSource +from savant.utils.zeromq import ZeroMQSource class ZeroMqProxy: """A proxy that receives messages from a ZeroMQ socket and forwards them - to another PUB ZeroMQ socket. Needed for multi-stream Always-On-RTSP sink. + to another ZeroMQ socket. Needed for multi-stream Always-On-RTSP sink. """ def __init__( @@ -22,19 +22,21 @@ def __init__( socket_type=input_socket_type, bind=input_bind, ) - self.output_socket = output_socket - self.sender: Optional[zmq.Socket] = None - self.output_zmq_context: Optional[zmq.Context] = None + writer_config_builder = WriterConfigBuilder(output_socket) + self.writer_config = writer_config_builder.build() + self.sender: Optional[BlockingWriter] = None def start(self): - self.output_zmq_context = zmq.Context() - self.sender = self.output_zmq_context.socket(SenderSocketTypes.PUB.value) - self.sender.setsockopt(zmq.SNDHWM, Defaults.SEND_HWM) - self.sender.bind(self.output_socket) + self.sender = BlockingWriter(self.writer_config) + self.sender.start() self.source.start() def run(self): while True: - message = self.source.next_message_without_routing_id() + message = self.source.next_message() if message is not None: - self.sender.send_multipart(message) + self.sender.send_message( + bytes(message.topic).decode(), + message.message, + message.content, + ) diff --git a/adapters/gst/gst_plugins/python/savant_rs_serializer.py b/adapters/gst/gst_plugins/python/savant_rs_serializer.py deleted file mode 100644 index 28e1dd015..000000000 --- a/adapters/gst/gst_plugins/python/savant_rs_serializer.py +++ /dev/null @@ -1,505 +0,0 @@ -import json -from fractions import Fraction -from pathlib import Path -from typing import Any, Dict, List, NamedTuple, Optional, Tuple - -from savant_rs.primitives import ( - Attribute, - AttributeValue, - EndOfStream, - Shutdown, - VideoFrame, - VideoFrameContent, - VideoFrameTransformation, -) -from savant_rs.utils.serialization import Message, save_message_to_bytes -from splitstream import splitfile - -from savant.api.builder import add_objects_to_video_frame -from savant.api.constants import DEFAULT_FRAMERATE, DEFAULT_NAMESPACE, DEFAULT_TIME_BASE -from savant.api.enums import ExternalFrameType -from savant.gstreamer import GObject, Gst, GstBase -from savant.gstreamer.codecs import CODEC_BY_CAPS_NAME, Codec -from savant.gstreamer.utils import gst_buffer_from_list -from savant.utils.logging import LoggerMixin - -EMBEDDED_FRAME_TYPE = 'embedded' -DEFAULT_SOURCE_ID_PATTERN = 'source-%d' - - -class FrameParams(NamedTuple): - """Frame parameters.""" - - codec_name: str - width: int - height: int - framerate: str - - -def is_videoframe_metadata(metadata: Dict[str, Any]) -> bool: - """Check that metadata contained if metadata is a video frame metadata. .""" - if 'schema' in metadata and metadata['schema'] != 'VideoFrame': - return False - return True - - -class SavantRsSerializer(LoggerMixin, GstBase.BaseTransform): - """GStreamer plugin to serialize video stream to savant-rs message.""" - - GST_PLUGIN_NAME: str = 'savant_rs_serializer' - - __gstmetadata__ = ( - 'Serializes video stream to savant-rs messages', - 'Transform', - 'Serializes video stream to savant-rs message', - 'Pavel Tomskikh ', - ) - - __gsttemplates__ = ( - Gst.PadTemplate.new( - 'src', - Gst.PadDirection.SRC, - Gst.PadPresence.ALWAYS, - Gst.Caps.new_any(), - ), - Gst.PadTemplate.new( - 'sink', - Gst.PadDirection.SINK, - Gst.PadPresence.ALWAYS, - Gst.Caps.from_string(';'.join(x.value.caps_with_params for x in Codec)), - ), - ) - - __gproperties__ = { - 'source-id': ( - str, - 'Source ID', - 'Source ID, e.g. "camera1".', - None, - GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, - ), - 'location': ( - str, - 'Source location', - 'Source location', - None, - GObject.ParamFlags.READWRITE, - ), - # TODO: make fraction - 'framerate': ( - str, - 'Default framerate', - 'Default framerate', - DEFAULT_FRAMERATE, - GObject.ParamFlags.READWRITE, - ), - 'eos-on-file-end': ( - bool, - 'Send EOS at the end of each file', - 'Send EOS at the end of each file', - True, - GObject.ParamFlags.READWRITE, - ), - 'eos-on-loop-end': ( - bool, - 'Send EOS on a loop end', - 'Send EOS on a loop end', - False, - GObject.ParamFlags.READWRITE, - ), - 'eos-on-frame-params-change': ( - bool, - 'Send EOS when frame parameters changed', - 'Send EOS when frame parameters changed', - True, - GObject.ParamFlags.READWRITE, - ), - 'read-metadata': ( - bool, - 'Read metadata', - 'Attempt to read the metadata of objects from the JSON file that has the identical name ' - 'as the source file with `json` extension, and then send it to the module.', - False, - GObject.ParamFlags.READWRITE, - ), - 'frame-type': ( - str, - 'Frame type.', - 'Frame type (allowed: ' - f'{", ".join([EMBEDDED_FRAME_TYPE] + [enum_member.value for enum_member in ExternalFrameType])})', - None, - GObject.ParamFlags.READWRITE, - ), - 'enable-multistream': ( - bool, - 'Enable multistream', - 'Enable multistream', - False, - GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, - ), - 'source-id-pattern': ( - str, - 'Pattern for source ID', - 'Pattern for source ID when multistream is enabled. E.g. "source-%d".', - DEFAULT_SOURCE_ID_PATTERN, - GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, - ), - 'number-of-streams': ( - int, - 'Number of streams', - 'Number of streams', - 1, # min - 1024, # max - 1, - GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, - ), - 'shutdown-auth': ( - str, - 'Authentication key for Shutdown message.', - 'Authentication key for Shutdown message. When specified, a shutdown' - 'message will be sent at the end of the stream.', - None, - GObject.ParamFlags.READWRITE, - ), - } - - def __init__(self): - super().__init__() - # properties - self.source_id: Optional[str] = None - self.eos_on_file_end: bool = True - self.eos_on_loop_end: bool = False - self.eos_on_frame_params_change: bool = True - self.enable_multistream: bool = False - self.source_id_pattern: str = DEFAULT_SOURCE_ID_PATTERN - self.number_of_streams: int = 1 - self.shutdown_auth: Optional[str] = None - # will be set after caps negotiation - self.frame_params: Optional[FrameParams] = None - self.initial_size_transformation: Optional[VideoFrameTransformation] = None - self.last_frame_params: Optional[FrameParams] = None - self.location: Optional[Path] = None - self.last_location: Optional[Path] = None - self.new_loop: bool = False - self.default_framerate: str = DEFAULT_FRAMERATE - self.frame_type: Optional[ExternalFrameType] = ExternalFrameType.ZEROMQ - - self.source_ids_and_topics: List[Tuple[str, bytes]] = [] - self.stream_in_progress = False - self.read_metadata: bool = False - self.json_metadata = None - self.frame_num = 0 - - def do_set_caps( # pylint: disable=unused-argument - self, in_caps: Gst.Caps, out_caps: Gst.Caps - ): - """Checks caps after negotiations.""" - self.logger.info('Sink caps changed to %s', in_caps) - struct: Gst.Structure = in_caps.get_structure(0) - try: - codec = CODEC_BY_CAPS_NAME[struct.get_name()] - except KeyError: - self.logger.error('Not supported caps: %s', in_caps.to_string()) - return False - codec_name = codec.value.name - frame_width = struct.get_int('width').value - frame_height = struct.get_int('height').value - if struct.has_field('framerate'): - _, framerate_num, framerate_demon = struct.get_fraction('framerate') - framerate = f'{framerate_num}/{framerate_demon}' - else: - framerate = self.default_framerate - self.frame_params = FrameParams( - codec_name=codec_name, - width=frame_width, - height=frame_height, - framerate=framerate, - ) - self.initial_size_transformation = VideoFrameTransformation.initial_size( - frame_width, - frame_height, - ) - - return True - - def do_get_property(self, prop: GObject.GParamSpec): - """Gst plugin get property function. - - :param prop: structure that encapsulates the parameter info - """ - if prop.name == 'source-id': - return self.source_id - if prop.name == 'location': - return self.location - if prop.name == 'framerate': - return self.default_framerate - if prop.name == 'eos-on-file-end': - return self.eos_on_file_end - if prop.name == 'eos-on-loop-end': - return self.eos_on_loop_end - if prop.name == 'eos-on-frame-params-change': - return self.eos_on_frame_params_change - if prop.name == 'read-metadata': - return self.read_metadata - if prop.name == 'frame-type': - if self.frame_type is None: - return EMBEDDED_FRAME_TYPE - return self.frame_type.value - if prop.name == 'enable-multistream': - return self.enable_multistream - if prop.name == 'source-id-pattern': - return self.source_id_pattern - if prop.name == 'number-of-streams': - return self.number_of_streams - if prop.name == 'shutdown-auth': - return self.shutdown_auth - raise AttributeError(f'Unknown property {prop.name}.') - - def do_set_property(self, prop: GObject.GParamSpec, value: Any): - """Gst plugin set property function. - - :param prop: structure that encapsulates the parameter info - :param value: new value for parameter, type dependents on parameter - """ - if prop.name == 'source-id': - self.source_id = value - elif prop.name == 'location': - self.location = value - elif prop.name == 'framerate': - try: - Fraction(value) # validate - except (ZeroDivisionError, ValueError) as e: - raise AttributeError(f'Invalid property {prop.name}: {e}.') from e - self.default_framerate = value - elif prop.name == 'eos-on-file-end': - self.eos_on_file_end = value - elif prop.name == 'eos-on-loop-end': - self.eos_on_loop_end = value - elif prop.name == 'eos-on-frame-params-change': - self.eos_on_frame_params_change = value - elif prop.name == 'read-metadata': - self.read_metadata = value - elif prop.name == 'frame-type': - if value == EMBEDDED_FRAME_TYPE: - self.frame_type = None - else: - self.frame_type = ExternalFrameType(value) - elif prop.name == 'enable-multistream': - self.enable_multistream = value - elif prop.name == 'source-id-pattern': - self.source_id_pattern = value - elif prop.name == 'number-of-streams': - self.number_of_streams = value - elif prop.name == 'shutdown-auth': - self.shutdown_auth = value - else: - raise AttributeError(f'Unknown property {prop.name}.') - - def do_start(self): - if self.enable_multistream: - if self.source_id_pattern is None: - self.logger.error( - 'Source ID pattern is required when enable-multistream=true.' - ) - return False - try: - source_ids = [ - self.source_id_pattern % i for i in range(self.number_of_streams) - ] - except TypeError as e: - self.logger.error('Invalid source ID pattern: %s', e) - return False - if len(source_ids) != len(set(source_ids)): - self.logger.error( - 'Duplicate source IDs. Check source-id-pattern property.' - ) - return False - - else: - if self.source_id is None: - self.logger.error( - 'Source ID is required when enable-multistream=false.' - ) - return False - source_ids = [self.source_id] - - self.source_ids_and_topics = [ - (source_id, f'{source_id}/'.encode()) for source_id in source_ids - ] - - return True - - def do_prepare_output_buffer(self, in_buf: Gst.Buffer): - """Transform gst function.""" - - self.logger.debug( - 'Processing frame %s of size %s', in_buf.pts, in_buf.get_size() - ) - if self.stream_in_progress: - if ( - self.eos_on_file_end - and self.location != self.last_location - or self.eos_on_frame_params_change - and self.frame_params != self.last_frame_params - or self.eos_on_loop_end - and self.new_loop - ): - self.json_metadata = self.read_json_metadata_file( - self.location.parent / f'{self.location.stem}.json' - ) - self.frame_num = 0 - self.send_eos_message() - self.last_location = self.location - self.last_frame_params = self.frame_params - self.new_loop = False - - frame_mapinfo: Optional[Gst.MapInfo] = None - if self.frame_type is None: - result, frame_mapinfo = in_buf.map(Gst.MapFlags.READ) - assert result, 'Cannot read buffer.' - content = VideoFrameContent.internal(frame_mapinfo.data) - elif self.frame_type == ExternalFrameType.ZEROMQ: - content = VideoFrameContent.external(self.frame_type.value, None) - else: - self.logger.error('Unsupported frame type "%s".', self.frame_type.value) - return Gst.FlowReturn.ERROR - - frame = self.build_video_frame( - source_id=self.source_ids_and_topics[0][0], - pts=in_buf.pts, - dts=in_buf.dts if in_buf.dts != Gst.CLOCK_TIME_NONE else None, - duration=in_buf.duration - if in_buf.duration != Gst.CLOCK_TIME_NONE - else None, - content=content, - keyframe=not in_buf.has_flags(Gst.BufferFlags.DELTA_UNIT), - ) - - for i, (source_id, zmq_topic) in enumerate(self.source_ids_and_topics): - frame.source_id = source_id - message = Message.video_frame(frame) - data = save_message_to_bytes(message) - out_buf: Gst.Buffer = gst_buffer_from_list([zmq_topic, data]) - if self.frame_type is not None: - out_buf.append_memory(in_buf.get_memory_range(0, -1)) - out_buf.pts = in_buf.pts - out_buf.dts = in_buf.dts - out_buf.duration = in_buf.duration - if i < len(self.source_ids_and_topics) - 1: - self.srcpad.push(out_buf) - - if frame_mapinfo is not None: - in_buf.unmap(frame_mapinfo) - self.stream_in_progress = True - - return Gst.FlowReturn.OK, out_buf - - def do_sink_event(self, event: Gst.Event): - if event.type == Gst.EventType.EOS: - self.logger.info('Got End-Of-Stream event') - self.send_eos_message() - self.send_shutdown_message() - - elif event.type == Gst.EventType.TAG: - tag_list: Gst.TagList = event.parse_tag() - has_location, location = tag_list.get_string(Gst.TAG_LOCATION) - if has_location: - self.logger.info('Set location to %s', location) - self.location = Path(location) - self.new_loop = True - self.json_metadata = self.read_json_metadata_file( - self.location.parent / f'{self.location.stem}.json' - ) - self.frame_num = 0 - - # Cannot use `super()` since it is `self` - return GstBase.BaseTransform.do_sink_event(self, event) - - def read_json_metadata_file(self, location: Path): - json_metadata = None - if self.read_metadata: - if location.is_file(): - with open(location, 'r') as fp: - json_metadata = list( - map( - lambda x: x['metadata'], - filter( - is_videoframe_metadata, - map(json.loads, splitfile(fp, format='json')), - ), - ) - ) - else: - self.logger.warning('JSON file `%s` not found', location.absolute()) - return json_metadata - - def send_eos_message(self): - self.logger.info('Sending serialized EOS message') - for source_id, zmq_topic in self.source_ids_and_topics: - message = Message.end_of_stream(EndOfStream(source_id)) - data = save_message_to_bytes(message) - out_buf = gst_buffer_from_list([zmq_topic, data]) - self.srcpad.push(out_buf) - self.stream_in_progress = False - - def send_shutdown_message(self): - if self.shutdown_auth is not None: - self.logger.info('Sending serialized Shutdown message') - message = Message.shutdown(Shutdown(self.shutdown_auth)) - data = save_message_to_bytes(message) - out_buf = gst_buffer_from_list([self.source_ids_and_topics[0][1], data]) - self.srcpad.push(out_buf) - - def build_video_frame( - self, - source_id: str, - pts: int, - dts: Optional[int], - duration: Optional[int], - content: VideoFrameContent, - keyframe: bool, - ) -> VideoFrame: - if pts == Gst.CLOCK_TIME_NONE: - # TODO: support CLOCK_TIME_NONE in schema - pts = 0 - objects = None - if self.read_metadata and self.json_metadata: - frame_metadata = self.json_metadata[self.frame_num] - self.frame_num += 1 - objects = frame_metadata['objects'] - - video_frame = VideoFrame( - source_id=source_id, - framerate=self.frame_params.framerate, - width=self.frame_params.width, - height=self.frame_params.height, - codec=self.frame_params.codec_name, - content=content, - keyframe=keyframe, - pts=pts, - dts=dts, - duration=duration, - time_base=DEFAULT_TIME_BASE, - ) - video_frame.add_transformation(self.initial_size_transformation) - if objects: - add_objects_to_video_frame(video_frame, objects) - if self.location: - video_frame.set_attribute( - Attribute( - namespace=DEFAULT_NAMESPACE, - name='location', - values=[AttributeValue.string(str(self.location))], - ) - ) - - return video_frame - - -# register plugin -GObject.type_register(SavantRsSerializer) -__gstelementfactory__ = ( - SavantRsSerializer.GST_PLUGIN_NAME, - Gst.Rank.NONE, - SavantRsSerializer, -) diff --git a/adapters/gst/gst_plugins/python/video_files_sink.py b/adapters/gst/gst_plugins/python/video_files_sink.py deleted file mode 100644 index 017ad0ee2..000000000 --- a/adapters/gst/gst_plugins/python/video_files_sink.py +++ /dev/null @@ -1,355 +0,0 @@ -import os -from typing import Dict, Optional, Union - -from savant_rs.primitives import EndOfStream, VideoFrame - -from adapters.python.sinks.chunk_writer import ChunkWriter, CompositeChunkWriter -from adapters.python.sinks.metadata_json import MetadataJsonWriter, Patterns -from gst_plugins.python.savant_rs_video_demux_common import FrameParams, build_caps -from savant.api.enums import ExternalFrameType -from savant.api.parser import convert_ts -from savant.gstreamer import GLib, GObject, Gst, GstApp -from savant.gstreamer.codecs import Codec -from savant.gstreamer.utils import load_message_from_gst_buffer, on_pad_event -from savant.utils.logging import LoggerMixin - -DEFAULT_CHUNK_SIZE = 10000 - - -class VideoFilesWriter(ChunkWriter): - def __init__( - self, - base_location: str, - source_id: str, - chunk_size: int, - frame_params: FrameParams, - ): - self.base_location = base_location - self.source_id = source_id - self.appsrc: Optional[GstApp.AppSrc] = None - self.bin: Gst.Bin = Gst.Bin.new(f'sink_bin_{source_id}') - self.frame_params = frame_params - self.caps = build_caps(frame_params) - super().__init__(chunk_size) - - def _write_video_frame( - self, - frame: VideoFrame, - data: Optional[Union[bytes, Gst.Memory]], - frame_num: int, - ) -> bool: - if not data: - return True - - if isinstance(data, bytes): - frame_buf: Gst.Buffer = Gst.Buffer.new_wrapped(data) - else: - frame_buf: Gst.Buffer = Gst.Buffer.new() - frame_buf.append_memory(data) - - frame_buf.pts = convert_ts(frame.pts, frame.time_base) - frame_buf.dts = ( - convert_ts(frame.dts, frame.time_base) - if frame.dts is not None - else Gst.CLOCK_TIME_NONE - ) - frame_buf.duration = ( - convert_ts(frame.duration, frame.time_base) - if frame.duration is not None - else Gst.CLOCK_TIME_NONE - ) - self.logger.debug( - 'Sending frame with pts=%s to %s', - frame.pts, - self.appsrc.get_name(), - ) - - return self.appsrc.push_buffer(frame_buf) == Gst.FlowReturn.OK - - def _write_eos(self, eos: EndOfStream) -> bool: - return True - - def _open(self): - self.logger.debug( - 'Creating sink elements for chunk %s of source %s', - self.chunk_idx, - self.source_id, - ) - appsrc_name = f'appsrc_{self.source_id}_{self.chunk_idx}' - filesink_name = f'filesink_{self.source_id}_{self.chunk_idx}' - sink: Gst.Bin = Gst.parse_bin_from_description( - ' ! '.join( - [ - f'appsrc name={appsrc_name} emit-signals=false format=time', - 'queue', - 'adjust_timestamps', - self.frame_params.codec.value.parser, - 'qtmux fragment-duration=1000 fragment-mode=first-moov-then-finalise', - f'filesink name={filesink_name}', - ] - ), - False, - ) - sink.set_name(f'sink_bin_{self.source_id}_{self.chunk_idx}') - self.appsrc: GstApp.AppSrc = sink.get_by_name(appsrc_name) - self.appsrc.set_caps(self.caps) - - filesink: Gst.Element = sink.get_by_name(filesink_name) - os.makedirs(self.base_location, exist_ok=True) - dst_location = os.path.join(self.base_location, f'{self.chunk_idx:04}.mov') - self.logger.info( - 'Writing video from source %s to file %s', self.source_id, dst_location - ) - filesink.set_property('location', dst_location) - - self.bin.add(sink) - - filesink.get_static_pad('sink').add_probe( - Gst.PadProbeType.EVENT_DOWNSTREAM, - on_pad_event, - {Gst.EventType.EOS: self.on_sink_pad_eos}, - sink, - self.chunk_idx, - ) - sink.sync_state_with_parent() - self.logger.debug( - 'Sink elements for chunk %s of source %s created', - self.chunk_idx, - self.source_id, - ) - - def _close(self): - self.logger.debug( - 'Stopping and removing sink elements for chunk %s of source %s', - self.chunk_idx, - self.source_id, - ) - self.logger.debug('Sending EOS to %s', self.appsrc.get_name()) - self.appsrc.end_of_stream() - self.appsrc = None - - def on_sink_pad_eos( - self, pad: Gst.Pad, event: Gst.Event, sink: Gst.Element, chunk_idx: int - ): - self.logger.debug( - 'Got EOS from pad %s of %s', pad.get_name(), pad.parent.get_name() - ) - GLib.idle_add(self._remove_branch, sink, chunk_idx) - return Gst.PadProbeReturn.HANDLED - - def _remove_branch(self, sink: Gst.Element, chunk_idx: int): - self.logger.debug('Removing element %s', sink.get_name()) - sink.set_locked_state(True) - sink.set_state(Gst.State.NULL) - self.bin.remove(sink) - self.logger.debug( - 'Sink elements for chunk %s of source %s removed', chunk_idx, self.source_id - ) - - return False - - -class VideoFilesSink(LoggerMixin, Gst.Bin): - """Writes frames as video files.""" - - GST_PLUGIN_NAME = 'video_files_sink' - - __gstmetadata__ = ( - 'Video files sink', - 'Bin/Sink/File', - 'Writes frames as video files', - 'Pavel Tomskikh ', - ) - - __gsttemplates__ = ( - Gst.PadTemplate.new( - 'sink', - Gst.PadDirection.SINK, - Gst.PadPresence.ALWAYS, - Gst.Caps.new_any(), - ), - ) - - __gproperties__ = { - 'location': ( - GObject.TYPE_STRING, - 'Output directory location', - 'Location of the directory for output files', - None, - GObject.ParamFlags.READWRITE, - ), - 'chunk-size': ( - int, - 'Chunk size', - 'Chunk size in frames (0 to disable chunks).', - 0, - GObject.G_MAXINT, - DEFAULT_CHUNK_SIZE, - GObject.ParamFlags.READWRITE, - ), - } - - def __init__(self): - super().__init__() - self.writers: Dict[str, ChunkWriter] = {} - - self.location: str = None - self.chunk_size: int = DEFAULT_CHUNK_SIZE - - self.sink_pad: Gst.Pad = Gst.Pad.new_from_template( - Gst.PadTemplate.new( - 'sink', - Gst.PadDirection.SINK, - Gst.PadPresence.ALWAYS, - Gst.Caps.new_any(), - ), - 'sink', - ) - self.sink_pad.set_chain_function(self.handle_buffer) - assert self.add_pad(self.sink_pad), 'Failed to add sink pad.' - - def do_get_property(self, prop): - """Gst plugin get property function. - - :param prop: structure that encapsulates - the metadata required to specify parameters - """ - if prop.name == 'location': - return self.location - if prop.name == 'chunk-size': - return self.chunk_size - raise AttributeError(f'Unknown property {prop.name}.') - - def do_set_property(self, prop, value): - """Gst plugin set property function. - - :param prop: structure that encapsulates - the metadata required to specify parameters - :param value: new value for param, type dependents on param - """ - if prop.name == 'location': - self.location = value - elif prop.name == 'chunk-size': - self.chunk_size = value - else: - raise AttributeError(f'Unknown property {prop.name}.') - - def do_set_state(self, state: Gst.State): - self.logger.info('Changing state from %s to %s', self.current_state, state) - if self.current_state == Gst.State.NULL and state != Gst.State.NULL: - assert self.location is not None, '"location" property is required' - return Gst.Bin.do_set_state(self, state) - - def handle_buffer(self, sink_pad: Gst.Pad, buffer: Gst.Buffer) -> Gst.FlowReturn: - self.logger.debug( - 'Handling buffer of size %s with timestamp %s', - buffer.get_size(), - buffer.pts, - ) - - message = load_message_from_gst_buffer(buffer) - message.validate_seq_id() - # TODO: Pipeline message types might be extended beyond only VideoFrame - # Additional checks for audio/raw_tensors/etc. may be required - if message.is_video_frame(): - result = self.handle_video_frame(message.as_video_frame(), buffer) - elif message.is_end_of_stream(): - result = self.handle_eos(message.as_end_of_stream()) - else: - self.logger.debug('Unsupported message type for message %r', message) - result = Gst.FlowReturn.OK - - return result - - def handle_video_frame( - self, - frame: VideoFrame, - buffer: Gst.Buffer, - ) -> Gst.FlowReturn: - frame_params = FrameParams.from_video_frame(frame) - assert frame_params.codec in [ - Codec.H264, - Codec.HEVC, - Codec.JPEG, - Codec.PNG, - ], f'Unsupported codec {frame.codec}' - if frame.content.is_none(): - self.logger.debug( - 'Received frame %s from source %s is empty', - frame.pts, - frame.source_id, - ) - content = None - elif frame.content.is_internal(): - content = frame.content.get_data_as_bytes() - self.logger.debug( - 'Received frame %s from source %s, size: %s bytes', - frame.pts, - frame.source_id, - len(content), - ) - else: - frame_type = ExternalFrameType(frame.content.get_method()) - if frame_type != ExternalFrameType.ZEROMQ: - self.logger.error('Unsupported frame type "%s".', frame_type.value) - return Gst.FlowReturn.ERROR - if buffer.n_memory() < 2: - self.logger.error( - 'Buffer has %s regions of memory, expected at least 2.', - buffer.n_memory(), - ) - return Gst.FlowReturn.ERROR - - content = buffer.get_memory_range(1, -1) - self.logger.debug( - 'Received frame %s from source %s, size: %s bytes', - frame.pts, - frame.source_id, - content.size, - ) - - writer = self.writers.get(frame.source_id) - if writer is None: - base_location = os.path.join(self.location, frame.source_id) - json_filename_pattern = f'{Patterns.CHUNK_IDX}.json' - video_writer = VideoFilesWriter( - base_location, - frame.source_id, - self.chunk_size, - frame_params, - ) - writer = CompositeChunkWriter( - [ - video_writer, - MetadataJsonWriter( - os.path.join(base_location, json_filename_pattern), - self.chunk_size, - ), - ], - self.chunk_size, - ) - self.writers[frame.source_id] = writer - self.add(video_writer.bin) - video_writer.bin.sync_state_with_parent() - - if writer.write_video_frame(frame, content, frame.keyframe): - return Gst.FlowReturn.OK - - return Gst.FlowReturn.ERROR - - def handle_eos(self, eos: EndOfStream) -> Gst.FlowReturn: - self.logger.info('Received EOS from source %s.', eos.source_id) - writer = self.writers.get(eos.source_id) - if writer is not None: - writer.write_eos(eos) - writer.close() - return Gst.FlowReturn.OK - - -# register plugin -GObject.type_register(VideoFilesSink) -__gstelementfactory__ = ( - VideoFilesSink.GST_PLUGIN_NAME, - Gst.Rank.NONE, - VideoFilesSink, -) diff --git a/adapters/gst/gst_plugins/python/zeromq_sink.py b/adapters/gst/gst_plugins/python/zeromq_sink.py index 8f422527f..fdc9a7301 100644 --- a/adapters/gst/gst_plugins/python/zeromq_sink.py +++ b/adapters/gst/gst_plugins/python/zeromq_sink.py @@ -1,35 +1,70 @@ """ZeroMQ sink.""" import inspect -from typing import List, Union +import json +from fractions import Fraction +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional -import zmq +from savant_rs.primitives import ( + AttributeValue, + EndOfStream, + Shutdown, + VideoFrame, + VideoFrameContent, + VideoFrameTransformation, +) +from savant_rs.utils.serialization import Message +from savant_rs.zmq import ( + BlockingWriter, + WriterConfigBuilder, + WriterResultAck, + WriterResultSuccess, +) +from splitstream import splitfile from gst_plugins.python.zeromq_properties import ZEROMQ_PROPERTIES, socket_type_property +from savant.api.builder import add_objects_to_video_frame +from savant.api.constants import DEFAULT_FRAMERATE, DEFAULT_NAMESPACE, DEFAULT_TIME_BASE +from savant.api.enums import ExternalFrameType from savant.gstreamer import GObject, Gst, GstBase +from savant.gstreamer.codecs import CODEC_BY_CAPS_NAME, Codec from savant.gstreamer.utils import ( gst_post_library_settings_error, gst_post_stream_failed_error, + required_property, ) from savant.utils.logging import LoggerMixin -from savant.utils.zeromq import ( - END_OF_STREAM_MESSAGE, - Defaults, - SenderSocketTypes, - ZMQException, - parse_zmq_socket_uri, - receive_response, -) +from savant.utils.zeromq import Defaults, SenderSocketTypes, get_zmq_socket_uri_options + +EMBEDDED_FRAME_TYPE = 'embedded' +DEFAULT_SOURCE_ID_PATTERN = 'source-%d' + + +class FrameParams(NamedTuple): + """Frame parameters.""" + + codec_name: str + width: int + height: int + framerate: str + + +def is_videoframe_metadata(metadata: Dict[str, Any]) -> bool: + """Check that metadata contained if metadata is a video frame metadata. .""" + if 'schema' in metadata and metadata['schema'] != 'VideoFrame': + return False + return True class ZeroMQSink(LoggerMixin, GstBase.BaseSink): - """ZeroMQSink GstPlugin.""" + """Serializes video stream to savant-rs message and sends it to ZeroMQ socket.""" GST_PLUGIN_NAME = 'zeromq_sink' __gstmetadata__ = ( - 'ZeroMQ sink', - 'Source', - 'Writes binary messages to ZeroMQ', + 'Serializes video stream to savant-rs message and sends it to ZeroMQ socket.', + 'Sink', + 'Serializes video stream to savant-rs message and sends it to ZeroMQ socket.', 'Pavel Tomskikh ', ) @@ -37,7 +72,7 @@ class ZeroMQSink(LoggerMixin, GstBase.BaseSink): 'sink', Gst.PadDirection.SINK, Gst.PadPresence.ALWAYS, - Gst.Caps.new_any(), + Gst.Caps.from_string(';'.join(x.value.caps_with_params for x in Codec)), ) __gproperties__ = { @@ -61,38 +96,141 @@ class ZeroMQSink(LoggerMixin, GstBase.BaseSink): Defaults.SENDER_RECEIVE_TIMEOUT, GObject.ParamFlags.READWRITE, ), - 'req-receive-retries': ( + 'receive-retries': ( int, - 'Retries to receive confirmation message from REQ socket', - 'Retries to receive confirmation message from REQ socket', + 'Retries to receive confirmation message', + 'Retries to receive confirmation message', 1, GObject.G_MAXINT, - Defaults.REQ_RECEIVE_RETRIES, + Defaults.RECEIVE_RETRIES, + GObject.ParamFlags.READWRITE, + ), + 'source-id': ( + str, + 'Source ID', + 'Source ID, e.g. "camera1".', + None, + GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, + ), + 'location': ( + str, + 'Source location', + 'Source location', + None, + GObject.ParamFlags.READWRITE, + ), + # TODO: make fraction + 'framerate': ( + str, + 'Default framerate', + 'Default framerate', + DEFAULT_FRAMERATE, + GObject.ParamFlags.READWRITE, + ), + 'eos-on-file-end': ( + bool, + 'Send EOS at the end of each file', + 'Send EOS at the end of each file', + True, + GObject.ParamFlags.READWRITE, + ), + 'eos-on-loop-end': ( + bool, + 'Send EOS on a loop end', + 'Send EOS on a loop end', + False, + GObject.ParamFlags.READWRITE, + ), + 'eos-on-frame-params-change': ( + bool, + 'Send EOS when frame parameters changed', + 'Send EOS when frame parameters changed', + True, + GObject.ParamFlags.READWRITE, + ), + 'read-metadata': ( + bool, + 'Read metadata', + 'Attempt to read the metadata of objects from the JSON file that has the identical name ' + 'as the source file with `json` extension, and then send it to the module.', + False, GObject.ParamFlags.READWRITE, ), - 'eos-confirmation-retries': ( + 'frame-type': ( + str, + 'Frame type.', + 'Frame type (allowed: ' + f'{", ".join([EMBEDDED_FRAME_TYPE] + [enum_member.value for enum_member in ExternalFrameType])})', + None, + GObject.ParamFlags.READWRITE, + ), + 'enable-multistream': ( + bool, + 'Enable multistream', + 'Enable multistream', + False, + GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, + ), + 'source-id-pattern': ( + str, + 'Pattern for source ID', + 'Pattern for source ID when multistream is enabled. E.g. "source-%d".', + DEFAULT_SOURCE_ID_PATTERN, + GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, + ), + 'number-of-streams': ( int, - 'Retries to receive EOS confirmation message', - 'Retries to receive EOS confirmation message', + 'Number of streams', + 'Number of streams', + 1, # min + 1024, # max 1, - GObject.G_MAXINT, - Defaults.EOS_CONFIRMATION_RETRIES, + GObject.ParamFlags.READWRITE | Gst.PARAM_MUTABLE_READY, + ), + 'shutdown-auth': ( + str, + 'Authentication key for Shutdown message.', + 'Authentication key for Shutdown message. When specified, a shutdown' + 'message will be sent at the end of the stream.', + None, GObject.ParamFlags.READWRITE, ), } def __init__(self): GstBase.BaseSink.__init__(self) + + # properties self.socket: str = None - self.socket_type: Union[str, SenderSocketTypes] = SenderSocketTypes.DEALER + self.socket_type: str = SenderSocketTypes.DEALER.name self.bind: bool = True - self.zmq_context: zmq.Context = None - self.sender: zmq.Socket = None - self.wait_response = False + self.source_id: Optional[str] = None + self.eos_on_file_end: bool = True + self.eos_on_loop_end: bool = False + self.eos_on_frame_params_change: bool = True + self.enable_multistream: bool = False + self.source_id_pattern: str = DEFAULT_SOURCE_ID_PATTERN + self.number_of_streams: int = 1 + self.shutdown_auth: Optional[str] = None self.send_hwm = Defaults.SEND_HWM + + # will be set after caps negotiation + self.frame_params: Optional[FrameParams] = None + self.initial_size_transformation: Optional[VideoFrameTransformation] = None + self.last_frame_params: Optional[FrameParams] = None + self.location: Optional[Path] = None + self.last_location: Optional[Path] = None + self.new_loop: bool = False + self.default_framerate: str = DEFAULT_FRAMERATE + + self.source_ids: List[str] = [] + self.stream_in_progress = False + self.read_metadata: bool = False + self.json_metadata = None + self.frame_num = 0 + self.writer: BlockingWriter = None self.receive_timeout = Defaults.SENDER_RECEIVE_TIMEOUT - self.req_receive_retries = Defaults.REQ_RECEIVE_RETRIES - self.eos_confirmation_retries = Defaults.EOS_CONFIRMATION_RETRIES + self.receive_retries = Defaults.RECEIVE_RETRIES self.set_sync(False) def do_get_property(self, prop): @@ -101,24 +239,43 @@ def do_get_property(self, prop): :param prop: structure that encapsulates the metadata required to specify parameters """ + if prop.name == 'socket': return self.socket if prop.name == 'socket-type': - return ( - self.socket_type.name - if isinstance(self.socket_type, SenderSocketTypes) - else self.socket_type - ) + return self.socket_type if prop.name == 'bind': return self.bind if prop.name == 'send-hwm': return self.send_hwm if prop.name == 'receive-timeout': return self.receive_timeout - if prop.name == 'req-receive-retries': - return self.req_receive_retries - if prop.name == 'eos-confirmation-retries': - return self.eos_confirmation_retries + if prop.name == 'receive-retries': + return self.receive_retries + + if prop.name == 'source-id': + return self.source_id + if prop.name == 'location': + return self.location + if prop.name == 'framerate': + return self.default_framerate + if prop.name == 'eos-on-file-end': + return self.eos_on_file_end + if prop.name == 'eos-on-loop-end': + return self.eos_on_loop_end + if prop.name == 'eos-on-frame-params-change': + return self.eos_on_frame_params_change + if prop.name == 'read-metadata': + return self.read_metadata + if prop.name == 'enable-multistream': + return self.enable_multistream + if prop.name == 'source-id-pattern': + return self.source_id_pattern + if prop.name == 'number-of-streams': + return self.number_of_streams + if prop.name == 'shutdown-auth': + return self.shutdown_auth + raise AttributeError(f'Unknown property {prop.name}.') def do_set_property(self, prop, value): @@ -128,6 +285,7 @@ def do_set_property(self, prop, value): the metadata required to specify parameters :param value: new value for param, type dependents on param """ + self.logger.debug('Setting property "%s" to "%s".', prop.name, value) if prop.name == 'socket': self.socket = value @@ -139,38 +297,127 @@ def do_set_property(self, prop, value): self.send_hwm = value elif prop.name == 'receive-timeout': self.receive_timeout = value - elif prop.name == 'req-receive-retries': - self.req_receive_retries = value - elif prop.name == 'eos-confirmation-retries': - self.eos_confirmation_retries = value + elif prop.name == 'receive-retries': + self.receive_retries = value + + elif prop.name == 'source-id': + self.source_id = value + elif prop.name == 'location': + self.location = value + elif prop.name == 'framerate': + try: + Fraction(value) # validate + except (ZeroDivisionError, ValueError) as e: + raise AttributeError(f'Invalid property {prop.name}: {e}.') from e + self.default_framerate = value + elif prop.name == 'eos-on-file-end': + self.eos_on_file_end = value + elif prop.name == 'eos-on-loop-end': + self.eos_on_loop_end = value + elif prop.name == 'eos-on-frame-params-change': + self.eos_on_frame_params_change = value + elif prop.name == 'read-metadata': + self.read_metadata = value + elif prop.name == 'enable-multistream': + self.enable_multistream = value + elif prop.name == 'source-id-pattern': + self.source_id_pattern = value + elif prop.name == 'number-of-streams': + self.number_of_streams = value + elif prop.name == 'shutdown-auth': + self.shutdown_auth = value else: raise AttributeError(f'Unknown property {prop.name}.') + def get_source_ids(self) -> Optional[List[str]]: + if self.enable_multistream: + if self.source_id_pattern is None: + self.logger.error( + 'Source ID pattern is required when enable-multistream=true.' + ) + return None + try: + source_ids = [ + self.source_id_pattern % i for i in range(self.number_of_streams) + ] + except TypeError as e: + self.logger.error('Invalid source ID pattern: %s', e) + return None + if len(source_ids) != len(set(source_ids)): + self.logger.error( + 'Duplicate source IDs. Check source-id-pattern property.' + ) + return None + + return source_ids + + if self.source_id is None: + self.logger.error('Source ID is required when enable-multistream=false.') + return None + return [self.source_id] + def do_start(self): - """Start source.""" + """Start sink.""" + + self.source_ids = self.get_source_ids() + if self.source_ids is None: + return False + try: - self.socket_type, self.bind, self.socket = parse_zmq_socket_uri( - uri=self.socket, - socket_type_name=self.socket_type, - socket_type_enum=SenderSocketTypes, - bind=self.bind, - ) - except ZMQException: - self.logger.exception('Element start error.') + required_property('socket', self.socket) + config_builder = WriterConfigBuilder(self.socket) + if not get_zmq_socket_uri_options(self.socket): + config_builder.with_socket_type( + SenderSocketTypes[self.socket_type].value + ) + config_builder.with_bind(self.bind) + config_builder.with_send_hwm(self.send_hwm) + config_builder.with_receive_timeout(self.receive_timeout) + config_builder.with_receive_retries(self.receive_retries) + config_builder.with_send_timeout(self.receive_timeout) + config_builder.with_send_retries(self.receive_retries) + self.writer = BlockingWriter(config_builder.build()) + self.writer.start() + + except Exception as exc: + error = f'Failed to start ZeroMQ sink with socket {self.socket}: {exc}.' + self.logger.exception(error, exc_info=True) frame = inspect.currentframe() - gst_post_library_settings_error(self, frame, __file__) + gst_post_library_settings_error(self, frame, __file__, error) # prevents pipeline from starting return False - self.wait_response = self.socket_type == SenderSocketTypes.REQ - self.zmq_context = zmq.Context() - self.sender = self.zmq_context.socket(self.socket_type.value) - self.sender.setsockopt(zmq.SNDHWM, self.send_hwm) - self.sender.setsockopt(zmq.RCVTIMEO, self.receive_timeout) - if self.bind: - self.sender.bind(self.socket) + return True + + def do_set_caps(self, caps: Gst.Caps): + """Checks caps after negotiations.""" + + self.logger.info('Sink caps changed to %s', caps) + struct: Gst.Structure = caps.get_structure(0) + try: + codec = CODEC_BY_CAPS_NAME[struct.get_name()] + except KeyError: + self.logger.error('Not supported caps: %s', caps.to_string()) + return False + codec_name = codec.value.name + frame_width = struct.get_int('width').value + frame_height = struct.get_int('height').value + if struct.has_field('framerate'): + _, framerate_num, framerate_demon = struct.get_fraction('framerate') + framerate = f'{framerate_num}/{framerate_demon}' else: - self.sender.connect(self.socket) + framerate = self.default_framerate + self.frame_params = FrameParams( + codec_name=codec_name, + width=frame_width, + height=frame_height, + framerate=framerate, + ) + self.initial_size_transformation = VideoFrameTransformation.initial_size( + frame_width, + frame_height, + ) + return True def do_render(self, buffer: Gst.Buffer): @@ -178,85 +425,171 @@ def do_render(self, buffer: Gst.Buffer): self.logger.debug( 'Processing frame %s of size %s', buffer.pts, buffer.get_size() ) - message: List[bytes] = [] - mapinfo_list: List[Gst.MapInfo] = [] - mapinfo: Gst.MapInfo - for i in range(2): - result, mapinfo = buffer.map_range(i, 1, Gst.MapFlags.READ) - assert result, 'Cannot read buffer.' - mapinfo_list.append(mapinfo) - message.append(mapinfo.data) - - if buffer.n_memory() > 2: - # TODO: Use Gst.Meta to check where to split buffer to ZeroMQ message parts - result, mapinfo = buffer.map_range(2, -1, Gst.MapFlags.READ) - assert result, 'Cannot read buffer.' - mapinfo_list.append(mapinfo) - message.append(mapinfo.data) - self.logger.debug( - 'Sending %s bytes to socket %s.', sum(len(x) for x in message), self.socket - ) - self.sender.send_multipart(message) - if self.wait_response: - try: - resp = receive_response(self.sender, self.req_receive_retries) - except zmq.Again: - error = ( - f"The REP socket hasn't responded in a configured timeframe " - f'{self.receive_timeout * self.req_receive_retries} ms.' - ) - self.logger.error(error) - frame = inspect.currentframe() - gst_post_stream_failed_error( - gst_element=self, - frame=frame, - file_path=__file__, - text=error, + if self.stream_in_progress: + if ( + self.eos_on_file_end + and self.location != self.last_location + or self.eos_on_frame_params_change + and self.frame_params != self.last_frame_params + or self.eos_on_loop_end + and self.new_loop + ): + self.json_metadata = self.read_json_metadata_file( + self.location.parent / f'{self.location.stem}.json' ) - return Gst.FlowReturn.ERROR + self.frame_num = 0 + self.send_eos() + self.last_location = self.location + self.last_frame_params = self.frame_params + self.new_loop = False + self.stream_in_progress = True - self.logger.debug( - 'Received %s bytes from socket %s.', len(resp), self.socket - ) - for mapinfo in mapinfo_list: - buffer.unmap(mapinfo) + content = buffer.extract_dup(0, buffer.get_size()) + base_frame = self.build_video_frame( + source_id=self.source_ids[0][0], + pts=buffer.pts, + dts=buffer.dts if buffer.dts != Gst.CLOCK_TIME_NONE else None, + duration=( + buffer.duration if buffer.duration != Gst.CLOCK_TIME_NONE else None + ), + keyframe=not buffer.has_flags(Gst.BufferFlags.DELTA_UNIT), + ) + + for source_id in self.source_ids: + frame = base_frame.copy() + frame.source_id = source_id + if not self.send_message_to_zmq(source_id, frame.to_message(), content): + return Gst.FlowReturn.ERROR return Gst.FlowReturn.OK def do_stop(self): """Stop source.""" - if self.socket_type == SenderSocketTypes.DEALER: - self.logger.info('Sending End-of-Stream message to socket %s', self.socket) - self.sender.send_multipart([END_OF_STREAM_MESSAGE]) - self.logger.info( - 'Waiting for End-of-Stream message confirmation from socket %s', - self.socket, + if self.shutdown_auth is not None: + self.logger.info('Sending serialized Shutdown message') + self.send_message_to_zmq( + self.source_ids[0], + Shutdown(self.shutdown_auth).to_message(), ) - try: - self.sender.recv() - except zmq.Again: - error = ( - f'Timeout exceeded when receiving End-of-Stream message ' - f'confirmation from socket {self.socket}' - ) - self.logger.error(error) - frame = inspect.currentframe() - gst_post_stream_failed_error( - gst_element=self, - frame=frame, - file_path=__file__, - text=error, - ) - return False - self.logger.info('Closing ZeroMQ socket') - self.sender.close() - self.logger.info('Terminating ZeroMQ context.') - self.zmq_context.term() - self.logger.info('ZeroMQ context terminated') + self.logger.info('Terminating ZeroMQ writer.') + self.writer.shutdown() + self.logger.info('ZeroMQ writer terminated') return True + def do_event(self, event: Gst.Event): + if event.type == Gst.EventType.EOS: + self.logger.info('Got End-Of-Stream event') + self.send_eos() + + elif event.type == Gst.EventType.TAG: + tag_list: Gst.TagList = event.parse_tag() + has_location, location = tag_list.get_string(Gst.TAG_LOCATION) + if has_location: + self.logger.info('Set location to %s', location) + self.location = Path(location) + self.new_loop = True + self.json_metadata = self.read_json_metadata_file( + self.location.parent / f'{self.location.stem}.json' + ) + self.frame_num = 0 + + # Cannot use `super()` since it is `self` + return GstBase.BaseSink.do_event(self, event) + + def build_video_frame( + self, + source_id: str, + pts: int, + dts: Optional[int], + duration: Optional[int], + keyframe: bool, + ) -> VideoFrame: + if pts == Gst.CLOCK_TIME_NONE: + pts = 0 + objects = None + if self.read_metadata and self.json_metadata: + frame_metadata = self.json_metadata[self.frame_num] + self.frame_num += 1 + objects = frame_metadata['objects'] + + video_frame = VideoFrame( + source_id=source_id, + framerate=self.frame_params.framerate, + width=self.frame_params.width, + height=self.frame_params.height, + codec=self.frame_params.codec_name, + content=VideoFrameContent.external(ExternalFrameType.ZEROMQ.value, None), + keyframe=keyframe, + pts=pts, + dts=dts, + duration=duration, + time_base=DEFAULT_TIME_BASE, + ) + video_frame.add_transformation(self.initial_size_transformation) + if objects: + add_objects_to_video_frame(video_frame, objects) + if self.location: + video_frame.set_persistent_attribute( + namespace=DEFAULT_NAMESPACE, + name='location', + values=[AttributeValue.string(str(self.location))], + ) + + return video_frame + + def read_json_metadata_file(self, location: Path): + json_metadata = None + if self.read_metadata: + if location.is_file(): + with open(location, 'r') as fp: + json_metadata = list( + map( + lambda x: x['metadata'], + filter( + is_videoframe_metadata, + map(json.loads, splitfile(fp, format='json')), + ), + ) + ) + else: + self.logger.warning('JSON file `%s` not found', location.absolute()) + return json_metadata + + def send_eos(self): + self.logger.info('Sending serialized EOS message') + for source_id in self.source_ids: + self.send_message_to_zmq( + source_id, + EndOfStream(source_id).to_message(), + ) + self.stream_in_progress = False + + def send_message_to_zmq( + self, + source_id: str, + message: Message, + content: bytes = b'', + ) -> bool: + try: + send_result = self.writer.send_message(source_id, message, content) + if isinstance(send_result, (WriterResultAck, WriterResultSuccess)): + return True + error = f'Failed to send message to ZeroMQ: {send_result}' + except Exception as exc: + error = f'Failed to send message to ZeroMQ: {exc}' + + self.logger.error(error) + frame = inspect.currentframe() + gst_post_stream_failed_error( + gst_element=self, + frame=frame, + file_path=__file__, + text=error, + ) + return False + # register plugin GObject.type_register(ZeroMQSink) diff --git a/adapters/gst/sinks/video_files.py b/adapters/gst/sinks/video_files.py new file mode 100755 index 000000000..4e7b62017 --- /dev/null +++ b/adapters/gst/sinks/video_files.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +import os +import signal +import threading +from datetime import timedelta +from distutils.util import strtobool +from time import time +from typing import Dict, Optional + +from savant_rs.primitives import EndOfStream, VideoFrame + +from adapters.python.shared.config import opt_config +from adapters.python.sinks.chunk_writer import ChunkWriter, CompositeChunkWriter +from adapters.python.sinks.metadata_json import MetadataJsonWriter, Patterns +from gst_plugins.python.savant_rs_video_demux_common import FrameParams, build_caps +from savant.api.parser import convert_ts +from savant.gstreamer import GLib, Gst, GstApp +from savant.gstreamer.codecs import Codec +from savant.utils.logging import get_logger, init_logging +from savant.utils.zeromq import ZeroMQMessage, ZeroMQSource + +LOGGER_NAME = 'adapters.video_files_sink' +DEFAULT_CHUNK_SIZE = 10000 + + +# Modified version of savant.gstreamer.runner.GstPipelineRunner +# to avoid unnecessary dependencies (omegaconf, opencv, etc.) +class GstPipelineRunner: + """Manages running Gstreamer pipeline. + + :param pipeline: GstPipeline or Gst.Pipeline to run. + :param shutdown_timeout: Seconds to wait for pipeline shutdown. + """ + + def __init__( + self, + pipeline: Gst.Pipeline, + shutdown_timeout: int = 10, # make configurable + ): + self._logger = get_logger(f'{LOGGER_NAME}.{self.__class__.__name__}') + self._shutdown_timeout = shutdown_timeout + + # pipeline error storage + self._error: Optional[str] = None + + # running pipeline flag + self._is_running = False + + # pipeline execution start time, will be set on startup + self._start_time = 0.0 + + self._main_loop = GLib.MainLoop() + self._main_loop_thread = threading.Thread(target=self._main_loop_run) + + self._pipeline: Gst.Pipeline = pipeline + + def _main_loop_run(self): + try: + self._main_loop.run() + finally: + self.shutdown() + if self._error: + raise RuntimeError(self._error) + + def startup(self): + """Starts pipeline.""" + self._logger.info('Starting pipeline `%s`...', self._pipeline) + start_time = time() + + bus = self._pipeline.get_bus() + self._logger.debug('Adding signal watch and connecting callbacks...') + bus.add_signal_watch() + bus.connect('message::error', self.on_error) + bus.connect('message::eos', self.on_eos) + bus.connect('message::warning', self.on_warning) + bus.connect('message::state-changed', self.on_state_changed) + + self._logger.debug('Setting pipeline to READY...') + self._pipeline.set_state(Gst.State.READY) + + self._logger.debug('Setting pipeline to PLAYING...') + self._pipeline.set_state(Gst.State.PLAYING) + + self._logger.debug('Starting main loop thread...') + self._is_running = True + self._main_loop_thread.start() + + end_time = time() + exec_seconds = end_time - start_time + self._logger.info( + 'The pipeline is initialized and ready to process data. Initialization took %s.', + timedelta(seconds=exec_seconds), + ) + + self._start_time = end_time + + def shutdown(self): + """Stops pipeline.""" + self._logger.debug('shutdown() called.') + if not self._is_running: + self._logger.debug('The pipeline is shutting down already.') + return + + self._is_running = False + + if self._main_loop.is_running(): + self._logger.debug('Quitting main loop...') + self._main_loop.quit() + + pipeline_state_thread = threading.Thread( + target=self._pipeline.set_state, + args=(Gst.State.NULL,), + daemon=True, + ) + self._logger.debug('Setting pipeline to NULL...') + pipeline_state_thread.start() + try: + pipeline_state_thread.join(self._shutdown_timeout) + except RuntimeError: + self._logger.error('Failed to join thread.') + + exec_seconds = time() - self._start_time + self._logger.info( + 'The pipeline is about to stop. Operation took %s.', + timedelta(seconds=exec_seconds), + ) + if pipeline_state_thread.is_alive(): + self._logger.warning('Pipeline shutdown timeout exceeded.') + + def on_error( # pylint: disable=unused-argument + self, bus: Gst.Bus, message: Gst.Message + ): + """Error callback.""" + err, debug = message.parse_error() + # calling `raise` here causes the pipeline to hang, + # just save message and handle it later + self._error = self.build_error_message(message, err, debug) + self._logger.error(self._error) + self._error += f' Debug info: "{debug}".' + self.shutdown() + + def build_error_message(self, message: Gst.Message, err: GLib.GError, debug: str): + """Build error message.""" + return f'Received error "{err}" from {message.src.name}.' + + def on_eos( # pylint: disable=unused-argument + self, bus: Gst.Bus, message: Gst.Message + ): + """EOS callback.""" + self._logger.info('Received EOS from %s.', message.src.name) + self.shutdown() + + def on_warning( # pylint: disable=unused-argument + self, bus: Gst.Bus, message: Gst.Message + ): + """Warning callback.""" + warn, debug = message.parse_warning() + self._logger.warning('Received warning %s. %s', warn, debug) + + def on_state_changed( # pylint: disable=unused-argument + self, bus: Gst.Bus, msg: Gst.Message + ): + """Change state callback.""" + if not msg.src == self._pipeline: + # not from the pipeline, ignore + return + + old_state, new_state, _ = msg.parse_state_changed() + old_state_name = Gst.Element.state_get_name(old_state) + new_state_name = Gst.Element.state_get_name(new_state) + self._logger.debug( + 'Pipeline state changed from %s to %s.', old_state_name, new_state_name + ) + + if old_state != new_state and os.getenv('GST_DEBUG_DUMP_DOT_DIR'): + file_name = f'pipeline.{old_state_name}_{new_state_name}' + Gst.debug_bin_to_dot_file_with_ts( + self._pipeline, Gst.DebugGraphDetails.ALL, file_name + ) + + @property + def error(self) -> Optional[str]: + """Returns error message.""" + return self._error + + @property + def is_running(self) -> bool: + """Checks if the pipeline is running.""" + return self._is_running + + +class VideoFilesWriter(ChunkWriter): + def __init__( + self, + base_location: str, + source_id: str, + chunk_size: int, + frame_params: FrameParams, + ): + self.base_location = base_location + self.source_id = source_id + self.appsrc: Optional[GstApp.AppSrc] = None + self.pipeline: Optional[Gst.Pipeline] = None + self.runner: Optional[GstPipelineRunner] = None + self.frame_params = frame_params + self.caps = build_caps(frame_params) + super().__init__(chunk_size) + + def _write_video_frame( + self, + frame: VideoFrame, + data: Optional[bytes], + frame_num: int, + ) -> bool: + if not data: + return True + + frame_buf: Gst.Buffer = Gst.Buffer.new_wrapped(data) + frame_buf.pts = convert_ts(frame.pts, frame.time_base) + frame_buf.dts = ( + convert_ts(frame.dts, frame.time_base) + if frame.dts is not None + else Gst.CLOCK_TIME_NONE + ) + frame_buf.duration = ( + convert_ts(frame.duration, frame.time_base) + if frame.duration is not None + else Gst.CLOCK_TIME_NONE + ) + self.logger.debug( + 'Sending frame with pts=%s to %s', + frame.pts, + self.appsrc.get_name(), + ) + + return self.appsrc.push_buffer(frame_buf) == Gst.FlowReturn.OK + + def _write_eos(self, eos: EndOfStream) -> bool: + return self.appsrc.end_of_stream() == Gst.FlowReturn.OK + + def _open(self): + self.logger.debug( + 'Creating sink elements for chunk %s of source %s', + self.chunk_idx, + self.source_id, + ) + appsrc_name = 'appsrc' + filesink_name = 'filesink' + self.pipeline: Gst.Pipeline = Gst.parse_launch( + ' ! '.join( + [ + f'appsrc name={appsrc_name} emit-signals=false format=time', + 'queue', + 'adjust_timestamps', + self.frame_params.codec.value.parser, + 'qtmux fragment-duration=1000 fragment-mode=first-moov-then-finalise', + f'filesink name={filesink_name}', + ] + ), + ) + self.pipeline.set_name(f'video_chunk_{self.source_id}_{self.chunk_idx}') + self.appsrc: GstApp.AppSrc = self.pipeline.get_by_name(appsrc_name) + self.appsrc.set_caps(self.caps) + + filesink: Gst.Element = self.pipeline.get_by_name(filesink_name) + os.makedirs(self.base_location, exist_ok=True) + dst_location = os.path.join(self.base_location, f'{self.chunk_idx:04}.mov') + self.logger.info( + 'Writing video from source %s to file %s', self.source_id, dst_location + ) + filesink.set_property('location', dst_location) + self.logger.debug( + 'Gst pipeline for chunk %s of source %s has been created', + self.chunk_idx, + self.source_id, + ) + + self.runner = GstPipelineRunner(self.pipeline) + self.runner.startup() + self.logger.debug( + 'Gst pipeline for chunk %s of source %s has been started', + self.chunk_idx, + self.source_id, + ) + + def _close(self): + self.logger.debug( + 'Stopping and removing sink elements for chunk %s of source %s', + self.chunk_idx, + self.source_id, + ) + self.logger.debug('Sending EOS to %s', self.appsrc.get_name()) + self.appsrc.end_of_stream() + self.appsrc = None + + +class VideoFilesSink: + def __init__( + self, + location: str, + chunk_size: int, + ): + self.logger = get_logger(f'{LOGGER_NAME}.{self.__class__.__name__}') + self.location = location + self.chunk_size = chunk_size + self.writers: Dict[str, ChunkWriter] = {} + + def write(self, zmq_message: ZeroMQMessage): + message = zmq_message.message + message.validate_seq_id() + if message.is_video_frame(): + return self._write_video_frame( + message.as_video_frame(), + zmq_message.content, + ) + if message.is_end_of_stream(): + return self._write_eos(message.as_end_of_stream()) + self.logger.debug('Unsupported message type for message %r', message) + + def _write_video_frame( + self, video_frame: VideoFrame, content: Optional[bytes] + ) -> bool: + frame_params = FrameParams.from_video_frame(video_frame) + if frame_params.codec not in [Codec.H264, Codec.HEVC, Codec.JPEG, Codec.PNG]: + self.logger.error( + 'Frame %s/%s has unsupported codec %s', + video_frame.source_id, + video_frame.pts, + video_frame.codec, + ) + return False + + writer = self.writers.get(video_frame.source_id) + if writer is None: + base_location = os.path.join(self.location, video_frame.source_id) + json_filename_pattern = f'{Patterns.CHUNK_IDX}.json' + video_writer = VideoFilesWriter( + base_location, + video_frame.source_id, + self.chunk_size, + frame_params, + ) + writer = CompositeChunkWriter( + [ + video_writer, + MetadataJsonWriter( + os.path.join(base_location, json_filename_pattern), + self.chunk_size, + ), + ], + self.chunk_size, + ) + self.writers[video_frame.source_id] = writer + + return writer.write_video_frame(video_frame, content, video_frame.keyframe) + + def _write_eos(self, eos: EndOfStream): + self.logger.info('Received EOS from source %s.', eos.source_id) + writer = self.writers.get(eos.source_id) + if writer is None: + return False + writer.write_eos(eos) + writer.close() + return True + + def terminate(self): + for file_writer in self.writers.values(): + file_writer.close() + + +def main(): + init_logging() + # To gracefully shutdown the adapter on SIGTERM (raise KeyboardInterrupt) + signal.signal(signal.SIGTERM, signal.getsignal(signal.SIGINT)) + + logger = get_logger(LOGGER_NAME) + + dir_location = os.environ['DIR_LOCATION'] + zmq_endpoint = os.environ['ZMQ_ENDPOINT'] + zmq_socket_type = opt_config('ZMQ_TYPE', 'SUB') + zmq_bind = opt_config('ZMQ_BIND', False, strtobool) + chunk_size = opt_config('CHUNK_SIZE', DEFAULT_CHUNK_SIZE, int) + source_id = opt_config('SOURCE_ID') + source_id_prefix = opt_config('SOURCE_ID_PREFIX') + + # possible exceptions will cause app to crash and log error by default + # no need to handle exceptions here + source = ZeroMQSource( + zmq_endpoint, + zmq_socket_type, + zmq_bind, + source_id=source_id, + source_id_prefix=source_id_prefix, + ) + + video_sink = VideoFilesSink(dir_location, chunk_size) + Gst.init(None) + logger.info('Video files sink started') + + try: + source.start() + for zmq_message in source: + video_sink.write(zmq_message) + except KeyboardInterrupt: + logger.info('Interrupted') + finally: + source.terminate() + video_sink.terminate() + + +if __name__ == '__main__': + main() diff --git a/adapters/gst/sinks/video_files.sh b/adapters/gst/sinks/video_files.sh index d15420985..ab88f3c80 100755 --- a/adapters/gst/sinks/video_files.sh +++ b/adapters/gst/sinks/video_files.sh @@ -1,41 +1,4 @@ #!/bin/bash -required_env() { - if [[ -z "${!1}" ]]; then - echo "Environment variable ${1} not set" - exit 1 - fi -} - -required_env DIR_LOCATION -required_env ZMQ_ENDPOINT - -ZMQ_SOCKET_TYPE="${ZMQ_TYPE:="SUB"}" -ZMQ_SOCKET_BIND="${ZMQ_BIND:="false"}" -ZEROMQ_SRC_ARGS=( - socket="${ZMQ_ENDPOINT}" - socket-type="${ZMQ_SOCKET_TYPE}" - bind="${ZMQ_SOCKET_BIND}" -) -if [[ -n "${SOURCE_ID}" ]]; then - ZEROMQ_SRC_ARGS+=(source-id="${SOURCE_ID}") -fi -if [[ -n "${SOURCE_ID_PREFIX}" ]]; then - ZEROMQ_SRC_ARGS+=(source-id-prefix="${SOURCE_ID_PREFIX}") -fi - -CHUNK_SIZE="${CHUNK_SIZE:=10000}" - -handler() { - kill -s SIGINT "${child_pid}" - wait "${child_pid}" -} -trap handler SIGINT SIGTERM - -gst-launch-1.0 \ - zeromq_src "${ZEROMQ_SRC_ARGS[@]}" ! \ - video_files_sink location="${DIR_LOCATION}" chunk-size="${CHUNK_SIZE}" \ - & - -child_pid="$!" -wait "${child_pid}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +python3 "${SCRIPT_DIR}/video_files.py" diff --git a/adapters/gst/sources/ffmpeg.sh b/adapters/gst/sources/ffmpeg.sh index 2801996b9..d6b6e4dec 100755 --- a/adapters/gst/sources/ffmpeg.sh +++ b/adapters/gst/sources/ffmpeg.sh @@ -31,6 +31,14 @@ fi BUFFER_LEN="${BUFFER_LEN:="50"}" FFMPEG_LOGLEVEL="${FFMPEG_LOGLEVEL:="info"}" USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" +SINK_PROPERTIES=( + source-id="${SOURCE_ID}" + socket="${ZMQ_ENDPOINT}" + socket-type="${ZMQ_SOCKET_TYPE}" + bind="${ZMQ_SOCKET_BIND}" + sync="${SYNC_OUTPUT}" + ts-offset="${SYNC_DELAY}" +) FFMPEG_SRC=(ffmpeg_src uri="${URI}" queue-len="${BUFFER_LEN}" loglevel="${FFMPEG_LOGLEVEL}") if [[ -n "${FFMPEG_PARAMS}" ]]; then @@ -49,9 +57,7 @@ if [[ "${USE_ABSOLUTE_TIMESTAMPS,,}" == "true" ]]; then fi PIPELINE+=( fps_meter "${FPS_PERIOD}" output="${FPS_OUTPUT}" ! - savant_rs_serializer source-id="${SOURCE_ID}" ! - zeromq_sink socket="${ZMQ_ENDPOINT}" socket-type="${ZMQ_SOCKET_TYPE}" bind="${ZMQ_SOCKET_BIND}" - sync="${SYNC_OUTPUT}" ts-offset="${SYNC_DELAY}" + zeromq_sink "${SINK_PROPERTIES[@]}" ) handler() { diff --git a/adapters/gst/sources/gige_cam.sh b/adapters/gst/sources/gige_cam.sh index ea63ae069..b2707e071 100755 --- a/adapters/gst/sources/gige_cam.sh +++ b/adapters/gst/sources/gige_cam.sh @@ -54,6 +54,7 @@ if [[ -n "${FEATURES}" ]]; then ADDITIONAL_ARAVISSRC_ARGS+=("features=${FEATURES USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" SINK_PROPERTIES=( + source-id="${SOURCE_ID}" socket="${ZMQ_ENDPOINT}" socket-type="${ZMQ_SOCKET_TYPE}" bind="${ZMQ_SOCKET_BIND}" @@ -95,7 +96,6 @@ if [[ "${USE_ABSOLUTE_TIMESTAMPS,,}" == "true" ]]; then fi PIPELINE+=( queue max-size-buffers=1 ! - savant_rs_serializer source-id="${SOURCE_ID}" ! fps_meter "${FPS_PERIOD}" output="${FPS_OUTPUT}" ! zeromq_sink "${SINK_PROPERTIES[@]}" ) diff --git a/adapters/gst/sources/media_files.sh b/adapters/gst/sources/media_files.sh index 2434c66ad..44d4b1d0e 100755 --- a/adapters/gst/sources/media_files.sh +++ b/adapters/gst/sources/media_files.sh @@ -36,20 +36,12 @@ fi SORT_BY_TIME="${SORT_BY_TIME:="false"}" READ_METADATA="${READ_METADATA:="false"}" -SAVANT_RS_SERIALIZER_OPTS=( +USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" +SINK_PROPERTIES=( source-id="${SOURCE_ID}" read-metadata="${READ_METADATA}" eos-on-file-end="${EOS_ON_FILE_END}" eos-on-frame-params-change=true -) -if [[ -n "${SHUTDOWN_AUTH}" ]]; then - SAVANT_RS_SERIALIZER_OPTS+=( - shutdown-auth="${SHUTDOWN_AUTH}" - ) -fi - -USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" -SINK_PROPERTIES=( socket="${ZMQ_ENDPOINT}" socket-type="${ZMQ_SOCKET_TYPE}" bind="${ZMQ_SOCKET_BIND}" @@ -60,6 +52,11 @@ if [[ -n "${RECEIVE_TIMEOUT_MSECS}" ]]; then else SINK_PROPERTIES+=("receive-timeout=5000") fi +if [[ -n "${SHUTDOWN_AUTH}" ]]; then + SINK_PROPERTIES+=( + shutdown-auth="${SHUTDOWN_AUTH}" + ) +fi PIPELINE=( media_files_src_bin location="${LOCATION}" file-type="${FILE_TYPE}" framerate="${FRAMERATE}" sort-by-time="${SORT_BY_TIME}" ! @@ -74,7 +71,6 @@ if [[ "${USE_ABSOLUTE_TIMESTAMPS,,}" == "true" ]]; then SINK_PROPERTIES+=(ts-offset="-${TS_OFFSET}") fi PIPELINE+=( - savant_rs_serializer "${SAVANT_RS_SERIALIZER_OPTS[@]}" ! zeromq_sink "${SINK_PROPERTIES[@]}" ) diff --git a/adapters/gst/sources/multi_stream.sh b/adapters/gst/sources/multi_stream.sh index 78c7ae3c7..81dc0cdb3 100755 --- a/adapters/gst/sources/multi_stream.sh +++ b/adapters/gst/sources/multi_stream.sh @@ -26,12 +26,6 @@ fi READ_METADATA="${READ_METADATA:="false"}" NUMBER_OF_STREAMS="${NUMBER_OF_STREAMS:=1}" -if [[ -n "${RECEIVE_TIMEOUT}" ]]; then - SENDER_RECEIVE_TIMEOUT="receive-timeout=${RECEIVE_TIMEOUT}" -else - SENDER_RECEIVE_TIMEOUT= -fi - MEDIA_FILES_SRC_BIN_OPTS=( location="${LOCATION}" file-type=video @@ -42,35 +36,32 @@ if [[ -n "${DOWNLOAD_PATH}" ]]; then ) fi -SAVANT_RS_SERIALIZER_OPTS=( + +USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" +SINK_PROPERTIES=( source-id="${SOURCE_ID}" read-metadata="${READ_METADATA}" enable-multistream=true number-of-streams="${NUMBER_OF_STREAMS}" + socket="${ZMQ_ENDPOINT}" + socket-type="${ZMQ_SOCKET_TYPE}" + bind="${ZMQ_SOCKET_BIND}" + sync="${SYNC_OUTPUT}" ) +if [[ -n "${RECEIVE_TIMEOUT}" ]]; then + SINK_PROPERTIES+=("receive-timeout=${RECEIVE_TIMEOUT}") +fi if [[ -n "${SOURCE_ID_PATTERN}" ]]; then - SAVANT_RS_SERIALIZER_OPTS+=( + SINK_PROPERTIES+=( source-id-pattern="${SOURCE_ID_PATTERN}" ) fi if [[ -n "${SHUTDOWN_AUTH}" ]]; then - SAVANT_RS_SERIALIZER_OPTS+=( + SINK_PROPERTIES+=( shutdown-auth="${SHUTDOWN_AUTH}" ) fi - -USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" -SINK_PROPERTIES=( - socket="${ZMQ_ENDPOINT}" - socket-type="${ZMQ_SOCKET_TYPE}" - bind="${ZMQ_SOCKET_BIND}" - sync="${SYNC_OUTPUT}" -) -if [[ -n "${RECEIVE_TIMEOUT}" ]]; then - SINK_PROPERTIES+=("receive-timeout=${RECEIVE_TIMEOUT}") -fi - PIPELINE=( media_files_src_bin "${MEDIA_FILES_SRC_BIN_OPTS[@]}" ! ) @@ -92,7 +83,6 @@ if [[ "${USE_ABSOLUTE_TIMESTAMPS,,}" == "true" ]]; then SINK_PROPERTIES+=(ts-offset="-${TS_OFFSET}") fi PIPELINE+=( - savant_rs_serializer "${SAVANT_RS_SERIALIZER_OPTS[@]}" ! zeromq_sink "${SINK_PROPERTIES[@]}" ) diff --git a/adapters/gst/sources/rtsp.sh b/adapters/gst/sources/rtsp.sh index c4a728652..29a824630 100755 --- a/adapters/gst/sources/rtsp.sh +++ b/adapters/gst/sources/rtsp.sh @@ -32,6 +32,14 @@ RTSP_TRANSPORT="${RTSP_TRANSPORT:="tcp"}" BUFFER_LEN="${BUFFER_LEN:="50"}" FFMPEG_LOGLEVEL="${FFMPEG_LOGLEVEL:="info"}" USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" +SINK_PROPERTIES=( + source-id="${SOURCE_ID}" + socket="${ZMQ_ENDPOINT}" + socket-type="${ZMQ_SOCKET_TYPE}" + bind="${ZMQ_SOCKET_BIND}" + sync="${SYNC_OUTPUT}" + ts-offset="${SYNC_DELAY}" +) PIPELINE=( ffmpeg_src uri="${RTSP_URI}" params="rtsp_transport=${RTSP_TRANSPORT}" @@ -47,9 +55,7 @@ if [[ "${USE_ABSOLUTE_TIMESTAMPS,,}" == "true" ]]; then fi PIPELINE+=( fps_meter "${FPS_PERIOD}" output="${FPS_OUTPUT}" ! - savant_rs_serializer source-id="${SOURCE_ID}" ! - zeromq_sink socket="${ZMQ_ENDPOINT}" socket-type="${ZMQ_SOCKET_TYPE}" bind="${ZMQ_SOCKET_BIND}" - sync="${SYNC_OUTPUT}" ts-offset="${SYNC_DELAY}" + zeromq_sink "${SINK_PROPERTIES[@]}" ) handler() { diff --git a/adapters/gst/sources/usb_cam.sh b/adapters/gst/sources/usb_cam.sh index a7860699b..dc401e8d9 100755 --- a/adapters/gst/sources/usb_cam.sh +++ b/adapters/gst/sources/usb_cam.sh @@ -26,6 +26,7 @@ fi USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" SINK_PROPERTIES=( + source-id="${SOURCE_ID}" socket="${ZMQ_ENDPOINT}" socket-type="${ZMQ_SOCKET_TYPE}" bind="${ZMQ_SOCKET_BIND}" @@ -46,7 +47,6 @@ if [[ "${USE_ABSOLUTE_TIMESTAMPS,,}" == "true" ]]; then fi PIPELINE+=( fps_meter "${FPS_PERIOD}" output="${FPS_OUTPUT}" ! - savant_rs_serializer source-id="${SOURCE_ID}" ! zeromq_sink "${SINK_PROPERTIES[@]}" ) diff --git a/adapters/gst/sources/video_loop.sh b/adapters/gst/sources/video_loop.sh index db99cb240..2772599b3 100755 --- a/adapters/gst/sources/video_loop.sh +++ b/adapters/gst/sources/video_loop.sh @@ -30,6 +30,9 @@ EOS_ON_LOOP_END="${EOS_ON_LOOP_END:="false"}" READ_METADATA="${READ_METADATA:="false"}" USE_ABSOLUTE_TIMESTAMPS="${USE_ABSOLUTE_TIMESTAMPS:="false"}" SINK_PROPERTIES=( + source-id="${SOURCE_ID}" + eos-on-loop-end="${EOS_ON_LOOP_END}" + read-metadata="${READ_METADATA}" socket="${ZMQ_ENDPOINT}" socket-type="${ZMQ_SOCKET_TYPE}" bind="${ZMQ_SOCKET_BIND}" @@ -52,8 +55,6 @@ if [[ -n "${LOSS_RATE}" ]]; then PIPELINE+=(identity drop-probability="${LOSS_RATE}" !) fi PIPELINE+=( - savant_rs_serializer source-id="${SOURCE_ID}" eos-on-loop-end="${EOS_ON_LOOP_END}" - read-metadata="${READ_METADATA}" ! zeromq_sink "${SINK_PROPERTIES[@]}" ) diff --git a/adapters/python/bridge/buffer.py b/adapters/python/bridge/buffer.py index 2dba1b666..df97f680c 100644 --- a/adapters/python/bridge/buffer.py +++ b/adapters/python/bridge/buffer.py @@ -2,9 +2,8 @@ import os import signal import time -from typing import Dict, List, Optional +from typing import Dict, Optional, Tuple -import zmq from rocksq.blocking import PersistentQueueWithCapacity from savant_rs.pipeline2 import ( VideoPipeline, @@ -12,19 +11,19 @@ VideoPipelineStagePayloadType, ) from savant_rs.primitives import VideoFrame, VideoFrameContent -from savant_rs.utils.serialization import Message, load_message_from_bytes +from savant_rs.utils.serialization import ( + Message, + load_message_from_bytes, + save_message_to_bytes, +) +from savant_rs.zmq import BlockingWriter, WriterConfigBuilder, WriterSocketType from adapters.python.shared.config import opt_config from adapters.shared.thread import BaseThreadWorker from savant.metrics import Counter, Gauge from savant.metrics.prometheus import BaseMetricsCollector, PrometheusMetricsExporter from savant.utils.logging import get_logger, init_logging -from savant.utils.zeromq import ( - Defaults, - SenderSocketTypes, - ZeroMQSource, - parse_zmq_socket_uri, -) +from savant.utils.zeromq import ZeroMQMessage, ZeroMQSource LOGGER_NAME = 'adapters.buffer' # For each message we need 3 slots: source ID, metadata, frame content @@ -99,33 +98,31 @@ def workload(self): self._zmq_source.start() while self.is_running: try: - message_parts = self._zmq_source.next_message_without_routing_id() - if message_parts is not None: + zmq_message = self._zmq_source.next_message() + if zmq_message is not None: self.logger.debug('Received message from the source ZeroMQ socket') - self.handle_next_message(message_parts) + self.handle_next_message(zmq_message) except Exception as e: self.logger.error('Failed to poll message: %s', e) self.is_running = False break + self._zmq_source.terminate() self.logger.info('Ingress was stopped') - def handle_next_message(self, message_parts: List[bytes]): + def handle_next_message(self, message: ZeroMQMessage): """Handle the next message from the source ZeroMQ socket.""" self._received_messages += 1 - while len(message_parts) < QUEUE_ITEM_SIZE: - message_parts.append(b'') - message: Message = load_message_from_bytes(message_parts[1]) - if message.is_video_frame(): - pushed = self.push_frame(message_parts) + if message.message.is_video_frame(): + pushed = self.push_frame(message) else: - pushed = self.push_service_message(message_parts) + pushed = self.push_service_message(message) if pushed: self._pushed_messages += 1 else: self._dropped_messages += 1 - def push_frame(self, message_parts: List[bytes]) -> bool: + def push_frame(self, message: ZeroMQMessage) -> bool: """Push frame to the buffer.""" buffer_size = self._queue.len @@ -141,15 +138,15 @@ def push_frame(self, message_parts: List[bytes]) -> bool: self.logger.debug('Buffer is full, dropping the frame') return False - self._queue.push(message_parts) + self._push_message(message) self.logger.debug('Pushed frame to the buffer') return True - def push_service_message(self, message_parts: List[bytes]) -> bool: + def push_service_message(self, message: ZeroMQMessage) -> bool: """Push service message to the buffer.""" try: - self._queue.push(message_parts) + self._push_message(message) except Exception as e: if e.args[0] != 'Failed to push item: Queue is full': raise @@ -159,6 +156,14 @@ def push_service_message(self, message_parts: List[bytes]) -> bool: self.logger.debug('Pushed message to the buffer') return True + def _push_message(self, message: ZeroMQMessage): + message_parts = [ + bytes(message.topic), + save_message_to_bytes(message.message), + message.content, + ] + self._queue.push(message_parts) + @property def received_messages(self) -> int: """Number of messages received from the source ZeroMQ socket.""" @@ -201,39 +206,31 @@ def __init__( content=VideoFrameContent.none(), ) - self._socket_type, self._bind, self._socket = parse_zmq_socket_uri( - uri=config.zmq_sink_endpoint, - socket_type_enum=SenderSocketTypes, - socket_type_name=None, - bind=None, - ) + config_builder = WriterConfigBuilder(config.zmq_sink_endpoint) + config = config_builder.build() assert ( - self._socket_type == SenderSocketTypes.DEALER + config.socket_type == WriterSocketType.Dealer ), 'Only DEALER socket type is supported for Egress' - self._zmq_context = zmq.Context() - self._sender: zmq.Socket = self._zmq_context.socket(self._socket_type.value) - self._sender.setsockopt(zmq.SNDHWM, Defaults.SEND_HWM) - self._sender.setsockopt(zmq.RCVTIMEO, Defaults.SENDER_RECEIVE_TIMEOUT) - if self._bind: - self._sender.bind(self._socket) - else: - self._sender.connect(self._socket) + + self._writer = BlockingWriter(config) def workload(self): self.logger.info('Starting Egress') + self._writer.start() while self.is_running: try: - message_parts = self.pop_next_message() - if message_parts is not None: + message = self.pop_next_message() + if message is not None: self.logger.debug('Sending message to the sink ZeroMQ socket') - self._sender.send_multipart(message_parts) + self._writer.send_message(*message) except Exception as e: - self.logger.error('Failed to poll message: %s', e) + self.logger.error('Failed to send message: %s', e) self.is_running = False break + self._writer.shutdown() self.logger.info('Egress was stopped') - def pop_next_message(self) -> Optional[List[bytes]]: + def pop_next_message(self) -> Optional[Tuple[str, Message, bytes]]: """Pop the next message from the buffer. When the buffer is empty, wait until it is not empty and then pop the message. @@ -247,14 +244,15 @@ def pop_next_message(self) -> Optional[List[bytes]]: time.sleep(self._idle_polling_period) return None - message_parts = self._queue.pop(QUEUE_ITEM_SIZE) - if not message_parts[-1]: - message_parts.pop() + topic, message, data = self._queue.pop(QUEUE_ITEM_SIZE) + topic = topic.decode() + message = load_message_from_bytes(message) + self._sent_messages += 1 frame_id = self._pipeline.add_frame('fps-meter', self._video_frame) self._pipeline.delete(frame_id) - return message_parts + return topic, message, data @property def sent_messages(self) -> int: diff --git a/adapters/python/sinks/chunk_writer.py b/adapters/python/sinks/chunk_writer.py index b2adf7f3d..c51c9a055 100644 --- a/adapters/python/sinks/chunk_writer.py +++ b/adapters/python/sinks/chunk_writer.py @@ -1,5 +1,5 @@ import math -from typing import List +from typing import List, Optional from savant_rs.primitives import EndOfStream, VideoFrame @@ -9,8 +9,8 @@ class ChunkWriter: """Writes data in chunks.""" - def __init__(self, chunk_size: int): - self.logger = get_logger(f'adapters.{self.__class__.__name__}') + def __init__(self, chunk_size: int, logger_prefix: str = __name__): + self.logger = get_logger(f'{logger_prefix}.{self.__class__.__name__}') self.chunk_size = chunk_size if chunk_size > 0: self.chunk_size_digits = int(math.log10(chunk_size)) + 1 @@ -23,7 +23,7 @@ def __init__(self, chunk_size: int): def write_video_frame( self, frame: VideoFrame, - data, + content: Optional[bytes], can_start_new_chunk: bool, ) -> bool: if can_start_new_chunk and 0 < self.chunk_size <= self.frames_in_chunk: @@ -31,7 +31,7 @@ def write_video_frame( if not self.opened: self.open() frame_num = self.frames_in_chunk - result = self._write_video_frame(frame, data, frame_num) + result = self._write_video_frame(frame, content, frame_num) self.frames_in_chunk += 1 return result @@ -70,7 +70,12 @@ def _close(self): def _flush(self): pass - def _write_video_frame(self, frame: VideoFrame, data, frame_num: int) -> bool: + def _write_video_frame( + self, + frame: VideoFrame, + content: Optional[bytes], + frame_num: int, + ) -> bool: pass def _write_eos(self, eos: EndOfStream) -> bool: @@ -94,9 +99,14 @@ def _flush(self): for writer in self.writers: writer.flush() - def _write_video_frame(self, frame: VideoFrame, data, frame_num: int) -> bool: + def _write_video_frame( + self, + frame: VideoFrame, + content: Optional[bytes], + frame_num: int, + ) -> bool: for writer in self.writers: - if not writer._write_video_frame(frame, data, frame_num): + if not writer._write_video_frame(frame, content, frame_num): return False return True diff --git a/adapters/python/sinks/image_files.py b/adapters/python/sinks/image_files.py index 52eb2945a..1dc887725 100755 --- a/adapters/python/sinks/image_files.py +++ b/adapters/python/sinks/image_files.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 import os +import signal import traceback from distutils.util import strtobool -from typing import Dict, List +from typing import Dict, Optional from savant_rs.primitives import EndOfStream, VideoFrame -from savant_rs.utils.serialization import Message, load_message_from_bytes from adapters.python.shared.config import opt_config from adapters.python.sinks.chunk_writer import ChunkWriter, CompositeChunkWriter @@ -16,7 +16,7 @@ ) from savant.api.enums import ExternalFrameType from savant.utils.logging import get_logger, init_logging -from savant.utils.zeromq import ZeroMQSource, build_topic_prefix +from savant.utils.zeromq import ZeroMQMessage, ZeroMQSource LOGGER_NAME = 'adapters.image_files_sink' DEFAULT_CHUNK_SIZE = 10000 @@ -26,18 +26,24 @@ class ImageFilesWriter(ChunkWriter): def __init__(self, base_location: str, chunk_size: int): self.base_location = base_location self.chunk_location = None - super().__init__(chunk_size) + super().__init__(chunk_size, logger_prefix=LOGGER_NAME) - def _write_video_frame(self, frame: VideoFrame, data, frame_num: int) -> bool: + def _write_video_frame( + self, + frame: VideoFrame, + content: Optional[bytes], + frame_num: int, + ) -> bool: if frame.content.is_external(): frame_type = ExternalFrameType(frame.content.get_method()) if frame_type != ExternalFrameType.ZEROMQ: self.logger.error('Unsupported frame type "%s".', frame_type.value) return False - if len(data) != 1: - self.logger.error('Data has %s parts, expected 1.', len(data)) + if not content: + self.logger.error( + 'Frame %s/%s has no content data', frame.source_id, frame.pts + ) return False - content = data[0] elif frame.content.is_internal(): content = frame.content.get_data_as_bytes() else: @@ -79,20 +85,25 @@ def __init__( chunk_size: int, skip_frames_without_objects: bool = False, ): - self.logger = get_logger(f'adapters.{self.__class__.__name__}') + self.logger = get_logger(f'{LOGGER_NAME}.{self.__class__.__name__}') self.location = location self.chunk_size = chunk_size self.skip_frames_without_objects = skip_frames_without_objects self.writers: Dict[str, ChunkWriter] = {} - def write(self, message: Message, data: List[bytes]): + def write(self, zmq_message: ZeroMQMessage): + message = zmq_message.message + message.validate_seq_id() if message.is_video_frame(): - return self._write_video_frame(message.as_video_frame(), data) + return self._write_video_frame( + message.as_video_frame(), + zmq_message.content, + ) elif message.is_end_of_stream(): return self._write_eos(message.as_end_of_stream()) self.logger.debug('Unsupported message type for message %r', message) - def _write_video_frame(self, video_frame: VideoFrame, data: List[bytes]) -> bool: + def _write_video_frame(self, video_frame: VideoFrame, content: bytes) -> bool: if self.skip_frames_without_objects and not frame_has_objects(video_frame): self.logger.debug( 'Frame %s from source %s does not have objects. Skipping it.', @@ -118,7 +129,7 @@ def _write_video_frame(self, video_frame: VideoFrame, data: List[bytes]) -> bool self.chunk_size, ) self.writers[video_frame.source_id] = writer - return writer.write_video_frame(video_frame, data, video_frame.keyframe) + return writer.write_video_frame(video_frame, content, video_frame.keyframe) def _write_eos(self, eos: EndOfStream): self.logger.info('Received EOS from source %s.', eos.source_id) @@ -136,6 +147,9 @@ def terminate(self): def main(): init_logging() + # To gracefully shutdown the adapter on SIGTERM (raise KeyboardInterrupt) + signal.signal(signal.SIGTERM, signal.getsignal(signal.SIGINT)) + logger = get_logger(LOGGER_NAME) dir_location = os.environ['DIR_LOCATION'] @@ -146,10 +160,8 @@ def main(): 'SKIP_FRAMES_WITHOUT_OBJECTS', False, strtobool ) chunk_size = opt_config('CHUNK_SIZE', DEFAULT_CHUNK_SIZE, int) - topic_prefix = build_topic_prefix( - source_id=opt_config('SOURCE_ID'), - source_id_prefix=opt_config('SOURCE_ID_PREFIX'), - ) + source_id = opt_config('SOURCE_ID') + source_id_prefix = opt_config('SOURCE_ID_PREFIX') # possible exceptions will cause app to crash and log error by default # no need to handle exceptions here @@ -157,7 +169,8 @@ def main(): zmq_endpoint, zmq_socket_type, zmq_bind, - topic_prefix=topic_prefix, + source_id=source_id, + source_id_prefix=source_id_prefix, ) image_sink = ImageFilesSink(dir_location, chunk_size, skip_frames_without_objects) @@ -165,13 +178,12 @@ def main(): try: source.start() - for message_bin, *data in source: - message = load_message_from_bytes(message_bin) - message.validate_seq_id() - image_sink.write(message, data) + for zmq_message in source: + image_sink.write(zmq_message) except KeyboardInterrupt: logger.info('Interrupted') finally: + source.terminate() image_sink.terminate() diff --git a/adapters/python/sinks/kafka_redis.py b/adapters/python/sinks/kafka_redis.py index 3875aaf4c..bb5be8695 100644 --- a/adapters/python/sinks/kafka_redis.py +++ b/adapters/python/sinks/kafka_redis.py @@ -11,7 +11,7 @@ VideoFrameContent, VideoFrameTranscodingMethod, ) -from savant_rs.utils.serialization import Message, save_message_to_bytes +from savant_rs.utils.serialization import save_message_to_bytes from adapters.python.shared.config import opt_config from adapters.python.shared.kafka_redis import ( @@ -216,13 +216,13 @@ async def process_message(self, result: SinkResult): frame_meta.content = await self.store_frame_content( frame_meta, result.frame_content ) - message = Message.video_frame(frame_meta) + message = frame_meta.to_message() self.count_frame() else: source_id = result.eos.source_id self._logger.debug('Received EOS for source %s', source_id) - message = Message.end_of_stream(result.eos) + message = result.eos.to_message() return message, source_id diff --git a/adapters/python/sinks/metadata_json.py b/adapters/python/sinks/metadata_json.py index bb224cc44..f31869a51 100755 --- a/adapters/python/sinks/metadata_json.py +++ b/adapters/python/sinks/metadata_json.py @@ -2,10 +2,12 @@ import json import os +import signal import traceback from distutils.util import strtobool -from typing import Any, Dict, Optional +from typing import Dict, Optional +from savant_rs.match_query import MatchQuery from savant_rs.primitives import ( Attribute, AttributeValue, @@ -13,15 +15,13 @@ EndOfStream, VideoFrame, ) -from savant_rs.utils.serialization import Message, load_message_from_bytes -from savant_rs.video_object_query import MatchQuery from adapters.python.shared.config import opt_config from adapters.python.sinks.chunk_writer import ChunkWriter from savant.api.constants import DEFAULT_NAMESPACE from savant.api.parser import parse_video_frame from savant.utils.logging import get_logger, init_logging -from savant.utils.zeromq import ZeroMQSource, build_topic_prefix +from savant.utils.zeromq import ZeroMQMessage, ZeroMQSource LOGGER_NAME = 'adapters.metadata_json_sink' @@ -35,9 +35,14 @@ class Patterns: class MetadataJsonWriter(ChunkWriter): def __init__(self, pattern: str, chunk_size: int): self.pattern = pattern - super().__init__(chunk_size) + super().__init__(chunk_size, logger_prefix=LOGGER_NAME) - def _write_video_frame(self, frame: VideoFrame, data: Any, frame_num: int) -> bool: + def _write_video_frame( + self, + frame: VideoFrame, + content: Optional[bytes], + frame_num: int, + ) -> bool: metadata = parse_video_frame(frame) metadata['schema'] = 'VideoFrame' return self._write_meta_to_file(metadata, frame_num) @@ -88,7 +93,7 @@ def __init__( skip_frames_without_objects: bool = True, chunk_size: int = 0, ): - self.logger = get_logger(f'adapters.{self.__class__.__name__}') + self.logger = get_logger(f'{LOGGER_NAME}.{self.__class__.__name__}') self.skip_frames_without_objects = skip_frames_without_objects self.chunk_size = chunk_size self.writers: Dict[str, MetadataJsonWriter] = {} @@ -104,7 +109,9 @@ def terminate(self): for file_writer in self.writers.values(): file_writer.close() - def write(self, message: Message): + def write(self, zmq_message: ZeroMQMessage): + message = zmq_message.message + message.validate_seq_id() if message.is_video_frame(): return self._write_video_frame(message.as_video_frame()) elif message.is_end_of_stream(): @@ -170,7 +177,11 @@ def get_tag_location(frame: VideoFrame): def main(): init_logging() + # To gracefully shutdown the adapter on SIGTERM (raise KeyboardInterrupt) + signal.signal(signal.SIGTERM, signal.getsignal(signal.SIGINT)) + logger = get_logger(LOGGER_NAME) + location = os.environ['LOCATION'] zmq_endpoint = os.environ['ZMQ_ENDPOINT'] zmq_socket_type = opt_config('ZMQ_TYPE', 'SUB') @@ -179,10 +190,8 @@ def main(): 'SKIP_FRAMES_WITHOUT_OBJECTS', False, strtobool ) chunk_size = opt_config('CHUNK_SIZE', 0, int) - topic_prefix = build_topic_prefix( - source_id=opt_config('SOURCE_ID'), - source_id_prefix=opt_config('SOURCE_ID_PREFIX'), - ) + source_id = opt_config('SOURCE_ID') + source_id_prefix = opt_config('SOURCE_ID_PREFIX') # possible exceptions will cause app to crash and log error by default # no need to handle exceptions here @@ -190,7 +199,8 @@ def main(): zmq_endpoint, zmq_socket_type, zmq_bind, - topic_prefix=topic_prefix, + source_id=source_id, + source_id_prefix=source_id_prefix, ) sink = MetadataJsonSink(location, skip_frames_without_objects, chunk_size) @@ -198,13 +208,12 @@ def main(): try: source.start() - for message_bin, *data in source: - message = load_message_from_bytes(message_bin) - message.validate_seq_id() - sink.write(message) + for zmq_message in source: + sink.write(zmq_message) except KeyboardInterrupt: logger.info('Interrupted') finally: + source.terminate() sink.terminate() diff --git a/docker/Dockerfile.adapters-gstreamer b/docker/Dockerfile.adapters-gstreamer index cd0e5c622..5ed1e88c6 100644 --- a/docker/Dockerfile.adapters-gstreamer +++ b/docker/Dockerfile.adapters-gstreamer @@ -104,7 +104,6 @@ COPY gst_plugins gst_plugins COPY adapters/shared adapters/shared COPY adapters/python adapters/python COPY adapters/gst adapters/gst -COPY gst_plugins/python/zeromq_src.py adapters/gst/gst_plugins/python/ COPY gst_plugins/python/logger.py adapters/gst/gst_plugins/python/ ENV GST_PLUGIN_PATH=$PROJECT_PATH/adapters/gst/gst_plugins \ LOGLEVEL=info diff --git a/docs/performance.md b/docs/performance.md index 7f916c9ed..dd56683b4 100644 --- a/docs/performance.md +++ b/docs/performance.md @@ -20,6 +20,7 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) (queue length 10) | 161.77 | 37.95 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 371.42 | 53.67? | 70.58 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 371.29 | | 71.79 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 376.48 | | 71.95 | ### conditional_video_processing @@ -48,6 +49,7 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) | 130.15 | 28.72 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 224.22 | 36.45? | 48.92 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 229.22 | | 49.77 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 229.51 | | 50.04 | ### intersection_traffic_meter (yolov8m) @@ -57,12 +59,14 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) | 94.56 | 21.56 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 264.02 | 32.15? | 41.14 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 268.50 | | 41.11 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 271.13 | | 41.11 | ### fisheye line crossing | Savant ver. | A4000 | Xavier NX | Orin Nano | |---------------------------------------------------------------|-------|-----------|-----------| | [#193](https://github.com/insight-platform/Savant/issues/193) | 86.6 | 23.9 | 33.7 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 86.6 | | 33.71 | ### license_plate_recognition @@ -72,6 +76,7 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) | 92.73 | 25.32 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 270.99 | 35.90? | 42.24 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 272.64 | | 41.86 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 272.78 | | 42.47 | ### nvidia_car_classification @@ -89,6 +94,7 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) (queue length 10) | 149.89 | 41.79 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 475.19 | 64.33? | 133.48 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 519.29 | | 144.69 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 530.02 | | 142.89 | ### opencv_cuda_bg_remover_mog2 @@ -106,6 +112,7 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) (queue length 10) | 618.57 | 95.59 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 740.03 | 102.53? | 130.39 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 749.55 | | 128.76 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 748.87 | | 126.41 | ### opencv_cuda_bg_remover_mog2 (multi-stream) @@ -115,6 +122,8 @@ | [#372](https://github.com/insight-platform/Savant/issues/372) (queue length 10) | 510.97 | 93.20 | | [#443](https://github.com/insight-platform/Savant/issues/443) | 595.70 | 89.13 | | [#443](https://github.com/insight-platform/Savant/issues/443) (queue length 10) | 598.71 | 87.69 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 752.63 | | +| [#612](https://github.com/insight-platform/Savant/issues/612) (queue length 10) | 748.87 | | ### peoplenet_detector @@ -132,12 +141,14 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) (queue length 10) | 110.64 | 26.73 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 414.61 | 77.47? | 117.53 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 416.32 | | 117.34 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 430.29 | | 117.03 | ### rtdetr | Savant ver. | A4000 | Xavier NX | Orin Nano | |---------------------------------------------------------------|--------|-----------|-----------| | [#558](https://github.com/insight-platform/Savant/issues/558) | 137.41 | | | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 134.47 | | | ### traffic_meter (yolov8m) @@ -155,6 +166,7 @@ | [#443](https://github.com/insight-platform/Savant/issues/443) (queue length 10) | 123.88 | 19.33 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 256.09 | 54.74? | 41.02 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 255.94 | | 41.01 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 261.18 | | 41.03 | ### yolov8_seg @@ -171,3 +183,10 @@ Note: `yolov8_seg` always has a queue length of 10. | [#443](https://github.com/insight-platform/Savant/issues/443) | 71.28 | 21.95 | | | [#352](https://github.com/insight-platform/Savant/issues/352) | 85.69 | 24.73? | 35.96 | | [#550](https://github.com/insight-platform/Savant/issues/550) | 91.04 | | 36.01 | +| [#612](https://github.com/insight-platform/Savant/issues/612) | 91.65 | | 36.11 | + +### panoptic_driving_perception + +| Savant ver. | A4000 | Xavier NX | Orin Nano | +|---------------------------------------------------------------|-------|-----------|-----------| +| [#612](https://github.com/insight-platform/Savant/issues/612) | 62.78 | | 10.07 | diff --git a/docs/source/savant_101/10_adapters.rst b/docs/source/savant_101/10_adapters.rst index 79d1e717e..1e6227cd4 100644 --- a/docs/source/savant_101/10_adapters.rst +++ b/docs/source/savant_101/10_adapters.rst @@ -743,7 +743,7 @@ Running the adapter with Docker: .. code-block:: bash docker run --rm -it --name sink-meta-json \ - --entrypoint /opt/savant/adapters/gst/sinks/video_files.sh \ + --entrypoint /opt/savant/adapters/gst/sinks/video_files.py \ -e ZMQ_ENDPOINT=sub+connect:ipc:///tmp/zmq-sockets/output-video.ipc \ -e DIR_LOCATION=/path/to/output/%source_id-%src_filename \ -e SKIP_FRAMES_WITHOUT_OBJECTS=False \ diff --git a/gst_plugins/python/savant_rs_video_decode_bin.py b/gst_plugins/python/savant_rs_video_decode_bin.py index e6b2bd56f..eec92dfce 100644 --- a/gst_plugins/python/savant_rs_video_decode_bin.py +++ b/gst_plugins/python/savant_rs_video_decode_bin.py @@ -28,12 +28,6 @@ 'source-timeout', 'source-eviction-interval', 'max-parallel-streams', - 'shutdown-auth', - 'pass-through-mode', - 'ingress-module', - 'ingress-class', - 'ingress-kwargs', - 'ingress-dev-mode', ] } SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES = { @@ -58,10 +52,10 @@ 'VideoPipeline object from savant-rs.', GObject.ParamFlags.READWRITE, ), - 'pipeline-source-stage-name': ( + 'pipeline-demux-stage-name': ( str, - 'Name of the pipeline stage for source.', - 'Name of the pipeline stage for source.', + 'Name of the pipeline stage for demuxer.', + 'Name of the pipeline stage for demuxer.', None, GObject.ParamFlags.READWRITE, ), @@ -188,7 +182,7 @@ def do_get_property(self, prop): return self._max_parallel_streams if prop.name == 'pipeline': return self._video_pipeline - if prop.name == 'pipeline-source-stage-name': + if prop.name == 'pipeline-demux-stage-name': return self._demuxer.get_property('pipeline-stage-name') if prop.name == 'pipeline-decoder-stage-name': return self._pipeline_decoder_stage_name @@ -217,7 +211,7 @@ def do_set_property(self, prop, value): elif prop.name == 'pipeline': self._video_pipeline = value self._demuxer.set_property(prop.name, value) - elif prop.name == 'pipeline-source-stage-name': + elif prop.name == 'pipeline-demux-stage-name': self._demuxer.set_property('pipeline-stage-name', value) elif prop.name == 'pipeline-decoder-stage-name': self._pipeline_decoder_stage_name = value diff --git a/gst_plugins/python/savant_rs_video_demux.py b/gst_plugins/python/savant_rs_video_demux.py index 533c1eb44..5c8cfc26c 100644 --- a/gst_plugins/python/savant_rs_video_demux.py +++ b/gst_plugins/python/savant_rs_video_demux.py @@ -1,33 +1,25 @@ """SavantRsVideoDemux element.""" import inspect -import itertools import time from contextlib import contextmanager from dataclasses import dataclass, field from threading import Event, Lock, Thread from typing import Dict, NamedTuple, Optional, Union +from pygstsavantframemeta import gst_buffer_get_savant_frame_meta from savant_rs.pipeline2 import VideoPipeline -from savant_rs.primitives import ( - EndOfStream, - Shutdown, - VideoFrame, - VideoFrameContent, - VideoFrameTransformation, -) -from savant_rs.utils import PropagatedContext +from savant_rs.primitives import VideoFrame -from gst_plugins.python.pyfunc_common import handle_non_fatal_error, init_pyfunc from gst_plugins.python.savant_rs_video_demux_common import FrameParams, build_caps -from savant.api.enums import ExternalFrameType -from savant.api.parser import convert_ts -from savant.base.frame_filter import DefaultIngressFilter -from savant.base.pyfunc import PyFunc from savant.gstreamer import GObject, Gst from savant.gstreamer.codecs import Codec +from savant.gstreamer.event import parse_savant_eos_event from savant.gstreamer.utils import ( + RequiredPropertyError, + gst_post_library_settings_error, gst_post_stream_demux_error, - load_message_from_gst_buffer, + on_pad_event, + required_property, ) from savant.utils.logging import LoggerMixin @@ -84,71 +76,6 @@ None, GObject.ParamFlags.READWRITE, ), - 'shutdown-auth': ( - str, - 'Authentication key for Shutdown message.', - 'Authentication key for Shutdown message.', - None, - GObject.ParamFlags.READWRITE, - ), - 'max-width': ( - int, - 'Maximum allowable resolution width of the video stream', - 'Maximum allowable resolution width of the video stream', - 0, - GObject.G_MAXINT, - 0, - GObject.ParamFlags.READWRITE, - ), - 'max-height': ( - int, - 'Maximum allowable resolution height of the video stream', - 'Maximum allowable resolution height of the video stream', - 0, - GObject.G_MAXINT, - 0, - GObject.ParamFlags.READWRITE, - ), - 'pass-through-mode': ( - bool, - 'Run module in a pass-through mode.', - 'Run module in a pass-through mode. Store frame content in VideoFrame ' - 'object as an internal VideoFrameContent.', - False, - GObject.ParamFlags.READWRITE, - ), - 'ingress-module': ( - str, - 'Ingress filter python module.', - 'Name or path of the python module where ' - 'the ingress filter class code is located.', - None, - GObject.ParamFlags.READWRITE, - ), - 'ingress-class': ( - str, - 'Ingress filter python class name.', - 'Name of the python class that implements ingress filter.', - None, - GObject.ParamFlags.READWRITE, - ), - 'ingress-kwargs': ( - str, - 'Ingress filter init kwargs.', - 'Keyword arguments for ingress filter initialization.', - None, - GObject.ParamFlags.READWRITE, - ), - 'ingress-dev-mode': ( - bool, - 'Ingress filter dev mode flag.', - ( - 'Whether to monitor the ingress filter source file changes at runtime' - ' and reload the pyfunc objects when necessary.' - ), - False, - GObject.ParamFlags.READWRITE, - ), } @@ -173,32 +100,16 @@ def lock(self): class FrameInfo(NamedTuple): + idx: int params: FrameParams - pts: int - dts: int - duration: int video_frame: VideoFrame @staticmethod - def build(video_frame: VideoFrame) -> 'FrameInfo': + def build(idx: int, video_frame: VideoFrame) -> 'FrameInfo': params = FrameParams.from_video_frame(video_frame) - pts = convert_ts(video_frame.pts, video_frame.time_base) - dts = ( - convert_ts(video_frame.dts, video_frame.time_base) - if video_frame.dts is not None - else Gst.CLOCK_TIME_NONE - ) - duration = ( - convert_ts(video_frame.duration, video_frame.time_base) - if video_frame.duration is not None - else Gst.CLOCK_TIME_NONE - ) - return FrameInfo( + idx=idx, params=params, - pts=pts, - dts=dts, - duration=duration, video_frame=video_frame, ) @@ -245,22 +156,9 @@ def __init__(self): self.source_lock = Lock() self.is_running = False self.expiration_thread = Thread(target=self.eviction_job, daemon=True) - self.store_metadata = False self.max_parallel_streams: int = 0 self.video_pipeline: Optional[VideoPipeline] = None self.pipeline_stage_name: Optional[str] = None - self.shutdown_auth: Optional[str] = None - self.max_width: int = 0 - self.max_height: int = 0 - self.pass_through_mode = False - - self.ingress_module: Optional[str] = None - self.ingress_class_name: Optional[str] = None - self.ingress_kwargs: Optional[str] = None - self.ingress_dev_mode: bool = False - self.ingress_pyfunc: Optional[PyFunc] = None - - self._frame_idx_gen = itertools.count() self.sink_pad: Gst.Pad = Gst.Pad.new_from_template( Gst.PadTemplate.new( @@ -272,6 +170,14 @@ def __init__(self): 'sink', ) self.sink_pad.set_chain_function(self.handle_buffer) + self.sink_pad.add_probe( + Gst.PadProbeType.EVENT_DOWNSTREAM, + on_pad_event, + { + Gst.EventType.CUSTOM_DOWNSTREAM: self.on_savant_eos_event, + Gst.EventType.EOS: self.on_eos, + }, + ) assert self.add_pad(self.sink_pad), 'Failed to add sink pad.' def do_state_changed(self, old: Gst.State, new: Gst.State, pending: Gst.State): @@ -288,19 +194,14 @@ def do_state_changed(self, old: Gst.State, new: Gst.State, pending: Gst.State): self.expiration_thread.start() if old == Gst.State.NULL and new == Gst.State.READY: - if self.ingress_module and self.ingress_class_name: - self.ingress_pyfunc = init_pyfunc( - self, - self.logger, - self.ingress_module, - self.ingress_class_name, - self.ingress_kwargs, - self.ingress_dev_mode, - ) - else: - # for AO RTSP - self.logger.debug('Ingress filter is not set, using default one.') - self.ingress_pyfunc = DefaultIngressFilter() + try: + required_property('pipeline', self.video_pipeline) + required_property('pipeline-stage-name', self.pipeline_stage_name) + except RequiredPropertyError as exc: + self.logger.exception('Failed to start element: %s', exc, exc_info=True) + frame = inspect.currentframe() + gst_post_library_settings_error(self, frame, __file__, text=exc.args[0]) + return def do_get_property(self, prop): """Get property callback.""" @@ -316,22 +217,6 @@ def do_get_property(self, prop): return self.video_pipeline if prop.name == 'pipeline-stage-name': return self.pipeline_stage_name - if prop.name == 'shutdown-auth': - return self.shutdown_auth - if prop.name == 'max-width': - return self.max_width - if prop.name == 'max-height': - return self.max_height - if prop.name == 'pass-through-mode': - return self.pass_through_mode - if prop.name == 'ingress-module': - return self.ingress_module - if prop.name == 'ingress-class': - return self.ingress_class_name - if prop.name == 'ingress-kwargs': - return self.ingress_kwargs - if prop.name == 'ingress-dev-mode': - return self.ingress_dev_mode raise AttributeError(f'Unknown property {prop.name}') def do_set_property(self, prop, value): @@ -348,22 +233,6 @@ def do_set_property(self, prop, value): self.video_pipeline = value elif prop.name == 'pipeline-stage-name': self.pipeline_stage_name = value - elif prop.name == 'shutdown-auth': - self.shutdown_auth = value - elif prop.name == 'max-width': - self.max_width = value - elif prop.name == 'max-height': - self.max_height = value - elif prop.name == 'pass-through-mode': - self.pass_through_mode = value - elif prop.name == 'ingress-module': - self.ingress_module = value - elif prop.name == 'ingress-class': - self.ingress_class_name = value - elif prop.name == 'ingress-kwargs': - self.ingress_kwargs = value - elif prop.name == 'ingress-dev-mode': - self.ingress_dev_mode = value else: raise AttributeError(f'Unknown property {prop.name}') @@ -379,56 +248,42 @@ def handle_buffer( buffer.get_size(), buffer.pts, ) - if not self.is_running: - self.logger.info( - 'Demuxer is not running. Skipping buffer with timestamp %s.', + + savant_frame_meta = gst_buffer_get_savant_frame_meta(buffer) + if savant_frame_meta is None: + self.logger.warning( + 'No Savant Frame Metadata found on buffer with PTS %s, skipping.', buffer.pts, ) return Gst.FlowReturn.OK - message = load_message_from_gst_buffer(buffer) - message.validate_seq_id() - # TODO: Pipeline message types might be extended beyond only VideoFrame - # Additional checks for audio/raw_tensors/etc. may be required - if message.is_video_frame(): - result = self.handle_video_frame( - message.as_video_frame(), - message.span_context, - buffer, - ) - elif message.is_end_of_stream(): - result = self.handle_eos(message.as_end_of_stream()) - elif message.is_shutdown(): - result = self.handle_shutdown(message.as_shutdown()) - else: - self.logger.warning('Unsupported message type for message %r', message) - result = Gst.FlowReturn.OK - - return result - - def handle_video_frame( - self, - video_frame: VideoFrame, - span_context: PropagatedContext, - buffer: Gst.Buffer, - ) -> Gst.FlowReturn: - """Handle VideoFrame message.""" - - frame_info = FrameInfo.build(video_frame) + video_frame, _ = self.video_pipeline.get_independent_frame( + savant_frame_meta.idx + ) self.logger.debug( - 'Received frame %s/%s from source %s; frame %s a keyframe', - frame_info.pts, - frame_info.dts, + 'Handling frame %s/%s from source %s; frame %s a keyframe', + buffer.pts, + buffer.dts, video_frame.source_id, 'is' if video_frame.keyframe else 'is not', ) - - if not self._apply_ingress_filter(frame_info): + self.video_pipeline.move_as_is( + self.pipeline_stage_name, [savant_frame_meta.idx] + ) + if not self.is_running: + self.logger.info( + 'Demuxer is not running. Skipping buffer with timestamp %s.', + buffer.pts, + ) + self.video_pipeline.delete(savant_frame_meta.idx) return Gst.FlowReturn.OK + frame_info = FrameInfo.build(savant_frame_meta.idx, video_frame) + with self.source_lock: - res = self._get_source_info(frame_info) + res = self._get_source_info(frame_info, buffer) if not isinstance(res, SourceInfo): + self.video_pipeline.delete(frame_info.idx) return res source_info = res source_info.locked.set() @@ -439,15 +294,13 @@ def handle_video_frame( source_info.src_pad is not None and source_info.params != frame_info.params ): - if self.is_greater_than_max_resolution(frame_info.params): - self.send_eos(source_info) - self.update_frame_params(source_info, frame_info) + self.update_frame_params(source_info, frame_info.params) if source_info.src_pad is not None: - self.check_timestamps(source_info, frame_info) - if frame_info.pts != Gst.CLOCK_TIME_NONE: - source_info.last_pts = frame_info.pts - if frame_info.dts != Gst.CLOCK_TIME_NONE: - source_info.last_dts = frame_info.dts + self.check_timestamps(source_info, buffer) + if buffer.pts != Gst.CLOCK_TIME_NONE: + source_info.last_pts = buffer.pts + if buffer.dts != Gst.CLOCK_TIME_NONE: + source_info.last_dts = buffer.dts if source_info.src_pad is None: if video_frame.keyframe: self.add_source(video_frame.source_id, source_info) @@ -455,75 +308,35 @@ def handle_video_frame( self.logger.warning( 'Frame %s from source %s is not a keyframe, skipping it. ' 'Stream should start with a keyframe.', - frame_info.pts, + buffer.pts, video_frame.source_id, ) + self.video_pipeline.delete(savant_frame_meta.idx) return Gst.FlowReturn.OK - try: - frame_buf = self._build_frame_buffer(video_frame, buffer) - except ValueError as e: - self.is_running = False - error = ( - f'Failed to build buffer for video frame {frame_info.pts} ' - f'from source {video_frame.source_id}: {e}' - ) - self.logger.error(error) - frame = inspect.currentframe() - gst_post_stream_demux_error( - gst_element=self, - frame=frame, - file_path=__file__, - text=error, - ) - return Gst.FlowReturn.ERROR - - result = self._push_frame(source_info, frame_info, frame_buf, span_context) + self.logger.debug( + 'Pushing frame with IDX %s and PTS %s from source %s', + frame_info.idx, + buffer.pts, + video_frame.source_id, + ) + result: Gst.FlowReturn = source_info.src_pad.push(buffer) + if result != Gst.FlowReturn.OK: + self.video_pipeline.delete(savant_frame_meta.idx) self.logger.debug( - 'Frame from source %s with PTS %s was processed.', + 'Frame with PTS %s from source %s has been processed (%s).', + buffer.pts, video_frame.source_id, - frame_info.pts, + result, ) return result - def _apply_ingress_filter(self, frame_info: FrameInfo) -> bool: - try: - if not self.ingress_pyfunc(frame_info.video_frame): - self.logger.debug( - 'Frame %s from source %s didnt pass ingress filter, skipping it.', - frame_info.pts, - frame_info.video_frame.source_id, - ) - return False - - self.logger.debug( - 'Frame %s from source %s passed ingress filter.', - frame_info.pts, - frame_info.video_frame.source_id, - ) - - except Exception as exc: - handle_non_fatal_error( - self, - self.logger, - exc, - f'Error in ingress filter call {self.ingress_pyfunc}', - self.ingress_dev_mode, - ) - if frame_info.video_frame.content.is_none(): - self.logger.debug( - 'Frame %s from source %s has no content, skipping it.', - frame_info.pts, - frame_info.video_frame.source_id, - ) - return False - - return True - def _get_source_info( - self, frame_info: FrameInfo + self, + frame_info: FrameInfo, + buffer: Gst.Buffer, ) -> Union[SourceInfo, Gst.FlowReturn]: source_info: SourceInfo = self.sources.get(frame_info.video_frame.source_id) if source_info is None: @@ -546,14 +359,11 @@ def _get_source_info( ) return Gst.FlowReturn.ERROR - if self.is_greater_than_max_resolution(frame_info.params): - return Gst.FlowReturn.OK - if not frame_info.video_frame.keyframe: self.logger.warning( 'Frame %s from source %s is not a keyframe, skipping it. ' 'Stream should start with a keyframe.', - frame_info.pts, + buffer.pts, frame_info.video_frame.source_id, ) return Gst.FlowReturn.OK @@ -565,127 +375,40 @@ def _get_source_info( return source_info - def _build_frame_buffer( - self, - video_frame: VideoFrame, - buffer: Gst.Buffer, - ) -> Gst.Buffer: - if video_frame.content.is_internal(): - return Gst.Buffer.new_wrapped(video_frame.content.get_data_as_bytes()) - - frame_type = ExternalFrameType(video_frame.content.get_method()) - if frame_type != ExternalFrameType.ZEROMQ: - raise ValueError(f'Unsupported frame type "{frame_type.value}".') - if buffer.n_memory() < 2: - raise ValueError( - f'Buffer has {buffer.n_memory()} regions of memory, expected at least 2.' - ) - - frame_buf: Gst.Buffer = Gst.Buffer.new() - frame_buf.append_memory(buffer.get_memory_range(1, -1)) - - return frame_buf - - def _push_frame( + def on_savant_eos_event( self, - source_info: SourceInfo, - frame_info: FrameInfo, - frame_buf: Gst.Buffer, - span_context, - ) -> Gst.FlowReturn: - frame_idx = self._add_frame_to_pipeline(frame_info.video_frame, span_context) - if self.pass_through_mode and not frame_info.video_frame.content.is_internal(): - content = frame_buf.extract_dup(0, frame_buf.get_size()) - self.logger.debug( - 'Storing content of frame with IDX %s as an internal VideoFrameContent (%s) bytes.', - frame_idx, - len(content), - ) - frame_info.video_frame.content = VideoFrameContent.internal(content) - frame_buf.pts = frame_info.pts - frame_buf.dts = frame_info.dts - frame_buf.duration = frame_info.duration - self.add_frame_meta(frame_idx, frame_buf, frame_info.video_frame) - self.logger.debug( - 'Pushing frame with idx=%s and pts=%s', frame_idx, frame_info.pts - ) - - return source_info.src_pad.push(frame_buf) - - def _add_frame_to_pipeline( - self, - video_frame: VideoFrame, - span_context: PropagatedContext, - ) -> int: - """Add frame to the pipeline and return frame ID.""" - - if self.video_pipeline is not None: - if span_context.as_dict(): - frame_idx = self.video_pipeline.add_frame_with_telemetry( - self.pipeline_stage_name, - video_frame, - span_context.nested_span(self.video_pipeline.root_span_name), - ) - self.logger.debug( - 'Frame with PTS %s from source %s was added to the pipeline ' - 'with telemetry. Frame ID is %s.', - video_frame.pts, - video_frame.source_id, - frame_idx, - ) - else: - frame_idx = self.video_pipeline.add_frame( - self.pipeline_stage_name, - video_frame, - ) - self.logger.debug( - 'Frame with PTS %s from source %s was added to the pipeline. ' - 'Frame ID is %s.', - video_frame.pts, - video_frame.source_id, - frame_idx, - ) - else: - frame_idx = next(self._frame_idx_gen) - self.logger.debug( - 'Pipeline is not set, generated ID for frame with PTS %s from ' - 'source %s is %s.', - video_frame.pts, - video_frame.source_id, - frame_idx, - ) + sink_pad: Gst.Pad, + event: Gst.Event, + ) -> Gst.PadProbeReturn: + """Handle savant-eos event from a sink pad.""" - return frame_idx + self.logger.debug('Got CUSTOM_DOWNSTREAM event from %s', sink_pad.get_name()) + source_id = parse_savant_eos_event(event) + if source_id is None: + return Gst.PadProbeReturn.PASS - def handle_eos(self, eos: EndOfStream) -> Gst.FlowReturn: - """Handle EndOfStream message.""" - self.logger.info('Received EOS from source %s.', eos.source_id) + self.logger.debug('Got savant-eos event for source %s', source_id) with self.source_lock: - source_info: SourceInfo = self.sources.get(eos.source_id) + source_info: SourceInfo = self.sources.get(source_id) if source_info is None: - return Gst.FlowReturn.OK - source_info.locked.set() + return Gst.PadProbeReturn.DROP + source_info.timestamp = time.time() with source_info.lock(): if source_info.src_pad is not None: self.send_eos(source_info) with self.source_lock: - del self.sources[eos.source_id] + del self.sources[source_id] - return Gst.FlowReturn.OK + return Gst.PadProbeReturn.DROP - def handle_shutdown(self, shutdown: Shutdown) -> Gst.FlowReturn: - """Handle Shutdown message.""" - if self.shutdown_auth is None: - self.logger.debug('Ignoring shutdown message: shutting down in disabled.') - return Gst.FlowReturn.OK - if shutdown.auth != self.shutdown_auth: - self.logger.debug( - 'Ignoring shutdown message: incorrect authentication key.' - ) - return Gst.FlowReturn.OK + def on_eos(self, pad: Gst.Pad, event: Gst.Event) -> Gst.PadProbeReturn: + """Handle EOS event from a sink pad. + + Emit shutdown signal. + """ - self.logger.info('Received shutdown message.') + self.logger.info('Received EOS.') with self.source_lock: self.is_running = False for source_id, source_info in list(self.sources.items()): @@ -733,31 +456,27 @@ def add_source(self, source_id: str, source_info: SourceInfo): f'Created new src pad for source {source_id}: {source_info.src_pad.name}.' ) - def update_frame_params(self, source_info: SourceInfo, frame_info: FrameInfo): + def update_frame_params(self, source_info: SourceInfo, frame_params: FrameParams): """Handle changed frame parameters on a source.""" - if source_info.params != frame_info.params: + if source_info.params != frame_params: self.logger.info( 'Frame parameters on pad %s was changed from %s to %s', source_info.src_pad.get_name(), source_info.params, - frame_info.params, + frame_params, ) - source_info.params = frame_info.params + source_info.params = frame_params self.send_eos(source_info) return - caps = build_caps(frame_info.params) + caps = build_caps(frame_params) source_info.src_pad.push_event(Gst.Event.new_caps(caps)) self.logger.info( 'Caps on pad %s changed to %s', source_info.src_pad, caps.to_string() ) - def check_timestamps( - self, - source_info: SourceInfo, - frame_info: FrameInfo, - ): + def check_timestamps(self, source_info: SourceInfo, buffer: Gst.Buffer): """Check frame timestamps (PTS and DTS). When timestamps are not monotonous, send EOS to prevent decoder @@ -770,16 +489,16 @@ def check_timestamps( 'Timestamps on source %s updated. PTS: %s -> %s, DTS: %s -> %s', source_info.source_id, source_info.last_pts, - frame_info.pts, + buffer.pts, source_info.last_dts, - frame_info.dts, + buffer.dts, ) reset = False - if frame_info.dts != Gst.CLOCK_TIME_NONE: - reset = frame_info.dts < source_info.last_dts - if frame_info.pts != Gst.CLOCK_TIME_NONE: - if frame_info.dts == Gst.CLOCK_TIME_NONE: - reset = frame_info.pts < source_info.last_pts + if buffer.dts != Gst.CLOCK_TIME_NONE: + reset = buffer.dts < source_info.last_dts + if buffer.pts != Gst.CLOCK_TIME_NONE: + if buffer.dts == Gst.CLOCK_TIME_NONE: + reset = buffer.pts < source_info.last_pts if reset: self.logger.info( @@ -788,28 +507,6 @@ def check_timestamps( ) self.send_eos(source_info) - def is_greater_than_max_resolution(self, video_frame: FrameParams) -> bool: - """Check if the resolution of the incoming stream is greater than the - max allowed resolution. Return True if the resolution is greater than - the max allowed resolution, otherwise False. - """ - if self.max_width and self.max_height: - if ( - int(video_frame.width) > self.max_width - or int(video_frame.height) > self.max_height - ): - self.logger.warning( - f'The resolution of the incoming stream is ' - f'{video_frame.width}x{video_frame.height} and ' - f'treater than the allowed max ' - f'{self.max_width}x' - f'{self.max_height}' - f' resolutions. Terminate. You can override the max allowed ' - f"resolution with 'MAX_RESOLUTION' environment variable." - ) - return True - return False - def send_eos(self, source_info: SourceInfo): """Send EOS event to a src pad.""" self.logger.debug( @@ -852,19 +549,6 @@ def eviction_loop(self): ) time.sleep(self.source_eviction_interval) - def add_frame_meta(self, idx: int, frame_buf: Gst.Buffer, video_frame: VideoFrame): - """Store metadata of a frame.""" - if self.video_pipeline is not None: - from pygstsavantframemeta import gst_buffer_add_savant_frame_meta - - if not video_frame.transformations: - video_frame.add_transformation( - VideoFrameTransformation.initial_size( - video_frame.width, video_frame.height - ) - ) - gst_buffer_add_savant_frame_meta(frame_buf, idx) - # register plugin GObject.type_register(SavantRsVideoDemux) diff --git a/gst_plugins/python/zeromq_source_bin.py b/gst_plugins/python/zeromq_source_bin.py index 0286772d7..9a586c355 100644 --- a/gst_plugins/python/zeromq_source_bin.py +++ b/gst_plugins/python/zeromq_source_bin.py @@ -1,4 +1,8 @@ """ZeroMQ src bin.""" +from typing import Optional + +from savant_rs.pipeline2 import VideoPipeline + from gst_plugins.python.savant_rs_video_decode_bin import ( SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES, SAVANT_RS_VIDEO_DECODE_BIN_SRC_PAD_TEMPLATE, @@ -11,6 +15,53 @@ DEFAULT_INGRESS_QUEUE_LENGTH = 200 DEFAULT_INGRESS_QUEUE_SIZE = 10485760 +NESTED_ZEROMQ_SRC_PROPERTIES = { + k: v + for k, v in ZEROMQ_SRC_PROPERTIES.items() + if k not in ['pipeline', 'pipeline-stage-name'] +} +NESTED_SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES = { + k: v + for k, v in SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES.items() + if k not in ['pipeline'] +} + +ZEROMQ_SOURCE_BIN_PROPERTIES = { + **NESTED_ZEROMQ_SRC_PROPERTIES, + **NESTED_SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES, + 'ingress-queue-length': ( + int, + 'Length of the ingress queue in frames.', + 'Length of the ingress queue in frames (0 - no limit).', + 0, + GObject.G_MAXINT, + DEFAULT_INGRESS_QUEUE_LENGTH, + GObject.ParamFlags.READWRITE, + ), + 'ingress-queue-byte-size': ( + int, + 'Size of the ingress queue in bytes.', + 'Size of the ingress queue in bytes (0 - no limit).', + 0, + GObject.G_MAXINT, + DEFAULT_INGRESS_QUEUE_SIZE, + GObject.ParamFlags.READWRITE, + ), + 'pipeline': ( + object, + 'VideoPipeline object from savant-rs.', + 'VideoPipeline object from savant-rs.', + GObject.ParamFlags.READWRITE, + ), + 'pipeline-source-stage-name': ( + str, + 'Name of the pipeline stage for source.', + 'Name of the pipeline stage for source.', + None, + GObject.ParamFlags.READWRITE, + ), +} + class ZeroMQSourceBin(LoggerMixin, Gst.Bin): """Wrapper for "zeromq_src ! @@ -29,28 +80,7 @@ class ZeroMQSourceBin(LoggerMixin, Gst.Bin): __gsttemplates__ = SAVANT_RS_VIDEO_DECODE_BIN_SRC_PAD_TEMPLATE - __gproperties__ = { - **ZEROMQ_SRC_PROPERTIES, - **SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES, - 'ingress-queue-length': ( - int, - 'Length of the ingress queue in frames.', - 'Length of the ingress queue in frames (0 - no limit).', - 0, - GObject.G_MAXINT, - DEFAULT_INGRESS_QUEUE_LENGTH, - GObject.ParamFlags.READWRITE, - ), - 'ingress-queue-byte-size': ( - int, - 'Size of the ingress queue in bytes.', - 'Size of the ingress queue in bytes (0 - no limit).', - 0, - GObject.G_MAXINT, - DEFAULT_INGRESS_QUEUE_SIZE, - GObject.ParamFlags.READWRITE, - ), - } + __gproperties__ = ZEROMQ_SOURCE_BIN_PROPERTIES __gsignals__ = {'shutdown': (GObject.SignalFlags.RUN_LAST, None, ())} @@ -60,6 +90,7 @@ def __init__(self, *args, **kwargs): # properties self._ingress_queue_length: int = DEFAULT_INGRESS_QUEUE_LENGTH self._ingress_queue_byte_size: int = DEFAULT_INGRESS_QUEUE_SIZE + self._video_pipeline: Optional[VideoPipeline] = None self._source: Gst.Element = Gst.ElementFactory.make('zeromq_src') self.add(self._source) @@ -88,9 +119,13 @@ def do_get_property(self, prop): return self._ingress_queue_length if prop.name == 'ingress-queue-byte-size': return self._ingress_queue_byte_size - if prop.name in ZEROMQ_SRC_PROPERTIES: + if prop.name == 'pipeline': + return self._video_pipeline + if prop.name == 'pipeline-source-stage-name': + return self._source.get_property('pipeline-stage-name') + if prop.name in NESTED_ZEROMQ_SRC_PROPERTIES: return self._source.get_property(prop.name) - if prop.name in SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES: + if prop.name in NESTED_SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES: return self._decodebin.get_property(prop.name) raise AttributeError(f'Unknown property {prop.name}') @@ -107,9 +142,15 @@ def do_set_property(self, prop, value): elif prop.name == 'ingress-queue-byte-size': self._ingress_queue_byte_size = value self._queue.set_property('max-size-bytes', value) - elif prop.name in ZEROMQ_SRC_PROPERTIES: + elif prop.name == 'pipeline': + self._video_pipeline = value + self._source.set_property(prop.name, value) + self._decodebin.set_property(prop.name, value) + elif prop.name == 'pipeline-source-stage-name': + self._source.set_property('pipeline-stage-name', value) + elif prop.name in NESTED_ZEROMQ_SRC_PROPERTIES: self._source.set_property(prop.name, value) - elif prop.name in SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES: + elif prop.name in NESTED_SAVANT_RS_VIDEO_DECODE_BIN_PROPERTIES: self._decodebin.set_property(prop.name, value) else: raise AttributeError(f'Unknown property {prop.name}') diff --git a/gst_plugins/python/zeromq_src.py b/gst_plugins/python/zeromq_src.py index 204363644..3d9b62378 100644 --- a/gst_plugins/python/zeromq_src.py +++ b/gst_plugins/python/zeromq_src.py @@ -1,25 +1,41 @@ """ZeroMQ src.""" import inspect -from typing import Optional +from typing import Optional, Tuple -import zmq +from pygstsavantframemeta import gst_buffer_add_savant_frame_meta +from savant_rs.pipeline2 import VideoPipeline +from savant_rs.primitives import ( + EndOfStream, + Shutdown, + VideoFrame, + VideoFrameContent, + VideoFrameTransformation, +) +from savant_rs.utils import PropagatedContext +from gst_plugins.python.pyfunc_common import handle_non_fatal_error, init_pyfunc from gst_plugins.python.zeromq_properties import ZEROMQ_PROPERTIES, socket_type_property +from savant.api.enums import ExternalFrameType +from savant.api.parser import convert_ts +from savant.base.frame_filter import DefaultIngressFilter +from savant.base.pyfunc import PyFunc from savant.gstreamer import GObject, Gst, GstBase +from savant.gstreamer.event import build_savant_eos_event from savant.gstreamer.utils import ( - gst_buffer_from_list, gst_post_library_settings_error, gst_post_stream_failed_error, + required_property, ) from savant.utils.logging import LoggerMixin from savant.utils.zeromq import ( Defaults, ReceiverSocketTypes, + ZeroMQMessage, ZeroMQSource, - ZMQException, - build_topic_prefix, ) +HandlerResult = Optional[Tuple[Gst.FlowReturn, Optional[Gst.Buffer]]] + ZEROMQ_SRC_PROPERTIES = { **ZEROMQ_PROPERTIES, 'socket-type': socket_type_property(ReceiverSocketTypes), @@ -46,6 +62,84 @@ None, GObject.ParamFlags.READWRITE, ), + 'pipeline': ( + object, + 'VideoPipeline object from savant-rs.', + 'VideoPipeline object from savant-rs.', + GObject.ParamFlags.READWRITE, + ), + 'pipeline-stage-name': ( + str, + 'Name of the pipeline stage.', + 'Name of the pipeline stage.', + None, + GObject.ParamFlags.READWRITE, + ), + 'shutdown-auth': ( + str, + 'Authentication key for Shutdown message.', + 'Authentication key for Shutdown message.', + None, + GObject.ParamFlags.READWRITE, + ), + 'max-width': ( + int, + 'Maximum allowable resolution width of the video stream', + 'Maximum allowable resolution width of the video stream', + 0, + GObject.G_MAXINT, + 0, + GObject.ParamFlags.READWRITE, + ), + 'max-height': ( + int, + 'Maximum allowable resolution height of the video stream', + 'Maximum allowable resolution height of the video stream', + 0, + GObject.G_MAXINT, + 0, + GObject.ParamFlags.READWRITE, + ), + 'pass-through-mode': ( + bool, + 'Run module in a pass-through mode.', + 'Run module in a pass-through mode. Store frame content in VideoFrame ' + 'object as an internal VideoFrameContent.', + False, + GObject.ParamFlags.READWRITE, + ), + 'ingress-module': ( + str, + 'Ingress filter python module.', + 'Name or path of the python module where ' + 'the ingress filter class code is located.', + None, + GObject.ParamFlags.READWRITE, + ), + 'ingress-class': ( + str, + 'Ingress filter python class name.', + 'Name of the python class that implements ingress filter.', + None, + GObject.ParamFlags.READWRITE, + ), + 'ingress-kwargs': ( + str, + 'Ingress filter init kwargs.', + 'Keyword arguments for ingress filter initialization.', + None, + GObject.ParamFlags.READWRITE, + ), + 'ingress-dev-mode': ( + bool, + 'Ingress filter dev mode flag.', + ( + 'Whether to monitor the ingress filter source file changes at runtime' + ' and reload the pyfunc objects when necessary.' + ), + False, + GObject.ParamFlags.READWRITE, + ), } @@ -72,16 +166,28 @@ class ZeromqSrc(LoggerMixin, GstBase.BaseSrc): def __init__(self): GstBase.BaseSrc.__init__(self) + + # properties self.socket: str = None self.socket_type: str = ReceiverSocketTypes.ROUTER.name self.bind: bool = True - self.zmq_context: zmq.Context = None - self.context = None - self.receiver = None self.receive_timeout: int = Defaults.RECEIVE_TIMEOUT self.receive_hwm: int = Defaults.RECEIVE_HWM self.source_id: Optional[str] = None self.source_id_prefix: Optional[str] = None + self.video_pipeline: Optional[VideoPipeline] = None + self.pipeline_stage_name: Optional[str] = None + self.shutdown_auth: Optional[str] = None + self.max_width: int = 0 + self.max_height: int = 0 + self.pass_through_mode = False + + self.ingress_module: Optional[str] = None + self.ingress_class_name: Optional[str] = None + self.ingress_kwargs: Optional[str] = None + self.ingress_dev_mode: bool = False + self.ingress_pyfunc: Optional[PyFunc] = None + self.source: ZeroMQSource = None self.set_live(True) @@ -90,6 +196,7 @@ def do_get_property(self, prop): :param prop: property parameters """ + if prop.name == 'socket': return self.socket if prop.name == 'socket-type': @@ -102,8 +209,30 @@ def do_get_property(self, prop): return self.source_id if prop.name == 'source-id-prefix': return self.source_id_prefix - # if prop.name == 'zmq-context': - # return self.zmq_context + + if prop.name == 'pipeline': + return self.video_pipeline + if prop.name == 'pipeline-stage-name': + return self.pipeline_stage_name + if prop.name == 'shutdown-auth': + return self.shutdown_auth + if prop.name == 'pass-through-mode': + return self.pass_through_mode + + if prop.name == 'max-width': + return self.max_width + if prop.name == 'max-height': + return self.max_height + + if prop.name == 'ingress-module': + return self.ingress_module + if prop.name == 'ingress-class': + return self.ingress_class_name + if prop.name == 'ingress-kwargs': + return self.ingress_kwargs + if prop.name == 'ingress-dev-mode': + return self.ingress_dev_mode + raise AttributeError(f'Unknown property {prop.name}.') def do_set_property(self, prop, value): @@ -125,27 +254,67 @@ def do_set_property(self, prop, value): self.source_id = value elif prop.name == 'source-id-prefix': self.source_id_prefix = value - # elif prop.name == 'zmq-context': - # self.zmq_context = value + + elif prop.name == 'pipeline': + self.video_pipeline = value + elif prop.name == 'pipeline-stage-name': + self.pipeline_stage_name = value + elif prop.name == 'shutdown-auth': + self.shutdown_auth = value + elif prop.name == 'pass-through-mode': + self.pass_through_mode = value + + elif prop.name == 'max-width': + self.max_width = value + elif prop.name == 'max-height': + self.max_height = value + + elif prop.name == 'ingress-module': + self.ingress_module = value + elif prop.name == 'ingress-class': + self.ingress_class_name = value + elif prop.name == 'ingress-kwargs': + self.ingress_kwargs = value + elif prop.name == 'ingress-dev-mode': + self.ingress_dev_mode = value + else: raise AttributeError(f'Unknown property "{prop.name}".') def do_start(self): """Start source.""" - self.logger.debug('Called do_start().') - topic_prefix = build_topic_prefix(self.source_id, self.source_id_prefix) + self.logger.debug('Starting ZeroMQ source') try: + required_property('socket', self.socket) + required_property('pipeline', self.video_pipeline) + required_property('pipeline-stage-name', self.pipeline_stage_name) + + if self.ingress_module and self.ingress_class_name: + self.ingress_pyfunc = init_pyfunc( + self, + self.logger, + self.ingress_module, + self.ingress_class_name, + self.ingress_kwargs, + self.ingress_dev_mode, + ) + else: + # for AO RTSP + self.logger.debug('Ingress filter is not set, using default one.') + self.ingress_pyfunc = DefaultIngressFilter() + self.source = ZeroMQSource( socket=self.socket, socket_type=self.socket_type, bind=self.bind, receive_timeout=self.receive_timeout, receive_hwm=self.receive_hwm, - topic_prefix=topic_prefix, + source_id=self.source_id, + source_id_prefix=self.source_id_prefix, ) - except ZMQException: - error = f'Failed to create ZeroMQ source with socket {self.socket}.' + except Exception as exc: + error = f'Failed to start ZeroMQ source with socket {self.socket}: {exc}.' self.logger.exception(error, exc_info=True) frame = inspect.currentframe() gst_post_library_settings_error(self, frame, __file__, error) @@ -157,8 +326,8 @@ def do_start(self): def start_zero_mq_source(self): try: self.source.start() - except ZMQException: - error = f'Failed to start ZeroMQ source with socket {self.socket}.' + except Exception as exc: + error = f'Failed to start ZeroMQ source with socket {self.socket}: {exc}.' self.logger.exception(error, exc_info=True) frame = inspect.currentframe() gst_post_stream_failed_error( @@ -175,23 +344,240 @@ def start_zero_mq_source(self): def do_create(self, offset: int, size: int, buffer: Gst.Buffer = None): """Create gst buffer.""" - if not self.source.is_alive: + self.logger.debug('Creating next buffer') + + if not self.source.is_started: if not self.start_zero_mq_source(): return Gst.FlowReturn.ERROR - self.logger.debug('Receiving next message') - - message = None - while message is None: + result = None + while result is None: flow_return = self.wait_playing() if flow_return != Gst.FlowReturn.OK: self.logger.info('Returning %s', flow_return) return flow_return, None - message = self.source.next_message() - self.logger.debug('Received message of sizes %s', [len(x) for x in message]) - buffer = gst_buffer_from_list(message) + result = self.try_create() + + return result - return Gst.FlowReturn.OK, buffer + def try_create(self) -> HandlerResult: + zmq_message = self.source.next_message() + if zmq_message is None: + return + self.logger.debug('Received message from topic %s.', zmq_message.topic) + + return self.handle_message(zmq_message) + + def handle_message(self, zmq_message: ZeroMQMessage) -> HandlerResult: + message = zmq_message.message + message.validate_seq_id() + if message.is_video_frame(): + return self.handle_video_frame( + message.as_video_frame(), + message.span_context, + zmq_message.content, + ) + if message.is_end_of_stream(): + return self.handle_eos(message.as_end_of_stream()) + if message.is_shutdown(): + return self.handle_shutdown(message.as_shutdown()) + self.logger.warning('Unsupported message type for message %r', message) + + def handle_video_frame( + self, + video_frame: VideoFrame, + span_context: PropagatedContext, + external_content: bytes, + ) -> HandlerResult: + """Handle VideoFrame message.""" + + frame_pts = convert_ts(video_frame.pts, video_frame.time_base) + frame_dts = ( + convert_ts(video_frame.dts, video_frame.time_base) + if video_frame.dts is not None + else Gst.CLOCK_TIME_NONE + ) + frame_duration = ( + convert_ts(video_frame.duration, video_frame.time_base) + if video_frame.duration is not None + else Gst.CLOCK_TIME_NONE + ) + self.logger.debug( + 'Received frame %s/%s from source %s; frame %s a keyframe', + frame_pts, + frame_dts, + video_frame.source_id, + 'is' if video_frame.keyframe else 'is not', + ) + + if self.is_greater_than_max_resolution(video_frame): + return + + try: + if not self.ingress_pyfunc(video_frame): + self.logger.debug( + 'Frame %s from source %s didnt pass ingress filter, skipping it.', + frame_pts, + video_frame.source_id, + ) + return + + self.logger.debug( + 'Frame %s from source %s passed ingress filter.', + frame_pts, + video_frame.source_id, + ) + except Exception as exc: + handle_non_fatal_error( + self, + self.logger, + exc, + f'Error in ingress filter call {self.ingress_pyfunc}', + self.ingress_dev_mode, + ) + if video_frame.content.is_none(): + self.logger.debug( + 'Frame %s from source %s has no content, skipping it.', + frame_pts, + video_frame.source_id, + ) + return + + try: + frame_buf = self.build_frame_buffer(video_frame, external_content) + except ValueError as e: + error = ( + f'Failed to build buffer for video frame {frame_pts} ' + f'from source {video_frame.source_id}: {e}' + ) + self.logger.error(error) + frame = inspect.currentframe() + gst_post_stream_failed_error( + gst_element=self, + frame=frame, + file_path=__file__, + text=error, + ) + return Gst.FlowReturn.ERROR, None + + if frame_buf is None: + self.logger.debug( + 'Frame %s from source %s has no content, skipping it.', + frame_pts, + video_frame.source_id, + ) + return + + frame_idx = self.add_frame_to_pipeline(video_frame, span_context) + if self.pass_through_mode and not video_frame.content.is_internal(): + self.logger.debug( + 'Storing content of frame with IDX %s as an internal VideoFrameContent (%s) bytes.', + frame_idx, + len(external_content), + ) + video_frame.content = VideoFrameContent.internal(external_content) + frame_buf.pts = frame_pts + frame_buf.dts = frame_dts + frame_buf.duration = frame_duration + self.add_frame_meta(frame_idx, frame_buf, video_frame) + self.logger.debug( + 'Frame with PTS %s from source %s has been processed.', + frame_pts, + video_frame.source_id, + ) + + return Gst.FlowReturn.OK, frame_buf + + def build_frame_buffer( + self, + video_frame: VideoFrame, + external_content: bytes, + ) -> Optional[Gst.Buffer]: + if video_frame.content.is_none(): + return None + + if video_frame.content.is_internal(): + return Gst.Buffer.new_wrapped(video_frame.content.get_data_as_bytes()) + + frame_type = ExternalFrameType(video_frame.content.get_method()) + if frame_type != ExternalFrameType.ZEROMQ: + raise ValueError(f'Unsupported frame type "{frame_type.value}".') + + if not external_content: + return None + + return Gst.Buffer.new_wrapped(external_content) + + def add_frame_to_pipeline( + self, + video_frame: VideoFrame, + span_context: PropagatedContext, + ) -> int: + """Add frame to the pipeline and return frame ID.""" + + if span_context.as_dict(): + frame_idx = self.video_pipeline.add_frame_with_telemetry( + self.pipeline_stage_name, + video_frame, + span_context.nested_span(self.video_pipeline.root_span_name), + ) + self.logger.debug( + 'Frame with PTS %s from source %s was added to the pipeline ' + 'with telemetry. Frame ID is %s.', + video_frame.pts, + video_frame.source_id, + frame_idx, + ) + else: + frame_idx = self.video_pipeline.add_frame( + self.pipeline_stage_name, + video_frame, + ) + self.logger.debug( + 'Frame with PTS %s from source %s was added to the pipeline. ' + 'Frame ID is %s.', + video_frame.pts, + video_frame.source_id, + frame_idx, + ) + + return frame_idx + + def add_frame_meta(self, idx: int, frame_buf: Gst.Buffer, video_frame: VideoFrame): + """Store metadata of a frame.""" + + if not video_frame.transformations: + video_frame.add_transformation( + VideoFrameTransformation.initial_size( + video_frame.width, video_frame.height + ) + ) + gst_buffer_add_savant_frame_meta(frame_buf, idx) + + def handle_eos(self, eos: EndOfStream) -> HandlerResult: + """Handle EndOfStream message.""" + + self.logger.info('Received EOS from source %s.', eos.source_id) + savant_eos_event = build_savant_eos_event(eos.source_id) + if not self.srcpad.push_event(savant_eos_event): + self.logger.error('Failed to push savant-eos event to the pipeline') + return Gst.FlowReturn.ERROR, None + + def handle_shutdown(self, shutdown: Shutdown) -> HandlerResult: + """Handle Shutdown message.""" + + if self.shutdown_auth is None: + self.logger.debug('Ignoring shutdown message: shutting down in disabled.') + return + if shutdown.auth != self.shutdown_auth: + self.logger.debug( + 'Ignoring shutdown message: incorrect authentication key.' + ) + return + + self.logger.info('Received shutdown message: sending EOS.') + self.srcpad.push_event(Gst.Event.new_eos()) + return Gst.FlowReturn.EOS, None def do_stop(self): """Gst src stop callback.""" @@ -202,6 +588,29 @@ def do_is_seekable(self): """Check if the source can seek.""" return False + def is_greater_than_max_resolution(self, video_frame: VideoFrame) -> bool: + """Check if the resolution of the incoming stream is greater than the + max allowed resolution. Return True if the resolution is greater than + the max allowed resolution, otherwise False. + """ + + if self.max_width and self.max_height: + if ( + int(video_frame.width) > self.max_width + or int(video_frame.height) > self.max_height + ): + self.logger.warning( + f'The resolution of the incoming stream is ' + f'{video_frame.width}x{video_frame.height} and ' + f'greater than the allowed max ' + f'{self.max_width}x' + f'{self.max_height}' + f' resolutions. Terminate. You can override the max allowed ' + f"resolution with 'MAX_RESOLUTION' environment variable." + ) + return True + return False + # register plugin GObject.type_register(ZeromqSrc) diff --git a/libs/gstsavantframemeta/CMakeLists.txt b/libs/gstsavantframemeta/CMakeLists.txt index 9b1da9b9b..206124ded 100644 --- a/libs/gstsavantframemeta/CMakeLists.txt +++ b/libs/gstsavantframemeta/CMakeLists.txt @@ -31,8 +31,14 @@ if (NOT ${EXIT_CODE} EQUAL 0) "The \"savant_rs\" Python3 package is not installed. Please install it using the following command: \"pip3 install savant_rs\"." ) endif() -file(GLOB SAVANT_RS_LIB_FILES "${SAVANT_RS_LIB_DIR}/savant_rs.cpython*.so") -message(STATUS "Found library for savant-rs: ${SAVANT_RS_LIB_FILES}") +set(SAVANT_RS_CORE_LIB_DIR "${SAVANT_RS_LIB_DIR}.libs") + +file(GLOB SAVANT_RS_CORE_LIB_FILE "${SAVANT_RS_CORE_LIB_DIR}/libsavant_core-*.so") +message(STATUS "Found core library for savant-rs: ${SAVANT_RS_CORE_LIB_FILE}") +file(GLOB SAVANT_RS_CORE_PY_LIB_FILE "${SAVANT_RS_CORE_LIB_DIR}/libsavant_core_py-*.so") +message(STATUS "Found core-py library for savant-rs: ${SAVANT_RS_CORE_PY_LIB_FILE}") +file(GLOB SAVANT_RS_LIB_FILE "${SAVANT_RS_LIB_DIR}/savant_rs.cpython*.so") +message(STATUS "Found library for savant-rs: ${SAVANT_RS_LIB_FILE}") if(NOT DEFINED DeepStream_DIR) set(DeepStream_DIR /opt/nvidia/deepstream/deepstream) @@ -58,6 +64,7 @@ link_directories( ${GSTREAMER_LIBRARY_DIRS} ${CUDA_LIBRARY_DIRS} ${DeepStream_DIR}/lib + ${SAVANT_RS_CORE_LIB_DIR} ${SAVANT_RS_LIB_DIR} ) @@ -69,8 +76,12 @@ file (GLOB PYTHON_FILES "pygstsavantframemeta/*.cpp" "pygstsavantframemeta/*.h") source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR} FILES ${SOURCE_FILES} ${HEADER_FILES} ${PYTHON_FILES} ) +add_library(savant_core UNKNOWN IMPORTED) +set_property(TARGET savant_core PROPERTY IMPORTED_LOCATION ${SAVANT_RS_CORE_LIB_FILE}) +add_library(savant_core_py UNKNOWN IMPORTED) +set_property(TARGET savant_core_py PROPERTY IMPORTED_LOCATION ${SAVANT_RS_CORE_PY_LIB_FILE}) add_library(savant_rs UNKNOWN IMPORTED) -set_property(TARGET savant_rs PROPERTY IMPORTED_LOCATION ${SAVANT_RS_LIB_FILES}) +set_property(TARGET savant_rs PROPERTY IMPORTED_LOCATION ${SAVANT_RS_LIB_FILE}) pybind11_add_module(${python_module_name} SHARED ${SOURCE_FILES} diff --git a/libs/gstsavantframemeta/gstsavantframemeta/CMakeLists.txt b/libs/gstsavantframemeta/gstsavantframemeta/CMakeLists.txt index 390d5dc3f..e1ccd72e8 100644 --- a/libs/gstsavantframemeta/gstsavantframemeta/CMakeLists.txt +++ b/libs/gstsavantframemeta/gstsavantframemeta/CMakeLists.txt @@ -22,8 +22,14 @@ if (NOT ${EXIT_CODE} EQUAL 0) "The \"savant_rs\" Python3 package is not installed. Please install it using the following command: \"pip3 install savant_rs\"." ) endif() -file(GLOB SAVANT_RS_LIB_FILES "${SAVANT_RS_LIB_DIR}/savant_rs.cpython*.so") -message(STATUS "Found library for savant-rs: ${SAVANT_RS_LIB_FILES}") +set(SAVANT_RS_CORE_LIB_DIR "${SAVANT_RS_LIB_DIR}.libs") + +file(GLOB SAVANT_RS_CORE_LIB_FILE "${SAVANT_RS_CORE_LIB_DIR}/libsavant_core-*.so") +message(STATUS "Found core library for savant-rs: ${SAVANT_RS_CORE_LIB_FILE}") +file(GLOB SAVANT_RS_CORE_PY_LIB_FILE "${SAVANT_RS_CORE_LIB_DIR}/libsavant_core_py-*.so") +message(STATUS "Found core-py library for savant-rs: ${SAVANT_RS_CORE_PY_LIB_FILE}") +file(GLOB SAVANT_RS_LIB_FILE "${SAVANT_RS_LIB_DIR}/savant_rs.cpython*.so") +message(STATUS "Found library for savant-rs: ${SAVANT_RS_LIB_FILE}") if(NOT DEFINED DeepStream_DIR) set(DeepStream_DIR /opt/nvidia/deepstream/deepstream) @@ -45,6 +51,7 @@ link_directories( ${GSTREAMER_LIBRARY_DIRS} ${CUDA_LIBRARY_DIRS} ${DeepStream_DIR}/lib + ${SAVANT_RS_CORE_LIB_DIR} ${SAVANT_RS_LIB_DIR} ) @@ -66,8 +73,12 @@ set(SOURCE_FILES source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR} FILES ${SOURCE_FILES}) +add_library(savant_core UNKNOWN IMPORTED) +set_property(TARGET savant_core PROPERTY IMPORTED_LOCATION ${SAVANT_RS_CORE_LIB_FILE}) +add_library(savant_core_py UNKNOWN IMPORTED) +set_property(TARGET savant_core_py PROPERTY IMPORTED_LOCATION ${SAVANT_RS_CORE_PY_LIB_FILE}) add_library(savant_rs UNKNOWN IMPORTED) -set_property(TARGET savant_rs PROPERTY IMPORTED_LOCATION ${SAVANT_RS_LIB_FILES}) +set_property(TARGET savant_rs PROPERTY IMPORTED_LOCATION ${SAVANT_RS_LIB_FILE}) add_library(gstsavantframemeta SHARED ${SOURCE_FILES}) target_link_libraries( @@ -77,6 +88,8 @@ target_link_libraries( ${CUDA_LIBRARIES} nvds_meta nvdsgst_meta + savant_core + savant_core_py savant_rs ) diff --git a/libs/gstsavantframemeta/gstsavantframemeta/src/savantrsprobes.cpp b/libs/gstsavantframemeta/gstsavantframemeta/src/savantrsprobes.cpp index 3049d6235..877c51d18 100644 --- a/libs/gstsavantframemeta/gstsavantframemeta/src/savantrsprobes.cpp +++ b/libs/gstsavantframemeta/gstsavantframemeta/src/savantrsprobes.cpp @@ -8,6 +8,10 @@ #include "gstnvdsmeta.h" #include "nvdsmeta.h" + +typedef struct Arc_Vec_BorrowedVideoObject { +} Arc_Vec_BorrowedVideoObject; + extern "C" { #include "savant_rs.h" } diff --git a/requirements/savant-rs.txt b/requirements/savant-rs.txt index 7d796953d..a030e7ff0 100644 --- a/requirements/savant-rs.txt +++ b/requirements/savant-rs.txt @@ -1 +1 @@ -savant-rs==0.1.85 +savant-rs==0.2.13 diff --git a/samples/animegan/docker-compose.x86.yml b/samples/animegan/docker-compose.x86.yml index 56df360a9..1bfc88657 100644 --- a/samples/animegan/docker-compose.x86.yml +++ b/samples/animegan/docker-compose.x86.yml @@ -48,7 +48,7 @@ services: - ZMQ_ENDPOINT=sub+connect:ipc:///tmp/zmq-sockets/output-video.ipc - DIR_LOCATION=/data - SKIP_FRAMES_WITHOUT_OBJECTS=False - entrypoint: /opt/savant/adapters/gst/sinks/video_files.sh + entrypoint: /opt/savant/adapters/gst/sinks/video_files.py volumes: zmq_sockets: diff --git a/samples/panoptic_driving_perception/run_perf.sh b/samples/panoptic_driving_perception/run_perf.sh index e97af3a57..2c6eb818e 100755 --- a/samples/panoptic_driving_perception/run_perf.sh +++ b/samples/panoptic_driving_perception/run_perf.sh @@ -10,7 +10,7 @@ MODULE_CONFIG=samples/panoptic_driving_perception/module.yml DATA_LOCATION=data/panoptic_driving_perception.mp4 if [ "$(uname -m)" = "aarch64" ]; then - docker compose -f samples/panoptic_driving_perception/docker-compose.l4t.yml build module + docker buildx build --target savant_torch -f ./samples/panoptic_driving_perception/docker/Dockerfile.l4t -t panoptic_driving_perception-module ./samples/panoptic_driving_perception else docker compose -f samples/panoptic_driving_perception/docker-compose.x86.yml build module fi @@ -19,4 +19,4 @@ source samples/assets/run_perf_helper.sh set_source $DATA_LOCATION PERF_CONFIG="${MODULE_CONFIG%.*}_perf.yml" config_perf $MODULE_CONFIG $PERF_CONFIG "${YQ_ARGS[@]}" -./scripts/run_module.py -i panoptic_driving_perception-module $PERF_CONFIG \ No newline at end of file +./scripts/run_module.py -i panoptic_driving_perception-module $PERF_CONFIG diff --git a/savant/VERSION b/savant/VERSION index 080c68b21..4b488a79e 100644 --- a/savant/VERSION +++ b/savant/VERSION @@ -1,2 +1,2 @@ -SAVANT=0.2.9 +SAVANT=0.2.10 DEEPSTREAM=6.3 diff --git a/savant/api/builder.py b/savant/api/builder.py index a4755207d..6b2059f2c 100644 --- a/savant/api/builder.py +++ b/savant/api/builder.py @@ -77,7 +77,7 @@ def add_objects_to_video_frame( for obj_id, obj in enumerate(objects): video_object = build_video_object(obj_id, obj) frame.add_object(video_object, IdCollisionResolutionPolicy.Error) - track_id = video_object.get_track_id() + track_id = video_object.track_id if track_id is not None: parents[(video_object.namespace, video_object.label, track_id)] = obj_id @@ -98,8 +98,14 @@ def build_video_object(obj_id: int, obj: Dict[str, Any]): if attributes is not None: attributes = build_object_attributes(attributes) else: - attributes = {} + attributes = [] bbox = build_bbox(obj['bbox']) + track_id = obj['object_id'] + if track_id == UNTRACKED_OBJECT_ID: + track_id = None + track_box = None + else: + track_box = bbox video_object = VideoObject( id=obj_id, namespace=obj['model_name'], @@ -107,10 +113,9 @@ def build_video_object(obj_id: int, obj: Dict[str, Any]): detection_box=bbox, attributes=attributes, confidence=obj['confidence'], + track_id=track_id, + track_box=track_box, ) - track_id = obj['object_id'] - if track_id != UNTRACKED_OBJECT_ID: - video_object.set_track_info(track_id, bbox) return video_object @@ -126,8 +131,8 @@ def build_bbox(bbox: Dict[str, Any]): def build_object_attributes(attributes: List[Dict[str, Any]]): - return { - (attr['element_name'], attr['name']): Attribute( + return [ + Attribute( namespace=attr['element_name'], name=attr['name'], values=[ @@ -135,7 +140,7 @@ def build_object_attributes(attributes: List[Dict[str, Any]]): ], ) for attr in attributes - } + ] def build_attribute_value(value: Any, confidence: Optional[float] = None): @@ -157,10 +162,8 @@ def add_tags_to_video_frame( tags: Dict[str, Union[bool, int, float, str]], ): for name, value in tags.items(): - frame.set_attribute( - Attribute( - namespace=DEFAULT_NAMESPACE, - name=name, - values=[build_attribute_value(value)], - ) + frame.set_persistent_attribute( + namespace=DEFAULT_NAMESPACE, + name=name, + values=[build_attribute_value(value)], ) diff --git a/savant/api/parser.py b/savant/api/parser.py index 54a55f8ca..7a9e5f487 100644 --- a/savant/api/parser.py +++ b/savant/api/parser.py @@ -1,5 +1,6 @@ from typing import Tuple, Union +from savant_rs.match_query import MatchQuery from savant_rs.primitives import ( Attribute, AttributeValue, @@ -9,7 +10,6 @@ VideoObject, ) from savant_rs.primitives.geometry import BBox, RBBox -from savant_rs.video_object_query import MatchQuery from savant.api.constants import DEFAULT_TIME_BASE from savant.meta.constants import UNTRACKED_OBJECT_ID @@ -90,13 +90,13 @@ def parse_video_objects(frame: VideoFrame): child = objects[obj_id] child['parent_model_name'] = parent.namespace child['parent_label'] = parent.label - child['parent_object_id'] = parent.get_track_id() + child['parent_object_id'] = parent.track_id return list(objects.values()) def parse_video_object(obj: VideoObject): - track_id = obj.get_track_id() + track_id = obj.track_id if track_id is None: track_id = UNTRACKED_OBJECT_ID diff --git a/savant/client/image_source/image_source.py b/savant/client/image_source/image_source.py index 287c37ee6..34377aa3b 100644 --- a/savant/client/image_source/image_source.py +++ b/savant/client/image_source/image_source.py @@ -3,7 +3,6 @@ from typing import Any, BinaryIO, List, Optional, Tuple, TypeVar, Union from savant_rs.primitives import ( - Attribute, AttributeValue, VideoFrame, VideoFrameContent, @@ -120,12 +119,10 @@ def build_frame(self) -> Tuple[VideoFrame, bytes]: time_base=self._time_base, ) if isinstance(self._file, (str, PathLike)): - video_frame.set_attribute( - Attribute( - namespace=DEFAULT_NAMESPACE, - name='location', - values=[AttributeValue.string(str(self._file))], - ) + video_frame.set_persistent_attribute( + namespace=DEFAULT_NAMESPACE, + name='location', + values=[AttributeValue.string(str(self._file))], ) for update in self._updates: video_frame.update(update) diff --git a/savant/client/runner/sink.py b/savant/client/runner/sink.py index e2815ee75..6cbf91d2f 100644 --- a/savant/client/runner/sink.py +++ b/savant/client/runner/sink.py @@ -1,17 +1,16 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional from savant_rs.primitives import EndOfStream, VideoFrame, VideoFrameContent -from savant_rs.utils.serialization import Message, load_message_from_bytes from savant.client.log_provider import LogProvider from savant.client.runner import LogResult from savant.client.runner.healthcheck import HealthCheck from savant.healthcheck.status import ModuleStatus from savant.utils.logging import get_logger -from savant.utils.zeromq import AsyncZeroMQSource, Defaults, ZeroMQSource +from savant.utils.zeromq import AsyncZeroMQSource, Defaults, ZeroMQMessage, ZeroMQSource logger = get_logger(__name__) @@ -65,8 +64,8 @@ def _build_zeromq_source(self, socket: str, receive_timeout: int, receive_hwm: i def _receive_next_message(self) -> Optional[SinkResult]: pass - def _handle_message(self, message_parts: List[bytes]): - message: Message = load_message_from_bytes(message_parts[0]) + def _handle_message(self, zmq_message: ZeroMQMessage): + message = zmq_message.message message.validate_seq_id() trace_id: Optional[str] = message.span_context.as_dict().get('uber-trace-id') if trace_id is not None: @@ -78,9 +77,8 @@ def _handle_message(self, message_parts: List[bytes]): video_frame.source_id, video_frame.pts, ) - if len(message_parts) > 1: - content = message_parts[1] - else: + content = zmq_message.content + if not zmq_message.content: content = None if video_frame.content.is_internal(): content = video_frame.content.get_data_as_bytes() @@ -143,9 +141,9 @@ def __iter__(self): return self def _receive_next_message(self) -> Optional[SinkResult]: - message_parts = self._source.next_message() - if message_parts is not None: - return self._handle_message(message_parts) + message = self._source.next_message() + if message is not None: + return self._handle_message(message) class AsyncSinkRunner(BaseSinkRunner): @@ -186,6 +184,6 @@ def __aiter__(self): return self async def _receive_next_message(self) -> Optional[SinkResult]: - message_parts = await self._source.next_message() - if message_parts is not None: - return self._handle_message(message_parts) + message = await self._source.next_message() + if message is not None: + return self._handle_message(message) diff --git a/savant/client/runner/source.py b/savant/client/runner/source.py index 0198e54a1..01e36f9e5 100644 --- a/savant/client/runner/source.py +++ b/savant/client/runner/source.py @@ -1,7 +1,7 @@ +import asyncio from dataclasses import dataclass -from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union +from typing import AsyncIterable, Iterable, Optional, Set, Tuple, Union -import zmq from savant_rs.pipeline2 import ( VideoPipeline, VideoPipelineConfiguration, @@ -9,10 +9,12 @@ ) from savant_rs.primitives import EndOfStream, Shutdown, VideoFrame from savant_rs.utils import TelemetrySpan -from savant_rs.utils.serialization import ( - Message, - clear_source_seq_id, - save_message_to_bytes, +from savant_rs.utils.serialization import Message, clear_source_seq_id +from savant_rs.zmq import ( + BlockingWriter, + NonBlockingWriter, + WriterConfig, + WriterConfigBuilder, ) from savant.client.frame_source import FrameSource @@ -21,13 +23,7 @@ from savant.client.runner.healthcheck import HealthCheck from savant.healthcheck.status import ModuleStatus from savant.utils.logging import get_logger -from savant.utils.zeromq import ( - Defaults, - SenderSocketTypes, - async_receive_response, - parse_zmq_socket_uri, - receive_response, -) +from savant.utils.zeromq import Defaults logger = get_logger(__name__) @@ -50,6 +46,8 @@ class SourceResult(LogResult): class SourceRunner: """Sends messages to ZeroMQ socket.""" + _writer: BlockingWriter + def __init__( self, socket: str, @@ -78,23 +76,14 @@ def __init__( if module_health_check_url is not None else None ) - self._socket_type, self._bind, self._socket = parse_zmq_socket_uri( - uri=socket, - socket_type_enum=SenderSocketTypes, - socket_type_name=None, - bind=None, - ) + + config_builder = WriterConfigBuilder(socket) + config_builder.with_receive_timeout(receive_timeout) + config_builder.with_send_hwm(send_hwm) + config = config_builder.build() self._last_send_time = 0 - self._wait_response = self._socket_type == SenderSocketTypes.REQ - self._zmq_context = self._create_zmq_ctx() - self._sender = self._zmq_context.socket(self._socket_type.value) - self._sender.setsockopt(zmq.SNDHWM, self._send_hwm) - self._sender.setsockopt(zmq.RCVTIMEO, self._receive_timeout) - if self._bind: - self._sender.bind(self._socket) - else: - self._sender.connect(self._socket) + self._writer = self._build_zeromq_writer(config) self._pipeline_stage_name = 'savant-client' self._pipeline = VideoPipeline( @@ -104,6 +93,7 @@ def __init__( ) if self._telemetry_enabled: self._pipeline.sampling_period = 1 + self._writer.start() def __call__(self, source: Frame, send_eos: bool = True) -> SourceResult: """Send a single frame to ZeroMQ socket. @@ -128,15 +118,13 @@ def send(self, source: Frame, send_eos: bool = True) -> SourceResult: if self._health_check is not None: self._health_check.wait_module_is_ready() - zmq_topic, serialized_message, content, result = self._prepare_video_frame( - source - ) + zmq_topic, message, content, result = self._prepare_video_frame(source) logger.debug( 'Sending video frame %s/%s.', result.source_id, result.pts, ) - self._send_zmq_message([zmq_topic, serialized_message, content]) + self._send_zmq_message(zmq_topic, message, content) logger.debug('Sent video frame %s/%s.', result.source_id, result.pts) if send_eos: self.send_eos(result.source_id) @@ -180,8 +168,8 @@ def send_eos(self, source_id: str) -> SourceResult: if self._health_check is not None: self._health_check.wait_module_is_ready() - zmq_topic, serialized_message, result = self._prepare_eos(source_id) - self._send_zmq_message([zmq_topic, serialized_message]) + zmq_topic, message, result = self._prepare_eos(source_id) + self._send_zmq_message(zmq_topic, message) logger.debug('Sent EOS for source %s.', source_id) result.status = 'ok' clear_source_seq_id(source_id) @@ -198,32 +186,18 @@ def send_shutdown(self, source_id: str, auth: str) -> SourceResult: if self._health_check is not None: self._health_check.wait_module_is_ready() - zmq_topic, serialized_message, result = self._prepare_shutdown(source_id, auth) - self._send_zmq_message([zmq_topic, serialized_message]) + zmq_topic, message, result = self._prepare_shutdown(source_id, auth) + self._send_zmq_message(zmq_topic, message) logger.debug('Sent Shutdown message for source %s.', source_id) result.status = 'ok' return result - def _send_zmq_message(self, message: List[bytes]): - for retries_left in reversed(range(self._retries)): - try: - self._sender.send_multipart(message) - break - except Exception: - if retries_left == 0: - raise - logger.error( - 'Failed to send message to socket %s. %s retries left.', - self._socket, - retries_left, - exc_info=True, - ) - if self._wait_response: - receive_response(self._sender, self._retries) - - def _create_zmq_ctx(self): - return zmq.Context() + def _send_zmq_message(self, topic: str, message: Message, content: bytes = b''): + self._writer.send_message(topic, message, content) + + def _build_zeromq_writer(self, config: WriterConfig): + return BlockingWriter(config) def _prepare_video_frame(self, source: Frame): if isinstance(source, FrameSource): @@ -233,8 +207,7 @@ def _prepare_video_frame(self, source: Frame): video_frame, content = source logger.debug('Sending video frame from source %s.', video_frame.source_id) frame_id = self._pipeline.add_frame(self._pipeline_stage_name, video_frame) - zmq_topic = f'{video_frame.source_id}/'.encode() - message = Message.video_frame(video_frame) + message = video_frame.to_message() if self._telemetry_enabled: span: TelemetrySpan = self._pipeline.delete(frame_id)[frame_id] message.span_context = span.propagate() @@ -242,11 +215,10 @@ def _prepare_video_frame(self, source: Frame): del span else: trace_id = None - serialized_message = save_message_to_bytes(message) return ( - zmq_topic, - serialized_message, + video_frame.source_id, + message, content, SourceResult( source_id=video_frame.source_id, @@ -259,13 +231,11 @@ def _prepare_video_frame(self, source: Frame): def _prepare_eos(self, source_id: str): logger.debug('Sending EOS for source %s.', source_id) - zmq_topic = f'{source_id}/'.encode() - message = Message.end_of_stream(EndOfStream(source_id)) - serialized_message = save_message_to_bytes(message) + message = EndOfStream(source_id).to_message() return ( - zmq_topic, - serialized_message, + source_id, + message, SourceResult( source_id=source_id, pts=None, @@ -277,13 +247,11 @@ def _prepare_eos(self, source_id: str): def _prepare_shutdown(self, source_id: str, auth: str): logger.debug('Sending Shutdown message for source %s.', source_id) - zmq_topic = f'{source_id}/'.encode() - message = Message.shutdown(Shutdown(auth)) - serialized_message = save_message_to_bytes(message) + message = Shutdown(auth).to_message() return ( - zmq_topic, - serialized_message, + source_id, + message, SourceResult( source_id=source_id, pts=None, @@ -297,6 +265,8 @@ def _prepare_shutdown(self, source_id: str, auth: str): class AsyncSourceRunner(SourceRunner): """Sends messages to ZeroMQ socket asynchronously.""" + _writer: NonBlockingWriter + async def __call__(self, source: Frame, send_eos: bool = True) -> SourceResult: return await self.send(source, send_eos) @@ -304,15 +274,13 @@ async def send(self, source: Frame, send_eos: bool = True) -> SourceResult: if self._health_check is not None: self._health_check.wait_module_is_ready() - zmq_topic, serialized_message, content, result = self._prepare_video_frame( - source - ) + zmq_topic, message, content, result = self._prepare_video_frame(source) logger.debug( 'Sending video frame %s/%s.', result.source_id, result.pts, ) - await self._send_zmq_message([zmq_topic, serialized_message, content]) + await self._send_zmq_message(zmq_topic, message, content) logger.debug('Sent video frame %s/%s.', result.source_id, result.pts) if send_eos: await self.send_eos(result.source_id) @@ -341,8 +309,8 @@ async def send_eos(self, source_id: str) -> SourceResult: if self._health_check is not None: self._health_check.wait_module_is_ready() - zmq_topic, serialized_message, result = self._prepare_eos(source_id) - await self._send_zmq_message([zmq_topic, serialized_message]) + zmq_topic, message, result = self._prepare_eos(source_id) + await self._send_zmq_message(zmq_topic, message) logger.debug('Sent EOS for source %s.', source_id) result.status = 'ok' clear_source_seq_id(source_id) @@ -352,32 +320,24 @@ async def send_shutdown(self, source_id: str, auth: str) -> SourceResult: if self._health_check is not None: self._health_check.wait_module_is_ready() - zmq_topic, serialized_message, result = self._prepare_shutdown(source_id, auth) - await self._send_zmq_message([zmq_topic, serialized_message]) + zmq_topic, message, result = self._prepare_shutdown(source_id, auth) + await self._send_zmq_message(zmq_topic, message) logger.debug('Sent Shutdown message for source %s.', source_id) result.status = 'ok' return result - async def _send_zmq_message(self, message: List[bytes]): - for retries_left in reversed(range(self._retries)): - try: - await self._sender.send_multipart(message) - break - except Exception: - if retries_left == 0: - raise - logger.error( - 'Failed to send message to socket %s. %s retries left.', - self._socket, - retries_left, - exc_info=True, - ) - if self._wait_response: - await async_receive_response(self._sender, self._retries) - - def _create_zmq_ctx(self): - return zmq.asyncio.Context() + async def _send_zmq_message( + self, topic: str, message: Message, content: bytes = b'' + ): + while not self._writer.has_capacity(): + await asyncio.sleep(0.01) # TODO: make configurable + await asyncio.get_running_loop().run_in_executor( + None, self._writer.send_message, topic, message, content + ) + + def _build_zeromq_writer(self, config: WriterConfig): + return NonBlockingWriter(config, 10) # TODO: make configurable async def _send_iter_item(self, source: FrameAndEos, source_ids: Set[str]): if isinstance(source, EndOfStream): diff --git a/savant/deepstream/buffer_processor.py b/savant/deepstream/buffer_processor.py index f8b01f97c..c92a552a8 100644 --- a/savant/deepstream/buffer_processor.py +++ b/savant/deepstream/buffer_processor.py @@ -10,6 +10,7 @@ gst_buffer_get_savant_frame_meta, nvds_frame_meta_get_nvds_savant_frame_meta, ) +from savant_rs.match_query import MatchQuery from savant_rs.pipeline2 import VideoPipeline from savant_rs.primitives import ( VideoFrame, @@ -24,7 +25,6 @@ get_object_id, parse_compound_key, ) -from savant_rs.video_object_query import MatchQuery from savant.api.parser import parse_attribute_value from savant.base.input_preproc import ObjectsPreprocessing @@ -226,7 +226,7 @@ def _prepare_input_frame( bbox.left += self._frame_params.padding.left bbox.top += self._frame_params.padding.top - track_id = obj_meta.get_track_id() + track_id = obj_meta.track_id if track_id is None: track_id = UNTRACKED_OBJECT_ID # create nvds obj meta @@ -498,7 +498,7 @@ def _build_sink_video_frame( with video_frame_span.nested_span('prepare_output'): if self._pass_through_mode: if video_frame.content.is_internal(): - content = video_frame.content.get_data_as_bytes() + content = video_frame.content.get_data() self.logger.debug( 'Pass-through mode is enabled. ' 'Sending frame with IDX %s to sink without any changes. ' diff --git a/savant/deepstream/meta/frame.py b/savant/deepstream/meta/frame.py index 3678f4240..83f0256c8 100644 --- a/savant/deepstream/meta/frame.py +++ b/savant/deepstream/meta/frame.py @@ -3,7 +3,7 @@ from typing import Dict, Iterator, Optional, Union import pyds -from savant_rs.primitives import Attribute, VideoFrame +from savant_rs.primitives import VideoFrame from savant_rs.primitives.geometry import BBox from savant_rs.utils import TelemetrySpan @@ -134,12 +134,10 @@ def set_tag(self, name: str, value: Union[bool, int, float, str]): :param value: Tag value """ - self._video_frame.set_attribute( - Attribute( - namespace=DEFAULT_NAMESPACE, - name=name, - values=[build_attribute_value(value)], - ) + self._video_frame.set_persistent_attribute( + namespace=DEFAULT_NAMESPACE, + name=name, + values=[build_attribute_value(value)], ) @property diff --git a/savant/deepstream/metadata.py b/savant/deepstream/metadata.py index 92d985217..f3c10714f 100644 --- a/savant/deepstream/metadata.py +++ b/savant/deepstream/metadata.py @@ -64,16 +64,22 @@ def nvds_obj_meta_output_converter( object_id = nvds_get_obj_uid(nvds_frame_meta, nvds_obj_meta) + if nvds_obj_meta.object_id == UNTRACKED_OBJECT_ID: + track_id = None + track_box = None + else: + track_id = nvds_obj_meta.object_id + track_box = bbox video_object = VideoObject( id=object_id, namespace=model_name, label=label, detection_box=bbox, - attributes={}, + attributes=[], confidence=confidence, + track_id=track_id, + track_box=track_box, ) - if nvds_obj_meta.object_id != UNTRACKED_OBJECT_ID: - video_object.set_track_info(nvds_obj_meta.object_id, bbox) parent_id = None if ( @@ -89,10 +95,14 @@ def nvds_obj_meta_output_converter( return video_object, parent_id -def nvds_attr_meta_output_converter(attr_meta: AttributeMeta) -> Attribute: +def nvds_attr_meta_output_converter( + attr_meta: AttributeMeta, + is_persistent: bool = True, +) -> Attribute: """Convert attribute meta to savant-rs format. :param attr_meta: Attribute meta. + :param is_persistent: Whether attribute is persistent. :return: Attribute meta in savant-rs format. """ value = build_attribute_value(attr_meta.value, attr_meta.confidence) @@ -100,4 +110,5 @@ def nvds_attr_meta_output_converter(attr_meta: AttributeMeta) -> Attribute: namespace=attr_meta.element_name, name=attr_meta.name, values=[value], + is_persistent=is_persistent, ) diff --git a/savant/deepstream/pipeline.py b/savant/deepstream/pipeline.py index 31cab3cd8..390285c72 100644 --- a/savant/deepstream/pipeline.py +++ b/savant/deepstream/pipeline.py @@ -135,6 +135,7 @@ def __init__( { 'max-parallel-streams': self._max_parallel_streams, 'pipeline-source-stage-name': 'source', + 'pipeline-demux-stage-name': 'source-demuxer', 'pipeline-decoder-stage-name': 'decode', } ) diff --git a/savant/deepstream/utils/pipeline.py b/savant/deepstream/utils/pipeline.py index e71ef4a69..956a00774 100644 --- a/savant/deepstream/utils/pipeline.py +++ b/savant/deepstream/utils/pipeline.py @@ -113,6 +113,7 @@ def get_pipeline_element_stages( def build_pipeline_stages(element_stages: List[Union[str, List[str]]]): pipeline_stages = [ ('source', VideoPipelineStagePayloadType.Frame), + ('source-demuxer', VideoPipelineStagePayloadType.Frame), ('decode', VideoPipelineStagePayloadType.Frame), ('source-convert', VideoPipelineStagePayloadType.Frame), ('source-capsfilter', VideoPipelineStagePayloadType.Frame), diff --git a/savant/gstreamer/event.py b/savant/gstreamer/event.py new file mode 100644 index 000000000..24fa87d34 --- /dev/null +++ b/savant/gstreamer/event.py @@ -0,0 +1,35 @@ +from typing import Optional + +from savant.gstreamer import Gst + +SAVANT_EOS_EVENT_NAME = 'savant-eos' +SAVANT_EOS_EVENT_SOURCE_ID_PROPERTY = 'source-id' + + +def build_savant_eos_event(source_id: str): + """Build a savant-eos event. + + :param source_id: Source ID of the stream. + :returns: The savant-eos event. + """ + + structure: Gst.Structure = Gst.Structure.new_empty(SAVANT_EOS_EVENT_NAME) + structure.set_value(SAVANT_EOS_EVENT_SOURCE_ID_PROPERTY, source_id) + return Gst.Event.new_custom(Gst.EventType.CUSTOM_DOWNSTREAM, structure) + + +def parse_savant_eos_event(event: Gst.Event) -> Optional[str]: + """Parse a savant-eos event. + + :param event: The event to parse. + :returns: Source ID of the stream if the event is a savant-eos event, otherwise None. + """ + + if event.type != Gst.EventType.CUSTOM_DOWNSTREAM: + return None + + struct: Gst.Structure = event.get_structure() + if not struct.has_name(SAVANT_EOS_EVENT_NAME): + return None + + return struct.get_string(SAVANT_EOS_EVENT_SOURCE_ID_PROPERTY) diff --git a/savant/parameter_storage/__init__.py b/savant/parameter_storage/__init__.py index ac2684723..343462ee8 100644 --- a/savant/parameter_storage/__init__.py +++ b/savant/parameter_storage/__init__.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple from omegaconf import DictConfig, ListConfig, OmegaConf -from savant_rs.video_object_query import ( +from savant_rs.match_query import ( register_config_resolver, register_env_resolver, register_etcd_resolver, diff --git a/savant/utils/artist/artist_gpumat.py b/savant/utils/artist/artist_gpumat.py index 1f66d711f..0c36ba6e5 100644 --- a/savant/utils/artist/artist_gpumat.py +++ b/savant/utils/artist/artist_gpumat.py @@ -146,7 +146,7 @@ def add_bbox( return if isinstance(bbox, BBox): - left, top, right, bottom = bbox.visual_box( + left, top, right, bottom = bbox.get_visual_box( PaddingDraw(*padding), border_width, self.max_col, self.max_row ).as_ltrb_int() @@ -327,7 +327,7 @@ def blur( cv2.CV_8UC4, cv2.CV_8UC4, (radius, radius), sigma ) - left, top, width, height = bbox.visual_box( + left, top, width, height = bbox.get_visual_box( PaddingDraw(*padding), 0, self.max_col, self.max_row ).as_ltwh_int() @@ -348,7 +348,7 @@ def copy_frame_region( value in pixels, left, top, right, bottom. :return: GpuMat with the specified region. """ - left, top, width, height = bbox.visual_box( + left, top, width, height = bbox.get_visual_box( PaddingDraw(*padding), 0, self.max_col, self.max_row ).as_ltwh_int() roi_mat = cv2.cuda_GpuMat(self.frame, (left, top, width, height)) diff --git a/savant/utils/re_patterns.py b/savant/utils/re_patterns.py index 1693b20ae..2b0b4c4c2 100644 --- a/savant/utils/re_patterns.py +++ b/savant/utils/re_patterns.py @@ -1,4 +1,3 @@ import re socket_uri_pattern = re.compile('([a-z]+\\+[a-z]+:)?([a-z]+://.*)') -socket_options_pattern = re.compile('([a-z]+)\\+([a-z]+):') diff --git a/savant/utils/sink_factories.py b/savant/utils/sink_factories.py index 486b924cc..efbfcae7c 100644 --- a/savant/utils/sink_factories.py +++ b/savant/utils/sink_factories.py @@ -2,10 +2,14 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union -import zmq from savant_rs.primitives import EndOfStream, VideoFrame, VideoFrameContent from savant_rs.utils import PropagatedContext -from savant_rs.utils.serialization import Message, save_message_to_bytes +from savant_rs.zmq import ( + BlockingWriter, + WriterConfigBuilder, + WriterResultAck, + WriterResultSuccess, +) from savant.api.enums import ExternalFrameType from savant.api.parser import convert_ts @@ -13,14 +17,7 @@ from savant.config.schema import SinkElement from savant.utils.logging import get_logger from savant.utils.registry import Registry -from savant.utils.zeromq import ( - Defaults, - SenderSocketTypes, - create_ipc_socket_dirs, - ipc_socket_chmod, - parse_zmq_socket_uri, - receive_response, -) +from savant.utils.zeromq import Defaults, SenderSocketTypes, get_zmq_socket_uri_options logger = get_logger(__name__) @@ -158,7 +155,7 @@ def __init__( bind: bool = True, send_hwm: int = Defaults.SEND_HWM, receive_timeout: int = Defaults.SENDER_RECEIVE_TIMEOUT, - req_receive_retries: int = Defaults.REQ_RECEIVE_RETRIES, + req_receive_retries: int = Defaults.RECEIVE_RETRIES, set_ipc_socket_permissions: bool = True, ): super().__init__(sink_name, egress_pyfunc) @@ -168,44 +165,41 @@ def __init__( socket_type, bind, ) + socket_type = SenderSocketTypes[socket_type] self.receive_timeout = receive_timeout self.req_receive_retries = req_receive_retries self.set_ipc_socket_permissions = set_ipc_socket_permissions - # might raise exceptions - # will be handled in savant.entrypoint - self.socket_type, self.bind, self.socket = parse_zmq_socket_uri( - uri=socket, - socket_type_name=socket_type, - socket_type_enum=SenderSocketTypes, - bind=bind, - ) + self.socket = socket + if get_zmq_socket_uri_options(socket): + self.socket_type = None + self.bind = None + else: + self.socket_type = socket_type + self.bind = bind self.send_hwm = send_hwm - self.wait_response = self.socket_type == SenderSocketTypes.REQ def get_sink(self) -> SinkCallable: - context = zmq.Context() - output_zmq_socket = context.socket(self.socket_type.value) - output_zmq_socket.setsockopt(zmq.SNDHWM, self.send_hwm) - output_zmq_socket.setsockopt(zmq.RCVTIMEO, self.receive_timeout) - - create_ipc_socket_dirs(self.socket) - - if self.bind: - output_zmq_socket.bind(self.socket) - else: - output_zmq_socket.connect(self.socket) - if self.set_ipc_socket_permissions and self.bind: - ipc_socket_chmod(self.socket) + config_builder = WriterConfigBuilder(self.socket) + config_builder.with_send_hwm(self.send_hwm) + config_builder.with_receive_timeout(self.receive_timeout) + config_builder.with_receive_retries(self.req_receive_retries) + if self.socket_type is not None: + config_builder.with_socket_type(self.socket_type.value) + if self.bind is not None: + config_builder.with_bind(bool(self.bind)) # in case "bind" is "int" + writer = BlockingWriter(config_builder.build()) + writer.start() + + # if self.set_ipc_socket_permissions and self.bind: + # ipc_socket_chmod(self.socket) def send_message( msg: SinkMessage, **kwargs, ): - zmq_topic = f'{msg.source_id}/'.encode() - zmq_message = [zmq_topic] - + send_result = None if isinstance(msg, SinkVideoFrame): frame_pts = convert_ts(msg.video_frame.pts, msg.video_frame.time_base) if self.video_frame_filter(msg.video_frame, frame_pts): @@ -233,26 +227,28 @@ def send_message( ) msg.video_frame.content = VideoFrameContent.none() - message = Message.video_frame(msg.video_frame) + message = msg.video_frame.to_message() if msg.span_context is not None: message.span_context = msg.span_context - zmq_message.append(save_message_to_bytes(message)) - if msg.frame: - zmq_message.append(msg.frame) + send_result = writer.send_message( + msg.source_id, message, msg.frame or b'' + ) elif isinstance(msg, SinkEndOfStream): logger.debug( 'Sending EOS of source "%s" to ZeroMQ sink.', msg.source_id ) - message = Message.end_of_stream(msg.eos) - zmq_message.append(save_message_to_bytes(message)) + message = msg.eos.to_message() + send_result = writer.send_message(msg.source_id, message, b'') else: logger.warning('Unknown message type %s.', type(msg)) return - output_zmq_socket.send_multipart(zmq_message) - if self.wait_response: - resp = receive_response(output_zmq_socket, self.req_receive_retries) - logger.debug( - 'Received %s bytes from socket %s.', len(resp), self.socket + + if not ( + send_result is None + or isinstance(send_result, (WriterResultAck, WriterResultSuccess)) + ): + raise RuntimeError( + f'Failed to send message to ZeroMQ sink: {send_result}' ) return send_message @@ -370,7 +366,7 @@ def sink_factory(sink: Union[SinkElement, List[SinkElement]]) -> SinkCallable: if isinstance(sink, SinkElement): return SINK_REGISTRY.get(sink.element.lower())( sink.full_name, sink.egress_frame_filter, **sink.properties - ).get_sink() + ).get_sink sink_factories = [] for _sink in sink: diff --git a/savant/utils/zeromq.py b/savant/utils/zeromq.py index 70d8466cf..5d947b1e1 100644 --- a/savant/utils/zeromq.py +++ b/savant/utils/zeromq.py @@ -1,54 +1,50 @@ """ZeroMQ utilities.""" -import os +import asyncio from abc import ABC, abstractmethod from enum import Enum -from typing import List, Optional, Tuple, Type, Union -from urllib.parse import urlparse - -import zmq -import zmq.asyncio -from cachetools import LRUCache +from typing import List, NamedTuple, Optional, Union + +from savant_rs.utils.serialization import Message +from savant_rs.zmq import ( + BlockingReader, + NonBlockingReader, + ReaderConfig, + ReaderConfigBuilder, + ReaderResultMessage, + ReaderResultPrefixMismatch, + ReaderResultTimeout, + ReaderSocketType, + TopicPrefixSpec, + WriterSocketType, +) from savant.utils.logging import get_logger -from .re_patterns import socket_options_pattern, socket_uri_pattern +from .re_patterns import socket_uri_pattern logger = get_logger(__name__) -CONFIRMATION_MESSAGE = b'OK' -END_OF_STREAM_MESSAGE = b'EOS' - - -class ZMQException(Exception): - """Error in ZMQ-related code.""" - - -class ZMQSocketEndpointException(ZMQException): - """Error in ZMQ socket endpoint.""" - -class ZMQSocketTypeException(ZMQException): - """Error in ZMQ socket type.""" - - -class ZMQSocketUriParsingException(ZMQException): - """Error in ZMQ socket URI.""" +class ZeroMQMessage(NamedTuple): + topic: List[int] + message: Message + content: Optional[bytes] = None class ReceiverSocketTypes(Enum): """Receiver socket types.""" - SUB = zmq.SUB - REP = zmq.REP - ROUTER = zmq.ROUTER + SUB = ReaderSocketType.Sub + REP = ReaderSocketType.Rep + ROUTER = ReaderSocketType.Router class SenderSocketTypes(Enum): """Sender socket types.""" - PUB = zmq.PUB - REQ = zmq.REQ - DEALER = zmq.DEALER + PUB = WriterSocketType.Pub + REQ = WriterSocketType.Req + DEALER = WriterSocketType.Dealer class Defaults: @@ -56,38 +52,7 @@ class Defaults: SENDER_RECEIVE_TIMEOUT = 5000 RECEIVE_HWM = 50 SEND_HWM = 50 - REQ_RECEIVE_RETRIES = 3 - EOS_CONFIRMATION_RETRIES = 3 - - -def get_socket_endpoint(socket_endpoint: str): - if not isinstance(socket_endpoint, str): - raise ZMQSocketEndpointException( - f'Incorrect socket endpoint: "{socket_endpoint}":' - f'"{type(socket_endpoint)}" is not string.' - ) - return socket_endpoint - - -def get_socket_type( - socket_type_name: str, - socket_type_enum: Union[Type[ReceiverSocketTypes], Type[SenderSocketTypes]], -): - if not isinstance(socket_type_name, str): - raise ZMQSocketTypeException( - f'Incorrect socket_type_name: "{socket_type_name}":' - f'"{type(socket_type_name)}" is not string.' - ) - - socket_type_name = str.upper(socket_type_name) - - try: - return socket_type_enum[socket_type_name] - except KeyError as exc: - raise ZMQSocketTypeException( - f'Incorrect socket type: {socket_type_name} is not one of ' - f'{[socket_type.name for socket_type in socket_type_enum]}.' - ) from exc + RECEIVE_RETRIES = 3 class BaseZeroMQSource(ABC): @@ -98,11 +63,11 @@ class BaseZeroMQSource(ABC): :param bind: zmq socket mode (bind or connect) :param receive_timeout: receive timeout socket option :param receive_hwm: high watermark for inbound messages - :param topic_prefix: filter inbound messages by topic prefix + :param source_id: filter inbound messages by source ID + :param source_id_prefix: filter inbound messages by topic prefix """ - zmq_context: zmq.Context - receiver: zmq.Socket + receiver: Union[BlockingReader, NonBlockingReader] def __init__( self, @@ -111,8 +76,8 @@ def __init__( bind: bool = True, receive_timeout: int = Defaults.RECEIVE_TIMEOUT, receive_hwm: int = Defaults.RECEIVE_HWM, - topic_prefix: Optional[str] = None, - routing_ids_cache_size: int = 1000, + source_id: Optional[str] = None, + source_id_prefix: Optional[str] = None, set_ipc_socket_permissions: bool = True, ): logger.debug( @@ -122,387 +87,140 @@ def __init__( bind, ) - self.topic_prefix = topic_prefix.encode() if topic_prefix else b'' - self.receive_hwm = receive_hwm - self.set_ipc_socket_permissions = set_ipc_socket_permissions - - # might raise exceptions - # will be handled in ZeromqSrc element - # or image_files.py / metadata_json.py Python sinks - self.socket_type, self.bind, self.socket = parse_zmq_socket_uri( - uri=socket, - socket_type_name=socket_type, - socket_type_enum=ReceiverSocketTypes, - bind=bind, - ) + config_builder = ReaderConfigBuilder(socket) + socket_options = get_zmq_socket_uri_options(socket) + if socket_options: + bind = 'bind' in socket_options + else: + config_builder.with_socket_type(ReceiverSocketTypes[socket_type].value) + config_builder.with_bind(bool(bind)) # in case "bind" is "int" + if source_id: + config_builder.with_topic_prefix_spec(TopicPrefixSpec.source_id(source_id)) + elif source_id_prefix: + config_builder.with_topic_prefix_spec( + TopicPrefixSpec.prefix(source_id_prefix) + ) + config_builder.with_receive_hwm(receive_hwm) + config_builder.with_receive_timeout(receive_timeout) + if bind: + # IPC permissions can only be set for bind sockets. + config_builder.with_fix_ipc_permissions(set_ipc_socket_permissions) - self.receive_timeout = receive_timeout - self.routing_id_filter = RoutingIdFilter(routing_ids_cache_size) - self.is_alive = False - self._always_respond = self.socket_type == ReceiverSocketTypes.REP + self.reader = self._create_zmq_reader(config_builder.build()) def start(self): """Start ZeroMQ source.""" - if self.is_alive: + if self.is_started: logger.warning('ZeroMQ source is already started.') return - logger.info( - 'Starting ZMQ source: socket %s, type %s, bind %s.', - self.socket, - self.socket_type, - self.bind, - ) - - self.zmq_context = self._create_zmq_ctx() - self.receiver = self.zmq_context.socket(self.socket_type.value) - self.receiver.setsockopt(zmq.RCVHWM, self.receive_hwm) + logger.info('Starting ZMQ source.') + self.reader.start() - create_ipc_socket_dirs(self.socket) - - if self.bind: - self.receiver.bind(self.socket) - else: - self.receiver.connect(self.socket) - if self.socket_type == ReceiverSocketTypes.SUB: - self.receiver.setsockopt(zmq.SUBSCRIBE, self.topic_prefix) - self.receiver.setsockopt(zmq.RCVTIMEO, self.receive_timeout) - if self.set_ipc_socket_permissions and self.bind: - ipc_socket_chmod(self.socket) - self.is_alive = True + @property + def is_started(self): + return self.reader.is_started() @abstractmethod - def next_message_without_routing_id(self) -> Optional[List[bytes]]: - """Try to receive next message without routing ID but with topic.""" - - @abstractmethod - def next_message(self) -> Optional[List[bytes]]: + def next_message(self) -> Optional[ZeroMQMessage]: """Try to receive next message.""" pass + def _build_result(self, result: ReaderResultMessage) -> Optional[ZeroMQMessage]: + if isinstance(result, ReaderResultMessage): + return ZeroMQMessage( + result.topic, + result.message, + b''.join(result.data(i) for i in range(result.data_len())), + ) + elif isinstance(result, ReaderResultTimeout): + logger.debug('Timeout exceeded when receiving the next frame') + elif isinstance(result, ReaderResultPrefixMismatch): + logger.debug('Skipping message from topic %s', result.topic) + + return None + def terminate(self): """Finish and free zmq socket.""" - if not self.is_alive: + if not self.is_started: logger.warning('ZeroMQ source is not started.') return - self.is_alive = False - logger.info('Closing ZeroMQ socket') - self.receiver.close() - self.receiver = None logger.info('Terminating ZeroMQ context.') - self.zmq_context.term() - self.zmq_context = None + self.reader.shutdown() logger.info('ZeroMQ context terminated') @abstractmethod - def _create_zmq_ctx(self): + def _create_zmq_reader(self, config: ReaderConfig): pass class ZeroMQSource(BaseZeroMQSource): """ZeroMQ Source class.""" - def next_message_without_routing_id(self) -> Optional[List[bytes]]: - """Try to receive next message without routing ID but with topic.""" - - if not self.is_alive: - raise RuntimeError('ZeroMQ source is not started.') - - try: - message = self.receiver.recv_multipart() - except zmq.Again: - logger.debug('Timeout exceeded when receiving the next frame') - return - - if self.socket_type == ReceiverSocketTypes.ROUTER: - routing_id, *message = message - else: - routing_id = None + reader: BlockingReader - if message[0] == END_OF_STREAM_MESSAGE: - if routing_id: - self.receiver.send_multipart([routing_id, CONFIRMATION_MESSAGE]) - else: - self.receiver.send(CONFIRMATION_MESSAGE) - return - - if self._always_respond: - self.receiver.send(CONFIRMATION_MESSAGE) - - topic = message[0] - if len(message) < 2: - raise RuntimeError(f'ZeroMQ message from topic {topic} does not have data.') - - if self.topic_prefix and not topic.startswith(self.topic_prefix): - logger.debug( - 'Skipping message from topic %s, expected prefix %s', - topic, - self.topic_prefix, - ) - return - - if self.routing_id_filter.filter(routing_id, topic): - return message - - def next_message(self) -> Optional[List[bytes]]: + def next_message(self) -> Optional[ZeroMQMessage]: """Try to receive next message.""" - message = self.next_message_without_routing_id() - if message is not None: - return message[1:] + if not self.reader.is_started(): + raise RuntimeError('ZeroMQ source is not started.') + + result = self.reader.receive() + return self._build_result(result) def __iter__(self): return self - def __next__(self): + def __next__(self) -> ZeroMQMessage: message = None - while self.is_alive and message is None: + while self.reader.is_started() and message is None: message = self.next_message() if message is None: raise StopIteration return message - def _create_zmq_ctx(self): - return zmq.Context() + def _create_zmq_reader(self, config: ReaderConfig): + return BlockingReader(config) class AsyncZeroMQSource(ZeroMQSource): """Async ZeroMQ Source class.""" - zmq_context: zmq.asyncio.Context - receiver: zmq.asyncio.Socket - - def _create_zmq_ctx(self): - return zmq.asyncio.Context() + reader: NonBlockingReader - async def next_message_without_routing_id(self) -> Optional[List[bytes]]: - """Try to receive next message without routing ID but with topic.""" + def _create_zmq_reader(self, config: ReaderConfig): + return NonBlockingReader(config, 10) # TODO: make configurable - if not self.is_alive: - raise RuntimeError('ZeroMQ source is not started.') - - try: - message = await self.receiver.recv_multipart() - except zmq.Again: - logger.debug('Timeout exceeded when receiving the next frame') - return - - if self.socket_type == ReceiverSocketTypes.ROUTER: - routing_id, *message = message - else: - routing_id = None + async def _try_receive(self, loop): + return await loop.run_in_executor(None, self.reader.try_receive) - if message[0] == END_OF_STREAM_MESSAGE: - if routing_id: - await self.receiver.send_multipart([routing_id, CONFIRMATION_MESSAGE]) - else: - await self.receiver.send(CONFIRMATION_MESSAGE) - return - - if self._always_respond: - await self.receiver.send(CONFIRMATION_MESSAGE) - - topic = message[0] - if len(message) < 2: - raise RuntimeError(f'ZeroMQ message from topic {topic} does not have data.') - - if self.topic_prefix and not topic.startswith(self.topic_prefix): - logger.debug( - 'Skipping message from topic %s, expected prefix %s', - topic, - self.topic_prefix, - ) - return + async def next_message(self) -> Optional[ZeroMQMessage]: + """Try to receive next message.""" - if self.routing_id_filter.filter(routing_id, topic): - return message + if not self.reader.is_started(): + raise RuntimeError('ZeroMQ source is not started.') - async def next_message(self) -> Optional[List[bytes]]: - """Try to receive next message.""" + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, self.reader.try_receive) + while result is None: + await asyncio.sleep(0.01) # TODO: make configurable + result = await loop.run_in_executor(None, self.reader.try_receive) - message = await self.next_message_without_routing_id() - if message is not None: - return message[1:] + return self._build_result(result) def __aiter__(self): return self async def __anext__(self): message = None - while self.is_alive and message is None: + while self.reader.is_started() and message is None: message = await self.next_message() if message is None: raise StopIteration return message -class RoutingIdFilter: - """Cache for routing IDs to filter out old connections. - - Some ZeroMQ sockets have buffer on the receiver side (PUSH/PULL, DEALER/ROUTER). - ZeroMQ processes messages in round-robin manner. When a sender reconnects - with the same source ID ZeroMQ mixes up messages from old and new connections. - This causes decoder to fail and the module freezes. To avoid this we are - caching all routing IDs and ignoring messages from the old ones. - """ - - def __init__(self, cache_size: int): - self.routing_ids = {} - self.routing_ids_cache = LRUCache(cache_size) - - def filter(self, routing_id: Optional[bytes], topic: bytes): - """Decide whether we need to accept of ignore the message from that routing ID.""" - - if not routing_id: - return True - - if topic not in self.routing_ids: - self.routing_ids[topic] = routing_id - self.routing_ids_cache[(topic, routing_id)] = None - - elif self.routing_ids[topic] != routing_id: - if (topic, routing_id) in self.routing_ids_cache: - logger.debug( - 'Skipping message from topic %s: routing ID %s, expected %s.', - topic, - routing_id, - self.routing_ids[topic], - ) - return False - - else: - logger.debug( - 'Routing ID for topic %s changed from %s to %s.', - topic, - self.routing_ids[topic], - routing_id, - ) - self.routing_ids[topic] = routing_id - self.routing_ids_cache[(topic, routing_id)] = None - - return True - - -def build_topic_prefix( - source_id: Optional[str], - source_id_prefix: Optional[str], -) -> Optional[str]: - """Build topic prefix based on source ID or its prefix.""" - if source_id: - return f'{source_id}/' - elif source_id_prefix: - return source_id_prefix - - -def parse_zmq_socket_uri( - uri: str, - socket_type_name: Optional[str], - socket_type_enum: Union[Type[ReceiverSocketTypes], Type[SenderSocketTypes]], - bind: Optional[bool], -) -> Tuple[Union[ReceiverSocketTypes, SenderSocketTypes], bool, str]: - """Parse ZMQ socket URI. - - Socket type and binding flag can be embedded into URI or passed as separate arguments. - - URI schema: [+(bind|connect):]. - - Examples: - - ipc:///tmp/zmq-sockets/input-video.ipc - - dealer+connect:ipc:///tmp/zmq-sockets/input-video.ipc:source - - tcp://1.1.1.1:3333 - - pub+bind:tcp://1.1.1.1:3333:source - - :param uri: ZMQ socket URI. - :param socket_type_name: Name of a socket type. Ignored when specified in URI. - :param socket_type_enum: Enum for a socket type. - :param bind: Whether to bind or connect ZMQ socket. Ignored when in URI. - """ - - options, endpoint = socket_uri_pattern.fullmatch(uri).groups() - if options: - socket_type_name, bind_str = socket_options_pattern.fullmatch(options).groups() - if bind_str == 'bind': - bind = True - elif bind_str == 'connect': - bind = False - else: - raise ZMQSocketUriParsingException( - f'Incorrect socket bind options in socket URI {uri!r}' - ) - if socket_type_name is None: - raise ZMQSocketUriParsingException( - f'Socket type is not specified for URI {uri!r}' - ) - if bind is None: - raise ZMQSocketUriParsingException( - f'Socket binding flag is not specified for URI {uri!r}' - ) - - endpoint = get_socket_endpoint(endpoint) - socket_type = get_socket_type(socket_type_name, socket_type_enum) - - return socket_type, bind, endpoint - - -def receive_response(sender: zmq.Socket, retries: int): - """Receive response from sender socket. - - Retry until response is received. - """ - - while retries > 0: - try: - return sender.recv() - except zmq.Again: - retries -= 1 - logger.debug( - 'Timeout exceeded when receiving response (%s retries left)', - retries, - ) - if retries == 0: - raise - - -async def async_receive_response(sender: zmq.asyncio.Socket, retries: int): - """Receive response from async sender socket. - - Retry until response is received. - """ - - while retries > 0: - try: - return await sender.recv() - except zmq.Again: - retries -= 1 - logger.debug( - 'Timeout exceeded when receiving response (%s retries left)', - retries, - ) - if retries == 0: - raise - - -def ipc_socket_chmod(socket: str, permission: int = 0o777): - """Set permissions for IPC socket. - - Needed to make socket available for non-root users. - """ - - parsed = urlparse(socket) - if parsed.scheme == 'ipc': - logger.debug('Setting socket permissions to %o (%s).', permission, socket) - os.chmod(parsed.path, permission) - - -def create_ipc_socket_dirs(socket: str): - """Create parent directories for an IPC socket.""" - - parsed = urlparse(socket) - if parsed.scheme == 'ipc': - dir_name = os.path.dirname(parsed.path) - if not os.path.exists(dir_name): - logger.debug( - 'Making directories for ipc socket %s, path %s.', socket, dir_name - ) - os.makedirs(dir_name) +def get_zmq_socket_uri_options(uri: str) -> Optional[str]: + socket_options, _ = socket_uri_pattern.fullmatch(uri).groups() + return socket_options diff --git a/scripts/run_sink.py b/scripts/run_sink.py index efe564750..74190d8bc 100755 --- a/scripts/run_sink.py +++ b/scripts/run_sink.py @@ -303,7 +303,7 @@ def video_files_sink( cmd = build_docker_run_command( f'sink-video-files-{uuid.uuid4().hex}', zmq_endpoints=[in_endpoint], - entrypoint='/opt/savant/adapters/gst/sinks/video_files.sh', + entrypoint='/opt/savant/adapters/gst/sinks/video_files.py', envs=envs, volumes=[f'{location}:{location}'], docker_image=docker_image,