ls_mlkit.diffuser.sde.score_fn_utils module

ls_mlkit.diffuser.sde.score_fn_utils.get_model_fn(model, train=False)[source]

Create a function to give the output of the score-based model.

Parameters:
  • model – The score model.

  • trainTrue for training and False for evaluation.

Returns:

A model function.

ls_mlkit.diffuser.sde.score_fn_utils.get_score_fn(sde, model, train=False, continuous=False)[source]

Wraps score_fn so that the model output corresponds to a real time-dependent score function.

Parameters:
  • sde – An sde_lib.SDE object that represents the forward SDE.

  • model – A score model.

  • trainTrue for training and False for evaluation.

  • continuous – If True, the score-based model is expected to directly take continuous time steps.

Returns:

A score function.