diff --git a/src/PeriodicArrays.jl b/src/PeriodicArrays.jl index e5240dd..3f3b79b 100644 --- a/src/PeriodicArrays.jl +++ b/src/PeriodicArrays.jl @@ -69,7 +69,7 @@ PeriodicMatrix(args...) = PeriodicArray(args...) Base.IndexStyle(::Type{PeriodicArray{T, N, A, F}}) where {T, N, A, F} = IndexCartesian() Base.IndexStyle(::Type{<:PeriodicVector}) = IndexLinear() -function cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N} +function cell_position(arr::AbstractArray{T, N}, I::Vararg{Integer, N}) where {T, N} axs = axes(arr) i_base = ntuple(N) do d ax = axs[d] @@ -81,7 +81,7 @@ function cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N} i_shift = ntuple(d -> fld(I[d] - i_base[d], length(axs[d])), N) return i_base, i_shift end -function inverse_cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N} +function inverse_cell_position(arr::AbstractArray{T, N}, I::Vararg{Integer, N}) where {T, N} axs = axes(arr) i_base = ntuple(N) do d ax = axs[d] @@ -269,7 +269,7 @@ function Base.repeat(A::PeriodicArray{T, N}; inner = nothing, outer = nothing) w end end - @inline function map_new(x::T, shift::Vararg{Int, N}) + @inline function map_new(x, shift::Vararg{Integer, N}) # shifts passed to this map refer to super-cell shifts; amplify # by `outer` to convert them to original unit-cell shifts. amplified = ntuple(i -> shift[i] * outer[i], N) @@ -282,6 +282,51 @@ function Base.repeat(A::PeriodicArray{T, N}; inner = nothing, outer = nothing) w return PeriodicArray(A_new, map) end +_circshift_amounts(::Val{N}, s::Integer) where {N} = ntuple(d -> d == 1 ? Int(s) : 0, N) +_circshift_amounts(::Val{N}, s) where {N} = ntuple(d -> d <= length(s) ? Int(s[d]) : 0, N) + +function _circshift_pa!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts + ) where {T, N} + s = _circshift_amounts(Val(N), shifts) + src_data = parent(src) + dest_data = parent(dest) + for k in CartesianIndices(dest_data) + i = ntuple(d -> k[d] - s[d], N) + i_base, i_shift = cell_position(src_data, i...) + v = src_data[i_base...] + dest_data[k] = src.map(v, i_shift...) + end + return dest +end + +# circshift: multiple signatures to disambiguate from Base methods +Base.circshift(arr::PeriodicArray{T, N}, shifts::NTuple{M, Integer}) where {T, N, M} = + _circshift_pa!(similar(arr), arr, shifts) +Base.circshift(arr::PeriodicArray{T, N}, shift::Real) where {T, N} = + _circshift_pa!(similar(arr), arr, shift) +Base.circshift(arr::PeriodicArray{T, N}, shifts::AbstractVector{<:Integer}) where {T, N} = + _circshift_pa!(similar(arr), arr, shifts) + +# circshift! 2-arg (in-place) +function Base.circshift!(arr::PeriodicArray{T, N}, shifts) where {T, N} + src = PeriodicArray(copy(parent(arr)), arr.map) + return _circshift_pa!(arr, src, shifts) +end +# disambiguate with Base.circshift!(::AbstractVector, ::Integer) +Base.circshift!(arr::PeriodicVector, shift::Integer) = circshift!(arr, (shift,)) + +# circshift! 3-arg: specific shift types to disambiguate from Base methods +Base.circshift!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts::NTuple{M, Integer} +) where {T, N, M} = _circshift_pa!(dest, src, shifts) +Base.circshift!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, ::Tuple{} +) where {T, N} = _circshift_pa!(dest, src, ()) +Base.circshift!( + dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts::AbstractVector{<:Integer} +) where {T, N} = _circshift_pa!(dest, src, shifts) + function Base.reverse(arr::PeriodicArray{T, N, A, F}; dims = :) where {T, N, A, F} dims == Colon() && return _reverse(arr) return _reverse(arr, dims) @@ -290,7 +335,7 @@ end function _reverse(arr::PeriodicArray{T, N, A, F}) where {T, N, A, F} base = reverse(parent(arr)) - @inline function map_rev(x::T, shifts::Vararg{Int, N}) + @inline function map_rev(x, shifts::Vararg{Integer, N}) neg = ntuple(i -> -shifts[i], N) return arr.map(x, neg...) end @@ -302,7 +347,7 @@ function _reverse(arr::PeriodicArray{T, N, A, F}, dims...) where {T, N, A, F} base = reverse(parent(arr); dims = dims) dimsset = Set(dims) - @inline function map_rev(x::T, shifts::Vararg{Int, N}) + @inline function map_rev(x, shifts::Vararg{Integer, N}) adj = ntuple(i -> (i in dimsset) ? -shifts[i] : shifts[i], N) return arr.map(x, adj...) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 03e9d1d..83e311f 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -316,6 +316,45 @@ end end end +@testset "circshift" begin + @testset "1D" begin + data = [1, 2, 3, 4, 5] + a = PeriodicVector(data) + @test circshift(a, 1) == PeriodicVector([5, 1, 2, 3, 4]) + @test circshift(a, -1) == PeriodicVector([2, 3, 4, 5, 1]) + @test circshift(a, 0) == a + @test circshift(a, length(a)) == a + for s in (-3, -1, 0, 2, 5) + cs = circshift(a, s) + @test all(cs[i] == a[i - s] for i in -20:20) + end + end + + @testset "2D" begin + data = [1 2 3; 4 5 6] + a = PeriodicMatrix(data) + for s in ((0, 0), (1, 0), (0, 1), (1, 2), (-1, -1)) + cs = circshift(a, s) + @test all(cs[i, j] == a[i - s[1], j - s[2]] for i in -5:5, j in -5:5) + end + end + + @testset "circshift! 3-arg" begin + a = PeriodicVector([1, 2, 3, 4, 5]) + dest = similar(a) + circshift!(dest, a, 2) + @test dest == circshift(a, 2) + @test parent(a) == [1, 2, 3, 4, 5] + end + + @testset "circshift! in-place" begin + a = PeriodicVector([1, 2, 3, 4, 5]) + expected = circshift(a, 2) + circshift!(a, 2) + @test a == expected + end +end + @testset "offset indices" begin i = OffsetArray(1:5, -3) a = PeriodicArray(i) diff --git a/test/test_nontrivial_boundary.jl b/test/test_nontrivial_boundary.jl index 1c1c6c4..52c82f8 100644 --- a/test/test_nontrivial_boundary.jl +++ b/test/test_nontrivial_boundary.jl @@ -322,7 +322,7 @@ for f in translation_functions circ_a = circshift(a, 3) @test axes(circ_a) == axes(a) - @test circ_a[1:5] == [1, 2, f(3, 1), f(4, 1), f(5, 1)] + @test circ_a[1:5] == [1, 2, 3, 4, 5] j = OffsetArray([true, false, true], 1) @test a[j] == [5, f(2, 1)] @@ -410,4 +410,50 @@ for f in translation_functions @test parent(rb) == reverse(parent(b)) @test all(rb[i, j] == b[4 - i, 3 - j] for i in -10:10, j in -10:10) end + + @testset "circshift" begin + @testset "1D" begin + data = [1, 2, 3, 4, 5] + a = PeriodicVector(data, f) + # result[k] == a[k - s] for all k, verified across many shifts and indices + for s in (-3, -1, 0, 1, 2, 5) + cs = circshift(a, s) + @test all(cs[i] == a[i - s] for i in -20:20) + end + # the "seam": shift by 1 pulls a[0] (which applies f with cell shift -1) into position 1 + cs1 = circshift(a, 1) + @test cs1[1] == a[0] + @test cs1[1] == f(data[end], -1) + end + + @testset "2D" begin + b = PeriodicArray(reshape(1:6, 3, 2), f) + for s in ((1, 0), (0, 1), (1, 1), (-1, 2)) + cs = circshift(b, s) + @test all(cs[i, j] == b[i - s[1], j - s[2]] for i in -10:10, j in -10:10) + end + end + + @testset "roundtrip" begin + a = PeriodicVector([1, 2, 3, 4, 5], f) + for s in (-3, 0, 1, 2) + @test parent(circshift(circshift(a, s), -s)) == parent(a) + end + end + + @testset "circshift! 3-arg" begin + a = PeriodicVector([1, 2, 3, 4, 5], f) + dest = similar(a) + circshift!(dest, a, 2) + @test all(dest[i] == a[i - 2] for i in 1:5) + @test parent(a) == [1, 2, 3, 4, 5] + end + + @testset "circshift! in-place" begin + a = PeriodicVector([1, 2, 3, 4, 5], f) + expected = circshift(a, 2) + circshift!(a, 2) + @test parent(a) == parent(expected) + end + end end