2023-01-19 • [input] to ‘Fit a line’

A distillation of 2022-10-24 N-to-1 with lognormal inputs,
for use in 2023-01-19 Fit a line; so that that notebook can remain concise.

Imports

#
@showtime using Revise
@showtime using MyToolbox
@showtime using SpikeWorks
@showtime using Sciplotlib
@showtime using VoltoMapSim
WARNING: using MyToolbox.@withfb in module Main conflicts with an existing identifier.
using MyToolbox: 3.606175 seconds (3.39 M allocations: 212.212 MiB, 4.14% gc time, 32.11% compilation time: 60% of which was recompilation)
using SpikeWorks: 2.186514 seconds (2.65 M allocations: 159.204 MiB, 2.88% gc time, 32.05% compilation time: 85% of which was recompilation)
using Sciplotlib: 15.150923 seconds (13.26 M allocations: 747.952 MiB, 4.13% gc time, 58.14% compilation time: 63% of which was recompilation)
[ Info: Precompiling VoltoMapSim [f713100b-c48c-421a-b480-5fcb4c589a9e]
using VoltoMapSim: 30.190371 seconds (9.76 M allocations: 622.510 MiB, 0.84% gc time, 2.84% compilation time: 30% of which was recompilation)

Start

Neuron-model parameters

@typed begin
    # Izhikevich params
    C  =  100    * pF        # Cell capacitance
    k  =    0.7  * (nS/mV)   # Steepness of parabola in v̇(v)
    vₗ = - 60    * mV        # Resting ('leak') membrane potential
    vₜ = - 40    * mV        # Spiking threshold (when no syn. & adaptation currents)
    a  =    0.03 / ms        # Reciprocal of time constant of adaptation current `u`
    b  = -  2    * nS        # (v-vₗ)→u coupling strength
    vₛ =   35    * mV        # Spike cutoff (defines spike time)
    vᵣ = - 50    * mV        # Reset voltage after spike
    Δu =  100    * pA        # Adaptation current inflow on self-spike
    # Conductance-based synapses
    Eₑ =   0 * mV            # Reversal potential at excitatory synapses
    Eᵢ = -80 * mV            # Reversal potential at inhibitory synapses
    τ  =   7 * ms            # Time constant for synaptic conductances' decay
end;

Simulated variables and their initial values

x₀ = (
    # Izhikevich variables
    v   = vᵣ,      # Membrane potential
    u   = 0 * pA,  # Adaptation current
    # Synaptic conductances g
    gₑ  = 0 * nS,  # = Sum over all exc. synapses
    gᵢ  = 0 * nS,  # = Sum over all inh. synapses
);

Differential equations:

calculate time derivatives of simulated vars
(and store them “in-place”, in Dₜ).

function f!(Dₜ, vars)
    v, u, gₑ, gᵢ = vars

    # Conductance-based synaptic current
    Iₛ = gₑ*(v-Eₑ) + gᵢ*(v-Eᵢ)

    # Izhikevich 2D system
    Dₜ.v = (k*(v-vₗ)*(v-vₜ) - u - Iₛ) / C
    Dₜ.u = a*(b*(v-vₗ) - u)

    # Synaptic conductance decay
    Dₜ.gₑ = -gₑ / τ
    Dₜ.gᵢ = -gᵢ / τ
end;

Spike discontinuity

has_spiked(vars) = (vars.v  vₛ)

function on_self_spike!(vars)
    vars.v = vᵣ
    vars.u += Δu
end;

Conductance-based Izhikevich neuron

coba_izh_neuron = NeuronModel(x₀, f!; has_spiked, on_self_spike!);

More parameters, and input spikers

using SpikeWorks.Units
using SpikeWorks: LogNormal
@typed begin
    Δt = 0.1ms
    sim_duration = 10minutes
end
600

Firing rates λ for the Poisson inputs

fr_distr = LogNormal(median = 4Hz, g = 2)
Distributions.LogNormal{Float64}(μ=1.39, σ=0.693)
@enum NeuronType exc inh
input(;
    N = 100,
    EIratio = 4//1,
    scaling = N,
) = begin
    firing_rates = rand(fr_distr, N)
    input_IDs = 1:N
    inputs = [
        Nto1Input(ID, poisson_SpikeTrain(λ, sim_duration))
        for (ID, λ) in zip(input_IDs, firing_rates)
    ]
    # Nₑ, Nᵢ = groupsizes(EIMix(N, EIratio))
    EImix = EIMix(N, EIratio)
    Nₑ = EImix.Nₑ
    Nᵢ = EImix.Nᵢ
    neuron_type(ID) = (ID  Nₑ) ? exc : inh
    Δgₑ = 60nS / scaling
    Δgᵢ = 60nS / scaling * EIratio
    on_spike_arrival!(vars, spike) =
        if neuron_type(source(spike)) == exc
            vars.gₑ += Δgₑ
        else
            vars.gᵢ += Δgᵢ
        end
    return (;
        firing_rates,
        inputs,
        on_spike_arrival!,
        Nₑ,
    )
end;
using SpikeWorks: Simulation, step!, run!, unpack, newsim,
                  get_new_spikes!, next_spike, index_of_next
new(; kw...) = begin
    ip = input(; kw...)
    s = newsim(coba_izh_neuron, ip.inputs, ip.on_spike_arrival!, Δt)
    (sim=s, input=ip)
end;

Multi sim

(These Ns are same as in e.g. https://tfiers.github.io/phd/nb/2022-10-11__Nto1_output_rate__Edit_of_2022-05-02.html)

using SpikeWorks: spikerate
sim_duration/minutes
10
using Printf
print_Δt(t0) = @printf("%.2G seconds\n", time()-t0)
macro timeh(ex) :( t0=time(); $(esc(ex)); print_Δt(t0) ) end;
Ns_and_scalings = [
    (5,    2.4),   # => N_inh = 1
    (20,   1.3),
        # orig: 21.
        # But: "pₑ = 0.8 does not divide N = 21 into integer parts"
        # So voila
    (100,  0.8),
    (400,  0.6),
    (1600, 0.5),
    (6500, 0.5),
];
Ns = first.(Ns_and_scalings);
nbname = "2023-01-19__[input]"
# nbname = "2022-10-24__Nto1_with_fixed_lognormal_inputs"
cachekey(N) = "$(nbname)__N=$N";
cachekey(Ns[end])
"2023-01-19__[input]__N=6500"
function runsim(N, scaling)
    println()
    (sim, inp) = new(; N, scaling)
    @show N
    @timeh run!(sim)
    @show spikerate(sim)
    return (; sim, input=inp)
end 

simruns = []
for (N, f) in Ns_and_scalings
    scaling = f*N
    simrun = cached(runsim, (N, scaling), key=cachekey(N))
    push!(simruns, simrun)
end
2.7 seconds
spikerate(sim) = 3.21
Saving output at `C:\Users\tfiers\.phdcache\runsim\2023-01-19__[input]__N=5.jld2` … done (6.3 s)

2 seconds
spikerate(sim) = 13.2
Saving output at `C:\Users\tfiers\.phdcache\runsim\2023-01-19__[input]__N=20.jld2` … done (0.1 s)

1.5 seconds
spikerate(sim) = 3.28
Saving output at `C:\Users\tfiers\.phdcache\runsim\2023-01-19__[input]__N=100.jld2` … done (0.1 s)

1.7 seconds
spikerate(sim) = 3.69
Saving output at `C:\Users\tfiers\.phdcache\runsim\2023-01-19__[input]__N=400.jld2` … done (0.1 s)

2.8 seconds
spikerate(sim) = 6.84
Saving output at `C:\Users\tfiers\.phdcache\runsim\2023-01-19__[input]__N=1600.jld2` … done (0.5 s)

4 seconds
spikerate(sim) = 4.9
Saving output at `C:\Users\tfiers\.phdcache\runsim\2023-01-19__[input]__N=6500.jld2` … done (0.9 s)
sims = first.(simruns)
inps = last.(simruns);
Base.summarysize(simruns[6]) / GB
0.535

Disentangle

spiketimes(input::Nto1Input) = input.train.spiketimes;
vrec(s::Simulation{<:Nto1System}) = s.rec.v;

Conntest

winsize = 1000

calcSTA(sim, spiketimes) =
    calc_STA(vrec(sim), spiketimes, sim.Δt, winsize);
# @code_warntype calc_STA(vrec(s), st1, s.Δt, winsize)
# all good

Cache STA calc

function calc_STA_and_shufs(spiketimes, sim)
    realSTA = calcSTA(sim, spiketimes)
    shufs = [
        calcSTA(sim, shuffle_ISIs(spiketimes))
        for _ in 1:100
    ]
    (; realSTA, shufs)
end

"calc_all_STAs_and_shufs"
function calc_all_STAz(inputs, sim)
    f(input) = calc_STA_and_shufs(spiketimes(input), sim)
    @showprogress map(f, inputs)
end
calc_all_STAz(simrun) = calc_all_STAz(unpakk(simrun)...);
unpakk(simrun) = (; simrun.input.inputs, simrun.sim);

# out = calc_all_STAz(simruns[1])
# print(Base.summary(out))
calc_all_cached(i) = cached(calc_all_STAz, [simruns[i]], key=cachekey(Ns[i]))

out = []
# for i in eachindex(simruns)
#     push!(out, calc_all_cached(i))
# end;
Any[]
conntype_vec(i) = begin
    sim, inp = simruns[i]
    Nₑ = inp.Nₑ
    N = Ns[i]
    conntype = Vector{Symbol}(undef, N);
    conntype[1:Nₑ]     .= :exc
    conntype[Nₑ+1:end] .= :inh
    conntype
end;
conntestresults(i, teststat = ptp_test; α = 0.05) = begin
    
    f((sta, shufs)) = test_conn(teststat, sta, shufs; α)
    res = @showprogress map(f, out[i])
    df = DataFrame(res)
    df[!, :conntype] = conntype_vec(i)
    df
end;

# conntestresults(1)
using Sciplotlib: plot
spikerate_(spiketimes) = length(spiketimes) / sim_duration;
spikerate_(inp::Nto1Input) = spikerate_(spiketimes(inp));
firing_rates(i) = spikerate_.(spiketimes.(inps[i].inputs));