Source code for clamnibs.sensor

import numpy as np
from scipy import linalg
import mne
from scipy.signal import hilbert
from . import misc
from mne.stats import permutation_cluster_test
from scipy.sparse import coo_matrix
import scipy
from mne.stats import ttest_ind_no_p
from mne.viz.topomap import _get_pos_outlines
from mne.utils.check import _check_sphere
import matplotlib.pyplot as plt
import seaborn as sns
from mne.viz import plot_sensors
import pandas as pd
from .misc import _get_ixs_goods, _get_main_target_phase
from .base import RawCLAM, EpochsCLAM, EpochsCLAMVariable
from .source import get_target
from tqdm import tqdm


def _find_n_nulls(A, B, D, M):
    errors = []
    for n_nulls in range(A.shape[0]):
        DI = np.ones(M.shape[0])
        DI[:n_nulls] = 0
        DI = np.diag(DI)
        P = M.dot(DI).dot(D)
        errors.append(linalg.norm(B - P @ A @ P.T))
    return np.argmin(errors)


[docs] def clean_sensor_data(obj_no_stim, obj_stim): """ Use Stimulation Artifact Source Separation (SASS) [1] to remove electric stimulation artifacts from EEG sensor data recorded during CLAM-NIBS. Bad sensors are excluded from the procedure, and interpolated after application of SASS. Parameters: ----------- obj_no_stim : RawCLAM or EpochsCLAM The RawCLAM or EpochsCLAM object containing the sensor data without CLAM-NIBS stimulation. obj_stim : RawCLAM or EpochsCLAM The RawCLAM or EpochsCLAM object containing the sensor data with CLAM-NIBS stimulation. Raises: ------- Exception If the input data types are not RawCLAM or EpochsCLAM objects. If the data with and without CLAM-NIBS have different bandpass filter or target frequency range settings. References: ----------- [1] Haslacher, David, et al. "Stimulation artifact source separation (SASS) for assessing electric brain oscillations during transcranial alternating current stimulation (tACS)." Neuroimage 228 (2021): 117571. """ if not (isinstance(obj_no_stim, (RawCLAM, EpochsCLAM, EpochsCLAMVariable)) and isinstance(obj_stim, (RawCLAM, EpochsCLAM, EpochsCLAMVariable))): raise Exception('clean_sensor_data can only be applied to RawCLAM, EpochsCLAM, or EpochsCLAMVariable objects') equal_l_freq = obj_no_stim.info['highpass'] == obj_stim.info['highpass'] equal_h_freq = obj_no_stim.info['lowpass'] == obj_stim.info['lowpass'] equal_l_freq_target = obj_no_stim.l_freq_target == obj_stim.l_freq_target equal_h_freq_target = obj_no_stim.h_freq_target == obj_stim.h_freq_target if not ( equal_l_freq and equal_h_freq and equal_l_freq_target and equal_h_freq_target): raise Exception( 'Data with and without CLAM-tACS have different bandpass-filter or target frequency range settings') l_freq = obj_no_stim.info['highpass'] h_freq = obj_no_stim.info['lowpass'] l_freq_target = obj_no_stim.l_freq_target h_freq_target = obj_no_stim.h_freq_target ixs_goods = _get_ixs_goods(obj_no_stim) if not (l_freq == l_freq_target and h_freq == h_freq_target): raise Exception( 'Data must be filtered into the target frequency range to apply SASS') if isinstance(obj_stim, RawCLAM): A = np.cov(obj_stim.get_data(ixs_goods)) elif isinstance(obj_stim, EpochsCLAMVariable): A = np.cov(np.concatenate( [np.real(ep) for ep in obj_stim.get_data(ixs_goods)], axis=-1)) else: A = np.cov( np.concatenate( obj_stim.get_data(ixs_goods), axis=-1)) if isinstance(obj_no_stim, RawCLAM): B = np.cov(obj_no_stim.get_data(ixs_goods)) elif isinstance(obj_no_stim, EpochsCLAMVariable): B = np.cov(np.concatenate( [np.real(ep) for ep in obj_no_stim.get_data(ixs_goods)], axis=-1)) else: B = np.cov( np.concatenate( obj_no_stim.get_data(ixs_goods), axis=-1)) eigen_values, eigen_vectors = linalg.eig(A, B) eigen_values = eigen_values.real eigen_vectors = eigen_vectors.real ix = np.argsort(eigen_values)[::-1] D = eigen_vectors[:, ix].T M = linalg.pinv(D) n_nulls = _find_n_nulls(A, B, D, M) DI = np.ones(M.shape[0]) DI[:n_nulls] = 0 DI = np.diag(DI) P = M.dot(DI).dot(D) if isinstance(obj_stim, RawCLAM): obj_stim._data[ixs_goods] = P @ obj_stim._data[ixs_goods] elif isinstance(obj_stim, EpochsCLAMVariable): obj_stim._data = [np.copy(ep) for ep in obj_stim._data] for i in range(len(obj_stim._data)): obj_stim._data[i][ixs_goods] = P @ obj_stim._data[i][ixs_goods] else: obj_stim._data[:, ixs_goods] = np.array( [P @ epoch for epoch in obj_stim._data[:, ixs_goods]]) obj_stim.interpolate_bads(reset_bads=True)
[docs] def compute_single_trial_connectivity(raw, measure='phase_lag_index', end_codes=None): """ Compute single-trial amplitude of target oscillation and assign it to CLAM-NIBS target phase. Parameters: ----------- raw : RawCLAM The RawCLAM object containing the raw data to analyze. measure : str, optional The connectivity measure to compute. Currently, only 'phase_lag_index' is supported (default). end_codes : list of int or None, optional If provided, use variable-length epochs defined by start markers (from marker_definition) to end markers (end_codes). If None, use fixed-length epochs (default). Returns: -------- pandas.DataFrame A DataFrame containing the computed amplitude values and CLAM-NIBS target phase for each epoch. Raises: ------- Exception If the input raw data is not of type RawCLAM. If the method for computing amplitude is not supported. If the Raw object does not meet the requirements for the chosen method. """ if not isinstance(raw, RawCLAM): raise Exception('compute_single_trial_amplitude can only be applied to RawCLAM objects') sfreq = raw.info['sfreq'] l_freq = raw.info['highpass'] h_freq = raw.info['lowpass'] l_freq_target = raw.l_freq_target h_freq_target = raw.h_freq_target marker_definition = raw.marker_definition participant = raw.participant session = raw.session design = raw.design if measure not in ['phase_lag_index']: raise Exception('Connectivity measure must be \"phase_lag_index\"') if not (l_freq == l_freq_target and h_freq == h_freq_target): raise Exception( 'Data must be filtered into the target frequency range to compute connectivity') if not len(raw.info['bads']) == 0: raise Exception( 'Bads must be interpolated before connectivity computation') if design == 'trial_wise': if end_codes is not None: epochs = EpochsCLAMVariable(raw, end_codes=end_codes) data_hil = epochs.get_data(picks='eeg') else: epochs = EpochsCLAM(raw) data_hil = epochs.get_data(picks='eeg') else: data_hil = raw.copy().apply_hilbert()[None, :, :] if isinstance(data_hil, list): phases = [np.angle(d) for d in data_hil] else: phases = np.angle(data_hil) n_chs = raw.n_chs conns = [] for phase in tqdm(phases, desc='Computing single-trial connectivity'): conn = np.zeros((n_chs, n_chs)) for ix1 in range(n_chs): for ix2 in range(ix1 + 1, n_chs): conn[ix1, ix2] = misc._pli(phase[ix1], phase[ix2]) conn += conn.T conn += np.diag(np.ones(n_chs)) conns.append(conn) if design == 'trial_wise': events = epochs.events target_phases = np.vectorize(marker_definition.get)(events[:, 2]) else: events = mne.events_from_annotations(raw)[0] target_phases = [_get_main_target_phase(marker_definition, events)] df_result = pd.DataFrame({'participant': [participant] * len(conns), 'session': [session] * len(conns), 'design': [design] * len(conns), 'target_phase': target_phases, 'measure': [measure] * len(conns), 'value': conns}) return df_result