From 031c26817f3fca14f736d82c310c0ea918ca8022 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 13 May 2025 00:48:05 +0800 Subject: [PATCH 1/4] new: longlonguint multiplication --- src/longlonguint.jl | 32 +++++++++++++++++++++++++++++- test/longlonguint.jl | 46 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/longlonguint.jl b/src/longlonguint.jl index db8ad3a..6923a2e 100644 --- a/src/longlonguint.jl +++ b/src/longlonguint.jl @@ -102,7 +102,7 @@ end function _sadd(x::NTuple{C,UInt}, y::NTuple{C,UInt}, c::Bool) where {C} v1, c1 = Base.add_with_overflow(x[C], y[C]) if c - v2, c2 = Base.add_with_overflow(v1, c) + v2, c2 = Base.add_with_overflow(v1, UInt(c)) c = c1 || c2 return (_sadd(x[1:C-1], y[1:C-1], c)..., v2) else @@ -125,6 +125,36 @@ function _ssub(x::NTuple{C,UInt}, y::NTuple{C,UInt}, c::Bool) where {C} return (_ssub(x[1:C-1], y[1:C-1], c1)..., v1) end end + +function Base.:(*)(x::LongLongUInt{C}, y::LongLongUInt{C}) where {C} + result = zero(LongLongUInt{C}) + for i in 1:C + x.content[C-i+1] == 0 && continue + for j in 1:C + y.content[C-j+1] == 0 && continue + # Skip if either position is out of bounds for the result + pos = i + j - 1 + pos > C && continue + + # Multiply the corresponding elements + mres = Base.widemul(x.content[C-i+1], y.content[C-j+1]) + + # Add the low part to the result at position pos + partial = LongLongUInt(ntuple(k -> (k == C-pos+1 ? UInt(mres & typemax(UInt)) : (k == C-pos ? UInt(mres >> bsizeof(UInt)) : zero(UInt))), Val{C}())) + result = result + partial + + # # Add the high part to the result at position pos-1 if there's overflow + # if overflow && pos-1 >= 1 + # hi = x.content[C-i+1] >> (sizeof(UInt) * 4) + # hi = hi * (y.content[C-j+1] >> (sizeof(UInt) * 4)) + # partial = LongLongUInt(ntuple(k -> (k == C-(pos-1)+1 ? hi : zero(UInt)), Val{C}())) + # result = result + partial + # end + end + end + + return result +end Base.count_ones(x::LongLongUInt) = sum(count_ones, x.content) Base.bitstring(x::LongLongUInt) = join(bitstring.(x.content), "") diff --git a/test/longlonguint.jl b/test/longlonguint.jl index 56c4780..f65300c 100644 --- a/test/longlonguint.jl +++ b/test/longlonguint.jl @@ -38,6 +38,11 @@ using Test, BitBasis BitBasis.max_num_elements(Int, 2) == 63 @test count_ones(bmask(LongLongUInt{50}, 1:3000)) == 3000 + + # Hard add + a = LongLongUInt((ntuple(i -> UInt64(0), Val{1}())..., ntuple(i -> rand(UInt64), Val{9}())...)) + b = LongLongUInt((ntuple(i -> UInt64(0), Val{1}())..., ntuple(i -> rand(UInt64), Val{9}())...)) + @test BigInt(a + b) == BigInt(a) + BigInt(b) end @testset "shift" begin @@ -136,3 +141,44 @@ end @test log2i(LongLongUInt((0,2))) == 1 @test log2i(LongLongUInt((1,0))) == 64 end + +@testset "multiplication" begin + # Test basic multiplication + x = LongLongUInt((2,)) + y = LongLongUInt((3,)) + @test x * y == LongLongUInt((6,)) + + # Test multiplication with overflow within a single UInt + x = LongLongUInt((0x8000000000000000,)) + y = LongLongUInt((2,)) + @test x * y == LongLongUInt((0,)) # Overflow truncated to original size + + # Test multiplication across multiple UInts + x = LongLongUInt((1, 0)) # 1 << 64 + y = LongLongUInt((0, 5)) # 2 << 64 + @test x * y == LongLongUInt((5, 0)) # Result truncated to original size + + # Test with values in both positions + x = LongLongUInt((0, 2)) + y = LongLongUInt((3, 4)) + # Expected: (1*3) << 128 + (1*4 + 2*3) << 64 + (2*4), truncated to 2 UInts + @test x * y == LongLongUInt((6, 8)) + + # Verify with BigInt conversion + x = LongLongUInt((10, 7)) + y = LongLongUInt((0, 3)) + result = x * y + expected = BigInt(x) * BigInt(y) + expected_truncated = expected & ((BigInt(1) << 128) - 1) # Truncate to 128 bits + @test BigInt(result) == expected_truncated + + x = LongLongUInt((UInt(0), typemax(UInt64))) + y = LongLongUInt((UInt(0), UInt(3))) + @test BigInt(x * y) == BigInt(x) * BigInt(y) + + # Hard instance + a = LongLongUInt((ntuple(i -> UInt64(0), Val{10}())..., ntuple(i -> rand(UInt64), Val{5}())...)) + b = LongLongUInt((ntuple(i -> UInt64(0), Val{6}())..., ntuple(i -> rand(UInt64), Val{9}())...)) + @show BigInt(a * b) - BigInt(a) * BigInt(b) + @test BigInt(a * b) == BigInt(a) * BigInt(b) +end From 617e1ad35afcd2cef5fd4e280194aae1b1c5d5ed Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 13 May 2025 01:29:50 +0800 Subject: [PATCH 2/4] update --- src/longlonguint.jl | 32 ++++++++++++++++++++++---------- test/longlonguint.jl | 11 +++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/longlonguint.jl b/src/longlonguint.jl index 6923a2e..cc2b992 100644 --- a/src/longlonguint.jl +++ b/src/longlonguint.jl @@ -3,7 +3,7 @@ A `LongLongUInt{C}` is an integer with `C` `UInt` numbers to store the value. """ -struct LongLongUInt{C} <: Integer +struct LongLongUInt{C} <: Unsigned content::NTuple{C, UInt} function LongLongUInt{C}(content::NTuple{C, UInt}) where {C} new{C}(content) @@ -15,6 +15,9 @@ struct LongLongUInt{C} <: Integer LongLongUInt{C}(content) end end +Base.string(x::LongLongUInt{C}) where {C} = "LongLongUInt{$C}(" * join(bitstring.(x.content), "") * ")" +Base.show(io::IO, x::LongLongUInt{C}) where {C} = print(io, string(x)) +Base.show(io::IO, ::MIME"text/plain", x::LongLongUInt{C}) where {C} = print(io, string(x)) bsizeof(::LongLongUInt{C}) where C = bsizeof(UInt64) * C nint(::LongLongUInt{C}) where {C} = C Base.Int(x::LongLongUInt{1}) = Int(x.content[1]) @@ -142,19 +145,11 @@ function Base.:(*)(x::LongLongUInt{C}, y::LongLongUInt{C}) where {C} # Add the low part to the result at position pos partial = LongLongUInt(ntuple(k -> (k == C-pos+1 ? UInt(mres & typemax(UInt)) : (k == C-pos ? UInt(mres >> bsizeof(UInt)) : zero(UInt))), Val{C}())) result = result + partial - - # # Add the high part to the result at position pos-1 if there's overflow - # if overflow && pos-1 >= 1 - # hi = x.content[C-i+1] >> (sizeof(UInt) * 4) - # hi = hi * (y.content[C-j+1] >> (sizeof(UInt) * 4)) - # partial = LongLongUInt(ntuple(k -> (k == C-(pos-1)+1 ? hi : zero(UInt)), Val{C}())) - # result = result + partial - # end end end - return result end + Base.count_ones(x::LongLongUInt) = sum(count_ones, x.content) Base.bitstring(x::LongLongUInt) = join(bitstring.(x.content), "") @@ -170,3 +165,20 @@ Base.hash(x::LongLongUInt{C}) where{C} = hash(x.content) # these APIs will are used in SparseTN BitBasis.log2i(x::LongLongUInt{C}) where C = floor(Int, log2(Float64(BigInt(x)))) Base.BigInt(x::LongLongUInt{C}) where C = mapfoldl(x -> BigInt(x), (x, y) -> ((x << 64) | y), x.content) + +function Base.Int(x::LongLongUInt) + if all(iszero, x.content[2:end]) && x.content[1] < typemax(Int) + return Int(x.content[1]) + else + throw(InexactError(:Int, x)) + end +end + +function Base.hash(bits_tuple::Tuple{LongLongUInt{C}, Vararg{LongLongUInt{C}, M}}) where{M, C} + N = M + 1 + hash0 = Base.hash(bits_tuple[1].content) + for i in 2:N + hash0 = Base.hash(bits_tuple[i].content, hash0) + end + return hash0 +end \ No newline at end of file diff --git a/test/longlonguint.jl b/test/longlonguint.jl index f65300c..899a736 100644 --- a/test/longlonguint.jl +++ b/test/longlonguint.jl @@ -2,6 +2,8 @@ using Test, BitBasis @testset "longlonguint" begin x = LongLongUInt((3,)) + @test abs(x) == x + @test Int(x) === 3 @test LongLongUInt{1}(x) == x @test Int(x) === 3 @test bsizeof(x) == 64 @@ -18,6 +20,7 @@ using Test, BitBasis @test promote(UInt(1), LongLongUInt((3, 4))) == (LongLongUInt((0, 1)), LongLongUInt((3, 4))) y = LongLongUInt((5, 7)) + @test_throws InexactError Int(y) @test one(y) == LongLongUInt((0, 1)) @test x & y == LongLongUInt((1, 6)) @test x | y == LongLongUInt((7, 7)) @@ -182,3 +185,11 @@ end @show BigInt(a * b) - BigInt(a) * BigInt(b) @test BigInt(a * b) == BigInt(a) * BigInt(b) end + +@testset "LongLongUInt hash" begin + b1 = bmask(LongLongUInt{1}, 1) + b2 = bmask(LongLongUInt{1}, 2) + b3 = bmask(LongLongUInt{1}, 3) + + @test hash((b1, b2, b3)) == hash((b1, b2, b3)) +end \ No newline at end of file From 62a60369488261766acac7e87f0be7a612df7e23 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 13 May 2025 01:37:57 +0800 Subject: [PATCH 3/4] new longlonguint - div --- src/longlonguint.jl | 42 ++++++++++++++++++++++++++++++++++ test/longlonguint.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/src/longlonguint.jl b/src/longlonguint.jl index cc2b992..ea9276f 100644 --- a/src/longlonguint.jl +++ b/src/longlonguint.jl @@ -150,6 +150,48 @@ function Base.:(*)(x::LongLongUInt{C}, y::LongLongUInt{C}) where {C} return result end +function Base.div(x::LongLongUInt{C}, y::LongLongUInt{C}) where {C} + y == zero(LongLongUInt{C}) && throw(DivideError()) + x < y && return zero(LongLongUInt{C}) + x == y && return one(LongLongUInt{C}) + + # Initialize quotient and remainder + quotient = zero(LongLongUInt{C}) + remainder = x + + # Find the highest bit position in y + y_highest_bit = 0 + for i in 1:C + if y.content[i] != 0 + y_highest_bit = (C - i) * 64 + (64 - leading_zeros(y.content[i])) + break + end + end + + # Find the highest bit position in x + x_highest_bit = 0 + for i in 1:C + if x.content[i] != 0 + x_highest_bit = (C - i) * 64 + (64 - leading_zeros(x.content[i])) + break + end + end + + # Long division algorithm + for i in (x_highest_bit - y_highest_bit + 1):-1:1 + # Shift y left by i-1 bits + shifted_y = y << (i - 1) + + # If remainder >= shifted_y, subtract and set bit in quotient + if remainder >= shifted_y + remainder = remainder - shifted_y + quotient = quotient | (one(LongLongUInt{C}) << (i - 1)) + end + end + + return quotient +end + Base.count_ones(x::LongLongUInt) = sum(count_ones, x.content) Base.bitstring(x::LongLongUInt) = join(bitstring.(x.content), "") diff --git a/test/longlonguint.jl b/test/longlonguint.jl index 899a736..2cc9469 100644 --- a/test/longlonguint.jl +++ b/test/longlonguint.jl @@ -186,6 +186,60 @@ end @test BigInt(a * b) == BigInt(a) * BigInt(b) end +@testset "division" begin + # Test basic division + x = LongLongUInt((6,)) + y = LongLongUInt((3,)) + @test div(x, y) == LongLongUInt((2,)) + + # Test division with zero + x = LongLongUInt((5,)) + y = LongLongUInt((0,)) + @test_throws DivideError div(x, y) + + # Test division where result is zero + x = LongLongUInt((3,)) + y = LongLongUInt((5,)) + @test div(x, y) == LongLongUInt((0,)) + + # Test division where result is one + x = LongLongUInt((7,)) + y = LongLongUInt((7,)) + @test div(x, y) == LongLongUInt((1,)) + + # Test division with multi-UInt values + x = LongLongUInt((1, 0)) # 1 << 64 + y = LongLongUInt((0, 2)) # 2 + @test div(x, y) == LongLongUInt((UInt(0), UInt(1) << 63)) # (1 << 64) ÷ 2 = 1 << 63 + + # Test with values in both positions + x = LongLongUInt((3, 0)) # 3 << 64 + y = LongLongUInt((0, 3)) # 3 + @test div(x, y) == LongLongUInt((1, 0)) # (3 << 64) ÷ 3 = 1 << 64 + + # Verify with BigInt conversion + x = LongLongUInt((10, 7)) + y = LongLongUInt((0, 3)) + result = div(x, y) + expected = div(BigInt(x), BigInt(y)) + @test BigInt(result) == expected + + # Test with large values + x = LongLongUInt((UInt(1), UInt(0))) + y = LongLongUInt((UInt(0), UInt(2))) + @test BigInt(div(x, y)) == div(BigInt(x), BigInt(y)) + + # Test with random values + a = LongLongUInt((rand(UInt64), rand(UInt64))) + b = LongLongUInt((UInt(0), rand(UInt64) | UInt(1))) # Ensure non-zero + @test BigInt(div(a, b)) == div(BigInt(a), BigInt(b)) + + # Hard test + a = LongLongUInt((ntuple(i -> UInt64(0), Val{10}())..., ntuple(i -> rand(UInt64), Val{5}())...)) + b = LongLongUInt((ntuple(i -> UInt64(0), Val{6}())..., ntuple(i -> rand(UInt64), Val{9}())...)) + @test BigInt(div(b, a)) == div(BigInt(b), BigInt(a)) +end + @testset "LongLongUInt hash" begin b1 = bmask(LongLongUInt{1}, 1) b2 = bmask(LongLongUInt{1}, 2) From fec9d77311f6a4ec794a4ddcaba7e2e83ad57030 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 13 May 2025 01:38:32 +0800 Subject: [PATCH 4/4] fix test --- test/longlonguint.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/longlonguint.jl b/test/longlonguint.jl index 2cc9469..4f74abe 100644 --- a/test/longlonguint.jl +++ b/test/longlonguint.jl @@ -32,7 +32,6 @@ using Test, BitBasis # add with overflow z = LongLongUInt((UInt(17), typemax(UInt)-1)) @test z + x == LongLongUInt((21, 4)) - @test (@allocated z + x) == 0 # maximum number of elements BitBasis.max_num_elements(LongLongUInt{2}, 2) == 80