Skip to content

Commit f144a80

Browse files
authored
Rename APIs + add tests (#5)
* rename * readme viols * BOI->BNK * fix viols * fix cons * fix cons cont * no param crc tests * jac/hess coord tests * typo * typo cont.
1 parent f315dd5 commit f144a80

15 files changed

Lines changed: 302 additions & 202 deletions

README.md

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55

66
`BatchNLPKernels.jl` provides [`KernelAbstractions.jl`](https://github.com/JuliaGPU/KernelAbstractions.jl) kernels for evaluating problem data from a (parametric) [`ExaModel`](https://github.com/exanauts/ExaModels.jl) for batches of solutions (and parameters). Currently the following functions (as well as their non-parametric variants) are exported:
77

8-
- `obj_batch!(::BatchModel, X, Θ)`
9-
- `grad_batch!(::BatchModel, X, Θ)`
10-
- `cons_nln_batch!(::BatchModel, X, Θ)`
11-
- `jac_coord_batch!(::BatchModel, X, Θ)`
12-
- `hess_coord_batch!(::BatchModel, X, Θ, Y; obj_weight=1.0)`
13-
- `jprod_nln_batch!(::BatchModel, X, Θ, V)`
14-
- `jtprod_nln_batch!(::BatchModel, X, Θ, V)`
15-
- `hprod_batch!(::BatchModel, X, Θ, Y, V; obj_weight=1.0)`
16-
8+
- `objective!(::BatchModel, X, Θ)`
9+
- `objective_gradient!(::BatchModel, X, Θ)`
10+
- `constraints!(::BatchModel, X, Θ)`
11+
- `constraints_jacobian!(::BatchModel, X, Θ)`
12+
- `lagrangian_hessian!(::BatchModel, X, Θ, Y; obj_weight=1.0)`
13+
- `constraints_jprod!(::BatchModel, X, Θ, V)`
14+
- `constraints_jtprod!(::BatchModel, X, Θ, V)`
15+
- `lagrangian_hprod!(::BatchModel, X, Θ, Y, V; obj_weight=1.0)`
16+
- `all_violations!(::BatchModel, X, Θ)`
17+
- `constraint_violations!(::BatchModel, X, Θ)`
18+
- `bound_violations!(::BatchModel, X)`
1719

1820
To use these functions, first wrap your `ExaModel` in a `BatchModel`:
1921

@@ -30,7 +32,7 @@ This pre-allocates work and output buffers. By default, only the buffers to supp
3032
Then, you can call the batch functions as follows:
3133

3234
```julia
33-
objs = obj_batch!(bm, X, Θ)
35+
objs = objective!(bm, X, Θ)
3436
```
3537

3638
where `X` and `Θ` are (device) matrices with dimensions `(nvar, batch_size)` and `(nθ, batch_size)` respectively.

ext/BNKChainRulesCore.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ module BNKChainRulesCore
33
using BatchNLPKernels
44
using ChainRulesCore
55

6-
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchModel, X, Θ)
7-
y = BatchNLPKernels.obj_batch!(bm, X, Θ)
6+
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.objective!), bm::BatchModel, X, Θ)
7+
y = BatchNLPKernels.objective!(bm, X, Θ)
88

99
function obj_batch_pullback(Ȳ)
1010
Ȳ = ChainRulesCore.unthunk(Ȳ)
11-
gradients = BatchNLPKernels.grad_batch!(bm, X, Θ)
11+
gradients = BatchNLPKernels.objective_gradient!(bm, X, Θ)
1212

1313
= gradients .* Ȳ'
1414

@@ -17,12 +17,12 @@ function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchMod
1717

1818
return y, obj_batch_pullback
1919
end
20-
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchModel, X)
21-
y = BatchNLPKernels.obj_batch!(bm, X)
20+
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.objective!), bm::BatchModel, X)
21+
y = BatchNLPKernels.objective!(bm, X)
2222

2323
function obj_batch_pullback(Ȳ)
2424
Ȳ = ChainRulesCore.unthunk(Ȳ)
25-
gradients = BatchNLPKernels.grad_batch!(bm, X)
25+
gradients = BatchNLPKernels.objective_gradient!(bm, X)
2626

2727
= gradients .* Ȳ'
2828

@@ -33,32 +33,32 @@ function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchMod
3333
end
3434

3535

