Parallel Training in Spiking Neural Networks
About
The bio-inspired integrate-fire-reset mechanism of spiking neurons constitutes the foundation for efficient processing in Spiking Neural Networks (SNNs). Recent progress in large models demands that spiking neurons support highly parallel computation to scale efficiently on modern GPUs. This work proposes a novel functional perspective that provides general guidance for designing parallel spiking neurons. We argue that the reset mechanism, which induces complex temporal dependencies and hinders parallel training, should be removed. However, any such modification should satisfy two principles: 1) preserving the functions of reset as a core biological mechanism; and 2) enabling parallel training without sacrificing the serial inference ability of spiking neurons, which underpins their efficiency at test time. To this end, we identify the functions of the reset and analyze how to reconcile parallel training with serial inference, upon which we propose a dynamic decay spiking neuron. We conduct comprehensive testing of our method in terms of: 1) Training efficiency and extrapolation capability. On 16k-length sequences, we achieve a 25.6x training speedup over the pioneering parallel spiking neuron, and our models trained on 2k-length can stably perform inference on sequences as long as 30k. 2) Generality. We demonstrate the consistent effectiveness of the proposed method across five task categories (image classification, neuromorphic event processing, time-series forecasting, language modeling, and reinforcement learning), three network architectures (spiking CNN/Transformer/SSMs), and two spike activation modes (spike/integer activation). 3) Energy consumption. The spiking firing of our neuron is lower than that of vanilla and existing parallel spiking neurons.
Related benchmarks
| Task | Dataset | Result | Rank | |
|---|---|---|---|---|
| Language Modeling | WikiText-103 (test) | Perplexity28.5 | 524 | |
| Classification | CIFAR10-DVS | Accuracy85.3 | 133 | |
| Sequential Image Classification | Sequential CIFAR10 | Accuracy90.1 | 48 | |
| Time Series Forecasting | METR-LA | Avg R271.1 | 39 | |
| Time Series Forecasting | PEMS-BAY | R2 (Horizon 6)0.883 | 19 | |
| Time Series Forecasting | solar | R2 (6h)0.964 | 19 | |
| Reinforcement Learning | Walker2d v4 | Avg Return4.44e+6 | 13 | |
| Reinforcement Learning | Hopper v4 | Average Return3.57e+5 | 13 | |
| Reinforcement Learning | IDP v4 | Average Return9.35e+4 | 8 | |
| Sequential Image Classification | S-CIFAR100 | Accuracy (%)64.7 | 7 |