Home > manopt > manifolds > multinomial > multinomialdoublystochasticgeneralfactory.m

# multinomialdoublystochasticgeneralfactory

## PURPOSE

Manifold of n-by-m postive matrices such that row sum is p and column sum is q.

## SYNOPSIS

function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)

## DESCRIPTION

``` Manifold of n-by-m postive matrices such that row sum is p and column sum is q.

function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)

X > 0.
X1 = p, p is a column positive vector of size n.
X'1 = q, q is a column positive vector of size m.

Ensure that p > 0 and q > 0. Also, ensure that sum(p) == sum(q).

Please cite the Manopt paper as well as the research papers:

@Techreport{mishra21a,
Title   = {Manifold optimization for optimal transport},
Author  = {Mishra, B. and Satya Dev, N. T. V., Kasai, H. and Jawanpuria, P.},
Journal = {Arxiv preprint arXiv:2103.00902},
Year    = {2021}
}

@article{douik2019manifold,
title={Manifold optimization over the set of doubly stochastic matrices: A second-order geometry},
author={Douik, A. and Hassibi, B.},
journal={IEEE Transactions on Signal Processing},
volume={67},
number={22},
pages={5761--5774},
year={2019}
}

@article{shi21a,
title={Coupling matrix manifolds assisted optimization for optimal transport problems},
author={Shi, D. and Gao, J. and Hong, X. and Choy, ST. B. and Wang, Z.},
journal={Machine Learning},
pages={1--26},
year={2021}
b}

The factory file extends the factory file
multinomialdoublystochasticfactory
to handle general scaling of rows and columns.

## CROSS-REFERENCE INFORMATION

This function calls:
• doubly_stochastic_general Project a matrix to the doubly stochastic matrices (Sinkhorn's algorithm)
• hashmd5 Computes the MD5 hash of input data.
• matrixlincomb Linear combination function for tangent vectors represented as matrices.
This function is called by:

## SOURCE CODE

```0001 function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)
0002 % Manifold of n-by-m postive matrices such that row sum is p and column sum is q.
0003 %
0004 % function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)
0005 %
0006 %  X > 0.
0007 %  X1 = p, p is a column positive vector of size n.
0008 %  X'1 = q, q is a column positive vector of size m.
0009 %
0010 % Ensure that p > 0 and q > 0. Also, ensure that sum(p) == sum(q).
0011 %
0012 %
0013 % Please cite the Manopt paper as well as the research papers:
0014 %
0015 %
0016 % @Techreport{mishra21a,
0017 %   Title   = {Manifold optimization for optimal transport},
0018 %   Author  = {Mishra, B. and Satya Dev, N. T. V., Kasai, H. and Jawanpuria, P.},
0019 %   Journal = {Arxiv preprint arXiv:2103.00902},
0020 %   Year    = {2021}
0021 % }
0022 %
0023 % @article{douik2019manifold,
0024 % title={Manifold optimization over the set of doubly stochastic matrices: A second-order geometry},
0025 %  author={Douik, A. and Hassibi, B.},
0026 %  journal={IEEE Transactions on Signal Processing},
0027 %  volume={67},
0028 %  number={22},
0029 %  pages={5761--5774},
0030 %  year={2019}
0031 %}
0032 %
0033 %
0034 % @article{shi21a,
0035 % title={Coupling matrix manifolds assisted optimization for optimal transport problems},
0036 %  author={Shi, D. and Gao, J. and Hong, X. and Choy, ST. B. and Wang, Z.},
0037 %  journal={Machine Learning},
0038 %  pages={1--26},
0039 %  year={2021}
0040 %b}
0041 %
0042 %
0043 % The factory file extends the factory file
0044 % multinomialdoublystochasticfactory
0045 % to handle general scaling of rows and columns.
0046 %
0047 %
0049
0050 % This file is part of Manopt: www.manopt.org.
0051 % Original author: Bamdev Mishra, Oct 30, 2020.
0052 % Contributors:
0053 % Change log:
0054
0055     e1 = ones(n, 1);
0056     e2 = ones(m, 1);
0057
0058     maxDSiters = min(1000, n*m); % Ideally it should be supplid by user.
0059
0060     if size(p, 1) ~= n
0061         error('p should be a column vector of size n.');
0062     end
0063
0064     if size(q, 1) ~= m
0065         error('q should be a column vector of size m.');
0066     end
0067
0068     function [alpha, beta] = mylinearsolve(X, b) % BM okay
0069         % zeta = sparse(A)\b; % sparse might not be better perf.-wise.
0070         % where A = [diag(p) X ; X' diag(q)];
0071         %
0072         % Even faster is to create a function handle
0073         % computing A*x (x is a given vector).
0074         % Make sure that A is not created, and X is only
0075         % passed with mylinearsolve and not A.
0076         [zeta, ~, ~, iter] = pcg(@mycompute, b, 1e-6, 100);
0077         function Ax = mycompute(x) % BM okay
0078             xtop = x(1:n,1); % vector of size n akin to alpha
0079             xbottom = x(n+1:end,1); % vector of size m akin to beta
0080             Axtop = xtop.*p + X*xbottom;
0081             Axbottom = X'*xtop + xbottom.*q;
0082             Ax = [Axtop; Axbottom];
0083         end
0084         alpha = zeta(1:n, 1);
0085         beta = zeta(n+1:end, 1);
0086     end
0087
0088     M.name = @() sprintf('%dx%d matrices with positive entries such that row sum is p and column sum is q', n, n);
0089
0090     M.dim = @() (n-1)*(m-1); % BM okay
0091
0092     % Fisher metric
0093     M.inner = @iproduct; % BM okay
0094     function ip = iproduct(X, eta, zeta)
0095         ip = sum((eta(:).*zeta(:))./X(:));
0096     end
0097
0098     M.norm = @(X, eta) sqrt(M.inner(X, eta, eta)); % BM okay
0099
0100     M.dist = @(X, Y) error('multinomialdoublystochasticgeneralfactory.dist not implemented yet.');
0101
0102     % The manifold is not compact as a result of the choice of the metric,
0103     % thus any choice here is arbitrary. This is notably used to pick
0104     % default values of initial and maximal trust-region radius in the
0105     % trustregions solver.
0106     M.typicaldist = @() m+n;
0107
0108     % Pick a random point on the manifold
0109     M.rand = @random; % BM okay
0110     function X = random()
0111         Z = abs(randn(n, m));     % Random point in the ambient space
0112         X = doubly_stochastic_general(Z, p, q, maxDSiters); % Projection on the Manifold
0113     end
0114
0115     % Pick a random vector in the tangent space at X.
0116     M.randvec = @randomvec; % BM okay
0117     function eta = randomvec(X) % A random vector in the tangent space
0118         % A random vector in the ambient space
0119         Z = randn(n, m);
0120         % Projection of the vector onto the tangent space
0121         b = [sum(Z, 2) ; sum(Z, 1)'];
0122         [alpha, beta] = mylinearsolve(X, b);
0123         eta = Z - (alpha*e2' + e1*beta').*X;
0124         % Normalizing the vector
0125         nrm = M.norm(X, eta);
0126         eta = eta / nrm;
0127     end
0128
0129     % Projection of vector eta in the ambient space to the tangent space.
0130     M.proj = @projection;  % BM okay
0131     function etaproj = projection(X, eta) % Projection of the vector eta in the ambeint space onto the tangent space
0132         b = [sum(eta, 2) ; sum(eta, 1)'];
0133         [alpha, beta] = mylinearsolve(X, b);
0134         etaproj = eta - (alpha*e2' + e1*beta').*X;
0135     end
0136
0137     M.tangent = M.proj;
0138     M.tangent2ambient = @(X, eta) eta; % BM okay
0139
0140     % Conversion of Euclidean to Riemannian gradient
0144         b = [sum(mu, 2) ; sum(mu, 1)'];
0145         [alpha, beta] = mylinearsolve(X, b);
0146         rgrad = mu - (alpha*e2' + e1*beta').*X;
0147     end
0148
0149     % First-order retraction
0150     M.retr = @retraction;
0151     function Y = retraction(X, eta, t)
0152         if nargin < 3
0153             t = 1.0;
0154         end
0155         Y = X.*exp(t*(eta./X));
0156
0157         Y = min(Y, 1e50); % For numerical stability
0158         Y = max(Y, 1e-50); % For numerical stability
0159
0160         Y = doubly_stochastic_general(Y, p, q, maxDSiters);
0161     end
0162
0163     % Conversion of Euclidean to Riemannian Hessian
0164     M.ehess2rhess = @ehess2rhess; % BM okay
0165     function rhess = ehess2rhess(X, egrad, ehess, eta)
0166
0167         % Computing the directional derivative of the Riemannian
0171
0172         b = [sum(gamma, 2) ; sum(gamma, 1)'];
0174         [alpha, beta] = mylinearsolve(X, b);
0176
0177         S = (alpha*e2' + e1*beta');
0179
0181         delta = gamma - S.*X; % rgrad
0182
0183         % Riemannian Hessian in the ambient space
0184         nabla = deltadot - 0.5*(delta.*eta)./X;
0185
0186         % Riemannian Hessian on the tangent space
0187         rhess = projection(X, nabla);
0188     end
0189
0190
0191     % Miscellaneous manifold functions % BM okay
0192     M.hash = @(X) ['z' hashmd5(X(:))];
0193     M.lincomb = @matrixlincomb;
0194     M.zerovec = @(X) zeros(n, m);
0195     M.transp = @(X1, X2, d) projection(X2, d);
0196     M.vec = @(X, U) U(:);
0197     M.mat = @(X, u) reshape(u, n, m);
0198     M.vecmatareisometries = @() false;
0199
0200 end```

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005