Manifold of n-by-m postive matrices such that row sum is p and column sum is q. function M = multinomialdoublystochasticgeneralfactory(n, m, p, q) X > 0. X1 = p, p is a column positive vector of size n. X'1 = q, q is a column positive vector of size m. Ensure that p > 0 and q > 0. Also, ensure that sum(p) == sum(q). Please cite the Manopt paper as well as the research papers: @Techreport{mishra21a, Title = {Manifold optimization for optimal transport}, Author = {Mishra, B. and Satya Dev, N. T. V., Kasai, H. and Jawanpuria, P.}, Journal = {Arxiv preprint arXiv:2103.00902}, Year = {2021} } @article{douik2019manifold, title={Manifold optimization over the set of doubly stochastic matrices: A second-order geometry}, author={Douik, A. and Hassibi, B.}, journal={IEEE Transactions on Signal Processing}, volume={67}, number={22}, pages={5761--5774}, year={2019} } @article{shi21a, title={Coupling matrix manifolds assisted optimization for optimal transport problems}, author={Shi, D. and Gao, J. and Hong, X. and Choy, ST. B. and Wang, Z.}, journal={Machine Learning}, pages={1--26}, year={2021} b} The factory file extends the factory file multinomialdoublystochasticfactory to handle general scaling of rows and columns. See also multinomialdoublystochastic multinomialsymmetricfactory multinomialfactory
0001 function M = multinomialdoublystochasticgeneralfactory(n, m, p, q) 0002 % Manifold of n-by-m postive matrices such that row sum is p and column sum is q. 0003 % 0004 % function M = multinomialdoublystochasticgeneralfactory(n, m, p, q) 0005 % 0006 % X > 0. 0007 % X1 = p, p is a column positive vector of size n. 0008 % X'1 = q, q is a column positive vector of size m. 0009 % 0010 % Ensure that p > 0 and q > 0. Also, ensure that sum(p) == sum(q). 0011 % 0012 % 0013 % Please cite the Manopt paper as well as the research papers: 0014 % 0015 % 0016 % @Techreport{mishra21a, 0017 % Title = {Manifold optimization for optimal transport}, 0018 % Author = {Mishra, B. and Satya Dev, N. T. V., Kasai, H. and Jawanpuria, P.}, 0019 % Journal = {Arxiv preprint arXiv:2103.00902}, 0020 % Year = {2021} 0021 % } 0022 % 0023 % @article{douik2019manifold, 0024 % title={Manifold optimization over the set of doubly stochastic matrices: A second-order geometry}, 0025 % author={Douik, A. and Hassibi, B.}, 0026 % journal={IEEE Transactions on Signal Processing}, 0027 % volume={67}, 0028 % number={22}, 0029 % pages={5761--5774}, 0030 % year={2019} 0031 %} 0032 % 0033 % 0034 % @article{shi21a, 0035 % title={Coupling matrix manifolds assisted optimization for optimal transport problems}, 0036 % author={Shi, D. and Gao, J. and Hong, X. and Choy, ST. B. and Wang, Z.}, 0037 % journal={Machine Learning}, 0038 % pages={1--26}, 0039 % year={2021} 0040 %b} 0041 % 0042 % 0043 % The factory file extends the factory file 0044 % multinomialdoublystochasticfactory 0045 % to handle general scaling of rows and columns. 0046 % 0047 % 0048 % See also multinomialdoublystochastic multinomialsymmetricfactory multinomialfactory 0049 0050 % This file is part of Manopt: www.manopt.org. 0051 % Original author: Bamdev Mishra, Oct 30, 2020. 0052 % Contributors: 0053 % Change log: 0054 0055 e1 = ones(n, 1); 0056 e2 = ones(m, 1); 0057 0058 maxDSiters = min(1000, n*m); % Ideally it should be supplid by user. 0059 0060 if size(p, 1) ~= n 0061 error('p should be a column vector of size n.'); 0062 end 0063 0064 if size(q, 1) ~= m 0065 error('q should be a column vector of size m.'); 0066 end 0067 0068 function [alpha, beta] = mylinearsolve(X, b) % BM okay 0069 % zeta = sparse(A)\b; % sparse might not be better perf.-wise. 0070 % where A = [diag(p) X ; X' diag(q)]; 0071 % 0072 % Even faster is to create a function handle 0073 % computing A*x (x is a given vector). 0074 % Make sure that A is not created, and X is only 0075 % passed with mylinearsolve and not A. 0076 [zeta, ~, ~, iter] = pcg(@mycompute, b, 1e-6, 100); 0077 function Ax = mycompute(x) % BM okay 0078 xtop = x(1:n,1); % vector of size n akin to alpha 0079 xbottom = x(n+1:end,1); % vector of size m akin to beta 0080 Axtop = xtop.*p + X*xbottom; 0081 Axbottom = X'*xtop + xbottom.*q; 0082 Ax = [Axtop; Axbottom]; 0083 end 0084 alpha = zeta(1:n, 1); 0085 beta = zeta(n+1:end, 1); 0086 end 0087 0088 M.name = @() sprintf('%dx%d matrices with positive entries such that row sum is p and column sum is q', n, n); 0089 0090 M.dim = @() (n-1)*(m-1); % BM okay 0091 0092 % Fisher metric 0093 M.inner = @iproduct; % BM okay 0094 function ip = iproduct(X, eta, zeta) 0095 ip = sum((eta(:).*zeta(:))./X(:)); 0096 end 0097 0098 M.norm = @(X, eta) sqrt(M.inner(X, eta, eta)); % BM okay 0099 0100 M.dist = @(X, Y) error('multinomialdoublystochasticgeneralfactory.dist not implemented yet.'); 0101 0102 % The manifold is not compact as a result of the choice of the metric, 0103 % thus any choice here is arbitrary. This is notably used to pick 0104 % default values of initial and maximal trust-region radius in the 0105 % trustregions solver. 0106 M.typicaldist = @() m+n; 0107 0108 % Pick a random point on the manifold 0109 M.rand = @random; % BM okay 0110 function X = random() 0111 Z = abs(randn(n, m)); % Random point in the ambient space 0112 X = doubly_stochastic_general(Z, p, q, maxDSiters); % Projection on the Manifold 0113 end 0114 0115 % Pick a random vector in the tangent space at X. 0116 M.randvec = @randomvec; % BM okay 0117 function eta = randomvec(X) % A random vector in the tangent space 0118 % A random vector in the ambient space 0119 Z = randn(n, m); 0120 % Projection of the vector onto the tangent space 0121 b = [sum(Z, 2) ; sum(Z, 1)']; 0122 [alpha, beta] = mylinearsolve(X, b); 0123 eta = Z - (alpha*e2' + e1*beta').*X; 0124 % Normalizing the vector 0125 nrm = M.norm(X, eta); 0126 eta = eta / nrm; 0127 end 0128 0129 % Projection of vector eta in the ambient space to the tangent space. 0130 M.proj = @projection; % BM okay 0131 function etaproj = projection(X, eta) % Projection of the vector eta in the ambeint space onto the tangent space 0132 b = [sum(eta, 2) ; sum(eta, 1)']; 0133 [alpha, beta] = mylinearsolve(X, b); 0134 etaproj = eta - (alpha*e2' + e1*beta').*X; 0135 end 0136 0137 M.tangent = M.proj; 0138 M.tangent2ambient = @(X, eta) eta; % BM okay 0139 0140 % Conversion of Euclidean to Riemannian gradient 0141 M.egrad2rgrad = @egrad2rgrad; % BM okay 0142 function rgrad = egrad2rgrad(X, egrad) % projection of the euclidean gradient 0143 mu = (X.*egrad); 0144 b = [sum(mu, 2) ; sum(mu, 1)']; 0145 [alpha, beta] = mylinearsolve(X, b); 0146 rgrad = mu - (alpha*e2' + e1*beta').*X; 0147 end 0148 0149 % First-order retraction 0150 M.retr = @retraction; 0151 function Y = retraction(X, eta, t) 0152 if nargin < 3 0153 t = 1.0; 0154 end 0155 Y = X.*exp(t*(eta./X)); 0156 0157 Y = min(Y, 1e50); % For numerical stability 0158 Y = max(Y, 1e-50); % For numerical stability 0159 0160 Y = doubly_stochastic_general(Y, p, q, maxDSiters); 0161 end 0162 0163 % Conversion of Euclidean to Riemannian Hessian 0164 M.ehess2rhess = @ehess2rhess; % BM okay 0165 function rhess = ehess2rhess(X, egrad, ehess, eta) 0166 0167 % Computing the directional derivative of the Riemannian 0168 % gradient 0169 gamma = egrad.*X; 0170 gammadot = ehess.*X + egrad.*eta; 0171 0172 b = [sum(gamma, 2) ; sum(gamma, 1)']; 0173 bdot = [sum(gammadot, 2) ; sum(gammadot, 1)']; 0174 [alpha, beta] = mylinearsolve(X, b); 0175 [alphadot, betadot] = mylinearsolve(X, bdot- [eta*beta; eta'*alpha]); 0176 0177 S = (alpha*e2' + e1*beta'); 0178 deltadot = gammadot - (alphadot*e2' + e1*betadot').*X- S.*eta; % rgraddot 0179 0180 % Computing Riemannian gradient 0181 delta = gamma - S.*X; % rgrad 0182 0183 % Riemannian Hessian in the ambient space 0184 nabla = deltadot - 0.5*(delta.*eta)./X; 0185 0186 % Riemannian Hessian on the tangent space 0187 rhess = projection(X, nabla); 0188 end 0189 0190 0191 % Miscellaneous manifold functions % BM okay 0192 M.hash = @(X) ['z' hashmd5(X(:))]; 0193 M.lincomb = @matrixlincomb; 0194 M.zerovec = @(X) zeros(n, m); 0195 M.transp = @(X1, X2, d) projection(X2, d); 0196 M.vec = @(X, U) U(:); 0197 M.mat = @(X, u) reshape(u, n, m); 0198 M.vecmatareisometries = @() false; 0199 0200 end