Skip to content

Commit 697fd31

Browse files
committed
fix: Adds partial support for arrays
1 parent 1a13eb2 commit 697fd31

File tree

7 files changed

+217
-17
lines changed

7 files changed

+217
-17
lines changed

core/2.4/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,21 @@ class KStructField(val getterName: String, val delegate: StructField) extends St
182182
override def getComment(): Option[String] = delegate.getComment()
183183

184184
override def toDDL: String = delegate.toDDL
185+
186+
override def productElement(n: Int): Any = delegate.productElement(n)
187+
188+
override def productArity: Int = delegate.productArity
189+
190+
override def productIterator: Iterator[Any] = delegate.productIterator
191+
192+
override def productPrefix: String = delegate.productPrefix
193+
194+
override def canEqual(that: Any): Boolean = delegate.canEqual(that)
195+
196+
override val dataType: DataType = delegate.dataType
197+
override val metadata: Metadata = delegate.metadata
198+
override val nullable: Boolean = delegate.nullable
199+
override val name: String = delegate.name
185200
}
186201

187202
object helpme {

core/2.4/src/main/scala/org/apache/spark/sql/catalyst/KotlinReflection.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ object KotlinReflection {
258258
returnNullable = false)
259259

260260

261-
case c if c.isArray && predefinedDt.isEmpty =>
261+
case c if c.isArray =>
262262
val elementType = c.getComponentType
263263
val primitiveMethod = elementType match {
264264
case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
@@ -271,16 +271,19 @@ object KotlinReflection {
271271
case _ => None
272272
}
273273

274+
val maybeType = predefinedDt.filter(_.dt.isInstanceOf[ArrayType]).map(_.dt.asInstanceOf[ArrayType].elementType)
274275
primitiveMethod.map { method =>
275276
Invoke(getPath, method, ObjectType(c))
276277
}.getOrElse {
277278
Invoke(
278279
MapObjects(
279-
p => deserializerFor(typeToken.getComponentType, Some(p)),
280+
p => {deserializerFor(typeToken.getComponentType, Some(p), maybeType.filter(_.isInstanceOf[ComplexWrapper]).map(_.asInstanceOf[ComplexWrapper]))},
280281
getPath,
281-
inferDataType(elementType)._1),
282+
maybeType.filter(_.isInstanceOf[ComplexWrapper]).map(_.asInstanceOf[ComplexWrapper].dt).getOrElse(inferDataType(elementType)._1)
283+
),
282284
"array",
283-
ObjectType(c))
285+
ObjectType(c)
286+
)
284287
}
285288

286289
case c if listType.isAssignableFrom(typeToken) && predefinedDt.isEmpty =>
@@ -343,7 +346,6 @@ object KotlinReflection {
343346
val fieldName = field.asInstanceOf[KStructField].delegate.name
344347
val newPath = addToPath(fieldName)
345348
deserializerFor(TypeToken.of(fieldCls), Some(newPath), Some(dataType).filter(_.isInstanceOf[ComplexWrapper]))
346-
// val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
347349

348350
})
349351
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)

core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,14 @@ object KotlinReflection extends KotlinReflection {
281281
createDeserializerForScalaBigInt(path)
282282

283283
case t if isSubtype(t, localTypeOf[Array[_]]) =>
284-
val TypeRef(_, _, Seq(elementType)) = t
285-
val Schema(dataType, elementNullable) = schemaFor(elementType)
284+
var TypeRef(_, _, Seq(elementType)) = t
285+
if (predefinedDt.isDefined && !elementType.dealias.typeSymbol.isClass)
286+
elementType = getType(predefinedDt.get.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType.asInstanceOf[DataTypeWithClass].cls)
287+
val Schema(dataType, elementNullable) = predefinedDt.map(it => {
288+
val elementInfo = it.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType.asInstanceOf[DataTypeWithClass]
289+
Schema(elementInfo.dt, elementInfo.nullable)
290+
})
291+
.getOrElse(schemaFor(elementType))
286292
val className = getClassNameFromType(elementType)
287293
val newTypePath = walkedTypePath.recordArray(className)
288294

@@ -293,7 +299,7 @@ object KotlinReflection extends KotlinReflection {
293299
dataType,
294300
nullable = elementNullable,
295301
newTypePath,
296-
(casted, typePath) => deserializerFor(elementType, casted, typePath))
302+
(casted, typePath) => deserializerFor(elementType, casted, typePath, predefinedDt.map(_.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType).filter(_.isInstanceOf[ComplexWrapper]).map(_.asInstanceOf[ComplexWrapper])))
297303
}
298304

