diff --git a/src/CommonLib/AdaptiveTimeout.cs b/src/CommonLib/AdaptiveTimeout.cs index 409266c7..6fcd1969 100644 --- a/src/CommonLib/AdaptiveTimeout.cs +++ b/src/CommonLib/AdaptiveTimeout.cs @@ -62,14 +62,15 @@ public void ClearSamples() { /// /// /// + /// A method that is used to observe the latency of the request. /// Returns a Fail result if a task runs longer than its budgeted time. - public async Task> ExecuteWithTimeout(Func func, CancellationToken parentToken = default) { + public async Task> ExecuteWithTimeout(Func func, CancellationToken parentToken = default, Action latencyObservation = null) { DateTime startTime = default; var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => { startTime = DateTime.Now; // for ordinal tracking; see use in TimeSpikeSafetyValve return func(timeoutToken); - }), parentToken); + }, latencyObservation), parentToken); TimeSpikeSafetyValve(result.IsSuccess, startTime); return result; } @@ -84,14 +85,15 @@ public async Task> ExecuteWithTimeout(Func fu /// /// /// + /// A method that is used to observe the latency of the request. /// Returns a Fail result if a task runs longer than its budgeted time. - public async Task ExecuteWithTimeout(Action func, CancellationToken parentToken = default) { + public async Task ExecuteWithTimeout(Action func, CancellationToken parentToken = default, Action latencyObservation = null) { DateTime startTime = default; var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => { startTime = DateTime.Now; // for ordinal tracking; see use in TimeSpikeSafetyValve func(timeoutToken); - }), parentToken); + }, latencyObservation), parentToken); TimeSpikeSafetyValve(result.IsSuccess, startTime); return result; } @@ -107,14 +109,15 @@ public async Task ExecuteWithTimeout(Action func, Can /// /// /// + /// A method that is used to observe the latency of the request. /// Returns a Fail result if a task runs longer than its budgeted time. - public async Task> ExecuteWithTimeout(Func> func, CancellationToken parentToken = default) { + public async Task> ExecuteWithTimeout(Func> func, CancellationToken parentToken = default, Action latencyObservation = null) { DateTime startTime = default; var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => { startTime = DateTime.Now; // for ordinal tracking; see use in TimeSpikeSafetyValve return func(timeoutToken); - }), parentToken); + }, latencyObservation), parentToken); TimeSpikeSafetyValve(result.IsSuccess, startTime); return result; } @@ -129,14 +132,15 @@ public async Task> ExecuteWithTimeout(Func /// /// + /// A method that is used to observe the latency of the request. /// Returns a Fail result if a task runs longer than its budgeted time. - public async Task ExecuteWithTimeout(Func func, CancellationToken parentToken = default) { + public async Task ExecuteWithTimeout(Func func, CancellationToken parentToken = default, Action latencyObservation = null) { DateTime startTime = default; var result = await Timeout.ExecuteWithTimeout(GetAdaptiveTimeout(), (timeoutToken) => _sampler.SampleExecutionTime(() => { startTime = DateTime.Now; // for ordinal tracking; see use in TimeSpikeSafetyValve return func(timeoutToken); - }), parentToken); + }, latencyObservation), parentToken); TimeSpikeSafetyValve(result.IsSuccess, startTime); return result; } diff --git a/src/CommonLib/ExecutionTimeSampler.cs b/src/CommonLib/ExecutionTimeSampler.cs index 74e0e4bd..a569f03c 100644 --- a/src/CommonLib/ExecutionTimeSampler.cs +++ b/src/CommonLib/ExecutionTimeSampler.cs @@ -43,35 +43,39 @@ public double StandardDeviation() { public double Average() => _samples.Average(); - public async Task SampleExecutionTime(Func> func) { + public async Task SampleExecutionTime(Func> func, Action latencyObservation = null) { var stopwatch = Stopwatch.StartNew(); var result = await func.Invoke(); stopwatch.Stop(); + latencyObservation?.Invoke(stopwatch.ElapsedMilliseconds); AddTimeSample(stopwatch.Elapsed); return result; } - public async Task SampleExecutionTime(Func func) { + public async Task SampleExecutionTime(Func func, Action latencyObservation = null) { var stopwatch = Stopwatch.StartNew(); await func.Invoke(); stopwatch.Stop(); + latencyObservation?.Invoke(stopwatch.ElapsedMilliseconds); AddTimeSample(stopwatch.Elapsed); } - public T SampleExecutionTime(Func func) { + public T SampleExecutionTime(Func func, Action latencyObservation = null) { var stopwatch = Stopwatch.StartNew(); var result = func.Invoke(); stopwatch.Stop(); + latencyObservation?.Invoke(stopwatch.ElapsedMilliseconds); AddTimeSample(stopwatch.Elapsed); return result; } - public void SampleExecutionTime(Action func) { + public void SampleExecutionTime(Action func, Action latencyObservation = null) { var stopwatch = Stopwatch.StartNew(); func.Invoke(); stopwatch.Stop(); + latencyObservation?.Invoke(stopwatch.ElapsedMilliseconds); AddTimeSample(stopwatch.Elapsed); } diff --git a/src/CommonLib/Interfaces/ILabelValuesCache.cs b/src/CommonLib/Interfaces/ILabelValuesCache.cs new file mode 100644 index 00000000..f6d5e817 --- /dev/null +++ b/src/CommonLib/Interfaces/ILabelValuesCache.cs @@ -0,0 +1,5 @@ +namespace SharpHoundCommonLib.Interfaces; + +public interface ILabelValuesCache { + string[] Intern(string[] values); +} \ No newline at end of file diff --git a/src/CommonLib/Interfaces/IMetricFactory.cs b/src/CommonLib/Interfaces/IMetricFactory.cs new file mode 100644 index 00000000..a36b31cb --- /dev/null +++ b/src/CommonLib/Interfaces/IMetricFactory.cs @@ -0,0 +1,5 @@ +namespace SharpHoundCommonLib.Interfaces; + +public interface IMetricFactory { + IMetricRouter CreateMetricRouter(); +} \ No newline at end of file diff --git a/src/CommonLib/Interfaces/IMetricRegistry.cs b/src/CommonLib/Interfaces/IMetricRegistry.cs new file mode 100644 index 00000000..c14a60fc --- /dev/null +++ b/src/CommonLib/Interfaces/IMetricRegistry.cs @@ -0,0 +1,9 @@ +using System.Collections.Generic; +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Interfaces; + +public interface IMetricRegistry { + bool TryRegister(MetricDefinition definition, out int definitionId); + IReadOnlyList Definitions { get; } +} \ No newline at end of file diff --git a/src/CommonLib/Interfaces/IMetricRouter.cs b/src/CommonLib/Interfaces/IMetricRouter.cs new file mode 100644 index 00000000..c4489874 --- /dev/null +++ b/src/CommonLib/Interfaces/IMetricRouter.cs @@ -0,0 +1,8 @@ +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Interfaces; + +public interface IMetricRouter { + void Observe(int definitionId, double value, LabelValues labelValues); + void Flush(); +} \ No newline at end of file diff --git a/src/CommonLib/Interfaces/IMetricSink.cs b/src/CommonLib/Interfaces/IMetricSink.cs new file mode 100644 index 00000000..e7421a56 --- /dev/null +++ b/src/CommonLib/Interfaces/IMetricSink.cs @@ -0,0 +1,8 @@ +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Interfaces; + +public interface IMetricSink { + void Observe(in MetricObservation.DoubleMetricObservation observation); + void Flush(); +} diff --git a/src/CommonLib/Interfaces/IMetricWriter.cs b/src/CommonLib/Interfaces/IMetricWriter.cs new file mode 100644 index 00000000..8c4c65cb --- /dev/null +++ b/src/CommonLib/Interfaces/IMetricWriter.cs @@ -0,0 +1,17 @@ +using System; +using System.Text; +using SharpHoundCommonLib.Models; +using SharpHoundCommonLib.Services; + +namespace SharpHoundCommonLib.Interfaces; + +public interface IMetricWriter { + void StringBuilderAppendMetric( + StringBuilder builder, + MetricDefinition definition, + LabelValues labelValues, + MetricAggregator aggregator, + DateTimeOffset timestamp, + string timestampOutputString = "yyyy-MM-dd HH:mm:ss.fff" + ); +} \ No newline at end of file diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index 8e373e5c..5a094e86 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -11,8 +11,11 @@ using Microsoft.Extensions.Logging; using SharpHoundCommonLib.Enums; using SharpHoundCommonLib.Exceptions; +using SharpHoundCommonLib.Interfaces; using SharpHoundCommonLib.LDAPQueries; +using SharpHoundCommonLib.Models; using SharpHoundCommonLib.Processors; +using SharpHoundCommonLib.Static; using SharpHoundRPC.NetAPINative; using SharpHoundRPC.PortScanner; @@ -36,12 +39,15 @@ internal class LdapConnectionPool : IDisposable { private const int BackoffDelayMultiplier = 2; private const int MaxRetries = 3; private static readonly ConcurrentDictionary DCInfoCache = new(); + + // Metrics + private readonly IMetricRouter _metric; // Tracks domains we know we've determined we shouldn't try to connect to private static readonly ConcurrentHashSet _excludedDomains = new(); public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config, - IPortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { + IPortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null, IMetricRouter metric = null) { _connections = new ConcurrentBag(); _globalCatalogConnection = new ConcurrentBag(); //TODO: Re-enable this once we track down the semaphore deadlock @@ -56,6 +62,7 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c _poolIdentifier = poolIdentifier; _ldapConfig = config; _log = log ?? Logging.LogProvider.CreateLogger("LdapConnectionPool"); + _metric = metric ?? Metrics.Factory.CreateMetricRouter(); _portScanner = scanner ?? new PortScanner(); _nativeMethods = nativeMethods ?? new NativeMethods(); _queryAdaptiveTimeout = new AdaptiveTimeout(maxTimeout: TimeSpan.FromMinutes(2), Logging.LogProvider.CreateLogger("LdapQuery"), useAdaptiveTimeout: false); @@ -72,6 +79,9 @@ public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig c return await GetConnectionAsync(); } + + private void LatencyObservation(double latency) => _metric.Observe(LdapMetricDefinitions.RequestLatency, latency, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); public async IAsyncEnumerable> Query(LdapQueryParameters queryParameters, [EnumeratorCancellation] CancellationToken cancellationToken = new()) { @@ -114,11 +124,15 @@ public async IAsyncEnumerable> Query(LdapQueryParam querySuccess = true; } else if (queryRetryCount == MaxRetries) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail($"Failed to get a response after {MaxRetries} attempts", queryParameters); } else { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); queryRetryCount++; } } @@ -134,6 +148,8 @@ public async IAsyncEnumerable> Query(LdapQueryParam * Release our connection in a faulted state since the connection is defunct. Attempt to get a new connection to any server in the domain * since non-paged queries do not require same server connections */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); queryRetryCount++; _log.LogDebug("Query - Attempting to recover from ServerDown for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), queryRetryCount); @@ -167,6 +183,8 @@ public async IAsyncEnumerable> Query(LdapQueryParam * If we get a busy error, we want to do an exponential backoff, but maintain the current connection * The expectation is that given enough time, the server should stop being busy and service our query appropriately */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); busyRetryCount++; _log.LogDebug("Query - Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); @@ -177,6 +195,8 @@ public async IAsyncEnumerable> Query(LdapQueryParam /* * Treat a timeout as a busy error */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); busyRetryCount++; _log.LogDebug("Query - Timeout: Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); @@ -187,6 +207,8 @@ public async IAsyncEnumerable> Query(LdapQueryParam /* * This is our fallback catch. If our retry counts have been exhausted this will trigger and break us out of our loop */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail( $"Query - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", queryParameters); @@ -195,6 +217,8 @@ public async IAsyncEnumerable> Query(LdapQueryParam /* * Generic exception handling for unforeseen circumstances */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail($"Query - Caught unrecoverable exception: {e.Message}", queryParameters); @@ -280,11 +304,15 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery queryRetryCount = 0; } else if (queryRetryCount == MaxRetries) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail( $"PagedQuery - Failed to get a response after {MaxRetries} attempts", queryParameters); } else { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); queryRetryCount++; } } @@ -299,6 +327,8 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery * Release our connection in a faulted state since the connection is defunct. * Paged queries require a connection to be made to the same server which we started the paged query on */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); if (serverName == null) { _log.LogError( "PagedQuery - Received server down exception without a known servername. Unable to generate new connection\n{Info}", @@ -338,6 +368,8 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery * If we get a busy error, we want to do an exponential backoff, but maintain the current connection * The expectation is that given enough time, the server should stop being busy and service our query appropriately */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); busyRetryCount++; _log.LogDebug("PagedQuery - Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); @@ -348,6 +380,8 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery /* * Treat a timeout as a busy error */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); busyRetryCount++; _log.LogDebug("PagedQuery - Timeout: Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); @@ -355,11 +389,15 @@ public async IAsyncEnumerable> PagedQuery(LdapQuery await Task.Delay(backoffDelay, cancellationToken); } catch (LdapException le) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail( $"PagedQuery - Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", queryParameters, le.ErrorCode); } catch (Exception e) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail($"PagedQuery - Caught unrecoverable exception: {e.Message}", queryParameters); @@ -499,6 +537,8 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish response = await SendRequestWithTimeout(connectionWrapper.Connection, searchRequest, _rangedRetrievalAdaptiveTimeout); } catch (LdapException le) when (le.ErrorCode == (int)ResultCode.Busy && busyRetryCount < MaxRetries) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); busyRetryCount++; _log.LogDebug("RangedRetrieval - Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); @@ -509,6 +549,8 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish /* * Treat a timeout as a busy error */ + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); busyRetryCount++; _log.LogDebug("RangedRetrieval - Timeout: Executing busy backoff for query {Info} (Attempt {Count})", queryParameters.GetQueryInfo(), busyRetryCount); @@ -517,6 +559,8 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish } catch (LdapException le) when (le.ErrorCode == (int)LdapErrorCodes.ServerDown && queryRetryCount < MaxRetries) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); queryRetryCount++; _log.LogDebug( "RangedRetrieval - Attempting to recover from ServerDown for query {Info} (Attempt {Count})", @@ -548,11 +592,15 @@ public async IAsyncEnumerable> RangedRetrieval(string distinguish } } catch (LdapException le) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail( $"Caught unrecoverable ldap exception: {le.Message} (ServerMessage: {le.ServerErrorMessage}) (ErrorCode: {le.ErrorCode})", queryParameters, le.ErrorCode); } catch (Exception e) { + _metric.Observe(LdapMetricDefinitions.FailedRequests, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); tempResult = LdapResult.Fail($"Caught unrecoverable exception: {e.Message}", queryParameters); } @@ -1047,11 +1095,24 @@ private SearchRequest CreateSearchRequest(string distinguishedName, string ldapF } private async Task SendRequestWithTimeout(LdapConnection connection, SearchRequest request, AdaptiveTimeout adaptiveTimeout) { + // Prerequest metrics + var concurrentRequests = LdapMetrics.IncrementInFlight(); + _metric.Observe(LdapMetricDefinitions.ConcurrentRequests, concurrentRequests, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); + // Add padding to account for network latency and processing overhead const int TimeoutPaddingSeconds = 3; var timeout = adaptiveTimeout.GetAdaptiveTimeout(); var timeoutWithPadding = timeout + TimeSpan.FromSeconds(TimeoutPaddingSeconds); - var result = await adaptiveTimeout.ExecuteWithTimeout((_) => connection.SendRequestAsync(request, timeoutWithPadding)); + var result = await adaptiveTimeout.ExecuteWithTimeout((_) => connection.SendRequestAsync(request, timeoutWithPadding), latencyObservation: LatencyObservation); + + // Postrequest metrics + concurrentRequests = LdapMetrics.DecrementInFlight(); + _metric.Observe(LdapMetricDefinitions.ConcurrentRequests, concurrentRequests, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); + _metric.Observe(LdapMetricDefinitions.RequestsTotal, 1, + new LabelValues([nameof(LdapConnectionPool), _poolIdentifier])); + if (result.IsSuccess) return (SearchResponse)result.Value; else diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index 14612da1..03cfdc90 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -15,9 +15,12 @@ using Microsoft.Extensions.Logging; using SharpHoundCommonLib.DirectoryObjects; using SharpHoundCommonLib.Enums; +using SharpHoundCommonLib.Interfaces; using SharpHoundCommonLib.LDAPQueries; +using SharpHoundCommonLib.Models; using SharpHoundCommonLib.OutputTypes; using SharpHoundCommonLib.Processors; +using SharpHoundCommonLib.Static; using SharpHoundRPC.NetAPINative; using SharpHoundRPC.PortScanner; using Domain = System.DirectoryServices.ActiveDirectory.Domain; @@ -47,6 +50,9 @@ private readonly ConcurrentDictionary private readonly ConcurrentDictionary _distinguishedNameCache = new(StringComparer.OrdinalIgnoreCase); + // Metrics + private readonly IMetricRouter _metric; + private readonly ILogger _log; private readonly IPortScanner _portScanner; private readonly NativeMethods _nativeMethods; @@ -77,13 +83,15 @@ public LdapUtils() { _nativeMethods = new NativeMethods(); _portScanner = new PortScanner(); _log = Logging.LogProvider.CreateLogger("LDAPUtils"); + _metric = Metrics.Factory.CreateMetricRouter(); _connectionPool = new ConnectionPoolManager(_ldapConfig, _log); } - public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, ILogger log = null) { + public LdapUtils(NativeMethods nativeMethods = null, PortScanner scanner = null, ILogger log = null, IMetricRouter metric = null) { _nativeMethods = nativeMethods ?? new NativeMethods(); _portScanner = scanner ?? new PortScanner(); _log = log ?? Logging.LogProvider.CreateLogger("LDAPUtils"); + _metric = metric ?? Metrics.Factory.CreateMetricRouter(); _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); } @@ -126,6 +134,7 @@ public IAsyncEnumerable> PagedQuery(LdapQueryParame var result = await LookupSidType(identifier, objectDomain); if (!result.Success) { _unresolvablePrincipals.Add(identifier); + _metric.Observe(LdapMetricDefinitions.UnresolvablePrincipals, 1, new LabelValues([nameof(LdapUtils)])); } return (result.Success, new TypedPrincipal(identifier, result.Type)); @@ -134,6 +143,7 @@ public IAsyncEnumerable> PagedQuery(LdapQueryParame var (success, type) = await LookupGuidType(identifier, objectDomain); if (!success) { _unresolvablePrincipals.Add(identifier); + _metric.Observe(LdapMetricDefinitions.UnresolvablePrincipals, 1, new LabelValues([nameof(LdapUtils)])); } return (success, new TypedPrincipal(identifier, type)); @@ -965,6 +975,7 @@ public async Task IsDomainController(string computerObjectId, string domai } catch { _unresolvablePrincipals.Add(distinguishedName); + _metric.Observe(LdapMetricDefinitions.UnresolvablePrincipals, 1, new LabelValues([nameof(LdapUtils)])); return (false, default); } } @@ -1129,6 +1140,9 @@ public void ResetUtils() { _domainControllers = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase); _connectionPool?.Dispose(); _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); + + // Metrics + LdapMetrics.ResetInFlight(); } private IDirectoryObject CreateDirectoryEntry(string path) { diff --git a/src/CommonLib/Models/FileMetricSinkOptions.cs b/src/CommonLib/Models/FileMetricSinkOptions.cs new file mode 100644 index 00000000..9db1c668 --- /dev/null +++ b/src/CommonLib/Models/FileMetricSinkOptions.cs @@ -0,0 +1,9 @@ +using System; + +namespace SharpHoundCommonLib.Models; + +public sealed class FileMetricSinkOptions { + public TimeSpan FlushInterval { get; set; } = TimeSpan.FromSeconds(10); + public string TimestampFormat { get; set; } = "yyyy-MM-dd HH:mm:ss.fff"; + public bool FlushWriter { get; set; } = true; +} \ No newline at end of file diff --git a/src/CommonLib/Models/IsExternalInit.cs b/src/CommonLib/Models/IsExternalInit.cs new file mode 100644 index 00000000..c0627212 --- /dev/null +++ b/src/CommonLib/Models/IsExternalInit.cs @@ -0,0 +1,11 @@ +using System.ComponentModel; + +// This class while it looks unused, is used to be able to build the records for a .Net Framework Target. +// see: https://stackoverflow.com/questions/64749385/predefined-type-system-runtime-compilerservices-isexternalinit-is-not-defined +// see: https://developercommunity.visualstudio.com/t/error-cs0518-predefined-type-systemruntimecompiler/1244809 + +namespace System.Runtime.CompilerServices +{ + [EditorBrowsable(EditorBrowsableState.Never)] + internal class IsExternalInit{} +} \ No newline at end of file diff --git a/src/CommonLib/Models/MetricDefinition.cs b/src/CommonLib/Models/MetricDefinition.cs new file mode 100644 index 00000000..31520cdc --- /dev/null +++ b/src/CommonLib/Models/MetricDefinition.cs @@ -0,0 +1,62 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace SharpHoundCommonLib.Models; + +public readonly record struct LabelValues(string[] Values) { + public string ToDisplayString(IReadOnlyList labelNames, string additionalName = null, string additionalValue = null) { + if (labelNames.Count == 0) + return string.Empty; + + if (labelNames.Count != Values.Length) + return $"{{Improper Observation Labels, LabelNamesCount: {labelNames.Count}, LabelValuesCount: {Values.Length}}}"; + + var sb = new StringBuilder(); + sb.Append('{'); + for (var i = 0; i < labelNames.Count; i++) { + if (i > 0) + sb.Append(','); + + sb.Append(labelNames[i]) + .Append("=\"") + .Append(Values[i]) + .Append('"'); + } + + if (!string.IsNullOrEmpty(additionalName) && !string.IsNullOrEmpty(additionalValue)) { + sb.Append(',').Append(additionalName).Append("=\"").Append(additionalValue).Append('"'); + } + + sb.Append('}'); + return sb.ToString(); + } +}; + +public abstract record MetricDefinition( + string Name, + IReadOnlyList LabelNames); + +public sealed record CounterDefinition(string Name, IReadOnlyList LabelNames) : MetricDefinition(Name, LabelNames); +public sealed record GaugeDefinition(string Name, IReadOnlyList LabelNames) : MetricDefinition(Name, LabelNames); + +public sealed record CumulativeHistogramDefinition(string Name, double[] InitBuckets, IReadOnlyList LabelNames) : MetricDefinition(Name, LabelNames) { + public double[] Buckets { get; } = NormalizeBuckets(InitBuckets); + + private static double[] NormalizeBuckets(double[] buckets) { + if (buckets is null || buckets.Length == 0) + throw new ArgumentException("Histogram buckets cannot be empty"); + + var copy = (double[])buckets.Clone(); + Array.Sort(copy); + + for (var i = 1; i < copy.Length; i++) { + if (copy[i] <= copy[i - 1]) + throw new ArgumentException("Histogram buckets must be strictly increasing"); + } + + return copy; + } +}; + +// Currently Native Histograms are not supported \ No newline at end of file diff --git a/src/CommonLib/Models/MetricObservation.cs b/src/CommonLib/Models/MetricObservation.cs new file mode 100644 index 00000000..7c499920 --- /dev/null +++ b/src/CommonLib/Models/MetricObservation.cs @@ -0,0 +1,10 @@ +namespace SharpHoundCommonLib.Models; + +public abstract record MetricObservation { + private MetricObservation() { } + + public readonly record struct DoubleMetricObservation( + int DefinitionId, + double Value, + string[] LabelsValues); +} diff --git a/src/CommonLib/Services/DefaultLabelValuesCache.cs b/src/CommonLib/Services/DefaultLabelValuesCache.cs new file mode 100644 index 00000000..cbb407a5 --- /dev/null +++ b/src/CommonLib/Services/DefaultLabelValuesCache.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using SharpHoundCommonLib.Interfaces; + +namespace SharpHoundCommonLib.Services; + +public sealed class DefaultLabelValuesCache : ILabelValuesCache { + internal readonly Dictionary _cache = new(); + + private readonly object _lock = new(); + private const char Separator = '\u001F'; // ascii unit separator + + public string[] Intern(string[] values) { + if (values == null || values.Length == 0) { + return []; + } + + var key = MakeKey(values); + + lock (_lock) { + if (_cache.TryGetValue(key, out var existing)) + return existing; + + var copy = new string[values.Length]; + Array.Copy(values, copy, values.Length); + _cache[key] = copy; + return copy; + } + } + + internal static string MakeKey(string[] values) { + return values.Length == 1 ? values[0] : string.Join(Separator.ToString(), values); + } + + +} \ No newline at end of file diff --git a/src/CommonLib/Services/FileMetricSink.cs b/src/CommonLib/Services/FileMetricSink.cs new file mode 100644 index 00000000..73bc1f08 --- /dev/null +++ b/src/CommonLib/Services/FileMetricSink.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Services; + +public sealed class FileMetricSink( + IReadOnlyList definitions, + TextWriter textWriter, + IMetricWriter metricWriter, + FileMetricSinkOptions options = null) + : IMetricSink, IDisposable { + private readonly TextWriter _textWriter = textWriter; + private readonly IMetricWriter _metricWriter = metricWriter; + private readonly FileMetricSinkOptions _options = options ?? new FileMetricSinkOptions(); + + + // metric state, using a lock rather than a concurrent dictionary protects both the dictionary, + // and the aggregators state. + private readonly MetricDefinition[] _definitions = definitions.ToArray(); + private readonly Dictionary<(int, string[]), MetricAggregator> _states = new(); + private readonly object _lock = new(); + + public FileMetricSink( + IReadOnlyList definitions, + string filePath, + IMetricWriter metricWriter, + FileMetricSinkOptions options = null) + : this( + definitions, + new StreamWriter( + File.Open(filePath, FileMode.Create, FileAccess.Write, FileShare.Read)), + metricWriter, + options) {} + + public void Observe(in MetricObservation.DoubleMetricObservation observation) { + var key = (observation.DefinitionId, observation.LabelsValues); + + lock (_lock) { + if (!_states.TryGetValue(key, out var aggregator)) { + aggregator = MetricAggregatorExtensions.Create(_definitions[observation.DefinitionId]); + _states[key] = aggregator; + } + + aggregator.Observe(observation.Value); + } + } + + private int EstimateSize() => _states.Count * 128; + + public void Flush() { + string output; + lock (_lock) { + var sb = new StringBuilder(EstimateSize()); + + var timestamp = DateTimeOffset.Now; + sb.Append("Metric Flush: ") + .Append(timestamp.ToString(_options.TimestampFormat)) + .AppendLine(); + sb.Append('=', 40).AppendLine(); + + // Must use this deconstruction for .Net Version + foreach (var kvp in _states) { + var definitionId = kvp.Key.Item1; + var labelValues = kvp.Key.Item2; + var aggregator = kvp.Value; + var definition = _definitions[definitionId]; + + _metricWriter.StringBuilderAppendMetric( + sb, + definition, + new LabelValues(labelValues), + aggregator, + timestamp); + } + + sb.Append('=', 40).AppendLine().AppendLine().AppendLine().AppendLine().AppendLine(); + output = sb.ToString(); + } + + _textWriter.Write(output); + + if (_options.FlushWriter) + _textWriter.Flush(); + } + + public void Dispose() { + Flush(); + _textWriter.Dispose(); + } +} \ No newline at end of file diff --git a/src/CommonLib/Services/MetricAggregator.cs b/src/CommonLib/Services/MetricAggregator.cs new file mode 100644 index 00000000..e6d491f9 --- /dev/null +++ b/src/CommonLib/Services/MetricAggregator.cs @@ -0,0 +1,67 @@ +using System; +using System.Threading; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Services; + +public static class MetricAggregatorExtensions { + public static MetricAggregator Create(MetricDefinition definition) => + definition switch { + CounterDefinition => new CounterAggregator(), + GaugeDefinition => new GaugeAggregator(), + CumulativeHistogramDefinition ch => new CumulativeHistogramAggregator(ch.Buckets), + _ => throw new ArgumentOutOfRangeException(nameof(definition), + $"Unknown metric type {definition.GetType().Name}") + }; +} + +public abstract class MetricAggregator { + public abstract void Observe(double value); + public abstract object Snapshot(); +} + +public sealed class CounterAggregator : MetricAggregator { + private long _value; + + public override void Observe(double value) => Interlocked.Add(ref _value, (long)value); + public override object Snapshot() => _value; +} + +public sealed class GaugeAggregator : MetricAggregator { + private double _value; + + public override void Observe(double value) => _value = value; + public override object Snapshot() => _value; +} + +public record struct HistogramSnapshot(double[] Bounds, long[] Counts, long TotalCount, double Sum); + +public sealed class CumulativeHistogramAggregator(double[] bounds) : MetricAggregator { + private readonly long[] _bucketCounts = new long[bounds.Length + 1]; // Includes the Inf+ bucket + private long _count; + private double _sum; + private readonly object _lock = new(); + + public override void Observe(double value) { + // this along with the following line, finds the correct bucket the value should be placed in. + // If the value is defined as a specific bucket, binary search returns it. If it is not found, + // it returns the compliment of the position it should be at. with a simple check we can undo + // that compliment if it is what is found. + var idx = Array.BinarySearch(bounds, value); + if (idx < 0) idx = ~idx; + lock (_lock) { + _bucketCounts[idx]++; + _count++; + _sum += value; + } + } + + public override object Snapshot() => SnapshotHistogram(); + + public HistogramSnapshot SnapshotHistogram() { + lock (_lock) { + return new HistogramSnapshot(bounds, (long[])_bucketCounts.Clone(), _count, _sum); + } + } +} \ No newline at end of file diff --git a/src/CommonLib/Services/MetricFactory.cs b/src/CommonLib/Services/MetricFactory.cs new file mode 100644 index 00000000..b448668c --- /dev/null +++ b/src/CommonLib/Services/MetricFactory.cs @@ -0,0 +1,15 @@ +using SharpHoundCommonLib.Interfaces; + +namespace SharpHoundCommonLib.Services; + +public sealed class MetricFactory(IMetricRouter router) : IMetricFactory { + private readonly IMetricRouter _router = router; + + public IMetricRouter CreateMetricRouter() => _router; +} + +public sealed class NoOpMetricFactory : IMetricFactory { + public static readonly NoOpMetricFactory Instance = new(); + private NoOpMetricFactory() { } + public IMetricRouter CreateMetricRouter() => NoOpMetricRouter.Instance; +} diff --git a/src/CommonLib/Services/MetricRegistry.cs b/src/CommonLib/Services/MetricRegistry.cs new file mode 100644 index 00000000..fd154e79 --- /dev/null +++ b/src/CommonLib/Services/MetricRegistry.cs @@ -0,0 +1,29 @@ +using System.Collections.Generic; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; +using SharpHoundCommonLib.Static; + +namespace SharpHoundCommonLib.Services; + +public sealed class MetricRegistry : IMetricRegistry { + private readonly List _metrics = []; + private readonly Dictionary _nameToId = new(); + private bool _sealed; + + public IReadOnlyList Definitions => _metrics; + + public bool TryRegister(MetricDefinition definition, out int definitionId) { + definitionId = MetricId.InvalidId; + if (_sealed) return false; + + if (_nameToId.TryGetValue(definition.Name, out definitionId)) + return true; + + definitionId = _metrics.Count; + _metrics.Add(definition); + _nameToId[definition.Name] = definitionId; + return true; + } + + public void Seal() => _sealed = true; +} diff --git a/src/CommonLib/Services/MetricRouter.cs b/src/CommonLib/Services/MetricRouter.cs new file mode 100644 index 00000000..86ba7d0a --- /dev/null +++ b/src/CommonLib/Services/MetricRouter.cs @@ -0,0 +1,49 @@ +using System.Collections.Generic; +using System.Linq; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Services; + +public sealed class MetricRouter( + IReadOnlyList definitions, + IEnumerable sinks, + ILabelValuesCache labelCache) : IMetricRouter { + private readonly int _definitionCount = definitions.Count; + private readonly IMetricSink[] _sinks = sinks.ToArray(); + private readonly ILabelValuesCache _labelCache = labelCache; + + // TODO MC: See if this boosts runtime, may need more metrics to see an appreciable difference. + // In JIT Complication, can remove some of the overhead of calling + // [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Observe(int definitionId, double value, LabelValues labelValues) { + // check to see if metric is registered, handles negative values and IDs greater than those registered. + if ((uint)definitionId >= (uint)_definitionCount) + return; + + var interned = _labelCache.Intern(labelValues.Values); + + var obs = new MetricObservation.DoubleMetricObservation(definitionId, value, interned); + + foreach (var sink in _sinks) + sink.Observe(obs); + } + + public void Flush() { + foreach (var sink in _sinks) + sink.Flush(); + } +} + +public sealed class NoOpMetricRouter : IMetricRouter { + public static readonly NoOpMetricRouter Instance = new(); + private NoOpMetricRouter() { } + + public void Observe(int definitionId, double value, LabelValues labelValues) { + // intentionally empty + } + + public void Flush() { + // intentionally empty + } +} diff --git a/src/CommonLib/Services/MetricWriter.cs b/src/CommonLib/Services/MetricWriter.cs new file mode 100644 index 00000000..b909a94b --- /dev/null +++ b/src/CommonLib/Services/MetricWriter.cs @@ -0,0 +1,89 @@ +using System; +using System.Globalization; +using System.Text; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Services; + +public class MetricWriter : IMetricWriter { + public void StringBuilderAppendMetric(StringBuilder builder, MetricDefinition definition, LabelValues labelValues, + MetricAggregator aggregator, DateTimeOffset timestamp, string timestampOutputString = "yyyy-MM-dd HH:mm:ss.fff") { + var labelText = labelValues.ToDisplayString(definition.LabelNames); + if (aggregator is CumulativeHistogramAggregator cha) { + CumulativeHistogramAppend(builder, definition, labelValues, cha, timestamp, timestampOutputString); + } else { + DefaultAppend(builder, definition, labelValues.ToDisplayString(definition.LabelNames), aggregator, timestamp, timestampOutputString); + } + } + + private static void CumulativeHistogramAppend( + StringBuilder builder, + MetricDefinition definition, + LabelValues labelValues, + CumulativeHistogramAggregator aggregator, + DateTimeOffset timestamp, + string timestampOutputString) { + long cumulativeValue = 0; + var defaultLabelText = labelValues.ToDisplayString(definition.LabelNames); + + var snapshot = aggregator.SnapshotHistogram(); + + for (var i = 0; i < snapshot.Bounds.Length; i++) { + cumulativeValue += snapshot.Counts[i]; + + if (labelValues.Values.Length > 0) { + builder.AppendFormat("{0} {1}{2} = {3}\n", + timestamp.ToString(timestampOutputString), + definition.Name + "_bucket", + labelValues.ToDisplayString(definition.LabelNames, "le", snapshot.Bounds[i].ToString(CultureInfo.InvariantCulture)), + cumulativeValue); + } else { + builder.AppendFormat("{0} {1}{2}{{le=\"{3}\"}} = {4}\n", + timestamp.ToString(timestampOutputString), + definition.Name + "_bucket", + defaultLabelText, + snapshot.Bounds[i], + cumulativeValue); + } + } + + if (labelValues.Values.Length > 0) { + builder.AppendFormat("{0} {1}{2} = {3}\n", + timestamp.ToString(timestampOutputString), + definition.Name + "_bucket", + labelValues.ToDisplayString(definition.LabelNames, "le", "+Inf"), + snapshot.TotalCount); + + } else { + builder.AppendFormat("{0} {1}{2}{{le=\"+Inf\"}} = {3}\n", + timestamp.ToString(timestampOutputString), + definition.Name + "_bucket", + defaultLabelText, + snapshot.TotalCount); + } + + builder.AppendFormat("{0} {1}{2} = {3}\n", + timestamp.ToString(timestampOutputString), + definition.Name + "_sum", + defaultLabelText, + snapshot.Sum); + + builder.AppendFormat("{0} {1}{2} = {3}\n", + timestamp.ToString(timestampOutputString), + definition.Name + "_count", + defaultLabelText, + snapshot.TotalCount); + } + + + private static void DefaultAppend( + StringBuilder builder, + MetricDefinition definition, + string labelText, + MetricAggregator aggregator, + DateTimeOffset timestamp, + string timestampOutputString) => + builder.AppendFormat("{0} {1}{2} = {{{3}}}\n", timestamp.ToString(timestampOutputString), + definition.Name, labelText, aggregator.Snapshot()); +} \ No newline at end of file diff --git a/src/CommonLib/Services/MetricsFlushTimer.cs b/src/CommonLib/Services/MetricsFlushTimer.cs new file mode 100644 index 00000000..a70c07cc --- /dev/null +++ b/src/CommonLib/Services/MetricsFlushTimer.cs @@ -0,0 +1,33 @@ +using System; +using System.Threading; + +namespace SharpHoundCommonLib.Services; + +public class MetricsFlushTimer : IDisposable { + private readonly Action _flush; + private readonly Timer _timer; + + + public MetricsFlushTimer( + Action flush, + TimeSpan interval) { + _flush = flush; + _timer = new Timer( + _ => FlushSafe(), + null, + interval, + interval); + } + + private void FlushSafe() { + try { + _flush(); + } catch { + // catch all exception and do not kill the process + } + } + + public void Dispose() { + _timer.Dispose(); + } +} \ No newline at end of file diff --git a/src/CommonLib/Static/DefaultMetricRegistry.cs b/src/CommonLib/Static/DefaultMetricRegistry.cs new file mode 100644 index 00000000..d2cc8821 --- /dev/null +++ b/src/CommonLib/Static/DefaultMetricRegistry.cs @@ -0,0 +1,40 @@ +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; + +namespace SharpHoundCommonLib.Static; + +public static class DefaultMetricRegistry { + public static void RegisterDefaultMetrics(this IMetricRegistry registry) { + // LDAP Metrics + registry.TryRegister( + new CounterDefinition( + Name: "ldap_total_requests", + LabelNames: ["location", "identifier"]), + out LdapMetricDefinitions.RequestsTotal); + + registry.TryRegister( + new CounterDefinition( + Name: "ldap_failed_requests", + LabelNames: ["location", "identifier"]), + out LdapMetricDefinitions.FailedRequests); + + registry.TryRegister( + new GaugeDefinition( + Name: "ldap_concurrent_requests", + LabelNames: ["location", "identifier"]), + out LdapMetricDefinitions.ConcurrentRequests); + + registry.TryRegister( + new CumulativeHistogramDefinition( + Name: "ldap_request_duration_milliseconds", + InitBuckets: [100, 250, 500, 1000, 2500, 5000], + LabelNames: ["location", "identifier"]), + out LdapMetricDefinitions.RequestLatency); + + registry.TryRegister( + new CounterDefinition( + Name: "ldap_total_unresolvable_principals", + LabelNames: ["location"]), + out LdapMetricDefinitions.UnresolvablePrincipals); + } +} \ No newline at end of file diff --git a/src/CommonLib/Static/Metrics.cs b/src/CommonLib/Static/Metrics.cs new file mode 100644 index 00000000..a0a75d54 --- /dev/null +++ b/src/CommonLib/Static/Metrics.cs @@ -0,0 +1,37 @@ +using System.Threading; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Services; + +namespace SharpHoundCommonLib.Static; + +public static class Metrics { + private static IMetricFactory _factory = NoOpMetricFactory.Instance; + + public static IMetricFactory Factory { + get => _factory; + set => _factory = value ?? NoOpMetricFactory.Instance; + } +} + +public static class LdapMetrics { + private static int _inFlightRequests; + + public static int InFlightRequest => _inFlightRequests; + + public static int IncrementInFlight() => Interlocked.Increment(ref _inFlightRequests); + public static int DecrementInFlight() => Interlocked.Decrement(ref _inFlightRequests); + public static void ResetInFlight() => Interlocked.Exchange(ref _inFlightRequests, 0); +} + + +public static class MetricId { + public const int InvalidId = -1; +} + +public static class LdapMetricDefinitions { + public static int RequestLatency = MetricId.InvalidId; + public static int ConcurrentRequests = MetricId.InvalidId; + public static int RequestsTotal = MetricId.InvalidId; + public static int FailedRequests = MetricId.InvalidId; + public static int UnresolvablePrincipals = MetricId.InvalidId; +} \ No newline at end of file diff --git a/test/unit/AdaptiveTimeoutTest.cs b/test/unit/AdaptiveTimeoutTest.cs index c105c23d..044ce758 100644 --- a/test/unit/AdaptiveTimeoutTest.cs +++ b/test/unit/AdaptiveTimeoutTest.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading.Tasks; using SharpHoundCommonLib; @@ -17,62 +18,116 @@ public AdaptiveTimeoutTest(ITestOutputHelper testOutputHelper) { [Fact] public async Task AdaptiveTimeout_GetAdaptiveTimeout_NotEnoughSamplesAsync() { + var observedLatency= -50.0; + var maxTimeout = TimeSpan.FromSeconds(1); var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3); - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50), latencyObservation: LatencyObservation); var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); Assert.Equal(maxTimeout, adaptiveTimeoutResult); + + Assert.InRange(observedLatency, 0.0, 100); + return; + + void LatencyObservation(double latency) { + observedLatency = latency; + } } [Fact] public async Task AdaptiveTimeout_GetAdaptiveTimeout_AdaptiveDisabled() { + var observedLatency1= -50.0; + var observedLatency2= -50.0; + var observedLatency3= -50.0; + var maxTimeout = TimeSpan.FromSeconds(1); var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3, false); - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50), latencyObservation: LatencyObservation1); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50), latencyObservation: LatencyObservation2); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50), latencyObservation: LatencyObservation3); var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); Assert.Equal(maxTimeout, adaptiveTimeoutResult); + Assert.InRange(observedLatency1, 0.0, 100); + Assert.InRange(observedLatency2, 0.0, 100); + Assert.InRange(observedLatency3, 0.0, 100); + return; + + + void LatencyObservation1(double latency) { + observedLatency1 = latency; + } + void LatencyObservation2(double latency) { + observedLatency2 = latency; + } + void LatencyObservation3(double latency) { + observedLatency3 = latency; + } } [Fact] public async Task AdaptiveTimeout_GetAdaptiveTimeout() { + var observedLatency1= -50.0; + var observedLatency2= -50.0; + var observedLatency3= -50.0; var maxTimeout = TimeSpan.FromSeconds(1); var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), 10, 1000, 3); - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(40)); - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50)); - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(60)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(40), latencyObservation: LatencyObservation1); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(50), latencyObservation: LatencyObservation2); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(60), latencyObservation: LatencyObservation3); var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); Assert.True(adaptiveTimeoutResult < maxTimeout); + Assert.InRange(observedLatency1, 0.0, 150); + Assert.InRange(observedLatency2, 0.0, 160); + Assert.InRange(observedLatency3, 0.0, 170); + return; + + void LatencyObservation1(double latency) { + observedLatency1 = latency; + } + void LatencyObservation2(double latency) { + observedLatency2 = latency; + } + void LatencyObservation3(double latency) { + observedLatency3 = latency; + } } [Fact] public async Task AdaptiveTimeout_GetAdaptiveTimeout_TimeSpikeSafetyValve() { + var observations = new ConcurrentBag(); var tasks = new List(); var maxTimeout = TimeSpan.FromSeconds(1); var numSamples = 30; var adaptiveTimeout = new AdaptiveTimeout(maxTimeout, new TestLogger(_testOutputHelper, Microsoft.Extensions.Logging.LogLevel.Trace), numSamples, 1000, 10); for (int i = 0; i < numSamples; i++) - tasks.Add(adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(10))); + tasks.Add(adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(10), latencyObservation: LatencyObservation)); await Task.WhenAll(tasks); for (int i = 0; i < 3; i++) - await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(500)); + await adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(500), latencyObservation: LatencyObservation); var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); Assert.Equal(maxTimeout, adaptiveTimeoutResult); + foreach (var t in observations) { + Assert.InRange(t, 0.0, 1000.1); + } + return; + + void LatencyObservation(double latency) => observations.Add(latency); } [Fact] public async Task AdaptiveTimeout_GetAdaptiveTimeout_TimeSpikeSafetyValve_IgnoreHiccup() { + var completedObservations = new ConcurrentBag(); + var timeoutObservations = new ConcurrentBag(); var tasks = new List(); var maxTimeout = TimeSpan.FromSeconds(1); var numSamples = 5; @@ -80,23 +135,33 @@ public async Task AdaptiveTimeout_GetAdaptiveTimeout_TimeSpikeSafetyValve_Ignore // Prepare our successful samples for (int i = 0; i < numSamples; i++) - tasks.Add(adaptiveTimeout.ExecuteWithTimeout((_) => Task.CompletedTask)); + tasks.Add(adaptiveTimeout.ExecuteWithTimeout((_) => Task.CompletedTask, latencyObservation: LatencyCompletedObservation)); await Task.WhenAll(tasks); // Add some timeout tasks that will resolve last for (int i = 0; i < 5; i++) - tasks.Add(adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(2000))); + tasks.Add(adaptiveTimeout.ExecuteWithTimeout(async (_) => await Task.Delay(2000), latencyObservation: LatencyTimeoutObservation)); // These tasks are added later but will resolve first for (int i = 0; i < 4; i++) - tasks.Add(adaptiveTimeout.ExecuteWithTimeout((_) => Task.CompletedTask)); + tasks.Add(adaptiveTimeout.ExecuteWithTimeout((_) => Task.CompletedTask, latencyObservation: LatencyCompletedObservation)); await Task.WhenAll(tasks); var adaptiveTimeoutResult = adaptiveTimeout.GetAdaptiveTimeout(); // So our time spike safety valve should ignore the hiccup, since later tasks have resolved // by the time the safety valve has triggered by the timeout tasks Assert.True(adaptiveTimeoutResult < maxTimeout); + foreach (var t in completedObservations) { + Assert.InRange(t, 0.0, 50.0); + } + foreach (var t in timeoutObservations) { + Assert.InRange(t, 0.0, 1000.1); + } + return; + + void LatencyCompletedObservation(double latency) => completedObservations.Add(latency); + void LatencyTimeoutObservation(double latency) => timeoutObservations.Add(latency); } [Fact] diff --git a/test/unit/CommonLibHelperTests.cs b/test/unit/CommonLibHelperTests.cs index 0cad80e1..f2e4c0c1 100644 --- a/test/unit/CommonLibHelperTests.cs +++ b/test/unit/CommonLibHelperTests.cs @@ -302,7 +302,7 @@ public void DomainNameToDistinguishedName_DotsBecomeDcComponents() Assert.Equal("DC=test,DC=local", result); } - [Theory] + [WindowsOnlyTheory] [InlineData("S-1-5-32-544", "\\01\\02\\00\\00\\00\\00\\00\\05\\20\\00\\00\\00\\20\\02\\00\\00")] public void ConvertSidToHexSid_ValidSid_MatchesSecurityIdentifierBinaryForm(string sid, string expectedHexSid) { @@ -322,7 +322,7 @@ static string BuildExpectedHexSid(string sid) } } - [Fact] + [WindowsOnlyFact] public void ConvertSidToHexSid_InvalidSid_Throws() { Assert.ThrowsAny(() => Helpers.ConvertSidToHexSid("NOT-A-SID")); diff --git a/test/unit/DefaultLabelValuesCacheTests.cs b/test/unit/DefaultLabelValuesCacheTests.cs new file mode 100644 index 00000000..d8634d6d --- /dev/null +++ b/test/unit/DefaultLabelValuesCacheTests.cs @@ -0,0 +1,90 @@ +using System.Collections.Concurrent; +using System.Linq; +using System.Threading.Tasks; +using SharpHoundCommonLib.Services; +using Xunit; + +namespace CommonLibTest; + +public class DefaultLabelValuesCacheTests { + + [Theory] + [InlineData(new[] {"value"}, "value")] + [InlineData(new string[] {}, "")] + [InlineData(new[] {"value1", "value2"}, "value1\u001Fvalue2")] + [InlineData(new[] {"value1", "value2", "value3"}, "value1\u001Fvalue2\u001Fvalue3")] + public void MakeKey_Returns_Proper_Key(string[] labelValues, string expectedKey) { + // act + var key = DefaultLabelValuesCache.MakeKey(labelValues); + + // assert + Assert.Equal(expectedKey, key); + } + + [Fact] + public void Intern_Retrieves_Existing_LabelValues() { + // setup + string[] values1 = ["value1", "value2"]; + string[] values2 = ["value1", "value2"]; + var cache = new DefaultLabelValuesCache(); + + // act + cache.Intern(values1); + cache.Intern(values2); + + // assert + Assert.NotEmpty(cache._cache); + Assert.Single(cache._cache); + } + + [Fact] + public void Empty_Intern_Returns_Empty_Array() { + // setup + var cache = new DefaultLabelValuesCache(); + + // act + var ret = cache.Intern([]); + + // assert + Assert.Empty(ret); + } + + [Fact] + public void LabelValuesCache_ReturnsSameReference_UnderConcurrency() + { + // setup + var cache = new DefaultLabelValuesCache(); + const int threadCount = 16; + const int iterationsPerThread = 10_000; + var results = new ConcurrentBag(); + var tasks = new Task[threadCount]; + + for (var t = 0; t < threadCount; t++) + { + tasks[t] = Task.Run(() => + { + for (var i = 0; i < iterationsPerThread; i++) + { + var labels = cache.Intern(["GET", "200"]); + results.Add(labels); + } + }); + } + + // act + Task.WaitAll(tasks); + + // assert + // Take the first reference + var first = results.First(); + + // Assert all references are identical + foreach (var arr in results) + { + Assert.True( + object.ReferenceEquals(first, arr), + "Different label array instances were returned"); + } + } + +} \ No newline at end of file diff --git a/test/unit/FileMetricSinkTests.cs b/test/unit/FileMetricSinkTests.cs new file mode 100644 index 00000000..7b27e0a7 --- /dev/null +++ b/test/unit/FileMetricSinkTests.cs @@ -0,0 +1,143 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Moq; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; +using SharpHoundCommonLib.Services; +using Xunit; + +namespace CommonLibTest; + +public class SimpleMetricWriter : IMetricWriter { + public void StringBuilderAppendMetric(StringBuilder builder, MetricDefinition definition, LabelValues labelValues, + MetricAggregator aggregator, DateTimeOffset timestamp, string timestampOutputString = "yyyy-MM-dd HH:mm:ss.fff") => + builder.AppendFormat( + "DefinitionType: {0}, DefinitionName: {1}, AggregatorType: {2}, AggregatorSnapshotType: {3}\n", + definition.GetType(), definition.Name, aggregator.GetType(), aggregator.Snapshot().GetType()); +} + +public class FileMetricSinkTests { + [Theory] + [MemberData(nameof(FileMetricSinkTestData.FlushStringCases), MemberType = typeof(FileMetricSinkTestData))] + public void FileMetricSink_Returns_Expected_Flush_String( + MetricDefinition[] definitions, + MetricObservation.DoubleMetricObservation[] observations, + string[] expectedOutputs, + string[] unexpectedOutputs) { + // setup + var sinkOptions = new FileMetricSinkOptions { + FlushWriter = true, + }; + var textWriter = new StringWriter(); + var metricWriter = new SimpleMetricWriter(); + var sink = new FileMetricSink(definitions, textWriter, metricWriter, sinkOptions); + + // act + foreach (var observation in observations) { + sink.Observe(observation); + } + sink.Flush(); + var output = textWriter.ToString(); + + // assert + foreach (var expectedOutput in expectedOutputs) { + Assert.Contains(expectedOutput, output); + } + + foreach (var unexpectedOutput in unexpectedOutputs) { + Assert.DoesNotContain(unexpectedOutput, output); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void FileMetricSink_Does_Not_Flush_Writer_With_AutoFlush_False(bool autoFlush) { + // setup + var sinkOptions = new FileMetricSinkOptions { + FlushWriter = autoFlush, + }; + var writerMoq = new Mock(MockBehavior.Strict); + writerMoq.Setup(w => w.Write(It.IsAny())).Verifiable(); + writerMoq.Setup(w => w.Flush()).Verifiable(); + var metricWriter = new SimpleMetricWriter(); + MetricDefinition[] definitions = [new CounterDefinition("counter_definition", ["name"])]; + var observation = new MetricObservation.DoubleMetricObservation(0, 1, ["value"]); + var sink = new FileMetricSink(definitions, writerMoq.Object, metricWriter, sinkOptions); + + // act + sink.Observe(observation); + sink.Flush(); + + // assert + writerMoq.Verify(w => w.Write(It.IsAny()), Times.Once); + if (autoFlush) + writerMoq.Verify(w => w.Flush(), Times.Once); + else + writerMoq.Verify(w => w.Flush(), Times.Never); + } +} + +public static class FileMetricSinkTestData { + public static IEnumerable FlushStringCases => [ + // Observations are flushed + [ + new MetricDefinition[] { + new CounterDefinition("counter_definition", ["value"]), + new GaugeDefinition("gauge_definition", ["value"]), + }, + new[] { + new MetricObservation.DoubleMetricObservation(0, 1, []), + new MetricObservation.DoubleMetricObservation(1, 1, []), + }, + new[] { + "Metric Flush: ", + "========================================", + "DefinitionType: SharpHoundCommonLib.Models.CounterDefinition, DefinitionName: counter_definition, AggregatorType: SharpHoundCommonLib.Services.CounterAggregator, AggregatorSnapshotType: System.Int64\n", + "DefinitionType: SharpHoundCommonLib.Models.GaugeDefinition, DefinitionName: gauge_definition, AggregatorType: SharpHoundCommonLib.Services.GaugeAggregator, AggregatorSnapshotType: System.Double\n", + }, + Array.Empty(), + ], + // Unobserved Metrics are not flushed + [ + new MetricDefinition[] { + new CounterDefinition("counter_definition", ["value"]), + new GaugeDefinition("gauge_definition", ["value"]), + }, + new[] { + new MetricObservation.DoubleMetricObservation(0, 1, []), + }, + new[] { + "Metric Flush: ", + "========================================", + "DefinitionType: SharpHoundCommonLib.Models.CounterDefinition, DefinitionName: counter_definition, AggregatorType: SharpHoundCommonLib.Services.CounterAggregator, AggregatorSnapshotType: System.Int64\n", + }, + new[] { + "DefinitionType: SharpHoundCommonLib.Models.GaugeDefinition, DefinitionName: gauge_definition, AggregatorType: SharpHoundCommonLib.Services.GaugeAggregator, AggregatorSnapshotType: System.Double\n", + }, + ], + // Cumulative Histogram Returns HistogramSnapshot + [ + new MetricDefinition[] { + new CounterDefinition("counter_definition", ["value"]), + new GaugeDefinition("gauge_definition", ["value"]), + new CumulativeHistogramDefinition("cumulative_histogram_definition", [1, 2, 3], ["value"]), + }, + new[] { + new MetricObservation.DoubleMetricObservation(0, 1, []), + new MetricObservation.DoubleMetricObservation(1, 1, []), + new MetricObservation.DoubleMetricObservation(2, 1, []), + }, + new[] { + "Metric Flush: ", + "========================================", + "DefinitionType: SharpHoundCommonLib.Models.CounterDefinition, DefinitionName: counter_definition, AggregatorType: SharpHoundCommonLib.Services.CounterAggregator, AggregatorSnapshotType: System.Int64\n", + "DefinitionType: SharpHoundCommonLib.Models.GaugeDefinition, DefinitionName: gauge_definition, AggregatorType: SharpHoundCommonLib.Services.GaugeAggregator, AggregatorSnapshotType: System.Double\n", + "DefinitionType: SharpHoundCommonLib.Models.CumulativeHistogramDefinition, DefinitionName: cumulative_histogram_definition, AggregatorType: SharpHoundCommonLib.Services.CumulativeHistogramAggregator, AggregatorSnapshotType: SharpHoundCommonLib.Services.HistogramSnapshot\n", + }, + Array.Empty(), + ], + ]; +} \ No newline at end of file diff --git a/test/unit/MetricAggregatorTests.cs b/test/unit/MetricAggregatorTests.cs new file mode 100644 index 00000000..3f77c505 --- /dev/null +++ b/test/unit/MetricAggregatorTests.cs @@ -0,0 +1,138 @@ +using System; +using System.Collections.Generic; +using System.Text; +using SharpHoundCommonLib.Models; +using SharpHoundCommonLib.Services; +using Xunit; +using Xunit.Abstractions; + +namespace CommonLibTest; + + +public class MetricAggregatorTests(ITestOutputHelper output) { + + [Theory] + [MemberData(nameof(MetricAggregatorTestData.CreateTestData), MemberType = typeof(MetricAggregatorTestData))] + public void MetricAggregatorExtensions_Create_Creates_Proper_Aggregator(MetricDefinition definition, + MetricAggregator expectedAggregator) { + // setup + // act + var aggregator = MetricAggregatorExtensions.Create(definition); + + // assert + Assert.IsType(expectedAggregator.GetType(), aggregator); + } + + [Fact] + public void MetricAggregatorExtensions_Creates_Throws_Exception_For_Unimplemented_MetricDefinition() { + // setup + var newMetricDefinition = new UnimplementedMetricDefinition("unimplemented", ["value1"]); + + // act and assert + Assert.Throws(() => MetricAggregatorExtensions.Create(newMetricDefinition)); + } + + [Theory] + [MemberData(nameof(MetricAggregatorTestData.ObserveAndSnapshotTests), + MemberType = typeof(MetricAggregatorTestData))] + public void MetricAggregator_Observe_and_Snapshot_Tests(MetricAggregator aggregator, + double[] observations, object expectedSnapshot) { + // setup + foreach (var observation in observations) { + aggregator.Observe(observation); + } + + // act + var snapshot = aggregator.Snapshot(); + + // assert + if (expectedSnapshot is HistogramSnapshot ehs && snapshot is HistogramSnapshot ahs) { + Assert.Equal(ehs.TotalCount, ahs.TotalCount); + Assert.Equal(ehs.Sum, ahs.Sum); + Assert.Equal(ehs.Bounds, ahs.Bounds); + Assert.Equal(ehs.Counts, ahs.Counts); + } else { + Assert.Equal(expectedSnapshot, snapshot); + } + + } + + private string snapShotArrays(double[] bounds, long[] counts) { + var builder = new StringBuilder(); + builder.Append("bounds: [ "); + Iterate(builder, bounds); + builder.Append(" ], counts: [ "); + Iterate(builder, counts); + builder.Append(" ]"); + return builder.ToString(); + + + void Iterate(StringBuilder sb, T[] os) { + var first = true; + + for (var i = 0; i < os.Length; i++) { + if (!first) + builder.Append(", "); + + builder.Append(os[i]); + first = false; + } + } + } + + + private record UnimplementedMetricDefinition(string Name, IReadOnlyList LabelNames) : MetricDefinition(Name, LabelNames) {} + +} + +public static class MetricAggregatorTestData { + public static IEnumerable CreateTestData => [ + [ + new CounterDefinition("counter_name", ["value"]), + new CounterAggregator(), + ], + [ + new GaugeDefinition("gauge_name", ["value"]), + new GaugeAggregator(), + ], + [ + new CumulativeHistogramDefinition("cumulative_histogram_name", [1, 2, 3], ["value"]), + new CumulativeHistogramAggregator([1, 2, 3]) + ], + ]; + + public static IEnumerable ObserveAndSnapshotTests => [ + [ + new CounterAggregator(), + new[] { + 1.0, + 2.0, + 1.0, + 4.0, + }, + 8L + ], + [ + new GaugeAggregator(), + new[] { + 1.0, + 2.0, + 1.0, + + }, + 1.0 + ], + [ + new CumulativeHistogramAggregator([1, 2, 3, 4]), + new[] { + 1.0, + 1.0, + 3.0, + 3.0, + 2.0, + }, + // Ensure Aggregation does not happen on observation or snapshot + new HistogramSnapshot([1, 2, 3, 4], [2, 1, 2, 0, 0], 5, 10) + ], + ]; +} \ No newline at end of file diff --git a/test/unit/MetricDefinitionTests.cs b/test/unit/MetricDefinitionTests.cs new file mode 100644 index 00000000..49ba2d51 --- /dev/null +++ b/test/unit/MetricDefinitionTests.cs @@ -0,0 +1,127 @@ +using System; +using SharpHoundCommonLib.Models; +using Xunit; + +namespace CommonLibTest; + +public class MetricDefinitionTests { + + [Fact] + public void LabelValues_EmptyLabelNames_Returns_Empty() { + // setup + var labelValues = new LabelValues(["value1", "value2"]); + string[] labelNames = []; + + // act + var output = labelValues.ToDisplayString(labelNames); + + // assert + Assert.Empty(output); + } + + [Fact] + public void LabelValues_MoreLabelNames_Returns_Error() { + // setup + var labelValues = new LabelValues(["value1", "value2"]); + string[] labelNames = ["value1", "value2", "value3"]; + + // act + var output = labelValues.ToDisplayString(labelNames); + + // assert + Assert.Equal($"{{Improper Observation Labels, LabelNamesCount: {labelNames.Length}, LabelValuesCount: {labelValues.Values.Length}}}", output); + } + + [Fact] + public void LabelValues_MoreLabelValues_Returns_Error() { + // setup + var labelValues = new LabelValues(["value1", "value2", "value3"]); + string[] labelNames = ["value1", "value2"]; + + // act + var output = labelValues.ToDisplayString(labelNames); + + // assert + Assert.Equal($"{{Improper Observation Labels, LabelNamesCount: {labelNames.Length}, LabelValuesCount: {labelValues.Values.Length}}}", output); + } + + [Fact] + public void LabelValues_ToDisplayString() { + // setup + var labelValues = new LabelValues(["value1", "value2", "value3"]); + string[] labelNames = ["name1", "name2", "name3"]; + + // act + var output = labelValues.ToDisplayString(labelNames); + + // assert + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\"}", output); + } + + [Fact] + public void LabelValues_ToDisplayString_Additional_Values_Requires_Both() { + // setup + var labelValues = new LabelValues(["value1", "value2", "value3"]); + string[] labelNames = ["name1", "name2", "name3"]; + + // act + var output1 = labelValues.ToDisplayString(labelNames, string.Empty); + var output2 = labelValues.ToDisplayString(labelNames, ""); + var output3 = labelValues.ToDisplayString(labelNames, "additional_name"); + var output4 = labelValues.ToDisplayString(labelNames, additionalValue: string.Empty); + var output5 = labelValues.ToDisplayString(labelNames, additionalValue: ""); + var output6 = labelValues.ToDisplayString(labelNames, additionalValue: "additional_value"); + var output7 = labelValues.ToDisplayString(labelNames, "additional_name", "additional_value"); + + // assert + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\"}", output1); + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\"}", output2); + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\"}", output3); + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\"}", output4); + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\"}", output5); + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\"}", output6); + Assert.Equal("{name1=\"value1\",name2=\"value2\",name3=\"value3\",additional_name=\"additional_value\"}", output7); + } + + [Fact] + public void Definitions_Properly_Assign_Values() { + // setup + const string name = "definitionName"; + string[] labelNames = ["name1", "name2", "name3" ]; + double[] buckets = [0, 1, 2, 3]; + + // act + var counter = new CounterDefinition(name, labelNames); + var gauge = new GaugeDefinition(name, labelNames); + var histogram = new CumulativeHistogramDefinition(name, buckets, labelNames); + + // assert + Assert.Equal(name, counter.Name); + Assert.Equal(name, gauge.Name); + Assert.Equal(name, histogram.Name); + Assert.Equal(labelNames, counter.LabelNames); + Assert.Equal(labelNames, gauge.LabelNames); + Assert.Equal(labelNames, histogram.LabelNames); + Assert.Equal(buckets.Length, histogram.Buckets.Length); + for (var i = 0; i < buckets.Length; ++i) { + Assert.Equal(buckets[i], histogram.Buckets[i]); + } + + } + + [Fact] + public void CumulativeHistogramDefinition_NormalizesBuckets() { + // setup + double[] initBuckets = [5, 4, 3, 2, 1]; + + // act + var definition = new CumulativeHistogramDefinition("name", initBuckets, []); + Array.Sort(initBuckets); + + // assert + Assert.Equal(initBuckets.Length, definition.Buckets.Length); + for (var i = 0; i < definition.Buckets.Length; ++i) { + Assert.Equal(initBuckets[i], definition.Buckets[i]); + } + } +} \ No newline at end of file diff --git a/test/unit/MetricRegistryTests.cs b/test/unit/MetricRegistryTests.cs new file mode 100644 index 00000000..4ab3630f --- /dev/null +++ b/test/unit/MetricRegistryTests.cs @@ -0,0 +1,65 @@ +using SharpHoundCommonLib.Models; +using SharpHoundCommonLib.Services; +using SharpHoundCommonLib.Static; +using Xunit; + +namespace CommonLibTest; + +public class MetricRegistryTests { + [Fact] + public void TryRegister_Returns_definitionID_if_Success() { + // setup + var registry = new MetricRegistry(); + var counter = new CounterDefinition("counter_name", ["value"]); + var gauge = new GaugeDefinition("gauge_name", ["value"]); + + // act + var registered1 = registry.TryRegister(counter, out var counterDefinitionId); + var registered2 = registry.TryRegister(gauge, out var gaugeDefinitionId); + + // assert + Assert.True(registered1); + Assert.True(registered2); + Assert.Equal(0, counterDefinitionId); + Assert.Equal(1, gaugeDefinitionId); + } + + [Fact] + public void TryRegister_Gets_Preregistered_Definition_by_Name() { + // setup + var registry = new MetricRegistry(); + var counter1 = new CounterDefinition("counter_name", ["value"]); + var counter2 = new CounterDefinition("counter_name", ["value"]); + + // act + var registered1 = registry.TryRegister(counter1, out var counterDefinitionId1); + var registered2 = registry.TryRegister(counter2, out var counterDefinitionId2); + + // assert + Assert.True(registered1); + Assert.True(registered2); + Assert.Equal(0, counterDefinitionId1); + Assert.Equal(counterDefinitionId1, counterDefinitionId2); + Assert.Single(registry.Definitions); + } + + [Fact] + public void TryRegister_After_Sealing_Returns_false_and_InvalidId() { + // setup + var registry = new MetricRegistry(); + var counter = new CounterDefinition("counter_name", ["value"]); + var gauge = new GaugeDefinition("gauge_name", ["value"]); + + // act + var registered1 = registry.TryRegister(counter, out var counterDefinitionId); + registry.Seal(); + var registered2 = registry.TryRegister(gauge, out var gaugeDefinitionId); + + // assert + Assert.True(registered1); + Assert.False(registered2); + Assert.Equal(0, counterDefinitionId); + Assert.Equal(MetricId.InvalidId, gaugeDefinitionId); + Assert.Single(registry.Definitions); + } +} \ No newline at end of file diff --git a/test/unit/MetricRouterTests.cs b/test/unit/MetricRouterTests.cs new file mode 100644 index 00000000..2b1bcb67 --- /dev/null +++ b/test/unit/MetricRouterTests.cs @@ -0,0 +1,86 @@ +using Moq; +using SharpHoundCommonLib.Interfaces; +using SharpHoundCommonLib.Models; +using SharpHoundCommonLib.Services; +using Xunit; + +namespace CommonLibTest; + +public class MetricRouterTests { + + [Theory] + [InlineData(-1)] + [InlineData(5)] + [InlineData(2)] + public void InvalidIds_Are_Not_Cached_or_Observed(int definitionId) { + // setup + MetricDefinition[] definitions = [ + new CounterDefinition("counter_name_1", ["name"]), + new CounterDefinition("counter_name_2", ["name"]), + ]; + var labelCacheMoq = new Mock(); + var sinkMoq = new Mock(); + var router = new MetricRouter( + definitions, + [sinkMoq.Object], + labelCacheMoq.Object + ); + + // act + router.Observe(definitionId, 1.0, new LabelValues(["value"])); + + // assert + sinkMoq.Verify(s => s.Observe(in It.Ref.IsAny), Times.Never()); + labelCacheMoq.Verify(c => c.Intern(It.IsAny()), Times.Never()); + } + + [Fact] + public void LabelValues_Are_Interned_And_Each_Sink_Is_Observed() { + // setup + MetricDefinition[] definitions = [ + new CounterDefinition("counter_name_1", ["name"]), + new CounterDefinition("counter_name_2", ["name"]), + ]; + var labelCacheMoq = new Mock(); + var sinkMoq1 = new Mock(); + var sinkMoq2 = new Mock(); + var router = new MetricRouter( + definitions, + [sinkMoq1.Object, sinkMoq2.Object], + labelCacheMoq.Object + ); + + // act + router.Observe(0, 1.0, new LabelValues(["value"])); + + // assert + labelCacheMoq.Verify(c => c.Intern(It.IsAny()), Times.Once()); + sinkMoq1.Verify(s => s.Observe(in It.Ref.IsAny), Times.Once()); + sinkMoq2.Verify(s => s.Observe(in It.Ref.IsAny), Times.Once()); + } + + [Fact] + public void Flush_Calls_Flush_on_Each_Sink() { + // setup + MetricDefinition[] definitions = [ + ]; + var labelCacheMoq = new Mock(); + var sinkMoq1 = new Mock(); + var sinkMoq2 = new Mock(); + var router = new MetricRouter( + definitions, + [sinkMoq1.Object, sinkMoq2.Object], + labelCacheMoq.Object + ); + + // act + router.Flush(); + + // assert + sinkMoq1.Verify(s => s.Flush(), Times.Once()); + sinkMoq2.Verify(s => s.Flush(), Times.Once()); + sinkMoq1.Verify(s => s.Observe(in It.Ref.IsAny), Times.Never()); + sinkMoq2.Verify(s => s.Observe(in It.Ref.IsAny), Times.Never()); + labelCacheMoq.Verify(c => c.Intern(It.IsAny()), Times.Never()); + } +} \ No newline at end of file diff --git a/test/unit/Utils.cs b/test/unit/Utils.cs index 10e9d90a..3ae8fd08 100644 --- a/test/unit/Utils.cs +++ b/test/unit/Utils.cs @@ -90,4 +90,10 @@ public WindowsOnlyFact() if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) Skip = "Ignore on non-Windows platforms"; } } + + public sealed class WindowsOnlyTheory : TheoryAttribute { + public WindowsOnlyTheory() { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) Skip = "Ignore on non-Windows platforms"; + } + } } \ No newline at end of file