Source code for ls_mlkit.scheduler.lr_scheduler_factory
import math
import torch
from diffusers.optimization import get_cosine_schedule_with_warmup
from torch.optim.lr_scheduler import LambdaLR
[docs]
def get_lambda_lr_scheduler(optimizer, num_warmup_steps, num_training_steps, lr_scheduler_type="linear"):
def cosine_lr_lambda(current_step):
if current_step < num_warmup_steps:
# warmup
return float(current_step) / float(max(1, num_warmup_steps))
else:
# decay learning rate
return (
math.cos((current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps) * math.pi) + 1
) / 2
def linear_lr_lambda(current_step):
if current_step < num_warmup_steps:
# warmup
return float(current_step) / float(max(1, num_warmup_steps))
else:
# decay learning rate
return 1 - ((current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps))
def constant_lr_lambda(current_step):
if current_step < num_warmup_steps:
# warmup
return float(current_step) / float(max(1, num_warmup_steps))
else:
# constant learning rate
return 1.0
match lr_scheduler_type:
case "cosine":
lr_lambda = cosine_lr_lambda
case "linear":
lr_lambda = linear_lr_lambda
case "constant":
lr_lambda = constant_lr_lambda
return LambdaLR(optimizer, lr_lambda)
[docs]
def get_lr_scheduler(optimizer, n_warmup_steps, n_training_steps, lr_scheduler_type="linear"):
if lr_scheduler_type in ["cosine", "linear", "constant"]:
return get_lambda_lr_scheduler(optimizer, n_warmup_steps, n_training_steps, lr_scheduler_type)
elif lr_scheduler_type in ["cosine_annealing"]:
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_training_steps)
elif lr_scheduler_type in ["cosine_with_warmup"]:
return get_cosine_schedule_with_warmup(optimizer, n_warmup_steps, n_training_steps)
else:
raise ValueError(f"LR Scheduler Type {lr_scheduler_type} is not supported")