Skip to content
Merged

test #114

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ LITELLM_LOG=FAKE_LOG_LEVEL
HASH_SALT="FAKE_HASH_SALT"
HASH_ALGO="FAKE_HASH_ALGO"
AUTH_TOKEN_EXPIRATION=9999
DATA_COLLECTION_HOST_PREFIX="fake_prefix"
DATA_COLLECTION_ORIGIN_PREFIX="fake_prefix"
2 changes: 1 addition & 1 deletion k8s/welearn-api/values.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ config:
nonSensitive:
PG_HOST: dev-lab-projects-backend.postgres.database.azure.com
TIKA_URL_BASE: https://tika.k8s.lp-i.dev/
DATA_COLLECTION_HOST_PREFIX: welearn
DATA_COLLECTION_ORIGIN_PREFIX: welearn
allowedHostsRegexes:
mainUrl: |-
https:\/\/welearn\.k8s\.lp-i\.dev
Expand Down
2 changes: 1 addition & 1 deletion k8s/welearn-api/values.prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ config:
nonSensitive:
PG_HOST: prod-prod-projects-backend.postgres.database.azure.com
TIKA_URL_BASE: https://tika.k8s.lp-i.org/
DATA_COLLECTION_HOST_PREFIX: workshop
DATA_COLLECTION_ORIGIN_PREFIX: workshop
allowedHostsRegexes:
alphaUrls: |-
https://[a-zA-Z0-9-]*\.alpha-welearn\.lp-i\.org
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ env =
RUN_ENV=development
TIKA_URL_BASE=https://tika.example.com
USE_CACHED_SETTINGS=True
DATA_COLLECTION_HOST_PREFIX=workshop
DATA_COLLECTION_ORIGIN_PREFIX=workshop

filterwarnings =
ignore:.*U.*mode is deprecated:DeprecationWarning
2 changes: 1 addition & 1 deletion src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Settings(BaseSettings):
"""

BACKEND_CORS_ORIGINS_REGEX: str = CLIENT_ORIGINS_REGEX
DATA_COLLECTION_HOST_PREFIX: str
DATA_COLLECTION_ORIGIN_PREFIX: str

def get_api_version(self, cls):
return {
Expand Down
23 changes: 14 additions & 9 deletions src/app/services/data_collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import uuid
from datetime import datetime, timedelta
from typing import Any
Expand Down Expand Up @@ -25,13 +26,15 @@


class DataCollection:
def __init__(self, host: str):
def __init__(self, origin: str):
is_campaign_active = self.get_campaign_state()
host_settings = settings.DATA_COLLECTION_HOST_PREFIX
self.should_collect = host.startswith(host_settings) and is_campaign_active
origin_settings = settings.DATA_COLLECTION_ORIGIN_PREFIX.strip()

self.should_collect = origin.startswith(origin_settings) and is_campaign_active
logger.info(
"data_collection: host_settings=%s, is_campaign=%s, should_collect=%s",
host_settings,
"data_collection: origin=%s, origin_settings=%s, is_campaign=%s, should_collect=%s",
origin,
origin_settings,
is_campaign_active,
self.should_collect,
)
Expand Down Expand Up @@ -114,7 +117,9 @@ async def register_document_click(


def get_data_collection_service(request: Request) -> DataCollection:
host = request.url.hostname
if host is None:
return DataCollection(host="")
return DataCollection(host=host)
origin = request.headers["origin"]
stripped_origin = re.sub(r"https?://www\.|https?://", "", origin).strip("/")

if stripped_origin is None:
return DataCollection(origin="")
return DataCollection(origin=stripped_origin)
6 changes: 3 additions & 3 deletions src/app/tests/api/api_v1/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def test_chat(self, chat_mock, *mocks):
response = client.post(
f"{settings.API_V1_STR}/qna/chat/answer",
json=JSON,
headers={"X-API-Key": "test"},
headers={"X-API-Key": "test", "origin": "test"},
)

response_json = response.json()
Expand All @@ -97,7 +97,7 @@ async def test_chat_empty_history(self, chat_mock, *mocks):
response = client.post(
f"{settings.API_V1_STR}/qna/chat/answer",
json=JSON_NO_HIST,
headers={"X-API-Key": "test"},
headers={"X-API-Key": "test", "origin": "test"},
)

chat_mock.assert_called_with(
Expand Down Expand Up @@ -140,7 +140,7 @@ async def test_chat_not_supported_lang(self, chat_mock, *mocks):
response = client.post(
f"{settings.API_V1_STR}/qna/chat/answer",
json=JSON_NO_HIST,
headers={"X-API-Key": "test"},
headers={"X-API-Key": "test", "origin": "test"},
)
self.assertEqual(response.status_code, 400)

Expand Down
16 changes: 8 additions & 8 deletions src/app/tests/services/test_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_campaign_active(self, mock_get_campaign):
mock_campaign.is_active = True
mock_get_campaign.return_value = mock_campaign

dc = DataCollection(host="workshop.example.com")
dc = DataCollection(origin="workshop.example.com")

self.assertTrue(dc.should_collect)

Expand All @@ -32,17 +32,17 @@ def test_campaign_inactive(self, mock_get_campaign):
mock_campaign.is_active = False
mock_get_campaign.return_value = mock_campaign

dc = DataCollection(host="workshop.example.com")
dc = DataCollection(origin="workshop.example.com")

self.assertFalse(dc.should_collect)

@patch("src.app.services.data_collection.get_current_data_collection_campaign")
def test_non_workshop_host(self, mock_get_campaign):
def test_non_workshop_origin(self, mock_get_campaign):
mock_campaign = MagicMock()
mock_campaign.is_active = True
mock_get_campaign.return_value = mock_campaign

dc = DataCollection(host="example.com")
dc = DataCollection(origin="example.com")

self.assertFalse(dc.should_collect)

Expand Down Expand Up @@ -74,7 +74,7 @@ async def test_register_chat_data_success(
mock_write_query.return_value = conversation_id
mock_write_answer.return_value = message_id

dc = DataCollection(host="workshop.example.com")
dc = DataCollection(origin="workshop.example.com")

result = await dc.register_chat_data(
session_id=str(uuid.uuid4()),
Expand All @@ -95,7 +95,7 @@ async def test_register_chat_data_success(
@patch("src.app.services.data_collection.get_user_from_session_id")
@patch("src.app.services.data_collection.get_current_data_collection_campaign")
async def test_register_chat_data_no_session(self, *args):
dc = DataCollection(host="workshop.example.com")
dc = DataCollection(origin="workshop.example.com")

with self.assertRaises(HTTPException) as ctx:
await dc.register_chat_data(
Expand All @@ -119,7 +119,7 @@ async def test_register_chat_data_no_session(self, *args):
async def test_register_chat_data_user_not_found(self, mock_campaign, _, __):
mock_campaign.return_value = MagicMock(is_active=True)

dc = DataCollection(host="workshop.example.com")
dc = DataCollection(origin="workshop.example.com")

with self.assertRaises(HTTPException) as ctx:
await dc.register_chat_data(
Expand All @@ -144,7 +144,7 @@ class TestRegisterDocumentClick(unittest.IsolatedAsyncioTestCase):
async def test_register_document_click(self, mock_campaign, mock_update, _):
mock_campaign.return_value = MagicMock(is_active=True)

dc = DataCollection(host="workshop.example.com")
dc = DataCollection(origin="workshop.example.com")

doc_id = uuid.uuid4()
message_id = uuid.uuid4()
Expand Down