Sharpness Aware Surrogate Training for Spiking Neural Networks

Spiking neural networks are trained with smooth approximations but deployed with discontinuous spikes. SAST closes that gap by seeking flat loss minima, here is the full theory, the intuition behind it, and a concrete worked example.

TLDR: Apply sharpness-aware minimization to a fully surrogate-forward SNN. The flat minima it finds are robust to the surrogate-to-hard spike substitution at deployment. On N-MNIST the surrogate-to-hard transfer gap falls from 30.3 pp to 2.5 pp. Preprint: arXiv:2603.18039.

65.7% → 94.7%

N-MNIST hard-spike accuracy at ρ=0.30\rho = 0.30

31.8% → 63.3%

DVS Gesture hard-spike accuracy at ρ=0.40\rho = 0.40

30.3 pp → 2.5 pp

N-MNIST surrogate-to-hard transfer gap after SAST

2.08×

Per-epoch training-time overhead on N-MNIST

Spiking neural networks

The LIF neuron

The workhorse model is the Leaky Integrate-and-Fire (LIF) neuron. At each discrete time step tt a neuron accumulates charge into its membrane potential u[t]u[t], leaks a fraction of it away, and fires a binary spike s[t]{0,1}s[t] \in \{0,1\} whenever that potential crosses a threshold ϑ\vartheta. After firing the potential resets.

Concretely, for a layer with weight matrix WW and input spikes s1[t]s^{\ell-1}[t] from the layer below:

u[t]=τu[t1](1s[t1])+Ws1[t]u^\ell[t] = \tau \, u^\ell[t-1]\bigl(1 - s^\ell[t-1]\bigr) + W^\ell s^{\ell-1}[t]
s[t]=H ⁣(u[t]ϑ)s^\ell[t] = H\!\bigl(u^\ell[t] - \vartheta\bigr)

where τ(0,1)\tau \in (0,1) is the membrane decay constant and H()H(\cdot) is the Heaviside step function. The term (1s[t1])(1-s^\ell[t-1]) implements the reset: once the neuron fires, the potential is zeroed before it integrates new input.

Surrogate gradients

Training requires backpropagating through HH, which has zero gradient almost everywhere and an undefined gradient at zero. The standard fix is to replace H(x)H'(x) with a smooth surrogate σβ(x)\sigma'_\beta(x) during the backward pass while keeping the forward spike binary. A common choice is the arctangent surrogate:

σβ(x)=1πarctan(βx)+12,σβ(x)=βπ(1+β2x2)\sigma_\beta(x) = \frac{1}{\pi}\arctan(\beta x) + \frac{1}{2}, \qquad \sigma'_\beta(x) = \frac{\beta}{\pi(1+\beta^2 x^2)}

SAST takes a different route: rather than keeping the forward pass hard and only smoothing the backward pass, it trains a surrogate-forward network where the spike function itself is replaced by σβ\sigma_\beta everywhere. This means the gradient is exact for the model being trained — no straight-through approximation anywhere in the computation graph.

The transfer gap

After surrogate training the learned weights are evaluated under two regimes. In surrogate-forward mode the smooth σβ\sigma_\beta is used at inference — this is what training optimised for. In hard-spike mode the smooth activations are replaced with HH — this is what a real neuromorphic deployment requires.

The transfer gap is simply:

Δ=AccsurrogateAcchard\Delta = \text{Acc}_{\text{surrogate}} - \text{Acc}_{\text{hard}}

For a standard baseline on N-MNIST, Δ=30.3\Delta = 30.3 percentage points. SAST reduces this to 2.52.5 pp. Understanding why requires thinking about the shape of the loss landscape.

Why flat minima help

A concrete example

Consider a single hidden neuron whose membrane potential at some time step is u=0.999u = 0.999 and whose threshold is ϑ=1.0\vartheta = 1.0.

Under surrogate training: σβ(uϑ)=σβ(0.001)0.499\sigma_\beta(u - \vartheta) = \sigma_\beta(-0.001) \approx 0.499. The neuron contributes a near-half activation — the network can learn to rely on this precisely calibrated fractional signal.

Under hard-spike evaluation: H(uϑ)=H(0.001)=0H(u - \vartheta) = H(-0.001) = 0. The neuron is completely silent. If the downstream layers depended on that ~0.5 activation, their inputs change dramatically.

The problem scales: a network with many neurons hovering near threshold accumulates these mismatches across layers and time steps. A small shift in the input distribution can flip dozens of spike decisions simultaneously.

The solution is to make the network stop relying on neurons that sit on the fence. If every neuron whose surrogate output is near 0.5 were nudged away from threshold — so that uϑ0|u - \vartheta| \gg 0 — then replacing σβ\sigma_\beta with HH would barely change anything, because both functions agree far from zero.

