Skip to content

nukolas/fteikpy-jax

 
 

Repository files navigation

fteikpy

License Stars Pyversions Version Downloads Code style: black Codacy Badge Codecov Build Docs DOI

fteikpy is a Python library that computes accurate first arrival traveltimes in 2D and 3D heterogeneous isotropic velocity models. The algorithm handles properly the curvature of wavefronts close to the source which can be placed without any problem between grid points.

The code is based on FTeik implemented in Python and compiled just-in-time with numba.

sample-marmousi

Computation of traveltimes and ray-tracing on smoothed Marmousi velocity model.

Features

Forward modeling:

  • Compute traveltimes in 2D and 3D Cartesian grids with the possibility to use a different grid spacing in Z, X and Y directions,
  • Compute traveltime gradients at runtime or a posteriori,
  • A posteriori 2D and 3D ray-tracing.

Parallel:

  • Traveltime grids are seemlessly computed in parallel for different sources,
  • Raypaths from a given source to different locations are also evaluated in parallel.

Installation

The recommended way to install fteikpy and all its dependencies is through the Python Package Index:

pip install fteikpy --user

Otherwise, clone and extract the package, then run from the package location:

pip install . --user

To test the integrity of the installed package, check out this repository and run:

pytest

Documentation

Refer to the online documentation for detailed description of the API and examples.

JAX backend and misfit gradients

An optional JAX backend is provided for differentiable forward modeling and reverse-mode autodiff of a simple L2 traveltime misfit. If jax is installed, you can compute gradients with respect to the velocity model:

import jax
import jax.numpy as jnp
from fteikpy import misfit_l2_jax, grad_misfit_l2_jax

# Velocity model (nz, nx, ny), grid spacings, sources and receivers
velocity = jnp.ones((32, 32, 32)) * 3000.0  # m/s
gridsize = (10.0, 10.0, 10.0)               # (dz, dx, dy)
sources = jnp.array([[50.0, 50.0, 50.0]])   # one source (z, x, y)
receivers = jnp.array([[200.0, 200.0, 200.0]])
t_obs = jnp.array([[0.0]])  # example data

# Compute scalar misfit and gradient wrt velocity
J = misfit_l2_jax(velocity, gridsize, sources, receivers, t_obs)
dJ_dv = grad_misfit_l2_jax(velocity, gridsize, sources, receivers, t_obs)

Notes

  • The JAX solver reimplements the fast-sweeping Eikonal update in pure jax.numpy and jax.lax to enable JIT and autodiff. The result grid shape is (nz+1, nx+1, ny+1) to keep parity with the original backend.
  • Both the traveltime solver and misfit helper dispatch to jax.jit-compiled kernels; expect a compilation hit on the first call and cached performance afterwards.
  • The misfit helper assumes all sources share the same receiver set for simplicity. Support for per-source receiver sets can be added similarly.
  • Nondifferentiability can occur at branch points of the min-operators; JAX provides subgradients almost everywhere but gradients may be undefined exactly at kinks.

Development with uv

This project now uses uv for dependency management. Common commands:

  • Install dev deps: uv sync --dev
  • Run tests: uv run -q pytest
  • Lint check: uv run -q bash -lc 'black --check fteikpy && isort --check fteikpy && docformatter -c -r fteikpy'
  • Format: uv run -q bash -lc 'black -t py38 fteikpy && isort fteikpy && docformatter -r -i --blank --wrap-summaries 88 --wrap-descriptions 88 --pre-summary-newline fteikpy'

Alternatively, the documentation can be built using Sphinx via uv:

uv sync --dev --extra doc
uv run -q sphinx-build -b html doc/source doc/build

Usage

The following example computes the traveltime grid in a 3D homogeneous velocity model:

import numpy as np
from fteikpy import Eikonal3D

# Velocity model
velocity_model = np.ones((8, 8, 8))
dz, dx, dy = 1.0, 1.0, 1.0

# Solve Eikonal at source
eik = Eikonal3D(velocity_model, gridsize=(dz, dx, dy))
tt = eik.solve((0.0, 0.0, 0.0))

# Get traveltime at specific grid point
t1 = tt[0, 1, 2]

# Or get traveltime at any point in the grid
t2 = tt(np.random.rand(3) * 7.0)

Contributing

Please refer to the Contributing Guidelines to see how you can help. This project is released with a Code of Conduct which you agree to abide by when contributing.

About

Accurate Eikonal solver for Python with JAX autodiff

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%