diff --git a/genetic-rs-common/src/builtin/eliminator.rs b/genetic-rs-common/src/builtin/eliminator.rs index 9130359..24b15cd 100644 --- a/genetic-rs-common/src/builtin/eliminator.rs +++ b/genetic-rs-common/src/builtin/eliminator.rs @@ -644,58 +644,80 @@ mod speciation { } } - /// Calculates the fitness of each genome, dividing by the number of genomes in its species, and sorts them by fitness. - /// Returns a vector of tuples containing the genome and its fitness score. + /// Computes raw and species-divided fitness for every genome. + /// + /// Returns `(raw_fitnesses, divided_fitnesses)` where both vecs are indexed + /// the same way as `genomes`. The divided value is used for elimination + /// (to balance species pressure); the raw value is what observers see. #[cfg(not(feature = "rayon"))] - pub fn calculate_and_sort(&self, genomes: Vec) -> Vec<(G, f32)> { + fn calculate_fitnesses(&self, genomes: &[G]) -> (Vec, Vec) { let population = - SpeciatedPopulation::from_genomes(&genomes, self.speciation_threshold, &self.ctx); - let mut fitnesses = vec![0.0; genomes.len()]; + SpeciatedPopulation::from_genomes(genomes, self.speciation_threshold, &self.ctx); + let mut raw = vec![0.0_f32; genomes.len()]; + let mut divided = vec![0.0_f32; genomes.len()]; for species in population.species() { let len = species.len() as f32; debug_assert!(len != 0.0); for &index in species { - let genome = &genomes[index]; - let fitness = self.inner.fitness_fn.fitness(genome); - if fitness < 0.0 { - fitnesses[index] = fitness * len; + let fitness = self.inner.fitness_fn.fitness(&genomes[index]); + raw[index] = fitness; + divided[index] = if fitness < 0.0 { + fitness * len } else { - fitnesses[index] = fitness / len; - } + fitness / len + }; } } - let mut fitnesses: Vec<(G, f32)> = genomes.into_iter().zip(fitnesses).collect(); - fitnesses.sort_by(|(_a, afit), (_b, bfit)| bfit.partial_cmp(afit).unwrap()); - fitnesses + (raw, divided) } - /// Calculates the fitness of each genome, dividing by the number of genomes in its species, and sorts them by fitness. - /// Returns a vector of tuples containing the genome and its fitness score. + /// Computes raw and species-divided fitness for every genome (parallel version). + /// + /// Species membership is determined sequentially first (greedy clustering), then + /// fitness functions are evaluated in parallel using rayon. #[cfg(feature = "rayon")] - pub fn calculate_and_sort(&self, genomes: Vec) -> Vec<(G, f32)> { + fn calculate_fitnesses(&self, genomes: &[G]) -> (Vec, Vec) { let population = - SpeciatedPopulation::from_genomes(&genomes, self.speciation_threshold, &self.ctx); - - let mut fitnesses = vec![0.0; genomes.len()]; + SpeciatedPopulation::from_genomes(genomes, self.speciation_threshold, &self.ctx); + let mut species_lens = vec![0.0_f32; genomes.len()]; for species in population.species() { let len = species.len() as f32; debug_assert!(len != 0.0); for &index in species { - let genome = &genomes[index]; - let fitness = self.inner.fitness_fn.fitness(genome); - if fitness < 0.0 { - fitnesses[index] = fitness * len; - } else { - fitnesses[index] = fitness / len; - } + species_lens[index] = len; } } - let mut result: Vec<(G, f32)> = genomes.into_iter().zip(fitnesses).collect(); - result.sort_by(|(_a, afit), (_b, bfit)| bfit.partial_cmp(afit).unwrap()); + let fitness_fn = &self.inner.fitness_fn; + let results: Vec<(f32, f32)> = genomes + .par_iter() + .zip(species_lens.par_iter()) + .map(|(genome, &len)| { + let fitness = fitness_fn.fitness(genome); + let divided = if fitness < 0.0 { + fitness * len + } else { + fitness / len + }; + (fitness, divided) + }) + .collect(); + + let raw = results.iter().map(|&(r, _)| r).collect(); + let divided = results.iter().map(|&(_, d)| d).collect(); + + (raw, divided) + } + + /// Calculates the fitness of each genome, dividing by the number of genomes in its species, and sorts them by fitness. + /// Returns a vector of tuples containing the genome and its fitness score. + pub fn calculate_and_sort(&self, genomes: Vec) -> Vec<(G, f32)> { + let (_, divided) = self.calculate_fitnesses(&genomes); + let mut result: Vec<(G, f32)> = genomes.into_iter().zip(divided).collect(); + result.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); result } } @@ -708,20 +730,44 @@ mod speciation { { #[cfg(not(feature = "rayon"))] fn eliminate(&mut self, genomes: Vec) -> Vec { - let mut fitnesses = self.calculate_and_sort(genomes); - self.inner.observer.observe(&fitnesses); - let median_index = (fitnesses.len() as f32) * self.inner.threshold; - fitnesses.truncate(median_index as usize + 1); - fitnesses.into_iter().map(|(g, _)| g).collect() + let (raw, divided) = self.calculate_fitnesses(&genomes); + + let mut sorted: Vec<(G, f32, f32)> = genomes + .into_iter() + .enumerate() + .map(|(i, g)| (g, raw[i], divided[i])) + .collect(); + sorted.sort_by(|(_, _, a), (_, _, b)| b.partial_cmp(a).unwrap()); + + let median_index = (sorted.len() as f32) * self.inner.threshold; + + let mut observer_pairs: Vec<(G, f32)> = + sorted.into_iter().map(|(g, raw, _)| (g, raw)).collect(); + self.inner.observer.observe(&observer_pairs); + + observer_pairs.truncate(median_index as usize + 1); + observer_pairs.into_iter().map(|(g, _)| g).collect() } #[cfg(feature = "rayon")] fn eliminate(&mut self, genomes: Vec) -> Vec { - let mut fitnesses = self.calculate_and_sort(genomes); - self.inner.observer.observe(&fitnesses); - let median_index = (fitnesses.len() as f32) * self.inner.threshold; - fitnesses.truncate(median_index as usize + 1); - fitnesses.into_par_iter().map(|(g, _)| g).collect() + let (raw, divided) = self.calculate_fitnesses(&genomes); + + let mut sorted: Vec<(G, f32, f32)> = genomes + .into_iter() + .enumerate() + .map(|(i, g)| (g, raw[i], divided[i])) + .collect(); + sorted.sort_by(|(_, _, a), (_, _, b)| b.partial_cmp(a).unwrap()); + + let median_index = (sorted.len() as f32) * self.inner.threshold; + + let mut observer_pairs: Vec<(G, f32)> = + sorted.into_iter().map(|(g, raw, _)| (g, raw)).collect(); + self.inner.observer.observe(&observer_pairs); + + observer_pairs.truncate(median_index as usize + 1); + observer_pairs.into_par_iter().map(|(g, _)| g).collect() } } } diff --git a/genetic-rs/examples/speciation.rs b/genetic-rs/examples/speciation.rs index a9f3635..bab9a72 100644 --- a/genetic-rs/examples/speciation.rs +++ b/genetic-rs/examples/speciation.rs @@ -111,9 +111,6 @@ fn fitness(genome: &MyGenome) -> f32 { } fn print_fitnesses(fitnesses: &[(MyGenome, f32)]) { - // note that with SpeciatedFitnessEliminator, - // these values are divided by the number of genomes in the species if positive, - // multiplied if negative. let hi = fitnesses[0].1; let med = fitnesses[fitnesses.len() / 2].1; let lo = fitnesses[fitnesses.len() - 1].1; diff --git a/genetic-rs/tests/speciation.rs b/genetic-rs/tests/speciation.rs index 751a7ca..89a2562 100644 --- a/genetic-rs/tests/speciation.rs +++ b/genetic-rs/tests/speciation.rs @@ -265,3 +265,52 @@ fn speciation_protects_rare_species() { "the rare species genome must survive despite lower raw fitness" ); } + +/// The fitness observer on [`SpeciatedFitnessEliminator`] must receive the raw +/// (pre-division) fitness values, not the values after they have been divided by +/// the number of genomes in the species. +/// +/// Setup: +/// - 4 genomes of class 0 with val = 1.0 → raw fitness = 1.0, divided = 0.25 +/// - 1 genome of class 1 with val = 0.5 → raw fitness = 0.5, divided = 0.5 +/// +/// If the observer sees raw values, it must observe 1.0 and 0.5 among the scores. +/// If it sees divided values, it would observe 0.25 instead of 1.0 — the test +/// would fail in that case. +#[test] +fn observer_receives_pre_division_fitness() { + use std::sync::{Arc, Mutex}; + + let observed: Arc>> = Arc::new(Mutex::new(Vec::new())); + let observed_clone = Arc::clone(&observed); + + let observer = move |fitnesses: &[(Genome, f32)]| { + let mut v = observed_clone.lock().unwrap(); + v.extend(fitnesses.iter().map(|(_, f)| *f)); + }; + + let mut class0_genomes: Vec = (0..4).map(|_| Genome { class: 0, val: 1.0 }).collect(); + class0_genomes.push(Genome { class: 1, val: 0.5 }); + + let mut eliminator = SpeciatedFitnessEliminator::new(fitness, 0.5, 0.5, observer, ()); + eliminator.eliminate(class0_genomes); + + let scores = observed.lock().unwrap(); + // Raw fitness values are 1.0 (×4) and 0.5 (×1). + // Divided values would be 0.25 and 0.5 — we must NOT see 0.25. + assert!( + scores.iter().any(|&f| (f - 1.0_f32).abs() < 1e-6), + "observer must see the raw fitness 1.0, but got: {:?}", + *scores, + ); + assert!( + scores.iter().any(|&f| (f - 0.5_f32).abs() < 1e-6), + "observer must see the raw fitness 0.5, but got: {:?}", + *scores, + ); + assert!( + !scores.iter().any(|&f| (f - 0.25_f32).abs() < 1e-6), + "observer must NOT see the divided fitness 0.25 (pre-division values expected), but got: {:?}", + *scores, + ); +}