diff --git a/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs b/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs index 80aa7f9..fed404b 100644 --- a/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs +++ b/src/Primitives/Primitives/src/Extensions/EnumerableExtensions.cs @@ -3,7 +3,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -15,7 +14,7 @@ public static partial class EnumerableExtensions /// /// The first element if found /// - /// True if element contains an item, otherwise false + /// True if contains an item, otherwise false /// public static bool TryGetFirstValue(this IEnumerable source, [MaybeNullWhen(false)] out T value) { @@ -30,11 +29,11 @@ public static bool TryGetFirstValue(this IEnumerable source, [MaybeNullWhe return false; } #else - if (source is ICollection {Count: 0} or ICollection {Count: 0}) - { - value = default; - return false; - } + if (source is ICollection {Count: 0} or ICollection {Count: 0}) + { + value = default; + return false; + } #endif if (source is IList list) @@ -59,44 +58,67 @@ public static bool TryGetFirstValue(this IEnumerable source, [MaybeNullWhe /// /// Casts a sequence to an , otherwise enumerates the sequence. /// - /// + /// /// /// - public static ICollection AsCollection(this IEnumerable enumerable) + public static ICollection AsCollection(this IEnumerable source) { - if (enumerable is ICollection collection) + if (source is null) + throw new ArgumentNullException(nameof(source)); + + if (source is ICollection collection) return collection; - return enumerable.ToList(); + return source.ToArray(); } /// /// Casts a sequence to a , otherwise enumerates the sequence. /// - /// + /// /// /// - public static IReadOnlyList AsReadOnlyList(this IEnumerable enumerable) + public static IReadOnlyList AsReadOnlyList(this IEnumerable source) { - if (enumerable is IReadOnlyList collection) + if (source is null) + throw new ArgumentNullException(nameof(source)); + + if (source is IReadOnlyList collection) return collection; - return new ReadOnlyCollection(enumerable.ToList()); + return source.ToArray(); + } + + /// + /// Casts a sequence to a , otherwise enumerates the sequence. + /// + /// + /// + /// + public static IReadOnlyCollection AsReadOnlyCollection(this IEnumerable source) + { + if (source is null) + throw new ArgumentNullException(nameof(source)); + + if (source is IReadOnlyCollection collection) + return collection; + + return source.ToArray(); } /// /// Enumerates a sequence and returns only non null values /// - /// + /// /// /// /// - public static IEnumerable WhereNotNull(this IEnumerable enumerable) where T : class + public static IEnumerable WhereNotNull(this IEnumerable source) { - if (enumerable is null) - throw new ArgumentNullException(nameof(enumerable)); + if (source is null) + throw new ArgumentNullException(nameof(source)); - foreach (var t in enumerable) + foreach (var t in source) { if (t is not null) yield return t; @@ -106,27 +128,26 @@ public static IEnumerable WhereNotNull(this IEnumerable enumerable) wh /// /// Enumerates a sequence and returns only non null values /// - /// + /// /// /// /// /// /// public static IEnumerable WhereNotNull( - this IEnumerable enumerable, + this IEnumerable source, Func selector) - where TResult : class { - if (enumerable is null) - throw new ArgumentNullException(nameof(enumerable)); + if (source is null) + throw new ArgumentNullException(nameof(source)); if (selector is null) throw new ArgumentNullException(nameof(selector)); - foreach (var t in enumerable.Select(selector)) + foreach (var t in source.Select(selector)) { if (t is not null) yield return t; } } -} \ No newline at end of file +} diff --git a/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs b/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs index fdabecb..f297276 100644 --- a/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs +++ b/src/Primitives/Primitives/test/Extensions/EnumerableExtensionTests.cs @@ -18,7 +18,7 @@ public void TryGetFirstValue_Array() public void TryGetFirstValue_Enumerable() { // Where() to force an Enumerable - IEnumerable list = Enumerable.Range(1, 100).Where(x => true); + var list = Enumerable.Range(1, 100).Where(x => true); Assert.True(list.TryGetFirstValue(out var value)); Assert.Equal(1, value); @@ -36,8 +36,33 @@ public void TryGetFirstValue_Empty() public void TryGetFirstValue_EmptyEnumerable() { // Where() to force an Enumerable - IEnumerable list = Enumerable.Range(1, 100).Where(x => false); + var list = Enumerable.Range(1, 100).Where(x => false); Assert.False(list.TryGetFirstValue(out _)); } + + [Fact] + public void WhereNotNull_ValueTypeListContainsNull_ReturnsNoNullValues() + { + int?[] list = [1, null, 2, null, 3]; + + var result = list.WhereNotNull(); + + Assert.Equal(result, [1, 2, 3]); + } + + [Fact] + public void WhereNotNull_SelectorClassTypeListContainsNull_ReturnsNoNullValues() + { + Test[] list = [new("1"), new(null), new("2"), new(null)]; + + var result = list.WhereNotNull(x => x.Value); + + Assert.Equal(result, ["1", "2"]); + } + + private class Test(string? value) + { + public string? Value { get; } = value; + } }