SAST achieves this indirectly, without explicitly penalising fence-sitter neurons. It does so by seeking flat regions of the parameter space.

The loss-landscape view

Imagine the surrogate training loss as a mountain range and the parameters ww as a location in that landscape. Standard gradient descent finds a valley — a local minimum — but says nothing about how steep the valley walls are. A narrow, sharp valley means that tiny perturbations to ww send the loss shooting upward. A wide, flat valley means the loss stays low across an entire neighbourhood.

Switching from surrogate to hard spikes is exactly such a perturbation. Every neuron near threshold gets its activation discretised — a jump that cannot be predicted from the gradient alone. If the network sits in a flat valley, this perturbation stays within the low-loss region. If it sits in a sharp valley, it likely escapes it.

Sharpness-aware minimization finds flat valleys by construction. That is the geometric reason SAST reduces the transfer gap.

The SAST method

SAM recap

Standard training minimises the empirical risk L(w)\mathcal{L}(w). Sharpness-Aware Minimization (Foret et al., 2021) instead minimises the worst-case loss in a ball around ww:

minw  maxε2ρL(w+ε)\min_w \; \max_{\|\varepsilon\|_2 \leq \rho} \mathcal{L}(w + \varepsilon)

The inner maximisation is intractable in general but admits a first-order approximation. Expanding around ww:

ε(w)ρwL(w)wL(w)2\varepsilon^*(w) \approx \rho \,\frac{\nabla_w \mathcal{L}(w)}{\|\nabla_w \mathcal{L}(w)\|_2}

The outer update then follows the gradient at the perturbed point w+ε(w)w + \varepsilon^*(w), not at ww itself. This pushes ww toward regions where the loss is simultaneously low and locally flat.

The SAST update

SAST applies SAM to the surrogate-forward loss Lsurr(w)\mathcal{L}_{\text{surr}}(w) — the cross-entropy of the SNN when evaluated with σβ\sigma_\beta everywhere. Each training step proceeds in two phases:

Phase 1 — compute the perturbation. Sample minibatch B\mathcal{B}, compute the surrogate gradient, and form:

ε^(w)=ρgg2+δ,g=wLsurr(w;B)\hat{\varepsilon}(w) = \rho \,\frac{g}{\|g\|_2 + \delta}, \qquad g = \nabla_w \mathcal{L}_{\text{surr}}(w;\,\mathcal{B})

The small constant δ>0\delta > 0 prevents division by zero.

Phase 2 — update at the perturbed point. Sample an independent minibatch B\mathcal{B}' and update:

