0001 classdef TTeMPS_tangent
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 properties( SetAccess = public, GetAccess = public )
0015
0016 dU
0017
0018 end
0019
0020 methods( Access = public )
0021
0022 function Y = TTeMPS_tangent(X, Z, ind, vec)
0023
0024
0025
0026
0027
0028
0029
0030 if nargin == 1
0031
0032 Y.dU = cellfun( @(x) zeros(size(x)), X.U, 'UniformOutput', false );
0033 return
0034 end
0035
0036
0037 if ~exist('ind','var')
0038 sampled = false;
0039 else
0040 sampled = true;
0041 if ~exist('vec','var')
0042 vec = false;
0043 end
0044 end
0045
0046 d = X.order;
0047 r = X.rank;
0048 n = X.size;
0049
0050 invXtX = cell(1,d);
0051 tmp = conj(unfold( X.U{d}, 'right')) * unfold( X.U{d}, 'right').';
0052 invXtX{d} = pinv(tmp,1e-8);
0053 for i = d-1:-1:2
0054 tmp = tensorprod_ttemps( X.U{i}, tmp', 3);
0055 tmp = conj(unfold( tmp, 'right')) * unfold( X.U{i}, 'right').';
0056 invXtX{i} = pinv(tmp,1e-8);
0057 end
0058
0059 if isa(Z, 'TTeMPS')
0060
0061 Y.dU = cell(1, d);
0062
0063 for i = 1:d
0064 Y.dU{i} = contract( X, Z, i);
0065 end
0066
0067 for i = 1:d-1
0068 Y.dU{i} = unfold( Y.dU{i}, 'left' );
0069 Y.dU{i} = Y.dU{i} * invXtX{i+1};
0070
0071
0072 U = unfold(X.U{i},'left');
0073 Y.dU{i} = Y.dU{i} - U * ( U' * Y.dU{i});
0074 Y.dU{i} = reshape( Y.dU{i}, [r(i), n(i), r(i+1)] );
0075 end
0076
0077 elseif isa(Z, 'TTeMPS_tangent')
0078
0079 Znew = tangent_to_TTeMPS( Z, X );
0080 Y = TTeMPS_tangent( X, Znew );
0081
0082 elseif ~sampled
0083
0084 ZZ = cell(1,d);
0085
0086
0087 ZZ{d} = Z(:);
0088 for i = d-1:-1:1
0089 zz = reshape( Z, [prod(n(1:i)), n(i+1)*r(i+2)] );
0090 xx = transpose( unfold( X.U{i+1}, 'right') );
0091 Z = zz*xx;
0092 ZZ{i} = Z * invXtX{i+1};
0093 end
0094
0095
0096 for k = 2:d
0097 for i = 1:k-1
0098 Z_i = reshape( ZZ{k}, [r(i)*prod(n(i)), prod(n(i+1:k))*r(k+1)] );
0099 X_i = unfold( X.U{i}, 'left');
0100 Z = X_i' * Z_i;
0101 ZZ{k} = Z;
0102 end
0103 ZZ{k} = reshape( ZZ{k}, [r(k)*n(k), r(k+1)] );
0104 end
0105
0106 Y.dU = cell(1,d);
0107
0108 for i = 1:d-1
0109 U = orth(unfold(X.U{i},'left'));
0110 ZZ{i} = ZZ{i} - U * ( U' * ZZ{i});
0111 Y.dU{i} = reshape( ZZ{i}, [r(i), n(i), r(i+1)] );
0112 end
0113 Y.dU{d} = reshape( ZZ{d}, [r(d), n(d), r(d+1)] );
0114
0115 else
0116
0117 vals = Z;
0118 C = cell(1,d);
0119 Y.dU = cell(1,d);
0120 for i=1:d
0121 C{i} = permute( X.U{i}, [1 3 2]);
0122
0123 end
0124 res = TTeMPS_tangent_omega( n, r, C, ind.', vals);
0125
0126 for i = 1:d
0127 res{i} = reshape( res{i}, [r(i), r(i+1), n(i)] );
0128 Y.dU{i} = unfold( permute( res{i}, [1 3 2]), 'left');
0129 end
0130
0131 for i=1:d-1
0132 Y.dU{i} = Y.dU{i} * invXtX{i+1};
0133 end
0134
0135 for i=1:d-1
0136
0137 U = unfold(X.U{i},'left');
0138 Y.dU{i} = Y.dU{i} - U * ( U' * Y.dU{i});
0139 Y.dU{i} = reshape( Y.dU{i}, [r(i), n(i), r(i+1)] );
0140 end
0141 Y.dU{d} = reshape( Y.dU{d}, [r(d), n(d), r(d+1)] );
0142
0143 end
0144
0145
0146 end
0147
0148 function Zfull = full( Z, X )
0149
0150
0151
0152
0153
0154
0155 d = X.order;
0156
0157 C = cell(1, d);
0158
0159 C{1} = cat( 3, Z.dU{1}, X.U{1} );
0160 for i = 2:d-1
0161 zeroblock = zeros( size(X.U{i}) );
0162 tmp1 = cat( 3, X.U{i}, zeroblock );
0163 tmp2 = cat( 3, Z.dU{i}, X.U{i} );
0164 C{i} = cat( 1, tmp1, tmp2 );
0165 end
0166 C{d} = cat( 1, X.U{d}, Z.dU{d} );
0167
0168 Xnew = TTeMPS( C );
0169 Zfull = full( Xnew );
0170 end
0171
0172 function res = plus( X, Y)
0173
0174
0175
0176
0177
0178
0179 res = X;
0180 res.dU = cellfun(@plus, X.dU, Y.dU, 'UniformOutput', false);
0181 end
0182
0183 function X = minus( X, Y )
0184
0185
0186
0187
0188
0189
0190 X.dU = cellfun(@plus, X.dU, Y.dU, 'UniformOutput', false);
0191 end
0192
0193 function X = mtimes( a, X )
0194
0195
0196
0197
0198
0199
0200 X.dU = cellfun(@(x) a*x, X.dU, 'UniformOutput', false);
0201 end
0202 function X = uminus( X )
0203
0204
0205
0206
0207
0208 X = mtimes( -1, X );
0209 end
0210
0211 function Xnew = tangentAdd( Z, t, X, retract )
0212
0213
0214
0215
0216
0217
0218
0219
0220
0221
0222
0223
0224
0225 if ~exist( 'retract', 'var' )
0226 retract = false;
0227 end
0228
0229 d = X.order;
0230 C = cell(1, d);
0231
0232 C{1} = cat( 3, t*Z.dU{1}, X.U{1} );
0233 for i = 2:d-1
0234 zeroblock = zeros( size(X.U{i}) );
0235 tmp1 = cat( 3, X.U{i}, zeroblock );
0236 tmp2 = cat( 3, t*Z.dU{i}, X.U{i} );
0237 C{i} = cat( 1, tmp1, tmp2 );
0238 end
0239 C{d} = cat( 1, X.U{d}, t*Z.dU{d} + X.U{d} );
0240
0241 Xnew = TTeMPS( C );
0242
0243 if retract
0244 Xnew = truncate( Xnew, X.rank );
0245 end
0246
0247 end
0248
0249 function res = innerprod( Z1, Z2, X )
0250
0251 X1 = tangent_to_TTeMPS( Z1, X );
0252 X2 = tangent_to_TTeMPS( Z2, X );
0253
0254 res = innerprod( X1, X2 );
0255 end
0256
0257 function res = at_Omega( Z, ind, X)
0258
0259 Xnew = tangent_to_TTeMPS( Z, X );
0260 res = Xnew(ind);
0261 end
0262
0263 function res = tangent_to_TTeMPS( Z, X)
0264 d = X.order;
0265 C = cell(1, d);
0266
0267 C{1} = cat( 3, Z.dU{1}, X.U{1} );
0268 for i = 2:d-1
0269 zeroblock = zeros( size(X.U{i}) );
0270 tmp1 = cat( 3, X.U{i}, zeroblock );
0271 tmp2 = cat( 3, Z.dU{i}, X.U{i} );
0272 C{i} = cat( 1, tmp1, tmp2 );
0273 end
0274 C{d} = cat( 1, X.U{d}, Z.dU{d});
0275
0276 res = TTeMPS( C );
0277 end
0278 end
0279
0280
0281
0282
0283 end