Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions packages/core/sdk/src/oauth-discovery.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { afterEach, describe, expect, it, vi } from "@effect/vitest";
import { Cause, Effect, Exit } from "effect";
import { Cause, Effect, Exit, Schema } from "effect";

import {
OAuthDiscoveryError,
Expand All @@ -11,6 +11,14 @@ import {

type Handler = (url: string, init: RequestInit) => Response | Promise<Response>;

const DcrRequestBody = Schema.Struct({
redirect_uris: Schema.Array(Schema.String),
token_endpoint_auth_method: Schema.String,
});
const decodeDcrRequestBody = Schema.decodeUnknownSync(
Schema.fromJsonString(DcrRequestBody),
);

const installFetchRouter = (
handlers: readonly { match: (url: string) => boolean; handle: Handler }[],
): { calls: Array<{ url: string; init: RequestInit }> } => {
Expand Down Expand Up @@ -196,7 +204,7 @@ describe("registerDynamicClient", () => {

const call = calls[0]!;
expect(call.init.method).toBe("POST");
const body = JSON.parse(call.init.body as string);
const body = decodeDcrRequestBody(call.init.body);
expect(body.redirect_uris).toEqual(["https://app.example.com/cb"]);
expect(body.token_endpoint_auth_method).toBe("none");
});
Expand Down Expand Up @@ -251,11 +259,12 @@ describe("registerDynamicClient", () => {
expect(Exit.isFailure(exit)).toBe(true);
if (!Exit.isFailure(exit)) return;
const reason = exit.cause.reasons.find(Cause.isFailReason);
if (!(reason?.error instanceof OAuthDiscoveryError)) {
throw new Error("expected OAuthDiscoveryError");
}
expect(reason.error.status).toBe(400);
expect(reason.error.message).toMatch(/invalid_client_metadata/);
const error = reason?.error;
expect(error).toEqual(expect.objectContaining({
_tag: "OAuthDiscoveryError",
status: 400,
message: expect.stringMatching(/invalid_client_metadata/),
}));
});
});

Expand Down
154 changes: 63 additions & 91 deletions packages/core/sdk/src/oauth-discovery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// callers actually need.
// ---------------------------------------------------------------------------

import { Data, Effect, Result, Schema } from "effect";
import { Data, Effect, Option, Result, Schema } from "effect";
import * as oauth from "oauth4webapi";

import {
Expand All @@ -45,22 +45,15 @@ export class OAuthDiscoveryError extends Data.TaggedError(
readonly cause?: unknown;
}> {}

const discoveryError = (
message: string,
options: { status?: number; cause?: unknown } = {},
): OAuthDiscoveryError =>
new OAuthDiscoveryError({
message,
status: options.status,
cause: options.cause,
});

// ---------------------------------------------------------------------------
// Schemas (narrow structural parsing — the RFCs leave many fields
// optional; we validate only the subset consumers read)
// ---------------------------------------------------------------------------

const StringArray = Schema.Array(Schema.String);
const JsonUnknownFromString = Schema.fromJsonString(Schema.Unknown);
const decodeJsonUnknownSync = Schema.decodeUnknownSync(JsonUnknownFromString);
const decodeJsonUnknownOption = Schema.decodeUnknownOption(JsonUnknownFromString);

export const OAuthProtectedResourceMetadataSchema = Schema.Struct({
resource: Schema.optional(Schema.String),
Expand Down Expand Up @@ -155,20 +148,17 @@ export interface DiscoveryRequestOptions {
const MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version";

const isLoopbackHttpUrl = (value: string): boolean => {
try {
const url = new URL(value);
if (url.protocol !== "http:") return false;
const hostname = url.hostname.toLowerCase();
return (
hostname === "localhost" ||
hostname === "0.0.0.0" ||
hostname === "::1" ||
hostname === "[::1]" ||
hostname.startsWith("127.")
);
} catch {
return false;
}
if (!URL.canParse(value)) return false;
const url = new URL(value);
if (url.protocol !== "http:") return false;
const hostname = url.hostname.toLowerCase();
return (
hostname === "localhost" ||
hostname === "0.0.0.0" ||
hostname === "::1" ||
hostname === "[::1]" ||
hostname.startsWith("127.")
);
};

const oauth4webapiOptions = (
Expand Down Expand Up @@ -260,30 +250,26 @@ export const discoverProtectedResourceMetadata = (
}
const text = await response.text();
if (text.length === 0) return "skip" as const;
return { status: response.status, body: JSON.parse(text) } as const;
return { status: response.status, body: decodeJsonUnknownSync(text) } as const;
},
catch: (cause) =>
discoveryError(
`Failed to fetch ${url}: ${cause instanceof Error ? cause.message : String(cause)}`,
{ cause },
),
new OAuthDiscoveryError({
message: `Failed to fetch protected resource metadata at ${url}`,
cause,
}),
});
if (result === "skip") continue;
if (!("body" in result)) {
return yield* Effect.fail(
discoveryError(
`Protected resource metadata returned status ${result.status}`,
{ status: result.status },
),
);
return yield* new OAuthDiscoveryError({
message: `Protected resource metadata returned status ${result.status}`,
status: result.status,
});
}
const metadata = yield* decodeResourceMetadata(result.body).pipe(
Effect.mapError(
(err) =>
new OAuthDiscoveryError({
message: `Protected resource metadata is malformed: ${
Schema.isSchemaError(err) ? err.message : String(err)
}`,
message: "Protected resource metadata is malformed",
cause: err,
}),
),
Expand Down Expand Up @@ -348,15 +334,11 @@ export const discoverAuthorizationServerMetadata = (
raw: as,
};
},
catch: (cause) => {
if (cause instanceof OAuthDiscoveryError) return cause;
return discoveryError(
`Discovery (${algorithm}) failed for ${issuer}: ${
cause instanceof Error ? cause.message : String(cause)
}`,
{ cause },
);
},
catch: (cause) =>
new OAuthDiscoveryError({
message: `Discovery (${algorithm}) failed for ${issuer}`,
cause,
}),
}).pipe(
// If one algorithm fails mid-roundtrip (network, parse, issuer
// mismatch) we still want to try the other before giving up.
Expand All @@ -370,9 +352,7 @@ export const discoverAuthorizationServerMetadata = (
Effect.mapError(
(err) =>
new OAuthDiscoveryError({
message: `Authorization server metadata is malformed: ${
Schema.isSchemaError(err) ? err.message : String(err)
}`,
message: "Authorization server metadata is malformed",
cause: err,
}),
),
Expand Down Expand Up @@ -437,11 +417,7 @@ const interpretDcrFailure = (
): DcrErrorBody | DcrTransport => {
// RFC 6749 error envelope: `{error, error_description?}` with 4xx.
if (status >= 400 && status < 500) {
const parsed = Result.try({
try: () => (text ? (JSON.parse(text) as unknown) : null),
catch: () => null,
});
const body = Result.isSuccess(parsed) ? parsed.success : null;
const body = text ? Option.getOrNull(decodeJsonUnknownOption(text)) : null;
if (
body &&
typeof body === "object" &&
Expand Down Expand Up @@ -497,17 +473,18 @@ export const registerDynamicClient = (
}),
catch: (cause) =>
new DcrTransport({
message: `Dynamic Client Registration request failed: ${cause instanceof Error ? cause.message : String(cause)}`,
message: "Dynamic Client Registration request failed",
cause,
}),
});

// Accept both 200 and 201 as success — RFC 7591 mandates 201, but
// Todoist (and others) return 200 OK with the client information body.
if (response.status !== 200 && response.status !== 201) {
const text = yield* Effect.promise(() =>
response.text().catch(() => ""),
);
const text = yield* Effect.tryPromise({
try: () => response.text(),
catch: () => "",
});
return yield* interpretDcrFailure(response.status, text);
}

Expand All @@ -520,22 +497,20 @@ export const registerDynamicClient = (
cause,
}),
});
const json = yield* Effect.try({
try: () => JSON.parse(text) as unknown,
catch: (cause) =>
const json = yield* Schema.decodeUnknownEffect(JsonUnknownFromString)(text).pipe(
Effect.mapError((cause) =>
new DcrTransport({
message: "Dynamic Client Registration response was not valid JSON",
status: response.status,
cause,
}),
});
),
);
return yield* decodeClientInformation(json).pipe(
Effect.mapError(
(err) =>
new OAuthDiscoveryError({
message: `Dynamic Client Registration response is malformed: ${
Schema.isSchemaError(err) ? err.message : String(err)
}`,
message: "Dynamic Client Registration response is malformed",
cause: err,
}),
),
Expand All @@ -544,16 +519,18 @@ export const registerDynamicClient = (
Effect.catchTags({
DcrErrorBody: (err) =>
Effect.fail(
discoveryError(
`Dynamic Client Registration failed: ${err.error}${
err.error_description ? ` ${err.error_description}` : ""
new OAuthDiscoveryError({
message: `Dynamic Client Registration failed: ${err.error}${
err.error_description ? ` - ${err.error_description}` : ""
}`,
{ status: err.status, cause: err },
),
status: err.status,
cause: err,
}),
),
DcrTransport: (err) =>
Effect.fail(
discoveryError(`Dynamic Client Registration failed: ${err.message}`, {
new OAuthDiscoveryError({
message: "Dynamic Client Registration failed",
status: err.status,
cause: err.cause ?? err,
}),
Expand Down Expand Up @@ -649,29 +626,23 @@ export const beginDynamicAuthorization = (
);

if (!authServer) {
return yield* Effect.fail(
discoveryError(
`No OAuth authorization server metadata at ${authorizationServerUrl}`,
),
);
return yield* new OAuthDiscoveryError({
message: `No OAuth authorization server metadata at ${authorizationServerUrl}`,
});
}

const pkceMethods = authServer.metadata.code_challenge_methods_supported ?? [];
if (pkceMethods.length > 0 && !pkceMethods.includes("S256")) {
return yield* Effect.fail(
discoveryError(
`Authorization server does not support PKCE S256 (advertised: ${pkceMethods.join(", ")})`,
),
);
return yield* new OAuthDiscoveryError({
message: `Authorization server does not support PKCE S256 (advertised: ${pkceMethods.join(", ")})`,
});
}

const responseTypes = authServer.metadata.response_types_supported ?? [];
if (responseTypes.length > 0 && !responseTypes.includes("code")) {
return yield* Effect.fail(
discoveryError(
`Authorization server does not support response_type=code (advertised: ${responseTypes.join(", ")})`,
),
);
return yield* new OAuthDiscoveryError({
message: `Authorization server does not support response_type=code (advertised: ${responseTypes.join(", ")})`,
});
}

const baseClientMetadata: DynamicClientMetadata = {
Expand All @@ -689,9 +660,10 @@ export const beginDynamicAuthorization = (
const reg = authServer.metadata.registration_endpoint;
if (!reg) {
return Effect.fail(
discoveryError(
"Authorization server does not advertise registration_endpoint — cannot auto-register a client",
),
new OAuthDiscoveryError({
message:
"Authorization server does not advertise registration_endpoint - cannot auto-register a client",
}),
);
}
return registerDynamicClient(
Expand Down
14 changes: 8 additions & 6 deletions packages/core/sdk/src/oauth-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ describe("exchangeAuthorizationCode", () => {
});

it("returns a typed OAuth2Error on transport failure", async () => {
globalThis.fetch = vi.fn().mockRejectedValue(new Error("boom")) as typeof fetch;
globalThis.fetch = vi.fn().mockRejectedValue({ message: "boom" }) as typeof fetch;
const exit = await Effect.runPromiseExit(
exchangeAuthorizationCode({
tokenUrl: "https://example.com/token",
Expand All @@ -380,7 +380,7 @@ describe("exchangeAuthorizationCode", () => {
const err = exit.cause;
const failure = JSON.stringify(err);
expect(failure).toContain("OAuth2Error");
expect(failure).toContain("boom");
expect(failure).toContain("OAuth token exchange failed");
});

it("propagates RFC 6749 error_description text in the OAuth2Error", async () => {
Expand Down Expand Up @@ -563,7 +563,7 @@ describe("shouldRefreshToken", () => {

describe("OAuth2Error tagging", () => {
beforeEach(() => {
globalThis.fetch = vi.fn().mockRejectedValue(new Error("network down")) as typeof fetch;
globalThis.fetch = vi.fn().mockRejectedValue({ message: "network down" }) as typeof fetch;
});
afterEach(() => {
globalThis.fetch = originalFetch;
Expand All @@ -585,8 +585,10 @@ describe("OAuth2Error tagging", () => {

it("OAuth2Error is constructable directly with message and cause", () => {
const err = new OAuth2Error({ message: "test", cause: { foo: 1 } });
expect(err._tag).toBe("OAuth2Error");
expect(err.message).toBe("test");
expect(err.cause).toEqual({ foo: 1 });
expect(err).toMatchObject({
_tag: "OAuth2Error",
message: "test",
cause: { foo: 1 },
});
});
});
Loading
Loading