@@ -5,6 +5,7 @@ import com.android.build.gradle.internal.tasks.DeviceProviderInstrumentTestTask
55import org.gradle.api.Plugin
66import org.gradle.api.Project
77import org.gradle.api.Task
8+ import org.gradle.api.file.Directory
89import org.gradle.api.provider.ProviderFactory
910import java.io.File
1011
@@ -24,74 +25,93 @@ class AndroidSnaptestingPlugin : Plugin<Project> {
2425 ? : throw RuntimeException (" TestedExtension not found" )
2526
2627 val isRecordMode = project.properties[" android.testInstrumentationRunnerArguments.record" ] == " true"
27- val projectDir = project.projectDir
2828 val providerFactory: ProviderFactory = project.providers
2929
3030 deviceProviderInstrumentTestTasks.names.forEach { taskName ->
3131 val deviceProviderTask = project.tasks.named(
3232 taskName,
3333 DeviceProviderInstrumentTestTask ::class .java,
3434 ).get()
35- val capitalizedVariant = deviceProviderTask.variantName.capitalizeFirstLetter()
36-
37- @Suppress(" DEPRECATION" )
38- val testedVariant = extension.testVariants
39- .firstOrNull { it.name == deviceProviderTask.variantName }
40- ? : throw RuntimeException (" TestVariant not found for ${deviceProviderTask.variantName} " )
41- val applicationIdProvider = providerFactory.provider { testedVariant.applicationId }
42- val adbExecutablePath = extension.adbExecutable.absolutePath
43-
44- val goldenSnapshotsSourcePath = run {
45- val variantSourceFolder = deviceProviderTask
46- .variantName
47- .replace(" AndroidTest" , " " )
48- .capitalizeFirstLetter()
49- .let { " androidTest$it " }
50- " $projectDir /src/$variantSourceFolder /assets/android-snaptesting-golden-files"
51- }
35+ registerTasksForVariant(project, taskName, deviceProviderTask, extension, isRecordMode, providerFactory)
36+ }
37+ }
38+ }
5239
53- // Note: in Kotlin, doFirst/doLast lambdas receive the task as 'it', not 'this'.
54- deviceProviderTask.doFirst {
55- (it as DeviceProviderInstrumentTestTask )
56- .deviceFileManager(applicationIdProvider.get(), adbExecutablePath, providerFactory)
57- .clearAllSnapshots()
58- }
40+ @Suppress(" DEPRECATION" )
41+ private fun registerTasksForVariant (
42+ project : Project ,
43+ taskName : String ,
44+ deviceProviderTask : DeviceProviderInstrumentTestTask ,
45+ extension : TestedExtension ,
46+ isRecordMode : Boolean ,
47+ providerFactory : ProviderFactory ,
48+ ) {
49+ val capitalizedVariant = deviceProviderTask.variantName.capitalizeFirstLetter()
50+
51+ val testedVariant = extension.testVariants
52+ .firstOrNull { it.name == deviceProviderTask.variantName }
53+ ? : throw RuntimeException (" TestVariant not found for ${deviceProviderTask.variantName} " )
54+ val applicationIdProvider = providerFactory.provider { testedVariant.applicationId }
55+ val adbExecutablePath = extension.adbExecutable.absolutePath
56+
57+ val goldenSnapshotsSourcePath = run {
58+ val variantSourceFolder = deviceProviderTask
59+ .variantName
60+ .replace(" AndroidTest" , " " )
61+ .capitalizeFirstLetter()
62+ .let { " androidTest$it " }
63+ " ${project.projectDir} /src/$variantSourceFolder /assets/android-snaptesting-golden-files"
64+ }
5965
60- // Before task as dependency anchor for CI scripts.
61- val beforeTaskName = " androidSnaptestingBefore$capitalizedVariant "
62- project.tasks.register(beforeTaskName, Task ::class .java)
63- deviceProviderTask.dependsOn(beforeTaskName)
64-
65- // After task runs post-processing via finalizedBy, which guarantees
66- // execution even when the test task fails (needed to pull snapshot
67- // results and generate reports on failure).
68- val afterTaskName = " androidSnaptestingAfter$capitalizedVariant "
69- project.tasks.register(afterTaskName, Task ::class .java) { task ->
70- task.doLast {
71- deviceProviderTask.afterExecution(
72- applicationId = applicationIdProvider.get(),
73- adbExecutablePath = adbExecutablePath,
74- providerFactory = providerFactory,
75- isRecordMode = isRecordMode,
76- goldenSnapshotsSourcePath = goldenSnapshotsSourcePath,
77- )
78- }
79- }
80- deviceProviderTask.finalizedBy(afterTaskName)
66+ // Shared provider — used by both before and after tasks (config-cache safe: references task by name)
67+ val deviceProviderFactoryProvider = project.tasks.named(taskName, DeviceProviderInstrumentTestTask ::class .java)
68+ .map { it.deviceProviderFactory }
69+
70+ // Before task clears snapshots and serves as dependency anchor for CI scripts.
71+ val beforeTaskName = " androidSnaptestingBefore$capitalizedVariant "
72+ project.tasks.register(beforeTaskName, Task ::class .java) { task ->
73+ task.doFirst {
74+ DeviceFileManager (deviceProviderFactoryProvider.get(), applicationIdProvider.get(), adbExecutablePath, providerFactory)
75+ .clearAllSnapshots()
76+ }
77+ }
78+ deviceProviderTask.dependsOn(beforeTaskName)
79+
80+ // After task runs post-processing via finalizedBy, which guarantees
81+ // execution even when the test task fails (needed to pull snapshot
82+ // results and generate reports on failure).
83+ val afterTaskName = " androidSnaptestingAfter$capitalizedVariant "
84+ val reportsDirProvider = project.tasks.named(taskName, DeviceProviderInstrumentTestTask ::class .java)
85+ .flatMap { it.reportsDir }
86+
87+ project.tasks.register(afterTaskName, Task ::class .java) { task ->
88+ task.doLast {
89+ afterExecution(
90+ deviceProviderFactory = deviceProviderFactoryProvider.get(),
91+ reportsDir = reportsDirProvider.get(),
92+ applicationId = applicationIdProvider.get(),
93+ adbExecutablePath = adbExecutablePath,
94+ providerFactory = providerFactory,
95+ isRecordMode = isRecordMode,
96+ goldenSnapshotsSourcePath = goldenSnapshotsSourcePath,
97+ )
8198 }
8299 }
100+ deviceProviderTask.finalizedBy(afterTaskName)
83101 }
84102
85- private fun DeviceProviderInstrumentTestTask.afterExecution (
103+ private fun afterExecution (
104+ deviceProviderFactory : DeviceProviderInstrumentTestTask .DeviceProviderFactory ,
105+ reportsDir : Directory ,
86106 applicationId : String ,
87107 adbExecutablePath : String ,
88108 providerFactory : ProviderFactory ,
89109 isRecordMode : Boolean ,
90110 goldenSnapshotsSourcePath : String ,
91111 ) {
92- val deviceFileManager = deviceFileManager( applicationId, adbExecutablePath, providerFactory)
112+ val deviceFileManager = DeviceFileManager (deviceProviderFactory, applicationId, adbExecutablePath, providerFactory)
93113
94- val reportsFolder = reportsDir.get(). dir(" androidSnaptesting" )
114+ val reportsFolder = reportsDir.dir(" androidSnaptesting" )
95115 val recordedFolderFile = reportsFolder.dir(" recorded" ).asFile.apply {
96116 mkdirs()
97117 deviceFileManager.pullRecordedSnapshots(absolutePath)
0 commit comments