forked from exporl/adaptive-procedures-comparison
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalculateConvergence.m
98 lines (85 loc) · 3.96 KB
/
calculateConvergence.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
function convergence = calculateConvergence(procedure,values,reversals,nStaircasePlots)
% Calculate the convergence points for simulations of an adaptive
% procedure.
%
% INPUTS:
% procedure struct that defines the procedure parameters,
% with required fields:
% procedure.nLast number of last trials/reversals to take into
% account for the calculation of the convergence
% point
% procedure.nLastType either 'trials' or 'reversals'
% values matrix with a staircase for each simulation of
% the respective procedure (the amount of rows
% corresponds to the amount of simulations that
% were done)
% reversals boolean matrix that specifies for each value of
% the starcaises whether a reversal occured.
% nStaircasePlots (optional): number of desired staircase plots
% to give an idea of the adaptive track
%
% OUTPUTS:
% convergence array with all estimated convergence points for
% each simulation (the length corresponds to the
% amount of simulations)
%
% Author: Benjamin Dieudonné, KU Leuven, Department of Neurosciences, ExpORL
% Correspondence: [email protected]
if nargin<4
nStaircasePlots = 0;
end
nSimulations = size(values,1);
convergence = nan(nSimulations,1);
experimentsDropped = 0;
switch procedure.nLastType
case 'trials'
for i=1:nSimulations
rowValues = values(i,~isnan(values(i,:)));
convergence(i) = mean(rowValues(end-procedure.nLast+1:end));
end
if nStaircasePlots>0 % plot some staircases
plotIDs = randi([1 nSimulations],nStaircasePlots);
trialIndices = 1:size(values,2);
hold on;
for iPlotID = 1:length(plotIDs)
plotID = plotIDs(iPlotID);
h = plot(trialIndices,values(plotID,:),'LineWidth',1.5);
plot(trialIndices(end-procedure.nLast+1:end),values(plotID,end-procedure.nLast+1:end),'.','color',get(h,'color'),'MarkerSize',10);
end
hold off;
xlabel('Trial index');
end
case 'reversals'
for i=1:nSimulations
rowReversalValues = values(i,reversals(i,:));
rowReversalValues = rowReversalValues(~isnan(rowReversalValues));
if length(rowReversalValues)-procedure.nLast+1<1
warning('Experiment %i of %i is dropped because it does not have enough reversals.',i,nSimulations);
convergence(i) = nan;
experimentsDropped = 1;
else
convergence(i) = mean(rowReversalValues(end-procedure.nLast+1:end));
end
end
if nStaircasePlots>0 % plot some staircases
plotIDs = randi([1 nSimulations],nStaircasePlots);
trialIndices = 1:size(values,2);
hold on;
for iPlotID = 1:length(plotIDs)
plotID = plotIDs(iPlotID);
reversalIndices = trialIndices(reversals(plotID,:));
rowReversalValues = values(plotID,reversals(plotID,:));
h = plot(trialIndices,values(plotID,:),'LineWidth',1.5);
plot(reversalIndices(end-procedure.nLast+1:end),rowReversalValues(end-procedure.nLast+1:end),'.','color',get(h,'color'),'MarkerSize',10);
end
hold off;
xlabel('Trial index');
end
otherwise
error('This type of procedure.nLastType is not supported');
end
convergence = convergence(~isnan(convergence));
if experimentsDropped
warning('%i of %i (%0.1f%%) experiments dropped because they did not have enough reversals.',nSimulations-length(convergence),nSimulations,100*(nSimulations-length(convergence))/nSimulations);
end
end