%% Perform the k-fold validation for H clear close all load('..\datasets\figure_4\focuscorrelation_forLR_kmeans.mat'); rng(3) % number of frames for the moving median filter medlength = 21; part = 1:15; % For all moving median filter lengths for nn = 1:length(medlength) % (use one filter length or more if desired) % Go through each of the 15-fold validation steps for fn = 1:15 scorecorr = []; scoresnr = []; scoressim = []; GTm = []; scoretestcorr = []; scoretestsnr = []; scoretestssim = []; GTmtest = []; % Assemble the training and testing data from the structure % containing the data for each individual video for k=1:15 vcorr = focus_layers(k).correlation; ssim = focus_layers(k).ssim; snr = focus_layers(k).snr; quadloc = focus_layers(k).quadrant; focusTable = focus_layers(k).focusTable; order = [2 1 3 4]; GT = zeros(size(vcorr,2), 4); for period = 1:size(focusTable, 1) for q = 1:4 if focusTable(period, q+2) == 1 GT(focusTable(period, 1):focusTable(period, 2), q) = 1; end end end nbins = 40; bins = linspace(-2,2,nbins+1); hgram = zeros(4,2,nbins); for mask = 1:size(vcorr, 1) if ~isnan(quadloc(mask)) && quadloc(mask) ~= 0 if all(k ~= part(:, fn)) scorecorr = cat(2, scorecorr, movmedian(normalize(vcorr(mask,:),'center','median'),medlength(nn),'omitnan')); scoresnr = cat(2, scoresnr, movmedian(normalize(snr(mask,:),'center','median', 'scale', 'std'),medlength(nn),'omitnan')); scoressim = cat(2, scoressim, movmedian(normalize(ssim(mask,:),'center','median'),medlength(nn),'omitnan')); GTm = cat(1, GTm, GT(:, quadloc(mask))); else scoretestcorr = cat(2, scoretestcorr, movmedian(normalize(vcorr(mask,:),'center','median'),medlength(nn),'omitnan')); scoretestsnr = cat(2, scoretestsnr, movmedian(normalize(snr(mask,:),'center','median', 'scale', 'std'),medlength(nn),'omitnan')); scoretestssim = cat(2, scoretestssim, movmedian(normalize(ssim(mask,:),'center','median'),medlength(nn),'omitnan')); GTmtest = cat(1, GTmtest, GT(:, quadloc(mask))); end end end end X = cat(2, scoressim(:), scorecorr(:)); Y = GTm(:); Pind = find(Y == 1); Nind = find(Y == 0); % There is a class imbalance, so set the number of in-focus and out % of focus frames equal to each other rNind = randsample(Nind, length(Pind)); tInd = cat(1, Pind, rNind); mdl = fitglm(X(tInd, :),Y(tInd),'Distribution','binomial'); %Train the model LRscores = mdl.Fitted.Probability; Xtest = cat(2, scoretestssim(:), scoretestcorr(:)); %Test the model on the left out video Ytest = GTmtest(:); ypredtest = predict(mdl,Xtest); [fpr,tpr,T,AUC,opt] = perfcurve(Y(tInd),LRscores, 1, 'Cost', [0 1; 1 0]); LogRmodel{fn, nn} = mdl; saveY{fn} = Ytest; savepred{fn, nn} =ypredtest; threshold(fn, nn) = T((fpr==opt(1))&(tpr==opt(2))); saveX{fn} = Xtest; % number of true positives TP(fn, nn) = sum(ismember(find(Ytest), find(ypredtest > threshold(fn)))); % number of true negatives TN(fn, nn) = sum(ismember(find(Ytest == 0), find(ypredtest < threshold(fn)))); % number of false positives FP(fn, nn) = sum(~ismember(find(ypredtest > threshold(fn)), find(Ytest))); % number of false negatives FN(fn, nn) = sum(~ismember(find(ypredtest < threshold(fn)), find(Ytest == 0))); % number of positive results P(fn, nn) = length(find(Ytest)); % number of negative results N(fn, nn) = length(find(Ytest == 0)); % Accuracy Acc(fn, nn) = (TP(fn, nn)+TN(fn, nn))/(P(fn, nn)+N(fn, nn)); % The MCC metric phi(fn, nn) = (TP(fn, nn)*TN(fn, nn)-FP(fn, nn)*FN(fn, nn))/sqrt((TP(fn, nn)+FP(fn, nn))*(TP(fn, nn)+FN(fn, nn))*(TN(fn, nn)+FP(fn, nn))*(TN(fn, nn)+FN(fn, nn))); end end bar([1 2 3 4], [mean(phi) mean(Acc) mean(TN./N) mean(TP./P)], 'FaceColor', [0.8 0.8 0.8]) hold on plot(ones(size(phi)), phi, 'ko') errorbar(1, mean(phi), std(phi), 'k', 'LineWidth', 2) plot(ones(size(Acc))*2, Acc, 'ko') errorbar(2, mean(Acc), std(Acc), 'k', 'LineWidth', 2) plot(ones(size(Acc))*3, TN./N, 'ko') errorbar(3, mean(TN./N), std(TN./N), 'k', 'LineWidth', 2) plot(ones(size(Acc))*4, TP./P, 'ko') errorbar(4, mean(TP./P), std(TP./P), 'k', 'LineWidth', 2) xticklabels({'Matthews correlation coefficient', 'Accuracy', 'True negative rate', 'True positive rate'}) ylabel('Value') %% Train the single model to be used in the remainder of the paper and for F and G of figure 4 clear load('..\datasets\figure_4\focuscorrelation_forLR_kmeans.mat'); rng(3) scorecorr = []; scoresnr = []; scoressim = []; GTm = []; scoretestcorr = []; scoretestsnr = []; scoretestssim = []; scorediff = []; scoreIOU = []; GTmtest = []; nn =1; part = 0; medlength = 21; for k=1:15 %Pool the data across the 15 videos vcorr = focus_layers(k).correlation; ssim = focus_layers(k).ssim; snr = focus_layers(k).snr; quadloc = focus_layers(k).quadrant; focusTable = focus_layers(k).focusTable; sdiff = focus_layers(k).diff; sIOU = focus_layers(k).IOU; order = [2 1 3 4]; GT = zeros(size(vcorr,2), 4); for period = 1:size(focusTable, 1) for q = 1:4 if focusTable(period, q+2) == 1 GT(focusTable(period, 1):focusTable(period, 2), q) = 1; end end end for mask = 1:size(vcorr, 1) if ~isnan(quadloc(mask)) && quadloc(mask) ~= 0 scorecorr = cat(2, scorecorr, movmedian(normalize(vcorr(mask,:),'center','median'),medlength(nn),'omitnan')); scoresnr = cat(2, scoresnr, movmedian(normalize(snr(mask,:),'zscore'),medlength(nn),'omitnan')); scoressim = cat(2, scoressim, movmedian(normalize(ssim(mask,:),'center','median'),medlength(nn),'omitnan')); scorediff = cat(2, scorediff, movmedian(normalize(sdiff(mask,:)./sIOU(mask,:),'zscore'),medlength(nn),'omitnan')); scoreIOU = cat(2, scoreIOU, movmedian(normalize(sIOU(mask,:),'center','median'),medlength(nn),'omitnan')); GTm = cat(1, GTm, GT(:, quadloc(mask))); end end end X = cat(2,scoressim(:), scorecorr(:)); Y = GTm(:); % Pick a set of frames with an even number of in and out of focus frames Pind = find(Y == 1); Nind = find(Y == 0); rNind = randsample(Nind, length(Pind)); tInd = cat(1, Pind, rNind); mdl = fitglm(X(tInd, :),Y(tInd),'Distribution','binomial'); LRscores = mdl.Fitted.Probability; %Plot the model output figure [fpr,tpr,T,AUC,opt] = perfcurve(Y(tInd),LRscores, 1, 'Cost', [0 1; 1 0]); plot(fpr,tpr,'k-',opt(1),opt(2),'ro') hold on axis equal title('Logistic Regression') xlabel('False positive rate') ylabel('True positive rate') axis([0, 1, 0, 1]) legend({sprintf("AUC: %.2f\nThreshold: %.2f",AUC,T((fpr==opt(1))&(tpr==opt(2))))},'Location','SouthEast') %% Figure 4F load('..\datasets\figure_4\focuscorrelation_forLR_kmeans.mat'); oofcorr = []; infcorr = []; oofssim = []; infssim = []; % Pool the data across the 15 videos for vidind = 1:15 correlation = focus_layers(vidind).correlation; ssim = focus_layers(vidind).ssim; focusTable = focus_layers(vidind).focusTable; quad = focus_layers(vidind).quadrant; order = [2 1 3 4]; oofind = cell(4, 1); for period = 1:size(focusTable, 1) for q = 1:4 if focusTable(period, q+2) == 1 oofind{q} = cat(2, oofind{q}, focusTable(period, 1):focusTable(period, 2)); end end end for ganglia = 1:size(correlation, 1) if quad(ganglia) ~= 0 && ~isnan(quad(ganglia)) corrt = movmedian(normalize(correlation(ganglia, :),'center','median'),medlength,'omitnan'); oofcorrelation = corrt(oofind{quad(ganglia)}); corrt(oofind{quad(ganglia)}) = []; ssimt = movmedian(normalize(ssim(ganglia, :),'center','median'),medlength,'omitnan'); oofssimt = ssimt(oofind{quad(ganglia)}); ssimt(oofind{quad(ganglia)}) = []; oofcorr = cat(2, oofcorr, oofcorrelation); infcorr = cat(2, infcorr, corrt); oofssim = cat(2, oofssim, oofssimt); infssim = cat(2, infssim, ssimt); end end end %Get the 2D histogram that is used for the contour plot [oofN,oofXedges,oofYedges] = histcounts2(oofcorr, oofssim, -2:0.01:2, -2:0.01:2, 'Normalization', 'probability'); [infN,infXedges,infYedges] = histcounts2(infcorr, infssim, -2:0.01:2, -2:0.01:2, 'Normalization', 'probability'); mdlarray = table2array(mdl.Coefficients); % From the logistic regression fit B0 = mdlarray(1, 1); B1 = mdlarray(2, 1); B2 = mdlarray(3, 1); T = T((fpr==opt(1))&(tpr==opt(2))); %Plot the contour plot figure ax1=axes; contour(oofXedges(1:end-1), oofYedges(1:end-1),oofN', 5, 'LineWidth', 1.5); hold on X1 = -1:0.1:1; plot(X1, (1/B2)*(log(T/(1-T))-B0-B1*X1), 'k') ax2=axes; hold on contour(infXedges(1:end-1), infYedges(1:end-1),infN', 5, 'LineWidth', 1.5); ax2.Visible='off'; linkaxes([ax1 ax2]) % this bit is critical, it makes sure that the limits match up xlim([-0.5 0.5]) ylim([-0.5 0.5]) xticks(-0.5:0.5:0.5) yticks(-0.5:0.5:0.5) cvec=linspace(0.4,1,255)'; cvec2 = linspace(0.4,1,255)'; z=zeros(255, 1); z2=zeros(255,1); colormap(ax1,[cvec z z]) colormap(ax2,[z2 z2 cvec]) xlabel(ax1, 'Normalized Correlation') ylabel(ax1, 'Normalized ssim') legend(ax1, {'OOF', 'LR'}) legend(ax2, {'INF'}, 'Location', 'South') %% Examples for A-E % Names for the selected sample data vidfolder = '..\datasets\videos\registered\'; infocusmaskfolder = '..\datasets\masks\'; focusfolder = '..\datasets\masks\consensus-labels\'; fluorIFF = '..\datasets\figure_4\050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green-in_to_focus-example-fluorescence.mat'; fluorOOF = strcat(vidfolder, '050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green_okada1cubic_registered_Fluorescence.mat'); % Frames to show what the video looks like at baseline framestemp = 30:35; % Frames to show how neurons can go out of focus framesoof = 850:855; % Frames to show how neurons can come into focus framesinf = 770:775; % Selected cells to show loss of focus oofind = [96 97 99 100]; framerate = 20; %Load in the data and make the plots load(strcat(infocusmaskfolder, '050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green_appliedRegistration_projection_consensusmasks.mat')) load(strcat(infocusmaskfolder,'050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green_okada1cubic_registered_associated.mat'), 'maxmove'); infmaxmove = maxmove; focusmasks = masks; load(strcat(focusfolder,'050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green_appliedRegistration.mat')) load('..\datasets\figure_4\050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green_okada1cubic_registered_comeintofocus_masks.mat') load(strcat(vidfolder,'050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green_okada1cubic_registered_associated.mat'), 'maxmove'); oofmasks = masks; oofmasks = fliplr(permute(flipud(oofmasks), [2 1 3])); video = h5read(strcat(vidfolder,'050820-10x-2xbin-Mp50-chatgc3-%5B2.041-0.06%5Defs021_Green_okada1cubic_registered.h5'), '/video3'); lims = [0 100]; [focusmasks_rs, video, ~] = convergeMaxMove(focusmasks, video, infmaxmove, maxmove); [focusmasks, oofmasks, ~] = convergeMaxMove(focusmasks, oofmasks, infmaxmove, maxmove); figure subplot(1, 3, 1) imagesc(mean(video(:, :, framestemp), 3), lims) title('Template image') colormap gray axis equal subplot(1, 3, 2) mask2image_contour(focusmasks(:, :, oofind), mean(video(:, :, framesoof), 3), lims) title('Out of focus neurons') axis equal subplot(1, 3, 3) mask2image_contour(oofmasks, mean(video(:, :, framesinf), 3), lims) title('New in focus neurons') axis equal %Plot the fluorescence activity. Both were calculated using the %registration coordinates figure subplot(2, 1, 1) load(fluorOOF) dFF = (F-repmat(median(F, 2),1 , size(F, 2)))./repmat(median(F, 2),1 , size(F, 2)); plotOpticalTraces(dFF', oofind, framerate) yv = ylim; hold on for fn = 1:size(focusTable, 1) if focusTable(fn, 6) == 1 plot([focusTable(fn, 1) focusTable(fn, 2)]/framerate, [yv(2) yv(2)], 'k', 'Linewidth', 2) end end xlabel('Time (s)') ylabel('dFF') title('Neurons that go out of focus') subplot(2, 1, 2) load(fluorIFF) dFF = (F-repmat(median(F, 2),1 , size(F, 2)))./repmat(median(F, 2),1 , size(F, 2)); plotOpticalTraces(dFF', 1:size(F, 1), framerate) yv = ylim; hold on for fn = 1:size(focusTable, 1) if focusTable(fn, 3) == 1 plot([focusTable(fn, 1) focusTable(fn, 2)]/framerate, [yv(2) yv(2)], 'k', 'Linewidth', 2) end end xlabel('Time (s)') ylabel('dFF') title('Neurons that go into focus') %% Perform the simulation for figure 4 I-P gangmat = '..\datasets\figure_4\consensus_masks_ganglia.mat'; load(gangmat) rng(1) % Set up the parameters for the simulation sensor = 'GCaMP6s'; srange = 10; smin = 10; useind = 12; t = 0:0.05:120; % Select a ganglia and create a box of background around it [d1, d2, d3] = size(gangmask{useind}); allgang = sum(gangmask{useind}, 3); ind = find(allgang); [r, c] = ind2sub(size(allgang), ind); r = [max([1 min(r)-10]) min([d1 max(r)+10])]; c = [max([1 min(c)-10]) min([d1 max(c)+10])]; cgang = double(gangmask{useind}(r(1):r(2), c(1):c(2), :)); [X, Y] = findCOM(double(cgang)); [d1 d2 d3] = size(cgang); [Xm Ym] = meshgrid(1:d2, 1:d1); %Create the nuclear hole in the masks to make them look a bit more %realistic for fn = 1:size(cgang, 3) pdist = sqrt((Ym-Y(fn)).^2+(Xm-X(fn)).^2); s = regionprops(logical(cgang(:, :, fn)), 'MinorAxisLength'); ind = find(pdist < s.MinorAxisLength/5); mask = cgang(:, :, fn); mask(ind) = rand*0.2+0.1; cgang(:, :, fn) = mask; end %Determine the calcium transient times for fn = 1:size(cgang, 3) nCT(fn) = randi(srange)+smin; CTtime{fn} = randsample(length(t)-51, nCT(fn)); end %Set up the figures and colormaps f1 = figure; f1.Position(1) = f1.Position(1) - f1.Position(3); f1.Position(3) = f1.Position(3)*3; tg = uitabgroup; f2 = figure; snr_a = [2 3 4 6 8 10 12 14 16 18]; focr_a = [1 2 3 4 5 6 7 8]; bluered = ones(255, 3); bluered(1:127, 1) = linspace(0, 1, 127); bluered(1:127, 2) = linspace(0, 1, 127); bluered(129:255, 2) = linspace(1, 0, 127); bluered(129:255, 3) = linspace(1, 0, 127); whitered = ones(255, 3); whitered(1:255, 2) = linspace(1, 0, 255); whitered(1:255, 3) = linspace(1, 0, 255); count = 1; medframe = 200; meanframe = 5; lim = 5; for fn = 1:length(snr_a) for nn = 1:length(focr_a) focr = focr_a(nn); snr = snr_a(fn); k = 1; xs = 60; x0 = 30; focus = focr-focr./(1+exp(-k*(abs(t-xs)-x0))); frames = find(focus > max(focus)/2); %Simulate the data when neurons go out of focus [origvid, focvid, ganglia, Fc, dFFc, Ft, dFF, Ffoc, dFFfoc] = simulateCalciumVid(cgang, snr, t, focus, nCT, CTtime); %Simulate the data when neurons come into focus focus = focr - focus; [origvid, iffvid, ganglia, F, dFFc, Ft, dFF, Fiff, dFFiff] = simulateCalciumVid(cgang, snr, t, focus, nCT, CTtime); for trace = 1:size(dFF, 2) GT = CTtime{trace}(ismember(CTtime{trace}, frames)); [pLb tLb{trace} tb(trace)] = peakLocation_prominence(movmean(dFF(:, trace), meanframe), 20, sensor); f1b(fn, nn, trace) = calcF1(GT, tLb{trace}, frames, lim); [pLfoc tLfoc{trace} tfoc(trace)] = peakLocation_prominence(movmean(dFFfoc(:, trace), meanframe), 20, sensor); f1foc(fn, nn, trace) = calcF1(GT, tLfoc{trace}, frames, lim); [pLiff tLiff{trace} tiff(trace)] = peakLocation_prominence(movmean(dFFiff(:, trace), meanframe), 20, sensor); f1iff(fn, nn, trace) = calcF1(GT, tLiff{trace}, frames, lim); end for trace = 1:size(dFF, 2) GT = CTtime{trace}(ismember(CTtime{trace}, frames)); mdFF(:, trace) = dFF(:, trace)-movmedian(dFF(:, trace), medframe); [pLb mtLb{trace} mt(trace)] = peakLocation_prominence(movmean(mdFF(:, trace), meanframe), 20, sensor); mf1b(fn, nn, trace) =calcF1(GT, mtLb{trace}, frames, lim); mdFFfoc(:, trace) = dFFfoc(:, trace)-movmedian(dFFfoc(:, trace), 200); [pLfoc mtLfoc{trace}] = peakLocation_prominence(movmean(mdFFfoc(:, trace), meanframe), 20, sensor); mf1foc(fn, nn, trace) = calcF1(GT, mtLfoc{trace}, frames, lim); mdFFiff(:, trace) = dFFiff(:, trace)-movmedian(dFFiff(:, trace), 200); [pLiff mtLiff{trace}] = peakLocation_prominence(movmean(mdFFiff(:, trace), meanframe), 20, sensor); mf1iff(fn, nn, trace) = calcF1(GT, mtLiff{trace}, frames, lim); end %Plot the identified calcium transients for the no loss of focus, %loss of focus, and come into focus conditions figure(f1) thistab = uitab(tg); axes('Parent',thistab); subplot(1, 3, 1) plotOpticalTraces_transients_compare(dFF, tb, CTtime, tLb, mtLb, 1:5, 20, medframe, meanframe) xlabel('Time (s)') xlim([0 length(dFF)/20]) y(1, :) = ylim; subplot(1, 3, 2) plotOpticalTraces_transients_compare(dFFfoc, tb, CTtime, tLfoc, mtLfoc, 1:5, 20, medframe, meanframe) xlabel('time (s)') xlim([0 length(dFF)/20]) y(2, :) = ylim; subplot(1, 3, 3) plotOpticalTraces_transients_compare(dFFiff, tb, CTtime, tLiff, mtLiff, 1:5, 20, medframe, meanframe) xlabel('time (s)') xlim([0 length(dFF)/20]) y(2, :) = ylim; my = max(y(:)); subplot(1, 3, 1) ylim([0 my]) xticklabels([0:20:120]) subplot(1, 3, 2) ylim([0 my]) xticklabels([0:20:120]) subplot(1, 3, 3) ylim([0 my]) xticklabels([0:20:120]) %Show an example frame during a calcium transient under different %SNR and sigma values for the gaussian filter to approximate loss %of focus. thistab.Title = sprintf("%i, %i",focr,snr); sgtitle(strcat('Guassian \sigma =', 32, num2str(focr), 32, 'SNR =', 32, num2str(snr))) figure(f2) if iseven(fn) && iseven(nn) subplot(length(snr_a)/2, length(focr_a)/2, count) frame = CTtime{1}(find(abs(CTtime{1}-1200) < 200, 1, 'first')); mask2image_contour(logical(cgang), focvid(:, :, frame), [0 200]) axis equal x = xlim; y = ylim; pixel = 20; hold on xt = linspace(x(2)*0.75, x(2)*0.9, 100); plot(xt, -(dFFfoc((-19:80)+frame, 1)-median(dFFfoc((-19:100)+frame, 1)))*80+pixel, 'k') plot(xt(1:20), ones(20, 1)*(pixel+3), 'k', 'LineWidth', 2) plot(ones(9, 1)*xt(1)-1, (pixel-8):pixel, 'k', 'LineWidth', 2) count = count+1; title(strcat('Guassian \sigma =', 32, num2str(focr), 32, 'SNR =', 32, num2str(snr))) colormap gray end clear origvid focvid end end %Plot the f1 scores under different conditions figure %Baseline accuracy of our algorithm imagesc(mean(f1b, 3), [0 1]) colormap(whitered) colorbar title('F1_{base}') xticks([1:length(focr_a)]) yticks([1:length(snr_a)]) xticklabels(num2str(focr_a')) yticklabels(num2str(snr_a')) xlabel('Gaussian filter \sigma') ylabel('SNR') figure subplot(2, 2, 1) %Loss of focus imagesc(mean(f1foc, 3), [0 1]) colormap(whitered) colorbar title('F1_{oof}') xticks([1:length(focr_a)]) yticks([1:length(snr_a)]) xticklabels(num2str(focr_a')) yticklabels(num2str(snr_a')) xlabel('Gaussian filter \sigma') ylabel('SNR') subplot(2, 2, 2) %Gain of focus imagesc(mean(f1iff, 3), [0 1]) colormap(whitered) colorbar title('F1_{iff}') xticks([1:length(focr_a)]) yticks([1:length(snr_a)]) xticklabels(num2str(focr_a')) yticklabels(num2str(snr_a')) xlabel('Gaussian filter \sigma') ylabel('SNR') subplot(2, 2, 3) %Loss of focus with median filtering imagesc(mean(mf1foc, 3), [0 1]) colormap(whitered) colorbar title('F1_{oof} with median filtering') xticks([1:length(focr_a)]) yticks([1:length(snr_a)]) xticklabels(num2str(focr_a')) yticklabels(num2str(snr_a')) xlabel('Gaussian filter \sigma') ylabel('SNR') subplot(2, 2, 4) %Gain of focus with median filtering imagesc(mean(mf1iff, 3), [0 1]) colormap(whitered) colorbar title('F1_{iff} with median filtering') xticks([1:length(focr_a)]) yticks([1:length(snr_a)]) xticklabels(num2str(focr_a')) yticklabels(num2str(snr_a')) xlabel('Gaussian filter \sigma') ylabel('SNR')