1- //! OpenAI-compatible API types for GPT-2 inference.
1+ //! GPT-2 API — wraps the inference engine with OpenAI-compatible types .
22//!
3- //! Provides request/response structs matching the OpenAI API surface :
3+ //! Endpoints :
44//! - `/v1/completions` — text completion
55//! - `/v1/embeddings` — token embeddings via wte
6- //! - `/v1/models` — model listing
7- //!
8- //! These types are transport-agnostic — they serialize/deserialize
9- //! but don't depend on any HTTP framework.
6+ //! - `/v1/models` — model info
107
8+ use crate :: hpc:: models:: api_types:: * ;
119use super :: inference:: { GeneratedToken , Gpt2Engine } ;
1210use super :: weights:: * ;
1311
14- // ============================================================================
15- // /v1/completions
16- // ============================================================================
17-
18- /// Request body for /v1/completions.
19- #[ derive( Clone , Debug ) ]
20- pub struct CompletionRequest {
21- /// Model name (ignored — we only have gpt2).
22- pub model : String ,
23- /// Input text prompt (will be tokenized externally).
24- pub prompt_tokens : Vec < u32 > ,
25- /// Maximum tokens to generate.
26- pub max_tokens : usize ,
27- /// Sampling temperature (1.0 = greedy effective).
28- pub temperature : f32 ,
29- /// Stop token ID (default: 50256 = <|endoftext|>).
30- pub stop_token : Option < u32 > ,
31- }
32-
33- impl Default for CompletionRequest {
34- fn default ( ) -> Self {
35- Self {
36- model : "gpt2" . into ( ) ,
37- prompt_tokens : Vec :: new ( ) ,
38- max_tokens : 128 ,
39- temperature : 1.0 ,
40- stop_token : Some ( 50256 ) ,
41- }
42- }
43- }
44-
45- /// Single completion choice.
46- #[ derive( Clone , Debug ) ]
47- pub struct CompletionChoice {
48- pub index : usize ,
49- pub tokens : Vec < GeneratedToken > ,
50- pub finish_reason : FinishReason ,
51- }
52-
53- /// Why generation stopped.
54- #[ derive( Clone , Debug , PartialEq , Eq ) ]
55- pub enum FinishReason {
56- Stop ,
57- Length ,
58- }
59-
60- /// Response body for /v1/completions.
61- #[ derive( Clone , Debug ) ]
62- pub struct CompletionResponse {
63- pub id : String ,
64- pub model : String ,
65- pub choices : Vec < CompletionChoice > ,
66- pub usage : Usage ,
67- }
68-
69- /// Token usage statistics.
70- #[ derive( Clone , Debug , Default ) ]
71- pub struct Usage {
72- pub prompt_tokens : usize ,
73- pub completion_tokens : usize ,
74- pub total_tokens : usize ,
75- }
76-
77- // ============================================================================
78- // /v1/embeddings
79- // ============================================================================
80-
81- /// Request body for /v1/embeddings.
82- #[ derive( Clone , Debug ) ]
83- pub struct EmbeddingRequest {
84- pub model : String ,
85- /// Token IDs to embed (one embedding per token).
86- pub input_tokens : Vec < u32 > ,
87- }
88-
89- /// Single embedding result.
90- #[ derive( Clone , Debug ) ]
91- pub struct EmbeddingData {
92- pub index : usize ,
93- pub embedding : Vec < f32 > ,
94- }
95-
96- /// Response body for /v1/embeddings.
97- #[ derive( Clone , Debug ) ]
98- pub struct EmbeddingResponse {
99- pub model : String ,
100- pub data : Vec < EmbeddingData > ,
101- pub usage : Usage ,
102- }
103-
104- // ============================================================================
105- // /v1/models
106- // ============================================================================
107-
108- /// Model info for /v1/models.
109- #[ derive( Clone , Debug ) ]
110- pub struct ModelInfo {
111- pub id : String ,
112- pub owned_by : String ,
113- pub vocab_size : usize ,
114- pub embed_dim : usize ,
115- pub num_layers : usize ,
116- pub num_heads : usize ,
117- pub max_seq_len : usize ,
118- }
119-
120- impl ModelInfo {
121- /// GPT-2 small (124M) model info.
122- pub fn gpt2_small ( ) -> Self {
123- Self {
124- id : "gpt2" . into ( ) ,
125- owned_by : "adaworldapi" . into ( ) ,
126- vocab_size : VOCAB_SIZE ,
127- embed_dim : EMBED_DIM ,
128- num_layers : NUM_LAYERS ,
129- num_heads : NUM_HEADS ,
130- max_seq_len : MAX_SEQ_LEN ,
131- }
132- }
133- }
134-
135- // ============================================================================
136- // Engine wrapper — stateless API over stateful engine
137- // ============================================================================
138-
13912/// Stateless API wrapper around Gpt2Engine.
140- /// Handles request→response conversion.
14113pub struct Gpt2Api {
14214 engine : Gpt2Engine ,
14315 request_counter : u64 ,
14416}
14517
14618impl Gpt2Api {
147- /// Create from pre-loaded weights.
14819 pub fn new ( weights : Gpt2Weights ) -> Self {
149- Self {
150- engine : Gpt2Engine :: new ( weights) ,
151- request_counter : 0 ,
152- }
20+ Self { engine : Gpt2Engine :: new ( weights) , request_counter : 0 }
15321 }
15422
155- /// /v1/completions handler.
23+ /// ` /v1/completions`
15624 pub fn complete ( & mut self , req : & CompletionRequest ) -> CompletionResponse {
15725 self . request_counter += 1 ;
26+ let tokens = req. prompt_tokens . as_deref ( ) . unwrap_or ( & [ ] ) ;
27+ let max = req. max_tokens . unwrap_or ( 128 ) ;
28+ let temp = req. temperature . unwrap_or ( 1.0 ) ;
15829
159- let generated = self . engine . generate (
160- & req. prompt_tokens ,
161- req. max_tokens ,
162- req. temperature ,
163- ) ;
30+ let generated = self . engine . generate ( tokens, max, temp) ;
16431
165- let finish_reason = if generated. len ( ) < req . max_tokens {
32+ let finish_reason = if generated. len ( ) < max {
16633 FinishReason :: Stop
16734 } else {
16835 FinishReason :: Length
16936 } ;
17037
171- let completion_tokens = generated. len ( ) ;
172- let prompt_tokens = req. prompt_tokens . len ( ) ;
173-
174- CompletionResponse {
175- id : format ! ( "cmpl-{}" , self . request_counter) ,
176- model : "gpt2" . into ( ) ,
177- choices : vec ! [ CompletionChoice {
38+ let text = generated. iter ( ) . map ( |t| format ! ( "[{}]" , t. token_id) ) . collect :: < String > ( ) ;
39+ let logprobs: Vec < LogprobInfo > = generated. iter ( ) . map ( |t| LogprobInfo {
40+ token : format ! ( "{}" , t. token_id) ,
41+ token_id : t. token_id ,
42+ logprob : t. logprob ,
43+ bytes : None ,
44+ top_logprobs : Vec :: new ( ) ,
45+ } ) . collect ( ) ;
46+
47+ let use_logprobs = req. logprobs . is_some ( ) ;
48+
49+ CompletionResponse :: new (
50+ format ! ( "cmpl-{}" , self . request_counter) ,
51+ "gpt2" . into ( ) ,
52+ vec ! [ CompletionChoice {
17853 index: 0 ,
179- tokens: generated,
180- finish_reason,
54+ text,
55+ logprobs: if use_logprobs { Some ( logprobs) } else { None } ,
56+ finish_reason: Some ( finish_reason) ,
18157 } ] ,
182- usage : Usage {
183- prompt_tokens,
184- completion_tokens,
185- total_tokens : prompt_tokens + completion_tokens ,
58+ Usage {
59+ prompt_tokens : tokens . len ( ) ,
60+ completion_tokens : generated . len ( ) ,
61+ total_tokens : tokens . len ( ) + generated . len ( ) ,
18662 } ,
187- }
63+ 0 ,
64+ )
18865 }
18966
190- /// /v1/embeddings handler — returns wte embeddings for token IDs.
67+ /// ` /v1/embeddings`
19168 pub fn embed ( & self , req : & EmbeddingRequest ) -> EmbeddingResponse {
192- let mut data = Vec :: with_capacity ( req. input_tokens . len ( ) ) ;
193-
194- for ( idx, & token_id) in req. input_tokens . iter ( ) . enumerate ( ) {
195- let offset = token_id as usize * EMBED_DIM ;
196- let embedding = self . engine . weights ( ) . wte [ offset..offset + EMBED_DIM ] . to_vec ( ) ;
197- data. push ( EmbeddingData {
198- index : idx,
199- embedding,
200- } ) ;
201- }
202-
203- EmbeddingResponse {
204- model : "gpt2" . into ( ) ,
69+ let token_ids: Vec < u32 > = match & req. input {
70+ EmbeddingInput :: TokenIds ( ids) => ids. clone ( ) ,
71+ _ => req. input_tokens . clone ( ) . unwrap_or_default ( ) ,
72+ } ;
73+
74+ let data: Vec < EmbeddingData > = token_ids. iter ( ) . enumerate ( ) . map ( |( idx, & tid) | {
75+ let offset = tid as usize * EMBED_DIM ;
76+ let mut emb = self . engine . weights ( ) . wte [ offset..offset + EMBED_DIM ] . to_vec ( ) ;
77+ if let Some ( dim) = req. dimensions {
78+ emb. truncate ( dim) ;
79+ }
80+ EmbeddingData :: new ( idx, emb)
81+ } ) . collect ( ) ;
82+
83+ EmbeddingResponse :: new (
84+ "gpt2" . into ( ) ,
20585 data,
206- usage : Usage {
207- prompt_tokens : req. input_tokens . len ( ) ,
208- completion_tokens : 0 ,
209- total_tokens : req. input_tokens . len ( ) ,
210- } ,
211- }
86+ Usage { prompt_tokens : token_ids. len ( ) , completion_tokens : 0 , total_tokens : token_ids. len ( ) } ,
87+ )
21288 }
21389
214- /// /v1/models handler.
215- pub fn model_info ( & self ) -> ModelInfo {
216- ModelInfo :: gpt2_small ( )
90+ /// ` /v1/models/{id}`
91+ pub fn model_info ( ) -> Model {
92+ Model :: new ( "gpt2" , "adaworldapi" , 0 )
21793 }
21894
219- /// Access the underlying engine (for advanced usage).
22095 pub fn engine_mut ( & mut self ) -> & mut Gpt2Engine {
22196 & mut self . engine
22297 }
@@ -228,25 +103,21 @@ mod tests {
228103
229104 #[ test]
230105 fn test_model_info ( ) {
231- let info = ModelInfo :: gpt2_small ( ) ;
232- assert_eq ! ( info. vocab_size, 50257 ) ;
233- assert_eq ! ( info. embed_dim, 768 ) ;
234- assert_eq ! ( info. num_layers, 12 ) ;
235- assert_eq ! ( info. num_heads, 12 ) ;
236- assert_eq ! ( info. max_seq_len, 1024 ) ;
106+ let m = Gpt2Api :: model_info ( ) ;
107+ assert_eq ! ( m. id, "gpt2" ) ;
108+ assert_eq ! ( m. object, "model" ) ;
237109 }
238110
239111 #[ test]
240- fn test_completion_request_default ( ) {
112+ fn test_completion_defaults ( ) {
241113 let req = CompletionRequest :: default ( ) ;
242- assert_eq ! ( req. max_tokens, 128 ) ;
243- assert_eq ! ( req. temperature, 1.0 ) ;
244- assert_eq ! ( req. stop_token, Some ( 50256 ) ) ;
114+ assert_eq ! ( req. model, "gpt2" ) ;
115+ assert_eq ! ( req. max_tokens, Some ( 128 ) ) ;
245116 }
246117
247118 #[ test]
248- fn test_finish_reason_variants ( ) {
249- assert_eq ! ( FinishReason :: Stop , FinishReason :: Stop ) ;
250- assert_ne ! ( FinishReason :: Stop , FinishReason :: Length ) ;
119+ fn test_completion_response_object ( ) {
120+ let resp = CompletionResponse :: new ( "x" . into ( ) , "gpt2" . into ( ) , vec ! [ ] , Usage :: default ( ) , 0 ) ;
121+ assert_eq ! ( resp . object , "text_completion" ) ;
251122 }
252123}
0 commit comments