diff --git a/README.md b/README.md index 6c5b8bc1..7234b397 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,6 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m + [2024/05/17] We release [meetkai/functionary-small-v2.5](https://huggingface.co/meetkai/functionary-small-v2.5) with better capability for function calling and code interpreter compared with [functionary-small-v2.4](https://huggingface.co/meetkai/functionary-small-v2.4) + [2024/05/06] Streaming support for functionary v2 to v2.4 models is released in [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)! + [2024/05/03] Added support for serverless vLLM deployment on [Modal.com](https://modal.com/) - + [2024/04/27] New and improved grammar sampling! Ensures 100% accuracy in generating function names, prompt template and parameters. + [2024/04/02] We release [meetkai/functionary-small-v2.4](https://huggingface.co/meetkai/functionary-small-v2.4) and [meetkai/functionary-medium-v2.4](https://huggingface.co/meetkai/functionary-medium-v2.4)! The first functionary models with code-interpreter ability (by passing in `{type: "code_interpreter"}` in tools)! @@ -114,17 +113,6 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \ ``` -### Grammar Sampling (Only in vLLM) - -We also offer our own function-calling grammar sampling feature which constrains the LLM's generation to always follow the prompt template, and ensures 100% accuracy for function name. The parameters are generated using the efficient [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), which ensures that the parameters follow the schema of the tool called. To enable grammar sampling, run the vLLM server with the command-line argument --enable-grammar-sampling: - -```shell -python3 server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 --enable-grammar-sampling -``` - -**Note:** Grammar Sampling support is applicable only for the V2, V3.0, V3.2 models. There is no such support for V1 and V3.1 models. - - ### Text-Generation-Inference (TGI) We also provide a service that performs inference on Functionary models using [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) (TGI). Follow these steps to get started: @@ -711,11 +699,6 @@ Evaluation function call prediction in SGD dataset. The accuracy metric measures See training [README](functionary/train/README.md) -## Safety & Security - -While its not strictly enforced, to ensure more *secure* function execution, one can enable grammar sampling to enforce type checking. -Main safety checks needs to be done in the functions/actions themselves. Such as validation of the given input, or the ouput that will be given to the model. - ## Roadmap - [ ] OpenAPI specification based plugin support. @@ -724,7 +707,6 @@ Main safety checks needs to be done in the functions/actions themselves. Such as - [X] [text-generation-inference](https://github.com/huggingface/text-generation-inference) - [X] Streaming Support - [X] function_call parameter to server - - [X] Grammar Sampling to ensure 100% accuracy for function and parameter names - [X] Parallel function calling support - [X] Python function calling support (Automatic detection of type annotations and calling them automatically) - [X] Real world usage examples, such as creating agents. diff --git a/functionary/inference.py b/functionary/inference.py index f5062e7f..e2bc4658 100644 --- a/functionary/inference.py +++ b/functionary/inference.py @@ -1,24 +1,17 @@ from typing import Dict, List, Optional, Union import torch -from lmformatenforcer import CharacterLevelParser, JsonSchemaParser -from lmformatenforcer.integrations.vllm import build_vllm_logits_processor from transformers import ( LlamaForCausalLM, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, ) -from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( - _cached_build_vllm_token_enforcer_tokenizer_data, - _normalize_json_schema_object, -) -from vllm.sampling_params import LogitsProcessor +from functionary.inference_utils import StopWordsCriteria from functionary.openai_types import ChatMessage, Function, FunctionCall, Tool from functionary.prompt_template import get_prompt_template_from_tokenizer from functionary.prompt_template.prompt_utils import prepare_messages_for_inference -from functionary.inference_utils import StopWordsCriteria def tokenize(message: ChatMessage, tokenizer: LlamaTokenizer, device="cuda:0"): @@ -100,30 +93,6 @@ def generate_message( return ChatMessage(**result) -async def get_lm_format_enforcer_vllm_logits_processor_from_tool_name( - tool_name, tools_or_functions, tokenizer -) -> LogitsProcessor: - """ - Given a tool_name and list of tool definitions, find the json schema - of the tool with tool_name name and get the necessary vLLM logits processor - for the given tool schema.""" - - tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(tokenizer) - character_level_parser: CharacterLevelParser - - # Get the tool schema - for tool_or_function in tools_or_functions: - if tool_or_function["name"] == tool_name: - raw_tool_schema = tool_or_function["parameters"] - break - schema = _normalize_json_schema_object(raw_tool_schema) - character_level_parser = JsonSchemaParser(schema) - logits_processor = build_vllm_logits_processor( - tokenizer_data, character_level_parser - ) - return logits_processor - - if __name__ == "__main__": # First lets create an example messages list with all different types of roles and content. functions = [ diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index 6d61f6b5..166d2d66 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -180,7 +180,6 @@ async def process_chat_completion( served_model: List[str], served_loras: List[LoRARequest], engine_model_config: Any, - enable_grammar_sampling: bool, engine: Any, ): error_check_ret = await check_all_errors(request, served_model, served_loras) @@ -216,14 +215,6 @@ async def process_chat_completion( tok_ids = tokenizer.encode(stop_tok, add_special_tokens=False) stop_token_ids.append(tok_ids[-1]) - # In vLLM==0.4.1, SamplingParams.logprobs has a proportional effect on latency - # We need to limit the size of SamplingParams.logprobs as a temporary fix first - # while investigating this problem in vLLM - if enable_grammar_sampling is False: - logprobs = None - else: - logprobs = 200 - try: sampling_params = SamplingParams( n=request.n, @@ -238,28 +229,17 @@ async def process_chat_completion( top_k=request.top_k, ignore_eos=request.ignore_eos, skip_special_tokens=False, - logprobs=logprobs, + logprobs=None, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - if enable_grammar_sampling: - result_generator = engine.generate( - prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), - lora_request=lora_request, - sampling_params=sampling_params, - request_id=request_id, - tools_or_functions=tools_or_functions, - prompt_template_cls=prompt_template, - tool_choice=tool_func_choice, - ) - else: - result_generator = engine.generate( - prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), - lora_request=lora_request, - sampling_params=sampling_params, - request_id=request_id, - ) + result_generator = engine.generate( + prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + lora_request=lora_request, + sampling_params=sampling_params, + request_id=request_id, + ) async def abort_request() -> None: await engine.abort(request_id) diff --git a/functionary/vllm_monkey_patch/async_llm_engine.py b/functionary/vllm_monkey_patch/async_llm_engine.py deleted file mode 100644 index 383b319b..00000000 --- a/functionary/vllm_monkey_patch/async_llm_engine.py +++ /dev/null @@ -1,1399 +0,0 @@ -import asyncio -import time -import weakref -from functools import partial -from typing import ( - Any, - AsyncGenerator, - AsyncIterator, - Callable, - Coroutine, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - Type, - Union, - overload, -) -from weakref import ReferenceType - -import vllm.envs as envs -from vllm.config import ( - DecodingConfig, - EngineConfig, - LoRAConfig, - ModelConfig, - ParallelConfig, - SchedulerConfig, -) -from vllm.core.scheduler import SchedulerOutputs -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import build_guided_decoding_logits_processor_async -from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState -from vllm.engine.metrics_types import StatLoggerBase -from vllm.engine.protocol import EngineClient -from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.gpu_executor import GPUExecutorAsync -from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs, weak_bind - -from functionary.inference import ( - get_lm_format_enforcer_vllm_logits_processor_from_tool_name, -) -from functionary.openai_types import Tool -from functionary.prompt_template.prompt_utils import resolve_json_refs - -logger = init_logger(__name__) -ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S - - -class AsyncEngineDeadError(RuntimeError): - pass - - -def _log_task_completion( - task: asyncio.Task, error_callback: Callable[[Exception], None] -) -> None: - """This function is only intended for the `engine.run_engine_loop()` task. - - In particular, that task runs a `while True` loop that can only exit if - there is an exception. - """ - - exception = None - try: - return_value = task.result() - raise AssertionError( - f"The engine background task should never finish without an " - f"exception. {return_value}" - ) - except asyncio.exceptions.CancelledError: - # We assume that if the task is cancelled, we are gracefully shutting - # down. This should only happen on program exit. - logger.info("Engine is gracefully shutting down.") - except Exception as e: - exception = e - logger.error("Engine background task failed", exc_info=e) - error_callback(exception) - raise AsyncEngineDeadError( - "Task finished unexpectedly. This should never happen! " - "Please open an issue on Github. See stack trace above for the " - "actual cause." - ) from e - - -STOP_ITERATION = Exception() # Sentinel - - -class AsyncStream: - """A stream of RequestOutputs or EmbeddingRequestOutputs for a request - that can be iterated over asynchronously via an async generator.""" - - def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: - self.request_id = request_id - self._cancel = cancel - self._queue: asyncio.Queue = asyncio.Queue() - self._finished = False - - def put( - self, item: Union[RequestOutput, EmbeddingRequestOutput, Exception] - ) -> None: - if not self._finished: - self._queue.put_nowait(item) - - def finish( - self, - exception: Optional[Union[BaseException, Type[BaseException]]] = None, - ) -> None: - if not self._finished: - self._finished = True - self._queue.put_nowait( - exception if self._is_raisable(exception) else STOP_ITERATION - ) - - @property - def finished(self) -> bool: - return self._finished - - async def generator( - self, - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: - try: - while True: - result = await self._queue.get() - if self._is_raisable(result): - if result == STOP_ITERATION: - return - raise result - yield result - except GeneratorExit: - self._cancel(self.request_id) - raise asyncio.CancelledError from None - - @staticmethod - def _is_raisable(value: Any): - return isinstance(value, BaseException) or ( - isinstance(value, type) and issubclass(value, BaseException) - ) - - -class RequestTracker: - """Synchronous abstraction for tracking requests.""" - - def __init__(self) -> None: - self._request_streams: Dict[str, AsyncStream] = {} - self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() - self.new_requests_event = asyncio.Event() - - def __contains__(self, item): - return item in self._request_streams - - def __len__(self) -> int: - return len(self._request_streams) - - def propagate_exception( - self, exc: Exception, request_id: Optional[str] = None - ) -> None: - """Propagate an exception to request streams - (all if request_id is None).""" - if request_id is not None: - self.abort_request(request_id, exception=exc) - else: - # NB: tuple() used here because self.abort_request pops the stream - # out of self._request_streams, so we can't iterate on it directly - for rid in tuple(self._request_streams.keys()): - self.abort_request(rid, exception=exc) - - def process_request_output( - self, - request_output: Union[RequestOutput, EmbeddingRequestOutput], - *, - verbose: bool = False, - ) -> None: - """Process a request output from the engine.""" - request_id = request_output.request_id - finished = request_output.finished - - if finished: - stream = self._request_streams.pop(request_id, None) - else: - stream = self._request_streams.get(request_id) - # Guard against a KeyError which can occur if the request was aborted - # while the output was generated - if stream is not None: - stream.put(request_output) - if finished: - stream.finish() - - if verbose and finished: - logger.info("Finished request %s.", request_id) - - def process_exception( - self, request_id: str, exception: BaseException, *, verbose: bool = False - ) -> None: - """Propagate an exception from the engine.""" - if verbose: - logger.info("Finished request %s.", request_id) - self.abort_request(request_id, exception=exception) - - def add_request( - self, request_id: str, *, verbose: bool = False, **engine_add_request_kwargs - ) -> AsyncStream: - """Add a request to be sent to the engine on the next background - loop iteration.""" - if request_id in self._request_streams: - raise KeyError(f"Request {request_id} already exists.") - - abort_request = partial(self.abort_request, verbose=verbose) - stream = AsyncStream(request_id, abort_request) - self._new_requests.put_nowait( - (stream, {"request_id": request_id, **engine_add_request_kwargs}) - ) - - self.new_requests_event.set() - - if verbose: - logger.info("Added request %s.", request_id) - - return stream - - def abort_request( - self, - request_id: str, - *, - exception: Optional[Union[BaseException, Type[BaseException]]] = None, - verbose: bool = False, - ) -> None: - """Abort a request during next background loop iteration.""" - if verbose: - logger.info("Aborted request %s.", request_id) - - self._aborted_requests.put_nowait(request_id) - - stream = self._request_streams.pop(request_id, None) - if stream is not None: - stream.finish(exception=exception) - - def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: - """Get the new requests and finished requests to be - sent to the engine.""" - new_requests: List[Dict] = [] - finished_requests: Set[str] = set() - - while not self._aborted_requests.empty(): - request_id = self._aborted_requests.get_nowait() - finished_requests.add(request_id) - - while not self._new_requests.empty(): - stream, new_request = self._new_requests.get_nowait() - request_id = stream.request_id - if request_id in finished_requests: - # The request has already been aborted. - stream.finish(asyncio.CancelledError) - finished_requests.discard(request_id) - else: - self._request_streams[request_id] = stream - new_requests.append(new_request) - - return new_requests, finished_requests - - async def wait_for_new_requests(self): - if not self.has_new_requests(): - await self.new_requests_event.wait() - self.new_requests_event.clear() - - def has_new_requests(self): - return not self._new_requests.empty() - - -class _AsyncLLMEngine(LLMEngine): - """Extension of LLMEngine to add async methods.""" - - # This is a dict mapping request_id to the list of tools/functions details - tools_or_functions: dict = {} - # This is a dict mapping request_id to the prompt_template_cls - prompt_templates: dict = {} - # This is a dict mappingg request_id to the generation_state. It contains - # the following information: - # - stage: one of the following: - # ["pre-function", "function", "pre-parameter", "parameter", "text-gen"] - # - curr_tokens: all the tokens for the current stage being generated - # - curr_text: curr_tokens but in string text form - # - func_name: the function name, if any - # - tool_choice: whether the user provided tool_choice - gen_states: dict = {} - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - async def step_async( - self, virtual_engine: int - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: - """Performs one decoding iteration and returns newly generated results. - The workers are ran asynchronously if possible. - - This function performs one decoding iteration of the engine. It first - schedules the sequences to be executed in the next iteration and the - token blocks to be swapped in/out/copy. Then, it executes the model - and updates the scheduler with the model outputs. Finally, it decodes - the sequences and returns the newly generated results. - """ - # these are cached outputs from previous iterations. None if on first - # iteration - cached_outputs = self.cached_scheduler_outputs[virtual_engine] - seq_group_metadata_list = cached_outputs.seq_group_metadata_list - scheduler_outputs = cached_outputs.scheduler_outputs - allow_async_output_proc = cached_outputs.allow_async_output_proc - - ctx = self.scheduler_contexts[virtual_engine] - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() - - # skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current - # batch has completed. - if not self._has_remaining_steps(seq_group_metadata_list): - - # Schedule iteration - (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) = ( - self.scheduler[virtual_engine].schedule() - ) - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs - - # Maybe switch from async mode to sync mode - if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - - if ( - self.scheduler_config.is_multi_step - and scheduler_outputs.num_lookahead_slots > 0 - ): - # cache the scheduler outputs for the next iteration if we have - # lookahead slots - self._cache_scheduler_outputs_for_multi_step( - virtual_engine, - seq_group_metadata_list, - scheduler_outputs, - allow_async_output_proc, - ) - - assert seq_group_metadata_list is not None - assert scheduler_outputs is not None - - tokenizer = self.get_tokenizer() - - # Loop through each request and turn on/off lm-format-enforcer logits - # processor before generating the next token - for i in range(len(seq_group_metadata_list)): - request_id = seq_group_metadata_list[i].request_id - gen_state = self.gen_states[request_id] - tools_or_functions = self.tools_or_functions[request_id] - tools = resolve_json_refs(tools_or_functions=tools_or_functions) - - # Check if the model just transitioned to "parameter" or "pre-function" - if ( - gen_state["stage"] == "parameter" - and seq_group_metadata_list[i].sampling_params.logits_processors is None - ): - - seq_group_metadata_list[i].sampling_params.logits_processors = [ - await get_lm_format_enforcer_vllm_logits_processor_from_tool_name( - tool_name=gen_state["func_name"], - tools_or_functions=tools, - tokenizer=tokenizer, - ) - ] - elif ( - gen_state["stage"] != "parameter" - and seq_group_metadata_list[i].sampling_params.logits_processors - is not None - ): - seq_group_metadata_list[i].sampling_params.logits_processors = None - - if not scheduler_outputs.is_empty(): - finished_requests_ids = self.scheduler[ - virtual_engine - ].get_and_reset_finished_requests_ids() - - # Check if we have a cached last_output from the previous iteration. - # For supporting PP this is probably the best way to pass the - # sampled_token_ids, as a separate broadcast over all the PP stages - # will cause one virtual engine's microbatch to block the pipeline. - last_sampled_token_ids = self._get_last_sampled_token_ids(virtual_engine) - - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - virtual_engine=virtual_engine, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots, - running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids, - # We use ExecuteModelRequest to pass the last sampled_token_ids - # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids, - ) - - if allow_async_output_proc: - execute_model_req.async_callback = self.async_callbacks[virtual_engine] - - # Execute the model. - outputs = await self.model_executor.execute_model_async(execute_model_req) - - # Loop through all the output in the batch - for i in range(len(outputs[0])): - # Check whether grammar sampling is needed - model_sampled_token_id = outputs[0].outputs[i].samples[-1].output_token - request_id = seq_group_metadata_list[i].request_id - if ( - tokenizer.decode(model_sampled_token_id) == tokenizer.eos_token - or request_id not in self.prompt_templates - ): - continue - - # Get all the required variables for grammar sampling - prompt_template = self.prompt_templates[request_id] - gen_state = self.gen_states[request_id] - tools_or_functions = self.tools_or_functions[request_id] - - # Slot the first entry of logprobs into its original position - # before getting delta_token_ids_by_logprobs - delta_token_id_by_logprobs = list( - outputs[0].outputs[i].samples[-1].logprobs.keys() - ) - delta_logprobs = list( - outputs[0].outputs[i].samples[-1].logprobs.values() - ) - chosen_token_id = delta_token_id_by_logprobs[0] - chosen_logprob = delta_logprobs[0] - if chosen_logprob.rank != 1: - delta_token_id_by_logprobs.pop(0) - delta_logprobs.pop(0) - delta_token_id_by_logprobs.insert( - chosen_logprob.rank - 1, chosen_token_id - ) - - # Perform grammar sampling if needed and update the gen_state before returning - ( - grammar_sampled_token_id, - grammar_sampled_token, - self.gen_states[request_id], - ) = prompt_template.grammar_sample( - gen_state=gen_state, - tools_or_functions=tools_or_functions, - delta_token_ids=delta_token_id_by_logprobs, - model_sampled_token_id=model_sampled_token_id, - tokenizer=tokenizer, - ) - - # Update the output token to vllm with the newly sampled one - outputs[0].outputs[i].samples[ - -1 - ].output_token = grammar_sampled_token_id - - # we need to do this here so that last step's sampled_token_ids can - # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, outputs) - else: - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - outputs = [] - - # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: - for seq_group in seq_group_metadata_list: - seq_group.finish_step() - - if not self._has_remaining_steps(seq_group_metadata_list): - # Clear the cache if we have finished all the steps - if self.scheduler_config.is_multi_step: - self.cached_scheduler_outputs[virtual_engine] = SchedulerOutputState() - - # is_first_step_output is True only when the num_steps of all - # the sequences are 1. When the num_steps > 1, - # multi_step_model_runner does the first-step output append. - is_first_step_output: bool = ( - False - if not seq_group_metadata_list - else seq_group_metadata_list[0].state.num_steps == 1 - ) - - ctx.append_output( - outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=allow_async_output_proc, - is_last_step=True, - is_first_step_output=is_first_step_output, - ) - - if outputs and allow_async_output_proc: - assert ( - len(outputs) == 1 - ), "Async postprocessor expects only a single output set" - self._advance_to_next_step( - outputs[0], - seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups, - ) - - if not allow_async_output_proc: - self._process_model_outputs(ctx=ctx) - - # Log stats. - self.do_log_stats(scheduler_outputs, outputs) - - # Tracing - self.do_tracing(scheduler_outputs) - - else: - # Multi-step case - return ctx.request_outputs - - if not self.has_unfinished_requests(): - # Drain async postprocessor (if exists) - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - assert len(ctx.output_queue) == 0 - - return ctx.request_outputs - - async def stop_remote_worker_execution_loop_async(self) -> None: - """Stop the remote worker execution loop.""" - await self.model_executor.stop_remote_worker_execution_loop_async() - - @overload # DEPRECATED - async def add_request_async( - self, - request_id: str, - *, - inputs: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: ... - - @overload - async def add_request_async( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - async def add_request_async( - self, - request_id: str, - prompt: Optional[PromptType] = None, - params: Optional[Union[SamplingParams, PoolingParams]] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: - """Async version of :meth:`add_request`.""" - if inputs is not None: - prompt = inputs - assert prompt is not None and params is not None - - if lora_request is not None and not self.lora_config: - raise ValueError( - f"Got lora_request {lora_request} but LoRA is " "not enabled!" - ) - if priority != 0 and not self.scheduler_config.policy == "priority": - raise ValueError( - f"Got priority {priority} but " "Priority scheduling is not enabled." - ) - if arrival_time is None: - arrival_time = time.time() - - preprocessed_inputs = await self.input_preprocessor.preprocess_async( - prompt, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - processed_inputs = self.input_processor(preprocessed_inputs) - - if isinstance(params, SamplingParams) and params.guided_decoding is not None: - # Guided decoding has an async implementation for building logits - # processors in a separate threadpool. - # We want to invoke that here instead of using the blocking - # implementation in the LLMEngine - params = await build_guided_decoding_logits_processor_async( - sampling_params=params, - tokenizer=self.get_tokenizer(lora_request), - default_guided_backend=self.decoding_config.guided_decoding_backend, - ) - - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - priority=priority, - ) - - async def check_health_async(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() - self.model_executor.check_health() - - -class AsyncLLMEngine(EngineClient): - """An asynchronous wrapper for :class:`LLMEngine`. - - This class is used to wrap the :class:`LLMEngine` class to make it - asynchronous. It uses asyncio to create a background loop that keeps - processing incoming requests. The :class:`LLMEngine` is kicked by the - generate method when there are requests in the waiting queue. The generate - method yields the outputs from the :class:`LLMEngine` to the caller. - - Args: - log_requests: Whether to log the requests. - start_engine_loop: If True, the background task to run the engine - will be automatically started in the generate call. - *args: Arguments for :class:`LLMEngine`. - **kwargs: Arguments for :class:`LLMEngine`. - """ - - _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine - - def __init__( - self, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs - ) -> None: - self.log_requests = log_requests - self.engine = self._engine_class(*args, **kwargs) - - # This ensures quick processing of request outputs - # so the append to asyncio queues is not delayed, - # especially for multi-step. - self.use_process_request_outputs_callback = ( - self.engine.model_config.use_async_output_proc - ) - - if self.use_process_request_outputs_callback: - self.engine.process_request_outputs_callback = weak_bind( - self.process_request_outputs - ) - - self.background_loop: Optional[asyncio.Future] = None - # We need to keep a reference to unshielded - # task as well to prevent it from being garbage - # collected - self._background_loop_unshielded: Optional[asyncio.Task] = None - self.start_engine_loop = start_engine_loop - self._errored_with: Optional[BaseException] = None - - # Lazy initialized fields - self._request_tracker: RequestTracker - - def __del__(self): - if rt := getattr(self, "request_tracker", None): - # Wake up engine loop so that it will exit cleanly - rt.new_requests_event.set() - - @classmethod - def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend - ) - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorAsyncBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorAsyncBase. Got {distributed_executor_backend}." - ) - executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutorAsync - - executor_class = NeuronExecutorAsync - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync - - executor_class = RayTPUExecutorAsync - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutorAsync - - executor_class = TPUExecutorAsync - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutorAsync - - executor_class = CPUExecutorAsync - elif engine_config.device_config.device_type == "openvino": - assert distributed_executor_backend is None, ( - "Distributed execution is not supported with " "the OpenVINO backend." - ) - from vllm.executor.openvino_executor import OpenVINOExecutorAsync - - executor_class = OpenVINOExecutorAsync - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend is None: - from vllm.executor.xpu_executor import XPUExecutorAsync - - executor_class = XPUExecutorAsync - elif distributed_executor_backend == "ray": - from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync - - executor_class = RayXPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_xpu_executor import ( - MultiprocessingXPUExecutorAsync, - ) - - executor_class = MultiprocessingXPUExecutorAsync - else: - raise RuntimeError( - "Not supported distributed execution model on XPU device." - ) - elif distributed_executor_backend == "ray": - from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync - - executor_class = RayGPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutorAsync, - ) - - executor_class = MultiprocessingGPUExecutorAsync - else: - from vllm.executor.gpu_executor import GPUExecutorAsync - - executor_class = GPUExecutorAsync - return executor_class - - @classmethod - def from_engine_args( - cls, - engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - if engine_config is None: - engine_config = engine_args.create_engine_config() - - executor_class = cls._get_executor_cls(engine_config) - - if executor_class.uses_ray: - initialize_ray_cluster(engine_config.parallel_config) - - # Create the async LLM engine. - engine = cls( - **engine_config.to_dict(), - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - start_engine_loop=start_engine_loop, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine - - @property - def is_running(self) -> bool: - return ( - self.background_loop is not None - and self._background_loop_unshielded is not None - and not self._background_loop_unshielded.done() - ) - - @property - def is_stopped(self) -> bool: - return self.errored or ( - self.background_loop is not None - and self._background_loop_unshielded is not None - and self._background_loop_unshielded.done() - ) - - @property - def errored(self) -> bool: - return self._errored_with is not None - - @property - def dead_error(self) -> BaseException: - return AsyncEngineDeadError( - "Background loop is not running. If it was running, " - "inspect the output to find the stacktrace of the " - "error that caused the background loop to stop " - "(AsyncEngineDeadError)." - ) - - def set_errored(self, exc: Exception) -> None: - self._errored_with = exc - - def _error_callback(self, exc: Exception) -> None: - self.set_errored(exc) - self._request_tracker.propagate_exception(exc) - - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return await self.engine.get_tokenizer_group().get_lora_tokenizer_async( - lora_request - ) - - def start_background_loop(self) -> None: - """Start the background loop.""" - if self.errored: - raise AsyncEngineDeadError( - "Background loop has errored already." - ) from self._errored_with - if self.is_running: - raise RuntimeError("Background loop is already running.") - # Initialize the RequestTracker here so it uses the right event loop. - self._request_tracker = RequestTracker() - - self._background_loop_unshielded = asyncio.get_event_loop().create_task( - self.run_engine_loop(weakref.ref(self)) - ) - self._background_loop_unshielded.add_done_callback( - partial(_log_task_completion, error_callback=self._error_callback) - ) - self.background_loop = asyncio.shield(self._background_loop_unshielded) - - def shutdown_background_loop(self) -> None: - """ - Shut down the background loop. - - This method needs to be called during cleanup to remove - references to `self` and properly GC the resources held - by the async LLM engine (e.g., the executors as well as - their resources). - """ - if self._background_loop_unshielded is not None: - self._background_loop_unshielded.cancel() - self._background_loop_unshielded = None - self.background_loop = None - - async def engine_step(self, virtual_engine: int) -> bool: - """Kick the engine to process the waiting requests. - - Returns True if there are in-progress requests.""" - - new_requests, aborted_requests = ( - self._request_tracker.get_new_and_aborted_requests() - ) - - for new_request in new_requests: - # Add the request into the vLLM engine's waiting queue. - try: - await self.engine.add_request_async(**new_request) - except ValueError as e: - # TODO: use a vLLM specific error for failed validation - self._request_tracker.process_exception( - new_request["request_id"], - e, - verbose=self.log_requests, - ) - - if aborted_requests: - await self._engine_abort(aborted_requests) - - request_outputs = await self.engine.step_async(virtual_engine) - - # Put the outputs into the corresponding streams. - # If used as a callback, then already invoked inside - # LLMEngine's _process_model_outputs - if not self.use_process_request_outputs_callback: - all_finished = self.process_request_outputs(request_outputs) - else: - # For callback case, we only need to detect when all - # requests are finished - all_finished = all( - request_output.finished for request_output in request_outputs - ) - - return not all_finished - - def process_request_outputs(self, request_outputs) -> bool: - # Put the outputs into the corresponding streams. - all_finished = True - for request_output in request_outputs: - self._request_tracker.process_request_output( - request_output, verbose=self.log_requests - ) - all_finished = all_finished and request_output.finished - - return all_finished - - async def _engine_abort(self, request_ids: Iterable[str]): - self.engine.abort_request(request_ids) - - @staticmethod - async def run_engine_loop(engine_ref: ReferenceType): - """We use a weakref to the engine so that the running loop - doesn't prevent the engine being garbage collected.""" - engine: Optional["AsyncLLMEngine"] = engine_ref() - if not engine: - return - - pipeline_parallel_size = engine.engine.parallel_config.pipeline_parallel_size - has_requests_in_progress = [False] * pipeline_parallel_size - while True: - if not any(has_requests_in_progress): - logger.debug("Waiting for new requests...") - # Stop the execute model loop in parallel workers until there - # are more requests to process. This avoids waiting - # indefinitely in torch.distributed ops which may otherwise - # timeout, and unblocks the RPC thread in the workers so that - # they can process any other queued control plane messages, - # such as add/remove lora adapters. - await engine.engine.stop_remote_worker_execution_loop_async() - request_tracker = engine._request_tracker - # Allow engine to be garbage collected while - # waiting for new requests - del engine - await asyncio.sleep(0) - if engine_ref() is None: - return - await request_tracker.wait_for_new_requests() - engine = engine_ref() - if not engine: - return - logger.debug("Got new requests!") - requests_in_progress = [ - asyncio.create_task(engine.engine_step(ve)) - for ve in range(pipeline_parallel_size) - ] - has_requests_in_progress = [True] * pipeline_parallel_size - - # Abort if iteration takes too long due to unrecoverable errors - # (eg. NCCL timeouts). - try: - async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): - done, _ = await asyncio.wait( - requests_in_progress, return_when=asyncio.FIRST_COMPLETED - ) - for _ in range(pipeline_parallel_size): - await asyncio.sleep(0) - for task in done: - result = task.result() - virtual_engine = requests_in_progress.index(task) - has_unfinished_requests = ( - engine.engine.has_unfinished_requests_for_virtual_engine( - virtual_engine - ) - ) - if result or has_unfinished_requests: - requests_in_progress[virtual_engine] = asyncio.create_task( - engine.engine_step(virtual_engine) - ) - has_requests_in_progress[virtual_engine] = True - else: - has_requests_in_progress[virtual_engine] = False - except asyncio.TimeoutError as exc: - logger.error("Engine iteration timed out. This should never happen!") - engine.set_errored(exc) - raise - await asyncio.sleep(0) - - # This method does not need to be async, but kept that way - # for backwards compatibility. - @overload # DEPRECATED - def add_request( - self, - request_id: str, - *, - inputs: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> Coroutine[ - None, None, AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None] - ]: ... - - @overload - def add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> Coroutine[ - None, None, AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None] - ]: ... - - @deprecate_kwargs( - "inputs", - additional_message="Please use the 'prompt' parameter instead.", - ) - async def add_request( - self, - request_id: str, - prompt: Optional[PromptType] = None, - params: Optional[Union[SamplingParams, PoolingParams]] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: - if inputs is not None: - prompt = inputs - assert prompt is not None and params is not None - - if not self.is_running: - if self.start_engine_loop: - self.start_background_loop() - else: - raise AsyncEngineDeadError( - "Background loop is not running. If it was running, " - "inspect the output to find the stacktrace of the " - "error that caused the background loop to stop " - "(AsyncEngineDeadError)." - ) - - if priority != 0 and not self.engine.scheduler_config.policy == "priority": - raise ValueError( - f"Got priority {priority} but " "Priority scheduling is not enabled." - ) - - stream = self._request_tracker.add_request( - request_id, - verbose=self.log_requests, - prompt=prompt, - params=params, - arrival_time=arrival_time or time.time(), - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ) - - return stream.generator() - - async def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - tools_or_functions: Optional[List[dict]] = None, - prompt_template_cls: Optional[Any] = None, - tool_choice: Optional[Any] = None, - ) -> AsyncIterator[RequestOutput]: - """Generate outputs for a request. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` - for more details about the format of each input. - sampling_params: The sampling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request to use - for generation, if any. - priority: The priority of the request. - Only applicable with priority scheduling. - - Yields: - The output `RequestOutput` objects from the LLMEngine - for the request. - - Details: - - If the engine is not running, start the background loop, - which iteratively invokes - :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` - to process the waiting requests. - - Add the request to the engine's `RequestTracker`. - On the next background loop, this request will be sent to - the underlying engine. - Also, a corresponding `AsyncStream` will be created. - - Wait for the request outputs from `AsyncStream` and yield them. - - Example: - >>> # Please refer to entrypoints/api_server.py for - >>> # the complete example. - >>> - >>> # initialize the engine and the example input - >>> engine = AsyncLLMEngine.from_engine_args(engine_args) - >>> example_input = { - >>> "prompt": "What is LLM?", - >>> "stream": False, # assume the non-streaming case - >>> "temperature": 0.0, - >>> "request_id": 0, - >>> } - >>> - >>> # start the generation - >>> results_generator = engine.generate( - >>> example_input["prompt"], - >>> SamplingParams(temperature=example_input["temperature"]), - >>> example_input["request_id"]) - >>> - >>> # get the results - >>> final_output = None - >>> async for request_output in results_generator: - >>> if await request.is_disconnected(): - >>> # Abort the request if the client disconnects. - >>> await engine.abort(request_id) - >>> # Return or raise an error - >>> ... - >>> final_output = request_output - >>> - >>> # Process and return the final output - >>> ... - """ - # Initialize the request_id entry of self.engine.tools_or_functions - # and prompt_templates at the start of generate method - self.engine.tools_or_functions[request_id] = [] - for tool_or_func in tools_or_functions: - if "type" not in tool_or_func: - self.engine.tools_or_functions[request_id].append(tool_or_func) - elif tool_or_func["type"] == "function": - self.engine.tools_or_functions[request_id].append( - tool_or_func["function"] - ) - self.engine.prompt_templates[request_id] = prompt_template_cls - - # Initialize gen_state based on tool_choice - if tool_choice is not None: - if tool_choice in ["none", "required"]: - tool_choice_name = tool_choice - elif tool_choice == "auto": - tool_choice_name = "" - else: - tool_choice_name = ( - tool_choice.function.name - if isinstance(tool_choice, Tool) - else tool_choice.name - ) - else: - tool_choice_name = "" - curr_text, curr_tokens = "", [] - - # Initialize the request_id entry of self.gen_states - self.engine.gen_states[request_id] = self.engine.prompt_templates[ - request_id - ].initialize_fsm_gen_state( - tool_choice=tool_choice, - curr_text=curr_text, - curr_tokens=curr_tokens, - add_code_interpreter=( - True - if any( - [ - "type" in tool_or_func - and tool_or_func["type"] == "code_interpreter" - for tool_or_func in tools_or_functions - ] - ) - else False - ), - ) - - try: - async for output in await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ): - yield LLMEngine.validate_output(output, RequestOutput) - except Exception: - pass - finally: - # Delete request_id entry from self.engine before finishing the request - del self.engine.tools_or_functions[request_id] - del self.engine.prompt_templates[request_id] - del self.engine.gen_states[request_id] - - async def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: - """Generate outputs for a request from an embedding model. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` - for more details about the format of each input. - pooling_params: The pooling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - - Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine - for the request. - - Details: - - If the engine is not running, start the background loop, - which iteratively invokes - :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` - to process the waiting requests. - - Add the request to the engine's `RequestTracker`. - On the next background loop, this request will be sent to - the underlying engine. - Also, a corresponding `AsyncStream` will be created. - - Wait for the request outputs from `AsyncStream` and yield them. - - Example: - >>> # Please refer to entrypoints/api_server.py for - >>> # the complete example. - >>> - >>> # initialize the engine and the example input - >>> engine = AsyncLLMEngine.from_engine_args(engine_args) - >>> example_input = { - >>> "input": "What is LLM?", - >>> "request_id": 0, - >>> } - >>> - >>> # start the generation - >>> results_generator = engine.encode( - >>> example_input["input"], - >>> PoolingParams(), - >>> example_input["request_id"]) - >>> - >>> # get the results - >>> final_output = None - >>> async for request_output in results_generator: - >>> if await request.is_disconnected(): - >>> # Abort the request if the client disconnects. - >>> await engine.abort(request_id) - >>> # Return or raise an error - >>> ... - >>> final_output = request_output - >>> - >>> # Process and return the final output - >>> ... - """ - async for output in await self.add_request( - request_id, - prompt, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ): - yield LLMEngine.validate_output(output, EmbeddingRequestOutput) - - async def abort(self, request_id: str) -> None: - """Abort a request. - - Abort a submitted request. If the request is finished or not found, - this method will be a no-op. - - Args: - request_id: The unique id of the request. - """ - if not self.is_running: - raise AsyncEngineDeadError( - "Background loop is not running. If it was running, " - "inspect the output to find the stacktrace of the " - "error that caused the background loop to stop " - "(AsyncEngineDeadError)." - ) - - return self._abort(request_id) - - def _abort(self, request_id: str) -> None: - """Abort a request. - - Abort a submitted request. If the request is finished or not found, - this method will be a no-op. - - Args: - request_id: The unique id of the request. - """ - self._request_tracker.abort_request( - request_id, exception=asyncio.CancelledError, verbose=self.log_requests - ) - - async def get_model_config(self) -> ModelConfig: - """Get the model configuration of the vLLM engine.""" - return self.engine.get_model_config() - - async def get_parallel_config(self) -> ParallelConfig: - """Get the parallel configuration of the vLLM engine.""" - return self.engine.get_parallel_config() - - async def get_decoding_config(self) -> DecodingConfig: - """Get the decoding configuration of the vLLM engine.""" - return self.engine.get_decoding_config() - - async def get_scheduler_config(self) -> SchedulerConfig: - """Get the scheduling configuration of the vLLM engine.""" - return self.engine.get_scheduler_config() - - async def get_lora_config(self) -> LoRAConfig: - """Get the lora configuration of the vLLM engine.""" - return self.engine.get_lora_config() - - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - ) -> None: - self.engine.do_log_stats() - - async def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - t = time.perf_counter() - logger.debug("Starting health check...") - if self.is_stopped: - raise AsyncEngineDeadError("Background loop is stopped.") - - await self.engine.check_health_async() - logger.debug("Health check took %fs", time.perf_counter() - t) - - async def is_tracing_enabled(self) -> bool: - return self.engine.is_tracing_enabled() - - def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: - self.engine.add_logger(logger_name=logger_name, logger=logger) - - def remove_logger(self, logger_name: str) -> None: - self.engine.remove_logger(logger_name=logger_name) - - async def start_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 - self.engine.model_executor.start_profile() - else: - self.engine.model_executor._run_workers("start_profile") - - async def stop_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 - self.engine.model_executor.stop_profile() - else: - self.engine.model_executor._run_workers("stop_profile") diff --git a/server_vllm.py b/server_vllm.py index 56567380..0e5121f2 100644 --- a/server_vllm.py +++ b/server_vllm.py @@ -29,6 +29,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.api_server import mount_metrics from vllm.entrypoints.openai.protocol import ( LoadLoraAdapterRequest, @@ -125,7 +126,6 @@ async def create_chat_completion(raw_request: Request): served_model=served_model, served_loras=served_loras, engine_model_config=engine_model_config, - enable_grammar_sampling=args.grammar_sampling, engine=engine, ) @@ -181,28 +181,10 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest): help="LoRA modules in the format 'name=path name=path ...'", default=[], ) - parser.add_argument( - "--enable-grammar-sampling", - dest="grammar_sampling", - action="store_true", - default=False, - help="enable grammar sampling for function names", - ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - v1_pattern = r"v1.*$" - v31_pattern = r"v3.1$" - if re.search(v1_pattern, args.model) or re.search(v31_pattern, args.model): - args.grammar_sampling = False - - if args.grammar_sampling: - logger.info("Grammar sampling enabled.") - from functionary.vllm_monkey_patch.async_llm_engine import AsyncLLMEngine - else: - from vllm.engine.async_llm_engine import AsyncLLMEngine - mount_metrics(app) app.add_middleware( @@ -235,8 +217,6 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest): tokenizer = get_tokenizer( engine_args.tokenizer, tokenizer_mode=engine_args.tokenizer_mode ) - # Overwrite vLLM's default ModelConfig.max_logprobs of 5 - engine_args.max_logprobs = len(tokenizer.vocab.keys()) engine = AsyncLLMEngine.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config()) diff --git a/tests/test_server.py b/tests/test_server.py index 75b2526e..9216808f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -22,7 +22,6 @@ def popen_launch_server( base_url: str, timeout: float, context_length: int, - grammar_sampling: bool, env: Optional[dict] = None, return_stdout_stderr: bool = False, ) -> subprocess.Popen: @@ -35,7 +34,6 @@ def popen_launch_server( base_url (str): The base URL for the server. timeout (float): Maximum time to wait for server launch. context_length (int): The context length for the model. - grammar_sampling (bool): Whether to enable grammar sampling. env (Optional[dict]): Environment variables for the subprocess. Defaults to None. return_stdout_stderr (bool): Whether to capture and return stdout/stderr. Defaults to False. @@ -62,8 +60,6 @@ def popen_launch_server( command += ["--max-model-len", str(context_length)] else: command += ["--context-length", str(context_length)] - if grammar_sampling: - command += ["--enable-grammar-sampling"] if return_stdout_stderr: process = subprocess.Popen( @@ -667,16 +663,14 @@ def _evaluate_test_cases(self, model: str) -> None: def test_vllm_server(self): for model in self.served_models: - for grammar_sample in [False, True]: - self.process = popen_launch_server( - backend="vllm", - model=model, - base_url=self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - context_length=4096, - grammar_sampling=grammar_sample, - ) - self._evaluate_test_cases(model) + self.process = popen_launch_server( + backend="vllm", + model=model, + base_url=self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + context_length=4096, + ) + self._evaluate_test_cases(model) def test_sgl_server(self): for model in self.served_models: @@ -686,6 +680,5 @@ def test_sgl_server(self): base_url=self.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, context_length=4096, - grammar_sampling=False, ) self._evaluate_test_cases(model)