X-Git-Url: https://www.fleuret.org/cgi-bin/gitweb/gitweb.cgi?p=pytorch.git;a=blobdiff_plain;f=speed.py;h=10a008f68dc6c3865fd3a0e00740016055260306;hp=f682a169076aca29f3f7b4b695bb55fc179b305d;hb=ab248b18b1167753b4f992fc083f34ef6c2fd016;hpb=31b3e7141d2acb0cb1ff298d2b7fca3911889f25 diff --git a/speed.py b/speed.py index f682a16..10a008f 100755 --- a/speed.py +++ b/speed.py @@ -4,20 +4,24 @@ import time, torch if torch.cuda.is_available(): device = torch.device('cuda') + sync = lambda: torch.cuda.synchronize() else: device = torch.device('cpu') + sync = lambda: None nb_runs = 10000 -d1, d2, d3 = 50000, 256, 512 +d1, d2, d3 = 2048, 2048, 2048 a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device) +sync start_time = time.perf_counter() for k in range(nb_runs): c = a @ b +sync duration = time.perf_counter() - start_time -nb_flop = float(nb_runs * d1 * d2 * d3) +nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops speed = nb_flop / duration for u in [ '', 'K', 'M', 'G', 'T', 'P' ]: