Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions src/PeriodicArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ PeriodicMatrix(args...) = PeriodicArray(args...)
Base.IndexStyle(::Type{PeriodicArray{T, N, A, F}}) where {T, N, A, F} = IndexCartesian()
Base.IndexStyle(::Type{<:PeriodicVector}) = IndexLinear()

function cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N}
function cell_position(arr::AbstractArray{T, N}, I::Vararg{Integer, N}) where {T, N}
axs = axes(arr)
i_base = ntuple(N) do d
ax = axs[d]
Expand All @@ -81,7 +81,7 @@ function cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N}
i_shift = ntuple(d -> fld(I[d] - i_base[d], length(axs[d])), N)
return i_base, i_shift
end
function inverse_cell_position(arr::AbstractArray{T, N}, I::Vararg{Int, N}) where {T, N}
function inverse_cell_position(arr::AbstractArray{T, N}, I::Vararg{Integer, N}) where {T, N}
axs = axes(arr)
i_base = ntuple(N) do d
ax = axs[d]
Expand Down Expand Up @@ -269,7 +269,7 @@ function Base.repeat(A::PeriodicArray{T, N}; inner = nothing, outer = nothing) w
end
end

@inline function map_new(x::T, shift::Vararg{Int, N})
@inline function map_new(x, shift::Vararg{Integer, N})
# shifts passed to this map refer to super-cell shifts; amplify
# by `outer` to convert them to original unit-cell shifts.
amplified = ntuple(i -> shift[i] * outer[i], N)
Expand All @@ -282,6 +282,51 @@ function Base.repeat(A::PeriodicArray{T, N}; inner = nothing, outer = nothing) w
return PeriodicArray(A_new, map)
end

_circshift_amounts(::Val{N}, s::Integer) where {N} = ntuple(d -> d == 1 ? Int(s) : 0, N)
_circshift_amounts(::Val{N}, s) where {N} = ntuple(d -> d <= length(s) ? Int(s[d]) : 0, N)

function _circshift_pa!(
dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts
) where {T, N}
s = _circshift_amounts(Val(N), shifts)
src_data = parent(src)
dest_data = parent(dest)
for k in CartesianIndices(dest_data)
i = ntuple(d -> k[d] - s[d], N)
i_base, i_shift = cell_position(src_data, i...)
v = src_data[i_base...]
dest_data[k] = src.map(v, i_shift...)
end
return dest
end

# circshift: multiple signatures to disambiguate from Base methods
Base.circshift(arr::PeriodicArray{T, N}, shifts::NTuple{M, Integer}) where {T, N, M} =
_circshift_pa!(similar(arr), arr, shifts)
Base.circshift(arr::PeriodicArray{T, N}, shift::Real) where {T, N} =
_circshift_pa!(similar(arr), arr, shift)
Base.circshift(arr::PeriodicArray{T, N}, shifts::AbstractVector{<:Integer}) where {T, N} =
_circshift_pa!(similar(arr), arr, shifts)

# circshift! 2-arg (in-place)
function Base.circshift!(arr::PeriodicArray{T, N}, shifts) where {T, N}
src = PeriodicArray(copy(parent(arr)), arr.map)
return _circshift_pa!(arr, src, shifts)
end
# disambiguate with Base.circshift!(::AbstractVector, ::Integer)
Base.circshift!(arr::PeriodicVector, shift::Integer) = circshift!(arr, (shift,))

# circshift! 3-arg: specific shift types to disambiguate from Base methods
Base.circshift!(
dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts::NTuple{M, Integer}
) where {T, N, M} = _circshift_pa!(dest, src, shifts)
Base.circshift!(
dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, ::Tuple{}
) where {T, N} = _circshift_pa!(dest, src, ())
Base.circshift!(
dest::PeriodicArray{T, N}, src::PeriodicArray{T, N}, shifts::AbstractVector{<:Integer}
) where {T, N} = _circshift_pa!(dest, src, shifts)

function Base.reverse(arr::PeriodicArray{T, N, A, F}; dims = :) where {T, N, A, F}
dims == Colon() && return _reverse(arr)
return _reverse(arr, dims)
Expand All @@ -290,7 +335,7 @@ end
function _reverse(arr::PeriodicArray{T, N, A, F}) where {T, N, A, F}
base = reverse(parent(arr))

