0001 classdef TTeMPS_tangent_orth
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 properties (SetAccess = public, GetAccess = public)
0018
0019 dU
0020 U
0021 V
0022 rank
0023 order
0024 size
0025
0026 end
0027
0028 methods (Access = public)
0029
0030 function Y = TTeMPS_tangent_orth(xL, xR, Z, ind, storedb, key)
0031
0032
0033
0034
0035
0036
0037 d = xL.order;
0038 r = xL.rank;
0039 n = xL.size;
0040
0041
0042 if nargin == 1
0043 xR = orthogonalize(xL, 1);
0044 end
0045
0046 Y.order = d;
0047 Y.rank = r;
0048 Y.size = n;
0049
0050 Y.U = xL.U;
0051
0052 Y.V = xR.U;
0053
0054
0055
0056
0057 if nargin == 1
0058
0059 Y.V = xR.U;
0060 Y.dU = cell(1, d);
0061
0062 for k = 1:d
0063 Y.dU{k} = zeros(size(Y.U{k}));
0064 end
0065
0066
0067
0068
0069 elseif nargin == 2
0070
0071
0072 Y.dU = cell(1, d);
0073
0074 for i = 1:d
0075 Y.dU{i} = randn(size(Y.U{i}));
0076 end
0077
0078 Y = TTeMPS_tangent_orth(xL, xR, Y);
0079
0080
0081
0082
0083 num_nonzero = 0;
0084 for k = 1:d
0085 if norm(Y.dU{k}(:)) > 1e-14
0086 num_nonzero = num_nonzero + 1;
0087 end
0088 end
0089
0090
0091 for k = 1:d
0092 if norm(Y.dU{k}(:)) > 1e-14
0093 Y.dU{k} = Y.dU{k} / (sqrt(num_nonzero) * norm(Y.dU{k}(:)));
0094 end
0095 end
0096
0097 else
0098
0099 if ~exist('ind', 'var')
0100 sampled = false;
0101 else
0102 sampled = true;
0103 end
0104
0105 if isa(Z, 'TTeMPS')
0106
0107 Y.dU = cell(1, d);
0108
0109 right = innerprod(xR, Z, 'RL', 2, true);
0110 left = innerprod(xL, Z, 'LR', d - 1, true);
0111
0112
0113 Y.dU{1} = tensorprod_ttemps(Z.U{1}, right{2}, 3);
0114
0115 for i = 2:d - 1
0116 res = tensorprod_ttemps(Z.U{i}, left{i - 1}, 1);
0117 Y.dU{i} = tensorprod_ttemps(res, right{i + 1}, 3);
0118 end
0119
0120
0121 Y.dU{d} = tensorprod_ttemps(Z.U{d}, left{d - 1}, 1);
0122
0123 for i = 1:d - 1
0124 Y.dU{i} = unfold(Y.dU{i}, 'left');
0125 U = unfold(Y.U{i}, 'left');
0126 Y.dU{i} = Y.dU{i} - U * (U' * Y.dU{i});
0127 Y.dU{i} = reshape(Y.dU{i}, [r(i), n(i), r(i + 1)]);
0128 end
0129
0130 elseif isa(Z, 'TTeMPS_tangent_orth')
0131
0132 Znew = tangent_to_TTeMPS(Z);
0133 Y = TTeMPS_tangent_orth(xL, xR, Znew);
0134
0135 elseif ~sampled
0136
0137 ZZ = cell(1, d);
0138
0139
0140 ZZ{d} = Z(:);
0141
0142 for i = d - 1:-1:1
0143 zz = reshape(Z, [prod(n(1:i)), n(i + 1) * r(i + 2)]);
0144 xx = transpose(unfold(Y.V{i + 1}, 'right'));
0145 Z = zz * xx;
0146 ZZ{i} = Z;
0147 end
0148
0149
0150 for k = 2:d
0151
0152 for i = 1:k - 1
0153 Z_i = reshape(ZZ{k}, [r(i) * prod(n(i)), prod(n(i + 1:k)) * r(k + 1)]);
0154 X_i = unfold(Y.U{i}, 'left');
0155 Z = X_i' * Z_i;
0156 ZZ{k} = Z;
0157 end
0158
0159 ZZ{k} = reshape(ZZ{k}, [r(k) * n(k), r(k + 1)]);
0160 end
0161
0162 Y.dU = cell(1, d);
0163
0164 for i = 1:d - 1
0165 U = unfold(Y.U{i}, 'left');
0166 ZZ{i} = ZZ{i} - U * (U' * ZZ{i});
0167 Y.dU{i} = reshape(ZZ{i}, [r(i), n(i), r(i + 1)]);
0168 end
0169
0170 Y.dU{d} = reshape(ZZ{d}, [r(d), n(d), r(d + 1)]);
0171
0172 else
0173
0174 vals = Z;
0175 CU = cell(1, d);
0176 CV = cell(1, d);
0177 Y.dU = cell(1, d);
0178
0179 for i = 1:d
0180 CU{i} = permute(Y.U{i}, [1 3 2]);
0181 CV{i} = permute(Y.V{i}, [1 3 2]);
0182 end
0183
0184 res = TTeMPS_tangent_orth.TTeMPS_tangent_orth_omega(n, r, CU, CV, ind.', vals);
0185
0186
0187 if exist('storedb', 'var')
0188
0189 store = storedb.getWithShared(key);
0190 store.inner_dU_ = res;
0191 storedb.setWithShared(store, key);
0192 end
0193
0194
0195
0196 for i = 1:d
0197 res{i} = reshape(res{i}, [r(i), r(i + 1), n(i)]);
0198 Y.dU{i} = unfold(permute(res{i}, [1 3 2]), 'left');
0199 end
0200
0201 for i = 1:d - 1
0202 U = unfold(Y.U{i}, 'left');
0203 Y.dU{i} = Y.dU{i} - U * (U' * Y.dU{i});
0204 Y.dU{i} = reshape(Y.dU{i}, [r(i), n(i), r(i + 1)]);
0205 end
0206
0207 Y.dU{d} = reshape(Y.dU{d}, [r(d), n(d), r(d + 1)]);
0208
0209 end
0210
0211 end
0212
0213 end
0214
0215 function Zfull = full(Z)
0216
0217
0218
0219
0220
0221
0222
0223 Zfull = tangent_to_TTeMPS(Z);
0224 Zfull = full(Zfull);
0225 end
0226
0227 function res = plus(X, Y)
0228
0229
0230
0231
0232
0233
0234 res = X;
0235 res.dU = cellfun(@plus, X.dU, Y.dU, 'UniformOutput', false);
0236 end
0237
0238 function X = minus(X, Y)
0239
0240
0241
0242
0243
0244
0245 X.dU = cellfun(@minus, X.dU, Y.dU, 'UniformOutput', false);
0246 end
0247
0248 function X = mtimes(a, X)
0249
0250
0251
0252
0253
0254
0255 X.dU = cellfun(@(x) a * x, X.dU, 'UniformOutput', false);
0256 end
0257
0258 function X = uminus(X)
0259
0260
0261
0262
0263
0264 X = mtimes(-1, X);
0265 end
0266
0267 function Xnew = tangentAdd(Z, t, retract)
0268
0269
0270
0271
0272
0273
0274
0275
0276
0277
0278
0279
0280
0281 if ~exist('retract', 'var')
0282 retract = false;
0283 end
0284
0285 d = length(Z.dU);
0286 r = ones(1, d + 1);
0287 C = cell(1, d);
0288
0289 C{1} = cat(3, t * Z.dU{1}, Z.U{1});
0290
0291 for i = 2:d - 1
0292 sz = size(Z.U{i});
0293 r(i) = sz(1);
0294 zeroblock = zeros(sz);
0295 tmp1 = cat(3, Z.V{i}, zeroblock);
0296 tmp2 = cat(3, t * Z.dU{i}, Z.U{i});
0297 C{i} = cat(1, tmp1, tmp2);
0298 end
0299
0300 r(d) = size(Z.U{d}, 1);
0301 C{d} = cat(1, Z.V{d}, Z.U{d} + t * Z.dU{d});
0302 Xnew = TTeMPS(C);
0303
0304 if retract
0305 Xnew = truncate(Xnew, r);
0306 end
0307
0308 end
0309
0310 function res = innerprod(Z1, Z2)
0311
0312
0313 res = 0;
0314
0315 for i = 1:length(Z1.dU)
0316 res = res + Z1.dU{i}(:)' * Z2.dU{i}(:);
0317 end
0318
0319 end
0320
0321 function n = norm(Z)
0322 Z_tt = tangent_to_TTeMPS(Z);
0323 n = norm(Z_tt);
0324 end
0325
0326 function res = at_Omega(Z, ind)
0327
0328 Xnew = tangent_to_TTeMPS(Z);
0329 res = Xnew(ind);
0330 end
0331
0332 function res = tangent_to_TTeMPS(Z)
0333 d = length(Z.dU);
0334 C = cell(1, d);
0335
0336 C{1} = cat(3, Z.dU{1}, Z.U{1});
0337
0338 for i = 2:d - 1
0339 zeroblock = zeros(size(Z.U{i}));
0340 tmp1 = cat(3, Z.V{i}, zeroblock);
0341 tmp2 = cat(3, Z.dU{i}, Z.U{i});
0342 C{i} = cat(1, tmp1, tmp2);
0343 end
0344
0345 C{d} = cat(1, Z.V{d}, Z.dU{d});
0346
0347 res = TTeMPS(C);
0348 end
0349
0350 function z = vectorize_tangent(Z)
0351 z = cellfun(@(x) x(:), Z.dU, 'UniformOutput', false);
0352 z = cell2mat(z(:));
0353 end
0354
0355 function Z = fill_with_vectorized(Z, z)
0356 d = length(Z.dU);
0357 k = 1;
0358
0359 for i = 1:d
0360 s = size(Z.dU{i});
0361 Z.dU{i} = reshape(z(k:k + prod(s) - 1), s);
0362 k = k + prod(s);
0363 end
0364
0365 end
0366
0367 function xi = tangent_orth_to_tangent(Z)
0368
0369 d = Z.order;
0370 xL = TTeMPS(Z.U);
0371 [xL, xR, G] = gauge_matrices(xL);
0372
0373 xi = TTeMPS_tangent(xL);
0374
0375 for ii = 1:d - 1
0376 xi.dU{ii} = tensorprod_ttemps(Z.dU{ii}, inv(G{ii}'), 3);
0377 end
0378
0379
0380 xi.dU{d} = Z.dU{d};
0381 end
0382
0383 end
0384
0385 methods (Static, Access = private)
0386
0387 x = TTeMPS_tangent_orth_omega(n, r, CU, CV, ind, vals);
0388 x = TTeMPS_tangent_orth_omega_openmp(n, r, CU, CV, ind, vals);
0389
0390 end
0391
0392 end