Skip to content

Commit

Permalink
Merge pull request #468 from cloudflare/mitali/print-client-info
Browse files Browse the repository at this point in the history
Add a new field to Operation and log client info from connection
  • Loading branch information
mitalirawat authored Jul 5, 2024
2 parents 78e8c53 + 19e29ea commit 31a148e
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,5 @@ snapshot:
docker run --rm --privileged -v $(PWD):/go/tmp \
-v /var/run/docker.sock:/var/run/docker.sock \
-w /go/tmp \
ghcr.io/goreleaser/goreleaser-cross:latest --clean --snapshot --skip-publish
ghcr.io/goreleaser/goreleaser-cross:latest --clean --snapshot --skip publish

4 changes: 3 additions & 1 deletion protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,21 @@ type Operation struct {
ServerIP net.IP
SNI string
CertID string
ForwardingSvc int64
CustomFuncName string
JaegerSpan []byte
ReqContext []byte
}

func (o *Operation) String() string {
return fmt.Sprintf("[Opcode: %v, SKI: %v, Digest: %02x, Client IP: %s, Server IP: %s, SNI: %s]",
return fmt.Sprintf("[Opcode: %v, SKI: %v, Digest: %02x, Client IP: %s, Server IP: %s, SNI: %s, Forwarding Service: %v]",
o.Opcode,
o.SKI,
o.Digest,
o.ClientIP,
o.ServerIP,
o.SNI,
o.ForwardingSvc,
)
}

Expand Down
107 changes: 76 additions & 31 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"net"
"net/rpc"
"os"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -152,6 +153,12 @@ type Sealer interface {
Unseal(*protocol.Operation) ([]byte, error)
}

// ClientInfo has information on the client of the connection
type ClientInfo struct {
Name string
CertSerial string
}

// handler is associated with a connection and contains bookkeeping
// information used across goroutines. The channel tokens limits the
// concurrency: before reading a request a token is extracted, when
Expand All @@ -166,6 +173,7 @@ type handler struct {
conn net.Conn
timeout time.Duration
closed bool
c *ClientInfo
}

func (h *handler) close(err error) {
Expand Down Expand Up @@ -197,6 +205,12 @@ func (h *handler) handle(pkt *protocol.Packet, reqTime time.Time) {
} else {
resp = h.s.unlimitedDo(pkt, h.name)
}

if resp.op.ErrorVal() != protocol.ErrNone {
// log the client certificate information on the connection if the request failed so the caller is apparent
reqID, _ := getOperationRequestID(&pkt.Operation)
log.Errorf("operation from client %s client cert serial: %s errored. sni %s ski %s cert %s request-id %s", h.c.Name, h.c.CertSerial, resp.op.SNI, resp.op.SKI.String(), resp.op.CertID, reqID)
}
logRequestExecDuration(pkt.Operation.Opcode, start, resp.op.ErrorVal())
respPkt := protocol.Packet{
Header: protocol.Header{
Expand Down Expand Up @@ -289,32 +303,61 @@ func makeErrResponse(pkt *protocol.Packet, err protocol.Error) response {
func addOperationRequestID(op *protocol.Operation) string {
reqContext := make(map[string]interface{})
var reqID string
var gen bool

if len(op.ReqContext) > 0 {
if err := json.Unmarshal(op.ReqContext, &reqContext); err == nil {
if v, ok := reqContext["request_id"]; ok {
return v.(string)
} else {
gen = true
}
} else {
log.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext)
if decodeErr := json.Unmarshal(op.ReqContext, &reqContext); decodeErr != nil {
log.Error(fmt.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext))
return reqID
}
}

if v, ok := reqContext["request_id"]; ok {
return v.(string)
}

reqID = uuid.New().String()
reqContext["request_id"] = reqID
b, err := json.Marshal(reqContext)
if err == nil {
op.ReqContext = b
} else {
log.Errorf("error marshaling operation.ReqContext %v, ignoring error", reqContext)
reqID = ""
}
return reqID
}

func getOperationRequestID(op *protocol.Operation) (reqID string, err error) {
reqContext := make(map[string]interface{})
if len(op.ReqContext) == 0 {
return
}
if decodeErr := json.Unmarshal(op.ReqContext, &reqContext); decodeErr == nil {
if v, ok := reqContext["request_id"]; ok {
return v.(string), nil
}
} else {
err = fmt.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext)
log.Error(err)
return
}
return
}

if len(op.ReqContext) == 0 || gen {
reqID = uuid.New().String()
reqContext["request_id"] = reqID
b, err := json.Marshal(reqContext)
if err == nil {
op.ReqContext = b
func getClientInfoFromCerts(certs []*x509.Certificate) *ClientInfo {
cln := []string(nil)
srls := []string(nil)
for _, cert := range certs {
if cert.Subject.CommonName != "" {
cln = append(cln, cert.Subject.CommonName)
} else {
log.Errorf("error marshaling operation.ReqContext %v, ignoring error", reqContext)
reqID = ""
cln = append(cln, cert.DNSNames...)
}
srls = append(srls, cert.SerialNumber.String())
}
return reqID
name := strings.Join(cln, " , ")
serial := strings.Join(srls, " , ")
return &ClientInfo{Name: name, CertSerial: serial}
}

func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
Expand All @@ -328,7 +371,7 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
reqID := addOperationRequestID(&pkt.Operation)
span.SetTag("request_id", reqID)

log.Debugf("connection %s: limited=false opcode=%s id=%d sni=%s ip=%s ski=%v request-id=%s",
log.Debugf("connection %s: limited=false opcode= %s id=%d sni= %s ip= %s ski= %v request-id= %s",
connName,
pkt.Operation.Opcode,
pkt.Header.ID,
Expand Down Expand Up @@ -412,14 +455,14 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {

sig, err := key.Sign(rand.Reader, pkt.Operation.Payload, crypto.Hash(0))
if err != nil {
log.Errorf("Connection: %s: sni=%s ski=%v request-id=%s: Signing error: %v: request-id:%s:", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID)
log.Errorf("Connection: %s: sni= %s ski= %v request-id= %s: Signing error: %v", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
// This indicates that a remote keyserver is being used
var remoteConfigurationErr RemoteConfigurationErr
if errors.As(err, &remoteConfigurationErr) {
log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err, reqID)
log.Errorf("Connection %v: sni= %s ski= %v request-id= %s: %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err)
return makeErrResponse(pkt, protocol.ErrRemoteConfiguration)
} else {
log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID)
log.Errorf("Connection %v: sni= %s ski= %v request-id= %s: %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
}
Expand All @@ -430,23 +473,23 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
key, err := s.keys.Get(ctx, &pkt.Operation)
logKeyLoadDuration(loadStart)
if err != nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
log.Errorf("failed to load key with sni= %s ip= %s ski=%v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrInternal)
} else if key == nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

if _, ok := key.Public().(*rsa.PublicKey); !ok {
log.Errorf("Connection %v: sni=%s request-id=%s: %s: Key is not RSA", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto)
log.Errorf("Connection %v: sni= %s request-id= %s: %s: Key is not RSA", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto)
return makeErrResponse(pkt, protocol.ErrCrypto)
}

if rsaKey, ok := key.(*rsa.PrivateKey); ok {
// Decrypt without removing padding; that's the client's responsibility.
ptxt, err := textbook_rsa.Decrypt(rsaKey, pkt.Operation.Payload)
if err != nil {
log.Errorf("connection %v: sni=%s ip=%s ski=%v request-id=%s: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
log.Errorf("connection %v: sni= %s ip= %s ski= %v request-id= %s: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
return makeRespondResponse(pkt, ptxt)
Expand Down Expand Up @@ -493,10 +536,10 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
key, err := s.keys.Get(ctx, &pkt.Operation)
logKeyLoadDuration(loadStart)
if err != nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrInternal)
} else if key == nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

Expand Down Expand Up @@ -526,17 +569,17 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
}
if err != nil {
if attempts > 1 {
log.Debugf("Connection %v sni=%s ip=%s ski=%v request-id=%s : failed sign attempt: %s, %d attempt(s) left", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err, attempts-1)
log.Debugf("Connection %v sni= %s ip= %s ski= %v request-id= %s : failed sign attempt: %s, %d attempt(s) left", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err, attempts-1)
continue
} else {
tracing.LogError(span, err)
// This indicates that a remote keyserver is being used
var remoteConfigurationErr RemoteConfigurationErr
if errors.As(err, &remoteConfigurationErr) {
log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err)
log.Errorf("Connection %v sni= %s ip= %s ski= %v request-id= %s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err)
return makeErrResponse(pkt, protocol.ErrRemoteConfiguration)
} else {
log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
log.Errorf("Connection %v sni= %s ip= %s ski= %v request-id= %s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
}
Expand Down Expand Up @@ -656,6 +699,7 @@ func (s *Server) spawn(l net.Listener, c net.Conn) {
}
connState := tconn.ConnectionState()
certmetrics.Observe(certmetrics.CertSourceFromCerts(fmt.Sprintf("listener: %s", l.Addr().String()), connState.PeerCertificates)...)
cl := getClientInfoFromCerts(connState.PeerCertificates)
limited, err := s.config.isLimited(connState)
if err != nil {
log.Errorf("connection %v: could not determine if limited: %v", c.RemoteAddr(), err)
Expand Down Expand Up @@ -692,6 +736,7 @@ func (s *Server) spawn(l net.Listener, c net.Conn) {
conn: tconn,
listener: l,
timeout: timeout,
c: cl,
}
err = handler.loop()

Expand Down
1 change: 1 addition & 0 deletions tracing/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func SetOperationSpanTags(span opentracing.Span, op *protocol.Operation) {
"operation.sni": op.SNI,
"operation.certid": op.CertID,
"operation.customfuncname": op.CustomFuncName,
"operation.forwardingsvc": fmt.Sprintf("%d", op.ForwardingSvc),
}
for k, v := range tags {
span.SetTag(k, v)
Expand Down

0 comments on commit 31a148e

Please sign in to comment.