Skip to content

Commit 5024e64

Browse files
committed
fix(lfm): fix LFM 1.2B Thinking model — ORT upgrade, KV cache dims, streaming tokens, sampling
1 parent be8e97b commit 5024e64

3 files changed

Lines changed: 85 additions & 20 deletions

File tree

changelogs/CHANGELOG-lfm-fix.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# CHANGELOG — LFM 1.2B Thinking Model Fix
2+
3+
## Summary
4+
Fixed the LFM 1.2B Thinking (Liquid AI) local model — it was completely non-functional due to multiple issues in the worker and main thread handler.
5+
6+
## Changes
7+
8+
### `public/ai-worker-lfm.js`
9+
- **Upgraded ONNX Runtime Web** from v1.22.0 → v1.24.3 — old version crashed with WASM errors on LFM's hybrid SSM+Transformer operators
10+
- **Fixed HEAD_DIM constant** from 256 → 64 (= hidden_size 2048 ÷ 32 attention heads) — wrong value caused KV cache shape mismatch during inference
11+
- **Replaced greedy decoding with temperature sampling** — top-k=40, temp=0.7 produces more detailed responses instead of ultra-terse one-liners
12+
- **Increased minimum token budget** to 2048 (4096 for thinking mode) — LFM always generates `<think>` reasoning which consumes tokens before the answer
13+
- **Added detail prompt hint** for chat/generate/qa/explain tasks to encourage comprehensive responses
14+
- **Fixed error handling** — ONNX Runtime throws raw WASM memory pointers (numbers) instead of Error objects; now safely extracts error messages with `String()` fallback
15+
- **Improved download status message** — shows "Downloading LFM 1.2B Thinking weights — this may take a few minutes..." instead of misleading "Downloading model_q4.onnx..."
16+
17+
### `js/ai-assistant.js`
18+
- **Added missing `case 'token'` handler** in the local worker message listener — streaming tokens from LFM (and Qwen) were silently dropped
19+
- **Changed `complete` handler** from `handleAiResponse``handleGroqComplete` — prevents duplicate response bubbles when streaming tokens are followed by a complete message
20+
21+
## Files Modified
22+
- `public/ai-worker-lfm.js` — 7 fixes
23+
- `js/ai-assistant.js` — 2 fixes

js/ai-assistant.js

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,14 @@
787787
break;
788788

789789
case 'complete':
790-
M._ai.handleAiResponse(msg.text, msg.messageId);
790+
// Use handleGroqComplete (not handleAiResponse) because this worker
791+
// also streams tokens. handleGroqComplete reuses the existing streaming
792+
// bubble; handleAiResponse would create a duplicate.
793+
M._ai.handleGroqComplete(msg.text, msg.messageId);
794+
break;
795+
796+
case 'token':
797+
M._ai.handleStreamingToken(msg.token, msg.messageId);
791798
break;
792799

793800
case 'error':

public/ai-worker-lfm.js

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import { TOKEN_LIMITS, buildMessages as _buildMessages } from './ai-worker-commo
1818

1919
// CDN URLs
2020
const TRANSFORMERS_URL = "https://cdn.jsdelivr.net/npm/@huggingface/transformers@4.0.0-next.6";
21-
const ORT_URL = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.22.0/dist/ort.webgpu.min.mjs";
21+
const ORT_URL = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.24.3/dist/ort.webgpu.min.mjs";
2222

2323
// Model host
2424
const MODEL_HOST = "https://huggingface.co";
@@ -35,8 +35,8 @@ let MODEL_DTYPE = "q4"; // 'q4', 'q8', or 'fp16'
3535

3636
// LFM architecture constants (from config.json)
3737
const HIDDEN_SIZE = 2048;
38-
const NUM_KV_HEADS = 8;
39-
const HEAD_DIM = 256;
38+
const NUM_KV_HEADS = 8; // num_key_value_heads
39+
const HEAD_DIM = 64; // hidden_size / num_attention_heads = 2048 / 32
4040

