function varargout = process_beamformer_lcmv( varargin )
% PROCESS_LCMV_BEAMFORMER: Linearly-constrained minimum variance beamformer

% @=============================================================================
% This software is part of the Brainstorm software:
% http://neuroimage.usc.edu/brainstorm
% 
% Copyright (c)2000-2015 University of Southern California & McGill University
% This software is distributed under the terms of the GNU General Public License
% as published by the Free Software Foundation. Further details on the GPL
% license can be found at http://www.gnu.org/copyleft/gpl.html.
% 
% FOR RESEARCH PURPOSES ONLY. THE SOFTWARE IS PROVIDED "AS IS," AND THE
% UNIVERSITY OF SOUTHERN CALIFORNIA AND ITS COLLABORATORS DO NOT MAKE ANY
% WARRANTY, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO WARRANTIES OF
% MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, NOR DO THEY ASSUME ANY
% LIABILITY OR RESPONSIBILITY FOR THE USE OF THIS SOFTWARE.
%
% For more information type "brainstorm license" at command prompt.
% =============================================================================@
%
% Authors: Hui-Ling Chan, Francois Tadel, Sylvain Baillet, 2014

macro_methodcall;
end


%% ===== GET DESCRIPTION =====
function sProcess = GetDescription() %#ok<DEFNU>
    % Description the process
    sProcess.Comment     = 'Beamformer: LCMV (20141022) [Experimental]';
    sProcess.FileTag     = '';
    sProcess.Category    = 'Custom';
    sProcess.SubGroup    = 'Sources';
    sProcess.Index       = 340;
    % Definition of the input accepted by this process
    sProcess.InputTypes  = {'data'};
    sProcess.OutputTypes = {'results'};
    sProcess.nInputs     = 1;
    sProcess.nMinFiles   = 1;
    % Definition of the options
    sProcess.options.comment.Comment    = 'Comment: ';
    sProcess.options.comment.Type       = 'text';
    sProcess.options.comment.Value      = '';
    % === OPTIONS: Cortial constrained orientation (eig -> singular value, corrected in 20141009)
    sProcess.options.label_orient.Comment = '<HTML><BR>&nbsp;Sources orientation:';
    sProcess.options.label_orient.Type    = 'label';
    sProcess.options.oriconstraint.Comment = {'Unconstrained (max singular value)','Unconstrained (trace)', 'Constrained (normal to cortex)'};
    sProcess.options.oriconstraint.Type    = 'radio';
    sProcess.options.oriconstraint.Value   = 1;
    % === ACTIVE TIME RANGE
    sProcess.options.label_range.Comment = '<HTML><BR> ';
    sProcess.options.label_range.Type    = 'label';
    sProcess.options.nai_range.Comment = 'Time range of interest: ';
    sProcess.options.nai_range.Type    = 'poststim';
    sProcess.options.nai_range.Value   = [];
    % === ACTIVE TIME WINDOW SIZE
    sProcess.options.active_window_size.Comment = 'Active window size: ';
    sProcess.options.active_window_size.Type    = 'value';
    sProcess.options.active_window_size.Value   = {0, 'ms', 1};
    % === ACTIVE TEMPORAL RESOLUTION
    sProcess.options.nai_tresolution.Comment = 'Temporal resolution: ';
    sProcess.options.nai_tresolution.Type    = 'value';
    sProcess.options.nai_tresolution.Value   = {0, 'ms', 1};
    % === REGULARIZATION
    sProcess.options.reg.Comment = 'Regularization parameter: ';
    sProcess.options.reg.Type    = 'value';
    sProcess.options.reg.Value   = {1, '%', 4}; % enlarge the default value (20141009)
    % === SENSOR TYPES
    sProcess.options.label_sensor.Comment = '<HTML><BR> ';
    sProcess.options.label_sensor.Type    = 'label';
    sProcess.options.sensortypes.Comment = 'Sensor types or names (empty=all): ';
    sProcess.options.sensortypes.Type    = 'text';
    sProcess.options.sensortypes.Value   = 'MEG';
    % === SAVE SPATIAL FILTER
    sProcess.options.savefilter.Comment = 'Save spatial filters';
    sProcess.options.savefilter.Type    = 'checkbox';
    sProcess.options.savefilter.Value   = 0;
end


%% ===== FORMAT COMMENT =====
function Comment = FormatComment(sProcess) %#ok<DEFNU>
    Comment = sProcess.Comment;
end


