In [6]:
import numpy as np



def tt_svd(tensor: np.ndarray, truncerr: float = 10**-6, maxdim = 10**12) -> list:

    """

    Compress a tensor to a MPS/TT using the TT-SVD algorithm.



    Args:

        tensor: The input tensor

        truncerr: Truncation error for each SVD

        maxdim: Maximum MPS bond dimension

    Return:

        An MPS/TT as a list of order-3 tensors (dummy bonds are added to boundary tensors)

    """

    dimensions = tensor.shape

    order = len(dimensions)

    mps = []

    virtual_dimension = 1

    for i in range(order-1):

        # Reshape to matrix

        tensor = tensor.reshape(virtual_dimension * dimensions[i], -1)

        

        # SVD + Truncate

        U, s, Vt = np.linalg.svd(tensor)

        # Truncate singular values such that truncation error is less than or equal to truncerr

        where_lower_than_truncerr = np.where(np.cumsum(s[::-1]**2) <= truncerr**2)[0]

        number_of_singular_values_to_discard = 0 if len(where_lower_than_truncerr) == 0 else int(1 + where_lower_than_truncerr[-1])

        new_virtual_dimension = min(maxdim, max(1, len(s) - number_of_singular_values_to_discard))

        

        # Reshape and truncate U matrix, store in return list

        mps.append(U[:,:new_virtual_dimension].reshape(virtual_dimension, dimensions[i], new_virtual_dimension))

        

        # Contract s and Vt

        tensor = np.diagflat(s[:new_virtual_dimension]) @ Vt[:new_virtual_dimension,:]

        virtual_dimension = new_virtual_dimension

    mps.append(tensor.reshape(virtual_dimension, dimensions[-1], 1))

    return mps

In [7]:
# Create a random tensor

tensor = np.random.rand(2,2,2,2,2)

In [8]:
# Compress to MPS/TT

mps = tt_svd(tensor, truncerr=10**-6, maxdim=10)

In [4]:
from functools import reduce

def get_index(mps: list, index: list):

    """

    Retrieve a single element of a tensor represented by an MPS/TT by performing the matrix product



    Args:

        mps: The MPS/TT

        index: The index of the element to retrieve

    """

    return reduce(np.matmul, (site[:, ind, :] for site, ind in zip(mps, index)))[0,0]

In [5]:
# Retrieve an element

get_index(mps, [0, 0, 1, 0, 0])

Out[5]:
0.2956102700141592