0001 function M = poincareballfactory(k, n, gpuflag)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 if ~exist('n', 'var') || isempty(n)
0030 n = 1;
0031 end
0032
0033 if ~exist('gpuflag', 'var') || isempty(gpuflag)
0034 gpuflag = false;
0035 end
0036
0037
0038
0039
0040 if gpuflag
0041 array_type = 'gpuArray';
0042 else
0043 array_type = 'double';
0044 end
0045
0046
0047 if n == 1
0048 M.name = @() sprintf('Poincare ball B_%d', k);
0049 else
0050 M.name = @() sprintf('Poincare ball B_%d^%d', k, n);
0051 end
0052
0053 M.dim = @() k * n;
0054
0055 M.conformal_factor = @(x) 2 ./ (1 - sum(x .* x, 1));
0056
0057 M.inner = @(x, u, v) sum(sum(u .* v, 1) .* (M.conformal_factor(x).^2));
0058
0059 M.norm = @(x, d) sqrt(M.inner(x, d, d));
0060
0061 M.dist = @dist;
0062 function d = dist(x, y)
0063 norms2x = sum(x .* x, 1);
0064 norms2y = sum(y .* y, 1);
0065 norms2diff = sum((x - y) .* (x - y), 1);
0066 d = sqrt(sum(acosh(1 + 2 * norms2diff ./ (1 - norms2x) ./ (1 - norms2y)) .^ 2));
0067 end
0068
0069 M.typicaldist = @() M.dim() / 8;
0070
0071
0072 M.proj = @(x, d) d;
0073
0074 M.tangent = M.proj;
0075
0076
0077
0078 M.egrad2rgrad = @egrad2rgrad;
0079 function rgrad = egrad2rgrad(x, egrad)
0080 factor = M.conformal_factor(x);
0081 rgrad = egrad .* ((1 ./ factor).^2);
0082 end
0083
0084 M.ehess2rhess = @ehess2rhess;
0085 function rhess = ehess2rhess(x, egrad, ehess, u)
0086 factor = M.conformal_factor(x);
0087 rhess = ( u .* sum(egrad .* x, 1) - ...
0088 egrad .* sum(u .* x, 1) - ...
0089 x .* sum(u .* egrad, 1) + ...
0090 ehess ./ factor ...
0091 ) ./factor;
0092 end
0093
0094 M.mobius_addition = @mobius_addition;
0095 function res = mobius_addition(x, y)
0096 sp = sum(x .* y, 1);
0097 norm2x = sum(x .* x, 1);
0098 norm2y = sum(y .* y, 1);
0099 res = ( x .* (1 + 2 .* sp + norm2y) + y .* (1 - norm2x) ) ...
0100 ./ (1 + 2 .* sp + norm2x .* norm2y);
0101 end
0102
0103 M.exp = @exponential;
0104 M.log = @logarithm;
0105
0106 M.retr = M.exp;
0107 M.invretr = M.log;
0108
0109
0110 M.transp = @(x1, x2, v) v;
0111
0112 M.hash = @(x) ['z' hashmd5(x(:))];
0113
0114
0115 M.rand = @() sample_ball_uniformly(k, n, array_type);
0116
0117 M.randvec = @randvec;
0118 function v = randvec(x)
0119 v = randn(k, n, array_type);
0120 v = v / M.norm(x, v);
0121 end
0122
0123 M.zerovec = @(x) zeros(k, n, array_type);
0124
0125 M.lincomb = @matrixlincomb;
0126
0127 M.pairmean = @pairmean;
0128 function y = pairmean(x1, x2)
0129 y = M.exp(x1, M.log(x1, x2) / 2);
0130 end
0131
0132 M.vec = @vec;
0133 function u_vec = vec(x, u_mat)
0134 u_vec = bsxfun(@times, u_mat, M.conformal_factor(x));
0135 u_vec = u_vec(:);
0136 end
0137 M.mat = @mat;
0138 function u_mat = mat(x, u_vec)
0139 u_mat = reshape(u_vec, [k, n]);
0140 u_mat = bsxfun(@times, u_mat, 1./M.conformal_factor(x));
0141 end
0142 M.vecmatareisometries = @() true;
0143
0144
0145
0146 if gpuflag
0147 M = factorygpuhelper(M);
0148 end
0149
0150 end
0151
0152 function z = mobius_addition(x, y)
0153 inner = sum(x .* y, 1);
0154 norms2x = sum(x .* x, 1);
0155 norms2y = sum(y .* y, 1);
0156 z = ((1 + 2 * inner + norms2y) .* x + (1 - norms2x) .* y ) ./ (1 + 2 * inner + norms2x .* norms2y);
0157 end
0158
0159
0160 function y = exponential(x, d, t)
0161 if nargin == 2
0162
0163 td = d;
0164 else
0165 td = t*d;
0166 end
0167
0168 normstd = vecnorm(td);
0169 factor = (1 - sum(x .* x, 1));
0170
0171 w = td .* (tanh(normstd ./ factor) ./ (normstd + (normstd == 0)));
0172 y = mobius_addition(x, w);
0173 end
0174
0175 function v = logarithm(x, y)
0176 w = mobius_addition(-x, y);
0177 normsw = vecnorm(w);
0178 factor = 1 - sum(x .* x, 1);
0179 v = w .* factor .* atanh(normsw) ./ normsw;
0180 end
0181
0182 function x = sample_ball_uniformly(k, n, array_type)
0183 isotropic = randn(k, n, array_type);
0184 isotropic = isotropic ./ vecnorm(isotropic);
0185 radiuses = rand(1, n, array_type) .^ (1 / k);
0186 x = isotropic .* radiuses;
0187 end