import tntorch as tn
import torch
import sys
import time
import numpy as np
import maxvolpy.maxvol
import logging
[docs]def cross(function, domain=None, tensors=None, function_arg='vectors', ranks_tt=None, kickrank=3, rmax=100, eps=1e-6, max_iter=25, val_size=1000, verbose=True, return_info=False):
"""
Cross-approximation routine that samples a black-box function and returns an N-dimensional tensor train approximating it. It accepts either:
- A domain (tensor product of :math:`N` given arrays) and a function :math:`\\mathbb{R}^N \\to \\mathbb{R}`
- A list of :math:`K` tensors of dimension :math:`N` and equal shape and a function :math:`\\mathbb{R}^K \\to \\mathbb{R}`
:Examples:
>>> tn.cross(function=lambda x: x**2, tensors=[t]) # Compute the element-wise square of `t` using 5 TT-ranks
>>> domain = [torch.linspace(-1, 1, 32)]*5
>>> tn.cross(function=lambda x, y, z, t, w: x**2 + y*z + torch.cos(t + w), domain=domain) # Approximate a function over the rectangle :math:`[-1, 1]^5`
>>> tn.cross(function=lambda x: torch.sum(x**2, dim=1), domain=domain, function_arg='matrix') # An example where the function accepts a matrix
References:
- I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_
- D. Savostyanov, I. Oseledets: `"Fast Adaptive Interpolation of Multi-dimensional Arrays in Tensor Train Format" (2011) <https://ieeexplore.ieee.org/document/6076873>`_
- S. Dolgov, R. Scheichl: `"A Hybrid Alternating Least Squares - TT Cross Algorithm for Parametric PDEs" (2018) <https://arxiv.org/pdf/1707.04562.pdf>`_
- A. Mikhalev's `maxvolpy package <https://bitbucket.org/muxas/maxvolpy>`_
- I. Oseledets (and others)'s `ttpy package <https://github.com/oseledets/ttpy>`_
:param function: should produce a vector of :math:`P` elements. Accepts either :math:`N` comma-separated vectors, or a matrix (see `function_arg`)
:param domain: a list of :math:`N` vectors (incompatible with `tensors`)
:param tensors: a :class:`Tensor` or list thereof (incompatible with `domain`)
:param function_arg: if 'vectors', `function` accepts :math:`N` vectors of length :math:`P` each. If 'matrix', a matrix of shape :math:`P \\times N`.
:param ranks_tt: int or list of :math:`N-1` ints. If None, will be determined adaptively
:param kickrank: when adaptively found, ranks will be increased by this amount after every iteration (full sweep left-to-right and right-to-left)
:param rmax: this rank will not be surpassed
:param eps: the procedure will stop after this validation error is met (as measured after each iteration)
:param max_iter: int
:param val_size: size of the validation set
:param verbose: default is True
:param return_info: if True, will also return a dictionary with informative metrics about the algorithm's outcome
:return: an N-dimensional TT :class:`Tensor` (if `return_info`=True, also a dictionary)
"""
assert domain is not None or tensors is not None
assert function_arg in ('vectors', 'matrix')
if function_arg == 'matrix':
def f(*args):
return function(torch.cat([arg[:, None] for arg in args], dim=1))
else:
f = function
if tensors is None:
tensors = tn.meshgrid(domain)
if not hasattr(tensors, '__len__'):
tensors = [tensors]
tensors = [t.decompress_tucker_factors(_clone=False) for t in tensors]
Is = list(tensors[0].shape)
N = len(Is)
# Process ranks and cap them, if needed
if ranks_tt is None:
ranks_tt = 1
else:
kickrank = None
if not hasattr(ranks_tt, '__len__'):
ranks_tt = [ranks_tt]*(N-1)
ranks_tt = [1] + list(ranks_tt) + [1]
Rs = np.array(ranks_tt)
for n in list(range(1, N)) + list(range(N-1, -1, -1)):
Rs[n] = min(Rs[n-1]*Is[n-1], Rs[n], Is[n]*Rs[n+1])
# Initialize cores at random
cores = [torch.randn(Rs[n], Is[n], Rs[n+1]) for n in range(N)]
# Prepare left and right sets
lsets = [np.array([[0]])] + [None]*(N-1)
randint = np.hstack([np.random.randint(0, Is[n+1], [max(Rs), 1]) for n in range(N-1)] + [np.zeros([max(Rs), 1])])
rsets = [randint[:Rs[n+1], n:] for n in range(N-1)] + [np.array([[0]])]
# Initialize left and right interfaces for `tensors`
def init_interfaces():
t_linterfaces = []
t_rinterfaces = []
for t in tensors:
linterfaces = [torch.ones(1, t.ranks_tt[0])] + [None]*(N-1)
rinterfaces = [None]*(N-1) + [torch.ones(t.ranks_tt[t.dim()], 1)]
for j in range(N-1):
M = torch.ones(t.cores[-1].shape[-1], len(rsets[j]))
for n in range(N-1, j, -1):
if t.cores[n].dim() == 3: # TT core
M = torch.einsum('iaj,ja->ia', (t.cores[n][:, rsets[j][:, n-1-j], :], M))
else: # CP factor
M = torch.einsum('ai,ia->ia', (t.cores[n][rsets[j][:, n-1-j], :], M))
rinterfaces[j] = M
t_linterfaces.append(linterfaces)
t_rinterfaces.append(rinterfaces)
return t_linterfaces, t_rinterfaces
t_linterfaces, t_rinterfaces = init_interfaces()
# Create a validation set
Xs_val = [torch.as_tensor(np.random.choice(I, val_size)) for I in Is]
ys_val = f(*[t[Xs_val].torch() for t in tensors])
if ys_val.dim() > 1:
assert ys_val.dim() == 2
assert ys_val.shape[1] == 1
ys_val = ys_val[:, 0]
assert len(ys_val) == val_size
norm_ys_val = torch.norm(ys_val)
if verbose:
print('Cross-approximation over a {}D domain containing {:g} grid points:'.format(N, tensors[0].numel()))
start = time.time()
converged = False
info = {
'nsamples': 0,
'eval_time': 0,
'val_epss': []
}
def evaluate_function(j): # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j]
Xs = []
for k, t in enumerate(tensors):
if tensors[k].cores[j].dim() == 3: # TT core
V = torch.einsum('ai,ibj,jc->abc', (t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j]))
else: # CP factor
V = torch.einsum('ai,bi,ic->abc', (t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j]))
Xs.append(V.flatten())
eval_start = time.time()
evaluation = f(*Xs)
info['eval_time'] += time.time() - eval_start
# Check for nan/inf values
if evaluation.dim() == 2:
evaluation = evaluation[:, 0]
invalid = (torch.isnan(evaluation) | torch.isinf(evaluation)).nonzero()
if len(invalid) > 0:
invalid = invalid[0].item()
raise ValueError('Invalid return value for function {}: f({}) = {}'.format(function, ', '.join('{:g}'.format(x[invalid].numpy()) for x in Xs),
f(*[x[invalid:invalid+1][:, None] for x in Xs]).item()))
V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]])
info['nsamples'] += V.numel()
return V
# Sweeps
for i in range(max_iter):
if verbose:
print('iter: {: <{}}'.format(i, len('{}'.format(max_iter))+1), end='')
sys.stdout.flush()
left_locals = []
# Left-to-right
for j in range(0, N-1):
# Update tensors for current indices
V = evaluate_function(j)
# QR + maxvol towards the right
V = torch.reshape(V, [-1, V.shape[2]]) # Left unfolding
Q, R = torch.qr(V)
local, _ = maxvolpy.maxvol.maxvol(Q.detach().numpy())
V = torch.gels(Q.t(), Q[local, :].t())[0].t()
cores[j] = torch.reshape(V, [Rs[j], Is[j], Rs[j+1]])
left_locals.append(local)
# Map local indices to global ones
local_r, local_i = np.unravel_index(local, [Rs[j], Is[j]])
lsets[j+1] = np.c_[lsets[j][local_r, :], local_i]
for k, t in enumerate(tensors):
if t.cores[j].dim() == 3: # TT core
t_linterfaces[k][j+1] = torch.einsum('ai,iaj->aj', (t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :]))
else: # CP factor
t_linterfaces[k][j+1] = torch.einsum('ai,ai->ai', (t_linterfaces[k][j][local_r, :], t.cores[j][local_i, :]))
# Right-to-left sweep
for j in range(N-1, 0, -1):
# Update tensors for current indices
V = evaluate_function(j)
# QR + maxvol towards the left
V = torch.reshape(V, [Rs[j], -1]) # Right unfolding
Q, R = torch.qr(V.t())
local, _ = maxvolpy.maxvol.maxvol(Q.detach().numpy())
V = torch.gels(Q.t(), Q[local, :].t())[0]
cores[j] = torch.reshape(torch.as_tensor(V), [Rs[j], Is[j], Rs[j+1]])
# Map local indices to global ones
local_i, local_r = np.unravel_index(local, [Is[j], Rs[j+1]])
rsets[j-1] = np.c_[local_i, rsets[j][local_r, :]]
for k, t in enumerate(tensors):
if t.cores[j].dim() == 3: # TT core
t_rinterfaces[k][j-1] = torch.einsum('iaj,ja->ia', (t.cores[j][:, local_i, :], t_rinterfaces[k][j][:, local_r]))
else: # CP factor
t_rinterfaces[k][j-1] = torch.einsum('ai,ia->ia', (t.cores[j][local_i, :], t_rinterfaces[k][j][:, local_r]))
# Leave the first core ready
V = evaluate_function(0)
cores[0] = V
# Evaluate validation error
val_eps = torch.norm(ys_val - tn.Tensor(cores)[Xs_val].torch()) / norm_ys_val
info['val_epss'].append(val_eps)
if val_eps < eps:
converged = True
if verbose: # Print status
print('| eps: {:.3e}'.format(val_eps), end='')
print(' | total time: {:8.4f} | largest rank: {:3d}'.format(time.time() - start, max(Rs)), end='')
if converged:
print(' <- converged: eps < {}'.format(eps))
elif i == max_iter-1:
print(' <- max_iter was reached: {}'.format(max_iter))
else:
print()
if converged:
break
elif i < max_iter-1 and kickrank is not None: # Augment ranks
newRs = Rs.copy()
newRs[1:-1] = np.minimum(rmax, newRs[1:-1]+kickrank)
for n in list(range(1, N)) + list(range(N-1, 0, -1)):
newRs[n] = min(newRs[n-1]*Is[n-1], newRs[n], Is[n]*newRs[n+1])
extra = np.hstack([np.random.randint(0, Is[n+1], [max(newRs), 1]) for n in range(N-1)] + [np.zeros([max(newRs), 1])])
for n in range(N-1):
if newRs[n+1] > Rs[n+1]:
rsets[n] = np.vstack([rsets[n], extra[:newRs[n+1]-Rs[n+1], n:]])
Rs = newRs
t_linterfaces, t_rinterfaces = init_interfaces() # Recompute interfaces
if val_eps > eps:
logging.warning('eps={:g} (larger than {}) when cross-approximating {}'.format(val_eps, eps, function))
if verbose:
print('Did {} function evaluations, which took {:.4g}s ({:.4g} evals/s)'.format(info['nsamples'], info['eval_time'], info['nsamples'] / info['eval_time']))
print()
if return_info:
info['lsets'] = lsets
info['rsets'] = rsets
info['left_locals'] = left_locals
info['total_time'] = time.time()-start
info['val_eps'] = val_eps
return tn.Tensor([torch.Tensor(c) for c in cores]), info
else:
return tn.Tensor([torch.Tensor(c) for c in cores])