From 8478e6815a06039a7023730f9e38d71e3cf22b50 Mon Sep 17 00:00:00 2001 From: firest Date: Fri, 12 Jul 2024 23:11:12 +0800 Subject: [PATCH] fix: refactor the udp proxy --- include/esockd_proxy.hrl | 1 + src/esockd_udp.erl | 10 ++ src/udp_proxy/esockd_udp_proxy.erl | 129 +++++++++++------- src/udp_proxy/esockd_udp_proxy_connection.erl | 9 +- src/udp_proxy/esockd_udp_proxy_db.erl | 60 +++----- 5 files changed, 115 insertions(+), 94 deletions(-) diff --git a/include/esockd_proxy.hrl b/include/esockd_proxy.hrl index 4926957..d0a910f 100644 --- a/include/esockd_proxy.hrl +++ b/include/esockd_proxy.hrl @@ -48,6 +48,7 @@ -type get_connection_id_result() :: %% send decoded packet {ok, connection_id(), connection_packet(), connection_state()} + | {error, binary()} | invalid. -type connection_options() :: #{ diff --git a/src/esockd_udp.erl b/src/esockd_udp.erl index fa96bb0..9fd2b3d 100644 --- a/src/esockd_udp.erl +++ b/src/esockd_udp.erl @@ -53,6 +53,8 @@ , code_change/3 ]). +-export([proxy_request/1]). + -type(maybe(T) :: undefined | T). -record(state, { @@ -106,6 +108,10 @@ count_peers(Pid) -> -spec(stop(pid()) -> ok). stop(Pid) -> gen_server:stop(Pid). +proxy_request(Fun) -> + Parent = gen:get_parent(), + gen_server:call(Parent, {?FUNCTION_NAME, Fun}, infinity). + %%-------------------------------------------------------------------- %% GET/SET APIs %%-------------------------------------------------------------------- @@ -233,6 +239,10 @@ handle_call(which_children, _From, State = #state{peers = Peers, mfa = {Mod, _Fu {reply, [{undefined, Pid, worker, [Mod]} || Pid <- maps:keys(Peers), is_pid(Pid), erlang:is_process_alive(Pid)], State}; +handle_call({proxy_request, Fun}, _From, State) -> + Result = Fun(), + {reply, Result, State}; + handle_call(Req, _From, State) -> ?ERROR_MSG("Unexpected call: ~p", [Req]), {reply, ignore, State}. diff --git a/src/udp_proxy/esockd_udp_proxy.erl b/src/udp_proxy/esockd_udp_proxy.erl index 48bf15f..f493362 100644 --- a/src/udp_proxy/esockd_udp_proxy.erl +++ b/src/udp_proxy/esockd_udp_proxy.erl @@ -21,7 +21,7 @@ -include("include/esockd_proxy.hrl"). %% API --export([start_link/3, send/2, close/1]). +-export([start_link/3, send/2, close/1, takeover/2]). %% gen_server callbacks -export([ @@ -50,6 +50,8 @@ connection_mod := connection_module(), connection_id := connection_id() | undefined, connection_state := connection_state(), + connection_pid := pid() | undefined, + connection_ref := reference() | undefined, connection_options := connection_options(), %% last source's connection active time last_time := pos_integer(), @@ -76,6 +78,10 @@ close(ProxyId) -> ok end. +takeover(ProxyId, CId) -> + _ = gen_server:cast(ProxyId, {?FUNCTION_NAME, CId}), + ok. + %%-------------------------------------------------------------------- %%- gen_server callbacks %%-------------------------------------------------------------------- @@ -88,7 +94,9 @@ init([Transport, Peer, #{esockd_proxy_opts := Opts} = COpts]) -> connection_mod => Mod, connection_options => COpts, connection_state => esockd_udp_proxy_connection:initialize(Mod, COpts), - connection_id => undefined + connection_id => undefined, + connection_pid => undefined, + connection_ref => undefined }). handle_call(close, _From, State) -> @@ -105,14 +113,18 @@ handle_cast({send, Data}, #{transport := Transport, peer := Peer} = State) -> ?ERROR_MSG("Send failed, Reason: ~0p", [Reason]), {stop, {sock_error, Reason}, State} end; +handle_cast({takeover, CId}, #{connection_id := CId} = State) -> + {stop, {shutdown, takeover}, State}; +handle_cast({takeover, _CId}, State) -> + {noreply, State}; handle_cast(Request, State) -> ?ERROR_MSG("Unexpected cast: ~p", [Request]), {noreply, State}. handle_info({datagram, _SockPid, Data}, State) -> - {noreply, handle_incoming(Data, State)}; + handle_incoming(Data, State); handle_info({ssl, _Socket, Data}, State) -> - {noreply, handle_incoming(Data, State)}; + handle_incoming(Data, State); handle_info({heartbeat, Span}, #{last_time := LastTime} = State) -> Now = ?NOW, case Now - LastTime > Span of @@ -127,7 +139,7 @@ handle_info({ssl_error, _Sock, Reason}, State) -> handle_info({ssl_closed, _Sock}, State) -> {stop, ssl_closed, socket_exit(State)}; handle_info( - {'DOWN', _, process, _, _Reason}, + {'DOWN', _, process, _Pid, _Reason}, State ) -> {stop, {shutdown, connection_closed}, State}; @@ -143,6 +155,8 @@ terminate(Reason, #{transport := Transport} = State) -> false; connection_closed -> false; + takeover -> + false; _ -> true end, @@ -151,7 +165,7 @@ terminate(Reason, #{transport := Transport} = State) -> %%-------------------------------------------------------------------- %%- Internal functions %%-------------------------------------------------------------------- --spec handle_incoming(socket_packet(), state()) -> state(). +-spec handle_incoming(socket_packet(), state()) -> _. handle_incoming( Data, #{transport := Transport, peer := Peer, connection_mod := Mod, connection_state := CState} = @@ -161,11 +175,17 @@ handle_incoming( case esockd_udp_proxy_connection:get_connection_id(Mod, Transport, Peer, CState, Data) of {ok, CId, Packet, CState2} -> dispatch(Mod, CId, Data, Packet, State2#{connection_state := CState2}); + {error, Reply} -> + ?ERROR_MSG("Can't get connection id, Transport:~0p, Peer:~0p, Mod:~0p", [ + Transport, Peer, Mod + ]), + _ = send(Transport, Peer, Reply), + {stop, {shutdown, no_clientid}, State2}; invalid -> ?ERROR_MSG("Can't get connection id, Transport:~0p, Peer:~0p, Mod:~0p", [ Transport, Peer, Mod ]), - State2 + {stop, {shutdown, no_clientid}, State2} end. -spec dispatch( @@ -174,8 +194,7 @@ handle_incoming( connection_id(), connection_packet(), state() -) -> - state(). +) -> _. dispatch( Mod, CId, @@ -183,43 +202,54 @@ dispatch( Packet, #{ transport := Transport, - peer := Peer, - connection_state := CState, - connection_options := Opts + connection_state := CState } = State ) -> - case lookup(Mod, Transport, Peer, CId, Opts) of + case lookup(CId, State) of {ok, Pid} -> + Result = attach(CId, State, Pid), esockd_udp_proxy_connection:dispatch( Mod, Pid, CState, {Transport, Data, Packet} ), - attach(CId, State); + {noreply, Result}; {error, Reason} -> ?ERROR_MSG("Dispatch failed, Reason:~0p", [Reason]), - State + {noreply, State} end. --spec attach(connection_id(), state()) -> state(). -attach(CId, #{connection_mod := Mod, connection_id := undefined} = State) -> +-spec attach(connection_id(), state(), pid()) -> state(). +attach(CId, #{connection_mod := Mod, connection_id := undefined} = State, Pid) -> esockd_udp_proxy_db:attach(Mod, CId), - State#{connection_id := CId}; -attach(CId, #{connection_id := OldId} = State) when CId =/= OldId -> - State2 = detach(State), - attach(CId, State2); -attach(_CId, State) -> + Ref = erlang:monitor(process, Pid), + State#{connection_id := CId, connection_pid := Pid, connection_ref := Ref}; +attach(CId, #{connection_id := OldId} = State, Pid) when CId =/= OldId -> + State2 = detach(State, false), + attach(CId, State2, Pid); +attach(_CId, State, _Pid) -> State. --spec detach(state()) -> state(). detach(State) -> detach(State, true). --spec detach(state(), boolean()) -> state(). +-spec detach(state()) -> state(). detach(#{connection_id := undefined} = State, _Clear) -> State; -detach(#{connection_id := CId, connection_mod := Mod, connection_state := CState} = State, Clear) -> - case esockd_udp_proxy_db:detach(Mod, CId) of - {Clear, Pid} -> +detach( + #{ + connection_id := CId, + connection_pid := Pid, + connection_ref := Ref, + connection_mod := Mod, + connection_state := CState + } = State, + Clear +) -> + erlang:demonitor(Ref), + + Result = esockd_udp_proxy_db:detach(Mod, CId), + case Clear andalso Result of + true -> case erlang:is_process_alive(Pid) of true -> esockd_udp_proxy_connection:close(Mod, Pid, CState); @@ -229,7 +259,7 @@ detach(#{connection_id := CId, connection_mod := Mod, connection_state := CState _ -> ok end, - State#{connection_id := undefined}. + State#{connection_id := undefined, connection_pid := undefined, connection_ref := undefined}. -spec socket_exit(state()) -> state(). socket_exit(State) -> @@ -240,28 +270,27 @@ heartbeat(Span) -> erlang:send_after(timer:seconds(Span), self(), {?FUNCTION_NAME, Span}), ok. --spec lookup( - connection_module(), - proxy_transport(), - peer(), - connection_id(), - connection_options() -) -> {ok, pid()} | {error, Reason :: term()}. -lookup(Mod, Transport, Peer, CId, Opts) -> - case esockd_udp_proxy_db:lookup(Mod, CId) of - {ok, _} = Ok -> - Ok; - undefined -> - case esockd_udp_proxy_connection:create(Mod, Transport, Peer, Opts) of - {ok, Pid} -> - esockd_udp_proxy_db:insert(Mod, CId, Pid), - _ = erlang:monitor(process, Pid), - {ok, Pid}; - ignore -> - {error, ignore}; - Error -> - Error - end +-spec lookup(connection_id(), state()) -> {ok, pid()} | {error, Reason :: term()}. +lookup(_CId, #{connection_pid := Pid}) when is_pid(Pid) -> + {ok, Pid}; +lookup(CId, #{ + connection_pid := undefined, + connection_mod := Mod, + transport := Transport, + peer := Peer, + connection_options := Opts +}) -> + %% TODO: use proc_lib:start_link to instead of this call + Fun = fun() -> + esockd_udp_proxy_connection:find_or_create(Mod, CId, Transport, Peer, Opts) + end, + case esockd_udp:proxy_request(Fun) of + {ok, Pid} -> + {ok, Pid}; + ignore -> + {error, ignore}; + Error -> + Error end. -spec send(proxy_transport(), peer(), binary()) -> _. diff --git a/src/udp_proxy/esockd_udp_proxy_connection.erl b/src/udp_proxy/esockd_udp_proxy_connection.erl index e53956b..67d31e8 100644 --- a/src/udp_proxy/esockd_udp_proxy_connection.erl +++ b/src/udp_proxy/esockd_udp_proxy_connection.erl @@ -20,7 +20,7 @@ -export([ initialize/2, - create/4, + find_or_create/5, get_connection_id/5, dispatch/4, close/3 @@ -34,7 +34,8 @@ -callback initialize(connection_options()) -> connection_state(). %% Create new connection --callback create(proxy_transport(), peer(), connection_options()) -> gen_server:start_ret(). +-callback find_or_create(connection_id(), proxy_transport(), peer(), connection_options()) -> + gen_server:start_ret(). %% Find routing information -callback get_connection_id( @@ -54,8 +55,8 @@ initialize(Mod, Opts) -> Mod:initialize(Opts). -create(Mod, Transport, Peer, Opts) -> - Mod:create(Transport, Peer, Opts). +find_or_create(Mod, CId, Transport, Peer, Opts) -> + Mod:find_or_create(CId, Transport, Peer, Opts). get_connection_id(Mod, Transport, Peer, State, Data) -> Mod:get_connection_id(Transport, Peer, State, Data). diff --git a/src/udp_proxy/esockd_udp_proxy_db.erl b/src/udp_proxy/esockd_udp_proxy_db.erl index d8a5dba..157b77e 100644 --- a/src/udp_proxy/esockd_udp_proxy_db.erl +++ b/src/udp_proxy/esockd_udp_proxy_db.erl @@ -23,10 +23,8 @@ %% API -export([ start_link/0, - insert/3, attach/2, - detach/2, - lookup/2 + detach/2 ]). %% gen_server callbacks @@ -42,14 +40,11 @@ -record(connection, { id :: ?ID(connection_module(), connection_id()), - %% the connection pid - pid :: pid(), %% Reference Counter - count :: non_neg_integer() + proxy :: pid() }). -define(TAB, esockd_udp_proxy_db). --define(MINIMUM_VAL, -2147483647). %%-------------------------------------------------------------------- %%- API @@ -58,42 +53,27 @@ start_link() -> gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). --spec insert(connection_module(), connection_id(), pid()) -> boolean(). -insert(Mod, CId, Pid) -> - ets:insert_new(?TAB, #connection{ - id = ?ID(Mod, CId), - pid = Pid, - count = 0 - }). - --spec attach(connection_module(), connection_id()) -> integer(). +-spec attach(connection_module(), connection_id()) -> true. attach(Mod, CId) -> - ets:update_counter(?TAB, ?ID(Mod, CId), {#connection.count, 1}). - --spec detach(connection_module(), connection_id()) -> {Clear :: true, connection_state()} | false. + ID = ?ID(Mod, CId), + case ets:lookup(?TAB, ID) of + [] -> + ok; + [#connection{proxy = ProxyId}] -> + esockd_udp_proxy:takeover(ProxyId, CId) + end, + ets:insert(?TAB, #connection{id = ID, proxy = self()}). + +-spec detach(connection_module(), connection_id()) -> boolean(). detach(Mod, CId) -> - Id = ?ID(Mod, CId), - RC = ets:update_counter(?TAB, Id, {#connection.count, -1, 0, ?MINIMUM_VAL}), - if - RC < 0 -> - case ets:lookup(?TAB, Id) of - [#connection{pid = Pid}] -> - ets:delete(?TAB, Id), - {true, Pid}; - _ -> - false - end; - true -> - false - end. - --spec lookup(connection_module(), connection_id()) -> {ok, pid()} | undefined. -lookup(Mod, CId) -> - case ets:lookup(?TAB, ?ID(Mod, CId)) of - [#connection{pid = Pid}] -> - {ok, Pid}; + ProxyId = self(), + ID = ?ID(Mod, CId), + case ets:lookup(?TAB, ID) of + [#connection{proxy = ProxyId}] -> + ets:delete(?TAB, ID), + true; _ -> - undefined + false end. %%--------------------------------------------------------------------