From a741ec957dfe78caaa9a64eeb8be5860935d9e9c Mon Sep 17 00:00:00 2001 From: Stefano Scafiti Date: Thu, 30 May 2024 10:04:51 +0200 Subject: [PATCH] Log request information as transaction metadata Signed-off-by: Stefano Scafiti --- cmd/immudb/command/init.go | 1 + cmd/immudb/command/parse_options.go | 4 +- embedded/store/immustore.go | 3 +- embedded/store/ongoing_tx.go | 2 +- embedded/store/tx_metadata.go | 4 + pkg/api/schema/metadata.go | 101 +++++++++++++++++++++ pkg/api/schema/metadata_test.go | 58 ++++++++++++ pkg/auth/passwords.go | 6 +- pkg/auth/serverinterceptors.go | 6 +- pkg/auth/user.go | 30 +++--- pkg/database/all_ops.go | 1 - pkg/database/database.go | 67 +++++++++++--- pkg/database/sql.go | 10 ++ pkg/integration/client_test.go | 32 +++++++ pkg/pgsql/server/initialize_session.go | 1 + pkg/pgsql/server/options.go | 6 ++ pkg/pgsql/server/query_machine.go | 14 ++- pkg/pgsql/server/server.go | 23 ++--- pkg/pgsql/server/session.go | 68 ++++++++++---- pkg/server/options.go | 7 ++ pkg/server/request_metadata_interceptor.go | 80 ++++++++++++++++ pkg/server/server.go | 6 +- pkg/server/servertest/server.go | 2 + 23 files changed, 466 insertions(+), 66 deletions(-) create mode 100644 pkg/api/schema/metadata.go create mode 100644 pkg/api/schema/metadata_test.go create mode 100644 pkg/server/request_metadata_interceptor.go diff --git a/cmd/immudb/command/init.go b/cmd/immudb/command/init.go index 92db3eaa8d..542e01764c 100644 --- a/cmd/immudb/command/init.go +++ b/cmd/immudb/command/init.go @@ -90,6 +90,7 @@ func (cl *Commandline) setupFlags(cmd *cobra.Command, options *server.Options) { cmd.Flags().MarkHidden("sessions-guard-check-interval") cmd.Flags().Bool("grpc-reflection", options.GRPCReflectionServerEnabled, "GRPC reflection server enabled") cmd.Flags().Bool("swaggerui", options.SwaggerUIEnabled, "Swagger UI enabled") + cmd.Flags().Bool("log-request-metadata", options.LogRequestMetadata, "log request information in transaction metadata") flagNameMapping := map[string]string{ "replication-enabled": "replication-is-replica", diff --git a/cmd/immudb/command/parse_options.go b/cmd/immudb/command/parse_options.go index b4132f23bf..328b901d82 100644 --- a/cmd/immudb/command/parse_options.go +++ b/cmd/immudb/command/parse_options.go @@ -84,6 +84,7 @@ func parseOptions() (options *server.Options, err error) { grpcReflectionServerEnabled := viper.GetBool("grpc-reflection") swaggerUIEnabled := viper.GetBool("swaggerui") + logRequestMetadata := viper.GetBool("log-request-metadata") s3Storage := viper.GetBool("s3-storage") s3RoleEnabled := viper.GetBool("s3-role-enabled") @@ -153,7 +154,8 @@ func parseOptions() (options *server.Options, err error) { WithPProf(pprof). WithLogFormat(logFormat). WithSwaggerUIEnabled(swaggerUIEnabled). - WithGRPCReflectionServerEnabled(grpcReflectionServerEnabled) + WithGRPCReflectionServerEnabled(grpcReflectionServerEnabled). + WithLogRequestMetadata(logRequestMetadata) return options, nil } diff --git a/embedded/store/immustore.go b/embedded/store/immustore.go index 4cc70461f6..e7ad6a04c3 100644 --- a/embedded/store/immustore.go +++ b/embedded/store/immustore.go @@ -1549,7 +1549,8 @@ func (s *ImmuStore) precommit(ctx context.Context, otx *OngoingTx, hdr *TxHeader return nil, fmt.Errorf("%w: transaction does not validate against header", err) } - if len(otx.entries) == 0 && otx.metadata.IsEmpty() { + // extra metadata are specified by the client and thus they are only allowed when entries is non empty + if len(otx.entries) == 0 && (otx.metadata.IsEmpty() || otx.metadata.HasExtraOnly()) { return nil, ErrNoEntriesProvided } diff --git a/embedded/store/ongoing_tx.go b/embedded/store/ongoing_tx.go index 2ec7205708..b4a1b6d041 100644 --- a/embedded/store/ongoing_tx.go +++ b/embedded/store/ongoing_tx.go @@ -178,7 +178,7 @@ func (tx *OngoingTx) IsReadOnly() bool { func (tx *OngoingTx) WithMetadata(md *TxMetadata) *OngoingTx { tx.metadata = md - return nil + return tx } func (tx *OngoingTx) Timestamp() time.Time { diff --git a/embedded/store/tx_metadata.go b/embedded/store/tx_metadata.go index 0f21ec9b4a..e144e9bd63 100644 --- a/embedded/store/tx_metadata.go +++ b/embedded/store/tx_metadata.go @@ -131,6 +131,10 @@ func (md *TxMetadata) IsEmpty() bool { return md == nil || len(md.attributes) == 0 } +func (md *TxMetadata) HasExtraOnly() bool { + return len(md.attributes) == 1 && md.Extra() != nil +} + func (md *TxMetadata) Equal(amd *TxMetadata) bool { if amd == nil || md == nil { return false diff --git a/pkg/api/schema/metadata.go b/pkg/api/schema/metadata.go new file mode 100644 index 0000000000..1757ad5891 --- /dev/null +++ b/pkg/api/schema/metadata.go @@ -0,0 +1,101 @@ +package schema + +import ( + "context" + "errors" +) + +const maxMetadataLen = 256 + +var ( + ErrEmptyMetadataKey = errors.New("metadata key cannot be empty") + ErrEmptyMetadataValue = errors.New("metadata value cannot be empty") + ErrMetadataTooLarge = errors.New("metadata exceeds maximum size") + ErrCorruptedMetadata = errors.New("corrupted metadata") +) + +const ( + UserRequestMetadataKey = "usr" + IpRequestMetadataKey = "ip" +) + +type Metadata map[string]string + +func (m Metadata) Marshal() ([]byte, error) { + if err := m.validate(); err != nil { + return nil, err + } + + var data [maxMetadataLen]byte + + off := 0 + for k, v := range m { + data[off] = byte(len(k) - 1) + data[off+1] = byte(len(v) - 1) + + off += 2 + copy(data[off:], []byte(k)) + off += len(k) + + copy(data[off:], []byte(v)) + off += len(v) + } + return data[:off], nil +} + +func (m Metadata) validate() error { + size := 0 + for k, v := range m { + if len(k) == 0 { + return ErrEmptyMetadataKey + } + + if len(v) == 0 { + return ErrEmptyMetadataValue + } + + size += len(k) + len(v) + 2 + + if size > maxMetadataLen { + return ErrMetadataTooLarge + } + } + return nil +} + +func (m Metadata) Unmarshal(data []byte) error { + off := 0 + for off <= len(data)-2 { + keySize := int(data[off]) + 1 + valueSize := int(data[off+1]) + 1 + + off += 2 + + if off+keySize+valueSize > len(data) { + return ErrCorruptedMetadata + } + + m[string(data[off:off+keySize])] = string(data[off+keySize : off+keySize+valueSize]) + + off += keySize + valueSize + } + + if off != len(data) { + return ErrCorruptedMetadata + } + return nil +} + +type metadataKey struct{} + +func ContextWithMetadata(ctx context.Context, md Metadata) context.Context { + return context.WithValue(ctx, metadataKey{}, md) +} + +func MetadataFromContext(ctx context.Context) Metadata { + md, ok := ctx.Value(metadataKey{}).(Metadata) + if !ok { + return nil + } + return md +} diff --git a/pkg/api/schema/metadata_test.go b/pkg/api/schema/metadata_test.go new file mode 100644 index 0000000000..74ac47bd0a --- /dev/null +++ b/pkg/api/schema/metadata_test.go @@ -0,0 +1,58 @@ +package schema + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMetadataMarshalUnmarshal(t *testing.T) { + meta := Metadata{ + "user": "default", + "ip": "127.0.0.1:8080", + } + + data, err := meta.Marshal() + require.NoError(t, err) + + t.Run("valid metadata", func(t *testing.T) { + unmarshalled := Metadata{} + err := unmarshalled.Unmarshal(data) + require.NoError(t, err) + require.Equal(t, meta, unmarshalled) + }) + + t.Run("corrupted metadata", func(t *testing.T) { + unmarshalled := Metadata{} + err := unmarshalled.Unmarshal(data[:len(data)/2]) + require.ErrorIs(t, err, ErrCorruptedMetadata) + }) + + t.Run("empty metadata", func(t *testing.T) { + m := Metadata{} + data, err := m.Marshal() + require.NoError(t, err) + require.Empty(t, data) + + unmarshalled := Metadata{} + err = unmarshalled.Unmarshal([]byte{}) + require.NoError(t, err) + require.Empty(t, unmarshalled) + }) + + t.Run("invalid metadata", func(t *testing.T) { + x := make([]byte, 256) + + m := Metadata{"x": string(x)} + _, err := m.Marshal() + require.ErrorIs(t, err, ErrMetadataTooLarge) + + m = Metadata{"": "v"} + _, err = m.Marshal() + require.ErrorIs(t, err, ErrEmptyMetadataKey) + + m = Metadata{"k": ""} + _, err = m.Marshal() + require.ErrorIs(t, err, ErrEmptyMetadataValue) + }) +} diff --git a/pkg/auth/passwords.go b/pkg/auth/passwords.go index e64bb29c06..c84de9a48d 100644 --- a/pkg/auth/passwords.go +++ b/pkg/auth/passwords.go @@ -40,8 +40,10 @@ func ComparePasswords(hashedPassword []byte, plainPassword []byte) error { return bcrypt.CompareHashAndPassword(hashedPassword, plainPassword) } -const minPasswordLen = 8 -const maxPasswordLen = 32 +const ( + minPasswordLen = 8 + maxPasswordLen = 32 +) // PasswordRequirementsMsg message used to inform the user about password strength requirements var PasswordRequirementsMsg = fmt.Sprintf( diff --git a/pkg/auth/serverinterceptors.go b/pkg/auth/serverinterceptors.go index 396c0014db..6fc6ecfff7 100644 --- a/pkg/auth/serverinterceptors.go +++ b/pkg/auth/serverinterceptors.go @@ -73,9 +73,9 @@ func ServerUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.Una } var localAddress = map[string]struct{}{ - "127.0.0.1": struct{}{}, - "localhost": struct{}{}, - "bufconn": struct{}{}, + "127.0.0.1": {}, + "localhost": {}, + "bufconn": {}, } func isLocalClient(ctx context.Context) bool { diff --git a/pkg/auth/user.go b/pkg/auth/user.go index 5d5545fee5..57e80a0b69 100644 --- a/pkg/auth/user.go +++ b/pkg/auth/user.go @@ -39,11 +39,13 @@ type User struct { CreatedAt time.Time `json:"createdat"` //time in which this user is created/updated } -// SysAdminUsername the system admin username -var SysAdminUsername = "immudb" +var ( + // SysAdminUsername the system admin username + SysAdminUsername = "immudb" -// SysAdminPassword the admin password (can be default or from command flags, config or env var) -var SysAdminPassword = SysAdminUsername + // SysAdminPassword the admin password (can be default or from command flags, config or env var) + SysAdminPassword = SysAdminUsername +) // SetPassword Hashes and salts the password and assigns it to hashedPassword of User func (u *User) SetPassword(plainPassword []byte) ([]byte, error) { @@ -63,10 +65,16 @@ func (u *User) ComparePasswords(plainPassword []byte) error { return ComparePasswords(u.HashedPassword, plainPassword) } -// IsValidUsername is a regexp function used to check username requirements -var IsValidUsername = regexp.MustCompile(`^[a-zA-Z0-9_]+$`).MatchString +const maxUsernameLen = 63 + +var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) + +// IsValidUsername is a function used to check username requirements +func IsValidUsername(user string) bool { + return len(user) <= maxUsernameLen && usernameRegex.MatchString(user) +} -//HasPermission checks if user has such permission for this database +// HasPermission checks if user has such permission for this database func (u *User) HasPermission(database string, permission uint32) bool { for _, val := range u.Permissions { if (val.Database == database) && @@ -77,7 +85,7 @@ func (u *User) HasPermission(database string, permission uint32) bool { return false } -//HasAtLeastOnePermission checks if user has this permission for at least one database +// HasAtLeastOnePermission checks if user has this permission for at least one database func (u *User) HasAtLeastOnePermission(permission uint32) bool { for _, val := range u.Permissions { if val.Permission == permission { @@ -87,7 +95,7 @@ func (u *User) HasAtLeastOnePermission(permission uint32) bool { return false } -//WhichPermission returns the permission that this user has on this database +// WhichPermission returns the permission that this user has on this database func (u *User) WhichPermission(database string) uint32 { if u.IsSysAdmin { return PermissionSysAdmin @@ -100,7 +108,7 @@ func (u *User) WhichPermission(database string) uint32 { return PermissionNone } -//RevokePermission revoke database permission from user +// RevokePermission revoke database permission from user func (u *User) RevokePermission(database string) bool { for i, val := range u.Permissions { if val.Database == database { @@ -112,7 +120,7 @@ func (u *User) RevokePermission(database string) bool { return false } -//GrantPermission add permission to database +// GrantPermission add permission to database func (u *User) GrantPermission(database string, permission uint32) bool { //first remove any previous permission for this db u.RevokePermission(database) diff --git a/pkg/database/all_ops.go b/pkg/database/all_ops.go index 584f1ce95c..eb5278ec1e 100644 --- a/pkg/database/all_ops.go +++ b/pkg/database/all_ops.go @@ -62,7 +62,6 @@ func (d *db) ExecAll(ctx context.Context, req *schema.ExecAllRequest) (*schema.T kmap := make(map[[sha256.Size]byte]bool) for i, op := range req.Operations { - e := &store.EntrySpec{} switch x := op.Operation.(type) { diff --git a/pkg/database/database.go b/pkg/database/database.go index 0f266455a7..3c6fcdc0ee 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -36,18 +36,22 @@ import ( "github.com/codenotary/immudb/pkg/api/schema" ) -const MaxKeyResolutionLimit = 1 -const MaxKeyScanLimit = 2500 - -var ErrKeyResolutionLimitReached = errors.New("key resolution limit reached. It may be due to cyclic references") -var ErrResultSizeLimitExceeded = errors.New("result size limit exceeded") -var ErrResultSizeLimitReached = errors.New("result size limit reached") -var ErrIllegalArguments = store.ErrIllegalArguments -var ErrIllegalState = store.ErrIllegalState -var ErrIsReplica = errors.New("database is read-only because it's a replica") -var ErrNotReplica = errors.New("database is NOT a replica") -var ErrReplicaDivergedFromPrimary = errors.New("replica diverged from primary") -var ErrInvalidRevision = errors.New("invalid key revision number") +const ( + MaxKeyResolutionLimit = 1 + MaxKeyScanLimit = 2500 +) + +var ( + ErrKeyResolutionLimitReached = errors.New("key resolution limit reached. It may be due to cyclic references") + ErrResultSizeLimitExceeded = errors.New("result size limit exceeded") + ErrResultSizeLimitReached = errors.New("result size limit reached") + ErrIllegalArguments = store.ErrIllegalArguments + ErrIllegalState = store.ErrIllegalState + ErrIsReplica = errors.New("database is read-only because it's a replica") + ErrNotReplica = errors.New("database is NOT a replica") + ErrReplicaDivergedFromPrimary = errors.New("replica diverged from primary") + ErrInvalidRevision = errors.New("invalid key revision number") +) type DB interface { GetName() string @@ -388,7 +392,7 @@ func (d *db) set(ctx context.Context, req *schema.SetRequest) (*schema.TxHeader, return nil, ErrIllegalArguments } - tx, err := d.st.NewWriteOnlyTx(ctx) + tx, err := d.newWriteOnlyTx(ctx) if err != nil { return nil, err } @@ -421,7 +425,6 @@ func (d *db) set(ctx context.Context, req *schema.SetRequest) (*schema.TxHeader, } for i := range req.Preconditions { - c, err := PreconditionFromProto(req.Preconditions[i]) if err != nil { return nil, err @@ -447,6 +450,40 @@ func (d *db) set(ctx context.Context, req *schema.SetRequest) (*schema.TxHeader, return schema.TxHeaderToProto(hdr), nil } +func (d *db) newWriteOnlyTx(ctx context.Context) (*store.OngoingTx, error) { + tx, err := d.st.NewWriteOnlyTx(ctx) + if err != nil { + return nil, err + } + return d.txWithMetadata(ctx, tx) +} + +func (d *db) newTx(ctx context.Context, opts *store.TxOptions) (*store.OngoingTx, error) { + tx, err := d.st.NewTx(ctx, opts) + if err != nil { + return nil, err + } + return d.txWithMetadata(ctx, tx) +} + +func (d *db) txWithMetadata(ctx context.Context, tx *store.OngoingTx) (*store.OngoingTx, error) { + meta := schema.MetadataFromContext(ctx) + if len(meta) > 0 { + txmd := store.NewTxMetadata() + + data, err := meta.Marshal() + if err != nil { + return nil, err + } + + if err := txmd.WithExtra(data); err != nil { + return nil, err + } + return tx.WithMetadata(txmd), nil + } + return tx, nil +} + func checkKeyRequest(req *schema.KeyRequest) error { if req == nil { return fmt.Errorf( @@ -853,7 +890,7 @@ func (d *db) Delete(ctx context.Context, req *schema.DeleteKeysRequest) (*schema }) } - tx, err := d.st.NewTx(ctx, opts) + tx, err := d.newTx(ctx, opts) if err != nil { return nil, err } diff --git a/pkg/database/sql.go b/pkg/database/sql.go index 77f8cf73d7..ab1f973197 100644 --- a/pkg/database/sql.go +++ b/pkg/database/sql.go @@ -304,6 +304,16 @@ func (d *db) NewSQLTx(ctx context.Context, opts *sql.TxOptions) (tx *sql.SQLTx, }() go func() { + md := schema.MetadataFromContext(ctx) + if len(md) > 0 { + data, err := md.Marshal() + if err != nil { + errChan <- err + return + } + opts = opts.WithExtra(data) + } + tx, err = d.sqlEngine.NewTx(txCtx, opts) if err != nil { errChan <- err diff --git a/pkg/integration/client_test.go b/pkg/integration/client_test.go index cd5e5dd5ae..ab01efbb5a 100644 --- a/pkg/integration/client_test.go +++ b/pkg/integration/client_test.go @@ -69,6 +69,7 @@ func setupTestServerAndClient(t *testing.T) (*servertest.BufconnServer, ic.ImmuC WithWebServer(true). WithDir(filepath.Join(t.TempDir(), "data")). WithAuth(true). + WithLogRequestMetadata(true). WithSigningKey("./../../test/signer/ec1.key"), ) @@ -1650,3 +1651,34 @@ func (ts TokenServiceMock) WithHds(hds homedir.HomedirService) tokenservice.Toke func (ts TokenServiceMock) WithTokenFileName(tfn string) tokenservice.TokenService { return ts } + +func TestServerLogRequestMetadata(t *testing.T) { + _, client, ctx := setupTestServerAndClient(t) + + requireMetadataPresent := func(hdr *schema.TxHeader) { + txmd := schema.Metadata{} + err := txmd.Unmarshal(hdr.Metadata.Extra) + require.NoError(t, err) + + require.Equal(t, schema.Metadata{schema.UserRequestMetadataKey: auth.SysAdminUsername, schema.IpRequestMetadataKey: "bufconn"}, txmd) + } + + hdr, err := client.Set(ctx, []byte("test"), []byte("test")) + require.NoError(t, err) + + requireMetadataPresent(hdr) + + hdr1, err := client.VerifiedSet(ctx, []byte("test"), []byte("test")) + require.NoError(t, err) + + requireMetadataPresent(hdr1) + require.NoError(t, err) + + _, err = client.SQLExec(ctx, "CREATE TABLE mytable (id INTEGER, PRIMARY KEY id)", nil) + require.NoError(t, err) + + tx, err := client.TxByID(ctx, 3) + require.NoError(t, err) + + requireMetadataPresent(tx.Header) +} diff --git a/pkg/pgsql/server/initialize_session.go b/pkg/pgsql/server/initialize_session.go index 8b0dff5346..8f32763214 100644 --- a/pkg/pgsql/server/initialize_session.go +++ b/pkg/pgsql/server/initialize_session.go @@ -154,6 +154,7 @@ func (s *session) HandleStartup(ctx context.Context) (err error) { if !ok || user == "" { return pserr.ErrUsernameNotprovided } + s.user = user db, ok := s.connParams["database"] if !ok { diff --git a/pkg/pgsql/server/options.go b/pkg/pgsql/server/options.go index 76aa0f93ea..f01a32af87 100644 --- a/pkg/pgsql/server/options.go +++ b/pkg/pgsql/server/options.go @@ -55,6 +55,12 @@ func TLSConfig(tlsConfig *tls.Config) Option { } } +func LogRequestMetadata(enabled bool) Option { + return func(args *pgsrv) { + args.logRequestMetadata = enabled + } +} + func DatabaseList(dbList database.DatabaseList) Option { return func(args *pgsrv) { args.dbList = dbList diff --git a/pkg/pgsql/server/query_machine.go b/pkg/pgsql/server/query_machine.go index 9272247b54..c716124e7c 100644 --- a/pkg/pgsql/server/query_machine.go +++ b/pkg/pgsql/server/query_machine.go @@ -303,7 +303,12 @@ func (s *session) fetchAndWriteResults(statements string, parameters []*schema.N } func (s *session) query(st *sql.SelectStmt, parameters []*schema.NamedParam, resultColumnFormatCodes []int16, skipRowDesc bool) error { - reader, err := s.db.SQLQueryPrepared(s.ctx, s.tx, st, schema.NamedParamsFromProto(parameters)) + tx, err := s.sqlTx() + if err != nil { + return err + } + + reader, err := s.db.SQLQueryPrepared(s.ctx, tx, st, schema.NamedParamsFromProto(parameters)) if err != nil { return err } @@ -332,7 +337,12 @@ func (s *session) exec(st sql.SQLStmt, namedParams []*schema.NamedParam, resultC params[p.Name] = schema.RawValue(p.Value) } - ntx, _, err := s.db.SQLExecPrepared(s.ctx, s.tx, []sql.SQLStmt{st}, params) + tx, err := s.sqlTx() + if err != nil { + return err + } + + ntx, _, err := s.db.SQLExecPrepared(s.ctx, tx, []sql.SQLStmt{st}, params) s.tx = ntx return err diff --git a/pkg/pgsql/server/server.go b/pkg/pgsql/server/server.go index 3b6f7c216d..2087ca20b6 100644 --- a/pkg/pgsql/server/server.go +++ b/pkg/pgsql/server/server.go @@ -31,16 +31,17 @@ import ( ) type pgsrv struct { - m sync.RWMutex - running bool - maxConnections int - tlsConfig *tls.Config - logger logger.Logger - host string - port int - immudbPort int - dbList database.DatabaseList - listener net.Listener + m sync.RWMutex + running bool + maxConnections int + tlsConfig *tls.Config + logger logger.Logger + logRequestMetadata bool + host string + port int + immudbPort int + dbList database.DatabaseList + listener net.Listener } type PGSQLServer interface { @@ -104,7 +105,7 @@ func (s *pgsrv) Serve() (err error) { } func (s *pgsrv) newSession(conn net.Conn) Session { - return newSession(conn, s.host, s.immudbPort, s.logger, s.tlsConfig, s.dbList) + return newSession(conn, s.host, s.immudbPort, s.logger, s.tlsConfig, s.logRequestMetadata, s.dbList) } func (s *pgsrv) Stop() (err error) { diff --git a/pkg/pgsql/server/session.go b/pkg/pgsql/server/session.go index d35d1fd23a..a76278064f 100644 --- a/pkg/pgsql/server/session.go +++ b/pkg/pgsql/server/session.go @@ -19,11 +19,13 @@ package server import ( "context" "crypto/tls" + "strings" "net" "github.com/codenotary/immudb/embedded/logger" "github.com/codenotary/immudb/embedded/sql" + "github.com/codenotary/immudb/pkg/api/schema" "github.com/codenotary/immudb/pkg/client" "github.com/codenotary/immudb/pkg/database" "github.com/codenotary/immudb/pkg/pgsql/errors" @@ -32,18 +34,21 @@ import ( ) type session struct { - immudbHost string - immudbPort int - tlsConfig *tls.Config - log logger.Logger + immudbHost string + immudbPort int + tlsConfig *tls.Config + log logger.Logger + logRequestMetadata bool dbList database.DatabaseList client client.ImmuClient - ctx context.Context - db database.DB - tx *sql.SQLTx + ctx context.Context + user string + ipAddr string + db database.DB + tx *sql.SQLTx mr MessageReader @@ -62,18 +67,32 @@ type Session interface { Close() error } -func newSession(c net.Conn, immudbHost string, immudbPort int, - log logger.Logger, tlsConfig *tls.Config, dbList database.DatabaseList) *session { +func newSession( + c net.Conn, + immudbHost string, + immudbPort int, + log logger.Logger, + tlsConfig *tls.Config, + logRequestMetadata bool, + dbList database.DatabaseList, +) *session { + addr := c.RemoteAddr().String() + i := strings.Index(addr, ":") + if i >= 0 { + addr = addr[:i] + } return &session{ - immudbHost: immudbHost, - immudbPort: immudbPort, - tlsConfig: tlsConfig, - log: log, - dbList: dbList, - mr: NewMessageReader(c), - statements: make(map[string]*statement), - portals: make(map[string]*portal), + immudbHost: immudbHost, + immudbPort: immudbPort, + tlsConfig: tlsConfig, + log: log, + logRequestMetadata: logRequestMetadata, + dbList: dbList, + ipAddr: addr, + mr: NewMessageReader(c), + statements: make(map[string]*statement), + portals: make(map[string]*portal), } } @@ -140,3 +159,18 @@ func (s *session) writeMessage(msg []byte) (int, error) { return s.mr.Write(msg) } + +func (s *session) sqlTx() (*sql.SQLTx, error) { + if s.tx != nil || !s.logRequestMetadata { + return s.tx, nil + } + + md := schema.Metadata{ + schema.UserRequestMetadataKey: s.user, + schema.IpRequestMetadataKey: s.ipAddr, + } + + // create transaction explicitly to inject request metadata + ctx := schema.ContextWithMetadata(s.ctx, md) + return s.db.NewSQLTx(ctx, sql.DefaultTxOptions()) +} diff --git a/pkg/server/options.go b/pkg/server/options.go index 91018e7115..f2fa5fc9e2 100644 --- a/pkg/server/options.go +++ b/pkg/server/options.go @@ -76,6 +76,7 @@ type Options struct { LogFormat string GRPCReflectionServerEnabled bool SwaggerUIEnabled bool + LogRequestMetadata bool } type RemoteStorageOptions struct { @@ -145,6 +146,7 @@ func DefaultOptions() *Options { PProf: false, GRPCReflectionServerEnabled: true, SwaggerUIEnabled: true, + LogRequestMetadata: false, } } @@ -496,6 +498,11 @@ func (o *Options) WithSwaggerUIEnabled(enabled bool) *Options { return o } +func (o *Options) WithLogRequestMetadata(enabled bool) *Options { + o.LogRequestMetadata = enabled + return o +} + // RemoteStorageOptions func (opts *RemoteStorageOptions) WithS3Storage(S3Storage bool) *RemoteStorageOptions { diff --git a/pkg/server/request_metadata_interceptor.go b/pkg/server/request_metadata_interceptor.go new file mode 100644 index 0000000000..e9daa6e409 --- /dev/null +++ b/pkg/server/request_metadata_interceptor.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + "strings" + + "github.com/codenotary/immudb/pkg/api/schema" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +func (s *ImmuServer) InjectRequestMetadataUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if !s.Options.LogRequestMetadata { + return handler(ctx, req) + } + return handler(s.withRequestMetadata(ctx), req) +} + +func (s *ImmuServer) InjectRequestMetadataStreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx := ss.Context() + if !s.Options.LogRequestMetadata { + return handler(srv, ss) + } + return handler(srv, &serverStreamWithContext{ServerStream: ss, ctx: s.withRequestMetadata(ctx)}) +} + +type serverStreamWithContext struct { + grpc.ServerStream + ctx context.Context +} + +func (s *serverStreamWithContext) Context() context.Context { + return s.ctx +} + +func (s *ImmuServer) withRequestMetadata(ctx context.Context) context.Context { + if !s.Options.LogRequestMetadata { + return ctx + } + + _, user, err := s.getLoggedInUserdataFromCtx(ctx) + if err != nil { + return ctx + } + + md := schema.Metadata{ + schema.UserRequestMetadataKey: user.Username, + } + + ip := ipAddrFromContext(ctx) + if len(ip) > 0 { + md[schema.IpRequestMetadataKey] = ip + } + return schema.ContextWithMetadata(ctx, md) +} + +func ipAddrFromContext(ctx context.Context) string { + md, ok := metadata.FromIncomingContext(ctx) + if ok { + // check for the headers forwarded by GRPC-gateway + if xffValues, ok := md["x-forwarded-for"]; ok && len(xffValues) > 0 { + return xffValues[0] + } else if xriValues, ok := md["x-real-ip"]; ok && len(xriValues) > 0 { + return xriValues[0] + } + } + + p, ok := peer.FromContext(ctx) + if !ok { + return "" + } + + addr := p.Addr.Network() + i := strings.Index(addr, ":") + if i < 0 { + return addr + } + return addr[:i] +} diff --git a/pkg/server/server.go b/pkg/server/server.go index c319739d5c..c91483b7d4 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -225,6 +225,7 @@ func (s *ImmuServer) Initialize() error { grpc_prometheus.UnaryServerInterceptor, auth.ServerUnaryInterceptor, s.SessionAuthInterceptor, + s.InjectRequestMetadataUnaryInterceptor, } sss := []grpc.StreamServerInterceptor{ ErrorMapperStream, // converts errors in gRPC ones. Need to be the first @@ -232,7 +233,9 @@ func (s *ImmuServer) Initialize() error { uuidContext.UUIDStreamContextSetter, grpc_prometheus.StreamServerInterceptor, auth.ServerStreamInterceptor, + s.InjectRequestMetadataStreamInterceptor, } + grpcSrvOpts = append( grpcSrvOpts, grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(uis...)), @@ -258,6 +261,7 @@ func (s *ImmuServer) Initialize() error { pgsqlsrv.TLSConfig(s.Options.TLSConfig), pgsqlsrv.Logger(s.Logger), pgsqlsrv.DatabaseList(s.dbList), + pgsqlsrv.LogRequestMetadata(s.Options.LogRequestMetadata), ) if err = s.PgsqlSrv.Initialize(); err != nil { @@ -313,7 +317,7 @@ func (s *ImmuServer) Start() (err error) { if s.Options.WebServer { if err := s.setUpWebServer(context.Background()); err != nil { - log.Fatal(fmt.Sprintf("failed to setup web API/console server: %v", err)) + log.Fatalf("failed to setup web API/console server: %v", err) } defer func() { if err := s.webServer.Close(); err != nil { diff --git a/pkg/server/servertest/server.go b/pkg/server/servertest/server.go index 34fa0ebb7f..3a8f4ed11b 100644 --- a/pkg/server/servertest/server.go +++ b/pkg/server/servertest/server.go @@ -89,12 +89,14 @@ func (bs *BufconnServer) setupGrpcServer() { uuidContext.UUIDContextSetter, auth.ServerUnaryInterceptor, bs.immuServer.SessionAuthInterceptor, + bs.immuServer.InjectRequestMetadataUnaryInterceptor, )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( server.ErrorMapperStream, bs.immuServer.KeepALiveSessionStreamInterceptor, uuidContext.UUIDStreamContextSetter, auth.ServerStreamInterceptor, + bs.immuServer.InjectRequestMetadataStreamInterceptor, )), ) }