diff --git a/protocol/czar/mux.go b/protocol/czar/mux.go index 12001c7..92a9d65 100644 --- a/protocol/czar/mux.go +++ b/protocol/czar/mux.go @@ -18,7 +18,6 @@ type Stream struct { rch chan []byte rer *Err wer *Err - wmu sync.Mutex zo0 sync.Once zo1 sync.Once } @@ -28,9 +27,10 @@ func (s *Stream) Close() error { s.rer.Put(io.ErrClosedPipe) s.wer.Put(io.ErrClosedPipe) s.zo0.Do(func() { - s.wmu.Lock() - s.mux.Write(0, []byte{s.idx, 0x02, 0x00, 0x00}) - s.wmu.Unlock() + s.mux.pri.H(func() error { + s.mux.con.Write([]byte{s.idx, 0x02, 0x00, 0x00}) + return nil + }) }) return nil } @@ -40,9 +40,10 @@ func (s *Stream) Esolc() error { s.rer.Put(io.EOF) s.wer.Put(io.ErrClosedPipe) s.zo0.Do(func() { - s.wmu.Lock() - s.mux.Write(0, []byte{s.idx, 0x02, 0x01, 0x00}) - s.wmu.Unlock() + s.mux.pri.H(func() error { + s.mux.con.Write([]byte{s.idx, 0x02, 0x01, 0x00}) + return nil + }) }) s.zo1.Do(func() { s.idp.Put(s.idx) @@ -98,18 +99,20 @@ func (s *Stream) Write(p []byte) (int, error) { binary.BigEndian.PutUint16(b[2:4], uint16(l)) copy(b[4:], p[:l]) p = p[l:] - s.wmu.Lock() - if err := s.wer.Get(); err != nil { - s.wmu.Unlock() - return n, err - } - _, err := s.mux.Write(1, b[:4+l]) + err := s.mux.pri.M(func() error { + if err := s.wer.Get(); err != nil { + return err + } + _, err := s.mux.con.Write(b[:4+l]) + if err != nil { + s.wer.Put(err) + return err + } + return nil + }) if err != nil { - s.wer.Put(err) - s.wmu.Unlock() return n, err } - s.wmu.Unlock() n += l } } @@ -124,7 +127,6 @@ func NewStream(idx uint8, mux *Mux) *Stream { rch: make(chan []byte, 32), rer: NewErr(), wer: NewErr(), - wmu: sync.Mutex{}, zo0: sync.Once{}, zo1: sync.Once{}, } @@ -144,10 +146,9 @@ type Mux struct { ach chan *Stream con net.Conn idp *Sip + pri *Priority rer *Err usb []*Stream - wm0 sync.Mutex - wm1 sync.Mutex } // Accept is used to block until the next available stream is ready to be accepted. @@ -163,17 +164,23 @@ func (m *Mux) Close() error { // Open is used to create a new stream as a net.Conn. func (m *Mux) Open() (*Stream, error) { - idx, err := m.idp.Get() + var ( + err error + idx uint8 + stm *Stream + ) + idx, err = m.idp.Get() if err != nil { return nil, err } - cnt, err := m.Write(0, []byte{idx, 0x00, 0x00, 0x00}) + err = m.pri.H(func() error { + return doa.Err(m.con.Write([]byte{idx, 0x00, 0x00, 0x00})) + }) if err != nil { m.idp.Put(idx) return nil, err } - doa.Doa(cnt == 4) - stm := NewStream(idx, m) + stm = NewStream(idx, m) stm.idp = m.idp m.usb[idx] = stm return stm, nil @@ -242,30 +249,15 @@ func (m *Mux) Recv() { close(m.ach) } -// Write writes data to the connection. The code implements a simple priority write using two locks. -func (m *Mux) Write(priority int, b []byte) (int, error) { - if priority >= 1 { - m.wm1.Lock() - defer m.wm1.Unlock() - } - if priority >= 0 { - m.wm0.Lock() - defer m.wm0.Unlock() - } - n, err := m.con.Write(b) - return n, err -} - // NewMux returns a new Mux. func NewMux(conn net.Conn) *Mux { mux := &Mux{ ach: make(chan *Stream), con: conn, idp: nil, + pri: NewPriority(), rer: NewErr(), usb: make([]*Stream, 256), - wm0: sync.Mutex{}, - wm1: sync.Mutex{}, } return mux } diff --git a/protocol/czar/priority.go b/protocol/czar/priority.go new file mode 100644 index 0000000..ed776ce --- /dev/null +++ b/protocol/czar/priority.go @@ -0,0 +1,48 @@ +package czar + +import ( + "sync" +) + +// Priority implement a lock with three priorities. +type Priority struct { + l sync.Mutex + m sync.Mutex + h sync.Mutex +} + +// H executes function f with 0 priority. +func (p *Priority) H(f func() error) error { + p.h.Lock() + defer p.h.Unlock() + return f() +} + +// H executes function f with 1 priority. +func (p *Priority) M(f func() error) error { + p.m.Lock() + defer p.m.Unlock() + p.h.Lock() + defer p.h.Unlock() + return f() +} + +// H executes function f with 2 priority. +func (p *Priority) L(f func() error) error { + p.l.Lock() + defer p.l.Unlock() + p.m.Lock() + defer p.m.Unlock() + p.h.Lock() + defer p.h.Unlock() + return f() +} + +// NewPriority returns a new Priority. +func NewPriority() *Priority { + return &Priority{ + l: sync.Mutex{}, + m: sync.Mutex{}, + h: sync.Mutex{}, + } +} diff --git a/protocol/czar/priority_test.go b/protocol/czar/priority_test.go new file mode 100644 index 0000000..4821fb3 --- /dev/null +++ b/protocol/czar/priority_test.go @@ -0,0 +1,18 @@ +package czar + +import ( + "testing" +) + +func TestPriority(t *testing.T) { + pri := NewPriority() + pri.H(func() error { + return nil + }) + pri.M(func() error { + return nil + }) + pri.L(func() error { + return nil + }) +}