##############################
-class Residual(nn.Module):
+class WithResidual(nn.Module):
def __init__(self, *f):
super().__init__()
self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
for _ in range(nb_blocks):
trunk_blocks += [
- Residual(
+ WithResidual(
nn.LayerNorm((dim_model,)),
QKVAttention(
dim_in = dim_model,
causal = True, attention_dropout = dropout
),
),
- Residual(
+ WithResidual(
nn.LayerNorm((dim_model,)),
nn.Linear(in_features = dim_model, out_features = dim_hidden),
nn.ReLU(),