2023-04-11 • Nto1 AdEx sims: conntest method comparison

Data generated in 2023-02-24__multisim-winline.jl

using CSV
using DataFrames
df = CSV.read("../data/Nto1_AdEx.csv", DataFrame)
90×13 DataFrame
65 rows omitted
df[df.N .== 6500, :]
15×13 DataFrame
df2 = groupby(df, [:method, :N])

GroupedDataFrame with 18 groups based on keys: method, N

First Group (5 rows): method = "STA_corr_2pass", N = 5


Last Group (5 rows): method = "STA_height", N = 6500
using Statistics
dfm = sort!(combine(df2, Between(:TPRₑ, :AUCᵢ) .=> mean), :N)
18×9 DataFrame
using WithFeedback
@withfb import PyPlot     # mpl wrapper
@withfb using Sciplotlib  # pyplot wrapper
@withfb using PhDPlots    # local plotting funcs
import PyPlot … ✔ (4.8 s)
using Sciplotlib … ✔ (4.7 s)
using PhDPlots … ✔


function plotrates(m)

    df3 = df[df.method .== m, :]
    df3m = dfm[dfm.method .== m, :]

    # jitter(x) = x .* (0.95 .+ 0.1.*rand(length(x)))
    jitter(x) = x

    fig, ax = plt.subplots()
    clip_on = false
    plt.semilogx(jitter(df3.N), df3.TPR, ".", ms = 8; clip_on)
    plt.semilogx(jitter(df3m.N), df3m.TPR_mean, "C0-"; clip_on)
    plt.semilogx(jitter(df3.N), df3.FPR, ".", ms = 3; clip_on)

    Ns = unique(df.N)

    plt.xticks(Ns, Ns)

    plt.ylabel("Detection rate")
    plt.xlabel("Number of inputs (N)")

../_images/2023-04-11__Nto1_AdEx_conntest_methods_comparison_12_0.png ../_images/2023-04-11__Nto1_AdEx_conntest_methods_comparison_12_1.png ../_images/2023-04-11__Nto1_AdEx_conntest_methods_comparison_12_2.png

What’s up with that N=100, zero detections point

This has been fixed by now

Which seed is it?

df[df.N .== 100 .&& df.seed .== 1, :]
3×13 DataFrame

First let’s inspect the simulation

using Distributed … ✔
using Revise … ✔ (2.0 s)
using SpikeWorks … ✔ (2.4 s)
using SpikeWorks.Units … ✔ (1.0 s)
using ConnectionTests … ✔ (0.2 s)
using DataFrames … ✔
using MemDiskCache … ✔ (2.0 s)
sd = sims(; N=100, seed=1, duration=10minutes, δ_nS=0.75);
Loading [/root/.julia/MemDiskCache.jl/2023-03-14__Nto1_AdEx/run_sim/_  N=100  δ_nS=0.75  duration=600.0  seed=1  _.jld2] … ✔ (4.7 s)
plotsig(voltsig(sd) / mV, [0,500], ms, hylabel = "Simulated membrane voltage (mV)");

We need to fix the spike peaks for aesthetics. But otherwise seems fine.

spikerate(sd) / Hz

Maybe sth went wrong during calc and wrong data was saved. Otoh, both the STA methods and the upstroke fit (which uses the sim directly) have bad perf..

Ok but first, is this very different from another seed?

sd2 = sims(; N=100, seed=2, duration=10minutes, δ_nS=0.75);
Loading [/root/.julia/MemDiskCache.jl/2023-03-14__Nto1_AdEx/run_sim/_  N=100  δ_nS=0.75  duration=600.0  seed=2  _.jld2] … ✔ (0.5 s)
plotsig(voltsig(sd2) / mV, [0,500], ms, hylabel = "Simulated membrane voltage (mV)", ylim=[-80, 0]);

So no, looks similar.

Let’s then look at some STAs.

kw = (; N=100, duration=10minutes, δ_nS=0.75, Nᵤ=100, batch_size, part=1);
# reals, shufs = STA_sets(; kw..., seed=1);
# ylim = [-59, -56]
# plotSTA(reals[1]; ylim);

Seems proper

# reals2, shufs2 = STA_sets(; kw..., seed=2);
# plotSTA(reals2[1]; ylim, color=C1);

Not too different.

# plotSTA(shufs[1][1]; ylim);
# plotSTA(shufs2[1][1]; ylim);

So only difference is a lower avg Vm.

mean(voltsig(sd)) / mV
mean(voltsig(sd2)) / mV

Or maybe, let’s delete the cached data and recalculate, see if it stays the same.

Nope, it stayed the same.



(Hit me offline)

