Skip to content

Commit

Permalink
Merge pull request #225 from moov-io/multiple-host-keys
Browse files Browse the repository at this point in the history
feat: configure multiple host key options
  • Loading branch information
adamdecaf authored Oct 22, 2024
2 parents 0d4e3fb + 91e0c97 commit f3906f9
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 28 deletions.
34 changes: 6 additions & 28 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ import (
"sync"
"time"

"github.com/moov-io/base/log"
"github.com/moov-io/go-sftp/pkg/sshx"

"github.com/go-kit/kit/metrics/prometheus"
"github.com/moov-io/base/log"
"github.com/pkg/sftp"
stdprometheus "github.com/prometheus/client_golang/prometheus"

"golang.org/x/crypto/ssh"
)

Expand All @@ -39,25 +36,6 @@ var (
}, []string{"hostname"})
)

type ClientConfig struct {
Hostname string
Username string
Password string

Timeout time.Duration
MaxConnections int
PacketSize int

HostPublicKey string

// ClientPrivateKey must be a base64 encoded string
ClientPrivateKey string
ClientPrivateKeyPassword string // not base64 encoded

SkipChmodAfterUpload bool
SkipDirectoryCreation bool
}

type Client interface {
Ping() error
Close() error
Expand Down Expand Up @@ -186,7 +164,7 @@ func (c *client) clearConnectionOnError(err error) error {
var (
hostKeyCallbackOnce sync.Once
hostKeyCallback = func(logger log.Logger) {
msg := "sftp: WARNING!!! Insecure default of skipping SFTP host key validation. Please set HostPublicKey"
msg := "sftp: WARNING!!! Insecure default of skipping SFTP host key validation. Please set HostPublicKey(s)"
if logger != nil {
logger.Warn().Log(msg)
}
Expand All @@ -200,12 +178,12 @@ func sftpConnect(logger log.Logger, cfg ClientConfig) (*ssh.Client, io.WriteClos
}
conf.SetDefaults()

if cfg.HostPublicKey != "" {
pubKey, err := sshx.ReadPubKey([]byte(cfg.HostPublicKey))
if hostKeys := cfg.HostKeys(); len(hostKeys) > 0 {
callback, err := NewMultiKeyCallback(hostKeys)
if err != nil {
return nil, nil, nil, fmt.Errorf("problem parsing ssh public key: %w", err)
return nil, nil, nil, err
}
conf.HostKeyCallback = ssh.FixedHostKey(pubKey)
conf.HostKeyCallback = callback
} else {
hostKeyCallbackOnce.Do(func() {
hostKeyCallback(logger)
Expand Down
51 changes: 51 additions & 0 deletions client_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package go_sftp

import "time"

type ClientConfig struct {
Hostname string
Username string
Password string

Timeout time.Duration
MaxConnections int
PacketSize int

// HostPublicKey configures an SSH public key to validate the remote server's host key.
// If provided, this key will be merged into HostPublicKeys.
// Deprecated: Use HostPublicKeys instead.
HostPublicKey string

// HostPublicKeys configures multiple SSH public keys to validate the remote server's host key.
// Any key provided in HostPublicKey will be appended to this list.
HostPublicKeys []string

// ClientPrivateKey must be a base64 encoded string
ClientPrivateKey string
ClientPrivateKeyPassword string // not base64 encoded

SkipChmodAfterUpload bool
SkipDirectoryCreation bool
}

// HostKeys returns the list of configured public keys to use for host key verification.
func (cfg ClientConfig) HostKeys() []string {
if cfg.HostPublicKey != "" {
cfg.HostPublicKeys = append(cfg.HostPublicKeys, cfg.HostPublicKey)
}

return dedupe(cfg.HostPublicKeys)
}

func dedupe[T comparable](vals []T) []T {
seen := make(map[T]struct{})
var out []T
for i := range vals {
if _, ok := seen[vals[i]]; ok {
continue
}
seen[vals[i]] = struct{}{}
out = append(out, vals[i])
}
return out
}
58 changes: 58 additions & 0 deletions client_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package go_sftp_test

import (
"testing"

sftp "github.com/moov-io/go-sftp"
"github.com/stretchr/testify/require"
)

func TestClientConfig_HostKeys(t *testing.T) {
tests := []struct {
name string
cfg sftp.ClientConfig
want []string
}{
{
name: "no host keys",
cfg: sftp.ClientConfig{},
want: nil,
},
{
name: "only HostPublicKey",
cfg: sftp.ClientConfig{
HostPublicKey: "public-key",
},
want: []string{"public-key"},
},
{
name: "only HostPublicKeys",
cfg: sftp.ClientConfig{
HostPublicKeys: []string{
"public-key-1",
"public-key-2",
},
},
want: []string{"public-key-1", "public-key-2"},
},
{
name: "combined and unique",
cfg: sftp.ClientConfig{
HostPublicKey: "public-key",
HostPublicKeys: []string{
"public-key",
"public-key-1",
"public-key-1",
},
},
want: []string{"public-key", "public-key-1"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.cfg.HostKeys()
require.Equal(t, tt.want, got)
})
}
}
36 changes: 36 additions & 0 deletions hostkeys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package go_sftp

import (
"bytes"
"fmt"
"net"

"github.com/moov-io/go-sftp/pkg/sshx"
"golang.org/x/crypto/ssh"
)

type MultiKeyCallback struct {
hostKeys []ssh.PublicKey
}

func NewMultiKeyCallback(keys []string) (ssh.HostKeyCallback, error) {
m := &MultiKeyCallback{}
for i := range keys {
pubKey, err := sshx.ReadPubKey([]byte(keys[i]))
if err != nil {
return nil, fmt.Errorf("sftp: reading host key at index %d: %w", i, err)
}
m.hostKeys = append(m.hostKeys, pubKey)
}
return m.check, nil
}

// check is an ssh.HostKeyCallback based on ssh.FixedHostKey, running the equality check against each configured key.
func (m *MultiKeyCallback) check(_ string, _ net.Addr, key ssh.PublicKey) error {
for _, mKey := range m.hostKeys {
if bytes.Equal(key.Marshal(), mKey.Marshal()) {
return nil
}
}
return fmt.Errorf("sftp: no matching host keys")
}
67 changes: 67 additions & 0 deletions hostkeys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package go_sftp_test

import (
"testing"

sftp "github.com/moov-io/go-sftp"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)

const (
rsaKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQD1MU4KKe56DW+cnEomhmk0JMp5dS5LUDvrNM8cRE8i/JxPRsEbrHsta7/1Bj6jutAVTvHVSDrCZ5c+TIXlhSGQEfbjlXMiu9vP4vewdFTfm1xUdryv8MO5+Tas0HlbO9h92aV/SBpBxMLCIBVM9U+zKxmskxR1QMQZ7tzRGMnYMhQD74V6ANnwndDAlWspF+LcaUaDQqjeMDTv86q+ki4uDID5dwvx4eX11exfT+LwCvTMpCKhPJawA7QwnXNVvSEu/4p9EkNKr1xNIoiJdIwOnWrX8kAmlVkwL1cKCQF7wOfneYjKxJUMKwKtPZ9qtMmeOlhO7pLxhbtjcwvfIg69"
ecdsaKey = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEQFGqHGgr0e0jyq2ojt1TJgsFdLrn9w6iYXn1oWvuiOQgVAUL/6vrwQQ7ncbqM7/ZOaonx3C2Kr2IZHIXRmVXc="
ed25519Key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPZ3WQItO2r2wfGrjedz9LGwlLFgIUM6GbIpBKvaxiSz"

mismatchKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINnH6Geq7YNlClxNhCMN0IVt1f0XsPyMYqlW5htNYLpy"
)

func TestMultiKeyCallback_Check(t *testing.T) {
tests := []struct {
name string
key string
wantErr bool
}{
{
name: "host key mismatch",
key: mismatchKey,
wantErr: true,
},
{
name: "rsa match",
key: rsaKey,
wantErr: false,
},
{
name: "ecdsa match",
key: "example.io " + ecdsaKey,
wantErr: false,
},
{
name: "ed25519 match",
key: ed25519Key,
wantErr: false,
},
}

callback, err := sftp.NewMultiKeyCallback([]string{
rsaKey,
ecdsaKey,
ed25519Key,
})
require.NoError(t, err)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hostKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(tt.key))
require.NoError(t, err)

err = callback("", nil, hostKey)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

0 comments on commit f3906f9

Please sign in to comment.