Skip to content

Feat: implement qwen3 probing service and tests#626

Open
yuerqiqi wants to merge 2 commits intoUbiquitousLearning:mainfrom
yuerqiqi:my-feature-dev
Open

Feat: implement qwen3 probing service and tests#626
yuerqiqi wants to merge 2 commits intoUbiquitousLearning:mainfrom
yuerqiqi:my-feature-dev

Conversation

@yuerqiqi
Copy link
Contributor

@yuerqiqi yuerqiqi commented Feb 2, 2026

This PR implements the probing service for the Qwen3 model to support model internal state analysis.

Key Changes:

New Features: Added examples/qwen3_service/main_probing.cpp as the main entry point for the probing service.

Model Definition: Implemented modeling_qwen3_probing_service.hpp to support the probing architecture.

Tests: Added test_accuracy.cpp and test_trivia_probing.cpp for validating model performance and probing logic.

Build System: Updated CMakeLists.txt to include the new probing targets.

Integration: Updated mllm-cli/cmd/mllm-server/main.go to register the new service.

Summary by CodeRabbit

  • New Features

    • Interactive probing session for Qwen3 models with streaming chat interface and thinking visualization
    • New accuracy test tool for QA sample evaluation and simple metrics (precision/recall/F1)
    • Trivia probing evaluation tool with per-layer metrics, timing, and JSON result export
    • Layer-wise probing framework with configurable thresholds and probe loading
  • Chores

    • Adjusted service startup sequencing for session lifecycle management

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 2, 2026

📝 Walkthrough

Walkthrough

Adds a Qwen3 probing framework (model instrumentation, probe loading, per-layer probe classifiers, streaming generation with prefill/decode checks and early-exit), three example/test executables (probing interactive, accuracy, trivia probing), and moves service Start/Stop to after session creation.

Changes

Cohort / File(s) Summary
Build Configuration
examples/qwen3_service/CMakeLists.txt
Added three new executables: mllm-qwen3-accuracy, mllm-qwen3-probing, mllm-qwen3-trivia-probing, linked to MllmRT and MllmCPUBackend.
Example Applications
examples/qwen3_service/main_probing.cpp, examples/qwen3_service/test_accuracy.cpp, examples/qwen3_service/test_trivia_probing.cpp
New interactive probing client (main_probing.cpp), accuracy evaluation tool (test_accuracy.cpp), and trivia probing evaluator with per-layer metrics and replay support (test_trivia_probing.cpp).
Probing Framework
mllm/models/qwen3/modeling_qwen3_probing_service.hpp
Large addition: ProbingArgs/ProbingContext, RoPE helpers, ProbeClassifier (scaler/PCA support), probing-enabled MLP/Attention/Decoder/Text modules, Qwen3ProbingForCausalLM (probe loading, forward with probe instrumentation), Qwen3ProbingSession and public probe result APIs.
Service Lifecycle
mllm-cli/cmd/mllm-server/main.go
Deferred StartService call moved to after all model sessions are created; StopService deferred accordingly. Minor reorder of lifecycle initialization.

Sequence Diagram(s)

sequenceDiagram
    participant App as Client App
    participant Session as Qwen3ProbingSession
    participant Model as Qwen3ProbingForCausalLM
    participant Probes as ProbeClassifier[]
    participant Service as Generation Service

    App->>Session: initialize(probing args)
    Session->>Model: create/load model
    App->>Session: loadProbes(dir)
    Session->>Model: loadProbesFromDirectory()
    Model->>Probes: instantiate/load probes
    Probes-->>Model: probes ready

    App->>Service: streamGeneration(request)
    Service->>Model: forward(prefill)
    Model->>Probes: collect prefill activations
    Probes->>Model: predict(prefill)
    alt prefill triggers stop
        Model-->>Service: early_exit
        Service-->>App: stop stream (hallucination)
    else continue
        Service->>Model: forward(decode)
        Model->>Probes: collect decode activations
        Probes->>Model: predict(decode)
        Service-->>App: stream tokens
    end
    Service-->>App: final response + probe results
Loading
sequenceDiagram
    participant Tester as Test App
    participant CSV as Data Source
    participant Session as Qwen3ProbingSession
    participant Model as Generation Engine
    participant Probes as ProbeClassifier[]

    Tester->>CSV: load samples
    CSV-->>Tester: samples
    Tester->>Session: init & loadProbes()
    Session->>Model: load probes
    loop per sample
        Tester->>Model: generate(question)
        Model-->>Tester: tokens + thinking markers
        Tester->>Tester: normalize & check answer
        Tester->>Probes: read probe activations
        Probes-->>Tester: per-layer scores
        Tester->>Tester: update stats (TP/FP/AUC)
    end
    Tester-->>Tester: compute and print metrics
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • chenghuaWang
  • yirongjie
  • liang1232018

Poem

🐰 I hopped through layers, sniffed each token's trail,

I nudged the probes where signals pale,
When thinking blooms I softly peep,
Catching ghosts before they leap,
Three new testers dance — the rabbit's tale. 🥕

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is largely incomplete relative to the template requirements; it lacks structured sections like problem statement, solution overview, testing approach, and breaking changes. Provide a more complete description following the repository template with sections: What problem does this solve?, How does it solve it?, Testing, and any breaking changes.
Docstring Coverage ⚠️ Warning Docstring coverage is 8.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Feat: implement qwen3 probing service and tests' clearly and concisely summarizes the main change—implementing a probing service for Qwen3 with supporting tests and infrastructure.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
mllm-cli/cmd/mllm-server/main.go (1)

1-17: ⚠️ Potential issue | 🔴 Critical

File extension .go is not acceptable for this repository.

The file mllm-cli/cmd/mllm-server/main.go uses the .go extension, which is not in the list of acceptable file extensions per coding guidelines (.c, .cc, .cpp, .cxx, .h, .hh, .hpp, .py, .pyi, .sh, .txt, .md, .yml, .yaml, .json, .toml). The encoding, line endings, control codes, and trailing whitespace checks all pass, but the file cannot be accepted with a .go extension.

