Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,25 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
worker: [aragorn, aragorn_lookup, aragorn_pathfinder, aragorn_omnicorp, aragorn_score, arax, bte, bte_lookup, sipr, filter_kgraph_orphans, filter_results_top_n, finish_query, merge_message, sort_results_score]
worker:
- aragorn
- aragorn_lookup
- aragorn_omnicorp
- aragorn_pathfinder
- aragorn_score
- arax
- bte
- bte_lookup
- filter_analyses_top_n
- filter_kgraph_orphans
- filter_results_top_n
- finish_query
- gandalf
- gandalf_rehydrate
- merge_message
- score_paths
- sipr
- sort_results_score
steps:
- name: Check out the repo
uses: actions/checkout@v4
Expand Down
65 changes: 65 additions & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ services:
container_name: shepherd_db
environment:
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-supersecretpassw0rd}
command: postgres -c max_connections=200
build:
context: .
dockerfile: shepherd_db/Dockerfile
Expand Down Expand Up @@ -65,6 +66,42 @@ services:
environment:
- COLLECTOR_OTLP_ENABLED=true
- LOG_LEVEL=error

gandalf:
container_name: gandalf
platform: linux/amd64
build:
context: .
dockerfile: workers/gandalf/Dockerfile
restart: unless-stopped
depends_on:
shepherd_db:
condition: service_healthy
shepherd_broker:
condition: service_healthy
volumes:
- ./workers/gandalf/debug:/app/debug
- ./logs:/app/logs
- ./.env:/app/.env
- ./gandalf_mmap:/app/gandalf_mmap

gandalf_rehydrate:
container_name: gandalf_rehydrate
platform: linux/amd64
build:
context: .
dockerfile: workers/gandalf_rehydrate/Dockerfile
restart: unless-stopped
depends_on:
shepherd_db:
condition: service_healthy
shepherd_broker:
condition: service_healthy
volumes:
- ./workers/gandalf_rehydrate/debug:/app/debug
- ./logs:/app/logs
- ./.env:/app/.env
- ./gandalf_mmap:/app/gandalf_mmap


########## Shared Workers
Expand Down Expand Up @@ -138,6 +175,34 @@ services:
volumes:
- ./logs:/app/logs
- ./.env:/app/.env
filter_analyses_top_n:
container_name: filter_analyses_top_n
build:
context: .
dockerfile: workers/filter_analyses_top_n/Dockerfile
restart: unless-stopped
depends_on:
shepherd_db:
condition: service_healthy
shepherd_broker:
condition: service_healthy
volumes:
- ./logs:/app/logs
- ./.env:/app/.env
score_paths:
container_name: score_paths
build:
context: .
dockerfile: workers/score_paths/Dockerfile
restart: unless-stopped
depends_on:
shepherd_db:
condition: service_healthy
shepherd_broker:
condition: service_healthy
volumes:
- ./logs:/app/logs
- ./.env:/app/.env


######### Example ARA
Expand Down
4 changes: 2 additions & 2 deletions shepherd_server/aras/aragorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

@ARAGORN.post("/query")
async def sync_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_sync_query(ARATargetEnum.ARAGORN, query)
return response


@ARAGORN.post("/asyncquery")
async def async_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_async_query(ARATargetEnum.ARAGORN, query)
return response
Expand Down
4 changes: 2 additions & 2 deletions shepherd_server/aras/arax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

@ARAX.post("/query")
async def sync_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_sync_query(ARATargetEnum.ARAX, query)
return response


@ARAX.post("/asyncquery")
async def async_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_async_query(ARATargetEnum.ARAX, query)
return response
Expand Down
4 changes: 2 additions & 2 deletions shepherd_server/aras/bte.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

@BTE.post("/query")
async def sync_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_sync_query(ARATargetEnum.BTE, query)
return response


@BTE.post("/asyncquery")
async def async_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_async_query(ARATargetEnum.BTE, query)
return response
Expand Down
4 changes: 2 additions & 2 deletions shepherd_server/aras/sipr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

@SIPR.post("/query")
async def sync_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_sync_query(ARATargetEnum.SIPR, query)
return response


