From 78a70a502021e82878b456203a96642553771147 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen Date: Tue, 7 Apr 2026 11:08:39 -0400 Subject: [PATCH 1/2] fix(event-stream): Handle unknown event types gracefully instead of crashing --- .../codegen/generators/UnionGenerator.java | 5 ++- .../_private/deserializers.py | 38 ++++++++++++++++--- .../smithy_aws_event_stream/aio/__init__.py | 2 +- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java index b3beb2cf3..34b6275d9 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/UnionGenerator.java @@ -149,6 +149,7 @@ private void generateDeserializer() { var symbol = symbolProvider.toSymbol(shape); var deserializerSymbol = symbol.expectProperty(SymbolProperties.DESERIALIZER); var schemaSymbol = symbol.expectProperty(SymbolProperties.SCHEMA); + var unknownSymbol = symbol.expectProperty(SymbolProperties.UNION_UNKNOWN); writer.putContext("schema", schemaSymbol); writer.write(""" class $1L: @@ -168,6 +169,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: ${4C|} case _: logger.debug("Unexpected member schema: %s", schema) + self._set_result($5L(tag=schema.member_name or "")) def _set_result(self, value: $2T) -> None: if self._result is not None: @@ -177,7 +179,8 @@ raise SerializationError("Unions must have exactly one value, but found more tha deserializerSymbol.getName(), symbol, schemaSymbol, - writer.consumer(w -> deserializeMembers())); + writer.consumer(w -> deserializeMembers()), + unknownSymbol.getName()); } private void deserializeMembers() { diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py index f52791450..cf0bf48cb 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py @@ -50,11 +50,39 @@ def read_struct( message_deserializer = self._create_deserializer(schema, headers) message_deserializer.read_struct(schema, consumer) else: - member_schema = schema.members[member_name] - message_deserializer = self._create_deserializer( - member_schema, headers - ) - consumer(member_schema, message_deserializer) + member_schema = schema.members.get(member_name) + if member_schema is None: + # Unknown event type. Call the consumer with a + # schema that carries the event type name as + # member_name and a member_index of -1 so the + # generated default branch constructs the unknown + # variant with the correct tag. + logger.debug( + "Unknown event type: %s", member_name + ) + from smithy_core.shapes import ShapeID + + _UNKNOWN_TARGET = Schema( + id=ShapeID("smithy.unknown#Unknown"), + shape_type=ShapeType.STRUCTURE, + ) + unknown_schema = Schema( + id=ShapeID( + f"smithy.unknown#Unknown${member_name}" + ), + shape_type=ShapeType.STRUCTURE, + member_target=_UNKNOWN_TARGET, + member_index=-1, + ) + consumer( + unknown_schema, + self._payload_codec.create_deserializer(b"{}"), + ) + else: + message_deserializer = self._create_deserializer( + member_schema, headers + ) + consumer(member_schema, message_deserializer) case "exception": member_name = expect_type(str, headers[":exception-type"]) member_schema = schema.members[member_name] diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py index 58ed7f184..a746fd61d 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py @@ -132,7 +132,7 @@ async def receive(self) -> E | None: ) result = self._deserializer(deserializer) logger.debug("Successfully deserialized event: %s", result) - if isinstance(getattr(result, "value"), Exception): + if isinstance(getattr(result, "value", None), Exception): raise result.value # type: ignore return result From e8f84ce8f8b07a7e8ad8310e320ee60a59f3c651 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen Date: Tue, 7 Apr 2026 14:53:00 -0400 Subject: [PATCH 2/2] Fix format issue and add unit tests for handling unknown events --- .../_private/deserializers.py | 11 +++-------- .../smithy_aws_event_stream/aio/__init__.py | 5 +++-- .../tests/unit/_private/__init__.py | 19 ++++++++++++++++--- .../tests/unit/_private/test_deserializers.py | 18 ++++++++++++++++++ 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py index cf0bf48cb..6b8783916 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/_private/deserializers.py @@ -10,7 +10,7 @@ SpecificShapeDeserializer, ) from smithy_core.schemas import Schema -from smithy_core.shapes import ShapeType +from smithy_core.shapes import ShapeID, ShapeType from smithy_core.traits import EventHeaderTrait from smithy_core.utils import expect_type @@ -57,19 +57,14 @@ def read_struct( # member_name and a member_index of -1 so the # generated default branch constructs the unknown # variant with the correct tag. - logger.debug( - "Unknown event type: %s", member_name - ) - from smithy_core.shapes import ShapeID + logger.debug("Unknown event type: %s", member_name) _UNKNOWN_TARGET = Schema( id=ShapeID("smithy.unknown#Unknown"), shape_type=ShapeType.STRUCTURE, ) unknown_schema = Schema( - id=ShapeID( - f"smithy.unknown#Unknown${member_name}" - ), + id=ShapeID(f"smithy.unknown#Unknown${member_name}"), shape_type=ShapeType.STRUCTURE, member_target=_UNKNOWN_TARGET, member_index=-1, diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py index a746fd61d..401442130 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py @@ -132,8 +132,9 @@ async def receive(self) -> E | None: ) result = self._deserializer(deserializer) logger.debug("Successfully deserialized event: %s", result) - if isinstance(getattr(result, "value", None), Exception): - raise result.value # type: ignore + value = getattr(result, "value", None) + if isinstance(value, Exception): + raise value return result async def close(self) -> None: diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py index 09162ecce..2c12808aa 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/__init__.py @@ -381,7 +381,7 @@ def serialize_members(self, serializer: ShapeSerializer): @dataclass -class EventStreamUnknownEvent: +class EventStreamUnknown: tag: str def serialize(self, serializer: ShapeSerializer): @@ -396,7 +396,7 @@ def serialize_members(self, serializer: ShapeSerializer): | EventStreamPayloadEvent | EventStreamBlobPayloadEvent | EventStreamErrorEvent - | EventStreamUnknownEvent + | EventStreamUnknown ) @@ -429,7 +429,7 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de))) case _: - raise SmithyError(f"Unexpected member schema: {schema}") + self._set_result(EventStreamUnknown(tag=schema.member_name or "")) def _set_result(self, value: EventStream) -> None: if self._result is not None: @@ -635,6 +635,19 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None: ] +UNKNOWN_EVENT_CASE = ( + EventStreamUnknown(tag="intermediateGroupEvent"), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "intermediateGroupEvent", + ":content-type": "application/json", + }, + payload=b"{}", + ), +) + + INITIAL_REQUEST_CASE = ( EventStreamOperationInputOutput(message="The initial request!"), EventMessage( diff --git a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py index 34081bbf8..41c635e03 100644 --- a/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/packages/smithy-aws-event-stream/tests/unit/_private/test_deserializers.py @@ -20,6 +20,7 @@ EventStreamDeserializer, EventStreamErrorEvent, EventStreamOperationInputOutput, + EventStreamUnknown, ) @@ -126,3 +127,20 @@ async def test_read_closed_receiver_source() -> None: with pytest.raises(IOError): await receiver.receive() assert receiver.closed + + +def test_deserialize_unknown_event_type(): + message = EventMessage( + headers={ + ":message-type": "event", + ":event-type": "intermediateGroupEvent", + ":content-type": "application/json", + }, + payload=b"{}", + ) + source = Event.decode(BytesIO(message.encode())) + assert source is not None + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) + result = EventStreamDeserializer().deserialize(deserializer) + assert isinstance(result, EventStreamUnknown) + assert result.tag == "intermediateGroupEvent"