侧边栏壁纸
博主头像
ZHD的小窝博主等级

行动起来,活在当下

  • 累计撰写 81 篇文章
  • 累计创建 53 个标签
  • 累计收到 1 条评论

目 录CONTENT

文章目录

AI大模型学习之路(五)

江南的风
2025-05-19 / 0 评论 / 0 点赞 / 6 阅读 / 5388 字 / 正在检测是否收录...

阶段1:基础知识储备

第四课:信息论基础及其应用


学习目标

  1. 理解信息熵、交叉熵与KL散度的定义与物理意义。
  2. 掌握交叉熵作为分类任务损失函数的设计原理。
  3. 学会用PyTorch实现交叉熵损失函数。
  4. 实战:用交叉熵损失训练逻辑回归模型。

1. 信息熵(Entropy)

(1) 定义

  • 信息熵:度量随机变量的不确定性(单位为比特或纳特)。

    H(X) = -\sum_{i=1}^n p(x_i) \log p(x_i)
    • 示例:公平硬币的熵为 H = -2 \times 0.5 \log_2 0.5 = 1 比特。

(2) 性质

  • 分布越均匀,熵越大(如均匀分布的熵最大)。
  • 确定事件的熵为0(如概率全为1的分布)。

2. 交叉熵(Cross Entropy)

(1) 定义

  • 交叉熵:用分布 q 近似真实分布 p 时的平均信息量。
    H(p, q) = -\sum_{i=1}^n p(x_i) \log q(x_i)
  • 关键性质H(p, q) \geq H(p)(等号当且仅当 p = q 时成立)。

(2) 机器学习中的应用

  • 分类任务中,真实分布 p 是One-hot编码,预测分布 q 是Softmax输出。
    损失函数简化为
    H(p, q) = -\log q_{\text{true\_class}}

3. KL散度(Kullback-Leibler Divergence)

(1) 定义

  • KL散度:衡量两个概率分布 pq 之间的差异。
    D_{\text{KL}}(p \parallel q) = \sum_{i=1}^n p(x_i) \log \frac{p(x_i)}{q(x_i)} = H(p, q) - H(p)
  • 非对称性D_{\text{KL}}(p \parallel q) \neq D_{\text{KL}}(q \parallel p)

(2) 应用场景

  • 变分推断(如VAE)、强化学习中的策略优化。

4. 实战:交叉熵损失与逻辑回归

(1) 问题定义

  • 任务:二分类(生成两个高斯分布的数据集)。
  • 模型:逻辑回归 P(y=1|x) = \sigma(wx + b),其中 \sigma 为Sigmoid函数。

(2) Python实现

import torch
import numpy as np
import matplotlib.pyplot as plt

# 生成二分类数据
np.random.seed(42)
X0 = np.random.randn(100, 2) + np.array([2, 2])    # 类别0(标签为0)
X1 = np.random.randn(100, 2) + np.array([-2, -2])   # 类别1(标签为1)
X = torch.tensor(np.vstack([X0, X1]), dtype=torch.float32)
y = torch.tensor([0]*100 + [1]*100, dtype=torch.float32).view(-1, 1)

# 定义模型
model = torch.nn.Linear(2, 1)
criterion = torch.nn.BCEWithLogitsLoss()  # 内置Sigmoid + BCE损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 训练
loss_history = []
for epoch in range(100):
    # 前向传播
    outputs = model(X)
    loss = criterion(outputs, y)
    
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    loss_history.append(loss.item())

# 可视化决策边界
w = model.weight.data.numpy()[0]
b = model.bias.data.numpy()
x1 = np.linspace(-5, 5, 100)
x2 = -(w[0] * x1 + b) / w[1]

plt.figure(figsize=(12, 5))
plt.subplot(121)
plt.plot(loss_history)
plt.title("Loss Curve")

plt.subplot(122)
plt.scatter(X0[:,0], X0[:,1], label='Class 0')
plt.scatter(X1[:,0], X1[:,1], label='Class 1')
plt.plot(x1, x2, color='red', label='Decision Boundary')
plt.legend()
plt.show()
0

评论区