From 0e8c2201a05b1a00e681748a9fd917076edcc8b1 Mon Sep 17 00:00:00 2001 From: Billal GHILAS Date: Wed, 6 May 2026 09:19:29 +0100 Subject: [PATCH] feat(h264): add parameter sets iterator to decorder configuration record --- src/h264.zig | 125 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 99 insertions(+), 26 deletions(-) diff --git a/src/h264.zig b/src/h264.zig index 7db2634..340a9bb 100644 --- a/src/h264.zig +++ b/src/h264.zig @@ -43,6 +43,10 @@ pub const NalType = enum(u5) { fu_b, unspecified30, unspecified31, + + pub inline fn isKeyframe(self: NalType) bool { + return self == .idr; + } }; /// Describes the NAL unit header, which is the first byte of a NAL unit. @@ -329,6 +333,9 @@ pub const DecoderConfigurationRecord = struct { profile_compatibility: u8, avc_level_indication: u8, length_size: u8, + ps_bytes: []const u8 = &.{}, + + pub const H264Error = error{InvalidH264DCR}; pub fn initFromSps(sps: *const Sps) DecoderConfigurationRecord { return DecoderConfigurationRecord{ @@ -339,48 +346,82 @@ pub const DecoderConfigurationRecord = struct { }; } - pub fn parse(data: []const u8) !DecoderConfigurationRecord { - var reader = std.Io.Reader.fixed(data); - _ = try reader.takeByte(); + pub fn parse(data: []const u8) H264Error!DecoderConfigurationRecord { + if (data.len < 7) return error.InvalidH264DCR; return DecoderConfigurationRecord{ - .avc_profile_indication = try reader.takeByte(), - .profile_compatibility = try reader.takeByte(), - .avc_level_indication = try reader.takeByte(), - .length_size = (try reader.takeByte() & 0x03) + 1, + .avc_profile_indication = data[1], + .profile_compatibility = data[2], + .avc_level_indication = data[3], + .length_size = (data[4] & 0x03) + 1, + .ps_bytes = data[5..], }; } - pub const Writer = struct { - writer: *std.Io.Writer, + pub fn iterateParameterSets(self: *const DecoderConfigurationRecord) Iterator { + return .init(self.ps_bytes); + } - pub fn init(writer: *std.Io.Writer) Writer { - return Writer{ .writer = writer }; + pub fn writer(self: *const DecoderConfigurationRecord, w: *std.Io.Writer) !Writer { + try w.writeByte(1); // configurationVersion + try w.writeByte(self.avc_profile_indication); + try w.writeByte(self.profile_compatibility); + try w.writeByte(self.avc_level_indication); + try w.writeByte(0xFC | (self.length_size - 1)); + return .{ .writer = w }; + } + + pub const Iterator = struct { + reader: std.Io.Reader, + nal_type: NalType, + count: u8, + + pub fn init(bytes: []const u8) Iterator { + return .{ + .reader = std.Io.Reader.fixed(bytes[1..]), + .nal_type = .sps, + .count = bytes[0] & 0x1F, + }; } - pub fn write(self: *@This(), config: *const DecoderConfigurationRecord) !void { - try self.writer.writeByte(1); // configurationVersion - try self.writer.writeByte(config.avc_profile_indication); - try self.writer.writeByte(config.profile_compatibility); - try self.writer.writeByte(config.avc_level_indication); - try self.writer.writeByte(0xFC | (config.length_size - 1)); + pub fn next(it: *Iterator) H264Error!?[]const u8 { + if (it.count == 0) { + switch (it.nal_type) { + .sps => { + it.nal_type = .pps; + it.count = it.reader.takeByte() catch return error.InvalidH264DCR; + return it.next(); + }, + .pps => return null, + else => unreachable, + } + } + + const nal_size = it.reader.takeInt(u16, .big) catch return error.InvalidH264DCR; + const result = it.reader.take(nal_size) catch return error.InvalidH264DCR; + it.count -= 1; + return result; } + }; + + pub const Writer = struct { + writer: *std.Io.Writer, - pub fn writeSpsCount(self: *@This(), count: u8) !void { + pub fn writeSpsCount(self: *Writer, count: u8) !void { std.debug.assert(count <= 31); try self.writer.writeByte(0xE0 | (count & 0x1F)); } - pub fn writePpsCount(self: *@This(), count: u8) !void { + pub fn writePpsCount(self: *Writer, count: u8) !void { try self.writer.writeByte(count); } - pub fn writeNalUnit(self: *@This(), nal_data: []const u8) !void { + pub fn writeNalUnit(self: *Writer, nal_data: []const u8) !void { try self.writer.writeInt(u16, @intCast(nal_data.len), .big); try self.writer.writeAll(nal_data); } - pub fn writeBase64NalUnit(self: *@This(), nal_data: []const u8) !void { + pub fn writeBase64NalUnit(self: *Writer, nal_data: []const u8) !void { var decoder = std.base64.standard.Decoder; const nal_size = try decoder.calcSizeForSlice(nal_data); @@ -392,7 +433,7 @@ pub const DecoderConfigurationRecord = struct { }; test "parse valid configuration" { - const data = [_]u8{ 1, 100, 0, 40, 0xFF, 0x00 }; + const data = [_]u8{ 1, 100, 0, 40, 0xFF, 0xE0, 0x00 }; const config = try DecoderConfigurationRecord.parse(&data); try std.testing.expect(config.avc_profile_indication == 100); @@ -413,8 +454,7 @@ pub const DecoderConfigurationRecord = struct { var buf: [64]u8 = undefined; var w = std.Io.Writer.fixed(&buf); - var dcr_writer = DecoderConfigurationRecord.Writer.init(&w); - try dcr_writer.write(&config); + var dcr_writer = try config.writer(&w); try dcr_writer.writeSpsCount(1); try dcr_writer.writeNalUnit(&sps_nal); try dcr_writer.writePpsCount(1); @@ -438,8 +478,7 @@ pub const DecoderConfigurationRecord = struct { var buf: [64]u8 = undefined; var w = std.Io.Writer.fixed(&buf); - var dcr_writer = DecoderConfigurationRecord.Writer.init(&w); - try dcr_writer.write(&original); + var dcr_writer = try original.writer(&w); try dcr_writer.writeSpsCount(0); try dcr_writer.writePpsCount(0); @@ -465,6 +504,40 @@ pub const DecoderConfigurationRecord = struct { try std.testing.expectEqual(sps.level_idc, config.avc_level_indication); try std.testing.expectEqual(@as(u8, 4), config.length_size); } + + test "iterator parameter sets" { + const bytes = [_]u8{ + 0x01, 0x4d, 0x40, 0x1f, 0xff, 0xe1, 0x00, + 0x1d, 0x67, 0x4d, 0x40, 0x1f, 0xec, 0xa0, + 0x6c, 0x1f, 0xf2, 0x44, 0x7f, 0xe1, 0xe2, + 0x01, 0xe2, 0xa2, 0x00, 0x00, 0x03, 0x00, + 0x64, 0x00, 0x00, 0x12, 0xbc, 0x1e, 0x30, + 0x63, 0x2c, 0x01, 0x00, 0x04, 0x68, 0xef, + 0x86, 0xf2, + }; + + const dcr = try parse(&bytes); + try std.testing.expectEqual(77, dcr.avc_profile_indication); + try std.testing.expectEqual(64, dcr.profile_compatibility); + try std.testing.expectEqual(31, dcr.avc_level_indication); + try std.testing.expectEqual(4, dcr.length_size); + + const expected_sps = [_]u8{ + 0x67, 0x4d, 0x40, 0x1f, 0xec, + 0xa0, 0x6c, 0x1f, 0xf2, 0x44, + 0x7f, 0xe1, 0xe2, 0x01, 0xe2, + 0xa2, 0x00, 0x00, 0x03, 0x00, + 0x64, 0x00, 0x00, 0x12, 0xbc, + 0x1e, 0x30, 0x63, 0x2c, + }; + + const expected_pps = [_]u8{ 0x68, 0xef, 0x86, 0xf2 }; + + var it = dcr.iterateParameterSets(); + try std.testing.expectEqualSlices(u8, &expected_sps, (try it.next()).?); + try std.testing.expectEqualSlices(u8, &expected_pps, (try it.next()).?); + try std.testing.expectEqual(null, try it.next()); + } }; const ParameterSetReader = struct {