Returns a manifold struct to optimize over unit-norm vectors or matrices. function M = factorygpuhelper(M) Helper tool to add GPU support to factories. The input is a factory M created by one of Manopt's factories. The output is the same factory, with gather() and gpuArray() added in a number of places, following the logic that points, tangent vectors and ambient vectors are stored on the GPU (but scalars should be 'gathered' to the CPU). The name of the factory is also appended with '(GPU)'. This tool is typically called inside the factory itself, at the very end. It is not enough to call this tool: one also needs to create all arrays on the GPU directly. See spherefactory for an example. This tool is still in beta: please let us know about any issues via the forum on http://www.manopt.org. Thanks! See also: spherefactory
0001 function M = factorygpuhelper(M) 0002 % Returns a manifold struct to optimize over unit-norm vectors or matrices. 0003 % 0004 % function M = factorygpuhelper(M) 0005 % 0006 % Helper tool to add GPU support to factories. The input is a factory M 0007 % created by one of Manopt's factories. The output is the same factory, 0008 % with gather() and gpuArray() added in a number of places, following the 0009 % logic that points, tangent vectors and ambient vectors are stored on the 0010 % GPU (but scalars should be 'gathered' to the CPU). The name of the 0011 % factory is also appended with '(GPU)'. 0012 % 0013 % This tool is typically called inside the factory itself, at the very end. 0014 % It is not enough to call this tool: one also needs to create all arrays 0015 % on the GPU directly. See spherefactory for an example. 0016 % 0017 % This tool is still in beta: please let us know about any issues via the 0018 % forum on http://www.manopt.org. Thanks! 0019 % 0020 % See also: spherefactory 0021 0022 % This file is part of Manopt: www.manopt.org. 0023 % Original author: Nicolas Boumal, Aug. 3, 2018. 0024 % Contributors: 0025 % Change log: 0026 0027 % Tag the factory name. 0028 M.name = @() [M.name(), ' (GPU)']; 0029 0030 % Gathering scalar outputs: it's unclear whether this is necessary. 0031 M.inner = @(x, u, v) gather(M.inner(x, u, v)); 0032 M.norm = @(x, u) gather(M.norm(x, u)); 0033 M.dist = @(x, y) gather(M.dist(x, y)); 0034 0035 % TODO: check that this works for manifolds whose points are not 0036 % matrices (but are structs or cells). 0037 M.hash = @(x) M.hash(gather(x)); 0038 0039 % The vec/mat pair is mostly used in the hessianspectrum tool, where 0040 % the vector representation of tangent vectors is assumed to be in 0041 % 'normal' memory (as opposed to GPU). But it's unclear whether we 0042 % actually need this too. 0043 M.vec = @(x, u_mat) gather(M.vec(x, u_mat)); 0044 M.mat = @(x, u_vec) M.mat(x, gpuArray(u_vec)); 0045 0046 end