Skip to content

Commit

Permalink
chore(embedded/sql): Add support for core pg_catalog tables (pg_class…
Browse files Browse the repository at this point in the history
…, pg_namespace, pg_roles)

Signed-off-by: Stefano Scafiti <[email protected]>
  • Loading branch information
ostafen committed Dec 10, 2024
1 parent b7ff0e6 commit 8d9971c
Show file tree
Hide file tree
Showing 23 changed files with 1,136 additions and 331 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
*.out
coverage.txt

# Output of goyacc
embedded/sql/y.output

# Editor
.vscode
.idea
Expand Down Expand Up @@ -50,4 +53,4 @@ token_admin

swagger/dist
swagger/swaggerembedded
webconsole/webconsoleembedded
webconsole/webconsoleembedded
24 changes: 24 additions & 0 deletions embedded/sql/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ type Engine struct {
lazyIndexConstraintValidation bool
parseTxMetadata func([]byte) (map[string]interface{}, error)
multidbHandler MultiDBHandler
tableResolves map[string]TableResolver
}

type MultiDBHandler interface {
Expand All @@ -134,6 +135,11 @@ type MultiDBHandler interface {
ExecPreparedStmts(ctx context.Context, opts *TxOptions, stmts []SQLStmt, params map[string]interface{}) (ntx *SQLTx, committedTxs []*SQLTx, err error)
}

type TableResolver interface {
Table() string
Resolve(ctx context.Context, tx *SQLTx, alias string) (RowReader, error)
}

type User interface {
Username() string
Permission() Permission
Expand Down Expand Up @@ -176,6 +182,10 @@ func NewEngine(st *store.ImmuStore, opts *Options) (*Engine, error) {
return nil, err
}

for _, r := range opts.tableResolvers {
e.registerTableResolver(r.Table(), r)
}

// TODO: find a better way to handle parsing errors
yyErrorVerbose = true

Expand Down Expand Up @@ -728,3 +738,17 @@ func (e *Engine) GetStore() *store.ImmuStore {
func (e *Engine) GetPrefix() []byte {
return e.prefix
}

func (e *Engine) TableResolveFor(tableName string) TableResolver {
if e.tableResolves == nil {
return nil
}
return e.tableResolves[tableName]
}

func (e *Engine) registerTableResolver(tableName string, r TableResolver) {
if e.tableResolves == nil {
e.tableResolves = make(map[string]TableResolver)
}
e.tableResolves[tableName] = r
}
4 changes: 4 additions & 0 deletions embedded/sql/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4616,6 +4616,10 @@ func TestOrderBy(t *testing.T) {
directions: []int{1, -1},
positionalRefs: []int{4, 5},
},
{
exps: []string{"weight/(height*height)"},
directions: []int{1},
},
}

runTest := func(t *testing.T, test *test, expectedTempFiles int) []*Row {
Expand Down
177 changes: 134 additions & 43 deletions embedded/sql/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,45 @@ import (
)

const (
LengthFnCall string = "LENGTH"
SubstringFnCall string = "SUBSTRING"
ConcatFnCall string = "CONCAT"
LowerFnCall string = "LOWER"
UpperFnCall string = "UPPER"
TrimFnCall string = "TRIM"
NowFnCall string = "NOW"
UUIDFnCall string = "RANDOM_UUID"
DatabasesFnCall string = "DATABASES"
TablesFnCall string = "TABLES"
TableFnCall string = "TABLE"
UsersFnCall string = "USERS"
ColumnsFnCall string = "COLUMNS"
IndexesFnCall string = "INDEXES"
GrantsFnCall string = "GRANTS"
JSONTypeOfFnCall string = "JSON_TYPEOF"
LengthFnCall string = "LENGTH"
SubstringFnCall string = "SUBSTRING"
ConcatFnCall string = "CONCAT"
LowerFnCall string = "LOWER"
UpperFnCall string = "UPPER"
TrimFnCall string = "TRIM"
NowFnCall string = "NOW"
UUIDFnCall string = "RANDOM_UUID"
DatabasesFnCall string = "DATABASES"
TablesFnCall string = "TABLES"
TableFnCall string = "TABLE"
UsersFnCall string = "USERS"
ColumnsFnCall string = "COLUMNS"
IndexesFnCall string = "INDEXES"
GrantsFnCall string = "GRANTS"
JSONTypeOfFnCall string = "JSON_TYPEOF"
PGGetUserByIDFnCall string = "PG_GET_USERBYID"
PgTableIsVisible string = "PG_TABLE_IS_VISIBLE"
PgShobjDescription string = "SHOBJ_DESCRIPTION"
)

var builtinFunctions = map[string]Function{
LengthFnCall: &LengthFn{},
SubstringFnCall: &SubstringFn{},
ConcatFnCall: &ConcatFn{},
LowerFnCall: &LowerUpperFnc{},
UpperFnCall: &LowerUpperFnc{isUpper: true},
TrimFnCall: &TrimFnc{},
NowFnCall: &NowFn{},
UUIDFnCall: &UUIDFn{},
JSONTypeOfFnCall: &JsonTypeOfFn{},
LengthFnCall: &LengthFn{},
SubstringFnCall: &SubstringFn{},
ConcatFnCall: &ConcatFn{},
LowerFnCall: &LowerUpperFnc{},
UpperFnCall: &LowerUpperFnc{isUpper: true},
TrimFnCall: &TrimFnc{},
NowFnCall: &NowFn{},
UUIDFnCall: &UUIDFn{},
JSONTypeOfFnCall: &JsonTypeOfFn{},
PGGetUserByIDFnCall: &pgGetUserByIDFunc{},
PgTableIsVisible: &pgTableIsVisible{},
PgShobjDescription: &pgShobjDescription{},
}

type Function interface {
requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error
InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error)
Apply(tx *SQLTx, params []TypedValue) (TypedValue, error)
}

Expand All @@ -67,11 +73,11 @@ type Function interface {

type LengthFn struct{}

func (f *LengthFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *LengthFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return IntegerType, nil
}

func (f *LengthFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *LengthFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != IntegerType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
}
Expand All @@ -98,11 +104,11 @@ func (f *LengthFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {

type ConcatFn struct{}

func (f *ConcatFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *ConcatFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return VarcharType, nil
}

func (f *ConcatFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *ConcatFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != VarcharType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
}
Expand Down Expand Up @@ -131,11 +137,11 @@ func (f *ConcatFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {
type SubstringFn struct {
}

func (f *SubstringFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *SubstringFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return VarcharType, nil
}

func (f *SubstringFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *SubstringFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != VarcharType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
}
Expand Down Expand Up @@ -180,11 +186,11 @@ type LowerUpperFnc struct {
isUpper bool
}

func (f *LowerUpperFnc) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *LowerUpperFnc) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return VarcharType, nil
}

func (f *LowerUpperFnc) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *LowerUpperFnc) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != VarcharType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
}
Expand Down Expand Up @@ -226,11 +232,11 @@ func (f *LowerUpperFnc) name() string {
type TrimFnc struct {
}

func (f *TrimFnc) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *TrimFnc) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return VarcharType, nil
}

func (f *TrimFnc) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *TrimFnc) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != VarcharType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
}
Expand Down Expand Up @@ -261,11 +267,11 @@ func (f *TrimFnc) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {

type NowFn struct{}

func (f *NowFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *NowFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return TimestampType, nil
}

func (f *NowFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *NowFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != TimestampType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t)
}
Expand All @@ -285,11 +291,11 @@ func (f *NowFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {

type JsonTypeOfFn struct{}

func (f *JsonTypeOfFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *JsonTypeOfFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return VarcharType, nil
}

func (f *JsonTypeOfFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *JsonTypeOfFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != VarcharType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
}
Expand Down Expand Up @@ -319,11 +325,11 @@ func (f *JsonTypeOfFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error)

type UUIDFn struct{}

func (f *UUIDFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
func (f *UUIDFn) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return UUIDType, nil
}

func (f *UUIDFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
func (f *UUIDFn) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != UUIDType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t)
}
Expand All @@ -336,3 +342,88 @@ func (f *UUIDFn) Apply(_ *SQLTx, params []TypedValue) (TypedValue, error) {
}
return &UUID{val: uuid.New()}, nil
}

