From 68b9acbbd341316e51a6d5fc299cbe4cbb2754a4 Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Wed, 13 May 2026 11:43:03 +0100 Subject: [PATCH 1/8] feat(ice): Add ice agent --- build.zig | 19 +- src/ice/agent.zig | 636 ++++++++++++++++++++++++++++++++++++++++++++++ src/ice/c.h | 2 + src/ice/ice.zig | 188 ++++++++++++++ src/root.zig | 1 + src/stun/stun.zig | 7 +- 6 files changed, 851 insertions(+), 2 deletions(-) create mode 100644 src/ice/agent.zig create mode 100644 src/ice/c.h create mode 100644 src/ice/ice.zig diff --git a/build.zig b/build.zig index a7f63aa..ae55e68 100644 --- a/build.zig +++ b/build.zig @@ -43,6 +43,22 @@ pub fn build(b: *std.Build) void { .optimize = optimize, }); + const translate_c = b.addTranslateC(.{ + .root_source_file = b.path("src/ice/c.h"), + .target = target, + .optimize = optimize, + }); + + const ice = b.addModule("ice", .{ + .root_source_file = b.path("src/ice/ice.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "stun", .module = stun }, + .{ .name = "c", .module = translate_c.createModule() }, + }, + .link_libc = true, + }); _ = b.addModule("protocols", .{ .root_source_file = b.path("src/root.zig"), .target = target, @@ -52,12 +68,13 @@ pub fn build(b: *std.Build) void { .{ .name = "sdp", .module = sdp }, .{ .name = "rtsp", .module = rtsp }, .{ .name = "stun", .module = stun }, + .{ .name = "ice", .module = ice }, }, }); { const test_filters = b.option([]const []const u8, "test-filter", "Skip tests that do not match any filter") orelse &[0][]const u8{}; - const modules = [_]*std.Build.Module{ rtp, rtcp, sdp, rtsp, stun }; + const modules = [_]*std.Build.Module{ rtp, rtcp, sdp, rtsp, stun, ice }; const test_step = b.step("test", "Run tests"); inline for (modules) |sub_module| { diff --git a/src/ice/agent.zig b/src/ice/agent.zig new file mode 100644 index 0000000..a4d31ef --- /dev/null +++ b/src/ice/agent.zig @@ -0,0 +1,636 @@ +const std = @import("std"); +const c = @import("c"); +const stun = @import("stun"); +const ice = @import("ice.zig"); + +const Io = std.Io; +const Socket = Io.net.Socket; +const IpAddress = Io.net.IpAddress; +const Candidate = ice.Candidate; +const CandidatePair = ice.CandidatePair; +const Agent = @This(); +const Logger = std.log.scoped(.ice); + +const max_message_size = 1500; +const max_binding_requests: usize = 7; +const connectivity_check_interval: std.Io.Duration = .fromMilliseconds(200); +const keep_alive_interval: std.Io.Duration = .fromMilliseconds(200); + +io: Io, +allocator: std.mem.Allocator, +buffer_pool: std.heap.MemoryPool([max_message_size]u8), +state: State = .new, + +// Stun related fields +role: Role, +credentials: ice.Credentials, +remote_credentials: ?ice.Credentials = null, +tie_breaker: u64, + +// Candidates and sockets +sockets: []Io.net.Socket = &.{}, +candidates: std.ArrayList(Candidate) = .empty, +pairs: std.ArrayList(CandidatePair) = .empty, +pending_requests: std.ArrayList(PendingRequest) = .empty, +nominated_pair: ?CandidatePair = null, + +// Io handling +group: Io.Group = .init, +queue_buffer: [1]InternalEvent = undefined, +queue: Io.Queue(InternalEvent) = undefined, + +pub const State = enum { new, checking, connected, disconnected, failed }; + +const Role = enum { controlling, controlled }; + +const InternalEvent = union(enum) { + add_candidate: Candidate, + message: struct { IpAddress, Io.net.IncomingMessage }, + check_connectivity: void, + data: []u8, +}; + +const StunRequest = struct { + username: []const u8 = &.{}, + ice_controlled: ?u64 = null, + ice_controlling: ?u64 = null, + use_candidate: bool = false, + priority: u32 = 0, +}; + +const PendingRequest = struct { + transaction_id: u96, + source: Io.net.IpAddress, + target: Io.net.IpAddress, +}; + +pub fn init(agent: *Agent, io: Io, allocator: std.mem.Allocator) !void { + agent.* = .{ + .io = io, + .allocator = allocator, + .buffer_pool = .empty, + .role = .controlled, + .credentials = try (ice.Credentials{ .username = "test", .password = "test" }).dupe(allocator), + .tie_breaker = generateTieBeaker(io), + }; + + agent.queue = .init(&agent.queue_buffer); +} + +pub fn deinit(agent: *Agent) void { + const io = agent.io; + const allocator = agent.allocator; + + agent.buffer_pool.deinit(allocator); + agent.candidates.deinit(allocator); + agent.pairs.deinit(allocator); + for (agent.sockets) |socket| socket.close(io); + allocator.free(agent.sockets); + + agent.credentials.deinit(allocator); + if (agent.remote_credentials) |*credens| credens.deinit(allocator); + + agent.queue.close(io); + agent.group.cancel(io); +} + +pub fn setRemoteCredentials(agent: *Agent, credentials: ice.Credentials) !void { + switch (agent.state) { + .new => { + agent.remote_credentials = try credentials.dupe(agent.allocator); + agent.state = .checking; + try agent.group.concurrent(agent.io, startConnectivityChecks, .{agent}); + }, + else => return error.CredentialsAlreadySet, + } +} + +pub fn addRemoteCandidate(agent: *Agent, remote_candidate: Candidate) !void { + switch (agent.state) { + .new => try agent.doAddRemoteCandidate(remote_candidate), + .checking => try agent.queue.putOne(agent.io, .{ .add_candidate = remote_candidate }), + else => {}, + } +} + +pub fn gatherCandidates(agent: *Agent) !void { + try agent.gatherHostCandidates(); + try agent.initSockets(); + try agent.group.concurrent(agent.io, listenForMessages, .{agent}); +} + +/// Poll for events +pub fn poll(agent: *Agent) !?[]u8 { + const io = agent.io; + + while (agent.queue.getOne(io)) |event| switch (event) { + .add_candidate => |remote_candidate| try agent.addRemoteCandidate(remote_candidate), + .message => |s| { + defer agent.buffer_pool.destroy(@ptrCast(@alignCast(s.@"1".data.ptr))); + if (stun.isMessage(s.@"1".data)) { + if (try agent.handleReceivedMessage(s.@"0", s.@"1")) |response| { + defer agent.buffer_pool.destroy(@ptrCast(@alignCast(@constCast(response.ptr)))); + try findSocket(agent.sockets, &s.@"0").send(io, &s.@"1".from, response); + } + } else { + for (agent.pairs.items) |*candidate_pair| if (candidate_pair.remote.address.eql(&s.@"1".from)) { + return s.@"1".data; + }; + continue; + } + + try agent.maybeSetNominatedCandidate(); + }, + .check_connectivity => try agent.batchSendConnectivityCheck(), + .data => |data| return data, + } else |err| switch (err) { + error.Canceled => return error.Canceled, + else => {}, + } + + return null; +} + +fn initSockets(agent: *Agent) !void { + agent.sockets = try agent.allocator.alloc(Io.net.Socket, agent.candidates.items.len); + var initialized: usize = 0; + errdefer { + for (0..initialized) |idx| agent.sockets[idx].close(agent.io); + agent.allocator.free(agent.sockets); + } + + for (agent.candidates.items) |*candidate| { + agent.sockets[initialized] = try candidate.address.bind( + agent.io, + .{ .mode = .dgram, .protocol = .udp }, + ); + candidate.base = agent.sockets[initialized].address; + candidate.address = agent.sockets[initialized].address; + initialized += 1; + } +} + +fn calculatePairPriority(l: u32, r: u32, role: Role) u64 { + var g = l; + var d = r; + if (role == .controlled) g, d = .{ d, g }; + + const last_part: u8 = if (g > d) 1 else 0; + return (@as(u64, 1) << 32) * @min(g, d) + 2 * @max(g, d) + last_part; +} + +fn generateTieBeaker(io: Io) u64 { + var bytes: [8]u8 = undefined; + io.random(&bytes); + return @bitCast(bytes); +} + +fn generateTrasactionId(io: Io) u96 { + var bytes: [12]u8 = undefined; + io.random(&bytes); + return std.mem.readInt(u96, &bytes, .big); +} + +fn gatherHostCandidates(agent: *Agent) !void { + var interfaces: [*c]c.ifaddrs = undefined; + if (c.getifaddrs(&interfaces) != 0) { + return error.GetIfAddrsFailed; + } + defer c.ifaddrs.freeifaddrs(interfaces); + + var it = interfaces; + while (it) |p_ifa| : (it = p_ifa.*.ifa_next) if (p_ifa.*.ifa_addr) |addr| { + switch (addr.*.sa_family) { + c.AF_INET => { + const sin: *const c.sockaddr_in = @ptrCast(@alignCast(addr)); + // Ignore loopback addresses. + if (sin.sin_addr.s_addr == std.mem.nativeToBig(u32, 0x7f000001)) { + continue; + } + + const ip_addr: Io.net.IpAddress = .{ + .ip4 = .{ .bytes = std.mem.toBytes(sin.sin_addr.s_addr), .port = 0 }, + }; + try agent.candidates.append(agent.allocator, .initHost(ip_addr)); + }, + else => {}, + } + }; +} + +fn doAddRemoteCandidate(agent: *Agent, remote_candidate: Candidate) !void { + for (agent.candidates.items) |candidate| { + for (agent.pairs.items) |*pair| + if (pair.local.base.eql(&candidate.base) and pair.remote.address.eql(&remote_candidate.address)) + continue; + + try agent.pairs.append(agent.allocator, .{ + .local = candidate, + .remote = remote_candidate, + .priority = calculatePairPriority(candidate.priority, remote_candidate.priority, agent.role), + }); + } +} + +fn handleReceivedMessage(agent: *Agent, base_addr: Io.net.IpAddress, incoming_message: Io.net.IncomingMessage) !?[]const u8 { + const msg = try stun.Message.parse(incoming_message.data); + switch (msg.header.message_type.class()) { + .request => return try agent.handleRequest(&msg, base_addr, incoming_message.from), + .success_response => { + Logger.debug("Handle success response on {f} from {f}", .{ base_addr, incoming_message.from }); + + const pending_request = blk: { + const tx_id = msg.header.transaction_id; + for (agent.pending_requests.items, 0..) |pr, i| { + if (pr.transaction_id == tx_id) { + const pending_request = agent.pending_requests.swapRemove(i); + break :blk pending_request; + } + } + + return null; + }; + + if (!pending_request.source.eql(&base_addr) or !pending_request.target.eql(&incoming_message.from)) return null; + + if (agent.findCandidatePair(&base_addr, &incoming_message.from)) |candidate_pair| { + const mapped_address = blk: { + var it = msg.iterateAttributes(&.{}); + while (try it.next()) |attribute| switch (attribute) { + .xor_mapped_address => |addr| break :blk addr, + else => {}, + }; + + return null; + }; + + if (mapped_address.eql(&base_addr)) { + candidate_pair.state.status = .succeeded; + if (agent.role == .controlled and candidate_pair.state.nominateOnBinding) { + candidate_pair.state.nominateOnBinding = false; + candidate_pair.state.nominated = true; + } + return null; + } + candidate_pair.state.status = .failed; + + if (agent.findCandidatePair(&mapped_address, &incoming_message.from)) |existing_candidate_pair| { + existing_candidate_pair.state.status = .succeeded; + return null; + } + + const reflexive_candidate: Candidate = .initPeerReflexive(base_addr, mapped_address); + try agent.pairs.append(agent.allocator, .{ + .local = reflexive_candidate, + .remote = candidate_pair.remote, + .priority = calculatePairPriority(reflexive_candidate.priority, candidate_pair.remote.priority, agent.role), + .state = .{ .status = .succeeded }, + }); + + return null; + } + }, + else => {}, + } + + return null; +} + +fn handleRequest(agent: *Agent, msg: *const stun.Message, base_addr: IpAddress, from: IpAddress) ![]const u8 { + Logger.debug("Handle request on {f} from {f}", .{ base_addr, from }); + const stun_req = try agent.parseAndValidateStunRequest(msg); + + if (agent.findCandidatePair(&base_addr, &from)) |candidate_pair| { + switch (candidate_pair.state.status) { + .succeeded => candidate_pair.state.nominated = stun_req.use_candidate, + else => candidate_pair.state.nominateOnBinding = stun_req.use_candidate, + } + } else { + const local: Candidate = .initHost(base_addr); + const remote: Candidate = .{ + .base = from, + .address = from, + .candidate_type = .peer_reflexive, + .priority = stun_req.priority, + }; + + try agent.pairs.append(agent.allocator, .{ + .local = local, + .remote = remote, + .priority = calculatePairPriority(local.priority, remote.priority, agent.role), + .state = .{ + .status = .in_progress, + .nominateOnBinding = stun_req.use_candidate, + }, + }); + } + + const buffer = try agent.buffer_pool.create(agent.allocator); + return try agent.buildSuccessResponse(msg, from, buffer); +} + +fn parseAndValidateStunRequest(agent: *Agent, msg: *const stun.Message) !StunRequest { + var it = msg.iterateAttributes(agent.credentials.password); + var has_fingerprint: bool = false; + var has_message_integrity = false; + var stun_request: StunRequest = .{}; + + while (try it.next()) |attribute| switch (attribute) { + .username => |u| stun_request.username = u, + .ice_controlled => |v| stun_request.ice_controlled = v, + .ice_controlling => |v| stun_request.ice_controlling = v, + .use_candidate => stun_request.use_candidate = true, + .priority => |p| stun_request.priority = p, + .fingerprint => has_fingerprint = true, + .message_integrity => has_message_integrity = true, + else => {}, + }; + + if (!has_fingerprint or !has_message_integrity) + return error.InvalidStunMessage; + if (stun_request.ice_controlling == null and stun_request.ice_controlled == null or + stun_request.ice_controlling != null and stun_request.ice_controlled != null) + return error.InvalidStunMessage; + + if (stun_request.ice_controlled != null and agent.role == .controlled) { + if (agent.tie_breaker >= stun_request.ice_controlled.?) + return error.SwitchRole + else + return error.RoleConflict; + } + + if (stun_request.ice_controlling != null and agent.role == .controlling) { + if (agent.tie_breaker >= stun_request.ice_controlling.?) + return error.RoleConflict + else + return error.SwitchRole; + } + + if (stun_request.use_candidate and agent.role == .controlling) + return error.InvalidStunMessage; + + //TODO: check username + + return stun_request; +} + +fn buildBindingRequest(agent: *Agent, tx_id: u96, buffer: *[max_message_size]u8) ![]const u8 { + var w = stun.Writer.init(&(buffer.*), .{ .password = agent.remote_credentials.?.password }); + try w.writeHeader(.{ + .message_type = .fromClassAndMethod(.request, .binding), + .transaction_id = tx_id, + .message_length = 0, + }); + + var username = [_][]const u8{ agent.remote_credentials.?.username, ":", agent.credentials.username }; + try w.writeRaw(.username, &username); + try w.writeAttribute(.{ .priority = 10 }); + const role_attribute: stun.Attribute = switch (agent.role) { + .controlled => .{ .ice_controlled = agent.tie_breaker }, + .controlling => .{ .ice_controlling = agent.tie_breaker }, + }; + try w.writeAttribute(role_attribute); + try w.writeAttribute(.{ .message_integrity = &.{} }); + try w.writeAttribute(.fingerprint); + + return w.final(); +} + +// Used for keep alive +fn buildIndicationRequest(buffer: []u8) ![]const u8 { + var w = stun.Writer.init(buffer, .{}); + try w.writeHeader(.{ + .message_type = .fromClassAndMethod(.indication, .binding), + .message_length = 0, + .transaction_id = 0x0010, + }); + + return w.final(); +} + +fn buildSuccessResponse( + agent: *const Agent, + msg: *const stun.Message, + from: Io.net.IpAddress, + buffer: *[max_message_size]u8, +) ![]const u8 { + var w = stun.Writer.init(&(buffer.*), .{ .password = agent.credentials.password }); + try w.writeHeader(.{ + .message_type = .fromClassAndMethod(.success_response, .binding), + .transaction_id = msg.header.transaction_id, + .message_length = 0, + }); + try w.writeAttribute(.{ .xor_mapped_address = from }); + try w.writeAttribute(.{ .message_integrity = &.{} }); + try w.writeAttribute(.fingerprint); + return w.final(); +} + +fn findSocket(sockets: []Io.net.Socket, addr: *const Io.net.IpAddress) *Io.net.Socket { + for (sockets) |*socket| if (socket.address.eql(addr)) return socket; + unreachable; +} + +fn findCandidatePair(agent: *Agent, local: *const Io.net.IpAddress, remote: *const Io.net.IpAddress) ?*CandidatePair { + for (agent.pairs.items) |*candidate| { + if (candidate.local.address.eql(local) and candidate.remote.address.eql(remote)) + return candidate; + } + + return null; +} + +fn maybeSetNominatedCandidate(agent: *Agent) !void { + if (agent.role == .controlling or agent.nominated_pair != null) return; + + for (agent.pairs.items) |candidate_pair| if (candidate_pair.state.nominated) { + agent.nominated_pair = candidate_pair; + agent.state = .connected; + agent.group.cancel(agent.io); + + // Clean up and listen on socket + agent.candidates.deinit(agent.allocator); + for (agent.sockets) |*socket| if (!socket.address.eql(&candidate_pair.local.base)) socket.close(agent.io); + try agent.pairs.shrinkAndFreePrecise(agent.allocator, 1); + agent.pairs.items[0] = candidate_pair; + + try agent.group.concurrent(agent.io, listen, .{agent}); + break; + }; +} + +// ============== Io related function ====================== +const Receive = union(enum) { + message: anyerror!struct { usize, Io.net.IncomingMessage }, +}; + +const ListenEvent = union(enum) { + message: Io.net.Socket.ReceiveTimeoutError!Io.net.IncomingMessage, + keep_alive: Io.Cancelable!void, +}; + +fn listen(agent: *Agent) !void { + agent.doListen() catch |err| switch (err) { + error.Canceled => return error.Canceled, + else => {}, + }; +} + +fn doListen(agent: *Agent) !void { + const ListenSelect = Io.Select(ListenEvent); + var listen_event_buffer: [1]ListenEvent = undefined; + var select = ListenSelect.init(agent.io, &listen_event_buffer); + defer select.cancelDiscard(); + + const socket = findSocket(agent.sockets, &agent.nominated_pair.?.local.base); + const receive_timeout: Io.Timeout = .{ .duration = .{ .clock = .awake, .raw = .fromSeconds(5) } }; + + const buffer = try agent.buffer_pool.create(agent.allocator); + defer agent.buffer_pool.destroy(buffer); + + var stun_indication: [20]u8 = undefined; + const dest = &agent.nominated_pair.?.remote.address; + + select.async(.message, Io.net.Socket.receiveTimeout, .{ socket, agent.io, &(buffer.*), receive_timeout }); + select.async(.keep_alive, Io.sleep, .{ agent.io, Io.Duration.fromSeconds(2), Io.Clock.awake }); + + while (true) switch (try select.await()) { + .message => |maybe_msg| { + const msg = maybe_msg catch |err| switch (err) { + error.Canceled => return error.Canceled, + error.Timeout => { + if (agent.state != .disconnected) { + Logger.warn("Agent state transitioned to disconnected", .{}); + agent.state = .disconnected; + } + continue; + }, + else => return, + }; + + select.async(.message, Io.net.Socket.receiveTimeout, .{ socket, agent.io, &(buffer.*), receive_timeout }); + if (stun.isMessage(msg.data)) continue; + try agent.queue.putOne(agent.io, .{ .data = msg.data }); + }, + .keep_alive => |timeout| { + try timeout; + select.async(.keep_alive, Io.sleep, .{ agent.io, keep_alive_interval, Io.Clock.awake }); + try socket.send(agent.io, dest, try buildIndicationRequest(&stun_indication)); + }, + }; +} + +fn listenForMessages(agent: *Agent) !void { + agent.doListenForMessages() catch |err| switch (err) { + error.Canceled => return error.Canceled, + else => {}, + }; +} + +fn startConnectivityChecks(agent: *Agent) !void { + while (true) { + agent.queue.putOne(agent.io, .check_connectivity) catch |err| switch (err) { + error.Canceled => return error.Canceled, + else => return, + }; + try agent.io.sleep(connectivity_check_interval, .awake); + } +} + +fn doListenForMessages(agent: *Agent) !void { + const IncomingMessageSelect = Io.Select(Receive); + + var queue: [4]Receive = undefined; + var select = IncomingMessageSelect.init(agent.io, &queue); + defer select.cancelDiscard(); + + for (agent.sockets, 0..) |*socket, idx| { + select.async(.message, receive, .{ agent, socket, idx }); + } + + while (true) { + const result = try select.await(); + + const index, const incoming_message = result.message catch |err| switch (err) { + error.Canceled => return error.Canceled, + else => |e| { + std.log.err("An error occurred when listening on socket: {}", .{e}); + continue; + }, + }; + + const socket = &agent.sockets[index]; + try agent.queue.putOne(agent.io, .{ .message = .{ socket.address, incoming_message } }); + select.async(.message, receive, .{ agent, socket, index }); + } +} + +fn doStartChecks(agent: *Agent) !void { + while (true) { + const buffer = try agent.buffer_pool.create(agent.allocator); + defer agent.buffer_pool.destroy(buffer); + + try agent.mutex.lock(agent.io); + for (agent.pairs.items) |*pair| switch (pair.state.status) { + .waiting, .in_progress => { + pair.conn_check_count += 1; + if (pair.conn_check_count > max_binding_requests) { + pair.state.status = .failed; + continue; + } + + const transaction_id = generateTrasactionId(agent.io); + const msg = try agent.buildBindingRequest(transaction_id, buffer); + + try agent.pending_requests.append(agent.allocator, .{ + .transaction_id = transaction_id, + .source = pair.local.base, + .target = pair.remote.address, + }); + + const socket = findSocket(agent.sockets, &pair.local.base); + try socket.send(agent.io, &pair.remote.address, msg); + }, + else => {}, + }; + agent.mutex.unlock(agent.io); + + try agent.io.sleep(.fromMilliseconds(200), .awake); + } +} + +fn batchSendConnectivityCheck(agent: *Agent) !void { + const buffer = try agent.buffer_pool.create(agent.allocator); + defer agent.buffer_pool.destroy(buffer); + + for (agent.pairs.items) |*candidate_pair| switch (candidate_pair.state.status) { + .waiting, .in_progress => { + candidate_pair.conn_check_count += 1; + if (candidate_pair.conn_check_count > max_binding_requests) { + candidate_pair.state.status = .failed; + continue; + } + + const transaction_id = generateTrasactionId(agent.io); + const msg = try agent.buildBindingRequest(transaction_id, buffer); + + try agent.pending_requests.append(agent.allocator, .{ + .transaction_id = transaction_id, + .source = candidate_pair.local.base, + .target = candidate_pair.remote.address, + }); + + const socket = findSocket(agent.sockets, &candidate_pair.local.base); + try socket.send(agent.io, &candidate_pair.remote.address, msg); + }, + else => {}, + }; +} + +fn receive(agent: *Agent, socket: *Socket, index: usize) !struct { usize, Io.net.IncomingMessage } { + const buffer = try agent.buffer_pool.create(agent.allocator); + errdefer agent.buffer_pool.destroy(buffer); + + const incoming_message = try socket.receive(agent.io, &(buffer.*)); + return .{ index, incoming_message }; +} diff --git a/src/ice/c.h b/src/ice/c.h new file mode 100644 index 0000000..393a8ff --- /dev/null +++ b/src/ice/c.h @@ -0,0 +1,2 @@ +#include "ifaddrs.h" +#include "netinet/in.h" \ No newline at end of file diff --git a/src/ice/ice.zig b/src/ice/ice.zig new file mode 100644 index 0000000..4471c36 --- /dev/null +++ b/src/ice/ice.zig @@ -0,0 +1,188 @@ +pub const Agent = @import("agent.zig"); + +const std = @import("std"); + +const Io = std.Io; + +pub const CandidateType = enum { + host, + server_reflexive, + peer_reflexive, + relayed, + + pub fn typePreference(self: CandidateType) u8 { + return switch (self) { + .host => 126, + .peer_reflexive => 110, + .server_reflexive => 100, + .relayed => 0, + }; + } + + pub fn name(self: CandidateType) []const u8 { + return switch (self) { + .host => "host", + .peer_reflexive => "prflx", + .server_reflexive => "srflx", + .relayed => "relay", + }; + } + + pub fn fromSlice(slice: []const u8) !CandidateType { + return if (std.mem.eql(u8, slice, "host")) + .host + else if (std.mem.eql(u8, slice, "prflx")) + .peer_reflexive + else if (std.mem.eql(u8, slice, "srflx")) + .server_reflexive + else if (std.mem.eql(u8, slice, "relay")) + .relayed + else + error.InvalidCandidateType; + } +}; + +pub const Candidate = struct { + candidate_type: CandidateType, + base: Io.net.IpAddress, + address: Io.net.IpAddress, + foundation: u32 = 0, + priority: u32 = 0, + + pub fn initHost(address: Io.net.IpAddress) Candidate { + var candidate: Candidate = .{ + .candidate_type = .host, + .base = address, + .address = address, + .priority = calculatePriority(.host), + }; + candidate.calculateFoundation(); + return candidate; + } + + pub fn initPeerReflexive(base: Io.net.IpAddress, address: Io.net.IpAddress) Candidate { + var candidate: Candidate = .{ + .candidate_type = .peer_reflexive, + .base = base, + .address = address, + .priority = calculatePriority(.peer_reflexive), + }; + candidate.calculateFoundation(); + return candidate; + } + + pub fn parse(value: []const u8) !Candidate { + var it = std.mem.tokenizeScalar(u8, value, ' '); + + const foundation = try std.fmt.parseUnsigned(u32, try nextToken(it.next()), 10); + _ = try nextToken(it.next()); // component + _ = try nextToken(it.next()); // assume udp + const priority = try std.fmt.parseUnsigned(u32, try nextToken(it.next()), 10); + + const address = try nextToken(it.next()); + const port = try std.fmt.parseUnsigned(u16, try nextToken(it.next()), 10); + const addr = try Io.net.IpAddress.parse(address, port); + + _ = try nextToken(it.next()); // typ + const candidate_type = try CandidateType.fromSlice(try nextToken(it.next())); + + return .{ + .foundation = foundation, + .priority = priority, + .base = addr, + .address = addr, + .candidate_type = candidate_type, + }; + } + + pub fn format(self: @This(), writer: *std.Io.Writer) !void { + try writer.print("{d:0>8} {} {s} {} ", .{ self.foundation, 1, "udp", self.priority }); + switch (self.address) { + .ip4 => |ip| try writer.print("{d}.{d}.{d}.{d} {d} ", .{ + ip.bytes[0], + ip.bytes[1], + ip.bytes[2], + ip.bytes[3], + ip.port, + }), + else => {}, + } + try writer.print("typ {s}", .{self.candidate_type.name()}); + } + + fn calculateFoundation(self: *Candidate) void { + var hasher = std.hash.Crc32.init(); + hasher.update(self.candidate_type.name()); + hasher.update(switch (self.address) { + .ip4 => |addr| &addr.bytes, + .ip6 => |addr| &addr.bytes, + }); + hasher.update("udp"); + self.foundation = hasher.final(); + } + + fn calculatePriority(t: CandidateType) u32 { + return (@as(u32, 1) << 24) * t.typePreference() + (1 << 8) * 65535 + 255; + } + + inline fn nextToken(maybe_token: ?[]const u8) ![]const u8 { + return if (maybe_token) |token| token else error.ParseError; + } +}; + +pub const Credentials = struct { + username: []const u8, + password: []const u8, + + pub fn dupe(credentials: *const Credentials, allocator: std.mem.Allocator) !Credentials { + const u = try allocator.dupe(u8, credentials.username); + errdefer allocator.free(u); + const p = try allocator.dupe(u8, credentials.password); + return .{ .username = u, .password = p }; + } + + pub fn deinit(credens: *Credentials, allocator: std.mem.Allocator) void { + allocator.free(credens.username); + allocator.free(credens.password); + } +}; + +pub const CandidatePair = struct { + local: Candidate, + remote: Candidate, + priority: u64, + state: PairState = .{}, + + /// The number of connectivity checks sent so far. + conn_check_count: u8 = 0, + + pub const Status = enum(u2) { waiting, in_progress, failed, succeeded }; + + pub const PairState = packed struct(u8) { + status: Status = .waiting, + nominated: bool = false, + nominateOnBinding: bool = false, + _pad: u4 = 0, + }; + + pub fn compare(_: void, lhs: CandidatePair, rhs: CandidatePair) bool { + return lhs.priority > rhs.priority; + } + + pub fn format( + self: CandidatePair, + writer: *std.Io.Writer, + ) std.Io.Writer.Error!void { + try writer.print("{f}({}) <=> {f}({})[{}]", .{ + self.local.address, + self.local.candidate_type, + self.remote.address, + self.remote.candidate_type, + self.priority, + }); + } +}; + +test { + _ = @import("agent.zig"); +} diff --git a/src/root.zig b/src/root.zig index b436f2b..ebc77a8 100644 --- a/src/root.zig +++ b/src/root.zig @@ -2,3 +2,4 @@ pub const rtp = @import("rtp"); pub const sdp = @import("sdp"); pub const rtsp = @import("rtsp"); pub const stun = @import("stun"); +pub const ice = @import("ice"); diff --git a/src/stun/stun.zig b/src/stun/stun.zig index 305ff07..de21b99 100644 --- a/src/stun/stun.zig +++ b/src/stun/stun.zig @@ -1,11 +1,16 @@ const std = @import("std"); - const Io = std.Io; pub const magic_cookie: u32 = 0x2112A442; pub const header_size = 20; + const fingerprint_xor: u32 = 0x5354554e; +/// Returns `true` if it's stun message. +pub fn isMessage(msg: []const u8) bool { + return msg.len >= header_size and std.mem.readInt(u32, msg[4..8], .big) == magic_cookie; +} + pub const Class = enum(u2) { request, indication, From fa2b2d91a703a2a421d3d225ba4087ed66359bf3 Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Wed, 13 May 2026 18:01:11 +0100 Subject: [PATCH 2/8] Add some basic tests --- src/ice/agent.zig | 68 ++++++------------------ src/ice/ice.zig | 132 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 135 insertions(+), 65 deletions(-) diff --git a/src/ice/agent.zig b/src/ice/agent.zig index a4d31ef..907e8fa 100644 --- a/src/ice/agent.zig +++ b/src/ice/agent.zig @@ -116,7 +116,7 @@ pub fn addRemoteCandidate(agent: *Agent, remote_candidate: Candidate) !void { pub fn gatherCandidates(agent: *Agent) !void { try agent.gatherHostCandidates(); try agent.initSockets(); - try agent.group.concurrent(agent.io, listenForMessages, .{agent}); + try agent.group.concurrent(agent.io, listenForConnectivityChecks, .{agent}); } /// Poll for events @@ -384,7 +384,7 @@ fn buildBindingRequest(agent: *Agent, tx_id: u96, buffer: *[max_message_size]u8) var username = [_][]const u8{ agent.remote_credentials.?.username, ":", agent.credentials.username }; try w.writeRaw(.username, &username); - try w.writeAttribute(.{ .priority = 10 }); + try w.writeAttribute(.{ .priority = ice.CandidateType.peer_reflexive.priority() }); const role_attribute: stun.Attribute = switch (agent.role) { .controlled => .{ .ice_controlled = agent.tie_breaker }, .controlling => .{ .ice_controlling = agent.tie_breaker }, @@ -426,12 +426,12 @@ fn buildSuccessResponse( return w.final(); } -fn findSocket(sockets: []Io.net.Socket, addr: *const Io.net.IpAddress) *Io.net.Socket { +fn findSocket(sockets: []Io.net.Socket, addr: *const IpAddress) *Io.net.Socket { for (sockets) |*socket| if (socket.address.eql(addr)) return socket; unreachable; } -fn findCandidatePair(agent: *Agent, local: *const Io.net.IpAddress, remote: *const Io.net.IpAddress) ?*CandidatePair { +fn findCandidatePair(agent: *Agent, local: *const IpAddress, remote: *const IpAddress) ?*CandidatePair { for (agent.pairs.items) |*candidate| { if (candidate.local.address.eql(local) and candidate.remote.address.eql(remote)) return candidate; @@ -520,13 +520,6 @@ fn doListen(agent: *Agent) !void { }; } -fn listenForMessages(agent: *Agent) !void { - agent.doListenForMessages() catch |err| switch (err) { - error.Canceled => return error.Canceled, - else => {}, - }; -} - fn startConnectivityChecks(agent: *Agent) !void { while (true) { agent.queue.putOne(agent.io, .check_connectivity) catch |err| switch (err) { @@ -537,7 +530,14 @@ fn startConnectivityChecks(agent: *Agent) !void { } } -fn doListenForMessages(agent: *Agent) !void { +fn listenForConnectivityChecks(agent: *Agent) !void { + agent.doListenForConnectivityChecks() catch |err| switch (err) { + error.Canceled => return error.Canceled, + else => {}, + }; +} + +fn doListenForConnectivityChecks(agent: *Agent) !void { const IncomingMessageSelect = Io.Select(Receive); var queue: [4]Receive = undefined; @@ -565,38 +565,12 @@ fn doListenForMessages(agent: *Agent) !void { } } -fn doStartChecks(agent: *Agent) !void { - while (true) { - const buffer = try agent.buffer_pool.create(agent.allocator); - defer agent.buffer_pool.destroy(buffer); - - try agent.mutex.lock(agent.io); - for (agent.pairs.items) |*pair| switch (pair.state.status) { - .waiting, .in_progress => { - pair.conn_check_count += 1; - if (pair.conn_check_count > max_binding_requests) { - pair.state.status = .failed; - continue; - } - - const transaction_id = generateTrasactionId(agent.io); - const msg = try agent.buildBindingRequest(transaction_id, buffer); - - try agent.pending_requests.append(agent.allocator, .{ - .transaction_id = transaction_id, - .source = pair.local.base, - .target = pair.remote.address, - }); - - const socket = findSocket(agent.sockets, &pair.local.base); - try socket.send(agent.io, &pair.remote.address, msg); - }, - else => {}, - }; - agent.mutex.unlock(agent.io); +fn receive(agent: *Agent, socket: *Socket, index: usize) !struct { usize, Io.net.IncomingMessage } { + const buffer = try agent.buffer_pool.create(agent.allocator); + errdefer agent.buffer_pool.destroy(buffer); - try agent.io.sleep(.fromMilliseconds(200), .awake); - } + const incoming_message = try socket.receive(agent.io, &(buffer.*)); + return .{ index, incoming_message }; } fn batchSendConnectivityCheck(agent: *Agent) !void { @@ -626,11 +600,3 @@ fn batchSendConnectivityCheck(agent: *Agent) !void { else => {}, }; } - -fn receive(agent: *Agent, socket: *Socket, index: usize) !struct { usize, Io.net.IncomingMessage } { - const buffer = try agent.buffer_pool.create(agent.allocator); - errdefer agent.buffer_pool.destroy(buffer); - - const incoming_message = try socket.receive(agent.io, &(buffer.*)); - return .{ index, incoming_message }; -} diff --git a/src/ice/ice.zig b/src/ice/ice.zig index 4471c36..efff989 100644 --- a/src/ice/ice.zig +++ b/src/ice/ice.zig @@ -19,7 +19,7 @@ pub const CandidateType = enum { }; } - pub fn name(self: CandidateType) []const u8 { + pub fn toSlice(self: CandidateType) []const u8 { return switch (self) { .host => "host", .peer_reflexive => "prflx", @@ -28,7 +28,7 @@ pub const CandidateType = enum { }; } - pub fn fromSlice(slice: []const u8) !CandidateType { + pub fn fromSlice(slice: []const u8) error{InvalidCandidateType}!CandidateType { return if (std.mem.eql(u8, slice, "host")) .host else if (std.mem.eql(u8, slice, "prflx")) @@ -40,6 +40,48 @@ pub const CandidateType = enum { else error.InvalidCandidateType; } + + pub fn priority(self: CandidateType) u32 { + return (@as(u32, 1) << 24) * self.typePreference() + (1 << 8) * 65535 + 255; + } + + test "type preference" { + const types = [_]CandidateType{ .host, .peer_reflexive, .server_reflexive, .relayed }; + const preferences = [_]u8{ 126, 110, 100, 0 }; + + for (&types, &preferences) |t, preference| { + try std.testing.expectEqual(preference, t.typePreference()); + } + } + + test "toSlice" { + const types = [_]CandidateType{ .host, .peer_reflexive, .server_reflexive, .relayed }; + const names = [_][]const u8{ "host", "prflx", "srflx", "relay" }; + + for (&types, &names) |t, type_name| { + try std.testing.expectEqualStrings(type_name, t.toSlice()); + } + } + + test "fromSlice" { + const types = [_]CandidateType{ .host, .peer_reflexive, .server_reflexive, .relayed }; + const names = [_][]const u8{ "host", "prflx", "srflx", "relay" }; + + for (&types, &names) |t, name| { + try std.testing.expectEqual(t, try CandidateType.fromSlice(name)); + } + + try std.testing.expectError(error.InvalidCandidateType, CandidateType.fromSlice("unknown")); + } + + test "priority" { + const types = [_]CandidateType{ .host, .peer_reflexive, .server_reflexive, .relayed }; + const priorities = [_]u32{ 2130706431, 1862270975, 1694498815, 16777215 }; + + for (&types, &priorities) |t, type_priority| { + try std.testing.expectEqual(type_priority, t.priority()); + } + } }; pub const Candidate = struct { @@ -54,7 +96,7 @@ pub const Candidate = struct { .candidate_type = .host, .base = address, .address = address, - .priority = calculatePriority(.host), + .priority = CandidateType.host.priority(), }; candidate.calculateFoundation(); return candidate; @@ -65,7 +107,7 @@ pub const Candidate = struct { .candidate_type = .peer_reflexive, .base = base, .address = address, - .priority = calculatePriority(.peer_reflexive), + .priority = CandidateType.peer_reflexive.priority(), }; candidate.calculateFoundation(); return candidate; @@ -76,14 +118,17 @@ pub const Candidate = struct { const foundation = try std.fmt.parseUnsigned(u32, try nextToken(it.next()), 10); _ = try nextToken(it.next()); // component - _ = try nextToken(it.next()); // assume udp + + const transport = try nextToken(it.next()); + if (!std.mem.eql(u8, transport, "udp")) return error.UnsupportedTransport; + const priority = try std.fmt.parseUnsigned(u32, try nextToken(it.next()), 10); const address = try nextToken(it.next()); const port = try std.fmt.parseUnsigned(u16, try nextToken(it.next()), 10); const addr = try Io.net.IpAddress.parse(address, port); - _ = try nextToken(it.next()); // typ + if (!std.mem.eql(u8, try nextToken(it.next()), "typ")) return error.ParseError; const candidate_type = try CandidateType.fromSlice(try nextToken(it.next())); return .{ @@ -96,7 +141,7 @@ pub const Candidate = struct { } pub fn format(self: @This(), writer: *std.Io.Writer) !void { - try writer.print("{d:0>8} {} {s} {} ", .{ self.foundation, 1, "udp", self.priority }); + try writer.print("{d} {} {s} {} ", .{ self.foundation, 1, "udp", self.priority }); switch (self.address) { .ip4 => |ip| try writer.print("{d}.{d}.{d}.{d} {d} ", .{ ip.bytes[0], @@ -105,14 +150,14 @@ pub const Candidate = struct { ip.bytes[3], ip.port, }), - else => {}, + else => return error.WriteFailed, } - try writer.print("typ {s}", .{self.candidate_type.name()}); + try writer.print("typ {s}", .{self.candidate_type.toSlice()}); } fn calculateFoundation(self: *Candidate) void { var hasher = std.hash.Crc32.init(); - hasher.update(self.candidate_type.name()); + hasher.update(self.candidate_type.toSlice()); hasher.update(switch (self.address) { .ip4 => |addr| &addr.bytes, .ip6 => |addr| &addr.bytes, @@ -121,13 +166,70 @@ pub const Candidate = struct { self.foundation = hasher.final(); } - fn calculatePriority(t: CandidateType) u32 { - return (@as(u32, 1) << 24) * t.typePreference() + (1 << 8) * 65535 + 255; - } - inline fn nextToken(maybe_token: ?[]const u8) ![]const u8 { return if (maybe_token) |token| token else error.ParseError; } + + test "initHost" { + const ip_addr: Io.net.IpAddress = try .parse("192.168.8.10", 1000); + const candidate = initHost(ip_addr); + + try std.testing.expect(candidate.base.eql(&ip_addr)); + try std.testing.expect(candidate.address.eql(&ip_addr)); + try std.testing.expectEqual(.host, candidate.candidate_type); + try std.testing.expectEqual(CandidateType.host.priority(), candidate.priority); + try std.testing.expect(candidate.foundation != 0); + } + + test "initPeerReflexive" { + const ip_addr: Io.net.IpAddress = try .parse("192.168.8.10", 1000); + const reflexive_addr: Io.net.IpAddress = try .parse("192.168.6.20", 1000); + const candidate = initPeerReflexive(ip_addr, reflexive_addr); + + try std.testing.expect(candidate.base.eql(&ip_addr)); + try std.testing.expect(candidate.address.eql(&reflexive_addr)); + try std.testing.expectEqual(.peer_reflexive, candidate.candidate_type); + try std.testing.expectEqual(CandidateType.peer_reflexive.priority(), candidate.priority); + try std.testing.expect(candidate.foundation != 0); + } + + test "parse" { + const values = [_][]const u8{ + "1890 1 udp 998000 10.77.0.1 45909 typ prflx ufrag username", + "1890 1 tcp 998000 10.77.0.1 45909 typ prflx ufrag username", + }; + + { + const candidate = try parse(values[0]); + try std.testing.expectEqual(1890, candidate.foundation); + try std.testing.expectEqual(998000, candidate.priority); + try std.testing.expectEqual(.peer_reflexive, candidate.candidate_type); + + const expected_addr = Io.net.IpAddress{ .ip4 = .{ .bytes = [_]u8{ 10, 77, 0, 1 }, .port = 45909 } }; + try std.testing.expect(expected_addr.eql(&candidate.address)); + try std.testing.expect(expected_addr.eql(&candidate.base)); + } + { + try std.testing.expectError(error.UnsupportedTransport, parse(values[1])); + } + } + + test "format" { + const addr = Io.net.IpAddress{ .ip4 = .{ .bytes = [_]u8{ 10, 77, 0, 1 }, .port = 45909 } }; + const candidate: Candidate = .{ + .base = addr, + .address = addr, + .candidate_type = .peer_reflexive, + .priority = 998000, + .foundation = 1890, + }; + + var buffer: [64]u8 = undefined; + var w = Io.Writer.fixed(&buffer); + try candidate.format(&w); + + try std.testing.expectEqualStrings("1890 1 udp 998000 10.77.0.1 45909 typ prflx", w.buffered()); + } }; pub const Credentials = struct { @@ -184,5 +286,7 @@ pub const CandidatePair = struct { }; test { + std.testing.refAllDecls(@This()); + _ = @import("agent.zig"); } From 8d2b0fbfa80f6c525d4ca9563adec3e25484628b Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Thu, 14 May 2026 07:31:00 +0100 Subject: [PATCH 3/8] report connection state update --- src/ice/agent.zig | 86 ++++++++++++++++++++++++++++------------------- src/ice/ice.zig | 2 ++ 2 files changed, 54 insertions(+), 34 deletions(-) diff --git a/src/ice/agent.zig b/src/ice/agent.zig index 907e8fa..aec9542 100644 --- a/src/ice/agent.zig +++ b/src/ice/agent.zig @@ -19,7 +19,7 @@ const keep_alive_interval: std.Io.Duration = .fromMilliseconds(200); io: Io, allocator: std.mem.Allocator, buffer_pool: std.heap.MemoryPool([max_message_size]u8), -state: State = .new, +connection_state: ice.ConnectionState = .new, // Stun related fields role: Role, @@ -39,7 +39,10 @@ group: Io.Group = .init, queue_buffer: [1]InternalEvent = undefined, queue: Io.Queue(InternalEvent) = undefined, -pub const State = enum { new, checking, connected, disconnected, failed }; +pub const Event = union(enum) { + data: []const u8, + connection_state: ice.ConnectionState, +}; const Role = enum { controlling, controlled }; @@ -48,6 +51,7 @@ const InternalEvent = union(enum) { message: struct { IpAddress, Io.net.IncomingMessage }, check_connectivity: void, data: []u8, + connection_state: ice.ConnectionState, }; const StunRequest = struct { @@ -95,10 +99,11 @@ pub fn deinit(agent: *Agent) void { } pub fn setRemoteCredentials(agent: *Agent, credentials: ice.Credentials) !void { - switch (agent.state) { + switch (agent.connection_state) { .new => { agent.remote_credentials = try credentials.dupe(agent.allocator); - agent.state = .checking; + agent.connection_state = .checking; + try agent.queue.putOne(agent.io, .{ .connection_state = agent.connection_state }); try agent.group.concurrent(agent.io, startConnectivityChecks, .{agent}); }, else => return error.CredentialsAlreadySet, @@ -106,7 +111,7 @@ pub fn setRemoteCredentials(agent: *Agent, credentials: ice.Credentials) !void { } pub fn addRemoteCandidate(agent: *Agent, remote_candidate: Candidate) !void { - switch (agent.state) { + switch (agent.connection_state) { .new => try agent.doAddRemoteCandidate(remote_candidate), .checking => try agent.queue.putOne(agent.io, .{ .add_candidate = remote_candidate }), else => {}, @@ -120,7 +125,7 @@ pub fn gatherCandidates(agent: *Agent) !void { } /// Poll for events -pub fn poll(agent: *Agent) !?[]u8 { +pub fn poll(agent: *Agent) !Event { const io = agent.io; while (agent.queue.getOne(io)) |event| switch (event) { @@ -134,40 +139,51 @@ pub fn poll(agent: *Agent) !?[]u8 { } } else { for (agent.pairs.items) |*candidate_pair| if (candidate_pair.remote.address.eql(&s.@"1".from)) { - return s.@"1".data; + return .{ .data = s.@"1".data }; }; + + std.log.warn("Drop non stun message from unknown remote candidate: {f}", .{s.@"1".from}); continue; } - try agent.maybeSetNominatedCandidate(); + if (try agent.maybeSetNominatedCandidate()) { + return .{ .connection_state = agent.connection_state }; + } }, .check_connectivity => try agent.batchSendConnectivityCheck(), - .data => |data| return data, + .data => |data| return .{ .data = data }, + .connection_state => |state| return .{ .connection_state = state }, } else |err| switch (err) { error.Canceled => return error.Canceled, - else => {}, + else => unreachable, } +} - return null; +pub fn destroyPacket(agent: *Agent, data: []const u8) void { + agent.buffer_pool.destroy(@ptrCast(@alignCast(@constCast(data)))); } fn initSockets(agent: *Agent) !void { - agent.sockets = try agent.allocator.alloc(Io.net.Socket, agent.candidates.items.len); - var initialized: usize = 0; - errdefer { - for (0..initialized) |idx| agent.sockets[idx].close(agent.io); - agent.allocator.free(agent.sockets); - } + var sockets: std.ArrayList(Io.net.Socket) = try .initCapacity(agent.allocator, agent.candidates.items.len); + errdefer sockets.deinit(agent.allocator); - for (agent.candidates.items) |*candidate| { - agent.sockets[initialized] = try candidate.address.bind( - agent.io, - .{ .mode = .dgram, .protocol = .udp }, - ); - candidate.base = agent.sockets[initialized].address; - candidate.address = agent.sockets[initialized].address; - initialized += 1; + const candidates = agent.candidates.items; + var index: usize = 0; + + while (true) { + if (index >= candidates.len) break; + const socket = candidates[index].address.bind(agent.io, .{ .mode = .dgram, .protocol = .udp }) catch { + _ = agent.candidates.swapRemove(index); + continue; + }; + + sockets.appendAssumeCapacity(socket); + candidates[index].base = socket.address; + candidates[index].address = socket.address; + index += 1; } + + agent.sockets = try sockets.toOwnedSlice(agent.allocator); } fn calculatePairPriority(l: u32, r: u32, role: Role) u64 { @@ -218,7 +234,7 @@ fn gatherHostCandidates(agent: *Agent) !void { }; } -fn doAddRemoteCandidate(agent: *Agent, remote_candidate: Candidate) !void { +fn doAddRemoteCandidate(agent: *Agent, remote_candidate: Candidate) std.mem.Allocator.Error!void { for (agent.candidates.items) |candidate| { for (agent.pairs.items) |*pair| if (pair.local.base.eql(&candidate.base) and pair.remote.address.eql(&remote_candidate.address)) @@ -440,12 +456,12 @@ fn findCandidatePair(agent: *Agent, local: *const IpAddress, remote: *const IpAd return null; } -fn maybeSetNominatedCandidate(agent: *Agent) !void { - if (agent.role == .controlling or agent.nominated_pair != null) return; +fn maybeSetNominatedCandidate(agent: *Agent) !bool { + if (agent.role == .controlling or agent.nominated_pair != null) return false; for (agent.pairs.items) |candidate_pair| if (candidate_pair.state.nominated) { agent.nominated_pair = candidate_pair; - agent.state = .connected; + agent.connection_state = .connected; agent.group.cancel(agent.io); // Clean up and listen on socket @@ -455,8 +471,10 @@ fn maybeSetNominatedCandidate(agent: *Agent) !void { agent.pairs.items[0] = candidate_pair; try agent.group.concurrent(agent.io, listen, .{agent}); - break; + return true; }; + + return false; } // ============== Io related function ====================== @@ -465,7 +483,7 @@ const Receive = union(enum) { }; const ListenEvent = union(enum) { - message: Io.net.Socket.ReceiveTimeoutError!Io.net.IncomingMessage, + message: Socket.ReceiveTimeoutError!Io.net.IncomingMessage, keep_alive: Io.Cancelable!void, }; @@ -499,9 +517,9 @@ fn doListen(agent: *Agent) !void { const msg = maybe_msg catch |err| switch (err) { error.Canceled => return error.Canceled, error.Timeout => { - if (agent.state != .disconnected) { - Logger.warn("Agent state transitioned to disconnected", .{}); - agent.state = .disconnected; + if (agent.connection_state != .disconnected) { + agent.connection_state = .disconnected; + try agent.queue.putOne(agent.io, .{ .connection_state = .disconnected }); } continue; }, diff --git a/src/ice/ice.zig b/src/ice/ice.zig index efff989..b97ed5b 100644 --- a/src/ice/ice.zig +++ b/src/ice/ice.zig @@ -285,6 +285,8 @@ pub const CandidatePair = struct { } }; +pub const ConnectionState = enum { new, checking, connected, completed, disconnected, failed }; + test { std.testing.refAllDecls(@This()); From b4285fcce78dfb67ed3b067795f79db59a341205 Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Thu, 14 May 2026 13:28:07 +0100 Subject: [PATCH 4/8] Use callbacks for data and connection state --- src/ice/agent.zig | 286 +++++++++++++++++++++------------------------- 1 file changed, 133 insertions(+), 153 deletions(-) diff --git a/src/ice/agent.zig b/src/ice/agent.zig index aec9542..b37995e 100644 --- a/src/ice/agent.zig +++ b/src/ice/agent.zig @@ -4,6 +4,7 @@ const stun = @import("stun"); const ice = @import("ice.zig"); const Io = std.Io; +const Allocator = std.mem.Allocator; const Socket = Io.net.Socket; const IpAddress = Io.net.IpAddress; const Candidate = ice.Candidate; @@ -15,10 +16,17 @@ const max_message_size = 1500; const max_binding_requests: usize = 7; const connectivity_check_interval: std.Io.Duration = .fromMilliseconds(200); const keep_alive_interval: std.Io.Duration = .fromMilliseconds(200); +const disconnect_timeout: Io.Clock.Duration = .{ .clock = .awake, .raw = .fromSeconds(5) }; + +pub const AgentConfig = struct { + onConnectionState: *const fn (*Agent, ice.ConnectionState) void, + onData: *const fn (*Agent, []const u8) void, +}; io: Io, -allocator: std.mem.Allocator, +allocator: Allocator, buffer_pool: std.heap.MemoryPool([max_message_size]u8), +config: AgentConfig, connection_state: ice.ConnectionState = .new, // Stun related fields @@ -36,8 +44,6 @@ nominated_pair: ?CandidatePair = null, // Io handling group: Io.Group = .init, -queue_buffer: [1]InternalEvent = undefined, -queue: Io.Queue(InternalEvent) = undefined, pub const Event = union(enum) { data: []const u8, @@ -46,14 +52,6 @@ pub const Event = union(enum) { const Role = enum { controlling, controlled }; -const InternalEvent = union(enum) { - add_candidate: Candidate, - message: struct { IpAddress, Io.net.IncomingMessage }, - check_connectivity: void, - data: []u8, - connection_state: ice.ConnectionState, -}; - const StunRequest = struct { username: []const u8 = &.{}, ice_controlled: ?u64 = null, @@ -68,97 +66,66 @@ const PendingRequest = struct { target: Io.net.IpAddress, }; -pub fn init(agent: *Agent, io: Io, allocator: std.mem.Allocator) !void { - agent.* = .{ +pub fn init(io: Io, allocator: Allocator, config: AgentConfig) !Agent { + return .{ .io = io, .allocator = allocator, .buffer_pool = .empty, .role = .controlled, .credentials = try (ice.Credentials{ .username = "test", .password = "test" }).dupe(allocator), .tie_breaker = generateTieBeaker(io), + .config = config, }; - - agent.queue = .init(&agent.queue_buffer); } pub fn deinit(agent: *Agent) void { const io = agent.io; const allocator = agent.allocator; + agent.group.cancel(io); - agent.buffer_pool.deinit(allocator); - agent.candidates.deinit(allocator); - agent.pairs.deinit(allocator); for (agent.sockets) |socket| socket.close(io); allocator.free(agent.sockets); + agent.candidates.deinit(allocator); + agent.pairs.deinit(allocator); + agent.pending_requests.deinit(allocator); agent.credentials.deinit(allocator); if (agent.remote_credentials) |*credens| credens.deinit(allocator); - agent.queue.close(io); - agent.group.cancel(io); + agent.buffer_pool.deinit(allocator); } +/// Set remote credentials +/// +/// Calling this function will trigger connectivity checks. `gatherCandidates` should be called first. pub fn setRemoteCredentials(agent: *Agent, credentials: ice.Credentials) !void { switch (agent.connection_state) { .new => { agent.remote_credentials = try credentials.dupe(agent.allocator); - agent.connection_state = .checking; - try agent.queue.putOne(agent.io, .{ .connection_state = agent.connection_state }); - try agent.group.concurrent(agent.io, startConnectivityChecks, .{agent}); + agent.setConnectionState(.checking); }, else => return error.CredentialsAlreadySet, } } pub fn addRemoteCandidate(agent: *Agent, remote_candidate: Candidate) !void { + // TODO: Add mutex switch (agent.connection_state) { - .new => try agent.doAddRemoteCandidate(remote_candidate), - .checking => try agent.queue.putOne(agent.io, .{ .add_candidate = remote_candidate }), + .new, .checking, .connected => try agent.doAddRemoteCandidate(remote_candidate), else => {}, } } +/// Start gathering candidates and start inner event handler. +/// +/// This function should be called first after initializing the agent. pub fn gatherCandidates(agent: *Agent) !void { try agent.gatherHostCandidates(); try agent.initSockets(); - try agent.group.concurrent(agent.io, listenForConnectivityChecks, .{agent}); -} - -/// Poll for events -pub fn poll(agent: *Agent) !Event { - const io = agent.io; - - while (agent.queue.getOne(io)) |event| switch (event) { - .add_candidate => |remote_candidate| try agent.addRemoteCandidate(remote_candidate), - .message => |s| { - defer agent.buffer_pool.destroy(@ptrCast(@alignCast(s.@"1".data.ptr))); - if (stun.isMessage(s.@"1".data)) { - if (try agent.handleReceivedMessage(s.@"0", s.@"1")) |response| { - defer agent.buffer_pool.destroy(@ptrCast(@alignCast(@constCast(response.ptr)))); - try findSocket(agent.sockets, &s.@"0").send(io, &s.@"1".from, response); - } - } else { - for (agent.pairs.items) |*candidate_pair| if (candidate_pair.remote.address.eql(&s.@"1".from)) { - return .{ .data = s.@"1".data }; - }; - - std.log.warn("Drop non stun message from unknown remote candidate: {f}", .{s.@"1".from}); - continue; - } - - if (try agent.maybeSetNominatedCandidate()) { - return .{ .connection_state = agent.connection_state }; - } - }, - .check_connectivity => try agent.batchSendConnectivityCheck(), - .data => |data| return .{ .data = data }, - .connection_state => |state| return .{ .connection_state = state }, - } else |err| switch (err) { - error.Canceled => return error.Canceled, - else => unreachable, - } + try agent.group.concurrent(agent.io, innerEventHandlerWrapper, .{agent}); } +/// Free the buffer and return to the pool. pub fn destroyPacket(agent: *Agent, data: []const u8) void { agent.buffer_pool.destroy(@ptrCast(@alignCast(@constCast(data)))); } @@ -234,7 +201,7 @@ fn gatherHostCandidates(agent: *Agent) !void { }; } -fn doAddRemoteCandidate(agent: *Agent, remote_candidate: Candidate) std.mem.Allocator.Error!void { +fn doAddRemoteCandidate(agent: *Agent, remote_candidate: Candidate) Allocator.Error!void { for (agent.candidates.items) |candidate| { for (agent.pairs.items) |*pair| if (pair.local.base.eql(&candidate.base) and pair.remote.address.eql(&remote_candidate.address)) @@ -461,134 +428,147 @@ fn maybeSetNominatedCandidate(agent: *Agent) !bool { for (agent.pairs.items) |candidate_pair| if (candidate_pair.state.nominated) { agent.nominated_pair = candidate_pair; - agent.connection_state = .connected; - agent.group.cancel(agent.io); - - // Clean up and listen on socket - agent.candidates.deinit(agent.allocator); - for (agent.sockets) |*socket| if (!socket.address.eql(&candidate_pair.local.base)) socket.close(agent.io); - try agent.pairs.shrinkAndFreePrecise(agent.allocator, 1); - agent.pairs.items[0] = candidate_pair; - - try agent.group.concurrent(agent.io, listen, .{agent}); return true; }; return false; } +fn setConnectionState(agent: *Agent, new_state: ice.ConnectionState) void { + agent.connection_state = new_state; + agent.config.onConnectionState(agent, new_state); +} + // ============== Io related function ====================== -const Receive = union(enum) { - message: anyerror!struct { usize, Io.net.IncomingMessage }, +const MessageError = (Allocator.Error || Socket.ReceiveTimeoutError); + +const Message = struct { + socket: *const Socket, + incoming_message: Io.net.IncomingMessage, }; -const ListenEvent = union(enum) { - message: Socket.ReceiveTimeoutError!Io.net.IncomingMessage, +const InnerEvent = union(enum) { + message: MessageError!Message, + connectivity_check: Io.Cancelable!void, + send_message: (Allocator.Error || Socket.SendError)!void, + complete: Io.Cancelable!void, + // message received from the nominated peer + data_message: MessageError!Message, keep_alive: Io.Cancelable!void, }; -fn listen(agent: *Agent) !void { - agent.doListen() catch |err| switch (err) { +fn innerEventHandlerWrapper(agent: *Agent) !void { + agent.innerEventHandler() catch |err| switch (err) { error.Canceled => return error.Canceled, - else => {}, + else => |e| std.log.err("Error occurred in event handler: {}", .{e}), }; } -fn doListen(agent: *Agent) !void { - const ListenSelect = Io.Select(ListenEvent); - var listen_event_buffer: [1]ListenEvent = undefined; - var select = ListenSelect.init(agent.io, &listen_event_buffer); +fn innerEventHandler(agent: *Agent) !void { + const io = agent.io; + const Select = Io.Select(InnerEvent); + + var queue: [1]InnerEvent = undefined; + var select = Select.init(agent.io, &queue); defer select.cancelDiscard(); - const socket = findSocket(agent.sockets, &agent.nominated_pair.?.local.base); - const receive_timeout: Io.Timeout = .{ .duration = .{ .clock = .awake, .raw = .fromSeconds(5) } }; + for (agent.sockets) |*socket| { + select.async(.message, receiveTimeout, .{ agent, socket, .none }); + } + select.async(.connectivity_check, Io.sleep, .{ io, connectivity_check_interval, .awake }); - const buffer = try agent.buffer_pool.create(agent.allocator); - defer agent.buffer_pool.destroy(buffer); + var nominated_socket: Socket = undefined; - var stun_indication: [20]u8 = undefined; - const dest = &agent.nominated_pair.?.remote.address; + while (true) switch (try select.await()) { + .connectivity_check => |timeout| { + try timeout; + switch (agent.connection_state) { + .completed, .failed => {}, + else => { + try agent.batchSendConnectivityCheck(); + select.async(.connectivity_check, Io.sleep, .{ io, connectivity_check_interval, .awake }); + }, + } + }, + .message => |result| { + const message = try result; - select.async(.message, Io.net.Socket.receiveTimeout, .{ socket, agent.io, &(buffer.*), receive_timeout }); - select.async(.keep_alive, Io.sleep, .{ agent.io, Io.Duration.fromSeconds(2), Io.Clock.awake }); + const data = message.incoming_message.data; + const sender = message.incoming_message.from; - while (true) switch (try select.await()) { - .message => |maybe_msg| { - const msg = maybe_msg catch |err| switch (err) { - error.Canceled => return error.Canceled, + if (stun.isMessage(data)) { + defer agent.destroyPacket(data); + if (try agent.handleReceivedMessage(message.socket.address, message.incoming_message)) |response| + select.async(.send_message, send, .{ agent, message.socket, &sender, response }); + + if (try agent.maybeSetNominatedCandidate()) { + agent.setConnectionState(.connected); + nominated_socket = message.socket.*; + + select.async(.data_message, receiveTimeout, .{ agent, &nominated_socket, .{ .duration = disconnect_timeout } }); + select.async(.keep_alive, Io.sleep, .{ io, keep_alive_interval, .awake }); + select.async(.complete, Io.sleep, .{ io, .fromSeconds(3), .awake }); + continue; + } + } else { + for (agent.pairs.items) |*candidate_pair| if (candidate_pair.remote.address.eql(&sender)) { + agent.config.onData(agent, data); + } else { + std.log.warn("Drop non stun message from unknown remote candidate: {f}", .{sender}); + agent.destroyPacket(data); + }; + } + + select.async(.message, receiveTimeout, .{ agent, message.socket, .none }); + }, + .send_message => |result| result catch |err| std.log.err("failed to send response: {}", .{err}), + .data_message => |result| { + const message = result catch |err| switch (err) { error.Timeout => { - if (agent.connection_state != .disconnected) { - agent.connection_state = .disconnected; - try agent.queue.putOne(agent.io, .{ .connection_state = .disconnected }); - } + if (agent.connection_state != .disconnected) agent.setConnectionState(.disconnected); continue; }, - else => return, + else => |e| return e, }; - select.async(.message, Io.net.Socket.receiveTimeout, .{ socket, agent.io, &(buffer.*), receive_timeout }); - if (stun.isMessage(msg.data)) continue; - try agent.queue.putOne(agent.io, .{ .data = msg.data }); + if (stun.isMessage(message.incoming_message.data)) + agent.buffer_pool.destroy(@ptrCast(@alignCast(message.incoming_message.data.ptr))) + else + agent.config.onData(agent, message.incoming_message.data); + + select.async(.data_message, receiveTimeout, .{ agent, message.socket, .{ .duration = disconnect_timeout } }); }, .keep_alive => |timeout| { try timeout; - select.async(.keep_alive, Io.sleep, .{ agent.io, keep_alive_interval, Io.Clock.awake }); - try socket.send(agent.io, dest, try buildIndicationRequest(&stun_indication)); - }, - }; -} - -fn startConnectivityChecks(agent: *Agent) !void { - while (true) { - agent.queue.putOne(agent.io, .check_connectivity) catch |err| switch (err) { - error.Canceled => return error.Canceled, - else => return, - }; - try agent.io.sleep(connectivity_check_interval, .awake); - } -} + select.async(.keep_alive, Io.sleep, .{ io, keep_alive_interval, .awake }); -fn listenForConnectivityChecks(agent: *Agent) !void { - agent.doListenForConnectivityChecks() catch |err| switch (err) { - error.Canceled => return error.Canceled, - else => {}, + var buffer: [20]u8 = undefined; + try nominated_socket.send(agent.io, &agent.nominated_pair.?.remote.address, try buildIndicationRequest(&buffer)); + }, + .complete => |result| { + try result; + for (agent.sockets) |*socket| if (!socket.address.eql(&nominated_socket.address)) socket.close(io); + agent.sockets = try agent.allocator.realloc(agent.sockets, 1); + agent.sockets[0] = nominated_socket; + + agent.pairs.clearAndFree(agent.allocator); + agent.pending_requests.clearAndFree(agent.allocator); + agent.setConnectionState(.completed); + }, }; } -fn doListenForConnectivityChecks(agent: *Agent) !void { - const IncomingMessageSelect = Io.Select(Receive); - - var queue: [4]Receive = undefined; - var select = IncomingMessageSelect.init(agent.io, &queue); - defer select.cancelDiscard(); - - for (agent.sockets, 0..) |*socket, idx| { - select.async(.message, receive, .{ agent, socket, idx }); - } - - while (true) { - const result = try select.await(); - - const index, const incoming_message = result.message catch |err| switch (err) { - error.Canceled => return error.Canceled, - else => |e| { - std.log.err("An error occurred when listening on socket: {}", .{e}); - continue; - }, - }; - - const socket = &agent.sockets[index]; - try agent.queue.putOne(agent.io, .{ .message = .{ socket.address, incoming_message } }); - select.async(.message, receive, .{ agent, socket, index }); - } -} - -fn receive(agent: *Agent, socket: *Socket, index: usize) !struct { usize, Io.net.IncomingMessage } { +fn receiveTimeout(agent: *Agent, socket: *const Socket, timeout: Io.Timeout) !Message { const buffer = try agent.buffer_pool.create(agent.allocator); errdefer agent.buffer_pool.destroy(buffer); - const incoming_message = try socket.receive(agent.io, &(buffer.*)); - return .{ index, incoming_message }; + const incoming_message = try socket.receiveTimeout(agent.io, &(buffer.*), timeout); + return .{ .incoming_message = incoming_message, .socket = socket }; +} + +fn send(agent: *Agent, socket: *const Socket, address: *const IpAddress, buffer: []const u8) (Allocator.Error || Socket.SendError)!void { + defer agent.buffer_pool.destroy(@ptrCast(@alignCast(@constCast(buffer)))); + try socket.send(agent.io, address, buffer); } fn batchSendConnectivityCheck(agent: *Agent) !void { From 0bca001a080a59517e82966d9cc56f10738e553a Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Fri, 15 May 2026 07:25:37 +0100 Subject: [PATCH 5/8] set connection to failed after some timeout --- src/ice/agent.zig | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/ice/agent.zig b/src/ice/agent.zig index b37995e..392455c 100644 --- a/src/ice/agent.zig +++ b/src/ice/agent.zig @@ -17,6 +17,7 @@ const max_binding_requests: usize = 7; const connectivity_check_interval: std.Io.Duration = .fromMilliseconds(200); const keep_alive_interval: std.Io.Duration = .fromMilliseconds(200); const disconnect_timeout: Io.Clock.Duration = .{ .clock = .awake, .raw = .fromSeconds(5) }; +const failing_timeout: Io.Clock.Duration = .{ .clock = .awake, .raw = .fromSeconds(25) }; pub const AgentConfig = struct { onConnectionState: *const fn (*Agent, ice.ConnectionState) void, @@ -524,9 +525,17 @@ fn innerEventHandler(agent: *Agent) !void { .send_message => |result| result catch |err| std.log.err("failed to send response: {}", .{err}), .data_message => |result| { const message = result catch |err| switch (err) { - error.Timeout => { - if (agent.connection_state != .disconnected) agent.setConnectionState(.disconnected); - continue; + error.Timeout => switch (agent.connection_state) { + .connected, .completed => { + agent.setConnectionState(.disconnected); + select.async(.data_message, receiveTimeout, .{ agent, &nominated_socket, .{ .duration = failing_timeout } }); + continue; + }, + .disconnected => { + agent.setConnectionState(.failed); + return; + }, + else => unreachable, }, else => |e| return e, }; @@ -547,13 +556,7 @@ fn innerEventHandler(agent: *Agent) !void { }, .complete => |result| { try result; - for (agent.sockets) |*socket| if (!socket.address.eql(&nominated_socket.address)) socket.close(io); - agent.sockets = try agent.allocator.realloc(agent.sockets, 1); - agent.sockets[0] = nominated_socket; - - agent.pairs.clearAndFree(agent.allocator); - agent.pending_requests.clearAndFree(agent.allocator); - agent.setConnectionState(.completed); + try agent.markConnectionCompleted(nominated_socket); }, }; } @@ -598,3 +601,13 @@ fn batchSendConnectivityCheck(agent: *Agent) !void { else => {}, }; } + +fn markConnectionCompleted(agent: *Agent, nominated_socket: Socket) !void { + for (agent.sockets) |*socket| if (!socket.address.eql(&nominated_socket.address)) socket.close(agent.io); + agent.sockets = try agent.allocator.realloc(agent.sockets, 1); + agent.sockets[0] = nominated_socket; + + agent.pairs.clearAndFree(agent.allocator); + agent.pending_requests.clearAndFree(agent.allocator); + agent.setConnectionState(.completed); +} From 35c6b121087f7e4e4b7fdca5be3f1c3419bd23c9 Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Fri, 15 May 2026 21:15:26 +0100 Subject: [PATCH 6/8] send data --- src/ice/agent.zig | 73 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 19 deletions(-) diff --git a/src/ice/agent.zig b/src/ice/agent.zig index 392455c..8d4c8b2 100644 --- a/src/ice/agent.zig +++ b/src/ice/agent.zig @@ -15,7 +15,7 @@ const Logger = std.log.scoped(.ice); const max_message_size = 1500; const max_binding_requests: usize = 7; const connectivity_check_interval: std.Io.Duration = .fromMilliseconds(200); -const keep_alive_interval: std.Io.Duration = .fromMilliseconds(200); +const keep_alive_interval: std.Io.Duration = .fromSeconds(2); const disconnect_timeout: Io.Clock.Duration = .{ .clock = .awake, .raw = .fromSeconds(5) }; const failing_timeout: Io.Clock.Duration = .{ .clock = .awake, .raw = .fromSeconds(25) }; @@ -41,14 +41,23 @@ sockets: []Io.net.Socket = &.{}, candidates: std.ArrayList(Candidate) = .empty, pairs: std.ArrayList(CandidatePair) = .empty, pending_requests: std.ArrayList(PendingRequest) = .empty, -nominated_pair: ?CandidatePair = null, +nominated_pair: ?SelectedPair = null, // Io handling group: Io.Group = .init, -pub const Event = union(enum) { - data: []const u8, - connection_state: ice.ConnectionState, +const SelectedPair = struct { + pair: CandidatePair, + socket: Socket, + + fn keep_alive(self: *const SelectedPair, io: Io) !void { + var buffer: [20]u8 = undefined; + try self.socket.send(io, &self.pair.remote.address, try buildIndicationRequest(&buffer)); + } + + inline fn sendData(self: *const SelectedPair, io: Io, data: []const u8) Socket.SendError!void { + try self.socket.send(io, &self.pair.remote.address, data); + } }; const Role = enum { controlling, controlled }; @@ -126,6 +135,13 @@ pub fn gatherCandidates(agent: *Agent) !void { try agent.group.concurrent(agent.io, innerEventHandlerWrapper, .{agent}); } +pub fn sendData(agent: *const Agent, data: []const u8) Socket.SendError!void { + switch (agent.connection_state) { + .connected, .completed => try agent.nominated_pair.?.sendData(agent.io, data), + else => std.log.debug("Agent not connected: ignore send request", .{}), + } +} + /// Free the buffer and return to the pool. pub fn destroyPacket(agent: *Agent, data: []const u8) void { agent.buffer_pool.destroy(@ptrCast(@alignCast(@constCast(data)))); @@ -176,6 +192,13 @@ fn generateTrasactionId(io: Io) u96 { } fn gatherHostCandidates(agent: *Agent) !void { + switch (@import("builtin").os.tag) { + .linux => try agent.linuxGatherHostCandidates(), + else => {}, + } +} + +fn linuxGatherHostCandidates(agent: *Agent) !void { var interfaces: [*c]c.ifaddrs = undefined; if (c.getifaddrs(&interfaces) != 0) { return error.GetIfAddrsFailed; @@ -186,11 +209,11 @@ fn gatherHostCandidates(agent: *Agent) !void { while (it) |p_ifa| : (it = p_ifa.*.ifa_next) if (p_ifa.*.ifa_addr) |addr| { switch (addr.*.sa_family) { c.AF_INET => { + const c_flags: u16 = @truncate(p_ifa.*.ifa_flags); + const flags: std.os.linux.IFF = @bitCast(c_flags); + if (flags.LOOPBACK) continue; + const sin: *const c.sockaddr_in = @ptrCast(@alignCast(addr)); - // Ignore loopback addresses. - if (sin.sin_addr.s_addr == std.mem.nativeToBig(u32, 0x7f000001)) { - continue; - } const ip_addr: Io.net.IpAddress = .{ .ip4 = .{ .bytes = std.mem.toBytes(sin.sin_addr.s_addr), .port = 0 }, @@ -502,13 +525,23 @@ fn innerEventHandler(agent: *Agent) !void { if (try agent.handleReceivedMessage(message.socket.address, message.incoming_message)) |response| select.async(.send_message, send, .{ agent, message.socket, &sender, response }); - if (try agent.maybeSetNominatedCandidate()) { + const candidate_pair: ?CandidatePair = blk: { + if (agent.role == .controlling or agent.nominated_pair != null) break :blk null; + for (agent.pairs.items) |candidate_pair| if (candidate_pair.state.nominated) break :blk candidate_pair; + break :blk null; + }; + + if (candidate_pair != null) { + agent.nominated_pair = .{ + .pair = candidate_pair.?, + .socket = message.socket.*, + }; + nominated_socket = agent.nominated_pair.?.socket; agent.setConnectionState(.connected); - nominated_socket = message.socket.*; - select.async(.data_message, receiveTimeout, .{ agent, &nominated_socket, .{ .duration = disconnect_timeout } }); - select.async(.keep_alive, Io.sleep, .{ io, keep_alive_interval, .awake }); select.async(.complete, Io.sleep, .{ io, .fromSeconds(3), .awake }); + select.async(.keep_alive, Io.sleep, .{ io, keep_alive_interval, .awake }); + select.async(.data_message, receiveTimeout, .{ agent, &nominated_socket, .{ .duration = disconnect_timeout } }); continue; } } else { @@ -552,11 +585,11 @@ fn innerEventHandler(agent: *Agent) !void { select.async(.keep_alive, Io.sleep, .{ io, keep_alive_interval, .awake }); var buffer: [20]u8 = undefined; - try nominated_socket.send(agent.io, &agent.nominated_pair.?.remote.address, try buildIndicationRequest(&buffer)); + try nominated_socket.send(agent.io, &agent.nominated_pair.?.pair.remote.address, try buildIndicationRequest(&buffer)); }, .complete => |result| { try result; - try agent.markConnectionCompleted(nominated_socket); + try agent.markConnectionCompleted(); }, }; } @@ -602,10 +635,12 @@ fn batchSendConnectivityCheck(agent: *Agent) !void { }; } -fn markConnectionCompleted(agent: *Agent, nominated_socket: Socket) !void { - for (agent.sockets) |*socket| if (!socket.address.eql(&nominated_socket.address)) socket.close(agent.io); - agent.sockets = try agent.allocator.realloc(agent.sockets, 1); - agent.sockets[0] = nominated_socket; +fn markConnectionCompleted(agent: *Agent) !void { + const addr = agent.nominated_pair.?.socket.address; + for (agent.sockets) |*socket| if (!socket.address.eql(&addr)) socket.close(agent.io); + + agent.allocator.free(agent.sockets); + agent.sockets = &.{}; agent.pairs.clearAndFree(agent.allocator); agent.pending_requests.clearAndFree(agent.allocator); From 123a934e254273c0922620b59356dbb19c1963d3 Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Sat, 16 May 2026 09:20:18 +0100 Subject: [PATCH 7/8] refactor getifaddrs --- src/ice/agent.zig | 65 ++++++++++++++++++++--------------------------- src/ice/c.h | 4 ++- 2 files changed, 31 insertions(+), 38 deletions(-) diff --git a/src/ice/agent.zig b/src/ice/agent.zig index 8d4c8b2..cd857aa 100644 --- a/src/ice/agent.zig +++ b/src/ice/agent.zig @@ -11,6 +11,7 @@ const Candidate = ice.Candidate; const CandidatePair = ice.CandidatePair; const Agent = @This(); const Logger = std.log.scoped(.ice); +const linux = std.os.linux; const max_message_size = 1500; const max_binding_requests: usize = 7; @@ -148,12 +149,15 @@ pub fn destroyPacket(agent: *Agent, data: []const u8) void { } fn initSockets(agent: *Agent) !void { - var sockets: std.ArrayList(Io.net.Socket) = try .initCapacity(agent.allocator, agent.candidates.items.len); - errdefer sockets.deinit(agent.allocator); - const candidates = agent.candidates.items; var index: usize = 0; + var sockets: std.ArrayList(Io.net.Socket) = try .initCapacity(agent.allocator, agent.candidates.items.len); + errdefer { + for (0..index) |idx| sockets.items[idx].close(agent.io); + sockets.deinit(agent.allocator); + } + while (true) { if (index >= candidates.len) break; const socket = candidates[index].address.bind(agent.io, .{ .mode = .dgram, .protocol = .udp }) catch { @@ -207,17 +211,16 @@ fn linuxGatherHostCandidates(agent: *Agent) !void { var it = interfaces; while (it) |p_ifa| : (it = p_ifa.*.ifa_next) if (p_ifa.*.ifa_addr) |addr| { - switch (addr.*.sa_family) { - c.AF_INET => { + const sockaddr: linux.sockaddr = @bitCast(addr.*); + + switch (sockaddr.family) { + linux.AF.INET => { const c_flags: u16 = @truncate(p_ifa.*.ifa_flags); - const flags: std.os.linux.IFF = @bitCast(c_flags); + const flags: linux.IFF = @bitCast(c_flags); if (flags.LOOPBACK) continue; - const sin: *const c.sockaddr_in = @ptrCast(@alignCast(addr)); - - const ip_addr: Io.net.IpAddress = .{ - .ip4 = .{ .bytes = std.mem.toBytes(sin.sin_addr.s_addr), .port = 0 }, - }; + const in: linux.sockaddr.in = @bitCast(sockaddr); + const ip_addr: Io.net.IpAddress = .{ .ip4 = .{ .bytes = std.mem.toBytes(in.addr), .port = 0 } }; try agent.candidates.append(agent.allocator, .initHost(ip_addr)); }, else => {}, @@ -447,23 +450,12 @@ fn findCandidatePair(agent: *Agent, local: *const IpAddress, remote: *const IpAd return null; } -fn maybeSetNominatedCandidate(agent: *Agent) !bool { - if (agent.role == .controlling or agent.nominated_pair != null) return false; - - for (agent.pairs.items) |candidate_pair| if (candidate_pair.state.nominated) { - agent.nominated_pair = candidate_pair; - return true; - }; - - return false; -} - fn setConnectionState(agent: *Agent, new_state: ice.ConnectionState) void { agent.connection_state = new_state; agent.config.onConnectionState(agent, new_state); } -// ============== Io related function ====================== +// ============== Io related functions ====================== const MessageError = (Allocator.Error || Socket.ReceiveTimeoutError); const Message = struct { @@ -474,7 +466,6 @@ const Message = struct { const InnerEvent = union(enum) { message: MessageError!Message, connectivity_check: Io.Cancelable!void, - send_message: (Allocator.Error || Socket.SendError)!void, complete: Io.Cancelable!void, // message received from the nominated peer data_message: MessageError!Message, @@ -496,10 +487,9 @@ fn innerEventHandler(agent: *Agent) !void { var select = Select.init(agent.io, &queue); defer select.cancelDiscard(); - for (agent.sockets) |*socket| { - select.async(.message, receiveTimeout, .{ agent, socket, .none }); - } select.async(.connectivity_check, Io.sleep, .{ io, connectivity_check_interval, .awake }); + for (agent.sockets) |*socket| + select.async(.message, receiveTimeout, .{ agent, socket, .none }); var nominated_socket: Socket = undefined; @@ -509,8 +499,8 @@ fn innerEventHandler(agent: *Agent) !void { switch (agent.connection_state) { .completed, .failed => {}, else => { - try agent.batchSendConnectivityCheck(); select.async(.connectivity_check, Io.sleep, .{ io, connectivity_check_interval, .awake }); + agent.batchSendConnectivityCheck() catch |err| std.log.err("connectivity check failed due to {}", .{err}); }, } }, @@ -522,8 +512,10 @@ fn innerEventHandler(agent: *Agent) !void { if (stun.isMessage(data)) { defer agent.destroyPacket(data); - if (try agent.handleReceivedMessage(message.socket.address, message.incoming_message)) |response| - select.async(.send_message, send, .{ agent, message.socket, &sender, response }); + if (try agent.handleReceivedMessage(message.socket.address, message.incoming_message)) |response| { + defer agent.destroyPacket(response); + try message.socket.send(io, &sender, response); + } const candidate_pair: ?CandidatePair = blk: { if (agent.role == .controlling or agent.nominated_pair != null) break :blk null; @@ -555,13 +547,12 @@ fn innerEventHandler(agent: *Agent) !void { select.async(.message, receiveTimeout, .{ agent, message.socket, .none }); }, - .send_message => |result| result catch |err| std.log.err("failed to send response: {}", .{err}), .data_message => |result| { const message = result catch |err| switch (err) { error.Timeout => switch (agent.connection_state) { .connected, .completed => { - agent.setConnectionState(.disconnected); select.async(.data_message, receiveTimeout, .{ agent, &nominated_socket, .{ .duration = failing_timeout } }); + agent.setConnectionState(.disconnected); continue; }, .disconnected => { @@ -573,12 +564,12 @@ fn innerEventHandler(agent: *Agent) !void { else => |e| return e, }; + select.async(.data_message, receiveTimeout, .{ agent, message.socket, .{ .duration = disconnect_timeout } }); + if (stun.isMessage(message.incoming_message.data)) - agent.buffer_pool.destroy(@ptrCast(@alignCast(message.incoming_message.data.ptr))) + agent.destroyPacket(message.incoming_message.data) else agent.config.onData(agent, message.incoming_message.data); - - select.async(.data_message, receiveTimeout, .{ agent, message.socket, .{ .duration = disconnect_timeout } }); }, .keep_alive => |timeout| { try timeout; @@ -589,7 +580,7 @@ fn innerEventHandler(agent: *Agent) !void { }, .complete => |result| { try result; - try agent.markConnectionCompleted(); + agent.markConnectionCompleted(); }, }; } @@ -635,7 +626,7 @@ fn batchSendConnectivityCheck(agent: *Agent) !void { }; } -fn markConnectionCompleted(agent: *Agent) !void { +fn markConnectionCompleted(agent: *Agent) void { const addr = agent.nominated_pair.?.socket.address; for (agent.sockets) |*socket| if (!socket.address.eql(&addr)) socket.close(agent.io); diff --git a/src/ice/c.h b/src/ice/c.h index 393a8ff..bd4f097 100644 --- a/src/ice/c.h +++ b/src/ice/c.h @@ -1,2 +1,4 @@ +#ifdef __linux #include "ifaddrs.h" -#include "netinet/in.h" \ No newline at end of file +#include "netinet/in.h" +#endif \ No newline at end of file From 35be42e0b3c01fe8ecd67b8091a3900d271ab27f Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Sat, 16 May 2026 16:24:17 +0100 Subject: [PATCH 8/8] Add local credentials option to AgentConfig --- src/ice/agent.zig | 324 +++++++++++++++++++++++++++++++++++----------- src/ice/ice.zig | 33 +++++ src/stun/stun.zig | 1 + 3 files changed, 283 insertions(+), 75 deletions(-) diff --git a/src/ice/agent.zig b/src/ice/agent.zig index cd857aa..81f8b9e 100644 --- a/src/ice/agent.zig +++ b/src/ice/agent.zig @@ -21,16 +21,23 @@ const disconnect_timeout: Io.Clock.Duration = .{ .clock = .awake, .raw = .fromSe const failing_timeout: Io.Clock.Duration = .{ .clock = .awake, .raw = .fromSeconds(25) }; pub const AgentConfig = struct { - onConnectionState: *const fn (*Agent, ice.ConnectionState) void, - onData: *const fn (*Agent, []const u8) void, + on_connection_state_change: *const fn (*Agent, ice.ConnectionState) void, + on_data: *const fn (*Agent, []const u8) void, + /// Local credentials of the agent (ufrag and password) + /// + /// Generated automatically if not provided + credentials: ?ice.Credentials = null, }; io: Io, allocator: Allocator, buffer_pool: std.heap.MemoryPool([max_message_size]u8), -config: AgentConfig, connection_state: ice.ConnectionState = .new, +// callbacks +on_connection_state_change: *const fn (*Agent, ice.ConnectionState) void, +on_data: *const fn (*Agent, []const u8) void, + // Stun related fields role: Role, credentials: ice.Credentials, @@ -73,19 +80,26 @@ const StunRequest = struct { const PendingRequest = struct { transaction_id: u96, - source: Io.net.IpAddress, - target: Io.net.IpAddress, + source: IpAddress, + target: IpAddress, }; pub fn init(io: Io, allocator: Allocator, config: AgentConfig) !Agent { + const credens = + try if (config.credentials) |credens| + credens.dupe(allocator) + else + ice.Credentials.generate(io, allocator); + return .{ .io = io, .allocator = allocator, .buffer_pool = .empty, .role = .controlled, - .credentials = try (ice.Credentials{ .username = "test", .password = "test" }).dupe(allocator), .tie_breaker = generateTieBeaker(io), - .config = config, + .credentials = credens, + .on_connection_state_change = config.on_connection_state_change, + .on_data = config.on_data, }; } @@ -220,7 +234,7 @@ fn linuxGatherHostCandidates(agent: *Agent) !void { if (flags.LOOPBACK) continue; const in: linux.sockaddr.in = @bitCast(sockaddr); - const ip_addr: Io.net.IpAddress = .{ .ip4 = .{ .bytes = std.mem.toBytes(in.addr), .port = 0 } }; + const ip_addr: IpAddress = .{ .ip4 = .{ .bytes = std.mem.toBytes(in.addr), .port = 0 } }; try agent.candidates.append(agent.allocator, .initHost(ip_addr)); }, else => {}, @@ -242,73 +256,24 @@ fn doAddRemoteCandidate(agent: *Agent, remote_candidate: Candidate) Allocator.Er } } -fn handleReceivedMessage(agent: *Agent, base_addr: Io.net.IpAddress, incoming_message: Io.net.IncomingMessage) !?[]const u8 { +fn handleReceivedMessage(agent: *Agent, base_addr: IpAddress, incoming_message: Io.net.IncomingMessage) !?[]const u8 { const msg = try stun.Message.parse(incoming_message.data); - switch (msg.header.message_type.class()) { - .request => return try agent.handleRequest(&msg, base_addr, incoming_message.from), - .success_response => { - Logger.debug("Handle success response on {f} from {f}", .{ base_addr, incoming_message.from }); - - const pending_request = blk: { - const tx_id = msg.header.transaction_id; - for (agent.pending_requests.items, 0..) |pr, i| { - if (pr.transaction_id == tx_id) { - const pending_request = agent.pending_requests.swapRemove(i); - break :blk pending_request; - } - } - - return null; - }; - - if (!pending_request.source.eql(&base_addr) or !pending_request.target.eql(&incoming_message.from)) return null; - - if (agent.findCandidatePair(&base_addr, &incoming_message.from)) |candidate_pair| { - const mapped_address = blk: { - var it = msg.iterateAttributes(&.{}); - while (try it.next()) |attribute| switch (attribute) { - .xor_mapped_address => |addr| break :blk addr, - else => {}, - }; - - return null; - }; - - if (mapped_address.eql(&base_addr)) { - candidate_pair.state.status = .succeeded; - if (agent.role == .controlled and candidate_pair.state.nominateOnBinding) { - candidate_pair.state.nominateOnBinding = false; - candidate_pair.state.nominated = true; - } - return null; - } - candidate_pair.state.status = .failed; - - if (agent.findCandidatePair(&mapped_address, &incoming_message.from)) |existing_candidate_pair| { - existing_candidate_pair.state.status = .succeeded; - return null; - } - - const reflexive_candidate: Candidate = .initPeerReflexive(base_addr, mapped_address); - try agent.pairs.append(agent.allocator, .{ - .local = reflexive_candidate, - .remote = candidate_pair.remote, - .priority = calculatePairPriority(reflexive_candidate.priority, candidate_pair.remote.priority, agent.role), - .state = .{ .status = .succeeded }, - }); - - return null; - } - }, - else => {}, - } - - return null; + return switch (msg.header.message_type.class()) { + .request => try agent.handleRequest(&msg, base_addr, incoming_message.from), + .success_response => try agent.handleSuccessResponse(&msg, base_addr, incoming_message.from), + else => null, + }; } fn handleRequest(agent: *Agent, msg: *const stun.Message, base_addr: IpAddress, from: IpAddress) ![]const u8 { Logger.debug("Handle request on {f} from {f}", .{ base_addr, from }); - const stun_req = try agent.parseAndValidateStunRequest(msg); + const buffer = try agent.buffer_pool.create(agent.allocator); + errdefer agent.buffer_pool.destroy(buffer); + + const stun_req = agent.parseAndValidateStunRequest(msg) catch |err| switch (err) { + error.RoleConflict => return try agent.buildRoleConflictErrorMessage(msg.header.transaction_id, buffer), + else => |e| return e, + }; if (agent.findCandidatePair(&base_addr, &from)) |candidate_pair| { switch (candidate_pair.state.status) { @@ -335,10 +300,55 @@ fn handleRequest(agent: *Agent, msg: *const stun.Message, base_addr: IpAddress, }); } - const buffer = try agent.buffer_pool.create(agent.allocator); return try agent.buildSuccessResponse(msg, from, buffer); } +fn handleSuccessResponse(agent: *Agent, msg: *const stun.Message, base_addr: IpAddress, from: IpAddress) !?[]const u8 { + Logger.debug("Handle success response on {f} from {f}", .{ base_addr, from }); + + const pending_request = blk: { + const tx_id = msg.header.transaction_id; + for (agent.pending_requests.items, 0..) |pr, i| { + if (pr.transaction_id == tx_id) { + const pending_request = agent.pending_requests.swapRemove(i); + break :blk pending_request; + } + } + + return null; + }; + + if (!pending_request.source.eql(&base_addr) or !pending_request.target.eql(&from)) return null; + + if (agent.findCandidatePair(&base_addr, &from)) |candidate_pair| { + const mapped_address = try agent.parseAndValidateStunResponse(msg); + + if (mapped_address.eql(&base_addr)) { + candidate_pair.state.status = .succeeded; + if (candidate_pair.state.nominateOnBinding) { + candidate_pair.state.nominateOnBinding = false; + candidate_pair.state.nominated = true; + } + return null; + } + candidate_pair.state.status = .failed; + + if (agent.findCandidatePair(&mapped_address, &from)) |existing_candidate_pair| { + existing_candidate_pair.state.status = .succeeded; + return null; + } + + const reflexive_candidate: Candidate = .initPeerReflexive(base_addr, mapped_address); + try agent.pairs.append(agent.allocator, .{ + .local = reflexive_candidate, + .remote = candidate_pair.remote, + .priority = calculatePairPriority(reflexive_candidate.priority, candidate_pair.remote.priority, agent.role), + .state = .{ .status = .succeeded }, + }); + } + return null; +} + fn parseAndValidateStunRequest(agent: *Agent, msg: *const stun.Message) !StunRequest { var it = msg.iterateAttributes(agent.credentials.password); var has_fingerprint: bool = false; @@ -384,6 +394,23 @@ fn parseAndValidateStunRequest(agent: *Agent, msg: *const stun.Message) !StunReq return stun_request; } +fn parseAndValidateStunResponse(agent: *Agent, msg: *const stun.Message) !IpAddress { + var it = msg.iterateAttributes(agent.remote_credentials.?.password); + var has_fingerprint: bool = false; + var has_message_integrity = false; + var maybe_addr: ?IpAddress = null; + + while (try it.next()) |attribute| switch (attribute) { + .xor_mapped_address => |value| maybe_addr = value, + .fingerprint => has_fingerprint = true, + .message_integrity => has_message_integrity = true, + else => {}, + }; + + if (!has_fingerprint or !has_message_integrity) return error.InvalidStunMessage; + return if (maybe_addr) |addr| addr else error.MissingMappedAddress; +} + fn buildBindingRequest(agent: *Agent, tx_id: u96, buffer: *[max_message_size]u8) ![]const u8 { var w = stun.Writer.init(&(buffer.*), .{ .password = agent.remote_credentials.?.password }); try w.writeHeader(.{ @@ -421,7 +448,7 @@ fn buildIndicationRequest(buffer: []u8) ![]const u8 { fn buildSuccessResponse( agent: *const Agent, msg: *const stun.Message, - from: Io.net.IpAddress, + from: IpAddress, buffer: *[max_message_size]u8, ) ![]const u8 { var w = stun.Writer.init(&(buffer.*), .{ .password = agent.credentials.password }); @@ -436,6 +463,22 @@ fn buildSuccessResponse( return w.final(); } +fn buildRoleConflictErrorMessage(agent: *const Agent, transaction_id: u96, buffer: *[max_message_size]u8) ![]const u8 { + var w = stun.Writer.init(&(buffer.*), .{ .password = agent.credentials.password }); + try w.writeHeader(.{ + .message_type = .fromClassAndMethod(.error_response, .binding), + .transaction_id = transaction_id, + .message_length = 0, + }); + try w.writeAttribute(.{ .error_code = .{ + .code = 487, + .reason = "Role conflict", + } }); + try w.writeAttribute(.{ .message_integrity = &.{} }); + try w.writeAttribute(.fingerprint); + return w.final(); +} + fn findSocket(sockets: []Io.net.Socket, addr: *const IpAddress) *Io.net.Socket { for (sockets) |*socket| if (socket.address.eql(addr)) return socket; unreachable; @@ -452,7 +495,7 @@ fn findCandidatePair(agent: *Agent, local: *const IpAddress, remote: *const IpAd fn setConnectionState(agent: *Agent, new_state: ice.ConnectionState) void { agent.connection_state = new_state; - agent.config.onConnectionState(agent, new_state); + agent.on_connection_state_change(agent, new_state); } // ============== Io related functions ====================== @@ -538,7 +581,7 @@ fn innerEventHandler(agent: *Agent) !void { } } else { for (agent.pairs.items) |*candidate_pair| if (candidate_pair.remote.address.eql(&sender)) { - agent.config.onData(agent, data); + agent.on_data(agent, data); } else { std.log.warn("Drop non stun message from unknown remote candidate: {f}", .{sender}); agent.destroyPacket(data); @@ -569,7 +612,7 @@ fn innerEventHandler(agent: *Agent) !void { if (stun.isMessage(message.incoming_message.data)) agent.destroyPacket(message.incoming_message.data) else - agent.config.onData(agent, message.incoming_message.data); + agent.on_data(agent, message.incoming_message.data); }, .keep_alive => |timeout| { try timeout; @@ -637,3 +680,134 @@ fn markConnectionCompleted(agent: *Agent) void { agent.pending_requests.clearAndFree(agent.allocator); agent.setConnectionState(.completed); } + +const testing = std.testing; + +fn testNewAgent() !Agent { + return try .init(testing.io, testing.allocator, .{ + .on_connection_state_change = undefined, + .on_data = undefined, + }); +} + +fn testBuildRequest(req: StunRequest, peer_password: []const u8, buffer: []u8) !stun.Message { + var w = stun.Writer.init(buffer, .{ .password = peer_password }); + try w.writeHeader(.{ + .message_type = .fromClassAndMethod(.request, .binding), + .transaction_id = generateTrasactionId(testing.io), + .message_length = 0, + }); + try w.writeAttribute(.{ .username = req.username }); + try w.writeAttribute(.{ .priority = req.priority }); + if (req.ice_controlled != null) try w.writeAttribute(.{ .ice_controlled = req.ice_controlled.? }); + if (req.ice_controlling != null) try w.writeAttribute(.{ .ice_controlling = req.ice_controlling.? }); + if (req.use_candidate) try w.writeAttribute(.use_candidate); + try w.writeAttribute(.{ .message_integrity = &.{} }); + try w.writeAttribute(.fingerprint); + + return try stun.Message.parse(w.final()); +} + +test "init agent" { + var agent: Agent = try .init(testing.io, testing.allocator, .{ + .on_connection_state_change = undefined, + .on_data = undefined, + }); + defer agent.deinit(); +} + +test "handle request: generate success response" { + var agent: Agent = try testNewAgent(); + defer agent.deinit(); + + var buffer: [1024]u8 = undefined; + + const base_addr = try IpAddress.parse("192.168.1.100", 1000); + const from = try IpAddress.parse("192.168.1.120", 2000); + + const msg = try testBuildRequest(.{ + .ice_controlling = 0x10000, + .priority = 0x9090, + .username = agent.credentials.username, + }, agent.credentials.password, &buffer); + + const resp = try agent.handleRequest(&msg, base_addr, from); + const resp_msg = try stun.Message.parse(resp); + + try testing.expectEqual(.success_response, resp_msg.header.message_type.class()); + try testing.expectEqual(.binding, resp_msg.header.message_type.method()); + try testing.expectEqual(msg.header.transaction_id, resp_msg.header.transaction_id); + + var it = resp_msg.iterateAttributes(agent.credentials.password); + var attr = try it.next() orelse return error.ExpectedAttribute; + try testing.expect(attr.xor_mapped_address.eql(&from)); + + attr = try it.next() orelse return error.ExpectedAttribute; + try testing.expectEqual(.message_integrity, @as(stun.AttributeType, attr)); + + attr = try it.next() orelse return error.ExpectedAttribute; + try testing.expectEqual(.fingerprint, @as(stun.AttributeType, attr)); + try testing.expectEqual(null, try it.next()); +} + +test "handle request: create peer reflexive candidate" { + var agent: Agent = try testNewAgent(); + defer agent.deinit(); + + var buffer: [1024]u8 = undefined; + + const base_addr = try IpAddress.parse("192.168.1.100", 1000); + const from = try IpAddress.parse("192.168.1.120", 2000); + + const msg = try testBuildRequest(.{ + .ice_controlling = 0x10000, + .priority = 0x9090, + .username = agent.credentials.username, + }, agent.credentials.password, &buffer); + + _ = try agent.handleRequest(&msg, base_addr, from); + + try testing.expectEqual(1, agent.pairs.items.len); + + const candidate_pair = agent.pairs.items[0]; + try testing.expect(candidate_pair.remote.address.eql(&from)); + try testing.expectEqual(candidate_pair.remote.priority, 0x9090); + + // Send request again + _ = try agent.handleRequest(&msg, base_addr, from); + try testing.expectEqual(1, agent.pairs.items.len); // no new peer is created +} + +test "handle request: nominate peer" { + var agent: Agent = try testNewAgent(); + defer agent.deinit(); + + var buffer: [1024]u8 = undefined; + + const base_addr = try IpAddress.parse("192.168.1.100", 1000); + const from = try IpAddress.parse("192.168.1.120", 2000); + + try agent.pairs.append(testing.allocator, .{ + .local = .initHost(base_addr), + .remote = .initHost(from), + .state = .{ .status = .in_progress }, + .priority = 0, + }); + + const msg = try testBuildRequest(.{ + .ice_controlling = 0x10000, + .priority = 0x9090, + .username = agent.credentials.username, + .use_candidate = true, + }, agent.credentials.password, &buffer); + + _ = try agent.handleRequest(&msg, base_addr, from); + + const candidate_pair = &agent.pairs.items[0]; + try testing.expectEqual(true, candidate_pair.state.nominateOnBinding); + try testing.expectEqual(false, candidate_pair.state.nominated); + + candidate_pair.state.status = .succeeded; + _ = try agent.handleRequest(&msg, base_addr, from); + try testing.expectEqual(true, candidate_pair.state.nominated); +} diff --git a/src/ice/ice.zig b/src/ice/ice.zig index b97ed5b..8591684 100644 --- a/src/ice/ice.zig +++ b/src/ice/ice.zig @@ -247,6 +247,39 @@ pub const Credentials = struct { allocator.free(credens.username); allocator.free(credens.password); } + + pub fn generate(io: std.Io, allocator: std.mem.Allocator) !Credentials { + var encoder = std.base64.standard.Encoder; + + var user_bytes: [6]u8 = undefined; + io.random(&user_bytes); + const username = try allocator.alloc(u8, encoder.calcSize(user_bytes.len)); + errdefer allocator.free(username); + _ = encoder.encode(username, &user_bytes); + + var password_bytes: [12]u8 = undefined; + try io.randomSecure(&password_bytes); + const password = try allocator.alloc(u8, encoder.calcSize(password_bytes.len)); + _ = encoder.encode(password, &password_bytes); + + return .{ + .username = username, + .password = password, + }; + } + + test "credentials: generate" { + var creds = try Credentials.generate(std.testing.io, std.testing.allocator); + defer creds.deinit(std.testing.allocator); + + try std.testing.expect(creds.username.len >= 8); + try std.testing.expect(creds.password.len >= 16); + } + + test "credentials: clean up on failure" { + var failing_allocator = std.testing.FailingAllocator.init(std.testing.allocator, .{ .fail_index = 1 }); + try std.testing.expectError(error.OutOfMemory, Credentials.generate(std.testing.io, failing_allocator.allocator())); + } }; pub const CandidatePair = struct { diff --git a/src/stun/stun.zig b/src/stun/stun.zig index de21b99..83bba54 100644 --- a/src/stun/stun.zig +++ b/src/stun/stun.zig @@ -351,6 +351,7 @@ pub const Writer = struct { switch (attribute) { .priority => |p| try out.writeInt(u32, p, .big), .ice_controlled, .ice_controlling => |tie_breaker| try out.writeInt(u64, tie_breaker, .big), + .use_candidate => {}, .message_integrity => try msg_writer.writeMessageIntegrity(), .fingerprint => try writeFingerprint(&msg_writer.writer), .software, .username, .userhash => |slice| try out.writeAll(slice),