Skip to content

Commit 81bfa0d

Browse files
feat: add support for the VectorValue type (#1716)
Implement VectorValue type support.
1 parent 4384970 commit 81bfa0d

File tree

11 files changed

+509
-6
lines changed

11 files changed

+509
-6
lines changed

google-cloud-firestore/src/main/java/com/google/cloud/firestore/CustomClassMapper.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ private static <T> Object serialize(T o, ErrorPath path) {
174174
|| o instanceof Blob
175175
|| o instanceof DocumentReference
176176
|| o instanceof FieldValue
177-
|| o instanceof Value) {
177+
|| o instanceof Value
178+
|| o instanceof VectorValue) {
178179
return o;
179180
} else if (o instanceof Instant) {
180181
Instant instant = (Instant) o;
@@ -243,6 +244,8 @@ private static <T> T deserializeToClass(Object o, Class<T> clazz, DeserializeCon
243244
return (T) convertBlob(o, context);
244245
} else if (GeoPoint.class.isAssignableFrom(clazz)) {
245246
return (T) convertGeoPoint(o, context);
247+
} else if (VectorValue.class.isAssignableFrom(clazz)) {
248+
return (T) convertVectorValue(o, context);
246249
} else if (DocumentReference.class.isAssignableFrom(clazz)) {
247250
return (T) convertDocumentReference(o, context);
248251
} else if (clazz.isArray()) {
@@ -596,6 +599,16 @@ private static GeoPoint convertGeoPoint(Object o, DeserializeContext context) {
596599
}
597600
}
598601

602+
private static VectorValue convertVectorValue(Object o, DeserializeContext context) {
603+
if (o instanceof VectorValue) {
604+
return (VectorValue) o;
605+
} else {
606+
throw deserializeError(
607+
context.errorPath,
608+
"Failed to convert value of type " + o.getClass().getName() + " to VectorValue");
609+
}
610+
}
611+
599612
private static DocumentReference convertDocumentReference(Object o, DeserializeContext context) {
600613
if (o instanceof DocumentReference) {
601614
return (DocumentReference) o;

google-cloud-firestore/src/main/java/com/google/cloud/firestore/DocumentSnapshot.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,18 @@ public GeoPoint getGeoPoint(@Nonnull String field) {
385385
return (GeoPoint) get(field);
386386
}
387387

388+
/**
389+
* Returns the value of the field as a VectorValue.
390+
*
391+
* @param field The path to the field.
392+
* @throws RuntimeException if the value is not a VectorValue.
393+
* @return The value of the field.
394+
*/
395+
@Nullable
396+
public VectorValue getVectorValue(@Nonnull String field) {
397+
return (VectorValue) get(field);
398+
}
399+
388400
/**
389401
* Gets the reference to the document.
390402
*

google-cloud-firestore/src/main/java/com/google/cloud/firestore/FieldValue.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,17 @@ public static FieldValue arrayRemove(@Nonnull Object... elements) {
319319
return new ArrayRemoveFieldValue(Arrays.asList(elements));
320320
}
321321

322+
/**
323+
* Creates a new {@link VectorValue} constructed with a copy of the given array of doubles.
324+
*
325+
* @param values Create a {@link VectorValue} instance with a copy of this array of doubles.
326+
* @return A new {@link VectorValue} constructed with a copy of the given array of doubles.
327+
*/
328+
@Nonnull
329+
public static VectorValue vector(@Nonnull double[] values) {
330+
return new VectorValue(values);
331+
}
332+
322333
/** Whether this FieldTransform should be included in the document mask. */
323334
abstract boolean includeInDocumentMask();
324335

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.firestore;
18+
19+
abstract class MapType {
20+
static final String RESERVED_MAP_KEY = "__type__";
21+
static final String RESERVED_MAP_KEY_VECTOR_VALUE = "__vector__";
22+
static final String VECTOR_MAP_VECTORS_KEY = "value";
23+
}

google-cloud-firestore/src/main/java/com/google/cloud/firestore/Order.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
package com.google.cloud.firestore;
1818

19+
import com.google.firestore.v1.MapValue;
1920
import com.google.firestore.v1.Value;
2021
import com.google.firestore.v1.Value.ValueTypeCase;
2122
import com.google.protobuf.ByteString;
23+
import java.util.Collections;
2224
import java.util.Comparator;
2325
import java.util.Iterator;
2426
import java.util.List;
@@ -40,6 +42,7 @@ enum TypeOrder implements Comparable<TypeOrder> {
4042
REF,
4143
GEO_POINT,
4244
ARRAY,
45+
VECTOR,
4346
OBJECT;
4447

4548
static TypeOrder fromValue(Value value) {
@@ -65,13 +68,24 @@ static TypeOrder fromValue(Value value) {
6568
case ARRAY_VALUE:
6669
return ARRAY;
6770
case MAP_VALUE:
68-
return OBJECT;
71+
return fromMapValue(value.getMapValue());
6972
default:
7073
throw new IllegalArgumentException("Could not detect value type for " + value);
7174
}
7275
}
7376
}
7477

78+
static TypeOrder fromMapValue(MapValue mapValue) {
79+
switch (UserDataConverter.detectMapRepresentation(mapValue)) {
80+
case VECTOR_VALUE:
81+
return TypeOrder.VECTOR;
82+
case UNKNOWN:
83+
case NONE:
84+
default:
85+
return TypeOrder.OBJECT;
86+
}
87+
}
88+
7589
static final Order INSTANCE = new Order();
7690

7791
private Order() {}
@@ -113,6 +127,8 @@ public int compare(@Nonnull Value left, @Nonnull Value right) {
113127
left.getArrayValue().getValuesList(), right.getArrayValue().getValuesList());
114128
case OBJECT:
115129
return compareObjects(left, right);
130+
case VECTOR:
131+
return compareVectors(left, right);
116132
default:
117133
throw new IllegalArgumentException("Cannot compare " + leftType);
118134
}
@@ -209,6 +225,30 @@ private int compareObjects(Value left, Value right) {
209225
return Boolean.compare(leftIterator.hasNext(), rightIterator.hasNext());
210226
}
211227

228+
private int compareVectors(Value left, Value right) {
229+
// The vector is a map, but only vector value is compared.
230+
Value leftValueField =
231+
left.getMapValue().getFieldsOrDefault(MapType.VECTOR_MAP_VECTORS_KEY, null);
232+
Value rightValueField =
233+
right.getMapValue().getFieldsOrDefault(MapType.VECTOR_MAP_VECTORS_KEY, null);
234+
235+
List<Value> leftArray =
236+
(leftValueField != null)
237+
? leftValueField.getArrayValue().getValuesList()
238+
: Collections.emptyList();
239+
List<Value> rightArray =
240+
(rightValueField != null)
241+
? rightValueField.getArrayValue().getValuesList()
242+
: Collections.emptyList();
243+
244+
Integer lengthCompare = Long.compare(leftArray.size(), rightArray.size());
245+
if (lengthCompare != 0) {
246+
return lengthCompare;
247+
}
248+
249+
return compareArrays(leftArray, rightArray);
250+
}
251+
212252
private int compareNumbers(Value left, Value right) {
213253
if (left.getValueTypeCase() == ValueTypeCase.DOUBLE_VALUE) {
214254
if (right.getValueTypeCase() == ValueTypeCase.DOUBLE_VALUE) {

google-cloud-firestore/src/main/java/com/google/cloud/firestore/UserDataConverter.java

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.google.common.base.Preconditions;
2121
import com.google.common.collect.Lists;
2222
import com.google.common.collect.Maps;
23+
import com.google.common.primitives.Doubles;
2324
import com.google.firestore.v1.ArrayValue;
2425
import com.google.firestore.v1.MapValue;
2526
import com.google.firestore.v1.Value;
@@ -32,10 +33,12 @@
3233
import java.util.List;
3334
import java.util.Map;
3435
import java.util.concurrent.TimeUnit;
36+
import java.util.logging.Logger;
3537
import javax.annotation.Nullable;
3638

3739
/** Converts user input into the Firestore Value representation. */
3840
class UserDataConverter {
41+
private static final Logger LOGGER = Logger.getLogger(UserDataConverter.class.getName());
3942

4043
/** Controls the behavior for field deletes. */
4144
interface EncodingOptions {
@@ -183,12 +186,34 @@ static Value encodeValue(
183186
// send the map.
184187
return null;
185188
}
189+
} else if (sanitizedObject instanceof VectorValue) {
190+
VectorValue vectorValue = (VectorValue) sanitizedObject;
191+
return Value.newBuilder().setMapValue(vectorValue.toProto()).build();
186192
}
187193

188194
throw FirestoreException.forInvalidArgument(
189195
"Cannot convert %s to Firestore Value", sanitizedObject);
190196
}
191197

198+
static MapValue encodeVector(double[] rawVector) {
199+
MapValue.Builder res = MapValue.newBuilder();
200+
201+
res.putFields(
202+
MapType.RESERVED_MAP_KEY,
203+
encodeValue(
204+
FieldPath.fromDotSeparatedString(MapType.RESERVED_MAP_KEY),
205+
MapType.RESERVED_MAP_KEY_VECTOR_VALUE,
206+
ARGUMENT));
207+
res.putFields(
208+
MapType.VECTOR_MAP_VECTORS_KEY,
209+
encodeValue(
210+
FieldPath.fromDotSeparatedString(MapType.RESERVED_MAP_KEY_VECTOR_VALUE),
211+
Doubles.asList(rawVector),
212+
ARGUMENT));
213+
214+
return res.build();
215+
}
216+
192217
static Object decodeValue(FirestoreRpcContext<?> rpcContext, Value v) {
193218
Value.ValueTypeCase typeCase = v.getValueTypeCase();
194219
switch (typeCase) {
@@ -220,18 +245,72 @@ static Object decodeValue(FirestoreRpcContext<?> rpcContext, Value v) {
220245
}
221246
return list;
222247
case MAP_VALUE:
248+
return decodeMap(rpcContext, v.getMapValue());
249+
default:
250+
throw FirestoreException.forInvalidArgument(
251+
String.format("Unknown Value Type: %s", typeCase));
252+
}
253+
}
254+
255+
static Object decodeMap(FirestoreRpcContext<?> rpcContext, MapValue mapValue) {
256+
MapRepresentation mapRepresentation = detectMapRepresentation(mapValue);
257+
Map<String, Value> inputMap = mapValue.getFieldsMap();
258+
switch (mapRepresentation) {
259+
case UNKNOWN:
260+
LOGGER.warning(
261+
"Parsing unknown map type as generic map. This map type may be supported in a newer SDK version.");
262+
case NONE:
223263
Map<String, Object> outputMap = new HashMap<>();
224-
Map<String, Value> inputMap = v.getMapValue().getFieldsMap();
225264
for (Map.Entry<String, Value> entry : inputMap.entrySet()) {
226265
outputMap.put(entry.getKey(), decodeValue(rpcContext, entry.getValue()));
227266
}
228267
return outputMap;
268+
case VECTOR_VALUE:
269+
double[] values =
270+
inputMap.get(MapType.VECTOR_MAP_VECTORS_KEY).getArrayValue().getValuesList().stream()
271+
.mapToDouble(val -> val.getDoubleValue())
272+
.toArray();
273+
return new VectorValue(values);
229274
default:
230275
throw FirestoreException.forInvalidArgument(
231-
String.format("Unknown Value Type: %s", typeCase));
276+
String.format("Unsupported MapRepresentation: %s", mapRepresentation));
232277
}
233278
}
234279

280+
/** Indicates the data type represented by a MapValue. */
281+
enum MapRepresentation {
282+
/** The MapValue represents an unknown data type. */
283+
UNKNOWN,
284+
/** The MapValue does not represent any special data type. */
285+
NONE,
286+
/** The MapValue represents a VectorValue. */
287+
VECTOR_VALUE
288+
}
289+
290+
static MapRepresentation detectMapRepresentation(MapValue mapValue) {
291+
Map<String, Value> fields = mapValue.getFieldsMap();
292+
if (!fields.containsKey(MapType.RESERVED_MAP_KEY)) {
293+
return MapRepresentation.NONE;
294+
}
295+
296+
Value typeValue = fields.get(MapType.RESERVED_MAP_KEY);
297+
if (typeValue.getValueTypeCase() != Value.ValueTypeCase.STRING_VALUE) {
298+
LOGGER.warning(
299+
"Unable to parse __type__ field of map. Unsupported value type: "
300+
+ typeValue.getValueTypeCase().toString());
301+
return MapRepresentation.UNKNOWN;
302+
}
303+
304+
String typeString = typeValue.getStringValue();
305+
306+
if (typeString.equals(MapType.RESERVED_MAP_KEY_VECTOR_VALUE)) {
307+
return MapRepresentation.VECTOR_VALUE;
308+
}
309+
310+
LOGGER.warning("Unsupported __type__ value for map: " + typeString);
311+
return MapRepresentation.UNKNOWN;
312+
}
313+
235314
static Object decodeGoogleProtobufValue(com.google.protobuf.Value v) {
236315
switch (v.getKindCase()) {
237316
case NULL_VALUE:
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.firestore;
18+
19+
import com.google.firestore.v1.MapValue;
20+
import java.io.Serializable;
21+
import java.util.Arrays;
22+
import javax.annotation.Nonnull;
23+
import javax.annotation.Nullable;
24+
25+
/**
26+
* Represents a vector in Firestore documents. Create an instance with {@link FieldValue#vector}.
27+
*/
28+
public final class VectorValue implements Serializable {
29+
private final double[] values;
30+
31+
VectorValue(@Nullable double[] values) {
32+
if (values == null) this.values = new double[] {};
33+
else this.values = values.clone();
34+
}
35+
36+
/**
37+
* Returns a representation of the vector as an array of doubles.
38+
*
39+
* @return A representation of the vector as an array of doubles
40+
*/
41+
@Nonnull
42+
public double[] toArray() {
43+
return this.values.clone();
44+
}
45+
46+
/**
47+
* Returns true if this VectorValue is equal to the provided object.
48+
*
49+
* @param obj The object to compare against.
50+
* @return Whether this VectorValue is equal to the provided object.
51+
*/
52+
@Override
53+
public boolean equals(Object obj) {
54+
if (this == obj) {
55+
return true;
56+
}
57+
if (obj == null || getClass() != obj.getClass()) {
58+
return false;
59+
}
60+
VectorValue otherArray = (VectorValue) obj;
61+
return Arrays.equals(this.values, otherArray.values);
62+
}
63+
64+
@Override
65+
public int hashCode() {
66+
return Arrays.hashCode(values);
67+
}
68+
69+
MapValue toProto() {
70+
return UserDataConverter.encodeVector(this.values);
71+
}
72+
}

google-cloud-firestore/src/test/java/com/google/cloud/firestore/DocumentReferenceTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,8 @@ public void extractFieldMaskFromMerge() throws Exception {
880880
"second.objectValue.foo",
881881
"second.timestampValue",
882882
"second.trueValue",
883-
"second.model.foo");
883+
"second.model.foo",
884+
"second.vectorValue");
884885

885886
CommitRequest expectedCommit = commit(set(nestedUpdate, updateMask));
886887
assertCommitEquals(expectedCommit, commitCapture.getValue());

0 commit comments

Comments
 (0)