From 20d30fd93b4f3260e8f9522925063bcff4791133 Mon Sep 17 00:00:00 2001 From: gaudyb <85708998+gaudyb@users.noreply.github.com> Date: Tue, 17 Feb 2026 15:59:36 -0600 Subject: [PATCH] Cosmosdb communities bug (#2232) * work in progress * cosmosdb output error fix * semserver update * remove unnecessary code * clean code * remove unnecessary prints --- .../patch-20260215034903124458.json | 4 + .../graphrag_storage/azure_cosmos_storage.py | 214 +++++++++++++----- packages/graphrag/graphrag/data_model/dfs.py | 15 +- .../graphrag/index/run/run_pipeline.py | 1 + .../index/workflows/create_communities.py | 1 - 5 files changed, 167 insertions(+), 68 deletions(-) create mode 100644 .semversioner/next-release/patch-20260215034903124458.json diff --git a/.semversioner/next-release/patch-20260215034903124458.json b/.semversioner/next-release/patch-20260215034903124458.json new file mode 100644 index 000000000..15c4511f4 --- /dev/null +++ b/.semversioner/next-release/patch-20260215034903124458.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "add support for cosmosdb output" +} diff --git a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py index ff3ec6dec..5423c216e 100644 --- a/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py +++ b/packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py @@ -25,6 +25,8 @@ logger = logging.getLogger(__name__) +_DEFAULT_PAGE_SIZE = 100 + class AzureCosmosStorage(Storage): """The CosmosDB-Storage Implementation.""" @@ -37,7 +39,7 @@ class AzureCosmosStorage(Storage): _database_name: str _container_name: str _encoding: str - _no_id_prefixes: list[str] + _no_id_prefixes: set[str] def __init__( self, @@ -51,7 +53,7 @@ def __init__( """Create a CosmosDB storage instance.""" logger.info("Creating cosmosdb storage") database_name = database_name - if database_name is None: + if not database_name: msg = "CosmosDB Storage requires a base_dir to be specified. This is used as the database name." logger.error(msg) raise ValueError(msg) @@ -81,7 +83,7 @@ def __init__( self._cosmosdb_account_name = ( account_url.split("//")[1].split(".")[0] if account_url else None ) - self._no_id_prefixes = [] + self._no_id_prefixes = set() logger.debug( "Creating cosmosdb storage with account [%s] and database [%s] and container [%s]", self._cosmosdb_account_name, @@ -150,12 +152,10 @@ def find( {"name": "@pattern", "value": file_pattern.pattern} ] - items = list( - self._container_client.query_items( - query=query, - parameters=parameters, - enable_cross_partition_query=True, - ) + items = self._query_all_items( + self._container_client, + query=query, + parameters=parameters, ) logger.debug("All items: %s", [item["id"] for item in items]) num_loaded = 0 @@ -192,72 +192,111 @@ async def get( try: if not self._database_client or not self._container_client: return None + if as_bytes: prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True + query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 + items_list = self._query_all_items( + self._container_client, + query=query, ) - items_list = list(queried_items) + + logger.info("Cosmos load prefix=%s count=%d", prefix, len(items_list)) + + if not items_list: + logger.warning("No items found for prefix %s (key=%s)", prefix, key) + return None + for item in items_list: - item["id"] = item["id"].split(":")[1] + item["id"] = item["id"].split(":", 1)[1] items_json_str = json.dumps(items_list) - items_df = pd.read_json( StringIO(items_json_str), orient="records", lines=False ) - # Drop the "id" column if the original dataframe does not include it - # TODO: Figure out optimal way to handle missing id keys in input dataframes - if prefix in self._no_id_prefixes: - items_df.drop(columns=["id"], axis=1, inplace=True) - + if prefix == "entities": + # Always preserve the Cosmos suffix for debugging/migrations + items_df["cosmos_id"] = items_df["id"] + items_df["id"] = items_df["id"].astype( + str + ) # Only restore pipeline UUID id if we actually have it + + if "human_readable_id" in items_df.columns: + # Fill any NaN values before converting to int + items_df["human_readable_id"] = ( + items_df["human_readable_id"] + .fillna(items_df["id"]) + .astype(int) + ) + else: + # Fresh run case: extract_graph entities may not have entity_id yet + # Keep id as the suffix (stable_key/index) for now. + logger.info( + "Entities loaded without entity_id; leaving id as cosmos suffix." + ) + + if items_df.empty: + logger.warning( + "No rows returned for prefix %s (key=%s)", prefix, key + ) + return None return items_df.to_parquet() item = self._container_client.read_item(item=key, partition_key=key) item_body = item.get("body") return json.dumps(item_body) - except Exception: # noqa: BLE001 - logger.warning("Error reading item %s", key) + except Exception: + logger.exception("Error reading item %s", key) return None async def set(self, key: str, value: Any, encoding: str | None = None) -> None: - """Insert the contents of a file into a cosmosdb container for the given filename key. - - For better optimization, the file is destructured such that each row is a unique cosmosdb item. - """ + """Write an item to Cosmos DB. If the value is bytes, we assume it's a parquet file and we write each row as a separate item with id formatted as {prefix}:{stable_key_or_index}.""" + if not self._database_client or not self._container_client: + error_msg = "Database or container not initialized. Cannot write item." + raise ValueError(error_msg) try: - if not self._database_client or not self._container_client: - msg = "Database or container not initialized" - raise ValueError(msg) # noqa: TRY301 - # value represents a parquet file if isinstance(value, bytes): prefix = self._get_prefix(key) value_df = pd.read_parquet(BytesIO(value)) - value_json = value_df.to_json( - orient="records", lines=False, force_ascii=False - ) - if value_json is None: - logger.error("Error converting output %s to json", key) + + # Decide once per dataframe + df_has_id = "id" in value_df.columns + + # IMPORTANT: if we now have ids, undo the earlier "no id" marking + if df_has_id: + self._no_id_prefixes.discard(prefix) else: - cosmosdb_item_list = json.loads(value_json) - for index, cosmosdb_item in enumerate(cosmosdb_item_list): - # If the id key does not exist in the input dataframe json, create a unique id using the prefix and item index - # TODO: Figure out optimal way to handle missing id keys in input dataframes - if "id" not in cosmosdb_item: - prefixed_id = f"{prefix}:{index}" - self._no_id_prefixes.append(prefix) + self._no_id_prefixes.add(prefix) + + cosmosdb_item_list = json.loads( + value_df.to_json(orient="records", lines=False, force_ascii=False) + ) + + for index, cosmosdb_item in enumerate(cosmosdb_item_list): + if prefix == "entities": + # Stable key for Cosmos identity + stable_key = cosmosdb_item.get("human_readable_id", index) + cosmos_id = f"{prefix}:{stable_key}" + + # If the pipeline provided a final UUID, store it separately + if "id" in cosmosdb_item: + cosmosdb_item["entity_id"] = cosmosdb_item["id"] + + # Cosmos identity must be stable and NEVER change + cosmosdb_item["id"] = cosmos_id + + else: + # Original behavior for non-entity prefixes + if df_has_id: + cosmosdb_item["id"] = f"{prefix}:{cosmosdb_item['id']}" else: - prefixed_id = f"{prefix}:{cosmosdb_item['id']}" - cosmosdb_item["id"] = prefixed_id - self._container_client.upsert_item(body=cosmosdb_item) - # value represents a cache output or stats.json + cosmosdb_item["id"] = f"{prefix}:{index}" + + self._container_client.upsert_item(body=cosmosdb_item) else: - cosmosdb_item = { - "id": key, - "body": json.loads(value), - } + cosmosdb_item = {"id": key, "body": json.loads(value)} self._container_client.upsert_item(body=cosmosdb_item) + except Exception: logger.exception("Error writing item %s", key) @@ -267,16 +306,66 @@ async def has(self, key: str) -> bool: return False if ".parquet" in key: prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True + count = self._query_count( + self._container_client, + query_filter=f"STARTSWITH(c.id, '{prefix}:')", ) - return len(list(queried_items)) > 0 - query = f"SELECT * FROM c WHERE c.id = '{key}'" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True + return count > 0 + count = self._query_count( + self._container_client, + query_filter=f"c.id = '{key}'", ) - return len(list(queried_items)) == 1 + return count >= 1 + + def _query_all_items( + self, + container_client: ContainerProxy, + query: str, + parameters: list[dict[str, Any]] | None = None, + page_size: int = _DEFAULT_PAGE_SIZE, + ) -> list[dict[str, Any]]: + """Fetch all items from a Cosmos DB query using pagination. + + This avoids the pitfalls of calling list() on the full pager, which can + time out or return incomplete results for large result sets. + """ + results: list[dict[str, Any]] = [] + query_kwargs: dict[str, Any] = { + "query": query, + "enable_cross_partition_query": True, + "max_item_count": page_size, + } + if parameters: + query_kwargs["parameters"] = parameters + + pager = container_client.query_items(**query_kwargs).by_page() + for page in pager: + results.extend(page) + return results + + def _query_count( + self, + container_client: ContainerProxy, + query_filter: str, + parameters: list[dict[str, Any]] | None = None, + ) -> int: + """Return the count of items matching a filter, without fetching them all. + + Parameters + ---------- + query_filter: + The WHERE clause (without 'WHERE'), e.g. "STARTSWITH(c.id, 'entities:')". + """ + count_query = f"SELECT VALUE COUNT(1) FROM c WHERE {query_filter}" # noqa: S608 + query_kwargs: dict[str, Any] = { + "query": count_query, + "enable_cross_partition_query": True, + } + if parameters: + query_kwargs["parameters"] = parameters + + results = list(container_client.query_items(**query_kwargs)) + return int(results[0]) if results else 0 # type: ignore[arg-type] async def delete(self, key: str) -> None: """Delete all cosmosdb items belonging to the given filename key.""" @@ -285,11 +374,12 @@ async def delete(self, key: str) -> None: try: if ".parquet" in key: prefix = self._get_prefix(key) - query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608 - queried_items = self._container_client.query_items( - query=query, enable_cross_partition_query=True + query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608 + items = self._query_all_items( + self._container_client, + query=query, ) - for item in queried_items: + for item in items: self._container_client.delete_item( item=item["id"], partition_key=item["id"] ) diff --git a/packages/graphrag/graphrag/data_model/dfs.py b/packages/graphrag/graphrag/data_model/dfs.py index d6d7e729f..7d8ae1f5a 100644 --- a/packages/graphrag/graphrag/data_model/dfs.py +++ b/packages/graphrag/graphrag/data_model/dfs.py @@ -28,6 +28,11 @@ ) +def _safe_int(series: pd.Series, fill: int = -1) -> pd.Series: + """Convert a series to int, filling NaN values first.""" + return series.fillna(fill).astype(int) + + def _split_list_column(value: Any) -> list[Any]: """Split a column containing a list string into an actual list.""" if isinstance(value, str): @@ -38,13 +43,13 @@ def _split_list_column(value: Any) -> list[Any]: def entities_typed(df: pd.DataFrame) -> pd.DataFrame: """Return the entities dataframe with correct types, in case it was stored in a weakly-typed format.""" if SHORT_ID in df.columns: - df[SHORT_ID] = df[SHORT_ID].astype(int) + df[SHORT_ID] = _safe_int(df[SHORT_ID]) if TEXT_UNIT_IDS in df.columns: df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) if NODE_FREQUENCY in df.columns: - df[NODE_FREQUENCY] = df[NODE_FREQUENCY].astype(int) + df[NODE_FREQUENCY] = _safe_int(df[NODE_FREQUENCY], 0) if NODE_DEGREE in df.columns: - df[NODE_DEGREE] = df[NODE_DEGREE].astype(int) + df[NODE_DEGREE] = _safe_int(df[NODE_DEGREE], 0) return df @@ -52,11 +57,11 @@ def entities_typed(df: pd.DataFrame) -> pd.DataFrame: def relationships_typed(df: pd.DataFrame) -> pd.DataFrame: """Return the relationships dataframe with correct types, in case it was stored in a weakly-typed format.""" if SHORT_ID in df.columns: - df[SHORT_ID] = df[SHORT_ID].astype(int) + df[SHORT_ID] = _safe_int(df[SHORT_ID]) if EDGE_WEIGHT in df.columns: df[EDGE_WEIGHT] = df[EDGE_WEIGHT].astype(float) if EDGE_DEGREE in df.columns: - df[EDGE_DEGREE] = df[EDGE_DEGREE].astype(int) + df[EDGE_DEGREE] = _safe_int(df[EDGE_DEGREE], 0) if TEXT_UNIT_IDS in df.columns: df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column) diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index 7f4c45f37..55a4b0017 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -39,6 +39,7 @@ async def run_pipeline( input_storage = create_storage(config.input_storage) output_storage = create_storage(config.output_storage) + output_table_provider = create_table_provider(config.table_provider, output_storage) cache = create_cache(config.cache) diff --git a/packages/graphrag/graphrag/index/workflows/create_communities.py b/packages/graphrag/graphrag/index/workflows/create_communities.py index 28ef3f542..d80b78b1a 100644 --- a/packages/graphrag/graphrag/index/workflows/create_communities.py +++ b/packages/graphrag/graphrag/index/workflows/create_communities.py @@ -30,7 +30,6 @@ async def run_workflow( reader = DataReader(context.output_table_provider) entities = await reader.entities() relationships = await reader.relationships() - max_cluster_size = config.cluster_graph.max_cluster_size use_lcc = config.cluster_graph.use_lcc seed = config.cluster_graph.seed