Skip to content

Commit 4ed261d

Browse files
committed
Add websocket subprotocol endpoints
1 parent a20e17b commit 4ed261d

4 files changed

Lines changed: 262 additions & 1 deletion

File tree

src/endpoints/ws-index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ export interface WebSocketEndpoint {
88
/** Return true to match, false to skip, or throw StatusError for invalid params. */
99
matchPath: (path: string, hostnamePrefix?: string) => boolean;
1010
getRemainingPath?: (path: string) => string | undefined;
11+
/** Return a subprotocol name to select, or false to explicitly select none. */
12+
getProtocol?: (path: string) => string | false;
1113
handle: (ws: WebSocket, req: IncomingMessage, options: {
1214
path: string;
1315
query: URLSearchParams;
@@ -21,3 +23,4 @@ export * from './ws/close.js';
2123
export * from './ws/message.js';
2224
export * from './ws/reset.js';
2325
export * from './ws/repeat.js';
26+
export * from './ws/subprotocol.js';

src/endpoints/ws/subprotocol.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import { WebSocketEndpoint } from '../ws-index.js';
2+
import { wsConnection } from '../groups.js';
3+
4+
const SUBPROTOCOL_PREFIX = '/ws/subprotocol/';
5+
6+
export const wsSubprotocolEndpoint: WebSocketEndpoint = {
7+
matchPath: (path) => {
8+
return path.startsWith(SUBPROTOCOL_PREFIX) && path.length > SUBPROTOCOL_PREFIX.length;
9+
},
10+
getRemainingPath: (path) => {
11+
const idx = path.indexOf('/', SUBPROTOCOL_PREFIX.length);
12+
return idx !== -1 ? '/ws' + path.slice(idx) : undefined;
13+
},
14+
getProtocol: (path) => {
15+
const idx = path.indexOf('/', SUBPROTOCOL_PREFIX.length);
16+
const end = idx !== -1 ? idx : path.length;
17+
return decodeURIComponent(path.slice(SUBPROTOCOL_PREFIX.length, end));
18+
},
19+
handle: () => {},
20+
meta: {
21+
path: '/ws/subprotocol/{name}',
22+
description: 'Forces the specified subprotocol in the upgrade response Sec-WebSocket-Protocol header, regardless of what the client requested.',
23+
examples: ['/ws/subprotocol/graphql-ws/echo', '/ws/subprotocol/mqtt/message/hello/close/1000'],
24+
group: wsConnection
25+
}
26+
};
27+
28+
const NO_SUBPROTOCOL_PATH = '/ws/no-subprotocol';
29+
30+
export const wsNoSubprotocolEndpoint: WebSocketEndpoint = {
31+
matchPath: (path) => {
32+
return path === NO_SUBPROTOCOL_PATH || path.startsWith(NO_SUBPROTOCOL_PATH + '/');
33+
},
34+
getRemainingPath: (path) => {
35+
return path.length > NO_SUBPROTOCOL_PATH.length
36+
? '/ws' + path.slice(NO_SUBPROTOCOL_PATH.length)
37+
: undefined;
38+
},
39+
getProtocol: () => false,
40+
handle: () => {},
41+
meta: {
42+
path: '/ws/no-subprotocol',
43+
description: 'Explicitly omits the Sec-WebSocket-Protocol header from the upgrade response, overriding the default behavior where the server auto-selects the first client-offered protocol.',
44+
examples: ['/ws/no-subprotocol/echo'],
45+
group: wsConnection
46+
}
47+
};

src/ws-handler.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,26 @@ import { StatusError } from '@httptoolkit/util';
66
import { wsEndpoints } from './endpoints/endpoint-index.js';
77
import { resolveEndpointChain } from './endpoint-chain.js';
88

9-
const wss = new WebSocketServer({ noServer: true });
9+
const FORCED_PROTOCOL = Symbol('ws-forced-protocol');
10+
11+
declare module 'http' {
12+
interface IncomingMessage {
13+
[FORCED_PROTOCOL]?: string | false;
14+
}
15+
}
16+
17+
const wss = new WebSocketServer({
18+
noServer: true,
19+
handleProtocols(clientProtocols, req) {
20+
const forced = req[FORCED_PROTOCOL];
21+
if (forced === undefined) {
22+
// No subprotocol endpoint in chain — use default ws behavior
23+
return clientProtocols.values().next().value || false;
24+
}
25+
26+
return forced;
27+
}
28+
});
1029

1130
export function handleWebSocketUpgrade(
1231
req: IncomingMessage,
@@ -45,6 +64,23 @@ export function handleWebSocketUpgrade(
4564
console.log('WebSocket upgrade socket error:', err.message);
4665
});
4766

67+
const protocolEntries = entries.filter(e => e.endpoint.getProtocol);
68+
if (protocolEntries.length > 1) {
69+
console.log(`WebSocket upgrade to ${path}: multiple subprotocol endpoints`);
70+
socket.write('HTTP/1.1 400 Bad Request\r\n\r\n');
71+
socket.destroy();
72+
return;
73+
}
74+
if (protocolEntries.length === 1) {
75+
req[FORCED_PROTOCOL] = protocolEntries[0].endpoint.getProtocol!(protocolEntries[0].path);
76+
77+
// ws only calls handleProtocols when the client sends Sec-WebSocket-Protocol.
78+
// Ensure the header exists so our handler always runs.
79+
if (!req.headers['sec-websocket-protocol']) {
80+
req.headers['sec-websocket-protocol'] = '_';
81+
}
82+
}
83+
4884
wss.handleUpgrade(req, socket, head, async (ws) => {
4985
ws.on('error', (err) => {
5086
console.log(`WebSocket error on ${path}:`, err.message);

test/ws-subprotocol.spec.ts

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import * as net from 'net';
2+
import * as http from 'http';
3+
import * as crypto from 'crypto';
4+
import { expect } from 'chai';
5+
import { WebSocket } from 'ws';
6+
import { DestroyableServer, makeDestroyable } from 'destroyable-server';
7+
8+
import { createServer } from '../src/server.js';
9+
10+
// Raw WebSocket upgrade that skips the ws client's strict subprotocol validation
11+
// (which rejects mismatched or missing server-selected protocols with an error)
12+
function rawUpgrade(port: number, path: string, protocols?: string[]): Promise<{
13+
statusCode: number;
14+
headers: http.IncomingHttpHeaders;
15+
}> {
16+
return new Promise((resolve, reject) => {
17+
const key = crypto.randomBytes(16).toString('base64');
18+
const headers: Record<string, string> = {
19+
'Connection': 'Upgrade',
20+
'Upgrade': 'websocket',
21+
'Sec-WebSocket-Version': '13',
22+
'Sec-WebSocket-Key': key,
23+
};
24+
if (protocols?.length) {
25+
headers['Sec-WebSocket-Protocol'] = protocols.join(', ');
26+
}
27+
28+
const req = http.request({
29+
host: 'localhost',
30+
port,
31+
path,
32+
headers
33+
});
34+
35+
req.on('upgrade', (res, socket) => {
36+
resolve({ statusCode: 101, headers: res.headers });
37+
socket.destroy();
38+
});
39+
req.on('response', (res) => {
40+
resolve({ statusCode: res.statusCode!, headers: res.headers });
41+
});
42+
req.on('error', reject);
43+
req.end();
44+
});
45+
}
46+
47+
describe("WebSocket Subprotocol endpoints", () => {
48+
49+
let server: DestroyableServer;
50+
let serverPort: number;
51+
52+
beforeEach(async () => {
53+
server = makeDestroyable(await createServer({
54+
domain: 'localhost'
55+
}));
56+
await new Promise<void>((resolve) => server.listen(resolve));
57+
serverPort = (server.address() as net.AddressInfo).port;
58+
});
59+
60+
afterEach(async () => {
61+
await server.destroy();
62+
});
63+
64+
describe("/ws/subprotocol/{name}", () => {
65+
66+
it("selects the specified subprotocol", async () => {
67+
const ws = new WebSocket(`ws://localhost:${serverPort}/ws/subprotocol/graphql-ws/echo`, ['graphql-ws']);
68+
69+
await new Promise<void>((resolve, reject) => {
70+
ws.on('open', resolve);
71+
ws.on('error', reject);
72+
});
73+
74+
expect(ws.protocol).to.equal('graphql-ws');
75+
ws.close();
76+
});
77+
78+
it("forces the specified protocol even if client offered a different one", async () => {
79+
const result = await rawUpgrade(serverPort, '/ws/subprotocol/mqtt/echo', ['other-protocol']);
80+
81+
expect(result.statusCode).to.equal(101);
82+
expect(result.headers['sec-websocket-protocol']).to.equal('mqtt');
83+
});
84+
85+
it("forces the specified protocol even when client sends no protocols", async () => {
86+
const result = await rawUpgrade(serverPort, '/ws/subprotocol/mqtt/echo');
87+
88+
expect(result.statusCode).to.equal(101);
89+
expect(result.headers['sec-websocket-protocol']).to.equal('mqtt');
90+
});
91+
92+
it("works with chained endpoints", async () => {
93+
const ws = new WebSocket(`ws://localhost:${serverPort}/ws/subprotocol/test-proto/echo`, ['test-proto']);
94+
95+
await new Promise<void>((resolve, reject) => {
96+
ws.on('open', resolve);
97+
ws.on('error', reject);
98+
});
99+
100+
expect(ws.protocol).to.equal('test-proto');
101+
102+
ws.send('hello');
103+
const msg = await new Promise<string>((resolve, reject) => {
104+
ws.on('message', (data) => resolve(data.toString()));
105+
ws.on('error', reject);
106+
});
107+
expect(msg).to.equal('hello');
108+
109+
ws.close();
110+
});
111+
112+
});
113+
114+
describe("/ws/no-subprotocol", () => {
115+
116+
it("omits subprotocol header even when client requests one", async () => {
117+
const result = await rawUpgrade(serverPort, '/ws/no-subprotocol/echo', ['graphql-ws']);
118+
119+
expect(result.statusCode).to.equal(101);
120+
expect(result.headers['sec-websocket-protocol']).to.equal(undefined);
121+
});
122+
123+
});
124+
125+
describe("multiple subprotocol endpoints", () => {
126+
127+
it("rejects multiple subprotocol endpoints in the same chain", async () => {
128+
const result = await rawUpgrade(
129+
serverPort,
130+
'/ws/subprotocol/proto-a/subprotocol/proto-b/echo'
131+
);
132+
133+
expect(result.statusCode).to.equal(400);
134+
});
135+
136+
it("rejects mixing subprotocol and no-subprotocol", async () => {
137+
const result = await rawUpgrade(
138+
serverPort,
139+
'/ws/no-subprotocol/subprotocol/mqtt/echo'
140+
);
141+
142+
expect(result.statusCode).to.equal(400);
143+
});
144+
145+
});
146+
147+
describe("default behavior", () => {
148+
149+
it("auto-selects first client protocol when no subprotocol endpoint is used", async () => {
150+
const ws = new WebSocket(`ws://localhost:${serverPort}/ws/echo`, ['proto-a', 'proto-b']);
151+
152+
await new Promise<void>((resolve, reject) => {
153+
ws.on('open', resolve);
154+
ws.on('error', reject);
155+
});
156+
157+
expect(ws.protocol).to.equal('proto-a');
158+
ws.close();
159+
});
160+
161+
it("returns empty protocol when client sends none", async () => {
162+
const ws = new WebSocket(`ws://localhost:${serverPort}/ws/echo`);
163+
164+
await new Promise<void>((resolve, reject) => {
165+
ws.on('open', resolve);
166+
ws.on('error', reject);
167+
});
168+
169+
expect(ws.protocol).to.equal('');
170+
ws.close();
171+
});
172+
173+
});
174+
175+
});

0 commit comments

Comments
 (0)