Pytorch深度学习教程_6_激活函数

news/2025/2/26 3:25:58

欢迎来到《pytorch深度学习教程》系列的第六篇!在前面的五篇中,我们已经介绍了Python、numpy及pytorch的基本使用,进行了梯度及神经网络的实践。今天,我们将深入理解激活函数并进行简单的实践学习

欢迎订阅专栏进行系统学习:

深度学习保姆教程_tRNA做科研的博客-CSDN博客


目录

激活函数

1.线性和非线性函数

线性函数

非线性函数

2.Sigmoid、Tanh和ReLU

Sigmoid函数

Tanh函数

ReLU(修正线性单元)‌

选择合适的激活函数

3.其他激活函数

Leaky ReLU(ReLU的一种变体)‌

Parametric ReLU(PReLU)‌

Exponential Linear Unit(ELU)‌

Swish

选择合适的激活函数

4.结语


激活函数

激活函数为神经网络引入了非线性,使它们能够学习复杂的模式。它们决定了神经元基于其输入的输出。

1.线性和非线性函数

理解线性和非线性函数之间的基本差异在机器学习中至关重要。它们构成了构建复杂模型的基础。

线性函数

线性函数展示了输入和输出之间的直线关系。它们的特点是变化率恒定

一般形式:‌ y = mx + b

  • m: 斜率,表示变化率。
  • b: 截距,表示线与y轴的交点。

示例:

import numpy as np
import matplotlib.pyplot as plt

def linear_function(x, m, b):
    """
    线性函数
    :param x: 输入值
    :param m: 斜率
    :param b: 截距
    :return: 输出值
    """
    return m * x + b

# 生成从-5到5的100个等间隔点
x = np.linspace(-5, 5, 100)
# 计算对应的y值,斜率为2,截距为1
y = linear_function(x, 2, 1)

# 绘制图形
plt.plot(x, y)
plt.xlabel('x')        # x轴标签
plt.ylabel('y')        # y轴标签
plt.title('linear function')  # 图形标题
plt.show()             # 显示图形

非线性函数

非线性函数不遵循直线模式。它们引入了复杂性,并允许模型捕捉数据中的复杂关系。

常见示例:

  • 多项式函数:y = ax^2 + bx + c
  • 指数函数:y = a^x
  • 对数函数:y = log(x)
  • 三角函数:sin(x), cos(x), tan(x)

示例:

import numpy as np
import matplotlib.pyplot as plt

# 定义多个非线性函数
def quadratic_function(x):
    """
    二次函数
    :param x: 输入值
    :return: 输出值
    """
    return x**2

def cubic_function(x):
    """
    三次函数
    :param x: 输入值
    :return: 输出值
    """
    return x**3

def exponential_function(x):
    """
    指数函数
    :param x: 输入值
    :return: 输出值
    """
    return np.exp(x)

def logarithmic_function(x):
    """
    对数函数
    :param x: 输入值
    :return: 输出值
    """
    return np.log(np.abs(x) + 1)  # 使用绝对值避免负数对数

# 生成从-5到5的100个等间隔点
x = np.linspace(-5, 5, 100)

# 计算各个函数对应的y值
y_quadratic = quadratic_function(x)
y_cubic = cubic_function(x)
y_exponential = exponential_function(x)
y_logarithmic = logarithmic_function(x)

# 创建一个新的图形
plt.figure(figsize=(10, 6))

# 绘制各个函数的图像
plt.plot(x, y_quadratic, label='Quadratic ($x^2$)', color='blue')
plt.plot(x, y_cubic, label='Cubic ($x^3$)', color='red')
plt.plot(x, y_exponential, label='Exponential ($e^x$)', color='green')
plt.plot(x, y_logarithmic, label='Logarithmic ($\\log(|x|+1)$)', color='orange')

# 添加图例
plt.legend()

# 设置轴标签和标题
plt.xlabel('x')          # x轴标签
plt.ylabel('y')          # y轴标签
plt.title('function')  # 图形标题

# 显示网格
plt.grid(True)

# 显示图形
plt.show()

 

为什么非线性在机器学习中至关重要

  • 复杂模式: 真实世界的数据通常表现出非线性关系
  • 决策边界: 非线性函数使模型能够学习复杂的决策边界
  • 深度学习 非线性激活函数对于深度神经网络是必不可少的。

实际应用

  • 线性回归: 基于线性关系预测连续数值。
  • 逻辑回归: 使用非线性的sigmoid函数将数据分类到不同的类别。
  • 神经网络: 使用多层非线性函数来学习复杂的模式。

2.Sigmoid、Tanh和ReLU

激活函数是神经网络的心脏和灵魂。它们引入了非线性,使模型能够学习复杂的模式。让我们探索一些最常用的激活函数:Sigmoid、Tanh和ReLU。

Sigmoid函数

Sigmoid函数将任何实数映射到0和1之间的值。它常用于二分类问题的输出层。

