Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Parameterized type #52

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
package extensions

import (
"errors"
"fmt"
"reflect"
"strings"

substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/types"
"github.com/substrait-io/substrait-go/types/parser"
)

Expand Down Expand Up @@ -57,6 +59,7 @@

type Argument interface {
toTypeString() string
ArgType() (types.Type, error)
anshuldata marked this conversation as resolved.
Show resolved Hide resolved
}

type EnumArg struct {
Expand All @@ -69,6 +72,10 @@
return "req"
}

func (EnumArg) ArgType() (types.Type, error) {
return nil, errors.New("unimplemented")

Check warning on line 76 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L75-L76

Added lines #L75 - L76 were not covered by tests
}

type ValueArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -80,6 +87,10 @@
return v.Value.Expr.(*parser.Type).ShortType()
}

func (v ValueArg) ArgType() (types.Type, error) {
return v.Value.Expr.(*parser.Type).Type()

Check warning on line 91 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L90-L91

Added lines #L90 - L91 were not covered by tests
}

type TypeArg struct {
Name string `yaml:",omitempty"`
Description string `yaml:",omitempty"`
Expand All @@ -88,6 +99,10 @@

func (TypeArg) toTypeString() string { return "type" }

func (TypeArg) ArgType() (types.Type, error) {
return nil, errors.New("unimplemented")

Check warning on line 103 in extensions/simple_extension.go

View check run for this annotation

Codecov / codecov/patch

extensions/simple_extension.go#L102-L103

Added lines #L102 - L103 were not covered by tests
}

type ArgumentList []Argument

func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error {
Expand Down
58 changes: 58 additions & 0 deletions types/any_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package types
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

// AnyType to represent AnyType, this type is to indicate "any" type of argument
// This type is not used in function invocation. It is only used in function definition
type AnyType struct {
Name string
Nullability Nullability
}

func (AnyType) isRootRef() {}
func (m AnyType) WithNullability(nullability Nullability) Type {
m.Nullability = nullability
return m

Check warning on line 19 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L16-L19

Added lines #L16 - L19 were not covered by tests
}
func (m AnyType) GetType() Type { return m }

Check warning on line 21 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L21

Added line #L21 was not covered by tests
func (m AnyType) GetNullability() Nullability {
return m.Nullability
}
func (AnyType) GetTypeVariationReference() uint32 {
panic("not allowed")

Check warning on line 26 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L25-L26

Added lines #L25 - L26 were not covered by tests
}
func (AnyType) Equals(rhs Type) bool {

Check warning on line 28 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L28

Added line #L28 was not covered by tests
// equal to every other type
return true

Check warning on line 30 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L30

Added line #L30 was not covered by tests
}

func (AnyType) ToProtoFuncArg() *proto.FunctionArgument {
panic("not allowed")

Check warning on line 34 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L33-L34

Added lines #L33 - L34 were not covered by tests
}

func (AnyType) ToProto() *proto.Type {
panic("not allowed")

Check warning on line 38 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L37-L38

Added lines #L37 - L38 were not covered by tests
}
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

func (t AnyType) ShortString() string { return t.Name }
func (t AnyType) String() string {
return fmt.Sprintf("%s%s", t.Name, strNullable(t))
}

// Below methods are for parser Def interface

func (AnyType) Optional() bool {
panic("not allowed")

Check warning on line 49 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L48-L49

Added lines #L48 - L49 were not covered by tests
}
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

func (m AnyType) ShortType() string {
return "any"
}

func (m AnyType) Type() (Type, error) {
return m, nil

Check warning on line 57 in types/any_type.go

View check run for this annotation

Codecov / codecov/patch

types/any_type.go#L56-L57

Added lines #L56 - L57 were not covered by tests
}
33 changes: 33 additions & 0 deletions types/any_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package types_test
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types"
)

func TestAnyType(t *testing.T) {
for _, td := range []struct {
testName string
argName string
nullability types.Nullability
expectedString string
}{
{"any", "any", types.NullabilityNullable, "any?"},
{"anyrequired", "any", types.NullabilityRequired, "any"},
{"anyOtherName", "any1", types.NullabilityNullable, "any1?"},
{"T name", "T", types.NullabilityNullable, "T?"},
} {
t.Run(td.testName, func(t *testing.T) {
arg := &types.AnyType{
Name: td.argName,
Nullability: td.nullability,
}
require.Equal(t, td.expectedString, arg.String())
require.Equal(t, td.nullability, arg.GetNullability())
require.Equal(t, td.argName, arg.ShortString())
require.Equal(t, "any", arg.ShortType())
})
}
}
58 changes: 58 additions & 0 deletions types/parameterized_decimal_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

// ParameterizedDecimalType is a decimal type with precision and scale parameters of string type
// example: Decimal(P,S). Kindly note concrete types Decimal(10, 2) are not represented by this type
// Concrete type is represented by DecimalType
type ParameterizedDecimalType struct {
anshuldata marked this conversation as resolved.
Show resolved Hide resolved
Nullability Nullability
TypeVariationRef uint32
Precision IntegerParam
Scale IntegerParam
}

func (ParameterizedDecimalType) isRootRef() {}
func (m ParameterizedDecimalType) WithNullability(n Nullability) Type {
m.Nullability = n
return m

Check warning on line 22 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L19-L22

Added lines #L19 - L22 were not covered by tests
}

func (m ParameterizedDecimalType) GetType() Type { return m }

Check warning on line 25 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L25

Added line #L25 was not covered by tests
func (m ParameterizedDecimalType) GetNullability() Nullability { return m.Nullability }
func (m ParameterizedDecimalType) GetTypeVariationReference() uint32 {
return m.TypeVariationRef

Check warning on line 28 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L27-L28

Added lines #L27 - L28 were not covered by tests
}
func (m ParameterizedDecimalType) Equals(rhs Type) bool {
if o, ok := rhs.(ParameterizedDecimalType); ok {
return o == m
}
return false

Check warning on line 34 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L34

Added line #L34 was not covered by tests
}

func (ParameterizedDecimalType) ToProtoFuncArg() *proto.FunctionArgument {

Check warning on line 37 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L37

Added line #L37 was not covered by tests
// parameterized type are never on wire so to proto is not supported
panic("not supported")

Check warning on line 39 in types/parameterized_decimal_type.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_decimal_type.go#L39

Added line #L39 was not covered by tests
}
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

func (m ParameterizedDecimalType) ShortString() string {
t := DecimalType{}
return t.ShortString()
}

func (m ParameterizedDecimalType) String() string {
return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString())
}