// pg functions

type pgGetUserByIDFunc struct{}

func (f *pgGetUserByIDFunc) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != VarcharType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t)
}
return nil
}

func (f *pgGetUserByIDFunc) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return VarcharType, nil
}

func (f *pgGetUserByIDFunc) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {
if len(params) != 1 {
return nil, fmt.Errorf("%w: '%s' function expects %d arguments but %d were provided", ErrIllegalArguments, JSONTypeOfFnCall, 1, len(params))
}

if params[0].RawValue() != int64(0) {
return nil, fmt.Errorf("user not found")
}

users, err := tx.ListUsers(tx.tx.Context())
if err != nil {
return nil, err
}

idx := findSysAdmin(users)
if idx < 0 {
return nil, fmt.Errorf("admin not found")
}
return NewVarchar(users[idx].Username()), nil
}

func findSysAdmin(users []User) int {
for i, u := range users {
if u.Permission() == PermissionSysAdmin {
return i
}
}
return -1
}

type pgTableIsVisible struct{}

func (f *pgTableIsVisible) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != BooleanType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, BooleanType, t)
}
return nil
}

func (f *pgTableIsVisible) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return BooleanType, nil
}

func (f *pgTableIsVisible) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {
if len(params) != 1 {
return nil, fmt.Errorf("%w: '%s' function expects %d arguments but %d were provided", ErrIllegalArguments, JSONTypeOfFnCall, 1, len(params))
}
return NewBool(true), nil
}

type pgShobjDescription struct{}

func (f *pgShobjDescription) RequiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error {
if t != VarcharType {
return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t)
}
return nil
}

func (f *pgShobjDescription) InferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) {
return VarcharType, nil
}

func (f *pgShobjDescription) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) {
if len(params) != 2 {
return nil, fmt.Errorf("%w: '%s' function expects %d arguments but %d were provided", ErrIllegalArguments, PgShobjDescription, 2, len(params))
}
return NewVarchar(""), nil
}
Loading

0 comments on commit 8d9971c

Please sign in to comment.