+
+ if memex_mask is None:
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ else:
+ loss = F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ loss_regular = (loss * (1 - memex_mask)).mean()
+ loss_memex = (loss * memex_mask).mean()
+
+ if not torch.is_tensor(movave_dot_products) or torch.rand(1) < 0.01:
+ dot_products = the_dot_products(
+ loss_regular, loss_memex, model.parameters()
+ )
+ eps = 1e-3
+ movave_dot_products = (
+ 1 - eps
+ ) * movave_dot_products + eps * dot_products
+
+ grgr, grgm, gmgm = movave_dot_products
+ l = (max(grgr, gmgm) - grgr) / gmgm
+ loss = loss_regular + l * loss_memex
+