forked from quic-go/quic-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
packet_unpacker.go
188 lines (171 loc) · 5.88 KB
/
packet_unpacker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package quic
import (
"bytes"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
type unpackedPacket struct {
packetNumber protocol.PacketNumber // the decoded packet number
hdr *wire.ExtendedHeader
encryptionLevel protocol.EncryptionLevel
data []byte
}
// The packetUnpacker unpacks QUIC packets.
type packetUnpacker struct {
cs handshake.CryptoSetup
largestRcvdPacketNumber protocol.PacketNumber
version protocol.VersionNumber
}
var _ unpacker = &packetUnpacker{}
func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker {
return &packetUnpacker{
cs: cs,
version: version,
}
}
func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
switch hdr.Type {
case protocol.PacketTypeInitial:
encLevel = protocol.EncryptionInitial
opener, err := u.cs.GetInitialOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketTypeHandshake:
encLevel = protocol.EncryptionHandshake
opener, err := u.cs.GetHandshakeOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketType0RTT:
encLevel = protocol.Encryption0RTT
opener, err := u.cs.Get0RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
default:
if hdr.IsLongHeader {
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
encLevel = protocol.Encryption1RTT
opener, err := u.cs.Get1RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data)
if err != nil {
return nil, err
}
}
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)
return &unpackedPacket{
hdr: extHdr,
packetNumber: extHdr.PacketNumber,
encryptionLevel: encLevel,
data: decrypted,
}, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, fmt.Errorf("error parsing extended header: %s", parseErr)
}
extHdrLen := extHdr.ParsedLen()
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
if parseErr != nil {
return nil, nil, parseErr
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackShortHeaderPacket(
opener handshake.ShortHeaderOpener,
hdr *wire.Header,
rcvTime time.Time,
data []byte,
) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, parseErr
}
extHdrLen := extHdr.ParsedLen()
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
if parseErr != nil {
return nil, nil, parseErr
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackHeader(hd, hdr, data, u.version)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, err
}
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
return extHdr, err
}
func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) {
r := bytes.NewReader(data)
hdrLen := hdr.ParsedLen()
if protocol.ByteCount(len(data)) < hdrLen+4+16 {
//nolint:stylecheck
return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
}
// The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
// 1. save a copy of the 4 bytes
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(r, version)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr
}
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
}
return extHdr, parseErr
}