Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.NoMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

Expand Down Expand Up @@ -42,6 +43,7 @@ public class Aggregator<in Value : Any, out Return : Any?>(
public val inputHandler: AggregatorInputHandler<Value, Return>,
public val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
public val name: String,
public val statisticsParameters: Map<String, ParameterValue?>,
) : AggregatorInputHandler<Value, Return> by inputHandler,
AggregatorMultipleColumnsHandler<Value, Return> by multipleColumnsHandler,
AggregatorAggregationHandler<Value, Return> by aggregationHandler {
Expand Down Expand Up @@ -75,13 +77,30 @@ public class Aggregator<in Value : Any, out Return : Any?>(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
statisticsParameters: Map<String, ParameterValue?>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
statisticsParameters = statisticsParameters,
)
}

internal operator fun <Value : Any, Return : Any?> invoke(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
emptyMap(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import org.jetbrains.kotlinx.dataframe.impl.columns.StatisticResult
import org.jetbrains.kotlinx.dataframe.impl.columns.ValueColumnInternal
import kotlin.reflect.KType

/**
Expand All @@ -26,13 +29,34 @@ public interface AggregatorAggregationHandler<in Value : Any, out Return : Any?>

/**
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
* Calls [aggregateSequence]. It tries to exploit a cache for statistics which is proper of
* [ValueColumnInternal]
*/
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return {
if (column is ValueColumnInternal<*>) {
// cache check, cache is dynamically created
val aggregator = this.aggregator ?: throw IllegalStateException("Aggregator is required")
val desiredStatisticNotConsideringParameters = column.statistics.getOrPut(aggregator.name) {
mutableMapOf<Map<String, ParameterValue?>, StatisticResult>()
}
// can't compare maps whose Values are Any? -> ParameterValue instead
val desiredStatistic = desiredStatisticNotConsideringParameters[aggregator.statisticsParameters]
// if desiredStatistic is null, statistic was never calculated
if (desiredStatistic != null) {
return desiredStatistic.value as Return
}
val statistic = aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
desiredStatisticNotConsideringParameters[aggregator.statisticsParameters] = StatisticResult(statistic)
return aggregateSingleColumn(column)
}
return aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
}

/**
* Function that can give the return type of [aggregateSequence] as [KType], given the type of the input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.inputHandlers.NumberInputHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.FlatteningMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColumnsHandlers.TwoStepMultipleColumnsHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.ParameterValue
import org.jetbrains.kotlinx.dataframe.math.indexOfMax
import org.jetbrains.kotlinx.dataframe.math.indexOfMedian
import org.jetbrains.kotlinx.dataframe.math.indexOfMin
Expand Down Expand Up @@ -35,10 +36,12 @@ public object Aggregators {
getReturnType: CalculateReturnType,
indexOfResult: IndexOfResult<Value>,
stepOneSelector: Selector<Value, Return>,
statisticsParameters: Map<String, ParameterValue?>,
) = Aggregator(
aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType),
inputHandler = AnyInputHandler(),
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
statisticsParameters = statisticsParameters,
)

private fun <Value : Any, Return : Any?> flattenHybridForAny(
Expand Down Expand Up @@ -117,8 +120,9 @@ public object Aggregators {
by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = minTypeConversion,
stepOneSelector = { type -> minOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMin(type, skipNaN) },
stepOneSelector = { type -> minOrNull(type, skipNaN) },
statisticsParameters = mapOf<String, ParameterValue?>(Pair("skipNaN", ParameterValue(skipNaN))),
)
}

Expand All @@ -132,6 +136,7 @@ public object Aggregators {
getReturnType = maxTypeConversion,
stepOneSelector = { type -> maxOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMax(type, skipNaN) },
statisticsParameters = mapOf<String, ParameterValue?>(Pair("skipNaN", ParameterValue(skipNaN))),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,41 @@ import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

@JvmInline
internal value class StatisticResult(val value: Any?)

public class ParameterValue(public val parameter: Any?) {

override fun equals(other: Any?): Boolean {
val otherAsParameterValue = other as ParameterValue?
val that = otherAsParameterValue?.parameter
if (parameter is Boolean && that is Boolean) {
return this.parameter == that
}
return super.equals(other)
}

override fun hashCode(): Int {
if (parameter is Boolean?) {
return this.parameter.hashCode()
}
return super.hashCode()
}
}

internal interface ValueColumnInternal<T> : ValueColumn<T> {
val statistics: MutableMap<String, MutableMap<Map<String, ParameterValue?>, StatisticResult>>
}

internal open class ValueColumnImpl<T>(
values: List<T>,
name: String,
type: KType,
val defaultValue: T? = null,
distinct: Lazy<Set<T>>? = null,
) : DataColumnImpl<T>(values, name, type, distinct),
ValueColumn<T> {
ValueColumn<T>,
ValueColumnInternal<T> {

override fun distinct() = ValueColumnImpl(toSet().toList(), name, type, defaultValue, distinct)

Expand Down Expand Up @@ -48,10 +75,13 @@ internal open class ValueColumnImpl<T>(
override fun defaultValue() = defaultValue

override fun forceResolve() = ResolvingValueColumn(this)

override val statistics = mutableMapOf<String, MutableMap<Map<String, ParameterValue?>, StatisticResult>>()
}

internal class ResolvingValueColumn<T>(override val source: ValueColumn<T>) :
ValueColumn<T> by source,
ValueColumnInternal<T>,
ForceResolvedColumn<T> {

override fun resolve(context: ColumnResolutionContext) = super<ValueColumn>.resolve(context)
Expand All @@ -70,4 +100,6 @@ internal class ResolvingValueColumn<T>(override val source: ValueColumn<T>) :
override fun equals(other: Any?) = source.checkEquals(other)

override fun hashCode(): Int = source.hashCode()

override val statistics = mutableMapOf<String, MutableMap<Map<String, ParameterValue?>, StatisticResult>>()
}
Loading