简化版笔记,只看transform结构具体是怎么运行的,涉及到推理过程。
不涉及具体原理,详细版可参考这篇或者查阅其他相关优秀文章。
基本结构
Transformer默认由encoder和decoder组成:
核心结构:
每个核心的Block包含:
- Layer Norm
- Multi headed attention
- A skip connection
- Second layer Norm
- Feed Forward network
- Another skip connection
看下llama decoder部分代码,摘自transformers/models/llama/modeling_llama.py
,整个forward过程和上图一模一样, 只是layer_norm换成了LlamaRMSNorm:
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connecte
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
Multihead Attention
核心就是Multihead Attention,多头注意力,要理解多头先从单头开始。
单个attention
即Scaled Dot-Product Attention,
MHA
自注意力在多个头部之间并行应用,最后将结果连接在一起。
看一下llama中的操作:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
# 这行代码是一个检查条件,确保hidden_size能够被num_heads整除。
# 在多头注意力(Multi-Head Attention, MHA)机制中,输入的hidden_size被分割成多个头,每个头处理输入的一个子集。
# head_dim是每个头处理的维度大小,它由hidden_size除以num_heads得到。
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self._init_rope()
def _init_rope(self):
# 省略
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
多头注意力(Multi-Head Attention, MHA)的计算过程可以总结为以下几个步骤:
假设输入张量hidden_states
的维度为[batch_size, seq_length, hidden_size]
。
-
线性投影:
- 查询(Q): (Q = hidden\_states \times W^Q)
- 键(K): ( K = hidden\_states \times W^K )
- 值(V): ( V = hidden\_states \times W^V )
其中, ( W^Q, W^K, W^V \in \mathbb{R}^{hidden\_size \times (num\_heads \times head\_dim)} ) 是可学习的参数矩阵。线性投影后,
Q
,K
,V
的维度均为[batch_size, seq_length, num_heads * head_dim]
。 -
重塑和转置:
- 对
Q
,K
,V
进行重塑和转置,以支持多头计算。新的维度为[batch_size, num_heads, seq_length, head_dim]
。
- 对
-
应用RoPE编码(根据情况使用):
Q
和K
经过RoPE编码后维度不变,依然是[batch_size, num_heads, seq_length, head_dim]
。
-
计算注意力:
- ( Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V )
其中, ( \sqrt{d_k} ) 是缩放因子,通常为
head_dim
的平方根。注意力分数的维度为[batch_size, num_heads, seq_length, seq_length]
。 -
应用注意力掩码(如果有):
- 注意力掩码用于修改注意力分数,以阻止模型关注某些特定位置。掩码的维度通常为
[batch_size, 1, seq_length, seq_length]
,应用后注意力分数维度不变。
- 注意力掩码用于修改注意力分数,以阻止模型关注某些特定位置。掩码的维度通常为
-
计算加权和:
- 加权的值
V
计算为attn_output = Attention(Q, K, V)
,attn_output
的维度为[batch_size, num_heads, seq_length, head_dim]
。
- 加权的值
-
重塑和线性投影:
attn_output
重塑回[batch_size, seq_length, num_heads * head_dim]
,然后通过一个输出线性层,将维度投影回[batch_size, seq_length, hidden_size]
。
总结为公式,多头注意力的输出可以表示为:
[ \text{MHA}(hidden\_states) = Concat(\text{head}_1, \text{head}_2, ..., \text{head}_{\text{num\_heads}})W^O ]
其中,
[ \text{head}_i = \text{Attention}(hidden\_statesW^Q_i, hidden\_statesW^K_i, hidden\_statesW^V_i) ]
并且 ( W^O \in \mathbb{R}^{(num\_heads \times head\_dim) \times hidden\_size} ) 是另一个可学习的参数矩阵。
这个过程实现了将输入通过多个注意力"头"并行处理的能力,每个"头"关注输入的不同部分,最终的输出是所有"头"输出的拼接,再经过一个线性变换。这种机制增强了模型的表达能力,使其能够从多个子空间同时捕获信息。
feed forward
token
一个 token embedding table,为每个token提供embeddings 。
还有一个positional embedding table,帮助网络理解每个块中token的relative positions。
decoder
Cross-Attention
stable diffusion中使用
参考
文中部分图片来源如下:
- https://twitter.com/akshay_pachaar/status/1741074169272713577
- Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch
- https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
- https://jalammar.github.io/illustrated-transformer/
- https://zhuanlan.zhihu.com/p/420820453
- Notebooks