@SIPR.post("/asyncquery")
async def async_query(
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> Response:
response = await run_async_query(ARATargetEnum.SIPR, query)
return response
Expand Down
22 changes: 16 additions & 6 deletions shepherd_server/base_routes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Base API routes that all Shepherd ARAs can use."""

import asyncio
from enum import Enum
import json
import logging
import time
import uuid
from enum import Enum
from typing import Optional, Tuple

from fastapi import APIRouter, Body, Response
from fastapi.responses import JSONResponse
from opentelemetry import trace
from opentelemetry.propagate import inject

from shepherd_utils.broker import add_task
Expand Down Expand Up @@ -79,7 +80,7 @@ async def run_query(

logger.info(f"Sending {query_id} to {target}")

with tracer.start_as_current_span("") as span:
with tracer.start_as_current_span(""):
span_carrier = {}
# adds otel trace to carrier for next worker
inject(span_carrier)
Expand All @@ -96,6 +97,8 @@ async def run_query(
"sort_results_score",
"filter_results_top_n",
"filter_kgraph_ophans",
"score_paths",
"filter_analyses_top_n",
]
)
workflow = None
Expand Down Expand Up @@ -131,7 +134,7 @@ async def run_query(

async def run_sync_query(
target: ARATargetEnum,
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> dict:
"""Handle synchronous TRAPI queries."""
# query_dict = query.dict()
Expand Down Expand Up @@ -168,7 +171,7 @@ async def run_sync_query(

async def run_async_query(
target: ARATargetEnum,
query: dict = Body(..., example=default_input_query),
query: dict = Body(..., examples=[default_input_query]),
) -> JSONResponse:
"""Handle asynchronous TRAPI queries."""
callback_url = query.get("callback")
Expand Down Expand Up @@ -221,6 +224,13 @@ async def callback(
logger.info(f"Got original query id: {query_id}")
if query_id is None:
return Response("Couldn't find original query.", 500)
# if len(response["message"]["results"]) > 0:
# with open(
# f"shepherd_server/debug/{query_id}_{callback_id}_response.json",
# "w",
# encoding="utf-8",
# ) as f:
# json.dump(response, f, indent=2)
query_state = await get_query_state(query_id, logger)
if query_state is None:
return Response("Failed to get query state.", 500)
Expand Down Expand Up @@ -249,7 +259,7 @@ async def query_status(
qid: str,
) -> dict:
"""Handle query status requests."""
# get query status from db
# TODO: get query status from db
return {
"status": "Queued",
"description": "Query is currently waiting to be run.",
Expand Down
1 change: 0 additions & 1 deletion shepherd_server/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import uvicorn


if __name__ == "__main__":
uvicorn.run(
"shepherd_server.server:APP",
Expand Down
2 changes: 1 addition & 1 deletion shepherd_server/openapi-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ contact:
x-id: https://github.com/maximusunc
x-role: responsible developer
description: '<img src="/static/favicon.png" width="200px"><br /><br />Shepherd: Translator Autonomous Relay Agent Platform'
version: 0.5.7
version: 0.6.0
servers:
- description: Default server
url: https://shepherd.renci.org
Expand Down
4 changes: 1 addition & 3 deletions shepherd_utils/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@
lock_client = aioredis.Redis(connection_pool=lock_redis_pool)


async def create_consumer_group(
stream, group, logger: logging.Logger
):
async def create_consumer_group(stream, group, logger: logging.Logger):
"""Ensure a redis consumer group exists."""
try:
await broker_client.xgroup_create(stream, group, "0", mkstream=True)
Expand Down
3 changes: 1 addition & 2 deletions shepherd_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@

async def check_connection(conn):
"""Check if the postgres connection is still alive."""
if conn.closed:
raise OperationalError("Connection is closed.")
await conn.execute("SELECT 1")


pool = AsyncConnectionPool(
Expand Down
2 changes: 1 addition & 1 deletion shepherd_utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def merge_kgraph(og_message, new_message, source, logger: logging.Logger):
else:
merged_kgraph["edges"][key] = value

if value["sources"] and not is_support_edge(value):
if value.get("sources") and not is_support_edge(value):
new_sources = combine_unique_dicts(
value["sources"],
[
Expand Down
4 changes: 4 additions & 0 deletions workers/aragorn/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ async def aragorn(task, logger: logging.Logger):
{"id": "aragorn.pathfinder"},
# {"id": "aragorn.omnicorp"},
# {"id": "aragorn.score"},
{"id": "score_paths"},
{"id": "sort_results_score"},
{"id": "filter_analyses_top_n", "parameters": {"max_analyses": 500}},
{"id": "filter_kgraph_orphans"},
{"id": "gandalf.rehydrate"},
]
else:
workflow = [
Expand Down
Loading
Loading