From cb516fe7e6cc348eaf6a821f4f7d9e051aa08525 Mon Sep 17 00:00:00 2001 From: meet Date: Wed, 3 Jul 2024 13:50:00 +0530 Subject: [PATCH 1/2] Created abstract class ensuring consistency for future clients --- pyepsilla/abstract_class/client.py | 19 ++++++++++++++ pyepsilla/abstract_class/vector_db.py | 36 +++++++++++++++++++++++++++ pyepsilla/cloud/client.py | 7 ++++-- pyepsilla/enterprise/client.py | 7 ++++-- 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 pyepsilla/abstract_class/client.py create mode 100644 pyepsilla/abstract_class/vector_db.py diff --git a/pyepsilla/abstract_class/client.py b/pyepsilla/abstract_class/client.py new file mode 100644 index 0000000..e51ca3b --- /dev/null +++ b/pyepsilla/abstract_class/client.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + + +class AbstractClient(ABC): + @abstractmethod + def __init__(self, project_id: str, api_key: str, headers: dict = None): + pass + + @abstractmethod + def get_db_list(self): + pass + + @abstractmethod + def get_db_info(self, db_id: str): + pass + + @abstractmethod + def vectordb(self, db_id: str): + pass diff --git a/pyepsilla/abstract_class/vector_db.py b/pyepsilla/abstract_class/vector_db.py new file mode 100644 index 0000000..13b71c1 --- /dev/null +++ b/pyepsilla/abstract_class/vector_db.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union + + +class AbstractVectordb(ABC): + @abstractmethod + def __init__(self, project_id: str, db_id: str, api_key: str, public_endpoint: str, headers: dict = None): + pass + + @abstractmethod + def list_tables(self): + pass + + @abstractmethod + def create_table(self, table_name: str, table_fields: list[dict] = None, indices: list[dict] = None): + pass + + @abstractmethod + def drop_table(self, table_name: str): + pass + + @abstractmethod + def insert(self, table_name: str, records: list[dict]): + pass + + @abstractmethod + def upsert(self, table_name: str, records: list[dict]): + pass + + @abstractmethod + def query(self, table_name: str, query_text: str = None, query_index: str = None, query_field: str = None, query_vector: Union[list, dict] = None, response_fields: Optional[list] = None, limit: int = 2, filter: Optional[str] = None, with_distance: Optional[bool] = False, facets: Optional[list[dict]] = None): + pass + + @abstractmethod + def delete(self, table_name: str, primary_keys: Optional[list[Union[str, int]]] = None, ids: Optional[list[Union[str, int]]] = None, filter: Optional[str] = None): + pass diff --git a/pyepsilla/cloud/client.py b/pyepsilla/cloud/client.py index 376c1ce..1bce77c 100644 --- a/pyepsilla/cloud/client.py +++ b/pyepsilla/cloud/client.py @@ -12,12 +12,15 @@ import sentry_sdk from pydantic import BaseModel, Field, constr +from ..abstract_class.vector_db import AbstractVectordb +from ..abstract_class.client import AbstractClient + from ..utils.search_engine import SearchEngine requests.packages.urllib3.disable_warnings() -class Client(object): +class Client(AbstractClient): def __init__(self, project_id: str, api_key: str, headers: dict = None): self._project_id = project_id self._apikey = api_key @@ -104,7 +107,7 @@ def vectordb(self, db_id: str): raise Exception("Failed to get db info") -class Vectordb(Client): +class Vectordb(Client, AbstractVectordb): def __init__( self, project_id: str, diff --git a/pyepsilla/enterprise/client.py b/pyepsilla/enterprise/client.py index 222cb76..621b3f4 100644 --- a/pyepsilla/enterprise/client.py +++ b/pyepsilla/enterprise/client.py @@ -12,6 +12,9 @@ import sentry_sdk from pydantic import BaseModel, Field, constr +from ..abstract_class.vector_db import AbstractVectordb +from ..abstract_class.client import AbstractClient + from ..utils.search_engine import SearchEngine requests.packages.urllib3.disable_warnings() # type: ignore @@ -26,7 +29,7 @@ class DbModel(BaseModel): project_id: Optional[str] = "default" -class Client(cloud.Client): +class Client(cloud.Client, AbstractClient): def __init__( self, base_url: str, project_id: Optional[str] = "default", headers: dict = None ): @@ -166,7 +169,7 @@ def drop_db(self, db_id: str): return status_code, body -class Vectordb(object): +class Vectordb(AbstractVectordb): def __init__(self, project_url: str, db_id: str, header: dict): self._db_id = db_id self._baseurl = "{}/vectordb/{}".format(project_url, db_id) From 09d4a60916e1c606c3261336589e32e04b43c30a Mon Sep 17 00:00:00 2001 From: meet Date: Wed, 3 Jul 2024 20:06:47 +0530 Subject: [PATCH 2/2] Centralize API query logic to avoid code duplication --- pyepsilla/cloud/client.py | 86 +++++---------------------- pyepsilla/enterprise/client.py | 104 ++++++-------------------------- pyepsilla/utils/rest_api.py | 42 +++++++++++++ pyepsilla/vectordb/client.py | 105 +++++---------------------------- 4 files changed, 92 insertions(+), 245 deletions(-) create mode 100644 pyepsilla/utils/rest_api.py diff --git a/pyepsilla/cloud/client.py b/pyepsilla/cloud/client.py index 1bce77c..a6318ac 100644 --- a/pyepsilla/cloud/client.py +++ b/pyepsilla/cloud/client.py @@ -15,6 +15,7 @@ from ..abstract_class.vector_db import AbstractVectordb from ..abstract_class.client import AbstractClient +from ..utils.rest_api import get_call, post_call, delete_call from ..utils.search_engine import SearchEngine requests.packages.urllib3.disable_warnings() @@ -37,49 +38,32 @@ def __init__(self, project_id: str, api_key: str, headers: dict = None): self._header.update(headers) def validate(self): - resp = requests.get( + _, data = get_call( url=self._baseurl + "/vectordb/list", data=None, headers=self._header, verify=False, ) - data = resp.json() - resp.close() - del resp return data def get_db_list(self): db_list = [] req_url = "{}/vectordb/list".format(self._baseurl) - resp = requests.get(url=req_url, data=None, headers=self._header, verify=False) - status_code = resp.status_code - body = resp.json() - if status_code == 200 and body["statusCode"] == 200: + status_code, body = get_call(url=req_url, data=None, headers=self._header, verify=False) + if status_code == requests.ok and body["statusCode"] == requests.ok: db_list = [db_id for db_id in body["result"]] - resp.close() - del resp return db_list def get_db_info(self, db_id: str): req_url = "{}/vectordb/{}".format(self._baseurl, db_id) - resp = requests.get(url=req_url, data=None, headers=self._header, verify=False) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body + return get_call(url=req_url, data=None, headers=self._header, verify=False) def get_db_statistics(self, db_id: str): req_url = "{}/vectordb/{}/statistics".format(self._baseurl, db_id) req_data = None - resp = requests.get( + return get_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def vectordb(self, db_id: str): # validate project_id and api_key @@ -97,7 +81,7 @@ def vectordb(self, db_id: str): # fetch db public endpoint status_code, resp = self.get_db_info(db_id=db_id) - if resp["statusCode"] == 200: + if resp["statusCode"] == requests.ok: return Vectordb( self._project_id, db_id, self._apikey, resp["result"]["public_endpoint"] ) @@ -132,12 +116,7 @@ def list_tables(self): if self._db_id is None: raise Exception("[ERROR] db_id is None!") req_url = "{}/table/list".format(self._baseurl) - resp = requests.get(url=req_url, headers=self._header, verify=False) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body + return get_call(url=req_url, headers=self._header, verify=False) # Create table def create_table( @@ -154,14 +133,9 @@ def create_table( req_data = {"name": table_name, "fields": table_fields} if indices is not None: req_data["indices"] = indices - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Drop table def drop_table(self, table_name: str): @@ -169,39 +143,24 @@ def drop_table(self, table_name: str): raise Exception("[ERROR] db_id is None!") req_url = "{}/table/delete?table_name={}".format(self._baseurl, table_name) req_data = {} - resp = requests.delete( + return delete_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Insert data into table def insert(self, table_name: str, records: list[dict]): req_url = "{}/data/insert".format(self._baseurl) req_data = {"table": table_name, "data": records} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def upsert(self, table_name: str, records: list[dict]): req_url = "{}/data/insert".format(self._baseurl) req_data = {"table": table_name, "data": records, "upsert": True} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Query data from table def query( @@ -248,14 +207,9 @@ def query( else: req_data["facets"] = facets - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Delete data from table def delete( @@ -289,14 +243,9 @@ def delete( if filter is not None: req_data["filter"] = filter - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Get data from table def get( @@ -346,14 +295,9 @@ def get( req_data["facets"] = facets req_url = "{}/data/get".format(self._baseurl) - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def as_search_engine(self): return SearchEngine(self) diff --git a/pyepsilla/enterprise/client.py b/pyepsilla/enterprise/client.py index 621b3f4..7c2068a 100644 --- a/pyepsilla/enterprise/client.py +++ b/pyepsilla/enterprise/client.py @@ -15,6 +15,7 @@ from ..abstract_class.vector_db import AbstractVectordb from ..abstract_class.client import AbstractClient +from ..utils.rest_api import get_call, post_call, delete_call from ..utils.search_engine import SearchEngine requests.packages.urllib3.disable_warnings() # type: ignore @@ -52,24 +53,15 @@ def hello(self): def get_db_list(self): db_list = [] req_url = "{}/vectordb/list".format(self._baseurl) - resp = requests.get(url=req_url, data=None, headers=self._header, verify=False) - status_code = resp.status_code - body = resp.json() - if status_code == 200 and body["statusCode"] == 200: - db_list = resp.json()["result"]["uuids"] - resp.close() - del resp + status_code, body = get_call(url=req_url, data=None, headers=self._header, verify=False) + if status_code == requests.ok and body["statusCode"] == requests.ok: + db_list = body.get("result", {}).get("uuids", []) return db_list # Get DB Information by db_id def get_db_info(self, db_id: str): req_url = "{}/vectordb/{}".format(self._baseurl, db_id) - resp = requests.get(url=req_url, data=None, headers=self._header, verify=False) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body + return get_call(url=req_url, data=None, headers=self._header, verify=False) # Connect to DB def vectordb(self, db_id: str): @@ -77,8 +69,8 @@ def vectordb(self, db_id: str): if db_id not in self.get_db_list(): raise Exception("Invalid db_id") - status_code, resp = self.get_db_info(db_id=db_id) - if resp["statusCode"] == 200: + _, resp = self.get_db_info(db_id=db_id) + if resp["statusCode"] == requests.ok: return Vectordb(self._baseurl, db_id, self._header) else: print(resp) @@ -109,64 +101,45 @@ def create_db( "sharding_capacity": sharding_capacity, "sharding_increase_threshold": sharding_increase_threshold, } - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False, ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Load DB def load_db(self, db_id: str): req_url = "{}/vectordb/{}/load".format(self._baseurl, db_id) req_data = {} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False, ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Unload DB def unload_db(self, db_id: str): req_url = "{}/vectordb/{}/unload".format(self._baseurl, db_id) req_data = {} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False, ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Delete DB def drop_db(self, db_id: str): req_url = "{}/vectordb/{}".format(self._baseurl, db_id) req_data = {} - resp = requests.delete( + return delete_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False, ) - status_code = resp.status_code - body = resp.json() - resp.close() - return status_code, body class Vectordb(AbstractVectordb): @@ -180,11 +153,7 @@ def list_tables(self): if self._db_id is None: raise Exception("[ERROR] db_id is None!") req_url = "{}/table/list".format(self._baseurl) - resp = requests.get(url=req_url, headers=self._header, verify=False) - status_code = resp.status_code - body = resp.json() - resp.close() - return status_code, body + return get_call(url=req_url, headers=self._header, verify=False) # Create table def create_table( @@ -201,14 +170,9 @@ def create_table( req_data = {"name": table_name, "fields": table_fields} if indices is not None: req_data["indices"] = indices - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Drop table def drop_table(self, table_name: str): @@ -216,14 +180,9 @@ def drop_table(self, table_name: str): raise Exception("[ERROR] db_id is None!") req_url = "{}/table/delete?table_name={}".format(self._baseurl, table_name) req_data = {} - resp = requests.delete( + return delete_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Insert data into table def insert(self, table_name: str, records: list[dict]): @@ -233,14 +192,9 @@ def insert(self, table_name: str, records: list[dict]): records = [] req_url = "{}/data/insert".format(self._baseurl) req_data = {"table": table_name, "data": records} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def upsert(self, table_name: str, records: list[dict]): if self._db_id is None: @@ -249,14 +203,9 @@ def upsert(self, table_name: str, records: list[dict]): records = [] req_url = "{}/data/insert".format(self._baseurl) req_data = {"table": table_name, "data": records, "upsert": True} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Query data from table def query( @@ -303,14 +252,9 @@ def query( else: req_data["facets"] = facets - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body # Delete data from table def delete( @@ -344,14 +288,9 @@ def delete( if filter is not None: req_data["filter"] = filter - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body ## get data from table def get( @@ -401,14 +340,9 @@ def get( req_data["facets"] = facets req_url = "{}/data/get".format(self._baseurl) - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def as_search_engine(self): return SearchEngine(self) diff --git a/pyepsilla/utils/rest_api.py b/pyepsilla/utils/rest_api.py new file mode 100644 index 0000000..8ee8f3f --- /dev/null +++ b/pyepsilla/utils/rest_api.py @@ -0,0 +1,42 @@ +import requests + + +def get_call(url, headers, data, timeout=None, verify=False): + resp = requests.get( + url=url, + data=data, + headers=headers, + timeout=timeout, + verify=verify, + ) + body = None + status_code = resp.status_code + if status_code == requests.ok: + body = resp.json() + resp.close() + del resp + return status_code, body + + +def post_call(url, headers, data, verify=False): + resp = requests.post( + url=url, data=data, headers=headers, verify=verify) + status_code = resp.status_code + body = None + if status_code == requests.ok: + body = resp.json() + resp.close() + del resp + return status_code, body + + +def delete_call(url, headers, data, verify=False): + resp = requests.delete( + url=url, data=data, headers=headers, verify=verify) + status_code = resp.status_code + body = None + if status_code == requests.ok: + body = resp.json() + resp.close() + del resp + return status_code, body diff --git a/pyepsilla/vectordb/client.py b/pyepsilla/vectordb/client.py index e5761a6..edfeed9 100644 --- a/pyepsilla/vectordb/client.py +++ b/pyepsilla/vectordb/client.py @@ -12,6 +12,7 @@ import sentry_sdk from requests.packages.urllib3.exceptions import InsecureRequestWarning +from ..utils.rest_api import get_call, post_call, delete_call from ..utils.search_engine import SearchEngine requests.packages.urllib3.disable_warnings(InsecureRequestWarning) @@ -53,30 +54,20 @@ def check_networking(self): def welcome(self): req_url = "{}/".format(self._baseurl) req_data = {} - resp = requests.get( + return get_call( url=req_url, data=json.dumps(req_data), headers=self._header, timeout=self._timeout, verify=False, ) - status_code = resp.status_code - body = resp.text - resp.close() - del resp - return status_code, body def state(self): req_url = "{}/state".format(self._baseurl) req_data = {} - resp = requests.get( + return get_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def use_db(self, db_name: str): self._db = db_name @@ -94,40 +85,25 @@ def load_db( req_data["vectorScale"] = vector_scale if wal_enabled is not None: req_data["walEnabled"] = wal_enabled - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def unload_db(self, db_name: str): req_url = "{}/api/{}/unload".format(self._baseurl, db_name) req_data = {} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def statistics(self): if self._db is None: raise Exception("[ERROR] Please use_db() first!") req_url = "{}/api/{}/statistics".format(self._baseurl, self._db) req_data = {} - resp = requests.get( + return get_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def create_table( self, @@ -143,25 +119,15 @@ def create_table( req_data = {"name": table_name, "fields": table_fields} if indices is not None: req_data["indices"] = indices - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def list_tables(self): if self._db is None: raise Exception("[ERROR] Please use_db() first!") req_url = "{}/api/{}/schema/tables/show".format(self._baseurl, self._db) - resp = requests.get(url=req_url, headers=self._header, verify=False) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body + return get_call(url=req_url, headers=self._header, verify=False) def insert(self, table_name: str, records: list = None): if self._db is None: @@ -170,14 +136,9 @@ def insert(self, table_name: str, records: list = None): records = [] req_url = "{}/api/{}/data/insert".format(self._baseurl, self._db) req_data = {"table": table_name, "data": records} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def upsert(self, table_name: str, records: list = None): if self._db is None: @@ -186,14 +147,9 @@ def upsert(self, table_name: str, records: list = None): records = [] req_url = "{}/api/{}/data/insert".format(self._baseurl, self._db) req_data = {"table": table_name, "data": records, "upsert": True} - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def delete( self, @@ -228,21 +184,16 @@ def delete( req_data["primaryKeys"] = primary_keys if filter != None: req_data["filter"] = filter - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def rebuild(self, timeout: int = 7200): req_url = "{}/api/rebuild".format(self._baseurl) req_data = {} print("[INFO] waiting until rebuild is finished ...") start_time = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") - resp = requests.post( + status_code, body = post_call( url=req_url, data=json.dumps(req_data), headers=self._header, @@ -251,10 +202,6 @@ def rebuild(self, timeout: int = 7200): ) end_time = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") print("[INFO] Start Time:{}\n End Time:{}".format(start_time, end_time)) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp return status_code, body def query( @@ -303,14 +250,9 @@ def query( else: req_data["facets"] = facets - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def get( self, @@ -361,14 +303,9 @@ def get( req_data["facets"] = facets req_url = "{}/api/{}/data/get".format(self._baseurl, self._db) - resp = requests.post( + return post_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def drop_table(self, table_name: str = None): if self._db is None: @@ -377,26 +314,16 @@ def drop_table(self, table_name: str = None): self._baseurl, self._db, table_name ) req_data = {} - resp = requests.delete( + return delete_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def drop_db(self, db_name: str): req_url = "{}/api/{}/drop".format(self._baseurl, db_name) req_data = {} - resp = requests.delete( + return delete_call( url=req_url, data=json.dumps(req_data), headers=self._header, verify=False ) - status_code = resp.status_code - body = resp.json() - resp.close() - del resp - return status_code, body def as_search_engine(self): return SearchEngine(self)