diff --git a/application/stream.go b/application/stream.go index e6a730c..9940afa 100644 --- a/application/stream.go +++ b/application/stream.go @@ -66,6 +66,7 @@ type stream struct { mtx sync.RWMutex streamOK bool closeOnce *gsync.Once + finiOnce *gsync.Once // app layer messages // raw cache @@ -84,9 +85,6 @@ type stream struct { // io writeInCh chan packet.Packet // for multiple message types - - // close channel - closeCh chan struct{} } func newStream(end *End, cn conn.Conn, dg multiplexer.Dialogue, opts *opts) *stream { @@ -103,15 +101,16 @@ func newStream(end *End, cn conn.Conn, dg multiplexer.Dialogue, opts *opts) *str remoteRPCs: make(map[string]struct{}), streamOK: true, closeOnce: new(gsync.Once), + finiOnce: new(gsync.Once), messageCh: make(chan *packet.MessagePacket, 32), streamCh: make(chan *packet.StreamPacket, 32), failedCh: make(chan packet.Packet), dlReadChList: list.New(), dlWriteChList: list.New(), writeInCh: make(chan packet.Packet, 32), - closeCh: make(chan struct{}), } go sm.handlePkt() + go sm.readPkt() return sm } @@ -145,48 +144,56 @@ func (sm *stream) Side() geminio.Side { // main handle logic func (sm *stream) handlePkt() { - readInCh := sm.dg.ReadC() writeInCh := sm.writeInCh for { select { - case pkt, ok := <-readInCh: + case pkt, ok := <-writeInCh: if !ok { + // BUG! shoud never be here. goto FINI } - sm.log.Tracef("stream read in packet, clientID: %d, dialogueID: %d, packetID: %d, packetType: %s", + sm.log.Tracef("stream write in packet, clientID: %d, dialogueID: %d, packetID: %d, packetType: %s", sm.cn.ClientID(), sm.dg.DialogueID(), pkt.ID(), pkt.Type().String()) - ret := sm.handleIn(pkt) + ret := sm.handleOut(pkt) switch ret { case iodefine.IOSuccess: continue case iodefine.IODiscard: - sm.log.Infof("stream read in packet but buffer full and discard, clientID: %d, dialogueID: %d, packetID: %d, packetType: %s", - sm.cn.ClientID(), sm.dg.DialogueID(), pkt.ID(), pkt.Type().String()) + sm.shub.Error(pkt.ID(), io.ErrShortBuffer) case iodefine.IOErr: goto FINI } + } + } +FINI: + sm.finiOnce.Do(sm.fini) +} - case pkt, ok := <-writeInCh: +func (sm *stream) readPkt() { + readInCh := sm.dg.ReadC() + for { + select { + case pkt, ok := <-readInCh: if !ok { - // BUG! shoud never be here. goto FINI } - sm.log.Tracef("stream write in packet, clientID: %d, dialogueID: %d, packetID: %d, packetType: %s", + sm.log.Tracef("stream read in packet, clientID: %d, dialogueID: %d, packetID: %d, packetType: %s", sm.cn.ClientID(), sm.dg.DialogueID(), pkt.ID(), pkt.Type().String()) - ret := sm.handleOut(pkt) + ret := sm.handleIn(pkt) switch ret { case iodefine.IOSuccess: continue case iodefine.IODiscard: - sm.shub.Error(pkt.ID(), io.ErrShortBuffer) + sm.log.Infof("stream read in packet but buffer full and discard, clientID: %d, dialogueID: %d, packetID: %d, packetType: %s", + sm.cn.ClientID(), sm.dg.DialogueID(), pkt.ID(), pkt.Type().String()) case iodefine.IOErr: goto FINI } } } FINI: - sm.fini() + sm.finiOnce.Do(sm.fini) } func (sm *stream) handleIn(pkt packet.Packet) iodefine.IORet { @@ -233,13 +240,16 @@ func (sm *stream) handleOut(pkt packet.Packet) iodefine.IORet { func (sm *stream) handleInMessagePacket(pkt *packet.MessagePacket) iodefine.IORet { sm.log.Tracef("read message packet, clientID: %d, dialogueID: %d, packetID: %d, packetType: %s", sm.cn.ClientID(), sm.dg.DialogueID(), pkt.ID(), pkt.Type().String()) - // we don't want block here. + // we don't want block here, and also we don't want discard immediately. select { case sm.messageCh <- pkt: - default: - pkt := sm.pf.NewMessageAckPacketWithSessionID(sm.dg.DialogueID(), pkt.ID(), iodefine.ErrIOBufferFull) - sm.dg.WriteWait(pkt) - return iodefine.IODiscard + /* + // TODO optimize it + default: + pkt := sm.pf.NewMessageAckPacketWithSessionID(sm.dg.DialogueID(), pkt.ID(), iodefine.ErrIOBufferFull) + sm.dg.WriteWait(pkt) + return iodefine.IODiscard + */ } return iodefine.IOSuccess } @@ -546,19 +556,18 @@ func (sm *stream) fini() { sm.cn.ClientID(), sm.dg.DialogueID()) sm.mtx.Lock() + defer sm.mtx.Unlock() // collect shub, and all syncs will be close notified - sm.shub.Close() - sm.shub = nil sm.streamOK = false close(sm.writeInCh) - sm.mtx.Unlock() - for range sm.writeInCh { - // TODO we should care about msg in writeInCh buffer, it may contains message, request... + for pkt := range sm.writeInCh { + sm.shub.Error(pkt.ID(), io.ErrClosedPipe) } - // collect channels sm.writeInCh = nil + sm.shub.Close() + sm.shub = nil // the outside should care about message and stream channel status close(sm.messageCh) @@ -570,9 +579,6 @@ func (sm *stream) fini() { } sm.tmr = nil - // collect close - close(sm.closeCh) - if sm.dg.DialogueID() == 1 { // the master stream sm.end.fini() diff --git a/client/end_options.go b/client/end_options.go index 8729ed2..0653972 100644 --- a/client/end_options.go +++ b/client/end_options.go @@ -114,10 +114,10 @@ func MergeEndOptions(opts ...*EndOptions) *EndOptions { if opt.LocalMethods != nil { eo.LocalMethods = opt.LocalMethods } - if opt.ReadBufferSize != -1 { + if opt.ReadBufferSize > 0 { eo.ReadBufferSize = opt.ReadBufferSize } - if opt.WriteBufferSize != -1 { + if opt.WriteBufferSize > 0 { eo.WriteBufferSize = opt.WriteBufferSize } } diff --git a/client/end_retry_options.go b/client/end_retry_options.go deleted file mode 100644 index befa672..0000000 --- a/client/end_retry_options.go +++ /dev/null @@ -1,73 +0,0 @@ -package client - -import ( - "github.com/jumboframes/armorigo/log" - "github.com/singchia/geminio/packet" - "github.com/singchia/geminio/pkg/id" -) - -type RetryEndOptions struct { - *EndOptions -} - -func NewRetryEndOptions() *RetryEndOptions { - return &RetryEndOptions{ - EndOptions: &EndOptions{}, - } -} - -func MergeRetryEndOptions(opts ...*RetryEndOptions) *RetryEndOptions { - eo := &RetryEndOptions{ - EndOptions: &EndOptions{ - ReadBufferSize: -1, - WriteBufferSize: -1, - }, - } - for _, opt := range opts { - if opt == nil { - continue - } - eo.RemoteMethodCheck = opt.RemoteMethodCheck - if opt.Timer != nil { - eo.Timer = opt.Timer - eo.TimerOwner = opt.TimerOwner - } - if opt.PacketFactory != nil { - eo.PacketFactory = opt.PacketFactory - } - if opt.Log != nil { - eo.Log = opt.Log - } - if opt.Delegate != nil { - eo.Delegate = opt.Delegate - } - if opt.Meta != nil { - eo.Meta = opt.Meta - } - if opt.ClientID != nil { - eo.ClientID = opt.ClientID - } - if opt.RemoteMethods != nil { - eo.RemoteMethods = opt.RemoteMethods - } - if opt.LocalMethods != nil { - eo.LocalMethods = opt.LocalMethods - } - if opt.ReadBufferSize != -1 { - eo.ReadBufferSize = opt.ReadBufferSize - } - if opt.WriteBufferSize != -1 { - eo.WriteBufferSize = opt.WriteBufferSize - } - } - return eo -} - -func initRetryEndOptions(eo *RetryEndOptions) { - if eo.Log == nil { - eo.Log = log.DefaultLog - } - if eo.PacketFactory == nil { - eo.PacketFactory = packet.NewPacketFactory(id.NewIDCounter(id.Odd)) - } -} diff --git a/conn/conn_base.go b/conn/conn_base.go index 3a74864..8f4ff4e 100644 --- a/conn/conn_base.go +++ b/conn/conn_base.go @@ -116,6 +116,7 @@ func (bc *baseConn) writePkt() { record := !packet.ConnLayer(pkt) err = bc.dowritePkt(pkt, record) if err != nil { + // write to net Conn error, we should close the layer return } } diff --git a/conn/conn_client.go b/conn/conn_client.go index 494d2e0..8bd5560 100644 --- a/conn/conn_client.go +++ b/conn/conn_client.go @@ -74,10 +74,10 @@ func OptionClientConnClientID(clientID uint64) ClientConnOption { func OptionClientConnBufferSize(read, write int) ClientConnOption { return func(cc *ClientConn) error { - if read != -1 { + if read > 0 { cc.readOutSize = read } - if write != -1 { + if write > 0 { cc.writeInSize = write } return nil @@ -105,14 +105,14 @@ func newClientConn(netconn net.Conn, opts ...ClientConnOption) (*ClientConn, err heartbeat: packet.Heartbeat20, meta: []byte{}, }, - netconn: netconn, - fsm: yafsm.NewFSM(), - side: geminio.InitiatorSide, - connOK: true, - readInCh: make(chan packet.Packet, 16), - writeOutCh: make(chan packet.Packet, 16), - readOutCh: make(chan packet.Packet, 16), - writeInCh: make(chan packet.Packet, 16), + netconn: netconn, + fsm: yafsm.NewFSM(), + side: geminio.InitiatorSide, + connOK: true, + readInSize: 32, + writeOutSize: 32, + readOutSize: 32, + writeInSize: 32, }, //finiOnce: new(sync.Once), closeOnce: new(sync.Once), @@ -125,6 +125,11 @@ func newClientConn(netconn net.Conn, opts ...ClientConnOption) (*ClientConn, err return nil, err } } + // io size + cc.readInCh = make(chan packet.Packet, cc.readInSize) + cc.writeOutCh = make(chan packet.Packet, cc.writeOutSize) + cc.readOutCh = make(chan packet.Packet, cc.readOutSize) + cc.writeInCh = make(chan packet.Packet, cc.writeInSize) // timer if !cc.tmrOutside { cc.tmr = timer.NewTimer() diff --git a/conn/conn_server.go b/conn/conn_server.go index 941f68e..c1b9474 100644 --- a/conn/conn_server.go +++ b/conn/conn_server.go @@ -64,13 +64,11 @@ func OptionServerConnFailedPacket(ch chan packet.Packet) ServerConnOption { func OptionServerConnBufferSize(read, write int) ServerConnOption { return func(sc *ServerConn) { - if read != -1 { + if read > 0 { sc.readOutSize = read - sc.readInSize = read } - if write != -1 { + if write > 0 { sc.writeInSize = write - sc.writeOutSize = write } } } @@ -275,8 +273,6 @@ func (sc *ServerConn) handleOut(pkt packet.Packet) iodefine.IORet { return sc.handleOutDisConnPacket(realPkt) case *packet.DisConnAckPacket: return sc.handleOutDisConnAckPacket(realPkt) - case *packet.HeartbeatAckPacket: - return sc.handleOutHeartbeatAckPacket(realPkt) default: return sc.handleOutDataPacket(pkt) } @@ -350,10 +346,12 @@ func (sc *ServerConn) handleInHeartbeatPacket(pkt *packet.HeartbeatPacket) iodef sc.hbTick = sc.tmr.Add(time.Duration(sc.heartbeat)*2*time.Second, timer.WithHandler(sc.waitHBTimeout)) retPkt := sc.pf.NewHeartbeatAckPacket(pkt.PacketID) - sc.writeInCh <- retPkt + sc.writeOutCh <- retPkt if sc.dlgt != nil { sc.dlgt.Heartbeat(sc) } + sc.log.Debugf("send heartbeat ack succeed, clientID: %d, PacketID: %d, packetType: %s", + sc.clientID, pkt.ID(), pkt.Type().String()) return iodefine.IOSuccess } @@ -397,13 +395,6 @@ func (sc *ServerConn) handleOutConnAckPacket(pkt *packet.ConnAckPacket) iodefine return iodefine.IOSuccess } -func (sc *ServerConn) handleOutHeartbeatAckPacket(pkt *packet.HeartbeatAckPacket) iodefine.IORet { - sc.writeOutCh <- pkt - sc.log.Debugf("send heartbeat ack succeed, clientID: %d, PacketID: %d, packetType: %s", - sc.clientID, pkt.ID(), pkt.Type().String()) - return iodefine.IOSuccess -} - func (sc *ServerConn) Close() { sc.closeOnce.Do(func() { sc.connMtx.RLock() diff --git a/multiplexer/dialogue.go b/multiplexer/dialogue.go index e6cec86..440ce0e 100644 --- a/multiplexer/dialogue.go +++ b/multiplexer/dialogue.go @@ -1,7 +1,6 @@ package multiplexer import ( - "fmt" "io" "sync" "time" @@ -127,10 +126,10 @@ func OptionDialogueNegotiatingID(negotiatingID uint64, dialogueIDPeersCall bool) func OptionDialogueBufferSize(read, write int) DialogueOption { return func(dg *dialogue) { - if read != -1 { + if read > 0 { dg.readOutSize = read } - if write != -1 { + if write > 0 { dg.writeInSize = write } } @@ -214,8 +213,11 @@ func (dg *dialogue) Write(pkt packet.Packet) error { pkt.(packet.SessionAbove).SetSessionID(dg.dialogueID) select { case dg.writeInCh <- pkt: - default: - return fmt.Errorf("%s, len: %d", io.ErrShortBuffer, len(dg.writeInCh)) + /* + // TODO optimize it + default: + return fmt.Errorf("%s, len: %d", io.ErrShortBuffer, len(dg.writeInCh)) + */ } return nil } diff --git a/multiplexer/dialogue_mgr.go b/multiplexer/dialogue_mgr.go index 46b04d2..677ae3c 100644 --- a/multiplexer/dialogue_mgr.go +++ b/multiplexer/dialogue_mgr.go @@ -121,10 +121,10 @@ func OptionTimer(tmr timer.Timer) MultiplexerOption { func OptionBufferSize(read, write int) MultiplexerOption { return func(opts *multiplexerOpts) { - if read != -1 { + if read > 0 { opts.readBufferSize = read } - if write != -1 { + if write > 0 { opts.writeBufferSize = write } } diff --git a/server/end_options.go b/server/end_options.go index ad308b5..6b6be1a 100644 --- a/server/end_options.go +++ b/server/end_options.go @@ -116,10 +116,10 @@ func MergeEndOptions(opts ...*EndOptions) *EndOptions { if opt.ClosedStreamFunc != nil { eo.ClosedStreamFunc = opt.ClosedStreamFunc } - if opt.ReadBufferSize != -1 { + if opt.ReadBufferSize > 0 { eo.ReadBufferSize = opt.ReadBufferSize } - if opt.WriteBufferSize != -1 { + if opt.WriteBufferSize > 0 { eo.WriteBufferSize = opt.WriteBufferSize } }