Source code for ls_mlkit.model.decoder_tf.causal_transformer

import math
from typing import Optional, Tuple

import torch
from torch import Tensor
from transformers import AutoModelForCausalLM, GenerationMixin, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

ACTIVATION_MAP = {
    "relu": torch.nn.ReLU,
    "gelu": torch.nn.GELU,
    "silu": torch.nn.SiLU,
    "tanh": torch.nn.Tanh,
    "sigmoid": torch.nn.Sigmoid,
}


[docs] class FeedForwardBlock(torch.nn.Module): def __init__(self, embed_dim, k=4, dropout=0.0, bias=False, act="relu"): super().__init__() self.linear_1 = torch.nn.Linear(embed_dim, k * embed_dim, bias=bias) self.act = ACTIVATION_MAP[act]() self.linear_2 = torch.nn.Linear(k * embed_dim, embed_dim, bias=bias) self.dropout = torch.nn.Dropout(p=dropout)
[docs] def forward(self, x): x = self.linear_1(x) x = self.act(x) x = self.linear_2(x) x = self.dropout(x) return x
[docs] class MultiHeadAttention(torch.nn.Module): def __init__( self, embed_dim, num_heads, dropout=0, bias=False, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=True, device=None, dtype=None, ): super().__init__() self.d_model = embed_dim self.d_head = embed_dim // num_heads self.num_heads = num_heads self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.q_linear = torch.nn.Linear(embed_dim, self.kdim, bias=bias) self.k_linear = torch.nn.Linear(embed_dim, self.kdim, bias=bias) self.v_linear = torch.nn.Linear(embed_dim, self.vdim, bias=bias) self.dropout = torch.nn.Dropout(dropout) self.out_linear = torch.nn.Linear(self.vdim, embed_dim, bias=bias)
[docs] def merge_masks( self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], batch_size: int, seq_len: int, ) -> Tuple[Optional[Tensor], Optional[int]]: r"""Determine mask type and combine masks if necessary. If only one mask is provided, that mask and the corresponding mask type will be returned. If both masks are provided, they will be both expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` and mask type 2 will be returned Args: attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` Returns: merged_mask: merged mask mask_type: merged mask type (0, 1, or 2) """ mask_type: Optional[int] = None merged_mask: Optional[Tensor] = None if key_padding_mask is not None: mask_type = 1 # (batch_size, seq_L) merged_mask = key_padding_mask if attn_mask is not None: mask_type = 2 # Always expands attn_mask to 4D if attn_mask.dim() == 3: # (batch_size, seq_L, seq_L) # (batch_size, seq_L, seq_L) -> (batch_size, 1, seq_L, seq_L) attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) else: # attn_mask.dim() == 2: #(seq_L, seq_L) # (seq_L, seq_L) -> (1, 1, seq_L, seq_L) -> (batch_size, head_num, seq_L, seq_L) attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1) merged_mask = attn_mask_expanded if key_padding_mask is not None: # (bs, seq_L) -> (bs, 1, 1, seq_L) -> (batch_size, head_num, 1, seq_L) key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand( -1, self.num_heads, -1, -1 ) # (bs, 1 or head_num, seq_L, seq_L) + (bs, head_num, 1, seq_L) -> (bs, head_num, seq_L, seq_L) merged_mask = attn_mask_expanded + key_padding_mask_expanded # no attn_mask and no key_padding_mask, returns None, None return merged_mask, mask_type
[docs] def attention( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, use_cache: bool = False, past_key_values=None, ) -> Tuple[Tensor, Optional[Tensor]]: if not use_cache or (use_cache and past_key_values is None): mask, mask_type = self.merge_masks( attn_mask=attn_mask, key_padding_mask=key_padding_mask, batch_size=query.shape[0], seq_len=query.shape[-2], ) mask = mask.to(device=query.device, dtype=query.dtype) # (bs, head_num, seq_L, kdim) @ (bs, head_num, kdim, seq_L) -> (bs, head_num, seq_L, seq_L) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_head) if not use_cache or (use_cache and past_key_values is None): scores = scores.masked_fill(mask != 0.0, float("-inf")) scores = torch.softmax(scores, dim=-1) scores = self.dropout(scores) # (bs, head_num, seq_L, seq_L) @ (bs, head_num, seq_L, vdim) -> (bs, head_num, seq_L, vdim) output = torch.matmul(scores, value) if need_weights: if average_attn_weights: # (bs, head_num, seq_L, seq_L) -> (bs, seq_L, seq_L) scores = torch.mean(scores, dim=1) return output, scores else: return output
[docs] def forward( self, q, k, v, key_padding_mask=None, attn_mask=None, average_attn_weights=True, need_weights=True, use_cache=False, past_key_values=None, is_causal=False, ): # q,k,v size(bs, seq_L, d_model) bs = q.size(0) k = self.k_linear(k).view(bs, -1, self.num_heads, self.d_head) q = self.q_linear(q).view(bs, -1, self.num_heads, self.d_head) v = self.v_linear(v).view(bs, -1, self.num_heads, self.d_head) # bs, seq_L, head_num, head_dim)-> (bs, head_num, seq_L, head_dim) k = k.transpose(1, 2) q = q.transpose(1, 2) v = v.transpose(1, 2) if use_cache: if past_key_values is not None: past_k, past_v = past_key_values k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) else: pass x = self.attention( query=q, key=k, value=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=need_weights, average_attn_weights=average_attn_weights, use_cache=use_cache, past_key_values=past_key_values, ) att_weight = None if need_weights: x, att_weight = x x = x.transpose(1, 2).contiguous().view(bs, -1, self.d_model) x = self.out_linear(x) result = { "x": x, "attentions": att_weight if need_weights else None, "past_key_values": (k, v) if use_cache else None, } return result
[docs] class AttentionBlock(torch.nn.Module): def __init__(self, embed_dim, num_heads, dropout=0, batch_first=True): super().__init__() self.att = MultiHeadAttention( embed_dim, num_heads, dropout=dropout, bias=False, add_bias_kv=False, add_zero_attn=False, kdim=embed_dim, vdim=embed_dim, batch_first=batch_first, device=None, dtype=None, ) self.dropout = torch.nn.Dropout(p=dropout)
[docs] def forward( self, x, att_mask=None, key_padding_mask=None, need_weights=True, average_attn_weights=True, is_causal=True, use_cache=False, past_key_values=None, ): # x.shape = (batch_size, seq_len, embed_dim) # if use_cache, x.shape = (batch_size, 1, embed_dim) att_out = self.att( q=x, k=x, v=x, attn_mask=att_mask, key_padding_mask=key_padding_mask, need_weights=need_weights, average_attn_weights=average_attn_weights, use_cache=use_cache, past_key_values=past_key_values, is_causal=is_causal, ) x = att_out["x"] x = self.dropout(x) return { "x": x, "attentions": att_out["attentions"] if need_weights else None, "past_key_values": att_out["past_key_values"] if use_cache else None, }
[docs] class TransformerBlock(torch.nn.Module): def __init__(self, embed_dim, num_head, dropout, batch_first=True): super().__init__() self.att = AttentionBlock( embed_dim=embed_dim, num_heads=num_head, dropout=dropout, batch_first=batch_first, ) self.ff = FeedForwardBlock(embed_dim=embed_dim, k=4, dropout=dropout)
[docs] def forward( self, x, att_mask=None, key_padding_mask=None, need_weights=True, average_attn_weights=True, use_cache=False, past_key_values=None, is_causal=True, ): x_residual = x att_out = self.att( x, att_mask=att_mask, key_padding_mask=key_padding_mask, need_weights=need_weights, average_attn_weights=average_attn_weights, use_cache=use_cache, past_key_values=past_key_values, is_causal=is_causal, ) x = att_out["x"] x = x_residual + x x_residual = x x = self.ff(x) x = x_residual + x return { "x": x, "attentions": att_out["attentions"] if need_weights else None, "past_key_values": att_out["past_key_values"] if use_cache else None, }
[docs] class CausalLanguageModel(torch.nn.Module): def __init__( self, vocab_size, embed_dim, num_head, dropout=0, num_block=3, max_pos_len=5000, batch_first=True, ): super().__init__() self.wte = torch.nn.Embedding(vocab_size, embed_dim) self.dropout = torch.nn.Dropout(p=dropout) self.blocks = torch.nn.ModuleList( [ TransformerBlock( embed_dim=embed_dim, num_head=num_head, dropout=dropout, batch_first=batch_first, ) for i in range(num_block) ] )
[docs] def generate_square_subsequent_mask(self, sz: int, device=None, dtype=None): r"""Generate a square causal mask for the sequence. The masked positions are filled with 'True'. Unmasked positions are filled with False """ if device is None: device = torch.device("cpu") if dtype is None: dtype = torch.bool return torch.triu(torch.ones(sz, sz, device=device, dtype=dtype), diagonal=1)
[docs] def forward( self, x: torch.Tensor, att_mask: torch.Tensor = None, key_padding_mask: torch.Tensor = None, need_weights: bool = True, average_attn_weights: bool = True, use_cache: bool = False, past_key_values: torch.Tensor = None, is_causal: bool = True, need_hidden_states: bool = False, ): if is_causal and att_mask is None: att_mask = self.generate_square_subsequent_mask(x.size(1), device=x.device) att_weight_list = [] past_key_values_list = [] x = self.wte(x) x = self.dropout(x) kv_cacahe_iter = iter(past_key_values) if (use_cache and past_key_values is not None) else None hidden_states_list = [] for block in self.blocks: block_out = block( x, att_mask=att_mask, key_padding_mask=key_padding_mask, need_weights=need_weights, average_attn_weights=average_attn_weights, is_causal=is_causal, use_cache=use_cache, past_key_values=next(kv_cacahe_iter) if kv_cacahe_iter else None, ) x = block_out["x"] if need_weights: att_weight_list.append(block_out["attentions"]) if use_cache: past_key_values_list.append(block_out["past_key_values"]) if need_hidden_states: hidden_states_list.append(x) return { "x": x, "attentions": att_weight_list if need_weights else None, "past_key_values": past_key_values_list if use_cache else None, "hidden_states": hidden_states_list if need_hidden_states else None, }
[docs] class CausalLanguageModelConfig: def __init__( self, vocab_size=32000, embed_dim=1024, num_head=2, dropout=0, num_block=3, max_pos_len=5000, batch_first=True, **kwargs, ): self.vocab_size = vocab_size self.embed_dim = embed_dim self.num_head = num_head self.dropout = dropout self.num_block = num_block self.max_pos_len = max_pos_len self.batch_first = batch_first self.kwargs = kwargs
[docs] class CausalLanguageModelConfigForAuto(PretrainedConfig): model_type = "D-TF-no-PE" def __init__( self, vocab_size=30000, embed_dim=1024, num_head=2, dropout=0, num_block=3, max_pos_len=5000, batch_first=True, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.embed_dim = embed_dim self.num_head = num_head self.dropout = dropout self.num_block = num_block self.max_pos_len = max_pos_len self.batch_first = batch_first
[docs] class CausalLanguageModelForAuto(PreTrainedModel, GenerationMixin): config_class = CausalLanguageModelConfigForAuto base_model_prefix = "zls_causal_tf" def __init__(self, config: CausalLanguageModelConfigForAuto): super().__init__(config) self.model = CausalLanguageModel( vocab_size=config.vocab_size, embed_dim=config.embed_dim, num_head=config.num_head, dropout=config.dropout, num_block=config.num_block, max_pos_len=config.max_pos_len, batch_first=config.batch_first, ) self.lm_head = torch.nn.Linear(config.embed_dim, config.vocab_size, bias=False)
[docs] def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = True, average_attn_weights: bool = True, position_ids: Optional[torch.LongTensor] = None, past_key_values=None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ): # Adjust the forward method to match the expected input/output format if use_cache and past_key_values is not None: input_ids = input_ids[:, -1:] # print(f"input_ids.shape: {input_ids.shape}") model_out = self.model( input_ids, att_mask=None, key_padding_mask=(~attention_mask.bool() if attention_mask is not None else None), need_weights=output_attentions, average_attn_weights=average_attn_weights, use_cache=use_cache, past_key_values=past_key_values, need_hidden_states=output_hidden_states, ) x = model_out["x"] logits = self.lm_head(x) loss = None if labels is not None: # shift logits and labels for computing the loss # shape = (batch_size, seq_length, vocab_size) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=model_out["past_key_values"] if use_cache else None, hidden_states=model_out["hidden_states"] if output_hidden_states else None, attentions=model_out["attentions"] if output_attentions else None, cross_attentions=None, )
[docs] def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): # This method prepares inputs for the generate method return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": True, }
[docs] def get_input_embeddings(self): return self.model.wte
[docs] def get_output_embeddings(self): return self.lm_head
[docs] def register_model(): from transformers import AutoConfig, AutoModelForCausalLM model_name = CausalLanguageModelConfigForAuto.model_type AutoConfig.register(model_name, CausalLanguageModelConfigForAuto) AutoModelForCausalLM.register(CausalLanguageModelConfigForAuto, CausalLanguageModelForAuto)
[docs] def get_causal_model( vocab_size=5000, embed_dim=1024, num_head=8, dropout=0, num_block=16, max_pos_len=5000, batch_first=True, **kwargs, ): register_model() pretrained = kwargs.get("pretrained", False) model_name = kwargs.get("model_name", None) if pretrained and model_name is not None: model = AutoModelForCausalLM.from_pretrained(model_name) return model config = CausalLanguageModelConfigForAuto( vocab_size=vocab_size, embed_dim=embed_dim, num_head=num_head, dropout=dropout, num_block=num_block, max_pos_len=max_pos_len, batch_first=batch_first, ) model = AutoModelForCausalLM.from_config(config=config) return model
[docs] def generate( prompt_tokens: torch.Tensor, max_new_tokens: int, model, use_cache: bool = False, ) -> torch.Tensor: """Generate text tokens autoregressively. Args: prompt_tokens: Input token ids of shape (batch_size, seq_len) max_new_tokens: Number of new tokens to generate use_cache: Whether to use KV cache during generation Returns: Generated token ids including prompt, shape (batch_size, seq_len + max_new_tokens) """ # Store the original prompt length prompt_tokens.shape[1] # Initialize generated sequence with prompt generated = prompt_tokens.clone() # Initialize past key values for caching past_key_values = None # Generate tokens one by one for _ in range(max_new_tokens): # Forward pass with torch.no_grad(): outputs = model( input_ids=generated, attention_mask=torch.ones_like(generated).bool(), past_key_values=past_key_values if use_cache else None, use_cache=use_cache, ) # Get the next token probabilities next_token_logits = outputs.logits[:, -1, :] # Simple greedy decoding - take the most likely token next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Concatenate with generated sequence generated = torch.cat([generated, next_tokens], dim=-1) # Update past key values if using cache if use_cache: past_key_values = outputs.past_key_values return generated
[docs] def test_gen(test_model, test_config): test_model.eval() input_ids = torch.randint(0, test_config.vocab_size, size=(1, 5)) print(input_ids) gen_len = 50 generated_ids = test_model.generate( input_ids, max_length=gen_len + 5, num_return_sequences=1, do_sample=False, temperature=0, # top_k=50, # top_p=0.95, use_cache=True, past_key_values=None, ) print(generated_ids) print("HF-Generated sequence shape:", generated_ids.shape) generated_ids = generate(input_ids, max_new_tokens=gen_len, model=test_model, use_cache=True) print(generated_ids) print("use-cache-Generated sequence shape:", generated_ids.shape) generated_ids = generate(input_ids, max_new_tokens=gen_len, model=test_model, use_cache=False) print(generated_ids) print("no-cache-Generated sequence shape:", generated_ids.shape)
if __name__ == "__main__": from ...util.seed import seed_everything seed_everything(0) register_model() register_model() register_model() from transformers import AutoModelForCausalLM test_config = CausalLanguageModelConfigForAuto( vocab_size=1000, embed_dim=256, num_head=2, dropout=0.1, num_block=2, batch_first=True, ) print("Test Configuration:") print(test_config) test_model = AutoModelForCausalLM.from_config(test_config) print("\nModel structure:") print(test_model) # Test the forward pass import torch batch_size = 2 seq_length = 20 input_ids = torch.randint(0, test_config.vocab_size, size=(batch_size, seq_length)) attention_mask = torch.ones_like(input_ids).bool() outputs = test_model( input_ids, attention_mask=attention_mask, labels=input_ids, output_attentions=True, ) print("\nOutput shape:") print(outputs["logits"].shape) print(len(outputs.attentions)) print(type(outputs)) def push_to_hub(model, config): model.save_pretrained("test_model", repo_id="zengls/decoder-tf", push_to_hub=True) config.save_pretrained("test_config", repo_id="zengls/decoder-tf", push_to_hub=True) test_gen(test_model, test_config) print(test_model.config)