diff --git a/integration/db/db_integration_test.go b/integration/db/db_integration_test.go index ec68c78c0ed96..bac905afdb65a 100644 --- a/integration/db/db_integration_test.go +++ b/integration/db/db_integration_test.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db" "github.com/gravitational/teleport/lib/srv/db/cassandra" "github.com/gravitational/teleport/lib/srv/db/common" + dbconnect "github.com/gravitational/teleport/lib/srv/db/common/connect" "github.com/gravitational/teleport/lib/srv/db/mongodb" "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/postgres" @@ -607,7 +608,7 @@ func TestDatabaseAccessPostgresSeparateListenerTLSDisabled(t *testing.T) { func init() { // Override database agents shuffle behavior to ensure they're always // tried in the same order during tests. Used for HA tests. - db.SetShuffleFunc(db.ShuffleSort) + db.SetShuffleFunc(dbconnect.ShuffleSort) } // testHARootCluster verifies that proxy falls back to a healthy diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 37848ea294c8e..c7704a1b19de6 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -81,6 +81,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/clickhouse" "github.com/gravitational/teleport/lib/srv/db/cloud" "github.com/gravitational/teleport/lib/srv/db/common" + dbconnect "github.com/gravitational/teleport/lib/srv/db/common/connect" "github.com/gravitational/teleport/lib/srv/db/dynamodb" "github.com/gravitational/teleport/lib/srv/db/elasticsearch" "github.com/gravitational/teleport/lib/srv/db/mongodb" @@ -2257,7 +2258,7 @@ func (c *testContext) Close() error { func init() { // Override database agents shuffle behavior to ensure they're always // tried in the same order during tests. Used for HA tests. - SetShuffleFunc(ShuffleSort) + SetShuffleFunc(dbconnect.ShuffleSort) } func setupTestContext(ctx context.Context, t testing.TB, withDatabases ...withDatabaseOption) *testContext { diff --git a/lib/srv/db/common/connect/connect.go b/lib/srv/db/common/connect/connect.go new file mode 100644 index 0000000000000..ee8252f78c392 --- /dev/null +++ b/lib/srv/db/common/connect/connect.go @@ -0,0 +1,352 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connect + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "log/slog" + "math/rand/v2" + "net" + "sort" + "strings" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client/proto" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +// DatabaseServersGetter is an interface for retrieving information about +// database proxy servers within a specific namespace. +type DatabaseServersGetter interface { + // GetDatabaseServers returns all registered database proxy servers. + GetDatabaseServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.DatabaseServer, error) +} + +// GetDatabaseServersParams contains the parameters required to retrieve +// database servers from a specific cluster. +type GetDatabaseServersParams struct { + Logger *slog.Logger + // ClusterName is the cluster name to which the database belongs. + ClusterName string + // DatabaseServersGetter used to fetch the list of database servers. + DatabaseServersGetter DatabaseServersGetter + // Identity contains the identity information. + Identity tlsca.Identity +} + +// GetDatabaseServers returns a list of database servers in a cluster that match +// the routing information from the provided identity. +func GetDatabaseServers(ctx context.Context, params GetDatabaseServersParams) ([]types.DatabaseServer, error) { + servers, err := params.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace) + if err != nil { + return nil, trace.Wrap(err) + } + + params.Logger.DebugContext(ctx, "Available database servers.", "cluster", params.ClusterName, "servers", servers) + + // Find out which database servers proxy the database a user is + // connecting to using routing information from identity. + var result []types.DatabaseServer + for _, server := range servers { + if server.GetDatabase().GetName() == params.Identity.RouteToDatabase.ServiceName { + result = append(result, server) + } + } + + if len(result) != 0 { + return result, nil + } + + return nil, trace.NotFound("database %q not found among registered databases in cluster %q", + params.Identity.RouteToDatabase.ServiceName, + params.Identity.RouteToCluster) +} + +// DatabaseCertificateSigner defines an interface for signing database +// Certificate Signing Requests (CSRs). +type DatabaseCertificateSigner interface { + // SignDatabaseCSR generates a client certificate used by proxy when talking + // to a remote database service. + SignDatabaseCSR(ctx context.Context, req *proto.DatabaseCSRRequest) (*proto.DatabaseCSRResponse, error) +} + +// AuthPreferenceGetter is an interface for retrieving the current configured +// cluster auth preference. +type AuthPreferenceGetter interface { + // GetAuthPreference returns the current cluster auth preference. + GetAuthPreference(context.Context) (types.AuthPreference, error) +} + +// ServerTLSConfigParams contains the parameters required to configure +// a TLS connection to a database server. +type ServerTLSConfigParams struct { + // CertSigner is the interface used to sign certificate signing requests + // for establishing a secure TLS connection. + CertSigner DatabaseCertificateSigner + // AuthPreference provides the authentication preference configuration + // used to determine cryptographic settings for certificate generation. + AuthPreference AuthPreferenceGetter + // Server represents the database server for which the TLS configuration + // is being generated. + Server types.DatabaseServer + // Identity contains the identity information. + Identity tlsca.Identity +} + +// GetServerTLSConfig returns TLS config used for establishing connection +// to a remote database server over reverse tunnel. +func GetServerTLSConfig(ctx context.Context, params ServerTLSConfigParams) (*tls.Config, error) { + privateKey, err := cryptosuites.GenerateKey(ctx, + cryptosuites.GetCurrentSuiteFromAuthPreference(params.AuthPreference), + cryptosuites.ProxyToDatabaseAgent) + if err != nil { + return nil, trace.Wrap(err) + } + + subject, err := params.Identity.Subject() + if err != nil { + return nil, trace.Wrap(err) + } + + csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey) + if err != nil { + return nil, trace.Wrap(err) + } + + response, err := params.CertSigner.SignDatabaseCSR(ctx, &proto.DatabaseCSRRequest{ + CSR: csr, + ClusterName: params.Identity.RouteToCluster, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + cert, err := keys.TLSCertificateForSigner(privateKey, response.Cert) + if err != nil { + return nil, trace.Wrap(err) + } + + pool := x509.NewCertPool() + for _, caCert := range response.CACerts { + ok := pool.AppendCertsFromPEM(caCert) + if !ok { + return nil, trace.BadParameter("failed to append CA certificate") + } + } + + return &tls.Config{ + ServerName: params.Server.GetHostname(), + Certificates: []tls.Certificate{cert}, + RootCAs: pool, + }, nil +} + +// ShuffleFunc defines a function that shuffles a list of database servers. +type ShuffleFunc func([]types.DatabaseServer) []types.DatabaseServer + +// ShuffleSort is a ShuffleFunc that sorts database servers by name and host ID. +// Used to provide predictable behavior in tests. +func ShuffleSort(servers []types.DatabaseServer) []types.DatabaseServer { + sort.Sort(types.DatabaseServers(servers)) + return servers +} + +// ShuffleRandom is a ShuffleFunc that randomizes the order of database servers. +// Used to provide load balancing behavior when proxying to multiple agents. +func ShuffleRandom(servers []types.DatabaseServer) []types.DatabaseServer { + rand.Shuffle(len(servers), func(i, j int) { + servers[i], servers[j] = servers[j], servers[i] + }) + return servers +} + +type Dialer interface { + // Dial dials any address within the site network, in terminating + // mode it uses local instance of forwarding server to terminate + // and record the connection. + Dial(params reversetunnelclient.DialParams) (conn net.Conn, err error) +} + +// ConnectParams contains parameters for connecting to the database server. +type ConnectParams struct { + Logger *slog.Logger + // Identity contains the identity information. + Identity tlsca.Identity + // Servers is the list of database servers that can handle the connection. + Servers []types.DatabaseServer + // ShuffleFunc is a function used to shuffle the list of database servers. + ShuffleFunc ShuffleFunc + // Cluster is the cluster name to which the database belongs. + ClusterName string + // Cluster represents the cluster to which the database belongs. + Dialer Dialer + // CertSigner is used to sign certificates for authenticating with the + // database. + CertSigner DatabaseCertificateSigner + // AuthPreference provides the authentication preferences for the cluster. + AuthPreference AuthPreferenceGetter + // ClientSrcAddr is the source address of the client making the connection. + ClientSrcAddr net.Addr + // ClientDstAddr is the destination address of the client making the + // connection. + ClientDstAddr net.Addr +} + +func (p *ConnectParams) CheckAndSetDefaults() error { + if p.Logger != nil { + p.Logger = slog.Default() + } + + if p.Identity.RouteToDatabase.Empty() { + return trace.BadParameter("Identity must have RouteToDatabase information") + } + + if p.ShuffleFunc == nil { + p.ShuffleFunc = ShuffleRandom + } + + if p.ClusterName == "" { + return trace.BadParameter("missing ClusterName parameter") + } + + if p.Dialer == nil { + return trace.BadParameter("missing Dialer parameter") + } + + if p.CertSigner == nil { + return trace.BadParameter("missing CertSigner parameter") + } + + if p.AuthPreference == nil { + return trace.BadParameter("missing AuthPreference parameter") + } + + if p.ClientSrcAddr == nil { + return trace.BadParameter("missing ClientSrcAddr parameter") + } + + if p.ClientDstAddr == nil { + return trace.BadParameter("missing ClientDstAddr parameter") + } + + return nil +} + +// ConnectStats contains statistics about the connection attempts. +type ConnectStats interface { + // GetAttemptedServers retrieves the number of database servers that were + // attempted to connect to. + GetAttemptedServers() int + // GetDialAttempts retrieves the number of times a dial to a server was + // attempted. + GetDialAttempts() int + // GetDialFailures retrieves the number of times a dial to a server failed. + GetDialFailures() int +} + +// Stats implements [ConnectStats]. +type Stats struct { + attemptedServers int + dialAttempts int + dialFailures int +} + +// GetAttemptedServers implements [ConnectStats]. +func (s Stats) GetAttemptedServers() int { + return s.attemptedServers +} + +// GetDialAttempts implements [ConnectStats]. +func (s Stats) GetDialAttempts() int { + return s.dialAttempts +} + +// GetDialFailures implements ConnectStats. +func (s Stats) GetDialFailures() int { + return s.dialFailures +} + +// Connect connects to the database server running on a remote cluster +// over reverse tunnel and upgrades this end of the connection to TLS so +// the identity can be passed over it. +func Connect(ctx context.Context, params ConnectParams) (net.Conn, ConnectStats, error) { + stats := Stats{} + if err := params.CheckAndSetDefaults(); err != nil { + return nil, stats, trace.Wrap(err) + } + + // There may be multiple database servers proxying the same database. If + // we get a connection problem error trying to dial one of them, likely + // the database server is down so try the next one. + for _, server := range params.ShuffleFunc(params.Servers) { + stats.attemptedServers++ + params.Logger.DebugContext(ctx, "Dialing to database service.", "server", server) + tlsConfig, err := GetServerTLSConfig(ctx, ServerTLSConfigParams{ + AuthPreference: params.AuthPreference, + CertSigner: params.CertSigner, + Identity: params.Identity, + Server: server, + }) + if err != nil { + return nil, stats, trace.Wrap(err) + } + + stats.dialAttempts++ + serviceConn, err := params.Dialer.Dial(reversetunnelclient.DialParams{ + From: params.ClientSrcAddr, + To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode}, + OriginalClientDstAddr: params.ClientDstAddr, + ServerID: fmt.Sprintf("%v.%v", server.GetHostID(), params.ClusterName), + ConnType: types.DatabaseTunnel, + ProxyIDs: server.GetProxyIDs(), + }) + if err != nil { + stats.dialFailures++ + // If an agent is down, we'll retry on the next one (if available). + if isReverseTunnelDownError(err) { + params.Logger.WarnContext(ctx, "Failed to dial database service.", "server", server, "error", err) + continue + } + return nil, stats, trace.Wrap(err) + } + // Upgrade the connection so the client identity can be passed to the + // remote server during TLS handshake. On the remote side, the connection + // received from the reverse tunnel will be handled by tls.Server. + serviceConn = tls.Client(serviceConn, tlsConfig) + return serviceConn, stats, nil + } + + return nil, stats, trace.BadParameter("failed to connect to any of the database servers") +} + +// isReverseTunnelDownError returns true if the provided error indicates that +// the reverse tunnel connection is down e.g. because the agent is down. +func isReverseTunnelDownError(err error) bool { + return trace.IsConnectionProblem(err) || + strings.Contains(err.Error(), reversetunnelclient.NoDatabaseTunnel) +} diff --git a/lib/srv/db/common/connect/connect_test.go b/lib/srv/db/common/connect/connect_test.go new file mode 100644 index 0000000000000..e1b6fa60afd37 --- /dev/null +++ b/lib/srv/db/common/connect/connect_test.go @@ -0,0 +1,308 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connect + +import ( + "context" + "crypto/tls" + "net" + "net/netip" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +func TestGetDatabaseServers(t *testing.T) { + for name, tc := range map[string]struct { + identity tlsca.Identity + getter *databaseServersMock + expectErrorFunc require.ErrorAssertionFunc + expectedServersLen int + }{ + "match": { + identity: identityWithDatabase("matched-db", "root", "alice", nil), + getter: newDatabaseServersWithServers("no-match", "matched-db", "another-db"), + expectErrorFunc: require.NoError, + expectedServersLen: 1, + }, + "no match": { + identity: identityWithDatabase("no-match", "root", "alice", nil), + getter: newDatabaseServersWithServers("first", "second", "third"), + expectErrorFunc: func(tt require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.True(t, trace.IsNotFound(err), "expected trace.NotFound error but got %T", err) + }, + }, + "get server error": { + identity: identityWithDatabase("no-match", "root", "alice", nil), + getter: newDatabaseServersWithErr(trace.Errorf("failure")), + expectErrorFunc: require.Error, + }, + } { + t.Run(name, func(t *testing.T) { + servers, err := GetDatabaseServers(context.Background(), GetDatabaseServersParams{ + Logger: utils.NewSlogLoggerForTests(), + ClusterName: "root", + DatabaseServersGetter: tc.getter, + Identity: tc.identity, + }) + tc.expectErrorFunc(t, err) + require.Len(t, servers, tc.expectedServersLen) + }) + } +} + +func TestGetServerTLSConfig(t *testing.T) { + clusterName := "root" + authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ + Clock: clockwork.NewFakeClockAt(time.Now()), + ClusterName: clusterName, + AuthPreferenceSpec: &types.AuthPreferenceSpecV2{ + SignatureAlgorithmSuite: types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_BALANCED_V1, + }, + Dir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, authServer.Close()) }) + + user, role, err := auth.CreateUserAndRole(authServer.AuthServer, "alice", []string{"db-access"}, nil) + require.NoError(t, err) + + for name, tc := range map[string]struct { + server types.DatabaseServer + identity tlsca.Identity + expectErrorFunc require.ErrorAssertionFunc + expectTLSConfigFunc require.ValueAssertionFunc + }{ + "generates the config": { + server: databaseServerWithName("db", "server1"), + identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}), + expectErrorFunc: require.NoError, + expectTLSConfigFunc: func(tt require.TestingT, tlsConfigI interface{}, _ ...interface{}) { + require.IsType(t, &tls.Config{}, tlsConfigI) + tlsConfig, _ := tlsConfigI.(*tls.Config) + require.Len(t, tlsConfig.Certificates, 1) + + ca, err := tlsca.FromTLSCertificate(tlsConfig.Certificates[0]) + require.NoError(t, err, "failed to extract CA from TLS certificate") + + identity, err := tlsca.FromSubject(ca.Cert.Subject, ca.Cert.NotAfter) + require.NoError(t, err, "failed to convert certificate subject into tlsca.Identity") + require.Equal(t, clusterName, identity.TeleportCluster) + require.ElementsMatch(t, []string{teleport.UsageDatabaseOnly}, identity.Usage) + }, + }, + "failed to generate config due to missing information on identity": { + server: databaseServerWithName("db", "server1"), + identity: tlsca.Identity{}, + expectErrorFunc: require.Error, + expectTLSConfigFunc: require.Nil, + }, + } { + t.Run(name, func(t *testing.T) { + tlsConfig, err := GetServerTLSConfig(context.Background(), ServerTLSConfigParams{ + CertSigner: authServer.AuthServer, + AuthPreference: authServer.AuthServer, + Server: tc.server, + Identity: tc.identity, + }) + tc.expectErrorFunc(t, err) + tc.expectTLSConfigFunc(t, tlsConfig) + }) + } +} + +func TestConnect(t *testing.T) { + clusterName := "root" + authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ + Clock: clockwork.NewFakeClockAt(time.Now()), + ClusterName: clusterName, + AuthPreferenceSpec: &types.AuthPreferenceSpecV2{ + SignatureAlgorithmSuite: types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_BALANCED_V1, + }, + Dir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, authServer.Close()) }) + + user, role, err := auth.CreateUserAndRole(authServer.AuthServer, "alice", []string{"db-access"}, nil) + require.NoError(t, err) + + for name, tc := range map[string]struct { + identity tlsca.Identity + dialer *dialerMock + expectErrFunc require.ErrorAssertionFunc + expectedStats ConnectStats + }{ + "connects": { + identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}), + dialer: newDialerMock(t, authServer.AuthServer, "db", []string{"server-1", "server-2"}, nil), + expectErrFunc: require.NoError, + expectedStats: Stats{attemptedServers: 1, dialAttempts: 1}, + }, + "connects but with dial failures": { + identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}), + // Given the shuffle function, the "server-1" will be attempted first (cause the initial failure). + dialer: newDialerMock(t, authServer.AuthServer, "db", []string{"server-2"}, []string{"server-1"}), + expectErrFunc: require.NoError, + expectedStats: Stats{attemptedServers: 2, dialAttempts: 2, dialFailures: 1}, + }, + "fails to connect": { + identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}), + dialer: newDialerMock(t, authServer.AuthServer, "db", nil, []string{"server-1"}), + expectErrFunc: require.Error, + expectedStats: Stats{attemptedServers: 1, dialAttempts: 1, dialFailures: 1}, + }, + "no servers": { + identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}), + dialer: newDialerMock(t, authServer.AuthServer, "db", nil, nil), + expectErrFunc: require.Error, + expectedStats: Stats{}, + }, + } { + t.Run(name, func(t *testing.T) { + conn, stats, err := Connect(context.Background(), ConnectParams{ + Logger: utils.NewSlogLoggerForTests(), + Identity: tc.identity, + Servers: tc.dialer.getServers(), + ShuffleFunc: ShuffleSort, + ClusterName: clusterName, + Dialer: tc.dialer, + CertSigner: authServer.AuthServer, + AuthPreference: authServer.AuthServer, + ClientSrcAddr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("0.0.0.0:3000")), + ClientDstAddr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("0.0.0.0:3000")), + }) + tc.expectErrFunc(t, err) + require.Equal(t, tc.expectedStats, stats) + if conn != nil { + conn.Close() + } + }) + } +} + +func identityWithDatabase(name, clusterName, user string, roles []string) tlsca.Identity { + return tlsca.Identity{ + RouteToCluster: clusterName, + TeleportCluster: clusterName, + Username: user, + Groups: roles, + RouteToDatabase: tlsca.RouteToDatabase{ + ServiceName: name, + Protocol: defaults.ProtocolPostgres, + Username: "postgres", + Database: "postgres", + }, + } +} + +type databaseServersMock struct { + servers []types.DatabaseServer + err error +} + +func databaseServerWithName(name, hostId string) types.DatabaseServer { + return &types.DatabaseServerV3{ + Spec: types.DatabaseServerSpecV3{ + Database: &types.DatabaseV3{ + Metadata: types.Metadata{ + Name: name, + }, + }, + HostID: hostId, + Hostname: name, + }, + } +} + +func newDatabaseServersWithServers(dbNames ...string) *databaseServersMock { + var servers []types.DatabaseServer + for _, name := range dbNames { + servers = append(servers, databaseServerWithName(name, uuid.New().String())) + } + + return &databaseServersMock{servers: servers} +} + +func newDatabaseServersWithErr(err error) *databaseServersMock { + return &databaseServersMock{err: err} +} + +func (d *databaseServersMock) GetDatabaseServers(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.DatabaseServer, error) { + return d.servers, d.err +} + +func newDialerMock(t *testing.T, authServer *auth.Server, dbName string, availableServers []string, unavailableServers []string) *dialerMock { + m := &dialerMock{serverConfig: make(map[string]*tls.Config)} + for _, host := range availableServers { + serverIdentity, err := auth.NewServerIdentity(authServer, host, types.RoleDatabase) + require.NoError(t, err) + tlsConfig, err := serverIdentity.TLSConfig(nil) + require.NoError(t, err) + + m.serverConfig[host] = tlsConfig + m.servers = append(m.servers, databaseServerWithName(dbName, host)) + } + + for _, host := range unavailableServers { + m.servers = append(m.servers, databaseServerWithName(dbName, host)) + } + + return m +} + +type dialerMock struct { + servers []types.DatabaseServer + serverConfig map[string]*tls.Config +} + +func (m *dialerMock) Dial(params reversetunnelclient.DialParams) (conn net.Conn, err error) { + hostID, _, _ := strings.Cut(params.ServerID, ".") + tlsConfig, ok := m.serverConfig[hostID] + if !ok { + return nil, trace.ConnectionProblem(nil, reversetunnelclient.NoDatabaseTunnel) + } + + // Start a fake database server that only performs the TLS handshake. + clt, srv := net.Pipe() + go func() { + defer srv.Close() + conn := tls.Server(srv, tlsConfig) + _ = conn.Handshake() + }() + + return clt, nil +} + +func (m *dialerMock) getServers() []types.DatabaseServer { + return m.servers +} diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index 3a2bacd8610e1..885ab194c048b 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -21,16 +21,12 @@ package db import ( "context" "crypto/tls" - "crypto/x509" "errors" - "fmt" "io" "log/slog" "math/rand/v2" "net" - "sort" "strconv" - "strings" "sync" "time" @@ -38,27 +34,23 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/client/proto" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" - "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/common/connect" "github.com/gravitational/teleport/lib/srv/db/common/enterprise" "github.com/gravitational/teleport/lib/srv/db/dbutils" "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/postgres" "github.com/gravitational/teleport/lib/srv/db/sqlserver" "github.com/gravitational/teleport/lib/srv/ingress" - "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -105,41 +97,22 @@ type ProxyServerConfig struct { MySQLServerVersion string } -// ShuffleFunc defines a function that shuffles a list of database servers. -type ShuffleFunc func([]types.DatabaseServer) []types.DatabaseServer - -// ShuffleRandom is a ShuffleFunc that randomizes the order of database servers. -// Used to provide load balancing behavior when proxying to multiple agents. -func ShuffleRandom(servers []types.DatabaseServer) []types.DatabaseServer { - rand.Shuffle(len(servers), func(i, j int) { - servers[i], servers[j] = servers[j], servers[i] - }) - return servers -} - -// ShuffleSort is a ShuffleFunc that sorts database servers by name and host ID. -// Used to provide predictable behavior in tests. -func ShuffleSort(servers []types.DatabaseServer) []types.DatabaseServer { - sort.Sort(types.DatabaseServers(servers)) - return servers -} - var ( // mu protects the shuffleFunc global access. mu sync.RWMutex // shuffleFunc provides shuffle behavior for multiple database agents. - shuffleFunc ShuffleFunc = ShuffleRandom + shuffleFunc connect.ShuffleFunc = connect.ShuffleRandom ) // SetShuffleFunc sets the shuffle behavior when proxying to multiple agents. -func SetShuffleFunc(fn ShuffleFunc) { +func SetShuffleFunc(fn connect.ShuffleFunc) { mu.Lock() defer mu.Unlock() shuffleFunc = fn } // getShuffleFunc returns the configured function used to shuffle agents. -func getShuffleFunc() ShuffleFunc { +func getShuffleFunc() connect.ShuffleFunc { mu.RLock() defer mu.RUnlock() return shuffleFunc @@ -451,59 +424,38 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext } else { labels = getLabelsFromDB(nil) } - labels["available_db_servers"] = strconv.Itoa(len(proxyCtx.Servers)) - defer observeLatency(connectionSetupTime.With(labels))() - var attemptedServers int + var ( + serviceConn net.Conn + stats connect.ConnectStats + err error + ) + defer func() { - dialAttemptedServers.With(labels).Observe(float64(attemptedServers)) + dialAttemptedServers.With(labels).Observe(float64(stats.GetAttemptedServers())) + dialAttempts.With(labels).Add(float64(stats.GetDialAttempts())) + dialFailures.With(labels).Add(float64(stats.GetDialFailures())) }() - // There may be multiple database servers proxying the same database. If - // we get a connection problem error trying to dial one of them, likely - // the database server is down so try the next one. - for _, server := range getShuffleFunc()(proxyCtx.Servers) { - attemptedServers++ - s.log.DebugContext(ctx, "Dialing to database service.", "server", server) - tlsConfig, err := s.getConfigForServer(ctx, proxyCtx.Identity, server) - if err != nil { - return nil, trace.Wrap(err) - } - - dialAttempts.With(labels).Inc() - serviceConn, err := proxyCtx.Cluster.Dial(reversetunnelclient.DialParams{ - From: clientSrcAddr, - To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode}, - OriginalClientDstAddr: clientDstAddr, - ServerID: fmt.Sprintf("%v.%v", server.GetHostID(), proxyCtx.Cluster.GetName()), - ConnType: types.DatabaseTunnel, - ProxyIDs: server.GetProxyIDs(), - }) - if err != nil { - dialFailures.With(labels).Inc() - // If an agent is down, we'll retry on the next one (if available). - if isReverseTunnelDownError(err) { - s.log.WarnContext(ctx, "Failed to dial database service.", "server", server, "error", err) - continue - } - return nil, trace.Wrap(err) - } - // Upgrade the connection so the client identity can be passed to the - // remote server during TLS handshake. On the remote side, the connection - // received from the reverse tunnel will be handled by tls.Server. - serviceConn = tls.Client(serviceConn, tlsConfig) - return serviceConn, nil + serviceConn, stats, err = connect.Connect(ctx, connect.ConnectParams{ + Logger: s.log, + Identity: proxyCtx.Identity, + Servers: proxyCtx.Servers, + ShuffleFunc: getShuffleFunc(), + ClusterName: proxyCtx.Cluster.GetName(), + Dialer: proxyCtx.Cluster, + CertSigner: s.cfg.AuthClient, + AuthPreference: s.cfg.AccessPoint, + ClientSrcAddr: clientSrcAddr, + ClientDstAddr: clientDstAddr, + }) + if err != nil { + return nil, trace.Wrap(err) } - return nil, trace.BadParameter("failed to connect to any of the database servers") -} -// isReverseTunnelDownError returns true if the provided error indicates that -// the reverse tunnel connection is down e.g. because the agent is down. -func isReverseTunnelDownError(err error) bool { - return trace.IsConnectionProblem(err) || - strings.Contains(err.Error(), reversetunnelclient.NoDatabaseTunnel) + return serviceConn, nil } // Proxy starts proxying all traffic received from database client between @@ -567,94 +519,28 @@ func (s *ProxyServer) Authorize(ctx context.Context, tlsConn utils.TLSConn, para if params.ClientIP != "" { identity.LoginIP = params.ClientIP } - cluster, servers, err := s.getDatabaseServers(ctx, identity) - if err != nil { - return nil, trace.Wrap(err) - } - return &common.ProxyContext{ - Identity: identity, - Cluster: cluster, - Servers: servers, - AuthContext: authContext, - }, nil -} - -// getDatabaseServers finds database servers that proxy the database instance -// encoded in the provided identity. -func (s *ProxyServer) getDatabaseServers(ctx context.Context, identity tlsca.Identity) (reversetunnelclient.RemoteSite, []types.DatabaseServer, error) { cluster, err := s.cfg.Tunnel.GetSite(identity.RouteToCluster) - if err != nil { - return nil, nil, trace.Wrap(err) - } - accessPoint, err := cluster.CachingAccessPoint() - if err != nil { - return nil, nil, trace.Wrap(err) - } - servers, err := accessPoint.GetDatabaseServers(ctx, apidefaults.Namespace) - if err != nil { - return nil, nil, trace.Wrap(err) - } - s.log.DebugContext(ctx, "Available database servers.", "cluster", cluster.GetName(), "servers", servers) - // Find out which database servers proxy the database a user is - // connecting to using routing information from identity. - var result []types.DatabaseServer - for _, server := range servers { - if server.GetDatabase().GetName() == identity.RouteToDatabase.ServiceName { - result = append(result, server) - } - } - if len(result) != 0 { - return cluster, result, nil - } - return nil, nil, trace.NotFound("database %q not found among registered databases in cluster %q", - identity.RouteToDatabase.ServiceName, - identity.RouteToCluster) -} - -// getConfigForServer returns TLS config used for establishing connection -// to a remote database server over reverse tunnel. -func (s *ProxyServer) getConfigForServer(ctx context.Context, identity tlsca.Identity, server types.DatabaseServer) (*tls.Config, error) { - defer observeLatency(tlsConfigTime.With(getLabelsFromDB(server.GetDatabase())))() - - privateKey, err := cryptosuites.GenerateKey(ctx, - cryptosuites.GetCurrentSuiteFromAuthPreference(s.cfg.AccessPoint), - cryptosuites.ProxyToDatabaseAgent) - if err != nil { - return nil, trace.Wrap(err) - } - subject, err := identity.Subject() if err != nil { return nil, trace.Wrap(err) } - csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey) + accessPoint, err := cluster.CachingAccessPoint() if err != nil { return nil, trace.Wrap(err) } - - response, err := s.cfg.AuthClient.SignDatabaseCSR(ctx, &proto.DatabaseCSRRequest{ - CSR: csr, - ClusterName: identity.RouteToCluster, + servers, err := connect.GetDatabaseServers(ctx, connect.GetDatabaseServersParams{ + Logger: s.log, + ClusterName: cluster.GetName(), + DatabaseServersGetter: accessPoint, + Identity: identity, }) if err != nil { return nil, trace.Wrap(err) } - - cert, err := keys.TLSCertificateForSigner(privateKey, response.Cert) - if err != nil { - return nil, trace.Wrap(err) - } - pool := x509.NewCertPool() - for _, caCert := range response.CACerts { - ok := pool.AppendCertsFromPEM(caCert) - if !ok { - return nil, trace.BadParameter("failed to append CA certificate") - } - } - - return &tls.Config{ - ServerName: server.GetHostname(), - Certificates: []tls.Certificate{cert}, - RootCAs: pool, + return &common.ProxyContext{ + Identity: identity, + Cluster: cluster, + Servers: servers, + AuthContext: authContext, }, nil }