From ba4f1c6c3608aa405c00b53c1a3f7fadce8aefa9 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 16:58:06 +0330 Subject: [PATCH 01/14] remove old icnf types --- src/ContinuousNormalizingFlows.jl | 6 --- src/core/base_icnf.jl | 32 ++++++---------- src/core/icnf.jl | 61 ++---------------------------- src/exts/mlj_ext/core_cond_icnf.jl | 2 +- src/exts/mlj_ext/core_icnf.jl | 2 +- src/layers/cond_layer.jl | 2 +- 6 files changed, 18 insertions(+), 87 deletions(-) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index fe57b04d..d5826584 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -33,12 +33,6 @@ export construct, generate, loss, ICNF, - RNODE, - CondRNODE, - FFJORD, - CondFFJORD, - Planar, - CondPlanar, TestMode, TrainMode, DIVecJacVectorMode, diff --git a/src/core/base_icnf.jl b/src/core/base_icnf.jl index b73de94b..7df0b779 100644 --- a/src/core/base_icnf.jl +++ b/src/core/base_icnf.jl @@ -1,12 +1,14 @@ -function construct( - aicnf::Type{<:AbstractICNF}, - nn::LuxCore.AbstractLuxLayer, +function construct(; nvars::Int, - naugmented::Int = 0; + naugmented::Int = 0, + nn::LuxCore.AbstractLuxLayer = Lux.Chain( + Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), + ), + aicnf::Type{<:AbstractICNF} = ICNF, data_type::Type{<:AbstractFloat} = Float32, compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), inplace::Bool = false, - cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar}, + cond::Bool = false, device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(), basedist::Distributions.Distribution = Distributions.MvNormal( FillArrays.Zeros{data_type}(nvars + naugmented), @@ -20,25 +22,13 @@ function construct( ), sol_kwargs::NamedTuple = (;), rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), - λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1.0e-2) - else - zero(data_type) - end, - λ₂::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1.0e-2) - else - zero(data_type) - end, - λ₃::AbstractFloat = if naugmented >= nvars - convert(data_type, 1.0e-2) - else - zero(data_type) - end, + λ₁::AbstractFloat = zero(data_type), + λ₂::AbstractFloat = zero(data_type), + λ₃::AbstractFloat = zero(data_type), ) steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) - return ICNF{ + return aicnf{ data_type, typeof(compute_mode), inplace, diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 0c4622f0..869bde00 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -1,60 +1,3 @@ -struct Planar{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondPlanar{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - -struct FFJORD{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondFFJORD{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - -struct RNODE{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondRNODE{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - """ Implementation of ICNF. @@ -65,6 +8,10 @@ Refs: [Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018).](https://arxiv.org/abs/1810.01367) [Finlay, Chris, Jörn-Henrik Jacobsen, Levon Nurbekyan, and Adam M. Oberman. "How to train your neural ODE: the world of Jacobian and kinetic regularization." arXiv preprint arXiv:2002.02798 (2020).](https://arxiv.org/abs/2002.02798) + +[Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. "Augmented Neural ODEs." arXiv preprint arXiv:1904.01681 (2019).](https://arxiv.org/abs/1904.01681) + +[Ghosh, Arnab, Harkirat Singh Behl, Emilien Dupont, Philip HS Torr, and Vinay Namboodiri. "STEER: Simple Temporal Regularization For Neural ODEs." arXiv preprint arXiv:2006.10711 (2020).](https://arxiv.org/abs/2006.10711) """ struct ICNF{ T <: AbstractFloat, diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 834d08aa..3330f741 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -40,7 +40,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) ps, data, ) - res_stats = SciMLBase.OptimizationStats[] + res_stats = Any[] for opt in model.optimizers optprob_re = SciMLBase.remake(optprob; u0 = ps) res = SciMLBase.solve(optprob_re, opt; model.sol_kwargs...) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 2db39ff3..7e1db2f9 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -37,7 +37,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X) ps, data, ) - res_stats = SciMLBase.OptimizationStats[] + res_stats = Any[] for opt in model.optimizers optprob_re = SciMLBase.remake(optprob; u0 = ps) res = SciMLBase.solve(optprob_re, opt; model.sol_kwargs...) diff --git a/src/layers/cond_layer.jl b/src/layers/cond_layer.jl index cdd75f00..e97bf9fa 100644 --- a/src/layers/cond_layer.jl +++ b/src/layers/cond_layer.jl @@ -1,4 +1,4 @@ -struct CondLayer{NN <: LuxCore.AbstractLuxLayer, AT <: AbstractArray} <: +struct CondLayer{NN <: LuxCore.AbstractLuxLayer, AT <: Any} <: LuxCore.AbstractLuxWrapperLayer{:nn} nn::NN ys::AT From 5181ce7eadf9cc709027900c3574ceb113179f59 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 17:19:41 +0330 Subject: [PATCH 02/14] more cleaning --- benchmark/benchmarks.jl | 23 ++++++++++---------- examples/usage.jl | 12 +++++------ src/core/base_icnf.jl | 2 +- test/ci_tests/regression_tests.jl | 15 ++++++------- test/ci_tests/smoke_tests.jl | 28 ++++++++++++------------- test/ci_tests/speed_tests.jl | 15 ++++++------- test/quality_tests/checkby_JET_tests.jl | 27 ++++++++++++------------ 7 files changed, 62 insertions(+), 60 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 1c65edde..ea696b4f 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -21,18 +21,19 @@ r = rand(rng, data_dist, ndimension, ndata) r = convert.(Float32, r) nvars = size(r, 1) -naugs = nvars +naugs = nvars + 1 n_in = nvars + naugs -nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) +nn = Lux.Chain( + Lux.Dense(n_in => (2 * n_in + 1), tanh), + Lux.Dense((2 * n_in + 1) => n_in, tanh), +) -icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, +icnf = ContinuousNormalizingFlows.construct(; nvars, - naugs; + naugmented = naugs, + nn, compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), steer_rate = 1.0f-1, λ₁ = 1.0f-2, λ₂ = 1.0f-2, @@ -40,14 +41,12 @@ icnf = ContinuousNormalizingFlows.construct( rng, ) -icnf2 = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, +icnf2 = ContinuousNormalizingFlows.construct(; + nvars; + naugmented = naugs, nn, - nvars, - naugs; inplace = true, compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), steer_rate = 1.0f-1, λ₁ = 1.0f-2, λ₂ = 1.0f-2, diff --git a/examples/usage.jl b/examples/usage.jl index 82b09192..4e4f11fa 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -25,14 +25,14 @@ using ContinuousNormalizingFlows, # To use gpu, add related packages # using LuxCUDA, CUDA, cuDNN -nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh)) -icnf = construct( - ICNF, - nn, - nvars, # number of variables - naugs; # number of augmented dimensions +nn = Chain(Dense(n_in => (2 * n_in + 1), tanh), Dense((2 * n_in + 1) => n_in, tanh)) +icnf = construct(; + nvars = nvars, # number of variables + naugmented = naugs, # number of augmented dimensions + nn = nn, compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote inplace = false, # not using the inplace version of functions + cond = false, # not conditioning on auxiliary input device = cpu_device(), # process data by CPU # device = gpu_device(), # process data by GPU tspan = (0.0f0, 1.0f0), # time span diff --git a/src/core/base_icnf.jl b/src/core/base_icnf.jl index 7df0b779..5406136a 100644 --- a/src/core/base_icnf.jl +++ b/src/core/base_icnf.jl @@ -1,5 +1,5 @@ function construct(; - nvars::Int, + nvars::Int = 1, naugmented::Int = 0, nn::LuxCore.AbstractLuxLayer = Lux.Chain( Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index 07f09ada..630bf028 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -7,18 +7,19 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test r = convert.(Float32, r) nvars = size(r, 1) - naugs = nvars + naugs = nvars + 1 n_in = nvars + naugs - nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) + nn = Lux.Chain( + Lux.Dense(n_in => (2 * n_in + 1), tanh), + Lux.Dense((2 * n_in + 1) => n_in, tanh), + ) - icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, + icnf = ContinuousNormalizingFlows.construct(; nvars, - naugs; + naugmented = naugs, + nn, compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), steer_rate = 1.0f-1, λ₁ = 1.0f-2, λ₂ = 1.0f-2, diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index fae767e8..afcb54c5 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -1,5 +1,4 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" begin - mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF] omodes = ContinuousNormalizingFlows.Mode[ ContinuousNormalizingFlows.TrainMode{true}(), ContinuousNormalizingFlows.TestMode{true}(), @@ -76,8 +75,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ), ] - Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in - devices, + Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode" for device in + devices, data_type in data_types, compute_mode in compute_modes, ndata in ndata_, @@ -85,8 +84,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be inplace in inplaces, cond in conds, planar in planars, - omode in omodes, - mt in mts + omode in omodes data_dist = Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) @@ -107,21 +105,24 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ifelse( planar, Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh; n_cond = nvars), + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2 + 1, + tanh; + n_cond = nvars, + ), ), - Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars * 3 + 1 => nvars * 2 + 1, tanh)), ), ifelse( planar, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2 + 1, tanh)), + Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), ), ) - icnf = ContinuousNormalizingFlows.construct( - mt, - nn, + icnf = ContinuousNormalizingFlows.construct(; nvars, - nvars; + naugmented = nvars + 1, + nn, data_type, compute_mode, inplace, @@ -202,7 +203,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && ( omode isa ContinuousNormalizingFlows.TrainMode || ( diff --git a/test/ci_tests/speed_tests.jl b/test/ci_tests/speed_tests.jl index 9dcf3f51..7a48c1a7 100644 --- a/test/ci_tests/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -32,7 +32,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be Test.@testset verbose = true showtiming = true failfast = false "$compute_mode" for compute_mode in compute_modes - @show compute_mode rng = StableRNGs.StableRNG(1) @@ -43,16 +42,18 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be r = convert.(Float32, r) nvars = size(r, 1) - naugs = nvars + naugs = nvars + 1 n_in = nvars + naugs - nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) + nn = Lux.Chain( + Lux.Dense(n_in => (2 * n_in + 1), tanh), + Lux.Dense((2 * n_in + 1) => n_in, tanh), + ) - icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, + icnf = ContinuousNormalizingFlows.construct(; nvars, - naugs; + naugmented = naugs, + nn, compute_mode, tspan = (0.0f0, 1.0f0), steer_rate = 1.0f-1, diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index 72ea31af..7b7a5c8d 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -4,7 +4,6 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg target_modules = (ContinuousNormalizingFlows,), ) - mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF] omodes = ContinuousNormalizingFlows.Mode[ ContinuousNormalizingFlows.TrainMode{true}(), ContinuousNormalizingFlows.TestMode{true}(), @@ -61,8 +60,8 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg ), ] - Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in - devices, + Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode" for device in + devices, data_type in data_types, compute_mode in compute_modes, ndata in ndata_, @@ -70,8 +69,7 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg inplace in inplaces, cond in conds, planar in planars, - omode in omodes, - mt in mts + omode in omodes data_dist = Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) @@ -90,21 +88,24 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg ifelse( planar, Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh; n_cond = nvars), + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2 + 1, + tanh; + n_cond = nvars, + ), ), - Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars * 3 + 1 => nvars * 2 + 1, tanh)), ), ifelse( planar, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2 + 1, tanh)), + Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), ), ) - icnf = ContinuousNormalizingFlows.construct( - mt, - nn, + icnf = ContinuousNormalizingFlows.construct(; nvars, - nvars; + naugmented = nvars + 1, + nn, data_type, compute_mode, inplace, From 5a6729d559c96b72fa568e5f4534bf0159a44167 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 17:25:43 +0330 Subject: [PATCH 03/14] use fitted in usage --- examples/usage.jl | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/examples/usage.jl b/examples/usage.jl index 4e4f11fa..951304cb 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -45,22 +45,25 @@ icnf = construct(; ## Fit It using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers -df = DataFrame(transpose(r), :auto) -model = ICNFModel( - icnf; - optimizers = (Adam(),), - adtype = AutoZygote(), - batchsize = 512, - sol_kwargs = (; epochs = 300, progress = true), # pass to the solver -) -mach = machine(model, df) -fit!(mach) -# CUDA.@allowscalar fit!(mach) # needed for gpu -## Store It icnf_mach_fn = "icnf_mach.jls" -MLJBase.save(icnf_mach_fn, mach) # save it -mach = machine(icnf_mach_fn) # load it +if ispath(icnf_mach_fn) + mach = machine(icnf_mach_fn) # load it +else + df = DataFrame(transpose(r), :auto) + model = ICNFModel( + icnf; + optimizers = (Adam(),), + adtype = AutoZygote(), + batchsize = 512, + sol_kwargs = (; epochs = 300, progress = true), # pass to the solver + ) + mach = machine(model, df) + fit!(mach) + # CUDA.@allowscalar fit!(mach) # needed for gpu + + MLJBase.save(icnf_mach_fn, mach) # save it +end ## Use It d = ICNFDist(mach, TestMode()) From 6f627e6a1a76dfefc7b220999b1a3fbccfef91de Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 17:29:49 +0330 Subject: [PATCH 04/14] test by more julia versions --- .github/workflows/CI-CheckBy.yml | 7 +++++-- .github/workflows/CI.yml | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI-CheckBy.yml b/.github/workflows/CI-CheckBy.yml index e03dfbeb..21f3790d 100644 --- a/.github/workflows/CI-CheckBy.yml +++ b/.github/workflows/CI-CheckBy.yml @@ -24,8 +24,11 @@ jobs: - CheckByJET - CheckByExplicitImports version: - - release - - lts + - 1.10 + - 1.11 + - 1.12 + # - release + # - lts # - nightly os: - ubuntu-latest diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 55012191..a7939109 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,8 +27,11 @@ jobs: - Regression - Speed version: - - release - - lts + - 1.10 + - 1.11 + - 1.12 + # - release + # - lts # - nightly os: - ubuntu-latest From 1ae6c07e03cf49c46fc249ce1f04a392965f63db Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 17:33:44 +0330 Subject: [PATCH 05/14] fix --- benchmark/benchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index ea696b4f..1ab154ba 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -42,7 +42,7 @@ icnf = ContinuousNormalizingFlows.construct(; ) icnf2 = ContinuousNormalizingFlows.construct(; - nvars; + nvars, naugmented = naugs, nn, inplace = true, From 13cb0386073c24f0d62c2605997d369028dc02e2 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 17:45:00 +0330 Subject: [PATCH 06/14] fix --- .github/workflows/CI-CheckBy.yml | 6 +++--- .github/workflows/CI.yml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/CI-CheckBy.yml b/.github/workflows/CI-CheckBy.yml index 21f3790d..c84b393f 100644 --- a/.github/workflows/CI-CheckBy.yml +++ b/.github/workflows/CI-CheckBy.yml @@ -24,9 +24,9 @@ jobs: - CheckByJET - CheckByExplicitImports version: - - 1.10 - - 1.11 - - 1.12 + - "1.10" + - "1.11" + - "1.12" # - release # - lts # - nightly diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a7939109..c33fca85 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,9 +27,9 @@ jobs: - Regression - Speed version: - - 1.10 - - 1.11 - - 1.12 + - "1.10" + - "1.11" + - "1.12" # - release # - lts # - nightly From 4565e9c71e47311b008abe964a7ae72089f43316 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 19:11:21 +0330 Subject: [PATCH 07/14] more cleaning --- Project.toml | 6 +- benchmark/Project.toml | 4 -- benchmark/benchmarks.jl | 28 +--------- examples/usage.jl | 27 ++++++--- src/ContinuousNormalizingFlows.jl | 6 +- src/core/base_icnf.jl | 67 ----------------------- src/core/icnf.jl | 73 +++++++++++++++++++++++++ src/core/types.jl | 2 +- src/exts/dist_ext/core.jl | 2 +- src/exts/dist_ext/core_cond_icnf.jl | 12 ++-- src/exts/dist_ext/core_icnf.jl | 12 ++-- src/exts/mlj_ext/core_cond_icnf.jl | 46 ++++++++-------- src/exts/mlj_ext/core_icnf.jl | 42 +++++++------- test/Project.toml | 4 -- test/ci_tests/regression_tests.jl | 16 +----- test/ci_tests/smoke_tests.jl | 14 ++--- test/ci_tests/speed_tests.jl | 11 +--- test/quality_tests/checkby_JET_tests.jl | 6 +- test/runtests.jl | 2 - 19 files changed, 175 insertions(+), 205 deletions(-) diff --git a/Project.toml b/Project.toml index 62530b6f..97374eac 100644 --- a/Project.toml +++ b/Project.toml @@ -23,11 +23,12 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" -OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" +OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -52,11 +53,12 @@ MLUtils = "0.4" NNlib = "0.9" Optimisers = "0.4" OptimizationOptimisers = "0.3" -OrdinaryDiffEqDefault = "1" +OrdinaryDiffEqAdamsBashforthMoulton = "1" Random = "1" SciMLBase = "2" SciMLSensitivity = "7" ScientificTypesBase = "3" +Static = "1" Statistics = "1" WeightInitializers = "1" Zygote = "0.7" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 7fef44ff..f35c328f 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -7,9 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" -SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -22,9 +20,7 @@ Distributions = "0.25" ForwardDiff = "1" Lux = "1" LuxCore = "1" -OrdinaryDiffEqDefault = "1" PkgBenchmark = "0.2" -SciMLSensitivity = "7" StableRNGs = "1" Zygote = "0.7" julia = "1.10" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 1ab154ba..fc19a444 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -6,9 +6,7 @@ import ADTypes, ForwardDiff, Lux, LuxCore, - OrdinaryDiffEqDefault, PkgBenchmark, - SciMLSensitivity, StableRNGs, Zygote, ContinuousNormalizingFlows @@ -29,30 +27,10 @@ nn = Lux.Chain( Lux.Dense((2 * n_in + 1) => n_in, tanh), ) -icnf = ContinuousNormalizingFlows.construct(; - nvars, - naugmented = naugs, - nn, - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - rng, -) +icnf = ContinuousNormalizingFlows.ICNF(; nvars, naugmented = naugs, nn, rng) -icnf2 = ContinuousNormalizingFlows.construct(; - nvars, - naugmented = naugs, - nn, - inplace = true, - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - rng, -) +icnf2 = + ContinuousNormalizingFlows.ICNF(; nvars, naugmented = naugs, nn, inplace = true, rng) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/examples/usage.jl b/examples/usage.jl index 951304cb..e786f637 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -20,13 +20,19 @@ n_in = nvars + naugs ## Model using ContinuousNormalizingFlows, - Lux, OrdinaryDiffEqAdamsBashforthMoulton, ADTypes, Zygote, MLDataDevices + Lux, + OrdinaryDiffEqAdamsBashforthMoulton, + SciMLSensitivity, + Static, + ADTypes, + Zygote, + MLDataDevices # To use gpu, add related packages # using LuxCUDA, CUDA, cuDNN nn = Chain(Dense(n_in => (2 * n_in + 1), tanh), Dense((2 * n_in + 1) => n_in, tanh)) -icnf = construct(; +icnf = ICNF(; nvars = nvars, # number of variables naugmented = naugs, # number of augmented dimensions nn = nn, @@ -40,7 +46,14 @@ icnf = construct(; λ₁ = 1.0f-2, # regulate flow λ₂ = 1.0f-2, # regulate volume change λ₃ = 1.0f-2, # regulate augmented dimensions - sol_kwargs = (; save_everystep = false, alg = VCABM()), # pass to the solver + sol_kwargs = (; + save_everystep = false, + reltol = 1.0f-4, + abstol = 1.0f-8, + maxiters = typemax(Int), + alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = True()), + sensealg = GaussAdjoint(; autodiff = true, checkpointing = true), + ), # pass to the solver ) ## Fit It @@ -51,11 +64,11 @@ if ispath(icnf_mach_fn) mach = machine(icnf_mach_fn) # load it else df = DataFrame(transpose(r), :auto) - model = ICNFModel( - icnf; - optimizers = (Adam(),), + model = ICNFModel(; + icnf, + optimizers = (OptimiserChain(WeightDecay(), Adam()),), adtype = AutoZygote(), - batchsize = 512, + batchsize = 1024, sol_kwargs = (; epochs = 300, progress = true), # pass to the solver ) mach = machine(model, df) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index d5826584..24a2082c 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -19,17 +19,17 @@ import ADTypes, NNlib, Optimisers, OptimizationOptimisers, - OrdinaryDiffEqDefault, + OrdinaryDiffEqAdamsBashforthMoulton, Random, SciMLBase, SciMLSensitivity, ScientificTypesBase, + Static, Statistics, WeightInitializers, Zygote -export construct, - inference, +export inference, generate, loss, ICNF, diff --git a/src/core/base_icnf.jl b/src/core/base_icnf.jl index 5406136a..3052ee82 100644 --- a/src/core/base_icnf.jl +++ b/src/core/base_icnf.jl @@ -1,70 +1,3 @@ -function construct(; - nvars::Int = 1, - naugmented::Int = 0, - nn::LuxCore.AbstractLuxLayer = Lux.Chain( - Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), - ), - aicnf::Type{<:AbstractICNF} = ICNF, - data_type::Type{<:AbstractFloat} = Float32, - compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), - inplace::Bool = false, - cond::Bool = false, - device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(), - basedist::Distributions.Distribution = Distributions.MvNormal( - FillArrays.Zeros{data_type}(nvars + naugmented), - FillArrays.Eye{data_type}(nvars + naugmented), - ), - tspan::NTuple{2} = (zero(data_type), one(data_type)), - steer_rate::AbstractFloat = zero(data_type), - epsdist::Distributions.Distribution = Distributions.MvNormal( - FillArrays.Zeros{data_type}(nvars + naugmented), - FillArrays.Eye{data_type}(nvars + naugmented), - ), - sol_kwargs::NamedTuple = (;), - rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), - λ₁::AbstractFloat = zero(data_type), - λ₂::AbstractFloat = zero(data_type), - λ₃::AbstractFloat = zero(data_type), -) - steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) - - return aicnf{ - data_type, - typeof(compute_mode), - inplace, - cond, - !iszero(naugmented), - !iszero(steer_rate), - !iszero(λ₁), - !iszero(λ₂), - !iszero(λ₃), - typeof(nn), - typeof(nvars), - typeof(device), - typeof(basedist), - typeof(tspan), - typeof(steerdist), - typeof(epsdist), - typeof(sol_kwargs), - typeof(rng), - }( - nn, - nvars, - naugmented, - compute_mode, - device, - basedist, - tspan, - steerdist, - epsdist, - sol_kwargs, - rng, - λ₁, - λ₂, - λ₃, - ) -end - function Base.show(io::IO, icnf::AbstractICNF) return print(io, typeof(icnf)) end diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 869bde00..71c91a9d 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -50,6 +50,79 @@ struct ICNF{ λ₃::T end +function ICNF(; + nvars::Int = 1, + naugmented::Int = nvars + 1, + nn::LuxCore.AbstractLuxLayer = Lux.Chain( + Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), + ), + data_type::Type{<:AbstractFloat} = Float32, + compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), + inplace::Bool = false, + cond::Bool = false, + device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(), + basedist::Distributions.Distribution = Distributions.MvNormal( + FillArrays.Zeros{data_type}(nvars + naugmented), + FillArrays.Eye{data_type}(nvars + naugmented), + ), + tspan::NTuple{2} = (zero(data_type), one(data_type)), + steer_rate::AbstractFloat = convert(data_type, 1.0e-1), + epsdist::Distributions.Distribution = Distributions.MvNormal( + FillArrays.Zeros{data_type}(nvars + naugmented), + FillArrays.Eye{data_type}(nvars + naugmented), + ), + sol_kwargs::NamedTuple = (; + save_everystep = false, + reltol = convert(data_type, 1.0e-4), + abstol = convert(data_type, 1.0e-8), + maxiters = typemax(Int), + alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), + sensealg = SciMLSensitivity.GaussAdjoint(; autodiff = true, checkpointing = true), + ), + rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), + λ₁::AbstractFloat = convert(data_type, 1.0e-2), + λ₂::AbstractFloat = convert(data_type, 1.0e-2), + λ₃::AbstractFloat = convert(data_type, 1.0e-2), +) + steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) + + return ICNF{ + data_type, + typeof(compute_mode), + inplace, + cond, + !iszero(naugmented), + !iszero(steer_rate), + !iszero(λ₁), + !iszero(λ₂), + !iszero(λ₃), + typeof(nn), + typeof(nvars), + typeof(device), + typeof(basedist), + typeof(tspan), + typeof(steerdist), + typeof(epsdist), + typeof(sol_kwargs), + typeof(rng), + }( + nn, + nvars, + naugmented, + compute_mode, + device, + basedist, + tspan, + steerdist, + epsdist, + sol_kwargs, + rng, + λ₁, + λ₂, + λ₃, + ) +end + function n_augment(::ICNF, ::Mode) return 2 end diff --git a/src/core/types.jl b/src/core/types.jl index 2fcba0cb..4ba974b3 100644 --- a/src/core/types.jl +++ b/src/core/types.jl @@ -3,7 +3,7 @@ struct TestMode{REG} <: Mode{REG} end struct TrainMode{REG} <: Mode{REG} end function TestMode() - return TestMode{false}() + return TestMode{true}() end function TrainMode() diff --git a/src/exts/dist_ext/core.jl b/src/exts/dist_ext/core.jl index a22171f4..97825690 100644 --- a/src/exts/dist_ext/core.jl +++ b/src/exts/dist_ext/core.jl @@ -4,7 +4,7 @@ abstract type ICNFDistribution{AICNF <: AbstractICNF} <: Distributions.ContinuousMultivariateDistribution end function Base.length(d::ICNFDistribution) - return d.m.nvars + return d.icnf.nvars end function Base.eltype(::ICNFDistribution{AICNF}) where {AICNF <: AbstractICNF} diff --git a/src/exts/dist_ext/core_cond_icnf.jl b/src/exts/dist_ext/core_cond_icnf.jl index 7043fde3..759cc989 100644 --- a/src/exts/dist_ext/core_cond_icnf.jl +++ b/src/exts/dist_ext/core_cond_icnf.jl @@ -1,5 +1,5 @@ struct CondICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} - m::AICNF + icnf::AICNF mode::Mode ys::AbstractVecOrMat{<:Real} ps::Any @@ -12,14 +12,14 @@ function CondICNFDist( ys::AbstractVecOrMat{<:Real}, ) (ps, st) = MLJModelInterface.fitted_params(mach) - return CondICNFDist(mach.model.m, mode, ys, ps, st) + return CondICNFDist(mach.model.icnf, mode, ys, ps, st) end function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) + return first(inference(d.icnf, d.mode, x, d.ys, d.ps, d.st)) end function Distributions._logpdf( @@ -42,7 +42,7 @@ function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) + return first(inference(d.icnf, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) end function Distributions._rand!( @@ -50,7 +50,7 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return x .= generate(d.m, d.mode, d.ys, d.ps, d.st) + return x .= generate(d.icnf, d.mode, d.ys, d.ps, d.st) end function Distributions._rand!( @@ -76,5 +76,5 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) + return A .= generate(d.icnf, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) end diff --git a/src/exts/dist_ext/core_icnf.jl b/src/exts/dist_ext/core_icnf.jl index 9f47a6da..21a3edf4 100644 --- a/src/exts/dist_ext/core_icnf.jl +++ b/src/exts/dist_ext/core_icnf.jl @@ -1,5 +1,5 @@ struct ICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} - m::AICNF + icnf::AICNF mode::Mode ps::Any st::NamedTuple @@ -7,14 +7,14 @@ end function ICNFDist(mach::MLJBase.Machine{<:ICNFModel}, mode::Mode) (ps, st) = MLJModelInterface.fitted_params(mach) - return ICNFDist(mach.model.m, mode, ps, st) + return ICNFDist(mach.model.icnf, mode, ps, st) end function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return first(inference(d.m, d.mode, x, d.ps, d.st)) + return first(inference(d.icnf, d.mode, x, d.ps, d.st)) end function Distributions._logpdf( @@ -37,7 +37,7 @@ function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return first(inference(d.m, d.mode, A, d.ps, d.st)) + return first(inference(d.icnf, d.mode, A, d.ps, d.st)) end function Distributions._rand!( @@ -45,7 +45,7 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return x .= generate(d.m, d.mode, d.ps, d.st) + return x .= generate(d.icnf, d.mode, d.ps, d.st) end function Distributions._rand!( @@ -71,5 +71,5 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) + return A .= generate(d.icnf, d.mode, d.ps, d.st, size(A, 2)) end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 3330f741..a6de49a9 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -1,5 +1,5 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} - m::AICNF + icnf::AICNF loss::Function optimizers::Tuple @@ -9,32 +9,34 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} sol_kwargs::NamedTuple end -function CondICNFModel( - m::AbstractICNF, - loss::Function = loss; - optimizers::Tuple = (Optimisers.Adam(),), +function CondICNFModel(; + icnf::AbstractICNF, + loss::Function = loss, + optimizers::Tuple = ( + Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), + ), adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - batchsize::Int = 32, - sol_kwargs::NamedTuple = (;), + batchsize::Int = 1024, + sol_kwargs::NamedTuple = (; epochs = 300, progress = true), ) - return CondICNFModel(m, loss, optimizers, adtype, batchsize, sol_kwargs) + return CondICNFModel(icnf, loss, optimizers, adtype, batchsize, sol_kwargs) end function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) X, Y = XY x = collect(transpose(MLJModelInterface.matrix(X))) y = collect(transpose(MLJModelInterface.matrix(Y))) - ps, st = LuxCore.setup(model.m.rng, model.m) + ps, st = LuxCore.setup(model.icnf.rng, model.icnf) ps = ComponentArrays.ComponentArray(ps) - x = model.m.device(x) - y = model.m.device(y) - ps = model.m.device(ps) - st = model.m.device(st) - data = make_dataloader(model.m, model.batchsize, (x, y)) - data = model.m.device(data) + x = model.icnf.device(x) + y = model.icnf.device(y) + ps = model.icnf.device(ps) + st = model.icnf.device(st) + data = make_dataloader(model.icnf, model.batchsize, (x, y)) + data = model.icnf.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( - make_opt_loss(model.m, TrainMode{true}(), st, model.loss), + make_opt_loss(model.icnf, TrainMode{true}(), st, model.loss), model.adtype, ), ps, @@ -58,21 +60,21 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) Xnew, Ynew = XYnew xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) ynew = collect(transpose(MLJModelInterface.matrix(Ynew))) - xnew = model.m.device(xnew) - ynew = model.m.device(ynew) + xnew = model.icnf.device(xnew) + ynew = model.icnf.device(ynew) (ps, st) = fitresult - logp̂x = if model.m.compute_mode isa VectorMode + logp̂x = if model.icnf.compute_mode isa VectorMode @warn maxlog = 1 "to compute by vectors, data should be a vector." broadcast( function (x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return first(inference(model.m, TestMode{false}(), x, y, ps, st)) + return first(inference(model.icnf, TestMode{false}(), x, y, ps, st)) end, collect(collect.(eachcol(xnew))), collect(collect.(eachcol(ynew))), ) - elseif model.m.compute_mode isa MatrixMode - first(inference(model.m, TestMode{false}(), xnew, ynew, ps, st)) + elseif model.icnf.compute_mode isa MatrixMode + first(inference(model.icnf, TestMode{false}(), xnew, ynew, ps, st)) else error("Not Implemented") end diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 7e1db2f9..0b39446f 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -1,5 +1,5 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} - m::AICNF + icnf::AICNF loss::Function optimizers::Tuple @@ -9,29 +9,31 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} sol_kwargs::NamedTuple end -function ICNFModel( - m::AbstractICNF, - loss::Function = loss; - optimizers::Tuple = (Optimisers.Adam(),), +function ICNFModel(; + icnf::AbstractICNF, + loss::Function = loss, + optimizers::Tuple = ( + Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), + ), adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - batchsize::Int = 32, - sol_kwargs::NamedTuple = (;), + batchsize::Int = 1024, + sol_kwargs::NamedTuple = (; epochs = 300, progress = true), ) - return ICNFModel(m, loss, optimizers, adtype, batchsize, sol_kwargs) + return ICNFModel(icnf, loss, optimizers, adtype, batchsize, sol_kwargs) end function MLJModelInterface.fit(model::ICNFModel, verbosity, X) x = collect(transpose(MLJModelInterface.matrix(X))) - ps, st = LuxCore.setup(model.m.rng, model.m) + ps, st = LuxCore.setup(model.icnf.rng, model.icnf) ps = ComponentArrays.ComponentArray(ps) - x = model.m.device(x) - ps = model.m.device(ps) - st = model.m.device(st) - data = make_dataloader(model.m, model.batchsize, (x,)) - data = model.m.device(data) + x = model.icnf.device(x) + ps = model.icnf.device(ps) + st = model.icnf.device(st) + data = make_dataloader(model.icnf, model.batchsize, (x,)) + data = model.icnf.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( - make_opt_loss(model.m, TrainMode{true}(), st, model.loss), + make_opt_loss(model.icnf, TrainMode{true}(), st, model.loss), model.adtype, ), ps, @@ -53,19 +55,19 @@ end function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) - xnew = model.m.device(xnew) + xnew = model.icnf.device(xnew) (ps, st) = fitresult - logp̂x = if model.m.compute_mode isa VectorMode + logp̂x = if model.icnf.compute_mode isa VectorMode @warn maxlog = 1 "to compute by vectors, data should be a vector." broadcast( function (x::AbstractVector{<:Real}) - return first(inference(model.m, TestMode{false}(), x, ps, st)) + return first(inference(model.icnf, TestMode{false}(), x, ps, st)) end, collect(collect.(eachcol(xnew))), ) - elseif model.m.compute_mode isa MatrixMode - first(inference(model.m, TestMode{false}(), xnew, ps, st)) + elseif model.icnf.compute_mode isa MatrixMode + first(inference(model.icnf, TestMode{false}(), xnew, ps, st)) else error("Not Implemented") end diff --git a/test/Project.toml b/test/Project.toml index dd245cf4..ee838c79 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,8 +14,6 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" -SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -36,8 +34,6 @@ Lux = "1" LuxCore = "1" MLDataDevices = "1" MLJBase = "1" -OrdinaryDiffEqDefault = "1" -SciMLSensitivity = "7" StableRNGs = "1" Test = "1" Zygote = "0.7" diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index 630bf028..00ef0e43 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -15,21 +15,11 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test Lux.Dense((2 * n_in + 1) => n_in, tanh), ) - icnf = ContinuousNormalizingFlows.construct(; - nvars, - naugmented = naugs, - nn, - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - rng, - ) + icnf = ContinuousNormalizingFlows.ICNF(; nvars, naugmented = naugs, nn, rng) df = DataFrames.DataFrame(transpose(r), :auto) - model = ContinuousNormalizingFlows.ICNFModel( - icnf; + model = ContinuousNormalizingFlows.ICNFModel(; + icnf, batchsize = 0, sol_kwargs = (; epochs = 300), ) diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index afcb54c5..d945afd5 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -119,7 +119,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), ), ) - icnf = ContinuousNormalizingFlows.construct(; + icnf = ContinuousNormalizingFlows.ICNF(; nvars, naugmented = nvars + 1, nn, @@ -128,10 +128,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be inplace, cond, device, - steer_rate = convert(data_type, 1.0e-1), - λ₁ = convert(data_type, 1.0e-2), - λ₂ = convert(data_type, 1.0e-2), - λ₃ = convert(data_type, 1.0e-2), ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) @@ -219,8 +215,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ) if cond - model = ContinuousNormalizingFlows.CondICNFModel( - icnf; + model = ContinuousNormalizingFlows.CondICNFModel(; + icnf, adtype, batchsize = 0, sol_kwargs = (; epochs = 2), @@ -240,8 +236,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), ) broken = compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} else - model = ContinuousNormalizingFlows.ICNFModel( - icnf; + model = ContinuousNormalizingFlows.ICNFModel(; + icnf, adtype, batchsize = 0, sol_kwargs = (; epochs = 2), diff --git a/test/ci_tests/speed_tests.jl b/test/ci_tests/speed_tests.jl index 7a48c1a7..3c14ecf5 100644 --- a/test/ci_tests/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -50,22 +50,17 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be Lux.Dense((2 * n_in + 1) => n_in, tanh), ) - icnf = ContinuousNormalizingFlows.construct(; + icnf = ContinuousNormalizingFlows.ICNF(; nvars, naugmented = naugs, nn, compute_mode, - tspan = (0.0f0, 1.0f0), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, rng, ) df = DataFrames.DataFrame(transpose(r), :auto) - model = ContinuousNormalizingFlows.ICNFModel( - icnf; + model = ContinuousNormalizingFlows.ICNFModel(; + icnf, batchsize = 0, sol_kwargs = (; epochs = 5), ) diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index 7b7a5c8d..91eb7600 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -102,7 +102,7 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), ), ) - icnf = ContinuousNormalizingFlows.construct(; + icnf = ContinuousNormalizingFlows.ICNF(; nvars, naugmented = nvars + 1, nn, @@ -111,10 +111,6 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg inplace, cond, device, - steer_rate = convert(data_type, 1.0e-1), - λ₁ = convert(data_type, 1.0e-2), - λ₂ = convert(data_type, 1.0e-2), - λ₃ = convert(data_type, 1.0e-2), ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/test/runtests.jl b/test/runtests.jl index 0a4ba95d..604ccb44 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,6 @@ import ADTypes, LuxCore, MLDataDevices, MLJBase, - OrdinaryDiffEqDefault, - SciMLSensitivity, StableRNGs, Test, Zygote, From 787d5bd808cdf53ba6ba7b27838b1702382c728e Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 19:47:16 +0330 Subject: [PATCH 08/14] fix maxlog --- src/exts/dist_ext/core_cond_icnf.jl | 8 ++++---- src/exts/dist_ext/core_icnf.jl | 8 ++++---- src/exts/mlj_ext/core_cond_icnf.jl | 2 +- src/exts/mlj_ext/core_icnf.jl | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/exts/dist_ext/core_cond_icnf.jl b/src/exts/dist_ext/core_cond_icnf.jl index 759cc989..6d3379de 100644 --- a/src/exts/dist_ext/core_cond_icnf.jl +++ b/src/exts/dist_ext/core_cond_icnf.jl @@ -26,7 +26,7 @@ function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return first(Distributions._logpdf(d, hcat(x))) end @@ -34,7 +34,7 @@ function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return Distributions._logpdf.(d, collect(collect.(eachcol(A)))) end @@ -58,7 +58,7 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return x .= Distributions._rand!(rng, d, hcat(x)) end @@ -67,7 +67,7 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) end diff --git a/src/exts/dist_ext/core_icnf.jl b/src/exts/dist_ext/core_icnf.jl index 21a3edf4..51d0c3b6 100644 --- a/src/exts/dist_ext/core_icnf.jl +++ b/src/exts/dist_ext/core_icnf.jl @@ -21,7 +21,7 @@ function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return first(Distributions._logpdf(d, hcat(x))) end @@ -29,7 +29,7 @@ function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return Distributions._logpdf.(d, collect(collect.(eachcol(A)))) end @@ -53,7 +53,7 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return x .= Distributions._rand!(rng, d, hcat(x)) end @@ -62,7 +62,7 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index a6de49a9..84f694e0 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -65,7 +65,7 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) (ps, st) = fitresult logp̂x = if model.icnf.compute_mode isa VectorMode - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 broadcast( function (x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) return first(inference(model.icnf, TestMode{false}(), x, y, ps, st)) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 0b39446f..29ab44ee 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -59,7 +59,7 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) (ps, st) = fitresult logp̂x = if model.icnf.compute_mode isa VectorMode - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 broadcast( function (x::AbstractVector{<:Real}) return first(inference(model.icnf, TestMode{false}(), x, ps, st)) From 34f4b022536a25227f0428bd0edadaadf544b196 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 20:56:00 +0330 Subject: [PATCH 09/14] fix --- src/core/icnf.jl | 5 ++++- src/exts/mlj_ext/core_cond_icnf.jl | 2 +- src/exts/mlj_ext/core_icnf.jl | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 71c91a9d..d3a10dc8 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -77,7 +77,10 @@ function ICNF(; abstol = convert(data_type, 1.0e-8), maxiters = typemax(Int), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), - sensealg = SciMLSensitivity.GaussAdjoint(; autodiff = true, checkpointing = true), + sensealg = SciMLSensitivity.InterpolatingAdjoint(; + autodiff = true, + checkpointing = true, + ), ), rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), λ₁::AbstractFloat = convert(data_type, 1.0e-2), diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 84f694e0..77ac5177 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -10,7 +10,7 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} end function CondICNFModel(; - icnf::AbstractICNF, + icnf::AbstractICNF = ICNF(), loss::Function = loss, optimizers::Tuple = ( Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 29ab44ee..3cbeced3 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -10,7 +10,7 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} end function ICNFModel(; - icnf::AbstractICNF, + icnf::AbstractICNF = ICNF(), loss::Function = loss, optimizers::Tuple = ( Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), From 9418f6854d096b07e74c8c18ba30229d22ac8077 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 20:56:24 +0330 Subject: [PATCH 10/14] fix --- examples/usage.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/usage.jl b/examples/usage.jl index e786f637..8f5b0388 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -52,7 +52,7 @@ icnf = ICNF(; abstol = 1.0f-8, maxiters = typemax(Int), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = True()), - sensealg = GaussAdjoint(; autodiff = true, checkpointing = true), + sensealg = InterpolatingAdjoint(; autodiff = true, checkpointing = true), ), # pass to the solver ) From a1b20b6331855fca8c2bc1fb87c08292c76bd32b Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 22:37:15 +0330 Subject: [PATCH 11/14] remove broken --- test/ci_tests/smoke_tests.jl | 45 ++++++++++-------------------------- 1 file changed, 12 insertions(+), 33 deletions(-) diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index d945afd5..28dd7a4e 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -199,20 +199,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && ( - omode isa ContinuousNormalizingFlows.TrainMode || ( - omode isa ContinuousNormalizingFlows.TestMode && - compute_mode isa ContinuousNormalizingFlows.VectorMode - ) - ) - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && ( - omode isa ContinuousNormalizingFlows.TrainMode || ( - omode isa ContinuousNormalizingFlows.TestMode && - compute_mode isa ContinuousNormalizingFlows.VectorMode - ) - ) + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) if cond model = ContinuousNormalizingFlows.CondICNFModel(; @@ -223,18 +211,14 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ) mach = MLJBase.machine(model, (df, df2)) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), - ) broken = compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + ) else model = ContinuousNormalizingFlows.ICNFModel(; icnf, @@ -244,17 +228,12 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.transform(mach, df)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, df)) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) - Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) end end end From d633c3a3dd6ba44d5138ee1903fd4b3925065444 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 14 Feb 2026 23:45:30 +0330 Subject: [PATCH 12/14] cleaning --- examples/usage.jl | 2 +- src/core/icnf.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/usage.jl b/examples/usage.jl index 8f5b0388..d9ad60c0 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -52,7 +52,7 @@ icnf = ICNF(; abstol = 1.0f-8, maxiters = typemax(Int), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = True()), - sensealg = InterpolatingAdjoint(; autodiff = true, checkpointing = true), + sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true), ), # pass to the solver ) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index d3a10dc8..a9ec0506 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -78,8 +78,8 @@ function ICNF(; maxiters = typemax(Int), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), sensealg = SciMLSensitivity.InterpolatingAdjoint(; - autodiff = true, checkpointing = true, + autodiff = true, ), ), rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), From b92e959f7c30906522157d84f9ee70ed728ed933 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 15 Feb 2026 17:58:59 +0330 Subject: [PATCH 13/14] more cleaning --- benchmark/benchmarks.jl | 4 +- examples/Project.toml | 17 ++++++ examples/usage.jl | 31 ++++------ src/core/icnf.jl | 80 ++++++++++++------------- src/exts/mlj_ext/core_cond_icnf.jl | 8 +-- src/exts/mlj_ext/core_icnf.jl | 8 +-- test/ci_tests/regression_tests.jl | 2 +- test/ci_tests/smoke_tests.jl | 14 ++--- test/ci_tests/speed_tests.jl | 4 +- test/quality_tests/checkby_JET_tests.jl | 10 ++-- 10 files changed, 89 insertions(+), 89 deletions(-) create mode 100644 examples/Project.toml diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index fc19a444..37312736 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -27,10 +27,10 @@ nn = Lux.Chain( Lux.Dense((2 * n_in + 1) => n_in, tanh), ) -icnf = ContinuousNormalizingFlows.ICNF(; nvars, naugmented = naugs, nn, rng) +icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng) icnf2 = - ContinuousNormalizingFlows.ICNF(; nvars, naugmented = naugs, nn, inplace = true, rng) + ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng, inplace = true) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 00000000..35a4f85d --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,17 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ContinuousNormalizingFlows = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/usage.jl b/examples/usage.jl index d9ad60c0..f19c349c 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -1,5 +1,5 @@ # Switch To MKL For Faster Computation -# using MKL +using MKL ## Enable Logging using Logging, TerminalLoggers @@ -20,39 +20,32 @@ n_in = nvars + naugs ## Model using ContinuousNormalizingFlows, - Lux, - OrdinaryDiffEqAdamsBashforthMoulton, - SciMLSensitivity, - Static, - ADTypes, - Zygote, - MLDataDevices + Lux, OrdinaryDiffEqAdamsBashforthMoulton, Static, ADTypes, Zygote, MLDataDevices # To use gpu, add related packages -# using LuxCUDA, CUDA, cuDNN +# using LuxCUDA nn = Chain(Dense(n_in => (2 * n_in + 1), tanh), Dense((2 * n_in + 1) => n_in, tanh)) icnf = ICNF(; + nn = nn, nvars = nvars, # number of variables naugmented = naugs, # number of augmented dimensions - nn = nn, - compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote - inplace = false, # not using the inplace version of functions - cond = false, # not conditioning on auxiliary input - device = cpu_device(), # process data by CPU - # device = gpu_device(), # process data by GPU - tspan = (0.0f0, 1.0f0), # time span - steer_rate = 1.0f-1, # add random noise to end of the time span λ₁ = 1.0f-2, # regulate flow λ₂ = 1.0f-2, # regulate volume change λ₃ = 1.0f-2, # regulate augmented dimensions + steer_rate = 1.0f-1, # add random noise to end of the time span + tspan = (0.0f0, 1.0f0), # time span + device = cpu_device(), # process data by CPU + # device = gpu_device(), # process data by GPU + cond = false, # not conditioning on auxiliary input + inplace = false, # not using the inplace version of functions + compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote sol_kwargs = (; save_everystep = false, reltol = 1.0f-4, abstol = 1.0f-8, maxiters = typemax(Int), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = True()), - sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true), ), # pass to the solver ) @@ -67,8 +60,8 @@ else model = ICNFModel(; icnf, optimizers = (OptimiserChain(WeightDecay(), Adam()),), - adtype = AutoZygote(), batchsize = 1024, + adtype = AutoZygote(), sol_kwargs = (; epochs = 300, progress = true), # pass to the solver ) mach = machine(model, df) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index a9ec0506..61933bb9 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -23,50 +23,53 @@ struct ICNF{ NORM_Z, NORM_J, NORM_Z_AUG, - NN <: LuxCore.AbstractLuxLayer, - NVARS <: Int, DEVICE <: MLDataDevices.AbstractDevice, - BASEDIST <: Distributions.Distribution, + RNG <: Random.AbstractRNG, TSPAN <: NTuple{2, T}, - STEERDIST <: Distributions.Distribution, + NVARS <: Int, + NN <: LuxCore.AbstractLuxLayer, + BASEDIST <: Distributions.Distribution, EPSDIST <: Distributions.Distribution, + STEERDIST <: Distributions.Distribution, SOL_KWARGS <: NamedTuple, - RNG <: Random.AbstractRNG, } <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} - nn::NN - nvars::NVARS - naugmented::NVARS - compute_mode::CM device::DEVICE - basedist::BASEDIST - tspan::TSPAN - steerdist::STEERDIST - epsdist::EPSDIST - sol_kwargs::SOL_KWARGS rng::RNG + tspan::TSPAN + nvars::NVARS + naugmented::NVARS + nn::NN λ₁::T λ₂::T λ₃::T + basedist::BASEDIST + epsdist::EPSDIST + steerdist::STEERDIST + sol_kwargs::SOL_KWARGS end function ICNF(; - nvars::Int = 1, - naugmented::Int = nvars + 1, - nn::LuxCore.AbstractLuxLayer = Lux.Chain( - Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), - ), data_type::Type{<:AbstractFloat} = Float32, compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), inplace::Bool = false, cond::Bool = false, device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(), + rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), + tspan::NTuple{2} = (zero(data_type), one(data_type)), + nvars::Int = 1, + naugmented::Int = nvars + 1, + nn::LuxCore.AbstractLuxLayer = Lux.Chain( + Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), + ), + steer_rate::AbstractFloat = convert(data_type, 1.0e-1), + λ₁::AbstractFloat = convert(data_type, 1.0e-2), + λ₂::AbstractFloat = convert(data_type, 1.0e-2), + λ₃::AbstractFloat = convert(data_type, 1.0e-2), basedist::Distributions.Distribution = Distributions.MvNormal( FillArrays.Zeros{data_type}(nvars + naugmented), FillArrays.Eye{data_type}(nvars + naugmented), ), - tspan::NTuple{2} = (zero(data_type), one(data_type)), - steer_rate::AbstractFloat = convert(data_type, 1.0e-1), epsdist::Distributions.Distribution = Distributions.MvNormal( FillArrays.Zeros{data_type}(nvars + naugmented), FillArrays.Eye{data_type}(nvars + naugmented), @@ -77,18 +80,9 @@ function ICNF(; abstol = convert(data_type, 1.0e-8), maxiters = typemax(Int), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), - sensealg = SciMLSensitivity.InterpolatingAdjoint(; - checkpointing = true, - autodiff = true, - ), ), - rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), - λ₁::AbstractFloat = convert(data_type, 1.0e-2), - λ₂::AbstractFloat = convert(data_type, 1.0e-2), - λ₃::AbstractFloat = convert(data_type, 1.0e-2), ) steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) - return ICNF{ data_type, typeof(compute_mode), @@ -99,30 +93,30 @@ function ICNF(; !iszero(λ₁), !iszero(λ₂), !iszero(λ₃), - typeof(nn), - typeof(nvars), typeof(device), - typeof(basedist), + typeof(rng), typeof(tspan), - typeof(steerdist), + typeof(nvars), + typeof(nn), + typeof(basedist), typeof(epsdist), + typeof(steerdist), typeof(sol_kwargs), - typeof(rng), }( - nn, - nvars, - naugmented, compute_mode, device, - basedist, - tspan, - steerdist, - epsdist, - sol_kwargs, rng, + tspan, + nvars, + naugmented, + nn, λ₁, λ₂, λ₃, + basedist, + epsdist, + steerdist, + sol_kwargs, ) end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 77ac5177..bb66ec83 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -1,11 +1,9 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} icnf::AICNF loss::Function - optimizers::Tuple - adtype::ADTypes.AbstractADType - batchsize::Int + adtype::ADTypes.AbstractADType sol_kwargs::NamedTuple end @@ -15,11 +13,11 @@ function CondICNFModel(; optimizers::Tuple = ( Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), ), - adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), batchsize::Int = 1024, + adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), sol_kwargs::NamedTuple = (; epochs = 300, progress = true), ) - return CondICNFModel(icnf, loss, optimizers, adtype, batchsize, sol_kwargs) + return CondICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) end function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 3cbeced3..97447d97 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -1,11 +1,9 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} icnf::AICNF loss::Function - optimizers::Tuple - adtype::ADTypes.AbstractADType - batchsize::Int + adtype::ADTypes.AbstractADType sol_kwargs::NamedTuple end @@ -15,11 +13,11 @@ function ICNFModel(; optimizers::Tuple = ( Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), ), - adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), batchsize::Int = 1024, + adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), sol_kwargs::NamedTuple = (; epochs = 300, progress = true), ) - return ICNFModel(icnf, loss, optimizers, adtype, batchsize, sol_kwargs) + return ICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) end function MLJModelInterface.fit(model::ICNFModel, verbosity, X) diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index 00ef0e43..f4ea9a24 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -15,7 +15,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test Lux.Dense((2 * n_in + 1) => n_in, tanh), ) - icnf = ContinuousNormalizingFlows.ICNF(; nvars, naugmented = naugs, nn, rng) + icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng) df = DataFrames.DataFrame(transpose(r), :auto) model = ContinuousNormalizingFlows.ICNFModel(; diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index 28dd7a4e..b96ba2dc 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -120,14 +120,14 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ), ) icnf = ContinuousNormalizingFlows.ICNF(; + nn, nvars, naugmented = nvars + 1, - nn, - data_type, - compute_mode, - inplace, - cond, device, + cond, + inplace, + compute_mode, + data_type, ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) @@ -205,8 +205,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be if cond model = ContinuousNormalizingFlows.CondICNFModel(; icnf, - adtype, batchsize = 0, + adtype, sol_kwargs = (; epochs = 2), ) mach = MLJBase.machine(model, (df, df2)) @@ -222,8 +222,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be else model = ContinuousNormalizingFlows.ICNFModel(; icnf, - adtype, batchsize = 0, + adtype, sol_kwargs = (; epochs = 2), ) mach = MLJBase.machine(model, df) diff --git a/test/ci_tests/speed_tests.jl b/test/ci_tests/speed_tests.jl index 3c14ecf5..86491d87 100644 --- a/test/ci_tests/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -51,11 +51,11 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be ) icnf = ContinuousNormalizingFlows.ICNF(; + nn, nvars, naugmented = naugs, - nn, - compute_mode, rng, + compute_mode, ) df = DataFrames.DataFrame(transpose(r), :auto) diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index 91eb7600..d5cca152 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -103,14 +103,14 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg ), ) icnf = ContinuousNormalizingFlows.ICNF(; + nn, nvars, naugmented = nvars + 1, - nn, - data_type, - compute_mode, - inplace, - cond, device, + cond, + inplace, + compute_mode, + data_type, ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) From b7c35965cda2fc6c9455ac0867430261e3588730 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 15 Feb 2026 19:11:35 +0330 Subject: [PATCH 14/14] revert sensealg --- examples/Project.toml | 1 + examples/usage.jl | 13 ++++++++++--- src/core/icnf.jl | 6 +++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/Project.toml b/examples/Project.toml index 35a4f85d..7d108e04 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -12,6 +12,7 @@ MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/usage.jl b/examples/usage.jl index f19c349c..27312bfb 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -20,7 +20,13 @@ n_in = nvars + naugs ## Model using ContinuousNormalizingFlows, - Lux, OrdinaryDiffEqAdamsBashforthMoulton, Static, ADTypes, Zygote, MLDataDevices + Lux, + OrdinaryDiffEqAdamsBashforthMoulton, + Static, + SciMLSensitivity, + ADTypes, + Zygote, + MLDataDevices # To use gpu, add related packages # using LuxCUDA @@ -42,10 +48,11 @@ icnf = ICNF(; compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote sol_kwargs = (; save_everystep = false, + maxiters = typemax(Int), reltol = 1.0f-4, abstol = 1.0f-8, - maxiters = typemax(Int), - alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = True()), + alg = VCABM(; thread = True()), + sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true), ), # pass to the solver ) diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 61933bb9..7391b984 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -76,10 +76,14 @@ function ICNF(; ), sol_kwargs::NamedTuple = (; save_everystep = false, + maxiters = typemax(Int), reltol = convert(data_type, 1.0e-4), abstol = convert(data_type, 1.0e-8), - maxiters = typemax(Int), alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), + sensealg = SciMLSensitivity.InterpolatingAdjoint(; + checkpointing = true, + autodiff = true, + ), ), ) steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate)