diff --git a/lib/srv/db/common/role/role.go b/lib/srv/db/common/role/role.go index 149fd620d0953..013cb641381ab 100644 --- a/lib/srv/db/common/role/role.go +++ b/lib/srv/db/common/role/role.go @@ -103,6 +103,10 @@ func databaseNameMatcher(dbProtocol, database string) *services.DatabaseNameMatc defaults.ProtocolOpenSearch, // DynamoDB integration doesn't support schema access control. defaults.ProtocolDynamoDB, + // Snowflake integration doesn't support schema access control. + defaults.ProtocolSnowflake, + // Oracle integration doesn't support schema access control. + defaults.ProtocolOracle, // Clickhouse Database Access doesn't support schema access control defaults.ProtocolClickHouse, defaults.ProtocolClickHouseHTTP: diff --git a/lib/srv/db/snowflake_test.go b/lib/srv/db/snowflake_test.go index 7984fa6cd95b6..319ada2c90e53 100644 --- a/lib/srv/db/snowflake_test.go +++ b/lib/srv/db/snowflake_test.go @@ -112,14 +112,13 @@ func TestAccessSnowflake(t *testing.T) { err: "HTTP: 401", }, { - desc: "no access to databases", + desc: "database name access is not enforced", user: "alice", role: "admin", allowDbNames: []string{}, allowDbUsers: []string{types.Wildcard}, dbName: "snowflake", dbUser: "snowflake", - err: "HTTP: 401", }, { desc: "no access to users", diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index 9332c5505867d..617a99f80a570 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -860,8 +860,8 @@ func (d *databaseInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClien // ensure the route protocol matches the db. d.Protocol = db.GetProtocol() - needDBUser := d.Username == "" && role.RequireDatabaseUserMatcher(d.Protocol) - needDBName := d.Database == "" && role.RequireDatabaseNameMatcher(d.Protocol) + needDBUser := d.Username == "" && isDatabaseUserRequired(d.Protocol) + needDBName := d.Database == "" && isDatabaseNameRequired(d.Protocol) if !needDBUser && !needDBName { return nil } @@ -1148,6 +1148,26 @@ func getDefaultDBUser(db types.Database, checker services.AccessChecker) (string return "", trace.BadParameter(errMsg) } +// isDatabaseUserRequired returns whether the --db-user flag is required for +// the db protocol. +func isDatabaseUserRequired(protocol string) bool { + return role.RequireDatabaseUserMatcher(protocol) +} + +// isDatabaseNameRequired returns whether the --db-name flag is required for +// the db protocol. +func isDatabaseNameRequired(protocol string) bool { + if role.RequireDatabaseNameMatcher(protocol) { + return true + } + switch protocol { + case defaults.ProtocolOracle: + // Always require database name for the Oracle protocol. + return true + } + return false +} + // getDefaultDBName enumerates the allowed database names for a given database // and selects one if it is the only non-wildcard database name allowed. // Returns an error if there are no allowed database names or more than one. @@ -1441,8 +1461,8 @@ func formatDatabaseConnectCommand(clusterFlag string, active tlsca.RouteToDataba // formatDatabaseConnectArgs generates the arguments for "tsh db connect" command. func formatDatabaseConnectArgs(clusterFlag string, active tlsca.RouteToDatabase) (flags []string) { // figure out if we need --db-user and --db-name - needUser := role.RequireDatabaseUserMatcher(active.Protocol) - needDatabase := role.RequireDatabaseNameMatcher(active.Protocol) + needUser := isDatabaseUserRequired(active.Protocol) + needDatabase := isDatabaseNameRequired(active.Protocol) if clusterFlag != "" { flags = append(flags, fmt.Sprintf("--cluster=%s", clusterFlag))