Skip to content

模型架构

核心路径 — 理解不同数据模态如何决定模型架构的选择,建立"数据结构 → 架构设计"的底层直觉

学习目标

完成本讲后,你应该能够:

  1. 识别 各种 AI 模型架构所对应的数据结构及其对称性/不变性
  2. 解释 为什么不同模态的数据需要不同的架构设计
  3. 比较 CNN、RNN、Transformer、GNN 的核心思想与适用场景
  4. 描述 自注意力机制如何取代循环和卷积成为通用建模范式
  5. 应用 几何深度学习框架统一理解各类模型架构

一、核心直觉:数据结构决定架构设计

模型架构不是随意发明的。每种经典架构的设计都源于其处理的数据结构所固有的对称性(symmetry)不变性(invariance)

这个问题由 Bronstein 等人在 [[02-基础/02-04-本周阅读|Geometric Deep Learning]] 中系统化阐述:

核心问题: 如果输入数据的结构发生变化(如平移、旋转、排列、缩放),模型的输出应该保持不变还是以可预测的方式变化?

不同架构对这个问题的回答决定了它的设计:

模型类型数据结构关键不变性典型应用奠基论文
CNN / ConvNet网格 (Grid)平移不变性图像、视频、任何规则网格LeNet-5 (1998), AlexNet (2012)
RNN / LSTM序列 (Sequence)时序平移不变性文本、语音、时间序列LSTM (1997)
Transformer序列 + 集合排列不变性 + 位置感知文本、图像、多模态Attention Is All You Need (2017)
Deep Sets集合 (Set)完全排列不变性点云、群体统计、元学习Deep Sets (2017)
GNN / GAT图 (Graph)图同构不变性分子、社交网络、知识图谱GCN (2017), GAT (2018)
Geometric DL齐次空间群不变性/等变性统一框架GDL Blueprint (2021)

关键洞察: 架构的本质是在模型中**编码(encode)**数据结构的先验知识。编码得越好,模型就越数据高效、越能泛化。


二、时序/序列模型:RNN 与 LSTM

数据特征

序列数据具有明确的顺序依赖关系:每个时间步的输入与前后步相关。典型例子包括文本(词的顺序影响语义)、语音(音素的时序性)、时间序列(股票价格)。

RNN 核心思想

RNN 通过**隐藏状态(hidden state)**在时间步间传递信息:

$$h_t = \tanh(W_{hh}h_{t-1} + W_{xh}x_t + b_h)$$

然而 RNN 存在严重的梯度消失/爆炸问题——长序列中早期信息很难传递到后期。

LSTM 的创新

LSTM(Long Short-Term Memory)通过引入门控机制解决这个问题:

  • 遗忘门:决定丢弃哪些旧信息
  • 输入门:决定存储哪些新信息
  • 输出门:决定输出哪些信息
  • 细胞状态(cell state):信息高速公路,让梯度顺畅流过

LSTM 使模型能够捕捉长距离依赖,成为 2010s 中后期 NLP 和语音领域的主流架构。

局限与演进

RNN/LSTM 的核心问题是顺序计算——每个时间步必须等待前一步完成,无法并行化。这为 Transformer 的登场埋下了伏笔。


三、空间卷积模型:CNN 与 ConvNet

数据特征

图像、视频等网格结构数据具有平移不变性(translation invariance):一个特征出现在图像左上角和右下角,其语义含义相同。

CNN 核心设计

CNN 通过三个关键设计编码平移不变性:

  1. 局部连接(Local Connectivity):每个神经元只与输入的一个局部区域连接(感受野),反映"相邻像素相关、远处像素不相关"的先验
  2. 权重共享(Weight Sharing):同一卷积核在整个输入上滑动,实现参数高效和局部特征检测的通用性
  3. 池化(Pooling):下采样引入空间层级结构,同时提供轻微的平移鲁棒性

关键架构演进

  • LeNet-5 (1998):卷积+池化+全连接的原始模板
  • AlexNet (2012):ReLU + Dropout + GPU 并行,引爆深度学习革命
  • VGG (2014):小卷积核堆叠策略,证明深度的重要性
  • ResNet (2016):残差连接解决深层网络退化问题,使上百层网络成为可能

局限

CNN 的核心局限是有限的感受野固定的网格拓扑。它难以处理全局依赖关系(需要极深的层级),也无法直接应用于非网格数据(如点云、图)。


四、Transformer 架构

从 Attention 到 Transformer

