diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index 3a7d5b428..f8ab40e3c 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -1655,3 +1655,119 @@ def test_uncompressed_quantitizer(collection_factory: CollectionFactory) -> None assert config.vector_index_config is not None assert isinstance(config.vector_index_config, _VectorIndexConfigHNSW) assert config.vector_index_config.quantizer is None + + +def test_quantitizer_update(collection_factory: CollectionFactory) -> None: + dummy = collection_factory("dummy", ports=(8090, 50061)) + if dummy._connection._weaviate_version.is_lower_than(1, 33, 0): + pytest.skip("uncompressed is not supported in Weaviate versions lower than 1.33.0") + + collection = collection_factory( + vector_config=[ + Configure.Vectors.self_provided( + name="hnsw", + vector_index_config=wvc.config.Configure.VectorIndex.hnsw( + quantizer=Configure.VectorIndex.Quantizer.none() + ), + ), + Configure.Vectors.self_provided( + name="flat", + vector_index_config=wvc.config.Configure.VectorIndex.flat( + quantizer=Configure.VectorIndex.Quantizer.none() + ), + ), + Configure.Vectors.self_provided( + name="dynamic", + vector_index_config=wvc.config.Configure.VectorIndex.dynamic( + flat=Configure.VectorIndex.flat( + quantizer=Configure.VectorIndex.Quantizer.none() + ), + hnsw=Configure.VectorIndex.hnsw( + quantizer=Configure.VectorIndex.Quantizer.none() + ), + ), + ), + ], + ports=(8090, 50061), # async + ) + + collection.config.update_quantizer( + name="hnsw", hnsw_quantizer=Reconfigure.VectorIndex.Quantizer.pq() + ) + + config = collection.config.get() + assert config.vector_config is not None + hnsw_config = config.vector_config["hnsw"].vector_index_config + assert hnsw_config is not None + assert isinstance(hnsw_config, _VectorIndexConfigHNSW) + assert hnsw_config.quantizer is not None + assert isinstance(hnsw_config.quantizer, _PQConfig) + + # other indices are not changed + flat_config = config.vector_config["flat"].vector_index_config + assert flat_config is not None + assert isinstance(flat_config, _VectorIndexConfigFlat) + assert flat_config.quantizer is None + + dynamic_config = config.vector_config["dynamic"].vector_index_config + assert dynamic_config is not None + assert isinstance(dynamic_config, _VectorIndexConfigDynamic) + assert dynamic_config.hnsw.quantizer is None + assert dynamic_config.flat.quantizer is None + + # collection.config.update_quantizer( + # name="flat", + # flat_quantizer=Reconfigure.VectorIndex.Quantizer.bq() + # ) + + # collection.config.update( + # vector_config=wvc.config.Reconfigure.Vectors.update( + # name="flat", + # vector_index_config=Reconfigure.VectorIndex.flat(quantizer=Reconfigure.VectorIndex.Quantizer.bq()) + # ) + # ) + + # config = collection.config.get() + # assert config.vector_config is not None + # flat_config = config.vector_config["flat"].vector_index_config + # assert flat_config is not None + # assert isinstance(flat_config, _VectorIndexConfigFlat) + # assert flat_config.quantizer is not None + # assert isinstance(flat_config.quantizer, _BQConfig) + + collection.config.update_quantizer( + name="dynamic", + hnsw_quantizer=Reconfigure.VectorIndex.Quantizer.sq(), + # flat_quantizer=Reconfigure.VectorIndex.Quantizer.pq() + ) + + config = collection.config.get() + assert config.vector_config is not None + dynamic_config = config.vector_config["dynamic"].vector_index_config + assert dynamic_config is not None + assert isinstance(dynamic_config, _VectorIndexConfigDynamic) + assert dynamic_config.hnsw.quantizer is not None + assert isinstance(dynamic_config.hnsw.quantizer, _SQConfig) + assert dynamic_config.flat.quantizer is None + + +def test_quantitizer_update_legacy(collection_factory: CollectionFactory) -> None: + dummy = collection_factory("dummy", ports=(8090, 50061)) + if dummy._connection._weaviate_version.is_lower_than(1, 33, 0): + pytest.skip("uncompressed is not supported in Weaviate versions lower than 1.33.0") + + collection = collection_factory( + vectorizer_config=Configure.Vectorizer.none(), + vector_index_config=wvc.config.Configure.VectorIndex.dynamic( + flat=Configure.VectorIndex.flat(quantizer=Configure.VectorIndex.Quantizer.none()), + hnsw=Configure.VectorIndex.hnsw(quantizer=Configure.VectorIndex.Quantizer.none()), + ), + ports=(8090, 50061), # async + ) + + collection.config.update_quantizer(hnsw_quantizer=Reconfigure.VectorIndex.Quantizer.pq()) + config = collection.config.get() + assert config.vector_index_config is not None + assert isinstance(config.vector_index_config, _VectorIndexConfigDynamic) + assert config.vector_index_config.hnsw.quantizer is not None + assert isinstance(config.vector_index_config.hnsw.quantizer, _PQConfig) diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index c30c9967e..c4ae85667 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -1128,6 +1128,29 @@ def mutual_exclusivity( ) return v + def __check_vectorizer_quantizer( + self, vectorIndexConfig: _VectorIndexConfigUpdate, vector_index_config: dict + ) -> None: + if isinstance(vectorIndexConfig, _VectorIndexConfigHNSWUpdate) or isinstance( + vectorIndexConfig, _VectorIndexConfigFlatUpdate + ): + self.__check_quantizers( + vectorIndexConfig.quantizer, + vector_index_config, + ) + else: + assert isinstance(vectorIndexConfig, _VectorIndexConfigDynamicUpdate) + if vectorIndexConfig.hnsw is not None: + self.__check_quantizers( + vectorIndexConfig.hnsw.quantizer, + vector_index_config, + ) + if vectorIndexConfig.flat is not None: + self.__check_quantizers( + vectorIndexConfig.flat.quantizer, + vector_index_config, + ) + def __check_quantizers( self, quantizer: Optional[_QuantizerConfigUpdate], @@ -1145,7 +1168,7 @@ def __check_quantizers( or ( isinstance(quantizer, _BQConfigUpdate) and ( - vector_index_config["pq"]["enabled"] + vector_index_config.get("pq", {"enabled": False})["enabled"] or vector_index_config.get("sq", {"enabled": False})["enabled"] or vector_index_config.get("rq", {"enabled": False})["enabled"] ) @@ -1153,7 +1176,7 @@ def __check_quantizers( or ( isinstance(quantizer, _SQConfigUpdate) and ( - vector_index_config["pq"]["enabled"] + vector_index_config.get("pq", {"enabled": False})["enabled"] or vector_index_config.get("bq", {"enabled": False})["enabled"] or vector_index_config.get("rq", {"enabled": False})["enabled"] ) @@ -1161,7 +1184,7 @@ def __check_quantizers( or ( isinstance(quantizer, _RQConfigUpdate) and ( - vector_index_config["pq"]["enabled"] + vector_index_config.get("pq", {"enabled": False})["enabled"] or vector_index_config.get("bq", {"enabled": False})["enabled"] or vector_index_config.get("sq", {"enabled": False})["enabled"] ) @@ -1200,7 +1223,7 @@ def merge_with_existing(self, schema: Dict[str, Any]) -> Dict[str, Any]: schema["multiTenancyConfig"] ) if self.vectorIndexConfig is not None: - self.__check_quantizers(self.vectorIndexConfig.quantizer, schema["vectorIndexConfig"]) + self.__check_vectorizer_quantizer(self.vectorIndexConfig, schema["vectorIndexConfig"]) schema["vectorIndexConfig"] = self.vectorIndexConfig.merge_with_existing( schema["vectorIndexConfig"] ) @@ -1228,9 +1251,10 @@ def merge_with_existing(self, schema: Dict[str, Any]) -> Dict[str, Any]: ) if self.vectorizerConfig is not None: if isinstance(self.vectorizerConfig, _VectorIndexConfigUpdate): - self.__check_quantizers( - self.vectorizerConfig.quantizer, schema["vectorIndexConfig"] + self.__check_vectorizer_quantizer( + self.vectorizerConfig, schema["vectorIndexConfig"] ) + schema["vectorIndexConfig"] = self.vectorizerConfig.merge_with_existing( schema["vectorIndexConfig"] ) @@ -1240,10 +1264,10 @@ def merge_with_existing(self, schema: Dict[str, Any]) -> Dict[str, Any]: raise WeaviateInvalidInputError( f"Vector config with name {vc.name} does not exist in the existing vector config" ) - self.__check_quantizers( - vc.vectorIndexConfig.quantizer, - schema["vectorConfig"][vc.name]["vectorIndexConfig"], + self.__check_vectorizer_quantizer( + vc.vectorIndexConfig, schema["vectorConfig"][vc.name]["vectorIndexConfig"] ) + schema["vectorConfig"][vc.name]["vectorIndexConfig"] = ( vc.vectorIndexConfig.merge_with_existing( schema["vectorConfig"][vc.name]["vectorIndexConfig"] @@ -1263,10 +1287,10 @@ def merge_with_existing(self, schema: Dict[str, Any]) -> Dict[str, Any]: raise WeaviateInvalidInputError( f"Vector config with name {vc.name} does not exist in the existing vector config" ) - self.__check_quantizers( - vc.vectorIndexConfig.quantizer, - schema["vectorConfig"][vc.name]["vectorIndexConfig"], + self.__check_vectorizer_quantizer( + vc.vectorIndexConfig, schema["vectorConfig"][vc.name]["vectorIndexConfig"] ) + schema["vectorConfig"][vc.name]["vectorIndexConfig"] = ( vc.vectorIndexConfig.merge_with_existing( schema["vectorConfig"][vc.name]["vectorIndexConfig"] @@ -1586,8 +1610,8 @@ def vector_index_type() -> str: @dataclass class _VectorIndexConfigDynamic(_ConfigBase): distance_metric: VectorDistances - hnsw: Optional[VectorIndexConfigHNSW] - flat: Optional[VectorIndexConfigFlat] + hnsw: VectorIndexConfigHNSW + flat: VectorIndexConfigFlat threshold: Optional[int] @staticmethod @@ -2292,7 +2316,6 @@ def dynamic( threshold: Optional[int] = None, hnsw: Optional[_VectorIndexConfigHNSWUpdate] = None, flat: Optional[_VectorIndexConfigFlatUpdate] = None, - quantizer: Optional[_BQConfigUpdate] = None, ) -> _VectorIndexConfigDynamicUpdate: """Create an `_VectorIndexConfigDynamicUpdate` object to update the configuration of the Dynamic vector index. @@ -2305,7 +2328,6 @@ def dynamic( threshold=threshold, hnsw=hnsw, flat=flat, - quantizer=quantizer, ) diff --git a/weaviate/collections/classes/config_base.py b/weaviate/collections/classes/config_base.py index 2ad42bad6..cd0d49594 100644 --- a/weaviate/collections/classes/config_base.py +++ b/weaviate/collections/classes/config_base.py @@ -32,7 +32,12 @@ def merge_with_existing(self, schema: Dict[str, Any]) -> Dict[str, Any]: schema[cls_field] = val elif isinstance(val, _QuantizerConfigUpdate): quantizers = ["pq", "bq", "sq"] - schema[val.quantizer_name()] = val.merge_with_existing(schema[val.quantizer_name()]) + if val.quantizer_name() in schema: + schema[val.quantizer_name()] = val.merge_with_existing( + schema[val.quantizer_name()] + ) + else: + schema[val.quantizer_name()] = val.merge_with_existing({}) for quantizer in quantizers: if quantizer == val.quantizer_name() or quantizer not in schema: continue diff --git a/weaviate/collections/classes/config_vector_index.py b/weaviate/collections/classes/config_vector_index.py index d514bfef7..6c90eded5 100644 --- a/weaviate/collections/classes/config_vector_index.py +++ b/weaviate/collections/classes/config_vector_index.py @@ -95,8 +95,6 @@ def _to_dict(self) -> Dict[str, Any]: class _VectorIndexConfigUpdate(_ConfigUpdateModel): - quantizer: Optional[_QuantizerConfigUpdate] = Field(exclude=True) - @staticmethod @abstractmethod def vector_index_type() -> VectorIndexType: ... @@ -136,13 +134,14 @@ def vector_index_type() -> VectorIndexType: class _VectorIndexConfigHNSWUpdate(_VectorIndexConfigUpdate): - dynamicEfMin: Optional[int] - dynamicEfMax: Optional[int] - dynamicEfFactor: Optional[int] - ef: Optional[int] - filterStrategy: Optional[VectorFilterStrategy] - flatSearchCutoff: Optional[int] - vectorCacheMaxObjects: Optional[int] + quantizer: Optional[_QuantizerConfigUpdate] = Field(exclude=True) + dynamicEfMin: Optional[int] = None + dynamicEfMax: Optional[int] = None + dynamicEfFactor: Optional[int] = None + ef: Optional[int] = None + filterStrategy: Optional[VectorFilterStrategy] = None + flatSearchCutoff: Optional[int] = None + vectorCacheMaxObjects: Optional[int] = None @staticmethod def vector_index_type() -> VectorIndexType: @@ -150,7 +149,8 @@ def vector_index_type() -> VectorIndexType: class _VectorIndexConfigFlatUpdate(_VectorIndexConfigUpdate): - vectorCacheMaxObjects: Optional[int] + vectorCacheMaxObjects: Optional[int] = None + quantizer: Optional[_QuantizerConfigUpdate] = Field(exclude=True) @staticmethod def vector_index_type() -> VectorIndexType: @@ -179,9 +179,9 @@ def _to_dict(self) -> dict: class _VectorIndexConfigDynamicUpdate(_VectorIndexConfigUpdate): - threshold: Optional[int] - hnsw: Optional[_VectorIndexConfigHNSWUpdate] - flat: Optional[_VectorIndexConfigFlatUpdate] + threshold: Optional[int] = None + hnsw: Optional[_VectorIndexConfigHNSWUpdate] = None + flat: Optional[_VectorIndexConfigFlatUpdate] = None @staticmethod def vector_index_type() -> VectorIndexType: diff --git a/weaviate/collections/config/async_.pyi b/weaviate/collections/config/async_.pyi index 9fcfefdb3..4aff531dc 100644 --- a/weaviate/collections/config/async_.pyi +++ b/weaviate/collections/config/async_.pyi @@ -22,6 +22,7 @@ from weaviate.collections.classes.config import ( _VectorIndexConfigFlatUpdate, _VectorIndexConfigHNSWUpdate, ) +from weaviate.collections.classes.config_base import _QuantizerConfigUpdate from weaviate.collections.classes.config_vector_index import _VectorIndexConfigDynamicUpdate from weaviate.connect.v4 import ConnectionAsync @@ -80,3 +81,6 @@ class _ConfigCollectionAsync(_ConfigCollectionExecutor[ConnectionAsync]): async def add_vector( self, *, vector_config: Union[_VectorConfigCreate, List[_VectorConfigCreate]] ) -> None: ... + async def update_quantizer( + self, *, name: Optional[str] = None, hnsw_quantizer: Optional[_QuantizerConfigUpdate] = None + ) -> None: ... diff --git a/weaviate/collections/config/executor.py b/weaviate/collections/config/executor.py index e5772b76a..0b5dc4474 100644 --- a/weaviate/collections/config/executor.py +++ b/weaviate/collections/config/executor.py @@ -36,9 +36,13 @@ _ShardStatus, _VectorConfigCreate, _VectorConfigUpdate, + _VectorIndexConfigDynamic, + _VectorIndexConfigFlat, _VectorIndexConfigFlatUpdate, + _VectorIndexConfigHNSW, _VectorIndexConfigHNSWUpdate, ) +from weaviate.collections.classes.config_base import _QuantizerConfigUpdate from weaviate.collections.classes.config_methods import ( _collection_config_from_json, _collection_config_simple_from_json, @@ -577,3 +581,130 @@ async def _execute() -> None: return _execute() schema = executor.result(self.__get()) return executor.result(resp(schema)) + + def update_quantizer( + self, + *, + name: Optional[str] = None, + hnsw_quantizer: Optional[_QuantizerConfigUpdate] = None, + # flat_quantizer: Optional[_QuantizerConfigUpdate] = None, not yet supported by Weaviate + ) -> executor.Result[None]: + """Update the quantizer configurations. + + Args: + name: Name of the vector to update quantizer settings + hnsw_quantizer: The HNSW quantizer configuration to update. + flat_quantizer: The flat quantizer configuration to update. + + Raises: + weaviate.exceptions.WeaviateConnectionError: If the network connection to Weaviate fails. + weaviate.exceptions.UnexpectedStatusCodeError: If Weaviate reports a non-OK status. + weaviate.exceptions.WeaviateInvalidInputError: If the vector already exists in the collection. + """ + flat_quantizer = None + if hnsw_quantizer is None and flat_quantizer is None: + raise WeaviateInvalidInputError("At least one quantizer must be provided.") + + def resp(schema: Dict[str, Any]) -> executor.Result[None]: + config = _collection_config_from_json(schema) + if config.vector_config is None and config.vector_index_config is None: + raise WeaviateInvalidInputError( + "Collection must contain either a vector config or a vector index config." + ) + + if config.vector_config is None: + if isinstance(config.vector_index_config, _VectorIndexConfigHNSW): + if hnsw_quantizer is None: + raise WeaviateInvalidInputError( + "HNSW quantizer must be provided for updating HNSW vector index." + ) + vector_index_config = _VectorIndexConfigHNSWUpdate(quantizer=hnsw_quantizer) + elif isinstance(config.vector_index_config, _VectorIndexConfigFlat): + if flat_quantizer is None: + raise WeaviateInvalidInputError( + "Flat quantizer must be provided for updating flat vector index." + ) + + vector_index_config = _VectorIndexConfigFlatUpdate(quantizer=flat_quantizer) + else: + assert isinstance(config.vector_index_config, _VectorIndexConfigDynamic) + hnsw_update = ( + _VectorIndexConfigHNSWUpdate(quantizer=hnsw_quantizer) + if hnsw_quantizer + else None + ) + flat_update = ( + _VectorIndexConfigFlatUpdate(quantizer=flat_quantizer) + if flat_quantizer + else None + ) + vector_index_config = _VectorIndexConfigDynamicUpdate( + hnsw=hnsw_update, flat=flat_update + ) + + updated_config = _CollectionConfigUpdate( + vector_index_config=vector_index_config, + ) + else: + if name not in config.vector_config: + raise WeaviateInvalidInputError(f"Vector {name} not found in collection.") + + if isinstance( + config.vector_config[name].vector_index_config, _VectorIndexConfigHNSW + ): + if hnsw_quantizer is None: + raise WeaviateInvalidInputError( + "HNSW quantizer must be provided for updating HNSW vector index." + ) + vector_config = _VectorIndexConfigHNSWUpdate(quantizer=hnsw_quantizer) + elif isinstance( + config.vector_config[name].vector_index_config, _VectorIndexConfigFlat + ): + if flat_quantizer is None: + raise WeaviateInvalidInputError( + "Flat quantizer must be provided for updating flat vector index." + ) + + vector_config = _VectorIndexConfigFlatUpdate(quantizer=flat_quantizer) + else: + assert isinstance( + config.vector_config[name].vector_index_config, _VectorIndexConfigDynamic + ) + hnsw_update = ( + _VectorIndexConfigHNSWUpdate(quantizer=hnsw_quantizer) + if hnsw_quantizer + else None + ) + flat_update = ( + _VectorIndexConfigFlatUpdate(quantizer=flat_quantizer) + if flat_quantizer + else None + ) + vector_config = _VectorIndexConfigDynamicUpdate( + hnsw=hnsw_update, flat=flat_update + ) + + updated_config = _CollectionConfigUpdate( + vector_config=_VectorConfigUpdate(name=name, vector_index_config=vector_config), + ) + + return executor.execute( + response_callback=lambda _: None, + method=self._connection.put, + path=f"/schema/{self._name}", + weaviate_object=updated_config.merge_with_existing(schema), + error_msg="Quantizer configuration may not have been updated.", + status_codes=_ExpectedStatusCodes( + ok_in=200, error="Update quantizer configuration" + ), + ) + + if isinstance(self._connection, ConnectionAsync): + + async def _execute() -> None: + schema = await executor.aresult(self.__get()) + return await executor.aresult(resp(schema)) + + return _execute() + schema = executor.result(self.__get()) + return executor.result(resp(schema)) diff --git a/weaviate/collections/config/sync.pyi b/weaviate/collections/config/sync.pyi index 89f37615e..f430a3b70 100644 --- a/weaviate/collections/config/sync.pyi +++ b/weaviate/collections/config/sync.pyi @@ -22,6 +22,7 @@ from weaviate.collections.classes.config import ( _VectorIndexConfigFlatUpdate, _VectorIndexConfigHNSWUpdate, ) +from weaviate.collections.classes.config_base import _QuantizerConfigUpdate from weaviate.collections.classes.config_vector_index import _VectorIndexConfigDynamicUpdate from weaviate.connect.v4 import ConnectionSync @@ -78,3 +79,6 @@ class _ConfigCollection(_ConfigCollectionExecutor[ConnectionSync]): def add_vector( self, *, vector_config: Union[_VectorConfigCreate, List[_VectorConfigCreate]] ) -> None: ... + def update_quantizer( + self, *, name: Optional[str] = None, hnsw_quantizer: Optional[_QuantizerConfigUpdate] = None + ) -> None: ...