+class LearningRateScheduler:
+ def get_learning_rate(self):
+ pass
+
+ def update(self, nb_finished_epochs, loss):
+ pass
+
+ def reset(self):
+ pass
+
+ def get_state(self):
+ return vars(self)
+
+ def set_state(self, state):
+ print(f"{state=}")
+ for k, v in state.items():
+ setattr(self, k, v)
+
+
+class StepWiseScheduler(LearningRateScheduler):
+ def __init__(self, schedule):
+ self.nb_finished_epochs = 0
+ self.schedule = schedule
+
+ def get_learning_rate(self):
+ return self.schedule[self.nb_finished_epochs]
+
+ def update(self, nb_finished_epochs, loss):
+ self.nb_finished_epochs = nb_finished_epochs
+
+ def reset(self):
+ self.nb_finished_epochs = 0
+
+ def get_state(self):
+ return {"nb_finished_epochs": self.nb_finished_epochs}
+
+
+class AutoScheduler(LearningRateScheduler):
+ def __init__(self, learning_rate_init, growth=1.0, degrowth=0.2):
+ self.learning_rate_init = learning_rate_init
+ self.learning_rate = learning_rate_init
+ self.growth = growth
+ self.degrowth = degrowth
+ self.pred_loss = None
+
+ def get_learning_rate(self):
+ return self.learning_rate
+
+ def update(self, nb_finished_epochs, loss):
+ if self.pred_loss is not None:
+ if loss >= self.pred_loss:
+ self.learning_rate *= self.degrowth
+ else:
+ self.learning_rate *= self.growth
+ self.pred_loss = loss
+
+ def reset(self):
+ self.learning_rate = self.learning_rate_init
+
+ def get_state(self):
+ return {
+ "learning_rate_init": self.learning_rate_init,
+ "pred_loss": self.pred_loss,
+ }
+
+
+######################################################################
+
+