Source code for tensor

import numpy as np
import torch
import tntorch as tn
torch.set_default_dtype(torch.float64)
import time


def _full_rank_tt(data):  # Naive TT formatting, don't even attempt to compress
    data = data.to(torch.get_default_dtype())
    shape = data.shape
    result = []
    N = data.dim()
    data = torch.Tensor(data) if type(data) is not torch.Tensor else data
    device = data.device
    resh = torch.reshape(data, [shape[0], -1])
    for n in range(1, N):
        if resh.shape[0] < resh.shape[1]:
            result.append(torch.reshape(torch.eye(resh.shape[0]).to(device), [resh.shape[0] // shape[n - 1],
                                                                       shape[n - 1], resh.shape[0]]))
            resh = torch.reshape(resh, (resh.shape[0] * shape[n], resh.shape[1] // shape[n]))
        else:
            result.append(torch.reshape(resh, [resh.shape[0] // shape[n - 1],
                                                   shape[n - 1], resh.shape[1]]))
            resh = torch.reshape(torch.eye(resh.shape[1]).to(device), (resh.shape[1] * shape[n], resh.shape[1] // shape[n]))
    result.append(torch.reshape(resh, [resh.shape[0] // shape[N - 1], shape[N - 1], 1]))
    return result


[docs]class Tensor(object): """ Class for all tensor networks. Currently supported: `tensor train (TT) <https://epubs.siam.org/doi/pdf/10.1137/090752286>`_, `CANDECOMP/PARAFAC (CP) <https://epubs.siam.org/doi/pdf/10.1137/07070111X>`_, `Tucker <https://epubs.siam.org/doi/pdf/10.1137/S0895479898346995>`_, and hybrid formats. Internal representation: an ND tensor has N cores, with each core following one of four options: - Size :math:`R_{n-1} \\times I_n \\times R_n` (standard TT core) - Size :math:`R_{n-1} \\times S_n \\times R_n` (TT-Tucker core), accompanied by an :math:`I_n \\times S_n` factor matrix - Size :math:`I_n \\times R` (CP factor matrix) - Size :math:`S_n \\times R_n` (CP-Tucker core), accompanied by an :math:`I_n \\times S_n` factor matrix """ def __init__(self, data, Us=None, idxs=None, device=None, requires_grad=None, ranks_cp=None, ranks_tucker=None, ranks_tt=None, eps=None, max_iter=25, tol=1e-4, verbose=False): """ The constructor can either: - Decompose an uncompressed tensor - Use an explicit list of tensor cores (and optionally, factors) See `this notebook <https://github.com/rballester/tntorch/blob/master/tutorials/decompositions.ipynb>`_ for examples of use. :param data: a NumPy ndarray, PyTorch tensor, or a list of cores (which can represent either CP factors or TT cores) :param Us: optional list of Tucker factors :param idxs: annotate maskable tensors (*advanced users*) :param device: PyTorch device :param requires_grad: Boolean :param ranks_cp: an integer (or list) :param ranks_tucker: an integer (or list) :param ranks_tt: an integer (or list) :param eps: maximal error :param max_iter: maximum number of iterations when computing a CP decomposition using ALS :param tol: stopping criterion (change in relative error) when computing a CP decomposition using ALS :param verbose: Boolean :return: a :class:`Tensor` """ if isinstance(data, (list, tuple)): if not all([2 <= d.dim() <= 3 for d in data]): raise ValueError('All tensor cores must have 2 (for CP) or 3 (for TT) dimensions') for n in range(len(data)-1): if (data[n+1].dim() == 3 and data[n].shape[-1] != data[n+1].shape[0]) or (data[n+1].dim() == 2 and data[n].shape[-1] != data[n+1].shape[1]): raise ValueError('Core ranks do not match') self.cores = data N = len(data) else: if isinstance(data, np.ndarray): data = torch.Tensor(data, device=device) elif isinstance(data, torch.Tensor): data = data.to(device) else: raise ValueError('A tntorch.Tensor may be built either from a list of cores, one NumPy ndarray, or one PyTorch tensor') N = data.dim() if Us is None: Us = [None]*N self.Us = Us if isinstance(data, torch.Tensor): if data.dim() == 0: data = data*torch.ones(1, device=device) if ranks_cp is not None: # Compute CP from full tensor: CP-ALS if ranks_tt is not None: raise ValueError('ALS for CP-TT is not yet supported') assert not hasattr(ranks_cp, '__len__') start = time.time() if verbose: print('ALS', end='') if ranks_tucker is not None: # CP on Tucker's core self.cores = _full_rank_tt(data) self.round_tucker(rmax=ranks_tucker) data = self.tucker_core() data_norm = tn.norm(data) self.cores = [torch.randn(sh, ranks_cp, device=device) for sh in data.shape] else: # We initialize CP factor to HOSVD data_norm = torch.norm(data) self.cores = [] for n in range(data.dim()): gram = tn.unfolding(data, n) gram = gram.matmul(gram.t()) eigvals, eigvecs = torch.symeig(gram, eigenvectors=True) # Sort eigenvectors in decreasing importance reverse = np.arange(len(eigvals)-1, -1, -1) # Negative steps not yet supported in PyTorch idx = np.argsort(eigvals.to('cpu'))[reverse[:ranks_cp]] self.cores.append(eigvecs[:, idx]) if self.cores[-1].shape[1] < ranks_cp: # Complete with random entries self.cores[-1] = torch.cat((self.cores[-1], torch.randn(self.cores[-1].shape[0], ranks_cp-self.cores[-1].shape[1])), dim=1) if verbose: print(' -- initialization time =', time.time() - start) grams = [None] + [self.cores[n].t().matmul(self.cores[n]) for n in range(1, self.dim())] errors = [] converged = False for iter in range(max_iter): for n in range(self.dim()): khatri = torch.ones(1, ranks_cp, device=device) prod = torch.ones(ranks_cp, ranks_cp, device=device) for m in range(self.dim()-1, -1, -1): if m != n: prod *= grams[m] khatri = torch.reshape(torch.einsum('ir,jr->ijr', (self.cores[m], khatri)), [-1, ranks_cp]) unfolding = tn.unfolding(data, n) self.cores[n] = torch.gels(unfolding.matmul(khatri).t(), prod)[0].t() grams[n] = self.cores[n].t().matmul(self.cores[n]) errors.append(torch.norm(data - tn.Tensor(self.cores).torch()) / data_norm) if len(errors) >= 2 and errors[-2] - errors[-1] < tol: converged = True if verbose: print('iter: {: <{}} | eps: '.format(iter, len('{}'.format(max_iter))), end='') print('{:.8f}'.format(errors[-1]), end='') print(' | total time: {:9.4f}'.format(time.time() - start), end='') if converged: print(' <- converged (tol={})'.format(tol)) elif iter == max_iter-1: print(' <- max_iter was reached: {}'.format(max_iter)) else: print() if converged: break else: self.cores = _full_rank_tt(data) self.Us = [None]*data.dim() if ranks_tucker is not None: self.round_tucker(rmax=ranks_tucker) if ranks_tt is not None: self.round_tt(rmax=ranks_tt) # Check factor shapes for n in range(self.dim()): if self.Us[n] is None: continue assert self.Us[n].dim() == 2 assert self.cores[n].shape[-2] == self.Us[n].shape[1] # Set cores/Us requires_grad, if needed if requires_grad: for n in range(self.dim()): if self.Us[n] is not None: self.Us[n].requires_grad_() self.cores[n].requires_grad_() if idxs is None: idxs = [torch.arange(sh, device=device) for sh in self.shape] self.idxs = idxs if eps is not None: # TT-SVD (or TT-EIG) algorithm if ranks_cp is not None or ranks_tucker is not None or ranks_tt is not None: raise ValueError('Specify eps or ranks, but not both') self.round(eps) """ Arithmetic operations """ def __add__(self, other): if not isinstance(other, Tensor): factor = other other = Tensor([torch.ones([1, self.shape[n], 1]) for n in range(self.dim())]) other.cores[0].data *= factor if self.dim() == 1: # Special case return Tensor([self.decompress_tucker_factors().cores[0] + other.decompress_tucker_factors().cores[0]]) this, other = _broadcast(self, other) cores = [] Us = [] for n in range(this.dim()): core1 = this.cores[n] core2 = other.cores[n] # CP + CP -> CP, other combinations -> TT if core1.dim() == 2 and core2.dim() == 2: core1 = core1[None, :, :] core2 = core2[None, :, :] else: core1 = self._cp_to_tt(core1) core2 = self._cp_to_tt(core2) if this.Us[n] is not None and other.Us[n] is not None: # if core1.shape[1] + core2.shape[1] >= self.Us[n] and core1.shape[1] + core2.shape[1] >= self.Us[n] slice1 = torch.cat([core1, torch.zeros([core2.shape[0], core1.shape[1], core1.shape[2]])], dim=0) slice1 = torch.cat([slice1, torch.zeros(core1.shape[0]+core2.shape[0], core1.shape[1], core2.shape[2])], dim=2) slice2 = torch.cat([torch.zeros([core1.shape[0], core2.shape[1], core2.shape[2]]), core2], dim=0) slice2 = torch.cat([torch.zeros(core1.shape[0]+core2.shape[0], core2.shape[1], core1.shape[2]), slice2], dim=2) c = torch.cat([slice1, slice2], dim=1) cores.append(c) Us.append(torch.cat((self.Us[n], other.Us[n]), dim=1)) continue if this.Us[n] is not None: core1 = torch.einsum('ijk,aj->iak', (core1, self.Us[n])) if other.Us[n] is not None: core2 = torch.einsum('ijk,aj->iak', (core2, other.Us[n])) column1 = torch.cat([core1, torch.zeros([core2.shape[0], this.shape[n], core1.shape[2]], device=core1.device)], dim=0) column2 = torch.cat([torch.zeros([core1.shape[0], this.shape[n], core2.shape[2]], device=core2.device), core2], dim=0) c = torch.cat([column1, column2], dim=2) cores.append(c) Us.append(None) # First core should have first size 1 (if it's TT) if not (this.cores[0].dim() == 2 and other.cores[0].dim() == 2): cores[0] = torch.sum(cores[0], dim=0, keepdim=True) # Similarly for the last core and last size if not (this.cores[-1].dim() == 2 and other.cores[-1].dim() == 2): cores[-1] = torch.sum(cores[-1], dim=2, keepdim=True) # Set up cores that should be CP cores for n in range(0, this.dim()): if this.cores[n].dim() == 2 and other.cores[n].dim() == 2: cores[n] = torch.sum(cores[n], dim=0, keepdim=False) return Tensor(cores, Us=Us) def __radd__(self, other): if other is None: return self return self + other def __sub__(self, other): return self + -1*other def __rsub__(self, other): return -1*self + other def __neg__(self): return -1*self def __mul__(self, other): if not isinstance(other, Tensor): # A scalar result = self.clone() result.cores[0].data *= other return result this, other = _broadcast(self, other) cores = [] Us = [] for n in range(this.dim()): core1 = this.cores[n] core2 = other.cores[n] # CP + CP -> CP, other combinations -> TT if core1.dim() == 2 and core2.dim() == 2: core1 = core1[None, :, :] core2 = core2[None, :, :] else: core1 = this._cp_to_tt(core1) core2 = this._cp_to_tt(core2) # We do the product core along 3 axes, unless it would blow up if this.Us[n] is not None and other.Us[n] is not None and this.cores[n].shape[1]*other.cores[n].shape[1] < this.shape[n]: cores.append(torch.reshape(torch.einsum('ijk,abc->iajbkc', (core1, core2)), (core1.shape[0]*core2.shape[0], core1.shape[1]*core2.shape[1], core1.shape[2]*core2.shape[2]))) Us.append(torch.reshape(torch.einsum('ij,ik->ijk', (this.Us[n], other.Us[n])), (this.Us[n].shape[0], -1))) else: # Decompress spatially, then do normal TT-TT slice-wise kronecker product if this.Us[n] is not None: core1 = torch.einsum('ijk,aj->iak', (core1, this.Us[n])) if other.Us[n] is not None: core2 = torch.einsum('ijk,aj->iak', (core2, other.Us[n])) cores.append(_core_kron(core1, core2)) Us.append(None) if this.cores[n].dim() == 2 and other.cores[n].dim() == 2: cores[-1] = cores[-1][0, :, :] return tn.Tensor(cores, Us=Us) def __truediv__(self, other): return tn.cross(function=lambda x, y: x / y, tensors=[self, other], verbose=False) def __rtruediv__(self, other): return tn.cross(function=lambda x, y: x / y, tensors=[tn.full_like(self, fill_value=other), self], verbose=False) def __pow__(self, power): return tn.cross(function=lambda x, y: x**y, tensors=[self, tn.full_like(self, fill_value=power)], verbose=False) """ Boolean logic """ def __rmul__(self, other): return self * other def __truediv__(self, other): return self * (1./other) def __invert__(self): return 1 - self def __and__(self, other): return self*other def __or__(self, other): return self+other - self*other def __xor__(self, other): return self+other - 2*self*other def __eq__(self, other): return tn.dist(self, other) <= 1e-14 def __ne__(self, other): return not self == other """ Shapes and ranks """ @property def shape(self): """ Returns the shape of this tensor. :return: a PyTorch shape object """ shape = [] for n in range(self.dim()): if self.Us[n] is None: shape.append(self.cores[n].shape[-2]) else: shape.append(self.Us[n].shape[0]) return torch.Size(shape) @property def ranks_tt(self): """ Returns the TT ranks of this tensor. :return: a vector of integers """ if self.cores[0].dim() == 2: first = self.cores[0].shape[1] else: first = self.cores[0].shape[0] return np.array([first] + [c.shape[-1] for c in self.cores]) @ranks_tt.setter def ranks_tt(self, value): self.round_tt(rmax=value) @property def ranks_tucker(self): """ Returns the Tucker ranks of this tensor. :return: a vector of integers """ return np.array([c.shape[-2] for c in self.cores]) @ranks_tucker.setter def ranks_tucker(self, value): self.round_tucker(rmax=value)
[docs] def dim(self): """ Returns the number of dimensions of this tensor. :return: an int """ return len(self.cores)
[docs] def size(self): """ Alias for :meth:`shape` (as PyTorch does) """ return self.shape
def __repr__(self): format = [] if any([c.dim() == 3 for c in self.cores]): format.append('TT') if any([c.dim() == 2 for c in self.cores]): format.append('CP') if any([U is not None for U in self.Us]): format.append('Tucker') format = '-'.join(format) s = '{}D {} tensor:\n'.format(self.dim(), format) s += '\n' ttr = self.ranks_tt tuckerr = self.ranks_tucker if any([U is not None for U in self.Us]): # Shape row = [' ']*(4*self.dim()-1) shape = self.shape for n in range(self.dim()): if self.Us[n] is None: continue lenn = len('{}'.format(shape[n])) row[n*4-lenn//2+2:n*4-lenn//2+lenn+2] = '{}'.format(shape[n]) s += ''.join(row) s += '\n' # Tucker ranks row = [' ']*(4*self.dim()-1) for n in range(self.dim()): if self.Us[n] is None: lenr = len('{}'.format(tuckerr[n])) row[n*4-lenr//2+2:n*4-lenr//2+lenr+2] = '{}'.format(tuckerr[n]) else: row[n*4+2:n*4+3] = '|' s += ''.join(row) s += '\n' row = [' ']*(4*self.dim()-1) for n in range(self.dim()): if self.Us[n] is None: row[n*4+2:n*4+3] = '|' else: lenr = len('{}'.format(tuckerr[n])) row[n*4-lenr//2+2:n*4-lenr//2+lenr+2] = '{}'.format(tuckerr[n]) s += ''.join(row) s += '\n' # Nodes row = [' ']*(4*self.dim()-1) for n in range(self.dim()): if self.cores[n].dim() == 2: nodestr = '<{}>'.format(n) else: nodestr = '({})'.format(n) lenn = len(nodestr) row[(n+1)*4-(lenn-1)//2:(n+1)*4-(lenn-1)//2+lenn] = nodestr s += ''.join(row[2:]) s += '\n' # TT rank bars s += ' / \\'*self.dim() s += '\n' # Bottom: TT/CP ranks row = [' ']*(4*self.dim()) for n in range(self.dim()+1): lenr = len('{}'.format(ttr[n])) row[n*4:n*4+lenr] = '{}'.format(ttr[n]) s += ''.join(row) s += '\n' return s """ Decompression """ def _process_key(self, key): if not hasattr(key, '__len__'): key = (key,) fancy = False if isinstance(key, torch.Tensor): key = key.detach().numpy() if any([not np.isscalar(k) for k in key]): # Fancy key = list(key) fancy = True if isinstance(key, tuple): key = list(key) elif not fancy: key = [key] # Process ellipsis, if any nonecount = sum(1 for k in key if k is None) for i in range(len(key)): if key[i] is Ellipsis: key = key[:i] + [slice(None)] * (self.dim() - (len(key) - nonecount) + 1) + key[i + 1:] break if any([k is Ellipsis for k in key]): raise IndexError("Only one ellipsis is allowed, at most") if self.dim() - (len(key) - nonecount) < 0: raise IndexError("Too many index entries") # Fill remaining unspecified dimensions with slice(None) key = key + [slice(None)] * (self.dim() - (len(key) - nonecount)) return key def __getitem__(self, key): """ NumPy-style indexing for compressed tensors. There are 5 accessors supported: slices, index arrays, integers, None, or another Tensor (selection via binary indexing) - Index arrays can be lists, tuples, or vectors - All index arrays must have the same length P - In NumPy, index arrays and slices can be interleaved. We do not admit that, as it requires expensive transpose operations """ # Preprocessing if isinstance(key, Tensor): if torch.abs(tn.sum(key)-1) > 1e-8: raise ValueError("When indexing via a mask tensor, that mask should have exactly 1 accepting string") s = tn.accepted_inputs(key)[0] slicing = [] for n in range(self.dim()): idx = self.idxs[n].long() idx[idx > 1] = 1 idx = np.where(idx == s[n])[0] sl = slice(idx[0], idx[-1]+1) lenidx = len(idx) if lenidx == 1: sl = idx.item() slicing.append(sl) return self[slicing] if isinstance(key, torch.Tensor): key = np.array(key, dtype=np.int) if isinstance(key, np.ndarray) and key.ndim == 2: key = [key[:, col] for col in range(key.shape[1])] key = self._process_key(key) last_mode = None factors = {'int': None, 'index': None, 'index_done': False} cores = [] Us = [] counter = 0 def join_cores(c1, c2): if c1.dim() == 1 and c2.dim() == 2: return torch.einsum('i,ai->ai', (c1, c2)) elif c1.dim() == 2 and c2.dim() == 2: return torch.einsum('ij,aj->iaj', (c1, c2)) elif c1.dim() == 1 and c2.dim() == 3: return torch.einsum('i,iaj->iaj', (c1, c2)) elif c1.dim() == 2 and c2.dim() == 3: return torch.einsum('ij,jak->iak', (c1, c2)) else: raise ValueError def insert_core(factors, core=None, key=None, U=None): if factors['index'] is not None: if factors['int'] is not None: factors['index'] = join_cores(factors['int'], factors['index']) factors['int'] = None cores.append(factors['index']) Us.append(None) factors['index'] = None factors['index_done'] = True if core is not None: if factors['int'] is not None: # There is a previous 1D/2D core (CP/Tucker) from an integer slicing if U is None: cores.append(join_cores(factors['int'], core[..., key, :])) Us.append(None) else: cores.append(join_cores(factors['int'], core)) Us.append(U[key, :]) factors['int'] = None else: # Easiest case if U is None: cores.append(core[..., key, :]) Us.append(None) else: cores.append(core) Us.append(U[key, :]) def get_key(counter, key): if self.Us[counter] is None: return self.cores[counter][..., key, :] else: sl = self.Us[counter][key, :] if sl.dim() == 1: # key is an int if self.cores[counter].dim() == 3: return torch.einsum('ijk,j->ik', (self.cores[counter], sl)) else: return torch.einsum('ji,j->i', (self.cores[counter], sl)) else: if self.cores[counter].dim() == 3: return torch.einsum('ijk,aj->iak', (self.cores[counter], sl)) else: return torch.einsum('ji,aj->ai', (self.cores[counter], sl)) for i in range(len(key)): if hasattr(key[i], '__len__'): this_mode = 'index' elif key[i] is None: this_mode = 'none' elif isinstance(key[i], (int, np.integer)): this_mode = 'int' elif isinstance(key[i], slice): this_mode = 'slice' else: raise IndexError if this_mode == 'none': insert_core(factors, torch.eye(self.ranks_tt[counter].item())[:, None, :], key=slice(None), U=None) elif this_mode == 'slice': insert_core(factors, self.cores[counter], key=key[i], U=self.Us[counter]) counter += 1 elif this_mode == 'index': if factors['index_done']: raise IndexError("All index arrays must appear contiguously") if factors['index'] is None: factors['index'] = get_key(counter, key[i]) else: if factors['index'].shape[-2] != len(key[i]): raise ValueError("Index arrays must have the same length") a1 = factors['index'] a2 = get_key(counter, key[i]) if a1.dim() == 2 and a2.dim() == 2: factors['index'] = torch.einsum('ai,ai->ai', (a1, a2)) elif a1.dim() == 2 and a2.dim() == 3: factors['index'] = torch.einsum('ai,iaj->iaj', (a1, a2)) elif a1.dim() == 3 and a2.dim() == 2: factors['index'] = torch.einsum('iaj,aj->iaj', (a1, a2)) elif a1.dim() == 3 and a2.dim() == 3: # Until https://github.com/pytorch/pytorch/issues/10661 is fully resolved # TODO check efficiency for other cases factors['index'] = torch.sum(a1[:, :, :, None]*a2.permute(1, 0, 2)[None, :, :, :], dim=2) # factors['index'] = torch.einsum('iaj,jak->iak', (a1, a2)) counter += 1 elif this_mode == 'int': if last_mode == 'index': insert_core(factors) if factors['int'] is None: factors['int'] = get_key(counter, key[i]) else: c1 = factors['int'] c2 = get_key(counter, key[i]) if c1.dim() == 1 and c2.dim() == 1: factors['int'] = torch.einsum('i,i->i', (c1, c2)) elif c1.dim() == 1 and c2.dim() == 2: factors['int'] = torch.einsum('i,ij->ij', (c1, c2)) elif c1.dim() == 2 and c2.dim() == 1: factors['int'] = torch.einsum('ij,j->ij', (c1, c2)) elif c1.dim() == 2 and c2.dim() == 2: factors['int'] = torch.einsum('ij,jk->ik', (c1, c2)) counter += 1 last_mode = this_mode # At the end: handle possibly pending factors if last_mode == 'index': insert_core(factors, core=None, key=None, U=None) elif last_mode == 'int': if len(cores) > 0: # We return a tensor: absorb existing cores with int factor if cores[-1].dim() == 2 and factors['int'].dim() == 1: cores[-1] = torch.einsum('ai,i->ai', (cores[-1], factors['int'])) elif cores[-1].dim() == 2 and factors['int'].dim() == 2: cores[-1] = torch.einsum('ai,ij->iaj', (cores[-1], factors['int'])) elif cores[-1].dim() == 3 and factors['int'].dim() == 1: cores[-1] = torch.einsum('iaj,j->ai', (cores[-1], factors['int'])) elif cores[-1].dim() == 3 and factors['int'].dim() == 2: cores[-1] = torch.einsum('iaj,jk->iak', (cores[-1], factors['int'])) else: # We return a scalar if factors['int'].numel() > 1: return torch.sum(factors['int']) return torch.squeeze(factors['int']) return tn.Tensor(cores, Us=Us) def __setitem__(self, key, value): # TODO not fully working yet key = self._process_key(key) scalar = False if isinstance(value, np.ndarray): value = tn.Tensor(torch.Tensor(value)) elif isinstance(value, torch.Tensor): if value.dim() == 0: value = value.item() scalar = True # value = value*torch.ones(self.shape) else: value = tn.Tensor(value) elif isinstance(value, tn.Tensor): pass else: # It's a scalar scalar = True subtract_cores = [] add_cores = [] for i in range(len(key)): if not isinstance(key[i], slice) and not hasattr(key[i], '__len__'): key[i] = slice(key[i], key[i]+1) chunk = self.cores[i][..., key[i], :] subtract_core = torch.zeros_like(self.cores[i]) subtract_core[..., key[i], :] += chunk subtract_cores.append(subtract_core) if scalar: if self.cores[i].dim() == 3: add_core = torch.zeros(1, self.shape[i], 1) else: add_core = torch.zeros(self.shape[i], 1) add_core[..., key[i], :] += 1 if i == 0: add_core *= value else: if chunk.shape[1] != value.shape[i]: raise ValueError('{}-th dimension mismatch in tensor assignment: {} (lhs) != {} (rhs)'.format(i, chunk.shape[1], value.shape[i])) if self.cores[i].dim() == 3: add_core = torch.zeros(value.cores[i].shape[0], self.shape[i], value.cores[i].shape[2]) else: add_core = torch.zeros(self.shape[i], value.cores[i].shape[1]) add_core[..., key[i], :] += value.cores[i] add_cores.append(add_core) # print(tn.Tensor(subtract_cores), tn.Tensor(add_cores)) result = self - tn.Tensor(subtract_cores) + tn.Tensor(add_cores) self.__init__(result.cores, result.Us, self.idxs)
[docs] def tucker_core(self): """ If this is a Tucker-like tensor, returns its Tucker core as an explicit PyTorch tensor. If this tensor does not have Tucker factors, then it returns the full decompressed tensor. :return: a PyTorch tensor """ return tn.Tensor(self.cores).torch()
[docs] def decompress_tucker_factors(self, dim='all', _clone=True): """ Decompresses this tensor along the Tucker factors only. :param dim: int, list, or 'all' (default) :return: a :class:`Tensor` in CP/TT format, without Tucker factors """ if dim == 'all': dim = range(self.dim()) if not hasattr(dim, '__len__'): dim = [dim]*self.dim() cores = [] Us = [] for n in range(self.dim()): if n in dim and self.Us[n] is not None: if self.cores[n].dim() == 2: cores.append(torch.einsum('jk,aj->ak', (self.cores[n], self.Us[n]))) else: cores.append(torch.einsum('ijk,aj->iak', (self.cores[n], self.Us[n]))) Us.append(None) else: if _clone: cores.append(self.cores[n].clone()) if self.Us[n] is not None: Us.append(self.Us[n].clone()) else: Us.append(None) else: cores.append(self.cores[n]) Us.append(self.Us[n]) return tn.Tensor(cores, Us, idxs=self.idxs)
[docs] def tt(self): """ Casts this tensor as a pure TT format. :return: a :class:`Tensor` in the TT format """ t = self.decompress_tucker_factors() t._cp_to_tt() return t
[docs] def torch(self): """ Decompresses this tensor into a PyTorch tensor. :return: a PyTorch tensor """ t = self.decompress_tucker_factors(_clone=False) shape = [] device = t.cores[0].device factor = torch.ones(1, self.ranks_tt[0]).to(device) for n in range(t.dim()): shape.append(t.cores[n].shape[-2]) if t.cores[n].dim() == 2: # CP core if n < t.dim() - 1: factor = torch.einsum('ai,bi->abi', (factor, t.cores[n])) else: factor = torch.einsum('ai,bi->ab', (factor, t.cores[n]))[..., None] else: # TT core factor = torch.einsum('ai,ibj->abj', (factor, t.cores[n])) factor = factor.reshape([-1, factor.shape[-1]]) if factor.shape[-1] > 1: factor = torch.sum(factor, dim=-1) else: factor = factor[..., 0] factor = factor.reshape(shape) return factor
[docs] def numpy(self): """ Decompresses this tensor into a NumPy ndarray. :return: a NumPy tensor """ return self.torch().detach().numpy()
def _cp_to_tt(self, factor=None): """ Turn a CP factor into a TT core (each slice is a diagonal matrix) :param factor: an integer between 0 to N-1. If None, all cores in this tensor will be converted """ if factor is None: if self.cores[0].dim() == 2: self.cores[0] = self.cores[0][None, :, :] for mu in range(1, self.dim()-1): self.cores[mu] = self._cp_to_tt(self.cores[mu]) if self.cores[-1].dim() == 2: self.cores[-1] = self.cores[-1].transpose(1, 0)[:, :, None] return if factor.dim() == 3: # Already a TT core return factor core = torch.zeros(factor.shape[1], factor.shape[1] + 1, factor.shape[0]) core[:, 0, :] = factor.t() return core.reshape(factor.shape[1] + 1, factor.shape[1], factor.shape[0]).permute(0, 2, 1)[:-1, :, :] """ Rounding and orthogonalization """
[docs] def factor_orthogonalize(self, mu): """ Pushes the factor's non-orthogonal part to its corresponding core. This method works in place. :param mu: an int between 0 and N-1 """ if self.Us[mu] is None: return Q, R = torch.qr(self.Us[mu]) self.Us[mu] = Q if self.cores[mu].dim() == 2: self.cores[mu] = torch.einsum('jk,aj->ak', (self.cores[mu], R)) else: self.cores[mu] = torch.einsum('ijk,aj->iak', (self.cores[mu], R))
[docs] def left_orthogonalize(self, mu): """ Makes the mu-th core left-orthogonal and pushes the R factor to its right core. This may change the ranks of the cores. This method works in place. Note: internally, this method will turn CP (or CP-Tucker) cores into TT (or TT-Tucker) ones. :param mu: an int between 0 and N-1 :return: the R factor """ assert 0 <= mu < self.dim()-1 self.factor_orthogonalize(mu) Q, R = torch.qr(tn.left_unfolding(self.cores[mu])) self.cores[mu] = torch.reshape(Q, self.cores[mu].shape[:-1] + (Q.shape[1], )) rightcoreR = tn.right_unfolding(self.cores[mu+1]) self.cores[mu+1] = torch.reshape(torch.mm(R, rightcoreR), (R.shape[0], ) + self.cores[mu+1].shape[1:]) return R
[docs] def right_orthogonalize(self, mu): """ Makes the mu-th core right-orthogonal and pushes the L factor to its left core. Note: this may change the ranks of the tensor. This method works in place. Note: internally, this method will turn CP (or CP-Tucker) cores into TT (or TT-Tucker) ones. :param mu: an int between 0 and N-1 :return: the L factor """ assert 1 <= mu < self.dim() self.factor_orthogonalize(mu) Q, L = torch.qr(tn.right_unfolding(self.cores[mu]).permute(1, 0)) # Torch has no rq() decomposition L = L.permute(1, 0) Q = Q.permute(1, 0) self.cores[mu] = torch.reshape(Q, (Q.shape[0], ) + self.cores[mu].shape[1:]) leftcoreL = tn.left_unfolding(self.cores[mu-1]) self.cores[mu-1] = torch.reshape(torch.mm(leftcoreL, L), self.cores[mu-1].shape[:-1] + (L.shape[1], )) return L
[docs] def orthogonalize(self, mu): """ Apply all left and right orthogonalizations needed to make the tensor mu-orthogonal. This method works in place. Note: internally, this method will turn CP (or CP-Tucker) cores into TT (or TT-Tucker) ones. :param mu: an int between 0 and N-1 :return: L, R: left and right factors """ if mu < 0: mu += self.dim() self._cp_to_tt() L = torch.ones(1, 1) R = torch.ones(1, 1) for i in range(0, mu): R = self.left_orthogonalize(i) for i in range(self.dim()-1, mu, -1): L = self.right_orthogonalize(i) return R, L
[docs] def round_tucker(self, eps=1e-14, rmax=None, dim='all', algorithm='svd'): """ Tries to recompress this tensor in place by reducing its Tucker ranks. Note: this method will turn CP (or CP-Tucker) cores into TT (or TT-Tucker) ones. :param eps: this relative error will not be exceeded :param rmax: all ranks should be rmax at most (default: no limit) :param algorithm: 'svd' (default) or 'eig'. The latter can be faster, but less accurate :param verbose: """ N = self.dim() if not hasattr(rmax, '__len__'): rmax = [rmax]*N assert len(rmax) == N if dim == 'all': dim = range(N) if not hasattr(dim, '__len__'): dim = [dim]*N for m in dim: self.cores[m] = self._cp_to_tt(self.cores[m]) self.orthogonalize(-1) for mu in range(N-1, -1, -1): if self.Us[mu] is None: device = self.cores[mu].device self.Us[mu] = torch.eye(self.shape[mu]).to(device) # Send non-orthogonality to factor Q, R = torch.qr(torch.reshape(self.cores[mu].permute(0, 2, 1), [-1, self.cores[mu].shape[1]])) self.cores[mu] = torch.reshape(Q, [self.cores[mu].shape[0], self.cores[mu].shape[2], -1]).permute(0, 2, 1) self.Us[mu] = torch.matmul(self.Us[mu], R.t()) # Split factor according to error budget left, right = tn.truncated_svd(self.Us[mu], eps=eps/np.sqrt(len(dim)), rmax=rmax[mu], left_ortho=True, algorithm=algorithm) self.Us[mu] = left # Push the (non-orthogonal) remainder to the core self.cores[mu] = torch.einsum('ijk,aj->iak', (self.cores[mu], right)) # Prepare next iteration if mu > 0: self.right_orthogonalize(mu)
[docs] def round_tt(self, eps=1e-14, rmax=None, algorithm='svd', verbose=False): """ Tries to recompress this tensor in place by reducing its TT ranks. Note: this method will turn CP (or CP-Tucker) cores into TT (or TT-Tucker) ones. :param eps: this relative error will not be exceeded :param rmax: all ranks should be rmax at most (default: no limit) :param algorithm: 'svd' (default) or 'eig'. The latter can be faster, but less accurate :param verbose: """ N = self.dim() if not hasattr(rmax, '__len__'): rmax = [rmax]*(N-1) assert len(rmax) == N-1 self._cp_to_tt() start = time.time() self.orthogonalize(N-1) # Make everything left-orthogonal if verbose: print('Orthogonalization time:', time.time() - start) delta = eps/max(1, torch.sqrt(torch.Tensor([N-1])))*torch.norm(self.cores[-1]) delta = delta.item() for mu in range(N - 1, 0, -1): M = tn.right_unfolding(self.cores[mu]) left, right = tn.truncated_svd(M, delta=delta, rmax=rmax[mu-1], left_ortho=False, algorithm=algorithm, verbose=verbose) self.cores[mu] = torch.reshape(right, [-1, self.cores[mu].shape[1], self.cores[mu].shape[2]]) self.cores[mu-1] = torch.einsum('ijk,kl', (self.cores[mu-1], left)) # Pass factor to the left
[docs] def round(self, eps=1e-14, **kwargs): """ General recompression. Attempts to reduce TT ranks first; then does Tucker rounding with the remaining error budget. :param eps: this relative error will not be exceeded :param kwargs: passed to `round_tt()` and `round_tucker()` """ copy = self.clone() self.round_tt(eps, **kwargs) reached = tn.relative_error(copy, self) if reached < eps: self.round_tucker((1+eps) / (1+reached) - 1, **kwargs)
""" Convenience "methods" """
[docs] def dot(self, other, **kwargs): """ See :func:`metrics.dot()`. """ return tn.dot(self, other, **kwargs)
[docs] def mean(self, **kwargs): """ See :func:`metrics.mean()`. """ return tn.mean(self, **kwargs)
[docs] def sum(self, **kwargs): """ See :func:`metrics.sum()`. """ return tn.sum(self, **kwargs)
[docs] def var(self, **kwargs): """ See :func:`metrics.var()`. """ return tn.var(self, **kwargs)
[docs] def std(self, **kwargs): """ See :func:`metrics.std()`. """ return tn.std(self, **kwargs)
[docs] def norm(self, **kwargs): """ See :func:`metrics.norm()`. """ return tn.norm(self, **kwargs)
[docs] def normsq(self, **kwargs): """ See :func:`metrics.normsq()`. """ return tn.normsq(self, **kwargs)
""" Miscellaneous """
[docs] def set_factors(self, name, dim='all', requires_grad=False): """ Sets factors Us of this tensor to be of a certain family. :param name: See :func:`tools.generate_basis()` :param dim: list of factors to set; default is 'all' :param requires_grad: whether the new factors should be optimizable. Default is False """ if dim == 'all': dim = range(self.dim()) for m in dim: if self.Us[m] is None: self.Us[m] = tn.generate_basis(name, (self.shape[m], self.shape[m])) else: self.Us[m] = tn.generate_basis(name, self.Us[m].shape) self.Us[m].requires_grad = requires_grad
[docs] def as_leaf(self): """ Makes this tensor a leaf (optimizable) tensor, thus forgetting the operations from which it arose. :Example: >>> t = tn.rand([10]*3, requires_grad=True) # Is a leaf >>> t *= 2 # Is not a leaf >>> t.as_leaf() # Is a leaf again """ for n in range(self.dim()): if self.Us[n] is not None: if self.Us[n].requires_grad: self.Us[n] = self.Us[n].detach().clone().requires_grad_() else: self.Us[n] = self.Us[n].detach().clone() if self.cores[n].requires_grad: self.cores[n] = self.cores[n].detach().clone().requires_grad_() else: self.cores[n] = self.cores[n].detach().clone()
[docs] def clone(self): """ Creates a copy of this tensor (calls PyTorch's `clone()` on all internal tensor network nodes) :return: another compressed tensor """ cores = [self.cores[n].clone()for n in range(self.dim())] Us = [] for n in range(self.dim()): if self.Us[n] is None: Us.append(None) else: Us.append(self.Us[n].clone()) if hasattr(self, 'idxs'): return tn.Tensor(cores, Us=Us, idxs=self.idxs) return tn.Tensor(cores, Us=Us)
[docs] def numel(self): """ Counts the total number of uncompressed elements of this tensor. :return: an integer """ return torch.prod(torch.Tensor(list(self.shape)))
[docs] def numcoef(self): """ Counts the total number of compressed coefficients of this tensor. :return: an integer """ result = 0 for n in range(self.dim()): result += self.cores[n].numel() if self.Us[n] is not None: result += self.Us[n].numel() return result
[docs] def repeat(self, *rep): """ Returns another tensor repeated along one or more axes; works like PyTorch's `repeat()`. :param rep: a list, possibly longer than the tensor's number of dimensions :return: another tensor """ assert len(rep) >= self.dim() assert all([r >= 1 for r in rep]) t = self.clone() if len(rep) > self.dim(): # If requested, we add trailing new dimensions. We use CP as is cheaper for n in range(self.dim(), len(rep)): t.cores.append(torch.ones(rep[n], self.cores[-1].shape[-1])) t.Us.append(None) for n in range(self.dim()): if t.Us[n] is not None: t.Us[n] = t.Us[n].repeat(rep[n], 1) else: if t.cores[n].dim() == 3: t.cores[n] = t.cores[n].repeat(1, rep[n], 1) else: t.cores[n] = t.cores[n].repeat(rep[n], 1) return t
def _broadcast(a, b): if a.shape == b.shape: return a, b elif a.dim() != b.dim(): raise ValueError('Cannot broadcast: lhs has {} dimensions, rhs has {}'.format(a.dim(), b.dim())) result1 = a.repeat(*[int(round(max(sh2/sh1, 1))) for sh1, sh2 in zip(a.shape, b.shape)]) result2 = b.repeat(*[int(round(max(sh1 / sh2, 1))) for sh1, sh2 in zip(a.shape, b.shape)]) return result1, result2 def _core_kron(a, b): # return torch.reshape(torch.einsum('iaj,kal->ikajl', (a, b)), [a.shape[0]*b.shape[0], -1, a.shape[2]*b.shape[2]]) # Seems slower c = a[:, None, :, :, None] * b[None, :, :, None, :] c = c.reshape([a.shape[0] * b.shape[0], -1, a.shape[-1] * b.shape[-1]]) return c