11package org.jetbrains.kotlinx.jupyter.widget
22
3- import kotlinx.serialization.json.Json
4- import kotlinx.serialization.json.buildJsonObject
5- import kotlinx.serialization.json.decodeFromJsonElement
6- import kotlinx.serialization.json.encodeToJsonElement
7- import kotlinx.serialization.json.jsonObject
8- import kotlinx.serialization.json.jsonPrimitive
9- import kotlinx.serialization.json.put
103import org.jetbrains.kotlinx.jupyter.api.DisplayResult
11- import org.jetbrains.kotlinx.jupyter.api.MimeTypedResultEx
12- import org.jetbrains.kotlinx.jupyter.api.MimeTypes
13- import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
14- import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
15- import org.jetbrains.kotlinx.jupyter.widget.model.DEFAULT_MAJOR_VERSION
16- import org.jetbrains.kotlinx.jupyter.widget.model.DEFAULT_MINOR_VERSION
17- import org.jetbrains.kotlinx.jupyter.widget.model.DefaultWidgetModel
184import org.jetbrains.kotlinx.jupyter.widget.model.WidgetFactoryRegistry
195import org.jetbrains.kotlinx.jupyter.widget.model.WidgetModel
20- import org.jetbrains.kotlinx.jupyter.widget.model.versionConstraintRegex
21- import org.jetbrains.kotlinx.jupyter.widget.protocol.CustomMessage
22- import org.jetbrains.kotlinx.jupyter.widget.protocol.RequestStateMessage
23- import org.jetbrains.kotlinx.jupyter.widget.protocol.RequestStatesMessage
24- import org.jetbrains.kotlinx.jupyter.widget.protocol.UpdateStatesMessage
25- import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetMessage
26- import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetOpenMessage
27- import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetStateMessage
28- import org.jetbrains.kotlinx.jupyter.widget.protocol.WidgetUpdateMessage
29- import org.jetbrains.kotlinx.jupyter.widget.protocol.getWireMessage
30- import org.jetbrains.kotlinx.jupyter.widget.protocol.toPatch
316
32- private val widgetOpenMetadataJson =
33- buildJsonObject {
34- put(" version" , " ${DEFAULT_MAJOR_VERSION } .${DEFAULT_MINOR_VERSION } " )
35- }
7+ public interface WidgetManager {
8+ public val factoryRegistry: WidgetFactoryRegistry
369
37- public class WidgetManager (
38- private val commManager : CommManager ,
39- private val classLoaderProvider : () -> ClassLoader ,
40- ) {
41- private val widgetTarget = " jupyter.widget"
42- private val widgetControlTarget = " jupyter.widget.control"
43- private val widgets = mutableMapOf<String , WidgetModel >()
44- private val widgetIdByWidget = mutableMapOf<WidgetModel , String >()
45-
46- public val factoryRegistry: WidgetFactoryRegistry = WidgetFactoryRegistry ()
47-
48- init {
49- commManager.registerCommTarget(widgetControlTarget) { comm, _, _, _ ->
50- comm.onMessage { msg, _, _ ->
51- when (Json .decodeFromJsonElement<WidgetMessage >(msg)) {
52- is RequestStatesMessage -> {
53- val fullStates =
54- widgets.mapValues { (id, widget) ->
55- widget.getFullState()
56- }
57-
58- val wireMessage = getWireMessage(fullStates)
59- val message = UpdateStatesMessage (wireMessage.state, wireMessage.bufferPaths)
60-
61- val data = Json .encodeToJsonElement<WidgetMessage >(message).jsonObject
62- comm.send(data, null , emptyList())
63- }
64-
65- else -> {}
66- }
67- }
68- }
69-
70- commManager.registerCommTarget(widgetTarget) { comm, data, _, buffers ->
71- val openMessage = Json .decodeFromJsonElement<WidgetOpenMessage >(data)
72- val modelName = openMessage.state[" _model_name" ]?.jsonPrimitive?.content!!
73- val widgetFactory = factoryRegistry.loadWidgetFactory(modelName, classLoaderProvider())
74-
75- val widget = widgetFactory.create(this )
76- val patch = openMessage.toPatch(buffers)
77- widget.applyPatch(patch)
78-
79- initializeWidget(comm, widget)
80- }
81- }
82-
83- public fun getWidget (modelId : String ): WidgetModel ? = widgets[modelId]
84-
85- public fun getWidgetId (widget : WidgetModel ): String? = widgetIdByWidget[widget]
86-
87- public fun registerWidget (widget : WidgetModel ) {
88- if (widgetIdByWidget[widget] != null ) return
89-
90- val fullState = widget.getFullState()
91- val wireMessage = getWireMessage(fullState)
92-
93- val comm =
94- commManager.openComm(
95- widgetTarget,
96- Json
97- .encodeToJsonElement(
98- WidgetOpenMessage (
99- wireMessage.state,
100- wireMessage.bufferPaths,
101- ),
102- ).jsonObject,
103- widgetOpenMetadataJson,
104- wireMessage.buffers,
105- )
106-
107- initializeWidget(comm, widget)
108- }
109-
110- public fun renderWidget (widget : WidgetModel ): DisplayResult =
111- MimeTypedResultEx (
112- buildJsonObject {
113- val modelId = widgetIdByWidget[widget] ? : error(" Widget is not registered" )
114- var versionMajor = DEFAULT_MAJOR_VERSION
115- var versionMinor = DEFAULT_MINOR_VERSION
116- var modelName: String? = null
117- if (widget is DefaultWidgetModel ) {
118- modelName = widget.modelName
119- val version = widget.modelModuleVersion
120- val matchResult = versionConstraintRegex.find(version)
121- if (matchResult != null ) {
122- versionMajor = matchResult.groupValues[1 ].toInt()
123- versionMinor = matchResult.groupValues[2 ].toInt()
124- }
125- }
126- if (modelName != null ) {
127- put(MimeTypes .HTML , " $modelName (id=$modelId )" )
128- }
129- put(
130- " application/vnd.jupyter.widget-view+json" ,
131- buildJsonObject {
132- put(" version_major" , versionMajor)
133- put(" version_minor" , versionMinor)
134- put(" model_id" , modelId)
135- },
136- )
137- },
138- null ,
139- )
140-
141- private fun initializeWidget (
142- comm : Comm ,
143- widget : WidgetModel ,
144- ) {
145- val modelId = comm.id
146- widgetIdByWidget[widget] = modelId
147- widgets[modelId] = widget
148-
149- // Reflect kernel-side changes on the frontend
150- widget.addChangeListener { patch ->
151- val wireMessage = getWireMessage(patch)
152- val data =
153- Json
154- .encodeToJsonElement<WidgetMessage >(
155- WidgetUpdateMessage (
156- wireMessage.state,
157- wireMessage.bufferPaths,
158- ),
159- ).jsonObject
160- comm.send(data, null , wireMessage.buffers)
161- }
162-
163- // Reflect frontend-side changes on kernel
164- comm.onMessage { msg, _, buffers ->
165- when (val message = Json .decodeFromJsonElement<WidgetMessage >(msg)) {
166- is WidgetStateMessage -> {
167- widget.applyPatch(message.toPatch(buffers))
168- }
169-
170- is RequestStateMessage -> {
171- val fullState = widget.getFullState()
172- val wireMessage = getWireMessage(fullState)
173- val data =
174- Json
175- .encodeToJsonElement<WidgetMessage >(
176- WidgetUpdateMessage (
177- wireMessage.state,
178- wireMessage.bufferPaths,
179- ),
180- ).jsonObject
181- comm.send(data, null , wireMessage.buffers)
182- }
183-
184- is CustomMessage -> {}
185-
186- else -> {}
187- }
188- }
189- }
190- }
10+ public fun getWidget (modelId : String ): WidgetModel ?
11+ public fun getWidgetId (widget : WidgetModel ): String?
12+ public fun registerWidget (widget : WidgetModel )
13+ public fun renderWidget (widget : WidgetModel ): DisplayResult
14+ }
0 commit comments