Home > examples > low_rank_tensor_completion_TT.m

low_rank_tensor_completion_TT

PURPOSE ^

Example file for the manifold encoded in fixedTTrankfactory.

SYNOPSIS ^

function low_rank_tensor_completion_TT()

DESCRIPTION ^

 Example file for the manifold encoded in fixedTTrankfactory.

 The script runs a tensor completion task, where tensors are controlled
 by a low Tensor-Train rank. The factory fixedTTrankfactory rests heavily
 on TTeMPS 1.1 (slightly modified for Manopt), coded by M. Steinlechner.
 See manopt/manifolds/ttfixedrank/TTeMPS_1.1/ for license and installation
 instruction (in particular, certain MEX files may require compiling).

 This script generates results from figure 1 of the following paper:
   Michael Psenka and Nicolas Boumal
   Second-order optimization for tensors with fixed tensor-train rank
   NeurIPS OPT2020 workshop
   https://arxiv.org/abs/2011.13395

 See also: fixedTTrankfactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function low_rank_tensor_completion_TT()
0002 % Example file for the manifold encoded in fixedTTrankfactory.
0003 %
0004 % The script runs a tensor completion task, where tensors are controlled
0005 % by a low Tensor-Train rank. The factory fixedTTrankfactory rests heavily
0006 % on TTeMPS 1.1 (slightly modified for Manopt), coded by M. Steinlechner.
0007 % See manopt/manifolds/ttfixedrank/TTeMPS_1.1/ for license and installation
0008 % instruction (in particular, certain MEX files may require compiling).
0009 %
0010 % This script generates results from figure 1 of the following paper:
0011 %   Michael Psenka and Nicolas Boumal
0012 %   Second-order optimization for tensors with fixed tensor-train rank
0013 %   NeurIPS OPT2020 workshop
0014 %   https://arxiv.org/abs/2011.13395
0015 %
0016 % See also: fixedTTrankfactory
0017 
0018 % This file is part of Manopt and is copyrighted. See the license file.
0019 %
0020 % Main author: Michael Psenka, Jan. 6, 2021
0021 % Contributors: Nicolas Boumal
0022 %
0023 % Change log:
0024 
0025 % Set the random seed for reproducible results.
0026 % rng(15);
0027 
0028 % order of the tensors
0029 d = 9;
0030 
0031 % size vector of the tensors
0032 nn = 4;
0033 n = nn * ones(1, d);
0034 
0035 % set of ranks_ to test on. Each rank vector is of the form (1, r, ..., r, 1)
0036 ranks_ = 3;
0037 % set of omega to tensor size ratios we want to observe
0038 omegaRatio_ = 0.1;
0039 
0040 prob_dist = [1 1 1 1];
0041 
0042 % How many times should the experiments be run? Set to 10 in the paper.
0043 count = 1;
0044 
0045 % create a cell for storing final output tensors for each method.
0046 % Want number dependent on whether we vary the ranks or the omega ratios
0047 finalTR_ = cell(1, count);
0048 
0049 finalRTTC_ = cell(1, count);
0050 finalALS_ = cell(1, count);
0051 
0052 % cell for storing all targets (to test convergence at the end)
0053 % also tracking all omegas and gammas
0054 
0055 targets_ = cell(1, count);
0056 omegas_ = cell(1, count);
0057 gammas_ = cell(1, count);
0058 
0059 % We want different max inner iter for TR depending on how hard the problem is
0060 maxInner_ = 10000;
0061 
0062 % Specify max iter for RTTC at each test
0063 maxIterRTTC_ = 5000;
0064 maxIterALS_ = 1000;
0065 
0066 % set to true if you want to verify condition numbers
0067 computeCondition = false;
0068 % variable to store target Hessian condition numbers
0069 cond_nums = [];
0070 
0071 for p = 1:count
0072 
0073     rr = ranks_;
0074 
0075     r = [1, rr * ones(1, d - 1), 1];
0076     r(3:8) = ones(1, 6) * 5;
0077     r(4:7) = ones(1, 4) * 10;
0078     r(5:6) = ones(1, 2) * 10;
0079 
0080     rTarg = r;
0081 
0082     % options for Steinlechner's algorithms, ALS and Riemannian respectively
0083     opts = struct('maxiter', maxIterALS_, 'tol', 1e-14, 'reltol', 0, 'gradtol', 0);
0084     opts_tt = struct('maxiter', maxIterRTTC_, 'tol', 1e-14, 'reltol', 0, 'gradtol', 1e-8);
0085 
0086     % set of observed points for tensor completion (Omega) and test set (Gamma)
0087     % test set used to make sure algorithms converge to the right tensor
0088     sizeOmega = round(omegaRatio_ * prod(n));
0089 
0090     sizeGamma = sizeOmega;
0091 
0092     Omega = makeOmegaSet_local(n, sizeOmega, prob_dist);
0093     Gamma = makeOmegaSet_local(n, sizeGamma);
0094 
0095     omegas_{p} = Omega;
0096     gammas_{p} = Gamma;
0097 
0098     A = TTeMPS_randn(rTarg, n);
0099     targets_{p} = A;
0100 
0101     % vector representing observed points of A at Omega and Gamma
0102     A_Omega = A(Omega);
0103     A_Gamma = A(Gamma);
0104 
0105     % Starting point for optimization. Forced to be unit norm
0106     X0 = TTeMPS_randn(r, n);
0107     X0 = (1 / norm(X0)) * X0;
0108     X0 = orthogonalize(X0, X0.order);
0109 
0110     % Construction of ManOpt factory for fixed-rank TT manifold
0111     % n is the dimension vector, r the rank vector, and Omega an optional
0112     % parameter to specify which points we observe on the manifold
0113     TT = fixedTTrankfactory(n, r, Omega);
0114 
0115     % checkmanifold(TT)
0116 
0117     disp("Oversampling ratio: " + sizeOmega / TT.dim());
0118 
0119     % Set up two problems: the normal tensor completion problem (problem)
0120     % and the same problem with L2 regularization (rProblem)
0121     rProblem.M = TT;
0122     problem.M = TT;
0123 
0124     % Setting up the original tensor competion problem for ManOpt
0125     problem.cost = @(x) eCostCompl(x, A_Omega, Omega);
0126     problem.egrad = @(x) eGradCompl(x, A_Omega, Omega);
0127     problem.ehess = @(x, u) eHessCompl(u, Omega);
0128 
0129     % computes the spectrum of the Hessian at current target point
0130     if computeCondition
0131         A_base = orthogonalize(A, A.order);
0132         spec = hessianspectrum(problem, A_base);
0133         cond_nums(end + 1) = spec(end) / spec(1);
0134     end
0135 
0136     problem = rmfield(problem, 'ehess');
0137 
0138     % options for trust regions
0139     options.Delta0 = 100;
0140     options.Delta_bar = 100 * 2^11;
0141     options.maxiter = 250;
0142     options.maxinner = maxInner_;
0143     options.maxtime = inf;
0144     options.tolgradnorm = 1e-8;
0145 
0146     % setting up stats func for test cost
0147     problem.Gamma = Gamma;
0148     problem.A_Gamma = A_Gamma;
0149     options.statsfun = @test_cost_manopt;
0150 
0151 
0152     % Solve tensor completion problem w/ finite differences
0153     [finalTR, cost_man_fd{p}, stats_man_fd{p}] = trustregions(problem, X0, options);
0154 
0155     problem.ehess = @(x, u) eHessCompl(u, Omega);
0156     % now solve with analytic hessian
0157     [finalTR, cost_man{p}, stats_man{p}] = trustregions(problem, X0, options);
0158 
0159     % Final RTTC, slightly change parameters to allow lower gradient tolerance
0160 
0161     [finalRTTC_{p}, cost_tt{p}, test_tt{p}, stats_tt{p}] = ...
0162         completion_orth(A_Omega, Omega, A_Gamma, Gamma, X0, opts_tt);
0163 
0164     % ALS completion
0165     [finalALS_{p}, cost_als{p}, test_als{p}, stats_als{p}] = ...
0166         completion_als(A_Omega, Omega, A_Gamma, Gamma, X0, opts);
0167 
0168 end
0169 
0170 %%
0171 l = lines(7);
0172 midred = l(end, :);
0173 darkred = brighten(l(end, :), -0.7);
0174 lightred = brighten(midred, 0.7);
0175 
0176 midblue = l(1, :);
0177 darkblue = brighten(midblue, -0.7);
0178 lightblue = brighten(midblue, 0.7);
0179 
0180 % set alpha
0181 lightred(end + 1) = 0.7;
0182 lightblue(end + 1) = 0.7;
0183 midred(end + 1) = 0.7;
0184 midblue(end + 1) = 0.7;
0185 
0186 figure;
0187 
0188 for k = 1:count
0189     A = targets_{k};
0190     Omega = gammas_{k};
0191     A_Omega = A(Omega);
0192     semilogy([stats_man_fd{k}.time], sqrt(2 * [stats_man_fd{k}.cost_test]) / norm(A_Omega), 'color', midred, 'linewidth', 2)
0193     hold on
0194     semilogy([stats_man{k}.time], sqrt(2 * [stats_man{k}.cost_test]) / norm(A_Omega), 'color', midblue, 'linewidth', 2)
0195     semilogy(stats_tt{k}.time, test_tt{k}, 'color', lightred, 'linewidth', 2)
0196     semilogy(stats_als{k}.time, test_als{k}, 'color', lightblue, 'linewidth', 2)
0197 end
0198 
0199 legend({'FD-TR', 'RTR', 'RTTC', 'ALS'})
0200 xlabel('Time (s)')
0201 ylabel('Test Cost')
0202 
0203 figure;
0204 
0205 for k = 1:count
0206     A = targets_{k};
0207     Omega = omegas_{k};
0208     A_Omega = A(Omega);
0209     semilogy([stats_man_fd{k}.time], sqrt(2 * [stats_man_fd{k}.cost]) / norm(A_Omega), 'color', midred, 'linewidth', 2)
0210     hold on
0211     semilogy([stats_man{k}.time], sqrt(2 * [stats_man{k}.cost]) / norm(A_Omega), 'color', midblue, 'linewidth', 2)
0212     semilogy(stats_tt{k}.time, cost_tt{k}, 'color', lightred, 'linewidth', 2)
0213     semilogy(stats_als{k}.time, cost_als{k}, 'color', lightblue, 'linewidth', 2)
0214 end
0215 
0216 legend({'FD-TR', 'RTR', 'RTTC', 'ALS'})
0217 xlabel('Time (s)')
0218 ylabel('Training Cost')
0219 
0220 
0221 figure;
0222 
0223 for k = 1:count
0224     A = targets_{k};
0225     Omega = omegas_{k};
0226     A_Omega = A(Omega);
0227     semilogy([stats_man_fd{k}.time], [stats_man_fd{k}.gradnorm] / stats_tt{k}.gradnorm(1), 'color', midred, 'linewidth', 2)
0228     hold on
0229     semilogy([stats_man{k}.time], [stats_man{k}.gradnorm] / stats_tt{k}.gradnorm(1), 'color', midblue, 'linewidth', 2)
0230     semilogy(stats_tt{k}.time, (stats_tt{k}.gradnorm) / stats_tt{k}.gradnorm(1), 'color', lightred, 'linewidth', 2)
0231 end
0232 
0233 legend({'FD-TR', 'RTR', 'RTTC'})
0234 xlabel('Time (s)')
0235 ylabel('Gradient Norm')
0236 
0237 
0238 end
0239 
0240 %%%%%%%%%%%%%%%%%%%%%%%%%%%% Stats function %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0241 
0242 function stats = test_cost_manopt(problem, x, stats)
0243     stats.cost_test = .5 * norm(x(problem.Gamma) - problem.A_Gamma)^2;
0244 end
0245 
0246 %%%%%%%%%%%%%%%%%%%%%%%%%% FUNCTIONS FOR MANOPT TRUST REGIONS %%%%%%%%%%%%%%%%%%%%%%%%%
0247 
0248 % Non-regularized Euclidean functions
0249 function c = eCostCompl(x, A, A0)
0250     c = .5 * norm(x(A0) - A)^2;
0251 end
0252 
0253 function g = eGradCompl(x, A, A0)
0254     g = (x(A0) - A);
0255 end
0256 
0257 function h = eHessCompl(u, A0)
0258     uTT = tangent_to_TTeMPS(u);
0259     h = uTT(A0);
0260 end
0261 
0262 %%%%%%%%%%%%%%%% CUSTOM OMEGA SET GENERATOR, NON_UNIFORM DIST. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0263 
0264 % dist is given distribution for the integers. If none given, default to uniform
0265 function Omega = makeOmegaSet_local(n, sizeOmega, dist)
0266 
0267     if sizeOmega > prod(n)
0268         error('makeOmegaSet:sizeOmegaTooHigh', 'Requested size of Omega is bigger than the tensor itself!')
0269     end
0270 
0271     d = length(n);
0272     subs = zeros(sizeOmega, d);
0273 
0274     for i = 1:d
0275 
0276         if nargin == 2
0277             subs(:, i) = randi(n(i), sizeOmega, 1);
0278         else
0279             subs(:, i) = randsample(n(i), sizeOmega, true, dist);
0280             % for k = 1:sizeOmega
0281             %     subs(k,i) = randsample(n(i), sizeOmega, true, dist);
0282 
0283         end
0284 
0285     end
0286 
0287     Omega = unique(subs, 'rows');
0288     m = size(Omega, 1);
0289 
0290     while m < sizeOmega
0291         subs(1:m, :) = Omega;
0292 
0293         for i = 1:d
0294 
0295             if nargin == 2
0296                 subs(m + 1:sizeOmega, i) = randi(n(i), sizeOmega - m, 1);
0297             else
0298                 subs(m + 1:sizeOmega, i) = randsample(n(i), sizeOmega - m, true, dist);
0299             end
0300 
0301         end
0302 
0303         Omega = unique(subs, 'rows');
0304         m = size(Omega, 1);
0305     end
0306 
0307 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005