-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconn.go
351 lines (325 loc) · 8.08 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
package noiseconn
import (
"encoding/binary"
"errors"
"io"
"net"
"sync"
"github.com/flynn/noise"
"github.com/zeebo/errs"
)
const HeaderByte = 0x80
const flushLimit = 640 * 1024
// MessageInspector is a callback that gets informed about unparsed
// Noise messages.
type MessageInspector func(addr net.Addr, message []byte) error
type Options struct {
// ResponderFirstMessageValidator will be called with the first
// received Noise message (unparsed) for a responder, if set. It is
// not considered for initiators or for any subsequent packet.
// This can be used for analyzing message replay, debouncing
// messages deliberately sent twice
// (see github.com/jtolio/noiseconn/debounce), and other issues,
// but is not safe for use as replay attack prevention.
ResponderFirstMessageValidator MessageInspector
}
// Conn is a net.Conn that implements a framed Noise protocol on top of the
// underlying net.Conn provided in NewConn. Conn allows for 0-RTT protocols,
// in the sense that bytes given to Write will be added to handshake
// payloads.
// Read and Write should not be called concurrently until
// HandshakeComplete() is true.
type Conn struct {
net.Conn
hsMu sync.Mutex
readBarrier barrier
hs *noise.HandshakeState
hh []byte
initiator bool
hsResponsibility bool
readMsgBuf []byte
writeMsgBuf []byte
readBuf []byte
send, recv *noise.CipherState
rfmValidate MessageInspector
}
var _ net.Conn = (*Conn)(nil)
// NewConn wraps an existing net.Conn with encryption provided by
// noise.Config.
func NewConn(conn net.Conn, config noise.Config) (*Conn, error) {
return NewConnWithOptions(conn, config, Options{})
}
// NewConn wraps an existing net.Conn with encryption provided by
// noise.Config and options provided by Options.
func NewConnWithOptions(conn net.Conn, config noise.Config, opts Options) (*Conn, error) {
hs, err := noise.NewHandshakeState(config)
if err != nil {
return nil, errs.Wrap(err)
}
return &Conn{
Conn: conn,
hs: hs,
initiator: config.Initiator,
hsResponsibility: config.Initiator,
rfmValidate: opts.ResponderFirstMessageValidator,
}, nil
}
func (c *Conn) Close() error {
c.readBarrier.Release()
return c.Conn.Close()
}
func (c *Conn) setCipherStates(cs1, cs2 *noise.CipherState) {
if c.initiator {
c.send, c.recv = cs1, cs2
} else {
c.send, c.recv = cs2, cs1
}
if c.send != nil {
c.readBarrier.Release()
c.hh = c.hs.ChannelBinding()
c.hs = nil
}
}
func (c *Conn) hsRead() (err error) {
c.readMsgBuf, err = c.readMsg(c.readMsgBuf[:0])
if err != nil {
return err
}
var cs1, cs2 *noise.CipherState
c.readBuf, cs1, cs2, err = c.hs.ReadMessage(c.readBuf, c.readMsgBuf)
if err != nil {
return errs.Wrap(err)
}
c.setCipherStates(cs1, cs2)
c.hsResponsibility = true
if c.rfmValidate != nil {
err = c.rfmValidate(c.Conn.RemoteAddr(), c.readMsgBuf)
c.rfmValidate = nil
return errs.Wrap(err)
}
return nil
}
func (c *Conn) Read(b []byte) (n int, err error) {
if c.initiator {
c.readBarrier.Wait()
}
c.hsMu.Lock()
locked := true
unlocker := func() {
if locked {
locked = false
c.hsMu.Unlock()
}
}
if c.hs == nil {
unlocker()
} else {
defer unlocker()
}
handleBuffered := func() bool {
if len(c.readBuf) == 0 {
return false
}
n = copy(b, c.readBuf)
copy(c.readBuf, c.readBuf[n:])
c.readBuf = c.readBuf[:len(c.readBuf)-n]
return true
}
if handleBuffered() {
return n, nil
}
for c.hs != nil {
if c.hsResponsibility {
c.writeMsgBuf, err = c.hsCreate(c.writeMsgBuf[:0], nil)
if err != nil {
return 0, err
}
_, err = c.Conn.Write(c.writeMsgBuf)
if err != nil {
return 0, errs.Wrap(err)
}
if c.hs == nil {
break
}
}
err = c.hsRead()
if err != nil {
return 0, err
}
if handleBuffered() {
return n, nil
}
}
unlocker()
for {
c.readMsgBuf, err = c.readMsg(c.readMsgBuf[:0])
if err != nil {
return 0, err
}
if len(b) >= 65535 {
// read directly into b, since b has enough room for a noise
// payload.
// TODO(jt): is this the best way to determine if we can read into
// b? we should be able to know without this worst case. i kind of
// hate this code.
out, err := c.recv.Decrypt(b[:0], nil, c.readMsgBuf)
if err != nil {
return 0, errs.Wrap(err)
}
if len(out) > len(b) {
panic("whoops")
}
if len(out) > 0 {
return len(out), nil
}
continue
}
c.readBuf, err = c.recv.Decrypt(c.readBuf, nil, c.readMsgBuf)
if err != nil {
return 0, errs.Wrap(err)
}
if handleBuffered() {
return n, nil
}
}
}
// readMsg appends a message to b.
func (c *Conn) readMsg(b []byte) ([]byte, error) {
// TODO(jt): make sure these reads are through bufio somewhere in the stack
// appropriate.
var msgHeader [4]byte
_, err := io.ReadFull(c.Conn, msgHeader[:])
if err != nil {
return nil, errs.Wrap(err)
}
if msgHeader[0] != HeaderByte {
// TODO(jt): close conn?
return nil, errs.New("unknown message header")
}
msgHeader[0] = 0
msgSize := int(binary.BigEndian.Uint32(msgHeader[:]))
b = append(b[len(b):], make([]byte, msgSize)...)
_, err = io.ReadFull(c.Conn, b)
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errs.Wrap(io.ErrUnexpectedEOF)
}
return nil, errs.Wrap(err)
}
return b, nil
}
func (c *Conn) frame(header, b []byte) error {
if len(b) >= 1<<(8*3) {
return errs.New("message too large: %d", len(b))
}
binary.BigEndian.PutUint32(header[:4], uint32(len(b)))
header[0] = HeaderByte
return nil
}
func (c *Conn) hsCreate(out, payload []byte) (_ []byte, err error) {
var cs1, cs2 *noise.CipherState
outlen := len(out)
out, cs1, cs2, err = c.hs.WriteMessage(append(out, make([]byte, 4)...), payload)
if err != nil {
return nil, errs.Wrap(err)
}
if c.rfmValidate != nil {
// only applies to responders, not initiators.
c.rfmValidate = nil
}
c.setCipherStates(cs1, cs2)
c.hsResponsibility = false
c.readBarrier.Release()
return out, c.frame(out[outlen:], out[outlen+4:])
}
// If a Noise handshake is still occurring (or has yet to occur), the
// data provided to Write will be included in handshake payloads. Note that
// even if the Noise configuration allows for 0-RTT, the request will only be
// 0-RTT if the request is 65535 bytes or smaller.
func (c *Conn) Write(b []byte) (n int, err error) {
c.hsMu.Lock()
locked := true
unlocker := func() {
if locked {
locked = false
c.hsMu.Unlock()
}
}
if c.hs == nil {
unlocker()
} else {
defer unlocker()
}
for c.hs != nil && len(b) > 0 {
if !c.hsResponsibility {
err = c.hsRead()
if err != nil {
return n, err
}
}
if c.hs != nil {
l := min(noise.MaxMsgLen, len(b))
c.writeMsgBuf, err = c.hsCreate(c.writeMsgBuf[:0], b[:l])
if err != nil {
return n, err
}
_, err = c.Conn.Write(c.writeMsgBuf)
if err != nil {
return n, errs.Wrap(err)
}
n += l
b = b[l:]
}
}
unlocker()
c.writeMsgBuf = c.writeMsgBuf[:0]
for len(b) > 0 {
outlen := len(c.writeMsgBuf)
l := min(noise.MaxMsgLen, len(b))
c.writeMsgBuf, err = c.send.Encrypt(append(c.writeMsgBuf, make([]byte, 4)...), nil, b[:l])
if err != nil {
return n, errs.Wrap(err)
}
err = c.frame(c.writeMsgBuf[outlen:], c.writeMsgBuf[outlen+4:])
if err != nil {
return n, err
}
n += l
b = b[l:]
if len(c.writeMsgBuf) > flushLimit {
_, err = c.Conn.Write(c.writeMsgBuf)
if err != nil {
return n, err
}
c.writeMsgBuf = c.writeMsgBuf[:0]
}
}
if len(c.writeMsgBuf) > 0 {
_, err = c.Conn.Write(c.writeMsgBuf)
if err != nil {
return n, err
}
c.writeMsgBuf = c.writeMsgBuf[:0]
}
return n, nil
}
// HandshakeComplete returns whether a handshake is complete.
func (c *Conn) HandshakeComplete() bool {
c.hsMu.Lock()
defer c.hsMu.Unlock()
return c.hs == nil
}
// HandshakeHash returns the hash generated by the handshake which can be
// used for channel identification and channel binding. This returns nil
// until the handshake is completed.
func (c *Conn) HandshakeHash() []byte {
c.hsMu.Lock()
defer c.hsMu.Unlock()
return c.hh
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}