Skip to content
This repository was archived by the owner on Mar 4, 2026. It is now read-only.

Commit b9ea514

Browse files
ehsannasMarkDuckworthgcf-owl-bot[bot]
authored
feat: Sum and Average aggregations (#1873)
* feat: Sum and Average Aggregations. (#1831) * WIP WIP. * also add isEqual. * lint. * add aggregate() API in firestore.d.ts. * Use NodeJS.Dict * Update isEqual. * Add isEqual unit test. * better api. * more tests. * Add more tests. * Remove test code that targets emulator. * rename avg() to average(). * rename avg to average. * Fix lint errors. * Clean up. * Address code review comments. * Expose aggregate type and field publicly. * Fix the way assert was imported. * lint. * backport test updates. * feat: Add long-alias support for aggregations. (#1844) * feat: Add long-alias support for aggregations. * Lint fix and fix unit tests. * update assertion message. * Unhide APIs and enable tests (#1869) * Enable tests. * Remove @internal annotation. * Address comments. * Address comments. * Removing AggregateField.field from new api. * Fix failing test assertion with REST transport. * Fix test query. * Add more tests with cursors. * prettier. * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Mark Duckworth <[email protected]> Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent f043d1c commit b9ea514

8 files changed

Lines changed: 2221 additions & 21 deletions

File tree

dev/src/aggregate.ts

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
import * as firestore from '@google-cloud/firestore';
19+
20+
import {FieldPath} from './path';
21+
import {google} from '../protos/firestore_v1_proto_api';
22+
23+
import IAggregation = google.firestore.v1.StructuredAggregationQuery.IAggregation;
24+
import * as assert from 'assert';
25+
26+
/**
27+
* Concrete implementation of the Aggregate type.
28+
*/
29+
export class Aggregate {
30+
constructor(
31+
readonly alias: string,
32+
readonly aggregateType: AggregateType,
33+
readonly fieldPath?: string | FieldPath
34+
) {}
35+
36+
/**
37+
* Converts this object to the proto representation of an Aggregate.
38+
* @internal
39+
*/
40+
toProto(): IAggregation {
41+
const proto: IAggregation = {};
42+
if (this.aggregateType === 'count') {
43+
proto.count = {};
44+
} else if (this.aggregateType === 'sum') {
45+
assert(
46+
this.fieldPath !== undefined,
47+
'Missing field path for sum aggregation.'
48+
);
49+
proto.sum = {
50+
field: {
51+
fieldPath: FieldPath.fromArgument(this.fieldPath!).formattedName,
52+
},
53+
};
54+
} else if (this.aggregateType === 'avg') {
55+
assert(
56+
this.fieldPath !== undefined,
57+
'Missing field path for average aggregation.'
58+
);
59+
proto.avg = {
60+
field: {
61+
fieldPath: FieldPath.fromArgument(this.fieldPath!).formattedName,
62+
},
63+
};
64+
} else {
65+
throw new Error(`Aggregate type ${this.aggregateType} unimplemented.`);
66+
}
67+
proto.alias = this.alias;
68+
return proto;
69+
}
70+
}
71+
72+
/**
73+
* Represents an aggregation that can be performed by Firestore.
74+
*/
75+
export class AggregateField<T> implements firestore.AggregateField<T> {
76+
/** A type string to uniquely identify instances of this class. */
77+
readonly type = 'AggregateField';
78+
79+
/**
80+
* The field on which the aggregation is performed.
81+
* @internal
82+
**/
83+
public readonly _field?: string | FieldPath;
84+
85+
/**
86+
* Create a new AggregateField<T>
87+
* @param aggregateType Specifies the type of aggregation operation to perform.
88+
* @param field Optionally specifies the field that is aggregated.
89+
* @internal
90+
*/
91+
private constructor(
92+
public readonly aggregateType: AggregateType,
93+
field?: string | FieldPath
94+
) {
95+
this._field = field;
96+
}
97+
98+
/**
99+
* Compares this object with the given object for equality.
100+
*
101+
* This object is considered "equal" to the other object if and only if
102+
* `other` performs the same kind of aggregation on the same field (if any).
103+
*
104+
* @param other The object to compare to this object for equality.
105+
* @return `true` if this object is "equal" to the given object, as
106+
* defined above, or `false` otherwise.
107+
*/
108+
isEqual(other: AggregateField<T>): boolean {
109+
return (
110+
other instanceof AggregateField &&
111+
this.aggregateType === other.aggregateType &&
112+
((this._field === undefined && other._field === undefined) ||
113+
(this._field !== undefined &&
114+
other._field !== undefined &&
115+
FieldPath.fromArgument(this._field).isEqual(
116+
FieldPath.fromArgument(other._field)
117+
)))
118+
);
119+
}
120+
121+
/**
122+
* Create an AggregateField object that can be used to compute the count of
123+
* documents in the result set of a query.
124+
*/
125+
static count(): AggregateField<number> {
126+
return new AggregateField<number>('count');
127+
}
128+
129+
/**
130+
* Create an AggregateField object that can be used to compute the average of
131+
* a specified field over a range of documents in the result set of a query.
132+
* @param field Specifies the field to average across the result set.
133+
*/
134+
static average(field: string | FieldPath): AggregateField<number | null> {
135+
return new AggregateField<number | null>('avg', field);
136+
}
137+
138+
/**
139+
* Create an AggregateField object that can be used to compute the sum of
140+
* a specified field over a range of documents in the result set of a query.
141+
* @param field Specifies the field to sum across the result set.
142+
*/
143+
static sum(field: string | FieldPath): AggregateField<number> {
144+
return new AggregateField<number>('sum', field);
145+
}
146+
}
147+
148+
/**
149+
* A type whose property values are all `AggregateField` objects.
150+
*/
151+
export interface AggregateSpec {
152+
[field: string]: AggregateFieldType;
153+
}
154+
155+
/**
156+
* The union of all `AggregateField` types that are supported by Firestore.
157+
*/
158+
export type AggregateFieldType =
159+
| ReturnType<typeof AggregateField.count>
160+
| ReturnType<typeof AggregateField.sum>
161+
| ReturnType<typeof AggregateField.average>;
162+
163+
/**
164+
* Union type representing the aggregate type to be performed.
165+
*/
166+
export type AggregateType = 'count' | 'avg' | 'sum';

dev/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ export {GeoPoint} from './geo-point';
105105
export {CollectionGroup};
106106
export {QueryPartition} from './query-partition';
107107
export {setLogFunction} from './logger';
108+
export {AggregateField, Aggregate} from './aggregate';
108109

109110
const libVersion = require('../../package.json').version;
110111
setLibVersion(libVersion);

dev/src/reference.ts

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616

1717
import * as firestore from '@google-cloud/firestore';
18+
import * as assert from 'assert';
1819
import {Duplex, Readable, Transform} from 'stream';
1920
import * as deepEqual from 'fast-deep-equal';
2021
import {GoogleError} from 'google-gax';
@@ -44,6 +45,7 @@ import {
4445
autoId,
4546
Deferred,
4647
isPermanentRpcError,
48+
mapToArray,
4749
requestTag,
4850
wrapError,
4951
} from './util';
@@ -58,6 +60,7 @@ import {DocumentWatch, QueryWatch} from './watch';
5860
import {validateDocumentData, WriteBatch, WriteResult} from './write-batch';
5961
import api = protos.google.firestore.v1;
6062
import {CompositeFilter, Filter, UnaryFilter} from './filter';
63+
import {AggregateField, Aggregate, AggregateSpec} from './aggregate';
6164

6265
/**
6366
* The direction of a `Query.orderBy()` clause is specified as 'desc' or 'asc'
@@ -1848,7 +1851,47 @@ export class Query<
18481851
AppModelType,
18491852
DbModelType
18501853
> {
1851-
return new AggregateQuery(this, {count: {}});
1854+
return this.aggregate({
1855+
count: AggregateField.count(),
1856+
});
1857+
}
1858+
1859+
/**
1860+
* Returns a query that can perform the given aggregations.
1861+
*
1862+
* The returned query, when executed, calculates the specified aggregations
1863+
* over the documents in the result set of this query, without actually
1864+
* downloading the documents.
1865+
*
1866+
* Using the returned query to perform aggregations is efficient because only
1867+
* the final aggregation values, not the documents' data, is downloaded. The
1868+
* returned query can even perform aggregations of the documents if the result set
1869+
* would be prohibitively large to download entirely (e.g. thousands of documents).
1870+
*
1871+
* @param aggregateSpec An `AggregateSpec` object that specifies the aggregates
1872+
* to perform over the result set. The AggregateSpec specifies aliases for each
1873+
* aggregate, which can be used to retrieve the aggregate result.
1874+
* @example
1875+
* ```typescript
1876+
* const aggregateQuery = col.aggregate(query, {
1877+
* countOfDocs: count(),
1878+
* totalHours: sum('hours'),
1879+
* averageScore: average('score')
1880+
* });
1881+
*
1882+
* const aggregateSnapshot = await aggregateQuery.get();
1883+
* const countOfDocs: number = aggregateSnapshot.data().countOfDocs;
1884+
* const totalHours: number = aggregateSnapshot.data().totalHours;
1885+
* const averageScore: number | null = aggregateSnapshot.data().averageScore;
1886+
* ```
1887+
*/
1888+
aggregate<T extends firestore.AggregateSpec>(
1889+
aggregateSpec: T
1890+
): AggregateQuery<T, AppModelType, DbModelType> {
1891+
return new AggregateQuery<T, AppModelType, DbModelType>(
1892+
this,
1893+
aggregateSpec
1894+
);
18521895
}
18531896

18541897
/**
@@ -3163,12 +3206,15 @@ export class CollectionReference<
31633206
* A query that calculates aggregations over an underlying query.
31643207
*/
31653208
export class AggregateQuery<
3166-
AggregateSpecType extends firestore.AggregateSpec,
3209+
AggregateSpecType extends AggregateSpec,
31673210
AppModelType = firestore.DocumentData,
31683211
DbModelType extends firestore.DocumentData = firestore.DocumentData,
31693212
> implements
31703213
firestore.AggregateQuery<AggregateSpecType, AppModelType, DbModelType>
31713214
{
3215+
private readonly clientAliasToServerAliasMap: Record<string, string> = {};
3216+
private readonly serverAliasToClientAliasMap: Record<string, string> = {};
3217+
31723218
/**
31733219
* @private
31743220
* @internal
@@ -3181,7 +3227,19 @@ export class AggregateQuery<
31813227
// eslint-disable-next-line @typescript-eslint/no-explicit-any
31823228
private readonly _query: Query<AppModelType, DbModelType>,
31833229
private readonly _aggregates: AggregateSpecType
3184-
) {}
3230+
) {
3231+
// Client-side aliases may be too long and exceed the 1500-byte string size limit.
3232+
// Such long strings do not need to be transferred over the wire either.
3233+
// The client maps the user's alias to a short form alias and send that to the server.
3234+
let aggregationNum = 0;
3235+
for (const clientAlias in this._aggregates) {
3236+
if (Object.prototype.hasOwnProperty.call(this._aggregates, clientAlias)) {
3237+
const serverAlias = `aggregate_${aggregationNum++}`;
3238+
this.clientAliasToServerAliasMap[clientAlias] = serverAlias;
3239+
this.serverAliasToClientAliasMap[serverAlias] = clientAlias;
3240+
}
3241+
}
3242+
}
31853243

31863244
/** The query whose aggregations will be calculated by this object. */
31873245
get query(): Query<AppModelType, DbModelType> {
@@ -3323,12 +3381,17 @@ export class AggregateQuery<
33233381
if (fields) {
33243382
const serializer = this._query.firestore._serializer!;
33253383
for (const prop of Object.keys(fields)) {
3326-
if (this._aggregates[prop] === undefined) {
3384+
const alias = this.serverAliasToClientAliasMap[prop];
3385+
assert(
3386+
alias !== null && alias !== undefined,
3387+
`'${prop}' not present in server-client alias mapping.`
3388+
);
3389+
if (this._aggregates[alias] === undefined) {
33273390
throw new Error(
33283391
`Unexpected alias [${prop}] in result aggregate result`
33293392
);
33303393
}
3331-
data[prop] = serializer.decodeValue(fields[prop]);
3394+
data[alias] = serializer.decodeValue(fields[prop]);
33323395
}
33333396
}
33343397
return data;
@@ -3344,18 +3407,22 @@ export class AggregateQuery<
33443407
*/
33453408
toProto(transactionId?: Uint8Array): api.IRunAggregationQueryRequest {
33463409
const queryProto = this._query.toProto();
3347-
//TODO(tomandersen) inspect _query to build request - this is just hard
3348-
// coded count right now.
33493410
const runQueryRequest: api.IRunAggregationQueryRequest = {
33503411
parent: queryProto.parent,
33513412
structuredAggregationQuery: {
33523413
structuredQuery: queryProto.structuredQuery,
3353-
aggregations: [
3354-
{
3355-
alias: 'count',
3356-
count: {},
3357-
},
3358-
],
3414+
aggregations: mapToArray(this._aggregates, (aggregate, clientAlias) => {
3415+
const serverAlias = this.clientAliasToServerAliasMap[clientAlias];
3416+
assert(
3417+
serverAlias !== null && serverAlias !== undefined,
3418+
`'${clientAlias}' not present in client-server alias mapping.`
3419+
);
3420+
return new Aggregate(
3421+
serverAlias,
3422+
aggregate.aggregateType,
3423+
aggregate._field
3424+
).toProto();
3425+
}),
33593426
},
33603427
};
33613428

dev/src/util.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {randomBytes} from 'crypto';
2020
import type {CallSettings, ClientConfig, GoogleError} from 'google-gax';
2121
import type {BackoffSettings} from 'google-gax/build/src/gax';
2222
import * as gapicConfig from './v1/firestore_client_config.json';
23+
import Dict = NodeJS.Dict;
2324

2425
/**
2526
* A Promise implementation that supports deferred resolution.
@@ -246,3 +247,23 @@ export function tryGetPreferRestEnvironmentVariable(): boolean | undefined {
246247
return undefined;
247248
}
248249
}
250+
251+
/**
252+
* Returns an array of values that are calculated by performing the given `fn`
253+
* on all keys in the given `obj` dictionary.
254+
*
255+
* @private
256+
* @internal
257+
*/
258+
export function mapToArray<V, R>(
259+
obj: Dict<V>,
260+
fn: (element: V, key: string, obj: Dict<V>) => R
261+
): R[] {
262+
const result: R[] = [];
263+
for (const key in obj) {
264+
if (Object.prototype.hasOwnProperty.call(obj, key)) {
265+
result.push(fn(obj[key]!, key, obj));
266+
}
267+
}
268+
return result;
269+
}

0 commit comments

Comments
 (0)