-
Notifications
You must be signed in to change notification settings - Fork 13
/
gostub.go
126 lines (110 loc) · 4.21 KB
/
gostub.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package gostub
import (
"fmt"
"reflect"
)
// Stub replaces the value stored at varToStub with stubVal.
// varToStub must be a pointer to the variable. stubVal should have a type
// that is assignable to the variable.
func Stub(varToStub interface{}, stubVal interface{}) *Stubs {
return New().Stub(varToStub, stubVal)
}
// StubFunc replaces a function variable with a function that returns stubVal.
// funcVarToStub must be a pointer to a function variable. If the function
// returns multiple values, then multiple values should be passed to stubFunc.
// The values must match be assignable to the return values' types.
func StubFunc(funcVarToStub interface{}, stubVal ...interface{}) *Stubs {
return New().StubFunc(funcVarToStub, stubVal...)
}
type envVal struct {
val string
ok bool
}
// Stubs represents a set of stubbed variables that can be reset.
type Stubs struct {
// stubs is a map from the variable pointer (being stubbed) to the original value.
stubs map[reflect.Value]reflect.Value
origEnv map[string]envVal
}
// New returns Stubs that can be used to stub out variables.
func New() *Stubs {
return &Stubs{
stubs: make(map[reflect.Value]reflect.Value),
origEnv: make(map[string]envVal),
}
}
// Stub replaces the value stored at varToStub with stubVal.
// varToStub must be a pointer to the variable. stubVal should have a type
// that is assignable to the variable.
func (s *Stubs) Stub(varToStub interface{}, stubVal interface{}) *Stubs {
v := reflect.ValueOf(varToStub)
stub := reflect.ValueOf(stubVal)
// Ensure varToStub is a pointer to the variable.
if v.Type().Kind() != reflect.Ptr {
panic("variable to stub is expected to be a pointer")
}
if _, ok := s.stubs[v]; !ok {
// Store the original value if this is the first time varPtr is being stubbed.
s.stubs[v] = reflect.ValueOf(v.Elem().Interface())
}
// *varToStub = stubVal
v.Elem().Set(stub)
return s
}
// StubFunc replaces a function variable with a function that returns stubVal.
// funcVarToStub must be a pointer to a function variable. If the function
// returns multiple values, then multiple values should be passed to stubFunc.
// The values must match be assignable to the return values' types.
func (s *Stubs) StubFunc(funcVarToStub interface{}, stubVal ...interface{}) *Stubs {
funcPtrType := reflect.TypeOf(funcVarToStub)
if funcPtrType.Kind() != reflect.Ptr ||
funcPtrType.Elem().Kind() != reflect.Func {
panic("func variable to stub must be a pointer to a function")
}
funcType := funcPtrType.Elem()
if funcType.NumOut() != len(stubVal) {
panic(fmt.Sprintf("func type has %v return values, but only %v stub values provided",
funcType.NumOut(), len(stubVal)))
}
return s.Stub(funcVarToStub, FuncReturning(funcPtrType.Elem(), stubVal...).Interface())
}
// FuncReturning creates a new function with type funcType that returns results.
func FuncReturning(funcType reflect.Type, results ...interface{}) reflect.Value {
var resultValues []reflect.Value
for i, r := range results {
var retValue reflect.Value
if r == nil {
// We can't use reflect.ValueOf(nil), so we need to create the zero value.
retValue = reflect.Zero(funcType.Out(i))
} else {
// We cannot simply use reflect.ValueOf(r) as that does not work for
// interface types, as reflect.ValueOf receives the dynamic type, which
// is the underlying type. e.g. for an error, it may *errors.errorString.
// Instead, we make the return type's expected interface value using
// reflect.New, and set the data to the passed in value.
tempV := reflect.New(funcType.Out(i))
tempV.Elem().Set(reflect.ValueOf(r))
retValue = tempV.Elem()
}
resultValues = append(resultValues, retValue)
}
return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
return resultValues
})
}
// Reset resets all stubbed variables back to their original values.
func (s *Stubs) Reset() {
for v, originalVal := range s.stubs {
v.Elem().Set(originalVal)
}
s.resetEnv()
}
// ResetSingle resets a single stubbed variable back to its original value.
func (s *Stubs) ResetSingle(varToStub interface{}) {
v := reflect.ValueOf(varToStub)
originalVal, ok := s.stubs[v]
if !ok {
panic("cannot reset variable as it has not been stubbed yet")
}
v.Elem().Set(originalVal)
}