Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;

Expand All @@ -15,7 +14,7 @@ public static partial class EnumerableExtensions
/// <param name="source"></param>
/// <param name="value">The first element if found</param>
/// <typeparam name="T"></typeparam>
/// <returns>True if element contains an item, otherwise false</returns>
/// <returns>True if <paramref name="source"/> contains an item, otherwise false</returns>
/// <exception cref="ArgumentNullException"></exception>
public static bool TryGetFirstValue<T>(this IEnumerable<T> source, [MaybeNullWhen(false)] out T value)
{
Expand All @@ -30,11 +29,11 @@ public static bool TryGetFirstValue<T>(this IEnumerable<T> source, [MaybeNullWhe
return false;
}
#else
if (source is ICollection<T> {Count: 0} or ICollection {Count: 0})
{
value = default;
return false;
}
if (source is ICollection<T> {Count: 0} or ICollection {Count: 0})
{
value = default;
return false;
}
#endif

if (source is IList<T> list)
Expand All @@ -59,44 +58,67 @@ public static bool TryGetFirstValue<T>(this IEnumerable<T> source, [MaybeNullWhe
/// <summary>
/// Casts a sequence to an <see cref="ICollection{T}"/>, otherwise enumerates the sequence.
/// </summary>
/// <param name="enumerable"></param>
/// <param name="source"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static ICollection<T> AsCollection<T>(this IEnumerable<T> enumerable)
public static ICollection<T> AsCollection<T>(this IEnumerable<T> source)
{
if (enumerable is ICollection<T> collection)
if (source is null)
throw new ArgumentNullException(nameof(source));

if (source is ICollection<T> collection)
return collection;

return enumerable.ToList();
return source.ToArray();
}

/// <summary>
/// Casts a sequence to a <see cref="IReadOnlyList{T}"/>, otherwise enumerates the sequence.
/// </summary>
/// <param name="enumerable"></param>
/// <param name="source"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static IReadOnlyList<T> AsReadOnlyList<T>(this IEnumerable<T> enumerable)
public static IReadOnlyList<T> AsReadOnlyList<T>(this IEnumerable<T> source)
{
if (enumerable is IReadOnlyList<T> collection)
if (source is null)
throw new ArgumentNullException(nameof(source));

if (source is IReadOnlyList<T> collection)
return collection;

return new ReadOnlyCollection<T>(enumerable.ToList());
return source.ToArray();
}

/// <summary>
/// Casts a sequence to a <see cref="IReadOnlyCollection{T}" />, otherwise enumerates the sequence.
/// </summary>
/// <param name="source"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static IReadOnlyCollection<T> AsReadOnlyCollection<T>(this IEnumerable<T> source)
{
if (source is null)
throw new ArgumentNullException(nameof(source));

if (source is IReadOnlyCollection<T> collection)
return collection;

return source.ToArray();
}

/// <summary>
/// Enumerates a sequence and returns only non null values
/// </summary>
/// <param name="enumerable"></param>
/// <param name="source"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
public static IEnumerable<T> WhereNotNull<T>(this IEnumerable<T?> enumerable) where T : class
public static IEnumerable<T> WhereNotNull<T>(this IEnumerable<T?> source)
{
if (enumerable is null)
throw new ArgumentNullException(nameof(enumerable));
if (source is null)
throw new ArgumentNullException(nameof(source));

foreach (var t in enumerable)
foreach (var t in source)
{
if (t is not null)
yield return t;
Expand All @@ -106,27 +128,26 @@ public static IEnumerable<T> WhereNotNull<T>(this IEnumerable<T?> enumerable) wh
/// <summary>
/// Enumerates a sequence and returns only non null values
/// </summary>
/// <param name="enumerable"></param>
/// <param name="source"></param>
/// <param name="selector"></param>
/// <typeparam name="T"></typeparam>
/// <typeparam name="TResult"></typeparam>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
public static IEnumerable<TResult> WhereNotNull<T, TResult>(
this IEnumerable<T> enumerable,
this IEnumerable<T> source,
Func<T, TResult?> selector)
where TResult : class
{
if (enumerable is null)
throw new ArgumentNullException(nameof(enumerable));
if (source is null)
throw new ArgumentNullException(nameof(source));

if (selector is null)
throw new ArgumentNullException(nameof(selector));

foreach (var t in enumerable.Select(selector))
foreach (var t in source.Select(selector))
{
if (t is not null)
yield return t;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public void TryGetFirstValue_Array()
public void TryGetFirstValue_Enumerable()
{
// Where() to force an Enumerable
IEnumerable<int> list = Enumerable.Range(1, 100).Where(x => true);
var list = Enumerable.Range(1, 100).Where(x => true);

Assert.True(list.TryGetFirstValue(out var value));
Assert.Equal(1, value);
Expand All @@ -36,8 +36,33 @@ public void TryGetFirstValue_Empty()
public void TryGetFirstValue_EmptyEnumerable()
{
// Where() to force an Enumerable
IEnumerable<int> list = Enumerable.Range(1, 100).Where(x => false);
var list = Enumerable.Range(1, 100).Where(x => false);

Assert.False(list.TryGetFirstValue(out _));
}

[Fact]
public void WhereNotNull_ValueTypeListContainsNull_ReturnsNoNullValues()
{
int?[] list = [1, null, 2, null, 3];

var result = list.WhereNotNull();

Assert.Equal(result, [1, 2, 3]);
}

[Fact]
public void WhereNotNull_SelectorClassTypeListContainsNull_ReturnsNoNullValues()
{
Test[] list = [new("1"), new(null), new("2"), new(null)];

var result = list.WhereNotNull(x => x.Value);

Assert.Equal(result, ["1", "2"]);
}

private class Test(string? value)
{
public string? Value { get; } = value;
}
}
Loading