import numpy as np
from mne.utils.check import _check_sphere
from mne.viz.topomap import _get_pos_outlines, _draw_outlines
from matplotlib import rcParams
import matplotlib.pyplot as plt
from functools import partial
from .base import RawCLAM, EpochsCLAM
from .misc import df_to_array
import seaborn as sns
from mne.viz import plot_sensors
import pandas as pd
from .stats import _dft
from mne import pick_info, pick_types
def _onpick_sensor(event, fig, ax, pos, ch_names, bads, scatter):
if event is not None:
if event.mouseevent.inaxes != ax:
return
ind = event.ind[0]
ch_name = ch_names[ind]
if ch_name in bads:
bads.remove(ch_name)
else:
bads.append(ch_name)
edgecolors = ['r' if ch in bads else 'k' for ch in ch_names]
scatter.set_edgecolors(edgecolors)
fig.canvas.draw()
[docs]
def set_bads(obj, default_bads):
"""Interactive tool to mark bad channels in EEG data.
This function provides an interactive visualization to mark bad channels in EEG data.
It displays a scalp plot with the channels labeled, allowing the user to click on
channels to mark them as bad.
Parameters:
-----------
obj : RawCLAM or EpochsCLAM object
The RawCLAM or EpochsCLAM object containing EEG data.
default_bads: list of str
The channels that are marked bad by default
Raises:
-------
Exception:
If the input object is not an instance of RawCLAM or EpochsCLAM.
Notes:
------
- This function requires Matplotlib to display the interactive visualization.
- Bad channels are marked by clicking on the corresponding channel in the plot.
- The plot window must be closed to finalize the selection of bad channels.
- The updated list of bad channels is stored in the `bads` attribute of the input object.
"""
if not (isinstance(obj, RawCLAM) or isinstance(obj, EpochsCLAM)):
raise Exception('set_bads can only be applied to RawCLAM or EpochsCLAM objects')
info = pick_info(obj.info, pick_types(obj.info, eeg=True, exclude=[]))
ch_names = info.ch_names
pos = np.empty((len(info["chs"]), 3))
for ci, ch in enumerate(info['chs']):
pos[ci] = ch["loc"][:3]
sphere = _check_sphere(None, info)
subplot_kw = dict()
fig, ax = plt.subplots(1, figsize=(
max(rcParams["figure.figsize"]),) * 2, subplot_kw=subplot_kw)
ax.text(0, 0, "", zorder=1)
pos, outlines = _get_pos_outlines(
info, range(len(ch_names)), sphere, to_sphere=True)
_draw_outlines(ax, outlines)
# DRAW SERIES OF LINES HERE FOR CONNECTIVITY GRAPH
scatter = ax.scatter(
pos[:, 0],
pos[:, 1],
picker=True,
clip_on=False,
c='k',
edgecolors='k',
s=150,
lw=2,
)
ax.set(aspect="equal")
fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
ax.axis("off")
indices = range(len(pos))
for idx in indices:
this_pos = pos[idx]
ax.text(
this_pos[0],
this_pos[1] - 0.007,
ch_names[idx],
ha="center",
va="center",
)
xmin, ymin, xmax, ymax = fig.get_window_extent().bounds
renderer = fig.canvas.get_renderer()
extents = [x.get_window_extent(renderer=renderer) for x in ax.texts]
xmaxs = np.array([x.max[0] for x in extents])
bad_xmax_ixs = np.nonzero(xmaxs > xmax)[0]
if len(bad_xmax_ixs):
needed_space = (xmaxs[bad_xmax_ixs] - xmax).max() / xmax
fig.subplots_adjust(right=1 - 1.1 * needed_space)
bads = [ch for ch in default_bads if ch in ch_names]
picker = partial(
_onpick_sensor,
fig=fig,
ax=ax,
pos=pos,
ch_names=ch_names,
bads=bads,
scatter=scatter
)
fig.canvas.mpl_connect("pick_event", picker)
picker(None) # call to the update function with a dummy event to force it to draw the initial bad channels
fig.set_size_inches(10, 12)
ax.text(
0.05,
0.95,
'Please mark all bad channels by clicking on them.\nClose the window when done',
transform=ax.transAxes,
verticalalignment='top',
horizontalalignment='center',
fontsize=14,
weight='bold')
plt.tight_layout()
plt.show()
obj.info['bads'] = bads
[docs]
def plot_network_modulation_values(df_network_results, df_network_data, participant_identified_in, participant_applied_to):
"""Plot connectivity values averaged within each modulated network as a
box-/stripplot featuring each target phase.
Parameters:
-----------
df_network_results: pandas.DataFrame
The dataframe containing the connections which were modulated at the
participant- or group-level.
df_network_data: pandas.DataFrame
The dataframe containing the connectivity matrix for each trial
and participant.
participant_identified_in : str
The participant in which the network was identified. This network mask
will be applied to the data and plotted. Can be 'group' for group-level
network.
participant_applied_to : str
The participant to whose data the network mask should be applied for
plotting. Can be 'group' for group-level data.
"""
df_network_results = df_network_results[df_network_results['participant'] == participant_identified_in]
if participant_applied_to == 'group':
df_network_data = df_network_data.groupby([col for col in df_network_data.columns if col != 'value']).agg({'value' : np.mean}).reset_index()
else:
df_network_data = df_network_data[df_network_data['participant'] == participant_applied_to]
network_data, target_phases = df_to_array(df_network_data)
# this is now (n_phases, n_epochs, n_chs, n_chs)
measure = df_network_data['measure'].iloc[0]
for ix_cluster, cluster in df_network_results.iterrows():
t_value = np.mean(cluster['t_values']) # average over t-values per connection in cluster
p_value = cluster['p_value'] # this is one p-value for the cluster
ixs_row, ixs_col = np.array(cluster['connections']).T # this contains the indices marking connections between sensors in cluster
cluster_data = network_data[:, :, ixs_row, ixs_col].mean(-1)
has_string_phases = any(isinstance(ph, str) for ph in target_phases)
x_label = 'Condition' if has_string_phases else 'Target Phase (°)'
x = np.concatenate([[target_phases[ix]] * len(cluster_data[ix])
for ix in range(len(target_phases))])
x = [ph if isinstance(ph, str) else round(np.rad2deg(ph)) for ph in x]
y = np.concatenate(cluster_data)
df_plot = pd.DataFrame(
{x_label: x, '{}'.format(measure): y})
df_plot_agg = df_plot.sort_values(x_label).groupby(x_label) \
.agg({'{}'.format(measure) : np.mean}).reset_index()
plt.figure()
sns.boxplot(
df_plot,
x=x_label,
y='{}'.format(measure),
color='k',
boxprops=dict(
alpha=0.5),
showmeans=True,
zorder=0,
showfliers=False)
sns.stripplot(
df_plot,
x=x_label,
y='{}'.format(measure),
color='r',
alpha=0.8,
zorder=1)
if len(target_phases) > 2:
avgs = df_plot_agg['{}'.format(measure)].to_numpy()
_dft(avgs, plot_sine=True)
plt.title('Identified in {}, applied to {}, Cluster {:d}, p = {:.3e}'.format(participant_identified_in,
participant_applied_to,
int(ix_cluster),
p_value))
sns.despine()
[docs]
def plot_network_modulation_topo(df_network_results, n_conns, participant, info):
"""Plot modulated network as connections between sensors on a topoplot.
Parameters:
-----------
df_network_results: pandas.DataFrame
The dataframe containing the connections which were modulated at the
participant- or group-level.
n_conns : int
The number of connections to plot. The n_conns most strongly modulated
connections will be plotted.
participant : str
The participant in which the modulated network was identified. Can be 'group'
for a group-level plot.
info : mne.Info
The Info object for topographic plotting.
"""
df_network_results = df_network_results[df_network_results['participant'] == participant]
colors = sns.color_palette("Set2")
sphere = _check_sphere(None, info)
pos, outlines = _get_pos_outlines(info, range(
len(info.ch_names)), sphere, to_sphere=True)
fig = plot_sensors(info, show_names=False, show=False)
fig.set_size_inches(12, 12)
for ix_cluster, cluster in df_network_results.iterrows():
t_values = cluster['t_values'] # this is one t-value per connection in cluster
p_value = cluster['p_value'] # this is one p-value for the cluster
connections = cluster['connections'] # this contains the indices marking connections between sensors in cluster
ixs_top_conns = np.argsort(t_values)[::-1][:n_conns]
passed_label = False
for ix_conn in ixs_top_conns:
ix_ch_1, ix_ch_2 = connections[ix_conn]
x_coords = [pos[ix_ch_1][0],
pos[ix_ch_2][0]]
y_coords = [pos[ix_ch_1][1],
pos[ix_ch_2][1]]
if not passed_label:
label = 'Cluster {:d}'.format(int(ix_cluster))
passed_label = True
else:
label = None
plt.plot(x_coords, y_coords, c=colors[ix_cluster], label=label)
plt.legend(frameon=False)
plt.title('Identified in {}'.format(participant))
plt.tight_layout()
[docs]
def plot_modulation_amp_corr(df_amp_phase_results):
"""Plot correlation of amplitude (depth) of modulation of two
outcome measures cross participants.
Parameters:
-----------
df_amp_phase_results: pandas.DataFrame
The dataframe containing the amplitude (depth) of modulation
for each participant and outcome measure.
"""
# TODO
# Do for all clusters in dataframe