Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > @TTeMPS > contract.m

contract

PURPOSE ^

CONTRACT Contraction of two TT/MPS tensors.

SYNOPSIS ^

function res = contract( x, y, idx )

DESCRIPTION ^

CONTRACT Contraction of two TT/MPS tensors.
   Z = CONTRACT(X,Y,IDX) contracts all cores of the two TT/MPS tensors X and Y except 
   core IDX. Result Z is a tensor of size [X.rank(IDX),Y.ORDER(IDX),X.rank(IDX+1)].

   RES = CONTRACT(X,Y,[IDX1, IDX2]) contracts all cores of the two TT/MPS tensors X and Y except 
   cores [IDX1, IDX2]. IDX1 and IDX2 must be two consecutive integers in ascending order. 
   Result RES is a cell array with two tensors of size [X.rank(IDX1),Y.ORDER(IDX1),Y.rank(IDX2)]
   and [Y.rank(IDX2),Y.ORDER(IDX2),X.rank(IDX2+1)], respectively.

   See also INNERPROD.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function res = contract( x, y, idx )
0002     %CONTRACT Contraction of two TT/MPS tensors.
0003     %   Z = CONTRACT(X,Y,IDX) contracts all cores of the two TT/MPS tensors X and Y except
0004     %   core IDX. Result Z is a tensor of size [X.rank(IDX),Y.ORDER(IDX),X.rank(IDX+1)].
0005     %
0006     %   RES = CONTRACT(X,Y,[IDX1, IDX2]) contracts all cores of the two TT/MPS tensors X and Y except
0007     %   cores [IDX1, IDX2]. IDX1 and IDX2 must be two consecutive integers in ascending order.
0008     %   Result RES is a cell array with two tensors of size [X.rank(IDX1),Y.ORDER(IDX1),Y.rank(IDX2)]
0009     %   and [Y.rank(IDX2),Y.ORDER(IDX2),X.rank(IDX2+1)], respectively.
0010     %
0011     %   See also INNERPROD.
0012 
0013     %   TTeMPS Toolbox.
0014     %   Michael Steinlechner, 2013-2016
0015     %   Questions and contact: michael.steinlechner@epfl.ch
0016     %   BSD 2-clause license, see LICENSE.txt
0017 
0018     sz = size(idx);
0019 
0020     if min(sz) == 1
0021         if max(sz) == 1
0022             % Scalar IDX case: only one node not to contract
0023             if idx == 1
0024                 right = innerprod( x, y, 'RL', idx+1 );
0025                 res = tensorprod_ttemps( y.U{idx}, right, 3 );
0026             elseif idx == x.order
0027                 left = innerprod( x, y, 'LR', idx-1 );
0028                 res = tensorprod_ttemps( y.U{idx}, left, 1 );
0029             else
0030                 left = innerprod( x, y, 'LR', idx-1 );
0031                 right = innerprod( x, y, 'RL', idx+1 ); 
0032 
0033                 res = tensorprod_ttemps( y.U{idx}, left, 1 );
0034                 res = tensorprod_ttemps( res, right, 3 );
0035             end 
0036 
0037         elseif max(sz) == 2
0038             % Two-IDX case: two neighboring nodes to contract.
0039             if diff(idx) ~= 1
0040                 error('Choose two neighboring nodes in ascending order.')
0041             end
0042 
0043             if idx(1) == 1
0044                 % test for block format
0045                 q = size(y.U{idx(2)}, 4);
0046                 right = innerprod( x, y, 'RL', idx(2)+1 );
0047                 res{1} = y.U{1};
0048                 if q ~= 1 
0049                     s = size(y.U{idx(2)});
0050                     res{2} = reshape( permute( y.U{idx(2)}, [3 1 2 4] ), [s(3), s(1)*s(2)*q]);
0051                     res{2} = right*res{2};
0052                     res{2} = reshape( res{2}, [size(right, 1), s(1), s(2), q]);
0053                     res{2} = ipermute( res{2}, [3 1 2 4] );
0054                 else
0055                     res{2} = tensorprod_ttemps( y.U{idx(2)}, right, 3 );
0056                 end
0057 
0058             elseif idx(2) == x.order
0059                 % test for block format
0060                 p = size(y.U{idx(1)}, 4);
0061                 left = innerprod( x, y, 'LR', idx(1)-1 );
0062                 if p ~= 1 
0063                     s = size(y.U{idx(1)});
0064                     res{1} = reshape( y.U{idx(1)}, [s(1), s(2)*s(3)*p]);
0065                     res{1} = left * res{1};
0066                     res{1} = reshape( res{1}, [size(left,1), s(2), s(3), p]);
0067                 else
0068                     res{1} = tensorprod_ttemps( y.U{idx(1)}, left, 1 );
0069                 end
0070                 res{2} = y.U{x.order};
0071 
0072             else
0073                 left = innerprod( x, y, 'LR', idx(1)-1 );
0074                 right = innerprod( x, y, 'RL', idx(2)+1 ); 
0075                 % test for block format
0076                 p = size(y.U{idx(1)}, 4);
0077                 q = size(y.U{idx(2)}, 4);
0078                 if p ~= 1 
0079                     s = size(y.U{idx(1)});
0080                     res{1} = reshape( y.U{idx(1)}, [s(1), s(2)*s(3)*p]);
0081                     res{1} = left * res{1};
0082                     res{1} = reshape( res{1}, [size(left,1), s(2), s(3), p]);
0083                     res{2} = tensorprod_ttemps( y.U{idx(2)}, right, 3 );
0084                 elseif q ~= 1
0085                     res{1} = tensorprod_ttemps( y.U{idx(1)}, left, 1 );
0086                     s = size(y.U{idx(2)});
0087                     res{2} = reshape( permute( y.U{idx(2)}, [3 1 2 4] ), [s(3), s(1)*s(2)*q]);
0088                     res{2} = right*res{2};
0089                     res{2} = reshape( res{2}, [size(right, 1), s(1), s(2), q]);
0090                     res{2} = ipermute( res{2}, [3 1 2 4] );
0091                 else
0092                     res{1} = tensorprod_ttemps( y.U{idx(1)}, left, 1 );
0093                     res{2} = tensorprod_ttemps( y.U{idx(2)}, right, 3 );
0094                 end
0095                     
0096             end 
0097             
0098         else
0099             % Wrong IDX format.
0100             error('Unknown IDX format. Either scalar or two-dim. row-/column array expected.')
0101         end
0102             
0103     else
0104         error('Unknown IDX format. Either scalar or two-dim. row-/column array expected.')
0105     end
0106 end
0107

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005