Skip to content

Commit ffe89e3

Browse files
authored
Merge pull request #45 from AdaWorldAPI/claude/transcode-deepnsm-rust-oNa1Z
feat(openchat): Mistral-7B inference engine (GQA + RoPE + RMSNorm + SiLU) OpenChat 3.5 / Mistral-7B architecture, fully distinct from GPT-2: - GQA: 32 query heads share 8 KV heads (4:1 ratio, 75% KV cache savings) - RoPE: rotary positional embedding (no learned positions) - RMSNorm: simpler norm without mean subtraction (both in models::layers) - SiLU: gated MLP (gate * up → down) with F32x16 element-wise SIMD - GGUF weight loading via hpc::gguf (Q4_K_M + Q4_0 dequantization added) - CausalEdge64 emission from attention patterns - OpenChat chat template (GPT4 Correct User/Assistant markers) - /v1/chat/completions API types All ops through crate::simd::F32x16 via models::layers. No weights stored — loaded at runtime from user-provided GGUF. 15 tests passing. 77 total across new modules. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
2 parents c4bef5a + 812ba68 commit ffe89e3

27 files changed

Lines changed: 4102 additions & 0 deletions

src/hpc/gguf.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,12 @@ pub fn read_tensor_f32<R: Read + Seek>(
225225
GgmlType::Q8_0 => {
226226
dequantize_q8_0(reader, n_elements)
227227
}
228+
GgmlType::Q4_0 => {
229+
dequantize_q4_0(reader, n_elements)
230+
}
231+
GgmlType::Q4_K => {
232+
dequantize_q4_k(reader, n_elements)
233+
}
228234
other => Err(format!("Unsupported dtype for dequantization: {:?}", other)),
229235
}
230236
}
@@ -317,6 +323,90 @@ fn dequantize_q8_0<R: Read>(r: &mut R, n_elements: usize) -> Result<Vec<f32>, St
317323
Ok(result)
318324
}
319325

