阶段1:基础知识储备
第四课:信息论基础及其应用
学习目标
- 理解信息熵、交叉熵与KL散度的定义与物理意义。
- 掌握交叉熵作为分类任务损失函数的设计原理。
- 学会用PyTorch实现交叉熵损失函数。
- 实战:用交叉熵损失训练逻辑回归模型。
1. 信息熵(Entropy)
(1) 定义
-
信息熵:度量随机变量的不确定性(单位为比特或纳特)。
H(X) = -\sum_{i=1}^n p(x_i) \log p(x_i)- 示例:公平硬币的熵为 H = -2 \times 0.5 \log_2 0.5 = 1 比特。
(2) 性质
- 分布越均匀,熵越大(如均匀分布的熵最大)。
- 确定事件的熵为0(如概率全为1的分布)。
2. 交叉熵(Cross Entropy)
(1) 定义
- 交叉熵:用分布 q 近似真实分布 p 时的平均信息量。
H(p, q) = -\sum_{i=1}^n p(x_i) \log q(x_i)
- 关键性质:H(p, q) \geq H(p)(等号当且仅当 p = q 时成立)。
(2) 机器学习中的应用
- 分类任务中,真实分布 p 是One-hot编码,预测分布 q 是Softmax输出。
损失函数简化为:H(p, q) = -\log q_{\text{true\_class}}
3. KL散度(Kullback-Leibler Divergence)
(1) 定义
- KL散度:衡量两个概率分布 p 和 q 之间的差异。
D_{\text{KL}}(p \parallel q) = \sum_{i=1}^n p(x_i) \log \frac{p(x_i)}{q(x_i)} = H(p, q) - H(p)
- 非对称性:D_{\text{KL}}(p \parallel q) \neq D_{\text{KL}}(q \parallel p)。
(2) 应用场景
- 变分推断(如VAE)、强化学习中的策略优化。
4. 实战:交叉熵损失与逻辑回归
(1) 问题定义
- 任务:二分类(生成两个高斯分布的数据集)。
- 模型:逻辑回归 P(y=1|x) = \sigma(wx + b),其中 \sigma 为Sigmoid函数。
(2) Python实现
import torch
import numpy as np
import matplotlib.pyplot as plt
# 生成二分类数据
np.random.seed(42)
X0 = np.random.randn(100, 2) + np.array([2, 2]) # 类别0(标签为0)
X1 = np.random.randn(100, 2) + np.array([-2, -2]) # 类别1(标签为1)
X = torch.tensor(np.vstack([X0, X1]), dtype=torch.float32)
y = torch.tensor([0]*100 + [1]*100, dtype=torch.float32).view(-1, 1)
# 定义模型
model = torch.nn.Linear(2, 1)
criterion = torch.nn.BCEWithLogitsLoss() # 内置Sigmoid + BCE损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# 训练
loss_history = []
for epoch in range(100):
# 前向传播
outputs = model(X)
loss = criterion(outputs, y)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_history.append(loss.item())
# 可视化决策边界
w = model.weight.data.numpy()[0]
b = model.bias.data.numpy()
x1 = np.linspace(-5, 5, 100)
x2 = -(w[0] * x1 + b) / w[1]
plt.figure(figsize=(12, 5))
plt.subplot(121)
plt.plot(loss_history)
plt.title("Loss Curve")
plt.subplot(122)
plt.scatter(X0[:,0], X0[:,1], label='Class 0')
plt.scatter(X1[:,0], X1[:,1], label='Class 1')
plt.plot(x1, x2, color='red', label='Decision Boundary')
plt.legend()
plt.show()
评论区