X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=lazy_linear.py;h=c49f0d0aed2f8cb475bfd7f1204ee8a46f33df9e;hp=97530ef2829bbae5912990adce85c31bdb4d93d9;hb=05b9b133a45ac9bd5abe6f8b6d29095f9c82797a;hpb=ca897077ed89fbc3c7e8d812ad262146a0c72b71 diff --git a/lazy_linear.py b/lazy_linear.py index 97530ef..c49f0d0 100755 --- a/lazy_linear.py +++ b/lazy_linear.py @@ -9,9 +9,9 @@ from torch import nn, Tensor ###################################################################### -class LazyLinear(nn.Module): - def __init__(self, out_dim, bias = True): +class LazyLinear(nn.Module): + def __init__(self, out_dim, bias=True): super().__init__() self.out_dim = out_dim self.bias = bias @@ -24,22 +24,25 @@ class LazyLinear(nn.Module): if self.training: self.core = nn.Linear(x.size(1), self.out_dim, self.bias) else: - raise RuntimeError('Undefined LazyLinear core in inference mode.') + raise RuntimeError("Undefined LazyLinear core in inference mode.") return self.core(x) - def named_parameters(self, memo=None, prefix=''): - assert self.core is not None, 'Parameters not yet defined' + def named_parameters(self, memo=None, prefix=""): + assert self.core is not None, "Parameters not yet defined" return super().named_parameters(memo, prefix) + ###################################################################### if __name__ == "__main__": - model = nn.Sequential(nn.Conv2d(3, 8, kernel_size = 5), - nn.ReLU(inplace = True), - LazyLinear(128), - nn.ReLU(inplace = True), - nn.Linear(128, 10)) + model = nn.Sequential( + nn.Conv2d(3, 8, kernel_size=5), + nn.ReLU(inplace=True), + LazyLinear(128), + nn.ReLU(inplace=True), + nn.Linear(128, 10), + ) # model.eval() @@ -49,4 +52,3 @@ if __name__ == "__main__": for n, x in model.named_parameters(): print(n, x.size()) -