0001 classdef TTeMPS_op_NN_hermite < TTeMPS_op_laplace
0002
0003
0004
0005
0006
0007
0008
0009
0010 properties
0011 L
0012 L1
0013 L2
0014 L3
0015 B
0016 C
0017
0018 E_L1
0019 V_L1
0020 E_L2
0021 V_L2
0022 E_L3
0023 V_L3
0024
0025 end
0026
0027 methods
0028
0029 function A = update_properties( A );
0030
0031 A.rank = [1, 3*ones(1, length(A.U)-1), 1];
0032 size_col_ = cellfun( @(y) size(y,1), A.U);
0033 A.size_col = size_col_ ./ (A.rank(1:end-1).*A.rank(2:end));
0034 A.size_row = cellfun( @(y) size(y,2), A.U);
0035 A.order = length( A.size_row );
0036 end
0037 end
0038
0039
0040 methods( Access = public )
0041
0042 function A = TTeMPS_op_NN_hermite( n, d )
0043
0044 one = ones(n,1);
0045
0046
0047
0048 offdiag = sqrt(0.5*[1:n-1]);
0049 comp_matrix = diag(offdiag, -1) + diag(offdiag, 1);
0050 q = sort(eig(comp_matrix));
0051
0052 lambda = 0.11;
0053
0054 for i = 1:n
0055 for j =1:n
0056 if i == j
0057 L(i,j) = (4*n - 1 - 2*q(i)^2) / 6;
0058 else
0059 L(i,j) = (-1)^(i-j) * (2/(q(i)-q(j))^2 - 0.5);
0060 end
0061 end
0062 end
0063
0064 D = spdiags(q, 0, n, n);
0065
0066 A = A@TTeMPS_op_laplace( L, d );
0067
0068
0069 A.L = L;
0070
0071 A.L1 = L + 0.5*D.^2 - lambda/3*D.^3 + lambda^2/16*D.^4;;
0072 A.L2 = L + 0.5*D.^2 - lambda/3*D.^3 + lambda^2/8*D.^4;
0073 A.L3 = L + 0.5*D.^2 + lambda^2/16*D.^4;
0074
0075 [A.V_L1, A.E_L1] = eig(full(A.L1));
0076 A.E_L1 = diag(A.E_L1);
0077 [A.V_L2, A.E_L2] = eig(full(A.L2));
0078 A.E_L2 = diag(A.E_L2);
0079 [A.V_L3, A.E_L3] = eig(full(A.L3));
0080 A.E_L3 = diag(A.E_L3);
0081
0082 A.B = D + lambda/8*D.^2;
0083 A.C = lambda * D.^2;
0084 M = speye( n, n );
0085
0086 e1 = sparse( 1, 1, 1, 3, 1 );
0087 e2 = sparse( 2, 1, 1, 3, 1 );
0088 e3 = sparse( 3, 1, 1, 3, 1 );
0089
0090 l_mid = sparse( 3, 1, 1, 9, 1 );
0091 b_mid = sparse( 6, 1, 1, 9, 1 );
0092 m_mid = sparse( [1;9], [1;1], [1;1], 9, 1 );
0093 c_mid = sparse( 2, 1, 1, 9, 1 );
0094
0095 A.U = cell( 1, d );
0096 A.U{1} = kron( A.L1, e1 ) + kron( A.B, e2 ) + kron( M, e3);
0097 A_mid = kron( A.L2, l_mid ) + kron( A.B, b_mid ) + kron( M, m_mid) + kron( A.C, c_mid );
0098 for i=2:d-1
0099 A.U{i} = A_mid;
0100 end
0101 A.U{d} = kron( M, e1 ) + kron( A.C, e2 ) + kron( A.L3, e3);
0102
0103 A = update_properties( A );
0104
0105 end
0106
0107 function expB = constr_precond_inner( A, X, mu )
0108
0109 n = size(A.L2, 1);
0110 sz = [X.rank(mu), X.size(mu), X.rank(mu+1)]
0111
0112 B1 = zeros( X.rank(mu) );
0113
0114 for i = 1:mu-1
0115
0116 tmp = X;
0117 Xi = matricize( tmp.U{i}, 2 );
0118 if i == 1
0119 Xi = A.L1*Xi;
0120 else
0121 Xi = A.L2*Xi;
0122 end
0123 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0124 B1 = B1 + innerprod( X, tmp, 'LR', mu-1);
0125 end
0126
0127 B3 = zeros( X.rank(mu+1) );
0128
0129 for i = mu+1:A.order
0130 tmp = X;
0131 Xi = matricize( tmp.U{i}, 2 );
0132 if i == A.order
0133 Xi = A.L2*Xi;
0134 else
0135 Xi = A.L3*Xi;
0136 end
0137 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0138 B3 = B3 + innerprod( X, tmp, 'RL', mu+1);
0139 end
0140
0141 [V1,e1] = eig(B1);
0142 e1 = diag(e1);
0143 [V3,e3] = eig(B3);
0144 e3 = diag(e3);
0145
0146 if mu == 1
0147 lmin = min(e1) + min(A.E_L1) + min(e3);
0148 lmax = max(e1) + max(A.E_L1) + max(e3);
0149 elseif mu == A.order
0150 lmin = min(e1) + min(A.E_L3) + min(e3);
0151 lmax = max(e1) + max(A.E_L3) + max(e3);
0152 else
0153 lmin = min(e1) + min(A.E_L2) + min(e3);
0154 lmax = max(e1) + max(A.E_L2) + max(e3);
0155 end
0156
0157 R = lmax/lmin
0158
0159 [omega, alpha] = load_coefficients( R );
0160
0161 k = 3;
0162 omega = omega/lmin;
0163 alpha = alpha/lmin;
0164
0165 expB = cell(3,k);
0166
0167 for i = 1:k
0168 expB{1,i} = omega(i) * V1*diag( exp( -alpha(i)*e1 ))*V1';
0169 if mu == 1
0170 expB{2,i} = A.V_L1*diag( exp( -alpha(i)*A.E_L1 ))*A.V_L1';
0171 elseif mu == A.order
0172 expB{2,i} = A.V_L3*diag( exp( -alpha(i)*A.E_L3 ))*A.V_L3';
0173 else
0174 expB{2,i} = A.V_L2*diag( exp( -alpha(i)*A.E_L2 ))*A.V_L2';
0175 end
0176 expB{3,i} = V3*diag( exp( -alpha(i)*e3 ))*V3';
0177 end
0178 end
0179
0180 function expB = constr_precond_outer( A, X, mu1, mu2 )
0181
0182 n = size(A.L2, 1);
0183
0184 B1 = zeros( X.rank(mu1) );
0185
0186 for i = 1:mu1-1
0187
0188 tmp = X;
0189 Xi = matricize( tmp.U{i}, 2 );
0190 if i == 1
0191 Xi = A.L1*Xi;
0192 else
0193 Xi = A.L2*Xi;
0194 end
0195 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0196 B1 = B1 + innerprod( X, tmp, 'LR', mu1-1);
0197 end
0198
0199 B3 = zeros( X.rank(mu2+1) );
0200
0201 for i = mu2+1:A.order
0202 tmp = X;
0203 Xi = matricize( tmp.U{i}, 2 );
0204 if i == A.order
0205 Xi = A.L2*Xi;
0206 else
0207 Xi = A.L3*Xi;
0208 end
0209 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0210 B3 = B3 + innerprod( X, tmp, 'RL', mu2+1);
0211 end
0212
0213 [V1,e1] = eig(B1);
0214 e1 = diag(e1);
0215 [V3,e3] = eig(B3);
0216 e3 = diag(e3);
0217
0218 if mu1 == 1
0219 lmin = min(e1) + min(A.E_L1) + min(A.E_L2) + min(e3);
0220 lmax = max(e1) + max(A.E_L1) + max(A.E_L2) + max(e3);
0221 elseif mu2 == A.order
0222 lmin = min(e1) + min(A.E_L2) + min(A.E_L3) + min(e3);
0223 lmax = max(e1) + max(A.E_L2) + max(A.E_L3) + max(e3);
0224 else
0225 lmin = min(e1) + 2*min(A.E_L2) + min(e3);
0226 lmax = max(e1) + 2*max(A.E_L2) + max(e3);
0227 end
0228
0229 R = lmax/lmin
0230
0231 [omega, alpha] = load_coefficients( R );
0232
0233 k = 3;
0234 omega = omega/lmin;
0235 alpha = alpha/lmin;
0236
0237 expB = cell(4,k);
0238
0239 for i = 1:k
0240 expB{1,i} = omega(i) * V1*diag( exp( -alpha(i)*e1 ))*V1';
0241 if mu1 == 1
0242 expB{2,i} = A.V_L1*diag( exp( -alpha(i)*A.E_L1 ))*A.V_L1';
0243 else
0244 expB{2,i} = A.V_L2*diag( exp( -alpha(i)*A.E_L2 ))*A.V_L2';
0245 end
0246
0247 if mu2 == A.order
0248 expB{3,i} = A.V_L3*diag( exp( -alpha(i)*A.E_L3 ))*A.V_L3';
0249 else
0250 expB{3,i} = A.V_L2*diag( exp( -alpha(i)*A.E_L2 ))*A.V_L2';
0251 end
0252
0253 expB{4,i} = V3*diag( exp( -alpha(i)*e3 ))*V3';
0254 end
0255 end
0256
0257 function P = constr_precond( A, k )
0258
0259 d = A.order;
0260 size(A.L1)
0261 ev1 = eig(A.L1);
0262 ev2 = eig(A.L2);
0263 ev3 = eig(A.L3);
0264
0265 lmin = min(ev1) + (d-2)*min(ev2) + min(ev3);
0266 lmax = max(ev1) + (d-2)*max(ev2) + max(ev3);
0267
0268 R = lmax/lmin
0269
0270
0271
0272
0273
0274
0275
0276
0277
0278
0279
0280
0281
0282
0283
0284
0285
0286 if k == 3
0287 [omega, alpha] = load_coefficients( R );
0288
0289 elseif k == 7
0290 omega = [0.0133615547183825570028305575534521842940 0.0429728469424360175410925952177443321034 0.1143029399081515586560726591147663100401,...
0291 0.2838881266934189482611071431161775535656 0.6622322841999484042811198458711174907876 1.4847175320092703810050463464342840325116,...
0292 3.4859753729916252771962870138366952232900];
0293 alpha = [0.0050213411684266507485648978019454613531 0.0312546410994290844202411500801774835168 0.1045970270084145620410366606112262388706,...
0294 0.2920522758702768403556507270657505159761 0.7407504784499061527671195936939341208927 1.7609744335543204401530945069076494746696,...
0295 4.0759036969145123916954953635638503328664];
0296 else
0297 error('Unknown rank specified. Choose either k=3 or k=7');
0298 end
0299
0300 omega = omega/lmin;
0301 alpha = alpha/lmin;
0302
0303 E1 = reshape( expm( -alpha(1) * A.L1), [1, A.size_row(1), A.size_col(1), 1]);
0304 E2 = reshape( expm( -alpha(1) * A.L2), [1, A.size_row(2), A.size_col(2), 1]);
0305 E3 = reshape( expm( -alpha(1) * A.L3), [1, A.size_row(d), A.size_col(d), 1]);
0306 P = omega(1)*TTeMPS_op( [E1, repmat({E2},1,d-2), E3] );
0307 for i = 2:k
0308 E1 = reshape( expm( -alpha(i) * A.L1), [1, A.size_row(1), A.size_col(1), 1]);
0309 E2 = reshape( expm( -alpha(i) * A.L2), [1, A.size_row(2), A.size_col(2), 1]);
0310 E3 = reshape( expm( -alpha(i) * A.L3), [1, A.size_row(d), A.size_col(d), 1]);
0311 P = P + omega(i)*TTeMPS_op( [E1, repmat({E2},1,d-1), E3] );
0312 end
0313
0314 end
0315
0316 end
0317
0318 end