示例:

import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):
    """
    Sigmoid激活函数
    :param x: 输入值
    :return: 输出值
    """
    return 1 / (1 + np.exp(-x))

# 生成从-10到10的100个等间隔点
x = np.linspace(-10, 10, 100)
# 计算对应的y值
y = sigmoid(x)

# 绘制图形
plt.plot(x, y)
plt.xlabel('x')            # x轴标签
plt.ylabel('y')            # y轴标签
plt.title('Sigmoid')  # 图形标题
plt.show()                 # 显示图形

挑战:

  • 梯度消失问题:‌ 梯度可能变得非常小,从而减慢训练速度。
  • 不是零中心化的:‌ 输出总是正的,这可能影响收敛。

Tanh函数

Tanh函数将输入值映射到-1到1的范围内。由于它是零中心化的,因此通常比Sigmoid更受青睐。

示例:

import numpy as np
import matplotlib.pyplot as plt

def tanh(x):
    """
    Tanh激活函数
    :param x: 输入值
    :return: 输出值
    """
    return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

# 生成从-10到10的100个等间隔点
x = np.linspace(-10, 10, 100)
# 计算对应的y值
y = tanh(x)

# 绘制图形
plt.plot(x, y)
plt.xlabel('x')            # x轴标签
plt.ylabel('y')            # y轴标签
plt.title('Tanh')      # 图形标题
plt.show()                 # 显示图形

 

ReLU(修正线性单元)

ReLU函数是目前使用最广泛的激活函数。它输出输入值和0之间的最大值。

示例:

import numpy as np
import matplotlib.pyplot as plt

def relu(x):
    """
    ReLU激活函数
    :param x: 输入值
    :return: 输出值
    """
    return np.maximum(0, x)

# 生成从-5到5的100个等间隔点
x = np.linspace(-5, 5, 100)
# 计算对应的y值
y = relu(x)

# 绘制图形
plt.plot(x, y)
plt.xlabel('x')            # x轴标签
plt.ylabel('y')            # y轴标签
plt.title('ReLU函数')      # 图形标题
plt.show()                 # 显示图形

ReLU的优势:

  • 计算效率高。
  • 缓解了梯度消失问题。

选择合适的激活函数

激活函数的选择取决于问题和神经网络的架构。

  • Sigmoid:‌ 常用于二分类问题的输出层。
  • Tanh:‌ 在隐藏层中通常比Sigmoid表现更好。
  • ReLU:‌ 由于其简单性和高效性,是隐藏层中最受欢迎的选择。

3.其他激活函数

虽然Sigmoid、Tanh和ReLU是基础的激活函数,但激活函数的世界提供了多种多样的选项,以适应不同的神经网络架构和问题领域。

Leaky ReLU(ReLU的一种变体)

Leaky ReLU是ReLU函数的一个变体,旨在通过为负输入引入一个小的、非零的梯度来解决“死亡ReLU”问题。

公式:

LeakyReLU(x) = max(αx, x)

其中α是一个小的正常数(通常为0.01)。

import numpy as np
import matplotlib.pyplot as plt

# 定义Leaky ReLU函数
def leaky_relu(x, alpha=0.01):
    return np.maximum(alpha * x, x)

# 创建一个输入数组
x = np.linspace(-5, 5, 1000)
y = leaky_relu(x)

# 绘制Leaky ReLU函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='Leaky ReLU')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Leaky ReLU Activation Function')
plt.legend()
plt.grid(True)
plt.show()

可以看到0之前的线是不等于0的非常小的数值

Parametric ReLU(PReLU)

PReLU是Leaky ReLU的一个扩展,其中负输入的斜率是一个可学习的参数。

PReLU的数学表达式如下:

f(x) = max(αx, x)

import numpy as np
import matplotlib.pyplot as plt

# 定义PReLU函数
def prelu(x, alpha):
    return np.where(x >= 0, x, alpha * x)

# 创建一个输入数组
x = np.linspace(-5, 5, 1000)
alpha = 0.05  # 初始设定α为0.01,实际应用中α是可学习的参数
y = prelu(x, alpha)

# 绘制PReLU函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='PReLU with α={}'.format(alpha))
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Parametric ReLU (PReLU) Activation Function')
plt.legend()
plt.grid(True)
plt.show()

 

Exponential Linear Unit(ELU)

ELU试图结合ReLU和tanh的优点。它对负输入输出负值,有助于梯度流动。

ELU的数学表达式如下:

f(x) = { α(e^x - 1) if x ≤ 0
x if x > 0 }

其中,α是一个超参数,通常设置为一个小的常数,例如0.1。当x为负值时,ELU函数的输出是α乘以(e^x - 1),这确保了即使在负值区域,梯度也不会完全消失,从而允许网络继续学习。

import numpy as np
import matplotlib.pyplot as plt

