Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 9 additions & 20 deletions etc/junit4-missing-features.txt
Original file line number Diff line number Diff line change
@@ -1,24 +1,5 @@
[ai generated overview of junit4 features]

8. Timeouts
- Standard JUnit @Test(timeout=N) is honoured
- @Timeout(millis=N) annotation provides an explicit alternative
- Termination sequence: Thread.interrupt() → Thread.stop() → zombie
detection; all attempts are logged with stack traces

9. Thread-leak detection
- Threads that escape a test's ThreadGroup boundary are killed and cause
a test failure
- Encourages explicit Thread.join() before a test method returns

10. Lingering threads and advanced thread-leak control
- @ThreadLeakLingering(linger=N) waits up to N ms for stray threads to
finish naturally (useful for Executor pools or other uncontrolled threads)
- Additional annotations for fine-grained policy:
@ThreadLeakScope – suite vs. test scope
@ThreadLeakAction – warn vs. fail
@ThreadLeakZombies – ignore vs. fail on zombie threads

11. Nightly / scaled tests
- @Nightly marks a test that only runs when nightly mode is active
(-Dtests.nightly=true)
Expand All @@ -38,4 +19,12 @@

- predictably shuffled test execution order
- blowing up test reps using tests.iters
-

[to check/ add tests of]

