From bbc2bfeb9e79ab1b446d7368f996b9b443c61cb2 Mon Sep 17 00:00:00 2001 From: "Remi GASCOU (Podalirius)" <79218792+p0dalirius@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:03:58 +0100 Subject: [PATCH] Added APOptions []int in client.GSSAPIBindRequest(...) and client.InitSecContext(...), fixes #536 --- gssapi/client.go | 55 ++++++++++++++++++++++++++++++++++++++++++-- v3/bind.go | 18 +++++++++------ v3/gssapi/client.go | 56 ++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 117 insertions(+), 12 deletions(-) diff --git a/gssapi/client.go b/gssapi/client.go index d6c6dbd4..e7018849 100644 --- a/gssapi/client.go +++ b/gssapi/client.go @@ -1,10 +1,15 @@ package gssapi import ( + "bytes" + "encoding/binary" + "encoding/hex" + "errors" "fmt" "github.com/jcmturner/gokrb5/v8/client" "github.com/jcmturner/gokrb5/v8/config" + "github.com/jcmturner/gokrb5/v8/iana/flags" "github.com/jcmturner/gokrb5/v8/keytab" "github.com/jcmturner/gokrb5/v8/types" @@ -110,7 +115,7 @@ func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, } client.ekey = ekey - token, err := spnego.NewKRB5TokenAPREQ(client.Client, tkt, ekey, gssapiFlags, []int{}) + token, err := spnego.NewKRB5TokenAPREQ(client.Client, tkt, ekey, gssapiFlags, []int{flags.APOptionMutualRequired}) if err != nil { return nil, false, err } @@ -160,7 +165,7 @@ func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, // See RFC 4752 section 3.1. func (client *Client) NegotiateSaslAuth(input []byte, authzid string) ([]byte, error) { token := &gssapi.WrapToken{} - err := token.Unmarshal(input, true) + err := UnmarshalWrapToken(token, input, true) if err != nil { return nil, err } @@ -212,3 +217,49 @@ func (client *Client) NegotiateSaslAuth(input []byte, authzid string) ([]byte, e return output, nil } + +func getGssWrapTokenId() *[2]byte { + return &[2]byte{0x05, 0x04} +} + +func UnmarshalWrapToken(wt *gssapi.WrapToken, b []byte, expectFromAcceptor bool) error { + // Check if we can read a whole header + if len(b) < 16 { + return errors.New("bytes shorter than header length") + } + // Is the Token ID correct? + if !bytes.Equal(getGssWrapTokenId()[:], b[0:2]) { + return fmt.Errorf("wrong Token ID. Expected %s, was %s", + hex.EncodeToString(getGssWrapTokenId()[:]), + hex.EncodeToString(b[0:2])) + } + // Check the acceptor flag + flags := b[2] + isFromAcceptor := flags&0x01 == 1 + if isFromAcceptor && !expectFromAcceptor { + return errors.New("unexpected acceptor flag is set: not expecting a token from the acceptor") + } + if !isFromAcceptor && expectFromAcceptor { + return errors.New("expected acceptor flag is not set: expecting a token from the acceptor, not the initiator") + } + // Check the filler byte + if b[3] != gssapi.FillerByte { + return fmt.Errorf("unexpected filler byte: expecting 0xFF, was %s ", hex.EncodeToString(b[3:4])) + } + checksumL := binary.BigEndian.Uint16(b[4:6]) + // Sanity check on the checksum length + if int(checksumL) > len(b)-gssapi.HdrLen { + return fmt.Errorf("inconsistent checksum length: %d bytes to parse, checksum length is %d", len(b), checksumL) + } + + payloadStart := 16 + checksumL + + wt.Flags = flags + wt.EC = checksumL + wt.RRC = binary.BigEndian.Uint16(b[6:8]) + wt.SndSeqNum = binary.BigEndian.Uint64(b[8:16]) + wt.CheckSum = b[16:payloadStart] + wt.Payload = b[payloadStart:] + + return nil +} diff --git a/v3/bind.go b/v3/bind.go index a37f8e2c..b7011bda 100644 --- a/v3/bind.go +++ b/v3/bind.go @@ -576,7 +576,7 @@ type GSSAPIClient interface { // reply token is received from the server, passing the reply token // to InitSecContext via the token parameters. // See RFC 4752 section 3.1. - InitSecContext(target string, token []byte) (outputToken []byte, needContinue bool, err error) + InitSecContext(target string, token []byte, APOptions []int) (outputToken []byte, needContinue bool, err error) // NegotiateSaslAuth performs the last step of the Sasl handshake. // It takes a token, which, when unwrapped, describes the servers supported // security layers (first octet) and maximum receive buffer (remaining @@ -606,14 +606,18 @@ type GSSAPIBindRequest struct { // GSSAPIBind performs the GSSAPI SASL bind using the provided GSSAPI client. func (l *Conn) GSSAPIBind(client GSSAPIClient, servicePrincipal, authzid string) error { - return l.GSSAPIBindRequest(client, &GSSAPIBindRequest{ - ServicePrincipalName: servicePrincipal, - AuthZID: authzid, - }) + return l.GSSAPIBindRequest( + client, + &GSSAPIBindRequest{ + ServicePrincipalName: servicePrincipal, + AuthZID: authzid, + }, + []int{}, + ) } // GSSAPIBindRequest performs the GSSAPI SASL bind using the provided GSSAPI client. -func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest) error { +func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest, APOptions []int) error { //nolint:errcheck defer client.DeleteSecContext() @@ -624,7 +628,7 @@ func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest) er for { if needInit { // Establish secure context between client and server. - reqToken, needInit, err = client.InitSecContext(req.ServicePrincipalName, recvToken) + reqToken, needInit, err = client.InitSecContext(req.ServicePrincipalName, recvToken, APOptions) if err != nil { return err } diff --git a/v3/gssapi/client.go b/v3/gssapi/client.go index d6c6dbd4..2ae2c006 100644 --- a/v3/gssapi/client.go +++ b/v3/gssapi/client.go @@ -1,6 +1,10 @@ package gssapi import ( + "bytes" + "encoding/binary" + "encoding/hex" + "errors" "fmt" "github.com/jcmturner/gokrb5/v8/client" @@ -99,7 +103,7 @@ func (client *Client) DeleteSecContext() error { // InitSecContext initiates the establishment of a security context for // GSS-API between the client and server. // See RFC 4752 section 3.1. -func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, error) { +func (client *Client) InitSecContext(target string, input []byte, APOptions []int) ([]byte, bool, error) { gssapiFlags := []int{gssapi.ContextFlagInteg, gssapi.ContextFlagConf, gssapi.ContextFlagMutual} switch input { @@ -110,7 +114,7 @@ func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, } client.ekey = ekey - token, err := spnego.NewKRB5TokenAPREQ(client.Client, tkt, ekey, gssapiFlags, []int{}) + token, err := spnego.NewKRB5TokenAPREQ(client.Client, tkt, ekey, gssapiFlags, APOptions) if err != nil { return nil, false, err } @@ -160,7 +164,7 @@ func (client *Client) InitSecContext(target string, input []byte) ([]byte, bool, // See RFC 4752 section 3.1. func (client *Client) NegotiateSaslAuth(input []byte, authzid string) ([]byte, error) { token := &gssapi.WrapToken{} - err := token.Unmarshal(input, true) + err := UnmarshalWrapToken(token, input, true) if err != nil { return nil, err } @@ -212,3 +216,49 @@ func (client *Client) NegotiateSaslAuth(input []byte, authzid string) ([]byte, e return output, nil } + +func getGssWrapTokenId() *[2]byte { + return &[2]byte{0x05, 0x04} +} + +func UnmarshalWrapToken(wt *gssapi.WrapToken, b []byte, expectFromAcceptor bool) error { + // Check if we can read a whole header + if len(b) < 16 { + return errors.New("bytes shorter than header length") + } + // Is the Token ID correct? + if !bytes.Equal(getGssWrapTokenId()[:], b[0:2]) { + return fmt.Errorf("wrong Token ID. Expected %s, was %s", + hex.EncodeToString(getGssWrapTokenId()[:]), + hex.EncodeToString(b[0:2])) + } + // Check the acceptor flag + flags := b[2] + isFromAcceptor := flags&0x01 == 1 + if isFromAcceptor && !expectFromAcceptor { + return errors.New("unexpected acceptor flag is set: not expecting a token from the acceptor") + } + if !isFromAcceptor && expectFromAcceptor { + return errors.New("expected acceptor flag is not set: expecting a token from the acceptor, not the initiator") + } + // Check the filler byte + if b[3] != gssapi.FillerByte { + return fmt.Errorf("unexpected filler byte: expecting 0xFF, was %s ", hex.EncodeToString(b[3:4])) + } + checksumL := binary.BigEndian.Uint16(b[4:6]) + // Sanity check on the checksum length + if int(checksumL) > len(b)-gssapi.HdrLen { + return fmt.Errorf("inconsistent checksum length: %d bytes to parse, checksum length is %d", len(b), checksumL) + } + + payloadStart := 16 + checksumL + + wt.Flags = flags + wt.EC = checksumL + wt.RRC = binary.BigEndian.Uint16(b[6:8]) + wt.SndSeqNum = binary.BigEndian.Uint64(b[8:16]) + wt.CheckSum = b[16:payloadStart] + wt.Payload = b[payloadStart:] + + return nil +}