# 定义ELU函数
def elu(x, alpha=0.1):
    return np.where(x >= 0, x, alpha * (np.exp(x) - 1))

# 创建一个输入数组
x = np.linspace(-3, 3, 1000)
y = elu(x)

# 绘制ELU函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='ELU with α={}'.format(alpha))
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Exponential Linear Unit (ELU) Activation Function')
plt.legend()
plt.grid(True)
plt.show()

Swish

Swish是一个自门控的激活函数,平滑地插值于线性和ReLU行为之间。

公式为:

Swish(x) = x * sigmoid(βx)

其中β是一个可学习的参数。

Swish激活函数的特点包括:

  • 自我门控(Self-gating):Swish函数通过x*Sigmoid(βx)的形式实现,简化了gating机制,允许其直接替代ReLU等单输入激活函数,而无需改变网络结构。
  • 避免梯度消失问题:Swish函数的导数始终大于0,这有助于缓解梯度消失问题。
  • 平滑性:Swish函数具有平滑性,有利于优化和泛化。
import numpy as np
import matplotlib.pyplot as plt

# 定义Swish激活函数
def swish(x, beta=1):
    return x * (1 / (1 + np.exp(-beta * x)))

# 创建一个输入数组
x = np.linspace(-5, 5, 1000)
y = swish(x)

# 绘制Swish函数图像
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='Swish Activation Function')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Swish Activation Function Visualization')
plt.legend()
plt.grid(True)
plt.show()

选择合适的激活函数

最佳的激活函数取决于多种因素:

  • 问题类型:‌ 分类、回归或生成任务。
  • 网络架构:‌ 网络的深度和复杂性。
  • 数据特性:‌ 输入数据的分布。
  • 计算资源:‌ 一些激活函数的计算成本更高。

实验和微调

确定最佳激活函数的最佳方法是通过实验。尝试不同的选项并评估它们在特定任务上的性能。

4.结语

以上就是激活函数本次的教程,如果有什么问题欢迎评论区一起讨论!


http://www.niftyadmin.cn/n/5867126.html

相关文章

算法-数据结构-图的构建(邻接矩阵表示)

数据定义 //邻接矩阵表示图 //1.无向图是对称的 //2.有权的把a,到b 对应的位置换成权的值/*** 无向图* A B* A 0 1* B 1 0*/ /*** 有向图* A B* A 0 1* B 0 0*/import java.util.ArrayList; import java.util.List;/*** 带权图* A B* A 0 1* B 0 0*/ p…

Android NDK基本开发流程

Android NDK(Native Development Kit)开发流程允许开发者使用C/C代码来开发Android应用的部分功能,通常用于性能敏感的场景,如游戏、图像处理等。以下是Android NDK开发的基本流程: 1. 环境准备 安装Android Studio&a…

LabVIEW不规则正弦波波峰波谷检测

在处理不规则正弦波信号时,准确检测波峰和波谷是分析和处理信号的关键任务。特别是在实验数据、传感器信号或其他非理想波形中,波峰和波谷的位置可以提供有价值的信息。然而,由于噪声干扰、信号畸变以及不规则性,波峰波谷的检测变…

Docker 搭建 Redis 数据库

Docker 搭建 Redis 数据库 前言一、准备工作二、创建 Redis 容器的目录结构三、启动 Redis 容器1. 通过 redis.conf 配置文件设置密码2. 通过 Docker 命令中的 requirepass 参数设置密码 四、Host 网络模式与 Port 映射模式五、检查 Redis 容器状态六、访问 Redis 服务总结 前言…

使用 AndroidNativeEmu 调用 JNI 函数

版权归作者所有,如有转发,请注明文章出处:https://cyrus-studio.github.io/blog/ AndroidNativeEmu AndroidNativeEmu 专为 Android 原生代码调试和模拟设计,特别关注 JNI 调用和 Android 环境。相比之下,Unicorn 是通…

Flask应用实战经验总结:使用工厂函数创建app与uWSGI服务部署启动失败解决方案

在 Flask 应用开发中,使用工厂函数创建应用实例,并借助 uWSGI 服务进行部署,是常见且高效的组合。 然而,在实际操作过程中,uWSGI 配置文件与应用启动函数之间的关系复杂,容易引发各种问题。 本文将详细探…

算法系列之搜素算法-二分查找

在算法中,查找算法是处理数据集合的基础操作之一。二分查找(Binary Search)是一种高效的查找算法,适用于有序数组或列表。本文将介绍二分查找的基本原理、Java实现。 二分查找介绍 二分查找是一种在有序数组中查找特定元素的算法…

《一起打怪兽吧》——自制一款Python小游戏

《一起消灭怪兽吧》——在深夜的屏幕前,你是指引光明的勇者。键盘化作利剑,用方向键在像素战场游走,发射吧,每次击杀都有代码绽放的烟火。这款由Python与Pygame铸就的小游戏,让0与1的世界生长出童真的浪漫。 文章目录…