36-
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.cons_nln_batch!), bm::BatchModel, X, Θ)
37-
y = BatchNLPKernels.cons_nln_batch!(bm, X, Θ)
36+
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.constraints!), bm::BatchModel, X, Θ)
37+
y = BatchNLPKernels.constraints!(bm, X, Θ)
3838

3939
function cons_nln_batch_pullback(Ȳ)
4040
Ȳ = ChainRulesCore.unthunk(Ȳ)
41-
= BatchNLPKernels.jtprod_nln_batch!(bm, X, Θ, Ȳ)
41+
= BatchNLPKernels.constraints_jtprod!(bm, X, Θ, Ȳ)
4242
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), X̄, ChainRulesCore.NoTangent()
4343
end
4444

4545
return y, cons_nln_batch_pullback
4646
end
47-
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.cons_nln_batch!), bm::BatchModel, X)
48-
y = BatchNLPKernels.cons_nln_batch!(bm, X)
47+
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.constraints!), bm::BatchModel, X)
48+
y = BatchNLPKernels.constraints!(bm, X)
4949

5050
function cons_nln_batch_pullback(Ȳ)
5151
Ȳ = ChainRulesCore.unthunk(Ȳ)
52-
= BatchNLPKernels.jtprod_nln_batch!(bm, X, Ȳ)
52+
= BatchNLPKernels.constraints_jtprod!(bm, X, Ȳ)
5353
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), X̄
5454
end
5555

5656
return y, cons_nln_batch_pullback
5757
end
5858

5959

60-
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.constraint_violations!), bm::BatchModel, V)
61-
Vc = BatchNLPKernels.constraint_violations!(bm, V)
60+
function ChainRulesCore.rrule(::typeof(BatchNLPKernels._constraint_violations!), bm::BatchModel, V)
61+
Vc = BatchNLPKernels._constraint_violations!(bm, V)
6262

6363
function constraint_violations_pullback(Ȳ)
6464
Ȳ = ChainRulesCore.unthunk(Ȳ)

src/BatchNLPKernels.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ const KAExtension = ExaKA.KAExtension
99
include("interval.jl")
1010
include("batch_model.jl")
1111

12-
const BOI = BatchNLPKernels
13-
export BOI, BatchModel, BatchModelConfig
14-
export obj_batch!, grad_batch!, cons_nln_batch!, jac_coord_batch!, hess_coord_batch!
15-
export jprod_nln_batch!, jtprod_nln_batch!, hprod_batch!
12+
const BNK = BatchNLPKernels
13+
export BNK, BatchModel, BatchModelConfig
14+
export objective!, objective_gradient!, constraints!, constraints_jacobian!, lagrangian_hessian!
15+
export constraints_jprod!, constraints_jtprod!, lagrangian_hprod!
16+
export all_violations!, constraint_violations!, bound_violations!
1617

1718
include("utils.jl")
1819
include("kernels.jl")

