Skip to content

Commit f78b682

Browse files
authored
feat: Adds broadcast function
1 parent a20dc09 commit f78b682

File tree

5 files changed

+94
-2
lines changed
  • examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples
  • kotlin-spark-api
    • 2.4/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api
    • 3.0/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api

5 files changed

+94
-2
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package org.jetbrains.kotlinx.spark.examples
2+
3+
import org.jetbrains.kotlinx.spark.api.broadcast
4+
import org.jetbrains.kotlinx.spark.api.map
5+
import org.jetbrains.kotlinx.spark.api.sparkContext
6+
import org.jetbrains.kotlinx.spark.api.withSpark
7+
import java.io.Serializable
8+
9+
// (data) class must be Serializable to be broadcast
10+
data class SomeClass(val a: IntArray, val b: Int) : Serializable
11+
12+
fun main() = withSpark {
13+
val broadcastVariable = spark.sparkContext.broadcast(SomeClass(a = intArrayOf(5, 6), b = 3))
14+
val result = listOf(1, 2, 3, 4, 5)
15+
.toDS()
16+
.map {
17+
val receivedBroadcast = broadcastVariable.value
18+
it + receivedBroadcast.a.first()
19+
}
20+
.collectAsList()
21+
22+
println(result)
23+
}

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.jetbrains.kotlinx.spark.api
2323

2424
import org.apache.spark.SparkContext
2525
import org.apache.spark.api.java.function.*
26+
import org.apache.spark.broadcast.Broadcast
2627
import org.apache.spark.sql.*
2728
import org.apache.spark.sql.Encoders.*
2829
import org.apache.spark.sql.catalyst.JavaTypeInference
@@ -63,6 +64,17 @@ val ENCODERS = mapOf<KClass<*>, Encoder<*>>(
6364
ByteArray::class to BINARY()
6465
)
6566

67+
/**
68+
* Broadcast a read-only variable to the cluster, returning a
69+
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
70+
* The variable will be sent to each cluster only once.
71+
*
72+
* @param value value to broadcast to the Spark nodes
73+
* @return `Broadcast` object, a read-only variable cached on each machine
74+
*/
75+
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = broadcast(value, encoder<T>().clsTag())
76+
77+
6678
/**
6779
* Utility method to create dataset from list
6880
*/

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import ch.tutteli.atrium.api.fluent.en_GB.*
2121
import ch.tutteli.atrium.domain.builders.migration.asExpect
2222
import ch.tutteli.atrium.verbs.expect
2323
import io.kotest.core.spec.style.ShouldSpec
24+
import java.io.Serializable
2425
import java.time.LocalDate
2526

2627
class ApiTest : ShouldSpec({
@@ -136,7 +137,25 @@ class ApiTest : ShouldSpec({
136137
expect(result).asExpect().contains.inOrder.only.values(2, 3, 4, 5)
137138

138139
}
140+
@OptIn(ExperimentalStdlibApi::class)
141+
should("broadcast variables") {
142+
val largeList = (1..15).map { SomeClass(a = (it..15).toList().toIntArray(), b = it) }
143+
val broadcast = spark.sparkContext.broadcast(largeList)
144+
145+
val result: List<Int> = listOf(1, 2, 3, 4, 5)
146+
.toDS()
147+
.mapPartitions { iterator ->
148+
val receivedBroadcast = broadcast.value
149+
buildList {
150+
iterator.forEach {
151+
this.add(it + receivedBroadcast[it].b)
152+
}
153+
}.iterator()
154+
}
155+
.collectAsList()
139156

157+
expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11)
158+
}
140159
}
141160
}
142161
})
@@ -161,3 +180,6 @@ data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>) {
161180
return result
162181
}
163182
}
183+
184+
// (data) class must be Serializable to be broadcast
185+
data class SomeClass(val a: IntArray, val b: Int) : Serializable

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.jetbrains.kotlinx.spark.api
2323

2424
import org.apache.spark.SparkContext
2525
import org.apache.spark.api.java.function.*
26+
import org.apache.spark.broadcast.Broadcast
2627
import org.apache.spark.sql.*
2728
import org.apache.spark.sql.Encoders.*
2829
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -59,6 +60,17 @@ val ENCODERS = mapOf<KClass<*>, Encoder<*>>(
5960
ByteArray::class to BINARY()
6061
)
6162

63+
64+
/**
65+
* Broadcast a read-only variable to the cluster, returning a
66+
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
67+
* The variable will be sent to each cluster only once.
68+
*
69+
* @param value value to broadcast to the Spark nodes
70+
* @return `Broadcast` object, a read-only variable cached on each machine
71+
*/
72+
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = broadcast(value, encoder<T>().clsTag())
73+
6274
/**
6375
* Utility method to create dataset from list
6476
*/
@@ -271,7 +283,7 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
271283
}.toMap())
272284
return when {
273285
klass.isSubclassOf(Iterable::class) || klass.java.isArray -> {
274-
val listParam = if(klass.java.isArray){
286+
val listParam = if (klass.java.isArray) {
275287
when (klass) {
276288
IntArray::class -> typeOf<Int>()
277289
LongArray::class -> typeOf<Long>()
@@ -282,7 +294,7 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
282294
ByteArray::class -> typeOf<Byte>()
283295
else -> types.getValue(klass.typeParameters[0].name)
284296
}
285-
}else types.getValue(klass.typeParameters[0].name)
297+
} else types.getValue(klass.typeParameters[0].name)
286298
KComplexTypeWrapper(
287299
DataTypes.createArrayType(schema(listParam, types), listParam.isMarkedNullable),
288300
klass.java,

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import ch.tutteli.atrium.api.fluent.en_GB.*
2121
import ch.tutteli.atrium.domain.builders.migration.asExpect
2222
import ch.tutteli.atrium.verbs.expect
2323
import io.kotest.core.spec.style.ShouldSpec
24+
import java.io.Serializable
2425
import java.time.LocalDate
2526

2627
class ApiTest : ShouldSpec({
@@ -149,8 +150,30 @@ class ApiTest : ShouldSpec({
149150
expect(result).asExpect().contains.inOrder.only.values(2, 3, 4, 5)
150151

151152
}
153+
@OptIn(ExperimentalStdlibApi::class)
154+
should("broadcast variables") {
155+
val largeList = (1..15).map { SomeClass(a = (it..15).toList().toIntArray(), b = it) }
156+
val broadcast = spark.sparkContext.broadcast(largeList)
157+
158+
val result: List<Int> = listOf(1, 2, 3, 4, 5)
159+
.toDS()
160+
.mapPartitions { iterator ->
161+
val receivedBroadcast = broadcast.value
162+
buildList {
163+
iterator.forEach {
164+
this.add(it + receivedBroadcast[it].b)
165+
}
166+
}.iterator()
167+
}
168+
.collectAsList()
169+
170+
expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11)
171+
}
152172
}
153173
}
154174
})
155175

156176
data class LonLat(val lon: Double, val lat: Double)
177+
178+
// (data) class must be Serializable to be broadcast
179+
data class SomeClass(val a: IntArray, val b: Int) : Serializable

0 commit comments

Comments
 (0)