Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions src/longlonguint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

A `LongLongUInt{C}` is an integer with `C` `UInt` numbers to store the value.
"""
struct LongLongUInt{C} <: Integer
struct LongLongUInt{C} <: Unsigned
content::NTuple{C, UInt}
function LongLongUInt{C}(content::NTuple{C, UInt}) where {C}
new{C}(content)
Expand All @@ -15,6 +15,9 @@ struct LongLongUInt{C} <: Integer
LongLongUInt{C}(content)
end
end
Base.string(x::LongLongUInt{C}) where {C} = "LongLongUInt{$C}(" * join(bitstring.(x.content), "") * ")"
Base.show(io::IO, x::LongLongUInt{C}) where {C} = print(io, string(x))
Base.show(io::IO, ::MIME"text/plain", x::LongLongUInt{C}) where {C} = print(io, string(x))
bsizeof(::LongLongUInt{C}) where C = bsizeof(UInt64) * C
nint(::LongLongUInt{C}) where {C} = C
Base.Int(x::LongLongUInt{1}) = Int(x.content[1])
Expand Down Expand Up @@ -102,7 +105,7 @@ end
function _sadd(x::NTuple{C,UInt}, y::NTuple{C,UInt}, c::Bool) where {C}
v1, c1 = Base.add_with_overflow(x[C], y[C])
if c
v2, c2 = Base.add_with_overflow(v1, c)
v2, c2 = Base.add_with_overflow(v1, UInt(c))
c = c1 || c2
return (_sadd(x[1:C-1], y[1:C-1], c)..., v2)
else
Expand All @@ -125,6 +128,70 @@ function _ssub(x::NTuple{C,UInt}, y::NTuple{C,UInt}, c::Bool) where {C}
return (_ssub(x[1:C-1], y[1:C-1], c1)..., v1)
end
end

function Base.:(*)(x::LongLongUInt{C}, y::LongLongUInt{C}) where {C}
result = zero(LongLongUInt{C})
for i in 1:C
x.content[C-i+1] == 0 && continue
for j in 1:C
y.content[C-j+1] == 0 && continue
# Skip if either position is out of bounds for the result
pos = i + j - 1
pos > C && continue

# Multiply the corresponding elements
mres = Base.widemul(x.content[C-i+1], y.content[C-j+1])

# Add the low part to the result at position pos
partial = LongLongUInt(ntuple(k -> (k == C-pos+1 ? UInt(mres & typemax(UInt)) : (k == C-pos ? UInt(mres >> bsizeof(UInt)) : zero(UInt))), Val{C}()))
result = result + partial
end
end
return result
end

function Base.div(x::LongLongUInt{C}, y::LongLongUInt{C}) where {C}
y == zero(LongLongUInt{C}) && throw(DivideError())
x < y && return zero(LongLongUInt{C})
x == y && return one(LongLongUInt{C})

# Initialize quotient and remainder
quotient = zero(LongLongUInt{C})
remainder = x

# Find the highest bit position in y
y_highest_bit = 0
for i in 1:C
if y.content[i] != 0
y_highest_bit = (C - i) * 64 + (64 - leading_zeros(y.content[i]))
break
end
end

# Find the highest bit position in x
x_highest_bit = 0
for i in 1:C
if x.content[i] != 0
x_highest_bit = (C - i) * 64 + (64 - leading_zeros(x.content[i]))
break
end
end

# Long division algorithm
for i in (x_highest_bit - y_highest_bit + 1):-1:1
# Shift y left by i-1 bits
shifted_y = y << (i - 1)

# If remainder >= shifted_y, subtract and set bit in quotient
if remainder >= shifted_y
remainder = remainder - shifted_y
quotient = quotient | (one(LongLongUInt{C}) << (i - 1))
end
end

return quotient
end

Base.count_ones(x::LongLongUInt) = sum(count_ones, x.content)
Base.bitstring(x::LongLongUInt) = join(bitstring.(x.content), "")

Expand All @@ -140,3 +207,20 @@ Base.hash(x::LongLongUInt{C}) where{C} = hash(x.content)
# these APIs will are used in SparseTN
BitBasis.log2i(x::LongLongUInt{C}) where C = floor(Int, log2(Float64(BigInt(x))))
Base.BigInt(x::LongLongUInt{C}) where C = mapfoldl(x -> BigInt(x), (x, y) -> ((x << 64) | y), x.content)

function Base.Int(x::LongLongUInt)
if all(iszero, x.content[2:end]) && x.content[1] < typemax(Int)
return Int(x.content[1])
else
throw(InexactError(:Int, x))
end
end

function Base.hash(bits_tuple::Tuple{LongLongUInt{C}, Vararg{LongLongUInt{C}, M}}) where{M, C}
N = M + 1
hash0 = Base.hash(bits_tuple[1].content)
for i in 2:N
hash0 = Base.hash(bits_tuple[i].content, hash0)
end
return hash0
end
112 changes: 111 additions & 1 deletion test/longlonguint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Test, BitBasis

@testset "longlonguint" begin
x = LongLongUInt((3,))
@test abs(x) == x
@test Int(x) === 3
@test LongLongUInt{1}(x) == x
@test Int(x) === 3
@test bsizeof(x) == 64
Expand All @@ -18,6 +20,7 @@ using Test, BitBasis
@test promote(UInt(1), LongLongUInt((3, 4))) == (LongLongUInt((0, 1)), LongLongUInt((3, 4)))

y = LongLongUInt((5, 7))
@test_throws InexactError Int(y)
@test one(y) == LongLongUInt((0, 1))
@test x & y == LongLongUInt((1, 6))
@test x | y == LongLongUInt((7, 7))
Expand All @@ -29,7 +32,6 @@ using Test, BitBasis
# add with overflow
z = LongLongUInt((UInt(17), typemax(UInt)-1))
@test z + x == LongLongUInt((21, 4))
@test (@allocated z + x) == 0

# maximum number of elements
BitBasis.max_num_elements(LongLongUInt{2}, 2) == 80
Expand All @@ -38,6 +40,11 @@ using Test, BitBasis
BitBasis.max_num_elements(Int, 2) == 63

@test count_ones(bmask(LongLongUInt{50}, 1:3000)) == 3000

# Hard add
a = LongLongUInt((ntuple(i -> UInt64(0), Val{1}())..., ntuple(i -> rand(UInt64), Val{9}())...))
b = LongLongUInt((ntuple(i -> UInt64(0), Val{1}())..., ntuple(i -> rand(UInt64), Val{9}())...))
@test BigInt(a + b) == BigInt(a) + BigInt(b)
end

@testset "shift" begin
Expand Down Expand Up @@ -136,3 +143,106 @@ end
@test log2i(LongLongUInt((0,2))) == 1
@test log2i(LongLongUInt((1,0))) == 64
end

@testset "multiplication" begin
# Test basic multiplication
x = LongLongUInt((2,))
y = LongLongUInt((3,))
@test x * y == LongLongUInt((6,))

# Test multiplication with overflow within a single UInt
x = LongLongUInt((0x8000000000000000,))
y = LongLongUInt((2,))
@test x * y == LongLongUInt((0,)) # Overflow truncated to original size

# Test multiplication across multiple UInts
x = LongLongUInt((1, 0)) # 1 << 64
y = LongLongUInt((0, 5)) # 2 << 64
@test x * y == LongLongUInt((5, 0)) # Result truncated to original size

# Test with values in both positions
x = LongLongUInt((0, 2))
y = LongLongUInt((3, 4))
# Expected: (1*3) << 128 + (1*4 + 2*3) << 64 + (2*4), truncated to 2 UInts
@test x * y == LongLongUInt((6, 8))

# Verify with BigInt conversion
x = LongLongUInt((10, 7))
y = LongLongUInt((0, 3))
result = x * y
expected = BigInt(x) * BigInt(y)
expected_truncated = expected & ((BigInt(1) << 128) - 1) # Truncate to 128 bits
@test BigInt(result) == expected_truncated

x = LongLongUInt((UInt(0), typemax(UInt64)))
y = LongLongUInt((UInt(0), UInt(3)))
@test BigInt(x * y) == BigInt(x) * BigInt(y)

# Hard instance
a = LongLongUInt((ntuple(i -> UInt64(0), Val{10}())..., ntuple(i -> rand(UInt64), Val{5}())...))
b = LongLongUInt((ntuple(i -> UInt64(0), Val{6}())..., ntuple(i -> rand(UInt64), Val{9}())...))
@show BigInt(a * b) - BigInt(a) * BigInt(b)
@test BigInt(a * b) == BigInt(a) * BigInt(b)
end

@testset "division" begin
# Test basic division
x = LongLongUInt((6,))
y = LongLongUInt((3,))
@test div(x, y) == LongLongUInt((2,))

# Test division with zero
x = LongLongUInt((5,))
y = LongLongUInt((0,))
@test_throws DivideError div(x, y)

# Test division where result is zero
x = LongLongUInt((3,))
y = LongLongUInt((5,))
@test div(x, y) == LongLongUInt((0,))

# Test division where result is one
x = LongLongUInt((7,))
y = LongLongUInt((7,))
@test div(x, y) == LongLongUInt((1,))

# Test division with multi-UInt values
x = LongLongUInt((1, 0)) # 1 << 64
y = LongLongUInt((0, 2)) # 2
@test div(x, y) == LongLongUInt((UInt(0), UInt(1) << 63)) # (1 << 64) ÷ 2 = 1 << 63

# Test with values in both positions
x = LongLongUInt((3, 0)) # 3 << 64
y = LongLongUInt((0, 3)) # 3
@test div(x, y) == LongLongUInt((1, 0)) # (3 << 64) ÷ 3 = 1 << 64

# Verify with BigInt conversion
x = LongLongUInt((10, 7))
y = LongLongUInt((0, 3))
result = div(x, y)
expected = div(BigInt(x), BigInt(y))
@test BigInt(result) == expected

# Test with large values
x = LongLongUInt((UInt(1), UInt(0)))
y = LongLongUInt((UInt(0), UInt(2)))
@test BigInt(div(x, y)) == div(BigInt(x), BigInt(y))

# Test with random values
a = LongLongUInt((rand(UInt64), rand(UInt64)))
b = LongLongUInt((UInt(0), rand(UInt64) | UInt(1))) # Ensure non-zero
@test BigInt(div(a, b)) == div(BigInt(a), BigInt(b))

# Hard test
a = LongLongUInt((ntuple(i -> UInt64(0), Val{10}())..., ntuple(i -> rand(UInt64), Val{5}())...))
b = LongLongUInt((ntuple(i -> UInt64(0), Val{6}())..., ntuple(i -> rand(UInt64), Val{9}())...))
@test BigInt(div(b, a)) == div(BigInt(b), BigInt(a))
end

@testset "LongLongUInt hash" begin
b1 = bmask(LongLongUInt{1}, 1)
b2 = bmask(LongLongUInt{1}, 2)
b3 = bmask(LongLongUInt{1}, 3)

@test hash((b1, b2, b3)) == hash((b1, b2, b3))
end
Loading