diff --git a/ipinfo_django/middleware.py b/ipinfo_django/middleware.py index ce21d83..3d31354 100644 --- a/ipinfo_django/middleware.py +++ b/ipinfo_django/middleware.py @@ -119,3 +119,23 @@ def __init__(self, get_response): ipinfo_token = getattr(settings, "IPINFO_TOKEN", None) ipinfo_settings = getattr(settings, "IPINFO_SETTINGS", {}) self.ipinfo = ipinfo.getHandlerAsyncCore(ipinfo_token, **ipinfo_settings) + + +class IPinfoPlusMiddleware(IPinfoMiddleware): + def __init__(self, get_response): + super().__init__(get_response=get_response) + + ipinfo_token = getattr(settings, "IPINFO_TOKEN", None) + ipinfo_settings = getattr(settings, "IPINFO_SETTINGS", {}) + self.ipinfo = ipinfo.getHandlerPlus(ipinfo_token, **ipinfo_settings) + + +class IPinfoAsyncPlusMiddleware(IPinfoAsyncMiddleware): + sync_capable = False + async_capable = True + + def __init__(self, get_response): + super().__init__(get_response=get_response) + ipinfo_token = getattr(settings, "IPINFO_TOKEN", None) + ipinfo_settings = getattr(settings, "IPINFO_SETTINGS", {}) + self.ipinfo = ipinfo.getHandlerAsyncPlus(ipinfo_token, **ipinfo_settings) diff --git a/tests/conftest.py b/tests/conftest.py index f34ebb9..9a14bbe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,3 +41,17 @@ def ipinfo_async_core_middleware(settings): settings.MIDDLEWARE = [ "ipinfo_django.middleware.IPinfoAsyncCoreMiddleware", ] + + +@pytest.fixture +def ipinfo_plus_middleware(settings): + settings.MIDDLEWARE = [ + "ipinfo_django.middleware.IPinfoPlusMiddleware", + ] + + +@pytest.fixture +def ipinfo_async_plus_middleware(settings): + settings.MIDDLEWARE = [ + "ipinfo_django.middleware.IPinfoAsyncPlusMiddleware", + ] diff --git a/tests/test_async_plus_middleware.py b/tests/test_async_plus_middleware.py new file mode 100644 index 0000000..86b746d --- /dev/null +++ b/tests/test_async_plus_middleware.py @@ -0,0 +1,43 @@ +from http import HTTPStatus +from unittest import mock + +import pytest +from ipinfo.details import Details + + +@pytest.mark.asyncio +async def test_middleware_appends_ip_info(async_client, ipinfo_async_plus_middleware): + with mock.patch("ipinfo.AsyncHandlerPlus.getDetails") as mocked_getDetails: + mocked_getDetails.return_value = Details({"ip": "127.0.0.1"}) + res = await async_client.get("/test_view/") + assert res.status_code == HTTPStatus.OK + assert b"For testing: 127.0.0.1" in res.content + + +@pytest.mark.asyncio +async def test_middleware_filters(async_client, ipinfo_async_plus_middleware): + res = await async_client.get("/test_view/", USER_AGENT="some bot") + assert res.status_code == HTTPStatus.OK + assert b"Request filtered." in res.content + + +@pytest.mark.asyncio +async def test_middleware_behind_proxy(async_client, ipinfo_async_plus_middleware): + with mock.patch("ipinfo.AsyncHandlerPlus.getDetails") as mocked_getDetails: + mocked_getDetails.return_value = Details({"ip": "93.44.186.197"}) + res = await async_client.get("/test_view/", X_FORWARDED_FOR="93.44.186.197") + + mocked_getDetails.assert_called_once_with("93.44.186.197") + assert res.status_code == HTTPStatus.OK + assert b"For testing: 93.44.186.197" in res.content + + +@pytest.mark.asyncio +async def test_middleware_not_behind_proxy(async_client, ipinfo_async_plus_middleware): + with mock.patch("ipinfo.AsyncHandlerPlus.getDetails") as mocked_getDetails: + mocked_getDetails.return_value = Details({"ip": "127.0.0.1"}) + res = await async_client.get("/test_view/") + + mocked_getDetails.assert_called_once_with("127.0.0.1") + assert res.status_code == HTTPStatus.OK + assert b"For testing: 127.0.0.1" in res.content diff --git a/tests/test_plus_middleware.py b/tests/test_plus_middleware.py new file mode 100644 index 0000000..4dbbc45 --- /dev/null +++ b/tests/test_plus_middleware.py @@ -0,0 +1,38 @@ +from http import HTTPStatus +from unittest import mock + +from ipinfo.details import Details + + +def test_middleware_appends_ip_info(client, ipinfo_plus_middleware): + with mock.patch("ipinfo.HandlerPlus.getDetails") as mocked_getDetails: + mocked_getDetails.return_value = Details({"ip": "127.0.0.1"}) + res = client.get("/test_view/") + assert res.status_code == HTTPStatus.OK + assert b"For testing: 127.0.0.1" in res.content + + +def test_middleware_filters(client, ipinfo_plus_middleware): + res = client.get("/test_view/", HTTP_USER_AGENT="some bot") + assert res.status_code == HTTPStatus.OK + assert b"Request filtered." in res.content + + +def test_middleware_behind_proxy(client, ipinfo_plus_middleware): + with mock.patch("ipinfo.HandlerPlus.getDetails") as mocked_getDetails: + mocked_getDetails.return_value = Details({"ip": "93.44.186.197"}) + res = client.get("/test_view/", HTTP_X_FORWARDED_FOR="93.44.186.197") + + mocked_getDetails.assert_called_once_with("93.44.186.197") + assert res.status_code == HTTPStatus.OK + assert b"For testing: 93.44.186.197" in res.content + + +def test_middleware_not_behind_proxy(client, ipinfo_plus_middleware): + with mock.patch("ipinfo.HandlerPlus.getDetails") as mocked_getDetails: + mocked_getDetails.return_value = Details({"ip": "127.0.0.1"}) + res = client.get("/test_view/") + + mocked_getDetails.assert_called_once_with("127.0.0.1") + assert res.status_code == HTTPStatus.OK + assert b"For testing: 127.0.0.1" in res.content