diff --git a/.bumpversion.cfg b/.bumpversion.cfg index b6b4de269..a0775c7ee 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 2.21.1 +current_version = 3.1.7 commit = True tag = True diff --git a/.coveragerc b/.coveragerc index a38e1c392..d351f3e7e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,6 +5,13 @@ source = pychunkedgraph omit = *test* *benchmarking/* + pychunkedgraph/debug/* + pychunkedgraph/export/* + pychunkedgraph/jobs/* + pychunkedgraph/logging/* + pychunkedgraph/repair/* + pychunkedgraph/meshing/* + pychunkedgraph/app/* [report] # Regexes for lines to exclude from consideration diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 899f0431f..b64b1175d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -4,19 +4,49 @@ on: push: branches: - "main" + - "pcgv3" pull_request: branches: - "main" + - "pcgv3" jobs: unit-tests: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - - name: Build image and run tests + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build image + uses: docker/build-push-action@v6 + with: + context: . + load: true + tags: seunglab/pychunkedgraph:${{ github.sha }} + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Run tests with coverage run: | - docker build --tag seunglab/pychunkedgraph:$GITHUB_SHA . - docker run --rm seunglab/pychunkedgraph:$GITHUB_SHA /bin/sh -c "pytest --cov-config .coveragerc --cov=pychunkedgraph ./pychunkedgraph/tests && codecov" + docker run --name pcg-tests seunglab/pychunkedgraph:${{ github.sha }} \ + /bin/sh -c "pytest --cov-config .coveragerc --cov=pychunkedgraph --cov-report=xml:/app/coverage.xml ./pychunkedgraph/tests" + + - name: Copy coverage report from container + if: always() + run: docker cp pcg-tests:/app/coverage.xml ./coverage.xml + + - name: Upload coverage to Codecov + if: always() + uses: codecov/codecov-action@v5 + with: + files: ./coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + slug: CAVEconnectome/PyChunkedGraph + fail_ci_if_error: true + - name: Cleanup + if: always() + run: docker rm pcg-tests || true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6ee89f6c6..80123fa18 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -45,7 +45,7 @@ jobs: - name: Install Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install bumpversion run: pip install bumpversion - name: Bump version with bumpversion diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..70ceaed90 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,5 @@ +repos: + - repo: https://github.com/psf/black + rev: 26.1.0 + hooks: + - id: black diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index a5e33242d..000000000 --- a/.travis.yml +++ /dev/null @@ -1,60 +0,0 @@ -sudo: true -services: - docker - -env: - global: - - CLOUDSDK_CORE_DISABLE_PROMPTS=1 - -stages: - - test - - name: merge-deploy -python: 3.6 -notifications: - email: - on_success: change - on_failure: always - -jobs: - include: - - stage: test - name: "Running Tests" - language: minimal - before_script: - # request codecov to detect CI environment to pass through to docker - - ci_env=`bash <(curl -s https://codecov.io/env)` - - script: - - openssl aes-256-cbc -K $encrypted_506e835c2891_key -iv $encrypted_506e835c2891_iv -in key.json.enc -out key.json -d - - curl https://sdk.cloud.google.com | bash > /dev/null - - source "$HOME/google-cloud-sdk/path.bash.inc" - - gcloud auth activate-service-account --key-file=key.json - - gcloud auth configure-docker - - docker build --tag seunglab/pychunkedgraph:$TRAVIS_BRANCH . || travis_terminate 1 - - docker run $ci_env --rm seunglab/pychunkedgraph:$TRAVIS_BRANCH /bin/sh -c "tox -v -- --cov-config .coveragerc --cov=pychunkedgraph && codecov" - - - stage: merge-deploy - name: "version bump and merge into master" - language: python - install: - - pip install bumpversion - - before_script: - - "git clone https://gist.github.com/2c04596a45ccac57fe8dde0718ad58ee.git /tmp/travis-automerge" - - "chmod a+x /tmp/travis-automerge/auto_merge_travis_with_bumpversion.sh" - - script: - - "BRANCHES_TO_MERGE_REGEX='develop' BRANCH_TO_MERGE_INTO=master /tmp/travis-automerge/auto_merge_travis_with_bumpversion.sh" - - - stage: merge-deploy - name: "deploy to pypi" - language: python - install: - - pip install twine - - before_script: - - "git clone https://gist.github.com/cf9b261f26a1bf3fae6b59e7047f007a.git /tmp/travis-autodist" - - "chmod a+x /tmp/travis-autodist/pypi_dist.sh" - - script: - - "BRANCHES_TO_DIST='develop' /tmp/travis-autodist/pypi_dist.sh" diff --git a/Dockerfile b/Dockerfile index 2b7eeb151..968f93043 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,66 @@ -FROM caveconnectome/pychunkedgraph:base_042124 +# syntax=docker/dockerfile:1 +ARG PYTHON_VERSION=3.12 +ARG BASE_IMAGE=tiangolo/uwsgi-nginx-flask:python${PYTHON_VERSION} + + +###################################################### +# Stage 1: Conda environment +###################################################### +FROM ${BASE_IMAGE} AS conda-deps +ENV PATH="/root/miniconda3/bin:${PATH}" + +RUN apt-get update && apt-get install build-essential wget -y \ + && wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + && bash Miniconda3-latest-Linux-x86_64.sh -b \ + && rm Miniconda3-latest-Linux-x86_64.sh \ + && conda config --add channels conda-forge \ + && conda update -y --override-channels -c conda-forge conda \ + && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ + && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ + && conda install -y --override-channels -c conda-forge conda-pack + +COPY requirements.yml requirements.txt requirements-dev.txt ./ + +RUN --mount=type=cache,target=/opt/conda/pkgs \ + conda env create -n pcg -f requirements.yml + +RUN conda-pack -n pcg --ignore-missing-files -o /tmp/env.tar \ + && mkdir -p /app/venv && cd /app/venv \ + && tar xf /tmp/env.tar && rm /tmp/env.tar \ + && /app/venv/bin/conda-unpack + + +###################################################### +# Stage 2: Bigtable emulator +###################################################### +FROM golang:bullseye AS bigtable-emulator +ARG GOOGLE_CLOUD_GO_VERSION=bigtable/v1.19.0 +RUN --mount=type=cache,target=/go/pkg/mod \ + --mount=type=cache,target=/root/.cache/go-build \ + git clone --depth=1 --branch="$GOOGLE_CLOUD_GO_VERSION" \ + https://github.com/googleapis/google-cloud-go.git /usr/src \ + && cd /usr/src/bigtable && go install -v ./cmd/emulator + + +###################################################### +# Stage 3: Production +###################################################### +FROM ${BASE_IMAGE} ENV VIRTUAL_ENV=/app/venv ENV PATH="$VIRTUAL_ENV/bin:$PATH" +COPY --from=conda-deps /app/venv /app/venv +COPY --from=bigtable-emulator /go/bin/emulator /app/venv/bin/cbtemulator COPY override/gcloud /app/venv/bin/gcloud COPY override/timeout.conf /etc/nginx/conf.d/timeout.conf COPY override/supervisord.conf /etc/supervisor/conf.d/supervisord.conf +RUN pip install --no-cache-dir --no-deps --force-reinstall zstandard>=0.23.0 \ + && mkdir -p /home/nginx/.cloudvolume/secrets \ + && chown -R nginx /home/nginx \ + && usermod -d /home/nginx -s /bin/bash nginx COPY requirements.txt . -RUN pip install --upgrade -r requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade -r requirements.txt + COPY . /app diff --git a/README.md b/README.md index ef888b3c6..ac1c67161 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # PyChunkedGraph -[![Build Status](https://travis-ci.org/seung-lab/PyChunkedGraph.svg?branch=master)](https://travis-ci.org/seung-lab/PyChunkedGraph) -[![codecov](https://codecov.io/gh/seung-lab/PyChunkedGraph/branch/master/graph/badge.svg)](https://codecov.io/gh/seung-lab/PyChunkedGraph) +[![Tests](https://github.com/CAVEconnectome/PyChunkedGraph/actions/workflows/main.yml/badge.svg)](https://github.com/CAVEconnectome/PyChunkedGraph/actions/workflows/main.yml) +[![codecov](https://codecov.io/gh/CAVEconnectome/PyChunkedGraph/branch/main/graph/badge.svg)](https://codecov.io/gh/CAVEconnectome/PyChunkedGraph) The PyChunkedGraph is a proofreading and segmentation data management backend powering FlyWire and other proofreading platforms. It builds on an initial agglomeration of supervoxels and facilitates fast and parallel editing of connected components in the agglomeration graph by many users. diff --git a/base.Dockerfile b/base.Dockerfile deleted file mode 100644 index b5123e137..000000000 --- a/base.Dockerfile +++ /dev/null @@ -1,70 +0,0 @@ -ARG PYTHON_VERSION=3.11 -ARG BASE_IMAGE=tiangolo/uwsgi-nginx-flask:python${PYTHON_VERSION} - - -###################################################### -# Build Image - PCG dependencies -###################################################### -FROM ${BASE_IMAGE} AS pcg-build -ENV PATH="/root/miniconda3/bin:${PATH}" -ENV CONDA_ENV="pychunkedgraph" - -# Setup Miniconda -RUN apt-get update && apt-get install build-essential wget -y -RUN wget \ - https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ - && mkdir /root/.conda \ - && bash Miniconda3-latest-Linux-x86_64.sh -b \ - && rm -f Miniconda3-latest-Linux-x86_64.sh \ - && conda update conda - -# Install PCG dependencies - especially graph-tool -# Note: uwsgi has trouble with pip and python3.11, so adding this with conda, too -COPY requirements.txt . -COPY requirements.yml . -COPY requirements-dev.txt . -RUN conda env create -n ${CONDA_ENV} -f requirements.yml - -# Shrink conda environment into portable non-conda env -RUN conda install conda-pack -c conda-forge - -RUN conda-pack -n ${CONDA_ENV} --ignore-missing-files -o /tmp/env.tar \ - && mkdir -p /app/venv \ - && cd /app/venv \ - && tar xf /tmp/env.tar \ - && rm /tmp/env.tar -RUN /app/venv/bin/conda-unpack - - -###################################################### -# Build Image - Bigtable Emulator (without Google SDK) -###################################################### -FROM golang:bullseye as bigtable-emulator-build -RUN mkdir -p /usr/src -WORKDIR /usr/src -ENV GOOGLE_CLOUD_GO_VERSION bigtable/v1.19.0 -RUN apt-get update && apt-get install git -y -RUN git clone --depth=1 --branch="$GOOGLE_CLOUD_GO_VERSION" https://github.com/googleapis/google-cloud-go.git . \ - && cd bigtable \ - && go install -v ./cmd/emulator - - -###################################################### -# Production Image -###################################################### -FROM ${BASE_IMAGE} -ENV VIRTUAL_ENV=/app/venv -ENV PATH="$VIRTUAL_ENV/bin:$PATH" - -COPY --from=pcg-build /app/venv /app/venv -COPY --from=bigtable-emulator-build /go/bin/emulator /app/venv/bin/cbtemulator -COPY override/gcloud /app/venv/bin/gcloud -COPY override/timeout.conf /etc/nginx/conf.d/timeout.conf -COPY override/supervisord.conf /etc/supervisor/conf.d/supervisord.conf -# Hack to get zstandard from PyPI - remove if conda-forge linked lib issue is resolved -RUN pip install --no-cache-dir --no-deps --force-reinstall zstandard==0.21.0 -COPY . /app - -RUN mkdir -p /home/nginx/.cloudvolume/secrets \ - && chown -R nginx /home/nginx \ - && usermod -d /home/nginx -s /bin/bash nginx diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..92e9570d2 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,17 @@ +codecov: + require_ci_to_pass: true + +coverage: + status: + project: + default: + target: auto + threshold: 1% + patch: + default: + target: 1% + +comment: + layout: "reach,diff,flags,files" + behavior: default + require_changes: false diff --git a/docs/Readme.md b/docs/Readme.md index 45799326e..c05ad6979 100644 --- a/docs/Readme.md +++ b/docs/Readme.md @@ -10,7 +10,7 @@ pip install -r requirements.txt ## Multiprocessing -Check out [multiprocessing.md](https://github.com/seung-lab/PyChunkedGraph/blob/master/src/pychunkedgraph/multiprocessing.md) for how to use the multiprocessing tools implemented for the ChunkedGraph +Check out [multiprocessing.md](https://github.com/CAVEconnectome/PyChunkedGraph/blob/master/src/pychunkedgraph/multiprocessing.md) for how to use the multiprocessing tools implemented for the ChunkedGraph ## Credentials @@ -30,7 +30,7 @@ The current version of the ChunkedGraph contains supervoxels from `gs://nkem/bas ### Building the graph -[buildgraph.md](https://github.com/seung-lab/PyChunkedGraph/blob/master/src/pychunkedgraph/buildgraph.md) explains how to build a graph from scratch. +[buildgraph.md](https://github.com/CAVEconnectome/PyChunkedGraph/blob/master/src/pychunkedgraph/buildgraph.md) explains how to build a graph from scratch. ### Initialization diff --git a/docs/edges.md b/docs/edges.md index 9dc15a98b..ccda4205b 100644 --- a/docs/edges.md +++ b/docs/edges.md @@ -2,7 +2,7 @@ PyChunkedgraph uses protobuf for serialization and zstandard for compression. -Edges and connected components per chunk are stored using the protobuf definitions in [`pychunkedgraph.io.protobuf`](https://github.com/seung-lab/PyChunkedGraph/pychunkedgraph/io/protobuf/chunkEdges.proto). +Edges and connected components per chunk are stored using the protobuf definitions in [`pychunkedgraph.io.protobuf`](https://github.com/CAVEconnectome/PyChunkedGraph/pychunkedgraph/io/protobuf/chunkEdges.proto). This format is a result of performance tests. It provided the best tradeoff between deserialzation speed and storage size. diff --git a/docs/edges_and_components.md b/docs/edges_and_components.md index 9dc15a98b..ccda4205b 100644 --- a/docs/edges_and_components.md +++ b/docs/edges_and_components.md @@ -2,7 +2,7 @@ PyChunkedgraph uses protobuf for serialization and zstandard for compression. -Edges and connected components per chunk are stored using the protobuf definitions in [`pychunkedgraph.io.protobuf`](https://github.com/seung-lab/PyChunkedGraph/pychunkedgraph/io/protobuf/chunkEdges.proto). +Edges and connected components per chunk are stored using the protobuf definitions in [`pychunkedgraph.io.protobuf`](https://github.com/CAVEconnectome/PyChunkedGraph/pychunkedgraph/io/protobuf/chunkEdges.proto). This format is a result of performance tests. It provided the best tradeoff between deserialzation speed and storage size. diff --git a/docs/segmentation_preprocessing.md b/docs/segmentation_preprocessing.md index 3fb1bf59b..028419a3a 100644 --- a/docs/segmentation_preprocessing.md +++ b/docs/segmentation_preprocessing.md @@ -32,10 +32,10 @@ There are three types of edges: 2. `cross_chunk`: edges between parts of "the same" supervoxel in the unchunked segmentation that has been split across chunk boundary 3. `between_chunk`: edges between supervoxels across chunks -Every pair of touching supervoxels has an edge between them. All edges are stored using [protobuf](https://github.com/seung-lab/PyChunkedGraph/blob/pcgv2/pychunkedgraph/io/protobuf/chunkEdges.proto). During ingest only edges of type 2. and 3. are copied into BigTable, whereas edges of type 1. are always read from storage to reduce cost. Similar to the supervoxel segmentation, we recommed storing these on GCloud in the same zone the ChunkedGraph server will be deployed in to reduce latency. +Every pair of touching supervoxels has an edge between them. All edges are stored using [protobuf](https://github.com/CAVEconnectome/PyChunkedGraph/blob/pcgv2/pychunkedgraph/io/protobuf/chunkEdges.proto). During ingest only edges of type 2. and 3. are copied into BigTable, whereas edges of type 1. are always read from storage to reduce cost. Similar to the supervoxel segmentation, we recommed storing these on GCloud in the same zone the ChunkedGraph server will be deployed in to reduce latency. To denote which edges form a connected component within a chunk, a component mapping needs to be created. This mapping is only used during ingest. -More details on how to create these protobuf files can be found [here](https://github.com/seung-lab/PyChunkedGraph/blob/pcgv2/docs/storage.md). +More details on how to create these protobuf files can be found [here](https://github.com/CAVEconnectome/PyChunkedGraph/blob/pcgv2/docs/storage.md). diff --git a/pychunkedgraph/__init__.py b/pychunkedgraph/__init__.py index e615ea2b7..28c0d26dc 100644 --- a/pychunkedgraph/__init__.py +++ b/pychunkedgraph/__init__.py @@ -1 +1,63 @@ -__version__ = "2.21.1" +__version__ = "3.1.7" + +import sys +import warnings +import logging as stdlib_logging # Use alias to avoid conflict with pychunkedgraph.logging + +# Suppress annoying warning from python_jsonschema_objects dependency +warnings.filterwarnings( + "ignore", message="Schema id not specified", module="python_jsonschema_objects" +) + +# Export logging levels for convenience +DEBUG = stdlib_logging.DEBUG +INFO = stdlib_logging.INFO +WARNING = stdlib_logging.WARNING +ERROR = stdlib_logging.ERROR + +# Set up library-level logger with NullHandler (Python logging best practice) +stdlib_logging.getLogger(__name__).addHandler(stdlib_logging.NullHandler()) + + +def configure_logging(level=stdlib_logging.INFO, format_str=None, stream=None): + """ + Configure logging for pychunkedgraph. Call this to enable log output. + + Works in Jupyter notebooks and scripts. + + Args: + level: Logging level (default: INFO). Use pychunkedgraph.DEBUG, .INFO, .WARNING, .ERROR + format_str: Custom format string (optional) + stream: Output stream (default: sys.stdout for Jupyter compatibility) + + Example: + import pychunkedgraph + pychunkedgraph.configure_logging() # Enable INFO level logging + pychunkedgraph.configure_logging(pychunkedgraph.DEBUG) # Enable DEBUG level + """ + if format_str is None: + format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + if stream is None: + stream = sys.stdout + + # Get root logger for pychunkedgraph + logger = stdlib_logging.getLogger(__name__) + logger.setLevel(level) + + # Remove existing handlers and add fresh StreamHandler + # This allows reconfiguring with different levels/formats + for h in logger.handlers[:]: + if isinstance(h, stdlib_logging.StreamHandler) and not isinstance( + h, stdlib_logging.NullHandler + ): + logger.removeHandler(h) + + handler = stdlib_logging.StreamHandler(stream) + handler.setLevel(level) + handler.setFormatter(stdlib_logging.Formatter(format_str)) + logger.addHandler(handler) + + return logger + + +configure_logging() diff --git a/pychunkedgraph/app/__init__.py b/pychunkedgraph/app/__init__.py index 3e938628b..262849258 100644 --- a/pychunkedgraph/app/__init__.py +++ b/pychunkedgraph/app/__init__.py @@ -105,6 +105,8 @@ def configure_app(app): with app.app_context(): from ..ingest.rq_cli import init_rq_cmds from ..ingest.cli import init_ingest_cmds + from ..ingest.cli_upgrade import init_upgrade_cmds init_rq_cmds(app) init_ingest_cmds(app) + init_upgrade_cmds(app) diff --git a/pychunkedgraph/app/common.py b/pychunkedgraph/app/common.py index 237e11fc0..7562762de 100644 --- a/pychunkedgraph/app/common.py +++ b/pychunkedgraph/app/common.py @@ -4,7 +4,7 @@ import json import time import traceback -from datetime import datetime +from datetime import datetime, timezone from cloudvolume import compression from google.api_core.exceptions import GoogleAPIError @@ -50,7 +50,7 @@ def _log_request(response_time): def before_request(): current_app.request_start_time = time.time() - current_app.request_start_date = datetime.utcnow() + current_app.request_start_date = datetime.now(timezone.utc) try: current_app.user_id = g.auth_user["id"] except (AttributeError, KeyError): diff --git a/pychunkedgraph/app/meshing/common.py b/pychunkedgraph/app/meshing/common.py index 8f1a0c20a..10306543a 100644 --- a/pychunkedgraph/app/meshing/common.py +++ b/pychunkedgraph/app/meshing/common.py @@ -4,8 +4,6 @@ import threading import numpy as np -import redis -from rq import Queue, Connection, Retry from flask import Response, current_app, jsonify, make_response, request from pychunkedgraph import __version__ @@ -145,37 +143,15 @@ def _check_post_options(cg, resp, data, seg_ids): def handle_remesh(table_id): current_app.request_type = "remesh_enque" current_app.table_id = table_id - is_priority = request.args.get("priority", True, type=str2bool) - is_redisjob = request.args.get("use_redis", False, type=str2bool) - new_lvl2_ids = json.loads(request.data)["new_lvl2_ids"] - - if is_redisjob: - with Connection(redis.from_url(current_app.config["REDIS_URL"])): - - if is_priority: - retry = Retry(max=3, interval=[1, 10, 60]) - queue_name = "mesh-chunks" - else: - retry = Retry(max=3, interval=[60, 60, 60]) - queue_name = "mesh-chunks-low-priority" - q = Queue(queue_name, retry=retry, default_timeout=1200) - task = q.enqueue(meshing_tasks.remeshing, table_id, new_lvl2_ids) - - response_object = {"status": "success", "data": {"task_id": task.get_id()}} - - return jsonify(response_object), 202 - else: - new_lvl2_ids = np.array(new_lvl2_ids, dtype=np.uint64) - cg = app_utils.get_cg(table_id) - - if len(new_lvl2_ids) > 0: - t = threading.Thread( - target=_remeshing, args=(cg.get_serialized_info(), new_lvl2_ids) - ) - t.start() - - return Response(status=202) + new_lvl2_ids = np.array(new_lvl2_ids, dtype=np.uint64) + cg = app_utils.get_cg(table_id) + if len(new_lvl2_ids) > 0: + t = threading.Thread( + target=_remeshing, args=(cg.get_serialized_info(), new_lvl2_ids) + ) + t.start() + return Response(status=202) def _remeshing(serialized_cg_info, lvl2_nodes): diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 70642c9ce..9f4ada9d7 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -3,7 +3,7 @@ import json import os import time -from datetime import datetime +from datetime import datetime, timezone from functools import reduce from collections import deque, defaultdict @@ -229,7 +229,7 @@ def handle_find_minimal_covering_nodes(table_id, is_binary=True): node_queue[layer].clear() # Return the download list - download_list = np.concatenate([np.array(list(v)) for v in download_list.values()]) + download_list = np.concatenate([np.array(list(v), dtype=np.uint64) for v in download_list.values()]) return download_list @@ -383,6 +383,7 @@ def handle_merge(table_id, allow_same_segment_merge=False): source_coords=coords[:1], sink_coords=coords[1:], allow_same_segment_merge=allow_same_segment_merge, + do_sanity_check=True, ) except cg_exceptions.LockingError as e: @@ -450,6 +451,7 @@ def handle_split(table_id): source_coords=coords[node_idents == 0], sink_coords=coords[node_idents == 1], mincut=mincut, + do_sanity_check=True, ) except cg_exceptions.LockingError as e: raise cg_exceptions.InternalServerError(e) @@ -600,7 +602,7 @@ def all_user_operations( target_user_id = request.args.get("user_id", None) start_time = _parse_timestamp("start_time", 0, return_datetime=True) - end_time = _parse_timestamp("end_time", datetime.utcnow(), return_datetime=True) + end_time = _parse_timestamp("end_time", datetime.now(timezone.utc), return_datetime=True) # Call ChunkedGraph cg = app_utils.get_cg(table_id) @@ -610,7 +612,7 @@ def all_user_operations( valid_entry_ids = [] timestamp_list = [] - undone_ids = np.array([]) + undone_ids = np.array([], dtype=np.uint64) entry_ids = np.sort(list(log_rows.keys())) for entry_id in entry_ids: @@ -688,7 +690,7 @@ def handle_children(table_id, parent_id): if layer > 1: children = cg.get_children(parent_id) else: - children = np.array([]) + children = np.array([], dtype=np.uint64) return children @@ -791,8 +793,8 @@ def handle_subgraph(table_id, root_id, only_internal_edges=True): supervoxels = np.concatenate( [agg.supervoxels for agg in l2id_agglomeration_d.values()] ) - mask0 = np.in1d(edges.node_ids1, supervoxels) - mask1 = np.in1d(edges.node_ids2, supervoxels) + mask0 = np.isin(edges.node_ids1, supervoxels) + mask1 = np.isin(edges.node_ids2, supervoxels) edges = edges[mask0 & mask1] return edges diff --git a/pychunkedgraph/debug/cross_edge_test.py b/pychunkedgraph/debug/cross_edge_test.py deleted file mode 100644 index 25bacfa0b..000000000 --- a/pychunkedgraph/debug/cross_edge_test.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -from datetime import datetime -import numpy as np - -from pychunkedgraph.graph import chunkedgraph -from pychunkedgraph.graph import attributes - -#os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/svenmd/.cloudvolume/secrets/google-secret.json" - -layer = 2 -n_chunks = 1000 -n_segments_per_chunk = 200 -# timestamp = datetime.datetime.fromtimestamp(1588875769) -timestamp = datetime.utcnow() - -cg = chunkedgraph.ChunkedGraph(graph_id="pinky_nf_v2") - -np.random.seed(42) - -node_ids = [] -for _ in range(n_chunks): - c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) - c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) - c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) - - chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) - - max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) - - if max_segment_id < 10: - continue - - segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) - - for segment_id in segment_ids: - node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) - -rows = cg.client.read_nodes(node_ids=node_ids, end_time=timestamp, - properties=attributes.Hierarchy.Parent) -valid_node_ids = [] -non_valid_node_ids = [] -for k in rows.keys(): - if len(rows[k]) > 0: - valid_node_ids.append(k) - else: - non_valid_node_ids.append(k) - -cc_edges = cg.get_atomic_cross_edges(valid_node_ids) -cc_ids = np.unique(np.concatenate([np.concatenate(list(v.values())) for v in list(cc_edges.values()) if len(v.values())])) - -roots = cg.get_roots(cc_ids) -root_dict = dict(zip(cc_ids, roots)) -root_dict_vec = np.vectorize(root_dict.get) - -for k in cc_edges: - if len(cc_edges[k]) == 0: - continue - local_ids = np.unique(np.concatenate(list(cc_edges[k].values()))) - - assert len(np.unique(root_dict_vec(local_ids))) \ No newline at end of file diff --git a/pychunkedgraph/debug/existence_test.py b/pychunkedgraph/debug/existence_test.py deleted file mode 100644 index 757d3d542..000000000 --- a/pychunkedgraph/debug/existence_test.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -from datetime import datetime -import numpy as np - -from pychunkedgraph.graph import chunkedgraph -from pychunkedgraph.graph import attributes - -#os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/svenmd/.cloudvolume/secrets/google-secret.json" - -layer = 2 -n_chunks = 100 -n_segments_per_chunk = 200 -# timestamp = datetime.datetime.fromtimestamp(1588875769) -timestamp = datetime.utcnow() - -cg = chunkedgraph.ChunkedGraph(graph_id="pinky_nf_v2") - -np.random.seed(42) - -node_ids = [] -for _ in range(n_chunks): - c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) - c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) - c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) - - chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) - - max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) - - if max_segment_id < 10: - continue - - segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) - - for segment_id in segment_ids: - node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) - -rows = cg.client.read_nodes(node_ids=node_ids, end_time=timestamp, - properties=attributes.Hierarchy.Parent) -valid_node_ids = [] -non_valid_node_ids = [] -for k in rows.keys(): - if len(rows[k]) > 0: - valid_node_ids.append(k) - else: - non_valid_node_ids.append(k) - -roots = cg.get_roots(valid_node_ids, time_stamp=timestamp) - -roots = [] -try: - roots = cg.get_roots(valid_node_ids) - assert len(roots) == len(valid_node_ids) - print(f"ALL {len(roots)} have been successful!") -except: - print("At least one node failed. Checking nodes one by one now") - -if len(roots) != len(valid_node_ids): - log_dict = {} - success_dict = {} - for node_id in valid_node_ids: - try: - root = cg.get_root(node_id, time_stamp=timestamp) - print(f"Success: {node_id} from chunk {cg.get_chunk_id(node_id)}") - success_dict[node_id] = True - except Exception as e: - print(f"{node_id} from chunk {cg.get_chunk_id(node_id)} failed with {e}") - success_dict[node_id] = False - - t_id = node_id - - while t_id is not None: - last_working_chunk = cg.get_chunk_id(t_id) - t_id = cg.get_parent(t_id) - - print(f"Failed on layer {cg.get_chunk_layer(last_working_chunk)} in chunk {last_working_chunk}") - log_dict[node_id] = last_working_chunk - diff --git a/pychunkedgraph/debug/family_test.py b/pychunkedgraph/debug/family_test.py deleted file mode 100644 index 198351e74..000000000 --- a/pychunkedgraph/debug/family_test.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -from datetime import datetime -import numpy as np - -from pychunkedgraph.graph import chunkedgraph -from pychunkedgraph.graph import attributes - -# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/svenmd/.cloudvolume/secrets/google-secret.json" - -layers = [2, 3, 4, 5, 6, 7] -n_chunks = 10 -n_segments_per_chunk = 200 -# timestamp = datetime.datetime.fromtimestamp(1588875769) -timestamp = datetime.utcnow() - -cg = chunkedgraph.ChunkedGraph(graph_id="pinky_nf_v2") - -np.random.seed(42) - -node_ids = [] - -for layer in layers: - for _ in range(n_chunks): - c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) - c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) - c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) - - chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) - - max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) - - if max_segment_id < 10: - continue - - segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) - - for segment_id in segment_ids: - node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) - -rows = cg.client.read_nodes(node_ids=node_ids, end_time=timestamp, - properties=attributes.Hierarchy.Parent) -valid_node_ids = [] -non_valid_node_ids = [] -for k in rows.keys(): - if len(rows[k]) > 0: - valid_node_ids.append(k) - else: - non_valid_node_ids.append(k) - -parents = cg.get_parents(valid_node_ids, time_stamp=timestamp) -children_dict = cg.get_children(parents) - -for child, parent in zip(valid_node_ids, parents): - assert child in children_dict[parent] \ No newline at end of file diff --git a/pychunkedgraph/debug/profiler.py b/pychunkedgraph/debug/profiler.py new file mode 100644 index 000000000..b74ddac76 --- /dev/null +++ b/pychunkedgraph/debug/profiler.py @@ -0,0 +1,121 @@ +from typing import Dict +from typing import List +from typing import Tuple + +import os +import time +from collections import defaultdict +from contextlib import contextmanager + + +class HierarchicalProfiler: + """ + Hierarchical profiler for detailed timing breakdowns. + Tracks timing at multiple levels and prints a breakdown at the end. + """ + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.timings: Dict[str, List[float]] = defaultdict(list) + self.call_counts: Dict[str, int] = defaultdict(int) + self.stack: List[Tuple[str, float]] = [] + self.current_path: List[str] = [] + + @contextmanager + def profile(self, name: str): + """Context manager for profiling a code block.""" + if not self.enabled: + yield + return + + full_path = ".".join(self.current_path + [name]) + self.current_path.append(name) + start_time = time.perf_counter() + + try: + yield + finally: + elapsed = time.perf_counter() - start_time + self.timings[full_path].append(elapsed) + self.call_counts[full_path] += 1 + self.current_path.pop() + + def print_report(self, operation_id=None): + """Print a detailed timing breakdown.""" + if not self.enabled or not self.timings: + return + + print("\n" + "=" * 80) + print( + f"PROFILER REPORT{f' (operation_id={operation_id})' if operation_id else ''}" + ) + print("=" * 80) + + # Group by depth level + by_depth: Dict[int, List[Tuple[str, float, int]]] = defaultdict(list) + for path, times in self.timings.items(): + depth = path.count(".") + total_time = sum(times) + count = self.call_counts[path] + by_depth[depth].append((path, total_time, count)) + + # Sort each level by total time + for depth in sorted(by_depth.keys()): + items = sorted(by_depth[depth], key=lambda x: -x[1]) + for path, total_time, count in items: + indent = " " * depth + avg_time = total_time / count if count > 0 else 0 + if count > 1: + print( + f"{indent}{path}: {total_time*1000:.2f}ms total " + f"({count} calls, {avg_time*1000:.2f}ms avg)" + ) + else: + print(f"{indent}{path}: {total_time*1000:.2f}ms") + + # Print summary + print("-" * 80) + top_level_total = sum( + sum(times) for path, times in self.timings.items() if "." not in path + ) + print(f"Total top-level time: {top_level_total*1000:.2f}ms") + + # Print top 10 slowest operations + print("\nTop 10 slowest operations:") + all_ops = [ + (path, sum(times), self.call_counts[path]) + for path, times in self.timings.items() + ] + all_ops.sort(key=lambda x: -x[1]) + for i, (path, total_time, count) in enumerate(all_ops[:10]): + pct = (total_time / top_level_total * 100) if top_level_total > 0 else 0 + print(f" {i+1}. {path}: {total_time*1000:.2f}ms ({pct:.1f}%)") + + print("=" * 80 + "\n") + + def reset(self): + """Reset all timing data.""" + self.timings.clear() + self.call_counts.clear() + self.stack.clear() + self.current_path.clear() + + +# Global profiler instance - enable via environment variable +PROFILER_ENABLED = os.environ.get("PCG_PROFILER_ENABLED", "0") == "1" +_profiler: HierarchicalProfiler = None + + +def get_profiler() -> HierarchicalProfiler: + """Get or create the global profiler instance.""" + global _profiler + if _profiler is None: + _profiler = HierarchicalProfiler(enabled=PROFILER_ENABLED) + return _profiler + + +def reset_profiler(): + """Reset the global profiler.""" + global _profiler + if _profiler is not None: + _profiler.reset() diff --git a/pychunkedgraph/debug/utils.py b/pychunkedgraph/debug/utils.py index 179f50aef..ad12103b2 100644 --- a/pychunkedgraph/debug/utils.py +++ b/pychunkedgraph/debug/utils.py @@ -1,7 +1,8 @@ +# pylint: disable=invalid-name, missing-docstring, bare-except, unidiomatic-typecheck + import numpy as np -from ..graph import ChunkedGraph -from ..graph.utils.basetypes import NODE_ID +from pychunkedgraph.graph.meta import ChunkedGraphMeta, GraphConfig def print_attrs(d): @@ -16,28 +17,59 @@ def print_attrs(d): print(v) -def print_node( - cg: ChunkedGraph, - node: NODE_ID, - indent: int = 0, - stop_layer: int = 2, -) -> None: +def print_node(cg, node: np.uint64, indent: int = 0, stop_layer: int = 2) -> None: children = cg.get_children(node) print(f"{' ' * indent}{node}[{len(children)}]") if cg.get_chunk_layer(node) <= stop_layer: return for child in children: - print_node(cg, child, indent=indent + 1, stop_layer=stop_layer) - - -def get_l2children(cg: ChunkedGraph, node: NODE_ID) -> np.ndarray: - nodes = np.array([node], dtype=NODE_ID) - layers = cg.get_chunk_layers(nodes) - assert np.all(layers > 2), "nodes must be at layers > 2" - l2children = [] - while nodes.size: - children = cg.get_children(nodes, flatten=True) - layers = cg.get_chunk_layers(children) - l2children.append(children[layers == 2]) - nodes = children[layers > 2] - return np.concatenate(l2children) + print_node(cg, child, indent=indent + 4, stop_layer=stop_layer) + + +def sanity_check(cg, new_roots, operation_id): + """ + Check for duplicates in hierarchy, useful for debugging. + """ + # print(f"{len(new_roots)} new ids from {operation_id}") + l2c_d = {} + for new_root in new_roots: + l2c_d[new_root] = cg.get_l2children([new_root]) + success = True + for k, v in l2c_d.items(): + success = success and (len(v) == np.unique(v).size) + # print(f"{k}: {np.unique(v).size}, {len(v)}") + if not success: + raise RuntimeError(f"{operation_id}: some ids are not valid.") + + +def sanity_check_single(cg, node, operation_id): + v = cg.get_l2children([node]) + msg = f"invalid node {node}:" + msg += f" found {len(v)} l2 ids, must be {np.unique(v).size}" + assert np.unique(v).size == len(v), f"{msg}, from {operation_id}." + return v + + +def update_graph_id(cg, new_graph_id:str): + old_gc = cg.meta.graph_config._asdict() + old_gc["ID"] = new_graph_id + new_gc = GraphConfig(**old_gc) + new_meta = ChunkedGraphMeta(new_gc, cg.meta.data_source, cg.meta.custom_data) + cg.update_meta(new_meta, overwrite=True) + + +def get_random_l1_ids(cg, n_chunks=100, n_per_chunk=10, seed=None): + """Generate random layer 1 IDs from different chunks.""" + if seed: + np.random.seed(seed) + bounds = cg.meta.layer_chunk_bounds[2] + ids = [] + for _ in range(n_chunks): + cx, cy, cz = [np.random.randint(0, b) for b in bounds] + chunk_id = cg.get_chunk_id(layer=2, x=cx, y=cy, z=cz) + max_seg = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) + if max_seg < 2: + continue + for seg in np.random.randint(1, max_seg + 1, n_per_chunk): + ids.append(cg.get_node_id(np.uint64(seg), np.uint64(chunk_id))) + return np.array(ids, dtype=np.uint64) diff --git a/pychunkedgraph/graph/attributes.py b/pychunkedgraph/graph/attributes.py index 3e48d204a..6b7a277f0 100644 --- a/pychunkedgraph/graph/attributes.py +++ b/pychunkedgraph/graph/attributes.py @@ -1,6 +1,9 @@ +# pylint: disable=invalid-name, missing-docstring, protected-access, raise-missing-from + # TODO design to use these attributes across different clients # `family_id` is specific to bigtable +from enum import Enum from typing import NamedTuple from .utils import serializers @@ -101,7 +104,7 @@ class Connectivity: serializer=serializers.NumPyArray(dtype=basetypes.EDGE_AREA), ) - CrossChunkEdge = _AttributeArray( + AtomicCrossChunkEdge = _AttributeArray( pattern=b"atomic_cross_edges_%d", family_id="3", serializer=serializers.NumPyArray( @@ -109,12 +112,26 @@ class Connectivity: ), ) - FakeEdges = _Attribute( + CrossChunkEdge = _AttributeArray( + pattern=b"cross_edges_%d", + family_id="4", + serializer=serializers.NumPyArray( + dtype=basetypes.NODE_ID, shape=(-1, 2), compression_level=22 + ), + ) + + FakeEdgesCF3 = _Attribute( key=b"fake_edges", family_id="3", serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)), ) + FakeEdges = _Attribute( + key=b"fake_edges", + family_id="4", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)), + ) + class Hierarchy: Child = _Attribute( @@ -143,6 +160,12 @@ class Hierarchy: serializer=serializers.NumPyValue(dtype=basetypes.NODE_ID), ) + # track when nodes became stale, required for migration + # will be eventually deleted by GC rule for column family_id 3. + StaleTimeStamp = _Attribute( + key=b"stale_ts", family_id="3", serializer=serializers.Pickle() + ) + class GraphMeta: key = b"meta" @@ -157,8 +180,6 @@ class GraphVersion: class OperationLogs: key = b"ioperations" - from enum import Enum - class StatusCodes(Enum): SUCCESS = 0 # all is well, new changes persisted CREATED = 1 # log record created in storage diff --git a/pychunkedgraph/graph/cache.py b/pychunkedgraph/graph/cache.py index f60b6ca92..011f4099e 100644 --- a/pychunkedgraph/graph/cache.py +++ b/pychunkedgraph/graph/cache.py @@ -1,6 +1,9 @@ +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel """ Cache nodes, parents, children and cross edges. """ +import traceback +from collections import defaultdict as defaultd from sys import maxsize from datetime import datetime @@ -30,28 +33,84 @@ def __init__(self, cg): self._parent_vec = np.vectorize(self.parent, otypes=[np.uint64]) self._children_vec = np.vectorize(self.children, otypes=[np.ndarray]) - self._atomic_cross_edges_vec = np.vectorize( - self.atomic_cross_edges, otypes=[dict] + self._cross_chunk_edges_vec = np.vectorize( + self.cross_chunk_edges, otypes=[dict] ) # no limit because we don't want to lose new IDs self.parents_cache = LRUCache(maxsize=maxsize) self.children_cache = LRUCache(maxsize=maxsize) - self.atomic_cx_edges_cache = LRUCache(maxsize=maxsize) + self.cross_chunk_edges_cache = LRUCache(maxsize=maxsize) + + self.new_ids = set() + + # Stats tracking for cache hits/misses + self.stats = { + "parents": {"hits": 0, "misses": 0, "calls": 0}, + "children": {"hits": 0, "misses": 0, "calls": 0}, + "cross_chunk_edges": {"hits": 0, "misses": 0, "calls": 0}, + } + # Track where calls/misses come from + self.sources = defaultd(lambda: defaultd(lambda: {"calls": 0, "misses": 0})) + + def _get_caller(self, skip_frames=2): + """Get caller info (filename:line:function).""" + stack = traceback.extract_stack() + # Skip frames: _get_caller, the cache method, and go to actual caller + if len(stack) > skip_frames: + frame = stack[-(skip_frames + 1)] + return f"{frame.filename.split('/')[-1]}:{frame.lineno}:{frame.name}" + return "unknown" + + def _record_call(self, cache_type, misses=0): + """Record a call and its source.""" + caller = self._get_caller(skip_frames=3) + self.sources[cache_type][caller]["calls"] += 1 + self.sources[cache_type][caller]["misses"] += misses def __len__(self): return ( len(self.parents_cache) + len(self.children_cache) - + len(self.atomic_cx_edges_cache) + + len(self.cross_chunk_edges_cache) ) def clear(self): self.parents_cache.clear() self.children_cache.clear() - self.atomic_cx_edges_cache.clear() + self.cross_chunk_edges_cache.clear() + + def get_stats(self): + """Return stats with hit rates calculated.""" + result = {} + for name, s in self.stats.items(): + total = s["hits"] + s["misses"] + hit_rate = s["hits"] / total if total > 0 else 0 + result[name] = { + **s, + "total": total, + "hit_rate": f"{hit_rate:.1%}", + "sources": dict(self.sources[name]), + } + return result + + def reset_stats(self): + for s in self.stats.values(): + s["hits"] = 0 + s["misses"] = 0 + s["calls"] = 0 + self.sources.clear() def parent(self, node_id: np.uint64, *, time_stamp: datetime = None): + self.stats["parents"]["calls"] += 1 + is_cached = node_id in self.parents_cache + miss_count = 0 if is_cached else 1 + if is_cached: + self.stats["parents"]["hits"] += 1 + else: + self.stats["parents"]["misses"] += 1 + self._record_call("parents", misses=miss_count) + @cached(cache=self.parents_cache, key=lambda node_id: node_id) def parent_decorated(node_id): return self._cg.get_parent(node_id, raw_only=True, time_stamp=time_stamp) @@ -59,6 +118,15 @@ def parent_decorated(node_id): return parent_decorated(node_id) def children(self, node_id): + self.stats["children"]["calls"] += 1 + is_cached = node_id in self.children_cache + miss_count = 0 if is_cached else 1 + if is_cached: + self.stats["children"]["hits"] += 1 + else: + self.stats["children"]["misses"] += 1 + self._record_call("children", misses=miss_count) + @cached(cache=self.children_cache, key=lambda node_id: node_id) def children_decorated(node_id): children = self._cg.get_children(node_id, raw_only=True) @@ -67,33 +135,66 @@ def children_decorated(node_id): return children_decorated(node_id) - def atomic_cross_edges(self, node_id): - @cached(cache=self.atomic_cx_edges_cache, key=lambda node_id: node_id) - def atomic_cross_edges_decorated(node_id): - edges = self._cg.get_atomic_cross_edges( - np.array([node_id], dtype=NODE_ID), raw_only=True + def cross_chunk_edges(self, node_id, *, time_stamp: datetime = None): + self.stats["cross_chunk_edges"]["calls"] += 1 + is_cached = node_id in self.cross_chunk_edges_cache + miss_count = 0 if is_cached else 1 + if is_cached: + self.stats["cross_chunk_edges"]["hits"] += 1 + else: + self.stats["cross_chunk_edges"]["misses"] += 1 + self._record_call("cross_chunk_edges", misses=miss_count) + + @cached(cache=self.cross_chunk_edges_cache, key=lambda node_id: node_id) + def cross_edges_decorated(node_id): + edges = self._cg.get_cross_chunk_edges( + np.array([node_id], dtype=NODE_ID), raw_only=True, time_stamp=time_stamp ) return edges[node_id] - return atomic_cross_edges_decorated(node_id) + return cross_edges_decorated(node_id) - def parents_multiple(self, node_ids: np.ndarray, *, time_stamp: datetime = None): + def parents_multiple( + self, + node_ids: np.ndarray, + *, + time_stamp: datetime = None, + fail_to_zero: bool = False, + ): + node_ids = np.asarray(node_ids, dtype=NODE_ID) if not node_ids.size: return node_ids - mask = np.in1d(node_ids, np.fromiter(self.parents_cache.keys(), dtype=NODE_ID)) + self.stats["parents"]["calls"] += 1 + mask = np.isin(node_ids, np.fromiter(self.parents_cache.keys(), dtype=NODE_ID)) + hits = int(np.sum(mask)) + misses = len(node_ids) - hits + self.stats["parents"]["hits"] += hits + self.stats["parents"]["misses"] += misses + self._record_call("parents", misses=misses) parents = node_ids.copy() parents[mask] = self._parent_vec(node_ids[mask]) parents[~mask] = self._cg.get_parents( - node_ids[~mask], raw_only=True, time_stamp=time_stamp + node_ids[~mask], + raw_only=True, + time_stamp=time_stamp, + fail_to_zero=fail_to_zero, ) + mask = mask | (parents == 0) update(self.parents_cache, node_ids[~mask], parents[~mask]) return parents def children_multiple(self, node_ids: np.ndarray, *, flatten=False): result = {} + node_ids = np.asarray(node_ids, dtype=NODE_ID) if not node_ids.size: return result - mask = np.in1d(node_ids, np.fromiter(self.children_cache.keys(), dtype=NODE_ID)) + self.stats["children"]["calls"] += 1 + mask = np.isin(node_ids, np.fromiter(self.children_cache.keys(), dtype=NODE_ID)) + hits = int(np.sum(mask)) + misses = len(node_ids) - hits + self.stats["children"]["hits"] += hits + self.stats["children"]["misses"] += misses + self._record_call("children", misses=misses) cached_children_ = self._children_vec(node_ids[mask]) result.update({id_: c_ for id_, c_ in zip(node_ids[mask], cached_children_)}) result.update(self._cg.get_children(node_ids[~mask], raw_only=True)) @@ -104,20 +205,33 @@ def children_multiple(self, node_ids: np.ndarray, *, flatten=False): return np.concatenate([*result.values()]) return result - def atomic_cross_edges_multiple(self, node_ids: np.ndarray): + def cross_chunk_edges_multiple( + self, node_ids: np.ndarray, *, time_stamp: datetime = None + ): result = {} + node_ids = np.asarray(node_ids, dtype=NODE_ID) if not node_ids.size: return result - mask = np.in1d( - node_ids, np.fromiter(self.atomic_cx_edges_cache.keys(), dtype=NODE_ID) + self.stats["cross_chunk_edges"]["calls"] += 1 + mask = np.isin( + node_ids, np.fromiter(self.cross_chunk_edges_cache.keys(), dtype=NODE_ID) ) - cached_edges_ = self._atomic_cross_edges_vec(node_ids[mask]) + hits = int(np.sum(mask)) + misses = len(node_ids) - hits + self.stats["cross_chunk_edges"]["hits"] += hits + self.stats["cross_chunk_edges"]["misses"] += misses + self._record_call("cross_chunk_edges", misses=misses) + cached_edges_ = self._cross_chunk_edges_vec(node_ids[mask]) result.update( {id_: edges_ for id_, edges_ in zip(node_ids[mask], cached_edges_)} ) - result.update(self._cg.get_atomic_cross_edges(node_ids[~mask], raw_only=True)) + result.update( + self._cg.get_cross_chunk_edges( + node_ids[~mask], raw_only=True, time_stamp=time_stamp + ) + ) update( - self.atomic_cx_edges_cache, + self.cross_chunk_edges_cache, node_ids[~mask], [result[k] for k in node_ids[~mask]], ) diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py index 210bff50b..38a408e92 100644 --- a/pychunkedgraph/graph/chunkedgraph.py +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -1,8 +1,10 @@ -# pylint: disable=invalid-name, missing-docstring, too-many-lines, import-outside-toplevel +# pylint: disable=invalid-name, missing-docstring, too-many-lines, import-outside-toplevel, unsupported-binary-operation import time import typing import datetime +from itertools import chain +from functools import reduce import numpy as np from pychunkedgraph import __version__ @@ -16,14 +18,17 @@ from .client import BackendClientInfo from .client import get_default_client_info from .cache import CacheService -from .meta import ChunkedGraphMeta +from .meta import ChunkedGraphMeta, GraphConfig from .utils import basetypes from .utils import id_helpers +from .utils import serializers from .utils import generic as misc_utils from .edges import Edges from .edges import utils as edge_utils from .chunks import utils as chunk_utils from .chunks import hierarchy as chunk_hierarchy +from .subgraph import get_subgraph_nodes +from .subgraph import get_subgraph_edges_and_leaves class ChunkedGraph: @@ -61,6 +66,16 @@ def __init__( self._cache_service = None self.mock_edges = None # hack for unit tests + # shim to update graph_id in meta for copied graphs + if graph_id != self.graph_id: + gc = self.meta.graph_config._asdict() + gc["ID"] = graph_id + new_meta = ChunkedGraphMeta( + GraphConfig(**gc), self.meta.data_source, self.meta.custom_data + ) + self.update_meta(new_meta, overwrite=True) + self._meta = new_meta + @property def meta(self) -> ChunkedGraphMeta: return self._meta @@ -74,7 +89,7 @@ def version(self) -> str: return self.client.read_graph_version() @property - def client(self) -> base.SimpleClient: + def client(self) -> BigTableClient: return self._client @property @@ -112,13 +127,15 @@ def range_read_chunk( """Read all nodes in a chunk.""" layer = self.get_chunk_layer(chunk_id) root_chunk = layer == self.meta.layer_count - max_node_id = self.id_client.get_max_node_id(chunk_id=chunk_id, root_chunk=root_chunk) + max_id = self.id_client.get_max_node_id( + chunk_id=chunk_id, root_chunk=root_chunk + ) if layer == 1: - max_node_id = chunk_id | self.get_segment_id_limit(chunk_id) # pylint: disable=unsupported-binary-operation + max_id = chunk_id | self.get_segment_id_limit(chunk_id) return self.client.read_nodes( start_id=self.get_node_id(np.uint64(0), chunk_id=chunk_id), - end_id=max_node_id, + end_id=max_id, end_id_inclusive=True, properties=properties, end_time=time_stamp, @@ -197,7 +214,7 @@ def get_parents( end_time=time_stamp, end_time_inclusive=True, ) - if not parent_rows: + if not parent_rows and not fail_to_zero: return types.empty_1d parents = [] @@ -209,6 +226,7 @@ def get_parents( if fail_to_zero: parents.append(0) else: + exc.add_note(f"timestamp: {time_stamp}") raise KeyError from exc parents = np.array(parents, dtype=basetypes.NODE_ID) else: @@ -223,7 +241,9 @@ def get_parents( else: raise KeyError from exc return parents - return self.cache.parents_multiple(node_ids, time_stamp=time_stamp) + return self.cache.parents_multiple( + node_ids, time_stamp=time_stamp, fail_to_zero=fail_to_zero + ) def get_parent( self, @@ -283,97 +303,79 @@ def _get_children_multiple( node_ids=node_ids, properties=attributes.Hierarchy.Child ) return { - x: node_children_d[x][0].value - if x in node_children_d - else types.empty_1d.copy() + x: ( + node_children_d[x][0].value + if x in node_children_d + else types.empty_1d.copy() + ) for x in node_ids } return self.cache.children_multiple(node_ids) - def get_atomic_cross_edges( - self, l2_ids: typing.Iterable, *, raw_only=False - ) -> typing.Dict[np.uint64, typing.Dict[int, typing.Iterable]]: - """Returns cross edges for level 2 IDs.""" + def get_atomic_cross_edges(self, l2_ids: typing.Iterable) -> typing.Dict: + """ + Returns atomic cross edges for level 2 IDs. + A dict of the form `{l2id: {layer: atomic_cross_edges}}`. + """ + node_edges_d_d = self.client.read_nodes( + node_ids=l2_ids, + properties=[ + attributes.Connectivity.AtomicCrossChunkEdge[l] + for l in range(2, max(3, self.meta.layer_count)) + ], + ) + result = {} + for id_ in l2_ids: + try: + result[id_] = { + prop.index: val[0].value.copy() + for prop, val in node_edges_d_d[id_].items() + } + except KeyError: + result[id_] = {} + return result + + def get_cross_chunk_edges( + self, + node_ids: typing.Iterable, + *, + raw_only=False, + all_layers=True, + time_stamp: typing.Optional[datetime.datetime] = None, + ) -> typing.Dict: + """ + Returns cross edges for `node_ids`. + A dict of the form `{node_id: {layer: cross_edges}}`. + """ + time_stamp = misc_utils.get_valid_timestamp(time_stamp) if raw_only or not self.cache: + result = {} + node_ids = np.array(node_ids, dtype=basetypes.NODE_ID) + if node_ids.size == 0: + return result + layers = range(2, max(3, self.meta.layer_count)) + attrs = [attributes.Connectivity.CrossChunkEdge[l] for l in layers] node_edges_d_d = self.client.read_nodes( - node_ids=l2_ids, - properties=[ - attributes.Connectivity.CrossChunkEdge[l] - for l in range(2, max(3, self.meta.layer_count)) - ], + node_ids=node_ids, + properties=attrs, + end_time=time_stamp, + end_time_inclusive=True, ) - result = {} - for id_ in l2_ids: + layers = self.get_chunk_layers(node_ids) + valid_layer = lambda x, y: x >= y + if not all_layers: + valid_layer = lambda x, y: x == y + for layer, id_ in zip(layers, node_ids): try: result[id_] = { prop.index: val[0].value.copy() for prop, val in node_edges_d_d[id_].items() + if valid_layer(prop.index, layer) } except KeyError: result[id_] = {} return result - return self.cache.atomic_cross_edges_multiple(l2_ids) - - def get_cross_chunk_edges( - self, node_ids: typing.Iterable, uplift=True, all_layers=False - ) -> typing.Dict[np.uint64, typing.Dict[int, typing.Iterable]]: - """ - Cross chunk edges for `node_id` at `node_layer`. - The edges are between node IDs at the `node_layer`, not atomic cross edges. - Returns dict {layer_id: cross_edges} - The first layer (>= `node_layer`) with atleast one cross chunk edge. - For current use-cases, other layers are not relevant. - - For performance, only children that lie along chunk boundary are considered. - Cross edges that belong to inner level 2 IDs are subsumed within the chunk. - This is because cross edges are stored only in level 2 IDs. - """ - result = {} - node_ids = np.array(node_ids, dtype=basetypes.NODE_ID) - if not node_ids.size: - return result - - node_l2ids_d = {} - layers_ = self.get_chunk_layers(node_ids) - for l in set(layers_): - node_l2ids_d.update(self._get_bounding_l2_children(node_ids[layers_ == l])) - l2_edges_d_d = self.get_atomic_cross_edges( - np.concatenate(list(node_l2ids_d.values())) - ) - for node_id in node_ids: - l2_edges_ds = [l2_edges_d_d[l2_id] for l2_id in node_l2ids_d[node_id]] - if all_layers: - result[node_id] = edge_utils.concatenate_cross_edge_dicts(l2_edges_ds) - else: - result[node_id] = self._get_min_layer_cross_edges( - node_id, l2_edges_ds, uplift=uplift - ) - return result - - def _get_min_layer_cross_edges( - self, - node_id: basetypes.NODE_ID, - l2id_atomic_cross_edges_ds: typing.Iterable, - uplift=True, - ) -> typing.Dict[int, typing.Iterable]: - """ - Find edges at relevant min_layer >= node_layer. - `l2id_atomic_cross_edges_ds` is a list of atomic cross edges of - level 2 IDs that are descendants of `node_id`. - """ - min_layer, edges = edge_utils.filter_min_layer_cross_edges_multiple( - self.meta, l2id_atomic_cross_edges_ds, self.get_chunk_layer(node_id) - ) - if self.get_chunk_layer(node_id) < min_layer: - # cross edges irrelevant - return {self.get_chunk_layer(node_id): types.empty_2d} - if not uplift: - return {min_layer: edges} - node_root_id = node_id - node_root_id = self.get_root(node_id, stop_layer=min_layer, ceil=False) - edges[:, 0] = node_root_id - edges[:, 1] = self.get_roots(edges[:, 1], stop_layer=min_layer, ceil=False) - return {min_layer: np.unique(edges, axis=0) if edges.size else types.empty_2d} + return self.cache.cross_chunk_edges_multiple(node_ids, time_stamp=time_stamp) def get_roots( self, @@ -384,6 +386,7 @@ def get_roots( stop_layer: int = None, ceil: bool = True, fail_to_zero: bool = False, + raw_only=False, n_tries: int = 1, ) -> typing.Union[np.ndarray, typing.Dict[int, np.ndarray]]: """ @@ -407,7 +410,10 @@ def get_roots( filtered_ids = parent_ids[layer_mask] unique_ids, inverse = np.unique(filtered_ids, return_inverse=True) temp_ids = self.get_parents( - unique_ids, time_stamp=time_stamp, fail_to_zero=fail_to_zero + unique_ids, + time_stamp=time_stamp, + fail_to_zero=fail_to_zero, + raw_only=raw_only, ) if not temp_ids.size: break @@ -462,6 +468,7 @@ def get_root( get_all_parents: bool = False, stop_layer: int = None, ceil: bool = True, + raw_only: bool = False, n_tries: int = 1, ) -> typing.Union[typing.List[np.uint64], np.uint64]: """Takes a node id and returns the associated agglomeration ids.""" @@ -479,7 +486,9 @@ def get_root( for _ in range(n_tries): parent_id = node_id for _ in range(self.get_chunk_layer(node_id), int(stop_layer + 1)): - temp_parent_id = self.get_parent(parent_id, time_stamp=time_stamp) + temp_parent_id = self.get_parent( + parent_id, time_stamp=time_stamp, raw_only=raw_only + ) if temp_parent_id is None: break else: @@ -499,7 +508,7 @@ def get_root( else: time.sleep(0.5) - if self.get_chunk_layer(parent_id) < stop_layer: + if ceil and self.get_chunk_layer(parent_id) < stop_layer: raise exceptions.ChunkedGraphError( f"Cannot find root id {node_id}, {stop_layer}, {time_stamp}" ) @@ -546,22 +555,52 @@ def get_all_parents_dict( ) return dict(zip(self.get_chunk_layers(parent_ids), parent_ids)) + def get_all_parents_dict_multiple(self, node_ids, *, time_stamp=None): + """Batch fetch all parent hierarchies layer by layer.""" + result = {node: {} for node in node_ids} + nodes = np.array(node_ids, dtype=basetypes.NODE_ID) + layers_map = {} + child_parent_map = {} + + while nodes.size > 0: + parents = self.get_parents(nodes, time_stamp=time_stamp) + parent_layers = self.get_chunk_layers(parents) + for node, parent, layer in zip(nodes, parents, parent_layers): + layers_map[parent] = layer + child_parent_map[node] = parent + nodes = parents[parent_layers < self.meta.layer_count] + + for node in node_ids: + current = node + node_result = {} + while True: + try: + parent = child_parent_map[current] + except KeyError: + break + parent_layer = layers_map[parent] + node_result[parent_layer] = parent + current = parent + result[node] = node_result + return result + def get_subgraph( self, node_id_or_ids: typing.Union[np.uint64, typing.Iterable], bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None, bbox_is_coordinate: bool = False, - return_layers: typing.List = [2], + return_layers: typing.List = None, nodes_only: bool = False, edges_only: bool = False, leaves_only: bool = False, return_flattened: bool = False, - ) -> typing.Tuple[typing.Dict, typing.Dict, Edges]: + ) -> typing.Tuple[typing.Dict, typing.Tuple[Edges]]: """ Generic subgraph method. """ - from .subgraph import get_subgraph_nodes - from .subgraph import get_subgraph_edges_and_leaves + + if return_layers is None: + return_layers = [2] if nodes_only: return get_subgraph_nodes( @@ -581,7 +620,7 @@ def get_subgraph_nodes( node_id_or_ids: typing.Union[np.uint64, typing.Iterable], bbox: typing.Optional[typing.Sequence[typing.Sequence[int]]] = None, bbox_is_coordinate: bool = False, - return_layers: typing.List = [2], + return_layers: typing.List = None, serializable: bool = False, return_flattened: bool = False, ) -> typing.Tuple[typing.Dict, typing.Dict, Edges]: @@ -589,7 +628,8 @@ def get_subgraph_nodes( Get the children of `node_ids` that are at each of return_layers within the specified bounding box. """ - from .subgraph import get_subgraph_nodes + if return_layers is None: + return_layers = [2] return get_subgraph_nodes( self, @@ -610,8 +650,6 @@ def get_subgraph_edges( """ Get the atomic edges of the `node_ids` within the specified bounding box. """ - from .subgraph import get_subgraph_edges_and_leaves - return get_subgraph_edges_and_leaves( self, node_id_or_ids, bbox, bbox_is_coordinate, True, False ) @@ -625,8 +663,6 @@ def get_subgraph_leaves( """ Get the supervoxels of the `node_ids` within the specified bounding box. """ - from .subgraph import get_subgraph_edges_and_leaves - return get_subgraph_edges_and_leaves( self, node_id_or_ids, bbox, bbox_is_coordinate, False, True ) @@ -644,20 +680,37 @@ def get_fake_edges( ) for id_, val in fake_edges_d.items(): edges = np.concatenate( - [np.array(e.value, dtype=basetypes.NODE_ID) for e in val] + [np.asarray(e.value, dtype=basetypes.NODE_ID) for e in val] ) - result[id_] = Edges(edges[:, 0], edges[:, 1], fake_edges=True) + result[id_] = Edges(edges[:, 0], edges[:, 1]) return result + def copy_fake_edges(self, chunk_id: np.uint64) -> None: + _edges = self.client.read_node( + node_id=chunk_id, + properties=attributes.Connectivity.FakeEdgesCF3, + end_time_inclusive=True, + fake_edges=True, + ) + mutations = [] + _id = serializers.serialize_uint64(chunk_id, fake_edges=True) + for e in _edges: + val_dict = {attributes.Connectivity.FakeEdges: e.value} + row = self.client.mutate_row(_id, val_dict, time_stamp=e.timestamp) + mutations.append(row) + self.client.write(mutations) + def get_l2_agglomerations( - self, level2_ids: np.ndarray, edges_only: bool = False - ) -> typing.Tuple[typing.Dict[int, types.Agglomeration], np.ndarray]: + self, + level2_ids: np.ndarray, + edges_only: bool = False, + active: bool = False, + time_stamp: typing.Optional[datetime.datetime] = None, + ) -> typing.Tuple[typing.Dict[int, types.Agglomeration], typing.Tuple[Edges]]: """ Children of Level 2 Node IDs and edges. Edges are read from cloud storage. """ - from itertools import chain - from functools import reduce from .misc import get_agglomerations chunk_ids = np.unique(self.get_chunk_ids_from_node_ids(level2_ids)) @@ -674,6 +727,8 @@ def get_l2_agglomerations( chain(edges_d.values(), fake_edges.values()), Edges([], []), ) + if self.mock_edges is not None: + all_chunk_edges += self.mock_edges if edges_only: if self.mock_edges is not None: @@ -681,20 +736,26 @@ def get_l2_agglomerations( else: all_chunk_edges = all_chunk_edges.get_pairs() supervoxels = self.get_children(level2_ids, flatten=True) - mask0 = np.in1d(all_chunk_edges[:, 0], supervoxels) - mask1 = np.in1d(all_chunk_edges[:, 1], supervoxels) + mask0 = np.isin(all_chunk_edges[:, 0], supervoxels) + mask1 = np.isin(all_chunk_edges[:, 1], supervoxels) return all_chunk_edges[mask0 & mask1] l2id_children_d = self.get_children(level2_ids) sv_parent_d = {} for l2id in l2id_children_d: svs = l2id_children_d[l2id] + for sv in svs: + if sv in sv_parent_d: + raise ValueError("Found conflicting parents.") sv_parent_d.update(dict(zip(svs.tolist(), [l2id] * len(svs)))) + if active: + all_chunk_edges = edge_utils.filter_inactive_cross_edges( + self, all_chunk_edges, time_stamp=time_stamp + ) + in_edges, out_edges, cross_edges = edge_utils.categorize_edges_v2( - self.meta, - all_chunk_edges, - sv_parent_d + self.meta, all_chunk_edges, sv_parent_d ) agglomeration_d = get_agglomerations( @@ -702,13 +763,15 @@ def get_l2_agglomerations( ) return ( agglomeration_d, - (self.mock_edges,) - if self.mock_edges is not None - else (in_edges, out_edges, cross_edges), + ( + (self.mock_edges,) + if self.mock_edges is not None + else (in_edges, out_edges, cross_edges) + ), ) def get_node_timestamps( - self, node_ids: typing.Sequence[np.uint64], return_numpy=True + self, node_ids: typing.Sequence[np.uint64], return_numpy=True, normalize=False ) -> typing.Iterable: """ The timestamp of the children column can be assumed @@ -722,22 +785,31 @@ def get_node_timestamps( if return_numpy: return np.array([], dtype=np.datetime64) return [] + result = [] + earliest_ts = self.get_earliest_timestamp() + for n in node_ids: + try: + ts = children[n][0].timestamp + except KeyError: + ts = datetime.datetime.now(datetime.timezone.utc) + if normalize: + ts = earliest_ts if ts < earliest_ts else ts + result.append(ts) if return_numpy: - return np.array( - [children[x][0].timestamp for x in node_ids], dtype=np.datetime64 - ) - return [children[x][0].timestamp for x in node_ids] + return np.array(result, dtype=np.datetime64) + return result # OPERATIONS def add_edges( self, user_id: str, - atomic_edges: typing.Sequence[np.uint64], + atomic_edges: typing.Sequence[typing.Sequence[np.uint64]], *, affinities: typing.Sequence[np.float32] = None, source_coords: typing.Sequence[int] = None, sink_coords: typing.Sequence[int] = None, allow_same_segment_merge: typing.Optional[bool] = False, + do_sanity_check: typing.Optional[bool] = True, ) -> operation.GraphEditOperation.Result: """ Adds an edge to the chunkedgraph @@ -754,6 +826,7 @@ def add_edges( source_coords=source_coords, sink_coords=sink_coords, allow_same_segment_merge=allow_same_segment_merge, + do_sanity_check=do_sanity_check, ).execute() def remove_edges( @@ -769,6 +842,7 @@ def remove_edges( path_augment: bool = True, disallow_isolating_cut: bool = True, bb_offset: typing.Tuple[int, int, int] = (240, 240, 24), + do_sanity_check: typing.Optional[bool] = True, ) -> operation.GraphEditOperation.Result: """ Removes edges - either directly or after applying a mincut @@ -793,6 +867,7 @@ def remove_edges( bbox_offset=bb_offset, path_augment=path_augment, disallow_isolating_cut=disallow_isolating_cut, + do_sanity_check=do_sanity_check, ).execute() if not atomic_edges: @@ -842,82 +917,7 @@ def redo_operation( multicut_as_split=True, ).execute() - # PRIVATE - - def _get_bounding_chunk_ids( - self, - parent_chunk_ids: typing.Iterable, - unique: bool = False, - ) -> typing.Dict: - """ - Returns bounding chunk IDs at layers < parent_layer for all chunk IDs. - Dict[parent_chunk_id] = np.array(bounding_chunk_ids) - """ - parent_chunk_coords = self.get_chunk_coordinates_multiple(parent_chunk_ids) - parents_layer = self.get_chunk_layer(parent_chunk_ids[0]) - chunk_id_bchunk_ids_d = {} - for i, chunk_id in enumerate(parent_chunk_ids): - if chunk_id in chunk_id_bchunk_ids_d: - # `parent_chunk_ids` can have duplicates - # avoid redundant calculations - continue - parent_coord = parent_chunk_coords[i] - chunk_ids = [types.empty_1d] - for child_layer in range(2, parents_layer): - bcoords = chunk_utils.get_bounding_children_chunks( - self.meta, - parents_layer, - parent_coord, - child_layer, - return_unique=False, - ) - bchunks_ids = chunk_utils.get_chunk_ids_from_coords( - self.meta, child_layer, bcoords - ) - chunk_ids.append(bchunks_ids) - chunk_ids = np.concatenate(chunk_ids) - if unique: - chunk_ids = np.unique(chunk_ids) - chunk_id_bchunk_ids_d[chunk_id] = chunk_ids - return chunk_id_bchunk_ids_d - - def _get_bounding_l2_children(self, parents: typing.Iterable) -> typing.Dict: - parent_chunk_ids = self.get_chunk_ids_from_node_ids(parents) - chunk_id_bchunk_ids_d = self._get_bounding_chunk_ids( - parent_chunk_ids, unique=len(parents) >= 200 - ) - - parent_descendants_d = { - _id: np.array([_id], dtype=basetypes.NODE_ID) for _id in parents - } - descendants_all = np.concatenate(list(parent_descendants_d.values())) - descendants_layers = self.get_chunk_layers(descendants_all) - layer_mask = descendants_layers > 2 - descendants_all = descendants_all[layer_mask] - - while descendants_all.size: - descendant_children_d = self.get_children(descendants_all) - for i, parent_id in enumerate(parents): - _descendants = parent_descendants_d[parent_id] - _layers = self.get_chunk_layers(_descendants) - _l2mask = _layers == 2 - descendants = [_descendants[_l2mask]] - for child in _descendants[~_l2mask]: - descendants.append(descendant_children_d[child]) - descendants = np.concatenate(descendants) - chunk_ids = self.get_chunk_ids_from_node_ids(descendants) - bchunk_ids = chunk_id_bchunk_ids_d[parent_chunk_ids[i]] - bounding_descendants = descendants[np.in1d(chunk_ids, bchunk_ids)] - parent_descendants_d[parent_id] = bounding_descendants - - descendants_all = np.concatenate(list(parent_descendants_d.values())) - descendants_layers = self.get_chunk_layers(descendants_all) - layer_mask = descendants_layers > 2 - descendants_all = descendants_all[layer_mask] - return parent_descendants_d - # HELPERS / WRAPPERS - def is_root(self, node_id: basetypes.NODE_ID) -> bool: return self.get_chunk_layer(node_id) == self.meta.layer_count @@ -955,9 +955,9 @@ def get_chunk_coordinates(self, node_or_chunk_id: basetypes.NODE_ID): return chunk_utils.get_chunk_coordinates(self.meta, node_or_chunk_id) def get_chunk_coordinates_multiple(self, node_or_chunk_ids: typing.Sequence): - node_or_chunk_ids = np.array(node_or_chunk_ids, dtype=basetypes.NODE_ID) + node_or_chunk_ids = np.asarray(node_or_chunk_ids, dtype=basetypes.NODE_ID) layers = self.get_chunk_layers(node_or_chunk_ids) - assert np.all(layers == layers[0]), "All IDs must have the same layer." + assert len(layers) == 0 or np.all(layers == layers[0]), "must be same layer." return chunk_utils.get_chunk_coordinates_multiple(self.meta, node_or_chunk_ids) def get_chunk_id( @@ -987,6 +987,11 @@ def get_parent_chunk_id( self.meta, node_or_chunk_id, parent_layer ) + def get_parent_chunk_id_multiple(self, node_or_chunk_ids: typing.Sequence): + return chunk_hierarchy.get_parent_chunk_id_multiple( + self.meta, node_or_chunk_ids + ) + def get_parent_chunk_ids(self, node_or_chunk_id: basetypes.NODE_ID): return chunk_hierarchy.get_parent_chunk_ids(self.meta, node_or_chunk_id) @@ -1017,6 +1022,71 @@ def get_earliest_timestamp(self): from datetime import timedelta for op_id in range(100): - _, timestamp = self.client.read_log_entry(op_id) + _log, timestamp = self.client.read_log_entry(op_id) if timestamp is not None: return timestamp - timedelta(milliseconds=500) + if _log: + return self.client._read_byte_row(serializers.serialize_uint64(op_id))[ + attributes.OperationLogs.Status + ][-1].timestamp + + def get_operation_ids(self, node_ids: typing.Sequence): + response = self.client.read_nodes(node_ids=node_ids) + result = {} + for node in node_ids: + try: + operations = response[node][attributes.OperationLogs.OperationID] + result[node] = [(x.value, x.timestamp) for x in operations] + except KeyError: + ... + return result + + def get_single_leaf_multiple(self, node_ids): + """Returns the first supervoxel found for each node_id.""" + result = {} + node_ids_copy = np.copy(node_ids) + children = np.copy(node_ids) + children_d = self.get_children(node_ids) + while True: + children = [children_d[k][0] for k in children] + children = np.array(children, dtype=basetypes.NODE_ID) + mask = self.get_chunk_layers(children) == 1 + result.update( + [(node, sv) for node, sv in zip(node_ids[mask], children[mask])] + ) + node_ids = node_ids[~mask] + children = children[~mask] + if children.size == 0: + break + children_d = self.get_children(children) + return np.array([result[k] for k in node_ids_copy], dtype=basetypes.NODE_ID) + + def get_chunk_layers_and_coordinates(self, node_or_chunk_ids: typing.Sequence): + """ + Helper function that wraps get chunk layer and coordinates for nodes at any layer. + """ + node_or_chunk_ids = np.array(node_or_chunk_ids, dtype=basetypes.NODE_ID) + layers = self.get_chunk_layers(node_or_chunk_ids) + chunk_coords = np.zeros(shape=(len(node_or_chunk_ids), 3), dtype=int) + for _layer in np.unique(layers): + mask = layers == _layer + _nodes = node_or_chunk_ids[mask] + chunk_coords[mask] = chunk_utils.get_chunk_coordinates_multiple( + self.meta, _nodes + ) + return layers, chunk_coords + + def get_l2children(self, node_ids) -> np.ndarray: + """ + Get L2 children of all node_ids, returns a flat array. + """ + node_ids = np.asarray(node_ids, dtype=basetypes.NODE_ID) + layers = self.get_chunk_layers(node_ids) + assert np.all(layers >= 2), "nodes must be at layers >= 2" + l2children = [types.empty_1d] + while node_ids.size: + children = self.get_children(node_ids, flatten=True) + layers = self.get_chunk_layers(children) + l2children.append(children[layers == 2]) + node_ids = children[layers > 2] + return np.concatenate(l2children) diff --git a/pychunkedgraph/graph/chunks/atomic.py b/pychunkedgraph/graph/chunks/atomic.py index e3de065ff..ec0109c69 100644 --- a/pychunkedgraph/graph/chunks/atomic.py +++ b/pychunkedgraph/graph/chunks/atomic.py @@ -1,3 +1,5 @@ +# pylint: disable=invalid-name, missing-docstring + from typing import List from typing import Sequence from itertools import product @@ -6,8 +8,6 @@ from .utils import get_bounding_children_chunks from ..meta import ChunkedGraphMeta -from ..utils.generic import get_valid_timestamp -from ..utils import basetypes def get_touching_atomic_chunks( @@ -27,7 +27,7 @@ def get_touching_atomic_chunks( chunk_offset = chunk_coords * atomic_chunk_count mid = (atomic_chunk_count // 2) - 1 - # TODO (akhileshh) convert this for loop to numpy + # TODO (akhileshh) convert this for loop to numpy; # relevant chunks along touching planes at center for axis_1, axis_2 in product(*[range(atomic_chunk_count)] * 2): # x-y plane @@ -62,4 +62,6 @@ def get_bounding_atomic_chunks( chunkedgraph_meta: ChunkedGraphMeta, layer: int, chunk_coords: Sequence[int] ) -> List: """Atomic chunk coordinates along the boundary of a chunk""" - return get_bounding_children_chunks(chunkedgraph_meta, layer, chunk_coords, 2) + return get_bounding_children_chunks( + chunkedgraph_meta, layer, tuple(chunk_coords), 2 + ) diff --git a/pychunkedgraph/graph/chunks/hierarchy.py b/pychunkedgraph/graph/chunks/hierarchy.py index 32d6029ee..5ff7823fe 100644 --- a/pychunkedgraph/graph/chunks/hierarchy.py +++ b/pychunkedgraph/graph/chunks/hierarchy.py @@ -37,17 +37,17 @@ def get_children_chunk_ids( layer = utils.get_chunk_layer(meta, node_or_chunk_id) if layer == 1: - return np.array([]) + return np.array([], dtype=np.uint64) elif layer == 2: return np.array([utils.get_chunk_id(meta, layer=layer, x=x, y=y, z=z)]) else: children_coords = get_children_chunk_coords(meta, layer, (x, y, z)) children_chunk_ids = [] - for (x, y, z) in children_coords: + for x, y, z in children_coords: children_chunk_ids.append( utils.get_chunk_id(meta, layer=layer - 1, x=x, y=y, z=z) ) - return np.array(children_chunk_ids) + return np.array(children_chunk_ids, dtype=np.uint64) def get_parent_chunk_id( @@ -62,6 +62,19 @@ def get_parent_chunk_id( return utils.get_chunk_id(meta, layer=parent_layer, x=x, y=y, z=z) +def get_parent_chunk_id_multiple( + meta: ChunkedGraphMeta, node_or_chunk_ids: np.ndarray +) -> np.ndarray: + """Parent chunk IDs for multiple nodes. Assumes nodes at same layer.""" + + node_layers = utils.get_chunk_layers(meta, node_or_chunk_ids) + assert np.unique(node_layers).size == 1, np.unique(node_layers) + parent_layer = node_layers[0] + 1 + coords = utils.get_chunk_coordinates_multiple(meta, node_or_chunk_ids) + coords = coords // meta.graph_config.FANOUT + return utils.get_chunk_ids_from_coords(meta, layer=parent_layer, coords=coords) + + def get_parent_chunk_ids( meta: ChunkedGraphMeta, node_or_chunk_id: np.uint64 ) -> np.ndarray: diff --git a/pychunkedgraph/graph/chunks/utils.py b/pychunkedgraph/graph/chunks/utils.py index dc895bde4..5b6d0ae78 100644 --- a/pychunkedgraph/graph/chunks/utils.py +++ b/pychunkedgraph/graph/chunks/utils.py @@ -1,13 +1,17 @@ # pylint: disable=invalid-name, missing-docstring -from typing import List from typing import Union from typing import Optional from typing import Sequence +from typing import Tuple from typing import Iterable +from copy import copy +from functools import lru_cache + import numpy as np + def get_chunks_boundary(voxel_boundary, chunk_size) -> np.ndarray: """returns number of chunks in each dimension""" return np.ceil((voxel_boundary / chunk_size)).astype(int) @@ -43,7 +47,7 @@ def normalize_bounding_box( def get_chunk_layer(meta, node_or_chunk_id: np.uint64) -> int: - """ Extract Layer from Node ID or Chunk ID """ + """Extract Layer from Node ID or Chunk ID""" return int(int(node_or_chunk_id) >> 64 - meta.graph_config.LAYER_ID_BITS) @@ -75,9 +79,9 @@ def get_chunk_coordinates(meta, node_or_chunk_id: np.uint64) -> np.ndarray: y_offset = x_offset - bits_per_dim z_offset = y_offset - bits_per_dim - x = int(node_or_chunk_id) >> x_offset & 2 ** bits_per_dim - 1 - y = int(node_or_chunk_id) >> y_offset & 2 ** bits_per_dim - 1 - z = int(node_or_chunk_id) >> z_offset & 2 ** bits_per_dim - 1 + x = int(node_or_chunk_id) >> x_offset & 2**bits_per_dim - 1 + y = int(node_or_chunk_id) >> y_offset & 2**bits_per_dim - 1 + z = int(node_or_chunk_id) >> z_offset & 2**bits_per_dim - 1 return np.array([x, y, z]) @@ -86,8 +90,8 @@ def get_chunk_coordinates_multiple(meta, ids: np.ndarray) -> np.ndarray: Array version of get_chunk_coordinates. Assumes all given IDs are in same layer. """ - if not len(ids): - return np.array([]) + if len(ids) == 0: + return np.array([], dtype=int).reshape(0, 3) layer = get_chunk_layer(meta, ids[0]) bits_per_dim = meta.bitmasks[layer] @@ -96,9 +100,9 @@ def get_chunk_coordinates_multiple(meta, ids: np.ndarray) -> np.ndarray: z_offset = y_offset - bits_per_dim ids = np.array(ids, dtype=int) - X = ids >> x_offset & 2 ** bits_per_dim - 1 - Y = ids >> y_offset & 2 ** bits_per_dim - 1 - Z = ids >> z_offset & 2 ** bits_per_dim - 1 + X = ids >> x_offset & 2**bits_per_dim - 1 + Y = ids >> y_offset & 2**bits_per_dim - 1 + Z = ids >> z_offset & 2**bits_per_dim - 1 return np.column_stack((X, Y, Z)) @@ -125,6 +129,7 @@ def get_chunk_id( def get_chunk_ids_from_coords(meta, layer: int, coords: np.ndarray): + layer = int(layer) result = np.zeros(len(coords), dtype=np.uint64) s_bits_per_dim = meta.bitmasks[layer] @@ -142,14 +147,15 @@ def get_chunk_ids_from_coords(meta, layer: int, coords: np.ndarray): def get_chunk_ids_from_node_ids(meta, ids: Iterable[np.uint64]) -> np.ndarray: - """ Extract Chunk IDs from Node IDs""" + """Extract Chunk IDs from Node IDs""" if len(ids) == 0: return np.array([], dtype=np.uint64) bits_per_dims = np.array([meta.bitmasks[l] for l in get_chunk_layers(meta, ids)]) offsets = 64 - meta.graph_config.LAYER_ID_BITS - 3 * bits_per_dims - cids1 = np.array((np.array(ids, dtype=int) >> offsets) << offsets, dtype=np.uint64) + ids = np.array(ids, dtype=int) + cids1 = np.array((ids >> offsets) << offsets, dtype=np.uint64) # cids2 = np.vectorize(get_chunk_id)(meta, ids) # assert np.all(cids1 == cids2) return cids1 @@ -164,7 +170,7 @@ def _compute_chunk_id( ) -> np.uint64: s_bits_per_dim = meta.bitmasks[layer] if not ( - x < 2 ** s_bits_per_dim and y < 2 ** s_bits_per_dim and z < 2 ** s_bits_per_dim + x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim ): raise ValueError( f"Coordinate is out of range \ @@ -208,8 +214,9 @@ def _get_chunk_coordinates_from_vol_coordinates( return coords.astype(int) +@lru_cache() def get_bounding_children_chunks( - cg_meta, layer: int, chunk_coords: Sequence[int], children_layer, return_unique=True + cg_meta, layer: int, chunk_coords: Tuple[int], children_layer, return_unique=True ) -> np.ndarray: """Children chunk coordinates at given layer, along the boundary of a chunk""" chunk_coords = np.array(chunk_coords, dtype=int) @@ -233,3 +240,47 @@ def get_bounding_children_chunks( if return_unique: return np.unique(result, axis=0) if result.size else result return result + + +@lru_cache() +def get_l2chunkids_along_boundary(cg_meta, mlayer: int, coord_a, coord_b, padding: int = 0): + """ + Gets L2 Chunk IDs along opposing faces for larger chunks. + If padding is enabled, more faces of L2 chunks are padded on both sides. + This is necessary to find fake edges that can span more than 2 L2 chunks. + """ + bounds_a = get_bounding_children_chunks(cg_meta, mlayer, tuple(coord_a), 2) + bounds_b = get_bounding_children_chunks(cg_meta, mlayer, tuple(coord_b), 2) + + coord_a, coord_b = np.array(coord_a, dtype=int), np.array(coord_b, dtype=int) + direction = coord_a - coord_b + major_axis = np.argmax(np.abs(direction)) + + l2chunk_count = 2 ** (mlayer - 2) + max_coord = coord_a if direction[major_axis] > 0 else coord_b + + skip = abs(direction[major_axis]) - 1 + l2_skip = skip * l2chunk_count + + mid = max_coord[major_axis] * l2chunk_count + face_a = mid if direction[major_axis] > 0 else (mid - l2_skip - 1) + face_b = mid if direction[major_axis] < 0 else (mid - l2_skip - 1) + + l2chunks_a = [bounds_a[bounds_a[:, major_axis] == face_a]] + l2chunks_b = [bounds_b[bounds_b[:, major_axis] == face_b]] + + step_a, step_b = (1, -1) if direction[major_axis] > 0 else (-1, 1) + for _ in range(padding): + _l2_chunks_a = copy(l2chunks_a[-1]) + _l2_chunks_b = copy(l2chunks_b[-1]) + _l2_chunks_a[:, major_axis] += step_a + _l2_chunks_b[:, major_axis] += step_b + l2chunks_a.append(_l2_chunks_a) + l2chunks_b.append(_l2_chunks_b) + + l2chunks_a = np.concatenate(l2chunks_a) + l2chunks_b = np.concatenate(l2chunks_b) + + l2chunk_ids_a = get_chunk_ids_from_coords(cg_meta, 2, l2chunks_a) + l2chunk_ids_b = get_chunk_ids_from_coords(cg_meta, 2, l2chunks_b) + return l2chunk_ids_a, l2chunk_ids_b diff --git a/pychunkedgraph/graph/client/base.py b/pychunkedgraph/graph/client/base.py index a66602a6a..953734670 100644 --- a/pychunkedgraph/graph/client/base.py +++ b/pychunkedgraph/graph/client/base.py @@ -13,7 +13,7 @@ def create_graph(self) -> None: """Initialize the graph and store associated meta.""" @abstractmethod - def add_graph_version(self, version): + def add_graph_version(self, version: str, overwrite: bool = False): """Add a version to the graph.""" @abstractmethod diff --git a/pychunkedgraph/graph/client/bigtable/client.py b/pychunkedgraph/graph/client/bigtable/client.py index 5b86826bd..260d985ab 100644 --- a/pychunkedgraph/graph/client/bigtable/client.py +++ b/pychunkedgraph/graph/client/bigtable/client.py @@ -1,11 +1,12 @@ -# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel, line-too-long, protected-access, arguments-differ, arguments-renamed, logging-fstring-interpolation +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel, line-too-long, protected-access, arguments-differ, arguments-renamed, logging-fstring-interpolation, too-many-arguments import sys import time import typing import logging -import datetime from datetime import datetime +from datetime import timedelta +from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np from multiwrapper import multiprocessing_utils as mu @@ -15,11 +16,12 @@ from google.api_core.exceptions import Aborted from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable +from google.cloud.bigtable.column_family import MaxAgeGCRule +from google.cloud.bigtable.column_family import MaxVersionsGCRule from google.cloud.bigtable.table import Table from google.cloud.bigtable.row_set import RowSet -from google.cloud.bigtable.row_data import PartialRowData +from google.cloud.bigtable.row_data import DEFAULT_RETRY_READ_ROWS, PartialRowData from google.cloud.bigtable.row_filters import RowFilter -from google.cloud.bigtable.column_family import MaxVersionsGCRule from . import utils from . import BigTableConfig @@ -71,6 +73,18 @@ def __init__( self._version = None self._max_row_key_count = config.MAX_ROW_KEY_COUNT + def _create_column_families(self): + f = self._table.column_family("0") + f.create() + f = self._table.column_family("1", gc_rule=MaxVersionsGCRule(1)) + f.create() + f = self._table.column_family("2") + f.create() + f = self._table.column_family("3", gc_rule=MaxAgeGCRule(timedelta(days=365))) + f.create() + f = self._table.column_family("4") + f.create() + @property def graph_meta(self): return self._graph_meta @@ -84,8 +98,9 @@ def create_graph(self, meta: ChunkedGraphMeta, version: str) -> None: self.add_graph_version(version) self.update_graph_meta(meta) - def add_graph_version(self, version: str): - assert self.read_graph_version() is None, "Graph has already been versioned." + def add_graph_version(self, version: str, overwrite: bool = False): + if not overwrite: + assert self.read_graph_version() is None, self.read_graph_version() self._version = version row = self.mutate_row( attributes.GraphVersion.key, @@ -137,6 +152,7 @@ def read_nodes( end_time=None, end_time_inclusive: bool = False, fake_edges: bool = False, + attr_keys: bool = True, ): """ Read nodes and their properties. @@ -147,26 +163,40 @@ def read_nodes( # when all IDs in a block are within a range node_ids = np.sort(node_ids) rows = self._read_byte_rows( - start_key=serialize_uint64(start_id, fake_edges=fake_edges) - if start_id is not None - else None, - end_key=serialize_uint64(end_id, fake_edges=fake_edges) - if end_id is not None - else None, + start_key=( + serialize_uint64(start_id, fake_edges=fake_edges) + if start_id is not None + else None + ), + end_key=( + serialize_uint64(end_id, fake_edges=fake_edges) + if end_id is not None + else None + ), end_key_inclusive=end_id_inclusive, row_keys=( - serialize_uint64(node_id, fake_edges=fake_edges) for node_id in node_ids - ) - if node_ids is not None - else None, + ( + serialize_uint64(node_id, fake_edges=fake_edges) + for node_id in node_ids + ) + if node_ids is not None + else None + ), columns=properties, start_time=start_time, end_time=end_time, end_time_inclusive=end_time_inclusive, user_id=user_id, ) + if attr_keys: + return { + deserialize_uint64(row_key, fake_edges=fake_edges): data + for (row_key, data) in rows.items() + } return { - deserialize_uint64(row_key, fake_edges=fake_edges): data + deserialize_uint64(row_key, fake_edges=fake_edges): { + k.key: v for k, v in data.items() + } for (row_key, data) in rows.items() } @@ -424,11 +454,9 @@ def lock_roots( max_tries: int = 1, waittime_s: float = 0.5, ) -> typing.Tuple[bool, typing.Iterable]: - """Attempts to lock multiple nodes with same operation id""" + """Attempts to lock multiple nodes with same operation id in parallel""" i_try = 0 while i_try < max_tries: - lock_acquired = False - # Collect latest root ids new_root_ids: typing.List[np.uint64] = [] for root_id in root_ids: future_root_ids = future_root_ids_d[root_id] @@ -437,18 +465,36 @@ def lock_roots( else: new_root_ids.extend(future_root_ids) - # Attempt to lock all latest root ids + lock_results = {} root_ids = np.unique(new_root_ids) - for root_id in root_ids: - lock_acquired = self.lock_root(root_id, operation_id) - # Roll back locks if one root cannot be locked - if not lock_acquired: - for id_ in root_ids: - self.unlock_root(id_, operation_id) - break - - if lock_acquired: + max_workers = min(8, max(1, len(root_ids))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_root = { + executor.submit(self.lock_root, root_id, operation_id): root_id + for root_id in root_ids + } + for future in as_completed(future_to_root): + root_id = future_to_root[future] + try: + lock_results[root_id] = future.result() + except Exception as e: + self.logger.error(f"Failed to lock root {root_id}: {e}") + lock_results[root_id] = False + + all_locked = all(lock_results.values()) + if all_locked: return True, root_ids + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + unlock_futures = [ + executor.submit(self.unlock_root, root_id, operation_id) + for root_id in root_ids + ] + for future in as_completed(unlock_futures): + try: + future.result() + except Exception as e: + self.logger.error(f"Failed to unlock root: {e}") time.sleep(waittime_s) i_try += 1 self.logger.debug(f"Try {i_try}") @@ -459,9 +505,8 @@ def lock_roots_indefinitely( root_ids: typing.Sequence[np.uint64], operation_id: np.uint64, future_root_ids_d: typing.Dict, - ) -> typing.Tuple[bool, typing.Iterable]: + ) -> typing.Tuple[bool, typing.Iterable, typing.Iterable]: """Attempts to indefinitely lock multiple nodes with same operation id""" - lock_acquired = False # Collect latest root ids new_root_ids: typing.List[np.uint64] = [] for _id in root_ids: @@ -471,21 +516,45 @@ def lock_roots_indefinitely( else: new_root_ids.extend(future_root_ids) - # Attempt to lock all latest root ids - failed_to_lock_id = None root_ids = np.unique(new_root_ids) - for _id in root_ids: - self.logger.debug(f"operation {operation_id} root_id {_id}") - lock_acquired = self.lock_root_indefinitely(_id, operation_id) - # Roll back locks if one root cannot be locked - if not lock_acquired: - failed_to_lock_id = _id - for id_ in root_ids: - self.unlock_indefinitely_locked_root(id_, operation_id) - break - if lock_acquired: - return True, root_ids, failed_to_lock_id - return False, root_ids, failed_to_lock_id + lock_results = {} + max_workers = min(8, max(1, len(root_ids))) + failed_to_lock = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_root = { + executor.submit( + self.lock_root_indefinitely, root_id, operation_id + ): root_id + for root_id in root_ids + } + for future in as_completed(future_to_root): + root_id = future_to_root[future] + try: + lock_results[root_id] = future.result() + if lock_results[root_id] is False: + failed_to_lock.append(root_id) + except Exception as e: + self.logger.error(f"Failed to lock root {root_id}: {e}") + lock_results[root_id] = False + failed_to_lock.append(root_id) + + all_locked = all(lock_results.values()) + if all_locked: + return True, root_ids, failed_to_lock + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + unlock_futures = [ + executor.submit( + self.unlock_indefinitely_locked_root, root_id, operation_id + ) + for root_id in root_ids + ] + for future in as_completed(unlock_futures): + try: + future.result() + except Exception as e: + self.logger.error(f"Failed to unlock root: {e}") + return False, root_ids, failed_to_lock def unlock_root(self, root_id: np.uint64, operation_id: np.uint64): """Unlocks root node that is locked with operation_id.""" @@ -532,10 +601,22 @@ def renew_lock(self, root_id: np.uint64, operation_id: np.uint64) -> bool: def renew_locks(self, root_ids: np.uint64, operation_id: np.uint64) -> bool: """Renews existing root node locks with operation_id to extend time.""" - for root_id in root_ids: - if not self.renew_lock(root_id, operation_id): - self.logger.warning(f"renew_lock failed - {root_id}") - return False + max_workers = min(8, max(1, len(root_ids))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(self.renew_lock, root_id, operation_id): root_id + for root_id in root_ids + } + for future in as_completed(futures): + root_id = futures[future] + try: + result = future.result() + if not result: + self.logger.warning(f"renew_lock failed - {root_id}") + return False + except Exception as e: + self.logger.error(f"Exception during renew_lock({root_id}): {e}") + return False return True def get_lock_timestamp( @@ -557,15 +638,31 @@ def get_consolidated_lock_timestamp( operation_ids: typing.Sequence[np.uint64], ) -> typing.Union[datetime, None]: """Minimum of multiple lock timestamps.""" - time_stamps = [] - for root_id, operation_id in zip(root_ids, operation_ids): - time_stamp = self.get_lock_timestamp(root_id, operation_id) - if time_stamp is None: - return None - time_stamps.append(time_stamp) - if len(time_stamps) == 0: + if len(root_ids) == 0: return None - return np.min(time_stamps) + max_workers = min(8, max(1, len(root_ids))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(self.get_lock_timestamp, root_id, op_id): ( + root_id, + op_id, + ) + for root_id, op_id in zip(root_ids, operation_ids) + } + timestamps = [] + for future in as_completed(futures): + root_id, op_id = futures[future] + try: + ts = future.result() + if ts is None: + return None + timestamps.append(ts) + except Exception as exc: + self.logger.warning(f"({root_id}, {op_id}): {exc}") + return None + if not timestamps: + return None + return np.min(timestamps) # IDs def create_node_ids( @@ -628,16 +725,6 @@ def get_compatible_timestamp( return utils.get_google_compatible_time_stamp(time_stamp, round_up=round_up) # PRIVATE METHODS - def _create_column_families(self): - f = self._table.column_family("0") - f.create() - f = self._table.column_family("1", gc_rule=MaxVersionsGCRule(1)) - f.create() - f = self._table.column_family("2") - f.create() - f = self._table.column_family("3") - f.create() - def _get_ids_range(self, key: bytes, size: int) -> typing.Tuple: """Returns a range (min, max) of IDs for a given `key`.""" column = attributes.Concurrency.Counter @@ -816,7 +903,8 @@ def _execute_read_thread(self, args: typing.Tuple[Table, RowSet, RowFilter]): # Check for everything falsy, because Bigtable considers even empty # lists of row_keys as no upper/lower bound! return {} - range_read = table.read_rows(row_set=row_set, filter_=row_filter) + retry = DEFAULT_RETRY_READ_ROWS.with_timeout(600) + range_read = table.read_rows(row_set=row_set, filter_=row_filter, retry=retry) res = {v.row_key: utils.partial_row_data_to_column_dict(v) for v in range_read} return res diff --git a/pychunkedgraph/graph/client/bigtable/utils.py b/pychunkedgraph/graph/client/bigtable/utils.py index 2d30eeb32..3f14e125d 100644 --- a/pychunkedgraph/graph/client/bigtable/utils.py +++ b/pychunkedgraph/graph/client/bigtable/utils.py @@ -4,6 +4,7 @@ from typing import Optional from datetime import datetime from datetime import timedelta +from datetime import timezone import numpy as np from google.cloud.bigtable.row_data import PartialRowData @@ -146,7 +147,7 @@ def get_time_range_and_column_filter( def get_root_lock_filter( lock_column, lock_expiry, indefinite_lock_column ) -> ConditionalRowFilter: - time_cutoff = datetime.utcnow() - lock_expiry + time_cutoff = datetime.now(timezone.utc) - lock_expiry # Comply to resolution of BigTables TimeRange time_cutoff -= timedelta(microseconds=time_cutoff.microsecond % 1000) time_filter = TimestampRangeFilter(TimestampRange(start=time_cutoff)) @@ -256,7 +257,7 @@ def get_renew_lock_filter( def get_unlock_root_filter(lock_column, lock_expiry, operation_id) -> RowFilterChain: - time_cutoff = datetime.utcnow() - lock_expiry + time_cutoff = datetime.now(timezone.utc) - lock_expiry # Comply to resolution of BigTables TimeRange time_cutoff -= timedelta(microseconds=time_cutoff.microsecond % 1000) time_filter = TimestampRangeFilter(TimestampRange(start=time_cutoff)) diff --git a/pychunkedgraph/graph/connectivity/search.py b/pychunkedgraph/graph/connectivity/search.py deleted file mode 100644 index bd3faf227..000000000 --- a/pychunkedgraph/graph/connectivity/search.py +++ /dev/null @@ -1,47 +0,0 @@ -import random -from typing import List - -import numpy as np -from graph_tool.search import bfs_search -from graph_tool.search import BFSVisitor -from graph_tool.search import StopSearch - -from ..utils.basetypes import NODE_ID - - -class TargetVisitor(BFSVisitor): - def __init__(self, target, reachable): - self.target = target - self.reachable = reachable - - def discover_vertex(self, u): - if u == self.target: - self.reachable[u] = 1 - raise StopSearch - - -def check_reachability(g, sv1s: np.ndarray, sv2s: np.ndarray, original_ids: np.ndarray) -> np.ndarray: - """ - g: graph tool Graph instance with ids 0 to N-1 where N = vertex count - original_ids: sorted ChunkedGraph supervoxel ids - (to identify corresponding ids in graph tool) - for each pair (sv1, sv2) check if a path exists (BFS) - """ - # mapping from original ids to graph tool ids - original_ids_d = { - sv_id: index for sv_id, index in zip(original_ids, range(len(original_ids))) - } - reachable = g.new_vertex_property("int", val=0) - - def _check_reachability(source, target): - bfs_search(g, source, TargetVisitor(target, reachable)) - return reachable[target] - - return np.array( - [ - _check_reachability(original_ids_d[source], original_ids_d[target]) - for source, target in zip(sv1s, sv2s) - ], - dtype=bool, - ) - diff --git a/pychunkedgraph/graph/cutting.py b/pychunkedgraph/graph/cutting.py index 8b1583871..a2fca8023 100644 --- a/pychunkedgraph/graph/cutting.py +++ b/pychunkedgraph/graph/cutting.py @@ -62,7 +62,7 @@ def merge_cross_chunk_edges_graph_tool( if len(mapping) > 0: mapping = np.concatenate(mapping) u_nodes = np.unique(edges) - u_unmapped_nodes = u_nodes[~np.in1d(u_nodes, mapping)] + u_unmapped_nodes = u_nodes[~np.isin(u_nodes, mapping)] unmapped_mapping = np.concatenate( [u_unmapped_nodes.reshape(-1, 1), u_unmapped_nodes.reshape(-1, 1)], axis=1 ) @@ -189,9 +189,9 @@ def _build_gt_graph(self, edges, affs): ) = flatgraph.build_gt_graph(comb_edges, comb_affs, make_directed=True) self.source_graph_ids = np.where( - np.in1d(self.unique_supervoxel_ids, self.sources) + np.isin(self.unique_supervoxel_ids, self.sources) )[0] - self.sink_graph_ids = np.where(np.in1d(self.unique_supervoxel_ids, self.sinks))[ + self.sink_graph_ids = np.where(np.isin(self.unique_supervoxel_ids, self.sinks))[ 0 ] @@ -398,7 +398,7 @@ def _remap_cut_edge_set(self, cut_edge_set): remapped_cutset_flattened_view = remapped_cutset.view(dtype="u8,u8") edges_flattened_view = self.cg_edges.view(dtype="u8,u8") - cutset_mask = np.in1d(remapped_cutset_flattened_view, edges_flattened_view) + cutset_mask = np.isin(remapped_cutset_flattened_view, edges_flattened_view).ravel() return remapped_cutset[cutset_mask] @@ -432,8 +432,8 @@ def _get_split_preview_connected_components(self, cut_edge_set): max_sinks = 0 i = 0 for cc in ccs_test_post_cut: - num_sources = np.count_nonzero(np.in1d(self.source_graph_ids, cc)) - num_sinks = np.count_nonzero(np.in1d(self.sink_graph_ids, cc)) + num_sources = np.count_nonzero(np.isin(self.source_graph_ids, cc)) + num_sinks = np.count_nonzero(np.isin(self.sink_graph_ids, cc)) if num_sources > max_sources: max_sources = num_sources max_source_index = i @@ -486,13 +486,15 @@ def _filter_graph_connected_components(self): # If connected component contains no sources or no sinks, # remove its nodes from the mincut computation if not ( - np.any(np.in1d(self.source_graph_ids, cc)) - and np.any(np.in1d(self.sink_graph_ids, cc)) + np.any(np.isin(self.source_graph_ids, cc)) + and np.any(np.isin(self.sink_graph_ids, cc)) ): for node_id in cc: removed[node_id] = True - self.weighted_graph.set_vertex_filter(removed, inverted=True) + keep = self.weighted_graph.new_vertex_property("bool") + keep.a = ~removed.a.astype(bool) + self.weighted_graph.set_vertex_filter(keep) pruned_graph = graph_tool.Graph(self.weighted_graph, prune=True) # Test that there is only one connected component left ccs = flatgraph.connected_components(pruned_graph) @@ -525,13 +527,13 @@ def _gt_mincut_sanity_check(self, partition): np.array(np.where(partition.a == i_cc)[0], dtype=int) ] - if np.any(np.in1d(self.sources, cc_list)): - assert np.all(np.in1d(self.sources, cc_list)) - assert ~np.any(np.in1d(self.sinks, cc_list)) + if np.any(np.isin(self.sources, cc_list)): + assert np.all(np.isin(self.sources, cc_list)) + assert ~np.any(np.isin(self.sinks, cc_list)) - if np.any(np.in1d(self.sinks, cc_list)): - assert np.all(np.in1d(self.sinks, cc_list)) - assert ~np.any(np.in1d(self.sources, cc_list)) + if np.any(np.isin(self.sinks, cc_list)): + assert np.all(np.isin(self.sinks, cc_list)) + assert ~np.any(np.isin(self.sources, cc_list)) def _sink_and_source_connectivity_sanity_check(self, cut_edge_set): """ @@ -547,7 +549,8 @@ def _sink_and_source_connectivity_sanity_check(self, cut_edge_set): for edge_to_remove in parallel_edges: self.edges_to_remove[edge_to_remove] = True - self.weighted_graph.set_edge_filter(self.edges_to_remove, True) + self.edges_to_remove.a = ~self.edges_to_remove.a.astype(bool) + self.weighted_graph.set_edge_filter(self.edges_to_remove) ccs_test_post_cut = flatgraph.connected_components(self.weighted_graph) # Make sure sinks and sources are among each other and not in different sets @@ -555,9 +558,9 @@ def _sink_and_source_connectivity_sanity_check(self, cut_edge_set): illegal_split = False try: for cc in ccs_test_post_cut: - if np.any(np.in1d(self.source_graph_ids, cc)): - assert np.all(np.in1d(self.source_graph_ids, cc)) - assert ~np.any(np.in1d(self.sink_graph_ids, cc)) + if np.any(np.isin(self.source_graph_ids, cc)): + assert np.all(np.isin(self.source_graph_ids, cc)) + assert ~np.any(np.isin(self.sink_graph_ids, cc)) if ( len(self.source_path_vertices) == len(cc) and self.disallow_isolating_cut @@ -565,9 +568,9 @@ def _sink_and_source_connectivity_sanity_check(self, cut_edge_set): if not self.partition_edges_within_label(cc): raise IsolatingCutException("Source") - if np.any(np.in1d(self.sink_graph_ids, cc)): - assert np.all(np.in1d(self.sink_graph_ids, cc)) - assert ~np.any(np.in1d(self.source_graph_ids, cc)) + if np.any(np.isin(self.sink_graph_ids, cc)): + assert np.all(np.isin(self.sink_graph_ids, cc)) + assert ~np.any(np.isin(self.source_graph_ids, cc)) if ( len(self.sink_path_vertices) == len(cc) and self.disallow_isolating_cut @@ -664,8 +667,8 @@ def run_split_preview( supervoxels = np.concatenate( [agg.supervoxels for agg in l2id_agglomeration_d.values()] ) - mask0 = np.in1d(edges.node_ids1, supervoxels) - mask1 = np.in1d(edges.node_ids2, supervoxels) + mask0 = np.isin(edges.node_ids1, supervoxels) + mask1 = np.isin(edges.node_ids2, supervoxels) edges = edges[mask0 & mask1] edges_to_remove, illegal_split = run_multicut( edges, diff --git a/pychunkedgraph/graph/edges/__init__.py b/pychunkedgraph/graph/edges/__init__.py index b0e488d05..80bc57d4a 100644 --- a/pychunkedgraph/graph/edges/__init__.py +++ b/pychunkedgraph/graph/edges/__init__.py @@ -2,104 +2,12 @@ Classes and types for edges """ -from typing import Optional -from collections import namedtuple - -import numpy as np - -from ..utils import basetypes - - -_edge_type_fileds = ("in_chunk", "between_chunk", "cross_chunk") -_edge_type_defaults = ("in", "between", "cross") - -EdgeTypes = namedtuple("EdgeTypes", _edge_type_fileds, defaults=_edge_type_defaults) -EDGE_TYPES = EdgeTypes() - -DEFAULT_AFFINITY = np.finfo(np.float32).tiny -DEFAULT_AREA = np.finfo(np.float32).tiny - - -class Edges: - def __init__( - self, - node_ids1: np.ndarray, - node_ids2: np.ndarray, - *, - affinities: Optional[np.ndarray] = None, - areas: Optional[np.ndarray] = None, - fake_edges=False, - ): - self.node_ids1 = np.array(node_ids1, dtype=basetypes.NODE_ID, copy=False) - self.node_ids2 = np.array(node_ids2, dtype=basetypes.NODE_ID, copy=False) - assert self.node_ids1.size == self.node_ids2.size - - self._as_pairs = None - self._fake_edges = fake_edges - - if affinities is not None and len(affinities) > 0: - self._affinities = np.array(affinities, dtype=basetypes.EDGE_AFFINITY, copy=False) - assert self.node_ids1.size == self._affinities.size - else: - self._affinities = np.full(len(self.node_ids1), DEFAULT_AFFINITY) - - if areas is not None and len(areas) > 0: - self._areas = np.array(areas, dtype=basetypes.EDGE_AREA, copy=False) - assert self.node_ids1.size == self._areas.size - else: - self._areas = np.full(len(self.node_ids1), DEFAULT_AREA) - - @property - def affinities(self) -> np.ndarray: - return self._affinities - - @affinities.setter - def affinities(self, affinities): - self._affinities = affinities - - @property - def areas(self) -> np.ndarray: - return self._areas - - @areas.setter - def areas(self, areas): - self._areas = areas - - def __add__(self, other): - """add two Edges instances""" - node_ids1 = np.concatenate([self.node_ids1, other.node_ids1]) - node_ids2 = np.concatenate([self.node_ids2, other.node_ids2]) - affinities = np.concatenate([self.affinities, other.affinities]) - areas = np.concatenate([self.areas, other.areas]) - return Edges(node_ids1, node_ids2, affinities=affinities, areas=areas) - - def __iadd__(self, other): - self.node_ids1 = np.concatenate([self.node_ids1, other.node_ids1]) - self.node_ids2 = np.concatenate([self.node_ids2, other.node_ids2]) - self.affinities = np.concatenate([self.affinities, other.affinities]) - self.areas = np.concatenate([self.areas, other.areas]) - return self - - def __len__(self): - return self.node_ids1.size - - def __getitem__(self, key): - """`key` must be a boolean numpy array.""" - try: - return Edges( - self.node_ids1[key], - self.node_ids2[key], - affinities=self.affinities[key], - areas=self.areas[key], - ) - except Exception as err: - raise (err) - - def get_pairs(self) -> np.ndarray: - """ - return numpy array of edge pairs [[sv1, sv2] ... ] - """ - if not self._as_pairs is None: - return self._as_pairs - self._as_pairs = np.column_stack((self.node_ids1, self.node_ids2)) - return self._as_pairs +from .definitions import EDGE_TYPES, Edges +from .ocdbt import put_edges, get_edges + +from .stale import ( + get_new_nodes, + get_stale_nodes, + get_latest_edges, + get_latest_edges_wrapper, +) diff --git a/pychunkedgraph/graph/edges/definitions.py b/pychunkedgraph/graph/edges/definitions.py new file mode 100644 index 000000000..26a14dd82 --- /dev/null +++ b/pychunkedgraph/graph/edges/definitions.py @@ -0,0 +1,111 @@ +""" +Edge data structures and type definitions. +""" + +from collections import namedtuple +from typing import Optional + +import numpy as np + +from ..utils import basetypes + + +_edge_type_fileds = ("in_chunk", "between_chunk", "cross_chunk") +_edge_type_defaults = ("in", "between", "cross") + +EdgeTypes = namedtuple("EdgeTypes", _edge_type_fileds, defaults=_edge_type_defaults) +EDGE_TYPES = EdgeTypes() + +DEFAULT_AFFINITY = np.finfo(np.float32).tiny +DEFAULT_AREA = np.finfo(np.float32).tiny +ADJACENCY_DTYPE = np.dtype( + [ + ("node", basetypes.NODE_ID), + ("aff", basetypes.EDGE_AFFINITY), + ("area", basetypes.EDGE_AREA), + ] +) +ZSTD_EDGE_COMPRESSION = 17 + + +class Edges: + def __init__( + self, + node_ids1: np.ndarray, + node_ids2: np.ndarray, + *, + affinities: Optional[np.ndarray] = None, + areas: Optional[np.ndarray] = None, + ): + self.node_ids1 = np.array(node_ids1, dtype=basetypes.NODE_ID) + self.node_ids2 = np.array(node_ids2, dtype=basetypes.NODE_ID) + assert self.node_ids1.size == self.node_ids2.size + + self._as_pairs = None + + if affinities is not None and len(affinities) > 0: + self._affinities = np.array(affinities, dtype=basetypes.EDGE_AFFINITY) + assert self.node_ids1.size == self._affinities.size + else: + self._affinities = np.full(len(self.node_ids1), DEFAULT_AFFINITY) + + if areas is not None and len(areas) > 0: + self._areas = np.array(areas, dtype=basetypes.EDGE_AREA) + assert self.node_ids1.size == self._areas.size + else: + self._areas = np.full(len(self.node_ids1), DEFAULT_AREA) + + @property + def affinities(self) -> np.ndarray: + return self._affinities + + @affinities.setter + def affinities(self, affinities): + self._affinities = affinities + + @property + def areas(self) -> np.ndarray: + return self._areas + + @areas.setter + def areas(self, areas): + self._areas = areas + + def __add__(self, other): + """add two Edges instances""" + node_ids1 = np.concatenate([self.node_ids1, other.node_ids1]) + node_ids2 = np.concatenate([self.node_ids2, other.node_ids2]) + affinities = np.concatenate([self.affinities, other.affinities]) + areas = np.concatenate([self.areas, other.areas]) + return Edges(node_ids1, node_ids2, affinities=affinities, areas=areas) + + def __iadd__(self, other): + self.node_ids1 = np.concatenate([self.node_ids1, other.node_ids1]) + self.node_ids2 = np.concatenate([self.node_ids2, other.node_ids2]) + self.affinities = np.concatenate([self.affinities, other.affinities]) + self.areas = np.concatenate([self.areas, other.areas]) + return self + + def __len__(self): + return self.node_ids1.size + + def __getitem__(self, key): + """`key` must be a boolean numpy array.""" + try: + return Edges( + self.node_ids1[key], + self.node_ids2[key], + affinities=self.affinities[key], + areas=self.areas[key], + ) + except Exception as err: + raise (err) + + def get_pairs(self) -> np.ndarray: + """ + return numpy array of edge pairs [[sv1, sv2] ... ] + """ + if not self._as_pairs is None: + return self._as_pairs + self._as_pairs = np.column_stack((self.node_ids1, self.node_ids2)) + return self._as_pairs diff --git a/pychunkedgraph/graph/edges/ocdbt.py b/pychunkedgraph/graph/edges/ocdbt.py new file mode 100644 index 000000000..99fa1ba68 --- /dev/null +++ b/pychunkedgraph/graph/edges/ocdbt.py @@ -0,0 +1,87 @@ +""" +OCDBT storage I/O for edges. +""" + +from os import environ + +import numpy as np +import tensorstore as ts +import zstandard as zstd +from graph_tool import Graph + +from ..utils import basetypes +from .definitions import ADJACENCY_DTYPE, ZSTD_EDGE_COMPRESSION, Edges + + +def put_edges(destination: str, nodes: np.ndarray, edges: Edges) -> None: + graph_ids, _edges = np.unique(edges.get_pairs(), return_inverse=True) + graph_ids_reverse = {n: i for i, n in enumerate(graph_ids)} + _edges = _edges.reshape(-1, 2) + + graph = Graph(directed=False) + graph.add_edge_list(_edges) + e_aff = graph.new_edge_property("double", vals=edges.affinities) + e_area = graph.new_edge_property("int", vals=edges.areas) + cctx = zstd.ZstdCompressor(level=ZSTD_EDGE_COMPRESSION) + ocdbt_host = environ["OCDBT_COORDINATOR_HOST"] + ocdbt_port = environ["OCDBT_COORDINATOR_PORT"] + + spec = { + "driver": "ocdbt", + "base": destination, + "coordinator": {"address": f"{ocdbt_host}:{ocdbt_port}"}, + } + dataset = ts.KvStore.open(spec).result() + with ts.Transaction() as txn: + for _node in nodes: + node = graph_ids_reverse[_node] + neighbors = graph.get_all_neighbors(node) + adjacency_list = np.zeros(neighbors.size, dtype=ADJACENCY_DTYPE) + adjacency_list["node"] = graph_ids[neighbors] + adjacency_list["aff"] = [e_aff[(node, neighbor)] for neighbor in neighbors] + adjacency_list["area"] = [ + e_area[(node, neighbor)] for neighbor in neighbors + ] + dataset.with_transaction(txn)[str(graph_ids[node])] = cctx.compress( + adjacency_list.tobytes() + ) + + +def get_edges(source: str, nodes: np.ndarray) -> Edges: + spec = {"driver": "ocdbt", "base": source} + dataset = ts.KvStore.open(spec).result() + zdc = zstd.ZstdDecompressor() + + read_futures = [dataset.read(str(n)) for n in nodes] + read_results = [rf.result() for rf in read_futures] + compressed = [rr.value for rr in read_results] + + try: + n_threads = int(environ.get("ZSTD_THREADS", 1)) + except ValueError: + n_threads = 1 + + decompressed = [] + try: + decompressed = zdc.multi_decompress_to_buffer(compressed, threads=n_threads) + except ValueError: + for content in compressed: + decompressed.append(zdc.decompressobj().decompress(content)) + + node_ids1 = [np.empty(0, dtype=basetypes.NODE_ID)] + node_ids2 = [np.empty(0, dtype=basetypes.NODE_ID)] + affinities = [np.empty(0, dtype=basetypes.EDGE_AFFINITY)] + areas = [np.empty(0, dtype=basetypes.EDGE_AREA)] + for n, content in zip(nodes, compressed): + adjacency_list = np.frombuffer(content, dtype=ADJACENCY_DTYPE) + node_ids1.append([n] * adjacency_list.size) + node_ids2.append(adjacency_list["node"]) + affinities.append(adjacency_list["aff"]) + areas.append(adjacency_list["area"]) + + return Edges( + np.concatenate(node_ids1), + np.concatenate(node_ids2), + affinities=np.concatenate(affinities), + areas=np.concatenate(areas), + ) diff --git a/pychunkedgraph/graph/edges/stale.py b/pychunkedgraph/graph/edges/stale.py new file mode 100644 index 000000000..e09dbac35 --- /dev/null +++ b/pychunkedgraph/graph/edges/stale.py @@ -0,0 +1,511 @@ +""" +Stale node detection and edge update logic. +""" + +import datetime +import logging +from os import environ +from typing import Iterable + +import numpy as np +from cachetools import LRUCache + +from pychunkedgraph.graph import types +from pychunkedgraph.graph.chunks.utils import get_l2chunkids_along_boundary + +from ..utils import basetypes +from ..utils.generic import get_parents_at_timestamp + + +PARENTS_CACHE: LRUCache = None +CHILDREN_CACHE: LRUCache = None + + +def get_new_nodes( + cg, nodes: np.ndarray, layer: int, parent_ts: datetime.datetime = None +): + unique_nodes, inverse = np.unique(nodes, return_inverse=True) + node_root_map = {n: n for n in unique_nodes} + lookup = np.ones(len(unique_nodes), dtype=unique_nodes.dtype) + while np.any(lookup): + roots = np.fromiter(node_root_map.values(), dtype=basetypes.NODE_ID) + roots = cg.get_parents(roots, time_stamp=parent_ts, fail_to_zero=True) + layers = cg.get_chunk_layers(roots) + lookup[layers > layer] = 0 + lookup[roots == 0] = 0 + + layer_mask = layers <= layer + non_zero_mask = roots != 0 + mask = layer_mask & non_zero_mask + for node, root in zip(unique_nodes[mask], roots[mask]): + node_root_map[node] = root + + unique_results = np.fromiter(node_root_map.values(), dtype=basetypes.NODE_ID) + return unique_results[inverse] + + +def get_stale_nodes( + cg, nodes: Iterable[basetypes.NODE_ID], parent_ts: datetime.datetime = None +): + """ + Checks to see if given nodes are stale. + This is done by getting a supervoxel of a node and checking + if it has a new parent at the same layer as the node. + """ + nodes = np.unique(np.array(nodes, dtype=basetypes.NODE_ID)) + new_ids = set() if cg.cache is None else cg.cache.new_ids + nodes = nodes[~np.isin(nodes, new_ids)] + supervoxels = cg.get_single_leaf_multiple(nodes) + # nodes can be at different layers due to skip connections + node_layers = cg.get_chunk_layers(nodes) + stale_nodes = [types.empty_1d] + for layer in np.unique(node_layers): + _mask = node_layers == layer + layer_nodes = nodes[_mask] + _nodes = get_new_nodes(cg, supervoxels[_mask], layer, parent_ts) + stale_mask = layer_nodes != _nodes + stale_nodes.append(layer_nodes[stale_mask]) + return np.concatenate(stale_nodes) + + +class LatestEdgesFinder: + """ + For each of stale_edges [[`node`, `partner`]], get their L2 edge equivalent. + Then get supervoxels of those L2 IDs and get parent(s) at `node` level. + These parents would be the new identities for the stale `partner`. + """ + + def __init__( + self, + cg, + stale_edges: Iterable, + edge_layers: Iterable, + parent_ts: datetime.datetime = None, + ): + self.cg = cg + self.stale_edges = stale_edges + self.edge_layers = edge_layers + self.parent_ts = parent_ts + + _nodes = np.unique(stale_edges) + self.nodes_ts_map = dict( + zip( + _nodes, + cg.get_node_timestamps(_nodes, return_numpy=False, normalize=True), + ) + ) + layers, coords = cg.get_chunk_layers_and_coordinates(_nodes) + self.layers_d = dict(zip(_nodes, layers)) + self.coords_d = dict(zip(_nodes, coords)) + + def _get_children_from_cache(self, nodes): + children = [] + non_cached = [] + for node in nodes: + try: + v = CHILDREN_CACHE[node] + children.append(v) + except KeyError: + non_cached.append(node) + + children_map = self.cg.get_children(non_cached) + for k, v in children_map.items(): + CHILDREN_CACHE[k] = v + children.append(v) + return np.concatenate(children) + + def _get_normalized_coords(self, node_a, node_b) -> tuple: + max_layer = self.layers_d[node_a] + coord_a, coord_b = self.coords_d[node_a], self.coords_d[node_b] + if self.layers_d[node_a] != self.layers_d[node_b]: + # normalize if nodes are not from the same layer + max_layer = max(self.layers_d[node_a], self.layers_d[node_b]) + chunk_a = self.cg.get_parent_chunk_id(node_a, parent_layer=max_layer) + chunk_b = self.cg.get_parent_chunk_id(node_b, parent_layer=max_layer) + coord_a, coord_b = self.cg.get_chunk_coordinates_multiple( + [chunk_a, chunk_b] + ) + return max_layer, tuple(coord_a), tuple(coord_b) + + def _get_filtered_l2ids(self, node_a, node_b, padding: int): + """ + Finds L2 IDs along opposing faces for given nodes. + Filterting is done by first finding L2 chunks along these faces. + Then get their parent chunks iteratively. + Then filter children iteratively using these chunks. + """ + chunks_map = {} + + def _filter(node): + result = [] + children = np.array([node], dtype=basetypes.NODE_ID) + while True: + chunk_ids = self.cg.get_chunk_ids_from_node_ids(children) + mask = np.isin(chunk_ids, chunks_map[node]) + children = children[mask] + + mask = self.cg.get_chunk_layers(children) == 2 + result.append(children[mask]) + + mask = self.cg.get_chunk_layers(children) > 2 + if children[mask].size == 0: + break + if PARENTS_CACHE is None: + children = self.cg.get_children(children[mask], flatten=True) + else: + children = self._get_children_from_cache(children[mask]) + return np.concatenate(result) + + mlayer, coord_a, coord_b = self._get_normalized_coords(node_a, node_b) + chunks_a, chunks_b = get_l2chunkids_along_boundary( + self.cg.meta, mlayer, coord_a, coord_b, padding + ) + + chunks_map[node_a] = [[self.cg.get_chunk_id(node_a)]] + chunks_map[node_b] = [[self.cg.get_chunk_id(node_b)]] + _layer = 2 + while _layer < mlayer: + chunks_map[node_a].append(chunks_a) + chunks_map[node_b].append(chunks_b) + chunks_a = np.unique(self.cg.get_parent_chunk_id_multiple(chunks_a)) + chunks_b = np.unique(self.cg.get_parent_chunk_id_multiple(chunks_b)) + _layer += 1 + chunks_map[node_a] = np.concatenate(chunks_map[node_a]) + chunks_map[node_b] = np.concatenate(chunks_map[node_b]) + return int(mlayer), _filter(node_a), _filter(node_b) + + def _populate_parents_cache(self, children: np.ndarray): + global PARENTS_CACHE + + not_cached = [] + for child in children: + try: + # reset lru index, these will be needed soon + _ = PARENTS_CACHE[child] + except KeyError: + not_cached.append(child) + + all_parents = self.cg.get_parents(not_cached, current=False) + for child, parents in zip(not_cached, all_parents): + PARENTS_CACHE[child] = {} + for parent, ts in parents: + PARENTS_CACHE[child][ts] = parent + + def _get_hierarchy(self, nodes, layer): + _hierarchy = [nodes] + for _a in nodes: + _hierarchy.append( + self.cg.get_root( + _a, + time_stamp=self.parent_ts, + stop_layer=layer, + get_all_parents=True, + ceil=False, + raw_only=True, + ) + ) + _children = self.cg.get_children(_a, raw_only=True) + _children_layers = self.cg.get_chunk_layers(_children) + _hierarchy.append(_children[_children_layers == 2]) + _children = _children[_children_layers > 2] + while _children.size: + _hierarchy.append(_children) + _children = self.cg.get_children( + _children, flatten=True, raw_only=True + ) + _children_layers = self.cg.get_chunk_layers(_children) + _hierarchy.append(_children[_children_layers == 2]) + _children = _children[_children_layers > 2] + return np.concatenate(_hierarchy) + + def _check_cross_edges_from_a(self, node_b, nodes_a, layer, parent_ts): + """ + Checks to match cross edges from partners_a + to hierarchy of potential node from partner b. + """ + if len(nodes_a) == 0: + return False + + _hierarchy_b = self.cg.get_root( + node_b, + time_stamp=parent_ts, + stop_layer=layer, + get_all_parents=True, + ceil=False, + raw_only=True, + ) + _hierarchy_b = np.append(_hierarchy_b, node_b) + _cx_edges_d_from_a = self.cg.get_cross_chunk_edges( + nodes_a, time_stamp=parent_ts + ) + for _edges_d_from_a in _cx_edges_d_from_a.values(): + _edges_from_a = _edges_d_from_a.get(layer, types.empty_2d) + nodes_b_from_a = _edges_from_a[:, 1] + hierarchy_b_from_a = self._get_hierarchy(nodes_b_from_a, layer) + _mask = np.isin(hierarchy_b_from_a, _hierarchy_b) + if np.any(_mask): + return True + return False + + def _check_hierarchy_a_from_b(self, parents_a, nodes_a_from_b, layer, parent_ts): + """ + Checks for overlap between hierarchy of a, + and hierarchy of a identified from partners of b. + """ + if len(nodes_a_from_b) == 0: + return False + + _hierarchy_a = [parents_a] + for _a in parents_a: + _hierarchy_a.append( + self.cg.get_root( + _a, + time_stamp=parent_ts, + stop_layer=layer, + get_all_parents=True, + ceil=False, + raw_only=True, + ) + ) + hierarchy_a = np.concatenate(_hierarchy_a) + hierarchy_a_from_b = self._get_hierarchy(nodes_a_from_b, layer) + return np.any(np.isin(hierarchy_a_from_b, hierarchy_a)) + + def _get_parents_b(self, edges, parent_ts, layer, fallback: bool = False): + """ + Attempts to find new partner side nodes. + Gets new partners at parent_ts using supervoxels, at `parent_ts`. + Searches for new partners that may have any edges to `edges[:,0]`. + """ + if PARENTS_CACHE is None: + # this cache is set only during migration + # also, fallback is not applicable if no migration + children_b = self.cg.get_children(edges[:, 1], flatten=True) + parents_b = np.unique( + self.cg.get_parents(children_b, time_stamp=parent_ts) + ) + fallback = False + else: + children_b = self._get_children_from_cache(edges[:, 1]) + self._populate_parents_cache(children_b) + _parents_b, missing = get_parents_at_timestamp( + children_b, PARENTS_CACHE, time_stamp=parent_ts, unique=True + ) + # handle cache miss cases + _parents_b_missed = np.unique( + self.cg.get_parents(missing, time_stamp=parent_ts) + ) + parents_b = np.concatenate([_parents_b, _parents_b_missed]) + + parents_a = np.unique(edges[:, 0]) + stale_a = get_stale_nodes(self.cg, parents_a, parent_ts=parent_ts) + if stale_a.size == parents_a.size or fallback: + # this is applicable only for v2 to v3 migration + # handle cases when source nodes in `edges[:,0]` are stale + atomic_edges_d = self.cg.get_atomic_cross_edges(stale_a) + partners = [types.empty_1d] + for _edges_d in atomic_edges_d.values(): + _edges = _edges_d.get(layer, types.empty_2d) + partners.append(_edges[:, 1]) + partners = np.concatenate(partners) + return np.unique(self.cg.get_parents(partners, time_stamp=parent_ts)) + + _cx_edges_d = self.cg.get_cross_chunk_edges(parents_b, time_stamp=parent_ts) + _parents_b = [] + for _node, _edges_d in _cx_edges_d.items(): + _edges = _edges_d.get(layer, types.empty_2d) + if self._check_cross_edges_from_a( + _node, _edges[:, 1], layer, parent_ts + ): + _parents_b.append(_node) + elif self._check_hierarchy_a_from_b( + parents_a, _edges[:, 1], layer, parent_ts + ): + _parents_b.append(_node) + else: + _new_ids = list(self.cg.cache.new_ids) + if np.any(np.isin(_new_ids, parents_a)): + _parents_b.append(_node) + return np.array(_parents_b, dtype=basetypes.NODE_ID) + + def _get_parents_b_with_chunk_mask( + self, + l2ids_b: np.ndarray, + nodes_b_from_a: np.ndarray, + max_ts: datetime.datetime, + edge, + ): + chunks_old = self.cg.get_chunk_ids_from_node_ids(l2ids_b) + chunks_new = self.cg.get_chunk_ids_from_node_ids(nodes_b_from_a) + chunk_mask = np.isin(chunks_new, chunks_old) + nodes_b_from_a = nodes_b_from_a[chunk_mask] + _stale_nodes = get_stale_nodes(self.cg, nodes_b_from_a, parent_ts=max_ts) + assert _stale_nodes.size == 0, f"{edge}, {_stale_nodes}, {max_ts}" + return nodes_b_from_a + + def _get_cx_edges(self, l2ids_a, max_node_ts, edge_layer, raw_only: bool = True): + _edges_d = self.cg.get_cross_chunk_edges( + node_ids=l2ids_a, time_stamp=max_node_ts, raw_only=raw_only + ) + _edges = [] + for v in _edges_d.values(): + if edge_layer in v: + _edges.append(v[edge_layer]) + return np.concatenate(_edges) + + def _get_dilated_edges(self, edges): + layers_b = self.cg.get_chunk_layers(edges[:, 1]) + _mask = layers_b == 2 + _l2_edges = [edges[_mask]] + for _edge in edges[~_mask]: + _node_a, _node_b = _edge + _nodes_b = self.cg.get_l2children([_node_b]) + _l2_edges.append( + np.array( + [[_node_a, _b] for _b in _nodes_b], dtype=basetypes.NODE_ID + ) + ) + return np.unique(np.concatenate(_l2_edges), axis=0) + + def _get_new_edge( + self, edge, edge_layer, parent_ts, padding, fallback: bool = False + ): + """ + Attempts to find new edge(s) for the stale `edge`. + * Find L2 IDs on opposite sides of the face in L2 chunks along the face. + * Find new edges between them (before the given timestamp). + * If none found, expand search by adding another layer of L2 chunks. + """ + node_a, node_b = edge + mlayer, l2ids_a, l2ids_b = self._get_filtered_l2ids( + node_a, node_b, padding=padding + ) + if l2ids_a.size == 0 or l2ids_b.size == 0: + return types.empty_2d.copy() + + max_ts = max(self.nodes_ts_map[node_a], self.nodes_ts_map[node_b]) + is_l2_edge = node_a in l2ids_a and node_b in l2ids_b + if is_l2_edge and (l2ids_a.size == 1 and l2ids_b.size == 1): + _edges = np.array([edge], dtype=basetypes.NODE_ID) + else: + try: + _edges = self._get_cx_edges(l2ids_a, max_ts, edge_layer) + except ValueError: + _edges = self._get_cx_edges( + l2ids_a, max_ts, edge_layer, raw_only=False + ) + except ValueError: + return types.empty_2d.copy() + + mask = np.isin(_edges[:, 1], l2ids_b) + if np.any(mask): + parents_b = self._get_parents_b(_edges[mask], parent_ts, edge_layer) + else: + # partner nodes likely lifted, dilate and retry + _edges = self._get_dilated_edges(_edges) + mask = np.isin(_edges[:, 1], l2ids_b) + if np.any(mask): + parents_b = self._get_parents_b( + _edges[mask], parent_ts, edge_layer + ) + else: + # if none of `l2ids_b` were found in edges, `l2ids_a` already have new edges + # so get the new identities of `l2ids_b` by using chunk mask + try: + parents_b = self._get_parents_b_with_chunk_mask( + l2ids_b, _edges[:, 1], max_ts, edge + ) + except AssertionError: + parents_b = [] + if fallback: + parents_b = self._get_parents_b( + _edges, parent_ts, edge_layer, True + ) + + parents_b = np.unique( + get_new_nodes(self.cg, parents_b, mlayer, parent_ts) + ) + parents_a = np.array([node_a] * parents_b.size, dtype=basetypes.NODE_ID) + return np.column_stack((parents_a, parents_b)) + + def run(self): + result = [types.empty_2d] + for edge_layer, _edge in zip(self.edge_layers, self.stale_edges): + max_chebyshev_distance = int(environ.get("MAX_CHEBYSHEV_DISTANCE", 3)) + for pad in range(0, max_chebyshev_distance + 1): + fallback = pad == max_chebyshev_distance + _new_edges = self._get_new_edge( + _edge, + edge_layer, + self.parent_ts, + padding=pad, + fallback=fallback, + ) + if _new_edges.size: + break + logging.info(f"{_edge}, expanding search with padding {pad+1}.") + assert ( + _new_edges.size + ), f"No new edge found {_edge}; {edge_layer}, {self.parent_ts}" + result.append(_new_edges) + return np.concatenate(result) + +def get_latest_edges( + cg, + stale_edges: Iterable, + edge_layers: Iterable, + parent_ts: datetime.datetime = None, +) -> np.ndarray: + """ + For each of stale_edges [[`node`, `partner`]], get their L2 edge equivalent. + Then get supervoxels of those L2 IDs and get parent(s) at `node` level. + These parents would be the new identities for the stale `partner`. + """ + return LatestEdgesFinder(cg, stale_edges, edge_layers, parent_ts).run() + + +def get_latest_edges_wrapper( + cg, cx_edges_d: dict, parent_ts: datetime.datetime = None +) -> tuple[dict, np.ndarray]: + """ + Helper function to filter stale edges and replace with latest edges. + Filters out edges with nodes stale in source, edges[:,0], at given timestamp. + """ + nodes = [types.empty_1d] + new_cx_edges_d = {0: types.empty_2d} + + all_edges = np.concatenate(list(cx_edges_d.values())) + all_edge_nodes = np.unique(all_edges) + all_stale_nodes = get_stale_nodes(cg, all_edge_nodes, parent_ts=parent_ts) + if all_stale_nodes.size == 0: + return cx_edges_d, all_edge_nodes + + for layer, _cx_edges in cx_edges_d.items(): + if _cx_edges.size == 0: + continue + + _new_cx_edges = [types.empty_2d] + _edge_layers = np.array([layer] * len(_cx_edges), dtype=int) + + stale_source_mask = np.isin(_cx_edges[:, 0], all_stale_nodes) + _new_cx_edges.append(_cx_edges[stale_source_mask]) + + _cx_edges = _cx_edges[~stale_source_mask] + _edge_layers = _edge_layers[~stale_source_mask] + stale_destination_mask = np.isin(_cx_edges[:, 1], all_stale_nodes) + _new_cx_edges.append(_cx_edges[~stale_destination_mask]) + + if np.any(stale_destination_mask): + stale_edges = _cx_edges[stale_destination_mask] + stale_edge_layers = _edge_layers[stale_destination_mask] + latest_edges = get_latest_edges( + cg, + stale_edges, + stale_edge_layers, + parent_ts=parent_ts, + ) + logging.debug(f"{stale_edges} -> {latest_edges}; {parent_ts}") + _new_cx_edges.append(latest_edges) + new_cx_edges_d[layer] = np.concatenate(_new_cx_edges) + nodes.append(np.unique(new_cx_edges_d[layer])) + return new_cx_edges_d, np.concatenate(nodes) diff --git a/pychunkedgraph/graph/edges/utils.py b/pychunkedgraph/graph/edges/utils.py index 034ca6ebc..f79debf94 100644 --- a/pychunkedgraph/graph/edges/utils.py +++ b/pychunkedgraph/graph/edges/utils.py @@ -8,16 +8,18 @@ from typing import Tuple from typing import Iterable from typing import Optional +from collections import defaultdict +from functools import reduce import fastremap import numpy as np from . import Edges from . import EDGE_TYPES -from ..types import empty_2d from ..utils import basetypes from ..chunks import utils as chunk_utils from ..meta import ChunkedGraphMeta +from ...utils.general import in2d def concatenate_chunk_edges(chunk_edge_dicts: Iterable) -> Dict: @@ -45,18 +47,21 @@ def concatenate_chunk_edges(chunk_edge_dicts: Iterable) -> Dict: return edges_dict -def concatenate_cross_edge_dicts(edges_ds: Iterable[Dict]) -> Dict: +def concatenate_cross_edge_dicts( + edges_ds: Iterable[Dict], unique: bool = False +) -> Dict: """Combines cross chunk edge dicts of form {layer id : edge list}.""" - from collections import defaultdict - result_d = defaultdict(list) - for edges_d in edges_ds: for layer, edges in edges_d.items(): result_d[layer].append(edges) for layer, edge_lists in result_d.items(): - result_d[layer] = np.concatenate(edge_lists) + edge_lists = [np.asarray(e, dtype=basetypes.NODE_ID) for e in edge_lists] + edges = np.concatenate(edge_lists) + if unique: + edges = np.unique(edges, axis=0) + result_d[layer] = edges return result_d @@ -131,7 +136,7 @@ def categorize_edges_v2( def get_cross_chunk_edges_layer(meta: ChunkedGraphMeta, cross_edges: Iterable): - """Computes the layer in which a cross chunk edge becomes relevant. + """Computes the layer in which an atomic cross chunk edge becomes relevant. I.e. if a cross chunk edge links two nodes in layer 4 this function returns 3. :param cross_edges: n x 2 array @@ -152,40 +157,7 @@ def get_cross_chunk_edges_layer(meta: ChunkedGraphMeta, cross_edges: Iterable): return cross_chunk_edge_layers -def filter_min_layer_cross_edges( - meta: ChunkedGraphMeta, cross_edges_d: Dict, node_layer: int = 2 -) -> Tuple[int, Iterable]: - """ - Given a dict of cross chunk edges {layer: edges} - Return the first layer with cross edges. - """ - for layer in range(node_layer, meta.layer_count): - edges_ = cross_edges_d.get(layer, empty_2d) - if edges_.size: - return (layer, edges_) - return (meta.layer_count, edges_) - - -def filter_min_layer_cross_edges_multiple( - meta: ChunkedGraphMeta, l2id_atomic_cross_edges_ds: Iterable, node_layer: int = 2 -) -> Tuple[int, Iterable]: - """ - Given a list of dicts of cross chunk edges [{layer: edges}] - Return the first layer with cross edges. - """ - min_layer = meta.layer_count - for edges_d in l2id_atomic_cross_edges_ds: - layer_, _ = filter_min_layer_cross_edges(meta, edges_d, node_layer=node_layer) - min_layer = min(min_layer, layer_) - edges = [empty_2d] - for edges_d in l2id_atomic_cross_edges_ds: - edges.append(edges_d.get(min_layer, empty_2d)) - return min_layer, np.concatenate(edges) - - def get_edges_status(cg, edges: Iterable, time_stamp: Optional[float] = None): - from ...utils.general import in2d - coords0 = chunk_utils.get_chunk_coordinates_multiple(cg.meta, edges[:, 0]) coords1 = chunk_utils.get_chunk_coordinates_multiple(cg.meta, edges[:, 1]) @@ -214,3 +186,20 @@ def get_edges_status(cg, edges: Iterable, time_stamp: Optional[float] = None): active_status.extend(mask) active_status = np.array(active_status, dtype=bool) return existence_status, active_status + + +def filter_inactive_cross_edges( + cg, all_chunk_edges: Edges, time_stamp: Optional[float] = None +): + result = [] + layers = cg.get_cross_chunk_edges_layer(all_chunk_edges.get_pairs()) + for layer in np.unique(layers): + layer_mask = layers == layer + parent_layer = layer + 1 + layer_edges = all_chunk_edges[layer_mask] + n1, n2 = layer_edges.node_ids1, layer_edges.node_ids2 + parents1 = cg.get_roots(n1, stop_layer=parent_layer, time_stamp=time_stamp) + parents2 = cg.get_roots(n2, stop_layer=parent_layer, time_stamp=time_stamp) + mask = parents1 == parents2 + result.append(layer_edges[mask]) + return reduce(lambda x, y: x + y, result, Edges([], [])) diff --git a/pychunkedgraph/graph/edits.py b/pychunkedgraph/graph/edits.py index be2eee1c6..25f31dd02 100644 --- a/pychunkedgraph/graph/edits.py +++ b/pychunkedgraph/graph/edits.py @@ -1,46 +1,70 @@ # pylint: disable=invalid-name, missing-docstring, too-many-locals, c-extension-no-member -import datetime +import datetime, logging, random from typing import Dict from typing import List from typing import Tuple from typing import Iterable +from typing import Set from collections import defaultdict -import numpy as np import fastremap +import numpy as np + +from pychunkedgraph.debug.profiler import HierarchicalProfiler, get_profiler from . import types from . import attributes from . import cache as cache_utils +from .edges import get_latest_edges_wrapper, get_new_nodes from .edges.utils import concatenate_cross_edge_dicts from .edges.utils import merge_cross_edge_dicts from .utils import basetypes from .utils import flatgraph from .utils.serializers import serialize_uint64 -from ..logging.log_db import TimeIt from ..utils.general import in2d +from ..debug.utils import sanity_check, sanity_check_single + +logger = logging.getLogger(__name__) def _init_old_hierarchy(cg, l2ids: np.ndarray, parent_ts: datetime.datetime = None): - new_old_id_d = defaultdict(set) - old_new_id_d = defaultdict(set) + """ + Populates old hierarcy from child to root and also gets children of intermediate nodes. + These will be needed later and cached in cg.cache used during an edit. + """ + all_parents = [] old_hierarchy_d = {id_: {2: id_} for id_ in l2ids} + node_layer_parent_map = cg.get_all_parents_dict_multiple( + l2ids, time_stamp=parent_ts + ) for id_ in l2ids: - layer_parent_d = cg.get_all_parents_dict(id_, time_stamp=parent_ts) + layer_parent_d = node_layer_parent_map[id_] old_hierarchy_d[id_].update(layer_parent_d) for parent in layer_parent_d.values(): + all_parents.append(parent) old_hierarchy_d[parent] = old_hierarchy_d[id_] - return new_old_id_d, old_new_id_d, old_hierarchy_d + children = cg.get_children(all_parents, flatten=True) + _ = cg.get_parents(children, time_stamp=parent_ts) + return old_hierarchy_d + + +def flip_ids(id_map, node_ids): + """ + returns old or new ids according to the map + """ + ids = [np.asarray(list(id_map[id_]), dtype=basetypes.NODE_ID) for id_ in node_ids] + ids.append(types.empty_1d) # concatenate needs at least one array + return np.concatenate(ids).astype(basetypes.NODE_ID) def _analyze_affected_edges( cg, atomic_edges: Iterable[np.ndarray], parent_ts: datetime.datetime = None ) -> Tuple[Iterable, Dict]: """ - Determine if atomic edges are within the chunk. - If not, they are cross edges between two L2 IDs in adjacent chunks. - Returns edges between L2 IDs and atomic cross edges. + Returns l2 edges within chunk and self edges for nodes in cross chunk edges. + + Also returns new cross edges dicts for nodes crossing chunk boundary. """ supervoxels = np.unique(atomic_edges) parents = cg.get_parents(supervoxels, time_stamp=parent_ts) @@ -51,23 +75,29 @@ def _analyze_affected_edges( for edge_ in atomic_edges[edge_layers == 1] ] - # cross chunk edges - atomic_cross_edges_d = defaultdict(lambda: defaultdict(list)) + cross_edges_d = defaultdict(lambda: defaultdict(list)) for layer in range(2, cg.meta.layer_count): layer_edges = atomic_edges[edge_layers == layer] if not layer_edges.size: continue for edge in layer_edges: - parent_1 = sv_parent_d[edge[0]] - parent_2 = sv_parent_d[edge[1]] - atomic_cross_edges_d[parent_1][layer].append(edge) - atomic_cross_edges_d[parent_2][layer].append(edge[::-1]) - parent_edges.extend([[parent_1, parent_1], [parent_2, parent_2]]) - return (parent_edges, atomic_cross_edges_d) - - -def _get_relevant_components(edges: np.ndarray, supervoxels: np.ndarray) -> Tuple: - edges = np.concatenate([edges, np.vstack([supervoxels, supervoxels]).T]) + parent0 = sv_parent_d[edge[0]] + parent1 = sv_parent_d[edge[1]] + cross_edges_d[parent0][layer].append([parent0, parent1]) + cross_edges_d[parent1][layer].append([parent1, parent0]) + parent_edges.extend([[parent0, parent0], [parent1, parent1]]) + # Convert inner Python lists to typed numpy arrays to avoid + # dtype promotion issues when concatenated with uint64 arrays. + for node_id in cross_edges_d: + for layer in cross_edges_d[node_id]: + cross_edges_d[node_id][layer] = np.array( + cross_edges_d[node_id][layer], dtype=basetypes.NODE_ID + ).reshape(-1, 2) + return parent_edges, cross_edges_d + + +def _get_relevant_components(edges: np.ndarray, svs: np.ndarray) -> Tuple: + edges = np.concatenate([edges, np.vstack([svs, svs]).T]).astype(basetypes.NODE_ID) graph, _, _, graph_ids = flatgraph.build_gt_graph(edges, make_directed=True) ccs = flatgraph.connected_components(graph) relevant_ccs = [] @@ -75,7 +105,7 @@ def _get_relevant_components(edges: np.ndarray, supervoxels: np.ndarray) -> Tupl # when merging, there must be only two components for cc_idx in ccs: cc = graph_ids[cc_idx] - if np.any(np.in1d(supervoxels, cc)): + if np.any(np.isin(svs, cc)): relevant_ccs.append(cc) assert len(relevant_ccs) == 2, "must be 2 components" return relevant_ccs @@ -89,9 +119,7 @@ def merge_preprocess( parent_ts: datetime.datetime = None, ) -> np.ndarray: """ - Determine if a fake edge needs to be added. - Get subgraph within the bounding box - Add fake edge if there are no inactive edges between two components. + Check and return inactive edges in the subgraph. """ edge_layers = cg.get_cross_chunk_edges_layer(subgraph_edges) active_edges = [types.empty_2d] @@ -108,19 +136,20 @@ def merge_preprocess( active_edges.append(active) inactive_edges.append(inactive) - relevant_ccs = _get_relevant_components(np.concatenate(active_edges), supervoxels) - inactive = np.concatenate(inactive_edges) + active_edges = np.concatenate(active_edges).astype(basetypes.NODE_ID) + inactive_edges = np.concatenate(inactive_edges).astype(basetypes.NODE_ID) + relevant_ccs = _get_relevant_components(active_edges, supervoxels) _inactive = [types.empty_2d] # source to sink edges - source_mask = np.in1d(inactive[:, 0], relevant_ccs[0]) - sink_mask = np.in1d(inactive[:, 1], relevant_ccs[1]) - _inactive.append(inactive[source_mask & sink_mask]) + source_mask = np.isin(inactive_edges[:, 0], relevant_ccs[0]) + sink_mask = np.isin(inactive_edges[:, 1], relevant_ccs[1]) + _inactive.append(inactive_edges[source_mask & sink_mask]) # sink to source edges - sink_mask = np.in1d(inactive[:, 1], relevant_ccs[0]) - source_mask = np.in1d(inactive[:, 0], relevant_ccs[1]) - _inactive.append(inactive[source_mask & sink_mask]) - _inactive = np.concatenate(_inactive) + sink_mask = np.isin(inactive_edges[:, 1], relevant_ccs[0]) + source_mask = np.isin(inactive_edges[:, 0], relevant_ccs[1]) + _inactive.append(inactive_edges[source_mask & sink_mask]) + _inactive = np.concatenate(_inactive).astype(basetypes.NODE_ID) return np.unique(_inactive, axis=0) if _inactive.size else types.empty_2d @@ -142,11 +171,11 @@ def check_fake_edges( ) ) assert len(roots) == 2, "edges must be from 2 roots" - print("found inactive", len(inactive_edges)) return inactive_edges, [] rows = [] supervoxels = atomic_edges.ravel() + # fake edges are stored with l2 chunks chunk_ids = cg.get_chunk_ids_from_node_ids( cg.get_parents(supervoxels, time_stamp=parent_ts) ) @@ -177,7 +206,6 @@ def check_fake_edges( time_stamp=time_stamp, ) ) - print("no inactive", len(atomic_edges)) return atomic_edges, rows @@ -189,90 +217,141 @@ def add_edges( time_stamp: datetime.datetime = None, parent_ts: datetime.datetime = None, allow_same_segment_merge=False, + stitch_mode: bool = False, + do_sanity_check: bool = True, ): - edges, l2_atomic_cross_edges_d = _analyze_affected_edges( + edges, l2_cross_edges_d = _analyze_affected_edges( cg, atomic_edges, parent_ts=parent_ts ) l2ids = np.unique(edges) - if not allow_same_segment_merge: - assert ( - np.unique(cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)).size - == 2 - ), "L2 IDs must belong to different roots." - new_old_id_d, old_new_id_d, old_hierarchy_d = _init_old_hierarchy( - cg, l2ids, parent_ts=parent_ts - ) + if not allow_same_segment_merge and not stitch_mode: + roots = cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts) + assert np.unique(roots).size >= 2, "L2 IDs must belong to different roots." + + new_old_id_d = defaultdict(set) + old_new_id_d = defaultdict(set) + old_hierarchy_d = _init_old_hierarchy(cg, l2ids, parent_ts=parent_ts) atomic_children_d = cg.get_children(l2ids) - atomic_cross_edges_d = merge_cross_edge_dicts( - cg.get_atomic_cross_edges(l2ids), l2_atomic_cross_edges_d + cross_edges_d = merge_cross_edge_dicts( + cg.get_cross_chunk_edges(l2ids, time_stamp=parent_ts), l2_cross_edges_d ) - graph, _, _, graph_ids = flatgraph.build_gt_graph(edges, make_directed=True) components = flatgraph.connected_components(graph) + + chunk_count_map = defaultdict(int) + for cc_indices in components: + l2ids_ = graph_ids[cc_indices] + chunk = cg.get_chunk_id(l2ids_[0]) + chunk_count_map[chunk] += 1 + + chunk_ids = list(chunk_count_map.keys()) + random.shuffle(chunk_ids) + chunk_new_ids_map = {} + for chunk_id in chunk_ids: + new_ids = cg.id_client.create_node_ids(chunk_id, size=chunk_count_map[chunk_id]) + chunk_new_ids_map[chunk_id] = list(new_ids) + new_l2_ids = [] for cc_indices in components: l2ids_ = graph_ids[cc_indices] - new_id = cg.id_client.create_node_id(cg.get_chunk_id(l2ids_[0])) - cg.cache.children_cache[new_id] = np.concatenate( - [atomic_children_d[l2id] for l2id in l2ids_] - ) - cg.cache.atomic_cx_edges_cache[new_id] = concatenate_cross_edge_dicts( - [atomic_cross_edges_d[l2id] for l2id in l2ids_] - ) - cache_utils.update( - cg.cache.parents_cache, cg.cache.children_cache[new_id], new_id - ) + new_id = chunk_new_ids_map[cg.get_chunk_id(l2ids_[0])].pop() new_l2_ids.append(new_id) new_old_id_d[new_id].update(l2ids_) for id_ in l2ids_: old_new_id_d[id_].add(new_id) - create_parents = CreateParentNodes( - cg, - new_l2_ids=new_l2_ids, - old_hierarchy_d=old_hierarchy_d, - new_old_id_d=new_old_id_d, - old_new_id_d=old_new_id_d, - operation_id=operation_id, - time_stamp=time_stamp, - parent_ts=parent_ts, - ) + # update cache + # map parent to new merged children and vice versa + merged_children = [atomic_children_d[l2id] for l2id in l2ids_] + merged_children = np.concatenate(merged_children).astype(basetypes.NODE_ID) + cg.cache.children_cache[new_id] = merged_children + cache_utils.update(cg.cache.parents_cache, merged_children, new_id) - new_roots = create_parents.run() - new_entries = create_parents.create_new_entries() - return new_roots, new_l2_ids, new_entries + # update cross chunk edges by replacing old_ids with new + # this can be done only after all new IDs have been created + for new_id, cc_indices in zip(new_l2_ids, components): + l2ids_ = graph_ids[cc_indices] + new_cx_edges_d = {} + cx_edges = [cross_edges_d[l2id] for l2id in l2ids_] + cx_edges_d = concatenate_cross_edge_dicts(cx_edges, unique=True) + temp_map = {k: next(iter(v)) for k, v in old_new_id_d.items()} + for layer, edges in cx_edges_d.items(): + edges = fastremap.remap(edges, temp_map, preserve_missing_labels=True) + new_cx_edges_d[layer] = edges + assert np.all(edges[:, 0] == new_id) + cg.cache.cross_chunk_edges_cache[new_id] = new_cx_edges_d + + profiler = get_profiler() + profiler.reset() + with profiler.profile("run"): + create_parents = CreateParentNodes( + cg, + new_l2_ids=new_l2_ids, + old_hierarchy_d=old_hierarchy_d, + new_old_id_d=new_old_id_d, + old_new_id_d=old_new_id_d, + operation_id=operation_id, + time_stamp=time_stamp, + parent_ts=parent_ts, + stitch_mode=stitch_mode, + do_sanity_check=do_sanity_check, + profiler=profiler, + ) + new_roots = create_parents.run() + if do_sanity_check: + sanity_check(cg, new_roots, operation_id) + create_parents.create_new_entries() + profiler.print_report(operation_id) + return new_roots, new_l2_ids, create_parents.new_entries -def _process_l2_agglomeration( + +def _split_l2_agglomeration( + cg, + operation_id: int, agg: types.Agglomeration, removed_edges: np.ndarray, - atomic_cross_edges_d: Dict[int, np.ndarray], + parent_ts: datetime.datetime = None, ): """ - For a given L2 id, remove given edges - and calculate new connected components. + For a given L2 id, remove given edges; calculate new connected components. """ chunk_edges = agg.in_edges.get_pairs() - cross_edges = np.concatenate([types.empty_2d, *atomic_cross_edges_d.values()]) chunk_edges = chunk_edges[~in2d(chunk_edges, removed_edges)] - cross_edges = cross_edges[~in2d(cross_edges, removed_edges)] - isolated_ids = agg.supervoxels[~np.in1d(agg.supervoxels, chunk_edges)] + cross_edges = agg.cross_edges.get_pairs() + # we must avoid the cache to read roots to get segment state before edit began + parents = cg.get_parents(cross_edges[:, 0], time_stamp=parent_ts, raw_only=True) + + # if there are cross edges, there must be a single parent. + # if there aren't any, there must be no parents. XOR these 2 conditions. + err = f"got cross edges from more than one l2 node; op {operation_id}" + assert (np.unique(parents).size == 1) != (cross_edges.size == 0), err + + if cross_edges.size: + # inactive edges must be filtered out + root = cg.get_root(parents[0], time_stamp=parent_ts, raw_only=True) + neighbor_roots = cg.get_roots( + cross_edges[:, 1], raw_only=True, time_stamp=parent_ts + ) + active_mask = neighbor_roots == root + cross_edges = cross_edges[active_mask] + cross_edges = cross_edges[~in2d(cross_edges, removed_edges)] + isolated_ids = agg.supervoxels[~np.isin(agg.supervoxels, chunk_edges)] isolated_edges = np.column_stack((isolated_ids, isolated_ids)) - graph, _, _, graph_ids = flatgraph.build_gt_graph( - np.concatenate([chunk_edges, isolated_edges]), make_directed=True - ) + _edges = np.concatenate([chunk_edges, isolated_edges]).astype(basetypes.NODE_ID) + graph, _, _, graph_ids = flatgraph.build_gt_graph(_edges, make_directed=True) return flatgraph.connected_components(graph), graph_ids, cross_edges def _filter_component_cross_edges( - cc_ids: np.ndarray, cross_edges: np.ndarray, cross_edge_layers: np.ndarray + component_ids: np.ndarray, cross_edges: np.ndarray, cross_edge_layers: np.ndarray ) -> Dict[int, np.ndarray]: """ Filters cross edges for a connected component `cc_ids` from `cross_edges` of the complete chunk. """ - mask = np.in1d(cross_edges[:, 0], cc_ids) + mask = np.isin(cross_edges[:, 0], component_ids) cross_edges_ = cross_edges[mask] cross_edge_layers_ = cross_edge_layers[mask] edges_d = {} @@ -288,45 +367,59 @@ def remove_edges( cg, *, atomic_edges: Iterable[np.ndarray], - l2id_agglomeration_d: Dict, - operation_id: basetypes.OPERATION_ID = None, + operation_id: basetypes.OPERATION_ID = None, # type: ignore time_stamp: datetime.datetime = None, parent_ts: datetime.datetime = None, + do_sanity_check: bool = True, ): edges, _ = _analyze_affected_edges(cg, atomic_edges, parent_ts=parent_ts) l2ids = np.unique(edges) - assert ( - np.unique(cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts)).size - == 1 - ), "L2 IDs must belong to same root." - new_old_id_d, old_new_id_d, old_hierarchy_d = _init_old_hierarchy( - cg, l2ids, parent_ts=parent_ts + roots = cg.get_roots(l2ids, assert_roots=True, time_stamp=parent_ts) + assert np.unique(roots).size == 1, "L2 IDs must belong to same root." + + l2id_agglomeration_d, _ = cg.get_l2_agglomerations( + l2ids, active=True, time_stamp=parent_ts ) - l2id_chunk_id_d = dict(zip(l2ids.tolist(), cg.get_chunk_ids_from_node_ids(l2ids))) - atomic_cross_edges_d = cg.get_atomic_cross_edges(l2ids) + new_old_id_d = defaultdict(set) + old_new_id_d = defaultdict(set) + old_hierarchy_d = _init_old_hierarchy(cg, l2ids, parent_ts=parent_ts) + chunk_id_map = dict(zip(l2ids.tolist(), cg.get_chunk_ids_from_node_ids(l2ids))) - removed_edges = np.concatenate([atomic_edges, atomic_edges[:, ::-1]], axis=0) + removed_edges = [atomic_edges, atomic_edges[:, ::-1]] + removed_edges = np.concatenate(removed_edges, axis=0).astype(basetypes.NODE_ID) new_l2_ids = [] for id_ in l2ids: - l2_agg = l2id_agglomeration_d[id_] - ccs, graph_ids, cross_edges = _process_l2_agglomeration( - l2_agg, removed_edges, atomic_cross_edges_d[id_] + agg = l2id_agglomeration_d[id_] + ccs, graph_ids, cross_edges = _split_l2_agglomeration( + cg, operation_id, agg, removed_edges, parent_ts ) - # calculated here to avoid repeat computation in loop + new_parents = cg.id_client.create_node_ids(chunk_id_map[agg.node_id], len(ccs)) + cross_edge_layers = cg.get_cross_chunk_edges_layer(cross_edges) - new_parent_ids = cg.id_client.create_node_ids( - l2id_chunk_id_d[l2_agg.node_id], len(ccs) - ) for i_cc, cc in enumerate(ccs): - new_id = new_parent_ids[i_cc] - cg.cache.children_cache[new_id] = graph_ids[cc] - cg.cache.atomic_cx_edges_cache[new_id] = _filter_component_cross_edges( - graph_ids[cc], cross_edges, cross_edge_layers - ) - cache_utils.update(cg.cache.parents_cache, graph_ids[cc], new_id) + new_id = new_parents[i_cc] new_l2_ids.append(new_id) new_old_id_d[new_id].add(id_) old_new_id_d[id_].add(new_id) + cg.cache.children_cache[new_id] = graph_ids[cc] + cache_utils.update(cg.cache.parents_cache, graph_ids[cc], new_id) + cg.cache.cross_chunk_edges_cache[new_id] = _filter_component_cross_edges( + graph_ids[cc], cross_edges, cross_edge_layers + ) + + cx_edges_d = cg.get_cross_chunk_edges(new_l2_ids, time_stamp=parent_ts) + for new_id in new_l2_ids: + new_cx_edges_d = cx_edges_d.get(new_id, {}) + for layer, edges in new_cx_edges_d.items(): + svs = np.unique(edges) + parents = cg.get_parents(svs, time_stamp=parent_ts) + temp_map = dict(zip(svs, parents)) + + edges = fastremap.remap(edges, temp_map, preserve_missing_labels=True) + edges = np.unique(edges, axis=0) + new_cx_edges_d[layer] = edges + assert np.all(edges[:, 0] == new_id) + cg.cache.cross_chunk_edges_cache[new_id] = new_cx_edges_d create_parents = CreateParentNodes( cg, @@ -337,10 +430,154 @@ def remove_edges( operation_id=operation_id, time_stamp=time_stamp, parent_ts=parent_ts, + do_sanity_check=do_sanity_check, ) new_roots = create_parents.run() - new_entries = create_parents.create_new_entries() - return new_roots, new_l2_ids, new_entries + + if do_sanity_check: + sanity_check(cg, new_roots, operation_id) + create_parents.create_new_entries() + return new_roots, new_l2_ids, create_parents.new_entries + + +def _get_descendants_batch(cg, node_ids): + """Get all descendants at layers >= 2 for multiple node_ids. + Batches get_children calls by level to reduce IO. + Returns dict {node_id: np.ndarray of descendants}. + """ + if not node_ids: + return {} + results = {nid: [] for nid in node_ids} + # expand_map: {node_to_expand: root_node_id} + expand_map = {nid: nid for nid in node_ids} + + while expand_map: + next_expand = {} + children_d = cg.get_children(list(expand_map.keys())) + for parent, root in expand_map.items(): + children = children_d[parent] + layers = cg.get_chunk_layers(children) + mask = layers >= 2 + results[root].extend(children[mask]) + for c in children[layers > 2]: + next_expand[c] = root + expand_map = next_expand + return { + nid: np.array(desc, dtype=basetypes.NODE_ID) for nid, desc in results.items() + } + + +def _get_counterparts( + cg, node_id: int, cx_edges_d: dict +) -> Tuple[List[int], Dict[int, int]]: + """ + Extract counterparts and their corresponding layers from cross chunk edges. + Returns (counterparts list, counterpart_layers dict). + """ + node_layer = cg.get_chunk_layer(node_id) + counterparts = [] + counterpart_layers = {} + for layer in range(node_layer, cg.meta.layer_count): + layer_edges = cx_edges_d.get(layer, types.empty_2d) + if layer_edges.size == 0: + continue + counterparts.extend(layer_edges[:, 1]) + layers_d = dict(zip(layer_edges[:, 1], [layer] * len(layer_edges[:, 1]))) + counterpart_layers.update(layers_d) + return counterparts, counterpart_layers + + +def _update_neighbor_cx_edges_single( + cg, + new_id: int, + node_map: dict, + counterpart_layers: dict, + all_counterparts_cx_edges_d: dict, + descendants_d: dict, +) -> dict: + """ + For each new_id, update cross chunk edges of its counterparts. + Some of them maybe updated multiple times so we need to collect them first + and then write to storage to consolidate the mutations. + Returns updated counterparts. + """ + node_layer = cg.get_chunk_layer(new_id) + counterparts = list(counterpart_layers.keys()) + cp_cx_edges_d = {cp: all_counterparts_cx_edges_d.get(cp, {}) for cp in counterparts} + updated_counterparts = {} + for counterpart, edges_d in cp_cx_edges_d.items(): + val_dict = {} + counterpart_layer = counterpart_layers[counterpart] + for layer in range(node_layer, cg.meta.layer_count): + edges = edges_d.get(layer, types.empty_2d) + if edges.size == 0: + continue + assert np.all(edges[:, 0] == counterpart) + edges = fastremap.remap(edges, node_map, preserve_missing_labels=True) + if layer == counterpart_layer: + flip_edge = np.array([counterpart, new_id], dtype=basetypes.NODE_ID) + edges = np.concatenate([edges, [flip_edge]]).astype(basetypes.NODE_ID) + descendants = descendants_d[new_id] + mask = np.isin(edges[:, 1], descendants) + if np.any(mask): + masked_edges = edges[mask] + masked_edges[:, 1] = new_id + edges[mask] = masked_edges + edges = np.unique(edges, axis=0) + edges_d[layer] = edges + val_dict[attributes.Connectivity.CrossChunkEdge[layer]] = edges + if not val_dict: + continue + cg.cache.cross_chunk_edges_cache[counterpart] = edges_d + updated_counterparts[counterpart] = val_dict + return updated_counterparts + + +def _update_neighbor_cx_edges( + cg, + new_ids: List[int], + new_old_id: dict, + old_new_id, + *, + time_stamp, + parent_ts, +) -> List: + """ + For each new_id, get counterparts and update its cross chunk edges. + Some of them maybe updated multiple times so we need to collect them first + and then write to storage to consolidate the mutations. + Returns mutations to updated counterparts/partner nodes. + """ + updated_counterparts = {} + newid_cx_edges_d = cg.get_cross_chunk_edges(new_ids, time_stamp=parent_ts) + node_map = {} + for k, v in old_new_id.items(): + if len(v) == 1: + node_map[k] = next(iter(v)) + + all_cps = set() + newid_counterpart_info = {} + for _id in new_ids: + counterparts, cp_layers = _get_counterparts(cg, _id, newid_cx_edges_d[_id]) + all_cps.update(counterparts) + newid_counterpart_info[_id] = cp_layers + + all_cx_edges_d = cg.get_cross_chunk_edges(list(all_cps), time_stamp=parent_ts) + descendants_d = _get_descendants_batch(cg, new_ids) + for new_id in new_ids: + m = {old_id: new_id for old_id in flip_ids(new_old_id, [new_id])} + node_map.update(m) + cp_layers = newid_counterpart_info[new_id] + result = _update_neighbor_cx_edges_single( + cg, new_id, node_map, cp_layers, all_cx_edges_d, descendants_d + ) + updated_counterparts.update(result) + updated_entries = [] + for node, val_dict in updated_counterparts.items(): + rowkey = serialize_uint64(node) + row = cg.client.mutate_row(rowkey, val_dict, time_stamp=time_stamp) + updated_entries.append(row) + return updated_entries class CreateParentNodes: @@ -349,32 +586,39 @@ def __init__( cg, *, new_l2_ids: Iterable, - operation_id: basetypes.OPERATION_ID, + operation_id: basetypes.OPERATION_ID, # type: ignore time_stamp: datetime.datetime, - new_old_id_d: Dict[np.uint64, Iterable[np.uint64]] = None, - old_new_id_d: Dict[np.uint64, Iterable[np.uint64]] = None, + new_old_id_d: Dict[np.uint64, Set[np.uint64]] = None, + old_new_id_d: Dict[np.uint64, Set[np.uint64]] = None, old_hierarchy_d: Dict[np.uint64, Dict[int, np.uint64]] = None, parent_ts: datetime.datetime = None, + stitch_mode: bool = False, + do_sanity_check: bool = True, + profiler: HierarchicalProfiler = None, ): self.cg = cg + self.new_entries = [] self._new_l2_ids = new_l2_ids self._old_hierarchy_d = old_hierarchy_d self._new_old_id_d = new_old_id_d self._old_new_id_d = old_new_id_d - self._new_ids_d = defaultdict(list) # new IDs in each layer - self._cross_edges_d = {} - self._operation_id = operation_id + self._new_ids_d = defaultdict(list) + self._opid = operation_id self._time_stamp = time_stamp - self._last_successful_ts = parent_ts + self._last_ts = parent_ts + self.stitch_mode = stitch_mode + self.do_sanity_check = do_sanity_check + self._profiler = profiler if profiler else get_profiler() def _update_id_lineage( self, - parent: basetypes.NODE_ID, + parent: basetypes.NODE_ID, # type: ignore children: np.ndarray, layer: int, parent_layer: int, ): - mask = np.in1d(children, self._new_ids_d[layer]) + # update newly created children; mask others + mask = np.isin(children, self._new_ids_d[layer]) for child_id in children[mask]: child_old_ids = self._new_old_id_d[child_id] for id_ in child_old_ids: @@ -382,90 +626,148 @@ def _update_id_lineage( self._new_old_id_d[parent].add(old_id) self._old_new_id_d[old_id].add(parent) - def _get_old_ids(self, new_ids): - old_ids = [ - np.array(list(self._new_old_id_d[id_]), dtype=basetypes.NODE_ID) - for id_ in new_ids - ] - return np.concatenate(old_ids) - - def _map_sv_to_parent(self, node_ids, layer, node_map=None): - sv_parent_d = {} - sv_cross_edges = [types.empty_2d] - if node_map is None: - node_map = {} + def _get_connected_components(self, node_ids: np.ndarray, layer: int): + cross_edges_d = self.cg.get_cross_chunk_edges( + node_ids, time_stamp=self._last_ts + ) + cx_edges = [types.empty_2d] for id_ in node_ids: - id_eff = node_map.get(id_, id_) - edges_ = self._cross_edges_d[id_].get(layer, types.empty_2d) - sv_parent_d.update(dict(zip(edges_[:, 0], [id_eff] * len(edges_)))) - sv_cross_edges.append(edges_) - return sv_parent_d, np.concatenate(sv_cross_edges) - - def _get_connected_components( - self, node_ids: np.ndarray, layer: int, lower_layer_ids: np.ndarray - ): - _node_ids = np.concatenate([node_ids, lower_layer_ids]) - cached = np.fromiter(self._cross_edges_d.keys(), dtype=basetypes.NODE_ID) - not_cached = _node_ids[~np.in1d(_node_ids, cached)] - - with TimeIt( - f"get_cross_chunk_edges.{layer}", - self.cg.graph_id, - self._operation_id, - ): - self._cross_edges_d.update( - self.cg.get_cross_chunk_edges(not_cached, all_layers=True) - ) - - sv_parent_d, sv_cross_edges = self._map_sv_to_parent(node_ids, layer) - get_sv_parents = np.vectorize(sv_parent_d.get, otypes=[np.uint64]) - try: - cross_edges = get_sv_parents(sv_cross_edges) - except TypeError: # NoneType error - # if there is a missing parent, try including lower layer ids - # this can happen due to skip connections - - # we want to map all these lower IDs to the current layer - lower_layer_to_layer = self.cg.get_roots( - lower_layer_ids, stop_layer=layer, ceil=False - ) - node_map = {k: v for k, v in zip(lower_layer_ids, lower_layer_to_layer)} - sv_parent_d, sv_cross_edges = self._map_sv_to_parent( - _node_ids, layer, node_map=node_map - ) - get_sv_parents = np.vectorize(sv_parent_d.get, otypes=[np.uint64]) - cross_edges = get_sv_parents(sv_cross_edges) + edges_ = cross_edges_d[id_].get(layer, types.empty_2d) + cx_edges.append(edges_) - cross_edges = np.concatenate([cross_edges, np.vstack([node_ids, node_ids]).T]) - graph, _, _, graph_ids = flatgraph.build_gt_graph( - cross_edges, make_directed=True - ) - return flatgraph.connected_components(graph), graph_ids + cx_edges = [*cx_edges, np.vstack([node_ids, node_ids]).T] + cx_edges = np.concatenate(cx_edges).astype(basetypes.NODE_ID) + graph, _, _, graph_ids = flatgraph.build_gt_graph(cx_edges, make_directed=True) + components = flatgraph.connected_components(graph) + return components, graph_ids def _get_layer_node_ids( self, new_ids: np.ndarray, layer: int ) -> Tuple[np.ndarray, np.ndarray]: # get old identities of new IDs - old_ids = self._get_old_ids(new_ids) + old_ids = flip_ids(self._new_old_id_d, new_ids) # get their parents, then children of those parents - node_ids = self.cg.get_children( - np.unique( - self.cg.get_parents(old_ids, time_stamp=self._last_successful_ts) - ), - flatten=True, - ) + old_parents = self.cg.get_parents(old_ids, time_stamp=self._last_ts) + siblings = self.cg.get_children(np.unique(old_parents), flatten=True) # replace old identities with new IDs - mask = np.in1d(node_ids, old_ids) - node_ids = np.concatenate( - [ - np.array(list(self._old_new_id_d[id_]), dtype=basetypes.NODE_ID) - for id_ in node_ids[mask] - ] - + [node_ids[~mask], new_ids] - ) + mask = np.isin(siblings, old_ids) + node_ids = [flip_ids(self._old_new_id_d, old_ids), siblings[~mask], new_ids] + node_ids = np.concatenate(node_ids).astype(basetypes.NODE_ID) node_ids = np.unique(node_ids) layer_mask = self.cg.get_chunk_layers(node_ids) == layer - return node_ids[layer_mask], node_ids[~layer_mask] + return node_ids[layer_mask] + + def _update_cross_edge_cache_batched(self, new_ids: list): + """ + Batch update cross chunk edges in cache for all new IDs at a layer. + """ + updated_entries = [] + if not new_ids: + return updated_entries + + parent_layer = self.cg.get_chunk_layer(new_ids[0]) + if parent_layer == 2: + # L2 cross edges have already been updated + return updated_entries + + all_children_d = self.cg.get_children(new_ids) + all_children = np.concatenate(list(all_children_d.values())) + all_cx_edges_raw = self.cg.get_cross_chunk_edges( + all_children, time_stamp=self._last_ts + ) + combined_cx_edges = concatenate_cross_edge_dicts(all_cx_edges_raw.values()) + with self._profiler.profile("latest"): + updated_cx_edges, edge_nodes = get_latest_edges_wrapper( + self.cg, combined_cx_edges, parent_ts=self._last_ts + ) + + # update cache with resolved stale edges + val_ds = defaultdict(dict) + children_cx_edges = defaultdict(dict) + for lyr in range(2, self.cg.meta.layer_count): + edges = updated_cx_edges.get(lyr, types.empty_2d) + if len(edges) == 0: + continue + children, inverse = np.unique(edges[:, 0], return_inverse=True) + masks = inverse == np.arange(len(children))[:, None] + for child, mask in zip(children, masks): + children_cx_edges[child][lyr] = edges[mask] + val_ds[child][attributes.Connectivity.CrossChunkEdge[lyr]] = edges[mask] + + for c, cx_edges_map in children_cx_edges.items(): + self.cg.cache.cross_chunk_edges_cache[c] = cx_edges_map + rowkey = serialize_uint64(c) + row = self.cg.client.mutate_row(rowkey, val_ds[c], time_stamp=self._last_ts) + updated_entries.append(row) + + # Distribute results back to each parent's cache + # Key insight: edges[:, 0] are children, map them to their parent + edge_parents = get_new_nodes(self.cg, edge_nodes, parent_layer, self._last_ts) + edge_parents_d = dict(zip(edge_nodes, edge_parents)) + for new_id in new_ids: + children_set = set(all_children_d[new_id]) + parent_cx_edges_d = {} + for layer in range(parent_layer, self.cg.meta.layer_count): + edges = updated_cx_edges.get(layer, types.empty_2d) + if len(edges) == 0: + continue + # Filter to edges whose source is one of this parent's children + mask = np.isin(edges[:, 0], list(children_set)) + if not np.any(mask): + continue + + pedges = edges[mask].copy() + pedges = fastremap.remap( + pedges, edge_parents_d, preserve_missing_labels=True + ) + parent_cx_edges_d[layer] = np.unique(pedges, axis=0) + assert np.all( + pedges[:, 0] == new_id + ), f"OP {self._opid}: mismatch {new_id} != {np.unique(pedges[:, 0])}" + self.cg.cache.cross_chunk_edges_cache[new_id] = parent_cx_edges_d + return updated_entries + + def _get_new_ids(self, chunk_id, count, is_root): + batch_size = count + new_ids = [] + while len(new_ids) < count: + candidate_ids = self.cg.id_client.create_node_ids( + chunk_id, batch_size, root_chunk=is_root + ) + existing = self.cg.client.read_nodes(node_ids=candidate_ids) + non_existing = set(candidate_ids) - existing.keys() + new_ids.extend(non_existing) + batch_size = min(batch_size * 2, 2**16) + return new_ids[:count] + + def _get_new_parents(self, layer, ccs, graph_ids) -> tuple[dict, dict]: + cc_layer_chunk_map = {} + size_map = defaultdict(int) + for i, cc_idx in enumerate(ccs): + parent_layer = layer + 1 # must be reset for each connected component + cc_ids = graph_ids[cc_idx] + if len(cc_ids) == 1: + # skip connection + parent_layer = self.cg.meta.layer_count + cx_edges_d = self.cg.get_cross_chunk_edges( + [cc_ids[0]], time_stamp=self._last_ts + ) + for l in range(layer + 1, self.cg.meta.layer_count): + if len(cx_edges_d[cc_ids[0]].get(l, types.empty_2d)) > 0: + parent_layer = l + break + chunk_id = self.cg.get_parent_chunk_id(cc_ids[0], parent_layer) + cc_layer_chunk_map[i] = (parent_layer, chunk_id) + size_map[chunk_id] += 1 + + chunk_ids = list(size_map.keys()) + random.shuffle(chunk_ids) + chunk_new_ids_map = {} + layers = self.cg.get_chunk_layers(chunk_ids) + for c, l in zip(chunk_ids, layers): + is_root = l == self.cg.meta.layer_count + chunk_new_ids_map[c] = self._get_new_ids(c, size_map[c], is_root) + return chunk_new_ids_map, cc_layer_chunk_map def _create_new_parents(self, layer: int): """ @@ -478,33 +780,35 @@ def _create_new_parents(self, layer: int): update parent old IDs """ new_ids = self._new_ids_d[layer] - layer_node_ids, lower_layer_ids = self._get_layer_node_ids(new_ids, layer) - components, graph_ids = self._get_connected_components( - layer_node_ids, layer, lower_layer_ids - ) - for cc_indices in components: - parent_layer = layer + 1 - cc_ids = graph_ids[cc_indices] - if len(cc_ids) == 1: - # skip connection - parent_layer = self.cg.meta.layer_count - for l in range(layer + 1, self.cg.meta.layer_count): - if len(self._cross_edges_d[cc_ids[0]].get(l, types.empty_2d)) > 0: - parent_layer = l - break + layer_node_ids = self._get_layer_node_ids(new_ids, layer) + ccs, _ids = self._get_connected_components(layer_node_ids, layer) + new_parents_map, cc_layer_chunk_map = self._get_new_parents(layer, ccs, _ids) + + for i, cc_indices in enumerate(ccs): + cc_ids = _ids[cc_indices] + parent_layer, chunk_id = cc_layer_chunk_map[i] + parent = new_parents_map[chunk_id].pop() + + self._new_ids_d[parent_layer].append(parent) + self._update_id_lineage(parent, cc_ids, layer, parent_layer) + self.cg.cache.children_cache[parent] = cc_ids + cache_utils.update(self.cg.cache.parents_cache, cc_ids, parent) + if not self.do_sanity_check: + continue - parent_id = self.cg.id_client.create_node_id( - self.cg.get_parent_chunk_id(cc_ids[0], parent_layer), - root_chunk=parent_layer == self.cg.meta.layer_count, - ) - self._new_ids_d[parent_layer].append(parent_id) - self.cg.cache.children_cache[parent_id] = cc_ids - cache_utils.update( - self.cg.cache.parents_cache, - cc_ids, - parent_id, - ) - self._update_id_lineage(parent_id, cc_ids, layer, parent_layer) + try: + sanity_check_single(self.cg, parent, self._opid) + except AssertionError: + pairs = [ + (a, b) for idx, a in enumerate(cc_ids) for b in cc_ids[idx + 1 :] + ] + for c1, c2 in pairs: + l2c1 = self.cg.get_l2children([c1]) + l2c2 = self.cg.get_l2children([c2]) + if np.intersect1d(l2c1, l2c2).size: + c = np.intersect1d(l2c1, l2c2) + msg = f"{self._opid}: {layer} {c1} {c2} common children {c}" + raise ValueError(msg) def run(self) -> Iterable: """ @@ -513,30 +817,45 @@ def run(self) -> Iterable: """ self._new_ids_d[2] = self._new_l2_ids for layer in range(2, self.cg.meta.layer_count): - if len(self._new_ids_d[layer]) == 0: + new_nodes = self._new_ids_d[layer] + if len(new_nodes) == 0: continue - with TimeIt( - f"create_new_parents_layer.{layer}", - self.cg.graph_id, - self._operation_id, - ): + self.cg.cache.new_ids.update(new_nodes) + # all new IDs in this layer have been created + # update their cross chunk edges and their neighbors' + with self._profiler.profile(f"l{layer}_update_cx_cache"): + entries = self._update_cross_edge_cache_batched(new_nodes) + self.new_entries.extend(entries) + + with self._profiler.profile(f"l{layer}_update_neighbor_cx"): + entries = _update_neighbor_cx_edges( + self.cg, + new_nodes, + self._new_old_id_d, + self._old_new_id_d, + time_stamp=self._time_stamp, + parent_ts=self._last_ts, + ) + self.new_entries.extend(entries) + with self._profiler.profile(f"l{layer}_create_new_parents"): self._create_new_parents(layer) return self._new_ids_d[self.cg.meta.layer_count] def _update_root_id_lineage(self): - new_root_ids = self._new_ids_d[self.cg.meta.layer_count] - former_root_ids = self._get_old_ids(new_root_ids) - former_root_ids = np.unique(former_root_ids) - assert ( - len(former_root_ids) < 2 or len(new_root_ids) < 2 - ), "Something went wrong." - rows = [] - for new_root_id in new_root_ids: + if self.stitch_mode: + return + new_roots = self._new_ids_d[self.cg.meta.layer_count] + former_roots = flip_ids(self._new_old_id_d, new_roots) + former_roots = np.unique(former_roots) + + err = f"new roots are inconsistent; op {self._opid}" + assert len(former_roots) < 2 or len(new_roots) < 2, err + for new_root_id in new_roots: val_dict = { - attributes.Hierarchy.FormerParent: np.array(former_root_ids), - attributes.OperationLogs.OperationID: self._operation_id, + attributes.Hierarchy.FormerParent: former_roots, + attributes.OperationLogs.OperationID: self._opid, } - rows.append( + self.new_entries.append( self.cg.client.mutate_row( serialize_uint64(new_root_id), val_dict, @@ -544,44 +863,62 @@ def _update_root_id_lineage(self): ) ) - for former_root_id in former_root_ids: + for former_root_id in former_roots: val_dict = { - attributes.Hierarchy.NewParent: np.array(new_root_ids), - attributes.OperationLogs.OperationID: self._operation_id, + attributes.Hierarchy.NewParent: np.array( + new_roots, dtype=basetypes.NODE_ID + ), + attributes.OperationLogs.OperationID: self._opid, } - rows.append( + self.new_entries.append( self.cg.client.mutate_row( serialize_uint64(former_root_id), val_dict, time_stamp=self._time_stamp, ) ) - return rows - def _get_atomic_cross_edges_val_dict(self): - new_ids = np.array(self._new_ids_d[2], dtype=basetypes.NODE_ID) + def _get_cross_edges_val_dicts(self): val_dicts = {} - atomic_cross_edges_d = self.cg.get_atomic_cross_edges(new_ids) - for id_ in new_ids: - val_dict = {} - for layer, edges in atomic_cross_edges_d[id_].items(): - val_dict[attributes.Connectivity.CrossChunkEdge[layer]] = edges - val_dicts[id_] = val_dict + for layer in range(2, self.cg.meta.layer_count): + new_ids = np.array(self._new_ids_d[layer], dtype=basetypes.NODE_ID) + cross_edges_d = self.cg.get_cross_chunk_edges( + new_ids, time_stamp=self._last_ts + ) + for id_ in new_ids: + val_dict = {} + for layer, edges in cross_edges_d[id_].items(): + val_dict[attributes.Connectivity.CrossChunkEdge[layer]] = edges + val_dicts[id_] = val_dict return val_dicts def create_new_entries(self) -> List: - rows = [] - val_dicts = self._get_atomic_cross_edges_val_dict() - for layer in range(2, self.cg.meta.layer_count + 1): + max_layer = self.cg.meta.layer_count + val_dicts = self._get_cross_edges_val_dicts() + for layer in range(2, max_layer + 1): new_ids = self._new_ids_d[layer] for id_ in new_ids: + if self.do_sanity_check: + root_layer = self.cg.get_chunk_layer(self.cg.get_root(id_)) + assert root_layer == max_layer, (id_, self.cg.get_root(id_)) + + if layer < max_layer: + try: + _parent = self.cg.get_parent(id_) + _children = self.cg.get_children(_parent) + assert id_ in _children, (layer, id_, _parent, _children) + except TypeError as e: + logger.error(id_, _parent, self.cg.get_root(id_)) + raise TypeError from e + val_dict = val_dicts.get(id_, {}) children = self.cg.get_children(id_) + err = f"parent layer less than children; op {self._opid}" assert np.max( self.cg.get_chunk_layers(children) - ) < self.cg.get_chunk_layer(id_), "Parent layer less than children." + ) < self.cg.get_chunk_layer(id_), err val_dict[attributes.Hierarchy.Child] = children - rows.append( + self.new_entries.append( self.cg.client.mutate_row( serialize_uint64(id_), val_dict, @@ -589,11 +926,11 @@ def create_new_entries(self) -> List: ) ) for child_id in children: - rows.append( + self.new_entries.append( self.cg.client.mutate_row( serialize_uint64(child_id), {attributes.Hierarchy.Parent: id_}, time_stamp=self._time_stamp, ) ) - return rows + self._update_root_id_lineage() + self._update_root_id_lineage() diff --git a/pychunkedgraph/graph/lineage.py b/pychunkedgraph/graph/lineage.py index 6876ec563..70d112f97 100644 --- a/pychunkedgraph/graph/lineage.py +++ b/pychunkedgraph/graph/lineage.py @@ -4,7 +4,7 @@ from typing import Union from typing import Optional from typing import Iterable -from datetime import datetime +from datetime import datetime, timezone from collections import defaultdict import numpy as np @@ -174,7 +174,7 @@ def lineage_graph( future_ids = np.array(node_ids, dtype=NODE_ID) timestamp_past = float(0) if timestamp_past is None else timestamp_past.timestamp() timestamp_future = ( - datetime.utcnow().timestamp() + datetime.now(timezone.utc).timestamp() if timestamp_future is None else timestamp_future.timestamp() ) diff --git a/pychunkedgraph/graph/locks.py b/pychunkedgraph/graph/locks.py index b3a3a0eb7..f7406922f 100644 --- a/pychunkedgraph/graph/locks.py +++ b/pychunkedgraph/graph/locks.py @@ -1,13 +1,17 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed +import logging from typing import Union from typing import Sequence from collections import defaultdict +import networkx as nx import numpy as np from . import exceptions from .types import empty_1d -from .lineage import get_future_root_ids +from .lineage import lineage_graph +logger = logging.getLogger(__name__) class RootLock: """Attempts to lock the requested root IDs using a unique operation ID. @@ -22,6 +26,7 @@ class RootLock: "lock_acquired", "operation_id", "privileged_mode", + "future_root_ids_d", ] # FIXME: `locked_root_ids` is only required and exposed because `cg.client.lock_roots` # currently might lock different (more recent) root IDs than requested. @@ -44,25 +49,30 @@ def __init__( # caused by failed writes. Must be used with `operation_id`, # meaning only existing failed operations can be run this way. self.privileged_mode = privileged_mode + self.future_root_ids_d = defaultdict(lambda: empty_1d) def __enter__(self): - if self.privileged_mode: - assert self.operation_id is not None, "Please provide operation ID." - from warnings import warn - - warn("Warning: Privileged mode without acquiring lock.") - return self if not self.operation_id: self.operation_id = self.cg.id_client.create_operation_id() - future_root_ids_d = defaultdict(lambda: empty_1d) + if self.privileged_mode: + return self + + nodes_ts = self.cg.get_node_timestamps(self.root_ids, return_numpy=0) + min_ts = min(nodes_ts) + lgraph = lineage_graph(self.cg, self.root_ids, timestamp_past=min_ts) + self.future_root_ids_d = defaultdict(lambda: empty_1d) for id_ in self.root_ids: - future_root_ids_d[id_] = get_future_root_ids(self.cg, id_) + node_descendants = nx.descendants(lgraph, id_) + node_descendants = np.unique( + np.array(list(node_descendants), dtype=np.uint64) + ) + self.future_root_ids_d[id_] = node_descendants self.lock_acquired, self.locked_root_ids = self.cg.client.lock_roots( root_ids=self.root_ids, operation_id=self.operation_id, - future_root_ids_d=future_root_ids_d, + future_root_ids_d=self.future_root_ids_d, max_tries=7, ) if not self.lock_acquired: @@ -71,8 +81,19 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): if self.lock_acquired: - for locked_root_id in self.locked_root_ids: - self.cg.client.unlock_root(locked_root_id, self.operation_id) + max_workers = min(8, max(1, len(self.locked_root_ids))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + unlock_futures = [ + executor.submit( + self.cg.client.unlock_root, root_id, self.operation_id + ) + for root_id in self.locked_root_ids + ] + for future in as_completed(unlock_futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock root: {e}") class IndefiniteRootLock: @@ -87,7 +108,14 @@ class IndefiniteRootLock: or when it has already been locked indefinitely. """ - __slots__ = ["cg", "root_ids", "acquired", "operation_id", "privileged_mode"] + __slots__ = [ + "cg", + "root_ids", + "acquired", + "operation_id", + "privileged_mode", + "future_root_ids_d", + ] def __init__( self, @@ -95,6 +123,7 @@ def __init__( operation_id: np.uint64, root_ids: Union[np.uint64, Sequence[np.uint64]], privileged_mode: bool = False, + future_root_ids_d=None, ) -> None: self.cg = cg self.operation_id = operation_id @@ -104,31 +133,49 @@ def __init__( # This is intended to be used in extremely rare cases to fix errors # caused by failed writes. self.privileged_mode = privileged_mode + self.future_root_ids_d = future_root_ids_d def __enter__(self): if self.privileged_mode: - from warnings import warn - - warn("Warning: Privileged mode without acquiring indefinite lock.") return self if not self.cg.client.renew_locks(self.root_ids, self.operation_id): raise exceptions.LockingError("Could not renew locks before writing.") - future_root_ids_d = defaultdict(lambda: empty_1d) - for id_ in self.root_ids: - future_root_ids_d[id_] = get_future_root_ids(self.cg, id_) + if self.future_root_ids_d is None: + nodes_ts = self.cg.get_node_timestamps(self.root_ids, return_numpy=0) + min_ts = min(nodes_ts) + lgraph = lineage_graph(self.cg, self.root_ids, timestamp_past=min_ts) + self.future_root_ids_d = defaultdict(lambda: empty_1d) + for id_ in self.root_ids: + node_descendants = nx.descendants(lgraph, id_) + node_descendants = np.unique( + np.array(list(node_descendants), dtype=np.uint64) + ) + self.future_root_ids_d[id_] = node_descendants + self.acquired, self.root_ids, failed = self.cg.client.lock_roots_indefinitely( root_ids=self.root_ids, operation_id=self.operation_id, - future_root_ids_d=future_root_ids_d, + future_root_ids_d=self.future_root_ids_d, ) if not self.acquired: - raise exceptions.LockingError(f"{failed} has been locked indefinitely.") + raise exceptions.LockingError(f"{failed} have been locked indefinitely.") return self def __exit__(self, exception_type, exception_value, traceback): if self.acquired: - for locked_root_id in self.root_ids: - self.cg.client.unlock_indefinitely_locked_root( - locked_root_id, self.operation_id - ) + max_workers = min(8, max(1, len(self.root_ids))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + unlock_futures = [ + executor.submit( + self.cg.client.unlock_indefinitely_locked_root, + root_id, + self.operation_id, + ) + for root_id in self.root_ids + ] + for future in as_completed(unlock_futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock root: {e}") diff --git a/pychunkedgraph/graph/misc.py b/pychunkedgraph/graph/misc.py index b33e8a6fd..faaa7fb29 100644 --- a/pychunkedgraph/graph/misc.py +++ b/pychunkedgraph/graph/misc.py @@ -8,7 +8,6 @@ import fastremap import numpy as np -from multiwrapper import multiprocessing_utils as mu from . import ChunkedGraph from . import attributes @@ -51,22 +50,6 @@ def _read_delta_root_rows( return new_root_ids, expired_root_ids -def _read_root_rows_thread(args) -> list: - start_seg_id, end_seg_id, serialized_cg_info, time_stamp = args - cg = ChunkedGraph(**serialized_cg_info) - start_id = cg.get_node_id(segment_id=start_seg_id, chunk_id=cg.root_chunk_id) - end_id = cg.get_node_id(segment_id=end_seg_id, chunk_id=cg.root_chunk_id) - rows = cg.client.read_nodes( - start_id=start_id, - end_id=end_id, - end_id_inclusive=False, - end_time=time_stamp, - end_time_inclusive=True, - ) - root_ids = [k for (k, v) in rows.items() if attributes.Hierarchy.NewParent not in v] - return root_ids - - def get_proofread_root_ids( cg: ChunkedGraph, start_time: Optional[datetime.datetime] = None, @@ -94,43 +77,12 @@ def get_proofread_root_ids( def get_latest_roots( - cg, time_stamp: Optional[datetime.datetime] = None, n_threads: int = 1 + cg: ChunkedGraph, time_stamp: Optional[datetime.datetime] = None, n_threads: int = 1 ) -> Sequence[np.uint64]: - # Create filters: time and id range - max_seg_id = cg.get_max_seg_id(cg.root_chunk_id) + 1 - n_blocks = 1 if n_threads == 1 else int(np.min([n_threads * 3 + 1, max_seg_id])) - seg_id_blocks = np.linspace(1, max_seg_id, n_blocks + 1, dtype=np.uint64) - cg_serialized_info = cg.get_serialized_info() - if n_threads > 1: - del cg_serialized_info["credentials"] - - multi_args = [] - for i_id_block in range(0, len(seg_id_blocks) - 1): - multi_args.append( - [ - seg_id_blocks[i_id_block], - seg_id_blocks[i_id_block + 1], - cg_serialized_info, - time_stamp, - ] - ) - - if n_threads == 1: - results = mu.multiprocess_func( - _read_root_rows_thread, - multi_args, - n_threads=n_threads, - verbose=False, - debug=n_threads == 1, - ) - else: - results = mu.multisubprocess_func( - _read_root_rows_thread, multi_args, n_threads=n_threads - ) - root_ids = [] - for result in results: - root_ids.extend(result) - return np.array(root_ids, dtype=np.uint64) + root_chunk = cg.get_chunk_id(layer=cg.meta.layer_count, x=0, y=0, z=0) + rr = cg.range_read_chunk(root_chunk, time_stamp=time_stamp) + roots = [k for k, v in rr.items() if attributes.Hierarchy.NewParent not in v] + return np.array(roots, dtype=np.uint64) def get_delta_roots( @@ -190,7 +142,7 @@ def get_contact_sites( ) # Build area lookup dictionary - cs_svs = edges[~np.in1d(edges, sv_ids).reshape(-1, 2)] + cs_svs = edges[~np.isin(edges, sv_ids)] area_dict = collections.defaultdict(int) for area, sv_id in zip(areas, cs_svs): @@ -202,7 +154,6 @@ def get_contact_sites( # Load edges of these cs_svs edges_cs_svs_rows = cg.client.read_nodes( node_ids=u_cs_svs, - # columns=[attributes.Connectivity.Partner, attributes.Connectivity.Connected], ) pre_cs_edges = [] for ri in edges_cs_svs_rows.items(): @@ -214,7 +165,7 @@ def get_contact_sites( cs_dict = collections.defaultdict(list) for cc in ccs: cc_sv_ids = unique_ids[cc] - cc_sv_ids = cc_sv_ids[np.in1d(cc_sv_ids, u_cs_svs)] + cc_sv_ids = cc_sv_ids[np.isin(cc_sv_ids, u_cs_svs)] cs_areas = area_dict_vec(cc_sv_ids) partner_root_id = ( int(cg.get_root(cc_sv_ids[0], time_stamp=time_stamp)) diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index d0d0e172a..8ff29b476 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -1,5 +1,6 @@ -# pylint: disable=invalid-name, missing-docstring, too-many-lines, protected-access +# pylint: disable=invalid-name, missing-docstring, too-many-lines, protected-access, broad-exception-raised +import logging from abc import ABC, abstractmethod from collections import namedtuple from datetime import datetime @@ -16,6 +17,8 @@ import numpy as np from google.cloud import bigtable +logger = logging.getLogger(__name__) + from . import locks from . import edits from . import types @@ -28,7 +31,7 @@ from .cutting import run_multicut from .exceptions import PreconditionError from .exceptions import PostconditionError -from .utils.generic import get_bounding_box as get_bbox +from .utils.generic import get_bounding_box as get_bbox, get_valid_timestamp from ..logging.log_db import TimeIt @@ -44,6 +47,7 @@ class GraphEditOperation(ABC): "sink_coords", "parent_ts", "privileged_mode", + "do_sanity_check", ] Result = namedtuple("Result", ["operation_id", "new_root_ids", "new_lvl2_ids"]) @@ -428,6 +432,8 @@ def execute( lock.locked_root_ids, np.array([lock.operation_id] * len(lock.locked_root_ids)), ) + if timestamp is None: + timestamp = get_valid_timestamp(timestamp) log_record_before_edit = self._create_log_record( operation_id=lock.operation_id, @@ -457,6 +463,9 @@ def execute( except PostconditionError as err: self.cg.cache = None raise PostconditionError(err) from err + except (AssertionError, RuntimeError) as err: + self.cg.cache = None + raise RuntimeError(err) from err except Exception as err: # unknown exception, update log record with error self.cg.cache = None @@ -469,7 +478,7 @@ def execute( exception=repr(err), ) self.cg.client.write([log_record_error]) - raise Exception(err) + raise Exception(err) from err with TimeIt(f"{op_type}.write", self.cg.graph_id, lock.operation_id): result = self._write( @@ -500,6 +509,7 @@ def _write(self, lock, timestamp, new_root_ids, new_lvl2_ids, affected_records): lock.operation_id, lock.locked_root_ids, privileged_mode=lock.privileged_mode, + future_root_ids_d=lock.future_root_ids_d, ): # indefinite lock for writing, if a node instance or pod dies during this # the roots must stay locked indefinitely to prevent further corruption. @@ -552,6 +562,7 @@ class MergeOperation(GraphEditOperation): "affinities", "bbox_offset", "allow_same_segment_merge", + "do_sanity_check", ] def __init__( @@ -565,6 +576,7 @@ def __init__( bbox_offset: Tuple[int, int, int] = (240, 240, 24), affinities: Optional[Sequence[np.float32]] = None, allow_same_segment_merge: Optional[bool] = False, + do_sanity_check: Optional[bool] = True, ) -> None: super().__init__( cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords @@ -572,6 +584,7 @@ def __init__( self.added_edges = np.atleast_2d(added_edges).astype(basetypes.NODE_ID) self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES) self.allow_same_segment_merge = allow_same_segment_merge + self.do_sanity_check = do_sanity_check self.affinities = None if affinities is not None: @@ -612,13 +625,16 @@ def _apply( edges_only=True, ) - with TimeIt("preprocess", self.cg.graph_id, operation_id): - inactive_edges = edits.merge_preprocess( - self.cg, - subgraph_edges=edges, - supervoxels=self.added_edges.ravel(), - parent_ts=self.parent_ts, - ) + if self.allow_same_segment_merge: + inactive_edges = types.empty_2d + else: + with TimeIt("preprocess", self.cg.graph_id, operation_id): + inactive_edges = edits.merge_preprocess( + self.cg, + subgraph_edges=edges, + supervoxels=self.added_edges.ravel(), + parent_ts=self.parent_ts, + ) atomic_edges, fake_edge_rows = edits.check_fake_edges( self.cg, @@ -634,6 +650,8 @@ def _apply( operation_id=operation_id, time_stamp=timestamp, parent_ts=self.parent_ts, + allow_same_segment_merge=self.allow_same_segment_merge, + do_sanity_check=self.do_sanity_check, ) return new_roots, new_l2_ids, fake_edge_rows + new_entries @@ -692,7 +710,7 @@ class SplitOperation(GraphEditOperation): :type sink_coords: Optional[Sequence[Sequence[int]]], optional """ - __slots__ = ["removed_edges", "bbox_offset"] + __slots__ = ["removed_edges", "bbox_offset", "do_sanity_check"] def __init__( self, @@ -703,12 +721,14 @@ def __init__( source_coords: Optional[Sequence[Sequence[int]]] = None, sink_coords: Optional[Sequence[Sequence[int]]] = None, bbox_offset: Tuple[int] = (240, 240, 24), + do_sanity_check: Optional[bool] = True, ) -> None: super().__init__( cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords ) self.removed_edges = np.atleast_2d(removed_edges).astype(basetypes.NODE_ID) self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES) + self.do_sanity_check = do_sanity_check if np.any(np.equal(self.removed_edges[:, 0], self.removed_edges[:, 1])): raise PreconditionError("Requested split contains at least 1 self-loop.") @@ -744,20 +764,14 @@ def _apply( ): raise PreconditionError("Supervoxels must belong to the same object.") - with TimeIt("subgraph", self.cg.graph_id, operation_id): - l2id_agglomeration_d, _ = self.cg.get_l2_agglomerations( - self.cg.get_parents( - self.removed_edges.ravel(), time_stamp=self.parent_ts - ), - ) with TimeIt("remove_edges", self.cg.graph_id, operation_id): return edits.remove_edges( self.cg, operation_id=operation_id, atomic_edges=self.removed_edges, - l2id_agglomeration_d=l2id_agglomeration_d, time_stamp=timestamp, parent_ts=self.parent_ts, + do_sanity_check=self.do_sanity_check, ) def _create_log_record( @@ -826,6 +840,7 @@ class MulticutOperation(GraphEditOperation): "bbox_offset", "path_augment", "disallow_isolating_cut", + "do_sanity_check", ] def __init__( @@ -841,6 +856,7 @@ def __init__( removed_edges: Sequence[Sequence[np.uint64]] = types.empty_2d, path_augment: bool = True, disallow_isolating_cut: bool = True, + do_sanity_check: Optional[bool] = True, ) -> None: super().__init__( cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords @@ -851,7 +867,8 @@ def __init__( self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES) self.path_augment = path_augment self.disallow_isolating_cut = disallow_isolating_cut - if np.any(np.in1d(self.sink_ids, self.source_ids)): + self.do_sanity_check = do_sanity_check + if np.any(np.isin(self.sink_ids, self.source_ids)): raise PreconditionError( "Supervoxels exist in both sink and source, " "try placing the points further apart." @@ -892,16 +909,16 @@ def _apply( self.cg.meta.split_bounding_offset, ) with TimeIt("get_subgraph", self.cg.graph_id, operation_id): - l2id_agglomeration_d, edges = self.cg.get_subgraph( + l2id_agglomeration_d, edges_tuple = self.cg.get_subgraph( root_ids.pop(), bbox=bbox, bbox_is_coordinate=True ) - edges = reduce(lambda x, y: x + y, edges, Edges([], [])) + edges = reduce(lambda x, y: x + y, edges_tuple, Edges([], [])) supervoxels = np.concatenate( [agg.supervoxels for agg in l2id_agglomeration_d.values()] ) - mask0 = np.in1d(edges.node_ids1, supervoxels) - mask1 = np.in1d(edges.node_ids2, supervoxels) + mask0 = np.isin(edges.node_ids1, supervoxels) + mask1 = np.isin(edges.node_ids2, supervoxels) edges = edges[mask0 & mask1] if len(edges) == 0: raise PreconditionError("No local edges found.") @@ -922,9 +939,9 @@ def _apply( self.cg, operation_id=operation_id, atomic_edges=self.removed_edges, - l2id_agglomeration_d=l2id_agglomeration_d, time_stamp=timestamp, parent_ts=self.parent_ts, + do_sanity_check=self.do_sanity_check, ) def _create_log_record( diff --git a/pychunkedgraph/graph/segmenthistory.py b/pychunkedgraph/graph/segmenthistory.py index 30f42d15b..bc4422490 100644 --- a/pychunkedgraph/graph/segmenthistory.py +++ b/pychunkedgraph/graph/segmenthistory.py @@ -1,5 +1,5 @@ import collections -from datetime import datetime +from datetime import datetime, timezone from typing import Iterable import numpy as np @@ -31,7 +31,7 @@ def __init__( if timestamp_past is not None: self.timestamp_past = timestamp_past - self.timestamp_future = datetime.utcnow() + self.timestamp_future = datetime.now(timezone.utc) if timestamp_future is None: self.timestamp_future = timestamp_future @@ -328,7 +328,7 @@ def past_future_id_mapping(self, root_id=None): past_id_mapping = {} future_id_mapping = {} for root_id in root_ids: - ancestors = np.array(list(nx_ancestors(self.lineage_graph, root_id))) + ancestors = np.array(list(nx_ancestors(self.lineage_graph, root_id)), dtype=np.uint64) if len(ancestors) == 0: past_id_mapping[int(root_id)] = [root_id] else: diff --git a/pychunkedgraph/graph/subgraph.py b/pychunkedgraph/graph/subgraph.py index ab2593175..1538b3cc2 100644 --- a/pychunkedgraph/graph/subgraph.py +++ b/pychunkedgraph/graph/subgraph.py @@ -1,3 +1,5 @@ +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel + from typing import List from typing import Dict from typing import Tuple @@ -30,9 +32,7 @@ def __init__(self, meta, node_ids, return_layers, serializable): # "Frontier" of nodes that cg.get_children will be called on self.cur_nodes = np.array(list(node_ids), dtype=np.uint64) # Mapping of current frontier to self.node_ids - self.cur_nodes_to_original_nodes = dict( - zip(self.cur_nodes, self.cur_nodes) - ) + self.cur_nodes_to_original_nodes = dict(zip(self.cur_nodes, self.cur_nodes)) self.stop_layer = max(1, min(return_layers)) self.create_initial_node_to_subgraph() @@ -107,13 +107,11 @@ def flatten_subgraph(self): for node_id in self.node_ids: for return_layer in self.return_layers: node_key = self.get_dict_key(node_id) - children_at_layer = self.node_to_subgraph[node_key][ - return_layer - ] + children_at_layer = self.node_to_subgraph[node_key][return_layer] if len(children_at_layer) > 0: - self.node_to_subgraph[node_key][ - return_layer - ] = np.concatenate(children_at_layer) + self.node_to_subgraph[node_key][return_layer] = np.concatenate( + children_at_layer + ) else: self.node_to_subgraph[node_key][return_layer] = empty_1d @@ -123,10 +121,12 @@ def get_subgraph_nodes( node_id_or_ids: Union[np.uint64, Iterable], bbox: Optional[Sequence[Sequence[int]]] = None, bbox_is_coordinate: bool = False, - return_layers: List = [2], + return_layers: List = None, serializable: bool = False, - return_flattened: bool = False + return_flattened: bool = False, ) -> Tuple[Dict, Dict, Edges]: + if return_layers is None: + return_layers = [2] single = False node_ids = node_id_or_ids bbox = normalize_bounding_box(cg.meta, bbox, bbox_is_coordinate) @@ -139,7 +139,7 @@ def get_subgraph_nodes( bounding_box=bbox, return_layers=return_layers, serializable=serializable, - return_flattened=return_flattened + return_flattened=return_flattened, ) if single: if serializable: @@ -155,7 +155,7 @@ def get_subgraph_edges_and_leaves( bbox_is_coordinate: bool = False, edges_only: bool = False, leaves_only: bool = False, -) -> Tuple[Dict, Dict, Edges]: +) -> Tuple[Dict, Tuple[Edges]]: """Get the edges and/or leaves of the specified node_ids within the specified bounding box.""" from .types import empty_1d @@ -183,7 +183,7 @@ def _get_subgraph_multiple_nodes( bounding_box: Optional[Sequence[Sequence[int]]], return_layers: Sequence[int], serializable: bool = False, - return_flattened: bool = False + return_flattened: bool = False, ): from collections import ChainMap from multiwrapper.multiprocessing_utils import n_cpus @@ -223,9 +223,7 @@ def _get_subgraph_multiple_nodes_threaded( subgraph = SubgraphProgress(cg.meta, node_ids, return_layers, serializable) while not subgraph.done_processing(): - this_n_threads = min( - [int(len(subgraph.cur_nodes) // 50000) + 1, n_cpus] - ) + this_n_threads = min([int(len(subgraph.cur_nodes) // 50000) + 1, n_cpus]) cur_nodes_child_maps = multithread_func( _get_subgraph_multiple_nodes_threaded, np.array_split(subgraph.cur_nodes, this_n_threads), @@ -239,8 +237,6 @@ def _get_subgraph_multiple_nodes_threaded( for node_id in node_ids: subgraph.node_to_subgraph[ _get_dict_key(node_id) - ] = subgraph.node_to_subgraph[_get_dict_key(node_id)][ - return_layers[0] - ] + ] = subgraph.node_to_subgraph[_get_dict_key(node_id)][return_layers[0]] - return subgraph.node_to_subgraph \ No newline at end of file + return subgraph.node_to_subgraph diff --git a/pychunkedgraph/graph/types.py b/pychunkedgraph/graph/types.py index 9a551f35c..1f35e5f6b 100644 --- a/pychunkedgraph/graph/types.py +++ b/pychunkedgraph/graph/types.py @@ -1,5 +1,4 @@ -from typing import Dict -from typing import Iterable +# pylint: disable=invalid-name, missing-docstring from collections import namedtuple import numpy as np diff --git a/pychunkedgraph/graph/utils/basetypes.py b/pychunkedgraph/graph/utils/basetypes.py index e55324e6a..c6b0b1974 100644 --- a/pychunkedgraph/graph/utils/basetypes.py +++ b/pychunkedgraph/graph/utils/basetypes.py @@ -1,16 +1,16 @@ import numpy as np -CHUNK_ID = SEGMENT_ID = NODE_ID = OPERATION_ID = np.dtype('uint64').newbyteorder('L') -EDGE_AFFINITY = np.dtype('float32').newbyteorder('L') -EDGE_AREA = np.dtype('uint64').newbyteorder('L') +CHUNK_ID = SEGMENT_ID = NODE_ID = OPERATION_ID = np.dtype("uint64").newbyteorder("L") +EDGE_AFFINITY = np.dtype("float32").newbyteorder("L") +EDGE_AREA = np.dtype("uint64").newbyteorder("L") -COUNTER = np.dtype('int64').newbyteorder('B') +COUNTER = np.dtype("int64").newbyteorder("B") -COORDINATES = np.dtype('int64').newbyteorder('L') -CHUNKSIZE = np.dtype('uint64').newbyteorder('L') -FANOUT = np.dtype('uint64').newbyteorder('L') -LAYERCOUNT = np.dtype('uint64').newbyteorder('L') -SPATIALBITS = np.dtype('uint64').newbyteorder('L') -ROOTCOUNTERBITS = np.dtype('uint64').newbyteorder('L') -SKIPCONNECTIONS = np.dtype('uint64').newbyteorder('L') \ No newline at end of file +COORDINATES = np.dtype("int64").newbyteorder("L") +CHUNKSIZE = np.dtype("uint64").newbyteorder("L") +FANOUT = np.dtype("uint64").newbyteorder("L") +LAYERCOUNT = np.dtype("uint64").newbyteorder("L") +SPATIALBITS = np.dtype("uint64").newbyteorder("L") +ROOTCOUNTERBITS = np.dtype("uint64").newbyteorder("L") +SKIPCONNECTIONS = np.dtype("uint64").newbyteorder("L") diff --git a/pychunkedgraph/graph/utils/flatgraph.py b/pychunkedgraph/graph/utils/flatgraph.py index df469d728..d9504f104 100644 --- a/pychunkedgraph/graph/utils/flatgraph.py +++ b/pychunkedgraph/graph/utils/flatgraph.py @@ -1,8 +1,11 @@ +# pylint: disable=invalid-name, missing-docstring, c-extension-no-member + +from itertools import combinations, chain + import fastremap import numpy as np -from itertools import combinations, chain from graph_tool import Graph, GraphView -from graph_tool import topology, search +from graph_tool import topology def build_gt_graph( @@ -88,7 +91,10 @@ def team_paths_all_to_all(graph, capacity, team_vertex_ids): def neighboring_edges(graph, vertex_id): - """Returns vertex and edge lists of a seed vertex, in the same format as team_paths_all_to_all.""" + """ + Returns vertex and edge lists of a seed vertex, + in the same format as team_paths_all_to_all. + """ add_v = [] add_e = [] v0 = graph.vertex(vertex_id) @@ -106,7 +112,7 @@ def intersect_nodes(paths_v_s, paths_v_y): def harmonic_mean_paths(x): - return np.power(np.product(x), 1 / len(x)) + return np.power(np.prod(x), 1 / len(x)) def compute_filtered_paths( @@ -124,7 +130,8 @@ def compute_filtered_paths( gfilt, capacity, team_vertex_ids ) - # graph-tool will invalidate the vertex and edge properties if I don't rebase them on the main graph + # graph-tool will invalidate the vertex and + # edge properties if I don't rebase them on the main graph # before tearing down the GraphView new_paths_e = [] for pth in paths_e: diff --git a/pychunkedgraph/graph/utils/generic.py b/pychunkedgraph/graph/utils/generic.py index 9a2b6f979..696a03801 100644 --- a/pychunkedgraph/graph/utils/generic.py +++ b/pychunkedgraph/graph/utils/generic.py @@ -3,7 +3,7 @@ TODO categorize properly """ - +import bisect import datetime from typing import Dict from typing import Iterable @@ -98,7 +98,7 @@ def time_min(): def get_valid_timestamp(timestamp): if timestamp is None: - timestamp = datetime.datetime.utcnow() + timestamp = datetime.datetime.now(datetime.timezone.utc) if timestamp.tzinfo is None: timestamp = pytz.UTC.localize(timestamp) # Comply to resolution of BigTables TimeRange @@ -173,9 +173,7 @@ def mask_nodes_by_bounding_box( adapt_layers = layers - 2 adapt_layers[adapt_layers < 0] = 0 fanout = meta.graph_config.FANOUT - bounding_box_layer = ( - bounding_box[None] / (fanout ** adapt_layers)[:, None, None] - ) + bounding_box_layer = bounding_box[None] / (fanout**adapt_layers)[:, None, None] bound_check = np.array( [ np.all(chunk_coordinates < bounding_box_layer[:, 1], axis=1), @@ -183,4 +181,25 @@ def mask_nodes_by_bounding_box( ] ).T - return np.all(bound_check, axis=1) \ No newline at end of file + return np.all(bound_check, axis=1) + + +def get_parents_at_timestamp(nodes, parents_ts_map, time_stamp, unique: bool = False): + """ + Search for the first parent with ts <= `time_stamp`. + `parents_ts_map[node]` is a map of ts:parent with sorted timestamps (desc). + """ + skipped_nodes = [] + parents = set() if unique else [] + for node in nodes: + try: + ts_parent_map = parents_ts_map[node] + ts_list = list(ts_parent_map.keys()) + asc_ts_list = ts_list[::-1] + idx = bisect.bisect_right(asc_ts_list, time_stamp) + ts = asc_ts_list[idx - 1] + parent = ts_parent_map[ts] + parents.add(parent) if unique else parents.append(parent) + except KeyError: + skipped_nodes.append(node) + return list(parents), skipped_nodes diff --git a/pychunkedgraph/graph/utils/id_helpers.py b/pychunkedgraph/graph/utils/id_helpers.py index aa486ac84..2a245f79c 100644 --- a/pychunkedgraph/graph/utils/id_helpers.py +++ b/pychunkedgraph/graph/utils/id_helpers.py @@ -89,7 +89,7 @@ def get_atomic_id_from_coord( # sort by frequency and discard those ids that have been checked # previously sorted_atomic_ids = atomic_ids[np.argsort(atomic_id_count)] - sorted_atomic_ids = sorted_atomic_ids[~np.in1d(sorted_atomic_ids, checked)] + sorted_atomic_ids = sorted_atomic_ids[~np.isin(sorted_atomic_ids, checked)] # For each candidate id check whether its root id corresponds to the # given root id diff --git a/pychunkedgraph/graph/utils/serializers.py b/pychunkedgraph/graph/utils/serializers.py index 09c0f63b0..a09094b33 100644 --- a/pychunkedgraph/graph/utils/serializers.py +++ b/pychunkedgraph/graph/utils/serializers.py @@ -41,7 +41,9 @@ def _deserialize(val, dtype, shape=None, order=None): def __init__(self, dtype, shape=None, order=None, compression_level=None): super().__init__( - serializer=lambda x: x.newbyteorder(dtype.byteorder).tobytes(), + serializer=lambda x: np.asarray(x) + .view(x.dtype.newbyteorder(dtype.byteorder)) + .tobytes(), deserializer=lambda x: NumPyArray._deserialize( x, dtype, shape=shape, order=order ), @@ -53,7 +55,9 @@ def __init__(self, dtype, shape=None, order=None, compression_level=None): class NumPyValue(_Serializer): def __init__(self, dtype): super().__init__( - serializer=lambda x: x.newbyteorder(dtype.byteorder).tobytes(), + serializer=lambda x: np.asarray(x) + .view(np.dtype(type(x)).newbyteorder(dtype.byteorder)) + .tobytes(), deserializer=lambda x: np.frombuffer(x, dtype=dtype)[0], basetype=dtype.type, ) @@ -96,7 +100,7 @@ def __init__(self): def pad_node_id(node_id: np.uint64) -> str: - """ Pad node id to 20 digits + """Pad node id to 20 digits :param node_id: int :return: str @@ -105,7 +109,7 @@ def pad_node_id(node_id: np.uint64) -> str: def serialize_uint64(node_id: np.uint64, counter=False, fake_edges=False) -> bytes: - """ Serializes an id to be ingested by a bigtable table row + """Serializes an id to be ingested by a bigtable table row :param node_id: int :return: str @@ -118,7 +122,7 @@ def serialize_uint64(node_id: np.uint64, counter=False, fake_edges=False) -> byt def serialize_uint64s_to_regex(node_ids: Iterable[np.uint64]) -> bytes: - """ Serializes an id to be ingested by a bigtable table row + """Serializes an id to be ingested by a bigtable table row :param node_id: int :return: str @@ -128,7 +132,7 @@ def serialize_uint64s_to_regex(node_ids: Iterable[np.uint64]) -> bytes: def deserialize_uint64(node_id: bytes, fake_edges=False) -> np.uint64: - """ De-serializes a node id from a BigTable row + """De-serializes a node id from a BigTable row :param node_id: bytes :return: np.uint64 @@ -139,7 +143,7 @@ def deserialize_uint64(node_id: bytes, fake_edges=False) -> np.uint64: def serialize_key(key: str) -> bytes: - """ Serializes a key to be ingested by a bigtable table row + """Serializes a key to be ingested by a bigtable table row :param key: str :return: bytes @@ -148,7 +152,7 @@ def serialize_key(key: str) -> bytes: def deserialize_key(key: bytes) -> str: - """ Deserializes a row key + """Deserializes a row key :param key: bytes :return: str diff --git a/pychunkedgraph/ingest/__init__.py b/pychunkedgraph/ingest/__init__.py index b3d832d5e..55c10ca5f 100644 --- a/pychunkedgraph/ingest/__init__.py +++ b/pychunkedgraph/ingest/__init__.py @@ -1,32 +1,16 @@ +import logging from collections import namedtuple - -_cluster_ingest_config_fields = ( - "ATOMIC_Q_NAME", - "ATOMIC_Q_LIMIT", - "ATOMIC_Q_INTERVAL", -) -_cluster_ingest_defaults = ( - "l2", - 100000, - 120, -) -ClusterIngestConfig = namedtuple( - "ClusterIngestConfig", - _cluster_ingest_config_fields, - defaults=_cluster_ingest_defaults, -) - +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) _ingestconfig_fields = ( - "CLUSTER", # cluster config "AGGLOMERATION", "WATERSHED", "USE_RAW_EDGES", "USE_RAW_COMPONENTS", "TEST_RUN", ) -_ingestconfig_defaults = (None, None, None, False, False, False) +_ingestconfig_defaults = (None, None, False, False, False) IngestConfig = namedtuple( "IngestConfig", _ingestconfig_fields, defaults=_ingestconfig_defaults ) diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index 7668e8f24..c50525ec6 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -1,24 +1,32 @@ +# pylint: disable=invalid-name, missing-function-docstring, unspecified-encoding + """ cli for running ingest """ -from os import environ -from time import sleep +import logging import click import yaml from flask.cli import AppGroup -from rq import Queue +from .cluster import create_atomic_chunk, create_parent_chunk, enqueue_l2_tasks from .manager import IngestionManager -from .utils import bootstrap -from .cluster import randomize_grid_points +from .utils import ( + bootstrap, + chunk_id_str, + print_completion_rate, + print_status, + queue_layer_helper, + job_type_guard, +) +from .simple_tests import run_all +from .create.parent_layer import add_parent_chunk from ..graph.chunkedgraph import ChunkedGraph -from ..utils.redis import get_redis_connection -from ..utils.redis import keys as r_keys -from ..utils.general import chunked +from ..utils.redis import get_redis_connection, keys as r_keys -ingest_cli = AppGroup("ingest") +group_name = "ingest" +ingest_cli = AppGroup(group_name) def init_ingest_cmds(app): @@ -26,6 +34,8 @@ def init_ingest_cmds(app): @ingest_cli.command("flush_redis") +@click.confirmation_option(prompt="Are you sure you want to flush redis?") +@job_type_guard(group_name) def flush_redis(): """FLush redis db.""" redis = get_redis_connection() @@ -35,9 +45,10 @@ def flush_redis(): @ingest_cli.command("graph") @click.argument("graph_id", type=str) @click.argument("dataset", type=click.Path(exists=True)) -@click.option("--raw", is_flag=True) -@click.option("--test", is_flag=True) -@click.option("--retry", is_flag=True) +@click.option("--raw", is_flag=True, help="Read edges from agglomeration output.") +@click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") +@click.option("--retry", is_flag=True, help="Rerun without creating a new table.") +@job_type_guard(group_name) def ingest_graph( graph_id: str, dataset: click.Path, raw: bool, test: bool, retry: bool ): @@ -45,27 +56,28 @@ def ingest_graph( Main ingest command. Takes ingest config from a yaml file and queues atomic tasks. """ - from .cluster import enqueue_atomic_tasks - + redis = get_redis_connection() + redis.set(r_keys.JOB_TYPE, group_name) with open(dataset, "r") as stream: config = yaml.safe_load(stream) - meta, ingest_config, client_info = bootstrap( - graph_id, - config=config, - raw=raw, - test_run=test, - ) + if test: + logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.DEBUG) + + meta, ingest_config, client_info = bootstrap(graph_id, config, raw, test) cg = ChunkedGraph(meta=meta, client_info=client_info) if not retry: cg.create() - enqueue_atomic_tasks(IngestionManager(ingest_config, meta)) + + imanager = IngestionManager(ingest_config, meta) + enqueue_l2_tasks(imanager, create_atomic_chunk) @ingest_cli.command("imanager") @click.argument("graph_id", type=str) @click.argument("dataset", type=click.Path(exists=True)) @click.option("--raw", is_flag=True) +@job_type_guard(group_name) def pickle_imanager(graph_id: str, dataset: click.Path, raw: bool): """ Load ingest config into redis server. @@ -79,96 +91,51 @@ def pickle_imanager(graph_id: str, dataset: click.Path, raw: bool): meta, ingest_config, _ = bootstrap(graph_id, config=config, raw=raw) imanager = IngestionManager(ingest_config, meta) - imanager.redis + imanager.redis.set(r_keys.JOB_TYPE, group_name) @ingest_cli.command("layer") @click.argument("parent_layer", type=int) +@job_type_guard(group_name) def queue_layer(parent_layer): """ Queue all chunk tasks at a given layer. Must be used when all the chunks at `parent_layer - 1` have completed. """ - from itertools import product - import numpy as np - from .cluster import create_parent_chunk - from .utils import chunk_id_str - assert parent_layer > 2, "This command is for layers 3 and above." redis = get_redis_connection() imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - - if parent_layer == imanager.cg_meta.layer_count: - chunk_coords = [(0, 0, 0)] - else: - bounds = imanager.cg_meta.layer_chunk_bounds[parent_layer] - chunk_coords = randomize_grid_points(*bounds) - - def get_chunks_not_done(coords: list) -> list: - """check for set membership in redis in batches""" - coords_strs = ["_".join(map(str, coord)) for coord in coords] - try: - completed = imanager.redis.smismember(f"{parent_layer}c", coords_strs) - except Exception: - return coords - return [coord for coord, c in zip(coords, completed) if not c] - - batch_size = int(environ.get("JOB_BATCH_SIZE", 10000)) - batches = chunked(chunk_coords, batch_size) - q = imanager.get_task_queue(f"l{parent_layer}") - - for batch in batches: - _coords = get_chunks_not_done(batch) - # buffer for optimal use of redis memory - if len(q) > int(environ.get("QUEUE_SIZE", 100000)): - interval = int(environ.get("QUEUE_INTERVAL", 300)) - sleep(interval) - - job_datas = [] - for chunk_coord in _coords: - job_datas.append( - Queue.prepare_data( - create_parent_chunk, - args=(parent_layer, chunk_coord), - result_ttl=0, - job_id=chunk_id_str(parent_layer, chunk_coord), - timeout=f"{int(parent_layer * parent_layer)}m", - ) - ) - q.enqueue_many(job_datas) + queue_layer_helper(parent_layer, imanager, create_parent_chunk) @ingest_cli.command("status") +@job_type_guard(group_name) def ingest_status(): """Print ingest status to console by layer.""" redis = get_redis_connection() - imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - layers = range(2, imanager.cg_meta.layer_count + 1) - for layer, layer_count in zip(layers, imanager.cg_meta.layer_chunk_counts): - completed = redis.scard(f"{layer}c") - print(f"{layer}\t: {completed} / {layer_count}") + try: + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + print_status(imanager, redis) + except TypeError as err: + print(f"\nNo current `{group_name}` job found in redis: {err}") @ingest_cli.command("chunk") @click.argument("queue", type=str) @click.argument("chunk_info", nargs=4, type=int) +@job_type_guard(group_name) def ingest_chunk(queue: str, chunk_info): """Manually queue chunk when a job is stuck for whatever reason.""" - from .cluster import _create_atomic_chunk - from .cluster import create_parent_chunk - from .utils import chunk_id_str - redis = get_redis_connection() imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - layer = chunk_info[0] - coords = chunk_info[1:] - queue = imanager.get_task_queue(queue) + layer, coords = chunk_info[0], chunk_info[1:] + + func = create_parent_chunk + args = (layer, coords) if layer == 2: - func = _create_atomic_chunk + func = create_atomic_chunk args = (coords,) - else: - func = create_parent_chunk - args = (layer, coords) + queue = imanager.get_task_queue(queue) queue.enqueue( func, job_id=chunk_id_str(layer, coords), @@ -182,13 +149,31 @@ def ingest_chunk(queue: str, chunk_info): @click.argument("graph_id", type=str) @click.argument("chunk_info", nargs=4, type=int) @click.option("--n_threads", type=int, default=1) +@job_type_guard(group_name) def ingest_chunk_local(graph_id: str, chunk_info, n_threads: int): """Manually ingest a chunk on a local machine.""" - from .create.abstract_layers import add_layer - from .cluster import _create_atomic_chunk - - if chunk_info[0] == 2: - _create_atomic_chunk(chunk_info[1:]) + layer, coords = chunk_info[0], chunk_info[1:] + if layer == 2: + create_atomic_chunk(coords) else: cg = ChunkedGraph(graph_id=graph_id) - add_layer(cg, chunk_info[0], chunk_info[1:], n_threads=n_threads) + add_parent_chunk(cg, layer, coords, n_threads=n_threads) + cg = ChunkedGraph(graph_id=graph_id) + add_parent_chunk(cg, layer, coords, n_threads=n_threads) + + +@ingest_cli.command("rate") +@click.argument("layer", type=int) +@click.option("--span", default=10, help="Time span to calculate rate.") +@job_type_guard(group_name) +def rate(layer: int, span: int): + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + print_completion_rate(imanager, layer, span=span) + + +@ingest_cli.command("run_tests") +@click.argument("graph_id", type=str) +@job_type_guard(group_name) +def run_tests(graph_id): + run_all(ChunkedGraph(graph_id=graph_id)) diff --git a/pychunkedgraph/ingest/cli_upgrade.py b/pychunkedgraph/ingest/cli_upgrade.py new file mode 100644 index 000000000..3c4e6f7f8 --- /dev/null +++ b/pychunkedgraph/ingest/cli_upgrade.py @@ -0,0 +1,157 @@ +# pylint: disable=invalid-name, missing-function-docstring, unspecified-encoding + +""" +cli for running upgrade +""" + +import logging +from time import sleep + +import click +import tensorstore as ts +from flask.cli import AppGroup +from pychunkedgraph import __version__ +from pychunkedgraph.graph.meta import GraphConfig + +from . import IngestConfig +from .cluster import ( + convert_to_ocdbt, + enqueue_l2_tasks, + upgrade_atomic_chunk, + upgrade_parent_chunk, +) +from .manager import IngestionManager +from .utils import ( + chunk_id_str, + print_completion_rate, + print_status, + queue_layer_helper, + start_ocdbt_server, + job_type_guard, +) +from ..graph.chunkedgraph import ChunkedGraph, ChunkedGraphMeta +from ..utils.redis import get_redis_connection +from ..utils.redis import keys as r_keys + +group_name = "upgrade" +upgrade_cli = AppGroup(group_name) + + +def init_upgrade_cmds(app): + app.cli.add_command(upgrade_cli) + + +@upgrade_cli.command("flush_redis") +@click.confirmation_option(prompt="Are you sure you want to flush redis?") +@job_type_guard(group_name) +def flush_redis(): + """FLush redis db.""" + redis = get_redis_connection() + redis.flushdb() + + +@upgrade_cli.command("graph") +@click.argument("graph_id", type=str) +@click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") +@click.option("--ocdbt", is_flag=True, help="Store edges using ts ocdbt kv store.") +@job_type_guard(group_name) +def upgrade_graph(graph_id: str, test: bool, ocdbt: bool): + """ + Main upgrade command. Queues atomic tasks. + """ + redis = get_redis_connection() + redis.set(r_keys.JOB_TYPE, group_name) + ingest_config = IngestConfig(TEST_RUN=test) + cg = ChunkedGraph(graph_id=graph_id) + cg.client.add_graph_version(__version__, overwrite=True) + + if graph_id != cg.graph_id: + gc = cg.meta.graph_config._asdict() + gc["ID"] = graph_id + new_meta = ChunkedGraphMeta( + GraphConfig(**gc), cg.meta.data_source, cg.meta.custom_data + ) + cg.update_meta(new_meta, overwrite=True) + cg = ChunkedGraph(graph_id=graph_id) + + try: + # create new column family for cross chunk edges + f = cg.client._table.column_family("4") + f.create() + except Exception: + ... + + imanager = IngestionManager(ingest_config, cg.meta) + server = ts.ocdbt.DistributedCoordinatorServer() + if ocdbt: + start_ocdbt_server(imanager, server) + + fn = convert_to_ocdbt if ocdbt else upgrade_atomic_chunk + enqueue_l2_tasks(imanager, fn) + + if ocdbt: + logging.info("All tasks queued. Keep this alive for ocdbt coordinator server.") + while True: + sleep(60) + + +@upgrade_cli.command("layer") +@click.argument("parent_layer", type=int) +@click.option("--splits", default=0, help="Split chunks into multiple tasks.") +@job_type_guard(group_name) +def queue_layer(parent_layer:int, splits:int=0): + """ + Queue all chunk tasks at a given layer. + Must be used when all the chunks at `parent_layer - 1` have completed. + """ + assert parent_layer > 2, "This command is for layers 3 and above." + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + queue_layer_helper(parent_layer, imanager, upgrade_parent_chunk, splits=splits) + + +@upgrade_cli.command("status") +@job_type_guard(group_name) +def upgrade_status(): + """Print upgrade status to console.""" + redis = get_redis_connection() + try: + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + print_status(imanager, redis, upgrade=True) + except TypeError as err: + print(f"\nNo current `{group_name}` job found in redis: {err}") + + +@upgrade_cli.command("chunk") +@click.argument("queue", type=str) +@click.argument("chunk_info", nargs=4, type=int) +@job_type_guard(group_name) +def upgrade_chunk(queue: str, chunk_info): + """Manually queue chunk when a job is stuck for whatever reason.""" + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + layer, coords = chunk_info[0], chunk_info[1:] + + func = upgrade_parent_chunk + args = (layer, coords) + if layer == 2: + func = upgrade_atomic_chunk + args = (coords,) + queue = imanager.get_task_queue(queue) + queue.enqueue( + func, + job_id=chunk_id_str(layer, coords), + job_timeout=f"{int(layer * layer)}m", + result_ttl=0, + args=args, + ) + + +@upgrade_cli.command("rate") +@click.argument("layer", type=int) +@click.option("--span", default=10, help="Time span to calculate rate.") +@job_type_guard(group_name) +def rate(layer: int, span: int): + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + print_completion_rate(imanager, layer, span=span) diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index cf9417024..219cae07b 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -1,104 +1,51 @@ +# pylint: disable=invalid-name, missing-function-docstring, import-outside-toplevel + """ -Ingest / create chunkedgraph with workers. +Ingest / create chunkedgraph with workers on a cluster. """ -from typing import Sequence, Tuple +import logging +from os import environ +from time import sleep +from typing import Callable, Dict, Iterable, Tuple, Sequence import numpy as np +from rq import Queue as RQueue, Retry + -from .utils import chunk_id_str +from .utils import chunk_id_str, get_chunks_not_done, randomize_grid_points from .manager import IngestionManager -from .common import get_atomic_chunk_data -from .ran_agglomeration import get_active_edges -from .create.atomic_layer import add_atomic_edges -from .create.abstract_layers import add_layer -from ..graph.meta import ChunkedGraphMeta +from .ran_agglomeration import ( + get_active_edges, + read_raw_edge_data, + read_raw_agglomeration_data, +) +from .create.atomic_layer import add_atomic_chunk +from .create.parent_layer import add_parent_chunk +from .upgrade.atomic_layer import update_chunk as update_atomic_chunk +from .upgrade.parent_layer import update_chunk as update_parent_chunk +from ..graph.edges import EDGE_TYPES, Edges, put_edges +from ..graph import ChunkedGraph, ChunkedGraphMeta from ..graph.chunks.hierarchy import get_children_chunk_coords -from ..utils.redis import keys as r_keys -from ..utils.redis import get_redis_connection - - -def _post_task_completion(imanager: IngestionManager, layer: int, coords: np.ndarray): - from os import environ - +from ..graph.utils.basetypes import NODE_ID +from ..io.edges import get_chunk_edges +from ..io.components import get_chunk_components +from ..utils.redis import keys as r_keys, get_redis_connection +from ..utils.general import chunked + + +def _post_task_completion( + imanager: IngestionManager, + layer: int, + coords: np.ndarray, + split:int=None +): chunk_str = "_".join(map(str, coords)) + if split is not None: + chunk_str += f"_{split}" # mark chunk as completed - "c" imanager.redis.sadd(f"{layer}c", chunk_str) - - if environ.get("DO_NOT_AUTOQUEUE_PARENT_CHUNKS", None) is not None: - return - - parent_layer = layer + 1 - if parent_layer > imanager.cg_meta.layer_count: - return - - parent_coords = np.array(coords, int) // imanager.cg_meta.graph_config.FANOUT - parent_id_str = chunk_id_str(parent_layer, parent_coords) - imanager.redis.sadd(parent_id_str, chunk_str) - - parent_chunk_str = "_".join(map(str, parent_coords)) - if not imanager.redis.hget(parent_layer, parent_chunk_str): - # cache children chunk count - # checked by tracker worker to enqueue parent chunk - children_count = len( - get_children_chunk_coords(imanager.cg_meta, parent_layer, parent_coords) - ) - imanager.redis.hset(parent_layer, parent_chunk_str, children_count) - - tracker_queue = imanager.get_task_queue(f"t{layer}") - tracker_queue.enqueue( - enqueue_parent_task, - job_id=f"t{layer}_{chunk_str}", - job_timeout=f"30s", - result_ttl=0, - args=( - parent_layer, - parent_coords, - ), - ) - - -def enqueue_parent_task( - parent_layer: int, - parent_coords: Sequence[int], -): - redis = get_redis_connection() - imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - parent_id_str = chunk_id_str(parent_layer, parent_coords) - parent_chunk_str = "_".join(map(str, parent_coords)) - - children_done = redis.scard(parent_id_str) - # if zero then this key was deleted and parent already queued. - if children_done == 0: - print("parent already queued.") - return - - # if the previous layer is complete - # no need to check children progress for each parent chunk - child_layer = parent_layer - 1 - child_layer_done = redis.scard(f"{child_layer}c") - child_layer_count = imanager.cg_meta.layer_chunk_counts[child_layer - 2] - child_layer_finished = child_layer_done == child_layer_count - - if not child_layer_finished: - children_count = int(redis.hget(parent_layer, parent_chunk_str).decode("utf-8")) - if children_done != children_count: - print("children not done.") - return - - queue = imanager.get_task_queue(f"l{parent_layer}") - queue.enqueue( - create_parent_chunk, - job_id=parent_id_str, - job_timeout=f"{int(parent_layer * parent_layer)}m", - result_ttl=0, - args=( - parent_layer, - parent_coords, - ), - ) - redis.hdel(parent_layer, parent_chunk_str) - redis.delete(parent_id_str) + logging.info(f"{chunk_str} marked as complete") def create_parent_chunk( @@ -107,7 +54,7 @@ def create_parent_chunk( ) -> None: redis = get_redis_connection() imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - add_layer( + add_parent_chunk( imanager.cg, parent_layer, parent_coords, @@ -120,76 +67,174 @@ def create_parent_chunk( _post_task_completion(imanager, parent_layer, parent_coords) -def randomize_grid_points(X: int, Y: int, Z: int) -> Tuple[int, int, int]: - indices = np.arange(X * Y * Z) - np.random.shuffle(indices) - for index in indices: - yield np.unravel_index(index, (X, Y, Z)) +def upgrade_parent_chunk( + parent_layer: int, + parent_coords: Sequence[int], + split:int=None, + splits:int=None +) -> None: + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + update_parent_chunk(imanager.cg, parent_coords, layer=parent_layer, split=split, splits=splits) + _post_task_completion(imanager, parent_layer, parent_coords, split=split) + + +def _get_atomic_chunk_data( + imanager: IngestionManager, coord: Sequence[int] +) -> Tuple[Dict, Dict]: + """ + Helper to read either raw data or processed data + If reading from raw data, save it as processed data + """ + chunk_edges = ( + read_raw_edge_data(imanager, coord) + if imanager.config.USE_RAW_EDGES + else get_chunk_edges(imanager.cg_meta.data_source.EDGES, [coord]) + ) + _check_edges_direction(chunk_edges, imanager.cg, coord) -def enqueue_atomic_tasks(imanager: IngestionManager): - from os import environ - from time import sleep - from rq import Queue as RQueue + mapping = ( + read_raw_agglomeration_data(imanager, coord) + if imanager.config.USE_RAW_COMPONENTS + else get_chunk_components(imanager.cg_meta.data_source.COMPONENTS, coord) + ) + return chunk_edges, mapping - chunk_coords = _get_test_chunks(imanager.cg.meta) - chunk_count = len(chunk_coords) - if not imanager.config.TEST_RUN: - atomic_chunk_bounds = imanager.cg_meta.layer_chunk_bounds[2] - chunk_coords = randomize_grid_points(*atomic_chunk_bounds) - chunk_count = imanager.cg_meta.layer_chunk_counts[0] - print(f"total chunk count: {chunk_count}, queuing...") - batch_size = int(environ.get("L2JOB_BATCH_SIZE", 1000)) +def _check_edges_direction( + chunk_edges: dict, cg: ChunkedGraph, coord: Sequence[int] +) -> None: + """ + For between and cross chunk edges: + Checks and flips edges such that nodes1 are always within a chunk and nodes2 outside the chunk. + Where nodes1 = edges[:,0] and nodes2 = edges[:,1]. + """ + x, y, z = coord + chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) + for edge_type in [EDGE_TYPES.between_chunk, EDGE_TYPES.cross_chunk]: + edges = chunk_edges[edge_type] + chunk_ids = cg.get_chunk_ids_from_node_ids(edges.node_ids1) + mask = chunk_ids == chunk_id + assert np.all(mask), "all IDs must belong to same chunk" + + +def create_atomic_chunk(coords: Sequence[int]): + """Creates single atomic chunk""" + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + coords = np.array(list(coords), dtype=int) - job_datas = [] - for chunk_coord in chunk_coords: - q = imanager.get_task_queue(imanager.config.CLUSTER.ATOMIC_Q_NAME) - # buffer for optimal use of redis memory - if len(q) > imanager.config.CLUSTER.ATOMIC_Q_LIMIT: - print(f"Sleeping {imanager.config.CLUSTER.ATOMIC_Q_INTERVAL}s...") - sleep(imanager.config.CLUSTER.ATOMIC_Q_INTERVAL) - - x, y, z = chunk_coord - chunk_str = f"{x}_{y}_{z}" - if imanager.redis.sismember("2c", chunk_str): - # already done, skip - continue - job_datas.append( - RQueue.prepare_data( - _create_atomic_chunk, - args=(chunk_coord,), - timeout=environ.get("L2JOB_TIMEOUT", "3m"), - result_ttl=0, - job_id=chunk_id_str(2, chunk_coord), - ) - ) - if len(job_datas) % batch_size == 0: - q.enqueue_many(job_datas) - job_datas = [] - q.enqueue_many(job_datas) + chunk_edges_all, mapping = _get_atomic_chunk_data(imanager, coords) + chunk_edges_active, isolated_ids = get_active_edges(chunk_edges_all, mapping) + add_atomic_chunk(imanager.cg, coords, chunk_edges_active, isolated=isolated_ids) + + for k, v in chunk_edges_all.items(): + logging.debug(f"{k}: {len(v)}") + for k, v in chunk_edges_active.items(): + logging.debug(f"active_{k}: {len(v)}") + _post_task_completion(imanager, 2, coords) -def _create_atomic_chunk(coords: Sequence[int]): - """Creates single atomic chunk""" +def upgrade_atomic_chunk(coords: Sequence[int]): + """Upgrades single atomic chunk""" redis = get_redis_connection() imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) coords = np.array(list(coords), dtype=int) - chunk_edges_all, mapping = get_atomic_chunk_data(imanager, coords) - chunk_edges_active, isolated_ids = get_active_edges(chunk_edges_all, mapping) - add_atomic_edges(imanager.cg, coords, chunk_edges_active, isolated=isolated_ids) - if imanager.config.TEST_RUN: - # print for debugging - for k, v in chunk_edges_all.items(): - print(k, len(v)) - for k, v in chunk_edges_active.items(): - print(f"active_{k}", len(v)) + update_atomic_chunk(imanager.cg, coords) + _post_task_completion(imanager, 2, coords) + + +def convert_to_ocdbt(coords: Sequence[int]): + """ + Convert edges stored per chunk to ajacency list in the tensorstore ocdbt kv store. + """ + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + coords = np.array(list(coords), dtype=int) + chunk_edges_all, mapping = _get_atomic_chunk_data(imanager, coords) + + node_ids1 = [] + node_ids2 = [] + affinities = [] + areas = [] + for edges in chunk_edges_all.values(): + node_ids1.extend(edges.node_ids1) + node_ids2.extend(edges.node_ids2) + affinities.extend(edges.affinities) + areas.extend(edges.areas) + + edges = Edges(node_ids1, node_ids2, affinities=affinities, areas=areas) + nodes = np.concatenate( + [edges.node_ids1, edges.node_ids2, np.fromiter(mapping.keys(), dtype=NODE_ID)] + ) + nodes = np.unique(nodes) + + chunk_id = imanager.cg.get_chunk_id(layer=1, x=coords[0], y=coords[1], z=coords[2]) + chunk_ids = imanager.cg.get_chunk_ids_from_node_ids(nodes) + + host = imanager.redis.get("OCDBT_COORDINATOR_HOST").decode() + port = imanager.redis.get("OCDBT_COORDINATOR_PORT").decode() + environ["OCDBT_COORDINATOR_HOST"] = host + environ["OCDBT_COORDINATOR_PORT"] = port + logging.info(f"OCDBT Coordinator address {host}:{port}") + + put_edges( + f"{imanager.cg.meta.data_source.EDGES}/ocdbt", + nodes[chunk_ids == chunk_id], + edges, + ) _post_task_completion(imanager, 2, coords) def _get_test_chunks(meta: ChunkedGraphMeta): - """Chunks at center of the dataset most likely not to be empty""" + """Chunks at the center most likely not to be empty""" parent_coords = np.array(meta.layer_chunk_bounds[3]) // 2 return get_children_chunk_coords(meta, 3, parent_coords) - # f = lambda r1, r2, r3: np.array(np.meshgrid(r1, r2, r3), dtype=int).T.reshape(-1, 3) - # return f((x, x + 1), (y, y + 1), (z, z + 1)) + + +def _queue_tasks(imanager: IngestionManager, chunk_fn: Callable, coords: Iterable): + queue_name = "l2" + q = imanager.get_task_queue(queue_name) + batch_size = int(environ.get("JOB_BATCH_SIZE", 10000)) + batches = chunked(coords, batch_size) + retry = int(environ.get("RETRY_COUNT", 0)) + failure_ttl = int(environ.get("FAILURE_TTL", 300)) + for batch in batches: + _coords = get_chunks_not_done(imanager, 2, batch) + # buffer for optimal use of redis memory + if len(q) > int(environ.get("QUEUE_SIZE", 1000000)): + interval = int(environ.get("QUEUE_INTERVAL", 300)) + logging.info(f"Queue full; sleeping {interval}s...") + sleep(interval) + + job_datas = [] + for chunk_coord in _coords: + job_datas.append( + RQueue.prepare_data( + chunk_fn, + args=(chunk_coord,), + timeout=environ.get("L2JOB_TIMEOUT", "3m"), + result_ttl=0, + job_id=chunk_id_str(2, chunk_coord), + retry=Retry(retry) if retry > 1 else None, + description="", + failure_ttl=failure_ttl + ) + ) + q.enqueue_many(job_datas) + logging.info(f"Queued {len(job_datas)} chunks.") + + +def enqueue_l2_tasks(imanager: IngestionManager, chunk_fn: Callable): + """ + `chunk_fn`: function to process a given layer 2 chunk. + """ + chunk_coords = _get_test_chunks(imanager.cg.meta) + chunk_count = len(chunk_coords) + if not imanager.config.TEST_RUN: + atomic_chunk_bounds = imanager.cg_meta.layer_chunk_bounds[2] + chunk_coords = randomize_grid_points(*atomic_chunk_bounds) + chunk_count = imanager.cg_meta.layer_chunk_counts[0] + logging.info(f"Chunk count: {chunk_count}, queuing...") + _queue_tasks(imanager, chunk_fn, chunk_coords) diff --git a/pychunkedgraph/ingest/common.py b/pychunkedgraph/ingest/common.py deleted file mode 100644 index dccf58602..000000000 --- a/pychunkedgraph/ingest/common.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import Dict -from typing import Tuple -from typing import Sequence - -from .manager import IngestionManager -from .ran_agglomeration import read_raw_edge_data -from .ran_agglomeration import read_raw_agglomeration_data -from ..graph import ChunkedGraph -from ..io.edges import get_chunk_edges -from ..io.components import get_chunk_components - - -def get_atomic_chunk_data( - imanager: IngestionManager, coord: Sequence[int] -) -> Tuple[Dict, Dict]: - """ - Helper to read either raw data or processed data - If reading from raw data, save it as processed data - """ - chunk_edges = ( - read_raw_edge_data(imanager, coord) - if imanager.config.USE_RAW_EDGES - else get_chunk_edges(imanager.cg_meta.data_source.EDGES, [coord]) - ) - - _check_edges_direction(chunk_edges, imanager.cg, coord) - - mapping = ( - read_raw_agglomeration_data(imanager, coord) - if imanager.config.USE_RAW_COMPONENTS - else get_chunk_components(imanager.cg_meta.data_source.COMPONENTS, coord) - ) - return chunk_edges, mapping - - -def _check_edges_direction( - chunk_edges: dict, cg: ChunkedGraph, coord: Sequence[int] -) -> None: - """ - For between and cross chunk edges: - Checks and flips edges such that nodes1 are always within a chunk and nodes2 outside the chunk. - Where nodes1 = edges[:,0] and nodes2 = edges[:,1]. - """ - import numpy as np - from ..graph.edges import Edges - from ..graph.edges import EDGE_TYPES - - x, y, z = coord - chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) - for edge_type in [EDGE_TYPES.between_chunk, EDGE_TYPES.cross_chunk]: - edges = chunk_edges[edge_type] - e1 = edges.node_ids1 - e2 = edges.node_ids2 - - e2_chunk_ids = cg.get_chunk_ids_from_node_ids(e2) - mask = e2_chunk_ids == chunk_id - e1[mask], e2[mask] = e2[mask], e1[mask] - - e1_chunk_ids = cg.get_chunk_ids_from_node_ids(e1) - mask = e1_chunk_ids == chunk_id - assert np.all(mask), "all IDs must belong to same chunk" diff --git a/pychunkedgraph/ingest/create/abstract_layers.py b/pychunkedgraph/ingest/create/abstract_layers.py deleted file mode 100644 index 529a6846f..000000000 --- a/pychunkedgraph/ingest/create/abstract_layers.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Functions for creating parents in level 3 and above -""" - -import time -import math -import datetime -import multiprocessing as mp -from collections import defaultdict -from typing import Optional -from typing import Sequence -from typing import List - -import numpy as np -from multiwrapper import multiprocessing_utils as mu - -from ...graph import types -from ...graph import attributes -from ...utils.general import chunked -from ...graph.utils import flatgraph -from ...graph.utils import basetypes -from ...graph.utils import serializers -from ...graph.chunkedgraph import ChunkedGraph -from ...graph.utils.generic import get_valid_timestamp -from ...graph.utils.generic import filter_failed_node_ids -from ...graph.chunks.hierarchy import get_children_chunk_coords -from ...graph.connectivity.cross_edges import get_children_chunk_cross_edges -from ...graph.connectivity.cross_edges import get_chunk_nodes_cross_edge_layer - - -def add_layer( - cg: ChunkedGraph, - layer_id: int, - parent_coords: Sequence[int], - children_coords: Sequence[Sequence[int]] = np.array([]), - *, - time_stamp: Optional[datetime.datetime] = None, - n_threads: int = 4, -) -> None: - if not children_coords.size: - children_coords = get_children_chunk_coords(cg.meta, layer_id, parent_coords) - children_ids = _read_children_chunks(cg, layer_id, children_coords, n_threads > 1) - edge_ids = get_children_chunk_cross_edges( - cg, layer_id, parent_coords, use_threads=n_threads > 1 - ) - - print("children_coords", children_coords.size, layer_id, parent_coords) - print( - "n e", len(children_ids), len(edge_ids), layer_id, parent_coords, - ) - - node_layers = cg.get_chunk_layers(children_ids) - edge_layers = cg.get_chunk_layers(np.unique(edge_ids)) - assert np.all(node_layers < layer_id), "invalid node layers" - assert np.all(edge_layers < layer_id), "invalid edge layers" - # Extract connected components - # isolated_node_mask = ~np.in1d(children_ids, np.unique(edge_ids)) - # add_node_ids = children_ids[isolated_node_mask].squeeze() - add_edge_ids = np.vstack([children_ids, children_ids]).T - - edge_ids = list(edge_ids) - edge_ids.extend(add_edge_ids) - graph, _, _, graph_ids = flatgraph.build_gt_graph(edge_ids, make_directed=True) - ccs = flatgraph.connected_components(graph) - print("ccs", len(ccs)) - _write_connected_components( - cg, - layer_id, - parent_coords, - ccs, - graph_ids, - get_valid_timestamp(time_stamp), - n_threads > 1, - ) - return f"{layer_id}_{'_'.join(map(str, parent_coords))}" - - -def _read_children_chunks( - cg: ChunkedGraph, layer_id, children_coords, use_threads=True -): - if not use_threads: - children_ids = [types.empty_1d] - for child_coord in children_coords: - children_ids.append(_read_chunk([], cg, layer_id - 1, child_coord)) - return np.concatenate(children_ids) - - print("_read_children_chunks") - with mp.Manager() as manager: - children_ids_shared = manager.list() - multi_args = [] - for child_coord in children_coords: - multi_args.append( - ( - children_ids_shared, - cg.get_serialized_info(), - layer_id - 1, - child_coord, - ) - ) - mu.multiprocess_func( - _read_chunk_helper, - multi_args, - n_threads=min(len(multi_args), mp.cpu_count()), - ) - print("_read_children_chunks done") - return np.concatenate(children_ids_shared) - - -def _read_chunk_helper(args): - children_ids_shared, cg_info, layer_id, chunk_coord = args - cg = ChunkedGraph(**cg_info) - _read_chunk(children_ids_shared, cg, layer_id, chunk_coord) - - -def _read_chunk(children_ids_shared, cg: ChunkedGraph, layer_id: int, chunk_coord): - print(f"_read_chunk {layer_id}, {chunk_coord}") - x, y, z = chunk_coord - range_read = cg.range_read_chunk( - cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z), - properties=attributes.Hierarchy.Child, - ) - row_ids = [] - max_children_ids = [] - for row_id, row_data in range_read.items(): - row_ids.append(row_id) - max_children_ids.append(np.max(row_data[0].value)) - row_ids = np.array(row_ids, dtype=basetypes.NODE_ID) - segment_ids = np.array([cg.get_segment_id(r_id) for r_id in row_ids]) - - row_ids = filter_failed_node_ids(row_ids, segment_ids, max_children_ids) - children_ids_shared.append(row_ids) - print(f"_read_chunk {layer_id}, {chunk_coord} done {len(row_ids)}") - return row_ids - - -def _write_connected_components( - cg: ChunkedGraph, - layer_id: int, - parent_coords, - ccs, - graph_ids, - time_stamp, - use_threads=True, -) -> None: - if not ccs: - return - - node_layer_d_shared = {} - if layer_id < cg.meta.layer_count: - print("getting node_layer_d_shared") - node_layer_d_shared = get_chunk_nodes_cross_edge_layer( - cg, layer_id, parent_coords, use_threads=use_threads - ) - - print("node_layer_d_shared", len(node_layer_d_shared)) - - ccs_with_node_ids = [] - for cc in ccs: - ccs_with_node_ids.append(graph_ids[cc]) - - if not use_threads: - _write( - cg, - layer_id, - parent_coords, - ccs_with_node_ids, - node_layer_d_shared, - time_stamp, - use_threads=use_threads, - ) - return - - task_size = int(math.ceil(len(ccs_with_node_ids) / mp.cpu_count() / 10)) - chunked_ccs = chunked(ccs_with_node_ids, task_size) - cg_info = cg.get_serialized_info() - multi_args = [] - for ccs in chunked_ccs: - multi_args.append( - (cg_info, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp) - ) - mu.multiprocess_func( - _write_components_helper, - multi_args, - n_threads=min(len(multi_args), mp.cpu_count()), - ) - - -def _write_components_helper(args): - print("running _write_components_helper") - cg_info, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp = args - cg = ChunkedGraph(**cg_info) - _write(cg, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp) - - -def _write( - cg, layer_id, parent_coords, ccs, node_layer_d_shared, time_stamp, use_threads=True -): - parent_layer_ids = range(layer_id, cg.meta.layer_count + 1) - cc_connections = {l: [] for l in parent_layer_ids} - for node_ids in ccs: - layer = layer_id - if len(node_ids) == 1: - layer = node_layer_d_shared.get(node_ids[0], cg.meta.layer_count) - cc_connections[layer].append(node_ids) - - rows = [] - x, y, z = parent_coords - parent_chunk_id = cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z) - parent_chunk_id_dict = cg.get_parent_chunk_id_dict(parent_chunk_id) - - # Iterate through layers - for parent_layer_id in parent_layer_ids: - if len(cc_connections[parent_layer_id]) == 0: - continue - - parent_chunk_id = parent_chunk_id_dict[parent_layer_id] - reserved_parent_ids = cg.id_client.create_node_ids( - parent_chunk_id, - size=len(cc_connections[parent_layer_id]), - root_chunk=parent_layer_id == cg.meta.layer_count and use_threads, - ) - - for i_cc, node_ids in enumerate(cc_connections[parent_layer_id]): - parent_id = reserved_parent_ids[i_cc] - for node_id in node_ids: - rows.append( - cg.client.mutate_row( - serializers.serialize_uint64(node_id), - {attributes.Hierarchy.Parent: parent_id}, - time_stamp=time_stamp, - ) - ) - - rows.append( - cg.client.mutate_row( - serializers.serialize_uint64(parent_id), - {attributes.Hierarchy.Child: node_ids}, - time_stamp=time_stamp, - ) - ) - - if len(rows) > 100000: - cg.client.write(rows) - print("wrote rows", len(rows), layer_id, parent_coords) - rows = [] - cg.client.write(rows) - print("wrote rows", len(rows), layer_id, parent_coords) diff --git a/pychunkedgraph/ingest/create/atomic_layer.py b/pychunkedgraph/ingest/create/atomic_layer.py index 4fa1f1688..0a7aae728 100644 --- a/pychunkedgraph/ingest/create/atomic_layer.py +++ b/pychunkedgraph/ingest/create/atomic_layer.py @@ -1,14 +1,14 @@ +# pylint: disable=invalid-name, missing-function-docstring, import-outside-toplevel + """ Functions for creating atomic nodes and their level 2 abstract parents """ import datetime from typing import Dict -from typing import List from typing import Optional from typing import Sequence -import pytz import numpy as np from ...graph import attributes @@ -23,9 +23,9 @@ from ...graph.utils.flatgraph import connected_components -def add_atomic_edges( +def add_atomic_chunk( cg: ChunkedGraph, - chunk_coord: np.ndarray, + coords: Sequence[int], chunk_edges_d: Dict[str, Edges], isolated: Sequence[int], time_stamp: Optional[datetime.datetime] = None, @@ -40,9 +40,7 @@ def add_atomic_edges( graph, _, _, unique_ids = build_gt_graph(chunk_edge_ids, make_directed=True) ccs = connected_components(graph) - parent_chunk_id = cg.get_chunk_id( - layer=2, x=chunk_coord[0], y=chunk_coord[1], z=chunk_coord[2] - ) + parent_chunk_id = cg.get_chunk_id(layer=2, x=coords[0], y=coords[1], z=coords[2]) parent_ids = cg.id_client.create_node_ids(parent_chunk_id, size=len(ccs)) sparse_indices, remapping = _get_remapping(chunk_edges_d) @@ -101,7 +99,13 @@ def _get_remapping(chunk_edges_d: dict): def _process_component( - cg, chunk_edges_d, parent_id, node_ids, sparse_indices, remapping, time_stamp, + cg, + chunk_edges_d, + parent_id, + node_ids, + sparse_indices, + remapping, + time_stamp, ): nodes = [] chunk_out_edges = [] # out = between + cross @@ -120,7 +124,7 @@ def _process_component( for cc_layer in u_cce_layers: layer_out_edges = chunk_out_edges[cce_layers == cc_layer] if layer_out_edges.size: - col = attributes.Connectivity.CrossChunkEdge[cc_layer] + col = attributes.Connectivity.AtomicCrossChunkEdge[cc_layer] val_dict[col] = layer_out_edges r_key = serializers.serialize_uint64(parent_id) diff --git a/pychunkedgraph/graph/connectivity/cross_edges.py b/pychunkedgraph/ingest/create/cross_edges.py similarity index 61% rename from pychunkedgraph/graph/connectivity/cross_edges.py rename to pychunkedgraph/ingest/create/cross_edges.py index 8aa52a9f1..9581838af 100644 --- a/pychunkedgraph/graph/connectivity/cross_edges.py +++ b/pychunkedgraph/ingest/create/cross_edges.py @@ -1,43 +1,38 @@ -import time +# pylint: disable=invalid-name, missing-docstring + import math import multiprocessing as mp from collections import defaultdict -from typing import Optional from typing import Sequence -from typing import List from typing import Dict import numpy as np from multiwrapper.multiprocessing_utils import multiprocess_func -from .. import attributes -from ..types import empty_2d -from ..utils import basetypes -from ..utils import serializers -from ..chunkedgraph import ChunkedGraph -from ..utils.generic import get_valid_timestamp -from ..utils.generic import filter_failed_node_ids -from ..chunks.atomic import get_touching_atomic_chunks -from ..chunks.atomic import get_bounding_atomic_chunks +from ...graph import attributes +from ...graph.types import empty_2d +from ...graph.utils import basetypes +from ...graph.chunkedgraph import ChunkedGraph +from ...graph.utils.generic import filter_failed_node_ids +from ...graph.chunks.atomic import get_touching_atomic_chunks +from ...graph.chunks.atomic import get_bounding_atomic_chunks from ...utils.general import chunked def get_children_chunk_cross_edges( - cg, layer, chunk_coord, *, use_threads=True + cg: ChunkedGraph, layer, chunk_coord, *, use_threads=True ) -> np.ndarray: """ Cross edges that connect children chunks. - The edges are between node IDs in the given layer (not atomic). + The edges are between node IDs in the given layer. """ atomic_chunks = get_touching_atomic_chunks(cg.meta, layer, chunk_coord) - if not len(atomic_chunks): + if len(atomic_chunks) == 0: return [] - print(f"touching atomic chunk count {len(atomic_chunks)}") if not use_threads: return _get_children_chunk_cross_edges(cg, atomic_chunks, layer - 1) - print("get_children_chunk_cross_edges, atomic chunks", len(atomic_chunks)) with mp.Manager() as manager: edge_ids_shared = manager.list() edge_ids_shared.append(empty_2d) @@ -68,10 +63,12 @@ def _get_children_chunk_cross_edges_helper(args) -> None: edge_ids_shared.append(_get_children_chunk_cross_edges(cg, atomic_chunks, layer)) -def _get_children_chunk_cross_edges(cg, atomic_chunks, layer) -> None: - print( - f"_get_children_chunk_cross_edges {layer} atomic_chunks count {len(atomic_chunks)}" - ) +def _get_children_chunk_cross_edges(cg: ChunkedGraph, atomic_chunks, layer) -> np.ndarray: + """ + Non parallelized version + Cross edges that connect children chunks. + The edges are between node IDs in the given layer (not atomic). + """ cross_edges = [empty_2d] for layer2_chunk in atomic_chunks: edges = _read_atomic_chunk_cross_edges(cg, layer2_chunk, layer) @@ -80,18 +77,21 @@ def _get_children_chunk_cross_edges(cg, atomic_chunks, layer) -> None: cross_edges = np.concatenate(cross_edges) if not cross_edges.size: return empty_2d - print(f"getting roots at stop_layer {layer} {cross_edges.shape}") + cross_edges[:, 0] = cg.get_roots(cross_edges[:, 0], stop_layer=layer, ceil=False) cross_edges[:, 1] = cg.get_roots(cross_edges[:, 1], stop_layer=layer, ceil=False) result = np.unique(cross_edges, axis=0) if cross_edges.size else empty_2d - print(f"_get_children_chunk_cross_edges done {result.shape}") return result def _read_atomic_chunk_cross_edges( - cg, chunk_coord: Sequence[int], cross_edge_layer: int + cg: ChunkedGraph, chunk_coord: Sequence[int], cross_edge_layer: int ) -> np.ndarray: - cross_edge_col = attributes.Connectivity.CrossChunkEdge[cross_edge_layer] + """ + Returns cross edges between l2 nodes in current chunk and + l1 supervoxels from neighbor chunks. + """ + cross_edge_col = attributes.Connectivity.AtomicCrossChunkEdge[cross_edge_layer] range_read, l2ids = _read_atomic_chunk(cg, chunk_coord, [cross_edge_layer]) parent_neighboring_chunk_supervoxels_d = defaultdict(list) @@ -102,8 +102,7 @@ def _read_atomic_chunk_cross_edges( parent_neighboring_chunk_supervoxels_d[l2id] = edges[:, 1] cross_edges = [empty_2d] - for l2id in parent_neighboring_chunk_supervoxels_d: - nebor_svs = parent_neighboring_chunk_supervoxels_d[l2id] + for l2id, nebor_svs in parent_neighboring_chunk_supervoxels_d.items(): chunk_parent_ids = np.array([l2id] * len(nebor_svs), dtype=basetypes.NODE_ID) cross_edges.append(np.vstack([chunk_parent_ids, nebor_svs]).T) cross_edges = np.concatenate(cross_edges) @@ -111,35 +110,31 @@ def _read_atomic_chunk_cross_edges( def get_chunk_nodes_cross_edge_layer( - cg, layer: int, chunk_coord: Sequence[int], use_threads=True + cg: ChunkedGraph, layer: int, chunk_coord: Sequence[int], use_threads=True ) -> Dict: """ gets nodes in a chunk that are part of cross chunk edges return_type dict {node_id: layer} the lowest layer (>= current layer) at which a node_id is part of a cross edge """ - print("get_bounding_atomic_chunks") atomic_chunks = get_bounding_atomic_chunks(cg.meta, layer, chunk_coord) - print("get_bounding_atomic_chunks complete") - if not len(atomic_chunks): + if len(atomic_chunks) == 0: return {} if not use_threads: return _get_chunk_nodes_cross_edge_layer(cg, atomic_chunks, layer) - print("divide tasks") cg_info = cg.get_serialized_info() manager = mp.Manager() - ids_l_shared = manager.list() - layers_l_shared = manager.list() + node_ids_shared = manager.list() + node_layers_shared = manager.list() task_size = int(math.ceil(len(atomic_chunks) / mp.cpu_count() / 10)) chunked_l2chunk_list = chunked(atomic_chunks, task_size) multi_args = [] for atomic_chunks in chunked_l2chunk_list: multi_args.append( - (ids_l_shared, layers_l_shared, cg_info, atomic_chunks, layer) + (node_ids_shared, node_layers_shared, cg_info, atomic_chunks, layer) ) - print("divide tasks complete") multiprocess_func( _get_chunk_nodes_cross_edge_layer_helper, @@ -148,24 +143,29 @@ def get_chunk_nodes_cross_edge_layer( ) node_layer_d_shared = manager.dict() - _find_min_layer(node_layer_d_shared, ids_l_shared, layers_l_shared) - print("_find_min_layer complete") + _find_min_layer(node_layer_d_shared, node_ids_shared, node_layers_shared) return node_layer_d_shared def _get_chunk_nodes_cross_edge_layer_helper(args): - ids_l_shared, layers_l_shared, cg_info, atomic_chunks, layer = args + node_ids_shared, node_layers_shared, cg_info, atomic_chunks, layer = args cg = ChunkedGraph(**cg_info) node_layer_d = _get_chunk_nodes_cross_edge_layer(cg, atomic_chunks, layer) - ids_l_shared.append(np.fromiter(node_layer_d.keys(), dtype=basetypes.NODE_ID)) - layers_l_shared.append(np.fromiter(node_layer_d.values(), dtype=np.uint8)) + node_ids_shared.append(np.fromiter(node_layer_d.keys(), dtype=basetypes.NODE_ID)) + node_layers_shared.append(np.fromiter(node_layer_d.values(), dtype=np.uint8)) -def _get_chunk_nodes_cross_edge_layer(cg, atomic_chunks, layer): +def _get_chunk_nodes_cross_edge_layer(cg: ChunkedGraph, atomic_chunks, layer): + """ + Non parallelized version + gets nodes in a chunk that are part of cross chunk edges + return_type dict {node_id: layer} + the lowest layer (>= current layer) at which a node_id is part of a cross edge + """ atomic_node_layer_d = {} for atomic_chunk in atomic_chunks: chunk_node_layer_d = _read_atomic_chunk_cross_edge_nodes( - cg, atomic_chunk, range(layer, cg.meta.layer_count + 1) + cg, atomic_chunk, layer ) atomic_node_layer_d.update(chunk_node_layer_d) @@ -179,32 +179,57 @@ def _get_chunk_nodes_cross_edge_layer(cg, atomic_chunks, layer): return node_layer_d -def _read_atomic_chunk_cross_edge_nodes(cg, chunk_coord, cross_edge_layers): +def _read_atomic_chunk_cross_edge_nodes(cg: ChunkedGraph, chunk_coord, layer): + """ + the lowest layer at which an l2 node is part of a cross edge + """ node_layer_d = {} - range_read, l2ids = _read_atomic_chunk(cg, chunk_coord, cross_edge_layers) + relevant_layers = range(layer, cg.meta.layer_count) + range_read, l2ids = _read_atomic_chunk(cg, chunk_coord, relevant_layers) for l2id in l2ids: - for layer in cross_edge_layers: - if attributes.Connectivity.CrossChunkEdge[layer] in range_read[l2id]: + for layer in relevant_layers: + if attributes.Connectivity.AtomicCrossChunkEdge[layer] in range_read[l2id]: node_layer_d[l2id] = layer break return node_layer_d -def _find_min_layer(node_layer_d_shared, ids_l_shared, layers_l_shared): - node_ids = np.concatenate(ids_l_shared) - layers = np.concatenate(layers_l_shared) +def _find_min_layer(node_layer_d_shared, node_ids_shared, node_layers_shared): + """ + `node_layer_d_shared`: DictProxy + + `node_ids_shared`: ListProxy + + `node_layers_shared`: ListProxy + + Due to parallelization, there will be multiple values for min_layer of a node. + We need to find the global min_layer after all multiprocesses return. + For eg: + At some indices p and q, there will be a node_id x + i.e. `node_ids_shared[p] == node_ids_shared[q]` + + and node_layers_shared[p] != node_layers_shared[q] + so we need: + `node_layer_d_shared[x] = min(node_layers_shared[p], node_layers_shared[q])` + """ + node_ids = np.concatenate(node_ids_shared) + layers = np.concatenate(node_layers_shared) for i, node_id in enumerate(node_ids): layer = node_layer_d_shared.get(node_id, layers[i]) node_layer_d_shared[node_id] = min(layer, layers[i]) -def _read_atomic_chunk(cg, chunk_coord, layers): +def _read_atomic_chunk(cg: ChunkedGraph, chunk_coord, layers): + """ + read entire atomic chunk; all nodes and their relevant cross edges + filter out invalid nodes generated by failed tasks + """ x, y, z = chunk_coord child_col = attributes.Hierarchy.Child range_read = cg.range_read_chunk( cg.get_chunk_id(layer=2, x=x, y=y, z=z), properties=[child_col] - + [attributes.Connectivity.CrossChunkEdge[l] for l in layers], + + [attributes.Connectivity.AtomicCrossChunkEdge[l] for l in layers], ) row_ids = [] diff --git a/pychunkedgraph/ingest/create/parent_layer.py b/pychunkedgraph/ingest/create/parent_layer.py new file mode 100644 index 000000000..90b24d26a --- /dev/null +++ b/pychunkedgraph/ingest/create/parent_layer.py @@ -0,0 +1,247 @@ +# pylint: disable=invalid-name, missing-docstring, import-outside-toplevel, c-extension-no-member + +""" +Functions for creating parents in level 3 and above +""" + +import math +import datetime +import multiprocessing as mp +from typing import Optional +from typing import Sequence + +import fastremap +import numpy as np +from multiwrapper import multiprocessing_utils as mu + +from ...graph import types +from ...graph import attributes +from ...utils.general import chunked +from ...graph.utils import flatgraph +from ...graph.utils import basetypes +from ...graph.utils import serializers +from ...graph.chunkedgraph import ChunkedGraph +from ...graph.edges.utils import concatenate_cross_edge_dicts +from ...graph.utils.generic import get_valid_timestamp +from ...graph.utils.generic import filter_failed_node_ids +from ...graph.chunks.hierarchy import get_children_chunk_coords +from .cross_edges import get_children_chunk_cross_edges +from .cross_edges import get_chunk_nodes_cross_edge_layer + + +def add_parent_chunk( + cg: ChunkedGraph, + layer_id: int, + coords: Sequence[int], + children_coords: Sequence[Sequence[int]] = np.array([]), + *, + time_stamp: Optional[datetime.datetime] = None, + n_threads: int = 4, +) -> None: + if not children_coords.size: + children_coords = get_children_chunk_coords(cg.meta, layer_id, coords) + children_ids = _read_children_chunks(cg, layer_id, children_coords, n_threads > 1) + cx_edges = get_children_chunk_cross_edges( + cg, layer_id, coords, use_threads=n_threads > 1 + ) + + node_layers = cg.get_chunk_layers(children_ids) + edge_layers = cg.get_chunk_layers(np.unique(cx_edges)) + assert np.all(node_layers < layer_id), "invalid node layers" + assert np.all(edge_layers < layer_id), "invalid edge layers" + + cx_edges = list(cx_edges) + cx_edges.extend(np.vstack([children_ids, children_ids]).T) # add self-edges + graph, _, _, graph_ids = flatgraph.build_gt_graph(cx_edges, make_directed=True) + raw_ccs = flatgraph.connected_components(graph) # connected components with indices + connected_components = [graph_ids[cc] for cc in raw_ccs] + + _write_connected_components( + cg, + layer_id, + coords, + connected_components, + get_valid_timestamp(time_stamp), + n_threads > 1, + ) + + +def _read_children_chunks( + cg: ChunkedGraph, layer_id, children_coords, use_threads=True +): + if not use_threads: + children_ids = [types.empty_1d] + for child_coord in children_coords: + children_ids.append(_read_chunk([], cg, layer_id - 1, child_coord)) + return np.concatenate(children_ids) + + with mp.Manager() as manager: + children_ids_shared = manager.list() + multi_args = [] + for child_coord in children_coords: + multi_args.append( + ( + children_ids_shared, + cg.get_serialized_info(), + layer_id - 1, + child_coord, + ) + ) + mu.multiprocess_func( + _read_chunk_helper, + multi_args, + n_threads=min(len(multi_args), mp.cpu_count()), + ) + return np.concatenate(children_ids_shared) + + +def _read_chunk_helper(args): + children_ids_shared, cg_info, layer_id, chunk_coord = args + cg = ChunkedGraph(**cg_info) + _read_chunk(children_ids_shared, cg, layer_id, chunk_coord) + + +def _read_chunk(children_ids_shared, cg: ChunkedGraph, layer_id: int, chunk_coord): + x, y, z = chunk_coord + range_read = cg.range_read_chunk( + cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z), + properties=attributes.Hierarchy.Child, + ) + row_ids = [] + max_children_ids = [] + for row_id, row_data in range_read.items(): + row_ids.append(row_id) + max_children_ids.append(np.max(row_data[0].value)) + row_ids = np.array(row_ids, dtype=basetypes.NODE_ID) + segment_ids = np.array([cg.get_segment_id(r_id) for r_id in row_ids]) + + row_ids = filter_failed_node_ids(row_ids, segment_ids, max_children_ids) + children_ids_shared.append(row_ids) + return row_ids + + +def _write_connected_components( + cg, layer, pcoords, components, time_stamp, use_threads=True +): + if len(components) == 0: + return + + node_layer_d = {} + if layer < cg.meta.layer_count: + node_layer_d = get_chunk_nodes_cross_edge_layer(cg, layer, pcoords, use_threads) + + if not use_threads: + _write(cg, layer, pcoords, components, node_layer_d, time_stamp, use_threads) + return + + task_size = int(math.ceil(len(components) / mp.cpu_count() / 10)) + chunked_ccs = chunked(components, task_size) + cg_info = cg.get_serialized_info() + multi_args = [] + for ccs in chunked_ccs: + args = (cg_info, layer, pcoords, ccs, node_layer_d, time_stamp) + multi_args.append(args) + mu.multiprocess_func( + _write_components_helper, + multi_args, + n_threads=min(len(multi_args), mp.cpu_count()), + ) + + +def _write_components_helper(args): + cg_info, layer, pcoords, ccs, node_layer_d, time_stamp = args + cg = ChunkedGraph(**cg_info) + _write(cg, layer, pcoords, ccs, node_layer_d, time_stamp) + + +def _children_rows( + cg: ChunkedGraph, parent_id, children: Sequence, cx_edges_d: dict, time_stamp +): + """ + Update children rows to point to the parent_id, collect cached children + cross chunk edges to lift and update parent cross chunk edges. + Returns list of mutations to children and list of children cross edges. + """ + rows = [] + children_cx_edges = [] + children_layers = cg.get_chunk_layers(children) + for child, node_layer in zip(children, children_layers): + node_layer = cg.get_chunk_layer(child) + row_id = serializers.serialize_uint64(child) + val_dict = {attributes.Hierarchy.Parent: parent_id} + node_cx_edges_d = cx_edges_d.get(child, {}) + if not node_cx_edges_d: + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp)) + continue + for layer in range(node_layer, cg.meta.layer_count): + if not layer in node_cx_edges_d: + continue + layer_edges = node_cx_edges_d[layer] + nodes = np.unique(layer_edges) + parents = cg.get_roots(nodes, stop_layer=node_layer, ceil=False) + edge_parents_d = dict(zip(nodes, parents)) + layer_edges = fastremap.remap( + layer_edges, edge_parents_d, preserve_missing_labels=True + ) + layer_edges = np.unique(layer_edges, axis=0) + col = attributes.Connectivity.CrossChunkEdge[layer] + val_dict[col] = layer_edges + node_cx_edges_d[layer] = layer_edges + children_cx_edges.append(node_cx_edges_d) + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp)) + return rows, children_cx_edges + + +def _write( + cg: ChunkedGraph, + layer_id, + parent_coords, + components, + node_layer_d, + ts, + use_threads=True, +): + parent_layers = range(layer_id, cg.meta.layer_count + 1) + cc_connections = {l: [] for l in parent_layers} + for node_ids in components: + layer = layer_id + if len(node_ids) == 1: + layer = node_layer_d.get(node_ids[0], cg.meta.layer_count) + cc_connections[layer].append(node_ids) + + rows = [] + x, y, z = parent_coords + parent_chunk_id = cg.get_chunk_id(layer=layer_id, x=x, y=y, z=z) + parent_chunk_id_dict = cg.get_parent_chunk_id_dict(parent_chunk_id) + for parent_layer in parent_layers: + if len(cc_connections[parent_layer]) == 0: + continue + parent_chunk_id = parent_chunk_id_dict[parent_layer] + reserved_parent_ids = cg.id_client.create_node_ids( + parent_chunk_id, + size=len(cc_connections[parent_layer]), + root_chunk=parent_layer == cg.meta.layer_count and use_threads, + ) + for i_cc, children in enumerate(cc_connections[parent_layer]): + parent = reserved_parent_ids[i_cc] + if layer_id == 3: + # when layer 3 is being processed, children chunks are at layer 2 + # layer 2 chunks at this time will only have atomic cross edges + cx_edges_d = cg.get_atomic_cross_edges(children) + else: + cx_edges_d = cg.get_cross_chunk_edges(children, raw_only=True) + _rows, cx_edges = _children_rows(cg, parent, children, cx_edges_d, ts) + rows.extend(_rows) + row_id = serializers.serialize_uint64(parent) + val_dict = {attributes.Hierarchy.Child: children} + parent_cx_edges_d = concatenate_cross_edge_dicts(cx_edges, unique=True) + for layer in range(parent_layer, cg.meta.layer_count): + if not layer in parent_cx_edges_d: + continue + col = attributes.Connectivity.CrossChunkEdge[layer] + val_dict[col] = parent_cx_edges_d[layer] + rows.append(cg.client.mutate_row(row_id, val_dict, ts)) + if len(rows) > 100000: + cg.client.write(rows) + rows = [] + cg.client.write(rows) diff --git a/pychunkedgraph/ingest/manager.py b/pychunkedgraph/ingest/manager.py index f5f870810..55e7d253f 100644 --- a/pychunkedgraph/ingest/manager.py +++ b/pychunkedgraph/ingest/manager.py @@ -1,3 +1,5 @@ +# pylint: disable=invalid-name, missing-docstring + import pickle from . import IngestConfig @@ -15,7 +17,9 @@ def __init__(self, config: IngestConfig, chunkedgraph_meta: ChunkedGraphMeta): self._cg = None self._redis = None self._task_queues = {} - self.redis # initiate and cache info + + # initiate redis and cache info + self.redis # pylint: disable=pointless-statement @property def config(self): diff --git a/pychunkedgraph/ingest/ran_agglomeration.py b/pychunkedgraph/ingest/ran_agglomeration.py index 7c4af51f7..a0ca42d54 100644 --- a/pychunkedgraph/ingest/ran_agglomeration.py +++ b/pychunkedgraph/ingest/ran_agglomeration.py @@ -5,10 +5,7 @@ from collections import defaultdict from itertools import product -from typing import Dict -from typing import Iterable -from typing import Tuple -from typing import Union +from typing import Dict, Iterable, Tuple, Union from binascii import crc32 @@ -23,8 +20,7 @@ from ..io.edges import put_chunk_edges from ..io.components import put_chunk_components from ..graph.utils import basetypes -from ..graph.edges import Edges -from ..graph.edges import EDGE_TYPES +from ..graph.edges import EDGE_TYPES, Edges from ..graph.types import empty_2d from ..graph.chunks.utils import get_chunk_id diff --git a/pychunkedgraph/ingest/rq_cli.py b/pychunkedgraph/ingest/rq_cli.py index 27b9c865d..6a1a4882d 100644 --- a/pychunkedgraph/ingest/rq_cli.py +++ b/pychunkedgraph/ingest/rq_cli.py @@ -1,20 +1,18 @@ +# pylint: disable=invalid-name, missing-function-docstring + """ cli for redis jobs """ -import os import sys import click from redis import Redis from rq import Queue -from rq import Worker -from rq.worker import WorkerStatus from rq.job import Job from rq.exceptions import InvalidJobOperationError from rq.exceptions import NoSuchJobError from rq.registry import StartedJobRegistry from rq.registry import FailedJobRegistry -from flask import current_app from flask.cli import AppGroup from ..utils.redis import REDIS_HOST @@ -27,23 +25,6 @@ connection = Redis(host=REDIS_HOST, port=REDIS_PORT, db=0, password=REDIS_PASSWORD) -@rq_cli.command("status") -@click.argument("queues", nargs=-1, type=str) -@click.option("--show-busy", is_flag=True) -def get_status(queues, show_busy): - print("NOTE: Use --show-busy to display count of non idle workers\n") - for queue in queues: - q = Queue(queue, connection=connection) - print(f"Queue name \t: {queue}") - print(f"Jobs queued \t: {len(q)}") - print(f"Workers total \t: {Worker.count(queue=q)}") - if show_busy: - workers = Worker.all(queue=q) - count = sum([worker.get_state() == WorkerStatus.BUSY for worker in workers]) - print(f"Workers busy \t: {count}") - print(f"Jobs failed \t: {q.failed_job_registry.count}\n") - - @rq_cli.command("failed") @click.argument("queue", type=str) @click.argument("job_ids", nargs=-1) @@ -129,9 +110,14 @@ def clean_start_registry(queue): def clear_failed_registry(queue): failed_job_registry = FailedJobRegistry(queue, connection=connection) job_ids = failed_job_registry.get_job_ids() + count = 0 for job_id in job_ids: - failed_job_registry.remove(job_id, delete_job=True) - print(f"Deleted {len(job_ids)} jobs from the failed job registry.") + try: + failed_job_registry.remove(job_id, delete_job=True) + count += 1 + except Exception: + ... + print(f"Deleted {count} jobs from the failed job registry.") def init_rq_cmds(app): diff --git a/pychunkedgraph/ingest/simple_tests.py b/pychunkedgraph/ingest/simple_tests.py new file mode 100644 index 000000000..48a49f922 --- /dev/null +++ b/pychunkedgraph/ingest/simple_tests.py @@ -0,0 +1,177 @@ +# pylint: disable=invalid-name, missing-function-docstring, broad-exception-caught + +""" +Some sanity tests to ensure chunkedgraph was created properly. +""" + +from datetime import datetime, timezone +import numpy as np + +from pychunkedgraph.graph import attributes, ChunkedGraph + + +def family(cg: ChunkedGraph): + np.random.seed(42) + n_chunks = 100 + n_segments_per_chunk = 200 + timestamp = datetime.now(timezone.utc) + + node_ids = [] + for layer in range(2, cg.meta.layer_count - 1): + for _ in range(n_chunks): + c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) + c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) + c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) + chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) + max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) + if max_segment_id < 10: + continue + + segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) + for segment_id in segment_ids: + node_ids.append( + cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id)) + ) + + rows = cg.client.read_nodes( + node_ids=node_ids, end_time=timestamp, properties=attributes.Hierarchy.Parent + ) + valid_node_ids = [] + non_valid_node_ids = [] + for k in rows.keys(): + if len(rows[k]) > 0: + valid_node_ids.append(k) + else: + non_valid_node_ids.append(k) + + parents = cg.get_parents(valid_node_ids, time_stamp=timestamp) + children_dict = cg.get_children(parents) + for child, parent in zip(valid_node_ids, parents): + assert child in children_dict[parent] + print("success") + + +def existence(cg: ChunkedGraph): + np.random.seed(42) + layer = 2 + n_chunks = 100 + n_segments_per_chunk = 200 + timestamp = datetime.now(timezone.utc) + node_ids = [] + for _ in range(n_chunks): + c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) + c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) + c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) + chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) + max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) + if max_segment_id < 10: + continue + + segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) + for segment_id in segment_ids: + node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) + + rows = cg.client.read_nodes( + node_ids=node_ids, end_time=timestamp, properties=attributes.Hierarchy.Parent + ) + valid_node_ids = [] + non_valid_node_ids = [] + for k in rows.keys(): + if len(rows[k]) > 0: + valid_node_ids.append(k) + else: + non_valid_node_ids.append(k) + + roots = [] + try: + roots = cg.get_roots(valid_node_ids) + assert len(roots) == len(valid_node_ids) + print("success") + except Exception as e: + print(f"Something went wrong: {e}") + print("At least one node failed. Checking nodes one by one:") + + if len(roots) != len(valid_node_ids): + log_dict = {} + success_dict = {} + for node_id in valid_node_ids: + try: + _ = cg.get_root(node_id, time_stamp=timestamp) + print(f"Success: {node_id} from chunk {cg.get_chunk_id(node_id)}") + success_dict[node_id] = True + except Exception as e: + print(f"{node_id} - chunk {cg.get_chunk_id(node_id)} failed: {e}") + success_dict[node_id] = False + t_id = node_id + while t_id is not None: + last_working_chunk = cg.get_chunk_id(t_id) + t_id = cg.get_parent(t_id) + + layer = cg.get_chunk_layer(last_working_chunk) + print(f"Failed on layer {layer} in chunk {last_working_chunk}") + log_dict[node_id] = last_working_chunk + + +def cross_edges(cg: ChunkedGraph): + np.random.seed(42) + layer = 2 + n_chunks = 10 + n_segments_per_chunk = 200 + timestamp = datetime.now(timezone.utc) + node_ids = [] + for _ in range(n_chunks): + c_x = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][0]) + c_y = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][1]) + c_z = np.random.randint(0, cg.meta.layer_chunk_bounds[layer][2]) + chunk_id = cg.get_chunk_id(layer=layer, x=c_x, y=c_y, z=c_z) + max_segment_id = cg.get_segment_id(cg.id_client.get_max_node_id(chunk_id)) + if max_segment_id < 10: + continue + + segment_ids = np.random.randint(1, max_segment_id, n_segments_per_chunk) + for segment_id in segment_ids: + node_ids.append(cg.get_node_id(np.uint64(segment_id), np.uint64(chunk_id))) + + rows = cg.client.read_nodes( + node_ids=node_ids, end_time=timestamp, properties=attributes.Hierarchy.Parent + ) + valid_node_ids = [] + non_valid_node_ids = [] + for k in rows.keys(): + if len(rows[k]) > 0: + valid_node_ids.append(k) + else: + non_valid_node_ids.append(k) + + cc_edges = cg.get_atomic_cross_edges(valid_node_ids) + cc_ids = np.unique( + np.concatenate( + [ + np.concatenate(list(v.values())) + for v in list(cc_edges.values()) + if len(v.values()) + ] + ) + ) + + roots = cg.get_roots(cc_ids) + root_dict = dict(zip(cc_ids, roots)) + root_dict_vec = np.vectorize(root_dict.get) + + for k in cc_edges: + if len(cc_edges[k]) == 0: + continue + local_ids = np.unique(np.concatenate(list(cc_edges[k].values()))) + assert len(np.unique(root_dict_vec(local_ids))) + print("success") + + +def run_all(cg: ChunkedGraph): + print("Running family tests:") + family(cg) + + print("\nRunning existence tests:") + existence(cg) + + print("\nRunning cross_edges tests:") + cross_edges(cg) diff --git a/pychunkedgraph/ingest/upgrade/__init__.py b/pychunkedgraph/ingest/upgrade/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pychunkedgraph/ingest/upgrade/atomic_layer.py b/pychunkedgraph/ingest/upgrade/atomic_layer.py new file mode 100644 index 000000000..43270081b --- /dev/null +++ b/pychunkedgraph/ingest/upgrade/atomic_layer.py @@ -0,0 +1,158 @@ +# pylint: disable=invalid-name, missing-docstring, c-extension-no-member + +from collections import defaultdict +from datetime import datetime, timedelta, timezone +import logging, time, os +from copy import copy + +import fastremap +import numpy as np +from pychunkedgraph.graph import ChunkedGraph, types +from pychunkedgraph.graph.attributes import Connectivity, Hierarchy +from pychunkedgraph.graph.utils import serializers +from pychunkedgraph.graph.utils.generic import get_parents_at_timestamp + +from .utils import fix_corrupt_nodes, get_end_timestamps, get_parent_timestamps + +CHILDREN = {} + + +def update_cross_edges( + cg: ChunkedGraph, + node, + cx_edges_d: dict, + node_ts, + node_end_ts, + timestamps_map: defaultdict[int, set], + parents_ts_map: defaultdict[int, dict], +) -> list: + """ + Helper function to update a single L2 ID. + Returns a list of mutations with given timestamps. + """ + rows = [] + edges = np.concatenate(list(cx_edges_d.values())) + partners = np.unique(edges[:, 1]) + + timestamps = copy(timestamps_map[node]) + for partner in partners: + timestamps.update(timestamps_map[partner]) + + node_end_ts = node_end_ts or datetime.now(timezone.utc) + for ts in sorted(timestamps): + if ts < node_ts: + continue + if ts > node_end_ts: + break + + val_dict = {} + parents, _ = get_parents_at_timestamp(partners, parents_ts_map, ts) + edge_parents_d = dict(zip(partners, parents)) + for layer, layer_edges in cx_edges_d.items(): + layer_edges = fastremap.remap( + layer_edges, edge_parents_d, preserve_missing_labels=True + ) + layer_edges[:, 0] = node + layer_edges = np.unique(layer_edges, axis=0) + col = Connectivity.CrossChunkEdge[layer] + val_dict[col] = layer_edges + row_id = serializers.serialize_uint64(node) + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=ts)) + return rows + + +def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list: + start = time.time() + if children_map is None: + children_map = CHILDREN + end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, children_map, layer=2) + + cx_edges_d = cg.get_atomic_cross_edges(nodes) + all_cx_edges = [types.empty_2d] + for _cx_edges_d in cx_edges_d.values(): + if _cx_edges_d: + all_cx_edges.append(np.concatenate(list(_cx_edges_d.values()))) + all_partners = np.unique(np.concatenate(all_cx_edges)[:, 1]) + timestamps_d = get_parent_timestamps(cg, np.concatenate([nodes, all_partners])) + + parents_ts_map = defaultdict(dict) + all_parents = cg.get_parents(all_partners, current=False) + for partner, parents in zip(all_partners, all_parents): + for parent, ts in parents: + parents_ts_map[partner][ts] = parent + logging.info(f"update_nodes init {len(nodes)}: {time.time() - start}") + + rows = [] + skipped = [] + for node, node_ts, end_ts in zip(nodes, nodes_ts, end_timestamps): + is_stale = end_ts is not None + _cx_edges_d = cx_edges_d.get(node, {}) + if is_stale: + end_ts -= timedelta(milliseconds=1) + row_id = serializers.serialize_uint64(node) + val_dict = {Hierarchy.StaleTimeStamp: 0} + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=end_ts)) + + if not _cx_edges_d: + skipped.append(node) + continue + + _rows = update_cross_edges( + cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d, parents_ts_map + ) + rows.extend(_rows) + parents = cg.get_roots(skipped) + layers = cg.get_chunk_layers(parents) + assert np.all(layers == cg.meta.layer_count) + return rows + + +def update_chunk(cg: ChunkedGraph, chunk_coords: list[int]): + """ + Iterate over all L2 IDs in a chunk and update their cross chunk edges, + within the periods they were valid/active. + """ + global CHILDREN + + start = time.time() + x, y, z = chunk_coords + chunk_id = cg.get_chunk_id(layer=2, x=x, y=y, z=z) + rr = cg.range_read_chunk(chunk_id) + + nodes = [] + nodes_ts = [] + try: + earliest_ts = os.environ["EARLIEST_TS"] + earliest_ts = datetime.fromisoformat(earliest_ts) + except KeyError: + earliest_ts = cg.get_earliest_timestamp() + + corrupt_nodes = [] + for k, v in rr.items(): + try: + CHILDREN[k] = v[Hierarchy.Child][0].value + ts = v[Hierarchy.Child][0].timestamp + _ = v[Hierarchy.Parent] + nodes.append(k) + nodes_ts.append(earliest_ts if ts < earliest_ts else ts) + except KeyError: + # ignore invalid nodes from failed ingest tasks, w/o parent column entry + # retain invalid nodes from edits to fix the hierarchy + if ts > earliest_ts: + corrupt_nodes.append(k) + + clean_task = os.environ.get("CLEAN_CHUNKS", "false") == "clean" + if clean_task: + logging.info(f"found {len(corrupt_nodes)} corrupt nodes {corrupt_nodes[:3]}...") + fix_corrupt_nodes(cg, corrupt_nodes, CHILDREN) + return + + cg.copy_fake_edges(chunk_id) + if len(nodes) == 0: + return + + logging.info(f"processing {len(nodes)} nodes.") + assert len(CHILDREN) > 0, (nodes, CHILDREN) + rows = update_nodes(cg, nodes, nodes_ts) + cg.client.write(rows) + logging.info(f"mutations: {len(rows)}, time: {time.time() - start}") diff --git a/pychunkedgraph/ingest/upgrade/parent_layer.py b/pychunkedgraph/ingest/upgrade/parent_layer.py new file mode 100644 index 000000000..436aca49c --- /dev/null +++ b/pychunkedgraph/ingest/upgrade/parent_layer.py @@ -0,0 +1,273 @@ +# pylint: disable=invalid-name, missing-docstring, c-extension-no-member + +from math import ceil +import bisect, logging, random, time, os, gc +import multiprocessing as mp +from collections import defaultdict +from datetime import datetime, timezone + +import fastremap +import numpy as np +from tqdm import tqdm +from cachetools import LRUCache + +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.edges import stale, get_latest_edges_wrapper +from pychunkedgraph.graph.attributes import Connectivity, Hierarchy +from pychunkedgraph.graph.utils import serializers, basetypes +from pychunkedgraph.graph.types import empty_2d +from pychunkedgraph.utils.general import chunked + +from .utils import fix_corrupt_nodes, get_end_timestamps, get_parent_timestamps + + +CHILDREN = {} +CX_EDGES = {} +CG: ChunkedGraph = None +PARENT_CACHE_LIMIT = int(os.environ.get("PARENT_CACHE_LIMIT", 256)) * 1024 + + +def _populate_nodes_and_children( + cg: ChunkedGraph, chunk_id: np.uint64, nodes: list = None +) -> dict: + global CHILDREN + if nodes: + children_map = cg.get_children(nodes) + for k, v in children_map.items(): + if len(v): + CHILDREN[k] = v + return + response = cg.range_read_chunk(chunk_id, properties=Hierarchy.Child) + for k, v in response.items(): + CHILDREN[k] = v[0].value + + +def _get_cx_edges_at_timestamp(node, response, ts): + result = defaultdict(list) + for child in CHILDREN[node]: + if child not in response: + continue + for key, cells in response[child].items(): + # cells are sorted in descending order of timestamps + asc_ts = [c.timestamp for c in reversed(cells)] + k = bisect.bisect_right(asc_ts, ts) - 1 + if k >= 0: + idx = len(cells) - 1 - k + try: + result[key.index].append(cells[idx].value) + except IndexError as e: + logging.error(f"{k}, {idx}, {len(cells)}, {asc_ts}") + raise IndexError from e + for layer, edges in result.items(): + result[layer] = np.concatenate(edges) + return result + + +def _populate_cx_edges_with_timestamps( + cg: ChunkedGraph, layer: int, nodes: list, nodes_ts: list +): + """ + Collect timestamps of edits from children, since we use the same timestamp + for all IDs involved in an edit, we can use the timestamps of + when cross edges of children were updated. + """ + clean_task = os.environ.get("CLEAN_CHUNKS", "false") == "clean" + # this data is not needed for clean tasks + if clean_task: + return + + start = time.time() + global CX_EDGES + attrs = [Connectivity.CrossChunkEdge[l] for l in range(layer, cg.meta.layer_count)] + all_children = np.concatenate(list(CHILDREN.values())) + response = cg.client.read_nodes(node_ids=all_children, properties=attrs) + timestamps_d = get_parent_timestamps(cg, nodes) + end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, CHILDREN, layer=layer) + logging.info(f"_populate_cx_edges_with_timestamps init: {time.time() - start}") + + start = time.time() + partners_map = {} + for node, node_ts in zip(nodes, nodes_ts): + CX_EDGES[node] = {} + cx_edges_d_node_ts = _get_cx_edges_at_timestamp(node, response, node_ts) + edges = np.concatenate([empty_2d] + list(cx_edges_d_node_ts.values())) + partners_map[node] = edges[:, 1] + CX_EDGES[node][node_ts] = cx_edges_d_node_ts + + partners = np.unique(np.concatenate([*partners_map.values()])) + partner_parent_ts_d = get_parent_timestamps(cg, partners) + logging.info(f"get partners timestamps init: {time.time() - start}") + + rows = [] + for node, node_ts, node_end_ts in zip(nodes, nodes_ts, end_timestamps): + timestamps = timestamps_d[node] + for partner in partners_map[node]: + timestamps.update(partner_parent_ts_d[partner]) + + is_stale = node_end_ts is not None + node_end_ts = node_end_ts or datetime.now(timezone.utc) + for ts in sorted(timestamps): + if ts > node_end_ts: + break + CX_EDGES[node][ts] = _get_cx_edges_at_timestamp(node, response, ts) + + if is_stale: + row_id = serializers.serialize_uint64(node) + val_dict = {Hierarchy.StaleTimeStamp: 0} + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=node_end_ts)) + cg.client.write(rows) + + +def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts) -> list: + """ + Helper function to update a single ID. + Returns a list of mutations with timestamps. + """ + rows = [] + row_id = serializers.serialize_uint64(node) + for ts, edges_d in CX_EDGES[node].items(): + if ts < node_ts: + continue + edges_d, _nodes = get_latest_edges_wrapper(cg, edges_d, parent_ts=ts) + if _nodes.size == 0: + continue + + parents = cg.get_roots(_nodes, time_stamp=ts, stop_layer=layer, ceil=False) + edge_parents_d = dict(zip(_nodes, parents)) + val_dict = {} + for _layer, layer_edges in edges_d.items(): + layer_edges = fastremap.remap( + layer_edges, edge_parents_d, preserve_missing_labels=True + ) + layer_edges[:, 0] = node + layer_edges = np.unique(layer_edges, axis=0) + col = Connectivity.CrossChunkEdge[_layer] + val_dict[col] = layer_edges + rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=ts)) + return rows + + +def _update_cross_edges_helper(args): + global CG + stale.PARENTS_CACHE = LRUCache(PARENT_CACHE_LIMIT) + stale.CHILDREN_CACHE = LRUCache(1 * 1024) + clean_task = os.environ.get("CLEAN_CHUNKS", "false") == "clean" + cg_info, layer, nodes, nodes_ts = args + + if CG is None: + CG = ChunkedGraph(**cg_info) + cg = CG + parents = cg.get_parents(nodes, fail_to_zero=True) + + tasks = [] + corrupt_nodes = [] + earliest_ts = None + if clean_task: + try: + earliest_ts = os.environ["EARLIEST_TS"] + earliest_ts = datetime.fromisoformat(earliest_ts) + except KeyError: + earliest_ts = cg.get_earliest_timestamp() + + for node, parent, node_ts in zip(nodes, parents, nodes_ts): + if parent == 0: + # ignore invalid nodes from failed ingest tasks, w/o parent column entry + # retain invalid nodes from edits to fix the hierarchy + if clean_task and node_ts > earliest_ts: + corrupt_nodes.append(node) + else: + tasks.append((cg, layer, node, node_ts)) + + if clean_task: + logging.info(f"found {len(corrupt_nodes)} corrupt nodes {corrupt_nodes[:3]}...") + fix_corrupt_nodes(cg, corrupt_nodes, CHILDREN) + return + + rows = [] + for task in tasks: + rows.extend(update_cross_edges(*task)) + stale.PARENTS_CACHE.clear() + stale.CHILDREN_CACHE.clear() + cg.client.write(rows) + gc.collect() + + +def _get_split_nodes( + cg: ChunkedGraph, chunk_id: basetypes.CHUNK_ID, split: int, splits: int +): + max_id = cg.client.get_max_node_id(chunk_id) + total = max_id - chunk_id + split_size = int(ceil(total / splits)) + start = int(chunk_id + np.uint64(split * split_size)) + end = int(start + split_size) + return range(int(start), int(end)) + + +def update_chunk( + cg: ChunkedGraph, + chunk_coords: list[int], + layer: int, + nodes: list = None, + split: int = None, + splits: int = None, +): + """ + Iterate over all layer IDs in a chunk and update their cross chunk edges. + """ + debug = nodes is not None + start = time.time() + x, y, z = chunk_coords + chunk_id = cg.get_chunk_id(layer=layer, x=x, y=y, z=z) + + if splits is not None: + nodes = _get_split_nodes(cg, chunk_id, split, splits) + + _populate_nodes_and_children(cg, chunk_id, nodes=nodes) + logging.info(f"_populate_nodes_and_children: {time.time() - start}") + nodes = list(CHILDREN.keys()) + if len(nodes) == 0: + return + + logging.info(f"processing {len(nodes)} nodes.") + random.shuffle(nodes) + start = time.time() + nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True) + logging.info(f"get_node_timestamps: {time.time() - start}") + + start = time.time() + _populate_cx_edges_with_timestamps(cg, layer, nodes, nodes_ts) + logging.info(f"_populate_cx_edges_with_timestamps: {time.time() - start}") + + if debug: + rows = [] + stale.PARENTS_CACHE = LRUCache(PARENT_CACHE_LIMIT) + stale.CHILDREN_CACHE = LRUCache(1 * 1024) + logging.info(f"processing {len(nodes)} nodes with 1 worker.") + for node, node_ts in zip(nodes, nodes_ts): + rows.extend(update_cross_edges(cg, layer, node, node_ts)) + stale.PARENTS_CACHE.clear() + stale.CHILDREN_CACHE.clear() + logging.info(f"total elaspsed time: {time.time() - start}") + return + + task_size = int(os.environ.get("TASK_SIZE", 1)) + chunked_nodes = chunked(nodes, task_size) + chunked_nodes_ts = chunked(nodes_ts, task_size) + cg_info = cg.get_serialized_info() + + tasks = [] + for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts): + args = (cg_info, layer, chunk, ts_chunk) + tasks.append(args) + + process_multiplier = int(os.environ.get("PROCESS_MULTIPLIER", 5)) + processes = min(mp.cpu_count() * process_multiplier, len(tasks)) + logging.info(f"processing {len(nodes)} nodes with {processes} workers.") + with mp.Pool(processes) as pool: + _ = list( + tqdm( + pool.imap_unordered(_update_cross_edges_helper, tasks), + total=len(tasks), + ) + ) + logging.info(f"total elaspsed time: {time.time() - start}") diff --git a/pychunkedgraph/ingest/upgrade/utils.py b/pychunkedgraph/ingest/upgrade/utils.py new file mode 100644 index 000000000..0410245c3 --- /dev/null +++ b/pychunkedgraph/ingest/upgrade/utils.py @@ -0,0 +1,137 @@ +# pylint: disable=invalid-name, missing-docstring + +from collections import defaultdict +from datetime import datetime, timedelta + +import numpy as np +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.attributes import Hierarchy +from pychunkedgraph.graph.utils import serializers +from google.cloud.bigtable.row_filters import TimestampRange + + +def exists_as_parent(cg: ChunkedGraph, parent, nodes) -> bool: + """ + Check if a given l2 parent is in the history of given nodes. + """ + response = cg.client.read_nodes(node_ids=nodes, properties=Hierarchy.Parent) + parents = set() + for cells in response.values(): + parents.update([cell.value for cell in cells]) + return parent in parents + + +def get_edit_timestamps(cg: ChunkedGraph, edges_d, start_ts, end_ts) -> list: + """ + Timestamps of when post-side nodes were involved in an edit. + Post-side - nodes in the neighbor chunk. + This is required because we need to update edges from both sides. + """ + cx_edges = np.concatenate(list(edges_d.values())) + timestamps = get_parent_timestamps( + cg, cx_edges[:, 1], start_time=start_ts, end_time=end_ts + ) + timestamps.add(start_ts) + return sorted(timestamps) + + +def _get_end_timestamps_helper(cg: ChunkedGraph, nodes: list) -> defaultdict[int, set]: + result = defaultdict(set) + response = cg.client.read_nodes(node_ids=nodes, properties=Hierarchy.StaleTimeStamp) + for k, v in response.items(): + result[k].add(v[0].timestamp) + return result + + +def get_end_timestamps( + cg: ChunkedGraph, nodes: list, nodes_ts: datetime, children_map: dict, layer: int +): + """ + Gets the last timestamp for each node at which to update its cross edges. + For layer 2: + Get parent timestamps for all children of a node. + The first timestamp > node_timestamp among these is the last timestamp. + This is the timestamp at which one of node's children + got a new parent that superseded the current node. + These are cached in database. + For all nodes in each layer > 2: + Pick the earliest child node_end_ts > node_ts and cache in database. + """ + result = [] + children = np.concatenate([*children_map.values()]) + if layer == 2: + timestamps_d = get_parent_timestamps(cg, children) + else: + timestamps_d = _get_end_timestamps_helper(cg, children) + + for node, node_ts in zip(nodes, nodes_ts): + node_children = children_map[node] + _children_timestamps = [] + for k in node_children: + if k in timestamps_d: + _children_timestamps.append(timestamps_d[k]) + _timestamps = set().union(*_children_timestamps) + _timestamps.add(node_ts) + try: + _timestamps = sorted(_timestamps) + _index = np.searchsorted(_timestamps, node_ts) + end_ts = _timestamps[_index + 1] + except IndexError: + # this node has not been edited, but might have it edges updated + end_ts = None + result.append(end_ts) + return result + + +def get_parent_timestamps( + cg: ChunkedGraph, nodes, start_time=None, end_time=None +) -> defaultdict[int, set]: + """ + Timestamps of when the given nodes were edited. + """ + earliest_ts = cg.get_earliest_timestamp() + response = cg.client.read_nodes( + node_ids=nodes, + properties=[Hierarchy.Parent], + start_time=start_time, + end_time=end_time, + end_time_inclusive=False, + ) + + result = defaultdict(set) + for k, v in response.items(): + for cell in v[Hierarchy.Parent]: + ts = cell.timestamp + result[k].add(earliest_ts if ts < earliest_ts else ts) + return result + + +def fix_corrupt_nodes(cg: ChunkedGraph, nodes: list, children_d: dict): + """ + For each node: delete it from parent column of its children. + Then deletes the node itself, effectively erasing it from hierarchy. + """ + table = cg.client._table + batcher = table.mutations_batcher(flush_count=500) + for node in nodes: + children = children_d[node] + _map = cg.client.read_nodes(node_ids=children, properties=Hierarchy.Parent) + + for child, parent_cells in _map.items(): + row = table.direct_row(serializers.serialize_uint64(child)) + for cell in parent_cells: + if cell.value == node: + start = cell.timestamp + end = start + timedelta(microseconds=1) + row.delete_cell( + column_family_id=Hierarchy.Parent.family_id, + column=Hierarchy.Parent.key, + time_range=TimestampRange(start=start, end=end), + ) + batcher.mutate(row) + + row = table.direct_row(serializers.serialize_uint64(node)) + row.delete() + batcher.mutate(row) + + batcher.flush() diff --git a/pychunkedgraph/ingest/utils.py b/pychunkedgraph/ingest/utils.py index fa7ef7a3c..83d2716d8 100644 --- a/pychunkedgraph/ingest/utils.py +++ b/pychunkedgraph/ingest/utils.py @@ -1,14 +1,25 @@ -from typing import Tuple +# pylint: disable=invalid-name, missing-docstring +import logging +import functools +import math, random, sys +from os import environ +from time import sleep +from typing import Any, Generator, Tuple -from . import ClusterIngestConfig -from . import IngestConfig -from ..graph.meta import ChunkedGraphMeta -from ..graph.meta import DataSource -from ..graph.meta import GraphConfig +import numpy as np +import tensorstore as ts +from rq import Queue, Retry, Worker +from rq.worker import WorkerStatus +from . import IngestConfig +from .manager import IngestionManager +from ..graph.meta import ChunkedGraphMeta, DataSource, GraphConfig from ..graph.client import BackendClientInfo from ..graph.client.bigtable import BigTableConfig +from ..utils.general import chunked +from ..utils.redis import get_redis_connection +from ..utils.redis import keys as r_keys chunk_id_str = lambda layer, coords: f"{layer}_{'_'.join(map(str, coords))}" @@ -16,14 +27,12 @@ def bootstrap( graph_id: str, config: dict, - overwrite: bool = False, raw: bool = False, test_run: bool = False, ) -> Tuple[ChunkedGraphMeta, IngestConfig, BackendClientInfo]: """Parse config loaded from a yaml file.""" ingest_config = IngestConfig( **config.get("ingest_config", {}), - CLUSTER=ClusterIngestConfig(), USE_RAW_EDGES=raw, USE_RAW_COMPONENTS=raw, TEST_RUN=test_run, @@ -33,7 +42,7 @@ def bootstrap( graph_config = GraphConfig( ID=f"{graph_id}", - OVERWRITE=overwrite, + OVERWRITE=False, **config["graph_config"], ) data_source = DataSource(**config["data_source"]) @@ -42,6 +51,10 @@ def bootstrap( return (meta, ingest_config, client_info) +def move_up(lines: int = 1): + sys.stdout.write(f"\033[{lines}A") + + def postprocess_edge_data(im, edge_dict): data_version = im.cg_meta.data_source.DATA_VERSION if data_version == 2: @@ -72,4 +85,196 @@ def postprocess_edge_data(im, edge_dict): return new_edge_dict else: - raise Exception(f"Unknown data_version: {data_version}") + raise ValueError(f"Unknown data_version: {data_version}") + + +def start_ocdbt_server(imanager: IngestionManager, server: Any): + spec = {"driver": "ocdbt", "base": f"{imanager.cg.meta.data_source.EDGES}/ocdbt"} + spec["coordinator"] = {"address": f"localhost:{server.port}"} + ts.KvStore.open(spec).result() + imanager.redis.set("OCDBT_COORDINATOR_PORT", str(server.port)) + ocdbt_host = environ.get("MY_POD_IP", "localhost") + imanager.redis.set("OCDBT_COORDINATOR_HOST", ocdbt_host) + logging.info(f"OCDBT Coordinator address {ocdbt_host}:{server.port}") + + +def randomize_grid_points(X: int, Y: int, Z: int) -> Generator[int, int, int]: + indices = np.arange(X * Y * Z) + np.random.shuffle(indices) + for index in indices: + yield np.unravel_index(index, (X, Y, Z)) + + +def get_chunks_not_done( + imanager: IngestionManager, layer: int, coords: list, splits: int = 0 +) -> list: + """check for set membership in redis in batches""" + coords_strs = [] + if splits > 0: + split_coords = [] + for coord in coords: + for split in range(splits): + jid = "_".join(map(str, coord)) + f"_{split}" + coords_strs.append(jid) + split_coords.append((coord, split)) + else: + coords_strs = ["_".join(map(str, coord)) for coord in coords] + try: + completed = imanager.redis.smismember(f"{layer}c", coords_strs) + except Exception: + return split_coords if splits > 0 else coords + + if splits > 0: + return [coord for coord, c in zip(split_coords, completed) if not c] + return [coord for coord, c in zip(coords, completed) if not c] + + +def print_completion_rate(imanager: IngestionManager, layer: int, span: int = 30): + rate = 0.0 + while True: + counts = [] + print(f"{rate} chunks per second.") + for _ in range(span + 1): + counts.append(imanager.redis.scard(f"{layer}c")) + sleep(1) + rate = np.diff(counts).sum() / span + move_up() + + +def print_status(imanager: IngestionManager, redis, upgrade: bool = False): + """ + Helper to print status to console. + If `upgrade=True`, status does not include the root layer, + since there is no need to update cross edges for root ids. + """ + layers = range(2, imanager.cg_meta.layer_count + 1) + if upgrade: + layers = range(2, imanager.cg_meta.layer_count) + + def _refresh_status(): + pipeline = redis.pipeline() + pipeline.get(r_keys.JOB_TYPE) + worker_busy = ["-"] * len(layers) + for layer in layers: + pipeline.scard(f"{layer}c") + queue = Queue(f"l{layer}", connection=redis) + pipeline.llen(queue.key) + pipeline.zcard(queue.failed_job_registry.key) + + results = pipeline.execute() + job_type = "not_available" + if results[0] is not None: + job_type = results[0].decode() + completed = [] + queued = [] + failed = [] + for i in range(1, len(results), 3): + result = results[i : i + 3] + completed.append(result[0]) + queued.append(result[1]) + failed.append(result[2]) + return job_type, completed, queued, failed, worker_busy + + job_type, completed, queued, failed, worker_busy = _refresh_status() + + layer_counts = imanager.cg_meta.layer_chunk_counts + header = ( + f"\njob_type: \t{job_type}" + f"\nversion: \t{imanager.cg.version}" + f"\ngraph_id: \t{imanager.cg.graph_id}" + f"\nchunk_size: \t{imanager.cg.meta.graph_config.CHUNK_SIZE}" + "\n\nlayer status:" + ) + print(header) + while True: + for layer, done, count in zip(layers, completed, layer_counts): + print( + f"{layer}\t| {done:9} / {count} \t| {math.floor((done/count)*100):6}%" + ) + + print("\n\nqueue status:") + for layer, q, f, wb in zip(layers, queued, failed, worker_busy): + print(f"l{layer}\t| queued: {q:<10} failed: {f:<10} busy: {wb}") + + sleep(1) + _, completed, queued, failed, worker_busy = _refresh_status() + move_up(lines=2 * len(layers) + 3) + + +def queue_layer_helper( + parent_layer: int, imanager: IngestionManager, fn, splits: int = 0 +): + if parent_layer == imanager.cg_meta.layer_count: + chunk_coords = [(0, 0, 0)] + else: + bounds = imanager.cg_meta.layer_chunk_bounds[parent_layer] + chunk_coords = randomize_grid_points(*bounds) + + q = imanager.get_task_queue(f"l{parent_layer}") + batch_size = int(environ.get("JOB_BATCH_SIZE", 10000)) + timeout_scale = int(environ.get("TIMEOUT_SCALE_FACTOR", 1)) + batches = chunked(chunk_coords, batch_size) + failure_ttl = int(environ.get("FAILURE_TTL", 300)) + for batch in batches: + _coords = get_chunks_not_done(imanager, parent_layer, batch, splits=splits) + # buffer for optimal use of redis memory + if len(q) > int(environ.get("QUEUE_SIZE", 100000)): + interval = int(environ.get("QUEUE_INTERVAL", 300)) + logging.info(f"Queue full; sleeping {interval}s...") + sleep(interval) + + job_datas = [] + retry = int(environ.get("RETRY_COUNT", 0)) + for chunk_coord in _coords: + if splits > 0: + coord, split = chunk_coord + jid = chunk_id_str(parent_layer, coord) + f"_{split}" + job_datas.append( + Queue.prepare_data( + fn, + args=(parent_layer, coord, split, splits), + result_ttl=0, + job_id=jid, + timeout=f"{timeout_scale * int(parent_layer * parent_layer)}m", + retry=Retry(retry) if retry > 1 else None, + description="", + failure_ttl=failure_ttl, + ) + ) + else: + job_datas.append( + Queue.prepare_data( + fn, + args=(parent_layer, chunk_coord), + result_ttl=0, + job_id=chunk_id_str(parent_layer, chunk_coord), + timeout=f"{timeout_scale * int(parent_layer * parent_layer)}m", + retry=Retry(retry) if retry > 1 else None, + description="", + failure_ttl=failure_ttl, + ) + ) + q.enqueue_many(job_datas) + logging.info(f"Queued {len(job_datas)} chunks.") + + +def job_type_guard(job_type: str): + def decorator_job_type_guard(func): + @functools.wraps(func) + def wrapper_job_type_guard(*args, **kwargs): + redis = get_redis_connection() + current_type = redis.get(r_keys.JOB_TYPE) + if current_type is not None: + current_type = current_type.decode() + msg = ( + f"Currently running `{current_type}`. You're attempting to run `{job_type}`." + f"\nRun `[flask] {current_type} flush_redis` to clear the current job and restart." + ) + if current_type != job_type: + print(f"\n*WARNING*\n{msg}") + exit(1) + return func(*args, **kwargs) + + return wrapper_job_type_guard + + return decorator_job_type_guard diff --git a/pychunkedgraph/logging/log_db.py b/pychunkedgraph/logging/log_db.py index 89680500a..4a4244022 100644 --- a/pychunkedgraph/logging/log_db.py +++ b/pychunkedgraph/logging/log_db.py @@ -4,7 +4,7 @@ import threading import time import queue -from datetime import datetime +from datetime import datetime, timezone from google.api_core.exceptions import GoogleAPIError from datastoreflex import DatastoreFlex @@ -109,7 +109,7 @@ def __init__(self, name: str, graph_id: str, operation_id=-1, **kwargs): self.names.append(name) self._start = None self._graph_id = graph_id - self._ts = datetime.utcnow() + self._ts = datetime.now(timezone.utc) self._kwargs = kwargs if operation_id != -1: self.operation_id = operation_id diff --git a/pychunkedgraph/meshing/manifest/utils.py b/pychunkedgraph/meshing/manifest/utils.py index 67e600653..90963570c 100644 --- a/pychunkedgraph/meshing/manifest/utils.py +++ b/pychunkedgraph/meshing/manifest/utils.py @@ -40,7 +40,7 @@ def _get_children(cg, node_ids: Sequence[np.uint64], children_cache: Dict): if len(node_ids) == 0: return empty_1d.copy() node_ids = np.array(node_ids, dtype=NODE_ID) - mask = np.in1d(node_ids, np.fromiter(children_cache.keys(), dtype=NODE_ID)) + mask = np.isin(node_ids, np.fromiter(children_cache.keys(), dtype=NODE_ID)) children_d = cg.get_children(node_ids[~mask]) children_cache.update(children_d) diff --git a/pychunkedgraph/meshing/mesh_analysis.py b/pychunkedgraph/meshing/mesh_analysis.py index 97bb28f5b..abdf95957 100644 --- a/pychunkedgraph/meshing/mesh_analysis.py +++ b/pychunkedgraph/meshing/mesh_analysis.py @@ -16,10 +16,10 @@ def compute_centroid_with_chunk_boundary(cg, vertices, l2_id, last_l2_id): a path, return the center point of the mesh on the chunk boundary separating the two ids, and the center point of the entire mesh. :param cg: ChunkedGraph object - :param vertices: [[np.float]] + :param vertices: [[np.float64]] :param l2_id: np.uint64 :param last_l2_id: np.uint64 or None - :return: [np.float] + :return: [np.float64] """ centroid_by_range = compute_centroid_by_range(vertices) if last_l2_id is None: diff --git a/pychunkedgraph/meshing/mesh_io.py b/pychunkedgraph/meshing/mesh_io.py index 40c02bba0..1cf1fed66 100644 --- a/pychunkedgraph/meshing/mesh_io.py +++ b/pychunkedgraph/meshing/mesh_io.py @@ -168,8 +168,8 @@ def load_obj(self): faces.append(face) self._faces = np.array(faces, dtype=int) - 1 - self._vertices = np.array(vertices, dtype=np.float) - self._normals = np.array(normals, dtype=np.float) + self._vertices = np.array(vertices, dtype=np.float64) + self._normals = np.array(normals, dtype=np.float64) def load_h5(self): with h5py.File(self.filename, "r") as f: diff --git a/pychunkedgraph/meshing/meshengine.py b/pychunkedgraph/meshing/meshengine.py index 615e6cdb6..e852dfa3a 100644 --- a/pychunkedgraph/meshing/meshengine.py +++ b/pychunkedgraph/meshing/meshengine.py @@ -126,14 +126,14 @@ def mesh_single_layer(self, layer, bounding_box=None, block_factor=2, block_bounding_box_cg /= 2 ** np.max([0, layer - 2]) block_bounding_box_cg = np.ceil(block_bounding_box_cg) - n_jobs = np.product(block_bounding_box_cg[1] - + n_jobs = np.prod(block_bounding_box_cg[1] - block_bounding_box_cg[0]) / \ block_factor ** 2 < n_threads while n_jobs < n_threads and block_factor > 1: block_factor -= 1 - n_jobs = np.product(block_bounding_box_cg[1] - + n_jobs = np.prod(block_bounding_box_cg[1] - block_bounding_box_cg[0]) / \ block_factor ** 2 < n_threads diff --git a/pychunkedgraph/meshing/meshgen.py b/pychunkedgraph/meshing/meshgen.py index a8da89b1f..80f75bffd 100644 --- a/pychunkedgraph/meshing/meshgen.py +++ b/pychunkedgraph/meshing/meshgen.py @@ -1,1367 +1,46 @@ -# pylint: disable=invalid-name, missing-docstring, too-many-lines, wrong-import-order, import-outside-toplevel, no-member, c-extension-no-member - -from typing import Sequence -import os -import numpy as np -import time -import collections -from functools import lru_cache -import datetime -import pytz -from scipy import ndimage - -from multiwrapper import multiprocessing_utils as mu -from cloudfiles import CloudFiles -from cloudvolume import CloudVolume -from cloudvolume.datasource.precomputed.sharding import ShardingSpecification -import DracoPy -import zmesh -import fastremap - -from pychunkedgraph.graph.chunkedgraph import ChunkedGraph # noqa -from pychunkedgraph.graph import attributes # noqa -from pychunkedgraph.meshing import meshgen_utils # noqa -from pychunkedgraph.meshing.manifest.cache import ManifestCache - - -UTC = pytz.UTC - -# Change below to true if debugging and want to see results in stdout -PRINT_FOR_DEBUGGING = False -# Change below to false if debugging and do not need to write to cloud (warning: do not deploy w/ below set to false) -WRITING_TO_CLOUD = True - -REDIS_HOST = os.environ.get("REDIS_SERVICE_HOST", "localhost") -REDIS_PORT = os.environ.get("REDIS_SERVICE_PORT", "6379") -REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "dev") -REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" - - -def decode_draco_mesh_buffer(fragment): - try: - mesh_object = DracoPy.decode_buffer_to_mesh(fragment) - vertices = np.array(mesh_object.points) - faces = np.array(mesh_object.faces) - except ValueError as exc: - raise ValueError("Not a valid draco mesh") from exc - - num_vertices = len(vertices) - - # For now, just return this dict until we figure out - # how exactly to deal with Draco's lossiness/duplicate vertices - return { - "num_vertices": num_vertices, - "vertices": vertices, - "faces": faces, - "encoding_options": mesh_object.encoding_options, - "encoding_type": "draco", - } - - -def remap_seg_using_unsafe_dict(seg, unsafe_dict): - for unsafe_root_id in unsafe_dict.keys(): - bin_seg = seg == unsafe_root_id - - if np.sum(bin_seg) == 0: - continue - - cc_seg, n_cc = ndimage.label(bin_seg) - for i_cc in range(1, n_cc + 1): - bin_cc_seg = cc_seg == i_cc - - overlaps = [] - overlaps.extend(np.unique(seg[-2, :, :][bin_cc_seg[-1, :, :]])) - overlaps.extend(np.unique(seg[:, -2, :][bin_cc_seg[:, -1, :]])) - overlaps.extend(np.unique(seg[:, :, -2][bin_cc_seg[:, :, -1]])) - overlaps = np.unique(overlaps) - - linked_l2_ids = overlaps[np.in1d(overlaps, unsafe_dict[unsafe_root_id])] - - if len(linked_l2_ids) == 0: - seg[bin_cc_seg] = 0 - else: - seg[bin_cc_seg] = linked_l2_ids[0] - - return seg - - -def get_remapped_segmentation( - cg, chunk_id, mip=2, overlap_vx=1, time_stamp=None, n_threads=1 -): - """Downloads + remaps ws segmentation + resolve unclear cases - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param mip: int - :param overlap_vx: int - :param time_stamp: - :return: remapped segmentation - """ - assert mip >= cg.meta.cv.mip - - sv_remapping, unsafe_dict = get_lx_overlapping_remappings( - cg, chunk_id, time_stamp=time_stamp, n_threads=n_threads - ) - - ws_seg = meshgen_utils.get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx) - seg = fastremap.mask_except(ws_seg, list(sv_remapping.keys()), in_place=False) - fastremap.remap(seg, sv_remapping, preserve_missing_labels=True, in_place=True) - - return remap_seg_using_unsafe_dict(seg, unsafe_dict) - - -def get_remapped_seg_for_lvl2_nodes( - cg, - chunk_id: np.uint64, - lvl2_nodes: Sequence[np.uint64], - mip: int = 2, - overlap_vx: int = 1, - time_stamp=None, - n_threads: int = 1, -): - """Downloads + remaps ws segmentation + resolve unclear cases, - filter out all but specified lvl2_nodes - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param mip: int - :param overlap_vx: int - :param time_stamp: - :return: remapped segmentation - """ - seg = meshgen_utils.get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx) - sv_of_lvl2_nodes = cg.get_children(lvl2_nodes) - - # Check which of the lvl2_nodes meet the chunk boundary - node_ids_on_the_border = [] - remapping = {} - for node, sv_list in sv_of_lvl2_nodes.items(): - node_on_the_border = False - for sv_id in sv_list: - remapping[sv_id] = node - # If a node_id is on the chunk_boundary, we must check - # the overlap region to see if the meshes' end will be open or closed - if (not node_on_the_border) and ( - np.isin(sv_id, seg[-2, :, :]) - or np.isin(sv_id, seg[:, -2, :]) - or np.isin(sv_id, seg[:, :, -2]) - ): - node_on_the_border = True - node_ids_on_the_border.append(node) - - node_ids_on_the_border = np.array(node_ids_on_the_border) - if len(node_ids_on_the_border) > 0: - overlap_region = np.concatenate( - (seg[:, :, -1], seg[:, -1, :], seg[-1, :, :]), axis=None - ) - overlap_sv_ids = np.unique(overlap_region) - if overlap_sv_ids[0] == 0: - overlap_sv_ids = overlap_sv_ids[1:] - # Get the remappings for the supervoxels in the overlap region - sv_remapping, unsafe_dict = get_lx_overlapping_remappings_for_nodes_and_svs( - cg, chunk_id, node_ids_on_the_border, overlap_sv_ids, time_stamp, n_threads - ) - sv_remapping.update(remapping) - fastremap.mask_except(seg, list(sv_remapping.keys()), in_place=True) - fastremap.remap(seg, sv_remapping, preserve_missing_labels=True, in_place=True) - # For some supervoxel, they could map to multiple l2 nodes in the chunk, - # so we must perform a connected component analysis - # to see which l2 node they are adjacent to - return remap_seg_using_unsafe_dict(seg, unsafe_dict) - else: - # If no nodes in our subset meet the chunk boundary - # we can simply retrieve the sv of the nodes in the subset - fastremap.mask_except(seg, list(remapping.keys()), in_place=True) - fastremap.remap(seg, remapping, preserve_missing_labels=True, in_place=True) - - return seg - - -@lru_cache(maxsize=None) -def get_higher_to_lower_remapping(cg, chunk_id, time_stamp): - """Retrieves lx node id to sv id mappping - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param time_stamp: datetime object - :return: dictionary - """ - - def _lower_remaps(ks): - return np.concatenate([lower_remaps[k] for k in ks]) - - assert cg.get_chunk_layer(chunk_id) >= 2 - assert cg.get_chunk_layer(chunk_id) <= cg.meta.layer_count - - print(f"\n{chunk_id} ----------------\n") - - lower_remaps = {} - if cg.get_chunk_layer(chunk_id) > 2: - for lower_chunk_id in cg.get_chunk_child_ids(chunk_id): - # TODO speedup - lower_remaps.update( - get_higher_to_lower_remapping(cg, lower_chunk_id, time_stamp=time_stamp) - ) - - rr_chunk = cg.range_read_chunk( - chunk_id=chunk_id, properties=attributes.Hierarchy.Child, time_stamp=time_stamp - ) - - # This for-loop ensures that only the latest lx_ids are considered - # The order by id guarantees the time order (only true for same neurons - # but that is the case here). - lx_remapping = {} - all_lower_ids = set() - for k in sorted(rr_chunk.keys(), reverse=True): - this_child_ids = rr_chunk[k][0].value - if this_child_ids[0] in all_lower_ids: - continue - - all_lower_ids = all_lower_ids.union(set(list(this_child_ids))) - - if cg.get_chunk_layer(chunk_id) > 2: - try: - lx_remapping[k] = _lower_remaps(this_child_ids) - except KeyError: - # KeyErrors indicate that this id is deprecated given the - # time_stamp - continue - else: - lx_remapping[k] = this_child_ids - - return lx_remapping - - -@lru_cache(maxsize=None) -def get_root_lx_remapping(cg, chunk_id, stop_layer, time_stamp, n_threads=1): - """Retrieves root to l2 node id mapping - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param stop_layer: int - :param time_stamp: datetime object - :return: multiples - """ - - def _get_root_ids(args): - start_id, end_id = args - root_ids[start_id:end_id] = cg.get_roots( - lx_ids[start_id:end_id], - stop_layer=stop_layer, - fail_to_zero=True, - ) - - lx_id_remap = get_higher_to_lower_remapping(cg, chunk_id, time_stamp=time_stamp) - - lx_ids = np.array(list(lx_id_remap.keys())) - - root_ids = np.zeros(len(lx_ids), dtype=np.uint64) - n_jobs = np.min([n_threads, len(lx_ids)]) - multi_args = [] - start_ids = np.linspace(0, len(lx_ids), n_jobs + 1).astype(int) - for i_block in range(n_jobs): - multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) - - if n_jobs > 0: - mu.multithread_func(_get_root_ids, multi_args, n_threads=n_threads) - - return lx_ids, np.array(root_ids), lx_id_remap - - -def calculate_stop_layer(cg, chunk_id): - chunk_coords = cg.get_chunk_coordinates(chunk_id) - chunk_layer = cg.get_chunk_layer(chunk_id) - - neigh_chunk_ids = [] - neigh_parent_chunk_ids = [] - - # Collect neighboring chunks and their parent chunk ids - # We only need to know about the parent chunk ids to figure the lowest - # common chunk - # Notice that the first neigh_chunk_id is equal to `chunk_id`. - for x in range(chunk_coords[0], chunk_coords[0] + 2): - for y in range(chunk_coords[1], chunk_coords[1] + 2): - for z in range(chunk_coords[2], chunk_coords[2] + 2): - # Chunk id - try: - neigh_chunk_id = cg.get_chunk_id(x=x, y=y, z=z, layer=chunk_layer) - # Get parent chunk ids - parent_chunk_ids = cg.get_parent_chunk_ids(neigh_chunk_id) - neigh_chunk_ids.append(neigh_chunk_id) - neigh_parent_chunk_ids.append(parent_chunk_ids) - except: - # cg.get_parent_chunk_id can fail if neigh_chunk_id is outside the dataset - # (only happens when cg.meta.bitmasks[chunk_layer+1] == log(max(x,y,z)), - # so only for specific datasets in which the # of chunks in the widest dimension - # just happens to be a power of two) - pass - - # Find lowest common chunk - neigh_parent_chunk_ids = np.array(neigh_parent_chunk_ids) - layer_agreement = np.all( - (neigh_parent_chunk_ids - neigh_parent_chunk_ids[0]) == 0, axis=0 - ) - stop_layer = np.where(layer_agreement)[0][0] + chunk_layer - - return stop_layer, neigh_chunk_ids - - -# @lru_cache(maxsize=None) -def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): - """Retrieves sv id to layer mapping for chunk with overlap in positive - direction (one chunk) - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param time_stamp: datetime object - :return: multiples - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - stop_layer, neigh_chunk_ids = calculate_stop_layer(cg, chunk_id) - print(f"Stop layer: {stop_layer}") - - # Find the parent in the lowest common chunk for each l2 id. These parent - # ids are referred to as root ids even though they are not necessarily the - # root id. - neigh_lx_ids = [] - neigh_lx_id_remap = {} - neigh_root_ids = [] - - safe_lx_ids = [] - unsafe_lx_ids = [] - unsafe_root_ids = [] - - # This loop is the main bottleneck - for neigh_chunk_id in neigh_chunk_ids: - print(f"Neigh: {neigh_chunk_id} --------------") - - lx_ids, root_ids, lx_id_remap = get_root_lx_remapping( - cg, neigh_chunk_id, stop_layer, time_stamp=time_stamp, n_threads=n_threads - ) - neigh_lx_ids.extend(lx_ids) - neigh_lx_id_remap.update(lx_id_remap) - neigh_root_ids.extend(root_ids) - - if neigh_chunk_id == chunk_id: - # The first neigh_chunk_id is the one we are interested in. All lx - # ids that share no root id with any other lx id are "safe", meaning - # that we can easily obtain the complete remapping (including - # overlap) for these. All other ones have to be resolved using the - # segmentation. - _, u_idx, c_root_ids = np.unique( - neigh_root_ids, return_counts=True, return_index=True - ) - - safe_lx_ids = lx_ids[u_idx[c_root_ids == 1]] - unsafe_lx_ids = lx_ids[~np.in1d(lx_ids, safe_lx_ids)] - unsafe_root_ids = np.unique(root_ids[u_idx[c_root_ids != 1]]) - - lx_root_dict = dict(zip(neigh_lx_ids, neigh_root_ids)) - root_lx_dict = collections.defaultdict(list) - - # Future sv id -> lx mapping - sv_ids = [] - lx_ids_flat = [] - - for i_root_id in range(len(neigh_root_ids)): - root_lx_dict[neigh_root_ids[i_root_id]].append(neigh_lx_ids[i_root_id]) - - # Do safe ones first - for lx_id in safe_lx_ids: - root_id = lx_root_dict[lx_id] - for neigh_lx_id in root_lx_dict[root_id]: - lx_sv_ids = neigh_lx_id_remap[neigh_lx_id] - sv_ids.extend(lx_sv_ids) - lx_ids_flat.extend([lx_id] * len(neigh_lx_id_remap[neigh_lx_id])) - - # For the unsafe ones we can only do the in chunk svs - # But we will map the out of chunk svs to the root id and store the - # hierarchical information in a dictionary - for lx_id in unsafe_lx_ids: - sv_ids.extend(neigh_lx_id_remap[lx_id]) - lx_ids_flat.extend([lx_id] * len(neigh_lx_id_remap[lx_id])) - - unsafe_dict = collections.defaultdict(list) - for root_id in unsafe_root_ids: - if np.sum(~np.in1d(root_lx_dict[root_id], unsafe_lx_ids)) == 0: - continue - - for neigh_lx_id in root_lx_dict[root_id]: - unsafe_dict[root_id].append(neigh_lx_id) - - if neigh_lx_id in unsafe_lx_ids: - continue - - sv_ids.extend(neigh_lx_id_remap[neigh_lx_id]) - lx_ids_flat.extend([root_id] * len(neigh_lx_id_remap[neigh_lx_id])) - - # Combine the lists for a (chunk-) global remapping - sv_remapping = dict(zip(sv_ids, lx_ids_flat)) - - return sv_remapping, unsafe_dict - - -def get_root_remapping_for_nodes_and_svs( - cg, chunk_id, node_ids, sv_ids, stop_layer, time_stamp, n_threads=1 -): - """Retrieves root to node id mapping for specified node ids and supervoxel ids - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param node_ids: [np.uint64] - :param stop_layer: int - :param time_stamp: datetime object - :return: multiples - """ - - def _get_root_ids(args): - start_id, end_id = args - - root_ids[start_id:end_id] = cg.get_roots( - combined_ids[start_id:end_id], - stop_layer=stop_layer, - time_stamp=time_stamp, - fail_to_zero=True, - ) - - rr = cg.range_read_chunk( - chunk_id=chunk_id, properties=attributes.Hierarchy.Child, time_stamp=time_stamp - ) - chunk_sv_ids = np.unique(np.concatenate([id[0].value for id in rr.values()])) - chunk_l2_ids = np.unique(cg.get_parents(chunk_sv_ids, time_stamp=time_stamp)) - combined_ids = np.concatenate((node_ids, sv_ids, chunk_l2_ids)) - - root_ids = np.zeros(len(combined_ids), dtype=np.uint64) - n_jobs = np.min([n_threads, len(combined_ids)]) - multi_args = [] - start_ids = np.linspace(0, len(combined_ids), n_jobs + 1).astype(int) - for i_block in range(n_jobs): - multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) - - if n_jobs > 0: - mu.multithread_func(_get_root_ids, multi_args, n_threads=n_threads) - - sv_ids_index = len(node_ids) - chunk_ids_index = len(node_ids) + len(sv_ids) - - return ( - root_ids[0:sv_ids_index], - root_ids[sv_ids_index:chunk_ids_index], - root_ids[chunk_ids_index:], - ) - - -def get_lx_overlapping_remappings_for_nodes_and_svs( - cg, - chunk_id: np.uint64, - node_ids: Sequence[np.uint64], - sv_ids: Sequence[np.uint64], - time_stamp=None, - n_threads: int = 1, -): - """Retrieves sv id to layer mapping for chunk with overlap in positive - direction (one chunk) - - :param cg: chunkedgraph object - :param chunk_id: np.uint64 - :param node_ids: list of np.uint64 - :param sv_ids: list of np.uint64 - :param time_stamp: datetime object - :param n_threads: int - :return: multiples - """ - if time_stamp is None: - time_stamp = datetime.datetime.utcnow() - if time_stamp.tzinfo is None: - time_stamp = UTC.localize(time_stamp) - - stop_layer, _ = calculate_stop_layer(cg, chunk_id) - print(f"Stop layer: {stop_layer}") - - # Find the parent in the lowest common chunk for each node id and sv id. These parent - # ids are referred to as root ids even though they are not necessarily the - # root id. - node_root_ids, sv_root_ids, chunks_root_ids = get_root_remapping_for_nodes_and_svs( - cg, chunk_id, node_ids, sv_ids, stop_layer, time_stamp, n_threads - ) - - u_root_ids, u_idx, c_root_ids = np.unique( - chunks_root_ids, return_counts=True, return_index=True - ) - - # All l2 ids that share no root id with any other l2 id in the chunk are "safe", meaning - # that we can easily obtain the complete remapping (including - # overlap) for these. All other ones have to be resolved using the - # segmentation. - - root_sorted_idx = np.argsort(u_root_ids) - node_sorted_index = np.searchsorted(u_root_ids[root_sorted_idx], node_root_ids) - node_root_counts = c_root_ids[root_sorted_idx][node_sorted_index] - unsafe_root_ids = node_root_ids[np.where(node_root_counts > 1)] - safe_node_ids = node_ids[~np.isin(node_root_ids, unsafe_root_ids)] - - node_to_root_dict = dict(zip(node_ids, node_root_ids)) - - # Future sv id -> lx mapping - sv_ids_to_remap = [] - node_ids_flat = [] - - # Do safe ones first - for node_id in safe_node_ids: - root_id = node_to_root_dict[node_id] - sv_ids_to_add = sv_ids[np.where(sv_root_ids == root_id)] - if len(sv_ids_to_add) > 0: - sv_ids_to_remap.extend(sv_ids_to_add) - node_ids_flat.extend([node_id] * len(sv_ids_to_add)) - - # For the unsafe roots, we will map the out of chunk svs to the root id and store the - # hierarchical information in a dictionary - unsafe_dict = collections.defaultdict(list) - for root_id in unsafe_root_ids: - sv_ids_to_add = sv_ids[np.where(sv_root_ids == root_id)] - if len(sv_ids_to_add) > 0: - relevant_node_ids = node_ids[np.where(node_root_ids == root_id)] - if len(relevant_node_ids) > 0: - unsafe_dict[root_id].extend(relevant_node_ids) - sv_ids_to_remap.extend(sv_ids_to_add) - node_ids_flat.extend([root_id] * len(sv_ids_to_add)) - - # Combine the lists for a (chunk-) global remapping - sv_remapping = dict(zip(sv_ids_to_remap, node_ids_flat)) - - return sv_remapping, unsafe_dict - - -def get_meshing_necessities_from_graph(cg, chunk_id: np.uint64, mip: int): - """Given a chunkedgraph, chunk_id, and mip level, return the voxel dimensions of the chunk to be meshed (mesh_block_shape) - and the chunk origin in the dataset in nm. - - :param cg: chunkedgraph instance - :param chunk_id: uint64 - :param mip: int - """ - layer = cg.get_chunk_layer(chunk_id) - cx, cy, cz = cg.get_chunk_coordinates(chunk_id) - mesh_block_shape = meshgen_utils.get_mesh_block_shape_for_mip(cg, layer, mip) - voxel_resolution = cg.meta.cv.mip_resolution(mip) - chunk_offset = ( - (cx, cy, cz) * mesh_block_shape + cg.meta.cv.mip_voxel_offset(mip) - ) * voxel_resolution - return layer, mesh_block_shape, chunk_offset - - -def calculate_quantization_bits_and_range( - min_quantization_range, max_draco_bin_size, draco_quantization_bits=None -): - if draco_quantization_bits is None: - draco_quantization_bits = np.ceil( - np.log2(min_quantization_range / max_draco_bin_size + 1) - ) - num_draco_bins = 2**draco_quantization_bits - 1 - draco_bin_size = np.ceil(min_quantization_range / num_draco_bins) - draco_quantization_range = draco_bin_size * num_draco_bins - if draco_quantization_range < min_quantization_range + draco_bin_size: - if draco_bin_size == max_draco_bin_size: - return calculate_quantization_bits_and_range( - min_quantization_range, max_draco_bin_size, draco_quantization_bits + 1 - ) - else: - draco_bin_size = draco_bin_size + 1 - draco_quantization_range = draco_quantization_range + num_draco_bins - return draco_quantization_bits, draco_quantization_range, draco_bin_size - - -def get_draco_encoding_settings_for_chunk( - cg, chunk_id: np.uint64, mip: int = 2, high_padding: int = 1 -): - """Calculate the proper draco encoding settings for a chunk to ensure proper stitching is possible - on the layer above. For details about how and why we do this, please see the meshing Readme - - :param cg: chunkedgraph instance - :param chunk_id: uint64 - :param mip: int - :param high_padding: int - """ - _, mesh_block_shape, chunk_offset = get_meshing_necessities_from_graph( - cg, chunk_id, mip - ) - segmentation_resolution = cg.meta.cv.mip_resolution(mip) - min_quantization_range = max( - (mesh_block_shape + high_padding) * segmentation_resolution - ) - max_draco_bin_size = np.floor(min(segmentation_resolution) / np.sqrt(2)) - ( - draco_quantization_bits, - draco_quantization_range, - draco_bin_size, - ) = calculate_quantization_bits_and_range( - min_quantization_range, max_draco_bin_size - ) - draco_quantization_origin = chunk_offset - (chunk_offset % draco_bin_size) - return { - "quantization_bits": draco_quantization_bits, - "compression_level": 1, - "quantization_range": draco_quantization_range, - "quantization_origin": draco_quantization_origin, - "create_metadata": True, - } - - -def get_next_layer_draco_encoding_settings( - cg, prev_layer_encoding_settings, next_layer_chunk_id, mip -): - old_draco_bin_size = prev_layer_encoding_settings["quantization_range"] // ( - 2 ** prev_layer_encoding_settings["quantization_bits"] - 1 - ) - _, mesh_block_shape, chunk_offset = get_meshing_necessities_from_graph( - cg, next_layer_chunk_id, mip - ) - segmentation_resolution = cg.meta.cv.mip_resolution(mip) - min_quantization_range = ( - max(mesh_block_shape * segmentation_resolution) + 2 * old_draco_bin_size - ) - max_draco_bin_size = np.floor(min(segmentation_resolution) / np.sqrt(2)) - ( - draco_quantization_bits, - draco_quantization_range, - draco_bin_size, - ) = calculate_quantization_bits_and_range( - min_quantization_range, max_draco_bin_size - ) - draco_quantization_origin = ( - chunk_offset - - old_draco_bin_size - - ((chunk_offset - old_draco_bin_size) % draco_bin_size) - ) - return { - "quantization_bits": draco_quantization_bits, - "compression_level": 1, - "quantization_range": draco_quantization_range, - "quantization_origin": draco_quantization_origin, - "create_metadata": True, - } - - -def transform_draco_vertices(mesh, encoding_settings): - vertices = np.reshape(mesh["vertices"], (mesh["num_vertices"] * 3,)) - max_quantized_value = 2 ** encoding_settings["quantization_bits"] - 1 - draco_bin_size = encoding_settings["quantization_range"] / max_quantized_value - assert np.equal(np.mod(draco_bin_size, 1), 0) - assert np.equal(np.mod(encoding_settings["quantization_range"], 1), 0) - assert np.equal(np.mod(encoding_settings["quantization_origin"], 1), 0).all() - for coord in range(3): - vertices[coord::3] -= encoding_settings["quantization_origin"][coord] - vertices /= draco_bin_size - vertices += 0.5 - np.floor(vertices, out=vertices) - vertices *= draco_bin_size - for coord in range(3): - vertices[coord::3] += encoding_settings["quantization_origin"][coord] - - -def transform_draco_fragment_and_return_encoding_options( - cg, fragment, layer, mip, chunk_id -): - fragment_encoding_options = fragment["mesh"]["encoding_options"] - if fragment_encoding_options is None: - raise ValueError("Draco fragment has no encoding options") - cur_encoding_settings = { - "quantization_range": fragment_encoding_options.quantization_range, - "quantization_bits": fragment_encoding_options.quantization_bits, - } - node_id = fragment["node_id"] - parent_chunk_ids = cg.get_parent_chunk_ids(node_id) - fragment_layer = cg.get_chunk_layer(node_id) - if fragment_layer >= layer: - raise ValueError( - f"Node {node_id} somehow has greater or equal layer than chunk {chunk_id}" - ) - assert len(parent_chunk_ids) > layer - fragment_layer - for next_layer in range(fragment_layer + 1, layer + 1): - next_layer_chunk_id = parent_chunk_ids[next_layer - fragment_layer] - next_encoding_settings = get_next_layer_draco_encoding_settings( - cg, cur_encoding_settings, next_layer_chunk_id, mip - ) - if next_layer < layer: - transform_draco_vertices(fragment["mesh"], next_encoding_settings) - cur_encoding_settings = next_encoding_settings - return cur_encoding_settings - - -def merge_draco_meshes_across_boundaries( - cg, fragments, chunk_id, mip, high_padding, return_zmesh_object=False -): - """ - Merge a list of draco mesh fragments, removing duplicate vertices that lie - on the chunk boundary where the meshes meet. - """ - vertexct = np.zeros(len(fragments) + 1, np.uint32) - vertexct[1:] = np.cumsum([x["mesh"]["num_vertices"] for x in fragments]) - vertices = np.concatenate([x["mesh"]["vertices"] for x in fragments]) - faces = np.concatenate( - [mesh["mesh"]["faces"] + vertexct[i] for i, mesh in enumerate(fragments)] - ) - del fragments - - if vertexct[-1] > 0: - chunk_coords = cg.get_chunk_coordinates(chunk_id) - coords_bottom_corner_child_chunk = chunk_coords * 2 + 1 - child_chunk_id = cg.get_chunk_id( - None, cg.get_chunk_layer(chunk_id) - 1, *coords_bottom_corner_child_chunk - ) - _, _, child_chunk_offset = get_meshing_necessities_from_graph( - cg, child_chunk_id, mip - ) - # Get the draco encoding settings for the - # child chunk in the "bottom corner" of the chunk_id chunk - draco_encoding_settings_smaller_chunk = get_draco_encoding_settings_for_chunk( - cg, child_chunk_id, mip=mip, high_padding=high_padding - ) - draco_bin_size = draco_encoding_settings_smaller_chunk["quantization_range"] / ( - 2 ** draco_encoding_settings_smaller_chunk["quantization_bits"] - 1 - ) - # Calculate which draco bin the child chunk's boundaries - # were placed into (for each x,y,z of boundary) - chunk_boundary_bin_index = np.floor( - ( - child_chunk_offset - - draco_encoding_settings_smaller_chunk["quantization_origin"] - ) - / draco_bin_size - + np.float32(0.5) - ) - # Now we can determine where the three planes of the quantized chunk boundary are - quantized_chunk_boundary = ( - draco_encoding_settings_smaller_chunk["quantization_origin"] - + chunk_boundary_bin_index * draco_bin_size - ) - # Separate the vertices that are on the quantized chunk boundary from those that aren't - are_chunk_aligned = (vertices == quantized_chunk_boundary).any(axis=1) - vertices = np.hstack((vertices, np.arange(vertexct[-1])[:, np.newaxis])) - chunk_aligned = vertices[are_chunk_aligned] - not_chunk_aligned = vertices[~are_chunk_aligned] - del vertices - del are_chunk_aligned - faces_remapping = {} - # Those that are not simply pass through (simple remap) - if len(not_chunk_aligned) > 0: - not_chunk_aligned_remap = dict( - zip( - not_chunk_aligned[:, 3].astype(np.uint32), - np.arange(len(not_chunk_aligned), dtype=np.uint32), - ) - ) - faces_remapping.update(not_chunk_aligned_remap) - # Those that are on the boundary we remove duplicates - if len(chunk_aligned) > 0: - unique_chunk_aligned, inverse_to_chunk_aligned = np.unique( - chunk_aligned[:, 0:3], return_inverse=True, axis=0 - ) - chunk_aligned_remap = dict( - zip( - chunk_aligned[:, 3].astype(np.uint32), - np.uint32(len(not_chunk_aligned)) - + inverse_to_chunk_aligned.astype(np.uint32), - ) - ) - faces_remapping.update(chunk_aligned_remap) - vertices = np.concatenate((not_chunk_aligned[:, 0:3], unique_chunk_aligned)) - else: - vertices = not_chunk_aligned[:, 0:3] - # Remap the faces to their new vertex indices - fastremap.remap(faces, faces_remapping, in_place=True) - - if return_zmesh_object: - return zmesh.Mesh(vertices[:, 0:3], faces.reshape(-1, 3), None) - - return { - "num_vertices": np.uint32(len(vertices)), - "vertices": vertices[:, 0:3].reshape(-1), - "faces": faces, - } - - -def black_out_dust_from_segmentation(seg, dust_threshold): - """Black out (set to 0) IDs in segmentation not on the segmentation - border that have less voxels than dust_threshold - - :param seg: 3D segmentation (usually uint64) - :param dust_threshold: int - :return: - """ - seg_ids, voxel_count = np.unique(seg, return_counts=True) - boundary = np.concatenate( - ( - seg[-2, :, :], - seg[-1, :, :], - seg[:, -2, :], - seg[:, -1, :], - seg[:, :, -2], - seg[:, :, -1], - ), - axis=None, - ) - seg_ids_on_boundary = np.unique(boundary) - dust_segids = [ - sid - for sid, ct in zip(seg_ids, voxel_count) - if ct < int(dust_threshold) and np.isin(sid, seg_ids_on_boundary, invert=True) - ] - seg = fastremap.mask(seg, dust_segids, in_place=True) - - -def _get_timestamp_from_node_ids(cg, node_ids): - timestamps = cg.get_node_timestamps(node_ids, return_numpy=False) - return max(timestamps) + datetime.timedelta(milliseconds=1) - - -def remeshing( - cg, - l2_node_ids: Sequence[np.uint64], - cv_sharded_mesh_dir: str, - cv_unsharded_mesh_path: str, - stop_layer: int = None, - mip: int = 2, - max_err: int = 40, - time_stamp: datetime.datetime or None = None, -): - """Given a chunkedgraph, a list of level 2 nodes, - perform remeshing and stitching up the node hierarchy (or up to the stop_layer) - - :param cg: chunkedgraph instance - :param l2_node_ids: list of uint64 - :param stop_layer: int - :param cv_path: str - :param cv_mesh_dir: str - :param mip: int - :param max_err: int - :return: - """ - l2_chunk_dict = collections.defaultdict(set) - # Find the chunk_ids of the l2_node_ids - - def add_nodes_to_l2_chunk_dict(ids): - for node_id in ids: - chunk_id = cg.get_chunk_id(node_id) - l2_chunk_dict[chunk_id].add(node_id) - - add_nodes_to_l2_chunk_dict(l2_node_ids) - for chunk_id, node_ids in l2_chunk_dict.items(): - if PRINT_FOR_DEBUGGING: - print("remeshing", chunk_id, node_ids) - try: - l2_time_stamp = _get_timestamp_from_node_ids(cg, node_ids) - except ValueError: - # ignore bad/invalid messages - return - # Remesh the l2_node_ids - chunk_initial_mesh_task( - None, - chunk_id, - mip=mip, - node_id_subset=node_ids, - cg=cg, - cv_unsharded_mesh_path=cv_unsharded_mesh_path, - max_err=max_err, - sharded=False, - time_stamp=l2_time_stamp, - ) - chunk_dicts = [] - max_layer = stop_layer or cg._n_layers - for layer in range(3, max_layer + 1): - chunk_dicts.append(collections.defaultdict(set)) - cur_chunk_dict = l2_chunk_dict - # Find the parents of each l2_node_id up to the stop_layer, - # as well as their associated chunk_ids - for layer in range(3, max_layer + 1): - for _, node_ids in cur_chunk_dict.items(): - parent_nodes = cg.get_parents(node_ids, time_stamp=time_stamp) - for parent_node in parent_nodes: - chunk_layer = cg.get_chunk_layer(parent_node) - index_in_dict_array = chunk_layer - 3 - if index_in_dict_array < len(chunk_dicts): - chunk_id = cg.get_chunk_id(parent_node) - chunk_dicts[index_in_dict_array][chunk_id].add(parent_node) - cur_chunk_dict = chunk_dicts[layer - 3] - for chunk_dict in chunk_dicts: - for chunk_id, node_ids in chunk_dict.items(): - if PRINT_FOR_DEBUGGING: - print("remeshing", chunk_id, node_ids) - # Stitch the meshes of the parents we found in the previous loop - chunk_stitch_remeshing_task( - None, - chunk_id, - mip=mip, - fragment_batch_size=40, - node_id_subset=node_ids, - cg=cg, - cv_sharded_mesh_dir=cv_sharded_mesh_dir, - cv_unsharded_mesh_path=cv_unsharded_mesh_path, - ) - - -def chunk_initial_mesh_task( - cg_name, - chunk_id, - cv_unsharded_mesh_path, - mip=2, - max_err=40, - lod=0, - encoding="draco", - time_stamp=None, - dust_threshold=None, - return_frag_count=False, - node_id_subset=None, - cg=None, - sharded=False, - cache=True, -): - if cg is None: - cg = ChunkedGraph(graph_id=cg_name) - result = [] - cache_string = "public" if cache else "no-cache" - - layer, _, chunk_offset = get_meshing_necessities_from_graph(cg, chunk_id, mip) - cx, cy, cz = cg.get_chunk_coordinates(chunk_id) - high_padding = 1 - assert layer == 2 - assert mip >= cg.meta.cv.mip - - if sharded: - cv = CloudVolume( - f"graphene://https://localhost/segmentation/table/dummy", - info=meshgen_utils.get_json_info(cg), - ) - sharding_info = cv.mesh.meta.info["sharding"]["2"] - sharding_spec = ShardingSpecification.from_dict(sharding_info) - merged_meshes = {} - mesh_dst = os.path.join( - cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(layer) - ) - else: - mesh_dst = cv_unsharded_mesh_path - - result.append((chunk_id, layer, cx, cy, cz)) - print( - "Retrieving remap table for chunk %s -- (%s, %s, %s, %s)" - % (chunk_id, layer, cx, cy, cz) - ) - mesher = zmesh.Mesher(cg.meta.cv.mip_resolution(mip)) - draco_encoding_settings = get_draco_encoding_settings_for_chunk( - cg, chunk_id, mip, high_padding - ) - if node_id_subset is None: - seg = get_remapped_segmentation( - cg, chunk_id, mip, overlap_vx=high_padding, time_stamp=time_stamp - ) - else: - seg = get_remapped_seg_for_lvl2_nodes( - cg, - chunk_id, - node_id_subset, - mip=mip, - overlap_vx=high_padding, - time_stamp=time_stamp, - ) - if dust_threshold: - black_out_dust_from_segmentation(seg, dust_threshold) - if return_frag_count: - return np.unique(seg).shape[0] - mesher.mesh(seg) - del seg - cf = CloudFiles(mesh_dst) - if PRINT_FOR_DEBUGGING: - print("cv path", mesh_dst) - print("num ids", len(mesher.ids())) - result.append(len(mesher.ids())) - for obj_id in mesher.ids(): - mesh = mesher.get(obj_id, reduction_factor=100, max_error=max_err) - mesher.erase(obj_id) - mesh.vertices[:] += chunk_offset - if encoding == "draco": - try: - file_contents = DracoPy.encode_mesh_to_buffer( - mesh.vertices.flatten("C"), - mesh.faces.flatten("C"), - **draco_encoding_settings, - ) - except: - result.append( - f"{obj_id} failed: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces" - ) - continue - compress = False - else: - file_contents = mesh.to_precomputed() - compress = True - if WRITING_TO_CLOUD: - if sharded: - merged_meshes[int(obj_id)] = file_contents - else: - cf.put( - path=f"{meshgen_utils.get_mesh_name(cg, obj_id)}", - content=file_contents, - compress=compress, - cache_control=cache_string, - ) - if sharded and WRITING_TO_CLOUD: - shard_binary = sharding_spec.synthesize_shard(merged_meshes) - shard_filename = cv.mesh.readers[layer].get_filename(chunk_id) - cf.put( - shard_filename, - shard_binary, - content_type="application/octet-stream", - compress=False, - cache_control=cache_string, - ) - if PRINT_FOR_DEBUGGING: - print(", ".join(str(x) for x in result)) - return result - - -def get_multi_child_nodes(cg, chunk_id, node_id_subset=None, chunk_bbox_string=False): - if node_id_subset is None: - range_read = cg.range_read_chunk( - chunk_id, properties=attributes.Hierarchy.Child - ) - else: - range_read = cg.client.read_nodes( - node_ids=node_id_subset, properties=attributes.Hierarchy.Child - ) - - node_ids = np.array(list(range_read.keys())) - node_rows = np.array(list(range_read.values())) - child_fragments = np.array( - [ - fragment.value - for child_fragments_for_node in node_rows - for fragment in child_fragments_for_node - ], dtype=object - ) - # Filter out node ids that do not have roots (caused by failed ingest tasks) - root_ids = cg.get_roots(node_ids, fail_to_zero=True) - # Only keep nodes with more than one child - multi_child_mask = np.array( - [len(fragments) > 1 for fragments in child_fragments], dtype=bool - ) - root_id_mask = np.array([root_id != 0 for root_id in root_ids], dtype=bool) - multi_child_node_ids = node_ids[multi_child_mask & root_id_mask] - multi_child_children_ids = child_fragments[multi_child_mask & root_id_mask] - # Store how many children each node has, because we will retrieve all children at once - multi_child_num_children = [len(children) for children in multi_child_children_ids] - child_fragments_flat = np.array( - [ - frag - for children_of_node in multi_child_children_ids - for frag in children_of_node - ] - ) - multi_child_descendants = meshgen_utils.get_downstream_multi_child_nodes( - cg, child_fragments_flat - ) - start_index = 0 - multi_child_nodes = {} - for i in range(len(multi_child_node_ids)): - end_index = start_index + multi_child_num_children[i] - descendents_for_current_node = multi_child_descendants[start_index:end_index] - node_id = multi_child_node_ids[i] - if chunk_bbox_string: - multi_child_nodes[ - f"{node_id}:0:{meshgen_utils.get_chunk_bbox_str(cg, node_id)}" - ] = [ - f"{c}:0:{meshgen_utils.get_chunk_bbox_str(cg, c)}" - for c in descendents_for_current_node - ] - else: - multi_child_nodes[multi_child_node_ids[i]] = descendents_for_current_node - start_index = end_index - - return multi_child_nodes, multi_child_descendants - - -def chunk_stitch_remeshing_task( - cg_name, - chunk_id, - cv_sharded_mesh_dir, - cv_unsharded_mesh_path, - mip=2, - lod=0, - fragment_batch_size=None, - node_id_subset=None, - cg=None, - high_padding=1, -): - """ - For each node with more than one child, create a new fragment by - merging the mesh fragments of the children. - """ - if cg is None: - cg = ChunkedGraph(graph_id=cg_name) - cx, cy, cz = cg.get_chunk_coordinates(chunk_id) - layer = cg.get_chunk_layer(chunk_id) - result = [] - - assert layer > 2 - - print( - "Retrieving children for chunk %s -- (%s, %s, %s, %s)" - % (chunk_id, layer, cx, cy, cz) - ) - - multi_child_nodes, _ = get_multi_child_nodes(cg, chunk_id, node_id_subset, False) - print(f"{len(multi_child_nodes)} nodes with more than one child") - result.append((chunk_id, len(multi_child_nodes))) - if not multi_child_nodes: - print("Nothing to do", cx, cy, cz) - return ", ".join(str(x) for x in result) - - cv = CloudVolume( - f"graphene://https://localhost/segmentation/table/dummy", - mesh_dir=cv_sharded_mesh_dir, - info=meshgen_utils.get_json_info(cg), - ) - - fragments_in_batch_processed = 0 - batches_processed = 0 - num_fragments_processed = 0 - fragment_to_fetch = [ - fragment - for child_fragments in multi_child_nodes.values() - for fragment in child_fragments - ] - cf = CloudFiles(cv_unsharded_mesh_path) - if fragment_batch_size is None: - fragment_map = cv.mesh.get_meshes_on_bypass( - fragment_to_fetch, allow_missing=True - ) - else: - fragment_map = cv.mesh.get_meshes_on_bypass( - fragment_to_fetch[0:fragment_batch_size], allow_missing=True - ) - i = 0 - fragments_d = {} - for new_fragment_id, fragment_ids_to_fetch in multi_child_nodes.items(): - i += 1 - if i % max(1, len(multi_child_nodes) // 10) == 0: - print(f"{i}/{len(multi_child_nodes)}") - - old_fragments = [] - missing_fragments = False - for fragment_id in fragment_ids_to_fetch: - if fragment_batch_size is not None: - fragments_in_batch_processed += 1 - if fragments_in_batch_processed > fragment_batch_size: - fragments_in_batch_processed = 1 - batches_processed += 1 - num_fragments_processed = batches_processed * fragment_batch_size - fragment_map = cv.mesh.get_meshes_on_bypass( - fragment_to_fetch[ - num_fragments_processed : num_fragments_processed - + fragment_batch_size - ], - allow_missing=True, - ) - if fragment_id in fragment_map: - old_frag = fragment_map[fragment_id] - new_old_frag = { - "num_vertices": len(old_frag.vertices), - "vertices": old_frag.vertices, - "faces": old_frag.faces.reshape(-1), - "encoding_options": old_frag.encoding_options, - "encoding_type": "draco", - } - wrapper_object = { - "mesh": new_old_frag, - "node_id": np.uint64(old_frag.segid), - } - old_fragments.append(wrapper_object) - elif cg.get_chunk_layer(np.uint64(fragment_id)) > 2: - missing_fragments = True - result.append(f"{fragment_id} missing for {new_fragment_id}") - - if len(old_fragments) == 0 or missing_fragments: - result.append(f"No meshes for {new_fragment_id}") - continue - - draco_encoding_options = None - for old_fragment in old_fragments: - if draco_encoding_options is None: - draco_encoding_options = ( - transform_draco_fragment_and_return_encoding_options( - cg, old_fragment, layer, mip, chunk_id - ) - ) - else: - transform_draco_fragment_and_return_encoding_options( - cg, old_fragment, layer, mip, chunk_id - ) - - new_fragment = merge_draco_meshes_across_boundaries( - cg, old_fragments, chunk_id, mip, high_padding - ) - - try: - new_fragment_b = DracoPy.encode_mesh_to_buffer( - new_fragment["vertices"], - new_fragment["faces"], - **draco_encoding_options, - ) - except: - result.append( - f'Bad mesh created for {new_fragment_id}: {len(new_fragment["vertices"])} ' - f'vertices, {len(new_fragment["faces"])} faces' - ) - continue - - if WRITING_TO_CLOUD: - fragment_name = meshgen_utils.get_chunk_bbox_str(cg, new_fragment_id) - fragment_name = f"{new_fragment_id}:0:{fragment_name}" - fragments_d[new_fragment_id] = fragment_name - cf.put( - fragment_name, - new_fragment_b, - content_type="application/octet-stream", - compress=False, - cache_control="public", - ) - - manifest_cache = ManifestCache(cg.graph_id, initial=False) - manifest_cache.set_fragments(fragments_d) - - if PRINT_FOR_DEBUGGING: - print(", ".join(str(x) for x in result)) - return ", ".join(str(x) for x in result) - - -def chunk_initial_sharded_stitching_task( - cg_name, chunk_id, mip, cg=None, high_padding=1, cache=True -): - start_existence_check_time = time.time() - if cg is None: - cg = ChunkedGraph(graph_id=cg_name) - - cache_string = "public" if cache else "no-cache" - - layer = cg.get_chunk_layer(chunk_id) - multi_child_nodes, multi_child_descendants = get_multi_child_nodes(cg, chunk_id) - - chunk_to_id_dict = collections.defaultdict(list) - for child_node in multi_child_descendants: - cur_chunk_id = int(cg.get_chunk_id(child_node)) - chunk_to_id_dict[cur_chunk_id].append(child_node) - - cv = CloudVolume( - f"graphene://https://localhost/segmentation/table/dummy", - info=meshgen_utils.get_json_info(cg), - ) - shard_filenames = [] - shard_to_chunk_id = {} - for cur_chunk_id in chunk_to_id_dict: - shard_id = cv.meta.decode_chunk_position_number(cur_chunk_id) - shard_filename = ( - str(cg.get_chunk_layer(cur_chunk_id)) + "/" + str(shard_id) + "-0.shard" - ) - shard_to_chunk_id[shard_filename] = cur_chunk_id - shard_filenames.append(shard_filename) - mesh_dict = {} - - cf = CloudFiles(os.path.join(cv.cloudpath, cv.mesh.meta.mesh_path, "initial")) - files_contents = cf.get(shard_filenames) - for i in range(len(files_contents)): - cur_chunk_id = shard_to_chunk_id[files_contents[i]["path"]] - cur_layer = cg.get_chunk_layer(cur_chunk_id) - if files_contents[i]["content"] is not None: - disassembled_shard = cv.mesh.readers[cur_layer].disassemble_shard( - files_contents[i]["content"] - ) - nodes_in_chunk = chunk_to_id_dict[int(cur_chunk_id)] - for node_in_chunk in nodes_in_chunk: - node_in_chunk_int = int(node_in_chunk) - if node_in_chunk_int in disassembled_shard: - mesh_dict[node_in_chunk_int] = disassembled_shard[node_in_chunk] - del files_contents - - number_frags_proc = 0 - sharding_info = cv.mesh.meta.info["sharding"][str(layer)] - sharding_spec = ShardingSpecification.from_dict(sharding_info) - merged_meshes = {} - biggest_frag = 0 - biggest_frag_vx_ct = 0 - bad_meshes = [] - for new_fragment_id in multi_child_nodes: - fragment_ids_to_fetch = multi_child_nodes[new_fragment_id] - old_fragments = [] - for frag_to_fetch in fragment_ids_to_fetch: - try: - old_fragments.append( - { - "mesh": decode_draco_mesh_buffer(mesh_dict[int(frag_to_fetch)]), - "node_id": np.uint64(frag_to_fetch), - } - ) - except KeyError: - pass - if len(old_fragments) > 0: - draco_encoding_options = None - for old_fragment in old_fragments: - if draco_encoding_options is None: - draco_encoding_options = ( - transform_draco_fragment_and_return_encoding_options( - cg, old_fragment, layer, mip, chunk_id - ) - ) - else: - transform_draco_fragment_and_return_encoding_options( - cg, old_fragment, layer, mip, chunk_id - ) - - new_fragment = merge_draco_meshes_across_boundaries( - cg, old_fragments, chunk_id, mip, high_padding - ) - - if len(new_fragment["vertices"]) > biggest_frag_vx_ct: - biggest_frag = new_fragment_id - biggest_frag_vx_ct = len(new_fragment["vertices"]) - - try: - new_fragment_b = DracoPy.encode_mesh_to_buffer( - new_fragment["vertices"], - new_fragment["faces"], - **draco_encoding_options, - ) - merged_meshes[int(new_fragment_id)] = new_fragment_b - except: - print(f"failed to merge {new_fragment_id}") - bad_meshes.append(new_fragment_id) - pass - number_frags_proc = number_frags_proc + 1 - if number_frags_proc % 1000 == 0: - print(f"number frag proc = {number_frags_proc}") - del mesh_dict - shard_binary = sharding_spec.synthesize_shard(merged_meshes) - shard_filename = cv.mesh.readers[layer].get_filename(chunk_id) - cf = CloudFiles( - os.path.join(cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(layer)) - ) - cf.put( - shard_filename, - shard_binary, - content_type="application/octet-stream", - compress=False, - cache_control=cache_string, - ) - total_time = time.time() - start_existence_check_time - - ret = { - "chunk_id": chunk_id, - "total_time": total_time, - "biggest_frag": biggest_frag, - "biggest_frag_vx_ct": biggest_frag_vx_ct, - "number_frag": number_frags_proc, - "bad meshes": bad_meshes, - } - return ret +# pylint: disable=invalid-name, missing-docstring, unused-wildcard-import, wildcard-import +# +# Backward-compatible re-export facade. +# All functionality has been split into: +# meshgen_utils.py — shared Draco, segmentation, merge utilities and constants +# meshgen_initial.py — initial ingest mesh pipeline +# meshgen_remesh.py — dynamic remeshing pipeline + +from pychunkedgraph.meshing.meshgen_utils import ( # noqa: F401 + UTC, + PRINT_FOR_DEBUGGING, + WRITING_TO_CLOUD, + REDIS_HOST, + REDIS_PORT, + REDIS_PASSWORD, + REDIS_URL, + decode_draco_mesh_buffer, + remap_seg_using_unsafe_dict, + calculate_stop_layer, + get_meshing_necessities_from_graph, + calculate_quantization_bits_and_range, + get_draco_encoding_settings_for_chunk, + get_next_layer_draco_encoding_settings, + transform_draco_vertices, + transform_draco_fragment_and_return_encoding_options, + merge_draco_meshes_across_boundaries, + black_out_dust_from_segmentation, + get_multi_child_nodes, +) + +from pychunkedgraph.meshing.meshgen_initial import ( # noqa: F401 + get_remapped_segmentation, + get_higher_to_lower_remapping, + get_root_lx_remapping, + get_lx_overlapping_remappings, + chunk_initial_mesh_task, + chunk_initial_sharded_stitching_task, +) + +from pychunkedgraph.meshing.meshgen_remesh import ( # noqa: F401 + get_remapped_seg_for_lvl2_nodes, + get_root_remapping_for_nodes_and_svs, + get_lx_overlapping_remappings_for_nodes_and_svs, + remeshing, + chunk_stitch_remeshing_task, +) diff --git a/pychunkedgraph/meshing/meshgen_initial.py b/pychunkedgraph/meshing/meshgen_initial.py new file mode 100644 index 000000000..6089cb4a4 --- /dev/null +++ b/pychunkedgraph/meshing/meshgen_initial.py @@ -0,0 +1,597 @@ +# pylint: disable=invalid-name, missing-docstring, too-many-lines, wrong-import-order, import-outside-toplevel, no-member, c-extension-no-member + +import os +import collections +import datetime +import time +from functools import lru_cache +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed + +import numpy as np +from cloudfiles import CloudFiles +from cloudvolume import CloudVolume +from cloudvolume.datasource.precomputed.sharding import ShardingSpecification +import DracoPy +import zmesh +import fastremap + +from pychunkedgraph.graph.chunkedgraph import ChunkedGraph # noqa +from pychunkedgraph.graph import attributes # noqa +from pychunkedgraph.meshing import meshgen_utils +from pychunkedgraph.meshing.meshgen_utils import ( + UTC, + PRINT_FOR_DEBUGGING, + WRITING_TO_CLOUD, + remap_seg_using_unsafe_dict, + calculate_stop_layer, + get_meshing_necessities_from_graph, + get_draco_encoding_settings_for_chunk, + black_out_dust_from_segmentation, + get_multi_child_nodes, + decode_draco_mesh_buffer, + transform_draco_fragment_and_return_encoding_options, + merge_draco_meshes_across_boundaries, +) + + +def get_remapped_segmentation( + cg, chunk_id, mip=2, overlap_vx=1, time_stamp=None, n_threads=1 +): + """Downloads + remaps ws segmentation + resolve unclear cases + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param mip: int + :param overlap_vx: int + :param time_stamp: + :return: remapped segmentation + """ + assert mip >= cg.meta.cv.mip + + sv_remapping, unsafe_dict = get_lx_overlapping_remappings( + cg, chunk_id, time_stamp=time_stamp, n_threads=n_threads + ) + + ws_seg = meshgen_utils.get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx) + seg = fastremap.mask_except(ws_seg, list(sv_remapping.keys()), in_place=False) + fastremap.remap(seg, sv_remapping, preserve_missing_labels=True, in_place=True) + + return remap_seg_using_unsafe_dict(seg, unsafe_dict) + + +@lru_cache(maxsize=None) +def get_higher_to_lower_remapping(cg, chunk_id, time_stamp): + """Retrieves lx node id to sv id mappping + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param time_stamp: datetime object + :return: dictionary + """ + + def _lower_remaps(ks): + return np.concatenate([lower_remaps[k] for k in ks]) + + assert cg.get_chunk_layer(chunk_id) >= 2 + assert cg.get_chunk_layer(chunk_id) <= cg.meta.layer_count + + print(f"\n{chunk_id} ----------------\n") + + lower_remaps = {} + if cg.get_chunk_layer(chunk_id) > 2: + for lower_chunk_id in cg.get_chunk_child_ids(chunk_id): + # TODO speedup + lower_remaps.update( + get_higher_to_lower_remapping(cg, lower_chunk_id, time_stamp=time_stamp) + ) + + rr_chunk = cg.range_read_chunk( + chunk_id=chunk_id, properties=attributes.Hierarchy.Child, time_stamp=time_stamp + ) + + # This for-loop ensures that only the latest lx_ids are considered + # The order by id guarantees the time order (only true for same neurons + # but that is the case here). + lx_remapping = {} + all_lower_ids = set() + for k in sorted(rr_chunk.keys(), reverse=True): + this_child_ids = rr_chunk[k][0].value + if this_child_ids[0] in all_lower_ids: + continue + + all_lower_ids = all_lower_ids.union(set(list(this_child_ids))) + + if cg.get_chunk_layer(chunk_id) > 2: + try: + lx_remapping[k] = _lower_remaps(this_child_ids) + except KeyError: + # KeyErrors indicate that this id is deprecated given the + # time_stamp + continue + else: + lx_remapping[k] = this_child_ids + + return lx_remapping + + +@lru_cache(maxsize=None) +def get_root_lx_remapping(cg, chunk_id, stop_layer, time_stamp, n_threads=1): + """Retrieves root to l2 node id mapping + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param stop_layer: int + :param time_stamp: datetime object + :return: multiples + """ + + def _get_root_ids(args): + start_id, end_id = args + root_ids[start_id:end_id] = cg.get_roots( + lx_ids[start_id:end_id], + stop_layer=stop_layer, + fail_to_zero=True, + ) + + lx_id_remap = get_higher_to_lower_remapping(cg, chunk_id, time_stamp=time_stamp) + + lx_ids = np.array(list(lx_id_remap.keys())) + + root_ids = np.zeros(len(lx_ids), dtype=np.uint64) + n_jobs = np.min([n_threads, len(lx_ids)]) + multi_args = [] + start_ids = np.linspace(0, len(lx_ids), n_jobs + 1).astype(int) + for i_block in range(n_jobs): + multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) + + if n_jobs > 0: + with ThreadPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(_get_root_ids, multi_args)) + + return lx_ids, np.array(root_ids), lx_id_remap + + +# @lru_cache(maxsize=None) +def get_lx_overlapping_remappings(cg, chunk_id, time_stamp=None, n_threads=1): + """Retrieves sv id to layer mapping for chunk with overlap in positive + direction (one chunk) + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param time_stamp: datetime object + :return: multiples + """ + if time_stamp is None: + time_stamp = datetime.datetime.now(datetime.timezone.utc) + if time_stamp.tzinfo is None: + time_stamp = UTC.localize(time_stamp) + + stop_layer, neigh_chunk_ids = calculate_stop_layer(cg, chunk_id) + print(f"Stop layer: {stop_layer}") + + # Find the parent in the lowest common chunk for each l2 id. These parent + # ids are referred to as root ids even though they are not necessarily the + # root id. + neigh_lx_ids = [] + neigh_lx_id_remap = {} + neigh_root_ids = [] + + safe_lx_ids = [] + unsafe_lx_ids = [] + unsafe_root_ids = [] + + # Parallelize the main bottleneck: fetching root mappings for neighbor chunks + with ThreadPoolExecutor() as executor: + future_to_chunk = { + executor.submit( + get_root_lx_remapping, + cg, + nid, + stop_layer, + time_stamp=time_stamp, + n_threads=n_threads, + ): nid + for nid in neigh_chunk_ids + } + results = {} + for future in as_completed(future_to_chunk): + nid = future_to_chunk[future] + results[nid] = future.result() + + for neigh_chunk_id in neigh_chunk_ids: + print(f"Neigh: {neigh_chunk_id} --------------") + + lx_ids, root_ids, lx_id_remap = results[neigh_chunk_id] + neigh_lx_ids.extend(lx_ids) + neigh_lx_id_remap.update(lx_id_remap) + neigh_root_ids.extend(root_ids) + + if neigh_chunk_id == chunk_id: + # The first neigh_chunk_id is the one we are interested in. All lx + # ids that share no root id with any other lx id are "safe", meaning + # that we can easily obtain the complete remapping (including + # overlap) for these. All other ones have to be resolved using the + # segmentation. + _, u_idx, c_root_ids = np.unique( + neigh_root_ids, return_counts=True, return_index=True + ) + + safe_lx_ids = lx_ids[u_idx[c_root_ids == 1]] + unsafe_lx_ids = lx_ids[~np.isin(lx_ids, safe_lx_ids)] + unsafe_root_ids = np.unique(root_ids[u_idx[c_root_ids != 1]]) + + lx_root_dict = dict(zip(neigh_lx_ids, neigh_root_ids)) + root_lx_dict = collections.defaultdict(list) + + # Future sv id -> lx mapping + sv_ids = [] + lx_ids_flat = [] + + for i_root_id in range(len(neigh_root_ids)): + root_lx_dict[neigh_root_ids[i_root_id]].append(neigh_lx_ids[i_root_id]) + + # Do safe ones first + for lx_id in safe_lx_ids: + root_id = lx_root_dict[lx_id] + for neigh_lx_id in root_lx_dict[root_id]: + lx_sv_ids = neigh_lx_id_remap[neigh_lx_id] + sv_ids.extend(lx_sv_ids) + lx_ids_flat.extend([lx_id] * len(neigh_lx_id_remap[neigh_lx_id])) + + # For the unsafe ones we can only do the in chunk svs + # But we will map the out of chunk svs to the root id and store the + # hierarchical information in a dictionary + for lx_id in unsafe_lx_ids: + sv_ids.extend(neigh_lx_id_remap[lx_id]) + lx_ids_flat.extend([lx_id] * len(neigh_lx_id_remap[lx_id])) + + unsafe_dict = collections.defaultdict(list) + for root_id in unsafe_root_ids: + if np.sum(~np.isin(root_lx_dict[root_id], unsafe_lx_ids)) == 0: + continue + + for neigh_lx_id in root_lx_dict[root_id]: + unsafe_dict[root_id].append(neigh_lx_id) + + if neigh_lx_id in unsafe_lx_ids: + continue + + sv_ids.extend(neigh_lx_id_remap[neigh_lx_id]) + lx_ids_flat.extend([root_id] * len(neigh_lx_id_remap[neigh_lx_id])) + + # Combine the lists for a (chunk-) global remapping + sv_remapping = dict(zip(sv_ids, lx_ids_flat)) + + return sv_remapping, unsafe_dict + + +def chunk_initial_mesh_task( + cg_name, + chunk_id, + cv_unsharded_mesh_path, + mip=2, + max_err=40, + lod=0, + encoding="draco", + time_stamp=None, + dust_threshold=None, + return_frag_count=False, + node_id_subset=None, + cg=None, + sharded=False, + cache=True, +): + if cg is None: + cg = ChunkedGraph(graph_id=cg_name) + result = [] + cache_string = "public" if cache else "no-cache" + + layer, _, chunk_offset = get_meshing_necessities_from_graph(cg, chunk_id, mip) + cx, cy, cz = cg.get_chunk_coordinates(chunk_id) + high_padding = 1 + assert layer == 2 + assert mip >= cg.meta.cv.mip + + if sharded: + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + info=meshgen_utils.get_json_info(cg), + ) + sharding_info = cv.mesh.meta.info["sharding"]["2"] + sharding_spec = ShardingSpecification.from_dict(sharding_info) + merged_meshes = {} + mesh_dst = os.path.join( + cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(layer) + ) + else: + mesh_dst = cv_unsharded_mesh_path + + result.append((chunk_id, layer, cx, cy, cz)) + print( + "Retrieving remap table for chunk %s -- (%s, %s, %s, %s)" + % (chunk_id, layer, cx, cy, cz) + ) + mesher = zmesh.Mesher(cg.meta.cv.mip_resolution(mip)) + draco_encoding_settings = get_draco_encoding_settings_for_chunk( + cg, chunk_id, mip, high_padding + ) + if node_id_subset is None: + seg = get_remapped_segmentation( + cg, chunk_id, mip, overlap_vx=high_padding, time_stamp=time_stamp + ) + else: + # Import here to avoid circular import at module level + from pychunkedgraph.meshing.meshgen_remesh import ( + get_remapped_seg_for_lvl2_nodes, + ) + + seg = get_remapped_seg_for_lvl2_nodes( + cg, + chunk_id, + node_id_subset, + mip=mip, + overlap_vx=high_padding, + time_stamp=time_stamp, + ) + if dust_threshold: + black_out_dust_from_segmentation(seg, dust_threshold) + if return_frag_count: + return np.unique(seg).shape[0] + mesher.mesh(seg) + del seg + + if PRINT_FOR_DEBUGGING: + print("cv path", mesh_dst) + print("num ids", len(mesher.ids())) + result.append(len(mesher.ids())) + + # Extract all meshes sequentially (zmesh Mesher is not thread-safe) + meshes = [] + for obj_id in mesher.ids(): + mesh = mesher.get(obj_id, reduction_factor=100, max_error=max_err) + mesher.erase(obj_id) + mesh.vertices[:] += chunk_offset + meshes.append((obj_id, mesh)) + del mesher + + # Encode + upload in parallel + def _encode_and_upload(args): + obj_id, mesh = args + if encoding == "draco": + try: + file_contents = DracoPy.encode_mesh_to_buffer( + mesh.vertices.flatten("C"), + mesh.faces.flatten("C"), + **draco_encoding_settings, + ) + except: + return ( + "error", + f"{obj_id} failed: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces", + ) + compress = False + else: + file_contents = mesh.to_precomputed() + compress = True + if WRITING_TO_CLOUD: + if sharded: + return ("shard", int(obj_id), file_contents) + else: + thread_cf = CloudFiles(mesh_dst) + thread_cf.put( + path=f"{meshgen_utils.get_mesh_name(cg, obj_id)}", + content=file_contents, + compress=compress, + cache_control=cache_string, + ) + return None + + with ThreadPoolExecutor() as executor: + for r in executor.map(_encode_and_upload, meshes): + if r is None: + continue + if r[0] == "error": + result.append(r[1]) + elif r[0] == "shard": + merged_meshes[r[1]] = r[2] + + if sharded and WRITING_TO_CLOUD: + shard_binary = sharding_spec.synthesize_shard(merged_meshes) + shard_filename = cv.mesh.readers[layer].get_filename(chunk_id) + cf = CloudFiles(mesh_dst) + cf.put( + shard_filename, + shard_binary, + content_type="application/octet-stream", + compress=False, + cache_control=cache_string, + ) + if PRINT_FOR_DEBUGGING: + print(", ".join(str(x) for x in result)) + return result + + +# --------------------------------------------------------------------------- +# Process-pool worker for chunk_initial_sharded_stitching_task +# --------------------------------------------------------------------------- + +_worker_cg = None + + +def _init_cg_worker(graph_id): + global _worker_cg + _worker_cg = ChunkedGraph(graph_id=graph_id) + + +def _process_sharded_fragment( + new_fragment_id, fragment_ids, frag_mesh_data, layer, mip, chunk_id, high_padding +): + cg = _worker_cg + old_fragments = [] + for frag_id in fragment_ids: + frag_data = frag_mesh_data.get(int(frag_id)) + if frag_data is not None: + try: + old_fragments.append( + { + "mesh": decode_draco_mesh_buffer(frag_data), + "node_id": np.uint64(frag_id), + } + ) + except (KeyError, ValueError): + pass + + if not old_fragments: + return (new_fragment_id, None, 0, None) + + draco_encoding_options = None + for old_fragment in old_fragments: + if draco_encoding_options is None: + draco_encoding_options = ( + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id + ) + ) + else: + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id + ) + + new_fragment = merge_draco_meshes_across_boundaries( + cg, old_fragments, chunk_id, mip, high_padding + ) + + vertex_count = len(new_fragment["vertices"]) + + try: + new_fragment_b = DracoPy.encode_mesh_to_buffer( + new_fragment["vertices"], + new_fragment["faces"], + **draco_encoding_options, + ) + return (new_fragment_id, new_fragment_b, vertex_count, None) + except Exception: + return ( + new_fragment_id, + None, + vertex_count, + f"failed to merge {new_fragment_id}", + ) + + +def chunk_initial_sharded_stitching_task( + cg_name, chunk_id, mip, cg=None, high_padding=1, cache=True +): + start_existence_check_time = time.time() + if cg is None: + cg = ChunkedGraph(graph_id=cg_name) + + cache_string = "public" if cache else "no-cache" + + layer = cg.get_chunk_layer(chunk_id) + multi_child_nodes, multi_child_descendants = get_multi_child_nodes(cg, chunk_id) + + chunk_to_id_dict = collections.defaultdict(list) + for child_node in multi_child_descendants: + cur_chunk_id = int(cg.get_chunk_id(child_node)) + chunk_to_id_dict[cur_chunk_id].append(child_node) + + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + info=meshgen_utils.get_json_info(cg), + ) + shard_filenames = [] + shard_to_chunk_id = {} + for cur_chunk_id in chunk_to_id_dict: + shard_id = cv.meta.decode_chunk_position_number(cur_chunk_id) + shard_filename = ( + str(cg.get_chunk_layer(cur_chunk_id)) + "/" + str(shard_id) + "-0.shard" + ) + shard_to_chunk_id[shard_filename] = cur_chunk_id + shard_filenames.append(shard_filename) + mesh_dict = {} + + cf = CloudFiles(os.path.join(cv.cloudpath, cv.mesh.meta.mesh_path, "initial")) + files_contents = cf.get(shard_filenames) + for i in range(len(files_contents)): + cur_chunk_id = shard_to_chunk_id[files_contents[i]["path"]] + cur_layer = cg.get_chunk_layer(cur_chunk_id) + if files_contents[i]["content"] is not None: + disassembled_shard = cv.mesh.readers[cur_layer].disassemble_shard( + files_contents[i]["content"] + ) + nodes_in_chunk = chunk_to_id_dict[int(cur_chunk_id)] + for node_in_chunk in nodes_in_chunk: + node_in_chunk_int = int(node_in_chunk) + if node_in_chunk_int in disassembled_shard: + mesh_dict[node_in_chunk_int] = disassembled_shard[node_in_chunk] + del files_contents + + sharding_info = cv.mesh.meta.info["sharding"][str(layer)] + sharding_spec = ShardingSpecification.from_dict(sharding_info) + merged_meshes = {} + biggest_frag = 0 + biggest_frag_vx_ct = 0 + bad_meshes = [] + number_frags_proc = 0 + + # Process fragments in parallel using multiprocessing + with ProcessPoolExecutor( + initializer=_init_cg_worker, initargs=(cg.graph_id,) + ) as executor: + futures = {} + for new_fragment_id, fragment_ids in multi_child_nodes.items(): + frag_mesh_data = { + int(f): mesh_dict[int(f)] for f in fragment_ids if int(f) in mesh_dict + } + futures[ + executor.submit( + _process_sharded_fragment, + new_fragment_id, + fragment_ids, + frag_mesh_data, + layer, + mip, + chunk_id, + high_padding, + ) + ] = new_fragment_id + + for future in as_completed(futures): + frag_id, encoded_bytes, vertex_count, error = future.result() + if encoded_bytes is not None: + merged_meshes[int(frag_id)] = encoded_bytes + number_frags_proc += 1 + if vertex_count > biggest_frag_vx_ct: + biggest_frag = frag_id + biggest_frag_vx_ct = vertex_count + if number_frags_proc % 1000 == 0: + print(f"number frag proc = {number_frags_proc}") + elif error is not None: + print(error) + bad_meshes.append(frag_id) + + del mesh_dict + shard_binary = sharding_spec.synthesize_shard(merged_meshes) + shard_filename = cv.mesh.readers[layer].get_filename(chunk_id) + cf = CloudFiles( + os.path.join(cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(layer)) + ) + cf.put( + shard_filename, + shard_binary, + content_type="application/octet-stream", + compress=False, + cache_control=cache_string, + ) + total_time = time.time() - start_existence_check_time + + ret = { + "chunk_id": chunk_id, + "total_time": total_time, + "biggest_frag": biggest_frag, + "biggest_frag_vx_ct": biggest_frag_vx_ct, + "number_frag": number_frags_proc, + "bad meshes": bad_meshes, + } + return ret diff --git a/pychunkedgraph/meshing/meshgen_remesh.py b/pychunkedgraph/meshing/meshgen_remesh.py new file mode 100644 index 000000000..52bc11639 --- /dev/null +++ b/pychunkedgraph/meshing/meshgen_remesh.py @@ -0,0 +1,465 @@ +# pylint: disable=invalid-name, missing-docstring, too-many-lines, wrong-import-order, import-outside-toplevel, no-member, c-extension-no-member + +from typing import Sequence +import collections +import datetime +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +from cloudfiles import CloudFiles +from cloudvolume import CloudVolume +import DracoPy +import fastremap + +from pychunkedgraph.graph.chunkedgraph import ChunkedGraph # noqa +from pychunkedgraph.graph import attributes # noqa +from pychunkedgraph.meshing import meshgen_utils +from pychunkedgraph.meshing.meshgen_utils import ( + UTC, + PRINT_FOR_DEBUGGING, + WRITING_TO_CLOUD, + remap_seg_using_unsafe_dict, + calculate_stop_layer, + get_multi_child_nodes, + transform_draco_fragment_and_return_encoding_options, + merge_draco_meshes_across_boundaries, +) +from pychunkedgraph.meshing.meshgen_initial import chunk_initial_mesh_task +from pychunkedgraph.meshing.manifest.cache import ManifestCache + + +def get_remapped_seg_for_lvl2_nodes( + cg, + chunk_id: np.uint64, + lvl2_nodes: Sequence[np.uint64], + mip: int = 2, + overlap_vx: int = 1, + time_stamp=None, + n_threads: int = 1, +): + """Downloads + remaps ws segmentation + resolve unclear cases, + filter out all but specified lvl2_nodes + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param mip: int + :param overlap_vx: int + :param time_stamp: + :return: remapped segmentation + """ + seg = meshgen_utils.get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx) + sv_of_lvl2_nodes = cg.get_children(lvl2_nodes) + + # Check which of the lvl2_nodes meet the chunk boundary + node_ids_on_the_border = [] + remapping = {} + for node, sv_list in sv_of_lvl2_nodes.items(): + node_on_the_border = False + for sv_id in sv_list: + remapping[sv_id] = node + # If a node_id is on the chunk_boundary, we must check + # the overlap region to see if the meshes' end will be open or closed + if (not node_on_the_border) and ( + np.isin(sv_id, seg[-2, :, :]) + or np.isin(sv_id, seg[:, -2, :]) + or np.isin(sv_id, seg[:, :, -2]) + ): + node_on_the_border = True + node_ids_on_the_border.append(node) + + node_ids_on_the_border = np.array(node_ids_on_the_border) + if len(node_ids_on_the_border) > 0: + overlap_region = np.concatenate( + (seg[:, :, -1], seg[:, -1, :], seg[-1, :, :]), axis=None + ) + overlap_sv_ids = np.unique(overlap_region) + if overlap_sv_ids[0] == 0: + overlap_sv_ids = overlap_sv_ids[1:] + # Get the remappings for the supervoxels in the overlap region + sv_remapping, unsafe_dict = get_lx_overlapping_remappings_for_nodes_and_svs( + cg, chunk_id, node_ids_on_the_border, overlap_sv_ids, time_stamp, n_threads + ) + sv_remapping.update(remapping) + fastremap.mask_except(seg, list(sv_remapping.keys()), in_place=True) + fastremap.remap(seg, sv_remapping, preserve_missing_labels=True, in_place=True) + # For some supervoxel, they could map to multiple l2 nodes in the chunk, + # so we must perform a connected component analysis + # to see which l2 node they are adjacent to + return remap_seg_using_unsafe_dict(seg, unsafe_dict) + else: + # If no nodes in our subset meet the chunk boundary + # we can simply retrieve the sv of the nodes in the subset + fastremap.mask_except(seg, list(remapping.keys()), in_place=True) + fastremap.remap(seg, remapping, preserve_missing_labels=True, in_place=True) + + return seg + + +def get_root_remapping_for_nodes_and_svs( + cg, chunk_id, node_ids, sv_ids, stop_layer, time_stamp, n_threads=1 +): + """Retrieves root to node id mapping for specified node ids and supervoxel ids + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param node_ids: [np.uint64] + :param stop_layer: int + :param time_stamp: datetime object + :return: multiples + """ + + def _get_root_ids(args): + start_id, end_id = args + + root_ids[start_id:end_id] = cg.get_roots( + combined_ids[start_id:end_id], + stop_layer=stop_layer, + time_stamp=time_stamp, + fail_to_zero=True, + ) + + rr = cg.range_read_chunk( + chunk_id=chunk_id, properties=attributes.Hierarchy.Child, time_stamp=time_stamp + ) + chunk_sv_ids = np.unique(np.concatenate([id[0].value for id in rr.values()])) + chunk_l2_ids = np.unique(cg.get_parents(chunk_sv_ids, time_stamp=time_stamp)) + combined_ids = np.concatenate((node_ids, sv_ids, chunk_l2_ids)) + + root_ids = np.zeros(len(combined_ids), dtype=np.uint64) + n_jobs = np.min([n_threads, len(combined_ids)]) + multi_args = [] + start_ids = np.linspace(0, len(combined_ids), n_jobs + 1).astype(int) + for i_block in range(n_jobs): + multi_args.append([start_ids[i_block], start_ids[i_block + 1]]) + + if n_jobs > 0: + with ThreadPoolExecutor(max_workers=n_threads) as executor: + list(executor.map(_get_root_ids, multi_args)) + + sv_ids_index = len(node_ids) + chunk_ids_index = len(node_ids) + len(sv_ids) + + return ( + root_ids[0:sv_ids_index], + root_ids[sv_ids_index:chunk_ids_index], + root_ids[chunk_ids_index:], + ) + + +def get_lx_overlapping_remappings_for_nodes_and_svs( + cg, + chunk_id: np.uint64, + node_ids: Sequence[np.uint64], + sv_ids: Sequence[np.uint64], + time_stamp=None, + n_threads: int = 1, +): + """Retrieves sv id to layer mapping for chunk with overlap in positive + direction (one chunk) + + :param cg: chunkedgraph object + :param chunk_id: np.uint64 + :param node_ids: list of np.uint64 + :param sv_ids: list of np.uint64 + :param time_stamp: datetime object + :param n_threads: int + :return: multiples + """ + if time_stamp is None: + time_stamp = datetime.datetime.now(datetime.timezone.utc) + if time_stamp.tzinfo is None: + time_stamp = UTC.localize(time_stamp) + + stop_layer, _ = calculate_stop_layer(cg, chunk_id) + print(f"Stop layer: {stop_layer}") + + # Find the parent in the lowest common chunk for each node id and sv id. These parent + # ids are referred to as root ids even though they are not necessarily the + # root id. + node_root_ids, sv_root_ids, chunks_root_ids = get_root_remapping_for_nodes_and_svs( + cg, chunk_id, node_ids, sv_ids, stop_layer, time_stamp, n_threads + ) + + u_root_ids, u_idx, c_root_ids = np.unique( + chunks_root_ids, return_counts=True, return_index=True + ) + + # All l2 ids that share no root id with any other l2 id in the chunk are "safe", meaning + # that we can easily obtain the complete remapping (including + # overlap) for these. All other ones have to be resolved using the + # segmentation. + + root_sorted_idx = np.argsort(u_root_ids) + node_sorted_index = np.searchsorted(u_root_ids[root_sorted_idx], node_root_ids) + node_root_counts = c_root_ids[root_sorted_idx][node_sorted_index] + unsafe_root_ids = node_root_ids[np.where(node_root_counts > 1)] + safe_node_ids = node_ids[~np.isin(node_root_ids, unsafe_root_ids)] + + node_to_root_dict = dict(zip(node_ids, node_root_ids)) + + # Future sv id -> lx mapping + sv_ids_to_remap = [] + node_ids_flat = [] + + # Do safe ones first + for node_id in safe_node_ids: + root_id = node_to_root_dict[node_id] + sv_ids_to_add = sv_ids[np.where(sv_root_ids == root_id)] + if len(sv_ids_to_add) > 0: + sv_ids_to_remap.extend(sv_ids_to_add) + node_ids_flat.extend([node_id] * len(sv_ids_to_add)) + + # For the unsafe roots, we will map the out of chunk svs to the root id and store the + # hierarchical information in a dictionary + unsafe_dict = collections.defaultdict(list) + for root_id in unsafe_root_ids: + sv_ids_to_add = sv_ids[np.where(sv_root_ids == root_id)] + if len(sv_ids_to_add) > 0: + relevant_node_ids = node_ids[np.where(node_root_ids == root_id)] + if len(relevant_node_ids) > 0: + unsafe_dict[root_id].extend(relevant_node_ids) + sv_ids_to_remap.extend(sv_ids_to_add) + node_ids_flat.extend([root_id] * len(sv_ids_to_add)) + + # Combine the lists for a (chunk-) global remapping + sv_remapping = dict(zip(sv_ids_to_remap, node_ids_flat)) + + return sv_remapping, unsafe_dict + + +def _get_timestamp_from_node_ids(cg, node_ids): + timestamps = cg.get_node_timestamps(node_ids, return_numpy=False) + return max(timestamps) + datetime.timedelta(milliseconds=1) + + +def remeshing( + cg, + l2_node_ids: Sequence[np.uint64], + cv_sharded_mesh_dir: str, + cv_unsharded_mesh_path: str, + stop_layer: int = None, + mip: int = 2, + max_err: int = 40, + time_stamp: datetime.datetime or None = None, +): + """Given a chunkedgraph, a list of level 2 nodes, + perform remeshing and stitching up the node hierarchy (or up to the stop_layer) + + :param cg: chunkedgraph instance + :param l2_node_ids: list of uint64 + :param stop_layer: int + :param cv_path: str + :param cv_mesh_dir: str + :param mip: int + :param max_err: int + :return: + """ + l2_chunk_dict = collections.defaultdict(set) + # Find the chunk_ids of the l2_node_ids + + def add_nodes_to_l2_chunk_dict(ids): + for node_id in ids: + chunk_id = cg.get_chunk_id(node_id) + l2_chunk_dict[chunk_id].add(node_id) + + add_nodes_to_l2_chunk_dict(l2_node_ids) + for chunk_id, node_ids in l2_chunk_dict.items(): + if PRINT_FOR_DEBUGGING: + print("remeshing", chunk_id, node_ids) + try: + l2_time_stamp = _get_timestamp_from_node_ids(cg, node_ids) + except ValueError: + # ignore bad/invalid messages + return + # Remesh the l2_node_ids + chunk_initial_mesh_task( + None, + chunk_id, + mip=mip, + node_id_subset=node_ids, + cg=cg, + cv_unsharded_mesh_path=cv_unsharded_mesh_path, + max_err=max_err, + sharded=False, + time_stamp=l2_time_stamp, + ) + chunk_dicts = [] + max_layer = stop_layer or cg._n_layers + for layer in range(3, max_layer + 1): + chunk_dicts.append(collections.defaultdict(set)) + cur_chunk_dict = l2_chunk_dict + # Find the parents of each l2_node_id up to the stop_layer, + # as well as their associated chunk_ids + for layer in range(3, max_layer + 1): + for _, node_ids in cur_chunk_dict.items(): + parent_nodes = cg.get_parents(node_ids, time_stamp=time_stamp) + for parent_node in parent_nodes: + chunk_layer = cg.get_chunk_layer(parent_node) + index_in_dict_array = chunk_layer - 3 + if index_in_dict_array < len(chunk_dicts): + chunk_id = cg.get_chunk_id(parent_node) + chunk_dicts[index_in_dict_array][chunk_id].add(parent_node) + cur_chunk_dict = chunk_dicts[layer - 3] + for chunk_dict in chunk_dicts: + for chunk_id, node_ids in chunk_dict.items(): + if PRINT_FOR_DEBUGGING: + print("remeshing", chunk_id, node_ids) + # Stitch the meshes of the parents we found in the previous loop + chunk_stitch_remeshing_task( + None, + chunk_id, + mip=mip, + fragment_batch_size=40, + node_id_subset=node_ids, + cg=cg, + cv_sharded_mesh_dir=cv_sharded_mesh_dir, + cv_unsharded_mesh_path=cv_unsharded_mesh_path, + ) + + +def chunk_stitch_remeshing_task( + cg_name, + chunk_id, + cv_sharded_mesh_dir, + cv_unsharded_mesh_path, + mip=2, + lod=0, + fragment_batch_size=None, + node_id_subset=None, + cg=None, + high_padding=1, +): + """ + For each node with more than one child, create a new fragment by + merging the mesh fragments of the children. + """ + if cg is None: + cg = ChunkedGraph(graph_id=cg_name) + cx, cy, cz = cg.get_chunk_coordinates(chunk_id) + layer = cg.get_chunk_layer(chunk_id) + result = [] + + assert layer > 2 + + print( + "Retrieving children for chunk %s -- (%s, %s, %s, %s)" + % (chunk_id, layer, cx, cy, cz) + ) + + multi_child_nodes, _ = get_multi_child_nodes(cg, chunk_id, node_id_subset, False) + print(f"{len(multi_child_nodes)} nodes with more than one child") + result.append((chunk_id, len(multi_child_nodes))) + if not multi_child_nodes: + print("Nothing to do", cx, cy, cz) + return ", ".join(str(x) for x in result) + + cv = CloudVolume( + f"graphene://https://localhost/segmentation/table/dummy", + mesh_dir=cv_sharded_mesh_dir, + info=meshgen_utils.get_json_info(cg), + ) + + fragment_to_fetch = [ + fragment + for child_fragments in multi_child_nodes.values() + for fragment in child_fragments + ] + + # Fetch all fragments upfront for parallel processing + fragment_map = cv.mesh.get_meshes_on_bypass(fragment_to_fetch, allow_missing=True) + + # Process each node's fragments in parallel + def _process_stitch_node(item): + new_fragment_id, fragment_ids_to_fetch = item + + old_fragments = [] + missing_fragments = False + for fragment_id in fragment_ids_to_fetch: + if fragment_id in fragment_map: + old_frag = fragment_map[fragment_id] + new_old_frag = { + "num_vertices": len(old_frag.vertices), + "vertices": old_frag.vertices, + "faces": old_frag.faces.reshape(-1), + "encoding_options": old_frag.encoding_options, + "encoding_type": "draco", + } + wrapper_object = { + "mesh": new_old_frag, + "node_id": np.uint64(old_frag.segid), + } + old_fragments.append(wrapper_object) + elif cg.get_chunk_layer(np.uint64(fragment_id)) > 2: + missing_fragments = True + return ( + new_fragment_id, + None, + None, + f"{fragment_id} missing for {new_fragment_id}", + ) + + if len(old_fragments) == 0 or missing_fragments: + return (new_fragment_id, None, None, f"No meshes for {new_fragment_id}") + + draco_encoding_options = None + for old_fragment in old_fragments: + if draco_encoding_options is None: + draco_encoding_options = ( + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id + ) + ) + else: + transform_draco_fragment_and_return_encoding_options( + cg, old_fragment, layer, mip, chunk_id + ) + + new_fragment = merge_draco_meshes_across_boundaries( + cg, old_fragments, chunk_id, mip, high_padding + ) + + try: + new_fragment_b = DracoPy.encode_mesh_to_buffer( + new_fragment["vertices"], + new_fragment["faces"], + **draco_encoding_options, + ) + except: + return ( + new_fragment_id, + None, + None, + f'Bad mesh created for {new_fragment_id}: {len(new_fragment["vertices"])} ' + f'vertices, {len(new_fragment["faces"])} faces', + ) + + fragment_name = None + if WRITING_TO_CLOUD: + fragment_name = meshgen_utils.get_chunk_bbox_str(cg, new_fragment_id) + fragment_name = f"{new_fragment_id}:0:{fragment_name}" + thread_cf = CloudFiles(cv_unsharded_mesh_path) + thread_cf.put( + fragment_name, + new_fragment_b, + content_type="application/octet-stream", + compress=False, + cache_control="public", + ) + + return (new_fragment_id, fragment_name, new_fragment_b, None) + + fragments_d = {} + with ThreadPoolExecutor() as executor: + for r in executor.map(_process_stitch_node, multi_child_nodes.items()): + new_fragment_id, fragment_name, _, error = r + if error is not None: + result.append(error) + elif fragment_name is not None: + fragments_d[new_fragment_id] = fragment_name + + manifest_cache = ManifestCache(cg.graph_id, initial=False) + manifest_cache.set_fragments(fragments_d) + + if PRINT_FOR_DEBUGGING: + print(", ".join(str(x) for x in result)) + return ", ".join(str(x) for x in result) diff --git a/pychunkedgraph/meshing/meshgen_utils.py b/pychunkedgraph/meshing/meshgen_utils.py index 711c09322..6e18867e7 100644 --- a/pychunkedgraph/meshing/meshgen_utils.py +++ b/pychunkedgraph/meshing/meshgen_utils.py @@ -1,3 +1,4 @@ +import os import re import multiprocessing as mp from time import time @@ -8,13 +9,30 @@ from functools import lru_cache import numpy as np +import pytz +from scipy import ndimage from cloudvolume import CloudVolume from cloudvolume.lib import Vec -from multiwrapper import multiprocessing_utils as mu +import DracoPy +import zmesh +import fastremap from pychunkedgraph.graph.utils.basetypes import NODE_ID # noqa +from pychunkedgraph.graph import attributes # noqa from ..graph.types import empty_1d +UTC = pytz.UTC + +# Change below to true if debugging and want to see results in stdout +PRINT_FOR_DEBUGGING = False +# Change below to false if debugging and do not need to write to cloud (warning: do not deploy w/ below set to false) +WRITING_TO_CLOUD = True + +REDIS_HOST = os.environ.get("REDIS_SERVICE_HOST", "localhost") +REDIS_PORT = os.environ.get("REDIS_SERVICE_PORT", "6379") +REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "dev") +REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" + def str_to_slice(slice_str: str): match = re.match(r"(\d+)-(\d+)_(\d+)-(\d+)_(\d+)-(\d+)", slice_str) @@ -155,7 +173,7 @@ def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): mip_diff = mip - cg.meta.cv.mip mip_chunk_size = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) / np.array( - [2 ** mip_diff, 2 ** mip_diff, 1] + [2**mip_diff, 2**mip_diff, 1] ) mip_chunk_size = mip_chunk_size.astype(int) @@ -177,3 +195,422 @@ def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): ].squeeze() return ws_seg + + +def decode_draco_mesh_buffer(fragment): + try: + mesh_object = DracoPy.decode_buffer_to_mesh(fragment) + vertices = np.array(mesh_object.points) + faces = np.array(mesh_object.faces) + except ValueError as exc: + raise ValueError("Not a valid draco mesh") from exc + + num_vertices = len(vertices) + + # For now, just return this dict until we figure out + # how exactly to deal with Draco's lossiness/duplicate vertices + return { + "num_vertices": num_vertices, + "vertices": vertices, + "faces": faces, + "encoding_options": mesh_object.encoding_options, + "encoding_type": "draco", + } + + +def remap_seg_using_unsafe_dict(seg, unsafe_dict): + for unsafe_root_id in unsafe_dict.keys(): + bin_seg = seg == unsafe_root_id + + if np.sum(bin_seg) == 0: + continue + + cc_seg, n_cc = ndimage.label(bin_seg) + for i_cc in range(1, n_cc + 1): + bin_cc_seg = cc_seg == i_cc + + overlaps = [] + overlaps.extend(np.unique(seg[-2, :, :][bin_cc_seg[-1, :, :]])) + overlaps.extend(np.unique(seg[:, -2, :][bin_cc_seg[:, -1, :]])) + overlaps.extend(np.unique(seg[:, :, -2][bin_cc_seg[:, :, -1]])) + overlaps = np.unique(overlaps) + + linked_l2_ids = overlaps[np.isin(overlaps, unsafe_dict[unsafe_root_id])] + + if len(linked_l2_ids) == 0: + seg[bin_cc_seg] = 0 + else: + seg[bin_cc_seg] = linked_l2_ids[0] + + return seg + + +def calculate_stop_layer(cg, chunk_id): + chunk_coords = cg.get_chunk_coordinates(chunk_id) + chunk_layer = cg.get_chunk_layer(chunk_id) + + neigh_chunk_ids = [] + neigh_parent_chunk_ids = [] + + # Collect neighboring chunks and their parent chunk ids + # We only need to know about the parent chunk ids to figure the lowest + # common chunk + # Notice that the first neigh_chunk_id is equal to `chunk_id`. + for x in range(chunk_coords[0], chunk_coords[0] + 2): + for y in range(chunk_coords[1], chunk_coords[1] + 2): + for z in range(chunk_coords[2], chunk_coords[2] + 2): + # Chunk id + try: + neigh_chunk_id = cg.get_chunk_id(x=x, y=y, z=z, layer=chunk_layer) + # Get parent chunk ids + parent_chunk_ids = cg.get_parent_chunk_ids(neigh_chunk_id) + neigh_chunk_ids.append(neigh_chunk_id) + neigh_parent_chunk_ids.append(parent_chunk_ids) + except: + # cg.get_parent_chunk_id can fail if neigh_chunk_id is outside the dataset + # (only happens when cg.meta.bitmasks[chunk_layer+1] == log(max(x,y,z)), + # so only for specific datasets in which the # of chunks in the widest dimension + # just happens to be a power of two) + pass + + # Find lowest common chunk + neigh_parent_chunk_ids = np.array(neigh_parent_chunk_ids) + layer_agreement = np.all( + (neigh_parent_chunk_ids - neigh_parent_chunk_ids[0]) == 0, axis=0 + ) + stop_layer = np.where(layer_agreement)[0][0] + chunk_layer + + return stop_layer, neigh_chunk_ids + + +def get_meshing_necessities_from_graph(cg, chunk_id: np.uint64, mip: int): + """Given a chunkedgraph, chunk_id, and mip level, return the voxel dimensions of the chunk to be meshed (mesh_block_shape) + and the chunk origin in the dataset in nm. + + :param cg: chunkedgraph instance + :param chunk_id: uint64 + :param mip: int + """ + layer = cg.get_chunk_layer(chunk_id) + cx, cy, cz = cg.get_chunk_coordinates(chunk_id) + mesh_block_shape = get_mesh_block_shape_for_mip(cg, layer, mip) + voxel_resolution = cg.meta.cv.mip_resolution(mip) + chunk_offset = ( + (cx, cy, cz) * mesh_block_shape + cg.meta.cv.mip_voxel_offset(mip) + ) * voxel_resolution + return layer, mesh_block_shape, chunk_offset + + +def calculate_quantization_bits_and_range( + min_quantization_range, max_draco_bin_size, draco_quantization_bits=None +): + if draco_quantization_bits is None: + draco_quantization_bits = np.ceil( + np.log2(min_quantization_range / max_draco_bin_size + 1) + ) + num_draco_bins = 2**draco_quantization_bits - 1 + draco_bin_size = np.ceil(min_quantization_range / num_draco_bins) + draco_quantization_range = draco_bin_size * num_draco_bins + if draco_quantization_range < min_quantization_range + draco_bin_size: + if draco_bin_size == max_draco_bin_size: + return calculate_quantization_bits_and_range( + min_quantization_range, max_draco_bin_size, draco_quantization_bits + 1 + ) + else: + draco_bin_size = draco_bin_size + 1 + draco_quantization_range = draco_quantization_range + num_draco_bins + return draco_quantization_bits, draco_quantization_range, draco_bin_size + + +def get_draco_encoding_settings_for_chunk( + cg, chunk_id: np.uint64, mip: int = 2, high_padding: int = 1 +): + """Calculate the proper draco encoding settings for a chunk to ensure proper stitching is possible + on the layer above. For details about how and why we do this, please see the meshing Readme + + :param cg: chunkedgraph instance + :param chunk_id: uint64 + :param mip: int + :param high_padding: int + """ + _, mesh_block_shape, chunk_offset = get_meshing_necessities_from_graph( + cg, chunk_id, mip + ) + segmentation_resolution = cg.meta.cv.mip_resolution(mip) + min_quantization_range = max( + (mesh_block_shape + high_padding) * segmentation_resolution + ) + max_draco_bin_size = np.floor(min(segmentation_resolution) / np.sqrt(2)) + ( + draco_quantization_bits, + draco_quantization_range, + draco_bin_size, + ) = calculate_quantization_bits_and_range( + min_quantization_range, max_draco_bin_size + ) + draco_quantization_origin = chunk_offset - (chunk_offset % draco_bin_size) + return { + "quantization_bits": draco_quantization_bits, + "compression_level": 1, + "quantization_range": draco_quantization_range, + "quantization_origin": draco_quantization_origin, + "create_metadata": True, + } + + +def get_next_layer_draco_encoding_settings( + cg, prev_layer_encoding_settings, next_layer_chunk_id, mip +): + old_draco_bin_size = prev_layer_encoding_settings["quantization_range"] // ( + 2 ** prev_layer_encoding_settings["quantization_bits"] - 1 + ) + _, mesh_block_shape, chunk_offset = get_meshing_necessities_from_graph( + cg, next_layer_chunk_id, mip + ) + segmentation_resolution = cg.meta.cv.mip_resolution(mip) + min_quantization_range = ( + max(mesh_block_shape * segmentation_resolution) + 2 * old_draco_bin_size + ) + max_draco_bin_size = np.floor(min(segmentation_resolution) / np.sqrt(2)) + ( + draco_quantization_bits, + draco_quantization_range, + draco_bin_size, + ) = calculate_quantization_bits_and_range( + min_quantization_range, max_draco_bin_size + ) + draco_quantization_origin = ( + chunk_offset + - old_draco_bin_size + - ((chunk_offset - old_draco_bin_size) % draco_bin_size) + ) + return { + "quantization_bits": draco_quantization_bits, + "compression_level": 1, + "quantization_range": draco_quantization_range, + "quantization_origin": draco_quantization_origin, + "create_metadata": True, + } + + +def transform_draco_vertices(mesh, encoding_settings): + vertices = np.reshape(mesh["vertices"], (mesh["num_vertices"] * 3,)) + max_quantized_value = 2 ** encoding_settings["quantization_bits"] - 1 + draco_bin_size = encoding_settings["quantization_range"] / max_quantized_value + assert np.equal(np.mod(draco_bin_size, 1), 0) + assert np.equal(np.mod(encoding_settings["quantization_range"], 1), 0) + assert np.equal(np.mod(encoding_settings["quantization_origin"], 1), 0).all() + for coord in range(3): + vertices[coord::3] -= encoding_settings["quantization_origin"][coord] + vertices /= draco_bin_size + vertices += 0.5 + np.floor(vertices, out=vertices) + vertices *= draco_bin_size + for coord in range(3): + vertices[coord::3] += encoding_settings["quantization_origin"][coord] + + +def transform_draco_fragment_and_return_encoding_options( + cg, fragment, layer, mip, chunk_id +): + fragment_encoding_options = fragment["mesh"]["encoding_options"] + if fragment_encoding_options is None: + raise ValueError("Draco fragment has no encoding options") + cur_encoding_settings = { + "quantization_range": fragment_encoding_options.quantization_range, + "quantization_bits": fragment_encoding_options.quantization_bits, + } + node_id = fragment["node_id"] + parent_chunk_ids = cg.get_parent_chunk_ids(node_id) + fragment_layer = cg.get_chunk_layer(node_id) + if fragment_layer >= layer: + raise ValueError( + f"Node {node_id} somehow has greater or equal layer than chunk {chunk_id}" + ) + assert len(parent_chunk_ids) > layer - fragment_layer + for next_layer in range(fragment_layer + 1, layer + 1): + next_layer_chunk_id = parent_chunk_ids[next_layer - fragment_layer] + next_encoding_settings = get_next_layer_draco_encoding_settings( + cg, cur_encoding_settings, next_layer_chunk_id, mip + ) + if next_layer < layer: + transform_draco_vertices(fragment["mesh"], next_encoding_settings) + cur_encoding_settings = next_encoding_settings + return cur_encoding_settings + + +def merge_draco_meshes_across_boundaries( + cg, fragments, chunk_id, mip, high_padding, return_zmesh_object=False +): + """ + Merge a list of draco mesh fragments, removing duplicate vertices that lie + on the chunk boundary where the meshes meet. + """ + vertexct = np.zeros(len(fragments) + 1, np.uint32) + vertexct[1:] = np.cumsum([x["mesh"]["num_vertices"] for x in fragments]) + vertices = np.concatenate([x["mesh"]["vertices"] for x in fragments]) + faces = np.concatenate( + [mesh["mesh"]["faces"] + vertexct[i] for i, mesh in enumerate(fragments)] + ) + del fragments + + if vertexct[-1] > 0: + chunk_coords = cg.get_chunk_coordinates(chunk_id) + coords_bottom_corner_child_chunk = chunk_coords * 2 + 1 + child_chunk_id = cg.get_chunk_id( + None, cg.get_chunk_layer(chunk_id) - 1, *coords_bottom_corner_child_chunk + ) + _, _, child_chunk_offset = get_meshing_necessities_from_graph( + cg, child_chunk_id, mip + ) + # Get the draco encoding settings for the + # child chunk in the "bottom corner" of the chunk_id chunk + draco_encoding_settings_smaller_chunk = get_draco_encoding_settings_for_chunk( + cg, child_chunk_id, mip=mip, high_padding=high_padding + ) + draco_bin_size = draco_encoding_settings_smaller_chunk["quantization_range"] / ( + 2 ** draco_encoding_settings_smaller_chunk["quantization_bits"] - 1 + ) + # Calculate which draco bin the child chunk's boundaries + # were placed into (for each x,y,z of boundary) + chunk_boundary_bin_index = np.floor( + ( + child_chunk_offset + - draco_encoding_settings_smaller_chunk["quantization_origin"] + ) + / draco_bin_size + + np.float32(0.5) + ) + # Now we can determine where the three planes of the quantized chunk boundary are + quantized_chunk_boundary = ( + draco_encoding_settings_smaller_chunk["quantization_origin"] + + chunk_boundary_bin_index * draco_bin_size + ) + # Separate the vertices that are on the quantized chunk boundary from those that aren't + are_chunk_aligned = (vertices == quantized_chunk_boundary).any(axis=1) + vertices = np.hstack((vertices, np.arange(vertexct[-1])[:, np.newaxis])) + chunk_aligned = vertices[are_chunk_aligned] + not_chunk_aligned = vertices[~are_chunk_aligned] + del vertices + del are_chunk_aligned + faces_remapping = {} + # Those that are not simply pass through (simple remap) + if len(not_chunk_aligned) > 0: + not_chunk_aligned_remap = dict( + zip( + not_chunk_aligned[:, 3].astype(np.uint32), + np.arange(len(not_chunk_aligned), dtype=np.uint32), + ) + ) + faces_remapping.update(not_chunk_aligned_remap) + # Those that are on the boundary we remove duplicates + if len(chunk_aligned) > 0: + unique_chunk_aligned, inverse_to_chunk_aligned = np.unique( + chunk_aligned[:, 0:3], return_inverse=True, axis=0 + ) + chunk_aligned_remap = dict( + zip( + chunk_aligned[:, 3].astype(np.uint32), + np.uint32(len(not_chunk_aligned)) + + inverse_to_chunk_aligned.astype(np.uint32), + ) + ) + faces_remapping.update(chunk_aligned_remap) + vertices = np.concatenate((not_chunk_aligned[:, 0:3], unique_chunk_aligned)) + else: + vertices = not_chunk_aligned[:, 0:3] + # Remap the faces to their new vertex indices + fastremap.remap(faces, faces_remapping, in_place=True) + + if return_zmesh_object: + return zmesh.Mesh(vertices[:, 0:3], faces.reshape(-1, 3), None) + + return { + "num_vertices": np.uint32(len(vertices)), + "vertices": vertices[:, 0:3].reshape(-1), + "faces": faces, + } + + +def black_out_dust_from_segmentation(seg, dust_threshold): + """Black out (set to 0) IDs in segmentation not on the segmentation + border that have less voxels than dust_threshold + + :param seg: 3D segmentation (usually uint64) + :param dust_threshold: int + :return: + """ + seg_ids, voxel_count = np.unique(seg, return_counts=True) + boundary = np.concatenate( + ( + seg[-2, :, :], + seg[-1, :, :], + seg[:, -2, :], + seg[:, -1, :], + seg[:, :, -2], + seg[:, :, -1], + ), + axis=None, + ) + seg_ids_on_boundary = np.unique(boundary) + below_threshold = voxel_count < int(dust_threshold) + not_on_boundary = ~np.isin(seg_ids, seg_ids_on_boundary) + dust_segids = seg_ids[below_threshold & not_on_boundary] + seg = fastremap.mask(seg, dust_segids, in_place=True) + + +def get_multi_child_nodes(cg, chunk_id, node_id_subset=None, chunk_bbox_string=False): + if node_id_subset is None: + range_read = cg.range_read_chunk( + chunk_id, properties=attributes.Hierarchy.Child + ) + else: + range_read = cg.client.read_nodes( + node_ids=node_id_subset, properties=attributes.Hierarchy.Child + ) + + node_ids = np.array(list(range_read.keys())) + node_rows = np.array(list(range_read.values())) + child_fragments = np.array( + [ + fragment.value + for child_fragments_for_node in node_rows + for fragment in child_fragments_for_node + ], + dtype=object, + ) + # Filter out node ids that do not have roots (caused by failed ingest tasks) + root_ids = cg.get_roots(node_ids, fail_to_zero=True) + # Only keep nodes with more than one child + multi_child_mask = np.array( + [len(fragments) > 1 for fragments in child_fragments], dtype=bool + ) + root_id_mask = np.array([root_id != 0 for root_id in root_ids], dtype=bool) + multi_child_node_ids = node_ids[multi_child_mask & root_id_mask] + multi_child_children_ids = child_fragments[multi_child_mask & root_id_mask] + # Store how many children each node has, because we will retrieve all children at once + multi_child_num_children = [len(children) for children in multi_child_children_ids] + child_fragments_flat = np.array( + [ + frag + for children_of_node in multi_child_children_ids + for frag in children_of_node + ] + ) + multi_child_descendants = get_downstream_multi_child_nodes(cg, child_fragments_flat) + start_index = 0 + multi_child_nodes = {} + for i in range(len(multi_child_node_ids)): + end_index = start_index + multi_child_num_children[i] + descendents_for_current_node = multi_child_descendants[start_index:end_index] + node_id = multi_child_node_ids[i] + if chunk_bbox_string: + multi_child_nodes[f"{node_id}:0:{get_chunk_bbox_str(cg, node_id)}"] = [ + f"{c}:0:{get_chunk_bbox_str(cg, c)}" + for c in descendents_for_current_node + ] + else: + multi_child_nodes[multi_child_node_ids[i]] = descendents_for_current_node + start_index = end_index + + return multi_child_nodes, multi_child_descendants diff --git a/pychunkedgraph/repair/edits.py b/pychunkedgraph/repair/edits.py index cb403a380..849b17e08 100644 --- a/pychunkedgraph/repair/edits.py +++ b/pychunkedgraph/repair/edits.py @@ -56,8 +56,6 @@ def repair_operation( op_ids_to_retry.append(locked_op) print(f"{node_id} indefinitely locked by op {locked_op}") print(f"total to retry: {len(op_ids_to_retry)}") - - logs = cg.client.read_log_entries(op_ids_to_retry) - for op_id, log in logs.items(): + for op_id in op_ids_to_retry: print(f"repairing {op_id}") - repair_operation(cg, log, op_id) + repair_operation(cg, op_id) diff --git a/pychunkedgraph/repair/fake_edges.py b/pychunkedgraph/repair/fake_edges.py deleted file mode 100644 index b58b93fb9..000000000 --- a/pychunkedgraph/repair/fake_edges.py +++ /dev/null @@ -1,78 +0,0 @@ -# pylint: disable=protected-access,missing-function-docstring,invalid-name,wrong-import-position - -""" -Replay merge operations to check if fake edges need to be added. -""" - -from datetime import datetime -from datetime import timedelta -from os import environ -from typing import Optional - -environ["BIGTABLE_PROJECT"] = "<>" -environ["BIGTABLE_INSTANCE"] = "<>" -environ["GOOGLE_APPLICATION_CREDENTIALS"] = "" - -from pychunkedgraph.graph import edits -from pychunkedgraph.graph import ChunkedGraph -from pychunkedgraph.graph.operation import GraphEditOperation -from pychunkedgraph.graph.operation import MergeOperation -from pychunkedgraph.graph.utils.generic import get_bounding_box as get_bbox - - -def _add_fake_edges(cg: ChunkedGraph, operation_id: int, operation_log: dict) -> bool: - operation = GraphEditOperation.from_operation_id( - cg, operation_id, multicut_as_split=False - ) - - if not isinstance(operation, MergeOperation): - return False - - ts = operation_log["timestamp"] - parent_ts = ts - timedelta(seconds=0.1) - override_ts = (ts + timedelta(microseconds=(ts.microsecond % 1000) + 10),) - - root_ids = set( - cg.get_roots( - operation.added_edges.ravel(), assert_roots=True, time_stamp=parent_ts - ) - ) - - bbox = get_bbox( - operation.source_coords, operation.sink_coords, operation.bbox_offset - ) - edges = cg.get_subgraph( - root_ids, - bbox=bbox, - bbox_is_coordinate=True, - edges_only=True, - ) - - inactive_edges = edits.merge_preprocess( - cg, - subgraph_edges=edges, - supervoxels=operation.added_edges.ravel(), - parent_ts=parent_ts, - ) - - _, fake_edge_rows = edits.check_fake_edges( - cg, - atomic_edges=operation.added_edges, - inactive_edges=inactive_edges, - time_stamp=override_ts, - parent_ts=parent_ts, - ) - - cg.client.write(fake_edge_rows) - return len(fake_edge_rows) > 0 - - -def add_fake_edges( - graph_id: str, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, -): - cg = ChunkedGraph(graph_id=graph_id) - logs = cg.client.read_log_entries(start_time=start_time, end_time=end_time) - for _id, _log in logs.items(): - _add_fake_edges(cg, _id, _log) diff --git a/pychunkedgraph/tests/conftest.py b/pychunkedgraph/tests/conftest.py new file mode 100644 index 000000000..a502ba505 --- /dev/null +++ b/pychunkedgraph/tests/conftest.py @@ -0,0 +1,283 @@ +import atexit +import os +import signal +import subprocess +from functools import partial +from datetime import timedelta + +import pytest + +# Skip the old monolithic test file if it still exists (e.g., during branch transitions) +collect_ignore = ["test_uncategorized.py"] +import numpy as np +from google.auth import credentials +from google.cloud import bigtable + +from ..ingest.utils import bootstrap +from ..graph.edges import Edges +from ..graph.chunkedgraph import ChunkedGraph +from ..ingest.create.parent_layer import add_parent_chunk + +from .helpers import ( + CloudVolumeMock, + create_chunk, + to_label, + get_layer_chunk_bounds, +) + +_emulator_proc = None +_emulator_cleaned = False + + +def _cleanup_emulator(): + global _emulator_cleaned + if _emulator_cleaned or _emulator_proc is None: + return + _emulator_cleaned = True + try: + pgid = os.getpgid(_emulator_proc.pid) + os.killpg(pgid, signal.SIGTERM) + try: + _emulator_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + os.killpg(pgid, signal.SIGKILL) + _emulator_proc.wait(timeout=5) + except (ProcessLookupError, OSError, ChildProcessError): + pass + # Hard kill cbtemulator in case it survived the process group signal + subprocess.run(["pkill", "-9", "cbtemulator"], stderr=subprocess.DEVNULL) + + +def setup_emulator_env(): + bt_env_init = subprocess.run( + ["gcloud", "beta", "emulators", "bigtable", "env-init"], stdout=subprocess.PIPE + ) + os.environ["BIGTABLE_EMULATOR_HOST"] = ( + bt_env_init.stdout.decode("utf-8").strip().split("=")[-1] + ) + + c = bigtable.Client( + project="IGNORE_ENVIRONMENT_PROJECT", + credentials=credentials.AnonymousCredentials(), + admin=True, + ) + t = c.instance("emulated_instance").table("emulated_table") + + try: + t.create() + return True + except Exception as err: + print("Bigtable Emulator not yet ready: %s" % err) + return False + + +@pytest.fixture(scope="session", autouse=True) +def bigtable_emulator(request): + global _emulator_proc, _emulator_cleaned + from time import sleep + + _emulator_cleaned = False + + # Kill any leftover emulator processes from previous runs + subprocess.run(["pkill", "-9", "cbtemulator"], stderr=subprocess.DEVNULL) + + # Start Emulator + _emulator_proc = subprocess.Popen( + [ + "gcloud", + "beta", + "emulators", + "bigtable", + "start", + "--host-port=localhost:8539", + ], + preexec_fn=os.setsid, + stdout=subprocess.PIPE, + ) + + # Register atexit handler as safety net for abnormal exits + atexit.register(_cleanup_emulator) + + # Wait for Emulator to start up + print("Waiting for BigTables Emulator to start up...", end="") + retries = 5 + while retries > 0: + if setup_emulator_env() is True: + break + else: + retries -= 1 + sleep(5) + + if retries == 0: + print( + "\nCouldn't start Bigtable Emulator. Make sure it is installed correctly." + ) + _cleanup_emulator() + exit(1) + + request.addfinalizer(_cleanup_emulator) + + +@pytest.fixture(scope="function") +def gen_graph(request): + def _cgraph(request, n_layers=10, atomic_chunk_bounds: np.ndarray = np.array([])): + config = { + "data_source": { + "EDGES": "gs://chunked-graph/minnie65_0/edges", + "COMPONENTS": "gs://chunked-graph/minnie65_0/components", + "WATERSHED": "gs://microns-seunglab/minnie65/ws_minnie65_0", + }, + "graph_config": { + "CHUNK_SIZE": [512, 512, 64], + "FANOUT": 2, + "SPATIAL_BITS": 10, + "ID_PREFIX": "", + "ROOT_LOCK_EXPIRY": timedelta(seconds=5), + }, + "backend_client": { + "TYPE": "bigtable", + "CONFIG": { + "ADMIN": True, + "READ_ONLY": False, + "PROJECT": "IGNORE_ENVIRONMENT_PROJECT", + "INSTANCE": "emulated_instance", + "CREDENTIALS": credentials.AnonymousCredentials(), + "MAX_ROW_KEY_COUNT": 1000, + }, + }, + "ingest_config": {}, + } + + meta, _, client_info = bootstrap("test", config=config) + graph = ChunkedGraph(graph_id="test", meta=meta, client_info=client_info) + graph.mock_edges = Edges([], []) + graph.meta._ws_cv = CloudVolumeMock() + graph.meta.layer_count = n_layers + graph.meta.layer_chunk_bounds = get_layer_chunk_bounds( + n_layers, atomic_chunk_bounds=atomic_chunk_bounds + ) + + graph.create() + + # setup Chunked Graph - Finalizer + def fin(): + graph.client._table.delete() + + request.addfinalizer(fin) + return graph + + return partial(_cgraph, request) + + +@pytest.fixture(scope="function") +def gen_graph_with_edges(request, tmp_path): + """Like gen_graph but with real edge/component I/O via local filesystem (file:// protocol).""" + + def _cgraph(request, n_layers=10, atomic_chunk_bounds: np.ndarray = np.array([])): + edges_dir = f"file://{tmp_path}/edges" + components_dir = f"file://{tmp_path}/components" + config = { + "data_source": { + "EDGES": edges_dir, + "COMPONENTS": components_dir, + "WATERSHED": "gs://microns-seunglab/minnie65/ws_minnie65_0", + }, + "graph_config": { + "CHUNK_SIZE": [512, 512, 64], + "FANOUT": 2, + "SPATIAL_BITS": 10, + "ID_PREFIX": "", + "ROOT_LOCK_EXPIRY": timedelta(seconds=5), + }, + "backend_client": { + "TYPE": "bigtable", + "CONFIG": { + "ADMIN": True, + "READ_ONLY": False, + "PROJECT": "IGNORE_ENVIRONMENT_PROJECT", + "INSTANCE": "emulated_instance", + "CREDENTIALS": credentials.AnonymousCredentials(), + "MAX_ROW_KEY_COUNT": 1000, + }, + }, + "ingest_config": {}, + } + + meta, _, client_info = bootstrap("test", config=config) + graph = ChunkedGraph(graph_id="test", meta=meta, client_info=client_info) + # No mock_edges - use real I/O via file:// protocol + graph.meta._ws_cv = CloudVolumeMock() + graph.meta.layer_count = n_layers + graph.meta.layer_chunk_bounds = get_layer_chunk_bounds( + n_layers, atomic_chunk_bounds=atomic_chunk_bounds + ) + + graph.create() + + def fin(): + graph.client._table.delete() + + request.addfinalizer(fin) + return graph + + return partial(_cgraph, request) + + +@pytest.fixture(scope="function") +def gen_graph_simplequerytest(request, gen_graph): + """ + ┌─────┬─────┬─────┐ + │ A¹ │ B¹ │ C¹ │ + │ 1 │ 3━2━┿━━4 │ + │ │ │ │ + └─────┴─────┴─────┘ + """ + from math import inf + + graph = gen_graph(n_layers=4) + + # Chunk A + create_chunk(graph, vertices=[to_label(graph, 1, 0, 0, 0, 0)], edges=[]) + + # Chunk B + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 1), 0.5), + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), + ], + ) + + # Chunk C + create_chunk( + graph, + vertices=[to_label(graph, 1, 2, 0, 0, 0)], + edges=[(to_label(graph, 1, 2, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf)], + ) + + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + return graph + + +@pytest.fixture(scope="session") +def sv_data(): + test_data_dir = "pychunkedgraph/tests/data" + edges_file = f"{test_data_dir}/sv_edges.npy" + sv_edges = np.load(edges_file) + + source_file = f"{test_data_dir}/sv_sources.npy" + sv_sources = np.load(source_file) + + sinks_file = f"{test_data_dir}/sv_sinks.npy" + sv_sinks = np.load(sinks_file) + + affinity_file = f"{test_data_dir}/sv_affinity.npy" + sv_affinity = np.load(affinity_file) + + area_file = f"{test_data_dir}/sv_area.npy" + sv_area = np.load(area_file) + yield (sv_edges, sv_sources, sv_sinks, sv_affinity, sv_area) diff --git a/pychunkedgraph/tests/helpers.py b/pychunkedgraph/tests/helpers.py index de5314422..335b44fd0 100644 --- a/pychunkedgraph/tests/helpers.py +++ b/pychunkedgraph/tests/helpers.py @@ -1,25 +1,11 @@ -import os -import subprocess -from math import inf -from time import sleep -from signal import SIGTERM from functools import reduce -from functools import partial -from datetime import timedelta - -import pytest import numpy as np -from google.auth import credentials -from google.cloud import bigtable -from ..ingest.utils import bootstrap -from ..ingest.create.atomic_layer import add_atomic_edges from ..graph.edges import Edges from ..graph.edges import EDGE_TYPES from ..graph.utils import basetypes -from ..graph.chunkedgraph import ChunkedGraph -from ..ingest.create.abstract_layers import add_layer +from ..ingest.create.atomic_layer import add_atomic_chunk class CloudVolumeBounds(object): @@ -43,162 +29,6 @@ def __init__(self): self.bounds = CloudVolumeBounds() -def setup_emulator_env(): - bt_env_init = subprocess.run( - ["gcloud", "beta", "emulators", "bigtable", "env-init"], stdout=subprocess.PIPE - ) - os.environ["BIGTABLE_EMULATOR_HOST"] = ( - bt_env_init.stdout.decode("utf-8").strip().split("=")[-1] - ) - - c = bigtable.Client( - project="IGNORE_ENVIRONMENT_PROJECT", - credentials=credentials.AnonymousCredentials(), - admin=True, - ) - t = c.instance("emulated_instance").table("emulated_table") - - try: - t.create() - return True - except Exception as err: - print("Bigtable Emulator not yet ready: %s" % err) - return False - - -@pytest.fixture(scope="session", autouse=True) -def bigtable_emulator(request): - # Start Emulator - bigtable_emulator = subprocess.Popen( - [ - "gcloud", - "beta", - "emulators", - "bigtable", - "start", - "--host-port=localhost:8539", - ], - preexec_fn=os.setsid, - stdout=subprocess.PIPE, - ) - - # Wait for Emulator to start up - print("Waiting for BigTables Emulator to start up...", end="") - retries = 5 - while retries > 0: - if setup_emulator_env() is True: - break - else: - retries -= 1 - sleep(5) - - if retries == 0: - print( - "\nCouldn't start Bigtable Emulator. Make sure it is installed correctly." - ) - exit(1) - - # Setup Emulator-Finalizer - def fin(): - os.killpg(os.getpgid(bigtable_emulator.pid), SIGTERM) - bigtable_emulator.wait() - - request.addfinalizer(fin) - - -@pytest.fixture(scope="function") -def gen_graph(request): - def _cgraph(request, n_layers=10, atomic_chunk_bounds: np.ndarray = np.array([])): - config = { - "data_source": { - "EDGES": "gs://chunked-graph/minnie65_0/edges", - "COMPONENTS": "gs://chunked-graph/minnie65_0/components", - "WATERSHED": "gs://microns-seunglab/minnie65/ws_minnie65_0", - }, - "graph_config": { - "CHUNK_SIZE": [512, 512, 64], - "FANOUT": 2, - "SPATIAL_BITS": 10, - "ID_PREFIX": "", - "ROOT_LOCK_EXPIRY": timedelta(seconds=5) - }, - "backend_client": { - "TYPE": "bigtable", - "CONFIG": { - "ADMIN": True, - "READ_ONLY": False, - "PROJECT": "IGNORE_ENVIRONMENT_PROJECT", - "INSTANCE": "emulated_instance", - "CREDENTIALS": credentials.AnonymousCredentials(), - "MAX_ROW_KEY_COUNT": 1000 - }, - }, - "ingest_config": {}, - } - - meta, _, client_info = bootstrap("test", config=config) - graph = ChunkedGraph(graph_id="test", meta=meta, - client_info=client_info) - graph.mock_edges = Edges([], []) - graph.meta._ws_cv = CloudVolumeMock() - graph.meta.layer_count = n_layers - graph.meta.layer_chunk_bounds = get_layer_chunk_bounds( - n_layers, atomic_chunk_bounds=atomic_chunk_bounds - ) - - graph.create() - - # setup Chunked Graph - Finalizer - def fin(): - graph.client._table.delete() - - request.addfinalizer(fin) - return graph - - return partial(_cgraph, request) - - -@pytest.fixture(scope="function") -def gen_graph_simplequerytest(request, gen_graph): - """ - ┌─────┬─────┬─────┐ - │ A¹ │ B¹ │ C¹ │ - │ 1 │ 3━2━┿━━4 │ - │ │ │ │ - └─────┴─────┴─────┘ - """ - - graph = gen_graph(n_layers=4) - - # Chunk A - create_chunk(graph, vertices=[to_label(graph, 1, 0, 0, 0, 0)], edges=[]) - - # Chunk B - create_chunk( - graph, - vertices=[to_label(graph, 1, 1, 0, 0, 0), - to_label(graph, 1, 1, 0, 0, 1)], - edges=[ - (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 1), 0.5), - (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), - ], - ) - - # Chunk C - create_chunk( - graph, - vertices=[to_label(graph, 1, 2, 0, 0, 0)], - edges=[(to_label(graph, 1, 2, 0, 0, 0), - to_label(graph, 1, 1, 0, 0, 0), inf)], - ) - - add_layer(graph, 3, [0, 0, 0], n_threads=1) - add_layer(graph, 3, [1, 0, 0], n_threads=1) - add_layer(graph, 4, [0, 0, 0], n_threads=1) - - return graph - - def create_chunk(cg, vertices=None, edges=None, timestamp=None): """ Helper function to add vertices and edges to the chunkedgraph - no safety checks! @@ -206,8 +36,7 @@ def create_chunk(cg, vertices=None, edges=None, timestamp=None): edges = edges if edges else [] vertices = vertices if vertices else [] vertices = np.unique(np.array(vertices, dtype=np.uint64)) - edges = [(np.uint64(v1), np.uint64(v2), np.float32(aff)) - for v1, v2, aff in edges] + edges = [(np.uint64(v1), np.uint64(v2), np.float32(aff)) for v1, v2, aff in edges] isolated_ids = [ x for x in vertices @@ -230,8 +59,7 @@ def create_chunk(cg, vertices=None, edges=None, timestamp=None): chunk_id = None if len(chunk_edges_active[EDGE_TYPES.in_chunk]): - chunk_id = cg.get_chunk_id( - chunk_edges_active[EDGE_TYPES.in_chunk].node_ids1[0]) + chunk_id = cg.get_chunk_id(chunk_edges_active[EDGE_TYPES.in_chunk].node_ids1[0]) elif len(vertices): chunk_id = cg.get_chunk_id(vertices[0]) @@ -257,11 +85,12 @@ def create_chunk(cg, vertices=None, edges=None, timestamp=None): cg.mock_edges += all_edges isolated_ids = np.array(isolated_ids, dtype=np.uint64) - add_atomic_edges( + add_atomic_chunk( cg, cg.get_chunk_coordinates(chunk_id), chunk_edges_active, isolated=isolated_ids, + time_stamp=timestamp, ) @@ -280,23 +109,3 @@ def get_layer_chunk_bounds( layer_bounds = atomic_chunk_bounds / (2 ** (layer - 2)) layer_bounds_d[layer] = np.ceil(layer_bounds).astype(int) return layer_bounds_d - - -@pytest.fixture(scope='session') -def sv_data(): - test_data_dir = 'pychunkedgraph/tests/data' - edges_file = f'{test_data_dir}/sv_edges.npy' - sv_edges = np.load(edges_file) - - source_file = f'{test_data_dir}/sv_sources.npy' - sv_sources = np.load(source_file) - - sinks_file = f'{test_data_dir}/sv_sinks.npy' - sv_sinks = np.load(sinks_file) - - affinity_file = f'{test_data_dir}/sv_affinity.npy' - sv_affinity = np.load(affinity_file) - - area_file = f'{test_data_dir}/sv_area.npy' - sv_area = np.load(area_file) - yield (sv_edges, sv_sources, sv_sinks, sv_affinity, sv_area) diff --git a/pychunkedgraph/tests/test_analysis_pathing.py b/pychunkedgraph/tests/test_analysis_pathing.py new file mode 100644 index 000000000..872158c6e --- /dev/null +++ b/pychunkedgraph/tests/test_analysis_pathing.py @@ -0,0 +1,558 @@ +"""Tests for pychunkedgraph.graph.analysis.pathing""" + +from datetime import datetime, timedelta, UTC +from math import inf +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from pychunkedgraph.graph.analysis.pathing import ( + get_first_shared_parent, + get_children_at_layer, + get_lvl2_edge_list, + find_l2_shortest_path, + compute_rough_coordinate_path, +) + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGetFirstSharedParent: + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + def test_same_root(self, gen_graph): + graph = self._build_graph(gen_graph) + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 1, 0, 0, 0) + parent = get_first_shared_parent(graph, sv0, sv1) + assert parent is not None + # The shared parent should be an ancestor of both SVs + root = graph.get_root(sv0) + # Verify the shared parent is on the path to root + assert graph.get_root(parent) == root + + def test_different_roots_returns_none(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + # Create two disconnected chunks + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 1, 0, 0, 0) + parent = get_first_shared_parent(graph, sv0, sv1) + assert parent is None + + +class TestGetChildrenAtLayer: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + children = get_children_at_layer(graph, root, 2) + assert len(children) > 0 + for child in children: + assert graph.get_chunk_layer(child) == 2 + + def test_allow_lower_layers(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + children = get_children_at_layer(graph, root, 2, allow_lower_layers=True) + assert len(children) > 0 + + +class TestGetLvl2EdgeList: + def _build_3chunk_graph(self, gen_graph): + """Build a graph with 3 chunks A(0,0,0), B(1,0,0), C(2,0,0) connected by cross-chunk edges. + + A:sv0 -- B:sv0 -- C:sv0 + """ + graph = gen_graph(n_layers=4) + + # Chunk A: sv0 connected to B:sv0 via cross-chunk edge + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + + # Chunk B: sv0 connected to A:sv0 and C:sv0 via cross-chunk edges + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), + ], + ) + + # Chunk C: sv0 connected to B:sv0 via cross-chunk edge + create_chunk( + graph, + vertices=[to_label(graph, 1, 2, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 2, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + return graph + + def test_basic(self, gen_graph): + """get_lvl2_edge_list should return edges between L2 IDs for a connected root.""" + graph = self._build_3chunk_graph(gen_graph) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + edges = get_lvl2_edge_list(graph, root) + + # There should be at least 2 edges: A_l2--B_l2 and B_l2--C_l2 + assert edges.shape[0] >= 2 + assert edges.shape[1] == 2 + + # All edge IDs should be L2 nodes (layer 2) + for edge in edges: + for node_id in edge: + assert graph.get_chunk_layer(node_id) == 2 + + def test_single_chunk_no_cross_edges(self, gen_graph): + """A single isolated chunk should produce no L2 edges.""" + graph = gen_graph(n_layers=4) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + edges = get_lvl2_edge_list(graph, root) + + assert edges.shape[0] == 0 + + +class TestFindL2ShortestPath: + def _build_3chunk_graph(self, gen_graph): + """Build a graph with 3 chunks A(0,0,0), B(1,0,0), C(2,0,0) connected linearly. + + A:sv0 -- B:sv0 -- C:sv0 + """ + graph = gen_graph(n_layers=4) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), + ], + ) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 2, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 2, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + return graph + + def test_path_between_endpoints(self, gen_graph): + """find_l2_shortest_path should return a path from source to target L2 IDs.""" + graph = self._build_3chunk_graph(gen_graph) + + # Get L2 parents of the supervoxels + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_c = to_label(graph, 1, 2, 0, 0, 0) + l2_a = graph.get_parent(sv_a) + l2_c = graph.get_parent(sv_c) + + path = find_l2_shortest_path(graph, l2_a, l2_c) + + assert path is not None + assert len(path) == 3 # A_l2 -> B_l2 -> C_l2 + # Path should start at source and end at target + assert path[0] == l2_a + assert path[-1] == l2_c + # All nodes in path should be layer 2 + for node_id in path: + assert graph.get_chunk_layer(node_id) == 2 + + def test_adjacent_l2_ids(self, gen_graph): + """find_l2_shortest_path between directly connected L2 IDs should return length 2 path.""" + graph = self._build_3chunk_graph(gen_graph) + + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + l2_a = graph.get_parent(sv_a) + l2_b = graph.get_parent(sv_b) + + path = find_l2_shortest_path(graph, l2_a, l2_b) + + assert path is not None + assert len(path) == 2 + assert path[0] == l2_a + assert path[-1] == l2_b + + def test_disconnected_returns_none(self, gen_graph): + """find_l2_shortest_path should return None when L2 IDs belong to different roots.""" + graph = gen_graph(n_layers=4) + + # Create two disconnected chunks + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[], + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + l2_a = graph.get_parent(sv_a) + l2_b = graph.get_parent(sv_b) + + path = find_l2_shortest_path(graph, l2_a, l2_b) + assert path is None + + +class TestGetChildrenAtLayerEdgeCases: + """Test get_children_at_layer with various edge cases.""" + + def test_children_at_layer_2_with_multiple_svs(self, gen_graph): + """Query children at layer 2 when root has multiple SVs in same chunk.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + to_label(graph, 1, 0, 0, 0, 2), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 1), to_label(graph, 1, 0, 0, 0, 2), 0.5), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + children = get_children_at_layer(graph, root, 2) + assert len(children) > 0 + for child in children: + assert graph.get_chunk_layer(child) == 2 + + def test_children_at_intermediate_layer(self, gen_graph): + """Query children at layer 3 from root at layer 4.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + children = get_children_at_layer(graph, root, 3) + assert len(children) > 0 + for child in children: + assert graph.get_chunk_layer(child) == 3 + + def test_children_allow_lower_layers_with_cross_chunk(self, gen_graph): + """Query with allow_lower_layers=True should include layer<=target.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + # Ask for layer 3 with allow_lower_layers=True + children = get_children_at_layer(graph, root, 3, allow_lower_layers=True) + assert len(children) > 0 + for child in children: + assert graph.get_chunk_layer(child) <= 3 + + def test_children_at_layer_from_l2_node(self, gen_graph): + """Querying children at layer 2 from a layer 2 node should return the node itself + or its layer-2 children (which is itself).""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + sv = to_label(graph, 1, 0, 0, 0, 0) + l2 = graph.get_parent(sv) + # From l2, get children at layer 2 (with allow_lower=True since + # the children of an L2 node are SVs at layer 1) + children = get_children_at_layer(graph, l2, 2, allow_lower_layers=True) + assert len(children) > 0 + + +class TestGetLvl2EdgeListWithBbox: + """Test get_lvl2_edge_list with a bounding box parameter.""" + + def _build_3chunk_graph(self, gen_graph): + """Build a graph with 3 chunks connected linearly.""" + graph = gen_graph(n_layers=4) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), + ], + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 2, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 2, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + return graph + + def test_lvl2_edge_list_with_bbox(self, gen_graph): + """get_lvl2_edge_list with a bbox should return edges within the bbox.""" + graph = self._build_3chunk_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + + # Use a large bbox that encompasses everything + bbox = np.array([[0, 0, 0], [2048, 2048, 256]]) + edges = get_lvl2_edge_list(graph, root, bbox=bbox) + + # Should have edges + assert edges.shape[1] == 2 + # All IDs should be L2 nodes + for edge in edges: + for node_id in edge: + assert graph.get_chunk_layer(node_id) == 2 + + +class TestFindL2ShortestPathEdgeCases: + """Test find_l2_shortest_path with additional edge cases.""" + + def test_path_through_chain(self, gen_graph): + """find_l2_shortest_path through a 4-chunk chain should return correct length.""" + graph = gen_graph(n_layers=4) + + # Build a 4-chunk chain: A(0,0,0)--B(1,0,0)--C(2,0,0)--D(3,0,0) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), + ], + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 2, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 2, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + (to_label(graph, 1, 2, 0, 0, 0), to_label(graph, 1, 3, 0, 0, 0), inf), + ], + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 3, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 3, 0, 0, 0), to_label(graph, 1, 2, 0, 0, 0), inf), + ], + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_d = to_label(graph, 1, 3, 0, 0, 0) + l2_a = graph.get_parent(sv_a) + l2_d = graph.get_parent(sv_d) + + path = find_l2_shortest_path(graph, l2_a, l2_d) + assert path is not None + assert len(path) == 4 # A_l2 -> B_l2 -> C_l2 -> D_l2 + assert path[0] == l2_a + assert path[-1] == l2_d + + +class TestComputeRoughCoordinatePath: + """Test compute_rough_coordinate_path returns proper coordinates.""" + + def test_basic_coordinate_path(self, gen_graph): + """compute_rough_coordinate_path should return a list of float32 3D coordinates.""" + graph = gen_graph(n_layers=4) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + l2_a = graph.get_parent(sv_a) + l2_b = graph.get_parent(sv_b) + + path = find_l2_shortest_path(graph, l2_a, l2_b) + assert path is not None + + # Mock cv methods that CloudVolumeMock doesn't have + mock_cv = MagicMock() + mock_cv.mip_voxel_offset = MagicMock(return_value=np.array([0, 0, 0])) + mock_cv.mip_resolution = MagicMock(return_value=np.array([1, 1, 1])) + graph.meta._ws_cv = mock_cv + + coordinate_path = compute_rough_coordinate_path(graph, path) + assert len(coordinate_path) == len(path) + for coord in coordinate_path: + assert isinstance(coord, np.ndarray) + assert coord.dtype == np.float32 + assert len(coord) == 3 diff --git a/pychunkedgraph/tests/test_attributes.py b/pychunkedgraph/tests/test_attributes.py new file mode 100644 index 000000000..e630353d7 --- /dev/null +++ b/pychunkedgraph/tests/test_attributes.py @@ -0,0 +1,88 @@ +"""Tests for pychunkedgraph.graph.attributes""" + +import numpy as np +import pytest + +from pychunkedgraph.graph.attributes import ( + _Attribute, + _AttributeArray, + Concurrency, + Connectivity, + Hierarchy, + GraphMeta, + GraphVersion, + OperationLogs, + from_key, +) +from pychunkedgraph.graph.utils import basetypes + + +class TestAttribute: + def test_serialize_deserialize_numpy(self): + attr = Hierarchy.Child + arr = np.array([1, 2, 3], dtype=basetypes.NODE_ID) + data = attr.serialize(arr) + result = attr.deserialize(data) + np.testing.assert_array_equal(result, arr) + + def test_serialize_deserialize_string(self): + attr = OperationLogs.UserID + data = attr.serialize("test_user") + assert attr.deserialize(data) == "test_user" + + def test_basetype(self): + assert Hierarchy.Child.basetype == basetypes.NODE_ID.type + assert OperationLogs.UserID.basetype == str + + def test_index(self): + attr = Connectivity.CrossChunkEdge[5] + assert attr.index == 5 + + def test_family_id(self): + assert Hierarchy.Child.family_id == "0" + assert Concurrency.Counter.family_id == "1" + assert OperationLogs.UserID.family_id == "2" + + +class TestAttributeArray: + def test_getitem(self): + attr = Connectivity.AtomicCrossChunkEdge[3] + assert isinstance(attr, _Attribute) + assert attr.key == b"atomic_cross_edges_3" + + def test_pattern(self): + assert Connectivity.CrossChunkEdge.pattern == b"cross_edges_%d" + + def test_serialize_deserialize(self): + arr = np.array([[1, 2], [3, 4]], dtype=basetypes.NODE_ID) + data = Connectivity.CrossChunkEdge.serialize(arr) + result = Connectivity.CrossChunkEdge.deserialize(data) + np.testing.assert_array_equal(result, arr) + + +class TestFromKey: + def test_valid_key(self): + result = from_key("0", b"children") + assert result is Hierarchy.Child + + def test_invalid_key_raises(self): + with pytest.raises(KeyError, match="Unknown key"): + from_key("99", b"nonexistent") + + +class TestOperationLogs: + def test_all_returns_list(self): + result = OperationLogs.all() + assert isinstance(result, list) + assert len(result) == 16 + assert OperationLogs.OperationID in result + assert OperationLogs.UserID in result + assert OperationLogs.RootID in result + assert OperationLogs.AddedEdge in result + + def test_status_codes(self): + assert OperationLogs.StatusCodes.SUCCESS.value == 0 + assert OperationLogs.StatusCodes.CREATED.value == 1 + assert OperationLogs.StatusCodes.EXCEPTION.value == 2 + assert OperationLogs.StatusCodes.WRITE_STARTED.value == 3 + assert OperationLogs.StatusCodes.WRITE_FAILED.value == 4 diff --git a/pychunkedgraph/tests/test_cache.py b/pychunkedgraph/tests/test_cache.py new file mode 100644 index 000000000..aadffcd3e --- /dev/null +++ b/pychunkedgraph/tests/test_cache.py @@ -0,0 +1,152 @@ +"""Tests for pychunkedgraph.graph.cache""" + +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest + +from pychunkedgraph.graph.cache import CacheService, update + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestUpdate: + def test_one_to_one(self): + cache = {} + update(cache, [1, 2, 3], [10, 20, 30]) + assert cache == {1: 10, 2: 20, 3: 30} + + def test_many_to_one(self): + cache = {} + update(cache, [1, 2, 3], 99) + assert cache == {1: 99, 2: 99, 3: 99} + + +class TestCacheService: + def _build_simple_graph(self, gen_graph): + """Build a simple 2-chunk graph with 2 SVs per chunk.""" + from math import inf + + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + def test_init(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + assert len(cache) == 0 + + def test_len(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + cache.parent(sv) + assert len(cache) >= 1 + + def test_clear(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + cache.parent(sv) + cache.clear() + assert len(cache) == 0 + + def test_parent_miss_then_hit(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + + # First call is a miss + parent1 = cache.parent(sv) + assert cache.stats["parents"]["misses"] == 1 + + # Second call is a hit + parent2 = cache.parent(sv) + assert cache.stats["parents"]["hits"] == 1 + assert parent1 == parent2 + + def test_children_backfills_parent(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + children = cache.children(root) + assert len(children) > 0 + # Children should be backfilled as parents + for child in children: + assert child in cache.parents_cache + + def test_get_stats(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + cache.parent(sv) + cache.parent(sv) + stats = cache.get_stats() + assert "parents" in stats + assert stats["parents"]["total"] == 2 + assert "hit_rate" in stats["parents"] + + def test_reset_stats(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + cache.parent(sv) + cache.reset_stats() + assert cache.stats["parents"]["hits"] == 0 + assert cache.stats["parents"]["misses"] == 0 + + def test_parents_multiple_empty(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + result = cache.parents_multiple(np.array([], dtype=np.uint64)) + assert len(result) == 0 + + def test_parents_multiple(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + svs = np.array( + [ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + dtype=np.uint64, + ) + result = cache.parents_multiple(svs) + assert len(result) == 2 + + def test_children_multiple(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = cache.children_multiple(np.array([root], dtype=np.uint64)) + assert root in result + + def test_children_multiple_flatten(self, gen_graph): + graph = self._build_simple_graph(gen_graph) + cache = CacheService(graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = cache.children_multiple( + np.array([root], dtype=np.uint64), flatten=True + ) + assert isinstance(result, np.ndarray) diff --git a/pychunkedgraph/tests/test_chunkedgraph_extended.py b/pychunkedgraph/tests/test_chunkedgraph_extended.py new file mode 100644 index 000000000..dd398f27e --- /dev/null +++ b/pychunkedgraph/tests/test_chunkedgraph_extended.py @@ -0,0 +1,1591 @@ +"""Tests for pychunkedgraph.graph.chunkedgraph - extended coverage""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk +from ..graph.operation import GraphEditOperation, MergeOperation, SplitOperation +from ..graph.exceptions import PreconditionError + + +class TestChunkedGraphExtended: + def _build_graph(self, gen_graph): + """Build a simple multi-chunk graph.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + # Chunk A: sv 0, 1 connected + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + + # Chunk B: sv 0 connected cross-chunk to A + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + def test_is_root_true(self, gen_graph): + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + root = graph.get_root(sv) + assert graph.is_root(root) + + def test_is_root_false(self, gen_graph): + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + assert not graph.is_root(sv) + + def test_get_parents_raw_only(self, gen_graph): + graph = self._build_graph(gen_graph) + svs = np.array( + [ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + dtype=np.uint64, + ) + parents = graph.get_parents(svs, raw_only=True) + assert len(parents) == 2 + # Parents should be L2 IDs + for p in parents: + assert graph.get_chunk_layer(p) == 2 + + def test_get_parents_fail_to_zero(self, gen_graph): + graph = self._build_graph(gen_graph) + # Non-existent ID should return 0 with fail_to_zero + bad_id = np.uint64(99999999) + result = graph.get_parents( + np.array([bad_id], dtype=np.uint64), fail_to_zero=True + ) + assert result[0] == 0 + + def test_get_children_flatten(self, gen_graph): + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + children = graph.get_children([root], flatten=True) + assert isinstance(children, np.ndarray) + assert len(children) > 0 + + def test_is_latest_roots(self, gen_graph): + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = graph.is_latest_roots(np.array([root], dtype=np.uint64)) + assert result[0] + + def test_get_node_timestamps(self, gen_graph): + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + root = graph.get_root(sv) + ts = graph.get_node_timestamps(np.array([root]), return_numpy=False) + assert len(ts) == 1 + + def test_get_earliest_timestamp(self, gen_graph): + graph = self._build_graph(gen_graph) + ts = graph.get_earliest_timestamp() + # May return None if no operation logs exist; test the method runs + assert ts is None or isinstance(ts, datetime) + + def test_get_l2children(self, gen_graph): + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + l2_children = graph.get_l2children(np.array([root], dtype=np.uint64)) + assert len(l2_children) > 0 + for child in l2_children: + assert graph.get_chunk_layer(child) == 2 + + # --- helpers for edit-based tests --- + + def _build_and_merge(self, gen_graph): + """Build a single-chunk graph with two disconnected SVs and merge them.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + return graph, result.new_root_ids[0], result + + @pytest.mark.timeout(30) + def test_get_operation_ids(self, gen_graph): + """After a merge, get_operation_ids on the new root should return at least one operation.""" + graph, new_root, result = self._build_and_merge(gen_graph) + op_ids = graph.get_operation_ids([new_root]) + assert new_root in op_ids + assert len(op_ids[new_root]) >= 1 + # Each entry is (operation_id_value, timestamp) + op_id_val, ts = op_ids[new_root][0] + assert op_id_val == result.operation_id + + @pytest.mark.timeout(30) + def test_get_single_leaf_multiple(self, gen_graph): + """get_single_leaf_multiple for an L2 node should return an L1 supervoxel.""" + graph, new_root, _ = self._build_and_merge(gen_graph) + # The new_root in n_layers=2 is actually L2 + assert graph.get_chunk_layer(new_root) == 2 + leaves = graph.get_single_leaf_multiple(np.array([new_root], dtype=np.uint64)) + assert len(leaves) == 1 + assert graph.get_chunk_layer(leaves[0]) == 1 + # The returned leaf should be one of our two SVs + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + assert leaves[0] in [sv0, sv1] + + @pytest.mark.timeout(30) + def test_get_atomic_cross_edges(self, gen_graph): + """get_atomic_cross_edges for an L2 node with cross-chunk connections.""" + graph = self._build_graph(gen_graph) + sv_a0 = to_label(graph, 1, 0, 0, 0, 0) + + # Get the L2 parent of sv_a0 + parent = graph.get_parents(np.array([sv_a0], dtype=np.uint64), raw_only=True)[0] + assert graph.get_chunk_layer(parent) == 2 + + result = graph.get_atomic_cross_edges([parent]) + assert parent in result + # Should have at least one layer of cross edges + assert isinstance(result[parent], dict) + + @pytest.mark.timeout(30) + def test_get_cross_chunk_edges_raw(self, gen_graph): + """get_cross_chunk_edges with raw_only=True should return cross edges.""" + graph = self._build_graph(gen_graph) + sv_a0 = to_label(graph, 1, 0, 0, 0, 0) + + # Get the L2 parent + parent = graph.get_parents(np.array([sv_a0], dtype=np.uint64), raw_only=True)[0] + assert graph.get_chunk_layer(parent) == 2 + + result = graph.get_cross_chunk_edges([parent], raw_only=True) + assert parent in result + assert isinstance(result[parent], dict) + + @pytest.mark.timeout(30) + def test_get_parents_not_current(self, gen_graph): + """get_parents with current=False should return list of (parent, timestamp) tuples.""" + graph, new_root, _ = self._build_and_merge(gen_graph) + sv0 = to_label(graph, 1, 0, 0, 0, 0) + + # current=False returns list of lists of (value, timestamp) pairs + parents = graph.get_parents( + np.array([sv0], dtype=np.uint64), raw_only=True, current=False + ) + assert len(parents) == 1 + # Each element is a list of (parent_value, timestamp) tuples + assert isinstance(parents[0], list) + assert len(parents[0]) >= 1 + parent_val, parent_ts = parents[0][0] + assert parent_val != 0 + assert isinstance(parent_ts, datetime) + + +class TestFromLogRecord: + """Test GraphEditOperation.from_log_record with real merge/split logs.""" + + def _build_two_sv_graph(self, gen_graph): + """Build a 2-layer graph with two disconnected SVs.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + return graph + + @pytest.mark.timeout(30) + def test_merge_from_log(self, gen_graph): + """After a merge, from_log_record should return a MergeOperation.""" + graph = self._build_two_sv_graph(gen_graph) + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + log, ts = graph.client.read_log_entry(result.operation_id) + op = GraphEditOperation.from_log_record(graph, log) + assert isinstance(op, MergeOperation) + + @pytest.mark.timeout(30) + def test_split_from_log(self, gen_graph): + """After a split, from_log_record should return a SplitOperation.""" + graph = self._build_two_sv_graph(gen_graph) + # First merge so the SVs belong to the same root + graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + # Now split them + split_result = graph.remove_edges( + "TestUser", + source_ids=to_label(graph, 1, 0, 0, 0, 0), + sink_ids=to_label(graph, 1, 0, 0, 0, 1), + mincut=False, + ) + log, ts = graph.client.read_log_entry(split_result.operation_id) + op = GraphEditOperation.from_log_record(graph, log) + assert isinstance(op, SplitOperation) + + +class TestCheckIds: + """Test ID validation in MergeOperation.""" + + def _build_two_sv_graph(self, gen_graph): + """Build a 2-layer graph with two disconnected SVs.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + return graph + + @pytest.mark.timeout(30) + def test_source_equals_sink_raises(self, gen_graph): + """MergeOperation with source==sink should raise PreconditionError (self-loop).""" + graph = self._build_two_sv_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + with pytest.raises(PreconditionError): + graph.add_edges( + "TestUser", + [sv, sv], + affinities=[0.3], + ) + + @pytest.mark.timeout(30) + def test_nonexistent_supervoxel_raises(self, gen_graph): + """Using a supervoxel ID that doesn't exist should raise an error.""" + graph = self._build_two_sv_graph(gen_graph) + sv_real = to_label(graph, 1, 0, 0, 0, 0) + # Use a layer-2 ID as a fake "supervoxel", which fails the layer check + sv_fake = to_label(graph, 2, 0, 0, 0, 99) + with pytest.raises(Exception): + graph.add_edges( + "TestUser", + [sv_real, sv_fake], + affinities=[0.3], + ) + + +class TestGetRootsExtended: + """Tests for get_roots with stop_layer and ceil parameters (lines 380-461).""" + + def _build_cross_chunk(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph, fake_ts + + @pytest.mark.timeout(30) + def test_get_roots_with_stop_layer(self, gen_graph): + """get_roots with stop_layer should return IDs at that layer.""" + graph, _ = self._build_cross_chunk(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + # Stop at layer 3 instead of going to root (layer 4) + result = graph.get_roots(np.array([sv], dtype=np.uint64), stop_layer=3) + assert len(result) == 1 + assert graph.get_chunk_layer(result[0]) == 3 + + @pytest.mark.timeout(30) + def test_get_roots_with_stop_layer_and_ceil_false(self, gen_graph): + """get_roots with stop_layer and ceil=False should not exceed stop_layer.""" + graph, _ = self._build_cross_chunk(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + result = graph.get_roots( + np.array([sv], dtype=np.uint64), stop_layer=3, ceil=False + ) + assert len(result) == 1 + assert graph.get_chunk_layer(result[0]) <= 3 + + @pytest.mark.timeout(30) + def test_get_roots_multiple_svs(self, gen_graph): + """get_roots with multiple SVs should return root for each.""" + graph, _ = self._build_cross_chunk(gen_graph) + svs = np.array( + [ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + to_label(graph, 1, 1, 0, 0, 0), + ], + dtype=np.uint64, + ) + roots = graph.get_roots(svs) + assert len(roots) == 3 + # All should reach the top layer + for r in roots: + assert graph.get_chunk_layer(r) == 4 + + @pytest.mark.timeout(30) + def test_get_roots_already_at_stop_layer(self, gen_graph): + """get_roots for a node already at stop_layer should return it unchanged.""" + graph, _ = self._build_cross_chunk(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + root = graph.get_root(sv) + # root is at layer 4; asking for stop_layer=4 should return it + result = graph.get_roots(np.array([root], dtype=np.uint64), stop_layer=4) + assert result[0] == root + + @pytest.mark.timeout(30) + def test_get_roots_fail_to_zero(self, gen_graph): + """get_roots with a zero ID and fail_to_zero should keep it as zero.""" + graph, _ = self._build_cross_chunk(gen_graph) + result = graph.get_roots(np.array([0], dtype=np.uint64), fail_to_zero=True) + assert result[0] == 0 + + @pytest.mark.timeout(30) + def test_get_root_stop_layer_ceil_false(self, gen_graph): + """get_root (singular) with stop_layer and ceil=False.""" + graph, _ = self._build_cross_chunk(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + result = graph.get_root(sv, stop_layer=3, ceil=False) + assert graph.get_chunk_layer(result) <= 3 + + +class TestGetChildrenExtended: + """Tests for get_children with flatten=True and edge cases (lines 271-296).""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_children_flatten_multiple(self, gen_graph): + """get_children with multiple node IDs and flatten=True returns flat array.""" + graph = self._build_graph(gen_graph) + sv_a0 = to_label(graph, 1, 0, 0, 0, 0) + sv_b0 = to_label(graph, 1, 1, 0, 0, 0) + + parent_a = graph.get_parents(np.array([sv_a0], dtype=np.uint64), raw_only=True)[ + 0 + ] + parent_b = graph.get_parents(np.array([sv_b0], dtype=np.uint64), raw_only=True)[ + 0 + ] + + children = graph.get_children([parent_a, parent_b], flatten=True) + assert isinstance(children, np.ndarray) + # Should contain at least sv_a0, sv_a1, sv_b0 + assert len(children) >= 3 + + @pytest.mark.timeout(30) + def test_get_children_flatten_empty(self, gen_graph): + """get_children with flatten=True on empty list returns empty array.""" + graph = self._build_graph(gen_graph) + children = graph.get_children([], flatten=True) + assert isinstance(children, np.ndarray) + assert len(children) == 0 + + @pytest.mark.timeout(30) + def test_get_children_dict(self, gen_graph): + """get_children without flatten returns a dict.""" + graph = self._build_graph(gen_graph) + sv_a0 = to_label(graph, 1, 0, 0, 0, 0) + parent = graph.get_parents(np.array([sv_a0], dtype=np.uint64), raw_only=True)[0] + children_d = graph.get_children([parent]) + assert isinstance(children_d, dict) + assert parent in children_d + + @pytest.mark.timeout(30) + def test_get_children_scalar(self, gen_graph): + """get_children with a scalar node_id returns an array.""" + graph = self._build_graph(gen_graph) + sv_a0 = to_label(graph, 1, 0, 0, 0, 0) + parent = graph.get_parents(np.array([sv_a0], dtype=np.uint64), raw_only=True)[0] + children = graph.get_children(parent, raw_only=True) + assert isinstance(children, np.ndarray) + assert len(children) >= 1 + + +class TestIsLatestRootsExtended: + """Tests for is_latest_roots (lines 524-544).""" + + def _build_and_merge(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + # Get the initial roots + root0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + root1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + + # Merge + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, root0, root1, new_root + + @pytest.mark.timeout(30) + def test_is_latest_roots_after_merge(self, gen_graph): + """After a merge, old roots should not be latest, new root should be.""" + graph, root0, root1, new_root = self._build_and_merge(gen_graph) + result = graph.is_latest_roots( + np.array([root0, root1, new_root], dtype=np.uint64) + ) + # Old roots are superseded + assert not result[0] + assert not result[1] + # New root is latest + assert result[2] + + @pytest.mark.timeout(30) + def test_is_latest_roots_empty(self, gen_graph): + """is_latest_roots with nonexistent IDs should return all False.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + result = graph.is_latest_roots(np.array([99999999], dtype=np.uint64)) + assert not result[0] + + +class TestGetNodeTimestampsExtended: + """Tests for get_node_timestamps (lines 773-800).""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_node_timestamps_return_numpy(self, gen_graph): + """get_node_timestamps with return_numpy=True should return numpy array.""" + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + ts = graph.get_node_timestamps( + np.array([root], dtype=np.uint64), return_numpy=True + ) + assert isinstance(ts, np.ndarray) + assert len(ts) == 1 + + @pytest.mark.timeout(30) + def test_get_node_timestamps_return_list(self, gen_graph): + """get_node_timestamps with return_numpy=False should return a list.""" + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + ts = graph.get_node_timestamps( + np.array([root], dtype=np.uint64), return_numpy=False + ) + assert isinstance(ts, list) + assert len(ts) == 1 + + @pytest.mark.timeout(30) + def test_get_node_timestamps_empty(self, gen_graph): + """get_node_timestamps with nonexistent nodes should handle gracefully.""" + graph = self._build_graph(gen_graph) + ts = graph.get_node_timestamps( + np.array([np.uint64(99999999)], dtype=np.uint64), return_numpy=True + ) + # Should either return empty or return a fallback timestamp + assert isinstance(ts, np.ndarray) + + @pytest.mark.timeout(30) + def test_get_node_timestamps_empty_return_list(self, gen_graph): + """get_node_timestamps with empty dict result and return_numpy=False.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + # Don't create any chunks; asking for timestamps on nonexistent nodes + ts = graph.get_node_timestamps( + np.array([np.uint64(99999999)], dtype=np.uint64), return_numpy=False + ) + assert isinstance(ts, list) + assert len(ts) == 0 + + +class TestGetOperationIdsExtended: + """Tests for get_operation_ids (lines 1033-1042).""" + + @pytest.mark.timeout(30) + def test_get_operation_ids_no_ops(self, gen_graph): + """get_operation_ids on a node with no operations returns empty dict.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = graph.get_operation_ids([root]) + # No operations => root may not be in result, or have empty list + if root in result: + assert isinstance(result[root], list) + + +class TestGetSingleLeafMultipleExtended: + """Tests for get_single_leaf_multiple (lines 1044-1062).""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_single_leaf_from_root(self, gen_graph): + """get_single_leaf_multiple from a root (layer 4) should drill down to layer 1.""" + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + assert graph.get_chunk_layer(root) == 4 + leaves = graph.get_single_leaf_multiple(np.array([root], dtype=np.uint64)) + assert len(leaves) == 1 + assert graph.get_chunk_layer(leaves[0]) == 1 + + @pytest.mark.timeout(30) + def test_get_single_leaf_from_l2(self, gen_graph): + """get_single_leaf_multiple from L2 node should return one of its SV children.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + parent = graph.get_parents(np.array([sv], dtype=np.uint64), raw_only=True)[0] + assert graph.get_chunk_layer(parent) == 2 + leaves = graph.get_single_leaf_multiple(np.array([parent], dtype=np.uint64)) + assert len(leaves) == 1 + assert graph.get_chunk_layer(leaves[0]) == 1 + + @pytest.mark.timeout(30) + def test_get_single_leaf_multiple_nodes(self, gen_graph): + """get_single_leaf_multiple with multiple node IDs should return one leaf each.""" + graph = self._build_graph(gen_graph) + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + parent_a = graph.get_parents(np.array([sv_a], dtype=np.uint64), raw_only=True)[ + 0 + ] + parent_b = graph.get_parents(np.array([sv_b], dtype=np.uint64), raw_only=True)[ + 0 + ] + leaves = graph.get_single_leaf_multiple( + np.array([parent_a, parent_b], dtype=np.uint64) + ) + assert len(leaves) == 2 + for leaf in leaves: + assert graph.get_chunk_layer(leaf) == 1 + + +class TestGetL2ChildrenExtended: + """Tests for get_l2children (lines 1079-1092).""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_l2children_from_root(self, gen_graph): + """get_l2children from a root should return all L2 children.""" + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + l2_children = graph.get_l2children(np.array([root], dtype=np.uint64)) + assert isinstance(l2_children, np.ndarray) + assert len(l2_children) >= 2 + for child in l2_children: + assert graph.get_chunk_layer(child) == 2 + + @pytest.mark.timeout(30) + def test_get_l2children_from_l3(self, gen_graph): + """get_l2children from an L3 node should return L2 children.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + # Get L3 parent + root = graph.get_root(sv, stop_layer=3) + assert graph.get_chunk_layer(root) == 3 + l2_children = graph.get_l2children(np.array([root], dtype=np.uint64)) + assert len(l2_children) >= 1 + for child in l2_children: + assert graph.get_chunk_layer(child) == 2 + + @pytest.mark.timeout(30) + def test_get_l2children_from_l2(self, gen_graph): + """get_l2children from an L2 node drills down to its children, + which are L1 - so no L2 children are found; result is empty.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + parent = graph.get_parents(np.array([sv], dtype=np.uint64), raw_only=True)[0] + assert graph.get_chunk_layer(parent) == 2 + l2_children = graph.get_l2children(np.array([parent], dtype=np.uint64)) + # L2 nodes only have L1 (SV) children, so no L2 descendants found + assert isinstance(l2_children, np.ndarray) + assert len(l2_children) == 0 + + +class TestGetChunkLayersExtended: + """Tests for get_chunk_layers and related helpers (line 951-952, 946).""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_chunk_layers_multiple(self, gen_graph): + """get_chunk_layers for nodes at different layers.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + parent_l2 = graph.get_parents(np.array([sv], dtype=np.uint64), raw_only=True)[0] + root = graph.get_root(sv) + layers = graph.get_chunk_layers( + np.array([sv, parent_l2, root], dtype=np.uint64) + ) + assert layers[0] == 1 + assert layers[1] == 2 + assert layers[2] == 4 + + @pytest.mark.timeout(30) + def test_get_segment_id_limit(self, gen_graph): + """get_segment_id_limit should return a valid limit.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + limit = graph.get_segment_id_limit(sv) + assert limit > 0 + + @pytest.mark.timeout(30) + def test_get_chunk_coordinates(self, gen_graph): + """get_chunk_coordinates should return the chunk coordinates of a node.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + coords = graph.get_chunk_coordinates(sv) + assert len(coords) == 3 + np.testing.assert_array_equal(coords, [0, 0, 0]) + + @pytest.mark.timeout(30) + def test_get_chunk_layers_and_coordinates(self, gen_graph): + """get_chunk_layers_and_coordinates returns layers and coords together.""" + graph = self._build_graph(gen_graph) + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + layers, coords = graph.get_chunk_layers_and_coordinates( + np.array([sv_a, sv_b], dtype=np.uint64) + ) + assert len(layers) == 2 + assert layers[0] == 1 + assert layers[1] == 1 + assert coords.shape == (2, 3) + + +class TestGetAtomicCrossEdgesExtended: + """Tests for get_atomic_cross_edges (lines 315-336).""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_atomic_cross_edges_multiple_l2(self, gen_graph): + """get_atomic_cross_edges with multiple L2 IDs.""" + graph = self._build_graph(gen_graph) + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + parent_a = graph.get_parents(np.array([sv_a], dtype=np.uint64), raw_only=True)[ + 0 + ] + parent_b = graph.get_parents(np.array([sv_b], dtype=np.uint64), raw_only=True)[ + 0 + ] + result = graph.get_atomic_cross_edges([parent_a, parent_b]) + assert parent_a in result + assert parent_b in result + # At least one should have cross edges + has_edges = any(len(v) > 0 for v in result.values()) + assert has_edges + + @pytest.mark.timeout(30) + def test_get_atomic_cross_edges_no_cross(self, gen_graph): + """get_atomic_cross_edges for an L2 node with no cross edges.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + sv = to_label(graph, 1, 0, 0, 0, 0) + parent = graph.get_parents(np.array([sv], dtype=np.uint64), raw_only=True)[0] + result = graph.get_atomic_cross_edges([parent]) + assert parent in result + assert isinstance(result[parent], dict) + # No cross edges + assert len(result[parent]) == 0 + + +class TestGetAllParentsDictExtended: + """Tests for get_all_parents_dict and get_all_parents_dict_multiple.""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_all_parents_dict(self, gen_graph): + """get_all_parents_dict returns a dict mapping layer -> parent.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + d = graph.get_all_parents_dict(sv) + assert isinstance(d, dict) + # Should have entries for layers 2, 3, 4 + assert 2 in d + assert 4 in d + + @pytest.mark.timeout(30) + def test_get_all_parents_dict_multiple(self, gen_graph): + """get_all_parents_dict_multiple for multiple SVs.""" + graph = self._build_graph(gen_graph) + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + result = graph.get_all_parents_dict_multiple( + np.array([sv0, sv1], dtype=np.uint64) + ) + assert sv0 in result + assert sv1 in result + # Both should have parents at layer 2 + assert 2 in result[sv0] + assert 2 in result[sv1] + + +class TestMiscMethods: + """Tests for misc ChunkedGraph methods.""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_serialized_info(self, gen_graph): + """get_serialized_info should return a dict with graph_id.""" + graph = self._build_graph(gen_graph) + info = graph.get_serialized_info() + assert isinstance(info, dict) + assert "graph_id" in info + + @pytest.mark.timeout(30) + def test_get_chunk_id(self, gen_graph): + """get_chunk_id should return a valid chunk id.""" + graph = self._build_graph(gen_graph) + chunk_id = graph.get_chunk_id(layer=2, x=0, y=0, z=0) + assert chunk_id > 0 + assert graph.get_chunk_layer(chunk_id) == 2 + + @pytest.mark.timeout(30) + def test_get_node_id(self, gen_graph): + """get_node_id should construct node IDs correctly.""" + graph = self._build_graph(gen_graph) + node_id = graph.get_node_id(np.uint64(1), layer=1, x=0, y=0, z=0) + assert node_id > 0 + assert graph.get_chunk_layer(node_id) == 1 + + @pytest.mark.timeout(30) + def test_get_segment_id(self, gen_graph): + """get_segment_id should extract segment id from node id.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 5) + seg_id = graph.get_segment_id(sv) + assert seg_id == 5 + + @pytest.mark.timeout(30) + def test_get_parent_chunk_id(self, gen_graph): + """get_parent_chunk_id should return the chunk id of the parent layer.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + parent_chunk = graph.get_parent_chunk_id(sv) + assert graph.get_chunk_layer(parent_chunk) == 2 + + @pytest.mark.timeout(30) + def test_get_children_chunk_ids(self, gen_graph): + """get_children_chunk_ids should return chunk IDs one layer below.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + root = graph.get_root(sv) + # root is at layer 4; children chunks should be at layer 3 + children_chunks = graph.get_children_chunk_ids(root) + for cc in children_chunks: + assert graph.get_chunk_layer(cc) == 3 + + @pytest.mark.timeout(30) + def test_get_cross_chunk_edges_empty(self, gen_graph): + """get_cross_chunk_edges with empty node_ids should return empty dict.""" + graph = self._build_graph(gen_graph) + result = graph.get_cross_chunk_edges([], raw_only=True) + assert isinstance(result, dict) + assert len(result) == 0 + + +class TestIsLatestRootsAfterMerge: + """Test is_latest_roots after a merge operation (lines 524-539, 689-701).""" + + def _build_and_merge(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + old_root0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + old_root1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, old_root0, old_root1, new_root + + @pytest.mark.timeout(30) + def test_is_latest_roots_after_merge(self, gen_graph): + """After merge, old roots are not latest; new root is latest.""" + graph, old_root0, old_root1, new_root = self._build_and_merge(gen_graph) + result = graph.is_latest_roots( + np.array([old_root0, old_root1, new_root], dtype=np.uint64) + ) + assert not result[0], "Old root0 should not be latest after merge" + assert not result[1], "Old root1 should not be latest after merge" + assert result[2], "New root should be latest after merge" + + +class TestGetSubgraphNodesOnly: + """Test get_subgraph with nodes_only=True (lines 602-613).""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_subgraph_nodes_only(self, gen_graph): + """get_subgraph with nodes_only=True should return layer->node_ids dict.""" + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = graph.get_subgraph(root, nodes_only=True) + # Result should be a dict with layer 2 by default + assert isinstance(result, dict) + assert 2 in result + l2_nodes = result[2] + assert len(l2_nodes) >= 2 + for node in l2_nodes: + assert graph.get_chunk_layer(node) == 2 + + @pytest.mark.timeout(30) + def test_get_subgraph_nodes_only_multiple_layers(self, gen_graph): + """get_subgraph with nodes_only=True and return_layers=[2,3].""" + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = graph.get_subgraph(root, nodes_only=True, return_layers=[2, 3]) + assert isinstance(result, dict) + # Should have entries for layer 2 and/or 3 + assert 2 in result or 3 in result + + +class TestGetSubgraphEdgesOnly: + """Test get_subgraph with edges_only=True.""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_subgraph_edges_only(self, gen_graph): + """get_subgraph with edges_only=True should return edges.""" + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = graph.get_subgraph(root, edges_only=True) + # edges_only returns Edges from get_l2_agglomerations + # It should be a tuple of Edges or similar iterable + assert result is not None + + +# =========================================================================== +# is_latest_roots after merge -- detailed tests (lines 689-701) +# =========================================================================== + + +class TestIsLatestRootsDetailed: + """Detailed tests for is_latest_roots checking old roots are not latest after merge.""" + + def _build_and_merge(self, gen_graph): + """Build graph with two disconnected SVs, merge them, return old and new roots.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + old_root0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + old_root1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, old_root0, old_root1, new_root + + @pytest.mark.timeout(30) + def test_is_latest_roots_correct(self, gen_graph): + """After merge, old roots should be flagged as not latest, new root as latest.""" + graph, old_root0, old_root1, new_root = self._build_and_merge(gen_graph) + result = graph.is_latest_roots( + np.array([old_root0, old_root1, new_root], dtype=np.uint64) + ) + assert not result[0], "Old root0 should not be latest after merge" + assert not result[1], "Old root1 should not be latest after merge" + assert result[2], "New root should be latest after merge" + + @pytest.mark.timeout(30) + def test_is_latest_roots_single_old_root(self, gen_graph): + """Check a single old root is not latest after merge.""" + graph, old_root0, _, _ = self._build_and_merge(gen_graph) + result = graph.is_latest_roots(np.array([old_root0], dtype=np.uint64)) + assert not result[0] + + @pytest.mark.timeout(30) + def test_is_latest_roots_single_new_root(self, gen_graph): + """Check a single new root is latest after merge.""" + graph, _, _, new_root = self._build_and_merge(gen_graph) + result = graph.is_latest_roots(np.array([new_root], dtype=np.uint64)) + assert result[0] + + +# =========================================================================== +# get_chunk_coordinates_multiple (lines 958-961) +# =========================================================================== + + +class TestGetChunkCoordinatesMultiple: + """Tests for get_chunk_coordinates_multiple with same/different layer assertions.""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_same_layer(self, gen_graph): + """get_chunk_coordinates_multiple with L2 node IDs should return correct coordinates.""" + graph = self._build_graph(gen_graph) + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + # Get L2 parents + parent_a = graph.get_parents(np.array([sv_a], dtype=np.uint64), raw_only=True)[ + 0 + ] + parent_b = graph.get_parents(np.array([sv_b], dtype=np.uint64), raw_only=True)[ + 0 + ] + assert graph.get_chunk_layer(parent_a) == 2 + assert graph.get_chunk_layer(parent_b) == 2 + + coords = graph.get_chunk_coordinates_multiple( + np.array([parent_a, parent_b], dtype=np.uint64) + ) + assert coords.shape == (2, 3) + # parent_a is in chunk (0,0,0), parent_b is in chunk (1,0,0) + np.testing.assert_array_equal(coords[0], [0, 0, 0]) + np.testing.assert_array_equal(coords[1], [1, 0, 0]) + + @pytest.mark.timeout(30) + def test_different_layers_raises(self, gen_graph): + """get_chunk_coordinates_multiple with nodes at different layers should raise.""" + graph = self._build_graph(gen_graph) + sv_a = to_label(graph, 1, 0, 0, 0, 0) + parent_l2 = graph.get_parents(np.array([sv_a], dtype=np.uint64), raw_only=True)[ + 0 + ] + root = graph.get_root(sv_a) + + assert graph.get_chunk_layer(parent_l2) == 2 + assert graph.get_chunk_layer(root) == 4 + + with pytest.raises(AssertionError, match="must be same layer"): + graph.get_chunk_coordinates_multiple( + np.array([parent_l2, root], dtype=np.uint64) + ) + + @pytest.mark.timeout(30) + def test_empty_array(self, gen_graph): + """get_chunk_coordinates_multiple with empty array should return empty result.""" + graph = self._build_graph(gen_graph) + coords = graph.get_chunk_coordinates_multiple(np.array([], dtype=np.uint64)) + assert len(coords) == 0 + + +# =========================================================================== +# get_parent_chunk_id_multiple and get_parent_chunk_ids (lines 991, 996) +# =========================================================================== + + +class TestParentChunkIdMethods: + """Tests for get_parent_chunk_id_multiple and get_parent_chunk_ids.""" + + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + @pytest.mark.timeout(30) + def test_get_parent_chunk_id_multiple(self, gen_graph): + """get_parent_chunk_id_multiple should return parent chunk IDs for all nodes.""" + graph = self._build_graph(gen_graph) + sv_a = to_label(graph, 1, 0, 0, 0, 0) + sv_b = to_label(graph, 1, 1, 0, 0, 0) + # Get L2 parents + parent_a = graph.get_parents(np.array([sv_a], dtype=np.uint64), raw_only=True)[ + 0 + ] + parent_b = graph.get_parents(np.array([sv_b], dtype=np.uint64), raw_only=True)[ + 0 + ] + + parent_chunks = graph.get_parent_chunk_id_multiple( + np.array([parent_a, parent_b], dtype=np.uint64) + ) + assert len(parent_chunks) == 2 + for pc in parent_chunks: + assert graph.get_chunk_layer(pc) == 3 + + @pytest.mark.timeout(30) + def test_get_parent_chunk_ids(self, gen_graph): + """get_parent_chunk_ids should return all parent chunk IDs up the hierarchy.""" + graph = self._build_graph(gen_graph) + sv = to_label(graph, 1, 0, 0, 0, 0) + parent_chunk_ids = graph.get_parent_chunk_ids(sv) + # Should have parent chunk IDs for layers 2, 3, 4 + assert len(parent_chunk_ids) >= 2 + layers = [graph.get_chunk_layer(pc) for pc in parent_chunk_ids] + # Layers should be ascending (from layer 2 up) + for i in range(len(layers) - 1): + assert layers[i] < layers[i + 1] + + +# =========================================================================== +# read_chunk_edges (lines 1005-1007) +# =========================================================================== + + +class TestReadChunkEdges: + """Tests for read_chunk_edges method.""" + + @pytest.mark.timeout(30) + def test_read_chunk_edges_returns_dict(self, gen_graph): + """read_chunk_edges should return a dict (possibly empty for gs:// edges source).""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + sv = to_label(graph, 1, 0, 0, 0, 0) + parent = graph.get_parents(np.array([sv], dtype=np.uint64), raw_only=True)[0] + # read_chunk_edges uses io.edges.get_chunk_edges which reads from GCS/file. + # With gs:// edges source and no actual files, it should raise or return empty. + try: + result = graph.read_chunk_edges(np.array([parent], dtype=np.uint64)) + assert isinstance(result, dict) + except Exception: + # Expected: GCS access will fail in test env + pass + + +# =========================================================================== +# get_proofread_root_ids (lines 1017-1019) +# =========================================================================== + + +class TestGetProofreadRootIds: + """Tests for get_proofread_root_ids method.""" + + @pytest.mark.timeout(30) + def test_get_proofread_root_ids_no_ops(self, gen_graph): + """get_proofread_root_ids with no operations should return empty arrays.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + old_roots, new_roots = graph.get_proofread_root_ids() + assert len(old_roots) == 0 + assert len(new_roots) == 0 + + @pytest.mark.timeout(30) + def test_get_proofread_root_ids_after_merge(self, gen_graph): + """get_proofread_root_ids after a merge should return the old and new roots.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + old_roots, new_roots = graph.get_proofread_root_ids() + assert len(new_roots) >= 1 + assert result.new_root_ids[0] in new_roots + + +# =========================================================================== +# remove_edges via shim path (line 876) -- source_ids/sink_ids without atomic_edges +# =========================================================================== + + +class TestRemoveEdgesShim: + """Test remove_edges with source_ids and sink_ids but no atomic_edges (shim path).""" + + def _build_connected_graph(self, gen_graph): + """Build a 2-layer graph with two connected SVs.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + return graph + + @pytest.mark.timeout(30) + def test_remove_edges_with_source_sink_ids(self, gen_graph): + """Call remove_edges with source_ids/sink_ids (no atomic_edges) -- exercises shim.""" + graph = self._build_connected_graph(gen_graph) + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + + # Verify they share a root before split + assert graph.get_root(sv0) == graph.get_root(sv1) + + # Use source_ids/sink_ids (the shim path) instead of atomic_edges + result = graph.remove_edges( + "TestUser", + source_ids=sv0, + sink_ids=sv1, + mincut=False, + ) + assert result.new_root_ids is not None + assert len(result.new_root_ids) == 2 + + # After split, they should have different roots + assert graph.get_root(sv0) != graph.get_root(sv1) + + @pytest.mark.timeout(30) + def test_remove_edges_shim_mismatched_lengths(self, gen_graph): + """Shim path with mismatched source_ids/sink_ids lengths should raise.""" + graph = self._build_connected_graph(gen_graph) + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + + with pytest.raises(PreconditionError, match="same number"): + graph.remove_edges( + "TestUser", + source_ids=[sv0, sv0], + sink_ids=[sv1], + mincut=False, + ) + + +# =========================================================================== +# get_earliest_timestamp -- detailed test (bigtable/client.py coverage) +# =========================================================================== + + +class TestEarliestTimestamp: + """Tests for get_earliest_timestamp after operations exist.""" + + @pytest.mark.timeout(30) + def test_get_earliest_timestamp_after_merge(self, gen_graph): + """After creating a graph and performing a merge, get_earliest_timestamp should return a valid datetime.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + ], + edges=[], + timestamp=fake_ts, + ) + # Perform a merge to generate operation logs + graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + ts = graph.get_earliest_timestamp() + assert ts is not None + assert isinstance(ts, datetime) + + @pytest.mark.timeout(30) + def test_get_earliest_timestamp_no_ops(self, gen_graph): + """On a fresh graph with no operations, get_earliest_timestamp should return None.""" + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + ts = graph.get_earliest_timestamp() + # No operation logs, so should be None + assert ts is None or isinstance(ts, datetime) diff --git a/pychunkedgraph/tests/test_chunks_hierarchy.py b/pychunkedgraph/tests/test_chunks_hierarchy.py new file mode 100644 index 000000000..40841997d --- /dev/null +++ b/pychunkedgraph/tests/test_chunks_hierarchy.py @@ -0,0 +1,87 @@ +"""Tests for pychunkedgraph.graph.chunks.hierarchy""" + +import numpy as np + +from pychunkedgraph.graph.chunks import hierarchy +from pychunkedgraph.graph.chunks import utils as chunk_utils + +from .helpers import to_label + + +class TestGetChildrenChunkCoords: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=5) + coords = hierarchy.get_children_chunk_coords(graph.meta, 3, [0, 0, 0]) + # Layer 3 chunk at [0,0,0] has fanout=2 children: 2^3 = 8 max + assert len(coords) > 0 + assert coords.shape[1] == 3 + + +class TestGetChildrenChunkIds: + def test_layer_1_returns_empty(self, gen_graph): + graph = gen_graph(n_layers=4) + node_id = to_label(graph, 1, 0, 0, 0, 1) + result = hierarchy.get_children_chunk_ids(graph.meta, node_id) + assert len(result) == 0 + + def test_layer_2_returns_self(self, gen_graph): + graph = gen_graph(n_layers=4) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=2, x=0, y=0, z=0) + result = hierarchy.get_children_chunk_ids(graph.meta, chunk_id) + assert len(result) == 1 + + def test_layer_3(self, gen_graph): + graph = gen_graph(n_layers=5) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=3, x=0, y=0, z=0) + result = hierarchy.get_children_chunk_ids(graph.meta, chunk_id) + assert len(result) > 0 + + +class TestGetParentChunkId: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=5) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=2, x=0, y=0, z=0) + parent_id = hierarchy.get_parent_chunk_id(graph.meta, chunk_id, 3) + assert chunk_utils.get_chunk_layer(graph.meta, parent_id) == 3 + + def test_parent_coords(self, gen_graph): + graph = gen_graph(n_layers=5) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=2, x=2, y=3, z=1) + parent_id = hierarchy.get_parent_chunk_id(graph.meta, chunk_id, 3) + coords = chunk_utils.get_chunk_coordinates(graph.meta, parent_id) + # With fanout=2, coords should be floor(original / 2) + np.testing.assert_array_equal(coords, [1, 1, 0]) + + +class TestGetParentChunkIdMultiple: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=5) + ids = np.array( + [ + chunk_utils.get_chunk_id(graph.meta, layer=2, x=0, y=0, z=0), + chunk_utils.get_chunk_id(graph.meta, layer=2, x=1, y=0, z=0), + ], + dtype=np.uint64, + ) + result = hierarchy.get_parent_chunk_id_multiple(graph.meta, ids) + assert len(result) == 2 + for pid in result: + assert chunk_utils.get_chunk_layer(graph.meta, pid) == 3 + + +class TestGetParentChunkIds: + def test_returns_chain(self, gen_graph): + graph = gen_graph(n_layers=5) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=2, x=0, y=0, z=0) + result = hierarchy.get_parent_chunk_ids(graph.meta, chunk_id) + # Should include chunk_id + parents up to layer_count + assert len(result) >= 2 + + +class TestGetParentChunkIdDict: + def test_returns_dict(self, gen_graph): + graph = gen_graph(n_layers=5) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=2, x=0, y=0, z=0) + result = hierarchy.get_parent_chunk_id_dict(graph.meta, chunk_id) + assert isinstance(result, dict) + assert 2 in result diff --git a/pychunkedgraph/tests/test_chunks_utils.py b/pychunkedgraph/tests/test_chunks_utils.py new file mode 100644 index 000000000..e5830f80d --- /dev/null +++ b/pychunkedgraph/tests/test_chunks_utils.py @@ -0,0 +1,133 @@ +"""Tests for pychunkedgraph.graph.chunks.utils""" + +import numpy as np +import pytest + +from pychunkedgraph.graph.chunks import utils as chunk_utils + + +class TestGetChunkLayer: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + from .helpers import to_label + + node_id = to_label(graph, 1, 0, 0, 0, 1) + assert chunk_utils.get_chunk_layer(graph.meta, node_id) == 1 + + def test_higher_layer(self, gen_graph): + graph = gen_graph(n_layers=4) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=3, x=0, y=0, z=0) + assert chunk_utils.get_chunk_layer(graph.meta, chunk_id) == 3 + + +class TestGetChunkLayers: + def test_empty(self, gen_graph): + graph = gen_graph(n_layers=4) + result = chunk_utils.get_chunk_layers(graph.meta, []) + assert len(result) == 0 + + def test_multiple(self, gen_graph): + graph = gen_graph(n_layers=4) + from .helpers import to_label + + ids = [ + to_label(graph, 1, 0, 0, 0, 1), + to_label(graph, 1, 1, 0, 0, 2), + ] + layers = chunk_utils.get_chunk_layers(graph.meta, ids) + np.testing.assert_array_equal(layers, [1, 1]) + + +class TestGetChunkCoordinates: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=2, x=1, y=2, z=3) + coords = chunk_utils.get_chunk_coordinates(graph.meta, chunk_id) + np.testing.assert_array_equal(coords, [1, 2, 3]) + + +class TestGetChunkCoordinatesMultiple: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + ids = [ + chunk_utils.get_chunk_id(graph.meta, layer=2, x=0, y=0, z=0), + chunk_utils.get_chunk_id(graph.meta, layer=2, x=1, y=2, z=3), + ] + coords = chunk_utils.get_chunk_coordinates_multiple(graph.meta, ids) + np.testing.assert_array_equal(coords[0], [0, 0, 0]) + np.testing.assert_array_equal(coords[1], [1, 2, 3]) + + def test_empty(self, gen_graph): + graph = gen_graph(n_layers=4) + result = chunk_utils.get_chunk_coordinates_multiple(graph.meta, []) + assert result.shape == (0, 3) + + +class TestGetChunkId: + def test_from_node_id(self, gen_graph): + graph = gen_graph(n_layers=4) + from .helpers import to_label + + node_id = to_label(graph, 1, 2, 3, 1, 5) + chunk_id = chunk_utils.get_chunk_id(graph.meta, node_id=node_id) + coords = chunk_utils.get_chunk_coordinates(graph.meta, chunk_id) + np.testing.assert_array_equal(coords, [2, 3, 1]) + + def test_from_components(self, gen_graph): + graph = gen_graph(n_layers=4) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=2, x=1, y=2, z=3) + assert chunk_utils.get_chunk_layer(graph.meta, chunk_id) == 2 + coords = chunk_utils.get_chunk_coordinates(graph.meta, chunk_id) + np.testing.assert_array_equal(coords, [1, 2, 3]) + + +class TestComputeChunkIdOutOfRange: + def test_raises(self, gen_graph): + graph = gen_graph(n_layers=4) + with pytest.raises(ValueError, match="out of range"): + chunk_utils._compute_chunk_id(graph.meta, layer=2, x=9999, y=0, z=0) + + +class TestGetChunkIdsFromCoords: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + coords = np.array([[0, 0, 0], [1, 0, 0]]) + result = chunk_utils.get_chunk_ids_from_coords(graph.meta, 2, coords) + assert len(result) == 2 + for cid in result: + assert chunk_utils.get_chunk_layer(graph.meta, cid) == 2 + + +class TestGetChunkIdsFromNodeIds: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + from .helpers import to_label + + ids = np.array( + [ + to_label(graph, 1, 0, 0, 0, 1), + to_label(graph, 1, 1, 0, 0, 2), + ], + dtype=np.uint64, + ) + result = chunk_utils.get_chunk_ids_from_node_ids(graph.meta, ids) + assert len(result) == 2 + + def test_empty(self, gen_graph): + graph = gen_graph(n_layers=4) + result = chunk_utils.get_chunk_ids_from_node_ids(graph.meta, []) + assert len(result) == 0 + + +class TestNormalizeBoundingBox: + def test_none(self, gen_graph): + graph = gen_graph(n_layers=4) + assert chunk_utils.normalize_bounding_box(graph.meta, None, False) is None + + +class TestGetBoundingChildrenChunks: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=5) + result = chunk_utils.get_bounding_children_chunks(graph.meta, 3, (0, 0, 0), 2) + assert len(result) > 0 + assert result.shape[1] == 3 diff --git a/pychunkedgraph/tests/test_connectivity.py b/pychunkedgraph/tests/test_connectivity.py new file mode 100644 index 000000000..c27d99b6b --- /dev/null +++ b/pychunkedgraph/tests/test_connectivity.py @@ -0,0 +1,119 @@ +"""Tests for pychunkedgraph.graph.connectivity.nodes""" + +import numpy as np + +from pychunkedgraph.graph.types import Agglomeration +from pychunkedgraph.graph.connectivity.nodes import edge_exists + + +def _make_agg(node_id, supervoxels, out_edges): + """Helper to create an Agglomeration with the fields needed by edge_exists.""" + return Agglomeration( + node_id=np.uint64(node_id), + supervoxels=np.array(supervoxels, dtype=np.uint64), + in_edges=np.empty((0, 2), dtype=np.uint64), + out_edges=np.array(out_edges, dtype=np.uint64).reshape(-1, 2), + cross_edges=np.empty((0, 2), dtype=np.uint64), + ) + + +class TestEdgeExists: + def test_edge_exists_true(self): + """Two agglomerations with edges pointing to each other's supervoxels.""" + # agg1 owns supervoxels [10, 11], agg2 owns supervoxels [20, 21]. + # agg1 has an out_edge from sv 10 -> sv 20 (which belongs to agg2) + # agg2 has an out_edge from sv 20 -> sv 10 (which belongs to agg1) + agg1 = _make_agg( + node_id=1, + supervoxels=[10, 11], + out_edges=[[10, 20]], + ) + agg2 = _make_agg( + node_id=2, + supervoxels=[20, 21], + out_edges=[[20, 10]], + ) + assert edge_exists([agg1, agg2]) is True + + def test_edge_exists_true_one_direction(self): + """Edge exists is True even if only one direction has a cross-reference.""" + # agg1 out_edge target (sv 20) belongs to agg2 -> True on the first condition + agg1 = _make_agg( + node_id=1, + supervoxels=[10, 11], + out_edges=[[10, 20]], + ) + agg2 = _make_agg( + node_id=2, + supervoxels=[20, 21], + out_edges=[[20, 30]], # target 30 is not in agg1 + ) + # For this to work, sv 30 must be in the supervoxel_parent_d. + # Since 30 is not in either agglomeration's supervoxels, a KeyError + # would occur when checking supervoxel_parent_d[t2]. + # The function iterates zip(targets1, targets2), checking t1 first. + # If t1 matches, it returns True before checking t2. + # So agg1.out_edges target=20 (belongs to agg2) triggers True. + # BUT: zip pairs them, and both t1 and t2 are checked. + # Actually, the condition uses OR: if t1 belongs to agg2 OR t2 belongs to agg1. + # However, supervoxel_parent_d[t2] will KeyError if t2=30 is not in the dict. + # Let's fix: put sv 30 in a third agg, or just make the targets safe. + # Instead, let's set up so that sv 30 doesn't cause a problem: + # We need all targets to be in the supervoxel_parent_d. + # Add sv 30 to agg2's supervoxels. + agg2_fixed = _make_agg( + node_id=2, + supervoxels=[20, 21, 30], + out_edges=[[20, 30]], # target 30 belongs to agg2 itself (not agg1) + ) + assert edge_exists([agg1, agg2_fixed]) is True + + def test_edge_exists_false(self): + """Two agglomerations with no cross-references between them.""" + # agg1 out_edge targets sv 11 (its own supervoxel), + # agg2 out_edge targets sv 21 (its own supervoxel). + # Neither target belongs to the other agglomeration. + agg1 = _make_agg( + node_id=1, + supervoxels=[10, 11], + out_edges=[[10, 11]], + ) + agg2 = _make_agg( + node_id=2, + supervoxels=[20, 21], + out_edges=[[20, 21]], + ) + assert edge_exists([agg1, agg2]) is False + + def test_edge_exists_single_agg(self): + """Single agglomeration returns False (no combinations to iterate).""" + agg = _make_agg( + node_id=1, + supervoxels=[10, 11], + out_edges=[[10, 11]], + ) + assert edge_exists([agg]) is False + + def test_edge_exists_empty_list(self): + """Empty list of agglomerations returns False.""" + assert edge_exists([]) is False + + def test_edge_exists_three_agglomerations(self): + """Three agglomerations where only two have a cross-reference.""" + agg1 = _make_agg( + node_id=1, + supervoxels=[10, 11], + out_edges=[[10, 20]], + ) + agg2 = _make_agg( + node_id=2, + supervoxels=[20, 21], + out_edges=[[20, 10]], + ) + agg3 = _make_agg( + node_id=3, + supervoxels=[30, 31], + out_edges=[[30, 31]], + ) + # The combination (agg1, agg2) has cross-references, so True. + assert edge_exists([agg1, agg2, agg3]) is True diff --git a/pychunkedgraph/tests/test_cutting.py b/pychunkedgraph/tests/test_cutting.py new file mode 100644 index 000000000..40a1842d6 --- /dev/null +++ b/pychunkedgraph/tests/test_cutting.py @@ -0,0 +1,1418 @@ +"""Tests for pychunkedgraph.graph.cutting""" + +import numpy as np +import pytest + +from pychunkedgraph.graph.cutting import ( + IsolatingCutException, + LocalMincutGraph, + merge_cross_chunk_edges_graph_tool, + run_multicut, +) +from pychunkedgraph.graph.edges import Edges +from pychunkedgraph.graph.exceptions import PostconditionError, PreconditionError + + +class TestIsolatingCutException: + def test_is_exception_subclass(self): + """IsolatingCutException is a proper Exception subclass.""" + assert issubclass(IsolatingCutException, Exception) + + def test_can_be_raised_and_caught(self): + with pytest.raises(IsolatingCutException): + raise IsolatingCutException("Source") + + def test_message_preserved(self): + exc = IsolatingCutException("Sink") + assert str(exc) == "Sink" + + +class TestMergeCrossChunkEdgesGraphTool: + def test_merge_cross_chunk_edges_basic(self): + """Cross-chunk edges (inf affinity) cause their endpoints to be merged. + + Edges: + 1--2 (aff=0.5, regular) + 2--3 (aff=inf, cross-chunk -> merge 2 and 3) + 3--4 (aff=0.3, regular) + + After merging, node 3 is remapped to node 2 (min of {2,3}). + The cross-chunk edge (2--3) is removed from the output. + The remaining edges become: + 1--2 (aff=0.5) + 2--4 (aff=0.3) [was 3--4, but 3 is now remapped to 2] + """ + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.5, np.inf, 0.3], dtype=np.float32) + + mapped_edges, mapped_affs, mapping, complete_mapping, remapping = ( + merge_cross_chunk_edges_graph_tool(edges, affs) + ) + + # Cross-chunk edge is removed; 2 output edges remain + assert mapped_edges.shape[0] == 2 + assert mapped_affs.shape[0] == 2 + + # Affinities of the non-cross-chunk edges are preserved + np.testing.assert_array_almost_equal( + np.sort(mapped_affs), np.array([0.3, 0.5], dtype=np.float32) + ) + + # The mapping should show that 2 and 3 map to the same representative (min=2) + assert len(remapping) == 1 + rep_node = list(remapping.keys())[0] + assert rep_node == 2 + merged_nodes = set(remapping[rep_node]) + assert 2 in merged_nodes + assert 3 in merged_nodes + + # All unique nodes appear in complete_mapping + all_mapped_from = set(complete_mapping[:, 0]) + assert {1, 2, 3, 4}.issubset(all_mapped_from) + + def test_merge_cross_chunk_edges_no_cross_chunk(self): + """When all affinities are finite, no merging occurs. + + All edges are returned as-is (no cross-chunk edges to remove). + """ + edges = np.array([[10, 20], [20, 30], [30, 40]], dtype=np.uint64) + affs = np.array([0.5, 0.8, 0.3], dtype=np.float32) + + mapped_edges, mapped_affs, mapping, complete_mapping, remapping = ( + merge_cross_chunk_edges_graph_tool(edges, affs) + ) + + # No edges removed + assert mapped_edges.shape[0] == 3 + assert mapped_affs.shape[0] == 3 + + # No remapping occurred + assert len(remapping) == 0 + + # Affinities are unchanged + np.testing.assert_array_almost_equal(mapped_affs, affs) + + # All nodes map to themselves in complete_mapping + for row in complete_mapping: + assert row[0] == row[1] + + def test_merge_cross_chunk_edges_all_cross_chunk(self): + """When all edges are cross-chunk, all edges are removed from output.""" + edges = np.array([[1, 2], [2, 3]], dtype=np.uint64) + affs = np.array([np.inf, np.inf], dtype=np.float32) + + mapped_edges, mapped_affs, mapping, complete_mapping, remapping = ( + merge_cross_chunk_edges_graph_tool(edges, affs) + ) + + # All edges were cross-chunk, so no mapped edges remain + assert mapped_edges.shape[0] == 0 + assert mapped_affs.shape[0] == 0 + + def test_merge_cross_chunk_edges_multiple_components(self): + """Multiple separate cross-chunk merges in a single call. + + Edges: + 1--2 (inf) -> merge into {1,2}, rep=1 + 3--4 (inf) -> merge into {3,4}, rep=3 + 1--3 (0.7) -> becomes 1--3 after remapping + """ + edges = np.array([[1, 2], [3, 4], [1, 3]], dtype=np.uint64) + affs = np.array([np.inf, np.inf, 0.7], dtype=np.float32) + + mapped_edges, mapped_affs, mapping, complete_mapping, remapping = ( + merge_cross_chunk_edges_graph_tool(edges, affs) + ) + + # Only 1 non-cross-chunk edge remains + assert mapped_edges.shape[0] == 1 + assert mapped_affs.shape[0] == 1 + np.testing.assert_array_almost_equal(mapped_affs, [0.7]) + + # Two remapping groups + assert len(remapping) == 2 + + +class TestLocalMincutGraph: + """Tests for LocalMincutGraph initialization and mincut computation.""" + + def test_init_basic(self): + """Create a simple 4-node line graph with a weak middle edge. + + Graph: 1 --0.9-- 2 --0.1-- 3 --0.9-- 4 + Sources: [1], Sinks: [4] + + The graph should initialize successfully and have the expected + source/sink graph ids set. + """ + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.9, 0.1, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=True, + ) + assert graph.weighted_graph is not None + assert graph.unique_supervoxel_ids is not None + assert len(graph.source_graph_ids) == 1 + assert len(graph.sink_graph_ids) == 1 + # Sources and sinks should be mapped correctly + assert np.array_equal(graph.sources, sources) + assert np.array_equal(graph.sinks, sinks) + + def test_init_with_cross_chunk_edges(self): + """Initialization with a mix of regular and cross-chunk edges. + + Graph: 1 --0.5-- 2 --inf-- 3 --0.5-- 4 + The inf edge merges 2 and 3 into one node. + Sources: [1], Sinks: [4] + """ + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.5, np.inf, 0.5], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=True, + ) + # After merging cross chunk edges 2 and 3, we should have fewer unique ids + assert graph.weighted_graph is not None + assert len(graph.cross_chunk_edge_remapping) == 1 + + def test_init_only_cross_chunk_raises(self): + """All inf affinities should raise PostconditionError. + + When every edge is a cross-chunk edge, all edges are removed after + merging, leaving an empty graph. This should raise PostconditionError. + """ + edges = np.array([[1, 2], [2, 3]], dtype=np.uint64) + affs = np.array([np.inf, np.inf], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([3], dtype=np.uint64) + + with pytest.raises(PostconditionError, match="cross chunk edges"): + LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + ) + + def test_compute_mincut_direct(self): + """Compute mincut with path_augment=False on a simple 2-node graph. + + Graph: 1 --0.5-- 2 + Sources: [1], Sinks: [2] + + The only possible cut is the single edge between 1 and 2. + """ + edges = np.array([[1, 2]], dtype=np.uint64) + affs = np.array([0.5], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([2], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + # The mincut should return edges to cut + assert len(result) > 0 + # The returned edges should contain the edge (1,2) or (2,1) + result_set = set(map(tuple, result)) + assert (1, 2) in result_set or (2, 1) in result_set + + def test_compute_mincut_path_augmented(self): + """Compute mincut with path_augment=True (default) on a line graph. + + Graph: 1 --0.9-- 2 --0.1-- 3 --0.9-- 4 + Sources: [1], Sinks: [4] + + The weakest edge is 2--3, so the mincut should cut there. + """ + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.9, 0.1, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + assert len(result) > 0 + # The cut should include the weak edge (2,3) or (3,2) + result_set = set(map(tuple, result)) + assert (2, 3) in result_set or (3, 2) in result_set + # The strong edges should NOT be in the cut + assert (1, 2) not in result_set + assert (3, 4) not in result_set + + def test_compute_mincut_line_graph_cuts_weakest(self): + """Line graph with clear weakest edge - mincut should cut it. + + Graph: 10 --0.8-- 20 --0.01-- 30 --0.8-- 40 + Sources: [10], Sinks: [40] + """ + edges = np.array([[10, 20], [20, 30], [30, 40]], dtype=np.uint64) + affs = np.array([0.8, 0.01, 0.8], dtype=np.float32) + sources = np.array([10], dtype=np.uint64) + sinks = np.array([40], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + assert len(result) > 0 + result_set = set(map(tuple, result)) + assert (20, 30) in result_set or (30, 20) in result_set + + def test_compute_mincut_split_preview(self): + """Compute mincut with split_preview=True returns connected components. + + Graph: 1 --0.9-- 2 --0.1-- 3 --0.9-- 4 + Sources: [1], Sinks: [4] + """ + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.9, 0.1, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + # split_preview returns (supervoxel_ccs, illegal_split) + supervoxel_ccs, illegal_split = result + assert isinstance(supervoxel_ccs, list) + assert len(supervoxel_ccs) >= 2 + assert isinstance(illegal_split, bool) + # First component should contain source(s), second should contain sink(s) + assert 1 in supervoxel_ccs[0] or 2 in supervoxel_ccs[0] + assert 4 in supervoxel_ccs[1] or 3 in supervoxel_ccs[1] + + +class TestRunMulticut: + """Tests for the run_multicut function.""" + + def test_basic_split(self): + """Two groups connected by a weak edge -- mincut should cut that edge. + + Graph: 1 --0.9-- 2 --0.05-- 3 --0.9-- 4 + Sources: [1], Sinks: [4] + """ + node_ids1 = np.array([1, 2, 3], dtype=np.uint64) + node_ids2 = np.array([2, 3, 4], dtype=np.uint64) + affinities = np.array([0.9, 0.05, 0.9], dtype=np.float32) + + edges = Edges(node_ids1, node_ids2, affinities=affinities) + result = run_multicut( + edges, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(4)], + path_augment=True, + disallow_isolating_cut=False, + ) + assert len(result) > 0 + result_set = set(map(tuple, result)) + assert (2, 3) in result_set or (3, 2) in result_set + + def test_basic_split_direct(self): + """Same as test_basic_split but with path_augment=False.""" + node_ids1 = np.array([1, 2, 3], dtype=np.uint64) + node_ids2 = np.array([2, 3, 4], dtype=np.uint64) + affinities = np.array([0.9, 0.05, 0.9], dtype=np.float32) + + edges = Edges(node_ids1, node_ids2, affinities=affinities) + result = run_multicut( + edges, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(4)], + path_augment=False, + disallow_isolating_cut=False, + ) + assert len(result) > 0 + result_set = set(map(tuple, result)) + assert (2, 3) in result_set or (3, 2) in result_set + + def test_no_edges_raises(self): + """Graph with only cross-chunk edges raises PostconditionError. + + When all edges have infinite affinity, the local graph is empty after + merging cross-chunk edges, and LocalMincutGraph raises PostconditionError. + """ + node_ids1 = np.array([1, 2], dtype=np.uint64) + node_ids2 = np.array([2, 3], dtype=np.uint64) + affinities = np.array([np.inf, np.inf], dtype=np.float32) + + edges = Edges(node_ids1, node_ids2, affinities=affinities) + with pytest.raises(PostconditionError, match="cross chunk edges"): + run_multicut( + edges, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(3)], + ) + + def test_split_preview_mode(self): + """run_multicut with split_preview=True returns (ccs, illegal_split).""" + node_ids1 = np.array([1, 2, 3], dtype=np.uint64) + node_ids2 = np.array([2, 3, 4], dtype=np.uint64) + affinities = np.array([0.9, 0.05, 0.9], dtype=np.float32) + + edges = Edges(node_ids1, node_ids2, affinities=affinities) + result = run_multicut( + edges, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(4)], + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + supervoxel_ccs, illegal_split = result + assert isinstance(supervoxel_ccs, list) + assert len(supervoxel_ccs) >= 2 + assert isinstance(illegal_split, bool) + + +class TestMergeCrossChunkEdgesOverlap: + """Test edge cases in merge_cross_chunk_edges_graph_tool.""" + + def test_duplicate_cross_chunk_edges(self): + """Duplicate cross-chunk edges should still merge correctly.""" + edges = np.array([[1, 2], [1, 2], [2, 3]], dtype=np.uint64) + affs = np.array([np.inf, np.inf, 0.5], dtype=np.float32) + + mapped_edges, mapped_affs, mapping, complete_mapping, remapping = ( + merge_cross_chunk_edges_graph_tool(edges, affs) + ) + + # Only the finite-affinity edge should remain + assert mapped_edges.shape[0] == 1 + assert mapped_affs[0] == pytest.approx(0.5) + + def test_self_loop_after_merge(self): + """When merging creates a self-loop, it should be present but with correct count.""" + # 1-2 inf, 1-2 finite -> after merge, 1-1 (self-loop) is created + edges = np.array([[1, 2], [1, 2]], dtype=np.uint64) + affs = np.array([np.inf, 0.5], dtype=np.float32) + + mapped_edges, mapped_affs, mapping, complete_mapping, remapping = ( + merge_cross_chunk_edges_graph_tool(edges, affs) + ) + + # One non-inf edge remains, but both endpoints map to same node + assert mapped_edges.shape[0] == 1 + assert mapped_edges[0][0] == mapped_edges[0][1] + + def test_chain_of_cross_chunk_edges(self): + """A chain of cross-chunk edges: 1-2(inf), 2-3(inf), 3-4(inf). + All should merge into one component.""" + edges = np.array([[1, 2], [2, 3], [3, 4], [1, 5]], dtype=np.uint64) + affs = np.array([np.inf, np.inf, np.inf, 0.7], dtype=np.float32) + + mapped_edges, mapped_affs, mapping, complete_mapping, remapping = ( + merge_cross_chunk_edges_graph_tool(edges, affs) + ) + + # Only 1 non-cross edge remains + assert mapped_edges.shape[0] == 1 + assert mapped_affs[0] == pytest.approx(0.7) + # All of 1,2,3,4 should be in one remapping group + assert len(remapping) == 1 + rep = list(remapping.keys())[0] + assert rep == 1 # min of {1,2,3,4} + assert set(remapping[rep]) == {1, 2, 3, 4} + + +class TestRemapCutEdgeSet: + """Test _remap_cut_edge_set handles cross-chunk remapping correctly.""" + + def test_remap_with_cross_chunk_remapping(self): + """When cross-chunk edge remapping is present, cut edges should expand to all + mapped supervoxels.""" + # Graph: 1 --0.5-- 2 --inf-- 3 --0.5-- 4 + # Nodes 2 and 3 merge -> rep=2, remapping[2]=[2,3] + # Source: [1], Sink: [4] + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.5, np.inf, 0.5], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + + # The cross_chunk_edge_remapping should exist + assert len(graph.cross_chunk_edge_remapping) == 1 + result = graph.compute_mincut() + + # Result should contain edges from the original edge set + result_set = set(map(tuple, result)) + # At least one of the original edges should appear + assert len(result_set) > 0 + # All returned edges should be from the original cg_edges + for edge in result: + assert tuple(edge) in {(1, 2), (2, 1), (2, 3), (3, 2), (3, 4), (4, 3)} + + def test_remap_no_cross_chunk(self): + """Without cross-chunk edges, remap should just return original supervoxel ids.""" + edges = np.array([[1, 2], [2, 3]], dtype=np.uint64) + affs = np.array([0.9, 0.1], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([3], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + + assert len(graph.cross_chunk_edge_remapping) == 0 + result = graph.compute_mincut() + result_set = set(map(tuple, result)) + # The weak edge 2-3 should be cut + assert (2, 3) in result_set or (3, 2) in result_set + + +class TestSplitPreviewConnectedComponents: + """Test _get_split_preview_connected_components orders CCs correctly.""" + + def test_source_first_sink_second(self): + """split_preview should return sources in ccs[0] and sinks in ccs[1].""" + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.9, 0.01, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + + supervoxel_ccs, illegal_split = graph.compute_mincut() + assert isinstance(supervoxel_ccs, list) + assert len(supervoxel_ccs) >= 2 + # First CC should contain source supervoxels + assert 1 in supervoxel_ccs[0] + # Second CC should contain sink supervoxels + assert 4 in supervoxel_ccs[1] + assert isinstance(illegal_split, bool) + assert not illegal_split + + def test_multiple_sources_and_sinks(self): + """With multiple sources and sinks, each group stays in its own CC.""" + # 1-2-3-4-5-6, cut between 3-4 + edges = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]], dtype=np.uint64) + affs = np.array([0.9, 0.9, 0.01, 0.9, 0.9], dtype=np.float32) + sources = np.array([1, 2], dtype=np.uint64) + sinks = np.array([5, 6], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + + supervoxel_ccs, illegal_split = graph.compute_mincut() + # Both sources should be in ccs[0] + assert 1 in supervoxel_ccs[0] + assert 2 in supervoxel_ccs[0] + # Both sinks should be in ccs[1] + assert 5 in supervoxel_ccs[1] + assert 6 in supervoxel_ccs[1] + + def test_split_preview_with_cross_chunk(self): + """split_preview with cross-chunk edges should expand remapped nodes in CCs.""" + # 1 --0.5-- 2 --inf-- 3 --0.01-- 4 + # Nodes 2,3 merge. Cut between merged(2,3) and 4. + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.5, np.inf, 0.01], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + + supervoxel_ccs, illegal_split = graph.compute_mincut() + assert len(supervoxel_ccs) >= 2 + # After expanding cross-chunk remapping, source CC should contain 1, 2, 3 + all_source_svs = set(supervoxel_ccs[0]) + assert 1 in all_source_svs + # 2 and 3 were merged and should appear in same CC as source + assert 2 in all_source_svs or 3 in all_source_svs + # Sink CC should contain 4 + assert 4 in set(supervoxel_ccs[1]) + + +class TestSanityCheck: + """Test _sink_and_source_connectivity_sanity_check edge cases.""" + + def test_split_preview_illegal_split_flag(self): + """In split_preview mode, when sanity check would normally raise, + illegal_split should be True rather than raising an error.""" + # Create a graph where the cut might produce an unusual partition. + edges = np.array([[1, 2], [2, 3], [1, 3]], dtype=np.uint64) + affs = np.array([0.01, 0.01, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([3], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + + result = graph.compute_mincut() + supervoxel_ccs, illegal_split = result + # Should return valid result without raising + assert isinstance(supervoxel_ccs, list) + assert isinstance(illegal_split, bool) + + def test_non_preview_postcondition_error_on_empty_cut(self): + """run_multicut raises PostconditionError when mincut produces empty cut set.""" + # When all edges are cross-chunk, PostconditionError is raised + node_ids1 = np.array([1, 2], dtype=np.uint64) + node_ids2 = np.array([2, 3], dtype=np.uint64) + affinities = np.array([np.inf, np.inf], dtype=np.float32) + + edges_obj = Edges(node_ids1, node_ids2, affinities=affinities) + with pytest.raises(PostconditionError): + run_multicut( + edges_obj, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(3)], + ) + + +class TestRunMulticutSplitPreview: + """Test run_multicut in split_preview mode returns correct structure.""" + + def test_split_preview_returns_ccs_and_flag(self): + """run_multicut with split_preview=True should return (ccs, illegal_split).""" + node_ids1 = np.array([1, 2, 3], dtype=np.uint64) + node_ids2 = np.array([2, 3, 4], dtype=np.uint64) + affinities = np.array([0.9, 0.01, 0.9], dtype=np.float32) + + edges_obj = Edges(node_ids1, node_ids2, affinities=affinities) + result = run_multicut( + edges_obj, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(4)], + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + + supervoxel_ccs, illegal_split = result + assert isinstance(supervoxel_ccs, list) + assert len(supervoxel_ccs) >= 2 + assert isinstance(illegal_split, bool) + + # Source side CC + assert 1 in supervoxel_ccs[0] + # Sink side CC + assert 4 in supervoxel_ccs[1] + + def test_split_preview_with_path_augment(self): + """run_multicut with split_preview=True and path_augment=True.""" + node_ids1 = np.array([1, 2, 3, 4], dtype=np.uint64) + node_ids2 = np.array([2, 3, 4, 5], dtype=np.uint64) + affinities = np.array([0.9, 0.9, 0.01, 0.9], dtype=np.float32) + + edges_obj = Edges(node_ids1, node_ids2, affinities=affinities) + result = run_multicut( + edges_obj, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(5)], + split_preview=True, + path_augment=True, + disallow_isolating_cut=False, + ) + + supervoxel_ccs, illegal_split = result + assert len(supervoxel_ccs) >= 2 + # Source side + assert 1 in supervoxel_ccs[0] + # Sink side + assert 5 in supervoxel_ccs[1] + + def test_split_preview_larger_graph(self): + """split_preview on a larger graph with a clear cut point.""" + # Two clusters connected by a single weak edge + # Cluster A: 1-2, 1-3, 2-3 (all strong) + # Cluster B: 4-5, 4-6, 5-6 (all strong) + # Bridge: 3-4 (weak) + node_ids1 = np.array([1, 1, 2, 4, 4, 5, 3], dtype=np.uint64) + node_ids2 = np.array([2, 3, 3, 5, 6, 6, 4], dtype=np.uint64) + affinities = np.array([0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.01], dtype=np.float32) + + edges_obj = Edges(node_ids1, node_ids2, affinities=affinities) + result = run_multicut( + edges_obj, + source_ids=[np.uint64(1)], + sink_ids=[np.uint64(6)], + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + + supervoxel_ccs, illegal_split = result + source_cc = set(supervoxel_ccs[0]) + sink_cc = set(supervoxel_ccs[1]) + # Source cluster + assert {1, 2, 3}.issubset(source_cc) + # Sink cluster + assert {4, 5, 6}.issubset(sink_cc) + assert not illegal_split + + +class TestLocalMincutGraphWithLogger: + """Test that logging branches are exercised without errors.""" + + def test_init_with_logger(self): + """Passing a logger should not break initialization.""" + import logging + + logger = logging.getLogger("test_cutting_logger") + logger.setLevel(logging.DEBUG) + + edges = np.array([[1, 2], [2, 3]], dtype=np.uint64) + affs = np.array([0.9, 0.1], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([3], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + logger=logger, + ) + + assert graph.weighted_graph is not None + + def test_compute_mincut_with_logger(self): + """Compute mincut with a logger should produce debug messages.""" + import logging + + logger = logging.getLogger("test_cutting_mincut_logger") + logger.setLevel(logging.DEBUG) + + edges = np.array([[1, 2], [2, 3]], dtype=np.uint64) + affs = np.array([0.9, 0.1], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([3], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + logger=logger, + ) + + result = graph.compute_mincut() + assert len(result) > 0 + + +class TestFilterGraphConnectedComponents: + """Test edge cases in _filter_graph_connected_components.""" + + def test_disconnected_source_sink_raises(self): + """When sources and sinks are in different connected components, should raise.""" + # Two disconnected components: {1,2} and {3,4} + # Sources in one, sinks in other + edges = np.array([[1, 2], [3, 4]], dtype=np.uint64) + affs = np.array([0.5, 0.5], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + + with pytest.raises(PreconditionError): + graph.compute_mincut() + + +class TestPartitionEdgesWithinLabel: + """Test the partition_edges_within_label method.""" + + def test_all_edges_within_labels(self): + """When all out-edges of a component go to labeled nodes, returns True.""" + # Simple triangle: 1-2-3-1, sources=[1,2], sinks=[3] + edges = np.array([[1, 2], [2, 3], [1, 3]], dtype=np.uint64) + affs = np.array([0.9, 0.9, 0.9], dtype=np.float32) + sources = np.array([1, 2], dtype=np.uint64) + sinks = np.array([3], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + + # All nodes are labeled, so any CC should return True + result = graph.partition_edges_within_label(graph.source_graph_ids) + assert isinstance(result, bool) + + def test_edges_outside_labels_returns_false(self): + """When a node has edges to an unlabeled node, returns False.""" + # 1 --0.9-- 2 --0.9-- 3 --0.9-- 4 + # sources=[1], sinks=[4], so nodes 2 and 3 are unlabeled + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.9, 0.9, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + + # The source node 1 has edges to node 2 which is not a label node + result = graph.partition_edges_within_label(graph.source_graph_ids) + assert result is False + + +class TestAugmentMincutCapacityOverlap: + """Test path augmentation when source and sink paths overlap.""" + + def test_overlapping_paths_resolved(self): + """Graph with overlapping shortest paths between sources and sinks. + + Graph topology: + 1--2--3--4--5 + | | + 6--7--8 + + Sources: [1, 6], Sinks: [5, 8] + Paths from 1->5 and 6->8 overlap at nodes 2, 3, 4. + The path augmentation should resolve this overlap. + """ + edges = np.array( + [ + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [2, 6], + [6, 7], + [7, 8], + [8, 4], + ], + dtype=np.uint64, + ) + affs = np.array([0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9], dtype=np.float32) + sources = np.array([1, 6], dtype=np.uint64) + sinks = np.array([5, 8], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=False, + ) + # The graph should initialize and compute the augmented capacity + # without errors, even with overlapping paths + result = graph.compute_mincut() + assert len(result) > 0 + + def test_overlapping_paths_with_weak_bridge(self): + """Graph with overlapping paths and a clear weak bridge to cut. + + Graph: + 1--2--3--4--5 + | | + 6--7--8 + + Edge 3-4 is weak (0.01), all others strong (0.9). + Sources: [1, 6], Sinks: [5, 8] + """ + edges = np.array( + [ + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [2, 6], + [6, 7], + [7, 8], + [8, 4], + ], + dtype=np.uint64, + ) + affs = np.array([0.9, 0.9, 0.01, 0.9, 0.9, 0.9, 0.9, 0.9], dtype=np.float32) + sources = np.array([1, 6], dtype=np.uint64) + sinks = np.array([5, 8], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + assert len(result) > 0 + result_set = set(map(tuple, result)) + # The weak edge 3-4 should be among the cut edges + assert (3, 4) in result_set or (4, 3) in result_set + + def test_path_augment_multiple_sources_sinks_no_overlap(self): + """Multiple sources and sinks where paths do not overlap. + + Graph: + 1--2--3--4 + | + 5--6 + + Sources: [1], Sinks: [4] + """ + edges = np.array( + [[1, 2], [2, 3], [3, 4], [3, 5], [5, 6]], + dtype=np.uint64, + ) + affs = np.array([0.9, 0.01, 0.9, 0.9, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + assert len(result) > 0 + result_set = set(map(tuple, result)) + assert (2, 3) in result_set or (3, 2) in result_set + + +class TestSplitPreviewMultipleCCs: + """Test _get_split_preview_connected_components with more than 2 components.""" + + def test_three_components(self): + """A graph that splits into 3 components after cut. + + Graph: 1--2--3--4--5 with weak links at 2-3 and 3-4. + After cutting both weak links, we get 3 components: + {1,2}, {3}, {4,5} + """ + edges = np.array( + [[1, 2], [2, 3], [3, 4], [4, 5]], + dtype=np.uint64, + ) + affs = np.array([0.9, 0.01, 0.01, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([5], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + supervoxel_ccs, illegal_split = graph.compute_mincut() + assert isinstance(supervoxel_ccs, list) + assert len(supervoxel_ccs) >= 2 + # Source should be in first CC + assert 1 in supervoxel_ccs[0] + # Sink should be in second CC + assert 5 in supervoxel_ccs[1] + assert isinstance(illegal_split, bool) + + def test_split_preview_preserves_all_nodes(self): + """All nodes should appear across the CCs.""" + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.9, 0.01, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + supervoxel_ccs, _ = graph.compute_mincut() + all_nodes = set() + for cc in supervoxel_ccs: + all_nodes.update(set(cc)) + # All original nodes should appear in some CC + assert {1, 2, 3, 4}.issubset(all_nodes) + + +class TestRunSplitPreview: + """Test the module-level run_split_preview function. + + Note: The full run_split_preview requires a ChunkedGraph instance, + so we test through run_multicut with split_preview=True which exercises + the same _get_split_preview_connected_components code path. + """ + + def test_basic_split_preview(self): + """run_multicut with split_preview should return CCs and a flag.""" + edges_sv = Edges( + np.array([1, 2, 3, 4], dtype=np.uint64), + np.array([2, 3, 4, 5], dtype=np.uint64), + affinities=np.array([0.9, 0.1, 0.9, 0.9], dtype=np.float32), + areas=np.array([1, 1, 1, 1], dtype=np.float32), + ) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([5], dtype=np.uint64) + ccs, illegal_split = run_multicut( + edges_sv, + sources, + sinks, + split_preview=True, + disallow_isolating_cut=False, + ) + assert isinstance(ccs, list) + assert isinstance(illegal_split, bool) + assert len(ccs) >= 2 + + def test_split_preview_with_areas(self): + """Split preview with areas provided.""" + edges_sv = Edges( + np.array([10, 20, 30], dtype=np.uint64), + np.array([20, 30, 40], dtype=np.uint64), + affinities=np.array([0.9, 0.01, 0.9], dtype=np.float32), + areas=np.array([100, 5, 100], dtype=np.float32), + ) + sources = np.array([10], dtype=np.uint64) + sinks = np.array([40], dtype=np.uint64) + ccs, illegal_split = run_multicut( + edges_sv, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=False, + ) + assert isinstance(ccs, list) + assert len(ccs) >= 2 + # Source side should contain 10 + assert 10 in ccs[0] + # Sink side should contain 40 + assert 40 in ccs[1] + + def test_split_preview_path_augment(self): + """Split preview with path_augment=True.""" + edges_sv = Edges( + np.array([1, 2, 3, 4, 5], dtype=np.uint64), + np.array([2, 3, 4, 5, 6], dtype=np.uint64), + affinities=np.array([0.9, 0.9, 0.01, 0.9, 0.9], dtype=np.float32), + ) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([6], dtype=np.uint64) + ccs, illegal_split = run_multicut( + edges_sv, + sources, + sinks, + split_preview=True, + path_augment=True, + disallow_isolating_cut=False, + ) + assert isinstance(ccs, list) + assert len(ccs) >= 2 + assert 1 in ccs[0] + assert 6 in ccs[1] + assert not illegal_split + + +class TestFilterGraphCCsWithLogger: + """Test _filter_graph_connected_components logs a warning when sources + and sinks are in different connected components.""" + + def test_disconnected_with_logger_raises(self): + """Disconnected graph with logger should log warning and raise.""" + import logging + + logger = logging.getLogger("test_filter_cc_logger") + logger.setLevel(logging.DEBUG) + + edges = np.array([[1, 2], [3, 4]], dtype=np.uint64) + affs = np.array([0.5, 0.5], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + logger=logger, + ) + + with pytest.raises( + PreconditionError, match="Sinks and sources are not connected" + ): + graph.compute_mincut() + + +class TestGtMincutSanityCheck: + """Test the _gt_mincut_sanity_check debug method.""" + + def test_sanity_check_valid_partition(self): + """A valid partition should pass the sanity check without error.""" + import graph_tool + import graph_tool.flow + + edges = np.array([[1, 2], [2, 3]], dtype=np.uint64) + affs = np.array([0.9, 0.1], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([3], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=False, + ) + + # Manually compute partition to test the sanity check + graph._filter_graph_connected_components() + src = graph.weighted_graph.vertex(graph.source_graph_ids[0]) + tgt = graph.weighted_graph.vertex(graph.sink_graph_ids[0]) + residuals = graph_tool.flow.push_relabel_max_flow( + graph.weighted_graph, src, tgt, graph.capacities + ) + partition = graph_tool.flow.min_st_cut( + graph.weighted_graph, src, graph.capacities, residuals + ) + # This should not raise any assertion error + graph._gt_mincut_sanity_check(partition) + + +class TestIsolatingCutPath: + """Test the IsolatingCutException path in _sink_and_source_connectivity_sanity_check.""" + + def test_isolating_cut_raises_precondition_error(self): + """When mincut isolates exactly the labeled points and they have edges + to non-label nodes, PreconditionError is raised. + + Graph: 1 --0.01-- 2 --0.9-- 3 --0.9-- 4 + Sources: [1], Sinks: [4] + disallow_isolating_cut=True + + The mincut cuts edge 1-2 (weakest). After cut, source CC = {1}. + source_path_vertices = source_graph_ids = {1} (path_augment=False). + len(source_path_vertices) == len(cc) == 1. + In the raw graph, node 1 has neighbor 2 which is NOT a label node. + partition_edges_within_label returns False -> IsolatingCutException -> PreconditionError. + """ + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.01, 0.9, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=False, + disallow_isolating_cut=True, + ) + # This should raise PreconditionError about isolating cut + with pytest.raises(PreconditionError, match="cut off only the labeled"): + graph.compute_mincut() + + def test_isolating_cut_split_preview_returns_illegal(self): + """In split_preview mode, isolating cut should set illegal_split=True.""" + edges = np.array([[1, 2], [2, 3], [3, 4]], dtype=np.uint64) + affs = np.array([0.01, 0.9, 0.9], dtype=np.float32) + sources = np.array([1], dtype=np.uint64) + sinks = np.array([4], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=False, + disallow_isolating_cut=True, + ) + supervoxel_ccs, illegal_split = graph.compute_mincut() + assert isinstance(supervoxel_ccs, list) + assert illegal_split is True + + +class TestRerunPathsWithoutOverlap: + """Test that the rerun_paths_without_overlap code path is exercised + when source and sink shortest paths overlap and removing overlap + breaks connectedness.""" + + def test_forced_overlap_resolution(self): + """Create graph where source/sink paths overlap, forcing rerun. + + Graph: + 1--2--3 + | | | + 4--5--6 + + Sources: [1, 4], Sinks: [3, 6] + Paths from 1->3 and 4->6 both go through 2 and 5, causing overlap. + The path augmentation should resolve the overlap via rerun_paths_without_overlap. + """ + edges = np.array( + [ + [1, 2], + [2, 3], + [4, 5], + [5, 6], + [1, 4], + [2, 5], + [3, 6], + ], + dtype=np.uint64, + ) + affs = np.array([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float32) + sources = np.array([1, 4], dtype=np.uint64) + sinks = np.array([3, 6], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + assert len(result) > 0 + + def test_forced_overlap_resolution_asymmetric(self): + """Asymmetric graph where one team wins overlap by harmonic mean. + + Graph: + 1--2--3--4 + | | | | + 5--6--7--8 + + Sources: [1, 5], Sinks: [4, 8] + Paths overlap at intermediate nodes 2,3,6,7. + The path augmentation should resolve the overlap. + """ + edges = np.array( + [ + [1, 2], + [2, 3], + [3, 4], + [5, 6], + [6, 7], + [7, 8], + [1, 5], + [2, 6], + [3, 7], + [4, 8], + ], + dtype=np.uint64, + ) + affs = np.array( + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + dtype=np.float32, + ) + sources = np.array([1, 5], dtype=np.uint64) + sinks = np.array([4, 8], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + assert len(result) > 0 + + def test_overlap_resolution_with_clear_cut(self): + """Graph with overlap at a bottleneck, but weak bridge for the cut. + + Graph: + 1--2--3--4--5 + | | + 6--7--8 + + Sources: [1, 6], Sinks: [5, 8] + Edge 3-4 is very weak (0.01), all others strong. + Overlap is forced at node 2 or 4. + """ + edges = np.array( + [ + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [2, 6], + [6, 7], + [7, 8], + [8, 4], + ], + dtype=np.uint64, + ) + # Make the bridge edge very weak + affs = np.array([0.9, 0.9, 0.01, 0.9, 0.9, 0.9, 0.9, 0.9], dtype=np.float32) + sources = np.array([1, 6], dtype=np.uint64) + sinks = np.array([5, 8], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=False, + path_augment=True, + disallow_isolating_cut=False, + ) + result = graph.compute_mincut() + assert len(result) > 0 + result_set = set(map(tuple, result)) + # The weak edge 3-4 should be among the cut edges + assert (3, 4) in result_set or (4, 3) in result_set + + def test_overlap_with_split_preview(self): + """Split preview mode with overlapping paths should produce valid CCs.""" + edges = np.array( + [ + [1, 2], + [2, 3], + [4, 5], + [5, 6], + [1, 4], + [2, 5], + [3, 6], + ], + dtype=np.uint64, + ) + affs = np.array([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float32) + sources = np.array([1, 4], dtype=np.uint64) + sinks = np.array([3, 6], dtype=np.uint64) + + graph = LocalMincutGraph( + edges, + affs, + sources, + sinks, + split_preview=True, + path_augment=True, + disallow_isolating_cut=False, + ) + supervoxel_ccs, illegal_split = graph.compute_mincut() + assert isinstance(supervoxel_ccs, list) + assert len(supervoxel_ccs) >= 2 + assert isinstance(illegal_split, bool) diff --git a/pychunkedgraph/tests/test_edges_definitions.py b/pychunkedgraph/tests/test_edges_definitions.py new file mode 100644 index 000000000..e1ab45288 --- /dev/null +++ b/pychunkedgraph/tests/test_edges_definitions.py @@ -0,0 +1,105 @@ +"""Tests for pychunkedgraph.graph.edges.definitions""" + +import pytest +import numpy as np + +from pychunkedgraph.graph.edges.definitions import ( + Edges, + EDGE_TYPES, + DEFAULT_AFFINITY, + DEFAULT_AREA, +) +from pychunkedgraph.graph.utils import basetypes + + +class TestEdgeTypes: + def test_fields(self): + assert EDGE_TYPES.in_chunk == "in" + assert EDGE_TYPES.between_chunk == "between" + assert EDGE_TYPES.cross_chunk == "cross" + + +class TestEdges: + def test_creation_defaults(self): + ids1 = np.array([1, 2], dtype=basetypes.NODE_ID) + ids2 = np.array([3, 4], dtype=basetypes.NODE_ID) + e = Edges(ids1, ids2) + np.testing.assert_array_equal(e.node_ids1, ids1) + np.testing.assert_array_equal(e.node_ids2, ids2) + assert np.all(e.affinities == DEFAULT_AFFINITY) + assert np.all(e.areas == DEFAULT_AREA) + + def test_creation_explicit(self): + ids1 = np.array([1, 2], dtype=basetypes.NODE_ID) + ids2 = np.array([3, 4], dtype=basetypes.NODE_ID) + affs = np.array([0.5, 0.9], dtype=basetypes.EDGE_AFFINITY) + areas = np.array([10.0, 20.0], dtype=basetypes.EDGE_AREA) + e = Edges(ids1, ids2, affinities=affs, areas=areas) + np.testing.assert_array_almost_equal(e.affinities, affs) + np.testing.assert_array_almost_equal(e.areas, areas) + + def test_creation_empty(self): + e = Edges([], []) + assert len(e) == 0 + pairs = e.get_pairs() + assert pairs.shape[0] == 0 + + def test_len(self): + e = Edges([1, 2, 3], [4, 5, 6]) + assert len(e) == 3 + + def test_add(self): + e1 = Edges([1], [2], affinities=[0.5], areas=[10.0]) + e2 = Edges([3], [4], affinities=[0.9], areas=[20.0]) + e3 = e1 + e2 + assert len(e3) == 2 + np.testing.assert_array_equal(e3.node_ids1, [1, 3]) + np.testing.assert_array_equal(e3.node_ids2, [2, 4]) + + def test_iadd(self): + e1 = Edges([1], [2]) + e2 = Edges([3], [4]) + e1 += e2 + assert len(e1) == 2 + np.testing.assert_array_equal(e1.node_ids1, [1, 3]) + + def test_getitem_boolean(self): + e = Edges([1, 2, 3], [4, 5, 6], affinities=[0.1, 0.5, 0.9], areas=[1, 2, 3]) + mask = np.array([True, False, True]) + filtered = e[mask] + assert len(filtered) == 2 + np.testing.assert_array_equal(filtered.node_ids1, [1, 3]) + + def test_getitem_error(self): + e = Edges([1, 2], [3, 4]) + with pytest.raises(Exception): + e["invalid_key"] + + def test_get_pairs(self): + e = Edges([1, 2], [3, 4]) + pairs = e.get_pairs() + assert pairs.shape == (2, 2) + np.testing.assert_array_equal(pairs[:, 0], [1, 2]) + np.testing.assert_array_equal(pairs[:, 1], [3, 4]) + + def test_get_pairs_caching(self): + e = Edges([1, 2], [3, 4]) + p1 = e.get_pairs() + p2 = e.get_pairs() + assert p1 is p2 + + def test_size_mismatch_raises(self): + with pytest.raises(AssertionError): + Edges([1, 2], [3]) + + def test_affinities_setter(self): + e = Edges([1], [2]) + new_affs = np.array([0.99], dtype=basetypes.EDGE_AFFINITY) + e.affinities = new_affs + np.testing.assert_array_almost_equal(e.affinities, new_affs) + + def test_areas_setter(self): + e = Edges([1], [2]) + new_areas = np.array([42.0], dtype=basetypes.EDGE_AREA) + e.areas = new_areas + np.testing.assert_array_almost_equal(e.areas, new_areas) diff --git a/pychunkedgraph/tests/test_edges_utils.py b/pychunkedgraph/tests/test_edges_utils.py new file mode 100644 index 000000000..775823870 --- /dev/null +++ b/pychunkedgraph/tests/test_edges_utils.py @@ -0,0 +1,96 @@ +"""Tests for pychunkedgraph.graph.edges.utils""" + +import numpy as np + +from pychunkedgraph.graph.edges import Edges, EDGE_TYPES +from pychunkedgraph.graph.edges.utils import ( + concatenate_chunk_edges, + concatenate_cross_edge_dicts, + merge_cross_edge_dicts, + get_cross_chunk_edges_layer, +) +from pychunkedgraph.graph.utils import basetypes + +from .helpers import to_label + + +class TestConcatenateChunkEdges: + def test_basic(self): + d1 = { + EDGE_TYPES.in_chunk: Edges([1, 2], [3, 4]), + EDGE_TYPES.between_chunk: Edges([5], [6]), + EDGE_TYPES.cross_chunk: Edges([], []), + } + d2 = { + EDGE_TYPES.in_chunk: Edges([7], [8]), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([9], [10]), + } + result = concatenate_chunk_edges([d1, d2]) + assert len(result[EDGE_TYPES.in_chunk]) == 3 + assert len(result[EDGE_TYPES.between_chunk]) == 1 + assert len(result[EDGE_TYPES.cross_chunk]) == 1 + + def test_empty(self): + result = concatenate_chunk_edges([]) + for edge_type in EDGE_TYPES: + assert len(result[edge_type]) == 0 + + +class TestConcatenateCrossEdgeDicts: + def test_no_unique(self): + d1 = {3: np.array([[1, 2]], dtype=basetypes.NODE_ID)} + d2 = {3: np.array([[1, 2], [3, 4]], dtype=basetypes.NODE_ID)} + result = concatenate_cross_edge_dicts([d1, d2], unique=False) + assert len(result[3]) == 3 # duplicates kept + + def test_unique(self): + d1 = {3: np.array([[1, 2]], dtype=basetypes.NODE_ID)} + d2 = {3: np.array([[1, 2], [3, 4]], dtype=basetypes.NODE_ID)} + result = concatenate_cross_edge_dicts([d1, d2], unique=True) + assert len(result[3]) == 2 # duplicates removed + + def test_different_layers(self): + d1 = {3: np.array([[1, 2]], dtype=basetypes.NODE_ID)} + d2 = {4: np.array([[5, 6]], dtype=basetypes.NODE_ID)} + result = concatenate_cross_edge_dicts([d1, d2]) + assert 3 in result + assert 4 in result + + +class TestMergeCrossEdgeDicts: + def test_basic(self): + d1 = { + np.uint64(100): {3: np.array([[1, 2]], dtype=basetypes.NODE_ID)}, + } + d2 = { + np.uint64(100): {3: np.array([[3, 4]], dtype=basetypes.NODE_ID)}, + np.uint64(200): {4: np.array([[5, 6]], dtype=basetypes.NODE_ID)}, + } + result = merge_cross_edge_dicts(d1, d2) + assert np.uint64(100) in result + assert np.uint64(200) in result + assert len(result[np.uint64(100)][3]) == 2 + + +class TestGetCrossChunkEdgesLayer: + def test_empty(self, gen_graph): + graph = gen_graph(n_layers=4) + result = get_cross_chunk_edges_layer(graph.meta, []) + assert len(result) == 0 + + def test_same_chunk(self, gen_graph): + graph = gen_graph(n_layers=4) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + sv2 = to_label(graph, 1, 0, 0, 0, 2) + edges = np.array([[sv1, sv2]], dtype=basetypes.NODE_ID) + result = get_cross_chunk_edges_layer(graph.meta, edges) + assert result[0] == 1 # same chunk -> layer 1 + + def test_adjacent_chunks(self, gen_graph): + graph = gen_graph(n_layers=4) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + sv2 = to_label(graph, 1, 1, 0, 0, 1) + edges = np.array([[sv1, sv2]], dtype=basetypes.NODE_ID) + result = get_cross_chunk_edges_layer(graph.meta, edges) + assert result[0] >= 2 # different chunks -> higher layer diff --git a/pychunkedgraph/tests/test_edits_extended.py b/pychunkedgraph/tests/test_edits_extended.py new file mode 100644 index 000000000..bc1227de7 --- /dev/null +++ b/pychunkedgraph/tests/test_edits_extended.py @@ -0,0 +1,55 @@ +"""Tests for pychunkedgraph.graph.edits - extended coverage""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from pychunkedgraph.graph.edits import flip_ids +from pychunkedgraph.graph.utils import basetypes + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestFlipIds: + def test_basic(self): + id_map = { + np.uint64(1): {np.uint64(10), np.uint64(11)}, + np.uint64(2): {np.uint64(20)}, + } + result = flip_ids(id_map, [np.uint64(1), np.uint64(2)]) + assert np.uint64(10) in result + assert np.uint64(11) in result + assert np.uint64(20) in result + + def test_empty(self): + id_map = {} + result = flip_ids(id_map, []) + assert len(result) == 0 + + +class TestInitOldHierarchy: + def test_basic(self, gen_graph): + from pychunkedgraph.graph.edits import _init_old_hierarchy + + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + sv = to_label(graph, 1, 0, 0, 0, 0) + l2_parent = graph.get_parent(sv) + result = _init_old_hierarchy(graph, np.array([l2_parent], dtype=np.uint64)) + assert l2_parent in result + assert 2 in result[l2_parent] diff --git a/pychunkedgraph/tests/test_exceptions.py b/pychunkedgraph/tests/test_exceptions.py new file mode 100644 index 000000000..2c054bfb0 --- /dev/null +++ b/pychunkedgraph/tests/test_exceptions.py @@ -0,0 +1,70 @@ +"""Tests for pychunkedgraph.graph.exceptions""" + +import pytest +from http.client import BAD_REQUEST, UNAUTHORIZED, FORBIDDEN, CONFLICT +from http.client import INTERNAL_SERVER_ERROR, GATEWAY_TIMEOUT + +from pychunkedgraph.graph.exceptions import ( + ChunkedGraphError, + LockingError, + PreconditionError, + PostconditionError, + ChunkedGraphAPIError, + ClientError, + BadRequest, + Unauthorized, + Forbidden, + Conflict, + ServerError, + InternalServerError, + GatewayTimeout, +) + + +class TestExceptionHierarchy: + def test_base_error(self): + with pytest.raises(ChunkedGraphError): + raise ChunkedGraphError("test") + + def test_locking_error_inherits(self): + assert issubclass(LockingError, ChunkedGraphError) + with pytest.raises(ChunkedGraphError): + raise LockingError("locked") + + def test_precondition_error(self): + assert issubclass(PreconditionError, ChunkedGraphError) + + def test_postcondition_error(self): + assert issubclass(PostconditionError, ChunkedGraphError) + + def test_api_error_str(self): + err = ChunkedGraphAPIError("test message") + assert err.message == "test message" + assert err.status_code is None + assert "[None]: test message" == str(err) + + def test_client_error_inherits(self): + assert issubclass(ClientError, ChunkedGraphAPIError) + + def test_bad_request(self): + err = BadRequest("bad") + assert err.status_code == BAD_REQUEST + assert issubclass(BadRequest, ClientError) + + def test_unauthorized(self): + assert Unauthorized.status_code == UNAUTHORIZED + + def test_forbidden(self): + assert Forbidden.status_code == FORBIDDEN + + def test_conflict(self): + assert Conflict.status_code == CONFLICT + + def test_server_error_inherits(self): + assert issubclass(ServerError, ChunkedGraphAPIError) + + def test_internal_server_error(self): + assert InternalServerError.status_code == INTERNAL_SERVER_ERROR + + def test_gateway_timeout(self): + assert GatewayTimeout.status_code == GATEWAY_TIMEOUT diff --git a/pychunkedgraph/tests/test_graph_build.py b/pychunkedgraph/tests/test_graph_build.py new file mode 100644 index 000000000..23ffebe0f --- /dev/null +++ b/pychunkedgraph/tests/test_graph_build.py @@ -0,0 +1,420 @@ +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import attributes +from ..graph.utils import basetypes +from ..graph.utils.serializers import serialize_uint64 +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGraphBuild: + @pytest.mark.timeout(30) + def test_build_single_node(self, gen_graph): + """ + Create graph with single RG node 1 in chunk A + ┌─────┐ + │ A¹ │ + │ 1 │ + │ │ + └─────┘ + """ + + cg = gen_graph(n_layers=2) + # Add Chunk A + create_chunk(cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)]) + + res = cg.client._table.read_rows() + res.consume_all() + + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + # Check for the one Level 2 node that should have been created. + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) + + for aces in atomic_cross_edge_d.values(): + assert len(aces) == 0 + + assert len(children) == 1 and children[0] == to_label(cg, 1, 0, 0, 0, 0) + # Make sure there are not any more entries in the table + # include counters, meta and version rows + assert len(res.rows) == 1 + 1 + 1 + 1 + 1 + + @pytest.mark.timeout(30) + def test_build_single_edge(self, gen_graph): + """ + Create graph with edge between RG supervoxels 1 and 2 (same chunk) + ┌─────┐ + │ A¹ │ + │ 1━2 │ + │ │ + └─────┘ + """ + + cg = gen_graph(n_layers=2) + + # Add Chunk A + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], + ) + + res = cg.client._table.read_rows() + res.consume_all() + + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 1)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + # Check for the one Level 2 node that should have been created. + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows + + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) + + for aces in atomic_cross_edge_d.values(): + assert len(aces) == 0 + assert ( + len(children) == 2 + and to_label(cg, 1, 0, 0, 0, 0) in children + and to_label(cg, 1, 0, 0, 0, 1) in children + ) + + # Make sure there are not any more entries in the table + # include counters, meta and version rows + assert len(res.rows) == 2 + 1 + 1 + 1 + 1 + + @pytest.mark.timeout(30) + def test_build_single_across_edge(self, gen_graph): + """ + Create graph with edge between RG supervoxels 1 and 2 (neighboring chunks) + ┌─────┌─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + + atomic_chunk_bounds = np.array([2, 1, 1]) + cg = gen_graph(n_layers=3, atomic_chunk_bounds=atomic_chunk_bounds) + + # Chunk A + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], + ) + + # Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + ) + + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + res = cg.client._table.read_rows() + res.consume_all() + + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + assert serialize_uint64(to_label(cg, 1, 1, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + assert parent == to_label(cg, 2, 1, 0, 0, 1) + + # Check for the two Level 2 nodes that should have been created. Since Level 2 has the same + # dimensions as Level 1, we also expect them to be in different chunks + # to_label(cg, 2, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 0, 0, 0, 1)) + ] + + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) + + test_ace = np.array( + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert len(children) == 1 and to_label(cg, 1, 0, 0, 0, 0) in children + + # to_label(cg, 2, 1, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 1, 0, 0, 1)) in res.rows + atomic_cross_edge_d = cg.get_atomic_cross_edges( + np.array([to_label(cg, 2, 1, 0, 0, 1)], dtype=basetypes.NODE_ID) + ) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 1, 0, 0, 1)) + ] + + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 2, 1, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) + + test_ace = np.array( + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert len(children) == 1 and to_label(cg, 1, 1, 0, 0, 0) in children + + # Check for the one Level 3 node that should have been created. This one combines the two + # connected components of Level 2 + # to_label(cg, 3, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 3, 0, 0, 0, 1)) in res.rows + + attr = attributes.Hierarchy.Child + row = res.rows[serialize_uint64(to_label(cg, 3, 0, 0, 0, 1))].cells["0"] + children = attr.deserialize(row[attr.key][0].value) + assert ( + len(children) == 2 + and to_label(cg, 2, 0, 0, 0, 1) in children + and to_label(cg, 2, 1, 0, 0, 1) in children + ) + + # Make sure there are not any more entries in the table + # include counters, meta and version rows + assert len(res.rows) == 2 + 2 + 1 + 3 + 1 + 1 + + @pytest.mark.timeout(30) + def test_build_single_edge_and_single_across_edge(self, gen_graph): + """ + Create graph with edge between RG supervoxels 1 and 2 (same chunk) + and edge between RG supervoxels 1 and 3 (neighboring chunks) + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 2━1━┿━━3 │ + │ │ │ + └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Chunk A + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + ], + ) + + # Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + ) + + add_parent_chunk(cg, 3, np.array([0, 0, 0]), n_threads=1) + res = cg.client._table.read_rows() + res.consume_all() + + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + # to_label(cg, 1, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 1)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) + assert parent == to_label(cg, 2, 0, 0, 0, 1) + + # to_label(cg, 1, 1, 0, 0, 0) + assert serialize_uint64(to_label(cg, 1, 1, 0, 0, 0)) in res.rows + parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + assert parent == to_label(cg, 2, 1, 0, 0, 1) + + # Check for the two Level 2 nodes that should have been created. Since Level 2 has the same + # dimensions as Level 1, we also expect them to be in different chunks + # to_label(cg, 2, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows + row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] + atomic_cross_edge_d = cg.get_atomic_cross_edges([to_label(cg, 2, 0, 0, 0, 1)]) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 0, 0, 0, 1)) + ] + column = attributes.Hierarchy.Child + children = column.deserialize(row[column.key][0].value) + + test_ace = np.array( + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert ( + len(children) == 2 + and to_label(cg, 1, 0, 0, 0, 0) in children + and to_label(cg, 1, 0, 0, 0, 1) in children + ) + + # to_label(cg, 2, 1, 0, 0, 1) + assert serialize_uint64(to_label(cg, 2, 1, 0, 0, 1)) in res.rows + row = res.rows[serialize_uint64(to_label(cg, 2, 1, 0, 0, 1))].cells["0"] + atomic_cross_edge_d = cg.get_atomic_cross_edges([to_label(cg, 2, 1, 0, 0, 1)]) + atomic_cross_edge_d = atomic_cross_edge_d[ + np.uint64(to_label(cg, 2, 1, 0, 0, 1)) + ] + children = column.deserialize(row[column.key][0].value) + + test_ace = np.array( + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + dtype=np.uint64, + ) + assert len(atomic_cross_edge_d[2]) == 1 + assert test_ace in atomic_cross_edge_d[2] + assert len(children) == 1 and to_label(cg, 1, 1, 0, 0, 0) in children + + # Check for the one Level 3 node that should have been created. This one combines the two + # connected components of Level 2 + # to_label(cg, 3, 0, 0, 0, 1) + assert serialize_uint64(to_label(cg, 3, 0, 0, 0, 1)) in res.rows + row = res.rows[serialize_uint64(to_label(cg, 3, 0, 0, 0, 1))].cells["0"] + column = attributes.Hierarchy.Child + children = column.deserialize(row[column.key][0].value) + + assert ( + len(children) == 2 + and to_label(cg, 2, 0, 0, 0, 1) in children + and to_label(cg, 2, 1, 0, 0, 1) in children + ) + + # Make sure there are not any more entries in the table + # include counters, meta and version rows + assert len(res.rows) == 3 + 2 + 1 + 3 + 1 + 1 + + @pytest.mark.timeout(120) + def test_build_big_graph(self, gen_graph): + """ + Create graph with RG nodes 1 and 2 in opposite corners of the largest possible dataset + ┌─────┐ ┌─────┐ + │ A¹ │ ... │ Z¹ │ + │ 1 │ │ 2 │ + │ │ │ │ + └─────┘ └─────┘ + """ + + atomic_chunk_bounds = np.array([8, 8, 8]) + cg = gen_graph(n_layers=5, atomic_chunk_bounds=atomic_chunk_bounds) + + # Preparation: Build Chunk A + create_chunk(cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)], edges=[]) + + # Preparation: Build Chunk Z + create_chunk(cg, vertices=[to_label(cg, 1, 7, 7, 7, 0)], edges=[]) + + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(cg, 3, [3, 3, 3], n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], n_threads=1) + add_parent_chunk(cg, 5, [0, 0, 0], n_threads=1) + + res = cg.client._table.read_rows() + res.consume_all() + + assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows + assert serialize_uint64(to_label(cg, 1, 7, 7, 7, 0)) in res.rows + assert serialize_uint64(to_label(cg, 5, 0, 0, 0, 1)) in res.rows + assert serialize_uint64(to_label(cg, 5, 0, 0, 0, 2)) in res.rows + + @pytest.mark.timeout(30) + def test_double_chunk_creation(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + """ + + atomic_chunk_bounds = np.array([4, 4, 4]) + cg = gen_graph(n_layers=4, atomic_chunk_bounds=atomic_chunk_bounds) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + add_parent_chunk( + cg, + 4, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=2, x=0, y=0, z=0))) == 2 + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=2, x=1, y=0, z=0))) == 1 + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=3, x=0, y=0, z=0))) == 0 + assert len(cg.range_read_chunk(cg.get_chunk_id(layer=4, x=0, y=0, z=0))) == 6 + + assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 0, 0, 0, 1))) == 4 + assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 0, 0, 0, 2))) == 4 + assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 1, 0, 0, 1))) == 4 + + root_seg_ids = [ + cg.get_segment_id(cg.get_root(to_label(cg, 1, 0, 0, 0, 1))), + cg.get_segment_id(cg.get_root(to_label(cg, 1, 0, 0, 0, 2))), + cg.get_segment_id(cg.get_root(to_label(cg, 1, 1, 0, 0, 1))), + ] + + assert 4 in root_seg_ids + assert 5 in root_seg_ids + assert 6 in root_seg_ids diff --git a/pychunkedgraph/tests/test_graph_queries.py b/pychunkedgraph/tests/test_graph_queries.py new file mode 100644 index 000000000..9845b121e --- /dev/null +++ b/pychunkedgraph/tests/test_graph_queries.py @@ -0,0 +1,222 @@ +from math import inf + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label + + +class TestGraphSimpleQueries: + """ + ┌─────┬─────┬─────┐ L X Y Z S L X Y Z S L X Y Z S L X Y Z S + │ A¹ │ B¹ │ C¹ │ 1: 1 0 0 0 0 ─── 2 0 0 0 1 ───────────────── 4 0 0 0 1 + │ 1 │ 3━2━┿━━4 │ 2: 1 1 0 0 0 ─┬─ 2 1 0 0 1 ─── 3 0 0 0 1 ─┬─ 4 0 0 0 2 + │ │ │ │ 3: 1 1 0 0 1 ─┘ │ + └─────┴─────┴─────┘ 4: 1 2 0 0 0 ─── 2 2 0 0 1 ─── 3 1 0 0 1 ─┘ + """ + + @pytest.mark.timeout(30) + def test_get_parent_and_children(self, gen_graph_simplequerytest): + cg = gen_graph_simplequerytest + + children10000 = cg.get_children(to_label(cg, 1, 0, 0, 0, 0)) + children11000 = cg.get_children(to_label(cg, 1, 1, 0, 0, 0)) + children11001 = cg.get_children(to_label(cg, 1, 1, 0, 0, 1)) + children12000 = cg.get_children(to_label(cg, 1, 2, 0, 0, 0)) + + parent10000 = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + parent11000 = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + parent11001 = cg.get_parent(to_label(cg, 1, 1, 0, 0, 1)) + parent12000 = cg.get_parent(to_label(cg, 1, 2, 0, 0, 0)) + + children20001 = cg.get_children(to_label(cg, 2, 0, 0, 0, 1)) + children21001 = cg.get_children(to_label(cg, 2, 1, 0, 0, 1)) + children22001 = cg.get_children(to_label(cg, 2, 2, 0, 0, 1)) + + parent20001 = cg.get_parent(to_label(cg, 2, 0, 0, 0, 1)) + parent21001 = cg.get_parent(to_label(cg, 2, 1, 0, 0, 1)) + parent22001 = cg.get_parent(to_label(cg, 2, 2, 0, 0, 1)) + + children30001 = cg.get_children(to_label(cg, 3, 0, 0, 0, 1)) + children31001 = cg.get_children(to_label(cg, 3, 1, 0, 0, 1)) + + parent30001 = cg.get_parent(to_label(cg, 3, 0, 0, 0, 1)) + parent31001 = cg.get_parent(to_label(cg, 3, 1, 0, 0, 1)) + + children40001 = cg.get_children(to_label(cg, 4, 0, 0, 0, 1)) + children40002 = cg.get_children(to_label(cg, 4, 0, 0, 0, 2)) + + parent40001 = cg.get_parent(to_label(cg, 4, 0, 0, 0, 1)) + parent40002 = cg.get_parent(to_label(cg, 4, 0, 0, 0, 2)) + + # (non-existing) Children of L1 + assert np.array_equal(children10000, []) is True + assert np.array_equal(children11000, []) is True + assert np.array_equal(children11001, []) is True + assert np.array_equal(children12000, []) is True + + # Parent of L1 + assert parent10000 == to_label(cg, 2, 0, 0, 0, 1) + assert parent11000 == to_label(cg, 2, 1, 0, 0, 1) + assert parent11001 == to_label(cg, 2, 1, 0, 0, 1) + assert parent12000 == to_label(cg, 2, 2, 0, 0, 1) + + # Children of L2 + assert len(children20001) == 1 and to_label(cg, 1, 0, 0, 0, 0) in children20001 + assert ( + len(children21001) == 2 + and to_label(cg, 1, 1, 0, 0, 0) in children21001 + and to_label(cg, 1, 1, 0, 0, 1) in children21001 + ) + assert len(children22001) == 1 and to_label(cg, 1, 2, 0, 0, 0) in children22001 + + # Parent of L2 + assert parent20001 == to_label(cg, 4, 0, 0, 0, 1) + assert parent21001 == to_label(cg, 3, 0, 0, 0, 1) + assert parent22001 == to_label(cg, 3, 1, 0, 0, 1) + + # Children of L3 + assert len(children30001) == 1 and len(children31001) == 1 + assert to_label(cg, 2, 1, 0, 0, 1) in children30001 + assert to_label(cg, 2, 2, 0, 0, 1) in children31001 + + # Parent of L3 + assert parent30001 == parent31001 + assert ( + parent30001 == to_label(cg, 4, 0, 0, 0, 1) + and parent20001 == to_label(cg, 4, 0, 0, 0, 2) + ) or ( + parent30001 == to_label(cg, 4, 0, 0, 0, 2) + and parent20001 == to_label(cg, 4, 0, 0, 0, 1) + ) + + # Children of L4 + assert parent10000 in children40001 + assert parent21001 in children40002 and parent22001 in children40002 + + # (non-existing) Parent of L4 + assert parent40001 is None + assert parent40002 is None + + children2_separate = cg.get_children( + [ + to_label(cg, 2, 0, 0, 0, 1), + to_label(cg, 2, 1, 0, 0, 1), + to_label(cg, 2, 2, 0, 0, 1), + ] + ) + assert len(children2_separate) == 3 + assert to_label(cg, 2, 0, 0, 0, 1) in children2_separate and np.all( + np.isin(children2_separate[to_label(cg, 2, 0, 0, 0, 1)], children20001) + ) + assert to_label(cg, 2, 1, 0, 0, 1) in children2_separate and np.all( + np.isin(children2_separate[to_label(cg, 2, 1, 0, 0, 1)], children21001) + ) + assert to_label(cg, 2, 2, 0, 0, 1) in children2_separate and np.all( + np.isin(children2_separate[to_label(cg, 2, 2, 0, 0, 1)], children22001) + ) + + children2_combined = cg.get_children( + [ + to_label(cg, 2, 0, 0, 0, 1), + to_label(cg, 2, 1, 0, 0, 1), + to_label(cg, 2, 2, 0, 0, 1), + ], + flatten=True, + ) + assert ( + len(children2_combined) == 4 + and np.all(np.isin(children20001, children2_combined)) + and np.all(np.isin(children21001, children2_combined)) + and np.all(np.isin(children22001, children2_combined)) + ) + + @pytest.mark.timeout(30) + def test_get_root(self, gen_graph_simplequerytest): + cg = gen_graph_simplequerytest + root10000 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + root11000 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + root11001 = cg.get_root(to_label(cg, 1, 1, 0, 0, 1)) + root12000 = cg.get_root(to_label(cg, 1, 2, 0, 0, 0)) + + with pytest.raises(Exception): + cg.get_root(0) + + assert ( + root10000 == to_label(cg, 4, 0, 0, 0, 1) + and root11000 == root11001 == root12000 == to_label(cg, 4, 0, 0, 0, 2) + ) or ( + root10000 == to_label(cg, 4, 0, 0, 0, 2) + and root11000 == root11001 == root12000 == to_label(cg, 4, 0, 0, 0, 1) + ) + + @pytest.mark.timeout(30) + def test_get_subgraph_nodes(self, gen_graph_simplequerytest): + cg = gen_graph_simplequerytest + root1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + root2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + + lvl1_nodes_1 = cg.get_subgraph([root1], leaves_only=True) + lvl1_nodes_2 = cg.get_subgraph([root2], leaves_only=True) + assert len(lvl1_nodes_1) == 1 + assert len(lvl1_nodes_2) == 3 + assert to_label(cg, 1, 0, 0, 0, 0) in lvl1_nodes_1 + assert to_label(cg, 1, 1, 0, 0, 0) in lvl1_nodes_2 + assert to_label(cg, 1, 1, 0, 0, 1) in lvl1_nodes_2 + assert to_label(cg, 1, 2, 0, 0, 0) in lvl1_nodes_2 + + lvl2_parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + lvl1_nodes = cg.get_subgraph([lvl2_parent], leaves_only=True) + assert len(lvl1_nodes) == 2 + assert to_label(cg, 1, 1, 0, 0, 0) in lvl1_nodes + assert to_label(cg, 1, 1, 0, 0, 1) in lvl1_nodes + + @pytest.mark.timeout(30) + def test_get_subgraph_edges(self, gen_graph_simplequerytest): + cg = gen_graph_simplequerytest + root1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + root2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + + edges = cg.get_subgraph([root1], edges_only=True) + assert len(edges) == 0 + + edges = cg.get_subgraph([root2], edges_only=True) + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)] in edges or [ + to_label(cg, 1, 1, 0, 0, 1), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0)] in edges or [ + to_label(cg, 1, 2, 0, 0, 0), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + lvl2_parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + edges = cg.get_subgraph([lvl2_parent], edges_only=True) + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)] in edges or [ + to_label(cg, 1, 1, 0, 0, 1), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0)] in edges or [ + to_label(cg, 1, 2, 0, 0, 0), + to_label(cg, 1, 1, 0, 0, 0), + ] in edges + + assert len(edges) == 1 + + @pytest.mark.timeout(30) + def test_get_subgraph_nodes_bb(self, gen_graph_simplequerytest): + cg = gen_graph_simplequerytest + bb = np.array([[1, 0, 0], [2, 1, 1]], dtype=int) + bb_coord = bb * cg.meta.graph_config.CHUNK_SIZE + childs_1 = cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 1, 0, 0, 1))], bbox=bb, leaves_only=True + ) + childs_2 = cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 1, 0, 0, 1))], + bbox=bb_coord, + bbox_is_coordinate=True, + leaves_only=True, + ) + assert np.all(~(np.sort(childs_1) - np.sort(childs_2))) diff --git a/pychunkedgraph/tests/test_history.py b/pychunkedgraph/tests/test_history.py new file mode 100644 index 000000000..0f0e2fa16 --- /dev/null +++ b/pychunkedgraph/tests/test_history.py @@ -0,0 +1,135 @@ +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import ChunkedGraph +from ..graph.lineage import lineage_graph, get_root_id_history +from ..graph.misc import get_delta_roots +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGraphHistory: + """These test inadvertantly also test merge and split operations""" + + @pytest.mark.timeout(120) + def test_cut_merge_history(self, gen_graph): + cg: ChunkedGraph = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + first_root = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + assert first_root == cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + timestamp_before_split = datetime.now(UTC) + split_roots = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ).new_root_ids + assert len(split_roots) == 2 + g = lineage_graph(cg, split_roots[0]) + assert g.size() == 1 + g = lineage_graph(cg, split_roots) + assert g.size() == 2 + + timestamp_after_split = datetime.now(UTC) + merge_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], + affinities=0.4, + ).new_root_ids + assert len(merge_roots) == 1 + merge_root = merge_roots[0] + timestamp_after_merge = datetime.now(UTC) + + g = lineage_graph(cg, merge_roots) + assert g.size() == 4 + assert ( + len( + get_root_id_history( + cg, + first_root, + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 4 + ) + assert ( + len( + get_root_id_history( + cg, + split_roots[0], + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 3 + ) + assert ( + len( + get_root_id_history( + cg, + split_roots[1], + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 3 + ) + assert ( + len( + get_root_id_history( + cg, + merge_root, + time_stamp_past=datetime.min, + time_stamp_future=datetime.max, + ) + ) + == 4 + ) + + new_roots, old_roots = get_delta_roots( + cg, timestamp_before_split, timestamp_after_split + ) + assert len(old_roots) == 1 + assert old_roots[0] == first_root + assert len(new_roots) == 2 + assert np.all(np.isin(new_roots, split_roots)) + + new_roots2, old_roots2 = get_delta_roots( + cg, timestamp_after_split, timestamp_after_merge + ) + assert len(new_roots2) == 1 + assert new_roots2[0] == merge_root + assert len(old_roots2) == 2 + assert np.all(np.isin(old_roots2, split_roots)) + + new_roots3, old_roots3 = get_delta_roots( + cg, timestamp_before_split, timestamp_after_merge + ) + assert len(new_roots3) == 1 + assert new_roots3[0] == merge_root + assert len(old_roots3) == 1 + assert old_roots3[0] == first_root diff --git a/pychunkedgraph/tests/test_ingest_atomic_layer.py b/pychunkedgraph/tests/test_ingest_atomic_layer.py new file mode 100644 index 000000000..c55318c8f --- /dev/null +++ b/pychunkedgraph/tests/test_ingest_atomic_layer.py @@ -0,0 +1,66 @@ +"""Tests for pychunkedgraph.ingest.create.atomic_layer""" + +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest + +from pychunkedgraph.ingest.create.atomic_layer import ( + _get_chunk_nodes_and_edges, + _get_remapping, +) +from pychunkedgraph.graph.edges import Edges, EDGE_TYPES +from pychunkedgraph.graph.utils import basetypes + + +class TestGetChunkNodesAndEdges: + def test_basic(self): + chunk_edges_d = { + EDGE_TYPES.in_chunk: Edges( + np.array([1, 2], dtype=basetypes.NODE_ID), + np.array([3, 4], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges( + np.array([1], dtype=basetypes.NODE_ID), + np.array([5], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.cross_chunk: Edges([], []), + } + isolated = np.array([10], dtype=np.uint64) + node_ids, edge_ids = _get_chunk_nodes_and_edges(chunk_edges_d, isolated) + assert 10 in node_ids + assert 1 in node_ids + assert 3 in node_ids + assert len(edge_ids) > 0 + + def test_isolated_only(self): + chunk_edges_d = { + EDGE_TYPES.in_chunk: Edges([], []), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([], []), + } + isolated = np.array([10, 20], dtype=np.uint64) + node_ids, edge_ids = _get_chunk_nodes_and_edges(chunk_edges_d, isolated) + assert 10 in node_ids + assert 20 in node_ids + + +class TestGetRemapping: + def test_basic(self): + chunk_edges_d = { + EDGE_TYPES.in_chunk: Edges( + np.array([1, 2], dtype=basetypes.NODE_ID), + np.array([3, 4], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges( + np.array([1], dtype=basetypes.NODE_ID), + np.array([5], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.cross_chunk: Edges( + np.array([2], dtype=basetypes.NODE_ID), + np.array([6], dtype=basetypes.NODE_ID), + ), + } + sparse_indices, remapping = _get_remapping(chunk_edges_d) + assert EDGE_TYPES.between_chunk in remapping + assert EDGE_TYPES.cross_chunk in remapping diff --git a/pychunkedgraph/tests/test_ingest_config.py b/pychunkedgraph/tests/test_ingest_config.py new file mode 100644 index 000000000..f068f5da1 --- /dev/null +++ b/pychunkedgraph/tests/test_ingest_config.py @@ -0,0 +1,27 @@ +"""Tests for pychunkedgraph.ingest IngestConfig""" + +from pychunkedgraph.ingest import IngestConfig + + +class TestIngestConfig: + def test_defaults(self): + config = IngestConfig() + assert config.AGGLOMERATION is None + assert config.WATERSHED is None + assert config.USE_RAW_EDGES is False + assert config.USE_RAW_COMPONENTS is False + assert config.TEST_RUN is False + + def test_custom_values(self): + config = IngestConfig( + AGGLOMERATION="gs://bucket/agg", + WATERSHED="gs://bucket/ws", + USE_RAW_EDGES=True, + USE_RAW_COMPONENTS=True, + TEST_RUN=True, + ) + assert config.AGGLOMERATION == "gs://bucket/agg" + assert config.WATERSHED == "gs://bucket/ws" + assert config.USE_RAW_EDGES is True + assert config.USE_RAW_COMPONENTS is True + assert config.TEST_RUN is True diff --git a/pychunkedgraph/tests/test_ingest_cross_edges.py b/pychunkedgraph/tests/test_ingest_cross_edges.py new file mode 100644 index 000000000..1084fb4a9 --- /dev/null +++ b/pychunkedgraph/tests/test_ingest_cross_edges.py @@ -0,0 +1,368 @@ +"""Tests for pychunkedgraph.ingest.create.cross_edges""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from pychunkedgraph.ingest.create.cross_edges import ( + _find_min_layer, + get_children_chunk_cross_edges, + get_chunk_nodes_cross_edge_layer, + _get_chunk_nodes_cross_edge_layer_helper, +) +from pychunkedgraph.graph.utils import basetypes + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestFindMinLayer: + """Pure unit tests for _find_min_layer helper.""" + + def test_single_batch(self): + """One array of node_ids and layers results in correct min layers.""" + node_layer_d = {} + node_ids_shared = [np.array([10, 20, 30], dtype=basetypes.NODE_ID)] + node_layers_shared = [np.array([3, 5, 4], dtype=np.uint8)] + + _find_min_layer(node_layer_d, node_ids_shared, node_layers_shared) + + assert node_layer_d[10] == 3 + assert node_layer_d[20] == 5 + assert node_layer_d[30] == 4 + assert len(node_layer_d) == 3 + + def test_multiple_batches_min_wins(self): + """Two batches with the same node_id but different layers; smallest layer wins.""" + node_layer_d = {} + node_ids_shared = [ + np.array([10, 20], dtype=basetypes.NODE_ID), + np.array([20, 30], dtype=basetypes.NODE_ID), + ] + node_layers_shared = [ + np.array([5, 7], dtype=np.uint8), + np.array([3, 4], dtype=np.uint8), + ] + + _find_min_layer(node_layer_d, node_ids_shared, node_layers_shared) + + assert node_layer_d[10] == 5 + # node 20 appears in both batches with layers 7 and 3; min is 3 + assert node_layer_d[20] == 3 + assert node_layer_d[30] == 4 + + def test_empty_batches(self): + """Empty arrays produce an empty dict.""" + node_layer_d = {} + node_ids_shared = [np.array([], dtype=basetypes.NODE_ID)] + node_layers_shared = [np.array([], dtype=np.uint8)] + + _find_min_layer(node_layer_d, node_ids_shared, node_layers_shared) + + assert len(node_layer_d) == 0 + + +class TestGetChildrenChunkCrossEdges: + """Integration tests for get_children_chunk_cross_edges using gen_graph.""" + + def test_no_cross_edges(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + + result = get_children_chunk_cross_edges(graph, 3, [0, 0, 0], use_threads=False) + # Should return empty or no cross edges + assert len(result) == 0 or result.size == 0 + + def test_with_cross_edges(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + + result = get_children_chunk_cross_edges(graph, 3, [0, 0, 0], use_threads=False) + assert len(result) > 0 + + @pytest.mark.timeout(30) + def test_no_atomic_chunks_returns_empty(self, gen_graph): + """When the chunk coordinate is out of bounds, get_touching_atomic_chunks + returns empty and the function returns early with an empty list.""" + cg = gen_graph(n_layers=3, atomic_chunk_bounds=np.array([1, 1, 1])) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + + # chunk_coord [1,0,0] is out of bounds for atomic_chunk_bounds=[1,1,1] + # so get_touching_atomic_chunks returns empty, triggering early return + result = get_children_chunk_cross_edges( + cg, layer=3, chunk_coord=[1, 0, 0], use_threads=False + ) + assert len(result) == 0 + + @pytest.mark.timeout(30) + def test_basic_cross_edges(self, gen_graph): + """A 4-layer graph with cross-chunk connected SVs returns cross edges + when called with use_threads=False.""" + cg = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + # Chunk A (0,0,0): sv 0 connected cross-chunk to chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + + # Chunk B (1,0,0): sv 0 connected cross-chunk to chunk A + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + + # Build parent layer so L3 nodes exist + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + + # Layer 3, chunk [0,0,0] should have cross edges connecting children chunks + result = get_children_chunk_cross_edges( + cg, layer=3, chunk_coord=[0, 0, 0], use_threads=False + ) + result = np.array(result) + assert result.size > 0 + assert result.ndim == 2 + assert result.shape[1] == 2 + + +class TestGetChildrenChunkCrossEdgesAdditional: + """Additional tests for get_children_chunk_cross_edges (serial path).""" + + @pytest.mark.timeout(30) + def test_multiple_cross_edges(self, gen_graph): + """Multiple SVs with cross-chunk edges should all be found.""" + cg = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + # Chunk A: two SVs, each cross-chunk connected + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 1), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), + (to_label(cg, 1, 1, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 1), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + + result = get_children_chunk_cross_edges( + cg, layer=3, chunk_coord=[0, 0, 0], use_threads=False + ) + result = np.array(result) + assert result.size > 0 + assert result.ndim == 2 + + @pytest.mark.timeout(30) + def test_cross_edges_layer4(self, gen_graph): + """Cross edges that span L3 chunk boundaries should appear at layer 4. + The SVs must be on the touching face between L3 children: + L4 [0,0,0] has L3 children [0,0,0] (x=0,1) and [1,0,0] (x=2,3). + Touching face is at L2 x=1 and x=2.""" + cg = gen_graph(n_layers=5) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + # SV at L1 [1,0,0] - on the right boundary of L3 [0,0,0] + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + # SV at L1 [2,0,0] - on the left boundary of L3 [1,0,0] + create_chunk( + cg, + vertices=[to_label(cg, 1, 2, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], n_threads=1) + + # At layer 4, chunk [0,0,0] should find cross edges at the L3 boundary + result = get_children_chunk_cross_edges( + cg, layer=4, chunk_coord=[0, 0, 0], use_threads=False + ) + result = np.array(result) + assert result.size > 0 + assert result.ndim == 2 + + +class TestGetChunkNodesCrossEdgeLayer: + """Tests for get_chunk_nodes_cross_edge_layer (lines 112-147).""" + + @pytest.mark.timeout(60) + def test_no_threads_with_cross_edges(self, gen_graph): + """use_threads=False should return dict mapping node_id to layer. + Cross edge between [0,0,0] and [2,0,0] has layer 3. + get_bounding_atomic_chunks(meta, 3, [0,0,0]) returns L2 boundary + chunks of L3 [0,0,0], which includes L2 at x=0 with AtomicCrossChunkEdge[3]. + """ + cg = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + # SV at L1 [0,0,0] with cross edge to [2,0,0] (layer-3 cross edge) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + # SV at L1 [2,0,0] + create_chunk( + cg, + vertices=[to_label(cg, 1, 2, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], n_threads=1) + + result = get_chunk_nodes_cross_edge_layer( + cg, layer=3, chunk_coord=[0, 0, 0], use_threads=False + ) + assert isinstance(result, dict) + assert len(result) > 0 + for node_id, layer in result.items(): + assert layer >= 3 + + @pytest.mark.timeout(60) + def test_no_threads_empty_chunk(self, gen_graph): + """use_threads=False with out-of-bounds chunk should return empty dict.""" + cg = gen_graph(n_layers=3, atomic_chunk_bounds=np.array([1, 1, 1])) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + + # Out of bounds chunk coord + result = get_chunk_nodes_cross_edge_layer( + cg, layer=3, chunk_coord=[1, 0, 0], use_threads=False + ) + assert isinstance(result, dict) + assert len(result) == 0 + + @pytest.mark.timeout(60) + def test_no_cross_edges_returns_empty(self, gen_graph): + """When chunks have no cross edges at the relevant layers, result is empty.""" + cg = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + + result = get_chunk_nodes_cross_edge_layer( + cg, layer=3, chunk_coord=[0, 0, 0], use_threads=False + ) + assert isinstance(result, dict) + assert len(result) == 0 + + +class TestFindMinLayerExtended: + """Additional tests for _find_min_layer with edge cases.""" + + def test_single_node_multiple_batches(self): + """Same node_id across multiple batches; lowest layer wins.""" + node_layer_d = {} + node_ids_shared = [ + np.array([100], dtype=basetypes.NODE_ID), + np.array([100], dtype=basetypes.NODE_ID), + np.array([100], dtype=basetypes.NODE_ID), + ] + node_layers_shared = [ + np.array([8], dtype=np.uint8), + np.array([3], dtype=np.uint8), + np.array([5], dtype=np.uint8), + ] + + _find_min_layer(node_layer_d, node_ids_shared, node_layers_shared) + assert node_layer_d[100] == 3 + + def test_no_overlap(self): + """All unique node_ids across batches should just pass through.""" + node_layer_d = {} + node_ids_shared = [ + np.array([1, 2], dtype=basetypes.NODE_ID), + np.array([3, 4], dtype=basetypes.NODE_ID), + ] + node_layers_shared = [ + np.array([5, 6], dtype=np.uint8), + np.array([7, 8], dtype=np.uint8), + ] + + _find_min_layer(node_layer_d, node_ids_shared, node_layers_shared) + assert node_layer_d[1] == 5 + assert node_layer_d[2] == 6 + assert node_layer_d[3] == 7 + assert node_layer_d[4] == 8 diff --git a/pychunkedgraph/tests/test_ingest_manager.py b/pychunkedgraph/tests/test_ingest_manager.py new file mode 100644 index 000000000..1c2032081 --- /dev/null +++ b/pychunkedgraph/tests/test_ingest_manager.py @@ -0,0 +1,131 @@ +"""Tests for pychunkedgraph.ingest.manager""" + +import pickle +import pytest +from unittest.mock import MagicMock, patch + +from pychunkedgraph.ingest import IngestConfig +from pychunkedgraph.graph.meta import ChunkedGraphMeta, GraphConfig, DataSource + + +def _make_config_and_meta(): + config = IngestConfig() + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], ID="test") + ds = DataSource( + EDGES="gs://test/edges", + COMPONENTS="gs://test/comp", + WATERSHED="gs://test/ws", + DATA_VERSION=2, + ) + meta = ChunkedGraphMeta(gc, ds) + meta._layer_count = 4 + meta._bitmasks = {1: 10, 2: 10, 3: 1, 4: 1} + return config, meta + + +def _make_manager(): + """Create an IngestionManager with mocked redis connection.""" + config, meta = _make_config_and_meta() + with patch("pychunkedgraph.ingest.manager.get_redis_connection") as mock_redis_conn: + mock_redis = MagicMock() + mock_redis_conn.return_value = mock_redis + from pychunkedgraph.ingest.manager import IngestionManager + + manager = IngestionManager(config=config, chunkedgraph_meta=meta) + return manager, config, meta, mock_redis + + +class TestIngestionManagerSerialization: + def test_serialized_dict(self): + config, meta = _make_config_and_meta() + # Test the serialized dict path without needing Redis + params = {"config": config, "chunkedgraph_meta": meta} + assert "config" in params + assert "chunkedgraph_meta" in params + assert params["config"] == config + + def test_serialized_pickle_roundtrip(self): + config, meta = _make_config_and_meta() + params = {"config": config, "chunkedgraph_meta": meta} + serialized = pickle.dumps(params) + restored = pickle.loads(serialized) + assert restored["config"] == config + assert restored["chunkedgraph_meta"].graph_config.ID == "test" + + +class TestSerializedDict: + def test_serialized_returns_dict_with_correct_keys(self): + """serialized() returns a dict with config and chunkedgraph_meta keys.""" + manager, config, meta, _ = _make_manager() + result = manager.serialized() + assert isinstance(result, dict) + assert "config" in result + assert "chunkedgraph_meta" in result + assert result["config"] is config + assert result["chunkedgraph_meta"] is meta + + +class TestSerializedPickleRoundtrip: + def test_serialized_pickled_roundtrips(self): + """serialized(pickled=True) produces bytes that pickle-load back correctly.""" + manager, config, meta, _ = _make_manager() + pickled = manager.serialized(pickled=True) + assert isinstance(pickled, bytes) + loaded = pickle.loads(pickled) + assert isinstance(loaded, dict) + assert loaded["config"] == config + assert isinstance(loaded["chunkedgraph_meta"], ChunkedGraphMeta) + assert loaded["chunkedgraph_meta"].graph_config == meta.graph_config + assert loaded["chunkedgraph_meta"].data_source == meta.data_source + + +class TestConfigProperty: + def test_config_property_returns_injected_config(self): + """config property returns the IngestConfig passed to __init__.""" + manager, config, _, _ = _make_manager() + assert manager.config is config + + +class TestCgMetaProperty: + def test_cg_meta_property_returns_injected_meta(self): + """cg_meta property returns the ChunkedGraphMeta passed to __init__.""" + manager, _, meta, _ = _make_manager() + assert manager.cg_meta is meta + + +class TestGetTaskQueueCaching: + def test_get_task_queue_returns_cached_on_second_call(self): + """Calling get_task_queue twice with the same name returns the same cached object.""" + manager, _, _, _ = _make_manager() + with patch("pychunkedgraph.ingest.manager.get_rq_queue") as mock_get_rq: + mock_queue = MagicMock() + mock_get_rq.return_value = mock_queue + + q1 = manager.get_task_queue("test_queue") + q2 = manager.get_task_queue("test_queue") + + assert q1 is q2 + mock_get_rq.assert_called_once_with("test_queue") + + +class TestRedisPropertyCaching: + def test_redis_returns_cached_connection(self): + """redis property returns cached value on second access; get_redis_connection not called again.""" + config, meta = _make_config_and_meta() + with patch( + "pychunkedgraph.ingest.manager.get_redis_connection" + ) as mock_redis_conn: + mock_redis = MagicMock() + mock_redis_conn.return_value = mock_redis + from pychunkedgraph.ingest.manager import IngestionManager + + manager = IngestionManager(config=config, chunkedgraph_meta=meta) + call_count_after_init = mock_redis_conn.call_count + + r1 = manager.redis + r2 = manager.redis + + # No additional calls to get_redis_connection after init + assert mock_redis_conn.call_count == call_count_after_init + assert r1 is r2 + assert r1 is mock_redis diff --git a/pychunkedgraph/tests/test_ingest_parent_layer.py b/pychunkedgraph/tests/test_ingest_parent_layer.py new file mode 100644 index 000000000..2e46a5e67 --- /dev/null +++ b/pychunkedgraph/tests/test_ingest_parent_layer.py @@ -0,0 +1,63 @@ +"""Tests for pychunkedgraph.ingest.create.parent_layer""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestAddParentChunk: + def test_single_thread(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + + # Should not raise + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + + # Verify parent was created + sv = to_label(graph, 1, 0, 0, 0, 0) + parent = graph.get_parent(sv) + assert parent is not None + assert graph.get_chunk_layer(parent) == 2 + + def test_multi_chunk(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + # Both SVs should share a root + root0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + root1 = graph.get_root(to_label(graph, 1, 1, 0, 0, 0)) + assert root0 == root1 diff --git a/pychunkedgraph/tests/test_ingest_ran_agglomeration.py b/pychunkedgraph/tests/test_ingest_ran_agglomeration.py new file mode 100644 index 000000000..9d02fd306 --- /dev/null +++ b/pychunkedgraph/tests/test_ingest_ran_agglomeration.py @@ -0,0 +1,1100 @@ +"""Tests for pychunkedgraph.ingest.ran_agglomeration - selected unit tests""" + +from binascii import crc32 + +import numpy as np +import pytest + +from pychunkedgraph.ingest.ran_agglomeration import ( + _crc_check, + _get_cont_chunk_coords, + define_active_edges, + get_active_edges, +) +from pychunkedgraph.graph.edges import Edges, EDGE_TYPES +from pychunkedgraph.graph.utils import basetypes + + +class TestCrcCheck: + def test_valid(self): + payload = b"test data here" + crc = np.array([crc32(payload)], dtype=np.uint32).tobytes() + full = payload + crc + _crc_check(full) # should not raise + + def test_invalid(self): + payload = b"test data here" + bad_crc = np.array([12345], dtype=np.uint32).tobytes() + full = payload + bad_crc + with pytest.raises(AssertionError): + _crc_check(full) + + +class TestDefineActiveEdges: + def test_basic(self): + edges = { + EDGE_TYPES.in_chunk: Edges( + np.array([1, 2], dtype=basetypes.NODE_ID), + np.array([3, 4], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([], []), + } + # Both sv1 and sv2 map to same agg ID -> active + mapping = {1: 0, 2: 0, 3: 0, 4: 0} + active, isolated = define_active_edges(edges, mapping) + assert np.all(active[EDGE_TYPES.in_chunk]) + + def test_unmapped_edges(self): + edges = { + EDGE_TYPES.in_chunk: Edges( + np.array([1], dtype=basetypes.NODE_ID), + np.array([2], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([], []), + } + # sv1 not in mapping -> isolated + mapping = {2: 0} + active, isolated = define_active_edges(edges, mapping) + assert not active[EDGE_TYPES.in_chunk][0] + assert 1 in isolated + + +class TestGetActiveEdges: + def test_basic(self): + edges = { + EDGE_TYPES.in_chunk: Edges( + np.array([1, 2], dtype=basetypes.NODE_ID), + np.array([3, 4], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([], []), + } + mapping = {1: 0, 2: 0, 3: 0, 4: 0} + chunk_edges, pseudo_isolated = get_active_edges(edges, mapping) + for et in EDGE_TYPES: + assert et in chunk_edges + assert len(pseudo_isolated) > 0 + + +class TestGetContChunkCoords: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + + class FakeIM: + cg_meta = graph.meta + + coord_a = np.array([1, 0, 0]) + coord_b = np.array([0, 0, 0]) + result = _get_cont_chunk_coords(FakeIM(), coord_a, coord_b) + assert isinstance(result, list) + + def test_returns_only_valid_coords(self, gen_graph): + """All returned coords should not be out of bounds.""" + graph = gen_graph(n_layers=4) + + class FakeIM: + cg_meta = graph.meta + + coord_a = np.array([1, 0, 0]) + coord_b = np.array([0, 0, 0]) + result = _get_cont_chunk_coords(FakeIM(), coord_a, coord_b) + for coord in result: + assert not graph.meta.is_out_of_bounds(coord) + + def test_symmetric_direction(self, gen_graph): + """Swapping coord_a and coord_b should yield the same set of neighboring coords.""" + graph = gen_graph(n_layers=4) + + class FakeIM: + cg_meta = graph.meta + + coord_a = np.array([1, 0, 0]) + coord_b = np.array([0, 0, 0]) + result_ab = _get_cont_chunk_coords(FakeIM(), coord_a, coord_b) + result_ba = _get_cont_chunk_coords(FakeIM(), coord_b, coord_a) + + # Convert to sets of tuples for comparison + set_ab = {tuple(c) for c in result_ab} + set_ba = {tuple(c) for c in result_ba} + assert set_ab == set_ba + + def test_non_adjacent_raises(self, gen_graph): + """Non-adjacent chunks (differing in more than one dim) should raise AssertionError.""" + graph = gen_graph(n_layers=4) + + class FakeIM: + cg_meta = graph.meta + + coord_a = np.array([1, 1, 0]) + coord_b = np.array([0, 0, 0]) + with pytest.raises(AssertionError): + _get_cont_chunk_coords(FakeIM(), coord_a, coord_b) + + def test_y_dim_adjacency(self, gen_graph): + """Test adjacency along y dimension.""" + graph = gen_graph(n_layers=4) + + class FakeIM: + cg_meta = graph.meta + + coord_a = np.array([0, 1, 0]) + coord_b = np.array([0, 0, 0]) + result = _get_cont_chunk_coords(FakeIM(), coord_a, coord_b) + assert isinstance(result, list) + # All returned coords should differ from chunk_coord_l along y + for coord in result: + assert not graph.meta.is_out_of_bounds(coord) + + +class TestParseEdgePayloads: + def test_empty_payloads(self): + """Empty list of payloads should return empty result.""" + from pychunkedgraph.ingest.ran_agglomeration import _parse_edge_payloads + + result = _parse_edge_payloads( + [], edge_dtype=[("sv1", np.uint64), ("sv2", np.uint64)] + ) + assert result == [] + + def test_none_content_skipped(self): + """Payloads with None content should be skipped.""" + from pychunkedgraph.ingest.ran_agglomeration import _parse_edge_payloads + + payloads = [{"content": None}] + result = _parse_edge_payloads( + payloads, edge_dtype=[("sv1", np.uint64), ("sv2", np.uint64)] + ) + assert result == [] + + def test_valid_payload(self): + """A valid payload with correct CRC should be parsed.""" + from pychunkedgraph.ingest.ran_agglomeration import _parse_edge_payloads + + dtype = [("sv1", np.uint64), ("sv2", np.uint64)] + data = np.array([(1, 2), (3, 4)], dtype=dtype) + raw = data.tobytes() + crc_val = np.array([crc32(raw)], dtype=np.uint32).tobytes() + content = raw + crc_val + + payloads = [{"content": content}] + result = _parse_edge_payloads(payloads, edge_dtype=dtype) + assert len(result) == 1 + assert len(result[0]) == 2 + assert result[0][0]["sv1"] == 1 + assert result[0][1]["sv2"] == 4 + + def test_bad_crc_raises(self): + """Payload with bad CRC should raise AssertionError.""" + from pychunkedgraph.ingest.ran_agglomeration import _parse_edge_payloads + + dtype = [("sv1", np.uint64), ("sv2", np.uint64)] + data = np.array([(1, 2)], dtype=dtype) + raw = data.tobytes() + bad_crc = np.array([99999], dtype=np.uint32).tobytes() + content = raw + bad_crc + + payloads = [{"content": content}] + with pytest.raises(AssertionError): + _parse_edge_payloads(payloads, edge_dtype=dtype) + + +class TestDefineActiveEdgesExtended: + def test_both_unmapped(self): + """When both endpoints are unmapped, edge should be inactive and both isolated.""" + edges = { + EDGE_TYPES.in_chunk: Edges( + np.array([10, 20], dtype=basetypes.NODE_ID), + np.array([30, 40], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([], []), + } + mapping = {} # No IDs in mapping + active, isolated = define_active_edges(edges, mapping) + # All edges should be inactive + assert not np.any(active[EDGE_TYPES.in_chunk]) + # All unmapped IDs should appear in isolated + for sv_id in [10, 20, 30, 40]: + assert sv_id in isolated + + def test_different_agg_ids(self): + """Edges where sv1 and sv2 map to different agg IDs should be inactive.""" + edges = { + EDGE_TYPES.in_chunk: Edges( + np.array([1], dtype=basetypes.NODE_ID), + np.array([2], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([], []), + } + mapping = {1: 100, 2: 200} # Different agg IDs + active, isolated = define_active_edges(edges, mapping) + assert not active[EDGE_TYPES.in_chunk][0] + + def test_empty_edges(self): + """Empty edge arrays should produce empty active arrays.""" + edges = { + EDGE_TYPES.in_chunk: Edges( + np.array([], dtype=basetypes.NODE_ID), + np.array([], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges([], []), + } + mapping = {} + active, isolated = define_active_edges(edges, mapping) + assert len(active[EDGE_TYPES.in_chunk]) == 0 + + def test_between_chunk_edges_active(self): + """Between-chunk edges should also be classified.""" + edges = { + EDGE_TYPES.in_chunk: Edges([], []), + EDGE_TYPES.between_chunk: Edges( + np.array([1, 2], dtype=basetypes.NODE_ID), + np.array([3, 4], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.cross_chunk: Edges([], []), + } + # 1->3 same agg, 2->4 different agg + mapping = {1: 0, 3: 0, 2: 1, 4: 2} + active, isolated = define_active_edges(edges, mapping) + assert active[EDGE_TYPES.between_chunk][0] # same agg + assert not active[EDGE_TYPES.between_chunk][1] # different agg + + +class TestGetActiveEdgesExtended: + def test_cross_chunk_always_active(self): + """Cross-chunk edges should always be kept active regardless of mapping.""" + edges = { + EDGE_TYPES.in_chunk: Edges([], []), + EDGE_TYPES.between_chunk: Edges([], []), + EDGE_TYPES.cross_chunk: Edges( + np.array([1, 2], dtype=basetypes.NODE_ID), + np.array([3, 4], dtype=basetypes.NODE_ID), + affinities=np.array([float("inf"), float("inf")]), + areas=np.array([1.0, 1.0]), + ), + } + mapping = {} # Empty mapping - but cross_chunk should still be active + chunk_edges, pseudo_isolated = get_active_edges(edges, mapping) + assert len(chunk_edges[EDGE_TYPES.cross_chunk].node_ids1) == 2 + + def test_pseudo_isolated_includes_all_node_ids(self): + """pseudo_isolated should include all node_ids from all edge types.""" + edges = { + EDGE_TYPES.in_chunk: Edges( + np.array([1], dtype=basetypes.NODE_ID), + np.array([2], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.between_chunk: Edges( + np.array([3], dtype=basetypes.NODE_ID), + np.array([4], dtype=basetypes.NODE_ID), + ), + EDGE_TYPES.cross_chunk: Edges( + np.array([5], dtype=basetypes.NODE_ID), + np.array([6], dtype=basetypes.NODE_ID), + affinities=np.array([float("inf")]), + areas=np.array([1.0]), + ), + } + mapping = {1: 0, 2: 0, 3: 0, 4: 0} + chunk_edges, pseudo_isolated = get_active_edges(edges, mapping) + # Should include node_ids1 from all types and node_ids2 from in_chunk + for sv_id in [1, 2, 3, 5]: + assert sv_id in pseudo_isolated + + +class TestGetIndex: + """Tests for _get_index which reads sharded index data from CloudFiles.""" + + def test_inchunk_index(self): + """Test _get_index with inchunk_or_agg=True uses single-u8 chunkid dtype.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import ( + CRC_LEN, + HEADER_LEN, + VERSION_LEN, + _get_index, + ) + + # Create fake index data with inchunk dtype + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + index_entries = np.array([(100, 20, 50)], dtype=dt) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + # Build a fake header: version (4 bytes) + idx_offset (8 bytes) + idx_length (8 bytes) = 20 bytes + idx_offset = np.uint64(HEADER_LEN) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + cf = MagicMock() + # First call returns headers, second call returns index data + cf.get.side_effect = [ + [{"path": "test.data", "content": header_content}], + [{"path": "test.data", "content": index_with_crc}], + ] + + result = _get_index(cf, ["test.data"], inchunk_or_agg=True) + assert "test.data" in result + assert result["test.data"][0]["chunkid"] == 100 + assert result["test.data"][0]["offset"] == 20 + assert result["test.data"][0]["size"] == 50 + + def test_between_chunk_index(self): + """Test _get_index with inchunk_or_agg=False uses 2-u8 chunkid dtype.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import ( + CRC_LEN, + HEADER_LEN, + VERSION_LEN, + _get_index, + ) + + # Between-chunk index uses ("chunkid", "2u8") -> two uint64 values + dt = np.dtype([("chunkid", "2u8"), ("offset", "u8"), ("size", "u8")]) + index_entries = np.array([((200, 300), 40, 60)], dtype=dt) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(HEADER_LEN) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + cf = MagicMock() + cf.get.side_effect = [ + [{"path": "between.data", "content": header_content}], + [{"path": "between.data", "content": index_with_crc}], + ] + + result = _get_index(cf, ["between.data"], inchunk_or_agg=False) + assert "between.data" in result + assert result["between.data"][0]["chunkid"][0] == 200 + assert result["between.data"][0]["chunkid"][1] == 300 + assert result["between.data"][0]["offset"] == 40 + assert result["between.data"][0]["size"] == 60 + + def test_none_content_skipped(self): + """When header content is None, that file should be skipped in the index.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import _get_index + + cf = MagicMock() + # Header returns None content for the file + cf.get.side_effect = [ + [{"path": "missing.data", "content": None}], + [], # No index_infos to fetch + ] + + result = _get_index(cf, ["missing.data"], inchunk_or_agg=True) + assert result == {} + + def test_multiple_files(self): + """Test _get_index with multiple filenames, one valid and one missing.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _get_index, + ) + + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + index_entries = np.array([(500, 100, 200)], dtype=dt) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(HEADER_LEN) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + cf = MagicMock() + cf.get.side_effect = [ + [ + {"path": "valid.data", "content": header_content}, + {"path": "invalid.data", "content": None}, + ], + [{"path": "valid.data", "content": index_with_crc}], + ] + + result = _get_index(cf, ["valid.data", "invalid.data"], inchunk_or_agg=True) + assert "valid.data" in result + assert "invalid.data" not in result + + def test_multiple_index_entries(self): + """Test _get_index with multiple entries in a single file index.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _get_index, + ) + + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + index_entries = np.array( + [(100, 20, 50), (200, 70, 80), (300, 150, 30)], dtype=dt + ) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(HEADER_LEN) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + cf = MagicMock() + cf.get.side_effect = [ + [{"path": "multi.data", "content": header_content}], + [{"path": "multi.data", "content": index_with_crc}], + ] + + result = _get_index(cf, ["multi.data"], inchunk_or_agg=True) + assert "multi.data" in result + assert len(result["multi.data"]) == 3 + assert result["multi.data"][0]["chunkid"] == 100 + assert result["multi.data"][1]["chunkid"] == 200 + assert result["multi.data"][2]["chunkid"] == 300 + + +class TestReadInChunkFiles: + """Tests for _read_in_chunk_files which reads edge data for a specific chunk.""" + + def test_basic_read(self): + """Mock CloudFiles to test full read flow for in-chunk files.""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _read_in_chunk_files, + ) + + chunk_id = np.uint64(100) + edge_dtype = [("sv1", np.uint64), ("sv2", np.uint64)] + + # Build index: one entry for our chunk_id + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + edge_data = np.array([(10, 20)], dtype=edge_dtype) + edge_bytes = edge_data.tobytes() + edge_crc = np.array([crc32(edge_bytes)], dtype=np.uint32).tobytes() + edge_payload = edge_bytes + edge_crc + + data_offset = np.uint64(HEADER_LEN) + data_size = np.uint64(len(edge_payload)) + + index_entries = np.array([(chunk_id, data_offset, data_size)], dtype=dt) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(data_offset + data_size) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + # 1st call: headers + [{"path": "in_chunk_0_0_0_0.data", "content": header_content}], + # 2nd call: index data + [{"path": "in_chunk_0_0_0_0.data", "content": index_with_crc}], + # 3rd call: edge payloads + [{"path": "in_chunk_0_0_0_0.data", "content": edge_payload}], + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_in_chunk_files( + chunk_id, "gs://fake/path", ["in_chunk_0_0_0_0.data"], edge_dtype + ) + + assert len(result) == 1 + assert result[0][0]["sv1"] == 10 + assert result[0][0]["sv2"] == 20 + + def test_no_matching_chunk(self): + """When the index has no entry matching the requested chunk_id, no payloads are fetched.""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _read_in_chunk_files, + ) + + chunk_id = np.uint64(999) + edge_dtype = [("sv1", np.uint64), ("sv2", np.uint64)] + + # Index entry for a *different* chunk_id + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + index_entries = np.array([(100, 20, 50)], dtype=dt) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(HEADER_LEN) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + [{"path": "in_chunk_0_0_0_0.data", "content": header_content}], + [{"path": "in_chunk_0_0_0_0.data", "content": index_with_crc}], + [], # No payloads fetched + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_in_chunk_files( + chunk_id, "gs://fake/path", ["in_chunk_0_0_0_0.data"], edge_dtype + ) + + assert result == [] + + +class TestReadBetweenOrFakeChunkFiles: + """Tests for _read_between_or_fake_chunk_files which reads between-chunk edge data.""" + + def _make_between_index_and_header(self, entries_list): + """Helper to create between-chunk index data and header. + + entries_list: list of (chunkid0, chunkid1, offset, size) tuples + """ + from pychunkedgraph.ingest.ran_agglomeration import HEADER_LEN + + dt = np.dtype([("chunkid", "2u8"), ("offset", "u8"), ("size", "u8")]) + index_entries = np.array( + [((c0, c1), off, sz) for c0, c1, off, sz in entries_list], dtype=dt + ) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(HEADER_LEN) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + return header_content, index_with_crc + + def test_basic_between_chunk_read(self): + """Test reading between-chunk files with matching chunk pair.""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _read_between_or_fake_chunk_files, + ) + + chunk_id = np.uint64(100) + adjacent_id = np.uint64(200) + edge_dtype = [("sv1", np.uint64), ("sv2", np.uint64)] + + # Create edge payload + edge_data = np.array([(10, 20)], dtype=edge_dtype) + edge_bytes = edge_data.tobytes() + edge_crc = np.array([crc32(edge_bytes)], dtype=np.uint32).tobytes() + edge_payload = edge_bytes + edge_crc + + data_offset = np.uint64(HEADER_LEN) + data_size = np.uint64(len(edge_payload)) + + header_content, index_with_crc = self._make_between_index_and_header( + [(100, 200, int(data_offset), int(data_size))] + ) + + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + # headers + [{"path": "between.data", "content": header_content}], + # index data + [{"path": "between.data", "content": index_with_crc}], + # chunk_finfos payloads (forward direction) + [{"path": "between.data", "content": edge_payload}], + # adj_chunk_finfos payloads (reverse direction) - empty + [], + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_between_or_fake_chunk_files( + chunk_id, adjacent_id, "gs://fake/path", ["between.data"], edge_dtype + ) + + assert len(result) == 1 + assert result[0][0]["sv1"] == 10 + assert result[0][0]["sv2"] == 20 + + def test_reverse_direction(self): + """Test reading from the adjacent->chunk direction (swapped columns in result dtype).""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _read_between_or_fake_chunk_files, + ) + + chunk_id = np.uint64(100) + adjacent_id = np.uint64(200) + edge_dtype = [("sv1", np.uint64), ("sv2", np.uint64)] + + # Edge payload for the *reverse* direction (adjacent_id, chunk_id) + # When reading reverse direction, the dtype columns are swapped: (sv2, sv1) + rev_edge_dtype = [("sv2", np.uint64), ("sv1", np.uint64)] + edge_data = np.array([(30, 40)], dtype=rev_edge_dtype) + edge_bytes = edge_data.tobytes() + edge_crc = np.array([crc32(edge_bytes)], dtype=np.uint32).tobytes() + edge_payload = edge_bytes + edge_crc + + data_offset = np.uint64(HEADER_LEN) + data_size = np.uint64(len(edge_payload)) + + # Index entry: (adjacent_id, chunk_id) => reverse direction + header_content, index_with_crc = self._make_between_index_and_header( + [(200, 100, int(data_offset), int(data_size))] + ) + + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + # headers + [{"path": "between.data", "content": header_content}], + # index + [{"path": "between.data", "content": index_with_crc}], + # chunk_finfos (forward) - empty + [], + # adj_chunk_finfos (reverse) + [{"path": "between.data", "content": edge_payload}], + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_between_or_fake_chunk_files( + chunk_id, adjacent_id, "gs://fake/path", ["between.data"], edge_dtype + ) + + # Result comes from adj_result which used the swapped dtype + assert len(result) == 1 + assert result[0][0]["sv2"] == 30 + assert result[0][0]["sv1"] == 40 + + def test_no_matching_pairs(self): + """When no chunk pair matches, should return empty list.""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _read_between_or_fake_chunk_files, + ) + + chunk_id = np.uint64(100) + adjacent_id = np.uint64(200) + edge_dtype = [("sv1", np.uint64), ("sv2", np.uint64)] + + # Index entry for a totally different pair + header_content, index_with_crc = self._make_between_index_and_header( + [(999, 888, 20, 50)] + ) + + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + [{"path": "between.data", "content": header_content}], + [{"path": "between.data", "content": index_with_crc}], + [], # No forward payloads + [], # No reverse payloads + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_between_or_fake_chunk_files( + chunk_id, adjacent_id, "gs://fake/path", ["between.data"], edge_dtype + ) + + assert result == [] + + +class TestReadAggFiles: + """Tests for _read_agg_files which reads agglomeration remap data.""" + + def test_basic_agg_read(self): + """Test reading agglomeration files returns edge list.""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + CRC_LEN, + HEADER_LEN, + _read_agg_files, + ) + + chunk_id = np.uint64(42) + + # Index entry for our chunk + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + + # Build edge data: pairs of node IDs + edges = np.array([[10, 20], [30, 40]], dtype=basetypes.NODE_ID) + edge_bytes = edges.tobytes() + edge_crc = np.array([crc32(edge_bytes)], dtype=np.uint32).tobytes() + edge_payload = edge_bytes + edge_crc + + data_offset = np.uint64(HEADER_LEN) + data_size = np.uint64(len(edge_payload)) + + index_entries = np.array([(chunk_id, data_offset, data_size)], dtype=dt) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(data_offset + data_size) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + # headers + [{"path": "done_0_0_0_0.data", "content": header_content}], + # index + [{"path": "done_0_0_0_0.data", "content": index_with_crc}], + # payloads + [{"path": "done_0_0_0_0.data", "content": edge_payload}], + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_agg_files( + ["done_0_0_0_0.data"], [chunk_id], "gs://fake/remap/" + ) + + # Result is a list starting with empty_2d, plus our edge data + assert len(result) >= 2 # empty_2d + our edges + # The last element should be our 2x2 edge array + combined = np.concatenate(result) + assert combined.shape[1] == 2 + assert len(combined) == 2 + + def test_missing_file_skipped(self): + """When a filename is not in files_index (KeyError), it should be skipped.""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _read_agg_files, + ) + + # No valid headers -> empty index + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + # headers: all None + [{"path": "done_0_0_0_0.data", "content": None}], + [], # empty index_infos + [], # empty payloads + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_agg_files( + ["done_0_0_0_0.data"], [np.uint64(42)], "gs://fake/remap/" + ) + + # Should only contain the initial empty_2d + assert len(result) == 1 + assert result[0].shape == (0, 2) + + def test_none_payload_skipped(self): + """When a payload content is None, it should be skipped.""" + from unittest.mock import MagicMock, patch + + from pychunkedgraph.ingest.ran_agglomeration import ( + HEADER_LEN, + _read_agg_files, + ) + + chunk_id = np.uint64(42) + dt = np.dtype([("chunkid", "u8"), ("offset", "u8"), ("size", "u8")]) + index_entries = np.array([(chunk_id, 20, 50)], dtype=dt) + index_bytes = index_entries.tobytes() + index_crc = np.array([crc32(index_bytes)], dtype=np.uint32).tobytes() + index_with_crc = index_bytes + index_crc + + idx_offset = np.uint64(HEADER_LEN) + idx_length = np.uint64(len(index_with_crc)) + version = b"SO01" + header_content = ( + version + + np.array([idx_offset], dtype=np.uint64).tobytes() + + np.array([idx_length], dtype=np.uint64).tobytes() + ) + + mock_cf = MagicMock() + mock_cf.get.side_effect = [ + [{"path": "done_0_0_0_0.data", "content": header_content}], + [{"path": "done_0_0_0_0.data", "content": index_with_crc}], + [{"path": "done_0_0_0_0.data", "content": None}], # None payload + ] + + with patch( + "pychunkedgraph.ingest.ran_agglomeration.CloudFiles", + return_value=mock_cf, + ): + result = _read_agg_files( + ["done_0_0_0_0.data"], [chunk_id], "gs://fake/remap/" + ) + + # Should only contain the initial empty_2d (None content was skipped) + assert len(result) == 1 + assert result[0].shape == (0, 2) + + +class TestReadRawEdgeData: + """Tests for read_raw_edge_data which orchestrates edge collection and writing.""" + + from unittest.mock import patch, MagicMock + + @patch("pychunkedgraph.ingest.ran_agglomeration._collect_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.postprocess_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.put_chunk_edges") + def test_basic(self, mock_put, mock_postprocess, mock_collect): + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import read_raw_edge_data + + # Setup mock return values + edge_dict = {} + for et in EDGE_TYPES: + edge_dict[et] = { + "sv1": np.array([1, 2], dtype=np.uint64), + "sv2": np.array([3, 4], dtype=np.uint64), + "aff": np.array([0.5, 0.6]), + "area": np.array([10, 20]), + } + # cross_chunk doesn't have aff/area in the read path (they get inf/ones) + edge_dict[EDGE_TYPES.cross_chunk] = { + "sv1": np.array([5], dtype=np.uint64), + "sv2": np.array([6], dtype=np.uint64), + } + mock_collect.return_value = edge_dict + mock_postprocess.return_value = edge_dict + + imanager = MagicMock() + imanager.cg_meta.data_source.EDGES = "gs://fake/edges" + + result = read_raw_edge_data(imanager, [0, 0, 0]) + assert EDGE_TYPES.in_chunk in result + assert EDGE_TYPES.between_chunk in result + assert EDGE_TYPES.cross_chunk in result + # in_chunk should have 2 edges + assert len(result[EDGE_TYPES.in_chunk].node_ids1) == 2 + # cross_chunk should have 1 edge with inf affinity + assert len(result[EDGE_TYPES.cross_chunk].node_ids1) == 1 + assert np.isinf(result[EDGE_TYPES.cross_chunk].affinities[0]) + # put_chunk_edges should have been called since there are edges + mock_put.assert_called_once() + + @patch("pychunkedgraph.ingest.ran_agglomeration._collect_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.postprocess_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.put_chunk_edges") + def test_no_edges(self, mock_put, mock_postprocess, mock_collect): + """When all edge types are empty, put_chunk_edges should not be called.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import read_raw_edge_data + + edge_dict = {et: {} for et in EDGE_TYPES} + mock_collect.return_value = edge_dict + mock_postprocess.return_value = edge_dict + + imanager = MagicMock() + imanager.cg_meta.data_source.EDGES = "gs://fake/edges" + + result = read_raw_edge_data(imanager, [0, 0, 0]) + # All edge types should be empty Edges objects + for et in EDGE_TYPES: + assert len(result[et].node_ids1) == 0 + mock_put.assert_not_called() + + @patch("pychunkedgraph.ingest.ran_agglomeration._collect_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.postprocess_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.put_chunk_edges") + def test_edges_but_no_storage_path(self, mock_put, mock_postprocess, mock_collect): + """When EDGES path is empty/falsy, put_chunk_edges should not be called.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import read_raw_edge_data + + edge_dict = {} + for et in EDGE_TYPES: + edge_dict[et] = { + "sv1": np.array([1], dtype=np.uint64), + "sv2": np.array([2], dtype=np.uint64), + "aff": np.array([0.5]), + "area": np.array([10]), + } + mock_collect.return_value = edge_dict + mock_postprocess.return_value = edge_dict + + imanager = MagicMock() + imanager.cg_meta.data_source.EDGES = "" # empty string = falsy + + result = read_raw_edge_data(imanager, [0, 0, 0]) + assert EDGE_TYPES.in_chunk in result + mock_put.assert_not_called() + + @patch("pychunkedgraph.ingest.ran_agglomeration._collect_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.postprocess_edge_data") + @patch("pychunkedgraph.ingest.ran_agglomeration.put_chunk_edges") + def test_partial_edges(self, mock_put, mock_postprocess, mock_collect): + """Only in_chunk has edges, others empty.""" + from unittest.mock import MagicMock + + from pychunkedgraph.ingest.ran_agglomeration import read_raw_edge_data + + edge_dict = { + EDGE_TYPES.in_chunk: { + "sv1": np.array([1, 2], dtype=np.uint64), + "sv2": np.array([3, 4], dtype=np.uint64), + "aff": np.array([0.5, 0.6]), + "area": np.array([10, 20]), + }, + EDGE_TYPES.between_chunk: {}, + EDGE_TYPES.cross_chunk: {}, + } + mock_collect.return_value = edge_dict + mock_postprocess.return_value = edge_dict + + imanager = MagicMock() + imanager.cg_meta.data_source.EDGES = "gs://fake/edges" + + result = read_raw_edge_data(imanager, [0, 0, 0]) + assert len(result[EDGE_TYPES.in_chunk].node_ids1) == 2 + assert len(result[EDGE_TYPES.between_chunk].node_ids1) == 0 + assert len(result[EDGE_TYPES.cross_chunk].node_ids1) == 0 + # Should still write because in_chunk has edges + mock_put.assert_called_once() + + +class TestReadRawAgglomerationData: + """Tests for read_raw_agglomeration_data which reads agg remap files.""" + + from unittest.mock import patch, MagicMock + + @patch("pychunkedgraph.ingest.ran_agglomeration._read_agg_files") + @patch("pychunkedgraph.ingest.ran_agglomeration.put_chunk_components") + def test_basic(self, mock_put_components, mock_read_agg, gen_graph): + from pychunkedgraph.ingest.ran_agglomeration import read_raw_agglomeration_data + from unittest.mock import MagicMock + + graph = gen_graph(n_layers=4) + imanager = MagicMock() + imanager.cg_meta = graph.meta + imanager.config.AGGLOMERATION = "gs://fake/agg" + + # Return edge pairs that form connected components + mock_read_agg.return_value = [np.array([[1, 2], [2, 3]], dtype=np.uint64)] + + mapping = read_raw_agglomeration_data(imanager, np.array([0, 0, 0])) + assert isinstance(mapping, dict) + # 1, 2, 3 should all map to the same component + assert mapping[1] == mapping[2] == mapping[3] + # put_chunk_components should have been called + mock_put_components.assert_called_once() + + @patch("pychunkedgraph.ingest.ran_agglomeration._read_agg_files") + @patch("pychunkedgraph.ingest.ran_agglomeration.put_chunk_components") + def test_multiple_components(self, mock_put_components, mock_read_agg, gen_graph): + from pychunkedgraph.ingest.ran_agglomeration import read_raw_agglomeration_data + from unittest.mock import MagicMock + + graph = gen_graph(n_layers=4) + imanager = MagicMock() + imanager.cg_meta = graph.meta + imanager.config.AGGLOMERATION = "gs://fake/agg" + + # Two separate components: {1,2} and {3,4} + mock_read_agg.return_value = [np.array([[1, 2], [3, 4]], dtype=np.uint64)] + + mapping = read_raw_agglomeration_data(imanager, np.array([0, 0, 0])) + assert isinstance(mapping, dict) + assert mapping[1] == mapping[2] + assert mapping[3] == mapping[4] + # The two components should have different IDs + assert mapping[1] != mapping[3] + + @patch("pychunkedgraph.ingest.ran_agglomeration._read_agg_files") + @patch("pychunkedgraph.ingest.ran_agglomeration.put_chunk_components") + def test_no_components_path(self, mock_put_components, mock_read_agg, gen_graph): + """When COMPONENTS is None (falsy), put_chunk_components should not be called.""" + from pychunkedgraph.ingest.ran_agglomeration import read_raw_agglomeration_data + from unittest.mock import MagicMock + + graph = gen_graph(n_layers=4) + # Replace the data_source with one that has COMPONENTS=None + original_ds = graph.meta.data_source + graph.meta._data_source = original_ds._replace(COMPONENTS=None) + + imanager = MagicMock() + imanager.cg_meta = graph.meta + imanager.config.AGGLOMERATION = "gs://fake/agg" + + mock_read_agg.return_value = [np.array([[1, 2]], dtype=np.uint64)] + + mapping = read_raw_agglomeration_data(imanager, np.array([0, 0, 0])) + assert isinstance(mapping, dict) + mock_put_components.assert_not_called() + + # Restore original data_source + graph.meta._data_source = original_ds diff --git a/pychunkedgraph/tests/test_ingest_utils.py b/pychunkedgraph/tests/test_ingest_utils.py new file mode 100644 index 000000000..4c5bdf0af --- /dev/null +++ b/pychunkedgraph/tests/test_ingest_utils.py @@ -0,0 +1,492 @@ +"""Tests for pychunkedgraph.ingest.utils""" + +import io +import sys +import numpy as np +import pytest +from unittest.mock import MagicMock, patch + +from pychunkedgraph.ingest.utils import ( + bootstrap, + chunk_id_str, + get_chunks_not_done, + job_type_guard, + move_up, + postprocess_edge_data, + randomize_grid_points, +) + + +class TestBootstrap: + def test_from_config(self): + from google.auth import credentials + + config = { + "data_source": { + "EDGES": "gs://test/edges", + "COMPONENTS": "gs://test/components", + "WATERSHED": "gs://test/ws", + }, + "graph_config": { + "CHUNK_SIZE": [64, 64, 64], + "FANOUT": 2, + "SPATIAL_BITS": 10, + }, + "backend_client": { + "TYPE": "bigtable", + "CONFIG": { + "ADMIN": True, + "READ_ONLY": False, + "PROJECT": "test-project", + "INSTANCE": "test-instance", + "CREDENTIALS": credentials.AnonymousCredentials(), + }, + }, + "ingest_config": {}, + } + meta, ingest_config, client_info = bootstrap("test_graph", config=config) + assert meta.graph_config.ID == "test_graph" + assert meta.graph_config.FANOUT == 2 + assert ingest_config.USE_RAW_EDGES is False + + +class TestPostprocessEdgeData: + def test_v2_passthrough(self): + class FakeMeta: + class data_source: + DATA_VERSION = 2 + + resolution = np.array([1, 1, 1]) + + class FakeIM: + cg_meta = FakeMeta() + + edge_dict = {"test": {"sv1": [1], "sv2": [2], "aff": [0.5], "area": [10]}} + result = postprocess_edge_data(FakeIM(), edge_dict) + assert result == edge_dict + + def test_v3(self): + class FakeMeta: + class data_source: + DATA_VERSION = 3 + + resolution = np.array([4, 4, 40]) + + class FakeIM: + cg_meta = FakeMeta() + + edge_dict = { + "test": { + "sv1": np.array([1]), + "sv2": np.array([2]), + "aff_x": np.array([0.1]), + "aff_y": np.array([0.2]), + "aff_z": np.array([0.3]), + "area_x": np.array([10]), + "area_y": np.array([20]), + "area_z": np.array([30]), + } + } + result = postprocess_edge_data(FakeIM(), edge_dict) + assert "aff" in result["test"] + assert "area" in result["test"] + # aff = 0.1*4 + 0.2*4 + 0.3*40 = 0.4 + 0.8 + 12 = 13.2 + np.testing.assert_almost_equal(result["test"]["aff"][0], 13.2) + + def test_empty_data(self): + class FakeMeta: + class data_source: + DATA_VERSION = 3 + + resolution = np.array([1, 1, 1]) + + class FakeIM: + cg_meta = FakeMeta() + + edge_dict = {"test": {}} + result = postprocess_edge_data(FakeIM(), edge_dict) + assert result["test"] == {} + + +class TestRandomizeGridPoints: + def test_basic(self): + points = list(randomize_grid_points(2, 2, 2)) + assert len(points) == 8 + # All coordinates should be valid + for x, y, z in points: + assert 0 <= x < 2 + assert 0 <= y < 2 + assert 0 <= z < 2 + + def test_covers_all(self): + points = list(randomize_grid_points(3, 2, 1)) + assert len(points) == 6 + coords = {(x, y, z) for x, y, z in points} + assert len(coords) == 6 + + +class TestPostprocessEdgeDataUnknownVersion: + def test_version5_raises(self): + """Version 5 is not supported and should raise ValueError.""" + + class FakeMeta: + class data_source: + DATA_VERSION = 5 + + resolution = np.array([1, 1, 1]) + + class FakeIM: + cg_meta = FakeMeta() + + edge_dict = {"test": {"sv1": [1], "sv2": [2]}} + with pytest.raises(ValueError, match="Unknown data_version"): + postprocess_edge_data(FakeIM(), edge_dict) + + +class TestPostprocessEdgeDataV4SameAsV3: + def test_v4_same_code_path(self): + """Version 4 should use the same processing logic as v3 (combine xyz components).""" + + class FakeMetaV3: + class data_source: + DATA_VERSION = 3 + + resolution = np.array([2, 2, 20]) + + class FakeMetaV4: + class data_source: + DATA_VERSION = 4 + + resolution = np.array([2, 2, 20]) + + class FakeIMv3: + cg_meta = FakeMetaV3() + + class FakeIMv4: + cg_meta = FakeMetaV4() + + edge_dict_v3 = { + "test": { + "sv1": np.array([10]), + "sv2": np.array([20]), + "aff_x": np.array([0.1]), + "aff_y": np.array([0.2]), + "aff_z": np.array([0.3]), + "area_x": np.array([5]), + "area_y": np.array([6]), + "area_z": np.array([7]), + } + } + edge_dict_v4 = { + "test": { + "sv1": np.array([10]), + "sv2": np.array([20]), + "aff_x": np.array([0.1]), + "aff_y": np.array([0.2]), + "aff_z": np.array([0.3]), + "area_x": np.array([5]), + "area_y": np.array([6]), + "area_z": np.array([7]), + } + } + + result_v3 = postprocess_edge_data(FakeIMv3(), edge_dict_v3) + result_v4 = postprocess_edge_data(FakeIMv4(), edge_dict_v4) + + # Both versions should produce the same combined aff and area values + np.testing.assert_array_almost_equal( + result_v3["test"]["aff"], result_v4["test"]["aff"] + ) + np.testing.assert_array_almost_equal( + result_v3["test"]["area"], result_v4["test"]["area"] + ) + np.testing.assert_array_equal( + result_v3["test"]["sv1"], result_v4["test"]["sv1"] + ) + np.testing.assert_array_equal( + result_v3["test"]["sv2"], result_v4["test"]["sv2"] + ) + + +class TestChunkIdStr: + def test_basic(self): + result = chunk_id_str(3, [1, 2, 3]) + assert result == "3_1_2_3" + + def test_layer_zero(self): + result = chunk_id_str(0, [0, 0, 0]) + assert result == "0_0_0_0" + + def test_tuple_coords(self): + result = chunk_id_str(5, (10, 20, 30)) + assert result == "5_10_20_30" + + def test_single_coord(self): + result = chunk_id_str(2, [7]) + assert result == "2_7" + + +class TestMoveUp: + def test_writes_escape_code_to_stdout(self): + """move_up() writes the ANSI escape code for cursor-up to stdout.""" + captured = io.StringIO() + old_stdout = sys.stdout + sys.stdout = captured + try: + move_up(3) + finally: + sys.stdout = old_stdout + assert captured.getvalue() == "\033[3A" + + def test_default_one_line(self): + """move_up() with no argument moves up 1 line.""" + captured = io.StringIO() + old_stdout = sys.stdout + sys.stdout = captured + try: + move_up() + finally: + sys.stdout = old_stdout + assert captured.getvalue() == "\033[1A" + + +class TestGetChunksNotDone: + def _make_mock_imanager(self): + imanager = MagicMock() + imanager.redis = MagicMock() + return imanager + + def test_all_completed_returns_empty(self): + """When all coords are completed in redis, returns empty list.""" + imanager = self._make_mock_imanager() + coords = [[0, 0, 0], [1, 0, 0], [0, 1, 0]] + # All marked as completed (1 = member of the set) + imanager.redis.smismember.return_value = [1, 1, 1] + result = get_chunks_not_done(imanager, layer=2, coords=coords) + assert result == [] + + def test_some_not_completed_returns_those(self): + """When some coords are not completed, returns those coords.""" + imanager = self._make_mock_imanager() + coords = [[0, 0, 0], [1, 0, 0], [0, 1, 0]] + # First is completed, second and third are not + imanager.redis.smismember.return_value = [1, 0, 0] + result = get_chunks_not_done(imanager, layer=2, coords=coords) + assert result == [[1, 0, 0], [0, 1, 0]] + + def test_redis_exception_returns_all_coords(self): + """When redis raises an exception, returns all coords as fallback.""" + imanager = self._make_mock_imanager() + coords = [[0, 0, 0], [1, 0, 0]] + imanager.redis.smismember.side_effect = Exception("Redis down") + result = get_chunks_not_done(imanager, layer=2, coords=coords) + assert result == coords + + +class TestJobTypeGuard: + @patch("pychunkedgraph.ingest.utils.get_redis_connection") + def test_same_job_type_runs_normally(self, mock_get_redis): + """When current job_type matches, decorated function runs normally.""" + mock_redis = MagicMock() + mock_redis.get.return_value = b"ingest" + mock_get_redis.return_value = mock_redis + + @job_type_guard("ingest") + def my_func(): + return "success" + + assert my_func() == "success" + + @patch("pychunkedgraph.ingest.utils.get_redis_connection") + def test_different_job_type_calls_exit(self, mock_get_redis): + """When current job_type differs, exit(1) is called.""" + mock_redis = MagicMock() + mock_redis.get.return_value = b"upgrade" + mock_get_redis.return_value = mock_redis + + @job_type_guard("ingest") + def my_func(): + return "success" + + with pytest.raises(SystemExit) as exc_info: + my_func() + assert exc_info.value.code == 1 + + @patch("pychunkedgraph.ingest.utils.get_redis_connection") + def test_no_current_type_runs_normally(self, mock_get_redis): + """When no current job_type is set in redis, decorated function runs normally.""" + mock_redis = MagicMock() + mock_redis.get.return_value = None + mock_get_redis.return_value = mock_redis + + @job_type_guard("ingest") + def my_func(): + return "success" + + assert my_func() == "success" + + +# ===================================================================== +# Additional pure unit tests +# ===================================================================== +from pychunkedgraph.ingest.utils import start_ocdbt_server + + +class TestGetChunksNotDoneWithSplits: + """Test get_chunks_not_done with splits > 0.""" + + def _make_mock_imanager(self): + imanager = MagicMock() + imanager.redis = MagicMock() + return imanager + + def test_get_chunks_not_done_with_splits(self): + """When splits > 0, should expand coords with split indices.""" + imanager = self._make_mock_imanager() + coords = [[0, 0, 0], [1, 0, 0]] + splits = 2 + # With 2 coords and 2 splits, we get 4 entries: + # (0,0,0) split 0, (0,0,0) split 1, (1,0,0) split 0, (1,0,0) split 1 + # All completed + imanager.redis.smismember.return_value = [1, 1, 1, 1] + result = get_chunks_not_done(imanager, layer=2, coords=coords, splits=splits) + assert result == [] + + def test_get_chunks_not_done_with_splits_some_incomplete(self): + """When splits > 0 and some are not done, return the incomplete (coord, split) tuples.""" + imanager = self._make_mock_imanager() + coords = [[0, 0, 0], [1, 0, 0]] + splits = 2 + # 4 entries, only first is completed + imanager.redis.smismember.return_value = [1, 0, 1, 0] + result = get_chunks_not_done(imanager, layer=2, coords=coords, splits=splits) + # Should return the (coord, split) tuples that are not done + assert len(result) == 2 + assert result[0] == ([0, 0, 0], 1) + assert result[1] == ([1, 0, 0], 1) + + def test_get_chunks_not_done_splits_redis_error(self): + """When redis raises with splits > 0, should return split_coords as fallback.""" + imanager = self._make_mock_imanager() + coords = [[0, 0, 0]] + splits = 2 + imanager.redis.smismember.side_effect = Exception("Redis down") + result = get_chunks_not_done(imanager, layer=2, coords=coords, splits=splits) + # Should return all (coord, split) tuples + assert len(result) == 2 + assert result[0] == ([0, 0, 0], 0) + assert result[1] == ([0, 0, 0], 1) + + def test_get_chunks_not_done_splits_coord_str_format(self): + """With splits, redis keys should include the split index.""" + imanager = self._make_mock_imanager() + coords = [[2, 3, 4]] + splits = 1 + imanager.redis.smismember.return_value = [0] + get_chunks_not_done(imanager, layer=3, coords=coords, splits=splits) + # Check the coords_strs passed to smismember + call_args = imanager.redis.smismember.call_args + assert call_args[0][0] == "3c" + assert call_args[0][1] == ["2_3_4_0"] + + +class TestStartOcdbtServer: + """Test start_ocdbt_server function.""" + + @patch("pychunkedgraph.ingest.utils.ts") + @patch.dict("os.environ", {"MY_POD_IP": "10.0.0.1"}) + def test_start_ocdbt_server(self, mock_ts): + """start_ocdbt_server should open a KvStore and set redis keys.""" + imanager = MagicMock() + imanager.cg.meta.data_source.EDGES = "gs://bucket/edges" + mock_redis = MagicMock() + imanager.redis = mock_redis + + server = MagicMock() + server.port = 12345 + + mock_kv_future = MagicMock() + mock_ts.KvStore.open.return_value = mock_kv_future + + start_ocdbt_server(imanager, server) + + # Verify tensorstore was called with the right spec + call_args = mock_ts.KvStore.open.call_args[0][0] + assert call_args["driver"] == "ocdbt" + assert "gs://bucket/edges/ocdbt" in call_args["base"] + assert call_args["coordinator"]["address"] == "localhost:12345" + mock_kv_future.result.assert_called_once() + + # Verify redis keys were set + mock_redis.set.assert_any_call("OCDBT_COORDINATOR_PORT", "12345") + mock_redis.set.assert_any_call("OCDBT_COORDINATOR_HOST", "10.0.0.1") + + @patch("pychunkedgraph.ingest.utils.ts") + @patch.dict("os.environ", {}, clear=True) + def test_start_ocdbt_server_default_host(self, mock_ts): + """When MY_POD_IP is not set, should default to localhost.""" + imanager = MagicMock() + imanager.cg.meta.data_source.EDGES = "gs://bucket/edges" + mock_redis = MagicMock() + imanager.redis = mock_redis + + server = MagicMock() + server.port = 9999 + + mock_kv_future = MagicMock() + mock_ts.KvStore.open.return_value = mock_kv_future + + start_ocdbt_server(imanager, server) + + mock_redis.set.assert_any_call("OCDBT_COORDINATOR_HOST", "localhost") + + +class TestPostprocessEdgeDataNoneValues: + """Test postprocess_edge_data when edge_dict values are None.""" + + def test_postprocess_edge_data_none_values(self): + """When edge_dict[k] is None, the key should be in result with empty dict.""" + + class FakeMeta: + class data_source: + DATA_VERSION = 3 + + resolution = np.array([4, 4, 40]) + + class FakeIM: + cg_meta = FakeMeta() + + edge_dict = {"test_key": None} + result = postprocess_edge_data(FakeIM(), edge_dict) + assert "test_key" in result + assert result["test_key"] == {} + + def test_postprocess_edge_data_v4_none_values(self): + """Version 4 with None values should also produce empty dict.""" + + class FakeMeta: + class data_source: + DATA_VERSION = 4 + + resolution = np.array([4, 4, 40]) + + class FakeIM: + cg_meta = FakeMeta() + + edge_dict = { + "a": None, + "b": { + "sv1": np.array([1]), + "sv2": np.array([2]), + "aff_x": np.array([0.1]), + "aff_y": np.array([0.2]), + "aff_z": np.array([0.3]), + "area_x": np.array([10]), + "area_y": np.array([20]), + "area_z": np.array([30]), + }, + } + result = postprocess_edge_data(FakeIM(), edge_dict) + assert result["a"] == {} + assert "aff" in result["b"] + assert "area" in result["b"] diff --git a/pychunkedgraph/tests/test_io_components.py b/pychunkedgraph/tests/test_io_components.py new file mode 100644 index 000000000..63ac5abaa --- /dev/null +++ b/pychunkedgraph/tests/test_io_components.py @@ -0,0 +1,57 @@ +"""Tests for pychunkedgraph.io.components using file:// protocol""" + +import numpy as np +import pytest + +from pychunkedgraph.io.components import ( + serialize, + deserialize, + put_chunk_components, + get_chunk_components, +) +from pychunkedgraph.graph.utils import basetypes + + +class TestSerializeDeserialize: + def test_roundtrip(self): + components = [ + {np.uint64(1), np.uint64(2), np.uint64(3)}, + {np.uint64(4), np.uint64(5)}, + ] + proto = serialize(components) + result = deserialize(proto) + # Each supervoxel should map to its component index + assert result[np.uint64(1)] == result[np.uint64(2)] == result[np.uint64(3)] + assert result[np.uint64(4)] == result[np.uint64(5)] + assert result[np.uint64(1)] != result[np.uint64(4)] + + def test_empty_components(self): + # serialize([]) raises ValueError because np.concatenate + # is called on an empty list; this matches production behavior + # where empty components are never serialized + with pytest.raises(ValueError): + serialize([]) + + +class TestPutGetChunkComponents: + def test_roundtrip_via_filesystem(self, tmp_path): + components_dir = f"file://{tmp_path}" + chunk_coord = np.array([1, 2, 3]) + + components = [ + {np.uint64(10), np.uint64(20)}, + {np.uint64(30)}, + ] + put_chunk_components(components_dir, components, chunk_coord) + result = get_chunk_components(components_dir, chunk_coord) + + assert np.uint64(10) in result + assert np.uint64(20) in result + assert np.uint64(30) in result + assert result[np.uint64(10)] == result[np.uint64(20)] + assert result[np.uint64(10)] != result[np.uint64(30)] + + def test_missing_file_returns_empty(self, tmp_path): + components_dir = f"file://{tmp_path}" + result = get_chunk_components(components_dir, np.array([99, 99, 99])) + assert result == {} diff --git a/pychunkedgraph/tests/test_io_edges.py b/pychunkedgraph/tests/test_io_edges.py new file mode 100644 index 000000000..2111bbc6b --- /dev/null +++ b/pychunkedgraph/tests/test_io_edges.py @@ -0,0 +1,79 @@ +"""Tests for pychunkedgraph.io.edges using file:// protocol""" + +import numpy as np +import pytest + +from pychunkedgraph.io.edges import ( + serialize, + deserialize, + get_chunk_edges, + put_chunk_edges, + _parse_edges, +) +from pychunkedgraph.graph.edges import Edges, EDGE_TYPES +from pychunkedgraph.graph.utils import basetypes + + +class TestSerializeDeserialize: + def test_roundtrip(self): + ids1 = np.array([1, 2, 3], dtype=basetypes.NODE_ID) + ids2 = np.array([4, 5, 6], dtype=basetypes.NODE_ID) + affs = np.array([0.5, 0.6, 0.7], dtype=basetypes.EDGE_AFFINITY) + areas = np.array([10, 20, 30], dtype=basetypes.EDGE_AREA) + edges = Edges(ids1, ids2, affinities=affs, areas=areas) + + proto = serialize(edges) + result = deserialize(proto) + np.testing.assert_array_equal(result.node_ids1, ids1) + np.testing.assert_array_equal(result.node_ids2, ids2) + np.testing.assert_array_almost_equal(result.affinities, affs) + np.testing.assert_array_almost_equal(result.areas, areas) + + def test_empty_edges(self): + edges = Edges([], []) + proto = serialize(edges) + result = deserialize(proto) + assert len(result) == 0 + + +class TestParseEdges: + def test_empty_list(self): + result = _parse_edges([]) + assert result == [] + + +class TestPutGetChunkEdges: + def test_roundtrip_via_filesystem(self, tmp_path): + edges_dir = f"file://{tmp_path}" + chunk_coord = np.array([0, 0, 0]) + + edges_d = { + EDGE_TYPES.in_chunk: Edges( + [1, 2], + [3, 4], + affinities=[0.5, 0.6], + areas=[10, 20], + ), + EDGE_TYPES.between_chunk: Edges( + [5], + [6], + affinities=[0.7], + areas=[30], + ), + EDGE_TYPES.cross_chunk: Edges([], []), + } + + put_chunk_edges(edges_dir, chunk_coord, edges_d, compression_level=3) + result = get_chunk_edges(edges_dir, [chunk_coord]) + + assert EDGE_TYPES.in_chunk in result + assert EDGE_TYPES.between_chunk in result + assert EDGE_TYPES.cross_chunk in result + assert len(result[EDGE_TYPES.in_chunk]) == 2 + assert len(result[EDGE_TYPES.between_chunk]) == 1 + + def test_missing_file_returns_empty(self, tmp_path): + edges_dir = f"file://{tmp_path}" + result = get_chunk_edges(edges_dir, [np.array([99, 99, 99])]) + for edge_type in EDGE_TYPES: + assert len(result[edge_type]) == 0 diff --git a/pychunkedgraph/tests/test_lineage.py b/pychunkedgraph/tests/test_lineage.py new file mode 100644 index 000000000..118393e8e --- /dev/null +++ b/pychunkedgraph/tests/test_lineage.py @@ -0,0 +1,458 @@ +"""Tests for pychunkedgraph.graph.lineage""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest +from networkx import DiGraph + +from pychunkedgraph.graph.lineage import ( + get_latest_root_id, + get_future_root_ids, + get_past_root_ids, + get_root_id_history, + lineage_graph, + get_previous_root_ids, + _get_node_properties, +) +from pychunkedgraph.graph import attributes + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestLineage: + def _build_and_merge(self, gen_graph): + """Build a graph with 2 isolated SVs, then merge them.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + + old_root_0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + old_root_1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + + # Merge + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, old_root_0, old_root_1, new_root + + def test_get_latest_root_id_current(self, gen_graph): + graph, _, _, new_root = self._build_and_merge(gen_graph) + latest = get_latest_root_id(graph, new_root) + assert new_root in latest + + def test_get_latest_root_id_after_edit(self, gen_graph): + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + latest = get_latest_root_id(graph, old_root_0) + assert new_root in latest + + def test_get_future_root_ids(self, gen_graph): + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + future = get_future_root_ids(graph, old_root_0) + assert new_root in future + + def test_get_past_root_ids(self, gen_graph): + graph, old_root_0, old_root_1, new_root = self._build_and_merge(gen_graph) + past = get_past_root_ids(graph, new_root) + assert old_root_0 in past or old_root_1 in past + + def test_get_root_id_history(self, gen_graph): + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + history = get_root_id_history(graph, old_root_0) + assert len(history) >= 2 + assert old_root_0 in history + assert new_root in history + + def test_lineage_graph(self, gen_graph): + """lineage_graph should return a DiGraph with nodes for old and new roots.""" + graph, old_root_0, old_root_1, new_root = self._build_and_merge(gen_graph) + lg = lineage_graph(graph, [new_root]) + assert isinstance(lg, DiGraph) + # The lineage graph should contain the new root + assert new_root in lg.nodes + # Should have at least 2 nodes (old root(s) + new root) + assert len(lg.nodes) >= 2 + # Should have edges connecting old roots to the new root + assert lg.number_of_edges() > 0 + + def test_lineage_graph_with_timestamps(self, gen_graph): + """lineage_graph should respect timestamp boundaries.""" + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + # Build lineage graph with a past timestamp that includes the merge + past = datetime.now(UTC) - timedelta(days=20) + future = datetime.now(UTC) + timedelta(days=1) + lg = lineage_graph( + graph, [new_root], timestamp_past=past, timestamp_future=future + ) + assert isinstance(lg, DiGraph) + assert new_root in lg.nodes + + def test_lineage_graph_single_node_id(self, gen_graph): + """lineage_graph should accept a single integer node_id.""" + graph, _, _, new_root = self._build_and_merge(gen_graph) + lg = lineage_graph(graph, int(new_root)) + assert isinstance(lg, DiGraph) + assert new_root in lg.nodes + + def test_get_previous_root_ids(self, gen_graph): + """After a merge, get_previous_root_ids of the new root should include the old roots.""" + graph, old_root_0, old_root_1, new_root = self._build_and_merge(gen_graph) + result = get_previous_root_ids(graph, [new_root]) + assert isinstance(result, dict) + assert new_root in result + previous = result[new_root] + # The previous roots of the merged node should include the old roots + assert old_root_0 in previous or old_root_1 in previous + + def test_get_node_properties(self, gen_graph): + """_get_node_properties should extract timestamp and operation_id from a node entry.""" + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + # Read the new root node with all properties + node_entry = graph.client.read_node(new_root) + assert node_entry is not None + + # _get_node_properties expects a dict with at least Hierarchy.Child + props = _get_node_properties(node_entry) + assert isinstance(props, dict) + # Should have a 'timestamp' key with a float value (epoch seconds) + assert "timestamp" in props + assert isinstance(props["timestamp"], float) + assert props["timestamp"] > 0 + + def test_get_node_properties_with_operation_id(self, gen_graph): + """Nodes created by edits should have an operation_id in their properties.""" + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + # The old root should have NewParent and OperationID set after the merge + node_entry = graph.client.read_node(old_root_0) + props = _get_node_properties(node_entry) + assert "timestamp" in props + # Old roots involved in an edit should have operation_id + if attributes.OperationLogs.OperationID in node_entry: + assert "operation_id" in props + + +class TestGetFutureRootIdsLatest: + """Test get_future_root_ids with different time_stamp values.""" + + def _build_graph_with_two_merges(self, gen_graph): + """Build a graph with 3 SVs, do 2 merges: + First merge SV0+SV1 -> root_A + Then merge root_A+SV2 -> root_B + """ + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + from .helpers import create_chunk, to_label + + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + to_label(graph, 1, 0, 0, 0, 2), + ], + edges=[], + timestamp=fake_ts, + ) + + old_root_0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + old_root_1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + old_root_2 = graph.get_root(to_label(graph, 1, 0, 0, 0, 2)) + + # First merge: SV0 + SV1 + result1 = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + mid_root = result1.new_root_ids[0] + + # Second merge: merged root + SV2 + result2 = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 2)], + affinities=[0.3], + ) + final_root = result2.new_root_ids[0] + + return graph, old_root_0, old_root_1, old_root_2, mid_root, final_root + + def test_future_root_ids_finds_chain(self, gen_graph): + """get_future_root_ids from original root should find mid and final roots.""" + graph, old_root_0, _, _, mid_root, final_root = ( + self._build_graph_with_two_merges(gen_graph) + ) + future = get_future_root_ids(graph, old_root_0) + # Should find at least the mid root and final root + assert len(future) >= 1 + # The final root should be reachable + assert mid_root in future or final_root in future + + def test_future_root_ids_with_past_timestamp(self, gen_graph): + """Using a very old timestamp should find nothing (no future roots before that time).""" + graph, old_root_0, _, _, _, _ = self._build_graph_with_two_merges(gen_graph) + very_old = datetime.now(UTC) - timedelta(days=20) + future = get_future_root_ids(graph, old_root_0, time_stamp=very_old) + # With a very old timestamp, no future roots should be found since + # all edits happened after that time + assert len(future) == 0 + + def test_future_root_ids_current_root_returns_empty(self, gen_graph): + """For the latest root, get_future_root_ids should return empty.""" + graph, _, _, _, _, final_root = self._build_graph_with_two_merges(gen_graph) + future = get_future_root_ids(graph, final_root) + assert len(future) == 0 + + +class TestGetPastRootIdsTimestamps: + """Test get_past_root_ids with different time_stamp values.""" + + def _build_and_merge(self, gen_graph): + """Build a graph with 2 isolated SVs, then merge them.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + from .helpers import create_chunk, to_label + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + + old_root_0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + old_root_1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, old_root_0, old_root_1, new_root + + def test_past_root_ids_of_merged_root(self, gen_graph): + """get_past_root_ids of the merged root should find old roots.""" + graph, old_root_0, old_root_1, new_root = self._build_and_merge(gen_graph) + past = get_past_root_ids(graph, new_root) + assert old_root_0 in past or old_root_1 in past + + def test_past_root_ids_with_future_timestamp(self, gen_graph): + """Using a far-future timestamp should find nothing (no past roots after that time).""" + graph, _, _, new_root = self._build_and_merge(gen_graph) + far_future = datetime.now(UTC) + timedelta(days=365) + past = get_past_root_ids(graph, new_root, time_stamp=far_future) + # With a far-future timestamp, the condition row_time_stamp > time_stamp + # will be False, so no past roots should be found + assert len(past) == 0 + + def test_past_root_ids_original_root_empty(self, gen_graph): + """An original root with no prior edits should have no past root ids.""" + graph, old_root_0, _, _ = self._build_and_merge(gen_graph) + past = get_past_root_ids(graph, old_root_0) + # The original root has no former parents, so past should be empty + assert len(past) == 0 + + +class TestGetRootIdHistory: + """Test get_root_id_history returns full history.""" + + def _build_and_merge(self, gen_graph): + """Build a graph with 2 isolated SVs, then merge them.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + from .helpers import create_chunk, to_label + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + + old_root_0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + old_root_1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, old_root_0, old_root_1, new_root + + def test_history_after_merge(self, gen_graph): + """After merge, get_root_id_history should contain past and current root.""" + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + history = get_root_id_history(graph, old_root_0) + assert isinstance(history, np.ndarray) + # Should contain the queried root itself + assert old_root_0 in history + # Should contain the new root + assert new_root in history + assert len(history) >= 2 + + def test_history_from_new_root(self, gen_graph): + """get_root_id_history from the new root should include old roots.""" + graph, old_root_0, old_root_1, new_root = self._build_and_merge(gen_graph) + history = get_root_id_history(graph, new_root) + assert isinstance(history, np.ndarray) + assert new_root in history + # At least one old root should appear in the history + assert old_root_0 in history or old_root_1 in history + + def test_history_with_timestamps(self, gen_graph): + """get_root_id_history with restrictive timestamps may limit results.""" + graph, old_root_0, _, new_root = self._build_and_merge(gen_graph) + # Very narrow time window: only current root + far_future = datetime.now(UTC) + timedelta(days=365) + very_old = datetime.now(UTC) - timedelta(days=365) + history = get_root_id_history( + graph, + new_root, + time_stamp_past=far_future, + time_stamp_future=very_old, + ) + assert isinstance(history, np.ndarray) + # At minimum, the queried root itself should be in the history + assert new_root in history + + +class TestGetRootIdHistoryDetailed: + """Detailed tests for get_root_id_history covering all branches.""" + + def _build_graph_with_two_merges(self, gen_graph): + """Build a graph with 3 SVs, do 2 merges: + First merge SV0+SV1 -> root_A + Then merge root_A+SV2 -> root_B + """ + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[ + to_label(graph, 1, 0, 0, 0, 0), + to_label(graph, 1, 0, 0, 0, 1), + to_label(graph, 1, 0, 0, 0, 2), + ], + edges=[], + timestamp=fake_ts, + ) + + old_root_0 = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + old_root_1 = graph.get_root(to_label(graph, 1, 0, 0, 0, 1)) + old_root_2 = graph.get_root(to_label(graph, 1, 0, 0, 0, 2)) + + # First merge: SV0 + SV1 + result1 = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + mid_root = result1.new_root_ids[0] + + # Second merge: merged root + SV2 + result2 = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 2)], + affinities=[0.3], + ) + final_root = result2.new_root_ids[0] + + return graph, old_root_0, old_root_1, old_root_2, mid_root, final_root + + def test_history_contains_all_roots_from_old(self, gen_graph): + """get_root_id_history from original root should contain all related roots.""" + graph, old_root_0, _, _, mid_root, final_root = ( + self._build_graph_with_two_merges(gen_graph) + ) + history = get_root_id_history(graph, old_root_0) + assert isinstance(history, np.ndarray) + # Should contain the queried root itself + assert old_root_0 in history + # Should contain mid_root (first merge) + assert mid_root in history + # Should contain final_root (second merge) + assert final_root in history + + def test_history_from_mid_root(self, gen_graph): + """get_root_id_history from mid root should include both past and future.""" + graph, old_root_0, old_root_1, _, mid_root, final_root = ( + self._build_graph_with_two_merges(gen_graph) + ) + history = get_root_id_history(graph, mid_root) + assert isinstance(history, np.ndarray) + assert mid_root in history + # Should include past roots + assert old_root_0 in history or old_root_1 in history + # Should include future root + assert final_root in history + + def test_history_from_final_root(self, gen_graph): + """get_root_id_history from final root should include all past roots.""" + graph, old_root_0, old_root_1, old_root_2, mid_root, final_root = ( + self._build_graph_with_two_merges(gen_graph) + ) + history = get_root_id_history(graph, final_root) + assert isinstance(history, np.ndarray) + assert final_root in history + # Should include the mid root + assert mid_root in history + # Should include at least one of the original roots + assert old_root_0 in history or old_root_1 in history or old_root_2 in history + + def test_history_with_narrow_past_timestamp(self, gen_graph): + """get_root_id_history with a very recent past timestamp excludes old roots.""" + graph, old_root_0, _, _, mid_root, final_root = ( + self._build_graph_with_two_merges(gen_graph) + ) + # Use a very recent past timestamp to exclude past roots + recent = datetime.now(UTC) + timedelta(days=365) + history = get_root_id_history( + graph, + mid_root, + time_stamp_past=recent, + ) + assert isinstance(history, np.ndarray) + # Should contain the root itself + assert mid_root in history + # Should still contain future roots (timestamp_future defaults to max) + assert final_root in history + + def test_history_with_narrow_future_timestamp(self, gen_graph): + """get_root_id_history with a very old future timestamp excludes future roots.""" + graph, old_root_0, _, _, mid_root, final_root = ( + self._build_graph_with_two_merges(gen_graph) + ) + # Use a very old future timestamp to exclude future roots + very_old = datetime.now(UTC) - timedelta(days=365) + history = get_root_id_history( + graph, + mid_root, + time_stamp_future=very_old, + ) + assert isinstance(history, np.ndarray) + # Should contain the root itself + assert mid_root in history + # Should contain past roots (timestamp_past defaults to min) + assert old_root_0 in history diff --git a/pychunkedgraph/tests/test_locks.py b/pychunkedgraph/tests/test_locks.py new file mode 100644 index 000000000..a0f7161cd --- /dev/null +++ b/pychunkedgraph/tests/test_locks.py @@ -0,0 +1,752 @@ +from time import sleep +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph.lineage import get_future_root_ids +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGraphLocks: + @pytest.mark.timeout(30) + def test_lock_unlock(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + + (1) Try lock (opid = 1) + (2) Try lock (opid = 2) + (3) Try unlock (opid = 1) + (4) Try lock (opid = 2) + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.unlock_root(root_id=root_id, operation_id=operation_id_1) + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + @pytest.mark.timeout(30) + def test_lock_expiration(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + + (1) Try lock (opid = 1) + (2) Try lock (opid = 2) + (3) Try lock (opid = 2) with retries + """ + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()) + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + max_tries=10, + waittime_s=0.5, + )[0] + + @pytest.mark.timeout(30) + def test_lock_renew(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + + (1) Try lock (opid = 1) + (2) Try lock (opid = 2) + (3) Try lock (opid = 2) with retries + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.renew_locks(root_ids=[root_id], operation_id=operation_id_1) + + @pytest.mark.timeout(30) + def test_lock_merge_lock_old_id(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + + (1) Merge (includes lock opid 1) + (2) Try lock opid 2 --> should be successful and return new root id + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + + new_root_ids = cg.add_edges( + "Chuck Norris", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + affinities=1.0, + ).new_root_ids + + assert new_root_ids is not None + + operation_id_2 = cg.id_client.create_operation_id() + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + success, new_root_id = cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + max_tries=10, + waittime_s=0.5, + ) + + assert success + assert new_root_ids[0] == new_root_id + + @pytest.mark.timeout(30) + def test_indefinite_lock(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + + (1) Try indefinite lock (opid = 1), get indefinite lock + (2) Try normal lock (opid = 2), doesn't get the normal lock + (3) Try unlock indefinite lock (opid = 1), should unlock indefinite lock + (4) Try lock (opid = 2), should get the normal lock + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + assert cg.client.lock_roots_indefinitely( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.unlock_indefinitely_locked_root( + root_id=root_id, operation_id=operation_id_1 + ) + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + @pytest.mark.timeout(30) + def test_indefinite_lock_with_normal_lock_expiration(self, gen_graph): + """ + No connection between 1, 2 and 3 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 3 │ + │ 2 │ │ + └─────┴─────┘ + + (1) Try normal lock (opid = 1), get normal lock + (2) Try indefinite lock (opid = 1), get indefinite lock + (3) Wait until normal lock expires + (4) Try normal lock (opid = 2), doesn't get the normal lock + (5) Try unlock indefinite lock (opid = 1), should unlock indefinite lock + (6) Try lock (opid = 2), should get the normal lock + """ + + # 1. TODO renew lock test when getting indefinite lock + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + operation_id_1 = cg.id_client.create_operation_id() + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) + + future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.lock_roots_indefinitely( + root_ids=[root_id], + operation_id=operation_id_1, + future_root_ids_d=future_root_ids_d, + )[0] + + sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()) + + operation_id_2 = cg.id_client.create_operation_id() + assert not cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + assert cg.client.unlock_indefinitely_locked_root( + root_id=root_id, operation_id=operation_id_1 + ) + + assert cg.client.lock_roots( + root_ids=[root_id], + operation_id=operation_id_2, + future_root_ids_d=future_root_ids_d, + )[0] + + +# ===================================================================== +# Pure unit tests (no BigTable emulator needed) +# ===================================================================== +from unittest.mock import MagicMock, patch +from collections import defaultdict +import networkx as nx + +from ..graph.locks import RootLock, IndefiniteRootLock +from ..graph.exceptions import LockingError + + +def _make_mock_cg(): + """Create a mock ChunkedGraph object with the methods needed by locks.""" + cg = MagicMock() + cg.id_client.create_operation_id.return_value = np.uint64(42) + cg.client.lock_roots.return_value = (True, [np.uint64(100)]) + cg.client.unlock_root.return_value = None + cg.client.renew_locks.return_value = True + cg.client.lock_roots_indefinitely.return_value = ( + True, + [np.uint64(100)], + [], + ) + cg.client.unlock_indefinitely_locked_root.return_value = None + cg.get_node_timestamps.return_value = [MagicMock()] + return cg + + +class TestRootLockPrivilegedMode: + def test_rootlock_privileged_mode(self): + """privileged_mode=True should skip locking entirely and return self.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + op_id = np.uint64(999) + + lock = RootLock(cg, root_ids, operation_id=op_id, privileged_mode=True) + result = lock.__enter__() + + assert result is lock + assert lock.lock_acquired is False + cg.client.lock_roots.assert_not_called() + + def test_rootlock_privileged_mode_exit_no_unlock(self): + """When privileged and lock was never acquired, __exit__ should not unlock.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + op_id = np.uint64(999) + + lock = RootLock(cg, root_ids, operation_id=op_id, privileged_mode=True) + lock.__enter__() + lock.__exit__(None, None, None) + + cg.client.unlock_root.assert_not_called() + + +class TestRootLockCreatesOperationId: + def test_rootlock_creates_operation_id(self): + """When operation_id is None, __enter__ should create one via cg.id_client.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + + mock_graph = nx.DiGraph() + mock_graph.add_node(np.uint64(100)) + + with patch("pychunkedgraph.graph.locks.lineage_graph", return_value=mock_graph): + lock = RootLock(cg, root_ids, operation_id=None) + lock.__enter__() + + cg.id_client.create_operation_id.assert_called_once() + assert lock.operation_id == np.uint64(42) + + +class TestRootLockAcquired: + def test_rootlock_lock_acquired(self): + """When lock_roots returns (True, [...]), lock_acquired should be True.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + locked = [np.uint64(100), np.uint64(101)] + cg.client.lock_roots.return_value = (True, locked) + + mock_graph = nx.DiGraph() + mock_graph.add_node(np.uint64(100)) + + with patch("pychunkedgraph.graph.locks.lineage_graph", return_value=mock_graph): + lock = RootLock(cg, root_ids, operation_id=np.uint64(10)) + result = lock.__enter__() + + assert lock.lock_acquired is True + assert lock.locked_root_ids == locked + assert result is lock + + +class TestRootLockFailed: + def test_rootlock_lock_failed(self): + """When lock_roots returns (False, []), should raise LockingError.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + cg.client.lock_roots.return_value = (False, []) + + mock_graph = nx.DiGraph() + mock_graph.add_node(np.uint64(100)) + + with patch("pychunkedgraph.graph.locks.lineage_graph", return_value=mock_graph): + lock = RootLock(cg, root_ids, operation_id=np.uint64(10)) + with pytest.raises(LockingError, match="Could not acquire root lock"): + lock.__enter__() + + +class TestRootLockExitUnlocks: + def test_rootlock_exit_unlocks(self): + """When lock_acquired=True, __exit__ should call unlock_root for each locked_root_id.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + locked = [np.uint64(100), np.uint64(101)] + cg.client.lock_roots.return_value = (True, locked) + + mock_graph = nx.DiGraph() + mock_graph.add_node(np.uint64(100)) + + with patch("pychunkedgraph.graph.locks.lineage_graph", return_value=mock_graph): + lock = RootLock(cg, root_ids, operation_id=np.uint64(10)) + lock.__enter__() + + lock.__exit__(None, None, None) + + assert cg.client.unlock_root.call_count == 2 + actual_calls = cg.client.unlock_root.call_args_list + called_root_ids = {c[0][0] for c in actual_calls} + assert called_root_ids == {np.uint64(100), np.uint64(101)} + for c in actual_calls: + assert c[0][1] == np.uint64(10) + + def test_rootlock_exit_no_unlock_when_not_acquired(self): + """When lock_acquired=False, __exit__ should not call unlock_root.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + + lock = RootLock(cg, root_ids, operation_id=np.uint64(10)) + lock.__exit__(None, None, None) + + cg.client.unlock_root.assert_not_called() + + def test_rootlock_exit_handles_unlock_exception(self): + """When unlock_root raises, __exit__ should log warning and not re-raise.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + locked = [np.uint64(100)] + cg.client.lock_roots.return_value = (True, locked) + cg.client.unlock_root.side_effect = RuntimeError("unlock failed") + + mock_graph = nx.DiGraph() + mock_graph.add_node(np.uint64(100)) + + with patch("pychunkedgraph.graph.locks.lineage_graph", return_value=mock_graph): + lock = RootLock(cg, root_ids, operation_id=np.uint64(10)) + lock.__enter__() + + # Should not raise even though unlock_root raises + lock.__exit__(None, None, None) + + +class TestIndefiniteRootLockPrivilegedMode: + def test_indefiniterootlock_privileged_mode(self): + """privileged_mode=True should skip locking and return self.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + op_id = np.uint64(999) + + lock = IndefiniteRootLock(cg, op_id, root_ids, privileged_mode=True) + result = lock.__enter__() + + assert result is lock + assert lock.acquired is False + cg.client.renew_locks.assert_not_called() + cg.client.lock_roots_indefinitely.assert_not_called() + + +class TestIndefiniteRootLockRenewFails: + def test_indefiniterootlock_renew_fails(self): + """When renew_locks returns False, should raise LockingError.""" + cg = _make_mock_cg() + cg.client.renew_locks.return_value = False + root_ids = np.array([np.uint64(100)]) + op_id = np.uint64(10) + + lock = IndefiniteRootLock( + cg, op_id, root_ids, future_root_ids_d=defaultdict(list) + ) + with pytest.raises(LockingError, match="Could not renew locks"): + lock.__enter__() + + +class TestIndefiniteRootLockSuccess: + def test_indefiniterootlock_lock_success(self): + """When lock_roots_indefinitely returns (True, [...], []), acquired should be True.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + locked = [np.uint64(100)] + cg.client.lock_roots_indefinitely.return_value = (True, locked, []) + + lock = IndefiniteRootLock( + cg, + np.uint64(10), + root_ids, + future_root_ids_d=defaultdict(list), + ) + result = lock.__enter__() + + assert lock.acquired is True + assert result is lock + assert list(lock.root_ids) == locked + + +class TestIndefiniteRootLockFail: + def test_indefiniterootlock_lock_fail(self): + """When lock_roots_indefinitely returns (False, [], [...]), should raise LockingError.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + failed = [np.uint64(100)] + cg.client.lock_roots_indefinitely.return_value = (False, [], failed) + + lock = IndefiniteRootLock( + cg, + np.uint64(10), + root_ids, + future_root_ids_d=defaultdict(list), + ) + with pytest.raises(LockingError, match="have been locked indefinitely"): + lock.__enter__() + + +class TestIndefiniteRootLockExitUnlocks: + def test_indefiniterootlock_exit_unlocks(self): + """When acquired=True, __exit__ should call unlock_indefinitely_locked_root.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100), np.uint64(101)]) + cg.client.lock_roots_indefinitely.return_value = ( + True, + [np.uint64(100), np.uint64(101)], + [], + ) + + lock = IndefiniteRootLock( + cg, + np.uint64(10), + root_ids, + future_root_ids_d=defaultdict(list), + ) + lock.__enter__() + lock.__exit__(None, None, None) + + assert cg.client.unlock_indefinitely_locked_root.call_count == 2 + actual_calls = cg.client.unlock_indefinitely_locked_root.call_args_list + called_root_ids = {c[0][0] for c in actual_calls} + assert called_root_ids == {np.uint64(100), np.uint64(101)} + for c in actual_calls: + assert c[0][1] == np.uint64(10) + + def test_indefiniterootlock_exit_no_unlock_when_not_acquired(self): + """When acquired=False, __exit__ should not unlock.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + lock = IndefiniteRootLock(cg, np.uint64(10), root_ids) + lock.__exit__(None, None, None) + cg.client.unlock_indefinitely_locked_root.assert_not_called() + + def test_indefiniterootlock_exit_handles_exception(self): + """When unlock_indefinitely_locked_root raises, should not re-raise.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + cg.client.lock_roots_indefinitely.return_value = ( + True, + [np.uint64(100)], + [], + ) + cg.client.unlock_indefinitely_locked_root.side_effect = RuntimeError("fail") + + lock = IndefiniteRootLock( + cg, + np.uint64(10), + root_ids, + future_root_ids_d=defaultdict(list), + ) + lock.__enter__() + # Should not raise + lock.__exit__(None, None, None) + + +class TestIndefiniteRootLockComputesFutureRootIds: + def test_indefiniterootlock_computes_future_root_ids(self): + """When future_root_ids_d is None, should compute from lineage_graph.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + cg.client.lock_roots_indefinitely.return_value = ( + True, + [np.uint64(100)], + [], + ) + + mock_lgraph = nx.DiGraph() + mock_lgraph.add_edge(np.uint64(100), np.uint64(200)) + mock_lgraph.add_edge(np.uint64(100), np.uint64(300)) + + with patch( + "pychunkedgraph.graph.locks.lineage_graph", return_value=mock_lgraph + ): + lock = IndefiniteRootLock( + cg, + np.uint64(10), + root_ids, + future_root_ids_d=None, + ) + lock.__enter__() + + assert lock.future_root_ids_d is not None + descendants = lock.future_root_ids_d[np.uint64(100)] + assert set(descendants) == {np.uint64(200), np.uint64(300)} + + +class TestRootLockContextManager: + def test_rootlock_as_context_manager(self): + """Test using RootLock with the `with` statement.""" + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100)]) + locked = [np.uint64(100)] + cg.client.lock_roots.return_value = (True, locked) + + mock_graph = nx.DiGraph() + mock_graph.add_node(np.uint64(100)) + + with patch("pychunkedgraph.graph.locks.lineage_graph", return_value=mock_graph): + with RootLock(cg, root_ids, operation_id=np.uint64(10)) as lock: + assert lock.lock_acquired is True + + cg.client.unlock_root.assert_called_once() diff --git a/pychunkedgraph/tests/test_merge.py b/pychunkedgraph/tests/test_merge.py new file mode 100644 index 000000000..9c6a3148c --- /dev/null +++ b/pychunkedgraph/tests/test_merge.py @@ -0,0 +1,710 @@ +from datetime import datetime, timedelta, UTC +from math import inf +from warnings import warn + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import ChunkedGraph +from ..graph.utils.serializers import serialize_uint64 +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGraphMerge: + @pytest.mark.timeout(30) + def test_merge_pair_same_chunk(self, gen_graph): + """ + Add edge between existing RG supervoxels 1 and 2 (same chunk) + Expected: Same (new) parent for RG 1 and 2 on Layer two + ┌─────┐ ┌─────┐ + │ A¹ │ │ A¹ │ + │ 1 2 │ => │ 1━2 │ + │ │ │ │ + └─────┘ └─────┘ + """ + + atomic_chunk_bounds = np.array([1, 1, 1]) + cg = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + + # Merge + new_root_ids = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], + affinities=[0.3], + ).new_root_ids + + assert len(new_root_ids) == 1 + new_root_id = new_root_ids[0] + + # Check + assert cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) == new_root_id + leaves = np.unique(cg.get_subgraph([new_root_id], leaves_only=True)) + assert len(leaves) == 2 + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 0, 0, 0, 1) in leaves + + @pytest.mark.timeout(30) + def test_merge_pair_neighboring_chunks(self, gen_graph): + """ + Add edge between existing RG supervoxels 1 and 2 (neighboring chunks) + ┌─────┬─────┐ ┌─────┬─────┐ + │ A¹ │ B¹ │ │ A¹ │ B¹ │ + │ 1 │ 2 │ => │ 1━━┿━━2 │ + │ │ │ │ │ │ + └─────┴─────┘ └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + # Merge + new_root_ids = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=0.3, + ).new_root_ids + + assert len(new_root_ids) == 1 + new_root_id = new_root_ids[0] + + # Check + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == new_root_id + leaves = np.unique(cg.get_subgraph([new_root_id], leaves_only=True)) + assert len(leaves) == 2 + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 1, 0, 0, 0) in leaves + + @pytest.mark.timeout(120) + def test_merge_pair_disconnected_chunks(self, gen_graph): + """ + Add edge between existing RG supervoxels 1 and 2 (disconnected chunks) + ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ + │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ + │ 1 │ │ 2 │ => │ 1━━┿━━━━━┿━━2 │ + │ │ │ │ │ │ │ │ + └─────┘ └─────┘ └─────┘ └─────┘ + """ + + cg = gen_graph(n_layers=5) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk Z + create_chunk( + cg, + vertices=[to_label(cg, 1, 7, 7, 7, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1, + ) + add_parent_chunk( + cg, 3, [3, 3, 3], time_stamp=fake_timestamp, n_threads=1, + ) + add_parent_chunk( + cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1, + ) + add_parent_chunk( + cg, 5, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1, + ) + + # Merge + result = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=[0.3], + ) + new_root_ids, lvl2_node_ids = result.new_root_ids, result.new_lvl2_ids + + u_layers = np.unique(cg.get_chunk_layers(lvl2_node_ids)) + assert len(u_layers) == 1 + assert u_layers[0] == 2 + + assert len(new_root_ids) == 1 + new_root_id = new_root_ids[0] + + # Check + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) == new_root_id + leaves = np.unique(cg.get_subgraph(new_root_id, leaves_only=True)) + assert len(leaves) == 2 + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 7, 7, 7, 0) in leaves + + @pytest.mark.timeout(30) + def test_merge_pair_already_connected(self, gen_graph): + """ + Add edge between already connected RG supervoxels 1 and 2 (same chunk). + Expected: No change + ┌─────┐ ┌─────┐ + │ A¹ │ │ A¹ │ + │ 1━2 │ => │ 1━2 │ + │ │ │ │ + └─────┘ └─────┘ + """ + + cg = gen_graph(n_layers=2) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], + timestamp=fake_timestamp, + ) + + res_old = cg.client._table.read_rows() + res_old.consume_all() + + # Merge + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], + ) + res_new = cg.client._table.read_rows() + res_new.consume_all() + res_new.rows.pop(b'ioperations', None) + res_new.rows.pop(b'00000000000000000001', None) + + # Check + if res_old.rows != res_new.rows: + warn( + "Rows were modified when merging a pair of already connected supervoxels. " + "While probably not an error, it is an unnecessary operation." + ) + + @pytest.mark.timeout(30) + def test_merge_triple_chain_to_full_circle_same_chunk(self, gen_graph): + """ + Add edge between indirectly connected RG supervoxels 1 and 2 (same chunk) + ┌─────┐ ┌─────┐ + │ A¹ │ │ A¹ │ + │ 1 2 │ => │ 1━2 │ + │ ┗3┛ │ │ ┗3┛ │ + └─────┘ └─────┘ + """ + + cg = gen_graph(n_layers=2) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 0, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + to_label(cg, 1, 0, 0, 0, 2), + ], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 0.5), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 0.5), + ], + timestamp=fake_timestamp, + ) + + # Merge + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], + affinities=0.3, + ).new_root_ids + + @pytest.mark.timeout(30) + def test_merge_triple_chain_to_full_circle_neighboring_chunks(self, gen_graph): + """ + Add edge between indirectly connected RG supervoxels 1 and 2 (neighboring chunks) + ┌─────┬─────┐ ┌─────┬─────┐ + │ A¹ │ B¹ │ │ A¹ │ B¹ │ + │ 1 │ 2 │ => │ 1━━┿━━2 │ + │ ┗3━┿━━┛ │ │ ┗3━┿━━┛ │ + └─────┴─────┘ └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), inf)], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1, + ) + + # Merge + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=1.0, + ).new_root_ids + + @pytest.mark.timeout(120) + def test_merge_triple_chain_to_full_circle_disconnected_chunks(self, gen_graph): + """ + Add edge between indirectly connected RG supervoxels 1 and 2 (disconnected chunks) + ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ + │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ + │ 1 │ │ 2 │ => │ 1━━┿━━━━━┿━━2 │ + │ ┗3━┿━━━━━┿━━┛ │ │ ┗3━┿━━━━━┿━━┛ │ + └─────┘ └─────┘ └─────┘ └─────┘ + """ + + cg = gen_graph(n_layers=5) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 7, 7, 7, 0), inf), + ], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 7, 7, 7, 0)], + edges=[ + (to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 1), inf) + ], + timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [3, 3, 3], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [1, 1, 1], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 5, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Merge + new_root_ids = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0)], + affinities=1.0, + ).new_root_ids + + assert len(new_root_ids) == 1 + new_root_id = new_root_ids[0] + + # Check + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_id + assert cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) == new_root_id + leaves = np.unique(cg.get_subgraph(new_root_id, leaves_only=True)) + assert len(leaves) == 3 + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 0, 0, 0, 1) in leaves + assert to_label(cg, 1, 7, 7, 7, 0) in leaves + + @pytest.mark.timeout(30) + def test_merge_same_node(self, gen_graph): + """ + Try to add loop edge between RG supervoxel 1 and itself + ┌─────┐ + │ A¹ │ + │ 1 │ => Reject + │ │ + └─────┘ + """ + + cg = gen_graph(n_layers=2) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + res_old = cg.client._table.read_rows() + res_old.consume_all() + + # Merge + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], + ) + + res_new = cg.client._table.read_rows() + res_new.consume_all() + + assert res_new.rows == res_old.rows + + @pytest.mark.timeout(30) + def test_merge_pair_abstract_nodes(self, gen_graph): + """ + Try to add edge between RG supervoxel 1 and abstract node "2" + => Reject + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)], edges=[], timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, vertices=[to_label(cg, 1, 1, 0, 0, 0)], edges=[], timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + res_old = cg.client._table.read_rows() + res_old.consume_all() + + # Merge + with pytest.raises(Exception): + cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 2, 1, 0, 0, 1)], + ) + + res_new = cg.client._table.read_rows() + res_new.consume_all() + + assert res_new.rows == res_old.rows + + @pytest.mark.timeout(30) + def test_diagonal_connections(self, gen_graph): + """ + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 2 1━┿━━3 │ + │ / │ │ + ┌─────┬─────┐ + │ | │ │ + │ 4━━┿━━5 │ + │ C¹ │ D¹ │ + └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Chunk A + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf), + ], + ) + + # Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + ) + + # Chunk C + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 1, 0, 0)], + edges=[ + (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 1, 1, 0, 0), inf), + (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), + ], + ) + + # Chunk D + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 1, 0, 0)], + edges=[(to_label(cg, 1, 1, 1, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf)], + ) + + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + + rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=3, x=0, y=0, z=0)) + root_ids_t0 = list(rr.keys()) + + assert len(root_ids_t0) == 2 + + child_ids = [] + for root_id in root_ids_t0: + child_ids.extend(cg.get_subgraph(root_id, leaves_only=True)) + + new_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + affinities=[0.5], + ).new_root_ids + + root_ids = [] + for child_id in child_ids: + root_ids.append(cg.get_root(child_id)) + + assert len(np.unique(root_ids)) == 1 + + root_id = root_ids[0] + assert root_id == new_roots[0] + + @pytest.mark.timeout(240) + def test_cross_edges(self, gen_graph): + cg = gen_graph(n_layers=5) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 0), inf), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), inf), + ], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), inf), + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1), inf), + ], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk C + create_chunk( + cg, vertices=[to_label(cg, 1, 2, 0, 0, 0)], timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 5, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + new_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0)], + affinities=0.9, + ).new_root_ids + + assert len(new_roots) == 1 + + +class TestGraphMergeSkipConnections: + """Tests for skip connection behavior during merge operations.""" + + @pytest.mark.timeout(120) + def test_merge_creates_skip_connection(self, gen_graph): + """ + Merge two isolated nodes in a 5-layer graph. After merge, each + component that has no sibling at its layer should get a skip-connection + parent at a higher layer. + + ┌─────┐ ┌─────┐ + │ A¹ │ │ Z¹ │ + │ 1 │ │ 2 │ + └─────┘ └─────┘ + After merge: 1 and 2 are connected, hierarchy should skip + intermediate empty layers. + """ + cg = gen_graph(n_layers=5) + + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 7, 7, 7, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [3, 3, 3], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 5, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Before merge: verify both nodes have root at layer 5 + root1_pre = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + root2_pre = cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) + assert root1_pre != root2_pre + assert cg.get_chunk_layer(root1_pre) == 5 + assert cg.get_chunk_layer(root2_pre) == 5 + + # Merge + result = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 7, 7, 7, 0)], + affinities=[0.5], + ) + new_root_ids = result.new_root_ids + assert len(new_root_ids) == 1 + + # After merge: single root, both supervoxels reachable + new_root = new_root_ids[0] + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root + assert cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) == new_root + assert cg.get_chunk_layer(new_root) == 5 + + @pytest.mark.timeout(120) + def test_merge_multi_layer_hierarchy_correctness(self, gen_graph): + """ + After a merge across chunks, verify the full parent chain from + each supervoxel to root is valid — every node has a parent at + a higher layer, and the root is reachable. + """ + cg = gen_graph(n_layers=5) + + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 7, 7, 7, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [3, 3, 3], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 5, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + result = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 7, 7, 7, 0)], + affinities=[0.5], + ) + + # Verify parent chain for both supervoxels + for sv in [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 7, 7, 7, 0)]: + parents = cg.get_root(sv, get_all_parents=True) + # Each parent should be at a strictly higher layer + prev_layer = 1 + for p in parents: + layer = cg.get_chunk_layer(p) + assert layer > prev_layer, ( + f"Parent chain not monotonically increasing: {prev_layer} -> {layer}" + ) + prev_layer = layer + # Last parent should be the root + assert parents[-1] == result.new_root_ids[0] + + @pytest.mark.timeout(30) + def test_merge_no_skip_when_siblings_exist(self, gen_graph): + """ + When two nodes in neighboring chunks are merged, they should NOT + create a skip connection — the parent should be at layer+1 since + they are siblings in the same parent chunk. + + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 2 │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=3) + + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Merge + result = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], + affinities=[0.5], + ) + + new_root = result.new_root_ids[0] + # Root should be at layer 3 (the top layer), since the two L2 nodes + # are siblings at layer 3 + assert cg.get_chunk_layer(new_root) == 3 diff --git a/pychunkedgraph/tests/test_merge_split.py b/pychunkedgraph/tests/test_merge_split.py new file mode 100644 index 000000000..45e67a483 --- /dev/null +++ b/pychunkedgraph/tests/test_merge_split.py @@ -0,0 +1,74 @@ +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import types + + +class TestGraphMergeSplit: + @pytest.mark.timeout(240) + def test_multiple_cuts_and_splits(self, gen_graph_simplequerytest): + cg = gen_graph_simplequerytest + + rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=4, x=0, y=0, z=0)) + root_ids_t0 = list(rr.keys()) + child_ids = [types.empty_1d] + for root_id in root_ids_t0: + child_ids.append(cg.get_subgraph([root_id], leaves_only=True)) + child_ids = np.concatenate(child_ids) + + for i in range(10): + new_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], + affinities=0.9, + ).new_root_ids + assert len(new_roots) == 1, new_roots + assert len(cg.get_subgraph([new_roots[0]], leaves_only=True)) == 4 + + root_ids = cg.get_roots(child_ids, assert_roots=True) + u_root_ids = np.unique(root_ids) + assert len(u_root_ids) == 1, u_root_ids + + new_roots = cg.remove_edges( + "John Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 1), + mincut=False, + ).new_root_ids + assert len(new_roots) == 2, new_roots + + root_ids = cg.get_roots(child_ids, assert_roots=True) + u_root_ids = np.unique(root_ids) + these_child_ids = [] + for root_id in u_root_ids: + these_child_ids.extend(cg.get_subgraph([root_id], leaves_only=True)) + + assert len(these_child_ids) == 4 + assert len(u_root_ids) == 2, u_root_ids + + new_roots = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 1), + mincut=False, + ).new_root_ids + assert len(new_roots) == 2, new_roots + + root_ids = cg.get_roots(child_ids, assert_roots=True) + u_root_ids = np.unique(root_ids) + assert len(u_root_ids) == 3, u_root_ids + + new_roots = cg.add_edges( + "Jane Doe", + [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], + affinities=0.9, + ).new_root_ids + assert len(new_roots) == 1, new_roots + + root_ids = cg.get_roots(child_ids, assert_roots=True) + u_root_ids = np.unique(root_ids) + assert len(u_root_ids) == 2, u_root_ids diff --git a/pychunkedgraph/tests/test_meta.py b/pychunkedgraph/tests/test_meta.py new file mode 100644 index 000000000..f94b7d792 --- /dev/null +++ b/pychunkedgraph/tests/test_meta.py @@ -0,0 +1,609 @@ +"""Tests for pychunkedgraph.graph.meta""" + +import pickle + +import numpy as np +import pytest + +from pychunkedgraph.graph.meta import ChunkedGraphMeta, GraphConfig, DataSource + + +class TestChunkedGraphMeta: + def test_init(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + assert meta.graph_config is not None + assert meta.data_source is not None + + def test_graph_config_properties(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + assert meta.graph_config.FANOUT == 2 + assert meta.graph_config.SPATIAL_BITS == 10 + assert meta.graph_config.LAYER_ID_BITS == 8 + + def test_layer_count_setter(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + meta.layer_count = 6 + assert meta.layer_count == 6 + assert meta.bitmasks is not None + assert 1 in meta.bitmasks + + def test_bitmasks(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + bm = meta.bitmasks + assert isinstance(bm, dict) + assert 1 in bm + assert 2 in bm + + def test_read_only_default(self, gen_graph): + graph = gen_graph(n_layers=4) + assert graph.meta.READ_ONLY is False + + def test_is_out_of_bounds(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + assert meta.is_out_of_bounds(np.array([-1, 0, 0])) + assert not meta.is_out_of_bounds(np.array([0, 0, 0])) + + def test_pickle_roundtrip(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + state = meta.__getstate__() + new_meta = ChunkedGraphMeta.__new__(ChunkedGraphMeta) + new_meta.__setstate__(state) + assert new_meta.graph_config == meta.graph_config + assert new_meta.data_source == meta.data_source + + def test_split_bounding_offset_default(self, gen_graph): + graph = gen_graph(n_layers=4) + assert graph.meta.split_bounding_offset == (240, 240, 24) + + +class TestEdgeDtype: + def test_v2(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=2) + meta = ChunkedGraphMeta(gc, ds) + # Manually set bitmasks/layer_count to avoid CloudVolume access + meta._layer_count = 4 + meta._bitmasks = {1: 10, 2: 10, 3: 1, 4: 1} + dt = meta.edge_dtype + names = [d[0] for d in dt] + assert "sv1" in names + assert "aff" in names + assert "area" in names + + def test_v3(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=3) + meta = ChunkedGraphMeta(gc, ds) + meta._layer_count = 4 + meta._bitmasks = {1: 10, 2: 10, 3: 1, 4: 1} + dt = meta.edge_dtype + names = [d[0] for d in dt] + assert "aff_x" in names + assert "area_x" in names + + def test_v4(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + meta._layer_count = 4 + meta._bitmasks = {1: 10, 2: 10, 3: 1, 4: 1} + dt = meta.edge_dtype + # v4 uses float32 for affinities + for name, dtype in dt: + if name.startswith("aff"): + assert dtype == np.float32 + + +class TestDataSourceDefaults: + def test_defaults(self): + ds = DataSource() + assert ds.EDGES is None + assert ds.COMPONENTS is None + assert ds.WATERSHED is None + assert ds.DATA_VERSION is None + assert ds.CV_MIP == 0 + + +class TestGraphConfigDefaults: + def test_defaults(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + assert gc.FANOUT == 2 + assert gc.LAYER_ID_BITS == 8 + assert gc.SPATIAL_BITS == 10 + assert gc.OVERWRITE is False + assert gc.ROOT_COUNTERS == 8 + + +class TestResolutionProperty: + def test_resolution_returns_numpy_array(self, gen_graph): + """meta.resolution should delegate to ws_cv.resolution and return a numpy array.""" + graph = gen_graph(n_layers=4) + meta = graph.meta + res = meta.resolution + assert isinstance(res, np.ndarray) + # The mock CloudVolumeMock sets resolution to [1, 1, 1] + np.testing.assert_array_equal(res, np.array([1, 1, 1])) + + def test_resolution_dtype(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + res = meta.resolution + # Should be numeric + assert np.issubdtype(res.dtype, np.integer) or np.issubdtype( + res.dtype, np.floating + ) + + +class TestLayerChunkCounts: + def test_layer_chunk_counts_length(self, gen_graph): + """layer_chunk_counts should return a list with one entry per layer from 2..layer_count-1, plus [1] for root.""" + graph = gen_graph(n_layers=4) + meta = graph.meta + counts = meta.layer_chunk_counts + # layers 2, 3 contribute entries, plus the trailing [1] for root + # layer_count=4, so range(2, 4) => layers 2, 3 => 2 entries + [1] = 3 + assert isinstance(counts, list) + assert ( + len(counts) == meta.layer_count - 2 + 1 + ) # -2 for range start, +1 for root + # The last entry should always be 1 (root layer) + assert counts[-1] == 1 + + def test_layer_chunk_counts_values(self, gen_graph): + """Each count should be the product of chunk bounds for that layer.""" + graph = gen_graph(n_layers=4) + meta = graph.meta + counts = meta.layer_chunk_counts + for i, layer in enumerate(range(2, meta.layer_count)): + expected = np.prod(meta.layer_chunk_bounds[layer]) + assert counts[i] == expected + + def test_layer_chunk_counts_n_layers_5(self, gen_graph): + graph = gen_graph(n_layers=5) + meta = graph.meta + counts = meta.layer_chunk_counts + # n_layers=5 => layers 2,3,4 + root => 4 entries + assert len(counts) == 4 + assert counts[-1] == 1 + + +class TestLayerChunkBoundsSetter: + def test_setter_overrides_bounds(self, gen_graph): + """Setting layer_chunk_bounds should override the computed value.""" + graph = gen_graph(n_layers=4) + meta = graph.meta + + custom_bounds = { + 2: np.array([10, 10, 10]), + 3: np.array([5, 5, 5]), + } + meta.layer_chunk_bounds = custom_bounds + assert meta.layer_chunk_bounds is custom_bounds + np.testing.assert_array_equal( + meta.layer_chunk_bounds[2], np.array([10, 10, 10]) + ) + np.testing.assert_array_equal(meta.layer_chunk_bounds[3], np.array([5, 5, 5])) + + def test_setter_with_none_clears(self, gen_graph): + """Setting layer_chunk_bounds to None should clear cached value (next access recomputes).""" + graph = gen_graph(n_layers=4) + meta = graph.meta + # Access to populate the cache + _ = meta.layer_chunk_bounds + meta.layer_chunk_bounds = None + # After clearing, the internal _layer_bounds_d is None + assert meta._layer_bounds_d is None + + +class TestEdgeDtypeUnknownVersion: + """Test that an unknown DATA_VERSION raises Exception in edge_dtype.""" + + def test_unknown_version_raises(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=999) + meta = ChunkedGraphMeta(gc, ds) + meta._layer_count = 4 + meta._bitmasks = {1: 10, 2: 10, 3: 1, 4: 1} + with pytest.raises(Exception): + _ = meta.edge_dtype + + def test_none_version_raises(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=None) + meta = ChunkedGraphMeta(gc, ds) + meta._layer_count = 4 + meta._bitmasks = {1: 10, 2: 10, 3: 1, 4: 1} + with pytest.raises(Exception): + _ = meta.edge_dtype + + def test_version_1_raises(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=1) + meta = ChunkedGraphMeta(gc, ds) + meta._layer_count = 4 + meta._bitmasks = {1: 10, 2: 10, 3: 1, 4: 1} + with pytest.raises(Exception): + _ = meta.edge_dtype + + +class TestGetNewArgs: + """Test __getnewargs__ returns (graph_config, data_source).""" + + def test_getnewargs_returns_tuple(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + result = meta.__getnewargs__() + assert isinstance(result, tuple) + assert len(result) == 2 + + def test_getnewargs_contains_config_and_source(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], FANOUT=2) + ds = DataSource(DATA_VERSION=3, CV_MIP=1) + meta = ChunkedGraphMeta(gc, ds) + result = meta.__getnewargs__() + assert result[0] is gc + assert result[1] is ds + assert result[0].CHUNK_SIZE == [64, 64, 64] + assert result[1].DATA_VERSION == 3 + + def test_getnewargs_with_gen_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + result = meta.__getnewargs__() + assert result[0] == meta.graph_config + assert result[1] == meta.data_source + + +class TestCustomData: + """Test custom_data including READ_ONLY=True and mesh dir.""" + + def test_read_only_true(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"READ_ONLY": True}) + assert meta.READ_ONLY is True + + def test_read_only_false_explicit(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"READ_ONLY": False}) + assert meta.READ_ONLY is False + + def test_read_only_default_no_custom_data(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + assert meta.READ_ONLY is False + + def test_mesh_dir_in_custom_data(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta( + gc, ds, custom_data={"mesh": {"dir": "gs://bucket/mesh"}} + ) + assert meta.custom_data["mesh"]["dir"] == "gs://bucket/mesh" + + def test_split_bounding_offset_custom(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + meta = ChunkedGraphMeta( + gc, ds, custom_data={"split_bounding_offset": (100, 100, 10)} + ) + assert meta.split_bounding_offset == (100, 100, 10) + + def test_custom_data_preserved_through_getstate(self): + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(DATA_VERSION=4) + custom = {"READ_ONLY": True, "mesh": {"dir": "gs://bucket/mesh"}} + meta = ChunkedGraphMeta(gc, ds, custom_data=custom) + state = meta.__getstate__() + assert state["custom_data"] == custom + + +class TestCvAlias: + """Test that cv property returns the same object as ws_cv.""" + + def test_cv_returns_same_as_ws_cv(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + assert meta.cv is meta.ws_cv + + def test_cv_is_not_none(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + assert meta.cv is not None + + +class TestStr: + """Test __str__ returns a non-empty string with expected sections.""" + + def _add_info_to_mock(self, meta): + """Add an info dict to the CloudVolumeMock so dataset_info works.""" + meta._ws_cv.info = {"scales": [{"resolution": [1, 1, 1]}]} + + def test_str_not_empty(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + result = str(meta) + assert len(result) > 0 + + def test_str_contains_sections(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + result = str(meta) + assert "GRAPH_CONFIG" in result + assert "DATA_SOURCE" in result + assert "CUSTOM_DATA" in result + assert "BITMASKS" in result + assert "VOXEL_BOUNDS" in result + assert "VOXEL_COUNTS" in result + assert "LAYER_CHUNK_BOUNDS" in result + assert "LAYER_CHUNK_COUNTS" in result + assert "DATASET_INFO" in result + + def test_str_is_string_type(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + result = str(meta) + assert isinstance(result, str) + + +class TestDatasetInfo: + """Test dataset_info returns dict with expected keys.""" + + def _add_info_to_mock(self, meta): + """Add an info dict to the CloudVolumeMock so dataset_info works.""" + meta._ws_cv.info = {"scales": [{"resolution": [1, 1, 1]}]} + + def test_dataset_info_is_dict(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + info = meta.dataset_info + assert isinstance(info, dict) + + def test_dataset_info_has_expected_keys(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + info = meta.dataset_info + assert "chunks_start_at_voxel_offset" in info + assert info["chunks_start_at_voxel_offset"] is True + assert "data_dir" in info + assert "graph" in info + + def test_dataset_info_graph_section(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + info = meta.dataset_info + graph_info = info["graph"] + assert "chunk_size" in graph_info + assert "n_bits_for_layer_id" in graph_info + assert "cv_mip" in graph_info + assert "n_layers" in graph_info + assert "spatial_bit_masks" in graph_info + assert graph_info["n_layers"] == meta.layer_count + + def test_dataset_info_with_mesh_dir(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + meta._custom_data = {"mesh": {"dir": "gs://bucket/mesh"}} + info = meta.dataset_info + assert "mesh" in info + assert info["mesh"] == "gs://bucket/mesh" + + def test_dataset_info_without_mesh_dir(self, gen_graph): + graph = gen_graph(n_layers=4) + meta = graph.meta + self._add_info_to_mock(meta) + meta._custom_data = {} + info = meta.dataset_info + assert "mesh" not in info + + +# ===================================================================== +# Pure unit tests (no BigTable emulator needed) - mock CloudVolume & Redis +# ===================================================================== +import json +from unittest.mock import MagicMock, patch, PropertyMock + + +class TestWsCvRedisCached: + """Test ws_cv property with Redis caching.""" + + @patch("pychunkedgraph.graph.meta.CloudVolume") + @patch("pychunkedgraph.graph.meta.get_redis_connection") + def test_ws_cv_redis_cached(self, mock_get_redis, mock_cv_cls): + """When redis has cached info, ws_cv uses cached CloudVolume.""" + gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + cached_info = {"scales": [{"resolution": [8, 8, 40]}]} + mock_redis = MagicMock() + mock_redis.get.return_value = json.dumps(cached_info) + mock_get_redis.return_value = mock_redis + + mock_cv_instance = MagicMock() + mock_cv_cls.return_value = mock_cv_instance + + result = meta.ws_cv + + assert result is mock_cv_instance + mock_cv_cls.assert_called_once_with("gs://bucket/ws", info=cached_info) + + @patch("pychunkedgraph.graph.meta.CloudVolume") + @patch("pychunkedgraph.graph.meta.get_redis_connection") + def test_ws_cv_redis_failure_fallback(self, mock_get_redis, mock_cv_cls): + """When redis raises, ws_cv falls back to direct CloudVolume.""" + gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + mock_get_redis.side_effect = Exception("Redis connection failed") + + mock_cv_instance = MagicMock() + mock_cv_instance.info = {"scales": []} + mock_cv_cls.return_value = mock_cv_instance + + result = meta.ws_cv + + assert result is mock_cv_instance + # Should have been called without info kwarg (fallback) + mock_cv_cls.assert_called_with("gs://bucket/ws") + + @patch("pychunkedgraph.graph.meta.CloudVolume") + @patch("pychunkedgraph.graph.meta.get_redis_connection") + def test_ws_cv_caches_to_redis(self, mock_get_redis, mock_cv_cls): + """When redis is available but cache miss, ws_cv caches info to redis.""" + gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + mock_redis = MagicMock() + # Make redis.get raise to simulate cache miss on json.loads + mock_redis.get.return_value = None # This will make json.loads fail + mock_get_redis.return_value = mock_redis + + mock_cv_instance = MagicMock() + mock_cv_instance.info = {"scales": [{"resolution": [8, 8, 40]}]} + mock_cv_cls.return_value = mock_cv_instance + + result = meta.ws_cv + + assert result is mock_cv_instance + # The fallback CloudVolume call (no info= kwarg) + mock_cv_cls.assert_called_with("gs://bucket/ws") + # Should try to cache in redis + mock_redis.set.assert_called_once() + + @patch("pychunkedgraph.graph.meta.CloudVolume") + @patch("pychunkedgraph.graph.meta.get_redis_connection") + def test_ws_cv_returns_cached_instance(self, mock_get_redis, mock_cv_cls): + """Once ws_cv has been set, subsequent calls return the cached instance.""" + gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + # Pre-set the cached ws_cv + mock_cv = MagicMock() + meta._ws_cv = mock_cv + + result = meta.ws_cv + assert result is mock_cv + # Should not try to create a new CloudVolume + mock_cv_cls.assert_not_called() + + +class TestLayerCountComputed: + """Test layer_count property computation from CloudVolume bounds.""" + + def test_layer_count_computed_from_cv(self): + """layer_count should be computed from ws_cv.bounds.""" + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], FANOUT=2, SPATIAL_BITS=10) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + # Create a mock ws_cv with bounds + mock_cv = MagicMock() + # bounds.to_list() returns [x_min, y_min, z_min, x_max, y_max, z_max] + # With a 256x256x256 volume and 64x64x64 chunks: 4 chunks per dim + # log_2(4) = 2, +2 = 4 layers + mock_cv.bounds.to_list.return_value = [0, 0, 0, 256, 256, 256] + meta._ws_cv = mock_cv + + count = meta.layer_count + assert isinstance(count, int) + assert count >= 3 # at least 3 layers for any reasonable volume + + def test_layer_count_cached_after_first_access(self): + """After layer_count is computed, it should be cached.""" + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], FANOUT=2, SPATIAL_BITS=10) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + meta._layer_count = 5 + + assert meta.layer_count == 5 + + +class TestBitmasksLazy: + """Test bitmasks property lazy computation.""" + + def test_bitmasks_lazy_computed(self): + """bitmasks should be computed lazily from layer_count.""" + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], FANOUT=2, SPATIAL_BITS=10) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + # Set layer_count directly to avoid needing ws_cv for layer_count + meta._layer_count = 5 + + bm = meta.bitmasks + assert isinstance(bm, dict) + assert 1 in bm + assert 2 in bm + + def test_bitmasks_cached_after_first_access(self): + """Once computed, bitmasks should be cached.""" + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], FANOUT=2, SPATIAL_BITS=10) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + meta._layer_count = 5 + + bm1 = meta.bitmasks + bm2 = meta.bitmasks + assert bm1 is bm2 + + +class TestLayerChunkBoundsComputed: + """Test layer_chunk_bounds property computation.""" + + def test_layer_chunk_bounds_computed(self): + """layer_chunk_bounds should be computed from voxel_counts and chunk_size.""" + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], FANOUT=2, SPATIAL_BITS=10) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + mock_cv = MagicMock() + mock_cv.bounds.to_list.return_value = [0, 0, 0, 256, 256, 256] + meta._ws_cv = mock_cv + # layer_count needs to be set to avoid recursive calls + meta._layer_count = 4 + + bounds = meta.layer_chunk_bounds + assert isinstance(bounds, dict) + # For layer_count=4, should have entries for layers 2 and 3 + assert 2 in bounds + assert 3 in bounds + # With 256/64=4 chunks, layer 2 should have 4 chunks per dim + np.testing.assert_array_equal(bounds[2], np.array([4, 4, 4])) + # layer 3: 4/2 = 2 chunks per dim + np.testing.assert_array_equal(bounds[3], np.array([2, 2, 2])) + + def test_layer_chunk_bounds_cached(self): + """After first access, layer_chunk_bounds should be cached.""" + gc = GraphConfig(CHUNK_SIZE=[64, 64, 64], FANOUT=2, SPATIAL_BITS=10) + ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds) + + mock_cv = MagicMock() + mock_cv.bounds.to_list.return_value = [0, 0, 0, 256, 256, 256] + meta._ws_cv = mock_cv + meta._layer_count = 4 + + bounds1 = meta.layer_chunk_bounds + bounds2 = meta.layer_chunk_bounds + assert bounds1 is bounds2 diff --git a/pychunkedgraph/tests/test_mincut.py b/pychunkedgraph/tests/test_mincut.py new file mode 100644 index 000000000..6208c444a --- /dev/null +++ b/pychunkedgraph/tests/test_mincut.py @@ -0,0 +1,317 @@ +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import exceptions +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGraphMinCut: + # TODO: Ideally, those tests should focus only on mincut retrieving the correct edges. + # The edge removal part should be tested exhaustively in TestGraphSplit + @pytest.mark.timeout(30) + def test_cut_regular_link(self, gen_graph): + """ + Regular link between 1 and 2 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + # Mincut + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + disallow_isolating_cut=True, + ).new_root_ids + + # verify new state + assert len(new_root_ids) == 2 + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( + to_label(cg, 1, 1, 0, 0, 0) + ) + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 1, 0, 0, 0))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 1, 0, 0, 0) in leaves + + @pytest.mark.timeout(30) + def test_cut_no_link(self, gen_graph): + """ + No connection between 1 and 2 + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1 │ 2 │ + │ │ │ + └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + res_old = cg.client._table.read_rows() + res_old.consume_all() + + # Mincut + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + ) + + res_new = cg.client._table.read_rows() + res_new.consume_all() + + assert res_new.rows == res_old.rows + + @pytest.mark.timeout(30) + def test_cut_old_link(self, gen_graph): + """ + Link between 1 and 2 got removed previously (aff = 0.0) + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1┅┅╎┅┅2 │ + │ │ │ + └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + cg.remove_edges( + "John Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ) + + res_old = cg.client._table.read_rows() + res_old.consume_all() + + # Mincut + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + ) + + res_new = cg.client._table.read_rows() + res_new.consume_all() + + assert res_new.rows == res_old.rows + + @pytest.mark.timeout(30) + def test_cut_indivisible_link(self, gen_graph): + """ + Sink: 1, Source: 2 + Link between 1 and 2 is set to `inf` and must not be cut. + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1══╪══2 │ + │ │ │ + └─────┴─────┘ + """ + + cg = gen_graph(n_layers=3) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], + timestamp=fake_timestamp, + ) + + # Preparation: Build Chunk B + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + timestamp=fake_timestamp, + ) + + add_parent_chunk( + cg, + 3, + [0, 0, 0], + time_stamp=fake_timestamp, + n_threads=1, + ) + + original_parents_1 = cg.get_root( + to_label(cg, 1, 0, 0, 0, 0), get_all_parents=True + ) + original_parents_2 = cg.get_root( + to_label(cg, 1, 1, 0, 0, 0), get_all_parents=True + ) + + # Mincut + with pytest.raises(exceptions.PostconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + source_coords=[0, 0, 0], + sink_coords=[ + 2 * cg.meta.graph_config.CHUNK_SIZE[0], + 2 * cg.meta.graph_config.CHUNK_SIZE[1], + cg.meta.graph_config.CHUNK_SIZE[2], + ], + mincut=True, + ) + + new_parents_1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0), get_all_parents=True) + new_parents_2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0), get_all_parents=True) + + assert np.all(np.array(original_parents_1) == np.array(new_parents_1)) + assert np.all(np.array(original_parents_2) == np.array(new_parents_2)) + + @pytest.mark.timeout(30) + def test_mincut_disrespects_sources_or_sinks(self, gen_graph): + """ + When the mincut separates sources or sinks, an error should be thrown. + Although the mincut is setup to never cut an edge between two sources or + two sinks, this can happen when an edge along the only path between two + sources or two sinks is cut. + """ + cg = gen_graph(n_layers=2) + + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 0, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + to_label(cg, 1, 0, 0, 0, 2), + to_label(cg, 1, 0, 0, 0, 3), + ], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 2), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 3), + (to_label(cg, 1, 0, 0, 0, 2), to_label(cg, 1, 0, 0, 0, 3), 10), + ], + timestamp=fake_timestamp, + ) + + # Mincut + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + sink_ids=[to_label(cg, 1, 0, 0, 0, 3)], + source_coords=[[0, 0, 0], [10, 0, 0]], + sink_coords=[[5, 5, 0]], + mincut=True, + ) diff --git a/pychunkedgraph/tests/test_misc.py b/pychunkedgraph/tests/test_misc.py new file mode 100644 index 000000000..0181934c2 --- /dev/null +++ b/pychunkedgraph/tests/test_misc.py @@ -0,0 +1,293 @@ +"""Tests for pychunkedgraph.graph.misc""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from pychunkedgraph.graph.misc import ( + get_latest_roots, + get_delta_roots, + get_proofread_root_ids, + get_agglomerations, + get_activated_edges, +) +from pychunkedgraph.graph.edges import Edges +from pychunkedgraph.graph.types import Agglomeration + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGetLatestRoots: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + roots = get_latest_roots(graph) + assert len(roots) >= 1 + + def test_with_timestamp(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + roots_before = get_latest_roots(graph, fake_ts - timedelta(days=1)) + roots_after = get_latest_roots(graph) + # Before creation, there should be no roots + assert len(roots_before) == 0 + assert len(roots_after) >= 1 + + +class TestGetDeltaRoots: + def test_basic(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + + before_merge = datetime.now(UTC) + + graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + + new_roots, expired_roots = get_delta_roots(graph, before_merge) + assert len(new_roots) >= 1 + + +class TestGetProofreadRootIds: + def test_after_merge(self, gen_graph): + """After a merge, get_proofread_root_ids should return old and new root IDs.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + + create_chunk( + graph, + vertices=[sv0, sv1], + edges=[], + timestamp=fake_ts, + ) + + before_merge = datetime.now(UTC) + + # Both SVs should have separate roots before merge + old_root0 = graph.get_root(sv0) + old_root1 = graph.get_root(sv1) + assert old_root0 != old_root1 + + # Perform a merge + graph.add_edges( + "TestUser", + [sv0, sv1], + affinities=[0.3], + ) + + # After merge, the two SVs share a new root + new_root = graph.get_root(sv0) + assert new_root == graph.get_root(sv1) + + old_roots, new_roots = get_proofread_root_ids(graph, start_time=before_merge) + + # The new root from the merge should appear in new_roots + assert new_root in new_roots + # The old roots that were merged should appear in old_roots + old_roots_set = set(old_roots.tolist()) + assert old_root0 in old_roots_set or old_root1 in old_roots_set + + def test_empty_when_no_operations(self, gen_graph): + """When no operations occurred, both arrays should be empty.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + + # Query a time range in the future where no operations exist + future = datetime.now(UTC) + timedelta(days=1) + old_roots, new_roots = get_proofread_root_ids(graph, start_time=future) + + assert len(old_roots) == 0 + assert len(new_roots) == 0 + + +class TestGetAgglomerations: + def test_single_l2id(self): + """Test get_agglomerations with a single L2 ID and its supervoxels.""" + l2id = np.uint64(100) + sv1 = np.uint64(1) + sv2 = np.uint64(2) + sv3 = np.uint64(3) + + l2id_children_d = {l2id: np.array([sv1, sv2, sv3], dtype=np.uint64)} + + # sv_parent_d maps supervoxel -> parent l2id + sv_parent_d = {sv1: l2id, sv2: l2id, sv3: l2id} + + # in_edges: edges within the agglomeration (sv1-sv2, sv2-sv3) + in_edges = Edges( + np.array([sv1, sv2], dtype=np.uint64), + np.array([sv2, sv3], dtype=np.uint64), + ) + + # ot_edges: edges to other agglomerations (empty here) + ot_edges = Edges(np.array([], dtype=np.uint64), np.array([], dtype=np.uint64)) + + # cx_edges: cross-chunk edges (empty here) + cx_edges = Edges(np.array([], dtype=np.uint64), np.array([], dtype=np.uint64)) + + result = get_agglomerations( + l2id_children_d, in_edges, ot_edges, cx_edges, sv_parent_d + ) + + assert l2id in result + agg = result[l2id] + assert isinstance(agg, Agglomeration) + assert agg.node_id == l2id + np.testing.assert_array_equal( + agg.supervoxels, np.array([sv1, sv2, sv3], dtype=np.uint64) + ) + # The in_edges should contain both edges (sv1-sv2, sv2-sv3) since both have node_ids1 mapping to l2id + assert len(agg.in_edges) == 2 + assert len(agg.out_edges) == 0 + assert len(agg.cross_edges) == 0 + + def test_multiple_l2ids(self): + """Test get_agglomerations partitions edges correctly across multiple L2 IDs.""" + l2id_a = np.uint64(100) + l2id_b = np.uint64(200) + + sv_a1 = np.uint64(1) + sv_a2 = np.uint64(2) + sv_b1 = np.uint64(3) + sv_b2 = np.uint64(4) + + l2id_children_d = { + l2id_a: np.array([sv_a1, sv_a2], dtype=np.uint64), + l2id_b: np.array([sv_b1, sv_b2], dtype=np.uint64), + } + + sv_parent_d = {sv_a1: l2id_a, sv_a2: l2id_a, sv_b1: l2id_b, sv_b2: l2id_b} + + # in_edges: internal edges for each agglomeration + in_edges = Edges( + np.array([sv_a1, sv_b1], dtype=np.uint64), + np.array([sv_a2, sv_b2], dtype=np.uint64), + ) + + # ot_edges: edge from sv_a2 to sv_b1 (between agglomerations) + ot_edges = Edges( + np.array([sv_a2, sv_b1], dtype=np.uint64), + np.array([sv_b1, sv_a2], dtype=np.uint64), + ) + + # cx_edges: empty + cx_edges = Edges(np.array([], dtype=np.uint64), np.array([], dtype=np.uint64)) + + result = get_agglomerations( + l2id_children_d, in_edges, ot_edges, cx_edges, sv_parent_d + ) + + assert len(result) == 2 + assert l2id_a in result + assert l2id_b in result + + agg_a = result[l2id_a] + agg_b = result[l2id_b] + + # Each agglomeration should have exactly 1 internal edge + assert len(agg_a.in_edges) == 1 + assert len(agg_b.in_edges) == 1 + + # Each agglomeration should have exactly 1 out_edge + assert len(agg_a.out_edges) == 1 + assert len(agg_b.out_edges) == 1 + + def test_empty_edges(self): + """Test get_agglomerations with an L2 ID that has no edges at all.""" + l2id = np.uint64(50) + sv = np.uint64(10) + + l2id_children_d = {l2id: np.array([sv], dtype=np.uint64)} + sv_parent_d = {sv: l2id} + + in_edges = Edges(np.array([], dtype=np.uint64), np.array([], dtype=np.uint64)) + ot_edges = Edges(np.array([], dtype=np.uint64), np.array([], dtype=np.uint64)) + cx_edges = Edges(np.array([], dtype=np.uint64), np.array([], dtype=np.uint64)) + + result = get_agglomerations( + l2id_children_d, in_edges, ot_edges, cx_edges, sv_parent_d + ) + + assert l2id in result + agg = result[l2id] + assert agg.node_id == l2id + assert len(agg.in_edges) == 0 + assert len(agg.out_edges) == 0 + assert len(agg.cross_edges) == 0 + + +class TestGetActivatedEdges: + @pytest.mark.timeout(30) + def test_returns_numpy_array_after_merge(self, gen_graph): + """After merging two isolated SVs, get_activated_edges returns a numpy array.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + sv0 = to_label(graph, 1, 0, 0, 0, 0) + sv1 = to_label(graph, 1, 0, 0, 0, 1) + + create_chunk( + graph, + vertices=[sv0, sv1], + edges=[], + timestamp=fake_ts, + ) + + # Merge the two isolated supervoxels + result = graph.add_edges( + "TestUser", + [sv0, sv1], + affinities=[0.3], + ) + + activated = get_activated_edges(graph, result.operation_id) + assert isinstance(activated, np.ndarray) diff --git a/pychunkedgraph/tests/test_multicut.py b/pychunkedgraph/tests/test_multicut.py new file mode 100644 index 000000000..078a74f9e --- /dev/null +++ b/pychunkedgraph/tests/test_multicut.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from ..graph.edges import Edges +from ..graph import exceptions +from ..graph.cutting import run_multicut + + +class TestGraphMultiCut: + @pytest.mark.timeout(30) + def test_cut_multi_tree(self, gen_graph): + """ + Multicut on a graph with multiple sources and sinks and parallel paths. + Sources: [1, 2], Sinks: [5, 6] + Graph: + 1━━3━━5 + ┃ ┃ + 2━━4━━6 + The multicut should find edges to separate {1,2} from {5,6}. + """ + node_ids1 = np.array([1, 2, 3, 4, 3, 1], dtype=np.uint64) + node_ids2 = np.array([3, 4, 5, 6, 4, 2], dtype=np.uint64) + affinities = np.array([0.5, 0.5, 0.5, 0.5, 0.8, 0.9], dtype=np.float32) + edges = Edges(node_ids1, node_ids2, affinities=affinities) + source_ids = np.array([1, 2], dtype=np.uint64) + sink_ids = np.array([5, 6], dtype=np.uint64) + + cut_edges = run_multicut( + edges, source_ids, sink_ids, path_augment=False, disallow_isolating_cut=False + ) + assert cut_edges.shape[0] > 0 + + # Verify the cut actually separates sources from sinks + cut_set = set(map(tuple, cut_edges.tolist())) + remaining = set() + for i in range(len(node_ids1)): + e = (int(node_ids1[i]), int(node_ids2[i])) + if e not in cut_set and (e[1], e[0]) not in cut_set: + remaining.add(e) + + # BFS from sources through remaining edges + reachable = set(source_ids.tolist()) + changed = True + while changed: + changed = False + for a, b in remaining: + if a in reachable and b not in reachable: + reachable.add(b) + changed = True + if b in reachable and a not in reachable: + reachable.add(a) + changed = True + # Sinks should not be reachable from sources + for s in sink_ids: + assert int(s) not in reachable + + @pytest.mark.timeout(30) + def test_path_augmented_multicut(self, sv_data): + sv_edges, sv_sources, sv_sinks, sv_affinity, sv_area = sv_data + edges = Edges( + sv_edges[:, 0], sv_edges[:, 1], affinities=sv_affinity, areas=sv_area + ) + cut_edges_aug = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) + assert cut_edges_aug.shape[0] == 350 + + with pytest.raises(exceptions.PreconditionError): + run_multicut(edges, sv_sources, sv_sinks, path_augment=False) diff --git a/pychunkedgraph/tests/test_node_conversion.py b/pychunkedgraph/tests/test_node_conversion.py new file mode 100644 index 000000000..68ca2810f --- /dev/null +++ b/pychunkedgraph/tests/test_node_conversion.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest + +from .helpers import to_label +from ..graph.utils.serializers import serialize_uint64 +from ..graph.utils.serializers import deserialize_uint64 + + +class TestGraphNodeConversion: + @pytest.mark.timeout(30) + def test_compute_bitmasks(self, gen_graph): + cg = gen_graph(n_layers=10) + # Verify bitmasks for layer and spatial bits + node_id = cg.get_node_id(np.uint64(1), layer=2, x=0, y=0, z=0) + assert cg.get_chunk_layer(node_id) == 2 + assert cg.get_segment_id(node_id) == 1 + + # Different layers should produce different bitmask regions + for layer in range(2, 10): + nid = cg.get_node_id(np.uint64(1), layer=layer, x=0, y=0, z=0) + assert cg.get_chunk_layer(nid) == layer + + @pytest.mark.timeout(30) + def test_node_conversion(self, gen_graph): + cg = gen_graph(n_layers=10) + + node_id = cg.get_node_id(np.uint64(4), layer=2, x=3, y=1, z=0) + assert cg.get_chunk_layer(node_id) == 2 + assert np.all(cg.get_chunk_coordinates(node_id) == np.array([3, 1, 0])) + + chunk_id = cg.get_chunk_id(layer=2, x=3, y=1, z=0) + assert cg.get_chunk_layer(chunk_id) == 2 + assert np.all(cg.get_chunk_coordinates(chunk_id) == np.array([3, 1, 0])) + + assert cg.get_chunk_id(node_id=node_id) == chunk_id + assert cg.get_node_id(np.uint64(4), chunk_id=chunk_id) == node_id + + @pytest.mark.timeout(30) + def test_node_id_adjacency(self, gen_graph): + cg = gen_graph(n_layers=10) + + assert cg.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0) + np.uint64( + 1 + ) == cg.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0) + + assert cg.get_node_id( + np.uint64(2**53 - 2), layer=10, x=0, y=0, z=0 + ) + np.uint64(1) == cg.get_node_id( + np.uint64(2**53 - 1), layer=10, x=0, y=0, z=0 + ) + + @pytest.mark.timeout(30) + def test_serialize_node_id(self, gen_graph): + cg = gen_graph(n_layers=10) + + assert serialize_uint64( + cg.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0) + ) < serialize_uint64(cg.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0)) + + assert serialize_uint64( + cg.get_node_id(np.uint64(2**53 - 2), layer=10, x=0, y=0, z=0) + ) < serialize_uint64( + cg.get_node_id(np.uint64(2**53 - 1), layer=10, x=0, y=0, z=0) + ) + + @pytest.mark.timeout(30) + def test_deserialize_node_id(self, gen_graph): + cg = gen_graph(n_layers=10) + node_id = cg.get_node_id(np.uint64(4), layer=2, x=3, y=1, z=0) + serialized = serialize_uint64(node_id) + assert deserialize_uint64(serialized) == node_id + + @pytest.mark.timeout(30) + def test_serialization_roundtrip(self, gen_graph): + cg = gen_graph(n_layers=10) + # Test various node IDs across layers and positions + for layer in [2, 5, 10]: + for seg_id in [0, 1, 42, 2**16]: + node_id = cg.get_node_id(np.uint64(seg_id), layer=layer, x=0, y=0, z=0) + assert deserialize_uint64(serialize_uint64(node_id)) == node_id + + @pytest.mark.timeout(30) + def test_serialize_valid_label_id(self): + label = np.uint64(0x01FF031234556789) + assert deserialize_uint64(serialize_uint64(label)) == label diff --git a/pychunkedgraph/tests/test_operation.py b/pychunkedgraph/tests/test_operation.py index ff7cb65bd..626efbf7e 100644 --- a/pychunkedgraph/tests/test_operation.py +++ b/pychunkedgraph/tests/test_operation.py @@ -1,261 +1,960 @@ -# from collections import namedtuple - -# import numpy as np -# import pytest - -# from ..graph.operation import ( -# GraphEditOperation, -# MergeOperation, -# MulticutOperation, -# RedoOperation, -# SplitOperation, -# UndoOperation, -# ) -# from ..graph import attributes - - -# class FakeLogRecords: -# Record = namedtuple("graph_op", ("id", "record")) - -# _records = [ -# { # 0: Merge with coordinates -# attributes.OperationLogs.AddedEdge: np.array([[1, 2]], dtype=np.uint64), -# attributes.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), -# attributes.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), -# attributes.OperationLogs.UserID: "42", -# }, -# { # 1: Multicut with coordinates -# attributes.OperationLogs.BoundingBoxOffset: np.array([240, 240, 24]), -# attributes.OperationLogs.RemovedEdge: np.array( -# [[1, 3], [4, 1], [1, 5]], dtype=np.uint64 -# ), -# attributes.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), -# attributes.OperationLogs.SinkID: np.array([1], dtype=np.uint64), -# attributes.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), -# attributes.OperationLogs.SourceID: np.array([2], dtype=np.uint64), -# attributes.OperationLogs.UserID: "42", -# }, -# { # 2: Split with coordinates -# attributes.OperationLogs.RemovedEdge: np.array( -# [[1, 3], [4, 1], [1, 5]], dtype=np.uint64 -# ), -# attributes.OperationLogs.SinkCoordinate: np.array([[1, 2, 3]]), -# attributes.OperationLogs.SinkID: np.array([1], dtype=np.uint64), -# attributes.OperationLogs.SourceCoordinate: np.array([[4, 5, 6]]), -# attributes.OperationLogs.SourceID: np.array([2], dtype=np.uint64), -# attributes.OperationLogs.UserID: "42", -# }, -# { # 3: Undo of records[0] -# attributes.OperationLogs.UndoOperationID: np.uint64(0), -# attributes.OperationLogs.UserID: "42", -# }, -# { # 4: Redo of records[0] -# attributes.OperationLogs.RedoOperationID: np.uint64(0), -# attributes.OperationLogs.UserID: "42", -# }, -# {attributes.OperationLogs.UserID: "42",}, # 5: Unknown record -# ] - -# MERGE = Record(id=np.uint64(0), record=_records[0]) -# MULTICUT = Record(id=np.uint64(1), record=_records[1]) -# SPLIT = Record(id=np.uint64(2), record=_records[2]) -# UNDO = Record(id=np.uint64(3), record=_records[3]) -# REDO = Record(id=np.uint64(4), record=_records[4]) -# UNKNOWN = Record(id=np.uint64(5), record=_records[5]) - -# @classmethod -# def get(cls, idx: int): -# try: -# return cls._records[idx] -# except IndexError as err: -# raise KeyError(err) # Bigtable would throw KeyError instead - - -# @pytest.fixture(scope="function") -# def cg(mocker): -# graph = mocker.MagicMock() -# graph.get_chunk_layer = mocker.MagicMock(return_value=1) -# graph.read_log_row = mocker.MagicMock(side_effect=FakeLogRecords.get) -# return graph - - -# def test_read_from_log_merge(mocker, cg): -# """MergeOperation should be correctly identified by an existing AddedEdge column. -# Coordinates are optional.""" -# graph_operation = GraphEditOperation.from_log_record( -# cg, FakeLogRecords.MERGE.record -# ) -# assert isinstance(graph_operation, MergeOperation) - - -# def test_read_from_log_multicut(mocker, cg): -# """MulticutOperation should be correctly identified by a Sink/Source ID and -# BoundingBoxOffset column. Unless requested as SplitOperation...""" -# graph_operation = GraphEditOperation.from_log_record( -# cg, FakeLogRecords.MULTICUT.record, multicut_as_split=False -# ) -# assert isinstance(graph_operation, MulticutOperation) - -# graph_operation = GraphEditOperation.from_log_record( -# cg, FakeLogRecords.MULTICUT.record, multicut_as_split=True -# ) -# assert isinstance(graph_operation, SplitOperation) - - -# def test_read_from_log_split(mocker, cg): -# """SplitOperation should be correctly identified by the lack of a -# BoundingBoxOffset column.""" -# graph_operation = GraphEditOperation.from_log_record( -# cg, FakeLogRecords.SPLIT.record -# ) -# assert isinstance(graph_operation, SplitOperation) - - -# def test_read_from_log_undo(mocker, cg): -# """UndoOperation should be correctly identified by the UndoOperationID.""" -# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.UNDO.record) -# assert isinstance(graph_operation, UndoOperation) - - -# def test_read_from_log_redo(mocker, cg): -# """RedoOperation should be correctly identified by the RedoOperationID.""" -# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.REDO.record) -# assert isinstance(graph_operation, RedoOperation) - - -# def test_read_from_log_undo_undo(mocker, cg): -# """Undo[Undo[Merge]] -> Redo[Merge]""" -# fake_log_record = { -# attributes.OperationLogs.UndoOperationID: np.uint64(FakeLogRecords.UNDO.id), -# attributes.OperationLogs.UserID: "42", -# } - -# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) -# assert isinstance(graph_operation, RedoOperation) -# assert isinstance(graph_operation.superseded_operation, MergeOperation) - - -# def test_read_from_log_undo_redo(mocker, cg): -# """Undo[Redo[Merge]] -> Undo[Merge]""" -# fake_log_record = { -# attributes.OperationLogs.UndoOperationID: np.uint64(FakeLogRecords.REDO.id), -# attributes.OperationLogs.UserID: "42", -# } - -# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) -# assert isinstance(graph_operation, UndoOperation) -# assert isinstance(graph_operation.inverse_superseded_operation, SplitOperation) - - -# def test_read_from_log_redo_undo(mocker, cg): -# """Redo[Undo[Merge]] -> Undo[Merge]""" -# fake_log_record = { -# attributes.OperationLogs.RedoOperationID: np.uint64(FakeLogRecords.UNDO.id), -# attributes.OperationLogs.UserID: "42", -# } - -# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) -# assert isinstance(graph_operation, UndoOperation) -# assert isinstance(graph_operation.inverse_superseded_operation, SplitOperation) - - -# def test_read_from_log_redo_redo(mocker, cg): -# """Redo[Redo[Merge]] -> Redo[Merge]""" -# fake_log_record = { -# attributes.OperationLogs.RedoOperationID: np.uint64(FakeLogRecords.REDO.id), -# attributes.OperationLogs.UserID: "42", -# } - -# graph_operation = GraphEditOperation.from_log_record(cg, fake_log_record) -# assert isinstance(graph_operation, RedoOperation) -# assert isinstance(graph_operation.superseded_operation, MergeOperation) - - -# def test_invert_merge(mocker, cg): -# """Inverse of Merge is a Split""" -# graph_operation = GraphEditOperation.from_log_record( -# cg, FakeLogRecords.MERGE.record -# ) -# inverted_graph_operation = graph_operation.invert() -# assert isinstance(inverted_graph_operation, SplitOperation) -# assert np.all( -# np.equal(graph_operation.added_edges, inverted_graph_operation.removed_edges) -# ) - - -# @pytest.mark.skip( -# reason="Can't test right now - would require recalculting the Multicut" -# ) -# def test_invert_multicut(mocker, cg): -# """Inverse of a Multicut is a Merge""" - - -# def test_invert_split(mocker, cg): -# """Inverse of Split is a Merge""" -# graph_operation = GraphEditOperation.from_log_record( -# cg, FakeLogRecords.SPLIT.record -# ) -# inverted_graph_operation = graph_operation.invert() -# assert isinstance(inverted_graph_operation, MergeOperation) -# assert np.all( -# np.equal(graph_operation.removed_edges, inverted_graph_operation.added_edges) -# ) - - -# def test_invert_undo(mocker, cg): -# """Inverse of Undo[x] is Redo[x]""" -# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.UNDO.record) -# inverted_graph_operation = graph_operation.invert() -# assert isinstance(inverted_graph_operation, RedoOperation) -# assert ( -# graph_operation.superseded_operation_id -# == inverted_graph_operation.superseded_operation_id -# ) - - -# def test_invert_redo(mocker, cg): -# """Inverse of Redo[x] is Undo[x]""" -# graph_operation = GraphEditOperation.from_log_record(cg, FakeLogRecords.REDO.record) -# inverted_graph_operation = graph_operation.invert() -# assert ( -# graph_operation.superseded_operation_id -# == inverted_graph_operation.superseded_operation_id -# ) - - -# def test_undo_redo_chain_fails(mocker, cg): -# """Prevent creation of Undo/Redo chains""" -# with pytest.raises(ValueError): -# UndoOperation( -# cg, -# user_id="DAU", -# superseded_operation_id=FakeLogRecords.UNDO.id, -# multicut_as_split=False, -# ) -# with pytest.raises(ValueError): -# UndoOperation( -# cg, -# user_id="DAU", -# superseded_operation_id=FakeLogRecords.REDO.id, -# multicut_as_split=False, -# ) -# with pytest.raises(ValueError): -# RedoOperation( -# cg, -# user_id="DAU", -# superseded_operation_id=FakeLogRecords.UNDO.id, -# multicut_as_split=False, -# ) -# with pytest.raises(ValueError): -# UndoOperation( -# cg, -# user_id="DAU", -# superseded_operation_id=FakeLogRecords.REDO.id, -# multicut_as_split=False, -# ) - - -# def test_unknown_log_record_fails(cg, mocker): -# """TypeError when encountering unknown log row""" -# with pytest.raises(TypeError): -# GraphEditOperation.from_log_record(cg, FakeLogRecords.UNKNOWN.record) +"""Integration tests for GraphEditOperation and its subclasses. + +Tests operation type identification from log records, operation inversion, +undo/redo chain resolution, ID validation, and execute error handling +-- all using real graph operations through the BigTable emulator. +""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import attributes +from ..graph.operation import ( + GraphEditOperation, + MergeOperation, + MulticutOperation, + SplitOperation, + RedoOperation, + UndoOperation, +) +from ..graph.exceptions import PreconditionError, PostconditionError +from ..ingest.create.parent_layer import add_parent_chunk + + +def _build_two_sv_disconnected(gen_graph): + """2-layer graph, two disconnected SVs in the same chunk.""" + cg = gen_graph(n_layers=2, atomic_chunk_bounds=np.array([1, 1, 1])) + ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[], + timestamp=ts, + ) + return cg, ts + + +def _build_two_sv_connected(gen_graph): + """2-layer graph, two connected SVs in the same chunk.""" + cg = gen_graph(n_layers=2, atomic_chunk_bounds=np.array([1, 1, 1])) + ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=ts, + ) + return cg, ts + + +def _build_cross_chunk(gen_graph): + """4-layer graph with cross-chunk edges suitable for MulticutOperation.""" + cg = gen_graph(n_layers=4) + ts = datetime.now(UTC) - timedelta(days=10) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + create_chunk( + cg, + vertices=[sv0, sv1], + edges=[ + (sv0, sv1, 0.5), + (sv0, to_label(cg, 1, 1, 0, 0, 0), inf), + ], + timestamp=ts, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), sv0, inf)], + timestamp=ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], n_threads=1) + return cg, ts, sv0, sv1 + + +# =========================================================================== +# Existing tests (log record types, inversion, undo/redo chain) +# =========================================================================== +class TestOperationFromLogRecord: + """Test that GraphEditOperation.from_log_record correctly identifies operation types.""" + + @pytest.fixture() + def merged_graph(self, gen_graph): + """Build a simple 2-chunk graph and perform a merge, returning (cg, operation_id).""" + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Split first to get two separate roots + split_result = cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + + # Now merge them back + merge_result = cg.add_edges( + "test_user", + atomic_edges=[[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)]], + source_coords=[0, 0, 0], + sink_coords=[0, 0, 0], + ) + return cg, merge_result.operation_id, split_result.operation_id + + @pytest.mark.timeout(30) + def test_merge_log_record_type(self, merged_graph): + """MergeOperation should be correctly identified from a real merge log record.""" + cg, merge_op_id, _ = merged_graph + log_record, _ = cg.client.read_log_entry(merge_op_id) + op_type = GraphEditOperation.get_log_record_type(log_record) + assert op_type is MergeOperation + + @pytest.mark.timeout(30) + def test_split_log_record_type(self, merged_graph): + """SplitOperation should be correctly identified from a real split log record.""" + cg, _, split_op_id = merged_graph + log_record, _ = cg.client.read_log_entry(split_op_id) + op_type = GraphEditOperation.get_log_record_type(log_record) + assert op_type is SplitOperation + + @pytest.mark.timeout(30) + def test_merge_from_log_record(self, merged_graph): + """from_log_record should return a MergeOperation for a real merge log.""" + cg, merge_op_id, _ = merged_graph + log_record, _ = cg.client.read_log_entry(merge_op_id) + graph_op = GraphEditOperation.from_log_record(cg, log_record) + assert isinstance(graph_op, MergeOperation) + + @pytest.mark.timeout(30) + def test_split_from_log_record(self, merged_graph): + """from_log_record should return a SplitOperation for a real split log.""" + cg, _, split_op_id = merged_graph + log_record, _ = cg.client.read_log_entry(split_op_id) + graph_op = GraphEditOperation.from_log_record(cg, log_record) + assert isinstance(graph_op, SplitOperation) + + @pytest.mark.timeout(30) + def test_unknown_log_record_fails(self, gen_graph): + """TypeError when encountering a log record with no recognizable operation columns.""" + cg = gen_graph(n_layers=3) + fake_record = {attributes.OperationLogs.UserID: "test_user"} + with pytest.raises(TypeError): + GraphEditOperation.from_log_record(cg, fake_record) + + +class TestOperationInversion: + """Test that operation inversion produces the correct inverse type and edges.""" + + @pytest.fixture() + def split_and_merge_ops(self, gen_graph): + """Build graph, split, merge -- return (cg, merge_op_id, split_op_id).""" + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + split_result = cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + merge_result = cg.add_edges( + "test_user", + atomic_edges=[[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)]], + source_coords=[0, 0, 0], + sink_coords=[0, 0, 0], + ) + return cg, merge_result.operation_id, split_result.operation_id + + @pytest.mark.timeout(30) + def test_invert_merge_produces_split(self, split_and_merge_ops): + """Inverse of a MergeOperation should be a SplitOperation with matching edges.""" + cg, merge_op_id, _ = split_and_merge_ops + log_record, _ = cg.client.read_log_entry(merge_op_id) + merge_op = GraphEditOperation.from_log_record(cg, log_record) + inverted = merge_op.invert() + assert isinstance(inverted, SplitOperation) + assert np.all(np.equal(merge_op.added_edges, inverted.removed_edges)) + + @pytest.mark.timeout(30) + def test_invert_split_produces_merge(self, split_and_merge_ops): + """Inverse of a SplitOperation should be a MergeOperation with matching edges.""" + cg, _, split_op_id = split_and_merge_ops + log_record, _ = cg.client.read_log_entry(split_op_id) + split_op = GraphEditOperation.from_log_record(cg, log_record) + inverted = split_op.invert() + assert isinstance(inverted, MergeOperation) + assert np.all(np.equal(split_op.removed_edges, inverted.added_edges)) + + +class TestUndoRedoChainResolution: + """Test undo/redo chain resolution through real graph operations.""" + + @pytest.fixture() + def graph_with_undo(self, gen_graph): + """Build graph, perform split, then undo -- return (cg, split_op_id, undo_result).""" + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Split + split_result = cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + # Undo the split (= merge) + undo_result = cg.undo_operation("test_user", split_result.operation_id) + return cg, split_result.operation_id, undo_result + + @pytest.mark.timeout(30) + def test_undo_log_record_type(self, graph_with_undo): + """Undo operation log record should be identified as UndoOperation.""" + cg, _, undo_result = graph_with_undo + log_record, _ = cg.client.read_log_entry(undo_result.operation_id) + op_type = GraphEditOperation.get_log_record_type(log_record) + assert op_type is UndoOperation + + @pytest.mark.timeout(30) + def test_undo_from_log_resolves_correctly(self, graph_with_undo): + """from_log_record on an undo record should resolve the chain to an UndoOperation.""" + cg, split_op_id, undo_result = graph_with_undo + log_record, _ = cg.client.read_log_entry(undo_result.operation_id) + resolved_op = GraphEditOperation.from_log_record(cg, log_record) + # Undo of a split -> UndoOperation whose inverse is a MergeOperation + assert isinstance(resolved_op, UndoOperation) + + @pytest.mark.timeout(30) + def test_redo_after_undo(self, graph_with_undo): + """Redo of the original split (after undo) should produce a RedoOperation log.""" + cg, split_op_id, undo_result = graph_with_undo + + # Redo the original split (which was undone) + redo_result = cg.redo_operation("test_user", split_op_id) + assert redo_result.operation_id is not None + redo_log, _ = cg.client.read_log_entry(redo_result.operation_id) + resolved_op = GraphEditOperation.from_log_record(cg, redo_log) + assert isinstance(resolved_op, RedoOperation) + + @pytest.mark.timeout(30) + def test_undo_redo_chain_prevention(self, graph_with_undo): + """Direct UndoOperation/RedoOperation on undo/redo targets should raise ValueError.""" + cg, _, undo_result = graph_with_undo + + # Direct UndoOperation on an undo record should fail + with pytest.raises(ValueError): + UndoOperation( + cg, + user_id="test_user", + superseded_operation_id=undo_result.operation_id, + multicut_as_split=True, + ) + + # Direct RedoOperation on an undo record should also fail + with pytest.raises(ValueError): + RedoOperation( + cg, + user_id="test_user", + superseded_operation_id=undo_result.operation_id, + multicut_as_split=True, + ) + + +# =========================================================================== +# NEW: Multicut log record type identification (lines 151-153) +# =========================================================================== +class TestGetLogRecordTypeMulticut: + """Synthetic tests for MulticutOperation identification in get_log_record_type.""" + + def test_bbox_only_is_multicut(self): + """BoundingBoxOffset with no RemovedEdge -> MulticutOperation (line 152-153).""" + log = {attributes.OperationLogs.BoundingBoxOffset: np.array([10, 10, 10])} + assert GraphEditOperation.get_log_record_type(log) is MulticutOperation + + def test_removed_edge_with_bbox_multicut_as_split_true(self): + """RemovedEdge + BoundingBoxOffset + multicut_as_split=True -> SplitOperation (line 150).""" + log = { + attributes.OperationLogs.RemovedEdge: np.array([[1, 2]], dtype=np.uint64), + attributes.OperationLogs.BoundingBoxOffset: np.array([10, 10, 10]), + } + assert ( + GraphEditOperation.get_log_record_type(log, multicut_as_split=True) + is SplitOperation + ) + + def test_removed_edge_with_bbox_multicut_as_split_false(self): + """RemovedEdge + BoundingBoxOffset + multicut_as_split=False -> MulticutOperation (line 151).""" + log = { + attributes.OperationLogs.RemovedEdge: np.array([[1, 2]], dtype=np.uint64), + attributes.OperationLogs.BoundingBoxOffset: np.array([10, 10, 10]), + } + assert ( + GraphEditOperation.get_log_record_type(log, multicut_as_split=False) + is MulticutOperation + ) + + def test_undo_log_record(self): + """UndoOperationID in log -> UndoOperation.""" + log = {attributes.OperationLogs.UndoOperationID: np.uint64(42)} + assert GraphEditOperation.get_log_record_type(log) is UndoOperation + + def test_redo_log_record(self): + """RedoOperationID in log -> RedoOperation.""" + log = {attributes.OperationLogs.RedoOperationID: np.uint64(42)} + assert GraphEditOperation.get_log_record_type(log) is RedoOperation + + def test_empty_log_raises_type_error(self): + """Empty log record should raise TypeError (line 154).""" + with pytest.raises(TypeError, match="Could not determine"): + GraphEditOperation.get_log_record_type({}) + + +# =========================================================================== +# NEW: from_log_record MulticutOperation path (lines 235-251) +# =========================================================================== +class TestFromLogRecordMulticutPath: + """Test from_log_record for the MulticutOperation path with multicut_as_split=False.""" + + @pytest.mark.timeout(60) + def test_multicut_from_log_record(self, gen_graph): + """A multicut operation's log, read back with multicut_as_split=False, + should be reconstructed as MulticutOperation (lines 235-249).""" + cg, _, sv0, sv1 = _build_cross_chunk(gen_graph) + source_coords = [[0, 0, 0]] + sink_coords = [[512, 0, 0]] + try: + mc_result = cg.remove_edges( + "test_user", + source_ids=sv0, + sink_ids=sv1, + source_coords=source_coords, + sink_coords=sink_coords, + mincut=True, + ) + except (PreconditionError, PostconditionError): + pytest.skip("Multicut not feasible in this small test graph") + + log, _ = cg.client.read_log_entry(mc_result.operation_id) + op = GraphEditOperation.from_log_record(cg, log, multicut_as_split=False) + assert isinstance(op, MulticutOperation) + + # With default multicut_as_split=True -> SplitOperation + op2 = GraphEditOperation.from_log_record(cg, log, multicut_as_split=True) + assert isinstance(op2, SplitOperation) + + +# =========================================================================== +# NEW: from_operation_id (lines 278-281) +# =========================================================================== +class TestFromOperationId: + """Test GraphEditOperation.from_operation_id round-trip.""" + + @pytest.mark.timeout(30) + def test_from_operation_id_merge(self, gen_graph): + """from_operation_id should reconstruct a MergeOperation.""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + result = cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + op = GraphEditOperation.from_operation_id(cg, result.operation_id) + assert isinstance(op, MergeOperation) + # privileged_mode defaults to False + assert op.privileged_mode is False + + @pytest.mark.timeout(30) + def test_from_operation_id_privileged(self, gen_graph): + """from_operation_id with privileged_mode=True should propagate the flag (line 280).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + result = cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + op = GraphEditOperation.from_operation_id( + cg, result.operation_id, privileged_mode=True + ) + assert op.privileged_mode is True + + @pytest.mark.timeout(30) + def test_from_operation_id_split(self, gen_graph): + """from_operation_id should reconstruct a SplitOperation.""" + cg, _ = _build_two_sv_connected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + op = GraphEditOperation.from_operation_id(cg, result.operation_id) + assert isinstance(op, SplitOperation) + + +# =========================================================================== +# NEW: MulticutOperation.invert() (line 974-981) +# =========================================================================== +class TestMulticutInversion: + """Test MulticutOperation.invert() returns a MergeOperation.""" + + @pytest.mark.timeout(30) + def test_multicut_invert(self, gen_graph): + """MulticutOperation.invert() -> MergeOperation with removed_edges as added_edges.""" + cg, _, sv0, sv1 = _build_cross_chunk(gen_graph) + mc_op = MulticutOperation( + cg, + user_id="test_user", + source_ids=[sv0], + sink_ids=[sv1], + source_coords=[[0, 0, 0]], + sink_coords=[[512, 0, 0]], + bbox_offset=[240, 240, 24], + removed_edges=np.array([[sv0, sv1]], dtype=np.uint64), + ) + inverted = mc_op.invert() + assert isinstance(inverted, MergeOperation) + np.testing.assert_array_equal(inverted.added_edges, mc_op.removed_edges) + + +# =========================================================================== +# NEW: ID validation -- self-loops and overlapping IDs (lines 593-596, 732-733, 871-875) +# =========================================================================== +class TestIDValidation: + """Test PreconditionError on self-loops and overlapping IDs.""" + + @pytest.mark.timeout(30) + def test_merge_self_loop_raises(self, gen_graph): + """added_edges where source == sink should raise PreconditionError (line 596).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + with pytest.raises(PreconditionError, match="self-loop"): + MergeOperation( + cg, + user_id="test_user", + added_edges=[[sv0, sv0]], + source_coords=None, + sink_coords=None, + ) + + @pytest.mark.timeout(30) + def test_split_self_loop_raises(self, gen_graph): + """removed_edges where source == sink should raise PreconditionError (line 733).""" + cg, _ = _build_two_sv_connected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + with pytest.raises(PreconditionError, match="self-loop"): + SplitOperation( + cg, + user_id="test_user", + removed_edges=[[sv0, sv0]], + source_coords=None, + sink_coords=None, + ) + + @pytest.mark.timeout(30) + def test_multicut_overlapping_ids_raises(self, gen_graph): + """source_ids overlapping sink_ids should raise PreconditionError (line 872).""" + cg, _, sv0, sv1 = _build_cross_chunk(gen_graph) + with pytest.raises(PreconditionError, match="both sink and source"): + MulticutOperation( + cg, + user_id="test_user", + source_ids=[sv0, sv1], + sink_ids=[sv1], + source_coords=[[0, 0, 0], [1, 0, 0]], + sink_coords=[[1, 0, 0]], + bbox_offset=[240, 240, 24], + ) + + +# =========================================================================== +# NEW: Empty coords / affinities normalization (lines 82, 86, 593) +# =========================================================================== +class TestEmptyCoordsAffinities: + """Empty source/sink coords and affinities should be normalized to None.""" + + @pytest.mark.timeout(30) + def test_empty_source_coords_becomes_none(self, gen_graph): + """source_coords with size 0 should be stored as None (line 82).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + op = MergeOperation( + cg, + user_id="test_user", + added_edges=[[sv0, sv1]], + source_coords=np.array([], dtype=np.int64).reshape(0, 3), + sink_coords=np.array([], dtype=np.int64).reshape(0, 3), + ) + assert op.source_coords is None + assert op.sink_coords is None + + @pytest.mark.timeout(30) + def test_empty_affinities_becomes_none(self, gen_graph): + """affinities with size 0 should be stored as None (line 593).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + op = MergeOperation( + cg, + user_id="test_user", + added_edges=[[sv0, sv1]], + source_coords=None, + sink_coords=None, + affinities=np.array([], dtype=np.float32), + ) + assert op.affinities is None + + +# =========================================================================== +# NEW: Merge / Split preconditions via execute (lines 618, 765) +# =========================================================================== +class TestEditPreconditions: + """Test precondition errors raised during _apply.""" + + @pytest.mark.timeout(30) + def test_merge_same_segment_raises(self, gen_graph): + """Merging SVs already in the same root raises PreconditionError (line 618).""" + cg, _ = _build_two_sv_connected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + with pytest.raises(PreconditionError, match="different objects"): + cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + @pytest.mark.timeout(30) + def test_split_different_roots_raises(self, gen_graph): + """Splitting SVs from different roots raises PreconditionError (line 765).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + with pytest.raises(PreconditionError, match="same object"): + cg.remove_edges("test_user", source_ids=sv0, sink_ids=sv1, mincut=False) + + +# =========================================================================== +# NEW: Undo / Redo via actual operations (lines 1160-1175, 1245-1259, etc.) +# =========================================================================== +class TestUndoRedoExecute: + """End-to-end undo/redo tests that verify graph state after execute.""" + + def _build_connected_cross_chunk(self, gen_graph): + """Build a 3-layer graph with between-chunk edge -- suitable for split+undo.""" + cg = gen_graph(n_layers=3) + ts = datetime.now(UTC) - timedelta(days=10) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 1, 0, 0, 0) + create_chunk( + cg, + vertices=[sv0], + edges=[(sv0, sv1, 0.5)], + timestamp=ts, + ) + create_chunk( + cg, + vertices=[sv1], + edges=[(sv1, sv0, 0.5)], + timestamp=ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=ts, n_threads=1) + return cg, sv0, sv1 + + @pytest.mark.timeout(60) + def test_undo_split_restores_root(self, gen_graph): + """After split + undo, the SVs should share a root again (lines 1160-1175).""" + cg, sv0, sv1 = self._build_connected_cross_chunk(gen_graph) + assert cg.get_root(sv0) == cg.get_root(sv1) + + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + assert cg.get_root(sv0) != cg.get_root(sv1) + + undo_result = cg.undo_operation("test_user", split_result.operation_id) + assert cg.get_root(sv0) == cg.get_root(sv1) + + @pytest.mark.timeout(60) + def test_redo_split_after_undo(self, gen_graph): + """After split + undo, redo the split directly (lines 1036-1043, 1094-1106).""" + cg, sv0, sv1 = self._build_connected_cross_chunk(gen_graph) + + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + assert cg.get_root(sv0) != cg.get_root(sv1) + + undo_result = cg.undo_operation("test_user", split_result.operation_id) + assert cg.get_root(sv0) == cg.get_root(sv1) + + # Redo the original split directly + redo_result = cg.redo_operation("test_user", split_result.operation_id) + # The redo should succeed and re-apply the split + assert redo_result.operation_id is not None + + @pytest.mark.timeout(60) + def test_undo_of_undo_resolves_to_redo(self, gen_graph): + """Undoing an undo should resolve to a RedoOperation (lines 102-108).""" + cg, sv0, sv1 = self._build_connected_cross_chunk(gen_graph) + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + undo_result = cg.undo_operation("test_user", split_result.operation_id) + + op = GraphEditOperation.undo_operation( + cg, user_id="test_user", operation_id=undo_result.operation_id + ) + assert isinstance(op, RedoOperation) + + +# =========================================================================== +# NEW: UndoOperation / RedoOperation .invert() (lines 1087, 1228) +# =========================================================================== +class TestUndoRedoInvert: + """Test invert() on UndoOperation and RedoOperation.""" + + @pytest.mark.timeout(60) + def test_undo_invert_is_redo(self, gen_graph): + """UndoOperation.invert() -> RedoOperation (line 1228).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + merge_result = cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + undo_op = GraphEditOperation.undo_operation( + cg, user_id="test_user", operation_id=merge_result.operation_id + ) + assert isinstance(undo_op, UndoOperation) + inverted = undo_op.invert() + assert isinstance(inverted, RedoOperation) + assert inverted.superseded_operation_id == undo_op.superseded_operation_id + + @pytest.mark.timeout(60) + def test_redo_invert_is_undo(self, gen_graph): + """RedoOperation.invert() -> UndoOperation (line 1087).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + merge_result = cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + redo_op = GraphEditOperation.redo_operation( + cg, user_id="test_user", operation_id=merge_result.operation_id + ) + assert isinstance(redo_op, RedoOperation) + inverted = redo_op.invert() + assert isinstance(inverted, UndoOperation) + assert inverted.superseded_operation_id == redo_op.superseded_operation_id + + +# =========================================================================== +# NEW: UndoOperation / RedoOperation edge attributes (lines 1040-1043, 1172-1175) +# =========================================================================== +class TestUndoRedoEdgeAttributes: + """Verify that undo/redo operations carry the correct edge attributes.""" + + @pytest.mark.timeout(60) + def test_undo_merge_has_removed_edges(self, gen_graph): + """Undoing a merge -> inverse is SplitOp -> undo should have removed_edges (line 1175).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + merge_result = cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + undo_op = GraphEditOperation.undo_operation( + cg, user_id="test_user", operation_id=merge_result.operation_id + ) + assert hasattr(undo_op, "removed_edges") + assert undo_op.removed_edges.shape[1] == 2 + + @pytest.mark.timeout(60) + def test_undo_split_has_added_edges(self, gen_graph): + """Undoing a split -> inverse is MergeOp -> undo should have added_edges (line 1173).""" + cg, _ = _build_two_sv_connected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + + undo_op = GraphEditOperation.undo_operation( + cg, user_id="test_user", operation_id=split_result.operation_id + ) + assert hasattr(undo_op, "added_edges") + assert undo_op.added_edges.shape[1] == 2 + + @pytest.mark.timeout(60) + def test_redo_merge_has_added_edges(self, gen_graph): + """RedoOperation for a merge should have added_edges (line 1040-1041).""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + merge_result = cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + redo_op = GraphEditOperation.redo_operation( + cg, user_id="test_user", operation_id=merge_result.operation_id + ) + assert isinstance(redo_op, RedoOperation) + assert hasattr(redo_op, "added_edges") + assert redo_op.added_edges.shape[1] == 2 + + @pytest.mark.timeout(60) + def test_redo_split_has_removed_edges(self, gen_graph): + """RedoOperation for a split should have removed_edges (line 1042-1043).""" + cg, _ = _build_two_sv_connected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + + redo_op = GraphEditOperation.redo_operation( + cg, user_id="test_user", operation_id=split_result.operation_id + ) + assert isinstance(redo_op, RedoOperation) + assert hasattr(redo_op, "removed_edges") + assert redo_op.removed_edges.shape[1] == 2 + + +# =========================================================================== +# NEW: Undo / Redo log record type from actual operations +# =========================================================================== +class TestUndoRedoLogRecordTypes: + """Verify that actual undo/redo operations produce correct log record types.""" + + def _build_and_split(self, gen_graph): + """Build a cross-chunk graph and split it -- suitable for undo/redo.""" + cg = gen_graph(n_layers=3) + ts = datetime.now(UTC) - timedelta(days=10) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 1, 0, 0, 0) + create_chunk( + cg, + vertices=[sv0], + edges=[(sv0, sv1, 0.5)], + timestamp=ts, + ) + create_chunk( + cg, + vertices=[sv1], + edges=[(sv1, sv0, 0.5)], + timestamp=ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=ts, n_threads=1) + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + return cg, sv0, sv1, split_result + + @pytest.mark.timeout(60) + def test_undo_log_type(self, gen_graph): + """Undo operation log should be identified as UndoOperation.""" + cg, sv0, sv1, split_result = self._build_and_split(gen_graph) + undo_result = cg.undo_operation("test_user", split_result.operation_id) + + log, _ = cg.client.read_log_entry(undo_result.operation_id) + assert GraphEditOperation.get_log_record_type(log) is UndoOperation + + @pytest.mark.timeout(60) + def test_redo_log_type(self, gen_graph): + """Redo operation log should be identified as RedoOperation.""" + cg, sv0, sv1, split_result = self._build_and_split(gen_graph) + undo_result = cg.undo_operation("test_user", split_result.operation_id) + + # Redo the split that was just undone + redo_result = cg.redo_operation("test_user", split_result.operation_id) + assert redo_result.operation_id is not None + + log, _ = cg.client.read_log_entry(redo_result.operation_id) + assert GraphEditOperation.get_log_record_type(log) is RedoOperation + + +# =========================================================================== +# NEW: execute() error handling -- PreconditionError clears cache (lines 436, 460-462) +# =========================================================================== +class TestExecuteErrorHandling: + """Test that execute() clears cache on PreconditionError/PostconditionError.""" + + @pytest.mark.timeout(30) + def test_execute_precondition_error_clears_cache(self, gen_graph): + """Trigger PreconditionError during merge (same-segment merge) and verify cache is cleared.""" + cg, _ = _build_two_sv_connected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + + # Merging already-connected SVs raises PreconditionError + with pytest.raises(PreconditionError, match="different objects"): + cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + # After the error, the graph cache should have been cleared (set to None) + assert cg.cache is None + + @pytest.mark.timeout(30) + def test_execute_postcondition_error_clears_cache(self, gen_graph): + """PostconditionError during execute should also clear cache (lines 463-465).""" + from unittest.mock import patch + + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + + # Mock _apply to raise PostconditionError + with patch.object( + MergeOperation, + "_apply", + side_effect=PostconditionError("test postcondition error"), + ): + with pytest.raises(PostconditionError, match="test postcondition error"): + cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + # Cache should have been cleared + assert cg.cache is None + + @pytest.mark.timeout(30) + def test_execute_assertion_error_clears_cache(self, gen_graph): + """AssertionError/RuntimeError during execute should also clear cache (lines 466-468).""" + from unittest.mock import patch + + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + + # Mock _apply to raise RuntimeError + with patch.object( + MergeOperation, "_apply", side_effect=RuntimeError("test runtime error") + ): + with pytest.raises(RuntimeError, match="test runtime error"): + cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + + assert cg.cache is None + + +# =========================================================================== +# NEW: UndoOperation.execute() edge validation (lines 1245-1267) +# =========================================================================== +class TestUndoEdgeValidation: + """Test UndoOperation.execute() edge validation logic.""" + + def _build_connected_cross_chunk(self, gen_graph): + """Build a 3-layer graph with between-chunk edge suitable for split+undo.""" + cg = gen_graph(n_layers=3) + ts = datetime.now(UTC) - timedelta(days=10) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 1, 0, 0, 0) + create_chunk( + cg, + vertices=[sv0], + edges=[(sv0, sv1, 0.5)], + timestamp=ts, + ) + create_chunk( + cg, + vertices=[sv1], + edges=[(sv1, sv0, 0.5)], + timestamp=ts, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=ts, n_threads=1) + return cg, sv0, sv1 + + @pytest.mark.timeout(60) + def test_undo_split_restores_edges(self, gen_graph): + """After undo of a split, edges should be active again.""" + cg, sv0, sv1 = self._build_connected_cross_chunk(gen_graph) + + # Verify initially connected + assert cg.get_root(sv0) == cg.get_root(sv1) + + # Split + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + assert cg.get_root(sv0) != cg.get_root(sv1) + + # Undo the split + undo_result = cg.undo_operation("test_user", split_result.operation_id) + assert undo_result.operation_id is not None + + # Edges should be active again -- the SVs share a root + assert cg.get_root(sv0) == cg.get_root(sv1) + + @pytest.mark.timeout(60) + def test_undo_merge_via_undo_operation_class(self, gen_graph): + """UndoOperation on a merge constructs with inverse being SplitOperation.""" + cg, _ = _build_two_sv_disconnected(gen_graph) + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + + # Merge + merge_result = cg.add_edges("test_user", [sv0, sv1], affinities=[0.3]) + assert cg.get_root(sv0) == cg.get_root(sv1) + + # Build the UndoOperation manually to inspect its structure + undo_op = GraphEditOperation.undo_operation( + cg, user_id="test_user", operation_id=merge_result.operation_id + ) + assert isinstance(undo_op, UndoOperation) + # The inverse of a merge is a split, so removed_edges should be set + assert hasattr(undo_op, "removed_edges") + assert undo_op.removed_edges.shape[1] == 2 + + @pytest.mark.timeout(60) + def test_undo_noop_when_split_already_undone(self, gen_graph): + """UndoOperation.execute() with edges already active returns early (lines 1253-1258).""" + cg, sv0, sv1 = self._build_connected_cross_chunk(gen_graph) + + # Split + split_result = cg.remove_edges( + "test_user", source_ids=sv0, sink_ids=sv1, mincut=False + ) + + # First undo + undo_result1 = cg.undo_operation("test_user", split_result.operation_id) + assert cg.get_root(sv0) == cg.get_root(sv1) + + # Second undo of the same split -- the inverse is a MergeOp + # and since those edges are already active, it should return early + # (lines 1253-1258: if np.all(a): early return with empty Result) + undo_result2 = cg.undo_operation("test_user", split_result.operation_id) + # The early return path returns a Result with operation_id=None and empty arrays + assert undo_result2.operation_id is None + assert len(undo_result2.new_root_ids) == 0 diff --git a/pychunkedgraph/tests/test_root_lock.py b/pychunkedgraph/tests/test_root_lock.py index a5ef7d4d2..1228c8ae9 100644 --- a/pychunkedgraph/tests/test_root_lock.py +++ b/pychunkedgraph/tests/test_root_lock.py @@ -1,104 +1,85 @@ -# from unittest.mock import DEFAULT - -# import numpy as np -# import pytest - -# from ..graph import exceptions -# from ..graph.locks import RootLock - -# G_UINT64 = np.uint64(2 ** 63) - - -# def big_uint64(): -# """Return incremental uint64 values larger than a signed int64""" -# global G_UINT64 -# if G_UINT64 == np.uint64(2 ** 64 - 1): -# G_UINT64 = np.uint64(2 ** 63) -# G_UINT64 = G_UINT64 + np.uint64(1) -# return G_UINT64 - - -# class RootLockTracker: -# def __init__(self): -# self.active_locks = dict() - -# def add_locks(self, root_ids, operation_id, **kwargs): -# if operation_id not in self.active_locks: -# self.active_locks[operation_id] = set(root_ids) -# else: -# self.active_locks[operation_id].update(root_ids) -# return DEFAULT - -# def remove_lock(self, root_id, operation_id, **kwargs): -# if operation_id in self.active_locks: -# self.active_locks[operation_id].discard(root_id) -# return DEFAULT - - -# @pytest.fixture() -# def root_lock_tracker(): -# return RootLockTracker() - - -# def test_successful_lock_acquisition(mocker, root_lock_tracker): -# """Ensure that root locks got released after successful -# root lock acquisition + *successful* graph operation""" -# fake_operation_id = big_uint64() -# fake_locked_root_ids = np.array((big_uint64(), big_uint64())) - -# cg = mocker.MagicMock() -# cg.id_client.create_operation_id = mocker.MagicMock(return_value=fake_operation_id) -# cg.client.lock_roots = mocker.MagicMock( -# return_value=(True, fake_locked_root_ids), -# side_effect=root_lock_tracker.add_locks, -# ) -# cg.client.unlock_root = mocker.MagicMock( -# return_value=True, side_effect=root_lock_tracker.remove_lock -# ) - -# with RootLock(cg, fake_locked_root_ids): -# assert fake_operation_id in root_lock_tracker.active_locks -# assert not root_lock_tracker.active_locks[fake_operation_id].difference( -# fake_locked_root_ids -# ) - -# assert not root_lock_tracker.active_locks[fake_operation_id] - - -# def test_failed_lock_acquisition(mocker): -# """Ensure that LockingError is raised when lock acquisition failed""" -# fake_operation_id = big_uint64() -# fake_locked_root_ids = np.array((big_uint64(), big_uint64())) - -# cg = mocker.MagicMock() -# cg.id_client.create_operation_id = mocker.MagicMock(return_value=fake_operation_id) -# cg.client.lock_roots = mocker.MagicMock( -# return_value=(False, fake_locked_root_ids), side_effect=None -# ) - -# with pytest.raises(exceptions.LockingError): -# with RootLock(cg, fake_locked_root_ids): -# pass - - -# def test_failed_graph_operation(mocker, root_lock_tracker): -# """Ensure that root locks got released after successful -# root lock acquisition + *unsuccessful* graph operation""" -# fake_operation_id = big_uint64() -# fake_locked_root_ids = np.array((big_uint64(), big_uint64())) - -# cg = mocker.MagicMock() -# cg.id_client.create_operation_id = mocker.MagicMock(return_value=fake_operation_id) -# cg.client.lock_roots = mocker.MagicMock( -# return_value=(True, fake_locked_root_ids), -# side_effect=root_lock_tracker.add_locks, -# ) -# cg.client.unlock_root = mocker.MagicMock( -# return_value=True, side_effect=root_lock_tracker.remove_lock -# ) - -# with pytest.raises(exceptions.PreconditionError): -# with RootLock(cg, fake_locked_root_ids): -# raise exceptions.PreconditionError("Something went wrong") - -# assert not root_lock_tracker.active_locks[fake_operation_id] +"""Integration tests for RootLock using real graph operations through the BigTable emulator. + +Tests lock acquisition, release, and behavior on operation failure. +""" + +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import exceptions +from ..graph.locks import RootLock +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestRootLock: + @pytest.fixture() + def simple_graph(self, gen_graph): + """Build a 2-chunk graph with a single edge, return (cg, root_id).""" + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + return cg, root_id + + @pytest.mark.timeout(30) + def test_successful_lock_and_release(self, simple_graph): + """Lock acquired successfully inside context, released after exit.""" + cg, root_id = simple_graph + + with RootLock(cg, np.array([root_id])) as lock: + assert lock.lock_acquired + assert len(lock.locked_root_ids) > 0 + + # After exiting the context, the lock should be released. + # Verify by acquiring the same lock again — if it wasn't released, this would fail. + with RootLock(cg, np.array([root_id])) as lock2: + assert lock2.lock_acquired + + @pytest.mark.timeout(30) + def test_lock_released_on_exception(self, simple_graph): + """Lock should be released even when an exception occurs inside the context.""" + cg, root_id = simple_graph + + with pytest.raises(exceptions.PreconditionError): + with RootLock(cg, np.array([root_id])) as lock: + assert lock.lock_acquired + raise exceptions.PreconditionError("Simulated failure") + + # Lock should still be released — acquiring again should succeed + with RootLock(cg, np.array([root_id])) as lock2: + assert lock2.lock_acquired + + @pytest.mark.timeout(30) + def test_operation_with_lock_succeeds(self, simple_graph): + """A real graph operation (split) should succeed while holding the lock.""" + cg, root_id = simple_graph + + # Use the high-level API which acquires locks internally + result = cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + assert len(result.new_root_ids) == 2 + + # After operation, locks should be released — verify we can re-acquire + new_root = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + with RootLock(cg, np.array([new_root])) as lock: + assert lock.lock_acquired diff --git a/pychunkedgraph/tests/test_segmenthistory.py b/pychunkedgraph/tests/test_segmenthistory.py new file mode 100644 index 000000000..0ccb2ab55 --- /dev/null +++ b/pychunkedgraph/tests/test_segmenthistory.py @@ -0,0 +1,627 @@ +"""Tests for pychunkedgraph.graph.segmenthistory""" + +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest +from pandas import DataFrame + +from pychunkedgraph.graph.segmenthistory import ( + SegmentHistory, + LogEntry, + get_all_log_entries, +) + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestSegmentHistory: + def _build_and_merge(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, new_root + + def test_init(self, gen_graph): + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + assert len(sh.root_ids) == 1 + + def test_lineage_graph(self, gen_graph): + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + lg = sh.lineage_graph + assert len(lg.nodes) > 0 + + def test_operation_ids(self, gen_graph): + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + ops = sh.operation_ids + assert len(ops) > 0 + + def test_past_operation_ids(self, gen_graph): + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + past_ops = sh.past_operation_ids(root_id=new_root) + assert isinstance(past_ops, np.ndarray) + + def test_collect_edited_sv_ids(self, gen_graph): + """After a merge, collect_edited_sv_ids should return supervoxel IDs.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + sv_ids = sh.collect_edited_sv_ids() + assert isinstance(sv_ids, np.ndarray) + assert sv_ids.dtype == np.uint64 + # The merge involved 2 supervoxels, so at least some IDs should appear + assert len(sv_ids) > 0 + + def test_collect_edited_sv_ids_with_root(self, gen_graph): + """collect_edited_sv_ids with an explicit root_id should also work.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + sv_ids = sh.collect_edited_sv_ids(root_id=new_root) + assert isinstance(sv_ids, np.ndarray) + assert len(sv_ids) > 0 + + def test_root_id_operation_id_dict(self, gen_graph): + """root_id_operation_id_dict maps each root_id in the lineage to its operation_id.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + d = sh.root_id_operation_id_dict + assert isinstance(d, dict) + # Should contain at least the new root + assert new_root in d + # Values should be integer operation IDs (including 0 for non-edit nodes) + for root_id, op_id in d.items(): + assert isinstance(root_id, (int, np.integer)) + assert isinstance(op_id, (int, np.integer)) + + def test_root_id_timestamp_dict(self, gen_graph): + """root_id_timestamp_dict maps each root_id to a timestamp.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + d = sh.root_id_timestamp_dict + assert isinstance(d, dict) + assert new_root in d + # Timestamps should be numeric (epoch seconds) or 0 for defaults + for root_id, ts in d.items(): + assert isinstance(ts, (int, float, np.integer, np.floating)) + + def test_last_edit_timestamp(self, gen_graph): + """last_edit_timestamp should return the timestamp for the given root.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + ts = sh.last_edit_timestamp(root_id=new_root) + # Should be a numeric timestamp (float epoch) or default value + assert isinstance(ts, (int, float, np.integer, np.floating)) + + def test_log_entry_api(self, gen_graph): + """After a merge, retrieve a log entry and verify its properties.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + op_ids = sh.operation_ids + # Filter out operation_id 0 (default for nodes without operations) + op_ids = op_ids[op_ids != 0] + assert len(op_ids) > 0, "Expected at least one real operation ID" + + entry = sh.log_entry(op_ids[0]) + assert isinstance(entry, LogEntry) + + # is_merge should be True since we performed a merge + assert entry.is_merge is True + + # user_id should be the user we passed to add_edges + assert entry.user_id == "TestUser" + + # log_type should be "merge" + assert entry.log_type == "merge" + + # edges_failsafe should return an array of SV IDs + ef = entry.edges_failsafe + assert isinstance(ef, np.ndarray) + assert len(ef) > 0 + + # __str__ should return a non-empty string + s = str(entry) + assert isinstance(s, str) + assert len(s) > 0 + + # __iter__ should yield attributes (user_id, log_type, root_ids, timestamp) + items = list(entry) + assert len(items) == 4 + + def test_tabular_changelogs(self, gen_graph): + """After a merge, tabular_changelogs should produce a DataFrame per root.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + changelogs = sh.tabular_changelogs + assert isinstance(changelogs, dict) + assert new_root in changelogs + + df = changelogs[new_root] + assert isinstance(df, DataFrame) + + # Verify expected columns are present + expected_columns = { + "operation_id", + "timestamp", + "user_id", + "before_root_ids", + "after_root_ids", + "is_merge", + "in_neuron", + "is_relevant", + } + assert expected_columns.issubset(set(df.columns)) + + # Should have at least one row (the merge we performed) + assert len(df) > 0 + + def test_tabular_changelog_single_root(self, gen_graph): + """tabular_changelog() with a single root should return the DataFrame directly.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + df = sh.tabular_changelog() + assert isinstance(df, DataFrame) + assert len(df) > 0 + + def test_operation_id_root_id_dict(self, gen_graph): + """operation_id_root_id_dict should be the inverse of root_id_operation_id_dict.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + d = sh.operation_id_root_id_dict + assert isinstance(d, dict) + # Each value should be a list of root IDs + for op_id, root_ids in d.items(): + assert isinstance(root_ids, list) + assert len(root_ids) > 0 + + def test_tabular_changelogs_filtered(self, gen_graph): + """After merge, tabular_changelogs_filtered returns dict with DataFrames + that have 'in_neuron' and 'is_relevant' columns.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + filtered = sh.tabular_changelogs_filtered + assert isinstance(filtered, dict) + assert new_root in filtered + df = filtered[new_root] + assert isinstance(df, DataFrame) + # The filtered method calls tabular_changelog(filtered=True) which + # drops "in_neuron" and "is_relevant" columns after filtering + assert "in_neuron" not in df.columns + assert "is_relevant" not in df.columns + + def test_tabular_changelog_with_explicit_root(self, gen_graph): + """tabular_changelog(root_id=new_root) should work same as without.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + df_implicit = sh.tabular_changelog() + df_explicit = sh.tabular_changelog(root_id=new_root) + assert isinstance(df_explicit, DataFrame) + assert len(df_explicit) == len(df_implicit) + # Same columns + assert set(df_explicit.columns) == set(df_implicit.columns) + + def test_change_log_summary(self, gen_graph): + """change_log_summary should return n_splits, n_mergers, user_info, etc.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + summary = sh.change_log_summary(root_id=new_root) + assert isinstance(summary, dict) + assert "n_splits" in summary + assert "n_mergers" in summary + assert "user_info" in summary + assert "operations_ids" in summary + assert "past_ids" in summary + assert summary["n_mergers"] >= 1 + + def test_past_future_id_mapping(self, gen_graph): + """past_future_id_mapping should return two dicts mapping past<->future root IDs.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + past_map, future_map = sh.past_future_id_mapping(root_id=new_root) + assert isinstance(past_map, dict) + assert isinstance(future_map, dict) + # The new_root should appear in past_map + assert int(new_root) in past_map + + +class TestLogEntryUnit: + """Pure unit tests for LogEntry class (no emulator needed).""" + + def test_merge_entry(self): + from pychunkedgraph.graph.attributes import OperationLogs + + row = { + OperationLogs.AddedEdge: np.array([[1, 2]], dtype=np.uint64), + OperationLogs.UserID: "alice", + OperationLogs.RootID: np.array([100], dtype=np.uint64), + OperationLogs.SourceID: np.array([1], dtype=np.uint64), + OperationLogs.SinkID: np.array([2], dtype=np.uint64), + OperationLogs.SourceCoordinate: np.array([0, 0, 0]), + OperationLogs.SinkCoordinate: np.array([1, 1, 1]), + } + ts = datetime.now(UTC) + entry = LogEntry(row, timestamp=ts) + assert entry.is_merge is True + assert entry.log_type == "merge" + assert entry.user_id == "alice" + assert entry.timestamp == ts + np.testing.assert_array_equal(entry.root_ids, np.array([100], dtype=np.uint64)) + np.testing.assert_array_equal( + entry.added_edges, np.array([[1, 2]], dtype=np.uint64) + ) + coords = entry.coordinates + assert coords.shape == (2, 3) + ef = entry.edges_failsafe + assert len(ef) > 0 + + def test_split_entry(self): + from pychunkedgraph.graph.attributes import OperationLogs + + row = { + OperationLogs.RemovedEdge: np.array([[3, 4]], dtype=np.uint64), + OperationLogs.UserID: "bob", + OperationLogs.RootID: np.array([200, 201], dtype=np.uint64), + OperationLogs.SourceID: np.array([3], dtype=np.uint64), + OperationLogs.SinkID: np.array([4], dtype=np.uint64), + OperationLogs.SourceCoordinate: np.array([0, 0, 0]), + OperationLogs.SinkCoordinate: np.array([1, 1, 1]), + } + ts = datetime.now(UTC) + entry = LogEntry(row, timestamp=ts) + assert entry.is_merge is False + assert entry.log_type == "split" + assert entry.user_id == "bob" + np.testing.assert_array_equal( + entry.removed_edges, np.array([[3, 4]], dtype=np.uint64) + ) + assert len(str(entry)) > 0 + assert len(list(entry)) == 4 + + def test_added_edges_on_split_raises(self): + from pychunkedgraph.graph.attributes import OperationLogs + + row = { + OperationLogs.RemovedEdge: np.array([[3, 4]], dtype=np.uint64), + OperationLogs.UserID: "bob", + OperationLogs.RootID: np.array([200], dtype=np.uint64), + } + entry = LogEntry(row, timestamp=datetime.now(UTC)) + with pytest.raises(AssertionError, match="Not a merge"): + entry.added_edges + + def test_removed_edges_on_merge_raises(self): + from pychunkedgraph.graph.attributes import OperationLogs + + row = { + OperationLogs.AddedEdge: np.array([[1, 2]], dtype=np.uint64), + OperationLogs.UserID: "alice", + OperationLogs.RootID: np.array([100], dtype=np.uint64), + } + entry = LogEntry(row, timestamp=datetime.now(UTC)) + with pytest.raises(AssertionError, match="Not a split"): + entry.removed_edges + + +class TestGetAllLogEntries: + def test_empty_graph(self, gen_graph): + """Create graph with no operations. get_all_log_entries should return empty list.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + # Create a chunk with vertices but perform no edits + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + entries = get_all_log_entries(graph) + assert isinstance(entries, list) + assert len(entries) == 0 + + def test_basic(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + # get_all_log_entries iterates range(get_max_operation_id()) which + # may not include the actual operation ID; verify it doesn't crash + entries = get_all_log_entries(graph) + assert isinstance(entries, list) + # If entries exist, verify LogEntry API works + for entry in entries: + assert entry.log_type in ("merge", "split") + assert str(entry) + for _ in entry: + pass + + +class TestMergeLog: + """Tests for SegmentHistory.merge_log() method (lines 245-268).""" + + def _build_and_merge(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + source_coords=[0, 0, 0], + sink_coords=[1, 1, 1], + ) + new_root = result.new_root_ids[0] + return graph, new_root + + def test_merge_log_with_root(self, gen_graph): + """merge_log(root_id=...) should return merge_edges and merge_edge_coords.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + result = sh.merge_log(root_id=new_root) + assert isinstance(result, dict) + assert "merge_edges" in result + assert "merge_edge_coords" in result + # We performed one merge, so there should be one entry + assert len(result["merge_edges"]) >= 1 + assert len(result["merge_edge_coords"]) >= 1 + + def test_merge_log_without_root(self, gen_graph): + """merge_log() without root_id should iterate over all root_ids.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + result = sh.merge_log() + assert isinstance(result, dict) + assert "merge_edges" in result + assert "merge_edge_coords" in result + + def test_merge_log_correct_for_wrong_coord_type_false(self, gen_graph): + """merge_log with correct_for_wrong_coord_type=False should skip coord hack.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + result = sh.merge_log(root_id=new_root, correct_for_wrong_coord_type=False) + assert isinstance(result, dict) + assert "merge_edges" in result + assert len(result["merge_edges"]) >= 1 + + +class TestPastOperationIdsExtended: + """Tests for SegmentHistory.past_operation_ids() (lines 270-292).""" + + def _build_and_merge(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, new_root + + def test_past_operation_ids_without_root(self, gen_graph): + """past_operation_ids() without root_id iterates all root_ids.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + ops = sh.past_operation_ids() + assert isinstance(ops, np.ndarray) + # Should have at least the merge operation + assert len(ops) >= 1 + # 0 should not appear in result + assert 0 not in ops + + def test_past_operation_ids_with_root(self, gen_graph): + """past_operation_ids(root_id=...) should return operations for that root.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + ops = sh.past_operation_ids(root_id=new_root) + assert isinstance(ops, np.ndarray) + assert len(ops) >= 1 + + +class TestPastFutureIdMappingExtended: + """More thorough tests for past_future_id_mapping (lines 315-368).""" + + def _build_and_merge(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + ) + new_root = result.new_root_ids[0] + return graph, new_root + + def test_past_future_id_mapping_without_root(self, gen_graph): + """past_future_id_mapping() without root_id iterates all root_ids.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + past_map, future_map = sh.past_future_id_mapping() + assert isinstance(past_map, dict) + assert isinstance(future_map, dict) + assert int(new_root) in past_map + + def test_past_future_id_mapping_values(self, gen_graph): + """Verify past_map values are arrays of past root IDs.""" + graph, new_root = self._build_and_merge(gen_graph) + sh = SegmentHistory(graph, new_root) + past_map, future_map = sh.past_future_id_mapping(root_id=new_root) + # past_map[int(new_root)] should point back to the original roots + past_ids = past_map[int(new_root)] + assert len(past_ids) >= 1 + # future_map should have entries for the past IDs + for past_id in past_ids: + if past_id in future_map: + assert future_map[past_id] is not None + + +class TestMergeSplitHistory: + """Tests involving merge followed by split to cover more branches.""" + + def _build_merge_and_split(self, gen_graph): + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_ts, + ) + # Merge + merge_result = graph.add_edges( + "TestUser", + [to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + affinities=[0.3], + source_coords=[0, 0, 0], + sink_coords=[1, 1, 1], + ) + merge_root = merge_result.new_root_ids[0] + + # Split + split_result = graph.remove_edges( + "TestUser", + source_ids=to_label(graph, 1, 0, 0, 0, 0), + sink_ids=to_label(graph, 1, 0, 0, 0, 1), + mincut=False, + ) + split_roots = split_result.new_root_ids + return graph, merge_root, split_roots + + def test_change_log_summary_with_split(self, gen_graph): + """change_log_summary after merge+split should show both operations.""" + graph, merge_root, split_roots = self._build_merge_and_split(gen_graph) + # Use the first split root as the segment history root + root = split_roots[0] + sh = SegmentHistory(graph, root) + summary = sh.change_log_summary(root_id=root) + assert isinstance(summary, dict) + assert "n_splits" in summary + assert "n_mergers" in summary + # There was at least a merge and a split in the history + total_ops = summary["n_splits"] + summary["n_mergers"] + assert total_ops >= 1 + + def test_past_operation_ids_after_split(self, gen_graph): + """past_operation_ids should include both merge and split operations.""" + graph, merge_root, split_roots = self._build_merge_and_split(gen_graph) + root = split_roots[0] + sh = SegmentHistory(graph, root) + ops = sh.past_operation_ids(root_id=root) + assert isinstance(ops, np.ndarray) + # Should include at least 2 operations (merge + split) + assert len(ops) >= 2 + + def test_merge_log_after_split(self, gen_graph): + """merge_log after split should still find the original merge.""" + graph, merge_root, split_roots = self._build_merge_and_split(gen_graph) + root = split_roots[0] + sh = SegmentHistory(graph, root) + result = sh.merge_log(root_id=root) + assert isinstance(result, dict) + # The original merge should still be in the history + assert len(result["merge_edges"]) >= 1 + + def test_tabular_changelog_after_split(self, gen_graph): + """tabular_changelog after merge+split should have multiple rows.""" + graph, merge_root, split_roots = self._build_merge_and_split(gen_graph) + root = split_roots[0] + sh = SegmentHistory(graph, root) + df = sh.tabular_changelog(root_id=root) + assert isinstance(df, DataFrame) + # Should have at least 2 rows (merge + split) + assert len(df) >= 2 + + def test_past_future_id_mapping_after_split(self, gen_graph): + """past_future_id_mapping after merge+split should track the lineage.""" + graph, merge_root, split_roots = self._build_merge_and_split(gen_graph) + root = split_roots[0] + sh = SegmentHistory(graph, root) + past_map, future_map = sh.past_future_id_mapping(root_id=root) + assert isinstance(past_map, dict) + assert isinstance(future_map, dict) + + def test_collect_edited_sv_ids_no_edits(self, gen_graph): + """collect_edited_sv_ids returns empty array when no edits exist for a root.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + sh = SegmentHistory(graph, root) + sv_ids = sh.collect_edited_sv_ids(root_id=root) + assert isinstance(sv_ids, np.ndarray) + assert sv_ids.dtype == np.uint64 + assert len(sv_ids) == 0 + + def test_change_log_summary_no_operations(self, gen_graph): + """change_log_summary with no operations should show zero splits/merges.""" + atomic_chunk_bounds = np.array([1, 1, 1]) + graph = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + sh = SegmentHistory(graph, root) + summary = sh.change_log_summary(root_id=root) + assert isinstance(summary, dict) + assert summary["n_splits"] == 0 + assert summary["n_mergers"] == 0 + assert len(summary["past_ids"]) == 0 diff --git a/pychunkedgraph/tests/test_serializers.py b/pychunkedgraph/tests/test_serializers.py new file mode 100644 index 000000000..59f1ed8c3 --- /dev/null +++ b/pychunkedgraph/tests/test_serializers.py @@ -0,0 +1,143 @@ +"""Tests for pychunkedgraph.graph.utils.serializers""" + +import numpy as np + +from pychunkedgraph.graph.utils.serializers import ( + _Serializer, + NumPyArray, + NumPyValue, + String, + JSON, + Pickle, + UInt64String, + pad_node_id, + serialize_uint64, + deserialize_uint64, + serialize_uint64s_to_regex, + serialize_key, + deserialize_key, +) +from pychunkedgraph.graph.utils import basetypes + + +class TestNumPyArray: + def test_roundtrip(self): + s = NumPyArray(dtype=basetypes.NODE_ID) + arr = np.array([1, 2, 3], dtype=basetypes.NODE_ID) + data = s.serialize(arr) + result = s.deserialize(data) + np.testing.assert_array_equal(result, arr) + + def test_with_shape(self): + s = NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)) + arr = np.array([[1, 2], [3, 4]], dtype=basetypes.NODE_ID) + data = s.serialize(arr) + result = s.deserialize(data) + assert result.shape == (2, 2) + np.testing.assert_array_equal(result, arr) + + def test_with_compression(self): + s = NumPyArray(dtype=basetypes.NODE_ID, compression_level=3) + arr = np.array([1, 2, 3, 4, 5], dtype=basetypes.NODE_ID) + data = s.serialize(arr) + result = s.deserialize(data) + np.testing.assert_array_equal(result, arr) + + def test_basetype(self): + s = NumPyArray(dtype=basetypes.NODE_ID) + assert s.basetype == basetypes.NODE_ID.type + + +class TestNumPyValue: + def test_roundtrip(self): + s = NumPyValue(dtype=basetypes.NODE_ID) + val = np.uint64(42) + data = s.serialize(val) + result = s.deserialize(data) + assert result == val + + +class TestString: + def test_roundtrip(self): + s = String() + data = s.serialize("hello") + assert s.deserialize(data) == "hello" + + +class TestJSON: + def test_roundtrip(self): + s = JSON() + obj = {"key": "value", "nested": [1, 2, 3]} + data = s.serialize(obj) + assert s.deserialize(data) == obj + + +class TestPickle: + def test_roundtrip(self): + s = Pickle() + obj = {"complex": [1, 2], "nested": {"a": True}} + data = s.serialize(obj) + assert s.deserialize(data) == obj + + +class TestUInt64String: + def test_roundtrip(self): + s = UInt64String() + val = np.uint64(12345) + data = s.serialize(val) + result = s.deserialize(data) + assert result == val + + +class TestPadNodeId: + def test_padding(self): + result = pad_node_id(np.uint64(42)) + assert len(result) == 20 + assert result == "00000000000000000042" + + def test_large_id(self): + result = pad_node_id(np.uint64(12345678901234567890)) + assert len(result) == 20 + + +class TestSerializeUint64: + def test_default(self): + result = serialize_uint64(np.uint64(42)) + assert isinstance(result, bytes) + assert b"00000000000000000042" in result + + def test_counter(self): + result = serialize_uint64(np.uint64(42), counter=True) + assert result.startswith(b"i") + + def test_fake_edges(self): + result = serialize_uint64(np.uint64(42), fake_edges=True) + assert result.startswith(b"f") + + +class TestDeserializeUint64: + def test_default(self): + serialized = serialize_uint64(np.uint64(42)) + result = deserialize_uint64(serialized) + assert result == np.uint64(42) + + def test_fake_edges(self): + serialized = serialize_uint64(np.uint64(42), fake_edges=True) + result = deserialize_uint64(serialized, fake_edges=True) + assert result == np.uint64(42) + + +class TestSerializeUint64sToRegex: + def test_multiple_ids(self): + ids = [np.uint64(1), np.uint64(2)] + result = serialize_uint64s_to_regex(ids) + assert isinstance(result, bytes) + assert b"|" in result + + +class TestSerializeKey: + def test_roundtrip(self): + key = "test_key_123" + serialized = serialize_key(key) + assert isinstance(serialized, bytes) + assert deserialize_key(serialized) == key diff --git a/pychunkedgraph/tests/test_split.py b/pychunkedgraph/tests/test_split.py new file mode 100644 index 000000000..6b814268a --- /dev/null +++ b/pychunkedgraph/tests/test_split.py @@ -0,0 +1,696 @@ +from datetime import datetime, timedelta, UTC +from math import inf +from warnings import warn + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph import ChunkedGraph +from ..graph import exceptions +from ..graph.misc import get_latest_roots +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestGraphSplit: + @pytest.mark.timeout(30) + def test_split_pair_same_chunk(self, gen_graph): + """ + Remove edge between existing RG supervoxels 1 and 2 (same chunk) + ┌─────┐ ┌─────┐ + │ A¹ │ │ A¹ │ + │ 1━2 │ => │ 1 2 │ + │ │ │ │ + └─────┘ └─────┘ + """ + + cg: ChunkedGraph = gen_graph(n_layers=2) + + # Preparation: Build Chunk A + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], + timestamp=fake_timestamp, + ) + + # Split + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 1), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ).new_root_ids + + # verify new state + assert len(new_root_ids) == 2 + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( + to_label(cg, 1, 0, 0, 0, 1) + ) + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 0, 0, 0, 1))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 1) in leaves + + # verify old state + cg.cache = None + assert cg.get_root( + to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp + ) == cg.get_root(to_label(cg, 1, 0, 0, 0, 1), time_stamp=fake_timestamp) + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], + leaves_only=True, + ) + ) + assert len(leaves) == 2 + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 0, 0, 0, 1) in leaves + + assert len(get_latest_roots(cg)) == 2 + assert len(get_latest_roots(cg, fake_timestamp)) == 1 + + def test_split_nonexisting_edge(self, gen_graph): + """ + ┌─────┐ ┌─────┐ + │ A¹ │ │ A¹ │ + │ 1━2 │ => │ 1━2 │ + │ | │ │ | │ + │ 3 │ │ 3 │ + └─────┘ └─────┘ + """ + cg = gen_graph(n_layers=2) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 2), to_label(cg, 1, 0, 0, 0, 1), 0.5), + ], + timestamp=fake_timestamp, + ) + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 0, 0, 0, 2), + mincut=False, + ).new_root_ids + assert len(new_root_ids) == 1 + + @pytest.mark.timeout(30) + def test_split_pair_neighboring_chunks(self, gen_graph): + """ + Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) + ┌─────┬─────┐ ┌─────┬─────┐ + │ A¹ │ B¹ │ │ A¹ │ B¹ │ + │ 1━━┿━━2 │ => │ 1 │ 2 │ + │ │ │ │ │ │ + └─────┴─────┘ └─────┴─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 1.0)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 1.0)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ).new_root_ids + + # verify new state + assert len(new_root_ids) == 2 + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( + to_label(cg, 1, 1, 0, 0, 0) + ) + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 1, 0, 0, 0))], leaves_only=True + ) + ) + assert len(leaves) == 1 and to_label(cg, 1, 1, 0, 0, 0) in leaves + + # verify old state + assert cg.get_root( + to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp + ) == cg.get_root(to_label(cg, 1, 1, 0, 0, 0), time_stamp=fake_timestamp) + leaves = np.unique( + cg.get_subgraph( + [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], + leaves_only=True, + ) + ) + assert len(leaves) == 2 + assert to_label(cg, 1, 0, 0, 0, 0) in leaves + assert to_label(cg, 1, 1, 0, 0, 0) in leaves + assert len(get_latest_roots(cg)) == 2 + assert len(get_latest_roots(cg, fake_timestamp)) == 1 + + @pytest.mark.timeout(30) + def test_split_verify_cross_chunk_edges(self, gen_graph): + """ + ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ + | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ + | │ 1━━┿━━3 │ => | │ 1━━┿━━3 │ + | │ | │ │ | │ │ │ + | │ 2 │ │ | │ 2 │ │ + └─────┴─────┴─────┘ └─────┴─────┴─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=4) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1), 0.5), + ], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 2, 0, 0, 0)], + edges=[(to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], + timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( + to_label(cg, 1, 1, 0, 0, 1) + ) + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( + to_label(cg, 1, 2, 0, 0, 0) + ) + + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 1), + mincut=False, + ).new_root_ids + + assert len(new_root_ids) == 2 + + svs2 = cg.get_subgraph([new_root_ids[0]], leaves_only=True) + svs1 = cg.get_subgraph([new_root_ids[1]], leaves_only=True) + len_set = {1, 2} + assert len(svs1) in len_set + len_set.remove(len(svs1)) + assert len(svs2) in len_set + + # verify new state + assert len(new_root_ids) == 2 + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) != cg.get_root( + to_label(cg, 1, 1, 0, 0, 1) + ) + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( + to_label(cg, 1, 2, 0, 0, 0) + ) + + assert len(get_latest_roots(cg)) == 2 + assert len(get_latest_roots(cg, fake_timestamp)) == 1 + + @pytest.mark.timeout(30) + def test_split_verify_loop(self, gen_graph): + """ + ┌─────┬────────┬─────┐ ┌─────┬────────┬─────┐ + | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ + | │ 4━━1━━┿━━5 │ => | │ 4 1━━┿━━5 │ + | │ / │ | │ | │ │ | │ + | │ 3 2━━┿━━6 │ | │ 3 2━━┿━━6 │ + └─────┴────────┴─────┘ └─────┴────────┴─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=4) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 1, 0, 0, 0), + to_label(cg, 1, 1, 0, 0, 1), + to_label(cg, 1, 1, 0, 0, 2), + to_label(cg, 1, 1, 0, 0, 3), + ], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), + (to_label(cg, 1, 1, 0, 0, 1), to_label(cg, 1, 2, 0, 0, 1), inf), + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 2), 0.5), + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 3), 0.5), + ], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + (to_label(cg, 1, 2, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 1), inf), + (to_label(cg, 1, 2, 0, 0, 1), to_label(cg, 1, 2, 0, 0, 0), 0.5), + ], + timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( + to_label(cg, 1, 1, 0, 0, 1) + ) + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( + to_label(cg, 1, 2, 0, 0, 0) + ) + + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 2), + mincut=False, + ).new_root_ids + assert len(new_root_ids) == 2 + + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 3), + mincut=False, + ).new_root_ids + assert len(new_root_ids) == 2 + + assert len(get_latest_roots(cg)) == 3 + assert len(get_latest_roots(cg, fake_timestamp)) == 1 + + @pytest.mark.timeout(30) + def test_split_pair_already_disconnected(self, gen_graph): + """ + Try to remove edge between already disconnected RG supervoxels 1 and 2 (same chunk). + ┌─────┐ ┌─────┐ + │ A¹ │ │ A¹ │ + │ 1 2 │ => │ 1 2 │ + │ │ │ │ + └─────┘ └─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=2) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[], + timestamp=fake_timestamp, + ) + res_old = cg.client._table.read_rows() + res_old.consume_all() + + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 1), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ) + + res_new = cg.client._table.read_rows() + res_new.consume_all() + + if res_old.rows != res_new.rows: + warn( + "Rows were modified when splitting a pair of already disconnected supervoxels." + "While probably not an error, it is an unnecessary operation." + ) + + @pytest.mark.timeout(30) + def test_split_full_circle_to_triple_chain_same_chunk(self, gen_graph): + """ + Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (same chunk) + ┌─────┐ ┌─────┐ + │ A¹ │ │ A¹ │ + │ 1━2 │ => │ 1 2 │ + │ ┗3┛ │ │ ┗3┛ │ + └─────┘ └─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=2) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[ + to_label(cg, 1, 0, 0, 0, 0), + to_label(cg, 1, 0, 0, 0, 1), + to_label(cg, 1, 0, 0, 0, 2), + ], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 0.5), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 0.5), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.3), + ], + timestamp=fake_timestamp, + ) + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 1), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ).new_root_ids + + # verify new state + assert len(new_root_ids) == 1 + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 2)) == new_root_ids[0] + leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) + assert len(leaves) == 3 + + # verify old state + old_root_id = cg.get_root( + to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp + ) + assert new_root_ids[0] != old_root_id + assert len(get_latest_roots(cg)) == 1 + assert len(get_latest_roots(cg, fake_timestamp)) == 1 + + @pytest.mark.timeout(30) + def test_split_full_circle_to_triple_chain_neighboring_chunks(self, gen_graph): + """ + Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection + ┌─────┬─────┐ ┌─────┬─────┐ + │ A¹ │ B¹ │ │ A¹ │ B¹ │ + │ 1━━┿━━2 │ => │ 1 │ 2 │ + │ ┗3━┿━━┛ │ │ ┗3━┿━━┛ │ + └─────┴─────┘ └─────┴─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 0), 0.5), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.3), + ], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.3), + ], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + new_root_ids = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 1, 0, 0, 0), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ).new_root_ids + + # verify new state + assert len(new_root_ids) == 1 + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] + assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == new_root_ids[0] + leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) + assert len(leaves) == 3 + + # verify old state + old_root_id = cg.get_root( + to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp + ) + assert new_root_ids[0] != old_root_id + assert len(get_latest_roots(cg)) == 1 + assert len(get_latest_roots(cg, fake_timestamp)) == 1 + + @pytest.mark.timeout(30) + def test_split_same_node(self, gen_graph): + """ + Try to remove (non-existing) edge between RG supervoxel 1 and itself + ┌─────┐ + │ A¹ │ + │ 1 │ => Reject + │ │ + └─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=2) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + res_old = cg.client._table.read_rows() + res_old.consume_all() + with pytest.raises(exceptions.PreconditionError): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 0, 0, 0, 0), + mincut=False, + ) + + res_new = cg.client._table.read_rows() + res_new.consume_all() + assert res_new.rows == res_old.rows + + @pytest.mark.timeout(30) + def test_split_pair_abstract_nodes(self, gen_graph): + """ + Try to remove (non-existing) edge between RG supervoxel 1 and abstract node "2" + => Reject + """ + + cg: ChunkedGraph = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)], edges=[], timestamp=fake_timestamp, + ) + create_chunk( + cg, vertices=[to_label(cg, 1, 1, 0, 0, 0)], edges=[], timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + res_old = cg.client._table.read_rows() + res_old.consume_all() + with pytest.raises((exceptions.PreconditionError, AssertionError)): + cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 2, 1, 0, 0, 1), + mincut=False, + ) + + res_new = cg.client._table.read_rows() + res_new.consume_all() + assert res_new.rows == res_old.rows + + @pytest.mark.timeout(30) + def test_diagonal_connections(self, gen_graph): + """ + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 2━1━┿━━3 │ + │ / │ │ + ┌─────┬─────┐ + │ | │ │ + │ 4━━┿━━5 │ + │ C¹ │ D¹ │ + └─────┴─────┘ + """ + cg: ChunkedGraph = gen_graph(n_layers=3) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], + edges=[ + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), + (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf), + ], + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 1, 0, 0)], + edges=[ + (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 1, 1, 0, 0), inf), + (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), + ], + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 1, 0, 0)], + edges=[(to_label(cg, 1, 1, 1, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf)], + ) + add_parent_chunk(cg, 3, [0, 0, 0], n_threads=1) + + rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=3, x=0, y=0, z=0)) + root_ids_t0 = list(rr.keys()) + assert len(root_ids_t0) == 1 + + new_roots = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 0, 0, 0, 1), + mincut=False, + ).new_root_ids + + assert len(new_roots) == 2 + assert cg.get_root(to_label(cg, 1, 1, 1, 0, 0)) == cg.get_root( + to_label(cg, 1, 0, 1, 0, 0) + ) + assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == cg.get_root( + to_label(cg, 1, 0, 0, 0, 0) + ) + + +class TestGraphSplitSkipConnections: + """Tests for skip connection behavior during split operations.""" + + @pytest.mark.timeout(120) + def test_split_multi_layer_hierarchy_correctness(self, gen_graph): + """ + After a split, verify the full parent chain from each supervoxel + to its new root is valid. + + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ => split => two separate roots + │ │ │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=4) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + result = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + assert len(result.new_root_ids) == 2 + + # Verify parent chain for both supervoxels + for sv in [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)]: + parents = cg.get_root(sv, get_all_parents=True) + prev_layer = 1 + for p in parents: + layer = cg.get_chunk_layer(p) + assert layer > prev_layer, ( + f"Parent chain not monotonically increasing: {prev_layer} -> {layer}" + ) + prev_layer = layer + # Last parent should be one of the new roots + assert parents[-1] in result.new_root_ids + + @pytest.mark.timeout(120) + def test_split_creates_isolated_components_with_skip_connections(self, gen_graph): + """ + After splitting a 3-node chain in a multi-layer graph, the isolated + node should still have a valid root. + + ┌─────┬─────┬─────┐ + │ A¹ │ B¹ │ C¹ │ + │ 1━━┿━━2━━┿━━3 │ => split 1-2 => 1 becomes isolated, 2-3 stay connected + └─────┴─────┴─────┘ + """ + cg = gen_graph(n_layers=4) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[ + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5), + (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), + ], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 2, 0, 0, 0)], + edges=[(to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # All three should share a root before split + root_pre = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + assert root_pre == cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + assert root_pre == cg.get_root(to_label(cg, 1, 2, 0, 0, 0)) + + # Split 1 from 2 + result = cg.remove_edges( + "Jane Doe", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + assert len(result.new_root_ids) == 2 + + # Node 1 should be isolated, nodes 2 and 3 should share a root + root1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + root2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + root3 = cg.get_root(to_label(cg, 1, 2, 0, 0, 0)) + assert root1 != root2 + assert root2 == root3 + + # Both roots should be valid + assert root1 in result.new_root_ids + assert root2 in result.new_root_ids diff --git a/pychunkedgraph/tests/test_stale_edges.py b/pychunkedgraph/tests/test_stale_edges.py new file mode 100644 index 000000000..bf160bdcc --- /dev/null +++ b/pychunkedgraph/tests/test_stale_edges.py @@ -0,0 +1,439 @@ +"""Integration tests for stale edge detection and resolution. + +Tests get_stale_nodes() and get_new_nodes() from stale.py using real graph +operations through the BigTable emulator. +""" + +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..graph.edges.stale import get_stale_nodes, get_new_nodes +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestStaleEdges: + @pytest.mark.timeout(30) + def test_stale_nodes_detected_after_split(self, gen_graph): + """ + After a split, the old L2 parent IDs become stale. + get_stale_nodes should identify them. + + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Get old parents before edit + old_root = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + + # Split + cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + + # The old root should now be stale + stale = get_stale_nodes(cg, [old_root]) + assert old_root in stale + + @pytest.mark.timeout(30) + def test_no_stale_nodes_for_current_ids(self, gen_graph): + """ + Current (post-edit) node IDs should not be flagged as stale. + + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Split + cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + + # Current roots should not be stale + new_root_1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + new_root_2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) + stale = get_stale_nodes(cg, [new_root_1, new_root_2]) + assert new_root_1 not in stale + assert new_root_2 not in stale + + @pytest.mark.timeout(30) + def test_get_new_nodes_resolves_to_correct_layer(self, gen_graph): + """ + get_new_nodes should follow the parent chain from a supervoxel + to the correct layer and return the current node at that layer. + + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Get L2 parent of SV 1 before edit + sv1 = to_label(cg, 1, 0, 0, 0, 0) + old_l2_parent = cg.get_parent(sv1) + + # Split + cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + + # get_new_nodes should resolve SV to its current L2 parent + new_l2 = get_new_nodes(cg, np.array([sv1], dtype=np.uint64), layer=2) + current_l2_parent = cg.get_parent(sv1) + assert new_l2[0] == current_l2_parent + + @pytest.mark.timeout(30) + def test_no_stale_nodes_in_unaffected_region(self, gen_graph): + """ + Nodes not involved in an edit should not be flagged as stale. + + ┌─────┬─────┬─────┐ + │ A¹ │ B¹ │ C¹ │ + │ 1━━┿━━2 │ 3 │ + │ │ │ │ + └─────┴─────┴─────┘ + """ + cg = gen_graph(n_layers=4) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + # Chunk C - isolated node, not connected to A or B + create_chunk( + cg, + vertices=[to_label(cg, 1, 2, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 3, [1, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + # Get the isolated node's root before edit + isolated_root = cg.get_root(to_label(cg, 1, 2, 0, 0, 0)) + + # Split nodes 1 and 2 + cg.remove_edges( + "test_user", + source_ids=to_label(cg, 1, 0, 0, 0, 0), + sink_ids=to_label(cg, 1, 1, 0, 0, 0), + mincut=False, + ) + + # The isolated root should not be stale — it was unaffected + stale = get_stale_nodes(cg, [isolated_root]) + assert isolated_root not in stale + + @pytest.mark.timeout(30) + def test_get_new_nodes_returns_self_for_non_stale(self, gen_graph): + """ + For freshly created nodes with no edits, get_new_nodes should return + the nodes themselves (identity mapping). + + ┌─────┐ + │ A¹ │ + │ 1 │ + │ │ + └─────┘ + """ + cg = gen_graph(n_layers=4) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + sv = to_label(cg, 1, 0, 0, 0, 0) + l2_parent = cg.get_parent(sv) + + # get_new_nodes at layer 2 should return the same L2 parent + result = get_new_nodes(cg, np.array([sv], dtype=np.uint64), layer=2) + assert result[0] == l2_parent + + @pytest.mark.timeout(30) + def test_get_stale_nodes_empty_for_fresh_graph(self, gen_graph): + """ + In a freshly built graph with no edits, no nodes should be stale. + + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + root = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) + l2_0 = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) + l2_1 = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) + + # No edits have been performed, so all nodes should be non-stale + stale = get_stale_nodes(cg, [root, l2_0, l2_1]) + assert len(stale) == 0 + + @pytest.mark.timeout(30) + def test_get_new_nodes_multiple_svs(self, gen_graph): + """ + get_new_nodes with multiple supervoxels should return an array + of the same length, each mapped to its current L2 parent. + + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + sv1 = to_label(cg, 1, 0, 0, 0, 0) + sv2 = to_label(cg, 1, 1, 0, 0, 0) + svs = np.array([sv1, sv2], dtype=np.uint64) + + result = get_new_nodes(cg, svs, layer=2) + assert result.shape == (2,) + # Each SV should map to its L2 parent + assert result[0] == cg.get_parent(sv1) + assert result[1] == cg.get_parent(sv2) + + @pytest.mark.timeout(30) + def test_get_new_nodes_with_duplicate_svs(self, gen_graph): + """ + get_new_nodes should handle duplicate SVs correctly, + returning the same result for duplicate inputs. + + ┌─────┐ + │ A¹ │ + │ 1 │ + │ │ + └─────┘ + """ + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + sv = to_label(cg, 1, 0, 0, 0, 0) + svs = np.array([sv, sv, sv], dtype=np.uint64) + + result = get_new_nodes(cg, svs, layer=2) + assert result.shape == (3,) + # All should map to the same L2 parent + expected = cg.get_parent(sv) + assert np.all(result == expected) + + @pytest.mark.timeout(30) + def test_get_stale_nodes_with_l2_ids_after_merge(self, gen_graph): + """ + After a merge, the old L2 IDs should become stale. + + ┌─────┐ + │ A¹ │ + │ 1 2 │ (isolated, then merged) + │ │ + └─────┘ + """ + atomic_chunk_bounds = np.array([1, 1, 1]) + cg = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + sv0 = to_label(cg, 1, 0, 0, 0, 0) + sv1 = to_label(cg, 1, 0, 0, 0, 1) + + create_chunk( + cg, + vertices=[sv0, sv1], + edges=[], + timestamp=fake_timestamp, + ) + + # Get L2 parents before merge (each SV has its own L2 parent) + old_l2_0 = cg.get_parent(sv0) + old_l2_1 = cg.get_parent(sv1) + + # Merge + cg.add_edges( + "test_user", + [sv0, sv1], + affinities=[0.3], + ) + + # Old L2 parents should now be stale + stale = get_stale_nodes(cg, [old_l2_0, old_l2_1]) + assert old_l2_0 in stale or old_l2_1 in stale + + @pytest.mark.timeout(30) + def test_get_stale_nodes_returns_numpy_array(self, gen_graph): + """ + get_stale_nodes should always return a numpy ndarray, even when + no nodes are stale. + + ┌─────┐ + │ A¹ │ + │ 1 │ + │ │ + └─────┘ + """ + atomic_chunk_bounds = np.array([1, 1, 1]) + cg = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + sv0 = to_label(cg, 1, 0, 0, 0, 0) + create_chunk( + cg, + vertices=[sv0], + edges=[], + timestamp=fake_timestamp, + ) + + root = cg.get_root(sv0) + stale = get_stale_nodes(cg, [root]) + assert isinstance(stale, np.ndarray) + + @pytest.mark.timeout(30) + def test_get_new_nodes_at_root_layer(self, gen_graph): + """ + get_new_nodes called with layer=root_layer should return the root node. + + ┌─────┐ + │ A¹ │ + │ 1 │ + │ │ + └─────┘ + """ + cg = gen_graph(n_layers=4) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + sv = to_label(cg, 1, 0, 0, 0, 0) + create_chunk( + cg, + vertices=[sv], + edges=[], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + add_parent_chunk(cg, 4, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + + root = cg.get_root(sv) + root_layer = cg.get_chunk_layer(root) + + result = get_new_nodes(cg, np.array([sv], dtype=np.uint64), layer=root_layer) + assert result.shape == (1,) + assert result[0] == root diff --git a/pychunkedgraph/tests/test_subgraph.py b/pychunkedgraph/tests/test_subgraph.py new file mode 100644 index 000000000..e9ca7cd66 --- /dev/null +++ b/pychunkedgraph/tests/test_subgraph.py @@ -0,0 +1,112 @@ +"""Tests for pychunkedgraph.graph.subgraph""" + +from datetime import datetime, timedelta, UTC +from math import inf + +import numpy as np +import pytest + +from pychunkedgraph.graph.subgraph import SubgraphProgress, get_subgraph_nodes + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestSubgraphProgress: + def test_init(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + progress = SubgraphProgress( + graph.meta, + node_ids=[root], + return_layers=[2], + serializable=False, + ) + assert not progress.done_processing() + + def test_serializable_keys(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0)], + edges=[], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + progress = SubgraphProgress( + graph.meta, + node_ids=[root], + return_layers=[2], + serializable=True, + ) + # Keys should be strings when serializable=True + key = progress.get_dict_key(root) + assert isinstance(key, str) + + +class TestGetSubgraphNodes: + def _build_graph(self, gen_graph): + graph = gen_graph(n_layers=4) + fake_ts = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + graph, + vertices=[to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1)], + edges=[ + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 1), 0.5), + (to_label(graph, 1, 0, 0, 0, 0), to_label(graph, 1, 1, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + create_chunk( + graph, + vertices=[to_label(graph, 1, 1, 0, 0, 0)], + edges=[ + (to_label(graph, 1, 1, 0, 0, 0), to_label(graph, 1, 0, 0, 0, 0), inf), + ], + timestamp=fake_ts, + ) + add_parent_chunk(graph, 3, [0, 0, 0], n_threads=1) + add_parent_chunk(graph, 4, [0, 0, 0], n_threads=1) + return graph + + def test_single_node(self, gen_graph): + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = get_subgraph_nodes(graph, root) + assert isinstance(result, dict) + assert 2 in result + + def test_return_flattened(self, gen_graph): + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = get_subgraph_nodes(graph, root, return_flattened=True) + assert isinstance(result, np.ndarray) + assert len(result) > 0 + + def test_multiple_nodes(self, gen_graph): + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = get_subgraph_nodes(graph, [root]) + assert root in result + + def test_serializable(self, gen_graph): + graph = self._build_graph(gen_graph) + root = graph.get_root(to_label(graph, 1, 0, 0, 0, 0)) + result = get_subgraph_nodes(graph, root, serializable=True) + # Keys should be layer ints, values should be arrays + assert isinstance(result, dict) diff --git a/pychunkedgraph/tests/test_types.py b/pychunkedgraph/tests/test_types.py new file mode 100644 index 000000000..ed6f5212b --- /dev/null +++ b/pychunkedgraph/tests/test_types.py @@ -0,0 +1,33 @@ +"""Tests for pychunkedgraph.graph.types""" + +import numpy as np + +from pychunkedgraph.graph.types import empty_1d, empty_2d, Agglomeration +from pychunkedgraph.graph.utils import basetypes + + +class TestEmptyArrays: + def test_empty_1d_shape_and_dtype(self): + assert empty_1d.shape == (0,) + assert empty_1d.dtype == basetypes.NODE_ID + + def test_empty_2d_shape_and_dtype(self): + assert empty_2d.shape == (0, 2) + assert empty_2d.dtype == basetypes.NODE_ID + + +class TestAgglomeration: + def test_defaults(self): + agg = Agglomeration(node_id=np.uint64(1)) + assert agg.node_id == np.uint64(1) + assert agg.supervoxels.shape == (0,) + assert agg.in_edges.shape == (0, 2) + assert agg.out_edges.shape == (0, 2) + assert agg.cross_edges.shape == (0, 2) + assert agg.cross_edges_d == {} + + def test_custom_fields(self): + svs = np.array([10, 20], dtype=basetypes.NODE_ID) + agg = Agglomeration(node_id=np.uint64(5), supervoxels=svs) + assert agg.node_id == np.uint64(5) + np.testing.assert_array_equal(agg.supervoxels, svs) diff --git a/pychunkedgraph/tests/test_uncategorized.py b/pychunkedgraph/tests/test_uncategorized.py deleted file mode 100644 index 93c41158d..000000000 --- a/pychunkedgraph/tests/test_uncategorized.py +++ /dev/null @@ -1,3544 +0,0 @@ -import collections -import os -import subprocess -import sys -from time import sleep -from datetime import datetime, timedelta -from functools import partial -from math import inf -from signal import SIGTERM -from unittest import mock -from warnings import warn - -import numpy as np -import pytest -from google.auth import credentials -from google.cloud import bigtable -from grpc._channel import _Rendezvous - -from .helpers import ( - bigtable_emulator, - create_chunk, - gen_graph, - gen_graph_simplequerytest, - to_label, - sv_data, -) -from ..graph import types -from ..graph import attributes -from ..graph import exceptions -from ..graph import chunkedgraph -from ..graph.edges import Edges -from ..graph.utils import basetypes -from ..graph.misc import get_delta_roots -from ..graph.cutting import run_multicut -from ..graph.lineage import get_root_id_history -from ..graph.lineage import get_future_root_ids -from ..graph.utils.serializers import serialize_uint64 -from ..graph.utils.serializers import deserialize_uint64 -from ..ingest.create.abstract_layers import add_layer - - -class TestGraphNodeConversion: - @pytest.mark.timeout(30) - def test_compute_bitmasks(self): - pass - - @pytest.mark.timeout(30) - def test_node_conversion(self, gen_graph): - cg = gen_graph(n_layers=10) - - node_id = cg.get_node_id(np.uint64(4), layer=2, x=3, y=1, z=0) - assert cg.get_chunk_layer(node_id) == 2 - assert np.all(cg.get_chunk_coordinates(node_id) == np.array([3, 1, 0])) - - chunk_id = cg.get_chunk_id(layer=2, x=3, y=1, z=0) - assert cg.get_chunk_layer(chunk_id) == 2 - assert np.all(cg.get_chunk_coordinates(chunk_id) == np.array([3, 1, 0])) - - assert cg.get_chunk_id(node_id=node_id) == chunk_id - assert cg.get_node_id(np.uint64(4), chunk_id=chunk_id) == node_id - - @pytest.mark.timeout(30) - def test_node_id_adjacency(self, gen_graph): - cg = gen_graph(n_layers=10) - - assert cg.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0) + np.uint64( - 1 - ) == cg.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0) - - assert cg.get_node_id( - np.uint64(2 ** 53 - 2), layer=10, x=0, y=0, z=0 - ) + np.uint64(1) == cg.get_node_id( - np.uint64(2 ** 53 - 1), layer=10, x=0, y=0, z=0 - ) - - @pytest.mark.timeout(30) - def test_serialize_node_id(self, gen_graph): - cg = gen_graph(n_layers=10) - - assert serialize_uint64( - cg.get_node_id(np.uint64(0), layer=2, x=3, y=1, z=0) - ) < serialize_uint64(cg.get_node_id(np.uint64(1), layer=2, x=3, y=1, z=0)) - - assert serialize_uint64( - cg.get_node_id(np.uint64(2 ** 53 - 2), layer=10, x=0, y=0, z=0) - ) < serialize_uint64( - cg.get_node_id(np.uint64(2 ** 53 - 1), layer=10, x=0, y=0, z=0) - ) - - @pytest.mark.timeout(30) - def test_deserialize_node_id(self): - pass - - @pytest.mark.timeout(30) - def test_serialization_roundtrip(self): - pass - - @pytest.mark.timeout(30) - def test_serialize_valid_label_id(self): - label = np.uint64(0x01FF031234556789) - assert deserialize_uint64(serialize_uint64(label)) == label - - -class TestGraphBuild: - @pytest.mark.timeout(30) - def test_build_single_node(self, gen_graph): - """ - Create graph with single RG node 1 in chunk A - ┌─────┐ - │ A¹ │ - │ 1 │ - │ │ - └─────┘ - """ - - cg = gen_graph(n_layers=2) - # Add Chunk A - create_chunk(cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)]) - - res = cg.client._table.read_rows() - res.consume_all() - - assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) - assert parent == to_label(cg, 2, 0, 0, 0, 1) - - # Check for the one Level 2 node that should have been created. - assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows - atomic_cross_edge_d = cg.get_atomic_cross_edges( - np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) - ) - attr = attributes.Hierarchy.Child - row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] - children = attr.deserialize(row[attr.key][0].value) - - for aces in atomic_cross_edge_d.values(): - assert len(aces) == 0 - - assert len(children) == 1 and children[0] == to_label(cg, 1, 0, 0, 0, 0) - # Make sure there are not any more entries in the table - # include counters, meta and version rows - assert len(res.rows) == 1 + 1 + 1 + 1 + 1 - - @pytest.mark.timeout(30) - def test_build_single_edge(self, gen_graph): - """ - Create graph with edge between RG supervoxels 1 and 2 (same chunk) - ┌─────┐ - │ A¹ │ - │ 1━2 │ - │ │ - └─────┘ - """ - - cg = gen_graph(n_layers=2) - - # Add Chunk A - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], - ) - - res = cg.client._table.read_rows() - res.consume_all() - - assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) - assert parent == to_label(cg, 2, 0, 0, 0, 1) - - assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 1)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) - assert parent == to_label(cg, 2, 0, 0, 0, 1) - - # Check for the one Level 2 node that should have been created. - assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows - - atomic_cross_edge_d = cg.get_atomic_cross_edges( - np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) - ) - attr = attributes.Hierarchy.Child - row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] - children = attr.deserialize(row[attr.key][0].value) - - for aces in atomic_cross_edge_d.values(): - assert len(aces) == 0 - assert ( - len(children) == 2 - and to_label(cg, 1, 0, 0, 0, 0) in children - and to_label(cg, 1, 0, 0, 0, 1) in children - ) - - # Make sure there are not any more entries in the table - # include counters, meta and version rows - assert len(res.rows) == 2 + 1 + 1 + 1 + 1 - - @pytest.mark.timeout(30) - def test_build_single_across_edge(self, gen_graph): - """ - Create graph with edge between RG supervoxels 1 and 2 (neighboring chunks) - ┌─────┌─────┐ - │ A¹ │ B¹ │ - │ 1━━┿━━2 │ - │ │ │ - └─────┴─────┘ - """ - - atomic_chunk_bounds = np.array([2, 1, 1]) - cg = gen_graph(n_layers=3, atomic_chunk_bounds=atomic_chunk_bounds) - - # Chunk A - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], - ) - - # Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], - ) - - add_layer(cg, 3, [0, 0, 0], n_threads=1) - res = cg.client._table.read_rows() - res.consume_all() - - assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) - assert parent == to_label(cg, 2, 0, 0, 0, 1) - - assert serialize_uint64(to_label(cg, 1, 1, 0, 0, 0)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) - assert parent == to_label(cg, 2, 1, 0, 0, 1) - - # Check for the two Level 2 nodes that should have been created. Since Level 2 has the same - # dimensions as Level 1, we also expect them to be in different chunks - # to_label(cg, 2, 0, 0, 0, 1) - assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows - atomic_cross_edge_d = cg.get_atomic_cross_edges( - np.array([to_label(cg, 2, 0, 0, 0, 1)], dtype=basetypes.NODE_ID) - ) - atomic_cross_edge_d = atomic_cross_edge_d[ - np.uint64(to_label(cg, 2, 0, 0, 0, 1)) - ] - - attr = attributes.Hierarchy.Child - row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] - children = attr.deserialize(row[attr.key][0].value) - - test_ace = np.array( - [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], - dtype=np.uint64, - ) - assert len(atomic_cross_edge_d[2]) == 1 - assert test_ace in atomic_cross_edge_d[2] - assert len(children) == 1 and to_label(cg, 1, 0, 0, 0, 0) in children - - # to_label(cg, 2, 1, 0, 0, 1) - assert serialize_uint64(to_label(cg, 2, 1, 0, 0, 1)) in res.rows - atomic_cross_edge_d = cg.get_atomic_cross_edges( - np.array([to_label(cg, 2, 1, 0, 0, 1)], dtype=basetypes.NODE_ID) - ) - atomic_cross_edge_d = atomic_cross_edge_d[ - np.uint64(to_label(cg, 2, 1, 0, 0, 1)) - ] - - attr = attributes.Hierarchy.Child - row = res.rows[serialize_uint64(to_label(cg, 2, 1, 0, 0, 1))].cells["0"] - children = attr.deserialize(row[attr.key][0].value) - - test_ace = np.array( - [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], - dtype=np.uint64, - ) - assert len(atomic_cross_edge_d[2]) == 1 - assert test_ace in atomic_cross_edge_d[2] - assert len(children) == 1 and to_label(cg, 1, 1, 0, 0, 0) in children - - # Check for the one Level 3 node that should have been created. This one combines the two - # connected components of Level 2 - # to_label(cg, 3, 0, 0, 0, 1) - assert serialize_uint64(to_label(cg, 3, 0, 0, 0, 1)) in res.rows - - attr = attributes.Hierarchy.Child - row = res.rows[serialize_uint64(to_label(cg, 3, 0, 0, 0, 1))].cells["0"] - children = attr.deserialize(row[attr.key][0].value) - assert ( - len(children) == 2 - and to_label(cg, 2, 0, 0, 0, 1) in children - and to_label(cg, 2, 1, 0, 0, 1) in children - ) - - # Make sure there are not any more entries in the table - # include counters, meta and version rows - assert len(res.rows) == 2 + 2 + 1 + 3 + 1 + 1 - - @pytest.mark.timeout(30) - def test_build_single_edge_and_single_across_edge(self, gen_graph): - """ - Create graph with edge between RG supervoxels 1 and 2 (same chunk) - and edge between RG supervoxels 1 and 3 (neighboring chunks) - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 2━1━┿━━3 │ - │ │ │ - └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Chunk A - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - edges=[ - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), - ], - ) - - # Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], - ) - - add_layer(cg, 3, np.array([0, 0, 0]), n_threads=1) - res = cg.client._table.read_rows() - res.consume_all() - - assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) - assert parent == to_label(cg, 2, 0, 0, 0, 1) - - # to_label(cg, 1, 0, 0, 0, 1) - assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 1)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) - assert parent == to_label(cg, 2, 0, 0, 0, 1) - - # to_label(cg, 1, 1, 0, 0, 0) - assert serialize_uint64(to_label(cg, 1, 1, 0, 0, 0)) in res.rows - parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) - assert parent == to_label(cg, 2, 1, 0, 0, 1) - - # Check for the two Level 2 nodes that should have been created. Since Level 2 has the same - # dimensions as Level 1, we also expect them to be in different chunks - # to_label(cg, 2, 0, 0, 0, 1) - assert serialize_uint64(to_label(cg, 2, 0, 0, 0, 1)) in res.rows - row = res.rows[serialize_uint64(to_label(cg, 2, 0, 0, 0, 1))].cells["0"] - atomic_cross_edge_d = cg.get_atomic_cross_edges([to_label(cg, 2, 0, 0, 0, 1)]) - atomic_cross_edge_d = atomic_cross_edge_d[ - np.uint64(to_label(cg, 2, 0, 0, 0, 1)) - ] - column = attributes.Hierarchy.Child - children = column.deserialize(row[column.key][0].value) - - test_ace = np.array( - [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], - dtype=np.uint64, - ) - assert len(atomic_cross_edge_d[2]) == 1 - assert test_ace in atomic_cross_edge_d[2] - assert ( - len(children) == 2 - and to_label(cg, 1, 0, 0, 0, 0) in children - and to_label(cg, 1, 0, 0, 0, 1) in children - ) - - # to_label(cg, 2, 1, 0, 0, 1) - assert serialize_uint64(to_label(cg, 2, 1, 0, 0, 1)) in res.rows - row = res.rows[serialize_uint64(to_label(cg, 2, 1, 0, 0, 1))].cells["0"] - atomic_cross_edge_d = cg.get_atomic_cross_edges([to_label(cg, 2, 1, 0, 0, 1)]) - atomic_cross_edge_d = atomic_cross_edge_d[ - np.uint64(to_label(cg, 2, 1, 0, 0, 1)) - ] - children = column.deserialize(row[column.key][0].value) - - test_ace = np.array( - [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], - dtype=np.uint64, - ) - assert len(atomic_cross_edge_d[2]) == 1 - assert test_ace in atomic_cross_edge_d[2] - assert len(children) == 1 and to_label(cg, 1, 1, 0, 0, 0) in children - - # Check for the one Level 3 node that should have been created. This one combines the two - # connected components of Level 2 - # to_label(cg, 3, 0, 0, 0, 1) - assert serialize_uint64(to_label(cg, 3, 0, 0, 0, 1)) in res.rows - row = res.rows[serialize_uint64(to_label(cg, 3, 0, 0, 0, 1))].cells["0"] - column = attributes.Hierarchy.Child - children = column.deserialize(row[column.key][0].value) - - assert ( - len(children) == 2 - and to_label(cg, 2, 0, 0, 0, 1) in children - and to_label(cg, 2, 1, 0, 0, 1) in children - ) - - # Make sure there are not any more entries in the table - # include counters, meta and version rows - assert len(res.rows) == 3 + 2 + 1 + 3 + 1 + 1 - - @pytest.mark.timeout(120) - def test_build_big_graph(self, gen_graph): - """ - Create graph with RG nodes 1 and 2 in opposite corners of the largest possible dataset - ┌─────┐ ┌─────┐ - │ A¹ │ ... │ Z¹ │ - │ 1 │ │ 2 │ - │ │ │ │ - └─────┘ └─────┘ - """ - - atomic_chunk_bounds = np.array([8, 8, 8]) - cg = gen_graph(n_layers=5, atomic_chunk_bounds=atomic_chunk_bounds) - - # Preparation: Build Chunk A - create_chunk(cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)], edges=[]) - - # Preparation: Build Chunk Z - create_chunk(cg, vertices=[to_label(cg, 1, 7, 7, 7, 0)], edges=[]) - - add_layer(cg, 3, [0, 0, 0], n_threads=1) - add_layer(cg, 3, [3, 3, 3], n_threads=1) - add_layer(cg, 4, [0, 0, 0], n_threads=1) - add_layer(cg, 5, [0, 0, 0], n_threads=1) - - res = cg.client._table.read_rows() - res.consume_all() - - assert serialize_uint64(to_label(cg, 1, 0, 0, 0, 0)) in res.rows - assert serialize_uint64(to_label(cg, 1, 7, 7, 7, 0)) in res.rows - assert serialize_uint64(to_label(cg, 5, 0, 0, 0, 1)) in res.rows - assert serialize_uint64(to_label(cg, 5, 0, 0, 0, 2)) in res.rows - - @pytest.mark.timeout(30) - def test_double_chunk_creation(self, gen_graph): - """ - No connection between 1, 2 and 3 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 3 │ - │ 2 │ │ - └─────┴─────┘ - """ - - atomic_chunk_bounds = np.array([4, 4, 4]) - cg = gen_graph(n_layers=4, atomic_chunk_bounds=atomic_chunk_bounds) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 4, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - assert len(cg.range_read_chunk(cg.get_chunk_id(layer=2, x=0, y=0, z=0))) == 2 - assert len(cg.range_read_chunk(cg.get_chunk_id(layer=2, x=1, y=0, z=0))) == 1 - assert len(cg.range_read_chunk(cg.get_chunk_id(layer=3, x=0, y=0, z=0))) == 0 - assert len(cg.range_read_chunk(cg.get_chunk_id(layer=4, x=0, y=0, z=0))) == 6 - - assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 0, 0, 0, 1))) == 4 - assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 0, 0, 0, 2))) == 4 - assert cg.get_chunk_layer(cg.get_root(to_label(cg, 1, 1, 0, 0, 1))) == 4 - - root_seg_ids = [ - cg.get_segment_id(cg.get_root(to_label(cg, 1, 0, 0, 0, 1))), - cg.get_segment_id(cg.get_root(to_label(cg, 1, 0, 0, 0, 2))), - cg.get_segment_id(cg.get_root(to_label(cg, 1, 1, 0, 0, 1))), - ] - - assert 4 in root_seg_ids - assert 5 in root_seg_ids - assert 6 in root_seg_ids - - -class TestGraphSimpleQueries: - """ - ┌─────┬─────┬─────┐ L X Y Z S L X Y Z S L X Y Z S L X Y Z S - │ A¹ │ B¹ │ C¹ │ 1: 1 0 0 0 0 ─── 2 0 0 0 1 ───────────────── 4 0 0 0 1 - │ 1 │ 3━2━┿━━4 │ 2: 1 1 0 0 0 ─┬─ 2 1 0 0 1 ─── 3 0 0 0 1 ─┬─ 4 0 0 0 2 - │ │ │ │ 3: 1 1 0 0 1 ─┘ │ - └─────┴─────┴─────┘ 4: 1 2 0 0 0 ─── 2 2 0 0 1 ─── 3 1 0 0 1 ─┘ - """ - - @pytest.mark.timeout(30) - def test_get_parent_and_children(self, gen_graph_simplequerytest): - cg = gen_graph_simplequerytest - - children10000 = cg.get_children(to_label(cg, 1, 0, 0, 0, 0)) - children11000 = cg.get_children(to_label(cg, 1, 1, 0, 0, 0)) - children11001 = cg.get_children(to_label(cg, 1, 1, 0, 0, 1)) - children12000 = cg.get_children(to_label(cg, 1, 2, 0, 0, 0)) - - parent10000 = cg.get_parent( - to_label(cg, 1, 0, 0, 0, 0), - ) - parent11000 = cg.get_parent( - to_label(cg, 1, 1, 0, 0, 0), - ) - parent11001 = cg.get_parent( - to_label(cg, 1, 1, 0, 0, 1), - ) - parent12000 = cg.get_parent( - to_label(cg, 1, 2, 0, 0, 0), - ) - - children20001 = cg.get_children(to_label(cg, 2, 0, 0, 0, 1)) - children21001 = cg.get_children(to_label(cg, 2, 1, 0, 0, 1)) - children22001 = cg.get_children(to_label(cg, 2, 2, 0, 0, 1)) - - parent20001 = cg.get_parent( - to_label(cg, 2, 0, 0, 0, 1), - ) - parent21001 = cg.get_parent( - to_label(cg, 2, 1, 0, 0, 1), - ) - parent22001 = cg.get_parent( - to_label(cg, 2, 2, 0, 0, 1), - ) - - children30001 = cg.get_children(to_label(cg, 3, 0, 0, 0, 1)) - # children30002 = cg.get_children(to_label(cg, 3, 0, 0, 0, 2)) - children31001 = cg.get_children(to_label(cg, 3, 1, 0, 0, 1)) - - parent30001 = cg.get_parent( - to_label(cg, 3, 0, 0, 0, 1), - ) - # parent30002 = cg.get_parent(to_label(cg, 3, 0, 0, 0, 2), ) - parent31001 = cg.get_parent( - to_label(cg, 3, 1, 0, 0, 1), - ) - - children40001 = cg.get_children(to_label(cg, 4, 0, 0, 0, 1)) - children40002 = cg.get_children(to_label(cg, 4, 0, 0, 0, 2)) - - parent40001 = cg.get_parent( - to_label(cg, 4, 0, 0, 0, 1), - ) - parent40002 = cg.get_parent( - to_label(cg, 4, 0, 0, 0, 2), - ) - - # (non-existing) Children of L1 - assert np.array_equal(children10000, []) is True - assert np.array_equal(children11000, []) is True - assert np.array_equal(children11001, []) is True - assert np.array_equal(children12000, []) is True - - # Parent of L1 - assert parent10000 == to_label(cg, 2, 0, 0, 0, 1) - assert parent11000 == to_label(cg, 2, 1, 0, 0, 1) - assert parent11001 == to_label(cg, 2, 1, 0, 0, 1) - assert parent12000 == to_label(cg, 2, 2, 0, 0, 1) - - # Children of L2 - assert len(children20001) == 1 and to_label(cg, 1, 0, 0, 0, 0) in children20001 - assert ( - len(children21001) == 2 - and to_label(cg, 1, 1, 0, 0, 0) in children21001 - and to_label(cg, 1, 1, 0, 0, 1) in children21001 - ) - assert len(children22001) == 1 and to_label(cg, 1, 2, 0, 0, 0) in children22001 - - # Parent of L2 - assert parent20001 == to_label(cg, 4, 0, 0, 0, 1) - assert parent21001 == to_label(cg, 3, 0, 0, 0, 1) - assert parent22001 == to_label(cg, 3, 1, 0, 0, 1) - - # Children of L3 - assert len(children30001) == 1 and len(children31001) == 1 - assert to_label(cg, 2, 1, 0, 0, 1) in children30001 - assert to_label(cg, 2, 2, 0, 0, 1) in children31001 - - # Parent of L3 - assert parent30001 == parent31001 - assert ( - parent30001 == to_label(cg, 4, 0, 0, 0, 1) - and parent20001 == to_label(cg, 4, 0, 0, 0, 2) - ) or ( - parent30001 == to_label(cg, 4, 0, 0, 0, 2) - and parent20001 == to_label(cg, 4, 0, 0, 0, 1) - ) - - # Children of L4 - assert parent10000 in children40001 - assert parent21001 in children40002 and parent22001 in children40002 - - # (non-existing) Parent of L4 - assert parent40001 is None - assert parent40002 is None - - children2_separate = cg.get_children( - [ - to_label(cg, 2, 0, 0, 0, 1), - to_label(cg, 2, 1, 0, 0, 1), - to_label(cg, 2, 2, 0, 0, 1), - ] - ) - assert len(children2_separate) == 3 - assert to_label(cg, 2, 0, 0, 0, 1) in children2_separate and np.all( - np.isin(children2_separate[to_label(cg, 2, 0, 0, 0, 1)], children20001) - ) - assert to_label(cg, 2, 1, 0, 0, 1) in children2_separate and np.all( - np.isin(children2_separate[to_label(cg, 2, 1, 0, 0, 1)], children21001) - ) - assert to_label(cg, 2, 2, 0, 0, 1) in children2_separate and np.all( - np.isin(children2_separate[to_label(cg, 2, 2, 0, 0, 1)], children22001) - ) - - children2_combined = cg.get_children( - [ - to_label(cg, 2, 0, 0, 0, 1), - to_label(cg, 2, 1, 0, 0, 1), - to_label(cg, 2, 2, 0, 0, 1), - ], - flatten=True, - ) - assert ( - len(children2_combined) == 4 - and np.all(np.isin(children20001, children2_combined)) - and np.all(np.isin(children21001, children2_combined)) - and np.all(np.isin(children22001, children2_combined)) - ) - - @pytest.mark.timeout(30) - def test_get_root(self, gen_graph_simplequerytest): - cg = gen_graph_simplequerytest - root10000 = cg.get_root( - to_label(cg, 1, 0, 0, 0, 0), - ) - root11000 = cg.get_root( - to_label(cg, 1, 1, 0, 0, 0), - ) - root11001 = cg.get_root( - to_label(cg, 1, 1, 0, 0, 1), - ) - root12000 = cg.get_root( - to_label(cg, 1, 2, 0, 0, 0), - ) - - with pytest.raises(Exception): - cg.get_root(0) - - assert ( - root10000 == to_label(cg, 4, 0, 0, 0, 1) - and root11000 == root11001 == root12000 == to_label(cg, 4, 0, 0, 0, 2) - ) or ( - root10000 == to_label(cg, 4, 0, 0, 0, 2) - and root11000 == root11001 == root12000 == to_label(cg, 4, 0, 0, 0, 1) - ) - - @pytest.mark.timeout(30) - def test_get_subgraph_nodes(self, gen_graph_simplequerytest): - cg = gen_graph_simplequerytest - root1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) - root2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) - - lvl1_nodes_1 = cg.get_subgraph([root1], leaves_only=True) - lvl1_nodes_2 = cg.get_subgraph([root2], leaves_only=True) - assert len(lvl1_nodes_1) == 1 - assert len(lvl1_nodes_2) == 3 - assert to_label(cg, 1, 0, 0, 0, 0) in lvl1_nodes_1 - assert to_label(cg, 1, 1, 0, 0, 0) in lvl1_nodes_2 - assert to_label(cg, 1, 1, 0, 0, 1) in lvl1_nodes_2 - assert to_label(cg, 1, 2, 0, 0, 0) in lvl1_nodes_2 - - lvl2_parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) - lvl1_nodes = cg.get_subgraph([lvl2_parent], leaves_only=True) - assert len(lvl1_nodes) == 2 - assert to_label(cg, 1, 1, 0, 0, 0) in lvl1_nodes - assert to_label(cg, 1, 1, 0, 0, 1) in lvl1_nodes - - @pytest.mark.timeout(30) - def test_get_subgraph_edges(self, gen_graph_simplequerytest): - cg = gen_graph_simplequerytest - root1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) - root2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) - - edges = cg.get_subgraph([root1], edges_only=True) - assert len(edges) == 0 - - edges = cg.get_subgraph([root2], edges_only=True) - assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)] in edges or [ - to_label(cg, 1, 1, 0, 0, 1), - to_label(cg, 1, 1, 0, 0, 0), - ] in edges - - assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0)] in edges or [ - to_label(cg, 1, 2, 0, 0, 0), - to_label(cg, 1, 1, 0, 0, 0), - ] in edges - - lvl2_parent = cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) - edges = cg.get_subgraph([lvl2_parent], edges_only=True) - assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)] in edges or [ - to_label(cg, 1, 1, 0, 0, 1), - to_label(cg, 1, 1, 0, 0, 0), - ] in edges - - assert [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0)] in edges or [ - to_label(cg, 1, 2, 0, 0, 0), - to_label(cg, 1, 1, 0, 0, 0), - ] in edges - - assert len(edges) == 1 - - @pytest.mark.timeout(30) - def test_get_subgraph_nodes_bb(self, gen_graph_simplequerytest): - cg = gen_graph_simplequerytest - bb = np.array([[1, 0, 0], [2, 1, 1]], dtype=int) - bb_coord = bb * cg.meta.graph_config.CHUNK_SIZE - childs_1 = cg.get_subgraph( - [cg.get_root(to_label(cg, 1, 1, 0, 0, 1))], bbox=bb, leaves_only=True - ) - childs_2 = cg.get_subgraph( - [cg.get_root(to_label(cg, 1, 1, 0, 0, 1))], - bbox=bb_coord, - bbox_is_coordinate=True, - leaves_only=True, - ) - assert np.all(~(np.sort(childs_1) - np.sort(childs_2))) - - -class TestGraphMerge: - @pytest.mark.timeout(30) - def test_merge_pair_same_chunk(self, gen_graph): - """ - Add edge between existing RG supervoxels 1 and 2 (same chunk) - Expected: Same (new) parent for RG 1 and 2 on Layer two - ┌─────┐ ┌─────┐ - │ A¹ │ │ A¹ │ - │ 1 2 │ => │ 1━2 │ - │ │ │ │ - └─────┘ └─────┘ - """ - - atomic_chunk_bounds = np.array([1, 1, 1]) - cg = gen_graph(n_layers=2, atomic_chunk_bounds=atomic_chunk_bounds) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - # Merge - new_root_ids = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], - affinities=[0.3], - ).new_root_ids - - assert len(new_root_ids) == 1 - new_root_id = new_root_ids[0] - - # Check - assert cg.get_parent(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id - assert cg.get_parent(to_label(cg, 1, 0, 0, 0, 1)) == new_root_id - leaves = np.unique(cg.get_subgraph([new_root_id], leaves_only=True)) - assert len(leaves) == 2 - assert to_label(cg, 1, 0, 0, 0, 0) in leaves - assert to_label(cg, 1, 0, 0, 0, 1) in leaves - - @pytest.mark.timeout(30) - def test_merge_pair_neighboring_chunks(self, gen_graph): - """ - Add edge between existing RG supervoxels 1 and 2 (neighboring chunks) - ┌─────┬─────┐ ┌─────┬─────┐ - │ A¹ │ B¹ │ │ A¹ │ B¹ │ - │ 1 │ 2 │ => │ 1━━┿━━2 │ - │ │ │ │ │ │ - └─────┴─────┘ └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - # Merge - new_root_ids = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], - affinities=0.3, - ).new_root_ids - - assert len(new_root_ids) == 1 - new_root_id = new_root_ids[0] - - # Check - assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id - assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == new_root_id - leaves = np.unique(cg.get_subgraph([new_root_id], leaves_only=True)) - assert len(leaves) == 2 - assert to_label(cg, 1, 0, 0, 0, 0) in leaves - assert to_label(cg, 1, 1, 0, 0, 0) in leaves - - @pytest.mark.timeout(120) - def test_merge_pair_disconnected_chunks(self, gen_graph): - """ - Add edge between existing RG supervoxels 1 and 2 (disconnected chunks) - ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ - │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ - │ 1 │ │ 2 │ => │ 1━━┿━━━━━┿━━2 │ - │ │ │ │ │ │ │ │ - └─────┘ └─────┘ └─────┘ └─────┘ - """ - - cg = gen_graph(n_layers=5) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk Z - create_chunk( - cg, - vertices=[to_label(cg, 1, 7, 7, 7, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 3, - [3, 3, 3], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 4, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 5, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - # Merge - result = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0)], - affinities=[0.3], - ) - new_root_ids, lvl2_node_ids = result.new_root_ids, result.new_lvl2_ids - print(f"lvl2_node_ids: {lvl2_node_ids}") - - u_layers = np.unique(cg.get_chunk_layers(lvl2_node_ids)) - assert len(u_layers) == 1 - assert u_layers[0] == 2 - - assert len(new_root_ids) == 1 - new_root_id = new_root_ids[0] - - # Check - assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id - assert cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) == new_root_id - leaves = np.unique(cg.get_subgraph(new_root_id, leaves_only=True)) - assert len(leaves) == 2 - assert to_label(cg, 1, 0, 0, 0, 0) in leaves - assert to_label(cg, 1, 7, 7, 7, 0) in leaves - - @pytest.mark.timeout(30) - def test_merge_pair_already_connected(self, gen_graph): - """ - Add edge between already connected RG supervoxels 1 and 2 (same chunk). - Expected: No change, i.e. same parent (to_label(cg, 2, 0, 0, 0, 1)), affinity (0.5) and timestamp as before - ┌─────┐ ┌─────┐ - │ A¹ │ │ A¹ │ - │ 1━2 │ => │ 1━2 │ - │ │ │ │ - └─────┘ └─────┘ - """ - - cg = gen_graph(n_layers=2) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], - timestamp=fake_timestamp, - ) - - res_old = cg.client._table.read_rows() - res_old.consume_all() - - # Merge - with pytest.raises(Exception): - cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], - ) - res_new = cg.client._table.read_rows() - res_new.consume_all() - - # Check - if res_old.rows != res_new.rows: - warn( - "Rows were modified when merging a pair of already connected supervoxels. " - "While probably not an error, it is an unnecessary operation." - ) - - @pytest.mark.timeout(30) - def test_merge_triple_chain_to_full_circle_same_chunk(self, gen_graph): - """ - Add edge between indirectly connected RG supervoxels 1 and 2 (same chunk) - ┌─────┐ ┌─────┐ - │ A¹ │ │ A¹ │ - │ 1 2 │ => │ 1━2 │ - │ ┗3┛ │ │ ┗3┛ │ - └─────┘ └─────┘ - """ - - cg = gen_graph(n_layers=2) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[ - to_label(cg, 1, 0, 0, 0, 0), - to_label(cg, 1, 0, 0, 0, 1), - to_label(cg, 1, 0, 0, 0, 2), - ], - edges=[ - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 0.5), - (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 0.5), - ], - timestamp=fake_timestamp, - ) - - # Merge - with pytest.raises(Exception): - cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 0)], - affinities=0.3, - ).new_root_ids - - @pytest.mark.timeout(30) - def test_merge_triple_chain_to_full_circle_neighboring_chunks(self, gen_graph): - """ - Add edge between indirectly connected RG supervoxels 1 and 2 (neighboring chunks) - ┌─────┬─────┐ ┌─────┬─────┐ - │ A¹ │ B¹ │ │ A¹ │ B¹ │ - │ 1 │ 2 │ => │ 1━━┿━━2 │ - │ ┗3━┿━━┛ │ │ ┗3━┿━━┛ │ - └─────┴─────┘ └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - edges=[ - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), - (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 0), inf), - ], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), inf)], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - # Merge - with pytest.raises(Exception): - cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], - affinities=1.0, - ).new_root_ids - - @pytest.mark.timeout(120) - def test_merge_triple_chain_to_full_circle_disconnected_chunks(self, gen_graph): - """ - Add edge between indirectly connected RG supervoxels 1 and 2 (disconnected chunks) - ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ - │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ - │ 1 │ │ 2 │ => │ 1━━┿━━━━━┿━━2 │ - │ ┗3━┿━━━━━┿━━┛ │ │ ┗3━┿━━━━━┿━━┛ │ - └─────┘ └─────┘ └─────┘ └─────┘ - """ - - cg = gen_graph(n_layers=5) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - edges=[ - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), - ( - to_label(cg, 1, 0, 0, 0, 1), - to_label(cg, 1, 7, 7, 7, 0), - inf, - ), - ], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 7, 7, 7, 0)], - edges=[ - ( - to_label(cg, 1, 7, 7, 7, 0), - to_label(cg, 1, 0, 0, 0, 1), - inf, - ) - ], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 3, - [3, 3, 3], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 4, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 4, - [1, 1, 1], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 5, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - # Merge - new_root_ids = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0)], - affinities=1.0, - ).new_root_ids - - assert len(new_root_ids) == 1 - new_root_id = new_root_ids[0] - - # Check - assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_id - assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_id - assert cg.get_root(to_label(cg, 1, 7, 7, 7, 0)) == new_root_id - leaves = np.unique(cg.get_subgraph(new_root_id, leaves_only=True)) - assert len(leaves) == 3 - assert to_label(cg, 1, 0, 0, 0, 0) in leaves - assert to_label(cg, 1, 0, 0, 0, 1) in leaves - assert to_label(cg, 1, 7, 7, 7, 0) in leaves - - @pytest.mark.timeout(30) - def test_merge_same_node(self, gen_graph): - """ - Try to add loop edge between RG supervoxel 1 and itself - ┌─────┐ - │ A¹ │ - │ 1 │ => Reject - │ │ - └─────┘ - """ - - cg = gen_graph(n_layers=2) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - res_old = cg.client._table.read_rows() - res_old.consume_all() - - # Merge - with pytest.raises(Exception): - cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0)], - ) - - res_new = cg.client._table.read_rows() - res_new.consume_all() - - assert res_new.rows == res_old.rows - - @pytest.mark.timeout(30) - def test_merge_pair_abstract_nodes(self, gen_graph): - """ - Try to add edge between RG supervoxel 1 and abstract node "2" - ┌─────┐ - │ B² │ - │ "2" │ - │ │ - └─────┘ - ┌─────┐ => Reject - │ A¹ │ - │ 1 │ - │ │ - └─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - res_old = cg.client._table.read_rows() - res_old.consume_all() - - # Merge - with pytest.raises(Exception): - cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 2, 1, 0, 0, 1)], - ) - - res_new = cg.client._table.read_rows() - res_new.consume_all() - - assert res_new.rows == res_old.rows - - @pytest.mark.timeout(30) - def test_diagonal_connections(self, gen_graph): - """ - Create graph with edge between RG supervoxels 1 and 2 (same chunk) - and edge between RG supervoxels 1 and 3 (neighboring chunks) - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 2 1━┿━━3 │ - │ / │ │ - ┌─────┬─────┐ - │ | │ │ - │ 4━━┿━━5 │ - │ C¹ │ D¹ │ - └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Chunk A - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - edges=[ - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf), - ], - ) - - # Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], - ) - - # Chunk C - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 1, 0, 0)], - edges=[ - (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 1, 1, 0, 0), inf), - (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), - ], - ) - - # Chunk D - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 1, 0, 0)], - edges=[(to_label(cg, 1, 1, 1, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf)], - ) - - add_layer( - cg, - 3, - [0, 0, 0], - n_threads=1, - ) - - rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=3, x=0, y=0, z=0)) - root_ids_t0 = list(rr.keys()) - - assert len(root_ids_t0) == 2 - - child_ids = [] - for root_id in root_ids_t0: - child_ids.extend(cg.get_subgraph(root_id, leaves_only=True)) - - new_roots = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - affinities=[0.5], - ).new_root_ids - - root_ids = [] - for child_id in child_ids: - root_ids.append(cg.get_root(child_id)) - - assert len(np.unique(root_ids)) == 1 - - root_id = root_ids[0] - assert root_id == new_roots[0] - - @pytest.mark.timeout(240) - def test_cross_edges(self, gen_graph): - """""" - - cg = gen_graph(n_layers=5) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[ - to_label(cg, 1, 0, 0, 0, 0), - to_label(cg, 1, 0, 0, 0, 1), - ], - edges=[ - ( - to_label(cg, 1, 0, 0, 0, 1), - to_label(cg, 1, 1, 0, 0, 0), - inf, - ), - ( - to_label(cg, 1, 0, 0, 0, 0), - to_label(cg, 1, 0, 0, 0, 1), - inf, - ), - ], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[ - to_label(cg, 1, 1, 0, 0, 0), - to_label(cg, 1, 1, 0, 0, 1), - ], - edges=[ - ( - to_label(cg, 1, 1, 0, 0, 0), - to_label(cg, 1, 0, 0, 0, 1), - inf, - ), - ( - to_label(cg, 1, 1, 0, 0, 0), - to_label(cg, 1, 1, 0, 0, 1), - inf, - ), - ], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk C - create_chunk( - cg, - vertices=[ - to_label(cg, 1, 2, 0, 0, 0), - ], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 3, - [1, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 4, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - add_layer( - cg, - 5, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - new_roots = cg.add_edges( - "Jane Doe", - [ - to_label(cg, 1, 1, 0, 0, 0), - to_label(cg, 1, 2, 0, 0, 0), - ], - affinities=0.9, - ).new_root_ids - - assert len(new_roots) == 1 - - -class TestGraphMergeSplit: - @pytest.mark.timeout(240) - def test_multiple_cuts_and_splits(self, gen_graph_simplequerytest): - """ - ┌─────┬─────┬─────┐ L X Y Z S L X Y Z S L X Y Z S L X Y Z S - │ A¹ │ B¹ │ C¹ │ 1: 1 0 0 0 0 ─── 2 0 0 0 1 ───────────────── 4 0 0 0 1 - │ 1 │ 3━2━┿━━4 │ 2: 1 1 0 0 0 ─┬─ 2 1 0 0 1 ─── 3 0 0 0 1 ─┬─ 4 0 0 0 2 - │ │ │ │ 3: 1 1 0 0 1 ─┘ │ - └─────┴─────┴─────┘ 4: 1 2 0 0 0 ─── 2 2 0 0 1 ─── 3 1 0 0 1 ─┘ - """ - cg = gen_graph_simplequerytest - - rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=4, x=0, y=0, z=0)) - root_ids_t0 = list(rr.keys()) - child_ids = [types.empty_1d] - for root_id in root_ids_t0: - child_ids.append(cg.get_subgraph([root_id], leaves_only=True)) - child_ids = np.concatenate(child_ids) - - for i in range(10): - - print(f"\n\nITERATION {i}/10") - print("\n\nMERGE 1 & 3\n\n") - new_roots = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], - affinities=0.9, - ).new_root_ids - assert len(new_roots) == 1 - assert len(cg.get_subgraph([new_roots[0]], leaves_only=True)) == 4 - - root_ids = [] - for child_id in child_ids: - root_ids.append(cg.get_root(child_id)) - - u_root_ids = np.unique(root_ids) - assert len(u_root_ids) == 1 - - # ------------------------------------------------------------------ - new_roots = cg.remove_edges( - "John Doe", - source_ids=to_label(cg, 1, 1, 0, 0, 0), - sink_ids=to_label(cg, 1, 1, 0, 0, 1), - mincut=False, - ).new_root_ids - - assert len(np.unique(new_roots)) == 2 - - root_ids = [] - for child_id in child_ids: - root_ids.append(cg.get_root(child_id)) - - u_root_ids = np.unique(root_ids) - these_child_ids = [] - for root_id in u_root_ids: - these_child_ids.extend(cg.get_subgraph([root_id], leaves_only=True)) - - assert len(these_child_ids) == 4 - assert len(u_root_ids) == 2 - - # ------------------------------------------------------------------ - - new_roots = cg.remove_edges( - "Jane Doe", - source_ids=to_label(cg, 1, 0, 0, 0, 0), - sink_ids=to_label(cg, 1, 1, 0, 0, 1), - mincut=False, - ).new_root_ids - assert len(new_roots) == 2 - - root_ids = [] - for child_id in child_ids: - root_ids.append(cg.get_root(child_id)) - - u_root_ids = np.unique(root_ids) - assert len(u_root_ids) == 3 - - # ------------------------------------------------------------------ - - print(f"\n\nITERATION {i}/10") - print("\n\nMERGE 2 & 3\n\n") - - new_roots = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], - affinities=0.9, - ).new_root_ids - assert len(new_roots) == 1 - - root_ids = [] - for child_id in child_ids: - root_ids.append(cg.get_root(child_id)) - - u_root_ids = np.unique(root_ids) - assert len(u_root_ids) == 2 - - # for root_id in root_ids: - # cross_edge_dict_layers = graph_tests.root_cross_edge_test( - # root_id, cg=cg - # ) # dict: layer -> cross_edge_dict - # n_cross_edges_layer = collections.defaultdict(list) - - # for child_layer in cross_edge_dict_layers.keys(): - # for layer in cross_edge_dict_layers[child_layer].keys(): - # n_cross_edges_layer[layer].append( - # len(cross_edge_dict_layers[child_layer][layer]) - # ) - - # for layer in n_cross_edges_layer.keys(): - # assert len(np.unique(n_cross_edges_layer[layer])) == 1 - - -class TestGraphMinCut: - # TODO: Ideally, those tests should focus only on mincut retrieving the correct edges. - # The edge removal part should be tested exhaustively in TestGraphSplit - @pytest.mark.timeout(30) - def test_cut_regular_link(self, gen_graph): - """ - Regular link between 1 and 2 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1━━┿━━2 │ - │ │ │ - └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - # Mincut - new_root_ids = cg.remove_edges( - "Jane Doe", - source_ids=to_label(cg, 1, 0, 0, 0, 0), - sink_ids=to_label(cg, 1, 1, 0, 0, 0), - source_coords=[0, 0, 0], - sink_coords=[ - 2 * cg.meta.graph_config.CHUNK_SIZE[0], - 2 * cg.meta.graph_config.CHUNK_SIZE[1], - cg.meta.graph_config.CHUNK_SIZE[2], - ], - mincut=True, - disallow_isolating_cut=True, - ).new_root_ids - - # Check New State - assert len(new_root_ids) == 2 - assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( - to_label(cg, 1, 1, 0, 0, 0) - ) - leaves = np.unique( - cg.get_subgraph( - [cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True - ) - ) - assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves - leaves = np.unique( - cg.get_subgraph( - [cg.get_root(to_label(cg, 1, 1, 0, 0, 0))], leaves_only=True - ) - ) - assert len(leaves) == 1 and to_label(cg, 1, 1, 0, 0, 0) in leaves - - @pytest.mark.timeout(30) - def test_cut_no_link(self, gen_graph): - """ - No connection between 1 and 2 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 2 │ - │ │ │ - └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - res_old = cg.client._table.read_rows() - res_old.consume_all() - - # Mincut - with pytest.raises(exceptions.PreconditionError): - cg.remove_edges( - "Jane Doe", - source_ids=to_label(cg, 1, 0, 0, 0, 0), - sink_ids=to_label(cg, 1, 1, 0, 0, 0), - source_coords=[0, 0, 0], - sink_coords=[ - 2 * cg.meta.graph_config.CHUNK_SIZE[0], - 2 * cg.meta.graph_config.CHUNK_SIZE[1], - cg.meta.graph_config.CHUNK_SIZE[2], - ], - mincut=True, - ) - - res_new = cg.client._table.read_rows() - res_new.consume_all() - - assert res_new.rows == res_old.rows - - @pytest.mark.timeout(30) - def test_cut_old_link(self, gen_graph): - """ - Link between 1 and 2 got removed previously (aff = 0.0) - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1┅┅╎┅┅2 │ - │ │ │ - └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - cg.remove_edges( - "John Doe", - source_ids=to_label(cg, 1, 1, 0, 0, 0), - sink_ids=to_label(cg, 1, 0, 0, 0, 0), - mincut=False, - ) - - res_old = cg.client._table.read_rows() - res_old.consume_all() - - # Mincut - with pytest.raises(exceptions.PreconditionError): - cg.remove_edges( - "Jane Doe", - source_ids=to_label(cg, 1, 0, 0, 0, 0), - sink_ids=to_label(cg, 1, 1, 0, 0, 0), - source_coords=[0, 0, 0], - sink_coords=[ - 2 * cg.meta.graph_config.CHUNK_SIZE[0], - 2 * cg.meta.graph_config.CHUNK_SIZE[1], - cg.meta.graph_config.CHUNK_SIZE[2], - ], - mincut=True, - ) - - res_new = cg.client._table.read_rows() - res_new.consume_all() - - assert res_new.rows == res_old.rows - - @pytest.mark.timeout(30) - def test_cut_indivisible_link(self, gen_graph): - """ - Sink: 1, Source: 2 - Link between 1 and 2 is set to `inf` and must not be cut. - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1══╪══2 │ - │ │ │ - └─────┴─────┘ - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - original_parents_1 = cg.get_root( - to_label(cg, 1, 0, 0, 0, 0), get_all_parents=True - ) - original_parents_2 = cg.get_root( - to_label(cg, 1, 1, 0, 0, 0), get_all_parents=True - ) - - # Mincut - with pytest.raises(exceptions.PostconditionError): - cg.remove_edges( - "Jane Doe", - source_ids=to_label(cg, 1, 0, 0, 0, 0), - sink_ids=to_label(cg, 1, 1, 0, 0, 0), - source_coords=[0, 0, 0], - sink_coords=[ - 2 * cg.meta.graph_config.CHUNK_SIZE[0], - 2 * cg.meta.graph_config.CHUNK_SIZE[1], - cg.meta.graph_config.CHUNK_SIZE[2], - ], - mincut=True, - ) - - new_parents_1 = cg.get_root(to_label(cg, 1, 0, 0, 0, 0), get_all_parents=True) - new_parents_2 = cg.get_root(to_label(cg, 1, 1, 0, 0, 0), get_all_parents=True) - - assert np.all(np.array(original_parents_1) == np.array(new_parents_1)) - assert np.all(np.array(original_parents_2) == np.array(new_parents_2)) - - @pytest.mark.timeout(30) - def test_mincut_disrespects_sources_or_sinks(self, gen_graph): - """ - When the mincut separates sources or sinks, an error should be thrown. - Although the mincut is setup to never cut an edge between two sources or - two sinks, this can happen when an edge along the only path between two - sources or two sinks is cut. - """ - cg = gen_graph(n_layers=2) - - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[ - to_label(cg, 1, 0, 0, 0, 0), - to_label(cg, 1, 0, 0, 0, 1), - to_label(cg, 1, 0, 0, 0, 2), - to_label(cg, 1, 0, 0, 0, 3), - ], - edges=[ - (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 2), - (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 3), - (to_label(cg, 1, 0, 0, 0, 2), to_label(cg, 1, 0, 0, 0, 3), 10), - ], - timestamp=fake_timestamp, - ) - - # Mincut - with pytest.raises(exceptions.PreconditionError): - cg.remove_edges( - "Jane Doe", - source_ids=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], - sink_ids=[to_label(cg, 1, 0, 0, 0, 3)], - source_coords=[[0, 0, 0], [10, 0, 0]], - sink_coords=[[5, 5, 0]], - mincut=True, - ) - - -class TestGraphMultiCut: - @pytest.mark.timeout(30) - def test_cut_multi_tree(self, gen_graph): - pass - - @pytest.mark.timeout(30) - def test_path_augmented_multicut(self, sv_data): - sv_edges, sv_sources, sv_sinks, sv_affinity, sv_area = sv_data - edges = Edges( - sv_edges[:, 0], sv_edges[:, 1], affinities=sv_affinity, areas=sv_area - ) - - cut_edges_aug = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) - assert cut_edges_aug.shape[0] == 350 - - with pytest.raises(exceptions.PreconditionError): - run_multicut(edges, sv_sources, sv_sinks, path_augment=False) - pass - - -class TestGraphHistory: - """These test inadvertantly also test merge and split operations""" - - @pytest.mark.timeout(120) - def test_cut_merge_history(self, gen_graph): - """ - Regular link between 1 and 2 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1━━┿━━2 │ - │ │ │ - └─────┴─────┘ - (1) Split 1 and 2 - (2) Merge 1 and 2 - """ - from ..graph.lineage import lineage_graph - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 0)], - edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 0)], - edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - first_root = cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) - assert first_root == cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) - timestamp_before_split = datetime.utcnow() - split_roots = cg.remove_edges( - "Jane Doe", - source_ids=to_label(cg, 1, 0, 0, 0, 0), - sink_ids=to_label(cg, 1, 1, 0, 0, 0), - mincut=False, - ).new_root_ids - assert len(split_roots) == 2 - g = lineage_graph(cg, split_roots[0]) - assert g.size() == 1 - g = lineage_graph(cg, split_roots) - assert g.size() == 2 - - timestamp_after_split = datetime.utcnow() - merge_roots = cg.add_edges( - "Jane Doe", - [to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0)], - affinities=0.4, - ).new_root_ids - assert len(merge_roots) == 1 - merge_root = merge_roots[0] - timestamp_after_merge = datetime.utcnow() - - g = lineage_graph(cg, merge_roots) - assert g.size() == 4 - assert ( - len( - get_root_id_history( - cg, - first_root, - time_stamp_past=datetime.min, - time_stamp_future=datetime.max, - ) - ) - == 4 - ) - assert ( - len( - get_root_id_history( - cg, - split_roots[0], - time_stamp_past=datetime.min, - time_stamp_future=datetime.max, - ) - ) - == 3 - ) - assert ( - len( - get_root_id_history( - cg, - split_roots[1], - time_stamp_past=datetime.min, - time_stamp_future=datetime.max, - ) - ) - == 3 - ) - assert ( - len( - get_root_id_history( - cg, - merge_root, - time_stamp_past=datetime.min, - time_stamp_future=datetime.max, - ) - ) - == 4 - ) - - new_roots, old_roots = get_delta_roots( - cg, timestamp_before_split, timestamp_after_split - ) - assert len(old_roots) == 1 - assert old_roots[0] == first_root - assert len(new_roots) == 2 - assert np.all(np.isin(new_roots, split_roots)) - - new_roots2, old_roots2 = get_delta_roots( - cg, timestamp_after_split, timestamp_after_merge - ) - assert len(new_roots2) == 1 - assert new_roots2[0] == merge_root - assert len(old_roots2) == 2 - assert np.all(np.isin(old_roots2, split_roots)) - - new_roots3, old_roots3 = get_delta_roots( - cg, timestamp_before_split, timestamp_after_merge - ) - assert len(new_roots3) == 1 - assert new_roots3[0] == merge_root - assert len(old_roots3) == 1 - assert old_roots3[0] == first_root - - -class TestGraphLocks: - @pytest.mark.timeout(30) - def test_lock_unlock(self, gen_graph): - """ - No connection between 1, 2 and 3 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 3 │ - │ 2 │ │ - └─────┴─────┘ - - (1) Try lock (opid = 1) - (2) Try lock (opid = 2) - (3) Try unlock (opid = 1) - (4) Try lock (opid = 2) - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - operation_id_1 = cg.id_client.create_operation_id() - root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - - future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_1, - future_root_ids_d=future_root_ids_d, - )[0] - - operation_id_2 = cg.id_client.create_operation_id() - assert not cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - )[0] - - assert cg.client.unlock_root(root_id=root_id, operation_id=operation_id_1) - - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - )[0] - - @pytest.mark.timeout(30) - def test_lock_expiration(self, gen_graph): - """ - No connection between 1, 2 and 3 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 3 │ - │ 2 │ │ - └─────┴─────┘ - - (1) Try lock (opid = 1) - (2) Try lock (opid = 2) - (3) Try lock (opid = 2) with retries - """ - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - operation_id_1 = cg.id_client.create_operation_id() - root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_1, - future_root_ids_d=future_root_ids_d, - )[0] - - operation_id_2 = cg.id_client.create_operation_id() - assert not cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - )[0] - - sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()) - - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - max_tries=10, - waittime_s=0.5, - )[0] - - @pytest.mark.timeout(30) - def test_lock_renew(self, gen_graph): - """ - No connection between 1, 2 and 3 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 3 │ - │ 2 │ │ - └─────┴─────┘ - - (1) Try lock (opid = 1) - (2) Try lock (opid = 2) - (3) Try lock (opid = 2) with retries - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - operation_id_1 = cg.id_client.create_operation_id() - root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_1, - future_root_ids_d=future_root_ids_d, - )[0] - - assert cg.client.renew_locks(root_ids=[root_id], operation_id=operation_id_1) - - @pytest.mark.timeout(30) - def test_lock_merge_lock_old_id(self, gen_graph): - """ - No connection between 1, 2 and 3 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 3 │ - │ 2 │ │ - └─────┴─────┘ - - (1) Merge (includes lock opid 1) - (2) Try lock opid 2 --> should be successful and return new root id - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - - new_root_ids = cg.add_edges( - "Chuck Norris", - [to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - affinities=1.0, - ).new_root_ids - - assert new_root_ids is not None - - operation_id_2 = cg.id_client.create_operation_id() - future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} - success, new_root_id = cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - max_tries=10, - waittime_s=0.5, - ) - - assert success - assert new_root_ids[0] == new_root_id - - @pytest.mark.timeout(30) - def test_indefinite_lock(self, gen_graph): - """ - No connection between 1, 2 and 3 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 3 │ - │ 2 │ │ - └─────┴─────┘ - - (1) Try indefinite lock (opid = 1), get indefinite lock - (2) Try normal lock (opid = 2), doesn't get the normal lock - (3) Try unlock indefinite lock (opid = 1), should unlock indefinite lock - (4) Try lock (opid = 2), should get the normal lock - """ - - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - operation_id_1 = cg.id_client.create_operation_id() - root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - - future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} - assert cg.client.lock_roots_indefinitely( - root_ids=[root_id], - operation_id=operation_id_1, - future_root_ids_d=future_root_ids_d, - )[0] - - operation_id_2 = cg.id_client.create_operation_id() - assert not cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - )[0] - - assert cg.client.unlock_indefinitely_locked_root( - root_id=root_id, operation_id=operation_id_1 - ) - - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - )[0] - - @pytest.mark.timeout(30) - def test_indefinite_lock_with_normal_lock_expiration(self, gen_graph): - """ - No connection between 1, 2 and 3 - ┌─────┬─────┐ - │ A¹ │ B¹ │ - │ 1 │ 3 │ - │ 2 │ │ - └─────┴─────┘ - - (1) Try normal lock (opid = 1), get normal lock - (2) Try indefinite lock (opid = 1), get indefinite lock - (3) Wait until normal lock expires - (4) Try normal lock (opid = 2), doesn't get the normal lock - (5) Try unlock indefinite lock (opid = 1), should unlock indefinite lock - (6) Try lock (opid = 2), should get the normal lock - """ - - # 1. TODO renew lock test when getting indefinite lock - cg = gen_graph(n_layers=3) - - # Preparation: Build Chunk A - fake_timestamp = datetime.utcnow() - timedelta(days=10) - create_chunk( - cg, - vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - edges=[], - timestamp=fake_timestamp, - ) - - # Preparation: Build Chunk B - create_chunk( - cg, - vertices=[to_label(cg, 1, 1, 0, 0, 1)], - edges=[], - timestamp=fake_timestamp, - ) - - add_layer( - cg, - 3, - [0, 0, 0], - time_stamp=fake_timestamp, - n_threads=1, - ) - - operation_id_1 = cg.id_client.create_operation_id() - root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - - future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} - - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_1, - future_root_ids_d=future_root_ids_d, - )[0] - - assert cg.client.lock_roots_indefinitely( - root_ids=[root_id], - operation_id=operation_id_1, - future_root_ids_d=future_root_ids_d, - )[0] - - sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()) - - operation_id_2 = cg.id_client.create_operation_id() - assert not cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - )[0] - - assert cg.client.unlock_indefinitely_locked_root( - root_id=root_id, operation_id=operation_id_1 - ) - - assert cg.client.lock_roots( - root_ids=[root_id], - operation_id=operation_id_2, - future_root_ids_d=future_root_ids_d, - )[0] - - # TODO fixme: this scenario can't be tested like this - # @pytest.mark.timeout(30) - # def test_normal_lock_expiration(self, gen_graph): - # """ - # No connection between 1, 2 and 3 - # ┌─────┬─────┐ - # │ A¹ │ B¹ │ - # │ 1 │ 3 │ - # │ 2 │ │ - # └─────┴─────┘ - - # (1) Try normal lock (opid = 1), get normal lock - # (2) Wait until normal lock expires - # (3) Try indefinite lock (opid = 1), doesn't get the indefinite lock - # """ - - # cg = gen_graph(n_layers=3) - - # # Preparation: Build Chunk A - # fake_timestamp = datetime.utcnow() - timedelta(days=10) - # create_chunk( - # cg, - # vertices=[to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2)], - # edges=[], - # timestamp=fake_timestamp, - # ) - - # # Preparation: Build Chunk B - # create_chunk( - # cg, - # vertices=[to_label(cg, 1, 1, 0, 0, 1)], - # edges=[], - # timestamp=fake_timestamp, - # ) - - # add_layer( - # cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1, - # ) - - # operation_id_1 = cg.id_client.create_operation_id() - # root_id = cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) - - # future_root_ids_d = {root_id: get_future_root_ids(cg, root_id)} - - # assert cg.client.lock_roots( - # root_ids=[root_id], - # operation_id=operation_id_1, - # future_root_ids_d=future_root_ids_d, - # )[0] - - # sleep(cg.meta.graph_config.ROOT_LOCK_EXPIRY.total_seconds()+1) - - # assert not cg.client.lock_roots_indefinitely( - # root_ids=[root_id], - # operation_id=operation_id_1, - # future_root_ids_d=future_root_ids_d, - # )[0] - - -# class MockChunkedGraph: -# """ -# Dummy class to mock partial functionality of the ChunkedGraph for use in unit tests. -# Feel free to add more functions as need be. Can pass in alternative member functions into constructor. -# """ - -# def __init__( -# self, get_chunk_coordinates=None, get_chunk_layer=None, get_chunk_id=None -# ): -# if get_chunk_coordinates is not None: -# self.get_chunk_coordinates = get_chunk_coordinates -# if get_chunk_layer is not None: -# self.get_chunk_layer = get_chunk_layer -# if get_chunk_id is not None: -# self.get_chunk_id = get_chunk_id - -# def get_chunk_coordinates(self, chunk_id): # pylint: disable=method-hidden -# return np.array([0, 0, 0]) - -# def get_chunk_layer(self, chunk_id): # pylint: disable=method-hidden -# return 2 - -# def get_chunk_id(self, *args): # pylint: disable=method-hidden -# return 0 - - -# class TestGraphSplit: -# @pytest.mark.timeout(30) -# def test_split_pair_same_chunk(self, gen_graph): -# """ -# Remove edge between existing RG supervoxels 1 and 2 (same chunk) -# Expected: Different (new) parents for RG 1 and 2 on Layer two -# ┌─────┐ ┌─────┐ -# │ A¹ │ │ A¹ │ -# │ 1━2 │ => │ 1 2 │ -# │ │ │ │ -# └─────┘ └─────┘ -# """ - -# cg = gen_graph(n_layers=2) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], -# edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], -# timestamp=fake_timestamp, -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 0, 0, 0, 1), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ).new_root_ids - -# # Check New State -# assert len(new_root_ids) == 2 -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( -# to_label(cg, 1, 0, 0, 0, 1) -# ) -# leaves = np.unique( -# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True) -# ) -# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves -# leaves = np.unique( -# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 1))], leaves_only=True) -# ) -# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 1) in leaves - -# # Check Old State still accessible -# assert cg.get_root( -# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp -# ) == cg.get_root(to_label(cg, 1, 0, 0, 0, 1), time_stamp=fake_timestamp) -# leaves = np.unique( -# cg.get_subgraph( -# [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], -# leaves_only=True, -# ) -# ) -# assert len(leaves) == 2 -# assert to_label(cg, 1, 0, 0, 0, 0) in leaves -# assert to_label(cg, 1, 0, 0, 0, 1) in leaves - -# # assert len(cg.get_latest_roots()) == 2 -# # assert len(cg.get_latest_roots(fake_timestamp)) == 1 - -# def test_split_nonexisting_edge(self, gen_graph): -# """ -# Remove edge between existing RG supervoxels 1 and 2 (same chunk) -# Expected: Different (new) parents for RG 1 and 2 on Layer two -# ┌─────┐ ┌─────┐ -# │ A¹ │ │ A¹ │ -# │ 1━2 │ => │ 1━2 │ -# │ | │ │ | │ -# │ 3 │ │ 3 │ -# └─────┘ └─────┘ -# """ - -# cg = gen_graph(n_layers=2) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], -# edges=[ -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), -# (to_label(cg, 1, 0, 0, 0, 2), to_label(cg, 1, 0, 0, 0, 1), 0.5), -# ], -# timestamp=fake_timestamp, -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 0, 0, 0, 0), -# sink_ids=to_label(cg, 1, 0, 0, 0, 2), -# mincut=False, -# ).new_root_ids - -# assert len(new_root_ids) == 1 - -# @pytest.mark.timeout(30) -# def test_split_pair_neighboring_chunks(self, gen_graph): -# """ -# Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) -# ┌─────┬─────┐ ┌─────┬─────┐ -# │ A¹ │ B¹ │ │ A¹ │ B¹ │ -# │ 1━━┿━━2 │ => │ 1 │ 2 │ -# │ │ │ │ │ │ -# └─────┴─────┘ └─────┴─────┘ -# """ - -# cg = gen_graph(n_layers=3) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0)], -# edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 1.0)], -# timestamp=fake_timestamp, -# ) - -# # Preparation: Build Chunk B -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 1, 0, 0, 0)], -# edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 1.0)], -# timestamp=fake_timestamp, -# ) - -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 1, 0, 0, 0), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ).new_root_ids - -# # Check New State -# assert len(new_root_ids) == 2 -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( -# to_label(cg, 1, 1, 0, 0, 0) -# ) -# leaves = np.unique( -# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True) -# ) -# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves -# leaves = np.unique( -# cg.get_subgraph([cg.get_root(to_label(cg, 1, 1, 0, 0, 0))], leaves_only=True) -# ) -# assert len(leaves) == 1 and to_label(cg, 1, 1, 0, 0, 0) in leaves - -# # Check Old State still accessible -# assert cg.get_root( -# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp -# ) == cg.get_root(to_label(cg, 1, 1, 0, 0, 0), time_stamp=fake_timestamp) -# leaves = np.unique( -# cg.get_subgraph( -# [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], -# leaves_only=True, -# ) -# ) -# assert len(leaves) == 2 -# assert to_label(cg, 1, 0, 0, 0, 0) in leaves -# assert to_label(cg, 1, 1, 0, 0, 0) in leaves - -# assert len(cg.get_latest_roots()) == 2 -# assert len(cg.get_latest_roots(fake_timestamp)) == 1 - -# @pytest.mark.timeout(30) -# def test_split_verify_cross_chunk_edges(self, gen_graph): -# """ -# Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) -# ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ -# | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ -# | │ 1━━┿━━3 │ => | │ 1━━┿━━3 │ -# | │ | │ │ | │ │ │ -# | │ 2 │ │ | │ 2 │ │ -# └─────┴─────┴─────┘ └─────┴─────┴─────┘ -# """ - -# cg = gen_graph(n_layers=4) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1)], -# edges=[ -# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), -# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 1), 0.5), -# ], -# timestamp=fake_timestamp, -# ) - -# # Preparation: Build Chunk B -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 2, 0, 0, 0)], -# edges=[(to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf)], -# timestamp=fake_timestamp, -# ) - -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 4, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) - -# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( -# to_label(cg, 1, 1, 0, 0, 1) -# ) -# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( -# to_label(cg, 1, 2, 0, 0, 0) -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 1, 0, 0, 0), -# sink_ids=to_label(cg, 1, 1, 0, 0, 1), -# mincut=False, -# ).new_root_ids - -# assert len(new_root_ids) == 2 - -# svs2 = cg.get_subgraph([new_root_ids[0]], leaves_only=True) -# svs1 = cg.get_subgraph([new_root_ids[1]], leaves_only=True) -# len_set = {1, 2} -# assert len(svs1) in len_set -# len_set.remove(len(svs1)) -# assert len(svs2) in len_set - -# # Check New State -# assert len(new_root_ids) == 2 -# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) != cg.get_root( -# to_label(cg, 1, 1, 0, 0, 1) -# ) -# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( -# to_label(cg, 1, 2, 0, 0, 0) -# ) - -# cc_dict = cg.get_atomic_cross_edges( -# cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) -# ) -# assert len(cc_dict[3]) == 1 -# assert cc_dict[3][0][0] == to_label(cg, 1, 1, 0, 0, 0) -# assert cc_dict[3][0][1] == to_label(cg, 1, 2, 0, 0, 0) - -# assert len(cg.get_latest_roots()) == 2 -# assert len(cg.get_latest_roots(fake_timestamp)) == 1 - -# @pytest.mark.timeout(30) -# def test_split_verify_loop(self, gen_graph): -# """ -# Remove edge between existing RG supervoxels 1 and 2 (neighboring chunks) -# ┌─────┬────────┬─────┐ ┌─────┬────────┬─────┐ -# | │ A¹ │ B¹ │ | │ A¹ │ B¹ │ -# | │ 4━━1━━┿━━5 │ => | │ 4 1━━┿━━5 │ -# | │ / │ | │ | │ │ | │ -# | │ 3 2━━┿━━6 │ | │ 3 2━━┿━━6 │ -# └─────┴────────┴─────┘ └─────┴────────┴─────┘ -# """ - -# cg = gen_graph(n_layers=4) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[ -# to_label(cg, 1, 1, 0, 0, 0), -# to_label(cg, 1, 1, 0, 0, 1), -# to_label(cg, 1, 1, 0, 0, 2), -# to_label(cg, 1, 1, 0, 0, 3), -# ], -# edges=[ -# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 0), inf), -# (to_label(cg, 1, 1, 0, 0, 1), to_label(cg, 1, 2, 0, 0, 1), inf), -# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 2), 0.5), -# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 3), 0.5), -# ], -# timestamp=fake_timestamp, -# ) - -# # Preparation: Build Chunk B -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 2, 0, 0, 1)], -# edges=[ -# (to_label(cg, 1, 2, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), -# (to_label(cg, 1, 2, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 1), inf), -# (to_label(cg, 1, 2, 0, 0, 1), to_label(cg, 1, 2, 0, 0, 0), 0.5), -# ], -# timestamp=fake_timestamp, -# ) - -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 4, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) - -# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( -# to_label(cg, 1, 1, 0, 0, 1) -# ) -# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == cg.get_root( -# to_label(cg, 1, 2, 0, 0, 0) -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 1, 0, 0, 0), -# sink_ids=to_label(cg, 1, 1, 0, 0, 2), -# mincut=False, -# ).new_root_ids - -# assert len(new_root_ids) == 2 - -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 1, 0, 0, 0), -# sink_ids=to_label(cg, 1, 1, 0, 0, 3), -# mincut=False, -# ).new_root_ids - -# assert len(new_root_ids) == 2 - -# cc_dict = cg.get_atomic_cross_edges( -# cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) -# ) -# assert len(cc_dict[3]) == 1 -# cc_dict = cg.get_atomic_cross_edges( -# cg.get_parent(to_label(cg, 1, 1, 0, 0, 0)) -# ) -# assert len(cc_dict[3]) == 1 - -# assert len(cg.get_latest_roots()) == 3 -# assert len(cg.get_latest_roots(fake_timestamp)) == 1 - -# @pytest.mark.timeout(30) -# def test_split_pair_disconnected_chunks(self, gen_graph): -# """ -# Remove edge between existing RG supervoxels 1 and 2 (disconnected chunks) -# ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ -# │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ -# │ 1━━┿━━━━━┿━━2 │ => │ 1 │ │ 2 │ -# │ │ │ │ │ │ │ │ -# └─────┘ └─────┘ └─────┘ └─────┘ -# """ - -# cg = gen_graph(n_layers=9) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0)], -# edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 7, 7, 7, 0), 1.0,)], -# timestamp=fake_timestamp, -# ) - -# # Preparation: Build Chunk Z -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 7, 7, 7, 0)], -# edges=[(to_label(cg, 1, 7, 7, 7, 0), to_label(cg, 1, 0, 0, 0, 0), 1.0,)], -# timestamp=fake_timestamp, -# ) - -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 4, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 4, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 5, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 5, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 6, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 6, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 7, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 7, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 8, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 8, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# 9, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) - -# # Split -# new_roots = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 7, 7, 7, 0), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ).new_root_ids - -# # Check New State -# assert len(new_roots) == 2 -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) != cg.get_root( -# to_label(cg, 1, 7, 7, 7, 0) -# ) -# leaves = np.unique( -# cg.get_subgraph([cg.get_root(to_label(cg, 1, 0, 0, 0, 0))], leaves_only=True) -# ) -# assert len(leaves) == 1 and to_label(cg, 1, 0, 0, 0, 0) in leaves -# leaves = np.unique( -# cg.get_subgraph([cg.get_root(to_label(cg, 1, 7, 7, 7, 0))], leaves_only=True) -# ) -# assert len(leaves) == 1 and to_label(cg, 1, 7, 7, 7, 0) in leaves - -# # Check Old State still accessible -# assert cg.get_root( -# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp -# ) == cg.get_root(to_label(cg, 1, 7, 7, 7, 0), time_stamp=fake_timestamp) -# leaves = np.unique( -# cg.get_subgraph( -# [cg.get_root(to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp)], -# leaves_only=True, -# ) -# ) -# assert len(leaves) == 2 -# assert to_label(cg, 1, 0, 0, 0, 0) in leaves -# assert to_label(cg, 1, 7, 7, 7, 0) in leaves - -# @pytest.mark.timeout(30) -# def test_split_pair_already_disconnected(self, gen_graph): -# """ -# Try to remove edge between already disconnected RG supervoxels 1 and 2 (same chunk). -# Expected: No change, no error -# ┌─────┐ ┌─────┐ -# │ A¹ │ │ A¹ │ -# │ 1 2 │ => │ 1 2 │ -# │ │ │ │ -# └─────┘ └─────┘ -# """ - -# cg = gen_graph(n_layers=2) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], -# edges=[], -# timestamp=fake_timestamp, -# ) - -# res_old = cg.client._table.read_rows() -# res_old.consume_all() - -# # Split -# with pytest.raises(exceptions.PreconditionError): -# cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 0, 0, 0, 1), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ) - -# res_new = cg.client._table.read_rows() -# res_new.consume_all() - -# # Check -# if res_old.rows != res_new.rows: -# warn( -# "Rows were modified when splitting a pair of already disconnected supervoxels. " -# "While probably not an error, it is an unnecessary operation." -# ) - -# @pytest.mark.timeout(30) -# def test_split_full_circle_to_triple_chain_same_chunk(self, gen_graph): -# """ -# Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (same chunk) -# ┌─────┐ ┌─────┐ -# │ A¹ │ │ A¹ │ -# │ 1━2 │ => │ 1 2 │ -# │ ┗3┛ │ │ ┗3┛ │ -# └─────┘ └─────┘ -# """ - -# cg = gen_graph(n_layers=2) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[ -# to_label(cg, 1, 0, 0, 0, 0), -# to_label(cg, 1, 0, 0, 0, 1), -# to_label(cg, 1, 0, 0, 0, 2), -# ], -# edges=[ -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 2), 0.5), -# (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 0, 0, 0, 2), 0.5), -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.3), -# ], -# timestamp=fake_timestamp, -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 0, 0, 0, 1), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ).new_root_ids - -# # Check New State -# assert len(new_root_ids) == 1 -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 2)) == new_root_ids[0] -# leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) -# assert len(leaves) == 3 -# assert to_label(cg, 1, 0, 0, 0, 0) in leaves -# assert to_label(cg, 1, 0, 0, 0, 1) in leaves -# assert to_label(cg, 1, 0, 0, 0, 2) in leaves - -# # Check Old State still accessible -# old_root_id = cg.get_root( -# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp -# ) -# assert new_root_ids[0] != old_root_id - -# # assert len(cg.get_latest_roots()) == 1 -# # assert len(cg.get_latest_roots(fake_timestamp)) == 1 - -# @pytest.mark.timeout(30) -# def test_split_full_circle_to_triple_chain_neighboring_chunks(self, gen_graph): -# """ -# Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (neighboring chunks) -# ┌─────┬─────┐ ┌─────┬─────┐ -# │ A¹ │ B¹ │ │ A¹ │ B¹ │ -# │ 1━━┿━━2 │ => │ 1 │ 2 │ -# │ ┗3━┿━━┛ │ │ ┗3━┿━━┛ │ -# └─────┴─────┘ └─────┴─────┘ -# """ - -# cg = gen_graph(n_layers=3) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], -# edges=[ -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), -# (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, 1, 0, 0, 0), 0.5), -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.3), -# ], -# timestamp=fake_timestamp, -# ) - -# # Preparation: Build Chunk B -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 1, 0, 0, 0)], -# edges=[ -# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), -# (to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.3), -# ], -# timestamp=fake_timestamp, -# ) - -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 1, 0, 0, 0), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ).new_root_ids - -# # Check New State -# assert len(new_root_ids) == 1 -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] -# assert cg.get_root(to_label(cg, 1, 1, 0, 0, 0)) == new_root_ids[0] -# leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) -# assert len(leaves) == 3 -# assert to_label(cg, 1, 0, 0, 0, 0) in leaves -# assert to_label(cg, 1, 0, 0, 0, 1) in leaves -# assert to_label(cg, 1, 1, 0, 0, 0) in leaves - -# # Check Old State still accessible -# old_root_id = cg.get_root( -# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp -# ) -# assert new_root_ids[0] != old_root_id - -# assert len(cg.get_latest_roots()) == 1 -# assert len(cg.get_latest_roots(fake_timestamp)) == 1 - -# @pytest.mark.timeout(30) -# def test_split_full_circle_to_triple_chain_disconnected_chunks(self, gen_graph): -# """ -# Remove direct edge between RG supervoxels 1 and 2, but leave indirect connection (disconnected chunks) -# ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ -# │ A¹ │ ... │ Z¹ │ │ A¹ │ ... │ Z¹ │ -# │ 1━━┿━━━━━┿━━2 │ => │ 1 │ │ 2 │ -# │ ┗3━┿━━━━━┿━━┛ │ │ ┗3━┿━━━━━┿━━┛ │ -# └─────┘ └─────┘ └─────┘ └─────┘ -# """ - -# cg = gen_graph(n_layers=9) - -# loc = 2 - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], -# edges=[ -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), -# (to_label(cg, 1, 0, 0, 0, 1), to_label(cg, 1, loc, loc, loc, 0), 0.5,), -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, loc, loc, loc, 0), 0.3,), -# ], -# timestamp=fake_timestamp, -# ) - -# # Preparation: Build Chunk Z -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, loc, loc, loc, 0)], -# edges=[ -# (to_label(cg, 1, loc, loc, loc, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5,), -# (to_label(cg, 1, loc, loc, loc, 0), to_label(cg, 1, 0, 0, 0, 0), 0.3,), -# ], -# timestamp=fake_timestamp, -# ) - -# for i_layer in range(3, 10): -# if loc // 2 ** (i_layer - 3) == 1: -# add_layer( -# cg, -# i_layer, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# elif loc // 2 ** (i_layer - 3) == 0: -# add_layer( -# cg, -# i_layer, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# else: -# add_layer( -# cg, -# i_layer, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) -# add_layer( -# cg, -# i_layer, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) - -# assert ( -# cg.get_root(to_label(cg, 1, loc, loc, loc, 0)) -# == cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) -# == cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) -# ) - -# # Split -# new_root_ids = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, loc, loc, loc, 0), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ).new_root_ids - -# # Check New State -# assert len(new_root_ids) == 1 -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == new_root_ids[0] -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 1)) == new_root_ids[0] -# assert cg.get_root(to_label(cg, 1, loc, loc, loc, 0)) == new_root_ids[0] -# leaves = np.unique(cg.get_subgraph([new_root_ids[0]], leaves_only=True)) -# assert len(leaves) == 3 -# assert to_label(cg, 1, 0, 0, 0, 0) in leaves -# assert to_label(cg, 1, 0, 0, 0, 1) in leaves -# assert to_label(cg, 1, loc, loc, loc, 0) in leaves - -# # Check Old State still accessible -# old_root_id = cg.get_root( -# to_label(cg, 1, 0, 0, 0, 0), time_stamp=fake_timestamp -# ) -# assert new_root_ids[0] != old_root_id - -# assert len(cg.get_latest_roots()) == 1 -# assert len(cg.get_latest_roots(fake_timestamp)) == 1 - -# @pytest.mark.timeout(30) -# def test_split_same_node(self, gen_graph): -# """ -# Try to remove (non-existing) edge between RG supervoxel 1 and itself -# ┌─────┐ -# │ A¹ │ -# │ 1 │ => Reject -# │ │ -# └─────┘ -# """ - -# cg = gen_graph(n_layers=2) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0)], -# edges=[], -# timestamp=fake_timestamp, -# ) - -# res_old = cg.client._table.read_rows() -# res_old.consume_all() - -# # Split -# with pytest.raises(exceptions.PreconditionError): -# cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 0, 0, 0, 0), -# sink_ids=to_label(cg, 1, 0, 0, 0, 0), -# mincut=False, -# ) - -# res_new = cg.client._table.read_rows() -# res_new.consume_all() - -# assert res_new.rows == res_old.rows - -# @pytest.mark.timeout(30) -# def test_split_pair_abstract_nodes(self, gen_graph): -# """ -# Try to remove (non-existing) edge between RG supervoxel 1 and abstract node "2" -# ┌─────┐ -# │ B² │ -# │ "2" │ -# │ │ -# └─────┘ -# ┌─────┐ => Reject -# │ A¹ │ -# │ 1 │ -# │ │ -# └─────┘ -# """ - -# cg = gen_graph(n_layers=3) - -# # Preparation: Build Chunk A -# fake_timestamp = datetime.utcnow() - timedelta(days=10) -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0)], -# edges=[], -# timestamp=fake_timestamp, -# ) - -# # Preparation: Build Chunk B -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 1, 0, 0, 0)], -# edges=[], -# timestamp=fake_timestamp, -# ) - -# add_layer( -# cg, -# 3, -# [0, 0, 0], -# -# time_stamp=fake_timestamp, -# n_threads=1, -# ) - -# res_old = cg.client._table.read_rows() -# res_old.consume_all() - -# # Split -# with pytest.raises(exceptions.PreconditionError): -# cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 0, 0, 0, 0), -# sink_ids=to_label(cg, 2, 1, 0, 0, 1), -# mincut=False, -# ) - -# res_new = cg.client._table.read_rows() -# res_new.consume_all() - -# assert res_new.rows == res_old.rows - -# @pytest.mark.timeout(30) -# def test_diagonal_connections(self, gen_graph): -# """ -# Create graph with edge between RG supervoxels 1 and 2 (same chunk) -# and edge between RG supervoxels 1 and 3 (neighboring chunks) -# ┌─────┬─────┐ -# │ A¹ │ B¹ │ -# │ 2━1━┿━━3 │ -# │ / │ │ -# ┌─────┬─────┐ -# │ | │ │ -# │ 4━━┿━━5 │ -# │ C¹ │ D¹ │ -# └─────┴─────┘ -# """ - -# cg = gen_graph(n_layers=3) - -# # Chunk A -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], -# edges=[ -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5), -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), inf), -# (to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf), -# ], -# ) - -# # Chunk B -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 1, 0, 0, 0)], -# edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], -# ) - -# # Chunk C -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 0, 1, 0, 0)], -# edges=[ -# (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 1, 1, 0, 0), inf), -# (to_label(cg, 1, 0, 1, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf), -# ], -# ) - -# # Chunk D -# create_chunk( -# cg, -# vertices=[to_label(cg, 1, 1, 1, 0, 0)], -# edges=[(to_label(cg, 1, 1, 1, 0, 0), to_label(cg, 1, 0, 1, 0, 0), inf)], -# ) - -# add_layer( -# cg, 3, [0, 0, 0], n_threads=1, -# ) - -# rr = cg.range_read_chunk(chunk_id=cg.get_chunk_id(layer=3, x=0, y=0, z=0)) -# root_ids_t0 = list(rr.keys()) - -# assert len(root_ids_t0) == 1 - -# child_ids = [] -# for root_id in root_ids_t0: -# child_ids.extend([cg.get_subgraph([root_id])], leaves_only=True) - -# new_roots = cg.remove_edges( -# "Jane Doe", -# source_ids=to_label(cg, 1, 0, 0, 0, 0), -# sink_ids=to_label(cg, 1, 0, 0, 0, 1), -# mincut=False, -# ).new_root_ids - -# assert len(new_roots) == 2 -# assert cg.get_root(to_label(cg, 1, 1, 1, 0, 0)) == cg.get_root( -# to_label(cg, 1, 0, 1, 0, 0) -# ) -# assert cg.get_root(to_label(cg, 1, 0, 0, 0, 0)) == cg.get_root( -# to_label(cg, 1, 0, 0, 0, 0) -# ) diff --git a/pychunkedgraph/tests/test_undo_redo.py b/pychunkedgraph/tests/test_undo_redo.py new file mode 100644 index 000000000..a49f01fe0 --- /dev/null +++ b/pychunkedgraph/tests/test_undo_redo.py @@ -0,0 +1,120 @@ +"""Integration tests for undo/redo operations through the full graph. + +Tests that undo and redo correctly restore graph state using real graph +operations through the BigTable emulator. +""" + +from datetime import datetime, timedelta, UTC + +import numpy as np +import pytest + +from .helpers import create_chunk, to_label +from ..ingest.create.parent_layer import add_parent_chunk + + +class TestUndoRedo: + @pytest.fixture() + def two_chunk_graph(self, gen_graph): + """ + Build a 2-chunk graph with edge between SVs 1 and 2. + ┌─────┬─────┐ + │ A¹ │ B¹ │ + │ 1━━┿━━2 │ + │ │ │ + └─────┴─────┘ + """ + cg = gen_graph(n_layers=3) + fake_timestamp = datetime.now(UTC) - timedelta(days=10) + + create_chunk( + cg, + vertices=[to_label(cg, 1, 0, 0, 0, 0)], + edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 1, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + create_chunk( + cg, + vertices=[to_label(cg, 1, 1, 0, 0, 0)], + edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), 0.5)], + timestamp=fake_timestamp, + ) + add_parent_chunk(cg, 3, [0, 0, 0], time_stamp=fake_timestamp, n_threads=1) + return cg + + @pytest.mark.timeout(30) + def test_undo_split_restores_merged_root(self, two_chunk_graph): + """Split two nodes, undo — nodes should share a common root again.""" + cg = two_chunk_graph + sv1 = to_label(cg, 1, 0, 0, 0, 0) + sv2 = to_label(cg, 1, 1, 0, 0, 0) + + # Initially, both SVs share a root + assert cg.get_root(sv1) == cg.get_root(sv2) + + # Split + split_result = cg.remove_edges( + "test_user", source_ids=sv1, sink_ids=sv2, mincut=False + ) + assert len(split_result.new_root_ids) == 2 + assert cg.get_root(sv1) != cg.get_root(sv2) + + # Undo the split + cg.undo_operation("test_user", split_result.operation_id) + + # After undo, both SVs should share a root again + assert cg.get_root(sv1) == cg.get_root(sv2) + + @pytest.mark.timeout(30) + def test_redo_restores_operation_result(self, two_chunk_graph): + """Split, undo, redo the original split — state should match the post-split state.""" + cg = two_chunk_graph + sv1 = to_label(cg, 1, 0, 0, 0, 0) + sv2 = to_label(cg, 1, 1, 0, 0, 0) + + # Split + split_result = cg.remove_edges( + "test_user", source_ids=sv1, sink_ids=sv2, mincut=False + ) + assert cg.get_root(sv1) != cg.get_root(sv2) + + # Undo (merges back) + cg.undo_operation("test_user", split_result.operation_id) + assert cg.get_root(sv1) == cg.get_root(sv2) + + # Redo the original split operation (re-applies the split) + cg.redo_operation("test_user", split_result.operation_id) + + # After redo, nodes should be split again + assert cg.get_root(sv1) != cg.get_root(sv2) + + @pytest.mark.timeout(30) + def test_undo_preserves_subgraph_leaves(self, two_chunk_graph): + """After undo, subgraph leaves should match the pre-operation state.""" + cg = two_chunk_graph + sv1 = to_label(cg, 1, 0, 0, 0, 0) + sv2 = to_label(cg, 1, 1, 0, 0, 0) + + # Get initial leaf set + initial_root = cg.get_root(sv1) + initial_leaves = set( + np.unique(cg.get_subgraph([initial_root], leaves_only=True)) + ) + assert sv1 in initial_leaves + assert sv2 in initial_leaves + + # Split + split_result = cg.remove_edges( + "test_user", source_ids=sv1, sink_ids=sv2, mincut=False + ) + + # Undo + cg.undo_operation("test_user", split_result.operation_id) + + # After undo, the root's subgraph should contain both SVs again + restored_root = cg.get_root(sv1) + restored_leaves = set( + np.unique(cg.get_subgraph([restored_root], leaves_only=True)) + ) + assert sv1 in restored_leaves + assert sv2 in restored_leaves diff --git a/pychunkedgraph/tests/test_utils_flatgraph.py b/pychunkedgraph/tests/test_utils_flatgraph.py new file mode 100644 index 000000000..a46ebe3c2 --- /dev/null +++ b/pychunkedgraph/tests/test_utils_flatgraph.py @@ -0,0 +1,260 @@ +"""Tests for pychunkedgraph.graph.utils.flatgraph""" + +import numpy as np + +from pychunkedgraph.graph.utils.flatgraph import ( + build_gt_graph, + connected_components, + remap_ids_from_graph, + neighboring_edges, + harmonic_mean_paths, + remove_overlapping_edges, + check_connectedness, + adjust_affinities, + flatten_edge_list, + team_paths_all_to_all, +) + + +class TestBuildGtGraph: + def test_directed(self): + edges = np.array([[0, 1], [1, 2]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=True) + assert graph.is_directed() + assert graph.num_vertices() == 3 + assert graph.num_edges() == 2 + assert cap is None + + def test_undirected(self): + edges = np.array([[0, 1], [1, 2]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=False) + assert not graph.is_directed() + assert graph.num_vertices() == 3 + + def test_with_weights(self): + edges = np.array([[0, 1], [1, 2]], dtype=np.uint64) + weights = np.array([0.5, 0.9]) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, weights=weights) + assert cap is not None + + def test_make_directed(self): + edges = np.array([[0, 1]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, make_directed=True) + assert graph.is_directed() + # make_directed doubles edges (forward + reverse) + assert graph.num_edges() == 2 + + def test_unique_ids_remapping(self): + # Non-contiguous node IDs + edges = np.array([[100, 200], [200, 300]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges) + np.testing.assert_array_equal(unique_ids, [100, 200, 300]) + + +class TestConnectedComponents: + def test_two_components(self): + edges = np.array([[0, 1], [2, 3]], dtype=np.uint64) + graph, _, _, _ = build_gt_graph(edges, is_directed=False) + ccs = connected_components(graph) + assert len(ccs) == 2 + + def test_single_component(self): + edges = np.array([[0, 1], [1, 2]], dtype=np.uint64) + graph, _, _, _ = build_gt_graph(edges, is_directed=False) + ccs = connected_components(graph) + assert len(ccs) == 1 + + +class TestRemapIdsFromGraph: + def test_basic(self): + unique_ids = np.array([100, 200, 300], dtype=np.uint64) + graph_ids = np.array([0, 2]) + result = remap_ids_from_graph(graph_ids, unique_ids) + np.testing.assert_array_equal(result, [100, 300]) + + +class TestNeighboringEdges: + def test_basic(self): + """Build graph 0-1-2 (undirected), neighboring_edges(graph, 1) returns neighbors of vertex 1.""" + edges = np.array([[0, 1], [1, 2]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=False) + add_v, add_e, weights = neighboring_edges(graph, 1) + # Should return one list of vertices and one list of edges + assert len(add_v) == 1 + assert len(add_e) == 1 + # Vertex 1 has two neighbors (0 and 2) in undirected graph + neighbor_ids = sorted([int(v) for v in add_v[0]]) + assert len(neighbor_ids) == 2 + assert 0 in neighbor_ids + assert 2 in neighbor_ids + # Should return edges corresponding to those neighbors + assert len(add_e[0]) == 2 + # Weights is always [1] + assert weights == [1] + + def test_isolated_vertex(self): + """A vertex with no out-neighbors returns empty lists.""" + # Build a directed graph: 0->1. Vertex 1 has no out-neighbors. + edges = np.array([[0, 1]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=True) + add_v, add_e, weights = neighboring_edges(graph, 1) + assert len(add_v) == 1 + assert len(add_v[0]) == 0 + assert len(add_e) == 1 + assert len(add_e[0]) == 0 + + +class TestHarmonicMeanPaths: + def test_two_values(self): + """harmonic_mean_paths([4, 16]) should return geometric mean = 8.0""" + result = harmonic_mean_paths([4, 16]) + assert result == 8.0 + + def test_single_value(self): + """harmonic_mean_paths([9]) should return 9.0""" + result = harmonic_mean_paths([9]) + assert result == 9.0 + + +class TestRemoveOverlappingEdges: + def test_no_overlap(self): + """Two path sets with no shared vertices return the same edges, do_check=False.""" + # Build two separate graphs: 0-1 and 2-3 + edges = np.array([[0, 1], [2, 3]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=False) + # Paths for "team s": vertex 0 and vertex 1 with edge 0-1 + v0 = graph.vertex(0) + v1 = graph.vertex(1) + e01 = graph.edge(0, 1) + paths_v_s = [[v0, v1]] + paths_e_s = [[e01]] + + # Paths for "team y": vertex 2 and vertex 3 with edge 2-3 + v2 = graph.vertex(2) + v3 = graph.vertex(3) + e23 = graph.edge(2, 3) + paths_v_y = [[v2, v3]] + paths_e_y = [[e23]] + + out_s, out_y, do_check = remove_overlapping_edges( + paths_v_s, paths_e_s, paths_v_y, paths_e_y + ) + # No overlap, so do_check is False + assert do_check is False + # Original edges returned unchanged + assert out_s == paths_e_s + assert out_y == paths_e_y + + def test_with_overlap(self): + """Paths sharing some vertices cause overlapping edges to be removed, do_check=True.""" + # Build a linear graph: 0-1-2-3 (undirected) + edges = np.array([[0, 1], [1, 2], [2, 3]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=False) + # Team s path: 0-1-2 (shares vertex 1 and 2) + v0 = graph.vertex(0) + v1 = graph.vertex(1) + v2 = graph.vertex(2) + v3 = graph.vertex(3) + e01 = graph.edge(0, 1) + e12 = graph.edge(1, 2) + e23 = graph.edge(2, 3) + + paths_v_s = [[v0, v1, v2]] + paths_e_s = [[e01, e12]] + + # Team y path: 1-2-3 (shares vertex 1 and 2) + paths_v_y = [[v1, v2, v3]] + paths_e_y = [[e12, e23]] + + out_s, out_y, do_check = remove_overlapping_edges( + paths_v_s, paths_e_s, paths_v_y, paths_e_y + ) + assert do_check is True + # Overlapping vertices are 1 and 2 + # Edges touching vertices 1 or 2 should be removed + # All edges in both paths touch vertex 1 or 2, so both should be empty + assert len(out_s[0]) == 0 + assert len(out_y[0]) == 0 + + +class TestCheckConnectedness: + def test_connected(self): + """A connected set of edges returns True.""" + # Build a connected graph: 0-1-2 + edges = np.array([[0, 1], [1, 2]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=False) + v0 = graph.vertex(0) + v1 = graph.vertex(1) + v2 = graph.vertex(2) + e01 = graph.edge(0, 1) + e12 = graph.edge(1, 2) + + vertices = [[v0, v1, v2]] + edge_list = [[e01, e12]] + + assert check_connectedness(vertices, edge_list, expected_number=1) is True + + def test_disconnected(self): + """A disconnected set returns False (more than expected_number components).""" + # Build a graph with two disconnected components: 0-1 and 2-3 + edges = np.array([[0, 1], [2, 3]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=False) + v0 = graph.vertex(0) + v1 = graph.vertex(1) + v2 = graph.vertex(2) + v3 = graph.vertex(3) + e01 = graph.edge(0, 1) + e23 = graph.edge(2, 3) + + # Include all vertices but edges that form two components + vertices = [[v0, v1, v2, v3]] + edge_list = [[e01, e23]] + + # Expecting 1 component but there are 2, so should return False + assert check_connectedness(vertices, edge_list, expected_number=1) is False + + +class TestAdjustAffinities: + def test_basic(self): + """Build a graph with known capacities, adjust a subset, verify capacities changed.""" + edges = np.array([[0, 1], [1, 2]], dtype=np.uint64) + weights = np.array([0.5, 0.8]) + graph, cap, g_edges, unique_ids = build_gt_graph( + edges, weights=weights, make_directed=True + ) + assert cap is not None + + # Get the edge 0->1 and adjust its affinity + e01 = graph.edge(0, 1) + original_cap_01 = cap[e01] + assert original_cap_01 == 0.5 + + paths_e = [[e01]] + new_cap = adjust_affinities(graph, cap, paths_e, value=999.0) + + # The original capacity should be unchanged (adjust_affinities copies) + assert cap[e01] == 0.5 + # The new capacity for the adjusted edge should be 999.0 + assert new_cap[e01] == 999.0 + # The reverse edge should also be adjusted + e10 = graph.edge(1, 0) + assert new_cap[e10] == 999.0 + # Edge 1->2 should be unchanged + e12 = graph.edge(1, 2) + assert new_cap[e12] == 0.8 + + +class TestFlattenEdgeList: + def test_basic(self): + """Flatten a list of graph-tool edges to unique vertex indices.""" + edges = np.array([[0, 1], [1, 2], [2, 3]], dtype=np.uint64) + graph, cap, g_edges, unique_ids = build_gt_graph(edges, is_directed=False) + e01 = graph.edge(0, 1) + e12 = graph.edge(1, 2) + e23 = graph.edge(2, 3) + + paths_e = [[e01, e12], [e23]] + result = flatten_edge_list(paths_e) + # Should contain unique vertex indices from all edges + assert isinstance(result, np.ndarray) + assert set(result.tolist()) == {0, 1, 2, 3} diff --git a/pychunkedgraph/tests/test_utils_generic.py b/pychunkedgraph/tests/test_utils_generic.py new file mode 100644 index 000000000..b6c51ea31 --- /dev/null +++ b/pychunkedgraph/tests/test_utils_generic.py @@ -0,0 +1,175 @@ +"""Tests for pychunkedgraph.graph.utils.generic""" + +import datetime + +import numpy as np +import pytz +import pytest + +from pychunkedgraph.graph.utils.generic import ( + compute_indices_pandas, + log_n, + compute_bitmasks, + get_max_time, + get_min_time, + time_min, + get_valid_timestamp, + get_bounding_box, + filter_failed_node_ids, + _get_google_compatible_time_stamp, + mask_nodes_by_bounding_box, + get_parents_at_timestamp, +) + + +class TestLogN: + def test_base2(self): + assert log_n(8, 2) == pytest.approx(3.0) + + def test_base10(self): + assert log_n(1000, 10) == pytest.approx(3.0) + + def test_other_base(self): + assert log_n(27, 3) == pytest.approx(3.0) + + def test_array_input(self): + result = log_n(np.array([4, 8, 16]), 2) + np.testing.assert_array_almost_equal(result, [2.0, 3.0, 4.0]) + + +class TestComputeBitmasks: + def test_basic(self): + bm = compute_bitmasks(4) + assert 1 in bm + assert 2 in bm + assert 3 in bm + assert 4 in bm + + def test_layer_1_equals_layer_2(self): + bm = compute_bitmasks(5) + assert bm[1] == bm[2] + + def test_insufficient_bits_raises(self): + with pytest.raises(ValueError, match="not enough"): + compute_bitmasks(4, s_bits_atomic_layer=0) + + +class TestTimeFunctions: + def test_get_max_time(self): + t = get_max_time() + assert isinstance(t, datetime.datetime) + assert t.year == 9999 + + def test_get_min_time(self): + t = get_min_time() + assert isinstance(t, datetime.datetime) + assert t.year == 2000 + + def test_time_min(self): + assert time_min() == get_min_time() + + +class TestGetValidTimestamp: + def test_none_returns_utc_now(self): + before = datetime.datetime.now(datetime.timezone.utc) + result = get_valid_timestamp(None) + after = datetime.datetime.now(datetime.timezone.utc) + assert result.tzinfo is not None + # get_valid_timestamp rounds down to millisecond precision, + # so result may be slightly before `before` + tolerance = datetime.timedelta(milliseconds=1) + assert before - tolerance <= result <= after + + def test_naive_gets_localized(self): + naive = datetime.datetime(2023, 6, 15, 12, 0, 0) + result = get_valid_timestamp(naive) + assert result.tzinfo is not None + + def test_aware_passthrough(self): + aware = datetime.datetime(2023, 6, 15, 12, 0, 0, tzinfo=pytz.UTC) + result = get_valid_timestamp(aware) + assert result.tzinfo is not None + + +class TestGoogleCompatibleTimestamp: + def test_round_down(self): + ts = datetime.datetime(2023, 6, 15, 12, 0, 0, 1500) + result = _get_google_compatible_time_stamp(ts, round_up=False) + assert result.microsecond % 1000 == 0 + assert result.microsecond == 1000 + + def test_round_up(self): + ts = datetime.datetime(2023, 6, 15, 12, 0, 0, 1500) + result = _get_google_compatible_time_stamp(ts, round_up=True) + assert result.microsecond % 1000 == 0 + assert result.microsecond == 2000 + + def test_exact_no_change(self): + ts = datetime.datetime(2023, 6, 15, 12, 0, 0, 3000) + result = _get_google_compatible_time_stamp(ts) + assert result == ts + + +class TestGetBoundingBox: + def test_normal(self): + source = np.array([[10, 20, 30]]) + sink = np.array([[50, 60, 70]]) + bbox = get_bounding_box(source, sink, bb_offset=(5, 5, 5)) + np.testing.assert_array_equal(bbox[0], [5, 15, 25]) + np.testing.assert_array_equal(bbox[1], [55, 65, 75]) + + def test_none_coords(self): + assert get_bounding_box(None, [[1, 2, 3]]) is None + assert get_bounding_box([[1, 2, 3]], None) is None + + +class TestFilterFailedNodeIds: + def test_basic(self): + row_ids = np.array([10, 20, 30, 40], dtype=np.uint64) + segment_ids = np.array([4, 3, 2, 1], dtype=np.uint64) + max_children_ids = np.array([100, 100, 200, 200]) + result = filter_failed_node_ids(row_ids, segment_ids, max_children_ids) + # Only the first occurrence of each max_children_id (by descending segment_id) survives + assert len(result) == 2 + + +class TestMaskNodesByBoundingBox: + def test_none_bbox(self): + nodes = np.array([1, 2, 3], dtype=np.uint64) + result = mask_nodes_by_bounding_box(None, nodes, bounding_box=None) + assert np.all(result) + + +class TestGetParentsAtTimestamp: + def test_normal_lookup(self): + ts1 = datetime.datetime(2023, 1, 1) + ts2 = datetime.datetime(2023, 6, 1) + ts_map = { + 10: {ts2: 100, ts1: 50}, + } + parents, skipped = get_parents_at_timestamp([10], ts_map, ts2) + assert 100 in parents + assert len(skipped) == 0 + + def test_missing_key(self): + parents, skipped = get_parents_at_timestamp([99], {}, datetime.datetime.now()) + assert len(parents) == 0 + assert 99 in skipped + + def test_unique(self): + ts = datetime.datetime(2023, 6, 1) + ts_map = { + 10: {ts: 100}, + 20: {ts: 100}, + } + parents, _ = get_parents_at_timestamp([10, 20], ts_map, ts, unique=True) + assert len(parents) == 1 + + +class TestComputeIndicesPandas: + def test_basic(self): + data = np.array([1, 2, 1, 2, 3]) + result = compute_indices_pandas(data) + assert 1 in result.index + assert 2 in result.index + assert 3 in result.index diff --git a/pychunkedgraph/tests/test_utils_id_helpers.py b/pychunkedgraph/tests/test_utils_id_helpers.py new file mode 100644 index 000000000..df8349962 --- /dev/null +++ b/pychunkedgraph/tests/test_utils_id_helpers.py @@ -0,0 +1,232 @@ +"""Tests for pychunkedgraph.graph.utils.id_helpers""" + +from unittest.mock import MagicMock + +import numpy as np + +from pychunkedgraph.graph.utils import id_helpers +from pychunkedgraph.graph.chunks import utils as chunk_utils + +from .helpers import to_label + + +class TestGetSegmentIdLimit: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + node_id = to_label(graph, 1, 0, 0, 0, 1) + limit = id_helpers.get_segment_id_limit(graph.meta, node_id) + assert limit > 0 + assert isinstance(limit, np.uint64) + + +class TestGetSegmentId: + def test_basic(self, gen_graph): + graph = gen_graph(n_layers=4) + node_id = to_label(graph, 1, 0, 0, 0, 42) + seg_id = id_helpers.get_segment_id(graph.meta, node_id) + assert seg_id == 42 + + +class TestGetNodeId: + def test_from_chunk_id(self, gen_graph): + graph = gen_graph(n_layers=4) + chunk_id = chunk_utils.get_chunk_id(graph.meta, layer=1, x=0, y=0, z=0) + node_id = id_helpers.get_node_id( + graph.meta, segment_id=np.uint64(5), chunk_id=chunk_id + ) + assert id_helpers.get_segment_id(graph.meta, node_id) == 5 + assert chunk_utils.get_chunk_layer(graph.meta, node_id) == 1 + + def test_from_components(self, gen_graph): + graph = gen_graph(n_layers=4) + node_id = id_helpers.get_node_id( + graph.meta, segment_id=np.uint64(7), layer=2, x=1, y=2, z=3 + ) + assert id_helpers.get_segment_id(graph.meta, node_id) == 7 + assert chunk_utils.get_chunk_layer(graph.meta, node_id) == 2 + coords = chunk_utils.get_chunk_coordinates(graph.meta, node_id) + np.testing.assert_array_equal(coords, [1, 2, 3]) + + +class TestGetAtomicIdFromCoord: + def test_exact_hit(self): + """When the voxel at (x, y, z) contains an atomic ID whose root matches, return it.""" + meta = MagicMock() + meta.data_source.CV_MIP = 0 + # meta.cv[x_l:x_h, y_l:y_h, z_l:z_h] returns an array block. + # For i_try=0: x_l = x - (-1)^2 = x-1, but clamped to 0 if negative; + # x_h = x + 1 + (-1)^2 = x+2. With x=0: x_l=0, x_h=2, etc. + # Simplest: put target atomic_id=42 everywhere in a small block. + meta.cv.__getitem__ = MagicMock(return_value=np.array([[[42]]])) + + root_id = np.uint64(100) + + def fake_get_root(node_id, time_stamp=None): + if node_id == 42: + return root_id + return root_id # same root for all + + result = id_helpers.get_atomic_id_from_coord( + meta, fake_get_root, 0, 0, 0, np.uint64(42), n_tries=1 + ) + assert result == np.uint64(42) + + def test_returns_none_when_no_match(self): + """When no candidate atomic ID shares the same root, return None.""" + meta = MagicMock() + meta.data_source.CV_MIP = 0 + # Return only zeros (background) from cloudvolume + meta.cv.__getitem__ = MagicMock(return_value=np.array([[[0]]])) + + root_id = np.uint64(100) + + def fake_get_root(node_id, time_stamp=None): + return root_id + + result = id_helpers.get_atomic_id_from_coord( + meta, fake_get_root, 5, 5, 5, np.uint64(999), n_tries=1 + ) + # Only candidate is 0, which is skipped, so result should be None + assert result is None + + def test_mip_scaling(self): + """Coordinates should be scaled by CV_MIP for x and y but not z.""" + meta = MagicMock() + meta.data_source.CV_MIP = 2 # scale factor of 4 for x,y + + call_args = [] + + def capture_getitem(self_mock, key): + call_args.append(key) + return np.array([[[7]]]) + + meta.cv.__getitem__ = capture_getitem + + root_id = np.uint64(200) + + def fake_get_root(node_id, time_stamp=None): + return root_id + + result = id_helpers.get_atomic_id_from_coord( + meta, fake_get_root, 8, 12, 3, np.uint64(7), n_tries=1 + ) + assert result == np.uint64(7) + # Verify that the function was called (coordinates are scaled) + assert len(call_args) >= 1 + + def test_retry_expands_search(self): + """With multiple tries, the search area should expand to find a matching ID.""" + meta = MagicMock() + meta.data_source.CV_MIP = 0 + + target_root = np.uint64(500) + wrong_root = np.uint64(999) + call_count = [0] + + def expanding_getitem(self_mock, key): + call_count[0] += 1 + if call_count[0] == 1: + # First try returns a non-matching ID + return np.array([[[10]]]) + else: + # Second try returns the matching ID + return np.array([[[10, 42]], [[10, 42]]]) + + meta.cv.__getitem__ = expanding_getitem + + def fake_get_root(node_id, time_stamp=None): + if node_id == 42: + return target_root + return wrong_root + + # parent_id=42 -> root=500; candidates: try1 has only 10 (root=999), try2 has 42 (root=500) + result = id_helpers.get_atomic_id_from_coord( + meta, fake_get_root, 5, 5, 5, np.uint64(42), n_tries=3 + ) + assert result == np.uint64(42) + assert call_count[0] >= 2 + + +class TestGetAtomicIdsFromCoords: + def test_layer1_returns_parent_id(self): + """When parent_id is already layer 1, return parent_id for all coordinates.""" + meta = MagicMock() + meta.data_source.CV_MIP = 0 + meta.resolution = np.array([1, 1, 1]) + + parent_id = np.uint64(42) + coordinates = np.array([[10, 20, 30], [40, 50, 60]]) + + def fake_get_roots( + node_ids, time_stamp=None, stop_layer=None, fail_to_zero=False + ): + return np.array([parent_id] * len(node_ids), dtype=np.uint64) + + result = id_helpers.get_atomic_ids_from_coords( + meta, + coordinates=coordinates, + parent_id=parent_id, + parent_id_layer=1, + parent_ts=None, + get_roots=fake_get_roots, + ) + + np.testing.assert_array_equal(result, [parent_id, parent_id]) + + def test_higher_layer_with_mock_cv(self): + """Test with a mocked CloudVolume that returns a known segmentation block.""" + meta = MagicMock() + meta.data_source.CV_MIP = 0 + meta.resolution = np.array([8, 8, 40]) + + parent_id = np.uint64(100) + sv1 = np.uint64(10) + sv2 = np.uint64(20) + + # Create a small segmentation volume (the CV mock) + # Coordinates: two points at [5, 5, 5] and [6, 5, 5] + coordinates = np.array([[5, 5, 5], [6, 5, 5]]) + max_dist_nm = 150 + max_dist_vx = np.ceil(max_dist_nm / np.array([8, 8, 40])).astype(np.int32) + + # Build a segmentation block big enough for the bounding box + bbox_min = np.min(coordinates, axis=0) - max_dist_vx + bbox_max = np.max(coordinates, axis=0) + max_dist_vx + 1 + shape = bbox_max - bbox_min + + seg_block = np.zeros(tuple(shape), dtype=np.uint64) + # Place sv1 at relative position of coordinate [5,5,5] + rel1 = coordinates[0] - bbox_min + seg_block[rel1[0], rel1[1], rel1[2]] = sv1 + # Place sv2 at relative position of coordinate [6,5,5] + rel2 = coordinates[1] - bbox_min + seg_block[rel2[0], rel2[1], rel2[2]] = sv2 + + meta.cv.__getitem__ = MagicMock(return_value=seg_block) + + def fake_get_roots( + node_ids, time_stamp=None, stop_layer=None, fail_to_zero=False + ): + # Map sv1 and sv2 to parent_id, everything else to 0 + result = [] + for nid in node_ids: + if nid == sv1 or nid == sv2: + result.append(parent_id) + else: + result.append(np.uint64(0)) + return np.array(result, dtype=np.uint64) + + result = id_helpers.get_atomic_ids_from_coords( + meta, + coordinates=coordinates, + parent_id=parent_id, + parent_id_layer=2, + parent_ts=None, + get_roots=fake_get_roots, + ) + + assert result is not None + assert len(result) == 2 + # Each coordinate should map to one of our supervoxels + assert np.uint64(result[0]) == sv1 + assert np.uint64(result[1]) == sv2 diff --git a/pychunkedgraph/utils/general.py b/pychunkedgraph/utils/general.py index 71e24eab0..533395f47 100644 --- a/pychunkedgraph/utils/general.py +++ b/pychunkedgraph/utils/general.py @@ -1,9 +1,20 @@ """ generic helper funtions """ + from typing import Sequence from itertools import islice +try: + from itertools import batched +except ImportError: + # Python < 3.12 fallback + def batched(iterable, n): + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch + + import numpy as np @@ -24,18 +35,13 @@ def reverse_dictionary(dictionary): def chunked(l: Sequence, n: int): - """ - Yield successive n-sized chunks from l. - NOTE: Use itertools.batched from python 3.12 - """ + """Yield successive n-sized chunks from l.""" if n < 1: n = len(l) - it = iter(l) - while batch := tuple(islice(it, n)): - yield batch + yield from batched(l, n) def in2d(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray: arr1_view = arr1.view(dtype="u8,u8").reshape(arr1.shape[0]) arr2_view = arr2.view(dtype="u8,u8").reshape(arr2.shape[0]) - return np.in1d(arr1_view, arr2_view) + return np.isin(arr1_view, arr2_view) diff --git a/pychunkedgraph/utils/redis.py b/pychunkedgraph/utils/redis.py index 420a849f1..82921f030 100644 --- a/pychunkedgraph/utils/redis.py +++ b/pychunkedgraph/utils/redis.py @@ -18,18 +18,20 @@ ) REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD", "") REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0" +CONNECTION = redis.Redis.from_url(REDIS_URL, socket_timeout=60) -keys_fields = ("INGESTION_MANAGER",) -keys_defaults = ("pcg:imanager",) +keys_fields = ("INGESTION_MANAGER", "JOB_TYPE") +keys_defaults = ("pcg:imanager", "pcg:job_type") Keys = namedtuple("keys", keys_fields, defaults=keys_defaults) keys = Keys() def get_redis_connection(redis_url=REDIS_URL): - return redis.Redis.from_url(redis_url) + if redis_url == REDIS_URL: + return CONNECTION + return redis.Redis.from_url(redis_url, socket_timeout=60) def get_rq_queue(queue): - connection = redis.Redis.from_url(REDIS_URL) - return Queue(queue, connection=connection) + return Queue(queue, connection=CONNECTION) diff --git a/requirements-dev.txt b/requirements-dev.txt index 9b1a97928..1b24f9ecb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,8 @@ pylint black +pre-commit pyopenssl jupyter -codecov ipython pytest pytest-cov diff --git a/requirements.in b/requirements.in index 63e0b3472..4bd56780b 100644 --- a/requirements.in +++ b/requirements.in @@ -5,29 +5,30 @@ grpcio>=1.36.1 numpy pandas networkx>=2.1 -google-cloud-bigtable>=0.33.0 +google-cloud-bigtable>=2.0.0 google-cloud-datastore>=1.8 flask flask_cors python-json-logger -redis -rq<2 +redis>7 +rq>2 pyyaml cachetools werkzeug +tensorstore # PyPI only: -cloud-files>=4.21.1 -cloud-volume>=8.26.0 +cloud-files>=6.0.0 +cloud-volume>=12.0.0 multiwrapper middle-auth-client>=3.11.0 zmesh>=1.7.0 fastremap>=1.14.0 -task-queue>=2.13.0 +task-queue>=2.14.0 messagingclient dracopy>=1.3.0 datastoreflex>=0.5.0 -zstandard==0.21.0 +zstandard>=0.23.0 # Conda only - use requirements.yml (or install manually): # graph-tool \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5a2f18adc..5005893d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,119 +1,120 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --output-file=requirements.txt requirements.in # -attrs==23.1.0 +attrs==25.4.0 # via # jsonschema # referencing -blinker==1.6.2 +blinker==1.9.0 # via flask -boto3==1.28.52 +boto3==1.42.53 # via # cloud-files # cloud-volume # task-queue -botocore==1.31.52 +botocore==1.42.53 # via # boto3 # s3transfer -brotli==1.1.0 +brotli==1.2.0 # via # cloud-files # urllib3 -cachetools==5.3.1 +cachetools==7.0.1 # via # -r requirements.in - # google-auth # middle-auth-client -certifi==2023.7.22 +certifi==2026.1.4 # via requests +cffi==2.0.0 + # via cryptography chardet==5.2.0 # via # cloud-files # cloud-volume -charset-normalizer==3.2.0 +charset-normalizer==3.4.4 # via requests -click==8.1.7 +click==8.3.1 # via # -r requirements.in # cloud-files # compressed-segmentation - # compresso # flask + # microviewer # rq # task-queue -cloud-files==4.21.1 +cloud-files==6.2.1 # via # -r requirements.in # cloud-volume # datastoreflex -cloud-volume==8.26.0 +cloud-volume==12.10.0 # via -r requirements.in -compressed-segmentation==2.2.1 - # via cloud-volume -compresso==3.2.1 +compressed-segmentation==2.3.2 # via cloud-volume -crackle-codec==0.7.0 - # via cloud-volume -crc32c==2.3.post0 +crc32c==2.8 # via cloud-files +croniter==6.0.0 + # via rq +cryptography==46.0.5 + # via google-auth datastoreflex==0.5.0 # via -r requirements.in -deflate==0.4.0 +deflate==0.8.1 # via cloud-files -dill==0.3.7 +dill==0.4.1 # via # multiprocess # pathos -dracopy==1.3.0 +dracopy==1.7.0 # via # -r requirements.in # cloud-volume -fasteners==0.19 +fasteners==0.20 # via cloud-files -fastremap==1.14.0 +fastremap==1.17.7 # via # -r requirements.in # cloud-volume - # crackle-codec -flask==2.3.3 + # osteoid +flask==3.1.3 # via # -r requirements.in # flask-cors # middle-auth-client -flask-cors==4.0.0 +flask-cors==6.0.2 # via -r requirements.in -fpzip==1.2.2 - # via cloud-volume -furl==2.1.3 +furl==2.1.4 # via middle-auth-client -gevent==23.9.1 +gevent==25.9.1 # via # cloud-files # cloud-volume # task-queue -google-api-core[grpc]==2.11.1 +google-api-core[grpc]==2.30.0 # via - # google-api-core # google-cloud-bigtable # google-cloud-core # google-cloud-datastore # google-cloud-pubsub # google-cloud-storage -google-auth==2.23.0 +google-auth==2.48.0 # via # cloud-files # cloud-volume # google-api-core + # google-cloud-bigtable # google-cloud-core + # google-cloud-datastore + # google-cloud-pubsub # google-cloud-storage # task-queue -google-cloud-bigtable==2.21.0 +google-cloud-bigtable==2.35.0 # via -r requirements.in -google-cloud-core==2.3.3 +google-cloud-core==2.5.0 # via # cloud-files # cloud-volume @@ -121,135 +122,158 @@ google-cloud-core==2.3.3 # google-cloud-datastore # google-cloud-storage # task-queue -google-cloud-datastore==2.18.0 +google-cloud-datastore==2.23.0 # via # -r requirements.in # datastoreflex -google-cloud-pubsub==2.18.4 +google-cloud-pubsub==2.35.0 # via messagingclient -google-cloud-storage==2.11.0 +google-cloud-storage==3.9.0 # via # cloud-files # cloud-volume -google-crc32c==1.5.0 +google-crc32c==1.8.0 # via # cloud-files + # google-cloud-bigtable + # google-cloud-storage # google-resumable-media -google-resumable-media==2.6.0 +google-resumable-media==2.8.0 # via google-cloud-storage -googleapis-common-protos[grpc]==1.60.0 +googleapis-common-protos[grpc]==1.72.0 # via # google-api-core # grpc-google-iam-v1 # grpcio-status -greenlet==3.0.0rc3 +greenlet==3.3.1 # via gevent -grpc-google-iam-v1==0.12.6 +grpc-google-iam-v1==0.14.3 # via # google-cloud-bigtable # google-cloud-pubsub -grpcio==1.58.0 +grpcio==1.78.0 # via # -r requirements.in # google-api-core + # google-cloud-datastore # google-cloud-pubsub # googleapis-common-protos # grpc-google-iam-v1 # grpcio-status -grpcio-status==1.58.0 +grpcio-status==1.78.0 # via # google-api-core # google-cloud-pubsub -idna==3.4 +idna==3.11 # via requests +importlib-metadata==8.7.1 + # via opentelemetry-api inflection==0.5.1 # via python-jsonschema-objects -iniconfig==2.0.0 +iniconfig==2.3.0 # via pytest -itsdangerous==2.1.2 +intervaltree==3.2.1 + # via cloud-files +itsdangerous==2.2.0 # via flask -jinja2==3.1.3 +jinja2==3.1.6 # via flask -jmespath==1.0.1 +jmespath==1.1.0 # via # boto3 # botocore -json5==0.9.14 +json5==0.13.0 # via cloud-volume -jsonschema==4.19.1 +jsonschema==4.26.0 # via # cloud-volume # python-jsonschema-objects -jsonschema-specifications==2023.7.1 +jsonschema-specifications==2025.9.1 # via jsonschema -markdown==3.4.4 +markdown==3.10.2 # via python-jsonschema-objects -markupsafe==2.1.3 +markupsafe==3.0.3 # via + # flask # jinja2 # werkzeug -messagingclient==0.1.3 +messagingclient==0.3.0 # via -r requirements.in -middle-auth-client==3.16.1 +microviewer==1.20.0 + # via cloud-volume +middle-auth-client==3.19.2 # via -r requirements.in -multiprocess==0.70.15 +ml-dtypes==0.5.4 + # via tensorstore +multiprocess==0.70.19 # via pathos multiwrapper==0.1.1 # via -r requirements.in -networkx==3.1 +networkx==3.6.1 # via # -r requirements.in # cloud-volume -numpy==1.26.0 + # osteoid +numpy==2.4.2 # via # -r requirements.in + # cloud-files # cloud-volume # compressed-segmentation - # compresso - # crackle-codec # fastremap - # fpzip # messagingclient + # microviewer + # ml-dtypes # multiwrapper + # osteoid # pandas - # pyspng-seunglab # simplejpeg # task-queue - # zfpc + # tensorstore # zmesh -orderedmultidict==1.0.1 +opentelemetry-api==1.39.1 + # via + # google-cloud-pubsub + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.39.1 + # via google-cloud-pubsub +opentelemetry-semantic-conventions==0.60b1 + # via opentelemetry-sdk +orderedmultidict==1.0.2 # via furl -orjson==3.9.7 +orjson==3.11.7 # via # cloud-files # task-queue -packaging==23.1 +osteoid==0.6.0 + # via cloud-volume +packaging==26.0 # via pytest -pandas==2.1.1 +pandas==3.0.1 # via -r requirements.in -pathos==0.3.1 +pathos==0.3.5 # via # cloud-files # cloud-volume # task-queue -pbr==5.11.1 +pbr==7.0.3 # via task-queue -pillow==10.0.1 - # via cloud-volume -pluggy==1.3.0 +pluggy==1.6.0 # via pytest -posix-ipc==1.1.1 +posix-ipc==1.3.2 # via cloud-volume -pox==0.3.3 +pox==0.3.7 # via pathos -ppft==1.7.6.7 +ppft==1.7.8 # via pathos -proto-plus==1.22.3 +proto-plus==1.27.1 # via + # google-api-core # google-cloud-bigtable # google-cloud-datastore # google-cloud-pubsub -protobuf==4.24.3 +protobuf==6.33.5 # via # -r requirements.in # cloud-files @@ -262,44 +286,47 @@ protobuf==4.24.3 # grpc-google-iam-v1 # grpcio-status # proto-plus -psutil==5.9.5 +psutil==7.2.2 # via cloud-volume -pyasn1==0.5.0 +pyasn1==0.6.2 # via # pyasn1-modules # rsa -pyasn1-modules==0.3.0 +pyasn1-modules==0.4.2 # via google-auth -pybind11==2.11.1 - # via crackle-codec -pysimdjson==5.0.2 - # via cloud-volume -pyspng-seunglab==1.1.0 +pybind11==3.0.2 + # via osteoid +pycparser==3.0 + # via cffi +pygments==2.19.2 + # via pytest +pysimdjson==7.0.2 # via cloud-volume -pytest==7.4.2 +pytest==9.0.2 # via compressed-segmentation -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via # botocore # cloud-volume + # croniter # pandas -python-json-logger==2.0.7 +python-json-logger==4.0.0 # via -r requirements.in -python-jsonschema-objects==0.5.0 +python-jsonschema-objects==0.5.7 # via cloud-volume -pytz==2023.3.post1 - # via pandas -pyyaml==6.0.1 +pytz==2025.2 + # via croniter +pyyaml==6.0.3 # via -r requirements.in -redis==5.0.0 +redis==7.2.0 # via # -r requirements.in # rq -referencing==0.30.2 +referencing==0.37.0 # via # jsonschema # jsonschema-specifications -requests==2.31.0 +requests==2.32.5 # via # -r requirements.in # cloud-files @@ -308,64 +335,69 @@ requests==2.31.0 # google-cloud-storage # middle-auth-client # task-queue -rpds-py==0.10.3 +rpds-py==0.30.0 # via # jsonschema # referencing -rq==1.15.1 +rq==2.6.1 # via -r requirements.in -rsa==4.9 +rsa==4.9.1 # via # cloud-files # google-auth -s3transfer==0.6.2 +s3transfer==0.16.0 # via boto3 -simplejpeg==1.7.2 +simplejpeg==1.9.0 # via cloud-volume -six==1.16.0 +six==1.17.0 # via # cloud-files - # cloud-volume # furl # orderedmultidict # python-dateutil - # python-jsonschema-objects -task-queue==2.13.0 +sortedcontainers==2.4.0 + # via intervaltree +task-queue==2.14.3 # via -r requirements.in -tenacity==8.2.3 +tenacity==9.1.4 # via # cloud-files # cloud-volume # task-queue -tqdm==4.66.1 +tensorstore==0.1.81 + # via -r requirements.in +tqdm==4.67.3 # via # cloud-files # cloud-volume # task-queue -tzdata==2023.3 - # via pandas -urllib3[brotli]==1.26.16 +typing-extensions==4.15.0 + # via + # grpcio + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # referencing +urllib3[brotli]==2.6.3 # via # botocore # cloud-files # cloud-volume - # google-auth # requests -werkzeug==2.3.8 +werkzeug==3.1.6 # via # -r requirements.in # flask -zfpc==0.1.2 - # via cloud-volume -zfpy==1.0.0 - # via zfpc -zmesh==1.7.0 + # flask-cors +zipp==3.23.0 + # via importlib-metadata +zmesh==1.10.0 # via -r requirements.in -zope-event==5.0 +zope-event==6.1 # via gevent -zope-interface==6.0 +zope-interface==8.2 # via gevent -zstandard==0.21.0 +zstandard==0.25.0 # via # -r requirements.in # cloud-files diff --git a/requirements.yml b/requirements.yml index 0bfa5b227..9b8bc536e 100644 --- a/requirements.yml +++ b/requirements.yml @@ -2,12 +2,12 @@ name: pychunkedgraph channels: - conda-forge dependencies: - - python==3.11.4 + - python==3.12.8 - pip - tox - - uwsgi==2.0.21 - - graph-tool-base==2.58 - - zstandard==0.19.0 # ugly hack to force PyPi install 0.21.0 + - numpy + - uwsgi + - graph-tool-base==2.98 - pip: - -r requirements.txt - -r requirements-dev.txt \ No newline at end of file diff --git a/setup.py b/setup.py index e71fcab1b..077fb23df 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ def find_version(*file_paths): description="Proofreading backend for Neuroglancer", long_description=long_description, long_description_content_type="text/markdown", - url="https://github.com/seung-lab/PyChunkedGraph", + url="https://github.com/CAVEconnectome/PyChunkedGraph", packages=find_packages(), install_requires=required, dependency_links=dependency_links, diff --git a/tox.ini b/tox.ini index 5398564e6..bb15fef19 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py311 +envlist = py312 requires = tox-conda [testenv] diff --git a/tracker.py b/tracker.py deleted file mode 100644 index d2ae63cb3..000000000 --- a/tracker.py +++ /dev/null @@ -1,22 +0,0 @@ -import sys -from rq import Connection, Worker - -# Preload libraries from pychunkedgraph.ingest.cluster -from typing import Sequence, Tuple - -import numpy as np - -from pychunkedgraph.ingest.utils import chunk_id_str -from pychunkedgraph.ingest.manager import IngestionManager -from pychunkedgraph.ingest.common import get_atomic_chunk_data -from pychunkedgraph.ingest.ran_agglomeration import get_active_edges -from pychunkedgraph.ingest.create.atomic_layer import add_atomic_edges -from pychunkedgraph.ingest.create.abstract_layers import add_layer -from pychunkedgraph.graph.meta import ChunkedGraphMeta -from pychunkedgraph.graph.chunks.hierarchy import get_children_chunk_coords -from pychunkedgraph.utils.redis import keys as r_keys -from pychunkedgraph.utils.redis import get_redis_connection - -qs = sys.argv[1:] -w = Worker(qs, connection=get_redis_connection()) -w.work() \ No newline at end of file