src/api/cons.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""
2-
cons_nln_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
2+
constraints!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
33
44
Evaluate constraints for a batch of solutions and parameters.
55
"""
6-
function cons_nln_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
6+
function constraints!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
77
C = _maybe_view(bm, :cons_out, X)
8-
cons_nln_batch!(bm, X, Θ, C)
8+
constraints!(bm, X, Θ, C)
99
return C
1010
end
1111

1212
"""
13-
cons_nln_batch!(bm::BatchModel, X::AbstractMatrix)
13+
constraints!(bm::BatchModel, X::AbstractMatrix)
1414
1515
Evaluate constraints for a batch of solutions.
1616
"""
17-
function cons_nln_batch!(bm::BatchModel, X::AbstractMatrix)
17+
function constraints!(bm::BatchModel, X::AbstractMatrix)
1818
Θ = _repeat_params(bm, X)
19-
cons_nln_batch!(bm, X, Θ)
19+
constraints!(bm, X, Θ)
2020
end
2121

2222

23-
function cons_nln_batch!(
23+
function constraints!(
2424
bm::BatchModel,
2525
X::AbstractMatrix,
2626
Θ::AbstractMatrix,
@@ -34,7 +34,7 @@ function cons_nln_batch!(
3434
_assert_batch_size(batch_size, bm.batch_size)
3535
backend = _get_backend(bm.model)
3636

37-
_cons_nln_batch!(backend, C, bm.model.cons, X, Θ)
37+
_constraints!(backend, C, bm.model.cons, X, Θ)
3838

3939
conbuffers_batch = _maybe_view(bm, :cons_work, X)
4040

@@ -53,17 +53,17 @@ function cons_nln_batch!(
5353
return C
5454
end
5555

56-
function _cons_nln_batch!(backend, C, con::ExaModels.Constraint, X, Θ)
56+
function _constraints!(backend, C, con::ExaModels.Constraint, X, Θ)
5757
if !isempty(con.itr)
5858
batch_size = size(X, 2)
5959
kerf_batch(backend)(C, con.f, con.itr, X, Θ; ndrange = (length(con.itr), batch_size))
6060
end
61-
_cons_nln_batch!(backend, C, con.inner, X, Θ)
61+
_constraints!(backend, C, con.inner, X, Θ)
6262
synchronize(backend)
6363
end
64-
function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintNull, X, Θ) end
65-
function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintAug, X, Θ)
66-
_cons_nln_batch!(backend, C, con.inner, X, Θ)
64+
function _constraints!(backend, C, con::ExaModels.ConstraintNull, X, Θ) end
65+
function _constraints!(backend, C, con::ExaModels.ConstraintAug, X, Θ)
66+
_constraints!(backend, C, con.inner, X, Θ)
6767
end
6868

6969
function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintAug, X, Θ)

src/api/grad.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
"""
2-
grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
2+
objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
33
44
Evaluate objective gradient for a batch of points.
55
"""
6-
function grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
6+
function objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
77
G = _maybe_view(bm, :grad_out, X)
8-
grad_batch!(bm, X, Θ, G)
8+
objective_gradient!(bm, X, Θ, G)
99
return G
1010
end
1111

1212
"""
13-
grad_batch!(bm::BatchModel, X::AbstractMatrix)
13+
objective_gradient!(bm::BatchModel, X::AbstractMatrix)
1414
1515
Evaluate objective gradient for a batch of points.
1616
"""
17-
function grad_batch!(bm::BatchModel, X::AbstractMatrix)
17+
function objective_gradient!(bm::BatchModel, X::AbstractMatrix)
1818
Θ = _repeat_params(bm, X)
19-
grad_batch!(bm, X, Θ)
19+
objective_gradient!(bm, X, Θ)
2020
end
2121

22-
function _grad_batch!(backend, grad_work, objs, X, Θ)
22+
function _objective_gradient!(backend, grad_work, objs, X, Θ)
2323
sgradient_batch!(backend, grad_work, objs, X, Θ, one(eltype(grad_work)))
24-
_grad_batch!(backend, grad_work, objs.inner, X, Θ)
24+
_objective_gradient!(backend, grad_work, objs.inner, X, Θ)
2525
synchronize(backend)
2626
end
27-
function _grad_batch!(backend, grad_work, objs::ExaModels.ObjectiveNull, X, Θ) end
27+
function _objective_gradient!(backend, grad_work, objs::ExaModels.ObjectiveNull, X, Θ) end
2828

2929
function sgradient_batch!(
3030
backend::B,
@@ -41,11 +41,11 @@ function sgradient_batch!(
4141
end
4242

4343
"""
44-
grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, G::AbstractMatrix)
44+
objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, G::AbstractMatrix)
4545
4646
Evaluate gradients for a batch of points with different parameters.
4747
"""
48-
function grad_batch!(
48+
function objective_gradient!(
4949
bm::BatchModel,
5050
X::AbstractMatrix,
5151
Θ::AbstractMatrix,
@@ -63,7 +63,7 @@ function grad_batch!(
6363
if !isempty(grad_work)
6464
fill!(grad_work, zero(eltype(grad_work)))
6565

66-
_grad_batch!(backend, grad_work, bm.model.objs, X, Θ)
66+
_objective_gradient!(backend, grad_work, bm.model.objs, X, Θ)
6767

6868
fill!(G, zero(eltype(G)))
6969
compress_to_dense_batch(backend)(

src/api/hess.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
"""
2-
hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
2+
lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
33
44
Evaluate Hessian coordinates for a batch of points.
55
"""
6-
function hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
6+
function lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
77
H_view = _maybe_view(bm, :hprod_work, X)
8-
hess_coord_batch!(bm, X, Θ, Y, H_view; obj_weight=obj_weight)
8+
lagrangian_hessian!(bm, X, Θ, Y, H_view; obj_weight=obj_weight)
99
return H_view
1010
end
1111

1212
"""
13-
hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
13+
lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
1414
1515
Evaluate Hessian coordinates for a batch of points.
1616
"""
17-
function hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
17+
function lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
1818
Θ = _repeat_params(bm, X)
19-
hess_coord_batch!(bm, X, Θ, Y; obj_weight=obj_weight)
19+
lagrangian_hessian!(bm, X, Θ, Y; obj_weight=obj_weight)
2020
end
2121

22-
function hess_coord_batch!(
22+
function lagrangian_hessian!(
2323
bm::BatchModel,
2424
X::AbstractMatrix,
2525
Θ::AbstractMatrix,
@@ -37,24 +37,24 @@ function hess_coord_batch!(
3737
backend = _get_backend(bm.model)
3838

3939
fill!(H, zero(eltype(H)))
40-
_obj_hess_coord_batch!(backend, H, bm.model.objs, X, Θ, obj_weight)
41-
_con_hess_coord_batch!(backend, H, bm.model.cons, X, Θ, Y)
40+
_obj_lagrangian_hessian!(backend, H, bm.model.objs, X, Θ, obj_weight)
41+
_con_lagrangian_hessian!(backend, H, bm.model.cons, X, Θ, Y)
4242
return H
4343
end
4444

45-
function _obj_hess_coord_batch!(backend, H, objs, X, Θ, obj_weight)
45+
function _obj_lagrangian_hessian!(backend, H, objs, X, Θ, obj_weight)
4646
shessian_batch!(backend, H, nothing, objs, X, Θ, obj_weight, zero(eltype(H)))
47-
_obj_hess_coord_batch!(backend, H, objs.inner, X, Θ, obj_weight)
47+
_obj_lagrangian_hessian!(backend, H, objs.inner, X, Θ, obj_weight)
4848
synchronize(backend)
4949
end
50-
function _obj_hess_coord_batch!(backend, H, objs::ExaModels.ObjectiveNull, X, Θ, obj_weight) end
50+
function _obj_lagrangian_hessian!(backend, H, objs::ExaModels.ObjectiveNull, X, Θ, obj_weight) end
5151

52-
function _con_hess_coord_batch!(backend, H, cons, X, Θ, Y)
52+
function _con_lagrangian_hessian!(backend, H, cons, X, Θ, Y)
5353
shessian_batch!(backend, H, nothing, cons, X, Θ, Y, zero(eltype(H)))
54-
_con_hess_coord_batch!(backend, H, cons.inner, X, Θ, Y)
54+
_con_lagrangian_hessian!(backend, H, cons.inner, X, Θ, Y)
5555
synchronize(backend)
5656
end
57-
function _con_hess_coord_batch!(backend, H, cons::ExaModels.ConstraintNull, X, Θ, Y) end
57+
function _con_lagrangian_hessian!(backend, H, cons::ExaModels.ConstraintNull, X, Θ, Y) end
5858

5959
function shessian_batch!(
6060
backend::B,

src/api/hprod.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""
2-
hprod_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
2+
lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
33
44
Evaluate Hessian-vector products for a batch of points.
55
"""
6-
function hprod_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
6+
function lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
77
Hv = _maybe_view(bm, :hprod_out, X)
8-
hprod_batch!(bm, X, Θ, Y, V, Hv; obj_weight=obj_weight)
8+
lagrangian_hprod!(bm, X, Θ, Y, V, Hv; obj_weight=obj_weight)
99
return Hv
1010
end
1111

1212
"""
13-
hprod_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
13+
lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
1414
1515
Evaluate Hessian-vector products for a batch of points.
1616
"""
17-
function hprod_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
17+
function lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
1818
Θ = _repeat_params(bm, X)
19-
hprod_batch!(bm, X, Θ, Y, V; obj_weight=obj_weight)
19+
lagrangian_hprod!(bm, X, Θ, Y, V; obj_weight=obj_weight)
2020
return Hv
2121
end
2222

23-
function hprod_batch!(
23+
function lagrangian_hprod!(
2424
bm::BatchModel,
2525
X::AbstractMatrix,
2626
Θ::AbstractMatrix,
@@ -40,7 +40,7 @@ function hprod_batch!(
4040

4141
H_batch = _maybe_view(bm, :hprod_work, X)
4242

43-
hess_coord_batch!(bm, X, Θ, Y, H_batch; obj_weight=obj_weight)
43+
lagrangian_hessian!(bm, X, Θ, Y, H_batch; obj_weight=obj_weight)
4444

4545
fill!(Hv, zero(eltype(Hv)))
4646
kersyspmv_batch(backend)(

0 commit comments

Comments
 (0)