Skip to content

Commit 70acfca

Browse files
author
hczphn
committed
fix chunk_size when solve witness
1 parent 168dcff commit 70acfca

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

expander_compiler/src/zkcuda/context.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -750,9 +750,10 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
750750
let mut output_chunk_sizes: Vec<Option<usize>> =
751751
vec![None; kernel_primitive.io_specs().len()];
752752
let mut any_shape = None;
753-
for ((input, ir_inputs), chunk_size) in kernel_call
753+
for (((input, &ib), ir_inputs), chunk_size) in kernel_call
754754
.input_handles
755755
.iter()
756+
.zip(kernel_call.is_broadcast.iter())
756757
.zip(ir_inputs_all.iter_mut())
757758
.zip(input_chunk_sizes.iter_mut())
758759
{
@@ -766,8 +767,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
766767
let values = handle
767768
.shape_history
768769
.permute_vec(&self.device_memories[handle.id].values);
769-
let kernel_shape = handle.shape_history.shape();
770-
*chunk_size = Some(kernel_shape.iter().product());
770+
*chunk_size = Some(values.len() * ib / kernel_call.num_parallel);
771771
*ir_inputs = values;
772772
}
773773
for (((output, &ib), ir_inputs), chunk_size) in kernel_call

0 commit comments

Comments
 (0)