0001 classdef TTeMPS_op_NN < TTeMPS_op_laplace
0002
0003
0004
0005
0006
0007
0008
0009 properties
0010 L
0011 L1
0012 L2
0013 L3
0014 B
0015 C
0016
0017 E_L1
0018 V_L1
0019 E_L2
0020 V_L2
0021 E_L3
0022 V_L3
0023
0024 end
0025
0026 methods
0027
0028 function A = update_properties( A );
0029
0030 A.rank = [1, 3*ones(1, length(A.U)-1), 1];
0031 size_col_ = cellfun( @(y) size(y,1), A.U);
0032 A.size_col = size_col_ ./ (A.rank(1:end-1).*A.rank(2:end));
0033 A.size_row = cellfun( @(y) size(y,2), A.U);
0034 A.order = length( A.size_row );
0035 end
0036 end
0037
0038
0039 methods( Access = public )
0040
0041 function A = TTeMPS_op_NN( n, d )
0042
0043 one = ones(n,1);
0044 q = linspace( -10, 2, n)';
0045 h = abs(q(2) - q(1));
0046 lambda = 0.11;
0047 L = -spdiags( [one, -2*one, one], [-1 0 1], n, n) / (h^2);
0048
0049 D = spdiags(q, 0, n, n);
0050
0051 A = A@TTeMPS_op_laplace( L, d );
0052
0053
0054 A.L = L;
0055
0056 A.L1 = L + 0.5*D.^2 - lambda/3*D.^3 + lambda^2/16*D.^4;;
0057 A.L2 = L + 0.5*D.^2 - lambda/3*D.^3 + lambda^2/8*D.^4;
0058 A.L3 = L + 0.5*D.^2 + lambda^2/16*D.^4;
0059
0060 [A.V_L1, A.E_L1] = eig(full(A.L1));
0061 A.E_L1 = diag(A.E_L1);
0062 [A.V_L2, A.E_L2] = eig(full(A.L2));
0063 A.E_L2 = diag(A.E_L2);
0064 [A.V_L3, A.E_L3] = eig(full(A.L3));
0065 A.E_L3 = diag(A.E_L3);
0066
0067 A.B = D + lambda/8*D.^2;
0068 A.C = lambda * D.^2;
0069 M = speye( n, n );
0070
0071 e1 = sparse( 1, 1, 1, 3, 1 );
0072 e2 = sparse( 2, 1, 1, 3, 1 );
0073 e3 = sparse( 3, 1, 1, 3, 1 );
0074
0075 l_mid = sparse( 3, 1, 1, 9, 1 );
0076 b_mid = sparse( 6, 1, 1, 9, 1 );
0077 m_mid = sparse( [1;9], [1;1], [1;1], 9, 1 );
0078 c_mid = sparse( 2, 1, 1, 9, 1 );
0079
0080 A.U = cell( 1, d );
0081 A.U{1} = kron( A.L1, e1 ) + kron( A.B, e2 ) + kron( M, e3);
0082 A_mid = kron( A.L2, l_mid ) + kron( A.B, b_mid ) + kron( M, m_mid) + kron( A.C, c_mid );
0083 for i=2:d-1
0084 A.U{i} = A_mid;
0085 end
0086 A.U{d} = kron( M, e1 ) + kron( A.C, e2 ) + kron( A.L3, e3);
0087
0088 A = update_properties( A );
0089
0090 end
0091
0092 function expB = constr_precond_inner( A, X, mu )
0093
0094 n = size(A.L2, 1);
0095 sz = [X.rank(mu), X.size(mu), X.rank(mu+1)]
0096
0097 B1 = zeros( X.rank(mu) );
0098
0099 for i = 1:mu-1
0100
0101 tmp = X;
0102 Xi = matricize( tmp.U{i}, 2 );
0103 if i == 1
0104 Xi = A.L1*Xi;
0105 else
0106 Xi = A.L2*Xi;
0107 end
0108 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0109 B1 = B1 + innerprod( X, tmp, 'LR', mu-1);
0110 end
0111
0112 B3 = zeros( X.rank(mu+1) );
0113
0114 for i = mu+1:A.order
0115 tmp = X;
0116 Xi = matricize( tmp.U{i}, 2 );
0117 if i == A.order
0118 Xi = A.L2*Xi;
0119 else
0120 Xi = A.L3*Xi;
0121 end
0122 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0123 B3 = B3 + innerprod( X, tmp, 'RL', mu+1);
0124 end
0125
0126 [V1,e1] = eig(B1);
0127 e1 = diag(e1);
0128 [V3,e3] = eig(B3);
0129 e3 = diag(e3);
0130
0131 if mu == 1
0132 lmin = min(e1) + min(A.E_L1) + min(e3);
0133 lmax = max(e1) + max(A.E_L1) + max(e3);
0134 elseif mu == A.order
0135 lmin = min(e1) + min(A.E_L3) + min(e3);
0136 lmax = max(e1) + max(A.E_L3) + max(e3);
0137 else
0138 lmin = min(e1) + min(A.E_L2) + min(e3);
0139 lmax = max(e1) + max(A.E_L2) + max(e3);
0140 end
0141
0142 R = lmax/lmin
0143
0144 [omega, alpha] = load_coefficients( R );
0145
0146 k = 3;
0147 omega = omega/lmin;
0148 alpha = alpha/lmin;
0149
0150 expB = cell(3,k);
0151
0152 for i = 1:k
0153 expB{1,i} = omega(i) * V1*diag( exp( -alpha(i)*e1 ))*V1';
0154 if mu == 1
0155 expB{2,i} = A.V_L1*diag( exp( -alpha(i)*A.E_L1 ))*A.V_L1';
0156 elseif mu == A.order
0157 expB{2,i} = A.V_L3*diag( exp( -alpha(i)*A.E_L3 ))*A.V_L3';
0158 else
0159 expB{2,i} = A.V_L2*diag( exp( -alpha(i)*A.E_L2 ))*A.V_L2';
0160 end
0161 expB{3,i} = V3*diag( exp( -alpha(i)*e3 ))*V3';
0162 end
0163 end
0164
0165 function expB = constr_precond_outer( A, X, mu1, mu2 )
0166
0167 n = size(A.L2, 1);
0168
0169 B1 = zeros( X.rank(mu1) );
0170
0171 for i = 1:mu1-1
0172
0173 tmp = X;
0174 Xi = matricize( tmp.U{i}, 2 );
0175 if i == 1
0176 Xi = A.L1*Xi;
0177 else
0178 Xi = A.L2*Xi;
0179 end
0180 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0181 B1 = B1 + innerprod( X, tmp, 'LR', mu1-1);
0182 end
0183
0184 B3 = zeros( X.rank(mu2+1) );
0185
0186 for i = mu2+1:A.order
0187 tmp = X;
0188 Xi = matricize( tmp.U{i}, 2 );
0189 if i == A.order
0190 Xi = A.L2*Xi;
0191 else
0192 Xi = A.L3*Xi;
0193 end
0194 tmp.U{i} = tensorize( Xi, 2, [X.rank(i), n, X.rank(i+1)] );
0195 B3 = B3 + innerprod( X, tmp, 'RL', mu2+1);
0196 end
0197
0198 [V1,e1] = eig(B1);
0199 e1 = diag(e1);
0200 [V3,e3] = eig(B3);
0201 e3 = diag(e3);
0202
0203 if mu1 == 1
0204 lmin = min(e1) + min(A.E_L1) + min(A.E_L2) + min(e3);
0205 lmax = max(e1) + max(A.E_L1) + max(A.E_L2) + max(e3);
0206 elseif mu2 == A.order
0207 lmin = min(e1) + min(A.E_L2) + min(A.E_L3) + min(e3);
0208 lmax = max(e1) + max(A.E_L2) + max(A.E_L3) + max(e3);
0209 else
0210 lmin = min(e1) + 2*min(A.E_L2) + min(e3);
0211 lmax = max(e1) + 2*max(A.E_L2) + max(e3);
0212 end
0213
0214 R = lmax/lmin
0215
0216 [omega, alpha] = load_coefficients( R );
0217
0218 k = 3;
0219 omega = omega/lmin;
0220 alpha = alpha/lmin;
0221
0222 expB = cell(4,k);
0223
0224 for i = 1:k
0225 expB{1,i} = omega(i) * V1*diag( exp( -alpha(i)*e1 ))*V1';
0226 if mu1 == 1
0227 expB{2,i} = A.V_L1*diag( exp( -alpha(i)*A.E_L1 ))*A.V_L1';
0228 else
0229 expB{2,i} = A.V_L2*diag( exp( -alpha(i)*A.E_L2 ))*A.V_L2';
0230 end
0231
0232 if mu2 == A.order
0233 expB{3,i} = A.V_L3*diag( exp( -alpha(i)*A.E_L3 ))*A.V_L3';
0234 else
0235 expB{3,i} = A.V_L2*diag( exp( -alpha(i)*A.E_L2 ))*A.V_L2';
0236 end
0237
0238 expB{4,i} = V3*diag( exp( -alpha(i)*e3 ))*V3';
0239 end
0240 end
0241
0242 function P = constr_precond( A, k )
0243
0244 d = A.order;
0245 size(A.L1)
0246 ev1 = eig(A.L1);
0247 ev2 = eig(A.L2);
0248 ev3 = eig(A.L3);
0249
0250 lmin = min(ev1) + (d-2)*min(ev2) + min(ev3);
0251 lmax = max(ev1) + (d-2)*max(ev2) + max(ev3);
0252
0253 R = lmax/lmin
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269
0270
0271 if k == 3
0272 [omega, alpha] = load_coefficients( R );
0273
0274 elseif k == 7
0275 omega = [0.0133615547183825570028305575534521842940 0.0429728469424360175410925952177443321034 0.1143029399081515586560726591147663100401,...
0276 0.2838881266934189482611071431161775535656 0.6622322841999484042811198458711174907876 1.4847175320092703810050463464342840325116,...
0277 3.4859753729916252771962870138366952232900];
0278 alpha = [0.0050213411684266507485648978019454613531 0.0312546410994290844202411500801774835168 0.1045970270084145620410366606112262388706,...
0279 0.2920522758702768403556507270657505159761 0.7407504784499061527671195936939341208927 1.7609744335543204401530945069076494746696,...
0280 4.0759036969145123916954953635638503328664];
0281 else
0282 error('Unknown rank specified. Choose either k=3 or k=7');
0283 end
0284
0285 omega = omega/lmin;
0286 alpha = alpha/lmin;
0287
0288 E1 = reshape( expm( -alpha(1) * A.L1), [1, A.size_row(1), A.size_col(1), 1]);
0289 E2 = reshape( expm( -alpha(1) * A.L2), [1, A.size_row(2), A.size_col(2), 1]);
0290 E3 = reshape( expm( -alpha(1) * A.L3), [1, A.size_row(d), A.size_col(d), 1]);
0291 P = omega(1)*TTeMPS_op( [E1, repmat({E2},1,d-2), E3] );
0292 for i = 2:k
0293 E1 = reshape( expm( -alpha(i) * A.L1), [1, A.size_row(1), A.size_col(1), 1]);
0294 E2 = reshape( expm( -alpha(i) * A.L2), [1, A.size_row(2), A.size_col(2), 1]);
0295 E3 = reshape( expm( -alpha(i) * A.L3), [1, A.size_row(d), A.size_col(d), 1]);
0296 P = P + omega(i)*TTeMPS_op( [E1, repmat({E2},1,d-1), E3] );
0297 end
0298
0299 end
0300
0301 end
0302
0303 end