%% ===== RUN =====
function OutputFiles = Run(sProcess, sInputs) %#ok<DEFNU>
    % Initialize returned list of files
    OutputFiles = {};
    % Get option values
    NAIRange         = sProcess.options.nai_range.Value{1};
    Reg              = sProcess.options.reg.Value{1};
    SensorTypes      = sProcess.options.sensortypes.Value;
    ActiveWindowSize = sProcess.options.active_window_size.Value{1};
    NAITResolu       = sProcess.options.nai_tresolution.Value{1};
    Comment          = sProcess.options.comment.Value;
    OrientConstr     = sProcess.options.oriconstraint.Value;
    isSaveFilter     = sProcess.options.savefilter.Value;

% profile on 
    
    % ===== LOAD THE DATA =====
    bst_progress('start', 'Beamformer: LCMV', 'Loading input data...');
    % Read the first file in the list, to initialize the loop
    DataMat = in_bst_data(sInputs(1).FileName);
    nChannels = size(DataMat.F,1);
    nTime     = size(DataMat.F,2);
    Time = DataMat.Time;
    % Lock input time window on the real time vector
    iNAIRange = panel_time('GetTimeIndices', Time, NAIRange);
    NAIRange = [Time(iNAIRange(1)), Time(iNAIRange(end))];
    % Convert active window to a number of samples
    sFreq = 1 ./ (Time(2) - Time(1));
    ActiveWindowSmp = ActiveWindowSize / sFreq;
    
    % ===== CHECK THE TIME WINDOWS =====
    % Verify the coherence of the inputs
    if (NAIRange(1) > NAIRange(2))
        bst_report('Error', sProcess, sInputs, 'The setting of time range of interest is incorrect.');
        return;
    end
    if (NAIRange(1) < Time(1))
        bst_report('Warning', sProcess, sInputs, 'The start for time range of interest is reset to the first time point of data');
        NAIRange(1) = Time(1);
    end
    if (NAIRange(2) > Time(end))
        bst_report('Warning', sProcess, sInputs, 'The end for time range of interest is reset to the end point of data');
        NAIRange(2) = Time(end);
    end
    if ((NAIRange(1)+ActiveWindowSize) > NAIRange(2))
        bst_report('Warning', sProcess, sInputs, 'The active window size is reset to the same as the time range of interest.');
        ActiveWindowSize = NAIRange(2) - NAIRange(1);
    end
    % Temporal resolution=0 : Active window size set to all the resolution of interest
    if (NAITResolu == 0) || (ActiveWindowSize == 0)
        NAITResolu = 1;
        ActiveWindowSize = NAIRange(2) - NAIRange(1);
    end
    
    % Get time points of interest
    NAIPoints = panel_time('GetTimeIndices', Time, [NAIRange(1)+ActiveWindowSize, NAIRange(2)]);
    if (length(NAIPoints) <= 1)
        nNAI = 1;
    else 
        nNAI = length((NAIRange(1)+ActiveWindowSize):NAITResolu:NAIRange(2));
    end
    
    NAITimeList= zeros(nNAI,2);
    for i=1:nNAI
        NAITimeList(i,:) = NAIRange(1) + NAITResolu*(i-1) + [0 ActiveWindowSize] ;
    end 
    
    MinVarTime = [NAIRange(1) NAIRange(2)];
    
    nActiveWindowPoints = length(panel_time('GetTimeIndices', Time, NAITimeList(1,:)));
    iActiveWindowTime = zeros(nNAI, nActiveWindowPoints);
    for i = 1:nNAI
        if (NAITimeList(i,1) < Time(1) || NAITimeList(i,2) > Time(end))
            bst_report('Error', sProcess, sInputs, 'One active time window is not within the data time range.');
            return;
        end            
        single_iActiveWindow = panel_time('GetTimeIndices', Time, NAITimeList(i,:));
        
        % detect the round error caused by the non-interger sampling rate
        % -- added in 20141008
        if  nActiveWindowPoints - length(single_iActiveWindow) == 1
            single_iActiveWindow(nActiveWindowPoints) = single_iActiveWindow(end) + 1;
        end
        %%%%%%%%
        iActiveWindowTime(i,:) = single_iActiveWindow(1:nActiveWindowPoints);    
    end
    iMinVarTime = panel_time('GetTimeIndices', Time, MinVarTime);
    
    
    % ===== LOAD CHANNEL FILE =====
    % Load channel file
    ChannelMat = in_bst_channel(sInputs(1).ChannelFile);
    % Get the channels of interest
    iChannels = channel_find(ChannelMat.Channel, SensorTypes);
    
    % ===== LOAD HEAD MODEL =====
    % Get channel study
    [sChannelStudy, iChannelStudy] = bst_get('ChannelFile', sInputs(1).ChannelFile);
    % Load the default head model
    HeadModelFile = sChannelStudy.HeadModel(sChannelStudy.iHeadModel).FileName;
    sHeadModel = in_headmodel_bst(HeadModelFile);
    if (OrientConstr == 3) && (isequal(sHeadModel.HeadModelType,'volume') || isempty(sHeadModel.GridOrient))
       bst_report('Error', sProcess, sInputs, 'No dipole orientation for cortical constrained beamformer estimation.');
       return;  
    end
    
    
    % ===== LOAD ALL INPUT FILES =====
    % Initialize the covariance matrices
    MinVarCov =  zeros(nChannels, nChannels);
    nTotalMinVar = zeros(nChannels, nChannels);
    NAIActiveCov = zeros(nChannels, nChannels, nNAI);
    nTotalNAIActive = zeros(nChannels, nChannels, nNAI);
    bst_progress('start', 'Beamformer: LCMV', 'Calulating covariance matrix...', 0, length(sInputs));
    % Reading all the input files in a big matrix
    for i = 1:length(sInputs)
        % Read the file #i
        DataMat = in_bst_data(sInputs(i).FileName);
        % Check the dimensions of the recordings matrix in this file
        if (size(DataMat.F,1) ~= nChannels) || (size(DataMat.F,2) ~= nTime)
            bst_report('Error', sProcess, sInputs, 'One file has a different number of channels or a different number of time samples.');
            return;
        end
        % Get good channels
        iGoodChan = find(DataMat.ChannelFlag == 1);  
        
        % Average baseline values
        FavgActive = mean(DataMat.F(iGoodChan,iMinVarTime), 2);
        %FavgActive = 0;
        % Remove average
        DataActive = bst_bsxfun(@minus, DataMat.F(iGoodChan,iMinVarTime), FavgActive);
        % Compute covariance for this file
        fileActiveCov = DataMat.nAvg .* (DataActive * DataActive');        
        % Add file covariance to accumulator
        MinVarCov(iGoodChan,iGoodChan) = MinVarCov(iGoodChan,iGoodChan) + fileActiveCov;       
        nTotalMinVar(iGoodChan,iGoodChan) = nTotalMinVar(iGoodChan,iGoodChan) + length(iMinVarTime);
        
        for j = 1:nNAI
            % Average baseline values
            FavgNAIActive = mean(DataMat.F(iGoodChan,iActiveWindowTime(j,:)), 2);  
            %FavgNAIActive = 0;
            % Remove average
            DataNAIActive = bst_bsxfun(@minus, DataMat.F(iGoodChan,iActiveWindowTime(j,:)), FavgNAIActive);        
            % Compute covariance for this file
            fileNAIActiveCov = DataMat.nAvg * (DataNAIActive * DataNAIActive');    
            % Add file covariance to accumulator
            NAIActiveCov(iGoodChan,iGoodChan,j) = NAIActiveCov(iGoodChan,iGoodChan,j) + fileNAIActiveCov;      
            nTotalNAIActive(iGoodChan,iGoodChan,j) = nTotalNAIActive(iGoodChan,iGoodChan,j) + nActiveWindowPoints;
        end
        bst_progress('inc',1);
    end

    % Remove zeros from N matrix
    nTotalMinVar(nTotalMinVar <= 1) = 2;
    nTotalNAIActive(nTotalNAIActive <= 1) = 2;
    % Divide final matrix by number of samples
    MinVarCov = MinVarCov ./ (nTotalMinVar - 1);
    NAIActiveCov = NAIActiveCov ./ (nTotalNAIActive - 1);

    
    % ===== PROCESS =====
    % Extract the covariance matrix of the used channels
    NoiseCovMat  = load(file_fullpath(sChannelStudy.NoiseCov(1).FileName));
    NoiseCov     = NoiseCovMat.NoiseCov(iChannels,iChannels);
    NAIActiveCov = NAIActiveCov(iChannels,iChannels,:);
    MinVarCov    = MinVarCov(iChannels,iChannels);
    
    % Normalize the regularization parameter
    % eigValues = eig(MinVarCov);
    % Reg_alpha = Reg * max(eigValues);
    bst_progress('start', 'Beamformer: LCMV', 'Calulating inverse covariance matrix...', 0, nNAI+2);
    % Calculate the inverse of (Cn+alpha*I)
    Cn_inv = pinv(NoiseCov);
    bst_progress('inc',1); 
    Cm_inv = zeros(length(iChannels),length(iChannels),nNAI);
    for i = 1:nNAI
        % Calculate the inverse of (Cm+alpha*I)
        Cm_inv(:,:,i) = pinv(NAIActiveCov(:,:,i));
        bst_progress('inc',1); 
    end
   
    MinVarCov_inv = regulatePInv(MinVarCov,Reg);
    bst_progress('inc',1);
    
    % Get forward field
    if (OrientConstr == 3) % constrained LCMV
        Kernel = bst_gain_orient(sHeadModel.Gain(iChannels,:), sHeadModel.GridOrient);
    else
        Kernel = sHeadModel.Gain(iChannels,:);
    end
    Kernel(abs(Kernel(:)) < eps) = eps; % Set zero elements to strictly non-zero
    

    
    % ===== CALCULATE NEURAL ACTIVITY INDEX AND FILTERS =====
    % Different methods
    switch (OrientConstr)
        % Unconstrained: maximum eigenvalue as the power
        case 1
            nComponents = 3;
            strConstr = 'Unconstr/svd'; % eig -> svd, corrected in 20141009
            fcn_lcmv = @(c)lambda1(lcmvPInv(c));
        % Unconstrained: use trace to calculate the power
        case 2
            nComponents = 3; 
            strConstr = 'Unconstr/trace';
            fcn_lcmv = @(c)trace(lcmvPInv(c));
        % Constrained
        case 3
            nComponents = 1;
            strConstr = 'Constr';
            fcn_lcmv = @(c)1./c;    % Inverse function
    end    
    % Number of sources
    [nChannels, nSources] = size(Kernel);
    nSources = nSources / nComponents;
    nBlockProgress = 300;
    bst_progress('start', 'Beamformer: LCMV', 'Calculating neural activity index...', 0, round(nSources/nBlockProgress));
    % Initialize variables
    NAI = zeros(nSources, nNAI);
    if isSaveFilter
        spatialFilter = zeros(nSources, nChannels);
    end
    % Loop on sources
    for i = 1:nSources
        % Indices for unconstrained model
        iu = nComponents * (i-1) + (1:nComponents);       
        % Obtain gain matrix 
        Gi = Kernel(:,iu);
        % At each vertex, noise pwr = (G Cn^-1 G')^(-1)
        NoisePower = fcn_lcmv(Gi' * Cn_inv * Gi);
        % For each time point
        for j = 1:nNAI
            % At each vertex, src pwr = (G Cm^-1 G')^(-1)
            SourcePower = fcn_lcmv(Gi' * Cm_inv(:,:,j) * Gi);
            % Calculate the neural activity index: 
            NAI(i,j) = SourcePower / NoisePower;
        end
        % Spatial filter
        if isSaveFilter
            % Calculate spatial filter W
            A = Gi' * MinVarCov_inv;
            W = pinv(A * Gi) * A;
            % Normalize the spatial filter using the power of control state
            spatialFilter(iu,:) = W / sqrt(NoisePower); 
        end
        % Increase by one every block of sources
        if (mod(i,nBlockProgress) == 0)
            bst_progress('inc',1);
        end
    end
    
    
    % ===== ASSIGN MAPS AND KERNELS =====
    bst_progress('start', 'Beamformer: LCMV', 'Saving Results...');

    if (nNAI == 1)
        ImageGridAmp = [NAI, NAI];
        TimeIndex = [Time(1), Time(end)];
    else % Interpolate the NAI to have the same temporal resolution as the data
        ImageGridAmpOriginal = NAI;
        ImageGridAmp = zeros(nSources, nTime);
        for i=1:(nNAI-1)
            InterpolateTimeWindow = NAIRange(1) + ActiveWindowSize/2 + [(i-1)*NAITResolu, i*NAITResolu];
            iInterpolateTime = panel_time('GetTimeIndices', Time, InterpolateTimeWindow);
            nInterpolateTime = length(iInterpolateTime);
            
            ImageGridAmp(:,iInterpolateTime(1)) = ImageGridAmpOriginal(:,i); 
            ImageGridAmp(:,iInterpolateTime(end)) = ImageGridAmpOriginal(:,i+1); 
            
            for j=2:(nInterpolateTime-1)       
                InterpolatePercentage = (j-1)/(nInterpolateTime-1);
                ImageGridAmp(:,iInterpolateTime(j)) = ImageGridAmpOriginal(:,i+1)*InterpolatePercentage + ImageGridAmpOriginal(:,i)*(1-InterpolatePercentage);
            end
        end
        TimeIndex = Time;
    end

    
    % ===== NAI: SAVE THE RESULTS =====
    % Comment
    if ~isempty(Comment)
        Comment = [Comment ': '];
    end
    if (nNAI == 1)
        Comment1 = sprintf('%sLCMV (%s, %d-%dms)', Comment, strConstr, round(NAIRange(1)*1000), round(NAIRange(2)*1000));
    else
        Comment1 = sprintf('%sLCMV (%s, %d-%dms, ws:%dms, tr:%dms)', Comment, strConstr, round((NAIRange(1)+ActiveWindowSize/2)*1000), ...
            round((NAIRange(1) + ActiveWindowSize/2 + (nNAI-1)*NAITResolu)*1000), round(ActiveWindowSize*1000), round(NAITResolu*1000));
    end
    % Create a new data file structure
    ResultsMat = db_template('resultsmat');
    ResultsMat.ImagingKernel = [];
    ResultsMat.ImageGridAmp  = ImageGridAmp;
    ResultsMat.Comment       = Comment1;
    ResultsMat.nComponents   = 1;   % 1 or 3
    ResultsMat.Function      = 'LCMVneuralActivityIndex';
    ResultsMat.Time          = TimeIndex;           % Leave it empty if using ImagingKernel
    ResultsMat.DataFile      = [];
    ResultsMat.HeadModelFile = HeadModelFile;
    ResultsMat.HeadModelType = sHeadModel.HeadModelType;
    ResultsMat.ChannelFlag   = [];
    ResultsMat.GoodChannel   = iChannels;
    ResultsMat.SurfaceFile   = sHeadModel.SurfaceFile;
    ResultsMat.GridLoc       = []; %sHeadModel.GridLoc
    % Get the output study (pick the one from the first file)
    iStudy = sInputs(1).iStudy;
    % Create a default output filename 
    OutputFiles{1} = bst_process('GetNewFilename', fileparts(sInputs(1).FileName), 'results_LCMV_NAI');
    % Save on disk
    save(OutputFiles{1}, '-struct', 'ResultsMat');
    % Register in database
    db_add_data(iStudy, OutputFiles{1}, ResultsMat);

    
    % ===== SPATIAL FILTER: SAVE FILE =====
    if isSaveFilter
        % Set file comment
        Comment2 = sprintf('%sLCMV/filter (%s, %d-%dms)', Comment, strConstr, round(MinVarTime(1)*1000), round(MinVarTime(2)*1000));
        % Create new structure
        ResultsMat2 = db_template('resultsmat');
        ResultsMat2.ImagingKernel = spatialFilter;
        ResultsMat2.ImageGridAmp  = [];
        ResultsMat2.Comment       = Comment2;
        ResultsMat2.Function      = 'LCMVspatialfilter';
        ResultsMat2.Time          = [];           % Leave it empty if using ImagingKernel
        ResultsMat2.DataFile      = [];
        ResultsMat2.HeadModelFile = HeadModelFile;
        ResultsMat2.HeadModelType = sHeadModel.HeadModelType;
        ResultsMat2.nComponents   = nComponents;
        ResultsMat2.ChannelFlag   = [];
        ResultsMat2.GoodChannel   = iChannels;
        ResultsMat2.SurfaceFile   = sHeadModel.SurfaceFile;
        ResultsMat2.GridLoc       = sHeadModel.GridLoc;
        % Get the output study (pick the one from the first file)
        iStudy = iChannelStudy;
        % Create a default output filename (shared kernel)
        OutputFiles{2} = bst_process('GetNewFilename', fileparts(sInputs(1).ChannelFile), 'results_LCMV_KERNEL');
        % Save on disk
        save(OutputFiles{2}, '-struct', 'ResultsMat2');
        % Register in database
        db_add_data(iStudy, OutputFiles{2}, ResultsMat2);
    end
    
% profile viewer;
    
end



%% ===== TRUNCATED PSEUDO-INVERSE =====
function X = lcmvPInv(A)
    % Inverse of 3x3 GCG' in unconstrained beamformers.
    % Since most head models have rank 2 at each vertex, we cut all the fat and
    % just take a rank 2 inverse of all the 3x3 matrices
    [U,S,V] = svd(A,0);
    S = diag(S);
    X = V(:,1:2)*diag(1./S(1:2))*U(:,1:2)';
end

function X = regulatePInv(A,Reg)
    eigValues = eig(A);
    Reg_alpha = (Reg/100) * max(eigValues);
    X = inv(A + Reg_alpha*eye(size(A)));
    % Calculate the inverse of (Cm+alpha*I)
%     [U,S,V] = svd(A,0);
%     S = diag(S); % Covariance = Cm = U*S*U'
%     Si = diag(1 ./ (S + S(1) * Reg / 100)); % 1/(S^2 + lambda I)
%     X = V*Si*U'; % U * 1/(S^2 + lambda I) * U' = Cm^(-1)   
end

function X = lambda1(A)
    [U,S,V] = svd(A,0);
    X = S(1,1);
end



