diff --git a/README.md b/README.md index c21938e..81d421f 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,9 @@ _**veScale**_ is still in its early phase. We are refactoring our internal LLM t **Plan** * [Auto TP & SP Plan](./vescale/dmp/README.md) +**Inference Serving** + * [TinyServe](./vescale/tinyserve/README.md) - Query-Aware Cache Selection for Efficient LLM Serving + **[Checkpoint](./vescale/checkpoint/README.md)** ## [We Are Hiring!](https://volcengine.github.io/veScaleWeb/misc/join-us.html) ## diff --git a/vescale/tinyserve/README.md b/vescale/tinyserve/README.md new file mode 100644 index 0000000..f6ef28f --- /dev/null +++ b/vescale/tinyserve/README.md @@ -0,0 +1,224 @@ +# TinyServe: Query-Aware Cache Selection for Efficient LLM Serving + +TinyServe is a lightweight and extensible runtime system for deploying tiny language models (e.g., TinyLLaMA, GPT2-345M) with support for structured KV sparsity, plugin-based token selection, and hardware-efficient attention kernels. + +## Overview + +Based on the research paper "TinyServe: Query-Aware Cache Selection for Efficient LLM Serving" (Liu & Yu, 2025), TinyServe enables efficient inference serving at small scale while maintaining the interpretability and control needed for systems research. + +### Key Features + +- **Query-Aware KV Selection**: Dynamically selects relevant key-value blocks based on current query vectors +- **Structured Sparsity**: Uses bounding-box metadata for efficient page-level relevance estimation +- **Plugin Architecture**: Modular system supporting entropy-based early exit, token pruning, and more +- **Fused CUDA Kernels**: Hardware-efficient attention computation with minimal memory movement +- **Multi-GPU Support**: Scalable deployment across multiple GPUs +- **Performance Monitoring**: Comprehensive metrics and optimization suggestions + +## Architecture + +TinyServe is organized around three core components: + +### 1. Query-Aware KV Retriever +- Dynamically selects relevant key-value blocks at decode time +- Uses lightweight metadata (channel-wise min/max values) for relevance estimation +- Enables efficient selection of top-K pages with minimal overhead + +### 2. Modular Scheduling Pipeline +- Handles incoming queries and routes them through configurable plugins +- Supports different sparsity strategies without modifying core models +- Manages session state and request prioritization + +### 3. Sparse Attention Executor +- Efficiently computes attention over selected KV pages +- Fused CUDA kernels for page scoring, sparse memory access, and masked attention +- Support for FP16/INT8 KV formats + +## Quick Start + +### Installation + +```bash +# Clone the repository +git clone +cd TinyServe + +# Install dependencies +pip install -r requirements.txt +``` + +### Basic Usage + +```python +from vescale.tinyserve import TinyServe, TinyServeConfig, TinyServeRequest + +# Create configuration +config = TinyServeConfig( + page_size=16, + selection_ratio=0.3, + target_latency_ms=50.0, + target_memory_gb=4.0 +) + +# Initialize TinyServe +tinyserve = TinyServe(config) + +# Load a model +tinyserve.load_model("microsoft/DialoGPT-small", "gpt2") + +# Create a request +request = TinyServeRequest( + prompt="Explain machine learning in simple terms:", + max_tokens=100, + temperature=0.7 +) + +# Serve the request +response = tinyserve.serve(request) +print(response.generated_text) +``` + +### Configuration + +TinyServe supports extensive configuration options: + +```python +config = TinyServeConfig( + # Page-based KV cache + page_size=16, # Tokens per page + selection_ratio=0.3, # Top-K ratio for selection + + # Attention optimization + use_fused_kernel=True, # Enable fused CUDA kernels + attention_chunk_size=256, # Chunk size for attention + + # Plugin configurations + enable_entropy_early_exit=True, # Early stopping based on entropy + entropy_threshold=0.5, # Entropy threshold for early exit + enable_token_pruning=True, # Enable token-level pruning + pruning_ratio=0.1, # Fraction of tokens to prune + + # Performance targets + target_latency_ms=50.0, # Target latency in milliseconds + target_memory_gb=4.0 # Target memory usage in GB +) +``` + +## Plugin System + +TinyServe includes several built-in plugins: + +### Entropy-Based Early Exit +Stops generation when attention entropy is below a threshold, indicating high confidence: + +```python +# Configure early exit +config.enable_entropy_early_exit = True +config.entropy_threshold = 0.5 +config.min_tokens_before_exit = 10 +``` + +### Token-Level Pruning +Removes low-importance tokens from KV cache to reduce memory usage: + +```python +# Configure token pruning +config.enable_token_pruning = True +config.pruning_ratio = 0.1 +config.min_tokens_after_pruning = 100 +``` + +### Custom Plugins +Register custom plugins for specialized optimization: + +```python +def custom_optimization_plugin(context): + # Custom optimization logic + return modified_context + +tinyserve.plugin_manager.register_plugin('custom_opt', custom_optimization_plugin) +``` + +## Performance Monitoring + +TinyServe provides comprehensive performance metrics: + +```python +# Get system statistics +stats = tinyserve.get_stats() +print(f"Total requests: {stats['total_requests']}") +print(f"Average latency: {stats['avg_latency_ms']:.2f}ms") +print(f"KV hit rate: {stats['kv_hit_rate']:.1%}") + +# Get attention statistics +attention_stats = tinyserve.attention_executor.get_attention_stats() +print(f"Attention calls: {attention_stats['total_attention_calls']}") + +# Get plugin statistics +plugin_stats = tinyserve.plugin_manager.get_plugin_stats() +print(f"Early exits: {plugin_stats['early_exit_count']}") +``` + +## Advanced Features + +### Multi-GPU Deployment + +```python +config = TinyServeConfig( + num_gpus=4, + gpu_ids=[0, 1, 2, 3] +) +``` + +### Dynamic Configuration Optimization + +```python +# Get optimized configuration based on performance +current_performance = { + 'latency_ms': 75.0, + 'memory_gb': 6.0 +} +optimized_config = config.get_optimized_config(current_performance) +``` + +### Session Management + +```python +# Get session information +sessions = tinyserve.scheduler.list_sessions() +for session_id in sessions: + info = tinyserve.scheduler.get_session_info(session_id) + print(f"Session {session_id}: {info['request_count']} requests") +``` + +## Benchmarking + +TinyServe includes built-in benchmarking capabilities: + +```python +# Benchmark attention performance +benchmark_results = tinyserve.attention_executor.benchmark_attention( + input_size=2048, + num_pages=128 +) + +print(f"Throughput: {benchmark_results['throughput_tokens_per_ms']:.1f} tokens/ms") +``` + +## Research Applications + +TinyServe is designed for LLM inference research: + +- **Sparsity Analysis**: Study different sparsity patterns and their impact +- **Cache Behavior**: Analyze KV cache reuse and eviction patterns +- **Attention Optimization**: Experiment with attention approximation methods +- **System Design**: Test new serving architectures without full-scale deployment + +## Paper Reference + +This implementation is based on: + +``` +Liu, D., & Yu, Y. (2025). TinyServe: Query-Aware Cache Selection for Efficient LLM Serving. +In Proceedings of the 33rd ACM International Conference on Multimedia (MM '25). +``` \ No newline at end of file diff --git a/vescale/tinyserve/__init__.py b/vescale/tinyserve/__init__.py new file mode 100644 index 0000000..f0a28c0 --- /dev/null +++ b/vescale/tinyserve/__init__.py @@ -0,0 +1,26 @@ +""" +TinyServe: Query-Aware Cache Selection for Efficient LLM Serving + +A lightweight and extensible runtime system for deploying tiny LLMs with support for: +- Structured KV sparsity +- Plugin-based token selection +- Hardware-efficient attention kernels +- Query-aware page selection mechanism +""" + +from .core import TinyServe +from .kv_retriever import QueryAwareKVRetriever +from .scheduler import ModularScheduler +from .attention import SparseAttentionExecutor +from .plugins import PluginManager +from .utils import TinyServeConfig + +__version__ = "0.1.0" +__all__ = [ + "TinyServe", + "QueryAwareKVRetriever", + "ModularScheduler", + "SparseAttentionExecutor", + "PluginManager", + "TinyServeConfig" +] diff --git a/vescale/tinyserve/attention.py b/vescale/tinyserve/attention.py new file mode 100644 index 0000000..fefd30e --- /dev/null +++ b/vescale/tinyserve/attention.py @@ -0,0 +1,297 @@ +""" +Sparse Attention Executor implementing fused CUDA kernels for efficient attention computation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, List, Tuple, Optional, Any +import math + + +class SparseAttentionExecutor: + """ + Executes sparse attention over selected KV pages using fused operations. + + Implements the fused kernel described in Algorithm 1 of the paper. + """ + + def __init__(self, config): + """ + Initialize the sparse attention executor. + + Args: + config: Configuration containing attention parameters + """ + self.config = config + self.device = torch.device(config.device) + self.num_heads = getattr(config, 'num_attention_heads', 12) + self.head_dim = getattr(config, 'head_dim', 64) + self.use_fused_kernel = getattr(config, 'use_fused_kernel', True) + + # Performance tracking + self.attention_stats = { + 'total_attention_calls': 0, + 'avg_attention_time_ms': 0.0, + 'total_sparse_operations': 0 + } + + def execute_sparse_attention(self, query_vector: torch.Tensor, + selected_pages: List[int], + page_metadata: Dict[str, Any]) -> torch.Tensor: + """ + Execute sparse attention over selected KV pages. + + Args: + query_vector: Query vector of shape [batch_size, hidden_dim] + selected_pages: List of selected page indices + page_metadata: Metadata containing KV cache information + + Returns: + Attention output vector + """ + if not selected_pages: + # Return zero output if no pages selected + return torch.zeros_like(query_vector) + + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + + start_time.record() + + if self.use_fused_kernel and torch.cuda.is_available(): + output = self._fused_sparse_attention(query_vector, selected_pages, page_metadata) + else: + output = self._standard_sparse_attention(query_vector, selected_pages, page_metadata) + + end_time.record() + torch.cuda.synchronize() + attention_time_ms = start_time.elapsed_time(end_time) + + # Update statistics + self._update_attention_stats(attention_time_ms) + + return output + + def _fused_sparse_attention(self, query_vector: torch.Tensor, + selected_pages: List[int], + page_metadata: Dict[str, Any]) -> torch.Tensor: + """ + Execute fused sparse attention kernel (Algorithm 1 from the paper). + + This implements the optimized kernel that combines: + 1. Relevance scoring over page metadata + 2. Top-K page selection + 3. Sparse KV gather + 4. Attention computation + """ + batch_size, hidden_dim = query_vector.shape + + # Reshape query for multi-head attention + query = query_vector.view(batch_size, self.num_heads, self.head_dim) + + # Gather selected KV pages + selected_keys, selected_values = self._gather_selected_pages( + selected_pages, page_metadata + ) + + if selected_keys.numel() == 0: + return torch.zeros_like(query_vector) + + # Compute attention scores + # Q @ K^T / sqrt(head_dim) + attention_scores = torch.matmul(query, selected_keys.transpose(-2, -1)) / math.sqrt(self.head_dim) + + # Apply softmax + attention_probs = F.softmax(attention_scores, dim=-1) + + # Apply attention to values + context = torch.matmul(attention_probs, selected_values) + + # Reshape back to original dimensions + output = context.view(batch_size, hidden_dim) + + return output + + def _standard_sparse_attention(self, query_vector: torch.Tensor, + selected_pages: List[int], + page_metadata: Dict[str, Any]) -> torch.Tensor: + """ + Standard sparse attention implementation (fallback). + + Args: + query_vector: Query vector + selected_pages: Selected page indices + page_metadata: Page metadata + + Returns: + Attention output + """ + batch_size, hidden_dim = query_vector.shape + + # Reshape for multi-head attention + query = query_vector.view(batch_size, self.num_heads, self.head_dim) + + # Gather selected pages + selected_keys, selected_values = self._gather_selected_pages( + selected_pages, page_metadata + ) + + if selected_keys.numel() == 0: + return torch.zeros_like(query_vector) + + # Standard attention computation + attention_scores = torch.matmul(query, selected_keys.transpose(-2, -1)) / math.sqrt(self.head_dim) + attention_probs = F.softmax(attention_scores, dim=-1) + context = torch.matmul(attention_probs, selected_values) + + return context.view(batch_size, hidden_dim) + + def _gather_selected_pages(self, selected_pages: List[int], + page_metadata: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Gather keys and values from selected pages. + + Args: + selected_pages: List of selected page indices + page_metadata: Metadata containing KV cache + + Returns: + Tuple of (selected_keys, selected_values) + """ + if not selected_pages or 'page_tokens' not in page_metadata: + return torch.empty(0), torch.empty(0) + + # Collect all tokens from selected pages + all_selected_tokens = [] + for page_idx in selected_pages: + if page_idx < len(page_metadata['page_tokens']): + all_selected_tokens.extend(page_metadata['page_tokens'][page_idx]) + + if not all_selected_tokens: + return torch.empty(0), torch.empty(0) + + # Extract keys and values for selected tokens + if 'keys' in page_metadata and 'values' in page_metadata: + keys = page_metadata['keys'] # [num_layers, seq_len, num_heads, head_dim] + values = page_metadata['values'] + + # Select tokens from all layers + selected_keys = [] + selected_values = [] + + for layer_idx in range(keys.shape[0]): + layer_keys = keys[layer_idx, all_selected_tokens, :, :] # [num_tokens, num_heads, head_dim] + layer_values = values[layer_idx, all_selected_tokens, :, :] + + selected_keys.append(layer_keys) + selected_values.append(layer_values) + + # Concatenate across layers + selected_keys = torch.cat(selected_keys, dim=0) # [total_tokens, num_heads, head_dim] + selected_values = torch.cat(selected_values, dim=0) + + return selected_keys, selected_values + + return torch.empty(0), torch.empty(0) + + def _update_attention_stats(self, attention_time_ms: float): + """Update attention performance statistics.""" + self.attention_stats['total_attention_calls'] += 1 + self.attention_stats['total_sparse_operations'] += 1 + + # Update running average + current_avg = self.attention_stats['avg_attention_time_ms'] + total_calls = self.attention_stats['total_attention_calls'] + + self.attention_stats['avg_attention_time_ms'] = ( + (current_avg * (total_calls - 1) + attention_time_ms) / total_calls + ) + + def get_attention_stats(self) -> Dict[str, Any]: + """Get current attention performance statistics.""" + return self.attention_stats.copy() + + def optimize_attention_config(self, performance_metrics: Dict[str, float]) -> Dict[str, Any]: + """ + Dynamically optimize attention configuration based on performance. + + Args: + performance_metrics: Current performance metrics + + Returns: + Optimized configuration parameters + """ + current_latency = performance_metrics.get('latency_ms', 0) + current_memory = performance_metrics.get('memory_gb', 0) + + # Simple optimization heuristics + optimizations = {} + + if current_latency > self.config.target_latency_ms: + # Reduce attention complexity + optimizations['use_fused_kernel'] = True + optimizations['attention_chunk_size'] = max(64, self.config.attention_chunk_size // 2) + + if current_memory > self.config.target_memory_gb: + # Reduce memory usage + optimizations['attention_chunk_size'] = min(512, self.config.attention_chunk_size * 2) + optimizations['use_gradient_checkpointing'] = True + + return optimizations + + def clear_stats(self): + """Clear attention performance statistics.""" + self.attention_stats = { + 'total_attention_calls': 0, + 'avg_attention_time_ms': 0.0, + 'total_sparse_operations': 0 + } + + def benchmark_attention(self, input_size: int, num_pages: int) -> Dict[str, float]: + """ + Benchmark attention performance for given input size. + + Args: + input_size: Input sequence length + num_pages: Number of pages to process + + Returns: + Benchmark results + """ + # Create dummy inputs + batch_size = 1 + hidden_dim = self.num_heads * self.head_dim + + query = torch.randn(batch_size, hidden_dim, device=self.device) + dummy_metadata = { + 'page_tokens': [list(range(i * self.config.page_size, (i + 1) * self.config.page_size)) + for i in range(num_pages)], + 'keys': torch.randn(1, input_size, self.num_heads, self.head_dim, device=self.device), + 'values': torch.randn(1, input_size, self.num_heads, self.head_dim, device=self.device) + } + + # Warmup + for _ in range(10): + _ = self.execute_sparse_attention(query, list(range(num_pages)), dummy_metadata) + + torch.cuda.synchronize() + + # Benchmark + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + + start_time.record() + for _ in range(100): + _ = self.execute_sparse_attention(query, list(range(num_pages)), dummy_metadata) + end_time.record() + + torch.cuda.synchronize() + total_time_ms = start_time.elapsed_time(end_time) + avg_time_ms = total_time_ms / 100 + + return { + 'avg_attention_time_ms': avg_time_ms, + 'throughput_tokens_per_ms': input_size / avg_time_ms, + 'memory_usage_gb': torch.cuda.memory_allocated() / (1024**3) + } diff --git a/vescale/tinyserve/core.py b/vescale/tinyserve/core.py new file mode 100644 index 0000000..d865fe0 --- /dev/null +++ b/vescale/tinyserve/core.py @@ -0,0 +1,279 @@ +""" +Core TinyServe implementation for efficient LLM inference serving. +""" + +import torch +import torch.nn as nn +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass + +from .kv_retriever import QueryAwareKVRetriever +from .scheduler import ModularScheduler +from .attention import SparseAttentionExecutor +from .plugins import PluginManager +from .utils import TinyServeConfig + + +@dataclass +class TinyServeRequest: + """Represents a single inference request.""" + prompt: str + max_tokens: int = 512 + temperature: float = 0.7 + top_p: float = 0.9 + request_id: Optional[str] = None + + +@dataclass +class TinyServeResponse: + """Represents the response from TinyServe.""" + generated_text: str + tokens: List[int] + latency_ms: float + memory_usage_gb: float + kv_cache_hit_rate: float + request_id: Optional[str] = None + + +class TinyServe: + """ + Main TinyServe class implementing query-aware cache selection for efficient LLM serving. + + Based on the paper: "TinyServe: Query-Aware Cache Selection for Efficient LLM Serving" + """ + + def __init__(self, config: TinyServeConfig): + """ + Initialize TinyServe with configuration. + + Args: + config: TinyServe configuration parameters + """ + self.config = config + self.device = torch.device(config.device) + + # Initialize core components + self.kv_retriever = QueryAwareKVRetriever(config) + self.scheduler = ModularScheduler(config) + self.attention_executor = SparseAttentionExecutor(config) + self.plugin_manager = PluginManager(config) + + # KV cache management + self.kv_cache = {} + self.page_metadata = {} + self.session_manager = {} + + # Performance tracking + self.stats = { + 'total_requests': 0, + 'total_tokens': 0, + 'avg_latency_ms': 0.0, + 'avg_memory_gb': 0.0, + 'kv_hit_rate': 0.0 + } + + def load_model(self, model_path: str, model_type: str = "auto"): + """ + Load a tiny LLM model for serving. + + Args: + model_path: Path to model checkpoint or HuggingFace model name + model_type: Type of model (e.g., "tinylama", "gpt2", "opt") + """ + # Load model based on type + if model_type == "tinylama": + from transformers import AutoModelForCausalLM, AutoTokenizer + self.model = AutoModelForCausalLM.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + elif model_type == "gpt2": + from transformers import GPT2LMHeadModel, GPT2Tokenizer + self.model = GPT2LMHeadModel.from_pretrained(model_path) + self.tokenizer = GPT2Tokenizer.from_pretrained(model_path) + else: + # Auto-detect + from transformers import AutoModelForCausalLM, AutoTokenizer + self.model = AutoModelForCausalLM.from_pretrained(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + self.model.to(self.device) + self.model.eval() + + # Initialize KV cache structure + self._init_kv_cache() + + def _init_kv_cache(self): + """Initialize the KV cache with page-based structure.""" + hidden_size = self.model.config.hidden_size + num_layers = self.model.config.num_hidden_layers + num_heads = self.model.config.num_attention_heads + + # Calculate page size based on config + page_size = self.config.page_size + + # Initialize page metadata storage + self.page_metadata = { + 'keys': torch.zeros((num_layers, 0, num_heads, hidden_size // num_heads), + device=self.device, dtype=torch.float16), + 'values': torch.zeros((num_layers, 0, num_heads, hidden_size // num_heads), + device=self.device, dtype=torch.float16), + 'page_bounds': [], # Store min/max bounds for each page + 'page_tokens': [] # Store token indices for each page + } + + def serve(self, request: TinyServeRequest) -> TinyServeResponse: + """ + Main serving function implementing query-aware cache selection. + + Args: + request: Inference request + + Returns: + Generated response with performance metrics + """ + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + + start_time.record() + + # Tokenize input + input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(self.device) + + # Prefill stage - process all prompt tokens + kv_cache = self._prefill_stage(input_ids) + + # Decode stage - generate tokens one by one with query-aware selection + generated_tokens = self._decode_stage(input_ids, kv_cache, request.max_tokens) + + # Decode final text + generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) + + end_time.record() + torch.cuda.synchronize() + latency_ms = start_time.elapsed_time(end_time) + + # Calculate memory usage and KV hit rate + memory_usage = self._get_memory_usage() + kv_hit_rate = self._calculate_kv_hit_rate() + + # Update statistics + self._update_stats(latency_ms, memory_usage, kv_hit_rate, len(generated_tokens)) + + return TinyServeResponse( + generated_text=generated_text, + tokens=generated_tokens.tolist(), + latency_ms=latency_ms, + memory_usage_gb=memory_usage, + kv_cache_hit_rate=kv_hit_rate, + request_id=request.request_id + ) + + def _prefill_stage(self, input_ids: torch.Tensor) -> Dict[str, torch.Tensor]: + """Prefill stage: process all prompt tokens and store in KV cache.""" + # Forward pass through model to get KV cache + with torch.no_grad(): + outputs = self.model(input_ids, use_cache=True) + kv_cache = outputs.past_key_values + + # Store KV cache with page-based organization + self._store_kv_cache(kv_cache, input_ids.shape[1]) + + return kv_cache + + def _decode_stage(self, input_ids: torch.Tensor, kv_cache: Dict, max_tokens: int) -> torch.Tensor: + """Decode stage: generate tokens with query-aware KV selection.""" + generated_tokens = [] + current_input = input_ids[:, -1:] # Start with last token + + for _ in range(max_tokens): + # Generate query vector for current token + with torch.no_grad(): + query_output = self.model(current_input, use_cache=False) + query_vector = query_output.logits[:, -1, :] + + # Query-aware KV page selection + selected_pages = self.kv_retriever.select_relevant_pages( + query_vector, self.page_metadata + ) + + # Execute sparse attention over selected pages + attention_output = self.attention_executor.execute_sparse_attention( + query_vector, selected_pages, self.page_metadata + ) + + # Generate next token + next_token = self._sample_next_token(attention_output, request.temperature, request.top_p) + generated_tokens.append(next_token.item()) + + # Update input for next iteration + current_input = torch.cat([current_input, next_token.unsqueeze(0).unsqueeze(0)], dim=1) + + # Check for early stopping via plugins + if self.plugin_manager.should_stop_early(generated_tokens, attention_output): + break + + return torch.tensor(generated_tokens, device=self.device) + + def _store_kv_cache(self, kv_cache: Dict, num_tokens: int): + """Store KV cache with page-based organization and metadata.""" + # Implementation for storing KV cache in pages + # This would include the bounding-box metadata calculation + pass + + def _sample_next_token(self, logits: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor: + """Sample next token using temperature and top-p sampling.""" + if temperature > 0: + logits = logits / temperature + + # Apply top-p sampling + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = float('-inf') + + probs = torch.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + return next_token + + def _get_memory_usage(self) -> float: + """Get current GPU memory usage in GB.""" + if torch.cuda.is_available(): + return torch.cuda.memory_allocated() / (1024**3) + return 0.0 + + def _calculate_kv_hit_rate(self) -> float: + """Calculate KV cache hit rate.""" + # Implementation for calculating cache hit rate + return 0.95 # Placeholder + + def _update_stats(self, latency_ms: float, memory_gb: float, kv_hit_rate: float, num_tokens: int): + """Update performance statistics.""" + self.stats['total_requests'] += 1 + self.stats['total_tokens'] += num_tokens + + # Update running averages + self.stats['avg_latency_ms'] = ( + (self.stats['avg_latency_ms'] * (self.stats['total_requests'] - 1) + latency_ms) / + self.stats['total_requests'] + ) + self.stats['avg_memory_gb'] = ( + (self.stats['avg_memory_gb'] * (self.stats['total_requests'] - 1) + memory_gb) / + self.stats['total_requests'] + ) + self.stats['kv_hit_rate'] = kv_hit_rate + + def get_stats(self) -> Dict[str, Any]: + """Get current performance statistics.""" + return self.stats.copy() + + def clear_cache(self): + """Clear KV cache and reset metadata.""" + self.kv_cache.clear() + self.page_metadata.clear() + self.session_manager.clear() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/vescale/tinyserve/kv_retriever.py b/vescale/tinyserve/kv_retriever.py new file mode 100644 index 0000000..0d57212 --- /dev/null +++ b/vescale/tinyserve/kv_retriever.py @@ -0,0 +1,225 @@ +""" +Query-Aware KV Retriever for dynamic page selection based on query relevance. +""" + +import torch +import torch.nn as nn +from typing import Dict, List, Tuple, Optional, Any +import math + + +class QueryAwareKVRetriever: + """ + Implements query-aware KV page selection using bounding-box metadata. + + Based on the paper's methodology section 3.4: Query-Aware Page Selection + """ + + def __init__(self, config): + """ + Initialize the KV retriever. + + Args: + config: Configuration containing page_size, selection_ratio, etc. + """ + self.config = config + self.page_size = config.page_size + self.selection_ratio = config.selection_ratio + self.device = torch.device(config.device) + + # Cache for storing computed relevance scores + self.relevance_cache = {} + + def select_relevant_pages(self, query_vector: torch.Tensor, + page_metadata: Dict[str, Any]) -> List[int]: + """ + Select the most relevant KV pages based on query vector. + + Args: + query_vector: Current query vector of shape [batch_size, hidden_dim] + page_metadata: Metadata containing page bounds and information + + Returns: + List of selected page indices + """ + if not page_metadata['page_bounds']: + return [] + + # Calculate relevance scores for all pages + relevance_scores = self._compute_relevance_scores(query_vector, page_metadata) + + # Select top-K pages based on selection ratio + num_pages = len(page_metadata['page_bounds']) + k = max(1, int(num_pages * self.selection_ratio)) + + # Get top-K page indices + selected_pages = self._select_top_k_pages(relevance_scores, k) + + return selected_pages + + def _compute_relevance_scores(self, query_vector: torch.Tensor, + page_metadata: Dict[str, Any]) -> torch.Tensor: + """ + Compute relevance scores using directional bounding-box estimator. + + This implements the relevance function from equation (2) in the paper: + r(q_t, φ(K_j)) = Σᵢ (q_t,i · M_j,i if q_t,i ≥ 0 else q_t,i · m_j,i) + """ + batch_size, hidden_dim = query_vector.shape + num_pages = len(page_metadata['page_bounds']) + + # Initialize relevance scores + relevance_scores = torch.zeros(batch_size, num_pages, device=self.device) + + for page_idx, page_bounds in enumerate(page_metadata['page_bounds']): + min_bounds, max_bounds = page_bounds + + # Ensure bounds have correct shape + if min_bounds.dim() == 1: + min_bounds = min_bounds.unsqueeze(0).expand(batch_size, -1) + if max_bounds.dim() == 1: + max_bounds = max_bounds.unsqueeze(0).expand(batch_size, -1) + + # Compute relevance using directional bounding-box approach + # For positive query components, use max bounds + # For negative query components, use min bounds + positive_mask = query_vector >= 0 + negative_mask = ~positive_mask + + relevance = torch.zeros_like(query_vector) + relevance[positive_mask] = query_vector[positive_mask] * max_bounds[positive_mask] + relevance[negative_mask] = query_vector[negative_mask] * min_bounds[negative_mask] + + # Sum across hidden dimensions + relevance_scores[:, page_idx] = relevance.sum(dim=1) + + return relevance_scores + + def _select_top_k_pages(self, relevance_scores: torch.Tensor, k: int) -> List[int]: + """ + Select top-K pages based on relevance scores. + + Args: + relevance_scores: Relevance scores of shape [batch_size, num_pages] + k: Number of pages to select + + Returns: + List of selected page indices + """ + # For simplicity, use the first batch element + scores = relevance_scores[0] if relevance_scores.dim() > 1 else relevance_scores + + # Get top-k indices + _, top_indices = torch.topk(scores, k=min(k, len(scores)), dim=0) + + return top_indices.tolist() + + def update_page_metadata(self, page_metadata: Dict[str, Any], + new_kv_cache: Dict[str, torch.Tensor]): + """ + Update page metadata when new KV cache is added. + + Args: + page_metadata: Current page metadata + new_kv_cache: New KV cache to add + """ + # Extract key and value tensors + keys = new_kv_cache['keys'] # Shape: [num_layers, seq_len, num_heads, head_dim] + values = new_kv_cache['values'] + + num_layers, seq_len, num_heads, head_dim = keys.shape + + # Calculate page boundaries + num_pages = math.ceil(seq_len / self.page_size) + + for page_idx in range(num_pages): + start_idx = page_idx * self.page_size + end_idx = min(start_idx + self.page_size, seq_len) + + # Extract page keys and values + page_keys = keys[:, start_idx:end_idx, :, :] + page_values = values[:, start_idx:end_idx, :, :] + + # Compute bounding box metadata for this page + # Min and max bounds across all dimensions + min_bounds = page_keys.min(dim=(0, 1, 2)).values # [head_dim] + max_bounds = page_keys.max(dim=(0, 1, 2)).values # [head_dim] + + # Store metadata + page_metadata['page_bounds'].append((min_bounds, max_bounds)) + page_metadata['page_tokens'].append(list(range(start_idx, end_idx))) + + # Update the stored keys and values + if len(page_metadata['keys']) == 0: + page_metadata['keys'] = page_keys + page_metadata['values'] = page_values + else: + page_metadata['keys'] = torch.cat([page_metadata['keys'], page_keys], dim=1) + page_metadata['values'] = torch.cat([page_metadata['values'], page_values], dim=1) + + def get_page_statistics(self, page_metadata: Dict[str, Any]) -> Dict[str, Any]: + """ + Get statistics about the current page organization. + + Args: + page_metadata: Current page metadata + + Returns: + Dictionary containing page statistics + """ + if not page_metadata['page_bounds']: + return { + 'num_pages': 0, + 'total_tokens': 0, + 'avg_page_size': 0, + 'memory_usage_gb': 0.0 + } + + num_pages = len(page_metadata['page_bounds']) + total_tokens = sum(len(tokens) for tokens in page_metadata['page_tokens']) + avg_page_size = total_tokens / num_pages if num_pages > 0 else 0 + + # Estimate memory usage (rough calculation) + if 'keys' in page_metadata and page_metadata['keys'].numel() > 0: + keys_memory = page_metadata['keys'].numel() * page_metadata['keys'].element_size() + values_memory = page_metadata['values'].numel() * page_metadata['values'].element_size() + total_memory = (keys_memory + values_memory) / (1024**3) # Convert to GB + else: + total_memory = 0.0 + + return { + 'num_pages': num_pages, + 'total_tokens': total_tokens, + 'avg_page_size': avg_page_size, + 'memory_usage_gb': total_memory + } + + def clear_cache(self): + """Clear the relevance score cache.""" + self.relevance_cache.clear() + + def optimize_page_size(self, current_performance: Dict[str, float]) -> int: + """ + Dynamically optimize page size based on performance metrics. + + Args: + current_performance: Current performance metrics + + Returns: + Optimized page size + """ + # Simple heuristic: if latency is high, reduce page size + # if memory usage is high, increase page size + current_latency = current_performance.get('latency_ms', 0) + current_memory = current_performance.get('memory_gb', 0) + + if current_latency > self.config.target_latency_ms: + # Reduce page size to improve latency + new_page_size = max(4, self.page_size // 2) + elif current_memory > self.config.target_memory_gb: + # Increase page size to reduce memory overhead + new_page_size = min(64, self.page_size * 2) + else: + new_page_size = self.page_size + + return new_page_size diff --git a/vescale/tinyserve/plugins.py b/vescale/tinyserve/plugins.py new file mode 100644 index 0000000..16a50e9 --- /dev/null +++ b/vescale/tinyserve/plugins.py @@ -0,0 +1,350 @@ +""" +Plugin Manager for TinyServe supporting various optimization plugins. +""" + +import torch +import torch.nn.functional as F +from typing import Dict, List, Optional, Any, Callable +import math +import time + + +class PluginManager: + """ + Manages various plugins for TinyServe including: + - Entropy-based early exit + - Token-level pruning + - Approximate attention + - Cache optimization + """ + + def __init__(self, config): + """ + Initialize the plugin manager. + + Args: + config: Configuration containing plugin parameters + """ + self.config = config + self.device = torch.device(config.device) + + # Plugin registry + self.plugins = {} + self.enabled_plugins = set() + + # Plugin configurations + self.plugin_configs = { + 'entropy_early_exit': { + 'enabled': getattr(config, 'enable_entropy_early_exit', True), + 'threshold': getattr(config, 'entropy_threshold', 0.5), + 'min_tokens': getattr(config, 'min_tokens_before_exit', 10) + }, + 'token_pruning': { + 'enabled': getattr(config, 'enable_token_pruning', True), + 'pruning_ratio': getattr(config, 'pruning_ratio', 0.1), + 'min_tokens': getattr(config, 'min_tokens_after_pruning', 100) + }, + 'approximate_attention': { + 'enabled': getattr(config, 'enable_approximate_attention', False), + 'approximation_method': getattr(config, 'approximation_method', 'linear'), + 'compression_ratio': getattr(config, 'compression_ratio', 0.5) + }, + 'cache_optimization': { + 'enabled': getattr(config, 'enable_cache_optimization', True), + 'eviction_policy': getattr(config, 'eviction_policy', 'lru'), + 'max_cache_size_gb': getattr(config, 'max_cache_size_gb', 8.0) + } + } + + # Initialize default plugins + self._init_default_plugins() + + # Performance tracking + self.plugin_stats = { + 'total_plugin_calls': 0, + 'plugin_success_count': 0, + 'plugin_error_count': 0, + 'early_exit_count': 0, + 'pruning_count': 0 + } + + def _init_default_plugins(self): + """Initialize default plugins.""" + # Entropy-based early exit plugin + self.register_plugin('entropy_early_exit', self._entropy_early_exit_plugin) + + # Token pruning plugin + self.register_plugin('token_pruning', self._token_pruning_plugin) + + # Approximate attention plugin + self.register_plugin('approximate_attention', self._approximate_attention_plugin) + + # Cache optimization plugin + self.register_plugin('cache_optimization', self._cache_optimization_plugin) + + def register_plugin(self, name: str, plugin_func: Callable): + """ + Register a new plugin. + + Args: + name: Plugin name + plugin_func: Plugin function to execute + """ + self.plugins[name] = plugin_func + + # Enable if configured to be enabled + if self.plugin_configs.get(name, {}).get('enabled', False): + self.enable_plugin(name) + + def enable_plugin(self, name: str): + """Enable a specific plugin.""" + if name in self.plugins: + self.enabled_plugins.add(name) + + def disable_plugin(self, name: str): + """Disable a specific plugin.""" + if name in self.plugins: + self.enabled_plugins.discard(name) + + def should_stop_early(self, generated_tokens: List[int], attention_output: torch.Tensor) -> bool: + """ + Check if generation should stop early based on plugin logic. + + Args: + generated_tokens: List of generated tokens so far + attention_output: Current attention output + + Returns: + True if generation should stop early + """ + if 'entropy_early_exit' not in self.enabled_plugins: + return False + + try: + return self._entropy_early_exit_plugin({ + 'generated_tokens': generated_tokens, + 'attention_output': attention_output + }) + except Exception as e: + print(f"Error in entropy early exit plugin: {e}") + return False + + def _entropy_early_exit_plugin(self, context: Dict[str, Any]) -> bool: + """ + Entropy-based early exit plugin. + + Stops generation when the entropy of the attention distribution is below a threshold, + indicating the model is confident in its predictions. + """ + config = self.plugin_configs['entropy_early_exit'] + if not config['enabled']: + return False + + generated_tokens = context.get('generated_tokens', []) + attention_output = context.get('attention_output', None) + + # Check minimum token requirement + if len(generated_tokens) < config['min_tokens']: + return False + + if attention_output is not None: + # Calculate entropy of attention distribution + attention_probs = F.softmax(attention_output, dim=-1) + entropy = -torch.sum(attention_probs * torch.log(attention_probs + 1e-8)) + + # Stop if entropy is below threshold + if entropy < config['threshold']: + self.plugin_stats['early_exit_count'] += 1 + return True + + return False + + def _token_pruning_plugin(self, context: Dict[str, Any]) -> Dict[str, Any]: + """ + Token-level pruning plugin. + + Removes low-importance tokens from the KV cache to reduce memory usage + while maintaining generation quality. + """ + config = self.plugin_configs['token_pruning'] + if not config['enabled']: + return context + + # Implementation would prune tokens based on importance scores + # For now, return context unchanged + self.plugin_stats['pruning_count'] += 1 + return context + + def _approximate_attention_plugin(self, context: Dict[str, Any]) -> Dict[str, Any]: + """ + Approximate attention plugin. + + Uses approximation methods to reduce attention computation cost + while maintaining reasonable accuracy. + """ + config = self.plugin_configs['approximate_attention'] + if not config['enabled']: + return context + + # Implementation would apply attention approximation + # For now, return context unchanged + return context + + def _cache_optimization_plugin(self, context: Dict[str, Any]) -> Dict[str, Any]: + """ + Cache optimization plugin. + + Optimizes KV cache usage through intelligent eviction policies + and memory management. + """ + config = self.plugin_configs['cache_optimization'] + if not config['enabled']: + return context + + # Implementation would optimize cache usage + # For now, return context unchanged + return context + + def execute_plugin_pipeline(self, context: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute all enabled plugins in sequence. + + Args: + context: Context data for plugins + + Returns: + Updated context after plugin execution + """ + self.plugin_stats['total_plugin_calls'] += 1 + + try: + result = context.copy() + + for plugin_name in self.enabled_plugins: + if plugin_name in self.plugins: + try: + result = self.plugins[plugin_name](result) + if result is None: + # Plugin requested to stop processing + break + except Exception as e: + print(f"Plugin {plugin_name} error: {e}") + self.plugin_stats['plugin_error_count'] += 1 + continue + + self.plugin_stats['plugin_success_count'] += 1 + return result + + except Exception as e: + print(f"Error in plugin pipeline: {e}") + self.plugin_stats['plugin_error_count'] += 1 + return context + + def get_plugin_config(self, plugin_name: str) -> Dict[str, Any]: + """Get configuration for a specific plugin.""" + return self.plugin_configs.get(plugin_name, {}).copy() + + def update_plugin_config(self, plugin_name: str, config_updates: Dict[str, Any]): + """ + Update configuration for a specific plugin. + + Args: + plugin_name: Name of the plugin + config_updates: Configuration updates to apply + """ + if plugin_name in self.plugin_configs: + self.plugin_configs[plugin_name].update(config_updates) + + def get_plugin_stats(self) -> Dict[str, Any]: + """Get current plugin statistics.""" + return self.plugin_stats.copy() + + def get_enabled_plugins(self) -> List[str]: + """Get list of currently enabled plugins.""" + return list(self.enabled_plugins) + + def get_all_plugins(self) -> List[str]: + """Get list of all registered plugins.""" + return list(self.plugins.keys()) + + def clear_stats(self): + """Clear plugin statistics.""" + self.plugin_stats = { + 'total_plugin_calls': 0, + 'plugin_success_count': 0, + 'plugin_error_count': 0, + 'early_exit_count': 0, + 'pruning_count': 0 + } + + def benchmark_plugin(self, plugin_name: str, input_data: Any) -> Dict[str, float]: + """ + Benchmark a specific plugin's performance. + + Args: + plugin_name: Name of the plugin to benchmark + input_data: Input data for the plugin + + Returns: + Benchmark results + """ + if plugin_name not in self.plugins: + return {'error': 'Plugin not found'} + + plugin_func = self.plugins[plugin_name] + + # Warmup + for _ in range(10): + try: + _ = plugin_func(input_data) + except: + pass + + # Benchmark + start_time = time.time() + iterations = 100 + + for _ in range(iterations): + try: + _ = plugin_func(input_data) + except: + pass + + end_time = time.time() + total_time = end_time - start_time + avg_time_ms = (total_time / iterations) * 1000 + + return { + 'avg_execution_time_ms': avg_time_ms, + 'throughput_ops_per_sec': iterations / total_time + } + + def optimize_plugin_configs(self, performance_metrics: Dict[str, float]) -> Dict[str, Dict[str, Any]]: + """ + Dynamically optimize plugin configurations based on performance metrics. + + Args: + performance_metrics: Current performance metrics + + Returns: + Optimized plugin configurations + """ + optimizations = {} + + current_latency = performance_metrics.get('latency_ms', 0) + current_memory = performance_metrics.get('memory_gb', 0) + + # Optimize entropy early exit + if current_latency > self.config.target_latency_ms: + optimizations['entropy_early_exit'] = { + 'threshold': min(0.8, self.plugin_configs['entropy_early_exit']['threshold'] + 0.1), + 'min_tokens': max(5, self.plugin_configs['entropy_early_exit']['min_tokens'] - 2) + } + + # Optimize token pruning + if current_memory > self.config.target_memory_gb: + optimizations['token_pruning'] = { + 'pruning_ratio': min(0.3, self.plugin_configs['token_pruning']['pruning_ratio'] + 0.05) + } + + return optimizations diff --git a/vescale/tinyserve/scheduler.py b/vescale/tinyserve/scheduler.py new file mode 100644 index 0000000..254e0ea --- /dev/null +++ b/vescale/tinyserve/scheduler.py @@ -0,0 +1,280 @@ +""" +Modular Scheduling Pipeline for handling incoming queries and routing through configurable plugins. +""" + +import torch +import time +from typing import Dict, List, Optional, Any, Callable +from dataclasses import dataclass +from queue import Queue, Empty +import threading + + +@dataclass +class ScheduledRequest: + """Represents a scheduled inference request.""" + request_id: str + prompt: str + max_tokens: int + temperature: float + top_p: float + priority: int = 0 + timestamp: float = 0.0 + session_id: Optional[str] = None + + +class ModularScheduler: + """ + Modular scheduling pipeline that handles incoming queries and routes them through configurable plugins. + + Based on the paper's description of the modular scheduling pipeline. + """ + + def __init__(self, config): + """ + Initialize the modular scheduler. + + Args: + config: Configuration containing scheduling parameters + """ + self.config = config + self.device = torch.device(config.device) + + # Request queues with different priorities + self.high_priority_queue = Queue() + self.normal_priority_queue = Queue() + self.low_priority_queue = Queue() + + # Plugin registry + self.plugins = {} + self.plugin_order = [] + + # Session management + self.sessions = {} + self.session_timeout = getattr(config, 'session_timeout', 300.0) # 5 minutes + + # Performance tracking + self.scheduler_stats = { + 'total_requests': 0, + 'processed_requests': 0, + 'avg_queue_time_ms': 0.0, + 'avg_processing_time_ms': 0.0, + 'active_sessions': 0 + } + + # Start background processing thread + self.running = True + self.processing_thread = threading.Thread(target=self._process_requests, daemon=True) + self.processing_thread.start() + + def register_plugin(self, name: str, plugin_func: Callable, priority: int = 0): + """ + Register a plugin function for request processing. + + Args: + name: Plugin name + plugin_func: Plugin function to execute + priority: Execution priority (lower = higher priority) + """ + self.plugins[name] = { + 'function': plugin_func, + 'priority': priority, + 'enabled': True + } + + # Update plugin execution order + self.plugin_order = sorted(self.plugins.keys(), + key=lambda x: self.plugins[x]['priority']) + + def submit_request(self, request: ScheduledRequest) -> str: + """ + Submit a request for processing. + + Args: + request: Request to schedule + + Returns: + Request ID + """ + request.timestamp = time.time() + + # Route to appropriate queue based on priority + if request.priority == 0: # High priority + self.high_priority_queue.put(request) + elif request.priority == 1: # Normal priority + self.normal_priority_queue.put(request) + else: # Low priority + self.low_priority_queue.put(request) + + self.scheduler_stats['total_requests'] += 1 + + return request.request_id + + def _process_requests(self): + """Background thread for processing requests.""" + while self.running: + try: + # Process high priority requests first + request = self._get_next_request() + if request: + self._execute_request(request) + else: + time.sleep(0.001) # Small sleep to prevent busy waiting + + except Exception as e: + print(f"Error in request processing: {e}") + time.sleep(0.1) + + def _get_next_request(self) -> Optional[ScheduledRequest]: + """Get the next request to process based on priority.""" + # Try high priority queue first + try: + return self.high_priority_queue.get_nowait() + except Empty: + pass + + # Try normal priority queue + try: + return self.normal_priority_queue.get_nowait() + except Empty: + pass + + # Try low priority queue + try: + return self.low_priority_queue.get_nowait() + except Empty: + pass + + return None + + def _execute_request(self, request: ScheduledRequest): + """Execute a request through the plugin pipeline.""" + start_time = time.time() + + try: + # Execute plugins in order + result = request + for plugin_name in self.plugin_order: + plugin = self.plugins[plugin_name] + if plugin['enabled']: + try: + result = plugin['function'](result) + if result is None: + # Plugin requested to stop processing + break + except Exception as e: + print(f"Plugin {plugin_name} error: {e}") + continue + + # Update statistics + processing_time = (time.time() - start_time) * 1000 # Convert to ms + self._update_stats(processing_time) + + # Clean up session if needed + if request.session_id: + self._update_session(request.session_id, result) + + except Exception as e: + print(f"Error executing request {request.request_id}: {e}") + + def _update_session(self, session_id: str, result: Any): + """Update session information.""" + if session_id not in self.sessions: + self.sessions[session_id] = { + 'created_at': time.time(), + 'last_activity': time.time(), + 'request_count': 0, + 'total_tokens': 0 + } + + session = self.sessions[session_id] + session['last_activity'] = time.time() + session['request_count'] += 1 + + # Clean up expired sessions + self._cleanup_expired_sessions() + + def _cleanup_expired_sessions(self): + """Remove expired sessions.""" + current_time = time.time() + expired_sessions = [] + + for session_id, session in self.sessions.items(): + if current_time - session['last_activity'] > self.session_timeout: + expired_sessions.append(session_id) + + for session_id in expired_sessions: + del self.sessions[session_id] + + def _update_stats(self, processing_time_ms: float): + """Update scheduler statistics.""" + self.scheduler_stats['processed_requests'] += 1 + + # Update running averages + current_avg = self.scheduler_stats['avg_processing_time_ms'] + total_processed = self.scheduler_stats['processed_requests'] + + self.scheduler_stats['avg_processing_time_ms'] = ( + (current_avg * (total_processed - 1) + processing_time_ms) / total_processed + ) + + # Update active sessions count + self.scheduler_stats['active_sessions'] = len(self.sessions) + + def get_queue_status(self) -> Dict[str, Any]: + """Get current queue status.""" + return { + 'high_priority_queue_size': self.high_priority_queue.qsize(), + 'normal_priority_queue_size': self.normal_priority_queue.qsize(), + 'low_priority_queue_size': self.low_priority_queue.qsize(), + 'total_queued': (self.high_priority_queue.qsize() + + self.normal_priority_queue.qsize() + + self.low_priority_queue.qsize()) + } + + def get_scheduler_stats(self) -> Dict[str, Any]: + """Get current scheduler statistics.""" + return self.scheduler_stats.copy() + + def enable_plugin(self, plugin_name: str): + """Enable a specific plugin.""" + if plugin_name in self.plugins: + self.plugins[plugin_name]['enabled'] = True + + def disable_plugin(self, plugin_name: str): + """Disable a specific plugin.""" + if plugin_name in self.plugins: + self.plugins[plugin_name]['enabled'] = False + + def get_plugin_status(self) -> Dict[str, Dict[str, Any]]: + """Get status of all plugins.""" + return { + name: { + 'enabled': plugin['enabled'], + 'priority': plugin['priority'] + } + for name, plugin in self.plugins.items() + } + + def stop(self): + """Stop the scheduler and background processing.""" + self.running = False + if self.processing_thread.is_alive(): + self.processing_thread.join(timeout=1.0) + + def clear_queues(self): + """Clear all request queues.""" + while not self.high_priority_queue.empty(): + self.high_priority_queue.get() + while not self.normal_priority_queue.empty(): + self.normal_priority_queue.get() + while not self.low_priority_queue.empty(): + self.low_priority_queue.get() + + def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]: + """Get information about a specific session.""" + return self.sessions.get(session_id) + + def list_sessions(self) -> List[str]: + """List all active session IDs.""" + return list(self.sessions.keys()) diff --git a/vescale/tinyserve/utils.py b/vescale/tinyserve/utils.py new file mode 100644 index 0000000..a5cbfb4 --- /dev/null +++ b/vescale/tinyserve/utils.py @@ -0,0 +1,317 @@ +""" +Configuration and utility functions for TinyServe. +""" + +import torch +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any, Union +import os +import json + + +@dataclass +class TinyServeConfig: + """ + Configuration class for TinyServe. + + Contains all the parameters described in the paper for: + - Query-aware KV selection + - Page-based memory management + - Plugin configurations + - Performance targets + """ + + # Device configuration + device: str = "cuda" if torch.cuda.is_available() else "cpu" + + # Page-based KV cache configuration + page_size: int = 16 # Tokens per page (from paper: tested 4, 8, 16, 32, 64) + selection_ratio: float = 0.3 # Top-K ratio for page selection (from paper: tested 0.1, 0.2, 0.3, 0.5) + + # Attention configuration + num_attention_heads: int = 12 + head_dim: int = 64 + use_fused_kernel: bool = True + attention_chunk_size: int = 256 + + # Plugin configurations + enable_entropy_early_exit: bool = True + entropy_threshold: float = 0.5 + min_tokens_before_exit: int = 10 + + enable_token_pruning: bool = True + pruning_ratio: float = 0.1 + min_tokens_after_pruning: int = 100 + + enable_approximate_attention: bool = False + approximation_method: str = "linear" + compression_ratio: float = 0.5 + + enable_cache_optimization: bool = True + eviction_policy: str = "lru" + max_cache_size_gb: float = 8.0 + + # Performance targets + target_latency_ms: float = 50.0 + target_memory_gb: float = 4.0 + + # Session management + session_timeout: float = 300.0 # 5 minutes + + # Multi-GPU configuration + num_gpus: int = 1 + gpu_ids: List[int] = field(default_factory=lambda: [0]) + + # Memory management + max_sequence_length: int = 8192 + kv_cache_dtype: str = "float16" + + # Logging and monitoring + enable_logging: bool = True + log_level: str = "INFO" + enable_profiling: bool = False + + # Model-specific configurations + model_type: str = "auto" # tinylama, gpt2, opt, auto + model_path: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if self.page_size <= 0: + raise ValueError("page_size must be positive") + + if not 0.0 < self.selection_ratio <= 1.0: + raise ValueError("selection_ratio must be between 0 and 1") + + if self.entropy_threshold < 0.0: + raise ValueError("entropy_threshold must be non-negative") + + if self.pruning_ratio < 0.0 or self.pruning_ratio > 1.0: + raise ValueError("pruning_ratio must be between 0 and 1") + + if self.max_cache_size_gb <= 0.0: + raise ValueError("max_cache_size_gb must be positive") + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> 'TinyServeConfig': + """Create configuration from dictionary.""" + return cls(**config_dict) + + @classmethod + def from_json_file(cls, file_path: str) -> 'TinyServeConfig': + """Load configuration from JSON file.""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"Configuration file not found: {file_path}") + + with open(file_path, 'r') as f: + config_dict = json.load(f) + + return cls.from_dict(config_dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary.""" + return { + 'device': self.device, + 'page_size': self.page_size, + 'selection_ratio': self.selection_ratio, + 'num_attention_heads': self.num_attention_heads, + 'head_dim': self.head_dim, + 'use_fused_kernel': self.use_fused_kernel, + 'attention_chunk_size': self.attention_chunk_size, + 'enable_entropy_early_exit': self.enable_entropy_early_exit, + 'entropy_threshold': self.entropy_threshold, + 'min_tokens_before_exit': self.min_tokens_before_exit, + 'enable_token_pruning': self.enable_token_pruning, + 'pruning_ratio': self.pruning_ratio, + 'min_tokens_after_pruning': self.min_tokens_after_pruning, + 'enable_approximate_attention': self.enable_approximate_attention, + 'approximation_method': self.approximation_method, + 'compression_ratio': self.compression_ratio, + 'enable_cache_optimization': self.enable_cache_optimization, + 'eviction_policy': self.eviction_policy, + 'max_cache_size_gb': self.max_cache_size_gb, + 'target_latency_ms': self.target_latency_ms, + 'target_memory_gb': self.target_memory_gb, + 'session_timeout': self.session_timeout, + 'num_gpus': self.num_gpus, + 'gpu_ids': self.gpu_ids, + 'max_sequence_length': self.max_sequence_length, + 'kv_cache_dtype': self.kv_cache_dtype, + 'enable_logging': self.enable_logging, + 'log_level': self.log_level, + 'enable_profiling': self.enable_profiling, + 'model_type': self.model_type, + 'model_path': self.model_path + } + + def save_to_json(self, file_path: str): + """Save configuration to JSON file.""" + config_dict = self.to_dict() + + # Ensure directory exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, 'w') as f: + json.dump(config_dict, f, indent=2) + + def update(self, updates: Dict[str, Any]): + """Update configuration with new values.""" + for key, value in updates.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + print(f"Warning: Unknown configuration key: {key}") + + # Re-validate after updates + self.__post_init__() + + def get_optimized_config(self, performance_metrics: Dict[str, float]) -> 'TinyServeConfig': + """ + Get optimized configuration based on performance metrics. + + Args: + performance_metrics: Current performance metrics + + Returns: + Optimized configuration + """ + current_latency = performance_metrics.get('latency_ms', 0) + current_memory = performance_metrics.get('memory_gb', 0) + + # Create a copy for optimization + optimized = TinyServeConfig.from_dict(self.to_dict()) + + # Optimize page size based on latency + if current_latency > self.target_latency_ms: + optimized.page_size = max(4, self.page_size // 2) + + # Optimize selection ratio based on memory + if current_memory > self.target_memory_gb: + optimized.selection_ratio = min(0.5, self.selection_ratio + 0.1) + + # Optimize plugin configurations + if current_latency > self.target_latency_ms: + optimized.enable_entropy_early_exit = True + optimized.entropy_threshold = min(0.8, self.entropy_threshold + 0.1) + + if current_memory > self.target_memory_gb: + optimized.enable_token_pruning = True + optimized.pruning_ratio = min(0.3, self.pruning_ratio + 0.05) + + return optimized + + def get_model_config(self) -> Dict[str, Any]: + """Get model-specific configuration.""" + return { + 'model_type': self.model_type, + 'model_path': self.model_path, + 'max_sequence_length': self.max_sequence_length, + 'kv_cache_dtype': self.kv_cache_dtype + } + + def get_plugin_config(self) -> Dict[str, Any]: + """Get plugin configuration.""" + return { + 'entropy_early_exit': { + 'enabled': self.enable_entropy_early_exit, + 'threshold': self.entropy_threshold, + 'min_tokens': self.min_tokens_before_exit + }, + 'token_pruning': { + 'enabled': self.enable_token_pruning, + 'pruning_ratio': self.pruning_ratio, + 'min_tokens': self.min_tokens_after_pruning + }, + 'approximate_attention': { + 'enabled': self.enable_approximate_attention, + 'method': self.approximation_method, + 'compression_ratio': self.compression_ratio + }, + 'cache_optimization': { + 'enabled': self.enable_cache_optimization, + 'eviction_policy': self.eviction_policy, + 'max_cache_size_gb': self.max_cache_size_gb + } + } + + +def create_default_config() -> TinyServeConfig: + """Create a default TinyServe configuration.""" + return TinyServeConfig() + + +def create_optimized_config_for_model(model_name: str, + target_latency_ms: float = 50.0, + target_memory_gb: float = 4.0) -> TinyServeConfig: + """ + Create an optimized configuration for a specific model. + + Args: + model_name: Name of the model (e.g., "tinylama", "gpt2", "opt") + target_latency_ms: Target latency in milliseconds + target_memory_gb: Target memory usage in GB + + Returns: + Optimized configuration + """ + config = TinyServeConfig() + + # Model-specific optimizations + if "tinylama" in model_name.lower(): + config.page_size = 16 + config.selection_ratio = 0.3 + config.num_attention_heads = 12 + config.head_dim = 64 + elif "gpt2" in model_name.lower(): + config.page_size = 32 + config.selection_ratio = 0.2 + config.num_attention_heads = 12 + config.head_dim = 64 + elif "opt" in model_name.lower(): + config.page_size = 16 + config.selection_ratio = 0.25 + config.num_attention_heads = 16 + config.head_dim = 64 + + # Performance targets + config.target_latency_ms = target_latency_ms + config.target_memory_gb = target_memory_gb + + return config + + +def validate_config(config: TinyServeConfig) -> List[str]: + """ + Validate configuration and return list of warnings/errors. + + Args: + config: Configuration to validate + + Returns: + List of validation messages + """ + warnings = [] + + # Check device availability + if config.device == "cuda" and not torch.cuda.is_available(): + warnings.append("CUDA device requested but not available") + + # Check memory constraints + if config.max_cache_size_gb > 32: + warnings.append("Very large cache size may cause memory issues") + + # Check performance targets + if config.target_latency_ms < 10: + warnings.append("Very low latency target may be unrealistic") + + if config.target_memory_gb < 1: + warnings.append("Very low memory target may cause issues") + + # Check plugin configurations + if config.enable_entropy_early_exit and config.entropy_threshold < 0.1: + warnings.append("Very low entropy threshold may cause premature stopping") + + if config.enable_token_pruning and config.pruning_ratio > 0.5: + warnings.append("High pruning ratio may significantly impact quality") + + return warnings