-for t in [ torch.float32, torch.float16 ]:
- a = torch.rand(d1, d2, device = device, dtype = t)
- b = torch.rand(d2, d3, device = device, dtype = t)
+for t in [torch.float32, torch.float16]:
+ try:
+ a = torch.rand(d1, d2, device=device, dtype=t)
+ b = torch.rand(d2, d3, device=device, dtype=t)
+ nb_runs = 0
+
+ sync()
+ start_time = time.perf_counter()
+ while time.perf_counter() - start_time < max_duration:
+ c = a @ b
+ nb_runs += 1
+ sync()
+ duration = time.perf_counter() - start_time