diff --git a/integration/db/db_integration_test.go b/integration/db/db_integration_test.go
index ec68c78c0ed96..bac905afdb65a 100644
--- a/integration/db/db_integration_test.go
+++ b/integration/db/db_integration_test.go
@@ -44,6 +44,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db"
"github.com/gravitational/teleport/lib/srv/db/cassandra"
"github.com/gravitational/teleport/lib/srv/db/common"
+ dbconnect "github.com/gravitational/teleport/lib/srv/db/common/connect"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
"github.com/gravitational/teleport/lib/srv/db/mysql"
"github.com/gravitational/teleport/lib/srv/db/postgres"
@@ -607,7 +608,7 @@ func TestDatabaseAccessPostgresSeparateListenerTLSDisabled(t *testing.T) {
func init() {
// Override database agents shuffle behavior to ensure they're always
// tried in the same order during tests. Used for HA tests.
- db.SetShuffleFunc(db.ShuffleSort)
+ db.SetShuffleFunc(dbconnect.ShuffleSort)
}
// testHARootCluster verifies that proxy falls back to a healthy
diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go
index 37848ea294c8e..c7704a1b19de6 100644
--- a/lib/srv/db/access_test.go
+++ b/lib/srv/db/access_test.go
@@ -81,6 +81,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db/clickhouse"
"github.com/gravitational/teleport/lib/srv/db/cloud"
"github.com/gravitational/teleport/lib/srv/db/common"
+ dbconnect "github.com/gravitational/teleport/lib/srv/db/common/connect"
"github.com/gravitational/teleport/lib/srv/db/dynamodb"
"github.com/gravitational/teleport/lib/srv/db/elasticsearch"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
@@ -2257,7 +2258,7 @@ func (c *testContext) Close() error {
func init() {
// Override database agents shuffle behavior to ensure they're always
// tried in the same order during tests. Used for HA tests.
- SetShuffleFunc(ShuffleSort)
+ SetShuffleFunc(dbconnect.ShuffleSort)
}
func setupTestContext(ctx context.Context, t testing.TB, withDatabases ...withDatabaseOption) *testContext {
diff --git a/lib/srv/db/common/connect/connect.go b/lib/srv/db/common/connect/connect.go
new file mode 100644
index 0000000000000..ee8252f78c392
--- /dev/null
+++ b/lib/srv/db/common/connect/connect.go
@@ -0,0 +1,352 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package connect
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "log/slog"
+ "math/rand/v2"
+ "net"
+ "sort"
+ "strings"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/client/proto"
+ apidefaults "github.com/gravitational/teleport/api/defaults"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/api/utils/keys"
+ "github.com/gravitational/teleport/lib/cryptosuites"
+ "github.com/gravitational/teleport/lib/reversetunnelclient"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+// DatabaseServersGetter is an interface for retrieving information about
+// database proxy servers within a specific namespace.
+type DatabaseServersGetter interface {
+ // GetDatabaseServers returns all registered database proxy servers.
+ GetDatabaseServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.DatabaseServer, error)
+}
+
+// GetDatabaseServersParams contains the parameters required to retrieve
+// database servers from a specific cluster.
+type GetDatabaseServersParams struct {
+ Logger *slog.Logger
+ // ClusterName is the cluster name to which the database belongs.
+ ClusterName string
+ // DatabaseServersGetter used to fetch the list of database servers.
+ DatabaseServersGetter DatabaseServersGetter
+ // Identity contains the identity information.
+ Identity tlsca.Identity
+}
+
+// GetDatabaseServers returns a list of database servers in a cluster that match
+// the routing information from the provided identity.
+func GetDatabaseServers(ctx context.Context, params GetDatabaseServersParams) ([]types.DatabaseServer, error) {
+ servers, err := params.DatabaseServersGetter.GetDatabaseServers(ctx, apidefaults.Namespace)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ params.Logger.DebugContext(ctx, "Available database servers.", "cluster", params.ClusterName, "servers", servers)
+
+ // Find out which database servers proxy the database a user is
+ // connecting to using routing information from identity.
+ var result []types.DatabaseServer
+ for _, server := range servers {
+ if server.GetDatabase().GetName() == params.Identity.RouteToDatabase.ServiceName {
+ result = append(result, server)
+ }
+ }
+
+ if len(result) != 0 {
+ return result, nil
+ }
+
+ return nil, trace.NotFound("database %q not found among registered databases in cluster %q",
+ params.Identity.RouteToDatabase.ServiceName,
+ params.Identity.RouteToCluster)
+}
+
+// DatabaseCertificateSigner defines an interface for signing database
+// Certificate Signing Requests (CSRs).
+type DatabaseCertificateSigner interface {
+ // SignDatabaseCSR generates a client certificate used by proxy when talking
+ // to a remote database service.
+ SignDatabaseCSR(ctx context.Context, req *proto.DatabaseCSRRequest) (*proto.DatabaseCSRResponse, error)
+}
+
+// AuthPreferenceGetter is an interface for retrieving the current configured
+// cluster auth preference.
+type AuthPreferenceGetter interface {
+ // GetAuthPreference returns the current cluster auth preference.
+ GetAuthPreference(context.Context) (types.AuthPreference, error)
+}
+
+// ServerTLSConfigParams contains the parameters required to configure
+// a TLS connection to a database server.
+type ServerTLSConfigParams struct {
+ // CertSigner is the interface used to sign certificate signing requests
+ // for establishing a secure TLS connection.
+ CertSigner DatabaseCertificateSigner
+ // AuthPreference provides the authentication preference configuration
+ // used to determine cryptographic settings for certificate generation.
+ AuthPreference AuthPreferenceGetter
+ // Server represents the database server for which the TLS configuration
+ // is being generated.
+ Server types.DatabaseServer
+ // Identity contains the identity information.
+ Identity tlsca.Identity
+}
+
+// GetServerTLSConfig returns TLS config used for establishing connection
+// to a remote database server over reverse tunnel.
+func GetServerTLSConfig(ctx context.Context, params ServerTLSConfigParams) (*tls.Config, error) {
+ privateKey, err := cryptosuites.GenerateKey(ctx,
+ cryptosuites.GetCurrentSuiteFromAuthPreference(params.AuthPreference),
+ cryptosuites.ProxyToDatabaseAgent)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ subject, err := params.Identity.Subject()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ response, err := params.CertSigner.SignDatabaseCSR(ctx, &proto.DatabaseCSRRequest{
+ CSR: csr,
+ ClusterName: params.Identity.RouteToCluster,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ cert, err := keys.TLSCertificateForSigner(privateKey, response.Cert)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ pool := x509.NewCertPool()
+ for _, caCert := range response.CACerts {
+ ok := pool.AppendCertsFromPEM(caCert)
+ if !ok {
+ return nil, trace.BadParameter("failed to append CA certificate")
+ }
+ }
+
+ return &tls.Config{
+ ServerName: params.Server.GetHostname(),
+ Certificates: []tls.Certificate{cert},
+ RootCAs: pool,
+ }, nil
+}
+
+// ShuffleFunc defines a function that shuffles a list of database servers.
+type ShuffleFunc func([]types.DatabaseServer) []types.DatabaseServer
+
+// ShuffleSort is a ShuffleFunc that sorts database servers by name and host ID.
+// Used to provide predictable behavior in tests.
+func ShuffleSort(servers []types.DatabaseServer) []types.DatabaseServer {
+ sort.Sort(types.DatabaseServers(servers))
+ return servers
+}
+
+// ShuffleRandom is a ShuffleFunc that randomizes the order of database servers.
+// Used to provide load balancing behavior when proxying to multiple agents.
+func ShuffleRandom(servers []types.DatabaseServer) []types.DatabaseServer {
+ rand.Shuffle(len(servers), func(i, j int) {
+ servers[i], servers[j] = servers[j], servers[i]
+ })
+ return servers
+}
+
+type Dialer interface {
+ // Dial dials any address within the site network, in terminating
+ // mode it uses local instance of forwarding server to terminate
+ // and record the connection.
+ Dial(params reversetunnelclient.DialParams) (conn net.Conn, err error)
+}
+
+// ConnectParams contains parameters for connecting to the database server.
+type ConnectParams struct {
+ Logger *slog.Logger
+ // Identity contains the identity information.
+ Identity tlsca.Identity
+ // Servers is the list of database servers that can handle the connection.
+ Servers []types.DatabaseServer
+ // ShuffleFunc is a function used to shuffle the list of database servers.
+ ShuffleFunc ShuffleFunc
+ // Cluster is the cluster name to which the database belongs.
+ ClusterName string
+ // Cluster represents the cluster to which the database belongs.
+ Dialer Dialer
+ // CertSigner is used to sign certificates for authenticating with the
+ // database.
+ CertSigner DatabaseCertificateSigner
+ // AuthPreference provides the authentication preferences for the cluster.
+ AuthPreference AuthPreferenceGetter
+ // ClientSrcAddr is the source address of the client making the connection.
+ ClientSrcAddr net.Addr
+ // ClientDstAddr is the destination address of the client making the
+ // connection.
+ ClientDstAddr net.Addr
+}
+
+func (p *ConnectParams) CheckAndSetDefaults() error {
+ if p.Logger != nil {
+ p.Logger = slog.Default()
+ }
+
+ if p.Identity.RouteToDatabase.Empty() {
+ return trace.BadParameter("Identity must have RouteToDatabase information")
+ }
+
+ if p.ShuffleFunc == nil {
+ p.ShuffleFunc = ShuffleRandom
+ }
+
+ if p.ClusterName == "" {
+ return trace.BadParameter("missing ClusterName parameter")
+ }
+
+ if p.Dialer == nil {
+ return trace.BadParameter("missing Dialer parameter")
+ }
+
+ if p.CertSigner == nil {
+ return trace.BadParameter("missing CertSigner parameter")
+ }
+
+ if p.AuthPreference == nil {
+ return trace.BadParameter("missing AuthPreference parameter")
+ }
+
+ if p.ClientSrcAddr == nil {
+ return trace.BadParameter("missing ClientSrcAddr parameter")
+ }
+
+ if p.ClientDstAddr == nil {
+ return trace.BadParameter("missing ClientDstAddr parameter")
+ }
+
+ return nil
+}
+
+// ConnectStats contains statistics about the connection attempts.
+type ConnectStats interface {
+ // GetAttemptedServers retrieves the number of database servers that were
+ // attempted to connect to.
+ GetAttemptedServers() int
+ // GetDialAttempts retrieves the number of times a dial to a server was
+ // attempted.
+ GetDialAttempts() int
+ // GetDialFailures retrieves the number of times a dial to a server failed.
+ GetDialFailures() int
+}
+
+// Stats implements [ConnectStats].
+type Stats struct {
+ attemptedServers int
+ dialAttempts int
+ dialFailures int
+}
+
+// GetAttemptedServers implements [ConnectStats].
+func (s Stats) GetAttemptedServers() int {
+ return s.attemptedServers
+}
+
+// GetDialAttempts implements [ConnectStats].
+func (s Stats) GetDialAttempts() int {
+ return s.dialAttempts
+}
+
+// GetDialFailures implements ConnectStats.
+func (s Stats) GetDialFailures() int {
+ return s.dialFailures
+}
+
+// Connect connects to the database server running on a remote cluster
+// over reverse tunnel and upgrades this end of the connection to TLS so
+// the identity can be passed over it.
+func Connect(ctx context.Context, params ConnectParams) (net.Conn, ConnectStats, error) {
+ stats := Stats{}
+ if err := params.CheckAndSetDefaults(); err != nil {
+ return nil, stats, trace.Wrap(err)
+ }
+
+ // There may be multiple database servers proxying the same database. If
+ // we get a connection problem error trying to dial one of them, likely
+ // the database server is down so try the next one.
+ for _, server := range params.ShuffleFunc(params.Servers) {
+ stats.attemptedServers++
+ params.Logger.DebugContext(ctx, "Dialing to database service.", "server", server)
+ tlsConfig, err := GetServerTLSConfig(ctx, ServerTLSConfigParams{
+ AuthPreference: params.AuthPreference,
+ CertSigner: params.CertSigner,
+ Identity: params.Identity,
+ Server: server,
+ })
+ if err != nil {
+ return nil, stats, trace.Wrap(err)
+ }
+
+ stats.dialAttempts++
+ serviceConn, err := params.Dialer.Dial(reversetunnelclient.DialParams{
+ From: params.ClientSrcAddr,
+ To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode},
+ OriginalClientDstAddr: params.ClientDstAddr,
+ ServerID: fmt.Sprintf("%v.%v", server.GetHostID(), params.ClusterName),
+ ConnType: types.DatabaseTunnel,
+ ProxyIDs: server.GetProxyIDs(),
+ })
+ if err != nil {
+ stats.dialFailures++
+ // If an agent is down, we'll retry on the next one (if available).
+ if isReverseTunnelDownError(err) {
+ params.Logger.WarnContext(ctx, "Failed to dial database service.", "server", server, "error", err)
+ continue
+ }
+ return nil, stats, trace.Wrap(err)
+ }
+ // Upgrade the connection so the client identity can be passed to the
+ // remote server during TLS handshake. On the remote side, the connection
+ // received from the reverse tunnel will be handled by tls.Server.
+ serviceConn = tls.Client(serviceConn, tlsConfig)
+ return serviceConn, stats, nil
+ }
+
+ return nil, stats, trace.BadParameter("failed to connect to any of the database servers")
+}
+
+// isReverseTunnelDownError returns true if the provided error indicates that
+// the reverse tunnel connection is down e.g. because the agent is down.
+func isReverseTunnelDownError(err error) bool {
+ return trace.IsConnectionProblem(err) ||
+ strings.Contains(err.Error(), reversetunnelclient.NoDatabaseTunnel)
+}
diff --git a/lib/srv/db/common/connect/connect_test.go b/lib/srv/db/common/connect/connect_test.go
new file mode 100644
index 0000000000000..e1b6fa60afd37
--- /dev/null
+++ b/lib/srv/db/common/connect/connect_test.go
@@ -0,0 +1,308 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package connect
+
+import (
+ "context"
+ "crypto/tls"
+ "net"
+ "net/netip"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/auth"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/reversetunnelclient"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+func TestGetDatabaseServers(t *testing.T) {
+ for name, tc := range map[string]struct {
+ identity tlsca.Identity
+ getter *databaseServersMock
+ expectErrorFunc require.ErrorAssertionFunc
+ expectedServersLen int
+ }{
+ "match": {
+ identity: identityWithDatabase("matched-db", "root", "alice", nil),
+ getter: newDatabaseServersWithServers("no-match", "matched-db", "another-db"),
+ expectErrorFunc: require.NoError,
+ expectedServersLen: 1,
+ },
+ "no match": {
+ identity: identityWithDatabase("no-match", "root", "alice", nil),
+ getter: newDatabaseServersWithServers("first", "second", "third"),
+ expectErrorFunc: func(tt require.TestingT, err error, i ...interface{}) {
+ require.Error(t, err)
+ require.True(t, trace.IsNotFound(err), "expected trace.NotFound error but got %T", err)
+ },
+ },
+ "get server error": {
+ identity: identityWithDatabase("no-match", "root", "alice", nil),
+ getter: newDatabaseServersWithErr(trace.Errorf("failure")),
+ expectErrorFunc: require.Error,
+ },
+ } {
+ t.Run(name, func(t *testing.T) {
+ servers, err := GetDatabaseServers(context.Background(), GetDatabaseServersParams{
+ Logger: utils.NewSlogLoggerForTests(),
+ ClusterName: "root",
+ DatabaseServersGetter: tc.getter,
+ Identity: tc.identity,
+ })
+ tc.expectErrorFunc(t, err)
+ require.Len(t, servers, tc.expectedServersLen)
+ })
+ }
+}
+
+func TestGetServerTLSConfig(t *testing.T) {
+ clusterName := "root"
+ authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
+ Clock: clockwork.NewFakeClockAt(time.Now()),
+ ClusterName: clusterName,
+ AuthPreferenceSpec: &types.AuthPreferenceSpecV2{
+ SignatureAlgorithmSuite: types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_BALANCED_V1,
+ },
+ Dir: t.TempDir(),
+ })
+ require.NoError(t, err)
+ t.Cleanup(func() { require.NoError(t, authServer.Close()) })
+
+ user, role, err := auth.CreateUserAndRole(authServer.AuthServer, "alice", []string{"db-access"}, nil)
+ require.NoError(t, err)
+
+ for name, tc := range map[string]struct {
+ server types.DatabaseServer
+ identity tlsca.Identity
+ expectErrorFunc require.ErrorAssertionFunc
+ expectTLSConfigFunc require.ValueAssertionFunc
+ }{
+ "generates the config": {
+ server: databaseServerWithName("db", "server1"),
+ identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}),
+ expectErrorFunc: require.NoError,
+ expectTLSConfigFunc: func(tt require.TestingT, tlsConfigI interface{}, _ ...interface{}) {
+ require.IsType(t, &tls.Config{}, tlsConfigI)
+ tlsConfig, _ := tlsConfigI.(*tls.Config)
+ require.Len(t, tlsConfig.Certificates, 1)
+
+ ca, err := tlsca.FromTLSCertificate(tlsConfig.Certificates[0])
+ require.NoError(t, err, "failed to extract CA from TLS certificate")
+
+ identity, err := tlsca.FromSubject(ca.Cert.Subject, ca.Cert.NotAfter)
+ require.NoError(t, err, "failed to convert certificate subject into tlsca.Identity")
+ require.Equal(t, clusterName, identity.TeleportCluster)
+ require.ElementsMatch(t, []string{teleport.UsageDatabaseOnly}, identity.Usage)
+ },
+ },
+ "failed to generate config due to missing information on identity": {
+ server: databaseServerWithName("db", "server1"),
+ identity: tlsca.Identity{},
+ expectErrorFunc: require.Error,
+ expectTLSConfigFunc: require.Nil,
+ },
+ } {
+ t.Run(name, func(t *testing.T) {
+ tlsConfig, err := GetServerTLSConfig(context.Background(), ServerTLSConfigParams{
+ CertSigner: authServer.AuthServer,
+ AuthPreference: authServer.AuthServer,
+ Server: tc.server,
+ Identity: tc.identity,
+ })
+ tc.expectErrorFunc(t, err)
+ tc.expectTLSConfigFunc(t, tlsConfig)
+ })
+ }
+}
+
+func TestConnect(t *testing.T) {
+ clusterName := "root"
+ authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
+ Clock: clockwork.NewFakeClockAt(time.Now()),
+ ClusterName: clusterName,
+ AuthPreferenceSpec: &types.AuthPreferenceSpecV2{
+ SignatureAlgorithmSuite: types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_BALANCED_V1,
+ },
+ Dir: t.TempDir(),
+ })
+ require.NoError(t, err)
+ t.Cleanup(func() { require.NoError(t, authServer.Close()) })
+
+ user, role, err := auth.CreateUserAndRole(authServer.AuthServer, "alice", []string{"db-access"}, nil)
+ require.NoError(t, err)
+
+ for name, tc := range map[string]struct {
+ identity tlsca.Identity
+ dialer *dialerMock
+ expectErrFunc require.ErrorAssertionFunc
+ expectedStats ConnectStats
+ }{
+ "connects": {
+ identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}),
+ dialer: newDialerMock(t, authServer.AuthServer, "db", []string{"server-1", "server-2"}, nil),
+ expectErrFunc: require.NoError,
+ expectedStats: Stats{attemptedServers: 1, dialAttempts: 1},
+ },
+ "connects but with dial failures": {
+ identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}),
+ // Given the shuffle function, the "server-1" will be attempted first (cause the initial failure).
+ dialer: newDialerMock(t, authServer.AuthServer, "db", []string{"server-2"}, []string{"server-1"}),
+ expectErrFunc: require.NoError,
+ expectedStats: Stats{attemptedServers: 2, dialAttempts: 2, dialFailures: 1},
+ },
+ "fails to connect": {
+ identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}),
+ dialer: newDialerMock(t, authServer.AuthServer, "db", nil, []string{"server-1"}),
+ expectErrFunc: require.Error,
+ expectedStats: Stats{attemptedServers: 1, dialAttempts: 1, dialFailures: 1},
+ },
+ "no servers": {
+ identity: identityWithDatabase("db", clusterName, user.GetName(), []string{role.GetName()}),
+ dialer: newDialerMock(t, authServer.AuthServer, "db", nil, nil),
+ expectErrFunc: require.Error,
+ expectedStats: Stats{},
+ },
+ } {
+ t.Run(name, func(t *testing.T) {
+ conn, stats, err := Connect(context.Background(), ConnectParams{
+ Logger: utils.NewSlogLoggerForTests(),
+ Identity: tc.identity,
+ Servers: tc.dialer.getServers(),
+ ShuffleFunc: ShuffleSort,
+ ClusterName: clusterName,
+ Dialer: tc.dialer,
+ CertSigner: authServer.AuthServer,
+ AuthPreference: authServer.AuthServer,
+ ClientSrcAddr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("0.0.0.0:3000")),
+ ClientDstAddr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("0.0.0.0:3000")),
+ })
+ tc.expectErrFunc(t, err)
+ require.Equal(t, tc.expectedStats, stats)
+ if conn != nil {
+ conn.Close()
+ }
+ })
+ }
+}
+
+func identityWithDatabase(name, clusterName, user string, roles []string) tlsca.Identity {
+ return tlsca.Identity{
+ RouteToCluster: clusterName,
+ TeleportCluster: clusterName,
+ Username: user,
+ Groups: roles,
+ RouteToDatabase: tlsca.RouteToDatabase{
+ ServiceName: name,
+ Protocol: defaults.ProtocolPostgres,
+ Username: "postgres",
+ Database: "postgres",
+ },
+ }
+}
+
+type databaseServersMock struct {
+ servers []types.DatabaseServer
+ err error
+}
+
+func databaseServerWithName(name, hostId string) types.DatabaseServer {
+ return &types.DatabaseServerV3{
+ Spec: types.DatabaseServerSpecV3{
+ Database: &types.DatabaseV3{
+ Metadata: types.Metadata{
+ Name: name,
+ },
+ },
+ HostID: hostId,
+ Hostname: name,
+ },
+ }
+}
+
+func newDatabaseServersWithServers(dbNames ...string) *databaseServersMock {
+ var servers []types.DatabaseServer
+ for _, name := range dbNames {
+ servers = append(servers, databaseServerWithName(name, uuid.New().String()))
+ }
+
+ return &databaseServersMock{servers: servers}
+}
+
+func newDatabaseServersWithErr(err error) *databaseServersMock {
+ return &databaseServersMock{err: err}
+}
+
+func (d *databaseServersMock) GetDatabaseServers(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.DatabaseServer, error) {
+ return d.servers, d.err
+}
+
+func newDialerMock(t *testing.T, authServer *auth.Server, dbName string, availableServers []string, unavailableServers []string) *dialerMock {
+ m := &dialerMock{serverConfig: make(map[string]*tls.Config)}
+ for _, host := range availableServers {
+ serverIdentity, err := auth.NewServerIdentity(authServer, host, types.RoleDatabase)
+ require.NoError(t, err)
+ tlsConfig, err := serverIdentity.TLSConfig(nil)
+ require.NoError(t, err)
+
+ m.serverConfig[host] = tlsConfig
+ m.servers = append(m.servers, databaseServerWithName(dbName, host))
+ }
+
+ for _, host := range unavailableServers {
+ m.servers = append(m.servers, databaseServerWithName(dbName, host))
+ }
+
+ return m
+}
+
+type dialerMock struct {
+ servers []types.DatabaseServer
+ serverConfig map[string]*tls.Config
+}
+
+func (m *dialerMock) Dial(params reversetunnelclient.DialParams) (conn net.Conn, err error) {
+ hostID, _, _ := strings.Cut(params.ServerID, ".")
+ tlsConfig, ok := m.serverConfig[hostID]
+ if !ok {
+ return nil, trace.ConnectionProblem(nil, reversetunnelclient.NoDatabaseTunnel)
+ }
+
+ // Start a fake database server that only performs the TLS handshake.
+ clt, srv := net.Pipe()
+ go func() {
+ defer srv.Close()
+ conn := tls.Server(srv, tlsConfig)
+ _ = conn.Handshake()
+ }()
+
+ return clt, nil
+}
+
+func (m *dialerMock) getServers() []types.DatabaseServer {
+ return m.servers
+}
diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go
index 3a2bacd8610e1..885ab194c048b 100644
--- a/lib/srv/db/proxyserver.go
+++ b/lib/srv/db/proxyserver.go
@@ -21,16 +21,12 @@ package db
import (
"context"
"crypto/tls"
- "crypto/x509"
"errors"
- "fmt"
"io"
"log/slog"
"math/rand/v2"
"net"
- "sort"
"strconv"
- "strings"
"sync"
"time"
@@ -38,27 +34,23 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/gravitational/teleport"
- "github.com/gravitational/teleport/api/client/proto"
- apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
- "github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/authz"
- "github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/srv/db/common"
+ "github.com/gravitational/teleport/lib/srv/db/common/connect"
"github.com/gravitational/teleport/lib/srv/db/common/enterprise"
"github.com/gravitational/teleport/lib/srv/db/dbutils"
"github.com/gravitational/teleport/lib/srv/db/mysql"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/srv/db/sqlserver"
"github.com/gravitational/teleport/lib/srv/ingress"
- "github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
)
@@ -105,41 +97,22 @@ type ProxyServerConfig struct {
MySQLServerVersion string
}
-// ShuffleFunc defines a function that shuffles a list of database servers.
-type ShuffleFunc func([]types.DatabaseServer) []types.DatabaseServer
-
-// ShuffleRandom is a ShuffleFunc that randomizes the order of database servers.
-// Used to provide load balancing behavior when proxying to multiple agents.
-func ShuffleRandom(servers []types.DatabaseServer) []types.DatabaseServer {
- rand.Shuffle(len(servers), func(i, j int) {
- servers[i], servers[j] = servers[j], servers[i]
- })
- return servers
-}
-
-// ShuffleSort is a ShuffleFunc that sorts database servers by name and host ID.
-// Used to provide predictable behavior in tests.
-func ShuffleSort(servers []types.DatabaseServer) []types.DatabaseServer {
- sort.Sort(types.DatabaseServers(servers))
- return servers
-}
-
var (
// mu protects the shuffleFunc global access.
mu sync.RWMutex
// shuffleFunc provides shuffle behavior for multiple database agents.
- shuffleFunc ShuffleFunc = ShuffleRandom
+ shuffleFunc connect.ShuffleFunc = connect.ShuffleRandom
)
// SetShuffleFunc sets the shuffle behavior when proxying to multiple agents.
-func SetShuffleFunc(fn ShuffleFunc) {
+func SetShuffleFunc(fn connect.ShuffleFunc) {
mu.Lock()
defer mu.Unlock()
shuffleFunc = fn
}
// getShuffleFunc returns the configured function used to shuffle agents.
-func getShuffleFunc() ShuffleFunc {
+func getShuffleFunc() connect.ShuffleFunc {
mu.RLock()
defer mu.RUnlock()
return shuffleFunc
@@ -451,59 +424,38 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext
} else {
labels = getLabelsFromDB(nil)
}
-
labels["available_db_servers"] = strconv.Itoa(len(proxyCtx.Servers))
-
defer observeLatency(connectionSetupTime.With(labels))()
- var attemptedServers int
+ var (
+ serviceConn net.Conn
+ stats connect.ConnectStats
+ err error
+ )
+
defer func() {
- dialAttemptedServers.With(labels).Observe(float64(attemptedServers))
+ dialAttemptedServers.With(labels).Observe(float64(stats.GetAttemptedServers()))
+ dialAttempts.With(labels).Add(float64(stats.GetDialAttempts()))
+ dialFailures.With(labels).Add(float64(stats.GetDialFailures()))
}()
- // There may be multiple database servers proxying the same database. If
- // we get a connection problem error trying to dial one of them, likely
- // the database server is down so try the next one.
- for _, server := range getShuffleFunc()(proxyCtx.Servers) {
- attemptedServers++
- s.log.DebugContext(ctx, "Dialing to database service.", "server", server)
- tlsConfig, err := s.getConfigForServer(ctx, proxyCtx.Identity, server)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- dialAttempts.With(labels).Inc()
- serviceConn, err := proxyCtx.Cluster.Dial(reversetunnelclient.DialParams{
- From: clientSrcAddr,
- To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode},
- OriginalClientDstAddr: clientDstAddr,
- ServerID: fmt.Sprintf("%v.%v", server.GetHostID(), proxyCtx.Cluster.GetName()),
- ConnType: types.DatabaseTunnel,
- ProxyIDs: server.GetProxyIDs(),
- })
- if err != nil {
- dialFailures.With(labels).Inc()
- // If an agent is down, we'll retry on the next one (if available).
- if isReverseTunnelDownError(err) {
- s.log.WarnContext(ctx, "Failed to dial database service.", "server", server, "error", err)
- continue
- }
- return nil, trace.Wrap(err)
- }
- // Upgrade the connection so the client identity can be passed to the
- // remote server during TLS handshake. On the remote side, the connection
- // received from the reverse tunnel will be handled by tls.Server.
- serviceConn = tls.Client(serviceConn, tlsConfig)
- return serviceConn, nil
+ serviceConn, stats, err = connect.Connect(ctx, connect.ConnectParams{
+ Logger: s.log,
+ Identity: proxyCtx.Identity,
+ Servers: proxyCtx.Servers,
+ ShuffleFunc: getShuffleFunc(),
+ ClusterName: proxyCtx.Cluster.GetName(),
+ Dialer: proxyCtx.Cluster,
+ CertSigner: s.cfg.AuthClient,
+ AuthPreference: s.cfg.AccessPoint,
+ ClientSrcAddr: clientSrcAddr,
+ ClientDstAddr: clientDstAddr,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
}
- return nil, trace.BadParameter("failed to connect to any of the database servers")
-}
-// isReverseTunnelDownError returns true if the provided error indicates that
-// the reverse tunnel connection is down e.g. because the agent is down.
-func isReverseTunnelDownError(err error) bool {
- return trace.IsConnectionProblem(err) ||
- strings.Contains(err.Error(), reversetunnelclient.NoDatabaseTunnel)
+ return serviceConn, nil
}
// Proxy starts proxying all traffic received from database client between
@@ -567,94 +519,28 @@ func (s *ProxyServer) Authorize(ctx context.Context, tlsConn utils.TLSConn, para
if params.ClientIP != "" {
identity.LoginIP = params.ClientIP
}
- cluster, servers, err := s.getDatabaseServers(ctx, identity)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- return &common.ProxyContext{
- Identity: identity,
- Cluster: cluster,
- Servers: servers,
- AuthContext: authContext,
- }, nil
-}
-
-// getDatabaseServers finds database servers that proxy the database instance
-// encoded in the provided identity.
-func (s *ProxyServer) getDatabaseServers(ctx context.Context, identity tlsca.Identity) (reversetunnelclient.RemoteSite, []types.DatabaseServer, error) {
cluster, err := s.cfg.Tunnel.GetSite(identity.RouteToCluster)
- if err != nil {
- return nil, nil, trace.Wrap(err)
- }
- accessPoint, err := cluster.CachingAccessPoint()
- if err != nil {
- return nil, nil, trace.Wrap(err)
- }
- servers, err := accessPoint.GetDatabaseServers(ctx, apidefaults.Namespace)
- if err != nil {
- return nil, nil, trace.Wrap(err)
- }
- s.log.DebugContext(ctx, "Available database servers.", "cluster", cluster.GetName(), "servers", servers)
- // Find out which database servers proxy the database a user is
- // connecting to using routing information from identity.
- var result []types.DatabaseServer
- for _, server := range servers {
- if server.GetDatabase().GetName() == identity.RouteToDatabase.ServiceName {
- result = append(result, server)
- }
- }
- if len(result) != 0 {
- return cluster, result, nil
- }
- return nil, nil, trace.NotFound("database %q not found among registered databases in cluster %q",
- identity.RouteToDatabase.ServiceName,
- identity.RouteToCluster)
-}
-
-// getConfigForServer returns TLS config used for establishing connection
-// to a remote database server over reverse tunnel.
-func (s *ProxyServer) getConfigForServer(ctx context.Context, identity tlsca.Identity, server types.DatabaseServer) (*tls.Config, error) {
- defer observeLatency(tlsConfigTime.With(getLabelsFromDB(server.GetDatabase())))()
-
- privateKey, err := cryptosuites.GenerateKey(ctx,
- cryptosuites.GetCurrentSuiteFromAuthPreference(s.cfg.AccessPoint),
- cryptosuites.ProxyToDatabaseAgent)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- subject, err := identity.Subject()
if err != nil {
return nil, trace.Wrap(err)
}
- csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey)
+ accessPoint, err := cluster.CachingAccessPoint()
if err != nil {
return nil, trace.Wrap(err)
}
-
- response, err := s.cfg.AuthClient.SignDatabaseCSR(ctx, &proto.DatabaseCSRRequest{
- CSR: csr,
- ClusterName: identity.RouteToCluster,
+ servers, err := connect.GetDatabaseServers(ctx, connect.GetDatabaseServersParams{
+ Logger: s.log,
+ ClusterName: cluster.GetName(),
+ DatabaseServersGetter: accessPoint,
+ Identity: identity,
})
if err != nil {
return nil, trace.Wrap(err)
}
-
- cert, err := keys.TLSCertificateForSigner(privateKey, response.Cert)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- pool := x509.NewCertPool()
- for _, caCert := range response.CACerts {
- ok := pool.AppendCertsFromPEM(caCert)
- if !ok {
- return nil, trace.BadParameter("failed to append CA certificate")
- }
- }
-
- return &tls.Config{
- ServerName: server.GetHostname(),
- Certificates: []tls.Certificate{cert},
- RootCAs: pool,
+ return &common.ProxyContext{
+ Identity: identity,
+ Cluster: cluster,
+ Servers: servers,
+ AuthContext: authContext,
}, nil
}