X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=95e552720e06031d0f0b95c4c3fc6beed0b4eae7;hb=75e1ddcb8de30a4a7be16c80c4f258da662837a6;hp=bd870bc67ec1c8895abe4cd8c81d2b113e4666f9;hpb=5c298b53859b4d97aa85331034af952aae3b0c05;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index bd870bc..95e5527 100755 --- a/mygpt.py +++ b/mygpt.py @@ -774,7 +774,6 @@ class MyGPT(nn.Module): nb_blocks, nb_lines=None, caterpillar_height=None, - dim_rec_v=-1, causal=False, dropout=0.0, len_max=1e5, @@ -820,7 +819,7 @@ class MyGPT(nn.Module): return DumbRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -829,7 +828,7 @@ class MyGPT(nn.Module): return KVRec( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, nb_lines=nb_lines, attention_dropout=dropout, @@ -838,7 +837,7 @@ class MyGPT(nn.Module): return Caterpillar( dim_model=dim_model, dim_qk=dim_keys, - dim_v=dim_rec_v, + dim_v=dim_model // nb_heads, nb_heads=nb_heads, caterpillar_length=self.caterpillar_length, caterpillar_height=self.caterpillar_height,