From f78c5e66610f45b6c368406f08aea1d71103da93 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 16 Dec 2025 17:19:50 -0500 Subject: [PATCH 1/3] [WIP] Sweeping algorithms based on AlgorithmsInterface.jl --- Project.toml | 4 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- .../AlgorithmsInterfaceExtensions.jl | 281 ++++++++++++++++++ src/ITensorNetworksNext.jl | 6 +- src/abstract_problem.jl | 1 - src/adapters.jl | 45 --- src/eigenproblem.jl | 93 ++++++ src/iterators.jl | 170 ----------- src/sweeping.jl | 54 ++++ test/Project.toml | 2 +- test/test_iterators.jl | 221 -------------- 12 files changed, 437 insertions(+), 444 deletions(-) create mode 100644 src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl delete mode 100644 src/abstract_problem.jl delete mode 100644 src/adapters.jl create mode 100644 src/eigenproblem.jl delete mode 100644 src/iterators.jl create mode 100644 src/sweeping.jl delete mode 100644 test/test_iterators.jl diff --git a/Project.toml b/Project.toml index 85efef2..1dc3858 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" -version = "0.2.3" +version = "0.3.0" authors = ["ITensor developers and contributors"] [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" @@ -32,6 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" Adapt = "4.3" +AlgorithmsInterface = "0.1.0" BackendSelection = "0.1.6" Combinatorics = "1" DataGraphs = "0.2.7" diff --git a/docs/Project.toml b/docs/Project.toml index 266b345..7262ba4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" [compat] Documenter = "1" Literate = "2" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" diff --git a/examples/Project.toml b/examples/Project.toml index 1e3b0ad..eb46f1a 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" [compat] -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl new file mode 100644 index 0000000..df636b2 --- /dev/null +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -0,0 +1,281 @@ +module AlgorithmsInterfaceExtensions + +import AlgorithmsInterface as AI + +#========================== Patches for AlgorithmsInterface.jl ============================# + +abstract type Problem <: AI.Problem end +abstract type Algorithm <: AI.Algorithm end +abstract type State <: AI.State end + +function AI.initialize_state!( + problem::Problem, algorithm::Algorithm, state::State; iterate = nothing + ) + !isnothing(iterate) && (state.iterate = iterate) + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.initialize_state( + problem::Problem, algorithm::Algorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultState(; stopping_criterion_state, kwargs...) +end + +#============================ DefaultState ================================================# + +@kwdef mutable struct DefaultState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +#============================ increment! ==================================================# + +# Custom version of `increment!` that also takes the problem and algorithm as arguments. +function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) + return AI.increment!(state) +end + +#============================ solve! ======================================================# + +# Custom version of `solve!` that allows specifying the logger and also overloads +# `increment!` on the problem and algorithm. +function basetypenameof(x) + return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), "."))) +end +default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) +function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) + return Symbol( + default_logging_context_prefix(problem), + default_logging_context_prefix(algorithm), + ) +end +function AI.solve!( + problem::Problem, algorithm::Algorithm, state::State; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + logger = AI.algorithm_logger() + + context_suffixes = [:Start, :PreStep, :PostStep, :Stop] + contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes)) + + # initialize the state and emit message + AI.initialize_state!(problem, algorithm, state; kwargs...) + AI.emit_message(logger, problem, algorithm, state, contexts[:Start]) + + # main body of the algorithm + while !AI.is_finished!(problem, algorithm, state) + AI.increment!(problem, algorithm, state) + + # logging event between convergence check and algorithm step + AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep]) + + # algorithm step + AI.step!(problem, algorithm, state; logging_context_prefix) + + # logging event between algorithm step and convergence check + AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep]) + end + + # emit message about finished state + AI.emit_message(logger, problem, algorithm, state, contexts[:Stop]) + return state +end + +function AI.solve( + problem::Problem, algorithm::Algorithm; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + state = AI.initialize_state(problem, algorithm; kwargs...) + return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) +end + +#============================ AlgorithmIterator ===========================================# + +abstract type AlgorithmIterator end + +function algorithm_iterator( + problem::Problem, algorithm::Algorithm, state::State + ) + return DefaultAlgorithmIterator(problem, algorithm, state) +end + +function AI.is_finished!(iterator::AlgorithmIterator) + return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.is_finished(iterator::AlgorithmIterator) + return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.increment!(iterator::AlgorithmIterator) + return AI.increment!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.step!(iterator::AlgorithmIterator) + return AI.step!(iterator.problem, iterator.algorithm, iterator.state) +end +function Base.iterate(iterator::AlgorithmIterator, init = nothing) + AI.is_finished!(iterator) && return nothing + AI.increment!(iterator) + AI.step!(iterator) + return iterator.state, nothing +end + +struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator + problem::Problem + algorithm::Algorithm + state::State +end + +#============================ with_algorithmlogger ========================================# + +# Allow passing functions, not just CallbackActions. +@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) + return AI.with_algorithmlogger(f, args...) +end +@inline function with_algorithmlogger(f, args::Pair{Symbol}...) + return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) +end + +#============================ NestedAlgorithm =============================================# + +abstract type NestedAlgorithm <: Algorithm end + +function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +end + +max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) + +function get_subproblem( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + +function set_substate!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State + ) + state.iterate = substate.iterate + return state +end + +function AI.step!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State; + logging_context_prefix = Symbol() + ) + # Get the subproblem, subalgorithm, and substate. + subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) + + # Solve the subproblem with the subalgorithm. + logging_context_prefix = Symbol( + logging_context_prefix, default_logging_context_prefix(subalgorithm) + ) + AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix) + + # Update the state with the substate. + set_substate!(problem, algorithm, state, substate) + + return state +end + +#= + DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm}) + +An algorithm that consists of running an algorithm at each iteration +from a list of stored algorithms. +=# +@kwdef struct DefaultNestedAlgorithm{ + Algorithms <: AbstractVector{<:Algorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end +function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +#============================ FlattenedAlgorithm ==========================================# + +# Flatten a nested algorithm. +abstract type FlattenedAlgorithm <: Algorithm end +abstract type FlattenedAlgorithmState <: State end + +function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...) +end + +function AI.initialize_state( + problem::Problem, algorithm::FlattenedAlgorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...) +end +function AI.increment!( + problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState + ) + # Increment the total iteration count. + state.iteration += 1 + # TODO: Use `is_finished!` instead? + if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration]) + # We're on the last iteration of the child algorithm, so move to the next + # child algorithm. + state.parent_iteration += 1 + state.child_iteration = 1 + else + # Iterate the child algorithm. + state.child_iteration += 1 + end + return state +end +function AI.step!( + problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState; + logging_context_prefix = Symbol() + ) + algorithm_sweep = algorithm.algorithms[state.parent_iteration] + state_sweep = AI.initialize_state( + problem, algorithm_sweep; + state.iterate, iteration = state.child_iteration + ) + AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix) + state.iterate = state_sweep.iterate + return state +end + +@kwdef struct DefaultFlattenedAlgorithm{ + Algorithms <: AbstractVector{<:Algorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: FlattenedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = + AI.StopAfterIteration(sum(max_iterations, algorithms)) +end +function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +@kwdef mutable struct DefaultFlattenedAlgorithmState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: FlattenedAlgorithmState + iterate::Iterate + iteration::Int = 0 + parent_iteration::Int = 1 + child_iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 19c4109..5c48287 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,12 +1,12 @@ module ITensorNetworksNext +include("AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl") include("LazyNamedDimsArrays/LazyNamedDimsArrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("abstract_problem.jl") -include("iterators.jl") -include("adapters.jl") +include("sweeping.jl") +include("eigenproblem.jl") end diff --git a/src/abstract_problem.jl b/src/abstract_problem.jl deleted file mode 100644 index 5a65e0a..0000000 --- a/src/abstract_problem.jl +++ /dev/null @@ -1 +0,0 @@ -abstract type AbstractProblem end diff --git a/src/adapters.jl b/src/adapters.jl deleted file mode 100644 index 28318fb..0000000 --- a/src/adapters.jl +++ /dev/null @@ -1,45 +0,0 @@ -""" - struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator - -Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the -process. This allows one to manually call a custom `compute!` or insert their own code it in -the loop body in place of `compute!`. -""" -struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator - parent::S -end - -islaststep(adapter::IncrementOnly) = islaststep(adapter.parent) -increment!(adapter::IncrementOnly) = increment!(adapter.parent) -compute!(adapter::IncrementOnly) = adapter - -IncrementOnly(adapter::IncrementOnly) = adapter - -""" - struct EachRegion{SweepIterator} <: AbstractNetworkIterator - -Adapter that flattens each region iterator in the parent sweep iterator into a single -iterator. -""" -struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator - parent::SI -end - -# In keeping with Julia convention. -eachregion(iter::SweepIterator) = EachRegion(iter) - -# Essential definitions -function islaststep(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - return islaststep(adapter.parent) && islaststep(region_iter) -end -function increment!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) - return adapter -end -function compute!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - compute!(region_iter) - return adapter -end diff --git a/src/eigenproblem.jl b/src/eigenproblem.jl new file mode 100644 index 0000000..2b4ec37 --- /dev/null +++ b/src/eigenproblem.jl @@ -0,0 +1,93 @@ +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE + +maybe_fill(value, len::Int) = fill(value, len) +function maybe_fill(v::AbstractVector, len::Int) + @assert length(v) == len + return v +end + +function dmrg_sweep(operator, algorithm, state) + problem = select_problem(dmrg_sweep, operator, algorithm, state) + return AI.solve(problem, algorithm; iterate = state).iterate +end +function dmrg_sweep(operator, state; kwargs...) + algorithm = select_algorithm(dmrg_sweep, operator, state; kwargs...) + return dmrg_sweep(operator, algorithm, state) +end + +function select_problem(::typeof(dmrg_sweep), operator, algorithm, state) + return EigenProblem(operator) +end +function select_algorithm(::typeof(dmrg_sweep), operator, state; regions, region_kwargs) + region_kwargs′ = maybe_fill(region_kwargs, length(regions)) + return Sweep(length(regions)) do i + return Returns(Region(regions[i]; region_kwargs′[i]...)) + end +end + +function dmrg(operator, algorithm, state) + problem = select_problem(dmrg, operator, algorithm, state) + return AI.solve(problem, algorithm; iterate = state).iterate +end +function dmrg(operator, state; kwargs...) + algorithm = select_algorithm(dmrg, operator, state; kwargs...) + return dmrg(operator, algorithm, state) +end + +function select_problem(::typeof(dmrg), operator, algorithm, state) + return EigenProblem(operator) +end +function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, region_kwargs) + region_kwargs′ = maybe_fill(region_kwargs, nsweeps) + return Sweeping(nsweeps) do i + return select_algorithm( + dmrg_sweep, operator, state; + regions, region_kwargs = region_kwargs′[i], + ) + end +end + +#= + EigenProblem(operator) + +Represents the problem we are trying to solve and minimal algorithm-independent +information, so for an eigenproblem it is the operator we want the eigenvector of. +=# +struct EigenProblem{Operator} <: AIE.Problem + operator::Operator +end + +function AI.step!(problem::EigenProblem, algorithm::Sweep, state::AI.State; kwargs...) + iterate = solve_region!!( + problem, algorithm.region_algorithms[state.iteration](state.iterate), state.iterate + ) + state.iterate = iterate + return state +end + +# extract!, update!, insert! for the region. +function solve_region!!(problem::EigenProblem, algorithm::RegionAlgorithm, state) + operator = problem.operator + region = algorithm.region + region_kwargs = algorithm.kwargs + + #= + # Reduce the `operator` and state `x` onto the region `region`, + # and call `eigsolve` on the reduced operator and state using the + # keyword arguments determined from `region_kwargs`. + operator_region = reduced_operator(operator, x, region) + x_region = reduced_state(x, region) + x_region′ = eigsolve(operator_region, x_region; region_kwargs.update...) + x′ = insert(x, region, x_region′; region_kwargs.insert...) + state.state = x′ + =# + + # Dummy update for demonstration purposes. + state′ = "region = $region" * + ", update_kwargs = $(region_kwargs.update)" * + ", insert_kwargs = $(region_kwargs.insert)" + state = [state; [state′]] + + return state +end diff --git a/src/iterators.jl b/src/iterators.jl deleted file mode 100644 index 62d5b21..0000000 --- a/src/iterators.jl +++ /dev/null @@ -1,170 +0,0 @@ -""" - abstract type AbstractNetworkIterator - -A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins -with a call to `increment!` before executing `compute!`, however the initial call to -`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that -this call is implict. Termination of the iterator is controlled by the function `done`. -""" -abstract type AbstractNetworkIterator end - -# We use greater than or equals here as we increment the state at the start of the iteration -islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) - -function Base.iterate(iterator::AbstractNetworkIterator, init = true) - # The assumption is that first "increment!" is implicit, therefore we must skip the - # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not - # defined when length < 1, - init || islaststep(iterator) && return nothing - # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* - # define a method for increment! This way we avoid cases where one may wish to nest - # calls to different step! methods accidentaly incrementing multiple times. - init || increment!(iterator) - rv = compute!(iterator) - return rv, false -end - -increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)})) -compute!(iterator::AbstractNetworkIterator) = iterator - -step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) -function step!(f, iterator::AbstractNetworkIterator) - compute!(iterator) - f(iterator) - increment!(iterator) - return iterator -end - -# -# RegionIterator -# -""" - struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator -""" -mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator - problem::Problem - region_plan::RegionPlan - which_region::Int - const which_sweep::Int - function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} - if isempty(region_plan) - throw(ArgumentError("Cannot construct a region iterator with 0 elements.")) - end - return new{P, R}(problem, region_plan, 1, sweep) - end -end - -function RegionIterator(problem; sweep, sweep_kwargs...) - plan = region_plan(problem; sweep_kwargs...) - return RegionIterator(problem, plan, sweep) -end - -state(region_iter::RegionIterator) = region_iter.which_region -Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) - -problem(region_iter::RegionIterator) = region_iter.problem - -function current_region_plan(region_iter::RegionIterator) - return region_iter.region_plan[region_iter.which_region] -end - -function current_region(region_iter::RegionIterator) - region, _ = current_region_plan(region_iter) - return region -end - -function region_kwargs(region_iter::RegionIterator) - _, kwargs = current_region_plan(region_iter) - return kwargs -end -function region_kwargs(f::Function, iter::RegionIterator) - return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) -end - -function prev_region(region_iter::RegionIterator) - state(region_iter) <= 1 && return nothing - prev, _ = region_iter.region_plan[region_iter.which_region - 1] - return prev -end - -function next_region(region_iter::RegionIterator) - islaststep(region_iter) && return nothing - next, _ = region_iter.region_plan[region_iter.which_region + 1] - return next -end - -# -# Functions associated with RegionIterator -# -function increment!(region_iter::RegionIterator) - region_iter.which_region += 1 - return region_iter -end - -function compute!(iter::RegionIterator) - extract!(iter; region_kwargs(extract!, iter)...) - update!(iter; region_kwargs(update!, iter)...) - insert!(iter; region_kwargs(insert!, iter)...) - - return iter -end - -region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) - -# -# SweepIterator -# - -mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator - region_iter::RegionIterator{Problem} - sweep_kwargs::Iterators.Stateful{Iter} - which_sweep::Int - function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} - stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) - - first_state = Iterators.peel(stateful_sweep_kwargs) - - if isnothing(first_state) - throw(ArgumentError("Cannot construct a sweep iterator with 0 elements.")) - end - - first_kwargs, _ = first_state - region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) - - return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) - end -end - -islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) - -region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter -problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) - -state(sweep_iter::SweepIterator) = sweep_iter.which_sweep -Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) -function increment!(sweep_iter::SweepIterator) - sweep_iter.which_sweep += 1 - sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) - update_region_iterator!(sweep_iter; sweep_kwargs...) - return sweep_iter -end - -function update_region_iterator!(iterator::SweepIterator; kwargs...) - sweep = state(iterator) - iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) - return iterator -end - -function compute!(sweep_iter::SweepIterator) - for _ in sweep_iter.region_iter - # TODO: Is it sensible to execute the default region callback function? - end - return -end - -# More basic constructor where sweep_kwargs are constant throughout sweeps -function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) - # Initialize this to an empty RegionIterator - sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) - return SweepIterator(problem, sweep_kwargs_iter) -end diff --git a/src/sweeping.jl b/src/sweeping.jl new file mode 100644 index 0000000..2c250e9 --- /dev/null +++ b/src/sweeping.jl @@ -0,0 +1,54 @@ +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE + +@kwdef struct Sweeping{ + Algorithms <: AbstractVector{<:AI.Algorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end +function Sweeping(f::Function, nalgorithms::Int; kwargs...) + return Sweeping(; algorithms = f.(1:nalgorithms), kwargs...) +end + +#= + Sweep(regions::AbsractVector, region_kwargs::Function, iteration::Int = 0) + Sweep(regions::AbsractVector, region_kwargs::NamedTuple, iteration::Int = 0) + +The "algorithm" for performing a single sweep over a list of regions. It also +stores a function that takes the problem, algorithm, and state (tensor network, current +region, etc.) and returns keyword arguments for performing the region update on the +current region. For simplicity, it also accepts a `NamedTuple` of keyword arguments +which is converted into a function that always returns the same keyword arguments +for an region. +=# +@kwdef struct Sweep{ + RegionAlgorithms <: AbstractVector, StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.Algorithm + region_algorithms::RegionAlgorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(region_algorithms)) +end +function Sweep(f, nalgorithms::Int; kwargs...) + region_algorithms = to_region_algorithm.(f.(1:nalgorithms)) + return Sweep(; region_algorithms, kwargs...) +end +to_region_algorithm(algorithm::Function) = algorithm +to_region_algorithm(algorithm) = Returns(region_algorithm(algorithm)) + +AIE.max_iterations(algorithm::Sweep) = length(algorithm.algorithms) + +abstract type RegionAlgorithm end +region_algorithm(algorithm::RegionAlgorithm) = algorithm +region_algorithm(algorithm::NamedTuple) = Region(; algorithm...) + +struct Region{R, Kwargs <: NamedTuple} <: RegionAlgorithm + region::R + kwargs::Kwargs +end +function Region(; region, kwargs...) + return Region(region, (; kwargs...)) +end +function Region(region; kwargs...) + return Region(region, (; kwargs...)) +end diff --git a/test/Project.toml b/test/Project.toml index cd6028d..2ba51bd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,7 +23,7 @@ DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.3" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" NamedDimsArrays = "0.8" NamedGraphs = "0.6.8, 0.7, 0.8" QuadGK = "2.11.2" diff --git a/test/test_iterators.jl b/test/test_iterators.jl deleted file mode 100644 index a17c7be..0000000 --- a/test/test_iterators.jl +++ /dev/null @@ -1,221 +0,0 @@ -using Test: @test, @testset, @test_throws -import ITensorNetworksNext as ITensorNetworks -using .ITensorNetworks: RegionIterator, SweepIterator, IncrementOnly, compute!, increment!, islaststep, state, eachregion - -module TestIteratorUtils - - import ITensorNetworksNext as ITensorNetworks - using .ITensorNetworks - - struct TestProblem <: ITensorNetworks.AbstractProblem - data::Vector{Int} - end - ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] - function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) - kwargs = ITensorNetworks.region_kwargs(iter) - push!(ITensorNetworks.problem(iter).data, kwargs.val) - return iter - end - - - mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator - state::Int - max::Int - output::Vector{Int} - end - - ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 - Base.length(TI::TestIterator) = TI.max - ITensorNetworks.state(TI::TestIterator) = TI.state - function ITensorNetworks.compute!(TI::TestIterator) - push!(TI.output, ITensorNetworks.state(TI)) - return TI - end - - mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator - parent::TestIterator - end - - Base.length(SA::SquareAdapter) = length(SA.parent) - ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) - ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) - function ITensorNetworks.compute!(SA::SquareAdapter) - ITensorNetworks.compute!(SA.parent) - return last(SA.parent.output)^2 - end - -end - -@testset "Iterators" begin - - import .TestIteratorUtils - - @testset "`AbstractNetworkIterator` Interface" begin - - @testset "Edge cases" begin - TI = TestIteratorUtils.TestIterator(1, 1, []) - cb = [] - @test islaststep(TI) - for _ in TI - @test islaststep(TI) - push!(cb, state(TI)) - end - @test length(cb) == 1 - @test length(TI.output) == 1 - @test only(cb) == 1 - - prob = TestIteratorUtils.TestProblem([]) - @test_throws ArgumentError SweepIterator(prob, 0) - @test_throws ArgumentError RegionIterator(prob, [], 1) - end - - TI = TestIteratorUtils.TestIterator(1, 4, []) - - @test !islaststep((TI)) - - # First iterator should compute only - rv, st = iterate(TI) - @test !islaststep((TI)) - @test !st - @test rv === TI - @test length(TI.output) == 1 - @test only(TI.output) == 1 - @test state(TI) == 1 - @test !st - - rv, st = iterate(TI, st) - @test !islaststep((TI)) - @test !st - @test length(TI.output) == 2 - @test state(TI) == 2 - @test TI.output == [1, 2] - - increment!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 2 - @test TI.output == [1, 2] - - compute!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 3 - @test TI.output == [1, 2, 3] - - # Final Step - iterate(TI, false) - @test islaststep((TI)) - @test state(TI) == 4 - @test length(TI.output) == 4 - @test TI.output == [1, 2, 3, 4] - - @test iterate(TI, false) === nothing - - TI = TestIteratorUtils.TestIterator(1, 5, []) - - cb = [] - - for _ in TI - @test length(cb) == length(TI.output) - 1 - @test cb == (TI.output)[1:(end - 1)] - push!(cb, state(TI)) - @test cb == TI.output - end - - @test islaststep((TI)) - @test length(TI.output) == 5 - @test length(cb) == 5 - @test cb == TI.output - - - TI = TestIteratorUtils.TestIterator(1, 5, []) - end - - @testset "Adapters" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - @testset "Generic" begin - - i = 0 - for rv in SA - i += 1 - @test rv isa Int - @test rv == i^2 - @test state(SA) == i - end - - @test islaststep((SA)) - - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - SA_c = collect(SA) - - @test SA_c isa Vector - @test length(SA_c) == 5 - @test SA_c == [1, 4, 9, 16, 25] - - end - - @testset "IncrementOnly" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - NI = IncrementOnly(TI) - - NI_c = [] - - for _ in IncrementOnly(TI) - push!(NI_c, state(TI)) - end - - @test length(NI_c) == 5 - @test isempty(TI.output) - end - - @testset "EachRegion" begin - prob = TestIteratorUtils.TestProblem([]) - prob_region = TestIteratorUtils.TestProblem([]) - - SI = SweepIterator(prob, 5) - SI_region = SweepIterator(prob_region, 5) - - callback = [] - callback_region = [] - - let i = 1 - for _ in SI - push!(callback, i) - i += 1 - end - end - - @test length(callback) == 5 - - let i = 1 - for _ in eachregion(SI_region) - push!(callback_region, i) - i += 1 - end - end - - @test length(callback_region) == 10 - - @test prob.data == prob_region.data - - @test prob.data[1:2:end] == fill(1, 5) - @test prob.data[2:2:end] == fill(2, 5) - - - let i = 1, prob = TestIteratorUtils.TestProblem([]) - SI = SweepIterator(prob, 1) - cb = [] - for _ in eachregion(SI) - push!(cb, i) - i += 1 - end - @test length(cb) == 2 - end - - end - end -end From 85f9440b50ab1cabb0eaa2311ddae4f0416563d3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 17 Dec 2025 20:13:53 -0500 Subject: [PATCH 2/3] Reorganization, add tests from Claude --- src/ITensorNetworksNext.jl | 4 +- src/{ => sweeping}/eigenproblem.jl | 0 src/{ => sweeping}/sweeping.jl | 0 test/Project.toml | 1 + test/test_algorithmsinterfaceextensions.jl | 470 +++++++++++++++++++++ 5 files changed, 473 insertions(+), 2 deletions(-) rename src/{ => sweeping}/eigenproblem.jl (100%) rename src/{ => sweeping}/sweeping.jl (100%) create mode 100644 test/test_algorithmsinterfaceextensions.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 5c48287..d985782 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -6,7 +6,7 @@ include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("sweeping.jl") -include("eigenproblem.jl") +include("sweeping/sweeping.jl") +include("sweeping/eigenproblem.jl") end diff --git a/src/eigenproblem.jl b/src/sweeping/eigenproblem.jl similarity index 100% rename from src/eigenproblem.jl rename to src/sweeping/eigenproblem.jl diff --git a/src/sweeping.jl b/src/sweeping/sweeping.jl similarity index 100% rename from src/sweeping.jl rename to src/sweeping/sweeping.jl diff --git a/test/Project.toml b/test/Project.toml index 2630e94..e71e7a4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl new file mode 100644 index 0000000..2ae4a4c --- /dev/null +++ b/test/test_algorithmsinterfaceextensions.jl @@ -0,0 +1,470 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +# Define test problems, algorithms, and states for testing +struct TestProblem <: AIE.Problem + data::Vector{Float64} +end + +@kwdef struct TestAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) +end + +@kwdef struct TestAlgorithmStep{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(5) +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState; + logging_context_prefix = Symbol() + ) + state.iterate .+= 1 # Simple increment step + return state +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState; + logging_context_prefix = Symbol() + ) + state.iterate .+= 2 # Different increment step + return state +end + +@testset "AlgorithmsInterfaceExtensions" begin + @testset "DefaultState" begin + # Test DefaultState construction + iterate = [1.0, 2.0, 3.0] + stopping_criterion_state = AI.initialize_state( + TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion + ) + state = AIE.DefaultState(; iterate = copy(iterate), stopping_criterion_state) + @test state.iterate == iterate + @test state.iteration == 0 + @test state.stopping_criterion_state isa AI.StoppingCriterionState + + # Test DefaultState with custom iteration + state.iteration = 5 + @test state.iteration == 5 + end + + @testset "initialize_state!" begin + # Test initialize_state! with iterate kwarg + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + initial_iterate = [1.0, 2.0] + AIE.AI.initialize_state!(problem, algorithm, state; iterate = initial_iterate) + @test state.iterate == initial_iterate + @test state.iteration == 0 + end + + @testset "initialize_state" begin + # Test initialize_state without exclamation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + + state = AIE.AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) + @test state isa AIE.DefaultState + @test state.iteration == 0 + end + + @testset "increment!" begin + # Test increment! with problem and algorithm + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + # Increment and verify iteration counter increases + AI.increment!(problem, algorithm, state) + @test state.iteration == 1 + + AI.increment!(problem, algorithm, state) + @test state.iteration == 2 + end + + @testset "solve! and solve" begin + # Test solve! with simple problem + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) + + initial_iterate = [10.0, 20.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + + # Solve with custom initial iterate + final_state = AI.solve!( + problem, algorithm, state; iterate = copy(initial_iterate) + ) + + @test final_state.iteration == 3 + # Each step increments by 1, so after 3 steps: [10, 20] + 3 = [13, 23] + @test final_state.iterate ≈ [13.0, 23.0] + + # Test solve without exclamation + problem2 = TestProblem([1.0, 2.0]) + algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate2 = [5.0, 10.0] + + final_state2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + @test final_state2.iteration == 2 + @test final_state2.iterate ≈ [7.0, 12.0] + end + + @testset "DefaultAlgorithmIterator" begin + # Test algorithm iterator creation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + @test iterator isa AIE.DefaultAlgorithmIterator + @test iterator.problem === problem + @test iterator.algorithm === algorithm + @test iterator.state === state + + # Test iteration interface + @test !AI.is_finished!(iterator) + + # Step through iterator + state_out, _ = iterate(iterator) + @test state_out.iteration == 1 + @test state_out.iterate ≈ [1.0, 1.0] # Incremented by step! + + state_out, _ = iterate(iterator) + @test state_out.iteration == 2 + + @test AI.is_finished!(iterator) + end + + @testset "with_algorithmlogger" begin + # Test with_algorithmlogger with functions + results = [] + function callback1(problem, algorithm, state) + push!(results, :callback1) + return nothing + end + function callback2(problem, algorithm, state) + push!(results, :callback2) + return nothing + end + + problem = TestProblem([1.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + + # Test with CallbackAction (wrapped functions) + state = AIE.with_algorithmlogger( + :TestProblem_TestAlgorithm_PreStep => callback1, + :TestProblem_TestAlgorithm_PostStep => callback2, + ) do + return AI.solve(problem, algorithm; iterate = [0.0]) + end + @test results == [:callback1, :callback2] + end + + @testset "DefaultNestedAlgorithm" begin + # Test creating nested algorithm with function + nested_alg = AIE.nested_algorithm(3) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + @test nested_alg isa AIE.DefaultNestedAlgorithm + @test length(nested_alg.algorithms) == 3 + @test AIE.max_iterations(nested_alg) == 3 + + # Test stepping through nested algorithm + problem = TestProblem([1.0, 2.0]) + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + initial_iterate = [0.0, 0.0] + AI.solve!( + problem, nested_alg, state; iterate = copy(initial_iterate) + ) + + @test state.iteration == 3 + # Each nested algorithm runs once with 2 steps, incrementing by 2 + # Total: 3 algorithms × 2 iterations × 2 increment = 12 + @test state.iterate ≈ [12.0, 12.0] + end + + @testset "NestedAlgorithm basic tests" begin + # Test basic nested algorithm functionality + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + problem = TestProblem([1.0, 2.0]) + + # Test state initialization + state_nested = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + @test state_nested isa AIE.DefaultState + @test state_nested.iteration == 0 + @test AIE.max_iterations(nested_alg) == 2 + end + + @testset "increment! for nested algorithms" begin + # Test increment! logic for nested algorithm state + problem = TestProblem([1.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test progression through iterations + @test state.iteration == 0 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 1 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 2 + end + + @testset "get_subproblem and set_substate!" begin + # Test get_subproblem + problem = TestProblem([1.0, 2.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [5.0, 10.0], + iteration = 1, + stopping_criterion_state, + ) + + subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) + @test subproblem === problem + @test subalgorithm === nested_alg.algorithms[1] + @test substate.iterate ≈ [5.0, 10.0] + + # Test set_substate! + new_substate = AIE.DefaultState(; + iterate = [100.0, 200.0], + substate.stopping_criterion_state, + ) + AIE.set_substate!(problem, nested_alg, state, new_substate) + @test state.iterate ≈ [100.0, 200.0] + end + + @testset "basetypenameof and default_logging_context_prefix" begin + # Test basetypenameof utility + problem = TestProblem([1.0]) + algorithm = TestAlgorithm() + + prefix_problem = AIE.default_logging_context_prefix(problem) + prefix_algorithm = AIE.default_logging_context_prefix(algorithm) + prefix_combined = AIE.default_logging_context_prefix(problem, algorithm) + + @test prefix_problem isa Symbol + @test prefix_algorithm isa Symbol + @test prefix_combined isa Symbol + @test contains(String(prefix_combined), String(prefix_problem)) + end + + @testset "DefaultFlattenedAlgorithm" begin + # Create nested algorithms that support max_iterations + nested_algs = map(1:3) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(6) # 3 algorithms × 2 iterations each + ) + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 3 + + # Test state initialization + problem = TestProblem([1.0, 2.0]) + state_flat = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + @test state_flat isa AIE.DefaultFlattenedAlgorithmState + @test state_flat.iteration == 0 + @test state_flat.parent_iteration == 1 + @test state_flat.child_iteration == 0 + end + + @testset "DefaultFlattenedAlgorithmState increment!" begin + # Create nested algorithms for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4), + ) + + problem = TestProblem([1.0]) + stopping_criterion_state = AI.initialize_state( + problem, flattened_alg, flattened_alg.stopping_criterion + ) + state = AIE.DefaultFlattenedAlgorithmState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test initial state + @test state.iteration == 0 + @test state.parent_iteration == 1 + @test state.child_iteration == 0 + + # First increment - should increment child_iteration + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 1 + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + # Second increment - should increment child_iteration again + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 2 + @test state.parent_iteration == 2 # Should move to next parent + @test state.child_iteration == 1 + end + + @testset "FlattenedAlgorithm step!" begin + # Test individual step! calls for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4) + ) + + problem = TestProblem([1.0, 2.0]) + state = AIE.AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + # Manually step through to test step! functionality + AIE.AI.increment!(problem, flattened_alg, state) + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + AIE.AI.step!(problem, flattened_alg, state) + # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 + @test state.iterate ≈ [4.0, 4.0] + end + + @testset "flattened_algorithm helper" begin + # Test the flattened_algorithm helper function + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + # Using the helper function + flattened_alg = AIE.flattened_algorithm(2) do i + AIE.nested_algorithm(1) do j + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 2 + end + + @testset "AlgorithmIterator is_finished (without !)" begin + # Test is_finished without mutation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Before any iterations + @test !AI.is_finished(iterator) + + # Run the algorithm + AI.solve!(problem, algorithm, state; iterate = copy(initial_iterate)) + + # After completion + @test AI.is_finished(iterator) + end + + @testset "AlgorithmIterator step!" begin + # Test step! method for iterator + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Step the iterator + AI.step!(iterator) + @test iterator.state.iterate ≈ [1.0, 1.0] + + AI.step!(iterator) + @test iterator.state.iterate ≈ [2.0, 2.0] + end + + @testset "NestedAlgorithm with different sub-algorithms" begin + # Test nested algorithm with varying sub-algorithms + nested_alg = AIE.DefaultNestedAlgorithm(; + algorithms = [ + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)), + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + ] + ) + + @test AIE.max_iterations(nested_alg) == 3 + @test length(nested_alg.algorithms) == 3 + + problem = TestProblem([1.0, 2.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) + + # First algorithm: 1 iteration × 1 increment = 1 + # Second algorithm: 2 iterations × 2 increment = 4 + # Third algorithm: 1 iteration × 1 increment = 1 + # Total: 1 + 4 + 1 = 6 + @test state.iterate ≈ [6.0, 6.0] + @test state.iteration == 3 + end + + @testset "Edge cases" begin + # Test with single nested algorithm + nested_alg = AIE.nested_algorithm(1) do i + return TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + end + + problem = TestProblem([1.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0]) + AI.solve!(problem, nested_alg, state; iterate = [0.0]) + + @test state.iterate ≈ [1.0] + @test state.iteration == 1 + end +end From 5a85467afafc54009e9c45198dba4e5b893ae96c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 17 Dec 2025 20:23:09 -0500 Subject: [PATCH 3/3] Simplify tests --- test/test_algorithmsinterfaceextensions.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 2ae4a4c..0edf7a4 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -25,7 +25,7 @@ end function AI.step!( problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState; - logging_context_prefix = Symbol() + kwargs... ) state.iterate .+= 2 # Different increment step return state @@ -58,7 +58,7 @@ end state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) initial_iterate = [1.0, 2.0] - AIE.AI.initialize_state!(problem, algorithm, state; iterate = initial_iterate) + AI.initialize_state!(problem, algorithm, state; iterate = initial_iterate) @test state.iterate == initial_iterate @test state.iteration == 0 end @@ -68,7 +68,7 @@ end problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm() - state = AIE.AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) + state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) @test state isa AIE.DefaultState @test state.iteration == 0 end @@ -363,14 +363,14 @@ end ) problem = TestProblem([1.0, 2.0]) - state = AIE.AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + state = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) # Manually step through to test step! functionality - AIE.AI.increment!(problem, flattened_alg, state) + AI.increment!(problem, flattened_alg, state) @test state.parent_iteration == 1 @test state.child_iteration == 1 - AIE.AI.step!(problem, flattened_alg, state) + AI.step!(problem, flattened_alg, state) # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 @test state.iterate ≈ [4.0, 4.0] end