#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Classify view
"""
from PyQt5.QtWidgets import QWidget, QVBoxLayout, QLabel, QHBoxLayout, QPushButton, QGridLayout, QCheckBox, \
QDoubleSpinBox, QSpinBox
from utils.elements_selector.elements_selector_controller import multipleSelectorController
from utils.view.separator import create_layout_separator
__author__ = "Lemahieu Antoine"
__copyright__ = "Copyright 2022"
__credits__ = ["Lemahieu Antoine"]
__license__ = "GNU General Public License v3.0"
__maintainer__ = "Lemahieu Antoine"
__email__ = "Antoine.Lemahieu@ulb.be"
__status__ = "Dev"
[docs]class classifyView(QWidget):
def __init__(self, number_of_channels, event_values, event_ids):
"""
Window displaying the parameters for performing the classification.
:param number_of_channels: The number of channels in the dataset.
:type number_of_channels: int
:param event_values: Event_id associated to each epoch/trial.
:type event_values: list of, list of int
:param event_ids: Name of the events associated to their id.
:type event_ids: dict
"""
super().__init__()
self.classify_listener = None
self.number_of_channels = number_of_channels
self.event_values = event_values
self.event_ids = event_ids
self.pipeline_selector_controller = None
self.pipeline_selected = None
self.events_selector_controller = None
self.trials_selected = None
self.setWindowTitle("Classification")
self.vertical_layout = QVBoxLayout()
self.setLayout(self.vertical_layout)
# Parameters
self.grid_widget = QWidget()
self.grid_layout = QGridLayout()
self.pipeline_selection_button = QPushButton("&Pipelines ...", self)
self.pipeline_selection_button.clicked.connect(self.pipeline_selection_trigger)
self.feature_selection = QCheckBox()
self.number_of_features = QSpinBox()
self.number_of_features.setRange(1, self.number_of_channels)
if self.number_of_channels >= 20:
self.number_of_features.setValue(20)
else:
self.number_of_features.setValue(self.number_of_channels)
self.hyper_tuning = QCheckBox()
self.cross_validation_number = QDoubleSpinBox()
self.cross_validation_number.setValue(5)
self.cross_validation_number.setMinimum(1)
self.cross_validation_number.setDecimals(0)
# Layout of parameters
self.grid_layout.addWidget(QLabel("Pipeline selection : "), 0, 0)
self.grid_layout.addWidget(self.pipeline_selection_button, 0, 1)
self.grid_layout.addWidget(QLabel("Feature selection : "), 1, 0)
self.grid_layout.addWidget(self.feature_selection, 1, 1)
self.grid_layout.addWidget(QLabel("Number of features to select : "), 2, 0)
self.grid_layout.addWidget(self.number_of_features, 2, 1)
self.grid_layout.addWidget(QLabel("Hyper-parameters tuning : "), 3, 0)
self.grid_layout.addWidget(self.hyper_tuning, 3, 1)
self.grid_layout.addWidget(QLabel("Cross-validation k-fold : "), 4, 0)
self.grid_layout.addWidget(self.cross_validation_number, 4, 1)
self.grid_widget.setLayout(self.grid_layout)
# Trial selection
self.trial_selection_widget = QWidget()
self.trial_selection_layout = QGridLayout()
self.trial_selection_label = QLabel("Trials indexes to compute (default : all) :")
self.trial_selection_indexes = QPushButton("Select by trials indexes")
self.trial_selection_indexes.clicked.connect(self.trial_selection_indexes_trigger)
self.trial_selection_events = QPushButton("Select by events")
self.trial_selection_events.clicked.connect(self.trial_selection_events_trigger)
self.trial_selection_layout.addWidget(self.trial_selection_label, 0, 0)
self.trial_selection_layout.addWidget(self.trial_selection_indexes, 0, 1)
self.trial_selection_layout.addWidget(self.trial_selection_events, 1, 1)
self.trial_selection_widget.setLayout(self.trial_selection_layout)
# Cancel Confirm
self.cancel_confirm_widget = QWidget()
self.cancel_confirm_layout = QHBoxLayout()
self.cancel = QPushButton("&Cancel", self)
self.cancel.clicked.connect(self.cancel_classification_trigger)
self.confirm = QPushButton("&Confirm", self)
self.confirm.clicked.connect(self.confirm_classification_trigger)
self.cancel_confirm_layout.addWidget(self.cancel)
self.cancel_confirm_layout.addWidget(self.confirm)
self.cancel_confirm_widget.setLayout(self.cancel_confirm_layout)
self.vertical_layout.addWidget(self.grid_widget)
self.vertical_layout.addWidget(create_layout_separator())
self.vertical_layout.addWidget(self.trial_selection_widget)
self.vertical_layout.addWidget(create_layout_separator())
self.vertical_layout.addWidget(self.cancel_confirm_widget)
"""
Plot
"""
[docs] @staticmethod
def plot_results(classifier):
"""
Plot the classification results.
:param classifier: The classifier that did the classification.
:type classifier: ApplePyClassifier
"""
classifier.show_results(4)
"""
Triggers
"""
[docs] def cancel_classification_trigger(self):
"""
Send the information to the controller that the computation is cancelled.
"""
self.classify_listener.cancel_button_clicked()
[docs] def confirm_classification_trigger(self):
"""
Retrieve the parameters and send the information to the controller.
"""
feature_selection = self.feature_selection.isChecked()
number_of_channels_to_select = self.number_of_features.value()
hyper_tuning = self.hyper_tuning.isChecked()
cross_val_number = int(self.cross_validation_number.value())
if self.trials_selected is None:
trials_selected = [i for i in range(len(self.event_values))]
else:
trials_selected = self.trials_selected
self.classify_listener.confirm_button_clicked(self.pipeline_selected, feature_selection, number_of_channels_to_select,
hyper_tuning, cross_val_number, trials_selected)
[docs] def pipeline_selection_trigger(self):
"""
Open the multiple selector window.
The user can select a multiple pipelines used for the classification.
"""
title = "Select the pipelines used for the classification :"
all_pipelines = ['XdawnCov', 'Xdawn', 'CSP', 'CSP2', 'cov', 'HankelCov', 'CSSP', 'PSD', 'MDM', 'FgMDM']
# 'XdawnCovTSLR', 'Cosp'
self.pipeline_selector_controller = multipleSelectorController(all_pipelines, title, box_checked=True,
element_type="pipeline")
self.pipeline_selector_controller.set_listener(self.classify_listener)
[docs] def trial_selection_indexes_trigger(self):
"""
Open the multiple selector window.
The user can select the trials indexes he wants the source estimation to be computed on.
"""
title = "Select the trial's events used for computing the source estimation:"
indexes_list = [str(i+1) for i in range(len(self.event_values))]
self.events_selector_controller = multipleSelectorController(indexes_list, title, box_checked=True,
element_type="indexes")
self.events_selector_controller.set_listener(self.classify_listener)
[docs] def trial_selection_events_trigger(self):
"""
Open the multiple selector window.
The user can select the events he wants the source estimation to be computed on.
"""
title = "Select the trial's events used for computing the source estimation:"
events_ids_list = list(self.event_ids.keys())
self.events_selector_controller = multipleSelectorController(events_ids_list, title, box_checked=True,
element_type="events")
self.events_selector_controller.set_listener(self.classify_listener)
"""
Utils
"""
[docs] def check_element_type(self, elements_selected, element_type):
if element_type == "pipeline":
self.set_pipeline_selected(elements_selected)
elif element_type == "indexes" or element_type == "events":
self.set_trials_selected(elements_selected, element_type)
"""
Setters
"""
[docs] def set_listener(self, listener):
"""
Set the listener to the controller.
:param listener: Listener to the controller.
:type listener: classifyController
"""
self.classify_listener = listener
[docs] def set_trials_selected(self, elements_selected, element_type):
"""
Set the channels selected in the multiple selector window.
:param elements_selected: Trials or Events selected.
:type elements_selected: list of str
:param element_type: Type of the element selected, used in case multiple element selector windows can be open in
a window. Can thus distinguish the returned elements.
:type element_type: str
"""
trials_to_use = []
if element_type == "indexes":
for trial in elements_selected:
trials_to_use.append(int(trial)-1) # -1 To get index in the list, not "position"
elif element_type == "events":
# Get ids of the events selected
event_ids_selected = []
for event in elements_selected:
event_ids_selected.append(self.event_ids[event])
# Get indexes of the trials if their event is selected.
for i in range(len(self.event_values)):
if self.event_values[i][2] in event_ids_selected:
trials_to_use.append(i)
self.trials_selected = trials_to_use
[docs] def set_pipeline_selected(self, pipeline):
"""
Set the pipeliens selected in the multiple selector window.
:param pipeline: Pipelines selected
:type pipeline: list of str
"""
self.pipeline_selected = pipeline