Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions needle/v1/collections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

from typing import Optional

import requests

from needle.v1.models import (
Expand Down Expand Up @@ -73,7 +72,7 @@ def create(self, name: str, file_ids: Optional[list[str]] = None):
created_at=c.get("created_at"),
updated_at=c.get("updated_at"),
)

def get(self, collection_id: str):
"""
Retrieves a collection by its ID.
Expand Down Expand Up @@ -114,6 +113,7 @@ def list(self):
Error: If the API request fails.
"""
resp = self.session.get(self.endpoint)
print(resp.json())
body = resp.json()
if resp.status_code >= 400:
error = body.get("error")
Expand All @@ -131,6 +131,25 @@ def list(self):
for c in body.get("result")
]

def identify_collection(self, collection_name: str) -> str:
"""Get collection ID from name.

Args:
collection_name: Name of the collection

Returns:
Collection ID

Raises:
ValueError: If collection name is not found
"""
collections_list = self.list()
collection_info = {collection.name: collection.id for collection in collections_list}

if collection_name not in collection_info:
raise ValueError(f"Collection '{collection_name}' not found")
return collection_info[collection_name]

def search(
self,
collection_id: str,
Expand Down
32 changes: 30 additions & 2 deletions needle/v1/collections/files.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""
This module provides NeedleCollectionsFiles class for interacting with
Needle API's collectiton files endpoint.
Needle API's collection files endpoint.
"""

from dataclasses import asdict
from typing import List, Optional
import requests

from needle.v1 import NeedleBaseClient, NeedleConfig
from needle.v1.models import FileToAdd, Error, CollectionFile
from time import sleep


class NeedleCollectionsFiles(NeedleBaseClient):
Expand Down Expand Up @@ -47,6 +49,7 @@ def add(self, collection_id: str, files: list[FileToAdd]):
req_body = {"files": [asdict(f) for f in files]}
resp = self.session.post(endpoint, json=req_body)
body = resp.json()

if resp.status_code >= 400:
error = body.get("error")
raise Error(**error)
Expand Down Expand Up @@ -80,9 +83,12 @@ def list(self, collection_id: str):
Raises:
Error: If the API request fails.
"""
endpoint = f"{self.collections_endpoint}/{collection_id}/files"
if collection_id is not None:
self.collection_id = collection_id
endpoint = f"{self.collections_endpoint}/{self.collection_id}/files"
resp = self.session.get(endpoint)
body = resp.json()

if resp.status_code >= 400:
error = body.get("error")
raise Error(**error)
Expand All @@ -102,3 +108,25 @@ def list(self, collection_id: str):
)
for cf in body.get("result")
]

def delete(self, collection_id: str, file_ids: List[str]) -> None:
"""Delete files from current collection.

Args:
collection_id (str): The ID of the collection from which files will be deleted.
file_ids (List[str]): List of file IDs to delete.

Raises:
ValueError: If no collection ID is set.
Error: If the API request fails.
"""
if not collection_id:
raise ValueError("No collection ID set")

endpoint = f"{self.collections_endpoint}/{collection_id}/files"
req_body = {"file_ids": file_ids}
resp = self.session.delete(endpoint, json=req_body)

if resp.status_code >= 400:
error = resp.json().get("error")
raise Error(**error)