注意力机制最初由 Bahdanau 等人(2014)提出用于机器翻译——在解码时"关注"输入序列的不同位置。但 "Attention Is All You Need" (Vaswani et al., 2017) 做出了激进的决定:完全抛弃循环和卷积,只用注意力。

缩放点积注意力 (Scaled Dot-Product Attention)

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

  • Q (Query):当前位置"想找什么"
  • K (Key):所有位置"有什么"
  • V (Value):所有位置"提供什么信息"
  • 注意力权重由 Q 和 K 的点积计算,通过 softmax 归一化后对 V 加权求和

多头注意力 (Multi-Head Attention)

多头注意力并行运行多个注意力"头",每个头学习不同的关注模式:

$$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W_O$$

这使模型能够在不同表示子空间同时关注不同类型的依赖关系(如语法关系、语义关系、长程依赖)。

Transformer 的关键优势

维度RNN/LSTMCNNTransformer
计算并行性❌ 顺序计算✅ 可并行✅ 完全并行
长程依赖❌ 随距离衰减❌ 需极深网络✅ 任意位置直接连接
参数效率✅ 时序共享✅ 空间共享⚠️ 计算成本随序列长度平方增长
全局感知❌ 依赖隐藏状态❌ 依赖深度✅ 一步到位

位置编码

由于自注意力本身是**排列等变(permutation equivariant)**的(下文将进一步讨论),Transformer 必须通过位置编码注入位置信息。原始论文使用正弦/余弦编码:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{\text{model}}})$$ $$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}})$$

Transformer 的统治地位

  • NLP: BERT、GPT 系列、LLaMA——Transformer 已成为语言建模的事实标准
  • 视觉: [[02-基础/02-04-本周阅读|ViT]] 将图像拆分为 patch 序列,证明 CNN 不是视觉处理的唯一路径
  • 多模态: CLIP、LLaVA 等将文本和图像在 Transformer 框架中统一建模
  • 音频/视频: 类似地扩展到语音、音乐、视频理解

五、集合模型:Deep Sets

数据特征

集合数据中元素没有顺序,且数量可变。例如:点云(3D 空间中的点集)、文档中的关键词集合、群体统计数据。

排列不变性的核心约束

集合模型的输出必须对输入元素的任意排列保持不变。这就是排列不变性(permutation invariance)。

Deep Sets (Zaheer et al., 2017)

Deep Sets 证明:所有排列不变函数都可以表示为:

$$f(X) = \rho\left(\sum_{x \in X} \phi(x)\right)$$

其中 $\phi$ 是每个元素的变换(通常是一个 MLP),$\sum$ 是一个排列不变的聚合操作(求和、平均、最大值等),$\rho$ 是对聚合结果做进一步处理的变换。

核心洞察:通过将每个元素独立编码(即 $\phi$ 对所有元素共享参数),然后使用对称聚合函数($\sum$),架构自动获得了排列不变性。

应用场景

  • 3D 点云处理:PointNet 应用 Deep Sets 思想
  • 元学习:任务集合的归纳
  • 群体预测:从个体特征预测群体属性

六、图模型:GNN 与 GAT

数据特征

图数据由**节点(nodes)边(edges)**组成,编码了实体之间任意复杂的拓扑关系。例子包括分子结构(原子=节点,化学键=边)、社交网络、知识图谱、程序代码的抽象语法树。

GNN 核心思想:消息传递

图神经网络(Graph Neural Networks)通过**消息传递(message passing)**在图上传播信息:

$$h_v^{(k)} = \text{UPDATE}^{(k)}\left(h_v^{(k-1)}, \text{AGGREGATE}^{(k)}\left({h_u^{(k-1)} : u \in \mathcal{N}(v)}\right)\right)$$

每个节点 $v$ 的表示通过聚合其邻居 $\mathcal{N}(v)$ 的信息来更新。经过 K 层后,每个节点的表示编码了 K 跳邻域的信息。

GCN 与 GAT

  • GCN (Graph Convolutional Network):用归一化邻接矩阵做卷积,是 CNN 的图推广
  • GAT (Graph Attention Network):引入图注意力机制,每个邻居的权重由注意力计算得出,而非固定归一化——让模型学会"关注更重要的邻居"

GAT 的注意力公式

$$\alpha_{ij} = \frac{\exp(\text{LeakyReLU}(a^T[Wh_i | Wh_j]))}{\sum_{k \in \mathcal{N}_i} \exp(\text{LeakyReLU}(a^T[Wh_i | Wh_k]))}$$

