diff --git a/Cargo.toml b/Cargo.toml index 38e6654c0..8f11fc550 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,8 @@ linfa-datasets = { path = "datasets", features = [ "diabetes", "generate", ] } +linfa-linear = { path = "algorithms/linfa-linear" } +linfa-svm = { path = "algorithms/linfa-svm" } statrs = "0.18" [target.'cfg(not(windows))'.dependencies] diff --git a/src/composing/mod.rs b/src/composing/mod.rs index a1f2acc37..bb7271889 100644 --- a/src/composing/mod.rs +++ b/src/composing/mod.rs @@ -1,12 +1,15 @@ //! Composition models //! -//! This module contains three composition models: +//! This module contains four composition models: //! * `MultiClassModel`: combine multiple binary decision models to a single multi-class model //! * `MultiTargetModel`: combine multiple univariate models to a single multi-target model //! * `Platt`: calibrate a classifier (i.e. SVC) to predicted posterior probabilities +//! * `ResidualChain`: fit models sequentially on the residuals of the previous one +//! (forward stagewise additive modeling / L2Boosting); see [`residual_chain::Stagewise`] mod multi_class_model; mod multi_target_model; pub mod platt_scaling; +pub mod residual_chain; pub use multi_class_model::MultiClassModel; pub use multi_target_model::MultiTargetModel; diff --git a/src/composing/residual_chain.rs b/src/composing/residual_chain.rs new file mode 100644 index 000000000..3713703d1 --- /dev/null +++ b/src/composing/residual_chain.rs @@ -0,0 +1,584 @@ +//! L2Boosting (forward stagewise additive modelling with squared-error loss) +//! for the linfa ML framework. +//! +//! This module provides [`ResidualChain`], which fits models sequentially on +//! residuals. Chain as many stages as you like via [`Stagewise`]: +//! +//! 1. Fit `base` on `(X, Y)` +//! 2. Compute residuals: `R = Y - base.predict(X)` +//! 3. Fit `corrector` on `(X, R)` +//! 4. Repeat for any further correctors stacked on top +//! +//! When predicting, all stages' outputs are summed. +//! +//! This is the special case of FSAM (Friedman 2001) where the loss is squared +//! error. Shrinkage (learning rate ν ∈ (0, 1]) can be set per corrector via +//! [`Shrunk::with_shrinkage`]; the default is ν = 1 (no scaling). +//! +//! # References +//! +//! - J. H. Friedman (2001). "Greedy function approximation: A gradient boosting machine." +//! +//! +//! # Examples +//! +//! ## Linear + linear +//! +//! Two `linfa_linear::LinearRegression` models stacked: the corrector fits +//! the residuals left by the base. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa::composing::residual_chain::{ResidualChain, Stagewise}; +//! use ndarray::{array, Array2}; +//! +//! // y = 2x: perfectly linear, so the corrector should see zero residuals. +//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +//! let y = array![0., 2., 4., 6., 8.]; +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = LinearRegression::default() +//! .chain(LinearRegression::default()) +//! .fit(&dataset) +//! .unwrap(); +//! +//! let _preds = fitted.predict(&x); +//! ``` +//! +//! ## The second model learns nothing when the first fits perfectly +//! +//! If the first model already captures the data exactly, the residuals are all +//! zero and the second model has nothing to learn — its parameters come out +//! at (or very near) zero. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa::composing::residual_chain::Stagewise; +//! use ndarray::{array, Array2}; +//! +//! // y = 2x: one linear model is enough to fit this perfectly. +//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +//! let y = array![0., 2., 4., 6., 8.]; +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = LinearRegression::default() +//! .chain(LinearRegression::default()) +//! .fit(&dataset) +//! .unwrap(); +//! +//! // The corrector trained on zero residuals — nothing left to correct. +//! assert!(fitted.corrector().model().params().iter().all(|&c: &f64| c.abs() < 1e-10)); +//! assert!(fitted.corrector().model().intercept().abs() < 1e-10); +//! ``` +//! +//! ## Chained SVMs and linear regression +//! +//! A linear-kernel `linfa_svm::Svm` captures the overall trend; two +//! Gaussian-kernel SVMs and a `linfa_linear::LinearRegression` then fit +//! successive residuals in a four-model chain. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa::composing::residual_chain::{ResidualChain, Stagewise}; +//! use linfa_svm::Svm; +//! use ndarray::Array; +//! +//! // y = sin(x): the linear SVM captures the slope; the RBF SVM captures +//! // the curvature left in the residuals. +//! let x = Array::linspace(0f64, 6., 20) +//! .into_shape_with_order((20, 1)) +//! .unwrap(); +//! let y = x.column(0).mapv(f64::sin); +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = Svm::::params() +//! .c_svr(1., None) +//! .linear_kernel() +//! .chain( +//! Svm::::params() +//! .c_svr(10., Some(0.1)) +//! .gaussian_kernel(1.), +//! ) +//! .chain(LinearRegression::default()) +//! .chain( +//! Svm::::params() +//! .c_svr(10., Some(0.1)) +//! .gaussian_kernel(3.), +//! ) +//! .fit(&dataset) +//! .unwrap(); +//! +//! let _preds = fitted.predict(&x); +//! ``` + +use crate::dataset::{AsTargets, DatasetBase, Records}; +use crate::param_guard::ParamGuard; +use crate::traits::{Fit, Predict, PredictInplace}; +use crate::Float; +use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, RawDataClone}; +#[cfg(feature = "serde")] +use serde_crate::{Deserialize, Serialize}; +use std::ops::{AddAssign, Mul}; + +type Arr2 = ArrayBase; + +/// Error returned by [`ResidualChain::fit`]. +/// +/// Wraps the error from whichever of the two model fits failed, keeping them +/// distinguishable without requiring both models to share the same error type. +#[derive(Debug, thiserror::Error)] +pub enum ResidualChainError { + #[error("base model: {0}")] + Base(E1), + #[error("corrector: {0}")] + Corrector(E2), + // Satisfies the `Fit` trait's `E: From` bound. + #[error(transparent)] + BaseCrate(#[from] crate::Error), +} + +/// A pair of [`Fit`] params that fits sequentially on residuals. +/// +/// `base` is fit on the original targets; `corrector` (a [`Shrunk`] model) is +/// fit on the residuals left by `base` and scaled by its shrinkage factor ν. +/// Prediction sums `base` and the scaled corrector output. +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, Copy)] +pub struct ResidualChain { + base: B, + corrector: Shrunk, +} + +impl ResidualChain { + pub fn base(&self) -> &B { + &self.base + } + pub fn corrector(&self) -> &Shrunk { + &self.corrector + } +} + +/// Extension trait that adds residual-chain composition methods to any type. +/// +/// Blanket-implemented for all `Sized` types, so any model params type gains +/// these methods automatically: +/// +/// - [`chain`](Stagewise::chain): compose `self` (as the base) with a corrector +/// that will be trained on the residuals left by `self`. The corrector is used +/// without shrinkage (ν = 1). Returns a [`ResidualChainParams`] whose `.fit()` +/// runs both stages. Calls can be chained to build arbitrarily deep sequences. +/// - [`chain_shrunk`](Stagewise::chain_shrunk): like `chain`, but accepts a +/// [`Shrunk`]-wrapped corrector so you can control the learning rate ν +/// explicitly via [`shrink_by`](Stagewise::shrink_by). +/// - [`shrink_by`](Stagewise::shrink_by): wrap `self` in a [`Shrunk`] with the +/// given learning rate ν ∈ (0, 1], making it ready to pass as the `corrector` +/// argument to [`Stagewise::chain_shrunk`]. +/// +/// # Example +/// +/// ``` +/// use linfa::traits::Fit; +/// use linfa::DatasetBase; +/// use linfa_linear::LinearRegression; +/// use linfa::composing::residual_chain::Stagewise; +/// use ndarray::{array, Array2}; +/// +/// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +/// let y = array![0., 2., 4., 6., 8.]; +/// let dataset = DatasetBase::new(x.clone(), y); +/// +/// let fitted = LinearRegression::default() +/// .chain(LinearRegression::default()) +/// .fit(&dataset) +/// .unwrap(); +/// ``` +pub trait Stagewise: Sized { + /// Compose `self` (as the base model) with a [`Shrunk`]-wrapped `corrector`, + /// which will be trained on the residuals left by `self`. Further stages can + /// be appended by calling `.chain(...)` or `.chain_shrunk(...)` on the + /// returned [`ResidualChainParams`]. + /// + /// Use [`chain`](Stagewise::chain) instead when you don't need to shrink + /// the corrector. + fn chain_shrunk(self, corrector: Shrunk) -> ResidualChainParams; + + /// Compose `self` (as the base model) with `corrector`, which will be + /// trained on the residuals left by `self`. The corrector is used without + /// shrinkage (equivalent to `shrink_by(1.0)`). Further stages can be + /// appended by calling `.chain(...)` or `.chain_shrunk(...)` on the + /// returned [`ResidualChainParams`]. + /// + /// Use [`chain_shrunk`](Stagewise::chain_shrunk) together with + /// [`shrink_by`](Stagewise::shrink_by) when you need to control the + /// learning rate ν of the corrector explicitly. + fn chain(self, corrector: C) -> ResidualChainParams + where + C: Fit, Array1, E>, + E: std::error::Error + From; + + /// Wrap `self` in a [`Shrunk`] with learning rate `shrinkage` ∈ (0, 1], + /// making it ready to pass as the `corrector` argument to [`Stagewise::chain_shrunk`]. + /// + /// The bound `Self: Fit, Array1, E>` ensures at compile time + /// that the model's element type matches the shrinkage type `F`. + fn shrink_by(self, shrinkage: F) -> Shrunk + where + Self: Fit, Array1, E>, + E: std::error::Error + From; +} + +impl Stagewise for B { + fn chain_shrunk(self, corrector: Shrunk) -> ResidualChainParams { + ResidualChainParams(ResidualChain { + base: self, + corrector, + }) + } + fn chain(self, corrector: C) -> ResidualChainParams + where + C: Fit, Array1, E>, + E: std::error::Error + From, + { + self.chain_shrunk(corrector.shrink_by(F::one())) + } + fn shrink_by(self, shrinkage: F) -> Shrunk + where + Self: Fit, Array1, E>, + E: std::error::Error + From, + { + Shrunk { + model: self, + shrinkage, + } + } +} + +impl + RawDataClone, T, E1, E2> + Fit, T, ResidualChainError> for ResidualChain +where + Arr2: Records, + F1: Fit, T, E1>, + for<'a> F1::Object: Predict<&'a Arr2, Array1>, + F2: Fit, Array1, E2>, + T: AsTargets, + E1: std::error::Error + From, + E2: std::error::Error + From, +{ + type Object = ResidualChain; + + fn fit( + &self, + dataset: &DatasetBase, T>, + ) -> Result> { + let base = self.base.fit(dataset).map_err(ResidualChainError::Base)?; + + let y_pred = base.predict(dataset.records()); + let residuals = &dataset.targets().as_targets() - &y_pred; + + let residual_dataset = DatasetBase::new(dataset.records().clone(), residuals); + let corrector_model = self + .corrector + .model + .fit(&residual_dataset) + .map_err(ResidualChainError::Corrector)?; + + Ok(ResidualChain { + base, + corrector: Shrunk { + model: corrector_model, + shrinkage: self.corrector.shrinkage, + }, + }) + } +} + +impl> PredictInplace, Array1> + for ResidualChain +where + R1: PredictInplace, Array1>, + R2: PredictInplace, Array1>, +{ + fn predict_inplace<'a>(&'a self, x: &'a Arr2, y: &mut Array1) { + self.base.predict_inplace(x, y); + y.add_assign( + &self + .corrector + .model + .predict(x) + .mul(self.corrector.shrinkage), + ); + } + + fn default_target(&self, x: &Arr2) -> Array1 { + Array1::zeros(x.nrows()) + } +} + +/// A model (params or fitted) paired with a shrinkage factor ν ∈ (0, 1]. +/// +/// Used in two roles: +/// - **Before fitting**: `Shrunk` wraps corrector params `C`; created by +/// [`Stagewise::shrink_by`] and stored in [`ResidualChain`] / [`ResidualChainParams`]. +/// - **After fitting**: `Shrunk` wraps the fitted corrector model; +/// prediction scales the corrector's output by ν before summing with the base. +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, Copy)] +pub struct Shrunk { + model: M, + shrinkage: F, +} + +impl Shrunk { + pub fn model(&self) -> &M { + &self.model + } + pub fn shrinkage(&self) -> F { + self.shrinkage + } + /// Set the shrinkage factor. Validation happens when the containing + /// [`ResidualChainParams`] is checked via [`ParamGuard`]. + pub fn with_shrinkage(mut self, shrinkage: F) -> Self { + self.shrinkage = shrinkage; + self + } +} + +/// Unvalidated [`ResidualChain`] parameters returned by [`Stagewise::chain_shrunk`]. +/// +/// Call `.fit()` to validate and fit in one step — the [`ParamGuard`] blanket +/// impl runs `check_ref` first, which verifies that the outermost corrector's +/// shrinkage factor is in (0, 1]. Inner chains validate lazily when their own +/// `.fit()` is called. You can also call `.check()` / `.check_unwrap()` to +/// validate explicitly. +/// +/// To set an explicit shrinkage factor on the corrector use +/// [`Shrunk::with_shrinkage`]: +/// +/// ``` +/// use linfa::traits::{Fit, Predict}; +/// use linfa::DatasetBase; +/// use linfa_linear::LinearRegression; +/// use linfa::composing::residual_chain::{Shrunk, Stagewise}; +/// use ndarray::{array, Array2}; +/// +/// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +/// let y = array![0., 2., 4., 6., 8.]; +/// let dataset = DatasetBase::new(x.clone(), y); +/// +/// // The corrector's contribution is scaled by 0.1. +/// let fitted = LinearRegression::default() +/// .chain_shrunk(LinearRegression::default().shrink_by(0.1)) +/// .fit(&dataset) +/// .unwrap(); +/// ``` +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, Copy)] +pub struct ResidualChainParams(ResidualChain); + +impl ParamGuard for ResidualChainParams { + type Checked = ResidualChain; + type Error = crate::Error; + + fn check_ref(&self) -> Result<&ResidualChain, crate::Error> { + let v = self.0.corrector.shrinkage; + if v > F::zero() && v <= F::one() { + Ok(&self.0) + } else { + Err(crate::Error::Parameters(format!( + "shrinkage must be in (0, 1], got {v}" + ))) + } + } + + fn check(self) -> Result, crate::Error> { + self.check_ref()?; + Ok(self.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Error as LinfaError; + use crate::DatasetBase; + use ndarray::{array, Array1, Array2}; + + #[derive(thiserror::Error, Debug)] + #[error("dummy error")] + struct DummyError(#[from] LinfaError); + + // Params that fits by recording the mean of the targets. + struct MeanParams; + + // Model that predicts the mean it saw during fit. + struct MeanModel(f64); + + impl Fit, Array1, DummyError> for MeanParams { + type Object = MeanModel; + fn fit( + &self, + dataset: &DatasetBase, Array1>, + ) -> Result { + let mean = dataset.targets().iter().sum::() / dataset.targets().len() as f64; + Ok(MeanModel(mean)) + } + } + + impl PredictInplace, Array1> for MeanModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) + } + } + + #[test] + fn corrector_is_fit_on_residuals() { + // targets = [1, 3]. base sees mean=2, predicts 2 for all. + // residuals = [1-2, 3-2] = [-1, 1]. corrector sees mean=0. + let model = MeanParams.chain(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + assert_eq!(fitted.base().0, 2.0); // mean of [1, 3] + assert_eq!(fitted.corrector().model().0, 0.0); // mean of residuals [-1, 1] + } + + #[test] + fn predict_sums_both_models() { + // base predicts 2.0, corrector predicts 0.0 → sum = 2.0 + let model = MeanParams.chain(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + let records = array![[0.0_f64], [1.0]]; + let predictions = fitted.predict(&records); + assert_eq!(predictions, array![2.0, 2.0]); + } + + #[test] + fn predict_recovers_targets_when_residuals_fit_perfectly() { + // If the corrector perfectly fits the residuals, the combined prediction = original targets. + struct FixedParams(f64); + struct FixedModel(f64); + + impl Fit, Array1, DummyError> for FixedParams { + type Object = FixedModel; + fn fit( + &self, + _dataset: &DatasetBase, Array1>, + ) -> Result { + Ok(FixedModel(self.0)) + } + } + + impl PredictInplace, Array1> for FixedModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) + } + } + + // base predicts 3.0, corrector predicts 1.0 → sum = 4.0 + let model = FixedParams(3.0) + .chain(FixedParams(1.0)) + .chain(FixedParams(0.0)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); + let fitted = model.fit(&dataset).unwrap(); + + let predictions = fitted.predict(&array![[0.0_f64], [1.0]]); + assert_eq!(predictions, array![4.0, 4.0]); + } + + #[test] + fn deep_chain_accessors() { + let model = MeanParams + .chain(MeanParams) + .chain(MeanParams) + .chain(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + assert_eq!(fitted.base().base().base().0, 2.0); // params trained on original targets + } + + #[test] + fn shrinkage_scales_corrector_prediction() { + // base predicts mean=2.0, corrector predicts mean=0.0 (residuals [-1,1]). + // With shrinkage=0.5, corrector contributes 0.5*0.0 = 0.0 → total = 2.0. + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(0.5)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + let preds = fitted.predict(&array![[0.0_f64], [1.0]]); + assert_eq!(preds, array![2.0, 2.0]); + assert_eq!(fitted.corrector().shrinkage(), 0.5); + } + + #[test] + fn shrinkage_corrector_sees_scaled_residuals() { + // base predicts 3.0 always. targets = [4.0, 4.0]. + // residuals = [1.0, 1.0]. corrector (mean) sees mean=1.0. + // With shrinkage=0.5: prediction = 3.0 + 0.5*1.0 = 3.5. + struct FixedParams(f64); + struct FixedModel(f64); + + impl Fit, Array1, DummyError> for FixedParams { + type Object = FixedModel; + fn fit( + &self, + _dataset: &DatasetBase, Array1>, + ) -> Result { + Ok(FixedModel(self.0)) + } + } + + impl PredictInplace, Array1> for FixedModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) + } + } + + let model = FixedParams(3.0).chain_shrunk(MeanParams.shrink_by(0.5)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); + let fitted = model.fit(&dataset).unwrap(); + + let preds = fitted.predict(&array![[0.0_f64], [1.0]]); + // corrector saw residuals [1.0, 1.0], mean=1.0, shrunk by 0.5 → 0.5 + assert!((preds[0] - 3.5_f64).abs() < 1e-10); + } + + #[test] + fn shrinkage_invalid_value_returns_error() { + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(0.0)); + assert!(model.fit(&dataset).is_err()); + + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(1.5)); + assert!(model.fit(&dataset).is_err()); + } +}