diff --git a/genetic-rs-common/src/builtin/eliminator.rs b/genetic-rs-common/src/builtin/eliminator.rs index 24b15cd..27fea10 100644 --- a/genetic-rs-common/src/builtin/eliminator.rs +++ b/genetic-rs-common/src/builtin/eliminator.rs @@ -35,6 +35,7 @@ impl + Send + Sync> FeatureBoundedFitne /// A trait for observing fitness scores. This can be used to implement things like logging or statistics collection. pub trait FitnessObserver { /// Observes the fitness scores of a generation of genomes. + /// The input slice is always sorted in descending order by fitness (highest fitness first). fn observe(&mut self, fitnesses: &[(G, f32)]); /// Layers this observer with another, calling both in sequence. @@ -732,42 +733,60 @@ mod speciation { fn eliminate(&mut self, genomes: Vec) -> Vec { let (raw, divided) = self.calculate_fitnesses(&genomes); - let mut sorted: Vec<(G, f32, f32)> = genomes + let mut data: Vec<(G, f32, f32)> = genomes .into_iter() .enumerate() .map(|(i, g)| (g, raw[i], divided[i])) .collect(); - sorted.sort_by(|(_, _, a), (_, _, b)| b.partial_cmp(a).unwrap()); - let median_index = (sorted.len() as f32) * self.inner.threshold; + let median_index = (data.len() as f32) * self.inner.threshold; + + // Sort by raw fitness so observer inputs are ordered by fitness descending. + data.sort_by(|(_, a, _), (_, b, _)| b.partial_cmp(a).unwrap()); + + // Split raw-sorted pairs for the observer while retaining divided values. + let (observer_pairs, divided_vals): (Vec<(G, f32)>, Vec) = data + .into_iter() + .map(|(g, raw, div)| ((g, raw), div)) + .unzip(); - let mut observer_pairs: Vec<(G, f32)> = - sorted.into_iter().map(|(g, raw, _)| (g, raw)).collect(); self.inner.observer.observe(&observer_pairs); - observer_pairs.truncate(median_index as usize + 1); - observer_pairs.into_iter().map(|(g, _)| g).collect() + // Re-sort by divided fitness and truncate for speciation-aware elimination. + let mut with_divided: Vec<_> = observer_pairs.into_iter().zip(divided_vals).collect(); + with_divided.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); + with_divided.truncate(median_index as usize + 1); + with_divided.into_iter().map(|((g, _), _)| g).collect() } #[cfg(feature = "rayon")] fn eliminate(&mut self, genomes: Vec) -> Vec { let (raw, divided) = self.calculate_fitnesses(&genomes); - let mut sorted: Vec<(G, f32, f32)> = genomes + let mut data: Vec<(G, f32, f32)> = genomes .into_iter() .enumerate() .map(|(i, g)| (g, raw[i], divided[i])) .collect(); - sorted.sort_by(|(_, _, a), (_, _, b)| b.partial_cmp(a).unwrap()); - let median_index = (sorted.len() as f32) * self.inner.threshold; + let median_index = (data.len() as f32) * self.inner.threshold; + + // Sort by raw fitness so observer inputs are ordered by fitness descending. + data.sort_by(|(_, a, _), (_, b, _)| b.partial_cmp(a).unwrap()); + + // Split raw-sorted pairs for the observer while retaining divided values. + let (observer_pairs, divided_vals): (Vec<(G, f32)>, Vec) = data + .into_iter() + .map(|(g, raw, div)| ((g, raw), div)) + .unzip(); - let mut observer_pairs: Vec<(G, f32)> = - sorted.into_iter().map(|(g, raw, _)| (g, raw)).collect(); self.inner.observer.observe(&observer_pairs); - observer_pairs.truncate(median_index as usize + 1); - observer_pairs.into_par_iter().map(|(g, _)| g).collect() + // Re-sort by divided fitness and truncate for speciation-aware elimination. + let mut with_divided: Vec<_> = observer_pairs.into_iter().zip(divided_vals).collect(); + with_divided.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); + with_divided.truncate(median_index as usize + 1); + with_divided.into_par_iter().map(|((g, _), _)| g).collect() } } } diff --git a/genetic-rs/tests/speciation.rs b/genetic-rs/tests/speciation.rs index 89a2562..a18af19 100644 --- a/genetic-rs/tests/speciation.rs +++ b/genetic-rs/tests/speciation.rs @@ -266,6 +266,61 @@ fn speciation_protects_rare_species() { ); } +/// The fitness observer on [`SpeciatedFitnessEliminator`] must receive fitness scores +/// sorted in descending order by raw (pre-division) fitness. +/// +/// Setup (deliberately unsorted input — low fitness genome placed first): +/// - 1 genome of class 1 with val = 0.5 → raw fitness = 0.5, divided = 0.5 +/// - 4 genomes of class 0 with val = 1.0 → raw fitness = 1.0, divided = 0.25 +/// +/// If the eliminator forwards the input order unchanged, the observer would see +/// `[0.5, 1.0, 1.0, 1.0, 1.0]` (not sorted). If sorted by divided fitness the +/// class-1 genome (divided = 0.5) would come first, yielding `[0.5, 1.0, …]`. +/// Only when sorted by raw fitness does the observer see `[1.0, 1.0, 1.0, 1.0, 0.5]`. +#[test] +fn observer_receives_fitness_sorted_by_raw_descending() { + use std::sync::{Arc, Mutex}; + + let observed: Arc>> = Arc::new(Mutex::new(Vec::new())); + let observed_clone = Arc::clone(&observed); + + let observer = move |fitnesses: &[(Genome, f32)]| { + let mut v = observed_clone.lock().unwrap(); + v.extend(fitnesses.iter().map(|(_, f)| *f)); + }; + + // Put the low-fitness genome first so the input is intentionally unsorted. + let mut genomes = vec![Genome { class: 1, val: 0.5 }]; + genomes.extend((0..4).map(|_| Genome { class: 0, val: 1.0 })); + + let mut eliminator = SpeciatedFitnessEliminator::new(fitness, 0.5, 0.5, observer, ()); + eliminator.eliminate(genomes); + + let scores = observed.lock().unwrap(); + assert_eq!(scores.len(), 5, "observer must receive all genomes"); + + // Scores must be in non-increasing order (sorted descending by raw fitness). + for window in scores.windows(2) { + assert!( + window[0] >= window[1], + "observer inputs must be sorted descending by fitness, but got: {:?}", + *scores + ); + } + + // The full expected sequence is [1.0, 1.0, 1.0, 1.0, 0.5]. + assert!( + (scores[0] - 1.0_f32).abs() < 1e-6, + "first fitness must be the highest raw fitness (1.0), but got: {:?}", + *scores + ); + assert!( + (scores[4] - 0.5_f32).abs() < 1e-6, + "last fitness must be the lowest raw fitness (0.5), but got: {:?}", + *scores + ); +} + /// The fitness observer on [`SpeciatedFitnessEliminator`] must receive the raw /// (pre-division) fitness values, not the values after they have been divided by /// the number of genomes in the species.