diff --git a/lib/ex_webrtc/dtls_transport.ex b/lib/ex_webrtc/dtls_transport.ex index 2fa1148e..76454ee5 100644 --- a/lib/ex_webrtc/dtls_transport.ex +++ b/lib/ex_webrtc/dtls_transport.ex @@ -7,8 +7,7 @@ defmodule ExWebRTC.DTLSTransport do require Logger - alias ExICE.ICEAgent - alias ExWebRTC.Utils + alias ExWebRTC.{DefaultICETransport, ICETransport, Utils} @type dtls_transport() :: GenServer.server() @@ -31,15 +30,21 @@ defmodule ExWebRTC.DTLSTransport do @type dtls_state() :: :new | :connecting | :connected | :closed | :failed @doc false - @spec start_link(ExICE.ICEAgent.opts(), GenServer.server()) :: GenServer.on_start() - def start_link(ice_config, ice_module \\ ICEAgent) do - GenServer.start_link(__MODULE__, [ice_config, ice_module, self()]) + @spec start_link(ICETransport.t(), Keyword.t()) :: GenServer.on_start() + def start_link(ice_transport \\ DefaultICETransport, ice_config) do + behaviour = ice_transport.__info__(:attributes)[:behaviour] || [] + + unless ICETransport in behaviour do + raise "DTLSTransport requires ice_transport to implement ExWebRTC.ICETransport beahviour." + end + + GenServer.start_link(__MODULE__, [ice_transport, ice_config, self()]) end @doc false - @spec get_ice_agent(dtls_transport()) :: GenServer.server() - def get_ice_agent(dtls_transport) do - GenServer.call(dtls_transport, :get_ice_agent) + @spec get_ice_transport(dtls_transport()) :: {module(), pid()} + def get_ice_transport(dtls_transport) do + GenServer.call(dtls_transport, :get_ice_transport) end @doc false @@ -68,15 +73,16 @@ defmodule ExWebRTC.DTLSTransport do end @impl true - def init([ice_config, ice_module, owner]) do + def init([ice_transport, ice_config, owner]) do {pkey, cert} = ExDTLS.generate_key_cert() fingerprint = ExDTLS.get_cert_fingerprint(cert) - {:ok, ice_agent} = ice_module.start_link(:controlled, ice_config) + {:ok, ice_pid} = ice_transport.start_link(:controlled, ice_config) state = %{ owner: owner, - ice_agent: ice_agent, + ice_transport: ice_transport, + ice_pid: ice_pid, ice_state: nil, buffered_packets: nil, cert: cert, @@ -97,8 +103,8 @@ defmodule ExWebRTC.DTLSTransport do end @impl true - def handle_call(:get_ice_agent, _from, state) do - {:reply, state.ice_agent, state} + def handle_call(:get_ice_transport, _from, state) do + {:reply, {state.ice_transport, state.ice_pid}, state} end @impl true @@ -134,7 +140,7 @@ defmodule ExWebRTC.DTLSTransport do def handle_cast({:send_rtp, data}, %{dtls_state: :connected, ice_state: ice_state} = state) when ice_state in [:connected, :completed] do case ExLibSRTP.protect(state.out_srtp, data) do - {:ok, protected} -> ICEAgent.send_data(state.ice_agent, protected) + {:ok, protected} -> state.ice_transport.send_data(state.ice_pid, protected) {:error, reason} -> Logger.error("Unable to protect RTP: #{inspect(reason)}") end @@ -160,7 +166,7 @@ defmodule ExWebRTC.DTLSTransport do ) do case ExDTLS.handle_timeout(state.dtls) do {:retransmit, packets, timeout} when ice_state in [:connected, :completed] -> - ICEAgent.send_data(state.ice_agent, packets) + state.ice_transport.send_data(state.ice_pid, packets) Process.send_after(self(), :dtls_timeout, timeout) {:retransmit, ^buffered_packets, timeout} -> @@ -199,7 +205,7 @@ defmodule ExWebRTC.DTLSTransport do # TODO: handle {:connection_closed, _} case ExDTLS.handle_data(state.dtls, data) do {:handshake_packets, packets, timeout} when state.ice_state in [:connected, :completed] -> - :ok = ICEAgent.send_data(state.ice_agent, packets) + :ok = state.ice_transport.send_data(state.ice_pid, packets) Process.send_after(self(), :dtls_timeout, timeout) update_dtls_state(state, :connecting) @@ -215,7 +221,7 @@ defmodule ExWebRTC.DTLSTransport do {:handshake_finished, lkm, rkm, profile, packets} -> Logger.debug("DTLS handshake finished") - ICEAgent.send_data(state.ice_agent, packets) + state.ice_transport.send_data(state.ice_pid, packets) peer_fingerprint = state.dtls @@ -275,7 +281,7 @@ defmodule ExWebRTC.DTLSTransport do if state.mode == :active do {packets, timeout} = ExDTLS.do_handshake(state.dtls) Process.send_after(self(), :dtls_timeout, timeout) - :ok = ICEAgent.send_data(state.ice_agent, packets) + :ok = state.ice_transport.send_data(state.ice_pid, packets) update_dtls_state(state, :connecting) else state @@ -286,7 +292,7 @@ defmodule ExWebRTC.DTLSTransport do when new_ice_state in [:connected, :completed] do if state.buffered_packets do Logger.debug("Sending buffered DTLS packets") - :ok = ICEAgent.send_data(state.ice_agent, state.buffered_packets) + :ok = state.ice_transport.send_data(state.ice_pid, state.buffered_packets) %{state | ice_state: new_ice_state, buffered_packets: nil} else state diff --git a/lib/ex_webrtc/ice_transport.ex b/lib/ex_webrtc/ice_transport.ex new file mode 100644 index 00000000..a8416de4 --- /dev/null +++ b/lib/ex_webrtc/ice_transport.ex @@ -0,0 +1,40 @@ +defmodule ExWebRTC.ICETransport do + @moduledoc false + + # module implementing this behaviour + @type t() :: module() + + @callback start_link(ExICE.ICEAgent.role(), Keyword.t()) :: {:ok, pid()} + @callback add_remote_candidate(pid(), candidate :: String.t()) :: :ok + @callback end_of_candidates(pid()) :: :ok + @callback gather_candidates(pid()) :: :ok + @callback get_local_credentials(pid()) :: {:ok, ufrag :: binary(), pwd :: binary()} + @callback restart(pid()) :: :ok + @callback send_data(pid(), binary()) :: :ok + @callback set_remote_credentials(pid(), ufrag :: binary(), pwd :: binary()) :: :ok +end + +defmodule ExWebRTC.DefaultICETransport do + @moduledoc false + + @behaviour ExWebRTC.ICETransport + + alias ExICE.ICEAgent + + @impl true + defdelegate add_remote_candidate(pid, candidate), to: ICEAgent + @impl true + defdelegate end_of_candidates(pid), to: ICEAgent + @impl true + defdelegate gather_candidates(pid), to: ICEAgent + @impl true + defdelegate get_local_credentials(pid), to: ICEAgent + @impl true + defdelegate restart(pid), to: ICEAgent + @impl true + defdelegate send_data(pid, data), to: ICEAgent + @impl true + defdelegate set_remote_credentials(pid, ufrag, pwd), to: ICEAgent + @impl true + defdelegate start_link(role, opts), to: ICEAgent +end diff --git a/lib/ex_webrtc/peer_connection.ex b/lib/ex_webrtc/peer_connection.ex index d1ad1a5d..2d6b3283 100644 --- a/lib/ex_webrtc/peer_connection.ex +++ b/lib/ex_webrtc/peer_connection.ex @@ -8,7 +8,6 @@ defmodule ExWebRTC.PeerConnection do require Logger alias __MODULE__.{Configuration, Demuxer} - alias ExICE.ICEAgent alias ExWebRTC.{ DTLSTransport, @@ -121,7 +120,7 @@ defmodule ExWebRTC.PeerConnection do def init({owner, config}) do ice_config = [stun_servers: config.ice_servers] {:ok, dtls_transport} = DTLSTransport.start_link(ice_config) - ice_agent = DTLSTransport.get_ice_agent(dtls_transport) + {ice_transport, ice_pid} = DTLSTransport.get_ice_transport(dtls_transport) state = %{ owner: owner, @@ -130,7 +129,8 @@ defmodule ExWebRTC.PeerConnection do pending_local_desc: nil, current_remote_desc: nil, pending_remote_desc: nil, - ice_agent: ice_agent, + ice_transport: ice_transport, + ice_pid: ice_pid, dtls_transport: dtls_transport, demuxer: %Demuxer{}, transceivers: [], @@ -159,13 +159,14 @@ defmodule ExWebRTC.PeerConnection do # TODO: handle subsequent offers if Keyword.get(options, :ice_restart, false) do - :ok = ICEAgent.restart(state.ice_agent) + :ok = state.ice_transport.restart(state.ice_pid) end next_mid = find_next_mid(state) transceivers = assign_mids(state.transceivers, next_mid) - {:ok, ice_ufrag, ice_pwd} = ICEAgent.get_local_credentials(state.ice_agent) + {:ok, ice_ufrag, ice_pwd} = + state.ice_transport.get_local_credentials(state.ice_pid) offer = %ExSDP{ExSDP.new() | timing: %ExSDP.Timing{start_time: 0, stop_time: 0}} @@ -219,7 +220,8 @@ defmodule ExWebRTC.PeerConnection do def handle_call({:create_answer, _options}, _from, state) do {:offer, remote_offer} = state.pending_remote_desc - {:ok, ice_ufrag, ice_pwd} = ICEAgent.get_local_credentials(state.ice_agent) + {:ok, ice_ufrag, ice_pwd} = + state.ice_transport.get_local_credentials(state.ice_pid) answer = %ExSDP{ExSDP.new() | timing: %ExSDP.Timing{start_time: 0, stop_time: 0}} @@ -315,7 +317,7 @@ defmodule ExWebRTC.PeerConnection do @impl true def handle_call({:add_ice_candidate, candidate}, _from, state) do with "candidate:" <> attr <- candidate.candidate do - ICEAgent.add_remote_candidate(state.ice_agent, attr) + state.ice_transport.add_remote_candidate(state.ice_pid, attr) end {:reply, :ok, state} @@ -496,8 +498,10 @@ defmodule ExWebRTC.PeerConnection do {:ok, {:fingerprint, {:sha256, peer_fingerprint}}} <- SDPUtils.get_cert_fingerprint(sdp), {:ok, new_transceivers} <- update_remote_transceivers(state.transceivers, sdp, state.config) do - :ok = ICEAgent.set_remote_credentials(state.ice_agent, ice_ufrag, ice_pwd) - :ok = ICEAgent.gather_candidates(state.ice_agent) + :ok = + state.ice_transport.set_remote_credentials(state.ice_pid, ice_ufrag, ice_pwd) + + :ok = state.ice_transport.gather_candidates(state.ice_pid) # TODO: this needs a look diff --git a/test/ex_webrtc/dtls_transport_test.exs b/test/ex_webrtc/dtls_transport_test.exs index 75f25251..8dba166d 100644 --- a/test/ex_webrtc/dtls_transport_test.exs +++ b/test/ex_webrtc/dtls_transport_test.exs @@ -9,20 +9,36 @@ defmodule ExWebRTC.DTLSTransportTest do |> ExDTLS.get_cert_fingerprint() |> Utils.hex_dump() - defmodule FakeICEAgent do + defmodule MockICETransport do + @behaviour ExWebRTC.ICETransport + use GenServer - def start_link(_mode, config) do - GenServer.start_link(__MODULE__, {self(), config}) - end + @impl true + def start_link(_mode, config), do: GenServer.start_link(__MODULE__, {self(), config}) - def send_data(ice_agent, data) do - GenServer.cast(ice_agent, {:send_data, data}) - end + @impl true + def send_data(ice_pid, data), do: GenServer.cast(ice_pid, {:send_data, data}) - def send_dtls(ice_agent, data) do - GenServer.cast(ice_agent, {:send_dtls, data}) - end + @impl true + def add_remote_candidate(ice_pid, _candidate), do: ice_pid + + @impl true + def end_of_candidates(ice_pid), do: ice_pid + + @impl true + def gather_candidates(ice_pid), do: ice_pid + + @impl true + def get_local_credentials(_state), do: {:ok, "testufrag", "testpwd"} + + @impl true + def restart(ice_pid), do: ice_pid + + @impl true + def set_remote_credentials(ice_pid, _ufrag, _pwd), do: ice_pid + + def send_dtls(ice_pid, data), do: GenServer.cast(ice_pid, {:send_dtls, data}) @impl true def init({dtls, tester: tester}), @@ -42,21 +58,20 @@ defmodule ExWebRTC.DTLSTransportTest do end setup do - assert {:ok, dtls} = DTLSTransport.start_link([tester: self()], FakeICEAgent) + assert {:ok, dtls} = DTLSTransport.start_link(MockICETransport, tester: self()) assert_receive {:dtls_transport, ^dtls, {:state_change, :new}} - ice = DTLSTransport.get_ice_agent(dtls) - assert is_pid(ice) + {ice_transport, ice_pid} = DTLSTransport.get_ice_transport(dtls) - %{dtls: dtls, ice: ice} + %{dtls: dtls, ice_transport: ice_transport, ice_pid: ice_pid} end - test "forwards non-data ICE messages", %{ice: ice} do + test "forwards non-data ICE messages", %{ice_transport: ice_transport, ice_pid: ice_pid} do message = :connected - FakeICEAgent.send_dtls(ice, message) + ice_transport.send_dtls(ice_pid, message) assert_receive {:ex_ice, _from, ^message} - FakeICEAgent.send_dtls(ice, {:data, <<1, 2, 3>>}) + ice_transport.send_dtls(ice_pid, {:data, <<1, 2, 3>>}) refute_receive {:ex_ice, _from, _msg} end @@ -71,27 +86,39 @@ defmodule ExWebRTC.DTLSTransportTest do assert {:error, :already_started} = DTLSTransport.start_dtls(dtls, :passive, @fingerprint) end - test "initiates DTLS handshake when in active mode", %{dtls: dtls, ice: ice} do + test "initiates DTLS handshake when in active mode", %{ + dtls: dtls, + ice_transport: ice_transport, + ice_pid: ice_pid + } do :ok = DTLSTransport.start_dtls(dtls, :active, @fingerprint) - FakeICEAgent.send_dtls(ice, {:connection_state_change, :connected}) + ice_transport.send_dtls(ice_pid, {:connection_state_change, :connected}) assert_receive {:fake_ice, packets} assert is_binary(packets) end - test "won't initiate DTLS handshake when in passive mode", %{dtls: dtls, ice: ice} do + test "won't initiate DTLS handshake when in passive mode", %{ + dtls: dtls, + ice_transport: ice_transport, + ice_pid: ice_pid + } do :ok = DTLSTransport.start_dtls(dtls, :passive, @fingerprint) - FakeICEAgent.send_dtls(ice, {:connection_state_change, :connected}) + ice_transport.send_dtls(ice_pid, {:connection_state_change, :connected}) refute_receive({:fake_ice, _msg}) end - test "will retransmit after initiating handshake", %{dtls: dtls, ice: ice} do + test "will retransmit after initiating handshake", %{ + dtls: dtls, + ice_transport: ice_transport, + ice_pid: ice_pid + } do :ok = DTLSTransport.start_dtls(dtls, :active, @fingerprint) - FakeICEAgent.send_dtls(ice, {:connection_state_change, :connected}) + ice_transport.send_dtls(ice_pid, {:connection_state_change, :connected}) assert_receive {:fake_ice, _packets} @@ -99,32 +126,44 @@ defmodule ExWebRTC.DTLSTransportTest do 1000 + ExUnit.configuration()[:assert_receive_timeout] end - test "will buffer packets and send when connected", %{dtls: dtls, ice: ice} do + test "will buffer packets and send when connected", %{ + dtls: dtls, + ice_transport: ice_transport, + ice_pid: ice_pid + } do :ok = DTLSTransport.start_dtls(dtls, :passive, @fingerprint) remote_dtls = ExDTLS.init(mode: :client, dtls_srtp: true) {packets, _timeout} = ExDTLS.do_handshake(remote_dtls) - FakeICEAgent.send_dtls(ice, {:data, packets}) + ice_transport.send_dtls(ice_pid, {:data, packets}) refute_receive {:fake_ice, _packets} - FakeICEAgent.send_dtls(ice, {:connection_state_change, :connected}) + ice_transport.send_dtls(ice_pid, {:connection_state_change, :connected}) assert_receive {:fake_ice, packets} assert is_binary(packets) end - test "finishes handshake in active mode", %{dtls: dtls, ice: ice} do + test "finishes handshake in active mode", %{ + dtls: dtls, + ice_transport: ice_transport, + ice_pid: ice_pid + } do :ok = DTLSTransport.start_dtls(dtls, :active, @fingerprint) remote_dtls = ExDTLS.init(mode: :server, dtls_srtp: true) - FakeICEAgent.send_dtls(ice, {:connection_state_change, :connected}) + ice_transport.send_dtls(ice_pid, {:connection_state_change, :connected}) - assert :ok = check_handshake(dtls, ice, remote_dtls) + assert :ok = check_handshake(dtls, ice_transport, ice_pid, remote_dtls) assert_receive {:dtls_transport, ^dtls, {:state_change, :connecting}} assert_receive {:dtls_transport, ^dtls, {:state_change, :connected}} end - test "finishes handshake in passive mode", %{dtls: dtls, ice: ice} do + test "finishes handshake in passive mode", %{ + dtls: dtls, + ice_transport: ice_transport, + ice_pid: ice_pid + } do remote_dtls = ExDTLS.init(mode: :client, dtls_srtp: true) remote_fingerprint = @@ -136,25 +175,25 @@ defmodule ExWebRTC.DTLSTransportTest do :ok = DTLSTransport.start_dtls(dtls, :passive, remote_fingerprint) {packets, _timeout} = ExDTLS.do_handshake(remote_dtls) - FakeICEAgent.send_dtls(ice, {:connection_state_change, :connected}) + ice_transport.send_dtls(ice_pid, {:connection_state_change, :connected}) - FakeICEAgent.send_dtls(ice, {:data, packets}) + ice_transport.send_dtls(ice_pid, {:data, packets}) - assert :ok == check_handshake(dtls, ice, remote_dtls) + assert :ok == check_handshake(dtls, ice_transport, ice_pid, remote_dtls) assert_receive {:dtls_transport, ^dtls, {:state_change, :connecting}} assert_receive {:dtls_transport, ^dtls, {:state_change, :connected}} end - defp check_handshake(dtls, ice, remote_dtls) do + defp check_handshake(dtls, ice_transport, ice_pid, remote_dtls) do assert_receive {:fake_ice, packets} case ExDTLS.handle_data(remote_dtls, packets) do {:handshake_packets, packets, _timeout} -> - FakeICEAgent.send_dtls(ice, {:data, packets}) - check_handshake(dtls, ice, remote_dtls) + ice_transport.send_dtls(ice_pid, {:data, packets}) + check_handshake(dtls, ice_transport, ice_pid, remote_dtls) {:handshake_finished, _, _, _, packets} -> - FakeICEAgent.send_dtls(ice, {:data, packets}) + ice_transport.send_dtls(ice_pid, {:data, packets}) :ok {:handshake_finished, _, _, _} -> diff --git a/test/test_helper.exs b/test/test_helper.exs index bcd608a3..7db342e3 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1 +1 @@ -ExUnit.start(capture_log: true, assert_receive_timeout: 300) +ExUnit.start(capture_log: true, assert_receive_timeout: 400)