Skip to content

Commit 70673ef

Browse files
committed
fix: Adds support for List<Map> and List<List> and relative structures
Fixes #107
1 parent 686c60c commit 70673ef

File tree

4 files changed

+114
-42
lines changed
  • core/3.0/src/main/scala/org/apache/spark/sql
  • kotlin-spark-api
    • 2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api
    • 3.0/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api

4 files changed

+114
-42
lines changed

core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,8 @@ object KotlinReflection extends KotlinReflection {
592592

593593
def toCatalystArray(input: Expression, elementType: `Type`, predefinedDt: Option[DataTypeWithClass] = None): Expression = {
594594
predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match {
595-
case dt: StructType =>
595+
596+
case dt@(MapType(_, _, _) | ArrayType(_, _) | StructType(_)) =>
596597
val clsName = getClassNameFromType(elementType)
597598
val newPath = walkedTypePath.recordArray(clsName)
598599
createSerializerForMapObjects(input, ObjectType(predefinedDt.get.cls),

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,31 @@ class ApiTest : ShouldSpec({
517517
first.someOtherArray shouldBe arrayOf(SomeOtherEnum.C, SomeOtherEnum.D)
518518
first.enumMap shouldBe mapOf(SomeEnum.A to SomeOtherEnum.C)
519519
}
520+
should("work with lists of maps") {
521+
val result = dsOf(
522+
listOf(mapOf("a" to "b", "x" to "y")),
523+
listOf(mapOf("a" to "b", "x" to "y")),
524+
listOf(mapOf("a" to "b", "x" to "y"))
525+
)
526+
.showDS()
527+
.map { it.last() }
528+
.map { it["x"] }
529+
.filterNotNull()
530+
.distinct()
531+
.collectAsList()
532+
expect(result).contains.inOrder.only.value("y")
533+
}
534+
should("work with lists of lists") {
535+
val result = dsOf(
536+
listOf(listOf(1, 2, 3)),
537+
listOf(listOf(1, 2, 3)),
538+
listOf(listOf(1, 2, 3))
539+
)
540+
.map { it.last() }
541+
.map { it.first() }
542+
.reduceK { a, b -> a + b }
543+
expect(result).toBe(3)
544+
}
520545
}
521546
}
522547
})

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ import kotlin.invoke
7777
import kotlin.reflect.*
7878
import kotlin.reflect.full.findAnnotation
7979
import kotlin.reflect.full.isSubclassOf
80+
import kotlin.reflect.full.isSubtypeOf
8081
import kotlin.reflect.full.primaryConstructor
8182
import kotlin.to
8283

@@ -122,9 +123,11 @@ inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
122123
* @return `Broadcast` object, a read-only variable cached on each machine
123124
* @see broadcast
124125
*/
125-
@Deprecated("You can now use `spark.broadcast()` instead.",
126+
@Deprecated(
127+
"You can now use `spark.broadcast()` instead.",
126128
ReplaceWith("spark.broadcast(value)"),
127-
DeprecationLevel.WARNING)
129+
DeprecationLevel.WARNING
130+
)
128131
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = try {
129132
broadcast(value, encoder<T>().clsTag())
130133
} catch (e: ClassNotFoundException) {
@@ -177,10 +180,14 @@ private fun isSupportedClass(cls: KClass<*>): Boolean = cls.isData
177180

178181
private fun <T> kotlinClassEncoder(schema: DataType, kClass: KClass<*>): Encoder<T> {
179182
return ExpressionEncoder(
180-
if (schema is DataTypeWithClass) KotlinReflection.serializerFor(kClass.java,
181-
schema) else KotlinReflection.serializerForType(KotlinReflection.getType(kClass.java)),
182-
if (schema is DataTypeWithClass) KotlinReflection.deserializerFor(kClass.java,
183-
schema) else KotlinReflection.deserializerForType(KotlinReflection.getType(kClass.java)),
183+
if (schema is DataTypeWithClass) KotlinReflection.serializerFor(
184+
kClass.java,
185+
schema
186+
) else KotlinReflection.serializerForType(KotlinReflection.getType(kClass.java)),
187+
if (schema is DataTypeWithClass) KotlinReflection.deserializerFor(
188+
kClass.java,
189+
schema
190+
) else KotlinReflection.deserializerForType(KotlinReflection.getType(kClass.java)),
184191
ClassTag.apply(kClass.java)
185192
)
186193
}
@@ -200,7 +207,8 @@ inline fun <T, reified R> Dataset<T>.groupByKey(noinline func: (T) -> R): KeyVal
200207
inline fun <T, reified R> Dataset<T>.mapPartitions(noinline func: (Iterator<T>) -> Iterator<R>): Dataset<R> =
201208
mapPartitions(func, encoder<R>())
202209

203-
fun <T> Dataset<T>.filterNotNull() = filter { it != null }
210+
@Suppress("UNCHECKED_CAST")
211+
fun <T : Any> Dataset<T?>.filterNotNull(): Dataset<T> = filter { it != null } as Dataset<T>
204212

205213
inline fun <KEY, VALUE, reified R> KeyValueGroupedDataset<KEY, VALUE>.mapValues(noinline func: (VALUE) -> R): KeyValueGroupedDataset<KEY, R> =
206214
mapValues(MapFunction(func), encoder<R>())
@@ -848,9 +856,11 @@ inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U
848856
@OptIn(ExperimentalStdlibApi::class)
849857
fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
850858
val primitiveSchema = knownDataTypes[type.classifier]
851-
if (primitiveSchema != null) return KSimpleTypeWrapper(primitiveSchema,
859+
if (primitiveSchema != null) return KSimpleTypeWrapper(
860+
primitiveSchema,
852861
(type.classifier!! as KClass<*>).java,
853-
type.isMarkedNullable)
862+
type.isMarkedNullable
863+
)
854864
val klass = type.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $type")
855865
val args = type.arguments
856866

@@ -901,15 +911,21 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
901911
.filter { it.findAnnotation<Transient>() == null }
902912
.map {
903913
val projectedType = types[it.type.toString()] ?: it.type
904-
val propertyDescriptor = PropertyDescriptor(it.name,
914+
val propertyDescriptor = PropertyDescriptor(
915+
it.name,
905916
klass.java,
906917
"is" + it.name?.replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() },
907-
null)
908-
KStructField(propertyDescriptor.readMethod.name,
909-
StructField(it.name,
918+
null
919+
)
920+
KStructField(
921+
propertyDescriptor.readMethod.name,
922+
StructField(
923+
it.name,
910924
schema(projectedType, types),
911925
projectedType.isMarkedNullable,
912-
Metadata.empty()))
926+
Metadata.empty()
927+
)
928+
)
913929
}
914930
.toTypedArray()
915931
)
@@ -923,8 +939,10 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
923939
val structType = DataTypes.createStructType(
924940
params.map { (fieldName, fieldType) ->
925941
val dataType = schema(fieldType, types)
926-
KStructField(fieldName,
927-
StructField(fieldName, dataType, fieldType.isMarkedNullable, Metadata.empty()))
942+
KStructField(
943+
fieldName,
944+
StructField(fieldName, dataType, fieldType.isMarkedNullable, Metadata.empty())
945+
)
928946
}.toTypedArray()
929947
)
930948

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ class ApiTest : ShouldSpec({
6666
val result = dsOf(1, 2, 3, 4, 5)
6767
.map { it to (it + 2) }
6868
.withCached {
69-
expect(collectAsList()).contains.inAnyOrder.only.values(1 to 3,
69+
expect(collectAsList()).contains.inAnyOrder.only.values(
70+
1 to 3,
7071
2 to 4,
7172
3 to 5,
7273
4 to 6,
73-
5 to 7)
74+
5 to 7
75+
)
7476

7577
val next = filter { it.first % 2 == 0 }
7678
expect(next.collectAsList()).contains.inAnyOrder.only.values(2 to 4, 4 to 6)
@@ -109,8 +111,12 @@ class ApiTest : ShouldSpec({
109111
.map { MovieExpanded(it.id, it.genres.split("|").toList()) }
110112
.filter { it.genres.contains("Comedy") }
111113
.collectAsList()
112-
expect(comedies).contains.inAnyOrder.only.values(MovieExpanded(1,
113-
listOf("Comedy", "Romance")))
114+
expect(comedies).contains.inAnyOrder.only.values(
115+
MovieExpanded(
116+
1,
117+
listOf("Comedy", "Romance")
118+
)
119+
)
114120
}
115121
should("handle strings converted to arrays") {
116122
data class Movie(val id: Long, val genres: String)
@@ -133,29 +139,15 @@ class ApiTest : ShouldSpec({
133139
.map { MovieExpanded(it.id, it.genres.split("|").toTypedArray()) }
134140
.filter { it.genres.contains("Comedy") }
135141
.collectAsList()
136-
expect(comedies).contains.inAnyOrder.only.values(MovieExpanded(1,
137-
arrayOf("Comedy", "Romance")))
142+
expect(comedies).contains.inAnyOrder.only.values(
143+
MovieExpanded(
144+
1,
145+
arrayOf("Comedy", "Romance")
146+
)
147+
)
138148
}
139149
should("handle arrays of generics") {
140-
data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>) {
141-
override fun equals(other: Any?): Boolean {
142-
if (this === other) return true
143-
if (javaClass != other?.javaClass) return false
144-
145-
other as Test<*>
146-
147-
if (id != other.id) return false
148-
if (!data.contentEquals(other.data)) return false
149-
150-
return true
151-
}
152-
153-
override fun hashCode(): Int {
154-
var result = id.hashCode()
155-
result = 31 * result + data.contentHashCode()
156-
return result
157-
}
158-
}
150+
data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>)
159151

160152
val result = listOf(Test(1, arrayOf(5.1 to 6, 6.1 to 7)))
161153
.toDS()
@@ -164,6 +156,16 @@ class ApiTest : ShouldSpec({
164156
.collectAsList()
165157
expect(result).contains.inOrder.only.values(5.1 to 6)
166158
}
159+
should("handle lists of generics") {
160+
data class Test<Z>(val id: Long, val data: List<Pair<Z, Int>>)
161+
162+
val result = listOf(Test(1, listOf(5.1 to 6, 6.1 to 7)))
163+
.toDS()
164+
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
165+
.map { it.second }
166+
.collectAsList()
167+
expect(result).contains.inOrder.only.values(5.1 to 6)
168+
}
167169
should("!handle primitive arrays") {
168170
val result = listOf(arrayOf(1, 2, 3, 4))
169171
.toDS()
@@ -558,6 +560,32 @@ class ApiTest : ShouldSpec({
558560
first.someOtherArray shouldBe arrayOf(SomeOtherEnum.C, SomeOtherEnum.D)
559561
first.enumMap shouldBe mapOf(SomeEnum.A to SomeOtherEnum.C)
560562
}
563+
should("work with lists of maps") {
564+
val result = dsOf(
565+
listOf(mapOf("a" to "b", "x" to "y")),
566+
listOf(mapOf("a" to "b", "x" to "y")),
567+
listOf(mapOf("a" to "b", "x" to "y"))
568+
)
569+
.showDS()
570+
.map { it.last() }
571+
.map { it["x"] }
572+
.filterNotNull()
573+
.distinct()
574+
.collectAsList()
575+
expect(result).contains.inOrder.only.value("y")
576+
}
577+
should("work with lists of lists") {
578+
val result = dsOf(
579+
listOf(listOf(1, 2, 3)),
580+
listOf(listOf(1, 2, 3)),
581+
listOf(listOf(1, 2, 3))
582+
)
583+
.map { it.last() }
584+
.map { it.first() }
585+
.reduceK { a, b -> a + b }
586+
expect(result).toBe(3)
587+
}
588+
561589
}
562590
}
563591
})

0 commit comments

Comments
 (0)