From 90fd419a2ac6c232e635c5a2b631e1c71e6f5f89 Mon Sep 17 00:00:00 2001 From: Anne Rodenburg Date: Thu, 20 Feb 2025 15:34:17 +0100 Subject: [PATCH 1/2] Added new version and included functions in v1 --- DELETEMEv1.ipynb | 115 ++++++++++++++++ DELETEMEv2.ipynb | 169 +++++++++++++++++++++++ needle/v1/collections/__init__.py | 31 ++++- needle/v1/collections/files.py | 86 +++++++++++- needle/v2/__init__.py | 47 +++++++ needle/v2/collections.py | 219 ++++++++++++++++++++++++++++++ needle/v2/files.py | 190 ++++++++++++++++++++++++++ needle/v2/models.py | 126 +++++++++++++++++ 8 files changed, 980 insertions(+), 3 deletions(-) create mode 100644 DELETEMEv1.ipynb create mode 100644 DELETEMEv2.ipynb create mode 100644 needle/v2/__init__.py create mode 100644 needle/v2/collections.py create mode 100644 needle/v2/files.py create mode 100644 needle/v2/models.py diff --git a/DELETEMEv1.ipynb b/DELETEMEv1.ipynb new file mode 100644 index 0000000..92159c2 --- /dev/null +++ b/DELETEMEv1.ipynb @@ -0,0 +1,115 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from needle.v1 import NeedleClient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### TEST attributes\n", + "api_key = \"\"\n", + "collection_name = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "needle_manager = NeedleClient(api_key)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Set the right collection\n", + "\n", + "collection_id = needle_manager.collections.identify_collection(collection_name)\n", + "needle_manager.collections.update_collection_id(collection_id)\n", + "needle_manager.collections.files.update_collection_id(collection_id)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "https://www.producthunt.com/products/needle-3 fle_01JMHW440SNS9619NZ01FJCX7X pending\n" + ] + } + ], + "source": [ + "# Get an overview of all existing files in the respective collection\n", + "needle_manager.collections.files.list(collection_id)\n", + "\n", + "for a in needle_manager.collections.files.list(collection_id):\n", + " print(a.name, a.id, a.status)\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[6], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# Update existing url if necessary\u001b[39;00m\n\u001b[0;32m 2\u001b[0m producthunt_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://www.producthunt.com/products/needle-3\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m----> 3\u001b[0m \u001b[43mneedle_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollections\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfiles\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_url\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproducthunt_url\u001b[49m\u001b[43m \u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v1\\collections\\files.py:175\u001b[0m, in \u001b[0;36mNeedleCollectionsFiles.add_url\u001b[1;34m(self, file_url, collection_id, overwrite)\u001b[0m\n\u001b[0;32m 168\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdelete([file_id])\n\u001b[0;32m 170\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madd(\n\u001b[0;32m 171\u001b[0m collection_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id,\n\u001b[0;32m 172\u001b[0m files\u001b[38;5;241m=\u001b[39m[FileToAdd(name\u001b[38;5;241m=\u001b[39mfile_url, url\u001b[38;5;241m=\u001b[39mfile_url)]\n\u001b[0;32m 173\u001b[0m )\n\u001b[1;32m--> 175\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait_for_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 176\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSuccessfully added \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile_url\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to collection \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v1\\collections\\files.py:188\u001b[0m, in \u001b[0;36mNeedleCollectionsFiles._wait_for_indexing\u001b[1;34m(self, check_interval)\u001b[0m\n\u001b[0;32m 186\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(f\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mindexed\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m all_files):\n\u001b[0;32m 187\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m--> 188\u001b[0m \u001b[43msleep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcheck_interval\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Update existing url if necessary\n", + "producthunt_url = \"https://www.producthunt.com/products/needle-3\"\n", + "needle_manager.collections.files.add_url(producthunt_url , overwrite=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/DELETEMEv2.ipynb b/DELETEMEv2.ipynb new file mode 100644 index 0000000..6811b14 --- /dev/null +++ b/DELETEMEv2.ipynb @@ -0,0 +1,169 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from needle.v2 import NeedleClient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### TEST attributes\n", + "api_key = \"\"\n", + "collection_name = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "needle_manager = NeedleClient(api_key)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Set the right collection\n", + "\n", + "\n", + "collection_id = needle_manager.collections.identify_collection(collection_name)\n", + "needle_manager.collections.update_collection_id(collection_id)\n", + "needle_manager.files.update_collection_id(collection_id)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "staging_operation_companies.txt fle_01JM4N6N287WYGPDNQECJRSWK7 indexed\n", + "staging_operation_api_tracking.txt fle_01JM4N6N296TWY6Z0022Y1GDXN indexed\n", + "staging_operation_datasources.txt fle_01JM4N6N29FNG7SFY9AAMNAT5F indexed\n", + "staging_flatfiles_countries.txt fle_01JM4N6N29J4VJQG3V88J6SJR3 indexed\n", + "staging_hubspot_engagements.txt fle_01JM4N6N29K5N68KNRD8DHNSP0 indexed\n", + "staging_operation_users.txt fle_01JM4N6N29VAYK5DP0GACJ25M1 indexed\n", + "staging_hubspot_owners.txt fle_01JM4N6N2ACMZN04WEECJA903T indexed\n", + "staging_personio_attendances.txt fle_01JM4N6N2AMEZ4H3QX34PSA5JP indexed\n", + "staging_personio_absences.txt fle_01JM4N6N2AQXNCR1DKJ1ENRRAV indexed\n", + "staging_sevdesk_invoices.txt fle_01JM4N6N2AV099XS4730HZ8QXW indexed\n", + "staging_sevdesk_contacts.txt fle_01JM4N6N2AX9FC3WQ8FMMGZJPE indexed\n", + "staging_sevdesk_receipts.txt fle_01JM4N6N2B7E8KSV3EE61S96GG indexed\n", + "staging_flatfiles_portals_suntrol.txt fle_01JM4N6N2BD7R4HGHXK7WS55BY indexed\n", + "staging_hubspot_pipelines.txt fle_01JM4N6N2BECJ1G76JEK08D0DJ indexed\n", + "staging_flatfiles_target.txt fle_01JM4N6N2BFR7QYJC8SNHA34X4 indexed\n", + "staging_personio_employees.txt fle_01JM4N6N2BNTPX7WTXXDWK79JW indexed\n", + "staging_hubspot_feedback_submissions.txt fle_01JM4N6N2BXSA0A81FBBMZN1CY indexed\n", + "staging_flatfiles_jira_issues.txt fle_01JM4N6N2BY9PRJEZZT1J8EP89 indexed\n", + "staging_flatfiles_ledgers_booked.txt fle_01JM4N6N2BYYAZVX727S0ZWJT0 indexed\n", + "staging_flatfiles_azure_costs.txt fle_01JM4N6N2BZNHQEF3M6C9YZF3V indexed\n", + "staging_sevdesk_payments.txt fle_01JM4N6N2CA7254V51DN54WQ60 indexed\n", + "staging_flatfiles_personnel_expenses.txt fle_01JM4N6N2CBMGY066NB5TJX3AS indexed\n", + "staging_operation_datasources_events.txt fle_01JM4N6N2CC8B6EVRC5GBA3SZD indexed\n", + "staging_hubspot_companies.txt fle_01JM4N6N2CES2H90EN2X5FTVBR indexed\n", + "staging_hubspot_products.txt fle_01JM4N6N2CH53MRR56DFMQ16KV indexed\n", + "staging_flatfiles_portals_solytic_1.txt fle_01JM4N6N2CNBWVY91BSZHW8767 indexed\n", + "staging_hubspot_deals.txt fle_01JM4N6N2CVCQMTTTQETYGS50G indexed\n", + "staging_hubspot_contacts.txt fle_01JM4N6N2CVM8TFWVFJVV7AV4M indexed\n", + "staging_operation_sites.txt fle_01JM4N6N2CWZ03RCSX8327DQZ7 indexed\n", + "staging_hubspot_tickets.txt fle_01JM4N6N2CY1XT5FYTY9QN51PT indexed\n", + "staging_hubspot_line_items.txt fle_01JM4N6N2CZQ331QMRA7KQTPVN indexed\n", + "staging_flatfiles_booked_revenue.txt fle_01JM4N6N2DA2CWYR88PMBH9G7Q indexed\n" + ] + } + ], + "source": [ + "# Get an overview of all existing files in the respective collection\n", + "needle_manager.files.list(collection_id)\n", + "\n", + "for a in needle_manager.files.list(collection_id):\n", + " print(a.name, a.id, a.status)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "for a in needle_manager.files.list(collection_id):\n", + " if a.name == \"https://www.producthunt.com/products/needle-3\":\n", + " print(a.name, a.id, a.status)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[14], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# Update existing url if necessary\u001b[39;00m\n\u001b[0;32m 2\u001b[0m producthunt_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://www.producthunt.com/products/needle-3\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m----> 3\u001b[0m \u001b[43mneedle_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfiles\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_file_from_url\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproducthunt_url\u001b[49m\u001b[43m \u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v2\\files.py:171\u001b[0m, in \u001b[0;36madd_file_from_url\u001b[1;34m(self, file_url, overwrite)\u001b[0m\n\u001b[0;32m 165\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madd(\n\u001b[0;32m 166\u001b[0m collection_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id,\n\u001b[0;32m 167\u001b[0m files\u001b[38;5;241m=\u001b[39m[FileToAdd(name\u001b[38;5;241m=\u001b[39mfile_url, url\u001b[38;5;241m=\u001b[39mfile_url)]\n\u001b[0;32m 168\u001b[0m )\n\u001b[0;32m 170\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait_for_indexing(file_name \u001b[38;5;241m=\u001b[39m file_url)\n\u001b[1;32m--> 171\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSuccessfully added \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile_url\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to collection \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v2\\files.py:190\u001b[0m, in \u001b[0;36m_wait_for_indexing\u001b[1;34m(self, check_interval, file_name)\u001b[0m\n\u001b[0;32m 0\u001b[0m \n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Update existing url if necessary\n", + "producthunt_url = \"https://www.producthunt.com/products/needle-3\"\n", + "needle_manager.files.add_file_from_url(producthunt_url , overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the relevant data from the rag\n", + "rag_prompt = \"\"\"Retrieve people who commented or reviewd the product.Exclude unrelated content.\"\"\"\n", + "rag_results = needle_manager.collection.search_collection(rag_prompt)\n", + "print(rag_results)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/needle/v1/collections/__init__.py b/needle/v1/collections/__init__.py index 3f29ee6..e5c9b88 100644 --- a/needle/v1/collections/__init__.py +++ b/needle/v1/collections/__init__.py @@ -4,7 +4,7 @@ from typing import Optional -import requests +import requests # type: ignore from needle.v1.models import ( NeedleConfig, @@ -73,7 +73,7 @@ def create(self, name: str, file_ids: Optional[list[str]] = None): created_at=c.get("created_at"), updated_at=c.get("updated_at"), ) - + def get(self, collection_id: str): """ Retrieves a collection by its ID. @@ -131,6 +131,33 @@ def list(self): for c in body.get("result") ] + def identify_collection(self, collection_name: str) -> str: + """Get collection ID from name. + + Args: + collection_name: Name of the collection + + Returns: + Collection ID + + Raises: + ValueError: If collection name is not found + """ + collections_list = self.list() + collection_info = {collection.name: collection.id for collection in collections_list} + + if collection_name not in collection_info: + raise ValueError(f"Collection '{collection_name}' not found") + return collection_info[collection_name] + + def update_collection_id(self, collection_id: str) -> None: + """Set the self.collection ID + + Args: + collection_id: New collection ID to use + """ + self.collection_id = collection_id + def search( self, collection_id: str, diff --git a/needle/v1/collections/files.py b/needle/v1/collections/files.py index 8a9fd5c..a23feac 100644 --- a/needle/v1/collections/files.py +++ b/needle/v1/collections/files.py @@ -4,10 +4,12 @@ """ from dataclasses import asdict +from typing import List, Optional import requests from needle.v1 import NeedleBaseClient, NeedleConfig from needle.v1.models import FileToAdd, Error, CollectionFile +from time import sleep class NeedleCollectionsFiles(NeedleBaseClient): @@ -27,6 +29,7 @@ def __init__(self, config: NeedleConfig, headers: dict): self.session = requests.Session() self.session.headers.update(headers) self.session.timeout = 120 + self.collection_id = None def add(self, collection_id: str, files: list[FileToAdd]): """ @@ -67,6 +70,14 @@ def add(self, collection_id: str, files: list[FileToAdd]): for cf in body.get("result") ] + def update_collection_id(self, collection_id: str) -> None: + """Set the self.collection ID + + Args: + collection_id: New collection ID to use + """ + self.collection_id = collection_id + def list(self, collection_id: str): """ Lists all files in a specified collection. @@ -80,7 +91,9 @@ def list(self, collection_id: str): Raises: Error: If the API request fails. """ - endpoint = f"{self.collections_endpoint}/{collection_id}/files" + if collection_id is not None: + self.collection_id = collection_id + endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" resp = self.session.get(endpoint) body = resp.json() if resp.status_code >= 400: @@ -102,3 +115,74 @@ def list(self, collection_id: str): ) for cf in body.get("result") ] + + def delete(self, file_ids: List[str]) -> None: + """Delete files from current collection. + + Args: + file_ids: List of file IDs to delete + + Raises: + ValueError: If no collection ID is set + requests.RequestException: If deletion request fails + """ + if not self.collection_id: + raise ValueError("No collection ID set") + + endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" + req_body = {"file_ids": file_ids} + resp = self.session.delete(endpoint, json=req_body) + + if resp.status_code >= 400: + raise requests.RequestException(f"File deletion failed: {resp.status_code} - {resp.text}") + else: + f"File deletion successful" + + def add_url( + self, + file_url: str, + collection_id: Optional[str] = None, + overwrite: bool = False + ) -> None: + """Add files to collection from URL. + + Args: + file_url: URL of the file to add + collection_name: Optional collection name to use instead of ID + overwrite: Whether to overwrite existing files + + Raises: + ValueError: If collection cannot be determined or file exists + """ + if not self.collection_id or collection_id: + raise ValueError("Collection must be specified by ID or name") + elif collection_id: + self.collection_id = collection_id + + files_info = {f.name: f.id for f in self.list(self.collection_id)} + file_id = files_info.get(file_url) + + if file_id: + if not overwrite: + raise ValueError(f"File {file_url} already exists in collection {self.collection_id}") + self.delete([file_id]) + + self.add( + collection_id=self.collection_id, + files=[FileToAdd(name=file_url, url=file_url)] + ) + + self._wait_for_indexing() + print(f"Successfully added {file_url} to collection {self.collection_id}") + + def _wait_for_indexing(self, check_interval: int = 5) -> None: + """Wait for all files in collection to be indexed. + + Args: + check_interval: Seconds to wait between checks + """ + while True: + all_files = self.list(self.collection_id) + if all(f.status == "indexed" for f in all_files): + break + sleep(check_interval) diff --git a/needle/v2/__init__.py b/needle/v2/__init__.py new file mode 100644 index 0000000..e66be3c --- /dev/null +++ b/needle/v2/__init__.py @@ -0,0 +1,47 @@ +from .collections import NeedleCollections +from .files import NeedleFiles +from needle.v2.models import ( + NeedleConfig, + NeedleBaseClient, +) +from typing import Optional +import os +from urllib.parse import urlparse, urlunparse + +__all__ = ["NeedleCollections", "NeedleFiles"] + + +class NeedleClient(NeedleBaseClient): + """ + A client for interacting with the Needle API. + + This class provides a high-level interface for interacting with the Needle API, + including managing collections and performing searches. + + Initialize the client with an API key and an optional URL. + If no API key is provided, the client will use the `NEEDLE_API_KEY` environment variable. + If no URL is provided, the client will use the default Needle API URL, that is https://needle-ai.com. + + Attributes: + collections (NeedleCollections): A client for managing collections within the Needle API. + files (NeedleFiles): A client for managing files within the Needle API. + """ + + def __init__( + self, + api_key: Optional[str] = os.environ.get("NEEDLE_API_KEY"), + url: Optional[str] = "https://needle-ai.com", + _search_url: Optional[str] = None, + ): + if not _search_url: + parsed_url = urlparse(url) + new_netloc = f"search.{parsed_url.netloc}" + _search_url = urlunparse(parsed_url._replace(netloc=new_netloc)) + + config = NeedleConfig(api_key, url, search_url=_search_url) + headers = {"x-api-key": config.api_key} + super().__init__(config, headers) + + # sub clients + self.collections = NeedleCollections(config, headers) + self.files = NeedleFiles(config, headers) \ No newline at end of file diff --git a/needle/v2/collections.py b/needle/v2/collections.py new file mode 100644 index 0000000..d0ebb5f --- /dev/null +++ b/needle/v2/collections.py @@ -0,0 +1,219 @@ +""" +This module provides NeedleCollections class for interacting with +Needle API's collections endpoint. +""" + +from dataclasses import asdict +import requests # type: ignore + +from needle.v2.models import NeedleConfig, Error, Collection, SearchResult, CollectionStats, CollectionDataStats + +from typing import Any, Optional + +class NeedleCollections: + """ + A client for interacting with the Needle API's collections endpoint. + + This class provides methods to create and manage collections within the Needle API. + It uses a requests session to handle HTTP requests with a default timeout of 120 seconds. + """ + + def __init__(self, config: NeedleConfig, headers: dict): + self.config = config + self.headers = headers + self.collections_endpoint = f"{config.url}/api/v1/collections" + + # requests config + self.session = requests.Session() + self.session.headers.update(headers) + self.session.timeout = 120 + + def create(self, name: str, file_ids: Optional[list[str]] = None) -> Collection: + """ + Creates a new collection with the specified name and file IDs. + + Args: + name (str): The name of the collection. + file_ids (Optional[list[str]]): A list of file IDs to include in the collection. + + Returns: + Collection: The created collection object. + + Raises: + Error: If the API request fails. + """ + req_body = {"name": name, "file_ids": file_ids} + resp = self.session.post(self.collections_endpoint, json=req_body) + body = resp.json() + if resp.status_code >= 400: + error = body.get("error") + raise Error(**error) + c = body.get("result") + return Collection( + id=c.get("id"), + name=c.get("name"), + embedding_model=c.get("embedding_model"), + embedding_dimensions=c.get("embedding_dimensions"), + search_queries=c.get("search_queries"), + created_at=c.get("created_at"), + updated_at=c.get("updated_at"), + ) + + def list(self) -> list[Collection]: + """ + Lists all collections. + + Returns: + list[Collection]: A list of all collections. + + Raises: + Error: If the API request fails. + """ + resp = self.session.get(self.collections_endpoint) + body = resp.json() + if resp.status_code >= 400: + error = body.get("error") + raise Error(**error) + return [ + Collection( + id=c.get("id"), + name=c.get("name"), + embedding_model=c.get("embedding_model"), + embedding_dimensions=c.get("embedding_dimensions"), + search_queries=c.get("search_queries"), + created_at=c.get("created_at"), + updated_at=c.get("updated_at"), + ) + for c in body.get("result") + ] + + def get(self, collection_id: str) -> Collection: + """ + Retrieves a collection by its ID. + + Args: + collection_id (str): The ID of the collection to retrieve. + + Returns: + Collection: The retrieved collection object. + + Raises: + Error: If the API request fails. + """ + resp = self.session.get(f"{self.collections_endpoint}/{collection_id}") + body = resp.json() + if resp.status_code >= 400: + error = body.get("error") + raise Error(**error) + c = body.get("result") + return Collection( + id=c.get("id"), + name=c.get("name"), + embedding_model=c.get("embedding_model"), + embedding_dimensions=c.get("embedding_dimensions"), + search_queries=c.get("search_queries"), + created_at=c.get("created_at"), + updated_at=c.get("updated_at"), + ) + + + def identify_collection(self, collection_name: str) -> str: + """Get collection ID from name. + + Args: + collection_name: Name of the collection + + Returns: + Collection ID + + Raises: + ValueError: If collection name is not found + """ + collections_list = self.list() + collection_info = {collection.name: collection.id for collection in collections_list} + + if collection_name not in collection_info: + raise ValueError(f"Collection '{collection_name}' not found") + return collection_info[collection_name] + + def update_collection_id(self, collection_id: str) -> None: + """Set the self.collection ID + + Args: + collection_id: New collection ID to use + """ + self.collection_id = collection_id + + def search( + self, + collection_id: str, + text: str, + max_distance: Optional[float] = None, + top_k: Optional[int] = None, + ): + """ + Searches within a collection based on the provided parameters. + + Args: + params (SearchCollectionRequest): The search parameters. + + Returns: + list[dict]: The search results. + + Raises: + Error: If the API request fails. + """ + endpoint = f"{self.search_endpoint}/{collection_id}/search" + req_body = { + "text": text, + "max_distance": max_distance, + "top_k": top_k, + } + resp = self.session.post(endpoint, headers=self.headers, json=req_body) + body = resp.json() + if resp.status_code >= 400: + error = body.get("error") + raise Error(**error) + return [ + SearchResult( + content=r.get("content"), + file_id=r.get("file_id"), + ) + for r in body.get("result") + ] + + def get_stats(self, collection_id: str): + """ + Retrieves statistics of a collection. + + Args: + collection_id (str): The ID of the collection to retrieve statistics for. + + Returns: + dict: The collection statistics. + + Raises: + Error: If the API request fails. + """ + endpoint = f"{self.endpoint}/{collection_id}/stats" + resp = self.session.get(endpoint, headers=self.headers) + body = resp.json() + if resp.status_code >= 400: + error = body.get("error") + raise Error(**error) + + result = body.get("result") + data_stats = [ + CollectionDataStats( + status=ds.get("status"), + files=ds.get("files"), + bytes=ds.get("bytes"), + ) + for ds in result.get("data_stats") + ] + return CollectionStats( + data_stats=data_stats, + chunks_count=result.get("chunks_count"), + characters=result.get("characters"), + users=result.get("users"), + ) diff --git a/needle/v2/files.py b/needle/v2/files.py new file mode 100644 index 0000000..83c6e57 --- /dev/null +++ b/needle/v2/files.py @@ -0,0 +1,190 @@ +""" +This module provides NeedleFiles class for interacting with +Needle API's collection files endpoint. +""" + +from dataclasses import asdict +import requests +from time import sleep +from typing import List, Optional + +from needle.v2.models import NeedleConfig, FileToAdd, Error, CollectionFile + + +class NeedleFiles: + """ + A client for interacting with the Needle API's collection files endpoint. + + This class provides methods to create and manage collection files within the Needle API. + It uses a requests session to handle HTTP requests with a default timeout of 120 seconds. + """ + + def __init__(self, config: NeedleConfig, headers: dict): + self.config = config + self.headers = headers + self.collections_endpoint = f"{config.url}/api/v1/collections" + + # requests config + self.session = requests.Session() + self.session.headers.update(headers) + self.session.timeout = 120 + self.collection_id = None + + def add(self, collection_id: str, files: list[FileToAdd]): + """ + Adds files to a specified collection. Added files will be automatically indexed and after be available for search within the collection. + + Args: + collection_id (str): The ID of the collection to which files will be added. + files (list[FileToAdd]): A list of FileToAdd objects representing the files to be added. + + Returns: + list[CollectionFile]: A list of CollectionFile objects representing the added files. + + Raises: + Error: If the API request fails. + """ + endpoint = f"{self.collections_endpoint}/{collection_id}/files" + req_body = {"files": [asdict(f) for f in files]} + resp = self.session.post(endpoint, json=req_body) + body = resp.json() + if resp.status_code >= 400: + error = body.get("error") + raise Error(**error) + return [ + CollectionFile( + id=cf.get("id"), + name=cf.get("name"), + type=cf.get("type"), + url=cf.get("url"), + user_id=cf.get("user_id"), + connector_id=cf.get("connector_id"), + size=cf.get("size"), + md5_hash=cf.get("md5_hash"), + created_at=cf.get("created_at"), + updated_at=cf.get("updated_at"), + status=cf.get("status"), + ) + for cf in body.get("result") + ] + + def update_collection_id(self, collection_id: str) -> None: + """Set the self.collection ID + + Args: + collection_id: New collection ID to use + """ + self.collection_id = collection_id + + def list(self, collection_id: str = None) -> list[CollectionFile]: + """ + Lists all files in a specified collection. + + Args: + collection_id (str): The ID of the collection whose files will be listed. + + Returns: + list[CollectionFile]: A list of CollectionFile objects representing the files in the collection. + + Raises: + Error: If the API request fails. + """ + if collection_id is not None: + self.collection_id = collection_id + + endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" + resp = self.session.get(endpoint) + body = resp.json() + if resp.status_code >= 400: + error = body.get("error") + raise Error(**error) + return [ + CollectionFile( + id=cf.get("id"), + name=cf.get("name"), + type=cf.get("type"), + url=cf.get("url"), + user_id=cf.get("user_id"), + connector_id=cf.get("connector_id"), + size=cf.get("size"), + md5_hash=cf.get("md5_hash"), + created_at=cf.get("created_at"), + updated_at=cf.get("updated_at"), + status=cf.get("status"), + ) + for cf in body.get("result") + ] + + def delete_file(self, file_ids: List[str]) -> None: + """Delete files from current collection. + + Args: + file_ids: List of file IDs to delete + + Raises: + ValueError: If no collection ID is set + requests.RequestException: If deletion request fails + """ + if not self.collection_id: + raise ValueError("No collection ID set") + + endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" + req_body = {"file_ids": file_ids} + resp = self.session.delete(endpoint, json=req_body) + + if resp.status_code >= 400: + raise requests.RequestException(f"File deletion failed: {resp.status_code} - {resp.text}") + else: + f"File deletion successful" + + def add_file_from_url( + self, + file_url: str, + overwrite: bool = False + ) -> None: + """Add files to collection from URL. + + Args: + file_url: URL of the file to add + collection_name: Optional collection name to use instead of ID + overwrite: Whether to overwrite existing files + + Raises: + ValueError: If collection cannot be determined or file exists + """ + if not self.collection_id: + raise ValueError("Collection must be specified by ID or name") + + files_info = {f.name: f.id for f in self.list()} + file_id = files_info.get(file_url) + + if file_id: + if not overwrite: + raise ValueError(f"File {file_url} already exists in collection {self.collection_id}") + self.delete_file([file_id]) + + self.add( + collection_id=self.collection_id, + files=[FileToAdd(name=file_url, url=file_url)] + ) + + self._wait_for_indexing(file_name = file_url) + print(f"Successfully added {file_url} to collection {self.collection_id}") + + def _wait_for_indexing(self, check_interval: int = 5, file_name:str = None) -> None: + """Wait for all files in collection to be indexed. + + Args: + check_interval: Seconds to wait between checks + """ + if file_name is not None: + all_files = [f for f in self.list(self.collection_id) if f.name == file_name] + else: + all_files = self.list(self.collection_id) + while True: + if all(f.status in ("indexed", "error") for f in all_files): + error_files = [f.name for f in all_files if f.status == "error"] + if error_files: + print(f"Files failed to index: {', '.join(error_files)}") + break + sleep(check_interval) \ No newline at end of file diff --git a/needle/v2/models.py b/needle/v2/models.py new file mode 100644 index 0000000..77702b8 --- /dev/null +++ b/needle/v2/models.py @@ -0,0 +1,126 @@ +""" +This module contains the data models used in the Needle API client. +""" + +from typing import Any, Optional, Literal +from dataclasses import dataclass, asdict +import json + + +@dataclass(frozen=True) +class NeedleConfig: + """ + Configuration for the Needle API client. + """ + + api_key: Optional[str] + url: Optional[str] + search_url: Optional[str] + + +@dataclass(frozen=True) +class NeedleBaseClient: + """ + Base client for interacting with the Needle API. Not intended to be used directly. + """ + + config: NeedleConfig + headers: dict + + +FileType = Literal["application/pdf"] + + +@dataclass() +class Error(BaseException): + """ + Error response from the Needle API. An object of this class is raised when an API request fails. + """ + + code: int + message: str + data: Optional[Any] = None + + def __str__(self): + return json.dumps(asdict(self), allow_nan=False) + + +@dataclass(frozen=True) +class Collection: + """ + Represents a collection in the Needle API. + A collection is a group of files that can be searched together. + """ + + id: str + name: str + embedding_model: str + embedding_dimensions: str + search_queries: str + created_at: str + updated_at: str + + +@dataclass(frozen=True) +class FileToAdd: + """ + Represents file metadata, used when adding new files to a collection in the Needle API. + """ + + name: str + url: str + + +CollectionFileStatus = Literal["pending", "indexed", "error"] + + +@dataclass(frozen=True) +class CollectionFile: + """ + Represents a file in the Needle API. Note that a file can be part of multiple collections. + """ + + id: str + name: str + type: FileType + url: str + user_id: str + connector_id: str + size: int + md5_hash: str + created_at: str + updated_at: str + status: CollectionFileStatus + + +@dataclass(frozen=True) +class CollectionDataStats: + """ + Represents data statistics of a collection in the Needle API. + """ + + status: Optional[str] + files: int + bytes: int + + +@dataclass(frozen=True) +class CollectionStats: + """ + Represents statistics of a collection in the Needle API. + """ + + data_stats: list[CollectionDataStats] + chunks_count: int + characters: int + users: int + + +@dataclass(frozen=True) +class SearchResult: + """ + Represents a search result from the Needle API. + """ + + content: str + file_id: str From ec679937800f464461fc9276023e8f772d7906db Mon Sep 17 00:00:00 2001 From: Anne Rodenburg Date: Wed, 26 Feb 2025 12:16:09 +0100 Subject: [PATCH 2/2] Incorporated Comments --- DELETEMEv1.ipynb | 115 ---------------- DELETEMEv2.ipynb | 169 ----------------------- needle/v1/collections/__init__.py | 14 +- needle/v1/collections/files.py | 82 ++--------- needle/v2/__init__.py | 47 ------- needle/v2/collections.py | 219 ------------------------------ needle/v2/files.py | 190 -------------------------- needle/v2/models.py | 126 ----------------- 8 files changed, 16 insertions(+), 946 deletions(-) delete mode 100644 DELETEMEv1.ipynb delete mode 100644 DELETEMEv2.ipynb delete mode 100644 needle/v2/__init__.py delete mode 100644 needle/v2/collections.py delete mode 100644 needle/v2/files.py delete mode 100644 needle/v2/models.py diff --git a/DELETEMEv1.ipynb b/DELETEMEv1.ipynb deleted file mode 100644 index 92159c2..0000000 --- a/DELETEMEv1.ipynb +++ /dev/null @@ -1,115 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from needle.v1 import NeedleClient" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "### TEST attributes\n", - "api_key = \"\"\n", - "collection_name = \"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "needle_manager = NeedleClient(api_key)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# Set the right collection\n", - "\n", - "collection_id = needle_manager.collections.identify_collection(collection_name)\n", - "needle_manager.collections.update_collection_id(collection_id)\n", - "needle_manager.collections.files.update_collection_id(collection_id)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "https://www.producthunt.com/products/needle-3 fle_01JMHW440SNS9619NZ01FJCX7X pending\n" - ] - } - ], - "source": [ - "# Get an overview of all existing files in the respective collection\n", - "needle_manager.collections.files.list(collection_id)\n", - "\n", - "for a in needle_manager.collections.files.list(collection_id):\n", - " print(a.name, a.id, a.status)\n", - " break\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[6], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# Update existing url if necessary\u001b[39;00m\n\u001b[0;32m 2\u001b[0m producthunt_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://www.producthunt.com/products/needle-3\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m----> 3\u001b[0m \u001b[43mneedle_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollections\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfiles\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_url\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproducthunt_url\u001b[49m\u001b[43m \u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v1\\collections\\files.py:175\u001b[0m, in \u001b[0;36mNeedleCollectionsFiles.add_url\u001b[1;34m(self, file_url, collection_id, overwrite)\u001b[0m\n\u001b[0;32m 168\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdelete([file_id])\n\u001b[0;32m 170\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madd(\n\u001b[0;32m 171\u001b[0m collection_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id,\n\u001b[0;32m 172\u001b[0m files\u001b[38;5;241m=\u001b[39m[FileToAdd(name\u001b[38;5;241m=\u001b[39mfile_url, url\u001b[38;5;241m=\u001b[39mfile_url)]\n\u001b[0;32m 173\u001b[0m )\n\u001b[1;32m--> 175\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wait_for_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 176\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSuccessfully added \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile_url\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to collection \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v1\\collections\\files.py:188\u001b[0m, in \u001b[0;36mNeedleCollectionsFiles._wait_for_indexing\u001b[1;34m(self, check_interval)\u001b[0m\n\u001b[0;32m 186\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(f\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mindexed\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m all_files):\n\u001b[0;32m 187\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m--> 188\u001b[0m \u001b[43msleep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcheck_interval\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[1;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "# Update existing url if necessary\n", - "producthunt_url = \"https://www.producthunt.com/products/needle-3\"\n", - "needle_manager.collections.files.add_url(producthunt_url , overwrite=True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/DELETEMEv2.ipynb b/DELETEMEv2.ipynb deleted file mode 100644 index 6811b14..0000000 --- a/DELETEMEv2.ipynb +++ /dev/null @@ -1,169 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "from needle.v2 import NeedleClient" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "### TEST attributes\n", - "api_key = \"\"\n", - "collection_name = \"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "needle_manager = NeedleClient(api_key)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "# Set the right collection\n", - "\n", - "\n", - "collection_id = needle_manager.collections.identify_collection(collection_name)\n", - "needle_manager.collections.update_collection_id(collection_id)\n", - "needle_manager.files.update_collection_id(collection_id)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "staging_operation_companies.txt fle_01JM4N6N287WYGPDNQECJRSWK7 indexed\n", - "staging_operation_api_tracking.txt fle_01JM4N6N296TWY6Z0022Y1GDXN indexed\n", - "staging_operation_datasources.txt fle_01JM4N6N29FNG7SFY9AAMNAT5F indexed\n", - "staging_flatfiles_countries.txt fle_01JM4N6N29J4VJQG3V88J6SJR3 indexed\n", - "staging_hubspot_engagements.txt fle_01JM4N6N29K5N68KNRD8DHNSP0 indexed\n", - "staging_operation_users.txt fle_01JM4N6N29VAYK5DP0GACJ25M1 indexed\n", - "staging_hubspot_owners.txt fle_01JM4N6N2ACMZN04WEECJA903T indexed\n", - "staging_personio_attendances.txt fle_01JM4N6N2AMEZ4H3QX34PSA5JP indexed\n", - "staging_personio_absences.txt fle_01JM4N6N2AQXNCR1DKJ1ENRRAV indexed\n", - "staging_sevdesk_invoices.txt fle_01JM4N6N2AV099XS4730HZ8QXW indexed\n", - "staging_sevdesk_contacts.txt fle_01JM4N6N2AX9FC3WQ8FMMGZJPE indexed\n", - "staging_sevdesk_receipts.txt fle_01JM4N6N2B7E8KSV3EE61S96GG indexed\n", - "staging_flatfiles_portals_suntrol.txt fle_01JM4N6N2BD7R4HGHXK7WS55BY indexed\n", - "staging_hubspot_pipelines.txt fle_01JM4N6N2BECJ1G76JEK08D0DJ indexed\n", - "staging_flatfiles_target.txt fle_01JM4N6N2BFR7QYJC8SNHA34X4 indexed\n", - "staging_personio_employees.txt fle_01JM4N6N2BNTPX7WTXXDWK79JW indexed\n", - "staging_hubspot_feedback_submissions.txt fle_01JM4N6N2BXSA0A81FBBMZN1CY indexed\n", - "staging_flatfiles_jira_issues.txt fle_01JM4N6N2BY9PRJEZZT1J8EP89 indexed\n", - "staging_flatfiles_ledgers_booked.txt fle_01JM4N6N2BYYAZVX727S0ZWJT0 indexed\n", - "staging_flatfiles_azure_costs.txt fle_01JM4N6N2BZNHQEF3M6C9YZF3V indexed\n", - "staging_sevdesk_payments.txt fle_01JM4N6N2CA7254V51DN54WQ60 indexed\n", - "staging_flatfiles_personnel_expenses.txt fle_01JM4N6N2CBMGY066NB5TJX3AS indexed\n", - "staging_operation_datasources_events.txt fle_01JM4N6N2CC8B6EVRC5GBA3SZD indexed\n", - "staging_hubspot_companies.txt fle_01JM4N6N2CES2H90EN2X5FTVBR indexed\n", - "staging_hubspot_products.txt fle_01JM4N6N2CH53MRR56DFMQ16KV indexed\n", - "staging_flatfiles_portals_solytic_1.txt fle_01JM4N6N2CNBWVY91BSZHW8767 indexed\n", - "staging_hubspot_deals.txt fle_01JM4N6N2CVCQMTTTQETYGS50G indexed\n", - "staging_hubspot_contacts.txt fle_01JM4N6N2CVM8TFWVFJVV7AV4M indexed\n", - "staging_operation_sites.txt fle_01JM4N6N2CWZ03RCSX8327DQZ7 indexed\n", - "staging_hubspot_tickets.txt fle_01JM4N6N2CY1XT5FYTY9QN51PT indexed\n", - "staging_hubspot_line_items.txt fle_01JM4N6N2CZQ331QMRA7KQTPVN indexed\n", - "staging_flatfiles_booked_revenue.txt fle_01JM4N6N2DA2CWYR88PMBH9G7Q indexed\n" - ] - } - ], - "source": [ - "# Get an overview of all existing files in the respective collection\n", - "needle_manager.files.list(collection_id)\n", - "\n", - "for a in needle_manager.files.list(collection_id):\n", - " print(a.name, a.id, a.status)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "for a in needle_manager.files.list(collection_id):\n", - " if a.name == \"https://www.producthunt.com/products/needle-3\":\n", - " print(a.name, a.id, a.status)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[14], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# Update existing url if necessary\u001b[39;00m\n\u001b[0;32m 2\u001b[0m producthunt_url \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://www.producthunt.com/products/needle-3\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m----> 3\u001b[0m \u001b[43mneedle_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfiles\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_file_from_url\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproducthunt_url\u001b[49m\u001b[43m \u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v2\\files.py:171\u001b[0m, in \u001b[0;36madd_file_from_url\u001b[1;34m(self, file_url, overwrite)\u001b[0m\n\u001b[0;32m 165\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madd(\n\u001b[0;32m 166\u001b[0m collection_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id,\n\u001b[0;32m 167\u001b[0m files\u001b[38;5;241m=\u001b[39m[FileToAdd(name\u001b[38;5;241m=\u001b[39mfile_url, url\u001b[38;5;241m=\u001b[39mfile_url)]\n\u001b[0;32m 168\u001b[0m )\n\u001b[0;32m 170\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait_for_indexing(file_name \u001b[38;5;241m=\u001b[39m file_url)\n\u001b[1;32m--> 171\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSuccessfully added \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile_url\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to collection \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcollection_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[1;32mc:\\Users\\anner\\Documents\\3_Developing\\GIT_repositories\\needle-python\\needle\\v2\\files.py:190\u001b[0m, in \u001b[0;36m_wait_for_indexing\u001b[1;34m(self, check_interval, file_name)\u001b[0m\n\u001b[0;32m 0\u001b[0m \n", - "\u001b[1;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "# Update existing url if necessary\n", - "producthunt_url = \"https://www.producthunt.com/products/needle-3\"\n", - "needle_manager.files.add_file_from_url(producthunt_url , overwrite=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Get the relevant data from the rag\n", - "rag_prompt = \"\"\"Retrieve people who commented or reviewd the product.Exclude unrelated content.\"\"\"\n", - "rag_results = needle_manager.collection.search_collection(rag_prompt)\n", - "print(rag_results)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/needle/v1/collections/__init__.py b/needle/v1/collections/__init__.py index e5c9b88..8f25f63 100644 --- a/needle/v1/collections/__init__.py +++ b/needle/v1/collections/__init__.py @@ -3,8 +3,7 @@ """ from typing import Optional - -import requests # type: ignore +import requests from needle.v1.models import ( NeedleConfig, @@ -114,6 +113,7 @@ def list(self): Error: If the API request fails. """ resp = self.session.get(self.endpoint) + print(resp.json()) body = resp.json() if resp.status_code >= 400: error = body.get("error") @@ -149,15 +149,7 @@ def identify_collection(self, collection_name: str) -> str: if collection_name not in collection_info: raise ValueError(f"Collection '{collection_name}' not found") return collection_info[collection_name] - - def update_collection_id(self, collection_id: str) -> None: - """Set the self.collection ID - - Args: - collection_id: New collection ID to use - """ - self.collection_id = collection_id - + def search( self, collection_id: str, diff --git a/needle/v1/collections/files.py b/needle/v1/collections/files.py index a23feac..a524ecf 100644 --- a/needle/v1/collections/files.py +++ b/needle/v1/collections/files.py @@ -1,6 +1,6 @@ """ This module provides NeedleCollectionsFiles class for interacting with -Needle API's collectiton files endpoint. +Needle API's collection files endpoint. """ from dataclasses import asdict @@ -29,7 +29,6 @@ def __init__(self, config: NeedleConfig, headers: dict): self.session = requests.Session() self.session.headers.update(headers) self.session.timeout = 120 - self.collection_id = None def add(self, collection_id: str, files: list[FileToAdd]): """ @@ -50,6 +49,7 @@ def add(self, collection_id: str, files: list[FileToAdd]): req_body = {"files": [asdict(f) for f in files]} resp = self.session.post(endpoint, json=req_body) body = resp.json() + if resp.status_code >= 400: error = body.get("error") raise Error(**error) @@ -70,14 +70,6 @@ def add(self, collection_id: str, files: list[FileToAdd]): for cf in body.get("result") ] - def update_collection_id(self, collection_id: str) -> None: - """Set the self.collection ID - - Args: - collection_id: New collection ID to use - """ - self.collection_id = collection_id - def list(self, collection_id: str): """ Lists all files in a specified collection. @@ -96,6 +88,7 @@ def list(self, collection_id: str): endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" resp = self.session.get(endpoint) body = resp.json() + if resp.status_code >= 400: error = body.get("error") raise Error(**error) @@ -116,73 +109,24 @@ def list(self, collection_id: str): for cf in body.get("result") ] - def delete(self, file_ids: List[str]) -> None: + def delete(self, collection_id: str, file_ids: List[str]) -> None: """Delete files from current collection. Args: - file_ids: List of file IDs to delete + collection_id (str): The ID of the collection from which files will be deleted. + file_ids (List[str]): List of file IDs to delete. Raises: - ValueError: If no collection ID is set - requests.RequestException: If deletion request fails + ValueError: If no collection ID is set. + Error: If the API request fails. """ - if not self.collection_id: + if not collection_id: raise ValueError("No collection ID set") - endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" + endpoint = f"{self.collections_endpoint}/{collection_id}/files" req_body = {"file_ids": file_ids} resp = self.session.delete(endpoint, json=req_body) - - if resp.status_code >= 400: - raise requests.RequestException(f"File deletion failed: {resp.status_code} - {resp.text}") - else: - f"File deletion successful" - - def add_url( - self, - file_url: str, - collection_id: Optional[str] = None, - overwrite: bool = False - ) -> None: - """Add files to collection from URL. - - Args: - file_url: URL of the file to add - collection_name: Optional collection name to use instead of ID - overwrite: Whether to overwrite existing files - - Raises: - ValueError: If collection cannot be determined or file exists - """ - if not self.collection_id or collection_id: - raise ValueError("Collection must be specified by ID or name") - elif collection_id: - self.collection_id = collection_id - - files_info = {f.name: f.id for f in self.list(self.collection_id)} - file_id = files_info.get(file_url) - - if file_id: - if not overwrite: - raise ValueError(f"File {file_url} already exists in collection {self.collection_id}") - self.delete([file_id]) - - self.add( - collection_id=self.collection_id, - files=[FileToAdd(name=file_url, url=file_url)] - ) - self._wait_for_indexing() - print(f"Successfully added {file_url} to collection {self.collection_id}") - - def _wait_for_indexing(self, check_interval: int = 5) -> None: - """Wait for all files in collection to be indexed. - - Args: - check_interval: Seconds to wait between checks - """ - while True: - all_files = self.list(self.collection_id) - if all(f.status == "indexed" for f in all_files): - break - sleep(check_interval) + if resp.status_code >= 400: + error = resp.json().get("error") + raise Error(**error) diff --git a/needle/v2/__init__.py b/needle/v2/__init__.py deleted file mode 100644 index e66be3c..0000000 --- a/needle/v2/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -from .collections import NeedleCollections -from .files import NeedleFiles -from needle.v2.models import ( - NeedleConfig, - NeedleBaseClient, -) -from typing import Optional -import os -from urllib.parse import urlparse, urlunparse - -__all__ = ["NeedleCollections", "NeedleFiles"] - - -class NeedleClient(NeedleBaseClient): - """ - A client for interacting with the Needle API. - - This class provides a high-level interface for interacting with the Needle API, - including managing collections and performing searches. - - Initialize the client with an API key and an optional URL. - If no API key is provided, the client will use the `NEEDLE_API_KEY` environment variable. - If no URL is provided, the client will use the default Needle API URL, that is https://needle-ai.com. - - Attributes: - collections (NeedleCollections): A client for managing collections within the Needle API. - files (NeedleFiles): A client for managing files within the Needle API. - """ - - def __init__( - self, - api_key: Optional[str] = os.environ.get("NEEDLE_API_KEY"), - url: Optional[str] = "https://needle-ai.com", - _search_url: Optional[str] = None, - ): - if not _search_url: - parsed_url = urlparse(url) - new_netloc = f"search.{parsed_url.netloc}" - _search_url = urlunparse(parsed_url._replace(netloc=new_netloc)) - - config = NeedleConfig(api_key, url, search_url=_search_url) - headers = {"x-api-key": config.api_key} - super().__init__(config, headers) - - # sub clients - self.collections = NeedleCollections(config, headers) - self.files = NeedleFiles(config, headers) \ No newline at end of file diff --git a/needle/v2/collections.py b/needle/v2/collections.py deleted file mode 100644 index d0ebb5f..0000000 --- a/needle/v2/collections.py +++ /dev/null @@ -1,219 +0,0 @@ -""" -This module provides NeedleCollections class for interacting with -Needle API's collections endpoint. -""" - -from dataclasses import asdict -import requests # type: ignore - -from needle.v2.models import NeedleConfig, Error, Collection, SearchResult, CollectionStats, CollectionDataStats - -from typing import Any, Optional - -class NeedleCollections: - """ - A client for interacting with the Needle API's collections endpoint. - - This class provides methods to create and manage collections within the Needle API. - It uses a requests session to handle HTTP requests with a default timeout of 120 seconds. - """ - - def __init__(self, config: NeedleConfig, headers: dict): - self.config = config - self.headers = headers - self.collections_endpoint = f"{config.url}/api/v1/collections" - - # requests config - self.session = requests.Session() - self.session.headers.update(headers) - self.session.timeout = 120 - - def create(self, name: str, file_ids: Optional[list[str]] = None) -> Collection: - """ - Creates a new collection with the specified name and file IDs. - - Args: - name (str): The name of the collection. - file_ids (Optional[list[str]]): A list of file IDs to include in the collection. - - Returns: - Collection: The created collection object. - - Raises: - Error: If the API request fails. - """ - req_body = {"name": name, "file_ids": file_ids} - resp = self.session.post(self.collections_endpoint, json=req_body) - body = resp.json() - if resp.status_code >= 400: - error = body.get("error") - raise Error(**error) - c = body.get("result") - return Collection( - id=c.get("id"), - name=c.get("name"), - embedding_model=c.get("embedding_model"), - embedding_dimensions=c.get("embedding_dimensions"), - search_queries=c.get("search_queries"), - created_at=c.get("created_at"), - updated_at=c.get("updated_at"), - ) - - def list(self) -> list[Collection]: - """ - Lists all collections. - - Returns: - list[Collection]: A list of all collections. - - Raises: - Error: If the API request fails. - """ - resp = self.session.get(self.collections_endpoint) - body = resp.json() - if resp.status_code >= 400: - error = body.get("error") - raise Error(**error) - return [ - Collection( - id=c.get("id"), - name=c.get("name"), - embedding_model=c.get("embedding_model"), - embedding_dimensions=c.get("embedding_dimensions"), - search_queries=c.get("search_queries"), - created_at=c.get("created_at"), - updated_at=c.get("updated_at"), - ) - for c in body.get("result") - ] - - def get(self, collection_id: str) -> Collection: - """ - Retrieves a collection by its ID. - - Args: - collection_id (str): The ID of the collection to retrieve. - - Returns: - Collection: The retrieved collection object. - - Raises: - Error: If the API request fails. - """ - resp = self.session.get(f"{self.collections_endpoint}/{collection_id}") - body = resp.json() - if resp.status_code >= 400: - error = body.get("error") - raise Error(**error) - c = body.get("result") - return Collection( - id=c.get("id"), - name=c.get("name"), - embedding_model=c.get("embedding_model"), - embedding_dimensions=c.get("embedding_dimensions"), - search_queries=c.get("search_queries"), - created_at=c.get("created_at"), - updated_at=c.get("updated_at"), - ) - - - def identify_collection(self, collection_name: str) -> str: - """Get collection ID from name. - - Args: - collection_name: Name of the collection - - Returns: - Collection ID - - Raises: - ValueError: If collection name is not found - """ - collections_list = self.list() - collection_info = {collection.name: collection.id for collection in collections_list} - - if collection_name not in collection_info: - raise ValueError(f"Collection '{collection_name}' not found") - return collection_info[collection_name] - - def update_collection_id(self, collection_id: str) -> None: - """Set the self.collection ID - - Args: - collection_id: New collection ID to use - """ - self.collection_id = collection_id - - def search( - self, - collection_id: str, - text: str, - max_distance: Optional[float] = None, - top_k: Optional[int] = None, - ): - """ - Searches within a collection based on the provided parameters. - - Args: - params (SearchCollectionRequest): The search parameters. - - Returns: - list[dict]: The search results. - - Raises: - Error: If the API request fails. - """ - endpoint = f"{self.search_endpoint}/{collection_id}/search" - req_body = { - "text": text, - "max_distance": max_distance, - "top_k": top_k, - } - resp = self.session.post(endpoint, headers=self.headers, json=req_body) - body = resp.json() - if resp.status_code >= 400: - error = body.get("error") - raise Error(**error) - return [ - SearchResult( - content=r.get("content"), - file_id=r.get("file_id"), - ) - for r in body.get("result") - ] - - def get_stats(self, collection_id: str): - """ - Retrieves statistics of a collection. - - Args: - collection_id (str): The ID of the collection to retrieve statistics for. - - Returns: - dict: The collection statistics. - - Raises: - Error: If the API request fails. - """ - endpoint = f"{self.endpoint}/{collection_id}/stats" - resp = self.session.get(endpoint, headers=self.headers) - body = resp.json() - if resp.status_code >= 400: - error = body.get("error") - raise Error(**error) - - result = body.get("result") - data_stats = [ - CollectionDataStats( - status=ds.get("status"), - files=ds.get("files"), - bytes=ds.get("bytes"), - ) - for ds in result.get("data_stats") - ] - return CollectionStats( - data_stats=data_stats, - chunks_count=result.get("chunks_count"), - characters=result.get("characters"), - users=result.get("users"), - ) diff --git a/needle/v2/files.py b/needle/v2/files.py deleted file mode 100644 index 83c6e57..0000000 --- a/needle/v2/files.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -This module provides NeedleFiles class for interacting with -Needle API's collection files endpoint. -""" - -from dataclasses import asdict -import requests -from time import sleep -from typing import List, Optional - -from needle.v2.models import NeedleConfig, FileToAdd, Error, CollectionFile - - -class NeedleFiles: - """ - A client for interacting with the Needle API's collection files endpoint. - - This class provides methods to create and manage collection files within the Needle API. - It uses a requests session to handle HTTP requests with a default timeout of 120 seconds. - """ - - def __init__(self, config: NeedleConfig, headers: dict): - self.config = config - self.headers = headers - self.collections_endpoint = f"{config.url}/api/v1/collections" - - # requests config - self.session = requests.Session() - self.session.headers.update(headers) - self.session.timeout = 120 - self.collection_id = None - - def add(self, collection_id: str, files: list[FileToAdd]): - """ - Adds files to a specified collection. Added files will be automatically indexed and after be available for search within the collection. - - Args: - collection_id (str): The ID of the collection to which files will be added. - files (list[FileToAdd]): A list of FileToAdd objects representing the files to be added. - - Returns: - list[CollectionFile]: A list of CollectionFile objects representing the added files. - - Raises: - Error: If the API request fails. - """ - endpoint = f"{self.collections_endpoint}/{collection_id}/files" - req_body = {"files": [asdict(f) for f in files]} - resp = self.session.post(endpoint, json=req_body) - body = resp.json() - if resp.status_code >= 400: - error = body.get("error") - raise Error(**error) - return [ - CollectionFile( - id=cf.get("id"), - name=cf.get("name"), - type=cf.get("type"), - url=cf.get("url"), - user_id=cf.get("user_id"), - connector_id=cf.get("connector_id"), - size=cf.get("size"), - md5_hash=cf.get("md5_hash"), - created_at=cf.get("created_at"), - updated_at=cf.get("updated_at"), - status=cf.get("status"), - ) - for cf in body.get("result") - ] - - def update_collection_id(self, collection_id: str) -> None: - """Set the self.collection ID - - Args: - collection_id: New collection ID to use - """ - self.collection_id = collection_id - - def list(self, collection_id: str = None) -> list[CollectionFile]: - """ - Lists all files in a specified collection. - - Args: - collection_id (str): The ID of the collection whose files will be listed. - - Returns: - list[CollectionFile]: A list of CollectionFile objects representing the files in the collection. - - Raises: - Error: If the API request fails. - """ - if collection_id is not None: - self.collection_id = collection_id - - endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" - resp = self.session.get(endpoint) - body = resp.json() - if resp.status_code >= 400: - error = body.get("error") - raise Error(**error) - return [ - CollectionFile( - id=cf.get("id"), - name=cf.get("name"), - type=cf.get("type"), - url=cf.get("url"), - user_id=cf.get("user_id"), - connector_id=cf.get("connector_id"), - size=cf.get("size"), - md5_hash=cf.get("md5_hash"), - created_at=cf.get("created_at"), - updated_at=cf.get("updated_at"), - status=cf.get("status"), - ) - for cf in body.get("result") - ] - - def delete_file(self, file_ids: List[str]) -> None: - """Delete files from current collection. - - Args: - file_ids: List of file IDs to delete - - Raises: - ValueError: If no collection ID is set - requests.RequestException: If deletion request fails - """ - if not self.collection_id: - raise ValueError("No collection ID set") - - endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" - req_body = {"file_ids": file_ids} - resp = self.session.delete(endpoint, json=req_body) - - if resp.status_code >= 400: - raise requests.RequestException(f"File deletion failed: {resp.status_code} - {resp.text}") - else: - f"File deletion successful" - - def add_file_from_url( - self, - file_url: str, - overwrite: bool = False - ) -> None: - """Add files to collection from URL. - - Args: - file_url: URL of the file to add - collection_name: Optional collection name to use instead of ID - overwrite: Whether to overwrite existing files - - Raises: - ValueError: If collection cannot be determined or file exists - """ - if not self.collection_id: - raise ValueError("Collection must be specified by ID or name") - - files_info = {f.name: f.id for f in self.list()} - file_id = files_info.get(file_url) - - if file_id: - if not overwrite: - raise ValueError(f"File {file_url} already exists in collection {self.collection_id}") - self.delete_file([file_id]) - - self.add( - collection_id=self.collection_id, - files=[FileToAdd(name=file_url, url=file_url)] - ) - - self._wait_for_indexing(file_name = file_url) - print(f"Successfully added {file_url} to collection {self.collection_id}") - - def _wait_for_indexing(self, check_interval: int = 5, file_name:str = None) -> None: - """Wait for all files in collection to be indexed. - - Args: - check_interval: Seconds to wait between checks - """ - if file_name is not None: - all_files = [f for f in self.list(self.collection_id) if f.name == file_name] - else: - all_files = self.list(self.collection_id) - while True: - if all(f.status in ("indexed", "error") for f in all_files): - error_files = [f.name for f in all_files if f.status == "error"] - if error_files: - print(f"Files failed to index: {', '.join(error_files)}") - break - sleep(check_interval) \ No newline at end of file diff --git a/needle/v2/models.py b/needle/v2/models.py deleted file mode 100644 index 77702b8..0000000 --- a/needle/v2/models.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -This module contains the data models used in the Needle API client. -""" - -from typing import Any, Optional, Literal -from dataclasses import dataclass, asdict -import json - - -@dataclass(frozen=True) -class NeedleConfig: - """ - Configuration for the Needle API client. - """ - - api_key: Optional[str] - url: Optional[str] - search_url: Optional[str] - - -@dataclass(frozen=True) -class NeedleBaseClient: - """ - Base client for interacting with the Needle API. Not intended to be used directly. - """ - - config: NeedleConfig - headers: dict - - -FileType = Literal["application/pdf"] - - -@dataclass() -class Error(BaseException): - """ - Error response from the Needle API. An object of this class is raised when an API request fails. - """ - - code: int - message: str - data: Optional[Any] = None - - def __str__(self): - return json.dumps(asdict(self), allow_nan=False) - - -@dataclass(frozen=True) -class Collection: - """ - Represents a collection in the Needle API. - A collection is a group of files that can be searched together. - """ - - id: str - name: str - embedding_model: str - embedding_dimensions: str - search_queries: str - created_at: str - updated_at: str - - -@dataclass(frozen=True) -class FileToAdd: - """ - Represents file metadata, used when adding new files to a collection in the Needle API. - """ - - name: str - url: str - - -CollectionFileStatus = Literal["pending", "indexed", "error"] - - -@dataclass(frozen=True) -class CollectionFile: - """ - Represents a file in the Needle API. Note that a file can be part of multiple collections. - """ - - id: str - name: str - type: FileType - url: str - user_id: str - connector_id: str - size: int - md5_hash: str - created_at: str - updated_at: str - status: CollectionFileStatus - - -@dataclass(frozen=True) -class CollectionDataStats: - """ - Represents data statistics of a collection in the Needle API. - """ - - status: Optional[str] - files: int - bytes: int - - -@dataclass(frozen=True) -class CollectionStats: - """ - Represents statistics of a collection in the Needle API. - """ - - data_stats: list[CollectionDataStats] - chunks_count: int - characters: int - users: int - - -@dataclass(frozen=True) -class SearchResult: - """ - Represents a search result from the Needle API. - """ - - content: str - file_id: str