Source code for neuroconv.datainterfaces.ecephys.baserecordingextractorinterface

from typing import Literal, Optional, Union

import numpy as np
from pynwb import NWBFile
from pynwb.device import Device
from pynwb.ecephys import ElectricalSeries, ElectrodeGroup

from ...baseextractorinterface import BaseExtractorInterface
from ...utils import (
    DeepDict,
    get_base_schema,
    get_schema_from_hdmf_class,
)


[docs]class BaseRecordingExtractorInterface(BaseExtractorInterface): """Parent class for all RecordingExtractorInterfaces.""" keywords = ("extracellular electrophysiology", "voltage", "recording") ExtractorModuleName = "spikeinterface.extractors" def __init__(self, verbose: bool = False, es_key: str = "ElectricalSeries", **source_data): """ Parameters ---------- verbose : bool, default: False If True, will print out additional information. es_key : str, default: "ElectricalSeries" The key of this ElectricalSeries in the metadata dictionary. source_data : dict The key-value pairs of extractor-specific arguments. """ super().__init__(**source_data) self.recording_extractor = self._extractor_instance property_names = self.recording_extractor.get_property_keys() # TODO remove this and go and change all the uses of channel_name once spikeinterface > 0.101.0 is released if "channel_name" not in property_names and "channel_names" in property_names: channel_names = self.recording_extractor.get_property("channel_names") self.recording_extractor.set_property("channel_name", channel_names) self.recording_extractor.delete_property("channel_names") self.subset_channels = None self.verbose = verbose self.es_key = es_key self._number_of_segments = self.recording_extractor.get_num_segments()
[docs] def get_metadata_schema(self) -> dict: """ Compile metadata schema for the RecordingExtractor. Returns ------- dict The metadata schema dictionary containing definitions for Device, ElectrodeGroup, Electrodes, and optionally ElectricalSeries. """ metadata_schema = super().get_metadata_schema() metadata_schema["properties"]["Ecephys"] = get_base_schema(tag="Ecephys") metadata_schema["properties"]["Ecephys"]["required"] = ["Device", "ElectrodeGroup"] metadata_schema["properties"]["Ecephys"]["properties"] = dict( Device=dict(type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/Device"}), ElectrodeGroup=dict( type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/ElectrodeGroup"} ), Electrodes=dict( type="array", minItems=0, renderForm=False, items={"$ref": "#/properties/Ecephys/definitions/Electrodes"}, ), ) # Schema definition for arrays metadata_schema["properties"]["Ecephys"]["definitions"] = dict( Device=get_schema_from_hdmf_class(Device), ElectrodeGroup=get_schema_from_hdmf_class(ElectrodeGroup), Electrodes=dict( type="object", additionalProperties=False, required=["name"], properties=dict( name=dict(type="string", description="name of this electrodes column"), description=dict(type="string", description="description of this electrodes column"), ), ), ) if self.es_key is not None: metadata_schema["properties"]["Ecephys"]["properties"].update( {self.es_key: get_schema_from_hdmf_class(ElectricalSeries)} ) return metadata_schema
[docs] def get_metadata(self) -> DeepDict: metadata = super().get_metadata() from ...tools.spikeinterface.spikeinterface import _get_group_name channel_groups_array = _get_group_name(recording=self.recording_extractor) unique_channel_groups = set(channel_groups_array) if channel_groups_array is not None else ["ElectrodeGroup"] electrode_metadata = [ dict(name=str(group_id), description="no description", location="unknown", device="DeviceEcephys") for group_id in unique_channel_groups ] metadata["Ecephys"] = dict( Device=[dict(name="DeviceEcephys", description="no description")], ElectrodeGroup=electrode_metadata, ) if self.es_key is not None: metadata["Ecephys"][self.es_key] = dict( name=self.es_key, description=f"Acquisition traces for the {self.es_key}." ) return metadata
@property def channel_ids(self): "Gets the channel ids of the data." return self.recording_extractor.get_channel_ids()
[docs] def get_original_timestamps(self) -> Union[np.ndarray, list[np.ndarray]]: """ Retrieve the original unaltered timestamps for the data in this interface. This function should retrieve the data on-demand by re-initializing the IO. Returns ------- timestamps: numpy.ndarray or list of numpy.ndarray The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned. """ new_recording = self.get_extractor()( **{ keyword: value for keyword, value in self.extractor_kwargs.items() if keyword not in ["verbose", "es_key"] } ) if self._number_of_segments == 1: return new_recording.get_times() else: return [ new_recording.get_times(segment_index=segment_index) for segment_index in range(self._number_of_segments) ]
[docs] def get_timestamps(self) -> Union[np.ndarray, list[np.ndarray]]: """ Retrieve the timestamps for the data in this interface. Returns ------- timestamps: numpy.ndarray or list of numpy.ndarray The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned. """ if self._number_of_segments == 1: return self.recording_extractor.get_times() else: return [ self.recording_extractor.get_times(segment_index=segment_index) for segment_index in range(self._number_of_segments) ]
[docs] def set_aligned_timestamps(self, aligned_timestamps: np.ndarray): assert ( self._number_of_segments == 1 ), "This recording has multiple segments; please use 'align_segment_timestamps' instead." self.recording_extractor.set_times(times=aligned_timestamps, with_warning=False)
[docs] def set_aligned_segment_timestamps(self, aligned_segment_timestamps: list[np.ndarray]): """ Replace all timestamps for all segments in this interface with those aligned to the common session start time. Must be in units seconds relative to the common 'session_start_time'. Parameters ---------- aligned_segment_timestamps : list of numpy.ndarray The synchronized timestamps for segment of data in this interface. """ assert isinstance( aligned_segment_timestamps, list ), "Recording has multiple segment! Please pass a list of timestamps to align each segment." assert ( len(aligned_segment_timestamps) == self._number_of_segments ), f"The number of timestamp vectors ({len(aligned_segment_timestamps)}) does not match the number of segments ({self._number_of_segments})!" for segment_index in range(self._number_of_segments): self.recording_extractor.set_times( times=aligned_segment_timestamps[segment_index], segment_index=segment_index, with_warning=False, )
[docs] def set_aligned_starting_time(self, aligned_starting_time: float): if self._number_of_segments == 1: self.set_aligned_timestamps(aligned_timestamps=self.get_timestamps() + aligned_starting_time) else: self.set_aligned_segment_timestamps( aligned_segment_timestamps=[ segment_timestamps + aligned_starting_time for segment_timestamps in self.get_timestamps() ] )
[docs] def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float]): """ Align the starting time for each segment in this interface relative to the common session start time. Must be in units seconds relative to the common 'session_start_time'. Parameters ---------- aligned_segment_starting_times : list of floats The starting time for each segment of data in this interface. """ assert len(aligned_segment_starting_times) == self._number_of_segments, ( f"The length of the starting_times ({len(aligned_segment_starting_times)}) does not match the " "number of segments ({self._number_of_segments})!" ) if self._number_of_segments == 1: self.set_aligned_starting_time(aligned_starting_time=aligned_segment_starting_times[0]) else: aligned_segment_timestamps = [ segment_timestamps + aligned_segment_starting_time for segment_timestamps, aligned_segment_starting_time in zip( self.get_timestamps(), aligned_segment_starting_times ) ] self.set_aligned_segment_timestamps(aligned_segment_timestamps=aligned_segment_timestamps)
[docs] def set_probe(self, probe, group_mode: Literal["by_shank", "by_probe"]): """ Set the probe information via a ProbeInterface object. Parameters ---------- probe : probeinterface.Probe The probe object. group_mode : {'by_shank', 'by_probe'} How to group the channels. If 'by_shank', channels are grouped by the shank_id column. If 'by_probe', channels are grouped by the probe_id column. This is a required parameter to avoid the pitfall of using the wrong mode. """ # Set the probe to the recording extractor self.recording_extractor.set_probe( probe, in_place=True, group_mode=group_mode, ) # Spike interface sets the "group" property # But neuroconv allows "group_name" property to override spike interface "group" value self.recording_extractor.set_property("group_name", self.recording_extractor.get_property("group").astype(str))
[docs] def has_probe(self) -> bool: """ Check if the recording extractor has probe information. Returns ------- bool True if the recording extractor has probe information, False otherwise. """ return self.recording_extractor.has_probe()
[docs] def align_by_interpolation( self, unaligned_timestamps: np.ndarray, aligned_timestamps: np.ndarray, ): if self._number_of_segments == 1: self.set_aligned_timestamps( aligned_timestamps=np.interp(x=self.get_timestamps(), xp=unaligned_timestamps, fp=aligned_timestamps) ) else: raise NotImplementedError("Multi-segment support for aligning by interpolation has not been added yet.")
[docs] def subset_recording(self, stub_test: bool = False): """ Subset a recording extractor according to stub and channel subset options. Parameters ---------- stub_test : bool, default: False If True, only a subset of frames will be included. Returns ------- spikeinterface.core.BaseRecording The subsetted recording extractor. """ from spikeinterface.core.segmentutils import AppendSegmentRecording max_frames = 100 recording_extractor = self.recording_extractor number_of_segments = recording_extractor.get_num_segments() recording_segments = [recording_extractor.select_segments([index]) for index in range(number_of_segments)] end_frame_list = [min(max_frames, segment.get_num_frames()) for segment in recording_segments] recording_segments_stubbed = [ segment.frame_slice(start_frame=0, end_frame=end_frame) for segment, end_frame in zip(recording_segments, end_frame_list) ] recording_extractor_stubbed = AppendSegmentRecording(recording_list=recording_segments_stubbed) times_stubbed = [ recording_extractor.get_times(segment_index=segment_index)[:end_frame] for segment_index, end_frame in zip(range(number_of_segments), end_frame_list) ] for segment_index in range(number_of_segments): recording_extractor_stubbed.set_times( times=times_stubbed[segment_index], segment_index=segment_index, with_warning=False, ) return recording_extractor_stubbed
[docs] def add_to_nwbfile( self, nwbfile: NWBFile, metadata: Optional[dict] = None, stub_test: bool = False, starting_time: Optional[float] = None, write_as: Literal["raw", "lfp", "processed"] = "raw", write_electrical_series: bool = True, iterator_type: Optional[str] = "v2", iterator_opts: Optional[dict] = None, always_write_timestamps: bool = False, ): """ Primary function for converting raw (unprocessed) RecordingExtractor data to the NWB standard. Parameters ---------- nwbfile : NWBFile NWBFile to which the recording information is to be added metadata : dict, optional metadata info for constructing the NWB file. Should be of the format:: metadata['Ecephys']['ElectricalSeries'] = dict(name=my_name, description=my_description) starting_time : float, optional Sets the starting time of the ElectricalSeries to a manually set value. stub_test : bool, default: False If True, will truncate the data to run the conversion faster and take up less memory. write_as : {'raw', 'processed', 'lfp'}, default='raw' Specifies how to save the trace data in the NWB file. Options are: - 'raw': Save the data in the acquisition group. - 'processed': Save the data as FilteredEphys in a processing module. - 'lfp': Save the data as LFP in a processing module. write_electrical_series : bool, default: True Electrical series are written in acquisition. If False, only device, electrode_groups, and electrodes are written to NWB. iterator_type : {'v2'} The type of DataChunkIterator to use. 'v2' is the locally developed RecordingExtractorDataChunkIterator, which offers full control over chunking iterator_opts : dict, optional Dictionary of options for the RecordingExtractorDataChunkIterator (iterator_type='v2'). Valid options are: * buffer_gb : float, default: 1.0 In units of GB. Recommended to be as much free RAM as available. Automatically calculates suitable buffer shape. * buffer_shape : tuple, optional Manual specification of buffer shape to return on each iteration. Must be a multiple of chunk_shape along each axis. Cannot be set if `buffer_gb` is specified. * chunk_mb : float. default: 1.0 Should be below 1 MB. Automatically calculates suitable chunk shape. * chunk_shape : tuple, optional Manual specification of the internal chunk shape for the HDF5 dataset. Cannot be set if `chunk_mb` is also specified. * display_progress : bool, default: False Display a progress bar with iteration rate and estimated completion time. * progress_bar_options : dict, optional Dictionary of keyword arguments to be passed directly to tqdm. See https://github.com/tqdm/tqdm#parameters for options. always_write_timestamps : bool, default: False Set to True to always write timestamps. By default (False), the function checks if the timestamps are uniformly sampled, and if so, stores the data using a regular sampling rate instead of explicit timestamps. If set to True, timestamps will be written explicitly, regardless of whether the sampling rate is uniform. """ from ...tools.spikeinterface import add_recording_to_nwbfile if stub_test or self.subset_channels is not None: recording = self.subset_recording(stub_test=stub_test) else: recording = self.recording_extractor metadata = metadata or self.get_metadata() add_recording_to_nwbfile( recording=recording, nwbfile=nwbfile, metadata=metadata, starting_time=starting_time, write_as=write_as, write_electrical_series=write_electrical_series, es_key=self.es_key, iterator_type=iterator_type, iterator_opts=iterator_opts, always_write_timestamps=always_write_timestamps, )