1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| n_train = 50 x_train, _ = torch.sort(torch.rand(n_train) * 5)
def f(x): return 2 * torch.sin(x) + x**0.8
y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) x_test = torch.arange(0, 5, 0.1) y_truth = f(x_test) n_test = len(x_test)
X_tile = x_train.repeat((n_train, 1))
Y_tile = y_train.repeat((n_train, 1))
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1)) net(x_train, keys, values)
class NWKernelRegression(nn.Module): def __init__(self, **kwargs): super().__init__(**kwargs) self.w = nn.Parameter(torch.rand((1,), requires_grad=True))
def forward(self, queries, keys, values): queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1])) self.attention_weights = nn.functional.softmax( -((queries - keys) * self.w)**2 / 2, dim=1) return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)
|