🤖 Fix all issues with AI agents
In `@examples/qwen3_service/main_probing.cpp`:
- Around line 100-142: Wrap the JSON parsing and downstream accesses in robust
error handling: surround nlohmann::json::parse(resp) with try-catch to handle
malformed JSON from mllm::service::getResponse("chat-probing"), and after
parsing validate that j.contains("choices") && j["choices"].is_array() &&
!j["choices"].empty() before indexing; likewise check that
choice.contains("delta") && choice["delta"].contains("content") and that
content.is_string() before calling content.get<std::string>(); on any
parse/access error log or print a concise warning and continue the loop (or
break if appropriate) so history, assistant_content, thinking and the rest of
the loop won’t access invalid data.

In `@examples/qwen3_service/test_accuracy.cpp`:
- Around line 84-90: The normalize(std::string s) implementation uses plain char
with std::ispunct and std::tolower which is undefined for negative (non-ASCII)
values and duplicates logic from test_trivia_probing.cpp; fix by casting
characters to unsigned char before calling std::ispunct/std::tolower (e.g.,
std::ispunct(static_cast<unsigned char>(c))) and remove duplication by
extracting the normalized routine into a shared test utility function
(rename/move the implementation used in test_trivia_probing.cpp into a common
helper and update normalize() to call that helper or replace the current body
with the robust implementation).
- Line 100: The line that parses argv[4] into int using std::stoi when setting
limit is unsafe because std::stoi can throw; wrap the conversion in a safe parse
with error handling (e.g., try/catch for std::invalid_argument and
std::out_of_range around std::stoi, or replace with std::from_chars) and on
failure set limit to the default -1 and/or log an error; update the code that
defines limit so it uses argc/argv checks and the safe parse (referencing limit,
argc, argv, and std::stoi) to avoid unhandled exceptions.

In `@examples/qwen3_service/test_trivia_probing.cpp`:
- Around line 92-96: The code redundantly strips surrounding quotes from
variables q and a_str twice; remove the duplicate operations (the repeated
checks using q.size(), q.front(), q.back(), q.substr(...) and the same for
a_str) so that quote-trimming is performed only once for each variable (look for
the two identical blocks operating on q and a_str in test_trivia_probing.cpp and
delete the second block).
- Line 409: The print statement computes accuracy using
(float)model_correct_total / processed_count which can divide by zero if
processed_count is 0; update the logic around the variables processed_count and
model_correct_total so you only perform the division when processed_count > 0
(e.g., compute accuracy = processed_count > 0 ? (float)model_correct_total /
processed_count * 100.0 : 0.0 or print "N/A" when processed_count == 0) and then
use that safe accuracy value in the std::cout line that currently references
processed_count and model_correct_total.
- Around line 346-353: The current catch (...) in the session->streamGenerate
block swallows all errors silently; replace it with structured exception
handling that at minimum catches std::exception and logs the sample index and
e.what() (and optionally have a fallback catch(...) that logs an unknown/non-std
exception), then continue; ensure the log includes identifiers like
generated_text and tok_cnt context so you can correlate which sample failed
while preserving the existing continue behavior.

In `@mllm-cli/cmd/mllm-server/main.go`:
- Around line 68-72: The code currently calls log.Fatal when StartService fails
which bypasses deferred cleanup; instead, change the failure branch in main so
that if mllm.StartService(1) returns false you log the error (e.g., log.Printf
or log.Error), explicitly call the cleanup routines (call mllm.StopService() and
any session/context teardown you need) and then exit with a non‑zero status
(os.Exit(1)); reference StartService and StopService in the change so you
explicitly perform cleanup before exiting rather than using log.Fatal.

In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp`:
- Around line 751-754: The code is using throw
std::runtime_error("PROBING_INTERRUPT") in the probing path (after calling
callback(...) and setting stop_generating) which uses exceptions for normal
control flow; change this to a non-exception early-exit by returning a distinct
status (e.g., bool or enum) or propagating the stop_generating flag: remove the
throw, ensure the function containing the throw (and any callers) returns a
status indicating "interrupted" (or checks stop_generating) and short-circuits
further work, and update callers of that function to handle the new return
value/status instead of relying on catching PROBING_INTERRUPT; reference the
symbols callback, stop_generating and the "PROBING_INTERRUPT" concept when
making these changes so all control paths are updated consistently.
- Around line 824-827: The loop in probing logic currently hardcodes layer 22
when inspecting candidate_key->activations; make the target layer configurable
by adding a field (e.g., hallucination_layer or hallucination_layers) to
ProbingArgs, defaulting to 22, use that field in the check inside the loop
(replace the literal 22 with ProbingArgs::hallucination_layer or the instance
field from where ProbingArgs is passed into the probing service), and validate
the value(s) when ProbingArgs is constructed or parsed so the
ProbingService/modeling_qwen3_probing_service.hpp code (the loop over
candidate_key->activations) supports different models/configs. Ensure call sites
that construct ProbingArgs are updated to supply or inherit the default.
- Around line 475-487: The try/catch around parsing parsed_layer currently
swallows exceptions—replace the empty catch with handling that logs the parse
failure (e.g., using existing logging facility or std::cerr) and preserves a
sensible default for parsed_layer; also remove the dead if-block and implement
its intended behavior by setting use_pca = true (and optionally logging that PCA
is forced) when linear_in_dim != cfg.hidden_size so the code path is exercised;
update references to parsed_layer, use_scaler, use_pca, linear_in_dim and
cfg.hidden_size accordingly.
- Around line 36-38: Remove the three `using namespace` directives from the
header to avoid leaking symbols into all translation units; instead fully
qualify types and symbols referenced in this header (e.g., replace unqualified
uses with mllm::Tensor, mllm::nn::Linear, mllm::models::qwen3::YourClass, etc.).
If local shorthand is needed only inside implementation, add scoped `using`
declarations in the corresponding .cpp or inside function bodies rather than the
header; update all declarations in modeling_qwen3_probing_service.hpp that
currently rely on those namespaces to their explicit qualified names.
- Around line 164-166: The code is directly casting logits.ptr<__fp16>()[0] when
checking logits.dtype() == mllm::kFloat16, which is non-portable; update the
branch that handles mllm::kFloat16 (the logits/type check and assignment to val)
to either convert logits to mllm::kFloat32 and read the first element (e.g., use
logits.to(mllm::kFloat32).ptr<float>()[0]) or wrap the __fp16 access with
appropriate platform guards (e.g., `#ifdef` __ARM_NEON) so the use of __fp16 is
only compiled on supported targets. Ensure you modify the same logits/kFloat16
branch and assign the converted float to val.
🧹 Nitpick comments (7)
mllm/models/qwen3/modeling_qwen3_probing_service.hpp (4)

