Home > manopt > tools > productmanifold.m

productmanifold

PURPOSE ^

Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.

SYNOPSIS ^

function M = productmanifold(elements)

DESCRIPTION ^

 Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.

 function M = productmanifold(elements)

 Input: an elements structure such that each field contains a manifold
 structure.
 
 Output: a manifold structure M representing the manifold obtained by
 taking the Cartesian product of the manifolds described in the elements
 structure, with the metric obtainded by element-wise extension. Points
 and vectors are stored as structures with the same fieldnames as in
 elements.

 Example:
 M = productmanifold(struct('X', spherefactory(3), 'Y', spherefactory(4)))
 disp(M.name());
 x = M.rand()

 Points of M = S^2 x S^3 are represented as structures with two fields, X
 and Y. The values associated to X are points of S^2, and likewise points
 of S^3 for the field Y. Tangent vectors are also represented as
 structures with two corresponding fields X and Y.
 
 See also: powermanifold

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = productmanifold(elements)
0002 % Returns a structure describing a product manifold M = M1 x M2 x ... x Mn.
0003 %
0004 % function M = productmanifold(elements)
0005 %
0006 % Input: an elements structure such that each field contains a manifold
0007 % structure.
0008 %
0009 % Output: a manifold structure M representing the manifold obtained by
0010 % taking the Cartesian product of the manifolds described in the elements
0011 % structure, with the metric obtainded by element-wise extension. Points
0012 % and vectors are stored as structures with the same fieldnames as in
0013 % elements.
0014 %
0015 % Example:
0016 % M = productmanifold(struct('X', spherefactory(3), 'Y', spherefactory(4)))
0017 % disp(M.name());
0018 % x = M.rand()
0019 %
0020 % Points of M = S^2 x S^3 are represented as structures with two fields, X
0021 % and Y. The values associated to X are points of S^2, and likewise points
0022 % of S^3 for the field Y. Tangent vectors are also represented as
0023 % structures with two corresponding fields X and Y.
0024 %
0025 % See also: powermanifold
0026 
0027 % This file is part of Manopt: www.manopt.org.
0028 % Original author: Nicolas Boumal, Dec. 30, 2012.
0029 % Contributors:
0030 % Change log:
0031 %
0032 %   July  4, 2013 (NB):
0033 %       Added support for vec, mat, tangent.
0034 %       Added support for egrad2rgrad and ehess2rhess.
0035 %       Modified hash function to make hash strings shorter.
0036 %
0037 %   Dec. 17, 2018 (NB):
0038 %       Added check all_elements_provide() to many functions, so that if,
0039 %       for example, one of the elements does not provide exp(), then the
0040 %       product manifold also won't provide exp(). This makes it easier for
0041 %       tools such as, for example, checkgradient, to determine whether exp
0042 %       is available or not.
0043 %
0044 %   Feb. 10, 2020 (NB):
0045 %       Added warnings about calling egrad2rgrad and ehess2rhess without
0046 %       storedb and key, even if some base manifolds allow them.
0047 %
0048 %   Jan. 4, 2021 (NB):
0049 %       Changes for compatibility with Octave 6.1.0: by introducing a
0050 %       "helper" function, we separate out the pre-computations. This way,
0051 %       all pre-computed quantities are passed as input to the helper
0052 %       function. This makes them available to nested subfunctions.
0053 %       The extra step is not necessary in Matlab.
0054 
0055 
0056     elems = fieldnames(elements);
0057     nelems = numel(elems);
0058     
0059     assert(nelems >= 1, ...
0060            'elements must be a structure with at least one field.');
0061 
0062     % Below are some precomputations for the mat/vec pair.
0063     %
0064     % Gather the length of the column vector representations of tangent
0065     % vectors for each of the manifolds. Raise a flag if any of the base
0066     % manifolds has no vec function available.
0067     vec_available = true;
0068     vec_lens = zeros(nelems, 1);
0069     for ii = 1 : nelems
0070         Mi = elements.(elems{ii});
0071         if isfield(Mi, 'vec')
0072             rand_x = Mi.rand();
0073             zero_u = Mi.zerovec(rand_x);
0074             vec_lens(ii) = length(Mi.vec(rand_x, zero_u));
0075         else
0076             vec_available = false;
0077             break;
0078         end
0079     end
0080     vec_pos = cumsum([1 ; vec_lens]);
0081     %
0082     vecmatareisometries = vec_available;
0083     for ii = 1 : nelems
0084         if ~isfield(elements.(elems{ii}), 'vecmatareisometries') || ...
0085            ~elements.(elems{ii}).vecmatareisometries()
0086             vecmatareisometries = false;
0087             break;
0088         end
0089     end
0090     %
0091     % Above are some precomputations for the mat/vec pair.
0092     
0093     % The helper function is the actual factory.
0094     M = productmanifoldhelper(elements, elems, nelems, vec_available, ...
0095                               vec_pos, vecmatareisometries);
0096     
0097 end
0098 
0099 
0100 function M = productmanifoldhelper(elements, elems, nelems, ...
0101                                    vec_available, vec_pos, ...
0102                                    vecmatareisometries)
0103 
0104     % Handy function to check if all elements provide the necessary methods
0105     function answer = all_elements_provide(method_name)
0106         answer = false;
0107         for i = 1 : nelems
0108             if ~isfield(elements.(elems{i}), method_name)
0109                 return;
0110             end
0111         end
0112         answer = true;
0113     end
0114        
0115     M.name = @name;
0116     function str = name()
0117         str = 'Product manifold: ';
0118         str = [str sprintf('[%s: %s]', ...
0119                            elems{1}, elements.(elems{1}).name())];
0120         for i = 2 : nelems
0121             str = [str sprintf(' x [%s: %s]', ...
0122                    elems{i}, elements.(elems{i}).name())]; %#ok<AGROW>
0123         end
0124     end
0125     
0126     M.dim = @dim;
0127     function d = dim()
0128         d = 0;
0129         for i = 1 : nelems
0130             d = d + elements.(elems{i}).dim();
0131         end
0132     end
0133     
0134     M.inner = @inner;
0135     function val = inner(x, u, v)
0136         val = 0;
0137         for i = 1 : nelems
0138             val = val + elements.(elems{i}).inner(x.(elems{i}), ...
0139                                                u.(elems{i}), v.(elems{i}));
0140         end
0141     end
0142 
0143     M.norm = @(x, d) sqrt(M.inner(x, d, d));
0144 
0145     if all_elements_provide('dist')
0146         M.dist = @dist;
0147     end
0148     function d = dist(x, y)
0149         sqd = 0;
0150         for i = 1 : nelems
0151             sqd = sqd + elements.(elems{i}).dist(x.(elems{i}), ...
0152                                                  y.(elems{i}))^2;
0153         end
0154         d = sqrt(sqd);
0155     end
0156     
0157     if all_elements_provide('typicaldist')
0158         M.typicaldist = @typicaldist;
0159     end
0160     function d = typicaldist
0161         sqd = 0;
0162         for i = 1 : nelems
0163             sqd = sqd + elements.(elems{i}).typicaldist()^2;
0164         end
0165         d = sqrt(sqd);
0166     end
0167 
0168     M.proj = @proj;
0169     function v = proj(x, u)
0170         for i = 1 : nelems
0171             v.(elems{i}) = elements.(elems{i}).proj(x.(elems{i}), ...
0172                                                     u.(elems{i}));
0173         end
0174     end
0175 
0176     M.tangent = @tangent;
0177     function v = tangent(x, u)
0178         for i = 1 : nelems
0179             v.(elems{i}) = elements.(elems{i}).tangent(x.(elems{i}), ...
0180                                                        u.(elems{i}));
0181         end
0182     end
0183 
0184     % True by default, false if any false encountered
0185     M.tangent2ambient_is_identity = true;
0186     for k = 1 : nelems
0187         if isfield(elements.(elems{k}), 'tangent2ambient_is_identity')
0188             if ~elements.(elems{k}).tangent2ambient_is_identity
0189                 M.tangent2ambient_is_identity = false;
0190                 break;
0191             end
0192         end
0193     end
0194     
0195     M.tangent2ambient = @tangent2ambient;
0196     function v = tangent2ambient(x, u)
0197         for i = 1 : nelems
0198             if isfield(elements.(elems{i}), 'tangent2ambient')
0199                 v.(elems{i}) = ...
0200                     elements.(elems{i}).tangent2ambient( ...
0201                                                x.(elems{i}), u.(elems{i}));
0202             else
0203                 v.(elems{i}) = u.(elems{i});
0204             end
0205         end
0206     end
0207 
0208     M.egrad2rgrad = @egrad2rgrad;
0209     function g = egrad2rgrad(x, g)
0210         for i = 1 : nelems
0211             g.(elems{i}) = elements.(elems{i}).egrad2rgrad(...
0212                                                x.(elems{i}), g.(elems{i}));
0213         end
0214     end
0215     for ii = 1 : nelems
0216         if nargin(elements.(elems{ii}).egrad2rgrad) > 2
0217             warning('manopt:productmanifold:egrad2rgrad', ...
0218                    ['Product manifolds call M.egrad2rgrad with only two ', ...
0219                     'inputs:\nstoredb and key won''t be available.']);
0220             break;
0221         end
0222     end
0223 
0224     M.ehess2rhess = @ehess2rhess;
0225     function h = ehess2rhess(x, eg, eh, h)
0226         for i = 1 : nelems
0227             h.(elems{i}) = elements.(elems{i}).ehess2rhess(...
0228                  x.(elems{i}), eg.(elems{i}), eh.(elems{i}), h.(elems{i}));
0229         end
0230     end
0231     for ii = 1 : nelems
0232         if nargin(elements.(elems{ii}).ehess2rhess) > 4
0233             warning('manopt:productmanifold:ehess2rhess', ...
0234                    ['Product manifolds call M.ehess2rhess with only two ', ...
0235                     'inputs:\nstoredb and key won''t be available.']);
0236             break;
0237         end
0238     end
0239     
0240     if all_elements_provide('exp')
0241         M.exp = @exp;
0242     end
0243     function y = exp(x, u, t)
0244         if nargin < 3
0245             t = 1.0;
0246         end
0247         for i = 1 : nelems
0248             y.(elems{i}) = elements.(elems{i}).exp(x.(elems{i}), ...
0249                                                    u.(elems{i}), t);
0250         end
0251     end
0252     
0253     M.retr = @retr;
0254     function y = retr(x, u, t)
0255         if nargin < 3
0256             t = 1.0;
0257         end
0258         for i = 1 : nelems
0259             y.(elems{i}) = elements.(elems{i}).retr(x.(elems{i}), ...
0260                                                     u.(elems{i}), t);
0261         end
0262     end
0263     
0264     if all_elements_provide('log')
0265         M.log = @log;
0266     end
0267     function u = log(x1, x2)
0268         for i = 1 : nelems
0269             u.(elems{i}) = elements.(elems{i}).log(x1.(elems{i}), ...
0270                                                    x2.(elems{i}));
0271         end
0272     end
0273 
0274     M.hash = @hash;
0275     function str = hash(x)
0276         str = '';
0277         for i = 1 : nelems
0278             str = [str elements.(elems{i}).hash(x.(elems{i}))]; %#ok<AGROW>
0279         end
0280         str = ['z' hashmd5(str)];
0281     end
0282 
0283     M.lincomb = @lincomb;
0284     function v = lincomb(x, a1, u1, a2, u2)
0285         if nargin == 3
0286             for i = 1 : nelems
0287                 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0288                                                         a1, u1.(elems{i}));
0289             end
0290         elseif nargin == 5
0291             for i = 1 : nelems
0292                 v.(elems{i}) = elements.(elems{i}).lincomb(x.(elems{i}), ...
0293                                      a1, u1.(elems{i}), a2, u2.(elems{i}));
0294             end
0295         else
0296             error('Bad usage of productmanifold.lincomb');
0297         end
0298     end
0299 
0300     M.rand = @rand;
0301     function x = rand()
0302         for i = 1 : nelems
0303             x.(elems{i}) = elements.(elems{i}).rand();
0304         end
0305     end
0306 
0307     M.randvec = @randvec;
0308     function u = randvec(x)
0309         for i = 1 : nelems
0310             u.(elems{i}) = elements.(elems{i}).randvec(x.(elems{i}));
0311         end
0312         u = M.lincomb(x, 1/sqrt(nelems), u);
0313     end
0314 
0315     M.zerovec = @zerovec;
0316     function u = zerovec(x)
0317         for i = 1 : nelems
0318             u.(elems{i}) = elements.(elems{i}).zerovec(x.(elems{i}));
0319         end
0320     end
0321 
0322     if all_elements_provide('transp')
0323         M.transp = @transp;
0324     end
0325     function v = transp(x1, x2, u)
0326         for i = 1 : nelems
0327             v.(elems{i}) = elements.(elems{i}).transp(x1.(elems{i}), ...
0328                                               x2.(elems{i}), u.(elems{i}));
0329         end
0330     end
0331 
0332     if all_elements_provide('pairmean')
0333         M.pairmean = @pairmean;
0334     end
0335     function y = pairmean(x1, x2)
0336         for i = 1 : nelems
0337             y.(elems{i}) = elements.(elems{i}).pairmean(x1.(elems{i}), ...
0338                                                         x2.(elems{i}));
0339         end
0340     end
0341     
0342     if vec_available
0343         M.vec = @vec;
0344         M.mat = @mat;
0345     end
0346     
0347     function u_vec = vec(x, u_mat)
0348         u_vec = zeros(vec_pos(end)-1, 1);
0349         for i = 1 : nelems
0350             range = vec_pos(i) : (vec_pos(i+1)-1);
0351             u_vec(range) = elements.(elems{i}).vec(x.(elems{i}), ...
0352                                                    u_mat.(elems{i}));
0353         end
0354     end
0355 
0356     function u_mat = mat(x, u_vec)
0357         u_mat = struct();
0358         for i = 1 : nelems
0359             range = vec_pos(i) : (vec_pos(i+1)-1);
0360             u_mat.(elems{i}) = elements.(elems{i}).mat(x.(elems{i}), ...
0361                                                        u_vec(range));
0362         end
0363     end
0364 
0365     M.vecmatareisometries = @() vecmatareisometries;    
0366 
0367 end

Generated on Sun 05-Sep-2021 17:57:00 by m2html © 2005