diff --git a/src/stun/stun.zig b/src/stun/stun.zig index 33c596d..305ff07 100644 --- a/src/stun/stun.zig +++ b/src/stun/stun.zig @@ -4,6 +4,7 @@ const Io = std.Io; pub const magic_cookie: u32 = 0x2112A442; pub const header_size = 20; +const fingerprint_xor: u32 = 0x5354554e; pub const Class = enum(u2) { request, @@ -125,10 +126,26 @@ pub const Attribute = union(AttributeType) { userhash: []const u8, priority: u32, software: []const u8, - fingerprint: u32, + fingerprint: void, ice_controlled: u64, ice_controlling: u64, unknown: struct { AttributeType, []const u8 }, + + pub fn size(attribute: Attribute) u16 { + return switch (attribute) { + .priority, .fingerprint => 4, + .ice_controlled, .ice_controlling => 8, + .message_integrity => 20, + .use_candidate => 0, + .software, .username, .userhash => |slice| @intCast(slice.len), + .mapped_address, .xor_mapped_address => |ip| switch (ip) { + .ip4 => 8, + .ip6 => 20, + }, + .error_code => |err| @intCast(err.reason.len + 4), + else => 0, + }; + } }; pub const AttributeIterator = struct { @@ -193,7 +210,7 @@ pub const AttributeIterator = struct { if (attr_len != 4) break :blk error.InvalidAttribute; const fingerprint = std.mem.readInt(u32, attr_value[0..4], .big); try it.verifyFingerprint(fingerprint); - break :blk .{ .fingerprint = fingerprint }; + break :blk .fingerprint; }, else => .{ .unknown = .{ attr_type, attr_value } }, }; @@ -281,12 +298,143 @@ pub const AttributeIterator = struct { var hasher: std.hash.Crc32 = .init(); hasher.update(msg[0 .. msg.len - 8]); - if (hasher.final() ^ 0x5354554e != expected_value) { + if (hasher.final() ^ fingerprint_xor != expected_value) { return error.FingerprintCheckFailed; } } }; +pub const Writer = struct { + writer: Io.Writer, + options: WriterOptions, + + pub const WriterOptions = struct { + password: ?[]const u8 = null, + padding_byte: u8 = 0, + }; + + pub fn init(buffer: []u8, options: WriterOptions) Writer { + return .{ .writer = .fixed(buffer), .options = options }; + } + + pub fn writeHeader(msg_writer: *Writer, header: Header) !void { + try msg_writer.writer.writeStruct(header, .big); + } + + pub fn writeRaw(msg_writer: *Writer, attr_type: AttributeType, content: [][]const u8) !void { + var w = &msg_writer.writer; + + try w.writeInt(u16, @intFromEnum(attr_type), .big); + const length = try w.writableArray(2); + const pos = w.end; + try msg_writer.writer.writeVecAll(content); + + const attr_size: u16 = @intCast(w.end - pos); + const padding = switch (@rem(attr_size, 4)) { + 0 => 0, + else => |v| 4 - v, + }; + @memset(try w.writableSlice(padding), msg_writer.options.padding_byte); + std.mem.writeInt(u16, length, attr_size, .big); + } + + pub fn writeAttribute(msg_writer: *Writer, attribute: Attribute) !void { + var out = &msg_writer.writer; + + try out.writeInt(u16, @intFromEnum(attribute), .big); + try out.writeInt(u16, attribute.size(), .big); + switch (attribute) { + .priority => |p| try out.writeInt(u32, p, .big), + .ice_controlled, .ice_controlling => |tie_breaker| try out.writeInt(u64, tie_breaker, .big), + .message_integrity => try msg_writer.writeMessageIntegrity(), + .fingerprint => try writeFingerprint(&msg_writer.writer), + .software, .username, .userhash => |slice| try out.writeAll(slice), + .mapped_address => |addr| try msg_writer.writeIpAddress(addr, false), + .xor_mapped_address => |addr| try msg_writer.writeIpAddress(addr, true), + .error_code => |err| { + try msg_writer.writer.writeInt(u32, @as(u24, err.code / 100) << 16 | (err.code % 100), .big); + try msg_writer.writer.writeAll(err.reason); + }, + else => return error.UnknownAttribute, + } + + const padding = switch (@rem(out.end, 4)) { + 0 => 0, + else => |v| 4 - v, + }; + @memset(try out.writableSlice(padding), msg_writer.options.padding_byte); + } + + pub fn final(msg_writer: *Writer) []const u8 { + const result = msg_writer.writer.buffered(); + const msg_length: u16 = @intCast(result.len - header_size); + std.mem.writeInt(u16, result[2..4], msg_length, .big); + return result; + } + + fn writeMessageIntegrity(msg_writer: *Writer) !void { + var w = &msg_writer.writer; + + const buf = w.buffered(); + const hash = try w.writableArray(20); + const msg_length: u16 = @intCast(w.end - header_size); // 4 bytes of already written attribute header + + var hasher = std.crypto.auth.hmac.HmacSha1.init(msg_writer.options.password.?); + hasher.update(buf[0..2]); + hasher.update(&std.mem.toBytes(std.mem.nativeToBig(u16, msg_length))); + hasher.update(buf[4 .. buf.len - 4]); + hasher.final(hash); + } + + fn writeFingerprint(w: *Io.Writer) !void { + const buf = w.buffered(); + const msg_size: u16 = @intCast(buf.len - header_size + 4); + + var hasher: std.hash.Crc32 = .init(); + hasher.update(buf[0..2]); + hasher.update(&std.mem.toBytes(std.mem.nativeToBig(u16, msg_size))); + hasher.update(buf[4 .. buf.len - 4]); + + try w.writeInt(u32, hasher.final() ^ fingerprint_xor, .big); + } + + fn writeIpAddress(msg_writer: *Writer, addr: Io.net.IpAddress, xor: bool) !void { + var out = &msg_writer.writer; + const cookie = std.mem.toBytes(std.mem.nativeToBig(u32, magic_cookie)); + switch (addr) { + .ip4 => |ipv4| { + try out.writeInt(u16, 1, .big); + if (xor) { + const xor_port: u16 = ipv4.port ^ @as(u16, magic_cookie >> 16); + try out.writeAll(&std.mem.toBytes(std.mem.nativeToBig(u16, xor_port))); + + const slice = try out.writableSlice(cookie.len); + for (slice, 0..) |*b, idx| b.* = cookie[idx] ^ ipv4.bytes[idx]; + } else { + try out.writeInt(u16, ipv4.port, .big); + try out.writeAll(&ipv4.bytes); + } + }, + .ip6 => |ipv6| { + try out.writeInt(u16, 2, .big); + if (xor) { + const xor_port: u16 = ipv6.port ^ @as(u16, magic_cookie >> 16); + try out.writeAll(&std.mem.toBytes(std.mem.nativeToBig(u16, xor_port))); + + const slice = try out.writableSlice(ipv6.bytes.len); + const txid = out.buffer[8..20]; + + for (slice, 0..) |*b, idx| b.* = cookie[idx] ^ ipv6.bytes[idx]; + for (slice[4..], 0..) |*b, idx| b.* = txid[idx] ^ ipv6.bytes[idx]; + } else { + try out.writeInt(u16, ipv6.port, .big); + try out.writeAll(&ipv6.bytes); + } + }, + } + } +}; + const testing = std.testing; test "MessageType: round-trip all classes" { @@ -405,39 +553,39 @@ test "Message.iterateAttributes: invalid attribute length zero" { try testing.expectError(error.InvalidAttribute, it.next()); } +const rfc_5769_test_vector = [_]u8{ + 0x00, 0x01, 0x00, 0x58, + 0x21, 0x12, 0xa4, 0x42, + 0xb7, 0xe7, 0xa7, 0x01, + 0xbc, 0x34, 0xd6, 0x86, + 0xfa, 0x87, 0xdf, 0xae, + 0x80, 0x22, 0x00, 0x10, + 0x53, 0x54, 0x55, 0x4e, + 0x20, 0x74, 0x65, 0x73, + 0x74, 0x20, 0x63, 0x6c, + 0x69, 0x65, 0x6e, 0x74, + 0x00, 0x24, 0x00, 0x04, + 0x6e, 0x00, 0x01, 0xff, + 0x80, 0x29, 0x00, 0x08, + 0x93, 0x2f, 0xf9, 0xb1, + 0x51, 0x26, 0x3b, 0x36, + 0x00, 0x06, 0x00, 0x09, + 0x65, 0x76, 0x74, 0x6a, + 0x3a, 0x68, 0x36, 0x76, + 0x59, 0x20, 0x20, 0x20, + 0x00, 0x08, 0x00, 0x14, + 0x9a, 0xea, 0xa7, 0x0c, + 0xbf, 0xd8, 0xcb, 0x56, + 0x78, 0x1e, 0xf2, 0xb5, + 0xb2, 0xd3, 0xf2, 0x49, + 0xc1, 0xb5, 0x71, 0xa2, + 0x80, 0x28, 0x00, 0x04, + 0xe5, 0x7a, 0x3b, 0xcf, +}; + // test vectors (RFC 5769) test "Message: iterator attributes" { - const data = [_]u8{ - 0x00, 0x01, 0x00, 0x58, - 0x21, 0x12, 0xa4, 0x42, - 0xb7, 0xe7, 0xa7, 0x01, - 0xbc, 0x34, 0xd6, 0x86, - 0xfa, 0x87, 0xdf, 0xae, - 0x80, 0x22, 0x00, 0x10, - 0x53, 0x54, 0x55, 0x4e, - 0x20, 0x74, 0x65, 0x73, - 0x74, 0x20, 0x63, 0x6c, - 0x69, 0x65, 0x6e, 0x74, - 0x00, 0x24, 0x00, 0x04, - 0x6e, 0x00, 0x01, 0xff, - 0x80, 0x29, 0x00, 0x08, - 0x93, 0x2f, 0xf9, 0xb1, - 0x51, 0x26, 0x3b, 0x36, - 0x00, 0x06, 0x00, 0x09, - 0x65, 0x76, 0x74, 0x6a, - 0x3a, 0x68, 0x36, 0x76, - 0x59, 0x20, 0x20, 0x20, - 0x00, 0x08, 0x00, 0x14, - 0x9a, 0xea, 0xa7, 0x0c, - 0xbf, 0xd8, 0xcb, 0x56, - 0x78, 0x1e, 0xf2, 0xb5, - 0xb2, 0xd3, 0xf2, 0x49, - 0xc1, 0xb5, 0x71, 0xa2, - 0x80, 0x28, 0x00, 0x04, - 0xe5, 0x7a, 0x3b, 0xcf, - }; - - const message = try Message.parse(&data); + const message = try Message.parse(&rfc_5769_test_vector); try testing.expectEqual(.request, message.header.message_type.class()); try testing.expectEqual(.binding, message.header.message_type.method()); @@ -474,3 +622,75 @@ test "Message.iterateAttributes: invalid attribute length not multiple of 4" { var it = msg.iterateAttributes(&.{}); try testing.expectError(error.InvalidAttribute, it.next()); } + +test "Writer: write rfc message" { + var buffer: [1024]u8 = undefined; + + var out = Writer.init(&buffer, .{ + .password = "VOkJxbRl1RmTxUk/WvJxBt", + .padding_byte = 0x20, + }); + + try out.writeHeader(.{ + .message_type = .fromClassAndMethod(.request, .binding), + .transaction_id = std.mem.readInt(u96, rfc_5769_test_vector[8..20], .big), + .message_length = 0, + }); + + try out.writeAttribute(.{ .software = "STUN test client" }); + try out.writeAttribute(.{ .priority = 0x6E0001FF }); + try out.writeAttribute(.{ .ice_controlled = 0x932FF9B151263B36 }); + var username = [_][]const u8{ "evtj", ":", "h6vY" }; + try out.writeRaw(.username, &username); + try out.writeAttribute(.{ .message_integrity = &.{} }); + try out.writeAttribute(.fingerprint); + + try std.testing.expectEqualSlices(u8, &rfc_5769_test_vector, out.final()); +} + +test "Writer: write mapped and xor mapped addresses" { + const expected = [_]u8{ + 0x01, 0x01, 0x00, 0x30, + 0x21, 0x12, 0xA4, 0x42, + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, + // MAPPED-ADDRESS IPv4 + 0x00, 0x01, 0x00, 0x08, + 0x00, 0x01, 0x80, 0x55, + 192, 0, 2, 1, + // MAPPED-ADDRESS IPv6 + 0x00, 0x01, 0x00, 0x14, + 0x00, 0x02, 0x80, 0x55, + 0x20, 0x01, 0x0D, 0xB8, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + // XOR-MAPPED-ADDRESS IPv4 + 0x00, 0x20, 0x00, 0x08, + 0x00, 0x01, 0xA1, 0x47, + 0xE1, 0x12, 0xA6, 0x43, + }; + + const ipv4 = Io.net.IpAddress{ .ip4 = .{ .bytes = .{ 192, 0, 2, 1 }, .port = 32853 } }; + const ipv6 = Io.net.IpAddress{ .ip6 = .{ + .bytes = .{ + 0x20, 0x01, 0x0D, 0xB8, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + }, + .port = 32853, + } }; + + var buffer: [1024]u8 = undefined; + var writer: Writer = .init(&buffer, .{}); + try writer.writeHeader(.{ + .message_type = .fromClassAndMethod(.success_response, .binding), + .transaction_id = std.mem.readInt(u96, expected[8..20], .big), + .message_length = 0, + }); + + try writer.writeAttribute(.{ .mapped_address = ipv4 }); + try writer.writeAttribute(.{ .mapped_address = ipv6 }); + try writer.writeAttribute(.{ .xor_mapped_address = ipv4 }); + try testing.expectEqualSlices(u8, &expected, writer.final()); +}