import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
from mne.time_frequency import psd_array_welch
from mne.viz import plot_topomap
import mne
import seaborn as sns
from .misc import _get_ixs_goods
from .base import RawCLAM, EpochsCLAM, EpochsCLAMVariable
def _get_lcmv_weights(COV, forward):
COVinv = linalg.pinv(COV)
w = ((COVinv @ forward[:, None])).squeeze() / \
(forward[None, :] @ COVinv @ forward[:, None])
w /= np.sqrt(w @ COV @ w.T)
return w
def _update_dict_with_mean(dct, keys):
mean = np.mean([dct[key] for key in keys], axis=0)
for key in keys:
dct[key] = mean
return dct
[docs]
def get_target(obj):
"""Compute the target signal from EEG data using LCMV beamforming.
This function computes the target signal from EEG data using Linearly Constrained Minimum Variance (LCMV) beamforming [1].
It extracts the target signal by spatially filtering the EEG data based on the forward model.
Parameters:
-----------
obj : RawCLAM or EpochsCLAM object
The RawCLAM or EpochsCLAM object containing EEG data.
Returns:
--------
target : ndarray
The target signal extracted from the EEG data using LCMV beamforming.
Raises:
-------
Exception:
If the input object is not an instance of RawCLAM or EpochsCLAM.
References:
-----------
[1] Van Veen, Barry D., and Kevin M. Buckley. "Beamforming: A versatile approach to spatial filtering."
IEEE assp magazine 5.2 (1988): 4-24.
"""
if not (isinstance(obj, RawCLAM) or isinstance(obj, EpochsCLAM) or isinstance(obj, EpochsCLAMVariable)):
raise Exception('get_target can only be applied to RawCLAM, EpochsCLAM, or EpochsCLAMVariable objects')
ixs_goods = _get_ixs_goods(obj)
target_codes = obj.marker_definition.keys()
target_phases = obj.marker_definition.values()
if isinstance(obj, EpochsCLAMVariable):
epochs_events = obj.events
epochs_data = obj.get_data(ixs_goods)
forward_goods = obj.forward_full[ixs_goods]
covs = {}
for tc, tp in zip(target_codes, target_phases, strict=True):
mask = epochs_events[:, 2] == tc
matching = [ep for ep, m in zip(epochs_data, mask) if m]
if matching:
covs[tc] = np.cov(np.real(np.concatenate(matching, axis=-1)))
cl_codes = [tc for tc, tp in zip(target_codes, target_phases, strict=True) if tp not in ('open-loop', 'no-stim')]
cl_codes = [tc for tc in cl_codes if tc in covs]
if cl_codes:
covs = _update_dict_with_mean(covs, cl_codes)
ws = {}
for tc in covs.keys():
ws[tc] = _get_lcmv_weights(covs[tc], forward_goods)
target = []
for tc, ep in zip(epochs_events[:, 2], epochs_data, strict=True):
target.append((ws[tc] @ ep).squeeze())
# return list of 1D arrays (variable length)
elif isinstance(obj, mne.Epochs):
epochs_events = obj.events
epochs_data = obj.get_data(ixs_goods)
forward_goods = obj.forward_full[ixs_goods]
covs = {}
for tc, tp in zip(target_codes, target_phases, strict=True):
covs[tc] = np.cov(np.real(np.concatenate(epochs_data[epochs_events[:, 2]
== tc], axis=-1)))
# ensure the covariance matrix for closed-loop conditions is the average of all
# closed-loop conditions to avoid suppression of correlated sources
cl_codes = [tc for tc, tp in zip(target_codes, target_phases, strict=True) if tp not in ('open-loop', 'no-stim')]
covs = _update_dict_with_mean(covs, cl_codes)
ws = {}
for tc in covs.keys():
ws[tc] = _get_lcmv_weights(covs[tc], forward_goods)
target = []
for tc, ep in zip(epochs_events[:, 2], epochs_data, strict=True):
target.append(ws[tc] @ ep)
target = np.array(target).squeeze()
else:
# TODO: Warn here if raw data contains a mixture of no stimulation and stimulation segments.
raw_events = mne.events_from_annotations(obj)[0]
raw_data = obj.get_data(ixs_goods)
if obj.marker_definition:
epochs = EpochsCLAM(obj)
epochs_events = epochs.events
epochs_data = epochs.get_data(ixs_goods)
forward_goods = epochs.forward_full[ixs_goods]
COV = np.mean([np.cov(np.real(np.concatenate(epochs_data[epochs_events[:, 2]
== target_code], axis=-1))) for target_code in target_codes], axis=0)
else:
forward_goods = obj.forward_full[ixs_goods]
COV = np.cov(raw_data)
w = _get_lcmv_weights(COV, forward_goods)
target = (w @ raw_data).squeeze()
if isinstance(target, list):
target = [t * obj.flip for t in target]
else:
target *= obj.flip
return target
[docs]
def set_forward(raw, l_freq_noise, h_freq_noise, n_comp=4):
"""Compute and set the forward model for target source reconstruction.
This method computes and sets the forward model necessary for LCMV beamforming.
A data-driven approach called Spatio-Spectral Decomposition (SSD) [1] is used to
find components in the signal that maximize power in the target frequency range
while minimizing power in the noise frequency range.
A component must be selected by the user by clicking on the respective power
spectrum before closing the plot.
Parameters:
-----------
raw : RawCLAM object
The RawCLAM object containing EEG data.
l_freq_noise : float
Lower edge of the noise frequency range.
h_freq_noise : float
Higher edge of the noise frequency range.
n_comp : int, optional
Number of SSD components to plot. Default is 4.
Returns:
--------
None
Raises:
-------
Exception:
If the input object is not an instance of RawCLAM.
If the data does not contain frequencies of at least 1 - 40 Hz.
If no target for stimulation (forward model) is selected by the user.
References:
-----------
[1] Nikulin, Vadim V., Guido Nolte, and Gabriel Curio. "A novel method for reliable and fast extraction of neuronal
EEG/MEG oscillations on the basis of spatio-spectral decomposition." NeuroImage 55.4 (2011): 1528-1535.
"""
if not isinstance(raw, RawCLAM):
raise Exception('set_forward can only be applied to RawCLAM objects')
sfreq = raw.info['sfreq']
l_freq = raw.info['highpass']
h_freq = raw.info['lowpass']
l_freq_target = raw.l_freq_target
h_freq_target = raw.h_freq_target
marker_definition = raw.marker_definition
participant = raw.participant
ixs_goods = _get_ixs_goods(raw)
if not (l_freq <= 1 and 40 <= h_freq):
raise Exception(
'Data should contain frequencies of at least 1 - 40 Hz')
data_broad = raw.copy().filter(1, 40).get_data(ixs_goods)
data_signal = raw.copy().filter(
l_freq_target, h_freq_target).get_data(ixs_goods)
data_noise = raw.copy().filter(
l_freq_noise, h_freq_noise).get_data(ixs_goods)
COV_B = np.cov(data_broad)
COV_S = np.cov(data_signal)
COV_N = np.cov(data_noise)
evals, evecs = linalg.eig(COV_S, COV_N)
ix = np.argsort(evals)[::-1]
D = evecs[:, ix].T
M = linalg.pinv(D)
fig, axes = plt.subplots(
n_comp, 2, figsize=(
7, 7), gridspec_kw={
'width_ratios': [
1, 1]})
for ix_comp in range(n_comp):
w = _get_lcmv_weights(COV_B, M[:, ix_comp])
psd, freqs = psd_array_welch(
w @ data_broad, raw.info['sfreq'], fmin=1, fmax=40, n_fft=int(3 * raw.info['sfreq']))
axes[ix_comp, 0].semilogy(freqs, psd.flatten(), c='k')
axes[ix_comp, 0].tick_params(axis='x', labelsize=8)
axes[ix_comp, 0].tick_params(axis='y', labelsize=5)
axes[ix_comp, 0].axvline(
l_freq_target, color='grey', linestyle='--', linewidth=0.5)
axes[ix_comp, 0].axvline(
h_freq_target, color='grey', linestyle='--', linewidth=0.5)
plot_topomap(M[:, ix_comp], mne.pick_info(raw.info, ixs_goods), axes=axes[ix_comp,
1], sensors=False, contours=0, show=False)
plt.suptitle(
'Please select a target for stimulation by clicking anywhere inside the left plot')
plt.figtext(0.25, 0, 'Frequency (Hz)')
plt.figtext(0, 0.5, 'Power (a.u.)', rotation='vertical')
plt.tight_layout()
default_linewidth = axes[0, 0].spines['bottom'].get_linewidth()
ix_comp = None
def _onclick_ax(event, axes=axes[:, 0], fig=fig):
nonlocal ix_comp
ax_pressed = None
ix_ax_pressed = None
for ix, ax in enumerate(axes):
if ax.contains(event)[0]:
ax_pressed = ax
ix_ax_pressed = ix
else:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_color('k')
ax.spines['left'].set_color('k')
ax.spines['bottom'].set_linewidth(default_linewidth)
ax.spines['left'].set_linewidth(default_linewidth)
if ax_pressed is not None:
ix_comp = ix_ax_pressed
ax_pressed.spines['top'].set_visible(True)
ax_pressed.spines['right'].set_visible(True)
ax_pressed.spines['top'].set_color('g')
ax_pressed.spines['right'].set_color('g')
ax_pressed.spines['bottom'].set_color('g')
ax_pressed.spines['left'].set_color('g')
ax_pressed.spines['top'].set_linewidth(2 * default_linewidth)
ax_pressed.spines['right'].set_linewidth(2 * default_linewidth)
ax_pressed.spines['bottom'].set_linewidth(2 * default_linewidth)
ax_pressed.spines['left'].set_linewidth(2 * default_linewidth)
fig.canvas.draw()
fig.canvas.mpl_connect("button_press_event", _onclick_ax)
sns.despine()
plt.show()
if ix_comp is None:
raise Exception(
'No target for stimulation (forward model) was selected')
forward_full = np.zeros(raw.n_chs)
forward_full[ixs_goods] = M[:, int(ix_comp)]
raw.forward_full = forward_full
# from scipy.io import loadmat
# raw = mne.io.read_raw_brainvision('C:\\Users\\hasla\Desktop\\rising_falling_cwm_tims\\data\\P1_DH\\calibration_no_stim.vhdr',preload=True)
# raw = raw.pick_channels(raw.ch_names[:64])
# raw.set_montage('easycap-M1',match_case=False)
# forward_model = loadmat('C:\\Users\\hasla\Desktop\\rising_falling_cwm_tims\\data\\P1_DH\\P_TARGET_64.mat')['P_TARGET_64'].squeeze()
# flip = loadmat('C:\\Users\\hasla\Desktop\\rising_falling_cwm_tims\\data\\P1_DH\\flip.mat')['flip'].squeeze()
# sfreq = raw.info['sfreq']
# mask_bad = forward_model==0
# bads = np.array(raw.ch_names)[:64][mask_bad]
# raw.drop_channels(bads)
# get_forward(raw,8,14,1,40)