@inline function map_rev(x::T, shifts::Vararg{Int, N})
@inline function map_rev(x, shifts::Vararg{Integer, N})
neg = ntuple(i -> -shifts[i], N)
return arr.map(x, neg...)
end
Expand All @@ -302,7 +347,7 @@ function _reverse(arr::PeriodicArray{T, N, A, F}, dims...) where {T, N, A, F}
base = reverse(parent(arr); dims = dims)
dimsset = Set(dims)

@inline function map_rev(x::T, shifts::Vararg{Int, N})
@inline function map_rev(x, shifts::Vararg{Integer, N})
adj = ntuple(i -> (i in dimsset) ? -shifts[i] : shifts[i], N)
return arr.map(x, adj...)
end
Expand Down
39 changes: 39 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,45 @@ end
end
end

@testset "circshift" begin
@testset "1D" begin
data = [1, 2, 3, 4, 5]
a = PeriodicVector(data)
@test circshift(a, 1) == PeriodicVector([5, 1, 2, 3, 4])
@test circshift(a, -1) == PeriodicVector([2, 3, 4, 5, 1])
@test circshift(a, 0) == a
@test circshift(a, length(a)) == a
for s in (-3, -1, 0, 2, 5)
cs = circshift(a, s)
@test all(cs[i] == a[i - s] for i in -20:20)
end
end

@testset "2D" begin
data = [1 2 3; 4 5 6]
a = PeriodicMatrix(data)
for s in ((0, 0), (1, 0), (0, 1), (1, 2), (-1, -1))
cs = circshift(a, s)
@test all(cs[i, j] == a[i - s[1], j - s[2]] for i in -5:5, j in -5:5)
end
end

@testset "circshift! 3-arg" begin
a = PeriodicVector([1, 2, 3, 4, 5])
dest = similar(a)
circshift!(dest, a, 2)
@test dest == circshift(a, 2)
@test parent(a) == [1, 2, 3, 4, 5]
end

@testset "circshift! in-place" begin
a = PeriodicVector([1, 2, 3, 4, 5])
expected = circshift(a, 2)
circshift!(a, 2)
@test a == expected
end
end

@testset "offset indices" begin
i = OffsetArray(1:5, -3)
a = PeriodicArray(i)
Expand Down
48 changes: 47 additions & 1 deletion test/test_nontrivial_boundary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ for f in translation_functions

circ_a = circshift(a, 3)
@test axes(circ_a) == axes(a)
@test circ_a[1:5] == [1, 2, f(3, 1), f(4, 1), f(5, 1)]
@test circ_a[1:5] == [1, 2, 3, 4, 5]

j = OffsetArray([true, false, true], 1)
@test a[j] == [5, f(2, 1)]
Expand Down Expand Up @@ -410,4 +410,50 @@ for f in translation_functions
@test parent(rb) == reverse(parent(b))
@test all(rb[i, j] == b[4 - i, 3 - j] for i in -10:10, j in -10:10)
end

@testset "circshift" begin
@testset "1D" begin
data = [1, 2, 3, 4, 5]
a = PeriodicVector(data, f)
# result[k] == a[k - s] for all k, verified across many shifts and indices
for s in (-3, -1, 0, 1, 2, 5)
cs = circshift(a, s)
@test all(cs[i] == a[i - s] for i in -20:20)
end
# the "seam": shift by 1 pulls a[0] (which applies f with cell shift -1) into position 1
cs1 = circshift(a, 1)
@test cs1[1] == a[0]
@test cs1[1] == f(data[end], -1)
end

@testset "2D" begin
b = PeriodicArray(reshape(1:6, 3, 2), f)
for s in ((1, 0), (0, 1), (1, 1), (-1, 2))
cs = circshift(b, s)
@test all(cs[i, j] == b[i - s[1], j - s[2]] for i in -10:10, j in -10:10)
end
end

@testset "roundtrip" begin
a = PeriodicVector([1, 2, 3, 4, 5], f)
for s in (-3, 0, 1, 2)
@test parent(circshift(circshift(a, s), -s)) == parent(a)
end
end

@testset "circshift! 3-arg" begin
a = PeriodicVector([1, 2, 3, 4, 5], f)
dest = similar(a)
circshift!(dest, a, 2)
@test all(dest[i] == a[i - 2] for i in 1:5)
@test parent(a) == [1, 2, 3, 4, 5]
end

@testset "circshift! in-place" begin
a = PeriodicVector([1, 2, 3, 4, 5], f)
expected = circshift(a, 2)
circshift!(a, 2)
@test parent(a) == parent(expected)
end
end
end
Loading