mxnet使用Gluon来创建神经网络(mxnet.whl)
前言
在上一篇MXNet快速入门之NDArray介绍中,我们已经详细介绍过了MXNet的核心数据结构NDArray的基本用法,在使用MXNet构建神经网络中,到处都可以看到它的影子,如果对NDArray不了解可以参考上一篇文章。本篇文章主要介绍如何使用MXNet来构建神经网络。
Gluon
可以利用Gluon中的nn模块很便利的去创建全连接层、卷积层、池化层、激活层
1.全连接层
mxnet.gluon.nn.Dense(units, activation=None, use_bias=True, flatten=True, dtype=‘float32’, weight_initializer=None, bias_initializer=‘zeros’, in_units=0, **kwargs)
函数功能说明:Dense函数主要实现了一个全连接层的功能output=activation(dot(input,weight)+bias),output表示的是函数的输出值,input表示的是输入值,activation表示输出元素输出时经过的激活函数,weight表示的是权重矩阵,bias表示的是偏置,当use_bias参数为True时偏置才会被创建。
参数说明:
- units(int):输出矩阵的大小
- activation(str):设置使用的激活函数,常用的激活函数有relu、sigmoid、tanh、softrelu、softsign等。如果activation为None,则不使用激活函数,即linear函数,f(x)=x。
- use_bias(bool,default True):是否创建偏置向量
- flatten(bool,default True):输入的向量是否被展开,如果为True则除了第一个axis保持不变,其他axis都要被折叠在一起。如果为false则,除了最后一个axis其他的都保持不变。怎么理解这段话呢?下面举个例子吧,比如说,你输入矩阵shape为(3,4,4),units为10,当flatten为True时,那么输出矩阵的shape就为(3,10),当flatten为False时,输出矩阵的shape就为(3,4,10)
- dtype(str or np.dtype,default “float32”):输出数据的数据类型
- bias_initializer(str or initializer):设置偏置的初始化函数
- in_units(int,optional):输入数据的大小,不需要特殊指定,mxnet会根据前向传播自动推断输入数据的shape
注意:输入矩阵必须是2阶的,否则我们需要通过flatten来将输入矩阵转换为2阶
利用nn.Dense方法来创建全连接层,还可以输出该层的结构
全连接层的前向传播,打印全连接层的输出值
输出全连接层的权重值
2.卷积层
mxnet.gluon.nn.Conv2D(channels, kernel_size, strides=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, layout=‘NCHW’, activation=None, use_bias=True, weight_initializer=None, bias_initializer=‘zeros’, in_channels=0, **kwargs)
函数功能说明:实现2D卷积
参数说明:
- channels(int):设置卷积输出的通道数,也就是卷积核的个数
- kernel_size(int or tuple/list of 2 int):设置卷积核的shape
- strides(int or tuple/list of 2 int):设置卷积的步长
- padding(int or a tuple/list of 2 int):卷积的时候是否需要在输入矩阵的周围填充0,从而来控制输出矩阵的大小
- dilation(int or tuple/list of 2 int):设置膨胀卷积的参数,默认不使用膨胀卷积,膨胀卷积的目的是为了增大卷积核的感受野
- groups(int):用于控制输入和输出之间的连接。当groups=1时,所有输入进行卷积产生输出。当groups=2时,等价于两个卷积层并排进行卷积,然后再将输出的结果进行连接,每个卷积层只使用了输入举证一半的channels
- layout(str,default “NCHW”):输入数据的数据格式,只有"NCHW"和"NHWC"两种格式,其中"N"表示的是batch,"C"表示的是channel,"H"表示的是height,"W"表示的是width
- in_channels(int,default 0):输入矩阵的通道数,如果没有特殊指定,mxnet会根据输入数据的shape推断出channel
- activation(str):设置激活函数
- use_bias(bool):是否使用偏置
- weight_initializer(str or initializer):设置初始化权重的方法
- bias_initializer(str or initializer):设置偏置的初始化方法
输入数据:输入数据是一个4D矩阵,输入数据的格式请参考layout参数
输出数据:输出数据的格式和layout相同,输出数据的out_height和out_width计算公式如下:
out_height = floor((height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 ) / stride[0] ) + 1
out_width = floor((width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) /stride[1] + 1
3.池化层
mxnet.gluon.nn.MaxPool2D(pool_size=(2, 2), strides=None, padding=0, layout=‘NCHW’, ceil_mode=False, **kwargs)
函数功能说明:实现一个最大池化功能
参数说明:
- pool_size(int or list/tuple of 2 ints):最大池化的窗口大小
- strides(int,list/tuple of 2 ints,or None):最大池化窗口的移动步长
- padding(int or list/tuple of 2 ints):输入矩阵的边缘填充大小设置
- layout(str,default “NCHW”):输入矩阵的数据设置
- ceil_mode(bool,default False):如果为True将使用ceil来替换floor去计算输出矩阵的shape
最大池化层输入矩阵尺寸计算公式:
out_height = floor((height + 2 * padding[0] - pool_size[0]) / strides[0] ) + 1
out_width = floor((width +2 * padding[1] - pool_size[1]) / strides[1] ) + 1
4.构建神经网络
- gluon的Sequential构建网络
利用mxnet的gluon很方便的构建一个链式网络
前向传播
查看指定层的权重参数
- nn.Block构建网络
在使用Sequential构建网络的时候很方便,但是有一个不好的地方在于,网络的前向传播是自动构建的。接下来我们介绍另一种方法,可以方便快捷构建网络同时还能自定义网络的前向传播
版权声明:
作者: freeclashnode
链接: https://www.freeclashnode.com/news/article-1159.htm
来源: FreeClashNode
文章版权归作者所有,未经允许请勿转载。
热门文章
- 9月15日|20.4M/S,Shadowrocket/V2ray/SSR/Clash免费节点订阅链接每天更新
- 10月1日|23M/S,Shadowrocket/Clash/SSR/V2ray免费节点订阅链接每天更新
- 9月20日|19.4M/S,V2ray/SSR/Shadowrocket/Clash免费节点订阅链接每天更新
- 9月19日|23M/S,Clash/SSR/Shadowrocket/V2ray免费节点订阅链接每天更新
- 9月18日|22.9M/S,Clash/Shadowrocket/V2ray/SSR免费节点订阅链接每天更新
- 9月16日|18M/S,SSR/Shadowrocket/Clash/V2ray免费节点订阅链接每天更新
- 10月2日|22.9M/S,V2ray/Shadowrocket/Clash/SSR免费节点订阅链接每天更新
- 10月3日|20.9M/S,SSR/V2ray/Clash/Shadowrocket免费节点订阅链接每天更新
- 9月17日|21.6M/S,SSR/Shadowrocket/V2ray/Clash免费节点订阅链接每天更新
- 10月5日|22.5M/S,Clash/V2ray/SSR/Shadowrocket免费节点订阅链接每天更新
最新文章
- 10月8日|18.9M/S,Clash/SSR/V2ray/Shadowrocket免费节点订阅链接每天更新
- 10月7日|21.5M/S,V2ray/Clash/Shadowrocket/SSR免费节点订阅链接每天更新
- 10月6日|19.5M/S,Shadowrocket/Clash/SSR/V2ray免费节点订阅链接每天更新
- 10月5日|22.5M/S,Clash/V2ray/SSR/Shadowrocket免费节点订阅链接每天更新
- 10月4日|22M/S,Clash/V2ray/SSR/Shadowrocket免费节点订阅链接每天更新
- 10月3日|20.9M/S,SSR/V2ray/Clash/Shadowrocket免费节点订阅链接每天更新
- 10月2日|22.9M/S,V2ray/Shadowrocket/Clash/SSR免费节点订阅链接每天更新
- 10月1日|23M/S,Shadowrocket/Clash/SSR/V2ray免费节点订阅链接每天更新
- 9月30日|18.8M/S,SSR/Clash/V2ray/Shadowrocket免费节点订阅链接每天更新
- 9月29日|20.6M/S,SSR/Shadowrocket/Clash/V2ray免费节点订阅链接每天更新