【PyTorch】第九节:Softmax 函数与交叉熵函数

news/2024/7/24 12:07:42 标签: pytorch, 深度学习, 机器学习

作者🕵️‍♂️:让机器理解语言か

专栏🎇:PyTorch

描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。

寄语💓:🐾没有白走的路,每一步都算数!🐾 

介绍💬

        本实验主要讲解了分类问题中的二分类问题和多分类问题之间的区别,以及每种问题下的交叉熵损失的定义方法。由于多分类问题的输出为属于每个类别的概率,要求概率和为 1 。因此,我们还介绍了如何利用 Softmax 函数,处理神经网络的输出,使其满足损失函数的格式要求。

本文参考蓝桥云课:PyTorch 基础入门实战_机器学习 - 蓝桥云课 (lanqiao.cn) 

知识点🍀

  • 👉二分类和多分类

  • 👉交叉熵损失

  • 👉PyTorch 中的 Softmax 和交叉熵


二分类问题和多分类问题

        二分类问题:表示分类任务有两个类别。比如我们想要识别一副图是否是猫,我们一般会训练出一个分类器,输入一副图片(用向量 x 表示),输出该图片是猫的概率 p。我们可以使用 max 函数判断 p 和 0.5 的大小。如果 p 大,则输出 1(表示该图像为猫)。如果 0.5 大,则输出 0 (表示该图像不为猫)。这就是二分类问题,即输出只为 0 或 1 的分类问题。

        多分类问题:表示分类任务有多个类别。比如我们需要建立一个分类器,用于分辨一堆水果图片中,哪些是橘子、哪些是苹果还有哪些是香蕉。

        在二分类问题中,我们可以使用 max(代码描述为 if a > b return a; else b) 来判断结果,就是非黑即白。但是在多分类问题中,我们就不能这样做。我们需要引入 Softmax 的概念

🌐Softmax

        在机器学习尤其是深度学习中,Softmax 是个非常常用的函数,尤其在多分类的场景中使用广泛。Softmax把输入映射为 0-1 之间的实数,并且通过归一化保证和为 1

        在多分类问题中,我们需要分类器输出每种分类的概率,且为了能够比较概率之间的大小,我们还希望概率之和能够为 1。因此,我们就需要使用 Softmax 函数。 特别是在利用神经网络解决多分类问题时,我们一般都会将输出的最后一层,加上 Softmax 函数,用于规则化输出

假设一个数组为 V,v_i表示 V 中的第 i 个元素,那么这个元素经历了 Softmax 函数后的输出为:

S_i = \frac{e^{v_i}}{\sum_{i=0}^Ne^{v_i}}

        这个定义其实很简单,也就是对输入的值进行了指数化,我认为这里进行指数化的目的是为了扩大任意两个输入之间的差距。将指数化后的值除以总的值,目的是将所有的值缩放到 0-1 之间,并保证所有的输出值相加的和为 1 。

我们可以先利用 NumPy 对其进行实现:

import numpy as np


def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)    # 按行相加


x = np.array([2.0, 1.0, 0.1])
outputs = softmax(x)

print('numpy 版 softmax 的输入 :', x)
print('numpy 版 softmax 的输出 :', outputs)
print('numpy 版 softmax 的输出之和:', outputs.sum())

        当然,我们也可以使用 PyTorch 中自带的 Softmax 函数,完成数据的处理:

import torch
import torch.nn as nn
x = torch.tensor([2.0, 1.0, 0.1])
outputs = torch.softmax(x, dim=0)  # dim=0,表示处理的是第 1 维的数据
print('torch 版 softmax 的输入 :', x)
print('torch 版 softmax 的输出 :', outputs)
print('torch 版 softmax 的输出之和:', outputs.sum())

🌐损失函数

        损失函数反映的是预测结果和实际结果之间的差距,即从预测结果到实际结果需要走的距离,即所需要消耗的成本,故称之为损失函数。

        这里让我们介绍一种常用的损失函数:交叉熵损失。那么为什么我们需要使用交叉熵损失作为我们的损失函数呢?为什么我们不直接用错误率来作为损失函数,用梯度下降算法得到错误率最低时的模型呢?为了能够更好的阐述上面这个问题,让我们来举个例子。

葡萄酒的种类预测

        我们希望通过葡萄酒的酒精浓度 、苹果酸浓度 、灰分浓度等独立特征,来预测该葡萄酒源产地。假设数据集中有三种源产地:英国、法国和美国。

        这里我们建立了两个模型用以预测葡萄酒的种类。每个模型都会输出三个值,即输入的葡萄酒来源于英国、法国和美国的概率。

        这里我们对两个模型输入了三条相同的数据,得到的结果如下:

