diff --git a/.github/workflows/CI-CheckBy.yml b/.github/workflows/CI-CheckBy.yml index e03dfbeb..c84b393f 100644 --- a/.github/workflows/CI-CheckBy.yml +++ b/.github/workflows/CI-CheckBy.yml @@ -24,8 +24,11 @@ jobs: - CheckByJET - CheckByExplicitImports version: - - release - - lts + - "1.10" + - "1.11" + - "1.12" + # - release + # - lts # - nightly os: - ubuntu-latest diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 55012191..c33fca85 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,8 +27,11 @@ jobs: - Regression - Speed version: - - release - - lts + - "1.10" + - "1.11" + - "1.12" + # - release + # - lts # - nightly os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 62530b6f..97374eac 100644 --- a/Project.toml +++ b/Project.toml @@ -23,11 +23,12 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" -OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" +OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -52,11 +53,12 @@ MLUtils = "0.4" NNlib = "0.9" Optimisers = "0.4" OptimizationOptimisers = "0.3" -OrdinaryDiffEqDefault = "1" +OrdinaryDiffEqAdamsBashforthMoulton = "1" Random = "1" SciMLBase = "2" SciMLSensitivity = "7" ScientificTypesBase = "3" +Static = "1" Statistics = "1" WeightInitializers = "1" Zygote = "0.7" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 7fef44ff..f35c328f 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -7,9 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" -SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -22,9 +20,7 @@ Distributions = "0.25" ForwardDiff = "1" Lux = "1" LuxCore = "1" -OrdinaryDiffEqDefault = "1" PkgBenchmark = "0.2" -SciMLSensitivity = "7" StableRNGs = "1" Zygote = "0.7" julia = "1.10" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 1c65edde..37312736 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -6,9 +6,7 @@ import ADTypes, ForwardDiff, Lux, LuxCore, - OrdinaryDiffEqDefault, PkgBenchmark, - SciMLSensitivity, StableRNGs, Zygote, ContinuousNormalizingFlows @@ -21,39 +19,18 @@ r = rand(rng, data_dist, ndimension, ndata) r = convert.(Float32, r) nvars = size(r, 1) -naugs = nvars +naugs = nvars + 1 n_in = nvars + naugs -nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) - -icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, - nvars, - naugs; - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - rng, +nn = Lux.Chain( + Lux.Dense(n_in => (2 * n_in + 1), tanh), + Lux.Dense((2 * n_in + 1) => n_in, tanh), ) -icnf2 = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, - nvars, - naugs; - inplace = true, - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - rng, -) +icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng) + +icnf2 = + ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng, inplace = true) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 00000000..7d108e04 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,18 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ContinuousNormalizingFlows = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/usage.jl b/examples/usage.jl index 82b09192..27312bfb 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -1,5 +1,5 @@ # Switch To MKL For Faster Computation -# using MKL +using MKL ## Enable Logging using Logging, TerminalLoggers @@ -20,47 +20,63 @@ n_in = nvars + naugs ## Model using ContinuousNormalizingFlows, - Lux, OrdinaryDiffEqAdamsBashforthMoulton, ADTypes, Zygote, MLDataDevices + Lux, + OrdinaryDiffEqAdamsBashforthMoulton, + Static, + SciMLSensitivity, + ADTypes, + Zygote, + MLDataDevices # To use gpu, add related packages -# using LuxCUDA, CUDA, cuDNN +# using LuxCUDA -nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh)) -icnf = construct( - ICNF, - nn, - nvars, # number of variables - naugs; # number of augmented dimensions - compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote - inplace = false, # not using the inplace version of functions - device = cpu_device(), # process data by CPU - # device = gpu_device(), # process data by GPU - tspan = (0.0f0, 1.0f0), # time span - steer_rate = 1.0f-1, # add random noise to end of the time span +nn = Chain(Dense(n_in => (2 * n_in + 1), tanh), Dense((2 * n_in + 1) => n_in, tanh)) +icnf = ICNF(; + nn = nn, + nvars = nvars, # number of variables + naugmented = naugs, # number of augmented dimensions λ₁ = 1.0f-2, # regulate flow λ₂ = 1.0f-2, # regulate volume change λ₃ = 1.0f-2, # regulate augmented dimensions - sol_kwargs = (; save_everystep = false, alg = VCABM()), # pass to the solver + steer_rate = 1.0f-1, # add random noise to end of the time span + tspan = (0.0f0, 1.0f0), # time span + device = cpu_device(), # process data by CPU + # device = gpu_device(), # process data by GPU + cond = false, # not conditioning on auxiliary input + inplace = false, # not using the inplace version of functions + compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote + sol_kwargs = (; + save_everystep = false, + maxiters = typemax(Int), + reltol = 1.0f-4, + abstol = 1.0f-8, + alg = VCABM(; thread = True()), + sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true), + ), # pass to the solver ) ## Fit It using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers -df = DataFrame(transpose(r), :auto) -model = ICNFModel( - icnf; - optimizers = (Adam(),), - adtype = AutoZygote(), - batchsize = 512, - sol_kwargs = (; epochs = 300, progress = true), # pass to the solver -) -mach = machine(model, df) -fit!(mach) -# CUDA.@allowscalar fit!(mach) # needed for gpu -## Store It icnf_mach_fn = "icnf_mach.jls" -MLJBase.save(icnf_mach_fn, mach) # save it -mach = machine(icnf_mach_fn) # load it +if ispath(icnf_mach_fn) + mach = machine(icnf_mach_fn) # load it +else + df = DataFrame(transpose(r), :auto) + model = ICNFModel(; + icnf, + optimizers = (OptimiserChain(WeightDecay(), Adam()),), + batchsize = 1024, + adtype = AutoZygote(), + sol_kwargs = (; epochs = 300, progress = true), # pass to the solver + ) + mach = machine(model, df) + fit!(mach) + # CUDA.@allowscalar fit!(mach) # needed for gpu + + MLJBase.save(icnf_mach_fn, mach) # save it +end ## Use It d = ICNFDist(mach, TestMode()) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index fe57b04d..24a2082c 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -19,26 +19,20 @@ import ADTypes, NNlib, Optimisers, OptimizationOptimisers, - OrdinaryDiffEqDefault, + OrdinaryDiffEqAdamsBashforthMoulton, Random, SciMLBase, SciMLSensitivity, ScientificTypesBase, + Static, Statistics, WeightInitializers, Zygote -export construct, - inference, +export inference, generate, loss, ICNF, - RNODE, - CondRNODE, - FFJORD, - CondFFJORD, - Planar, - CondPlanar, TestMode, TrainMode, DIVecJacVectorMode, diff --git a/src/core/base_icnf.jl b/src/core/base_icnf.jl index b73de94b..3052ee82 100644 --- a/src/core/base_icnf.jl +++ b/src/core/base_icnf.jl @@ -1,80 +1,3 @@ -function construct( - aicnf::Type{<:AbstractICNF}, - nn::LuxCore.AbstractLuxLayer, - nvars::Int, - naugmented::Int = 0; - data_type::Type{<:AbstractFloat} = Float32, - compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), - inplace::Bool = false, - cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar}, - device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(), - basedist::Distributions.Distribution = Distributions.MvNormal( - FillArrays.Zeros{data_type}(nvars + naugmented), - FillArrays.Eye{data_type}(nvars + naugmented), - ), - tspan::NTuple{2} = (zero(data_type), one(data_type)), - steer_rate::AbstractFloat = zero(data_type), - epsdist::Distributions.Distribution = Distributions.MvNormal( - FillArrays.Zeros{data_type}(nvars + naugmented), - FillArrays.Eye{data_type}(nvars + naugmented), - ), - sol_kwargs::NamedTuple = (;), - rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), - λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1.0e-2) - else - zero(data_type) - end, - λ₂::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1.0e-2) - else - zero(data_type) - end, - λ₃::AbstractFloat = if naugmented >= nvars - convert(data_type, 1.0e-2) - else - zero(data_type) - end, -) - steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) - - return ICNF{ - data_type, - typeof(compute_mode), - inplace, - cond, - !iszero(naugmented), - !iszero(steer_rate), - !iszero(λ₁), - !iszero(λ₂), - !iszero(λ₃), - typeof(nn), - typeof(nvars), - typeof(device), - typeof(basedist), - typeof(tspan), - typeof(steerdist), - typeof(epsdist), - typeof(sol_kwargs), - typeof(rng), - }( - nn, - nvars, - naugmented, - compute_mode, - device, - basedist, - tspan, - steerdist, - epsdist, - sol_kwargs, - rng, - λ₁, - λ₂, - λ₃, - ) -end - function Base.show(io::IO, icnf::AbstractICNF) return print(io, typeof(icnf)) end diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 0c4622f0..7391b984 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -1,60 +1,3 @@ -struct Planar{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondPlanar{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - -struct FFJORD{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondFFJORD{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - -struct RNODE{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondRNODE{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - """ Implementation of ICNF. @@ -65,6 +8,10 @@ Refs: [Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018).](https://arxiv.org/abs/1810.01367) [Finlay, Chris, Jörn-Henrik Jacobsen, Levon Nurbekyan, and Adam M. Oberman. "How to train your neural ODE: the world of Jacobian and kinetic regularization." arXiv preprint arXiv:2002.02798 (2020).](https://arxiv.org/abs/2002.02798) + +[Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. "Augmented Neural ODEs." arXiv preprint arXiv:1904.01681 (2019).](https://arxiv.org/abs/1904.01681) + +[Ghosh, Arnab, Harkirat Singh Behl, Emilien Dupont, Philip HS Torr, and Vinay Namboodiri. "STEER: Simple Temporal Regularization For Neural ODEs." arXiv preprint arXiv:2006.10711 (2020).](https://arxiv.org/abs/2006.10711) """ struct ICNF{ T <: AbstractFloat, @@ -76,31 +23,105 @@ struct ICNF{ NORM_Z, NORM_J, NORM_Z_AUG, - NN <: LuxCore.AbstractLuxLayer, - NVARS <: Int, DEVICE <: MLDataDevices.AbstractDevice, - BASEDIST <: Distributions.Distribution, + RNG <: Random.AbstractRNG, TSPAN <: NTuple{2, T}, - STEERDIST <: Distributions.Distribution, + NVARS <: Int, + NN <: LuxCore.AbstractLuxLayer, + BASEDIST <: Distributions.Distribution, EPSDIST <: Distributions.Distribution, + STEERDIST <: Distributions.Distribution, SOL_KWARGS <: NamedTuple, - RNG <: Random.AbstractRNG, } <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} - nn::NN - nvars::NVARS - naugmented::NVARS - compute_mode::CM device::DEVICE - basedist::BASEDIST - tspan::TSPAN - steerdist::STEERDIST - epsdist::EPSDIST - sol_kwargs::SOL_KWARGS rng::RNG + tspan::TSPAN + nvars::NVARS + naugmented::NVARS + nn::NN λ₁::T λ₂::T λ₃::T + basedist::BASEDIST + epsdist::EPSDIST + steerdist::STEERDIST + sol_kwargs::SOL_KWARGS +end + +function ICNF(; + data_type::Type{<:AbstractFloat} = Float32, + compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()), + inplace::Bool = false, + cond::Bool = false, + device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(), + rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device), + tspan::NTuple{2} = (zero(data_type), one(data_type)), + nvars::Int = 1, + naugmented::Int = nvars + 1, + nn::LuxCore.AbstractLuxLayer = Lux.Chain( + Lux.Dense((nvars + naugmented) => (nvars + naugmented), tanh), + ), + steer_rate::AbstractFloat = convert(data_type, 1.0e-1), + λ₁::AbstractFloat = convert(data_type, 1.0e-2), + λ₂::AbstractFloat = convert(data_type, 1.0e-2), + λ₃::AbstractFloat = convert(data_type, 1.0e-2), + basedist::Distributions.Distribution = Distributions.MvNormal( + FillArrays.Zeros{data_type}(nvars + naugmented), + FillArrays.Eye{data_type}(nvars + naugmented), + ), + epsdist::Distributions.Distribution = Distributions.MvNormal( + FillArrays.Zeros{data_type}(nvars + naugmented), + FillArrays.Eye{data_type}(nvars + naugmented), + ), + sol_kwargs::NamedTuple = (; + save_everystep = false, + maxiters = typemax(Int), + reltol = convert(data_type, 1.0e-4), + abstol = convert(data_type, 1.0e-8), + alg = OrdinaryDiffEqAdamsBashforthMoulton.VCABM(; thread = Static.True()), + sensealg = SciMLSensitivity.InterpolatingAdjoint(; + checkpointing = true, + autodiff = true, + ), + ), +) + steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) + return ICNF{ + data_type, + typeof(compute_mode), + inplace, + cond, + !iszero(naugmented), + !iszero(steer_rate), + !iszero(λ₁), + !iszero(λ₂), + !iszero(λ₃), + typeof(device), + typeof(rng), + typeof(tspan), + typeof(nvars), + typeof(nn), + typeof(basedist), + typeof(epsdist), + typeof(steerdist), + typeof(sol_kwargs), + }( + compute_mode, + device, + rng, + tspan, + nvars, + naugmented, + nn, + λ₁, + λ₂, + λ₃, + basedist, + epsdist, + steerdist, + sol_kwargs, + ) end function n_augment(::ICNF, ::Mode) diff --git a/src/core/types.jl b/src/core/types.jl index 2fcba0cb..4ba974b3 100644 --- a/src/core/types.jl +++ b/src/core/types.jl @@ -3,7 +3,7 @@ struct TestMode{REG} <: Mode{REG} end struct TrainMode{REG} <: Mode{REG} end function TestMode() - return TestMode{false}() + return TestMode{true}() end function TrainMode() diff --git a/src/exts/dist_ext/core.jl b/src/exts/dist_ext/core.jl index a22171f4..97825690 100644 --- a/src/exts/dist_ext/core.jl +++ b/src/exts/dist_ext/core.jl @@ -4,7 +4,7 @@ abstract type ICNFDistribution{AICNF <: AbstractICNF} <: Distributions.ContinuousMultivariateDistribution end function Base.length(d::ICNFDistribution) - return d.m.nvars + return d.icnf.nvars end function Base.eltype(::ICNFDistribution{AICNF}) where {AICNF <: AbstractICNF} diff --git a/src/exts/dist_ext/core_cond_icnf.jl b/src/exts/dist_ext/core_cond_icnf.jl index 7043fde3..6d3379de 100644 --- a/src/exts/dist_ext/core_cond_icnf.jl +++ b/src/exts/dist_ext/core_cond_icnf.jl @@ -1,5 +1,5 @@ struct CondICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} - m::AICNF + icnf::AICNF mode::Mode ys::AbstractVecOrMat{<:Real} ps::Any @@ -12,21 +12,21 @@ function CondICNFDist( ys::AbstractVecOrMat{<:Real}, ) (ps, st) = MLJModelInterface.fitted_params(mach) - return CondICNFDist(mach.model.m, mode, ys, ps, st) + return CondICNFDist(mach.model.icnf, mode, ys, ps, st) end function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) + return first(inference(d.icnf, d.mode, x, d.ys, d.ps, d.st)) end function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return first(Distributions._logpdf(d, hcat(x))) end @@ -34,7 +34,7 @@ function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return Distributions._logpdf.(d, collect(collect.(eachcol(A)))) end @@ -42,7 +42,7 @@ function Distributions._logpdf( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) + return first(inference(d.icnf, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) end function Distributions._rand!( @@ -50,7 +50,7 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return x .= generate(d.m, d.mode, d.ys, d.ps, d.st) + return x .= generate(d.icnf, d.mode, d.ys, d.ps, d.st) end function Distributions._rand!( @@ -58,7 +58,7 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return x .= Distributions._rand!(rng, d, hcat(x)) end @@ -67,7 +67,7 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) end @@ -76,5 +76,5 @@ function Distributions._rand!( d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) + return A .= generate(d.icnf, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) end diff --git a/src/exts/dist_ext/core_icnf.jl b/src/exts/dist_ext/core_icnf.jl index 9f47a6da..51d0c3b6 100644 --- a/src/exts/dist_ext/core_icnf.jl +++ b/src/exts/dist_ext/core_icnf.jl @@ -1,5 +1,5 @@ struct ICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} - m::AICNF + icnf::AICNF mode::Mode ps::Any st::NamedTuple @@ -7,21 +7,21 @@ end function ICNFDist(mach::MLJBase.Machine{<:ICNFModel}, mode::Mode) (ps, st) = MLJModelInterface.fitted_params(mach) - return ICNFDist(mach.model.m, mode, ps, st) + return ICNFDist(mach.model.icnf, mode, ps, st) end function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return first(inference(d.m, d.mode, x, d.ps, d.st)) + return first(inference(d.icnf, d.mode, x, d.ps, d.st)) end function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return first(Distributions._logpdf(d, hcat(x))) end @@ -29,7 +29,7 @@ function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return Distributions._logpdf.(d, collect(collect.(eachcol(A)))) end @@ -37,7 +37,7 @@ function Distributions._logpdf( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return first(inference(d.m, d.mode, A, d.ps, d.st)) + return first(inference(d.icnf, d.mode, A, d.ps, d.st)) end function Distributions._rand!( @@ -45,7 +45,7 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return x .= generate(d.m, d.mode, d.ps, d.st) + return x .= generate(d.icnf, d.mode, d.ps, d.st) end function Distributions._rand!( @@ -53,7 +53,7 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, x::AbstractVector{<:Real}, ) - @warn maxlog = 1 "to compute by matrices, data should be a matrix." + @warn "to compute by matrices, data should be a matrix." maxlog = 1 return x .= Distributions._rand!(rng, d, hcat(x)) end @@ -62,7 +62,7 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, A::AbstractMatrix{<:Real}, ) - @warn maxlog = 1 "to compute by vectors, data should be a vector." + @warn "to compute by vectors, data should be a vector." maxlog = 1 return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) end @@ -71,5 +71,5 @@ function Distributions._rand!( d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) + return A .= generate(d.icnf, d.mode, d.ps, d.st, size(A, 2)) end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 834d08aa..bb66ec83 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -1,46 +1,46 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} - m::AICNF + icnf::AICNF loss::Function - optimizers::Tuple - adtype::ADTypes.AbstractADType - batchsize::Int + adtype::ADTypes.AbstractADType sol_kwargs::NamedTuple end -function CondICNFModel( - m::AbstractICNF, - loss::Function = loss; - optimizers::Tuple = (Optimisers.Adam(),), +function CondICNFModel(; + icnf::AbstractICNF = ICNF(), + loss::Function = loss, + optimizers::Tuple = ( + Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), + ), + batchsize::Int = 1024, adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - batchsize::Int = 32, - sol_kwargs::NamedTuple = (;), + sol_kwargs::NamedTuple = (; epochs = 300, progress = true), ) - return CondICNFModel(m, loss, optimizers, adtype, batchsize, sol_kwargs) + return CondICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) end function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) X, Y = XY x = collect(transpose(MLJModelInterface.matrix(X))) y = collect(transpose(MLJModelInterface.matrix(Y))) - ps, st = LuxCore.setup(model.m.rng, model.m) + ps, st = LuxCore.setup(model.icnf.rng, model.icnf) ps = ComponentArrays.ComponentArray(ps) - x = model.m.device(x) - y = model.m.device(y) - ps = model.m.device(ps) - st = model.m.device(st) - data = make_dataloader(model.m, model.batchsize, (x, y)) - data = model.m.device(data) + x = model.icnf.device(x) + y = model.icnf.device(y) + ps = model.icnf.device(ps) + st = model.icnf.device(st) + data = make_dataloader(model.icnf, model.batchsize, (x, y)) + data = model.icnf.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( - make_opt_loss(model.m, TrainMode{true}(), st, model.loss), + make_opt_loss(model.icnf, TrainMode{true}(), st, model.loss), model.adtype, ), ps, data, ) - res_stats = SciMLBase.OptimizationStats[] + res_stats = Any[] for opt in model.optimizers optprob_re = SciMLBase.remake(optprob; u0 = ps) res = SciMLBase.solve(optprob_re, opt; model.sol_kwargs...) @@ -58,21 +58,21 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) Xnew, Ynew = XYnew xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) ynew = collect(transpose(MLJModelInterface.matrix(Ynew))) - xnew = model.m.device(xnew) - ynew = model.m.device(ynew) + xnew = model.icnf.device(xnew) + ynew = model.icnf.device(ynew) (ps, st) = fitresult - logp̂x = if model.m.compute_mode isa VectorMode - @warn maxlog = 1 "to compute by vectors, data should be a vector." + logp̂x = if model.icnf.compute_mode isa VectorMode + @warn "to compute by vectors, data should be a vector." maxlog = 1 broadcast( function (x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return first(inference(model.m, TestMode{false}(), x, y, ps, st)) + return first(inference(model.icnf, TestMode{false}(), x, y, ps, st)) end, collect(collect.(eachcol(xnew))), collect(collect.(eachcol(ynew))), ) - elseif model.m.compute_mode isa MatrixMode - first(inference(model.m, TestMode{false}(), xnew, ynew, ps, st)) + elseif model.icnf.compute_mode isa MatrixMode + first(inference(model.icnf, TestMode{false}(), xnew, ynew, ps, st)) else error("Not Implemented") end diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 2db39ff3..97447d97 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -1,43 +1,43 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} - m::AICNF + icnf::AICNF loss::Function - optimizers::Tuple - adtype::ADTypes.AbstractADType - batchsize::Int + adtype::ADTypes.AbstractADType sol_kwargs::NamedTuple end -function ICNFModel( - m::AbstractICNF, - loss::Function = loss; - optimizers::Tuple = (Optimisers.Adam(),), +function ICNFModel(; + icnf::AbstractICNF = ICNF(), + loss::Function = loss, + optimizers::Tuple = ( + Optimisers.OptimiserChain(Optimisers.WeightDecay(), Optimisers.Adam()), + ), + batchsize::Int = 1024, adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - batchsize::Int = 32, - sol_kwargs::NamedTuple = (;), + sol_kwargs::NamedTuple = (; epochs = 300, progress = true), ) - return ICNFModel(m, loss, optimizers, adtype, batchsize, sol_kwargs) + return ICNFModel(icnf, loss, optimizers, batchsize, adtype, sol_kwargs) end function MLJModelInterface.fit(model::ICNFModel, verbosity, X) x = collect(transpose(MLJModelInterface.matrix(X))) - ps, st = LuxCore.setup(model.m.rng, model.m) + ps, st = LuxCore.setup(model.icnf.rng, model.icnf) ps = ComponentArrays.ComponentArray(ps) - x = model.m.device(x) - ps = model.m.device(ps) - st = model.m.device(st) - data = make_dataloader(model.m, model.batchsize, (x,)) - data = model.m.device(data) + x = model.icnf.device(x) + ps = model.icnf.device(ps) + st = model.icnf.device(st) + data = make_dataloader(model.icnf, model.batchsize, (x,)) + data = model.icnf.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( - make_opt_loss(model.m, TrainMode{true}(), st, model.loss), + make_opt_loss(model.icnf, TrainMode{true}(), st, model.loss), model.adtype, ), ps, data, ) - res_stats = SciMLBase.OptimizationStats[] + res_stats = Any[] for opt in model.optimizers optprob_re = SciMLBase.remake(optprob; u0 = ps) res = SciMLBase.solve(optprob_re, opt; model.sol_kwargs...) @@ -53,19 +53,19 @@ end function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) - xnew = model.m.device(xnew) + xnew = model.icnf.device(xnew) (ps, st) = fitresult - logp̂x = if model.m.compute_mode isa VectorMode - @warn maxlog = 1 "to compute by vectors, data should be a vector." + logp̂x = if model.icnf.compute_mode isa VectorMode + @warn "to compute by vectors, data should be a vector." maxlog = 1 broadcast( function (x::AbstractVector{<:Real}) - return first(inference(model.m, TestMode{false}(), x, ps, st)) + return first(inference(model.icnf, TestMode{false}(), x, ps, st)) end, collect(collect.(eachcol(xnew))), ) - elseif model.m.compute_mode isa MatrixMode - first(inference(model.m, TestMode{false}(), xnew, ps, st)) + elseif model.icnf.compute_mode isa MatrixMode + first(inference(model.icnf, TestMode{false}(), xnew, ps, st)) else error("Not Implemented") end diff --git a/src/layers/cond_layer.jl b/src/layers/cond_layer.jl index cdd75f00..e97bf9fa 100644 --- a/src/layers/cond_layer.jl +++ b/src/layers/cond_layer.jl @@ -1,4 +1,4 @@ -struct CondLayer{NN <: LuxCore.AbstractLuxLayer, AT <: AbstractArray} <: +struct CondLayer{NN <: LuxCore.AbstractLuxLayer, AT <: Any} <: LuxCore.AbstractLuxWrapperLayer{:nn} nn::NN ys::AT diff --git a/test/Project.toml b/test/Project.toml index dd245cf4..ee838c79 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,8 +14,6 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" -SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -36,8 +34,6 @@ Lux = "1" LuxCore = "1" MLDataDevices = "1" MLJBase = "1" -OrdinaryDiffEqDefault = "1" -SciMLSensitivity = "7" StableRNGs = "1" Test = "1" Zygote = "0.7" diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index 07f09ada..f4ea9a24 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -7,28 +7,19 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test r = convert.(Float32, r) nvars = size(r, 1) - naugs = nvars + naugs = nvars + 1 n_in = nvars + naugs - nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) - - icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, - nvars, - naugs; - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - rng, + nn = Lux.Chain( + Lux.Dense(n_in => (2 * n_in + 1), tanh), + Lux.Dense((2 * n_in + 1) => n_in, tanh), ) + icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng) + df = DataFrames.DataFrame(transpose(r), :auto) - model = ContinuousNormalizingFlows.ICNFModel( - icnf; + model = ContinuousNormalizingFlows.ICNFModel(; + icnf, batchsize = 0, sol_kwargs = (; epochs = 300), ) diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index fae767e8..b96ba2dc 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -1,5 +1,4 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" begin - mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF] omodes = ContinuousNormalizingFlows.Mode[ ContinuousNormalizingFlows.TrainMode{true}(), ContinuousNormalizingFlows.TestMode{true}(), @@ -76,8 +75,8 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ), ] - Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in - devices, + Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode" for device in + devices, data_type in data_types, compute_mode in compute_modes, ndata in ndata_, @@ -85,8 +84,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be inplace in inplaces, cond in conds, planar in planars, - omode in omodes, - mt in mts + omode in omodes data_dist = Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) @@ -107,30 +105,29 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be ifelse( planar, Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh; n_cond = nvars), + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2 + 1, + tanh; + n_cond = nvars, + ), ), - Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars * 3 + 1 => nvars * 2 + 1, tanh)), ), ifelse( planar, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2 + 1, tanh)), + Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), ), ) - icnf = ContinuousNormalizingFlows.construct( - mt, + icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, - nvars; - data_type, - compute_mode, - inplace, - cond, + naugmented = nvars + 1, device, - steer_rate = convert(data_type, 1.0e-1), - λ₁ = convert(data_type, 1.0e-2), - λ₂ = convert(data_type, 1.0e-2), - λ₃ = convert(data_type, 1.0e-2), + cond, + inplace, + compute_mode, + data_type, ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) @@ -202,63 +199,41 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in adtypes - - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && ( - omode isa ContinuousNormalizingFlows.TrainMode || ( - omode isa ContinuousNormalizingFlows.TestMode && - compute_mode isa ContinuousNormalizingFlows.VectorMode - ) - ) - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && ( - omode isa ContinuousNormalizingFlows.TrainMode || ( - omode isa ContinuousNormalizingFlows.TestMode && - compute_mode isa ContinuousNormalizingFlows.VectorMode - ) - ) + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) if cond - model = ContinuousNormalizingFlows.CondICNFModel( - icnf; - adtype, + model = ContinuousNormalizingFlows.CondICNFModel(; + icnf, batchsize = 0, + adtype, sol_kwargs = (; epochs = 2), ) mach = MLJBase.machine(model, (df, df2)) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), - ) broken = compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + ) else - model = ContinuousNormalizingFlows.ICNFModel( - icnf; - adtype, + model = ContinuousNormalizingFlows.ICNFModel(; + icnf, batchsize = 0, + adtype, sol_kwargs = (; epochs = 2), ) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.transform(mach, df)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - Test.@test !isnothing(MLJBase.serializable(mach)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, df)) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + Test.@test !isnothing(MLJBase.serializable(mach)) - Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} + Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) end end end diff --git a/test/ci_tests/speed_tests.jl b/test/ci_tests/speed_tests.jl index 9dcf3f51..86491d87 100644 --- a/test/ci_tests/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -32,7 +32,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be Test.@testset verbose = true showtiming = true failfast = false "$compute_mode" for compute_mode in compute_modes - @show compute_mode rng = StableRNGs.StableRNG(1) @@ -43,28 +42,25 @@ Test.@testset verbose = true showtiming = true failfast = false "Speed Tests" be r = convert.(Float32, r) nvars = size(r, 1) - naugs = nvars + naugs = nvars + 1 n_in = nvars + naugs - nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) + nn = Lux.Chain( + Lux.Dense(n_in => (2 * n_in + 1), tanh), + Lux.Dense((2 * n_in + 1) => n_in, tanh), + ) - icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, + icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, - naugs; - compute_mode, - tspan = (0.0f0, 1.0f0), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, + naugmented = naugs, rng, + compute_mode, ) df = DataFrames.DataFrame(transpose(r), :auto) - model = ContinuousNormalizingFlows.ICNFModel( - icnf; + model = ContinuousNormalizingFlows.ICNFModel(; + icnf, batchsize = 0, sol_kwargs = (; epochs = 5), ) diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index 72ea31af..d5cca152 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -4,7 +4,6 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg target_modules = (ContinuousNormalizingFlows,), ) - mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF] omodes = ContinuousNormalizingFlows.Mode[ ContinuousNormalizingFlows.TrainMode{true}(), ContinuousNormalizingFlows.TestMode{true}(), @@ -61,8 +60,8 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg ), ] - Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in - devices, + Test.@testset verbose = true showtiming = true failfast = false "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode" for device in + devices, data_type in data_types, compute_mode in compute_modes, ndata in ndata_, @@ -70,8 +69,7 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg inplace in inplaces, cond in conds, planar in planars, - omode in omodes, - mt in mts + omode in omodes data_dist = Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) @@ -90,30 +88,29 @@ Test.@testset verbose = true showtiming = true failfast = false "CheckByJET" beg ifelse( planar, Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh; n_cond = nvars), + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2 + 1, + tanh; + n_cond = nvars, + ), ), - Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars * 3 + 1 => nvars * 2 + 1, tanh)), ), ifelse( planar, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2 + 1, tanh)), + Lux.Chain(Lux.Dense(nvars * 2 + 1 => nvars * 2 + 1, tanh)), ), ) - icnf = ContinuousNormalizingFlows.construct( - mt, + icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, - nvars; - data_type, - compute_mode, - inplace, - cond, + naugmented = nvars + 1, device, - steer_rate = convert(data_type, 1.0e-1), - λ₁ = convert(data_type, 1.0e-2), - λ₂ = convert(data_type, 1.0e-2), - λ₃ = convert(data_type, 1.0e-2), + cond, + inplace, + compute_mode, + data_type, ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) diff --git a/test/runtests.jl b/test/runtests.jl index 0a4ba95d..604ccb44 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,6 @@ import ADTypes, LuxCore, MLDataDevices, MLJBase, - OrdinaryDiffEqDefault, - SciMLSensitivity, StableRNGs, Test, Zygote,