Source code for ls_mlkit.pipeline.callback

from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import List


[docs] class CallbackEvent(Enum): # training TRAINING_START = "training_start" TRAINING_END = "training_end" EPOCH_START = "epoch_start" EPOCH_END = "epoch_end" STEP_START = "step_start" STEP_END = "step_end" # save, load BEFORE_SAVE = "before_save" AFTER_SAVE = "after_save" BEFORE_LOAD = "before_load" AFTER_LOAD = "after_load" # optimize BEFORE_COMPUTE_LOSS = "before_compute_loss" AFTER_COMPUTE_LOSS = "after_compute_loss" BEFORE_BACKWARD = "before_backward" AFTER_BACKWARD = "after_backward" BEFORE_OPTIMIZER_STEP = "before_optimizer_step" AFTER_OPTIMIZER_STEP = "after_optimizer_step" # eval BEFORE_EVAL = "before_eval" AFTER_EVAL = "after_eval" BEFORE_EVAL_STEP = "before_eval_step" AFTER_EVAL_STEP = "after_eval_step"
[docs] class BaseCallback(metaclass=ABCMeta):
[docs] @abstractmethod def on_event(self, event: CallbackEvent, *args, **kwargs): """On event Args: event (CallbackEvent): the event to trigger *args: the arguments to pass to the callback **kwargs: the keyword arguments to pass to the callback """
[docs] class CallbackManager: def __init__(self): self.callbacks: List[BaseCallback] = []
[docs] def add_callback(self, callback: BaseCallback): """Add a callback Args: callback (BaseCallback): the callback to add """ if callback is not None: self.callbacks.append(callback)
[docs] def add_callbacks(self, callbacks: List[BaseCallback]): """Add a list of callbacks Args: callbacks (List[BaseCallback]): the callbacks to add """ if callbacks is not None and len(callbacks) > 0: self.callbacks.extend(callbacks)
[docs] def trigger(self, event: CallbackEvent, *args, **kwargs): """Trigger all callbacks for a given event Args: event (CallbackEvent): the event to trigger *args: the arguments to pass to the callback **kwargs: the keyword arguments to pass to the callback """ for callback in self.callbacks: callback.on_event(event, *args, **kwargs)