Source code for ls_mlkit.model.decoder_tf.generate

import torch

from .tokenizer import get_collate_fn, get_masks


[docs] def greedy_decode(model, tokenizer, prompt, max_len, device): model.eval() with torch.no_grad(): input_seq = get_collate_fn(tokenizer, max_len=500, train=False)([prompt])["x"] output_seq = [] for _ in range(max_len): input_seq = input_seq.to(device) att_mask, pad_mask = get_masks(input_seq, tokenizer) att_mask = att_mask.to(device) pad_mask = pad_mask.to(device) model.to(device) logits = model(input_seq, att_mask, pad_mask) logits = logits[:, -1, :] next_token = logits.argmax(dim=-1) output_seq.append(next_token.item()) input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1) if next_token.item() == tokenizer.eos_id: break return output_seq