diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 00000000..8dca3717 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,3 @@ +{ + "recommendations": ["charliermarsh.ruff", "ms-python.mypy-type-checker"] +} diff --git a/.vscode/settings.json b/.vscode/settings.json index 583caae4..8b89edda 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,6 @@ { + "python.analysis.typeCheckingMode": "basic", + "python.analysis.autoImportCompletions": true, "python.analysis.diagnosticSeverityOverrides": { "reportMissingImports": "none", "reportUnusedImport": "information" @@ -6,10 +8,12 @@ "[python]": { "editor.formatOnSave": true, "editor.formatOnSaveMode": "file", + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.autoIndent": "full", "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" - }, - "editor.defaultFormatter": "charliermarsh.ruff" + "source.fixAll.ruff": "always", + "source.organizeImports.ruff": "explicit" + } }, "[cpp]": { "editor.formatOnSave": true @@ -119,7 +123,5 @@ }, "C_Cpp.errorSquiggles": "disabled", // nanobind 周りでエラーが消えないので全部消す - "editor.tabSize": 2, - "python.analysis.typeCheckingMode": "off", - "cSpell.words": ["dtype", "imshow", "samplerate", "sounddevice"] + "editor.tabSize": 2 } diff --git a/examples/.env.template b/examples/.env.template new file mode 100644 index 00000000..2f5a43a3 --- /dev/null +++ b/examples/.env.template @@ -0,0 +1,13 @@ +# コマンドライン引数 もしくは 環境変数での指定が必須なパラメーター +# SORA_SIGNALING_URLS カンマ区切りで複数指定可能 +SORA_SIGNALING_URLS=wss://1.example.com./signaling,wss://2.example.com/signaling +SORA_CHANNEL_ID=sora +SORA_METADATA='{"access_token": "secret"}' + +# オプション設定 +# SORA_VIDEO_CODEC_TYPE=vp9 +# SORA_VIDEO_BIT_RATE=500 +# SORA_CAMERA_ID=0 +# SORA_VIDEO_WIDTH=640 +# SORA_VIDEO_HEIGHT=480 +# SORA_MESSAGING_LABEL=#sora-devtools \ No newline at end of file diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 00000000..929da49b --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1,7 @@ +.venv +.ruff_cache +.mypy_cache +*.pyc +__pycache__ +.env* +!.env.template \ No newline at end of file diff --git a/examples/.python-version b/examples/.python-version new file mode 100644 index 00000000..9ad6380c --- /dev/null +++ b/examples/.python-version @@ -0,0 +1 @@ +3.8.18 diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..9b7d6160 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,59 @@ +# Python Sora SDK サンプル集 + +## About Shiguredo's open source software + +We will not respond to PRs or issues that have not been discussed on Discord. Also, Discord is only available in Japanese. + +Please read https://github.com/shiguredo/oss/blob/master/README.en.md before use. + +## 時雨堂のオープンソースソフトウェアについて + +利用前に https://github.com/shiguredo/oss をお読みください。 + +## サンプルコードの実行方法 + +[Rye](https://github.com/mitsuhiko/rye) というパッケージマネージャーを利用しています。 + +Linux と macOS の場合は `curl -sSf https://rye-up.com/get | bash` でインストール可能です。 +Windows は https://rye-up.com/ の Installation Instructions を確認してください。 + +### 依存パッケージのビルド + +```console +$ rye sync +``` + +### サンプルコードの実行 + +```console +$ rye run media_recvonly --signaling-urls wss://1.example.com/signaling wss://2.example.com/signaling --channel-id sora +``` + +### コマンドラインの代わりに環境変数を利用する + +```console +$ cp .env.template .env +# .env に必要な変数を設定してください。 +$ rye run media_recvonly +``` + +## ライセンス + +Apache License 2.0 + +``` +Copyright 2023-2024, tnoho (Original Author) +Copyright 2023-2024, Shiguredo Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +``` diff --git a/examples/pyproject.toml b/examples/pyproject.toml new file mode 100644 index 00000000..2f6e120e --- /dev/null +++ b/examples/pyproject.toml @@ -0,0 +1,41 @@ +[project] +name = "sora-sdk-samples" +version = "2024.1.0" +description = "Sora Python SDK Samples" +authors = [{ name = "Shiguredo Inc." }] +dependencies = [ + "opencv-python~=4.9.0.80", + "opencv-python-headless~=4.9.0.80", + "sounddevice~=0.4.6", + "sora-sdk>=2024.1.0", + "mediapipe~=0.10.1", + "python-dotenv>=1.0.1", +] +readme = "README.md" +requires-python = ">= 3.8" + +[project.scripts] +media_sendonly = "media.sendonly:sendonly" +media_recvonly = "media.recvonly:recvonly" +messaging_sendrecv = "messaging.sendrecv:sendrecv" +messaging_sendonly = "messaging.sendonly:sendonly" +messaging_recvonly = "messaging.recvonly:recvonly" +hideface_sender = "ml.hideface_sender:hideface_sender" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = ["ruff>=0.3.0", "mypy>=1.8.0"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/media", "src/messaging", "src/ml"] + +[tool.ruff] +line-length = 100 +indent-width = 4 diff --git a/examples/requirements-dev.lock b/examples/requirements-dev.lock new file mode 100644 index 00000000..0e18ec18 --- /dev/null +++ b/examples/requirements-dev.lock @@ -0,0 +1,76 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false + +-e file:. +absl-py==1.4.0 + # via mediapipe +attrs==23.1.0 + # via mediapipe +cffi==1.15.1 + # via sounddevice +contourpy==1.1.0 + # via matplotlib +cycler==0.11.0 + # via matplotlib +flatbuffers==23.5.26 + # via mediapipe +fonttools==4.40.0 + # via matplotlib +importlib-resources==5.12.0 + # via matplotlib +kiwisolver==1.4.4 + # via matplotlib +matplotlib==3.7.2 + # via mediapipe +mediapipe==0.10.1 + # via sora-sdk-samples +mypy==1.8.0 +mypy-extensions==1.0.0 + # via mypy +numpy==1.24.4 + # via contourpy + # via matplotlib + # via mediapipe + # via opencv-contrib-python + # via opencv-python + # via opencv-python-headless +opencv-contrib-python==4.8.0.74 + # via mediapipe +opencv-python==4.9.0.80 + # via sora-sdk-samples +opencv-python-headless==4.9.0.80 + # via sora-sdk-samples +packaging==23.1 + # via matplotlib +pillow==10.0.0 + # via matplotlib +protobuf==3.20.3 + # via mediapipe +pycparser==2.21 + # via cffi +pyparsing==3.0.9 + # via matplotlib +python-dateutil==2.8.2 + # via matplotlib +python-dotenv==1.0.1 + # via sora-sdk-samples +ruff==0.3.0 +six==1.16.0 + # via python-dateutil +sora-sdk==2024.1.0 + # via sora-sdk-samples +sounddevice==0.4.6 + # via mediapipe + # via sora-sdk-samples +tomli==2.0.1 + # via mypy +typing-extensions==4.10.0 + # via mypy +zipp==3.17.0 + # via importlib-resources diff --git a/examples/requirements.lock b/examples/requirements.lock new file mode 100644 index 00000000..806fcc4b --- /dev/null +++ b/examples/requirements.lock @@ -0,0 +1,68 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false + +-e file:. +absl-py==1.4.0 + # via mediapipe +attrs==23.1.0 + # via mediapipe +cffi==1.15.1 + # via sounddevice +contourpy==1.1.0 + # via matplotlib +cycler==0.11.0 + # via matplotlib +flatbuffers==23.5.26 + # via mediapipe +fonttools==4.40.0 + # via matplotlib +importlib-resources==5.12.0 + # via matplotlib +kiwisolver==1.4.4 + # via matplotlib +matplotlib==3.7.2 + # via mediapipe +mediapipe==0.10.1 + # via sora-sdk-samples +numpy==1.24.4 + # via contourpy + # via matplotlib + # via mediapipe + # via opencv-contrib-python + # via opencv-python + # via opencv-python-headless +opencv-contrib-python==4.8.0.74 + # via mediapipe +opencv-python==4.9.0.80 + # via sora-sdk-samples +opencv-python-headless==4.9.0.80 + # via sora-sdk-samples +packaging==23.1 + # via matplotlib +pillow==10.0.0 + # via matplotlib +protobuf==3.20.3 + # via mediapipe +pycparser==2.21 + # via cffi +pyparsing==3.0.9 + # via matplotlib +python-dateutil==2.8.2 + # via matplotlib +python-dotenv==1.0.1 + # via sora-sdk-samples +six==1.16.0 + # via python-dateutil +sora-sdk==2024.1.0 + # via sora-sdk-samples +sounddevice==0.4.6 + # via mediapipe + # via sora-sdk-samples +zipp==3.17.0 + # via importlib-resources diff --git a/examples/src/media/__init__.py b/examples/src/media/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/src/media/recvonly.py b/examples/src/media/recvonly.py new file mode 100644 index 00000000..8f21cdf1 --- /dev/null +++ b/examples/src/media/recvonly.py @@ -0,0 +1,188 @@ +import argparse +import json +import os +import queue +from threading import Event +from typing import Any, Dict, List, Optional + +import cv2 +import sounddevice +from dotenv import load_dotenv +from numpy import ndarray +from sora_sdk import ( + Sora, + SoraAudioSink, + SoraConnection, + SoraMediaTrack, + SoraSignalingErrorCode, + SoraVideoFrame, + SoraVideoSink, +) + + +class Recvonly: + def __init__( + self, + # python 3.8 まで対応なので list[str] ではなく List[str] にする + signaling_urls: List[str], + channel_id: str, + metadata: Optional[Dict[str, Any]], + openh264: Optional[str], + output_frequency: int = 16000, + output_channels: int = 1, + ): + self._output_frequency = output_frequency + self._output_channels = output_channels + + self._sora: Sora = Sora(openh264=openh264) + self._connection: SoraConnection = self._sora.create_connection( + signaling_urls=signaling_urls, + role="recvonly", + channel_id=channel_id, + metadata=metadata, + ) + self._connection_id = "" + self._connected = Event() + self._closed = False + self._default_connection_timeout_s = 10.0 + + self._audio_sink: Optional[SoraAudioSink] = None + self._video_sink: Optional[SoraVideoSink] = None + + # SoraVideoFrame を格納するキュー + self._q_out: queue.Queue = queue.Queue() + + self._connection.on_set_offer = self._on_set_offer + self._connection.on_notify = self._on_notify + self._connection.on_disconnect = self._on_disconnect + self._connection.on_track = self._on_track + + def connect(self): + self._connection.connect() + + assert self._connected.wait( + timeout=self._default_connection_timeout_s + ), "接続に失敗しました" + + def disconnect(self): + self._connection.disconnect() + + def _on_set_offer(self, raw_message: str): + message: Dict[str, Any] = json.loads(raw_message) + if message["type"] == "offer": + # "type": "offer" に入ってくる自分の connection_id を保存する + self._connection_id = message["connection_id"] + + def _on_notify(self, raw_message: str): + message: Dict[str, Any] = json.loads(raw_message) + # "type": "notify" の "connection.created" で通知される connection_id が + # 自分の connection_id と一致する場合に接続完了とする + if ( + message["type"] == "notify" + and message["event_type"] == "connection.created" + and message["connection_id"] == self._connection_id + ): + print("Sora に接続しました") + self._connected.set() + + def _on_disconnect(self, error_code: SoraSignalingErrorCode, message: str): + print(f"Sora から切断されました: error_code='{error_code}' message='{message}'") + self._connected.clear() + self._closed = True + + def _on_video_frame(self, frame: SoraVideoFrame): + # キューに SoraVideoFrame を入れる + self._q_out.put(frame) + + def _on_track(self, track: SoraMediaTrack): + if track.kind == "audio": + self._audio_sink = SoraAudioSink(track, self._output_frequency, self._output_channels) + if track.kind == "video": + self._video_sink = SoraVideoSink(track) + self._video_sink.on_frame = self._on_video_frame + + def _callback(self, outdata: ndarray, frames: int, time, status: sounddevice.CallbackFlags): + if self._audio_sink is not None: + success, data = self._audio_sink.read(frames) + if success: + if data.shape[0] != frames: + print("音声データが十分ではありません", data.shape, frames) + outdata[:] = data + else: + print("音声データを取得できません") + + def run(self): + # サウンドデバイスのOutputStreamを使って音声出力を設定 + with sounddevice.OutputStream( + channels=self._output_channels, + callback=self._callback, + samplerate=self._output_frequency, + dtype="int16", + ): + self.connect() + try: + while self._connected.is_set(): + # Windows 環境の場合 timeout を入れておかないと Queue.get() で + # ブロックしたときに脱出方法がなくなる。 + try: + # キューから SoraVideoFrame を取り出す + frame = self._q_out.get(timeout=1) + except queue.Empty: + continue + # 画像を表示する + cv2.imshow("frame", frame.data()) + # これは削除してよさそう + if cv2.waitKey(1) & 0xFF == ord("q"): + break + except KeyboardInterrupt: + pass + finally: + self.disconnect() + + # すべてのウィンドウを破棄 + cv2.destroyAllWindows() + + +def recvonly(): + # .env ファイル読み込み + load_dotenv() + parser = argparse.ArgumentParser() + + # 必須引数 + default_signaling_urls = None + if urls := os.getenv("SORA_SIGNALING_URLS"): + # SORA_SIGNALING_URLS 環境変数はカンマ区切りで複数指定可能 + default_signaling_urls = urls.split(",") + parser.add_argument( + "--signaling-urls", + default=default_signaling_urls, + type=str, + nargs="+", + required=not default_signaling_urls, + help="シグナリング URL", + ) + default_channel_id = os.getenv("SORA_CHANNEL_ID") + parser.add_argument( + "--channel-id", + default=default_channel_id, + required=not default_channel_id, + help="チャネルID", + ) + + # オプション引数 + parser.add_argument("--metadata", default=os.getenv("SORA_METADATA"), help="メタデータ JSON") + parser.add_argument( + "--openh264", type=str, default=None, help="OpenH264 の共有ライブラリへのパス" + ) + args = parser.parse_args() + + metadata = None + if args.metadata: + metadata = json.loads(args.metadata) + + recvonly = Recvonly(args.signaling_urls, args.channel_id, metadata, args.openh264) + recvonly.run() + + +if __name__ == "__main__": + recvonly() diff --git a/examples/src/media/sendonly.py b/examples/src/media/sendonly.py new file mode 100644 index 00000000..6d34e5ec --- /dev/null +++ b/examples/src/media/sendonly.py @@ -0,0 +1,208 @@ +import argparse +import json +import os +from threading import Event +from typing import Any, Dict, List, Optional + +import cv2 +import sounddevice +from dotenv import load_dotenv +from numpy import ndarray +from sora_sdk import Sora, SoraConnection, SoraSignalingErrorCode + + +class SendOnly: + def __init__( + self, + # python 3.8 まで対応なので list[str] ではなく List[str] にする + signaling_urls: List[str], + channel_id: str, + metadata: Optional[Dict[str, Any]], + camera_id: int, + video_codec_type: str, + video_bit_rate: int, + video_width: Optional[int], + video_height: Optional[int], + openh264: Optional[str], + audio_channels: int = 1, + audio_sample_rate: int = 16000, + ): + self.audio_channels = audio_channels + self.audio_sample_rate = audio_sample_rate + + self._sora: Sora = Sora(openh264=openh264) + + self._audio_source = self._sora.create_audio_source( + self.audio_channels, self.audio_sample_rate + ) + self._video_source = self._sora.create_video_source() + + self._connection: SoraConnection = self._sora.create_connection( + signaling_urls=signaling_urls, + role="sendonly", + channel_id=channel_id, + metadata=metadata, + video_codec_type=video_codec_type, + video_bit_rate=video_bit_rate, + audio_source=self._audio_source, + video_source=self._video_source, + ) + self._connection_id = "" + self._connected = Event() + self._closed = False + self._default_connection_timeout_s = 10.0 + + self._connection.on_set_offer = self._on_set_offer + self._connection.on_notify = self._on_notify + self._connection.on_disconnect = self._on_disconnect + + self._video_capture = cv2.VideoCapture(camera_id) + if video_width is not None: + self._video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, video_width) + if video_height is not None: + self._video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height) + + def connect(self): + self._connection.connect() + + assert self._connected.wait( + timeout=self._default_connection_timeout_s + ), "接続がタイムアウトしました" + + def disconnect(self): + self._connection.disconnect() + + def _on_notify(self, raw_message: str): + message: Dict[str, Any] = json.loads(raw_message) + # "type": "notify" の "connection.created" で通知される connection_id が + # 自分の connection_id と一致する場合に接続完了とする + if ( + message["type"] == "notify" + and message["event_type"] == "connection.created" + and message["connection_id"] == self._connection_id + ): + print("Sora に接続しました") + self._connected.set() + + def _on_set_offer(self, raw_message: str): + message: Dict[str, Any] = json.loads(raw_message) + if message["type"] == "offer": + # "type": "offer" に入ってくる自分の connection_id を保存する + self._connection_id = message["connection_id"] + + def _on_disconnect(self, error_code: SoraSignalingErrorCode, message: str): + print(f"Sora から切断されました: error_code='{error_code}' message='{message}'") + self._connected.clear() + self._closed = True + + def _callback(self, indata: ndarray, frames: int, time, status: sounddevice.CallbackFlags): + self._audio_source.on_data(indata) + + def run(self): + # 音声デバイスの入力を Sora に送信する設定 + with sounddevice.InputStream( + samplerate=self.audio_sample_rate, + channels=self.audio_channels, + dtype="int16", + callback=self._callback, + ): + self.connect() + try: + while self._connected.is_set(): + # 取得したフレームを Sora に送信する + success, frame = self._video_capture.read() + if not success: + continue + self._video_source.on_captured(frame) + except KeyboardInterrupt: + pass + finally: + self.disconnect() + self._video_capture.release() + + +def sendonly(): + # .env ファイルを読み込む + load_dotenv() + parser = argparse.ArgumentParser() + + # 必須引数 + default_signaling_urls = None + if urls := os.getenv("SORA_SIGNALING_URLS"): + # SORA_SIGNALING_URLS 環境変数はカンマ区切りで複数指定可能 + default_signaling_urls = urls.split(",") + parser.add_argument( + "--signaling-urls", + default=default_signaling_urls, + type=str, + nargs="+", + required=not default_signaling_urls, + help="シグナリング URL", + ) + default_channel_id = os.getenv("SORA_CHANNEL_ID") + parser.add_argument( + "--channel-id", + default=default_channel_id, + required=not default_channel_id, + help="チャネルID", + ) + + # オプション引数 + parser.add_argument( + "--video-codec-type", + # Sora のデフォルト値と合わせる + default=os.getenv("SORA_VIDEO_CODEC_TYPE", "VP9"), + help="映像コーデックの種類", + ) + parser.add_argument( + "--video-bit-rate", + type=int, + # Sora のデフォルト値と合わせる + default=int(os.getenv("SORA_VIDEO_BIT_RATE", "500")), + help="映像ビットレート", + ) + parser.add_argument("--metadata", default=os.getenv("SORA_METADATA"), help="メタデータ JSON") + parser.add_argument( + "--camera-id", + type=int, + default=int(os.getenv("SORA_CAMERA_ID", "0")), + help="cv2.VideoCapture() に渡すカメラ ID", + ) + parser.add_argument( + "--video-width", + type=int, + default=int(os.getenv("SORA_VIDEO_WIDTH", "640")), + help="入力カメラ映像の横幅のヒント", + ) + parser.add_argument( + "--video-height", + type=int, + default=int(os.getenv("SORA_VIDEO_HEIGHT", "360")), + help="入力カメラ映像の高さのヒント", + ) + parser.add_argument( + "--openh264", type=str, default=None, help="OpenH264 の共有ライブラリへのパス" + ) + args = parser.parse_args() + + # metadata は JSON 形式で指定するので一同 JSON 形式で読み込む + metadata = None + if args.metadata: + metadata = json.loads(args.metadata) + + sendonly = SendOnly( + args.signaling_urls, + args.channel_id, + metadata, + args.camera_id, + args.video_codec_type, + args.video_bit_rate, + args.video_width, + args.video_height, + args.openh264, + ) + sendonly.run() + + +if __name__ == "__main__": + sendonly() diff --git a/examples/src/messaging/__init__.py b/examples/src/messaging/__init__.py new file mode 100644 index 00000000..83d22eba --- /dev/null +++ b/examples/src/messaging/__init__.py @@ -0,0 +1,102 @@ +import json +import random +import time +from threading import Event +from typing import Any, Dict, List, Optional + +from sora_sdk import Sora, SoraConnection, SoraSignalingErrorCode + + +class Messaging: + def __init__( + self, + # python 3.8 まで対応なので list[str] ではなく List[str] にする + signaling_urls: List[str], + channel_id: str, + data_channels: List[Dict[str, Any]], + metadata: Optional[Dict[str, Any]], + ): + self._data_channels = data_channels + + self._sora = Sora() + self._connection: SoraConnection = self._sora.create_connection( + signaling_urls=signaling_urls, + role="sendrecv", + channel_id=channel_id, + metadata=metadata, + audio=False, + video=False, + data_channels=self._data_channels, + data_channel_signaling=True, + ) + self._connection_id: str = "" + + self._connected = Event() + self._closed = False + self._label = data_channels[0]["label"] + self._sendable_data_channels: set = set() + self._is_data_channel_ready = False + + self.sender_id = random.randint(1, 10000) + + self._connection.on_set_offer = self._on_set_offer + self._connection.on_notify = self._on_notify + self._connection.on_data_channel = self._on_data_channel + self._connection.on_message = self._on_message + self._connection.on_disconnect = self._on_disconnect + + @property + def closed(self): + return self._closed + + def connect(self): + self._connection.connect() + + assert self._connected.wait(10), "接続に失敗しました" + + def disconnect(self): + self._connection.disconnect() + + def send(self, data: bytes): + # on_data_channel() が呼ばれるまではデータチャネルの準備ができていないので待機 + while not self._is_data_channel_ready and not self._closed: + time.sleep(0.01) + + self._connection.send_data_channel(self._label, data) + + def _on_set_offer(self, raw_message: str): + message: Dict[str, Any] = json.loads(raw_message) + if message["type"] == "offer": + # "type": "offer" に入ってくる自分の connection_id を保存する + self._connection_id = message["connection_id"] + + def _on_notify(self, raw_message: str): + message: Dict[str, Any] = json.loads(raw_message) + # "type": "notify" の "connection.created" で通知される connection_id が + # 自分の connection_id と一致する場合に接続完了とする + if ( + message["type"] == "notify" + and message["event_type"] == "connection.created" + and message["connection_id"] == self._connection_id + ): + print("Sora に接続しました") + self._connected.set() + + def _on_disconnect(self, error_code: SoraSignalingErrorCode, message: str): + print(f"Sora から切断されました: error_code='{error_code}' message='{message}'") + self._connected.clear() + self._closed = True + + def _on_message(self, label: str, data: bytes): + print(f"メッセージを受信しました: label={label}, data={data.decode('utf-8')}") + + def _on_data_channel(self, label: str): + for data_channel in self._data_channels: + if data_channel["label"] != label: + continue + + if data_channel["direction"] in ["sendrecv", "sendonly"]: + self._sendable_data_channels.add(label) + # データチャネルの準備ができたのでフラグを立てる + self._is_data_channel_ready = True + break diff --git a/examples/src/messaging/recvonly.py b/examples/src/messaging/recvonly.py new file mode 100644 index 00000000..37041ace --- /dev/null +++ b/examples/src/messaging/recvonly.py @@ -0,0 +1,77 @@ +import argparse +import json +import os +import time + +from dotenv import load_dotenv + +from messaging import Messaging + + +def recvonly(): + # .env ファイルを読み込む + load_dotenv() + + parser = argparse.ArgumentParser() + + # 必須引数 + default_signaling_urls = None + if urls := os.getenv("SORA_SIGNALING_URLS"): + # カンマ区切りで複数指定可能 + default_signaling_urls = urls.split(",") + parser.add_argument( + "--signaling-urls", + default=default_signaling_urls, + type=str, + nargs="+", + required=not default_signaling_urls, + help="シグナリング URL", + ) + default_channel_id = os.getenv("SORA_CHANNEL_ID") + parser.add_argument( + "--channel-id", + default=default_channel_id, + required=not default_channel_id, + help="チャネルID", + ) + default_messaging_label = os.getenv("SORA_MESSAGING_LABEL") + parser.add_argument( + "--messaging-label", + default=default_messaging_label, + type=str, + nargs="+", + required=not default_messaging_label, + help="データチャネルのラベル名", + ) + + # オプション引数 + parser.add_argument("--metadata", default=os.getenv("SORA_METADATA"), help="メタデータ JSON") + args = parser.parse_args() + + metadata = {} + if args.metadata: + metadata = json.loads(args.metadata) + + data_channels = [{"label": args.messaging_label, "direction": "recvonly"}] + messaging_recvonly = Messaging( + args.signaling_urls, + args.channel_id, + data_channels, + metadata, + ) + + # Sora に接続する + messaging_recvonly.connect() + try: + # Ctrl+C が押される or 切断されるまでメッセージ受信を待機 + while not messaging_recvonly.closed: + time.sleep(0.01) + except KeyboardInterrupt: + pass + finally: + # Sora から切断する(すでに切断済みの場合には無視される) + messaging_recvonly.disconnect() + + +if __name__ == "__main__": + recvonly() diff --git a/examples/src/messaging/sendonly.py b/examples/src/messaging/sendonly.py new file mode 100644 index 00000000..08b27cb5 --- /dev/null +++ b/examples/src/messaging/sendonly.py @@ -0,0 +1,70 @@ +import argparse +import json +import os + +from dotenv import load_dotenv + +from messaging import Messaging + + +def sendonly(): + # .env ファイルを読み込む + load_dotenv() + + parser = argparse.ArgumentParser() + + # 必須引数 + default_signaling_urls = None + if urls := os.getenv("SORA_SIGNALING_URLS"): + # カンマ区切りで複数指定可能 + default_signaling_urls = urls.split(",") + parser.add_argument( + "--signaling-urls", + default=default_signaling_urls, + type=str, + nargs="+", + required=not default_signaling_urls, + help="シグナリング URL", + ) + default_channel_id = os.getenv("SORA_CHANNEL_ID") + parser.add_argument( + "--channel-id", + default=default_channel_id, + required=not default_channel_id, + help="チャネルID", + ) + default_messaging_label = os.getenv("SORA_MESSAGING_LABEL", "#example") + parser.add_argument( + "--messaging-label", + default=default_messaging_label, + required=not default_messaging_label, + help="データチャネルのラベル名", + ) + + # オプション引数 + parser.add_argument("--metadata", default=os.getenv("SORA_METADATA"), help="メタデータ JSON") + args = parser.parse_args() + + metadata = None + if args.metadata: + metadata = json.loads(args.metadata) + + # data_channels 組み立て + data_channels = [{"label": args.messaging_label, "direction": "sendonly"}] + messaging_sendonly = Messaging(args.signaling_urls, args.channel_id, data_channels, metadata) + + # Sora に接続する + messaging_sendonly.connect() + try: + while not messaging_sendonly.closed: + # input で入力された文字列を utf-8 でエンコードして送信 + message = input("Enter キーを押すと送信します: ") + messaging_sendonly.send(message.encode("utf-8")) + except KeyboardInterrupt: + pass + finally: + messaging_sendonly.disconnect() + + +if __name__ == "__main__": + sendonly() diff --git a/examples/src/messaging/sendrecv.py b/examples/src/messaging/sendrecv.py new file mode 100644 index 00000000..a24ae110 --- /dev/null +++ b/examples/src/messaging/sendrecv.py @@ -0,0 +1,70 @@ +import argparse +import json +import os + +from dotenv import load_dotenv + +from messaging import Messaging + + +def sendrecv(): + # .env ファイルを読み込む + load_dotenv() + + parser = argparse.ArgumentParser() + + # 必須引数 + default_signaling_urls = None + if urls := os.getenv("SORA_SIGNALING_URLS"): + # カンマ区切りで複数指定可能 + default_signaling_urls = urls.split(",") + parser.add_argument( + "--signaling-urls", + default=default_signaling_urls, + type=str, + nargs="+", + required=not default_signaling_urls, + help="シグナリング URL", + ) + default_channel_id = os.getenv("SORA_CHANNEL_ID") + parser.add_argument( + "--channel-id", + default=default_channel_id, + required=not default_channel_id, + help="チャネルID", + ) + default_messaging_label = os.getenv("SORA_MESSAGING_LABEL") + parser.add_argument( + "--messaging-label", + default=default_messaging_label, + type=str, + nargs="+", + required=not default_messaging_label, + help="データチャネルのラベル名", + ) + + # オプション引数 + parser.add_argument("--metadata", default=os.getenv("SORA_METADATA"), help="メタデータ JSON") + args = parser.parse_args() + + metadata = None + if args.metadata: + metadata = json.loads(args.metadata) + + data_channels = [{"label": args.messaging_label, "direction": "sendrecv"}] + messaging_sendrecv = Messaging(args.signaling_urls, args.channel_id, data_channels, metadata) + # Sora に接続する + messaging_sendrecv.connect() + try: + while not messaging_sendrecv.closed: + # input で入力された文字列を utf-8 でエンコードして送信 + message = input() + messaging_sendrecv.send(message.encode("utf-8")) + except KeyboardInterrupt: + pass + finally: + messaging_sendrecv.disconnect() + + +if __name__ == "__main__": + sendrecv() diff --git a/examples/src/ml/hideface_sender.py b/examples/src/ml/hideface_sender.py new file mode 100644 index 00000000..0dda65d4 --- /dev/null +++ b/examples/src/ml/hideface_sender.py @@ -0,0 +1,197 @@ +import argparse +import json +import math +import os +from pathlib import Path +from threading import Event +from typing import Any, Dict, List, Optional + +import cv2 +import mediapipe as mp +import numpy as np +from cv2.typing import MatLike +from dotenv import load_dotenv +from PIL import Image +from sora_sdk import Sora, SoraSignalingErrorCode, SoraVideoSource + + +class LogoStreamer: + def __init__( + self, + signaling_urls: List[str], + role: str, + channel_id: str, + metadata: Optional[Dict[str, Any]], + camera_id: int, + video_width: Optional[int], + video_height: Optional[int], + ): + self.mp_face_detection = mp.solutions.face_detection + + self._sora = Sora(openh264=None) + self._video_source: SoraVideoSource = self._sora.create_video_source() + self._connection = self._sora.create_connection( + signaling_urls=signaling_urls, + role=role, + channel_id=channel_id, + metadata=metadata, + video_codec_type=None, + video_bit_rate=500, + video_source=self._video_source, + ) + self._connection_id = "" + + self._connected = Event() + self._closed = False + self._default_connection_timeout_s = 10.0 + + self._connection.on_set_offer = self._on_set_offer + self._connection.on_notify = self._on_notify + self._connection.on_disconnect = self._on_disconnect + + self._video_capture = cv2.VideoCapture(camera_id) + if video_width is not None: + self._video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, video_width) + if video_height is not None: + self._video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height) + + # ロゴを読み込む + self._logo = Image.open(Path(__file__).parent.joinpath("shiguremaru.png")) + + def connect(self): + self._connection.connect() + + assert self._connected.wait( + timeout=self._default_connection_timeout_s + ), "接続に失敗しました" + + def disconnect(self): + self._connection.disconnect() + + def _on_disconnect(self, error_code: SoraSignalingErrorCode, message: str): + print(f"Sora から切断されました: error_code='{error_code}' message='{message}'") + self._connected.clear() + self._closed = True + + def _on_set_offer(self, raw_message: str): + message = json.loads(raw_message) + if message["type"] == "offer": + self._connection_id = message["connection_id"] + + def _on_notify(self, raw_message: str): + message = json.loads(raw_message) + if ( + message["type"] == "notify" + and message["event_type"] == "connection.created" + and message["connection_id"] == self._connection_id + ): + print("Sora に接続しました") + self._connected.set() + + def run(self): + self.connect() + try: + # 顔検出を用意する + # TODO: face_detection の型を調べる + with self.mp_face_detection.FaceDetection( + model_selection=0, min_detection_confidence=0.5 + ) as face_detection: + angle = 0 + while self._connected.is_set() and self._video_capture.isOpened(): + # フレームを取得する + success, frame = self._video_capture.read() + if not success: + continue + angle = self.run_one_frame(face_detection, angle, frame) + except KeyboardInterrupt: + pass + finally: + self.disconnect() + self._video_capture.release() + + def run_one_frame(self, face_detection, angle: int, frame: MatLike): + # 高速化の呪文 + frame.flags.writeable = False + # mediapipe や PIL で処理できるように色の順序を変える + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # mediapipe で顔を検出する + results = face_detection.process(frame) + + frame_height, frame_width, _ = frame.shape + # PIL で処理できるように画像を変換する + pil_image = Image.fromarray(frame) + + # ロゴを回しておく + rotated_logo = self._logo.rotate(angle) + angle += 1 + if angle >= 360: + angle = 0 + if results.detections: + for detection in results.detections: + location = detection.location_data + if not location.HasField("relative_bounding_box"): + continue + bb = location.relative_bounding_box + + # 正規化されているので逆正規化を行う + w_px = math.floor(bb.width * frame_width) + h_px = math.floor(bb.height * frame_height) + x_px = min(math.floor(bb.xmin * frame_width), frame_width - 1) + y_px = min(math.floor(bb.ymin * frame_height), frame_height - 1) + + # 検出領域は顔に対して小さいため、顔全体が覆われるように検出領域を大きくする + fixed_w_px = math.floor(w_px * 1.6) + fixed_h_px = math.floor(h_px * 1.6) + # 大きくした分、座標がずれてしまうため顔の中心になるように座標を補正する + fixed_x_px = max(0, math.floor(x_px - (fixed_w_px - w_px) / 2)) + # 検出領域は顔であり頭が入っていないため、上寄りになるように座標を補正する + fixed_y_px = max(0, math.floor(y_px - (fixed_h_px - h_px))) + + # ロゴをリサイズする + resized_logo = rotated_logo.resize((fixed_w_px, fixed_h_px)) + pil_image.paste(resized_logo, (fixed_x_px, fixed_y_px), resized_logo) + + frame.flags.writeable = True + # PIL から numpy に画像を戻す + frame = np.array(pil_image) + # 色の順序をもとに戻す + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + # WebRTC に渡す + self._video_source.on_captured(frame) + return angle + + +def hideface_sender(): + # .env ファイルを読み込む + load_dotenv() + + # 必須引数 + signaling_urls = os.getenv("SORA_SIGNALING_URLS").split(",") + channel_id = os.getenv("SORA_CHANNEL_ID") + + # オプション引数 + metadata = None + raw_metadata = os.getenv("SORA_METADATA") + if raw_metadata is not None: + metadata = json.loads(raw_metadata) + + camera_id = int(os.getenv("SORA_CAMERA_ID", "0")) + video_width = int(os.getenv("SORA_VIDEO_WIDTH", "640")) + video_height = int(os.getenv("SORA_VIDEO_HEIGHT", "360")) + + streamer = LogoStreamer( + signaling_urls=signaling_urls, + role="sendonly", + channel_id=channel_id, + metadata=metadata, + camera_id=camera_id, + video_height=video_height, + video_width=video_width, + ) + streamer.run() + + +if __name__ == "__main__": + hideface_sender() diff --git a/examples/src/ml/shiguremaru.png b/examples/src/ml/shiguremaru.png new file mode 100644 index 00000000..1ac1e541 Binary files /dev/null and b/examples/src/ml/shiguremaru.png differ diff --git a/examples/sync.sh b/examples/sync.sh new file mode 100755 index 00000000..6c9c4d77 --- /dev/null +++ b/examples/sync.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# ローカルインストールしたキャッシュを削除した上で rye sync するスクリプト。 +# +# `rye add sora_sdk --path ` でローカルで書き換えた sora-python-sdk を +# 利用可能だが、キャッシュが残っていると一生更新されないため、キャッシュディレクトリを削除する。 +# +# 毎回どこのディレクトリを消せばいいか忘れてしまうので、このスクリプトで対応する。 + +set -ex + +case "`uname`" in + "Darwin" ) CACHE_DIR=~/Library/Caches/pip ;; + "Linux" ) CACHE_DIR=~/.cache/pip ;; + * ) exit 1 ;; +esac + +rm -rf $CACHE_DIR/wheels +rye sync