[kernel][mobile] add mobile-optimized CPU backend#1609
Conversation
Signed-off-by: AlpinDale <alpindale@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a mobile-optimized CPU backend targeting Android, leveraging ARM NEON for performance. The changes are extensive, adding numerous C++ files with optimized kernels for various operations like attention, matrix multiplication, and activations, along with the necessary CMake build system modifications. While the effort to optimize for mobile is commendable, the review has identified several critical issues. These include correctness bugs such as potential race conditions due to thread-unsafe static variables, incorrect arithmetic leading to wrong results, and linker errors from function name mismatches. There are also significant performance issues stemming from inefficient memory access patterns and repeated quantization/dequantization operations within tight loops. Addressing these critical and high-severity issues is essential for the stability and performance of the new mobile backend.
| int8_t* o_vec = | ||
| O_base + q_pos * o_seq_stride + q_head_idx * head_dim; | ||
| std::fill(o_vec, o_vec + head_dim, 0); | ||
|
|
||
| for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { | ||
| const float attn_weight = scores_row[kv_pos]; | ||
| if (attn_weight == 0.0f) continue; | ||
|
|
||
| const int8_t* v_vec = | ||
| V_base + kv_pos * kv_seq_stride + kv_head_idx * head_dim; | ||
|
|
||
| for (size_t dim = 0; dim < head_dim; ++dim) { | ||
| float weighted_val_fp32 = | ||
| attn_weight * static_cast<float>(v_vec[dim]) * v_scale; | ||
| float current_fp32 = | ||
| static_cast<float>(o_vec[dim]) * output_scale; | ||
| float result_fp32 = current_fp32 + weighted_val_fp32; | ||
|
|
||
| int32_t quantized_result = | ||
| static_cast<int32_t>(result_fp32 / output_scale + | ||
| (result_fp32 >= 0 ? 0.5f : -0.5f)); | ||
| quantized_result = | ||
| std::max(-128, std::min(127, quantized_result)); | ||
| o_vec[dim] = static_cast<int8_t>(quantized_result); | ||
| } | ||
| } |
There was a problem hiding this comment.
The accumulation of attention-weighted values into the output vector o_vec is inefficient and can lead to precision loss. Inside the kv_pos loop, the o_vec (which is int8_t) is repeatedly dequantized to float, updated, and then requantized. This process is computationally expensive and introduces quantization errors at each step.
A more efficient and accurate approach is to use a temporary floating-point accumulator for the output vector. This accumulator can sum up all the weighted values in full precision. The final quantization to int8_t should only happen once, after the loop over kv_pos is complete.
std::vector<float> o_vec_fp32(head_dim, 0.0f);
for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) {
const float attn_weight = scores_row[kv_pos];
if (attn_weight == 0.0f) continue;
const int8_t* v_vec =
V_base + kv_pos * kv_seq_stride + kv_head_idx * head_dim;
for (size_t dim = 0; dim < head_dim; ++dim) {
o_vec_fp32[dim] +=
attn_weight * static_cast<float>(v_vec[dim]) * v_scale;
}
}
int8_t* o_vec =
O_base + q_pos * o_seq_stride + q_head_idx * head_dim;
for (size_t dim = 0; dim < head_dim; ++dim) {
float val = o_vec_fp32[dim];
int32_t quantized_result =
static_cast<int32_t>(val / output_scale +
(val >= 0 ? 0.5f : -0.5f));
quantized_result =
std::max(-128, std::min(127, quantized_result));
o_vec[dim] = static_cast<int8_t>(quantized_result);
}| for (; i < tile_end; i += TILE_SIZE) { | ||
| for (size_t u = 0; u < UNROLL_FACTOR; u++) { | ||
| int8x8_t input_i8 = vld1_s8(&input_row[i + u * SIMD_WIDTH]); | ||
| int16x4_t input_i16 = vget_low_s16(vmovl_s8(input_i8)); | ||
| int32x4_t input_i32 = vmovl_s16(input_i16); | ||
| float32x4_t input_f32 = | ||
| vmulq_f32(vcvtq_f32_s32(input_i32), input_scale_vec); | ||
| sum_squares_vec[u] = | ||
| vfmaq_f32(sum_squares_vec[u], input_f32, input_f32); | ||
| } | ||
| } | ||
|
|
||
| const size_t simd_end = (dims >= SIMD_WIDTH) ? dims - SIMD_WIDTH + 1 : 0; | ||
| for (; i < simd_end; i += SIMD_WIDTH) { | ||
| int8x8_t input_i8 = vld1_s8(&input_row[i]); | ||
| int16x4_t input_i16 = vget_low_s16(vmovl_s8(input_i8)); | ||
| int32x4_t input_i32 = vmovl_s16(input_i16); | ||
| float32x4_t input_f32 = | ||
| vmulq_f32(vcvtq_f32_s32(input_i32), input_scale_vec); | ||
| sum_squares_vec[0] = vfmaq_f32(sum_squares_vec[0], input_f32, input_f32); | ||
| } |
There was a problem hiding this comment.
The vectorization logic in rms_norm_i8_f32 is inefficient and likely incorrect due to overlapping memory loads. The loop iterates with a step of TILE_SIZE (16), and the inner loop with UNROLL_FACTOR (4) and SIMD_WIDTH (4) results in addresses i, i+4, i+8, i+12. At each of these addresses, vld1_s8 loads 8 bytes. This means, for example, the second load at i+4 re-loads 4 bytes that were already loaded at i. Furthermore, from each 8-byte load, only the first 4 bytes are processed (vget_low_s16(vmovl_s8(input_i8))), discarding half of the loaded data.
This should be corrected to avoid overlapping loads and to process all loaded data, which will significantly improve performance and ensure correctness. A similar issue exists in the second part of the function where the normalization is applied.
| for (size_t pos = 0; pos < seq_len; ++pos) { | ||
| const float pos_float = static_cast<float>(pos); | ||
| for (size_t i = 0; i < half_dim; ++i) { | ||
| const float freq = 1.0f / powf(theta, (2.0f * i) / head_dim); |
There was a problem hiding this comment.
There is an integer division bug in the calculation of freq. The variables i and head_dim are of type size_t, so (2.0f * i) / head_dim will perform integer division before the multiplication, leading to incorrect frequency values. You should cast these variables to float to ensure floating-point division.
| const float freq = 1.0f / powf(theta, (2.0f * i) / head_dim); | |
| const float freq = 1.0f / powf(theta, (2.0f * static_cast<float>(i)) / static_cast<float>(head_dim)); |
| for (size_t k_block = K_aligned; k_block < K; | ||
| k_block += DOT_GRANULARITY) { | ||
| size_t remaining = | ||
| std::min(static_cast<size_t>(DOT_GRANULARITY), K - k_block); | ||
|
|
||
| for (int m = 0; m < TILE_M; ++m) { | ||
| size_t row = row_block + m; | ||
| if (row >= M) continue; | ||
|
|
||
| for (int n = 0; n < TILE_N; ++n) { | ||
| size_t col = col_block + n; | ||
| if (col >= N) continue; | ||
|
|
||
| int32_t dot_product = 0; | ||
| for (size_t k = 0; k < remaining; ++k) { | ||
| dot_product += | ||
| static_cast<int32_t>(a[row * K + k_block + k]) * | ||
| static_cast<int32_t>(b_transposed[col * K + k_block + k]); | ||
| } | ||
|
|
||
| int32x4_t dot_vec = vdupq_n_s32(dot_product); | ||
| accumulators[m][n] = vaddq_s32(accumulators[m][n], dot_vec); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| for (int m = 0; m < TILE_M; ++m) { | ||
| size_t row = row_block + m; | ||
| if (row >= M) continue; | ||
| for (int n = 0; n < TILE_N; ++n) { | ||
| size_t col = col_block + n; | ||
| if (col >= N) continue; | ||
| int32_t sum = vaddvq_s32(accumulators[m][n]); | ||
| c[row * N + col] = sum; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
The logic for handling the remainder of the K dimension is incorrect. The scalar dot_product is broadcast to a vector and added to the vector accumulator accumulators[m][n]. When vaddvq_s32 is called later to get the final sum, the dot_product is effectively added four times, leading to incorrect matrix multiplication results.
The correct approach is to perform the horizontal sum of the vector accumulator first to get a scalar sum, and then add the scalar-computed remainder to it.
for (int m = 0; m < TILE_M; ++m) {
size_t row = row_block + m;
if (row >= M) continue;
for (int n = 0; n < TILE_N; ++n) {
size_t col = col_block + n;
if (col >= N) continue;
int32_t sum = vaddvq_s32(accumulators[m][n]);
for (size_t k = K_aligned; k < K; ++k) {
sum += static_cast<int32_t>(a[row * K + k]) *
static_cast<int32_t>(b_transposed[col * K + k]);
}
c[row * N + col] = sum;
}
}| static std::vector<uint32_t> token_history; | ||
| static const size_t MAX_HISTORY = 128; | ||
| static const float REPETITION_PENALTY = 1.1f; |
There was a problem hiding this comment.
The use of a static variable token_history makes the sample_f16 function not thread-safe. If this function is called from multiple threads concurrently (e.g., for different sequences in a batch), they will all share and modify the same token_history vector, leading to race conditions and incorrect behavior.
To fix this, token_history should be managed externally and passed into the function as an argument. The related constants MAX_HISTORY and REPETITION_PENALTY should also be passed as arguments if they need to be configurable per call.
| if (col_block + 0 < num_cols) { | ||
| if (row_block_end - row_block >= 4) { | ||
| vst1q_f32(&destination[(col_block + 0) * num_rows + row_block], | ||
| col0); | ||
| } else { | ||
| float temp[4]; | ||
| vst1q_f32(temp, col0); | ||
| for (size_t i = 0; i < row_block_end - row_block; ++i) { | ||
| destination[(col_block + 0) * num_rows + row_block + i] = | ||
| temp[i]; | ||
| } | ||
| } | ||
| } | ||
| if (col_block + 1 < num_cols) { | ||
| if (row_block_end - row_block >= 4) { | ||
| vst1q_f32(&destination[(col_block + 1) * num_rows + row_block], | ||
| col1); | ||
| } else { | ||
| float temp[4]; | ||
| vst1q_f32(temp, col1); | ||
| for (size_t i = 0; i < row_block_end - row_block; ++i) { | ||
| destination[(col_block + 1) * num_rows + row_block + i] = | ||
| temp[i]; | ||
| } | ||
| } | ||
| } | ||
| if (col_block + 2 < num_cols) { | ||
| if (row_block_end - row_block >= 4) { | ||
| vst1q_f32(&destination[(col_block + 2) * num_rows + row_block], | ||
| col2); | ||
| } else { | ||
| float temp[4]; | ||
| vst1q_f32(temp, col2); | ||
| for (size_t i = 0; i < row_block_end - row_block; ++i) { | ||
| destination[(col_block + 2) * num_rows + row_block + i] = | ||
| temp[i]; | ||
| } | ||
| } | ||
| } | ||
| if (col_block + 3 < num_cols) { | ||
| if (row_block_end - row_block >= 4) { | ||
| vst1q_f32(&destination[(col_block + 3) * num_rows + row_block], | ||
| col3); | ||
| } else { | ||
| float temp[4]; | ||
| vst1q_f32(temp, col3); | ||
| for (size_t i = 0; i < row_block_end - row_block; ++i) { | ||
| destination[(col_block + 3) * num_rows + row_block + i] = | ||
| temp[i]; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
The code for storing the transposed columns (col0, col1, col2, col3) is highly repetitive. This duplication makes the code harder to read and maintain. You can refactor this section by using a helper lambda or a small loop to handle the storage of each column, which would significantly reduce code duplication and improve clarity.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: AlpinDale <alpindale@gmail.com>
WIP, only targeting Android for now.