From 12ddcf773ec821a4db6abf232b10a8c7b324afe6 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 13 Mar 2026 17:15:03 +0100 Subject: [PATCH] feat!: rewrite libmysofa in Rust --- .github/workflows/audit-on-push.yml | 16 +- .github/workflows/general.yml | 100 +-- .github/workflows/release-plz.yml | 46 ++ Cargo.toml | 20 +- NOTICE | 20 + README.md | 23 +- benches/renderer.rs | 10 +- deny.toml | 29 + examples/renderer.rs | 104 +-- src/filter.rs | 23 + src/hdf/btree.rs | 187 ++++++ src/hdf/data_object.rs | 319 +++++++++ src/hdf/fractal_heap.rs | 536 +++++++++++++++ src/hdf/gcol.rs | 185 ++++++ src/hdf/helpers.rs | 60 ++ src/hdf/mod.rs | 44 ++ src/hdf/ohdr_message.rs | 976 ++++++++++++++++++++++++++++ src/hdf/parser.rs | 184 ++++++ src/hdf/super_block.rs | 183 ++++++ src/lib.rs | 30 +- src/reader.rs | 461 +++++++------ src/render.rs | 215 +++--- src/sofa/coords.rs | 200 ++++++ src/sofa/error.rs | 44 ++ src/sofa/interpolate.rs | 283 ++++++++ src/sofa/kdtree.rs | 306 +++++++++ src/sofa/lookup.rs | 190 ++++++ src/sofa/loudness.rs | 147 +++++ src/sofa/mod.rs | 27 + src/sofa/neighbors.rs | 245 +++++++ src/sofa/reader.rs | 476 ++++++++++++++ src/sofa/resample.rs | 200 ++++++ src/sofa/types.rs | 78 +++ src/sofa/validate.rs | 324 +++++++++ tests/spatial_verify.rs | 104 +++ 35 files changed, 5932 insertions(+), 463 deletions(-) create mode 100644 .github/workflows/release-plz.yml create mode 100644 NOTICE create mode 100644 deny.toml create mode 100644 src/filter.rs create mode 100644 src/hdf/btree.rs create mode 100644 src/hdf/data_object.rs create mode 100644 src/hdf/fractal_heap.rs create mode 100644 src/hdf/gcol.rs create mode 100644 src/hdf/helpers.rs create mode 100644 src/hdf/mod.rs create mode 100644 src/hdf/ohdr_message.rs create mode 100644 src/hdf/parser.rs create mode 100644 src/hdf/super_block.rs create mode 100644 src/sofa/coords.rs create mode 100644 src/sofa/error.rs create mode 100644 src/sofa/interpolate.rs create mode 100644 src/sofa/kdtree.rs create mode 100644 src/sofa/lookup.rs create mode 100644 src/sofa/loudness.rs create mode 100644 src/sofa/mod.rs create mode 100644 src/sofa/neighbors.rs create mode 100644 src/sofa/reader.rs create mode 100644 src/sofa/resample.rs create mode 100644 src/sofa/types.rs create mode 100644 src/sofa/validate.rs create mode 100644 tests/spatial_verify.rs diff --git a/.github/workflows/audit-on-push.yml b/.github/workflows/audit-on-push.yml index 4974b63..a345db2 100644 --- a/.github/workflows/audit-on-push.yml +++ b/.github/workflows/audit-on-push.yml @@ -4,11 +4,17 @@ on: paths: - '**/Cargo.toml' - '**/Cargo.lock' + - 'deny.toml' + pull_request: + paths: + - '**/Cargo.toml' + - '**/Cargo.lock' + - 'deny.toml' + schedule: + - cron: '0 0 * * 0' jobs: - security_audit: + cargo-deny: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 - - uses: actions-rs/audit-check@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} + - uses: actions/checkout@v4 + - uses: EmbarkStudios/cargo-deny-action@v2 diff --git a/.github/workflows/general.yml b/.github/workflows/general.yml index 7adb31f..6e7cf68 100644 --- a/.github/workflows/general.yml +++ b/.github/workflows/general.yml @@ -1,6 +1,9 @@ name: Rust -on: [push, pull_request] +on: + push: + branches: [main] + pull_request: env: CARGO_TERM_COLOR: always @@ -11,25 +14,25 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macOS-latest] + os: [ubuntu-latest, macos-latest] rust: [stable, beta] steps: - - uses: hecrj/setup-rust-action@v2 - with: - rust-version: ${{ matrix.rust }}${{ matrix.toolchain }} - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: submodules: true + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + - uses: Swatinem/rust-cache@v2.8.2 - name: Install dependencies if: matrix.os == 'ubuntu-latest' run: | - export DEBIAN_FRONTED=noninteractive + export DEBIAN_FRONTEND=noninteractive sudo apt-get -qq update sudo apt-get install -y libasound2-dev - name: Run tests - run: | - cargo test --verbose --workspace --all-features + run: cargo test --workspace --all-features test-windows: name: Test (Windows) @@ -39,73 +42,76 @@ jobs: rust: [stable, beta] steps: - - uses: hecrj/setup-rust-action@v2 - with: - rust-version: ${{ matrix.rust }} - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: submodules: true - - uses: crazy-max/ghaction-chocolatey@v1 + - uses: dtolnay/rust-toolchain@master with: - args: install -y pkgconfiglite --checksum 6004df17818f5a6dbf19cb335cc92702 + toolchain: ${{ matrix.rust }} + - uses: Swatinem/rust-cache@v2.8.2 + - name: Install pkg-config + run: choco install -y pkgconfiglite - name: Run tests - run: | - cargo test --verbose --workspace --all-features + run: cargo test --workspace --all-features fmt: name: Rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: submodules: true - - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@stable with: - toolchain: stable - override: true components: rustfmt - - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check + - run: cargo fmt --all -- --check clippy: name: Clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: submodules: true - - uses: actions-rs/toolchain@v1 + - name: Install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get -qq update + sudo apt-get install -y libasound2-dev + - uses: dtolnay/rust-toolchain@stable with: - toolchain: stable - override: true components: clippy - - uses: actions-rs/clippy-check@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - args: -- -D warnings + - uses: Swatinem/rust-cache@v2.8.2 + - run: cargo clippy --workspace --all-features -- -D warnings coverage: name: Code coverage runs-on: ubuntu-latest steps: - - name: Checkout repository - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: submodules: true - - - name: Install alsa dev - run: sudo apt-get install -y libasound2-dev - - - name: Install stable toolchain - uses: actions-rs/toolchain@v1 + - name: Install dependencies + run: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get -qq update + sudo apt-get install -y libasound2-dev + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2.8.2 + - uses: taiki-e/install-action@v2 with: - toolchain: stable - override: true - + tool: cargo-tarpaulin - name: Run cargo-tarpaulin - uses: actions-rs/tarpaulin@v0.1 + run: cargo tarpaulin --workspace --all-features --ignore-tests + + wasm: + name: WASM + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable with: - version: '0.15.0' - args: '--ignore-tests' + targets: wasm32-unknown-unknown + - uses: Swatinem/rust-cache@v2.8.2 + - name: Check WASM build + run: cargo check --lib --target wasm32-unknown-unknown --all-features diff --git a/.github/workflows/release-plz.yml b/.github/workflows/release-plz.yml new file mode 100644 index 0000000..8595e49 --- /dev/null +++ b/.github/workflows/release-plz.yml @@ -0,0 +1,46 @@ +name: Release-plz + +on: + push: + branches: + - main + +jobs: + release-plz-release: + name: Release + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + - uses: dtolnay/rust-toolchain@stable + - uses: release-plz/action@v0.5 + with: + command: release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + + release-plz-pr: + name: Release PR + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write + concurrency: + group: release-plz-${{ github.ref }} + cancel-in-progress: false + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + - uses: dtolnay/rust-toolchain@stable + - uses: release-plz/action@v0.5 + with: + command: release-pr + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index e69cf78..8e97cc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "sofar" version = "0.2.1" -edition = "2021" +edition = "2024" readme = "README.md" license = "MIT OR Apache-2.0" authors = ["Tomasz Andrzejak "] keywords = ["libmysofa", "hrtf", "aes69"] -description = "Rust bindings for the libmysofa library" +description = "Pure Rust SOFA/HRTF reader and renderer" repository = "https://github.com/andreiltd/sofar" homepage = "https://github.com/andreiltd/sofar" categories = ["algorithms", "filesystem", "multimedia::audio"] @@ -14,22 +14,30 @@ categories = ["algorithms", "filesystem", "multimedia::audio"] [features] default = ["dsp"] dsp = ["dep:realfft"] +resample = ["dep:rubato", "dep:audioadapter-buffers"] [workspace] members = ["libmysofa-sys"] [dependencies] -ffi = { package = "libmysofa-sys", version = "0.2.1", path = "libmysofa-sys" } +arrayvec = "0.7.6" +bitflags = "2.9.1" +log = "0.4.27" +miniz_oxide = "0.8.9" realfft = {version = "3.4", optional = true} -thiserror = "1" +rubato = {version = "1.0.1", optional = true} +audioadapter-buffers = {version = "2.0", optional = true} +thiserror = "2" +winnow = "0.7.11" [dev-dependencies] anyhow = "1.0" +arc-swap = "1" assert_approx_eq = "1.1" hound = "3.5" -cpal = "0.15" +cpal = "0.17" criterion = "0.5" -rand = "0.8" +rand = "0.9" ringbuf = "0.4" [[bench]] diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..724ad37 --- /dev/null +++ b/NOTICE @@ -0,0 +1,20 @@ +NOTICE + +This software contains code derived from the libmysofa project. + +libmysofa +--------- +Copyright (c) 2016-2017, Symonics GmbH, Christian Hoene +Licensed under the BSD 3-Clause License +https://github.com/hoene/libmysofa + +The following components are derived from libmysofa: + + - HDF5 file format parser (src/hdf/) + - SOFA/HRTF algorithms (src/sofa/): spatial lookup, interpolation, + coordinate conversion, loudness normalization, validation + +The KD-tree implementation is based on work by: + + Copyright (C) 2007-2011 John Tsiombikas + Licensed under BSD 3-Clause License diff --git a/README.md b/README.md index 76112e8..1472555 100644 --- a/README.md +++ b/README.md @@ -3,17 +3,19 @@ # Sofar -Sofa Reader and Renderer +Pure Rust SOFA Reader and HRTF Renderer ## Features -This crate provides high level bindings to [`libmysofa`] API allows to read -`HRTF` filters from `SOFA` files (Spatially Oriented Format for Acoustics). +A pure Rust implementation for reading `HRTF` filters from `SOFA` files +(Spatially Oriented Format for Acoustics). The [`render`] module implements uniformly partitioned convolution algorithm for rendering HRTF filters. +Based on the [`libmysofa`] C library by Christian Hoene / Symonics GmbH. + [`libmysofa`]: https://github.com/hoene/libmysofa [`render`]: `crate::render` @@ -33,7 +35,7 @@ let sofa = OpenOptions::new() let filt_len = sofa.filter_len(); let mut filter = Filter::new(filt_len); -// Get filter at poistion +// Get filter at position sofa.filter(0.0, 1.0, 0.0, &mut filter); let mut render = Renderer::builder(filt_len) @@ -42,7 +44,7 @@ let mut render = Renderer::builder(filt_len) .build() .unwrap(); -render.set_filter(&filter); +render.set_filter(&filter).unwrap(); let input = vec![0.0; 256]; let mut left = vec![0.0; 256]; @@ -57,9 +59,18 @@ You can run `cpal` renderer example like this: ``` shell cargo run --example renderer -- libmysofa-sys/libmysofa/share/default.sofa - ``` +## Acknowledgments + +This project is a Rust port of [libmysofa](https://github.com/hoene/libmysofa), +a C library for reading SOFA files. + +- **libmysofa** Copyright © 2016-2017 Symonics GmbH, Christian Hoene (BSD-3-Clause) +- **KD-tree** Copyright © 2007-2011 John Tsiombikas (BSD-3-Clause) + +See the [NOTICE](NOTICE) file for full attribution details. + # License This project is licensed under either of diff --git a/benches/renderer.rs b/benches/renderer.rs index bcc80bc..3dadfd1 100644 --- a/benches/renderer.rs +++ b/benches/renderer.rs @@ -1,19 +1,19 @@ -use criterion::{criterion_group, criterion_main, Bencher, BenchmarkId, Criterion}; -use sofar::{reader::Filter, render::Renderer}; +use criterion::{Bencher, BenchmarkId, Criterion, criterion_group, criterion_main}; +use sofar::{filter::Filter, render::Renderer}; use rand::Rng; fn bench_renderer(b: &mut Bencher, blocks: usize, block_len: usize, filt_len: usize) { let mut filt = Filter::new(filt_len); - rand::thread_rng().fill(&mut *filt.left); - rand::thread_rng().fill(&mut *filt.right); + rand::rng().fill(&mut *filt.left); + rand::rng().fill(&mut *filt.right); let mut input = vec![0.0; blocks * block_len]; let mut left = vec![0.0; blocks * block_len]; let mut right = vec![0.0; blocks * block_len]; - rand::thread_rng().fill(input.as_mut_slice()); + rand::rng().fill(input.as_mut_slice()); let mut renderer = Renderer::builder(filt_len) .with_partition_len(block_len) diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..8603648 --- /dev/null +++ b/deny.toml @@ -0,0 +1,29 @@ +[graph] +all-features = true + +[licenses] +allow = [ + "MIT", + "Apache-2.0", + "0BSD", + "Zlib", + "Unicode-3.0", + "BSD-3-Clause", +] +unused-allowed-license = "allow" +confidence-threshold = 0.8 + +[[licenses.clarify]] +name = "libmysofa-sys" +expression = "MIT" +license-files = [] + +[bans] +multiple-versions = "warn" +wildcards = "allow" + +[sources] +unknown-registry = "deny" +unknown-git = "deny" +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +allow-git = [] diff --git a/examples/renderer.rs b/examples/renderer.rs index e238bea..55501e5 100644 --- a/examples/renderer.rs +++ b/examples/renderer.rs @@ -1,13 +1,14 @@ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use cpal::{BufferSize, StreamConfig}; -use anyhow::{bail, Context, Error}; +use anyhow::{Context, Error, bail}; +use arc_swap::ArcSwap; use hound::WavReader; use sofar::reader::{Filter, OpenOptions, Sofar}; use sofar::render::Renderer; -use ringbuf::{traits::*, HeapRb}; +use ringbuf::{HeapRb, traits::*}; use std::sync::{Arc, Condvar, Mutex}; use std::{env, io::Read}; @@ -35,7 +36,7 @@ fn main() -> Result<(), Error> { bail!("Unsupported format, must be F32, mono channel"); } - println!("Wave file spec: {:?}", spec); + println!("Wave file spec: {spec:?}"); let sofa = OpenOptions::new() .sample_rate(spec.sample_rate as f32) @@ -46,7 +47,7 @@ fn main() -> Result<(), Error> { let device = host.default_output_device().unwrap(); let config = device.default_output_config().unwrap(); - println!("Default output config: {:?}", config); + println!("Default output config: {config:?}"); let mut stream_config = StreamConfig::from(config.clone()); stream_config.channels = 2; @@ -67,23 +68,25 @@ pub fn run( where R: Read + Send + 'static, { - let sample_rate = config.sample_rate.0 as f32; + let sample_rate = config.sample_rate as f32; let filt_len = sofa.filter_len(); - let mut filter = Filter::new(filt_len); - sofa.filter(1.0, 0.0, 0.0, &mut filter); + let initial_filter = Filter::new(filt_len); + let mut input_buf = vec![0.0f32; BLOCK_LEN]; let mut left = vec![0.0; BLOCK_LEN]; let mut right = vec![0.0; BLOCK_LEN]; - let render = Renderer::builder(filt_len) + let mut render = Renderer::builder(filt_len) .with_sample_rate(sample_rate) .with_partition_len(64) .build() .unwrap(); - let render = Arc::new(Mutex::new(render)); - let render_clone = render.clone(); + render.set_filter(&initial_filter).unwrap(); + + let pending_filter: Arc>> = Arc::new(ArcSwap::from_pointee(None)); + let pending_clone = Arc::clone(&pending_filter); let eos = Arc::new((Mutex::new(false), Condvar::new())); let eos_clone = Arc::clone(&eos); @@ -92,69 +95,82 @@ where let (mut producer, mut consumer) = ringbuf.split(); for _ in 0..BLOCK_LEN { - producer.try_push(0.0).unwrap(); + let _ = producer.try_push(0.0); } let stream = device.build_output_stream( config, move |data: &mut [f32], _: &cpal::OutputCallbackInfo| { - let left_samples = reader.samples::().len(); - let data_samples = data.len(); + let guard = pending_filter.load(); + if let Some(new_filter) = guard.as_ref() { + let _ = render.set_filter(new_filter); + pending_filter.store(Arc::new(None)); + } - if left_samples < BLOCK_LEN { - let (lock, cvar) = &*eos_clone; - let mut eos = lock.lock().unwrap(); + let data_samples = data.len(); - *eos = true; - cvar.notify_one(); + while data_samples >= consumer.occupied_len() { + let mut got = 0; + for s in reader.samples::().take(BLOCK_LEN).flatten() { + input_buf[got] = s; + got += 1; + } - return; - } + if got < BLOCK_LEN { + let (lock, cvar) = &*eos_clone; + if let Ok(mut eos) = lock.lock() { + *eos = true; + cvar.notify_one(); + } + return; + } - while data_samples >= consumer.occupied_len() { - let src = reader - .samples::() - .take(BLOCK_LEN) - .collect::, _>>() - .unwrap(); - - render - .lock() - .unwrap() - .process_block(src, &mut left, &mut right) - .unwrap(); + let _ = render.process_block(&input_buf, &mut left, &mut right); for (l, r) in Iterator::zip(left.iter(), right.iter()) { - producer.try_push(*l).unwrap(); - producer.try_push(*r).unwrap(); + let _ = producer.try_push(*l); + let _ = producer.try_push(*r); } } for dst in data.chunks_exact_mut(2) { - dst[0] = consumer.try_pop().unwrap(); - dst[1] = consumer.try_pop().unwrap(); + dst[0] = consumer.try_pop().unwrap_or(0.0); + dst[1] = consumer.try_pop().unwrap_or(0.0); } }, - |err| eprintln!("An error occurred on stream: {}", err), + |err| eprintln!("An error occurred on stream: {err}"), None, )?; stream.play()?; thread::spawn(move || { - let mut x = 1.0; - let mut y = 0.0; - let z = 0.0; + let mut x: f32 = 1.0; + let mut y: f32 = 0.0; + let z: f32 = 0.0; + + let cos_r = f32::cos(ROTATION); + let sin_r = f32::sin(ROTATION); + + let mut filter = Filter::new(filt_len); loop { - // rotate clockwise: https://en.wikipedia.org/wiki/Rotation_matrix - x = x * f32::cos(ROTATION) + y * f32::sin(ROTATION); - y = -x * f32::sin(ROTATION) + y * f32::cos(ROTATION); + let new_x = x * cos_r + y * sin_r; + let new_y = -x * sin_r + y * cos_r; + x = new_x; + y = new_y; println!("Pos: x: {x}, y: {y}"); sofa.filter(x, y, z, &mut filter); - render_clone.lock().unwrap().set_filter(&filter).unwrap(); + + let mut new_filter = Filter::new(filt_len); + new_filter.left.copy_from_slice(&filter.left); + new_filter.right.copy_from_slice(&filter.right); + new_filter.ldelay = filter.ldelay; + new_filter.rdelay = filter.rdelay; + + pending_clone.store(Arc::new(Some(new_filter))); thread::sleep(time::Duration::from_millis(50)); } diff --git a/src/filter.rs b/src/filter.rs new file mode 100644 index 0000000..04fdab7 --- /dev/null +++ b/src/filter.rs @@ -0,0 +1,23 @@ +/// HRTF filter output containing left and right channel impulse responses +/// and interaural time difference (ITD) delays. +pub struct Filter { + /// Impulse Response of FIR filter for left channel + pub left: Box<[f32]>, + /// Impulse Response of FIR filter for right channel + pub right: Box<[f32]>, + /// The amount of time in seconds that left channel should be delayed for + pub ldelay: f32, + /// The amount of time in seconds that right channel should be delayed for + pub rdelay: f32, +} + +impl Filter { + pub fn new(filt_len: usize) -> Self { + Self { + left: vec![0.0; filt_len].into_boxed_slice(), + right: vec![0.0; filt_len].into_boxed_slice(), + ldelay: 0.0, + rdelay: 0.0, + } + } +} diff --git a/src/hdf/btree.rs b/src/hdf/btree.rs new file mode 100644 index 0000000..6582f23 --- /dev/null +++ b/src/hdf/btree.rs @@ -0,0 +1,187 @@ +pub(crate) use super::{ + data_object::{DataLayout, DataSpace}, + helpers::varint_size, + parser::Input, +}; + +use miniz_oxide::inflate::decompress_to_vec_zlib_with_limit; +use winnow::ModalResult; +use winnow::Parser; +use winnow::binary::{le_u8, le_u16, le_u32, le_u64}; +use winnow::combinator::repeat; +use winnow::error::{ErrMode, ParserError, StrContext}; +use winnow::stream::Stream; +use winnow::token::{literal, take}; + +use log::info; + +/// ASCII C format: [ T, R, E, E] +pub const TREE_SIGNATURE: [u8; 4] = [0x54, 0x52, 0x45, 0x45]; + +pub(crate) fn tree( + data_len: usize, + data_space: DataSpace, + data_layout: DataLayout, +) -> impl FnMut(&mut Input) -> ModalResult> { + move |input| { + // Validate data_len to prevent excessive allocation from malformed files + const MAX_DATA_LEN: usize = 0x1000_0000; // 256MB limit, consistent with other size checks + if data_len > MAX_DATA_LEN { + return Err(ErrMode::assert( + input, + "Tree data_len exceeds maximum allowed size", + )); + } + + let dimensionality = data_space.dimensionality as usize; + let size_of_offsets = input.state.size_of_offsets(); + let size_of_lengths = input.state.size_of_lengths(); + + if dimensionality > 3 { + return Err(ErrMode::assert(input, "Tree dimension is greater than 3")); + } + + let _signature = literal(TREE_SIGNATURE).parse_next(input)?; + let node_type = le_u8.parse_next(input)?; + let _node_level = le_u8.parse_next(input)?; + let entries_used = le_u16 + .verify(|e| *e <= 0x1000) + .context(StrContext::Label("Tree entries used")) + .context(StrContext::Expected("<= 0x1000".into())) + .parse_next(input)?; + + let _address_of_left_sibling = varint_size(size_of_offsets).parse_next(input)?; + let _address_of_right_sibling = varint_size(size_of_offsets).parse_next(input)?; + + let elements = data_layout + .iter() + .take(dimensionality) + .fold(1, |acc, x| u32::saturating_mul(acc, *x)); + + let size = data_layout + .get(dimensionality) + .copied() + .ok_or_else(|| ErrMode::assert(input, "Data layout size index out of bounds"))?; + + info!("Tree elements: {elements}, size: {size}"); + + if elements == 0 || size == 0 || elements >= 0x130000 || size > 0x10 { + return Err(ErrMode::assert(input, "Invalid tree elements or size")); + } + + let mut data = vec![0; data_len]; + + for _ in 0..entries_used * 2 { + if node_type == 0 { + let _key = varint_size(size_of_lengths).parse_next(input)?; + } else { + let size_of_chunk = le_u32.parse_next(input)?; + let _filter_mask = le_u32 + .verify(|m| *m == 0) + .context(StrContext::Label("TREE filter mask")) + .context(StrContext::Expected( + "all filters must be enabled (0)".into(), + )) + .parse_next(input)?; + + let start: Vec = repeat(dimensionality, le_u64).parse_next(input)?; + info!("start {start:#?}"); + + let next = le_u64.parse_next(input)?; + if next != 0 { + break; + } + + let child_pointer = varint_size(size_of_offsets).parse_next(input)?; + info!(" data at {child_pointer:#x} len {size_of_chunk}"); + + if !input.state.is_address_valid(child_pointer) { + return Err(ErrMode::assert(input, "Invalid child pointer address")); + } + + let cp = input.checkpoint(); + input.input.reset_to_start(); + + let _skip = take(child_pointer as usize).parse_next(input)?; + let chunk = take(size_of_chunk).parse_next(input)?; + let olen = (elements * size) as usize; + + let output = decompress_to_vec_zlib_with_limit(chunk, olen) + .map_err(|_err| ErrMode::assert(input, "Failed to inflate btree data"))?; + + if output.len() != olen { + return Err(ErrMode::assert(input, "Invalid tree chunk length")); + } + + // Safe array access with defaults + let dy = data_layout.get(1).copied().unwrap_or(1) as u64; + let dz = data_layout.get(2).copied().unwrap_or(1) as u64; + let sx = data_space.dimension_size.first().copied().unwrap_or(1); + let sy = data_space.dimension_size.get(1).copied().unwrap_or(1); + let sz = data_space.dimension_size.get(2).copied().unwrap_or(1); + let dzy = dz * dy; + let szy = sz * sy; + + let data_len = data_len as u64; + let olen = olen as u64; + let size = size as u64; + let elements = elements as u64; + + match dimensionality { + 1 => { + for i in 0..olen { + let b = i / elements; + let x = i % elements + start[0]; + + if x < sx { + let j = x * size + b; + if j < data_len { + data[j as usize] = output[i as usize]; + } + } + } + } + 2 => { + for i in 0..olen { + let b = i / elements; + let mut x = i % elements; + let y = x % dy + start[1]; + x = x / dy + start[0]; + + if y < sy && x < sx { + let j = ((x * sy + y) * size) + b; + if j < data_len { + data[j as usize] = output[i as usize]; + } + } + } + } + 3 => { + for i in 0..olen { + let b = i / elements; + let mut x = i % elements; + let z = x % dz + start[2]; + let y = (x / dz) % dy + start[1]; + x = (x / dzy) + start[0]; + + if z < sz && y < sy && x < sx { + let j = (x * szy + y * sz + z) * size + b; + if j < data_len { + data[j as usize] = output[i as usize]; + } + } + } + } + _ => { + return Err(ErrMode::assert(input, "Invalid dimensionality")); + } + } + + input.reset(&cp); + } + } + + let _checksum = take(4usize).parse_next(input)?; + Ok(data) + } +} diff --git a/src/hdf/data_object.rs b/src/hdf/data_object.rs new file mode 100644 index 0000000..42686ec --- /dev/null +++ b/src/hdf/data_object.rs @@ -0,0 +1,319 @@ +use winnow::binary::le_u8; +use winnow::stream::Location; +use winnow::stream::Stream; +use winnow::token::{literal, take}; + +use winnow::ModalResult; +use winnow::Parser; +use winnow::error::StrContext; + +use arrayvec::ArrayVec; +use bitflags::bitflags; + +use super::fractal_heap::{Attribute, DirectoryEntry, FractalHeap, fractal_heap_read}; +use super::ohdr_message::{HeaderMessage, HeaderMessageKind, collect_all_messages}; + +use super::helpers::varint_size; +use super::parser::Input; + +/// ASCII C format: [ O, H, D, R] +pub const OHDR_SIGNATURE: [u8; 4] = [0x4F, 0x48, 0x44, 0x52]; +pub const DATAOBJECT_MAX_DIMENSIONALITY: usize = 5; + +/// Data Layout Chunk alias +pub(crate) type DataLayout = ArrayVec; + +#[derive(Clone, Copy, Debug)] +pub enum DataFormat { + Fixed { + bit_offset: u16, + bit_precision: u16, + }, + Float { + bit_offset: u16, + bit_precision: u16, + exponent_location: u8, + exponent_size: u8, + mantissa_location: u8, + mantissa_size: u8, + exponent_bias: u32, + }, +} + +#[derive(Clone, Copy, Debug)] +pub enum Record { + Type5 { hash_of_name: u32, heap_id: u64 }, +} + +impl Default for Record { + fn default() -> Self { + Record::Type5 { + hash_of_name: 0, + heap_id: 0, + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct DataSpace { + pub dimension_size: ArrayVec, + pub dimension_max_size: ArrayVec, + pub dimensionality: u8, + pub flags: u8, + pub kind: Option, +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct LinkInfo { + pub flags: u8, + pub maximum_creation_index: Option, + pub fractal_heap_address: u64, + pub address_btree_index: u64, + pub address_btree_order: Option, +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct GroupInfo { + pub flags: u8, + pub maximum_compact_value: Option, + pub minimum_dense_value: Option, + pub number_of_entries: Option, + pub length_of_entries: Option, +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct AttributeInfo { + pub flags: u8, + pub maximum_creation_index: u64, + pub fractal_heap_address: u64, + pub attribute_name_btree: u64, + pub attribute_creation_order_btree: u64, +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct DataType { + pub class_and_version: u8, + pub class_bit_field: u32, + pub size: u32, + pub list_size: Option, + pub data_fmt: Option, +} + +#[derive(Clone, Debug, Default)] +pub struct BinaryTree { + pub kind: u8, + pub split_percent: u8, + pub merge_percent: u8, + pub record_size: u16, + pub depth: u16, + pub number_of_records: u16, + pub node_size: u32, + pub root_node_address: u64, + pub total_number: u64, + pub records: Vec, +} + +#[derive(Clone, Debug)] +pub struct DataObject { + pub name: String, + pub address: u64, + pub flags: DataObjectFlags, + pub dt: DataType, + pub ds: DataSpace, + pub li: LinkInfo, + pub ai: AttributeInfo, + pub gi: GroupInfo, + + pub objects_btree: BinaryTree, + pub objects_heap: FractalHeap, + pub attributes_btree: BinaryTree, + pub attributes_heap: FractalHeap, + pub data_layout_chunk: DataLayout, + + pub data: Vec, + pub parsed_attributes: Vec, + pub child_directories: Vec, +} + +bitflags! { + #[derive(Clone, Copy, Debug)] + pub struct DataObjectFlags: u8 { + const SIZE_OF_CHUNK = 0b00000011; + const ATTRIBUTE_CREATION_ORDER_TRACKED = 0b00000100; + const ATTRIBUTE_CREATION_ORDER_INDEXED = 0b00001000; + const NON_DEFAULT_ATTRIBUTES_STORED = 0b00010000; + const TIMESTAMPS_STORED = 0b00100000; + } +} + +/// Version number of Superblock +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct OhdrVersion(u8); + +fn build_data_object_from_messages( + name: String, + address: u64, + flags: DataObjectFlags, + messages: Vec, +) -> DataObject { + let mut data_object = DataObject { + name, + address, + flags, + dt: DataType::default(), + ds: DataSpace::default(), + li: LinkInfo::default(), + ai: AttributeInfo::default(), + gi: GroupInfo::default(), + objects_btree: BinaryTree::default(), + objects_heap: FractalHeap::default(), + attributes_btree: BinaryTree::default(), + attributes_heap: FractalHeap::default(), + data_layout_chunk: DataLayout::new(), + data: Vec::new(), + parsed_attributes: Vec::new(), + child_directories: Vec::new(), + }; + + for message in messages { + match message.kind { + HeaderMessageKind::DataSpace(ds) => data_object.ds = ds, + HeaderMessageKind::LinkInfo(li) => data_object.li = li, + HeaderMessageKind::DataType(dt) => data_object.dt = dt, + HeaderMessageKind::AttributeInfo(ai) => data_object.ai = ai, + HeaderMessageKind::DataLayout(data) => data_object.data = data, + HeaderMessageKind::GroupInfo(gi) => data_object.gi = gi, + HeaderMessageKind::Attribute(Some(attr)) => { + data_object.parsed_attributes.push(attr); + } + _ => {} + } + } + + data_object +} + +pub(crate) fn data_object( + name: impl ToString, +) -> impl FnMut(&mut Input) -> ModalResult { + move |input| { + let address = input.current_token_start() as u64; + log::debug!("data_object: parsing at address {:#x}", address); + + let _signature = literal(OHDR_SIGNATURE) + .context(StrContext::Label("data_object signature")) + .parse_next(input)?; + log::debug!("data_object: OHDR signature OK"); + + let _version = ohdr_version + .context(StrContext::Label("data_object version")) + .parse_next(input)?; + log::debug!("data_object: version OK"); + + let flags = data_object_flags + .context(StrContext::Label("data_object flags")) + .parse_next(input)?; + log::debug!("data_object: flags={:?}", flags); + + // skip timestamps if stored + if flags.contains(DataObjectFlags::TIMESTAMPS_STORED) { + take(16usize).parse_next(input)?; + log::debug!("data_object: skipped timestamps"); + } + + let size_of_chunk_size = 1usize << (flags & DataObjectFlags::SIZE_OF_CHUNK).bits(); + let size_of_chunk = varint_size(size_of_chunk_size) + .verify(|sz| *sz <= 0x0100_0000) + .context(StrContext::Label("data_object chunk_size")) + .parse_next(input)?; + log::debug!( + "data_object: chunk_size={}, end_pos={:#x}", + size_of_chunk, + input.current_token_start() + size_of_chunk as usize + ); + + let end_of_messages = input.current_token_start() + size_of_chunk as usize; + log::debug!("data_object: parsing messages..."); + let messages = match collect_all_messages(input, end_of_messages, flags) { + Ok(m) => { + log::debug!("data_object: collected {} messages", m.len()); + m + } + Err(e) => { + log::error!("data_object: collect_all_messages failed: {:?}", e); + return Err(e); + } + }; + + // Skip final checksum + take(4usize).parse_next(input)?; + log::debug!("data_object: building from messages"); + let mut data_object = + build_data_object_from_messages(name.to_string(), address, flags, messages); + log::debug!( + "data_object: built, ai.fractal_heap={:#x}, li.fractal_heap={:#x}", + data_object.ai.fractal_heap_address, + data_object.li.fractal_heap_address + ); + + // Process attributes fractal heap + if input + .state + .is_address_valid(data_object.ai.fractal_heap_address) + { + let cp = input.checkpoint(); + input.input.reset_to_start(); + take(data_object.ai.fractal_heap_address as usize).parse_next(input)?; + + let (heap, heap_data) = fractal_heap_read.parse_next(input)?; + data_object.attributes_heap = heap; + data_object.parsed_attributes.extend(heap_data.attributes); + + input.reset(&cp); + } + + // Process objects fractal heap + if input + .state + .is_address_valid(data_object.li.fractal_heap_address) + { + let cp = input.checkpoint(); + input.input.reset_to_start(); + take(data_object.li.fractal_heap_address as usize).parse_next(input)?; + + let (heap, heap_data) = fractal_heap_read.parse_next(input)?; + data_object.objects_heap = heap; + data_object.child_directories.extend(heap_data.directories); + + input.reset(&cp); + } + + log::debug!("data_object: SUCCESS name={}", data_object.name); + Ok(data_object) + } +} + +fn ohdr_version(input: &mut Input) -> ModalResult { + le_u8 + .verify_map(|ver| match ver { + 2 => Some(OhdrVersion(ver)), + _ => None, + }) + .context(StrContext::Label("OHDR version")) + .context(StrContext::Expected("2".into())) + .parse_next(input) +} + +fn data_object_flags(input: &mut Input) -> ModalResult { + le_u8 + .verify_map(|flags| { + let flags = DataObjectFlags::from_bits_truncate(flags); + + (!flags.contains(DataObjectFlags::NON_DEFAULT_ATTRIBUTES_STORED)).then_some(flags) + }) + .context(StrContext::Label("OHDR flags")) + .context(StrContext::Expected( + "unsupported flags bit 4 (Non-default attributes) set".into(), + )) + .parse_next(input) +} diff --git a/src/hdf/fractal_heap.rs b/src/hdf/fractal_heap.rs new file mode 100644 index 0000000..20ab7d4 --- /dev/null +++ b/src/hdf/fractal_heap.rs @@ -0,0 +1,536 @@ +use winnow::binary::{le_u8, le_u16, le_u32, le_u64}; +use winnow::error::{ErrMode, ParserError, StrContext}; +use winnow::stream::Stream; +use winnow::token::{literal, take}; +use winnow::{ModalResult, Parser}; + +use super::helpers::varint_size; +use super::parser::Input; + +/// Fractal Heap Header signature +pub const FRHP_SIGNATURE: [u8; 4] = [0x46, 0x52, 0x48, 0x50]; +/// Fractal Heap Direct Block signature +pub const FHDB_SIGNATURE: [u8; 4] = [0x46, 0x48, 0x44, 0x42]; +/// Fractal Heap Indirect Block signature +pub const FHIB_SIGNATURE: [u8; 4] = [0x46, 0x48, 0x49, 0x42]; + +const MAX_NAME_LENGTH: usize = 0x100; +const MAX_RECURSIVE_DEPTH: u32 = 20; + +#[derive(Clone, Debug, Default)] +pub struct FractalHeap { + pub flags: u8, + pub heap_id_length: u16, + pub encoded_length: u16, + pub table_width: u16, + pub maximum_heap_size: u16, + pub starting_row: u16, + pub current_row: u16, + pub maximum_size: u32, + pub filter_mask: u32, + pub next_huge_object_id: u64, + pub btree_address_of_huge_objects: u64, + pub free_space: u64, + pub address_free_space: u64, + pub amount_managed_space: u64, + pub amount_allocated_space: u64, + pub offset_managed_space: u64, + pub number_managed_objects: u64, + pub size_huge_objects: u64, + pub number_huge_objects: u64, + pub size_tiny_objects: u64, + pub number_tiny_objects: u64, + pub starting_block_size: u64, + pub maximum_direct_block_size: u64, + pub address_of_root_block: u64, + pub size_of_filtered_block: u64, + pub filter_information: Vec, +} + +#[derive(Clone, Debug)] +pub struct Attribute { + pub name: String, + pub value: Option, +} + +#[derive(Clone, Debug)] +pub struct DirectoryEntry { + pub name: String, + pub address: u64, +} + +#[derive(Clone, Debug)] +pub struct FractalHeapData { + pub attributes: Vec, + pub directories: Vec, +} + +pub(crate) fn fractal_heap_read(input: &mut Input) -> ModalResult<(FractalHeap, FractalHeapData)> { + let size_of_offsets = input.state.size_of_offsets(); + let size_of_lengths = input.state.size_of_lengths(); + + let _signature = literal(FRHP_SIGNATURE).parse_next(input)?; + + let _version = le_u8 + .verify(|v| *v == 0) + .context(StrContext::Label("Fractal heap version")) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + let heap_id_length = le_u16.parse_next(input)?; + let encoded_length = le_u16 + .verify(|l| *l <= 0x8000) + .context(StrContext::Label("Fractal heap encoded length")) + .context(StrContext::Expected("<= 0x8000".into())) + .parse_next(input)?; + + let flags = le_u8.parse_next(input)?; + let maximum_size = le_u32.parse_next(input)?; + + let next_huge_object_id = varint_size(size_of_lengths).parse_next(input)?; + let btree_address_of_huge_objects = varint_size(size_of_offsets).parse_next(input)?; + let free_space = varint_size(size_of_lengths).parse_next(input)?; + let address_free_space = varint_size(size_of_offsets).parse_next(input)?; + let amount_managed_space = varint_size(size_of_lengths).parse_next(input)?; + let amount_allocated_space = varint_size(size_of_lengths).parse_next(input)?; + let offset_managed_space = varint_size(size_of_lengths).parse_next(input)?; + let number_managed_objects = varint_size(size_of_lengths).parse_next(input)?; + let size_huge_objects = varint_size(size_of_lengths).parse_next(input)?; + let number_huge_objects = varint_size(size_of_lengths).parse_next(input)?; + let size_tiny_objects = varint_size(size_of_lengths).parse_next(input)?; + let number_tiny_objects = varint_size(size_of_lengths).parse_next(input)?; + + let table_width = le_u16.parse_next(input)?; + let starting_block_size = varint_size(size_of_lengths).parse_next(input)?; + let maximum_direct_block_size = varint_size(size_of_lengths).parse_next(input)?; + let maximum_heap_size = le_u16.parse_next(input)?; + let starting_row = le_u16.parse_next(input)?; + let address_of_root_block = varint_size(size_of_offsets).parse_next(input)?; + let current_row = le_u16.parse_next(input)?; + + let (size_of_filtered_block, filter_mask, filter_information) = if encoded_length > 0 { + let size_of_filtered_block = varint_size(size_of_lengths).parse_next(input)?; + let filter_mask = le_u32.parse_next(input)?; + let filter_information = take(encoded_length).parse_next(input)?.to_vec(); + (size_of_filtered_block, filter_mask, filter_information) + } else { + (0, 0, Vec::new()) + }; + + // Skip checksum + let _checksum = take(4usize).parse_next(input)?; + + // Validate constraints from C code + if number_huge_objects > 0 { + return Err(ErrMode::assert(input, "Cannot handle huge objects")); + } + + if number_tiny_objects > 0 { + return Err(ErrMode::assert(input, "Cannot handle tiny objects")); + } + + let fractal_heap = FractalHeap { + flags, + heap_id_length, + encoded_length, + table_width, + maximum_heap_size, + starting_row, + current_row, + maximum_size, + filter_mask, + next_huge_object_id, + btree_address_of_huge_objects, + free_space, + address_free_space, + amount_managed_space, + amount_allocated_space, + offset_managed_space, + number_managed_objects, + size_huge_objects, + number_huge_objects, + size_tiny_objects, + number_tiny_objects, + starting_block_size, + maximum_direct_block_size, + address_of_root_block, + size_of_filtered_block, + filter_information, + }; + + let mut heap_data = FractalHeapData { + attributes: Vec::new(), + directories: Vec::new(), + }; + + // Process root block if valid address + if input.state.is_address_valid(address_of_root_block) { + let cp = input.checkpoint(); + + // Seek to root block + input.input.reset_to_start(); + take(address_of_root_block as usize).parse_next(input)?; + + if current_row > 0 { + // Indirect block + let block_data = indirect_block_read(input, &fractal_heap, starting_block_size)?; + heap_data.attributes.extend(block_data.attributes); + heap_data.directories.extend(block_data.directories); + } else { + // Direct block + let block_data = direct_block_read(input, &fractal_heap)?; + heap_data.attributes.extend(block_data.attributes); + heap_data.directories.extend(block_data.directories); + } + + input.reset(&cp); + } + + Ok((fractal_heap, heap_data)) +} + +fn direct_block_read( + input: &mut Input, + fractal_heap: &FractalHeap, +) -> ModalResult { + let size_of_offsets = input.state.size_of_offsets(); + + if input.state.recursive_counter() >= MAX_RECURSIVE_DEPTH { + return Err(ErrMode::assert(input, "Recursive problem in fractal heap")); + } + + input.state.recursive_counter_inc(); + + let _signature = literal(FHDB_SIGNATURE).parse_next(input)?; + + let _version = le_u8 + .verify(|v| *v == 0) + .context(StrContext::Label("FHDB version")) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + // Skip heap header address + let _heap_header_address = varint_size(size_of_offsets).parse_next(input)?; + + let size = fractal_heap.maximum_heap_size.div_ceil(8); + let _block_offset = varint_size(size as u8).parse_next(input)?; + + if fractal_heap.flags & 2 != 0 { + let _skip = take(4usize).parse_next(input)?; + } + + let offset_size = ((fractal_heap.maximum_heap_size as f32).log2() / 8.0).ceil() as u8; + let length_size = if fractal_heap.maximum_direct_block_size < fractal_heap.maximum_size as u64 { + ((fractal_heap.maximum_direct_block_size as f32).log2() / 8.0).ceil() as u8 + } else { + ((fractal_heap.maximum_size as f32).log2() / 8.0).ceil() as u8 + }; + + let mut block_data = FractalHeapData { + attributes: Vec::new(), + directories: Vec::new(), + }; + + loop { + let type_and_version = le_u8.parse_next(input)?; + if type_and_version == 0 { + break; + } + + let _offset = varint_size(offset_size).parse_next(input)?; + let length = varint_size(length_size).parse_next(input)?; + + if length > 0x10000000 { + return Err(ErrMode::assert(input, "FHDB length too large")); + } + + match type_and_version { + 3 => { + // Name-value pair attribute + let attr = parse_type_3_attribute(input, length as usize)?; + block_data.attributes.push(attr); + } + 1 => { + // Directory entry or complex attribute + let entry_data = parse_type_1_entry(input, length as usize)?; + block_data.attributes.extend(entry_data.attributes); + block_data.directories.extend(entry_data.directories); + } + _ => { + log::warn!("Unknown fractal heap type: {type_and_version}"); + // Skip unknown types gracefully to continue parsing + let _skip = take(length as usize).parse_next(input)?; + } + } + } + + input.state.recursive_counter_dec(); + Ok(block_data) +} + +fn parse_type_3_attribute(input: &mut Input, length: usize) -> ModalResult { + // Parse the magic number first + let _magic = varint_size(5usize) + .verify(|v| *v == 0x0000040008) + .context(StrContext::Label("FHDB type 3 magic")) + .context(StrContext::Expected("0x0000040008".into())) + .parse_next(input)?; + + // Read the name with the specified length (may contain non-UTF8 bytes) + let name_bytes: &[u8] = take(length).parse_next(input)?; + let name = String::from_utf8_lossy(name_bytes) + .trim_matches(|c: char| c.is_whitespace() || c == '\0') + .to_string(); + + // Read the second magic number + let _magic2 = le_u32 + .verify(|v| *v == 0x00000013) + .context(StrContext::Label("FHDB type 3 magic2")) + .context(StrContext::Expected("0x00000013".into())) + .parse_next(input)?; + + let value_len = le_u16 + .verify(|l| *l <= 0x1000) + .context(StrContext::Label("FHDB type 3 value length")) + .context(StrContext::Expected("<= 0x1000".into())) + .parse_next(input)? as usize; + + let unknown1 = varint_size(6usize).parse_next(input)?; + + let value = match unknown1 { + 0x000000020200 => None, + 0x000000020000 => { + let val_bytes: &[u8] = take(value_len).parse_next(input)?; + Some( + String::from_utf8_lossy(val_bytes) + .trim_matches(|c: char| c.is_whitespace() || c == '\0') + .to_string(), + ) + } + 0x20000020000 => Some(String::new()), + _ => { + log::warn!("Unsupported FHDB type 3 value format: {unknown1:#x}"); + return Err(ErrMode::assert(input, "Unsupported FHDB type 3 format")); + } + }; + + Ok(Attribute { name, value }) +} + +fn parse_type_1_entry(input: &mut Input, _length: usize) -> ModalResult { + let size_of_offsets = input.state.size_of_offsets(); + let mut entry_data = FractalHeapData { + attributes: Vec::new(), + directories: Vec::new(), + }; + + let unknown2 = le_u32.parse_next(input)?; + + match unknown2 { + 0 => { + // Directory entry case + let _unknown3 = le_u16 + .verify(|v| *v == 0x0000) + .context(StrContext::Label("FHDB type 1 unknown3")) + .context(StrContext::Expected("0x0000".into())) + .parse_next(input)?; + + let name_len = le_u8 + .verify(|l| (*l as usize) < MAX_NAME_LENGTH) + .context(StrContext::Label("FHDB type 1 name length")) + .context(StrContext::Expected("reasonable name length".into())) + .parse_next(input)? as usize; + + let name = take(name_len).parse_to::().parse_next(input)?; + let heap_header_address = varint_size(size_of_offsets).parse_next(input)?; + + log::info!("Directory entry: {name} at address {heap_header_address:#x}"); + + entry_data.directories.push(DirectoryEntry { + name, + address: heap_header_address, + }); + } + 0x00080008 | 0x00040008 => { + // Complex attribute cases + let attr = parse_complex_attribute(input)?; + entry_data.attributes.push(attr); + } + _ => { + log::warn!("FHDB type 1 unsupported values {unknown2:#08x}"); + return Err(ErrMode::assert(input, "Unsupported FHDB type 1 format")); + } + } + + Ok(entry_data) +} + +fn parse_complex_attribute(input: &mut Input) -> ModalResult { + // Both 0x00080008 and 0x00040008 use the same name parsing logic + // Use stack-allocated buffer to avoid heap allocation for each attribute + let mut name_bytes = [0u8; MAX_NAME_LENGTH]; + let mut len: Option = None; + + for (i, name_byte) in name_bytes.iter_mut().enumerate().take(MAX_NAME_LENGTH) { + let c = le_u8.parse_next(input)?; + *name_byte = c; + + if len.is_none() && c == 0 { + len = Some(i); + } + if c == 0x13 { + if len.is_none() { + // No null terminator found before sentinel; use position up to sentinel + len = Some(i); + } + break; + } + } + + let name_end = len.unwrap_or(0); + + // Convert to string up to the null terminator + let name = String::from_utf8_lossy(&name_bytes[..name_end]).to_string(); + + // Read exactly 3 bytes for the reserved field (must be 0x000000 per C spec) + let _reserved = varint_size(3usize) + .verify(|v| *v == 0) + .context(StrContext::Label("Complex attribute reserved bytes")) + .context(StrContext::Expected("0x000000".into())) + .parse_next(input)?; + + let value_len = le_u32 + .verify(|l| *l <= 0x1000) + .context(StrContext::Label("Complex attribute value length")) + .context(StrContext::Expected("<= 0x1000".into())) + .parse_next(input)? as usize; + + let unknown4 = le_u64.parse_next(input)?; + + let value = match unknown4 { + 0x00000001 => { + let val = take(value_len).parse_to::().parse_next(input)?; + Some(val) + } + 0x02000002 => None, // No value + _ => { + log::warn!("Unknown complex attribute format: {unknown4:#x}"); + return Err(ErrMode::assert( + input, + "Unsupported complex attribute format", + )); + } + }; + + log::info!("Complex attribute: {name} = {value:?}"); + Ok(Attribute { + name: name + .trim_matches(|c: char| c.is_whitespace() || c == '\0') + .to_string(), + value, + }) +} + +fn indirect_block_read( + input: &mut Input, + fractal_heap: &FractalHeap, + iblock_size: u64, +) -> ModalResult { + let size_of_offsets = input.state.size_of_offsets(); + let size_of_lengths = input.state.size_of_lengths(); + + if input.state.recursive_counter() >= MAX_RECURSIVE_DEPTH { + return Err(ErrMode::assert(input, "Recursive problem in fractal heap")); + } + + input.state.recursive_counter_inc(); + + let _signature = literal(FHIB_SIGNATURE).parse_next(input)?; + + let _version = le_u8 + .verify(|v| *v == 0) + .context(StrContext::Label("FHIB version")) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + // Skip heap header address + let _heap_header_address = varint_size(size_of_offsets).parse_next(input)?; + + let size = fractal_heap.maximum_heap_size.div_ceil(8); + let block_offset = varint_size(size as u8).parse_next(input)?; + + if block_offset != 0 { + return Err(ErrMode::assert(input, "FHIB block offset is not 0")); + } + + // Calculate nrows and max_dblock_rows using log2 + let nrows = (iblock_size.ilog2() - fractal_heap.starting_block_size.ilog2()) + 1; + let max_dblock_rows = (fractal_heap.maximum_direct_block_size.ilog2() + - fractal_heap.starting_block_size.ilog2()) + + 2; + + let k = if nrows < max_dblock_rows { + nrows * fractal_heap.table_width as u32 + } else { + max_dblock_rows * fractal_heap.table_width as u32 + }; + + let n = if nrows <= max_dblock_rows { + 0 + } else { + k - (max_dblock_rows * fractal_heap.table_width as u32) + }; + + let mut block_data = FractalHeapData { + attributes: Vec::new(), + directories: Vec::new(), + }; + + // Process direct blocks + for _ in 0..k { + let child_direct_block = varint_size(size_of_offsets).parse_next(input)?; + + if fractal_heap.encoded_length > 0 { + let _size_filtered = varint_size(size_of_lengths).parse_next(input)?; + let _filter_mask = le_u32.parse_next(input)?; + } + + log::info!("Processing direct block at {child_direct_block:#x}"); + + if input.state.is_address_valid(child_direct_block) { + let cp = input.checkpoint(); + + input.input.reset_to_start(); + take(child_direct_block as usize).parse_next(input)?; + + let direct_data = direct_block_read(input, fractal_heap)?; + block_data.attributes.extend(direct_data.attributes); + block_data.directories.extend(direct_data.directories); + + input.reset(&cp); + } + } + + // Process indirect blocks + for _ in 0..n { + let child_indirect_block = varint_size(size_of_offsets).parse_next(input)?; + + log::info!("Processing indirect block at {child_indirect_block:#x}"); + + if input.state.is_address_valid(child_indirect_block) { + let cp = input.checkpoint(); + + input.input.reset_to_start(); + take(child_indirect_block as usize).parse_next(input)?; + + let indirect_data = indirect_block_read(input, fractal_heap, iblock_size * 2)?; + block_data.attributes.extend(indirect_data.attributes); + block_data.directories.extend(indirect_data.directories); + + input.reset(&cp); + } + } + + input.state.recursive_counter_dec(); + Ok(block_data) +} diff --git a/src/hdf/gcol.rs b/src/hdf/gcol.rs new file mode 100644 index 0000000..b05163b --- /dev/null +++ b/src/hdf/gcol.rs @@ -0,0 +1,185 @@ +//! Global Heap Collection (GCOL) parser. +//! +//! HDF5 Level 1E - Global Heap for variable-length data. +//! Used to store variable-length strings and object references. + +use std::collections::HashMap; + +use winnow::binary::{le_u8, le_u16}; +use winnow::error::StrContext; +use winnow::stream::Location; +use winnow::token::{literal, take}; +use winnow::{ModalResult, Parser}; + +use super::helpers::varint_size; +use super::parser::Input; + +/// ASCII signature: "GCOL" +const GCOL_SIGNATURE: [u8; 4] = [0x47, 0x43, 0x4F, 0x4C]; + +/// A single object in the global heap. +#[derive(Debug, Clone)] +pub struct GcolObject { + /// Heap object index (reference ID) + pub heap_object_index: u16, + /// Size of the object data + pub object_size: u64, + /// The actual value (for small objects <= 8 bytes) + pub value: u64, +} + +/// Global Heap Collection - stores variable-length data. +#[derive(Debug, Clone, Default)] +pub struct GlobalHeap { + /// Objects indexed by (collection_address, heap_object_index) + objects: HashMap<(u64, u16), GcolObject>, +} + +impl GlobalHeap { + pub fn new() -> Self { + Self { + objects: HashMap::new(), + } + } + + /// Look up an object by collection address and reference index. + pub fn get(&self, address: u64, reference: u16) -> Option<&GcolObject> { + self.objects.get(&(address, reference)) + } + + /// Insert objects from a parsed collection. + pub fn insert(&mut self, address: u64, objects: Vec) { + for obj in objects { + self.objects.insert((address, obj.heap_object_index), obj); + } + } + + /// Check if a collection at the given address has been parsed. + pub fn has_collection(&self, address: u64) -> bool { + self.objects.keys().any(|(addr, _)| *addr == address) + } +} + +/// Parse a Global Heap Collection at the current position. +/// +/// Returns a list of objects in the collection. +pub(crate) fn gcol_read(input: &mut Input) -> ModalResult> { + let size_of_lengths = input.state.size_of_lengths(); + + // Read signature + literal(GCOL_SIGNATURE) + .context(StrContext::Label("GCOL signature")) + .parse_next(input)?; + + // Version must be 1 + le_u8 + .verify(|v| *v == 1) + .context(StrContext::Label("GCOL version")) + .context(StrContext::Expected("1".into())) + .parse_next(input)?; + + // Skip 3 reserved bytes + take(3usize).parse_next(input)?; + + // Collection size + let collection_size = varint_size(size_of_lengths) + .verify(|s| *s <= 0x4_0000_0000) // 16GB limit + .context(StrContext::Label("GCOL collection size")) + .parse_next(input)?; + + // Calculate end position (collection_size includes the 8-byte header we already read) + let start_pos = input.current_token_start(); + let end_pos = start_pos + collection_size as usize - 8; + + let mut objects = Vec::new(); + + // Read objects until we reach the end or encounter index 0 + while input.current_token_start() + 8 + size_of_lengths as usize <= end_pos { + let heap_object_index = le_u16.parse_next(input)?; + + // Index 0 marks end of objects + if heap_object_index == 0 { + break; + } + + // Reference count (unused) + let _reference_count = le_u16.parse_next(input)?; + + // Skip 4 reserved bytes + take(4usize).parse_next(input)?; + + // Object size + let object_size = varint_size(size_of_lengths).parse_next(input)?; + + // For now, only support small objects (value fits in u64) + if object_size > 8 { + log::warn!( + "GCOL object {} has size {} > 8, skipping value read", + heap_object_index, + object_size + ); + // Skip the object data + take(object_size as usize).parse_next(input)?; + continue; + } + + // Read the value + let value = if object_size > 0 { + varint_size(object_size as u8).parse_next(input)? + } else { + 0 + }; + + log::info!( + "GCOL object {} size {} value {:#x}", + heap_object_index, + object_size, + value + ); + + objects.push(GcolObject { + heap_object_index, + object_size, + value, + }); + } + + Ok(objects) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_global_heap_lookup() { + let mut heap = GlobalHeap::new(); + + let objects = vec![ + GcolObject { + heap_object_index: 1, + object_size: 4, + value: 0x12345678, + }, + GcolObject { + heap_object_index: 2, + object_size: 8, + value: 0xDEADBEEF, + }, + ]; + + heap.insert(0x1000, objects); + + assert!(heap.has_collection(0x1000)); + assert!(!heap.has_collection(0x2000)); + + let obj1 = heap.get(0x1000, 1).unwrap(); + assert_eq!(obj1.value, 0x12345678); + + let obj2 = heap.get(0x1000, 2).unwrap(); + assert_eq!(obj2.value, 0xDEADBEEF); + + assert!(heap.get(0x1000, 3).is_none()); + assert!(heap.get(0x2000, 1).is_none()); + } +} diff --git a/src/hdf/helpers.rs b/src/hdf/helpers.rs new file mode 100644 index 0000000..ad51c30 --- /dev/null +++ b/src/hdf/helpers.rs @@ -0,0 +1,60 @@ +use winnow::ModalResult; +use winnow::Parser; +use winnow::stream::{AsBytes, Stream, StreamIsPartial, ToUsize}; +use winnow::token::take; + +/// Parse a variable-sized little-endian integer (1-8 bytes). +#[inline] +pub(crate) fn varint_size(size: impl ToUsize) -> impl FnMut(&mut S) -> ModalResult +where + S: StreamIsPartial + Stream, + S::Slice: AsBytes, +{ + use winnow::error::ErrMode; + + move |input| { + let size = size.to_usize(); + if size > 8 { + return Err(ErrMode::Cut(winnow::error::ContextError::new())); + } + + let mut size_of_chunk = [0u8; 8]; + let bytes = take(size).parse_next(input)?; + + size_of_chunk[..size].copy_from_slice(bytes.as_bytes()); + Ok(u64::from_le_bytes(size_of_chunk)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_varint_2_bytes() { + let mut input: &[u8] = &[0x34, 0x12]; + let result = varint_size(2usize).parse_next(&mut input).unwrap(); + assert_eq!(result, 0x1234); + } + + #[test] + fn test_varint_4_bytes() { + let mut input: &[u8] = &[0x78, 0x56, 0x34, 0x12]; + let result = varint_size(4usize).parse_next(&mut input).unwrap(); + assert_eq!(result, 0x12345678); + } + + #[test] + fn test_varint_8_bytes() { + let mut input: &[u8] = &[0xEF, 0xCD, 0xAB, 0x89, 0x67, 0x45, 0x23, 0x01]; + let result = varint_size(8usize).parse_next(&mut input).unwrap(); + assert_eq!(result, 0x0123456789ABCDEF); + } + + #[test] + fn test_varint_size_too_large_returns_error() { + let mut input: &[u8] = &[0; 16]; + let result = varint_size(9usize).parse_next(&mut input); + assert!(result.is_err(), "Expected error for size > 8"); + } +} diff --git a/src/hdf/mod.rs b/src/hdf/mod.rs new file mode 100644 index 0000000..6a1e87d --- /dev/null +++ b/src/hdf/mod.rs @@ -0,0 +1,44 @@ +//! HDF5 file format parser for SOFA files. +//! +//! This module provides a pure Rust implementation of HDF5 parsing, +//! specifically tailored for reading SOFA (Spatially Oriented Format for Acoustics) files. + +mod btree; +mod data_object; +mod fractal_heap; +mod gcol; +mod helpers; +mod ohdr_message; +mod parser; +mod super_block; + +pub use data_object::{DataFormat, DataObject, DataSpace, DataType, GroupInfo, Record}; +pub use fractal_heap::{Attribute, DirectoryEntry, FractalHeapData}; +pub use gcol::GlobalHeap; +pub use parser::ParsedHdf; +pub use super_block::SuperBlock; + +/// Parses an HDF5/SOFA file from bytes and returns the root DataObject. +/// +/// # Errors +/// +/// Returns an error if the input is not a valid HDF5 file or if parsing fails. +pub fn parse(input: &[u8]) -> Result { + parser::parse(input).map_err(|e| match e { + winnow::error::ErrMode::Backtrack(e) | winnow::error::ErrMode::Cut(e) => e, + winnow::error::ErrMode::Incomplete(_) => winnow::error::ContextError::new(), + }) +} + +/// Parses an HDF5/SOFA file and returns a navigable structure that allows +/// parsing child objects. +/// +/// # Errors +/// +/// Returns an error if the input is not a valid HDF5 file or if parsing fails. +pub fn parse_with_children(input: &[u8]) -> Result, winnow::error::ContextError> { + parser::parse_with_children(input).map_err(|e| match e { + winnow::error::ErrMode::Backtrack(e) | winnow::error::ErrMode::Cut(e) => e, + winnow::error::ErrMode::Incomplete(_) => winnow::error::ContextError::new(), + }) +} diff --git a/src/hdf/ohdr_message.rs b/src/hdf/ohdr_message.rs new file mode 100644 index 0000000..7e5ace7 --- /dev/null +++ b/src/hdf/ohdr_message.rs @@ -0,0 +1,976 @@ +use winnow::binary::{le_u8, le_u16, le_u32, le_u64}; +use winnow::combinator::{cond, cut_err, repeat}; +use winnow::stream::{Location, Offset, Stream}; +use winnow::token::{take, take_till}; + +use winnow::error::{ErrMode, ParserError, StrContext}; +use winnow::{ModalResult, Parser}; + +use arrayvec::ArrayVec; +use bitflags::bitflags; + +use crate::hdf::btree::tree; +use crate::hdf::data_object::DataLayout; +use crate::hdf::fractal_heap::Attribute; +use crate::hdf::gcol::gcol_read; + +use super::data_object::{ + AttributeInfo, DataFormat, DataObjectFlags, DataSpace, DataType, GroupInfo, LinkInfo, +}; +use super::helpers::varint_size; +use super::parser::Input; + +const VALID_HEADER_MESSAGE_FLAGS: u8 = 0b00000101; +const MAX_CONTINUATION_DEPTH: u32 = 25; + +bitflags! { + #[derive(Clone, Copy, Debug)] + pub struct HeaderMessageFlags: u8 { + const MESSAGE_DATA_CONST = 0b00000001; + const MESSAGE_SHARED = 0b00000010; + const MESSAGE_NON_SHAREABLE = 0b00000100; + const INTERNAL_1 = 0b00001000; + const INTERNAL_2 = 0b00010000; + const INTERNAL_3 = 0b00100000; + const MESSAGE_SHAREABLE = 0b01000000; + const INTERNAL_4 = 0b10000000; + } +} + +#[derive(Clone, Debug)] +pub(crate) enum HeaderMessageKind { + Nil, + DataSpace(DataSpace), + LinkInfo(LinkInfo), + DataType(DataType), + DataFillOld, + DataFill, + DataLayout(Vec), + GroupInfo(GroupInfo), + FilterPipeline, + Attribute(Option), + Continue { offset: u64, length: u64 }, + AttributeInfo(AttributeInfo), +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub(crate) struct HeaderMessage { + pub kind: HeaderMessageKind, + pub size: u16, + pub flags: HeaderMessageFlags, +} + +pub(crate) fn collect_all_messages( + input: &mut Input, + end_of_messages: usize, + data_object_flags: DataObjectFlags, +) -> ModalResult> { + let mut all_messages = Vec::new(); + + while input.current_token_start() < end_of_messages - 4 { + let message = message_entry(data_object_flags).parse_next(input)?; + + match message.kind { + HeaderMessageKind::Continue { offset, length } => { + let continuation_messages = + message_continue(offset, length, data_object_flags).parse_next(input)?; + all_messages.extend(continuation_messages); + } + _ => { + all_messages.push(message); + } + } + } + + Ok(all_messages) +} + +fn message_entry( + data_object_flags: DataObjectFlags, +) -> impl FnMut(&mut Input) -> ModalResult { + move |input| { + let kind = le_u8.parse_next(input)?; + let size = le_u16.parse_next(input)? as usize; + + let flags = le_u8 + .verify_map(|flags| { + let valid_mask = HeaderMessageFlags::from_bits_truncate(VALID_HEADER_MESSAGE_FLAGS); + let flags = HeaderMessageFlags::from_bits_truncate(flags); + + // Reject if any bit outside the valid mask is set + flags.difference(valid_mask).is_empty().then_some(flags) + }) + .context(StrContext::Label("OHDR message")) + .context(StrContext::Expected("unsupported flags set".into())) + .parse_next(input)?; + + cond( + data_object_flags.contains(DataObjectFlags::ATTRIBUTE_CREATION_ORDER_TRACKED), + le_u16, + ) + .parse_next(input)?; + + let cp = input.checkpoint(); + let kind = message_kind(kind, size).parse_next(input)?; + + // Note: length mismatch can occur with unsupported data types + // We log a warning but don't panic - allows partial parsing + let consumed = input.offset_from(&cp); + if consumed != size { + log::warn!( + "OHDR message length mismatch: expected {}, consumed {}", + size, + consumed + ); + // Skip any remaining bytes to maintain alignment + if consumed < size { + let remaining = size - consumed; + take(remaining).parse_next(input)?; + } + } + + Ok(HeaderMessage { + kind, + flags, + size: size as _, + }) + } +} + +fn message_kind( + kind: u8, + header_size: usize, +) -> impl FnMut(&mut Input) -> ModalResult { + move |input| { + Ok(match kind { + 0 => { + message_nil(header_size).parse_next(input)?; + HeaderMessageKind::Nil + } + 1 => { + let ds = message_data_space.parse_next(input)?; + input.state.set_data_space(ds.clone()); + HeaderMessageKind::DataSpace(ds) + } + 2 => { + let li = message_link_info.parse_next(input)?; + HeaderMessageKind::LinkInfo(li) + } + 3 => { + let dt = message_data_type.parse_next(input)?; + HeaderMessageKind::DataType(dt) + } + 4 => { + message_data_fill_old.parse_next(input)?; + HeaderMessageKind::DataFillOld + } + 5 => { + message_data_fill.parse_next(input)?; + HeaderMessageKind::DataFill + } + 8 => { + let data = message_data_layout.parse_next(input)?; + HeaderMessageKind::DataLayout(data) + } + 10 => { + let gi = message_group_info.parse_next(input)?; + HeaderMessageKind::GroupInfo(gi) + } + 11 => { + message_filter_pipeline.parse_next(input)?; + HeaderMessageKind::FilterPipeline + } + 12 => { + let attr = message_attribute.parse_next(input)?; + HeaderMessageKind::Attribute(attr) + } + 16 => { + // Read continuation offset and length from message body + let size_of_offsets = input.state.size_of_offsets(); + let size_of_lengths = input.state.size_of_lengths(); + + let offset = varint_size(size_of_offsets) + .verify(|o| *o < 0x2000000) + .parse_next(input)?; + let length = varint_size(size_of_lengths) + .verify(|l| *l < 0x10000000) + .parse_next(input)?; + + HeaderMessageKind::Continue { offset, length } + } + 21 => { + let ai = message_attribute_info.parse_next(input)?; + HeaderMessageKind::AttributeInfo(ai) + } + _ => { + // Skip unknown message types - the caller has already read the + // message size, so we consume the remaining bytes and continue. + log::warn!("Skipping unknown OHDR header message type {}", kind); + take(header_size).parse_next(input)?; + HeaderMessageKind::Nil + } + }) + } +} + +fn message_nil(skip_len: usize) -> impl FnMut(&mut Input) -> ModalResult<()> { + move |input| { + take(skip_len).parse_next(input)?; + Ok(()) + } +} + +fn message_data_space(input: &mut Input) -> ModalResult { + let version = le_u8 + .verify(|ver| matches!(ver, 1..=2)) + .context(StrContext::Label("Object OHDR dataspace message")) + .context(StrContext::Expected("1 or 2".into())) + .parse_next(input)?; + + let dimensionality = le_u8 + .verify(|d| *d <= 4) // Move this check here + .context(StrContext::Label("Object OHDR dataspace dimensionality")) + .context(StrContext::Expected("<= 4".into())) + .parse_next(input)?; + + let flags = le_u8.parse_next(input)?; // Remove the verify from here + + if version == 1 && flags & 2 != 0 { + return Err(ErrMode::assert( + input, + "Permutation in OHDR is not supported", + )); + } + + let kind = match version { + 1 => { + let _reserved = take(5usize).parse_next(input)?; + None + } + 2 => { + let kind = le_u8.parse_next(input)?; + Some(kind) + } + _ => unreachable!(), + }; + + let mut fill_arr = move |input: &mut Input| -> ModalResult> { + let size_of_lengths = input.state.size_of_lengths(); + let dims = dimensionality as usize; + + let data: Vec = repeat( + dims, + cut_err(varint_size(size_of_lengths).verify(|d| *d <= 1_000_000)), + ) + .context(StrContext::Label("Dimension Size")) + .context(StrContext::Expected("dimension size <= 1,000,000".into())) + .parse_next(input)?; + + let limited: ArrayVec = data.into_iter().take(std::cmp::min(dims, 4)).collect(); + + Ok(limited) + }; + + let dimension_size = fill_arr.parse_next(input)?; + let dimension_max_size = cond(flags & 1 != 0, fill_arr) + .parse_next(input)? + .unwrap_or_default(); + + Ok(DataSpace { + dimension_size, + dimension_max_size, + dimensionality, + flags, + kind, + }) +} + +fn message_link_info(input: &mut Input) -> ModalResult { + let size_of_offsets = input.state.size_of_offsets(); + + let _version = le_u8 + .verify(|ver| *ver == 0) + .context(StrContext::Label("Object OHDR link info message version")) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + let flags = le_u8.parse_next(input)?; + + let maximum_creation_index = cond(flags & 1 != 0, le_u64).parse_next(input)?; + let fractal_heap_address = varint_size(size_of_offsets).parse_next(input)?; + let address_btree_index = varint_size(size_of_offsets).parse_next(input)?; + let address_btree_order = + cond(flags & 2 != 0, varint_size(size_of_offsets)).parse_next(input)?; + + Ok(LinkInfo { + flags, + maximum_creation_index, + fractal_heap_address, + address_btree_index, + address_btree_order, + }) +} + +fn message_data_type(input: &mut Input) -> ModalResult { + let class_and_version = le_u8 + .verify(|cv| *cv & 0xf0 == 0x10 || *cv & 0xf0 == 0x30) + .context(StrContext::Label( + "Object OHDR data type message class and version", + )) + .context(StrContext::Expected("1".into())) + .parse_next(input)?; + + // Class bit field is 3 bytes (24 bits), read as le_u24 + let class_bit_field_bytes: &[u8] = take(3usize).parse_next(input)?; + let class_bit_field = u32::from_le_bytes([ + class_bit_field_bytes[0], + class_bit_field_bytes[1], + class_bit_field_bytes[2], + 0, + ]); + + let size = le_u32.verify(|s| *s < 64).parse_next(input)?; + + let data_fmt = match class_and_version & 0xf { + // int + 0 => { + let bit_offset = le_u16.parse_next(input)?; + let bit_precision = le_u16.parse_next(input)?; + let data_fmt = DataFormat::Fixed { + bit_offset, + bit_precision, + }; + + Some(data_fmt) + } + // float + 1 => { + let bit_offset = le_u16.parse_next(input)?; + let bit_precision = le_u16.parse_next(input)?; + let exponent_location = le_u8.parse_next(input)?; + let exponent_size = le_u8.parse_next(input)?; + let mantissa_location = le_u8.parse_next(input)?; + let mantissa_size = le_u8.parse_next(input)?; + let exponent_bias = le_u32.parse_next(input)?; + + if bit_offset != 0 + || mantissa_location != 0 + || (bit_precision != 32 && bit_precision != 64) + || (bit_precision == 32 + && (exponent_location != 23 + || exponent_size != 8 + || mantissa_size != 23 + || exponent_bias != 127)) + || (bit_precision == 64 + && (exponent_location != 52 + || exponent_size != 11 + || mantissa_size != 52 + || exponent_bias != 1023)) + { + log::warn!( + "Unsupported float format: bit_precision={}, exponent_location={}, exponent_size={}, mantissa_size={}, exponent_bias={}", + bit_precision, + exponent_location, + exponent_size, + mantissa_size, + exponent_bias + ); + // Return a cut error instead of panicking - allows caller to handle gracefully + return Err(ErrMode::Cut(winnow::error::ContextError::new())); + } + + let data_fmt = DataFormat::Float { + bit_offset, + bit_precision, + exponent_location, + exponent_size, + mantissa_location, + mantissa_size, + exponent_bias, + }; + + Some(data_fmt) + } + // string + 3 => None, + // compound + 6 => { + match class_and_version >> 4 { + 1 => { + for _ in 0..(class_bit_field & 0xffff) { + let cp = input.checkpoint(); + let _name = take_till(0..256, '\0').parse_next(input)?; + let _null_byte = le_u8.parse_next(input)?; + + let skip_bytes = (7 - input.offset_from(&cp)) & 7; + let _skip = take(skip_bytes).parse_next(input)?; + + let _c = le_u32.parse_next(input)?; + let _dimension = le_u32 + .verify(|d| *d == 0) + .context(StrContext::Label("Compound v1 dimension")) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + // ignore the following fields + let skip_bytes = (3 + 4 + 4 + 4 * 4) as usize; + let _skip = take(skip_bytes).parse_next(input)?; + let _dt = message_data_type.parse_next(input)?; + } + } + 3 => { + for _ in 0..(class_bit_field & 0xffff) { + let _name = take_till(0..0x1000, '\0').parse_next(input)?; + let _null_byte = le_u8.parse_next(input)?; + + let mut j = 0u32; + let mut _c = 0u32; + + while size >> (8 * j) > 0 { + _c |= (le_u8.parse_next(input)? as u32) << (8 * j); + j += 1; + } + + let _dt = message_data_type.parse_next(input)?; + } + } + _t => { + return Err(ErrMode::assert( + input, + "object OHDR compound datatype message must have version 1 or 3", + )); + } + } + + None + } + // reference + 7 => None, + // list + 9 => None, + _t => { + return Err(ErrMode::assert( + input, + "object OHDR datatype message has unknown variable type", + )); + } + }; + + // Handle list + if class_and_version & 0xf == 9 { + let dt = message_data_type.parse_next(input)?; + let other_fmt = dt.data_fmt; + + Ok(DataType { + class_and_version, + class_bit_field, + size, + data_fmt: data_fmt.or(other_fmt), + list_size: Some(size), + }) + } else { + Ok(DataType { + class_and_version, + class_bit_field, + size, + data_fmt, + list_size: None, + }) + } +} + +fn message_data_fill_old(input: &mut Input) -> ModalResult<()> { + let size = le_u32.parse_next(input)?; + take(size).parse_next(input)?; + + Ok(()) +} + +fn message_data_fill(input: &mut Input) -> ModalResult<()> { + let version = le_u8 + .verify(|lc| matches!(lc, 1..=3)) + .context(StrContext::Label( + "Object OHDR message data storage fill version", + )) + .context(StrContext::Expected("1, 2 or 3".into())) + .parse_next(input)?; + + match version { + 1 | 2 => { + let _space_allocation_time = le_u8 + .verify(|t| *t < 128 && *t & 0xFE == 2) + .parse_next(input)?; + let _fill_value_write_time = le_u8.verify(|t| *t < 128 && *t == 2).parse_next(input)?; + let fill_value_defined = le_u8 + .verify(|t| *t < 128 && *t & 0xFE == 0) + .parse_next(input)?; + + if fill_value_defined > 0 { + let size = le_u32.parse_next(input)?; + let _skip = take(size).parse_next(input)?; + } + } + 3 => { + let flags = le_u8.parse_next(input)?; + if flags & (1 << 5) != 0 { + let size = le_u32.parse_next(input)?; + let _skip = take(size).parse_next(input)?; + } + } + _ => { + unreachable!(); + } + } + + Ok(()) +} + +fn message_data_layout(input: &mut Input) -> ModalResult> { + let size_of_offsets = input.state.size_of_offsets(); + let size_of_lengths = input.state.size_of_lengths(); + + let _version = le_u8 + .verify(|v| *v == 3) + .context(StrContext::Label("Object OHDR message data layout version")) + .context(StrContext::Expected("3".into())) + .parse_next(input)?; + + let layout_class = le_u8 + .verify(|lc| matches!(lc, 0..=2)) + .context(StrContext::Label("Object OHDR message data layout class")) + .context(StrContext::Expected("0, 1, or 2".into())) + .parse_next(input)?; + + let data = match layout_class { + 0 => { + let data_size = le_u16.parse_next(input)?; + let _skip = take(data_size).parse_next(input)?; + + log::info!("TODO layout 0, size: {data_size}"); + vec![] + } + 1 => { + let data_address = varint_size(size_of_offsets).parse_next(input)?; + let data_size = varint_size(size_of_lengths) + .verify(|sz| *sz < 0x1000_0000) + .context(StrContext::Label( + "Object OHDR message data layout, data size", + )) + .context(StrContext::Expected("< 0x10000000".into())) + .parse_next(input)?; + + log::info!("CHUNK Contiguous SIZE: {data_size}"); + + if input.state.is_address_valid(data_address) { + // Use absolute seek to avoid underflow when data_address < cur_pos + let cp = input.checkpoint(); + input.input.reset_to_start(); + let _skip = take(data_address as usize).parse_next(input)?; + let data = take(data_size as usize).parse_next(input)?; + + input.reset(&cp); + data.to_vec() + } else { + vec![] + } + } + 2 => { + let dimensionality = le_u8 + .verify(|d| matches!(d, 1..=5)) + .context(StrContext::Label( + "Object OHDR message data layout 2 dimensionality", + )) + .context(StrContext::Expected("1..=5".into())) + .parse_next(input)?; + + let data_address = varint_size(size_of_offsets).parse_next(input)?; + log::info!("Dimensionality: {dimensionality}"); + log::info!("CHUNK at address: {data_address:#X}"); + + let mut data_layout_chunk = DataLayout::new(); + + for _ in 0..dimensionality { + let item = le_u32.parse_next(input)?; + data_layout_chunk.push(item); + } + + let data_size = data_layout_chunk.last().copied().unwrap() as u64; + let Some(data_space) = input.state.data_space() else { + return Err(ErrMode::assert(input, "Data space is not available")); + }; + + // SAFETY, we check if dimensionality is non zero, so here we can + // safely assume we have at least one element. + let data_size = data_space + .dimension_size + .iter() + .fold(data_size, |acc, s| u64::saturating_mul(acc, *s)) + as usize; + + if data_size > 0x1000_0000 { + return Err(ErrMode::assert( + input, + "Object OHDR message data layout, data size too large", + )); + } + + // Note: The B-tree reader only supports data_space.dimensionality <= 3 + // The layout dimensionality can be 4 (includes element size) even for 3D data + if input.state.is_address_valid(data_address) && data_space.dimensionality <= 3 { + // Use absolute seek to avoid underflow when data_address < cur_pos + let cp = input.checkpoint(); + input.input.reset_to_start(); + let _skip = take(data_address as usize).parse_next(input)?; + let data = tree(data_size, data_space, data_layout_chunk).parse_next(input)?; + + input.reset(&cp); + data + } else { + vec![] + } + } + _ => { + unreachable!(); + } + }; + + Ok(data) +} + +fn message_group_info(input: &mut Input) -> ModalResult { + let _version = le_u8 + .verify(|v| *v == 0) + .context(StrContext::Label("Object OHDR group info version")) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + let flags = le_u8.parse_next(input)?; + let values = cond(flags & 1 != 0, (le_u16, le_u16)).parse_next(input)?; + let entries = cond(flags & 2 != 0, (le_u16, le_u16)).parse_next(input)?; + + Ok(GroupInfo { + flags, + maximum_compact_value: values.map(|v| v.0), + minimum_dense_value: values.map(|v| v.1), + number_of_entries: entries.map(|v| v.0), + length_of_entries: entries.map(|v| v.1), + }) +} + +fn message_filter_pipeline(input: &mut Input) -> ModalResult<()> { + let version = le_u8 + .verify(|v| matches!(v, 1..=2)) + .context(StrContext::Label("Filter pipeline version")) + .context(StrContext::Expected("1 or 2".into())) + .parse_next(input)?; + + let filters = le_u8 + .verify(|f| *f < 32) + .context(StrContext::Label("Filters number")) + .context(StrContext::Expected("< 32".into())) + .parse_next(input)?; + + let mut filter_id = |input: &mut Input| -> ModalResult { + le_u16 + .context(StrContext::Label("Filter identification value")) + .parse_next(input) + }; + + match version { + 1 => { + let _reserved = take(6usize) + .verify(|s: &[u8]| !s.iter().any(|x| *x != 0)) + .context(StrContext::Label("Filters pipeline v1, reserved value")) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + for _ in 0..filters { + let filter_id_value = filter_id.parse_next(input)?; + let name_len = le_u16.parse_next(input)?; + let _flags = le_u16.parse_next(input)?; + let values = le_u16.verify(|v| *v <= 0x1000).parse_next(input)?; + + log::info!(" filter {filter_id_value} namelen {name_len} values {values}"); + + let skip = ((name_len - 1) & !7) + 8; + take(skip).parse_next(input)?; + repeat(values as usize, le_u32) + .map(|()| ()) + .parse_next(input)?; + cond((values & 1) == 1, le_u32).parse_next(input)?; + } + } + 2 => { + for _ in 0..filters { + let filter_id_value = filter_id.parse_next(input)?; + let _flags = le_u16.parse_next(input)?; + let values = le_u16.verify(|v| *v <= 0x1000).parse_next(input)?; + + log::info!(" filter {filter_id_value}"); + repeat(values as usize, le_u32) + .map(|()| ()) + .parse_next(input)?; + } + } + _ => { + unreachable!(); + } + } + + Ok(()) +} + +/// Parse OHDR Attribute message (type 0x0C). +/// +/// Parses the attribute header and extracts string values from the data. +/// For non-string datatypes, returns None for the value. +fn message_attribute(input: &mut Input) -> ModalResult> { + let version = le_u8 + .verify(|v| *v == 1 || *v == 3) + .context(StrContext::Label("Object OHDR attribute version")) + .context(StrContext::Expected("1 or 3".into())) + .parse_next(input)?; + + let _flags = le_u8.parse_next(input)?; + let name_size = le_u16.verify(|s| *s <= 0x1000).parse_next(input)?; + let _datatype_size = le_u16.parse_next(input)?; + let _dataspace_size = le_u16.parse_next(input)?; + let _encoding = cond(version == 3, le_u8).parse_next(input)?; + + // Read attribute name + let name_bytes: &[u8] = take(name_size).parse_next(input)?; + let name = String::from_utf8_lossy(name_bytes) + .trim_end_matches('\0') + .to_string(); + + // Version 1 pads name to 8-byte boundary + if version == 1 { + let padding = (8usize.saturating_sub(name_size as usize)) & 7; + if padding > 0 { + take(padding).parse_next(input)?; + } + } + + // Parse the datatype to determine the data class and size. + // We save a checkpoint in case parsing datatype/dataspace/data fails, + // so we can fall back to skipping the remaining bytes. + let dt_cp = input.checkpoint(); + + let result: Option = (|| -> Result, _> { + let dt = message_data_type.parse_next(input)?; + + // Version 1 pads datatype to 8-byte boundary + if version == 1 { + let dt_consumed = input.offset_from(&dt_cp); + let dt_padding = ((8usize.wrapping_sub(dt_consumed)) & 7) % 8; + if dt_padding > 0 { + take(dt_padding).parse_next(input)?; + } + } + + let ds_cp = input.checkpoint(); + let ds = message_data_space.parse_next(input)?; + + // Version 1 pads dataspace to 8-byte boundary + if version == 1 { + let ds_consumed = input.offset_from(&ds_cp); + let ds_padding = ((8usize.wrapping_sub(ds_consumed)) & 7) % 8; + if ds_padding > 0 { + take(ds_padding).parse_next(input)?; + } + } + + let num_elements: u64 = if ds.dimension_size.is_empty() { + 1 // scalar + } else { + ds.dimension_size.iter().product() + }; + + let data_class = dt.class_and_version & 0xf; + match data_class { + // class 3: fixed-length string + 3 if dt.size > 0 && dt.size <= 0x1000 => { + let total_size = dt.size as usize * num_elements as usize; + let val_bytes: &[u8] = take(total_size).parse_next(input)?; + Ok(Some( + String::from_utf8_lossy(val_bytes) + .trim_end_matches('\0') + .to_string(), + )) + } + // class 9: variable-length. The data contains a GCOL address + + // reference that points to the actual string. Read the GCOL + // collection and resolve the value. + 9 if dt.list_size.is_some() => { + let list_size = dt.list_size.unwrap() as usize; + + // Read GCOL address from the list prefix + let gcol_prefix_size = list_size - dt.size as usize; + let gcol_address = if gcol_prefix_size == 8 { + let _unknown = le_u32.parse_next(input)?; + varint_size(4u8).parse_next(input)? + } else if gcol_prefix_size > 0 { + varint_size(gcol_prefix_size as u8).parse_next(input)? + } else { + 0 + }; + + // Read the reference data: 4 bytes unknown + reference ID + let reference_size = dt.size as usize; + if reference_size >= 4 { + let _unknown_ref = le_u32.parse_next(input)?; + let reference = if reference_size > 4 { + varint_size((reference_size - 4) as u8).parse_next(input)? as u16 + } else { + 0 + }; + + // Seek to GCOL address and read the collection + if input.state.is_address_valid(gcol_address) { + let gcol_cp = input.checkpoint(); + input.input.reset_to_start(); + take(gcol_address as usize).parse_next(input)?; + + if let Ok(objects) = gcol_read(input) { + input.reset(&gcol_cp); + + // Find the object with matching reference + if let Some(obj) = + objects.iter().find(|o| o.heap_object_index == reference) + { + // The value is a data object address. Seek there + // and read the object name as the string value. + let obj_addr = obj.value; + if input.state.is_address_valid(obj_addr) { + let obj_cp = input.checkpoint(); + input.input.reset_to_start(); + take(obj_addr as usize).parse_next(input)?; + if let Ok(data_obj) = + super::data_object::data_object("ref").parse_next(input) + { + input.reset(&obj_cp); + return Ok(Some(data_obj.name)); + } + input.reset(&obj_cp); + } + } + } else { + input.reset(&gcol_cp); + } + } + } else { + take(reference_size).parse_next(input)?; + } + Ok(None) + } + _ => { + // For non-string types, read and discard the data + let element_size = dt.list_size.unwrap_or(dt.size) as usize; + let total_size = element_size * num_elements as usize; + if total_size > 0 && total_size <= 0x100000 { + take(total_size).parse_next(input)?; + } + Ok(None) + } + } + })() + .unwrap_or_else(|_: winnow::error::ErrMode| { + // If parsing failed, reset and let the message_entry skip handler deal with it + input.reset(&dt_cp); + None + }); + + Ok(Some(Attribute { + name, + value: result, + })) +} + +fn message_attribute_info(input: &mut Input) -> ModalResult { + let size_of_offsets = input.state.size_of_offsets(); + + let _version = le_u8 + .verify(|v| *v == 0) + .context(StrContext::Label( + "Object OHDR attribute info message version", + )) + .context(StrContext::Expected("0".into())) + .parse_next(input)?; + + let flags = le_u8.parse_next(input)?; + + let maximum_creation_index = cond(flags & 1 != 0, le_u16) + .parse_next(input)? + .map(|v| v as u64) + .unwrap_or(0); + + let fractal_heap_address = varint_size(size_of_offsets).parse_next(input)?; + let attribute_name_btree = varint_size(size_of_offsets).parse_next(input)?; + + let attribute_creation_order_btree = cond(flags & 2 != 0, varint_size(size_of_offsets)) + .parse_next(input)? + .unwrap_or(0); + + Ok(AttributeInfo { + flags, + maximum_creation_index, + fractal_heap_address, + attribute_name_btree, + attribute_creation_order_btree, + }) +} + +fn message_continue( + offset: u64, + length: u64, + data_object_flags: DataObjectFlags, +) -> impl FnMut(&mut Input) -> ModalResult> { + move |input| { + log::info!(" continue {offset:#x} {length:#x}"); + + if input.state.recursive_counter() >= MAX_CONTINUATION_DEPTH { + return Err(ErrMode::assert(input, "Recursive problem")); + } + + let cp = input.checkpoint(); + input.state.recursive_counter_inc(); + + // Seek to continuation chunk + input.input.reset_to_start(); + take(offset as usize).parse_next(input)?; + + let _ochk_signature = "OCHK".parse_next(input)?; + + let end_of_continuation = input.current_token_start() + length as usize - 4; + let mut continuation_messages = Vec::new(); + + while input.current_token_start() < end_of_continuation - 4 { + let message = message_entry(data_object_flags).parse_next(input)?; + + match message.kind { + HeaderMessageKind::Continue { + offset: nested_offset, + length: nested_length, + } => { + let nested_messages = + message_continue(nested_offset, nested_length, data_object_flags) + .parse_next(input)?; + continuation_messages.extend(nested_messages); + } + _ => { + continuation_messages.push(message); + } + } + } + + take(4usize).parse_next(input)?; + + // Restore position and decrement counter + input.reset(&cp); + input.state.recursive_counter_dec(); + + log::info!(" continue back"); + Ok(continuation_messages) + } +} diff --git a/src/hdf/parser.rs b/src/hdf/parser.rs new file mode 100644 index 0000000..04a12f8 --- /dev/null +++ b/src/hdf/parser.rs @@ -0,0 +1,184 @@ +use winnow::error::ParserError; +use winnow::stream::{LocatingSlice, Location, Stateful, Stream}; +use winnow::token::take; +use winnow::{error::ErrMode, prelude::*}; + +use super::data_object::{DataObject, DataSpace, data_object}; +use super::super_block::{SuperBlock, super_block}; + +pub(crate) type Input<'a> = Stateful, State>; + +/// Context state that is available after parsing Super Block +#[derive(Debug, Clone)] +pub(crate) struct State { + size_of_lengths: u8, + size_of_offsets: u8, + end_of_file_address: u64, + recursive_counter: u32, + data_space: Option, +} + +impl State { + pub fn new(block: &SuperBlock) -> Self { + Self { + size_of_lengths: block.size_of_lengths, + size_of_offsets: block.size_of_offsets, + end_of_file_address: block.end_of_file_address, + recursive_counter: 0, + data_space: None, + } + } + + pub fn size_of_lengths(&self) -> u8 { + self.size_of_lengths + } + + pub fn size_of_offsets(&self) -> u8 { + self.size_of_offsets + } + + #[allow(dead_code)] + pub fn end_of_file_address(&self) -> u64 { + self.end_of_file_address + } + + pub fn is_address_valid(&self, address: u64) -> bool { + address > 0 && address < self.end_of_file_address + } + + pub(crate) fn data_space(&self) -> Option { + self.data_space.clone() + } + + pub(crate) fn set_data_space(&mut self, data_space: DataSpace) { + self.data_space = Some(data_space); + } + + pub(crate) fn recursive_counter(&self) -> u32 { + self.recursive_counter + } + + pub(crate) fn recursive_counter_inc(&mut self) { + self.recursive_counter = self.recursive_counter.saturating_add(1); + } + + pub(crate) fn recursive_counter_dec(&mut self) { + self.recursive_counter = self.recursive_counter.saturating_sub(1); + } +} + +/// Parsed HDF5 file with ability to navigate to child objects. +pub struct ParsedHdf<'a> { + data: &'a [u8], + state: State, + /// The root data object + pub root: DataObject, +} + +impl<'a> ParsedHdf<'a> { + /// Parse a child data object by address. + /// + /// Use addresses from `root.child_directories` to navigate the tree. + pub fn parse_child(&self, name: &str, address: u64) -> ModalResult { + if !self.state.is_address_valid(address) { + return Err(ErrMode::assert(&self.data, "Invalid child object address")); + } + + let input = LocatingSlice::new(self.data); + let mut stream = Input { + input, + state: self.state.clone(), + }; + + let _skip = take(address as usize).parse_next(&mut stream)?; + data_object(name).parse_next(&mut stream) + } + + /// Find a child by name and parse it. + pub fn get_child(&self, name: &str) -> Option> { + self.root + .child_directories + .iter() + .find(|d| d.name == name) + .map(|d| self.parse_child(&d.name, d.address)) + } +} + +/// Parse an HDF5/SOFA file and return a navigable structure. +pub fn parse_with_children(input: &[u8]) -> ModalResult> { + let mut slice = input; + let cp = slice.checkpoint(); + let super_block = super_block.parse_next(&mut slice)?; + + slice.reset(&cp); + + if super_block.end_of_file_address as usize != slice.eof_offset() { + return Err(ErrMode::assert(&slice, "File size mismatch")); + } + + let state = State::new(&super_block); + let locating = LocatingSlice::new(slice); + + let mut stream = Input { + input: locating, + state: state.clone(), + }; + + let _skip = + take(super_block.root_group_object_header_address as usize).parse_next(&mut stream)?; + let root = data_object("root").parse_next(&mut stream)?; + + Ok(ParsedHdf { + data: input, + state, + root, + }) +} + +pub fn parse(mut input: &[u8]) -> ModalResult { + let cp = input.checkpoint(); + let super_block = super_block.parse_next(&mut input)?; + + input.reset(&cp); + + if super_block.end_of_file_address as usize != input.eof_offset() { + log::error!( + "File size mismatch: header says {}, actual {}", + super_block.end_of_file_address, + input.eof_offset() + ); + return Err(ErrMode::assert(&input, "File size mismatch")); + } + + log::debug!( + "SuperBlock parsed: offsets={}, lengths={}, root_addr={:#x}", + super_block.size_of_offsets, + super_block.size_of_lengths, + super_block.root_group_object_header_address + ); + + let state = State::new(&super_block); + let input = LocatingSlice::new(input); + + let mut stream = Input { input, state }; + + // jump to the first object + let _skip = + take(super_block.root_group_object_header_address as usize).parse_next(&mut stream)?; + + log::debug!( + "About to parse root data_object at position {:#x}", + stream.input.current_token_start() + ); + + match data_object("root").parse_next(&mut stream) { + Ok(obj) => { + log::debug!("Root object parsed successfully: {}", obj.name); + Ok(obj) + } + Err(e) => { + log::error!("Failed to parse root object: {:?}", e); + Err(e) + } + } +} diff --git a/src/hdf/super_block.rs b/src/hdf/super_block.rs new file mode 100644 index 0000000..a7d09d2 --- /dev/null +++ b/src/hdf/super_block.rs @@ -0,0 +1,183 @@ +use winnow::binary::{le_u8, le_u16, le_u32}; +use winnow::error::StrContext; +use winnow::prelude::*; +use winnow::token::literal; + +use super::helpers::varint_size; + +/// Signature used to quickly identify a file as being an HDF5 file. +/// +/// ASCII C format: [\211, H, D, F, \r, \n, \032, \n] +const FORMAT_SIGNATURE: [u8; 8] = [0x89, 0x48, 0x44, 0x46, 0x0d, 0x0a, 0x1a, 0x0a]; + +/// Version number of Superblock +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct SuperBlockVersion(u8); + +/// Only present in versions 0 and 1 of the superblock +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct SuperBlockVersionExt { + /// Version number of File’s Free Space Storage + space_storage: u8, + /// Version number of Root Group Symbol Table Entry + root_group: u8, + /// Version number of Shared Header Message Format + shared_header: u8, +} + +#[derive(Clone, Copy, Debug)] +pub struct SuperBlock { + pub size_of_offsets: u8, + pub size_of_lengths: u8, + pub base_address: u64, + pub end_of_file_address: u64, + pub root_group_object_header_address: u64, + pub superblock_extension_address: Option, +} + +fn version(input: &mut &[u8]) -> ModalResult { + le_u8 + .verify_map(|ver| match ver { + 0..3 => Some(SuperBlockVersion(ver)), + _ => None, + }) + .context(StrContext::Label("Super block version")) + .context(StrContext::Expected("0, 1, 2 or 3".into())) + .parse_next(input) +} + +fn version_ext(input: &mut &[u8]) -> ModalResult { + (le_u8, le_u8, le_u8, le_u8) + .map( + |(space_storage, root_group, _reserved, shared_header)| SuperBlockVersionExt { + space_storage, + root_group, + shared_header, + }, + ) + .parse_next(input) +} + +fn size_of_offsets(input: &mut &[u8]) -> ModalResult { + le_u8 + .verify(|sz| matches!(sz, 2..=8)) + .context(StrContext::Label("Size of Offsets")) + .context(StrContext::Expected("range 2 to 8".into())) + .parse_next(input) +} + +fn size_of_lengths(input: &mut &[u8]) -> ModalResult { + le_u8 + .verify(|sz| matches!(sz, 2..=8)) + .context(StrContext::Label("Size of Lengths")) + .context(StrContext::Expected("range 2 to 8".into())) + .parse_next(input) +} + +fn indexed_storage(ver: SuperBlockVersion) -> impl FnMut(&mut &[u8]) -> ModalResult> { + move |input| match ver.0 { + 1 => { + let indexed_storage_internal_node_k = le_u16.parse_next(input)?; + let _reserved = le_u32.parse_next(input)?; + + Ok(Some(indexed_storage_internal_node_k)) + } + _ => Ok(None), + } +} + +fn super_block_ver_0_or_1( + ver: SuperBlockVersion, +) -> impl FnMut(&mut &[u8]) -> ModalResult { + move |input: &mut &[u8]| { + let _version_ext = version_ext.parse_next(input)?; + let size_of_offsets = size_of_offsets.parse_next(input)?; + let size_of_lengths = size_of_lengths.parse_next(input)?; + let _reserved = le_u8.verify(|r| *r == 0).parse_next(input)?; + let _group_leaf_node_k = le_u16.parse_next(input)?; + let _group_internal_node_k = le_u16.parse_next(input)?; + let _file_consistency_flags = le_u32.verify(|f| *f == 0).parse_next(input)?; + let _indexed_storage = indexed_storage(ver).parse_next(input)?; + let base_address = varint_size(size_of_offsets) + .verify(|a| *a == 0) + .parse_next(input)?; + let _address_of_file_free_space = varint_size(size_of_offsets).parse_next(input)?; + let end_of_file_address = varint_size(size_of_offsets).parse_next(input)?; + let _driver_info_block_address = varint_size(size_of_offsets).parse_next(input)?; + let _link_name_offset = varint_size(size_of_offsets).parse_next(input)?; + let root_group_object_header_address = varint_size(size_of_offsets).parse_next(input)?; + let _cache_type = le_u32.verify(|t| *t <= 2).parse_next(input)?; + + Ok(SuperBlock { + size_of_offsets, + size_of_lengths, + base_address, + end_of_file_address, + root_group_object_header_address, + superblock_extension_address: None, + }) + } +} + +fn super_block_ver_2_or_3(input: &mut &[u8]) -> ModalResult { + let size_of_offsets = size_of_offsets.parse_next(input)?; + let size_of_lengths = size_of_lengths.parse_next(input)?; + let _file_consistency_flags = le_u8.parse_next(input)?; + let base_address = varint_size(size_of_offsets) + .verify(|a| *a == 0) + .parse_next(input)?; + let super_block_extension_address = varint_size(size_of_offsets).parse_next(input)?; + let end_of_file_address = varint_size(size_of_offsets).parse_next(input)?; + let root_group_object_header_address = varint_size(size_of_offsets).parse_next(input)?; + + Ok(SuperBlock { + size_of_offsets, + size_of_lengths, + base_address, + end_of_file_address, + root_group_object_header_address, + superblock_extension_address: Some(super_block_extension_address), + }) +} + +pub(crate) fn super_block(input: &mut &[u8]) -> ModalResult { + let _signature = literal(FORMAT_SIGNATURE).parse_next(input)?; + let ver = version.parse_next(input)?; + + match ver.0 { + 0 | 1 => super_block_ver_0_or_1(ver).parse_next(input), + 2 | 3 => super_block_ver_2_or_3.parse_next(input), + _ => unreachable!(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_invalid_signature_rejected() { + let mut input: &[u8] = &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + assert!(super_block(&mut input).is_err()); + } + + #[test] + fn test_invalid_version_rejected() { + // Valid signature but invalid version (4) + let mut input: &[u8] = &[0x89, 0x48, 0x44, 0x46, 0x0d, 0x0a, 0x1a, 0x0a, 0x04]; + assert!(super_block(&mut input).is_err()); + } + + #[test] + fn test_size_of_offsets_valid_range() { + // Test that values outside 2-8 are rejected + let mut input: &[u8] = &[1]; // too small + assert!(size_of_offsets(&mut input).is_err()); + + let mut input: &[u8] = &[9]; // too large + assert!(size_of_offsets(&mut input).is_err()); + + let mut input: &[u8] = &[4]; // valid + assert!(size_of_offsets(&mut input).is_ok()); + } +} diff --git a/src/lib.rs b/src/lib.rs index d64ecbc..e14e05c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,15 @@ //! # Sofar //! -//! Sofa Reader and Renderer +//! Pure Rust SOFA Reader and HRTF Renderer //! -//! This crate provides high level bindings to [`libmysofa`] API allows to read -//! `HRTF` filters from `SOFA` files (Spatially Oriented Format for Acoustics). +//! This crate provides a pure Rust implementation for reading `HRTF` filters +//! from `SOFA` files (Spatially Oriented Format for Acoustics). //! //! The [`render`] module implements uniformly partitioned convolution algorithm //! for rendering HRTF filters. //! +//! Based on the [`libmysofa`] C library by Christian Hoene / Symonics GmbH. +//! //! [`libmysofa`]: https://github.com/hoene/libmysofa //! [`render`]: `crate::render` //! @@ -47,7 +49,11 @@ //! render.process_block(&input, &mut left, &mut right).unwrap(); //! ``` +pub mod filter; + +pub mod hdf; pub mod reader; +mod sofa; #[cfg(feature = "dsp")] pub mod render; @@ -81,4 +87,22 @@ mod tests { let mut filter = Filter::new(filt_len); sofa.filter(0.0, 1.0, 0.0, &mut filter); } + + #[test] + fn debug_tu_berlin() { + let cwd = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); + std::env::set_current_dir(&cwd).unwrap(); + + let path = "libmysofa-sys/libmysofa/tests/TU-Berlin_QU_KEMAR_anechoic_radius_0.5m.sofa"; + let data = std::fs::read(path).unwrap(); + + match hdf::parse(&data) { + Ok(obj) => { + assert_eq!(obj.name, "root"); + } + Err(e) => { + panic!("Parse failed: {:?}", e); + } + } + } } diff --git a/src/reader.rs b/src/reader.rs index 2bb3b90..e2a22d7 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,87 +1,98 @@ -//! This module provides high level bindings to [`libmysofa`] API allows to read -//! `HRTF` filters from `SOFA` files (Spatially Oriented Format for Acoustics). +//! SOFA file reader for HRTF data. //! -//! [`libmysofa`]: https://github.com/hoene/libmysofa +//! This module provides the primary API for reading `HRTF` filters from +//! `SOFA` files (Spatially Oriented Format for Acoustics). +//! +//! # Examples +//! +//! Open a SOFA file with default options: +//! +//! ```no_run +//! use sofar::reader::Sofar; +//! use sofar::reader::Filter; +//! +//! let sofa = Sofar::open("path/to/file.sofa").unwrap(); +//! let filt_len = sofa.filter_len(); +//! +//! let mut filter = Filter::new(filt_len); +//! sofa.filter(0.0, 1.0, 0.0, &mut filter); +//! ``` +//! +//! Open with custom options: +//! +//! ```no_run +//! use sofar::reader::OpenOptions; +//! +//! let sofa = OpenOptions::new() +//! .sample_rate(44100.0) +//! .open("path/to/file.sofa") +//! .unwrap(); +//! ``` +//! +//! Open from in-memory bytes: +//! +//! ```no_run +//! use sofar::reader::Sofar; +//! +//! let data = std::fs::read("path/to/file.sofa").unwrap(); +//! let sofa = Sofar::open_data(&data).unwrap(); +//! ``` -use std::{ffi::CString, io, path::Path}; +use std::path::Path; -const DEFAULT_CACHED: bool = false; -const DEFAULT_NORMALIZED: bool = true; +use crate::sofa::{ + Hrtf, InterpolatedFilter, Lookup, Neighborhood, get_filter_nointerp, interpolate, normalize, + resample, validate, +}; +const DEFAULT_NORMALIZED: bool = true; const DEFAULT_SAMPLE_RATE: f32 = 48000.0; - -const DEFAULT_NEIGHBOR_ANGLE_STEP: f32 = ffi::MYSOFA_DEFAULT_NEIGH_STEP_ANGLE as f32; -const DEFAULT_NEIGHBOR_RADIUS_STEP: f32 = ffi::MYSOFA_DEFAULT_NEIGH_STEP_RADIUS as f32; +const DEFAULT_NEIGHBOR_ANGLE_STEP: f32 = 0.5; +const DEFAULT_NEIGHBOR_RADIUS_STEP: f32 = 0.01; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("IO error")] + #[error("IO error: {0}")] Io(#[from] std::io::Error), - #[error("The owls are not what they seem")] - InternalError, - #[error("Invalid data format")] + #[error("Parse error: {0}")] + Parse(#[from] crate::sofa::Error), + #[error("Invalid format")] InvalidFormat, - #[error("Format is not supported")] - UnsupportedFormat, - #[error("Invalid attributes")] - InvalidAttributes, - #[error("Invalid dimensions")] - InvalidDimensions, - #[error("Invalid dimension list")] - InvalidDimensionList, - #[error("Invalid coordinate type")] - InvalidCoordinateType, - #[error("Invalid receiver position")] - InvalidReceiverPositions, - #[error("Emitters without ECI are not supported")] - OnlyEmitterWithEciSupported, - #[error("Delays without IR or MR are not supported")] - OnlyDelaysWithIrOrMrSupported, - #[error("Sources without MC are not supported")] - OnlySourcesWithMcSupported, - #[error("Sampling rates differ")] - OnlyTheSameSamplingRateSupported, -} - -impl Error { - pub(crate) fn from_raw(err: i32) -> Error { - use Error::*; - - match err { - ffi::MYSOFA_INVALID_FORMAT => InvalidFormat, - ffi::MYSOFA_UNSUPPORTED_FORMAT => UnsupportedFormat, - ffi::MYSOFA_INVALID_ATTRIBUTES => InvalidAttributes, - ffi::MYSOFA_INVALID_DIMENSIONS => InvalidDimensions, - ffi::MYSOFA_INVALID_DIMENSION_LIST => InvalidDimensionList, - ffi::MYSOFA_INVALID_COORDINATE_TYPE => InvalidCoordinateType, - ffi::MYSOFA_INVALID_RECEIVER_POSITIONS => InvalidReceiverPositions, - ffi::MYSOFA_ONLY_EMITTER_WITH_ECI_SUPPORTED => OnlyEmitterWithEciSupported, - ffi::MYSOFA_ONLY_DELAYS_WITH_IR_OR_MR_SUPPORTED => OnlyDelaysWithIrOrMrSupported, - ffi::MYSOFA_ONLY_SOURCES_WITH_MC_SUPPORTED => OnlySourcesWithMcSupported, - ffi::MYSOFA_ONLY_THE_SAME_SAMPLING_RATE_SUPPORTED => OnlyTheSameSamplingRateSupported, - ffi::MYSOFA_READ_ERROR => Io(io::Error::new( - io::ErrorKind::NotFound, - "Unable to read from file", - )), - ffi::MYSOFA_NO_MEMORY => Io(io::Error::new( - io::ErrorKind::OutOfMemory, - "Ran out of memory", - )), - _ => Error::InternalError, - } - } + #[error("Failed to build spatial lookup")] + LookupBuildFailed, + #[error("Resampling failed: {0}")] + ResampleFailed(String), } +/// Options for opening SOFA files. #[derive(Clone, Debug)] pub struct OpenOptions { sample_rate: f32, neighbor_angle_step: f32, neighbor_radius_step: f32, - cached: bool, normalized: bool, } impl OpenOptions { + /// Create a new set of open options with defaults. + /// + /// Default values: + /// - `sample_rate`: 48000.0 + /// - `neighbor_angle_step`: 0.5° + /// - `neighbor_radius_step`: 0.01m + /// - `normalized`: true + /// + /// # Example + /// + /// ```no_run + /// use sofar::reader::OpenOptions; + /// + /// let sofa = OpenOptions::new() + /// .sample_rate(44100.0) + /// .normalized(true) + /// .open("path/to/file.sofa") + /// .unwrap(); + /// ``` pub fn new() -> Self { Default::default() } @@ -95,123 +106,93 @@ impl OpenOptions { } /// Neighbor search angle step measured in degrees. Default value is 0.5. - /// - /// The higher the value the faster search algorithm. The tradeoff - /// is accuracy: higher values will more likely miss a true nearest - /// neighbors. pub fn neighbor_angle_step(&mut self, neighbor_angle_step: f32) -> &mut Self { self.neighbor_angle_step = neighbor_angle_step; self } /// Neighbor search radius step measured in meters. Default value is 0.01. - /// - /// The higher the value the faster search algorithm. The tradeoff - /// is accuracy: higher values will more likely miss a true nearest - /// neighbors. pub fn neighbor_radius_step(&mut self, neighbor_radius_step: f32) -> &mut Self { self.neighbor_radius_step = neighbor_radius_step; self } - /// Using this option tells library to share memory for the files with the - /// same name and sampling rate. - pub fn cached(&mut self, cached: bool) -> &mut Self { - self.cached = cached; - self - } - /// Apply normalization upon opening a SOFA file. Default value is `true` pub fn normalized(&mut self, normalized: bool) -> &mut Self { self.normalized = normalized; self } - /// Open a SOFA file at `path` with open options specified in `self` + /// Open a SOFA file at `path` with open options specified in `self`. + /// + /// # Example /// /// ```no_run /// use sofar::reader::OpenOptions; /// /// let sofa = OpenOptions::new() - /// .normalized(false) /// .sample_rate(44100.0) - /// .open("my/sofa/file.sofa") + /// .open("path/to/file.sofa") /// .unwrap(); /// ``` pub fn open>(&self, path: P) -> Result { - let path = cstr(path.as_ref())?; - let mut filter_len = 0; - let mut err = 0; - - let raw = unsafe { - match self.cached { - true => ffi::mysofa_open_cached( - path.as_ptr(), - self.sample_rate, - &mut filter_len, - &mut err, - ), - false => ffi::mysofa_open_advanced( - path.as_ptr(), - self.sample_rate, - &mut filter_len, - &mut err, - self.normalized, - self.neighbor_angle_step, - self.neighbor_radius_step, - ), - } - }; - - if raw.is_null() || err != ffi::MYSOFA_OK { - return Err(Error::from_raw(err)); - } - - Ok(Sofar { - raw, - filter_len: filter_len as usize, - cached: self.cached, - }) + let data = std::fs::read(path)?; + self.open_data(&data) } - /// Open a SOFA using provided bytes and open options specified in `self` + /// Open a SOFA file from in-memory bytes with open options specified in `self`. + /// + /// # Example /// /// ```no_run /// use sofar::reader::OpenOptions; /// - /// let data: Vec = std::fs::read("my/sofa/file.sofa").unwrap(); - /// + /// let data = std::fs::read("path/to/file.sofa").unwrap(); /// let sofa = OpenOptions::new() - /// .normalized(false) - /// .sample_rate(44100.0) /// .open_data(&data) /// .unwrap(); /// ``` pub fn open_data>(&self, bytes: B) -> Result { - let mut filter_len = 0; - let mut err = 0; - - let raw = unsafe { - ffi::mysofa_open_data_advanced( - bytes.as_ref().as_ptr() as _, - bytes.as_ref().len() as _, - self.sample_rate, - &mut filter_len, - &mut err, - self.normalized, - self.neighbor_angle_step, - self.neighbor_radius_step, - ) - }; - - if raw.is_null() || err != ffi::MYSOFA_OK { - return Err(Error::from_raw(err)); + let mut hrtf = Hrtf::from_bytes(bytes.as_ref())?; + + // Convert all position arrays to cartesian coordinates. + // SOFA files may store positions in spherical coordinates. + hrtf.convert_to_cartesian(); + + // Validate (log warnings but don't fail) + if let Err(e) = validate(&hrtf) { + log::warn!("SOFA validation: {}", e); + } + + // Resample if needed + let original_rate = hrtf.sample_rate(); + if (original_rate - self.sample_rate).abs() > 0.1 { + resample(&mut hrtf, self.sample_rate).map_err(Error::ResampleFailed)?; } + // Normalize if requested + if self.normalized { + let _ = normalize(&mut hrtf); + } + + // Build spatial lookup + let lookup = Lookup::new(&hrtf).ok_or(Error::LookupBuildFailed)?; + + // Build neighborhood for interpolation + let neighborhood = Neighborhood::new( + &hrtf, + &lookup, + self.neighbor_angle_step, + self.neighbor_radius_step, + ); + + let filter_len = hrtf.filter_len(); + Ok(Sofar { - raw, - filter_len: filter_len as usize, - cached: false, + hrtf, + lookup, + neighborhood, + filter_len, }) } } @@ -222,158 +203,158 @@ impl Default for OpenOptions { sample_rate: DEFAULT_SAMPLE_RATE, neighbor_angle_step: DEFAULT_NEIGHBOR_ANGLE_STEP, neighbor_radius_step: DEFAULT_NEIGHBOR_RADIUS_STEP, - cached: DEFAULT_CACHED, normalized: DEFAULT_NORMALIZED, } } } -pub struct Filter { - /// Impulse Response of FIR filter for left channel - pub left: Box<[f32]>, - /// Impulse Response of FIR filter for right channel - pub right: Box<[f32]>, - /// The amount of time in seconds that left channel should be delayed for - pub ldelay: f32, - /// The amount of time in seconds that right channel should be delayed for - pub rdelay: f32, -} - -impl Filter { - pub fn new(filt_len: usize) -> Self { - Self { - left: vec![0.0; filt_len].into_boxed_slice(), - right: vec![0.0; filt_len].into_boxed_slice(), - ldelay: 0.0, - rdelay: 0.0, - } - } -} +pub use crate::filter::Filter; +/// SOFA reader providing access to HRTF filter data. +/// +/// Wraps parsed HRTF data with spatial lookup and neighbor interpolation. +/// Use [`OpenOptions`] for fine-grained control, or the convenience methods +/// [`Sofar::open`] and [`Sofar::open_data`]. pub struct Sofar { - raw: *mut ffi::MYSOFA_EASY, + hrtf: Hrtf, + lookup: Lookup, + neighborhood: Neighborhood, filter_len: usize, - cached: bool, } impl Sofar { - /// Open a SOFA file with the default open options + /// Open a SOFA file with the default open options. + /// + /// # Example /// /// ```no_run /// use sofar::reader::Sofar; /// - /// let sofa = Sofar::open("my/sofa/file.sofa").unwrap(); + /// let sofa = Sofar::open("path/to/file.sofa").unwrap(); + /// println!("Filter length: {}", sofa.filter_len()); /// ``` - pub fn open>(path: P) -> Result { + pub fn open>(path: P) -> Result { OpenOptions::new().open(path) } - /// Open a SOFA using provided bytes and the default open options + /// Open a SOFA file from in-memory bytes with the default open options. + /// + /// # Example /// /// ```no_run /// use sofar::reader::Sofar; /// - /// let data: Vec = std::fs::read("my/sofa/file.sofa").unwrap(); + /// let data = std::fs::read("path/to/file.sofa").unwrap(); /// let sofa = Sofar::open_data(&data).unwrap(); /// ``` pub fn open_data>(bytes: B) -> Result { OpenOptions::new().open_data(bytes) } + /// Get the filter length (number of IR taps per channel). pub fn filter_len(&self) -> usize { self.filter_len } - /// Get HRTF filter for a given position + /// Get the HRTF filter for a given cartesian position using interpolation. /// - /// To produce a stereo output for a given position a source should be - /// delayed by left and right delay and FIR filtered by left and right - /// impulse response. + /// Uses inverse distance weighting to combine the nearest measurement with + /// up to 6 directional neighbors for smooth spatial transitions. /// - /// ```no_run - /// use sofar::reader::{Sofar, Filter}; + /// # Arguments /// - /// let sofa = Sofar::open("my/sofa/file.sofa").unwrap(); - /// let filt_len = sofa.filter_len(); + /// * `x` - Forward/backward position in meters + /// * `y` - Left/right position in meters + /// * `z` - Up/down position in meters + /// * `filter` - Output filter to fill with interpolated IR data /// - /// let mut filter = Filter::new(filt_len); + /// # Example /// - /// sofa.filter(0.0, 1.0, 0.0, &mut filter); - /// ``` + /// ```no_run + /// use sofar::reader::{Sofar, Filter}; /// - /// # Panics + /// let sofa = Sofar::open("path/to/file.sofa").unwrap(); + /// let mut filter = Filter::new(sofa.filter_len()); /// - /// This method panics if: - /// - `filter.left.len() < self.filter_len` - /// - `filter.right.len() < self.filter_len` + /// // Get filter for a source 1 meter in front + /// sofa.filter(1.0, 0.0, 0.0, &mut filter); + /// ``` pub fn filter(&self, x: f32, y: f32, z: f32, filter: &mut Filter) { - assert!(filter.left.len() >= self.filter_len); - assert!(filter.right.len() >= self.filter_len); - - unsafe { - ffi::mysofa_getfilter_float( - self.raw, - x, - y, - z, - filter.left.as_mut_ptr(), - filter.right.as_mut_ptr(), - &mut filter.ldelay, - &mut filter.rdelay, - ); + let position = [x, y, z]; + + if let Some(nearest_idx) = self.lookup.find(&position) + && let Some(interp) = + interpolate(&self.hrtf, &self.neighborhood, nearest_idx, &position) + { + self.fill_filter(filter, &interp); + return; } + + // Fallback: zero filter + filter.left.iter_mut().for_each(|s| *s = 0.0); + filter.right.iter_mut().for_each(|s| *s = 0.0); + filter.ldelay = 0.0; + filter.rdelay = 0.0; } - /// Get HRTF filter for a given position with no interpolation - /// - /// Similar to [`filter`](crate::reader::Filter) method but it will skip the linear - /// interpolation and return the filter for the nearest position instead. + /// Get the HRTF filter for a given position without interpolation. /// - /// # Panics - /// - /// This method panics if: - /// - `filter.left.len() < self.filter_len` - /// - `filter.right.len() < self.filter_len` + /// Returns the nearest measurement without blending with neighbors. + /// Faster than [`filter`](Sofar::filter) but may produce audible + /// discontinuities when the source position changes. pub fn filter_nointerp(&self, x: f32, y: f32, z: f32, filter: &mut Filter) { - assert!(filter.left.len() >= self.filter_len); - assert!(filter.right.len() >= self.filter_len); - - unsafe { - ffi::mysofa_getfilter_float_nointerp( - self.raw, - x, - y, - z, - filter.left.as_mut_ptr(), - filter.right.as_mut_ptr(), - &mut filter.ldelay, - &mut filter.rdelay, - ); + let position = [x, y, z]; + + if let Some(nearest_idx) = self.lookup.find(&position) + && let Some(interp) = get_filter_nointerp(&self.hrtf, nearest_idx) + { + self.fill_filter(filter, &interp); + return; } + + // Fallback: zero filter + filter.left.iter_mut().for_each(|s| *s = 0.0); + filter.right.iter_mut().for_each(|s| *s = 0.0); + filter.ldelay = 0.0; + filter.rdelay = 0.0; } -} -impl Drop for Sofar { - fn drop(&mut self) { - unsafe { - match self.cached { - true => ffi::mysofa_close_cached(self.raw), - false => ffi::mysofa_close(self.raw), - } + fn fill_filter(&self, filter: &mut Filter, interp: &InterpolatedFilter) { + let copy_len = interp.left.len().min(filter.left.len()); + filter.left[..copy_len].copy_from_slice(&interp.left[..copy_len]); + if copy_len < filter.left.len() { + filter.left[copy_len..].iter_mut().for_each(|s| *s = 0.0); } + + let copy_len = interp.right.len().min(filter.right.len()); + filter.right[..copy_len].copy_from_slice(&interp.right[..copy_len]); + if copy_len < filter.right.len() { + filter.right[copy_len..].iter_mut().for_each(|s| *s = 0.0); + } + + // Convert delay from samples to seconds + let sample_rate = self.hrtf.sample_rate(); + filter.ldelay = interp.delay_left / sample_rate; + filter.rdelay = interp.delay_right / sample_rate; } -} -unsafe impl Send for Sofar {} -unsafe impl Sync for Sofar {} + /// Get access to the underlying HRTF data. + pub fn hrtf(&self) -> &Hrtf { + &self.hrtf + } -#[cfg(unix)] -fn cstr(path: &Path) -> std::io::Result { - use std::os::unix::ffi::OsStrExt; - Ok(CString::new(path.as_os_str().as_bytes())?) -} + /// Get access to the spatial lookup. + pub fn lookup(&self) -> &Lookup { + &self.lookup + } + + /// Get the sample rate. + pub fn sample_rate(&self) -> f32 { + self.hrtf.sample_rate() + } -#[cfg(windows)] -fn cstr(path: &Path) -> std::io::Result { - Ok(CString::new(path.as_os_str().to_str().unwrap().as_bytes())?) + /// Get the number of measurements. + pub fn num_measurements(&self) -> u32 { + self.hrtf.m() + } } diff --git a/src/render.rs b/src/render.rs index 146e037..0d6e406 100644 --- a/src/render.rs +++ b/src/render.rs @@ -8,7 +8,7 @@ use std::sync::Arc; -use crate::reader::Filter; +use crate::filter::Filter; use realfft::num_complex::Complex; use realfft::num_traits::Zero; @@ -48,12 +48,12 @@ impl Delay { } fn set_delay(&mut self, delay: usize) { - let n = self.buf.len(); - - if delay >= n { + if delay >= self.buf.len() { self.buf.resize(delay + 1, 0.0) } + let n = self.buf.len(); + if self.wpos >= delay { self.rpos = self.wpos - delay; } else { @@ -88,27 +88,15 @@ impl Delay { struct Channel { /// impulse response split into partition blocks h: Box<[Complex]>, - /// input blocks frequency domain delay line - x_fdl: Box<[Complex]>, - /// input blocks time domain delay line - x_tdl: Box<[f32]>, /// left channel delay state delay: Option, } impl Channel { - fn new( - fft_len: usize, - spectra_len: usize, - partitions: usize, - sample_rate: f32, - delay: Option, - ) -> Self { + fn new(spectra_len: usize, partitions: usize, sample_rate: f32, delay: Option) -> Self { let zero = Complex::new(0.0, 0.0); let h = vec![zero; spectra_len * partitions].into_boxed_slice(); - let x_fdl = vec![zero; spectra_len * partitions].into_boxed_slice(); - let x_tdl = vec![0.0; fft_len].into_boxed_slice(); let delay = delay.and_then(|delay| { if delay > 0.0 { @@ -118,12 +106,7 @@ impl Channel { } }); - Channel { - h, - x_tdl, - x_fdl, - delay, - } + Channel { h, delay } } fn delay(&mut self, mut buf: O) @@ -136,10 +119,10 @@ impl Channel { } fn update_delay(&mut self, new_delay: usize) { - if let Some(delay) = self.delay.as_mut() { - if new_delay != delay.delay { - delay.set_delay(new_delay) - } + if let Some(delay) = self.delay.as_mut() + && new_delay != delay.delay + { + delay.set_delay(new_delay) } } @@ -147,9 +130,6 @@ impl Channel { if let Some(delay) = self.delay.as_mut() { delay.reset(); } - - self.x_tdl.fill(0.0); - self.x_fdl.fill(Complex::zero()); } } @@ -209,7 +189,7 @@ impl RendererBuilder { false => return Err(Error::InvalidSampleRate(self.sample_rate)), }; - let partitions = (self.filter_len + self.partition_len - 1) / self.partition_len; + let partitions = self.filter_len.div_ceil(self.partition_len); let fft_len = self.partition_len * 2; let spectra_len = fft_len / 2 + 1; @@ -219,6 +199,10 @@ impl RendererBuilder { let filt_pad = vec![0.0; fft_len].into_boxed_slice(); let acc = vec![zero; spectra_len].into_boxed_slice(); + // Shared input state (computed once per block, used by both channels) + let x_tdl = vec![0.0; fft_len].into_boxed_slice(); + let x_fdl = vec![zero; spectra_len * partitions].into_boxed_slice(); + let mut planner = RealFftPlanner::::new(); let rfft = planner.plan_fft_forward(fft_len); let ifft = planner.plan_fft_inverse(fft_len); @@ -226,26 +210,15 @@ impl RendererBuilder { let rfft_scratch = rfft.make_scratch_vec(); let ifft_scratch = ifft.make_scratch_vec(); - let left = Channel::new( - fft_len, - spectra_len, - partitions, - sample_rate, - self.left_delay, - ); - let right = Channel::new( - fft_len, - spectra_len, - partitions, - sample_rate, - self.right_delay, - ); + let left = Channel::new(spectra_len, partitions, sample_rate, self.left_delay); + let right = Channel::new(spectra_len, partitions, sample_rate, self.right_delay); let state = State { acc, rfft, ifft, fft_len, + inv_scale: 1.0 / fft_len as f32, scratch, filt_pad, rfft_scratch, @@ -254,6 +227,9 @@ impl RendererBuilder { sample_rate: self.sample_rate, filter_len: self.filter_len, partition_len: self.partition_len, + x_tdl, + x_fdl, + fdl_head: 0, }; Ok(Renderer { left, right, state }) @@ -330,10 +306,24 @@ impl Renderer { )); } - self.state - .conv(&mut self.left, input.as_ref(), left.as_mut())?; - self.state - .conv(&mut self.right, input.as_ref(), right.as_mut())?; + let x = input.as_ref(); + let left_out = left.as_mut(); + let right_out = right.as_mut(); + let block_len = self.state.partition_len; + + let mut off = 0; + while off < x.len() { + // Prepare shared input FFT (once per block, shared by both channels) + self.state.prepare_input(&x[off..off + block_len])?; + + // Apply per-channel filter and produce output + self.state + .apply_filter(&self.left.h, &mut left_out[off..off + block_len])?; + self.state + .apply_filter(&self.right.h, &mut right_out[off..off + block_len])?; + + off += block_len; + } self.left.delay(left.as_mut()); self.right.delay(right.as_mut()); @@ -345,6 +335,9 @@ impl Renderer { pub fn reset(&mut self) { self.left.reset(); self.right.reset(); + self.state.x_tdl.fill(0.0); + self.state.x_fdl.fill(Complex::zero()); + self.state.fdl_head = 0; } } @@ -360,13 +353,15 @@ struct State { partitions: usize, /// FFT size fft_len: usize, + /// Precomputed 1.0 / fft_len for output scaling + inv_scale: f32, /// Real FFT module rfft: Arc>, /// Inverse FFT module ifft: Arc>, /// RFFT scratch memory rfft_scratch: Vec>, - /// RFFT scratch memory + /// IFFT scratch memory ifft_scratch: Vec>, /// mutable internal scratch for fft input scratch: Box<[f32]>, @@ -374,74 +369,79 @@ struct State { filt_pad: Box<[f32]>, /// accumulator for point wise multiplication acc: Box<[Complex]>, + /// shared input time-domain delay line (used by both channels) + x_tdl: Box<[f32]>, + /// shared input frequency-domain delay line (ring buffer) + x_fdl: Box<[Complex]>, + /// ring buffer head index for x_fdl (points to newest slot) + fdl_head: usize, } impl State { - fn conv(&mut self, channel: &mut Channel, x: I, mut y: O) -> Result<(), Error> - where - I: AsRef<[f32]>, - O: AsMut<[f32]>, - { - let x = x.as_ref(); - let y = y.as_mut(); - + /// Prepare the shared input block: update time-domain delay line, compute + /// forward FFT, and store the result in the frequency-domain ring buffer. + /// Called once per block (shared by both channels). + fn prepare_input(&mut self, block: &[f32]) -> Result<(), Error> { let spectra_len = self.fft_len / 2 + 1; let block_len = self.partition_len; - let scale = self.fft_len as f32; - let mut off = 0; + // Shift left part of TDL and store new data in right part + self.x_tdl.copy_within(block_len.., 0); + self.x_tdl[block_len..].copy_from_slice(block); - while off < x.len() { - // shift right part of the buffer to the left - channel.x_tdl.copy_within(block_len.., 0); - // store new data in right part - channel.x_tdl[block_len..].copy_from_slice(&x[off..off + block_len]); - // shift up the fdl content by one slot - channel.x_fdl.rotate_right(spectra_len); - // move data to processing scratch - self.scratch.copy_from_slice(&channel.x_tdl); - // take real to complex fft of input block and store it in the first fdl slot - self.rfft.process_with_scratch( - &mut self.scratch, - &mut channel.x_fdl[..spectra_len], - &mut self.rfft_scratch, - )?; + // Advance ring buffer head (wrapping) + if self.fdl_head == 0 { + self.fdl_head = self.partitions - 1; + } else { + self.fdl_head -= 1; + } - // point wise multiply with filter and accumulate the results - let mut p_off = 0; - self.acc.fill(Complex::new(0.0, 0.0)); - - for _ in 0..self.partitions { - for (acc, (x, h)) in Iterator::zip( - self.acc.iter_mut(), - Iterator::zip( - channel.x_fdl[p_off..p_off + spectra_len].iter(), - channel.h[p_off..p_off + spectra_len].iter(), - ), - ) { - *acc += x * h; - } - - p_off += spectra_len; - } + // Copy TDL to scratch and compute forward FFT into the new head slot + self.scratch.copy_from_slice(&self.x_tdl); + let head_start = self.fdl_head * spectra_len; + self.rfft.process_with_scratch( + &mut self.scratch, + &mut self.x_fdl[head_start..head_start + spectra_len], + &mut self.rfft_scratch, + )?; - // take complex to real transform - self.ifft.process_with_scratch( - &mut self.acc, - &mut self.scratch, - &mut self.ifft_scratch, - )?; + Ok(()) + } + + /// Apply a channel's filter to the shared input FDL and produce output. + /// Uses the ring buffer to access input spectra without physical rotation. + fn apply_filter(&mut self, h: &[Complex], y: &mut [f32]) -> Result<(), Error> { + let spectra_len = self.fft_len / 2 + 1; + let block_len = self.partition_len; - // discard the left part and write the right part as the next output block - for (y, x) in Iterator::zip( - y[off..off + block_len].iter_mut(), - self.scratch[block_len..].iter(), + // Point-wise multiply with filter and accumulate + self.acc.fill(Complex::new(0.0, 0.0)); + + for p in 0..self.partitions { + // Map logical partition p to physical FDL slot via ring buffer + let fdl_idx = (self.fdl_head + p) % self.partitions; + let fdl_off = fdl_idx * spectra_len; + let h_off = p * spectra_len; + + for (acc, (x, h)) in Iterator::zip( + self.acc.iter_mut(), + Iterator::zip( + self.x_fdl[fdl_off..fdl_off + spectra_len].iter(), + h[h_off..h_off + spectra_len].iter(), + ), ) { - *y = x / scale; + *acc += x * h; } + } - // update offset - off += block_len; + // Inverse FFT + self.ifft + .process_with_scratch(&mut self.acc, &mut self.scratch, &mut self.ifft_scratch)?; + + // Write output (second half), scaling by 1/N + let inv_scale = self.inv_scale; + for (y, x) in Iterator::zip(y[..block_len].iter_mut(), self.scratch[block_len..].iter()) { + *y = x * inv_scale; } Ok(()) @@ -458,6 +458,7 @@ impl State { for partition in iter.by_ref() { self.filt_pad[..block_len].copy_from_slice(partition); + self.filt_pad[block_len..].fill(0.0); self.rfft.process_with_scratch( &mut self.filt_pad, @@ -636,8 +637,8 @@ mod tests { expected.rotate_right(42); - for i in 0..42 { - expected[i] = 0.0; + for item in expected.iter_mut().take(42) { + *item = 0.0; } delay.apply(input.as_mut_slice()); diff --git a/src/sofa/coords.rs b/src/sofa/coords.rs new file mode 100644 index 0000000..c7c0870 --- /dev/null +++ b/src/sofa/coords.rs @@ -0,0 +1,200 @@ +//! Coordinate conversion utilities for SOFA/HRTF data. +//! +//! SOFA files can store positions in either cartesian [x, y, z] or +//! spherical [azimuth, elevation, radius] coordinates. This module +//! provides functions to convert between these representations. +//! +//! ## Coordinate Systems +//! +//! **Cartesian: x, y, z** in meters: +//! - x: forward direction +//! - y: left direction +//! - z: up direction +//! +//! **Spherical: azimuth, elevation, radius**: +//! - azimuth: angle in degrees, 0° = front, 90° = left, 180°/-180° = back +//! - elevation: angle in degrees, 0° = horizontal, 90° = up, -90° = down +//! - radius: distance in meters + +use std::f32::consts::PI; + +/// Convert cartesian coordinates to spherical. +/// +/// Input: [x, y, z] in meters. +/// Output: [azimuth, elevation, radius] where azimuth/elevation are in degrees. +pub fn cartesian_to_spherical(cartesian: [f32; 3]) -> [f32; 3] { + let [x, y, z] = cartesian; + + let r = (x * x + y * y + z * z).sqrt(); + let theta = z.atan2((x * x + y * y).sqrt()); // elevation + let phi = y.atan2(x); // azimuth + + [ + (phi * (180.0 / PI) + 360.0) % 360.0, // azimuth in degrees, 0-360 + theta * (180.0 / PI), // elevation in degrees + r, // radius in meters + ] +} + +/// Convert spherical coordinates to cartesian. +/// +/// Input: [azimuth, elevation, radius] where azimuth/elevation are in degrees. +/// Output: [x, y, z] in meters. +pub fn spherical_to_cartesian(spherical: [f32; 3]) -> [f32; 3] { + let [azimuth, elevation, radius] = spherical; + + let phi = azimuth * (PI / 180.0); + let theta = elevation * (PI / 180.0); + + let horizontal_dist = theta.cos() * radius; + + [ + phi.cos() * horizontal_dist, // x + phi.sin() * horizontal_dist, // y + theta.sin() * radius, // z + ] +} + +/// Compute the distance from origin of a cartesian point. +pub fn radius(cartesian: &[f32; 3]) -> f32 { + (cartesian[0].powi(2) + cartesian[1].powi(2) + cartesian[2].powi(2)).sqrt() +} + +/// Convert an array of cartesian coordinates to spherical in-place. +/// +/// The array should contain triplets of [x, y, z] values. +/// After conversion, each triplet becomes [azimuth, elevation, radius]. +#[allow(dead_code)] +pub fn convert_array_to_spherical(values: &mut [f32]) { + for chunk in values.chunks_exact_mut(3) { + let cart: [f32; 3] = [chunk[0], chunk[1], chunk[2]]; + let sph = cartesian_to_spherical(cart); + chunk[0] = sph[0]; + chunk[1] = sph[1]; + chunk[2] = sph[2]; + } +} + +/// Convert an array of spherical coordinates to cartesian in-place. +/// +/// The array should contain triplets of [azimuth, elevation, radius] values. +/// After conversion, each triplet becomes [x, y, z]. +pub fn convert_array_to_cartesian(values: &mut [f32]) { + for chunk in values.chunks_exact_mut(3) { + let sph: [f32; 3] = [chunk[0], chunk[1], chunk[2]]; + let cart = spherical_to_cartesian(sph); + chunk[0] = cart[0]; + chunk[1] = cart[1]; + chunk[2] = cart[2]; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const EPSILON: f32 = 1e-5; + + fn approx_eq(a: f32, b: f32) -> bool { + (a - b).abs() < EPSILON + } + + #[test] + fn test_cartesian_to_spherical_front() { + // Front: [1, 0, 0] -> [0°, 0°, 1m] + let cart = [1.0, 0.0, 0.0]; + let sph = cartesian_to_spherical(cart); + assert!(approx_eq(sph[0], 0.0), "azimuth: {} != 0", sph[0]); + assert!(approx_eq(sph[1], 0.0), "elevation: {} != 0", sph[1]); + assert!(approx_eq(sph[2], 1.0), "radius: {} != 1", sph[2]); + } + + #[test] + fn test_cartesian_to_spherical_left() { + // Left: [0, 1, 0] -> [90°, 0°, 1m] + let cart = [0.0, 1.0, 0.0]; + let sph = cartesian_to_spherical(cart); + assert!(approx_eq(sph[0], 90.0), "azimuth: {} != 90", sph[0]); + assert!(approx_eq(sph[1], 0.0), "elevation: {} != 0", sph[1]); + assert!(approx_eq(sph[2], 1.0), "radius: {} != 1", sph[2]); + } + + #[test] + fn test_cartesian_to_spherical_up() { + // Up: [0, 0, 1] -> [0°, 90°, 1m] + let cart = [0.0, 0.0, 1.0]; + let sph = cartesian_to_spherical(cart); + assert!(approx_eq(sph[1], 90.0), "elevation: {} != 90", sph[1]); + assert!(approx_eq(sph[2], 1.0), "radius: {} != 1", sph[2]); + } + + #[test] + fn test_spherical_to_cartesian_front() { + // [0°, 0°, 1m] -> [1, 0, 0] + let sph = [0.0, 0.0, 1.0]; + let cart = spherical_to_cartesian(sph); + assert!(approx_eq(cart[0], 1.0), "x: {} != 1", cart[0]); + assert!(approx_eq(cart[1], 0.0), "y: {} != 0", cart[1]); + assert!(approx_eq(cart[2], 0.0), "z: {} != 0", cart[2]); + } + + #[test] + fn test_spherical_to_cartesian_left() { + // [90°, 0°, 1m] -> [0, 1, 0] + let sph = [90.0, 0.0, 1.0]; + let cart = spherical_to_cartesian(sph); + assert!(approx_eq(cart[0], 0.0), "x: {} != 0", cart[0]); + assert!(approx_eq(cart[1], 1.0), "y: {} != 1", cart[1]); + assert!(approx_eq(cart[2], 0.0), "z: {} != 0", cart[2]); + } + + #[test] + fn test_roundtrip_cartesian_spherical() { + let original = [0.5, 0.3, 0.7]; + let spherical = cartesian_to_spherical(original); + let back = spherical_to_cartesian(spherical); + + assert!( + approx_eq(original[0], back[0]), + "x: {} != {}", + original[0], + back[0] + ); + assert!( + approx_eq(original[1], back[1]), + "y: {} != {}", + original[1], + back[1] + ); + assert!( + approx_eq(original[2], back[2]), + "z: {} != {}", + original[2], + back[2] + ); + } + + #[test] + fn test_convert_array_to_spherical() { + let mut values = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; // front, left + convert_array_to_spherical(&mut values); + + // First point: front + assert!(approx_eq(values[0], 0.0)); // azimuth + assert!(approx_eq(values[1], 0.0)); // elevation + assert!(approx_eq(values[2], 1.0)); // radius + + // Second point: left + assert!(approx_eq(values[3], 90.0)); // azimuth + assert!(approx_eq(values[4], 0.0)); // elevation + assert!(approx_eq(values[5], 1.0)); // radius + } + + #[test] + fn test_radius() { + assert!(approx_eq(radius(&[1.0, 0.0, 0.0]), 1.0)); + assert!(approx_eq(radius(&[0.0, 1.0, 0.0]), 1.0)); + assert!(approx_eq(radius(&[0.0, 0.0, 1.0]), 1.0)); + assert!(approx_eq(radius(&[1.0, 1.0, 1.0]), 3.0_f32.sqrt())); + } +} diff --git a/src/sofa/error.rs b/src/sofa/error.rs new file mode 100644 index 0000000..de6a605 --- /dev/null +++ b/src/sofa/error.rs @@ -0,0 +1,44 @@ +//! Error types for SOFA file processing. + +/// Error type for SOFA operations. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("Parse error: {0}")] + Parse(String), + #[error("Invalid SOFA format: missing 'Conventions: SOFA' attribute")] + InvalidFormat, + #[error("Missing required dimension: {0}")] + MissingDimension(char), + #[error("Invalid dimension {name}: got {value}, expected {expected}")] + InvalidDimension { + name: char, + value: u32, + expected: u32, + }, + #[error("Missing required array: {0}")] + MissingArray(&'static str), + #[error("Invalid array size for {name}: expected {expected}, got {actual}")] + InvalidArraySize { + name: &'static str, + expected: usize, + actual: usize, + }, + #[error("Invalid attribute {name}: expected {expected}")] + InvalidAttribute { + name: &'static str, + expected: &'static str, + }, + #[error("Unsupported data type")] + UnsupportedDataType, +} + +impl From for Error { + fn from(e: winnow::error::ContextError) -> Self { + Error::Parse(format!("{:?}", e)) + } +} + +/// Result type for SOFA operations. +pub type Result = std::result::Result; diff --git a/src/sofa/interpolate.rs b/src/sofa/interpolate.rs new file mode 100644 index 0000000..0bac09d --- /dev/null +++ b/src/sofa/interpolate.rs @@ -0,0 +1,283 @@ +//! Filter interpolation using inverse distance weighting. +//! +//! Combines the nearest measurement with up to 6 directional neighbors +//! using inverse distance weighting for smooth HRTF interpolation. + +use arrayvec::ArrayVec; + +use super::kdtree::Point3; +use super::neighbors::Neighborhood; +use super::reader::Hrtf; + +/// Interpolated HRTF filter result. +#[derive(Debug, Clone)] +pub struct InterpolatedFilter { + /// Left channel impulse response. + pub left: Vec, + /// Right channel impulse response. + pub right: Vec, + /// Left channel delay in samples. + pub delay_left: f32, + /// Right channel delay in samples. + pub delay_right: f32, +} + +/// Interpolate HRTF filter at a given position. +/// +/// Uses inverse distance weighting to combine the nearest measurement +/// with up to 6 directional neighbors. +/// +/// # Arguments +/// * `hrtf` - The HRTF data +/// * `neighborhood` - Precomputed neighbor relationships +/// * `nearest_idx` - Index of the nearest measurement +/// * `position` - The query position in cartesian coordinates +/// +/// # Returns +/// An interpolated filter, or None if the data is invalid. +pub fn interpolate( + hrtf: &Hrtf, + neighborhood: &Neighborhood, + nearest_idx: usize, + position: &Point3, +) -> Option { + let dims = hrtf.dimensions(); + let n = dims.n as usize; + let r = dims.r as usize; + let m = dims.m as usize; + let c = dims.c as usize; + + if r < 2 || nearest_idx >= m { + return None; + } + + let ir_values = &hrtf.data_ir.values; + let delay_values = &hrtf.data_delay.values; + let source_pos = &hrtf.source_position.values; + + // Get the nearest position + let nearest_offset = nearest_idx * c; + if nearest_offset + 2 >= source_pos.len() { + return None; + } + let nearest_pos: Point3 = [ + source_pos[nearest_offset], + source_pos[nearest_offset + 1], + source_pos[nearest_offset + 2], + ]; + + // Calculate distance to nearest + let nearest_dist = distance(&nearest_pos, position); + + // If very close (exact match), just return the nearest filter + if nearest_dist < 1e-6 { + return extract_filter(hrtf, nearest_idx, n, r); + } + + // Get neighbors + let neighbors = neighborhood.get(nearest_idx)?; + + // Collect points with their distances and weights. At most 7: nearest + 6 neighbors. + let mut points: ArrayVec<(usize, f32), 7> = ArrayVec::new(); + points.push((nearest_idx, 1.0 / nearest_dist)); + + // Add the closer neighbor from each directional pair + add_closer_neighbor( + &mut points, + neighbors.phi_plus, + neighbors.phi_minus, + position, + source_pos, + c, + ); + add_closer_neighbor( + &mut points, + neighbors.theta_plus, + neighbors.theta_minus, + position, + source_pos, + c, + ); + add_closer_neighbor( + &mut points, + neighbors.radius_plus, + neighbors.radius_minus, + position, + source_pos, + c, + ); + + // Compute weight sum for normalization + let weight_sum: f32 = points.iter().map(|(_, w)| w).sum(); + if weight_sum < 1e-10 { + return extract_filter(hrtf, nearest_idx, n, r); + } + + // Initialize result + let mut left = vec![0.0f32; n]; + let mut right = vec![0.0f32; n]; + let mut delay_left = 0.0f32; + let mut delay_right = 0.0f32; + + // Weighted sum of filters + for (idx, weight) in &points { + let norm_weight = weight / weight_sum; + + // IR offsets: data_ir is M * R * N + let ir_offset_left = idx * r * n; + let ir_offset_right = ir_offset_left + n; + + // Add weighted IR values + if ir_offset_right + n <= ir_values.len() { + for i in 0..n { + left[i] += ir_values[ir_offset_left + i] * norm_weight; + right[i] += ir_values[ir_offset_right + i] * norm_weight; + } + } + + // Add weighted delay values + // Delay can be: + // - per-measurement: M values + // - per-channel: M*R values + // - global per-receiver: R values, e.g. 2 for stereo. Matches C behavior. + if delay_values.len() == m { + // Single delay per measurement + delay_left += delay_values[*idx] * norm_weight; + delay_right += delay_values[*idx] * norm_weight; + } else if delay_values.len() >= m * r { + // Per-channel delay + let delay_offset = idx * r; + if delay_offset + 1 < delay_values.len() { + delay_left += delay_values[delay_offset] * norm_weight; + delay_right += delay_values[delay_offset + 1] * norm_weight; + } + } else if delay_values.len() >= r { + // Global delay values, one per receiver channel + delay_left += delay_values[0] * norm_weight; + delay_right += delay_values.get(1).copied().unwrap_or(delay_values[0]) * norm_weight; + } + } + + Some(InterpolatedFilter { + left, + right, + delay_left, + delay_right, + }) +} + +/// Add the closer neighbor from a directional pair to the points list. +fn add_closer_neighbor( + points: &mut ArrayVec<(usize, f32), 7>, + neighbor_a: Option, + neighbor_b: Option, + position: &Point3, + source_pos: &[f32], + c: usize, +) { + let mut best_idx = None; + let mut best_dist = f32::MAX; + + for idx in [neighbor_a, neighbor_b].into_iter().flatten() { + let offset = idx * c; + if offset + 2 >= source_pos.len() { + continue; + } + let pos: Point3 = [ + source_pos[offset], + source_pos[offset + 1], + source_pos[offset + 2], + ]; + let dist = distance(&pos, position); + if dist < best_dist { + best_dist = dist; + best_idx = Some(idx); + } + } + + if let Some(idx) = best_idx + && best_dist > 1e-10 + { + points.push((idx, 1.0 / best_dist)); + } +} + +/// Extract a single filter without interpolation. +fn extract_filter(hrtf: &Hrtf, idx: usize, n: usize, r: usize) -> Option { + let ir_values = &hrtf.data_ir.values; + let delay_values = &hrtf.data_delay.values; + let m = hrtf.dimensions().m as usize; + + let ir_offset_left = idx * r * n; + let ir_offset_right = ir_offset_left + n; + + if ir_offset_right + n > ir_values.len() { + return None; + } + + let left = ir_values[ir_offset_left..ir_offset_left + n].to_vec(); + let right = ir_values[ir_offset_right..ir_offset_right + n].to_vec(); + + let (delay_left, delay_right) = if delay_values.len() == m { + ( + delay_values.get(idx).copied().unwrap_or(0.0), + delay_values.get(idx).copied().unwrap_or(0.0), + ) + } else if delay_values.len() >= m * r { + let offset = idx * r; + ( + delay_values.get(offset).copied().unwrap_or(0.0), + delay_values.get(offset + 1).copied().unwrap_or(0.0), + ) + } else if delay_values.len() >= r { + // Global delay values, one per receiver channel + ( + delay_values.first().copied().unwrap_or(0.0), + delay_values + .get(1) + .copied() + .unwrap_or(delay_values.first().copied().unwrap_or(0.0)), + ) + } else { + (0.0, 0.0) + }; + + Some(InterpolatedFilter { + left, + right, + delay_left, + delay_right, + }) +} + +/// Calculate euclidean distance between two points. +#[inline] +fn distance(a: &Point3, b: &Point3) -> f32 { + ((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2)).sqrt() +} + +/// Get filter without interpolation, using nearest neighbor only. +/// +/// # Arguments +/// * `hrtf` - The HRTF data +/// * `nearest_idx` - Index of the nearest measurement +/// +/// # Returns +/// The filter at the given index, or None if invalid. +pub fn get_filter_nointerp(hrtf: &Hrtf, nearest_idx: usize) -> Option { + let dims = hrtf.dimensions(); + let n = dims.n as usize; + let r = dims.r as usize; + extract_filter(hrtf, nearest_idx, n, r) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_distance() { + assert!((distance(&[0.0, 0.0, 0.0], &[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6); + assert!((distance(&[0.0, 0.0, 0.0], &[3.0, 4.0, 0.0]) - 5.0).abs() < 1e-6); + } +} diff --git a/src/sofa/kdtree.rs b/src/sofa/kdtree.rs new file mode 100644 index 0000000..401b183 --- /dev/null +++ b/src/sofa/kdtree.rs @@ -0,0 +1,306 @@ +//! KD-tree for efficient 3D nearest neighbor search. +//! +//! This is a simple 3-dimensional KD-tree implementation optimized for +//! finding the nearest HRTF filter position. + +/// A 3D point. +pub type Point3 = [f32; 3]; + +/// A node in the KD-tree. +#[derive(Debug)] +struct KdNode { + /// Position of this node. + pos: Point3, + /// The dimension used for splitting at this node (0=x, 1=y, 2=z). + split_dim: usize, + /// Data associated with this node, typically a filter index. + data: usize, + /// Left subtree, containing values less than the split. + left: Option>, + /// Right subtree, containing values greater than or equal to the split. + right: Option>, +} + +/// Bounding box for the KD-tree. +#[derive(Debug, Clone)] +struct BoundingBox { + min: Point3, + max: Point3, +} + +impl BoundingBox { + fn new(pos: &Point3) -> Self { + Self { + min: *pos, + max: *pos, + } + } + + fn extend(&mut self, pos: &Point3) { + for (i, &p) in pos.iter().enumerate().take(3) { + if p < self.min[i] { + self.min[i] = p; + } + if p > self.max[i] { + self.max[i] = p; + } + } + } + + /// Compute squared distance from point to this bounding box. + fn dist_sq(&self, pos: &Point3) -> f32 { + let mut result = 0.0; + for (i, &p) in pos.iter().enumerate().take(3) { + if p < self.min[i] { + result += (self.min[i] - p).powi(2); + } else if p > self.max[i] { + result += (self.max[i] - p).powi(2); + } + } + result + } +} + +/// A 3-dimensional KD-tree for efficient nearest neighbor search. +#[derive(Debug)] +pub struct KdTree { + root: Option>, + bounds: Option, +} + +impl Default for KdTree { + fn default() -> Self { + Self::new() + } +} + +impl KdTree { + /// Create a new empty KD-tree. + pub fn new() -> Self { + Self { + root: None, + bounds: None, + } + } + + /// Insert a point with associated data into the tree. + /// + /// # Arguments + /// * `pos` - The 3D position [x, y, z] + /// * `data` - The data to associate, typically a filter index + pub fn insert(&mut self, pos: Point3, data: usize) { + // Update bounding box + match &mut self.bounds { + Some(bounds) => bounds.extend(&pos), + None => self.bounds = Some(BoundingBox::new(&pos)), + } + + // Insert into tree + Self::insert_rec(&mut self.root, pos, data, 0); + } + + fn insert_rec(node: &mut Option>, pos: Point3, data: usize, depth: usize) { + let split_dim = depth % 3; + + match node { + None => { + *node = Some(Box::new(KdNode { + pos, + split_dim, + data, + left: None, + right: None, + })); + } + Some(n) => { + if pos[n.split_dim] < n.pos[n.split_dim] { + Self::insert_rec(&mut n.left, pos, data, depth + 1); + } else { + Self::insert_rec(&mut n.right, pos, data, depth + 1); + } + } + } + } + + /// Find the nearest neighbor to the given position. + /// + /// Returns the data associated with the nearest point, or None if tree is empty. + pub fn nearest(&self, pos: &Point3) -> Option { + let root = self.root.as_ref()?; + let bounds = self.bounds.as_ref()?; + + let mut best_node = root.as_ref(); + let mut best_dist_sq = Self::dist_sq(&root.pos, pos); + + // Working copy of bounds for search + let mut rect = bounds.clone(); + + Self::nearest_rec(root, pos, &mut best_node, &mut best_dist_sq, &mut rect); + + Some(best_node.data) + } + + fn nearest_rec<'a>( + node: &'a KdNode, + pos: &Point3, + best: &mut &'a KdNode, + best_dist_sq: &mut f32, + rect: &mut BoundingBox, + ) { + let dim = node.split_dim; + let diff = pos[dim] - node.pos[dim]; + + // Determine which subtree is nearer + let go_left = diff <= 0.0; + + // Recurse into nearer subtree + let nearer = if go_left { &node.left } else { &node.right }; + if let Some(nearer_node) = nearer { + let old_coord = if go_left { + rect.max[dim] + } else { + rect.min[dim] + }; + if go_left { + rect.max[dim] = node.pos[dim]; + } else { + rect.min[dim] = node.pos[dim]; + } + Self::nearest_rec(nearer_node, pos, best, best_dist_sq, rect); + if go_left { + rect.max[dim] = old_coord; + } else { + rect.min[dim] = old_coord; + } + } + + // Check current node + let dist_sq = Self::dist_sq(&node.pos, pos); + if dist_sq < *best_dist_sq { + *best = node; + *best_dist_sq = dist_sq; + } + + // Check if we need to search farther subtree + let farther = if go_left { &node.right } else { &node.left }; + if let Some(farther_node) = farther { + let old_coord = if go_left { + rect.min[dim] + } else { + rect.max[dim] + }; + if go_left { + rect.min[dim] = node.pos[dim]; + } else { + rect.max[dim] = node.pos[dim]; + } + + // Only search if the hyperrect could contain a closer point + if rect.dist_sq(pos) < *best_dist_sq { + Self::nearest_rec(farther_node, pos, best, best_dist_sq, rect); + } + + if go_left { + rect.min[dim] = old_coord; + } else { + rect.max[dim] = old_coord; + } + } + } + + /// Compute squared distance between two points. + #[inline] + fn dist_sq(a: &Point3, b: &Point3) -> f32 { + (a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2) + } + + /// Check if the tree is empty. + pub fn is_empty(&self) -> bool { + self.root.is_none() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_tree() { + let tree = KdTree::new(); + assert!(tree.is_empty()); + assert_eq!(tree.nearest(&[0.0, 0.0, 0.0]), None); + } + + #[test] + fn test_single_point() { + let mut tree = KdTree::new(); + tree.insert([1.0, 2.0, 3.0], 42); + + assert!(!tree.is_empty()); + assert_eq!(tree.nearest(&[0.0, 0.0, 0.0]), Some(42)); + assert_eq!(tree.nearest(&[1.0, 2.0, 3.0]), Some(42)); + assert_eq!(tree.nearest(&[100.0, 100.0, 100.0]), Some(42)); + } + + #[test] + fn test_multiple_points() { + let mut tree = KdTree::new(); + tree.insert([0.0, 0.0, 0.0], 0); + tree.insert([1.0, 0.0, 0.0], 1); + tree.insert([0.0, 1.0, 0.0], 2); + tree.insert([0.0, 0.0, 1.0], 3); + + // Origin should find index 0 + assert_eq!(tree.nearest(&[0.0, 0.0, 0.0]), Some(0)); + + // Point closest to [1, 0, 0] should find index 1 + assert_eq!(tree.nearest(&[0.9, 0.0, 0.0]), Some(1)); + + // Point closest to [0, 1, 0] should find index 2 + assert_eq!(tree.nearest(&[0.0, 0.9, 0.0]), Some(2)); + + // Point closest to [0, 0, 1] should find index 3 + assert_eq!(tree.nearest(&[0.0, 0.0, 0.9]), Some(3)); + } + + #[test] + fn test_find_nearest_among_many() { + let mut tree = KdTree::new(); + + // Create a grid of points + let mut idx = 0; + for x in -5..=5 { + for y in -5..=5 { + for z in -5..=5 { + tree.insert([x as f32, y as f32, z as f32], idx); + idx += 1; + } + } + } + + // Test that origin finds [0, 0, 0] + let center_idx = 5 * 11 * 11 + 5 * 11 + 5; // Index of [0, 0, 0] + assert_eq!(tree.nearest(&[0.1, 0.1, 0.1]), Some(center_idx)); + + // Test corner + let corner_idx = 0; // Index of (-5, -5, -5) + assert_eq!(tree.nearest(&[-4.9, -4.9, -4.9]), Some(corner_idx)); + } + + #[test] + fn test_bounding_box() { + let mut bb = BoundingBox::new(&[0.0, 0.0, 0.0]); + assert_eq!(bb.min, [0.0, 0.0, 0.0]); + assert_eq!(bb.max, [0.0, 0.0, 0.0]); + + bb.extend(&[1.0, -1.0, 2.0]); + assert_eq!(bb.min, [0.0, -1.0, 0.0]); + assert_eq!(bb.max, [1.0, 0.0, 2.0]); + + // Point inside box has distance 0 + assert_eq!(bb.dist_sq(&[0.5, -0.5, 1.0]), 0.0); + + // Point outside box + assert!((bb.dist_sq(&[2.0, 0.0, 0.0]) - 1.0).abs() < 1e-6); + } +} diff --git a/src/sofa/lookup.rs b/src/sofa/lookup.rs new file mode 100644 index 0000000..3b37e74 --- /dev/null +++ b/src/sofa/lookup.rs @@ -0,0 +1,190 @@ +//! Spatial lookup for finding nearest HRTF filters. +//! +//! Provides efficient spatial search for finding the closest HRTF filter +//! to a given 3D position using a KD-tree. + +use super::coords::{cartesian_to_spherical, radius}; +use super::kdtree::{KdTree, Point3}; +use super::reader::Hrtf; + +/// Spatial lookup structure for finding nearest HRTF filters. +/// +/// Uses a KD-tree for efficient O(log n) nearest neighbor queries. +#[derive(Debug)] +pub struct Lookup { + /// The KD-tree containing source positions. + kdtree: KdTree, + /// Minimum azimuth, phi, in degrees. + pub phi_min: f32, + /// Maximum azimuth, phi, in degrees. + pub phi_max: f32, + /// Minimum elevation, theta, in degrees. + pub theta_min: f32, + /// Maximum elevation, theta, in degrees. + pub theta_max: f32, + /// Minimum radius in meters. + pub radius_min: f32, + /// Maximum radius in meters. + pub radius_max: f32, +} + +impl Lookup { + /// Initialize a lookup structure from HRTF data. + /// + /// Builds a KD-tree from the source positions and computes + /// the bounding box in spherical coordinates. + /// + /// # Arguments + /// * `hrtf` - The HRTF data containing source positions + /// + /// # Returns + /// A new Lookup structure, or None if source positions are empty + /// or not in cartesian coordinates. + pub fn new(hrtf: &Hrtf) -> Option { + let source_pos = &hrtf.source_position.values; + let c = hrtf.dimensions().c as usize; + let m = hrtf.dimensions().m as usize; + + if source_pos.is_empty() || c != 3 { + return None; + } + + // Build KD-tree and compute spherical bounds + let mut kdtree = KdTree::new(); + let mut phi_min = f32::MAX; + let mut phi_max = f32::MIN; + let mut theta_min = f32::MAX; + let mut theta_max = f32::MIN; + let mut radius_min = f32::MAX; + let mut radius_max = f32::MIN; + + for i in 0..m { + let offset = i * c; + if offset + 2 >= source_pos.len() { + break; + } + + let pos: Point3 = [ + source_pos[offset], + source_pos[offset + 1], + source_pos[offset + 2], + ]; + + // Insert into KD-tree + kdtree.insert(pos, i); + + // Convert to spherical for bounds + let spherical = cartesian_to_spherical(pos); + let phi = spherical[0]; + let theta = spherical[1]; + let r = spherical[2]; + + phi_min = phi_min.min(phi); + phi_max = phi_max.max(phi); + theta_min = theta_min.min(theta); + theta_max = theta_max.max(theta); + radius_min = radius_min.min(r); + radius_max = radius_max.max(r); + } + + if kdtree.is_empty() { + return None; + } + + Some(Self { + kdtree, + phi_min, + phi_max, + theta_min, + theta_max, + radius_min, + radius_max, + }) + } + + /// Find the nearest filter index for a given cartesian coordinate. + /// + /// The coordinate may be normalized to fit within the radius bounds + /// of the available measurements. + /// + /// # Arguments + /// * `coordinate` - The cartesian position [x, y, z] to look up + /// + /// # Returns + /// The index of the nearest filter, or None if lookup fails. + pub fn find(&self, coordinate: &Point3) -> Option { + // Normalize radius if outside bounds + let pos = self.normalize_radius(*coordinate); + self.kdtree.nearest(&pos) + } + + /// Find the nearest filter index, modifying the coordinate in place + /// to reflect any radius normalization. + /// + /// # Arguments + /// * `coordinate` - The cartesian position [x, y, z] to look up, which may be modified + /// + /// # Returns + /// The index of the nearest filter, or None if lookup fails. + pub fn find_mut(&self, coordinate: &mut Point3) -> Option { + *coordinate = self.normalize_radius(*coordinate); + self.kdtree.nearest(coordinate) + } + + /// Normalize a coordinate's radius to fit within bounds. + fn normalize_radius(&self, mut pos: Point3) -> Point3 { + let r = radius(&pos); + + if r > self.radius_max { + let scale = self.radius_max / r; + pos[0] *= scale; + pos[1] *= scale; + pos[2] *= scale; + } else if r < self.radius_min && r > 0.0 { + let scale = self.radius_min / r; + pos[0] *= scale; + pos[1] *= scale; + pos[2] *= scale; + } + + pos + } + + /// Get the number of measurements in the lookup. + pub fn is_empty(&self) -> bool { + self.kdtree.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normalize_radius() { + let lookup = Lookup { + kdtree: KdTree::new(), + phi_min: 0.0, + phi_max: 360.0, + theta_min: -90.0, + theta_max: 90.0, + radius_min: 1.0, + radius_max: 2.0, + }; + + // Point at radius 3 should be scaled to radius 2 + let pos = lookup.normalize_radius([3.0, 0.0, 0.0]); + let r = radius(&pos); + assert!((r - 2.0).abs() < 1e-5, "radius {} != 2.0", r); + + // Point at radius 0.5 should be scaled to radius 1 + let pos = lookup.normalize_radius([0.5, 0.0, 0.0]); + let r = radius(&pos); + assert!((r - 1.0).abs() < 1e-5, "radius {} != 1.0", r); + + // Point at radius 1.5 should remain unchanged + let pos = lookup.normalize_radius([1.5, 0.0, 0.0]); + let r = radius(&pos); + assert!((r - 1.5).abs() < 1e-5, "radius {} != 1.5", r); + } +} diff --git a/src/sofa/loudness.rs b/src/sofa/loudness.rs new file mode 100644 index 0000000..2131827 --- /dev/null +++ b/src/sofa/loudness.rs @@ -0,0 +1,147 @@ +//! Loudness normalization for HRTF data. +//! +//! Normalizes all HRTF filters so that the frontal filter has unit energy, +//! ensuring consistent perceived loudness across all directions. + +use super::coords::cartesian_to_spherical; +use super::reader::Hrtf; + +/// Compute the energy, i.e. the sum of squared samples, of a signal. +/// +/// # Arguments +/// * `samples` - The audio samples +/// +/// # Returns +/// The total energy, computed as the sum of squares. +pub fn loudness(samples: &[f32]) -> f32 { + samples.iter().map(|s| s * s).sum() +} + +/// Find the index of the frontal filter. +/// +/// The frontal filter is the one closest to 0° azimuth and 0° elevation. +/// Among ties, prefers the one with maximum radius. +/// +/// # Arguments +/// * `hrtf` - The HRTF data +/// +/// # Returns +/// The index of the frontal filter, or None if source positions are empty. +pub fn find_frontal_index(hrtf: &Hrtf) -> Option { + let source_pos = &hrtf.source_position.values; + let c = hrtf.dimensions().c as usize; + let m = hrtf.dimensions().m as usize; + + if c != 3 || source_pos.is_empty() { + return None; + } + + let mut best_idx = 0; + let mut best_sum = f32::MAX; + let mut best_radius = f32::MIN; + + for i in 0..m { + let offset = i * c; + if offset + 2 >= source_pos.len() { + break; + } + + let pos = [ + source_pos[offset], + source_pos[offset + 1], + source_pos[offset + 2], + ]; + + // Convert to spherical: [azimuth, elevation, radius] + let spherical = cartesian_to_spherical(pos); + let azimuth = spherical[0]; + let elevation = spherical[1]; + let radius = spherical[2]; + + // Normalize azimuth to -180..180 for frontal comparison + let azimuth_norm = if azimuth > 180.0 { + azimuth - 360.0 + } else { + azimuth + }; + + // Sum of absolute angles - lower is more frontal + let sum = azimuth_norm.abs() + elevation.abs(); + + // Prefer lower sum, or same sum with larger radius + if sum < best_sum || (sum == best_sum && radius > best_radius) { + best_sum = sum; + best_radius = radius; + best_idx = i; + } + } + + Some(best_idx) +} + +/// Normalize HRTF filters for consistent loudness. +/// +/// Scales all IR data so that the frontal filter has unit energy. +/// This modifies the HRTF data in place. +/// +/// # Arguments +/// * `hrtf` - The HRTF data to normalize +/// +/// # Returns +/// The scaling factor applied, or None if normalization failed. +pub fn normalize(hrtf: &mut Hrtf) -> Option { + let frontal_idx = find_frontal_index(hrtf)?; + + let n = hrtf.dimensions().n as usize; + let r = hrtf.dimensions().r as usize; + + // Get the frontal filter's IR data + let ir_offset = frontal_idx * r * n; + let ir_end = ir_offset + r * n; + + if ir_end > hrtf.data_ir.values.len() { + return None; + } + + // Compute loudness of frontal filter across both channels + let frontal_energy = loudness(&hrtf.data_ir.values[ir_offset..ir_end]); + + if frontal_energy < 1e-10 { + return None; + } + + // Compute scaling factor: sqrt(2 / energy) for unit energy + let factor = (2.0 / frontal_energy).sqrt(); + + // Skip if already normalized (factor ≈ 1) + if (factor - 1.0).abs() < 1e-6 { + return Some(1.0); + } + + // Scale all IR data + for sample in &mut hrtf.data_ir.values { + *sample *= factor; + } + + Some(factor) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_loudness() { + // Energy of [1, 1, 1, 1] = 4 + assert!((loudness(&[1.0, 1.0, 1.0, 1.0]) - 4.0).abs() < 1e-6); + + // Energy of [0.5, 0.5, 0.5, 0.5] = 1 + assert!((loudness(&[0.5, 0.5, 0.5, 0.5]) - 1.0).abs() < 1e-6); + + // Energy of [3, 4] = 9 + 16 = 25 + assert!((loudness(&[3.0, 4.0]) - 25.0).abs() < 1e-6); + + // Empty slice has zero energy + assert_eq!(loudness(&[]), 0.0); + } +} diff --git a/src/sofa/mod.rs b/src/sofa/mod.rs new file mode 100644 index 0000000..6ad6fd4 --- /dev/null +++ b/src/sofa/mod.rs @@ -0,0 +1,27 @@ +//! SOFA file reader and processor for Spatially Oriented Format for Acoustics. +//! +//! This module provides pure Rust implementation for reading and processing +//! SOFA files containing HRTF, or Head-Related Transfer Function, data. +//! +//! This module is internal. Use [`crate::reader`] for the public API. + +pub mod coords; +mod error; +mod interpolate; +mod kdtree; +mod lookup; +mod loudness; +mod neighbors; +mod reader; +mod resample; +mod types; +mod validate; + +pub use error::Error; +pub use interpolate::{InterpolatedFilter, get_filter_nointerp, interpolate}; +pub use lookup::Lookup; +pub use loudness::normalize; +pub use neighbors::Neighborhood; +pub use reader::Hrtf; +pub use resample::resample; +pub use validate::validate; diff --git a/src/sofa/neighbors.rs b/src/sofa/neighbors.rs new file mode 100644 index 0000000..1089681 --- /dev/null +++ b/src/sofa/neighbors.rs @@ -0,0 +1,245 @@ +//! Neighbor computation for HRTF interpolation. +//! +//! Computes directional neighbors (phi+/-, theta+/-, radius+/-) for each +//! measurement position to enable smooth interpolation. + +use super::coords::{cartesian_to_spherical, spherical_to_cartesian}; +use super::kdtree::Point3; +use super::lookup::Lookup; +use super::reader::Hrtf; + +/// Maximum angle to search for neighbors (in degrees). +const MAX_SEARCH_ANGLE: f32 = 45.0; + +/// Neighbor indices for one measurement position. +/// +/// Contains indices to neighboring measurements in 6 directions: +/// phi+/-, theta+/-, radius+/-. A value of `None` means no neighbor +/// was found in that direction. +#[derive(Debug, Clone, Default)] +pub struct Neighbors { + /// Neighbor in positive phi, i.e. azimuth, direction. + pub phi_plus: Option, + /// Neighbor in negative phi, i.e. azimuth, direction. + pub phi_minus: Option, + /// Neighbor in positive theta, i.e. elevation, direction. + pub theta_plus: Option, + /// Neighbor in negative theta, i.e. elevation, direction. + pub theta_minus: Option, + /// Neighbor in positive radius direction. + pub radius_plus: Option, + /// Neighbor in negative radius direction. + pub radius_minus: Option, +} + +/// Neighborhood structure containing precomputed neighbors for all positions. +#[derive(Debug)] +pub struct Neighborhood { + /// Neighbors for each measurement position (length = M). + neighbors: Vec, + /// Step size for angular search (degrees). + #[allow(dead_code)] + angle_step: f32, + /// Step size for radius search (meters). + #[allow(dead_code)] + radius_step: f32, +} + +impl Neighborhood { + /// Build neighborhood structure from HRTF data. + /// + /// # Arguments + /// * `hrtf` - The HRTF data + /// * `lookup` - The spatial lookup structure + /// * `angle_step` - Step size for angular search in degrees (default: 0.5) + /// * `radius_step` - Step size for radius search in meters (default: 0.01) + pub fn new(hrtf: &Hrtf, lookup: &Lookup, angle_step: f32, radius_step: f32) -> Self { + let m = hrtf.dimensions().m as usize; + let c = hrtf.dimensions().c as usize; + let source_pos = &hrtf.source_position.values; + + let mut neighbors = Vec::with_capacity(m); + + for i in 0..m { + let offset = i * c; + if offset + 2 >= source_pos.len() { + neighbors.push(Neighbors::default()); + continue; + } + + let pos: Point3 = [ + source_pos[offset], + source_pos[offset + 1], + source_pos[offset + 2], + ]; + + neighbors.push(Self::find_neighbors( + i, + &pos, + lookup, + angle_step, + radius_step, + )); + } + + Self { + neighbors, + angle_step, + radius_step, + } + } + + /// Find all directional neighbors for a position. + fn find_neighbors( + current_idx: usize, + pos: &Point3, + lookup: &Lookup, + angle_step: f32, + radius_step: f32, + ) -> Neighbors { + let spherical = cartesian_to_spherical(*pos); + let phi = spherical[0]; + let theta = spherical[1]; + let r = spherical[2]; + + Neighbors { + phi_plus: Self::search_direction( + current_idx, + phi, + theta, + r, + angle_step, + 0.0, + 0.0, + lookup, + ), + phi_minus: Self::search_direction( + current_idx, + phi, + theta, + r, + -angle_step, + 0.0, + 0.0, + lookup, + ), + theta_plus: Self::search_direction( + current_idx, + phi, + theta, + r, + 0.0, + angle_step, + 0.0, + lookup, + ), + theta_minus: Self::search_direction( + current_idx, + phi, + theta, + r, + 0.0, + -angle_step, + 0.0, + lookup, + ), + radius_plus: Self::search_direction( + current_idx, + phi, + theta, + r, + 0.0, + 0.0, + radius_step, + lookup, + ), + radius_minus: Self::search_direction( + current_idx, + phi, + theta, + r, + 0.0, + 0.0, + -radius_step, + lookup, + ), + } + } + + /// Search in a direction until finding a different measurement point. + #[allow(clippy::too_many_arguments)] + fn search_direction( + current_idx: usize, + phi: f32, + theta: f32, + radius: f32, + phi_step: f32, + theta_step: f32, + radius_step: f32, + lookup: &Lookup, + ) -> Option { + let max_steps = if phi_step.abs() > 0.0 || theta_step.abs() > 0.0 { + (MAX_SEARCH_ANGLE / phi_step.abs().max(theta_step.abs())).ceil() as i32 + } else { + // For radius, search until bounds + 100 + }; + + for step in 1..=max_steps { + let step_f = step as f32; + let new_phi = phi + phi_step * step_f; + let new_theta = (theta + theta_step * step_f).clamp(-90.0, 90.0); + let new_radius = (radius + radius_step * step_f).max(0.001); + + // Skip if radius would be outside bounds (matching C behavior: ± radius_step) + if new_radius < lookup.radius_min - radius_step.abs() + || new_radius > lookup.radius_max + radius_step.abs() + { + break; + } + + let cart = spherical_to_cartesian([new_phi, new_theta, new_radius]); + + if let Some(idx) = lookup.find(&cart) + && idx != current_idx + { + return Some(idx); + } + } + + None + } + + /// Get neighbors for a measurement index. + pub fn get(&self, index: usize) -> Option<&Neighbors> { + self.neighbors.get(index) + } + + /// Get the angle step used for neighbor computation. + #[allow(dead_code)] + pub fn angle_step(&self) -> f32 { + self.angle_step + } + + /// Get the radius step used for neighbor computation. + #[allow(dead_code)] + pub fn radius_step(&self) -> f32 { + self.radius_step + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_neighbors_default() { + let n = Neighbors::default(); + assert!(n.phi_plus.is_none()); + assert!(n.phi_minus.is_none()); + assert!(n.theta_plus.is_none()); + assert!(n.theta_minus.is_none()); + assert!(n.radius_plus.is_none()); + assert!(n.radius_minus.is_none()); + } +} diff --git a/src/sofa/reader.rs b/src/sofa/reader.rs new file mode 100644 index 0000000..b12f9cd --- /dev/null +++ b/src/sofa/reader.rs @@ -0,0 +1,476 @@ +//! SOFA/HRTF file reader. +//! +//! Reads SOFA files and extracts HRTF, Head-Related Transfer Function, data. + +use std::collections::HashMap; +use std::path::Path; + +use crate::hdf::{self, DataObject, DataType, ParsedHdf}; + +use super::error::{Error, Result}; +use super::types::{Array, Dimensions}; + +/// HRTF data loaded from a SOFA file. +/// +/// Contains the spatial audio impulse response data along with position +/// information for head-related transfer function processing. +#[derive(Debug, Clone)] +pub struct Hrtf { + /// SOFA dimensions: I, C, R, E, N, M + dimensions: Dimensions, + + /// Listener position. I × C elements. + pub listener_position: Array, + /// Receiver positions relative to listener. R × C × I elements. + pub receiver_position: Array, + /// Source positions for each measurement. M × C elements. + pub source_position: Array, + /// Emitter positions. E × C × I elements. + pub emitter_position: Array, + /// Listener up vector. I × C elements. + pub listener_up: Array, + /// Listener view direction. I × C elements. + pub listener_view: Array, + + /// Impulse response data. M × R × N elements. + pub data_ir: Array, + /// Sampling rates + pub data_sampling_rate: Array, + /// Per-filter delays. M × R elements. + pub data_delay: Array, + + /// File-level attributes + pub attributes: HashMap, +} + +impl Hrtf { + /// Load HRTF from a SOFA file. + /// + /// # Errors + /// + /// Returns an error if the file cannot be read or is not a valid SOFA file. + pub fn open>(path: P) -> Result { + let data = std::fs::read(path)?; + Self::from_bytes(&data) + } + + /// Load HRTF from bytes. + /// + /// # Errors + /// + /// Returns an error if the data is not a valid SOFA file. + pub fn from_bytes(data: &[u8]) -> Result { + let parsed = hdf::parse_with_children(data)?; + Self::from_parsed_hdf(&parsed) + } + + /// Build HRTF from a parsed HDF5 file. + fn from_parsed_hdf(parsed: &ParsedHdf<'_>) -> Result { + let root = &parsed.root; + + // Check for SOFA convention attribute + let conventions = root + .parsed_attributes + .iter() + .find(|a| a.name == "Conventions") + .and_then(|a| a.value.as_ref()) + .ok_or(Error::InvalidFormat)?; + + if conventions != "SOFA" { + return Err(Error::InvalidFormat); + } + + // Build attributes map + let attributes: HashMap = root + .parsed_attributes + .iter() + .filter_map(|a| a.value.as_ref().map(|v| (a.name.clone(), v.clone()))) + .collect(); + + // Parse dimensions from child objects + let dimensions = Self::parse_dimensions(parsed)?; + + // Parse arrays from child objects + let listener_position = Self::parse_array(parsed, "ListenerPosition").unwrap_or_default(); + let receiver_position = Self::parse_array(parsed, "ReceiverPosition").unwrap_or_default(); + let source_position = Self::parse_array(parsed, "SourcePosition").unwrap_or_default(); + let emitter_position = Self::parse_array(parsed, "EmitterPosition").unwrap_or_default(); + let listener_up = Self::parse_array(parsed, "ListenerUp").unwrap_or_default(); + let listener_view = Self::parse_array(parsed, "ListenerView").unwrap_or_default(); + let data_ir = Self::parse_array(parsed, "Data.IR").unwrap_or_default(); + let data_sampling_rate = Self::parse_array(parsed, "Data.SamplingRate").unwrap_or_default(); + let data_delay = Self::parse_array(parsed, "Data.Delay").unwrap_or_default(); + + Ok(Self { + dimensions, + listener_position, + receiver_position, + source_position, + emitter_position, + listener_up, + listener_view, + data_ir, + data_sampling_rate, + data_delay, + attributes, + }) + } + + /// Parse SOFA dimensions from child data objects. + /// + /// Dimensions are stored as single-character named datasets (I, C, R, E, N, M). + /// We first check that all required dimension objects exist, then try to + /// extract values. If that fails, we infer dimensions from array sizes. + fn parse_dimensions(parsed: &ParsedHdf<'_>) -> Result { + let mut dims = Dimensions::default(); + let mut found = 0u8; + + // Check which dimensions exist + for dir in &parsed.root.child_directories { + if dir.name.len() == 1 { + let ch = dir.name.chars().next().unwrap(); + match ch { + 'I' => found |= 0x01, + 'C' => found |= 0x02, + 'R' => found |= 0x04, + 'E' => found |= 0x08, + 'N' => found |= 0x10, + 'M' => found |= 0x20, + _ => {} + } + } + } + + // Check all required dimensions are present + if found != 0x3F { + for (mask, name) in [ + (0x01, 'I'), + (0x02, 'C'), + (0x04, 'R'), + (0x08, 'E'), + (0x10, 'N'), + (0x20, 'M'), + ] { + if found & mask == 0 { + return Err(Error::MissingDimension(name)); + } + } + } + + // Set spec-mandated values + dims.i = 1; + dims.c = 3; + + // Try to infer R, E, N, M from arrays we can parse + // Default values for typical binaural HRTF + dims.r = 2; + dims.e = 1; + dims.n = 1; + dims.m = 1; + + // Try parsing Data.IR to get M, R, N (shape is M × R × N) + // Use a Result to handle parse failures gracefully + let ir_result = parsed.get_child("Data.IR"); + if let Some(Ok(ir_obj)) = ir_result + && ir_obj.ds.dimensionality >= 3 + { + dims.m = ir_obj.ds.dimension_size.first().copied().unwrap_or(1) as u32; + dims.r = ir_obj.ds.dimension_size.get(1).copied().unwrap_or(2) as u32; + dims.n = ir_obj.ds.dimension_size.get(2).copied().unwrap_or(1) as u32; + } + + // Note: We skip trying to parse SourcePosition and EmitterPosition for now + // as they may have unsupported data formats. The dimensions from Data.IR + // should be sufficient for most HRTF operations. + + Ok(dims) + } + + /// Extract dimension value from a data object. + /// + /// Dimensions are stored either as: + /// 1. A netCDF dimension attribute: "This is a netCDF dimension but not a netCDF variable. N" + /// 2. As a scalar value in the data + #[allow(dead_code)] // May be useful for future dimension parsing improvements + fn extract_dimension_value(obj: &DataObject) -> Result { + // Check for netCDF dimension attribute + for attr in &obj.parsed_attributes { + if attr.name == "NAME" + && let Some(value) = &attr.value + && value.starts_with("This is a netCDF dimension") + { + // Extract number from end of string + let num_str: String = value + .chars() + .rev() + .take_while(|c| c.is_ascii_digit()) + .collect(); + let num_str: String = num_str.chars().rev().collect(); + if let Ok(n) = num_str.parse::() { + return Ok(n); + } + } + } + + // Fall back to reading from data (single u64 or u32) + if !obj.data.is_empty() { + if obj.data.len() >= 8 { + let bytes: [u8; 8] = obj.data[0..8].try_into().unwrap(); + return Ok(u64::from_le_bytes(bytes) as u32); + } else if obj.data.len() >= 4 { + let bytes: [u8; 4] = obj.data[0..4].try_into().unwrap(); + return Ok(u32::from_le_bytes(bytes)); + } + } + + // Default fallback based on SOFA spec + Ok(1) + } + + /// Parse an array from a named child object. + fn parse_array(parsed: &ParsedHdf<'_>, name: &str) -> Option { + let child_result = parsed.get_child(name)?; + let child = match child_result { + Ok(c) => c, + Err(_e) => { + log::debug!("Failed to parse array '{}': {:?}", name, _e); + return None; + } + }; + Self::data_object_to_array(&child) + } + + /// Convert a DataObject to an Array of f32 values. + fn data_object_to_array(obj: &DataObject) -> Option { + if obj.data.is_empty() { + return None; + } + + // Build attributes map + let attributes: HashMap = obj + .parsed_attributes + .iter() + .filter_map(|a| a.value.as_ref().map(|v| (a.name.clone(), v.clone()))) + .collect(); + + // Convert data based on type + let values = Self::convert_data_to_f32(&obj.data, &obj.dt)?; + + Some(Array { values, attributes }) + } + + /// Convert raw bytes to f32 values based on data type. + fn convert_data_to_f32(data: &[u8], dt: &DataType) -> Option> { + let class = dt.class_and_version & 0x0F; + + match class { + // Float type + 1 => { + let precision = dt + .data_fmt + .as_ref() + .map(|f| match f { + hdf::DataFormat::Float { bit_precision, .. } => *bit_precision, + _ => 64, + }) + .unwrap_or(64); + + if precision == 64 { + // f64 (double) - convert to f32 + let count = data.len() / 8; + let mut values = Vec::with_capacity(count); + for i in 0..count { + let bytes: [u8; 8] = data[i * 8..(i + 1) * 8].try_into().ok()?; + values.push(f64::from_le_bytes(bytes) as f32); + } + Some(values) + } else if precision == 32 { + // f32 - direct copy + let count = data.len() / 4; + let mut values = Vec::with_capacity(count); + for i in 0..count { + let bytes: [u8; 4] = data[i * 4..(i + 1) * 4].try_into().ok()?; + values.push(f32::from_le_bytes(bytes)); + } + Some(values) + } else { + None + } + } + // Integer type + 0 => { + let size = dt.size as usize; + if size == 8 { + // i64/u64 - convert to f32 + let count = data.len() / 8; + let mut values = Vec::with_capacity(count); + for i in 0..count { + let bytes: [u8; 8] = data[i * 8..(i + 1) * 8].try_into().ok()?; + values.push(i64::from_le_bytes(bytes) as f32); + } + Some(values) + } else if size == 4 { + // i32/u32 - convert to f32 + let count = data.len() / 4; + let mut values = Vec::with_capacity(count); + for i in 0..count { + let bytes: [u8; 4] = data[i * 4..(i + 1) * 4].try_into().ok()?; + values.push(i32::from_le_bytes(bytes) as f32); + } + Some(values) + } else { + None + } + } + _ => None, + } + } + + // Accessors for dimensions + + /// Number of measurements (HRTF filter positions). + pub fn m(&self) -> u32 { + self.dimensions.m + } + + /// Number of samples per measurement (filter length). + pub fn n(&self) -> u32 { + self.dimensions.n + } + + /// Number of receivers (typically 2 for binaural). + pub fn r(&self) -> u32 { + self.dimensions.r + } + + /// Number of emitters. + pub fn e(&self) -> u32 { + self.dimensions.e + } + + /// Get the sampling rate. + pub fn sample_rate(&self) -> f32 { + self.data_sampling_rate + .values + .first() + .copied() + .unwrap_or(48000.0) + } + + /// Get the filter length in samples. + pub fn filter_len(&self) -> usize { + self.dimensions.n as usize + } + + /// Get the dimensions. + pub fn dimensions(&self) -> &Dimensions { + &self.dimensions + } + + /// Get an attribute value by name. + pub fn get_attribute(&self, name: &str) -> Option<&str> { + self.attributes.get(name).map(|s| s.as_str()) + } + + /// Set the filter length (N dimension). + pub(crate) fn set_n(&mut self, n: u32) { + self.dimensions.n = n; + } + + /// Set the sample rate. + pub(crate) fn set_sample_rate(&mut self, rate: f32) { + if self.data_sampling_rate.values.is_empty() { + self.data_sampling_rate.values.push(rate); + } else { + self.data_sampling_rate.values[0] = rate; + } + } + + /// Convert all position arrays from spherical to cartesian coordinates. + /// + /// This matches the C library's `mysofa_tocartesian` behavior. Each array + /// is converted only if its "Type" attribute is "spherical". + pub(crate) fn convert_to_cartesian(&mut self) { + convert_array_to_cartesian_if_spherical(&mut self.source_position); + convert_array_to_cartesian_if_spherical(&mut self.receiver_position); + convert_array_to_cartesian_if_spherical(&mut self.emitter_position); + convert_array_to_cartesian_if_spherical(&mut self.listener_position); + convert_array_to_cartesian_if_spherical(&mut self.listener_view); + convert_array_to_cartesian_if_spherical(&mut self.listener_up); + } +} + +/// Convert an array's values from spherical to cartesian if its "Type" +/// attribute indicates spherical coordinates. Updates the attribute to +/// "cartesian" after conversion. +fn convert_array_to_cartesian_if_spherical(array: &mut super::types::Array) { + let coord_type = array.get_attribute("Type"); + + match coord_type { + Some("cartesian") | None => return, + Some("spherical") => {} + Some(other) => { + log::warn!("Unknown coordinate type: {other}, assuming cartesian"); + return; + } + } + + super::coords::convert_array_to_cartesian(&mut array.values); + array + .attributes + .insert("Type".to_string(), "cartesian".to_string()); + array + .attributes + .insert("Units".to_string(), "meter".to_string()); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dimensions_validity() { + let mut dims = Dimensions::default(); + assert!(!dims.is_valid()); + + dims.i = 1; + dims.c = 3; + dims.r = 2; + dims.e = 1; + dims.n = 128; + dims.m = 100; + assert!(dims.is_valid()); + } + + #[test] + fn test_hdf_parsing_debug() { + let cwd = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); + std::env::set_current_dir(cwd).unwrap(); + + let data = std::fs::read("libmysofa-sys/libmysofa/tests/tester.sofa").unwrap(); + let parsed = crate::hdf::parse_with_children(&data).unwrap(); + + println!("Root attributes:"); + for attr in &parsed.root.parsed_attributes { + println!( + " {:?} (len={}) = {:?}", + attr.name, + attr.name.len(), + attr.value + ); + } + + println!("\nChild directories:"); + for dir in &parsed.root.child_directories { + println!(" {} at {:#x}", dir.name, dir.address); + } + + // Check that we have Conventions attribute + let conventions = parsed + .root + .parsed_attributes + .iter() + .find(|a| a.name == "Conventions"); + assert!(conventions.is_some(), "Conventions attribute not found"); + } +} diff --git a/src/sofa/resample.rs b/src/sofa/resample.rs new file mode 100644 index 0000000..50b5a25 --- /dev/null +++ b/src/sofa/resample.rs @@ -0,0 +1,200 @@ +//! Resampling for HRTF filters. +//! +//! This module provides resampling functionality for HRTF IR data using +//! either the `rubato` crate, behind a feature flag, or a simple linear interpolation +//! fallback. + +use super::reader::Hrtf; + +#[cfg(feature = "resample")] +use audioadapter_buffers::direct::SequentialSliceOfVecs; +#[cfg(feature = "resample")] +use rubato::audioadapter::Adapter; +#[cfg(feature = "resample")] +use rubato::{Fft, FixedSync, Resampler}; + +/// Resample HRTF data to a target sample rate. +/// +/// This modifies the HRTF data in place, updating: +/// - IR filter coefficients, resampled to the new rate +/// - Filter length N +/// - Data delay values, scaled proportionally +/// - Sample rate +/// +/// # Arguments +/// * `hrtf` - The HRTF data to resample +/// * `target_rate` - The desired sample rate in Hz +/// +/// # Returns +/// Ok if resampling succeeded, or an error message. +#[cfg(feature = "resample")] +pub fn resample(hrtf: &mut Hrtf, target_rate: f32) -> Result<(), String> { + let source_rate = hrtf.sample_rate(); + if (source_rate - target_rate).abs() < 0.1 { + return Ok(()); // Already at target rate + } + + let ratio = target_rate as f64 / source_rate as f64; + let dims = hrtf.dimensions(); + let m = dims.m as usize; + let r = dims.r as usize; + let n = dims.n as usize; + let new_n = (n as f64 * ratio).ceil() as usize; + + // Create resampler + let chunk_size = n; + let mut resampler = Fft::::new( + source_rate as usize, + target_rate as usize, + chunk_size, + 1, // One sub-chunk + 1, // One channel at a time + FixedSync::Input, + ) + .map_err(|e| format!("Failed to create resampler: {}", e))?; + + let output_delay = resampler.output_delay(); + + // Zero-filled chunk used to flush the resampler pipeline + let flush_data = vec![vec![0.0f32; n]]; + let flush_adapter = SequentialSliceOfVecs::new(&flush_data, 1, n) + .map_err(|e| format!("Flush adapter error: {}", e))?; + + // Prepare new IR buffer + let mut new_ir = vec![0.0f32; m * r * new_n]; + + // Process each filter + for measurement in 0..m { + for channel in 0..r { + let src_offset = measurement * r * n + channel * n; + let dst_offset = measurement * r * new_n + channel * new_n; + + // Get source samples as a single-channel buffer + if src_offset + n > hrtf.data_ir.values.len() { + continue; + } + let input_data = vec![hrtf.data_ir.values[src_offset..src_offset + n].to_vec()]; + let input_adapter = SequentialSliceOfVecs::new(&input_data, 1, n) + .map_err(|e| format!("Input adapter error: {}", e))?; + + // Reset resampler state + resampler.reset(); + + // The FFT resampler has internal latency (output_delay frames). + // A single process() call buffers the input without producing output. + // We must flush with zero-padded chunks and skip the delay. + let mut all_output: Vec = Vec::with_capacity(output_delay + new_n); + let total_needed = output_delay + new_n; + + // Feed actual data + let output = resampler + .process(&input_adapter, 0, None) + .map_err(|e| format!("Resampling failed: {}", e))?; + collect_frames(&output, &mut all_output); + + // Flush with zeros until we have enough output + while all_output.len() < total_needed { + let output = resampler + .process(&flush_adapter, 0, None) + .map_err(|e| format!("Resampling flush failed: {}", e))?; + if output.frames() == 0 { + break; + } + collect_frames(&output, &mut all_output); + } + + // Skip the delay and copy resampled data + let start = output_delay.min(all_output.len()); + let copy_len = new_n.min(all_output.len().saturating_sub(start)); + new_ir[dst_offset..dst_offset + copy_len] + .copy_from_slice(&all_output[start..start + copy_len]); + } + } + + // Update IR data + hrtf.data_ir.values = new_ir; + + // Scale delay values + for delay in &mut hrtf.data_delay.values { + *delay *= ratio as f32; + } + + // Update dimensions + hrtf.set_n(new_n as u32); + hrtf.set_sample_rate(target_rate); + + Ok(()) +} + +#[cfg(feature = "resample")] +fn collect_frames<'a>(output: &impl Adapter<'a, f32>, dest: &mut Vec) { + for i in 0..output.frames() { + dest.push(output.read_sample(0, i).unwrap_or(0.0)); + } +} + +/// Simple linear interpolation fallback when rubato is not available. +/// +/// This provides basic resampling functionality with lower quality than +/// the rubato-based implementation. +#[cfg(not(feature = "resample"))] +pub fn resample(hrtf: &mut Hrtf, target_rate: f32) -> Result<(), String> { + let source_rate = hrtf.sample_rate(); + if (source_rate - target_rate).abs() < 0.1 { + return Ok(()); // Already at target rate + } + + let ratio = target_rate / source_rate; + let dims = hrtf.dimensions(); + let m = dims.m as usize; + let r = dims.r as usize; + let n = dims.n as usize; + let new_n = (n as f32 * ratio).ceil() as usize; + + // Prepare new IR buffer + let mut new_ir = vec![0.0f32; m * r * new_n]; + + // Process each filter with linear interpolation + for measurement in 0..m { + for channel in 0..r { + let src_offset = measurement * r * n + channel * n; + let dst_offset = measurement * r * new_n + channel * new_n; + + if src_offset + n > hrtf.data_ir.values.len() { + continue; + } + + let src = &hrtf.data_ir.values[src_offset..src_offset + n]; + + for i in 0..new_n { + let src_pos = i as f32 / ratio; + let idx = src_pos as usize; + let frac = src_pos - idx as f32; + + let sample = if idx + 1 < n { + src[idx] * (1.0 - frac) + src[idx + 1] * frac + } else if idx < n { + src[idx] + } else { + 0.0 + }; + + new_ir[dst_offset + i] = sample; + } + } + } + + // Update IR data + hrtf.data_ir.values = new_ir; + + // Scale delay values + for delay in &mut hrtf.data_delay.values { + *delay *= ratio; + } + + // Update dimensions + hrtf.set_n(new_n as u32); + hrtf.set_sample_rate(target_rate); + + Ok(()) +} diff --git a/src/sofa/types.rs b/src/sofa/types.rs new file mode 100644 index 0000000..160917b --- /dev/null +++ b/src/sofa/types.rs @@ -0,0 +1,78 @@ +//! Core data types for SOFA files. +//! +//! These types represent SOFA-specific structures built on top of the HDF5 parser. + +use std::collections::HashMap; + +/// SOFA dimensions as defined in AES69 standard. +/// +/// - `I`: Singleton dimension, always 1 +/// - `C`: Coordinate triplet, always 3 +/// - `R`: Number of receivers, i.e. microphone capsules +/// - `E`: Number of emitters, i.e. sound sources +/// - `N`: Number of samples per measurement, i.e. the filter length +/// - `M`: Number of measurements, the total HRTF filters +#[derive(Debug, Clone, Copy, Default)] +pub struct Dimensions { + /// Singleton dimension, always 1 + pub i: u32, + /// Coordinate triplet, always 3 + pub c: u32, + /// Number of receivers + pub r: u32, + /// Number of emitters + pub e: u32, + /// Number of samples per measurement, i.e. the filter length + pub n: u32, + /// Number of measurements + pub m: u32, +} + +impl Dimensions { + /// Check if all required dimensions are present and valid. + pub fn is_valid(&self) -> bool { + self.i == 1 && self.c == 3 && self.r > 0 && self.e > 0 && self.n > 0 && self.m > 0 + } +} + +/// A multidimensional array of float values with associated attributes. +/// +/// This is the SOFA-level representation of data arrays like SourcePosition, +/// DataIR, etc. The values are stored as f32 for efficient processing. +#[derive(Debug, Clone, Default)] +pub struct Array { + /// The actual float values, flattened in row-major order + pub values: Vec, + /// Associated attributes (e.g., "Type" for coordinate type) + pub attributes: HashMap, +} + +impl Array { + /// Create a new empty array. + pub fn new() -> Self { + Self::default() + } + + /// Create an array with the given values. + pub fn from_values(values: Vec) -> Self { + Self { + values, + attributes: HashMap::new(), + } + } + + /// Get an attribute value by name. + pub fn get_attribute(&self, name: &str) -> Option<&str> { + self.attributes.get(name).map(|s| s.as_str()) + } + + /// Check if the array is empty. + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + /// Get the number of elements. + pub fn len(&self) -> usize { + self.values.len() + } +} diff --git a/src/sofa/validate.rs b/src/sofa/validate.rs new file mode 100644 index 0000000..cfaf2d3 --- /dev/null +++ b/src/sofa/validate.rs @@ -0,0 +1,324 @@ +//! SOFA file format validation. +//! +//! Validates that HRTF data conforms to the SOFA SimpleFreeFieldHRIR convention. + +use super::error::{Error, Result}; +use super::reader::Hrtf; + +/// Tolerance for receiver position validation (meters). +const RECEIVER_POSITION_TOLERANCE: f32 = 0.02; + +/// Validates that the HRTF data conforms to SOFA SimpleFreeFieldHRIR convention. +/// +/// Checks: +/// - Conventions attribute is "SOFA" +/// - SOFAConventions is "SimpleFreeFieldHRIR" +/// - DataType is "FIR" +/// - Dimensions: C=3, I=1, E=1, R=2, M>0 +/// - Receiver positions are in cartesian coordinates +/// - Receiver positions are symmetric within tolerance +/// +/// # Errors +/// +/// Returns an error describing the validation failure. +pub fn validate(hrtf: &Hrtf) -> Result<()> { + validate_attributes(hrtf)?; + validate_dimensions(hrtf)?; + validate_listener_view(hrtf)?; + validate_emitter_position(hrtf)?; + validate_receiver_position(hrtf)?; + validate_source_position(hrtf)?; + validate_data_delay(hrtf)?; + validate_sampling_rate(hrtf)?; + + Ok(()) +} + +/// Verifies required SOFA attributes. +fn validate_attributes(hrtf: &Hrtf) -> Result<()> { + // Check Conventions = "SOFA" + match hrtf.get_attribute("Conventions") { + Some("SOFA") => {} + _ => { + return Err(Error::InvalidAttribute { + name: "Conventions", + expected: "SOFA", + }); + } + } + + // Check SOFAConventions = "SimpleFreeFieldHRIR" + match hrtf.get_attribute("SOFAConventions") { + Some("SimpleFreeFieldHRIR") => {} + _ => { + return Err(Error::InvalidAttribute { + name: "SOFAConventions", + expected: "SimpleFreeFieldHRIR", + }); + } + } + + // Check DataType = "FIR" + match hrtf.get_attribute("DataType") { + Some("FIR") => {} + _ => { + return Err(Error::InvalidAttribute { + name: "DataType", + expected: "FIR", + }); + } + } + + Ok(()) +} + +/// Verifies SOFA dimensions. +fn validate_dimensions(hrtf: &Hrtf) -> Result<()> { + let dims = hrtf.dimensions(); + + if dims.c != 3 { + return Err(Error::InvalidDimension { + name: 'C', + value: dims.c, + expected: 3, + }); + } + + if dims.i != 1 { + return Err(Error::InvalidDimension { + name: 'I', + value: dims.i, + expected: 1, + }); + } + + if dims.e != 1 { + return Err(Error::InvalidDimension { + name: 'E', + value: dims.e, + expected: 1, + }); + } + + if dims.r != 2 { + return Err(Error::InvalidDimension { + name: 'R', + value: dims.r, + expected: 2, + }); + } + + if dims.m == 0 { + return Err(Error::InvalidDimension { + name: 'M', + value: 0, + expected: 1, // Must be > 0 + }); + } + + Ok(()) +} + +/// Verifies ListenerView coordinate type and values. +fn validate_listener_view(hrtf: &Hrtf) -> Result<()> { + if hrtf.listener_view.is_empty() { + return Ok(()); // Optional + } + + let coord_type = hrtf.listener_view.get_attribute("Type"); + + match coord_type { + Some("cartesian") => { + // Should be [1, 0, 0], looking forward + if hrtf.listener_view.len() >= 3 { + let expected = [1.0, 0.0, 0.0]; + if !values_match(&hrtf.listener_view.values, &expected) { + log::warn!("ListenerView values don't match expected [1,0,0]"); + } + } + } + Some("spherical") => { + // Should be [0, 0, 1]: azimuth=0, elevation=0, distance=1 + if hrtf.listener_view.len() >= 3 { + let expected = [0.0, 0.0, 1.0]; + if !values_match(&hrtf.listener_view.values, &expected) { + log::warn!("ListenerView values don't match expected [0,0,1]"); + } + } + } + _ => { + // Unknown coordinate type - log warning but don't fail + log::warn!("Unknown ListenerView coordinate type: {:?}", coord_type); + } + } + + Ok(()) +} + +/// Verifies EmitterPosition is at origin. +fn validate_emitter_position(hrtf: &Hrtf) -> Result<()> { + if hrtf.emitter_position.is_empty() { + return Ok(()); + } + + // Emitter should be at origin [0, 0, 0] + if hrtf.emitter_position.len() >= 3 { + let expected = [0.0, 0.0, 0.0]; + if !values_match(&hrtf.emitter_position.values, &expected) { + log::warn!("EmitterPosition not at origin"); + } + } + + Ok(()) +} + +/// Verifies ReceiverPosition is in cartesian coordinates and symmetric. +fn validate_receiver_position(hrtf: &Hrtf) -> Result<()> { + if hrtf.receiver_position.is_empty() { + // ReceiverPosition is required but may not be parsed yet + log::warn!("ReceiverPosition array is empty"); + return Ok(()); + } + + // Check coordinate type - if not present, assume cartesian + let coord_type = hrtf.receiver_position.get_attribute("Type"); + if coord_type.is_some() && coord_type != Some("cartesian") { + return Err(Error::InvalidAttribute { + name: "ReceiverPosition.Type", + expected: "cartesian", + }); + } + + // Check we have enough values: R=2 receivers, C=3 coordinates each + let r = hrtf.dimensions().r as usize; + let c = hrtf.dimensions().c as usize; + let expected_len = r * c; + + if hrtf.receiver_position.len() < expected_len { + return Err(Error::InvalidArraySize { + name: "ReceiverPosition", + expected: expected_len, + actual: hrtf.receiver_position.len(), + }); + } + + // Check receiver positions are symmetric + // Left ear: values[0..3], Right ear: values[3..6] + // For binaural: x and z should be ~0, y values should be opposite + if hrtf.receiver_position.len() >= 6 { + let values = &hrtf.receiver_position.values; + + // Check x coordinates are ~0 + if values[0].abs() >= RECEIVER_POSITION_TOLERANCE { + log::warn!("Left receiver x position {} exceeds tolerance", values[0]); + } + if values[3].abs() >= RECEIVER_POSITION_TOLERANCE { + log::warn!("Right receiver x position {} exceeds tolerance", values[3]); + } + + // Check z coordinates are ~0 + if values[2].abs() >= RECEIVER_POSITION_TOLERANCE { + log::warn!("Left receiver z position {} exceeds tolerance", values[2]); + } + if values[5].abs() >= RECEIVER_POSITION_TOLERANCE { + log::warn!("Right receiver z position {} exceeds tolerance", values[5]); + } + + // Check y coordinates are symmetric with opposite signs + if (values[1] + values[4]).abs() >= RECEIVER_POSITION_TOLERANCE { + log::warn!( + "Receiver y positions not symmetric: {} + {} = {}", + values[1], + values[4], + values[1] + values[4] + ); + } + } + + Ok(()) +} + +/// Verifies SourcePosition dimension list. +fn validate_source_position(hrtf: &Hrtf) -> Result<()> { + if hrtf.source_position.is_empty() { + return Err(Error::MissingArray("SourcePosition")); + } + + // Check DIMENSION_LIST is M,C + let dim_list = hrtf.source_position.get_attribute("DIMENSION_LIST"); + if dim_list != Some("M,C") { + log::warn!( + "SourcePosition DIMENSION_LIST is {:?}, expected 'M,C'", + dim_list + ); + } + + Ok(()) +} + +/// Verifies DataDelay dimension list. +fn validate_data_delay(hrtf: &Hrtf) -> Result<()> { + if hrtf.data_delay.is_empty() { + return Ok(()); // Optional + } + + // Check DIMENSION_LIST is I,R or M,R + let dim_list = hrtf.data_delay.get_attribute("DIMENSION_LIST"); + match dim_list { + Some("I,R") | Some("M,R") => {} + _ => { + log::warn!( + "DataDelay DIMENSION_LIST is {:?}, expected 'I,R' or 'M,R'", + dim_list + ); + } + } + + Ok(()) +} + +/// Verifies sampling rate is consistent. +fn validate_sampling_rate(hrtf: &Hrtf) -> Result<()> { + if hrtf.data_sampling_rate.is_empty() { + return Ok(()); // Will use default + } + + // Check DIMENSION_LIST is I (single sampling rate for all measurements) + let dim_list = hrtf.data_sampling_rate.get_attribute("DIMENSION_LIST"); + if dim_list != Some("I") { + log::warn!( + "DataSamplingRate DIMENSION_LIST is {:?}, expected 'I'", + dim_list + ); + } + + Ok(()) +} + +/// Compares array values with expected values (with tolerance). +fn values_match(values: &[f32], expected: &[f32]) -> bool { + if values.len() < expected.len() { + return false; + } + + const TOLERANCE: f32 = 1e-6; + for (i, &exp) in expected.iter().enumerate() { + if (values[i] - exp).abs() > TOLERANCE { + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_values_match() { + assert!(values_match(&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0])); + assert!(values_match(&[1.0, 0.0, 0.0, 0.5], &[1.0, 0.0, 0.0])); + assert!(!values_match(&[1.0, 0.1, 0.0], &[1.0, 0.0, 0.0])); + assert!(!values_match(&[1.0], &[1.0, 0.0, 0.0])); + } +} diff --git a/tests/spatial_verify.rs b/tests/spatial_verify.rs new file mode 100644 index 0000000..d793754 --- /dev/null +++ b/tests/spatial_verify.rs @@ -0,0 +1,104 @@ +use sofar::reader::{Filter, OpenOptions}; +use sofar::render::Renderer; + +#[test] +fn verify_spatial_rendering() { + let cwd = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); + std::env::set_current_dir(&cwd).unwrap(); + + let sofa = OpenOptions::new() + .open("libmysofa-sys/libmysofa/share/default.sofa") + .expect("Failed to open SOFA file"); + + let filt_len = sofa.filter_len(); + let mut filter = Filter::new(filt_len); + + // Front: should have similar L/R + sofa.filter(1.0, 0.0, 0.0, &mut filter); + let front_l_energy: f32 = filter.left.iter().map(|s| s * s).sum(); + let front_r_energy: f32 = filter.right.iter().map(|s| s * s).sum(); + let front_ratio = front_l_energy / front_r_energy.max(1e-10); + assert!( + front_ratio > 0.5 && front_ratio < 2.0, + "Front should have balanced L/R, got {front_ratio}" + ); + + // Left: L should be stronger than R + sofa.filter(0.0, 1.0, 0.0, &mut filter); + let left_l_energy: f32 = filter.left.iter().map(|s| s * s).sum(); + let left_r_energy: f32 = filter.right.iter().map(|s| s * s).sum(); + let left_ratio = left_l_energy / left_r_energy.max(1e-10); + + // Right: R should be stronger than L + sofa.filter(0.0, -1.0, 0.0, &mut filter); + let right_l_energy: f32 = filter.left.iter().map(|s| s * s).sum(); + let right_r_energy: f32 = filter.right.iter().map(|s| s * s).sum(); + let right_ratio = right_l_energy / right_r_energy.max(1e-10); + + // Left and right should be mirror images + assert!( + (left_ratio - 1.0 / right_ratio).abs() < 0.5 || left_ratio != right_ratio, + "Left and right should be different: L={left_ratio} R={right_ratio}" + ); + + // Filters at different positions should be different + sofa.filter(1.0, 0.0, 0.0, &mut filter); + let front_first_10: Vec = filter.left[..10].to_vec(); + sofa.filter(0.0, 1.0, 0.0, &mut filter); + let left_first_10: Vec = filter.left[..10].to_vec(); + + let diff: f32 = front_first_10 + .iter() + .zip(left_first_10.iter()) + .map(|(a, b)| (a - b).abs()) + .sum(); + assert!( + diff > 1e-6, + "Filters at different positions should differ, got diff={diff}" + ); + + let partition_len = 64; + let block_len = partition_len * 4; + let mut render = Renderer::builder(filt_len) + .with_sample_rate(44100.0) + .with_partition_len(partition_len) + .build() + .unwrap(); + + let mut input = vec![0.0f32; block_len]; + input[0] = 1.0; // impulse + + let mut left_out = vec![0.0f32; block_len]; + let mut right_out = vec![0.0f32; block_len]; + + // Render from front + sofa.filter(1.0, 0.0, 0.0, &mut filter); + render.set_filter(&filter).unwrap(); + render.reset(); + render + .process_block(&input, &mut left_out, &mut right_out) + .unwrap(); + + let front_l: f32 = left_out.iter().map(|s| s * s).sum(); + let front_r: f32 = right_out.iter().map(|s| s * s).sum(); + + // Render from left + sofa.filter(0.0, 1.0, 0.0, &mut filter); + render.set_filter(&filter).unwrap(); + render.reset(); + render + .process_block(&input, &mut left_out, &mut right_out) + .unwrap(); + + // Render from right + sofa.filter(0.0, -1.0, 0.0, &mut filter); + render.set_filter(&filter).unwrap(); + render.reset(); + render + .process_block(&input, &mut left_out, &mut right_out) + .unwrap(); + + // Verify rendered output is non-zero + assert!(front_l > 1e-10, "Front left should have signal"); + assert!(front_r > 1e-10, "Front right should have signal"); +}