Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: configure multiple host key options #225

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 6 additions & 28 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ import (
"sync"
"time"

"github.com/moov-io/base/log"
"github.com/moov-io/go-sftp/pkg/sshx"

"github.com/go-kit/kit/metrics/prometheus"
"github.com/moov-io/base/log"
"github.com/pkg/sftp"
stdprometheus "github.com/prometheus/client_golang/prometheus"

"golang.org/x/crypto/ssh"
)

Expand All @@ -39,25 +36,6 @@ var (
}, []string{"hostname"})
)

type ClientConfig struct {
Hostname string
Username string
Password string

Timeout time.Duration
MaxConnections int
PacketSize int

HostPublicKey string

// ClientPrivateKey must be a base64 encoded string
ClientPrivateKey string
ClientPrivateKeyPassword string // not base64 encoded

SkipChmodAfterUpload bool
SkipDirectoryCreation bool
}

type Client interface {
Ping() error
Close() error
Expand Down Expand Up @@ -186,7 +164,7 @@ func (c *client) clearConnectionOnError(err error) error {
var (
hostKeyCallbackOnce sync.Once
hostKeyCallback = func(logger log.Logger) {
msg := "sftp: WARNING!!! Insecure default of skipping SFTP host key validation. Please set HostPublicKey"
msg := "sftp: WARNING!!! Insecure default of skipping SFTP host key validation. Please set HostPublicKey(s)"
if logger != nil {
logger.Warn().Log(msg)
}
Expand All @@ -200,12 +178,12 @@ func sftpConnect(logger log.Logger, cfg ClientConfig) (*ssh.Client, io.WriteClos
}
conf.SetDefaults()

if cfg.HostPublicKey != "" {
pubKey, err := sshx.ReadPubKey([]byte(cfg.HostPublicKey))
if hostKeys := cfg.HostKeys(); len(hostKeys) > 0 {
callback, err := NewMultiKeyCallback(hostKeys)
if err != nil {
return nil, nil, nil, fmt.Errorf("problem parsing ssh public key: %w", err)
return nil, nil, nil, err
}
conf.HostKeyCallback = ssh.FixedHostKey(pubKey)
conf.HostKeyCallback = callback
} else {
hostKeyCallbackOnce.Do(func() {
hostKeyCallback(logger)
Expand Down
51 changes: 51 additions & 0 deletions client_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package go_sftp

import "time"

type ClientConfig struct {
Hostname string
Username string
Password string

Timeout time.Duration
MaxConnections int
PacketSize int

// HostPublicKey configures an SSH public key to validate the remote server's host key.
// If provided, this key will be merged into HostPublicKeys.
// Deprecated: Use HostPublicKeys instead.
HostPublicKey string

// HostPublicKeys configures multiple SSH public keys to validate the remote server's host key.
// Any key provided in HostPublicKey will be appended to this list.
HostPublicKeys []string

// ClientPrivateKey must be a base64 encoded string
ClientPrivateKey string
ClientPrivateKeyPassword string // not base64 encoded

SkipChmodAfterUpload bool
SkipDirectoryCreation bool
}

// HostKeys returns the list of configured public keys to use for host key verification.
func (cfg ClientConfig) HostKeys() []string {
if cfg.HostPublicKey != "" {
cfg.HostPublicKeys = append(cfg.HostPublicKeys, cfg.HostPublicKey)
}

return dedupe(cfg.HostPublicKeys)
}

func dedupe[T comparable](vals []T) []T {
seen := make(map[T]struct{})
var out []T
for i := range vals {
if _, ok := seen[vals[i]]; ok {
continue
}
seen[vals[i]] = struct{}{}
out = append(out, vals[i])
}
return out
}
58 changes: 58 additions & 0 deletions client_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package go_sftp_test

import (
"testing"

sftp "github.com/moov-io/go-sftp"
"github.com/stretchr/testify/require"
)

func TestClientConfig_HostKeys(t *testing.T) {
tests := []struct {
name string
cfg sftp.ClientConfig
want []string
}{
{
name: "no host keys",
cfg: sftp.ClientConfig{},
want: nil,
},
{
name: "only HostPublicKey",
cfg: sftp.ClientConfig{
HostPublicKey: "public-key",
},
want: []string{"public-key"},
},
{
name: "only HostPublicKeys",
cfg: sftp.ClientConfig{
HostPublicKeys: []string{
"public-key-1",
"public-key-2",
},
},
want: []string{"public-key-1", "public-key-2"},
},
{
name: "combined and unique",
cfg: sftp.ClientConfig{
HostPublicKey: "public-key",
HostPublicKeys: []string{
"public-key",
"public-key-1",
"public-key-1",
},
},
want: []string{"public-key", "public-key-1"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.cfg.HostKeys()
require.Equal(t, tt.want, got)
})
}
}
36 changes: 36 additions & 0 deletions hostkeys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package go_sftp

import (
"bytes"
"fmt"
"net"

"github.com/moov-io/go-sftp/pkg/sshx"
"golang.org/x/crypto/ssh"
)

type MultiKeyCallback struct {
hostKeys []ssh.PublicKey
}

func NewMultiKeyCallback(keys []string) (ssh.HostKeyCallback, error) {
m := &MultiKeyCallback{}
for i := range keys {
pubKey, err := sshx.ReadPubKey([]byte(keys[i]))
if err != nil {
return nil, fmt.Errorf("sftp: reading host key at index %d: %w", i, err)
}
m.hostKeys = append(m.hostKeys, pubKey)
}
return m.check, nil
}

// check is an ssh.HostKeyCallback based on ssh.FixedHostKey, running the equality check against each configured key.
func (m *MultiKeyCallback) check(_ string, _ net.Addr, key ssh.PublicKey) error {
for _, mKey := range m.hostKeys {
if bytes.Equal(key.Marshal(), mKey.Marshal()) {
return nil
}
}
return fmt.Errorf("sftp: no matching host keys")
}
67 changes: 67 additions & 0 deletions hostkeys_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package go_sftp_test

import (
"testing"

sftp "github.com/moov-io/go-sftp"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)

const (
rsaKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQD1MU4KKe56DW+cnEomhmk0JMp5dS5LUDvrNM8cRE8i/JxPRsEbrHsta7/1Bj6jutAVTvHVSDrCZ5c+TIXlhSGQEfbjlXMiu9vP4vewdFTfm1xUdryv8MO5+Tas0HlbO9h92aV/SBpBxMLCIBVM9U+zKxmskxR1QMQZ7tzRGMnYMhQD74V6ANnwndDAlWspF+LcaUaDQqjeMDTv86q+ki4uDID5dwvx4eX11exfT+LwCvTMpCKhPJawA7QwnXNVvSEu/4p9EkNKr1xNIoiJdIwOnWrX8kAmlVkwL1cKCQF7wOfneYjKxJUMKwKtPZ9qtMmeOlhO7pLxhbtjcwvfIg69"
ecdsaKey = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEQFGqHGgr0e0jyq2ojt1TJgsFdLrn9w6iYXn1oWvuiOQgVAUL/6vrwQQ7ncbqM7/ZOaonx3C2Kr2IZHIXRmVXc="
ed25519Key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPZ3WQItO2r2wfGrjedz9LGwlLFgIUM6GbIpBKvaxiSz"

mismatchKey = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINnH6Geq7YNlClxNhCMN0IVt1f0XsPyMYqlW5htNYLpy"
)

func TestMultiKeyCallback_Check(t *testing.T) {
tests := []struct {
name string
key string
wantErr bool
}{
{
name: "host key mismatch",
key: mismatchKey,
wantErr: true,
},
{
name: "rsa match",
key: rsaKey,
wantErr: false,
},
{
name: "ecdsa match",
key: "example.io " + ecdsaKey,
wantErr: false,
},
{
name: "ed25519 match",
key: ed25519Key,
wantErr: false,
},
}

callback, err := sftp.NewMultiKeyCallback([]string{
rsaKey,
ecdsaKey,
ed25519Key,
})
require.NoError(t, err)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hostKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(tt.key))
require.NoError(t, err)

err = callback("", nil, hostKey)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
Loading