40-68: Consider adding documentation for public API structs.

ProbingArgs and ProbingContext are public interfaces but lack documentation explaining:

  • Valid ranges for threshold values (e.g., should they be in [0,1]?)
  • Purpose and lifecycle of ProbingContext members
  • When to use reset() vs soft_reset()

327-331: Add braces around single-line conditional statements.

Static analysis flags missing braces which can lead to maintenance bugs when adding statements later.

Proposed fix
       if (probe_ctx->target_layers.empty())
-        layer_needed = true;
-      else if (probe_ctx->target_layers.count(layer_idx_))
-        layer_needed = true;
+      {
+        layer_needed = true;
+      }
+      else if (probe_ctx->target_layers.count(layer_idx_))
+      {
+        layer_needed = true;
+      }

741-742: Replace magic numbers with named constants.

The values 10000 (line 741, 811) and 100000 (line 835) appear to be limits for probe results storage. Extract these as named constants for clarity and maintainability.

Proposed fix
+  static constexpr size_t kMaxProbeResults = 10000;
+  static constexpr size_t kMaxDecodeProbeResults = 100000;
+
   // In the callback:
-  if (model_->last_probe_results_.size() < 10000) {
+  if (model_->last_probe_results_.size() < kMaxProbeResults) {

1051-1054: Consider marking this function [[nodiscard]].

The return value should not be discarded as it provides the probe results. This aligns with static analysis recommendation.

Proposed fix
- std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() const {
+ [[nodiscard]] std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() const {
examples/qwen3_service/main_probing.cpp (1)

19-27: Consider validating parsed layer numbers.

The function accepts any integers without validation. Negative or out-of-range layer indices could cause issues downstream.

examples/qwen3_service/test_accuracy.cpp (1)

138-139: Signed/unsigned comparison in loop condition.

int i compared with samples.size() (returns size_t). Same issue at line 177. Use size_t for loop variables iterating over container sizes.

Proposed fix
-  for (int i = 0; i < samples.size(); ++i) {
-    if (limit > 0 && i >= limit) break;
+  for (size_t i = 0; i < samples.size(); ++i) {
+    if (limit > 0 && static_cast<int>(i) >= limit) break;
examples/qwen3_service/test_trivia_probing.cpp (1)

128-140: This normalize function is better than the one in test_accuracy.cpp.

This implementation correctly uses unsigned char and handles control characters. Consider extracting to a shared utility header to avoid duplication and inconsistency.

Comment on lines +100 to +142
while (true) {
std::string resp = mllm::service::getResponse("chat-probing");
auto j = nlohmann::json::parse(resp);

if (j.contains("choices") && j["choices"].size() > 0) {
auto& choice = j["choices"][0];
auto content = choice["delta"]["content"];

if (content.is_string()) {
std::string s = content.get<std::string>();
if (s.find("early_exit") != std::string::npos) {
try {
auto warn = nlohmann::json::parse(s);
fmt::print(fmt::fg(fmt::color::red) | fmt::emphasis::bold,
"\n[Hallucination] Phase: {} | Layer: {} | Score: {:.4f}\n", warn.value("phase", "unknown"),
warn.value("layer", -1), warn.value("score", 0.0f));
} catch (...) { fmt::print(fmt::fg(fmt::color::red), "\n[Hallucination] Raw: {}\n", s); }

if (!history.empty() && history.back()["role"] == "user") history.pop_back();
break;
}

if (s == "<think>") {
thinking = true;
continue;
}
if (s == "</think>") {
thinking = false;
continue;
}

if (thinking)
fmt::print(fmt::fg(fmt::color::gray), "{}", s);
else {
fmt::print("{}", s);
assistant_content += s;
}
std::fflush(stdout);
}

if (choice["finish_reason"] == "stop") break;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add error handling for JSON parsing and access.

Line 102 calls nlohmann::json::parse(resp) without try-catch, which will throw on malformed JSON. Line 106 assumes j["choices"][0]["delta"]["content"] exists without null checks.

Proposed fix
     while (true) {
       std::string resp = mllm::service::getResponse("chat-probing");
-      auto j = nlohmann::json::parse(resp);
+      nlohmann::json j;
+      try {
+        j = nlohmann::json::parse(resp);
+      } catch (const nlohmann::json::parse_error& e) {
+        std::cerr << "JSON parse error: " << e.what() << std::endl;
+        break;
+      }

       if (j.contains("choices") && j["choices"].size() > 0) {
         auto& choice = j["choices"][0];
-        auto content = choice["delta"]["content"];
+        if (!choice.contains("delta") || !choice["delta"].contains("content")) continue;
+        auto content = choice["delta"]["content"];
🤖 Prompt for AI Agents
In `@examples/qwen3_service/main_probing.cpp` around lines 100 - 142, Wrap the
JSON parsing and downstream accesses in robust error handling: surround
nlohmann::json::parse(resp) with try-catch to handle malformed JSON from
mllm::service::getResponse("chat-probing"), and after parsing validate that
j.contains("choices") && j["choices"].is_array() && !j["choices"].empty() before
indexing; likewise check that choice.contains("delta") &&
choice["delta"].contains("content") and that content.is_string() before calling
content.get<std::string>(); on any parse/access error log or print a concise
warning and continue the loop (or break if appropriate) so history,
assistant_content, thinking and the rest of the loop won’t access invalid data.

Comment on lines +84 to +90
std::string normalize(std::string s) {
std::string out;
for (char c : s) {
if (!std::ispunct(c)) out += std::tolower(c);
}
return out;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential undefined behavior with signed char and duplicate code.

  1. Using char instead of unsigned char can cause issues with std::ispunct() and std::tolower() for non-ASCII characters (negative values → undefined behavior).

  2. This function is duplicated in test_trivia_probing.cpp (lines 128-140) with a more robust implementation. Consider extracting to a shared utility.

Proposed fix (align with test_trivia_probing.cpp)
 std::string normalize(std::string s) {
+  if (s.empty()) return "";
   std::string out;
-  for (char c : s) {
-    if (!std::ispunct(c)) out += std::tolower(c);
+  out.reserve(s.size());
+  for (size_t i = 0; i < s.size(); ++i) {
+    unsigned char c = s[i];
+    if (c < 128) {
+      if (!std::ispunct(c) && !std::iscntrl(c)) out += std::tolower(c);
+    }
   }
   return out;
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::string normalize(std::string s) {
std::string out;
for (char c : s) {
if (!std::ispunct(c)) out += std::tolower(c);
}
return out;
}
std::string normalize(std::string s) {
if (s.empty()) return "";
std::string out;
out.reserve(s.size());
for (size_t i = 0; i < s.size(); ++i) {
unsigned char c = s[i];
if (c < 128) {
if (!std::ispunct(c) && !std::iscntrl(c)) out += std::tolower(c);
}
}
return out;
}
🤖 Prompt for AI Agents
In `@examples/qwen3_service/test_accuracy.cpp` around lines 84 - 90, The
normalize(std::string s) implementation uses plain char with std::ispunct and
std::tolower which is undefined for negative (non-ASCII) values and duplicates
logic from test_trivia_probing.cpp; fix by casting characters to unsigned char
before calling std::ispunct/std::tolower (e.g.,
std::ispunct(static_cast<unsigned char>(c))) and remove duplication by
extracting the normalized routine into a shared test utility function
(rename/move the implementation used in test_trivia_probing.cpp into a common
helper and update normalize() to call that helper or replace the current body
with the robust implementation).

std::string model_path = argv[1];
std::string probes_path = argv[2];
std::string csv_path = argv[3];
int limit = (argc > 4) ? std::stoi(argv[4]) : -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

std::stoi can throw on invalid input.

If argv[4] is not a valid integer, std::stoi will throw std::invalid_argument or std::out_of_range, causing an unhandled exception crash.

Proposed fix
-  int limit = (argc > 4) ? std::stoi(argv[4]) : -1;
+  int limit = -1;
+  if (argc > 4) {
+    try {
+      limit = std::stoi(argv[4]);
+    } catch (const std::exception& e) {
+      std::cerr << "Invalid limit argument: " << argv[4] << std::endl;
+      return 1;
+    }
+  }
🤖 Prompt for AI Agents
In `@examples/qwen3_service/test_accuracy.cpp` at line 100, The line that parses
argv[4] into int using std::stoi when setting limit is unsafe because std::stoi
can throw; wrap the conversion in a safe parse with error handling (e.g.,
try/catch for std::invalid_argument and std::out_of_range around std::stoi, or
replace with std::from_chars) and on failure set limit to the default -1 and/or
log an error; update the code that defines limit so it uses argc/argv checks and
the safe parse (referencing limit, argc, argv, and std::stoi) to avoid unhandled
exceptions.

Comment on lines +92 to +96
if (q.size() >= 2 && q.front() == '"' && q.back() == '"') q = q.substr(1, q.size() - 2);
if (a_str.size() >= 2 && a_str.front() == '"' && a_str.back() == '"') a_str = a_str.substr(1, a_str.size() - 2);
if (q.find("bt_") == 0 || q.find("tc_") == 0 || q.length() < 5) continue;
if (q.size() >= 2 && q.front() == '"' && q.back() == '"') q = q.substr(1, q.size() - 2);
if (a_str.size() >= 2 && a_str.front() == '"' && a_str.back() == '"') a_str = a_str.substr(1, a_str.size() - 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Redundant quote stripping applied twice.

Lines 92-93 strip quotes, then lines 95-96 do the same operation again. This is either redundant or indicates a logic error.

Proposed fix
     if (q.size() >= 2 && q.front() == '"' && q.back() == '"') q = q.substr(1, q.size() - 2);
     if (a_str.size() >= 2 && a_str.front() == '"' && a_str.back() == '"') a_str = a_str.substr(1, a_str.size() - 2);
     if (q.find("bt_") == 0 || q.find("tc_") == 0 || q.length() < 5) continue;
-    if (q.size() >= 2 && q.front() == '"' && q.back() == '"') q = q.substr(1, q.size() - 2);
-    if (a_str.size() >= 2 && a_str.front() == '"' && a_str.back() == '"') a_str = a_str.substr(1, a_str.size() - 2);
🤖 Prompt for AI Agents
In `@examples/qwen3_service/test_trivia_probing.cpp` around lines 92 - 96, The
code redundantly strips surrounding quotes from variables q and a_str twice;
remove the duplicate operations (the repeated checks using q.size(), q.front(),
q.back(), q.substr(...) and the same for a_str) so that quote-trimming is
performed only once for each variable (look for the two identical blocks
operating on q and a_str in test_trivia_probing.cpp and delete the second
block).

Comment on lines +346 to +353
try {
session->streamGenerate(req, [&](const nlohmann::json& j, bool finished) {
if (j.is_string()) {
generated_text += j.get<std::string>();
tok_cnt++;
}
});
} catch (...) { continue; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silent exception swallowing hides failures.

The catch (...) { continue; } silently ignores all exceptions during generation, making it difficult to diagnose issues. At minimum, log the sample index and exception.

Proposed fix
     try {
       session->streamGenerate(req, [&](const nlohmann::json& j, bool finished) {
         if (j.is_string()) {
           generated_text += j.get<std::string>();
           tok_cnt++;
         }
       });
-    } catch (...) { continue; }
+    } catch (const std::exception& e) {
+      std::cerr << "Generation failed for sample " << i << ": " << e.what() << std::endl;
+      continue;
+    } catch (...) {
+      std::cerr << "Generation failed for sample " << i << " (unknown error)" << std::endl;
+      continue;
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try {
session->streamGenerate(req, [&](const nlohmann::json& j, bool finished) {
if (j.is_string()) {
generated_text += j.get<std::string>();
tok_cnt++;
}
});
} catch (...) { continue; }
try {
session->streamGenerate(req, [&](const nlohmann::json& j, bool finished) {
if (j.is_string()) {
generated_text += j.get<std::string>();
tok_cnt++;
}
});
} catch (const std::exception& e) {
std::cerr << "Generation failed for sample " << i << ": " << e.what() << std::endl;
continue;
} catch (...) {
std::cerr << "Generation failed for sample " << i << " (unknown error)" << std::endl;
continue;
}
🤖 Prompt for AI Agents
In `@examples/qwen3_service/test_trivia_probing.cpp` around lines 346 - 353, The
current catch (...) in the session->streamGenerate block swallows all errors
silently; replace it with structured exception handling that at minimum catches
std::exception and logs the sample index and e.what() (and optionally have a
fallback catch(...) that logs an unknown/non-std exception), then continue;
ensure the log includes identifiers like generated_text and tok_cnt context so
you can correlate which sample failed while preserving the existing continue
behavior.

Comment on lines +36 to +38
using namespace mllm;
using namespace mllm::nn;
using namespace mllm::models::qwen3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid using namespace directives in header files.

These directives pollute the namespace for all translation units that include this header, potentially causing name collisions. This is flagged by static analysis (google-build-using-namespace). Consider using explicit namespace qualifications or scoped using declarations within function bodies instead.

Proposed fix
-using namespace mllm;
-using namespace mllm::nn;
-using namespace mllm::models::qwen3;
+// Use explicit qualifications or add these inside function scopes only

Then qualify types explicitly, e.g., mllm::Tensor, mllm::nn::Linear, etc.

🧰 Tools
🪛 Clang (14.0.6)

[error] 36-36: do not use namespace using-directives; use using-declarations instead

(google-build-using-namespace,-warnings-as-errors)

🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 36 - 38,
Remove the three `using namespace` directives from the header to avoid leaking
symbols into all translation units; instead fully qualify types and symbols
referenced in this header (e.g., replace unqualified uses with mllm::Tensor,
mllm::nn::Linear, mllm::models::qwen3::YourClass, etc.). If local shorthand is
needed only inside implementation, add scoped `using` declarations in the
corresponding .cpp or inside function bodies rather than the header; update all
declarations in modeling_qwen3_probing_service.hpp that currently rely on those
namespaces to their explicit qualified names.

Comment on lines 164 to 166
} else if (logits.dtype() == mllm::kFloat16) {
val = (float)logits.ptr<__fp16>()[0];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Search for __fp16 usage in the codebase
rg -n '__fp16' --type=cpp --type=h -B 2 -A 2

Repository: UbiquitousLearning/mllm

Length of output: 30823


🏁 Script executed:

# Look for platform-specific guards or conditionals
rg -n '#ifdef|#if defined|_MSC_VER|__GNUC__|__clang__' --type=cpp --type=h | head -50

Repository: UbiquitousLearning/mllm

Length of output: 3879


🏁 Script executed:

# Check if there are any half-float utilities or conversion functions
rg -n 'half|fp16|float16|Float16' --type=cpp --type=h -i | head -40

Repository: UbiquitousLearning/mllm

Length of output: 3897


🏁 Script executed:

# Look for CMake or build configuration to understand target platforms
find . -name 'CMakeLists.txt' -o -name '*.cmake' -o -name 'setup.py' | head -20

Repository: UbiquitousLearning/mllm

Length of output: 756


🏁 Script executed:

# Check the specific file context and surrounding code
head -200 mllm/models/qwen3/modeling_qwen3_probing_service.hpp | tail -50

Repository: UbiquitousLearning/mllm

Length of output: 1828


🏁 Script executed:

# Check the exact context around the problematic line
sed -n '150,170p' mllm/models/qwen3/modeling_qwen3_probing_service.hpp

Repository: UbiquitousLearning/mllm

Length of output: 717


🏁 Script executed:

# Search for how the codebase handles float16 conversion portably
rg -n 'kFloat16|mllm::kFloat16' --type=cpp -B 1 -A 1 | head -80

Repository: UbiquitousLearning/mllm

Length of output: 5934


Replace direct __fp16 cast with portable conversion or add platform guards.

Line 165 uses __fp16 without platform-specific guards. Other model files in the codebase handle kFloat16 by converting to kFloat32 (e.g., key_states.to(kFloat32)) or, in ARM kernel files, by wrapping __fp16 with #ifdef __ARM_NEON. Use one of these portable patterns instead of directly casting to __fp16:

} else if (logits.dtype() == mllm::kFloat16) {
  val = (float)logits.to(mllm::kFloat32).ptr<float>()[0];
}

Or add proper platform guards if direct __fp16 access is intentional.

🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 164 - 166,
The code is directly casting logits.ptr<__fp16>()[0] when checking
logits.dtype() == mllm::kFloat16, which is non-portable; update the branch that
handles mllm::kFloat16 (the logits/type check and assignment to val) to either
convert logits to mllm::kFloat32 and read the first element (e.g., use
logits.to(mllm::kFloat32).ptr<float>()[0]) or wrap the __fp16 access with
appropriate platform guards (e.g., `#ifdef` __ARM_NEON) so the use of __fp16 is
only compiled on supported targets. Ensure you modify the same logits/kFloat16
branch and assign the converted float to val.

Comment on lines +475 to +487
try {
size_t num_start = layer_pos + 6;
size_t num_end = fn.find_first_not_of("0123456789", num_start);
parsed_layer = std::stoi(fn.substr(num_start, num_end - num_start));
} catch (...) {}
}

bool use_scaler = has_scaler;
bool use_pca = has_pca;

if (!use_pca && linear_in_dim != cfg.hidden_size) {
// If linear_in_dim differs from hidden_size, PCA must be used
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Empty catch block and dead code detected.

  1. Line 479: Empty catch block silently swallows parse exceptions. At minimum, log a warning.
  2. Lines 485-487: Empty if-block with only a comment is dead code.
Proposed fix
         try {
           size_t num_start = layer_pos + 6;
           size_t num_end = fn.find_first_not_of("0123456789", num_start);
           parsed_layer = std::stoi(fn.substr(num_start, num_end - num_start));
-        } catch (...) {}
+        } catch (const std::exception& e) {
+          std::cerr << "Warning: Failed to parse layer number from " << fn << ": " << e.what() << std::endl;
+        }
       }
 
       bool use_scaler = has_scaler;
       bool use_pca = has_pca;
 
-      if (!use_pca && linear_in_dim != cfg.hidden_size) {
-        // If linear_in_dim differs from hidden_size, PCA must be used
-      }
+      // Note: If linear_in_dim differs from hidden_size without PCA, the probe may malfunction
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try {
size_t num_start = layer_pos + 6;
size_t num_end = fn.find_first_not_of("0123456789", num_start);
parsed_layer = std::stoi(fn.substr(num_start, num_end - num_start));
} catch (...) {}
}
bool use_scaler = has_scaler;
bool use_pca = has_pca;
if (!use_pca && linear_in_dim != cfg.hidden_size) {
// If linear_in_dim differs from hidden_size, PCA must be used
}
try {
size_t num_start = layer_pos + 6;
size_t num_end = fn.find_first_not_of("0123456789", num_start);
parsed_layer = std::stoi(fn.substr(num_start, num_end - num_start));
} catch (const std::exception& e) {
std::cerr << "Warning: Failed to parse layer number from " << fn << ": " << e.what() << std::endl;
}
}
bool use_scaler = has_scaler;
bool use_pca = has_pca;
// Note: If linear_in_dim differs from hidden_size without PCA, the probe may malfunction
🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 475 - 487,
The try/catch around parsing parsed_layer currently swallows exceptions—replace
the empty catch with handling that logs the parse failure (e.g., using existing
logging facility or std::cerr) and preserves a sensible default for
parsed_layer; also remove the dead if-block and implement its intended behavior
by setting use_pca = true (and optionally logging that PCA is forced) when
linear_in_dim != cfg.hidden_size so the code path is exercised; update
references to parsed_layer, use_scaler, use_pca, linear_in_dim and
cfg.hidden_size accordingly.

Comment on lines +751 to +754
callback(stop_resp.dump(), true);
stop_generating = true;
throw std::runtime_error("PROBING_INTERRUPT");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Using exceptions for control flow is an anti-pattern.

throw std::runtime_error("PROBING_INTERRUPT") at lines 753 and 852 uses exceptions for normal control flow (early exit). This is expensive (stack unwinding) and makes the code harder to reason about.

Consider using a return value or flag-based approach instead.

🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 751 - 754,
The code is using throw std::runtime_error("PROBING_INTERRUPT") in the probing
path (after calling callback(...) and setting stop_generating) which uses
exceptions for normal control flow; change this to a non-exception early-exit by
returning a distinct status (e.g., bool or enum) or propagating the
stop_generating flag: remove the throw, ensure the function containing the throw
(and any callers) returns a status indicating "interrupted" (or checks
stop_generating) and short-circuits further work, and update callers of that
function to handle the new return value/status instead of relying on catching
PROBING_INTERRUPT; reference the symbols callback, stop_generating and the
"PROBING_INTERRUPT" concept when making these changes so all control paths are
updated consistently.

Comment on lines +824 to +827
for (auto& [layer_idx, tensor] : candidate_key->activations) {
// Only check Layer 22 for hallucination as mapped to user request
if (layer_idx != 22) continue;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded layer index (22) reduces flexibility.

The hallucination check is restricted to layer 22 only. Consider making this configurable via ProbingArgs to allow different models or configurations.

🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 824 - 827,
The loop in probing logic currently hardcodes layer 22 when inspecting
candidate_key->activations; make the target layer configurable by adding a field
(e.g., hallucination_layer or hallucination_layers) to ProbingArgs, defaulting
to 22, use that field in the check inside the loop (replace the literal 22 with
ProbingArgs::hallucination_layer or the instance field from where ProbingArgs is
passed into the probing service), and validate the value(s) when ProbingArgs is
constructed or parsed so the ProbingService/modeling_qwen3_probing_service.hpp
code (the loop over candidate_key->activations) supports different
models/configs. Ensure call sites that construct ProbingArgs are updated to
supply or inherit the default.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp`:
- Line 580: Remove the duplicate non-const accessor and keep a single const
accessor Qwen3ProbingForCausalLM::getLastProbeResults() const marked
[[nodiscard]]; inside it, null-check model_ before accessing
model_->last_probe_results_ (return an empty vector if model_ is null) so
callers get a safe, const-qualified result and static analysis is satisfied
instead of having two conflicting overloads.
- Around line 580-581: Both getLastProbeResults() and clearLastProbeResults()
dereference model_ without a null check; add checks to avoid dereferencing
before fromPreTrain(): in getLastProbeResults() return an empty
std::vector<Qwen3ProbingForCausalLM::ProbeResult> if model_ is null, otherwise
return model_->last_probe_results_; in clearLastProbeResults() make it a no-op
when model_ is null, otherwise call model_->clearProbeResults(). Use the
existing symbols getLastProbeResults, clearLastProbeResults, model_,
last_probe_results_, clearProbeResults(), and fromPreTrain() to locate and
implement the checks.
- Around line 161-167: The current code silently returns sigmoid(0) when
logits.dtype() != mllm::kFloat32; update the branch around logits.dtype() to
handle other dtypes explicitly: if logits.dtype() == mllm::kFloat32 keep reading
via logits.ptr<float>()[0], else if logits.dtype() == mllm::kFloat16 (or other
supported numeric types) convert/cast the tensor to float32 (or read via the
appropriate ptr<T>() and cast to float) before assigning to val, and for any
truly unsupported dtype emit a warning or throw an exception; ensure you
reference the existing logits, val, logits.dtype(), and ptr<> usage so the
change replaces the silent fallback with explicit conversion or error handling.
- Around line 355-361: The memcpy code incorrectly assumes only two dtype sizes
and casts mlp_out.ptr<float>() to char*, which breaks when mlp_out.dtype() is
not kFloat32; update the logic to (1) compute dtype_size from mlp_out.dtype()
with a complete switch/if handling all possible mllm dtypes instead of a
ternary, (2) obtain a raw byte pointer from mlp_out (e.g., use a void*/uint8_t*
accessor such as mlp_out.ptr<uint8_t>() or a generic data() method) rather than
casting ptr<float>(), (3) ensure dest_ptr->ptr<T>() is accessed with the correct
template T matching the destination dtype (or copy into dest as bytes), and (4)
validate both source and destination pointers are non-null before calling
std::memcpy with byte count = hidden_dim * dtype_size using token_offset to
compute the byte_offset.
- Line 5: The `#include` <span> directive in modeling_qwen3_probing_service.hpp is
unnecessary and breaks builds on older compilers; remove the unused include (the
'#include <span>' line) or replace it with a conditional fallback that only
includes <span> when __cplusplus >= 202002L and otherwise provides a
compatibility shim or alternative type. Ensure you update
modeling_qwen3_probing_service.hpp so no references remain to std::span (or add
the shim header) and re-run the build to confirm portability.
🧹 Nitpick comments (7)
mllm/models/qwen3/modeling_qwen3_probing_service.hpp (7)

40-68: Add documentation for public structs.

ProbingArgs and ProbingContext are public API types but lack documentation explaining their purpose, field meanings, and usage. As per coding guidelines, public APIs should have clear docstrings or comments.

Suggested documentation
+/**
+ * `@brief` Configuration arguments for probing behavior during inference.
+ */
 struct ProbingArgs {
-  bool enable_prefill_check = false;
-  float prefill_stop_threshold = 0.7f;
+  bool enable_prefill_check = false;       ///< Enable hallucination check during prefill phase
+  float prefill_stop_threshold = 0.7f;     ///< Threshold for early exit during prefill
   std::vector<int> default_prefill_layers;
 
-  bool enable_decode_check = false;
-  float decode_stop_threshold = 0.8f;
-  float pos_threshold = 0.9f;
+  bool enable_decode_check = false;        ///< Enable hallucination check during decode phase
+  float decode_stop_threshold = 0.8f;      ///< Threshold for early exit during decode
+  float pos_threshold = 0.9f;              ///< Threshold for key token position detection
 };

+/**
+ * `@brief` Runtime context for collecting MLP activations during probing.
+ */
 struct ProbingContext {

324-329: Add braces around single-line statements.

Static analysis flags missing braces at lines 326 and 328. While functional, this can lead to maintenance issues and is flagged as a style violation.

Proposed fix
       if (probe_ctx && probe_ctx->collecting) {
         bool layer_needed = false;
-        if (probe_ctx->target_layers.empty())
+        if (probe_ctx->target_layers.empty()) {
           layer_needed = true;
-        else if (probe_ctx->target_layers.count(layer_idx_))
+        } else if (probe_ctx->target_layers.count(layer_idx_)) {
           layer_needed = true;
+        }

547-564: Public member variables and naming inconsistency.

Several member variables are public but would benefit from encapsulation:

  • prefill_probes, decode_probes, pos_probe, pos_probe_layer_idx are directly accessible
  • last_probe_results_ uses trailing underscore convention (typically private) but is accessed directly by Qwen3ProbingSession

Consider providing accessor methods for better encapsulation, or removing the trailing underscore if the field is intentionally public.


593-594: Use English for code comments.

Line 593 contains a Chinese comment (// 简短指令). For consistency and accessibility across the team, use English comments.

-    // 简短指令
+    // Concise instruction prompt
     std::string concise_instruction = " Please answer in a single, complete sentence. Keep it concise.";

588-874: streamGenerate function is excessively complex.

This ~300-line function with a ~200-line nested lambda (wrapped_callback) has very high cyclomatic complexity, handling:

  • Message preprocessing and chat template application
  • Prefill probing and early exit
  • Decode probing with debouncing
  • Key token detection and hallucination checks
  • Context management between tokens

Consider decomposing into smaller, focused functions:

  • prepareChatInput() - handle message preprocessing
  • runPrefillCheck() - prefill probing logic
  • runDecodeCheck() - decode probing with debouncing
  • handleProbingResult() - callback and early exit logic

This would significantly improve readability and testability. Based on learnings: "Flag overly complex functions (e.g., high cyclomatic complexity) and suggest breaking them down."


739-741: Magic numbers for result limits should be named constants.

The values 10000 and 100000 appear as limits for last_probe_results_ size checks. These should be named constants for clarity and maintainability.

+  static constexpr size_t kMaxPrefillResults = 10000;
+  static constexpr size_t kMaxDecodeResults = 100000;
   ...
-  if (model_->last_probe_results_.size() < 10000) {
+  if (model_->last_probe_results_.size() < kMaxPrefillResults) {

Also applies to: 809-810, 833-834


789-789: Magic number for debounce window.

The debounce counter is hardcoded to 5. Consider making this configurable via ProbingArgs.

 struct ProbingArgs {
   ...
+  int debounce_window = 5;  ///< Number of tokens to wait before confirming key token
 };
 ...
-            debounce_counter = 5;  // Reset window
+            debounce_counter = probing_args_.debounce_window;  // Reset window

// Licensed under the MIT License.
#pragma once

#include <span>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

C++20 <span> header may cause portability issues.

Static analysis reports 'span' file not found on Clang 14.0.6. Additionally, <span> doesn't appear to be used anywhere in this file. Consider removing this unused include or adding a fallback for older compilers.

Proposed fix
-#include <span>

If <span> is needed elsewhere, consider using a compatibility shim or conditional include:

`#if` __cplusplus >= 202002L
`#include` <span>
`#else`
// fallback or alternative
`#endif`
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <span>
(Line 5 with `#include` <span> removed from the includes section)
🧰 Tools
🪛 Clang (14.0.6)

[error] 5-5: 'span' file not found

(clang-diagnostic-error)

🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` at line 5, The `#include`
<span> directive in modeling_qwen3_probing_service.hpp is unnecessary and breaks
builds on older compilers; remove the unused include (the '#include <span>'
line) or replace it with a conditional fallback that only includes <span> when
__cplusplus >= 202002L and otherwise provides a compatibility shim or
alternative type. Ensure you update modeling_qwen3_probing_service.hpp so no
references remain to std::span (or add the shim header) and re-run the build to
confirm portability.

Comment on lines +161 to +167
float val = 0.0f;
if (logits.dtype() == mllm::kFloat32) {
val = logits.ptr<float>()[0];
}

return 1.0f / (1.0f + std::exp(-val));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silent failure when logits dtype is not kFloat32.

If logits.dtype() is anything other than kFloat32 (e.g., kFloat16), val remains 0.0f and the function returns 0.5 (sigmoid of 0). This could mask issues silently.

Consider either:

  1. Converting non-float32 tensors to float32 before reading
  2. Throwing an exception for unsupported dtypes
  3. At minimum, logging a warning
Proposed fix with fallback conversion
     float val = 0.0f;
     if (logits.dtype() == mllm::kFloat32) {
       val = logits.ptr<float>()[0];
-    } 
+    } else if (logits.dtype() == mllm::kFloat16) {
+      // Convert to float32 for portable access
+      auto logits_f32 = logits.to(mllm::kFloat32);
+      val = logits_f32.ptr<float>()[0];
+    } else {
+      std::cerr << "Warning: Unsupported logits dtype in ProbeClassifier::predict" << std::endl;
+    }
🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 161 - 167,
The current code silently returns sigmoid(0) when logits.dtype() !=
mllm::kFloat32; update the branch around logits.dtype() to handle other dtypes
explicitly: if logits.dtype() == mllm::kFloat32 keep reading via
logits.ptr<float>()[0], else if logits.dtype() == mllm::kFloat16 (or other
supported numeric types) convert/cast the tensor to float32 (or read via the
appropriate ptr<T>() and cast to float) before assigning to val, and for any
truly unsupported dtype emit a warning or throw an exception; ensure you
reference the existing logits, val, logits.dtype(), and ptr<> usage so the
change replaces the silent fallback with explicit conversion or error handling.

Comment on lines +355 to +361
size_t dtype_size = (mlp_out.dtype() == mllm::kFloat32) ? 4 : 2;
char* src_base_ptr = (char*)mlp_out.ptr<float>();
size_t byte_offset = (size_t)token_offset * hidden_dim * dtype_size;

if (src_base_ptr && dest_ptr->ptr<float>()) {
std::memcpy(dest_ptr->ptr<float>(), src_base_ptr + byte_offset, hidden_dim * dtype_size);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unsafe dtype handling and pointer casting in memcpy logic.

  1. Line 355: dtype_size ternary only handles kFloat32 and assumes everything else is 2 bytes - this won't work correctly for other dtypes.
  2. Line 356: Casting mlp_out.ptr<float>() to char* is incorrect when dtype is not kFloat32 - the template parameter should match the actual dtype.
Proposed fix
-        size_t dtype_size = (mlp_out.dtype() == mllm::kFloat32) ? 4 : 2;
-        char* src_base_ptr = (char*)mlp_out.ptr<float>();
+        size_t dtype_size = 0;
+        if (mlp_out.dtype() == mllm::kFloat32) {
+          dtype_size = sizeof(float);
+        } else if (mlp_out.dtype() == mllm::kFloat16) {
+          dtype_size = 2;
+        } else {
+          // Skip unsupported dtype
+          continue;
+        }
+        char* src_base_ptr = static_cast<char*>(mlp_out.ptr<void>());
         size_t byte_offset = (size_t)token_offset * hidden_dim * dtype_size;
 
-        if (src_base_ptr && dest_ptr->ptr<float>()) {
-          std::memcpy(dest_ptr->ptr<float>(), src_base_ptr + byte_offset, hidden_dim * dtype_size);
+        if (src_base_ptr && dest_ptr->ptr<void>()) {
+          std::memcpy(dest_ptr->ptr<void>(), src_base_ptr + byte_offset, hidden_dim * dtype_size);
         }
🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 355 - 361,
The memcpy code incorrectly assumes only two dtype sizes and casts
mlp_out.ptr<float>() to char*, which breaks when mlp_out.dtype() is not
kFloat32; update the logic to (1) compute dtype_size from mlp_out.dtype() with a
complete switch/if handling all possible mllm dtypes instead of a ternary, (2)
obtain a raw byte pointer from mlp_out (e.g., use a void*/uint8_t* accessor such
as mlp_out.ptr<uint8_t>() or a generic data() method) rather than casting
ptr<float>(), (3) ensure dest_ptr->ptr<T>() is accessed with the correct
template T matching the destination dtype (or copy into dest as bytes), and (4)
validate both source and destination pointers are non-null before calling
std::memcpy with byte count = hidden_dim * dtype_size using token_offset to
compute the byte_offset.

void setProbingArgs(const ProbingArgs& args) { probing_args_ = args; }
void loadProbes(const std::string& path, const ProbingArgs& args) { model_->loadProbesFromDirectory(path, args); }

std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() { return model_->last_probe_results_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate getLastProbeResults() method definitions.

There are two getLastProbeResults() methods:

  • Line 580: std::vector<...> getLastProbeResults() (non-const, public)
  • Line 1049: std::vector<...> getLastProbeResults() const (const, public)

This is confusing and the const version at line 1049 is flagged by static analysis as needing [[nodiscard]]. Consider keeping only the const version with proper null checking.

Proposed fix
-  std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() { return model_->last_probe_results_; }
+  // Removed in favor of const version below
   void clearLastProbeResults() { model_->clearProbeResults(); }
   ...
  public:
-  std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() const {
+  [[nodiscard]] std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() const {
     if (model_) return model_->last_probe_results_;
     return {};
   }

Also applies to: 1048-1052

🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` at line 580, Remove the
duplicate non-const accessor and keep a single const accessor
Qwen3ProbingForCausalLM::getLastProbeResults() const marked [[nodiscard]];
inside it, null-check model_ before accessing model_->last_probe_results_
(return an empty vector if model_ is null) so callers get a safe,
const-qualified result and static analysis is satisfied instead of having two
conflicting overloads.

Comment on lines +580 to +581
std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() { return model_->last_probe_results_; }
void clearLastProbeResults() { model_->clearProbeResults(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing null check before dereferencing model_.

getLastProbeResults() and clearLastProbeResults() access model_-> without verifying that model_ is not null. If called before fromPreTrain(), this will cause a null pointer dereference.

Proposed fix
-  std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() { return model_->last_probe_results_; }
-  void clearLastProbeResults() { model_->clearProbeResults(); }
+  std::vector<Qwen3ProbingForCausalLM::ProbeResult> getLastProbeResults() {
+    if (!model_) return {};
+    return model_->last_probe_results_;
+  }
+  void clearLastProbeResults() {
+    if (model_) model_->clearProbeResults();
+  }
🤖 Prompt for AI Agents
In `@mllm/models/qwen3/modeling_qwen3_probing_service.hpp` around lines 580 - 581,
Both getLastProbeResults() and clearLastProbeResults() dereference model_
without a null check; add checks to avoid dereferencing before fromPreTrain():
in getLastProbeResults() return an empty
std::vector<Qwen3ProbingForCausalLM::ProbeResult> if model_ is null, otherwise
return model_->last_probe_results_; in clearLastProbeResults() make it a no-op
when model_ is null, otherwise call model_->clearProbeResults(). Use the
existing symbols getLastProbeResults, clearLastProbeResults, model_,
last_probe_results_, clearProbeResults(), and fromPreTrain() to locate and
implement the checks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant