Home > manopt > autodiff > basic_examples_AD > complextest_AD2.m

complextest_AD2

PURPOSE ^

Test AD for a complex optimization problem on a power manifold (cell)

SYNOPSIS ^

function complextest_AD2()

DESCRIPTION ^

 Test AD for a complex optimization problem on a power manifold (cell)

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function complextest_AD2()
0002 % Test AD for a complex optimization problem on a power manifold (cell)
0003 
0004     % Verify that Manopt was indeed added to the Matlab path.
0005     if isempty(which('spherecomplexfactory'))
0006         error(['You should first add Manopt to the Matlab path.\n' ...
0007                'Please run importmanopt.']);
0008     end
0009     
0010     % Verify that the deep learning tool box was installed
0011     assert(exist('dlarray', 'file') == 2, ['Deep learning tool box is '... 
0012     'needed for automatic differentiation.\n Please install the'...
0013     'latest version of the deep learning tool box and \nupgrade to Matlab'...
0014     ' R2021b if possible.'])
0015     
0016     % Generate the problem data.
0017     n = 100;
0018     A = randn(n) + 1i*randn(n);
0019     A = .5*(A+A');
0020     
0021     % Create the power manifold
0022     S = spherecomplexfactory(n);
0023     problem.M = powermanifold(S,2); %cell
0024     
0025     % For Matlab R2021b or later, define the problem cost function as usual
0026     % problem.cost  = @(X) -real(X{1}'*A*X{2});
0027     
0028     % For Matlab R2021a or earlier, translate the cost function into a
0029     % particular format with the basic functions in /functions_AD
0030     problem.cost  = @(X) -creal(cprod(cprod(ctransp(X{1}), A), X{2}));
0031 
0032     % Define the gradient and the hessian via automatic differentiation
0033     problem = manoptAD(problem);
0034 
0035     % Numerically check gradient and Hessian consistency.
0036     figure;
0037     checkgradient(problem);
0038     figure;
0039     checkhessian(problem);
0040     
0041     % Solve.
0042     [x, xcost, info] = trustregions(problem);          %#ok<ASGLU>
0043     
0044     % Test
0045     ground_truth = svd(A);
0046     distance = abs(ground_truth(1) - (-problem.cost(x)));
0047     fprintf('The distance between the ground truth and the solution is %e \n',distance);
0048 
0049     
0050 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005