Going Deeper With Directly-Trained Larger Spiking Neural Networks

前言

  • Going Deeper With Directly-Trained Larger Spiking Neural Networks - 原文
  • 个人理解:听起来和结果上都很厉害的工作。主要解决直接训练SNN算法中的各类Normalization的问题。如STBP这类算法因为是BPTT-inspired的,也会出现梯度消失和爆炸的现象。而且由于脉冲激活不可导,为了BP算法的应用通常都要引入梯度的近似,造成梯度更新有误差,带来了不稳定因素。同时由于信息在LIF模型中传输是“输入-电位-输出”的模型,信息需要在电位上周转一次。如果timestep低且发放率(firing rate)低,“电位-输出”这一环节就会被卡住,造成信息损失;如果脉冲发放率高,神经元的输出对输入就不敏感(实际上从编码的角度来看,firing rate无论是过低还是过高,其信息表达能力都很有限,从这个角度来看理想的发放率是0.5)。论文提出了适用于直接训练SNN算法的BN改进版,并且在残差网络上跑了ImageNet数据集,总的来说算是比较solid的,要是能开放源码就更solid了。
  • 后续补:给个Vth和tau也行啊,还以为能解决硬调Vth的问题呢。

Iterative LIF

  • 迭代LIF模型,引自此文
  • STBP的基础。分析可见此处

Threshold-dependent batch normalization

  • ANN-to-SNN的方法能使用ANN中所有的训练trick,包括知名的batch normalization,但是模型在ANN-to-SNN映射时会有映射损失。基于梯度近似的BP算法(此文中为STBP)能够直接在SNN上训练,避免先训练后转换的损失,但是由于SNN结构不同,无法直接应用ANN的BN。所以这篇文章提出了适用于直接训练SNN的BN方法tdBN。

  • 公式:

    其中 $\alpha$ 是超参,$V_{th}$ 是阈值电压,其他的参数和ANN的BN意义相同,即做channel-wise的统计并取得统计量用于inference的计算,训练时则使用minibatch 的统计量, 不同的是求均值/方差的时候还要在SNN新增的时间维度上求,即对(N,H,W,T)的张量求均值等。

image-20201122170903068

  • 理论解释:

    • 关于防止梯度消失/爆炸的作用:有点长,先码着。涉及到另一篇分析grad norm的文章,理论性比较强。
    • 关于公式里多出来的scaling factor:总体是为了平衡firing rate设计的,理论细节先鸽了。

Deep Spiking Residual Network

  • 结构总体和ANN相同:ReLU替换为LIF(激活),BN替换为tdBN。区别主要是
    • 在shortcut的连接上也加了tdBN。
    • 与最后的add操作相连的tdBN中,$\alpha=1/\sqrt{2}$ 。
  • 因为tdBN的目的是平衡firing rate,所以要将fire单元前的所有输入都norm一遍。α的取值能保证相加后特征图方差符合要求。

image-20201122170903063

Experiment

  • CIFAR10上的结果很漂亮。由于加入了BN,(从表上看)训练出来的ResNet19对其他结果都是降维打击。在timesteps=6的情况下就能达到SOTA(一般timesteps调高能很暴力地涨点)。不过表中用作baseline的ANN结构ResNet19只有90.6%的ACC,我记着这网络起码也能有93%左右的,不知道这个数据是什么意思。
  • 跑了ImageNet,文章声称它是第一个直接训练SNN算法跑ImageNet的。结果也不错,而且timesteps也只有6,相比其他动辄几百上千的timesteps确实低很多(有一个用VGG跑了2500steps的,还好是ANN2SNN的方法)。
  • 脉冲数据集上测了DVS-Gesture和他们自录的DVS-CIFAR10。对于DVS-Gesture,超参设置是dt=30ms,T=40 steps。结果平均比之前的SOTA高了1到2个点,有点夸张。
  • DVS-CIFAR10的结果也是SOTA。

result

  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2019-2020 thiswinex
  • 访问人数: | 浏览次数:

请我喝奶茶~

支付宝
微信