diff --git a/needle/v1/collections/__init__.py b/needle/v1/collections/__init__.py index 3f29ee6..8f25f63 100644 --- a/needle/v1/collections/__init__.py +++ b/needle/v1/collections/__init__.py @@ -3,7 +3,6 @@ """ from typing import Optional - import requests from needle.v1.models import ( @@ -73,7 +72,7 @@ def create(self, name: str, file_ids: Optional[list[str]] = None): created_at=c.get("created_at"), updated_at=c.get("updated_at"), ) - + def get(self, collection_id: str): """ Retrieves a collection by its ID. @@ -114,6 +113,7 @@ def list(self): Error: If the API request fails. """ resp = self.session.get(self.endpoint) + print(resp.json()) body = resp.json() if resp.status_code >= 400: error = body.get("error") @@ -131,6 +131,25 @@ def list(self): for c in body.get("result") ] + def identify_collection(self, collection_name: str) -> str: + """Get collection ID from name. + + Args: + collection_name: Name of the collection + + Returns: + Collection ID + + Raises: + ValueError: If collection name is not found + """ + collections_list = self.list() + collection_info = {collection.name: collection.id for collection in collections_list} + + if collection_name not in collection_info: + raise ValueError(f"Collection '{collection_name}' not found") + return collection_info[collection_name] + def search( self, collection_id: str, diff --git a/needle/v1/collections/files.py b/needle/v1/collections/files.py index 8a9fd5c..a524ecf 100644 --- a/needle/v1/collections/files.py +++ b/needle/v1/collections/files.py @@ -1,13 +1,15 @@ """ This module provides NeedleCollectionsFiles class for interacting with -Needle API's collectiton files endpoint. +Needle API's collection files endpoint. """ from dataclasses import asdict +from typing import List, Optional import requests from needle.v1 import NeedleBaseClient, NeedleConfig from needle.v1.models import FileToAdd, Error, CollectionFile +from time import sleep class NeedleCollectionsFiles(NeedleBaseClient): @@ -47,6 +49,7 @@ def add(self, collection_id: str, files: list[FileToAdd]): req_body = {"files": [asdict(f) for f in files]} resp = self.session.post(endpoint, json=req_body) body = resp.json() + if resp.status_code >= 400: error = body.get("error") raise Error(**error) @@ -80,9 +83,12 @@ def list(self, collection_id: str): Raises: Error: If the API request fails. """ - endpoint = f"{self.collections_endpoint}/{collection_id}/files" + if collection_id is not None: + self.collection_id = collection_id + endpoint = f"{self.collections_endpoint}/{self.collection_id}/files" resp = self.session.get(endpoint) body = resp.json() + if resp.status_code >= 400: error = body.get("error") raise Error(**error) @@ -102,3 +108,25 @@ def list(self, collection_id: str): ) for cf in body.get("result") ] + + def delete(self, collection_id: str, file_ids: List[str]) -> None: + """Delete files from current collection. + + Args: + collection_id (str): The ID of the collection from which files will be deleted. + file_ids (List[str]): List of file IDs to delete. + + Raises: + ValueError: If no collection ID is set. + Error: If the API request fails. + """ + if not collection_id: + raise ValueError("No collection ID set") + + endpoint = f"{self.collections_endpoint}/{collection_id}/files" + req_body = {"file_ids": file_ids} + resp = self.session.delete(endpoint, json=req_body) + + if resp.status_code >= 400: + error = resp.json().get("error") + raise Error(**error)