深度学习框架 2025 对比
PyTorch、JAX、Flax、Pax… 深度学习框架的选择比以往更复杂。
当前格局
| 框架 | 维护者 | 市场份额 | 适用场景 |
|---|
| PyTorch | Meta | ~60% | 研究、生产 |
| JAX | Google | ~20% | 研究、大模型 |
| TensorFlow | Google | ~15% | 生产(遗留) |
| PaddlePaddle | 百度 | ~3% | 中国市场 |
| Others | - | ~2% | 特定场景 |
PyTorch:研究者的首选
优点
- 动态图优先:调试友好
- 生态丰富:Hugging Face、torchvision 等
- 社区活跃:问题容易找到答案
- 工业采用:Meta、OpenAI、Tesla
2.x 的变化
# 1.x:eager mode
def train(model, x, y):
optimizer.zero_grad()
output = model(x)
loss = F.cross_entropy(output, y)
loss.backward()
optimizer.step()
# 2.x:compile 加速
model = torch.compile(model) # 一行加速
分布式训练
import torch.distributed as dist
import torch.nn.parallel.DistributedDataParallel as DDP
# 单机多卡
model = DDP(model, device_ids=[local_rank])
# 多机多卡
dist.init_process_group("nccl")
生态
| 库 | 用途 |
|---|
torchvision | 图像 |
torchaudio | 音频 |
transformers | NLP(Hugging Face) |
lightning | 训练框架 |
accelerate | 分布式简化 |
JAX:Google 的新赌注
核心特点
- 函数式:纯函数、不可变
- 可组合变换:grad、jit、vmap、pmap
- XLA 编译:TPU/CPU/GPU 优化
基础用法
import jax
import jax.numpy as jnp
# grad:自动微分
def loss_fn(params, x, y):
return jnp.mean((model(params, x) - y) ** 2)
grad_fn = jax.grad(loss_fn)
# jit:编译加速
@jax.jit
def train_step(params, x, y):
grads = grad_fn(params, x, y)
return update(params, grads)
# vmap:批量处理
batched_model = jax.vmap(model, in_axes=(None, 0))
pmap:分布式
# 多设备并行
@jax.pmap
def train_step(params, x, y):
grads = grad_fn(params, x, y)
return jax.tree_map(lambda p, g: p - lr * g, params, grads)
# 在 8 个 TPU 上运行
params = train_step(params, xs, ys)
生态
| 库 | 用途 |
|---|
Flax | 神经网络 API |
Optax | 优化器 |
Haiku | DeepMind 的 NN 库 |
Equinox | 更 Pythonic 的 API |
Flax vs PyTorch
定义模型
# PyTorch
class Model(nn.Module):
def __init__(self):
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
# Flax
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
return nn.Dense(10)(x)
训练循环
# PyTorch
for epoch in range(epochs):
for x, y in dataloader:
optimizer.zero_grad()
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
# Flax
@jax.jit
def train_step(state, x, y):
def loss_fn(params):
pred = model.apply(params, x)
return cross_entropy(pred, y)
grads = jax.grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)
for epoch in range(epochs):
for x, y in dataloader:
state = train_step(state, x, y)
选择建议
用 PyTorch 如果
- 刚入门深度学习
- 需要快速原型
- 使用 Hugging Face 生态
- 需要调试友好
用 JAX/Flax 如果
- 需要 TPU 加速
- 做大模型训练
- 需要函数式纯度
- Google/DeepMind 生态
用 TensorFlow 如果
- 维护遗留项目
- 需要生产部署(TF Serving)
- Google Cloud 深度集成
性能对比
训练速度(ResNet-50,ImageNet)
| 框架 | 硬件 | 吞吐量(images/s) |
|---|
| PyTorch 2.x | A100 | ~3100 |
| JAX/Flax | A100 | ~3200 |
| JAX/Flax | TPU v4 | ~4200 |
大模型训练(LLM)
| 框架 | 优势 |
|---|
| PyTorch + FSDP | 生态成熟 |
| JAX + pjit | TPU 原生支持 |
| Megatron-LM | 多框架支持 |
新兴趋势
1. PyTorch 2.x 的 compile
# 自动优化
model = torch.compile(model, mode="reduce-overhead")
# 自定义后端
model = torch.compile(model, backend="inductor")
2. JAX 的 Pallas(自定义 kernel)
import jax.experimental.pallas as pl
@pl.jit
def custom_kernel(x_ref, o_ref):
# 写 GPU kernel,用 Python
...
3. 多框架互操作
# PyTorch → JAX
import torch2jax
jax_fn = torch2jax.torch2jax(torch_fn)
# JAX → PyTorch
import jax2torch
torch_fn = jax2torch.jax2torch(jax_fn)
生态系统成熟度
| 方面 | PyTorch | JAX |
|---|
| 教程/文档 | 优秀 | 良好 |
| 预训练模型 | 丰富 | 增长中 |
| 分布式 | 成熟 | 优秀(TPU) |
| 部署 | TorchServe/TorchScript | 不成熟 |
| 社区 | 最大 | 快速增长 |
实战建议
新项目
- 研究项目:PyTorch(快速迭代)
- 大模型训练:JAX(TPU 支持)
- 生产部署:PyTorch + TorchServe
迁移成本
| 从 → 到 | 成本 |
|---|
| TF → PyTorch | 中等 |
| PyTorch → JAX | 高(重写) |
| JAX → PyTorch | 高(重写) |
总结
2025 年的深度学习框架:
- PyTorch 仍然是主流:生态最成熟,社区最大
- JAX 正在崛起:Google/DeepMind 的选择,TPU 优化
- TensorFlow 逐渐淡出:主要用于遗留项目
选择取决于场景:
- 快速原型 → PyTorch
- 大规模训练 → JAX
- 生产部署 → PyTorch + TorchServe
学习资源