diff --git a/docs/track.md b/docs/track.md new file mode 100644 index 000000000..6c65d0335 --- /dev/null +++ b/docs/track.md @@ -0,0 +1,48 @@ +# Tracking events + +`track` lets you record things happening in your app — like failed logins, signups, or password resets. Zen sends these to Aikido so patterns can be detected, like someone failing to log in 50 times in a minute. + +```js +const Zen = require("@aikidosec/firewall"); + +app.post("/login", async (req, res) => { + const user = await authenticate(req.body.username, req.body.password); + + if (!user) { + Zen.track("user.login_failed"); + return res.status(401).json({ error: "Invalid credentials" }); + } + + Zen.setUser({ id: user.id }); + Zen.track("user.login_succeeded"); + res.json({ token: createToken(user) }); +}); +``` + +Zen automatically picks up the IP address, user agent, and current user (if you called [`setUser`](./user.md)) from the request — you don't need to pass those yourself. + +## More examples + +```js +Zen.track("user.signed_up"); +Zen.track("user.password_reset_requested"); +Zen.track("plan.invite_sent"); +Zen.track("payment.failed"); +``` + +## Naming events + +Use lowercase with dots to group related events: + +- `user.login_failed` +- `user.login_succeeded` +- `user.signed_up` +- `user.password_reset_requested` +- `payment.failed` +- `plan.invite_sent` + +## Things to know + +`track` only works inside an HTTP request. If you call it in a background job or a script, nothing gets sent and you'll see a warning in the console. + +If you haven't called `setUser` yet, the event still goes through — it just won't have a user ID attached. diff --git a/end2end/server/app.ts b/end2end/server/app.ts index da888d0a0..332c17638 100644 --- a/end2end/server/app.ts +++ b/end2end/server/app.ts @@ -10,6 +10,8 @@ import { updateConfig } from "./src/handlers/updateConfig.ts"; import { lists } from "./src/handlers/lists.ts"; import { updateIPLists } from "./src/handlers/updateLists.ts"; import { realtimeConfig } from "./src/handlers/realtimeConfig.ts"; +import { stream, disconnectStreams } from "./src/handlers/stream.ts"; +import { deleteApp } from "./src/handlers/deleteApp.ts"; const app = express(); app.set("trust proxy", false); @@ -24,6 +26,8 @@ app.post("/api/runtime/config", checkToken, updateConfig); // Realtime polling endpoint app.get("/config", checkToken, realtimeConfig); +app.get("/api/runtime/stream", checkToken, stream); +app.post("/api/runtime/stream/disconnect", checkToken, disconnectStreams); app.get("/api/runtime/events", checkToken, listEvents); app.post("/api/runtime/events", checkToken, captureEvent); @@ -32,6 +36,7 @@ app.get("/api/runtime/firewall/lists", checkToken, lists); app.post("/api/runtime/firewall/lists", checkToken, updateIPLists); app.post("/api/runtime/apps", createApp); +app.delete("/api/runtime/apps", checkToken, deleteApp); app.listen(port, () => { console.log(`Server is running on port ${port}`); diff --git a/end2end/server/src/handlers/deleteApp.ts b/end2end/server/src/handlers/deleteApp.ts new file mode 100644 index 000000000..5e39657ab --- /dev/null +++ b/end2end/server/src/handlers/deleteApp.ts @@ -0,0 +1,14 @@ +import type { Response } from "express"; +import { removeApp } from "../zen/apps.ts"; +import { closeStreams } from "./stream.ts"; +import type { ZenRequest } from "../types.ts"; + +export function deleteApp(req: ZenRequest, res: Response) { + if (!req.zenApp) { + throw new Error("App is missing"); + } + + removeApp(req.zenApp); + closeStreams(req.zenApp.id); + res.json({ ok: true }); +} diff --git a/end2end/server/src/handlers/stream.ts b/end2end/server/src/handlers/stream.ts new file mode 100644 index 000000000..90a9873f1 --- /dev/null +++ b/end2end/server/src/handlers/stream.ts @@ -0,0 +1,64 @@ +import type { Response } from "express"; +import { getAppConfig, configEvents } from "../zen/config.ts"; +import type { ZenRequest } from "../types.ts"; + +const connections = new Map>(); + +export function stream(req: ZenRequest, res: Response) { + if (!req.zenApp) { + throw new Error("App is missing"); + } + + const app = req.zenApp; + + if (!connections.has(app.id)) { + connections.set(app.id, new Set()); + } + connections.get(app.id)!.add(res); + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + + function sendConfig() { + const config = getAppConfig(app); + const data = { serviceId: app.id, configUpdatedAt: config.configUpdatedAt }; + res.write(`event: config-updated\ndata: ${JSON.stringify(data)}\n\n`); + } + + sendConfig(); + + const eventName = `config-updated:${app.id}`; + configEvents.on(eventName, sendConfig); + + const ping = setInterval(() => { + res.write(": ping\n\n"); + }, 30_000); + + req.on("close", () => { + connections.get(app.id)?.delete(res); + configEvents.off(eventName, sendConfig); + clearInterval(ping); + }); +} + +export function closeStreams(appId: number) { + const appConnections = connections.get(appId); + if (appConnections) { + for (const conn of appConnections) { + conn.end(); + } + appConnections.clear(); + } +} + +export function disconnectStreams(req: ZenRequest, res: Response) { + if (!req.zenApp) { + throw new Error("App is missing"); + } + + closeStreams(req.zenApp.id); + res.json({ ok: true }); +} diff --git a/end2end/server/src/zen/apps.ts b/end2end/server/src/zen/apps.ts index e11e2cd99..a95680e0f 100644 --- a/end2end/server/src/zen/apps.ts +++ b/end2end/server/src/zen/apps.ts @@ -6,7 +6,7 @@ export type App = { configUpdatedAt: number; }; -const apps: App[] = []; +let apps: App[] = []; let id = 1; export function createApp(): string { @@ -20,6 +20,10 @@ export function createApp(): string { return token; } +export function removeApp(app: App): void { + apps = apps.filter((a) => a.id !== app.id); +} + export function getByToken(token: string): App | undefined { return apps.find((app) => { if (app.token.length !== token.length) { diff --git a/end2end/server/src/zen/config.ts b/end2end/server/src/zen/config.ts index c59198c20..cc81d6af4 100644 --- a/end2end/server/src/zen/config.ts +++ b/end2end/server/src/zen/config.ts @@ -1,5 +1,8 @@ +import { EventEmitter } from "node:events"; import type { App } from "./apps.ts"; +export const configEvents = new EventEmitter(); + type AppConfig = { success: boolean; serviceId: number; @@ -53,6 +56,7 @@ export function updateAppConfig(app: App, newConfig: Partial) { ...newConfig, configUpdatedAt: Date.now(), }; + configEvents.emit(`config-updated:${app.id}`); return true; } diff --git a/end2end/tests-new/heartbeat.test.mjs b/end2end/tests-new/heartbeat.test.mjs index 1da88a6c2..548dee847 100644 --- a/end2end/tests-new/heartbeat.test.mjs +++ b/end2end/tests-new/heartbeat.test.mjs @@ -90,7 +90,7 @@ test("It reports own http requests in heartbeat events", async () => { { hostname: "localhost", port: 5874, - hits: 3, + hits: 4, }, ], agent: { diff --git a/end2end/tests-new/hono-pg-esm-outbound.test.mjs b/end2end/tests-new/hono-pg-esm-outbound.test.mjs index c63ad5556..a8d8fb77a 100644 --- a/end2end/tests-new/hono-pg-esm-outbound.test.mjs +++ b/end2end/tests-new/hono-pg-esm-outbound.test.mjs @@ -132,6 +132,8 @@ test("blockNewOutgoingRequests is true", async () => { domains: [ { hostname: "ssrf-redirects.testssandbox.com", mode: "block" }, { hostname: "aikido.dev", mode: "allow" }, + // Otherwise we cannot communicate with the mock server + { hostname: "localhost", mode: "allow" }, ], }), }); diff --git a/end2end/tests-new/realtime-config-updates.test.mjs b/end2end/tests-new/realtime-config-updates.test.mjs new file mode 100644 index 000000000..8bb8e7b9a --- /dev/null +++ b/end2end/tests-new/realtime-config-updates.test.mjs @@ -0,0 +1,223 @@ +import { spawn } from "child_process"; +import { resolve } from "path"; +import { test } from "node:test"; +import { equal, doesNotMatch, match, fail } from "node:assert"; +import { getRandomPort } from "./utils/get-port.mjs"; +import { timeout } from "./utils/timeout.mjs"; + +const pathToAppDir = resolve( + import.meta.dirname, + "../../sample-apps/hono-pg-ts-esm" +); + +const testServerUrl = "http://localhost:5874"; + +function spawnApp(token, port) { + return spawn( + `node`, + [ + "--require", + "@aikidosec/firewall/instrument", + "--experimental-strip-types", + "./app.ts", + port, + ], + { + cwd: pathToAppDir, + env: { + ...process.env, + AIKIDO_TOKEN: token, + AIKIDO_ENDPOINT: testServerUrl, + AIKIDO_REALTIME_ENDPOINT: testServerUrl, + AIKIDO_DEBUG: "true", + AIKIDO_DEBUG_SSE: "true", + AIKIDO_BLOCK: "true", + }, + } + ); +} + +test("it picks up blocked IP via SSE config update", async () => { + const response = await fetch(`${testServerUrl}/api/runtime/apps`, { + method: "POST", + }); + const body = await response.json(); + const token = body.token; + const port = await getRandomPort(); + + const server = spawnApp(token, port); + + try { + server.on("error", (err) => { + fail(err); + }); + + let stdout = ""; + server.stdout.on("data", (data) => { + stdout += data.toString(); + }); + + let stderr = ""; + server.stderr.on("data", (data) => { + stderr += data.toString(); + }); + + // Wait for the server to start and SSE to connect + await timeout(3000); + + // Verify request from 5.6.7.8 is allowed before blocking + const before = await fetch(`http://127.0.0.1:${port}/`, { + headers: { "x-forwarded-for": "5.6.7.8" }, + signal: AbortSignal.timeout(5000), + }); + equal(before.status, 200); + + // Block IP 5.6.7.8 via the test server API + await fetch(`${testServerUrl}/api/runtime/firewall/lists`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: token, + }, + body: JSON.stringify({ + blockedIPAddresses: ["5.6.7.8"], + }), + }); + + // Wait for SSE config-updated event to propagate + await timeout(2000); + + // Verify request from 5.6.7.8 is now blocked + const after = await fetch(`http://127.0.0.1:${port}/`, { + headers: { "x-forwarded-for": "5.6.7.8" }, + signal: AbortSignal.timeout(5000), + }); + equal(after.status, 403); + } catch (err) { + fail(err); + } finally { + server.kill(); + } +}); + +test("it reconnects SSE after server disconnects", async () => { + const response = await fetch(`${testServerUrl}/api/runtime/apps`, { + method: "POST", + }); + const body = await response.json(); + const token = body.token; + const port = await getRandomPort(); + + const server = spawnApp(token, port); + + try { + server.on("error", (err) => { + fail(err); + }); + + let stdout = ""; + server.stdout.on("data", (data) => { + stdout += data.toString(); + }); + + let stderr = ""; + server.stderr.on("data", (data) => { + stderr += data.toString(); + }); + + // Wait for the server to start and SSE to connect + await timeout(3000); + match(stdout, /SSE connected successfully/); + + // Disconnect SSE from the server side + await fetch(`${testServerUrl}/api/runtime/stream/disconnect`, { + method: "POST", + headers: { Authorization: token }, + }); + + // Wait for reconnect (initial reconnect delay is 5s + jitter up to 7.5s) + await timeout(10000); + match(stdout, /SSE connection closed by server/); + + // Verify SSE reconnected + const connectedCount = + stdout.split("SSE connected successfully").length - 1; + equal(connectedCount >= 2, true); + + // Block IP 9.8.7.6 after reconnect to verify the new connection works + await fetch(`${testServerUrl}/api/runtime/firewall/lists`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: token, + }, + body: JSON.stringify({ + blockedIPAddresses: ["9.8.7.6"], + }), + }); + + // Wait for SSE config-updated event to propagate + await timeout(2000); + + // Verify the blocked IP is picked up via the reconnected SSE + const blocked = await fetch(`http://127.0.0.1:${port}/`, { + headers: { "x-forwarded-for": "9.8.7.6" }, + signal: AbortSignal.timeout(5000), + }); + equal(blocked.status, 403); + } catch (err) { + fail(err); + } finally { + server.kill(); + } +}); + +test("it stops SSE reconnect on 401", async () => { + const response = await fetch(`${testServerUrl}/api/runtime/apps`, { + method: "POST", + }); + const body = await response.json(); + const token = body.token; + const port = await getRandomPort(); + + const server = spawnApp(token, port); + + try { + server.on("error", (err) => { + fail(err); + }); + + let stdout = ""; + server.stdout.on("data", (data) => { + stdout += data.toString(); + }); + + let stderr = ""; + server.stderr.on("data", (data) => { + stderr += data.toString(); + }); + + // Wait for the server to start and SSE to connect + await timeout(3000); + match(stdout, /SSE connected successfully/); + + // Revoke the token and disconnect SSE so it tries to reconnect with 401 + await fetch(`${testServerUrl}/api/runtime/apps`, { + method: "DELETE", + headers: { Authorization: token }, + }); + + // Wait for reconnect attempts (may take multiple due to backoff + jitter) + await timeout(15000); + + match(stdout, /SSE connection rejected with status 401, stopping/); + // Should not schedule a reconnect after 401 + const rejectedIndex = stdout.indexOf("SSE connection rejected"); + const afterRejected = stdout.slice(rejectedIndex); + doesNotMatch(afterRejected, /SSE scheduling reconnect/); + } catch (err) { + fail(err); + } finally { + server.kill(); + } +}); diff --git a/end2end/tests/hono-xml-outbound.test.js b/end2end/tests/hono-xml-outbound.test.js index 400350bc5..66f39e5e3 100644 --- a/end2end/tests/hono-xml-outbound.test.js +++ b/end2end/tests/hono-xml-outbound.test.js @@ -132,6 +132,8 @@ t.test("blockNewOutgoingRequests is true", (t) => { domains: [ { hostname: "ssrf-redirects.testssandbox.com", mode: "block" }, { hostname: "aikido.dev", mode: "allow" }, + // Otherwise we cannot communicate with the mock server + { hostname: "localhost", mode: "allow" }, ], }), }); diff --git a/library/agent/Agent.test.ts b/library/agent/Agent.test.ts index 8eb6354bd..394027dc5 100644 --- a/library/agent/Agent.test.ts +++ b/library/agent/Agent.test.ts @@ -459,6 +459,8 @@ t.test( // After a minute, we'll see that the dashboard didn't receive any stats yet // And then send a heartbeat clock.tick(60 * 1000); + // Extra nextAsync to drain the fetch timeout from probeRealtimeURL + await clock.nextAsync(); await clock.nextAsync(); t.match(api.getEvents(), [ { @@ -526,6 +528,7 @@ t.test( // But the stats is still empty, so we won't send a heartbeat clock.tick(60 * 1000); await clock.nextAsync(); + await clock.nextAsync(); t.match(api.getEvents(), [ { type: "started", @@ -586,6 +589,7 @@ t.test("it sends heartbeat when reached max timings", async () => { // After 30 seconds, the first heartbeat should be sent clock.tick(30 * 1000); await clock.nextAsync(); + await clock.nextAsync(); t.match(api.getEvents(), [ { @@ -734,6 +738,7 @@ t.test("unable to prevent prototype pollution", async () => { clock.tick(1000 * 60 * 30); await clock.nextAsync(); + await clock.nextAsync(); t.same(api.getEvents().length, 2); const [_, heartbeat] = api.getEvents(); diff --git a/library/agent/Agent.ts b/library/agent/Agent.ts index 9dfd4aea1..1f724fe88 100644 --- a/library/agent/Agent.ts +++ b/library/agent/Agent.ts @@ -13,10 +13,11 @@ import type { DetectedAttack, DetectedAttackWave, } from "./api/Event"; +import { sendUserEvent, type UserEvent } from "./api/UserEventsAPI"; import { Token } from "./api/Token"; import { Kind } from "./Attack"; -import { Endpoint } from "./Config"; -import { pollForChanges } from "./realtime/pollForChanges"; +import { type Config, Endpoint } from "./Config"; +import { listenForConfigUpdates } from "./realtime/listenForConfigUpdates"; import { Context } from "./Context"; import { Hostnames } from "./Hostnames"; import { InspectionStatistics } from "./InspectionStatistics"; @@ -37,6 +38,9 @@ import type { FetchListsAPI } from "./api/FetchListsAPI"; import { PendingEvents } from "./PendingEvents"; import type { IdorProtectionConfig } from "./IdorProtectionConfig"; import { warnIfTsxIsUsed } from "../helpers/warnIfTsxIsUsed"; +import { pollForChanges } from "./realtime/pollForChanges"; +import { getRealtimeURL } from "./realtime/getRealtimeURL"; +import { probeRealtimeURL } from "./realtime/probeRealtimeURL"; type WrappedPackage = { version: string; supported: boolean }; @@ -451,17 +455,40 @@ export class Agent { } } - private startPollingForConfigChanges() { + private async startCheckingForConfigUpdates() { + if (!this.token) { + return; + } + + const onConfigUpdate = (config: Config) => { + this.updateServiceConfig({ success: true, ...config }); + this.updateBlockedLists().catch((error) => { + this.logger.log(`Failed to update blocked lists: ${error.message}`); + }); + }; + + const lastUpdatedAt = this.serviceConfig.getLastUpdatedAt(); + + const { pollingURL, realtimeReachable } = await probeRealtimeURL( + this.token, + this.logger + ); + + if (realtimeReachable) { + listenForConfigUpdates({ + token: this.token, + logger: this.logger, + lastUpdatedAt, + onConfigUpdate, + }); + } + pollForChanges({ token: this.token, logger: this.logger, - lastUpdatedAt: this.serviceConfig.getLastUpdatedAt(), - onConfigUpdate: (config) => { - this.updateServiceConfig({ success: true, ...config }); - this.updateBlockedLists().catch((error) => { - this.logger.log(`Failed to update blocked lists: ${error.message}`); - }); - }, + lastUpdatedAt, + realtimeURL: pollingURL, + onConfigUpdate, }); } @@ -557,7 +584,7 @@ export class Agent { this.onStart() .then(() => { this.startHeartbeats(); - this.startPollingForConfigChanges(); + this.startCheckingForConfigUpdates(); }) .catch((err) => { console.error(`Aikido: Failed to start agent: ${err.message}`); @@ -760,6 +787,19 @@ export class Agent { } } + onTrackEvent(event: UserEvent) { + if (!this.token) { + return; + } + + const promise = sendUserEvent(this.token, event).catch(() => { + this.logger.log( + `Can't send tracked event, make sure ${getRealtimeURL().hostname} is in your outbound firewall allowlist` + ); + }); + this.pendingEvents.onAPICall(promise); + } + public async shutdown(timeoutInMS = 1000): Promise { this.logger.log("Shutting down agent..."); await this.flushStats(timeoutInMS); diff --git a/library/agent/api/UserEventsAPI.ts b/library/agent/api/UserEventsAPI.ts new file mode 100644 index 000000000..384d1c7ff --- /dev/null +++ b/library/agent/api/UserEventsAPI.ts @@ -0,0 +1,25 @@ +import { fetch } from "../../helpers/fetch"; +import { getRealtimeURL } from "../realtime/getRealtimeURL"; +import type { Token } from "./Token"; + +export type UserEvent = { + name: string; + userId: string | undefined; + ipAddress: string | undefined; +}; + +export async function sendUserEvent( + token: Token, + event: UserEvent +): Promise { + await fetch({ + url: new URL(`${getRealtimeURL().toString()}api/runtime/events`), + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: token.asString(), + }, + body: JSON.stringify(event), + timeoutInMS: 5000, + }); +} diff --git a/library/agent/context/track.ts b/library/agent/context/track.ts new file mode 100644 index 000000000..862d14a75 --- /dev/null +++ b/library/agent/context/track.ts @@ -0,0 +1,42 @@ +import { getInstance } from "../AgentSingleton"; +import { ContextStorage } from "./ContextStorage"; + +export function track(eventName: string): void { + const agent = getInstance(); + + if (!agent) { + return; + } + + if (typeof eventName !== "string" || eventName.length === 0) { + agent.log(`track(...) expects a non-empty string as event name.`); + return; + } + + const context = ContextStorage.getStore(); + if (!context) { + logWarningTrackCalledWithoutContext(); + return; + } + + agent.onTrackEvent({ + name: eventName, + userId: context.user?.id, + ipAddress: context.remoteAddress, + }); +} + +let loggedWarningTrackCalledWithoutContext = false; + +function logWarningTrackCalledWithoutContext() { + if (loggedWarningTrackCalledWithoutContext) { + return; + } + + // oxlint-disable-next-line no-console + console.warn( + "track(...) was called without a context. The event will not be tracked. Make sure to call track(...) within an HTTP request." + ); + + loggedWarningTrackCalledWithoutContext = true; +} diff --git a/library/agent/realtime/connectToSSE.401.test.ts b/library/agent/realtime/connectToSSE.401.test.ts new file mode 100644 index 000000000..a3a258aae --- /dev/null +++ b/library/agent/realtime/connectToSSE.401.test.ts @@ -0,0 +1,43 @@ +import * as t from "tap"; +import { setTimeout } from "node:timers/promises"; +import { createServer } from "http"; +import { Token } from "../api/Token"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { connectToSSE } from "./connectToSSE"; + +t.test("it stops reconnecting on 401", async (t) => { + let connectionCount = 0; + + const server = createServer((_req, res) => { + connectionCount++; + res.writeHead(401); + res.end(); + }); + + await new Promise((resolve) => server.listen(0, resolve)); + server.unref(); + server.on("connection", (socket) => socket.unref()); + const port = (server.address() as { port: number }).port; + process.env.AIKIDO_REALTIME_ENDPOINT = `http://localhost:${port}/`; + + const logger = new LoggerForTesting(); + + try { + connectToSSE({ + token: new Token("bad-token"), + logger, + onEvent() {}, + }); + + await setTimeout(500); + + t.equal(connectionCount, 1); + t.equal(logger.getMessages().length, 1); + t.match( + logger.getMessages()[0], + /SSE connection rejected with status 401, stopping/ + ); + } finally { + server.close(); + } +}); diff --git a/library/agent/realtime/connectToSSE.500.test.ts b/library/agent/realtime/connectToSSE.500.test.ts new file mode 100644 index 000000000..667002208 --- /dev/null +++ b/library/agent/realtime/connectToSSE.500.test.ts @@ -0,0 +1,46 @@ +import * as t from "tap"; +import { setTimeout } from "node:timers/promises"; +import { createServer } from "http"; +import { Token } from "../api/Token"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { connectToSSE } from "./connectToSSE"; + +t.test("it reconnects on non-200 status", async (t) => { + let connectionCount = 0; + + const server = createServer((_req, res) => { + connectionCount++; + if (connectionCount === 1) { + res.writeHead(500); + res.end(); + return; + } + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write(": ping\n\n"); + }); + + await new Promise((resolve) => server.listen(0, resolve)); + server.unref(); + server.on("connection", (socket) => socket.unref()); + const port = (server.address() as { port: number }).port; + process.env.AIKIDO_REALTIME_ENDPOINT = `http://localhost:${port}/`; + + try { + connectToSSE({ + token: new Token("test-token"), + logger: new LoggerForTesting(), + onEvent() {}, + }); + + // Wait for reconnect after 500 (initial delay 5s + up to 2.5s jitter) + await setTimeout(8000); + + t.equal(connectionCount, 2); + } finally { + server.close(); + } +}); diff --git a/library/agent/realtime/connectToSSE.connrefused.test.ts b/library/agent/realtime/connectToSSE.connrefused.test.ts new file mode 100644 index 000000000..6c2ef323b --- /dev/null +++ b/library/agent/realtime/connectToSSE.connrefused.test.ts @@ -0,0 +1,29 @@ +import * as t from "tap"; +import { setTimeout } from "node:timers/promises"; +import { createServer } from "http"; +import { Token } from "../api/Token"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { connectToSSE } from "./connectToSSE"; + +t.test("it handles connection refused", async (t) => { + // Bind and immediately close to get a port that's definitely not in use + const server = createServer(); + await new Promise((resolve) => server.listen(0, resolve)); + const port = (server.address() as { port: number }).port; + await new Promise((resolve) => server.close(() => resolve())); + + process.env.AIKIDO_REALTIME_ENDPOINT = `http://localhost:${port}/`; + process.env.AIKIDO_DEBUG_SSE = "true"; + + const logger = new LoggerForTesting(); + + connectToSSE({ + token: new Token("test-token"), + logger, + onEvent() {}, + }); + + await setTimeout(500); + + t.ok(logger.getMessages().some((m) => m.includes("SSE connection error:"))); +}); diff --git a/library/agent/realtime/connectToSSE.reconnect.test.ts b/library/agent/realtime/connectToSSE.reconnect.test.ts new file mode 100644 index 000000000..a5fc4f669 --- /dev/null +++ b/library/agent/realtime/connectToSSE.reconnect.test.ts @@ -0,0 +1,48 @@ +import * as t from "tap"; +import { setTimeout } from "node:timers/promises"; +import { createServer, type ServerResponse } from "http"; +import { Token } from "../api/Token"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { connectToSSE } from "./connectToSSE"; + +t.test("it reconnects when server closes connection", async (t) => { + let connectionCount = 0; + let sseRes: ServerResponse | null = null; + + const server = createServer((_req, res) => { + connectionCount++; + sseRes = res; + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write(": ping\n\n"); + }); + + await new Promise((resolve) => server.listen(0, resolve)); + server.unref(); + server.on("connection", (socket) => socket.unref()); + const port = (server.address() as { port: number }).port; + process.env.AIKIDO_REALTIME_ENDPOINT = `http://localhost:${port}/`; + + try { + connectToSSE({ + token: new Token("test-token"), + logger: new LoggerForTesting(), + onEvent() {}, + }); + + await setTimeout(200); + t.equal(connectionCount, 1); + + sseRes!.end(); + + // Wait for reconnect (initial delay is 5s + up to 2.5s jitter) + await setTimeout(8000); + + t.equal(connectionCount, 2); + } finally { + server.close(); + } +}); diff --git a/library/agent/realtime/connectToSSE.test.ts b/library/agent/realtime/connectToSSE.test.ts new file mode 100644 index 000000000..699d948ba --- /dev/null +++ b/library/agent/realtime/connectToSSE.test.ts @@ -0,0 +1,71 @@ +import * as t from "tap"; +import { setTimeout } from "node:timers/promises"; +import { createServer, type ServerResponse } from "http"; +import { Token } from "../api/Token"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { connectToSSE } from "./connectToSSE"; +import type { EventSourceMessage } from "../../helpers/eventsource-parser/types"; + +t.test( + "it connects with auth header and receives events, ignoring pings", + async (t) => { + let receivedAuth: string | undefined; + let sseRes: ServerResponse | null = null; + + const server = createServer((req, res) => { + receivedAuth = req.headers.authorization; + sseRes = res; + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + + res.write(": ping\n\n"); + + const data = JSON.stringify({ serviceId: 1, configUpdatedAt: 100 }); + res.write(`event: config-updated\ndata: ${data}\n\n`); + }); + + await new Promise((resolve) => server.listen(0, resolve)); + server.unref(); + server.on("connection", (socket) => socket.unref()); + const port = (server.address() as { port: number }).port; + process.env.AIKIDO_REALTIME_ENDPOINT = `http://localhost:${port}/`; + + const events: EventSourceMessage[] = []; + + try { + connectToSSE({ + token: new Token("my-secret-token"), + logger: new LoggerForTesting(), + onEvent(event) { + events.push(event); + }, + }); + + await setTimeout(200); + + t.equal(receivedAuth, "my-secret-token"); + t.equal(events.length, 1); + t.equal(events[0].event, "config-updated"); + t.same(JSON.parse(events[0].data), { + serviceId: 1, + configUpdatedAt: 100, + }); + + const data2 = JSON.stringify({ serviceId: 1, configUpdatedAt: 200 }); + sseRes!.write(`event: config-updated\ndata: ${data2}\n\n`); + + await setTimeout(100); + + t.equal(events.length, 2); + t.same(JSON.parse(events[1].data), { + serviceId: 1, + configUpdatedAt: 200, + }); + } finally { + server.close(); + } + } +); diff --git a/library/agent/realtime/connectToSSE.timeout.test.ts b/library/agent/realtime/connectToSSE.timeout.test.ts new file mode 100644 index 000000000..b378f88b6 --- /dev/null +++ b/library/agent/realtime/connectToSSE.timeout.test.ts @@ -0,0 +1,43 @@ +import * as t from "tap"; +import { setTimeout } from "node:timers/promises"; +import { createServer } from "http"; +import { Token } from "../api/Token"; +import { LoggerForTesting } from "../logger/LoggerForTesting"; +import { connectToSSE } from "./connectToSSE"; + +t.test("it reconnects on read timeout", async (t) => { + let connectionCount = 0; + + const server = createServer((_req, res) => { + connectionCount++; + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + }); + + await new Promise((resolve) => server.listen(0, resolve)); + server.unref(); + server.on("connection", (socket) => socket.unref()); + const port = (server.address() as { port: number }).port; + process.env.AIKIDO_REALTIME_ENDPOINT = `http://localhost:${port}/`; + + try { + connectToSSE({ + token: new Token("test-token"), + logger: new LoggerForTesting(), + onEvent() {}, + readTimeoutMs: 200, + initialReconnectMs: 100, + }); + + await setTimeout(200); + t.equal(connectionCount, 1); + + await setTimeout(500); + t.equal(connectionCount, 2); + } finally { + server.close(); + } +}); diff --git a/library/agent/realtime/connectToSSE.ts b/library/agent/realtime/connectToSSE.ts new file mode 100644 index 000000000..ff6bbed4f --- /dev/null +++ b/library/agent/realtime/connectToSSE.ts @@ -0,0 +1,174 @@ +import { request as requestHttp } from "http"; +import { request as requestHttps } from "https"; +import { setTimeout } from "node:timers/promises"; +import { createParser } from "../../helpers/eventsource-parser/parse"; +import type { EventSourceMessage } from "../../helpers/eventsource-parser/types"; +import { isDebuggingSSE } from "../../helpers/isDebuggingSSE"; +import { Token } from "../api/Token"; +import { Logger } from "../logger/Logger"; +import { getRealtimeURL } from "./getRealtimeURL"; + +const INITIAL_RECONNECT_MS = 5000; +const MAX_RECONNECT_MS = 60 * 1000; +const STABLE_CONNECTION_MS = 30 * 1000; +const READ_TIMEOUT_MS = 70 * 1000; + +type ConnectResult = + | { outcome: "error" } + | { outcome: "disconnected"; statusCode: number }; + +function connect({ + token, + onEvent, + readTimeoutMs, + logDebug, +}: { + token: Token; + onEvent: (event: EventSourceMessage) => void; + readTimeoutMs: number; + logDebug: (msg: string) => void; +}): Promise { + return new Promise((resolve) => { + let resolved = false; + + function resolveOnce(result: ConnectResult) { + if (resolved) { + return; + } + resolved = true; + resolve(result); + } + + const url = new URL(`${getRealtimeURL().toString()}api/runtime/stream`); + + logDebug(`SSE connecting to ${url.toString()}`); + + const requestFn = url.protocol === "https:" ? requestHttps : requestHttp; + + const req = requestFn( + url.toString(), + { + method: "GET", + headers: { + Authorization: token.asString(), + Accept: "text/event-stream", + "Cache-Control": "no-cache", + }, + }, + (response) => { + const statusCode = response.statusCode!; + + if (statusCode !== 200) { + response.destroy(); + resolveOnce({ outcome: "disconnected", statusCode }); + return; + } + + logDebug("SSE connected successfully"); + + const parser = createParser({ + onEvent(event) { + onEvent(event); + }, + }); + + response.setEncoding("utf-8"); + + response.on("data", (chunk: string) => { + logDebug(`SSE received chunk: ${chunk.trimEnd()}`); + parser.feed(chunk); + }); + + response.on("end", () => { + logDebug("SSE connection closed by server"); + parser.reset(); + resolveOnce({ outcome: "disconnected", statusCode }); + }); + + response.on("error", (error) => { + logDebug(`SSE stream error: ${error.message}`); + parser.reset(); + resolveOnce({ outcome: "disconnected", statusCode }); + }); + } + ); + + req.on("socket", (socket) => { + socket.setTimeout(readTimeoutMs, () => { + if (socket.destroyed) { + return; + } + logDebug("SSE read timeout"); + resolveOnce({ outcome: "error" }); + req.destroy(); + }); + socket.unref(); + }); + + req.on("error", (error) => { + logDebug(`SSE connection error: ${error.message}`); + resolveOnce({ outcome: "error" }); + }); + + req.end(); + }); +} + +export function connectToSSE({ + token, + logger, + onEvent, + initialReconnectMs = INITIAL_RECONNECT_MS, + readTimeoutMs = READ_TIMEOUT_MS, +}: { + token: Token; + logger: Logger; + onEvent: (event: EventSourceMessage) => void; + initialReconnectMs?: number; + readTimeoutMs?: number; +}) { + let reconnectMs = initialReconnectMs; + + const debugSSE = isDebuggingSSE(); + + function logDebug(msg: string) { + if (debugSSE) { + logger.log(msg); + } + } + + async function loop() { + while (true) { + const start = Date.now(); + const result = await connect({ token, onEvent, readTimeoutMs, logDebug }); + + if ( + result.outcome === "disconnected" && + (result.statusCode === 401 || result.statusCode === 403) + ) { + logger.log( + `SSE connection rejected with status ${result.statusCode}, stopping` + ); + return; + } + + if (Date.now() - start >= STABLE_CONNECTION_MS) { + reconnectMs = initialReconnectMs; + } + + const jitter = Math.random() * (reconnectMs / 2); + const delayMs = reconnectMs + jitter; + + logDebug(`SSE scheduling reconnect in ${Math.round(delayMs)}ms`); + + reconnectMs = Math.min(reconnectMs * 2, MAX_RECONNECT_MS); + + // ref: false so the timer doesn't keep the process alive + await setTimeout(delayMs, undefined, { ref: false }); + } + } + + loop().catch((error) => { + logger.log(`SSE loop error: ${error.message}`); + }); +} diff --git a/library/agent/realtime/getConfigLastUpdatedAt.ts b/library/agent/realtime/getConfigLastUpdatedAt.ts index 28f6512fa..741acc434 100644 --- a/library/agent/realtime/getConfigLastUpdatedAt.ts +++ b/library/agent/realtime/getConfigLastUpdatedAt.ts @@ -1,12 +1,14 @@ import { fetch } from "../../helpers/fetch"; import { Token } from "../api/Token"; -import { getRealtimeURL } from "./getRealtimeURL"; type RealtimeResponse = { configUpdatedAt: number }; -export async function getConfigLastUpdatedAt(token: Token): Promise { +export async function getConfigLastUpdatedAt( + token: Token, + realtimeURL: URL +): Promise { const { body, statusCode } = await fetch({ - url: new URL(`${getRealtimeURL().toString()}config`), + url: new URL(`${realtimeURL.toString()}config`), method: "GET", headers: { Authorization: token.asString(), diff --git a/library/agent/realtime/getRealtimeURL.ts b/library/agent/realtime/getRealtimeURL.ts index 6f2cc465d..a177ddf34 100644 --- a/library/agent/realtime/getRealtimeURL.ts +++ b/library/agent/realtime/getRealtimeURL.ts @@ -3,5 +3,5 @@ export function getRealtimeURL() { return new URL(process.env.AIKIDO_REALTIME_ENDPOINT); } - return new URL("https://runtime.aikido.dev"); + return new URL("https://zen.aikido.dev"); } diff --git a/library/agent/realtime/listenForConfigUpdates.ts b/library/agent/realtime/listenForConfigUpdates.ts new file mode 100644 index 000000000..807479ac0 --- /dev/null +++ b/library/agent/realtime/listenForConfigUpdates.ts @@ -0,0 +1,71 @@ +import type { Config } from "../Config"; +import type { Token } from "../api/Token"; +import type { Logger } from "../logger/Logger"; +import { isDebuggingSSE } from "../../helpers/isDebuggingSSE"; +import { connectToSSE } from "./connectToSSE"; +import { getConfig } from "./getConfig"; + +type OnConfigUpdate = (config: Config) => void; + +export function listenForConfigUpdates({ + onConfigUpdate, + token, + logger, + lastUpdatedAt, +}: { + onConfigUpdate: OnConfigUpdate; + token: Token | undefined; + logger: Logger; + lastUpdatedAt: number; +}) { + if (!token) { + logger.log("No token provided, not listening for config updates"); + return; + } + + const validToken = token; + const debugSSE = isDebuggingSSE(); + + function logDebug(msg: string) { + if (debugSSE) { + logger.log(msg); + } + } + + let currentLastUpdatedAt = lastUpdatedAt; + + connectToSSE({ + token, + logger, + onEvent(event) { + logDebug(`SSE event received: ${event.event}`); + if (event.event !== "config-updated") { + return; + } + + try { + const payload: { configUpdatedAt: number } = JSON.parse(event.data); + if (payload.configUpdatedAt <= currentLastUpdatedAt) { + return; + } + } catch { + logDebug(`SSE config-updated event has invalid payload: ${event.data}`); + return; + } + + logDebug("SSE config-updated event, fetching new config"); + + getConfig(validToken) + .then((config) => { + logDebug( + `SSE config fetched, configUpdatedAt: ${config.configUpdatedAt}` + ); + currentLastUpdatedAt = config.configUpdatedAt; + onConfigUpdate(config); + }) + .catch((error) => { + logDebug(`Failed to fetch config after SSE event: ${error.message}`); + }); + }, + }); +} diff --git a/library/agent/realtime/pollForChanges.test.ts b/library/agent/realtime/pollForChanges.test.ts index abfe9be68..a77cb81f2 100644 --- a/library/agent/realtime/pollForChanges.test.ts +++ b/library/agent/realtime/pollForChanges.test.ts @@ -15,6 +15,7 @@ t.test("it does not start interval if no token", async (t) => { logger: logger, token: undefined, lastUpdatedAt: 0, + realtimeURL: new URL("https://zen.aikido.dev"), }); t.same(logger.getMessages(), [ @@ -35,7 +36,7 @@ t.test("it checks for config updates", async () => { method: params.method, }); - if (params.url.hostname.startsWith("runtime")) { + if (params.url.hostname.startsWith("zen")) { return { body: JSON.stringify({ configUpdatedAt: configUpdatedAt, @@ -68,6 +69,7 @@ t.test("it checks for config updates", async () => { logger: new LoggerNoop(), token: new Token("123"), lastUpdatedAt: 0, + realtimeURL: new URL("https://zen.aikido.dev"), }); t.same(configUpdates, []); @@ -78,7 +80,7 @@ t.test("it checks for config updates", async () => { t.same(configUpdates, []); t.same(calls, [ { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, ]); @@ -95,11 +97,11 @@ t.test("it checks for config updates", async () => { ]); t.same(calls, [ { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { @@ -119,11 +121,11 @@ t.test("it checks for config updates", async () => { ]); t.same(calls, [ { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { @@ -131,7 +133,7 @@ t.test("it checks for config updates", async () => { method: "GET", }, { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, ]); @@ -153,11 +155,11 @@ t.test("it checks for config updates", async () => { ]); t.same(calls, [ { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { @@ -165,11 +167,11 @@ t.test("it checks for config updates", async () => { method: "GET", }, { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { - url: "https://runtime.aikido.dev/config", + url: "https://zen.aikido.dev/config", method: "GET", }, { @@ -200,6 +202,7 @@ t.test("it deals with API throwing errors", async () => { logger: logger, token: new Token("123"), lastUpdatedAt: 0, + realtimeURL: new URL("https://zen.aikido.dev"), }); t.same(configUpdates, []); diff --git a/library/agent/realtime/pollForChanges.ts b/library/agent/realtime/pollForChanges.ts index 25825d410..5630e43bd 100644 --- a/library/agent/realtime/pollForChanges.ts +++ b/library/agent/realtime/pollForChanges.ts @@ -14,11 +14,13 @@ export function pollForChanges({ token, logger, lastUpdatedAt, + realtimeURL, }: { onConfigUpdate: OnConfigUpdate; token: Token | undefined; logger: Logger; lastUpdatedAt: number; + realtimeURL: URL; }) { if (!token) { logger.log("No token provided, not polling for config updates"); @@ -32,7 +34,7 @@ export function pollForChanges({ } interval = setInterval(() => { - check(token, onConfigUpdate).catch((error) => { + check(token, realtimeURL, onConfigUpdate).catch((error) => { logger.log(`Failed to check for config updates: ${error.message}`); }); }, 60 * 1000); @@ -40,8 +42,12 @@ export function pollForChanges({ interval.unref(); } -async function check(token: Token, onConfigUpdate: OnConfigUpdate) { - const configLastUpdatedAt = await getConfigLastUpdatedAt(token); +async function check( + token: Token, + realtimeURL: URL, + onConfigUpdate: OnConfigUpdate +) { + const configLastUpdatedAt = await getConfigLastUpdatedAt(token, realtimeURL); if ( typeof currentLastUpdatedAt === "number" && diff --git a/library/agent/realtime/probeRealtimeURL.ts b/library/agent/realtime/probeRealtimeURL.ts new file mode 100644 index 000000000..b059a5d18 --- /dev/null +++ b/library/agent/realtime/probeRealtimeURL.ts @@ -0,0 +1,44 @@ +import { fetch } from "../../helpers/fetch"; +import type { Token } from "../api/Token"; +import type { Logger } from "../logger/Logger"; +import { getRealtimeURL } from "./getRealtimeURL"; + +type RealtimeProbeResult = { + pollingURL: URL; + realtimeReachable: boolean; +}; + +export async function probeRealtimeURL( + token: Token, + logger: Logger +): Promise { + const realtimeURL = getRealtimeURL(); + + if (process.env.AIKIDO_REALTIME_ENDPOINT) { + return { pollingURL: realtimeURL, realtimeReachable: true }; + } + + const configURL = new URL(`${realtimeURL.toString()}config`); + + try { + await fetch({ + url: configURL, + method: "GET", + headers: { + Authorization: token.asString(), + }, + timeoutInMS: 5000, + }); + + return { pollingURL: realtimeURL, realtimeReachable: true }; + } catch { + logger.log( + `Can't reach ${realtimeURL.hostname}, make sure it's in your outbound firewall allowlist. Realtime config updates won't be available, switched to polling.` + ); + + return { + pollingURL: new URL("https://runtime.aikido.dev"), + realtimeReachable: false, + }; + } +} diff --git a/library/helpers/eventsource-parser/errors.ts b/library/helpers/eventsource-parser/errors.ts new file mode 100644 index 000000000..d34b2f29e --- /dev/null +++ b/library/helpers/eventsource-parser/errors.ts @@ -0,0 +1,23 @@ +// Based on https://github.com/rexxars/eventsource-parser +// MIT License - Copyright (c) 2025 Espen Hovlandsdal + +export type ErrorType = "invalid-retry" | "unknown-field"; + +export class ParseError extends Error { + type: ErrorType; + field?: string | undefined; + value?: string | undefined; + line?: string | undefined; + + constructor( + message: string, + options: { type: ErrorType; field?: string; value?: string; line?: string } + ) { + super(message); + this.name = "ParseError"; + this.type = options.type; + this.field = options.field; + this.value = options.value; + this.line = options.line; + } +} diff --git a/library/helpers/eventsource-parser/parse.ts b/library/helpers/eventsource-parser/parse.ts new file mode 100644 index 000000000..780320021 --- /dev/null +++ b/library/helpers/eventsource-parser/parse.ts @@ -0,0 +1,170 @@ +// Based on https://github.com/rexxars/eventsource-parser +// MIT License - Copyright (c) 2025 Espen Hovlandsdal + +import { ParseError } from "./errors"; +import type { EventSourceParser, ParserCallbacks } from "./types"; + +function noop(_arg: unknown) { + // intentional noop +} + +export function createParser(callbacks: ParserCallbacks): EventSourceParser { + const { + onEvent = noop, + onError = noop, + onRetry = noop, + onComment, + } = callbacks; + + let incompleteLine = ""; + let isFirstChunk = true; + let id: string | undefined; + let data = ""; + let eventType = ""; + + function feed(newChunk: string) { + // Strip any UTF8 byte order mark (BOM) at the start of the stream + const chunk = isFirstChunk + ? newChunk.replace(/^\xEF\xBB\xBF/, "") + : newChunk; + + const [complete, incomplete] = splitLines(`${incompleteLine}${chunk}`); + + for (const line of complete) { + parseLine(line); + } + + incompleteLine = incomplete; + isFirstChunk = false; + } + + function parseLine(line: string) { + if (line === "") { + dispatchEvent(); + return; + } + + if (line.startsWith(":")) { + if (onComment) { + onComment(line.slice(line.startsWith(": ") ? 2 : 1)); + } + return; + } + + const fieldSeparatorIndex = line.indexOf(":"); + if (fieldSeparatorIndex !== -1) { + const field = line.slice(0, fieldSeparatorIndex); + const offset = line[fieldSeparatorIndex + 1] === " " ? 2 : 1; + const value = line.slice(fieldSeparatorIndex + offset); + processField(field, value, line); + return; + } + + processField(line, "", line); + } + + function processField(field: string, value: string, line: string) { + switch (field) { + case "event": + eventType = value; + break; + case "data": + data = `${data}${value}\n`; + break; + case "id": + id = value.includes("\0") ? undefined : value; + break; + case "retry": + if (/^\d+$/.test(value)) { + onRetry(parseInt(value, 10)); + } else { + onError( + new ParseError(`Invalid \`retry\` value: "${value}"`, { + type: "invalid-retry", + value, + line, + }) + ); + } + break; + default: + onError( + new ParseError( + `Unknown field "${field.length > 20 ? `${field.slice(0, 20)}…` : field}"`, + { type: "unknown-field", field, value, line } + ) + ); + break; + } + } + + function dispatchEvent() { + const shouldDispatch = data.length > 0; + if (shouldDispatch) { + onEvent({ + id, + event: eventType || undefined, + data: data.endsWith("\n") ? data.slice(0, -1) : data, + }); + } + + id = undefined; + data = ""; + eventType = ""; + } + + function reset(options: { consume?: boolean } = {}) { + if (incompleteLine && options.consume) { + parseLine(incompleteLine); + } + + isFirstChunk = true; + id = undefined; + data = ""; + eventType = ""; + incompleteLine = ""; + } + + return { feed, reset }; +} + +function splitLines( + chunk: string +): [complete: Array, incomplete: string] { + const lines: Array = []; + let incompleteLine = ""; + let searchIndex = 0; + + while (searchIndex < chunk.length) { + const crIndex = chunk.indexOf("\r", searchIndex); + const lfIndex = chunk.indexOf("\n", searchIndex); + + let lineEnd = -1; + if (crIndex !== -1 && lfIndex !== -1) { + lineEnd = Math.min(crIndex, lfIndex); + } else if (crIndex !== -1) { + if (crIndex === chunk.length - 1) { + lineEnd = -1; + } else { + lineEnd = crIndex; + } + } else if (lfIndex !== -1) { + lineEnd = lfIndex; + } + + if (lineEnd === -1) { + incompleteLine = chunk.slice(searchIndex); + break; + } else { + const line = chunk.slice(searchIndex, lineEnd); + lines.push(line); + + searchIndex = lineEnd + 1; + if (chunk[searchIndex - 1] === "\r" && chunk[searchIndex] === "\n") { + searchIndex++; + } + } + } + + return [lines, incompleteLine]; +} diff --git a/library/helpers/eventsource-parser/types.ts b/library/helpers/eventsource-parser/types.ts new file mode 100644 index 000000000..cf8130d56 --- /dev/null +++ b/library/helpers/eventsource-parser/types.ts @@ -0,0 +1,22 @@ +// Based on https://github.com/rexxars/eventsource-parser +// MIT License - Copyright (c) 2025 Espen Hovlandsdal + +import type { ParseError } from "./errors"; + +export interface EventSourceParser { + feed(chunk: string): void; + reset(options?: { consume?: boolean }): void; +} + +export interface EventSourceMessage { + event?: string | undefined; + id?: string | undefined; + data: string; +} + +export interface ParserCallbacks { + onEvent?: ((event: EventSourceMessage) => void) | undefined; + onRetry?: ((retry: number) => void) | undefined; + onComment?: ((comment: string) => void) | undefined; + onError?: ((error: ParseError) => void) | undefined; +} diff --git a/library/helpers/isDebuggingSSE.ts b/library/helpers/isDebuggingSSE.ts new file mode 100644 index 000000000..9282e532b --- /dev/null +++ b/library/helpers/isDebuggingSSE.ts @@ -0,0 +1,5 @@ +import { envToBool } from "./envToBool"; + +export function isDebuggingSSE() { + return envToBool(process.env.AIKIDO_DEBUG_SSE); +} diff --git a/library/index.ts b/library/index.ts index 8c476c354..f665daade 100644 --- a/library/index.ts +++ b/library/index.ts @@ -1,6 +1,7 @@ import isFirewallSupported from "./helpers/isFirewallSupported"; import shouldEnableFirewall from "./helpers/shouldEnableFirewall"; import { setUser } from "./agent/context/user"; +import { track } from "./agent/context/track"; import { markUnsafe } from "./agent/context/markUnsafe"; import { shouldBlockRequest } from "./middleware/shouldBlockRequest"; import { addExpressMiddleware } from "./middleware/express"; @@ -64,6 +65,7 @@ if (!isNewHookSystemUsed()) { export { setUser, + track, markUnsafe, shouldBlockRequest, addExpressMiddleware, @@ -84,6 +86,7 @@ export { // e.g. import Zen from '@aikidosec/firewall'; would not work without this, as Zen.setUser would be undefined export default { setUser, + track, markUnsafe, shouldBlockRequest, addExpressMiddleware, diff --git a/library/sinks/Fetch.test.ts b/library/sinks/Fetch.test.ts index 3b312660d..17a8b77d9 100644 --- a/library/sinks/Fetch.test.ts +++ b/library/sinks/Fetch.test.ts @@ -1,4 +1,5 @@ import * as t from "tap"; +import { setTimeout } from "timers/promises"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; @@ -89,6 +90,9 @@ t.test( agent.start([new Fetch()]); + // Let the realtime probe resolve before we start asserting + await setTimeout(500); + agent.getHostnames().clear(); t.same(agent.getHostnames().asArray(), []); await fetch("http://app.aikido.dev"); diff --git a/library/sinks/HTTPRequest.got.test.ts b/library/sinks/HTTPRequest.got.test.ts index 818103916..0435c219c 100644 --- a/library/sinks/HTTPRequest.got.test.ts +++ b/library/sinks/HTTPRequest.got.test.ts @@ -1,5 +1,4 @@ import * as t from "tap"; -import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; import { HTTPRequest } from "./HTTPRequest"; import { createTestAgent } from "../helpers/createTestAgent"; @@ -51,9 +50,7 @@ t.before(async () => { }); t.test("it works", opts, async (t) => { - const agent = createTestAgent({ - token: new Token("123"), - }); + const agent = createTestAgent(); agent.start([new HTTPRequest()]); t.same(agent.getHostnames().asArray(), []); diff --git a/library/sinks/HTTPRequest.test.ts b/library/sinks/HTTPRequest.test.ts index b7869304f..e3d6a95d0 100644 --- a/library/sinks/HTTPRequest.test.ts +++ b/library/sinks/HTTPRequest.test.ts @@ -71,6 +71,7 @@ const https = require("https") as typeof import("https"); const oldUrl = require("url"); t.test("it works", (t) => { + agent.getHostnames().clear(); t.same(agent.getHostnames().asArray(), []); runWithContext(createContext(), () => { diff --git a/library/sinks/Undici.tests.ts b/library/sinks/Undici.tests.ts index 671b2e7f3..339ae16b0 100644 --- a/library/sinks/Undici.tests.ts +++ b/library/sinks/Undici.tests.ts @@ -1,4 +1,5 @@ import * as t from "tap"; +import { setTimeout } from "timers/promises"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; @@ -71,6 +72,11 @@ export async function createUndiciTests(undiciPkgName: string, port: number) { undiciPkgName ) as typeof import("undici-v6"); + // Let the realtime probe resolve before we start asserting + await setTimeout(500); + agent.getHostnames().clear(); + t.same(agent.getHostnames().asArray(), []); + await request("https://ssrf-redirects.testssandbox.com"); t.same(agent.getHostnames().asArray(), [ { diff --git a/sample-apps/express-mysql/app.js b/sample-apps/express-mysql/app.js index 7f549eec5..d7306affe 100644 --- a/sample-apps/express-mysql/app.js +++ b/sample-apps/express-mysql/app.js @@ -1,5 +1,5 @@ require("dotenv").config(); -require("@aikidosec/firewall"); +const { track, setUser } = require("@aikidosec/firewall"); const Sentry = require("@sentry/node"); Sentry.init({ @@ -108,6 +108,30 @@ async function main(port) { res.status(200).send("Done"); }); + app.post( + "/login", + express.json(), + asyncHandler(async (req, res) => { + const { username, password } = req.body; + + if (!username || !password) { + return res + .status(400) + .json({ error: "username and password required" }); + } + + // Hardcoded credentials for demo purposes + if (username === "admin" && password === "admin") { + setUser({ id: "admin", name: "Admin" }); + track("login_success"); + return res.json({ message: "Login successful" }); + } + + track("login_failure"); + res.status(401).json({ error: "Invalid credentials" }); + }) + ); + // This route is for testing purposes only and uses internal APIs // Normal users should NOT rely on these internals as they may change without notice app.get("/pending-events", (req, res) => { diff --git a/scripts/helpers/test-helpers.mjs b/scripts/helpers/test-helpers.mjs index e74df83ef..5e881b3f8 100644 --- a/scripts/helpers/test-helpers.mjs +++ b/scripts/helpers/test-helpers.mjs @@ -31,7 +31,7 @@ export function throws(...args) { assert.fail("Missing expected exception"); } -export function match(actual, expected, message) { +export function match(actual, expected, ...rest) { if (typeof expected === "string") { expected = new RegExp(RegExp.escape(expected)); } @@ -41,11 +41,11 @@ export function match(actual, expected, message) { actual = String(actual); } - assert.match(actual, expected, message); + assert.match(actual, expected, ...rest); return; } - assert.partialDeepStrictEqual(actual, expected, message); + assert.partialDeepStrictEqual(actual, expected, ...rest); } function toPlainObject(value) {