func (m ParameterizedDecimalType) ParameterString() string {
return fmt.Sprintf("<%s,%s>", m.Precision.String(), m.Scale.String())
}

func (m ParameterizedDecimalType) BaseString() string {
t := DecimalType{}
return t.BaseString()
}
112 changes: 112 additions & 0 deletions types/parameterized_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package types

import (
"fmt"

"github.com/substrait-io/substrait-go/proto"
)

// IntegerParam represents a single integer parameter for a parameterized type
// Example: VARCHAR(L1) -> L1 is the integer parameter
type IntegerParam struct {
anshuldata marked this conversation as resolved.
Show resolved Hide resolved
anshuldata marked this conversation as resolved.
Show resolved Hide resolved
Name string
}

func (m IntegerParam) Equals(o IntegerParam) bool {
return m == o

Check warning on line 16 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L15-L16

Added lines #L15 - L16 were not covered by tests
}

func (p IntegerParam) ToProto() *proto.ParameterizedType_IntegerParameter {
panic("not implemented")

Check warning on line 20 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L19-L20

Added lines #L19 - L20 were not covered by tests
}
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

func (m IntegerParam) String() string {
return m.Name
}

// ParameterizedTypeSingleIntegerParam This is a generic type to represent parameterized type with a single integer parameter
type ParameterizedTypeSingleIntegerParam[T VarCharType | FixedCharType | FixedBinaryType | PrecisionTimestampType | PrecisionTimestampTzType] struct {
anshuldata marked this conversation as resolved.
Show resolved Hide resolved
Nullability Nullability
TypeVariationRef uint32
IntegerOption IntegerParam
}

func (m ParameterizedTypeSingleIntegerParam[T]) WithIntegerOption(integerOption IntegerParam) ParameterizedSingleIntegerType {
m.IntegerOption = integerOption
return m
}

func (ParameterizedTypeSingleIntegerParam[T]) isRootRef() {}

Check warning on line 39 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L39

Added line #L39 was not covered by tests
func (m ParameterizedTypeSingleIntegerParam[T]) WithNullability(n Nullability) Type {
m.Nullability = n
return m
}

func (m ParameterizedTypeSingleIntegerParam[T]) GetType() Type { return m }

Check warning on line 45 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L45

Added line #L45 was not covered by tests
func (m ParameterizedTypeSingleIntegerParam[T]) GetNullability() Nullability { return m.Nullability }
func (m ParameterizedTypeSingleIntegerParam[T]) GetTypeVariationReference() uint32 {
return m.TypeVariationRef

Check warning on line 48 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L47-L48

Added lines #L47 - L48 were not covered by tests
}
func (m ParameterizedTypeSingleIntegerParam[T]) Equals(rhs Type) bool {
if o, ok := rhs.(ParameterizedTypeSingleIntegerParam[T]); ok {
return o == m
}
return false

Check warning on line 54 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L54

Added line #L54 was not covered by tests
}

func (ParameterizedTypeSingleIntegerParam[T]) ToProtoFuncArg() *proto.FunctionArgument {

Check warning on line 57 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L57

Added line #L57 was not covered by tests
// parameterized type are never on wire so to proto is not supported
panic("not supported")

Check warning on line 59 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L59

Added line #L59 was not covered by tests
}
anshuldata marked this conversation as resolved.
Show resolved Hide resolved

