Skip to content

Commit 1a78ae9

Browse files
committed
Improvements to iterators including taking length and iteration for one sweep case. (Includes changes from PR #262.)
1 parent 8cb36ef commit 1a78ae9

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

src/solvers/iterators.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ abstract type AbstractNetworkIterator end
1212
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)
1313

1414
function Base.iterate(iterator::AbstractNetworkIterator, init = true)
15-
islaststep(iterator) && return nothing
15+
# The assumption is that first "increment!" is implicit, therefore we must skip the
16+
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
17+
# defined when length < 1,
18+
init || islaststep(iterator) && return nothing
1619
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
1720
# define a method for increment! This way we avoid cases where one may wish to nest
1821
# calls to different step! methods accidentaly incrementing multiple times.
@@ -44,6 +47,9 @@ mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
4447
which_region::Int
4548
const which_sweep::Int
4649
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
50+
if length(region_plan) == 0
51+
throw(BoundsError("Cannot construct a region iterator with 0 elements."))
52+
end
4753
return new{P, R}(problem, region_plan, 1, sweep)
4854
end
4955
end
@@ -115,26 +121,33 @@ region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs
115121

116122
mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
117123
region_iter::RegionIterator{Problem}
118-
sweep_kwargs::Iterators.Stateful{Iter}
124+
sweep_kwargs::Iter
119125
which_sweep::Int
126+
nsweeps::Int
120127
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
121-
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)
122-
first_kwargs, _ = Iterators.peel(stateful_sweep_kwargs)
128+
first_state = Iterators.peel(sweep_kwargs)
129+
if isnothing(first_state)
130+
throw(BoundsError("Cannot construct a sweep iterator with 0 elements."))
131+
end
132+
first_kwargs, sweep_kwargs_rest = first_state
123133
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)
124-
return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
134+
return new{Prob, typeof(sweep_kwargs_rest)}(region_iter, sweep_kwargs_rest, 1, length(sweep_kwargs))
125135
end
126136
end
127137

128-
islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))
138+
islaststep(sweep_iter::SweepIterator) = isempty(sweep_iter.sweep_kwargs)
129139

130140
region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter
141+
131142
problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))
132143

133144
state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
134-
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)
145+
146+
Base.length(sweep_iter::SweepIterator) = sweep_iter.nsweeps
147+
135148
function increment!(sweep_iter::SweepIterator)
136149
sweep_iter.which_sweep += 1
137-
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
150+
sweep_kwargs, sweep_iter.sweep_kwargs = Iterators.peel(sweep_iter.sweep_kwargs)
138151
update_region_iterator!(sweep_iter; sweep_kwargs...)
139152
return sweep_iter
140153
end

0 commit comments

Comments
 (0)