Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions src/core/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ export class SessionManager {
private tokenExpiresIn: number = 0;
private refreshTimer: ReturnType<typeof setTimeout> | undefined;
private gateway: string = '';
private apiBaseOverride: string = '';
private clientId: string = '';
private accessKey: string = '';

Expand All @@ -21,7 +20,7 @@ export class SessionManager {
* API calls are proxied through this gateway.
*/
public get baseUrl(): string {
return this.apiBaseOverride || this.gateway;
return this.gateway;
}

/** Current JWT token. */
Expand All @@ -33,7 +32,7 @@ export class SessionManager {
}

/** Connect and obtain initial JWT token. */
async connect(gateway: string, clientId: string, accessKey: string, apiBaseOverride?: string): Promise<void> {
async connect(gateway: string, clientId: string, accessKey: string): Promise<void> {
// Normalize gateway to protocol + host only
try {
const url = new URL(gateway);
Expand All @@ -44,15 +43,6 @@ export class SessionManager {
this.clientId = clientId;
this.accessKey = accessKey;

if (apiBaseOverride) {
try {
this.apiBaseOverride = new URL(apiBaseOverride).origin;
} catch {
this.apiBaseOverride = apiBaseOverride;
}
console.error(`[auth] API base overridden to ${this.apiBaseOverride}`);
}

await this.performLogin();
this.scheduleRefresh();
console.error(`[auth] Connected to ${this.gateway}, token expires in ${this.tokenExpiresIn}s`);
Expand Down Expand Up @@ -97,6 +87,7 @@ export class SessionManager {
console.error(`[auth] Token refreshed, expires in ${this.tokenExpiresIn}s`);
} catch (err) {
console.error(`[auth] Token refresh failed: ${err}`);
this.scheduleRefresh();
}
}, delaySeconds * 1000);
// Don't keep the process alive just for token refresh
Expand Down
27 changes: 2 additions & 25 deletions src/core/streamable-http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ const JSON_RPC = '2.0';
class MCPStreamableHttpServer {
createServer: () => Server;
transports: { [sessionId: string]: StreamableHTTPServerTransport } = {};
lastSessionId: string | undefined;

constructor(createServer: () => Server) {
this.createServer = createServer;
}
Expand Down Expand Up @@ -76,14 +74,10 @@ class MCPStreamableHttpServer {
if (newSessionId) {
console.error(`New session established: ${newSessionId}`);
this.transports[newSessionId] = transport;
this.lastSessionId = newSessionId;

transport.onclose = () => {
console.error(`Session closed: ${newSessionId}`);
delete this.transports[newSessionId];
if (this.lastSessionId === newSessionId) {
this.lastSessionId = undefined;
}
};
}

Expand All @@ -94,24 +88,7 @@ class MCPStreamableHttpServer {
return toFetchResponse(res);
}

// Fallback: route session-less non-initialize requests to the last active session
if (!sessionId && this.lastSessionId && this.transports[this.lastSessionId]) {
// Inject the session ID header so the transport's own validation passes
const fallbackRequest = new Request(c.req.raw.url, {
method: c.req.raw.method,
headers: new Headers([...c.req.raw.headers.entries(), [SESSION_ID_HEADER_NAME, this.lastSessionId]]),
body: bodyText,
});
const { req: fallbackReq, res: fallbackRes } = toReqRes(fallbackRequest);
const transport = this.transports[this.lastSessionId];
await transport.handleRequest(fallbackReq, fallbackRes, body);
fallbackRes.on('close', () => {
console.error(`Request closed for fallback session ${this.lastSessionId}`);
});
return toFetchResponse(fallbackRes);
}

return c.json(this.createErrorResponse('Bad Request: invalid session ID or method.'), 400);
return c.json(this.createErrorResponse('Bad Request: missing or invalid session ID.'), 400);
} catch (error) {
console.error('Error handling MCP request:', error);
return c.json(this.createErrorResponse('Internal server error.'), 500);
Expand Down Expand Up @@ -141,7 +118,7 @@ class MCPStreamableHttpServer {
export async function setupStreamableHttpServer(createServer: () => Server, port = 9096) {
const app = new Hono();

app.use('*', cors());
app.use('*', cors({ origin: [`http://localhost:${port}`, `http://127.0.0.1:${port}`] }));

const mcpHandler = new MCPStreamableHttpServer(createServer);

Expand Down
6 changes: 0 additions & 6 deletions src/executer/executer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ export function buildRequest(
throw new Error(`Failed to resolve path parameters: ${urlPath}`);
}

// Strip configured prefix for local service testing
const removePrefix = process.env.REMOVE_CI_PREFIX;
if (removePrefix && urlPath.startsWith(removePrefix)) {
urlPath = urlPath.slice(removePrefix.length);
}

// Handle request body
if (definition.requestBodyContentType && typeof toolArgs['requestBody'] !== 'undefined') {
requestBodyData = toolArgs['requestBody'];
Expand Down
13 changes: 7 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* CP_CI_CLIENT_ID - CloudInfra API key client ID (required)
* CP_CI_ACCESS_KEY - CloudInfra API key secret (required)
* CP_CI_GATEWAY - CloudInfra gateway URL (required)
* WRITE_MODE - Enable write tools (default: false)
* WRITE_MODE - Enable write tools (default: false)
* MCP_MODE - Transport mode: "stdio" or "http" (required)
* PORT - HTTP port (required when MCP_MODE=http)
*/
Expand Down Expand Up @@ -72,7 +72,6 @@ async function main() {
const clientId = process.env.CP_CI_CLIENT_ID;
const accessKey = process.env.CP_CI_ACCESS_KEY;
const gateway = process.env.CP_CI_GATEWAY;
const apiBaseOverride = process.env.CP_CI_API_BASE_URL;
const enableWrite = process.env.WRITE_MODE?.toLowerCase() === 'true';
const rawMcpMode = process.env.MCP_MODE;
const port = +(process.env.PORT || '0');
Expand Down Expand Up @@ -106,7 +105,7 @@ async function main() {

// Connect to CloudInfra
try {
await sessionManager.connect(gateway, clientId, accessKey, apiBaseOverride);
await sessionManager.connect(gateway, clientId, accessKey);
} catch (error) {
console.error('Error connecting to CloudInfra gateway:', error);
process.exit(1);
Expand Down Expand Up @@ -139,12 +138,14 @@ async function cleanup() {
process.on('SIGINT', cleanup);
process.on('SIGTERM', cleanup);

// Prevent crashes from unhandled errors
// Log and exit on unhandled errors to avoid running in a corrupt state
process.on('uncaughtException', (error) => {
console.error('Uncaught exception (server will continue):', error);
console.error('Uncaught exception — shutting down:', error);
process.exit(1);
});
process.on('unhandledRejection', (reason) => {
console.error('Unhandled rejection (server will continue):', reason);
console.error('Unhandled rejection — shutting down:', reason);
process.exit(1);
});

main().catch((error) => {
Expand Down
47 changes: 1 addition & 46 deletions test/executer.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, it, expect, vi, afterEach } from 'vitest';
import { describe, it, expect, vi } from 'vitest';
import { buildRequest, buildResult, buildErrorResult } from '../src/executer/executer.js';
import { McpToolDefinition } from '../src/types/types.js';
import { AxiosResponse } from 'axios';
Expand Down Expand Up @@ -31,20 +31,9 @@ function makeDef(overrides: Partial<McpToolDefinition> = {}): McpToolDefinition
}

describe('buildRequest', () => {
const originalPrefix = process.env.REMOVE_CI_PREFIX;

afterEach(() => {
if (originalPrefix === undefined) {
delete process.env.REMOVE_CI_PREFIX;
} else {
process.env.REMOVE_CI_PREFIX = originalPrefix;
}
});

// --- basic request building ---

it('builds GET request with correct URL and auth header', () => {
delete process.env.REMOVE_CI_PREFIX;
const req = buildRequest(makeDef(), {});
expect(req.method).toBe('GET');
expect(req.url).toBe(
Expand All @@ -63,7 +52,6 @@ describe('buildRequest', () => {
// --- path parameters ---

it('replaces path parameters and encodes them', () => {
delete process.env.REMOVE_CI_PREFIX;
const def = makeDef({
pathTemplate: '/api/{groupId}/items/{itemId}',
executionParameters: [
Expand Down Expand Up @@ -142,39 +130,6 @@ describe('buildRequest', () => {
expect(req.data).toBeUndefined();
});

// --- REMOVE_CI_PREFIX ---

it('strips matching prefix from path', () => {
process.env.REMOVE_CI_PREFIX = '/app/genai-protect-policy';
const req = buildRequest(makeDef(), {});
expect(req.url).toBe('https://gw.example.com/mcp/v1/policy/chats/rulebase');
});

it('leaves path unchanged when prefix does not match', () => {
process.env.REMOVE_CI_PREFIX = '/other/prefix';
const req = buildRequest(makeDef(), {});
expect(req.url).toBe(
'https://gw.example.com/app/genai-protect-policy/mcp/v1/policy/chats/rulebase',
);
});

it('leaves path unchanged when env var is not set', () => {
delete process.env.REMOVE_CI_PREFIX;
const req = buildRequest(makeDef(), {});
expect(req.url).toBe(
'https://gw.example.com/app/genai-protect-policy/mcp/v1/policy/chats/rulebase',
);
});

it('strips prefix after path parameter resolution', () => {
process.env.REMOVE_CI_PREFIX = '/app/genai-protect-policy';
const def = makeDef({
pathTemplate: '/app/genai-protect-policy/mcp/v1/policy/{policyId}/rules',
executionParameters: [{ name: 'policyId', in: 'path' }],
});
const req = buildRequest(def, { policyId: '42' });
expect(req.url).toBe('https://gw.example.com/mcp/v1/policy/42/rules');
});
});

describe('buildResult', () => {
Expand Down
10 changes: 0 additions & 10 deletions test/session.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,11 @@ describe('SessionManager', () => {
expect(sm.baseUrl).toBe('https://gw.example.com');
});

it('baseUrl returns override when provided', async () => {
await sm.connect('https://gw.example.com', 'id', 'key', 'http://localhost:8080');
expect(sm.baseUrl).toBe('http://localhost:8080');
});

it('normalizes gateway to origin (strips path)', async () => {
await sm.connect('https://gw.example.com/some/path', 'id', 'key');
expect(sm.baseUrl).toBe('https://gw.example.com');
});

it('normalizes override to origin (strips path)', async () => {
await sm.connect('https://gw.example.com', 'id', 'key', 'http://localhost:8080/some/path');
expect(sm.baseUrl).toBe('http://localhost:8080');
});

// --- authToken ---

it('authToken returns the JWT after connect', async () => {
Expand Down
Loading