Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/CI-CheckBy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ jobs:
- CheckByJET
- CheckByExplicitImports
version:
- release
- lts
- "1.10"
- "1.11"
- "1.12"
# - release
# - lts
# - nightly
os:
- ubuntu-latest
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ jobs:
- Regression
- Speed
version:
- release
- lts
- "1.10"
- "1.11"
- "1.12"
# - release
# - lts
# - nightly
os:
- ubuntu-latest
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -52,11 +53,12 @@ MLUtils = "0.4"
NNlib = "0.9"
Optimisers = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEqDefault = "1"
OrdinaryDiffEqAdamsBashforthMoulton = "1"
Random = "1"
SciMLBase = "2"
SciMLSensitivity = "7"
ScientificTypesBase = "3"
Static = "1"
Statistics = "1"
WeightInitializers = "1"
Zygote = "0.7"
Expand Down
4 changes: 0 additions & 4 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -22,9 +20,7 @@ Distributions = "0.25"
ForwardDiff = "1"
Lux = "1"
LuxCore = "1"
OrdinaryDiffEqDefault = "1"
PkgBenchmark = "0.2"
SciMLSensitivity = "7"
StableRNGs = "1"
Zygote = "0.7"
julia = "1.10"
39 changes: 8 additions & 31 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import ADTypes,
ForwardDiff,
Lux,
LuxCore,
OrdinaryDiffEqDefault,
PkgBenchmark,
SciMLSensitivity,
StableRNGs,
Zygote,
ContinuousNormalizingFlows
Expand All @@ -21,39 +19,18 @@ r = rand(rng, data_dist, ndimension, ndata)
r = convert.(Float32, r)

nvars = size(r, 1)
naugs = nvars
naugs = nvars + 1
n_in = nvars + naugs

nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh))

icnf = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.ICNF,
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 1.0f0),
steer_rate = 1.0f-1,
λ₁ = 1.0f-2,
λ₂ = 1.0f-2,
λ₃ = 1.0f-2,
rng,
nn = Lux.Chain(
Lux.Dense(n_in => (2 * n_in + 1), tanh),
Lux.Dense((2 * n_in + 1) => n_in, tanh),
)

icnf2 = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.ICNF,
nn,
nvars,
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 1.0f0),
steer_rate = 1.0f-1,
λ₁ = 1.0f-2,
λ₂ = 1.0f-2,
λ₃ = 1.0f-2,
rng,
)
icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng)

icnf2 =
ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng, inplace = true)

ps, st = LuxCore.setup(icnf.rng, icnf)
ps = ComponentArrays.ComponentArray(ps)
Expand Down
18 changes: 18 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ContinuousNormalizingFlows = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEqAdamsBashforthMoulton = "89bda076-bce5-4f1c-845f-551c83cdda9a"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
76 changes: 46 additions & 30 deletions examples/usage.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Switch To MKL For Faster Computation
# using MKL
using MKL

## Enable Logging
using Logging, TerminalLoggers
Expand All @@ -20,47 +20,63 @@ n_in = nvars + naugs

## Model
using ContinuousNormalizingFlows,
Lux, OrdinaryDiffEqAdamsBashforthMoulton, ADTypes, Zygote, MLDataDevices
Lux,
OrdinaryDiffEqAdamsBashforthMoulton,
Static,
SciMLSensitivity,
ADTypes,
Zygote,
MLDataDevices

# To use gpu, add related packages
# using LuxCUDA, CUDA, cuDNN
# using LuxCUDA

nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
icnf = construct(
ICNF,
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
inplace = false, # not using the inplace version of functions
device = cpu_device(), # process data by CPU
# device = gpu_device(), # process data by GPU
tspan = (0.0f0, 1.0f0), # time span
steer_rate = 1.0f-1, # add random noise to end of the time span
nn = Chain(Dense(n_in => (2 * n_in + 1), tanh), Dense((2 * n_in + 1) => n_in, tanh))
icnf = ICNF(;
nn = nn,
nvars = nvars, # number of variables
naugmented = naugs, # number of augmented dimensions
λ₁ = 1.0f-2, # regulate flow
λ₂ = 1.0f-2, # regulate volume change
λ₃ = 1.0f-2, # regulate augmented dimensions
sol_kwargs = (; save_everystep = false, alg = VCABM()), # pass to the solver
steer_rate = 1.0f-1, # add random noise to end of the time span
tspan = (0.0f0, 1.0f0), # time span
device = cpu_device(), # process data by CPU
# device = gpu_device(), # process data by GPU
cond = false, # not conditioning on auxiliary input
inplace = false, # not using the inplace version of functions
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
sol_kwargs = (;
save_everystep = false,
maxiters = typemax(Int),
reltol = 1.0f-4,
abstol = 1.0f-8,
alg = VCABM(; thread = True()),
sensealg = InterpolatingAdjoint(; checkpointing = true, autodiff = true),
), # pass to the solver
)

## Fit It
using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers
df = DataFrame(transpose(r), :auto)
model = ICNFModel(
icnf;
optimizers = (Adam(),),
adtype = AutoZygote(),
batchsize = 512,
sol_kwargs = (; epochs = 300, progress = true), # pass to the solver
)
mach = machine(model, df)
fit!(mach)
# CUDA.@allowscalar fit!(mach) # needed for gpu

## Store It
icnf_mach_fn = "icnf_mach.jls"
MLJBase.save(icnf_mach_fn, mach) # save it
mach = machine(icnf_mach_fn) # load it
if ispath(icnf_mach_fn)
mach = machine(icnf_mach_fn) # load it
else
df = DataFrame(transpose(r), :auto)
model = ICNFModel(;
icnf,
optimizers = (OptimiserChain(WeightDecay(), Adam()),),
batchsize = 1024,
adtype = AutoZygote(),
sol_kwargs = (; epochs = 300, progress = true), # pass to the solver
)
mach = machine(model, df)
fit!(mach)
# CUDA.@allowscalar fit!(mach) # needed for gpu

MLJBase.save(icnf_mach_fn, mach) # save it
end

## Use It
d = ICNFDist(mach, TestMode())
Expand Down
12 changes: 3 additions & 9 deletions src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,20 @@ import ADTypes,
NNlib,
Optimisers,
OptimizationOptimisers,
OrdinaryDiffEqDefault,
OrdinaryDiffEqAdamsBashforthMoulton,
Random,
SciMLBase,
SciMLSensitivity,
ScientificTypesBase,
Static,
Statistics,
WeightInitializers,
Zygote

export construct,
inference,
export inference,
generate,
loss,
ICNF,
RNODE,
CondRNODE,
FFJORD,
CondFFJORD,
Planar,
CondPlanar,
TestMode,
TrainMode,
DIVecJacVectorMode,
Expand Down
77 changes: 0 additions & 77 deletions src/core/base_icnf.jl
Original file line number Diff line number Diff line change
@@ -1,80 +1,3 @@
function construct(
aicnf::Type{<:AbstractICNF},
nn::LuxCore.AbstractLuxLayer,
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()),
inplace::Bool = false,
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(),
basedist::Distributions.Distribution = Distributions.MvNormal(
FillArrays.Zeros{data_type}(nvars + naugmented),
FillArrays.Eye{data_type}(nvars + naugmented),
),
tspan::NTuple{2} = (zero(data_type), one(data_type)),
steer_rate::AbstractFloat = zero(data_type),
epsdist::Distributions.Distribution = Distributions.MvNormal(
FillArrays.Zeros{data_type}(nvars + naugmented),
FillArrays.Eye{data_type}(nvars + naugmented),
),
sol_kwargs::NamedTuple = (;),
rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device),
λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE}
convert(data_type, 1.0e-2)
else
zero(data_type)
end,
λ₂::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE}
convert(data_type, 1.0e-2)
else
zero(data_type)
end,
λ₃::AbstractFloat = if naugmented >= nvars
convert(data_type, 1.0e-2)
else
zero(data_type)
end,
)
steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate)

return ICNF{
data_type,
typeof(compute_mode),
inplace,
cond,
!iszero(naugmented),
!iszero(steer_rate),
!iszero(λ₁),
!iszero(λ₂),
!iszero(λ₃),
typeof(nn),
typeof(nvars),
typeof(device),
typeof(basedist),
typeof(tspan),
typeof(steerdist),
typeof(epsdist),
typeof(sol_kwargs),
typeof(rng),
}(
nn,
nvars,
naugmented,
compute_mode,
device,
basedist,
tspan,
steerdist,
epsdist,
sol_kwargs,
rng,
λ₁,
λ₂,
λ₃,
)
end

function Base.show(io::IO, icnf::AbstractICNF)
return print(io, typeof(icnf))
end
Expand Down
Loading
Loading