From d9e2ba6f82b9f95ba0b1c46db71abc2843132bff Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:53:52 -0400 Subject: [PATCH 01/24] Stub VectorValue --- .../FirebaseFirestore/FIRVectorValue.h | 15 ++++ Firestore/Source/API/FIRFieldValue.mm | 5 ++ .../Source/API/FIRVectorValue+Internal.h | 29 +++++++ Firestore/Source/API/FIRVectorValue.m | 83 +++++++++++++++++++ .../Public/FirebaseFirestore/FIRFieldValue.h | 9 ++ .../Public/FirebaseFirestore/FIRVectorValue.h | 29 +++++++ .../Source/SwiftAPI/FieldValue+Swift.swift | 44 ++++++++++ .../Source/SwiftAPI/VectorValue+Swift.swift | 30 +++++++ 8 files changed, 244 insertions(+) create mode 100644 FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorValue.h create mode 100644 Firestore/Source/API/FIRVectorValue+Internal.h create mode 100644 Firestore/Source/API/FIRVectorValue.m create mode 100644 Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h create mode 100644 Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift create mode 100644 Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift diff --git a/FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorValue.h b/FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorValue.h new file mode 100644 index 00000000000..ac85c462d20 --- /dev/null +++ b/FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorValue.h @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import diff --git a/Firestore/Source/API/FIRFieldValue.mm b/Firestore/Source/API/FIRFieldValue.mm index dccf735175c..a38cff88a11 100644 --- a/Firestore/Source/API/FIRFieldValue.mm +++ b/Firestore/Source/API/FIRFieldValue.mm @@ -15,6 +15,7 @@ */ #import "Firestore/Source/API/FIRFieldValue+Internal.h" +#import "Firestore/Source/API/FIRVectorValue+Internal.h" NS_ASSUME_NONNULL_BEGIN @@ -176,6 +177,10 @@ + (instancetype)fieldValueForIntegerIncrement:(int64_t)l { return [[FSTNumericIncrementFieldValue alloc] initWithOperand:@(l)]; } ++ (nonnull FIRVectorValue *)vectorFromNSNumbers:(nonnull NSArray *)values { + return [[FIRVectorValue alloc] initWithNSNumbers: values]; +} + @end NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/API/FIRVectorValue+Internal.h b/Firestore/Source/API/FIRVectorValue+Internal.h new file mode 100644 index 00000000000..8774da83cb7 --- /dev/null +++ b/Firestore/Source/API/FIRVectorValue+Internal.h @@ -0,0 +1,29 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#import "FIRVectorValue.h" + + +NS_ASSUME_NONNULL_BEGIN + +@interface FIRVectorValue (Internal) +- (instancetype)init NS_UNAVAILABLE; +- (instancetype)initWithNSNumbers: (NSArray *)values; +- (NSArray *)toNSArray; +@end + + +NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/API/FIRVectorValue.m b/Firestore/Source/API/FIRVectorValue.m new file mode 100644 index 00000000000..ec63aecf43b --- /dev/null +++ b/Firestore/Source/API/FIRVectorValue.m @@ -0,0 +1,83 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#import + +#include + +#import "Firestore/Source/API/FIRVectorValue+Internal.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface FIRVectorValue () { + /** Internal vector representation */ + std::vector _internalValue; +} + +@end + +@implementation FIRVectorValue + +- (instancetype)initWithNSNumbers: (NSArray *)values { + if (self = [super init]) { + std::vector converted; + converted.reserve(values.count); + for (NSNumber *value in values) { + converted.emplace_back([value doubleValue]); + } + + _internalValue = std::move(converted); + } + return self; +} + +- (nonnull NSArray *)toNSArray { + size_t length = _internalValue.size(); + NSMutableArray *outArray = [[NSMutableArray alloc] initWithCapacity:length]; + for (size_t i = 0; i < length; i++) { + [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; + } + + return outArray; +} + +- (BOOL)isEqual:(nullable id)object { + if (self == object) { + return YES; + } + + if (![object isKindOfClass:[FIRVectorValue class]]) { + return NO; + } + + FIRVectorValue *otherVector = ((FIRVectorValue *)object); + + if (self->_internalValue.size() != otherVector->_internalValue.size()) { + return NO; + } + + for (size_t i = 0; i < self->_internalValue.size(); i++) { + if (self->_internalValue[i] != otherVector->_internalValue[i]) + return NO; + } + + return YES; +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h index 269735570d9..37f1a72d31c 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h @@ -17,6 +17,7 @@ #import NS_ASSUME_NONNULL_BEGIN +@class FIRVectorValue; /** * Sentinel values that can be used when writing document fields with `setData()` or `updateData()`. @@ -90,6 +91,14 @@ NS_SWIFT_NAME(FieldValue) */ + (instancetype)fieldValueForIntegerIncrement:(int64_t)l NS_SWIFT_NAME(increment(_:)); +/** + * Creates a new `VectorValue` constructed with a copy of the given array of NSNumbers. + * + * @param values Create a `VectorValue` instance with a copy of this array of NSNumbers. + * @return A new `VectorValue` constructed with a copy of the given array of NSNumbers. + */ ++ (FIRVectorValue*)vectorFromNSNumbers:(NSArray *)values NS_SWIFT_NAME(vector(fromNSNumbers:)); + @end NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h new file mode 100644 index 00000000000..8ed768e0f23 --- /dev/null +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h @@ -0,0 +1,29 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#import + +NS_ASSUME_NONNULL_BEGIN + +NS_SWIFT_NAME(VectorValue) +@interface FIRVectorValue : NSObject +- (instancetype)init __attribute__((unavailable("FIRVectorValue cannot be created directly."))); + +- (NSArray *)toNSArray NS_SWIFT_NAME(toNSArray()); +@end + + +NS_ASSUME_NONNULL_END diff --git a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift new file mode 100644 index 00000000000..6d43c810e2a --- /dev/null +++ b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift @@ -0,0 +1,44 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if SWIFT_PACKAGE + @_exported import FirebaseFirestoreInternalWrapper +#else + @_exported import FirebaseFirestoreInternal +#endif // SWIFT_PACKAGE + +public extension FieldValue { + + /// Creates a new `VectorValue` constructed with a copy of the given array of Doubles. + /// - Parameter values: An array of Doubles. + /// - Returns: A new `VectorValue` constructed with a copy of the given array of Doubles. + static func vector(_ values: [Double]) -> VectorValue { + let array = values.map { double in + return NSNumber(value: double) + } + return FieldValue.vector(fromNSNumbers: array) + } + + /// Creates a new `VectorValue` constructed with a copy of the given array of Floats. + /// - Parameter values: An array of Floats. + /// - Returns: A new `VectorValue` constructed with a copy of the given array of Floats. + static func vector(_ values: [Float]) -> VectorValue { + let array = values.map { float in + return NSNumber(value: float) + } + return FieldValue.vector(fromNSNumbers: array) + } +} diff --git a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift new file mode 100644 index 00000000000..6e43cf628a4 --- /dev/null +++ b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift @@ -0,0 +1,30 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if SWIFT_PACKAGE + @_exported import FirebaseFirestoreInternalWrapper +#else + @_exported import FirebaseFirestoreInternal +#endif // SWIFT_PACKAGE + +public extension VectorValue { + + /// Returns a raw number array representation of the vector. + /// - Returns: An array of Double values representing the vector. + func toArray() -> [Double] { + return self.toNSArray().map { Double(truncating: $0) } + } +} From 3ad498a8fedef27f6a9e3a329a740a7b407dd3f0 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:21:17 -0600 Subject: [PATCH 02/24] Porting VectorValue implementation. --- Firestore/Source/API/FSTUserDataReader.mm | 40 ++- Firestore/Source/API/FSTUserDataWriter.mm | 75 +++-- Firestore/core/src/core/target.cc | 4 +- Firestore/core/src/model/value_util.cc | 266 +++++++++++++----- Firestore/core/src/model/value_util.h | 57 +++- .../core/test/unit/model/value_util_test.cc | 58 ++-- 6 files changed, 360 insertions(+), 140 deletions(-) diff --git a/Firestore/Source/API/FSTUserDataReader.mm b/Firestore/Source/API/FSTUserDataReader.mm index dbc5c976839..b175131c05b 100644 --- a/Firestore/Source/API/FSTUserDataReader.mm +++ b/Firestore/Source/API/FSTUserDataReader.mm @@ -24,6 +24,7 @@ #import "FIRGeoPoint.h" #import "FIRTimestamp.h" +#import "FIRVectorValue.h" #import "Firestore/Source/API/FIRDocumentReference+Internal.h" #import "Firestore/Source/API/FIRFieldPath+Internal.h" @@ -340,6 +341,41 @@ - (ParsedUpdateData)parsedUpdateData:(id)input { return std::move(result); } +- (Message)parseVectorValue:(FIRVectorValue *)vectorValue context:(ParseContext &&)context { + __block Message result; + result->which_value_type = google_firestore_v1_Value_map_value_tag; + result->map_value = {}; + + result->map_value.fields_count = 2; + result->map_value.fields = nanopb::MakeArray(2); + + result->map_value.fields[0].key = nanopb::MakeBytesArray(MakeString(@"__type__")); + result->map_value.fields[0].value = *[self encodeStringValue:MakeString(@"__vector__")].release(); + + NSArray *vectorArray = [vectorValue toNSArray]; + + __block Message arrayMessage; + arrayMessage->which_value_type = google_firestore_v1_Value_array_value_tag; + arrayMessage->array_value.values_count = CheckedSize([vectorArray count]); + arrayMessage->array_value.values = + nanopb::MakeArray(arrayMessage->array_value.values_count); + + [vectorArray enumerateObjectsUsingBlock:^(id entry, NSUInteger idx, BOOL *) { + if (![entry isKindOfClass:[NSNumber class]]) { + ThrowInvalidArgument("VectorValues must only contain numeric values.", + context.FieldDescription()); + } + + // Vector values must always use Double encoding + result->array_value.values[idx] = *[self encodeDouble:[entry doubleValue]].release(); + }]; + + result->map_value.fields[1].key = nanopb::MakeBytesArray(MakeString(@"value")); + result->map_value.fields[1].value = *arrayMessage.release(); + + return std::move(result); +} + - (Message)parseArray:(NSArray *)array context:(ParseContext &&)context { __block Message result; @@ -528,7 +564,9 @@ - (void)parseSentinelFieldValue:(FIRFieldValue *)fieldValue context:(ParseContex _databaseID.database_id(), context.FieldDescription()); } return [self encodeReference:_databaseID key:reference.key]; - + } else if ([input isKindOfClass:[FIRVectorValue class]]) { + FIRVectorValue *vector = input; + return [self parseVectorValue:vector context:std::move(context)]; } else { ThrowInvalidArgument("Unsupported type: %s%s", NSStringFromClass([input class]), context.FieldDescription()); diff --git a/Firestore/Source/API/FSTUserDataWriter.mm b/Firestore/Source/API/FSTUserDataWriter.mm index c561efe1353..7d3e87b7c46 100644 --- a/Firestore/Source/API/FSTUserDataWriter.mm +++ b/Firestore/Source/API/FSTUserDataWriter.mm @@ -22,6 +22,7 @@ #include "Firestore/Protos/nanopb/google/firestore/v1/document.nanopb.h" #include "Firestore/Source/API/FIRDocumentReference+Internal.h" #include "Firestore/Source/API/converters.h" +#include "Firestore/Source/API/FIRFieldValue+Internal.h" #include "Firestore/core/include/firebase/firestore/geo_point.h" #include "Firestore/core/include/firebase/firestore/timestamp.h" #include "Firestore/core/src/api/firestore.h" @@ -79,36 +80,38 @@ - (instancetype)initWithFirestore:(std::shared_ptr)firestore } - (id)convertedValue:(const google_firestore_v1_Value &)value { - switch (GetTypeOrder(value)) { - case TypeOrder::kMap: - return [self convertedObject:value.map_value]; - case TypeOrder::kArray: - return [self convertedArray:value.array_value]; - case TypeOrder::kReference: - return [self convertedReference:value]; - case TypeOrder::kTimestamp: - return [self convertedTimestamp:value.timestamp_value]; - case TypeOrder::kServerTimestamp: - return [self convertedServerTimestamp:value]; - case TypeOrder::kNull: - return [NSNull null]; - case TypeOrder::kBoolean: - return value.boolean_value ? @YES : @NO; - case TypeOrder::kNumber: - return value.which_value_type == google_firestore_v1_Value_integer_value_tag - ? @(value.integer_value) - : @(value.double_value); - case TypeOrder::kString: - return MakeNSString(MakeStringView(value.string_value)); - case TypeOrder::kBlob: - return MakeNSData(value.bytes_value); - case TypeOrder::kGeoPoint: - return MakeFIRGeoPoint( - GeoPoint(value.geo_point_value.latitude, value.geo_point_value.longitude)); - case TypeOrder::kMaxValue: - // It is not possible for users to construct a kMaxValue manually. - break; - } + switch (GetTypeOrder(value)) { + case TypeOrder::kMap: + return [self convertedObject:value.map_value]; + case TypeOrder::kArray: + return [self convertedArray:value.array_value]; + case TypeOrder::kReference: + return [self convertedReference:value]; + case TypeOrder::kTimestamp: + return [self convertedTimestamp:value.timestamp_value]; + case TypeOrder::kServerTimestamp: + return [self convertedServerTimestamp:value]; + case TypeOrder::kNull: + return [NSNull null]; + case TypeOrder::kBoolean: + return value.boolean_value ? @YES : @NO; + case TypeOrder::kNumber: + return value.which_value_type == google_firestore_v1_Value_integer_value_tag + ? @(value.integer_value) + : @(value.double_value); + case TypeOrder::kString: + return MakeNSString(MakeStringView(value.string_value)); + case TypeOrder::kBlob: + return MakeNSData(value.bytes_value); + case TypeOrder::kGeoPoint: + return MakeFIRGeoPoint( + GeoPoint(value.geo_point_value.latitude, value.geo_point_value.longitude)); + case TypeOrder::kVector: + return [self convertedVector:value.map_value]; + case TypeOrder::kMaxValue: + // It is not possible for users to construct a kMaxValue manually. + break; + } UNREACHABLE(); } @@ -123,6 +126,18 @@ - (id)convertedValue:(const google_firestore_v1_Value &)value { return result; } +- (FIRVectorValue *)convertedVector:(const google_firestore_v1_MapValue &)mapValue { + for (pb_size_t i = 0; i < mapValue.fields_count; ++i) { + absl::string_view key = MakeStringView(mapValue.fields[i].key); + const google_firestore_v1_Value &value = mapValue.fields[i].value; + if ((0 == key.compare(absl::string_view("value"))) + && value.which_value_type == google_firestore_v1_Value_array_value_tag) { + return [FIRFieldValue vectorFromNSNumbers:[self convertedArray:value.array_value]]; + } + } + return [FIRFieldValue vectorFromNSNumbers:@[]]; +} + - (NSArray *)convertedArray:(const google_firestore_v1_ArrayValue &)arrayValue { NSMutableArray *result = [NSMutableArray arrayWithCapacity:arrayValue.values_count]; for (pb_size_t i = 0; i < arrayValue.values_count; ++i) { diff --git a/Firestore/core/src/core/target.cc b/Firestore/core/src/core/target.cc index 76b9c625f57..54a9e127728 100644 --- a/Firestore/core/src/core/target.cc +++ b/Firestore/core/src/core/target.cc @@ -220,7 +220,7 @@ Target::IndexBoundValue Target::GetAscendingBound( case FieldFilter::Operator::LessThan: case FieldFilter::Operator::LessThanOrEqual: filter_value = - model::GetLowerBound(field_filter.value().which_value_type); + model::GetLowerBound(field_filter.value()); break; case FieldFilter::Operator::Equal: case FieldFilter::Operator::In: @@ -285,7 +285,7 @@ Target::IndexBoundValue Target::GetDescendingBound( case FieldFilter::Operator::GreaterThanOrEqual: case FieldFilter::Operator::GreaterThan: filter_value = - model::GetUpperBound(field_filter.value().which_value_type); + model::GetUpperBound(field_filter.value()); filter_inclusive = false; break; case FieldFilter::Operator::Equal: diff --git a/Firestore/core/src/model/value_util.cc b/Firestore/core/src/model/value_util.cc index 61c4a8c865f..bc7038b4bd6 100644 --- a/Firestore/core/src/model/value_util.cc +++ b/Firestore/core/src/model/value_util.cc @@ -44,16 +44,26 @@ namespace { pb_bytes_array_s* kMinimumReferenceValue = nanopb::MakeBytesArray("projects//databases//documents/"); -/** The field type of a maximum proto value. */ -const char* kRawMaxValueFieldKey = "__type__"; -pb_bytes_array_s* kMaxValueFieldKey = - nanopb::MakeBytesArray(kRawMaxValueFieldKey); +/** The field type of a special object type. */ +const char* kRawTypeValueFieldKey = "__type__"; +pb_bytes_array_s* kTypeValueFieldKey = + nanopb::MakeBytesArray(kRawTypeValueFieldKey); /** The field value of a maximum proto value. */ const char* kRawMaxValueFieldValue = "__max__"; pb_bytes_array_s* kMaxValueFieldValue = nanopb::MakeBytesArray(kRawMaxValueFieldValue); +/** The type of a VectorValue proto. */ +const char* kRawVectorTypeFieldValue = "__vector__"; +pb_bytes_array_s* kVectorTypeFieldValue = + nanopb::MakeBytesArray(kRawVectorTypeFieldValue); + +/** The value key of a VectorValue proto. */ +const char* kRawVectorValueFieldKey = "value"; +pb_bytes_array_s* kVectorValueFieldKey = + nanopb::MakeBytesArray(kRawVectorValueFieldKey); + } // namespace using nanopb::Message; @@ -85,15 +95,17 @@ TypeOrder GetTypeOrder(const google_firestore_v1_Value& value) { case google_firestore_v1_Value_geo_point_value_tag: return TypeOrder::kGeoPoint; - - case google_firestore_v1_Value_array_value_tag: - return TypeOrder::kArray; + + case google_firestore_v1_Value_array_value_tag: + return TypeOrder::kArray; case google_firestore_v1_Value_map_value_tag: { if (IsServerTimestamp(value)) { return TypeOrder::kServerTimestamp; } else if (IsMaxValue(value)) { return TypeOrder::kMaxValue; + } else if (IsVectorValue(value)) { + return TypeOrder::kVector; } return TypeOrder::kMap; } @@ -253,6 +265,24 @@ ComparisonResult CompareMaps(const google_firestore_v1_MapValue& left, return util::Compare(left_map->fields_count, right_map->fields_count); } +ComparisonResult CompareVectors(const google_firestore_v1_Value& left, + const google_firestore_v1_Value& right) { + HARD_ASSERT(IsVectorValue(left) && IsVectorValue(right), "Cannot compare non-vector values as vectors."); + + int64_t leftIndex = IndexOfKey(left.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); + int64_t rightIndex = IndexOfKey(right.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); + + google_firestore_v1_Value leftArray = left.map_value.fields[leftIndex].value; + google_firestore_v1_Value rightArray = right.map_value.fields[rightIndex].value; + + ComparisonResult lengthCompare = util::Compare(leftArray.array_value.values_count, rightArray.array_value.values_count); + if (lengthCompare != ComparisonResult::Same) { + return lengthCompare; + } + + return CompareArrays(leftArray, rightArray); +} + ComparisonResult Compare(const google_firestore_v1_Value& left, const google_firestore_v1_Value& right) { TypeOrder left_type = GetTypeOrder(left); @@ -293,9 +323,12 @@ ComparisonResult Compare(const google_firestore_v1_Value& left, case TypeOrder::kArray: return CompareArrays(left, right); - - case TypeOrder::kMap: - return CompareMaps(left.map_value, right.map_value); + + case TypeOrder::kMap: + return CompareMaps(left.map_value, right.map_value); + + case TypeOrder::kVector: + return CompareMaps(left.map_value, right.map_value); case TypeOrder::kMaxValue: return util::ComparisonResult::Same; @@ -425,6 +458,7 @@ bool Equals(const google_firestore_v1_Value& lhs, case TypeOrder::kArray: return ArrayEquals(lhs.array_value, rhs.array_value); + case TypeOrder::kVector: case TypeOrder::kMap: return MapValueEquals(lhs.map_value, rhs.map_value); @@ -539,106 +573,85 @@ std::string CanonicalId(const google_firestore_v1_ArrayValue& value) { return CanonifyArray(value); } -google_firestore_v1_Value GetLowerBound(pb_size_t value_tag) { - switch (value_tag) { +google_firestore_v1_Value GetLowerBound(const google_firestore_v1_Value& value) { + switch (value.which_value_type) { case google_firestore_v1_Value_null_value_tag: return NullValue(); case google_firestore_v1_Value_boolean_value_tag: { - google_firestore_v1_Value value; - value.which_value_type = value_tag; - value.boolean_value = false; - return value; + return MinBoolean(); } case google_firestore_v1_Value_integer_value_tag: case google_firestore_v1_Value_double_value_tag: { - return NaNValue(); + return MinNumber(); } case google_firestore_v1_Value_timestamp_value_tag: { - google_firestore_v1_Value value; - value.which_value_type = value_tag; - value.timestamp_value.seconds = std::numeric_limits::min(); - value.timestamp_value.nanos = 0; - return value; + return MinTimestamp(); } case google_firestore_v1_Value_string_value_tag: { - google_firestore_v1_Value value; - value.which_value_type = value_tag; - value.string_value = nullptr; - return value; + return MinString(); } case google_firestore_v1_Value_bytes_value_tag: { - google_firestore_v1_Value value; - value.which_value_type = value_tag; - value.bytes_value = nullptr; - return value; + return MinBytes(); } case google_firestore_v1_Value_reference_value_tag: { - google_firestore_v1_Value result; - result.which_value_type = google_firestore_v1_Value_reference_value_tag; - result.reference_value = kMinimumReferenceValue; - return result; + return MinReference(); } case google_firestore_v1_Value_geo_point_value_tag: { - google_firestore_v1_Value value; - value.which_value_type = value_tag; - value.geo_point_value.latitude = -90.0; - value.geo_point_value.longitude = -180.0; - return value; + return MinGeoPoint(); } case google_firestore_v1_Value_array_value_tag: { - google_firestore_v1_Value value; - value.which_value_type = value_tag; - value.array_value.values = nullptr; - value.array_value.values_count = 0; - return value; + return MinArray(); } case google_firestore_v1_Value_map_value_tag: { - google_firestore_v1_Value value; - value.which_value_type = value_tag; - value.map_value.fields = nullptr; - value.map_value.fields_count = 0; - return value; + if (IsVectorValue(value)) { + return MinVector(); + } + + return MinMap(); } default: - HARD_FAIL("Invalid type value: %s", value_tag); + HARD_FAIL("Invalid type value: %s", value.which_value_type); } } -google_firestore_v1_Value GetUpperBound(pb_size_t value_tag) { - switch (value_tag) { +google_firestore_v1_Value GetUpperBound(const google_firestore_v1_Value& value) { + switch (value.which_value_type) { case google_firestore_v1_Value_null_value_tag: - return GetLowerBound(google_protobuf_BoolValue_value_tag); + return MinBoolean(); case google_firestore_v1_Value_boolean_value_tag: - return GetLowerBound(google_firestore_v1_Value_integer_value_tag); + return MinNumber(); case google_firestore_v1_Value_integer_value_tag: case google_firestore_v1_Value_double_value_tag: - return GetLowerBound(google_firestore_v1_Value_timestamp_value_tag); + return MinTimestamp(); case google_firestore_v1_Value_timestamp_value_tag: - return GetLowerBound(google_firestore_v1_Value_string_value_tag); + return MinString(); case google_firestore_v1_Value_string_value_tag: - return GetLowerBound(google_firestore_v1_Value_bytes_value_tag); + return MinBytes(); case google_firestore_v1_Value_bytes_value_tag: - return GetLowerBound(google_firestore_v1_Value_reference_value_tag); + return MinReference(); case google_firestore_v1_Value_reference_value_tag: - return GetLowerBound(google_firestore_v1_Value_geo_point_value_tag); + return MinGeoPoint(); case google_firestore_v1_Value_geo_point_value_tag: - return GetLowerBound(google_firestore_v1_Value_array_value_tag); + return MinArray(); case google_firestore_v1_Value_array_value_tag: - return GetLowerBound(google_firestore_v1_Value_map_value_tag); + return MinVector(); case google_firestore_v1_Value_map_value_tag: + if (IsVectorValue(value)) { + return MinMap(); + } return MaxValue(); default: - HARD_FAIL("Invalid type value: %s", value_tag); + HARD_FAIL("Invalid type value: %s", value.which_value_type); } } @@ -693,7 +706,7 @@ google_firestore_v1_Value MaxValue() { "google_firestore_v1_MapValue_FieldsEntry should be " "trivially-destructible; otherwise, it should use NoDestructor below."); static google_firestore_v1_MapValue_FieldsEntry field_entry; - field_entry.key = kMaxValueFieldKey; + field_entry.key = kTypeValueFieldKey; field_entry.value = value; google_firestore_v1_MapValue map_value; @@ -718,9 +731,9 @@ bool IsMaxValue(const google_firestore_v1_Value& value) { // Comparing the pointer address, then actual content if addresses are // different. - if (value.map_value.fields[0].key != kMaxValueFieldKey && + if (value.map_value.fields[0].key != kTypeValueFieldKey && nanopb::MakeStringView(value.map_value.fields[0].key) != - kRawMaxValueFieldKey) { + kRawTypeValueFieldKey) { return false; } @@ -736,6 +749,57 @@ bool IsMaxValue(const google_firestore_v1_Value& value) { kRawMaxValueFieldValue; } +int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, const char* kRawTypeValueFieldKey, + pb_bytes_array_s* kTypeValueFieldKey) { + for (pb_size_t i = 0; i < mapValue.fields_count; i++) { + if (mapValue.fields[0].key != kTypeValueFieldKey && + nanopb::MakeStringView(mapValue.fields[0].key) != + kRawTypeValueFieldKey) { + return i; + } + } + + return -1; +} + +bool IsVectorValue(const google_firestore_v1_Value& value) { + if (value.which_value_type != google_firestore_v1_Value_map_value_tag) { + return false; + } + + if (value.map_value.fields_count < 2) { + return false; + } + + int64_t typeFieldIndex = -1; + if ((typeFieldIndex = IndexOfKey(value.map_value, kRawTypeValueFieldKey, kTypeValueFieldKey)) < 0) { + return false; + } + + if (value.map_value.fields[typeFieldIndex].value.which_value_type != + google_firestore_v1_Value_string_value_tag) { + return false; + } + + // Comparing the pointer address, then actual content if addresses are + // different. + return value.map_value.fields[typeFieldIndex].value.string_value == kVectorTypeFieldValue || + nanopb::MakeStringView(value.map_value.fields[typeFieldIndex].value.string_value) == + kRawVectorTypeFieldValue; + + int64_t valueFieldIndex = -1; + if ((valueFieldIndex = IndexOfKey(value.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey)) < 0) { + return false; + } + + if (value.map_value.fields[valueFieldIndex].value.which_value_type != + google_firestore_v1_Value_map_value_tag) { + return false; + } + + return true; +} + google_firestore_v1_Value NaNValue() { google_firestore_v1_Value nan_value; nan_value.which_value_type = google_firestore_v1_Value_double_value_tag; @@ -748,6 +812,78 @@ bool IsNaNValue(const google_firestore_v1_Value& value) { std::isnan(value.double_value); } +google_firestore_v1_Value MinBoolean() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_boolean_value_tag; + lowerBound.boolean_value = false; + return lowerBound; +} + +google_firestore_v1_Value MinNumber() { + return NaNValue(); +} + +google_firestore_v1_Value MinTimestamp() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_timestamp_value_tag; + lowerBound.timestamp_value.seconds = std::numeric_limits::min(); + lowerBound.timestamp_value.nanos = 0; + return lowerBound; +} + +google_firestore_v1_Value MinString() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_string_value_tag; + lowerBound.string_value = nullptr; + return lowerBound; +} + +google_firestore_v1_Value MinBytes() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_bytes_value_tag; + lowerBound.bytes_value = nullptr; + return lowerBound; +} + +google_firestore_v1_Value MinReference() { + google_firestore_v1_Value result; + result.which_value_type = google_firestore_v1_Value_reference_value_tag; + result.reference_value = kMinimumReferenceValue; + return result; +} + +google_firestore_v1_Value MinGeoPoint() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_geo_point_value_tag; + lowerBound.geo_point_value.latitude = -90.0; + lowerBound.geo_point_value.longitude = -180.0; + return lowerBound; +} + +google_firestore_v1_Value MinArray() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_array_value_tag; + lowerBound.array_value.values = nullptr; + lowerBound.array_value.values_count = 0; + return lowerBound; +} + +google_firestore_v1_Value MinVector() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_map_value_tag; + lowerBound.map_value.fields = nullptr; + lowerBound.map_value.fields_count = 0; + return lowerBound; +} + +google_firestore_v1_Value MinMap() { + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_map_value_tag; + lowerBound.map_value.fields = nullptr; + lowerBound.map_value.fields_count = 0; + return lowerBound; +} + Message RefValue( const model::DatabaseId& database_id, const model::DocumentKey& document_key) { diff --git a/Firestore/core/src/model/value_util.h b/Firestore/core/src/model/value_util.h index 91e26a21ebb..7ed6d076ac4 100644 --- a/Firestore/core/src/model/value_util.h +++ b/Firestore/core/src/model/value_util.h @@ -42,18 +42,19 @@ class DatabaseId; * ordering, but modified to support server timestamps. */ enum class TypeOrder { - kNull = 0, - kBoolean = 1, - kNumber = 2, - kTimestamp = 3, - kServerTimestamp = 4, - kString = 5, - kBlob = 6, - kReference = 7, - kGeoPoint = 8, - kArray = 9, - kMap = 10, - kMaxValue = 11 + kNull = 0, + kBoolean = 1, + kNumber = 2, + kTimestamp = 3, + kServerTimestamp = 4, + kString = 5, + kBlob = 6, + kReference = 7, + kGeoPoint = 8, + kArray = 9, + kVector = 10, + kMap = 11, + kMaxValue = 12 }; /** Returns the backend's type order of the given Value type. */ @@ -94,7 +95,7 @@ std::string CanonicalId(const google_firestore_v1_Value& value); * The returned value might point to heap allocated memory that is owned by * this function. To take ownership of this memory, call `DeepClone`. */ -google_firestore_v1_Value GetLowerBound(pb_size_t value_tag); +google_firestore_v1_Value GetLowerBound(const google_firestore_v1_Value& value); /** * Returns the largest value for the given value type (exclusive). @@ -102,7 +103,7 @@ google_firestore_v1_Value GetLowerBound(pb_size_t value_tag); * The returned value might point to heap allocated memory that is owned by * this function. To take ownership of this memory, call `DeepClone`. */ -google_firestore_v1_Value GetUpperBound(pb_size_t value_tag); +google_firestore_v1_Value GetUpperBound(const google_firestore_v1_Value& value); /** * Generates the canonical ID for the provided array value (as used in Target @@ -155,6 +156,14 @@ google_firestore_v1_Value MaxValue(); */ bool IsMaxValue(const google_firestore_v1_Value& value); +/** + * Returns `true` if `value` represents a VectorValue.. + */ +bool IsVectorValue(const google_firestore_v1_Value& value); + +int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, const char* kRawTypeValueFieldKey, + pb_bytes_array_s* kTypeValueFieldKey); + /** * Returns `NaN` in its Protobuf representation. * @@ -166,6 +175,26 @@ google_firestore_v1_Value NaNValue(); /** Returns `true` if `value` is `NaN` in its Protobuf representation. */ bool IsNaNValue(const google_firestore_v1_Value& value); +google_firestore_v1_Value MinBoolean(); + +google_firestore_v1_Value MinNumber(); + +google_firestore_v1_Value MinTimestamp(); + +google_firestore_v1_Value MinString(); + +google_firestore_v1_Value MinBytes(); + +google_firestore_v1_Value MinReference(); + +google_firestore_v1_Value MinGeoPoint(); + +google_firestore_v1_Value MinArray(); + +google_firestore_v1_Value MinVector(); + +google_firestore_v1_Value MinMap(); + /** * Returns a Protobuf reference value representing the given location. * diff --git a/Firestore/core/test/unit/model/value_util_test.cc b/Firestore/core/test/unit/model/value_util_test.cc index d4db43dfe20..b2499eb6f9a 100644 --- a/Firestore/core/test/unit/model/value_util_test.cc +++ b/Firestore/core/test/unit/model/value_util_test.cc @@ -272,7 +272,7 @@ TEST_F(ValueUtilTest, StrictOrdering) { // numbers Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_integer_value_tag))); + DeepClone(MinNumber())); Add(comparison_groups, -1e20); Add(comparison_groups, std::numeric_limits::min()); Add(comparison_groups, -0.1); @@ -286,7 +286,7 @@ TEST_F(ValueUtilTest, StrictOrdering) { // dates Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_timestamp_value_tag))); + DeepClone(MinTimestamp())); Add(comparison_groups, kTimestamp1); Add(comparison_groups, kTimestamp2); @@ -317,7 +317,7 @@ TEST_F(ValueUtilTest, StrictOrdering) { // resource names Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_reference_value_tag))); + DeepClone(MinReference())); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc1"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc2"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c10/doc1"))); @@ -341,7 +341,7 @@ TEST_F(ValueUtilTest, StrictOrdering) { // arrays Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_array_value_tag))); + DeepClone(MinArray())); Add(comparison_groups, Array("bar")); Add(comparison_groups, Array("foo", 1)); Add(comparison_groups, Array("foo", 2)); @@ -349,14 +349,14 @@ TEST_F(ValueUtilTest, StrictOrdering) { // objects Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_map_value_tag))); + DeepClone(MinMap())); Add(comparison_groups, Map("bar", 0)); Add(comparison_groups, Map("bar", 0, "foo", 1)); Add(comparison_groups, Map("foo", 1)); Add(comparison_groups, Map("foo", 2)); Add(comparison_groups, Map("foo", "0")); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_map_value_tag))); + DeepClone(MaxValue())); for (size_t i = 0; i < comparison_groups.size(); ++i) { for (size_t j = i; j < comparison_groups.size(); ++j) { @@ -378,24 +378,24 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { // null first Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_null_value_tag))); + DeepClone(NullValue())); Add(comparison_groups, nullptr); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_null_value_tag))); + DeepClone(MinBoolean())); // booleans Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_boolean_value_tag))); + DeepClone(MinBoolean())); Add(comparison_groups, false); Add(comparison_groups, true); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_boolean_value_tag))); + DeepClone(MinNumber())); // numbers Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_integer_value_tag))); + DeepClone(MinNumber())); Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_double_value_tag))); + DeepClone(MinNumber())); Add(comparison_groups, -1e20); Add(comparison_groups, std::numeric_limits::min()); Add(comparison_groups, -0.1); @@ -407,13 +407,13 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, std::numeric_limits::max()); Add(comparison_groups, 1e20); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_integer_value_tag))); + DeepClone(MinTimestamp())); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_double_value_tag))); + DeepClone(MinTimestamp())); // dates Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_timestamp_value_tag))); + DeepClone(MinTimestamp())); Add(comparison_groups, kTimestamp1); Add(comparison_groups, kTimestamp2); @@ -422,11 +422,11 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, EncodeServerTimestamp(kTimestamp1, absl::nullopt)); Add(comparison_groups, EncodeServerTimestamp(kTimestamp2, absl::nullopt)); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_timestamp_value_tag))); + DeepClone(MinString())); // strings Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_string_value_tag))); + DeepClone(MinString())); Add(comparison_groups, ""); Add(comparison_groups, "\001\ud7ff\ue000\uffff"); Add(comparison_groups, "(╯°□°)╯︵ ┻━┻"); @@ -439,22 +439,22 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { // latin small letter e with acute accent + latin small letter a Add(comparison_groups, "\u00e9a"); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_string_value_tag))); + DeepClone(MinBytes())); // blobs Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_bytes_value_tag))); + DeepClone(MinBytes())); Add(comparison_groups, BlobValue()); Add(comparison_groups, BlobValue(0)); Add(comparison_groups, BlobValue(0, 1, 2, 3, 4)); Add(comparison_groups, BlobValue(0, 1, 2, 4, 3)); Add(comparison_groups, BlobValue(255)); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_bytes_value_tag))); + DeepClone(MinReference())); // resource names Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_reference_value_tag))); + DeepClone(MinReference())); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc1"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc2"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c10/doc1"))); @@ -462,11 +462,11 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, RefValue(DbId("p1/d2"), Key("c1/doc1"))); Add(comparison_groups, RefValue(DbId("p2/d1"), Key("c1/doc1"))); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_reference_value_tag))); + DeepClone(MinGeoPoint())); // geo points Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_geo_point_value_tag))); + DeepClone(MinGeoPoint())); Add(comparison_groups, GeoPoint(-90, -180)); Add(comparison_groups, GeoPoint(-90, 0)); Add(comparison_groups, GeoPoint(-90, 180)); @@ -480,28 +480,30 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, GeoPoint(90, 0)); Add(comparison_groups, GeoPoint(90, 180)); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_geo_point_value_tag))); + DeepClone(MinArray())); // arrays Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_array_value_tag))); + DeepClone(MinArray())); Add(comparison_groups, Array("bar")); Add(comparison_groups, Array("foo", 1)); Add(comparison_groups, Array("foo", 2)); Add(comparison_groups, Array("foo", "0")); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_array_value_tag))); + DeepClone(MinVector())); + // TODO vector + // objects Add(comparison_groups, - DeepClone(GetLowerBound(google_firestore_v1_Value_map_value_tag))); + DeepClone(MinMap())); Add(comparison_groups, Map("bar", 0)); Add(comparison_groups, Map("bar", 0, "foo", 1)); Add(comparison_groups, Map("foo", 1)); Add(comparison_groups, Map("foo", 2)); Add(comparison_groups, Map("foo", "0")); Add(comparison_groups, - DeepClone(GetUpperBound(google_firestore_v1_Value_map_value_tag))); + DeepClone(MaxValue())); for (size_t i = 0; i < comparison_groups.size(); ++i) { for (size_t j = i; j < comparison_groups.size(); ++j) { From 0725ad0c0e2e08ac7912c0d74d064df725f07800 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Wed, 24 Jul 2024 18:44:45 -0600 Subject: [PATCH 03/24] VectorValue tests and fixes. --- .../Firestore.xcodeproj/project.pbxproj | 14 +- .../xcschemes/Firestore_Tests_iOS.xcscheme | 26 ++- Firestore/Example/GoogleService-Info.plist | 30 ++++ .../Source/API/FIRVectorValue+Internal.h | 5 +- Firestore/Source/API/FIRVectorValue.m | 6 +- Firestore/Source/API/FSTUserDataReader.mm | 6 +- .../Public/FirebaseFirestore/FIRVectorValue.h | 9 +- .../Codable/CodablePassThroughTypes.swift | 3 +- .../Source/Codable/VectorValue+Codable.swift | 61 +++++++ .../Source/SwiftAPI/VectorValue+Swift.swift | 11 +- .../Integration/CodableIntegrationTests.swift | 25 ++- .../SnapshotListenerSourceTests.swift | 53 ++++++ .../Integration/VectorIntegrationTests.swift | 162 ++++++++++++++++++ .../src/index/firestore_index_value_writer.cc | 26 +++ Firestore/core/src/model/value_util.cc | 38 ++-- Firestore/core/src/model/value_util.h | 20 +++ .../unit/local/leveldb_index_manager_test.cc | 40 +++++ .../core/test/unit/model/value_util_test.cc | 28 ++- .../core/test/unit/remote/serializer_test.cc | 18 ++ Firestore/core/test/unit/testutil/testutil.h | 5 + 20 files changed, 537 insertions(+), 49 deletions(-) create mode 100644 Firestore/Example/GoogleService-Info.plist create mode 100644 Firestore/Swift/Source/Codable/VectorValue+Codable.swift create mode 100644 Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift diff --git a/Firestore/Example/Firestore.xcodeproj/project.pbxproj b/Firestore/Example/Firestore.xcodeproj/project.pbxproj index a6c521d5c07..db7bf9cb644 100644 --- a/Firestore/Example/Firestore.xcodeproj/project.pbxproj +++ b/Firestore/Example/Firestore.xcodeproj/project.pbxproj @@ -1541,6 +1541,9 @@ EF79998EBE4C72B97AB1880E /* value_util_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 40F9D09063A07F710811A84F /* value_util_test.cc */; }; EF8C005DC4BEA6256D1DBC6F /* user_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = CCC9BD953F121B9E29F9AA42 /* user_test.cc */; }; EFD682178A87513A5F1AEFD9 /* memory_query_engine_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 8EF6A33BC2D84233C355F1D0 /* memory_query_engine_test.cc */; }; + EFF22EAA2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EFF22EA92C5060A4009A369B /* VectorIntegrationTests.swift */; }; + EFF22EAB2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EFF22EA92C5060A4009A369B /* VectorIntegrationTests.swift */; }; + EFF22EAC2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EFF22EA92C5060A4009A369B /* VectorIntegrationTests.swift */; }; F05B277F16BDE6A47FE0F943 /* local_serializer_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = F8043813A5D16963EC02B182 /* local_serializer_test.cc */; }; F08DA55D31E44CB5B9170CCE /* limbo_spec_test.json in Resources */ = {isa = PBXBuildFile; fileRef = 54DA129E1F315EE100DD57A1 /* limbo_spec_test.json */; }; F091532DEE529255FB008E25 /* snapshot_version_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = ABA495B9202B7E79008A7851 /* snapshot_version_test.cc */; }; @@ -1693,7 +1696,7 @@ 132E32997D781B896672D30A /* reference_set_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = reference_set_test.cc; sourceTree = ""; }; 166CE73C03AB4366AAC5201C /* leveldb_index_manager_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; path = leveldb_index_manager_test.cc; sourceTree = ""; }; 1A7D48A017ECB54FD381D126 /* Validation_BloomFilterTest_MD5_5000_1_membership_test_result.json */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.json; name = Validation_BloomFilterTest_MD5_5000_1_membership_test_result.json; path = bloom_filter_golden_test_data/Validation_BloomFilterTest_MD5_5000_1_membership_test_result.json; sourceTree = ""; }; - 1A8141230C7E3986EACEF0B6 /* thread_safe_memoizer_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; path = thread_safe_memoizer_test.cc; sourceTree = ""; }; + 1A8141230C7E3986EACEF0B6 /* thread_safe_memoizer_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; path = thread_safe_memoizer_test.cc; sourceTree = ""; }; 1B342370EAE3AA02393E33EB /* cc_compilation_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; name = cc_compilation_test.cc; path = api/cc_compilation_test.cc; sourceTree = ""; }; 1B9F95EC29FAD3F100EEC075 /* FIRAggregateQueryUnitTests.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = FIRAggregateQueryUnitTests.mm; sourceTree = ""; }; 1C01D8CE367C56BB2624E299 /* index.pb.h */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.c.h; name = index.pb.h; path = admin/index.pb.h; sourceTree = ""; }; @@ -1750,7 +1753,7 @@ 4BD051DBE754950FEAC7A446 /* Validation_BloomFilterTest_MD5_500_01_bloom_filter_proto.json */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.json; name = Validation_BloomFilterTest_MD5_500_01_bloom_filter_proto.json; path = bloom_filter_golden_test_data/Validation_BloomFilterTest_MD5_500_01_bloom_filter_proto.json; sourceTree = ""; }; 4C73C0CC6F62A90D8573F383 /* string_apple_benchmark.mm */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.objcpp; path = string_apple_benchmark.mm; sourceTree = ""; }; 4D65F6E69993611D47DC8E7C /* SnapshotListenerSourceTests.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = SnapshotListenerSourceTests.swift; sourceTree = ""; }; - 4D9E51DA7A275D8B1CAEAEB2 /* listen_source_spec_test.json */ = {isa = PBXFileReference; includeInIndex = 1; path = listen_source_spec_test.json; sourceTree = ""; }; + 4D9E51DA7A275D8B1CAEAEB2 /* listen_source_spec_test.json */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.json; path = listen_source_spec_test.json; sourceTree = ""; }; 4F5B96F3ABCD2CA901DB1CD4 /* bundle_builder.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; path = bundle_builder.cc; sourceTree = ""; }; 526D755F65AC676234F57125 /* target_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; path = target_test.cc; sourceTree = ""; }; 52756B7624904C36FBB56000 /* fake_target_metadata_provider.h */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.c.h; path = fake_target_metadata_provider.h; sourceTree = ""; }; @@ -1904,7 +1907,7 @@ 62E54B832A9E910A003347C8 /* IndexingTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = IndexingTests.swift; sourceTree = ""; }; 63136A2371C0C013EC7A540C /* target_index_matcher_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; path = target_index_matcher_test.cc; sourceTree = ""; }; 64AA92CFA356A2360F3C5646 /* filesystem_testing.h */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.c.h; path = filesystem_testing.h; sourceTree = ""; }; - 65AF0AB593C3AD81A1F1A57E /* FIRCompositeIndexQueryTests.mm */ = {isa = PBXFileReference; includeInIndex = 1; path = FIRCompositeIndexQueryTests.mm; sourceTree = ""; }; + 65AF0AB593C3AD81A1F1A57E /* FIRCompositeIndexQueryTests.mm */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.objcpp; path = FIRCompositeIndexQueryTests.mm; sourceTree = ""; }; 67786C62C76A740AEDBD8CD3 /* FSTTestingHooks.h */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.c.h; path = FSTTestingHooks.h; sourceTree = ""; }; 69E6C311558EC77729A16CF1 /* Pods-Firestore_Example_iOS-Firestore_SwiftTests_iOS.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Firestore_Example_iOS-Firestore_SwiftTests_iOS.debug.xcconfig"; path = "Pods/Target Support Files/Pods-Firestore_Example_iOS-Firestore_SwiftTests_iOS/Pods-Firestore_Example_iOS-Firestore_SwiftTests_iOS.debug.xcconfig"; sourceTree = ""; }; 6A7A30A2DB3367E08939E789 /* bloom_filter.pb.h */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.c.h; path = bloom_filter.pb.h; sourceTree = ""; }; @@ -2090,6 +2093,7 @@ EF6C285029E462A200A7D4F1 /* FIRAggregateTests.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = FIRAggregateTests.mm; sourceTree = ""; }; EF6C286C29E6D22200A7D4F1 /* AggregationIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AggregationIntegrationTests.swift; sourceTree = ""; }; EF83ACD5E1E9F25845A9ACED /* leveldb_migrations_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; path = leveldb_migrations_test.cc; sourceTree = ""; }; + EFF22EA92C5060A4009A369B /* VectorIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VectorIntegrationTests.swift; sourceTree = ""; }; F02F734F272C3C70D1307076 /* filter_test.cc */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.cpp; path = filter_test.cc; sourceTree = ""; }; F119BDDF2F06B3C0883B8297 /* firebase_app_check_credentials_provider_test.mm */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.cpp.objcpp; name = firebase_app_check_credentials_provider_test.mm; path = credentials/firebase_app_check_credentials_provider_test.mm; sourceTree = ""; }; F354C0FE92645B56A6C6FD44 /* Pods-Firestore_IntegrationTests_iOS.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Firestore_IntegrationTests_iOS.release.xcconfig"; path = "Pods/Target Support Files/Pods-Firestore_IntegrationTests_iOS/Pods-Firestore_IntegrationTests_iOS.release.xcconfig"; sourceTree = ""; }; @@ -2233,6 +2237,7 @@ 62E54B832A9E910A003347C8 /* IndexingTests.swift */, 621D620928F9CE7400D2FA26 /* QueryIntegrationTests.swift */, 4D65F6E69993611D47DC8E7C /* SnapshotListenerSourceTests.swift */, + EFF22EA92C5060A4009A369B /* VectorIntegrationTests.swift */, ); path = Integration; sourceTree = ""; @@ -4638,6 +4643,7 @@ AECCD9663BB3DC52199F954A /* executor_std_test.cc in Sources */, 18F644E6AA98E6D6F3F1F809 /* executor_test.cc in Sources */, 6938575C8B5E6FE0D562547A /* exponential_backoff_test.cc in Sources */, + EFF22EAC2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 258B372CF33B7E7984BBA659 /* fake_target_metadata_provider.cc in Sources */, F8BD2F61EFA35C2D5120D9EB /* field_index_test.cc in Sources */, F272A8C41D2353700A11D1FB /* field_mask_test.cc in Sources */, @@ -4879,6 +4885,7 @@ 17DFF30CF61D87883986E8B6 /* executor_std_test.cc in Sources */, 814724DE70EFC3DDF439CD78 /* executor_test.cc in Sources */, BD6CC8614970A3D7D2CF0D49 /* exponential_backoff_test.cc in Sources */, + EFF22EAB2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 4D2655C5675D83205C3749DC /* fake_target_metadata_provider.cc in Sources */, 50C852E08626CFA7DC889EEA /* field_index_test.cc in Sources */, A1563EFEB021936D3FFE07E3 /* field_mask_test.cc in Sources */, @@ -5367,6 +5374,7 @@ 125B1048ECB755C2106802EB /* executor_std_test.cc in Sources */, DABB9FB61B1733F985CBF713 /* executor_test.cc in Sources */, 7BCF050BA04537B0E7D44730 /* exponential_backoff_test.cc in Sources */, + EFF22EAA2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, BA1C5EAE87393D8E60F5AE6D /* fake_target_metadata_provider.cc in Sources */, 84285C3F63D916A4786724A8 /* field_index_test.cc in Sources */, 6A40835DB2C02B9F07C02E88 /* field_mask_test.cc in Sources */, diff --git a/Firestore/Example/Firestore.xcodeproj/xcshareddata/xcschemes/Firestore_Tests_iOS.xcscheme b/Firestore/Example/Firestore.xcodeproj/xcshareddata/xcschemes/Firestore_Tests_iOS.xcscheme index b4ac9bed6dd..98b55265c5e 100644 --- a/Firestore/Example/Firestore.xcodeproj/xcshareddata/xcschemes/Firestore_Tests_iOS.xcscheme +++ b/Firestore/Example/Firestore.xcodeproj/xcshareddata/xcschemes/Firestore_Tests_iOS.xcscheme @@ -26,8 +26,17 @@ buildConfiguration = "Debug" selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB" selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB" - enableASanStackUseAfterReturn = "YES" - shouldUseLaunchSchemeArgsEnv = "YES"> + shouldUseLaunchSchemeArgsEnv = "YES" + enableASanStackUseAfterReturn = "YES"> + + + + @@ -40,17 +49,6 @@ - - - - - - - - + + + + API_KEY + AIzaSyCPOtsd7zCEEnc90uXunosAcu1-EKtWYYw + GCM_SENDER_ID + 37794738117 + PLIST_VERSION + 1 + BUNDLE_ID + com.google.firebase.FirestoreSample + PROJECT_ID + markduckworth-firestore-tests + STORAGE_BUCKET + markduckworth-firestore-tests.appspot.com + IS_ADS_ENABLED + + IS_ANALYTICS_ENABLED + + IS_APPINVITE_ENABLED + + IS_GCM_ENABLED + + IS_SIGNIN_ENABLED + + GOOGLE_APP_ID + 1:37794738117:ios:f1861da1590da1e46125ac + + \ No newline at end of file diff --git a/Firestore/Source/API/FIRVectorValue+Internal.h b/Firestore/Source/API/FIRVectorValue+Internal.h index 8774da83cb7..23a6243f566 100644 --- a/Firestore/Source/API/FIRVectorValue+Internal.h +++ b/Firestore/Source/API/FIRVectorValue+Internal.h @@ -20,9 +20,8 @@ NS_ASSUME_NONNULL_BEGIN @interface FIRVectorValue (Internal) -- (instancetype)init NS_UNAVAILABLE; -- (instancetype)initWithNSNumbers: (NSArray *)values; -- (NSArray *)toNSArray; +// - (instancetype)init NS_UNAVAILABLE; +// - (NSArray *)toNSArray; @end diff --git a/Firestore/Source/API/FIRVectorValue.m b/Firestore/Source/API/FIRVectorValue.m index ec63aecf43b..0b371774953 100644 --- a/Firestore/Source/API/FIRVectorValue.m +++ b/Firestore/Source/API/FIRVectorValue.m @@ -32,11 +32,11 @@ @interface FIRVectorValue () { @implementation FIRVectorValue -- (instancetype)initWithNSNumbers: (NSArray *)values { +- (instancetype)initWithNSNumbers: (NSArray *)data { if (self = [super init]) { std::vector converted; - converted.reserve(values.count); - for (NSNumber *value in values) { + converted.reserve(data.count); + for (NSNumber *value in data) { converted.emplace_back([value doubleValue]); } diff --git a/Firestore/Source/API/FSTUserDataReader.mm b/Firestore/Source/API/FSTUserDataReader.mm index b175131c05b..115a00c9332 100644 --- a/Firestore/Source/API/FSTUserDataReader.mm +++ b/Firestore/Source/API/FSTUserDataReader.mm @@ -349,7 +349,7 @@ - (ParsedUpdateData)parsedUpdateData:(id)input { result->map_value.fields_count = 2; result->map_value.fields = nanopb::MakeArray(2); - result->map_value.fields[0].key = nanopb::MakeBytesArray(MakeString(@"__type__")); + result->map_value.fields[0].key = nanopb::CopyBytesArray(model::kTypeValueFieldKey); result->map_value.fields[0].value = *[self encodeStringValue:MakeString(@"__vector__")].release(); NSArray *vectorArray = [vectorValue toNSArray]; @@ -367,10 +367,10 @@ - (ParsedUpdateData)parsedUpdateData:(id)input { } // Vector values must always use Double encoding - result->array_value.values[idx] = *[self encodeDouble:[entry doubleValue]].release(); + arrayMessage->array_value.values[idx] = *[self encodeDouble:[entry doubleValue]].release(); }]; - result->map_value.fields[1].key = nanopb::MakeBytesArray(MakeString(@"value")); + result->map_value.fields[1].key = nanopb::CopyBytesArray(model::kVectorValueFieldKey); result->map_value.fields[1].value = *arrayMessage.release(); return std::move(result); diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h index 8ed768e0f23..746965fea4d 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h @@ -20,9 +20,14 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(VectorValue) @interface FIRVectorValue : NSObject -- (instancetype)init __attribute__((unavailable("FIRVectorValue cannot be created directly."))); -- (NSArray *)toNSArray NS_SWIFT_NAME(toNSArray()); +/** :nodoc: */ +- (instancetype)init NS_UNAVAILABLE; + +// Public initializer is required to support Codable +- (instancetype)initWithNSNumbers: (NSArray *)data NS_REFINED_FOR_SWIFT; + +- (NSArray *)toNSArray NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Swift/Source/Codable/CodablePassThroughTypes.swift b/Firestore/Swift/Source/Codable/CodablePassThroughTypes.swift index 3640533ef0f..e9f7b87c0bf 100644 --- a/Firestore/Swift/Source/Codable/CodablePassThroughTypes.swift +++ b/Firestore/Swift/Source/Codable/CodablePassThroughTypes.swift @@ -29,6 +29,7 @@ struct FirestorePassthroughTypes: StructureCodingPassthroughTypeResolver { t is GeoPoint || t is Timestamp || t is FieldValue || - t is DocumentReference + t is DocumentReference || + t is VectorValue } } diff --git a/Firestore/Swift/Source/Codable/VectorValue+Codable.swift b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift new file mode 100644 index 00000000000..6f8b532c846 --- /dev/null +++ b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift @@ -0,0 +1,61 @@ +/* + * Copyright 2024 Google + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if SWIFT_PACKAGE + @_exported import FirebaseFirestoreInternalWrapper +#else + @_exported import FirebaseFirestoreInternal +#endif // SWIFT_PACKAGE + +/** + * A protocol describing the encodable properties of a VectorValue. + */ +private protocol CodableVectorValue: Codable { + var data: [Double] { get } + + init(__nsNumbers: [NSNumber]) +} + +/** The keys in a Timestamp. Must match the properties of CodableTimestamp. */ +private enum VectorValueKeys: String, CodingKey { + case data +} + +/** + * An extension of VectorValue that implements the behavior of the Codable protocol. + * + * Note: this is implemented manually here because the Swift compiler can't synthesize these methods + * when declaring an extension to conform to Codable. + */ +extension CodableVectorValue { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: VectorValueKeys.self) + let data = try container.decode([Double].self, forKey: .data) + + let array = data.map { double in + return NSNumber(value: double) + } + self.init(__nsNumbers: array) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: VectorValueKeys.self) + try container.encode(data, forKey: .data) + } +} + +/** Extends VectorValue to conform to Codable. */ +extension VectorValue: CodableVectorValue{} diff --git a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift index 6e43cf628a4..147e1349209 100644 --- a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift @@ -21,10 +21,17 @@ #endif // SWIFT_PACKAGE public extension VectorValue { + convenience init(_ doubles: [Double]) { + let array = doubles.map { float in + return NSNumber(value: float) + } + + self.init(__nsNumbers: array) + } /// Returns a raw number array representation of the vector. /// - Returns: An array of Double values representing the vector. - func toArray() -> [Double] { - return self.toNSArray().map { Double(truncating: $0) } + var data: [Double] { + return self.__toNSArray().map { Double(truncating: $0) } } } diff --git a/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift b/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift index 7aa6ae527b9..01634dfcc21 100644 --- a/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift @@ -82,13 +82,15 @@ class CodableIntegrationTests: FSTIntegrationTestCase { var ts: Timestamp var geoPoint: GeoPoint var docRef: DocumentReference + var vector: VectorValue } let docToWrite = documentRef() let model = Model(name: "test", age: 42, ts: Timestamp(seconds: 987_654_321, nanoseconds: 0), geoPoint: GeoPoint(latitude: 45, longitude: 54), - docRef: docToWrite) + docRef: docToWrite, + vector: FieldValue.vector([0.7, 0.6])) for flavor in allFlavors { try setData(from: model, forDocument: docToWrite, withFlavor: flavor) @@ -183,6 +185,27 @@ class CodableIntegrationTests: FSTIntegrationTestCase { XCTAssertEqual(data["intValue"] as! Int, 3, "Failed with flavor \(flavor)") } } + + func testVectorValue() throws { + struct Model: Codable { + var name: String + var embedding: VectorValue + } + let model = Model( + name: "name", + embedding: FieldValue.vector([0.1, 0.3, 0.4]) + ) + + let docToWrite = documentRef() + + for flavor in allFlavors { + try setData(from: model, forDocument: docToWrite, withFlavor: flavor) + + let data = try readDocument(forRef: docToWrite).data(as: Model.self) + + XCTAssertEqual(data.embedding, VectorValue([0.1, 0.3, 0.4]), "Failed with flavor \(flavor)") + } + } func testDataBlob() throws { struct Model: Encodable { diff --git a/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift b/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift index 7b5c7812a20..0a6bebf7091 100644 --- a/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift +++ b/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift @@ -673,4 +673,57 @@ class SnapshotListenerSourceTests: FSTIntegrationTestCase { defaultRegistration.remove() cacheRegistration.remove() } + + + func testListenToDocumentsWithVectors() throws { + let collection = collectionRef() + let doc = collection.document() + + let registration = collection.whereField("purpose", isEqualTo: "vector tests").addSnapshotListener(eventAccumulator.valueEventHandler) + + var querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, true) + + doc.setData([ + "purpose": "vector tests", + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99]) + ]) + + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, false) + XCTAssertEqual(querySnap.documents[0].data()["vector0"] as! VectorValue, FieldValue.vector([0.0])) + XCTAssertEqual(querySnap.documents[0].data()["vector1"] as! VectorValue, FieldValue.vector([1, 2, 3.99])) + + doc.setData([ + "purpose": "vector tests", + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99]), + "vector2": FieldValue.vector([0.0, 0, 0]) + ]) + + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, false) + XCTAssertEqual(querySnap.documents[0].data()["vector0"] as! VectorValue, FieldValue.vector([0.0])) + XCTAssertEqual(querySnap.documents[0].data()["vector1"] as! VectorValue, FieldValue.vector([1, 2, 3.99])) + XCTAssertEqual(querySnap.documents[0].data()["vector2"] as! VectorValue, FieldValue.vector([0.0, 0, 0])) + + doc.updateData([ + "vector3": FieldValue.vector([-1, -200, -999.0]) + ]) + + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, false) + XCTAssertEqual(querySnap.documents[0].data()["vector0"] as! VectorValue, FieldValue.vector([0.0])) + XCTAssertEqual(querySnap.documents[0].data()["vector1"] as! VectorValue, FieldValue.vector([1, 2, 3.99])) + XCTAssertEqual(querySnap.documents[0].data()["vector2"] as! VectorValue, FieldValue.vector([0.0, 0, 0])) + XCTAssertEqual(querySnap.documents[0].data()["vector3"] as! VectorValue, FieldValue.vector([-1, -200, -999.0])) + + doc.delete() + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, true) + + eventAccumulator.assertNoAdditionalEvents() + registration.remove() + } } diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift new file mode 100644 index 00000000000..d4c3bd29bf3 --- /dev/null +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -0,0 +1,162 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import FirebaseFirestore +import Foundation +import Combine + +@available(iOS 15, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *) +class VectorIntegrationTests: FSTIntegrationTestCase { + func testWriteAndReadVectorEmbeddings() async throws { + let collection = collectionRef() + + let ref = try await collection.addDocument(data: [ + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99])]) + + try await ref.setData([ + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99]), + "vector2": FieldValue.vector([0, 0, 0] as Array)]) + + try await ref.updateData([ + "vector3": FieldValue.vector([-1, -200, -999] as Array)]) + + let snapshot = try await ref.getDocument(); + XCTAssertEqual(snapshot.get("vector0") as? VectorValue, FieldValue.vector([0.0])) + XCTAssertEqual(snapshot.get("vector1") as? VectorValue, FieldValue.vector([1, 2, 3.99])) + XCTAssertEqual(snapshot.get("vector2") as? VectorValue, FieldValue.vector([0, 0, 0] as Array)) + XCTAssertEqual(snapshot.get("vector3") as? VectorValue, FieldValue.vector([-1, -200, -999] as Array)) + } + + func testSdkOrdersVectorFieldSameWayAsBackend() async throws { + let collection = collectionRef() + + var docsInOrder: [[String:Any]] = [ + ["embedding": [1, 2, 3, 4, 5, 6]], + ["embedding": [100] ], + ["embedding": FieldValue.vector([Double.infinity * -1]) ], + ["embedding": FieldValue.vector([-100.0]) ], + ["embedding": FieldValue.vector([100.0]) ], + ["embedding": FieldValue.vector([Double.infinity]) ], + ["embedding": FieldValue.vector([1, 2.0]) ], + ["embedding": FieldValue.vector([2, 2.0]) ], + ["embedding": FieldValue.vector([1, 2, 3.0]) ], + ["embedding": FieldValue.vector([1, 2, 3, 4.0]) ], + ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0]) ], + ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0]) ], + ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0]) ], + ["embedding": [ "HELLO": "WORLD" ] ], + ["embedding": [ "hello": "world" ] ] + ] + + var docs: [[String:Any]] = []; + for data in docsInOrder { + let docRef = try await collection.addDocument(data: data) + docs.append(["id": docRef.documentID, "value": data]) + } + + // We validate that the SDK orders the vector field the same way as the backend + // by comparing the sort order of vector fields from getDocsFromServer and + // onSnapshot. onSnapshot will return sort order of the SDK, + // and getDocsFromServer will return sort order of the backend. + + let orderedQuery = collection.order(by: "embedding") + + + let watchSnapshot = try await Future() { promise in + orderedQuery.addSnapshotListener { snapshot, error in + if let error { + promise(Result.failure(error)) + } + if let snapshot { + promise(Result.success(snapshot)) + } + } + }.value + + let getSnapshot = try await orderedQuery.getDocuments(source: .server) + + // Compare the snapshot (including sort order) of a snapshot + // from Query.onSnapshot() to an actual snapshot from Query.get() + XCTAssertEqual(watchSnapshot.count, getSnapshot.count) + for i in (0.. #include @@ -46,6 +47,7 @@ enum IndexType { kReference = 37, kGeopoint = 45, kArray = 50, + kVector = 53, kMap = 55, kReferenceSegment = 60, // A terminator that indicates that a truncatable value was not truncated. @@ -105,6 +107,27 @@ void WriteIndexArray(const google_firestore_v1_ArrayValue& array_index_value, } } +void WriteIndexVector(const google_firestore_v1_MapValue& map_index_value, + DirectionalIndexByteEncoder* encoder) { + WriteValueTypeLabel(encoder, IndexType::kVector); + + int64_t valueIndex = model::IndexOfKey(map_index_value, model::kRawVectorValueFieldKey, model::kVectorValueFieldKey); + + if (valueIndex < 0 || map_index_value.fields[valueIndex].value.which_value_type != google_firestore_v1_Value_array_value_tag) { + return WriteIndexArray(model::MinArray().array_value, encoder); + } + + auto value = map_index_value.fields[valueIndex].value; + + // Vectors sort first by length + WriteValueTypeLabel(encoder, IndexType::kNumber); + encoder->WriteLong(value.array_value.values_count); + + // Vectors then sort by position value + WriteIndexString(model::kVectorValueFieldKey, encoder); + WriteIndexValueAux(value, encoder); +} + void WriteIndexMap(google_firestore_v1_MapValue map_index_value, DirectionalIndexByteEncoder* encoder) { WriteValueTypeLabel(encoder, IndexType::kMap); @@ -183,6 +206,9 @@ void WriteIndexValueAux(const google_firestore_v1_Value& index_value, if (model::IsMaxValue(index_value)) { WriteValueTypeLabel(encoder, std::numeric_limits::max()); break; + } else if(model::IsVectorValue(index_value)) { + WriteIndexVector(index_value.map_value, encoder); + break; } WriteIndexMap(index_value.map_value, encoder); WriteTruncationMarker(encoder); diff --git a/Firestore/core/src/model/value_util.cc b/Firestore/core/src/model/value_util.cc index bc7038b4bd6..9dace80b301 100644 --- a/Firestore/core/src/model/value_util.cc +++ b/Firestore/core/src/model/value_util.cc @@ -38,7 +38,9 @@ namespace firebase { namespace firestore { namespace model { -namespace { + +using nanopb::Message; +using util::ComparisonResult; /** The smallest reference value. */ pb_bytes_array_s* kMinimumReferenceValue = @@ -64,11 +66,6 @@ const char* kRawVectorValueFieldKey = "value"; pb_bytes_array_s* kVectorValueFieldKey = nanopb::MakeBytesArray(kRawVectorValueFieldKey); -} // namespace - -using nanopb::Message; -using util::ComparisonResult; - TypeOrder GetTypeOrder(const google_firestore_v1_Value& value) { switch (value.which_value_type) { case google_firestore_v1_Value_null_value_tag: @@ -328,7 +325,7 @@ ComparisonResult Compare(const google_firestore_v1_Value& left, return CompareMaps(left.map_value, right.map_value); case TypeOrder::kVector: - return CompareMaps(left.map_value, right.map_value); + return CompareVectors(left, right); case TypeOrder::kMaxValue: return util::ComparisonResult::Same; @@ -752,8 +749,8 @@ bool IsMaxValue(const google_firestore_v1_Value& value) { int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, const char* kRawTypeValueFieldKey, pb_bytes_array_s* kTypeValueFieldKey) { for (pb_size_t i = 0; i < mapValue.fields_count; i++) { - if (mapValue.fields[0].key != kTypeValueFieldKey && - nanopb::MakeStringView(mapValue.fields[0].key) != + if (mapValue.fields[i].key == kTypeValueFieldKey || + nanopb::MakeStringView(mapValue.fields[i].key) == kRawTypeValueFieldKey) { return i; } @@ -869,10 +866,29 @@ google_firestore_v1_Value MinArray() { } google_firestore_v1_Value MinVector() { + google_firestore_v1_Value typeValue; + typeValue.which_value_type = google_firestore_v1_Value_string_value_tag; + typeValue.string_value = kVectorTypeFieldValue; + + google_firestore_v1_MapValue_FieldsEntry *field_entries = nanopb::MakeArray(2); + field_entries[0].key = kTypeValueFieldKey; + field_entries[0].value = typeValue; + + google_firestore_v1_Value arrayValue; + arrayValue.which_value_type = google_firestore_v1_Value_array_value_tag; + arrayValue.array_value.values = nullptr; + arrayValue.array_value.values_count = 0; + field_entries[1].key = kVectorValueFieldKey; + field_entries[1].value = arrayValue; + + google_firestore_v1_MapValue map_value; + map_value.fields_count = 2; + map_value.fields = field_entries; + google_firestore_v1_Value lowerBound; lowerBound.which_value_type = google_firestore_v1_Value_map_value_tag; - lowerBound.map_value.fields = nullptr; - lowerBound.map_value.fields_count = 0; + lowerBound.map_value = map_value; + return lowerBound; } diff --git a/Firestore/core/src/model/value_util.h b/Firestore/core/src/model/value_util.h index 7ed6d076ac4..44a37f20067 100644 --- a/Firestore/core/src/model/value_util.h +++ b/Firestore/core/src/model/value_util.h @@ -23,6 +23,7 @@ #include "Firestore/Protos/nanopb/google/firestore/v1/document.nanopb.h" #include "Firestore/core/src/nanopb/message.h" +#include "Firestore/core/src/nanopb/nanopb_util.h" #include "absl/types/optional.h" namespace firebase { @@ -37,6 +38,25 @@ namespace model { class DocumentKey; class DatabaseId; +/** The smallest reference value. */ +extern pb_bytes_array_s* kMinimumReferenceValue; + +/** The field type of a special object type. */ +extern const char* kRawTypeValueFieldKey; +extern pb_bytes_array_s* kTypeValueFieldKey; + +/** The field value of a maximum proto value. */ +extern const char* kRawMaxValueFieldValue; +extern pb_bytes_array_s* kMaxValueFieldValue; + +/** The type of a VectorValue proto. */ +extern const char* kRawVectorTypeFieldValue; +extern pb_bytes_array_s* kVectorTypeFieldValue; + +/** The value key of a VectorValue proto. */ +extern const char* kRawVectorValueFieldKey; +extern pb_bytes_array_s* kVectorValueFieldKey; + /** * The order of types in Firestore. This order is based on the backend's * ordering, but modified to support server timestamps. diff --git a/Firestore/core/test/unit/local/leveldb_index_manager_test.cc b/Firestore/core/test/unit/local/leveldb_index_manager_test.cc index 20445421c22..175f681e512 100644 --- a/Firestore/core/test/unit/local/leveldb_index_manager_test.cc +++ b/Firestore/core/test/unit/local/leveldb_index_manager_test.cc @@ -50,6 +50,7 @@ using testutil::OrderBy; using testutil::OrFilters; using testutil::Query; using testutil::Version; +using testutil::VectorType; std::unique_ptr PersistenceFactory() { return LevelDbPersistenceForTesting(); @@ -929,6 +930,45 @@ TEST_F(LevelDbIndexManagerTest, IndexEntriesAreUpdatedWithDeletedDoc) { }); } +TEST_F(LevelDbIndexManagerTest, IndexVectorValueFields) { + persistence_->Run("TestIndexVectorValueFields", [&]() { + index_manager_->Start(); + index_manager_->AddFieldIndex(MakeFieldIndex("coll", "embedding", model::Segment::kAscending)); + + AddDoc("coll/arr1", Map("embedding", Array(1.0, 2.0, 3.0))); + AddDoc("coll/map2", Map("embedding", Map())); + AddDoc("coll/doc3", Map("embedding", VectorType(4.0, 5.0, 6.0))); + AddDoc("coll/doc4", Map("embedding", VectorType(5.0))); + + auto query = Query("coll").AddingOrderBy(OrderBy("embedding")); + { + SCOPED_TRACE("no filter"); + VerifyResults(query, {"coll/arr1", "coll/doc4", "coll/doc3", "coll/map2"}); + } + + query = Query("coll").AddingOrderBy(OrderBy("embedding")) + .AddingFilter(Filter("embedding", "==", VectorType(4.0, 5.0, 6.0))); + { + SCOPED_TRACE("vector<4.0, 5.0, 6.0>"); + VerifyResults(query, {"coll/doc3"}); + } + + query = Query("coll").AddingOrderBy(OrderBy("embedding")) + .AddingFilter(Filter("embedding", ">", VectorType(4.0, 5.0, 6.0))); + { + SCOPED_TRACE("> vector<4.0, 5.0, 6.0>"); + VerifyResults(query, {}); + } + + query = Query("coll").AddingOrderBy(OrderBy("embedding")) + .AddingFilter(Filter("embedding", ">", VectorType(4.0))); + { + SCOPED_TRACE("> vector<4.0>"); + VerifyResults(query, {"coll/doc4", "coll/doc3"}); + } + }); +} + TEST_F(LevelDbIndexManagerTest, AdvancedQueries) { // This test compares local query results with those received from the Java // Server SDK. diff --git a/Firestore/core/test/unit/model/value_util_test.cc b/Firestore/core/test/unit/model/value_util_test.cc index b2499eb6f9a..465d20d0e02 100644 --- a/Firestore/core/test/unit/model/value_util_test.cc +++ b/Firestore/core/test/unit/model/value_util_test.cc @@ -99,6 +99,9 @@ class ValueUtilTest : public ::testing::Test { ComparisonResult expected_result) { for (pb_size_t i = 0; i < left->values_count; ++i) { for (pb_size_t j = 0; j < right->values_count; ++j) { + if (expected_result != Compare(left->values[i], right->values[j])) { + std::cout << "here" << std::endl; + } EXPECT_EQ(expected_result, Compare(left->values[i], right->values[j])) << "Order check failed for '" << CanonicalId(left->values[i]) << "' and '" << CanonicalId(right->values[j]) << "' (expected " @@ -243,6 +246,7 @@ TEST_F(ValueUtilTest, Equality) { Add(equals_group, Array("foo", "bar"), Array("foo", "bar")); Add(equals_group, Array("foo", "bar", "baz")); Add(equals_group, Array("foo")); + Add(equals_group, Map("__type__", "__vector__", "value", Array()), DeepClone(MinVector())); Add(equals_group, Map("bar", 1, "foo", 2), Map("bar", 1, "foo", 2)); Add(equals_group, Map("bar", 2, "foo", 1)); Add(equals_group, Map("bar", 1)); @@ -346,6 +350,13 @@ TEST_F(ValueUtilTest, StrictOrdering) { Add(comparison_groups, Array("foo", 1)); Add(comparison_groups, Array("foo", 2)); Add(comparison_groups, Array("foo", "0")); + + // vectors + Add(comparison_groups, + DeepClone(MinVector())); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(100))); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 2.0, 3.0))); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 3.0, 2.0))); // objects Add(comparison_groups, @@ -491,8 +502,13 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, Array("foo", "0")); Add(comparison_groups, DeepClone(MinVector())); - - // TODO vector + + // vectors + Add(comparison_groups, + DeepClone(MinVector())); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(100))); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 2.0, 3.0))); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 3.0, 2.0))); // objects Add(comparison_groups, @@ -517,7 +533,7 @@ TEST_F(ValueUtilTest, CanonicalId) { VerifyCanonicalId(Value(true), "true"); VerifyCanonicalId(Value(false), "false"); VerifyCanonicalId(Value(1), "1"); - VerifyCanonicalId(Value(1.0), "1.0"); + VerifyCanonicalId(Value(1.0), "1.0"); VerifyCanonicalId(Value(Timestamp(30, 1000)), "time(30,1000)"); VerifyCanonicalId(Value("a"), "a"); VerifyCanonicalId(Value(std::string("a\0b", 3)), std::string("a\0b", 3)); @@ -526,8 +542,10 @@ TEST_F(ValueUtilTest, CanonicalId) { VerifyCanonicalId(Value(GeoPoint(30, 60)), "geo(30.0,60.0)"); VerifyCanonicalId(Value(Array(1, 2, 3)), "[1,2,3]"); VerifyCanonicalId(Map("a", 1, "b", 2, "c", "3"), "{a:1,b:2,c:3}"); - VerifyCanonicalId(Map("a", Array("b", Map("c", GeoPoint(30, 60)))), - "{a:[b,{c:geo(30.0,60.0)}]}"); + VerifyCanonicalId(Map("a", Array("b", Map("c", GeoPoint(30, 60)))), + "{a:[b,{c:geo(30.0,60.0)}]}"); + VerifyCanonicalId(Map("__type__", "__vector__", "value", Array(1.0, 1.0, -2.0, 3.14)), + "{__type__:__vector__,value:[1.0,1.0,-2.0,3.1]}"); } TEST_F(ValueUtilTest, DeepClone) { diff --git a/Firestore/core/test/unit/remote/serializer_test.cc b/Firestore/core/test/unit/remote/serializer_test.cc index bea30617594..79fc79e99d7 100644 --- a/Firestore/core/test/unit/remote/serializer_test.cc +++ b/Firestore/core/test/unit/remote/serializer_test.cc @@ -821,6 +821,24 @@ TEST_F(SerializerTest, EncodesNestedObjects) { ExpectRoundTrip(model, proto, TypeOrder::kMap); } +TEST_F(SerializerTest, EncodesVectorValue) { + Message model = Map( + "__type__", "__vector__", "value", Array(1.0, 2.0, 3.0)); + + v1::Value array_proto; + *array_proto.mutable_array_value()->add_values() = ValueProto(1.0); + *array_proto.mutable_array_value()->add_values() = ValueProto(2.0); + *array_proto.mutable_array_value()->add_values() = ValueProto(3.0); + + v1::Value proto; + google::protobuf::Map* fields = + proto.mutable_map_value()->mutable_fields(); + (*fields)["__type__"] = ValueProto("__vector__"); + (*fields)["value"] = array_proto; + + ExpectRoundTrip(model, proto, TypeOrder::kVector); +} + TEST_F(SerializerTest, EncodesFieldValuesWithRepeatedEntries) { // Technically, serialized Value protos can contain multiple values. (The last // one "wins".) However, well-behaved proto emitters (such as libprotobuf) diff --git a/Firestore/core/test/unit/testutil/testutil.h b/Firestore/core/test/unit/testutil/testutil.h index 50845c21d6c..259e93eb6c2 100644 --- a/Firestore/core/test/unit/testutil/testutil.h +++ b/Firestore/core/test/unit/testutil/testutil.h @@ -288,6 +288,11 @@ nanopb::Message Map(Args... key_value_pairs) { return details::MakeMap(std::move(key_value_pairs)...); } +template +nanopb::Message VectorType(Args&&... values) { + return Map("__type__", "__vector__", "value", details::MakeArray(std::move(values)...)); +} + model::DocumentKey Key(absl::string_view path); model::FieldPath Field(absl::string_view field); From 1d691774d206856aaf790bafa645c3999d19ef27 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Wed, 24 Jul 2024 18:47:25 -0600 Subject: [PATCH 04/24] Cleanup --- Firestore/Example/GoogleService-Info.plist | 30 ---------------------- 1 file changed, 30 deletions(-) delete mode 100644 Firestore/Example/GoogleService-Info.plist diff --git a/Firestore/Example/GoogleService-Info.plist b/Firestore/Example/GoogleService-Info.plist deleted file mode 100644 index c1ff82badbe..00000000000 --- a/Firestore/Example/GoogleService-Info.plist +++ /dev/null @@ -1,30 +0,0 @@ - - - - - API_KEY - AIzaSyCPOtsd7zCEEnc90uXunosAcu1-EKtWYYw - GCM_SENDER_ID - 37794738117 - PLIST_VERSION - 1 - BUNDLE_ID - com.google.firebase.FirestoreSample - PROJECT_ID - markduckworth-firestore-tests - STORAGE_BUCKET - markduckworth-firestore-tests.appspot.com - IS_ADS_ENABLED - - IS_ANALYTICS_ENABLED - - IS_APPINVITE_ENABLED - - IS_GCM_ENABLED - - IS_SIGNIN_ENABLED - - GOOGLE_APP_ID - 1:37794738117:ios:f1861da1590da1e46125ac - - \ No newline at end of file From 12637475b44d888d88886129fa03ace22ba1ead9 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Thu, 25 Jul 2024 14:40:17 -0600 Subject: [PATCH 05/24] cleanup --- .../Public/FirebaseFirestore/FIRFieldValue.h | 2 +- .../Source/SwiftAPI/FieldValue+Swift.swift | 16 ++--- .../Source/SwiftAPI/VectorValue+Swift.swift | 4 +- .../Integration/CodableIntegrationTests.swift | 2 +- .../Integration/VectorIntegrationTests.swift | 61 +++++++++++++++++-- 5 files changed, 69 insertions(+), 16 deletions(-) diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h index 37f1a72d31c..23842b30edf 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h @@ -97,7 +97,7 @@ NS_SWIFT_NAME(FieldValue) * @param values Create a `VectorValue` instance with a copy of this array of NSNumbers. * @return A new `VectorValue` constructed with a copy of the given array of NSNumbers. */ -+ (FIRVectorValue*)vectorFromNSNumbers:(NSArray *)values NS_SWIFT_NAME(vector(fromNSNumbers:)); ++ (FIRVectorValue*)vectorFromNSNumbers:(NSArray *)values NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift index 6d43c810e2a..60ccbd39250 100644 --- a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift @@ -23,22 +23,22 @@ public extension FieldValue { /// Creates a new `VectorValue` constructed with a copy of the given array of Doubles. - /// - Parameter values: An array of Doubles. + /// - Parameter data: An array of Doubles. /// - Returns: A new `VectorValue` constructed with a copy of the given array of Doubles. - static func vector(_ values: [Double]) -> VectorValue { - let array = values.map { double in + static func vector(_ data: [Double]) -> VectorValue { + let array = data.map { double in return NSNumber(value: double) } - return FieldValue.vector(fromNSNumbers: array) + return FieldValue.__vector(from: array) } /// Creates a new `VectorValue` constructed with a copy of the given array of Floats. - /// - Parameter values: An array of Floats. + /// - Parameter data: An array of Floats. /// - Returns: A new `VectorValue` constructed with a copy of the given array of Floats. - static func vector(_ values: [Float]) -> VectorValue { - let array = values.map { float in + static func vector(_ data: [Float]) -> VectorValue { + let array = data.map { float in return NSNumber(value: float) } - return FieldValue.vector(fromNSNumbers: array) + return FieldValue.__vector(from: array) } } diff --git a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift index 147e1349209..b9d4d92e886 100644 --- a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift @@ -21,8 +21,8 @@ #endif // SWIFT_PACKAGE public extension VectorValue { - convenience init(_ doubles: [Double]) { - let array = doubles.map { float in + convenience init(_ data: [Double]) { + let array = data.map { float in return NSNumber(value: float) } diff --git a/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift b/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift index 01634dfcc21..ee25019e466 100644 --- a/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift @@ -193,7 +193,7 @@ class CodableIntegrationTests: FSTIntegrationTestCase { } let model = Model( name: "name", - embedding: FieldValue.vector([0.1, 0.3, 0.4]) + embedding: VectorValue([0.1, 0.3, 0.4]) ) let docToWrite = documentRef() diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index d4c3bd29bf3..9e4642f1182 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -138,8 +138,6 @@ class VectorIntegrationTests: FSTIntegrationTestCase { checkOnlineAndOfflineQuery(collection.order(by: "embedding").whereField("embedding", isGreaterThan: FieldValue.vector([1, 2, 100, 4, 4.0])), matchesResult: Array(docIds[12...12])) } - - func testQueryVectorValueWrittenByCodable() async throws { let collection = collectionRef() @@ -154,9 +152,64 @@ class VectorIntegrationTests: FSTIntegrationTestCase { try collection.document().setData(from: model) - let querySnap = try await collection.whereField("embedding", isEqualTo: FieldValue.vector([0.1, 0.3, 0.4])).getDocuments() + let querySnap: QuerySnapshot = try await collection.whereField("embedding", isEqualTo: FieldValue.vector([0.1, 0.3, 0.4])).getDocuments() XCTAssertEqual(1, querySnap.count) - XCTAssertEqual(try querySnap.documents[0].data(as: Model.self).embedding, VectorValue([0.1, 0.3, 0.4])) + + let returnedModel: Model = try querySnap.documents[0].data(as: Model.self) + XCTAssertEqual(returnedModel.embedding, VectorValue([0.1, 0.3, 0.4])) + + let vectorData: [Double] = returnedModel.embedding.data; + XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) + } + + func testQueryVectorValueWrittenByCodableClass() async throws { + let collection = collectionRef() + + struct Model: Codable { + var name: String + var embedding: VectorValue + } + + struct ModelWithDistance: Codable { + var name: String + var embedding: VectorValue + var distance: Double + } + +struct WithDistance: Decodable { + var distance: Double + var data: T + + private enum CodingKeys: String, CodingKey { + case distance + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + distance = try container.decode(Double.self, forKey: .distance) + data = try T(from: decoder) + } +} + + let model = ModelWithDistance( + name: "name", + embedding: FieldValue.vector([0.1, 0.3, 0.4]), + distance: 0.2 + ) + + try collection.document().setData(from: model) + + let querySnap: QuerySnapshot = try await collection.getDocuments() + + XCTAssertEqual(1, querySnap.count) + + let returnedModel: WithDistance = + try querySnap.documents[0].data(as: WithDistance.self) + XCTAssertEqual(returnedModel.data.embedding, VectorValue([0.1, 0.3, 0.4])) + XCTAssertEqual(returnedModel.distance, 0.2) + + let vectorData: [Double] = returnedModel.data.embedding.data; + XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) } } From 707c9f8df4bfb2d37f500b00258c0dac1ce632e8 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Thu, 25 Jul 2024 14:51:49 -0600 Subject: [PATCH 06/24] Formatting --- Firestore/Source/API/FIRFieldValue.mm | 2 +- .../Source/API/FIRVectorValue+Internal.h | 2 - Firestore/Source/API/FIRVectorValue.m | 73 ++-- Firestore/Source/API/FSTUserDataReader.mm | 71 ++-- Firestore/Source/API/FSTUserDataWriter.mm | 82 ++-- .../Public/FirebaseFirestore/FIRFieldValue.h | 2 +- .../Public/FirebaseFirestore/FIRVectorValue.h | 3 +- .../Source/Codable/VectorValue+Codable.swift | 26 +- .../Source/SwiftAPI/FieldValue+Swift.swift | 35 +- .../Source/SwiftAPI/VectorValue+Swift.swift | 24 +- .../Integration/CodableIntegrationTests.swift | 36 +- .../SnapshotListenerSourceTests.swift | 133 +++--- .../Integration/VectorIntegrationTests.swift | 387 ++++++++++-------- Firestore/core/src/core/target.cc | 6 +- .../src/index/firestore_index_value_writer.cc | 48 ++- Firestore/core/src/model/value_util.cc | 267 ++++++------ Firestore/core/src/model/value_util.h | 31 +- .../unit/local/leveldb_index_manager_test.cc | 77 ++-- .../core/test/unit/model/value_util_test.cc | 132 +++--- .../core/test/unit/remote/serializer_test.cc | 4 +- Firestore/core/test/unit/testutil/testutil.h | 3 +- 21 files changed, 747 insertions(+), 697 deletions(-) diff --git a/Firestore/Source/API/FIRFieldValue.mm b/Firestore/Source/API/FIRFieldValue.mm index a38cff88a11..4b59e78c998 100644 --- a/Firestore/Source/API/FIRFieldValue.mm +++ b/Firestore/Source/API/FIRFieldValue.mm @@ -178,7 +178,7 @@ + (instancetype)fieldValueForIntegerIncrement:(int64_t)l { } + (nonnull FIRVectorValue *)vectorFromNSNumbers:(nonnull NSArray *)values { - return [[FIRVectorValue alloc] initWithNSNumbers: values]; + return [[FIRVectorValue alloc] initWithNSNumbers:values]; } @end diff --git a/Firestore/Source/API/FIRVectorValue+Internal.h b/Firestore/Source/API/FIRVectorValue+Internal.h index 23a6243f566..5192eceb78e 100644 --- a/Firestore/Source/API/FIRVectorValue+Internal.h +++ b/Firestore/Source/API/FIRVectorValue+Internal.h @@ -16,7 +16,6 @@ #import "FIRVectorValue.h" - NS_ASSUME_NONNULL_BEGIN @interface FIRVectorValue (Internal) @@ -24,5 +23,4 @@ NS_ASSUME_NONNULL_BEGIN // - (NSArray *)toNSArray; @end - NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/API/FIRVectorValue.m b/Firestore/Source/API/FIRVectorValue.m index 0b371774953..417a52efda4 100644 --- a/Firestore/Source/API/FIRVectorValue.m +++ b/Firestore/Source/API/FIRVectorValue.m @@ -14,7 +14,6 @@ * limitations under the License. */ - #import #include @@ -32,50 +31,50 @@ @interface FIRVectorValue () { @implementation FIRVectorValue -- (instancetype)initWithNSNumbers: (NSArray *)data { - if (self = [super init]) { - std::vector converted; - converted.reserve(data.count); - for (NSNumber *value in data) { - converted.emplace_back([value doubleValue]); - } - - _internalValue = std::move(converted); +- (instancetype)initWithNSNumbers:(NSArray *)data { + if (self = [super init]) { + std::vector converted; + converted.reserve(data.count); + for (NSNumber *value in data) { + converted.emplace_back([value doubleValue]); } - return self; + + _internalValue = std::move(converted); + } + return self; } - (nonnull NSArray *)toNSArray { - size_t length = _internalValue.size(); - NSMutableArray *outArray = [[NSMutableArray alloc] initWithCapacity:length]; - for (size_t i = 0; i < length; i++) { - [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; - } - - return outArray; + size_t length = _internalValue.size(); + NSMutableArray *outArray = + [[NSMutableArray alloc] initWithCapacity:length]; + for (size_t i = 0; i < length; i++) { + [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; + } + + return outArray; } - (BOOL)isEqual:(nullable id)object { - if (self == object) { - return YES; - } - - if (![object isKindOfClass:[FIRVectorValue class]]) { - return NO; - } - - FIRVectorValue *otherVector = ((FIRVectorValue *)object); - - if (self->_internalValue.size() != otherVector->_internalValue.size()) { - return NO; - } - - for (size_t i = 0; i < self->_internalValue.size(); i++) { - if (self->_internalValue[i] != otherVector->_internalValue[i]) - return NO; - } - + if (self == object) { return YES; + } + + if (![object isKindOfClass:[FIRVectorValue class]]) { + return NO; + } + + FIRVectorValue *otherVector = ((FIRVectorValue *)object); + + if (self->_internalValue.size() != otherVector->_internalValue.size()) { + return NO; + } + + for (size_t i = 0; i < self->_internalValue.size(); i++) { + if (self->_internalValue[i] != otherVector->_internalValue[i]) return NO; + } + + return YES; } @end diff --git a/Firestore/Source/API/FSTUserDataReader.mm b/Firestore/Source/API/FSTUserDataReader.mm index 115a00c9332..3ea08e1e847 100644 --- a/Firestore/Source/API/FSTUserDataReader.mm +++ b/Firestore/Source/API/FSTUserDataReader.mm @@ -341,39 +341,40 @@ - (ParsedUpdateData)parsedUpdateData:(id)input { return std::move(result); } -- (Message)parseVectorValue:(FIRVectorValue *)vectorValue context:(ParseContext &&)context { - __block Message result; - result->which_value_type = google_firestore_v1_Value_map_value_tag; - result->map_value = {}; - - result->map_value.fields_count = 2; - result->map_value.fields = nanopb::MakeArray(2); - - result->map_value.fields[0].key = nanopb::CopyBytesArray(model::kTypeValueFieldKey); - result->map_value.fields[0].value = *[self encodeStringValue:MakeString(@"__vector__")].release(); - - NSArray *vectorArray = [vectorValue toNSArray]; - - __block Message arrayMessage; - arrayMessage->which_value_type = google_firestore_v1_Value_array_value_tag; - arrayMessage->array_value.values_count = CheckedSize([vectorArray count]); - arrayMessage->array_value.values = - nanopb::MakeArray(arrayMessage->array_value.values_count); - - [vectorArray enumerateObjectsUsingBlock:^(id entry, NSUInteger idx, BOOL *) { - if (![entry isKindOfClass:[NSNumber class]]) { - ThrowInvalidArgument("VectorValues must only contain numeric values.", - context.FieldDescription()); - } - - // Vector values must always use Double encoding - arrayMessage->array_value.values[idx] = *[self encodeDouble:[entry doubleValue]].release(); - }]; - - result->map_value.fields[1].key = nanopb::CopyBytesArray(model::kVectorValueFieldKey); - result->map_value.fields[1].value = *arrayMessage.release(); - - return std::move(result); +- (Message)parseVectorValue:(FIRVectorValue *)vectorValue + context:(ParseContext &&)context { + __block Message result; + result->which_value_type = google_firestore_v1_Value_map_value_tag; + result->map_value = {}; + + result->map_value.fields_count = 2; + result->map_value.fields = nanopb::MakeArray(2); + + result->map_value.fields[0].key = nanopb::CopyBytesArray(model::kTypeValueFieldKey); + result->map_value.fields[0].value = *[self encodeStringValue:MakeString(@"__vector__")].release(); + + NSArray *vectorArray = [vectorValue toNSArray]; + + __block Message arrayMessage; + arrayMessage->which_value_type = google_firestore_v1_Value_array_value_tag; + arrayMessage->array_value.values_count = CheckedSize([vectorArray count]); + arrayMessage->array_value.values = + nanopb::MakeArray(arrayMessage->array_value.values_count); + + [vectorArray enumerateObjectsUsingBlock:^(id entry, NSUInteger idx, BOOL *) { + if (![entry isKindOfClass:[NSNumber class]]) { + ThrowInvalidArgument("VectorValues must only contain numeric values.", + context.FieldDescription()); + } + + // Vector values must always use Double encoding + arrayMessage->array_value.values[idx] = *[self encodeDouble:[entry doubleValue]].release(); + }]; + + result->map_value.fields[1].key = nanopb::CopyBytesArray(model::kVectorValueFieldKey); + result->map_value.fields[1].value = *arrayMessage.release(); + + return std::move(result); } - (Message)parseArray:(NSArray *)array @@ -565,8 +566,8 @@ - (void)parseSentinelFieldValue:(FIRFieldValue *)fieldValue context:(ParseContex } return [self encodeReference:_databaseID key:reference.key]; } else if ([input isKindOfClass:[FIRVectorValue class]]) { - FIRVectorValue *vector = input; - return [self parseVectorValue:vector context:std::move(context)]; + FIRVectorValue *vector = input; + return [self parseVectorValue:vector context:std::move(context)]; } else { ThrowInvalidArgument("Unsupported type: %s%s", NSStringFromClass([input class]), context.FieldDescription()); diff --git a/Firestore/Source/API/FSTUserDataWriter.mm b/Firestore/Source/API/FSTUserDataWriter.mm index 7d3e87b7c46..e6f2cd358eb 100644 --- a/Firestore/Source/API/FSTUserDataWriter.mm +++ b/Firestore/Source/API/FSTUserDataWriter.mm @@ -21,8 +21,8 @@ #include "Firestore/Protos/nanopb/google/firestore/v1/document.nanopb.h" #include "Firestore/Source/API/FIRDocumentReference+Internal.h" -#include "Firestore/Source/API/converters.h" #include "Firestore/Source/API/FIRFieldValue+Internal.h" +#include "Firestore/Source/API/converters.h" #include "Firestore/core/include/firebase/firestore/geo_point.h" #include "Firestore/core/include/firebase/firestore/timestamp.h" #include "Firestore/core/src/api/firestore.h" @@ -80,38 +80,38 @@ - (instancetype)initWithFirestore:(std::shared_ptr)firestore } - (id)convertedValue:(const google_firestore_v1_Value &)value { - switch (GetTypeOrder(value)) { - case TypeOrder::kMap: - return [self convertedObject:value.map_value]; - case TypeOrder::kArray: - return [self convertedArray:value.array_value]; - case TypeOrder::kReference: - return [self convertedReference:value]; - case TypeOrder::kTimestamp: - return [self convertedTimestamp:value.timestamp_value]; - case TypeOrder::kServerTimestamp: - return [self convertedServerTimestamp:value]; - case TypeOrder::kNull: - return [NSNull null]; - case TypeOrder::kBoolean: - return value.boolean_value ? @YES : @NO; - case TypeOrder::kNumber: - return value.which_value_type == google_firestore_v1_Value_integer_value_tag - ? @(value.integer_value) - : @(value.double_value); - case TypeOrder::kString: - return MakeNSString(MakeStringView(value.string_value)); - case TypeOrder::kBlob: - return MakeNSData(value.bytes_value); - case TypeOrder::kGeoPoint: - return MakeFIRGeoPoint( - GeoPoint(value.geo_point_value.latitude, value.geo_point_value.longitude)); - case TypeOrder::kVector: - return [self convertedVector:value.map_value]; - case TypeOrder::kMaxValue: - // It is not possible for users to construct a kMaxValue manually. - break; - } + switch (GetTypeOrder(value)) { + case TypeOrder::kMap: + return [self convertedObject:value.map_value]; + case TypeOrder::kArray: + return [self convertedArray:value.array_value]; + case TypeOrder::kReference: + return [self convertedReference:value]; + case TypeOrder::kTimestamp: + return [self convertedTimestamp:value.timestamp_value]; + case TypeOrder::kServerTimestamp: + return [self convertedServerTimestamp:value]; + case TypeOrder::kNull: + return [NSNull null]; + case TypeOrder::kBoolean: + return value.boolean_value ? @YES : @NO; + case TypeOrder::kNumber: + return value.which_value_type == google_firestore_v1_Value_integer_value_tag + ? @(value.integer_value) + : @(value.double_value); + case TypeOrder::kString: + return MakeNSString(MakeStringView(value.string_value)); + case TypeOrder::kBlob: + return MakeNSData(value.bytes_value); + case TypeOrder::kGeoPoint: + return MakeFIRGeoPoint( + GeoPoint(value.geo_point_value.latitude, value.geo_point_value.longitude)); + case TypeOrder::kVector: + return [self convertedVector:value.map_value]; + case TypeOrder::kMaxValue: + // It is not possible for users to construct a kMaxValue manually. + break; + } UNREACHABLE(); } @@ -127,15 +127,15 @@ - (id)convertedValue:(const google_firestore_v1_Value &)value { } - (FIRVectorValue *)convertedVector:(const google_firestore_v1_MapValue &)mapValue { - for (pb_size_t i = 0; i < mapValue.fields_count; ++i) { - absl::string_view key = MakeStringView(mapValue.fields[i].key); - const google_firestore_v1_Value &value = mapValue.fields[i].value; - if ((0 == key.compare(absl::string_view("value"))) - && value.which_value_type == google_firestore_v1_Value_array_value_tag) { - return [FIRFieldValue vectorFromNSNumbers:[self convertedArray:value.array_value]]; - } + for (pb_size_t i = 0; i < mapValue.fields_count; ++i) { + absl::string_view key = MakeStringView(mapValue.fields[i].key); + const google_firestore_v1_Value &value = mapValue.fields[i].value; + if ((0 == key.compare(absl::string_view("value"))) && + value.which_value_type == google_firestore_v1_Value_array_value_tag) { + return [FIRFieldValue vectorFromNSNumbers:[self convertedArray:value.array_value]]; } - return [FIRFieldValue vectorFromNSNumbers:@[]]; + } + return [FIRFieldValue vectorFromNSNumbers:@[]]; } - (NSArray *)convertedArray:(const google_firestore_v1_ArrayValue &)arrayValue { diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h index 23842b30edf..45ef2d8d24e 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h @@ -97,7 +97,7 @@ NS_SWIFT_NAME(FieldValue) * @param values Create a `VectorValue` instance with a copy of this array of NSNumbers. * @return A new `VectorValue` constructed with a copy of the given array of NSNumbers. */ -+ (FIRVectorValue*)vectorFromNSNumbers:(NSArray *)values NS_REFINED_FOR_SWIFT; ++ (FIRVectorValue *)vectorFromNSNumbers:(NSArray *)values NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h index 746965fea4d..af82d447f45 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h @@ -25,10 +25,9 @@ NS_SWIFT_NAME(VectorValue) - (instancetype)init NS_UNAVAILABLE; // Public initializer is required to support Codable -- (instancetype)initWithNSNumbers: (NSArray *)data NS_REFINED_FOR_SWIFT; +- (instancetype)initWithNSNumbers:(NSArray *)data NS_REFINED_FOR_SWIFT; - (NSArray *)toNSArray NS_REFINED_FOR_SWIFT; @end - NS_ASSUME_NONNULL_END diff --git a/Firestore/Swift/Source/Codable/VectorValue+Codable.swift b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift index 6f8b532c846..7ed48ddc6a2 100644 --- a/Firestore/Swift/Source/Codable/VectorValue+Codable.swift +++ b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift @@ -41,21 +41,21 @@ private enum VectorValueKeys: String, CodingKey { * when declaring an extension to conform to Codable. */ extension CodableVectorValue { - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: VectorValueKeys.self) - let data = try container.decode([Double].self, forKey: .data) - - let array = data.map { double in - return NSNumber(value: double) - } - self.init(__nsNumbers: array) - } + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: VectorValueKeys.self) + let data = try container.decode([Double].self, forKey: .data) - public func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: VectorValueKeys.self) - try container.encode(data, forKey: .data) + let array = data.map { double in + NSNumber(value: double) } + self.init(__nsNumbers: array) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: VectorValueKeys.self) + try container.encode(data, forKey: .data) + } } /** Extends VectorValue to conform to Codable. */ -extension VectorValue: CodableVectorValue{} +extension VectorValue: CodableVectorValue {} diff --git a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift index 60ccbd39250..35d59bc6209 100644 --- a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift @@ -21,24 +21,23 @@ #endif // SWIFT_PACKAGE public extension FieldValue { - - /// Creates a new `VectorValue` constructed with a copy of the given array of Doubles. - /// - Parameter data: An array of Doubles. - /// - Returns: A new `VectorValue` constructed with a copy of the given array of Doubles. - static func vector(_ data: [Double]) -> VectorValue { - let array = data.map { double in - return NSNumber(value: double) - } - return FieldValue.__vector(from: array) + /// Creates a new `VectorValue` constructed with a copy of the given array of Doubles. + /// - Parameter data: An array of Doubles. + /// - Returns: A new `VectorValue` constructed with a copy of the given array of Doubles. + static func vector(_ data: [Double]) -> VectorValue { + let array = data.map { double in + NSNumber(value: double) } - - /// Creates a new `VectorValue` constructed with a copy of the given array of Floats. - /// - Parameter data: An array of Floats. - /// - Returns: A new `VectorValue` constructed with a copy of the given array of Floats. - static func vector(_ data: [Float]) -> VectorValue { - let array = data.map { float in - return NSNumber(value: float) - } - return FieldValue.__vector(from: array) + return FieldValue.__vector(from: array) + } + + /// Creates a new `VectorValue` constructed with a copy of the given array of Floats. + /// - Parameter data: An array of Floats. + /// - Returns: A new `VectorValue` constructed with a copy of the given array of Floats. + static func vector(_ data: [Float]) -> VectorValue { + let array = data.map { float in + NSNumber(value: float) } + return FieldValue.__vector(from: array) + } } diff --git a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift index b9d4d92e886..590c6644a08 100644 --- a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift @@ -21,17 +21,17 @@ #endif // SWIFT_PACKAGE public extension VectorValue { - convenience init(_ data: [Double]) { - let array = data.map { float in - return NSNumber(value: float) - } - - self.init(__nsNumbers: array) - } - - /// Returns a raw number array representation of the vector. - /// - Returns: An array of Double values representing the vector. - var data: [Double] { - return self.__toNSArray().map { Double(truncating: $0) } + convenience init(_ data: [Double]) { + let array = data.map { float in + NSNumber(value: float) } + + self.init(__nsNumbers: array) + } + + /// Returns a raw number array representation of the vector. + /// - Returns: An array of Double values representing the vector. + var data: [Double] { + return __toNSArray().map { Double(truncating: $0) } + } } diff --git a/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift b/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift index ee25019e466..06c68ae7405 100644 --- a/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/CodableIntegrationTests.swift @@ -185,27 +185,31 @@ class CodableIntegrationTests: FSTIntegrationTestCase { XCTAssertEqual(data["intValue"] as! Int, 3, "Failed with flavor \(flavor)") } } - - func testVectorValue() throws { - struct Model: Codable { - var name: String - var embedding: VectorValue - } - let model = Model( - name: "name", - embedding: VectorValue([0.1, 0.3, 0.4]) - ) - let docToWrite = documentRef() + func testVectorValue() throws { + struct Model: Codable { + var name: String + var embedding: VectorValue + } + let model = Model( + name: "name", + embedding: VectorValue([0.1, 0.3, 0.4]) + ) + + let docToWrite = documentRef() - for flavor in allFlavors { - try setData(from: model, forDocument: docToWrite, withFlavor: flavor) + for flavor in allFlavors { + try setData(from: model, forDocument: docToWrite, withFlavor: flavor) - let data = try readDocument(forRef: docToWrite).data(as: Model.self) + let data = try readDocument(forRef: docToWrite).data(as: Model.self) - XCTAssertEqual(data.embedding, VectorValue([0.1, 0.3, 0.4]), "Failed with flavor \(flavor)") - } + XCTAssertEqual( + data.embedding, + VectorValue([0.1, 0.3, 0.4]), + "Failed with flavor \(flavor)" + ) } + } func testDataBlob() throws { struct Model: Encodable { diff --git a/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift b/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift index 0a6bebf7091..61b4da23530 100644 --- a/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift +++ b/Firestore/Swift/Tests/Integration/SnapshotListenerSourceTests.swift @@ -673,57 +673,84 @@ class SnapshotListenerSourceTests: FSTIntegrationTestCase { defaultRegistration.remove() cacheRegistration.remove() } - - - func testListenToDocumentsWithVectors() throws { - let collection = collectionRef() - let doc = collection.document() - - let registration = collection.whereField("purpose", isEqualTo: "vector tests").addSnapshotListener(eventAccumulator.valueEventHandler) - - var querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot - XCTAssertEqual(querySnap.isEmpty, true) - - doc.setData([ - "purpose": "vector tests", - "vector0": FieldValue.vector([0.0]), - "vector1": FieldValue.vector([1, 2, 3.99]) - ]) - - querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot - XCTAssertEqual(querySnap.isEmpty, false) - XCTAssertEqual(querySnap.documents[0].data()["vector0"] as! VectorValue, FieldValue.vector([0.0])) - XCTAssertEqual(querySnap.documents[0].data()["vector1"] as! VectorValue, FieldValue.vector([1, 2, 3.99])) - - doc.setData([ - "purpose": "vector tests", - "vector0": FieldValue.vector([0.0]), - "vector1": FieldValue.vector([1, 2, 3.99]), - "vector2": FieldValue.vector([0.0, 0, 0]) - ]) - - querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot - XCTAssertEqual(querySnap.isEmpty, false) - XCTAssertEqual(querySnap.documents[0].data()["vector0"] as! VectorValue, FieldValue.vector([0.0])) - XCTAssertEqual(querySnap.documents[0].data()["vector1"] as! VectorValue, FieldValue.vector([1, 2, 3.99])) - XCTAssertEqual(querySnap.documents[0].data()["vector2"] as! VectorValue, FieldValue.vector([0.0, 0, 0])) - - doc.updateData([ - "vector3": FieldValue.vector([-1, -200, -999.0]) - ]) - - querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot - XCTAssertEqual(querySnap.isEmpty, false) - XCTAssertEqual(querySnap.documents[0].data()["vector0"] as! VectorValue, FieldValue.vector([0.0])) - XCTAssertEqual(querySnap.documents[0].data()["vector1"] as! VectorValue, FieldValue.vector([1, 2, 3.99])) - XCTAssertEqual(querySnap.documents[0].data()["vector2"] as! VectorValue, FieldValue.vector([0.0, 0, 0])) - XCTAssertEqual(querySnap.documents[0].data()["vector3"] as! VectorValue, FieldValue.vector([-1, -200, -999.0])) - - doc.delete() - querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot - XCTAssertEqual(querySnap.isEmpty, true) - - eventAccumulator.assertNoAdditionalEvents() - registration.remove() - } + + func testListenToDocumentsWithVectors() throws { + let collection = collectionRef() + let doc = collection.document() + + let registration = collection.whereField("purpose", isEqualTo: "vector tests") + .addSnapshotListener(eventAccumulator.valueEventHandler) + + var querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, true) + + doc.setData([ + "purpose": "vector tests", + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99]), + ]) + + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, false) + XCTAssertEqual( + querySnap.documents[0].data()["vector0"] as! VectorValue, + FieldValue.vector([0.0]) + ) + XCTAssertEqual( + querySnap.documents[0].data()["vector1"] as! VectorValue, + FieldValue.vector([1, 2, 3.99]) + ) + + doc.setData([ + "purpose": "vector tests", + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99]), + "vector2": FieldValue.vector([0.0, 0, 0]), + ]) + + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, false) + XCTAssertEqual( + querySnap.documents[0].data()["vector0"] as! VectorValue, + FieldValue.vector([0.0]) + ) + XCTAssertEqual( + querySnap.documents[0].data()["vector1"] as! VectorValue, + FieldValue.vector([1, 2, 3.99]) + ) + XCTAssertEqual( + querySnap.documents[0].data()["vector2"] as! VectorValue, + FieldValue.vector([0.0, 0, 0]) + ) + + doc.updateData([ + "vector3": FieldValue.vector([-1, -200, -999.0]), + ]) + + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, false) + XCTAssertEqual( + querySnap.documents[0].data()["vector0"] as! VectorValue, + FieldValue.vector([0.0]) + ) + XCTAssertEqual( + querySnap.documents[0].data()["vector1"] as! VectorValue, + FieldValue.vector([1, 2, 3.99]) + ) + XCTAssertEqual( + querySnap.documents[0].data()["vector2"] as! VectorValue, + FieldValue.vector([0.0, 0, 0]) + ) + XCTAssertEqual( + querySnap.documents[0].data()["vector3"] as! VectorValue, + FieldValue.vector([-1, -200, -999.0]) + ) + + doc.delete() + querySnap = eventAccumulator.awaitEvent(withName: "snapshot") as! QuerySnapshot + XCTAssertEqual(querySnap.isEmpty, true) + + eventAccumulator.assertNoAdditionalEvents() + registration.remove() + } } diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index 9e4642f1182..4d50d89c6e1 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -14,202 +14,225 @@ * limitations under the License. */ +import Combine import FirebaseFirestore import Foundation -import Combine +// iOS 15 required for test implementation, not vector feature @available(iOS 15, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *) class VectorIntegrationTests: FSTIntegrationTestCase { - func testWriteAndReadVectorEmbeddings() async throws { - let collection = collectionRef() - - let ref = try await collection.addDocument(data: [ - "vector0": FieldValue.vector([0.0]), - "vector1": FieldValue.vector([1, 2, 3.99])]) - - try await ref.setData([ - "vector0": FieldValue.vector([0.0]), - "vector1": FieldValue.vector([1, 2, 3.99]), - "vector2": FieldValue.vector([0, 0, 0] as Array)]) - - try await ref.updateData([ - "vector3": FieldValue.vector([-1, -200, -999] as Array)]) - - let snapshot = try await ref.getDocument(); - XCTAssertEqual(snapshot.get("vector0") as? VectorValue, FieldValue.vector([0.0])) - XCTAssertEqual(snapshot.get("vector1") as? VectorValue, FieldValue.vector([1, 2, 3.99])) - XCTAssertEqual(snapshot.get("vector2") as? VectorValue, FieldValue.vector([0, 0, 0] as Array)) - XCTAssertEqual(snapshot.get("vector3") as? VectorValue, FieldValue.vector([-1, -200, -999] as Array)) + func testWriteAndReadVectorEmbeddings() async throws { + let collection = collectionRef() + + let ref = try await collection.addDocument(data: [ + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99]), + ]) + + try await ref.setData([ + "vector0": FieldValue.vector([0.0]), + "vector1": FieldValue.vector([1, 2, 3.99]), + "vector2": FieldValue.vector([0, 0, 0] as [Double]), + ]) + + try await ref.updateData([ + "vector3": FieldValue.vector([-1, -200, -999] as [Double]), + ]) + + let snapshot = try await ref.getDocument() + XCTAssertEqual(snapshot.get("vector0") as? VectorValue, FieldValue.vector([0.0])) + XCTAssertEqual(snapshot.get("vector1") as? VectorValue, FieldValue.vector([1, 2, 3.99])) + XCTAssertEqual( + snapshot.get("vector2") as? VectorValue, + FieldValue.vector([0, 0, 0] as [Double]) + ) + XCTAssertEqual( + snapshot.get("vector3") as? VectorValue, + FieldValue.vector([-1, -200, -999] as [Double]) + ) + } + + func testSdkOrdersVectorFieldSameWayAsBackend() async throws { + let collection = collectionRef() + + var docsInOrder: [[String: Any]] = [ + ["embedding": [1, 2, 3, 4, 5, 6]], + ["embedding": [100]], + ["embedding": FieldValue.vector([Double.infinity * -1])], + ["embedding": FieldValue.vector([-100.0])], + ["embedding": FieldValue.vector([100.0])], + ["embedding": FieldValue.vector([Double.infinity])], + ["embedding": FieldValue.vector([1, 2.0])], + ["embedding": FieldValue.vector([2, 2.0])], + ["embedding": FieldValue.vector([1, 2, 3.0])], + ["embedding": FieldValue.vector([1, 2, 3, 4.0])], + ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])], + ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])], + ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])], + ["embedding": ["HELLO": "WORLD"]], + ["embedding": ["hello": "world"]], + ] + + var docs: [[String: Any]] = [] + for data in docsInOrder { + let docRef = try await collection.addDocument(data: data) + docs.append(["id": docRef.documentID, "value": data]) } - - func testSdkOrdersVectorFieldSameWayAsBackend() async throws { - let collection = collectionRef() - - var docsInOrder: [[String:Any]] = [ - ["embedding": [1, 2, 3, 4, 5, 6]], - ["embedding": [100] ], - ["embedding": FieldValue.vector([Double.infinity * -1]) ], - ["embedding": FieldValue.vector([-100.0]) ], - ["embedding": FieldValue.vector([100.0]) ], - ["embedding": FieldValue.vector([Double.infinity]) ], - ["embedding": FieldValue.vector([1, 2.0]) ], - ["embedding": FieldValue.vector([2, 2.0]) ], - ["embedding": FieldValue.vector([1, 2, 3.0]) ], - ["embedding": FieldValue.vector([1, 2, 3, 4.0]) ], - ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0]) ], - ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0]) ], - ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0]) ], - ["embedding": [ "HELLO": "WORLD" ] ], - ["embedding": [ "hello": "world" ] ] - ] - - var docs: [[String:Any]] = []; - for data in docsInOrder { - let docRef = try await collection.addDocument(data: data) - docs.append(["id": docRef.documentID, "value": data]) - } - - // We validate that the SDK orders the vector field the same way as the backend - // by comparing the sort order of vector fields from getDocsFromServer and - // onSnapshot. onSnapshot will return sort order of the SDK, - // and getDocsFromServer will return sort order of the backend. - - let orderedQuery = collection.order(by: "embedding") - - - let watchSnapshot = try await Future() { promise in - orderedQuery.addSnapshotListener { snapshot, error in - if let error { - promise(Result.failure(error)) - } - if let snapshot { - promise(Result.success(snapshot)) - } - } - }.value - - let getSnapshot = try await orderedQuery.getDocuments(source: .server) - - // Compare the snapshot (including sort order) of a snapshot - // from Query.onSnapshot() to an actual snapshot from Query.get() - XCTAssertEqual(watchSnapshot.count, getSnapshot.count) - for i in (0..() { promise in + orderedQuery.addSnapshotListener { snapshot, error in + if let error { + promise(Result.failure(error)) } - - // Compare the snapshot (including sort order) of a snapshot - // from Query.onSnapshot() to the expected sort order from - // the backend. - XCTAssertEqual(watchSnapshot.count, docs.count) - for i in (0..: Decodable { - var distance: Double - var data: T - - private enum CodingKeys: String, CodingKey { - case distance + + checkOnlineAndOfflineQuery(collection.order(by: "embedding"), matchesResult: docIds) + checkOnlineAndOfflineQuery( + collection.order(by: "embedding") + .whereField("embedding", isLessThan: FieldValue.vector([1, 2, 100, 4, 4.0])), + matchesResult: Array(docIds[2 ... 10]) + ) + checkOnlineAndOfflineQuery( + collection.order(by: "embedding") + .whereField("embedding", isGreaterThan: FieldValue.vector([1, 2, 100, 4, 4.0])), + matchesResult: Array(docIds[12 ... 12]) + ) + } + + func testQueryVectorValueWrittenByCodable() async throws { + let collection = collectionRef() + + struct Model: Codable { + var name: String + var embedding: VectorValue + } + let model = Model( + name: "name", + embedding: FieldValue.vector([0.1, 0.3, 0.4]) + ) + + try collection.document().setData(from: model) + + let querySnap: QuerySnapshot = try await collection.whereField( + "embedding", + isEqualTo: FieldValue.vector([0.1, 0.3, 0.4]) + ).getDocuments() + + XCTAssertEqual(1, querySnap.count) + + let returnedModel: Model = try querySnap.documents[0].data(as: Model.self) + XCTAssertEqual(returnedModel.embedding, VectorValue([0.1, 0.3, 0.4])) + + let vectorData: [Double] = returnedModel.embedding.data + XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) + } + + func testQueryVectorValueWrittenByCodableClass() async throws { + let collection = collectionRef() + + struct Model: Codable { + var name: String + var embedding: VectorValue + } + + struct ModelWithDistance: Codable { + var name: String + var embedding: VectorValue + var distance: Double } - - init(from decoder: Decoder) throws { + + struct WithDistance: Decodable { + var distance: Double + var data: T + + private enum CodingKeys: String, CodingKey { + case distance + } + + init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) distance = try container.decode(Double.self, forKey: .distance) data = try T(from: decoder) + } } -} - - let model = ModelWithDistance( - name: "name", - embedding: FieldValue.vector([0.1, 0.3, 0.4]), - distance: 0.2 - ) - - try collection.document().setData(from: model) - - let querySnap: QuerySnapshot = try await collection.getDocuments() - - XCTAssertEqual(1, querySnap.count) - - let returnedModel: WithDistance = - try querySnap.documents[0].data(as: WithDistance.self) - XCTAssertEqual(returnedModel.data.embedding, VectorValue([0.1, 0.3, 0.4])) - XCTAssertEqual(returnedModel.distance, 0.2) - - let vectorData: [Double] = returnedModel.data.embedding.data; - XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) - } + + let model = ModelWithDistance( + name: "name", + embedding: FieldValue.vector([0.1, 0.3, 0.4]), + distance: 0.2 + ) + + try collection.document().setData(from: model) + + let querySnap: QuerySnapshot = try await collection.getDocuments() + + XCTAssertEqual(1, querySnap.count) + + let returnedModel: WithDistance = + try querySnap.documents[0].data(as: WithDistance.self) + XCTAssertEqual(returnedModel.data.embedding, VectorValue([0.1, 0.3, 0.4])) + XCTAssertEqual(returnedModel.distance, 0.2) + + let vectorData: [Double] = returnedModel.data.embedding.data + XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) + } } diff --git a/Firestore/core/src/core/target.cc b/Firestore/core/src/core/target.cc index 54a9e127728..3002a955da1 100644 --- a/Firestore/core/src/core/target.cc +++ b/Firestore/core/src/core/target.cc @@ -219,8 +219,7 @@ Target::IndexBoundValue Target::GetAscendingBound( switch (field_filter.op()) { case FieldFilter::Operator::LessThan: case FieldFilter::Operator::LessThanOrEqual: - filter_value = - model::GetLowerBound(field_filter.value()); + filter_value = model::GetLowerBound(field_filter.value()); break; case FieldFilter::Operator::Equal: case FieldFilter::Operator::In: @@ -284,8 +283,7 @@ Target::IndexBoundValue Target::GetDescendingBound( switch (field_filter.op()) { case FieldFilter::Operator::GreaterThanOrEqual: case FieldFilter::Operator::GreaterThan: - filter_value = - model::GetUpperBound(field_filter.value()); + filter_value = model::GetUpperBound(field_filter.value()); filter_inclusive = false; break; case FieldFilter::Operator::Equal: diff --git a/Firestore/core/src/index/firestore_index_value_writer.cc b/Firestore/core/src/index/firestore_index_value_writer.cc index 444a06ecb26..24a261a3d84 100644 --- a/Firestore/core/src/index/firestore_index_value_writer.cc +++ b/Firestore/core/src/index/firestore_index_value_writer.cc @@ -47,7 +47,7 @@ enum IndexType { kReference = 37, kGeopoint = 45, kArray = 50, - kVector = 53, + kVector = 53, kMap = 55, kReferenceSegment = 60, // A terminator that indicates that a truncatable value was not truncated. @@ -108,24 +108,28 @@ void WriteIndexArray(const google_firestore_v1_ArrayValue& array_index_value, } void WriteIndexVector(const google_firestore_v1_MapValue& map_index_value, - DirectionalIndexByteEncoder* encoder) { - WriteValueTypeLabel(encoder, IndexType::kVector); - - int64_t valueIndex = model::IndexOfKey(map_index_value, model::kRawVectorValueFieldKey, model::kVectorValueFieldKey); - - if (valueIndex < 0 || map_index_value.fields[valueIndex].value.which_value_type != google_firestore_v1_Value_array_value_tag) { - return WriteIndexArray(model::MinArray().array_value, encoder); - } - - auto value = map_index_value.fields[valueIndex].value; - - // Vectors sort first by length - WriteValueTypeLabel(encoder, IndexType::kNumber); - encoder->WriteLong(value.array_value.values_count); - - // Vectors then sort by position value - WriteIndexString(model::kVectorValueFieldKey, encoder); - WriteIndexValueAux(value, encoder); + DirectionalIndexByteEncoder* encoder) { + WriteValueTypeLabel(encoder, IndexType::kVector); + + int64_t valueIndex = + model::IndexOfKey(map_index_value, model::kRawVectorValueFieldKey, + model::kVectorValueFieldKey); + + if (valueIndex < 0 || + map_index_value.fields[valueIndex].value.which_value_type != + google_firestore_v1_Value_array_value_tag) { + return WriteIndexArray(model::MinArray().array_value, encoder); + } + + auto value = map_index_value.fields[valueIndex].value; + + // Vectors sort first by length + WriteValueTypeLabel(encoder, IndexType::kNumber); + encoder->WriteLong(value.array_value.values_count); + + // Vectors then sort by position value + WriteIndexString(model::kVectorValueFieldKey, encoder); + WriteIndexValueAux(value, encoder); } void WriteIndexMap(google_firestore_v1_MapValue map_index_value, @@ -206,9 +210,9 @@ void WriteIndexValueAux(const google_firestore_v1_Value& index_value, if (model::IsMaxValue(index_value)) { WriteValueTypeLabel(encoder, std::numeric_limits::max()); break; - } else if(model::IsVectorValue(index_value)) { - WriteIndexVector(index_value.map_value, encoder); - break; + } else if (model::IsVectorValue(index_value)) { + WriteIndexVector(index_value.map_value, encoder); + break; } WriteIndexMap(index_value.map_value, encoder); WriteTruncationMarker(encoder); diff --git a/Firestore/core/src/model/value_util.cc b/Firestore/core/src/model/value_util.cc index 9dace80b301..2f3ce990449 100644 --- a/Firestore/core/src/model/value_util.cc +++ b/Firestore/core/src/model/value_util.cc @@ -92,9 +92,9 @@ TypeOrder GetTypeOrder(const google_firestore_v1_Value& value) { case google_firestore_v1_Value_geo_point_value_tag: return TypeOrder::kGeoPoint; - - case google_firestore_v1_Value_array_value_tag: - return TypeOrder::kArray; + + case google_firestore_v1_Value_array_value_tag: + return TypeOrder::kArray; case google_firestore_v1_Value_map_value_tag: { if (IsServerTimestamp(value)) { @@ -102,7 +102,7 @@ TypeOrder GetTypeOrder(const google_firestore_v1_Value& value) { } else if (IsMaxValue(value)) { return TypeOrder::kMaxValue; } else if (IsVectorValue(value)) { - return TypeOrder::kVector; + return TypeOrder::kVector; } return TypeOrder::kMap; } @@ -263,21 +263,26 @@ ComparisonResult CompareMaps(const google_firestore_v1_MapValue& left, } ComparisonResult CompareVectors(const google_firestore_v1_Value& left, - const google_firestore_v1_Value& right) { - HARD_ASSERT(IsVectorValue(left) && IsVectorValue(right), "Cannot compare non-vector values as vectors."); - - int64_t leftIndex = IndexOfKey(left.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); - int64_t rightIndex = IndexOfKey(right.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); - - google_firestore_v1_Value leftArray = left.map_value.fields[leftIndex].value; - google_firestore_v1_Value rightArray = right.map_value.fields[rightIndex].value; - - ComparisonResult lengthCompare = util::Compare(leftArray.array_value.values_count, rightArray.array_value.values_count); - if (lengthCompare != ComparisonResult::Same) { - return lengthCompare; - } + const google_firestore_v1_Value& right) { + HARD_ASSERT(IsVectorValue(left) && IsVectorValue(right), + "Cannot compare non-vector values as vectors."); + + int64_t leftIndex = + IndexOfKey(left.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); + int64_t rightIndex = IndexOfKey(right.map_value, kRawVectorValueFieldKey, + kVectorValueFieldKey); + + google_firestore_v1_Value leftArray = left.map_value.fields[leftIndex].value; + google_firestore_v1_Value rightArray = + right.map_value.fields[rightIndex].value; + + ComparisonResult lengthCompare = util::Compare( + leftArray.array_value.values_count, rightArray.array_value.values_count); + if (lengthCompare != ComparisonResult::Same) { + return lengthCompare; + } - return CompareArrays(leftArray, rightArray); + return CompareArrays(leftArray, rightArray); } ComparisonResult Compare(const google_firestore_v1_Value& left, @@ -320,12 +325,12 @@ ComparisonResult Compare(const google_firestore_v1_Value& left, case TypeOrder::kArray: return CompareArrays(left, right); - - case TypeOrder::kMap: - return CompareMaps(left.map_value, right.map_value); - - case TypeOrder::kVector: - return CompareVectors(left, right); + + case TypeOrder::kMap: + return CompareMaps(left.map_value, right.map_value); + + case TypeOrder::kVector: + return CompareVectors(left, right); case TypeOrder::kMaxValue: return util::ComparisonResult::Same; @@ -455,7 +460,7 @@ bool Equals(const google_firestore_v1_Value& lhs, case TypeOrder::kArray: return ArrayEquals(lhs.array_value, rhs.array_value); - case TypeOrder::kVector: + case TypeOrder::kVector: case TypeOrder::kMap: return MapValueEquals(lhs.map_value, rhs.map_value); @@ -570,50 +575,51 @@ std::string CanonicalId(const google_firestore_v1_ArrayValue& value) { return CanonifyArray(value); } -google_firestore_v1_Value GetLowerBound(const google_firestore_v1_Value& value) { +google_firestore_v1_Value GetLowerBound( + const google_firestore_v1_Value& value) { switch (value.which_value_type) { case google_firestore_v1_Value_null_value_tag: return NullValue(); case google_firestore_v1_Value_boolean_value_tag: { - return MinBoolean(); + return MinBoolean(); } case google_firestore_v1_Value_integer_value_tag: case google_firestore_v1_Value_double_value_tag: { - return MinNumber(); + return MinNumber(); } case google_firestore_v1_Value_timestamp_value_tag: { - return MinTimestamp(); + return MinTimestamp(); } case google_firestore_v1_Value_string_value_tag: { - return MinString(); + return MinString(); } case google_firestore_v1_Value_bytes_value_tag: { - return MinBytes(); + return MinBytes(); } case google_firestore_v1_Value_reference_value_tag: { - return MinReference(); + return MinReference(); } case google_firestore_v1_Value_geo_point_value_tag: { - return MinGeoPoint(); + return MinGeoPoint(); } case google_firestore_v1_Value_array_value_tag: { - return MinArray(); + return MinArray(); } case google_firestore_v1_Value_map_value_tag: { - if (IsVectorValue(value)) { - return MinVector(); - } - - return MinMap(); + if (IsVectorValue(value)) { + return MinVector(); + } + + return MinMap(); } default: @@ -621,31 +627,32 @@ google_firestore_v1_Value GetLowerBound(const google_firestore_v1_Value& value) } } -google_firestore_v1_Value GetUpperBound(const google_firestore_v1_Value& value) { +google_firestore_v1_Value GetUpperBound( + const google_firestore_v1_Value& value) { switch (value.which_value_type) { case google_firestore_v1_Value_null_value_tag: - return MinBoolean(); + return MinBoolean(); case google_firestore_v1_Value_boolean_value_tag: - return MinNumber(); + return MinNumber(); case google_firestore_v1_Value_integer_value_tag: case google_firestore_v1_Value_double_value_tag: - return MinTimestamp(); + return MinTimestamp(); case google_firestore_v1_Value_timestamp_value_tag: - return MinString(); + return MinString(); case google_firestore_v1_Value_string_value_tag: - return MinBytes(); + return MinBytes(); case google_firestore_v1_Value_bytes_value_tag: - return MinReference(); + return MinReference(); case google_firestore_v1_Value_reference_value_tag: - return MinGeoPoint(); + return MinGeoPoint(); case google_firestore_v1_Value_geo_point_value_tag: - return MinArray(); + return MinArray(); case google_firestore_v1_Value_array_value_tag: - return MinVector(); + return MinVector(); case google_firestore_v1_Value_map_value_tag: - if (IsVectorValue(value)) { - return MinMap(); - } + if (IsVectorValue(value)) { + return MinMap(); + } return MaxValue(); default: HARD_FAIL("Invalid type value: %s", value.which_value_type); @@ -746,55 +753,60 @@ bool IsMaxValue(const google_firestore_v1_Value& value) { kRawMaxValueFieldValue; } -int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, const char* kRawTypeValueFieldKey, +int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, + const char* kRawTypeValueFieldKey, pb_bytes_array_s* kTypeValueFieldKey) { - for (pb_size_t i = 0; i < mapValue.fields_count; i++) { - if (mapValue.fields[i].key == kTypeValueFieldKey || - nanopb::MakeStringView(mapValue.fields[i].key) == - kRawTypeValueFieldKey) { - return i; - } + for (pb_size_t i = 0; i < mapValue.fields_count; i++) { + if (mapValue.fields[i].key == kTypeValueFieldKey || + nanopb::MakeStringView(mapValue.fields[i].key) == + kRawTypeValueFieldKey) { + return i; } - - return -1; + } + + return -1; } bool IsVectorValue(const google_firestore_v1_Value& value) { - if (value.which_value_type != google_firestore_v1_Value_map_value_tag) { - return false; - } - - if (value.map_value.fields_count < 2) { - return false; - } - - int64_t typeFieldIndex = -1; - if ((typeFieldIndex = IndexOfKey(value.map_value, kRawTypeValueFieldKey, kTypeValueFieldKey)) < 0) { - return false; - } - - if (value.map_value.fields[typeFieldIndex].value.which_value_type != - google_firestore_v1_Value_string_value_tag) { - return false; - } - - // Comparing the pointer address, then actual content if addresses are - // different. - return value.map_value.fields[typeFieldIndex].value.string_value == kVectorTypeFieldValue || - nanopb::MakeStringView(value.map_value.fields[typeFieldIndex].value.string_value) == - kRawVectorTypeFieldValue; - - int64_t valueFieldIndex = -1; - if ((valueFieldIndex = IndexOfKey(value.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey)) < 0) { - return false; - } - - if (value.map_value.fields[valueFieldIndex].value.which_value_type != - google_firestore_v1_Value_map_value_tag) { - return false; - } - - return true; + if (value.which_value_type != google_firestore_v1_Value_map_value_tag) { + return false; + } + + if (value.map_value.fields_count < 2) { + return false; + } + + int64_t typeFieldIndex = -1; + if ((typeFieldIndex = IndexOfKey(value.map_value, kRawTypeValueFieldKey, + kTypeValueFieldKey)) < 0) { + return false; + } + + if (value.map_value.fields[typeFieldIndex].value.which_value_type != + google_firestore_v1_Value_string_value_tag) { + return false; + } + + // Comparing the pointer address, then actual content if addresses are + // different. + return value.map_value.fields[typeFieldIndex].value.string_value == + kVectorTypeFieldValue || + nanopb::MakeStringView( + value.map_value.fields[typeFieldIndex].value.string_value) == + kRawVectorTypeFieldValue; + + int64_t valueFieldIndex = -1; + if ((valueFieldIndex = IndexOfKey(value.map_value, kRawVectorValueFieldKey, + kVectorValueFieldKey)) < 0) { + return false; + } + + if (value.map_value.fields[valueFieldIndex].value.which_value_type != + google_firestore_v1_Value_map_value_tag) { + return false; + } + + return true; } google_firestore_v1_Value NaNValue() { @@ -866,38 +878,39 @@ google_firestore_v1_Value MinArray() { } google_firestore_v1_Value MinVector() { - google_firestore_v1_Value typeValue; - typeValue.which_value_type = google_firestore_v1_Value_string_value_tag; - typeValue.string_value = kVectorTypeFieldValue; - - google_firestore_v1_MapValue_FieldsEntry *field_entries = nanopb::MakeArray(2); - field_entries[0].key = kTypeValueFieldKey; - field_entries[0].value = typeValue; - - google_firestore_v1_Value arrayValue; - arrayValue.which_value_type = google_firestore_v1_Value_array_value_tag; - arrayValue.array_value.values = nullptr; - arrayValue.array_value.values_count = 0; - field_entries[1].key = kVectorValueFieldKey; - field_entries[1].value = arrayValue; - - google_firestore_v1_MapValue map_value; - map_value.fields_count = 2; - map_value.fields = field_entries; - - google_firestore_v1_Value lowerBound; - lowerBound.which_value_type = google_firestore_v1_Value_map_value_tag; - lowerBound.map_value = map_value; - - return lowerBound; -} - + google_firestore_v1_Value typeValue; + typeValue.which_value_type = google_firestore_v1_Value_string_value_tag; + typeValue.string_value = kVectorTypeFieldValue; + + google_firestore_v1_MapValue_FieldsEntry* field_entries = + nanopb::MakeArray(2); + field_entries[0].key = kTypeValueFieldKey; + field_entries[0].value = typeValue; + + google_firestore_v1_Value arrayValue; + arrayValue.which_value_type = google_firestore_v1_Value_array_value_tag; + arrayValue.array_value.values = nullptr; + arrayValue.array_value.values_count = 0; + field_entries[1].key = kVectorValueFieldKey; + field_entries[1].value = arrayValue; + + google_firestore_v1_MapValue map_value; + map_value.fields_count = 2; + map_value.fields = field_entries; + + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_map_value_tag; + lowerBound.map_value = map_value; + + return lowerBound; +} + google_firestore_v1_Value MinMap() { - google_firestore_v1_Value lowerBound; - lowerBound.which_value_type = google_firestore_v1_Value_map_value_tag; - lowerBound.map_value.fields = nullptr; - lowerBound.map_value.fields_count = 0; - return lowerBound; + google_firestore_v1_Value lowerBound; + lowerBound.which_value_type = google_firestore_v1_Value_map_value_tag; + lowerBound.map_value.fields = nullptr; + lowerBound.map_value.fields_count = 0; + return lowerBound; } Message RefValue( diff --git a/Firestore/core/src/model/value_util.h b/Firestore/core/src/model/value_util.h index 44a37f20067..49ff85f3d55 100644 --- a/Firestore/core/src/model/value_util.h +++ b/Firestore/core/src/model/value_util.h @@ -62,19 +62,19 @@ extern pb_bytes_array_s* kVectorValueFieldKey; * ordering, but modified to support server timestamps. */ enum class TypeOrder { - kNull = 0, - kBoolean = 1, - kNumber = 2, - kTimestamp = 3, - kServerTimestamp = 4, - kString = 5, - kBlob = 6, - kReference = 7, - kGeoPoint = 8, - kArray = 9, - kVector = 10, - kMap = 11, - kMaxValue = 12 + kNull = 0, + kBoolean = 1, + kNumber = 2, + kTimestamp = 3, + kServerTimestamp = 4, + kString = 5, + kBlob = 6, + kReference = 7, + kGeoPoint = 8, + kArray = 9, + kVector = 10, + kMap = 11, + kMaxValue = 12 }; /** Returns the backend's type order of the given Value type. */ @@ -181,7 +181,8 @@ bool IsMaxValue(const google_firestore_v1_Value& value); */ bool IsVectorValue(const google_firestore_v1_Value& value); -int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, const char* kRawTypeValueFieldKey, +int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, + const char* kRawTypeValueFieldKey, pb_bytes_array_s* kTypeValueFieldKey); /** @@ -212,7 +213,7 @@ google_firestore_v1_Value MinGeoPoint(); google_firestore_v1_Value MinArray(); google_firestore_v1_Value MinVector(); - + google_firestore_v1_Value MinMap(); /** diff --git a/Firestore/core/test/unit/local/leveldb_index_manager_test.cc b/Firestore/core/test/unit/local/leveldb_index_manager_test.cc index 175f681e512..1f2c8de44d2 100644 --- a/Firestore/core/test/unit/local/leveldb_index_manager_test.cc +++ b/Firestore/core/test/unit/local/leveldb_index_manager_test.cc @@ -49,8 +49,8 @@ using testutil::Map; using testutil::OrderBy; using testutil::OrFilters; using testutil::Query; -using testutil::Version; using testutil::VectorType; +using testutil::Version; std::unique_ptr PersistenceFactory() { return LevelDbPersistenceForTesting(); @@ -931,42 +931,49 @@ TEST_F(LevelDbIndexManagerTest, IndexEntriesAreUpdatedWithDeletedDoc) { } TEST_F(LevelDbIndexManagerTest, IndexVectorValueFields) { - persistence_->Run("TestIndexVectorValueFields", [&]() { - index_manager_->Start(); - index_manager_->AddFieldIndex(MakeFieldIndex("coll", "embedding", model::Segment::kAscending)); - - AddDoc("coll/arr1", Map("embedding", Array(1.0, 2.0, 3.0))); - AddDoc("coll/map2", Map("embedding", Map())); - AddDoc("coll/doc3", Map("embedding", VectorType(4.0, 5.0, 6.0))); - AddDoc("coll/doc4", Map("embedding", VectorType(5.0))); - - auto query = Query("coll").AddingOrderBy(OrderBy("embedding")); - { - SCOPED_TRACE("no filter"); - VerifyResults(query, {"coll/arr1", "coll/doc4", "coll/doc3", "coll/map2"}); - } - - query = Query("coll").AddingOrderBy(OrderBy("embedding")) + persistence_->Run("TestIndexVectorValueFields", [&]() { + index_manager_->Start(); + index_manager_->AddFieldIndex( + MakeFieldIndex("coll", "embedding", model::Segment::kAscending)); + + AddDoc("coll/arr1", Map("embedding", Array(1.0, 2.0, 3.0))); + AddDoc("coll/map2", Map("embedding", Map())); + AddDoc("coll/doc3", Map("embedding", VectorType(4.0, 5.0, 6.0))); + AddDoc("coll/doc4", Map("embedding", VectorType(5.0))); + + auto query = Query("coll").AddingOrderBy(OrderBy("embedding")); + { + SCOPED_TRACE("no filter"); + VerifyResults(query, + {"coll/arr1", "coll/doc4", "coll/doc3", "coll/map2"}); + } + + query = + Query("coll") + .AddingOrderBy(OrderBy("embedding")) .AddingFilter(Filter("embedding", "==", VectorType(4.0, 5.0, 6.0))); - { - SCOPED_TRACE("vector<4.0, 5.0, 6.0>"); - VerifyResults(query, {"coll/doc3"}); - } - - query = Query("coll").AddingOrderBy(OrderBy("embedding")) + { + SCOPED_TRACE("vector<4.0, 5.0, 6.0>"); + VerifyResults(query, {"coll/doc3"}); + } + + query = + Query("coll") + .AddingOrderBy(OrderBy("embedding")) .AddingFilter(Filter("embedding", ">", VectorType(4.0, 5.0, 6.0))); - { - SCOPED_TRACE("> vector<4.0, 5.0, 6.0>"); - VerifyResults(query, {}); - } - - query = Query("coll").AddingOrderBy(OrderBy("embedding")) - .AddingFilter(Filter("embedding", ">", VectorType(4.0))); - { - SCOPED_TRACE("> vector<4.0>"); - VerifyResults(query, {"coll/doc4", "coll/doc3"}); - } - }); + { + SCOPED_TRACE("> vector<4.0, 5.0, 6.0>"); + VerifyResults(query, {}); + } + + query = Query("coll") + .AddingOrderBy(OrderBy("embedding")) + .AddingFilter(Filter("embedding", ">", VectorType(4.0))); + { + SCOPED_TRACE("> vector<4.0>"); + VerifyResults(query, {"coll/doc4", "coll/doc3"}); + } + }); } TEST_F(LevelDbIndexManagerTest, AdvancedQueries) { diff --git a/Firestore/core/test/unit/model/value_util_test.cc b/Firestore/core/test/unit/model/value_util_test.cc index 465d20d0e02..c6d2479929c 100644 --- a/Firestore/core/test/unit/model/value_util_test.cc +++ b/Firestore/core/test/unit/model/value_util_test.cc @@ -99,9 +99,9 @@ class ValueUtilTest : public ::testing::Test { ComparisonResult expected_result) { for (pb_size_t i = 0; i < left->values_count; ++i) { for (pb_size_t j = 0; j < right->values_count; ++j) { - if (expected_result != Compare(left->values[i], right->values[j])) { - std::cout << "here" << std::endl; - } + if (expected_result != Compare(left->values[i], right->values[j])) { + std::cout << "here" << std::endl; + } EXPECT_EQ(expected_result, Compare(left->values[i], right->values[j])) << "Order check failed for '" << CanonicalId(left->values[i]) << "' and '" << CanonicalId(right->values[j]) << "' (expected " @@ -246,7 +246,8 @@ TEST_F(ValueUtilTest, Equality) { Add(equals_group, Array("foo", "bar"), Array("foo", "bar")); Add(equals_group, Array("foo", "bar", "baz")); Add(equals_group, Array("foo")); - Add(equals_group, Map("__type__", "__vector__", "value", Array()), DeepClone(MinVector())); + Add(equals_group, Map("__type__", "__vector__", "value", Array()), + DeepClone(MinVector())); Add(equals_group, Map("bar", 1, "foo", 2), Map("bar", 1, "foo", 2)); Add(equals_group, Map("bar", 2, "foo", 1)); Add(equals_group, Map("bar", 1)); @@ -275,8 +276,7 @@ TEST_F(ValueUtilTest, StrictOrdering) { Add(comparison_groups, true); // numbers - Add(comparison_groups, - DeepClone(MinNumber())); + Add(comparison_groups, DeepClone(MinNumber())); Add(comparison_groups, -1e20); Add(comparison_groups, std::numeric_limits::min()); Add(comparison_groups, -0.1); @@ -289,8 +289,7 @@ TEST_F(ValueUtilTest, StrictOrdering) { Add(comparison_groups, 1e20); // dates - Add(comparison_groups, - DeepClone(MinTimestamp())); + Add(comparison_groups, DeepClone(MinTimestamp())); Add(comparison_groups, kTimestamp1); Add(comparison_groups, kTimestamp2); @@ -320,8 +319,7 @@ TEST_F(ValueUtilTest, StrictOrdering) { Add(comparison_groups, BlobValue(255)); // resource names - Add(comparison_groups, - DeepClone(MinReference())); + Add(comparison_groups, DeepClone(MinReference())); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc1"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc2"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c10/doc1"))); @@ -344,30 +342,28 @@ TEST_F(ValueUtilTest, StrictOrdering) { Add(comparison_groups, GeoPoint(90, 180)); // arrays - Add(comparison_groups, - DeepClone(MinArray())); + Add(comparison_groups, DeepClone(MinArray())); Add(comparison_groups, Array("bar")); Add(comparison_groups, Array("foo", 1)); Add(comparison_groups, Array("foo", 2)); Add(comparison_groups, Array("foo", "0")); - + // vectors - Add(comparison_groups, - DeepClone(MinVector())); - Add(comparison_groups, Map("__type__", "__vector__", "value", Array(100))); - Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 2.0, 3.0))); - Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 3.0, 2.0))); + Add(comparison_groups, DeepClone(MinVector())); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(100))); + Add(comparison_groups, + Map("__type__", "__vector__", "value", Array(1.0, 2.0, 3.0))); + Add(comparison_groups, + Map("__type__", "__vector__", "value", Array(1.0, 3.0, 2.0))); // objects - Add(comparison_groups, - DeepClone(MinMap())); + Add(comparison_groups, DeepClone(MinMap())); Add(comparison_groups, Map("bar", 0)); Add(comparison_groups, Map("bar", 0, "foo", 1)); Add(comparison_groups, Map("foo", 1)); Add(comparison_groups, Map("foo", 2)); Add(comparison_groups, Map("foo", "0")); - Add(comparison_groups, - DeepClone(MaxValue())); + Add(comparison_groups, DeepClone(MaxValue())); for (size_t i = 0; i < comparison_groups.size(); ++i) { for (size_t j = i; j < comparison_groups.size(); ++j) { @@ -388,25 +384,19 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { std::vector> comparison_groups; // null first - Add(comparison_groups, - DeepClone(NullValue())); + Add(comparison_groups, DeepClone(NullValue())); Add(comparison_groups, nullptr); - Add(comparison_groups, - DeepClone(MinBoolean())); + Add(comparison_groups, DeepClone(MinBoolean())); // booleans - Add(comparison_groups, - DeepClone(MinBoolean())); + Add(comparison_groups, DeepClone(MinBoolean())); Add(comparison_groups, false); Add(comparison_groups, true); - Add(comparison_groups, - DeepClone(MinNumber())); + Add(comparison_groups, DeepClone(MinNumber())); // numbers - Add(comparison_groups, - DeepClone(MinNumber())); - Add(comparison_groups, - DeepClone(MinNumber())); + Add(comparison_groups, DeepClone(MinNumber())); + Add(comparison_groups, DeepClone(MinNumber())); Add(comparison_groups, -1e20); Add(comparison_groups, std::numeric_limits::min()); Add(comparison_groups, -0.1); @@ -417,14 +407,11 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, 1.0, 1L); Add(comparison_groups, std::numeric_limits::max()); Add(comparison_groups, 1e20); - Add(comparison_groups, - DeepClone(MinTimestamp())); - Add(comparison_groups, - DeepClone(MinTimestamp())); + Add(comparison_groups, DeepClone(MinTimestamp())); + Add(comparison_groups, DeepClone(MinTimestamp())); // dates - Add(comparison_groups, - DeepClone(MinTimestamp())); + Add(comparison_groups, DeepClone(MinTimestamp())); Add(comparison_groups, kTimestamp1); Add(comparison_groups, kTimestamp2); @@ -432,12 +419,10 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { // NOTE: server timestamps can't be parsed with . Add(comparison_groups, EncodeServerTimestamp(kTimestamp1, absl::nullopt)); Add(comparison_groups, EncodeServerTimestamp(kTimestamp2, absl::nullopt)); - Add(comparison_groups, - DeepClone(MinString())); + Add(comparison_groups, DeepClone(MinString())); // strings - Add(comparison_groups, - DeepClone(MinString())); + Add(comparison_groups, DeepClone(MinString())); Add(comparison_groups, ""); Add(comparison_groups, "\001\ud7ff\ue000\uffff"); Add(comparison_groups, "(╯°□°)╯︵ ┻━┻"); @@ -449,35 +434,29 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, "æ"); // latin small letter e with acute accent + latin small letter a Add(comparison_groups, "\u00e9a"); - Add(comparison_groups, - DeepClone(MinBytes())); + Add(comparison_groups, DeepClone(MinBytes())); // blobs - Add(comparison_groups, - DeepClone(MinBytes())); + Add(comparison_groups, DeepClone(MinBytes())); Add(comparison_groups, BlobValue()); Add(comparison_groups, BlobValue(0)); Add(comparison_groups, BlobValue(0, 1, 2, 3, 4)); Add(comparison_groups, BlobValue(0, 1, 2, 4, 3)); Add(comparison_groups, BlobValue(255)); - Add(comparison_groups, - DeepClone(MinReference())); + Add(comparison_groups, DeepClone(MinReference())); // resource names - Add(comparison_groups, - DeepClone(MinReference())); + Add(comparison_groups, DeepClone(MinReference())); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc1"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c1/doc2"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c10/doc1"))); Add(comparison_groups, RefValue(DbId("p1/d1"), Key("c2/doc1"))); Add(comparison_groups, RefValue(DbId("p1/d2"), Key("c1/doc1"))); Add(comparison_groups, RefValue(DbId("p2/d1"), Key("c1/doc1"))); - Add(comparison_groups, - DeepClone(MinGeoPoint())); + Add(comparison_groups, DeepClone(MinGeoPoint())); // geo points - Add(comparison_groups, - DeepClone(MinGeoPoint())); + Add(comparison_groups, DeepClone(MinGeoPoint())); Add(comparison_groups, GeoPoint(-90, -180)); Add(comparison_groups, GeoPoint(-90, 0)); Add(comparison_groups, GeoPoint(-90, 180)); @@ -490,36 +469,32 @@ TEST_F(ValueUtilTest, RelaxedOrdering) { Add(comparison_groups, GeoPoint(90, -180)); Add(comparison_groups, GeoPoint(90, 0)); Add(comparison_groups, GeoPoint(90, 180)); - Add(comparison_groups, - DeepClone(MinArray())); + Add(comparison_groups, DeepClone(MinArray())); // arrays - Add(comparison_groups, - DeepClone(MinArray())); + Add(comparison_groups, DeepClone(MinArray())); Add(comparison_groups, Array("bar")); Add(comparison_groups, Array("foo", 1)); Add(comparison_groups, Array("foo", 2)); Add(comparison_groups, Array("foo", "0")); + Add(comparison_groups, DeepClone(MinVector())); + + // vectors + Add(comparison_groups, DeepClone(MinVector())); + Add(comparison_groups, Map("__type__", "__vector__", "value", Array(100))); Add(comparison_groups, - DeepClone(MinVector())); - - // vectors - Add(comparison_groups, - DeepClone(MinVector())); - Add(comparison_groups, Map("__type__", "__vector__", "value", Array(100))); - Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 2.0, 3.0))); - Add(comparison_groups, Map("__type__", "__vector__", "value", Array(1.0, 3.0, 2.0))); - - // objects + Map("__type__", "__vector__", "value", Array(1.0, 2.0, 3.0))); Add(comparison_groups, - DeepClone(MinMap())); + Map("__type__", "__vector__", "value", Array(1.0, 3.0, 2.0))); + + // objects + Add(comparison_groups, DeepClone(MinMap())); Add(comparison_groups, Map("bar", 0)); Add(comparison_groups, Map("bar", 0, "foo", 1)); Add(comparison_groups, Map("foo", 1)); Add(comparison_groups, Map("foo", 2)); Add(comparison_groups, Map("foo", "0")); - Add(comparison_groups, - DeepClone(MaxValue())); + Add(comparison_groups, DeepClone(MaxValue())); for (size_t i = 0; i < comparison_groups.size(); ++i) { for (size_t j = i; j < comparison_groups.size(); ++j) { @@ -533,7 +508,7 @@ TEST_F(ValueUtilTest, CanonicalId) { VerifyCanonicalId(Value(true), "true"); VerifyCanonicalId(Value(false), "false"); VerifyCanonicalId(Value(1), "1"); - VerifyCanonicalId(Value(1.0), "1.0"); + VerifyCanonicalId(Value(1.0), "1.0"); VerifyCanonicalId(Value(Timestamp(30, 1000)), "time(30,1000)"); VerifyCanonicalId(Value("a"), "a"); VerifyCanonicalId(Value(std::string("a\0b", 3)), std::string("a\0b", 3)); @@ -542,10 +517,11 @@ TEST_F(ValueUtilTest, CanonicalId) { VerifyCanonicalId(Value(GeoPoint(30, 60)), "geo(30.0,60.0)"); VerifyCanonicalId(Value(Array(1, 2, 3)), "[1,2,3]"); VerifyCanonicalId(Map("a", 1, "b", 2, "c", "3"), "{a:1,b:2,c:3}"); - VerifyCanonicalId(Map("a", Array("b", Map("c", GeoPoint(30, 60)))), - "{a:[b,{c:geo(30.0,60.0)}]}"); - VerifyCanonicalId(Map("__type__", "__vector__", "value", Array(1.0, 1.0, -2.0, 3.14)), - "{__type__:__vector__,value:[1.0,1.0,-2.0,3.1]}"); + VerifyCanonicalId(Map("a", Array("b", Map("c", GeoPoint(30, 60)))), + "{a:[b,{c:geo(30.0,60.0)}]}"); + VerifyCanonicalId( + Map("__type__", "__vector__", "value", Array(1.0, 1.0, -2.0, 3.14)), + "{__type__:__vector__,value:[1.0,1.0,-2.0,3.1]}"); } TEST_F(ValueUtilTest, DeepClone) { diff --git a/Firestore/core/test/unit/remote/serializer_test.cc b/Firestore/core/test/unit/remote/serializer_test.cc index 79fc79e99d7..0de665273ee 100644 --- a/Firestore/core/test/unit/remote/serializer_test.cc +++ b/Firestore/core/test/unit/remote/serializer_test.cc @@ -822,8 +822,8 @@ TEST_F(SerializerTest, EncodesNestedObjects) { } TEST_F(SerializerTest, EncodesVectorValue) { - Message model = Map( - "__type__", "__vector__", "value", Array(1.0, 2.0, 3.0)); + Message model = + Map("__type__", "__vector__", "value", Array(1.0, 2.0, 3.0)); v1::Value array_proto; *array_proto.mutable_array_value()->add_values() = ValueProto(1.0); diff --git a/Firestore/core/test/unit/testutil/testutil.h b/Firestore/core/test/unit/testutil/testutil.h index 259e93eb6c2..234ef3d5d12 100644 --- a/Firestore/core/test/unit/testutil/testutil.h +++ b/Firestore/core/test/unit/testutil/testutil.h @@ -290,7 +290,8 @@ nanopb::Message Map(Args... key_value_pairs) { template nanopb::Message VectorType(Args&&... values) { - return Map("__type__", "__vector__", "value", details::MakeArray(std::move(values)...)); + return Map("__type__", "__vector__", "value", + details::MakeArray(std::move(values)...)); } model::DocumentKey Key(absl::string_view path); From a0c596e6106b543863492eb6ba108598be2e417b Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:16:01 -0600 Subject: [PATCH 07/24] PR feedback on codable. --- Firestore/Source/API/FIRFieldValue.mm | 4 +-- .../Source/API/FIRVectorValue+Internal.h | 26 ------------------- .../{FIRVectorValue.m => FIRVectorValue.mm} | 26 +++++++++---------- Firestore/Source/API/FSTUserDataReader.mm | 2 +- .../Public/FirebaseFirestore/FIRVectorValue.h | 5 ++-- .../Source/Codable/VectorValue+Codable.swift | 12 ++++----- .../Source/SwiftAPI/VectorValue+Swift.swift | 6 ++--- .../Integration/VectorIntegrationTests.swift | 4 +-- 8 files changed, 30 insertions(+), 55 deletions(-) delete mode 100644 Firestore/Source/API/FIRVectorValue+Internal.h rename Firestore/Source/API/{FIRVectorValue.m => FIRVectorValue.mm} (76%) diff --git a/Firestore/Source/API/FIRFieldValue.mm b/Firestore/Source/API/FIRFieldValue.mm index 4b59e78c998..56583d11717 100644 --- a/Firestore/Source/API/FIRFieldValue.mm +++ b/Firestore/Source/API/FIRFieldValue.mm @@ -15,7 +15,7 @@ */ #import "Firestore/Source/API/FIRFieldValue+Internal.h" -#import "Firestore/Source/API/FIRVectorValue+Internal.h" +#import "Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h" NS_ASSUME_NONNULL_BEGIN @@ -178,7 +178,7 @@ + (instancetype)fieldValueForIntegerIncrement:(int64_t)l { } + (nonnull FIRVectorValue *)vectorFromNSNumbers:(nonnull NSArray *)values { - return [[FIRVectorValue alloc] initWithNSNumbers:values]; + return [[FIRVectorValue alloc] initWithArray:values]; } @end diff --git a/Firestore/Source/API/FIRVectorValue+Internal.h b/Firestore/Source/API/FIRVectorValue+Internal.h deleted file mode 100644 index 5192eceb78e..00000000000 --- a/Firestore/Source/API/FIRVectorValue+Internal.h +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#import "FIRVectorValue.h" - -NS_ASSUME_NONNULL_BEGIN - -@interface FIRVectorValue (Internal) -// - (instancetype)init NS_UNAVAILABLE; -// - (NSArray *)toNSArray; -@end - -NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/API/FIRVectorValue.m b/Firestore/Source/API/FIRVectorValue.mm similarity index 76% rename from Firestore/Source/API/FIRVectorValue.m rename to Firestore/Source/API/FIRVectorValue.mm index 417a52efda4..a1214dade37 100644 --- a/Firestore/Source/API/FIRVectorValue.m +++ b/Firestore/Source/API/FIRVectorValue.mm @@ -18,7 +18,7 @@ #include -#import "Firestore/Source/API/FIRVectorValue+Internal.h" +#include "Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h" NS_ASSUME_NONNULL_BEGIN @@ -31,7 +31,18 @@ @interface FIRVectorValue () { @implementation FIRVectorValue -- (instancetype)initWithNSNumbers:(NSArray *)data { +- (NSArray *) array { + size_t length = _internalValue.size(); + NSMutableArray *outArray = + [[NSMutableArray alloc] initWithCapacity:length]; + for (size_t i = 0; i < length; i++) { + [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; + } + + return outArray; +} + +- (instancetype)initWithArray:(NSArray *)data { if (self = [super init]) { std::vector converted; converted.reserve(data.count); @@ -44,17 +55,6 @@ - (instancetype)initWithNSNumbers:(NSArray *)data { return self; } -- (nonnull NSArray *)toNSArray { - size_t length = _internalValue.size(); - NSMutableArray *outArray = - [[NSMutableArray alloc] initWithCapacity:length]; - for (size_t i = 0; i < length; i++) { - [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; - } - - return outArray; -} - - (BOOL)isEqual:(nullable id)object { if (self == object) { return YES; diff --git a/Firestore/Source/API/FSTUserDataReader.mm b/Firestore/Source/API/FSTUserDataReader.mm index 3ea08e1e847..db5d1eea093 100644 --- a/Firestore/Source/API/FSTUserDataReader.mm +++ b/Firestore/Source/API/FSTUserDataReader.mm @@ -353,7 +353,7 @@ - (ParsedUpdateData)parsedUpdateData:(id)input { result->map_value.fields[0].key = nanopb::CopyBytesArray(model::kTypeValueFieldKey); result->map_value.fields[0].value = *[self encodeStringValue:MakeString(@"__vector__")].release(); - NSArray *vectorArray = [vectorValue toNSArray]; + NSArray *vectorArray = vectorValue.array; __block Message arrayMessage; arrayMessage->which_value_type = google_firestore_v1_Value_array_value_tag; diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h index af82d447f45..28924e3b84b 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h @@ -21,13 +21,14 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(VectorValue) @interface FIRVectorValue : NSObject +@property(readonly) NSArray *array NS_REFINED_FOR_SWIFT; + /** :nodoc: */ - (instancetype)init NS_UNAVAILABLE; // Public initializer is required to support Codable -- (instancetype)initWithNSNumbers:(NSArray *)data NS_REFINED_FOR_SWIFT; +- (instancetype)initWithArray:(NSArray *)array NS_REFINED_FOR_SWIFT; -- (NSArray *)toNSArray NS_REFINED_FOR_SWIFT; @end NS_ASSUME_NONNULL_END diff --git a/Firestore/Swift/Source/Codable/VectorValue+Codable.swift b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift index 7ed48ddc6a2..d10fbc593bf 100644 --- a/Firestore/Swift/Source/Codable/VectorValue+Codable.swift +++ b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift @@ -24,14 +24,14 @@ * A protocol describing the encodable properties of a VectorValue. */ private protocol CodableVectorValue: Codable { - var data: [Double] { get } + var array: [Double] { get } - init(__nsNumbers: [NSNumber]) + init(__array: [NSNumber]) } /** The keys in a Timestamp. Must match the properties of CodableTimestamp. */ private enum VectorValueKeys: String, CodingKey { - case data + case array } /** @@ -43,17 +43,17 @@ private enum VectorValueKeys: String, CodingKey { extension CodableVectorValue { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: VectorValueKeys.self) - let data = try container.decode([Double].self, forKey: .data) + let data = try container.decode([Double].self, forKey: .array) let array = data.map { double in NSNumber(value: double) } - self.init(__nsNumbers: array) + self.init(__array: array) } public func encode(to encoder: Encoder) throws { var container = encoder.container(keyedBy: VectorValueKeys.self) - try container.encode(data, forKey: .data) + try container.encode(array, forKey: .array) } } diff --git a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift index 590c6644a08..ef0424f250a 100644 --- a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift @@ -26,12 +26,12 @@ public extension VectorValue { NSNumber(value: float) } - self.init(__nsNumbers: array) + self.init(__array: array) } /// Returns a raw number array representation of the vector. /// - Returns: An array of Double values representing the vector. - var data: [Double] { - return __toNSArray().map { Double(truncating: $0) } + var array: [Double] { + return __array.map { Double(truncating: $0) } } } diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index 4d50d89c6e1..1edec71d211 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -182,7 +182,7 @@ class VectorIntegrationTests: FSTIntegrationTestCase { let returnedModel: Model = try querySnap.documents[0].data(as: Model.self) XCTAssertEqual(returnedModel.embedding, VectorValue([0.1, 0.3, 0.4])) - let vectorData: [Double] = returnedModel.embedding.data + let vectorData: [Double] = returnedModel.embedding.array XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) } @@ -232,7 +232,7 @@ class VectorIntegrationTests: FSTIntegrationTestCase { XCTAssertEqual(returnedModel.data.embedding, VectorValue([0.1, 0.3, 0.4])) XCTAssertEqual(returnedModel.distance, 0.2) - let vectorData: [Double] = returnedModel.data.embedding.data + let vectorData: [Double] = returnedModel.data.embedding.array XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) } } From ca625e9e58744f77e8bbae3deff6936634d924f1 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:16:43 -0600 Subject: [PATCH 08/24] formatting --- Firestore/Source/API/FIRVectorValue.mm | 16 ++++++++-------- .../Source/SwiftAPI/VectorValue+Swift.swift | 2 +- .../Integration/VectorIntegrationTests.swift | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Firestore/Source/API/FIRVectorValue.mm b/Firestore/Source/API/FIRVectorValue.mm index a1214dade37..8c6acd86d68 100644 --- a/Firestore/Source/API/FIRVectorValue.mm +++ b/Firestore/Source/API/FIRVectorValue.mm @@ -31,15 +31,15 @@ @interface FIRVectorValue () { @implementation FIRVectorValue -- (NSArray *) array { - size_t length = _internalValue.size(); - NSMutableArray *outArray = - [[NSMutableArray alloc] initWithCapacity:length]; - for (size_t i = 0; i < length; i++) { - [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; - } +- (NSArray *)array { + size_t length = _internalValue.size(); + NSMutableArray *outArray = + [[NSMutableArray alloc] initWithCapacity:length]; + for (size_t i = 0; i < length; i++) { + [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; + } - return outArray; + return outArray; } - (instancetype)initWithArray:(NSArray *)data { diff --git a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift index ef0424f250a..3a177aab53c 100644 --- a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift @@ -32,6 +32,6 @@ public extension VectorValue { /// Returns a raw number array representation of the vector. /// - Returns: An array of Double values representing the vector. var array: [Double] { - return __array.map { Double(truncating: $0) } + return __array.map { Double(truncating: $0) } } } diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index 1edec71d211..0cb5c2a63de 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -182,7 +182,7 @@ class VectorIntegrationTests: FSTIntegrationTestCase { let returnedModel: Model = try querySnap.documents[0].data(as: Model.self) XCTAssertEqual(returnedModel.embedding, VectorValue([0.1, 0.3, 0.4])) - let vectorData: [Double] = returnedModel.embedding.array + let vectorData: [Double] = returnedModel.embedding.array XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) } @@ -232,7 +232,7 @@ class VectorIntegrationTests: FSTIntegrationTestCase { XCTAssertEqual(returnedModel.data.embedding, VectorValue([0.1, 0.3, 0.4])) XCTAssertEqual(returnedModel.distance, 0.2) - let vectorData: [Double] = returnedModel.data.embedding.array + let vectorData: [Double] = returnedModel.data.embedding.array XCTAssertEqual(vectorData, [0.1, 0.3, 0.4]) } } From 190ea8c8ff888df32d13de3d0d2dd091a643a50f Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:52:21 -0600 Subject: [PATCH 09/24] PR comments c++ --- .../src/index/firestore_index_value_writer.cc | 8 +-- Firestore/core/src/model/value_util.cc | 53 +++++++++++-------- Firestore/core/src/model/value_util.h | 13 +++-- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/Firestore/core/src/index/firestore_index_value_writer.cc b/Firestore/core/src/index/firestore_index_value_writer.cc index 24a261a3d84..e44e3118bed 100644 --- a/Firestore/core/src/index/firestore_index_value_writer.cc +++ b/Firestore/core/src/index/firestore_index_value_writer.cc @@ -111,17 +111,17 @@ void WriteIndexVector(const google_firestore_v1_MapValue& map_index_value, DirectionalIndexByteEncoder* encoder) { WriteValueTypeLabel(encoder, IndexType::kVector); - int64_t valueIndex = + absl::optional valueIndex = model::IndexOfKey(map_index_value, model::kRawVectorValueFieldKey, model::kVectorValueFieldKey); - if (valueIndex < 0 || - map_index_value.fields[valueIndex].value.which_value_type != + if (!valueIndex.has_value() || + map_index_value.fields[valueIndex.value()].value.which_value_type != google_firestore_v1_Value_array_value_tag) { return WriteIndexArray(model::MinArray().array_value, encoder); } - auto value = map_index_value.fields[valueIndex].value; + auto value = map_index_value.fields[valueIndex.value()].value; // Vectors sort first by length WriteValueTypeLabel(encoder, IndexType::kNumber); diff --git a/Firestore/core/src/model/value_util.cc b/Firestore/core/src/model/value_util.cc index 2f3ce990449..dfbdc4caaee 100644 --- a/Firestore/core/src/model/value_util.cc +++ b/Firestore/core/src/model/value_util.cc @@ -267,14 +267,18 @@ ComparisonResult CompareVectors(const google_firestore_v1_Value& left, HARD_ASSERT(IsVectorValue(left) && IsVectorValue(right), "Cannot compare non-vector values as vectors."); - int64_t leftIndex = + absl::optional leftIndex = IndexOfKey(left.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); - int64_t rightIndex = IndexOfKey(right.map_value, kRawVectorValueFieldKey, - kVectorValueFieldKey); + absl::optional rightIndex = IndexOfKey( + right.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); - google_firestore_v1_Value leftArray = left.map_value.fields[leftIndex].value; + HARD_ASSERT(leftIndex.has_value() && rightIndex.has_value(), + "Unexpected occurence of vector without `value` field."); + + google_firestore_v1_Value leftArray = + left.map_value.fields[leftIndex.value()].value; google_firestore_v1_Value rightArray = - right.map_value.fields[rightIndex].value; + right.map_value.fields[rightIndex.value()].value; ComparisonResult lengthCompare = util::Compare( leftArray.array_value.values_count, rightArray.array_value.values_count); @@ -753,9 +757,10 @@ bool IsMaxValue(const google_firestore_v1_Value& value) { kRawMaxValueFieldValue; } -int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, - const char* kRawTypeValueFieldKey, - pb_bytes_array_s* kTypeValueFieldKey) { +absl::optional IndexOfKey( + const google_firestore_v1_MapValue& mapValue, + const char* kRawTypeValueFieldKey, + pb_bytes_array_s* kTypeValueFieldKey) { for (pb_size_t i = 0; i < mapValue.fields_count; i++) { if (mapValue.fields[i].key == kTypeValueFieldKey || nanopb::MakeStringView(mapValue.fields[i].key) == @@ -764,7 +769,7 @@ int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, } } - return -1; + return absl::nullopt; } bool IsVectorValue(const google_firestore_v1_Value& value) { @@ -776,32 +781,34 @@ bool IsVectorValue(const google_firestore_v1_Value& value) { return false; } - int64_t typeFieldIndex = -1; - if ((typeFieldIndex = IndexOfKey(value.map_value, kRawTypeValueFieldKey, - kTypeValueFieldKey)) < 0) { + absl::optional typeFieldIndex = + IndexOfKey(value.map_value, kRawTypeValueFieldKey, kTypeValueFieldKey); + if (!typeFieldIndex.has_value()) { return false; } - if (value.map_value.fields[typeFieldIndex].value.which_value_type != + if (value.map_value.fields[typeFieldIndex.value()].value.which_value_type != google_firestore_v1_Value_string_value_tag) { return false; } // Comparing the pointer address, then actual content if addresses are // different. - return value.map_value.fields[typeFieldIndex].value.string_value == - kVectorTypeFieldValue || - nanopb::MakeStringView( - value.map_value.fields[typeFieldIndex].value.string_value) == - kRawVectorTypeFieldValue; - - int64_t valueFieldIndex = -1; - if ((valueFieldIndex = IndexOfKey(value.map_value, kRawVectorValueFieldKey, - kVectorValueFieldKey)) < 0) { + if (value.map_value.fields[typeFieldIndex.value()].value.string_value != + kVectorTypeFieldValue && + nanopb::MakeStringView( + value.map_value.fields[typeFieldIndex.value()].value.string_value) != + kRawVectorTypeFieldValue) { + return false; + } + + absl::optional valueFieldIndex = IndexOfKey( + value.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); + if (!valueFieldIndex.has_value()) { return false; } - if (value.map_value.fields[valueFieldIndex].value.which_value_type != + if (value.map_value.fields[valueFieldIndex.value()].value.which_value_type != google_firestore_v1_Value_map_value_tag) { return false; } diff --git a/Firestore/core/src/model/value_util.h b/Firestore/core/src/model/value_util.h index 49ff85f3d55..4def5b81b82 100644 --- a/Firestore/core/src/model/value_util.h +++ b/Firestore/core/src/model/value_util.h @@ -181,9 +181,16 @@ bool IsMaxValue(const google_firestore_v1_Value& value); */ bool IsVectorValue(const google_firestore_v1_Value& value); -int64_t IndexOfKey(const google_firestore_v1_MapValue& mapValue, - const char* kRawTypeValueFieldKey, - pb_bytes_array_s* kTypeValueFieldKey); +/** + * Returns the index of the specified key (`kRawTypeValueFieldKey`) in the + * map (`mapValue`). `kTypeValueFieldKey` is an alternative representation + * of the key specified in `kRawTypeValueFieldKey`. + * If the key is not found, then `-1` is returned. + */ +absl::optional IndexOfKey( + const google_firestore_v1_MapValue& mapValue, + const char* kRawTypeValueFieldKey, + pb_bytes_array_s* kTypeValueFieldKey); /** * Returns `NaN` in its Protobuf representation. From 5f0b01bcd69f391c4bba879b0c9e26e763474b4e Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 15:49:41 -0600 Subject: [PATCH 10/24] Fixes based on PR feedback. --- Firestore/Source/API/FIRVectorValue.mm | 6 ++-- .../Source/SwiftAPI/VectorValue+Swift.swift | 6 ++-- .../Integration/VectorIntegrationTests.swift | 33 +++++++++++++++++-- Firestore/core/src/model/value_util.cc | 28 +++++++++++----- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/Firestore/Source/API/FIRVectorValue.mm b/Firestore/Source/API/FIRVectorValue.mm index 8c6acd86d68..ac7373a2f75 100644 --- a/Firestore/Source/API/FIRVectorValue.mm +++ b/Firestore/Source/API/FIRVectorValue.mm @@ -42,11 +42,11 @@ @implementation FIRVectorValue return outArray; } -- (instancetype)initWithArray:(NSArray *)data { +- (instancetype)initWithArray:(NSArray *)array { if (self = [super init]) { std::vector converted; - converted.reserve(data.count); - for (NSNumber *value in data) { + converted.reserve(array.count); + for (NSNumber *value in array) { converted.emplace_back([value doubleValue]); } diff --git a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift index 3a177aab53c..dffb35eb811 100644 --- a/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/VectorValue+Swift.swift @@ -21,12 +21,12 @@ #endif // SWIFT_PACKAGE public extension VectorValue { - convenience init(_ data: [Double]) { - let array = data.map { float in + convenience init(_ array: [Double]) { + let nsNumbers = array.map { float in NSNumber(value: float) } - self.init(__array: array) + self.init(__array: nsNumbers) } /// Returns a raw number array representation of the vector. diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index 0cb5c2a63de..a24de93144f 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -146,6 +146,35 @@ class VectorIntegrationTests: FSTIntegrationTestCase { } checkOnlineAndOfflineQuery(collection.order(by: "embedding"), matchesResult: docIds) + } + + func testSdkFiltersVectorFieldSameWayOnlineAndOffline() async throws { + let collection = collectionRef() + + let docsInOrder: [[String: Any]] = [ + ["embedding": [1, 2, 3, 4, 5, 6]], + ["embedding": [100]], + ["embedding": FieldValue.vector([Double.infinity * -1])], + ["embedding": FieldValue.vector([-100.0])], + ["embedding": FieldValue.vector([100.0])], + ["embedding": FieldValue.vector([Double.infinity])], + ["embedding": FieldValue.vector([1, 2.0])], + ["embedding": FieldValue.vector([2, 2.0])], + ["embedding": FieldValue.vector([1, 2, 3.0])], + ["embedding": FieldValue.vector([1, 2, 3, 4.0])], + ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])], + ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])], + ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])], + ["embedding": ["HELLO": "WORLD"]], + ["embedding": ["hello": "world"]], + ] + + var docIds: [String] = [] + for data in docsInOrder { + let docRef = try await collection.addDocument(data: data) + docIds.append(docRef.documentID) + } + checkOnlineAndOfflineQuery( collection.order(by: "embedding") .whereField("embedding", isLessThan: FieldValue.vector([1, 2, 100, 4, 4.0])), @@ -153,8 +182,8 @@ class VectorIntegrationTests: FSTIntegrationTestCase { ) checkOnlineAndOfflineQuery( collection.order(by: "embedding") - .whereField("embedding", isGreaterThan: FieldValue.vector([1, 2, 100, 4, 4.0])), - matchesResult: Array(docIds[12 ... 12]) + .whereField("embedding", isGreaterThanOrEqualTo: FieldValue.vector([1, 2, 100, 4, 4.0])), + matchesResult: Array(docIds[11 ... 12]) ) } diff --git a/Firestore/core/src/model/value_util.cc b/Firestore/core/src/model/value_util.cc index dfbdc4caaee..f363d2d7090 100644 --- a/Firestore/core/src/model/value_util.cc +++ b/Firestore/core/src/model/value_util.cc @@ -272,16 +272,26 @@ ComparisonResult CompareVectors(const google_firestore_v1_Value& left, absl::optional rightIndex = IndexOfKey( right.map_value, kRawVectorValueFieldKey, kVectorValueFieldKey); - HARD_ASSERT(leftIndex.has_value() && rightIndex.has_value(), - "Unexpected occurence of vector without `value` field."); + pb_size_t leftArrayLength = 0; + google_firestore_v1_Value leftArray; + if (leftIndex.has_value()) { + leftArray = left.map_value.fields[leftIndex.value()].value; + leftArrayLength = leftArray.array_value.values_count; + } + + pb_size_t rightArrayLength = 0; + google_firestore_v1_Value rightArray; + if (leftIndex.has_value()) { + rightArray = right.map_value.fields[rightIndex.value()].value; + rightArrayLength = rightArray.array_value.values_count; + } - google_firestore_v1_Value leftArray = - left.map_value.fields[leftIndex.value()].value; - google_firestore_v1_Value rightArray = - right.map_value.fields[rightIndex.value()].value; + if (leftArrayLength == 0 && rightArrayLength == 0) { + return ComparisonResult::Same; + } - ComparisonResult lengthCompare = util::Compare( - leftArray.array_value.values_count, rightArray.array_value.values_count); + ComparisonResult lengthCompare = + util::Compare(leftArrayLength, rightArrayLength); if (lengthCompare != ComparisonResult::Same) { return lengthCompare; } @@ -809,7 +819,7 @@ bool IsVectorValue(const google_firestore_v1_Value& value) { } if (value.map_value.fields[valueFieldIndex.value()].value.which_value_type != - google_firestore_v1_Value_map_value_tag) { + google_firestore_v1_Value_array_value_tag) { return false; } From e0fadeae620633f49232759f21292931568f3f77 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:38:13 -0600 Subject: [PATCH 11/24] API reference doc comments. --- .../Public/FirebaseFirestore/FIRVectorValue.h | 10 +++++++++- .../Swift/Source/SwiftAPI/FieldValue+Swift.swift | 16 ++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h index 28924e3b84b..2468379937c 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorValue.h @@ -18,15 +18,23 @@ NS_ASSUME_NONNULL_BEGIN +/** + * Represent a vector type in Firestore documents. + * Create an instance with `@link `FieldValue.vector(...)`. + */ NS_SWIFT_NAME(VectorValue) @interface FIRVectorValue : NSObject +/** Returns a copy of the raw number array that represents the vector. */ @property(readonly) NSArray *array NS_REFINED_FOR_SWIFT; /** :nodoc: */ - (instancetype)init NS_UNAVAILABLE; -// Public initializer is required to support Codable +/** + * Creates a `VectorValue` constructed with a copy of the given array of NSNumbrers. + * @param array An array of NSNumbers that represents a vector. + */ - (instancetype)initWithArray:(NSArray *)array NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift index 35d59bc6209..dbfbb1f9665 100644 --- a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift @@ -22,22 +22,22 @@ public extension FieldValue { /// Creates a new `VectorValue` constructed with a copy of the given array of Doubles. - /// - Parameter data: An array of Doubles. + /// - Parameter array: An array of Doubles. /// - Returns: A new `VectorValue` constructed with a copy of the given array of Doubles. - static func vector(_ data: [Double]) -> VectorValue { - let array = data.map { double in + static func vector(_ array: [Double]) -> VectorValue { + let nsNumbers = array.map { double in NSNumber(value: double) } - return FieldValue.__vector(from: array) + return FieldValue.__vector(from: nsNumbers) } /// Creates a new `VectorValue` constructed with a copy of the given array of Floats. - /// - Parameter data: An array of Floats. + /// - Parameter array: An array of Floats. /// - Returns: A new `VectorValue` constructed with a copy of the given array of Floats. - static func vector(_ data: [Float]) -> VectorValue { - let array = data.map { float in + static func vector(_ array: [Float]) -> VectorValue { + let nsNumbers = array.map { float in NSNumber(value: float) } - return FieldValue.__vector(from: array) + return FieldValue.__vector(from: nsNumbers) } } From ece38eb0733a791185c0ca220c9c2e832d6aa4fe Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:53:27 -0600 Subject: [PATCH 12/24] Fix copyright header --- Firestore/Swift/Source/Codable/VectorValue+Codable.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Firestore/Swift/Source/Codable/VectorValue+Codable.swift b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift index d10fbc593bf..57d544af3b7 100644 --- a/Firestore/Swift/Source/Codable/VectorValue+Codable.swift +++ b/Firestore/Swift/Source/Codable/VectorValue+Codable.swift @@ -1,5 +1,5 @@ /* - * Copyright 2024 Google + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From c1abeb2c7e25c143341a9c780b01f0e96fcc3ae6 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:58:38 -0600 Subject: [PATCH 13/24] Update CHANGELOG.md --- Firestore/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Firestore/CHANGELOG.md b/Firestore/CHANGELOG.md index fbfc61ba933..0ca9d518dc8 100644 --- a/Firestore/CHANGELOG.md +++ b/Firestore/CHANGELOG.md @@ -1,3 +1,6 @@ +# 11.1.0 +- [feature] Add VectorValue type support. + # 11.0.0 - [removed] **Breaking change**: The deprecated `FirebaseFirestoreSwift` module has been removed. See From 7f2aa72e44323b45ef40762d5c3978c37d691aa9 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Mon, 29 Jul 2024 17:04:11 -0600 Subject: [PATCH 14/24] Header ordering --- Firestore/core/src/index/firestore_index_value_writer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Firestore/core/src/index/firestore_index_value_writer.cc b/Firestore/core/src/index/firestore_index_value_writer.cc index e44e3118bed..4587844b930 100644 --- a/Firestore/core/src/index/firestore_index_value_writer.cc +++ b/Firestore/core/src/index/firestore_index_value_writer.cc @@ -15,13 +15,13 @@ */ #include "Firestore/core/src/index/firestore_index_value_writer.h" -#include "Firestore/core/src/model/value_util.h" #include #include #include #include "Firestore/core/src/model/resource_path.h" +#include "Firestore/core/src/model/value_util.h" #include "Firestore/core/src/nanopb/nanopb_util.h" namespace firebase { From 9fc9c7bf0b6b5828ee2a532d7b973be0b5baf1fc Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:38:27 -0600 Subject: [PATCH 15/24] Update Firestore/CHANGELOG.md Co-authored-by: Nick Cooke <36927374+ncooke3@users.noreply.github.com> --- Firestore/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Firestore/CHANGELOG.md b/Firestore/CHANGELOG.md index 0ca9d518dc8..6a4973c7b02 100644 --- a/Firestore/CHANGELOG.md +++ b/Firestore/CHANGELOG.md @@ -1,5 +1,5 @@ # 11.1.0 -- [feature] Add VectorValue type support. +- [feature] Add `VectorValue` type support. # 11.0.0 - [removed] **Breaking change**: The deprecated `FirebaseFirestoreSwift` module From d8123e5dbc8448c8526552b7b132731465df1ac9 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Fri, 2 Aug 2024 10:43:28 -0600 Subject: [PATCH 16/24] VQ exploration --- .../FirebaseFirestore/FIRFindNearestOptions.h | 15 +++++++ .../FIRFirestoreDistanceMeasure.h | 15 +++++++ .../FirebaseFirestore/FIRFindNearestOptions.h | 41 ++++++++++++++++++ .../FIRFirestoreDistanceMeasure.h | 14 +++++++ .../FIRFirestoreVectorSource.h | 13 ++++++ .../Public/FirebaseFirestore/FIRQuery.h | 30 +++++++++++++ .../Public/FirebaseFirestore/FIRVectorQuery.h | 42 +++++++++++++++++++ .../FIRVectorQuerySnapshot.h | 34 +++++++++++++++ 8 files changed, 204 insertions(+) create mode 100644 FirebaseFirestoreInternal/FirebaseFirestore/FIRFindNearestOptions.h create mode 100644 FirebaseFirestoreInternal/FirebaseFirestore/FIRFirestoreDistanceMeasure.h create mode 100644 Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h create mode 100644 Firestore/Source/Public/FirebaseFirestore/FIRFirestoreDistanceMeasure.h create mode 100644 Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h create mode 100644 Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h create mode 100644 Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h diff --git a/FirebaseFirestoreInternal/FirebaseFirestore/FIRFindNearestOptions.h b/FirebaseFirestoreInternal/FirebaseFirestore/FIRFindNearestOptions.h new file mode 100644 index 00000000000..c671ef0828a --- /dev/null +++ b/FirebaseFirestoreInternal/FirebaseFirestore/FIRFindNearestOptions.h @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import diff --git a/FirebaseFirestoreInternal/FirebaseFirestore/FIRFirestoreDistanceMeasure.h b/FirebaseFirestoreInternal/FirebaseFirestore/FIRFirestoreDistanceMeasure.h new file mode 100644 index 00000000000..564cd3e538a --- /dev/null +++ b/FirebaseFirestoreInternal/FirebaseFirestore/FIRFirestoreDistanceMeasure.h @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h b/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h new file mode 100644 index 00000000000..8119506ff61 --- /dev/null +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h @@ -0,0 +1,41 @@ +// +// FIRFindNearestOptions.h +// FirebaseFirestoreInternal +// +// Created by Mark Duckworth on 7/25/24. +// +#import + +#import "FIRFieldPath.h" + +@class FIRAggregateQuery; +@class FIRAggregateField; +@class FIRFieldPath; +@class FIRFirestore; +@class FIRFilter; +@class FIRQuerySnapshot; +@class FIRDocumentSnapshot; +@class FIRVectorQuery; + +NS_ASSUME_NONNULL_BEGIN + +NS_SWIFT_NAME(FindNearestOptions) +@interface FIRFindNearestOptions : NSObject +@property (nonatomic, readonly) FIRFieldPath *distanceResultFieldPath; +@property (nonatomic, readonly) NSNumber *distanceThreshold; + +- (nonnull instancetype)init NS_DESIGNATED_INITIALIZER; + +- (nonnull FIRFindNearestOptions *)optionsWithDistanceResultField: + (NSString *)distanceResultField; + +- (nonnull FIRFindNearestOptions *)optionsWithDistanceResultFieldPath: + (FIRFieldPath *)distanceResultFieldPath; + +- (nonnull FIRFindNearestOptions *)optionsWithDistanceThreshold: + (NSNumber *)distanceThreshold; + +@end + + +NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreDistanceMeasure.h b/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreDistanceMeasure.h new file mode 100644 index 00000000000..82761da81c9 --- /dev/null +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreDistanceMeasure.h @@ -0,0 +1,14 @@ +// +// FIRFirestoreDistanceMeasure.h +// FirebaseFirestoreInternal +// +// Created by Mark Duckworth on 7/25/24. +// + +#import + +typedef NS_ENUM(NSUInteger, FIRFirestoreDistanceMeasure) { + FIRFirestoreDistanceMeasureCosine, + FIRFirestoreDistanceMeasureEuclidean, + FIRFirestoreDistanceMeasureDotProduct +} NS_SWIFT_NAME(FirestoreDistanceMeasure); diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h b/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h new file mode 100644 index 00000000000..cf1cb3f4b03 --- /dev/null +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h @@ -0,0 +1,13 @@ +// +// FIRFirestoreVectorSource.h +// FirebaseFirestoreInternal +// +// Created by Mark Duckworth on 7/25/24. +// + + +#import + +typedef NS_ENUM(NSUInteger, FIRFirestoreVectorSource) { + FIRFirestoreVectorSourceServer, +} NS_SWIFT_NAME(FirestoreVectorSource); diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h b/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h index c75952876a2..42fe2f3bf20 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h @@ -19,6 +19,8 @@ #import "FIRFirestoreSource.h" #import "FIRListenerRegistration.h" #import "FIRSnapshotListenOptions.h" +#import "FIRFirestoreDistanceMeasure.h" +#import "FIRFindNearestOptions.h" @class FIRAggregateQuery; @class FIRAggregateField; @@ -27,6 +29,8 @@ @class FIRFilter; @class FIRQuerySnapshot; @class FIRDocumentSnapshot; +@class FIRVectorQuery; +@class FIRVectorValue; NS_ASSUME_NONNULL_BEGIN @@ -49,6 +53,32 @@ NS_SWIFT_NAME(Query) /** The `Firestore` instance that created this query (useful for performing transactions, etc.). */ @property(nonatomic, strong, readonly) FIRFirestore *firestore; +- (nonnull FIRVectorQuery *)findNearestWithField:(nonnull NSString *)field + queryVector:(nonnull NSArray *)queryVector + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + NS_SWIFT_NAME(findNearest(field:queryVector:limit:distanceMeasure:)); + +- (nonnull FIRVectorQuery *)findNearestWithFieldPath:(nonnull FIRFieldPath *)fieldPath + queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + NS_SWIFT_NAME(findNearest(fieldPath:queryVectorValue:limit:distanceMeasure:)); + +- (nonnull FIRVectorQuery *)findNearestWithField:(nonnull NSString *)field + queryVector:(nonnull NSArray *)queryVector + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + options:(nonnull FIRFindNearestOptions *)options + NS_SWIFT_NAME(findNearest(field:queryVector:limit:distanceMeasure:options:)); + +- (nonnull FIRVectorQuery *)findNearestWithFieldPath:(nonnull FIRFieldPath *)fieldPath + queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + options:(nonnull FIRFindNearestOptions *)options + NS_SWIFT_NAME(findNearest(fieldPath:queryVectorValue:limit:distanceMeasure:options:)); + #pragma mark - Retrieving Data /** * Reads the documents matching this query. diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h new file mode 100644 index 00000000000..afafe55d28d --- /dev/null +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h @@ -0,0 +1,42 @@ +// +// FIRVectorQuery.h +// FirebaseFirestoreInternal +// +// Created by Mark Duckworth on 7/25/24. +// +#import + +#import "FIRFirestoreVectorSource.h" +#import "FIRVectorQuerySnapshot.h" + +@class FIRAggregateQuery; +@class FIRAggregateField; +@class FIRFieldPath; +@class FIRFirestore; +@class FIRFilter; +@class FIRVectorQuerySnapshot; +@class FIRDocumentSnapshot; +@class FIRVectorQuery; + +NS_ASSUME_NONNULL_BEGIN + +NS_SWIFT_NAME(VectorQuery) +@interface FIRVectorQuery : NSObject + +@property(nonatomic, strong, readonly) FIRQuery *query; + +/** + * Executes this query. + * + * @param source The source from which to acquire the VectorQuery results. + * @param completion a block to execute once the results have been successfully read. + * snapshot will be `nil` only if error is `non-nil`. + */ +- (void)getDocumentsWithSource:(FIRFirestoreVectorSource)source + completion:(void (^)(FIRVectorQuerySnapshot *_Nullable snapshot, + NSError *_Nullable error))completion + NS_SWIFT_NAME(getDocuments(source:completion:)); + +@end + +NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h new file mode 100644 index 00000000000..7463ce10771 --- /dev/null +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h @@ -0,0 +1,34 @@ +// +// FIRVectorQuerySnapshot.h +// FirebaseFirestoreInternal +// +// Created by Mark Duckworth on 7/25/24. +// +#import + +NS_ASSUME_NONNULL_BEGIN + +@class FIRVectorQuery; +@class FIRAggregateQuery; +@class FIRAggregateField; +@class FIRFieldPath; +@class FIRFirestore; +@class FIRFilter; +@class FIRQuerySnapshot; +@class FIRDocumentSnapshot; + +NS_SWIFT_NAME(VectorQuerySnapshot) +@interface FIRVectorQuerySnapshot : NSObject +@property(nonatomic, strong, readonly) FIRVectorQuery *query; +@property(nonatomic, strong, readonly) FIRSnapshotMetadata *metadata; +@property(nonatomic, readonly, getter=isEmpty) BOOL empty; +@property(nonatomic, readonly) NSInteger count; +@property(nonatomic, strong, readonly) NSArray *documents; +@property(nonatomic, strong, readonly) NSArray *documentChanges; +- (NSArray *)documentChangesWithIncludeMetadataChanges: + (BOOL)includeMetadataChanges NS_SWIFT_NAME(documentChanges(includeMetadataChanges:)); + +@end + + +NS_ASSUME_NONNULL_END From 9b88f0317e32b11c4de28e9c450ae55fa1cb07b0 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:09:33 -0600 Subject: [PATCH 17/24] Peer review updates --- .../Firestore.xcodeproj/project.pbxproj | 6 ++-- Firestore/Source/API/FIRFieldValue.mm | 4 +-- Firestore/Source/API/FSTUserDataWriter.mm | 4 +-- .../Public/FirebaseFirestore/FIRFieldValue.h | 4 +-- .../FirebaseFirestore/FIRFindNearestOptions.h | 11 +++---- .../FIRFirestoreVectorSource.h | 1 - .../Public/FirebaseFirestore/FIRQuery.h | 32 +++++++++---------- .../Public/FirebaseFirestore/FIRVectorQuery.h | 4 +-- .../FIRVectorQuerySnapshot.h | 1 - .../Source/SwiftAPI/FieldValue+Swift.swift | 4 +-- 10 files changed, 33 insertions(+), 38 deletions(-) diff --git a/Firestore/Example/Firestore.xcodeproj/project.pbxproj b/Firestore/Example/Firestore.xcodeproj/project.pbxproj index 63f7e4c37f8..7478ab8332d 100644 --- a/Firestore/Example/Firestore.xcodeproj/project.pbxproj +++ b/Firestore/Example/Firestore.xcodeproj/project.pbxproj @@ -4586,6 +4586,7 @@ 62E54B862A9E910B003347C8 /* IndexingTests.swift in Sources */, 621D620C28F9CE7400D2FA26 /* QueryIntegrationTests.swift in Sources */, 1CFBD4563960D8A20C4679A3 /* SnapshotListenerSourceTests.swift in Sources */, + EFF22EAC2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 4D42E5C756229C08560DD731 /* XCTestCase+Await.mm in Sources */, 09BE8C01EC33D1FD82262D5D /* aggregate_query_test.cc in Sources */, 0EC3921AE220410F7394729B /* aggregation_result.pb.cc in Sources */, @@ -4632,7 +4633,6 @@ AECCD9663BB3DC52199F954A /* executor_std_test.cc in Sources */, 18F644E6AA98E6D6F3F1F809 /* executor_test.cc in Sources */, 6938575C8B5E6FE0D562547A /* exponential_backoff_test.cc in Sources */, - EFF22EAC2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 258B372CF33B7E7984BBA659 /* fake_target_metadata_provider.cc in Sources */, F8BD2F61EFA35C2D5120D9EB /* field_index_test.cc in Sources */, F272A8C41D2353700A11D1FB /* field_mask_test.cc in Sources */, @@ -4827,6 +4827,7 @@ 62E54B852A9E910B003347C8 /* IndexingTests.swift in Sources */, 621D620B28F9CE7400D2FA26 /* QueryIntegrationTests.swift in Sources */, A0BC30D482B0ABD1A3A24CDC /* SnapshotListenerSourceTests.swift in Sources */, + EFF22EAB2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 736C4E82689F1CA1859C4A3F /* XCTestCase+Await.mm in Sources */, 412BE974741729A6683C386F /* aggregate_query_test.cc in Sources */, DF983A9C1FBF758AF3AF110D /* aggregation_result.pb.cc in Sources */, @@ -4873,7 +4874,6 @@ 17DFF30CF61D87883986E8B6 /* executor_std_test.cc in Sources */, 814724DE70EFC3DDF439CD78 /* executor_test.cc in Sources */, BD6CC8614970A3D7D2CF0D49 /* exponential_backoff_test.cc in Sources */, - EFF22EAB2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 4D2655C5675D83205C3749DC /* fake_target_metadata_provider.cc in Sources */, 50C852E08626CFA7DC889EEA /* field_index_test.cc in Sources */, A1563EFEB021936D3FFE07E3 /* field_mask_test.cc in Sources */, @@ -5314,6 +5314,7 @@ 62E54B842A9E910B003347C8 /* IndexingTests.swift in Sources */, 621D620A28F9CE7400D2FA26 /* QueryIntegrationTests.swift in Sources */, B00F8D1819EE20C45B660940 /* SnapshotListenerSourceTests.swift in Sources */, + EFF22EAA2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 5492E0442021457E00B64F25 /* XCTestCase+Await.mm in Sources */, B04E4FE20930384DF3A402F9 /* aggregate_query_test.cc in Sources */, 1A3D8028303B45FCBB21CAD3 /* aggregation_result.pb.cc in Sources */, @@ -5360,7 +5361,6 @@ 125B1048ECB755C2106802EB /* executor_std_test.cc in Sources */, DABB9FB61B1733F985CBF713 /* executor_test.cc in Sources */, 7BCF050BA04537B0E7D44730 /* exponential_backoff_test.cc in Sources */, - EFF22EAA2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, BA1C5EAE87393D8E60F5AE6D /* fake_target_metadata_provider.cc in Sources */, 84285C3F63D916A4786724A8 /* field_index_test.cc in Sources */, 6A40835DB2C02B9F07C02E88 /* field_mask_test.cc in Sources */, diff --git a/Firestore/Source/API/FIRFieldValue.mm b/Firestore/Source/API/FIRFieldValue.mm index 56583d11717..23c5060a8ee 100644 --- a/Firestore/Source/API/FIRFieldValue.mm +++ b/Firestore/Source/API/FIRFieldValue.mm @@ -177,8 +177,8 @@ + (instancetype)fieldValueForIntegerIncrement:(int64_t)l { return [[FSTNumericIncrementFieldValue alloc] initWithOperand:@(l)]; } -+ (nonnull FIRVectorValue *)vectorFromNSNumbers:(nonnull NSArray *)values { - return [[FIRVectorValue alloc] initWithArray:values]; ++ (nonnull FIRVectorValue *)vectorWithArray:(nonnull NSArray *)array { + return [[FIRVectorValue alloc] initWithArray:array]; } @end diff --git a/Firestore/Source/API/FSTUserDataWriter.mm b/Firestore/Source/API/FSTUserDataWriter.mm index e6f2cd358eb..1e170531782 100644 --- a/Firestore/Source/API/FSTUserDataWriter.mm +++ b/Firestore/Source/API/FSTUserDataWriter.mm @@ -132,10 +132,10 @@ - (FIRVectorValue *)convertedVector:(const google_firestore_v1_MapValue &)mapVal const google_firestore_v1_Value &value = mapValue.fields[i].value; if ((0 == key.compare(absl::string_view("value"))) && value.which_value_type == google_firestore_v1_Value_array_value_tag) { - return [FIRFieldValue vectorFromNSNumbers:[self convertedArray:value.array_value]]; + return [FIRFieldValue vectorWithArray:[self convertedArray:value.array_value]]; } } - return [FIRFieldValue vectorFromNSNumbers:@[]]; + return [FIRFieldValue vectorWithArray:@[]]; } - (NSArray *)convertedArray:(const google_firestore_v1_ArrayValue &)arrayValue { diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h index 45ef2d8d24e..8add3dec1fa 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h @@ -94,10 +94,10 @@ NS_SWIFT_NAME(FieldValue) /** * Creates a new `VectorValue` constructed with a copy of the given array of NSNumbers. * - * @param values Create a `VectorValue` instance with a copy of this array of NSNumbers. + * @param array Create a `VectorValue` instance with a copy of this array of NSNumbers. * @return A new `VectorValue` constructed with a copy of the given array of NSNumbers. */ -+ (FIRVectorValue *)vectorFromNSNumbers:(NSArray *)values NS_REFINED_FOR_SWIFT; ++ (FIRVectorValue *)vectorWithArray:(NSArray *)array NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h b/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h index 8119506ff61..d8d3598dca4 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h @@ -21,21 +21,18 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(FindNearestOptions) @interface FIRFindNearestOptions : NSObject -@property (nonatomic, readonly) FIRFieldPath *distanceResultFieldPath; -@property (nonatomic, readonly) NSNumber *distanceThreshold; +@property(nonatomic, readonly) FIRFieldPath *distanceResultFieldPath; +@property(nonatomic, readonly) NSNumber *distanceThreshold; - (nonnull instancetype)init NS_DESIGNATED_INITIALIZER; -- (nonnull FIRFindNearestOptions *)optionsWithDistanceResultField: - (NSString *)distanceResultField; +- (nonnull FIRFindNearestOptions *)optionsWithDistanceResultField:(NSString *)distanceResultField; - (nonnull FIRFindNearestOptions *)optionsWithDistanceResultFieldPath: (FIRFieldPath *)distanceResultFieldPath; -- (nonnull FIRFindNearestOptions *)optionsWithDistanceThreshold: - (NSNumber *)distanceThreshold; +- (nonnull FIRFindNearestOptions *)optionsWithDistanceThreshold:(NSNumber *)distanceThreshold; @end - NS_ASSUME_NONNULL_END diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h b/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h index cf1cb3f4b03..fa5eab63251 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h @@ -5,7 +5,6 @@ // Created by Mark Duckworth on 7/25/24. // - #import typedef NS_ENUM(NSUInteger, FIRFirestoreVectorSource) { diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h b/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h index 42fe2f3bf20..c4cbad2c018 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h @@ -16,11 +16,11 @@ #import +#import "FIRFindNearestOptions.h" +#import "FIRFirestoreDistanceMeasure.h" #import "FIRFirestoreSource.h" #import "FIRListenerRegistration.h" #import "FIRSnapshotListenOptions.h" -#import "FIRFirestoreDistanceMeasure.h" -#import "FIRFindNearestOptions.h" @class FIRAggregateQuery; @class FIRAggregateField; @@ -54,29 +54,29 @@ NS_SWIFT_NAME(Query) @property(nonatomic, strong, readonly) FIRFirestore *firestore; - (nonnull FIRVectorQuery *)findNearestWithField:(nonnull NSString *)field - queryVector:(nonnull NSArray *)queryVector - limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + queryVector:(nonnull NSArray *)queryVector + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure NS_SWIFT_NAME(findNearest(field:queryVector:limit:distanceMeasure:)); - (nonnull FIRVectorQuery *)findNearestWithFieldPath:(nonnull FIRFieldPath *)fieldPath - queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue - limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure NS_SWIFT_NAME(findNearest(fieldPath:queryVectorValue:limit:distanceMeasure:)); - (nonnull FIRVectorQuery *)findNearestWithField:(nonnull NSString *)field - queryVector:(nonnull NSArray *)queryVector - limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure - options:(nonnull FIRFindNearestOptions *)options + queryVector:(nonnull NSArray *)queryVector + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + options:(nonnull FIRFindNearestOptions *)options NS_SWIFT_NAME(findNearest(field:queryVector:limit:distanceMeasure:options:)); - (nonnull FIRVectorQuery *)findNearestWithFieldPath:(nonnull FIRFieldPath *)fieldPath - queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue - limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure - options:(nonnull FIRFindNearestOptions *)options + queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue + limit:(int64_t)limit + distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + options:(nonnull FIRFindNearestOptions *)options NS_SWIFT_NAME(findNearest(fieldPath:queryVectorValue:limit:distanceMeasure:options:)); #pragma mark - Retrieving Data diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h index afafe55d28d..7af9bae7da0 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h @@ -33,8 +33,8 @@ NS_SWIFT_NAME(VectorQuery) * snapshot will be `nil` only if error is `non-nil`. */ - (void)getDocumentsWithSource:(FIRFirestoreVectorSource)source - completion:(void (^)(FIRVectorQuerySnapshot *_Nullable snapshot, - NSError *_Nullable error))completion + completion:(void (^)(FIRVectorQuerySnapshot *_Nullable snapshot, + NSError *_Nullable error))completion NS_SWIFT_NAME(getDocuments(source:completion:)); @end diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h index 7463ce10771..7ecdfcc1662 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuerySnapshot.h @@ -30,5 +30,4 @@ NS_SWIFT_NAME(VectorQuerySnapshot) @end - NS_ASSUME_NONNULL_END diff --git a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift index dbfbb1f9665..ccab6238267 100644 --- a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift @@ -28,7 +28,7 @@ public extension FieldValue { let nsNumbers = array.map { double in NSNumber(value: double) } - return FieldValue.__vector(from: nsNumbers) + return FieldValue.__vector(with: nsNumbers) } /// Creates a new `VectorValue` constructed with a copy of the given array of Floats. @@ -38,6 +38,6 @@ public extension FieldValue { let nsNumbers = array.map { float in NSNumber(value: float) } - return FieldValue.__vector(from: nsNumbers) + return FieldValue.__vector(with: nsNumbers) } } From 7b4ff0847eb8e634e1203de7df7c402b02bf299b Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:58:58 -0600 Subject: [PATCH 18/24] Message rename to match API design --- Firestore/Example/Firestore.xcodeproj/project.pbxproj | 6 +++--- Firestore/Source/API/FIRFieldValue.mm | 4 ++-- Firestore/Source/API/FSTUserDataWriter.mm | 4 ++-- Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h | 4 ++-- Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Firestore/Example/Firestore.xcodeproj/project.pbxproj b/Firestore/Example/Firestore.xcodeproj/project.pbxproj index 63f7e4c37f8..7478ab8332d 100644 --- a/Firestore/Example/Firestore.xcodeproj/project.pbxproj +++ b/Firestore/Example/Firestore.xcodeproj/project.pbxproj @@ -4586,6 +4586,7 @@ 62E54B862A9E910B003347C8 /* IndexingTests.swift in Sources */, 621D620C28F9CE7400D2FA26 /* QueryIntegrationTests.swift in Sources */, 1CFBD4563960D8A20C4679A3 /* SnapshotListenerSourceTests.swift in Sources */, + EFF22EAC2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 4D42E5C756229C08560DD731 /* XCTestCase+Await.mm in Sources */, 09BE8C01EC33D1FD82262D5D /* aggregate_query_test.cc in Sources */, 0EC3921AE220410F7394729B /* aggregation_result.pb.cc in Sources */, @@ -4632,7 +4633,6 @@ AECCD9663BB3DC52199F954A /* executor_std_test.cc in Sources */, 18F644E6AA98E6D6F3F1F809 /* executor_test.cc in Sources */, 6938575C8B5E6FE0D562547A /* exponential_backoff_test.cc in Sources */, - EFF22EAC2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 258B372CF33B7E7984BBA659 /* fake_target_metadata_provider.cc in Sources */, F8BD2F61EFA35C2D5120D9EB /* field_index_test.cc in Sources */, F272A8C41D2353700A11D1FB /* field_mask_test.cc in Sources */, @@ -4827,6 +4827,7 @@ 62E54B852A9E910B003347C8 /* IndexingTests.swift in Sources */, 621D620B28F9CE7400D2FA26 /* QueryIntegrationTests.swift in Sources */, A0BC30D482B0ABD1A3A24CDC /* SnapshotListenerSourceTests.swift in Sources */, + EFF22EAB2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 736C4E82689F1CA1859C4A3F /* XCTestCase+Await.mm in Sources */, 412BE974741729A6683C386F /* aggregate_query_test.cc in Sources */, DF983A9C1FBF758AF3AF110D /* aggregation_result.pb.cc in Sources */, @@ -4873,7 +4874,6 @@ 17DFF30CF61D87883986E8B6 /* executor_std_test.cc in Sources */, 814724DE70EFC3DDF439CD78 /* executor_test.cc in Sources */, BD6CC8614970A3D7D2CF0D49 /* exponential_backoff_test.cc in Sources */, - EFF22EAB2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 4D2655C5675D83205C3749DC /* fake_target_metadata_provider.cc in Sources */, 50C852E08626CFA7DC889EEA /* field_index_test.cc in Sources */, A1563EFEB021936D3FFE07E3 /* field_mask_test.cc in Sources */, @@ -5314,6 +5314,7 @@ 62E54B842A9E910B003347C8 /* IndexingTests.swift in Sources */, 621D620A28F9CE7400D2FA26 /* QueryIntegrationTests.swift in Sources */, B00F8D1819EE20C45B660940 /* SnapshotListenerSourceTests.swift in Sources */, + EFF22EAA2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, 5492E0442021457E00B64F25 /* XCTestCase+Await.mm in Sources */, B04E4FE20930384DF3A402F9 /* aggregate_query_test.cc in Sources */, 1A3D8028303B45FCBB21CAD3 /* aggregation_result.pb.cc in Sources */, @@ -5360,7 +5361,6 @@ 125B1048ECB755C2106802EB /* executor_std_test.cc in Sources */, DABB9FB61B1733F985CBF713 /* executor_test.cc in Sources */, 7BCF050BA04537B0E7D44730 /* exponential_backoff_test.cc in Sources */, - EFF22EAA2C5060A4009A369B /* VectorIntegrationTests.swift in Sources */, BA1C5EAE87393D8E60F5AE6D /* fake_target_metadata_provider.cc in Sources */, 84285C3F63D916A4786724A8 /* field_index_test.cc in Sources */, 6A40835DB2C02B9F07C02E88 /* field_mask_test.cc in Sources */, diff --git a/Firestore/Source/API/FIRFieldValue.mm b/Firestore/Source/API/FIRFieldValue.mm index 56583d11717..23c5060a8ee 100644 --- a/Firestore/Source/API/FIRFieldValue.mm +++ b/Firestore/Source/API/FIRFieldValue.mm @@ -177,8 +177,8 @@ + (instancetype)fieldValueForIntegerIncrement:(int64_t)l { return [[FSTNumericIncrementFieldValue alloc] initWithOperand:@(l)]; } -+ (nonnull FIRVectorValue *)vectorFromNSNumbers:(nonnull NSArray *)values { - return [[FIRVectorValue alloc] initWithArray:values]; ++ (nonnull FIRVectorValue *)vectorWithArray:(nonnull NSArray *)array { + return [[FIRVectorValue alloc] initWithArray:array]; } @end diff --git a/Firestore/Source/API/FSTUserDataWriter.mm b/Firestore/Source/API/FSTUserDataWriter.mm index e6f2cd358eb..1e170531782 100644 --- a/Firestore/Source/API/FSTUserDataWriter.mm +++ b/Firestore/Source/API/FSTUserDataWriter.mm @@ -132,10 +132,10 @@ - (FIRVectorValue *)convertedVector:(const google_firestore_v1_MapValue &)mapVal const google_firestore_v1_Value &value = mapValue.fields[i].value; if ((0 == key.compare(absl::string_view("value"))) && value.which_value_type == google_firestore_v1_Value_array_value_tag) { - return [FIRFieldValue vectorFromNSNumbers:[self convertedArray:value.array_value]]; + return [FIRFieldValue vectorWithArray:[self convertedArray:value.array_value]]; } } - return [FIRFieldValue vectorFromNSNumbers:@[]]; + return [FIRFieldValue vectorWithArray:@[]]; } - (NSArray *)convertedArray:(const google_firestore_v1_ArrayValue &)arrayValue { diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h index 45ef2d8d24e..8add3dec1fa 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFieldValue.h @@ -94,10 +94,10 @@ NS_SWIFT_NAME(FieldValue) /** * Creates a new `VectorValue` constructed with a copy of the given array of NSNumbers. * - * @param values Create a `VectorValue` instance with a copy of this array of NSNumbers. + * @param array Create a `VectorValue` instance with a copy of this array of NSNumbers. * @return A new `VectorValue` constructed with a copy of the given array of NSNumbers. */ -+ (FIRVectorValue *)vectorFromNSNumbers:(NSArray *)values NS_REFINED_FOR_SWIFT; ++ (FIRVectorValue *)vectorWithArray:(NSArray *)array NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift index dbfbb1f9665..ccab6238267 100644 --- a/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift +++ b/Firestore/Swift/Source/SwiftAPI/FieldValue+Swift.swift @@ -28,7 +28,7 @@ public extension FieldValue { let nsNumbers = array.map { double in NSNumber(value: double) } - return FieldValue.__vector(from: nsNumbers) + return FieldValue.__vector(with: nsNumbers) } /// Creates a new `VectorValue` constructed with a copy of the given array of Floats. @@ -38,6 +38,6 @@ public extension FieldValue { let nsNumbers = array.map { float in NSNumber(value: float) } - return FieldValue.__vector(from: nsNumbers) + return FieldValue.__vector(with: nsNumbers) } } From 7c91869f60883301b3dd9ad350590fed831a4f4d Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:02:11 -0600 Subject: [PATCH 19/24] Fix comment --- Firestore/core/src/model/value_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Firestore/core/src/model/value_util.h b/Firestore/core/src/model/value_util.h index 4def5b81b82..708b71ccd16 100644 --- a/Firestore/core/src/model/value_util.h +++ b/Firestore/core/src/model/value_util.h @@ -185,7 +185,7 @@ bool IsVectorValue(const google_firestore_v1_Value& value); * Returns the index of the specified key (`kRawTypeValueFieldKey`) in the * map (`mapValue`). `kTypeValueFieldKey` is an alternative representation * of the key specified in `kRawTypeValueFieldKey`. - * If the key is not found, then `-1` is returned. + * If the key is not found, then `absl::nullopt` is returned. */ absl::optional IndexOfKey( const google_firestore_v1_MapValue& mapValue, From 070136c9404a018dc5382f50f09e6bd1d10a7356 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Tue, 6 Aug 2024 18:32:28 -0600 Subject: [PATCH 20/24] Fix CI issues --- .../Swift/Tests/Integration/VectorIntegrationTests.swift | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index a24de93144f..069a71a6541 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -19,7 +19,7 @@ import FirebaseFirestore import Foundation // iOS 15 required for test implementation, not vector feature -@available(iOS 15, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *) +@available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *) class VectorIntegrationTests: FSTIntegrationTestCase { func testWriteAndReadVectorEmbeddings() async throws { let collection = collectionRef() @@ -52,10 +52,11 @@ class VectorIntegrationTests: FSTIntegrationTestCase { ) } + @available(iOS 15, tvOS 15, macOS 12.0, macCatalyst 13, watchOS 7, *) func testSdkOrdersVectorFieldSameWayAsBackend() async throws { let collection = collectionRef() - var docsInOrder: [[String: Any]] = [ + let docsInOrder: [[String: Any]] = [ ["embedding": [1, 2, 3, 4, 5, 6]], ["embedding": [100]], ["embedding": FieldValue.vector([Double.infinity * -1])], From d749fc7a6beb2cd8e2f56ba8ea43d7c80a1f43bb Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Wed, 7 Aug 2024 12:45:47 -0600 Subject: [PATCH 21/24] Update run_firestore_emulator.sh to 1.19.7 --- scripts/run_firestore_emulator.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/run_firestore_emulator.sh b/scripts/run_firestore_emulator.sh index dac3f81ad7c..7401009c44d 100755 --- a/scripts/run_firestore_emulator.sh +++ b/scripts/run_firestore_emulator.sh @@ -25,7 +25,7 @@ if [[ ! -z "${JAVA_HOME_11_X64:-}" ]]; then export JAVA_HOME=$JAVA_HOME_11_X64 fi -VERSION='1.18.2' +VERSION='1.19.7' FILENAME="cloud-firestore-emulator-v${VERSION}.jar" URL="https://storage.googleapis.com/firebase-preview-drop/emulator/${FILENAME}" From 0fda84e87af6a5628e4ad8f9a0f493a00f73b9c6 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:11:25 -0600 Subject: [PATCH 22/24] Code cleanup --- Firestore/Source/API/FIRVectorValue.mm | 36 +++----------------------- 1 file changed, 3 insertions(+), 33 deletions(-) diff --git a/Firestore/Source/API/FIRVectorValue.mm b/Firestore/Source/API/FIRVectorValue.mm index ac7373a2f75..d4f25b39ac3 100644 --- a/Firestore/Source/API/FIRVectorValue.mm +++ b/Firestore/Source/API/FIRVectorValue.mm @@ -22,35 +22,13 @@ NS_ASSUME_NONNULL_BEGIN -@interface FIRVectorValue () { - /** Internal vector representation */ - std::vector _internalValue; -} - -@end - @implementation FIRVectorValue -- (NSArray *)array { - size_t length = _internalValue.size(); - NSMutableArray *outArray = - [[NSMutableArray alloc] initWithCapacity:length]; - for (size_t i = 0; i < length; i++) { - [outArray addObject:[[NSNumber alloc] initWithDouble:self->_internalValue.at(i)]]; - } - - return outArray; -} +@synthesize array = _internalValue; - (instancetype)initWithArray:(NSArray *)array { if (self = [super init]) { - std::vector converted; - converted.reserve(array.count); - for (NSNumber *value in array) { - converted.emplace_back([value doubleValue]); - } - - _internalValue = std::move(converted); + _internalValue = [NSArray arrayWithArray:array]; } return self; } @@ -66,15 +44,7 @@ - (BOOL)isEqual:(nullable id)object { FIRVectorValue *otherVector = ((FIRVectorValue *)object); - if (self->_internalValue.size() != otherVector->_internalValue.size()) { - return NO; - } - - for (size_t i = 0; i < self->_internalValue.size(); i++) { - if (self->_internalValue[i] != otherVector->_internalValue[i]) return NO; - } - - return YES; + return [self.array isEqualToArray:otherVector.array]; } @end From 07a6409f1efda930fe418e096f6f5b91c2211e3e Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:05:21 -0600 Subject: [PATCH 23/24] Refined Swift API --- ...DistanceMeasure.h => FIRDistanceMeasure.h} | 2 +- .../FirebaseFirestore/FIRVectorSource.h | 15 +++++++ ...DistanceMeasure.h => FIRDistanceMeasure.h} | 8 ++-- .../Public/FirebaseFirestore/FIRFieldPath.h | 3 +- .../FirebaseFirestore/FIRFindNearestOptions.h | 7 ++-- .../Public/FirebaseFirestore/FIRQuery.h | 25 ++---------- .../Public/FirebaseFirestore/FIRVectorQuery.h | 7 ++-- ...estoreVectorSource.h => FIRVectorSource.h} | 6 +-- .../SwiftAPI/FIeldPath+Expressible.swift | 38 ++++++++++++++++++ .../SwiftAPI/FindNearestOptions+Swift.swift | 31 +++++++++++++++ .../Swift/Source/SwiftAPI/Query+Swift.swift | 39 +++++++++++++++++++ .../Source/SwiftAPI/VectorQuery+Swift.swift | 27 +++++++++++++ .../Integration/VectorIntegrationTests.swift | 13 +++++++ 13 files changed, 182 insertions(+), 39 deletions(-) rename FirebaseFirestoreInternal/FirebaseFirestore/{FIRFirestoreDistanceMeasure.h => FIRDistanceMeasure.h} (89%) create mode 100644 FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorSource.h rename Firestore/Source/Public/FirebaseFirestore/{FIRFirestoreDistanceMeasure.h => FIRDistanceMeasure.h} (52%) rename Firestore/Source/Public/FirebaseFirestore/{FIRFirestoreVectorSource.h => FIRVectorSource.h} (53%) create mode 100644 Firestore/Swift/Source/SwiftAPI/FIeldPath+Expressible.swift create mode 100644 Firestore/Swift/Source/SwiftAPI/FindNearestOptions+Swift.swift create mode 100644 Firestore/Swift/Source/SwiftAPI/Query+Swift.swift create mode 100644 Firestore/Swift/Source/SwiftAPI/VectorQuery+Swift.swift diff --git a/FirebaseFirestoreInternal/FirebaseFirestore/FIRFirestoreDistanceMeasure.h b/FirebaseFirestoreInternal/FirebaseFirestore/FIRDistanceMeasure.h similarity index 89% rename from FirebaseFirestoreInternal/FirebaseFirestore/FIRFirestoreDistanceMeasure.h rename to FirebaseFirestoreInternal/FirebaseFirestore/FIRDistanceMeasure.h index 564cd3e538a..e2cd4ba65d5 100644 --- a/FirebaseFirestoreInternal/FirebaseFirestore/FIRFirestoreDistanceMeasure.h +++ b/FirebaseFirestoreInternal/FirebaseFirestore/FIRDistanceMeasure.h @@ -12,4 +12,4 @@ // See the License for the specific language governing permissions and // limitations under the License. -#import +#import diff --git a/FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorSource.h b/FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorSource.h new file mode 100644 index 00000000000..ac85c462d20 --- /dev/null +++ b/FirebaseFirestoreInternal/FirebaseFirestore/FIRVectorSource.h @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreDistanceMeasure.h b/Firestore/Source/Public/FirebaseFirestore/FIRDistanceMeasure.h similarity index 52% rename from Firestore/Source/Public/FirebaseFirestore/FIRFirestoreDistanceMeasure.h rename to Firestore/Source/Public/FirebaseFirestore/FIRDistanceMeasure.h index 82761da81c9..e35173a4dcb 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreDistanceMeasure.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRDistanceMeasure.h @@ -7,8 +7,8 @@ #import -typedef NS_ENUM(NSUInteger, FIRFirestoreDistanceMeasure) { - FIRFirestoreDistanceMeasureCosine, - FIRFirestoreDistanceMeasureEuclidean, - FIRFirestoreDistanceMeasureDotProduct +typedef NS_ENUM(NSUInteger, FIRDistanceMeasure) { + FIRDistanceMeasureCosine, + FIRDistanceMeasureEuclidean, + FIRDistanceMeasureDotProduct } NS_SWIFT_NAME(FirestoreDistanceMeasure); diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFieldPath.h b/Firestore/Source/Public/FirebaseFirestore/FIRFieldPath.h index 9f64fbdc99d..781c3e19b6b 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFieldPath.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFieldPath.h @@ -36,7 +36,8 @@ NS_SWIFT_NAME(FieldPath) * @param fieldNames A list of field names. * @return A `FieldPath` that points to a field location in a document. */ -- (instancetype)initWithFields:(NSArray *)fieldNames NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithFields:(NSArray *)fieldNames + NS_SWIFT_NAME(init(_:)) NS_DESIGNATED_INITIALIZER; /** * A special sentinel `FieldPath` to refer to the ID of a document. It can be used in queries to diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h b/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h index d8d3598dca4..2ccd4b8f6d7 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRFindNearestOptions.h @@ -22,16 +22,15 @@ NS_ASSUME_NONNULL_BEGIN NS_SWIFT_NAME(FindNearestOptions) @interface FIRFindNearestOptions : NSObject @property(nonatomic, readonly) FIRFieldPath *distanceResultFieldPath; -@property(nonatomic, readonly) NSNumber *distanceThreshold; +@property(nonatomic, readonly) NSNumber *distanceThreshold NS_REFINED_FOR_SWIFT; - (nonnull instancetype)init NS_DESIGNATED_INITIALIZER; -- (nonnull FIRFindNearestOptions *)optionsWithDistanceResultField:(NSString *)distanceResultField; - - (nonnull FIRFindNearestOptions *)optionsWithDistanceResultFieldPath: (FIRFieldPath *)distanceResultFieldPath; -- (nonnull FIRFindNearestOptions *)optionsWithDistanceThreshold:(NSNumber *)distanceThreshold; +- (nonnull FIRFindNearestOptions *)optionsWithDistanceThreshold:(NSNumber *)distanceThreshold + NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h b/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h index c4cbad2c018..04912f9ab7c 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRQuery.h @@ -16,8 +16,8 @@ #import +#import "FIRDistanceMeasure.h" #import "FIRFindNearestOptions.h" -#import "FIRFirestoreDistanceMeasure.h" #import "FIRFirestoreSource.h" #import "FIRListenerRegistration.h" #import "FIRSnapshotListenOptions.h" @@ -53,31 +53,12 @@ NS_SWIFT_NAME(Query) /** The `Firestore` instance that created this query (useful for performing transactions, etc.). */ @property(nonatomic, strong, readonly) FIRFirestore *firestore; -- (nonnull FIRVectorQuery *)findNearestWithField:(nonnull NSString *)field - queryVector:(nonnull NSArray *)queryVector - limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure - NS_SWIFT_NAME(findNearest(field:queryVector:limit:distanceMeasure:)); - -- (nonnull FIRVectorQuery *)findNearestWithFieldPath:(nonnull FIRFieldPath *)fieldPath - queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue - limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure - NS_SWIFT_NAME(findNearest(fieldPath:queryVectorValue:limit:distanceMeasure:)); - -- (nonnull FIRVectorQuery *)findNearestWithField:(nonnull NSString *)field - queryVector:(nonnull NSArray *)queryVector - limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure - options:(nonnull FIRFindNearestOptions *)options - NS_SWIFT_NAME(findNearest(field:queryVector:limit:distanceMeasure:options:)); - - (nonnull FIRVectorQuery *)findNearestWithFieldPath:(nonnull FIRFieldPath *)fieldPath queryVectorValue:(nonnull FIRVectorValue *)queryVectorValue limit:(int64_t)limit - distanceMeasure:(FIRFirestoreDistanceMeasure)distanceMeasure + distanceMeasure:(FIRDistanceMeasure)distanceMeasure options:(nonnull FIRFindNearestOptions *)options - NS_SWIFT_NAME(findNearest(fieldPath:queryVectorValue:limit:distanceMeasure:options:)); + NS_REFINED_FOR_SWIFT; #pragma mark - Retrieving Data /** diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h index 7af9bae7da0..caea1c002b7 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorQuery.h @@ -6,8 +6,8 @@ // #import -#import "FIRFirestoreVectorSource.h" #import "FIRVectorQuerySnapshot.h" +#import "FIRVectorSource.h" @class FIRAggregateQuery; @class FIRAggregateField; @@ -32,10 +32,9 @@ NS_SWIFT_NAME(VectorQuery) * @param completion a block to execute once the results have been successfully read. * snapshot will be `nil` only if error is `non-nil`. */ -- (void)getDocumentsWithSource:(FIRFirestoreVectorSource)source +- (void)getDocumentsWithSource:(FIRVectorSource)source completion:(void (^)(FIRVectorQuerySnapshot *_Nullable snapshot, - NSError *_Nullable error))completion - NS_SWIFT_NAME(getDocuments(source:completion:)); + NSError *_Nullable error))completion NS_REFINED_FOR_SWIFT; @end diff --git a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h b/Firestore/Source/Public/FirebaseFirestore/FIRVectorSource.h similarity index 53% rename from Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h rename to Firestore/Source/Public/FirebaseFirestore/FIRVectorSource.h index fa5eab63251..5d4e995f10b 100644 --- a/Firestore/Source/Public/FirebaseFirestore/FIRFirestoreVectorSource.h +++ b/Firestore/Source/Public/FirebaseFirestore/FIRVectorSource.h @@ -7,6 +7,6 @@ #import -typedef NS_ENUM(NSUInteger, FIRFirestoreVectorSource) { - FIRFirestoreVectorSourceServer, -} NS_SWIFT_NAME(FirestoreVectorSource); +typedef NS_ENUM(NSUInteger, FIRVectorSource) { + FIRVectorSourceServer, +} NS_SWIFT_NAME(VectorSource); diff --git a/Firestore/Swift/Source/SwiftAPI/FIeldPath+Expressible.swift b/Firestore/Swift/Source/SwiftAPI/FIeldPath+Expressible.swift new file mode 100644 index 00000000000..35c5cb3d157 --- /dev/null +++ b/Firestore/Swift/Source/SwiftAPI/FIeldPath+Expressible.swift @@ -0,0 +1,38 @@ +// +// FIeldPath+Expressible.swift +// FirebaseFirestore +// +// Created by Mark Duckworth on 8/2/24. +// + +import Foundation + +#if SWIFT_PACKAGE + @_exported import FirebaseFirestoreInternalWrapper +#else + @_exported import FirebaseFirestoreInternal +#endif // SWIFT_PACKAGE +// extension FieldPath : ExpressibleByStringLiteral { +// public required convenience init(stringLiteral: String) { +// self.init([stringLiteral]) +// } +// } + +private protocol ExpressibleByStringLiteralFieldPath: ExpressibleByStringLiteral { + init(_: [String]) +} + +/** + * An extension of VectorValue that implements the behavior of the Codable protocol. + * + * Note: this is implemented manually here because the Swift compiler can't synthesize these methods + * when declaring an extension to conform to Codable. + */ +extension ExpressibleByStringLiteralFieldPath { + public init(stringLiteral: String) { + self.init([stringLiteral]) + } +} + +/** Extends VectorValue to conform to Codable. */ +extension FieldPath: ExpressibleByStringLiteralFieldPath {} diff --git a/Firestore/Swift/Source/SwiftAPI/FindNearestOptions+Swift.swift b/Firestore/Swift/Source/SwiftAPI/FindNearestOptions+Swift.swift new file mode 100644 index 00000000000..9c48d90ed81 --- /dev/null +++ b/Firestore/Swift/Source/SwiftAPI/FindNearestOptions+Swift.swift @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if SWIFT_PACKAGE + @_exported import FirebaseFirestoreInternalWrapper +#else + @_exported import FirebaseFirestoreInternal +#endif // SWIFT_PACKAGE + +public extension FindNearestOptions { + var distanceThreshold: Double { + return __distanceThreshold.doubleValue + } + + func withDistanceThreshold(_ distanceThreshold: Double) -> FindNearestOptions { + return __withDistanceThreshold(NSNumber(value: distanceThreshold)) + } +} diff --git a/Firestore/Swift/Source/SwiftAPI/Query+Swift.swift b/Firestore/Swift/Source/SwiftAPI/Query+Swift.swift new file mode 100644 index 00000000000..f54e3189110 --- /dev/null +++ b/Firestore/Swift/Source/SwiftAPI/Query+Swift.swift @@ -0,0 +1,39 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if SWIFT_PACKAGE + @_exported import FirebaseFirestoreInternalWrapper +#else + @_exported import FirebaseFirestoreInternal +#endif // SWIFT_PACKAGE + +public extension Query { + func findNearest(fieldPath: FieldPath, + queryVector: VectorValue, + limit: Int64, + distanceMeasure: FirestoreDistanceMeasure, + options: FindNearestOptions = FindNearestOptions()) -> VectorQuery { + fatalError("not implemented") + } + + func findNearest(fieldPath: FieldPath, + queryVector: [Double], + limit: Int64, + distanceMeasure: FirestoreDistanceMeasure, + options: FindNearestOptions = FindNearestOptions()) -> VectorQuery { + fatalError("not implemented") + } +} diff --git a/Firestore/Swift/Source/SwiftAPI/VectorQuery+Swift.swift b/Firestore/Swift/Source/SwiftAPI/VectorQuery+Swift.swift new file mode 100644 index 00000000000..558c05bea9a --- /dev/null +++ b/Firestore/Swift/Source/SwiftAPI/VectorQuery+Swift.swift @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if SWIFT_PACKAGE + @_exported import FirebaseFirestoreInternalWrapper +#else + @_exported import FirebaseFirestoreInternal +#endif // SWIFT_PACKAGE + +public extension VectorQuery { + func getDocuments(source: VectorSource) async throws -> VectorQuerySnapshot { + fatalError("not implemented") + } +} diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index 069a71a6541..0db7bce41c0 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -21,6 +21,19 @@ import Foundation // iOS 15 required for test implementation, not vector feature @available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *) class VectorIntegrationTests: FSTIntegrationTestCase { + func exampleFindNearest() async throws { + let collection = collectionRef() + + let vectorQuery = collection.findNearest( + fieldPath: "embedding", + queryVector: [1.0, 2.0, 3.0], + limit: 10, + distanceMeasure: FirestoreDistanceMeasure.cosine + ) + + try await vectorQuery.getDocuments(source: VectorSource.server) + } + func testWriteAndReadVectorEmbeddings() async throws { let collection = collectionRef() From d9eea7a58c2edf8fa0e5ad09aa83f6ff8a2e1400 Mon Sep 17 00:00:00 2001 From: Mark Duckworth <1124037+MarkDuckworth@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:10:22 -0600 Subject: [PATCH 24/24] another example --- .../Integration/VectorIntegrationTests.swift | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift index 0db7bce41c0..22740e8058d 100644 --- a/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift +++ b/Firestore/Swift/Tests/Integration/VectorIntegrationTests.swift @@ -34,6 +34,21 @@ class VectorIntegrationTests: FSTIntegrationTestCase { try await vectorQuery.getDocuments(source: VectorSource.server) } + func exampleFindNearestWithOptions() async throws { + let collection = collectionRef() + + let vectorQuery = collection.findNearest( + fieldPath: "embedding", + queryVector: [1.0, 2.0, 3.0], + limit: 10, + distanceMeasure: FirestoreDistanceMeasure.cosine, + options: FindNearestOptions().withDistanceResultFieldPath("distance") + .withDistanceThreshold(0.5) + ) + + try await vectorQuery.getDocuments(source: VectorSource.server) + } + func testWriteAndReadVectorEmbeddings() async throws { let collection = collectionRef()