import torch
import tntorch as tn
[docs]def weight_mask(N, weight, nsymbols=2):
"""
Accepts a string iff its number of 1's equals (or is in) `weight`
:param N: number of dimensions
:param weight: an integer (or list thereof): recognized weight(s)
:param nsymbols: slices per core (default is 2)
:return: a mask tensor
"""
if not hasattr(weight, '__len__'):
weight = [weight]
weight = torch.Tensor(weight).long()
assert weight[0] >= 0
t = tn.weight_one_hot(N, int(max(weight) + 1), nsymbols)
t.cores[-1] = torch.sum(t.cores[-1][:, :, weight], dim=2, keepdim=True)
return t
[docs]def weight_one_hot(N, r=None, nsymbols=2):
"""
Given a string with :math:`k` 1's, it produces a vector that represents :math:`k` in `one hot encoding <https://en.wikipedia.org/wiki/One-hot>`_
:param N: number of dimensions
:param r:
:param nsymbols:
:return: a vector of N zeros, except its :math:`k`-th element which is a 1
"""
if not hasattr(nsymbols, '__len__'):
nsymbols = [nsymbols]*N
assert len(nsymbols) == N
if r is None:
r = N+1
cores = []
for n in range(N):
core = torch.zeros([r, nsymbols[n], r])
core[:, 0, :] = torch.eye(r)
for s in range(1, nsymbols[n]):
core[:, s, s:] = torch.eye(r)[:, :-s]
cores.append(core)
cores[0] = cores[0][0:1, :, :]
return tn.Tensor(cores)
[docs]def weight(N, nsymbols=2):
"""
For any string, counts how many 1's it has
:param N: number of dimensions
:param nsymbols: slices per core (default is 2)
:return: a mask tensor
"""
cores = []
for n in range(N):
core = torch.eye(2)[:, None, :].repeat(1, nsymbols, 1)
core[1, :, 0] = torch.arange(nsymbols)
cores.append(core)
cores[0] = cores[0][1:2, :, :]
cores[-1] = cores[-1][:, :, 0:1]
return tn.Tensor(cores)
[docs]def length(N): # TODO
"""
:todo:
:param N:
:return:
"""
raise NotImplementedError