4141
// Runtime state
4242
let tokenizer = null;
@@ -127,9 +127,11 @@ async function loadModel() {
127127
const dataPath = `${modelBase}/onnx/${onnxFile}_data`;
128128

129129
// Report download progress for the main model file
130+
// Note: ONNX Runtime downloads the external data (~600 MB+) internally
131+
// without a progress callback, so we show an informative status message.
130132
self.postMessage({
131133
type: "status",
132-
message: `Downloading ${onnxFile}...`,
134+
message: `Downloading ${MODEL_LABEL} weights — this may take a few minutes...`,
133135
});
134136

135137
session = await ort.InferenceSession.create(onnxPath, {
@@ -142,7 +144,9 @@ async function loadModel() {
142144
await loadFromHost(MODEL_ID);
143145
} catch (primaryErr) {
144146
// Fallback to LiquidAI org
145-
console.warn(`textagent model failed: ${primaryErr.message}. Falling back to ${MODEL_ORG_FALLBACK}…`);
147+
const errMsg = (primaryErr && primaryErr.message) || String(primaryErr);
148+
console.warn(`textagent model failed: ${errMsg}. Falling back to ${MODEL_ORG_FALLBACK}…`);
149+
console.error('Primary load error:', primaryErr);
146150
self.postMessage({ type: "status", message: `Falling back to ${MODEL_ORG_FALLBACK} models…` });
147151
MODEL_ID = MODEL_ID.replace('textagent/', MODEL_ORG_FALLBACK + '/');
148152
tokenizer = null;
@@ -152,12 +156,14 @@ async function loadModel() {
152156

153157
self.postMessage({ type: "loaded", device: "webgpu" });
154158
} catch (error) {
155-
const hint = error.message.includes("Failed to fetch") || error.message.includes("NetworkError")
159+
const errMsg = (error && error.message) || String(error);
160+
const hint = errMsg.includes("Failed to fetch") || errMsg.includes("NetworkError")
156161
? " (Check your internet connection and ensure the model host is not blocked)"
157162
: "";
163+
console.error('LFM load error:', error);
158164
self.postMessage({
159165
type: "error",
160-
message: `Failed to load LFM model: ${error.message}${hint}`,
166+
message: `Failed to load LFM model: ${errMsg}${hint}`,
161167
});
162168
}
163169
}
@@ -222,12 +228,17 @@ async function generate(taskType, context, userPrompt, messageId, enableThinking
222228
augmentedPrompt += '\n\n[Attached File: ' + (att.name || 'file') + ']\n' + att.textContent;
223229
});
224230

225-
// Build messages
226-
const messages = buildMessages(taskType, context, augmentedPrompt || userPrompt, chatHistory);
231+
// Build messages — append detail hint since LFM Thinking tends to be terse
232+
// after its reasoning phase
233+
const detailHint = (taskType === 'generate' || taskType === 'chat' || taskType === 'qa' || taskType === 'explain')
234+
? '\n\nProvide a detailed, comprehensive response with explanations.' : '';
235+
const messages = buildMessages(taskType, context, (augmentedPrompt || userPrompt) + detailHint, chatHistory);
227236

228-
// Use task-specific token limit; thinking mode gets more
229-
let maxTokens = maxTokensOverride || TOKEN_LIMITS[taskType] || 512;
230-
if (enableThinking) maxTokens = Math.max(maxTokens * 2, 1024);
237+
// LFM is a Thinking model — it always generates <think>...</think> reasoning
238+
// before the answer, consuming significant tokens. Use higher minimums.
239+
let maxTokens = maxTokensOverride || TOKEN_LIMITS[taskType] || 2048;
240+
maxTokens = Math.max(maxTokens, 2048); // ensure thinking has room
241+
if (enableThinking) maxTokens = Math.max(maxTokens * 2, 4096);
231242

232243
// Apply chat template and tokenize
233244
const prompt = tokenizer.apply_chat_template(messages, {
@@ -261,18 +272,42 @@ async function generate(taskType, context, userPrompt, messageId, enableThinking
261272
...cache,
262273
});
263274

264-
// Greedy decode: argmax of last token logits
275+
// Temperature sampling with top-k for better response quality
276+
// (greedy argmax causes small models to produce very short answers)
265277
const logits = outputs.logits;
266278
const vocabSize = logits.dims[2];
267279
const lastLogits = logits.data.slice((logits.dims[1] - 1) * vocabSize);
268280

269-
let nextToken = 0;
270-
let maxVal = -Infinity;
281+
const temperature = enableThinking ? 0.6 : 0.7;
282+
const topK = 40;
283+
284+
// Apply temperature scaling
285+
const scaled = new Float32Array(lastLogits.length);
271286
for (let i = 0; i < lastLogits.length; i++) {
272-
if (lastLogits[i] > maxVal) {
273-
maxVal = lastLogits[i];
274-
nextToken = i;
275-
}
287+
scaled[i] = lastLogits[i] / temperature;
288+
}
289+
290+
// Top-k filtering: find top-k indices
291+
const indices = Array.from({ length: scaled.length }, (_, i) => i);
292+
indices.sort((a, b) => scaled[b] - scaled[a]);
293+
const topKIndices = indices.slice(0, topK);
294+
295+
// Softmax over top-k
296+
let maxLogit = scaled[topKIndices[0]];
297+
let sumExp = 0;
298+
const probs = new Float32Array(topK);
299+
for (let i = 0; i < topK; i++) {
300+
probs[i] = Math.exp(scaled[topKIndices[i]] - maxLogit);
301+
sumExp += probs[i];
302+
}
303+
for (let i = 0; i < topK; i++) probs[i] /= sumExp;
304+
305+
// Sample from the distribution
306+
let r = Math.random();
307+
let nextToken = topKIndices[0]; // fallback
308+
for (let i = 0; i < topK; i++) {
309+
r -= probs[i];
310+
if (r <= 0) { nextToken = topKIndices[i]; break; }
276311
}
277312

278313
generatedTokens.push(nextToken);

0 commit comments

Comments
 (0)