-
Notifications
You must be signed in to change notification settings - Fork 0
/
sessions.go
315 lines (262 loc) · 7.33 KB
/
sessions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
package possessions
import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/pkg/errors"
)
// Session gets strings
type Session interface {
Get(key string) (value string, hasKey bool)
}
// session holds the session value and the flash messages key/value mapping
type session struct {
ID string
// value is the session value stored as a json encoded string
Values map[string]string
}
// Get a key
func (s session) Get(key string) (string, bool) {
str, ok := s.Values[key]
return str, ok
}
// Storer provides methods to retrieve, add and delete sessions.
type Storer interface {
// All returns all keys in the store
All(ctx context.Context) (keys []string, err error)
Get(ctx context.Context, key string) (value string, err error)
Set(ctx context.Context, key, value string) error
Del(ctx context.Context, key string) error
ResetExpiry(ctx context.Context, key string) error
}
// EventKind of session mutation
type EventKind int
const (
// EventSet sets a key-value pair
EventSet EventKind = iota
// EventDel removes a key
EventDel
// EventDelAll means you should delete EVERY key-value pair from
// the client state - though a whitelist of keys that should not be deleted
// will be set in Keys
EventDelAll
// EventRefresh should refresh the TTL if any on the session
EventRefresh
// Deletes the client state
EventDelClientState
)
// Event represents an operation on a session
type Event struct {
Kind EventKind
Key string
Val string
Keys []string
}
// Overseer of session cookies
type Overseer interface {
// ReadState should return a map like structure allowing it to look up
// any values in the current session, or any cookie in the request
ReadState(*http.Request) (Session, error)
// WriteState can sometimes be called with a nil ClientState in the event
// that no ClientState was read in from LoadClientState
WriteState(context.Context, http.ResponseWriter, Session, []Event) error
}
// timer interface is used to mock the test harness for disk and memory storers
type timer interface {
Stop() bool
Reset(time.Duration) bool
}
type noSessionInterface interface {
NoSession()
}
type noMapKeyInterface interface {
NoMapKey()
}
type errNoSession struct{}
type errNoMapKey struct{}
func (errNoSession) NoSession() {}
func (errNoMapKey) NoMapKey() {}
func (errNoSession) Error() string {
return "session does not exist"
}
func (errNoMapKey) Error() string {
return "session map key does not exist"
}
// IsNoSessionError checks an error to see if it means that there was no session
func IsNoSessionError(err error) bool {
_, ok := err.(noSessionInterface)
if ok {
return ok
}
_, ok = errors.Cause(err).(noSessionInterface)
return ok
}
// IsNoMapKeyError checks an error to see if it means that there was
// no session map key
func IsNoMapKeyError(err error) bool {
_, ok := err.(noMapKeyInterface)
if ok {
return ok
}
_, ok = errors.Cause(err).(noMapKeyInterface)
return ok
}
// timerTestHarness allows us to control the timer channels manually in the
// disk and memory storer tests so that we can trigger cleans at will
var timerTestHarness = func(d time.Duration) (timer, <-chan time.Time) {
t := time.NewTimer(d)
return t, t.C
}
// validKey returns true if the session key is a valid UUIDv4 format:
// 8chars-4chars-4chars-4chars-12chars (chars are a-f 0-9)
// Example: a668b3bb-0cf1-4627-8cd4-7f62d09ebad6
func validKey(key string) bool {
// UUIDv4's are 36 chars (16 bytes not including dashes)
if len(key) != 36 {
return false
}
// 0 indexed dash positions
dashPos := []int{8, 13, 18, 23}
for i := 0; i < len(key); i++ {
atDashPos := false
for _, pos := range dashPos {
if i == pos {
atDashPos = true
break
}
}
if atDashPos == true {
if key[i] != '-' {
return false
}
// continue the loop if dash is found
continue
}
// if not a dash, make sure char is a-f or 0-9
// 48 == '0', 57 == '9', 97 == 'a', 102 == 'f'
if key[i] < 48 || (key[i] > 57 && key[i] < 97) || key[i] > 102 {
return false
}
}
return true
}
// Get a session string value
func Get(ctx context.Context, key string) (string, bool) {
return get(ctx, key)
}
// GetObj a session json encoded string and decode it into obj. Use the
// IsNoMapKeyError to determine if the value was found or not.
func GetObj(ctx context.Context, key string, obj interface{}) error {
encodedString, ok := get(ctx, key)
if !ok {
return errNoMapKey{}
}
err := json.Unmarshal([]byte(encodedString), obj)
if err != nil {
return errors.Wrap(err, "failed to unmarshal session key-value")
}
return nil
}
// Set a session-value string
func Set(w http.ResponseWriter, key, value string) {
set(w, key, value)
}
// SetObj marshals the value to a json string and sets it in the session
func SetObj(w http.ResponseWriter, key string, obj interface{}) error {
value, err := json.Marshal(obj)
if err != nil {
return err
}
set(w, key, string(value))
return nil
}
func get(ctx context.Context, key string) (string, bool) {
cached := ctx.Value(CTXKeyPossessions{})
if cached == nil {
return "", false
}
sess, ok := cached.(Session)
if !ok {
panic("cached session value does not conform to possesions.Session interface")
}
return sess.Get(key)
}
func set(w http.ResponseWriter, key, value string) {
pw := getResponseWriter(w)
pw.events = append(pw.events, Event{
Kind: EventSet,
Key: key,
Val: value,
})
}
// Del a session key
func Del(w http.ResponseWriter, key string) {
pw := getResponseWriter(w)
pw.events = append(pw.events, Event{
Kind: EventDel,
Key: key,
})
}
// DelAll delete all keys except for a whitelist
func DelAll(w http.ResponseWriter, whitelist []string) {
pw := getResponseWriter(w)
pw.events = append(pw.events, Event{
Kind: EventDelAll,
Keys: whitelist,
})
}
// Refresh a session's ttl
func Refresh(w http.ResponseWriter) {
pw := getResponseWriter(w)
pw.events = append(pw.events, Event{
Kind: EventRefresh,
})
}
// AddFlash adds a flash message to the session. Typically read and removed
// on the next request.
func AddFlash(w http.ResponseWriter, key string, value string) {
Set(w, key, value)
}
// AddFlashObj adds a flash message to the session using an object that's
// marshalled into JSON
func AddFlashObj(w http.ResponseWriter, key string, obj interface{}) error {
return SetObj(w, key, obj)
}
// GetFlash reads a flash message from the request and deletes it using the
// responsewriter.
func GetFlash(w http.ResponseWriter, ctx context.Context, key string) (string, bool) {
flash, ok := Get(ctx, key)
if !ok {
return "", false
}
Del(w, key)
return flash, true
}
// GetFlashObj reads a json-encoded flash message from the session and
// unmarshals it into obj. Use IsNoMapKeyError to determine if the value was
// found or not.
func GetFlashObj(w http.ResponseWriter, ctx context.Context, key string, obj interface{}) error {
flash, ok := Get(ctx, key)
if !ok {
return errNoMapKey{}
}
Del(w, key)
err := json.Unmarshal([]byte(flash), obj)
if err != nil {
return errors.Wrap(err, "failed to unmarshal flash key-value string")
}
return nil
}
func getResponseWriter(w http.ResponseWriter) *possesionsWriter {
for {
if r, ok := w.(*possesionsWriter); ok {
return r
}
u, ok := w.(UnderlyingResponseWriter)
if !ok {
panic("http.ResponseWriter was not possessions.responseWriter no posssessions.UnderlyingResponseWriter")
}
w = u.UnderlyingResponseWriter()
}
}