Source code for ls_mlkit.util.cgraph

import torch
from torchviz import make_dot


[docs] def get_compute_graph( model: torch.nn.Module, input_shape=None, input: dict = None, dir: str = "compute_graph", filename: str = "simple_net_graph", format: str = "pdf", ) -> None: """Generate the computing graph of model (format default is pdf) Args: model (torch.nn.Module): the model to generate the computing graph input_shape (tuple, optional): the shape of the input. Defaults to None. input (dict, optional): the input of the model. Defaults to None. dir (str, optional): the directory to save the computing graph. Defaults to "compute_graph". filename (str, optional): the name of the file to save the computing graph. Defaults to "simple_net_graph". format (str, optional): the format of the computing graph. Defaults to "pdf". Returns: None """ assert input is not None or input_shape is not None, "error: input is None and input_shape is None" if input is None: example_input = torch.randn(input_shape) else: example_input = input out = model(**example_input) def extract_tensors(output): if isinstance(output, torch.Tensor): return output elif hasattr(output, "logits"): return output.logits elif hasattr(output, "loss"): return output.loss else: raise ValueError("Unsupported output type") output = extract_tensors(out) make_dot( output, params=dict(model.named_parameters()), show_attrs=True, show_saved=True, ).render(dir + "/" + filename, format=format)