diff --git a/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs b/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs
index 80aa7f9..fed404b 100644
--- a/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs
+++ b/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs
@@ -3,7 +3,6 @@
using System;
using System.Collections;
using System.Collections.Generic;
-using System.Collections.ObjectModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
@@ -15,7 +14,7 @@ public static partial class EnumerableExtensions
///
/// The first element if found
///
- /// True if element contains an item, otherwise false
+ /// True if contains an item, otherwise false
///
public static bool TryGetFirstValue(this IEnumerable source, [MaybeNullWhen(false)] out T value)
{
@@ -30,11 +29,11 @@ public static bool TryGetFirstValue(this IEnumerable source, [MaybeNullWhe
return false;
}
#else
- if (source is ICollection {Count: 0} or ICollection {Count: 0})
- {
- value = default;
- return false;
- }
+ if (source is ICollection {Count: 0} or ICollection {Count: 0})
+ {
+ value = default;
+ return false;
+ }
#endif
if (source is IList list)
@@ -59,44 +58,67 @@ public static bool TryGetFirstValue(this IEnumerable source, [MaybeNullWhe
///
/// Casts a sequence to an , otherwise enumerates the sequence.
///
- ///
+ ///
///
///
- public static ICollection AsCollection(this IEnumerable enumerable)
+ public static ICollection AsCollection(this IEnumerable source)
{
- if (enumerable is ICollection collection)
+ if (source is null)
+ throw new ArgumentNullException(nameof(source));
+
+ if (source is ICollection collection)
return collection;
- return enumerable.ToList();
+ return source.ToArray();
}
///
/// Casts a sequence to a , otherwise enumerates the sequence.
///
- ///
+ ///
///
///
- public static IReadOnlyList AsReadOnlyList(this IEnumerable enumerable)
+ public static IReadOnlyList AsReadOnlyList(this IEnumerable source)
{
- if (enumerable is IReadOnlyList collection)
+ if (source is null)
+ throw new ArgumentNullException(nameof(source));
+
+ if (source is IReadOnlyList collection)
return collection;
- return new ReadOnlyCollection(enumerable.ToList());
+ return source.ToArray();
+ }
+
+ ///
+ /// Casts a sequence to a , otherwise enumerates the sequence.
+ ///
+ ///
+ ///
+ ///
+ public static IReadOnlyCollection AsReadOnlyCollection(this IEnumerable source)
+ {
+ if (source is null)
+ throw new ArgumentNullException(nameof(source));
+
+ if (source is IReadOnlyCollection collection)
+ return collection;
+
+ return source.ToArray();
}
///
/// Enumerates a sequence and returns only non null values
///
- ///
+ ///
///
///
///
- public static IEnumerable WhereNotNull(this IEnumerable enumerable) where T : class
+ public static IEnumerable WhereNotNull(this IEnumerable source)
{
- if (enumerable is null)
- throw new ArgumentNullException(nameof(enumerable));
+ if (source is null)
+ throw new ArgumentNullException(nameof(source));
- foreach (var t in enumerable)
+ foreach (var t in source)
{
if (t is not null)
yield return t;
@@ -106,27 +128,26 @@ public static IEnumerable WhereNotNull(this IEnumerable enumerable) wh
///
/// Enumerates a sequence and returns only non null values
///
- ///
+ ///
///
///
///
///
///
public static IEnumerable WhereNotNull(
- this IEnumerable enumerable,
+ this IEnumerable source,
Func selector)
- where TResult : class
{
- if (enumerable is null)
- throw new ArgumentNullException(nameof(enumerable));
+ if (source is null)
+ throw new ArgumentNullException(nameof(source));
if (selector is null)
throw new ArgumentNullException(nameof(selector));
- foreach (var t in enumerable.Select(selector))
+ foreach (var t in source.Select(selector))
{
if (t is not null)
yield return t;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs b/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs
index fdabecb..f297276 100644
--- a/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs
+++ b/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs
@@ -18,7 +18,7 @@ public void TryGetFirstValue_Array()
public void TryGetFirstValue_Enumerable()
{
// Where() to force an Enumerable
- IEnumerable list = Enumerable.Range(1, 100).Where(x => true);
+ var list = Enumerable.Range(1, 100).Where(x => true);
Assert.True(list.TryGetFirstValue(out var value));
Assert.Equal(1, value);
@@ -36,8 +36,33 @@ public void TryGetFirstValue_Empty()
public void TryGetFirstValue_EmptyEnumerable()
{
// Where() to force an Enumerable
- IEnumerable list = Enumerable.Range(1, 100).Where(x => false);
+ var list = Enumerable.Range(1, 100).Where(x => false);
Assert.False(list.TryGetFirstValue(out _));
}
+
+ [Fact]
+ public void WhereNotNull_ValueTypeListContainsNull_ReturnsNoNullValues()
+ {
+ int?[] list = [1, null, 2, null, 3];
+
+ var result = list.WhereNotNull();
+
+ Assert.Equal(result, [1, 2, 3]);
+ }
+
+ [Fact]
+ public void WhereNotNull_SelectorClassTypeListContainsNull_ReturnsNoNullValues()
+ {
+ Test[] list = [new("1"), new(null), new("2"), new(null)];
+
+ var result = list.WhereNotNull(x => x.Value);
+
+ Assert.Equal(result, ["1", "2"]);
+ }
+
+ private class Test(string? value)
+ {
+ public string? Value { get; } = value;
+ }
}