Home > manopt > autodiff > functions_AD > ctrace.m

ctrace

PURPOSE ^

Computes the sum of diagonal elements of A.

SYNOPSIS ^

function traceA = ctrace(A)

DESCRIPTION ^

 Computes the sum of diagonal elements of A.

 function traceA = ctrace(A)

 Returns the sum of diagonal elements of A. The input A does not have
 to be a square matrix. The function supports both numeric  arrays and
 structs with fields real and imag. This file was created because trace
 is not currently supported by dlarrays which we use for automatic
 differentiation: ctrace is a backup function.

 See also: manoptADhelp

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function traceA = ctrace(A)
0002 % Computes the sum of diagonal elements of A.
0003 %
0004 % function traceA = ctrace(A)
0005 %
0006 % Returns the sum of diagonal elements of A. The input A does not have
0007 % to be a square matrix. The function supports both numeric  arrays and
0008 % structs with fields real and imag. This file was created because trace
0009 % is not currently supported by dlarrays which we use for automatic
0010 % differentiation: ctrace is a backup function.
0011 %
0012 % See also: manoptADhelp
0013 
0014 % This file is part of Manopt: www.manopt.org.
0015 % Original author: Xiaowen Jiang, July 31, 2021.
0016 % Contributors: Nicolas Boumal
0017 % Change log:
0018 
0019     if iscstruct(A)
0020         assert(length(size(A.real)) == 2, 'Input should be a 2-D array')
0021         m = size(A.real, 1);
0022         n = size(A.real, 2);
0023         realA = A.real;
0024         imagA = A.imag;
0025         
0026         if n >= m
0027             traceA.real = sum(realA(1:m+1:m^2));
0028             traceA.imag = sum(imagA(1:m+1:m^2));
0029         else
0030             traceA.real = sum(realA(1:m+1:m*n-m+n));
0031             traceA.imag = sum(imagA(1:m+1:m*n-m+n));
0032         end
0033         
0034     elseif isnumeric(A)
0035         assert(length(size(A)) == 2, 'Input should be a 2-D array')
0036         m = size(A,1);
0037         n = size(A,2);
0038         if n >= m
0039             traceA = sum(A(1:m+1:m^2));
0040         else
0041             traceA = sum(A(1:m+1:m*n-m+n));
0042         end
0043 
0044     else
0045         ME = MException('ctrace:inputError', ...
0046                         'Input does not have the expected format.');
0047         throw(ME);
0048         
0049     end
0050 
0051 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005