func (m ParameterizedTypeSingleIntegerParam[T]) ShortString() string {
switch any(m).(type) {
case ParameterizedVarCharType:
t := VarCharType{}
return t.ShortString()
case ParameterizedFixedCharType:
t := FixedCharType{}
return t.ShortString()
case ParameterizedFixedBinaryType:
t := FixedBinaryType{}
return t.ShortString()
case ParameterizedPrecisionTimestampType:
t := PrecisionTimestampType{}
return t.ShortString()
case ParameterizedPrecisionTimestampTzType:
t := PrecisionTimestampTzType{}
return t.ShortString()
default:
panic("unknown type")

Check warning on line 80 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L79-L80

Added lines #L79 - L80 were not covered by tests
}
}

func (m ParameterizedTypeSingleIntegerParam[T]) String() string {
return fmt.Sprintf("%s%s%s", m.BaseString(), strNullable(m), m.ParameterString())
}

func (m ParameterizedTypeSingleIntegerParam[T]) ParameterString() string {
return fmt.Sprintf("<%s>", m.IntegerOption.String())
}

func (m ParameterizedTypeSingleIntegerParam[T]) BaseString() string {
switch any(m).(type) {
case ParameterizedVarCharType:
t := VarCharType{}
return t.BaseString()
case ParameterizedFixedCharType:
t := FixedCharType{}
return t.BaseString()
case ParameterizedFixedBinaryType:
t := FixedBinaryType{}
return t.BaseString()
case ParameterizedPrecisionTimestampType:
t := PrecisionTimestampType{}
return t.BaseString()
case ParameterizedPrecisionTimestampTzType:
t := PrecisionTimestampTzType{}
return t.BaseString()
default:
panic("unknown type")

Check warning on line 110 in types/parameterized_types.go

View check run for this annotation

Codecov / codecov/patch

types/parameterized_types.go#L109-L110

Added lines #L109 - L110 were not covered by tests
}
}
66 changes: 66 additions & 0 deletions types/parameterized_types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package types_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/types"
)

func TestParameterizedVarCharType(t *testing.T) {
for _, td := range []struct {
name string
typ types.ParameterizedSingleIntegerType
nullability types.Nullability
integerOption types.IntegerParam
expectedString string
expectedBaseString string
expectedShortString string
}{
{"nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "varchar?<L1>", "varchar", "vchar"},
{"non nullable varchar", &types.ParameterizedVarCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "varchar<L1>", "varchar", "vchar"},
{"nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "char?<L1>", "char", "fchar"},
{"non nullable fixChar", &types.ParameterizedFixedCharType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "char<L1>", "char", "fchar"},
{"nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "fixedbinary?<L1>", "fixedbinary", "fbin"},
{"non nullable fixBinary", &types.ParameterizedFixedBinaryType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "fixedbinary<L1>", "fixedbinary", "fbin"},
{"nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp?<L1>", "precision_timestamp", "prets"},
{"non nullable precisionTimeStamp", &types.ParameterizedPrecisionTimestampType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp<L1>", "precision_timestamp", "prets"},
{"nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityNullable, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz?<L1>", "precision_timestamp_tz", "pretstz"},
{"non nullable precisionTimeStampTz", &types.ParameterizedPrecisionTimestampTzType{}, types.NullabilityRequired, types.IntegerParam{Name: "L1"}, "precision_timestamp_tz<L1>", "precision_timestamp_tz", "pretstz"},
} {
t.Run(td.name, func(t *testing.T) {
pt := td.typ.WithIntegerOption(td.integerOption).WithNullability(td.nullability)
require.Equal(t, td.expectedString, pt.String())
parameterizeType, ok := pt.(types.ParameterizedType)
require.True(t, ok)
require.Equal(t, td.expectedBaseString, parameterizeType.BaseString())
require.Equal(t, td.expectedShortString, pt.ShortString())
require.True(t, pt.Equals(pt))
})
}
}

func TestParameterizedDecimalType(t *testing.T) {
for _, td := range []struct {
name string
precision string
scale string
nullability types.Nullability
expectedString string
expectedBaseString string
expectedShortString string
}{
{"nullable decimal", "P", "S", types.NullabilityNullable, "decimal?<P,S>", "decimal", "dec"},
{"non nullable decimal", "P", "S", types.NullabilityRequired, "decimal<P,S>", "decimal", "dec"},
} {
t.Run(td.name, func(t *testing.T) {
precision := types.IntegerParam{Name: td.precision}
scale := types.IntegerParam{Name: td.scale}
pt := types.ParameterizedDecimalType{Precision: precision, Scale: scale, Nullability: td.nullability}
require.Equal(t, td.expectedString, pt.String())
require.Equal(t, td.expectedBaseString, pt.BaseString())
require.Equal(t, td.expectedShortString, pt.ShortString())
require.True(t, pt.Equals(pt))
})
}
}
Loading
Loading