Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
omit =
*_tmp.py
zetta_utils/log.py
zetta_utils/builder/preload/try_load.py
zetta_utils/cli/task_mgmt.py
zetta_utils/task_management/subtask_structure.py
zetta_utils/task_management/segment.py
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,7 @@ warn_unused_ignores = true
known_third_party = "wandb"
profile = "black"
skip = ["specs"]
skip_glob = ["**/__init__.py"]
skip_glob = [
"**/__init__.py",
"zetta_utils/builder/preload/*.py"
]
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from docker.errors import DockerException

import docker
from zetta_utils import constants
from zetta_utils import constants, setup_environment

constants.RUN_DATABASE = None

Expand All @@ -15,6 +15,12 @@ def pytest_addoption(parser):
parser.addoption("--run-integration", default=False, help="Run integration tests")


@pytest.fixture(scope="session", autouse=True)
def _setup_forkserver():
"""Initialize forkserver with preloaded modules for test session."""
setup_environment("all")


@pytest.fixture(scope="session")
def datastore_emulator():
"""Ensure that the DataStore service is up and responsive."""
Expand Down
30 changes: 17 additions & 13 deletions tests/unit/mazepa/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
from ..helpers import DummyException


def _slow_task_fn():
time.sleep(0.3)


def _raising_task_fn():
raise DummyException()


def test_make_taskable_operation_cls() -> None:
@taskable_operation_cls(operation_name="OpDummyClass1")
@attrs.mutable
Expand Down Expand Up @@ -56,24 +64,20 @@ def dummy_task_fn():


def test_task_runtime_limit() -> None:
@taskable_operation(runtime_limit_sec=0.1)
def dummy_task_fn():
time.sleep(0.3)

assert isinstance(dummy_task_fn, TaskableOperation)
task = dummy_task_fn.make_task()
# pebble with forkserver/spawn context does not support nested functions
slow_task_op = taskable_operation(runtime_limit_sec=0.1)(_slow_task_fn)
assert isinstance(slow_task_op, TaskableOperation)
task = slow_task_op.make_task()
assert isinstance(task, Task)
outcome = task(debug=False)
assert isinstance(outcome.exception, MazepaTimeoutError)


def test_task_no_handle_exc() -> None:
@taskable_operation(runtime_limit_sec=0.1)
def dummy_task_fn():
raise DummyException()

assert isinstance(dummy_task_fn, TaskableOperation)
task = dummy_task_fn.make_task()
# pebble with forkserver/spawn context does not support nested functions
raising_task_op = taskable_operation(runtime_limit_sec=0.1)(_raising_task_fn)
assert isinstance(raising_task_op, TaskableOperation)
task = raising_task_op.make_task()
assert isinstance(task, Task)
with pytest.raises(Exception):
with pytest.raises(DummyException):
task(debug=False, handle_exceptions=False)
149 changes: 77 additions & 72 deletions zetta_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,35 @@
# pylint: disable=unused-import, import-outside-toplevel, broad-exception-caught, import-error
"""Zetta AI Computational Connectomics Toolkit."""
import multiprocessing
import os
import sys
import threading
import time
import warnings
import multiprocessing
from typing import Literal

from .log import get_logger

# Set global multiprocessing threshold
# Set global multiprocessing context and threshold
MULTIPROCESSING_CONTEXT = "forkserver"
MULTIPROCESSING_NUM_TASKS_THRESHOLD = 128

# Forkserver initialization
LoadMode = Literal["all", "inference", "training", "try"]

_PRELOAD_MODULES: dict[LoadMode, str] = {
"all": "zetta_utils.builder.preload.all",
"inference": "zetta_utils.builder.preload.inference",
"training": "zetta_utils.builder.preload.training",
"try": "zetta_utils.builder.preload.try_load",
}

# Set start method to `forkserver` if not set elsewhere
# If not set here, `get_start_method` will set the default
# to `fork` w/o allow_none and cause issues with dependencies.
if multiprocessing.get_start_method(allow_none=True) is None:
multiprocessing.set_start_method("forkserver")
multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)


if "sphinx" not in sys.modules: # pragma: no cover
import pdbp # noqa
Expand All @@ -36,88 +51,78 @@
warnings.filterwarnings("ignore", category=DeprecationWarning)


def _load_core_modules():
"""Load core modules that were previously imported at package level."""
from . import log, typing, parsing, builder, common, constants
from . import geometry, distributions, layer, ng
def load_all_modules(): # pragma: no cover
import zetta_utils.builder.preload.all

# Add builder module suppression now that it's loaded
log.add_supress_traceback_module(builder)

def try_load_train_inference(): # pragma: no cover
import zetta_utils.builder.preload.try_load

def load_all_modules():
_load_core_modules()
load_inference_modules()
load_training_modules()
from . import task_management

def load_submodules(): # pragma: no cover
from . import internal

def try_load_train_inference(): # pragma: no cover
try:
_load_core_modules()
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception(e)

try:
load_inference_modules()
def load_inference_modules(): # pragma: no cover
import zetta_utils.builder.preload.inference

except Exception as e: # pylint: disable=broad-exception-caught
logger.exception(e)

try:
load_training_modules()
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception(e)
def load_training_modules(): # pragma: no cover
import zetta_utils.builder.preload.training

try:
from . import mazepa_addons
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception(e)

def _noop() -> None:
pass

