X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=speed.py;h=e4add261984da7a5275ceb3401d75fa06c8413a2;hp=10a008f68dc6c3865fd3a0e00740016055260306;hb=31ffed8a581c04094e8d2727c88de6e1d2c07a65;hpb=ab248b18b1167753b4f992fc083f34ef6c2fd016 diff --git a/speed.py b/speed.py index 10a008f..e4add26 100755 --- a/speed.py +++ b/speed.py @@ -14,11 +14,11 @@ d1, d2, d3 = 2048, 2048, 2048 a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device) -sync +sync() start_time = time.perf_counter() for k in range(nb_runs): c = a @ b -sync +sync() duration = time.perf_counter() - start_time nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops