diff --git a/src/Microsoft.Data.Analysis/GroupBy.cs b/src/Microsoft.Data.Analysis/GroupBy.cs
index 357fa80a63..93aec05b49 100644
--- a/src/Microsoft.Data.Analysis/GroupBy.cs
+++ b/src/Microsoft.Data.Analysis/GroupBy.cs
@@ -9,6 +9,27 @@
namespace Microsoft.Data.Analysis
{
+ ///
+ /// A record to identify the row that is being aggregated that can be used to decide whether or not to include it in the aggregation.
+ ///
+ public record GroupByPredicateInput
+ {
+ ///
+ /// The name of the column that is being aggregated
+ ///
+ public string ColumnName { get; set; }
+
+ ///
+ /// The value from the GroupBy column that this group is grouped on
+ ///
+ public object GroupKey { get; set; }
+
+ ///
+ /// The value of this row within the column that is being aggregated
+ ///
+ public object RowValue { get; set; }
+ }
+
///
/// A GroupBy class that is typically the result of a DataFrame.GroupBy call.
/// It holds information to perform typical aggregation ops on it.
@@ -16,14 +37,31 @@ namespace Microsoft.Data.Analysis
public abstract class GroupBy
{
///
- /// Compute the number of non-null values in each group
+ /// Compute the number of non-null values in each group
///
+ /// The columns within which to compute the number of non-null values in each group. A default value includes all columns.
///
public abstract DataFrame Count(params string[] columnNames);
+ ///
+ /// Compute the number of values in each group that match a custom predicate
+ ///
+ /// A function that takes in the column name, group key, and row value and returns true to include that row in the group count or false to exclude it.
+ /// The columns within which to compute the number of values in each group that match the predicate. A default value includes all columns.
+ ///
+ public abstract DataFrame CountIf(Func predicate, params string[] columnNames);
+
+ ///
+ /// Compute the number of distinct non-null values in each group
+ ///
+ /// The columns within which to compute the number of distinct non-null values in each group. A default value includes all columns.
+ ///
+ public abstract DataFrame CountDistinct(params string[] columnNames);
+
///
/// Return the first value in each group
///
+ /// Names of the columns to aggregate
///
public abstract DataFrame First(params string[] columnNames);
@@ -140,6 +178,11 @@ private void EnumerateColumnsWithRows(GroupByColumnDelegate groupByColumnDelegat
}
public override DataFrame Count(params string[] columnNames)
+ {
+ return CountIf(input => input.RowValue != null, columnNames);
+ }
+
+ public override DataFrame CountIf(Func predicate, params string[] columnNames)
{
DataFrame ret = new DataFrame();
PrimitiveDataFrameColumn empty = new PrimitiveDataFrameColumn("Empty");
@@ -156,10 +199,19 @@ public override DataFrame Count(params string[] columnNames)
return;
DataFrameColumn column = _dataFrame.Columns[columnIndex];
long count = 0;
+ var groupByPredicateInput = new GroupByPredicateInput
+ {
+ ColumnName = column.Name,
+ GroupKey = firstColumn[rowIndex]
+ };
foreach (long row in rowEnumerable)
{
- if (column[row] != null)
+ groupByPredicateInput.RowValue = column[row];
+
+ if (predicate(groupByPredicateInput))
+ {
count++;
+ }
}
DataFrameColumn retColumn;
if (firstGroup)
@@ -182,6 +234,26 @@ public override DataFrame Count(params string[] columnNames)
return ret;
}
+ public override DataFrame CountDistinct(params string[] columnNames)
+ {
+ HashSet seenValues = [];
+
+ return CountIf(
+ input =>
+ {
+ if (input.RowValue == null || seenValues.Contains(input))
+ {
+ return false;
+ }
+
+ seenValues.Add(input);
+
+ return true;
+ },
+ columnNames
+ );
+ }
+
public override DataFrame First(params string[] columnNames)
{
DataFrame ret = new DataFrame();
diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs
index 2d75caef72..6320d15aba 100644
--- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs
+++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs
@@ -429,6 +429,50 @@ public void TestGroupBy()
Assert.Equal(2, firstDecimalColumn.Rows.Count);
Assert.Equal((decimal)0, firstDecimalColumn.Columns["Decimal"][0]);
Assert.Equal((decimal)1, firstDecimalColumn.Columns["Decimal"][1]);
+
+ var dfWithDuplicates = new DataFrame(
+ new Int32DataFrameColumn("Group", [1, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
+ new Int32DataFrameColumn("Int", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]),
+ new DoubleDataFrameColumn("Double", [1, 2, 3, 4, null, 1, 1, 2, 3, 4]),
+ new StringDataFrameColumn("String", ["1", "2", "3", "4", null, "1", "1", "2", "3", "4"]),
+ new DateTimeDataFrameColumn("DateTime", [
+ new DateTime(2026, 1, 1, 0, 0, 0),
+ new DateTime(2026, 1, 1, 0, 0, 1),
+ new DateTime(2026, 1, 1, 0, 0, 2),
+ new DateTime(2026, 1, 1, 0, 0, 3),
+ null,
+ new DateTime(2026, 1, 1, 0, 0, 0),
+ new DateTime(2026, 1, 1, 0, 0, 0),
+ new DateTime(2026, 1, 1, 0, 0, 1),
+ new DateTime(2026, 1, 1, 0, 0, 2),
+ new DateTime(2026, 1, 1, 0, 0, 3)
+ ])
+ );
+
+ DataFrame countDistinct = dfWithDuplicates.GroupBy("Group").CountDistinct();
+ Assert.Equal(5, countDistinct.Columns.Count);
+ Assert.Equal(2, countDistinct.Rows.Count);
+
+ foreach (var columnName in countDistinct.Columns.Select(c => c.Name))
+ {
+ if (columnName == "Group")
+ {
+ continue;
+ }
+
+ var column = (PrimitiveDataFrameColumn)countDistinct[columnName];
+
+ for (int row = 0; row < countDistinct.Rows.Count; row++)
+ {
+ Assert.Equal(4, column[row]);
+ }
+ }
+
+ DataFrame countIf = dfWithDuplicates.GroupBy("Group").CountIf((GroupByPredicateInput input) => input.RowValue is int and < 3, "Int");
+ Assert.Equal(2, countIf.Columns.Count);
+ Assert.Equal(2, countIf.Rows.Count);
+ Assert.Equal(2L, countIf["Int"][0]);
+ Assert.Equal(3L, countIf["Int"][1]);
}
[Fact]