diff --git a/src/UnweightedSamplingSingle.jl b/src/UnweightedSamplingSingle.jl index 108e7f0..318fb71 100644 --- a/src/UnweightedSamplingSingle.jl +++ b/src/UnweightedSamplingSingle.jl @@ -44,9 +44,9 @@ function Base.merge(ss::SingleAlgRSWRSKIPSampler...) ns = [nobs(s) for s in ss] n_tot = sum(ns) ps = cumsum(ns ./ n_tot) - r = rand(s1.rng) - value = ss[findfirst(p -> r < p, ps)].value - return typeof(s1)(n_tot, sum(s.skip_k for s in ss), ss[1].rng, value) + r = rand(ss[1].rng) + value = ss[findfirst(p -> r < p, ps)].rvalue + return typeof(ss[1])(n_tot, sum(s.skip_k for s in ss), ss[1].rng, value) end function Base.merge!(s1::SingleAlgRSWRSKIPSampler_Mut, ss::SingleAlgRSWRSKIPSampler_Mut...) diff --git a/src/WeightedSamplingSingle.jl b/src/WeightedSamplingSingle.jl index 4579960..5ac3ef9 100644 --- a/src/WeightedSamplingSingle.jl +++ b/src/WeightedSamplingSingle.jl @@ -49,9 +49,9 @@ function Base.merge(ss::SingleAlgWRSWRSKIPSampler...) ns = [s.total_w for s in ss] n_tot = sum(ns) ps = cumsum(ns ./ n_tot) - r = rand(s1.rng) - value = ss[findfirst(p -> r < p, ps)].value - return typeof(s1)(sum(s.seen_k for s in ss), sum(s.total_w for s in ss), sum(s.skip_w for s in ss), + r = rand(ss[1].rng) + value = ss[findfirst(p -> r < p, ps)].rvalue + return typeof(ss[1])(sum(s.seen_k for s in ss), sum(s.total_w for s in ss), sum(s.skip_w for s in ss), ss[1].rng, value) end diff --git a/test/merge_tests.jl b/test/merge_tests.jl index 14d94f7..130ff82 100644 --- a/test/merge_tests.jl +++ b/test/merge_tests.jl @@ -43,4 +43,53 @@ m == AlgRSWRSKIP() ? fit!(s2, 2) : fit!(s2, 2, 1.0) @test value(merge!(s1, s2)) in (1, 2) end + + iters = (1:10, 11:30) + reps = 10000 + for m in (AlgRSWRSKIP(),) + count_s1 = 0 + for _ in 1:reps + s1 = ReservoirSampler{Int}(rng, m) + s2 = ReservoirSampler{Int}(rng, m) + for x in iters[1] fit!(s1, x) end + for x in iters[2] fit!(s2, x) end + s_merged = merge(s1, s2) + if value(s_merged) <= 10 + count_s1 += 1 + end + end + chisq_test = ChisqTest([count_s1, reps - count_s1], [1/3, 2/3]) + @test pvalue(chisq_test) > 0.05 + end + + for m in (AlgWRSWRSKIP(),) + count_s1 = 0 + for _ in 1:reps + s1 = ReservoirSampler{Int}(rng, m) + s2 = ReservoirSampler{Int}(rng, m) + for x in iters[1] fit!(s1, x, 1.0) end + for x in iters[2] fit!(s2, x, 1.0) end + s_merged = merge(s1, s2) + if value(s_merged) <= 10 + count_s1 += 1 + end + end + chisq_test = ChisqTest([count_s1, reps - count_s1], [1/3, 2/3]) + @test pvalue(chisq_test) > 0.05 + + rng = StableRNG(45) + count_s1 = 0 + for _ in 1:reps + s1 = ReservoirSampler{Int}(rng, m) + s2 = ReservoirSampler{Int}(rng, m) + fit!(s1, 1, 10.0) + fit!(s2, 2, 20.0) + s_merged = merge(s1, s2) + if value(s_merged) == 1 + count_s1 += 1 + end + end + chisq_test = ChisqTest([count_s1, reps - count_s1], [1/3, 2/3]) + @test pvalue(chisq_test) > 0.05 + end end