diff --git a/count_expirement_1.mat b/count_expirement_1.mat new file mode 100644 index 0000000..8d3bf2d Binary files /dev/null and b/count_expirement_1.mat differ diff --git a/demo1.m b/demo1.m index bfba2a7..d1eb8c0 100644 --- a/demo1.m +++ b/demo1.m @@ -9,18 +9,21 @@ D(2) = init_D_from_txt('solway1.txt'); D(3) = init_D_from_txt('solway2.txt'); D(4) = init_D_from_txt('schapiro.txt'); +D(5) = init_D_from_txt('lynn.txt'); -for i = 1:length(D) +for i = 5:length(D) + tic [samples, post] = sample(D(i), h); for j = 1:length(samples) H(i,j) = samples(j); P(i,j) = post(j); end + toc end -%} -%load four.mat; -load four_repro.mat; +save demo1.mat +%} +load demo1.mat; figure; diff --git a/demo1.mat b/demo1.mat new file mode 100644 index 0000000..fef4276 Binary files /dev/null and b/demo1.mat differ diff --git a/demo_generate_samples.m b/demo_generate_samples.m new file mode 100644 index 0000000..79a6168 --- /dev/null +++ b/demo_generate_samples.m @@ -0,0 +1,24 @@ +h.alpha = 1.5; + +D(1) = init_D_from_txt('hourglass1.txt'); + +nsubjects = 40; + +count_4_1 = 0; +count_4_7 = 0; +both = 0; +for i = 1:nsubjects + [~, samples, ~] = sample_graph_update(D(1), h); + H = samples(1); + if (H.c(4) == H.c(1)) + count_4_1 = count_4_1 + 1; + end + if (H.c(4) == H.c(7)) + count_4_7 = count_4_7 + 1; + end + if (H.c(4) == H.c(7)) && (H.c(4) == H.c(1)) + both = both + 1; + end +end + +save count_expirement_1.mat \ No newline at end of file diff --git a/demo_update_graph.m b/demo_update_graph.m new file mode 100644 index 0000000..ac1f592 --- /dev/null +++ b/demo_update_graph.m @@ -0,0 +1,38 @@ +clear all; +% { +h.alpha = 1.5; + +D(1) = init_D_from_txt('hourglass.txt'); + +for i = 1:1 + tic + [D(i), samples, post] = sample_graph_update(D(i), h); + for j = 1:length(samples) + H(i,j) = samples(j); + P(i,j) = post(j); + end + toc +end + +save demo_update_graph.mat +% } + +load demo_update_graph.mat; + +figure; + +k = 1; +for i = 1:length(D) + post = P(i,:); + [~,I] = maxk(post, k); + for j = 1:k + subplot(length(D),k, (i-1)*k+j); + plot_H(H(i,I(j)), D(i)); + if j == ceil(k/2) + %ylabel(D(i).name); + title(D(i).name); + end + set(gca, 'xtick', []); + set(gca, 'ytick', []); + end +end \ No newline at end of file diff --git a/hourglass.txt b/hourglass.txt index 74e4a11..30738b5 100644 --- a/hourglass.txt +++ b/hourglass.txt @@ -1,10 +1,19 @@ Hourglass -6 7 -1 3 -3 5 -1 5 -3 4 -4 2 -2 6 -4 6 +6 15 +1 3 1 +3 5 1 +1 5 1 +3 4 1 +4 2 1 +1 2 0 +1 4 0 +2 5 0 +3 6 0 +4 5 0 +5 6 0 +2 6 1 +1 6 0 +2 3 0 +4 6 1 +0 0 diff --git a/hourglass1.txt b/hourglass1.txt new file mode 100644 index 0000000..38ecc5e --- /dev/null +++ b/hourglass1.txt @@ -0,0 +1,37 @@ +Hourglass +7 21 +1 2 0 +1 3 0 +1 4 0 +1 5 0 +1 6 0 +1 7 0 +2 3 0 +2 4 0 +2 5 0 +2 6 0 +2 7 0 +3 4 0 +3 5 0 +3 6 0 +3 7 0 +4 5 0 +4 6 0 +4 7 0 +5 6 0 +5 7 0 +6 7 0 +0 +12 +1 2 1 +1 3 1 +1 4 1 +2 3 1 +2 4 1 +3 4 1 +4 5 1 +4 6 1 +4 7 1 +5 6 1 +5 7 1 +6 7 1 \ No newline at end of file diff --git a/hourglass2.txt b/hourglass2.txt new file mode 100644 index 0000000..42919bf --- /dev/null +++ b/hourglass2.txt @@ -0,0 +1,37 @@ +Hourglass +7 21 +1 2 0 +1 3 0 +1 4 0 +1 5 0 +1 6 0 +1 7 0 +2 3 0 +2 4 0 +2 5 0 +2 6 0 +2 7 0 +3 4 0 +3 5 0 +3 6 0 +3 7 0 +4 5 0 +4 6 0 +4 7 0 +5 6 0 +5 7 0 +6 7 0 +0 +12 +4 5 1 +4 6 1 +4 7 1 +5 6 1 +5 7 1 +6 7 1 +1 2 1 +1 3 1 +1 4 1 +2 3 1 +2 4 1 +3 4 1 \ No newline at end of file diff --git a/init_D_from_txt.m b/init_D_from_txt.m index e99270b..68f697c 100644 --- a/init_D_from_txt.m +++ b/init_D_from_txt.m @@ -7,11 +7,14 @@ N = A(1); M = A(2); D.G.N = N; D.G.E = zeros(N, N); % TODO sparse? + D.G.I = zeros(N, N); for k = 1:M - A = freadline(f, '%d %d'); - i = A(1); j = A(2); - D.G.E(i,j) = 1; - D.G.E(j,i) = 1; + A = freadline(f, '%d %d %d'); + i = A(1); j = A(2); exists = A(3); + D.G.E(i,j) = exists; + D.G.E(j,i) = exists; + D.G.I(i,j) = 1; + D.G.I(j,i) = 1; end D.tasks.s = []; @@ -23,6 +26,16 @@ D.tasks.s = [D.tasks.s s]; D.tasks.g = [D.tasks.g g]; end + + % updates + A = freadline(f, '%d'); + num_updates = A(1); + D.updates = []; + for k = 1:num_updates + A = freadline(f, '%d %d %d'); + i = A(1); j = A(2); exists = A(3); + D.updates = [D.updates; i j exists]; + end fclose(f); end diff --git a/init_H.m b/init_H.m index 6548adb..85d64a9 100644 --- a/init_H.m +++ b/init_H.m @@ -18,21 +18,6 @@ H.p = betarnd(1,1); % TODO const H.q = betarnd(1,1); % TODO const - H.tp = betarnd(1,1); % TODO const - - H.N = length(cnt); - H.hp = betarnd(1,1); % TODO const - H.E = zeros(H.N, H.N); % TODO sparse ? - for k = 1:H.N - for l = 1:k-1 - if rand < H.hp - H.E(k,l) = 1; - H.E(l,k) = 1; - end - end - end - - % TODO bridges end diff --git a/loglik.m b/loglik.m index 68ae47b..f0fbd69 100644 --- a/loglik.m +++ b/loglik.m @@ -1,10 +1,13 @@ function logp = loglik(H, D, h) - % P(D|H) = P(G,tasks|H) = P(tasks|G,H) P(G|H) + % P(D|H) = P(G|H) % logp = 0; for i = 1:D.G.N for j = 1:i-1 + if D.G.I(i, j) == 0 + continue; + end if H.c(i) == H.c(j) if D.G.E(i,j) logp = logp + log(H.p); @@ -22,15 +25,5 @@ % TODO bridges end end - - for i = 1:length(D.tasks.s) - s = D.tasks.s(i); - logp = logp + log(1 / D.G.N); - - g = D.tasks.g(i); - P = ones(1, D.G.N); - P(H.c ~= H.c(s)) = H.tp; - logp = logp + log(P(g)) - log(sum(P)); - end end diff --git a/logprior.m b/logprior.m index 737d172..12f0c06 100644 --- a/logprior.m +++ b/logprior.m @@ -17,20 +17,8 @@ end assert(isequal(cnt, H.cnt)); - logp = logp + log(betapdf(H.p,1,1)) + log(betapdf(H.q,1,1)) + log(betapdf(H.tp,1,1)); % TODO const + logp = logp + log(betapdf(H.p,1,1)) + log(betapdf(H.q,1,1)); % TODO const - for k = 1:H.N - for l = 1:k-1 - if H.E(k,l) - logp = logp + log(H.hp); - else - logp = logp + log(1 - H.hp); - end - end - end - - logp = logp + log(betapdf(H.hp,1,1)); - % TODO bridges end diff --git a/logsumexp.m b/logsumexp.m new file mode 100644 index 0000000..90d01ea --- /dev/null +++ b/logsumexp.m @@ -0,0 +1,19 @@ +function s = logsumexp(x,dim) + + % Returns log(sum(exp(x),dim)) while avoiding numerical underflow. + % Default is dim = 1 (columns). + + if nargin == 1 + % Determine which dimension sum will use + dim = find(size(x)~=1,1); + if isempty(dim), dim = 1; end + end + + % subtract the largest in each column + y = max(x,[],dim); + x = bsxfun(@minus,x,y); + s = y + log(sum(exp(x),dim)); + i = find(~isfinite(y)); + if ~isempty(i) + s(i) = y(i); + end \ No newline at end of file diff --git a/lynn.txt b/lynn.txt new file mode 100644 index 0000000..2eb2625 --- /dev/null +++ b/lynn.txt @@ -0,0 +1,33 @@ +Lynn 2018 +15 30 +1 2 +2 3 +3 4 +4 5 +5 6 +6 7 +7 8 +8 9 +9 10 +10 11 +11 12 +12 13 +13 14 +14 15 +15 1 +1 3 +2 4 +3 5 +4 6 +5 7 +6 8 +7 9 +8 10 +9 11 +10 12 +11 13 +12 14 +13 15 +14 1 +15 2 +0 diff --git a/read_samples.m b/read_samples.m new file mode 100644 index 0000000..0de08ac --- /dev/null +++ b/read_samples.m @@ -0,0 +1,10 @@ +load count_expirement_1.mat + +% 4 is connected to 5, 6, 7 + +n = nsubjects; +c1 = count_4_1; +c2 = count_4_7; + +p = 2 * binocdf(min(c1,c2), n, 0.5); +y = binoinv([0.025 0.975], n, 0.5); \ No newline at end of file diff --git a/sample.m b/sample.m index 74add91..a26b5ae 100644 --- a/sample.m +++ b/sample.m @@ -2,10 +2,8 @@ % % Draw samples from posterior P(H|D) using Metropolis-Hastings-within-Gibbs sampling. % hierarchy H = (c, p, q, p', p", E', V') - % data D = (G, tasks) + % data D = (G) % graph G = (E, V) - % tasks = (task_1, task_2 ...) - % task = (s, g) % % Generative model: % @@ -13,17 +11,10 @@ % state chunks = c ~ CRP % within-cluster density = p ~ Beta % across-cluster density = pq, q ~ Beta - % H graph density = p' = hp ~ Beta - % probability goal state is in different chunk from starting state = p" = tp ~ Beta % % P(G|H): % E(i,j) ~ Bern(p) if c(i) == c(j) % E(i,j) ~ Bern(pq) if c(i) != c(j) - % - % P(tasks|G,H) = product P(task|G,H) - % P(task|G,H): - % starting state = s ~ Cat(all vertices in G) - % goal state = g ~ Cat(1,1,1,1... for all i s.t. c(i) == c(s), ... p", p", p"... for all i s.t. c(i) != c(s) % if ~exist('nsamples', 'var') @@ -61,30 +52,13 @@ [q, accept] = mhsample(H.q, 1, 'logpdf', logp, 'proprnd', proprnd, 'logproppdf', logprop); H.q = q; - [tp, accept] = mhsample(H.tp, 1, 'logpdf', logp, 'proprnd', proprnd, 'logproppdf', logprop); - H.tp = tp; - - [hp, accept] = mhsample(H.hp, 1, 'logpdf', logp, 'proprnd', proprnd, 'logproppdf', logprop); - H.hp = hp; - - for k = 1:H.N - for l = 1:k-1 - logp = @(e) logpost_E_k_l(e, k, l, H, D, h); - proprnd = @(e_old) proprnd_E_k_l(e_old, k, l, H, D, h); - logprop = @(e_new, e_old) logprop_E_k_l(e_new, e_old, k, l, H, D, h); - - [e, accept] = mhsample(H.E(k,l), 1, 'logpdf', logp, 'proprnd', proprnd, 'logproppdf', logprop); % TODO adaptive - H.E(k,l) = e; - H.E(l,k) = e; - end - end - % TODO bridges samples(n) = H; post(n) = logpost(H,D,h); end - + + post = post(burnin:lag:end); samples = samples(burnin:lag:end); end @@ -99,7 +73,7 @@ % Update H.c(i) and counts % TODO makes copy of H -- super slow... % -function H = update_c_i(c_i, i, H) % TODO FIXME H.N is broken!! +function H = update_c_i(c_i, i, H) H.cnt(H.c(i)) = H.cnt(H.c(i)) - 1; H.c(i) = c_i; if c_i <= length(H.cnt) @@ -173,31 +147,3 @@ logp = log(normpdf(p_new, p_old, 1)) - log(Z); end - -% P(H|D) for updates of E -% -function logp = logpost_E_k_l(e, k, l, H, D, h) - H.E(k,l) = e; - H.E(l,k) = e; - logp = logpost(H, D, h); -end - -% proposal PMF for E -% keep the same, or draw from prior w/ some small prob -% -function P = propP_E_k_l(e_old, k, l, H, D, h) - P = [1 - H.hp, H.hp] * 0.3; % draw from prior w/ some small prob - P(e_old + 1) = P(e_old + 1) + 0.7; % or keep the same TODO consts -end - -% proposal for E -% -function e_new = proprnd_E_k_l(e_old, k, l, H, D, h) - P = propP_E_k_l(e_old, k, l, H, D, h); - e_new = find(mnrnd(1, P)) - 1; -end - -function logp = logprop_E_k_l(e_new, e_old, k, l, H, D, h) % TODO merge w/ proprnd - P = propP_E_k_l(e_old, k, l, H, D, h); - logp = log(P(e_new + 1)); -end diff --git a/sample_graph_update.m b/sample_graph_update.m new file mode 100644 index 0000000..dfa4550 --- /dev/null +++ b/sample_graph_update.m @@ -0,0 +1,218 @@ +function [D, samples, post] = sample_graph_update(D, h, nwait_update, nsamples, nparticles, burnin, lag) + if ~exist('nsamples', 'var') + nsamples = 10; + end + + if ~exist('burnin', 'var') + burnin = 1; % no burn-in + end + + if ~exist('lag', 'var') + lag = 1; + end + + if ~exist('nparticles', 'var') + nparticles = 20; + end + + if ~exist('nwait_update', 'var') + nwait_update = 1; + end + + % initialize + W = zeros(nparticles, 1); + + if any(D.G.I(:)) + [H, W] = sample(D, h, nparticles, 1000, 10); + else + for i = 1:nparticles + H(i) = init_H(D, h); + % [samples_t{i}, post_t{i}, H(i)] = sample_Hm(D, H(i), h, 1000, 1, 1); + [~, ~, H(i)] = sample_Hm(D, H(i), h, 1000, 1, 1); + W(i) = exp(loglik(H(i), D, h)); + end + end + + W = W/sum(W); + + s = size(D.updates); + num_updates = s(1); + for i = 1:num_updates + new_edge = D.updates(i, :); + D = update_D(D, new_edge); + if mod(i, nwait_update) == 0 + for j = 1:length(H) +% [samples_t{j}, post_t{j}, H(j)] = sample_Hm(D, H(j), h, nsamples, 1, 10); + [~, ~, H(j)] = sample_Hm(D, H(j), h, nsamples, 1, 10); + end + end + W = update_weights(W, H, new_edge); + end + + % sample from H +% pd = makedist('Multinomial','probabilities', W); +% r = random(pd); +% samples = samples_t{r}; post = post_t{r}; +% samples = H(r); post = 1; + post = zeros(size(H)); + len = size(H); + len = len(2); + for i = 1:len + post(i) = logpost(H(i),D,h); + end + [~, I] = max(post); + post = 1; + samples = H(I); +end + +function [samples, post, H] = sample_Hm(D, H, h, nsamples, burnin, lag) + % Roberts & Rosenthal (2009) + for n = 1:nsamples * lag + burnin + + % single MH step + for i = 1:D.G.N + logp = @(c_i) logpost_c_i(c_i, i, H, D, h); + proprnd = @(c_i_old) proprnd_c_i(c_i_old, i, H, D, h); + logprop = @(c_i_new, c_i_old) logprop_c_i(c_i_new, c_i_old, i, H, D, h); + + [c_i, accept] = mhsample(H.c(i), 1, 'logpdf', logp, 'proprnd', proprnd, 'logproppdf', logprop); + H = update_c_i(c_i, i, H); + end + + logp = @(p) logpost_p(p, H, D, h); + proprnd = @(p_old) proprnd_p(p_old, H, D, h); + logprop = @(p_new, p_old) logprop_p(p_new, p_old, H, D, h); + + [p, accept] = mhsample(H.p, 1, 'logpdf', logp, 'proprnd', proprnd, 'logproppdf', logprop); % TODO adaptive + H.p = p; + + [q, accept] = mhsample(H.q, 1, 'logpdf', logp, 'proprnd', proprnd, 'logproppdf', logprop); + H.q = q; + + % TODO bridges + samples(n) = H; + post(n) = logpost(H,D,h); + end + + post = post(burnin:lag:end); + samples = samples(burnin:lag:end); +end + +% P(H|D) up to proportionality constant +% +function logp = logpost(H, D, h) + logp = loglik(H, D, h) + logprior(H, D, h); +end + +% Update H.c(i) and counts +% TODO makes copy of H -- super slow... +% +function H = update_c_i(c_i, i, H) + H.cnt(H.c(i)) = H.cnt(H.c(i)) - 1; + H.c(i) = c_i; + if c_i <= length(H.cnt) + H.cnt(H.c(i)) = H.cnt(H.c(i)) + 1; + else + H.cnt = [H.cnt 1]; + end +end + + +% P(H|D) for updates of c_i +% i.e. with new c's up to c_i, the candidate c_i, then old c's after (and old rest of H) +% +function logp = logpost_c_i(c_i, i, H, D, h) + H = update_c_i(c_i, i, H); + logp = logpost(H, D, h); +end + +% proposal PMF for c_i +% inspired by Algorithm 5 from Neal 1998: MCMC for DP mixtures +% +function P = propP_c_i(c_i_old, i, H, D, h) + cnt = H.cnt; + cnt(H.c(i)) = cnt(H.c(i)) - 1; + z = find(cnt == 0); % reuse empty bins TODO is this legit? + if isempty(z) + cnt = [cnt h.alpha]; + else + cnt(z) = h.alpha; + end + P = cnt / sum(cnt); +end + +% propose c_i +% +function c_i_new = proprnd_c_i(c_i_old, i, H, D, h) + P = propP_c_i(c_i_old, i, H, D, h); + c_i_new = find(mnrnd(1, P)); + + % TODO bridges +end + +function [logP, P] = logprop_c_i(c_i_new, c_i_old, i, H, D, h) % TODO merge w/ proprnd + P = propP_c_i(c_i_old, i, H, D, h); + logP = log(P(c_i_new)); +end + + +% P(H|D) for updates of p +% +function logp = logpost_p(p, H, D, h) + H.p = p; + logp = logpost(H, D, h); +end + +% proposals for p; random walk +% +function p_new = proprnd_p(p_old, H, D, h) + while true % TODO can use universality of uniform inverse CDF thingy + p_new = normrnd(p_old, 0.1); % TODO const TODO adaptive + if p_new <= 1 && p_new >= 0 + break; % keep params within bounds + end + end +end + +% account for truncating that keeps params within bounds +% +function logp = logprop_p(p_new, p_old, H, D, h) + Z = normcdf(1, p_old, 0.1) - normcdf(0, p_old, 0.1); % TODO consts TODO adaptive + logp = log(normpdf(p_new, p_old, 1)) - log(Z); +end + +function D = update_D(D, new_edge) + i = new_edge(1); j = new_edge(2); exists = new_edge(3); + + D.G.E(i,j) = exists; + D.G.E(j,i) = exists; + D.G.I(i,j) = 1; + D.G.I(j,i) = 1; + +end + +function W = update_weights(W, H, new_edge) + assert(length(H) == length(W)); + for i = 1:length(W) + W(i) = W(i) * exp(loglik_update(H(i), new_edge)); + end + W = W/sum(W); +end + +% assumes new_edge state is known +function logp = loglik_update(H, new_edge) + i = new_edge(1); j = new_edge(2); exists = new_edge(3); + if H.c(i) == H.c(j) + if exists + logp = log(H.p); + else + logp = log(1 - H.p); + end + else + if exists + logp = log(H.p * H.q); + else + logp = log(1 - H.p * H.q); + end + end +end