-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsynthDataNew.m
More file actions
94 lines (67 loc) · 1.83 KB
/
synthDataNew.m
File metadata and controls
94 lines (67 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
function [X, Y, Psi, Theta, B] = synthDataNew(n, p, q)
% generate groups among the inputs
G_x = generateGroups(p);
% generate groups among the outputs
G_y = generateGroups(q);
% generate block sparse covariance matrix for the inputs
S_x = generateBlockCovariance(G_x);
% generate conditional covariance matrix for the outputs
S_y = eye(q);
% generate structured sparse regression coefficients
B = generateRegressionWeights(G_x,G_y);
% generate inputs X
X = mvnrnd(zeros(n,p),S_x);
X = zscore(X);
% generate outputs Y
Y = mvnrnd(X*B,S_y);
Y = zscore(Y);
% calculate correlation matrix for the inputs
Psi = X'*X/(n-1);
% calculate correlation matrix for the outputs
Theta = Y'*Y/(n-1);
end
% partition indices 1:v into groups
function G = generateGroups(numVar)
% fix number of groups
numGroup = floor(numVar/3);
% assign each variable to a group
groups = randi(numGroup,numVar,1);
% order variables according to groups
G = sort(groups);
end
% generate block sparse covariance matrix
function S = generateBlockCovariance(G)
% initialize S
v = length(G);
S = zeros(v,v);
% get number of groups
numG = length(G);
% construct each block of S
for g = 1:numG
gInd = find(G==g);
gLen = length(gInd);
S(gInd,gInd) = ones(gLen,gLen);
end
end
% generate regression weights
function B = generateRegressionWeights(G_in,G_out)
% initialize B
p = length(G_in);
q = length(G_out);
B = zeros(p,q);
% get number of groups
numGin = max(G_in);
numGout = max(G_out);
% set distribution over number of connections
connDist = sort(4.^(0:numGin-1),'descend');
connDist = connDist./sum(connDist);
% assign each output group to one or more input groups
for gout = 1:numGout
numConn = find(mnrnd(1,connDist));
connInd = randsample(numGin,numConn)';
display(num2str(numConn));
for c = connInd
B(G_in==c,G_out==gout) = 0.8;
end
end
end