diff --git a/cli/build.gradle.kts b/cli/build.gradle.kts index 82253d1e..c890dddc 100644 --- a/cli/build.gradle.kts +++ b/cli/build.gradle.kts @@ -80,6 +80,7 @@ sourceSets.main { dependencies { implementation(projects.config) implementation(projects.library.server) + implementation(projects.library.mcp) implementation(projects.ai.discover) implementation(projects.library.core) implementation(projects.library.sqlite) diff --git a/cli/src/main/kotlin/com/linroid/ketch/cli/Main.kt b/cli/src/main/kotlin/com/linroid/ketch/cli/Main.kt index 69636583..f5f90d0f 100644 --- a/cli/src/main/kotlin/com/linroid/ketch/cli/Main.kt +++ b/cli/src/main/kotlin/com/linroid/ketch/cli/Main.kt @@ -22,6 +22,7 @@ import com.linroid.ketch.config.generateConfig import com.linroid.ketch.core.Ketch import com.linroid.ketch.engine.KtorHttpEngine import com.linroid.ketch.ftp.FtpDownloadSource +import com.linroid.ketch.mcp.KetchMcpServer import com.linroid.ketch.server.KetchServer import com.linroid.ketch.sqlite.DriverFactory import com.linroid.ketch.sqlite.SqliteTaskStore @@ -56,6 +57,10 @@ fun main(args: Array) { runAiDiscover(remaining.drop(1)) return } + "mcp" -> { + runMcp(remaining.drop(1)) + return + } } var url: String? = null @@ -577,9 +582,112 @@ private fun runAiDiscover(args: List) { } } +private fun runMcp(args: List) { + var configPath: String? = null + var cliDownloadDir: String? = null + + var i = 0 + while (i < args.size) { + when (args[i]) { + "--help", "-h" -> { + printMcpUsage() + return + } + "--config" -> { + if (i + 1 >= args.size) { + println("Error: --config requires a value") + println() + printMcpUsage() + return + } + configPath = args[++i] + } + "--dir" -> { + if (i + 1 >= args.size) { + println("Error: --dir requires a value") + println() + printMcpUsage() + return + } + cliDownloadDir = args[++i] + } + } + i++ + } + + val fileConfig = if (configPath != null) { + FileConfigStore(configPath).load() + } else { + val defaultPath = defaultConfigPath() + if (File(defaultPath).exists()) { + FileConfigStore(defaultPath).load() + } else { + KetchConfig() + } + } + + val defaultDownloadDir = System.getProperty("user.home") + + File.separator + "Downloads" + val downloadConfig = fileConfig.download.copy( + defaultDirectory = cliDownloadDir + ?: fileConfig.download.defaultDirectory + ?: defaultDownloadDir, + ) + + File(downloadConfig.defaultDirectory!!).mkdirs() + + val dbPath = defaultDbPath() + val driver = DriverFactory(dbPath).createDriver() + val taskStore = SqliteTaskStore(driver) + + val ketch = Ketch( + httpEngine = KtorHttpEngine(), + taskStore = taskStore, + config = downloadConfig, + logger = Logger.console(ketchLogLevel), + additionalSources = listOf(FtpDownloadSource()), + ) + + Runtime.getRuntime().addShutdownHook(Thread { + ketch.close() + }) + + val mcpServer = KetchMcpServer(ketch) + + runBlocking { + ketch.start() + mcpServer.startStdio() + } +} + +private fun printMcpUsage() { + println("Usage: ketch mcp [options]") + println() + println("Start Ketch as an MCP (Model Context Protocol) server") + println("using stdio transport. AI agents like Claude Desktop") + println("can manage downloads through MCP tools.") + println() + println("Options:") + println(" --config Path to TOML config file") + println(" --dir Download directory") + println(" (default: ~/Downloads)") + println(" --help, -h Show this help message") + println() + println("MCP client configuration (e.g. claude_desktop_config.json):") + println(" {") + println(" \"mcpServers\": {") + println(" \"ketch\": {") + println(" \"command\": \"ketch\",") + println(" \"args\": [\"mcp\"]") + println(" }") + println(" }") + println(" }") +} + private fun printUsage() { println("Usage: ketch [options] [destination]") println(" ketch server [options]") + println(" ketch mcp [options]") println(" ketch ai-discover [options]") println() println("Global Options:") @@ -600,6 +708,11 @@ private fun printUsage() { println(" Run `ketch server --help`") println(" for server options") println() + println("MCP Server:") + println(" mcp [options] Start MCP server (stdio)") + println(" Run `ketch mcp --help`") + println(" for MCP options") + println() println("AI Discovery:") println(" ai-discover Discover downloadable resources") println(" --sites Comma-separated domain allowlist") diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4e82beda..d67b968a 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -62,6 +62,7 @@ okio-nodefilesystem = { module = "com.squareup.okio:okio-nodefilesystem", versio kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinx-datetime" } kermit = { module = "co.touchlab:kermit", version.ref = "kermit" } koog-agents = { module = "ai.koog:koog-agents", version.ref = "koog" } +koog-mcp-server = { module = "ai.koog:agents-mcp-server", version.ref = "koog" } kotlinx-coroutines-android = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-android", version.ref = "kotlinx-coroutines" } kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlinx-coroutines" } kotlinx-coroutinesSwing = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-swing", version.ref = "kotlinx-coroutines" } diff --git a/library/mcp/build.gradle.kts b/library/mcp/build.gradle.kts new file mode 100644 index 00000000..26192908 --- /dev/null +++ b/library/mcp/build.gradle.kts @@ -0,0 +1,15 @@ +plugins { + alias(libs.plugins.kotlinJvm) + alias(libs.plugins.kotlinx.serialization) +} + +dependencies { + api(projects.library.api) + + implementation(libs.koog.mcp.server) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.kotlinx.serialization.json) + + testImplementation(libs.kotlin.test) + testImplementation(libs.kotlinx.coroutines.test) +} diff --git a/library/mcp/src/main/kotlin/com/linroid/ketch/mcp/KetchMcpServer.kt b/library/mcp/src/main/kotlin/com/linroid/ketch/mcp/KetchMcpServer.kt new file mode 100644 index 00000000..dc730a8e --- /dev/null +++ b/library/mcp/src/main/kotlin/com/linroid/ketch/mcp/KetchMcpServer.kt @@ -0,0 +1,65 @@ +package com.linroid.ketch.mcp + +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.reflect.tools +import ai.koog.agents.mcp.server.startSseMcpServer +import ai.koog.agents.mcp.server.startStdioMcpServer +import com.linroid.ketch.api.KetchApi +import io.ktor.server.engine.ApplicationEngineFactory +import kotlinx.coroutines.Job + +/** + * Exposes a [KetchApi] instance as an MCP (Model Context Protocol) + * server, allowing AI agents to manage downloads via MCP tools. + * + * Supports two transport modes: + * - **stdio** — for CLI/editor integration (Claude Desktop, VS Code, etc.) + * - **SSE** — for remote HTTP access + * + * Usage: + * ```kotlin + * val ketch = Ketch(httpEngine = KtorHttpEngine()) + * val mcp = KetchMcpServer(ketch) + * mcp.startStdio() // suspends until closed + * ``` + */ +class KetchMcpServer( + private val ketch: KetchApi, +) { + private val toolRegistry = ToolRegistry { + tools(KetchToolSet(ketch)) + } + + /** + * Starts the MCP server using stdio transport. + * Reads JSON-RPC messages from stdin and writes responses to stdout. + * This is the standard transport for MCP clients like Claude Desktop. + * + * This function suspends until the server is closed. + */ + suspend fun startStdio() { + val server = startStdioMcpServer(toolRegistry) + val done = Job() + server.onClose { done.complete() } + done.join() + } + + /** + * Starts the MCP server using SSE (Server-Sent Events) transport + * over HTTP. + * + * @param factory the Ktor server engine factory (e.g., `CIO`) + * @param port the port to listen on + * @param host the host to bind to + */ + suspend fun startSse( + factory: ApplicationEngineFactory<*, *>, + port: Int = 3001, + host: String = "localhost", + ) { + val server = startSseMcpServer(factory, port, host, toolRegistry) + val done = Job() + server.onClose { done.complete() } + done.join() + } +} diff --git a/library/mcp/src/main/kotlin/com/linroid/ketch/mcp/KetchToolSet.kt b/library/mcp/src/main/kotlin/com/linroid/ketch/mcp/KetchToolSet.kt new file mode 100644 index 00000000..4ccb300a --- /dev/null +++ b/library/mcp/src/main/kotlin/com/linroid/ketch/mcp/KetchToolSet.kt @@ -0,0 +1,327 @@ +package com.linroid.ketch.mcp + +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.core.tools.annotations.Tool +import ai.koog.agents.core.tools.reflect.ToolSet +import com.linroid.ketch.api.Destination +import com.linroid.ketch.api.DownloadPriority +import com.linroid.ketch.api.DownloadRequest +import com.linroid.ketch.api.DownloadState +import com.linroid.ketch.api.KetchApi +import com.linroid.ketch.api.SpeedLimit +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.put + +/** + * Koog [ToolSet] exposing [KetchApi] download management capabilities + * as MCP tools for AI agents. + * + * Each `@Tool` method wraps a [KetchApi] or `DownloadTask` operation + * and returns a JSON-encoded string. + */ +@LLMDescription( + "Download manager tools for starting, managing, and monitoring file downloads", +) +class KetchToolSet( + private val ketch: KetchApi, + private val json: Json = Json { + encodeDefaults = true + ignoreUnknownKeys = true + }, +) : ToolSet { + + @Tool + @LLMDescription( + "List all download tasks with their current state and progress. " + + "Returns JSON array of task snapshots.", + ) + fun listDownloads(): String { + val tasks = ketch.tasks.value + return json.encodeToString( + buildJsonArray { + tasks.forEach { task -> add(taskToJson(task)) } + }, + ) + } + + @Tool + @LLMDescription( + "Get details of a specific download task by its ID. " + + "Returns JSON object with task state, progress, and segments.", + ) + fun getDownload( + @LLMDescription("The unique task ID") + taskId: String, + ): String { + val task = findTask(taskId) ?: return notFound(taskId) + return json.encodeToString(taskToJson(task)) + } + + @Tool + @LLMDescription( + "Start a new download from a URL. Returns the created task snapshot.", + ) + suspend fun startDownload( + @LLMDescription("The URL to download from") + url: String, + @LLMDescription( + "Where to save the file. Can be a directory path " + + "(ending with /), a filename, or a full file path. " + + "Omit to use the default directory.", + ) + destination: String = "", + @LLMDescription( + "Number of concurrent connections (segments). " + + "0 uses the default from config.", + ) + connections: Int = 0, + @LLMDescription( + "Download priority: LOW, NORMAL, HIGH, or URGENT", + ) + priority: String = "NORMAL", + @LLMDescription( + "Speed limit, e.g. '1m' for 1 MB/s, '500k' for 500 KB/s, " + + "or 'unlimited'", + ) + speedLimit: String = "unlimited", + ): String { + val request = DownloadRequest( + url = url, + destination = destination.ifEmpty { null }?.let { Destination(it) }, + connections = connections, + priority = parsePriority(priority), + speedLimit = parseSpeedLimit(speedLimit), + ) + val task = ketch.download(request) + return json.encodeToString(taskToJson(task)) + } + + @Tool + @LLMDescription("Pause a running download. Preserves progress for later resume.") + suspend fun pauseDownload( + @LLMDescription("The unique task ID to pause") + taskId: String, + ): String { + val task = findTask(taskId) ?: return notFound(taskId) + task.pause() + return json.encodeToString(taskToJson(task)) + } + + @Tool + @LLMDescription("Resume a paused or failed download from where it left off.") + suspend fun resumeDownload( + @LLMDescription("The unique task ID to resume") + taskId: String, + ): String { + val task = findTask(taskId) ?: return notFound(taskId) + task.resume() + return json.encodeToString(taskToJson(task)) + } + + @Tool + @LLMDescription("Cancel a download. This is a terminal action and cannot be undone.") + suspend fun cancelDownload( + @LLMDescription("The unique task ID to cancel") + taskId: String, + ): String { + val task = findTask(taskId) ?: return notFound(taskId) + task.cancel() + return json.encodeToString(taskToJson(task)) + } + + @Tool + @LLMDescription( + "Remove a download task from the task list. " + + "Cancels the download if still active.", + ) + suspend fun removeDownload( + @LLMDescription("The unique task ID to remove") + taskId: String, + ): String { + val task = findTask(taskId) ?: return notFound(taskId) + task.remove() + return buildJsonObject { put("removed", taskId) }.toString() + } + + @Tool + @LLMDescription( + "Resolve URL metadata without downloading. Returns file size, " + + "resume support, suggested filename, and source type.", + ) + suspend fun resolveUrl( + @LLMDescription("The URL to resolve") + url: String, + ): String { + val resolved = ketch.resolve(url) + return json.encodeToString( + json.encodeToJsonElement(resolved), + ) + } + + @Tool + @LLMDescription( + "Get server status including version, uptime, configuration, " + + "and system information.", + ) + suspend fun getStatus(): String { + val status = ketch.status() + return json.encodeToString( + json.encodeToJsonElement(status), + ) + } + + @Tool + @LLMDescription("Set the speed limit for a specific download task.") + suspend fun setSpeedLimit( + @LLMDescription("The unique task ID") + taskId: String, + @LLMDescription( + "Speed limit, e.g. '1m' for 1 MB/s, '500k' for 500 KB/s, " + + "or 'unlimited' to remove the limit", + ) + speedLimit: String, + ): String { + val task = findTask(taskId) ?: return notFound(taskId) + task.setSpeedLimit(parseSpeedLimit(speedLimit)) + return json.encodeToString(taskToJson(task)) + } + + @Tool + @LLMDescription("Set the priority of a download task in the queue.") + suspend fun setPriority( + @LLMDescription("The unique task ID") + taskId: String, + @LLMDescription("Priority level: LOW, NORMAL, HIGH, or URGENT") + priority: String, + ): String { + val task = findTask(taskId) ?: return notFound(taskId) + task.setPriority(parsePriority(priority)) + return json.encodeToString(taskToJson(task)) + } + + @Tool + @LLMDescription( + "Update global download configuration such as speed limit " + + "and concurrency settings.", + ) + suspend fun updateConfig( + @LLMDescription( + "Global speed limit, e.g. '10m' for 10 MB/s, " + + "'unlimited' to remove. Empty string to keep current.", + ) + speedLimit: String = "", + @LLMDescription( + "Maximum concurrent downloads. 0 to keep current.", + ) + maxConcurrentDownloads: Int = 0, + @LLMDescription( + "Maximum connections per download. 0 to keep current.", + ) + maxConnectionsPerDownload: Int = 0, + ): String { + val current = ketch.status().config + val updated = current.copy( + speedLimit = if (speedLimit.isEmpty()) { + current.speedLimit + } else { + parseSpeedLimit(speedLimit) + }, + maxConcurrentDownloads = if (maxConcurrentDownloads > 0) { + maxConcurrentDownloads + } else { + current.maxConcurrentDownloads + }, + maxConnectionsPerDownload = if (maxConnectionsPerDownload > 0) { + maxConnectionsPerDownload + } else { + current.maxConnectionsPerDownload + }, + ) + ketch.updateConfig(updated) + return json.encodeToString( + json.encodeToJsonElement(updated), + ) + } + + private fun findTask(taskId: String) = + ketch.tasks.value.find { it.taskId == taskId } + + private fun notFound(taskId: String): String = + buildJsonObject { + put("error", "not_found") + put("message", "Task not found: $taskId") + }.toString() + + private fun taskToJson(task: com.linroid.ketch.api.DownloadTask): JsonObject { + val state = task.state.value + return buildJsonObject { + put("taskId", task.taskId) + put("url", task.request.url) + put("destination", task.request.destination?.value) + put("state", stateName(state)) + put("createdAt", task.createdAt.toString()) + when (state) { + is DownloadState.Downloading -> { + val p = state.progress + put("downloadedBytes", p.downloadedBytes) + put("totalBytes", p.totalBytes) + put("bytesPerSecond", p.bytesPerSecond) + put("percent", p.percent.toDouble()) + } + is DownloadState.Paused -> { + val p = state.progress + put("downloadedBytes", p.downloadedBytes) + put("totalBytes", p.totalBytes) + put("percent", p.percent.toDouble()) + } + is DownloadState.Completed -> { + put("outputPath", state.outputPath) + } + is DownloadState.Failed -> { + put("error", state.error.message ?: "Unknown error") + } + else -> {} + } + if (task.request.priority != DownloadPriority.NORMAL) { + put("priority", task.request.priority.name) + } + if (!task.request.speedLimit.isUnlimited) { + put( + "speedLimit", + "${task.request.speedLimit.bytesPerSecond}", + ) + } + } + } + + private fun stateName(state: DownloadState): String = when (state) { + is DownloadState.Scheduled -> "scheduled" + is DownloadState.Queued -> "queued" + is DownloadState.Downloading -> "downloading" + is DownloadState.Paused -> "paused" + is DownloadState.Completed -> "completed" + is DownloadState.Failed -> "failed" + is DownloadState.Canceled -> "canceled" + } + + private fun parsePriority(value: String): DownloadPriority = + DownloadPriority.entries.find { + it.name.equals(value, ignoreCase = true) + } ?: DownloadPriority.NORMAL + + private fun parseSpeedLimit(value: String): SpeedLimit = + if (value.equals("unlimited", ignoreCase = true) || value.isEmpty()) { + SpeedLimit.Unlimited + } else { + SpeedLimit.parse(value) + ?: throw IllegalArgumentException( + "Invalid speed limit '$value'. " + + "Use e.g. '1m' (1 MB/s), '500k' (500 KB/s), " + + "a raw byte count, or 'unlimited'.", + ) + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index dfec9904..fe9cf5d2 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -45,6 +45,7 @@ include(":library:sqlite") include(":library:ftp") include(":library:torrent") include(":library:server") +include(":library:mcp") // AI modules include(":ai:discover")