From d229274a667649f6e46ac0e83c8b23fa0a745a54 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Mon, 18 May 2026 00:36:31 +0200 Subject: [PATCH] Add support for passing IO objects to Websockets.open() --- src/Connections.jl | 2 + src/WebSockets.jl | 133 ++++++++++++++++++++++++++------------ test/runtests.jl | 1 + test/websockets/raw_io.jl | 58 +++++++++++++++++ 4 files changed, 153 insertions(+), 41 deletions(-) create mode 100644 test/websockets/raw_io.jl diff --git a/src/Connections.jl b/src/Connections.jl index a8ba4051b..36d30c95e 100644 --- a/src/Connections.jl +++ b/src/Connections.jl @@ -142,6 +142,8 @@ This forces the socket to send whatever data is within its buffer immediately, rather than waiting 10's of milliseconds for the buffer to fill more. """ function Base.flush(c::Connection) + has_tcpsocket(c) || return + # Flushing the TCP buffer requires support for `Sockets.nagle()` # which was only added in Julia v1.3 @static if VERSION >= v"1.3" diff --git a/src/WebSockets.jl b/src/WebSockets.jl index 1afb5f761..d32f5166d 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -332,6 +332,56 @@ function hashedkey(key) return base64encode(digest(MD_SHA1, hashkey)) end +function upgrade_headers(headers, key) + return [ + "Upgrade" => "websocket", + "Connection" => "Upgrade", + "Sec-WebSocket-Key" => key, + "Sec-WebSocket-Version" => "13", + headers... + ] +end + +# Validate the server's upgrade response on `http` (a client-side `Stream` whose +# headers have already been read), then build a `WebSocket` and run `f(ws)` with +# the standard close-frame/error handling. +function runclient(f::Function, http, key; suppress_close_error::Bool, maxframesize::Integer, maxfragmentation::Integer) + startread(http) + isupgrade(http.message) || handshakeerror() + if header(http, "Sec-WebSocket-Accept") != hashedkey(key) + throw(WebSocketError("Invalid Sec-WebSocket-Accept\n" * "$(http.message)")) + end + # later stream logic checks to see if the HTTP message is "complete" + # by seeing if ntoread is 0, which is typemax(Int) for websockets by default + # so set it to 0 so it's correctly viewed as "complete" once we're done + # doing websocket things + http.ntoread = 0 + io = http.stream + ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation) + @debug "$(ws.id): WebSocket opened" + try + f(ws) + catch e + if !isok(e) + suppress_close_error || @error "$(ws.id): error" (e, catch_backtrace()) + end + if !isclosed(ws) + if e isa WebSocketError && e.message isa CloseFrameBody + close(ws, e.message) + else + close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) + end + end + if !isok(e) + rethrow() + end + finally + if !isclosed(ws) + close(ws, CloseFrameBody(1000, "")) + end + end +end + """ WebSockets.open(handler, url; verbose=false, kw...) @@ -354,49 +404,50 @@ end """ function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) - headers = [ - "Upgrade" => "websocket", - "Connection" => "Upgrade", - "Sec-WebSocket-Key" => key, - "Sec-WebSocket-Version" => "13", - headers... - ] + headers = upgrade_headers(headers, key) + # HTTP.open open("GET", url, headers; verbose=verbose, kw...) do http - startread(http) - isupgrade(http.message) || handshakeerror() - if header(http, "Sec-WebSocket-Accept") != hashedkey(key) - throw(WebSocketError("Invalid Sec-WebSocket-Accept\n" * "$(http.message)")) - end - # later stream logic checks to see if the HTTP message is "complete" - # by seeing if ntoread is 0, which is typemax(Int) for websockets by default - # so set it to 0 so it's correctly viewed as "complete" once we're done - # doing websocket things - http.ntoread = 0 - io = http.stream - ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation) - @debug "$(ws.id): WebSocket opened" - try - f(ws) - catch e - if !isok(e) - suppress_close_error || @error "$(ws.id): error" (e, catch_backtrace()) - end - if !isclosed(ws) - if e isa WebSocketError && e.message isa CloseFrameBody - close(ws, e.message) - else - close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) - end - end - if !isok(e) - rethrow() - end - finally - if !isclosed(ws) - close(ws, CloseFrameBody(1000, "")) - end - end + runclient(f, http, key; suppress_close_error, maxframesize, maxfragmentation) + end +end + +""" + WebSockets.open(handler, io::IO; target="/", host="", kw...) + +Perform the websocket client handshake directly over an already-connected `IO` +object (e.g. a raw `TCPSocket`, a TLS stream, or an in-memory `IOBuffer`-like +stream) instead of opening a new connection from a URL. This is useful when the +transport is established out-of-band, for testing, or for tunnelling websockets +over a custom byte stream. + +Because there is no URL to derive them from, the HTTP request line target and +`Host` header have simple defaults; override with the `target` and `host` +keyword arguments. The provided `io` is *not* closed by `open`, the caller +retains ownership of the underlying stream's lifetime. Other keyword arguments +match the URL-based [`WebSockets.open`](@ref). +""" +function open(f::Function, io::IO; target="/", host::AbstractString="", + suppress_close_error::Bool=false, headers=[], + maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) + key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) + reqheaders = upgrade_headers(headers, key) + isempty(host) || pushfirst!(reqheaders, "Host" => host) + + # Bypass the client request stack entirely: wrap the user-provided IO in a + # Connection and drive a Stream over it directly, mirroring what + # StreamRequest.streamlayer does around the websocket `iofunction`. Note + # that we pass `clientconnection=false` so closeread() won't spawn an idle + # monitor that closes the caller-owned `io` on EOF. + conn = Connections.Connection("", "", 0, false, true, io, false) + req = Request("GET", target, reqheaders) + stream = Stream(req.response, conn) + startwrite(stream) + try + runclient(f, stream, key; suppress_close_error, maxframesize, maxfragmentation) + finally + closewrite(stream) + closeread(stream) end end diff --git a/test/runtests.jl b/test/runtests.jl index a7d10f903..fd9593552 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ isok(r) = r.status == 200 "httpversion.jl", "websockets/autobahn.jl", "websockets/multiple_writers.jl", + "websockets/raw_io.jl", ] # ARGS can be most easily passed like this: # import Pkg; Pkg.test("HTTP"; test_args=`ascii.jl parser.jl`) diff --git a/test/websockets/raw_io.jl b/test/websockets/raw_io.jl new file mode 100644 index 000000000..a544f55e0 --- /dev/null +++ b/test/websockets/raw_io.jl @@ -0,0 +1,58 @@ +using Test, Sockets +using HTTP +using HTTP.WebSockets + +@testset "WebSockets.open over a raw IO" begin + port = 8234 + server = WebSockets.listen!("127.0.0.1", port) do ws + # echo back, and surface the request target the client used + WebSockets.send(ws, "target:" * ws.request.target) + for msg in ws + WebSockets.send(ws, "echo: " * msg) + end + end + + try + @testset "echo round-trip over a connected TCPSocket" begin + sock = Sockets.connect("127.0.0.1", port) + got = String[] + WebSockets.open(sock; host="127.0.0.1:$port") do ws + push!(got, WebSockets.receive(ws)) + WebSockets.send(ws, "hello") + push!(got, WebSockets.receive(ws)) + WebSockets.send(ws, "world") + push!(got, WebSockets.receive(ws)) + end + @test got == ["target:/", "echo: hello", "echo: world"] + # open() must not close the caller-owned transport: the socket is + # still locally open (only the caller may close it). If the idle + # monitor were spawned it would have closed this on EOF. + @test isopen(sock) + close(sock) + end + + @testset "custom target keyword is sent in the request line" begin + sock = Sockets.connect("127.0.0.1", port) + target = nothing + WebSockets.open(sock; target="/ws/v1", host="127.0.0.1:$port") do ws + target = WebSockets.receive(ws) + end + @test target == "target:/ws/v1" + end + + @testset "works with a non-socket IO type" begin + # any IO is accepted, not just sockets: drive it through a TCP pair + sock = Sockets.connect("127.0.0.1", port) + io = IOContext(sock) # wrap so it's not a TCPSocket itself + got = String[] + WebSockets.open(io; host="127.0.0.1:$port") do ws + push!(got, WebSockets.receive(ws)) # target line + WebSockets.send(ws, "hi") + push!(got, WebSockets.receive(ws)) + end + @test got == ["target:/", "echo: hi"] + end + finally + close(server) + end +end