Skip to content

Commit 3f9dcb4

Browse files
committed
KTNB-1205: Extract WidgetManager interface
1 parent 5b2e3f9 commit 3f9dcb4

File tree

3 files changed

+198
-184
lines changed

3 files changed

+198
-184
lines changed

integrations/widgets/widgets-api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/widget/WidgetJupyterIntegration.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class WidgetJupyterIntegration : JupyterIntegration() {
1414
importPackage<IntSliderWidget>()
1515

1616
var myLastClassLoader = WidgetJupyterIntegration::class.java.classLoader
17-
val widgetManager = WidgetManager(notebook.commManager) { myLastClassLoader }
17+
val widgetManager = WidgetManagerImpl(notebook.commManager) { myLastClassLoader }
1818
myWidgetManager = widgetManager
1919

2020
onLoaded {
Lines changed: 7 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,190 +1,14 @@
11
package org.jetbrains.kotlinx.jupyter.widget
22

3-
import kotlinx.serialization.json.Json
4-
import kotlinx.serialization.json.buildJsonObject
5-
import kotlinx.serialization.json.decodeFromJsonElement
6-
import kotlinx.serialization.json.encodeToJsonElement
7-
import kotlinx.serialization.json.jsonObject
8-
import kotlinx.serialization.json.jsonPrimitive
9-
import kotlinx.serialization.json.put
103
import org.jetbrains.kotlinx.jupyter.api.DisplayResult
11-
import org.jetbrains.kotlinx.jupyter.api.MimeTypedResultEx
12-
import org.jetbrains.kotlinx.jupyter.api.MimeTypes
13-
import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
14-
import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
15-
import org.jetbrains.kotlinx.jupyter.widget.model.DEFAULT_MAJOR_VERSION
16-
import org.jetbrains.kotlinx.jupyter.widget.model.DEFAULT_MINOR_VERSION
17-
import org.jetbrains.kotlinx.jupyter.widget.model.DefaultWidgetModel
184
import org.jetbrains.kotlinx.jupyter.widget.model.WidgetFactoryRegistry
195
import org.jetbrains.kotlinx.jupyter.widget.model.WidgetModel
20-
import org.jetbrains.kotlinx.jupyter.widget.model.versionConstraintRegex
21-
import org.jetbrains.kotlinx.jupyter.widget.protocol.CustomMessage
22-
import org.jetbrains.kotlinx.jupyter.widget.protocol.RequestStateMessage
23-
import org.jetbrains.kotlinx.jupyter.widget.protocol.RequestStatesMessage
24-
import org.jetbrains.kotlinx.jupyter.widget.protocol.UpdateStatesMessage
25-
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetMessage
26-
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetOpenMessage
27-
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetStateMessage
28-
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetUpdateMessage
29-
import org.jetbrains.kotlinx.jupyter.widget.protocol.getWireMessage
30-
import org.jetbrains.kotlinx.jupyter.widget.protocol.toPatch
316

32-
private val widgetOpenMetadataJson =
33-
buildJsonObject {
34-
put("version", "${DEFAULT_MAJOR_VERSION}.${DEFAULT_MINOR_VERSION}")
35-
}
7+
public interface WidgetManager {
8+
public val factoryRegistry: WidgetFactoryRegistry
369

37-
public class WidgetManager(
38-
private val commManager: CommManager,
39-
private val classLoaderProvider: () -> ClassLoader,
40-
) {
41-
private val widgetTarget = "jupyter.widget"
42-
private val widgetControlTarget = "jupyter.widget.control"
43-
private val widgets = mutableMapOf<String, WidgetModel>()
44-
private val widgetIdByWidget = mutableMapOf<WidgetModel, String>()
45-
46-
public val factoryRegistry: WidgetFactoryRegistry = WidgetFactoryRegistry()
47-
48-
init {
49-
commManager.registerCommTarget(widgetControlTarget) { comm, _, _, _ ->
50-
comm.onMessage { msg, _, _ ->
51-
when (Json.decodeFromJsonElement<WidgetMessage>(msg)) {
52-
is RequestStatesMessage -> {
53-
val fullStates =
54-
widgets.mapValues { (id, widget) ->
55-
widget.getFullState()
56-
}
57-
58-
val wireMessage = getWireMessage(fullStates)
59-
val message = UpdateStatesMessage(wireMessage.state, wireMessage.bufferPaths)
60-
61-
val data = Json.encodeToJsonElement<WidgetMessage>(message).jsonObject
62-
comm.send(data, null, emptyList())
63-
}
64-
65-
else -> {}
66-
}
67-
}
68-
}
69-
70-
commManager.registerCommTarget(widgetTarget) { comm, data, _, buffers ->
71-
val openMessage = Json.decodeFromJsonElement<WidgetOpenMessage>(data)
72-
val modelName = openMessage.state["_model_name"]?.jsonPrimitive?.content!!
73-
val widgetFactory = factoryRegistry.loadWidgetFactory(modelName, classLoaderProvider())
74-
75-
val widget = widgetFactory.create(this)
76-
val patch = openMessage.toPatch(buffers)
77-
widget.applyPatch(patch)
78-
79-
initializeWidget(comm, widget)
80-
}
81-
}
82-
83-
public fun getWidget(modelId: String): WidgetModel? = widgets[modelId]
84-
85-
public fun getWidgetId(widget: WidgetModel): String? = widgetIdByWidget[widget]
86-
87-
public fun registerWidget(widget: WidgetModel) {
88-
if (widgetIdByWidget[widget] != null) return
89-
90-
val fullState = widget.getFullState()
91-
val wireMessage = getWireMessage(fullState)
92-
93-
val comm =
94-
commManager.openComm(
95-
widgetTarget,
96-
Json
97-
.encodeToJsonElement(
98-
WidgetOpenMessage(
99-
wireMessage.state,
100-
wireMessage.bufferPaths,
101-
),
102-
).jsonObject,
103-
widgetOpenMetadataJson,
104-
wireMessage.buffers,
105-
)
106-
107-
initializeWidget(comm, widget)
108-
}
109-
110-
public fun renderWidget(widget: WidgetModel): DisplayResult =
111-
MimeTypedResultEx(
112-
buildJsonObject {
113-
val modelId = widgetIdByWidget[widget] ?: error("Widget is not registered")
114-
var versionMajor = DEFAULT_MAJOR_VERSION
115-
var versionMinor = DEFAULT_MINOR_VERSION
116-
var modelName: String? = null
117-
if (widget is DefaultWidgetModel) {
118-
modelName = widget.modelName
119-
val version = widget.modelModuleVersion
120-
val matchResult = versionConstraintRegex.find(version)
121-
if (matchResult != null) {
122-
versionMajor = matchResult.groupValues[1].toInt()
123-
versionMinor = matchResult.groupValues[2].toInt()
124-
}
125-
}
126-
if (modelName != null) {
127-
put(MimeTypes.HTML, "$modelName(id=$modelId)")
128-
}
129-
put(
130-
"application/vnd.jupyter.widget-view+json",
131-
buildJsonObject {
132-
put("version_major", versionMajor)
133-
put("version_minor", versionMinor)
134-
put("model_id", modelId)
135-
},
136-
)
137-
},
138-
null,
139-
)
140-
141-
private fun initializeWidget(
142-
comm: Comm,
143-
widget: WidgetModel,
144-
) {
145-
val modelId = comm.id
146-
widgetIdByWidget[widget] = modelId
147-
widgets[modelId] = widget
148-
149-
// Reflect kernel-side changes on the frontend
150-
widget.addChangeListener { patch ->
151-
val wireMessage = getWireMessage(patch)
152-
val data =
153-
Json
154-
.encodeToJsonElement<WidgetMessage>(
155-
WidgetUpdateMessage(
156-
wireMessage.state,
157-
wireMessage.bufferPaths,
158-
),
159-
).jsonObject
160-
comm.send(data, null, wireMessage.buffers)
161-
}
162-
163-
// Reflect frontend-side changes on kernel
164-
comm.onMessage { msg, _, buffers ->
165-
when (val message = Json.decodeFromJsonElement<WidgetMessage>(msg)) {
166-
is WidgetStateMessage -> {
167-
widget.applyPatch(message.toPatch(buffers))
168-
}
169-
170-
is RequestStateMessage -> {
171-
val fullState = widget.getFullState()
172-
val wireMessage = getWireMessage(fullState)
173-
val data =
174-
Json
175-
.encodeToJsonElement<WidgetMessage>(
176-
WidgetUpdateMessage(
177-
wireMessage.state,
178-
wireMessage.bufferPaths,
179-
),
180-
).jsonObject
181-
comm.send(data, null, wireMessage.buffers)
182-
}
183-
184-
is CustomMessage -> {}
185-
186-
else -> {}
187-
}
188-
}
189-
}
190-
}
10+
public fun getWidget(modelId: String): WidgetModel?
11+
public fun getWidgetId(widget: WidgetModel): String?
12+
public fun registerWidget(widget: WidgetModel)
13+
public fun renderWidget(widget: WidgetModel): DisplayResult
14+
}
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
package org.jetbrains.kotlinx.jupyter.widget
2+
3+
import kotlinx.serialization.json.Json
4+
import kotlinx.serialization.json.buildJsonObject
5+
import kotlinx.serialization.json.decodeFromJsonElement
6+
import kotlinx.serialization.json.encodeToJsonElement
7+
import kotlinx.serialization.json.jsonObject
8+
import kotlinx.serialization.json.jsonPrimitive
9+
import kotlinx.serialization.json.put
10+
import org.jetbrains.kotlinx.jupyter.api.DisplayResult
11+
import org.jetbrains.kotlinx.jupyter.api.MimeTypedResultEx
12+
import org.jetbrains.kotlinx.jupyter.api.MimeTypes
13+
import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
14+
import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
15+
import org.jetbrains.kotlinx.jupyter.widget.model.DEFAULT_MAJOR_VERSION
16+
import org.jetbrains.kotlinx.jupyter.widget.model.DEFAULT_MINOR_VERSION
17+
import org.jetbrains.kotlinx.jupyter.widget.model.DefaultWidgetModel
18+
import org.jetbrains.kotlinx.jupyter.widget.model.WidgetFactoryRegistry
19+
import org.jetbrains.kotlinx.jupyter.widget.model.WidgetModel
20+
import org.jetbrains.kotlinx.jupyter.widget.model.versionConstraintRegex
21+
import org.jetbrains.kotlinx.jupyter.widget.protocol.CustomMessage
22+
import org.jetbrains.kotlinx.jupyter.widget.protocol.RequestStateMessage
23+
import org.jetbrains.kotlinx.jupyter.widget.protocol.RequestStatesMessage
24+
import org.jetbrains.kotlinx.jupyter.widget.protocol.UpdateStatesMessage
25+
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetMessage
26+
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetOpenMessage
27+
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetStateMessage
28+
import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetUpdateMessage
29+
import org.jetbrains.kotlinx.jupyter.widget.protocol.getWireMessage
30+
import org.jetbrains.kotlinx.jupyter.widget.protocol.toPatch
31+
32+
private val widgetOpenMetadataJson =
33+
buildJsonObject {
34+
put("version", "${DEFAULT_MAJOR_VERSION}.${DEFAULT_MINOR_VERSION}")
35+
}
36+
37+
public class WidgetManagerImpl(
38+
private val commManager: CommManager,
39+
private val classLoaderProvider: () -> ClassLoader,
40+
) : WidgetManager {
41+
private val widgetTarget = "jupyter.widget"
42+
private val widgetControlTarget = "jupyter.widget.control"
43+
private val widgets = mutableMapOf<String, WidgetModel>()
44+
private val widgetIdByWidget = mutableMapOf<WidgetModel, String>()
45+
46+
override val factoryRegistry: WidgetFactoryRegistry = WidgetFactoryRegistry()
47+
48+
init {
49+
commManager.registerCommTarget(widgetControlTarget) { comm, _, _, _ ->
50+
comm.onMessage { msg, _, _ ->
51+
when (Json.decodeFromJsonElement<WidgetMessage>(msg)) {
52+
is RequestStatesMessage -> {
53+
val fullStates =
54+
widgets.mapValues { (id, widget) ->
55+
widget.getFullState()
56+
}
57+
58+
val wireMessage = getWireMessage(fullStates)
59+
val message = UpdateStatesMessage(wireMessage.state, wireMessage.bufferPaths)
60+
61+
val data = Json.encodeToJsonElement<WidgetMessage>(message).jsonObject
62+
comm.send(data, null, emptyList())
63+
}
64+
65+
else -> {}
66+
}
67+
}
68+
}
69+
70+
commManager.registerCommTarget(widgetTarget) { comm, data, _, buffers ->
71+
val openMessage = Json.decodeFromJsonElement<WidgetOpenMessage>(data)
72+
val modelName = openMessage.state["_model_name"]?.jsonPrimitive?.content!!
73+
val widgetFactory = factoryRegistry.loadWidgetFactory(modelName, classLoaderProvider())
74+
75+
val widget = widgetFactory.create(this)
76+
val patch = openMessage.toPatch(buffers)
77+
widget.applyPatch(patch)
78+
79+
initializeWidget(comm, widget)
80+
}
81+
}
82+
83+
override fun getWidget(modelId: String): WidgetModel? = widgets[modelId]
84+
85+
override fun getWidgetId(widget: WidgetModel): String? = widgetIdByWidget[widget]
86+
87+
override fun registerWidget(widget: WidgetModel) {
88+
if (getWidgetId(widget) != null) return
89+
90+
val fullState = widget.getFullState()
91+
val wireMessage = getWireMessage(fullState)
92+
93+
val comm =
94+
commManager.openComm(
95+
widgetTarget,
96+
Json
97+
.encodeToJsonElement(
98+
WidgetOpenMessage(
99+
wireMessage.state,
100+
wireMessage.bufferPaths,
101+
),
102+
).jsonObject,
103+
widgetOpenMetadataJson,
104+
wireMessage.buffers,
105+
)
106+
107+
initializeWidget(comm, widget)
108+
}
109+
110+
override fun renderWidget(widget: WidgetModel): DisplayResult =
111+
MimeTypedResultEx(
112+
buildJsonObject {
113+
val modelId = getWidgetId(widget) ?: error("Widget is not registered")
114+
var versionMajor = DEFAULT_MAJOR_VERSION
115+
var versionMinor = DEFAULT_MINOR_VERSION
116+
var modelName: String? = null
117+
if (widget is DefaultWidgetModel) {
118+
modelName = widget.modelName
119+
val version = widget.modelModuleVersion
120+
val matchResult = versionConstraintRegex.find(version)
121+
if (matchResult != null) {
122+
versionMajor = matchResult.groupValues[1].toInt()
123+
versionMinor = matchResult.groupValues[2].toInt()
124+
}
125+
}
126+
if (modelName != null) {
127+
put(MimeTypes.HTML, "$modelName(id=$modelId)")
128+
}
129+
put(
130+
"application/vnd.jupyter.widget-view+json",
131+
buildJsonObject {
132+
put("version_major", versionMajor)
133+
put("version_minor", versionMinor)
134+
put("model_id", modelId)
135+
},
136+
)
137+
},
138+
null,
139+
)
140+
141+
private fun initializeWidget(
142+
comm: Comm,
143+
widget: WidgetModel,
144+
) {
145+
val modelId = comm.id
146+
widgetIdByWidget[widget] = modelId
147+
widgets[modelId] = widget
148+
149+
// Reflect kernel-side changes on the frontend
150+
widget.addChangeListener { patch ->
151+
val wireMessage = getWireMessage(patch)
152+
val data =
153+
Json
154+
.encodeToJsonElement<WidgetMessage>(
155+
WidgetUpdateMessage(
156+
wireMessage.state,
157+
wireMessage.bufferPaths,
158+
),
159+
).jsonObject
160+
comm.send(data, null, wireMessage.buffers)
161+
}
162+
163+
// Reflect frontend-side changes on kernel
164+
comm.onMessage { msg, _, buffers ->
165+
when (val message = Json.decodeFromJsonElement<WidgetMessage>(msg)) {
166+
is WidgetStateMessage -> {
167+
widget.applyPatch(message.toPatch(buffers))
168+
}
169+
170+
is RequestStateMessage -> {
171+
val fullState = widget.getFullState()
172+
val wireMessage = getWireMessage(fullState)
173+
val data =
174+
Json
175+
.encodeToJsonElement<WidgetMessage>(
176+
WidgetUpdateMessage(
177+
wireMessage.state,
178+
wireMessage.bufferPaths,
179+
),
180+
).jsonObject
181+
comm.send(data, null, wireMessage.buffers)
182+
}
183+
184+
is CustomMessage -> {}
185+
186+
else -> {}
187+
}
188+
}
189+
}
190+
}

0 commit comments

Comments
 (0)