Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(firestore): Adding vector search #10548

Merged
merged 12 commits into from
Jul 22, 2024
22 changes: 22 additions & 0 deletions firestore/docref.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,28 @@
}
}

bhshkh marked this conversation as resolved.
Show resolved Hide resolved
// VectorType represpresents a vector
type VectorType interface {
bhshkh marked this conversation as resolved.
Show resolved Hide resolved
isVectorType()
toProtoValue() (*pb.Value, bool, error)
}

// Vector represents a vector in the form of a float64 array
type Vector []float64
bhshkh marked this conversation as resolved.
Show resolved Hide resolved

func (_ Vector) isVectorType() {}

Check failure on line 650 in firestore/docref.go

View workflow job for this annotation

GitHub Actions / vet

receiver name should not be an underscore, omit the name if it is unused

func (vector Vector) toProtoValue() (*pb.Value, bool, error) {
bhshkh marked this conversation as resolved.
Show resolved Hide resolved
bhshkh marked this conversation as resolved.
Show resolved Hide resolved
if vector == nil {
return nullValue, false, nil
}

vectorMap := map[string]interface{}{}
vectorMap["__type__"] = "__vector__"
vectorMap["value"] = []float64(vector)
return mapToProtoValue(reflect.ValueOf(vectorMap))
}

// An Update describes an update to a value referred to by a path.
// An Update should have either a non-empty Path or a non-empty FieldPath,
// but not both.
Expand Down
170 changes: 108 additions & 62 deletions firestore/from_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,221 +32,267 @@ func setFromProtoValue(x interface{}, vproto *pb.Value, c *Client) error {
return setReflectFromProtoValue(v.Elem(), vproto, c)
}

