5050
5151import java .io .IOException ;
5252import java .io .InputStream ;
53+ import java .io .UncheckedIOException ;
5354import java .util .ArrayList ;
5455import java .util .Collections ;
5556import java .util .Map ;
5657import java .util .Set ;
58+ import java .util .function .Consumer ;
5759
60+ import org .knime .core .checkpoint .PhasedInit ;
61+ import org .knime .core .checkpoint .PhasedInitSupport ;
5862import org .knime .core .node .NodeLogger ;
5963import org .knime .python3 .PythonGatewayCreationGate .PythonGatewayCreationGateListener ;
6064
@@ -79,6 +83,18 @@ public final class PythonGatewayTracker implements PythonGatewayCreationGateList
7983
8084 private PythonGatewayTracker () {
8185 m_openGateways = gatewaySet ();
86+ // Support CRaC (Coordinated Restore at Checkpoint) and close all connections prior to checkpointing
87+ PhasedInitSupport .registerOrActivate (new PhasedInit <RuntimeException >() {
88+ @ Override
89+ public void beforeCheckpoint () throws RuntimeException {
90+ try {
91+ clearForCheckpoint ();
92+ } catch (IOException ex ) {
93+ throw new UncheckedIOException (
94+ "Error when forcefully terminating Python processes before checkpointing" , ex );
95+ }
96+ }
97+ });
8298 }
8399
84100 /**
@@ -107,15 +123,30 @@ public void onPythonGatewayCreationGateClose() {
107123 }
108124
109125 void clear () throws IOException {
126+ clear (LOGGER ::error ,
127+ "Found running Python processes (%d). Aborting them to allow installation process. "
128+ + "If this leads to failures in node execution, "
129+ + "please restart those nodes once the installation has finished. Triggered from thread '%s'." );
130+ }
131+
132+ void clearForCheckpoint () throws IOException {
133+ clear (LOGGER ::info ,
134+ "Found running Python processes (%d). Aborting them prior to checkpointing. Triggered from thread '%s'." );
135+ }
136+
137+ /**
138+ * Closes all open gateways and logs a message using the provided consumer.
139+ *
140+ * @param logMessageConsumer consumer for logging messages
141+ * @param logMessage message format string with placeholders for gateway count and thread name
142+ * @throws IOException if an error occurs while closing the gateways
143+ */
144+ private void clear (final Consumer <String > logMessageConsumer , final String logMessage ) throws IOException {
110145 if (m_openGateways .isEmpty ()) {
111146 return ;
112147 }
113148
114- LOGGER .errorWithFormat (
115- "Found running Python processes (%d). Aborting them to allow installation process. "
116- + "If this leads to failures in node execution, "
117- + "please restart those nodes once the installation has finished. Triggered from thread '%s'." ,
118- m_openGateways .size (), Thread .currentThread ().getName ());
149+ logMessageConsumer .accept (String .format (logMessage , m_openGateways .size (), Thread .currentThread ().getName ()));
119150
120151 var exceptions = new ArrayList <Exception >();
121152 for (var gateway : m_openGateways ) {
0 commit comments