diff --git a/pycsw/core/repository.py b/pycsw/core/repository.py index d0d599f7f..a4d9d9db2 100644 --- a/pycsw/core/repository.py +++ b/pycsw/core/repository.py @@ -347,15 +347,18 @@ def query_ids(self, ids): query = self.session.query(self.dataset).filter(column.in_(ids)) return self._get_repo_filter(query).all() - def query_collections(self, filters=None, limit=10): + def query_collections(self, collection=None, filters=None, limit=10): ''' Query for parent collections ''' column = getattr(self.dataset, self.context.md_core_model['mappings']['pycsw:ParentIdentifier']) - collections = self.session.query(column).distinct() + if collection is not None: + collections = self.session.query(column).filter(column==collection) + else: + collections = self.session.query(column) - results = self._get_repo_filter(collections).all() + results = self._get_repo_filter(collections).distinct().all() ids = [res[0] for res in results if res[0] is not None] diff --git a/pycsw/ogc/api/records.py b/pycsw/ogc/api/records.py index 1652eb2c3..003676981 100644 --- a/pycsw/ogc/api/records.py +++ b/pycsw/ogc/api/records.py @@ -518,7 +518,7 @@ def queryables(self, headers_, args, collection='metadata:main'): if 'json' in headers_['Content-Type']: headers_['Content-Type'] = 'application/schema+json' - if collection not in self.get_all_collections(): + if collection not in self.get_collections(collection=collection): msg = 'Invalid collection' LOGGER.exception(msg) return self.get_exception(400, headers_, 'InvalidParameterValue', msg) @@ -599,7 +599,7 @@ def items(self, headers_, json_post_data, args, collection='metadata:main'): collections = [] cql_ops_list = [] - if collection not in self.get_all_collections(): + if collection not in self.get_collections(collection=collection): msg = 'Invalid collection' LOGGER.exception(msg) return self.get_exception(400, headers_, 'InvalidParameterValue', msg) @@ -968,7 +968,7 @@ def item(self, headers_, args, collection, item): record = None headers_['Content-Type'] = self.get_content_type(headers_, args) - if collection not in self.get_all_collections(): + if collection not in self.get_collections(collection=collection): msg = 'Invalid collection' LOGGER.exception(msg) return self.get_exception(400, headers_, 'InvalidParameterValue', msg) @@ -1247,7 +1247,7 @@ def federated_catalogue(self, headers_, args, collection, catalogue): return self.get_response(200, headers_, response, template) - def get_all_collections(self) -> list: + def get_collections(self, collection=None) -> list: """ Get all collections @@ -1255,7 +1255,7 @@ def get_all_collections(self) -> list: """ default_collection = 'metadata:main' - virtual_collections = self.repository.query_collections(limit=self.limit) + virtual_collections = self.repository.query_collections(collection=collection, limit=self.limit) return [default_collection] + [vc.identifier for vc in virtual_collections] diff --git a/pycsw/stac/api.py b/pycsw/stac/api.py index 0852afd9d..eab8ef388 100644 --- a/pycsw/stac/api.py +++ b/pycsw/stac/api.py @@ -256,7 +256,7 @@ def collections(self, headers_, args): filters = to_filter(ast, self.repository.dbtype, self.repository.query_mappings) LOGGER.debug(f'Filter: {filters}') - virtual_collections = self.repository.query_collections(filters, limit) + virtual_collections = self.repository.query_collections(filters=filters, limit=limit) for virtual_collection in virtual_collections: virtual_collection_info = self.get_collection_info( @@ -432,7 +432,7 @@ def items(self, headers_, json_post_data, args, collection='metadata:main'): distributed_search_args = deepcopy(args) distributed_search_args.pop('type', None) - if collection not in self.get_all_collections(): + if collection not in self.get_collections(collection=collection): msg = 'Invalid collection' LOGGER.exception(msg) return self.get_exception(400, headers_, 'InvalidParameterValue', msg) @@ -689,7 +689,7 @@ def item(self, headers_, args, collection, item): :returns: tuple of headers, status code, content """ - if collection not in self.get_all_collections(): + if collection not in self.get_collections(collection=collection): msg = 'Invalid collection' LOGGER.exception(msg) return self.get_exception(400, headers_, 'InvalidParameterValue', msg) @@ -773,7 +773,7 @@ def get_collection_info(self, collection_name: str = 'metadata:main', def manage_collection_item(self, headers_, action='create', item=None, data=None, collection=None): if action in ['create', 'update']: if (data is not None and data.get('type', '') == 'Feature' and - collection not in self.get_all_collections()): + collection not in self.get_collections(collection=collection)): msg = 'Invalid collection' LOGGER.exception(msg) return self.get_exception(400, headers_, 'InvalidParameterValue', msg)