N = 100, Nᵤ = 100, and the seed used to generate both is the same.

So, add Nᵤ to the seed, e.g.


Why is AUC < 0.5

First, let’s plot em

    df, ycol;
    xcol = :N,
    xlabel = (xcol == :N ? "Number of inputs (N)" : string(xcol)),
    hylabel = string(ycol),
    color = C0,
    ax = nothing,
) = begin
    clip_on = false
    isnothing(ax) && (fig, ax = plt.subplots())
    x = df[:, xcol]
    y = df[:, ycol]
    plot(x, y, "."; clip_on, ms=8, alpha=0.6, color, ax)
    xu = unique(x)
    ax.set_xticks(xu, xu)
    ax.set_xlim(xu[1]/1.4, xu[end]*1.4)
    set(ax; xlabel, hylabel, kw...)
    dfm = combine(groupby(df, xcol), ycol => mean => ycol)
    ym = dfm[:, ycol]
    plot(xu, ym, "-"; clip_on, color, ax)

plot_AUC(df, method) = begin
    fig, ax = plt.subplots()
    ax.axhline(0.5, c="gray", lw=1, ls="--")
        df[df.method .== method, :],
        title = method,
        ylim = [0, 1];

methods = unique(df.method)
for m in methods
    plot_AUC(df, m)
../_images/2023-04-11__Nto1_AdEx_conntest_methods_comparison_50_0.png ../_images/2023-04-11__Nto1_AdEx_conntest_methods_comparison_50_1.png ../_images/2023-04-11__Nto1_AdEx_conntest_methods_comparison_50_2.png
kw = (; N=6500, duration=10minutes, δ_nS=0.02, Nᵤ=100, seed=1, method="STA_corr_2pass");
conntest = conntest_tables(; kw...);
using ConnTestEval
sweep = sweep_threshold(conntest);
plot_ROC(sweep) = begin
    fig, ax = plt.subplots()
        xlabel = "FPR",
        ylabel = ("TPR", :loc=>"top"),
        xlim = [0, 1],
        ylim = [0, 1],
ax = plot_ROC(sweep);

For the last point on the graph:

           exc   inh   unc
      exc 1349  3851     0
Real  inh 1139   161     0
      unc   58    42     0

So the problem is that we only count a detection as correct if it’s the right type: exc or inh. Hence why, even with a threshold as low as possible, we do not detect all real connections.

Two-pass STA template for high N

Finally, what’s the template look like for the high N 2 pass case.

I suspect it’s zeros: not enough (or no) connections detected in the first pass, with a high threshold.

kw = (; N=1600, duration=10minutes, δ_nS=0.08, Nᵤ=100, seed=1, method="STA_height");
table = conntest_tables(; kw...)

So there is indeed not one connection with t-value 1. So no excitatory STAs for the template.

We must change the method. Auto-choose a threshold at which we do have at least some detections.

We can do the full sweep: and then start from high thresholds, and pick the first where we cross a detrate; or we have a fixed integer number of exc detections.

Aha wait no ofc, we don’t have the groundtruth.

So rather we’ll, maybe, go in a loop and lower the thr? You need to know the range then.. that’s np, just sorted unique. ok

ts = sort!(unique(table.t), rev = true)
for θ in ts
    predtype = ConnTestEval.predicted_types(table.t, θ)
    N = count(predtype .== :exc)
    @show (θ, N)
    N > 5 && break
(θ, N) = (0.99, 0)
(θ, N) = (0.98, 305)

Something like that.

Now to insert it in the code.

Wait, it’s still so low after that change. Ok then so how many do we have and what’s template like (now that we should have at least some).

using Distributed … ✔
using Revise … ✔
using SpikeWorks … ✔
using SpikeWorks.Units … ✔
using ConnectionTests … ✔
using DataFrames … ✔
using MemDiskCache … ✔
simkw = (; N=1600, duration=10minutes, δ_nS=0.08, Nᵤ=100, seed=1);
m = TwoPassCorrTest();
template = get_template(m, simkw, batch_size);
Loading [/root/.julia/MemDiskCache.jl/2023-03-14__Nto1_AdEx/calc_all_STAs/_  N=1600  Nᵤ=100  δ_nS=0.08  duration=600.0  seed=1  batch_size=300  part=1  _.jld2] … ✔ (0.4 s)
Aha, here’s the problem. We do have an STA. But it’s negative. It’s classified as positive cause the heuristic for that – ‘net area above start’ – is.. Ah yes, that is positive, so classified as positive.

One ad-hoc solution would be to use a shorter stretch of time for this ‘area-above start’. I.e. first 20 ms. (Here, first 5 or 10 would be even better; but imagine a spike transmission delay).