Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 254 additions & 34 deletions src/stun/stun.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const Io = std.Io;

pub const magic_cookie: u32 = 0x2112A442;
pub const header_size = 20;
const fingerprint_xor: u32 = 0x5354554e;

pub const Class = enum(u2) {
request,
Expand Down Expand Up @@ -125,10 +126,26 @@ pub const Attribute = union(AttributeType) {
userhash: []const u8,
priority: u32,
software: []const u8,
fingerprint: u32,
fingerprint: void,
ice_controlled: u64,
ice_controlling: u64,
unknown: struct { AttributeType, []const u8 },

pub fn size(attribute: Attribute) u16 {
return switch (attribute) {
.priority, .fingerprint => 4,
.ice_controlled, .ice_controlling => 8,
.message_integrity => 20,
.use_candidate => 0,
.software, .username, .userhash => |slice| @intCast(slice.len),
.mapped_address, .xor_mapped_address => |ip| switch (ip) {
.ip4 => 8,
.ip6 => 20,
},
.error_code => |err| @intCast(err.reason.len + 4),
else => 0,
};
}
};

pub const AttributeIterator = struct {
Expand Down Expand Up @@ -193,7 +210,7 @@ pub const AttributeIterator = struct {
if (attr_len != 4) break :blk error.InvalidAttribute;
const fingerprint = std.mem.readInt(u32, attr_value[0..4], .big);
try it.verifyFingerprint(fingerprint);
break :blk .{ .fingerprint = fingerprint };
break :blk .fingerprint;
},
else => .{ .unknown = .{ attr_type, attr_value } },
};
Expand Down Expand Up @@ -281,12 +298,143 @@ pub const AttributeIterator = struct {

var hasher: std.hash.Crc32 = .init();
hasher.update(msg[0 .. msg.len - 8]);
if (hasher.final() ^ 0x5354554e != expected_value) {
if (hasher.final() ^ fingerprint_xor != expected_value) {
return error.FingerprintCheckFailed;
}
}
};

pub const Writer = struct {
writer: Io.Writer,
options: WriterOptions,

pub const WriterOptions = struct {
password: ?[]const u8 = null,
padding_byte: u8 = 0,
};

pub fn init(buffer: []u8, options: WriterOptions) Writer {
return .{ .writer = .fixed(buffer), .options = options };
}

pub fn writeHeader(msg_writer: *Writer, header: Header) !void {
try msg_writer.writer.writeStruct(header, .big);
}

pub fn writeRaw(msg_writer: *Writer, attr_type: AttributeType, content: [][]const u8) !void {
var w = &msg_writer.writer;

try w.writeInt(u16, @intFromEnum(attr_type), .big);
const length = try w.writableArray(2);
const pos = w.end;
try msg_writer.writer.writeVecAll(content);

const attr_size: u16 = @intCast(w.end - pos);
const padding = switch (@rem(attr_size, 4)) {
0 => 0,
else => |v| 4 - v,
};
@memset(try w.writableSlice(padding), msg_writer.options.padding_byte);
std.mem.writeInt(u16, length, attr_size, .big);
}

pub fn writeAttribute(msg_writer: *Writer, attribute: Attribute) !void {
var out = &msg_writer.writer;

try out.writeInt(u16, @intFromEnum(attribute), .big);
try out.writeInt(u16, attribute.size(), .big);
switch (attribute) {
.priority => |p| try out.writeInt(u32, p, .big),
.ice_controlled, .ice_controlling => |tie_breaker| try out.writeInt(u64, tie_breaker, .big),
.message_integrity => try msg_writer.writeMessageIntegrity(),
.fingerprint => try writeFingerprint(&msg_writer.writer),
.software, .username, .userhash => |slice| try out.writeAll(slice),
.mapped_address => |addr| try msg_writer.writeIpAddress(addr, false),
.xor_mapped_address => |addr| try msg_writer.writeIpAddress(addr, true),
.error_code => |err| {
try msg_writer.writer.writeInt(u32, @as(u24, err.code / 100) << 16 | (err.code % 100), .big);
try msg_writer.writer.writeAll(err.reason);
},
else => return error.UnknownAttribute,
}

const padding = switch (@rem(out.end, 4)) {
0 => 0,
else => |v| 4 - v,
};
@memset(try out.writableSlice(padding), msg_writer.options.padding_byte);
}

pub fn final(msg_writer: *Writer) []const u8 {
const result = msg_writer.writer.buffered();
const msg_length: u16 = @intCast(result.len - header_size);
std.mem.writeInt(u16, result[2..4], msg_length, .big);
return result;
}

fn writeMessageIntegrity(msg_writer: *Writer) !void {
var w = &msg_writer.writer;

const buf = w.buffered();
const hash = try w.writableArray(20);
const msg_length: u16 = @intCast(w.end - header_size); // 4 bytes of already written attribute header

var hasher = std.crypto.auth.hmac.HmacSha1.init(msg_writer.options.password.?);
hasher.update(buf[0..2]);
hasher.update(&std.mem.toBytes(std.mem.nativeToBig(u16, msg_length)));
hasher.update(buf[4 .. buf.len - 4]);
hasher.final(hash);
}

fn writeFingerprint(w: *Io.Writer) !void {
const buf = w.buffered();
const msg_size: u16 = @intCast(buf.len - header_size + 4);

var hasher: std.hash.Crc32 = .init();
hasher.update(buf[0..2]);
hasher.update(&std.mem.toBytes(std.mem.nativeToBig(u16, msg_size)));
hasher.update(buf[4 .. buf.len - 4]);

try w.writeInt(u32, hasher.final() ^ fingerprint_xor, .big);
}

fn writeIpAddress(msg_writer: *Writer, addr: Io.net.IpAddress, xor: bool) !void {
var out = &msg_writer.writer;
const cookie = std.mem.toBytes(std.mem.nativeToBig(u32, magic_cookie));
switch (addr) {
.ip4 => |ipv4| {
try out.writeInt(u16, 1, .big);
if (xor) {
const xor_port: u16 = ipv4.port ^ @as(u16, magic_cookie >> 16);
try out.writeAll(&std.mem.toBytes(std.mem.nativeToBig(u16, xor_port)));

const slice = try out.writableSlice(cookie.len);
for (slice, 0..) |*b, idx| b.* = cookie[idx] ^ ipv4.bytes[idx];
} else {
try out.writeInt(u16, ipv4.port, .big);
try out.writeAll(&ipv4.bytes);
}
},
.ip6 => |ipv6| {
try out.writeInt(u16, 2, .big);
if (xor) {
const xor_port: u16 = ipv6.port ^ @as(u16, magic_cookie >> 16);
try out.writeAll(&std.mem.toBytes(std.mem.nativeToBig(u16, xor_port)));

const slice = try out.writableSlice(ipv6.bytes.len);
const txid = out.buffer[8..20];

for (slice, 0..) |*b, idx| b.* = cookie[idx] ^ ipv6.bytes[idx];
for (slice[4..], 0..) |*b, idx| b.* = txid[idx] ^ ipv6.bytes[idx];
} else {
try out.writeInt(u16, ipv6.port, .big);
try out.writeAll(&ipv6.bytes);
}
},
}
}
};

const testing = std.testing;

test "MessageType: round-trip all classes" {
Expand Down Expand Up @@ -405,39 +553,39 @@ test "Message.iterateAttributes: invalid attribute length zero" {
try testing.expectError(error.InvalidAttribute, it.next());
}

const rfc_5769_test_vector = [_]u8{
0x00, 0x01, 0x00, 0x58,
0x21, 0x12, 0xa4, 0x42,
0xb7, 0xe7, 0xa7, 0x01,
0xbc, 0x34, 0xd6, 0x86,
0xfa, 0x87, 0xdf, 0xae,
0x80, 0x22, 0x00, 0x10,
0x53, 0x54, 0x55, 0x4e,
0x20, 0x74, 0x65, 0x73,
0x74, 0x20, 0x63, 0x6c,
0x69, 0x65, 0x6e, 0x74,
0x00, 0x24, 0x00, 0x04,
0x6e, 0x00, 0x01, 0xff,
0x80, 0x29, 0x00, 0x08,
0x93, 0x2f, 0xf9, 0xb1,
0x51, 0x26, 0x3b, 0x36,
0x00, 0x06, 0x00, 0x09,
0x65, 0x76, 0x74, 0x6a,
0x3a, 0x68, 0x36, 0x76,
0x59, 0x20, 0x20, 0x20,
0x00, 0x08, 0x00, 0x14,
0x9a, 0xea, 0xa7, 0x0c,
0xbf, 0xd8, 0xcb, 0x56,
0x78, 0x1e, 0xf2, 0xb5,
0xb2, 0xd3, 0xf2, 0x49,
0xc1, 0xb5, 0x71, 0xa2,
0x80, 0x28, 0x00, 0x04,
0xe5, 0x7a, 0x3b, 0xcf,
};

// test vectors (RFC 5769)
test "Message: iterator attributes" {
const data = [_]u8{
0x00, 0x01, 0x00, 0x58,
0x21, 0x12, 0xa4, 0x42,
0xb7, 0xe7, 0xa7, 0x01,
0xbc, 0x34, 0xd6, 0x86,
0xfa, 0x87, 0xdf, 0xae,
0x80, 0x22, 0x00, 0x10,
0x53, 0x54, 0x55, 0x4e,
0x20, 0x74, 0x65, 0x73,
0x74, 0x20, 0x63, 0x6c,
0x69, 0x65, 0x6e, 0x74,
0x00, 0x24, 0x00, 0x04,
0x6e, 0x00, 0x01, 0xff,
0x80, 0x29, 0x00, 0x08,
0x93, 0x2f, 0xf9, 0xb1,
0x51, 0x26, 0x3b, 0x36,
0x00, 0x06, 0x00, 0x09,
0x65, 0x76, 0x74, 0x6a,
0x3a, 0x68, 0x36, 0x76,
0x59, 0x20, 0x20, 0x20,
0x00, 0x08, 0x00, 0x14,
0x9a, 0xea, 0xa7, 0x0c,
0xbf, 0xd8, 0xcb, 0x56,
0x78, 0x1e, 0xf2, 0xb5,
0xb2, 0xd3, 0xf2, 0x49,
0xc1, 0xb5, 0x71, 0xa2,
0x80, 0x28, 0x00, 0x04,
0xe5, 0x7a, 0x3b, 0xcf,
};

const message = try Message.parse(&data);
const message = try Message.parse(&rfc_5769_test_vector);
try testing.expectEqual(.request, message.header.message_type.class());
try testing.expectEqual(.binding, message.header.message_type.method());

Expand Down Expand Up @@ -474,3 +622,75 @@ test "Message.iterateAttributes: invalid attribute length not multiple of 4" {
var it = msg.iterateAttributes(&.{});
try testing.expectError(error.InvalidAttribute, it.next());
}

test "Writer: write rfc message" {
var buffer: [1024]u8 = undefined;

var out = Writer.init(&buffer, .{
.password = "VOkJxbRl1RmTxUk/WvJxBt",
.padding_byte = 0x20,
});

try out.writeHeader(.{
.message_type = .fromClassAndMethod(.request, .binding),
.transaction_id = std.mem.readInt(u96, rfc_5769_test_vector[8..20], .big),
.message_length = 0,
});

try out.writeAttribute(.{ .software = "STUN test client" });
try out.writeAttribute(.{ .priority = 0x6E0001FF });
try out.writeAttribute(.{ .ice_controlled = 0x932FF9B151263B36 });
var username = [_][]const u8{ "evtj", ":", "h6vY" };
try out.writeRaw(.username, &username);
try out.writeAttribute(.{ .message_integrity = &.{} });
try out.writeAttribute(.fingerprint);

try std.testing.expectEqualSlices(u8, &rfc_5769_test_vector, out.final());
}

test "Writer: write mapped and xor mapped addresses" {
const expected = [_]u8{
0x01, 0x01, 0x00, 0x30,
0x21, 0x12, 0xA4, 0x42,
0x00, 0x01, 0x02, 0x03,
0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0A, 0x0B,
// MAPPED-ADDRESS IPv4
0x00, 0x01, 0x00, 0x08,
0x00, 0x01, 0x80, 0x55,
192, 0, 2, 1,
// MAPPED-ADDRESS IPv6
0x00, 0x01, 0x00, 0x14,
0x00, 0x02, 0x80, 0x55,
0x20, 0x01, 0x0D, 0xB8,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
// XOR-MAPPED-ADDRESS IPv4
0x00, 0x20, 0x00, 0x08,
0x00, 0x01, 0xA1, 0x47,
0xE1, 0x12, 0xA6, 0x43,
};

const ipv4 = Io.net.IpAddress{ .ip4 = .{ .bytes = .{ 192, 0, 2, 1 }, .port = 32853 } };
const ipv6 = Io.net.IpAddress{ .ip6 = .{
.bytes = .{
0x20, 0x01, 0x0D, 0xB8, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
},
.port = 32853,
} };

var buffer: [1024]u8 = undefined;
var writer: Writer = .init(&buffer, .{});
try writer.writeHeader(.{
.message_type = .fromClassAndMethod(.success_response, .binding),
.transaction_id = std.mem.readInt(u96, expected[8..20], .big),
.message_length = 0,
});

try writer.writeAttribute(.{ .mapped_address = ipv4 });
try writer.writeAttribute(.{ .mapped_address = ipv6 });
try writer.writeAttribute(.{ .xor_mapped_address = ipv4 });
try testing.expectEqualSlices(u8, &expected, writer.final());
}
Loading