Source code for neuroconv.tools.testing.mock_ttl_signals

import math
from pathlib import Path
from typing import Optional, Union

import numpy as np
from numpy.typing import DTypeLike
from pydantic import DirectoryPath
from pynwb import NWBHDF5IO, H5DataIO, TimeSeries
from pynwb.testing.mock.file import mock_NWBFile

from ..importing import is_package_installed
from ...utils import ArrayType


def _check_parameter_dtype_consistency(
    parameter_name: str,
    parameter_value: Union[int, float],
    generic_dtype: type,  # Literal[np.integer, np.floating]
):
    """Helper for `generate_mock_ttl_signal` to assert consistency between parameters and expected trace dtype."""
    end_format = "an integer" if generic_dtype == np.integer else "a float"
    assert np.issubdtype(type(parameter_value), generic_dtype), (
        f"If specifying the '{parameter_name}' manually, please ensure it matches the 'dtype'! "
        f"Received '{type(parameter_value).__name__}', should be {end_format}."
    )


[docs]def generate_mock_ttl_signal( signal_duration: float = 7.0, ttl_times: Optional[ArrayType] = None, ttl_duration: float = 1.0, sampling_frequency_hz: float = 25_000.0, dtype: DTypeLike = "int16", baseline_mean: Optional[Union[int, float]] = None, signal_mean: Optional[Union[int, float]] = None, channel_noise: Optional[Union[int, float]] = None, random_seed: Optional[int] = 0, ) -> np.ndarray: """ Generate a synthetic signal of TTL pulses similar to those seen in .nidq.bin files using SpikeGLX. Parameters ---------- signal_duration : float, default: 7.0 The number of seconds to simulate. ttl_times : array of floats, optional The times within the `signal_duration` to trigger the TTL pulse. In conjunction with the `ttl_duration`, these must produce disjoint 'on' intervals. The default generates a periodic 1 second on, 1 second off pattern. ttl_duration : float, default: 1.0 How long the TTL pulse stays in the 'on' state when triggered, in seconds. In conjunction with the `ttl_times`, these must produce disjoint 'on' intervals. sampling_frequency_hz : float, default: 25,000.0 The sampling frequency of the signal in Hz. The default is 25000 Hz; similar to that of typical .nidq.bin files. dtype : numpy data type or one of its accepted string input, default: "int16" The data type of the trace. Must match the data type of `baseline_mean`, `signal_mean`, and `channel_noise`, if any of those are specified. Recommended to be int16 for maximum efficiency, but can also be any size float to represent voltage scalings. baseline_mean : integer or float, depending on specified 'dtype', optional The average value for the baseline; usually around 0 Volts. The default is approximately 0.005645752 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file. signal_mean : integer or float, optional Type depends on specified 'dtype'. The average value for the signal; usually around 5 Volts. The default is approximately 4.980773925 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file. channel_noise : integer or float, optional Type depends on specified 'dtype'. The standard deviation of white noise in the channel. The default is approximately 0.002288818 Volts, estimated from a real example of a TTL pulse in a .nidq.bin file. random_seed : int or None, default: 0 The seed to set for the numpy random number generator. Set to None to choose the seed randomly. The default is kept at 0 for generating reproducible outputs. Returns ------- trace: numpy.ndarray The synethic trace representing a channel with TTL pulses. """ dtype = np.dtype(dtype) # Default values estimated from real files baseline_mean_int16_default = 37 signal_mean_int16_default = 32642 channel_noise_int16_default = 15 default_gain_to_volts = 152.58789062 * 1e-6 if np.issubdtype(dtype, np.unsignedinteger): # If data type is an unsigned integer, increment the signed default values by the midpoint of the unsigned range shift = math.floor(np.iinfo(dtype).max / 2) baseline_mean_int16_default += shift signal_mean_int16_default += shift if np.issubdtype(dtype, np.integer): baseline_mean = baseline_mean or baseline_mean_int16_default signal_mean = signal_mean or signal_mean_int16_default channel_noise = channel_noise or channel_noise_int16_default generic_dtype = np.integer else: baseline_mean = baseline_mean or baseline_mean_int16_default * default_gain_to_volts signal_mean = signal_mean or signal_mean_int16_default * default_gain_to_volts channel_noise = channel_noise or channel_noise_int16_default * default_gain_to_volts generic_dtype = np.floating parameters_to_check = dict(baseline_mean=baseline_mean, signal_mean=signal_mean, channel_noise=channel_noise) for parameter_name, parameter_value in parameters_to_check.items(): _check_parameter_dtype_consistency( parameter_name=parameter_name, parameter_value=parameter_value, generic_dtype=generic_dtype ) np.random.seed(seed=random_seed) num_frames = np.ceil(signal_duration * sampling_frequency_hz).astype(int) trace = (np.random.randn(num_frames) * channel_noise + baseline_mean).astype(dtype) if ttl_times is not None: ttl_times = np.array(ttl_times) else: ttl_times = np.arange(start=1.0, stop=signal_duration, step=2.0) assert len(ttl_times) == 1 or not any( # np.diff errors out when len(ttl_times) < 2 np.diff(ttl_times) <= ttl_duration ), "There are overlapping TTL 'on' intervals! Please specify disjoint on/off periods." ttl_start_frames = np.round(ttl_times * sampling_frequency_hz).astype(int) num_frames_ttl_duration = np.round(ttl_duration * sampling_frequency_hz).astype(int) ttl_intervals = (slice(start, start + num_frames_ttl_duration) for start in ttl_start_frames) for ttl_interval in ttl_intervals: trace[ttl_interval] += signal_mean return trace
[docs]def regenerate_test_cases(folder_path: DirectoryPath, regenerate_reference_images: bool = False): # pragma: no cover """ Regenerate the test cases of the file included in the main testing suite, which is frozen between breaking changes. Parameters ---------- folder_path : PathType Folder to save the resulting NWB file in. For use in the testing suite, this must be the '/test_testing/test_mock_ttl/' subfolder adjacent to the 'test_mock_tt.py' file. regenerate_reference_images : bool If true, uses the kaleido package with plotly (you may need to install both) to regenerate the images used as references in the documentation. """ folder_path = Path(folder_path) if regenerate_reference_images: assert is_package_installed("plotly") and is_package_installed("kaleido"), ( "To regenerate the reference images, " "you must install both plotly and kaleido!" ) import plotly.graph_objects as go from plotly.subplots import make_subplots image_file_path = folder_path / "example_ttl_reference.png" nwbfile_path = folder_path / "mock_ttl_examples.nwb" compression_options = dict(compression="gzip", compression_opts=9) unit = "Volts" rate = 1000.0 # For non-default series to produce less data nwbfile = mock_NWBFile() # Test Case 1: Default default_ttl_signal = generate_mock_ttl_signal() nwbfile.add_acquisition( TimeSeries( name="DefaultTTLSignal", unit=unit, rate=25000.0, data=H5DataIO(data=default_ttl_signal, chunks=default_ttl_signal.shape, **compression_options), ) ) non_default_series = dict() # Test Case 2: Irregular short pulses irregular_short_pulses = generate_mock_ttl_signal( signal_duration=2.5, ttl_times=[0.22, 1.37], ttl_duration=0.25, sampling_frequency_hz=rate ) non_default_series.update(IrregularShortPulses=irregular_short_pulses) # Test Case 3: Non-default regular non_default_regular = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, ) non_default_series.update(NonDefaultRegular=non_default_regular) # Test Case 4: Non-default regular with adjusted means non_default_regular_adjusted_means = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, baseline_mean=300, signal_mean=20000, ) non_default_series.update(NonDefaultRegularAdjustedMeans=non_default_regular_adjusted_means) # Test Case 5: Irregular short pulses with adjusted noise irregular_short_pulses_adjusted_noise = generate_mock_ttl_signal( signal_duration=2.5, ttl_times=[0.22, 1.37], ttl_duration=0.25, sampling_frequency_hz=rate, channel_noise=2, ) non_default_series.update(IrregularShortPulsesAdjustedNoise=irregular_short_pulses_adjusted_noise) # Test Case 6: Non-default regular as floats non_default_regular_as_floats = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, dtype="float32", ) non_default_series.update(NonDefaultRegularFloats=non_default_regular_as_floats) # Test Case 7: Non-default regular as floats with adjusted means and noise (which are also, then, floats) non_default_regular_as_floats_adjusted_means_and_noise = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, dtype="float32", baseline_mean=1.1, signal_mean=7.2, channel_noise=0.4, ) non_default_series.update(FloatsAdjustedMeansAndNoise=non_default_regular_as_floats_adjusted_means_and_noise) # Test Case 8: Non-default regular as uint16 non_default_regular_as_uint16 = generate_mock_ttl_signal( signal_duration=2.7, ttl_times=[0.2, 1.2, 2.2], ttl_duration=0.3, sampling_frequency_hz=rate, dtype="uint16", ) non_default_series.update(NonDefaultRegularUInt16=non_default_regular_as_uint16) # Test Case 9: Irregular short pulses with different seed irregular_short_pulses_different_seed = generate_mock_ttl_signal( signal_duration=2.5, ttl_times=[0.22, 1.37], ttl_duration=0.25, sampling_frequency_hz=rate, random_seed=1, ) non_default_series.update(IrregularShortPulsesDifferentSeed=irregular_short_pulses_different_seed) if regenerate_reference_images: num_cols = 5 plot_index = 1 subplot_titles = ["Default"] subplot_titles.extend(list(non_default_series)) fig = make_subplots(rows=2, cols=num_cols, subplot_titles=subplot_titles) fig.add_trace(go.Scatter(y=default_ttl_signal, text="Default"), row=1, col=1) for time_series_name, time_series_data in non_default_series.items(): nwbfile.add_acquisition( TimeSeries( name=time_series_name, unit=unit, rate=rate, data=H5DataIO(data=time_series_data, chunks=time_series_data.shape, **compression_options), ) ) if regenerate_reference_images: fig.add_trace( go.Scatter(y=time_series_data, text=time_series_name), row=math.floor(plot_index / num_cols) + 1, col=int(plot_index % num_cols) + 1, ) plot_index += 1 if regenerate_reference_images: fig.update_annotations(font_size=6) fig.update_layout(showlegend=False) fig.update_yaxes(tickfont=dict(size=5)) fig.update_xaxes(showticklabels=False) fig.write_image(file=image_file_path) with NWBHDF5IO(path=nwbfile_path, mode="w") as io: io.write(nwbfile)