diff --git a/genetic-rs-common/src/builtin/eliminator.rs b/genetic-rs-common/src/builtin/eliminator.rs index 5309d4e..b2a60bc 100644 --- a/genetic-rs-common/src/builtin/eliminator.rs +++ b/genetic-rs-common/src/builtin/eliminator.rs @@ -654,10 +654,10 @@ mod speciation { SpeciatedPopulation::from_genomes(&genomes, self.speciation_threshold, &self.ctx); let mut fitnesses = vec![0.0; genomes.len()]; - for species in population.species { + for species in population.species() { let len = species.len() as f32; debug_assert!(len != 0.0); - for index in species { + for &index in species { let genome = &genomes[index]; let fitness = self.inner.fitness_fn.fitness(genome); if fitness < 0.0 { @@ -682,10 +682,10 @@ mod speciation { let mut fitnesses = vec![0.0; genomes.len()]; - for species in population.species { + for species in population.species() { let len = species.len() as f32; debug_assert!(len != 0.0); - for index in species { + for &index in species { let genome = &genomes[index]; let fitness = self.inner.fitness_fn.fitness(genome); if fitness < 0.0 { diff --git a/genetic-rs-common/src/builtin/repopulator.rs b/genetic-rs-common/src/builtin/repopulator.rs index f4c3c88..67c31c3 100644 --- a/genetic-rs-common/src/builtin/repopulator.rs +++ b/genetic-rs-common/src/builtin/repopulator.rs @@ -283,7 +283,7 @@ mod speciation { // if all species are isolated, we fall back to the inner crossover repopulator to avoid an infinite loop. if matches!(self.action_if_isolated, ActionIfIsolated::DoNothing) - && !population.species.iter().any(|s| s.len() >= 2) + && !population.species().iter().any(|s| s.len() >= 2) { self.inner.repopulate(genomes, target_size); return; @@ -295,7 +295,7 @@ mod speciation { let mut i = 0; while i < amount_to_make { let (species_i, genome_i) = species_cycle.next().unwrap(); - let species = &population.species[species_i]; + let species = &population.species()[species_i]; let parent1 = &genomes[genome_i]; if species.len() < 2 { match self.action_if_isolated { @@ -314,7 +314,7 @@ mod speciation { ActionIfIsolated::CrossoverSimilarSpecies => { let mut best_species_i = 0; let mut best_divergence = f32::MAX; - for (j, species) in population.species.iter().enumerate() { + for (j, species) in population.species().iter().enumerate() { if j == species_i || species.is_empty() { continue; } @@ -326,7 +326,7 @@ mod speciation { } } - let best_species = &population.species[best_species_i]; + let best_species = &population.species()[best_species_i]; let j = rng.random_range(0..best_species.len()); let parent2 = &genomes[best_species[j]]; let child = parent1.crossover( diff --git a/genetic-rs-common/src/speciation.rs b/genetic-rs-common/src/speciation.rs index 6dd8fcb..402ab18 100644 --- a/genetic-rs-common/src/speciation.rs +++ b/genetic-rs-common/src/speciation.rs @@ -19,8 +19,9 @@ pub trait Speciated { pub struct SpeciatedPopulation { /// The species in this population. Each species is a vector of indices into the original genome vector. /// The first genome in a species is its representation (i.e. the one that gets compared to other genomes to determine - /// if they belong in the species) - pub species: Vec>, + /// if they belong in the species). + /// Invariant: every inner `Vec` is non-empty. + species: Vec>, /// The threshold used to determine if a genome belongs in a species. If the divergence between a genome and the representative genome /// of a species is less than this threshold, then the genome belongs in that species. @@ -28,6 +29,20 @@ pub struct SpeciatedPopulation { } impl SpeciatedPopulation { + /// Creates a new, empty [`SpeciatedPopulation`] with the given threshold. + pub fn new(threshold: f32) -> Self { + Self { + species: Vec::new(), + threshold, + } + } + + /// Returns the species in this population. + /// Each inner slice is guaranteed to be non-empty. + pub fn species(&self) -> &[Vec] { + &self.species + } + /// Inserts a genome into the speciated population. /// Returns whether a new species was created by this insertion. pub fn insert_genome( @@ -48,10 +63,7 @@ impl SpeciatedPopulation { /// Note that this can be O(n^2) worst case, but is typically much faster in practice, /// especially if the genome structure doesn't mutate often. pub fn from_genomes(population: &[G], threshold: f32, ctx: &G::Context) -> Self { - let mut speciated_population = SpeciatedPopulation { - species: Vec::new(), - threshold, - }; + let mut speciated_population = SpeciatedPopulation::new(threshold); for index in 0..population.len() { speciated_population.insert_genome(index, population, ctx); } diff --git a/genetic-rs/tests/speciation.rs b/genetic-rs/tests/speciation.rs index 448ccf4..751a7ca 100644 --- a/genetic-rs/tests/speciation.rs +++ b/genetic-rs/tests/speciation.rs @@ -69,9 +69,9 @@ fn fitness(g: &Genome) -> f32 { fn identical_genomes_in_same_species() { let genomes: Vec = (0..5).map(|_| Genome { class: 0, val: 0.0 }).collect(); let pop = SpeciatedPopulation::from_genomes(&genomes, 0.5, &()); - assert_eq!(pop.species.len(), 1, "expected a single species"); + assert_eq!(pop.species().len(), 1, "expected a single species"); assert_eq!( - pop.species[0].len(), + pop.species()[0].len(), 5, "all genomes must belong to the single species" ); @@ -82,7 +82,7 @@ fn identical_genomes_in_same_species() { fn different_class_genomes_in_different_species() { let genomes: Vec = (0..4).map(|i| Genome { class: i, val: 0.0 }).collect(); let pop = SpeciatedPopulation::from_genomes(&genomes, 0.5, &()); - assert_eq!(pop.species.len(), 4); + assert_eq!(pop.species().len(), 4); } /// With a threshold > 1.0 (larger than the max divergence) all genomes, @@ -97,7 +97,7 @@ fn high_threshold_groups_all_genomes() { .collect(); // Divergence is at most 1.0; with threshold 1.5 everything is "close enough". let pop = SpeciatedPopulation::from_genomes(&genomes, 1.5, &()); - assert_eq!(pop.species.len(), 1, "all genomes must be in one species"); + assert_eq!(pop.species().len(), 1, "all genomes must be in one species"); } /// Every genome index must appear in exactly one species. @@ -113,7 +113,7 @@ fn every_genome_index_appears_exactly_once() { let pop = SpeciatedPopulation::from_genomes(&genomes, 0.5, &()); let mut seen = vec![false; n]; - for species in &pop.species { + for species in pop.species() { for &idx in species { assert!( !seen[idx], @@ -139,13 +139,11 @@ fn insert_genome_creates_new_species_for_novel_genome() { Genome { class: 0, val: 0.0 }, Genome { class: 1, val: 0.0 }, // divergence 1.0 > threshold 0.5 → new species ]; - let mut pop = SpeciatedPopulation { - species: vec![vec![0]], - threshold: 0.5, - }; + let mut pop = SpeciatedPopulation::new(0.5); + pop.insert_genome(0, &genomes, &()); let created_new = pop.insert_genome(1, &genomes, &()); assert!(created_new, "expected a new species to be created"); - assert_eq!(pop.species.len(), 2); + assert_eq!(pop.species().len(), 2); } /// Inserting a genome from an existing class must join that species. @@ -155,17 +153,15 @@ fn insert_genome_joins_existing_species_for_similar_genome() { Genome { class: 0, val: 0.0 }, Genome { class: 0, val: 1.0 }, // divergence 0.0 < threshold 0.5 → joins species ]; - let mut pop = SpeciatedPopulation { - species: vec![vec![0]], - threshold: 0.5, - }; + let mut pop = SpeciatedPopulation::new(0.5); + pop.insert_genome(0, &genomes, &()); let created_new = pop.insert_genome(1, &genomes, &()); assert!( !created_new, "must not create a new species for a similar genome" ); - assert_eq!(pop.species.len(), 1); - assert_eq!(pop.species[0].len(), 2); + assert_eq!(pop.species().len(), 1); + assert_eq!(pop.species()[0].len(), 2); } // ───────────────────────────────────────────────────────────────────────────── @@ -208,7 +204,7 @@ fn round_robin_enumerate_species_index_is_valid() { for (species_i, genome_i) in pop.round_robin_enumerate().take(12) { assert!( - pop.species[species_i].contains(&genome_i), + pop.species()[species_i].contains(&genome_i), "genome index {genome_i} is not in species {species_i}" ); }