diff --git a/burr/core/action.py b/burr/core/action.py index eb6fceff..c42cff41 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -1511,14 +1511,15 @@ def pydantic( writes: List[str], state_input_type: Type["BaseModel"], state_output_type: Type["BaseModel"], - stream_type: Union[Type["BaseModel"], Type[dict]], + stream_type: Union[Type["BaseModel"], Type[dict], object], tags: Optional[List[str]] = None, ) -> Callable: """Creates a streaming action that uses pydantic models. :param reads: The fields this consumes from the state. :param writes: The fields this writes to the state. - :param stream_type: The pydantic model or dictionary type that is used to represent the partial results. + :param stream_type: The pydantic model, dictionary type, or a union of pydantic models + (e.g. ``ModelA | ModelB``) used to represent the partial results. Use a dict if you want this untyped. :param state_input_type: The pydantic model type that is used to represent the input state. :param state_output_type: The pydantic model type that is used to represent the output state. diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py index 300cbb6a..75fb86b9 100644 --- a/burr/integrations/pydantic.py +++ b/burr/integrations/pydantic.py @@ -269,7 +269,7 @@ async def async_action_function(state: State, **kwargs) -> State: return decorator -PartialType = Union[Type[pydantic.BaseModel], Type[dict]] +PartialType = Union[Type[pydantic.BaseModel], Type[dict], object] PydanticStreamingActionFunctionSync = Callable[ ..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None] @@ -290,7 +290,7 @@ async def async_action_function(state: State, **kwargs) -> State: def _validate_and_extract_signature_types_streaming( fn: PydanticStreamingActionFunction, - stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]], + stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict], object]], state_input_type: Optional[Type[pydantic.BaseModel]] = None, state_output_type: Optional[Type[pydantic.BaseModel]] = None, ) -> Tuple[