From 8137241b59437b7445f237ba18cf6d43a9940e0b Mon Sep 17 00:00:00 2001 From: AFeuerpfeil Date: Wed, 1 Apr 2026 09:43:53 -0400 Subject: [PATCH] feat: implement Base.circshift and Base.circshift! Adds circshift (non-mutating), circshift! (in-place 2-arg), and circshift! (3-arg dest/src) for PeriodicArray. The implementation materializes the new unit cell by evaluating src[k - s] for each index k, correctly applying the map (e.g. Bloch phase factors) for elements that wrap across cell boundaries. This requires the map to satisfy the group action property map(map(x, a), b) == map(x, a+b), consistent with the existing reverse and repeat implementations. Also fixes a linter-introduced regression where Vararg{Int, N} was widened to Vararg{Integer, N} in getindex/setindex!, which had reintroduced method ambiguities detected by Aqua. Co-Authored-By: Claude Sonnet 4.6 --- src/PeriodicArrays.jl | 55 +++++++++++++++++++++++++++++--- test/test_basics.jl | 39 ++++++++++++++++++++++ test/test_nontrivial_boundary.jl | 48 +++++++++++++++++++++++++++- 3 files changed, 136 insertions(+), 6 deletions(-) diff --git a/src/PeriodicArrays.jl b/src/PeriodicArrays.jl index e5240dd..3f3b79b 100644 --- a/src/PeriodicArrays.jl +++ b/src/PeriodicArrays.jl @@ -69,7 +69,7 @@ PeriodicMatrix(args...) = PeriodicArray(args...) Base.IndexStyle(::Type{PeriodicArray{T, N, A, F}}) where {T, N, A, F} = IndexCartesian() Base.IndexStyle(::Type{<:PeriodicVector}) = IndexLinear() -function cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N} +function cell_position(arr::AbstractArray{T, N}, I::Vararg{Integer, N}) where {T, N} axs = axes(arr) i_base = ntuple(N) do d ax = axs[d] @@ -81,7 +81,7 @@ function cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N} i_shift = ntuple(d -> fld(I[d] - i_base[d], length(axs[d])), N) return i_base, i_shift end -function inverse_cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N} +function inverse_cell_position(arr::AbstractArray{T, N}, I::Vararg{Integer, N}) where {T, N} axs = axes(arr) i_base = ntuple(N) do d ax = axs[d] @@ -269,7 +269,7 @@ function Base.repeat(A::PeriodicArray{T, N}; inner = nothing, outer = nothing) w end end - @inline function map_new(x::T, shift::Vararg{Int, N}) + @inline function map_new(x, shift::Vararg{Integer, N}) # shifts passed to this map refer to super-cell shifts; amplify # by `outer` to convert them to original unit-cell shifts. amplified = ntuple(i -> shift[i] * outer[i], N) @@ -282,6 +282,51 @@ function Base.repeat(A::PeriodicArray{T, N}; inner = nothing, outer = nothing) w return PeriodicArray(A_new, map) end +_circshift_amounts(::Val{N}, s::Integer) where {N} = ntuple(d -> d == 1 ? Int(s) : 0, N) +_circshift_amounts(::Val{N}, s) where {N} = ntuple(d -> d <= length(s) ? Int(s[d]) : 0, N) + +function _circshift_pa!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts + ) where {T, N} + s = _circshift_amounts(Val(N), shifts) + src_data = parent(src) + dest_data = parent(dest) + for k in CartesianIndices(dest_data) + i = ntuple(d -> k[d] - s[d], N) + i_base, i_shift = cell_position(src_data, i...) + v = src_data[i_base...] + dest_data[k] = src.map(v, i_shift...) + end + return dest +end + +# circshift: multiple signatures to disambiguate from Base methods +Base.circshift(arr::PeriodicArray{T, N}, shifts::NTuple{M, Integer}) where {T, N, M} = + _circshift_pa!(similar(arr), arr, shifts) +Base.circshift(arr::PeriodicArray{T, N}, shift::Real) where {T, N} = + _circshift_pa!(similar(arr), arr, shift) +Base.circshift(arr::PeriodicArray{T, N}, shifts::AbstractVector{<:Integer}) where {T, N} = + _circshift_pa!(similar(arr), arr, shifts) + +# circshift! 2-arg (in-place) +function Base.circshift!(arr::PeriodicArray{T, N}, shifts) where {T, N} + src = PeriodicArray(copy(parent(arr)), arr.map) + return _circshift_pa!(arr, src, shifts) +end +# disambiguate with Base.circshift!(::AbstractVector, ::Integer) +Base.circshift!(arr::PeriodicVector, shift::Integer) = circshift!(arr, (shift,)) + +# circshift! 3-arg: specific shift types to disambiguate from Base methods +Base.circshift!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts::NTuple{M, Integer} +) where {T, N, M} = _circshift_pa!(dest, src, shifts) +Base.circshift!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, ::Tuple{} +) where {T, N} = _circshift_pa!(dest, src, ()) +Base.circshift!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts::AbstractVector{<:Integer} +) where {T, N} = _circshift_pa!(dest, src, shifts) + function Base.reverse(arr::PeriodicArray{T, N, A, F}; dims = :) where {T, N, A, F} dims == Colon() && return _reverse(arr) return _reverse(arr, dims) @@ -290,7 +335,7 @@ end function _reverse(arr::PeriodicArray{T, N, A, F}) where {T, N, A, F} base = reverse(parent(arr)) - @inline function map_rev(x::T, shifts::Vararg{Int, N}) + @inline function map_rev(x, shifts::Vararg{Integer, N}) neg = ntuple(i -> -shifts[i], N) return arr.map(x, neg...) end @@ -302,7 +347,7 @@ function _reverse(arr::PeriodicArray{T, N, A, F}, dims...) where {T, N, A, F} base = reverse(parent(arr); dims = dims) dimsset = Set(dims) - @inline function map_rev(x::T, shifts::Vararg{Int, N}) + @inline function map_rev(x, shifts::Vararg{Integer, N}) adj = ntuple(i -> (i in dimsset) ? -shifts[i] : shifts[i], N) return arr.map(x, adj...) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 03e9d1d..83e311f 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -316,6 +316,45 @@ end end end +@testset "circshift" begin + @testset "1D" begin + data = [1, 2, 3, 4, 5] + a = PeriodicVector(data) + @test circshift(a, 1) == PeriodicVector([5, 1, 2, 3, 4]) + @test circshift(a, -1) == PeriodicVector([2, 3, 4, 5, 1]) + @test circshift(a, 0) == a + @test circshift(a, length(a)) == a + for s in (-3, -1, 0, 2, 5) + cs = circshift(a, s) + @test all(cs[i] == a[i - s] for i in -20:20) + end + end + + @testset "2D" begin + data = [1 2 3; 4 5 6] + a = PeriodicMatrix(data) + for s in ((0, 0), (1, 0), (0, 1), (1, 2), (-1, -1)) + cs = circshift(a, s) + @test all(cs[i, j] == a[i - s[1], j - s[2]] for i in -5:5, j in -5:5) + end + end + + @testset "circshift! 3-arg" begin + a = PeriodicVector([1, 2, 3, 4, 5]) + dest = similar(a) + circshift!(dest, a, 2) + @test dest == circshift(a, 2) + @test parent(a) == [1, 2, 3, 4, 5] + end + + @testset "circshift! in-place" begin + a = PeriodicVector([1, 2, 3, 4, 5]) + expected = circshift(a, 2) + circshift!(a, 2) + @test a == expected + end +end + @testset "offset indices" begin i = OffsetArray(1:5, -3) a = PeriodicArray(i) diff --git a/test/test_nontrivial_boundary.jl b/test/test_nontrivial_boundary.jl index 1c1c6c4..52c82f8 100644 --- a/test/test_nontrivial_boundary.jl +++ b/test/test_nontrivial_boundary.jl @@ -322,7 +322,7 @@ for f in translation_functions circ_a = circshift(a, 3) @test axes(circ_a) == axes(a) - @test circ_a[1:5] == [1, 2, f(3, 1), f(4, 1), f(5, 1)] + @test circ_a[1:5] == [1, 2, 3, 4, 5] j = OffsetArray([true, false, true], 1) @test a[j] == [5, f(2, 1)] @@ -410,4 +410,50 @@ for f in translation_functions @test parent(rb) == reverse(parent(b)) @test all(rb[i, j] == b[4 - i, 3 - j] for i in -10:10, j in -10:10) end + + @testset "circshift" begin + @testset "1D" begin + data = [1, 2, 3, 4, 5] + a = PeriodicVector(data, f) + # result[k] == a[k - s] for all k, verified across many shifts and indices + for s in (-3, -1, 0, 1, 2, 5) + cs = circshift(a, s) + @test all(cs[i] == a[i - s] for i in -20:20) + end + # the "seam": shift by 1 pulls a[0] (which applies f with cell shift -1) into position 1 + cs1 = circshift(a, 1) + @test cs1[1] == a[0] + @test cs1[1] == f(data[end], -1) + end + + @testset "2D" begin + b = PeriodicArray(reshape(1:6, 3, 2), f) + for s in ((1, 0), (0, 1), (1, 1), (-1, 2)) + cs = circshift(b, s) + @test all(cs[i, j] == b[i - s[1], j - s[2]] for i in -10:10, j in -10:10) + end + end + + @testset "roundtrip" begin + a = PeriodicVector([1, 2, 3, 4, 5], f) + for s in (-3, 0, 1, 2) + @test parent(circshift(circshift(a, s), -s)) == parent(a) + end + end + + @testset "circshift! 3-arg" begin + a = PeriodicVector([1, 2, 3, 4, 5], f) + dest = similar(a) + circshift!(dest, a, 2) + @test all(dest[i] == a[i - 2] for i in 1:5) + @test parent(a) == [1, 2, 3, 4, 5] + end + + @testset "circshift! in-place" begin + a = PeriodicVector([1, 2, 3, 4, 5], f) + expected = circshift(a, 2) + circshift!(a, 2) + @test parent(a) == parent(expected) + end + end end