Update.
[python.git] / covid19.py
index 0fd9ecc..fdfe16e 100755 (executable)
@@ -1,70 +1,89 @@
 #!/usr/bin/env python
 
-import os, time, math
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import os, time
 import numpy, csv
 import matplotlib.pyplot as plt
 import matplotlib.dates as mdates
 import urllib.request
 
-url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv'
+######################################################################
+
+
+def gentle_download(url, delay=86400):
+    filename = url[url.rfind("/") + 1 :]
+    if not os.path.isfile(filename) or os.path.getmtime(filename) < time.time() - delay:
+        print(f"Retrieving {url}")
+        urllib.request.urlretrieve(url, filename)
+    return filename
 
-file = 'time_series_19-covid-Confirmed.csv'
 
 ######################################################################
 
-if not os.path.isfile(file) or os.path.getmtime(file) < time.time() - 86400:
-    print('Retrieving file')
-    urllib.request.urlretrieve(url, file)
+nbcases_filename = gentle_download(
+    "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv"
+)
 
 ######################################################################
 
-with open(file, newline='') as csvfile:
-    reader = csv.reader(csvfile, delimiter=',')
+with open(nbcases_filename, newline="") as csvfile:
+    reader = csv.reader(csvfile, delimiter=",")
     times = []
     nb_cases = {}
     time_col = 5
     for row_nb, row in enumerate(reader):
         for col_nb, field in enumerate(row):
-            if row_nb >= 1 and col_nb == 1:
-                country = field
-                if not country in nb_cases:
-                    nb_cases[country] = numpy.zeros(len(times))
-                # print(country)
             if row_nb == 0 and col_nb >= time_col:
-                times.append(time.mktime(time.strptime(field, '%m/%d/%y')))
-            if row_nb == 1 and col_nb == time_col:
-                nb_cases['World'] = numpy.zeros(len(times))
+                times.append(time.mktime(time.strptime(field, "%m/%d/%y")))
             if row_nb >= 1:
-                if col_nb >= time_col:
-                    nb_cases['World'][col_nb - time_col] += int(field)
+                if col_nb == 1:
+                    country = field
+                    if not country in nb_cases:
+                        nb_cases[country] = numpy.zeros(len(times))
+                elif col_nb >= time_col:
+                    # if field == '': field = '0'
                     nb_cases[country][col_nb - time_col] += int(field)
 
+countries = list(nb_cases.keys())
+countries.sort()
+print("Countries: ", countries)
+
+nb_cases["World"] = sum(nb_cases.values())
+
 ######################################################################
 
 fig = plt.figure()
 ax = fig.add_subplot(1, 1, 1)
 
-ax.grid(color='gray', linestyle='-', linewidth=0.25)
+ax.yaxis.grid(color="gray", linestyle="-", linewidth=0.25)
+ax.set_title("Nb. of COVID-19 cases")
+ax.set_xlabel("Date", labelpad=10)
+ax.set_yscale("log")
 
-ax.set_title('Nb. of COVID-19 cases')
-ax.set_xlabel('Date', labelpad = 10)
-ax.set_yscale('log')
+myFmt = mdates.DateFormatter("%b %d")
 
-myFmt = mdates.DateFormatter('%b %d')
 ax.xaxis.set_major_formatter(myFmt)
 dates = mdates.epoch2num(times)
 
-for label, color in [ ('World', 'blue'),
-                      ('Switzerland', 'red'),
-                      ('France', 'green'),
-                      ('South Korea', 'gray'),
-                      ('Mainland China', 'orange') ]:
-    ax.plot(dates, nb_cases[label], color = color, label = label)
+for key, color, label in [
+    ("World", "blue", "World"),
+    ("Switzerland", "red", "Switzerland"),
+    ("France", "lightgreen", "France"),
+    ("US", "black", "USA"),
+    ("Korea, South", "gray", "South Korea"),
+    ("Italy", "purple", "Italy"),
+    ("China", "orange", "China"),
+]:
+    ax.plot(dates, nb_cases[key], color=color, label=label, linewidth=2)
 
-# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), frameon = False)
-ax.legend(frameon = False)
+ax.legend(frameon=False)
 
 plt.show()
-# fig.savefig('covid19.svg')
+
+fig.savefig("covid19_nb_cases.png")
 
 ######################################################################