Skip to content

Commit e7db7eb

Browse files
committed
feat(models): complete 1:1 OpenAI API surface + unified router
Rewrite api_types.rs as single source of truth with exact OpenAI field names and object types: - CompletionRequest/Response with logprobs, stop, penalties, seed - ChatCompletionRequest/Response with tools, streaming chunks, deltas - EmbeddingRequest/Response with encoding_format, dimensions - ImageGenerationRequest/Response with size parsing, quality, style - Model/ModelList with object fields ("model", "list") - ErrorResponse envelope matching OpenAI error format - ChatRole with as_str()/from_str() for wire format Add router.rs — unified ModelRouter dispatching by model ID: - complete() → GPT-2 - chat_complete() → OpenChat 3.5 (or GPT-2 via chat↔completion adapter) - embed() → GPT-2 wte (extensible to Jina/BERT) - list_models() / get_model() → all registered models - chat↔completion adapters for cross-model compatibility Rewire gpt2/api.rs, openchat/api.rs, stable_diffusion/api.rs to use shared types. No JSON/serde surface — pure Rust types for in-binary consumers (q2, lance-graph). 58 API tests passing. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent ffe89e3 commit e7db7eb

6 files changed

Lines changed: 1124 additions & 444 deletions

File tree

src/hpc/gpt2/api.rs

Lines changed: 66 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -1,222 +1,97 @@
1-
//! OpenAI-compatible API types for GPT-2 inference.
1+
//! GPT-2 API — wraps the inference engine with OpenAI-compatible types.
22
//!
3-
//! Provides request/response structs matching the OpenAI API surface:
3+
//! Endpoints:
44
//! - `/v1/completions` — text completion
55
//! - `/v1/embeddings` — token embeddings via wte
6-
//! - `/v1/models` — model listing
7-
//!
8-
//! These types are transport-agnostic — they serialize/deserialize
9-
//! but don't depend on any HTTP framework.
6+
//! - `/v1/models` — model info
107
8+
use crate::hpc::models::api_types::*;
119
use super::inference::{GeneratedToken, Gpt2Engine};
1210
use super::weights::*;
1311

14-
// ============================================================================
15-
// /v1/completions
16-
// ============================================================================
17-
18-
/// Request body for /v1/completions.
19-
#[derive(Clone, Debug)]
20-
pub struct CompletionRequest {
21-
/// Model name (ignored — we only have gpt2).
22-
pub model: String,
23-
/// Input text prompt (will be tokenized externally).
24-
pub prompt_tokens: Vec<u32>,
25-
/// Maximum tokens to generate.
26-
pub max_tokens: usize,
27-
/// Sampling temperature (1.0 = greedy effective).
28-
pub temperature: f32,
29-
/// Stop token ID (default: 50256 = <|endoftext|>).
30-
pub stop_token: Option<u32>,
31-
}
32-
33-
impl Default for CompletionRequest {
34-
fn default() -> Self {
35-
Self {
36-
model: "gpt2".into(),
37-
prompt_tokens: Vec::new(),
38-
max_tokens: 128,
39-
temperature: 1.0,
40-
stop_token: Some(50256),
41-
}
42-
}
43-
}
44-
45-
/// Single completion choice.
46-
#[derive(Clone, Debug)]
47-
pub struct CompletionChoice {
48-
pub index: usize,
49-
pub tokens: Vec<GeneratedToken>,
50-
pub finish_reason: FinishReason,
51-
}
52-
53-
/// Why generation stopped.
54-
#[derive(Clone, Debug, PartialEq, Eq)]
55-
pub enum FinishReason {
56-
Stop,
57-
Length,
58-
}
59-
60-
/// Response body for /v1/completions.
61-
#[derive(Clone, Debug)]
62-
pub struct CompletionResponse {
63-
pub id: String,
64-
pub model: String,
65-
pub choices: Vec<CompletionChoice>,
66-
pub usage: Usage,
67-
}
68-
69-
/// Token usage statistics.
70-
#[derive(Clone, Debug, Default)]
71-
pub struct Usage {
72-
pub prompt_tokens: usize,
73-
pub completion_tokens: usize,
74-
pub total_tokens: usize,
75-
}
76-
77-
// ============================================================================
78-
// /v1/embeddings
79-
// ============================================================================
80-
81-
/// Request body for /v1/embeddings.
82-
#[derive(Clone, Debug)]
83-
pub struct EmbeddingRequest {
84-
pub model: String,
85-
/// Token IDs to embed (one embedding per token).
86-
pub input_tokens: Vec<u32>,
87-
}
88-
89-
/// Single embedding result.
90-
#[derive(Clone, Debug)]
91-
pub struct EmbeddingData {
92-
pub index: usize,
93-
pub embedding: Vec<f32>,
94-
}
95-
96-
/// Response body for /v1/embeddings.
97-
#[derive(Clone, Debug)]
98-
pub struct EmbeddingResponse {
99-
pub model: String,
100-
pub data: Vec<EmbeddingData>,
101-
pub usage: Usage,
102-
}
103-
104-
// ============================================================================
105-
// /v1/models
106-
// ============================================================================
107-
108-
/// Model info for /v1/models.
109-
#[derive(Clone, Debug)]
110-
pub struct ModelInfo {
111-
pub id: String,
112-
pub owned_by: String,
113-
pub vocab_size: usize,
114-
pub embed_dim: usize,
115-
pub num_layers: usize,
116-
pub num_heads: usize,
117-
pub max_seq_len: usize,
118-
}
119-
120-
impl ModelInfo {
121-
/// GPT-2 small (124M) model info.
122-
pub fn gpt2_small() -> Self {
123-
Self {
124-
id: "gpt2".into(),
125-
owned_by: "adaworldapi".into(),
126-
vocab_size: VOCAB_SIZE,
127-
embed_dim: EMBED_DIM,
128-
num_layers: NUM_LAYERS,
129-
num_heads: NUM_HEADS,
130-
max_seq_len: MAX_SEQ_LEN,
131-
}
132-
}
133-
}
134-
135-
// ============================================================================
136-
// Engine wrapper — stateless API over stateful engine
137-
// ============================================================================
138-
13912
/// Stateless API wrapper around Gpt2Engine.
140-
/// Handles request→response conversion.
14113
pub struct Gpt2Api {
14214
engine: Gpt2Engine,
14315
request_counter: u64,
14416
}
14517

14618
impl Gpt2Api {
147-
/// Create from pre-loaded weights.
14819
pub fn new(weights: Gpt2Weights) -> Self {
149-
Self {
150-
engine: Gpt2Engine::new(weights),
151-
request_counter: 0,
152-
}
20+
Self { engine: Gpt2Engine::new(weights), request_counter: 0 }
15321
}
15422

155-
/// /v1/completions handler.
23+
/// `/v1/completions`
15624
pub fn complete(&mut self, req: &CompletionRequest) -> CompletionResponse {
15725
self.request_counter += 1;
26+
let tokens = req.prompt_tokens.as_deref().unwrap_or(&[]);
27+
let max = req.max_tokens.unwrap_or(128);
28+
let temp = req.temperature.unwrap_or(1.0);
15829

159-
let generated = self.engine.generate(
160-
&req.prompt_tokens,
161-
req.max_tokens,
162-
req.temperature,
163-
);
30+
let generated = self.engine.generate(tokens, max, temp);
16431

165-
let finish_reason = if generated.len() < req.max_tokens {
32+
let finish_reason = if generated.len() < max {
16633
FinishReason::Stop
16734
} else {
16835
FinishReason::Length
16936
};
17037

171-
let completion_tokens = generated.len();
172-
let prompt_tokens = req.prompt_tokens.len();
173-
174-
CompletionResponse {
175-
id: format!("cmpl-{}", self.request_counter),
176-
model: "gpt2".into(),
177-
choices: vec![CompletionChoice {
38+
let text = generated.iter().map(|t| format!("[{}]", t.token_id)).collect::<String>();
39+
let logprobs: Vec<LogprobInfo> = generated.iter().map(|t| LogprobInfo {
40+
token: format!("{}", t.token_id),
41+
token_id: t.token_id,
42+
logprob: t.logprob,
43+
bytes: None,
44+
top_logprobs: Vec::new(),
45+
}).collect();
46+
47+
let use_logprobs = req.logprobs.is_some();
48+
49+
CompletionResponse::new(
50+
format!("cmpl-{}", self.request_counter),
51+
"gpt2".into(),
52+
vec![CompletionChoice {
17853
index: 0,
179-
tokens: generated,
180-
finish_reason,
54+
text,
55+
logprobs: if use_logprobs { Some(logprobs) } else { None },
56+
finish_reason: Some(finish_reason),
18157
}],
182-
usage: Usage {
183-
prompt_tokens,
184-
completion_tokens,
185-
total_tokens: prompt_tokens + completion_tokens,
58+
Usage {
59+
prompt_tokens: tokens.len(),
60+
completion_tokens: generated.len(),
61+
total_tokens: tokens.len() + generated.len(),
18662
},
187-
}
63+
0,
64+
)
18865
}
18966

190-
/// /v1/embeddings handler — returns wte embeddings for token IDs.
67+
/// `/v1/embeddings`
19168
pub fn embed(&self, req: &EmbeddingRequest) -> EmbeddingResponse {
192-
let mut data = Vec::with_capacity(req.input_tokens.len());
193-
194-
for (idx, &token_id) in req.input_tokens.iter().enumerate() {
195-
let offset = token_id as usize * EMBED_DIM;
196-
let embedding = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec();
197-
data.push(EmbeddingData {
198-
index: idx,
199-
embedding,
200-
});
201-
}
202-
203-
EmbeddingResponse {
204-
model: "gpt2".into(),
69+
let token_ids: Vec<u32> = match &req.input {
70+
EmbeddingInput::TokenIds(ids) => ids.clone(),
71+
_ => req.input_tokens.clone().unwrap_or_default(),
72+
};
73+
74+
let data: Vec<EmbeddingData> = token_ids.iter().enumerate().map(|(idx, &tid)| {
75+
let offset = tid as usize * EMBED_DIM;
76+
let mut emb = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec();
77+
if let Some(dim) = req.dimensions {
78+
emb.truncate(dim);
79+
}
80+
EmbeddingData::new(idx, emb)
81+
}).collect();
82+
83+
EmbeddingResponse::new(
84+
"gpt2".into(),
20585
data,
206-
usage: Usage {
207-
prompt_tokens: req.input_tokens.len(),
208-
completion_tokens: 0,
209-
total_tokens: req.input_tokens.len(),
210-
},
211-
}
86+
Usage { prompt_tokens: token_ids.len(), completion_tokens: 0, total_tokens: token_ids.len() },
87+
)
21288
}
21389

214-
/// /v1/models handler.
215-
pub fn model_info(&self) -> ModelInfo {
216-
ModelInfo::gpt2_small()
90+
/// `/v1/models/{id}`
91+
pub fn model_info() -> Model {
92+
Model::new("gpt2", "adaworldapi", 0)
21793
}
21894

219-
/// Access the underlying engine (for advanced usage).
22095
pub fn engine_mut(&mut self) -> &mut Gpt2Engine {
22196
&mut self.engine
22297
}
@@ -228,25 +103,21 @@ mod tests {
228103

229104
#[test]
230105
fn test_model_info() {
231-
let info = ModelInfo::gpt2_small();
232-
assert_eq!(info.vocab_size, 50257);
233-
assert_eq!(info.embed_dim, 768);
234-
assert_eq!(info.num_layers, 12);
235-
assert_eq!(info.num_heads, 12);
236-
assert_eq!(info.max_seq_len, 1024);
106+
let m = Gpt2Api::model_info();
107+
assert_eq!(m.id, "gpt2");
108+
assert_eq!(m.object, "model");
237109
}
238110

239111
#[test]
240-
fn test_completion_request_default() {
112+
fn test_completion_defaults() {
241113
let req = CompletionRequest::default();
242-
assert_eq!(req.max_tokens, 128);
243-
assert_eq!(req.temperature, 1.0);
244-
assert_eq!(req.stop_token, Some(50256));
114+
assert_eq!(req.model, "gpt2");
115+
assert_eq!(req.max_tokens, Some(128));
245116
}
246117

247118
#[test]
248-
fn test_finish_reason_variants() {
249-
assert_eq!(FinishReason::Stop, FinishReason::Stop);
250-
assert_ne!(FinishReason::Stop, FinishReason::Length);
119+
fn test_completion_response_object() {
120+
let resp = CompletionResponse::new("x".into(), "gpt2".into(), vec![], Usage::default(), 0);
121+
assert_eq!(resp.object, "text_completion");
251122
}
252123
}

0 commit comments

Comments
 (0)