def load_submodules(): # pragma: no cover
from . import internal

def get_mp_context() -> multiprocessing.context.BaseContext:
"""Get the multiprocessing context for the configured start method."""
return multiprocessing.get_context(MULTIPROCESSING_CONTEXT)

def load_inference_modules():
_load_core_modules()
from . import (
augmentations,
convnet,
mazepa,
mazepa_layer_processing,
tensor_ops,
tensor_typing,
tensor_mapping,
)
from .layer import volumetric
from .layer.volumetric import cloudvol
from .message_queues import sqs

from . import mazepa_addons
from . import message_queues
from . import cloud_management

load_submodules()


def load_training_modules():
_load_core_modules()
from . import (
augmentations,
convnet,
mazepa,
tensor_ops,
tensor_typing,
training,
tensor_mapping,

def initialize_forkserver(load_mode: LoadMode = "all") -> None:
"""Initialize forkserver with preloaded modules for the given load mode."""
preload_module = _PRELOAD_MODULES[load_mode]
logger.info(f"Configuring forkserver with preload module: {preload_module}")

total_start = time.perf_counter()
multiprocessing.set_forkserver_preload([preload_module])
ctx = get_mp_context()
proc = ctx.Process(target=_noop) # type: ignore[attr-defined]
proc.start()
proc.join()

total_elapsed = time.perf_counter() - total_start
logger.info(f"Forkserver initialized in {total_elapsed:.2f}s (mode: {load_mode})")


def setup_environment(load_mode: LoadMode = "all") -> None:
"""
Initialize forkserver and load modules in parallel.

This function:
1. Starts forkserver initialization in a background thread
2. Loads modules in the main process (runs in parallel with forkserver init)
3. Waits for forkserver to be ready before returning

Args:
load_mode: Which modules to load ("all", "inference", "training", "try")
"""
# Start forkserver init in background while main process loads modules
forkserver_thread = threading.Thread(
target=initialize_forkserver, args=(load_mode,), name="forkserver_init"
)
from .layer import volumetric, db_layer
from .layer.db_layer import datastore, firestore
from .layer.volumetric import cloudvol
forkserver_thread.start()

from . import mazepa_addons
from . import message_queues
from . import cloud_management
# Load modules in main process (runs in parallel with forkserver init)
if load_mode == "all":
load_all_modules()
elif load_mode == "inference": # pragma: no cover
load_inference_modules()
elif load_mode == "try": # pragma: no cover
try_load_train_inference()
else: # training # pragma: no cover
load_training_modules()

load_submodules()
# Wait for forkserver to be ready before proceeding
forkserver_thread.join()
Empty file.
15 changes: 15 additions & 0 deletions zetta_utils/builder/preload/all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# pylint: disable=unused-import, wrong-import-position
"""All module imports."""

import time

from zetta_utils import log

_start = time.perf_counter()

from zetta_utils.builder.preload import inference
from zetta_utils.builder.preload import training
from zetta_utils import task_management

_elapsed = time.perf_counter() - _start
log.get_logger("zetta_utils").debug(f"Preload all modules: {_elapsed:.2f}s")
15 changes: 15 additions & 0 deletions zetta_utils/builder/preload/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# pylint: disable=unused-import, wrong-import-position
"""Core module imports - shared by all load modes."""

import time

_start = time.perf_counter()

from zetta_utils import log, typing, parsing, builder, common, constants
from zetta_utils import geometry, distributions, layer, ng

# Add builder module suppression now that it's loaded
log.add_supress_traceback_module(builder)

_elapsed = time.perf_counter() - _start
log.get_logger("zetta_utils").debug(f"Preload core modules: {_elapsed:.2f}s")
33 changes: 33 additions & 0 deletions zetta_utils/builder/preload/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# pylint: disable=unused-import, wrong-import-position
"""Inference module imports."""

import time

from zetta_utils import log

_start = time.perf_counter()

# Import core first
from zetta_utils.builder.preload import core

from zetta_utils import (
augmentations,
convnet,
mazepa,
mazepa_layer_processing,
tensor_ops,
tensor_typing,
tensor_mapping,
)
from zetta_utils.layer import volumetric
from zetta_utils.layer.volumetric import cloudvol
from zetta_utils.message_queues import sqs

from zetta_utils import mazepa_addons
from zetta_utils import message_queues
from zetta_utils import cloud_management

from zetta_utils import internal

_elapsed = time.perf_counter() - _start
log.get_logger("zetta_utils").debug(f"Preload inference modules: {_elapsed:.2f}s")
33 changes: 33 additions & 0 deletions zetta_utils/builder/preload/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# pylint: disable=unused-import, wrong-import-position
"""Training module imports."""

import time

from zetta_utils import log

_start = time.perf_counter()

# Import core first
from zetta_utils.builder.preload import core

from zetta_utils import (
augmentations,
convnet,
mazepa,
tensor_ops,
tensor_typing,
training,
tensor_mapping,
)
from zetta_utils.layer import volumetric, db_layer
from zetta_utils.layer.db_layer import datastore, firestore
from zetta_utils.layer.volumetric import cloudvol

from zetta_utils import mazepa_addons
from zetta_utils import message_queues
from zetta_utils import cloud_management

from zetta_utils import internal

_elapsed = time.perf_counter() - _start
log.get_logger("zetta_utils").debug(f"Preload training modules: {_elapsed:.2f}s")
Loading
Loading