X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pysvrt.git;a=blobdiff_plain;f=cnn-svrt.py;h=3550d854b1d5bb240b0cabf969f3715d5b604d48;hp=1d5e887c832f59cbb9d56317f3413ba4d24dbe03;hb=2fd030cd849fa7879211128c15d4a1fbf9d6e7f4;hpb=3f3a2df9cb54730206a94a294d60d48422333a11 diff --git a/cnn-svrt.py b/cnn-svrt.py index 1d5e887..3550d85 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -122,6 +122,11 @@ class Net(nn.Module): def train_model(train_input, train_target): model, criterion = Net(), nn.CrossEntropyLoss() + nb_parameters = 0 + for p in model.parameters(): + nb_parameters += p.numel() + log_string('NB_PARAMETERS {:d}'.format(nb_parameters)) + if torch.cuda.is_available(): model.cuda() criterion.cuda()