Skip to content

Commit 0f6d7cc

Browse files
Cleptomaniaclaude
andcommitted
Add RawHitBox TypedDict and normalize points to tuples
Add a RawHitBox TypedDict for typed serialization in to_dict/from_dict. Normalize all points to tuples of tuples on construction, add_region, and in adjusted point calculation for immutability and consistency. Update tests to expect tuple-based comparisons. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 50a79af commit 0f6d7cc

File tree

4 files changed

+47
-33
lines changed

4 files changed

+47
-33
lines changed

arcade/hitbox/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from arcade.types import Point2List
44

5-
from .base import HitBox, HitBoxAlgorithm
5+
from .base import HitBox, HitBoxAlgorithm, RawHitBox
66
from .bounding_box import BoundingHitBoxAlgorithm
77

88
from .simple import SimpleHitBoxAlgorithm
@@ -59,6 +59,7 @@ def calculate_hit_box_points_detailed(
5959
__all__ = [
6060
"HitBoxAlgorithm",
6161
"HitBox",
62+
"RawHitBox",
6263
"SimpleHitBoxAlgorithm",
6364
"PymunkHitBoxAlgorithm",
6465
"BoundingHitBoxAlgorithm",

arcade/hitbox/base.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
import json
55
from math import cos, radians, sin
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, TypedDict
88

99
from PIL.Image import Image
1010
from typing_extensions import Self
1111

1212
from arcade.types import EMPTY_POINT_LIST, Point2, Point2List
1313

14-
__all__ = ["HitBoxAlgorithm", "HitBox"]
14+
__all__ = ["HitBoxAlgorithm", "HitBox", "RawHitBox"]
15+
16+
17+
class RawHitBox(TypedDict):
18+
"""Typed dictionary representing the serialized form of a :py:class:`HitBox`."""
19+
20+
version: int
21+
regions: dict[str, Point2List]
1522

1623

1724
class HitBoxAlgorithm:
@@ -134,14 +141,21 @@ class HitBox:
134141
135142
loaded = HitBox.load("hitbox.json")
136143
137-
# Dict round-trip
144+
# Dict round-trip (see RawHitBox for the schema)
138145
data = box.to_dict()
139146
copy = HitBox.from_dict(data)
140147
148+
.. note::
149+
150+
All points are normalized to tuples of tuples on construction.
151+
Any sequence type is accepted as input, but regions will always
152+
store tuples internally.
153+
141154
Args:
142155
points:
143156
Either a single ``Point2List`` (creates a ``"default"`` region)
144157
or a ``dict[str, Point2List]`` mapping region names to point lists.
158+
Points are normalized to tuples on storage.
145159
position:
146160
The center around which the points will be offset.
147161
scale:
@@ -160,9 +174,11 @@ def __init__(
160174
angle: float = 0.0,
161175
):
162176
if isinstance(points, dict):
163-
self._regions: dict[str, Point2List] = dict(points)
177+
self._regions: dict[str, Point2List] = {
178+
name: tuple(tuple(p) for p in pts) for name, pts in points.items()
179+
}
164180
else:
165-
self._regions = {self.DEFAULT_REGION: points}
181+
self._regions = {self.DEFAULT_REGION: tuple(tuple(p) for p in points)}
166182

167183
self._position = position
168184
self._scale = scale
@@ -205,7 +221,7 @@ def add_region(self, name: str, points: Point2List) -> None:
205221
name: The name for the new region.
206222
points: The polygon points for the region.
207223
"""
208-
self._regions[name] = points
224+
self._regions[name] = tuple(tuple(p) for p in points)
209225
self._is_single_region = len(self._regions) == 1
210226
self._adjusted_cache_dirty = True
211227

@@ -324,7 +340,7 @@ def _adjust_point(point: Point2) -> Point2:
324340
return (x + position_x, y + position_y)
325341

326342
self._adjusted_regions = {
327-
name: [_adjust_point(p) for p in pts] for name, pts in self._regions.items()
343+
name: tuple(_adjust_point(p) for p in pts) for name, pts in self._regions.items()
328344
}
329345
self._adjusted_cache_dirty = False
330346

@@ -352,39 +368,36 @@ def get_all_adjusted_polygons(self) -> list[Point2List]:
352368

353369
# --- Serialization ---
354370

355-
def to_dict(self) -> dict:
371+
def to_dict(self) -> RawHitBox:
356372
"""
357-
Serialize the hitbox shape to a dictionary.
373+
Serialize the hitbox shape to a :py:class:`RawHitBox` dictionary.
358374
359375
Only the region definitions (point data) are serialized.
360376
Position, scale, and angle are runtime state and are not included.
361377
"""
362378
return {
363379
"version": 1,
364-
"regions": {name: [list(p) for p in pts] for name, pts in self._regions.items()},
380+
"regions": {name: pts for name, pts in self._regions.items()},
365381
}
366382

367383
@classmethod
368384
def from_dict(
369385
cls,
370-
data: dict,
386+
data: RawHitBox,
371387
position: Point2 = (0.0, 0.0),
372388
scale: Point2 = (1.0, 1.0),
373389
angle: float = 0.0,
374390
) -> HitBox:
375391
"""
376-
Create a HitBox from a serialized dictionary.
392+
Create a HitBox from a :py:class:`RawHitBox` dictionary.
377393
378394
Args:
379-
data: The dictionary to deserialize from.
395+
data: A :py:class:`RawHitBox` dictionary to deserialize from.
380396
position: The center offset.
381397
scale: The scaling factors.
382398
angle: The rotation angle in degrees.
383399
"""
384-
regions: dict[str, Point2List] = {
385-
name: tuple(tuple(p) for p in pts) for name, pts in data["regions"].items()
386-
}
387-
return cls(points=regions, position=position, scale=scale, angle=angle)
400+
return cls(points=data["regions"], position=position, scale=scale, angle=angle)
388401

389402
def save(self, path: str | Path) -> None:
390403
"""

tests/unit/hitbox/test_hitbox.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import pytest
66
from arcade import hitbox
77

8-
points = [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)]
9-
rot_90 = [(0.0, 0.0), (10.0, 0), (10.0, -10.0), (0.0, -10.0)]
8+
points = ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0))
9+
rot_90 = ((0.0, 0.0), (10.0, 0), (10.0, -10.0), (0.0, -10.0))
1010

