Home > examples > low_rank_dist_completion.m

low_rank_dist_completion

PURPOSE ^

Perform low-rank distance matrix completion w/ automatic rank detection.

SYNOPSIS ^

function [Y, infos, problem_description] = low_rank_dist_completion(problem_description)

DESCRIPTION ^

 Perform low-rank distance matrix completion w/ automatic rank detection.

 function Y = low_rank_dist_completion(problem_description)
 function [Y, infos, out_problem_description] = low_rank_dist_completion(problem_description)

 It implements the ideas of Journee, Bach, Absil and Sepulchre, SIOPT, 2010,
 applied to the problem of low-rank Euclidean distance matrix completion.
 The details are in the paper "Low-rank optimization for distance matrix completion",
 B. Mishra, G. Meyer, and R. Sepulchre, IEEE CDC, 2011.

 Paper link: http://arxiv.org/abs/1304.6663.

 Input:
 -------

 problem_description: The problem structure with the description of the problem.


 - problem_description.data_train: Data structure for known distances that are used to learn a low-rank model.
                                   It contains the 3 fields that are shown
                                   below. An empty "data_train" structure
                                   will generate the 3d Helix instance.

       -- data_train.entries:      A column vector consisting of known
                                   distances. An empty "data_train.entries"
                                   field will generate the 3d Helix
                                   instance.

       -- data_train.rows:         The row position of th corresponding
                                   distances. An empty "data_train.rows"
                                   field will generate the 3d Helix
                                   instance.

       -- data_train.cols:         The column position of th corresponding
                                   distances. An empty "data_train.cols"
                                   field will generate the 3d Helix
                                   instance.



 - problem_description.data_test:  Data structure to compute distances for the "unknown" (to the algorithm) distances.
                                   It contains the 3 fields that are shown
                                   below. An empty "data_test" structure
                                   will not compute the test error.

       -- data_test.entries:       A column vector consisting of "unknown" (to the algorithm)
                                   distances. An empty "data_test.entries"
                                   field will not compute the test error.
       -- data_test.rows:          The row position of th corresponding
                                   distances. An empty "data_test.rows"
                                   field will not compute the test error.
       -- data_test.cols:          The column position of th corresponding
                                   distances. An empty "data_test.cols"
                                   field will not compute the test error.



 - problem_description.n:          The number of data points. An empty
                                   "n", but complete "data_train" structure
                                   will lead to an error, to avoid
                                   potential data inconsistency.





 - problem_description.rank_initial: Starting rank. By default, it is 1.



 - problem_description.rank_max:     Maximum rank. By default, it is equal to
                                     "problem_description.n".




 - problem_description.params:  Structure array containing algorithm
                                parameters for stopping criteria.
       -- params.abstolcost:    Tolerance on absolute value of cost.
                                By default, it is 1e-3.


       -- params.reltolcost:    Tolerance on absolute value of cost.
                                By default, it is 1e-3.
       -- params.tolgradnorm:   Tolerance on the norm of the gradient.
                                By default, it is 1e-5.
       -- params.maxiter:       Maximum number of fixe-rank iterations.
                                By default, it is 100.
       -- params.tolSmin:       Tolerance on smallest eigenvalue of Sy,
                                the dual variable.
                                By default, it is 1e-5.
       -- params.tolrankdeficiency:   Tolerance on the
                                      smallest singular value of Y.
                                      By default, it is 1e-3.
       -- params.solver:        Fixed-rank algorithm. Options are
                                '@trustregions' for trust-regions,
                                '@conjugategradient' for conjugate gradients,
                                '@steepestdescent' for steepest descent.
                                 By default, it is '@trustregions'.


 Output:
 --------

   Y:                    n-by-r solution matrix of rank r.
   infos:                Structure array with computed statistics.
   problem_description:  Structure array with used problem description.



 Please cite the Manopt paper as well as the research paper:
     @InProceedings{mishra2011dist,
       Title        = {Low-rank optimization for distance matrix completion},
       Author       = {Mishra, B. and Meyer, G. and Sepulchre, R.},
       Booktitle    = {{50th IEEE Conference on Decision and Control}},
       Year         = {2011},
       Organization = {{IEEE CDC}}
     }

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [Y, infos, problem_description] =  low_rank_dist_completion(problem_description)
0002 % Perform low-rank distance matrix completion w/ automatic rank detection.
0003 %
0004 % function Y = low_rank_dist_completion(problem_description)
0005 % function [Y, infos, out_problem_description] = low_rank_dist_completion(problem_description)
0006 %
0007 % It implements the ideas of Journee, Bach, Absil and Sepulchre, SIOPT, 2010,
0008 % applied to the problem of low-rank Euclidean distance matrix completion.
0009 % The details are in the paper "Low-rank optimization for distance matrix completion",
0010 % B. Mishra, G. Meyer, and R. Sepulchre, IEEE CDC, 2011.
0011 %
0012 % Paper link: http://arxiv.org/abs/1304.6663.
0013 %
0014 % Input:
0015 % -------
0016 %
0017 % problem_description: The problem structure with the description of the problem.
0018 %
0019 %
0020 % - problem_description.data_train: Data structure for known distances that are used to learn a low-rank model.
0021 %                                   It contains the 3 fields that are shown
0022 %                                   below. An empty "data_train" structure
0023 %                                   will generate the 3d Helix instance.
0024 %
0025 %       -- data_train.entries:      A column vector consisting of known
0026 %                                   distances. An empty "data_train.entries"
0027 %                                   field will generate the 3d Helix
0028 %                                   instance.
0029 %
0030 %       -- data_train.rows:         The row position of th corresponding
0031 %                                   distances. An empty "data_train.rows"
0032 %                                   field will generate the 3d Helix
0033 %                                   instance.
0034 %
0035 %       -- data_train.cols:         The column position of th corresponding
0036 %                                   distances. An empty "data_train.cols"
0037 %                                   field will generate the 3d Helix
0038 %                                   instance.
0039 %
0040 %
0041 %
0042 % - problem_description.data_test:  Data structure to compute distances for the "unknown" (to the algorithm) distances.
0043 %                                   It contains the 3 fields that are shown
0044 %                                   below. An empty "data_test" structure
0045 %                                   will not compute the test error.
0046 %
0047 %       -- data_test.entries:       A column vector consisting of "unknown" (to the algorithm)
0048 %                                   distances. An empty "data_test.entries"
0049 %                                   field will not compute the test error.
0050 %       -- data_test.rows:          The row position of th corresponding
0051 %                                   distances. An empty "data_test.rows"
0052 %                                   field will not compute the test error.
0053 %       -- data_test.cols:          The column position of th corresponding
0054 %                                   distances. An empty "data_test.cols"
0055 %                                   field will not compute the test error.
0056 %
0057 %
0058 %
0059 % - problem_description.n:          The number of data points. An empty
0060 %                                   "n", but complete "data_train" structure
0061 %                                   will lead to an error, to avoid
0062 %                                   potential data inconsistency.
0063 %
0064 %
0065 %
0066 %
0067 %
0068 % - problem_description.rank_initial: Starting rank. By default, it is 1.
0069 %
0070 %
0071 %
0072 % - problem_description.rank_max:     Maximum rank. By default, it is equal to
0073 %                                     "problem_description.n".
0074 %
0075 %
0076 %
0077 %
0078 % - problem_description.params:  Structure array containing algorithm
0079 %                                parameters for stopping criteria.
0080 %       -- params.abstolcost:    Tolerance on absolute value of cost.
0081 %                                By default, it is 1e-3.
0082 %
0083 %
0084 %       -- params.reltolcost:    Tolerance on absolute value of cost.
0085 %                                By default, it is 1e-3.
0086 %       -- params.tolgradnorm:   Tolerance on the norm of the gradient.
0087 %                                By default, it is 1e-5.
0088 %       -- params.maxiter:       Maximum number of fixe-rank iterations.
0089 %                                By default, it is 100.
0090 %       -- params.tolSmin:       Tolerance on smallest eigenvalue of Sy,
0091 %                                the dual variable.
0092 %                                By default, it is 1e-5.
0093 %       -- params.tolrankdeficiency:   Tolerance on the
0094 %                                      smallest singular value of Y.
0095 %                                      By default, it is 1e-3.
0096 %       -- params.solver:        Fixed-rank algorithm. Options are
0097 %                                '@trustregions' for trust-regions,
0098 %                                '@conjugategradient' for conjugate gradients,
0099 %                                '@steepestdescent' for steepest descent.
0100 %                                 By default, it is '@trustregions'.
0101 %
0102 %
0103 % Output:
0104 % --------
0105 %
0106 %   Y:                    n-by-r solution matrix of rank r.
0107 %   infos:                Structure array with computed statistics.
0108 %   problem_description:  Structure array with used problem description.
0109 %
0110 %
0111 %
0112 % Please cite the Manopt paper as well as the research paper:
0113 %     @InProceedings{mishra2011dist,
0114 %       Title        = {Low-rank optimization for distance matrix completion},
0115 %       Author       = {Mishra, B. and Meyer, G. and Sepulchre, R.},
0116 %       Booktitle    = {{50th IEEE Conference on Decision and Control}},
0117 %       Year         = {2011},
0118 %       Organization = {{IEEE CDC}}
0119 %     }
0120 
0121 
0122 % This file is part of Manopt: www.manopt.org.
0123 % Original author: Bamdev Mishra, April 06, 2015.
0124 % Contributors: Nicolas Boumal.
0125 % Change log:
0126 %   August 30 2016 (BM):
0127 %                   Corrected some logic flaws while plotting and storing
0128 %                   rank information. A typo was also corrected.
0129 %   August 20 2021 (XJ):
0130 %                   Added AD to compute the egrad and the ehess
0131     
0132     % Check problem description
0133     if ~exist('problem_description', 'var')
0134         problem_description = struct();
0135     end
0136     problem_description = check_problem_description(problem_description); % Check the problem description;
0137     
0138     
0139     % Common quantities
0140     data_train = problem_description.data_train;
0141     data_test =  problem_description.data_test;
0142     n =  problem_description.n;
0143     rank_initial = problem_description.rank_initial;
0144     rank_max =  problem_description.rank_max;
0145     params =  problem_description.params;
0146     N = data_train.nentries; % Number of known distances
0147     EIJ = speye(n);
0148     EIJ = EIJ(:, data_train.rows) - EIJ(:, data_train.cols);
0149     rr = rank_initial; % Starting rank.
0150     Y = randn(n, rr); % Random starting initialization.
0151     
0152     
0153     % Information
0154     time = [];               % Time for each iteration per rank
0155     cost = [];               % Cost at each iteration per rank
0156     test_error = [];         % Test error at each iteration per rank
0157     rank = [];               % Rank at each iteration
0158     rank_change_stats = [];  % Some stats relating the change of ranks
0159     
0160     
0161     
0162     % Main loop rank search
0163     rank_search = 0;
0164     while (rr <= rank_max) % When r = n a global min is attained for sure.
0165         rank_search = rank_search + 1;
0166         
0167         fprintf('>> Rank %d <<\n', rr);
0168         
0169         % Follow the descent direction to compute an iterate in a higher dimension
0170         if (rr > rank_initial)
0171             if isempty(restartDir) % If no restart dir avail. do a random restart
0172                 disp('No restart dir available, random restart is performed');
0173                 Y = randn(n, rr);
0174                 
0175             else % Perform a simple line-search based on the restart direction
0176                 disp('>> Line-search with restart direction');
0177                 Y(:, rr) = 0; % Append a column of zeroes
0178                 
0179                 Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0180                 estimDists = sum(Z.^2, 2);
0181                 errors = (estimDists - data_train.entries);
0182                 costBefore = 0.5*mean(errors.^2);
0183                 fprintf('>> Cost before = %f\n',costBefore);
0184                 
0185                 % Simple linesearch to maintain monotonicity
0186                 problem.M = symfixedrankYYfactory(n, rr);
0187                 problem.cost = @(Y)  cost_evaluation(Y, data_train);
0188                 d = zeros(size(Y));
0189                 d(:, rr) = restartDir;
0190                 [unused, Y] = linesearch_decrease(problem, Y, d, costBefore); %#ok<ASGLU>
0191                 
0192                 Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0193                 estimDists = sum(Z.^2, 2);
0194                 errors = (estimDists - data_train.entries);
0195                 costAfter = 0.5*mean(errors.^2);
0196                 
0197                 % Check for decrease
0198                 if costAfter >= costBefore - 1e-8
0199                     disp('Decrease is not sufficient, random restart');
0200                     Y = randn(n, rr);
0201                 end
0202                 
0203             end
0204             
0205         end
0206         
0207         % Fixed-rank optimization with Manopt
0208         [Y, infos_fixedrank] = low_rank_dist_completion_fixedrank(data_train, data_test, Y, params);
0209 
0210         % Some info logging
0211         thistime = [infos_fixedrank.time];
0212         if ~isempty(time)
0213             thistime = time(end) + thistime;
0214         end
0215         time = [time thistime]; %#ok<AGROW>
0216         cost = [cost [infos_fixedrank.cost]]; %#ok<AGROW>
0217         rank = [rank [infos_fixedrank.rank]]; %#ok<AGROW>
0218         rank_change_stats(rank_search).rank = rr; %#ok<AGROW>
0219         rank_change_stats(rank_search).iter = length([infos_fixedrank.cost]); %#ok<AGROW>
0220         rank_change_stats(rank_search).Y = Y; %#ok<AGROW>
0221         if isfield(infos_fixedrank, 'test_error')
0222             test_error = [test_error [infos_fixedrank.test_error]]; %#ok<AGROW>
0223         end
0224         
0225         
0226         % Evaluate gradient of the convex cost function (i.e. wrt X).
0227         Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0228         estimDists = sum(Z.^2,2);
0229         errors = (estimDists - data_train.entries);
0230         
0231       
0232         % Dual variable and its minimum eigenvalue that is used to guarantee convergence.
0233         Sy = (0.5)*EIJ * sparse(1:N,1:N,2 * errors / N,N,N) * EIJ'; % "0.5" comes from 0.5 in cost evaluation
0234         
0235         
0236         % Compute smallest algebraic eigenvalue of Sy,
0237         % this gives us a descent direction for the next rank (v)
0238         % as well as a way to control progress toward the global
0239         % optimum (s_min).
0240         
0241         % Make eigs silent.
0242         opts.disp = 0;
0243         [v, s_min] = eigs(Sy, 1, 'SA', opts);
0244         
0245         
0246         % Check whether Y is rank deficient.
0247         vp = svd(Y);
0248         
0249         % Stopping criterion.
0250         fprintf('>> smin = %.3e, and min(vp) = %.3e\n',s_min,min(vp));
0251         if (s_min  > params.tolSmin) || (min(vp) < params.tolrankdeficiency)
0252             break;
0253         end
0254         
0255         % Update rank
0256         rr = rr + 1;
0257         
0258         % Compute descent direction
0259         if (s_min < -1e-10)
0260             restartDir = v;
0261         else
0262             restartDir = [];
0263         end
0264     end
0265     
0266     
0267     % Collect relevant statistics
0268     infos.time = time;
0269     infos.cost = cost;
0270     infos.rank = rank;
0271     infos.test_error = test_error;
0272     infos.rank_change_stats = rank_change_stats;
0273     
0274     % Few plots.
0275     show_plots(problem_description, infos);
0276     
0277 end
0278 
0279 
0280 
0281 
0282 %% Cost function evaluation.
0283 function val = cost_evaluation(Y, data_train)
0284     Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0285     estimDists = sum(Z.^2, 2);
0286     errors = (estimDists - data_train.entries);
0287     val = 0.5*mean(errors.^2);
0288 end
0289 
0290 
0291 
0292 
0293 %% Local defaults
0294 function localdefaults = getlocaldefaults()
0295     localdefaults.abstolcost = 1e-3;
0296     localdefaults.reltolcost = 1e-3;
0297     localdefaults.tolSmin = -1e-3;
0298     localdefaults.tolrankdeficiency = 1e-3;
0299     localdefaults.tolgradnorm = 1e-5;
0300     localdefaults.maxiter = 100;
0301     localdefaults.solver = @trustregions; % Trust-regions
0302 end
0303 
0304 
0305 
0306 
0307 
0308 
0309 
0310 %% Fixed-rank optimization
0311 function [Yopt, infos] = low_rank_dist_completion_fixedrank(data_train, data_test, Y_initial, params)
0312     % Common quantities that are used often in the optimization process.
0313     [n, r] = size(Y_initial);
0314     EIJ = speye(n);
0315     EIJ = EIJ(:, data_train.rows) - EIJ(:, data_train.cols);
0316     
0317     % Create problem structure
0318     problem.M = symfixedrankYYfactory(n,  r);
0319     
0320     
0321     % Cost evaluation
0322     problem.cost = @cost;
0323     function [f, store] = cost(Y, store)
0324         if ~isfield(store, 'xij')
0325             store.xij = EIJ'*Y;
0326         end
0327         xij = store.xij;
0328         estimDists = sum(xij.^2,2);
0329         f = 0.5*mean((estimDists - data_train.entries).^2);
0330     end
0331     
0332     % Gradient evaluation
0333     problem.grad = @grad;
0334     function [g, store] = grad(Y, store)
0335         N = data_train.nentries;
0336         if ~isfield(store, 'xij')
0337             store.xij = EIJ'*Y;
0338         end
0339         xij = store.xij;
0340         estimDists = sum(xij.^2,2);
0341         g = EIJ * sparse(1:N,1:N,2 * (estimDists - data_train.entries) / N, N, N) * xij;
0342     end
0343     
0344     
0345     % Hessian evaluation
0346     problem.hess = @hess;
0347     function [Hess, store] = hess(Y, eta, store)
0348         N = data_train.nentries;
0349         if ~isfield(store, 'xij')
0350             store.xij = EIJ'*Y;
0351         end
0352         xij = store.xij;
0353         zij = EIJ'*eta;
0354         estimDists = sum(xij.^2,2);
0355         crossYZ = 2*sum(xij .* zij,2);
0356         Hess = (EIJ*sparse(1:N,1:N,2 * (estimDists - data_train.entries) / N,N,N))*zij + (EIJ*sparse(1:N,1:N,2 * crossYZ / N,N,N))*xij;
0357         Hess = problem.M.proj(Y, Hess);
0358     end
0359     
0360     % An alternative way to compute the egrad and the ehess is to use
0361     % automatic differentiation provided in the deep learning toolbox (slower)
0362     % problem.cost = @cost_AD;
0363     %    function f = cost_AD(Y)
0364     %        xij = EIJ'*Y;
0365     %        estimDists = sum(xij.^2,2);
0366     %        f = 0.5*mean((estimDists - data_train.entries).^2);
0367     %    end
0368     % call manoptAD to prepare AD for the problem structure
0369     % problem = manoptAD(problem);
0370     
0371     %     % Check numerically whether gradient and Hessian are correct
0372     %     checkgradient(problem);
0373     %     drawnow;
0374     %     pause;
0375     %     checkhessian(problem);
0376     %     drawnow;
0377     %     pause;
0378     
0379     
0380     % When asked, ask Manopt to compute the test error at every iteration.
0381     if ~isempty(data_test)
0382         options.statsfun = @compute_test_error;
0383         EIJ_test = speye(n);
0384         EIJ_test = EIJ_test(:, data_test.rows) - EIJ_test(:, data_test.cols);
0385     end
0386     function stats = compute_test_error(problem, Y, stats) %#ok<INUSL>
0387         xij = EIJ_test'*Y;
0388         estimDists_test = sum(xij.^2,2);
0389         stats.test_error = 0.5*mean((estimDists_test - data_test.entries).^2);
0390     end
0391     
0392     
0393     % Stopping criteria options
0394     options.stopfun = @mystopfun;
0395     function stopnow = mystopfun(problem, Y, info, last) %#ok<INUSL>
0396         stopnow = (last >= 5 && (info(last-2).cost - info(last).cost < params.abstolcost || abs(info(last-2).cost - info(last).cost)/info(last).cost < params.reltolcost));
0397     end
0398     options.tolgradnorm = params.tolgradnorm;
0399     options.maxiter = params.maxiter;
0400     
0401     
0402     % Call appropriate algorithm
0403     options.solver = params.solver;
0404     [Yopt, ~, infos] = manoptsolve(problem, Y_initial, options);
0405     [infos.rank] = deal(r);
0406 end
0407 
0408 
0409 
0410 
0411 
0412 
0413 %% 3d Helix problem instance
0414 function problem_description = get_3d_Helix_instance()
0415     
0416     % Helix curve in 3d
0417     tvec = 0:2*pi/100:2*pi;
0418     tvec = tvec'; % column vector
0419     xvec = 4*cos(3*tvec);
0420     yvec = 4*sin(3*tvec);
0421     zvec = 2*tvec;
0422     Yo = [xvec, yvec, zvec];
0423     n = size(Yo, 1); % Number of points
0424     
0425     % Fraction of unknown distances
0426     fractionOfUnknown = 0.85;
0427     
0428     % True distances among points in 3d Helix.
0429     % The pdist function is part of the Statistics and ML toolbox.
0430     trueDists = pdist(Yo)'.^2; % True distances
0431     
0432     
0433     % Add noise (set noise_level = 0 for clean measurements)
0434     noise_level = 0; % 0.01;
0435     trueDists = trueDists + noise_level * std(trueDists) * randn(size(trueDists));
0436     
0437     
0438     % Compute all pairs of indices
0439     H = tril(true(n), -1);
0440     [I, J] = ind2sub([n, n], find(H(:)));
0441     clear 'H';
0442     
0443     
0444     % Train data
0445     train = false(length(trueDists), 1);
0446     train(1:floor(length(trueDists)*(1- fractionOfUnknown))) = true;
0447     train = train(randperm(length(train)));
0448     
0449     data_train.rows = I(train);
0450     data_train.cols = J(train);
0451     data_train.entries = trueDists(train);
0452     data_train.nentries = length(data_train.entries);
0453     
0454     
0455     % Test data
0456     data_test.nentries = 1*data_train.nentries; % Depends how big data that we can handle.
0457     test = false(length(trueDists),1);
0458     test(1 : floor(data_test.nentries)) = true;
0459     test = test(randperm(length(test)));
0460     data_test.rows = I(test);
0461     data_test.cols = J(test);
0462     data_test.entries = trueDists(test);
0463     
0464     
0465     % Rank bounds
0466     rank_initial = 1; % Starting rank
0467     rank_max = n; % Maximum rank
0468     
0469     
0470     % Basic parameters used in optimization
0471     params = struct();
0472     params = mergeOptions(getlocaldefaults, params);
0473     
0474     
0475     % Problem description
0476     problem_description.data_train = data_train;
0477     problem_description.data_test = data_test;
0478     problem_description.n = n;
0479     problem_description.rank_initial = rank_initial;
0480     problem_description.rank_max = rank_max;
0481     problem_description.params = params;
0482     problem_description.Yo = Yo; % Store original Helix structure
0483 end
0484 
0485 
0486 
0487 
0488 
0489 %% Problem description check
0490 function checked_problem_description = check_problem_description(problem_description)
0491     checked_problem_description = problem_description;
0492     
0493     % Check train data
0494     if isempty(problem_description)...
0495             || ~all(isfield(problem_description,{'data_train'}) == 1)...
0496             || ~all(isfield(problem_description.data_train,{'cols', 'rows', 'entries'}) == 1)...
0497             || isempty(problem_description.data_train.cols)...
0498             || isempty(problem_description.data_train.rows)...
0499             || isempty(problem_description.data_train.entries)
0500         
0501         fprintf(['The training set is empty or not properly defined.\n' ...
0502                  'We work with the default 3d Helix example.\n']);
0503         checked_problem_description = get_3d_Helix_instance();
0504         checked_problem_description.helix_example = true;
0505         return; % No need for further check
0506     end
0507     
0508     
0509     % Check number of data points
0510     if ~isfield(problem_description, 'n')
0511         error('low_rank_dist_completion:problem_description',...
0512             'Error. The scalar corresponding to field "n" of problem description must be given. \n');
0513     end
0514     
0515     
0516     % Check initial rank
0517     if ~isfield(problem_description, 'rank_initial')...
0518             || isempty(problem_description.rank_initial)...
0519             || ~(floor(problem_description.rank_initial) == problem_description.rank_initial)
0520         warning('low_rank_dist_completion:problem_description', ...
0521             'The field "rank_initial" is not properly defined. We work with the default "1".\n');
0522         rank_initial = 1;
0523     else
0524         rank_initial = problem_description.rank_initial;
0525     end
0526     checked_problem_description.rank_initial = rank_initial;
0527     
0528     
0529     % Check maximum rank
0530     if ~isfield(problem_description, 'rank_max')...
0531             || isempty(problem_description.rank_max)...
0532             || ~(floor(problem_description.rank_max) == problem_description.rank_max)...
0533             || problem_description.rank_max > problem_description.n
0534         warning('low_rank_dist_completion:problem_description', ...
0535             'The field "rank_max" is not properly defined. We work with the default "n".\n');
0536         rank_max = problem_description.n;
0537     else
0538         rank_max = problem_description.rank_max;
0539     end
0540     checked_problem_description.rank_max = rank_max;
0541     
0542     
0543     % Check testing dataset
0544     if ~isfield(problem_description,{'data_test'})...
0545             || ~all(isfield(problem_description.data_test,{'cols', 'rows', 'entries'}) == 1)...
0546             || isempty(problem_description.data_test.cols)...
0547             || isempty(problem_description.data_test.rows)...
0548             || isempty(problem_description.data_test.entries)
0549         
0550         warning('low_rank_dist_completion:problem_description', ...
0551             'The field "data_test" is not properly defined. We work with the default "[]".\n');
0552         data_test = [];
0553     else
0554         data_test = problem_description.data_test;
0555     end
0556     checked_problem_description.data_test = data_test;
0557     
0558     
0559     % Check parameters
0560     if isfield(problem_description, 'params')
0561         params = problem_description.params;
0562     else
0563         params = struct();
0564     end
0565     params = mergeOptions(getlocaldefaults, params);
0566     checked_problem_description.params = params;
0567      
0568 end
0569 
0570 
0571 
0572 
0573 %% Show plots
0574 function  show_plots(problem_description, infos)
0575    
0576     solver = problem_description.params.solver;
0577     rank_change_stats = infos.rank_change_stats;
0578     rank_change_stats_rank = [rank_change_stats.rank];
0579     rank_change_stats_iter = [rank_change_stats.iter];
0580     rank_change_stats_iter = cumsum(rank_change_stats_iter);
0581     N = problem_description.data_train.nentries;
0582     n = problem_description.n;
0583     
0584    
0585     % Plot: train error
0586     fs = 20;
0587     figure('name', 'Training on the known distances');
0588     
0589     line(1:length([infos.cost]),log10([infos.cost]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5);
0590     ax1 = gca;
0591     
0592     set(ax1,'FontSize',fs);
0593     xlabel(ax1,'Number of iterations','FontSize',fs);
0594     ylabel(ax1,'Cost (log scale) on known distances','FontSize',fs);
0595     
0596     ax2 = axes('Position',get(ax1,'Position'),...
0597         'XAxisLocation','top',...
0598         'YAxisLocation','right',...
0599         'Color','none',...
0600         'XColor','k');
0601     
0602     set(ax2,'FontSize',fs);
0603     line(1:length([infos.cost]),log10([infos.cost]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5,'Parent',ax2);
0604     set(ax2,'XTick',rank_change_stats_iter(1:max(1,end-1)),...
0605         'XTickLabel',rank_change_stats_rank(1) + 1 : rank_change_stats_rank(max(1,end-1)) + 1,...
0606         'YTick',[]);
0607     
0608     set(ax2,'XGrid','on');
0609     legend(func2str(solver));
0610     title('Rank');
0611     legend 'boxoff';
0612     
0613     
0614     % Plot: test error
0615     if isfield(infos, 'test_error') && ~isempty(infos.test_error)
0616         Yo = problem_description.Yo;
0617         
0618         fs = 20;
0619         figure('name','Test error on a set of distances different from the training set');
0620         
0621         line(1:length([infos.test_error]),log10([infos.test_error]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5);
0622         ax1 = gca;
0623         
0624         set(ax1,'FontSize',fs);
0625         xlabel(ax1,'Number of iterations','FontSize',fs);
0626         ylabel(ax1,'Cost (log scale) on testing set','FontSize',fs);
0627         
0628         ax2 = axes('Position',get(ax1,'Position'),...
0629             'XAxisLocation','top',...
0630             'YAxisLocation','right',...
0631             'Color','none',...
0632             'XColor','k');
0633         
0634         set(ax2,'FontSize',fs);
0635         line(1:length([infos.test_error]),log10([infos.test_error]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5,'Parent',ax2);
0636         set(ax2,'XTick',rank_change_stats_iter(1:max(1,end-1)),...
0637             'XTickLabel',rank_change_stats_rank(1) + 1 : rank_change_stats_rank(max(1,end-1)) + 1,...
0638             'YTick',[]);
0639         
0640         set(ax2,'XGrid','on');
0641         legend(func2str(solver));
0642         title('Rank');
0643         legend 'boxoff';
0644         
0645         
0646         
0647     end
0648     
0649     
0650     % Plot: visualize Helix curve
0651     if isfield(problem_description, 'helix_example')
0652         jj = ceil((length(rank_change_stats_rank) + 1)/2);
0653         
0654         
0655         figure('name',['3D structure with ', num2str(N/((n^2 -n)/2)),' fraction known distances'])
0656         fs = 20;
0657         ax1 = gca;
0658         set(ax1,'FontSize',fs);
0659         subplot(jj,2,1);
0660         plot3(Yo(:,1), Yo(:,2), Yo(:,3),'*','Color', 'b','LineWidth',1.0);
0661         title('Original 3D structure');
0662         for kk = 1 : length(rank_change_stats_rank)
0663             subplot(jj, 2, kk + 1);
0664             rank_change_stats_kk = rank_change_stats(kk);
0665             Ykk = rank_change_stats_kk.Y;
0666             if size(Ykk, 2) == 1
0667                 plot(Ykk(:,1), zeros(size(Ykk, 1)),'*','Color', 'r','LineWidth',1.0);
0668                 legend(func2str(solver))
0669                 title(['Recovery at rank ',num2str(size(Ykk, 2))]);
0670                 
0671             elseif size(Ykk, 2) == 2
0672                 plot(Ykk(:,1), Ykk(:,2),'*','Color', 'r','LineWidth',1.0);
0673                 title(['Recovery at rank ',num2str(size(Ykk, 2))]);
0674                 
0675             else  % Project onto dominant 3Dsubspace
0676                 [U1, S1, V1] = svds(Ykk, 3);
0677                 Yhat = U1*S1*V1';
0678                 plot3(Yhat(:,1), Yhat(:,2), Yhat(:,3),'*','Color', 'r','LineWidth',1.0);
0679                 title(['Recovery at rank ',num2str(size(Ykk, 2))]);
0680             end
0681             
0682             axis equal;
0683             
0684         end
0685         
0686         % Trick to add a global title to the whole subplot collection.
0687         % HitTest is disabled to make it easier to select the individual
0688         % subplots (for example, to rotate the viewing angle).
0689         ha = axes('Position',[0 0 1 1],'Xlim',[0 1],'Ylim',[0 1],'Box','off','Visible','off','Units','normalized', 'clipping' , 'off' );
0690         set(ha, 'HitTest', 'off');
0691         text(0.5, 1,['Recovery of Helix from ',num2str(N/((n^2 -n)/2)),' fraction known distances'],'HorizontalAlignment','center','VerticalAlignment', 'top');
0692     end
0693        
0694 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005