Source code for tools

import tntorch as tn
import torch
import numpy as np
import time
import scipy.fftpack


"""
Array-like manipulations
"""


[docs]def squeeze(t, dim=None): """ Removes singleton dimensions. :param t: input :class:`Tensor` :param dim: which dim to delete. By default, all that have size 1 :return: another :class:`Tensor`, without dummy (singleton) indices """ if dim is None: dim = np.where([s == 1 for s in t.shape])[0] if not hasattr(dim, '__len__'): dim = [dim] assert np.all(np.array(t.shape)[dim] == 1) idx = [slice(None) for n in range(t.dim())] for m in dim: idx[m] = 0 return t[tuple(idx)]
[docs]def unsqueeze(t, dim): """ Inserts singleton dimensions at specified positions. :param t: input :class:`Tensor` :param dim: int or list of int :return: a :class:`Tensor` with dummy (singleton) dimensions inserted at the positions given by `dim` """ if not hasattr(dim, '__len__'): dim = [dim] idx = [slice(None) for n in range(t.dim()+len(dim))] for d in dim: idx[d] = None return t[tuple(idx)]
[docs]def cat(*ts, dim): """ Concatenate two or more tensors along a given dim, similarly to PyTorch's `cat()`. :param ts: a list of :class:`Tensor` :param dim: an int :return: a :class:`Tensor` of the same shape as all tensors in the list, except along `dim` where it has the sum of shapes """ if hasattr(ts[0], '__len__'): ts = ts[0] if len(ts) == 1: return ts[0].clone() if any([any([t.shape[n] != ts[0].shape[n] for n in np.delete(range(ts[0].dim()), dim)]) for t in ts[1:]]): raise ValueError('To concatenate tensors, all must have the same shape along all but the given dim') shapes = np.array([t.shape[dim] for t in ts]) sumshapes = np.concatenate([np.array([0]), np.cumsum(shapes)]) for i in range(len(ts)): t = ts[i].clone() if t.Us[dim] is None: if t.cores[dim].dim() == 2: t.cores[dim] = torch.zeros(sumshapes[-1], t.cores[dim].shape[-1]) else: t.cores[dim] = torch.zeros(t.cores[dim].shape[0], sumshapes[-1], t.cores[dim].shape[-1]) t.cores[dim][..., sumshapes[i]:sumshapes[i+1], :] += ts[i].cores[dim] else: t.Us[dim] = torch.zeros(sumshapes[-1], t.Us[dim].shape[-1]) t.Us[dim][sumshapes[i]:sumshapes[i+1], :] += ts[i].Us[dim] if i == 0: result = t else: result += t return result
[docs]def transpose(t): """ Inverts the dimension order of a tensor, e.g. :math:`I_1 \\times I_2 \\times I_3` becomes :math:`I_3 \\times I_2 \\times I_1`. :param t: input tensor :return: another :class:`Tensor`, indexed by dimensions in inverse order """ cores = [] Us = [] idxs = [] for n in range(t.dim()-1, -1, -1): if t.cores[n].dim() == 3: cores.append(t.cores[n].permute(2, 1, 0)) else: cores.append(t.cores[n]) if t.Us[n] is None: Us.append(None) else: Us.append(t.Us[n].clone()) try: idxs.append(t.idxs[n].clone()) except Exception: idxs.append(None) return tn.Tensor(cores, Us, idxs)
[docs]def meshgrid(*axes): """ See NumPy's or PyTorch's `meshgrid()`. :param axes: a list of N ints or torch vectors :return: a list of N :class:`Tensor`, of N dimensions each """ if not hasattr(axes, '__len__'): axes = [axes] if hasattr(axes[0], '__len__'): axes = axes[0] axes = list(axes) N = len(axes) for n in range(N): if not hasattr(axes[n], '__len__'): axes[n] = torch.arange(axes[n], dtype=torch.get_default_dtype()) tensors = [] for n in range(N): cores = [torch.ones(1, len(ax), 1) for ax in axes] cores[n] = torch.Tensor(axes[n].to(torch.get_default_dtype()))[None, :, None] tensors.append(tn.Tensor(cores)) return tensors
[docs]def flip(t, dim): """ Reverses the order of a tensor along one or several dimensions; see NumPy's or PyTorch's `flip()`. :param t: input :class:`Tensor` :param dims: an int or list of ints :return: another :class:`Tensor` of the same shape """ if not hasattr(dim, '__len__'): dim = [dim] shape = t.shape result = t.clone() for d in dim: idx = np.arange(shape[d]-1, -1, -1) if result.Us[d] is not None: result.Us[d] = result.Us[d][idx, :] else: result.cores[d] = result.cores[d][..., idx, :] return result
[docs]def unbind(t, dim): """ Slices a tensor along a dimension and returns the slices as a sequence, like PyTorch's `unbind()`. :param t: input :class:`Tensor` :param dim: an int :return: a list of :class:`Tensor`, as many as `t.shape[dim]` """ if dim < 0: dim += t.dim() return [t[[slice(None)]*dim + [sl] + [slice(None)]*(t.dim()-1-dim)] for sl in range(t.shape[dim])]
[docs]def unfolding(data, n): """ Computes the `n-th mode unfolding <https://epubs.siam.org/doi/pdf/10.1137/07070111X>`_ of a PyTorch tensor. :param data: a PyTorch tensor :param n: unfolding mode :return: a PyTorch matrix """ return data.permute([n] + list(range(n)) + list(range(n + 1, data.dim()))).reshape([data.shape[n], -1])
[docs]def right_unfolding(core): """ Computes the `right unfolding <https://epubs.siam.org/doi/pdf/10.1137/090752286>`_ of a 3D PyTorch tensor. :param core: a PyTorch tensor of shape :math:`I_1 \\times I_2 \\times I_3` :return: a PyTorch matrix of shape :math:`I_1 \\times I_2 I_3` """ return core.reshape([core.shape[0], -1])
[docs]def left_unfolding(core): """ Computes the `left unfolding <https://epubs.siam.org/doi/pdf/10.1137/090752286>`_ of a 3D PyTorch tensor. :param core: a PyTorch tensor of shape :math:`I_1 \\times I_2 \\times I_3` :return: a PyTorch matrix of shape :math:`I_1 I_2 \\times I_3` """ return core.reshape([-1, core.shape[-1]])
""" Multilinear algebra """
[docs]def ttm(t, U, dim=None, transpose=False): """ `Tensor-times-matrix (TTM) <https://epubs.siam.org/doi/pdf/10.1137/07070111X>`_ along one or several dimensions. :param t: input :class:`Tensor` :param U: one or several factors :param dim: one or several dimensions (may be vectors or matrices). If None, the first len(U) dims are assumed :param transpose: if False (default) the contraction is performed along U's rows, else along its columns :return: transformed :class:`Tensor` """ if not isinstance(U, (list, tuple)): U = [U] if dim is None: dim = range(len(U)) if not hasattr(dim, '__len__'): dim = [dim] dim = list(dim) for i in range(len(dim)): if dim[i] < 0: dim[i] += t.dim() cores = [] Us = [] for n in range(t.dim()): if n in dim: if transpose: factor = U[dim.index(n)].t() else: factor = U[dim.index(n)] if factor.dim() == 1: factor = factor[None, :] if t.Us[n] is None: if t.cores[n].dim() == 3: cores.append(torch.einsum('iak,ja->ijk', (t.cores[n], factor))) else: cores.append(torch.einsum('ai,ja->ji', (t.cores[n], factor))) Us.append(None) else: cores.append(t.cores[n].clone()) Us.append(torch.matmul(factor, t.Us[n])) else: cores.append(t.cores[n].clone()) if t.Us[n] is None: Us.append(None) else: Us.append(t.Us[n].clone()) return tn.Tensor(cores, Us=Us, idxs=t.idxs)
""" Miscellaneous """
[docs]def mask(t, mask): """ Masks a tensor. Basically an element-wise product, but this function makes sure slices are matched according to their "meaning" (as annotated by the tensor's `idx` field, if available) :param t: input :class:`Tensor` :param mask: a mask :class:`Tensor` :return: masked :class:`Tensor` """ if not hasattr(t, 'idxs'): idxs = [np.arange(sh) for sh in t.shape] else: idxs = t.idxs cores = [] Us = [] for n in range(t.dim()): idx = np.array(idxs[n]) idx[idx >= mask.shape[n]] = mask.shape[n]-1 # Clamp if mask.Us[n] is None: cores.append(mask.cores[n][..., idx, :]) Us.append(None) else: cores.append(mask.cores[n]) Us.append(mask.Us[n][idx, :]) mask = tn.Tensor(cores, Us) return t*mask
[docs]def sample(t, P=1): """ Generate P points (with replacement) from a joint PDF distribution represented by a tensor. The tensor does not have to sum 1 (will be handled in a normalized form). :param t: a :class:`Tensor` :param P: how many samples to draw (default: 1) :return Xs: an integer matrix of size :math:`P \\times N` """ def from_matrix(M): """ Treat each row of a matrix M as a PMF and select a column per row according to it """ M /= torch.sum(M, dim=1)[:, None] # Normalize row-wise M = np.hstack([np.zeros([M.shape[0], 1]), M]) M = np.cumsum(M, axis=1) thresh = np.random.rand(M.shape[0]) M -= thresh[:, np.newaxis] shiftand = np.logical_and(M[:, :-1] <= 0, M[:, 1:] > 0) # Find where the sign switches return np.where(shiftand)[1] N = t.dim() tsum = tn.sum(t, dim=np.arange(N), keepdim=True).decompress_tucker_factors() Xs = torch.zeros([P, N]) rights = [torch.ones(1)] for core in tsum.cores[::-1]: rights.append(torch.matmul(torch.sum(core, dim=1), rights[-1])) rights = rights[::-1] lefts = torch.ones([P, 1]) t = t.decompress_tucker_factors() for mu in range(t.dim()): fiber = torch.einsum('ijk,k->ij', (t.cores[mu], rights[mu + 1])) per_point = torch.einsum('ij,jk->ik', (lefts, fiber)) rows = from_matrix(per_point) Xs[:, mu] = torch.Tensor(rows) lefts = torch.einsum('ij,jik->ik', (lefts, t.cores[mu][:, rows, :])) return Xs
[docs]def hash(t): """ Computes an integer number that depends on the tensor entries (not on its internal compressed representation). We obtain it as :math:`\\langle T, W \\rangle`, where :math:`W` is a rank-1 tensor of weights selected at random (always the same seed). :return: an integer """ gen = torch.Generator() gen.manual_seed(0) cores = [torch.ones(1, 1, 1) for n in range(t.dim())] Us = [torch.rand([sh, 1], generator=gen) for sh in t.shape] w = tn.Tensor(cores, Us) return t.dot(w)
[docs]def generate_basis(name, shape, orthonormal=False): """ Generate a factor matrix whose columns are functions of a truncated basis. :param name: 'dct', 'legendre', 'chebyshev' or 'hermite' :param shape: two integers :param orthonormal: whether to orthonormalize the basis :return: a PyTorch matrix of `shape` """ if name == "dct": U = scipy.fftpack.dct(np.eye(shape[0]), norm="ortho")[:, :shape[1]] elif name == 'identity': U = np.eye(shape[0], shape[1]) else: eval_points = np.linspace(-1, 1, shape[0]) if name == "legendre": U = np.polynomial.legendre.legval(eval_points, np.eye(shape[0], shape[1])).T elif name == "chebyshev": U = np.polynomial.chebyshev.chebval(eval_points, np.eye(shape[0], shape[1])).T elif name == "hermite": U = np.polynomial.hermite.hermval(eval_points, np.eye(shape[0], shape[1])).T else: raise ValueError("Unsupported basis function") if orthonormal: U / np.sqrt(np.sum(U*U, axis=0)) return torch.from_numpy(U)
[docs]def reduce(ts, function, eps=0, rmax=np.iinfo(np.int32).max, algorithm='svd', verbose=False, **kwargs): """ Compute a tensor as a function to all tensors in a sequence. :Example 1 (addition): >>> import operator >>> tn.reduce([t1, t2], operator.add) :Example 2 (cat with bounded rank): >>> tn.reduce([t1, t2], tn.cat, rmax=10) :param ts: A generator (or list) of :class:`Tensor` :param eps: intermediate tensors will be rounded at this error when climbing up the hierarchy :param rmax: no node should exceed this number of ranks :param algorithm: passed to :func:`round.round()` :param verbose: Boolean :return: the reduced result """ d = dict() start = time.time() for i, elem in enumerate(ts): if verbose and i % 100 == 0: print("reduce: element {}, time={:g}".format(i, time.time()-start)) climb = 0 # For going up the tree while climb in d: elem = tn.round(function(d[climb], elem, **kwargs), eps=eps, rmax=rmax, algorithm=algorithm) d.pop(climb) climb += 1 d[climb] = elem keys = list(d.keys()) result = d[keys[0]] for key in keys[1:]: result = tn.round(function(result, d[key], **kwargs), eps=eps, rmax=rmax, algorithm=algorithm) return result