diff --git a/wire-golden-files/src/main/kotlin/squareup/wire/sealedoneof/SealedOneOfs.kt b/wire-golden-files/src/main/kotlin/squareup/wire/sealedoneof/SealedOneOfs.kt index 1cc7199656..f5238e5f7a 100644 --- a/wire-golden-files/src/main/kotlin/squareup/wire/sealedoneof/SealedOneOfs.kt +++ b/wire-golden-files/src/main/kotlin/squareup/wire/sealedoneof/SealedOneOfs.kt @@ -129,11 +129,11 @@ public class SealedOneOfs private constructor( public companion object { @JvmField public val ADAPTER: ProtoAdapter = object : ProtoAdapter( - FieldEncoding.LENGTH_DELIMITED, - SealedOneOfs::class, - "type.googleapis.com/squareup.wire.sealedoneof.SealedOneOfs", - PROTO_2, - null, + FieldEncoding.LENGTH_DELIMITED, + SealedOneOfs::class, + "type.googleapis.com/squareup.wire.sealedoneof.SealedOneOfs", + PROTO_2, + null, "squareup/wire/sealed_oneof.proto" ) { override fun encodedSize(`value`: SealedOneOfs): Int { @@ -195,3 +195,12 @@ public class SealedOneOfs private constructor( public inline fun build(body: Builder.() -> Unit): SealedOneOfs = Builder().apply(body).build() } } + +public val SealedOneOfs.Value.first_value: SealedMessage? + get() = (this as? SealedOneOfs.Value.FirstValue)?.value + +public val SealedOneOfs.Value.second_value: String? + get() = (this as? SealedOneOfs.Value.SecondValue)?.value + +public val SealedOneOfs.Value.value_: Int? + get() = (this as? SealedOneOfs.Value.Value)?.value diff --git a/wire-kotlin-generator/api/wire-kotlin-generator.api b/wire-kotlin-generator/api/wire-kotlin-generator.api index 8a4e86b6a7..06e41aeb4f 100644 --- a/wire-kotlin-generator/api/wire-kotlin-generator.api +++ b/wire-kotlin-generator/api/wire-kotlin-generator.api @@ -10,6 +10,7 @@ public final class com/squareup/wire/kotlin/KotlinGenerator { public static final field Companion Lcom/squareup/wire/kotlin/KotlinGenerator$Companion; public synthetic fun (Lcom/squareup/wire/schema/Schema;Ljava/util/Map;Ljava/util/Map;Lcom/squareup/wire/schema/Profile;ZZZZLcom/squareup/wire/kotlin/RpcCallStyle;Lcom/squareup/wire/kotlin/RpcRole;ILjava/lang/String;ZZLcom/squareup/wire/kotlin/EnumMode;Lcom/squareup/wire/kotlin/OneofMode;ZZZZLkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun generateOptionType (Lcom/squareup/wire/schema/Extend;Lcom/squareup/wire/schema/Field;)Lcom/squareup/kotlinpoet/TypeSpec; + public final fun generateSealedOneOfAccessors (Lcom/squareup/wire/schema/Type;)Ljava/util/List; public final fun generateServiceTypeSpecs (Lcom/squareup/wire/schema/Service;Lcom/squareup/wire/schema/Rpc;)Ljava/util/Map; public static synthetic fun generateServiceTypeSpecs$default (Lcom/squareup/wire/kotlin/KotlinGenerator;Lcom/squareup/wire/schema/Service;Lcom/squareup/wire/schema/Rpc;ILjava/lang/Object;)Ljava/util/Map; public final fun generateType (Lcom/squareup/wire/schema/Type;)Lcom/squareup/kotlinpoet/TypeSpec; diff --git a/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinGenerator.kt b/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinGenerator.kt index fc786c3d2b..ca09842914 100644 --- a/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinGenerator.kt +++ b/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinGenerator.kt @@ -338,6 +338,16 @@ class KotlinGenerator private constructor( } } + fun generateSealedOneOfAccessors(type: Type): List { + if (oneofMode != OneofMode.SEALED_CLASS) return listOf() + + return when (type) { + is MessageType -> type.sealedOneOfAccessors() + is EnclosingType -> type.nestedTypes.flatMap(::generateSealedOneOfAccessors) + else -> listOf() + } + } + /** * Generates all [TypeSpec]s for the given [Service]. * @@ -3063,6 +3073,53 @@ class KotlinGenerator private constructor( else -> boxedOneOfClassName(oneOf.name) } + /** + * Generates top-level extension properties that read values from sealed oneof instances. + * + * Example: + * ``` + * public val PaymentMethodChoice.Method.card_id: String? + * get() = (this as? PaymentMethodChoice.Method.CardId)?.value + * + * public val PaymentMethodChoice.Method.bank_account: BankAccount? + * get() = (this as? PaymentMethodChoice.Method.BankAccount)?.value + * ``` + */ + private fun MessageType.sealedOneOfAccessors(): List { + val accessors = mutableListOf() + val className = typeName as ClassName + val nameAllocator = nameAllocator(this) + + for (oneOf in sealedOneOfs()) { + val sealedClassName = className.nestedClass(nameAllocator[boxedOneOfClassName(oneOf.name)]) + val subclassNameAllocator = sealedSubclassNameAllocator(oneOf) + val accessorNameAllocator = NameAllocator(preallocateKeywords = !escapeKotlinKeywords) + + for (field in oneOf.fields) { + val fieldName = if (field.name == field.type!!.simpleName || hasEponymousType(schema, field)) { + accessorNameAllocator.newName(legacyQualifiedFieldName(field), field) + } else { + accessorNameAllocator.newName(field.name, field) + } + val subclassName = sealedClassName.nestedClass(subclassNameAllocator[field]) + accessors += PropertySpec.builder(fieldName, field.type!!.typeName.copy(nullable = true)) + .receiver(sealedClassName) + .getter( + FunSpec.getterBuilder() + .addStatement("return (this as? %T)?.value", subclassName) + .build(), + ) + .build() + } + } + + for (nestedType in nestedTypes) { + accessors += generateSealedOneOfAccessors(nestedType) + } + + return accessors + } + /** * Generates a sealed class for a oneof. * @@ -3314,18 +3371,18 @@ class KotlinGenerator private constructor( private fun MessageType.flatOneOfs(): List = when (oneofMode) { OneofMode.FLAT -> oneOfs.filter { it.fields.size < boxOneOfsMinSize } - else -> emptyList() + else -> listOf() } private fun MessageType.boxOneOfs(): List = when (oneofMode) { OneofMode.FLAT -> oneOfs.filter { it.fields.size >= boxOneOfsMinSize } OneofMode.BOXED -> oneOfs - OneofMode.SEALED_CLASS -> emptyList() + OneofMode.SEALED_CLASS -> listOf() } private fun MessageType.sealedOneOfs(): List = when (oneofMode) { OneofMode.SEALED_CLASS -> oneOfs - else -> emptyList() + else -> listOf() } private fun sealedSubclassNameAllocator(oneOf: OneOf): NameAllocator = sealedSubclassNameAllocatorStore.getOrPut(oneOf) { @@ -3545,7 +3602,7 @@ class KotlinGenerator private constructor( AnnotationTarget.FIELD, ) METHOD_OPTIONS -> listOf(AnnotationTarget.FUNCTION) - else -> emptyList() + else -> listOf() } internal fun String.sanitizeKdoc(): String = this diff --git a/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinSchemaHandler.kt b/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinSchemaHandler.kt index c4fb7c40df..eb841f4682 100644 --- a/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinSchemaHandler.kt +++ b/wire-kotlin-generator/src/main/java/com/squareup/wire/kotlin/KotlinSchemaHandler.kt @@ -19,6 +19,7 @@ import com.squareup.kotlinpoet.AnnotationSpec import com.squareup.kotlinpoet.AnnotationSpec.UseSiteTarget.FILE import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FileSpec +import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec import com.squareup.wire.schema.Extend import com.squareup.wire.schema.Field @@ -143,12 +144,20 @@ class KotlinSchemaHandler( if (KotlinGenerator.builtInType(type.type)) return null val typeSpec = kotlinGenerator.generateType(type) + val topLevelProperties = kotlinGenerator.generateSealedOneOfAccessors(type) val className = kotlinGenerator.generatedTypeName(type) - return write(className, typeSpec, type.type, type.location, context) + return write( + className, + typeSpec, + type.type, + type.location, + context, + topLevelProperties = topLevelProperties, + ) } override fun handle(service: Service, context: Context): List { - if (rpcRole === RpcRole.NONE) return emptyList() + if (rpcRole === RpcRole.NONE) return listOf() val generatedPaths = mutableListOf() @@ -183,6 +192,7 @@ class KotlinSchemaHandler( source: Any, location: Location, context: Context, + topLevelProperties: List = listOf(), ): Path { val modulePath = context.outDirectory val kotlinFile = FileSpec.builder(name.packageName, name.simpleName) @@ -198,6 +208,11 @@ class KotlinSchemaHandler( .build(), ) .addType(typeSpec) + .apply { + for (propertySpec in topLevelProperties) { + addProperty(propertySpec) + } + } .build() val filePath = modulePath / kotlinFile.packageName.replace(".", "/") / @@ -205,7 +220,7 @@ class KotlinSchemaHandler( context.logger.artifactHandled( modulePath, - "${kotlinFile.packageName}.${(kotlinFile.members.first() as TypeSpec).name}", + "${kotlinFile.packageName}.${name.simpleName}", "Kotlin", ) try { diff --git a/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinGeneratorTest.kt b/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinGeneratorTest.kt index 906b1c994d..31ca2a0e47 100644 --- a/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinGeneratorTest.kt +++ b/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinGeneratorTest.kt @@ -2246,6 +2246,34 @@ class KotlinGeneratorTest { ) } + @Test fun sealedOneofGeneratesFieldAccessors() { + val schema = buildSchema { + add( + "message.proto".toPath(), + """ + |syntax = "proto2"; + |message PaymentMethodChoice { + | oneof method { + | string card_id = 1; + | int32 value = 2; + | } + |} + """.trimMargin(), + ) + } + val code = KotlinWithProfilesGenerator(schema) + .generateKotlin("PaymentMethodChoice", oneofMode = OneofMode.SEALED_CLASS) + assertThat(code).contains( + """ + |public val PaymentMethodChoice.Method.card_id: String? + | get() = (this as? PaymentMethodChoice.Method.CardId)?.value + | + |public val PaymentMethodChoice.Method.value_: Int? + | get() = (this as? PaymentMethodChoice.Method.Value)?.value + """.trimMargin(), + ) + } + @Test fun sealedOneofGeneratesCamelCaseClassName() { val schema = buildSchema { add( diff --git a/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinWithProfilesGenerator.kt b/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinWithProfilesGenerator.kt index 6f67d9db71..171d8d8a88 100644 --- a/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinWithProfilesGenerator.kt +++ b/wire-kotlin-generator/src/test/java/com/squareup/wire/kotlin/KotlinWithProfilesGenerator.kt @@ -70,6 +70,11 @@ internal class KotlinWithProfilesGenerator(private val schema: Schema) { val typeSpec = kotlinGenerator.generateType(type) val packageName = kotlinGenerator.generatedTypeName(type).packageName val fileSpec = FileSpec.builder(packageName, "_") + .apply { + for (propertySpec in kotlinGenerator.generateSealedOneOfAccessors(type)) { + addProperty(propertySpec) + } + } .addType(typeSpec) .addImport("com.squareup.wire.kotlin", "decodeMessage") .build()