0001 function [Y, infos, problem_description] = low_rank_dist_completion(problem_description)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133 if ~exist('problem_description', 'var')
0134 problem_description = struct();
0135 end
0136 problem_description = check_problem_description(problem_description);
0137
0138
0139
0140 data_train = problem_description.data_train;
0141 data_test = problem_description.data_test;
0142 n = problem_description.n;
0143 rank_initial = problem_description.rank_initial;
0144 rank_max = problem_description.rank_max;
0145 params = problem_description.params;
0146 N = data_train.nentries;
0147 EIJ = speye(n);
0148 EIJ = EIJ(:, data_train.rows) - EIJ(:, data_train.cols);
0149 rr = rank_initial;
0150 Y = randn(n, rr);
0151
0152
0153
0154 time = [];
0155 cost = [];
0156 test_error = [];
0157 rank = [];
0158 rank_change_stats = [];
0159
0160
0161
0162
0163 rank_search = 0;
0164 while (rr <= rank_max)
0165 rank_search = rank_search + 1;
0166
0167 fprintf('>> Rank %d <<\n', rr);
0168
0169
0170 if (rr > rank_initial)
0171 if isempty(restartDir)
0172 disp('No restart dir available, random restart is performed');
0173 Y = randn(n, rr);
0174
0175 else
0176 disp('>> Line-search with restart direction');
0177 Y(:, rr) = 0;
0178
0179 Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0180 estimDists = sum(Z.^2, 2);
0181 errors = (estimDists - data_train.entries);
0182 costBefore = 0.5*mean(errors.^2);
0183 fprintf('>> Cost before = %f\n',costBefore);
0184
0185
0186 problem.M = symfixedrankYYfactory(n, rr);
0187 problem.cost = @(Y) cost_evaluation(Y, data_train);
0188 d = zeros(size(Y));
0189 d(:, rr) = restartDir;
0190 [unused, Y] = linesearch_decrease(problem, Y, d, costBefore);
0191
0192 Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0193 estimDists = sum(Z.^2, 2);
0194 errors = (estimDists - data_train.entries);
0195 costAfter = 0.5*mean(errors.^2);
0196
0197
0198 if costAfter >= costBefore - 1e-8
0199 disp('Decrease is not sufficient, random restart');
0200 Y = randn(n, rr);
0201 end
0202
0203 end
0204
0205 end
0206
0207
0208 [Y, infos_fixedrank] = low_rank_dist_completion_fixedrank(data_train, data_test, Y, params);
0209
0210
0211 thistime = [infos_fixedrank.time];
0212 if ~isempty(time)
0213 thistime = time(end) + thistime;
0214 end
0215 time = [time thistime];
0216 cost = [cost [infos_fixedrank.cost]];
0217 rank = [rank [infos_fixedrank.rank]];
0218 rank_change_stats(rank_search).rank = rr;
0219 rank_change_stats(rank_search).iter = length([infos_fixedrank.cost]);
0220 rank_change_stats(rank_search).Y = Y;
0221 if isfield(infos_fixedrank, 'test_error')
0222 test_error = [test_error [infos_fixedrank.test_error]];
0223 end
0224
0225
0226
0227 Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0228 estimDists = sum(Z.^2,2);
0229 errors = (estimDists - data_train.entries);
0230
0231
0232
0233 Sy = (0.5)*EIJ * sparse(1:N,1:N,2 * errors / N,N,N) * EIJ';
0234
0235
0236
0237
0238
0239
0240
0241
0242 opts.disp = 0;
0243 [v, s_min] = eigs(Sy, 1, 'SA', opts);
0244
0245
0246
0247 vp = svd(Y);
0248
0249
0250 fprintf('>> smin = %.3e, and min(vp) = %.3e\n',s_min,min(vp));
0251 if (s_min > params.tolSmin) || (min(vp) < params.tolrankdeficiency)
0252 break;
0253 end
0254
0255
0256 rr = rr + 1;
0257
0258
0259 if (s_min < -1e-10)
0260 restartDir = v;
0261 else
0262 restartDir = [];
0263 end
0264 end
0265
0266
0267
0268 infos.time = time;
0269 infos.cost = cost;
0270 infos.rank = rank;
0271 infos.test_error = test_error;
0272 infos.rank_change_stats = rank_change_stats;
0273
0274
0275 show_plots(problem_description, infos);
0276
0277 end
0278
0279
0280
0281
0282
0283 function val = cost_evaluation(Y, data_train)
0284 Z = Y(data_train.rows, :) - Y(data_train.cols,:);
0285 estimDists = sum(Z.^2, 2);
0286 errors = (estimDists - data_train.entries);
0287 val = 0.5*mean(errors.^2);
0288 end
0289
0290
0291
0292
0293
0294 function localdefaults = getlocaldefaults()
0295 localdefaults.abstolcost = 1e-3;
0296 localdefaults.reltolcost = 1e-3;
0297 localdefaults.tolSmin = -1e-3;
0298 localdefaults.tolrankdeficiency = 1e-3;
0299 localdefaults.tolgradnorm = 1e-5;
0300 localdefaults.maxiter = 100;
0301 localdefaults.solver = @trustregions;
0302 end
0303
0304
0305
0306
0307
0308
0309
0310
0311 function [Yopt, infos] = low_rank_dist_completion_fixedrank(data_train, data_test, Y_initial, params)
0312
0313 [n, r] = size(Y_initial);
0314 EIJ = speye(n);
0315 EIJ = EIJ(:, data_train.rows) - EIJ(:, data_train.cols);
0316
0317
0318 problem.M = symfixedrankYYfactory(n, r);
0319
0320
0321
0322 problem.cost = @cost;
0323 function [f, store] = cost(Y, store)
0324 if ~isfield(store, 'xij')
0325 store.xij = EIJ'*Y;
0326 end
0327 xij = store.xij;
0328 estimDists = sum(xij.^2,2);
0329 f = 0.5*mean((estimDists - data_train.entries).^2);
0330 end
0331
0332
0333 problem.grad = @grad;
0334 function [g, store] = grad(Y, store)
0335 N = data_train.nentries;
0336 if ~isfield(store, 'xij')
0337 store.xij = EIJ'*Y;
0338 end
0339 xij = store.xij;
0340 estimDists = sum(xij.^2,2);
0341 g = EIJ * sparse(1:N,1:N,2 * (estimDists - data_train.entries) / N, N, N) * xij;
0342 end
0343
0344
0345
0346 problem.hess = @hess;
0347 function [Hess, store] = hess(Y, eta, store)
0348 N = data_train.nentries;
0349 if ~isfield(store, 'xij')
0350 store.xij = EIJ'*Y;
0351 end
0352 xij = store.xij;
0353 zij = EIJ'*eta;
0354 estimDists = sum(xij.^2,2);
0355 crossYZ = 2*sum(xij .* zij,2);
0356 Hess = (EIJ*sparse(1:N,1:N,2 * (estimDists - data_train.entries) / N,N,N))*zij + (EIJ*sparse(1:N,1:N,2 * crossYZ / N,N,N))*xij;
0357 Hess = problem.M.proj(Y, Hess);
0358 end
0359
0360
0361
0362
0363
0364
0365
0366
0367
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377
0378
0379
0380
0381 if ~isempty(data_test)
0382 options.statsfun = @compute_test_error;
0383 EIJ_test = speye(n);
0384 EIJ_test = EIJ_test(:, data_test.rows) - EIJ_test(:, data_test.cols);
0385 end
0386 function stats = compute_test_error(problem, Y, stats)
0387 xij = EIJ_test'*Y;
0388 estimDists_test = sum(xij.^2,2);
0389 stats.test_error = 0.5*mean((estimDists_test - data_test.entries).^2);
0390 end
0391
0392
0393
0394 options.stopfun = @mystopfun;
0395 function stopnow = mystopfun(problem, Y, info, last)
0396 stopnow = (last >= 5 && (info(last-2).cost - info(last).cost < params.abstolcost || abs(info(last-2).cost - info(last).cost)/info(last).cost < params.reltolcost));
0397 end
0398 options.tolgradnorm = params.tolgradnorm;
0399 options.maxiter = params.maxiter;
0400
0401
0402
0403 options.solver = params.solver;
0404 [Yopt, ~, infos] = manoptsolve(problem, Y_initial, options);
0405 [infos.rank] = deal(r);
0406 end
0407
0408
0409
0410
0411
0412
0413
0414 function problem_description = get_3d_Helix_instance()
0415
0416
0417 tvec = 0:2*pi/100:2*pi;
0418 tvec = tvec';
0419 xvec = 4*cos(3*tvec);
0420 yvec = 4*sin(3*tvec);
0421 zvec = 2*tvec;
0422 Yo = [xvec, yvec, zvec];
0423 n = size(Yo, 1);
0424
0425
0426 fractionOfUnknown = 0.85;
0427
0428
0429
0430 trueDists = pdist(Yo)'.^2;
0431
0432
0433
0434 noise_level = 0;
0435 trueDists = trueDists + noise_level * std(trueDists) * randn(size(trueDists));
0436
0437
0438
0439 H = tril(true(n), -1);
0440 [I, J] = ind2sub([n, n], find(H(:)));
0441 clear 'H';
0442
0443
0444
0445 train = false(length(trueDists), 1);
0446 train(1:floor(length(trueDists)*(1- fractionOfUnknown))) = true;
0447 train = train(randperm(length(train)));
0448
0449 data_train.rows = I(train);
0450 data_train.cols = J(train);
0451 data_train.entries = trueDists(train);
0452 data_train.nentries = length(data_train.entries);
0453
0454
0455
0456 data_test.nentries = 1*data_train.nentries;
0457 test = false(length(trueDists),1);
0458 test(1 : floor(data_test.nentries)) = true;
0459 test = test(randperm(length(test)));
0460 data_test.rows = I(test);
0461 data_test.cols = J(test);
0462 data_test.entries = trueDists(test);
0463
0464
0465
0466 rank_initial = 1;
0467 rank_max = n;
0468
0469
0470
0471 params = struct();
0472 params = mergeOptions(getlocaldefaults, params);
0473
0474
0475
0476 problem_description.data_train = data_train;
0477 problem_description.data_test = data_test;
0478 problem_description.n = n;
0479 problem_description.rank_initial = rank_initial;
0480 problem_description.rank_max = rank_max;
0481 problem_description.params = params;
0482 problem_description.Yo = Yo;
0483 end
0484
0485
0486
0487
0488
0489
0490 function checked_problem_description = check_problem_description(problem_description)
0491 checked_problem_description = problem_description;
0492
0493
0494 if isempty(problem_description)...
0495 || ~all(isfield(problem_description,{'data_train'}) == 1)...
0496 || ~all(isfield(problem_description.data_train,{'cols', 'rows', 'entries'}) == 1)...
0497 || isempty(problem_description.data_train.cols)...
0498 || isempty(problem_description.data_train.rows)...
0499 || isempty(problem_description.data_train.entries)
0500
0501 fprintf(['The training set is empty or not properly defined.\n' ...
0502 'We work with the default 3d Helix example.\n']);
0503 checked_problem_description = get_3d_Helix_instance();
0504 checked_problem_description.helix_example = true;
0505 return;
0506 end
0507
0508
0509
0510 if ~isfield(problem_description, 'n')
0511 error('low_rank_dist_completion:problem_description',...
0512 'Error. The scalar corresponding to field "n" of problem description must be given. \n');
0513 end
0514
0515
0516
0517 if ~isfield(problem_description, 'rank_initial')...
0518 || isempty(problem_description.rank_initial)...
0519 || ~(floor(problem_description.rank_initial) == problem_description.rank_initial)
0520 warning('low_rank_dist_completion:problem_description', ...
0521 'The field "rank_initial" is not properly defined. We work with the default "1".\n');
0522 rank_initial = 1;
0523 else
0524 rank_initial = problem_description.rank_initial;
0525 end
0526 checked_problem_description.rank_initial = rank_initial;
0527
0528
0529
0530 if ~isfield(problem_description, 'rank_max')...
0531 || isempty(problem_description.rank_max)...
0532 || ~(floor(problem_description.rank_max) == problem_description.rank_max)...
0533 || problem_description.rank_max > problem_description.n
0534 warning('low_rank_dist_completion:problem_description', ...
0535 'The field "rank_max" is not properly defined. We work with the default "n".\n');
0536 rank_max = problem_description.n;
0537 else
0538 rank_max = problem_description.rank_max;
0539 end
0540 checked_problem_description.rank_max = rank_max;
0541
0542
0543
0544 if ~isfield(problem_description,{'data_test'})...
0545 || ~all(isfield(problem_description.data_test,{'cols', 'rows', 'entries'}) == 1)...
0546 || isempty(problem_description.data_test.cols)...
0547 || isempty(problem_description.data_test.rows)...
0548 || isempty(problem_description.data_test.entries)
0549
0550 warning('low_rank_dist_completion:problem_description', ...
0551 'The field "data_test" is not properly defined. We work with the default "[]".\n');
0552 data_test = [];
0553 else
0554 data_test = problem_description.data_test;
0555 end
0556 checked_problem_description.data_test = data_test;
0557
0558
0559
0560 if isfield(problem_description, 'params')
0561 params = problem_description.params;
0562 else
0563 params = struct();
0564 end
0565 params = mergeOptions(getlocaldefaults, params);
0566 checked_problem_description.params = params;
0567
0568 end
0569
0570
0571
0572
0573
0574 function show_plots(problem_description, infos)
0575
0576 solver = problem_description.params.solver;
0577 rank_change_stats = infos.rank_change_stats;
0578 rank_change_stats_rank = [rank_change_stats.rank];
0579 rank_change_stats_iter = [rank_change_stats.iter];
0580 rank_change_stats_iter = cumsum(rank_change_stats_iter);
0581 N = problem_description.data_train.nentries;
0582 n = problem_description.n;
0583
0584
0585
0586 fs = 20;
0587 figure('name', 'Training on the known distances');
0588
0589 line(1:length([infos.cost]),log10([infos.cost]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5);
0590 ax1 = gca;
0591
0592 set(ax1,'FontSize',fs);
0593 xlabel(ax1,'Number of iterations','FontSize',fs);
0594 ylabel(ax1,'Cost (log scale) on known distances','FontSize',fs);
0595
0596 ax2 = axes('Position',get(ax1,'Position'),...
0597 'XAxisLocation','top',...
0598 'YAxisLocation','right',...
0599 'Color','none',...
0600 'XColor','k');
0601
0602 set(ax2,'FontSize',fs);
0603 line(1:length([infos.cost]),log10([infos.cost]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5,'Parent',ax2);
0604 set(ax2,'XTick',rank_change_stats_iter(1:max(1,end-1)),...
0605 'XTickLabel',rank_change_stats_rank(1) + 1 : rank_change_stats_rank(max(1,end-1)) + 1,...
0606 'YTick',[]);
0607
0608 set(ax2,'XGrid','on');
0609 legend(func2str(solver));
0610 title('Rank');
0611 legend 'boxoff';
0612
0613
0614
0615 if isfield(infos, 'test_error') && ~isempty(infos.test_error)
0616 Yo = problem_description.Yo;
0617
0618 fs = 20;
0619 figure('name','Test error on a set of distances different from the training set');
0620
0621 line(1:length([infos.test_error]),log10([infos.test_error]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5);
0622 ax1 = gca;
0623
0624 set(ax1,'FontSize',fs);
0625 xlabel(ax1,'Number of iterations','FontSize',fs);
0626 ylabel(ax1,'Cost (log scale) on testing set','FontSize',fs);
0627
0628 ax2 = axes('Position',get(ax1,'Position'),...
0629 'XAxisLocation','top',...
0630 'YAxisLocation','right',...
0631 'Color','none',...
0632 'XColor','k');
0633
0634 set(ax2,'FontSize',fs);
0635 line(1:length([infos.test_error]),log10([infos.test_error]),'Marker','O','LineStyle','-','Color','blue','LineWidth',1.5,'Parent',ax2);
0636 set(ax2,'XTick',rank_change_stats_iter(1:max(1,end-1)),...
0637 'XTickLabel',rank_change_stats_rank(1) + 1 : rank_change_stats_rank(max(1,end-1)) + 1,...
0638 'YTick',[]);
0639
0640 set(ax2,'XGrid','on');
0641 legend(func2str(solver));
0642 title('Rank');
0643 legend 'boxoff';
0644
0645
0646
0647 end
0648
0649
0650
0651 if isfield(problem_description, 'helix_example')
0652 jj = ceil((length(rank_change_stats_rank) + 1)/2);
0653
0654
0655 figure('name',['3D structure with ', num2str(N/((n^2 -n)/2)),' fraction known distances'])
0656 fs = 20;
0657 ax1 = gca;
0658 set(ax1,'FontSize',fs);
0659 subplot(jj,2,1);
0660 plot3(Yo(:,1), Yo(:,2), Yo(:,3),'*','Color', 'b','LineWidth',1.0);
0661 title('Original 3D structure');
0662 for kk = 1 : length(rank_change_stats_rank)
0663 subplot(jj, 2, kk + 1);
0664 rank_change_stats_kk = rank_change_stats(kk);
0665 Ykk = rank_change_stats_kk.Y;
0666 if size(Ykk, 2) == 1
0667 plot(Ykk(:,1), zeros(size(Ykk, 1)),'*','Color', 'r','LineWidth',1.0);
0668 legend(func2str(solver))
0669 title(['Recovery at rank ',num2str(size(Ykk, 2))]);
0670
0671 elseif size(Ykk, 2) == 2
0672 plot(Ykk(:,1), Ykk(:,2),'*','Color', 'r','LineWidth',1.0);
0673 title(['Recovery at rank ',num2str(size(Ykk, 2))]);
0674
0675 else
0676 [U1, S1, V1] = svds(Ykk, 3);
0677 Yhat = U1*S1*V1';
0678 plot3(Yhat(:,1), Yhat(:,2), Yhat(:,3),'*','Color', 'r','LineWidth',1.0);
0679 title(['Recovery at rank ',num2str(size(Ykk, 2))]);
0680 end
0681
0682 axis equal;
0683
0684 end
0685
0686
0687
0688
0689 ha = axes('Position',[0 0 1 1],'Xlim',[0 1],'Ylim',[0 1],'Box','off','Visible','off','Units','normalized', 'clipping' , 'off' );
0690 set(ha, 'HitTest', 'off');
0691 text(0.5, 1,['Recovery of Helix from ',num2str(N/((n^2 -n)/2)),' fraction known distances'],'HorizontalAlignment','center','VerticalAlignment', 'top');
0692 end
0693
0694 end