1111

1212
def test_module():
@@ -34,14 +34,14 @@ def test_scale():
3434
hb = hitbox.HitBox(points)
3535
hb.scale = (2.0, 2.0)
3636
assert hb.scale == (2.0, 2.0)
37-
assert hb.get_adjusted_points() == [(0.0, 0.0), (0.0, 20.0), (20.0, 20.0), (20.0, 0.0)]
37+
assert hb.get_adjusted_points() == ((0.0, 0.0), (0.0, 20.0), (20.0, 20.0), (20.0, 0.0))
3838

3939

4040
def test_position():
4141
hb = hitbox.HitBox(points)
4242
hb.position = (10.0, 10.0)
4343
assert hb.position == (10.0, 10.0)
44-
assert hb.get_adjusted_points() == [(10.0, 10.0), (10.0, 20.0), (20.0, 20.0), (20.0, 10.0)]
44+
assert hb.get_adjusted_points() == ((10.0, 10.0), (10.0, 20.0), (20.0, 20.0), (20.0, 10.0))
4545

4646

4747
def test_rotation():
@@ -74,8 +74,8 @@ def test_multi_region_create():
7474
assert hb.has_region("body")
7575
assert hb.has_region("head")
7676
assert not hb.has_region("default")
77-
assert hb.regions["body"] == body_pts
78-
assert hb.regions["head"] == head_pts
77+
assert hb.regions["body"] == tuple(tuple(p) for p in body_pts)
78+
assert hb.regions["head"] == tuple(tuple(p) for p in head_pts)
7979

8080

8181
def test_multi_region_adjusted():
@@ -84,8 +84,8 @@ def test_multi_region_adjusted():
8484
hb = hitbox.HitBox({"body": body_pts, "head": head_pts}, position=(5.0, 5.0))
8585
body_adj = hb.get_adjusted_points("body")
8686
head_adj = hb.get_adjusted_points("head")
87-
assert body_adj == [(5.0, 5.0), (5.0, 15.0), (15.0, 15.0), (15.0, 5.0)]
88-
assert head_adj == [(7.0, 15.0), (7.0, 20.0), (13.0, 20.0), (13.0, 15.0)]
87+
assert body_adj == ((5.0, 5.0), (5.0, 15.0), (15.0, 15.0), (15.0, 5.0))
88+
assert head_adj == ((7.0, 15.0), (7.0, 20.0), (13.0, 20.0), (13.0, 15.0))
8989

9090

9191
def test_multi_region_boundaries():
@@ -134,7 +134,7 @@ def test_single_region_fast_path():
134134
hb = hitbox.HitBox(points)
135135
polys = hb.get_all_adjusted_polygons()
136136
assert len(polys) == 1
137-
assert polys[0] == list(points)
137+
assert polys[0] == points
138138

139139

140140
# --- Serialization tests ---
@@ -167,7 +167,7 @@ def test_from_dict():
167167
}
168168
hb = hitbox.HitBox.from_dict(d)
169169
assert hb.points == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0))
170-
assert hb.get_adjusted_points() == [(0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0)]
170+
assert hb.get_adjusted_points() == ((0.0, 0.0), (0.0, 10.0), (10.0, 10.0), (10.0, 0.0))
171171

172172

173173
def test_roundtrip_dict():

tests/unit/sprite/test_sprite_hitbox.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,27 @@ def test_1():
1515
print()
1616
hitbox = my_sprite.hit_box.get_adjusted_points()
1717
print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}")
18-
assert hitbox == [(90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0)]
18+
assert hitbox == ((90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0))
1919

2020
my_sprite.scale = 0.5, 0.5
2121
hitbox = my_sprite.hit_box.get_adjusted_points()
2222
print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}")
23-
assert hitbox == [(95.0, 95.0), (95.0, 105.0), (105.0, 105.0), (105.0, 95.0)]
23+
assert hitbox == ((95.0, 95.0), (95.0, 105.0), (105.0, 105.0), (105.0, 95.0))
2424

2525
my_sprite.scale = 1.0
2626
hitbox = my_sprite.hit_box.get_adjusted_points()
2727
print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}")
28-
assert hitbox == [(90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0)]
28+
assert hitbox == ((90.0, 90.0), (90.0, 110.0), (110.0, 110.0), (110.0, 90.0))
2929

3030
my_sprite.scale = 2.0
3131
hitbox = my_sprite.hit_box.get_adjusted_points()
3232
print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}")
33-
assert hitbox == [(80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0)]
33+
assert hitbox == ((80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0))
3434

3535
my_sprite.scale = 2.0
3636
hitbox = my_sprite.hit_box.get_adjusted_points()
3737
print(f"Hitbox: {my_sprite.scale} -> {my_sprite.hit_box.points} -> {hitbox}")
38-
assert hitbox == [(80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0)]
38+
assert hitbox == ((80.0, 80.0), (80.0, 120.0), (120.0, 120.0), (120.0, 80.0))
3939

4040

4141
def test_2():

0 commit comments

Comments
 (0)