Skip to content

Commit 4e6e6b7

Browse files
guptaakacopybara-github
authored andcommitted
Update environment variables for JAX backend
PiperOrigin-RevId: 886971839
1 parent a57c2a0 commit 4e6e6b7

1 file changed

Lines changed: 42 additions & 1 deletion

File tree

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ def _deploy_pathways_proxy_server(
123123
_logger.info("Successfully deployed Pathways proxy.")
124124

125125

126+
def _restore_env_var(key: str, original_value: str | None) -> None:
127+
"""Restores an environment variable to its original value or unsets it."""
128+
if original_value is None:
129+
_logger.info("Unsetting environment variable: %s", key)
130+
os.environ.pop(key, None)
131+
else:
132+
_logger.info(
133+
"Restoring environment variable '%s' to '%s'", key, original_value
134+
)
135+
os.environ[key] = original_value
136+
137+
126138
class _ISCPathways:
127139
"""Class for managing TPUs for interactive supercomputing.
128140
@@ -163,6 +175,10 @@ def __init__(
163175
self._proxy_port = None
164176
self.proxy_server_image = proxy_server_image
165177
self.proxy_options = proxy_options or ProxyOptions()
178+
self._old_jax_platforms = None
179+
self._old_jax_backend_target = None
180+
self._old_jax_platforms_config = None
181+
self._old_jax_backend_target_config = None
166182

167183
def __repr__(self):
168184
return (
@@ -176,6 +192,15 @@ def __repr__(self):
176192

177193
def __enter__(self):
178194
"""Enters the context manager, ensuring cluster exists."""
195+
self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY)
196+
self._old_jax_backend_target = os.environ.get(_JAX_BACKEND_TARGET_KEY)
197+
self._old_jax_platforms_config = getattr(
198+
jax.config, _JAX_PLATFORMS_KEY, None
199+
)
200+
self._old_jax_backend_target_config = getattr(
201+
jax.config, _JAX_BACKEND_TARGET_KEY, None
202+
)
203+
179204
try:
180205
_deploy_pathways_proxy_server(
181206
pathways_service=self.pathways_service,
@@ -199,11 +224,17 @@ def __enter__(self):
199224
)
200225

201226
# Update the JAX backend to use the proxy.
227+
os.environ[_JAX_PLATFORMS_KEY] = _JAX_PLATFORM_PROXY
228+
os.environ[
229+
_JAX_BACKEND_TARGET_KEY
230+
] = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}"
231+
202232
jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY)
203233
jax.config.update(
204234
_JAX_BACKEND_TARGET_KEY,
205235
f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}",
206236
)
237+
207238
pathwaysutils.initialize()
208239
_logger.info(
209240
"Interactive supercomputing proxy client ready for cluster '%s'.",
@@ -221,7 +252,7 @@ def __exit__(self, exc_type, exc_value, traceback):
221252
_logger.info("Exiting ISCPathways context.")
222253
self._cleanup()
223254

224-
def _cleanup(self):
255+
def _cleanup(self) -> None:
225256
"""Cleans up resources created by the ISCPathways context."""
226257
# 1. Clear JAX caches and run garbage collection.
227258
_logger.info("Starting Pathways proxy cleanup.")
@@ -248,6 +279,16 @@ def _cleanup(self):
248279
gke_utils.delete_gke_job(self._proxy_job_name)
249280
_logger.info("Pathways proxy GKE job deletion complete.")
250281

282+
# 4. Restore JAX variables.
283+
_logger.info("Restoring JAX env and config variables...")
284+
_restore_env_var(_JAX_PLATFORMS_KEY, self._old_jax_platforms)
285+
_restore_env_var(_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target)
286+
jax.config.update(_JAX_PLATFORMS_KEY, self._old_jax_platforms_config)
287+
jax.config.update(
288+
_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target_config
289+
)
290+
_logger.info("JAX variables restored.")
291+
251292

252293
@contextlib.contextmanager
253294
def connect(

0 commit comments

Comments
 (0)