projects
/
beaver.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update
[beaver.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
75adbf6
..
7166788
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-132,13
+132,28
@@
class AddPositionalEncoding(nn.Module):
class QKVAttention(nn.Module):
def __init__(
class QKVAttention(nn.Module):
def __init__(
- self, dim_in, dim_qk, dim_v, nb_heads=1, causal=False, attention_dropout=0.0
+ self,
+ dim_in,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ causal=False,
+ attention_dropout=0.0,
+ amm_generator=None,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+ if amm_generator is None:
+ self.amm_generator = (
+ lambda d: torch.arange(d)[None, None, :, None]
+ < torch.arange(d)[None, None, None, :]
+ )
+ else:
+ self.amm_generator = amm_generator
+
self.causal = causal
self.attention_dropout = attention_dropout
self.causal = causal
self.attention_dropout = attention_dropout
@@
-175,10
+190,7
@@
class QKVAttention(nn.Module):
if self.causal:
if bs_q.first == 0:
if self.causal:
if bs_q.first == 0:
- self.cache_attzero = (
- torch.arange(x_q.size(1), device=q.device)[None, None, :, None]
- < torch.arange(x_q.size(1), device=q.device)[None, None, None, :]
- )
+ self.cache_attzero = self.amm_generator(x_q.size(1)).to(q.device)
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
a = a.masked_fill(
self.cache_attzero[
:, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
@@
-215,6
+227,7
@@
class MyGPT(nn.Module):
causal=False,
dropout=0.0,
len_max=1e5,
causal=False,
dropout=0.0,
len_max=1e5,
+ amm_generator=None,
):
super().__init__()
):
super().__init__()
@@
-238,6
+251,7
@@
class MyGPT(nn.Module):
nb_heads=nb_heads,
causal=causal,
attention_dropout=dropout,
nb_heads=nb_heads,
causal=causal,
attention_dropout=dropout,
+ amm_generator=amm_generator,
),
),
WithResidual(
),
),
WithResidual(