diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index fd9ce452..6623685a 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -8,6 +8,8 @@ on: schedule: - cron: '0 11 * * 2' +permissions: {} + jobs: CodeQL-Build: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index eb8b999b..d4cdac91 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,6 +10,8 @@ on: types: - published +permissions: {} + jobs: build: name: Build source distribution diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1a215745..a3ebec08 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,6 +6,8 @@ on: schedule: - cron: '3 15 * * SUN' +permissions: {} + jobs: build: diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml index c09cf120..7ca9cb2b 100644 --- a/.github/workflows/zizmor.yml +++ b/.github/workflows/zizmor.yml @@ -6,6 +6,8 @@ on: pull_request: branches: ["**"] +permissions: {} + jobs: zizmor: name: zizmor latest via PyPI diff --git a/HISTORY.rst b/HISTORY.rst index 2ff32b20..3dd2f5a6 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,9 +4,18 @@ History ------- -4.9.0 -++++++++++++++++++ - +5.0.0 +++++++++++++++++++ + +* BREAKING: The ``raw`` attribute on the model classes has been replaced + with a ``to_dict()`` method. This can be used to get a representation of + the object that is suitable for serialization. +* BREAKING: The ``ip_address`` property on the model classes now always returns + a ``ipaddress.IPv4Address`` or ``ipaddress.IPv6Address``. +* BREAKING: The model and record classes now require all arguments other than + ``locales`` and ``ip_address`` to be keyword arguments. +* BREAKING: ``geoip2.mixins`` has been made internal. This normally would not + have been used by external code. * IMPORTANT: Python 3.9 or greater is required. If you are using an older version, please use an earlier release. * ``metro_code`` on ``geoip2.record.Location`` has been deprecated. The diff --git a/geoip2/_internal.py b/geoip2/_internal.py new file mode 100644 index 00000000..6f37ced5 --- /dev/null +++ b/geoip2/_internal.py @@ -0,0 +1,52 @@ +"""This package contains internal utilities""" + +# pylint: disable=too-few-public-methods +from abc import ABCMeta +from typing import Any + + +class Model(metaclass=ABCMeta): + """Shared methods for MaxMind model classes""" + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.to_dict() == other.to_dict() + + def __ne__(self, other): + return not self.__eq__(other) + + # pylint: disable=too-many-branches + def to_dict(self): + """Returns a dict of the object suitable for serialization""" + result = {} + for key, value in self.__dict__.items(): + if key.startswith("_"): + continue + if hasattr(value, "to_dict") and callable(value.to_dict): + if d := value.to_dict(): + result[key] = d + elif isinstance(value, (list, tuple)): + ls = [] + for e in value: + if hasattr(e, "to_dict") and callable(e.to_dict): + if e := e.to_dict(): + ls.append(e) + elif e is not None: + ls.append(e) + if ls: + result[key] = ls + # We only have dicts of strings currently. Do not bother with + # the general case. + elif isinstance(value, dict): + if value: + result[key] = value + elif value is not None and value is not False: + result[key] = value + + # network and ip_address are properties for performance reasons + # pylint: disable=no-member + if hasattr(self, "ip_address") and self.ip_address is not None: + result["ip_address"] = str(self.ip_address) + if hasattr(self, "network") and self.network is not None: + result["network"] = str(self.network) + + return result diff --git a/geoip2/database.py b/geoip2/database.py index ab985076..4652d7d0 100644 --- a/geoip2/database.py +++ b/geoip2/database.py @@ -252,10 +252,9 @@ def _model_for( ip_address: IPAddress, ) -> Union[Country, Enterprise, City]: (record, prefix_len) = self._get(types, ip_address) - traits = record.setdefault("traits", {}) - traits["ip_address"] = ip_address - traits["prefix_len"] = prefix_len - return model_class(record, locales=self._locales) + return model_class( + self._locales, ip_address=ip_address, prefix_len=prefix_len, **record + ) def _flat_model_for( self, @@ -266,9 +265,7 @@ def _flat_model_for( ip_address: IPAddress, ) -> Union[ConnectionType, ISP, AnonymousIP, Domain, ASN]: (record, prefix_len) = self._get(types, ip_address) - record["ip_address"] = ip_address - record["prefix_len"] = prefix_len - return model_class(record) + return model_class(ip_address, prefix_len=prefix_len, **record) def metadata( self, diff --git a/geoip2/mixins.py b/geoip2/mixins.py deleted file mode 100644 index 2581f5fb..00000000 --- a/geoip2/mixins.py +++ /dev/null @@ -1,15 +0,0 @@ -"""This package contains utility mixins""" - -# pylint: disable=too-few-public-methods -from abc import ABCMeta -from typing import Any - - -class SimpleEquality(metaclass=ABCMeta): - """Naive __dict__ equality mixin""" - - def __eq__(self, other: Any) -> bool: - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not self.__eq__(other) diff --git a/geoip2/models.py b/geoip2/models.py index 693a305a..ac0199c7 100644 --- a/geoip2/models.py +++ b/geoip2/models.py @@ -11,16 +11,17 @@ """ -# pylint: disable=too-many-instance-attributes,too-few-public-methods +# pylint: disable=too-many-instance-attributes,too-few-public-methods,too-many-arguments import ipaddress from abc import ABCMeta -from typing import Any, cast, Dict, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union import geoip2.records -from geoip2.mixins import SimpleEquality +from geoip2._internal import Model +from geoip2.types import IPAddress -class Country(SimpleEquality): +class Country(Model): """Model for the Country web service and Country database. This class provides the following attributes: @@ -76,32 +77,43 @@ class Country(SimpleEquality): traits: geoip2.records.Traits def __init__( - self, raw_response: Dict[str, Any], locales: Optional[Sequence[str]] = None + self, + locales: Optional[Sequence[str]], + *, + continent: Optional[Dict] = None, + country: Optional[Dict] = None, + ip_address: Optional[IPAddress] = None, + maxmind: Optional[Dict] = None, + prefix_len: Optional[int] = None, + registered_country: Optional[Dict] = None, + represented_country: Optional[Dict] = None, + traits: Optional[Dict] = None, + **_, ) -> None: - if locales is None: - locales = ["en"] self._locales = locales - self.continent = geoip2.records.Continent( - locales, **raw_response.get("continent", {}) - ) - self.country = geoip2.records.Country( - locales, **raw_response.get("country", {}) - ) + self.continent = geoip2.records.Continent(locales, **(continent or {})) + self.country = geoip2.records.Country(locales, **(country or {})) self.registered_country = geoip2.records.Country( - locales, **raw_response.get("registered_country", {}) + locales, **(registered_country or {}) ) self.represented_country = geoip2.records.RepresentedCountry( - locales, **raw_response.get("represented_country", {}) + locales, **(represented_country or {}) ) - self.maxmind = geoip2.records.MaxMind(**raw_response.get("maxmind", {})) + self.maxmind = geoip2.records.MaxMind(**(maxmind or {})) + + traits = traits or {} + if ip_address is not None: + traits["ip_address"] = ip_address + if prefix_len is not None: + traits["prefix_len"] = prefix_len - self.traits = geoip2.records.Traits(**raw_response.get("traits", {})) - self.raw = raw_response + self.traits = geoip2.records.Traits(**traits) def __repr__(self) -> str: return ( - f"{self.__module__}.{self.__class__.__name__}({self.raw}, {self._locales})" + f"{self.__module__}.{self.__class__.__name__}({repr(self._locales)}, " + f"{', '.join(f'{k}={repr(v)}' for k, v in self.to_dict().items())})" ) @@ -182,15 +194,38 @@ class City(Country): subdivisions: geoip2.records.Subdivisions def __init__( - self, raw_response: Dict[str, Any], locales: Optional[Sequence[str]] = None + self, + locales: Optional[Sequence[str]], + *, + city: Optional[Dict] = None, + continent: Optional[Dict] = None, + country: Optional[Dict] = None, + location: Optional[Dict] = None, + ip_address: Optional[IPAddress] = None, + maxmind: Optional[Dict] = None, + postal: Optional[Dict] = None, + prefix_len: Optional[int] = None, + registered_country: Optional[Dict] = None, + represented_country: Optional[Dict] = None, + subdivisions: Optional[List[Dict]] = None, + traits: Optional[Dict] = None, + **_, ) -> None: - super().__init__(raw_response, locales) - self.city = geoip2.records.City(locales, **raw_response.get("city", {})) - self.location = geoip2.records.Location(**raw_response.get("location", {})) - self.postal = geoip2.records.Postal(**raw_response.get("postal", {})) - self.subdivisions = geoip2.records.Subdivisions( - locales, *raw_response.get("subdivisions", []) + super().__init__( + locales, + continent=continent, + country=country, + ip_address=ip_address, + maxmind=maxmind, + prefix_len=prefix_len, + registered_country=registered_country, + represented_country=represented_country, + traits=traits, ) + self.city = geoip2.records.City(locales, **(city or {})) + self.location = geoip2.records.Location(**(location or {})) + self.postal = geoip2.records.Postal(**(postal or {})) + self.subdivisions = geoip2.records.Subdivisions(locales, *(subdivisions or [])) class Insights(City): @@ -321,22 +356,49 @@ class Enterprise(City): """ -class SimpleModel(SimpleEquality, metaclass=ABCMeta): +class SimpleModel(Model, metaclass=ABCMeta): """Provides basic methods for non-location models""" - raw: Dict[str, Union[bool, str, int]] - ip_address: str + _ip_address: IPAddress _network: Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] - _prefix_len: int + _prefix_len: Optional[int] - def __init__(self, raw: Dict[str, Union[bool, str, int]]) -> None: - self.raw = raw - self._network = None - self._prefix_len = cast(int, raw.get("prefix_len")) - self.ip_address = cast(str, raw.get("ip_address")) + def __init__( + self, + ip_address: IPAddress, + network: Optional[str], + prefix_len: Optional[int], + ) -> None: + if network: + self._network = ipaddress.ip_network(network, False) + self._prefix_len = self._network.prefixlen + else: + # This case is for MMDB lookups where performance is paramount. + # This is why we don't generate the network unless .network is + # used. + self._network = None + self._prefix_len = prefix_len + self._ip_address = ip_address def __repr__(self) -> str: - return f"{self.__module__}.{self.__class__.__name__}({self.raw})" + d = self.to_dict() + d.pop("ip_address", None) + return ( + f"{self.__module__}.{self.__class__.__name__}(" + + repr(str(self._ip_address)) + + ", " + + ", ".join(f"{k}={repr(v)}" for k, v in d.items()) + + ")" + ) + + @property + def ip_address(self): + """The IP address for the record""" + if not isinstance( + self._ip_address, (ipaddress.IPv4Address, ipaddress.IPv6Address) + ): + self._ip_address = ipaddress.ip_address(self._ip_address) + return self._ip_address @property def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: @@ -425,14 +487,27 @@ class AnonymousIP(SimpleModel): is_residential_proxy: bool is_tor_exit_node: bool - def __init__(self, raw: Dict[str, bool]) -> None: - super().__init__(raw) # type: ignore - self.is_anonymous = raw.get("is_anonymous", False) - self.is_anonymous_vpn = raw.get("is_anonymous_vpn", False) - self.is_hosting_provider = raw.get("is_hosting_provider", False) - self.is_public_proxy = raw.get("is_public_proxy", False) - self.is_residential_proxy = raw.get("is_residential_proxy", False) - self.is_tor_exit_node = raw.get("is_tor_exit_node", False) + def __init__( + self, + ip_address: IPAddress, + *, + is_anonymous: bool = False, + is_anonymous_vpn: bool = False, + is_hosting_provider: bool = False, + is_public_proxy: bool = False, + is_residential_proxy: bool = False, + is_tor_exit_node: bool = False, + network: Optional[str] = None, + prefix_len: Optional[int] = None, + **_, + ) -> None: + super().__init__(ip_address, network, prefix_len) + self.is_anonymous = is_anonymous + self.is_anonymous_vpn = is_anonymous_vpn + self.is_hosting_provider = is_hosting_provider + self.is_public_proxy = is_public_proxy + self.is_residential_proxy = is_residential_proxy + self.is_tor_exit_node = is_tor_exit_node class ASN(SimpleModel): @@ -472,14 +547,19 @@ class ASN(SimpleModel): autonomous_system_organization: Optional[str] # pylint:disable=too-many-arguments,too-many-positional-arguments - def __init__(self, raw: Dict[str, Union[str, int]]) -> None: - super().__init__(raw) - self.autonomous_system_number = cast( - Optional[int], raw.get("autonomous_system_number") - ) - self.autonomous_system_organization = cast( - Optional[str], raw.get("autonomous_system_organization") - ) + def __init__( + self, + ip_address: IPAddress, + *, + autonomous_system_number: Optional[int] = None, + autonomous_system_organization: Optional[str] = None, + network: Optional[str] = None, + prefix_len: Optional[int] = None, + **_, + ) -> None: + super().__init__(ip_address, network, prefix_len) + self.autonomous_system_number = autonomous_system_number + self.autonomous_system_organization = autonomous_system_organization class ConnectionType(SimpleModel): @@ -518,9 +598,17 @@ class ConnectionType(SimpleModel): connection_type: Optional[str] - def __init__(self, raw: Dict[str, Union[str, int]]) -> None: - super().__init__(raw) - self.connection_type = cast(Optional[str], raw.get("connection_type")) + def __init__( + self, + ip_address: IPAddress, + *, + connection_type: Optional[str] = None, + network: Optional[str] = None, + prefix_len: Optional[int] = None, + **_, + ) -> None: + super().__init__(ip_address, network, prefix_len) + self.connection_type = connection_type class Domain(SimpleModel): @@ -552,9 +640,17 @@ class Domain(SimpleModel): domain: Optional[str] - def __init__(self, raw: Dict[str, Union[str, int]]) -> None: - super().__init__(raw) - self.domain = cast(Optional[str], raw.get("domain")) + def __init__( + self, + ip_address: IPAddress, + *, + domain: Optional[str] = None, + network: Optional[str] = None, + prefix_len: Optional[int] = None, + **_, + ) -> None: + super().__init__(ip_address, network, prefix_len) + self.domain = domain class ISP(ASN): @@ -624,9 +720,28 @@ class ISP(ASN): organization: Optional[str] # pylint:disable=too-many-arguments,too-many-positional-arguments - def __init__(self, raw: Dict[str, Union[str, int]]) -> None: - super().__init__(raw) - self.isp = cast(Optional[str], raw.get("isp")) - self.mobile_country_code = cast(Optional[str], raw.get("mobile_country_code")) - self.mobile_network_code = cast(Optional[str], raw.get("mobile_network_code")) - self.organization = cast(Optional[str], raw.get("organization")) + def __init__( + self, + ip_address: IPAddress, + *, + autonomous_system_number: Optional[int] = None, + autonomous_system_organization: Optional[str] = None, + isp: Optional[str] = None, + mobile_country_code: Optional[str] = None, + mobile_network_code: Optional[str] = None, + organization: Optional[str] = None, + network: Optional[str] = None, + prefix_len: Optional[int] = None, + **_, + ) -> None: + super().__init__( + autonomous_system_number=autonomous_system_number, + autonomous_system_organization=autonomous_system_organization, + ip_address=ip_address, + network=network, + prefix_len=prefix_len, + ) + self.isp = isp + self.mobile_country_code = mobile_country_code + self.mobile_network_code = mobile_network_code + self.organization = organization diff --git a/geoip2/records.py b/geoip2/records.py index 4481a702..17c1f37b 100644 --- a/geoip2/records.py +++ b/geoip2/records.py @@ -13,14 +13,14 @@ from abc import ABCMeta from typing import Dict, Optional, Type, Sequence, Union -from geoip2.mixins import SimpleEquality +from geoip2._internal import Model -class Record(SimpleEquality, metaclass=ABCMeta): +class Record(Model, metaclass=ABCMeta): """All records are subclasses of the abstract class ``Record``.""" def __repr__(self) -> str: - args = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items()) + args = ", ".join(f"{k}={v!r}" for k, v in self.to_dict().items()) return f"{self.__module__}.{self.__class__.__name__}({args})" @@ -32,8 +32,8 @@ class PlaceRecord(Record, metaclass=ABCMeta): def __init__( self, - locales: Optional[Sequence[str]] = None, - names: Optional[Dict[str, str]] = None, + locales: Optional[Sequence[str]], + names: Optional[Dict[str, str]], ) -> None: if locales is None: locales = ["en"] @@ -93,7 +93,8 @@ class City(PlaceRecord): def __init__( self, - locales: Optional[Sequence[str]] = None, + locales: Optional[Sequence[str]], + *, confidence: Optional[int] = None, geoname_id: Optional[int] = None, names: Optional[Dict[str, str]] = None, @@ -147,7 +148,8 @@ class Continent(PlaceRecord): def __init__( self, - locales: Optional[Sequence[str]] = None, + locales: Optional[Sequence[str]], + *, code: Optional[str] = None, geoname_id: Optional[int] = None, names: Optional[Dict[str, str]] = None, @@ -217,7 +219,8 @@ class Country(PlaceRecord): def __init__( self, - locales: Optional[Sequence[str]] = None, + locales: Optional[Sequence[str]], + *, confidence: Optional[int] = None, geoname_id: Optional[int] = None, is_in_european_union: bool = False, @@ -298,7 +301,8 @@ class RepresentedCountry(Country): def __init__( self, - locales: Optional[Sequence[str]] = None, + locales: Optional[Sequence[str]], + *, confidence: Optional[int] = None, geoname_id: Optional[int] = None, is_in_european_union: bool = False, @@ -310,7 +314,12 @@ def __init__( ) -> None: self.type = type super().__init__( - locales, confidence, geoname_id, is_in_european_union, iso_code, names + locales, + confidence=confidence, + geoname_id=geoname_id, + is_in_european_union=is_in_european_union, + iso_code=iso_code, + names=names, ) @@ -391,6 +400,7 @@ class Location(Record): def __init__( self, + *, average_income: Optional[int] = None, accuracy_radius: Optional[int] = None, latitude: Optional[float] = None, @@ -425,7 +435,7 @@ class MaxMind(Record): queries_remaining: Optional[int] - def __init__(self, queries_remaining: Optional[int] = None, **_) -> None: + def __init__(self, *, queries_remaining: Optional[int] = None, **_) -> None: self.queries_remaining = queries_remaining @@ -460,7 +470,7 @@ class Postal(Record): confidence: Optional[int] def __init__( - self, code: Optional[str] = None, confidence: Optional[int] = None, **_ + self, *, code: Optional[str] = None, confidence: Optional[int] = None, **_ ) -> None: self.code = code self.confidence = confidence @@ -519,7 +529,8 @@ class Subdivision(PlaceRecord): def __init__( self, - locales: Optional[Sequence[str]] = None, + locales: Optional[Sequence[str]], + *, confidence: Optional[int] = None, geoname_id: Optional[int] = None, iso_code: Optional[str] = None, @@ -552,7 +563,9 @@ def __new__( return obj def __init__( - self, locales: Optional[Sequence[str]], *subdivisions # pylint:disable=W0613 + self, + locales: Optional[Sequence[str]], + *subdivisions, # pylint:disable=W0613 ) -> None: self._locales = locales super().__init__() @@ -825,7 +838,7 @@ class Traits(Record): autonomous_system_organization: Optional[str] connection_type: Optional[str] domain: Optional[str] - ip_address: Optional[str] + _ip_address: Optional[str] is_anonymous: bool is_anonymous_proxy: bool is_anonymous_vpn: bool @@ -848,6 +861,7 @@ class Traits(Record): def __init__( self, + *, autonomous_system_number: Optional[int] = None, autonomous_system_organization: Optional[str] = None, connection_type: Optional[str] = None, @@ -895,7 +909,7 @@ def __init__( self.static_ip_score = static_ip_score self.user_type = user_type self.user_count = user_count - self.ip_address = ip_address + self._ip_address = ip_address if network is None: self._network = None else: @@ -905,6 +919,15 @@ def __init__( # much more performance sensitive than web service users. self._prefix_len = prefix_len + @property + def ip_address(self): + """The IP address for the record""" + if not isinstance( + self._ip_address, (ipaddress.IPv4Address, ipaddress.IPv6Address) + ): + self._ip_address = ipaddress.ip_address(self._ip_address) + return self._ip_address + @property def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: """The network for the record""" diff --git a/geoip2/webservice.py b/geoip2/webservice.py index b390e94e..3158d735 100644 --- a/geoip2/webservice.py +++ b/geoip2/webservice.py @@ -350,7 +350,7 @@ async def _response_for( if status != 200: raise self._exception_for_error(status, content_type, body, uri) decoded_body = self._handle_success(body, uri) - return model_class(decoded_body, locales=self._locales) + return model_class(self._locales, **decoded_body) async def close(self): """Close underlying session @@ -499,7 +499,7 @@ def _response_for( if status != 200: raise self._exception_for_error(status, content_type, body, uri) decoded_body = self._handle_success(body, uri) - return model_class(decoded_body, locales=self._locales) + return model_class(self._locales, **decoded_body) def close(self): """Close underlying session diff --git a/tests/database_test.py b/tests/database_test.py index 73279dcd..3008f574 100644 --- a/tests/database_test.py +++ b/tests/database_test.py @@ -35,7 +35,7 @@ def test_unknown_address(self) -> None: reader = geoip2.database.Reader("tests/data/test-data/GeoIP2-City-Test.mmdb") with self.assertRaisesRegex( geoip2.errors.AddressNotFoundError, - "The address 10.10.10.10 is not in the " "database.", + "The address 10.10.10.10 is not in the database.", ): reader.city("10.10.10.10") reader.close() @@ -56,7 +56,7 @@ def test_wrong_database(self) -> None: reader = geoip2.database.Reader("tests/data/test-data/GeoIP2-City-Test.mmdb") with self.assertRaisesRegex( TypeError, - "The country method cannot be used with " "the GeoIP2-City database", + "The country method cannot be used with the GeoIP2-City database", ): reader.country("1.1.1.1") reader.close() @@ -64,7 +64,7 @@ def test_wrong_database(self) -> None: def test_invalid_address(self) -> None: reader = geoip2.database.Reader("tests/data/test-data/GeoIP2-City-Test.mmdb") with self.assertRaisesRegex( - ValueError, "u?'invalid' does not appear to be an " "IPv4 or IPv6 address" + ValueError, "u?'invalid' does not appear to be an IPv4 or IPv6 address" ): reader.city("invalid") reader.close() @@ -82,7 +82,7 @@ def test_anonymous_ip(self) -> None: self.assertEqual(record.is_public_proxy, False) self.assertEqual(record.is_residential_proxy, False) self.assertEqual(record.is_tor_exit_node, False) - self.assertEqual(record.ip_address, ip_address) + self.assertEqual(record.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual(record.network, ipaddress.ip_network("1.2.0.0/16")) reader.close() @@ -99,7 +99,7 @@ def test_anonymous_ip_all_set(self) -> None: self.assertEqual(record.is_public_proxy, True) self.assertEqual(record.is_residential_proxy, True) self.assertEqual(record.is_tor_exit_node, True) - self.assertEqual(record.ip_address, ip_address) + self.assertEqual(record.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual(record.network, ipaddress.ip_network("81.2.69.0/24")) reader.close() @@ -113,7 +113,7 @@ def test_asn(self) -> None: self.assertEqual(record.autonomous_system_number, 1221) self.assertEqual(record.autonomous_system_organization, "Telstra Pty Ltd") - self.assertEqual(record.ip_address, ip_address) + self.assertEqual(record.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual(record.network, ipaddress.ip_network("1.128.0.0/11")) self.assertRegex( @@ -156,12 +156,12 @@ def test_connection_type(self) -> None: ) self.assertEqual(record.connection_type, "Cellular") - self.assertEqual(record.ip_address, ip_address) + self.assertEqual(record.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual(record.network, ipaddress.ip_network("1.0.1.0/24")) self.assertRegex( str(record), - r"ConnectionType\(\{.*Cellular.*\}\)", + r"ConnectionType\(.*Cellular.*\)", "ConnectionType str representation is reasonable", ) @@ -171,7 +171,9 @@ def test_country(self) -> None: reader = geoip2.database.Reader("tests/data/test-data/GeoIP2-Country-Test.mmdb") record = reader.country("81.2.69.160") self.assertEqual( - record.traits.ip_address, "81.2.69.160", "IP address is added to model" + record.traits.ip_address, + ipaddress.ip_address("81.2.69.160"), + "IP address is added to model", ) self.assertEqual(record.traits.network, ipaddress.ip_network("81.2.69.160/27")) self.assertEqual(record.country.is_in_european_union, False) @@ -192,12 +194,12 @@ def test_domain(self) -> None: self.assertEqual(record, eval(repr(record)), "Domain repr can be eval'd") self.assertEqual(record.domain, "maxmind.com") - self.assertEqual(record.ip_address, ip_address) + self.assertEqual(record.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual(record.network, ipaddress.ip_network("1.2.0.0/16")) self.assertRegex( str(record), - r"Domain\(\{.*maxmind.com.*\}\)", + r"Domain\(.*maxmind.com.*\)", "Domain str representation is reasonable", ) @@ -217,7 +219,7 @@ def test_enterprise(self) -> None: self.assertEqual(record.registered_country.is_in_european_union, False) self.assertEqual(record.traits.connection_type, "Cable/DSL") self.assertTrue(record.traits.is_legitimate_proxy) - self.assertEqual(record.traits.ip_address, ip_address) + self.assertEqual(record.traits.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual( record.traits.network, ipaddress.ip_network("74.209.16.0/20") ) @@ -242,12 +244,12 @@ def test_isp(self) -> None: self.assertEqual(record.autonomous_system_organization, "Telstra Pty Ltd") self.assertEqual(record.isp, "Telstra Internet") self.assertEqual(record.organization, "Telstra Internet") - self.assertEqual(record.ip_address, ip_address) + self.assertEqual(record.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual(record.network, ipaddress.ip_network("1.128.0.0/11")) self.assertRegex( str(record), - r"ISP\(\{.*Telstra.*\}\)", + r"ISP\(.*Telstra.*\)", "ISP str representation is reasonable", ) @@ -261,7 +263,9 @@ def test_context_manager(self) -> None: "tests/data/test-data/GeoIP2-Country-Test.mmdb" ) as reader: record = reader.country("81.2.69.160") - self.assertEqual(record.traits.ip_address, "81.2.69.160") + self.assertEqual( + record.traits.ip_address, ipaddress.ip_address("81.2.69.160") + ) @patch("maxminddb.open_database") def test_modes(self, mock_open) -> None: diff --git a/tests/models_test.py b/tests/models_test.py index 9f4fa72e..3f72ec27 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals import sys +import ipaddress from typing import Dict import unittest @@ -13,6 +14,9 @@ class TestModels(unittest.TestCase): + def setUp(self): + self.maxDiff = 20_000 + def test_insights_full(self) -> None: raw = { "city": { @@ -85,7 +89,6 @@ def test_insights_full(self) -> None: "is_satellite_provider": True, "is_tor_exit_node": True, "isp": "Comcast", - "network_speed": "cable/DSL", "organization": "Blorg", "static_ip_score": 1.3, "user_count": 2, @@ -93,7 +96,7 @@ def test_insights_full(self) -> None: }, } - model = geoip2.models.Insights(raw) + model = geoip2.models.Insights(["en"], **raw) # type: ignore self.assertEqual( type(model), geoip2.models.Insights, "geoip2.models.Insights object" ) @@ -131,7 +134,7 @@ def test_insights_full(self) -> None: self.assertEqual( type(model.traits), geoip2.records.Traits, "geoip2.records.Traits object" ) - self.assertEqual(model.raw, raw, "raw method returns raw input") + self.assertEqual(model.to_dict(), raw, "to_dict() method matches raw input") self.assertEqual( model.subdivisions[0].iso_code, "MN", "div 1 has correct iso_code" ) @@ -172,7 +175,7 @@ def test_insights_full(self) -> None: self.assertRegex( str(model), - r"^geoip2.models.Insights\(\{.*geoname_id.*\}, \[.*en.*\]\)", + r"^geoip2.models.Insights\(\[.*en.*\]\, .*geoname_id.*\)", "Insights str representation looks reasonable", ) @@ -205,7 +208,7 @@ def test_insights_full(self) -> None: self.assertEqual(model.traits.static_ip_score, 1.3) def test_insights_min(self) -> None: - model = geoip2.models.Insights({"traits": {"ip_address": "5.6.7.8"}}) + model = geoip2.models.Insights(["en"], traits={"ip_address": "5.6.7.8"}) self.assertEqual( type(model), geoip2.models.Insights, "geoip2.models.Insights object" ) @@ -236,8 +239,7 @@ def test_insights_min(self) -> None: self.assertEqual( type(model.subdivisions.most_specific), geoip2.records.Subdivision, - "geoip2.records.Subdivision object returned even" - "when none are available.", + "geoip2.records.Subdivision object returned even when none are available.", ) self.assertEqual( model.subdivisions.most_specific.names, {}, "Empty names hash returned" @@ -265,7 +267,7 @@ def test_city_full(self) -> None: "is_satellite_provider": True, }, } - model = geoip2.models.City(raw) + model = geoip2.models.City(["en"], **raw) # type: ignore self.assertEqual(type(model), geoip2.models.City, "geoip2.models.City object") self.assertEqual( type(model.city), geoip2.records.City, "geoip2.records.City object" @@ -291,7 +293,9 @@ def test_city_full(self) -> None: self.assertEqual( type(model.traits), geoip2.records.Traits, "geoip2.records.Traits object" ) - self.assertEqual(model.raw, raw, "raw method returns raw input") + self.assertEqual( + model.to_dict(), raw, "to_dict method output matches raw input" + ) self.assertEqual(model.continent.geoname_id, 42, "continent geoname_id is 42") self.assertEqual(model.continent.code, "NA", "continent code is NA") self.assertEqual( @@ -339,57 +343,58 @@ def test_city_full(self) -> None: True, "traits is_setellite_provider is True", ) - self.assertEqual(model.raw, raw, "raw method produces raw output") + self.assertEqual(model.to_dict(), raw, "to_dict method matches raw input") self.assertRegex( - str(model), r"^geoip2.models.City\(\{.*geoname_id.*\}, \[.*en.*\]\)" + str(model), r"^geoip2.models.City\(\[.*en.*\], .*geoname_id.*\)" ) self.assertFalse(model == True, "__eq__ does not blow up on weird input") def test_unknown_keys(self) -> None: model = geoip2.models.City( - { - "city": {"invalid": 0}, - "continent": { - "invalid": 0, - "names": {"invalid": 0}, - }, - "country": { - "invalid": 0, - "names": {"invalid": 0}, - }, - "location": {"invalid": 0}, - "postal": {"invalid": 0}, - "subdivisions": [ - { - "invalid": 0, - "names": { - "invalid": 0, - }, - }, - ], - "registered_country": { + ["en"], + city={"invalid": 0}, + continent={ + "invalid": 0, + "names": {"invalid": 0}, + }, + country={ + "invalid": 0, + "names": {"invalid": 0}, + }, + location={"invalid": 0}, + postal={"invalid": 0}, + subdivisions=[ + { "invalid": 0, "names": { "invalid": 0, }, }, - "represented_country": { + ], + registered_country={ + "invalid": 0, + "names": { "invalid": 0, - "names": { - "invalid": 0, - }, }, - "traits": {"ip_address": "1.2.3.4", "invalid": "blah"}, - "unk_base": {"blah": 1}, - } + }, + represented_country={ + "invalid": 0, + "names": { + "invalid": 0, + }, + }, + traits={"ip_address": "1.2.3.4", "invalid": "blah"}, + unk_base={"blah": 1}, ) with self.assertRaises(AttributeError): model.unk_base # type: ignore with self.assertRaises(AttributeError): model.traits.invalid # type: ignore - self.assertEqual(model.traits.ip_address, "1.2.3.4", "correct ip") + self.assertEqual( + model.traits.ip_address, ipaddress.ip_address("1.2.3.4"), "correct ip" + ) class TestNames(unittest.TestCase): @@ -422,7 +427,7 @@ class TestNames(unittest.TestCase): } def test_names(self) -> None: - model = geoip2.models.Country(self.raw, locales=["sq", "ar"]) + model = geoip2.models.Country(["sq", "ar"], **self.raw) self.assertEqual( model.continent.names, self.raw["continent"]["names"], @@ -435,7 +440,7 @@ def test_names(self) -> None: ) def test_three_locales(self) -> None: - model = geoip2.models.Country(self.raw, locales=["fr", "zh-CN", "en"]) + model = geoip2.models.Country(locales=["fr", "zh-CN", "en"], **self.raw) self.assertEqual( model.continent.name, "北美洲", @@ -444,16 +449,16 @@ def test_three_locales(self) -> None: self.assertEqual(model.country.name, "États-Unis", "country name is in French") def test_two_locales(self) -> None: - model = geoip2.models.Country(self.raw, locales=["ak", "fr"]) + model = geoip2.models.Country(locales=["ak", "fr"], **self.raw) self.assertEqual( model.continent.name, None, - "continent name is undef (no Akan or French " "available)", + "continent name is undef (no Akan or French available)", ) self.assertEqual(model.country.name, "États-Unis", "country name is in French") def test_unknown_locale(self) -> None: - model = geoip2.models.Country(self.raw, locales=["aa"]) + model = geoip2.models.Country(locales=["aa"], **self.raw) self.assertEqual( model.continent.name, None, "continent name is undef (no Afar available)" ) @@ -462,7 +467,7 @@ def test_unknown_locale(self) -> None: ) def test_german(self) -> None: - model = geoip2.models.Country(self.raw, locales=["de"]) + model = geoip2.models.Country(locales=["de"], **self.raw) self.assertEqual( model.continent.name, "Nordamerika", "Correct german name for continent" ) diff --git a/tests/webservice_test.py b/tests/webservice_test.py index c50e5fc3..0d6cc496 100644 --- a/tests/webservice_test.py +++ b/tests/webservice_test.py @@ -110,7 +110,7 @@ def test_country_ok(self): country.traits.network, ipaddress.ip_network("1.2.3.0/24"), "network" ) self.assertTrue(country.traits.is_anycast) - self.assertEqual(country.raw, self.country, "raw response is correct") + self.assertEqual(country.to_dict(), self.country, "raw response is correct") def test_me(self): self.httpserver.expect_request( @@ -147,7 +147,7 @@ def test_200_error(self): def test_bad_ip_address(self): with self.assertRaisesRegex( - ValueError, "'1.2.3' does not appear to be an IPv4 " "or IPv6 address" + ValueError, "'1.2.3' does not appear to be an IPv4 or IPv6 address" ): self.run_client(self.client.country("1.2.3")) @@ -165,7 +165,6 @@ def test_no_body_error(self): self.run_client(self.client.country("1.2.3.7")) def test_weird_body_error(self): - self.httpserver.expect_request( "/geoip/v2.1/country/1.2.3.8", method="GET" ).respond_with_json( @@ -176,12 +175,11 @@ def test_weird_body_error(self): with self.assertRaisesRegex( HTTPError, - "Response contains JSON but it does not " "specify code or error keys", + "Response contains JSON but it does not specify code or error keys", ): self.run_client(self.client.country("1.2.3.8")) def test_bad_body_error(self): - self.httpserver.expect_request( "/geoip/v2.1/country/1.2.3.9", method="GET" ).respond_with_data( @@ -206,7 +204,6 @@ def test_500_error(self): self.run_client(self.client.country("1.2.3.10")) def test_300_error(self): - self.httpserver.expect_request( "/geoip/v2.1/country/1.2.3.11", method="GET" ).respond_with_data( @@ -360,6 +357,7 @@ def setUp(self): self.client_class = Client self.client = Client(42, "abcdef123456") self.client._base_uri = self.httpserver.url_for("/geoip/v2.1") + self.maxDiff = 20_000 def run_client(self, v): return v @@ -371,6 +369,7 @@ def setUp(self): self.client_class = AsyncClient self.client = AsyncClient(42, "abcdef123456") self.client._base_uri = self.httpserver.url_for("/geoip/v2.1") + self.maxDiff = 20_000 def tearDown(self): self._loop.run_until_complete(self.client.close())