// setReflectFromProtoValue sets v from a Firestore Value.
// v must be a settable value.
func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) error {
// setReflectFromProtoValue sets vDest from a Firestore Value.
// vDest must be a settable value.
func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Client) error {
typeErr := func() error {
return fmt.Errorf("firestore: cannot set type %s to %s", v.Type(), typeString(vproto))
return fmt.Errorf("firestore: cannot set type %s to %s", vDest.Type(), typeString(vprotoSrc))
}

val := vproto.ValueType
typeErrWithArgs := func(destType string) error {
return fmt.Errorf("firestore: cannot set type %s to %s", destType, typeString(vprotoSrc))
}

valTypeSrc := vprotoSrc.ValueType
// A Null value sets anything nullable to nil, and has no effect
// on anything else.
if _, ok := val.(*pb.Value_NullValue); ok {
switch v.Kind() {
if _, ok := valTypeSrc.(*pb.Value_NullValue); ok {
switch vDest.Kind() {
case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice:
v.Set(reflect.Zero(v.Type()))
vDest.Set(reflect.Zero(vDest.Type()))
}
return nil
}

// Handle special types first.
switch v.Type() {
switch vDest.Type() {
case typeOfByteSlice:
x, ok := val.(*pb.Value_BytesValue)
x, ok := valTypeSrc.(*pb.Value_BytesValue)
if !ok {
return typeErr()
}
v.SetBytes(x.BytesValue)
vDest.SetBytes(x.BytesValue)
return nil

case typeOfGoTime:
x, ok := val.(*pb.Value_TimestampValue)
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
if !ok {
return typeErr()
}
if err := x.TimestampValue.CheckValid(); err != nil {
return err
}
v.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
vDest.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
return nil

case typeOfProtoTimestamp:
x, ok := val.(*pb.Value_TimestampValue)
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
if !ok {
return typeErr()
}
v.Set(reflect.ValueOf(x.TimestampValue))
vDest.Set(reflect.ValueOf(x.TimestampValue))
return nil

case typeOfLatLng:
x, ok := val.(*pb.Value_GeoPointValue)
x, ok := valTypeSrc.(*pb.Value_GeoPointValue)
if !ok {
return typeErr()
}
v.Set(reflect.ValueOf(x.GeoPointValue))
vDest.Set(reflect.ValueOf(x.GeoPointValue))
return nil

case typeOfDocumentRef:
x, ok := val.(*pb.Value_ReferenceValue)
x, ok := valTypeSrc.(*pb.Value_ReferenceValue)
if !ok {
return typeErr()
}
dr, err := pathToDoc(x.ReferenceValue, c)
if err != nil {
return err
}
v.Set(reflect.ValueOf(dr))
vDest.Set(reflect.ValueOf(dr))
return nil

case typeOfVector:
jba marked this conversation as resolved.
Show resolved Hide resolved
/*
Vector is stored as:
{
"__type__": "__vector__",
"value": []float64{},
}
but needs to be returned as firestore.Vector to the user
*/

// Convert Firestore proto map from Go map
vectorMapDest := map[string]interface{}{}
vectorMapDestVal := reflect.ValueOf(vectorMapDest)
x, ok := valTypeSrc.(*pb.Value_MapValue)
if !ok {
// Vector not stored as map in Firestore
return typeErrWithArgs("Vector")
}
err := populateMap(vectorMapDestVal, x.MapValue.Fields, c)
if err != nil {
return err
}

// Convert value at "value" key to array of floats
anyArr, isInterfaceArr := vectorMapDest["value"].([]interface{})
if !isInterfaceArr {
// value at "value" key is not an array
return typeErrWithArgs("Vector")
}
floats := []float64{}
for _, v := range anyArr {
// Convert each element of []interface{} to float64
floatVal, isFloat := v.(float64)
if isFloat {
floats = append(floats, floatVal)
}
}

// Set Vector in destination
vDest.Set(reflect.ValueOf(floats))
return nil
}

switch v.Kind() {
switch vDest.Kind() {
case reflect.Bool:
x, ok := val.(*pb.Value_BooleanValue)
x, ok := valTypeSrc.(*pb.Value_BooleanValue)
if !ok {
return typeErr()
}
v.SetBool(x.BooleanValue)
vDest.SetBool(x.BooleanValue)

case reflect.String:
x, ok := val.(*pb.Value_StringValue)
x, ok := valTypeSrc.(*pb.Value_StringValue)
if !ok {
return typeErr()
}
v.SetString(x.StringValue)
vDest.SetString(x.StringValue)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var i int64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_IntegerValue:
i = x.IntegerValue
case *pb.Value_DoubleValue:
f := x.DoubleValue
i = int64(f)
if float64(i) != f {
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
}
default:
return typeErr()
}
if v.OverflowInt(i) {
return overflowErr(v, i)
if vDest.OverflowInt(i) {
return overflowErr(vDest, i)
}
v.SetInt(i)
vDest.SetInt(i)

case reflect.Uint8, reflect.Uint16, reflect.Uint32:
var u uint64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_IntegerValue:
u = uint64(x.IntegerValue)
case *pb.Value_DoubleValue:
f := x.DoubleValue
u = uint64(f)
if float64(u) != f {
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
}
default:
return typeErr()
}
if v.OverflowUint(u) {
return overflowErr(v, u)
if vDest.OverflowUint(u) {
return overflowErr(vDest, u)
}
v.SetUint(u)
vDest.SetUint(u)

case reflect.Float32, reflect.Float64:
var f float64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_DoubleValue:
f = x.DoubleValue
case *pb.Value_IntegerValue:
f = float64(x.IntegerValue)
if int64(f) != x.IntegerValue {
return overflowErr(v, x.IntegerValue)
return overflowErr(vDest, x.IntegerValue)
}
default:
return typeErr()
}
if v.OverflowFloat(f) {
return overflowErr(v, f)
if vDest.OverflowFloat(f) {
return overflowErr(vDest, f)
}
v.SetFloat(f)
vDest.SetFloat(f)

case reflect.Slice:
x, ok := val.(*pb.Value_ArrayValue)
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
if !ok {
return typeErr()
}
vals := x.ArrayValue.Values
vlen := v.Len()
vlen := vDest.Len()
xlen := len(vals)
// Make a slice of the right size, avoiding allocation if possible.
switch {
case vlen < xlen:
v.Set(reflect.MakeSlice(v.Type(), xlen, xlen))
vDest.Set(reflect.MakeSlice(vDest.Type(), xlen, xlen))
case vlen > xlen:
v.SetLen(xlen)
vDest.SetLen(xlen)
}
return populateRepeated(v, vals, xlen, c)
return populateRepeated(vDest, vals, xlen, c)

case reflect.Array:
x, ok := val.(*pb.Value_ArrayValue)
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
if !ok {
return typeErr()
}
vals := x.ArrayValue.Values
xlen := len(vals)
vlen := v.Len()
vlen := vDest.Len()
minlen := vlen
// Set extra elements to their zero value.
if vlen > xlen {
z := reflect.Zero(v.Type().Elem())
z := reflect.Zero(vDest.Type().Elem())
for i := xlen; i < vlen; i++ {
v.Index(i).Set(z)
vDest.Index(i).Set(z)
}
minlen = xlen
}
return populateRepeated(v, vals, minlen, c)
return populateRepeated(vDest, vals, minlen, c)

case reflect.Map:
x, ok := val.(*pb.Value_MapValue)
x, ok := valTypeSrc.(*pb.Value_MapValue)
if !ok {
return typeErr()
}
return populateMap(v, x.MapValue.Fields, c)
return populateMap(vDest, x.MapValue.Fields, c)

case reflect.Ptr:
// If the pointer is nil, set it to a zero value.
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
if vDest.IsNil() {
vDest.Set(reflect.New(vDest.Type().Elem()))
}
return setReflectFromProtoValue(v.Elem(), vproto, c)
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)

case reflect.Struct:
x, ok := val.(*pb.Value_MapValue)
x, ok := valTypeSrc.(*pb.Value_MapValue)
if !ok {
return typeErr()
}
return populateStruct(v, x.MapValue.Fields, c)
return populateStruct(vDest, x.MapValue.Fields, c)

case reflect.Interface:
if v.NumMethod() == 0 { // empty interface
if vDest.NumMethod() == 0 { // empty interface
// If v holds a pointer, set the pointer.
if !v.IsNil() && v.Elem().Kind() == reflect.Ptr {
return setReflectFromProtoValue(v.Elem(), vproto, c)
if !vDest.IsNil() && vDest.Elem().Kind() == reflect.Ptr {
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)
}
// Otherwise, create a fresh value.
x, err := createFromProtoValue(vproto, c)
x, err := createFromProtoValue(vprotoSrc, c)
if err != nil {
return err
}
v.Set(reflect.ValueOf(x))
vDest.Set(reflect.ValueOf(x))
return nil
}
// Any other kind of interface is an error.
fallthrough

default:
return fmt.Errorf("firestore: cannot set type %s", v.Type())
return fmt.Errorf("firestore: cannot set type %s", vDest.Type())
}
return nil
}
Expand Down
Loading
Loading