diff --git a/README.md b/README.md index 9f0f454..a2ae264 100644 --- a/README.md +++ b/README.md @@ -7,3 +7,4 @@ The projects is structured into modules, each module is a separate library that * `rtp` - [RTP (Real-time Transport Protocol)](https://datatracker.ietf.org/doc/html/rfc3550) implementation for media streaming based on RFC 3550. * `sdp` - [SDP (Session Description Protocol)](https://datatracker.ietf.org/doc/html/rfc4566) implementation for describing multimedia sessions based on RFC 4566. * `rtsp` - [RTSP (Real Time Streaming Protocol)](https://datatracker.ietf.org/doc/html/rfc2326) implementation for controlling streaming media servers based on RFC 2326. +* `stun` - [STUN (Session Traversal Utilities for NAT)](https://datatracker.ietf.org/doc/html/rfc8489) implementation for NAT traversal based on RFC 5389. diff --git a/bench/stun/message.zig b/bench/stun/message.zig new file mode 100644 index 0000000..79ad3dc --- /dev/null +++ b/bench/stun/message.zig @@ -0,0 +1,97 @@ +const std = @import("std"); +const zbench = @import("zbench"); +const stun = @import("stun"); + +// Binding request, no attributes (header only) +const empty_request = [_]u8{ + 0x00, 0x01, 0x00, 0x00, + 0x21, 0x12, 0xA4, 0x42, + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, +}; + +// Binding success response with MAPPED-ADDRESS (IPv4) + XOR-MAPPED-ADDRESS (IPv4) +const mapped_response = [_]u8{ + 0x01, 0x01, 0x00, 0x18, + 0x21, 0x12, 0xA4, 0x42, + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, + // MAPPED-ADDRESS IPv4 + 0x00, 0x01, 0x00, 0x08, + 0x00, 0x01, 0x80, 0x55, + 192, 0, 2, 1, + // XOR-MAPPED-ADDRESS IPv4 + 0x00, 0x20, 0x00, 0x08, + 0x00, 0x01, 0xA1, 0x47, + 0xE1, 0x12, 0xA6, 0x43, +}; + +// RFC 5769 ยง2.1 sample binding request: SOFTWARE, PRIORITY, ICE-CONTROLLED, +// USERNAME, MESSAGE-INTEGRITY, FINGERPRINT +const rfc5769_request = [_]u8{ + 0x00, 0x01, 0x00, 0x58, + 0x21, 0x12, 0xa4, 0x42, + 0xb7, 0xe7, 0xa7, 0x01, + 0xbc, 0x34, 0xd6, 0x86, + 0xfa, 0x87, 0xdf, 0xae, + 0x80, 0x22, 0x00, 0x10, + 0x53, 0x54, 0x55, 0x4e, + 0x20, 0x74, 0x65, 0x73, + 0x74, 0x20, 0x63, 0x6c, + 0x69, 0x65, 0x6e, 0x74, + 0x00, 0x24, 0x00, 0x04, + 0x6e, 0x00, 0x01, 0xff, + 0x80, 0x29, 0x00, 0x08, + 0x93, 0x2f, 0xf9, 0xb1, + 0x51, 0x26, 0x3b, 0x36, + 0x00, 0x06, 0x00, 0x09, + 0x65, 0x76, 0x74, 0x6a, + 0x3a, 0x68, 0x36, 0x76, + 0x59, 0x20, 0x20, 0x20, + 0x00, 0x08, 0x00, 0x14, + 0x9a, 0xea, 0xa7, 0x0c, + 0xbf, 0xd8, 0xcb, 0x56, + 0x78, 0x1e, 0xf2, 0xb5, + 0xb2, 0xd3, 0xf2, 0x49, + 0xc1, 0xb5, 0x71, 0xa2, + 0x80, 0x28, 0x00, 0x04, + 0xe5, 0x7a, 0x3b, 0xcf, +}; + +const rfc5769_password = "VOkJxbRl1RmTxUk/WvJxBt"; + +fn parseHeaderOnly(_: std.mem.Allocator) void { + const msg = stun.Message.parse(&empty_request) catch unreachable; + var it = msg.iterateAttributes(&.{}); + while (it.next() catch unreachable) |attr| std.mem.doNotOptimizeAway(attr); + std.mem.doNotOptimizeAway(msg); +} + +fn parseMappedAddresses(_: std.mem.Allocator) void { + const msg = stun.Message.parse(&mapped_response) catch unreachable; + var it = msg.iterateAttributes(&.{}); + while (it.next() catch unreachable) |attr| std.mem.doNotOptimizeAway(attr); + std.mem.doNotOptimizeAway(msg); +} + +fn parseRfc5769(_: std.mem.Allocator) void { + const msg = stun.Message.parse(&rfc5769_request) catch unreachable; + var it = msg.iterateAttributes(rfc5769_password); + while (it.next() catch unreachable) |attr| std.mem.doNotOptimizeAway(attr); + std.mem.doNotOptimizeAway(msg); +} + +pub fn main(init: std.process.Init) !void { + const io = init.io; + const stdout: std.Io.File = .stdout(); + + var bench = zbench.Benchmark.init(init.gpa, .{}); + defer bench.deinit(); + + try bench.add("STUN: header only", parseHeaderOnly, .{}); + try bench.add("STUN: mapped + xor-mapped (IPv4)", parseMappedAddresses, .{}); + try bench.add("STUN: RFC 5769 with integrity + fingerprint", parseRfc5769, .{}); + try bench.run(io, stdout); +} diff --git a/build.zig b/build.zig index 751d5c2..21326b9 100644 --- a/build.zig +++ b/build.zig @@ -31,6 +31,12 @@ pub fn build(b: *std.Build) void { }, }); + const stun = b.addModule("stun", .{ + .root_source_file = b.path("src/stun/stun.zig"), + .target = target, + .optimize = optimize, + }); + _ = b.addModule("protocols", .{ .root_source_file = b.path("src/root.zig"), .target = target, @@ -38,6 +44,7 @@ pub fn build(b: *std.Build) void { .{ .name = "rtp", .module = rtp }, .{ .name = "sdp", .module = sdp }, .{ .name = "rtsp", .module = rtsp }, + .{ .name = "stun", .module = stun }, }, }); @@ -61,10 +68,17 @@ pub fn build(b: *std.Build) void { }); const run_rtsp_tests = b.addRunArtifact(rtsp_tests); + const stun_tests = b.addTest(.{ + .root_module = stun, + .filters = test_filters, + }); + const run_stun_tests = b.addRunArtifact(stun_tests); + const test_step = b.step("test", "Run tests"); test_step.dependOn(&run_rtp_tests.step); test_step.dependOn(&run_sdp_tests.step); test_step.dependOn(&run_rtsp_tests.step); + test_step.dependOn(&run_stun_tests.step); } { @@ -73,6 +87,7 @@ pub fn build(b: *std.Build) void { const benches = .{ .{ .name = "rtp_packet", .src = "bench/rtp/packet.zig" }, .{ .name = "sdp_session", .src = "bench/sdp/session.zig" }, + .{ .name = "stun_message", .src = "bench/stun/message.zig" }, }; inline for (benches) |bench| { @@ -86,6 +101,7 @@ pub fn build(b: *std.Build) void { .{ .name = "zbench", .module = zbench.module("zbench") }, .{ .name = "rtp", .module = rtp }, .{ .name = "sdp", .module = sdp }, + .{ .name = "stun", .module = stun }, }, }), }); diff --git a/src/root.zig b/src/root.zig index e8a1cfc..b436f2b 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,3 +1,4 @@ pub const rtp = @import("rtp"); pub const sdp = @import("sdp"); pub const rtsp = @import("rtsp"); +pub const stun = @import("stun"); diff --git a/src/stun/stun.zig b/src/stun/stun.zig new file mode 100644 index 0000000..5616347 --- /dev/null +++ b/src/stun/stun.zig @@ -0,0 +1,439 @@ +const std = @import("std"); + +const Io = std.Io; + +pub const magic_cookie: u32 = 0x2112A442; +pub const header_size = 20; + +pub const Class = enum(u2) { + request, + indication, + success_response, + error_response, +}; + +pub const Method = enum(u12) { + binding = 1, + _, +}; + +pub const MessageType = packed struct { + m1: u4, + c1: u1, + m2: u3, + c2: u1, + m3: u5, + + pub fn fromClassAndMethod(c: Class, m: Method) MessageType { + const m_int = @intFromEnum(m); + const cl_int = @intFromEnum(c); + + return MessageType{ + .m1 = @intCast(m_int & 0x0F), + .c1 = @intCast(cl_int & 1), + .m2 = @intCast((m_int >> 4) & 0x07), + .c2 = @intCast(cl_int >> 1), + .m3 = @intCast((m_int >> 7) & 0x1F), + }; + } + + pub fn method(self: MessageType) Method { + const method_int = @as(u12, self.m3) << 7 | @as(u12, self.m2) << 4 | self.m1; + return @enumFromInt(method_int); + } + + pub fn class(self: MessageType) Class { + return @enumFromInt((@as(u2, self.c2) << 1) | self.c1); + } +}; + +pub const Header = packed struct { + transaction_id: u96, + magic_cookie: u32 = magic_cookie, + message_length: u16, + message_type: MessageType, + _pad: u2 = 0, +}; + +pub const Message = struct { + header: Header, + bytes: []const u8, + + pub fn iterateAttributes(message: *const Message, passwd: []const u8) AttributeIterator { + var reader = std.Io.Reader.fixed(message.bytes); + reader.toss(header_size); + return .{ .reader = reader, .password = passwd }; + } + + pub fn parse(msg: []const u8) !Message { + std.debug.assert(msg.len >= header_size); + + const header_int = std.mem.readInt(@typeInfo(Header).@"struct".backing_integer.?, msg[0..header_size], .big); + const header: Header = @bitCast(header_int); + if (header.magic_cookie != magic_cookie) { + return error.WrongMagicCookie; + } + + return .{ + .header = header, + .bytes = msg, + }; + } +}; + +pub const AttributeType = enum(u16) { + mapped_address = 0x0001, + username = 0x0006, + message_integrity = 0x0008, + xor_mapped_address = 0x0020, + userhash = 0x001E, + priority = 0x0024, + software = 0x8022, + fingerprint = 0x8028, + ice_controlled = 0x8029, + unknown = 0xFFFF, + _, +}; + +pub const Attribute = union(AttributeType) { + mapped_address: Io.net.IpAddress, + username: []const u8, + message_integrity: []const u8, + xor_mapped_address: Io.net.IpAddress, + userhash: []const u8, + priority: u32, + software: []const u8, + fingerprint: u32, + ice_controlled: u64, + unknown: struct { AttributeType, []const u8 }, +}; + +pub const AttributeIterator = struct { + reader: Io.Reader, + password: []const u8, + + pub const Error = error{ + InvalidAttribute, + MessageIntegrityCheckFailed, + FingerprintCheckFailed, + }; + + pub fn next(it: *AttributeIterator) Error!?Attribute { + if (it.reader.bufferedLen() == 0) return null; + + const attr_type = it.reader.takeEnum(AttributeType, .big) catch return error.InvalidAttribute; + const attr_len = it.reader.takeInt(u16, .big) catch return error.InvalidAttribute; + + const padding = switch (@rem(attr_len, 4)) { + 0 => 0, + else => |v| 4 - v, + }; + const attr_value = it.reader.take(attr_len + padding) catch return error.InvalidAttribute; + + return switch (attr_type) { + .mapped_address => try parseMappedAddress(attr_value), + .xor_mapped_address => try parseXorMappedAddress(attr_value, it.reader.buffer[8..20]), + .username => .{ .username = attr_value[0..attr_len] }, + .software => .{ .software = attr_value[0..attr_len] }, + .userhash => blk: { + if (attr_len != 32) break :blk error.InvalidAttribute; + break :blk .{ .userhash = attr_value[0..attr_len] }; + }, + .priority => blk: { + if (attr_len != 4) return error.InvalidAttribute; + break :blk .{ .priority = std.mem.readInt(u32, attr_value[0..4], .big) }; + }, + .ice_controlled => blk: { + if (attr_len != 8) return error.InvalidAttribute; + break :blk .{ .ice_controlled = std.mem.readInt(u64, attr_value[0..8], .big) }; + }, + .message_integrity => blk: { + if (attr_len != 20) break :blk error.InvalidAttribute; + try it.verifyMessageIntegrity(attr_value); + break :blk .{ .message_integrity = attr_value }; + }, + .fingerprint => blk: { + if (attr_len != 4) break :blk error.InvalidAttribute; + const fingerprint = std.mem.readInt(u32, attr_value[0..4], .big); + try it.verifyFingerprint(fingerprint); + break :blk .{ .fingerprint = fingerprint }; + }, + else => .{ .unknown = .{ attr_type, attr_value } }, + }; + } + + fn parseMappedAddress(value: []const u8) !Attribute { + if (value.len < 8) return error.InvalidAttribute; + + const family = switch (value[1]) { + 1 => Io.net.IpAddress.Family.ip4, + 2 => Io.net.IpAddress.Family.ip6, + else => return error.InvalidAttribute, + }; + + if (family == .ip4 and value.len != 8 or family == .ip6 and value.len != 20) { + return error.InvalidAttribute; + } + + const port = std.mem.readInt(u16, value[2..4], .big); + const ip = blk: switch (family) { + .ip4 => { + var ip = Io.net.IpAddress{ .ip4 = .unspecified(port) }; + @memcpy(&ip.ip4.bytes, value[4..8]); + break :blk ip; + }, + .ip6 => { + var ip = Io.net.IpAddress{ .ip6 = .unspecified(port) }; + @memcpy(&ip.ip6.bytes, value[4..]); + break :blk ip; + }, + }; + + return .{ .mapped_address = ip }; + } + + fn parseXorMappedAddress(value: []const u8, tx_id: []const u8) !Attribute { + if (value.len < 8) return error.InvalidAttribute; + const family = switch (value[1]) { + 1 => Io.net.IpAddress.Family.ip4, + 2 => Io.net.IpAddress.Family.ip6, + else => return error.InvalidAttribute, + }; + + if (family == .ip4 and value.len != 8 or family == .ip6 and value.len != 20) { + return error.InvalidAttribute; + } + + const cookie = std.mem.toBytes(std.mem.nativeToBig(u32, magic_cookie)); + const port: u16 = std.mem.readInt(u16, &[_]u8{ value[2] ^ cookie[0], value[3] ^ cookie[1] }, .big); + const ip = blk: switch (family) { + .ip4 => { + var ip = Io.net.IpAddress{ .ip4 = .unspecified(port) }; + for (&ip.ip4.bytes, 0..) |*b, idx| b.* = value[4 + idx] ^ cookie[idx]; + break :blk ip; + }, + .ip6 => { + var ip = Io.net.IpAddress{ .ip6 = .unspecified(port) }; + for (ip.ip6.bytes[0..4], 0..) |*b, idx| b.* = value[4 + idx] ^ cookie[idx]; + for (ip.ip4.bytes[4..], 0..) |*b, idx| b.* = value[8 + idx] ^ tx_id[idx]; + break :blk ip; + }, + }; + + return .{ .xor_mapped_address = ip }; + } + + fn verifyMessageIntegrity(it: *const AttributeIterator, expected_hash: []u8) !void { + var hash: [20]u8 = undefined; + const msg = it.reader.buffer[0..it.reader.seek]; + const msg_size = msg.len - header_size; + + var hasher: std.crypto.auth.hmac.HmacSha1 = .init(it.password); + hasher.update(msg[0..2]); + hasher.update(&std.mem.toBytes(std.mem.nativeToBig(u16, @intCast(msg_size)))); + hasher.update(msg[4 .. msg_size - 4]); + hasher.final(&hash); + + if (!std.mem.eql(u8, &hash, expected_hash)) { + return error.MessageIntegrityCheckFailed; + } + } + + fn verifyFingerprint(it: *const AttributeIterator, expected_value: u32) !void { + const msg = it.reader.buffer; + + var hasher: std.hash.Crc32 = .init(); + hasher.update(msg[0 .. msg.len - 8]); + if (hasher.final() ^ 0x5354554e != expected_value) { + return error.FingerprintCheckFailed; + } + } +}; + +const testing = std.testing; + +test "MessageType: round-trip all classes" { + const classes = [_]Class{ .request, .indication, .success_response, .error_response }; + for (classes) |c| { + const mt = MessageType.fromClassAndMethod(c, .binding); + try testing.expectEqual(c, mt.class()); + try testing.expectEqual(Method.binding, mt.method()); + } +} + +test "Header: size matches STUN spec" { + try testing.expectEqual(@as(usize, 20), @divExact(@bitSizeOf(Header), 8)); +} + +test "Message.parse: binding request header" { + const bytes = [_]u8{ + 0x00, 0x01, 0x00, 0x00, + 0x21, 0x12, 0xA4, 0x42, + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, + }; + + const msg = try Message.parse(&bytes); + try testing.expectEqual(Class.request, msg.header.message_type.class()); + try testing.expectEqual(Method.binding, msg.header.message_type.method()); + try testing.expectEqual(@as(u16, 0), msg.header.message_length); + try testing.expectEqual(@as(u32, magic_cookie), msg.header.magic_cookie); + try testing.expectEqual(@as(u96, 0x000102030405060708090A0B), msg.header.transaction_id); + + var it = msg.iterateAttributes(&.{}); + try testing.expect((try it.next()) == null); +} + +test "Message.parse: binding success response header" { + const bytes = [_]u8{ + 0x01, 0x01, 0x00, 0x00, + 0x21, 0x12, 0xA4, 0x42, + 0xB7, 0xE7, 0xA7, 0x01, + 0xBC, 0x34, 0xD6, 0x86, + 0xFA, 0x87, 0xDF, 0xAE, + }; + + const msg = try Message.parse(&bytes); + try testing.expectEqual(Class.success_response, msg.header.message_type.class()); + try testing.expectEqual(Method.binding, msg.header.message_type.method()); +} + +test "Message.iterateAttributes" { + const bytes = [_]u8{ + 0x01, 0x01, 0x00, 0x38, + 0x21, 0x12, 0xA4, 0x42, + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, + // MAPPED-ADDRESS IPv4 + 0x00, 0x01, 0x00, 0x08, + 0x00, 0x01, 0x80, 0x55, + 192, 0, 2, 1, + // MAPPED-ADDRESS IPv6 + 0x00, 0x01, 0x00, 0x14, + 0x00, 0x02, 0x80, 0x55, + 0x20, 0x01, 0x0D, 0xB8, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + // XOR-MAPPED-ADDRESS IPv4 + 0x00, 0x20, 0x00, 0x08, + 0x00, 0x01, 0xA1, 0x47, + 0xE1, 0x12, 0xA6, 0x43, + // SOFTWARE (unknown) + 0x80, 0x22, 0x00, 0x04, + 't', 'e', 's', 't', + }; + + const expected_v4 = Io.net.IpAddress{ .ip4 = .{ .bytes = .{ 192, 0, 2, 1 }, .port = 32853 } }; + const expected_v6 = Io.net.IpAddress{ .ip6 = .{ + .bytes = .{ + 0x20, 0x01, 0x0D, 0xB8, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + }, + .port = 32853, + } }; + + const msg = try Message.parse(&bytes); + var it = msg.iterateAttributes(&.{}); + + const a1 = try it.next(); + try testing.expect(a1.?.mapped_address.eql(&expected_v4)); + + const a2 = (try it.next()) orelse return error.MissingAttribute; + try testing.expect(a2.mapped_address.eql(&expected_v6)); + + const a3 = (try it.next()) orelse return error.MissingAttribute; + try testing.expect(a3.xor_mapped_address.eql(&expected_v4)); + + const a4 = (try it.next()) orelse return error.MissingAttribute; + try testing.expectEqualStrings("test", a4.software); + + try testing.expectEqual(null, try it.next()); +} + +test "Message.iterateAttributes: invalid attribute length zero" { + const bytes = [_]u8{ + 0x00, 0x01, 0x00, 0x04, + 0x21, 0x12, 0xA4, 0x42, + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, + 0x00, 0x01, 0x00, 0x00, + }; + + const msg = try Message.parse(&bytes); + var it = msg.iterateAttributes(&.{}); + try testing.expectError(error.InvalidAttribute, it.next()); +} + +// test vectors (RFC 5769) +test "Message: iterator attributes" { + const data = [_]u8{ + 0x00, 0x01, 0x00, 0x58, + 0x21, 0x12, 0xa4, 0x42, + 0xb7, 0xe7, 0xa7, 0x01, + 0xbc, 0x34, 0xd6, 0x86, + 0xfa, 0x87, 0xdf, 0xae, + 0x80, 0x22, 0x00, 0x10, + 0x53, 0x54, 0x55, 0x4e, + 0x20, 0x74, 0x65, 0x73, + 0x74, 0x20, 0x63, 0x6c, + 0x69, 0x65, 0x6e, 0x74, + 0x00, 0x24, 0x00, 0x04, + 0x6e, 0x00, 0x01, 0xff, + 0x80, 0x29, 0x00, 0x08, + 0x93, 0x2f, 0xf9, 0xb1, + 0x51, 0x26, 0x3b, 0x36, + 0x00, 0x06, 0x00, 0x09, + 0x65, 0x76, 0x74, 0x6a, + 0x3a, 0x68, 0x36, 0x76, + 0x59, 0x20, 0x20, 0x20, + 0x00, 0x08, 0x00, 0x14, + 0x9a, 0xea, 0xa7, 0x0c, + 0xbf, 0xd8, 0xcb, 0x56, + 0x78, 0x1e, 0xf2, 0xb5, + 0xb2, 0xd3, 0xf2, 0x49, + 0xc1, 0xb5, 0x71, 0xa2, + 0x80, 0x28, 0x00, 0x04, + 0xe5, 0x7a, 0x3b, 0xcf, + }; + + const message = try Message.parse(&data); + try testing.expectEqual(.request, message.header.message_type.class()); + try testing.expectEqual(.binding, message.header.message_type.method()); + + var it = message.iterateAttributes("VOkJxbRl1RmTxUk/WvJxBt"); + var attribute = try it.next() orelse return error.ExpectedAttribute; + try testing.expectEqualStrings("STUN test client", attribute.software); + + attribute = try it.next() orelse return error.ExpectedAttribute; + try testing.expectEqual(0x6E0001FF, attribute.priority); + + attribute = try it.next() orelse return error.ExpectedAttribute; + try testing.expectEqual(0x932FF9B151263B36, attribute.ice_controlled); + + attribute = try it.next() orelse return error.ExpectedAttribute; + try testing.expectEqualStrings("evtj:h6vY", attribute.username); + + _ = try it.next() orelse return error.ExpectedAttribute; // Message Integrity + _ = try it.next() orelse return error.ExpectedAttribute; // Fingerprint + try testing.expectEqual(null, try it.next()); +} + +test "Message.iterateAttributes: invalid attribute length not multiple of 4" { + const bytes = [_]u8{ + 0x00, 0x01, 0x00, 0x08, + 0x21, 0x12, 0xA4, 0x42, + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, + 0x00, 0x01, 0x00, 0x07, + 0xAA, 0xBB, 0xCC, 0xDD, + }; + + const msg = try Message.parse(&bytes); + var it = msg.iterateAttributes(&.{}); + try testing.expectError(error.InvalidAttribute, it.next()); +}