0001 function M = fixedranktensorembeddedfactory(tensor_size, tensor_rank)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072 assert(exist('ttensor', 'file') == 2, sprintf( ...
0073 ['It seems the Tensor Toolbox is not installed.\nIt is needed ', ...
0074 'for the execution of fixedranktensorembeddedfactory.m.\n', ...
0075 'Please download the toolbox at https://www.tensortoolbox.org/', ...
0076 '\nor https://gitlab.com/tensors/tensor_toolbox.']));
0077
0078
0079 d = length(tensor_size);
0080 if d ~= length(tensor_rank)
0081 error(['Tensor dimensions and ranks do not match: ' ...
0082 'the two inputs should have the same length.']);
0083 end
0084 n = tensor_size;
0085 r = tensor_rank;
0086
0087
0088 manifold_name = mfname(n, r, d);
0089 M.name = @() manifold_name;
0090
0091
0092 mfdim = prod(r) + sum(r.*(n-r));
0093 M.dim = @() mfdim;
0094
0095
0096 M.inner = @iproduct;
0097 function ip = iproduct(X, eta, zeta)
0098 ip = innerprod(eta.G, zeta.G);
0099 for i = 1:d
0100 ip = ip + innerprod(X.X.core, ...
0101 ttm(X.X.core, eta.V{i}'*zeta.V{i}, i));
0102 end
0103 end
0104
0105 M.norm = @(X, eta) sqrt(iproduct(X, eta, eta));
0106
0107 M.dist = @(x, y) error('fixedranktensorembeddedfactory.dist not implemented yet.');
0108
0109
0110 M.typicaldist = @() 10*mean(n)*mean(r);
0111
0112
0113
0114 M.proj = @projection;
0115 function Eproj = projection(X, E)
0116 if ~isstruct(E)
0117 uList = X.X.U;
0118
0119
0120 G = ttm(E, uList, 't');
0121
0122
0123 V = cell(1, d);
0124 for i = 1:d
0125
0126 modes = 1:d;
0127 modes(i) = [];
0128
0129
0130 uListWoI = uList;
0131 uListWoI(:, i) = [];
0132
0133
0134 beforeProj = tenmat(ttm(E, uListWoI, modes, 't'), i) * X.Cpinv{i};
0135
0136
0137 V{i} = double(beforeProj - X.X.U{i}*(X.X.U{i}'*beforeProj));
0138 end
0139
0140 Eproj.G = G;
0141 Eproj.V = V;
0142 else
0143 error(['fixedranktensorembeddedfactory.proj only ' ...
0144 'implemented for ambient tensors so far.']);
0145 end
0146
0147 end
0148
0149
0150 M.egrad2rgrad = @egrad2rgrad;
0151 function rgrad = egrad2rgrad(X, egrad)
0152 rgrad = projection(X, egrad);
0153 end
0154
0155
0156
0157 M.ehess2rhess = @ehess2rhess;
0158 function Hess = ehess2rhess(X, egrad, ehess, eta)
0159 Hess = lincomb(X, 1, projection(X, ehess), ...
0160 1, curvature_term(egrad, X, eta));
0161 end
0162
0163
0164
0165
0166 M.tangent = @tangent;
0167 function xi = tangent(X, eta)
0168 xi = eta;
0169 for i = 1:d
0170 xi.V{i} = eta.V{i} - X.X.U{i}*(X.X.U{i}'*eta.V{i});
0171 end
0172 end
0173
0174
0175
0176 M.tangent2ambient = @tan2amb;
0177 function E = tan2amb(X, eta)
0178 E = ttm(eta.G, X.X.U);
0179
0180 for i = 1:d
0181
0182 modes = 1:d;
0183 modes(i) = [];
0184
0185
0186 uListWoI = X.X.U;
0187 uListWoI(:, i) = [];
0188
0189 E = E + ttm(ttm(X.X.core, eta.V{i}, i), uListWoI, modes);
0190 end
0191
0192 end
0193
0194
0195 M.retr = @retraction;
0196 function Y = retraction(X, xi, alpha)
0197
0198 if nargin < 3
0199 alpha = 1.0;
0200 end
0201
0202 Q = cell(1, d);
0203 R = cell(1, d);
0204 for i = 1:d
0205 [Q{i}, R{i}] = qr([X.X.U{i}, xi.V{i}], 0);
0206 end
0207
0208 S = zeros(2*r);
0209
0210
0211 sBlock = double(X.X.core + alpha*xi.G);
0212 for i = 1:d
0213 sBlock = cat(i, sBlock, zeros(size(sBlock)));
0214 end
0215 S = S + sBlock;
0216
0217
0218 for i = 1:d
0219 sBlock = double(alpha*X.X.core);
0220 modes = 1:d;
0221 modes(i) = [];
0222 for j=modes
0223 sBlock = cat(j, sBlock, zeros(size(sBlock)));
0224 end
0225 sBlock = cat(i, zeros(size(sBlock)), sBlock);
0226 S = S + sBlock;
0227 end
0228
0229
0230 S = tensor(S);
0231
0232
0233 S = ttm(S, R);
0234
0235
0236 sHosvd = hosvd(S, r);
0237
0238
0239 U = cell(1, d);
0240 for i = 1:d
0241 U{i} = Q{i} * sHosvd.U{i};
0242 end
0243
0244 Z = ttensor(sHosvd.core, U);
0245 Y.X = Z;
0246 Y.Cpinv = cell(1, d);
0247 for i = 1:d
0248
0249
0250 Y.Cpinv{i} = pinv(double(tenmat(Y.X.core, i)));
0251 end
0252 end
0253
0254 M.hash = @hashing;
0255 function h = hashing(X)
0256 v = zeros(2*d+1, 1);
0257 for i = 1:d
0258 v(i) = sum(X.X.U{i}(:));
0259 end
0260 v(d+1) = sum(X.X.core(:));
0261 for i = 1:d
0262 v(d+1+i) = sum(X.Cpinv{i}(:));
0263 end
0264 h = ['z' hashmd5(v)];
0265 end
0266
0267
0268 M.rand = @random;
0269 function X = random()
0270 U = cell(1, d);
0271 R = cell(1, d);
0272 for i = 1:d
0273 [U{i}, R{i}] = qr(rand(n(i), r(i)), 0);
0274 end
0275 C = tenrand(r);
0276 C = ttm(C,R);
0277
0278
0279 Z = hosvd(C, r);
0280 for i = 1:d
0281 U{i} = U{i}*Z.U{i};
0282 end
0283
0284 Y = ttensor(Z.core, U);
0285 X.X = Y;
0286 Cpinv = cell(1, d);
0287 for i = 1:d
0288 Cpinv{i} = pinv(double(tenmat(X.X.core, i)));
0289 end
0290 X.Cpinv = Cpinv;
0291 end
0292
0293
0294 M.randvec = @randomvec;
0295 function eta = randomvec(X)
0296 G = tensor(randn(r));
0297 xi.G = G;
0298
0299 V = cell(1, d);
0300 for i = 1:d
0301 V{i} = randn(n(i), r(i));
0302 end
0303 xi.V = V;
0304
0305 xi = M.tangent(X, xi);
0306 nrm = M.norm(X, xi);
0307
0308 eta.G = xi.G / nrm;
0309 for i = 1:d
0310 xi.V{i} = xi.V{i} / nrm;
0311 end
0312 eta.V = xi.V;
0313 end
0314
0315
0316 M.lincomb = @lincomb;
0317 function xi = lincomb(X, lambda1, eta1, lambda2, eta2)
0318 if nargin == 3
0319 V = cell(1, d);
0320 for i = 1:d
0321 V{i} = lambda1*eta1.V{i};
0322 end
0323 xi.G = lambda1*eta1.G;
0324 xi.V = V;
0325 elseif nargin == 5
0326 V = cell(1, d);
0327 for i = 1:d
0328 V{i} = lambda1*eta1.V{i} + lambda2*eta2.V{i};
0329 end
0330 xi.G = lambda1*eta1.G + lambda2*eta2.G;
0331 xi.V = V;
0332 else
0333 error('fixedranktensorembeddedfactory.lincomb takes 3 or 5 inputs.');
0334 end
0335 end
0336
0337 M.zerovec = @zerovector;
0338 function eta = zerovector(X)
0339 G = tenzeros(r);
0340 V = cell(1, d);
0341 for i = 1:d
0342 V{i} = zeros(n(i), r(i));
0343 end
0344 eta.G = G;
0345 eta.V = V;
0346 end
0347
0348
0349
0350 M.transp = @transport;
0351 function eta = transport(X, Y, xi)
0352
0353
0354 C = X.X.core;
0355 U = X.X.U;
0356 uList = U;
0357 U_tilde = Y.X.U;
0358 uTildeList = U_tilde;
0359 G = xi.G;
0360 V = xi.V;
0361
0362 G_tilde = ttm(ttm(G, U), U_tilde, 't');
0363
0364 for i = 1:d
0365
0366
0367 uListWoI = U;
0368 uListWoI{i} = V{i};
0369
0370 G_tilde = G_tilde + ttm(ttm(C, uListWoI), U_tilde, 't');
0371 end
0372
0373 V_tilde = cell(1, d);
0374 for i = 1:d
0375
0376
0377 modesWoI = 1:d;
0378 modesWoI(i) = [];
0379
0380
0381 uTildeListWoI = uTildeList;
0382 uTildeListWoI(:, i) = [];
0383
0384 beforeProj = ttm(ttm(G, U), uTildeListWoI, modesWoI, 't');
0385
0386 for k = 1:d
0387 uListWoK = uList;
0388 uListWoK{k} = V{k};
0389
0390 beforeProj = beforeProj + ...
0391 ttm(ttm(C, uListWoK), uTildeListWoI, modesWoI, 't');
0392 end
0393
0394 beforeProj = tenmat(beforeProj, i) * Y.Cpinv{i};
0395
0396 V_tilde{i} = double(beforeProj - U_tilde{i}*(U_tilde{i}'*beforeProj));
0397 end
0398
0399 eta.G = G_tilde;
0400 eta.V = V_tilde;
0401 end
0402
0403 M.vec = @(X, eta) [eta.V{1}(:); eta.V{2}(:); eta.V{3}(:);eta.G(:)];
0404
0405 M.mat = @normrep;
0406 function eta = normrep(X, eta_vec)
0407
0408 V = cell(1, d);
0409 first_ind = 1;
0410 for i = 1:d
0411 V{i} = reshape(eta_vec(first_ind : first_ind + n(i)*r(i)-1), n(i), r(i));
0412 first_ind = first_ind + n(i)*r(i);
0413 end
0414 G = tensor(reshape(eta_vec(first_ind : end), r));
0415
0416 eta.G = G;
0417 eta.V = V;
0418 end
0419
0420
0421 M.vecmatareisometries = @() false;
0422
0423 end
0424
0425
0426 function T = hosvd(X, r)
0427 if ndims(X) == length(r)
0428 d = ndims(X);
0429 else
0430 error('Dimensions of tensor and multilinear rank vector do not match.')
0431 end
0432
0433
0434
0435 uList = cell(1, d);
0436 for i = 1:d
0437
0438 A = double(tenmat(X, i));
0439 [U, ~, ~] = svds(A, r(i));
0440 uList{i} = U;
0441 end
0442
0443 C = ttm(X, uList, 't');
0444
0445 T = ttensor(C, uList);
0446 end
0447
0448
0449
0450 function eta = curvature_term(E, X, xi)
0451 G = tenzeros(size(X.X.core));
0452 d = length(size(X.X.core));
0453 V = cell(1, d);
0454
0455 for i = 1:d
0456 modesWoI = 1:d;
0457 modesWoI(i) = [];
0458
0459 uListWoI = X.X.U;
0460 uListWoI(:,i) = [];
0461
0462 EUit = ttm(E, uListWoI, modesWoI, 't');
0463 Gi = double(tenmat(xi.G, i));
0464 Ci = double(tenmat(X.X.core, i));
0465
0466 G = G + ttm(EUit,xi.V{i}, i, 't')...
0467 - ttm(X.X.core, ...
0468 double(xi.V{i}'*(tenmat(EUit,i)*X.Cpinv{i})), i);
0469
0470 Cplusi2 = X.Cpinv{i}'*X.Cpinv{i};
0471 Vi = (tenmat(EUit, i)*Gi')*Cplusi2 + ...
0472 (tenmat(EUit, i)*X.Cpinv{i})*(Ci*Gi')*Cplusi2;
0473 for k = 1:length(modesWoI)
0474 modesWoIWoK = modesWoI;
0475 modesWoIWoK(k) = [];
0476
0477 uListWoIWoK = uListWoI;
0478 uListWoIWoK(:, k) = [];
0479
0480 EUiEUkdott = ttm(ttm(E, uListWoIWoK, modesWoIWoK, 't'), ...
0481 xi.V{modesWoI(k)}, modesWoI(k), 't');
0482 Vi = Vi + tenmat(EUiEUkdott, i)*X.Cpinv{i};
0483 end
0484 V{i} = double(Vi - X.X.U{i}*(X.X.U{i}'*Vi));
0485 end
0486
0487 eta.G = G;
0488 eta.V = V;
0489 end
0490
0491 function spf = mfname(n, r, d)
0492 s = 'C';
0493 for i = 1:d
0494 s = strcat(s, ' x U',int2str(i));
0495 end
0496 s = strcat(s, ' Tucker manifold of ');
0497 for i = 1:10
0498 if n(i) < 10^i
0499 digits = i;
0500 break;
0501 end
0502 end
0503 s = strcat(s, '%', int2str(digits+1), 'd-by-');
0504 for i = 2:d-1
0505 s = strcat(s, '%d-by-');
0506 end
0507 s = strcat(s, '%d tensors of rank ');
0508 for i = 1:10
0509 if r(i)<10^i
0510 digits = i;
0511 break;
0512 end
0513 end
0514 s = strcat(s, '%', int2str(digits+1), 'd-by-');
0515 for i = 2:d-1
0516 s = strcat(s, '%d-by-');
0517 end
0518 s = strcat(s, '%d');
0519 spf = sprintf(s, n, r);
0520 end