Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > @TTeMPS_op > apply.m

apply

PURPOSE ^

APPLY Application of TT/MPS operator to a TT/MPS tensor

SYNOPSIS ^

function y = apply( A, x, idx )

DESCRIPTION ^

APPLY Application of TT/MPS operator to a TT/MPS tensor
   Y = APPLY(A, X) applies the TT/MPS operator A to the TT/MPS tensor X.

   Y = APPLY(A, X, idx) is the application of A but only in mode idx.
       note that in this case, X is assumed to be a standard matlab array and
       not a TTeMPS tensor. 

   In both cases, X can come from a block-TT format, that is, with a four-dimensional core instead.

   See also CONTRACT

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function y = apply( A, x, idx )
0002     %APPLY Application of TT/MPS operator to a TT/MPS tensor
0003     %   Y = APPLY(A, X) applies the TT/MPS operator A to the TT/MPS tensor X.
0004     %
0005     %   Y = APPLY(A, X, idx) is the application of A but only in mode idx.
0006     %       note that in this case, X is assumed to be a standard matlab array and
0007     %       not a TTeMPS tensor.
0008     %
0009     %   In both cases, X can come from a block-TT format, that is, with a four-dimensional core instead.
0010     %
0011     %   See also CONTRACT
0012 
0013     %   TTeMPS Toolbox.
0014     %   Michael Steinlechner, 2013-2016
0015     %   Questions and contact: michael.steinlechner@epfl.ch
0016     %   BSD 2-clause license, see LICENSE.txt
0017 
0018     % first case: all cores.
0019     if ~exist( 'idx', 'var' )
0020         V = cell(1, A.order);
0021 
0022         if isa( x, 'TTeMPS' )
0023             for i = 1:A.order
0024                 %check for possible block format
0025                 p = size(x.U{i},4);
0026                 if p ~= 1
0027                     Xi = permute( x.U{i}, [2 1 3 4]);
0028                     Xi = reshape( Xi, [A.size_row(i), x.rank(i)*x.rank(i+1)*p] );
0029                     Ai = reshape( permute( A.U{i}, [1 2 4 3]), A.rank(i)*A.size_col(i)*A.rank(i+1), A.size_row(i));
0030                     V{i} = Ai*Xi;
0031                     V{i} = reshape( V{i}, [A.rank(i), A.rank(i+1), A.size_col(i), x.rank(i), x.rank(i+1), p] );
0032                     V{i} = permute( V{i}, [1, 4, 3, 2, 5, 6]);
0033                     V{i} = reshape( V{i}, [A.rank(i)*x.rank(i), A.size_col(i), A.rank(i+1)*x.rank(i+1), p]);
0034                 else
0035                     Xi = matricize( x.U{i}, 2);
0036                     Ai = reshape( permute( A.U{i}, [1 2 4 3]), A.rank(i)*A.size_col(i)*A.rank(i+1), A.size_row(i));
0037                     V{i} = Ai*Xi;
0038                     V{i} = reshape( V{i}, [A.rank(i), A.size_col(i), A.rank(i+1), x.rank(i), x.rank(i+1)] );
0039                     V{i} = permute( V{i}, [1, 4, 2, 3, 5]);
0040                     V{i} = reshape( V{i}, [A.rank(i)*x.rank(i), A.size_col(i), A.rank(i+1)*x.rank(i+1)]);
0041                 end
0042             end
0043             y = TTeMPS( V );
0044 
0045         elseif isa( x, 'TTeMPS_block' )
0046             mu = x.mu;
0047             p = x.p;
0048             % first case: all cores.
0049             V = cell(1, A.order);
0050 
0051             for i = [1:mu-1, mu+1:A.order]
0052                 Xi = matricize( x.U{i}, 2);
0053                 Ai = reshape( permute( A.U{i}, [1 2 4 3]), A.rank(i)*A.size_col(i)*A.rank(i+1), A.size_row(i));
0054                 V{i} = Ai*Xi;
0055                 V{i} = reshape( V{i}, [A.rank(i), A.size_col(i), A.rank(i+1), x.rank(i), x.rank(i+1)] );
0056                 V{i} = permute( V{i}, [1, 4, 2, 3, 5]);
0057                 V{i} = reshape( V{i}, [A.rank(i)*x.rank(i), A.size_col(i), A.rank(i+1)*x.rank(i+1)]);
0058             end
0059 
0060             Xi = permute( x.U{mu}, [2 1 3 4]);
0061             Xi = reshape( Xi, [A.size_row(mu), x.rank(mu)*x.rank(mu+1)*p] );
0062             Ai = reshape( permute( A.U{mu}, [1 2 4 3]), A.rank(mu)*A.size_col(mu)*A.rank(mu+1), A.size_row(mu));
0063             V{mu} = Ai*Xi;
0064             V{mu} = reshape( V{mu}, [A.rank(mu), A.rank(mu+1), A.size_col(mu), x.rank(mu), x.rank(mu+1), p] );
0065             V{mu} = permute( V{mu}, [1, 4, 3, 2, 5, 6]);
0066             V{mu} = reshape( V{mu}, [A.rank(mu)*x.rank(mu), A.size_col(mu), A.rank(mu+1)*x.rank(mu+1), p]);
0067 
0068             y = TTeMPS_block( V, mu, p );
0069 
0070         else
0071             error('Unsupported class type of vector argument. Must be TTeMPS or TTeMPS_block object')
0072         end
0073 
0074     else
0075         %check for possible block format
0076         p = size(x,4);
0077         if p ~= 1
0078             Xi = permute( x, [2 1 3 4]);
0079             Xi = reshape( Xi, [A.size_row(idx), size(x, 1)*size(x, 3)*p] );
0080             Ai = reshape( permute( A.U{idx}, [1 2 4 3]), A.rank(idx)*A.size_col(idx)*A.rank(idx+1), A.size_row(idx));
0081             V = Ai*Xi;
0082             V = reshape( V, [A.rank(idx), A.rank(idx+1), A.size_col(idx), size(x, 1), size(x, 3), p] );
0083             V = permute( V, [1, 4, 3, 2, 5, 6]);
0084             y = reshape( V, [A.rank(idx)*size(x, 1), A.size_col(idx), A.rank(idx+1)*size(x, 3), p]);
0085         else
0086             Xi = matricize( x, 2);
0087             Ai = reshape( permute( A.U{idx}, [1 2 4 3]), A.rank(idx)*A.size_col(idx)*A.rank(idx+1), A.size_row(idx));
0088             V = Ai*Xi;
0089             V = reshape( V, [A.rank(idx), A.size_col(idx), A.rank(idx+1), size(x, 1), size(x, 3)] );
0090             V = permute( V, [1, 4, 2, 3, 5]);
0091             y = reshape( V, [A.rank(idx)*size(x,1), A.size_col(idx), A.rank(idx+1)*size(x,3)]);
0092         end
0093     end
0094 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005