Source code for ls_mlkit.model.decoder_tf.tokenizer

import collections
import copy
import json
import pickle
from typing import List

import torch
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm


[docs] class Tokenizer: def __init__(self): self.special_vocab = None self.inv_vocab = None self.vocab = None self.has_add_special_tokens = False self.has_build_vocab = False self.default_special_token_list = ["[BOS]", "[EOS]", "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"] self.eos_token = None self.pad_token = None self.set_eos_token("[EOS]") self.set_pad_token("[PAD]") self.eos_id = (None,) self.pad_id = (None,)
[docs] def set_eos_token(self, val: str = "[EOS]"): self.eos_token = val
[docs] def set_pad_token(self, val: str = "[PAD]"): self.pad_token = val
[docs] def tokenize(self, text): return list(text)
[docs] def convert_token_to_id(self, token_list): assert self.has_build_vocab, "haven't build vocab, please call <build_vocab> method fist! " id_list = list() for token in token_list: id_list.append(self.vocab.get(token, self.vocab.get("UNK", 0))) return id_list
[docs] def convert_id_to_token(self, id_list): token_list = list() for index in id_list: token_list.append(self.inv_vocab.get(index, "[UNK]")) return token_list
[docs] def build_vocab(self, text_list: list, max_vocab_size=10000, min_freq=1): counter = collections.Counter() p_bar = tqdm(total=len(text_list), desc="counting token in texts") for text in text_list: tokens = list(text) counter.update(tokens) p_bar.update(1) if not self.has_add_special_tokens: self.add_special_tokens(self.default_special_token_list) vocab = copy.deepcopy(self.special_vocab) p_bar = tqdm(total=max(counter.total(), max_vocab_size), desc="specifying id to tokens") for token, freq in counter.most_common(max_vocab_size): if freq >= min_freq and token not in vocab: vocab[token] = len(vocab) p_bar.update(1) self.vocab = vocab self.inv_vocab = {v: k for k, v in self.vocab.items()} self.has_build_vocab = True self.eos_id = vocab[self.eos_token] self.pad_id = vocab[self.pad_token]
[docs] def add_special_tokens(self, special_token_list: List[str]): self.special_vocab = {special_token_list[i]: i for i in range(len(special_token_list))} self.has_add_special_tokens = True
[docs] def save_state_dict(self, save_directory="model_pretrained/gpt2"): state_dict = self.__dict__ state_dict["vocab"] = self.vocab with open(f"{save_directory}/tokenizer.pkl", "wb") as f: pickle.dump(self.__dict__, f) with open(f"{save_directory}/vocab.json", "w") as f: json.dump(self.vocab, f)
[docs] def load_state_dict(self, save_directory="model_pretrained/gpt2"): with open(f"{save_directory}/tokenizer.pkl", "rb") as f: state_dict = pickle.load(f) self.__dict__.update(state_dict)
[docs] def get_vocab_size(self): return len(self.vocab)
[docs] def get_masks(data: torch.Tensor, tokenizer): seq_len = data.shape[1] attention_mask = ( (torch.ones(seq_len, seq_len) - torch.triu(torch.ones(seq_len, seq_len))).type(torch.bool).transpose(0, 1) ) padding_mask = data == tokenizer.pad_id return attention_mask, padding_mask
[docs] def get_collate_fn(tokenizer: Tokenizer, max_len: int = 500, train=True): def transform_text_to_tensor(text: str, tokenizer: Tokenizer): return torch.Tensor( tokenizer.convert_token_to_id(tokenizer.tokenize(text) + ([tokenizer.eos_token] if train else [])) ) def collate_fn(batch): collated_batch = [] for sample in batch: collated_batch.append(transform_text_to_tensor(sample.rstrip("\n"), tokenizer)) collated_batch = pad_sequence( collated_batch, padding_value=tokenizer.convert_token_to_id([tokenizer.pad_token])[0], batch_first=True ) collated_batch = collated_batch.long()[:, :max_len] attention_mask, padding_mask = get_masks(collated_batch[:, :-1] if train else collated_batch, tokenizer) result = { "x": collated_batch[:, :-1] if train else collated_batch, "y": collated_batch[:, 1:], "att_mask": attention_mask, "pad_mask": padding_mask, } return result return collate_fn