Source code for ls_mlkit.model.longLinear

import torch


[docs] class LongLinearModel(torch.nn.Module): def __init__(self, n_layers=200, dim=1024): super().__init__() self.nb = n_layers self.dim = dim self.fc = torch.nn.ModuleList([torch.nn.Linear(dim, dim, bias=False) for i in range(self.nb)])
[docs] def forward(self, x: torch.Tensor): i = 0 for m in self.fc: x = m(x) i += 1 return x