2022-02-20 • Fixed timestep Euler solver in vanilla Julia
Contents
2022-02-20 • Fixed timestep Euler solver in vanilla Julia¶
i.e. no DifferentialEquations.jl
.
Hopefully this achieves better performance.
Setup¶
# using Pkg; Pkg.resolve()
println("start"); flush(stdout)
start
using Revise
using Distributions
using MyToolbox
using VoltageToMap
println("setup done") # feedback when running in terminal
setup done
Parameters¶
@kwdef struct PoissonInputParams
N_unconn ::Int = 100
N_exc ::Int = 5200
N_inh ::Int = N_exc ÷ 4
N_conn ::Int = N_inh + N_exc
N ::Int = N_conn + N_unconn
spike_rate::Distribution = LogNormal_with_mean(4Hz, √0.6) # (μₓ, σ)
end
const realistic_input = PoissonInputParams()
const slightly_smaller_input = PoissonInputParams(N_exc = 800)
const small_N__as_in_Python_2021 = PoissonInputParams(N_unconn = 9, N_exc = 17)
small_N__as_in_Python_2021.N
30
@kwdef struct SynapseParams
g_t0 ::Float64 = 0 * nS
τ_s ::Float64 = 7 * ms
E_exc ::Float64 = 0 * mV
E_inh ::Float64 = - 65 * mV
Δg_exc ::Float64 = 0.4 * nS
Δg_inh ::Float64 = 1.6 * nS
end
const semi_arbitrary_synaptic_params = SynapseParams();
@kwdef struct IzhNeuronParams
v_t0 ::Float64 = - 80 * mV
u_t0 ::Float64 = 0 * pA
C ::Float64 = 100 * pF
k ::Float64 = 0.7 * (nS/mV) # steepness of dv/dt's parabola
vr ::Float64 = - 60 * mV # resting v
vt ::Float64 = - 40 * mV # ~spiking thr
a ::Float64 = 0.03 / ms # reciprocal of `u`'s time constant
b ::Float64 = - 2 * nS # how strongly `(v - vr)` increases `u`
v_peak ::Float64 = 35 * mV # cutoff to define spike
v_reset ::Float64 = - 50 * mV # ..on spike. `c` in Izh.
Δu ::Float64 = 100 * pA # ..on spike. `d` in Izh. Free parameter.
end
const cortical_RS = IzhNeuronParams();
Base.@kwdef struct SimParams
sim_duration ::Float64 = 1.2 * seconds
Δt ::Float64 = 0.1 * ms
poisson_input ::PoissonInputParams = realistic_input
synapses ::SynapseParams = semi_arbitrary_synaptic_params
izh_neuron ::IzhNeuronParams = cortical_RS
Δg_multiplier ::Float64 = 1.0 # Free parameter, fiddled with until medium number of output spikes.
end;
Simulation¶
function sim(params::SimParams)
@unpack sim_duration, Δt, Δg_multiplier = params
@unpack N_unconn, N_exc, N_inh, N_conn, N, spike_rate = params.poisson_input
@unpack E_exc, E_inh, g_t0, τ_s, Δg_exc, Δg_inh = params.synapses
@unpack v_t0, u_t0, C, k, vr, vt, a, b, v_peak, v_reset, Δu = params.izh_neuron
input_neuron_IDs = idvec(conn = idvec(exc = N_exc, inh = N_inh), unconn = N_unconn)
synapse_IDs = idvec(exc = N_exc, inh = N_inh)
simulated_vars = idvec(t = nothing, v = nothing, u = nothing, g = similar(synapse_IDs))
# Connections
postsynapses = Dict{Int, Vector{Int}}() # input_neuron_ID => [synapse_IDs...]
for (n, s) in zip(input_neuron_IDs.conn, synapse_IDs)
postsynapses[n] = [s]
end
for n in input_neuron_IDs.unconn
postsynapses[n] = []
end
# Broadcast scalar parameters
Δg = similar(synapse_IDs, Float64)
Δg.exc .= Δg_multiplier * Δg_exc
Δg.inh .= Δg_multiplier * Δg_inh
E = similar(synapse_IDs, Float64)
E.exc .= E_exc
E.inh .= E_inh
# Inter-spike—interval distributions
λ = similar(input_neuron_IDs, Float64)
λ .= rand(spike_rate, length(λ))
β = 1 ./ λ
ISI_distributions = Exponential.(β)
first_input_spike_t = rand.(ISI_distributions)
upcoming_input_spikes = PriorityQueue{Int, Float64}()
for (neuron_ID, spike_t) in zip(input_neuron_IDs, first_input_spike_t)
enqueue!(upcoming_input_spikes, neuron_ID => spike_t)
end
next_input_spike_t = peek(upcoming_input_spikes).second # (`.first` is neuron ID).
# Initialize simulation vars and their derivatives
vars = similar(simulated_vars, Float64)
vars.t = zero(sim_duration)
vars.v = v_t0
vars.u = u_t0
vars.g .= g_t0
D = similar(vars)
D.t = 1
num_timesteps = round(Int, sim_duration / Δt) # Fixed timestep
v_rec = Vector{Float64}(undef, num_timesteps)
input_spike_t_rec = similar(input_neuron_IDs, Vector{Float64})
for i in eachindex(input_spike_t_rec)
input_spike_t_rec[i] = Vector{Float64}()
end
# package it all up
p = (;
vars, D, Δt, E, τ_s, Δg, params, v_rec, input_spike_t_rec,
upcoming_input_spikes, ISI_distributions, postsynapses
)
@showprogress 200ms for i in 1:num_timesteps
step!(p)
v_rec[i] = vars.v
end
return (
t = linspace(zero(sim_duration), sim_duration, num_timesteps),
v = v_rec,
input_spikes = input_spike_t_rec
)
end
sim (generic function with 1 method)
function step!(p)
@unpack vars, D, Δt, E, τ_s, Δg, input_spike_t_rec = p
@unpack upcoming_input_spikes, ISI_distributions, postsynapses = p
@unpack t, v, u, g = vars
@unpack C, k, vr, vt, a, b, v_peak, v_reset, Δu = p.params.izh_neuron
# Sum synaptic currents
I_s = zero(u)
for (gi, Ei) in zip(g, E)
I_s += gi * (v - Ei)
end
# Differential equations
D.v = (k * (v - vr) * (v - vt) - u - I_s) / C
D.u = a * (b * (v - vr) - u)
for i in eachindex(g)
D.g[i] = -g[i] / τ_s
end
# Euler integration
@. vars += D * Δt
# Izhikevich neuron spiking threshold
if vars.v ≥ v_peak
vars.v = v_reset
vars.u += Δu
end
# Input spikes
next_input_spike_t = peek(upcoming_input_spikes).second
if t ≥ next_input_spike_t
fired_neuron = dequeue!(upcoming_input_spikes)
push!(input_spike_t_rec[fired_neuron], t)
for synapse in postsynapses[fired_neuron]
g[synapse] += Δg[synapse]
end
new_spike_time = t + rand(ISI_distributions[fired_neuron])
enqueue!(upcoming_input_spikes, fired_neuron => new_spike_time)
end
end
step! (generic function with 2 methods)
println("defs done")
defs done
p = SimParams(poisson_input = small_N__as_in_Python_2021, Δg_multiplier = 7, sim_duration=1*minutes)
sim(p); # to trigger compilation
Progress: 100%|█████████████████████████████████████████| Time: 0:00:00
using Profile
Profile.clear_malloc_data()
p = SimParams(poisson_input = slightly_smaller_input, Δg_multiplier = 1, sim_duration = 1*minutes)
dump(p)
SimParams
sim_duration: Float64 60.0
Δt: Float64 0.0001
poisson_input: PoissonInputParams
N_unconn: Int64 100
N_exc: Int64 800
N_inh: Int64 200
N_conn: Int64 1000
N: Int64 1100
spike_rate: LogNormal{Float64}
μ: Float64 1.0862943611198905
σ: Float64 0.7745966692414834
synapses: SynapseParams
g_t0: Float64 0.0
τ_s: Float64 0.007
E_exc: Float64 0.0
E_inh: Float64 -0.065
Δg_exc: Float64 4.0000000000000007e-10
Δg_inh: Float64 1.6000000000000003e-9
izh_neuron: IzhNeuronParams
v_t0: Float64 -0.08
u_t0: Float64 0.0
C: Float64 1.0e-10
k: Float64 7.0e-7
vr: Float64 -0.06
vt: Float64 -0.04
a: Float64 30.0
b: Float64 -2.0e-9
v_peak: Float64 0.035
v_reset: Float64 -0.05
Δu: Float64 1.0e-10
Δg_multiplier: Float64 1.0
t, v, input_spikes = @time sim(p);
Progress: 100%|█████████████████████████████████████████| Time: 0:00:02
2.628588 seconds (2.11 M allocations: 68.386 MiB, 0.76% gc time)
num_spikes = length.(input_spikes)
ComponentVector{Int64}(conn = (exc = [0, 0, 2, 0, 0, 0, 0, 1, 0, 2 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], inh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1 … 0, 1, 0, 0, 1, 2, 0, 1, 1, 0]), unconn = [0, 1, 0, 2, 0, 0, 0, 0, 0, 1 … 0, 0, 0, 0, 0, 2, 0, 0, 0, 1])
Plot¶
# import PyPlot
# using Sciplotlib
""" tzoom = [200ms, 600ms] e.g. """
function plotsig(t, sig, tzoom = nothing)
isnothing(tzoom) && (tzoom = t[[1, end]])
izoom = first(tzoom) .≤ t .≤ last(tzoom)
plot(t[izoom], sig[izoom]; clip_on=false)
end;
# plotsig(t, v / mV);
# plotsig(t, v / mV, [200ms,400ms]);