diff --git a/.changeset/add-string-agg.md b/.changeset/add-string-agg.md new file mode 100644 index 000000000..899d38cf9 --- /dev/null +++ b/.changeset/add-string-agg.md @@ -0,0 +1,6 @@ +--- +'@tanstack/db-ivm': patch +'@tanstack/db': patch +--- + +Add `stringAgg` aggregate function for concatenating string values within groups. Supports configurable separators and ordering with efficient incremental maintenance via binary search and fast-path text splicing for head/tail changes. diff --git a/docs/guides/live-queries.md b/docs/guides/live-queries.md index 440d48834..40934f883 100644 --- a/docs/guides/live-queries.md +++ b/docs/guides/live-queries.md @@ -1122,7 +1122,7 @@ const userStats = createCollection(liveQueryCollectionOptions({ Use various aggregate functions to summarize your data: ```ts -import { count, sum, avg, min, max } from '@tanstack/db' +import { count, sum, avg, min, max, stringAgg } from '@tanstack/db' const orderStats = createCollection(liveQueryCollectionOptions({ query: (q) => @@ -1136,6 +1136,7 @@ const orderStats = createCollection(liveQueryCollectionOptions({ avgOrderValue: avg(order.amount), minOrder: min(order.amount), maxOrder: max(order.amount), + statusTimeline: stringAgg(order.status, ' -> ', order.createdAt), })) })) ``` @@ -2254,6 +2255,15 @@ min(user.salary) max(order.amount) ``` +#### `stringAgg(value)`, `stringAgg(value, orderBy)`, `stringAgg(value, separator)`, `stringAgg(value, separator, orderBy)` +Concatenate string values within each group. When `orderBy` is omitted, TanStack DB falls back to the source row key for deterministic ordering: +```ts +stringAgg(delta.text) // Deterministic fallback order by row key +stringAgg(delta.text, delta.createdAt) // Ordered by createdAt +stringAgg(delta.text, ' ') // Custom separator with fallback order by row key +stringAgg(delta.text, ' ', delta.seq) // Ordered by seq with separator +``` + ### Function Composition Functions can be composed and chained: diff --git a/packages/db-ivm/src/operators/groupBy.ts b/packages/db-ivm/src/operators/groupBy.ts index 9c2fe1e35..bca32a9ca 100644 --- a/packages/db-ivm/src/operators/groupBy.ts +++ b/packages/db-ivm/src/operators/groupBy.ts @@ -1,31 +1,62 @@ -import { serializeValue } from '../utils.js' +import { binarySearch, compareKeys, serializeValue } from '../utils.js' import { map } from './map.js' import { reduce } from './reduce.js' import type { IStreamBuilder, KeyValue } from '../types.js' type GroupKey = Record -type BasicAggregateFunction = { +type BasicAggregateFunction = { preMap: (data: T) => V - reduce: (values: Array<[V, number]>) => V - postMap?: (result: V) => R + reduce: (values: Array<[V, number]>, groupKey: string) => Reduced + postMap?: (result: Reduced) => R + cleanup?: (groupKey: string) => void } type PipedAggregateFunction = { pipe: (stream: IStreamBuilder) => IStreamBuilder> } -type AggregateFunction = - | BasicAggregateFunction +type AggregateFunction = + | BasicAggregateFunction | PipedAggregateFunction type ExtractAggregateReturnType = - A extends AggregateFunction ? R : never + A extends AggregateFunction ? R : never type AggregatesReturnType = { [K in keyof A]: ExtractAggregateReturnType } +type StringAggOrderable = + | string + | number + | bigint + | boolean + | Date + | null + | undefined + +type StringAggValue = + { + rowKey?: string | number + value: string | null | undefined + orderBy: TOrderBy + } + +type StringAggEntry = + { + rowKey: string | number + value: string + orderBy: TOrderBy + } + +type StringAggState = + { + entriesByKey: Map> + orderedEntries: Array> + text: string + } + function isPipedAggregateFunction( aggregate: AggregateFunction, ): aggregate is PipedAggregateFunction { @@ -40,43 +71,37 @@ function isPipedAggregateFunction( export function groupBy< T, K extends GroupKey, - A extends Record>, + A extends Record>, >(keyExtractor: (data: T) => K, aggregates: A = {} as A) { type ResultType = K & AggregatesReturnType const basicAggregates = Object.fromEntries( Object.entries(aggregates).filter( - ([_, aggregate]) => !isPipedAggregateFunction(aggregate), + ([, aggregate]) => !isPipedAggregateFunction(aggregate), ), - ) as Record> + ) as Record> // @ts-expect-error - TODO: we don't use this yet, but we will // eslint-disable-next-line @typescript-eslint/no-unused-vars const pipedAggregates = Object.fromEntries( - Object.entries(aggregates).filter(([_, aggregate]) => + Object.entries(aggregates).filter(([, aggregate]) => isPipedAggregateFunction(aggregate), ), - ) as Record> + ) as Record> return ( stream: IStreamBuilder, ): IStreamBuilder> => { - // Special key to store the original key object - const KEY_SENTINEL = `__original_key__` + const keySentinel = `__original_key__` - // First map to extract keys and pre-aggregate values const withKeysAndValues = stream.pipe( map((data) => { const key = keyExtractor(data) const keyString = serializeValue(key) + const values: Record = { + [keySentinel]: key, + } - // Create values object with pre-aggregated values - const values: Record = {} - - // Store the original key object - values[KEY_SENTINEL] = key - - // Add pre-aggregated values for (const [name, aggregate] of Object.entries(basicAggregates)) { values[name] = aggregate.preMap(data) } @@ -85,60 +110,48 @@ export function groupBy< }), ) - // Then reduce to compute aggregates const reduced = withKeysAndValues.pipe( - reduce((values) => { - // Calculate total multiplicity to check if the group should exist + reduce((values, keyString) => { let totalMultiplicity = 0 - for (const [_, multiplicity] of values) { + for (const [, multiplicity] of values) { totalMultiplicity += multiplicity } - // If total multiplicity is 0 or negative, the group should be removed completely if (totalMultiplicity <= 0) { + for (const aggregate of Object.values(basicAggregates)) { + aggregate.cleanup?.(keyString) + } return [] } const result: Record = {} + result[keySentinel] = values[0]?.[0]?.[keySentinel] - // Get the original key from first value in group - const originalKey = values[0]?.[0]?.[KEY_SENTINEL] - result[KEY_SENTINEL] = originalKey - - // Apply each aggregate function for (const [name, aggregate] of Object.entries(basicAggregates)) { const preValues = values.map( - ([v, m]) => [v[name], m] as [any, number], + ([value, multiplicity]) => + [value[name], multiplicity] as [unknown, number], ) - result[name] = aggregate.reduce(preValues) + result[name] = aggregate.reduce(preValues, keyString) } return [[result, 1]] }), ) - // Finally map to extract the key and include all values return reduced.pipe( map(([keyString, values]) => { - // Extract the original key - const key = values[KEY_SENTINEL] as K - - // Create intermediate result with key values and aggregate results + const key = values[keySentinel] as K const result: Record = {} - // Add key properties to result Object.assign(result, key) - // Apply postMap if provided for (const [name, aggregate] of Object.entries(basicAggregates)) { - if (aggregate.postMap) { - result[name] = aggregate.postMap(values[name]) - } else { - result[name] = values[name] - } + result[name] = aggregate.postMap + ? aggregate.postMap(values[name]) + : values[name] } - // Return with the string key instead of the object return [keyString, result] as KeyValue }), ) @@ -167,7 +180,7 @@ export function sum( * Creates a count aggregate function */ export function count( - valueExtractor: (value: T) => any = (v) => v, + valueExtractor: (value: T) => unknown = (v) => v, ): AggregateFunction { return { // Count only not-null values (the `== null` comparison gives true for both null and undefined) @@ -211,6 +224,254 @@ export function avg( } } +function compareStringAggOrderValues( + a: StringAggOrderable, + b: StringAggOrderable, +): number { + if (a == null && b == null) return 0 + if (a == null) return -1 + if (b == null) return 1 + + const normalizedA = a instanceof Date ? a.getTime() : a + const normalizedB = b instanceof Date ? b.getTime() : b + + if (normalizedA < normalizedB) return -1 + if (normalizedA > normalizedB) return 1 + return 0 +} + +function compareStringAggEntries( + left: StringAggEntry, + right: StringAggEntry, +): number { + const orderComparison = compareStringAggOrderValues( + left.orderBy, + right.orderBy, + ) + if (orderComparison !== 0) { + return orderComparison + } + return compareKeys(left.rowKey, right.rowKey) +} + +function buildStringAggText( + state: StringAggState, + separator: string, +): void { + state.text = state.orderedEntries.map((entry) => entry.value).join(separator) +} + +function removeStringAggEntry( + state: StringAggState, + entry: StringAggEntry, + separator: string, +): boolean { + const index = binarySearch( + state.orderedEntries, + entry, + compareStringAggEntries, + ) + if (state.orderedEntries[index]?.rowKey !== entry.rowKey) { + throw new Error( + `stringAgg internal state desynchronized: entry missing from orderedEntries`, + ) + } + + const entryCount = state.orderedEntries.length + state.orderedEntries.splice(index, 1) + + if (entryCount === 1) { + state.text = `` + return false + } + + if (index === entryCount - 1) { + const suffixLength = separator.length + entry.value.length + state.text = state.text.slice(0, state.text.length - suffixLength) + return false + } + + if (index === 0) { + state.text = state.text.slice(entry.value.length + separator.length) + return false + } + + return true +} + +function insertStringAggEntry( + state: StringAggState, + entry: StringAggEntry, + separator: string, +): boolean { + const index = binarySearch( + state.orderedEntries, + entry, + compareStringAggEntries, + ) + const entryCount = state.orderedEntries.length + state.orderedEntries.splice(index, 0, entry) + + if (entryCount === 0) { + state.text = entry.value + return false + } + + if (index === entryCount) { + state.text = `${state.text}${separator}${entry.value}` + return false + } + + if (index === 0) { + state.text = `${entry.value}${separator}${state.text}` + return false + } + + return true +} + +function fallbackStringAggReduce( + values: Array<[StringAggValue, number]>, + separator: string, +): string { + const orderedEntries: Array> = [] + + for (const [entry, multiplicity] of values) { + if (multiplicity <= 0 || entry.value == null) { + continue + } + + for (let i = 0; i < multiplicity; i++) { + orderedEntries.push({ + // Fallback path has no stable row identity, so reuse the string value as + // a deterministic tie-breaker when orderBy values collide. + rowKey: entry.value, + value: entry.value, + orderBy: entry.orderBy, + }) + } + } + + orderedEntries.sort(compareStringAggEntries) + + return orderedEntries.map((entry) => entry.value).join(separator) +} + +/** + * Creates a string aggregation function that concatenates string values ordered + * by orderByExtractor and then rowKeyExtractor. + * When rowKeyExtractor is omitted, ties fall back to the string value itself. + * @param valueExtractor Function to extract the string value from each data entry + * @param separator Separator inserted between aggregated values + * @param orderByExtractor Function to extract the ordering value for deterministic concatenation + * @param rowKeyExtractor Optional stable row identity used to break orderBy ties deterministically + */ +export function stringAgg( + valueExtractor: (value: T) => string | null | undefined = (v) => + v as unknown as string, + separator: string = ``, + orderByExtractor: (value: T) => TOrderBy = () => + undefined as unknown as TOrderBy, + rowKeyExtractor?: (value: T) => string | number, +): AggregateFunction, string> { + const groupStates = new Map>() + + const preMap = (data: T): StringAggValue => ({ + rowKey: rowKeyExtractor?.(data), + value: valueExtractor(data), + orderBy: orderByExtractor(data), + }) + + if (!rowKeyExtractor) { + return { + preMap, + reduce: (values) => fallbackStringAggReduce(values, separator), + } + } + + return { + preMap, + reduce: (values, groupKey) => { + let state = groupStates.get(groupKey) + if (!state) { + state = { + entriesByKey: new Map(), + orderedEntries: [], + text: ``, + } + groupStates.set(groupKey, state) + } + + const nextEntriesByKey = new Map< + string | number, + StringAggEntry + >() + + for (const [entry, multiplicity] of values) { + if (entry.rowKey == null || multiplicity <= 0 || entry.value == null) { + continue + } + + nextEntriesByKey.set(entry.rowKey, { + rowKey: entry.rowKey, + value: entry.value, + orderBy: entry.orderBy, + }) + } + + const touchedRowKeys = new Set([ + ...state.entriesByKey.keys(), + ...nextEntriesByKey.keys(), + ]) + + let textDirty = false + + for (const rowKey of touchedRowKeys) { + const previousEntry = state.entriesByKey.get(rowKey) + const nextEntry = nextEntriesByKey.get(rowKey) + + if ( + previousEntry && + nextEntry && + previousEntry.value === nextEntry.value && + compareStringAggEntries(previousEntry, nextEntry) === 0 + ) { + continue + } + + if (previousEntry) { + const removedNeedsRebuild = removeStringAggEntry( + state, + previousEntry, + separator, + ) + textDirty = textDirty || removedNeedsRebuild + state.entriesByKey.delete(rowKey) + } + + if (nextEntry) { + const insertedNeedsRebuild = insertStringAggEntry( + state, + nextEntry, + separator, + ) + textDirty = textDirty || insertedNeedsRebuild + state.entriesByKey.set(rowKey, nextEntry) + } + } + + if (textDirty) { + buildStringAggText(state, separator) + } + + return state.text + }, + cleanup: (groupKey) => { + groupStates.delete(groupKey) + }, + } +} + type CanMinMax = number | Date | bigint | string /** @@ -233,7 +494,7 @@ export function min( preMap: (data: T) => extractor(data), reduce: (values) => { let minValue: V | undefined - for (const [value, _multiplicity] of values) { + for (const [value] of values) { if (!minValue || (value && value < minValue)) { minValue = value } @@ -263,7 +524,7 @@ export function max( preMap: (data: T) => extractor(data), reduce: (values) => { let maxValue: V | undefined - for (const [value, _multiplicity] of values) { + for (const [value] of values) { if (!maxValue || (value && value > maxValue)) { maxValue = value } @@ -284,38 +545,30 @@ export function median( return { preMap: (data: T) => [valueExtractor(data)], reduce: (values: Array<[Array, number]>) => { - // Flatten all values, taking multiplicity into account const allValues: Array = [] for (const [valueArray, multiplicity] of values) { for (const value of valueArray) { - // Add each value multiple times based on multiplicity for (let i = 0; i < multiplicity; i++) { allValues.push(value) } } } - // Return empty array if no values if (allValues.length === 0) { return [] } - // Sort values allValues.sort((a, b) => a - b) - return allValues }, postMap: (result: Array) => { if (result.length === 0) return 0 const mid = Math.floor(result.length / 2) - - // If even number of values, average the two middle values if (result.length % 2 === 0) { return (result[mid - 1]! + result[mid]!) / 2 } - // If odd number of values, return the middle value return result[mid]! }, } @@ -337,7 +590,6 @@ export function mode( return frequencyMap }, reduce: (values: Array<[Map, number]>) => { - // Combine all frequency maps const combinedMap = new Map() for (const [frequencyMap, multiplicity] of values) { @@ -371,6 +623,7 @@ export const groupByOperators = { sum, count, avg, + stringAgg, min, max, median, diff --git a/packages/db-ivm/src/operators/reduce.ts b/packages/db-ivm/src/operators/reduce.ts index 3a8690e01..af255d226 100644 --- a/packages/db-ivm/src/operators/reduce.ts +++ b/packages/db-ivm/src/operators/reduce.ts @@ -11,13 +11,13 @@ import type { IStreamBuilder, KeyValue } from '../types.js' export class ReduceOperator extends UnaryOperator<[K, V1], [K, V2]> { #index = new Index() #indexOut = new Index() - #f: (values: Array<[V1, number]>) => Array<[V2, number]> + #f: (values: Array<[V1, number]>, key: K) => Array<[V2, number]> constructor( id: number, inputA: DifferenceStreamReader<[K, V1]>, output: DifferenceStreamWriter<[K, V2]>, - f: (values: Array<[V1, number]>) => Array<[V2, number]>, + f: (values: Array<[V1, number]>, key: K) => Array<[V2, number]>, ) { super(id, inputA, output) this.#f = f @@ -39,7 +39,7 @@ export class ReduceOperator extends UnaryOperator<[K, V1], [K, V2]> { for (const key of keysTodo) { const curr = this.#index.get(key) const currOut = this.#indexOut.get(key) - const out = this.#f(curr) + const out = this.#f(curr, key) // Create maps for current and previous outputs using values directly as keys const newOutputMap = new Map() @@ -105,7 +105,7 @@ export function reduce< V1Type extends T extends KeyValue ? V : never, R, T, ->(f: (values: Array<[V1Type, number]>) => Array<[R, number]>) { +>(f: (values: Array<[V1Type, number]>, key: KType) => Array<[R, number]>) { return (stream: IStreamBuilder): IStreamBuilder> => { const output = new StreamBuilder>( stream.graph, diff --git a/packages/db-ivm/tests/operators/groupBy.test.ts b/packages/db-ivm/tests/operators/groupBy.test.ts index fbe50fb33..32a9e59ab 100644 --- a/packages/db-ivm/tests/operators/groupBy.test.ts +++ b/packages/db-ivm/tests/operators/groupBy.test.ts @@ -9,6 +9,7 @@ import { median, min, mode, + stringAgg, sum, } from '../../src/operators/groupBy.js' import { output } from '../../src/operators/index.js' @@ -519,6 +520,413 @@ describe(`Operators`, () => { expect(latestMessage.getInner()).toEqual(expectedDeleteResult) }) + test(`with stringAgg ordered by a sequence field`, () => { + const graph = new D2() + const input = graph.newInput<{ + id: string + responseId: string + seq: number + text: string | null + }>() + let latestMessage: any = null + + input.pipe( + groupBy((data) => ({ responseId: data.responseId }), { + message: stringAgg( + (data) => data.text, + ``, + (data) => data.seq, + (data) => data.id, + ), + }), + output((message) => { + latestMessage = message + }), + ) + + graph.finalize() + + input.sendData( + new MultiSet([ + [{ id: `a-2`, responseId: `a`, seq: 2, text: `world` }, 1], + [{ id: `a-1`, responseId: `a`, seq: 1, text: `Hello ` }, 1], + [{ id: `a-3`, responseId: `a`, seq: 3, text: null }, 1], + [{ id: `b-1`, responseId: `b`, seq: 1, text: `Bye` }, 1], + [{ id: `b-2`, responseId: `b`, seq: 2, text: ` now` }, 1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `Hello world`, + }, + ], + 1, + ], + [ + [ + `{"responseId":"b"}`, + { + responseId: `b`, + message: `Bye now`, + }, + ], + 1, + ], + ]) + + input.sendData( + new MultiSet([ + [{ id: `a-2`, responseId: `a`, seq: 2, text: `world` }, -1], + [{ id: `a-2`, responseId: `a`, seq: 2, text: `there` }, 1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `Hello world`, + }, + ], + -1, + ], + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `Hello there`, + }, + ], + 1, + ], + ]) + + input.sendData( + new MultiSet([ + [{ id: `a-4`, responseId: `a`, seq: 4, text: `!` }, 1], + [{ id: `b-0`, responseId: `b`, seq: 0, text: `Start: ` }, 1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `Hello there`, + }, + ], + -1, + ], + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `Hello there!`, + }, + ], + 1, + ], + [ + [ + `{"responseId":"b"}`, + { + responseId: `b`, + message: `Bye now`, + }, + ], + -1, + ], + [ + [ + `{"responseId":"b"}`, + { + responseId: `b`, + message: `Start: Bye now`, + }, + ], + 1, + ], + ]) + + input.sendData( + new MultiSet([ + [{ id: `a-2`, responseId: `a`, seq: 2, text: `there` }, -1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `Hello there!`, + }, + ], + -1, + ], + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `Hello !`, + }, + ], + 1, + ], + ]) + }) + + test(`stringAgg handles large out-of-order inserts in a single batch`, () => { + const graph = new D2() + const input = graph.newInput<{ + id: string + responseId: string + seq: number + text: string + }>() + let latestMessage: any = null + + input.pipe( + groupBy((data) => ({ responseId: data.responseId }), { + message: stringAgg( + (data) => data.text, + ``, + (data) => data.seq, + (data) => data.id, + ), + }), + output((message) => { + latestMessage = message + }), + ) + + graph.finalize() + + const rows = [ + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9, + ].map((seq) => ({ + id: `delta-${seq}`, + responseId: `a`, + seq, + text: String.fromCharCode(96 + seq), + })) + + input.sendData(new MultiSet(rows.map((row) => [row, 1]))) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `abcdefghijklmnopqrst`, + }, + ], + 1, + ], + ]) + }) + + test(`stringAgg fallback path is deterministic without a row key extractor`, () => { + const graph = new D2() + const input = graph.newInput<{ + responseId: string + seq: number + text: string + }>() + let latestMessage: any = null + + input.pipe( + groupBy((data) => ({ responseId: data.responseId }), { + message: stringAgg( + (data) => data.text, + ` | `, + (data) => data.seq, + ), + }), + output((message) => { + latestMessage = message + }), + ) + + graph.finalize() + + input.sendData( + new MultiSet([ + [{ responseId: `a`, seq: 1, text: `beta` }, 1], + [{ responseId: `a`, seq: 1, text: `alpha` }, 1], + [{ responseId: `a`, seq: 2, text: `gamma` }, 1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `alpha | beta | gamma`, + }, + ], + 1, + ], + ]) + }) + + test(`stringAgg cleans up per-group state after group deletion and re-creation`, () => { + const graph = new D2() + const input = graph.newInput<{ + id: string + responseId: string + seq: number + text: string + }>() + let latestMessage: any = null + + input.pipe( + groupBy((data) => ({ responseId: data.responseId }), { + message: stringAgg( + (data) => data.text, + ``, + (data) => data.seq, + (data) => data.id, + ), + }), + output((message) => { + latestMessage = message + }), + ) + + graph.finalize() + + input.sendData( + new MultiSet([ + [{ id: `a-1`, responseId: `a`, seq: 1, text: `A` }, 1], + [{ id: `a-2`, responseId: `a`, seq: 2, text: `B` }, 1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `AB`, + }, + ], + 1, + ], + ]) + + input.sendData( + new MultiSet([ + [{ id: `a-1`, responseId: `a`, seq: 1, text: `A` }, -1], + [{ id: `a-2`, responseId: `a`, seq: 2, text: `B` }, -1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `AB`, + }, + ], + -1, + ], + ]) + + input.sendData( + new MultiSet([[{ id: `a-3`, responseId: `a`, seq: 1, text: `C` }, 1]]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `C`, + }, + ], + 1, + ], + ]) + }) + + test(`stringAgg preserves values that already contain the separator`, () => { + const graph = new D2() + const input = graph.newInput<{ + id: string + responseId: string + seq: number + text: string + }>() + let latestMessage: any = null + + input.pipe( + groupBy((data) => ({ responseId: data.responseId }), { + message: stringAgg( + (data) => data.text, + ` -> `, + (data) => data.seq, + (data) => data.id, + ), + }), + output((message) => { + latestMessage = message + }), + ) + + graph.finalize() + + input.sendData( + new MultiSet([ + [{ id: `a-1`, responseId: `a`, seq: 1, text: `A -> B` }, 1], + [{ id: `a-2`, responseId: `a`, seq: 2, text: `C` }, 1], + ]), + ) + graph.run() + + expect(latestMessage.getInner()).toEqual([ + [ + [ + `{"responseId":"a"}`, + { + responseId: `a`, + message: `A -> B -> C`, + }, + ], + 1, + ], + ]) + }) + test(`with min and max aggregates`, () => { const graph = new D2() const input = graph.newInput<{ diff --git a/packages/db/src/query/builder/functions.ts b/packages/db/src/query/builder/functions.ts index 82b806192..3cbcc4584 100644 --- a/packages/db/src/query/builder/functions.ts +++ b/packages/db/src/query/builder/functions.ts @@ -25,6 +25,13 @@ type StringLike = | null | undefined +type OrderByScalar = number | bigint | boolean | Date | null | undefined +type OrderByLike = + | RefLeaf + | RefProxy + | BasicExpression + | OrderByScalar + type ComparisonOperand = | RefProxy | RefLeaf @@ -59,6 +66,13 @@ type AggregateReturnType = : Aggregate : Aggregate +type StringAggregateReturnType = + ExtractType extends infer U + ? U extends string | undefined | null + ? Aggregate + : Aggregate + : Aggregate + // Helper type to determine string function return type based on input nullability type StringFunctionReturnType = ExtractType extends infer U @@ -325,6 +339,41 @@ export function max(arg: T): AggregateReturnType { return new Aggregate(`max`, [toExpression(arg)]) as AggregateReturnType } +export function stringAgg( + arg: T, +): StringAggregateReturnType +export function stringAgg( + arg: T, + separator: string, +): StringAggregateReturnType +export function stringAgg( + arg: T, + orderBy: TOrderBy, +): StringAggregateReturnType +export function stringAgg( + arg: T, + separator: string, + orderBy: TOrderBy, +): StringAggregateReturnType +export function stringAgg( + arg: T, + separatorOrOrderBy?: string | OrderByLike, + orderBy?: OrderByLike, +): StringAggregateReturnType { + const args: Array = [toExpression(arg) as BasicExpression] + + if (typeof separatorOrOrderBy === `string`) { + args.push(toExpression(separatorOrOrderBy) as BasicExpression) + if (orderBy !== undefined) { + args.push(toExpression(orderBy) as BasicExpression) + } + } else if (separatorOrOrderBy !== undefined) { + args.push(toExpression(separatorOrOrderBy) as BasicExpression) + } + + return new Aggregate(`stringAgg`, args) as StringAggregateReturnType +} + /** * List of comparison function names that can be used with indexes */ @@ -374,6 +423,7 @@ export const operators = [ `sum`, `min`, `max`, + `stringAgg`, ] as const export type OperatorName = (typeof operators)[number] diff --git a/packages/db/src/query/compiler/group-by.ts b/packages/db/src/query/compiler/group-by.ts index 48661ba96..ff72636cf 100644 --- a/packages/db/src/query/compiler/group-by.ts +++ b/packages/db/src/query/compiler/group-by.ts @@ -59,7 +59,7 @@ function getRowVirtualMetadata(row: NamespacedRow): RowVirtualMetadata { } } -const { sum, count, avg, min, max } = groupByOperators +const { sum, count, avg, stringAgg, min, max } = groupByOperators /** * Interface for caching the mapping between GROUP BY expressions and SELECT expressions @@ -535,6 +535,14 @@ function getAggregateFunction(aggExpr: Aggregate) { return compiledExpr(namespacedRow) } + const stringValueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => { + const value = compiledExpr(namespacedRow) + if (value == null) { + return value + } + return typeof value === `string` ? value : String(value) + } + // Return the appropriate aggregate function switch (aggExpr.name.toLowerCase()) { case `sum`: @@ -543,6 +551,53 @@ function getAggregateFunction(aggExpr: Aggregate) { return count(rawValueExtractor) case `avg`: return avg(valueExtractor) + case `stringagg`: { + // `stringAgg(value, orderBy)` and `stringAgg(value, separator, orderBy)` + // share the second argument slot. Treat literal strings as separators; + // otherwise the second argument is the orderBy expression. + const separatorOrOrderByExpr = aggExpr.args[1] + const explicitOrderByExpr = aggExpr.args[2] + + const separator = + separatorOrOrderByExpr?.type === `val` && + typeof separatorOrOrderByExpr.value === `string` + ? separatorOrOrderByExpr.value + : `` + + const orderByExpr = + explicitOrderByExpr ?? + (separatorOrOrderByExpr?.type === `val` + ? undefined + : separatorOrOrderByExpr) + + const compiledOrderByExpr = orderByExpr + ? compileExpression(orderByExpr) + : undefined + + const orderByExtractor = orderByExpr + ? ([, namespacedRow]: [string, NamespacedRow]) => { + const value = compiledOrderByExpr!(namespacedRow) + if ( + value == null || + typeof value === `string` || + typeof value === `number` || + typeof value === `bigint` || + typeof value === `boolean` || + value instanceof Date + ) { + return value + } + return String(value) + } + : ([key]: [string, NamespacedRow]) => key + + return stringAgg( + stringValueExtractor, + separator, + orderByExtractor, + ([key]: [string, NamespacedRow]) => key, + ) + } case `min`: return min(valueExtractorForMinMax) case `max`: diff --git a/packages/db/src/query/index.ts b/packages/db/src/query/index.ts index 889202e5a..97e36ed2a 100644 --- a/packages/db/src/query/index.ts +++ b/packages/db/src/query/index.ts @@ -60,6 +60,7 @@ export { sum, min, max, + stringAgg, // Includes helpers toArray, } from './builder/functions.js' diff --git a/packages/db/tests/query/builder/functions.test.ts b/packages/db/tests/query/builder/functions.test.ts index fb6cb4f35..72bdd043b 100644 --- a/packages/db/tests/query/builder/functions.test.ts +++ b/packages/db/tests/query/builder/functions.test.ts @@ -21,6 +21,7 @@ import { min, not, or, + stringAgg, sum, upper, } from '../../../src/query/builder/functions.js' @@ -274,6 +275,39 @@ describe(`QueryBuilder Functions`, () => { expect((select.min_salary as any).name).toBe(`min`) expect((select.max_salary as any).name).toBe(`max`) }) + + it(`stringAgg function works with separator and orderBy`, () => { + const query = new Query() + .from({ employees: employeesCollection }) + .groupBy(({ employees }) => employees.department_id) + .select(({ employees }) => ({ + department_id: employees.department_id, + employee_names: stringAgg(employees.name, `, `, employees.id), + })) + + const builtQuery = getQueryIR(query) + const select = builtQuery.select! + expect((select.employee_names as any).name).toBe(`stringAgg`) + expect((select.employee_names as any).args).toHaveLength(3) + }) + + it(`stringAgg function supports orderBy-only and separator-only overloads`, () => { + const query = new Query() + .from({ employees: employeesCollection }) + .groupBy(({ employees }) => employees.department_id) + .select(({ employees }) => ({ + department_id: employees.department_id, + ordered_names: stringAgg(employees.name, employees.id), + spaced_names: stringAgg(employees.name, `, `), + })) + + const builtQuery = getQueryIR(query) + const select = builtQuery.select! + expect((select.ordered_names as any).name).toBe(`stringAgg`) + expect((select.ordered_names as any).args).toHaveLength(2) + expect((select.spaced_names as any).name).toBe(`stringAgg`) + expect((select.spaced_names as any).args).toHaveLength(2) + }) }) describe(`Math functions`, () => { diff --git a/packages/db/tests/query/group-by.test-d.ts b/packages/db/tests/query/group-by.test-d.ts index 583b404d0..79c2c68fc 100644 --- a/packages/db/tests/query/group-by.test-d.ts +++ b/packages/db/tests/query/group-by.test-d.ts @@ -13,6 +13,7 @@ import { max, min, or, + stringAgg, sum, } from '../../src/query/builder/functions.js' import type { OutputWithVirtual } from '../utils.js' @@ -132,6 +133,28 @@ describe(`Query GROUP BY Types`, () => { >() }) + test(`group by customer_id with stringAgg return type`, () => { + const customerStatuses = createLiveQueryCollection({ + query: (q) => + q + .from({ orders: ordersCollection }) + .groupBy(({ orders }) => orders.customer_id) + .select(({ orders }) => ({ + customer_id: orders.customer_id, + statuses: stringAgg(orders.status, `, `, orders.date_instance), + })), + }) + + const customer1 = customerStatuses.get(1) + expectTypeOf(customer1).toMatchTypeOf< + | OutputWithVirtual<{ + customer_id: number + statuses: string + }> + | undefined + >() + }) + test(`group by product_category return type`, () => { const categorySummary = createLiveQueryCollection({ query: (q) => diff --git a/packages/db/tests/query/group-by.test.ts b/packages/db/tests/query/group-by.test.ts index 0e247b131..459702a4a 100644 --- a/packages/db/tests/query/group-by.test.ts +++ b/packages/db/tests/query/group-by.test.ts @@ -17,6 +17,7 @@ import { min, not, or, + stringAgg, sum, } from '../../src/query/builder/functions.js' @@ -59,6 +60,14 @@ type Order = { } } +type TextDelta = { + id: number + response_id: number + created_at: Date + seq: number + text: string | null +} + // Sample order data const sampleOrders: Array = [ { @@ -210,6 +219,44 @@ const sampleOrders: Array = [ }, ] +const sampleTextDeltas: Array = [ + { + id: 1, + response_id: 1, + created_at: new Date(`2023-03-01T00:00:02.000Z`), + seq: 2, + text: `world`, + }, + { + id: 2, + response_id: 1, + created_at: new Date(`2023-03-01T00:00:01.000Z`), + seq: 1, + text: `Hello`, + }, + { + id: 3, + response_id: 1, + created_at: new Date(`2023-03-01T00:00:03.000Z`), + seq: 3, + text: null, + }, + { + id: 4, + response_id: 2, + created_at: new Date(`2023-03-02T00:00:01.000Z`), + seq: 1, + text: `Bye`, + }, + { + id: 5, + response_id: 2, + created_at: new Date(`2023-03-02T00:00:02.000Z`), + seq: 2, + text: `now`, + }, +] + function createOrdersCollection(autoIndex: `off` | `eager` = `eager`) { return createCollection( mockSyncCollectionOptions({ @@ -221,6 +268,17 @@ function createOrdersCollection(autoIndex: `off` | `eager` = `eager`) { ) } +function createTextDeltaCollection(autoIndex: `off` | `eager` = `eager`) { + return createCollection( + mockSyncCollectionOptions({ + id: `test-text-deltas`, + getKey: (delta) => delta.id, + initialData: sampleTextDeltas, + autoIndex, + }), + ) +} + function createGroupByTests(autoIndex: `off` | `eager`): void { describe(`with autoIndex ${autoIndex}`, () => { describe(`Single Column Grouping`, () => { @@ -447,6 +505,33 @@ function createGroupByTests(autoIndex: `off` | `eager`): void { expect(customer3?.first_category).toBe(`books`) expect(customer3?.last_category).toBe(`electronics`) }) + + test(`stringAgg concatenates grouped text in explicit order`, () => { + const textDeltaCollection = createTextDeltaCollection(autoIndex) + + const responses = createLiveQueryCollection({ + startSync: true, + query: (q) => + q + .from({ delta: textDeltaCollection }) + .groupBy(({ delta }) => delta.response_id) + .select(({ delta }) => ({ + response_id: delta.response_id, + text_by_created_at: stringAgg(delta.text, delta.created_at), + text_by_seq: stringAgg(delta.text, ` `, delta.seq), + })), + }) + + const response1 = responses.get(1) + expect(response1?.response_id).toBe(1) + expect(response1?.text_by_created_at).toBe(`Helloworld`) + expect(response1?.text_by_seq).toBe(`Hello world`) + + const response2 = responses.get(2) + expect(response2?.response_id).toBe(2) + expect(response2?.text_by_created_at).toBe(`Byenow`) + expect(response2?.text_by_seq).toBe(`Bye now`) + }) }) describe(`Multiple Column Grouping`, () => {