- is the seed stack trace frame injected for leaked threads + randomized testing ext?
- can we enforce the order of extensions (randomized testing > leaked threads)
- how are jupiter timeouts working together with leaked threads ext.?
- maybe bring back thread leak zombies annotation (if we can't cleanly terminate leaked threads, ignore all remaining tests).
- maybe move some of the implementation details to a non-exposed package?
- regenerate the javadocs with public API only.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.carrotsearch.randomizedtesting.jupiter;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.function.Predicate;
import org.junit.jupiter.api.extension.ExtendWith;

/**
* Detects threads started within the annotated test class that are still alive after the configured
* scope ends.
*
* <p>Only functional in sequential (same-thread) execution mode. Emits a warning and skips
* detection if tests run concurrently.
*/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@ExtendWith(DetectThreadLeaksExtension.class)
@Inherited
public @interface DetectThreadLeaks {
/** Scope at which thread leak detection is performed. */
Scope scope() default Scope.SUITE;

enum Scope {
/** Disable thread leak detection entirely. */
NONE,
/** Check for leaked threads once after all tests in the class complete. */
SUITE,
/** Check for leaked threads after each individual test method. */
TEST
}

/**
* Milliseconds to wait for leaked threads to self-terminate before declaring a failure. If all
* leaked threads terminate within this window, the test passes. Default is 0 (no lingering).
*
* <p>Place this annotation on the same class or method as {@link DetectThreadLeaks}. A
* method-level annotation takes precedence over a class-level one.
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
@interface LingerTime {
int millis();
}

/**
* Excludes threads matched by any of the given {@link Predicate} classes from leak detection. A
* thread is excluded when at least one predicate returns {@code true} for it.
*
* <p>Annotations are collected hierarchically from the class and its superclasses, and the
* filters from all levels are combined.
*
* @see SystemThreadFilter
*/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@interface ExcludeThreads {
Class<? extends Predicate<Thread>>[] value() default {SystemThreadFilter.class};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
package com.carrotsearch.randomizedtesting.jupiter;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.parallel.ExecutionMode;

/** JUnit Jupiter extension implementing {@link DetectThreadLeaks}. */
public class DetectThreadLeaksExtension
implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback {

private static final Logger LOGGER = Logger.getLogger(DetectThreadLeaksExtension.class.getName());
private static final ExtensionContext.Namespace EXTENSION_NAMESPACE =
ExtensionContext.Namespace.create(DetectThreadLeaksExtension.class);
private static final String THREAD_SNAPSHOT_KEY = "snapshot";
private static final String CONCURRENT_KEY = "concurrent";
private static final String UNCAUGHT_EXCEPTION_HANDLER_KEY = "uncaught-exception-handler";

/** Total time budget to join interrupted threads before giving up. */
private static final Duration INTERRUPT_JOIN_MS = Duration.ofSeconds(3);

@Override
public void beforeAll(ExtensionContext context) {
if (scope(context) == DetectThreadLeaks.Scope.NONE) {
return;
}

if (context.getExecutionMode() != ExecutionMode.SAME_THREAD) {
LOGGER.warning(
"Thread leak detection is disabled: tests in ["
+ context.getDisplayName()
+ "] run in concurrent execution mode.");
context.getStore(EXTENSION_NAMESPACE).put(CONCURRENT_KEY, Boolean.TRUE);
return;
}

var store = context.getStore(EXTENSION_NAMESPACE);
var filter = buildFilter(context);
store.put(UNCAUGHT_EXCEPTION_HANDLER_KEY, installUncaughtExceptionHandler());
store.put(THREAD_SNAPSHOT_KEY, liveThreads(filter));
}

@Override
public void beforeEach(ExtensionContext context) {
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
return;
}

var store = context.getStore(EXTENSION_NAMESPACE);
var filter = buildFilter(context);
store.put(THREAD_SNAPSHOT_KEY, liveThreads(filter));
}

@Override
public void afterEach(ExtensionContext context) {
if (isConcurrentMode(context) || scope(context) != DetectThreadLeaks.Scope.TEST) {
return;
}

var store = context.getStore(EXTENSION_NAMESPACE);
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
try {
checkLeaks(
store,
"test [" + context.getDisplayName() + "]",
linger(context),
buildFilter(context),
handler);
} finally {
if (handler != null) handler.restore();
}
}

@Override
public void afterAll(ExtensionContext context) {
if (isConcurrentMode(context) || scope(context) == DetectThreadLeaks.Scope.NONE) {
return;
}

var store = context.getStore(EXTENSION_NAMESPACE);
var handler = store.get(UNCAUGHT_EXCEPTION_HANDLER_KEY, UncaughtExceptionsHandler.class);
try {
checkLeaks(
store,
"suite [" + context.getDisplayName() + "]",
linger(context),
buildFilter(context),
handler);
} finally {
if (handler != null) handler.restore();
}
}

private static UncaughtExceptionsHandler installUncaughtExceptionHandler() {
var handler = new UncaughtExceptionsHandler(Thread.getDefaultUncaughtExceptionHandler());
Thread.setDefaultUncaughtExceptionHandler(handler);
return handler;
}

private static DetectThreadLeaks.Scope scope(ExtensionContext context) {
return context.getRequiredTestClass().getAnnotation(DetectThreadLeaks.class).scope();
}

private static int linger(ExtensionContext context) {
var methodAnn =
context
.getTestMethod()
.map(m -> m.getAnnotation(DetectThreadLeaks.LingerTime.class))
.orElse(null);
if (methodAnn != null) return methodAnn.millis();

var classAnn = context.getRequiredTestClass().getAnnotation(DetectThreadLeaks.LingerTime.class);
return classAnn == null ? 0 : classAnn.millis();
}

@DetectThreadLeaks.ExcludeThreads()
private static class AnnotationDefaultsSource {}

/**
* Collects {@link DetectThreadLeaks.ExcludeThreads} filter classes from the entire hierarchy
* (method to class to superclasses) and returns a combined predicate that excludes a thread when
* any filter matches it.
*/
private static Predicate<Thread> buildFilter(ExtensionContext context) {
List<DetectThreadLeaks.ExcludeThreads> excludeThreads = new ArrayList<>();

for (Class<?> cls = context.getRequiredTestClass(); cls != null; cls = cls.getSuperclass()) {
var ann = cls.getAnnotation(DetectThreadLeaks.ExcludeThreads.class);
if (ann != null) {
excludeThreads.add(ann);
}
}

if (excludeThreads.isEmpty()) {
excludeThreads.add(
AnnotationDefaultsSource.class.getAnnotation(DetectThreadLeaks.ExcludeThreads.class));
}

var filterClasses = new LinkedHashSet<Predicate<Thread>>();
for (var ann : excludeThreads) {
for (var cls : ann.value()) {
try {
filterClasses.add(cls.getDeclaredConstructor().newInstance());
} catch (ReflectiveOperationException e) {
throw new RuntimeException("Cannot instantiate thread filter: " + cls.getName(), e);
}
}
}

if (filterClasses.isEmpty()) {
return t -> false;
} else {
return t -> filterClasses.stream().anyMatch(p -> p.test(t));
}
}

private static boolean isConcurrentMode(ExtensionContext context) {
return context
.getParent()
.map(
p ->
Boolean.TRUE.equals(
p.getStore(EXTENSION_NAMESPACE).get(CONCURRENT_KEY, Boolean.class)))
.orElse(false);
}

private static void checkLeaks(
ExtensionContext.Store store,
String description,
int lingerMs,
Predicate<Thread> filter,
UncaughtExceptionsHandler handler) {
var snapshot = store.get(THREAD_SNAPSHOT_KEY, HashSet.class);
AssertionError leakError = null;

if (snapshot != null) {
var leaked = leakedSince(snapshot, filter);

// Linger: poll until threads self-terminate or the window expires.
if (!leaked.isEmpty() && lingerMs > 0) {
long deadline = System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(lingerMs);
while (!leaked.isEmpty() && System.nanoTime() < deadline) {
try {
long remainingMs = TimeUnit.NANOSECONDS.toMillis(deadline - System.nanoTime());
Thread.sleep(Math.max(1L, Math.min(100L, remainingMs)));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
leaked = leakedSince(snapshot, filter);
}
}

if (!leaked.isEmpty()) {
// Suppress uncaught exception reporting during the interrupt/join phase to avoid
// capturing expected InterruptedException-related exceptions from cleaned-up threads.
if (handler != null) {
handler.stopReporting();
}

try {
// Send an interrupt to all threads.
leaked.keySet().forEach(Thread::interrupt);

// Wait for all those threads.
long joinDeadline = System.nanoTime() + INTERRUPT_JOIN_MS.toNanos();
for (Thread t : leaked.keySet()) {
long remaining = TimeUnit.NANOSECONDS.toMillis(joinDeadline - System.nanoTime());
if (remaining <= 0) {
break;
}

try {
t.join(remaining);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
} finally {
if (handler != null) {
handler.resumeReporting();
}
}

var sb = new StringBuilder(leaked.size() + " thread(s) leaked from " + description + ":");
int cnt = 1;
for (var entry : leaked.entrySet()) {
sb.append(String.format("%n %2d) %s", cnt++, Threads.threadName(entry.getKey())));
for (var ste : entry.getValue()) {
sb.append(String.format("%n at %s", ste));
}
}
leakError = new AssertionError(sb.toString());
}
}

// Collect uncaught exceptions regardless of whether threads leaked.
List<UncaughtExceptionsHandler.UncaughtException> uncaught =
handler != null ? handler.getAndClear() : List.of();

if (leakError == null && uncaught.isEmpty()) return;

// Combine: leak error first (if any), uncaught exceptions after; all but the first
// are attached as suppressed on the thrown error.
var errors = new ArrayList<AssertionError>();
if (leakError != null) errors.add(leakError);
for (var ue : uncaught) {
errors.add(
new AssertionError("Uncaught exception in thread [" + ue.threadName() + "]", ue.error()));
}
var first = errors.get(0);
errors.subList(1, errors.size()).forEach(first::addSuppressed);
throw first;
}

private static Map<Thread, StackTraceElement[]> leakedSince(
HashSet<?> snapshot, Predicate<Thread> filter) {
var current = liveThreadsWithStacks(filter);
current.keySet().removeAll(snapshot);
return current;
}

private static HashSet<Thread> liveThreads(Predicate<Thread> filter) {
return new HashSet<>(liveThreadsWithStacks(filter).keySet());
}

private static Map<Thread, StackTraceElement[]> liveThreadsWithStacks(Predicate<Thread> filter) {
return Thread.getAllStackTraces().entrySet().stream()
.filter(e -> e.getKey().isAlive())
.filter(e -> !filter.test(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}
Loading
Loading