diff --git a/src/main/java/com/timgroup/statsd/StatsDAggregator.java b/src/main/java/com/timgroup/statsd/StatsDAggregator.java index 5a0adae3..9d27c00d 100644 --- a/src/main/java/com/timgroup/statsd/StatsDAggregator.java +++ b/src/main/java/com/timgroup/statsd/StatsDAggregator.java @@ -6,13 +6,17 @@ import java.util.Map; import java.util.Timer; import java.util.TimerTask; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; public class StatsDAggregator { public static int DEFAULT_FLUSH_INTERVAL = 2000; // 2s public static int DEFAULT_SHARDS = 4; // 4 partitions to reduce contention. protected final String AGGREGATOR_THREAD_NAME = "statsd-aggregator-thread"; + protected final ArrayList> aggregateMetrics; + private final Lock[] locks; protected final int shardGranularity; protected final long flushInterval; @@ -44,6 +48,7 @@ public StatsDAggregator( this.flushInterval = flushInterval; this.shardGranularity = shards; this.aggregateMetrics = new ArrayList<>(shards); + this.locks = new ReentrantLock[shards]; if (flushInterval > 0) { this.scheduler = new Timer(AGGREGATOR_THREAD_NAME, true); @@ -51,6 +56,7 @@ public StatsDAggregator( for (int i = 0; i < this.shardGranularity; i++) { this.aggregateMetrics.add(i, new HashMap()); + this.locks[i] = new ReentrantLock(); } } @@ -86,7 +92,8 @@ public boolean aggregateMessage(Message message) { int bucket = Math.abs(hash % this.shardGranularity); Map map = aggregateMetrics.get(bucket); - synchronized (map) { + locks[bucket].lock(); + try { // For now let's just put the message in the map Message msg = MapUtils.putIfAbsent(map, message); if (msg != null) { @@ -110,6 +117,8 @@ public boolean aggregateMessage(Message message) { } } } + } finally { + locks[bucket].unlock(); } return true; @@ -127,7 +136,8 @@ protected void flush() { for (int i = 0; i < shardGranularity; i++) { Map map = aggregateMetrics.get(i); - synchronized (map) { + locks[i].lock(); + try { Iterator> iter = map.entrySet().iterator(); while (iter.hasNext()) { Message msg = iter.next().getValue(); @@ -139,6 +149,8 @@ protected void flush() { iter.remove(); } + } finally { + locks[i].unlock(); } } } diff --git a/src/test/java/com/timgroup/statsd/StatsDAggregatorTest.java b/src/test/java/com/timgroup/statsd/StatsDAggregatorTest.java index ece7b4ba..92e34d0f 100644 --- a/src/test/java/com/timgroup/statsd/StatsDAggregatorTest.java +++ b/src/test/java/com/timgroup/statsd/StatsDAggregatorTest.java @@ -254,17 +254,15 @@ public int hashCode() { for (int i = 0; i < StatsDAggregator.DEFAULT_SHARDS; i++) { Map map = fakeProcessor.aggregator.aggregateMetrics.get(i); - synchronized (map) { - Iterator> iter = map.entrySet().iterator(); - int count = 0; - while (iter.hasNext()) { - count++; - iter.next(); - } - - // sharding should be balanced - assertEquals(iterations, count); + Iterator> iter = map.entrySet().iterator(); + int count = 0; + while (iter.hasNext()) { + count++; + iter.next(); } + + // sharding should be balanced + assertEquals(iterations, count); } }