Skip to content

Commit a2ebe14

Browse files
hczphnclaude
andauthored
Reduce peak memory during prove by releasing witness shared memory early (#204)
* Add unchecked device memory export, LogUp query count API, lightweight prove, and configurable CPU count - Add `export_device_memories_unchecked()` for exporting device memories without state assertion, enabling memory optimization workflows where context is dropped before proving - Add `prove_lightweight()` to ExpanderNoOverSubscribe, allowing prove without holding computation_graph or prover_setup references - Add `final_check_with_query_count()` to LogUpSingleKeyTable and LogUpRangeProofTable for hint-free logup verification with externally provided query counts - Support `ZKML_NUM_CPUS` env var to override physical CPU detection for MPI process count Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Reduce peak memory during prove by releasing witness shared memory early Add witness_ack shared memory signaling between client and server: - Client resets a 1-byte ack signal before writing witness - Server signals ack after reading witness into MPI shared memory - Client polls for ack, then immediately releases witness shared memory and calls malloc_trim to return memory to OS - Prove request runs concurrently via tokio async, so witness memory is freed while proving is in progress - Skip reading PCS setup from shared memory (return default) since the client does not need it after setup Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Address review feedback: platform-guard malloc_trim, reduce polling interval - Wrap malloc_trim calls with #[cfg(all(target_os = "linux", target_env = "gnu"))] to avoid linker errors on non-glibc platforms - Reduce witness_ack polling interval from 500ms to 10ms for faster response Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Address review: add polling timeout, revert debug label, remove redundant malloc_trim - Add 5-minute timeout to wait_for_witness_read_complete to prevent indefinite hang if the server crashes - Revert timer label from "new setup" back to "setup" - Remove duplicate malloc_trim inside spawn_blocking (shared memory is mmap-managed, not glibc heap) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: restore verifier setup from shared memory to fix verify panic The previous optimization skipped reading PCS setup from shared memory and returned empty defaults, which caused verify to panic on v_keys lookup (unwrap on None). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8f953a2 commit a2ebe14

3 files changed

Lines changed: 93 additions & 2 deletions

File tree

expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ where
140140

141141
setup_timer.stop();
142142

143-
SharedMemoryEngine::read_pcs_setup_from_shared_memory()
143+
// Prover setup not needed on client side (server does the proving).
144+
// Verifier setup is required for verification, so read it from shared memory.
145+
let (_prover_setup, verifier_setup) =
146+
SharedMemoryEngine::read_pcs_setup_from_shared_memory::<C::FieldConfig, C::PCSConfig>();
147+
(ExpanderProverSetup::default(), verifier_setup)
144148
}
145149

146150
pub fn client_send_witness_and_prove<C, ECCConfig>(
@@ -152,8 +156,39 @@ where
152156
{
153157
let timer = Timer::new("prove", true);
154158

159+
// Reset ack signal, then write witness
160+
SharedMemoryEngine::reset_witness_ack();
155161
SharedMemoryEngine::write_witness_to_shared_memory::<C::FieldConfig>(device_memories);
156-
wait_async(ClientHttpHelper::request_prove());
162+
163+
#[cfg(all(target_os = "linux", target_env = "gnu"))]
164+
{
165+
extern "C" {
166+
fn malloc_trim(pad: usize) -> i32;
167+
}
168+
unsafe {
169+
malloc_trim(0);
170+
}
171+
}
172+
173+
// Async: send prove request + poll for witness ack to release shared memory early
174+
let rt = tokio::runtime::Runtime::new().unwrap();
175+
rt.block_on(async {
176+
let prove_handle = tokio::spawn(async {
177+
ClientHttpHelper::request_prove().await;
178+
});
179+
180+
// Poll witness_ack; once server confirms read, release witness shared memory
181+
tokio::task::spawn_blocking(|| {
182+
SharedMemoryEngine::wait_for_witness_read_complete();
183+
unsafe {
184+
super::shared_memory_utils::SHARED_MEMORY.witness = None;
185+
}
186+
})
187+
.await
188+
.expect("Witness cleanup task failed");
189+
190+
prove_handle.await.expect("Prove task failed");
191+
});
157192

158193
let proof = SharedMemoryEngine::read_proof_from_shared_memory();
159194

expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ where
149149
let mut witness_win = state.wt_shared_memory_win.lock().await;
150150
S::setup_shared_witness(&state.global_mpi_config, &mut witness, &mut witness_win);
151151

152+
// Signal client: witness has been read, shared memory can be released
153+
SharedMemoryEngine::signal_witness_read_complete();
154+
152155
let prover_setup_guard = state.prover_setup.lock().await;
153156
let computation_graph = state.computation_graph.lock().await;
154157

expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ pub struct SharedMemory {
1818
pub pcs_setup: Option<Shmem>,
1919
pub witness: Option<Shmem>,
2020
pub proof: Option<Shmem>,
21+
/// 1-byte signal: 0 = witness not read, 1 = server finished reading witness
22+
pub witness_ack: Option<Shmem>,
2123
}
2224

2325
pub static mut SHARED_MEMORY: SharedMemory = SharedMemory {
2426
pcs_setup: None,
2527
witness: None,
2628
proof: None,
29+
witness_ack: None,
2730
};
2831

2932
pub struct SharedMemoryEngine {}
@@ -106,6 +109,56 @@ impl SharedMemoryEngine {
106109
Self::read_object_from_shared_memory("pcs_setup", 0)
107110
}
108111

112+
/// Client: reset witness_ack to 0 (call before writing witness)
113+
pub fn reset_witness_ack() {
114+
unsafe {
115+
Self::allocate_shared_memory_if_necessary(
116+
&mut SHARED_MEMORY.witness_ack,
117+
"witness_ack",
118+
1,
119+
);
120+
let ptr = SHARED_MEMORY.witness_ack.as_mut().unwrap().as_ptr();
121+
std::ptr::write_volatile(ptr, 0u8);
122+
}
123+
}
124+
125+
/// Server: set witness_ack to 1 (call after reading witness)
126+
pub fn signal_witness_read_complete() {
127+
let shmem = ShmemConf::new()
128+
.flink("witness_ack")
129+
.open()
130+
.expect("Failed to open witness_ack shared memory");
131+
unsafe {
132+
std::ptr::write_volatile(shmem.as_ptr(), 1u8);
133+
}
134+
}
135+
136+
/// Client: poll until witness_ack becomes 1, with a timeout to avoid hanging
137+
/// if the server crashes.
138+
pub fn wait_for_witness_read_complete() {
139+
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
140+
let start = std::time::Instant::now();
141+
unsafe {
142+
let ptr = SHARED_MEMORY
143+
.witness_ack
144+
.as_ref()
145+
.expect("witness_ack not initialized, call reset_witness_ack first")
146+
.as_ptr() as *const u8;
147+
loop {
148+
if std::ptr::read_volatile(ptr) != 0 {
149+
break;
150+
}
151+
if start.elapsed() > TIMEOUT {
152+
panic!(
153+
"Timed out waiting for server to read witness ({}s)",
154+
TIMEOUT.as_secs()
155+
);
156+
}
157+
std::thread::sleep(std::time::Duration::from_millis(10));
158+
}
159+
}
160+
}
161+
109162
pub fn write_witness_to_shared_memory<F: FieldEngine>(values: Vec<Vec<F::SimdCircuitField>>) {
110163
let total_size = std::mem::size_of::<usize>()
111164
+ values

0 commit comments

Comments
 (0)