Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260215034903124458.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "add support for cosmosdb output"
}
214 changes: 152 additions & 62 deletions packages/graphrag-storage/graphrag_storage/azure_cosmos_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

logger = logging.getLogger(__name__)

_DEFAULT_PAGE_SIZE = 100


class AzureCosmosStorage(Storage):
"""The CosmosDB-Storage Implementation."""
Expand All @@ -37,7 +39,7 @@ class AzureCosmosStorage(Storage):
_database_name: str
_container_name: str
_encoding: str
_no_id_prefixes: list[str]
_no_id_prefixes: set[str]

def __init__(
self,
Expand All @@ -51,7 +53,7 @@ def __init__(
"""Create a CosmosDB storage instance."""
logger.info("Creating cosmosdb storage")
database_name = database_name
if database_name is None:
if not database_name:
msg = "CosmosDB Storage requires a base_dir to be specified. This is used as the database name."
logger.error(msg)
raise ValueError(msg)
Expand Down Expand Up @@ -81,7 +83,7 @@ def __init__(
self._cosmosdb_account_name = (
account_url.split("//")[1].split(".")[0] if account_url else None
)
self._no_id_prefixes = []
self._no_id_prefixes = set()
logger.debug(
"Creating cosmosdb storage with account [%s] and database [%s] and container [%s]",
self._cosmosdb_account_name,
Expand Down Expand Up @@ -150,12 +152,10 @@ def find(
{"name": "@pattern", "value": file_pattern.pattern}
]

items = list(
self._container_client.query_items(
query=query,
parameters=parameters,
enable_cross_partition_query=True,
)
items = self._query_all_items(
self._container_client,
query=query,
parameters=parameters,
)
logger.debug("All items: %s", [item["id"] for item in items])
num_loaded = 0
Expand Down Expand Up @@ -192,72 +192,111 @@ async def get(
try:
if not self._database_client or not self._container_client:
return None

if as_bytes:
prefix = self._get_prefix(key)
query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608
queried_items = self._container_client.query_items(
query=query, enable_cross_partition_query=True
query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608
items_list = self._query_all_items(
self._container_client,
query=query,
)
items_list = list(queried_items)

logger.info("Cosmos load prefix=%s count=%d", prefix, len(items_list))

if not items_list:
logger.warning("No items found for prefix %s (key=%s)", prefix, key)
return None

for item in items_list:
item["id"] = item["id"].split(":")[1]
item["id"] = item["id"].split(":", 1)[1]

items_json_str = json.dumps(items_list)

items_df = pd.read_json(
StringIO(items_json_str), orient="records", lines=False
)

# Drop the "id" column if the original dataframe does not include it
# TODO: Figure out optimal way to handle missing id keys in input dataframes
if prefix in self._no_id_prefixes:
items_df.drop(columns=["id"], axis=1, inplace=True)

if prefix == "entities":
# Always preserve the Cosmos suffix for debugging/migrations
items_df["cosmos_id"] = items_df["id"]
items_df["id"] = items_df["id"].astype(
str
) # Only restore pipeline UUID id if we actually have it

if "human_readable_id" in items_df.columns:
# Fill any NaN values before converting to int
items_df["human_readable_id"] = (
items_df["human_readable_id"]
.fillna(items_df["id"])
.astype(int)
)
else:
# Fresh run case: extract_graph entities may not have entity_id yet
# Keep id as the suffix (stable_key/index) for now.
logger.info(
"Entities loaded without entity_id; leaving id as cosmos suffix."
)

if items_df.empty:
logger.warning(
"No rows returned for prefix %s (key=%s)", prefix, key
)
return None
return items_df.to_parquet()
item = self._container_client.read_item(item=key, partition_key=key)
item_body = item.get("body")
return json.dumps(item_body)
except Exception: # noqa: BLE001
logger.warning("Error reading item %s", key)
except Exception:
logger.exception("Error reading item %s", key)
return None

async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Insert the contents of a file into a cosmosdb container for the given filename key.

For better optimization, the file is destructured such that each row is a unique cosmosdb item.
"""
"""Write an item to Cosmos DB. If the value is bytes, we assume it's a parquet file and we write each row as a separate item with id formatted as {prefix}:{stable_key_or_index}."""
if not self._database_client or not self._container_client:
error_msg = "Database or container not initialized. Cannot write item."
raise ValueError(error_msg)
try:
if not self._database_client or not self._container_client:
msg = "Database or container not initialized"
raise ValueError(msg) # noqa: TRY301
# value represents a parquet file
if isinstance(value, bytes):
prefix = self._get_prefix(key)
value_df = pd.read_parquet(BytesIO(value))
value_json = value_df.to_json(
orient="records", lines=False, force_ascii=False
)
if value_json is None:
logger.error("Error converting output %s to json", key)

# Decide once per dataframe
df_has_id = "id" in value_df.columns

# IMPORTANT: if we now have ids, undo the earlier "no id" marking
if df_has_id:
self._no_id_prefixes.discard(prefix)
else:
cosmosdb_item_list = json.loads(value_json)
for index, cosmosdb_item in enumerate(cosmosdb_item_list):
# If the id key does not exist in the input dataframe json, create a unique id using the prefix and item index
# TODO: Figure out optimal way to handle missing id keys in input dataframes
if "id" not in cosmosdb_item:
prefixed_id = f"{prefix}:{index}"
self._no_id_prefixes.append(prefix)
self._no_id_prefixes.add(prefix)

cosmosdb_item_list = json.loads(
value_df.to_json(orient="records", lines=False, force_ascii=False)
)

for index, cosmosdb_item in enumerate(cosmosdb_item_list):
if prefix == "entities":
# Stable key for Cosmos identity
stable_key = cosmosdb_item.get("human_readable_id", index)
cosmos_id = f"{prefix}:{stable_key}"

# If the pipeline provided a final UUID, store it separately
if "id" in cosmosdb_item:
cosmosdb_item["entity_id"] = cosmosdb_item["id"]

# Cosmos identity must be stable and NEVER change
cosmosdb_item["id"] = cosmos_id

else:
# Original behavior for non-entity prefixes
if df_has_id:
cosmosdb_item["id"] = f"{prefix}:{cosmosdb_item['id']}"
else:
prefixed_id = f"{prefix}:{cosmosdb_item['id']}"
cosmosdb_item["id"] = prefixed_id
self._container_client.upsert_item(body=cosmosdb_item)
# value represents a cache output or stats.json
cosmosdb_item["id"] = f"{prefix}:{index}"

self._container_client.upsert_item(body=cosmosdb_item)
else:
cosmosdb_item = {
"id": key,
"body": json.loads(value),
}
cosmosdb_item = {"id": key, "body": json.loads(value)}
self._container_client.upsert_item(body=cosmosdb_item)

except Exception:
logger.exception("Error writing item %s", key)

Expand All @@ -267,16 +306,66 @@ async def has(self, key: str) -> bool:
return False
if ".parquet" in key:
prefix = self._get_prefix(key)
query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608
queried_items = self._container_client.query_items(
query=query, enable_cross_partition_query=True
count = self._query_count(
self._container_client,
query_filter=f"STARTSWITH(c.id, '{prefix}:')",
)
return len(list(queried_items)) > 0
query = f"SELECT * FROM c WHERE c.id = '{key}'" # noqa: S608
queried_items = self._container_client.query_items(
query=query, enable_cross_partition_query=True
return count > 0
count = self._query_count(
self._container_client,
query_filter=f"c.id = '{key}'",
)
return len(list(queried_items)) == 1
return count >= 1

def _query_all_items(
self,
container_client: ContainerProxy,
query: str,
parameters: list[dict[str, Any]] | None = None,
page_size: int = _DEFAULT_PAGE_SIZE,
) -> list[dict[str, Any]]:
"""Fetch all items from a Cosmos DB query using pagination.

This avoids the pitfalls of calling list() on the full pager, which can
time out or return incomplete results for large result sets.
"""
results: list[dict[str, Any]] = []
query_kwargs: dict[str, Any] = {
"query": query,
"enable_cross_partition_query": True,
"max_item_count": page_size,
}
if parameters:
query_kwargs["parameters"] = parameters

pager = container_client.query_items(**query_kwargs).by_page()
for page in pager:
results.extend(page)
return results

def _query_count(
self,
container_client: ContainerProxy,
query_filter: str,
parameters: list[dict[str, Any]] | None = None,
) -> int:
"""Return the count of items matching a filter, without fetching them all.

Parameters
----------
query_filter:
The WHERE clause (without 'WHERE'), e.g. "STARTSWITH(c.id, 'entities:')".
"""
count_query = f"SELECT VALUE COUNT(1) FROM c WHERE {query_filter}" # noqa: S608
query_kwargs: dict[str, Any] = {
"query": count_query,
"enable_cross_partition_query": True,
}
if parameters:
query_kwargs["parameters"] = parameters

results = list(container_client.query_items(**query_kwargs))
return int(results[0]) if results else 0 # type: ignore[arg-type]

async def delete(self, key: str) -> None:
"""Delete all cosmosdb items belonging to the given filename key."""
Expand All @@ -285,11 +374,12 @@ async def delete(self, key: str) -> None:
try:
if ".parquet" in key:
prefix = self._get_prefix(key)
query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}')" # noqa: S608
queried_items = self._container_client.query_items(
query=query, enable_cross_partition_query=True
query = f"SELECT * FROM c WHERE STARTSWITH(c.id, '{prefix}:')" # noqa: S608
items = self._query_all_items(
self._container_client,
query=query,
)
for item in queried_items:
for item in items:
self._container_client.delete_item(
item=item["id"], partition_key=item["id"]
)
Expand Down
15 changes: 10 additions & 5 deletions packages/graphrag/graphrag/data_model/dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
)


def _safe_int(series: pd.Series, fill: int = -1) -> pd.Series:
"""Convert a series to int, filling NaN values first."""
return series.fillna(fill).astype(int)


def _split_list_column(value: Any) -> list[Any]:
"""Split a column containing a list string into an actual list."""
if isinstance(value, str):
Expand All @@ -38,25 +43,25 @@ def _split_list_column(value: Any) -> list[Any]:
def entities_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the entities dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
df[SHORT_ID] = _safe_int(df[SHORT_ID])
if TEXT_UNIT_IDS in df.columns:
df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column)
if NODE_FREQUENCY in df.columns:
df[NODE_FREQUENCY] = df[NODE_FREQUENCY].astype(int)
df[NODE_FREQUENCY] = _safe_int(df[NODE_FREQUENCY], 0)
if NODE_DEGREE in df.columns:
df[NODE_DEGREE] = df[NODE_DEGREE].astype(int)
df[NODE_DEGREE] = _safe_int(df[NODE_DEGREE], 0)

return df


def relationships_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the relationships dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
df[SHORT_ID] = _safe_int(df[SHORT_ID])
if EDGE_WEIGHT in df.columns:
df[EDGE_WEIGHT] = df[EDGE_WEIGHT].astype(float)
if EDGE_DEGREE in df.columns:
df[EDGE_DEGREE] = df[EDGE_DEGREE].astype(int)
df[EDGE_DEGREE] = _safe_int(df[EDGE_DEGREE], 0)
if TEXT_UNIT_IDS in df.columns:
df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column)

Expand Down
1 change: 1 addition & 0 deletions packages/graphrag/graphrag/index/run/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def run_pipeline(
input_storage = create_storage(config.input_storage)

output_storage = create_storage(config.output_storage)

output_table_provider = create_table_provider(config.table_provider, output_storage)

cache = create_cache(config.cache)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ async def run_workflow(
reader = DataReader(context.output_table_provider)
entities = await reader.entities()
relationships = await reader.relationships()

max_cluster_size = config.cluster_graph.max_cluster_size
use_lcc = config.cluster_graph.use_lcc
seed = config.cluster_graph.seed
Expand Down
Loading