Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > examples > ex_completion_compare_als_riemann.m

ex_completion_compare_als_riemann

PURPOSE ^

This example shows a simple comparison of two different algorithm for tensor completion:

SYNOPSIS ^

This is a script file.

DESCRIPTION ^

 This example shows a simple comparison of two different algorithm for tensor completion:

   -- ALS completion
   -- Riemannian tensor completion (RTTC)

 in a very similar comparison as Figure 5.2. in
   
   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
   Technical report, March 2015, revised December 2015. 
   To appear in SIAM J. Sci. Comput. 

 See this report for more details about the algorithms and the setup. 
 The different to the therein described setup is only a reduced problem size (d, n, r) so 
 that it takes less time to compute the results.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % This example shows a simple comparison of two different algorithm for tensor completion:
0002 %
0003 %   -- ALS completion
0004 %   -- Riemannian tensor completion (RTTC)
0005 %
0006 % in a very similar comparison as Figure 5.2. in
0007 %
0008 %   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
0009 %   Technical report, March 2015, revised December 2015.
0010 %   To appear in SIAM J. Sci. Comput.
0011 %
0012 % See this report for more details about the algorithms and the setup.
0013 % The different to the therein described setup is only a reduced problem size (d, n, r) so
0014 % that it takes less time to compute the results.
0015 
0016 %   TTeMPS Toolbox.
0017 %   Michael Steinlechner, 2013-2016
0018 %   Questions and contact: michael.steinlechner@epfl.ch
0019 %   BSD 2-clause license, see LICENSE.txt
0020 
0021 rng(13);
0022 d = 10;
0023 
0024 ranks = [4, 6, 8];
0025 
0026 cost = cell(1,length(ranks));
0027 test = cell(1,length(ranks));
0028 stats = cell(1,length(ranks));
0029 
0030 for j = 1:length(ranks)
0031     r = ranks(j);
0032     rr = [1, r*ones(1,d-1), 1];
0033 
0034     nn = 20;
0035     n = nn*ones(1,d);
0036 
0037     opts = struct('maxiter', 50, 'tol', 0, 'reltol',0, 'gradtol',0);
0038     opts_tt = struct('maxiter', 60, 'tol', 0, 'reltol',0, 'gradtol',0);
0039     
0040     dof = d*nn*r^2;
0041     sizeOmega = 10*dof;
0042     sizeGamma = sizeOmega;
0043     
0044     Omega = makeOmegaSet_mod(n, sizeOmega);
0045     Gamma = makeOmegaSet_mod(n, sizeGamma);
0046 
0047     A = TTeMPS_rand( rr, n );
0048     A = 1/norm(A) * A;
0049     
0050     A_Omega = A(Omega);
0051     A_Gamma = A(Gamma);
0052 
0053 
0054     X0 = TTeMPS_rand( rr, n );
0055     X0 = 1/norm(X0) * X0;
0056     X0 = orthogonalize( X0, X0.order );
0057 
0058     [X,cost_als{j},test_als{j},stats_als{j}] = completion_als( A_Omega, Omega, A_Gamma, Gamma, X0, opts );
0059     [X,cost_tt{j},test_tt{j},stats_tt{j}] = completion_orth( A_Omega, Omega, A_Gamma, Gamma, X0, opts_tt );
0060 end
0061 
0062 l = lines(7);
0063 midred = l(end,:);
0064 darkred = brighten(l(end,:),-0.7);
0065 lightred = brighten(midred,0.7);
0066 
0067 midblue = l(1,:)
0068 darkblue = brighten(midblue,-0.7);
0069 lightblue = brighten(midblue,0.7);
0070 
0071 subplot(1,2,1)
0072 semilogy( test_als{1}(1:end),'color',darkred,'linewidth',2)
0073 hold on
0074 semilogy( test_als{2}(1:end),'color',midred,'linewidth',2)
0075 semilogy( test_als{3}(1:end),'color',lightred,'linewidth',2)
0076 semilogy( test_tt{1},'--','color',darkblue,'linewidth',2)
0077 semilogy( test_tt{2},'--','color',midblue,'linewidth',2)
0078 semilogy( test_tt{3},'--','color',lightblue,'linewidth',2)
0079 
0080 xlabel('Iterations')
0081 ylabel('Error on test set')
0082 legend({'ALS, rank 4','ALS, rank 6', 'ALS, rank 8','RTTC, rank 4', 'RTTC, rank 6', 'RTTC, rank 8'})
0083 
0084 
0085 subplot(1,2,2)
0086 loglog( stats_als{1}.time(1:end), test_als{1}(1:end),'color',darkred,'linewidth',2)
0087 hold on
0088 loglog( stats_als{2}.time(1:end), test_als{2}(1:end),'color',midred,'linewidth',2)
0089 loglog( stats_als{3}.time(1:end), test_als{3}(1:end),'color',lightred,'linewidth',2)
0090 loglog( stats_tt{1}.time, test_tt{1},'--','color',darkblue,'linewidth',2)
0091 loglog( stats_tt{2}.time, test_tt{2},'--','color',midblue,'linewidth',2)
0092 loglog( stats_tt{3}.time, test_tt{3},'--','color',lightblue,'linewidth',2)
0093 xlim([1e-1,1e3])
0094 
0095 xlabel('Time [s]')
0096 ylabel('Error on test set')

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005