w    w    ηwLsurr ⁣(w+ε^(w);  B)w \;\leftarrow\; w \;-\; \eta \,\nabla_w \mathcal{L}_{\text{surr}}\!\bigl(w + \hat{\varepsilon}(w);\;\mathcal{B}'\bigr)

The independence of B\mathcal{B} and B\mathcal{B}' is not just an implementation detail — it is the condition required by the convergence theorem (Theorem 3 below). In code:

for each step:
  g  = grad surrogate_loss(w; B)           # phase 1 batch
  eps = rho * g / (norm(g) + delta)
  w -= lr * grad surrogate_loss(w + eps; B')  # phase 2 batch

Theory

The theoretical results are scoped to the smooth auxiliary model — the surrogate-forward SNN that SAST actually optimises. They do not claim to govern the hard-spike network directly; that connection is empirical.

State stability

The first result establishes that the surrogate state sequence is well-behaved. Under the assumption that the spectral norm of each weight matrix is bounded:

W2γ,τγ<1    \|W^\ell\|_2 \leq \gamma_\ell, \qquad \tau \gamma_\ell < 1 \;\;\forall \ell

the surrogate state map is a contraction in the 2\ell_2 norm across time steps. Concretely, for two input sequences xx and xx' that differ only in one time step:

u[t](x)u[t](x)2    κTtxx2\|u^\ell[t](x) - u^\ell[t](x')\|_2 \;\leq\; \kappa^{T-t}\|x - x'\|_2

where κ=τγ<1\kappa = \tau\gamma < 1. This fading-memory property means early perturbations wash out over time — a necessary precondition for the Lipschitz and smoothness results that follow.

Lipschitz bound and smoothness

Under the same spectral-norm bounds and a bound on the surrogate derivative σβCσ\|\sigma'_\beta\|_\infty \leq C_\sigma, the network output is Lipschitz in its input with constant:

L  =  Cσ=1Dγ11κL \;=\; C_\sigma \prod_{\ell=1}^{D} \gamma_\ell \cdot \frac{1}{1-\kappa}

The empirical surrogate loss is then β\beta-smooth in the parameters — meaning its Hessian is bounded:

w2Lsurr(w)2    β\|\nabla^2_w \mathcal{L}_{\text{surr}}(w)\|_2 \;\leq\; \beta

Smoothness is the key property SAM theory requires. Without it, the first-order perturbation ε^(w)\hat{\varepsilon}(w) would not reliably approximate the true worst-case direction. Because the surrogate-forward model is smooth by construction (every activation is differentiable), this holds cleanly. It would not hold for the hard-spike model.

Convergence

Given β\beta-smoothness and independent minibatches, SAST satisfies a standard nonconvex convergence bound. After TT steps with learning rate η=O(1/T)\eta = O(1/\sqrt{T}):

1Tt=1TE ⁣[wLsurr(wt)22]    O ⁣(1T)\frac{1}{T}\sum_{t=1}^{T}\mathbb{E}\!\left[\|\nabla_w \mathcal{L}_{\text{surr}}(w_t)\|_2^2\right] \;\leq\; O\!\left(\frac{1}{\sqrt{T}}\right)

This guarantees convergence to a first-order stationary point of the surrogate loss at the standard O(1/T)O(1/\sqrt{T}) rate for nonconvex stochastic optimisation. The independent-minibatch requirement is what keeps the gradient estimate unbiased at the perturbed point — if the same batch were used for both phases, the estimate would be biased and the bound would not hold.

The paper also includes a local mechanism proposition: under a bound on the per-sample Jacobian JwsiFCJ\|J_w s_i\|_F \leq C_J, the input gradient norm satisfies:

xLsurr2    LCJsL2\|\nabla_x \mathcal{L}_{\text{surr}}\|_2 \;\leq\; L \cdot C_J \cdot \|\nabla_s \mathcal{L}\|_2

Smaller parameter gradients (as SAM encourages) imply smaller CJC_J, which in turn implies smaller input sensitivity. This is the local theoretical connection between sharpness and the transfer gap — but it is presented as a proposition under local conditions, not a universal theorem.

Results

The empirical story is about the transfer gap, not raw surrogate accuracy. Surrogate accuracy changes modestly; hard-spike accuracy changes dramatically.

Transfer results summary

DatasetBaseline surrogateBest surrogateBaseline hardBest hardTransfer gapBest ρ\rho
N-MNIST0.3096.1%97.2%65.7%94.7%30.3 pp → 2.5 pp
DVS Gesture0.4075.0%76.9%31.8%63.3%43.2 pp → 13.6 pp
Numbers from the main results table in arXiv:2603.18039.

On N-MNIST, surrogate accuracy improves from 96.1% to 97.2% — a 1.1 pp gain. But hard-spike accuracy improves from 65.7% to 94.7% — a 29 pp gain. The transfer gap collapses from 30.3 pp to 2.5 pp. On DVS Gesture the pattern repeats: the surrogate gain is modest (1.9 pp) while the hard-spike gain is large (31.5 pp).

The overhead is real. Each SAST step requires two forward-backward passes — one to compute ε^\hat{\varepsilon} and one to compute the update gradient — giving a roughly 2× per-epoch cost. Whether this is acceptable depends on the deployment context.

Limitations

Compute-matched baselines. A 2× compute budget could alternatively be spent on longer standard training or larger batch sizes. Whether SAST still wins under a matched compute budget is an open question.

Post-hoc threshold calibration. Some of the hard-spike gap may be recoverable by simply re-tuning ϑ\vartheta after surrogate training. This is a cheap alternative that has not been isolated as a control.

Scope of benchmarks. Results are reported on N-MNIST and DVS Gesture — both relatively small neuromorphic datasets. Whether the transfer-gap reduction generalises to larger, noisier, or temporally longer tasks is unknown.

Theory scope. The convergence and smoothness guarantees apply to the surrogate-forward auxiliary model. The connection between flat surrogate minima and hard-spike robustness is empirically demonstrated and locally motivated, but not proved in full generality.

References

Nicholson, M. (2026). Sharpness Aware Surrogate Training for Spiking Neural Networks. arXiv:2603.18039. arxiv.org/abs/2603.18039

Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2021). Sharpness-Aware Minimization for Efficiently Improving Generalization. ICLR 2021. arXiv:2010.01412.

Neftci, E. O., Mostafa, H., & Zenke, F. (2019). Surrogate Gradient Learning in Spiking Neural Networks. IEEE Signal Processing Magazine, 36(6), 51–63.

Mahowald, M., & Douglas, R. (1991). A silicon neuron. Nature, 354, 515–518.

Zenke, F., & Ganguli, S. (2018). SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, 30(6), 1514–1541.