From: Francois Fleuret Date: Fri, 29 Apr 2022 12:08:20 +0000 (+0200) Subject: Update. X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=mygpt.git;a=commitdiff_plain;h=063e25c1e1442c406746a39220f3c3590882cf51 Update. --- diff --git a/mygpt.py b/mygpt.py index 7bf25b5..a23470b 100755 --- a/mygpt.py +++ b/mygpt.py @@ -119,3 +119,18 @@ class MyGPT(nn.Module): return x ###################################################################### + +if __name__ == '__main__': + vocabulary_size = 10 + x = torch.randint(vocabulary_size, (25, 100)) + + model = MyGPT( + vocabulary_size = vocabulary_size, + dim_model = 16, dim_keys = 50, dim_hidden = 100, + nb_heads = 2, nb_blocks = 3, + dropout = 0.1 + ) + + y = model(x) + +######################################################################