Home > manopt > tools > productmanifold.m

productmanifold

PURPOSE ^

Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.

SYNOPSIS ^

function M = productmanifold(elements)

DESCRIPTION ^

 Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.

 function M = productmanifold(elements)

 Input: an elements structure such that each field contains a manifold
 structure.
 
 Output: a manifold structure M representing the manifold obtained by
 taking the Cartesian product of the manifolds described in the elements
 structure, with the metric obtainded by element-wise extension. Points
 and vectors are stored as structures with the same fieldnames as in
 elements.

 Example:
 M = productmanifold(struct('X', spherefactory(3), 'Y', spherefactory(4)))
 disp(M.name());
 x = M.rand()

 Points of M = S^2 x S^3 are represented as structures with two fields, X
 and Y. The values associated to X are points of S^2, and likewise points
 of S^3 for the field Y. Tangent vectors are also represented as
 structures with two corresponding fields X and Y.
 
 See also: powermanifold

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = productmanifold(elements)
0002 % Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.
0003 %
0004 % function M = productmanifold(elements)
0005 %
0006 % Input: an elements structure such that each field contains a manifold
0007 % structure.
0008 %
0009 % Output: a manifold structure M representing the manifold obtained by
0010 % taking the Cartesian product of the manifolds described in the elements
0011 % structure, with the metric obtainded by element-wise extension. Points
0012 % and vectors are stored as structures with the same fieldnames as in
0013 % elements.
0014 %
0015 % Example:
0016 % M = productmanifold(struct('X', spherefactory(3), 'Y', spherefactory(4)))
0017 % disp(M.name());
0018 % x = M.rand()
0019 %
0020 % Points of M = S^2 x S^3 are represented as structures with two fields, X
0021 % and Y. The values associated to X are points of S^2, and likewise points
0022 % of S^3 for the field Y. Tangent vectors are also represented as
0023 % structures with two corresponding fields X and Y.
0024 %
0025 % See also: powermanifold
0026 
0027 % This file is part of Manopt: www.manopt.org.
0028 % Original author: Nicolas Boumal, Dec. 30, 2012.
0029 % Contributors:
0030 % Change log:
0031 %   NB, July 4, 2013: Added support for vec, mat, tangent.
0032 %                     Added support for egrad2rgrad and ehess2rhess.
0033 %                     Modified hash function to make hash strings shorter.
0034 
0035 
0036     elems = fieldnames(elements);
0037     nelems = numel(elems);
0038     
0039     assert(nelems >= 1, ...
0040            'elements must be a structure with at least one field.');
0041     
0042     M.name = @name;
0043     function str = name()
0044         str = 'Product manifold: ';
0045         str = [str sprintf('[%s: %s]', ...
0046                            elems{1}, elements.(elems{1}).name())];
0047         for i = 2 : nelems
0048             str = [str sprintf(' x [%s: %s]', ...
0049                    elems{i}, elements.(elems{i}).name())]; %#ok<AGROW>
0050         end
0051     end
0052     
0053     M.dim = @dim;
0054     function d = dim()
0055         d = 0;
0056         for i = 1 : nelems
0057             d = d + elements.(elems{i}).dim();
0058         end
0059     end
0060     
0061     M.inner = @inner;
0062     function val = inner(x, u, v)
0063         val = 0;
0064         for i = 1 : nelems
0065             val = val + elements.(elems{i}).inner(x.(elems{i}), ...
0066                                                u.(elems{i}), v.(elems{i}));
0067         end
0068     end
0069 
0070     M.norm = @(x, d) sqrt(M.inner(x, d, d));
0071 
0072     M.dist = @dist;
0073     function d = dist(x, y)
0074         sqd = 0;
0075         for i = 1 : nelems
0076             sqd = sqd + elements.(elems{i}).dist(x.(elems{i}), ...
0077                                                  y.(elems{i}))^2;
0078         end
0079         d = sqrt(sqd);
0080     end
0081     
0082     M.typicaldist = @typicaldist;
0083     function d = typicaldist
0084         sqd = 0;
0085         for i = 1 : nelems
0086             sqd = sqd + elements.(elems{i}).typicaldist()^2;
0087         end
0088         d = sqrt(sqd);
0089     end
0090 
0091     M.proj = @proj;
0092     function v = proj(x, u)
0093         for i = 1 : nelems
0094             v.(elems{i}) = elements.(elems{i}).proj(x.(elems{i}), ...
0095                                                     u.(elems{i}));
0096         end
0097     end
0098 
0099     M.tangent = @tangent;
0100     function v = tangent(x, u)
0101         for i = 1 : nelems
0102             v.(elems{i}) = elements.(elems{i}).tangent(x.(elems{i}), ...
0103                                                        u.(elems{i}));
0104         end
0105     end
0106 
0107     % True by default, false if any false encountered
0108     M.tangent2ambient_is_identity = true;
0109     for k = 1 : nelems
0110         if isfield(elements.(elems{k}), 'tangent2ambient_is_identity')
0111             if ~elements.(elems{k}).tangent2ambient_is_identity
0112                 M.tangent2ambient_is_identity = false;
0113                 break;
0114             end
0115         end
0116     end
0117     
0118     M.tangent2ambient = @tangent2ambient;
0119     function v = tangent2ambient(x, u)
0120         for i = 1 : nelems
0121             if isfield(elements.(elems{i}), 'tangent2ambient')
0122                 v.(elems{i}) = ...
0123                     elements.(elems{i}).tangent2ambient( ...
0124                                                x.(elems{i}), u.(elems{i}));
0125             else
0126                 v.(elems{i}) = u.(elems{i});
0127             end
0128         end
0129     end
0130 
0131     M.egrad2rgrad = @egrad2rgrad;
0132     function g = egrad2rgrad(x, g)
0133         for i = 1 : nelems
0134             g.(elems{i}) = elements.(elems{i}).egrad2rgrad(...
0135                                                x.(elems{i}), g.(elems{i}));
0136         end
0137     end
0138 
0139     M.ehess2rhess = @ehess2rhess;
0140     function h = ehess2rhess(x, eg, eh, h)
0141         for i = 1 : nelems
0142             h.(elems{i}) = elements.(elems{i}).ehess2rhess(...
0143                  x.(elems{i}), eg.(elems{i}), eh.(elems{i}), h.(elems{i}));
0144         end
0145     end
0146     
0147     M.exp = @exp;
0148     function y = exp(x, u, t)
0149         if nargin < 3
0150             t = 1.0;
0151         end
0152         for i = 1 : nelems
0153             y.(elems{i}) = elements.(elems{i}).exp(x.(elems{i}), ...
0154                                                    u.(elems{i}), t);
0155         end
0156     end
0157     
0158     M.retr = @retr;
0159     function y = retr(x, u, t)
0160         if nargin < 3
0161             t = 1.0;
0162         end
0163         for i = 1 : nelems
0164             y.(elems{i}) = elements.(elems{i}).retr(x.(elems{i}), ...
0165                                                     u.(elems{i}), t);
0166         end
0167     end
0168     
0169     M.log = @log;
0170     function u = log(x1, x2)
0171         for i = 1 : nelems
0172             u.(elems{i}) = elements.(elems{i}).log(x1.(elems{i}), ...
0173                                                    x2.(elems{i}));
0174         end
0175     end
0176 
0177     M.hash = @hash;
0178     function str = hash(x)
0179         str = '';
0180         for i = 1 : nelems
0181             str = [str elements.(elems{i}).hash(x.(elems{i}))]; %#ok<AGROW>
0182         end
0183         str = ['z' hashmd5(str)];
0184     end
0185 
0186     M.lincomb = @lincomb;
0187     function v = lincomb(x, a1, u1, a2, u2)
0188         if nargin == 3
0189             for i = 1 : nelems
0190                 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0191                                                         a1, u1.(elems{i}));
0192             end
0193         elseif nargin == 5
0194             for i = 1 : nelems
0195                 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0196                                      a1, u1.(elems{i}), a2, u2.(elems{i}));
0197             end
0198         else
0199             error('Bad usage of productmanifold.lincomb');
0200         end
0201     end
0202 
0203     M.rand = @rand;
0204     function x = rand()
0205         for i = 1 : nelems
0206             x.(elems{i}) = elements.(elems{i}).rand();
0207         end
0208     end
0209 
0210     M.randvec = @randvec;
0211     function u = randvec(x)
0212         for i = 1 : nelems
0213             u.(elems{i}) = elements.(elems{i}).randvec(x.(elems{i}));
0214         end
0215         u = M.lincomb(x, 1/sqrt(nelems), u);
0216     end
0217 
0218     M.zerovec = @zerovec;
0219     function u = zerovec(x)
0220         for i = 1 : nelems
0221             u.(elems{i}) = elements.(elems{i}).zerovec(x.(elems{i}));
0222         end
0223     end
0224 
0225     M.transp = @transp;
0226     function v = transp(x1, x2, u)
0227         for i = 1 : nelems
0228             v.(elems{i}) = elements.(elems{i}).transp(x1.(elems{i}), ...
0229                                               x2.(elems{i}), u.(elems{i}));
0230         end
0231     end
0232 
0233     M.pairmean = @pairmean;
0234     function y = pairmean(x1, x2)
0235         for i = 1 : nelems
0236             y.(elems{i}) = elements.(elems{i}).pairmean(x1.(elems{i}), ...
0237                                                         x2.(elems{i}));
0238         end
0239     end
0240 
0241 
0242     % Gather the length of the column vector representations of tangent
0243     % vectors for each of the manifolds. Raise a flag if any of the base
0244     % manifolds has no vec function available.
0245     vec_available = true;
0246     vec_lens = zeros(nelems, 1);
0247     for ii = 1 : nelems
0248         Mi = elements.(elems{ii});
0249         if isfield(Mi, 'vec')
0250             rand_x = Mi.rand();
0251             zero_u = Mi.zerovec(rand_x);
0252             vec_lens(ii) = length(Mi.vec(rand_x, zero_u));
0253         else
0254             vec_available = false;
0255             break;
0256         end
0257     end
0258     vec_pos = cumsum([1 ; vec_lens]);
0259     
0260     if vec_available
0261         M.vec = @vec;
0262         M.mat = @mat;
0263     end
0264     
0265     function u_vec = vec(x, u_mat)
0266         u_vec = zeros(vec_pos(end)-1, 1);
0267         for i = 1 : nelems
0268             range = vec_pos(i) : (vec_pos(i+1)-1);
0269             u_vec(range) = elements.(elems{i}).vec(x.(elems{i}), ...
0270                                                    u_mat.(elems{i}));
0271         end
0272     end
0273 
0274     function u_mat = mat(x, u_vec)
0275         u_mat = struct();
0276         for i = 1 : nelems
0277             range = vec_pos(i) : (vec_pos(i+1)-1);
0278             u_mat.(elems{i}) = elements.(elems{i}).mat(x.(elems{i}), ...
0279                                                        u_vec(range));
0280         end
0281     end
0282 
0283     vecmatareisometries = true;
0284     for ii = 1 : nelems
0285         if ~isfield(elements.(elems{ii}), 'vecmatareisometries') || ...
0286            ~elements.(elems{ii}).vecmatareisometries()
0287             vecmatareisometries = false;
0288             break;
0289         end
0290     end
0291     M.vecmatareisometries = @() vecmatareisometries;    
0292 
0293 end

Generated on Mon 10-Sep-2018 11:48:06 by m2html © 2005