其中注意力系数 $\alpha_{ij}$ 表示节点 $j$ 对节点 $i$ 的重要性。

关键局限

  • 过平滑(Over-smoothing):随着层数增加,所有节点的表示趋向一致
  • 计算复杂度:大图上全批训练困难,需要采样技术(GraphSAGE)
  • 表达力上限:经典 GNN 的表达力上限不超过 Weisfeiler-Lehman 图同构测试(1-WL)

七、几何深度学习:统一框架

Geometric Deep Learning(Bronstein et al., 2021)提出了一个极具洞察的统一视角:

所有模型架构都是特定数据结构下对称性/不变性先验的编码。

核心概念

术语定义实例
对称性 (Symmetry)数据的某种变换不改变其"本质"旋转一幅猫的图像,它仍然是一只猫
不变性 (Invariance)输入变换后模型输出保持不变图像分类应对平移保持不变
等变性 (Equivariance)输入变换后模型输出以可预测方式变换目标检测——图像平移后检测框也平移
群 (Group)所有对称性变换的集合平移群、旋转群、排列群

统一视角下的架构

数据结构     →     对称性群       →     架构设计
网格(Grid)      平移群              CNN (权重共享)
序列(Sequence)   时序平移群          RNN (时序共享参数)
集合(Set)       排列群              Deep Sets (置换不变聚合)
图(Graph)       图同构群            GNN (消息传递)
齐次流形        旋转/平移群         组卷积网络

为什么这很重要?

  1. 指导架构选择:面对新数据类型,分析其对称性就能推导出合适的架构设计
  2. 解释成功原因:为什么 Transformer 如此通用?因为自注意力本质上是排列等变的,而排列群是各种数据结构的公共子群
  3. 提示新方向:可以设计新的群等变架构来处理具身 AI 中的 SE(3) 等变等

八、架构选择的实战指南

面对实际问题如何在各架构间做选择?

阶段问题推荐架构
你的数据是固定长度的规则网格吗?图像、视频帧CNN → ViT
你的数据是变长序列吗?文本、语音、时间序列Transformer
数据量极小、训练资源有限?小样本序列建模LSTM (比 Transformer 更数据高效)
数据是无序集合点云、关键词集Deep Sets / PointNet
数据有复杂拓扑关系分子、知识图谱GNN / GAT
你的数据既不是网格也不是图流形、3D 形状组卷积 / 等变网络

实用提示: 在实际项目中,Transformer 已成为"万金油"架构。但理解其背后的对称性假设(排列等变性 + 位置编码补偿)对于何时选择它、何时选择其他架构至关重要。


关键概念

概念定义
对称性(Symmetry)数据在某种变换下其"本质"保持不变的性质
不变性(Invariance)模型输出对输入的某种变换不敏感
等变性(Equivariance)输入变换后模型输出以可预测的对应方式变化
自注意力(Self-Attention)序列内部每个位置关注其他所有位置的机制
消息传递(Message Passing)GNN 中节点通过边交换信息的范式
排列不变性(Permutation Invariance)模型对输入元素的任意排列输出不变

讨论问题

  1. 为什么自注意力机制是排列等变的?Transformer 如何补偿这一点?
  2. 如果你要处理一个"社交网络时间线"数据集(每个帖子是文本 + 用户的社交关系图),你会选择什么架构组合?为什么?
  3. CNN 的平移不变性在处理旋转后的图像时会失效。有哪些方法可以让 CNN 获得旋转不变性?
  4. Geometric Deep Learning 框架是否能预测未来新的架构设计范式?你有哪些想法?
  5. Transformer 的 $O(n^2)$ 复杂度是其主要瓶颈。你能想到哪些可能的改进方向?

延伸阅读

相关笔记

  • [[01-AI导论/01-01-AI导论|AI导论]]
  • [[02-基础/02-01-数据与结构|数据、结构与信息]]
  • [[02-基础/02-02-实用AI工具|实用AI工具]]
  • [[02-基础/02-04-本周阅读|本周阅读]]
  • [[03-多模态/03-01-连接与对齐|连接与对齐]]
  • [[讨论课/讨论02-现代AI架构|讨论02:现代AI架构]]
  • [[讨论课/讨论01-学习与泛化|讨论01:学习与泛化]]
  • [[MOC-如何AI一切|🗺️ 返回内容地图]]

基于 MIT MAS.S60 How to AI (Almost) Anything 翻译改编