Source code for ls_mlkit.util.utils_for_main
import math
import os
import torch
from omegaconf import DictConfig
from torch.nn import Module
[docs]
def get_run_name(cfg: DictConfig, prefix: str = "", suffix: str = "") -> str:
run_name = f"{cfg.dataset.name}-{cfg.model.name.replace('/','_').replace('-','_')}-lr:{cfg.optimizer.lr}"
if cfg.train.train_strategy in ["steps"]:
run_name += f"-n_steps:{cfg.train.n_steps}"
if cfg.train.train_strategy in ["epochs"]:
run_name += f"-n_epochs:{cfg.train.n_epochs}"
run_name = f"{prefix}{run_name}{suffix}"
return run_name
[docs]
def get_optimizer(model, cfg: DictConfig):
import torch # type: ignore
optimizer_config = dict(cfg.optimizer)
optimizer_class_name = optimizer_config.pop("name")
optimizer = eval("torch.optim." + optimizer_class_name)(model.parameters(), **optimizer_config)
return optimizer
[docs]
def get_correct_n_training_steps(accelerator, train_set, cfg: DictConfig, inplace: bool = True):
if cfg.train.train_strategy in ["epochs"]:
effective_batch_size = accelerator.num_processes * cfg.train.batch_size
n_training_steps = math.ceil(1.0 * len(train_set) * cfg.train.n_epochs / effective_batch_size)
elif cfg.train.train_strategy in ["steps"]:
n_training_steps = cfg.train.n_steps
else:
raise ValueError(f"Train Strategy {cfg.train.train_strategy} is not supported")
if inplace:
cfg.train.n_steps = n_training_steps
return n_training_steps
[docs]
def get_learing_rate_scheduler(optimizer, accelerator, train_set, cfg: DictConfig, inplace: bool = True):
from ls_mlkit.scheduler.lr_scheduler_factory import get_lr_scheduler
n_training_steps = get_correct_n_training_steps(accelerator, train_set, cfg, inplace=inplace)
lr_scheduler = get_lr_scheduler(
optimizer=optimizer,
n_warmup_steps=cfg.train.n_warmup_steps,
n_training_steps=n_training_steps,
lr_scheduler_type=cfg.train.lr_scheduler_type,
)
return lr_scheduler
[docs]
def get_train_class():
from ls_mlkit.pipeline import MyDistributedPipeline, MyTrainingConfig
return MyDistributedPipeline, MyTrainingConfig
[docs]
def get_new_save_dir(save_dir, cfg: DictConfig, prefix: str = "", suffix: str = ""):
if save_dir is None:
save_dir = "checkpoints"
model_name = str(cfg.model.name.replace("/", "_").replace("-", "_"))
dataset_name = str(cfg.dataset.id)
lr = str(cfg.optimizer.lr)
batch_size = str(cfg.train.batch_size)
new_save_dir = os.path.join(save_dir, f"{model_name}/{dataset_name}/-lr:{lr}-bs:{batch_size}")
new_save_dir = f"{prefix}{new_save_dir}{suffix}"
return new_save_dir
[docs]
def load_checkpoint(model: Module, final_model_ckpt_path: str) -> Module:
# Handle different checkpoint formats
if final_model_ckpt_path.endswith(".safetensors"):
print(f"Loading safetensors checkpoint: {final_model_ckpt_path}")
try:
from safetensors.torch import load_file
checkpoint = load_file(final_model_ckpt_path)
# For safetensors, the state_dict is directly the checkpoint
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded model from safetensors: {final_model_ckpt_path}")
except ImportError:
print("safetensors library not found. Install with: pip install safetensors")
raise
except Exception as e:
print(f"Error loading safetensors file: {e}")
print("Trying to load model from the checkpoint...")
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded model from safetensors: {final_model_ckpt_path}")
else:
print(f"Loading PyTorch checkpoint: {final_model_ckpt_path}")
checkpoint = torch.load(final_model_ckpt_path, map_location="cpu", weights_only=False)
# print(f"Loaded PyTorch checkpoint with keys: {list(checkpoint.keys())}")
if "model" in checkpoint:
# If saved via pipeline, the diffuser is under 'model' key
model.load_state_dict(checkpoint["model"])
else:
# If saved directly as state_dict
model.load_state_dict(checkpoint)
print(f"Loaded model from PyTorch checkpoint: {final_model_ckpt_path}")
return model