模型 1 :

        从结果可以看出,模型 1 对于样本 1 和样本 2 以非常微弱的优势(概率只比其他结果高 0.1)判断正确,对于样本 3 的判断则彻底错误。

模型 2 :

        模型 2 对于样本 1 和样本 2 的判断非常准确(概率比其他结果高很多)。模型 2 对于样本 3 的判断错误,但是相对来说没有错得太离谱(概率只比其他结果高 0.1)。好了,有了模型之后,我们需要通过定义损失函数来判断模型在样本上的表现,那么我们可以定义哪些损失函数呢?

        如果使用简单的分类错误率作为损失函数,那么两个模型的分类错误率为:

模型 1:

classification\_error = \frac{1}{3}

模型 2:

classification\_error = \frac{1}{3}

        从结果可以看出,如果使用分类错误率来衡量两个模型的好坏,那么这两个模型的好坏程度相同。我们从上面的结果可以看出,虽然模型 1 和模型 2 都预测错了 1 个,但是相对来说,模型 2 的预测效果更好,损失函数照理来说应该更小。因此,我们使用分类错误率不能很好的描述模型的优劣。

        为此,我们引入了交叉熵损失函数用以描述模型的优劣。

🌐交叉熵损失函数

交叉熵损失函数有两种形式:二分类形式多分类形式

  • 🍒二分类的任务 

在 二分类的任务 中,模型最后需要预测的结果只有两种情况,对于每个类别,我们的预测得到的概率为 p 和 1−p 。此时的交叉熵损失(又叫二进制交叉熵)为:

L = \frac{1}{N}\sum_iL_i=\frac{1}{N}\sum_i-[y_i\cdot log(p_i)+(1-y_i)\cdot log(1-p_i)]

其中:

  • y_i:表示样本 i 的真实标签值,正类为 1,负类为 0。
  • p_i:表示样本 i 的预测为正的概率。

        当然,我们不必手动实现上面的损失函数。我们可以利用 nn.BCELoss() 定义二值交叉熵损失。如下:

loss = nn.BCELoss()
# 假设两个模型最后的预测结果相同,但是概率不同
model_one_pred = torch.tensor([0.9, 0.6])
model_two_pred = torch.tensor([0.6, 0.9])
# 真实结果
target = torch.FloatTensor([1, 0])

# 计算两种模型的损失
l1 = loss(model_one_pred, target)
l2 = loss(model_two_pred, target)
l1, l2
# (tensor(0.5108), tensor(1.4067))

        从上面代码中可以看出,其实模型 1 和模型 2 都预测对了一条数据,预测错了一条数据。但是,由于模型 1 是在概率差距很大的情况下,预测正确的。因此,模型 1 的预测效果比模型 2 好,即模型 1 的损失应当比模型 2 的损失小。

  • 🍒 多分类任务 

        在 多分类任务 中,交叉熵损失的函数形式会发生一定的改变(其实就是二进制交叉熵损失的扩展):

L=\frac{1}{M}\sum_iL_i=-\frac{1}{M}\sum_{c=1}^My_{ic}log(p_{ic})

其中:

  • M: 类别的数量。
  • y_{ic}​: 标签的 one-hot 编码,如果该类别和样本 i 的类别相同,就是 1,否则为 0。
  • p_{ic}​: 对于观察样本 i 属于类别 c 的预测概率。

        利用 PyTorch 中的 nn.CrossEntropyLoss() 定义多分类任务的交叉熵损失函数。但是,实际上 nn.CrossEntropyLoss() 是包含了 nn.LogSoftmax()  nn.NLLLoss() 。 因为这里已经是概率值了,所以我们使用 nn.NLLLoss() 来计算交叉熵损失。

        接下来让我们利用交叉熵损失来评估一下上面建立的两个葡萄酒预测模型的好坏:

loss = nn.NLLLoss()
# 三条数据的真实结果:法国、美国、英国
Y = torch.tensor([2, 1, 0])

# 模型一对每条数据的预测,每条数据对应三个概率,表示该条数据属于第 i 类的概率值
model_one_pred = torch.tensor(
    [[0.3, 0.3, 0.4],  # predict class 2
     [0.3, 0.4, 0.3],  # predict class 1
     [0.1, 0.2, 0.7]])  # predict class 0

