Home > manopt > autodiff > basic_examples_AD > realtest_AD3.m

realtest_AD3

PURPOSE ^

Test AD for a real optimization problem on a manifold which is stored in

SYNOPSIS ^

function realtest_AD3()

DESCRIPTION ^

 Test AD for a real optimization problem on a manifold which is stored in
 a particular data structure which is recursively defined by a struct, an
 array and a cell.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function realtest_AD3()
0002 % Test AD for a real optimization problem on a manifold which is stored in
0003 % a particular data structure which is recursively defined by a struct, an
0004 % array and a cell.
0005 
0006     % Verify that Manopt was indeed added to the Matlab path.
0007     if isempty(which('spherefactory'))
0008         error(['You should first add Manopt to the Matlab path.\n' ...
0009                'Please run importmanopt.']);
0010     end
0011     
0012     % Verify that the deep learning tool box was installed
0013     assert(exist('dlarray', 'file') == 2, ['Deep learning tool box is '... 
0014     'needed for automatic differentiation.\n Please install the'...
0015     'latest version of the deep learning tool box and \nupgrade to Matlab'...
0016     ' R2021b if possible.'])
0017     
0018     % Generate the problem data.
0019     n = 100;
0020     A = randn(n);
0021     A = .5*(A+A');
0022     
0023     % Create the manifold
0024     S = spherefactory(n);
0025     P = powermanifold(S,1); % cell
0026     X.x = S;
0027     X.y = P;
0028     problem.M = productmanifold(X); % struct
0029     
0030     % Define the problem cost function
0031     problem.cost  = @(X) -X.x'*(A*X.y{1});
0032     
0033     % Define the gradient and the hessian via automatic differentiation
0034     problem = manoptAD(problem);
0035 
0036     % Numerically check gradient and Hessian consistency.
0037     figure;
0038     checkgradient(problem);
0039     figure;
0040     checkhessian(problem);
0041     
0042     % Solve.
0043     [x, xcost, info] = trustregions(problem);          %#ok<ASGLU>
0044     
0045     % Test
0046     ground_truth = svd(A);
0047     distance = abs(ground_truth(1) - (-problem.cost(x)));
0048     fprintf('The distance between the ground truth and the solution is %e \n',distance);
0049 
0050     
0051 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005