self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
def forward(self, x):
+ x = torch.cat((x.new_zeros(x.size(0), 1), x), 1)
x = self.embedding(x)
x = self.trunk(x)
x = self.readout(x)
- return x
+ return x[:, :-1]
######################################################################