Home > manopt > tools > factorygpuhelper.m

factorygpuhelper

PURPOSE ^

Returns a manifold struct to optimize over unit-norm vectors or matrices.

SYNOPSIS ^

function M = factorygpuhelper(M)

DESCRIPTION ^

 Returns a manifold struct to optimize over unit-norm vectors or matrices.

 function M = factorygpuhelper(M)

 Helper tool to add GPU support to factories. The input is a factory M
 created by one of Manopt's factories. The output is the same factory,
 with gather() and gpuArray() added in a number of places, following the
 logic that points, tangent vectors and ambient vectors are stored on the
 GPU (but scalars should be 'gathered' to the CPU). The name of the
 factory is also appended with '(GPU)'.

 This tool is typically called inside the factory itself, at the very end.
 It is not enough to call this tool: one also needs to create all arrays
 on the GPU directly. See spherefactory for an example.

 This tool is still in beta: please let us know about any issues via the
 forum on http://www.manopt.org. Thanks!

 See also: spherefactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function M = factorygpuhelper(M)
0002 % Returns a manifold struct to optimize over unit-norm vectors or matrices.
0003 %
0004 % function M = factorygpuhelper(M)
0005 %
0006 % Helper tool to add GPU support to factories. The input is a factory M
0007 % created by one of Manopt's factories. The output is the same factory,
0008 % with gather() and gpuArray() added in a number of places, following the
0009 % logic that points, tangent vectors and ambient vectors are stored on the
0010 % GPU (but scalars should be 'gathered' to the CPU). The name of the
0011 % factory is also appended with '(GPU)'.
0012 %
0013 % This tool is typically called inside the factory itself, at the very end.
0014 % It is not enough to call this tool: one also needs to create all arrays
0015 % on the GPU directly. See spherefactory for an example.
0016 %
0017 % This tool is still in beta: please let us know about any issues via the
0018 % forum on http://www.manopt.org. Thanks!
0019 %
0020 % See also: spherefactory
0021 
0022 % This file is part of Manopt: www.manopt.org.
0023 % Original author: Nicolas Boumal, Aug. 3, 2018.
0024 % Contributors:
0025 % Change log:
0026 
0027     % Tag the factory name.
0028     M.name = @() [M.name(), ' (GPU)'];
0029     
0030     % Gathering scalar outputs: it's unclear whether this is necessary.
0031     M.inner = @(x, u, v) gather(M.inner(x, u, v));
0032     M.norm = @(x, u) gather(M.norm(x, u));
0033     M.dist = @(x, y) gather(M.dist(x, y));
0034     
0035     % TODO: check that this works for manifolds whose points are not
0036     % matrices (but are structs or cells).
0037     M.hash = @(x) M.hash(gather(x));
0038     
0039     % The vec/mat pair is mostly used in the hessianspectrum tool, where
0040     % the vector representation of tangent vectors is assumed to be in
0041     % 'normal' memory (as opposed to GPU). But it's unclear whether we
0042     % actually need this too.
0043     M.vec = @(x, u_mat) gather(M.vec(x, u_mat));
0044     M.mat = @(x, u_vec) M.mat(x, gpuArray(u_vec));
0045 
0046 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005