88//! Memory usage: O(num_experts_per_tok × expert_size) for the buffer pool,
99//! plus whatever the OS decides to keep in the page cache.
1010
11- use std:: sync:: Arc ;
11+ use std:: collections:: HashMap ;
12+ use std:: sync:: { Arc , RwLock } ;
1213
1314use candle_core:: { DType , Device , Result , Tensor } ;
1415
1516use crate :: utils:: tensor_storage:: TensorStorageProvider ;
1617
1718use super :: expert_provider:: { ExpertProvider , ExpertWeights } ;
1819
20+ /// Simple LRU cache for dequantized expert weights.
21+ /// Avoids re-dequantizing the same popular experts across tokens.
22+ struct ExpertCache {
23+ entries : RwLock < HashMap < usize , ExpertWeights > > ,
24+ capacity : usize ,
25+ }
26+
1927/// Pre-computed tensor names for a single expert (avoids format! on hot path).
2028struct ExpertNames {
2129 gate_proj : String ,
@@ -86,6 +94,8 @@ pub struct DiskExpertProvider {
8694 stacked_meta : Option < StackedMeta > ,
8795 /// Pre-computed stacked quantization tensor names.
8896 stacked_quant_names : Option < StackedQuantNames > ,
97+ /// Cache of dequantized expert weights (avoids repeated dequantization for popular experts).
98+ cache : Option < ExpertCache > ,
8999}
90100
91101impl std:: fmt:: Debug for DiskExpertProvider {
@@ -170,6 +180,7 @@ impl DiskExpertProvider {
170180 gptq_group_size,
171181 stacked_meta : None ,
172182 stacked_quant_names : None ,
183+ cache : None ,
173184 }
174185 }
175186
@@ -272,7 +283,7 @@ impl DiskExpertProvider {
272283 let storage_dtype = stacked_meta. as_ref ( ) . map ( |m| m. storage_dtype ) ;
273284 let use_f32_zerocopy = !is_affine && dtype == DType :: F32
274285 && storage_dtype. is_some_and ( |sd| sd == DType :: F32 ) ;
275- Self {
286+ let provider = Self {
276287 storage,
277288 layer_prefix,
278289 expert_names,
@@ -284,7 +295,33 @@ impl DiskExpertProvider {
284295 gptq_group_size : None ,
285296 stacked_meta,
286297 stacked_quant_names,
298+ // Enable cache for quantized experts — dequantization is expensive
299+ cache : if is_affine {
300+ Some ( ExpertCache {
301+ entries : RwLock :: new ( HashMap :: with_capacity ( num_experts) ) ,
302+ capacity : num_experts,
303+ } )
304+ } else {
305+ None
306+ } ,
307+ } ;
308+
309+ // Pre-warm: dequantize all experts at construction (moves cost from first token to loading)
310+ if provider. cache . is_some ( ) {
311+ log:: info!( "pre-warming expert cache for {} experts..." , num_experts) ;
312+ for i in 0 ..num_experts {
313+ if let Ok ( ew) = provider. get_expert_uncached ( i) {
314+ if let Some ( ref cache) = provider. cache {
315+ if let Ok ( mut entries) = cache. entries . write ( ) {
316+ entries. insert ( i, ew) ;
317+ }
318+ }
319+ }
320+ }
321+ log:: info!( "expert cache warmed ({} entries)" , num_experts) ;
287322 }
323+
324+ provider
288325 }
289326
290327 /// Reinterpret a `Vec<u8>` as `Vec<f32>` without copying.
@@ -410,6 +447,87 @@ impl ExpertProvider for DiskExpertProvider {
410447 ) ) ) ;
411448 }
412449
450+ // Check cache first — stores CPU-side dequantized tensors to avoid repeated dequant
451+ if let Some ( ref cache) = self . cache {
452+ if let Ok ( entries) = cache. entries . read ( ) {
453+ if let Some ( ew) = entries. get ( & idx) {
454+ // Cache hit: transfer CPU tensors to target device
455+ return if self . needs_device_transfer {
456+ Ok ( ExpertWeights {
457+ gate_proj : ew. gate_proj . to_device ( & self . device ) ?,
458+ up_proj : ew. up_proj . to_device ( & self . device ) ?,
459+ down_proj : ew. down_proj . to_device ( & self . device ) ?,
460+ } )
461+ } else {
462+ Ok ( ew. clone ( ) )
463+ } ;
464+ }
465+ }
466+ }
467+
468+ // get_expert_uncached returns CPU tensors when cache is enabled (stacked affine path)
469+ let cpu_result = self . get_expert_uncached ( idx) ?;
470+
471+ // Store CPU tensors in cache
472+ if let Some ( ref cache) = self . cache {
473+ if let Ok ( mut entries) = cache. entries . write ( ) {
474+ if entries. len ( ) < cache. capacity {
475+ entries. insert ( idx, cpu_result. clone ( ) ) ;
476+ }
477+ }
478+ }
479+
480+ // Transfer to target device — batch gate+up into one PCIe transfer
481+ if self . needs_device_transfer {
482+ let gate_up = Tensor :: cat ( & [ & cpu_result. gate_proj , & cpu_result. up_proj ] , 0 ) ?;
483+ let gate_up_gpu = gate_up. to_device ( & self . device ) ?;
484+ let g_rows = cpu_result. gate_proj . dim ( 0 ) ?;
485+ Ok ( ExpertWeights {
486+ gate_proj : gate_up_gpu. narrow ( 0 , 0 , g_rows) ?,
487+ up_proj : gate_up_gpu. narrow ( 0 , g_rows, g_rows) ?,
488+ down_proj : cpu_result. down_proj . to_device ( & self . device ) ?,
489+ } )
490+ } else {
491+ Ok ( cpu_result)
492+ }
493+ }
494+
495+ fn num_experts ( & self ) -> usize {
496+ self . num_experts
497+ }
498+
499+ #[ cfg( unix) ]
500+ fn prefetch_experts ( & self , indices : & [ usize ] ) {
501+ for & idx in indices {
502+ if idx >= self . num_experts {
503+ continue ;
504+ }
505+ let names = & self . expert_names [ idx] ;
506+ for name in [ & names. gate_proj , & names. up_proj , & names. down_proj ] {
507+ if let Some ( ( bytes, _, _) ) = self . storage . tensor_bytes ( name) {
508+ unsafe {
509+ libc:: posix_madvise (
510+ bytes. as_ptr ( ) as * mut _ ,
511+ bytes. len ( ) ,
512+ libc:: POSIX_MADV_WILLNEED ,
513+ ) ;
514+ }
515+ }
516+ }
517+ }
518+ }
519+ }
520+
521+ impl DiskExpertProvider {
522+ /// Internal: load expert weights without cache lookup.
523+ fn get_expert_uncached ( & self , idx : usize ) -> Result < ExpertWeights > {
524+ if idx >= self . num_experts {
525+ return Err ( candle_core:: Error :: Msg ( format ! (
526+ "expert index {idx} out of range (num_experts={})" ,
527+ self . num_experts
528+ ) ) ) ;
529+ }
530+
413531 let names = & self . expert_names [ idx] ;
414532
415533 // GPTQ path: read and dequantize each weight individually
@@ -443,8 +561,8 @@ impl ExpertProvider for DiskExpertProvider {
443561 let scales = Tensor :: from_raw_buffer ( s_bytes, DType :: BF16 , & proj. scales_shape , & Device :: Cpu ) ?;
444562 let biases = Tensor :: from_raw_buffer ( b_bytes, DType :: BF16 , & proj. scales_shape , & Device :: Cpu ) ?;
445563 let weight = crate :: utils:: gptq:: dequantize_packed_4bit ( & packed, & scales, & biases, sm. group_size ) ?;
446- let weight = weight . to_dtype ( self . dtype ) ? ;
447- if self . needs_device_transfer { weight. to_device ( target_device ) } else { Ok ( weight ) }
564+ // Return CPU tensor — device transfer happens in get_expert after caching
565+ weight. to_dtype ( self . dtype )
448566 } ;
449567 return Ok ( ExpertWeights {
450568 gate_proj : read_dequant ( & names. gate_proj , & sm. gate , & qn. gate_scales , & qn. gate_biases ) ?,
@@ -563,31 +681,6 @@ impl ExpertProvider for DiskExpertProvider {
563681 } )
564682 }
565683
566- fn num_experts ( & self ) -> usize {
567- self . num_experts
568- }
569-
570- #[ cfg( unix) ]
571- fn prefetch_experts ( & self , indices : & [ usize ] ) {
572- for & idx in indices {
573- if idx >= self . num_experts {
574- continue ;
575- }
576- let names = & self . expert_names [ idx] ;
577- // Issue madvise(WILLNEED) for each projection's mmap region
578- for name in [ & names. gate_proj , & names. up_proj , & names. down_proj ] {
579- if let Some ( ( bytes, _, _) ) = self . storage . tensor_bytes ( name) {
580- unsafe {
581- libc:: posix_madvise (
582- bytes. as_ptr ( ) as * mut _ ,
583- bytes. len ( ) ,
584- libc:: POSIX_MADV_WILLNEED ,
585- ) ;
586- }
587- }
588- }
589- }
590- }
591684}
592685
593686#[ cfg( test) ]
0 commit comments