# 模型 2 对每条数据的预测,每条数据对应三个概率,表示该条数据属于第 i 类的概率值
model_two_pred = torch.tensor(
    [[0.1, 0.2, 0.7],  # predict class 2
     [0.1, 0.7, 0.2],  # predict class 1
     [0.4, 0.3, 0.3]])  # predict class 0
l1 = loss(torch.log(model_one_pred), Y)
l2 = loss(torch.log(model_two_pred), Y)
l1, l2
# (tensor(1.3784), tensor(0.5432))

        从上面的损失大小可以看出,通过交叉熵损失对模型进行评估的话,模型 2 的效果要优于模型 1 的效果,这是符合我们的客观想法的。

        综上,这就是我们为什么使用交叉熵损失的原因。交叉熵损失函数除了考虑模型的准确率之外,还将模型的鲁棒性等因素考虑了进去,能够更好的评价模型的好坏,使训练出来的模型具有更加稳定的预测准确率。

        从上面的输入可以看出,交叉熵损失需要的的输入每条数据所属种类的概率,且这些概率之和为 1。说到这里,我想你已经联想到了本实验开始学到的 Softmax 函数了吧。

        实际上,我们一般无法严格的将神经网络的输出控制在 0 - 1 之间,更无法使这些值之和等于 1。因此我们一般会在神经网络的最后一层,加上 softmax 函数,得到每种种类的概率值。然后将概率值放入交叉熵损失函数之中,得到预测结果和真实结果之间的距离

实验总结📌

        本实验以葡萄酒的类别为预测的例,以一种简单的方式,理解了交叉熵损失与其他损失的不同,以及引入交叉熵损失的原因。当然,从模型的训练角度来讲,引入交叉熵损失函数也是为了能够加快模型的收敛速度。


http://www.niftyadmin.cn/n/233524.html

相关文章

怎么远程控制电脑

为什么要从另一台电脑远程控制电脑? 如今,Splashtop已广泛应用于各个领域。 在很多情况下,您需要从另一台远程电脑控制一台电脑。 这里演示了两个例子: 1:当您不在同一楼层时,您的同事需要您的帮助来解决…

Flutter开发日常练习-小猫咪杂货店

贴三张效果图 1.欢迎页面 2.商品展示列表 3.购物车页面 因为数据是本地的所以创建本地数据 final List _shopItems [["ZaoShui.", "25.00", "assets/8b10de68e58cfef6bd5f22e5321537.jpg", Colors.green],["ZaoQi.", "25.0…

核心业务5:充值业务实现

核心业务5:我要充值 1.充值业务流程图 2.充值业务流程逻辑 3.数据库表 4.前端逻辑代码 5.汇付宝代码逻辑 6.尚融宝代码逻辑 7.幂等性判断原理和解决方案 8.代码规范和原理了解 核心业务5:我要充值 1.充值业务流程图

【Redis数据库】异地公网远程登录连接Redis教程

文章目录1. Linux(centos8)安装redis数据库2. 配置redis数据库3. 内网穿透3.1 安装cpolar内网穿透3.2 创建隧道映射本地端口4. 配置固定TCP端口地址4.1 保留一个固定tcp地址4.2 配置固定TCP地址4.3 使用固定的tcp地址连接转发自CSDN远程穿透的文章:公网远程连接Redi…

[NLP] SentenceTransformers使用介绍

SentenceTransformers 是一个可以用于句子、文本和图像嵌入的Python库。 可以为 100 多种语言计算文本的嵌入并且可以轻松地将它们用于语义文本相似性、语义搜索和同义词挖掘等常见任务。 该框架基于 PyTorch 和 Transformers,并提供了大量针对各种任务的预训练模型…

NanoPC-T4 RK3399:(三)Kernel编译框架

一:获取Kernel Source build_main() {build_task_is_enabled "kernel" && build_get_kernel_sources }build_get_kernel_sources() {fetch_from_repo "$KERNELSOURCE" "$KERNELDIR" "$KERNELBRANCH" "yes" }$KERN…

电网开源数据-持续更新

1、风电光伏出力 1)欧盟官方——PVGIS(https://re.jrc.ec.europa.eu/pvg_tools/en/) 2)帝国理工学院——Renewables.ninja(https://www.renewables.ninja/) 2、负荷 1)比利时Elia电网公司风、光、负荷的历史、实时、…

devOps

实现DevOps需要什么? 硬性要求:工具上的准备 上文提到了工具链的打通,那么工具自然就需要做好准备。现将工具类型及对应的不完全列举整理如下: 代码管理(SCM):GitHub、GitLab、BitBucket、Su…