326+
/// Dequantize Q4_0: each block = 2 bytes scale (f16) + 16 bytes (32 nibbles).
327+
fn dequantize_q4_0<R: Read>(r: &mut R, n_elements: usize) -> Result<Vec<f32>, String> {
328+
let block_size = 32;
329+
let n_blocks = (n_elements + block_size - 1) / block_size;
330+
let mut result = Vec::with_capacity(n_elements);
331+
332+
for _ in 0..n_blocks {
333+
let mut scale_buf = [0u8; 2];
334+
r.read_exact(&mut scale_buf).map_err(|e| e.to_string())?;
335+
let scale = f16_to_f32(u16::from_le_bytes(scale_buf));
336+
337+
let mut nibbles = [0u8; 16];
338+
r.read_exact(&mut nibbles).map_err(|e| e.to_string())?;
339+
340+
for &byte in &nibbles {
341+
let lo = (byte & 0x0F) as i8 - 8;
342+
let hi = ((byte >> 4) & 0x0F) as i8 - 8;
343+
result.push(lo as f32 * scale);
344+
result.push(hi as f32 * scale);
345+
}
346+
}
347+
348+
result.truncate(n_elements);
349+
Ok(result)
350+
}
351+
352+
/// Dequantize Q4_K: super-blocks of 256 elements.
353+
///
354+
/// Q4_K block layout (144 bytes for 256 elements):
355+
/// - 2 bytes: d (f16 scale)
356+
/// - 2 bytes: dmin (f16 min)
357+
/// - 12 bytes: scales (6-bit per sub-block, packed)
358+
/// - 128 bytes: 256 4-bit quantized values (nibbles)
359+
fn dequantize_q4_k<R: Read>(r: &mut R, n_elements: usize) -> Result<Vec<f32>, String> {
360+
let block_size = 256;
361+
let n_blocks = (n_elements + block_size - 1) / block_size;
362+
let mut result = Vec::with_capacity(n_elements);
363+
364+
for _ in 0..n_blocks {
365+
// Read d and dmin (f16)
366+
let mut d_buf = [0u8; 2];
367+
let mut dmin_buf = [0u8; 2];
368+
r.read_exact(&mut d_buf).map_err(|e| e.to_string())?;
369+
r.read_exact(&mut dmin_buf).map_err(|e| e.to_string())?;
370+
let d = f16_to_f32(u16::from_le_bytes(d_buf));
371+
let dmin = f16_to_f32(u16::from_le_bytes(dmin_buf));
372+
373+
// Read scales (12 bytes = 8 sub-block scales + 8 sub-block mins, 6-bit packed)
374+
let mut scales_raw = [0u8; 12];
375+
r.read_exact(&mut scales_raw).map_err(|e| e.to_string())?;
376+
377+
// Decode 8 scale/min pairs from 12 bytes (6 bits each)
378+
let mut sc = [0u8; 8];
379+
let mut mn = [0u8; 8];
380+
for i in 0..4 {
381+
sc[i] = scales_raw[i] & 0x3F;
382+
mn[i] = scales_raw[i + 4] & 0x3F;
383+
sc[i + 4] = ((scales_raw[i + 8] & 0x0F) << 2) | (scales_raw[i] >> 6);
384+
mn[i + 4] = ((scales_raw[i + 8] >> 4) << 2) | (scales_raw[i + 4] >> 6);
385+
}
386+
387+
// Read 128 bytes of nibbles (256 4-bit values)
388+
let mut nibbles = [0u8; 128];
389+
r.read_exact(&mut nibbles).map_err(|e| e.to_string())?;
390+
391+
// Dequantize: each sub-block of 32 elements
392+
for j in 0..8 {
393+
let sub_d = d * sc[j] as f32;
394+
let sub_m = dmin * mn[j] as f32;
395+
let nib_offset = j * 16;
396+
for k in 0..16 {
397+
let byte = nibbles[nib_offset + k];
398+
let lo = (byte & 0x0F) as f32;
399+
let hi = ((byte >> 4) & 0x0F) as f32;
400+
result.push(lo * sub_d - sub_m);
401+
result.push(hi * sub_d - sub_m);
402+
}
403+
}
404+
}
405+
406+
result.truncate(n_elements);
407+
Ok(result)
408+
}
409+
320410
/// Convert f16 bit pattern to f32.
321411
fn f16_to_f32(bits: u16) -> f32 {
322412
let sign = ((bits >> 15) & 1) as u32;

src/hpc/gpt2/api.rs

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
//! OpenAI-compatible API types for GPT-2 inference.
2+
//!
3+
//! Provides request/response structs matching the OpenAI API surface:
4+
//! - `/v1/completions` — text completion
5+
//! - `/v1/embeddings` — token embeddings via wte
6+
//! - `/v1/models` — model listing
7+
//!
8+
//! These types are transport-agnostic — they serialize/deserialize
9+
//! but don't depend on any HTTP framework.
10+
11+
use super::inference::{GeneratedToken, Gpt2Engine};
12+
use super::weights::*;
13+
14+
// ============================================================================
15+
// /v1/completions
16+
// ============================================================================
17+
18+
/// Request body for /v1/completions.
19+
#[derive(Clone, Debug)]
20+
pub struct CompletionRequest {
21+
/// Model name (ignored — we only have gpt2).
22+
pub model: String,
23+
/// Input text prompt (will be tokenized externally).
24+
pub prompt_tokens: Vec<u32>,
25+
/// Maximum tokens to generate.
26+
pub max_tokens: usize,
27+
/// Sampling temperature (1.0 = greedy effective).
28+
pub temperature: f32,
29+
/// Stop token ID (default: 50256 = <|endoftext|>).
30+
pub stop_token: Option<u32>,
31+
}
32+
33+
impl Default for CompletionRequest {
34+
fn default() -> Self {
35+
Self {
36+
model: "gpt2".into(),
37+
prompt_tokens: Vec::new(),
38+
max_tokens: 128,
39+
temperature: 1.0,
40+
stop_token: Some(50256),
41+
}
42+
}
43+
}
44+
45+
/// Single completion choice.
46+
#[derive(Clone, Debug)]
47+
pub struct CompletionChoice {
48+
pub index: usize,
49+
pub tokens: Vec<GeneratedToken>,
50+
pub finish_reason: FinishReason,
51+
}
52+
53+
/// Why generation stopped.
54+
#[derive(Clone, Debug, PartialEq, Eq)]
55+
pub enum FinishReason {
56+
Stop,
57+
Length,
58+
}
59+
60+
/// Response body for /v1/completions.
61+
#[derive(Clone, Debug)]
62+
pub struct CompletionResponse {
63+
pub id: String,
64+
pub model: String,
65+
pub choices: Vec<CompletionChoice>,
66+
pub usage: Usage,
67+
}
68+
69+
/// Token usage statistics.
70+
#[derive(Clone, Debug, Default)]
71+
pub struct Usage {
72+
pub prompt_tokens: usize,
73+
pub completion_tokens: usize,
74+
pub total_tokens: usize,
75+
}
76+
77+
// ============================================================================
78+
// /v1/embeddings
79+
// ============================================================================
80+
81+
/// Request body for /v1/embeddings.
82+
#[derive(Clone, Debug)]
83+
pub struct EmbeddingRequest {
84+
pub model: String,
85+
/// Token IDs to embed (one embedding per token).
86+
pub input_tokens: Vec<u32>,
87+
}
88+
89+
/// Single embedding result.
90+
#[derive(Clone, Debug)]
91+
pub struct EmbeddingData {
92+
pub index: usize,
93+
pub embedding: Vec<f32>,
94+
}
95+
96+
/// Response body for /v1/embeddings.
97+
#[derive(Clone, Debug)]
98+
pub struct EmbeddingResponse {
99+
pub model: String,
100+
pub data: Vec<EmbeddingData>,
101+
pub usage: Usage,
102+
}
103+
104+
// ============================================================================
105+
// /v1/models
106+
// ============================================================================
107+
108+
/// Model info for /v1/models.
109+
#[derive(Clone, Debug)]
110+
pub struct ModelInfo {
111+
pub id: String,
112+
pub owned_by: String,
113+
pub vocab_size: usize,
114+
pub embed_dim: usize,
115+
pub num_layers: usize,
116+
pub num_heads: usize,
117+
pub max_seq_len: usize,
118+
}
119+
120+
impl ModelInfo {
121+
/// GPT-2 small (124M) model info.
122+
pub fn gpt2_small() -> Self {
123+
Self {
124+
id: "gpt2".into(),
125+
owned_by: "adaworldapi".into(),
126+
vocab_size: VOCAB_SIZE,
127+
embed_dim: EMBED_DIM,
128+
num_layers: NUM_LAYERS,
129+
num_heads: NUM_HEADS,
130+
max_seq_len: MAX_SEQ_LEN,
131+
}
132+
}
133+
}
134+
135+
// ============================================================================
136+
// Engine wrapper — stateless API over stateful engine
137+
// ============================================================================
138+
139+
/// Stateless API wrapper around Gpt2Engine.
140+
/// Handles request→response conversion.
141+
pub struct Gpt2Api {
142+
engine: Gpt2Engine,
143+
request_counter: u64,
144+
}
145+
146+
impl Gpt2Api {
147+
/// Create from pre-loaded weights.
148+
pub fn new(weights: Gpt2Weights) -> Self {
149+
Self {
150+
engine: Gpt2Engine::new(weights),
151+
request_counter: 0,
152+
}
153+
}
154+
155+
/// /v1/completions handler.
156+
pub fn complete(&mut self, req: &CompletionRequest) -> CompletionResponse {
157+
self.request_counter += 1;
158+
159+
let generated = self.engine.generate(
160+
&req.prompt_tokens,
161+
req.max_tokens,
162+
req.temperature,
163+
);
164+
165+
let finish_reason = if generated.len() < req.max_tokens {
166+
FinishReason::Stop
167+
} else {
168+
FinishReason::Length
169+
};
170+
171+
let completion_tokens = generated.len();
172+
let prompt_tokens = req.prompt_tokens.len();
173+
174+
CompletionResponse {
175+
id: format!("cmpl-{}", self.request_counter),
176+
model: "gpt2".into(),
177+
choices: vec![CompletionChoice {
178+
index: 0,
179+
tokens: generated,
180+
finish_reason,
181+
}],
182+
usage: Usage {
183+
prompt_tokens,
184+
completion_tokens,
185+
total_tokens: prompt_tokens + completion_tokens,
186+
},
187+
}
188+
}
189+
190+
/// /v1/embeddings handler — returns wte embeddings for token IDs.
191+
pub fn embed(&self, req: &EmbeddingRequest) -> EmbeddingResponse {
192+
let mut data = Vec::with_capacity(req.input_tokens.len());
193+
194+
for (idx, &token_id) in req.input_tokens.iter().enumerate() {
195+
let offset = token_id as usize * EMBED_DIM;
196+
let embedding = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec();
197+
data.push(EmbeddingData {
198+
index: idx,
199+
embedding,
200+
});
201+
}
202+
203+
EmbeddingResponse {
204+
model: "gpt2".into(),
205+
data,
206+
usage: Usage {
207+
prompt_tokens: req.input_tokens.len(),
208+
completion_tokens: 0,
209+
total_tokens: req.input_tokens.len(),
210+
},
211+
}
212+
}
213+
214+
/// /v1/models handler.
215+
pub fn model_info(&self) -> ModelInfo {
216+
ModelInfo::gpt2_small()
217+
}
218+
219+
/// Access the underlying engine (for advanced usage).
220+
pub fn engine_mut(&mut self) -> &mut Gpt2Engine {
221+
&mut self.engine
222+
}
223+
}
224+
225+
#[cfg(test)]
226+
mod tests {
227+
use super::*;
228+
229+
#[test]
230+
fn test_model_info() {
231+
let info = ModelInfo::gpt2_small();
232+
assert_eq!(info.vocab_size, 50257);
233+
assert_eq!(info.embed_dim, 768);
234+
assert_eq!(info.num_layers, 12);
235+
assert_eq!(info.num_heads, 12);
236+
assert_eq!(info.max_seq_len, 1024);
237+
}
238+
239+
#[test]
240+
fn test_completion_request_default() {
241+
let req = CompletionRequest::default();
242+
assert_eq!(req.max_tokens, 128);
243+
assert_eq!(req.temperature, 1.0);
244+
assert_eq!(req.stop_token, Some(50256));
245+
}
246+
247+
#[test]
248+
fn test_finish_reason_variants() {
249+
assert_eq!(FinishReason::Stop, FinishReason::Stop);
250+
assert_ne!(FinishReason::Stop, FinishReason::Length);
251+
}
252+
}

0 commit comments

Comments
 (0)