@@ -123,6 +123,18 @@ def _deploy_pathways_proxy_server(
123123 _logger .info ("Successfully deployed Pathways proxy." )
124124
125125
126+ def _restore_env_var (key : str , original_value : str | None ) -> None :
127+ """Restores an environment variable to its original value or unsets it."""
128+ if original_value is None :
129+ _logger .info ("Unsetting environment variable: %s" , key )
130+ os .environ .pop (key , None )
131+ else :
132+ _logger .info (
133+ "Restoring environment variable '%s' to '%s'" , key , original_value
134+ )
135+ os .environ [key ] = original_value
136+
137+
126138class _ISCPathways :
127139 """Class for managing TPUs for interactive supercomputing.
128140
@@ -163,6 +175,10 @@ def __init__(
163175 self ._proxy_port = None
164176 self .proxy_server_image = proxy_server_image
165177 self .proxy_options = proxy_options or ProxyOptions ()
178+ self ._old_jax_platforms = None
179+ self ._old_jax_backend_target = None
180+ self ._old_jax_platforms_config = None
181+ self ._old_jax_backend_target_config = None
166182
167183 def __repr__ (self ):
168184 return (
@@ -176,6 +192,15 @@ def __repr__(self):
176192
177193 def __enter__ (self ):
178194 """Enters the context manager, ensuring cluster exists."""
195+ self ._old_jax_platforms = os .environ .get (_JAX_PLATFORMS_KEY )
196+ self ._old_jax_backend_target = os .environ .get (_JAX_BACKEND_TARGET_KEY )
197+ self ._old_jax_platforms_config = getattr (
198+ jax .config , _JAX_PLATFORMS_KEY , None
199+ )
200+ self ._old_jax_backend_target_config = getattr (
201+ jax .config , _JAX_BACKEND_TARGET_KEY , None
202+ )
203+
179204 try :
180205 _deploy_pathways_proxy_server (
181206 pathways_service = self .pathways_service ,
@@ -199,11 +224,17 @@ def __enter__(self):
199224 )
200225
201226 # Update the JAX backend to use the proxy.
227+ os .environ [_JAX_PLATFORMS_KEY ] = _JAX_PLATFORM_PROXY
228+ os .environ [
229+ _JAX_BACKEND_TARGET_KEY
230+ ] = f"{ _JAX_BACKEND_TARGET_HOSTNAME } :{ self ._proxy_port } "
231+
202232 jax .config .update (_JAX_PLATFORMS_KEY , _JAX_PLATFORM_PROXY )
203233 jax .config .update (
204234 _JAX_BACKEND_TARGET_KEY ,
205235 f"{ _JAX_BACKEND_TARGET_HOSTNAME } :{ self ._proxy_port } " ,
206236 )
237+
207238 pathwaysutils .initialize ()
208239 _logger .info (
209240 "Interactive supercomputing proxy client ready for cluster '%s'." ,
@@ -221,7 +252,7 @@ def __exit__(self, exc_type, exc_value, traceback):
221252 _logger .info ("Exiting ISCPathways context." )
222253 self ._cleanup ()
223254
224- def _cleanup (self ):
255+ def _cleanup (self ) -> None :
225256 """Cleans up resources created by the ISCPathways context."""
226257 # 1. Clear JAX caches and run garbage collection.
227258 _logger .info ("Starting Pathways proxy cleanup." )
@@ -248,6 +279,16 @@ def _cleanup(self):
248279 gke_utils .delete_gke_job (self ._proxy_job_name )
249280 _logger .info ("Pathways proxy GKE job deletion complete." )
250281
282+ # 4. Restore JAX variables.
283+ _logger .info ("Restoring JAX env and config variables..." )
284+ _restore_env_var (_JAX_PLATFORMS_KEY , self ._old_jax_platforms )
285+ _restore_env_var (_JAX_BACKEND_TARGET_KEY , self ._old_jax_backend_target )
286+ jax .config .update (_JAX_PLATFORMS_KEY , self ._old_jax_platforms_config )
287+ jax .config .update (
288+ _JAX_BACKEND_TARGET_KEY , self ._old_jax_backend_target_config
289+ )
290+ _logger .info ("JAX variables restored." )
291+
251292
252293@contextlib .contextmanager
253294def connect (
0 commit comments