299305
val arrayData = UnresolvedMapObjects(mapFunction, path)
@@ -513,6 +519,12 @@ object KotlinReflection extends KotlinReflection {
513519

514520
def toCatalystArray(input: Expression, elementType: `Type`, predefinedDt: Option[DataTypeWithClass] = None): Expression = {
515521
predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match {
522+
case dt:StructType =>
523+
val clsName = getClassNameFromType(elementType)
524+
val newPath = walkedTypePath.recordArray(clsName)
525+
createSerializerForMapObjects(input, ObjectType(predefinedDt.get.cls),
526+
serializerFor(_, elementType, newPath, seenTypeSet, predefinedDt))
527+
516528
case dt: ObjectType =>
517529
val clsName = getClassNameFromType(elementType)
518530
val newPath = walkedTypePath.recordArray(clsName)
@@ -528,8 +540,15 @@ object KotlinReflection extends KotlinReflection {
528540
createSerializerForGenericArray(input, dt, nullable = predefinedDt.map(_.nullable).getOrElse(schemaFor(elementType).nullable))
529541
}
530542

543+
case _: StringType =>
544+
val clsName = getClassNameFromType(typeOf[String])
545+
val newPath = walkedTypePath.recordArray(clsName)
546+
createSerializerForMapObjects(input, ObjectType(classOf[String]),
547+
serializerFor(_, elementType, newPath, seenTypeSet))
548+
549+
531550
case dt =>
532-
createSerializerForGenericArray(input, dt, nullable = schemaFor(elementType).nullable)
551+
createSerializerForGenericArray(input, dt, nullable = predefinedDt.map(_.nullable).getOrElse(schemaFor(elementType).nullable))
533552
}
534553
}
535554

@@ -552,7 +571,7 @@ object KotlinReflection extends KotlinReflection {
552571
val TypeRef(_, _, Seq(elementType)) = t
553572
toCatalystArray(inputObject, elementType)
554573

555-
case t if isSubtype(t, localTypeOf[Array[_]]) =>
574+
case t if isSubtype(t, localTypeOf[Array[_]]) && predefinedDt.isEmpty =>
556575
val TypeRef(_, _, Seq(elementType)) = t
557576
toCatalystArray(inputObject, elementType)
558577

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.Expression
3232
import org.apache.spark.sql.types.*
3333
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
3434
import scala.collection.Seq
35-
import scala.reflect.ClassTag
3635
import scala.reflect.`ClassTag$`
3736
import java.beans.PropertyDescriptor
3837
import java.math.BigDecimal
@@ -106,6 +105,7 @@ fun <T> generateEncoder(type: KType, cls: KClass<*>): Encoder<T> {
106105
private fun isSupportedClass(cls: KClass<*>): Boolean = cls.isData
107106
|| cls.isSubclassOf(Map::class)
108107
|| cls.isSubclassOf(Iterable::class)
108+
|| cls.java.isArray
109109

110110
@Suppress("UNCHECKED_CAST")
111111
private fun <T> kotlinClassEncoder(schema: DataType, kClass: KClass<*>): Encoder<T> {
@@ -269,6 +269,7 @@ inline fun <reified R> Dataset<*>.toArray(): Array<R> = to<R>().collect() as Arr
269269
*/
270270
fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply { show(numRows, truncate) }
271271

272+
@OptIn(ExperimentalStdlibApi::class)
272273
fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
273274
val primitiveSchema = knownDataTypes[type.classifier]
274275
if (primitiveSchema != null) return KSimpleTypeWrapper(primitiveSchema, (type.classifier!! as KClass<*>).java, type.isMarkedNullable)
@@ -279,8 +280,19 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
279280
it.first.name to it.second.type!!
280281
}.toMap())
281282
return when {
282-
klass.isSubclassOf(Iterable::class) -> {
283-
val listParam = types.getValue(klass.typeParameters[0].name)
283+
klass.isSubclassOf(Iterable::class) || klass.java.isArray -> {
284+
val listParam = if(klass.java.isArray){
285+
when (klass) {
286+
IntArray::class -> typeOf<Int>()
287+
LongArray::class -> typeOf<Long>()
288+
FloatArray::class -> typeOf<Float>()
289+
DoubleArray::class -> typeOf<Double>()
290+
BooleanArray::class -> typeOf<Boolean>()
291+
ShortArray::class -> typeOf<Short>()
292+
ByteArray::class -> typeOf<Byte>()
293+
else -> types.getValue(klass.typeParameters[0].name)
294+
}
295+
}else types.getValue(klass.typeParameters[0].name)
284296
KComplexTypeWrapper(
285297
DataTypes.createArrayType(schema(listParam, types), listParam.isMarkedNullable),
286298
klass.java,

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import java.time.LocalDate
2525

2626
class ApiTest : ShouldSpec({
2727
context("integration tests") {
28-
withSpark(master = "local", props = mapOf("spark.sql.codegen.comments" to true)) {
28+
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
2929
should("collect data classes with doubles correctly") {
3030
val ll1 = LonLat(1.0, 2.0)
3131
val ll2 = LonLat(3.0, 4.0)
@@ -85,9 +85,79 @@ class ApiTest : ShouldSpec({
8585
.only
8686
.values(5, 6, 7, 8, 7, 8, 9)
8787
}
88+
should("hadle strings converted to lists") {
89+
data class Movie(val id: Long, val genres: String)
90+
data class MovieExpanded(val id: Long, val genres: List<String>)
91+
92+
val comedies = listOf(Movie(1, "Comedy|Romance"), Movie(2, "Horror|Action")).toDS()
93+
.map { MovieExpanded(it.id, it.genres.split("|").toList()) }
94+
.filter { it.genres.contains("Comedy") }
95+
.collectAsList()
96+
expect(comedies).asExpect().contains.inAnyOrder.only.values(MovieExpanded(1, listOf("Comedy", "Romance")))
97+
}
98+
should("handle strings converted to arrays") {
99+
data class Movie(val id: Long, val genres: String)
100+
data class MovieExpanded(val id: Long, val genres: Array<String>) {
101+
override fun equals(other: Any?): Boolean {
102+
if (this === other) return true
103+
if (javaClass != other?.javaClass) return false
104+
other as MovieExpanded
105+
return if (id != other.id) false else genres.contentEquals(other.genres)
106+
}
107+
108+
override fun hashCode(): Int {
109+
var result = id.hashCode()
110+
result = 31 * result + genres.contentHashCode()
111+
return result
112+
}
113+
}
114+
115+
val comedies = listOf(Movie(1, "Comedy|Romance"), Movie(2, "Horror|Action")).toDS()
116+
.map { MovieExpanded(it.id, it.genres.split("|").toTypedArray()) }
117+
.filter { it.genres.contains("Comedy") }
118+
.collectAsList()
119+
expect(comedies).asExpect().contains.inAnyOrder.only.values(MovieExpanded(1, arrayOf("Comedy", "Romance")))
120+
}
121+
should("!handle arrays of generics") {
122+
123+
val result = listOf(Test(1, arrayOf(5.1 to 6, 6.1 to 7)))
124+
.toDS()
125+
.map { it.id to it.data.first { liEl -> liEl.first < 6 } }
126+
.map { it.second }
127+
.collectAsList()
128+
expect(result).asExpect().contains.inOrder.only.values(5.1 to 6)
129+
}
130+
should("handle primitive arrays") {
131+
val result = listOf(arrayOf(1, 2, 3, 4))
132+
.toDS()
133+
.map { it.map { ai -> ai + 1 } }
134+
.collectAsList()
135+
.flatten()
136+
expect(result).asExpect().contains.inOrder.only.values(2, 3, 4, 5)
137+
138+
}
88139

89140
}
90141
}
91142
})
92143

93144
data class LonLat(val lon: Double, val lat: Double)
145+
data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>) {
146+
override fun equals(other: Any?): Boolean {
147+
if (this === other) return true
148+
if (javaClass != other?.javaClass) return false
149+
150+
other as Test<*>
151+
152+
if (id != other.id) return false
153+
if (!data.contentEquals(other.data)) return false
154+
155+
return true
156+
}
157+
158+
override fun hashCode(): Int {
159+
var result = id.hashCode()
160+
result = 31 * result + data.contentHashCode()
161+
return result
162+
}
163+
}

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ fun <T> generateEncoder(type: KType, cls: KClass<*>): Encoder<T> {
100100
private fun isSupportedClass(cls: KClass<*>): Boolean = cls.isData
101101
|| cls.isSubclassOf(Map::class)
102102
|| cls.isSubclassOf(Iterable::class)
103+
|| cls.java.isArray
103104

104105
private fun <T> kotlinClassEncoder(schema: DataType, kClass: KClass<*>): Encoder<T> {
105106
return ExpressionEncoder(
@@ -258,6 +259,7 @@ inline fun <reified R> Dataset<*>.toArray(): Array<R> = to<R>().collect() as Arr
258259
*/
259260
fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply { show(numRows, truncate) }
260261

262+
@OptIn(ExperimentalStdlibApi::class)
261263
fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
262264
val primitiveSchema = knownDataTypes[type.classifier]
263265
if (primitiveSchema != null) return KSimpleTypeWrapper(primitiveSchema, (type.classifier!! as KClass<*>).java, type.isMarkedNullable)
@@ -268,8 +270,19 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
268270
it.first.name to it.second.type!!
269271
}.toMap())
270272
return when {
271-
klass.isSubclassOf(Iterable::class) -> {
272-
val listParam = types.getValue(klass.typeParameters[0].name)
273+
klass.isSubclassOf(Iterable::class) || klass.java.isArray -> {
274+
val listParam = if(klass.java.isArray){
275+
when (klass) {
276+
IntArray::class -> typeOf<Int>()
277+
LongArray::class -> typeOf<Long>()
278+
FloatArray::class -> typeOf<Float>()
279+
DoubleArray::class -> typeOf<Double>()
280+
BooleanArray::class -> typeOf<Boolean>()
281+
ShortArray::class -> typeOf<Short>()
282+
ByteArray::class -> typeOf<Byte>()
283+
else -> types.getValue(klass.typeParameters[0].name)
284+
}
285+
}else types.getValue(klass.typeParameters[0].name)
273286
KComplexTypeWrapper(
274287
DataTypes.createArrayType(schema(listParam, types), listParam.isMarkedNullable),
275288
klass.java,

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import java.time.LocalDate
2525

2626
class ApiTest : ShouldSpec({
2727
context("integration tests") {
28-
withSpark {
28+
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
2929
should("collect data classes with doubles correctly") {
3030
val ll1 = LonLat(1.0, 2.0)
3131
val ll2 = LonLat(3.0, 4.0)
@@ -79,7 +79,76 @@ class ApiTest : ShouldSpec({
7979
.collectAsList()
8080
expect(result).asExpect().contains.inAnyOrder.only.values(5, 6, 7, 8, 7, 8, 9)
8181
}
82+
should("handle strings converted to lists") {
83+
data class Movie(val id: Long, val genres: String)
84+
data class MovieExpanded(val id: Long, val genres: List<String>)
8285

86+
val comedies = listOf(Movie(1, "Comedy|Romance"), Movie(2, "Horror|Action")).toDS()
87+
.map { MovieExpanded(it.id, it.genres.split("|").toList()) }
88+
.filter { it.genres.contains("Comedy") }
89+
.collectAsList()
90+
expect(comedies).asExpect().contains.inAnyOrder.only.values(MovieExpanded(1, listOf("Comedy", "Romance")))
91+
}
92+
should("handle strings converted to arrays") {
93+
data class Movie(val id: Long, val genres: String)
94+
data class MovieExpanded(val id: Long, val genres: Array<String>) {
95+
override fun equals(other: Any?): Boolean {
96+
if (this === other) return true
97+
if (javaClass != other?.javaClass) return false
98+
other as MovieExpanded
99+
return if (id != other.id) false else genres.contentEquals(other.genres)
100+
}
101+
102+
override fun hashCode(): Int {
103+
var result = id.hashCode()
104+
result = 31 * result + genres.contentHashCode()
105+
return result
106+
}
107+
}
108+
109+
val comedies = listOf(Movie(1, "Comedy|Romance"), Movie(2, "Horror|Action")).toDS()
110+
.map { MovieExpanded(it.id, it.genres.split("|").toTypedArray()) }
111+
.filter { it.genres.contains("Comedy") }
112+
.collectAsList()
113+
expect(comedies).asExpect().contains.inAnyOrder.only.values(MovieExpanded(1, arrayOf("Comedy", "Romance")))
114+
}
115+
should("handle arrays of generics") {
116+
data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>) {
117+
override fun equals(other: Any?): Boolean {
118+
if (this === other) return true
119+
if (javaClass != other?.javaClass) return false
120+
121+
other as Test<*>
122+
123+
if (id != other.id) return false
124+
if (!data.contentEquals(other.data)) return false
125+
126+
return true
127+
}
128+
129+
override fun hashCode(): Int {
130+
var result = id.hashCode()
131+
result = 31 * result + data.contentHashCode()
132+
return result
133+
}
134+
}
135+
136+
val result = listOf(Test(1, arrayOf(5.1 to 6, 6.1 to 7)))
137+
.toDS()
138+
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
139+
.map { it.second }
140+
.collectAsList()
141+
expect(result).asExpect().contains.inOrder.only.values(5.1 to 6)
142+
}
143+
should("!handle primitive arrays") {
144+
val result = listOf(arrayOf(1, 2, 3, 4))
145+
.toDS()
146+
.map { it.map { ai -> ai + 1 } }
147+
.collectAsList()
148+
.flatten()
149+
expect(result).asExpect().contains.inOrder.only.values(2, 3, 4, 5)
150+
151+
}
83152
}
84153
}
85154
})

0 commit comments

Comments
 (0)