0001 function low_rank_tensor_completion_TT()
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 d = 9;
0030
0031
0032 nn = 4;
0033 n = nn * ones(1, d);
0034
0035
0036 ranks_ = 3;
0037
0038 omegaRatio_ = 0.1;
0039
0040 prob_dist = [1 1 1 1];
0041
0042
0043 count = 1;
0044
0045
0046
0047 finalTR_ = cell(1, count);
0048
0049 finalRTTC_ = cell(1, count);
0050 finalALS_ = cell(1, count);
0051
0052
0053
0054
0055 targets_ = cell(1, count);
0056 omegas_ = cell(1, count);
0057 gammas_ = cell(1, count);
0058
0059
0060 maxInner_ = 10000;
0061
0062
0063 maxIterRTTC_ = 5000;
0064 maxIterALS_ = 1000;
0065
0066
0067 computeCondition = false;
0068
0069 cond_nums = [];
0070
0071 for p = 1:count
0072
0073 rr = ranks_;
0074
0075 r = [1, rr * ones(1, d - 1), 1];
0076 r(3:8) = ones(1, 6) * 5;
0077 r(4:7) = ones(1, 4) * 10;
0078 r(5:6) = ones(1, 2) * 10;
0079
0080 rTarg = r;
0081
0082
0083 opts = struct('maxiter', maxIterALS_, 'tol', 1e-14, 'reltol', 0, 'gradtol', 0);
0084 opts_tt = struct('maxiter', maxIterRTTC_, 'tol', 1e-14, 'reltol', 0, 'gradtol', 1e-8);
0085
0086
0087
0088 sizeOmega = round(omegaRatio_ * prod(n));
0089
0090 sizeGamma = sizeOmega;
0091
0092 Omega = makeOmegaSet_local(n, sizeOmega, prob_dist);
0093 Gamma = makeOmegaSet_local(n, sizeGamma);
0094
0095 omegas_{p} = Omega;
0096 gammas_{p} = Gamma;
0097
0098 A = TTeMPS_randn(rTarg, n);
0099 targets_{p} = A;
0100
0101
0102 A_Omega = A(Omega);
0103 A_Gamma = A(Gamma);
0104
0105
0106 X0 = TTeMPS_randn(r, n);
0107 X0 = (1 / norm(X0)) * X0;
0108 X0 = orthogonalize(X0, X0.order);
0109
0110
0111
0112
0113 TT = fixedTTrankfactory(n, r, Omega);
0114
0115
0116
0117 disp("Oversampling ratio: " + sizeOmega / TT.dim());
0118
0119
0120
0121 rProblem.M = TT;
0122 problem.M = TT;
0123
0124
0125 problem.cost = @(x) eCostCompl(x, A_Omega, Omega);
0126 problem.egrad = @(x) eGradCompl(x, A_Omega, Omega);
0127 problem.ehess = @(x, u) eHessCompl(u, Omega);
0128
0129
0130 if computeCondition
0131 A_base = orthogonalize(A, A.order);
0132 spec = hessianspectrum(problem, A_base);
0133 cond_nums(end + 1) = spec(end) / spec(1);
0134 end
0135
0136 problem = rmfield(problem, 'ehess');
0137
0138
0139 options.Delta0 = 100;
0140 options.Delta_bar = 100 * 2^11;
0141 options.maxiter = 250;
0142 options.maxinner = maxInner_;
0143 options.maxtime = inf;
0144 options.tolgradnorm = 1e-8;
0145
0146
0147 problem.Gamma = Gamma;
0148 problem.A_Gamma = A_Gamma;
0149 options.statsfun = @test_cost_manopt;
0150
0151
0152
0153 [finalTR, cost_man_fd{p}, stats_man_fd{p}] = trustregions(problem, X0, options);
0154
0155 problem.ehess = @(x, u) eHessCompl(u, Omega);
0156
0157 [finalTR, cost_man{p}, stats_man{p}] = trustregions(problem, X0, options);
0158
0159
0160
0161 [finalRTTC_{p}, cost_tt{p}, test_tt{p}, stats_tt{p}] = ...
0162 completion_orth(A_Omega, Omega, A_Gamma, Gamma, X0, opts_tt);
0163
0164
0165 [finalALS_{p}, cost_als{p}, test_als{p}, stats_als{p}] = ...
0166 completion_als(A_Omega, Omega, A_Gamma, Gamma, X0, opts);
0167
0168 end
0169
0170
0171 l = lines(7);
0172 midred = l(end, :);
0173 darkred = brighten(l(end, :), -0.7);
0174 lightred = brighten(midred, 0.7);
0175
0176 midblue = l(1, :);
0177 darkblue = brighten(midblue, -0.7);
0178 lightblue = brighten(midblue, 0.7);
0179
0180
0181 lightred(end + 1) = 0.7;
0182 lightblue(end + 1) = 0.7;
0183 midred(end + 1) = 0.7;
0184 midblue(end + 1) = 0.7;
0185
0186 figure;
0187
0188 for k = 1:count
0189 A = targets_{k};
0190 Omega = gammas_{k};
0191 A_Omega = A(Omega);
0192 semilogy([stats_man_fd{k}.time], sqrt(2 * [stats_man_fd{k}.cost_test]) / norm(A_Omega), 'color', midred, 'linewidth', 2)
0193 hold on
0194 semilogy([stats_man{k}.time], sqrt(2 * [stats_man{k}.cost_test]) / norm(A_Omega), 'color', midblue, 'linewidth', 2)
0195 semilogy(stats_tt{k}.time, test_tt{k}, 'color', lightred, 'linewidth', 2)
0196 semilogy(stats_als{k}.time, test_als{k}, 'color', lightblue, 'linewidth', 2)
0197 end
0198
0199 legend({'FD-TR', 'RTR', 'RTTC', 'ALS'})
0200 xlabel('Time (s)')
0201 ylabel('Test Cost')
0202
0203 figure;
0204
0205 for k = 1:count
0206 A = targets_{k};
0207 Omega = omegas_{k};
0208 A_Omega = A(Omega);
0209 semilogy([stats_man_fd{k}.time], sqrt(2 * [stats_man_fd{k}.cost]) / norm(A_Omega), 'color', midred, 'linewidth', 2)
0210 hold on
0211 semilogy([stats_man{k}.time], sqrt(2 * [stats_man{k}.cost]) / norm(A_Omega), 'color', midblue, 'linewidth', 2)
0212 semilogy(stats_tt{k}.time, cost_tt{k}, 'color', lightred, 'linewidth', 2)
0213 semilogy(stats_als{k}.time, cost_als{k}, 'color', lightblue, 'linewidth', 2)
0214 end
0215
0216 legend({'FD-TR', 'RTR', 'RTTC', 'ALS'})
0217 xlabel('Time (s)')
0218 ylabel('Training Cost')
0219
0220
0221 figure;
0222
0223 for k = 1:count
0224 A = targets_{k};
0225 Omega = omegas_{k};
0226 A_Omega = A(Omega);
0227 semilogy([stats_man_fd{k}.time], [stats_man_fd{k}.gradnorm] / stats_tt{k}.gradnorm(1), 'color', midred, 'linewidth', 2)
0228 hold on
0229 semilogy([stats_man{k}.time], [stats_man{k}.gradnorm] / stats_tt{k}.gradnorm(1), 'color', midblue, 'linewidth', 2)
0230 semilogy(stats_tt{k}.time, (stats_tt{k}.gradnorm) / stats_tt{k}.gradnorm(1), 'color', lightred, 'linewidth', 2)
0231 end
0232
0233 legend({'FD-TR', 'RTR', 'RTTC'})
0234 xlabel('Time (s)')
0235 ylabel('Gradient Norm')
0236
0237
0238 end
0239
0240
0241
0242 function stats = test_cost_manopt(problem, x, stats)
0243 stats.cost_test = .5 * norm(x(problem.Gamma) - problem.A_Gamma)^2;
0244 end
0245
0246
0247
0248
0249 function c = eCostCompl(x, A, A0)
0250 c = .5 * norm(x(A0) - A)^2;
0251 end
0252
0253 function g = eGradCompl(x, A, A0)
0254 g = (x(A0) - A);
0255 end
0256
0257 function h = eHessCompl(u, A0)
0258 uTT = tangent_to_TTeMPS(u);
0259 h = uTT(A0);
0260 end
0261
0262
0263
0264
0265 function Omega = makeOmegaSet_local(n, sizeOmega, dist)
0266
0267 if sizeOmega > prod(n)
0268 error('makeOmegaSet:sizeOmegaTooHigh', 'Requested size of Omega is bigger than the tensor itself!')
0269 end
0270
0271 d = length(n);
0272 subs = zeros(sizeOmega, d);
0273
0274 for i = 1:d
0275
0276 if nargin == 2
0277 subs(:, i) = randi(n(i), sizeOmega, 1);
0278 else
0279 subs(:, i) = randsample(n(i), sizeOmega, true, dist);
0280
0281
0282
0283 end
0284
0285 end
0286
0287 Omega = unique(subs, 'rows');
0288 m = size(Omega, 1);
0289
0290 while m < sizeOmega
0291 subs(1:m, :) = Omega;
0292
0293 for i = 1:d
0294
0295 if nargin == 2
0296 subs(m + 1:sizeOmega, i) = randi(n(i), sizeOmega - m, 1);
0297 else
0298 subs(m + 1:sizeOmega, i) = randsample(n(i), sizeOmega - m, true, dist);
0299 end
0300
0301 end
0302
0303 Omega = unique(subs, 'rows');
0304 m = size(Omega, 1);
0305 end
0306
0307 end