Tensor Decompositions¶
The philosophy of tntorch is simple: one class for all formats. Different decompositions (CP, Tucker, TT, hybrids) all use the same interface.
Note: sometimes the internal format will change automatically. For example, no recompression algorithm is known for the CP format, and running ``round()`` on a CP tensor will convert it to the TT format.
We will give a few examples of how to compress a full tensor into different tensor formats.
[1]:
import tntorch as tn
import torch
import time
import numpy as np
X, Y, Z = np.meshgrid(range(128), range(128), range(128))
full = torch.Tensor(np.sqrt(np.sqrt(X)*(Y+Z) + Y*Z**2)*(X + np.sin(Y)*np.cos(Z))) # Some analytical 3D function
print(full.shape)
torch.Size([128, 128, 128])
TT¶
To compress as a low-rank tensor train (TT), use the ranks_tt
argument:
[2]:
t = tn.Tensor(full, ranks_tt=3) # You can also pass a list of ranks
def metrics():
print(t)
print('Compression ratio: {}/{} = {:g}'.format(full.numel(), t.numel(), full.numel() / t.numel()))
print('Relative error:', tn.relative_error(full, t))
print('RMSE:', tn.rmse(full, t))
print('R^2:', tn.r_squared(full, t))
metrics()
3D TT tensor:
128 128 128
| | |
(0) (1) (2)
/ \ / \ / \
1 3 3 1
Compression ratio: 2097152/1920 = 1092.27
Relative error: tensor(0.0005)
RMSE: tensor(22.0745)
R^2: tensor(1.0000)
The TT cores are available as t.cores
.
Tucker¶
Use the ranks_tucker
argument:
[3]:
t = tn.Tensor(full, ranks_tucker=3)
metrics()
3D TT-Tucker tensor:
128 128 128
| | |
3 3 3
(0) (1) (2)
/ \ / \ / \
1 9 3 1
Compression ratio: 2097152/1269 = 1652.6
Relative error: tensor(0.0005)
RMSE: tensor(22.0752)
R^2: tensor(1.0000)
Even though technically a TT-Tucker tensor, it has the exact same expressive power as a low-rank Tucker decomposition.
The Tucker factors are t.Us
. To retrieve the full Tucker core, use tucker_core()
:
[4]:
t.tucker_core().shape
[4]:
torch.Size([3, 3, 3])
CP¶
Use the ranks_cp
argument:
[5]:
t = tn.Tensor(full, ranks_cp=3, verbose=True) # CP is computed using alternating least squares (ALS)
metrics()
ALS -- initialization time = 0.045638084411621094
iter: 0 | eps: 0.00098631 | total time: 0.0682
iter: 1 | eps: 0.00092816 | total time: 0.0896 <- converged (tol=0.0001)
3D CP tensor:
128 128 128
| | |
<0> <1> <2>
/ \ / \ / \
3 3 3 3
Compression ratio: 2097152/1152 = 1820.44
Relative error: tensor(0.0009)
RMSE: tensor(39.9936)
R^2: tensor(1.0000)
The CP factors are t.cores
(they are all 2D tensors).
Hybrid Formats¶
ranks_tucker
can be combined with the other arguments to produce hybrid decompositions:
[6]:
t = tn.Tensor(full, ranks_cp=3, ranks_tucker=3)
metrics()
t = tn.Tensor(full, ranks_tt=2, ranks_tucker=4)
metrics()
3D CP-Tucker tensor:
128 128 128
| | |
3 3 3
<0> <1> <2>
/ \ / \ / \
3 3 3 3
Compression ratio: 2097152/1179 = 1778.75
Relative error: tensor(0.0035)
RMSE: tensor(149.4028)
R^2: tensor(1.0000)
3D TT-Tucker tensor:
128 128 128
| | |
4 4 4
(0) (1) (2)
/ \ / \ / \
1 2 2 1
Compression ratio: 2097152/1568 = 1337.47
Relative error: tensor(0.0012)
RMSE: tensor(51.8083)
R^2: tensor(1.0000)
Error-bounded Decompositions¶
If you instead pass the argument eps
, a decomposition will be computed that will not exceed that relative error:
[7]:
t = tn.Tensor(full, eps=1e-5)
metrics()
3D TT-Tucker tensor:
128 128 128
| | |
4 5 6
(0) (1) (2)
/ \ / \ / \
1 4 6 1
Compression ratio: 2097152/2092 = 1002.46
Relative error: tensor(8.3402e-06)
RMSE: tensor(0.3594)
R^2: tensor(1.0000)
That will always try to compress in both Tucker and TT senses, and therefore will always produce a TT-Tucker tensor. If you only want to compress, say, in the Tucker sense, you can do:
[8]:
t = tn.Tensor(full)
t.round_tucker(eps=1e-5)
metrics()
3D TT-Tucker tensor:
128 128 128
| | |
5 4 7
(0) (1) (2)
/ \ / \ / \
1 28 7 1
Compression ratio: 2097152/3021 = 694.191
Relative error: tensor(4.0447e-06)
RMSE: tensor(0.1743)
R^2: tensor(1.0000)
And conversely, for a TT-only compression:
[9]:
t = tn.Tensor(full)
t.round_tt(eps=1e-5)
metrics()
3D TT tensor:
128 128 128
| | |
(0) (1) (2)
/ \ / \ / \
1 4 6 1
Compression ratio: 2097152/4352 = 481.882
Relative error: tensor(8.3358e-06)
RMSE: tensor(0.3592)
R^2: tensor(1.0000)