-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.go
201 lines (168 loc) · 4.22 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
package jsonrpc
import (
"context"
"encoding/json"
"strings"
"sync"
"sync/atomic"
"github.com/41north/async.go"
"github.com/juju/errors"
gonanoid "github.com/matoous/go-nanoid"
log "github.com/sirupsen/logrus"
)
var (
idGen = func() string { return gonanoid.MustID(20) }
ErrClosed = errors.ConstError("connection has been closed")
)
type (
ResponseFuture = async.Future[async.Result[*Response]]
RequestHandler = func(req Request)
CloseHandler = func(err error)
)
type Client interface {
Connect() error
Send(req Request, resp *Response) error
SendContext(ctx context.Context, req Request, resp *Response) error
SendAsync(req Request) ResponseFuture
SetCloseHandler(handler CloseHandler)
SetRequestHandler(handler RequestHandler)
Close() error
}
type client struct {
dialer Dialer
conn Connection
inFlight sync.Map
log *log.Entry
closed atomic.Bool
reqHandler RequestHandler
closeError error
closeHandler CloseHandler
}
func NewClient(dialer Dialer) Client {
return &client{
dialer: dialer,
}
}
func (c *client) Connect() error {
conn, err := c.dialer.Dial()
if err != nil {
return err
}
c.conn = conn
c.inFlight = sync.Map{}
c.log = log.WithField("connectionId", "tbd")
go c.readMessages()
return nil
}
func (c *client) SetRequestHandler(handler RequestHandler) {
c.reqHandler = handler
}
func (c *client) SetCloseHandler(handler CloseHandler) {
c.closeHandler = handler
}
func (c *client) readMessages() {
for !c.closed.Load() {
// read the next response
bytes, err := c.conn.Read()
if err != nil {
// set the client has closed and break out of the read loop
if err == ErrClosed {
c.closeError = err
c.Close()
break
}
// otherwise log the error
c.log.WithError(err).Error("read failure")
}
hasMethod := strings.Contains(string(bytes), "method")
if hasMethod {
// we assume this is a notification
var req Request
if err := json.Unmarshal(bytes, &req); err != nil {
c.log.WithError(err).Error("unmarshal failure")
} else {
c.reqHandler(req)
}
} else {
// otherwise we assume it is a response
var resp Response
if err := json.Unmarshal(bytes, &resp); err != nil {
c.log.WithError(err).Error("unmarshal failure")
} else {
c.onResponse(&resp)
}
}
}
}
func (c *client) onResponse(resp *Response) {
future, ok := c.inFlight.LoadAndDelete(string(resp.Id))
if !ok {
c.log.
WithField("id", resp.Id).
Warn("response received with unrecognised id")
}
future.(ResponseFuture).Set(async.NewResultValue[*Response](resp))
}
func (c *client) Close() error {
if c.closed.CompareAndSwap(false, true) {
// cancel any in flight requests
c.inFlight.Range(func(key, value any) bool {
value.(ResponseFuture).Set(async.NewResultErr[*Response](ErrClosed))
return true
})
if c.closeHandler != nil {
c.closeHandler(c.closeError)
}
return nil
} else {
return ErrClosed
}
}
func (c *client) Send(req Request, resp *Response) error {
return c.SendContext(context.Background(), req, resp)
}
func (c *client) SendContext(ctx context.Context, req Request, resp *Response) error {
future := c.SendAsync(req)
select {
case <-ctx.Done():
return ctx.Err()
case result := <-future.Get():
r, err := result.Unwrap()
if err != nil {
return err
}
// TODO can this copy be removed?
resp.Id = r.Id
resp.Result = r.Result
resp.Error = r.Error
resp.Version = r.Version
return nil
}
}
func (c *client) SendAsync(req Request) ResponseFuture {
// create a future for returning the result
future := async.NewFuture[async.Result[*Response]]()
// ensure a request id
if err := req.EnsureId(idGen); err != nil {
future.Set(async.NewResultErr[*Response](err))
return future
}
if c.closed.Load() {
// short circuit
future.Set(async.NewResultErr[*Response](ErrClosed))
return future
}
// marshal to json
bytes, err := json.Marshal(req)
if err != nil {
future.Set(async.NewResultErr[*Response](errors.Annotate(err, "failed to marshal request to json")))
return future
}
// create an in flight entry
c.inFlight.Store(string(req.Id), future)
// send the request
if err := c.conn.Write(bytes); err != nil {
future.Set(async.NewResultErr[*Response](err))
}
return future
}