diff --git a/matlab/BMatchingSolver.cpp b/matlab/BMatchingSolver.cpp new file mode 100644 index 0000000..fa22e4d --- /dev/null +++ b/matlab/BMatchingSolver.cpp @@ -0,0 +1,176 @@ +/*! + * BMatchingSolverMex + * Bert Huang + */ + +#include +#include +#include "mex.h" +#include "BMatchingLibrary.h" +#include "SparseMatrix.h" +#include "utils.h" + +using namespace std; +using namespace bmatchingLibrary; + +double ** getMatrix(const mxArray *pm) { + double ** A; + double * matrix = mxGetPr(pm); + int m = mxGetM(pm); + int n = mxGetN(pm); + + if (m == 0) + return 0; + + A = new double*[m]; + + for (int i = 0; i < m; i++) + A[i] = new double[n]; + + for (int i = 0; i < m; i++) + for (int j = 0; j < n; j++) + A[i][j] = matrix[j*m+i]; + + return A; +} + +void deleteMatrix(double ** A, int size) { + for (int i = 0; i < size; i++) + delete[](A[i]); + delete[](A); +} + +void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { + int weightType = 0, d = 0, m = 0, n = 0, cacheSize = 0, + wsize = 0, xsize = 0, ysize = 0; + double ** W = 0, ** X = 0, ** Y = 0; + int * b, blen; + bool matrix = true; + bool verbose = false; + + + //(W, b, X, Y, weightType, cacheSize) + W = getMatrix(prhs[0]); + + m = mxGetM(prhs[0]); + wsize = m; + n = mxGetN(prhs[0]); + double * bdouble = mxGetPr(prhs[1]); + blen = mxGetM(prhs[1]) + mxGetN(prhs[1]) - 1; + b = new int[blen]; + for (int i = 0; i < blen; i++) + b[i] = round(bdouble[i]); + + if (nrhs > 2) { + X = getMatrix(prhs[2]); + + xsize = mxGetM(prhs[2]); + + if (xsize > 0) { + m = mxGetM(prhs[2]); + d = mxGetN(prhs[2]); + matrix = false; + weightType = 1; // default weight type + if (nrhs > 3) { + Y = getMatrix(prhs[3]); + n = mxGetM(prhs[3]); + ysize = n; + } else + n = 0; + } else + X = 0; + } + + if (nrhs > 4 && mxGetM(prhs[4]) > 0) + weightType = round(mxGetScalar(prhs[4])); + if (nrhs > 5 && mxGetM(prhs[4]) > 0) + cacheSize = round(mxGetScalar(prhs[5])); + else + cacheSize = round(2 * sqrt(m + n)); + + if (nrhs > 6) + verbose = mxGetScalar(prhs[6]) > 0; + + SparseMatrix * solution; + + // By default, perform no more than 100*(m+n) iterations + int maxIter = 100*(m+n); + + if (blen == m) { + // unipartite + if (matrix) + solution = bMatchMatrixCache(m, W, b, cacheSize, maxIter, verbose); + else if (weightType == 1) + solution = bMatchEuclideanCache(m, d, X, b, cacheSize, maxIter, verbose); + else if (weightType == 2) + solution = bMatchInnerProductCache(m, d, X, b, cacheSize, maxIter, verbose); + else + mexErrMsgTxt("Unrecognized weight type"); + } else { + // bipartite + if (matrix) + solution = bMatchBipartiteMatrixCache(m, n, W, b, b+m, + cacheSize, maxIter, verbose); + else if (weightType == 1) + solution = bMatchBipartiteEuclideanCache(m, n, d, X, Y, b, + b+m, cacheSize, maxIter, verbose); + else if (weightType == 2) + solution = bMatchBipartiteInnerProductCache(m, n, d, X, Y, b, + b+m, cacheSize, maxIter, verbose); + else + mexErrMsgTxt("Unrecognized weight type"); + } + + if (wsize > 0) + deleteMatrix(W, wsize); + if (xsize > 0) + deleteMatrix(X, xsize); + if (ysize > 0) + deleteMatrix(Y, ysize); + + delete[](b); + + + int nnz = solution->getNNz(); + + mxArray * I = mxCreateDoubleMatrix(nnz, 1, mxREAL); + mxArray * J = mxCreateDoubleMatrix(nnz, 1, mxREAL); + mxArray * V = mxCreateDoubleMatrix(nnz, 1, mxREAL); + + double * rows = new double[nnz]; + double * cols = new double[nnz]; + double * vals = new double[nnz]; + + for (int i=0; i < nnz; i++) { + rows[i] = solution->getRows()[i]+1; + cols[i] = solution->getCols()[i]+1; + vals[i] = 1.0; + } + + delete(solution); + + memcpy(mxGetPr(I), rows, nnz*sizeof(double)); + memcpy(mxGetPr(J), cols, nnz*sizeof(double)); + memcpy(mxGetPr(V), vals, nnz*sizeof(double)); + + delete[](rows); + delete[](cols); + delete[](vals); + + mxArray * rhs[5]; + + rhs[0] = I; + rhs[1] = J; + rhs[2] = V; + rhs[3] = mxCreateDoubleScalar(blen); + rhs[4] = mxCreateDoubleScalar(blen); + + mexCallMATLAB(1, plhs, 5, rhs, "sparse"); + + for (int i = 0; i < 5; i++) + mxDestroyArray(rhs[i]); + + + + return; +} diff --git a/matlab/BMatchingSolverCmd.m b/matlab/BMatchingSolverCmd.m new file mode 100644 index 0000000..be21347 --- /dev/null +++ b/matlab/BMatchingSolverCmd.m @@ -0,0 +1,90 @@ +function [A, time] = BMatchingSolver(W, b, X, Y, weightType, cacheSize, flags) + +% solves for the maximum weight b-matching +% W is the weight matrix (option 1) +% b is the vector of target degrees +% X is the first bipartition (option 2) +% Y is the second bipartition (option 2) +% weightType is 1 for negative Euclidean distance, 2 for inner product +% cacheSize is the size of the weight cache, default is 2*sqrt(m+n) +% calls mex version if mex version is compiled and in current path + +[m, n] = size(W); +if exist('weightType', 'var') && ~isempty(weightType) && ismember(weightType, [1 2]) + m = size(X,1); + n = size(Y,1); +end + +global tmp_dir +if isempty(tmp_dir) + tmp_dir = '~/tmp'; +end + +global bmatchingsolver +if isempty(bmatchingsolver) + bmatchingsolver = '~/Dropbox/workspace/BMatchingSolver/Release/BMatchingSolver'; +end + +persistent problem_id; +if isempty(problem_id) + problem_id = uint32(randi(9999)); +else + problem_id = uint32(mod(problem_id + randi(9999), 100000)); +end + +outFile = sprintf('%s/tmp_%d_output.txt', tmp_dir, problem_id); +degFile = sprintf('%s/tmp_%d_degrees.txt', tmp_dir, problem_id); +dlmwrite(degFile, b, 'precision', '%9.0f'); + +if (m == n && length(b) == m) + % assume unipartite + cmd = sprintf('%s -n %d -d %s -o %s', bmatchingsolver, m, ... + degFile, outFile); + N = m; +else + % bipartite + cmd = sprintf('%s -n %d --bipartite %d -d %s -o %s', bmatchingsolver, m+n, m, ... + degFile, outFile); + N = m+n; +end + +if ~isempty(W) + weightFile = sprintf('%s/tmp_%d_weights.txt', tmp_dir, problem_id); + save(weightFile, 'W', '-ascii', '-double'); + + cmd = sprintf('%s -w %s', cmd, weightFile); +elseif exist('weightType', 'var') + dataFile = sprintf('%s/tmp_%d_data.txt', tmp_dir, problem_id); + data = [X; Y]; + save(dataFile, 'data', '-ascii', '-double'); + + cmd = sprintf('%s -x %s -t %d -D %d', cmd, dataFile, weightType,... + size(data,2)); +end + +% add cache size parameter +if ~exist('cacheSize', 'var') + cacheSize = round(2*sqrt(m+n)); +end +cmd = sprintf('%s -c %d', cmd, cacheSize); + +if exist('flags', 'var') + cmd = sprintf('%s %s', cmd, flags); +end + +tic; +system(cmd); +time = toc; + +IJ = dlmread(outFile, ' ') + 1; +A = sparse(IJ(:,1), IJ(:,2), true(size(IJ,1),1), N, N); + +delete(degFile); +if (exist('weightFile', 'var')) + delete(weightFile); +end +if (exist('dataFile', 'var')) + delete(dataFile); +end +delete(outFile); + diff --git a/matlab/bdmatch_augment.m b/matlab/bdmatch_augment.m new file mode 100644 index 0000000..cbe8b0d --- /dev/null +++ b/matlab/bdmatch_augment.m @@ -0,0 +1,60 @@ +function [W, X, Y, b, I, J] = bdmatch_augment(W, X, Y, lb, ub) +% function [W, X, Y, b] = bdmatch_mex(W, X, Y, lb, ub) +% creates augmented b-matching to solve bd-matching problem + +b = ub; + +auxcount = ub - lb; + +maxaux = max(auxcount); + +unused = maxaux - auxcount; + +if ~isempty(W) + [m,n] = size(W); + + W = [W zeros(m, maxaux); zeros(maxaux, n+maxaux)]; + if (m == n && length(lb) == m) + % unipartite + for i = 1:m + W(i,end-unused(i)+1:end) = -inf; + W(end-unused(i)+1:end,i) = -inf; + end + b = [b(:); -ones(maxaux,1)]; + I = 1:m; + J = 1:m; + else + % bipartite + for i = 1:m + W(i,end-unused(i)+1:end) = -inf; + end + for i = 1:n + W(end-unused(i+m)+1:end,i) = -inf; + end + b = [b(1:m); -ones(maxaux,1); b(m+1:end); -ones(maxaux,1)]; + + I = 1:m; + J = m+maxaux+1:m+maxaux+n; + end +elseif ~isempty(X) + [m,d] = size(X); + [n,d] = size(Y); + + if any(unused(1:m) ~= unused(1)) || n > 0 && any(unused(m+1:end) ~= unused(m+1)) + error('This script can only handle Euclidean or inner product problems with the same lb and ub in each bipartition'); + end + + if (n > 0) + % bipartite + X = [X; nan(auxcount(m+1), d)]; + Y = [Y; nan(auxcount(1), d)]; + b = [b(1:m); -ones(auxcount(m+1),1); b(m+1:end); -ones(auxcount(1),1)]; + I = 1:m; + J = m+auxcount(m+1)+1:m+auxcount(m+1)+n; + else + X = [X; nan(auxcount(1), d)]; + b = [b(:); -ones(auxcount(1),1)]; + I = 1:m; + J = 1:m; + end +end diff --git a/matlab/lprelax.m b/matlab/lprelax.m new file mode 100644 index 0000000..11d225e --- /dev/null +++ b/matlab/lprelax.m @@ -0,0 +1,46 @@ +function [f, A, b] = lprelax(W, mask, lb, ub) + +N = size(W,1); + +if (N ~= length(lb)) + [m,n] = size(W); + W = [sparse(m, m) W; W' sparse(n, n)]; + mask = [sparse(m, m) mask; mask' sparse(n, n)]; + N = n+m; + bipartite = true; +else + bipartite = false; +end + +f = -nonzeros(triu(W.*mask)); + +[I,J] = find(triu(mask)); + + +A = zeros(2*N, nnz(triu(mask)))'; +b = zeros(2*N, 1); + +for i=1:N + A(I==i | J==i, i) = 1; + b(i) = ub(i); + + A(I==i | J==i, i+N) = -1; + b(i+N) = -lb(i); +end +A = A'; + +if (nargout==1) + + options.Display = 'none'; + + x = linprog(f, A, b, [], [], zeros(size(f)), ones(size(f)), ... + zeros(size(f)), options); + + X = sparse(I, J, x, N, N); + + f = X+triu(X,1)'; + + if bipartite + f = f(1:m, m+1:end); + end +end \ No newline at end of file diff --git a/matlab/makeMex.m b/matlab/makeMex.m new file mode 100644 index 0000000..62ffeb4 --- /dev/null +++ b/matlab/makeMex.m @@ -0,0 +1,33 @@ +src_dir = '/Users/bert/Dropbox/workspace/BMatchingSolver/src/'; + +objects = {'IntSet', + 'utils', + 'IndexHeap', + 'IntDoubleMap', + 'WeightFunction', + 'FunctionOracle', + 'EuclideanDistance', + 'InnerProduct', + 'WeightOracle', + 'OscillationDetector', + 'BipartiteFunctionOracle', + 'BipartiteMatrixOracle', + 'MatrixOracle', + 'BMatchingProblem', + 'BeliefPropagator', + 'BMatchingLibrary'}; + +object_string = ''; + +for i = 1:length(objects) + object_string = sprintf('%s%s%s.cpp ', ... + object_string, src_dir, objects{i}); +end +%% + +cmd = sprintf('mex -O -D_MEX_HACK -I%s -O BMatchingSolver.cpp %s', src_dir, object_string); + +eval(cmd); + +delete *.o *~; + diff --git a/matlab/memory_test.m b/matlab/memory_test.m new file mode 100644 index 0000000..1759cc5 --- /dev/null +++ b/matlab/memory_test.m @@ -0,0 +1,15 @@ +n = 10; +d = 5; + +X = rand(n,d); +Y = rand(n,d); + +b = ones(2*n, 1); + + +fprintf('Starting test. Hit ctrl-c to exit\n'); +while(1) + A = BMatchingSolver([], b, X, Y); +end + + diff --git a/matlab/tester.m b/matlab/tester.m new file mode 100644 index 0000000..50fbce9 --- /dev/null +++ b/matlab/tester.m @@ -0,0 +1,166 @@ +% test suite: +% - mex euclidean, inner product, matrix +% - cmd line euclidean, inner product, matrix +% - verbose and not verbose +% n = 1, 10, 50 +% b = 1, 2, n, +% cache size = 0, 1 n +% compare against lprelax result + +Nvec = [2 10 50]; + +d = 20; + +for i = 1:length(Nvec) + + n = Nvec(i); + + Cvec = [0 1 n]; + bvec = [1 ceil(n/5) n]; + + for c = 1:length(Cvec) + cacheSize = Cvec(c); + + for j = 1:length(bvec) + for s = 1:3 + if (s == 1) + % weight matrix + lpW = randn(n); + W = lpW; + X = []; + Y = []; + weightType = 0; + elseif (s == 2) + % euclidean + X = randn(n, d); + Y = randn(n, d); + W = []; + + selfX = sum(X.^2,2); + selfY = sum(Y.^2,2); + lpW = -sqrt(bsxfun(@plus, selfX, selfY') - 2*X*Y'); + weightType = 1; + elseif (s == 3) + % inner product + X = randn(n,d); + Y = randn(n,d); + W = []; + lpW = X*Y'; + weightType = 2; + end + + b = bvec(j)*ones(2*n,1); + lpsol = round(lprelax(lpW, true(n), b, b)); + + mexsol = BMatchingSolver(W, b, X, Y, weightType, cacheSize); + mexsol = mexsol(1:n, n+1:end); + cmdsol = BMatchingSolverCmd(W, b, X, Y, weightType, cacheSize); + cmdsol = cmdsol(1:n, n+1:end); + + subplot(211); + imagesc(lpsol); + subplot(212); + imagesc(mexsol); + + if (nnz(mexsol - lpsol) || nnz(mexsol - cmdsol)) + fprintf('different solution for type %d, n = %d, b = %d, c = %d\n', s, n, bvec(j), cacheSize); + nnz(mexsol-lpsol) + nnz(mexsol-cmdsol) + nnz(lpsol-cmdsol) + pause + else + fprintf('type %d, n = %d, b = %d, c = %d \t OK\n', s, n, bvec(j), cacheSize); + end + + drawnow; + end + end + end +end +%% +fprintf('Running with verbose on: mex version\n'); +n = 50; +d = 5; +X = randn(n, d); +Y = randn(n, d); + +b = ones(2*n,1); + +BMatchingSolver([], b, X, Y, weightType, cacheSize, true); + +BMatchingSolverCmd([], b, X, Y, weightType, cacheSize, '-v'); + +%% comparing running times +fprintf('Running on larger size and dimensionality input and measuring running time (standby...)\n'); +n = 1024; +d = 784; +X = randn(n, d); +Y = randn(n, d); +cacheSize = 100; + +b = 4*ones(2*n,1); + +tic; +BMatchingSolver([], b, X, Y, weightType, cacheSize); +mextime = toc; + +tic; +BMatchingSolverCmd([], b, X, Y, weightType, cacheSize); +cmdtime = toc; + +fprintf('Mex took %f seconds, CMD line took %f seconds\n', mextime, cmdtime); + + +%% trying bdmatching + +m = 50; +n = 100; + +W = rand(m,n); + +lb = [ones(m,1); zeros(n,1)]; +ub = ones(m+n,1); + +[W2, ~, ~, b, I, J] = bdmatch_augment(W, [], [], lb, ub); + +mexsol = BMatchingSolver(W2, b, [], [], 0, 10); +mexsol = mexsol(I,J); + +lpsol = round(lprelax(W, true(m,n), lb, ub)); + +if nnz(mexsol - lpsol) + fprintf('different solution for bd-matching matrix\n'); +else + fprintf('bd-match weight matrix \t\tOK\n'); +end + +%% +m = 50; +n = 100; +d = 5; +X = floor(256*rand(m, d)); +Y = floor(256*rand(n, d)); + +lb = [ones(m,1); zeros(n,1)]; +ub = ones(m+n,1); + +selfX = sum(X.^2,2); +selfY = sum(Y.^2,2); +W = -sqrt(bsxfun(@plus, selfX, selfY') - 2*X*Y'); + +lpsol = round(lprelax(W, true(m,n), lb, ub)); + +%[W, ~, ~, b, I, J] = bdmatch_augment(W, [], [], lb, ub); +%mexsol = BMatchingSolver(W, b, [], [], 1, 10); + +[~, X2, Y2, b, I, J] = bdmatch_augment([], X, Y, lb, ub); + +mexsol = BMatchingSolver([], b, X2, Y2, 1, 10); +mexsol = mexsol(I,J); + +if nnz(mexsol - lpsol) + fprintf('different solution for bd-matching Euclidean dist.\n'); +else + fprintf('bd-match Euclidean dist. \tOK\n'); +end +