0001 function low_rank_tensor_completion()
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045 n1 = 70;
0046 n2 = 60;
0047 n3 = 50;
0048 r1 = 3;
0049 r2 = 4;
0050 r3 = 5;
0051 tensor_dims = [n1 n2 n3];
0052 core_dims = [r1 r2 r3];
0053 total_entries = n1*n2*n3;
0054
0055
0056 [U1,R1] = qr(rand(n1, r1), 0);
0057 [U2,R2] = qr(rand(n2, r2), 0);
0058 [U3,R3] = qr(rand(n3, r3), 0);
0059
0060 Z.U1 = R1;
0061 Z.U2 = R2;
0062 Z.U3 = R3;
0063 Z.G = rand( core_dims );
0064 Core = tucker2multiarray(Z);
0065
0066 Y.U1 = U1;
0067 Y.U2 = U2;
0068 Y.U3 = U3;
0069 Y.G = Core;
0070 A = tucker2multiarray(Y);
0071
0072
0073
0074
0075 fraction = 0.1;
0076 nr = round(fraction * total_entries);
0077 ind = randperm(total_entries);
0078 ind = ind(1 : nr);
0079 P = false(tensor_dims);
0080 P(ind) = true;
0081
0082 PA = P.*A;
0083
0084
0085
0086
0087
0088 problem.M = fixedrankfactory_tucker_preconditioned(tensor_dims, core_dims);
0089
0090
0091
0092
0093
0094
0095
0096 problem.cost = @cost;
0097 function f = cost(X)
0098 Xmultiarray = tucker2multiarray(X);
0099 Diffmultiarray = P.*Xmultiarray - PA;
0100 Diffmultiarray_flat = reshape(Diffmultiarray, n1, n2*n3);
0101 f = .5*norm(Diffmultiarray_flat , 'fro')^2;
0102 end
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112 problem.egrad = @egrad;
0113 function [g] = egrad(X)
0114 Xmultiarray = tucker2multiarray(X);
0115 Smultiarray = P.*Xmultiarray - PA;
0116
0117
0118 S1multiarray = reshape(Smultiarray, [n1, n2*n3]);
0119 S2multiarray = reshape(permute(Smultiarray, [2 1 3]),[n2, n1*n3]);
0120 S3multiarray = reshape(permute(Smultiarray, [3 1 2]),[n3, n1*n2]);
0121
0122 g.U1 = double(S1multiarray) * kron(X.U3, X.U2) * reshape(X.G, r1, r2*r3)';
0123 g.U2 = double(S2multiarray) * kron(X.U3, X.U1) * reshape(permute(X.G, [2 1 3]), r2, r1*r3)';
0124 g.U3 = double(S3multiarray) * kron(X.U2, X.U1) * reshape(permute(X.G, [3 1 2]), r3, r1*r2)';
0125 g.G = reshape(X.U1' * reshape(double(Smultiarray),n1,n2*n3) * kron(X.U3', X.U2')', r1, r2, r3);
0126 end
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138 problem.ehess = @ehess;
0139 function [Hess] = ehess(X, eta)
0140
0141
0142 Xmultiarray = tucker2multiarray(X);
0143 S = P.*Xmultiarray - PA;
0144 S1 = reshape(S, [n1, n2*n3]);
0145 S2 = reshape(permute(S, [2 1 3]),[n2, n1*n3]);
0146 S3 = reshape(permute(S, [3 1 2]),[n3, n1*n2]);
0147
0148
0149 XG = X.G;
0150 etaG = eta.G;
0151 G1 = zeros(4*size(X.G));
0152 G1(1:r1, 1:r2, 1:r3) = XG;
0153 G1(r1 + 1 : 2*r1, r2 + 1 : 2*r2, r3 + 1 : 2*r3) = XG;
0154 G1(2*r1 + 1 : 3*r1, 2*r2 + 1 : 3*r2, 2*r3 + 1 : 3*r3) = XG;
0155 G1(3*r1 + 1 : 4*r1, 3*r2 + 1 : 4*r2, 3*r3 + 1 : 4*r3) = etaG;
0156
0157 X1.G = G1;
0158 X1.U1 = [eta.U1 X.U1 X.U1 X.U1];
0159 X1.U2 = [X.U2 eta.U2 X.U2 X.U2];
0160 X1.U3 = [X.U3 X.U3 eta.U3 X.U3];
0161
0162 X1multiarray = tucker2multiarray(X1);
0163 Sdot = P.*X1multiarray;
0164 S1dot = reshape(Sdot, [n1, n2*n3]);
0165 S2dot = reshape(permute(Sdot, [2 1 3]),[n2, n1*n3]);
0166 S3dot = reshape(permute(Sdot, [3 1 2]),[n3, n1*n2]);
0167
0168
0169 X_G1 = reshape(double(X.G),r1, r2*r3);
0170 X_G2 = reshape(permute(double(X.G),[2 1 3]),r2, r1*r3);
0171 X_G3 = reshape(permute(double(X.G),[3 1 2]),r3, r1*r2);
0172 eta_G1 = reshape(double(eta.G),r1, r2*r3);
0173 eta_G2 = reshape(permute(double(eta.G),[2 1 3]),r2, r1*r3);
0174 eta_G3 = reshape(permute(double(eta.G),[3 1 2]),r3, r1*r2);
0175
0176
0177 T1 = double(S1dot) * (kron(X.U3,X.U2)*X_G1') ...
0178 + double(S1) * (kron(eta.U3,X.U2)*X_G1' ...
0179 + kron(X.U3,eta.U2)*X_G1' + kron(X.U3,X.U2)*eta_G1');
0180
0181 T2 = double(S2dot) * (kron(X.U3,X.U1)*X_G2') ...
0182 + double(S2) * (kron(eta.U3,X.U1)*X_G2' ...
0183 + kron(X.U3,eta.U1)*X_G2' + kron(X.U3,X.U1)*eta_G2');
0184
0185 T3 = double(S3dot) * (kron(X.U2,X.U1)*X_G3') ...
0186 + double(S3) * (kron(eta.U2,X.U1)*X_G3' ...
0187 + kron(X.U2,eta.U1)*X_G3' + kron(X.U2,X.U1)*eta_G3');
0188
0189 Hess.U1 = T1;
0190 Hess.U2 = T2;
0191 Hess.U3 = T3;
0192
0193
0194 N.U1 = X.U1';
0195 N.U2 = X.U2';
0196 N.U3 = X.U3';
0197 N.G = Sdot;
0198 M0array = tucker2multiarray(N);
0199
0200 M1.U1 = eta.U1';
0201 M1.U2 = X.U2';
0202 M1.U3 = X.U3';
0203 M1.G = S;
0204 M1array = tucker2multiarray(M1);
0205
0206 M2.U1 = X.U1';
0207 M2.U2 = eta.U2';
0208 M2.U3 = X.U3';
0209 M2.G = S;
0210 M2array = tucker2multiarray(M2);
0211
0212 M3.U1 = X.U1';
0213 M3.U2 = X.U2';
0214 M3.U3 = eta.U3';
0215 M3.G = S;
0216 M3array = tucker2multiarray(M3);
0217
0218 Hess.G = M0array + M1array + M2array + M3array;
0219 end
0220
0221
0222
0223
0224
0225
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237
0238
0239
0240
0241
0242
0243
0244
0245
0246
0247
0248
0249
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259 options.maxiter = 200;
0260 options.maxinner = 30;
0261 options.maxtime = inf;
0262 options.tolgradnorm = 1e-5;
0263
0264
0265
0266
0267
0268 Xtr = trustregions(problem, [], options);
0269
0270
0271
0272
0273
0274 Xtrmultiarray = tucker2multiarray(Xtr);
0275 fprintf('||X-A||_F = %g\n', norm(reshape(Xtrmultiarray - A, [n1 n2*n3]), 'fro'));
0276
0277
0278
0279
0280
0281
0282
0283
0284
0285
0286
0287
0288
0289
0290 problem.linesearch = @linesearch_helper;
0291 function tmin = linesearch_helper(X, eta)
0292
0293
0294 Xmultiarray = tucker2multiarray(X);
0295 residual_mat = P.*Xmultiarray - PA;
0296 residual_vec = residual_mat(:);
0297 term0 = residual_vec;
0298
0299
0300 XG = X.G;
0301 etaG = eta.G;
0302 G1 = zeros(4*size(X.G));
0303 G1(1:r1, 1:r2, 1:r3) = XG;
0304 G1(r1 + 1 : 2*r1, r2 + 1 : 2*r2, r3 + 1 : 2*r3) = XG;
0305 G1(2*r1 + 1 : 3*r1, 2*r2 + 1 : 3*r2, 2*r3 + 1 : 3*r3) = XG;
0306 G1(3*r1 + 1 : 4*r1, 3*r2 + 1 : 4*r2, 3*r3 + 1 : 4*r3) = etaG;
0307
0308 X1.U1 = [eta.U1 X.U1 X.U1 X.U1];
0309 X1.U2 = [X.U2 eta.U2 X.U2 X.U2];
0310 X1.U3 = [X.U3 X.U3 eta.U3 X.U3];
0311 X1.G = G1;
0312
0313 X1multiarray = tucker2multiarray(X1);
0314 term1_mat = P.*X1multiarray;
0315 term1 = term1_mat(:);
0316
0317
0318
0319 a2 = (term1'*term1);
0320 a1 = 2*(term1'*term0);
0321 tmin = - 0.5*(a1 / a2);
0322
0323 end
0324
0325
0326 [Xcg, costcg, infocg] = conjugategradient(problem, [], options);
0327
0328 fprintf('Take a look at the options that CG used:\n');
0329 disp(options);
0330 fprintf('And see how many trials were made at each line search call:\n');
0331 info_ls = [infocg.linesearch];
0332 disp([info_ls.costevals]);
0333
0334
0335
0336 fprintf('Try it again without the linesearch helper.\n');
0337
0338
0339 problem = rmfield(problem, 'linesearch');
0340
0341 [Xcg, xcost, info, options] = conjugategradient(problem, []);
0342
0343 fprintf('Take a look at the options that CG used:\n');
0344 disp(options);
0345 fprintf('And see how many trials were made at each line search call:\n');
0346 info_ls = [info.linesearch];
0347 disp([info_ls.costevals]);
0348
0349
0350
0351 end