#!/usr/bin/python
# vim: set expandtab ts=4 sw=4:
"""
Functions for computing power spectra based on windowed Fast Fourier Transforms.
The module contains four types of function.
* worker functions which perform detailed computation - these functions do not
necessarily have globally sensible or adaptive defaults and some additional
processing may be required to get a full spectrum from them. For example,
windows and scaling factors may need to be precomputed.
* config function which set, check and provide parameter values
* helper functions isolate code from worker and config functions and are typically _private
* user functions which provide a top-level interface to the module
The user functions are intended for day to day use, though the worker functions
can be called directly if needed.
User functions:
periodogram
sw_periodogram
multitaper
sw_multitaper
glm_periodogram
Config classes:
STFTConfig
PeriodogramConfig
MultiTaperConfig
GLMPeriodogramConfig
Worker functions:
apply_sliding_window
compute_fft
compute_stft
compute_multitaper_stft
"""
import logging
import typing
import warnings
from dataclasses import dataclass
from functools import wraps
import numpy as np
from scipy import fft as sp_fft
from scipy import stats, signal
from scipy.signal.windows import dpss
try:
from scipy.signal.spectral import _triage_segments
except ImportError:
# scipy.signal was refactored in v1.8.0
from scipy.signal._spectral_py import _triage_segments
# Will configure a proper logger later
logging.basicConfig(level=logging.WARNING)
# Decorator for setting and reverting logging level. Top level functions should
# specify verbose but within internal operations we ALWAYS set verbose=None to
# preserve current level. Only top-level user called function should have a
# specification.
def set_verbose(func):
"""Add option to change logging level for single function calls."""
# This is the actual decorator
@wraps(func)
def inner_verbose(*args, **kwargs):
# Change level if requested
if ('verbose' in kwargs) and (kwargs['verbose'] is not None):
if kwargs['verbose'].upper() in ['DEBUG', 'INFO', 'WARNING', 'CRITICAL']:
logging.getLogger().setLevel(getattr(logging, kwargs['verbose']))
if kwargs['verbose'] == 'DEBUG':
formatter = logging.Formatter('%(asctime)-s - %(levelname)-8s %(funcName)30s : %(message)s')
formatter.datefmt = '%H:%M:%S'
else:
formatter = logging.Formatter('%(funcName)30s - %(message)s')
logging.getLogger().handlers[0].setFormatter(formatter)
else:
raise ValueError("Logger level '{}' not recognised".format(kwargs['verbose']))
# Call function itself
func_output = func(*args, **kwargs)
# If we changed anything, change it back - otherwise preserve current level
if ('verbose' in kwargs) and (kwargs['verbose'] is not None):
logging.getLogger().setLevel(logging.WARNING)
return func_output
return inner_verbose
# ------------------------------------------------------------------
# Worker Functions
#
# These functions are stand-alone data processors which are usable on their own
# Inputs are not sanity checked and documentation may point elsewhere but these
# are fast and flexible for expert users.
#
# Most users will likely interact with these via the high level and config functions.
@set_verbose
def apply_sliding_window(x, nperseg=256, nstep=128, window=None,
detrend_func=None, padded=False, verbose=None):
"""Apply a delay-embedding to the last axis of an array.
Create a windowed versions of a dataset with options for specifying
padding, detrending and windowing operations.
Parameters
----------
x : ndarray
Array of data
%(stft_window)s'
%(verbose)s'
Returns
-------
ndarray
Data array with delay embedding applied
Notes
-----
Strongly inspired by scipy.signal.spectral._fft_helper with the FFT
computation separate out.
"""
# pad to make the vector length an integer number of windows
msg = 'Applying sliding windows of nperseg : {0} nstep : {1}'
logging.info(msg.format(nperseg, nstep))
if padded:
nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg
zeros_shape = list(x.shape[:-1]) + [nadd]
y = np.concatenate((x, np.zeros(zeros_shape)), axis=-1)
else:
y = x
# Strided array
# https://github.com/scipy/scip/y/blob/v1.5.1/scipy/signal/spectral.py#L1896
noverlap = nperseg-nstep
step = nperseg - noverlap
shape = y.shape[:-1]+((y.shape[-1]-noverlap)//step, nperseg)
strides = y.strides[:-1]+(step*y.strides[-1], y.strides[-1])
y_window = np.lib.stride_tricks.as_strided(y, shape=shape, strides=strides)
logging.info('Created {0} windows of length {1}'.format(y_window.shape[-2], y_window.shape[-1]))
logging.debug('windowed data shape : {0}'.format(y_window.shape))
if detrend_func is not None:
logging.debug('applying detrending {0} '.format(detrend_func))
y_window = detrend_func(y_window)
if window is not None:
# Apply windowing
logging.debug('applying windowing {0} '.format(window.sum()))
y_window = window * y_window
msg = 'output shape : {0}'
logging.debug(msg.format(y_window.shape))
return y_window
@set_verbose
def compute_fft(x, nfft=256, axis=-1,
side='onesided', mode='psd', scale=1.0,
fs=1.0, fmin=-0.5, fmax=0.5, verbose=None):
"""Compute, trim and post-process an FFT on last dimension of input array.
Parameters
----------
x : ndarray
Array of data
%(fft_core)s'
%(verbose)s'
Returns
-------
f : ndarray
Array of sample frequencies.
spec : ndarray
FFT spectrum of x.
"""
# Compute FFT
if side == 'twosided':
func = sp_fft.fft
else:
x = x.real
func = sp_fft.rfft
logging.info('Computing {0}-point {1} FFT with {2}'.format(nfft, side, func))
result = func(x, nfft)
logging.debug('fft output shape {0}'.format(result.shape))
# Apply spectrum mode selection
result = _proc_spectrum_mode(result, mode, axis=axis)
# Apply scaling
result = _proc_spectrum_scaling(result, scale, side, mode, nfft)
# Get frequency values
freqvals = _set_freqvalues(nfft, fs, side)
# Trim frequency range to specified limits
freqs, result = _proc_trim_freq_range(result, freqvals, fmin, fmax)
return freqs, result
@set_verbose
def compute_stft(x,
# STFT window args
nperseg=256, nstep=256, window=None, detrend_func=None, padded=False,
# FFT core args
nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0,
fs=1.0, fmin=-0.5, fmax=0.5,
# misc
output_axis='auto', verbose=None):
"""Compute a short-time Fourier transform to a dataset.
Parameters
----------
x : ndarray
Array of data
%(stft_window)s'
%(fft_core)s'
%(output_axis)s'
%(verbose)s'
Returns
-------
freqs : ndarray
Array of sample frequencies.
t : ndarray
Array of times corresponding to each data segment
result : ndarray
Array of output data, contents dependent on *mode* kwarg.
Notes
-----
Initially adapted from scipy.signal.spectral._spectral_helper.
"""
# ---- Work start here
if axis == -1:
axis = x.ndim-1
x = _proc_roll_input(x, axis=axis)
# window inputs
y = apply_sliding_window(x, nperseg, nstep, detrend_func=detrend_func, window=window, padded=padded)
# Run actual FFT
freqs, result = compute_fft(y, nfft=nfft, axis=axis, side=side, mode=mode,
scale=scale, fs=fs, fmin=fmin, fmax=fmax)
# Create time window vector
noverlap = nperseg - nstep
time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1,
nperseg - noverlap)/float(fs)
# Final two axes are now [..., time x freq]
result = _proc_unroll_output(result, axis, output_axis=output_axis)
return freqs, time, result
@set_verbose
def compute_multitaper_stft(x, num_tapers='auto', freq_resolution=1, time_bandwidth=5,
apply_tapers='broadcast',
# STFT window args
nperseg=256, nstep=256, window=None, detrend_func=None, padded=False,
# FFT core args
nfft=256, axis=-1, side='onesided', mode='psd', scale=1.0,
fs=1.0, fmin=-0.5, fmax=0.5,
# misc
output_axis='auto', verbose=None):
"""Compute a multi-tapered short time fourier transform.
Parameters
----------
x : ndarray
Array of data
%(multitaper_core)s'
%(stft_window)s'
%(fft_core)s'
%(output_axis)s'
%(verbose)s'
Returns
-------
freqs : ndarray
Array of sample frequencies.
t : ndarray
Array of times corresponding to each data segment
result : ndarray
Array of output data, contents dependent on *mode* kwarg.
"""
if num_tapers == 'auto':
seconds_perseg = nperseg / fs
time_half_bandwidth = int(seconds_perseg * freq_resolution / 2)
num_tapers = 2 * time_half_bandwidth - 1
msg = 'freq_resolution : {0} time_bandwidth : {1} time_half_bandwidth: {2}'
logging.debug(msg.format(freq_resolution, time_bandwidth, time_half_bandwidth))
logging.info('Using auto-computed number of tapers - {0}'.format(num_tapers))
else:
logging.info('Using user-specified number of tapers - {0}'.format(num_tapers))
tapers, ratios = dpss(nperseg, time_bandwidth, num_tapers, return_ratios=True)
taper_weights = np.ones((num_tapers,)) / num_tapers
# ---- Work start here
if axis == -1:
axis = x.ndim-1
x = _proc_roll_input(x, axis=axis)
# delay embedding - don't apply normal window function... will apply tapers next
y = apply_sliding_window(x, nperseg, nstep,
detrend_func=detrend_func,
window=None, padded=padded)
if apply_tapers == 'broadcast':
# Apply tapers - via broadcasting to avoid loops
to_shape = np.r_[np.ones((len(y.shape)-1),), num_tapers, nperseg].astype(int)
z = y[..., np.newaxis, :] * np.broadcast_to(tapers, to_shape)
logging.debug('tapered data shape {0}'.format(z.shape))
# Run actual FFT
freqs, result = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode,
scale=scale, fs=fs, fmin=fmin, fmax=fmax)
logging.debug('tapered and fftd data shape {0}'.format(result.shape))
# Average over tapers - could be high level option? mean or median?
result = np.average(result, weights=taper_weights, axis=-2)
elif apply_tapers == 'loop':
# Apply tapers in a loop - slower but uses much less RAM
to_shape = np.r_[np.ones((len(y.shape)-1),), nperseg].astype(int)
for ii in range(num_tapers):
logging.info('running taper: {0}'.format(ii))
z = y * np.broadcast_to(tapers[ii, :], to_shape)
# Run actual FFT
freqs, taper_result = compute_fft(z, nfft=nfft, axis=-1, side=side, mode=mode,
scale=scale, fs=fs, fmin=fmin, fmax=fmax)
logging.debug('tapered and fftd data shape {0}'.format(taper_result.shape))
# Run an incremental average so we don't have to store anything
# https://math.stackexchange.com/questions/106700/incremental-averaging/1836447
if ii == 0:
result = taper_result
else:
result = result + (taper_result - result) / (ii + 1)
else:
msg = "'apply_tapers' option '{0}' not recognised. Use one of 'broadcast' or 'loop''"
raise ValueError(msg.format(apply_tapers))
# Periodogram Scaling
result = result / fs
# Create time window vector
noverlap = nperseg - nstep
time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1,
nperseg - noverlap)/float(fs)
# Final two axes are now [..., time x freq] - return them to requested position
result = _proc_unroll_output(result, axis, output_axis=output_axis)
return freqs, time, result
def compute_spectral_matrix_fft(pxx):
"""Compute a spectral matrix from a periodogram.
Parameters
----------
%(pxx_complex)s'
Returns
-------
S
Cross-Spectral Density Matrix.
"""
# pxx should be [channels x freq] complex for now
S = np.zeros((pxx.shape[0], pxx.shape[0], pxx.shape[1]), dtype=complex)
for ii in range(pxx.shape[1]):
S[:, :, ii] = np.dot(pxx[:, ii, np.newaxis], pxx[np.newaxis, :, ii].conj())
return S
# Helpers - private functions assisting low-level processors
def _proc_roll_input(x, axis=-1):
"""Move axis to be transformed to final position.
Parameters
----------
x : ndarray
array of numeric data values
%(axis)s'
Returns
-------
x : ndarray
A view of x with the specified axes rolled to final position
"""
logging.debug('Rolling input axis position {0} to end'.format(axis))
logging.debug('Pre-rolled shape - {0}'.format(x.shape))
if axis != -1:
x = np.rollaxis(x, axis, len(x.shape))
logging.debug('Post-rolled shape - {0}'.format(x.shape))
return x
def _proc_unroll_output(result, axis, output_axis='auto'):
"""Move time and frequency dimensions to user specified position.
Parameters
----------
result : ndarray
array of numeric data values, typically an output from `compute_stft`
%(axis)s'
%(output_axis)s'
Returns
-------
result : ndarray
A view of the input array with the axis rolled to specified positions.
Notes
-----
The `time_first` option is used to simplify the subsequent temporal
averaging of standard periodograms and the computation of GLM-periodograms
which require the temporal dimension in the first position.
"""
logging.debug('Rolling output axis {0} to position {1}'.format(axis, output_axis))
logging.debug('Pre-rolled shape {0}'.format(result.shape))
if output_axis == 'auto':
# Return time and freq back to original position
result = np.rollaxis(result, -2, axis)
result = np.rollaxis(result, -1, axis+1)
elif output_axis == 'time_first':
# Put time at front and freq in original position
result = np.rollaxis(result, -2, 0)
result = np.rollaxis(result, -1, axis+1)
logging.debug('Post-rolled shape {0}'.format(result.shape))
return result
def _proc_spectrum_mode(pxx, mode, axis=-1):
"""Apply specified transformation to STFT result.
Parameters
----------
%(pxx_complex)s'
%(spec_mode)s'
%(axis)s'
Returns
-------
pxx : ndarray
Array containing the transformed power spectrum
"""
logging.debug('computing {0} spectrum'.format(mode))
if mode == 'magnitude':
pxx = np.abs(pxx)
elif mode == 'psd':
pxx = (np.conjugate(pxx) * pxx).real
elif mode in ['angle', 'phase']:
pxx = np.angle(pxx)
if mode == 'phase':
# pxx has one additional dimension for time strides
if axis < 0:
axis -= 1
pxx = np.unwrap(pxx, axis=axis)
elif mode == 'complex':
pass
return pxx
def _proc_spectrum_scaling(pxx, scale, side, mode, nfft):
"""Apply specified unit scaling to STFT output.
Need to ignore DC and Nyquist frequencies to ensure that overall power is
consistent with time-dimension.
Parameters
----------
%(pxx_complex)s'
%(fft_scale)s'
%(fft_side)s'
%(spec_mode)s'
%(nfft)s'
Returns
-------
pxx : ndarray
Scaled version of input power spectrum
"""
logging.debug('Applying scaling factor {0}'.format(scale))
pxx *= scale
# Need to handle first and last points differently in onesided and twosided modes
if side == 'onesided' and mode == 'psd':
if nfft % 2:
pxx[..., 1:] *= 2
else:
# Last point is unpaired Nyquist freq point, don't double
pxx[..., 1:-1] *= 2
return pxx
def _proc_trim_freq_range(result, freqvals, fmin, fmax):
"""Trim an FFT output to desired frequency range.
This helper function assumes that we want to trim the final axis.
Parameters
----------
result : array_like
Spectrum result with frequency on the final axis
freqvals : vector
Vector of frequency values with length matching the final axis of result
%(freq_range)s'
Returns
-------
result : array_like
Input array with final dimension trimmed in-place
freqs : array_like
New frequency array matching the final axis of output
"""
logging.info('Trimming freq axis to range {0} - {1}'.format(fmin, fmax))
fidx = (freqvals >= fmin) & \
(freqvals <= fmax)
result = result[..., fidx]
freqs = freqvals[fidx]
logging.debug('fft trimmed output shape {0}'.format(result.shape))
return freqs, result
# ------------------------------------------------------------------
# Config Functions
#
# These functions parse inputs, set defaults and return sets of configured
# options.
def _set_freqvalues(nfft, fs, side):
"""Set frequency values for FFT.
Parameters
----------
%(nfft)s'
%(fs)s'
%(fft_side)s'
Returns
-------
ndarray
Vector of frequency values
"""
if side == 'twosided':
freqs = sp_fft.fftfreq(nfft, 1/fs)
elif side == 'onesided':
freqs = sp_fft.rfftfreq(nfft, 1/fs)
return freqs
def _set_onesided(return_onesided, input_complex):
"""Set flag indicating whether FFT will be one- or two-sided.
Parameters
----------
return_onesided : bool
Flag indicating whether one-sided FFT is preferred
input_complex : bool
Flag indicating whether input array is complex
Returns
-------
str
One of 'onesided' or 'twosided'
"""
if return_onesided:
if input_complex:
sides = 'twosided'
warnings.warn('Input data is complex, switching to return_onesided=False')
else:
sides = 'onesided'
else:
sides = 'twosided'
return sides
def _set_nfft(nfft, nperseg):
"""Set FFT length.
Parameters
----------
%(nfft)s'
%(nperseg)s'
Returns
-------
int
Selected length of FFT
"""
if nfft is None:
nfft = nperseg
elif nfft < nperseg:
raise ValueError('nfft must be greater than or equal to nperseg.')
else:
nfft = int(nfft)
return nfft
def _set_noverlap(noverlap, nperseg):
"""Set length over overlap between successive windows.
Parameters
----------
%(noverlap)s'
%(nperseg)s'
Returns
-------
int
Number of overlapping samples between successive windows.
"""
if noverlap is None:
noverlap = nperseg//2
else:
noverlap = int(noverlap)
if noverlap >= nperseg:
raise ValueError('noverlap must be less than nperseg.')
return noverlap
def _set_scaling(scale, fs, window):
"""Set scaling to be applied to FFT output.
Parameters
----------
%(fft_scale)s'
%(fs)s'
%(window)s'
Returns
-------
float
Scaling factor to apply to FFT result
"""
logging.debug("setting scaling '{0}' {1}".format(scale, fs))
if scale == 'density':
sfactor = 1.0 / (fs * (window*window).sum())
elif scale == 'spectrum':
sfactor = 1.0 / window.sum()**2
elif scale is None:
sfactor = 1.0
else:
raise ValueError('Unknown scale: %r' % scale)
return sfactor
def _set_heinzel_scaling(fs, win, input_length):
"""Compute FFT scaling factors with Heinzel method.
https://holometer.fnal.gov/GH_FFT.pdf - section 9
EXPERIMENTAL - NOT YET PLUGGED IN.
Parameters
----------
fs : float
Sampling rate of the data
win : ndarray
Windowing function
input_length : int
Length of input data
Returns
-------
nenbw : float
normalised effective noise bandwidth
enbw : float
effective noise bandwidth
"""
s1 = win.sum()
s2 = (win*win).sum()
nenbw = len(win) * (s2 / s1**2)
enbw = fs * (s2 / s1**2)
return nenbw, enbw
def _set_detrend(detrend, axis):
"""Set a detrending function to be applied to STFT windows prior to FFT.
Parameters
----------
detrend : {False, str, func}
One of either:
* False or None - Detrend function does nothing
* str - detrend function defined by scipy.signal.detrend with type=detrend.
* func - specified detrend function is returned
%(axis)s'
Returns
-------
func
Detrending function
"""
# Handle detrending and window functions
if not detrend:
def detrend_func(d):
return d
elif not hasattr(detrend, '__call__'):
def detrend_func(d):
return signal.detrend(d, type=detrend, axis=-1)
elif axis != -1:
# Wrap this function so that it receives a shape that it could
# reasonably expect to receive.
def detrend_func(d):
d = np.rollaxis(d, -1, axis)
d = detrend(d)
return np.rollaxis(d, axis, len(d.shape))
else:
detrend_func = detrend
return detrend_func
def _set_mode(mode):
"""Set output mode for FFT result.
Parameters
----------
%(spec_mode)s'
Raises
------
ValueError
If non-valid mode is specified
"""
modelist = ['psd', 'complex', 'magnitude', 'angle', 'phase']
if mode not in modelist:
raise ValueError("Invalid value ('{}') for mode, must be one of {}"
.format(mode, modelist))
def _set_frange(fs, fmin, fmax, side):
"""Set range of frequenecy values to be returned from FFT result.
Parameters
----------
%(fs)s'
%(freq_range)s'
%(fft_side)s'
Returns
-------
float, float
fmin and fmax
"""
if fmin is None and side == 'onesided':
fmin = 0
elif fmin is None and side == 'twosided':
fmin = -fs/2
if fmax is None:
fmax = fs/2
return fmin, fmax
@dataclass
class STFTConfig:
"""Configuration options for a Short Time Fourier Transform.
This sets user options and sensible defaults for an STFT to be applied to a
specific dataset. It may not generalise to datasets of different lengths,
sampling rates etc.
"""
# Data specific args
input_len: int
axis: int = -1
input_complex: bool = False
# General FFT args
fs: float = 1.0
window_type: str = 'hann'
nperseg: int = None
noverlap: int = None
nfft: int = None
detrend: typing.Union[typing.Callable, str] = 'constant'
return_onesided: bool = True
scaling: str = 'density'
mode: str = 'psd'
boundary: str = None # Not currently used...
padded = bool = False
fmin: float = None
fmax: float = None
output_axis: typing.Union[int, str] = 'auto'
def __post_init__(self):
"""Set user picks and fill rest with sensible defaults."""
if self.window_type is None:
# Set a rectangular boxcar of 1s as the window if None is requested
# - keeps following code cleaner.
self.window_type = 'boxcar'
self.window, self.nperseg = _triage_segments(self.window_type, self.nperseg, input_length=self.input_len)
self.nfft = _set_nfft(self.nfft, self.nperseg)
self.noverlap = _set_noverlap(self.noverlap, self.nperseg)
self.nstep = self.nperseg - self.noverlap
self.nwindows = np.fix(self.input_len/self.nstep - 1).astype(int)
self.scale = _set_scaling(self.scaling, self.fs, self.window)
self.detrend_func = _set_detrend(self.detrend, axis=self.axis)
self.side = _set_onesided(self.return_onesided, self.input_complex)
self.fullfreqvals = _set_freqvalues(self.nfft, self.fs, self.side)
self.fmin, self.fmax = _set_frange(self.fs, self.fmin, self.fmax, self.side)
fidx = (self.fullfreqvals >= self.fmin) & (self.fullfreqvals <= self.fmax)
self.freqvals = self.fullfreqvals[fidx]
_set_mode(self.mode)
logging.debug(self)
@property
def stft_args(self):
"""Get keyword arguments for a call to compute_stft."""
args = {}
for key in ['fs', 'nperseg', 'nstep', 'nfft', 'detrend_func',
'side', 'scale', 'axis', 'mode', 'window',
'padded', 'fmin', 'fmax', 'output_axis']:
args[key] = getattr(self, key)
return args
@property
def sliding_window_args(self):
"""Get keyword arguments for a call to apply_sliding_window."""
args = {}
for key in ['nperseg', 'nstep', 'detrend_func', 'window', 'padded']:
args[key] = getattr(self, key)
return args
@property
def fft_args(self):
"""Get keyword arguments for a call to compute_fft."""
args = {}
for key in ['nfft', 'axis', 'side', 'mode', 'scale', 'fs', 'fmin', 'fmax']:
args[key] = getattr(self, key)
return args
@dataclass
class PeriodogramConfig(STFTConfig):
"""Configuration options for a Periodogram.
This inhrits from STFTConfig and includes extra periodogram specific arguments:
'average'
This sets user options and sensible defaults for Periodogram to be applied
to a specific dataset. It may not generalise to datasets of different
lengths, sampling rates etc.
"""
average: str = 'mean'
def __post_init__(self):
"""Set user picks and fill rest with sensible defaults."""
super().__post_init__()
@dataclass
class GLMPeriodogramConfig(STFTConfig):
"""Configuration options for a GLM Periodogram.
This inhrits from STFTConfig and includes extra periodogram specific arguments:
'reg_ztrans'
'reg_unitmax'
'fit_method'
'fit_intercept'
This sets user options and sensible defaults for Periodogram to be applied
to a specific dataset. It may not generalise to datasets of different
lengths, sampling rates etc.
"""
reg_ztrans: dict = None
reg_unitmax: dict = None
reg_categorical: dict = None
contrasts: dict = None
fit_method: str = 'pinv'
fit_intercept: bool = True
def __post_init__(self):
"""Set user picks and fill rest with sensible defaults."""
super().__post_init__()
@dataclass
class MultiTaperConfig(STFTConfig):
"""Configuration options for a GLM Periodogram.
This inhrits from STFTConfig and includes extra periodogram specific arguments:
'average'
'time_bandwidth'
'num_tapers'
'freq_resolution'
'apply_tapers'
This sets user options and sensible defaults for Periodogram to be applied
to a specific dataset. It may not generalise to datasets of different
lengths, sampling rates etc.
"""
average: str = 'mean'
time_bandwidth: int = 3
num_tapers: typing.Union[str, int] = 'auto'
freq_resolution: int = 1
apply_tapers: str = 'broadcast'
def __post_init__(self):
"""Set user picks and fill rest with sensible defaults."""
super().__post_init__()
@property
def multitaper_stft_args(self):
"""Get keyword arguments for a call to compute_multitaper_stft."""
args = {}
for key in ['time_bandwidth', 'num_tapers', 'freq_resolution',
'apply_tapers', 'fs', 'nperseg', 'nstep', 'nfft',
'detrend_func', 'side', 'scale', 'axis', 'mode', 'padded',
'fmin', 'fmax', 'output_axis']:
args[key] = getattr(self, key)
return args
# ------------------------------------------------------------------------
# Top-level computation functions
#
# These functions take input data, run the option handling and execute whatever
# computations are needed
[docs]
@set_verbose
def sw_periodogram(x,
# STFT window args
nperseg=None, noverlap=None, window_type='hann', detrend='constant',
# FFT core args
nfft=None, axis=-1, return_onesided=True, mode='psd',
scaling='density', fs=1.0, fmin=None, fmax=None,
# misc
return_config=False, verbose=None):
"""Compute Periodogram by averaging across windows in a STFT.
Parameters
----------
x : array_like
Time series of measurement values
%(stft_window_user)s'
%(fft_user)s'
return_config : bool
Indicate whether parameter configuration object should be returned
alongside result (Default value = False)
%(verbose)s'
Returns
-------
freqs : ndarray
Array of sample frequencies.
t : ndarray
Array of times corresponding to each data segment
result : ndarray
Array of output data, contents dependent on *mode* kwarg.
config : PeriodogramConfig, optional
Configuration object containing all parameters used to compute
spectrum, optionally returned based on value of `return_config`.
"""
# Config object stores options in one place and sets sensible defaults for
# unspecified options given the data in-hand
config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
average=None, fs=fs, window_type=window_type, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
return_onesided=return_onesided, scaling=scaling, axis=axis,
fmin=fmin, fmax=fmax, output_axis='auto')
f, t, p = compute_stft(x, **config.stft_args)
logging.debug(p.shape)
if return_config:
return f, t, p, config
else:
return f, t, p
[docs]
@set_verbose
def periodogram(x, average='mean',
# STFT window args
nperseg=None, noverlap=None, window_type='hann', detrend='constant',
# FFT core args
nfft=None, axis=-1, return_onesided=True, mode='psd',
scaling='density', fs=1.0, fmin=None, fmax=None,
# misc
return_config=False, verbose=None):
"""Compute Periodogram by averaging across windows in a STFT.
Parameters
----------
x : array_like
Time series of measurement values
average : { 'mean', 'median', 'median_bias' }, optional
Method to use when averaging periodograms. Defaults to 'mean'.
%(stft_window_user)s'
%(fft_user)s'
return_config : bool
Indicate whether parameter configuration object should be returned
alongside result (Default value = False)
%(verbose)s'
Returns
-------
freqs : ndarray
Array of sample frequencies.
t : ndarray
Array of times corresponding to each data segment
result : ndarray
Array of output data, contents dependent on *mode* kwarg.
config : PeriodogramConfig, optional
Configuration object containing all parameters used to compute
spectrum, optionally returned based on value of `return_config`.
"""
# Config object stores options in one place and sets sensible defaults for
# unspecified options given the data in-hand
logging.info('Setting config options')
config = PeriodogramConfig(x.shape[axis], input_complex=np.any(np.iscomplex(x)),
average=average, fs=fs, window_type=window_type, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend, mode=mode,
return_onesided=return_onesided, scaling=scaling, axis=axis,
fmin=fmin, fmax=fmax, output_axis='time_first')
logging.info('Starting computation')
f, t, p = compute_stft(x, **config.stft_args)
logging.info('Averaging across first dim of result using method {0}'.format(config.average))
if config.average == 'mean':
p = np.nanmean(p, axis=0)
elif config.average == 'median':
p = np.nanmedian(p, axis=0)
elif config.average == 'median_bias':
bias = signal._spectral_py._median_bias(p.shape[0])
p = np.nanmedian(p, axis=0) / bias
else:
msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'"
raise ValueError(msg.format(config.average))
logging.info('Returning spectrum of shape {0}'.format(p.shape))
if return_config:
return f, p, config
else:
return f, p
[docs]
@set_verbose
def sw_multitaper(x,
# Multitaper core
num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast',
# STFT window args
nperseg=None, noverlap=None, window_type='hann', detrend='constant',
# FFT core args
nfft=None, axis=-1, return_onesided=True, mode='psd',
scaling='density', fs=1.0, fmin=None, fmax=None,
# misc
return_config=False, verbose=None):
"""Compute a multi-tapered power spectrum across windows in a STFT.
Parameters
----------
x : array_like
Time series of measurement values
%(multitaper_core)s'
%(stft_window_user)s'
%(fft_user)s'
return_config : bool
Indicate whether parameter configuration object should be returned
alongside result (Default value = False)
%(verbose)s'
Returns
-------
freqs : ndarray
Array of sample frequencies.
t : ndarray
Array of times corresponding to each data segment
result : ndarray
Array of output data, contents dependent on *mode* kwarg.
config : MultiTaperConfig, optional
Configuration object containing all parameters used to compute
spectrum, optionally returned based on value of `return_config`.
"""
# Config object stores options in one place and sets sensible defaults for
# unspecified options given the data in-hand
config = MultiTaperConfig(x.shape[axis],
input_complex=np.any(np.iscomplex(x)),
time_bandwidth=time_bandwidth,
num_tapers=num_tapers, apply_tapers=apply_tapers,
freq_resolution=freq_resolution, average=None,
fs=fs, window_type=None, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, mode=mode,
scaling=scaling, axis=axis, fmin=fmin, fmax=fmax,
output_axis='auto')
f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args)
if return_config:
return f, t, p, config
else:
return f, t, p
[docs]
@set_verbose
def multitaper(x, average='mean',
# Multitaper core
num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast',
# STFT window args
nperseg=None, noverlap=None, window_type='hann', detrend='constant',
# FFT core args
nfft=None, axis=-1, return_onesided=True, mode='psd',
scaling='density', fs=1.0, fmin=None, fmax=None,
# misc
return_config=False, verbose=None):
"""Compute a multi-tapered power spectrum averaged across windows in a STFT.
Parameters
----------
x : array_like
Time series of measurement values
average : { 'mean', 'median', 'median_bias' }, optional
Method to use when averaging periodograms. Defaults to 'mean'.
%(multitaper_core)s'
%(stft_window_user)s'
%(fft_user)s'
return_config : bool
Indicate whether parameter configuration object should be returned
alongside result (Default value = False)
%(verbose)s'
Returns
-------
freqs : ndarray
Array of sample frequencies.
t : ndarray
Array of times corresponding to each data segment
result : ndarray
Array of output data, contents dependent on *mode* kwarg.
config : MultiTaperConfig, optional
Configuration object containing all parameters used to compute
spectrum, optionally returned based on value of `return_config`.
"""
# Config object stores options in one place and sets sensible defaults for
# unspecified options given the data in-hand
config = MultiTaperConfig(x.shape[axis],
input_complex=np.any(np.iscomplex(x)),
time_bandwidth=time_bandwidth,
num_tapers=num_tapers, apply_tapers=apply_tapers,
freq_resolution=freq_resolution, average=average,
fs=fs, window_type=None, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, mode=mode,
scaling=scaling, axis=axis, fmin=fmin, fmax=fmax,
output_axis='time_first')
f, t, p = compute_multitaper_stft(x, **config.multitaper_stft_args)
if config.average == 'mean':
p = np.nanmean(p, axis=0)
elif config.average == 'median':
p = np.nanmedian(p, axis=0)
elif config.average == 'median_bias':
bias = signal._spectral_py._median_bias(p.shape[0])
p = np.nanmedian(p, axis=0) / bias
else:
msg = "'average' value of '{0}' not recognised - please use 'mean' or 'median'"
raise ValueError(msg.format(config.average))
if return_config:
return f, p, config
else:
return f, p
# -----------------------------------------------------------------------
# GLM Spectrogram Functions
@dataclass
class GLMSpectrumResult:
def __init__(self, f, model, design, data, config=None):
self.f = f
self.config = config
self.model = model
self.design = design
self.data = data
@property
def copes(self):
return self.model.copes
def _is_sklearn_estimator(fit_method):
"""Check (in the duck sense) if object is a skearn fitter.
Parameters
----------
fit_method : obj
Initialised sklearn fitter object
Returns
-------
bool
flag indicating whether obj is a likely sklearn fitter
"""
test1 = hasattr(fit_method, 'fit') and callable(getattr(fit_method, 'fit'))
test2 = hasattr(fit_method, 'get_params') and callable(getattr(fit_method, 'get_params'))
test3 = hasattr(fit_method, 'set_params') and callable(getattr(fit_method, 'set_params'))
return test1 and test2 and test3
def _compute_ols_varcopes(design_matrix, data, contrasts, betas):
"""Compute variance of cope estimates.
Parameters
----------
design_matrix : ndarray
Matrix specifying GLM design matrix of shape [num_observations x num_regressors]
data : ndarray
Matrix of data to be fitted
contrasts : ndarray
Matrix of contrasts
betas : ndarray
Array of fitted regression parameters
Returns
-------
ndarray
Standard error of GLM parameter estimates
"""
# Compute varcopes
varcopes = np.zeros((contrasts.shape[0], data.shape[1]))
# Compute varcopes
residue_forming_matrix = np.linalg.pinv(design_matrix.T.dot(design_matrix))
var_forming_matrix = np.diag(np.linalg.multi_dot([contrasts,
residue_forming_matrix,
contrasts.T]))
resid = data - design_matrix.dot(betas)
# This is equivalent to >> np.diag( resid.T.dot(resid) )
resid_dots = np.einsum('ij,ji->i', resid.T, resid)
del resid
dof_error = data.shape[0] - np.linalg.matrix_rank(design_matrix)
V = resid_dots / dof_error
varcopes = var_forming_matrix[:, None] * V[None, :]
return varcopes
def _process_regressor(Y, config, mode='confound'):
"""Prepare a vector of data into a STFT regressor.
Y is [nregs x nsamples]. Confound is scaled 0->1 and covariate is z-transformed.
Parameters
----------
Y : ndarray
Covariate vector
config : obj
Initialised object of STFT options
mode : {'covariate', 'confound', 'None'}
Type of regressor to create (Default value = 'confound')
Returns
-------
ndarray
Processed regressor
"""
window = None if mode == 'condition' else config.window
windowed = apply_sliding_window(Y, config.nperseg, config.nstep,
window=window, padded=config.padded)
y = np.nansum(windowed, axis=-1)
if mode == 'condition':
y = y / config.nperseg
elif mode == 'covariate':
y = stats.zscore(y, axis=-1)
elif mode == 'confound':
y = y - y.min(axis=-1)[:, np.newaxis]
y = y / y.max(axis=-1)[:, np.newaxis]
elif mode is None:
pass
return y
def _process_input_covariate(cov, input_len):
"""Prepare and check user specified GLM covariates.
Parameters
----------
cov :{dict, None, ndarray}
Specified covariates. One of a dictionary of vectors, a vector array, a
[num_regressors x num_samples] array or None.
input_len : int
Number of samples in input data
Returns
-------
dict
Set of covariates
"""
if isinstance(cov, dict):
# We have a dictionary - ensure every entry is array_like numeric vector with
# expected length
for key, var in cov.items():
if len(cov[key]) != input_len:
msg = "Regressor '{0}' shape ({1}) not matched to input data length ({2})"
raise ValueError(msg.format(key, len(cov[key]), input_len))
ret = cov # pass back out
elif cov is None:
ret = {} # No regressors defined
else:
# Check array_like inputs
cov = np.array(cov)
if np.issubdtype(cov.dtype, np.number) is False:
msg = "Regressor inputs must be numeric in type - input has type '{0}'"
raise ValueError(msg.format(cov.dtype))
# Add dummy dim if input is vector
if cov.ndim == 1:
cov = cov[np.newaxis, :]
# Check length of regressors are correct
if cov.shape[1] != input_len:
msg = 'Regressor shape ({0}) not matched to input data length ({1})'
raise ValueError(msg.format(cov.shape[1], input_len))
# Give regressor a dummy name
ret = {}
for ii in range(cov.shape[0]):
ret[chr(65 + ii)] = cov[ii, :]
return ret
def _specify_design(reg_categorical, reg_ztrans, reg_unitmax, config, fit_intercept=True):
"""Create a design matrix.
Parameters
----------
reg_ztrans : dict
Dictionary of covariate variables
reg_unitmax : dict
Dictionary of confound variables
config : obj
User specified STFT options
fit_intercept : bool
Flag indicating whether to include a constant term (Default value = True)
Returns
-------
design_matrix : ndarray
[num_observations x num_regressors] matrix of regressors
contrasts : ndarray
[num_regressors x num_regressors] matrix of contrasts (identity)
Xlabels : list of str
List of regressor names
"""
X = []
Xlabels = []
if fit_intercept:
logging.info("Adding constant")
X.append(np.ones((config.nwindows,)))
Xlabels.append('Constant')
# Add reg_categorical
for idx, var in enumerate(reg_categorical.keys()):
logging.info("Adding condition '{0}'".format(var))
X.append(_process_regressor(reg_categorical[var], config, mode='condition'))
Xlabels.append(var)
# Add reg_ztrans
for idx, var in enumerate(reg_ztrans.keys()):
logging.info("Adding covariate '{0}'".format(var))
X.append(_process_regressor(reg_ztrans[var], config, mode='covariate'))
Xlabels.append(var)
# Add reg_unitmax
for idx, var in enumerate(reg_unitmax.keys()):
logging.info("Adding confound '{0}'".format(var))
X.append(_process_regressor(reg_ztrans[var], config, mode='confound'))
Xlabels.append(var)
design_matrix = np.vstack(X).T
contrasts = np.eye(design_matrix.shape[1])
return design_matrix, contrasts, Xlabels
def _run_prefit_checks(data, design_matrix, contrasts):
"""Run a few checks to catch likely errors in GLM fit.
Parameters
----------
data : ndarray
Matrix of data to be fitted
design_matrix : ndarray
Matrix specifying GLM design matrix of shape [num_observations x num_regressors]
contrasts : ndarray
Matrix of contrasts
Returns
-------
None
"""
# Make sure we're set for model fitting
assert(data.shape[0] == design_matrix.shape[0])
assert(design_matrix.shape[1] == contrasts.shape[0])
def _glm_fit_simple(pxx, reg_categorical, reg_ztrans, reg_unitmax, config, fit_method='pinv', fit_intercept=True,
ret_arrays=True):
"""Fit a GLM using a standard OLS fitting method.
Parameters
----------
pxx : ndarray
Power spectrum estimate with sliding windows in the first dimension
reg_ztrans : dict
Dictionary of covariate variables
reg_unitmax : dict
Dictionary of confound variables
config : obj
User specified STFT options
fit_method : {'pinv', 'lstsq'}
Fitting method to use (Default value = 'pinv')
fit_intercept : bool
Flag indicating whether to fit a constant (Default value = True)
Returns
-------
copes : ndarray
array of fitted parameter estimates
varcopes : ndarray
array of standard errors of parameter estimates
"""
# Prepare GLM components
design_matrix, contrasts, Xlabels = _specify_design(reg_categorical, reg_ztrans, reg_unitmax,
config, fit_intercept=fit_intercept)
# Check we're probably good to go
_run_prefit_checks(pxx, design_matrix, contrasts)
# Compute parameters
if fit_method == 'pinv':
logging.debug('using np.linalg.pinv')
betas = np.linalg.pinv(design_matrix).dot(pxx)
elif fit_method == 'lstsq':
logging.debug('using np.linalg.lstsq')
betas, resids, rank, s = np.linalg.lstsq(design_matrix, pxx)
else:
raise ValueError("'fit_method' input {0} not recognised".format(fit_method))
# Compute COPES and VARCOPES
copes = contrasts.dot(betas)
varcopes = _compute_ols_varcopes(design_matrix, pxx, contrasts, betas)
if ret_arrays:
return betas, copes, varcopes
else:
out = GLMSpectrumResult(config.freqvals, betas, copes, varcopes,
config=config, design_matrix=design_matrix)
return out
def _glm_fit_sklearn_estimator(pxx, reg_categorical, reg_ztrans, reg_unitmax, config, fit_method, fit_intercept=True):
"""Fit a GLM using a sklearn-like estimator object.
Parameters
----------
pxx : ndarray
Power spectrum estimate with sliding windows in the first dimension
reg_ztrans : dict
Dictionary of covariate variables
reg_unitmax : dict
Dictionary of confound variables
config : obj
User specified STFT options
fit_intercept : bool
Flag indicating whether to fit a constant (Default value = True)
Returns
-------
copes : ndarray
array of fitted parameter estimates
varcopes : ndarray
array of standard errors of parameter estimates
fitter : obj
sklearn fitting object
"""
logging.info('Running sklearn GLM fit')
# Prepare GLM components
design_matrix, contrasts, Xlabels = _specify_design(reg_categorical, reg_ztrans, reg_unitmax,
config, fit_intercept=fit_intercept)
# Check we're probably good to go
_run_prefit_checks(pxx, design_matrix, contrasts)
# Compute parameters
fit_method.fit(design_matrix, pxx)
if hasattr(fit_method, 'coef_'):
betas = fit_method.coef_.T
else:
# Sometimes this is stored in a sub model...
betas = fit_method.estimator_.coef_.T
# Compute COPES and VARCOPES
copes = contrasts.dot(betas)
varcopes = _compute_ols_varcopes(design_matrix, pxx, contrasts, betas)
return betas, copes, varcopes, (fit_method)
def _glm_fit_glmtools(pxx, reg_categorical, reg_ztrans, reg_unitmax,
config, contrasts=None, fit_intercept=True):
"""Fit a GLM using the glmtools package.
Parameters
----------
pxx : ndarray
Power spectrum estimate with sliding windows in the first dimension
reg_ztrans : dict
Dictionary of covariate variables
reg_unitmax : dict
Dictionary of confound variables
config : obj
User specified STFT options
fit_intercept : bool
Flag indicating whether to fit a constant (Default value = True)
Returns
-------
copes : ndarray
array of fitted parameter estimates
varcopes : ndarray
array of standard errors of parameter estimates
extras : tuple
Tuple containing glmtools model, design and data objects
"""
logging.info('Running glmtools GLM fit')
import glmtools as glm # keep this as a soft dependency
# Allocate GLM data object
data = glm.data.TrialGLMData(data=pxx)
# Add windowed reg_unitmax and reg_ztrans - no preproc yet
for key, value in reg_categorical.items():
logging.debug('Processing Condition Regressor : {0}'.format(key))
data.info[key] = _process_regressor(value, config, mode='condition')
for key, value in reg_ztrans.items():
data.info[key] = _process_regressor(value, config, mode=None)
for key, value in reg_unitmax.items():
data.info[key] = _process_regressor(value, config, mode=None)
DC = glm.design.DesignConfig()
if fit_intercept:
logging.debug('Adding Constant Regressor')
DC.add_regressor(name='Constant', rtype='Constant')
for key in reg_categorical.keys():
logging.debug('Adding Condition : {0}'.format(key))
DC.add_regressor(name=key, rtype='Categorical', datainfo=key, codes=[1])
for key in reg_ztrans.keys():
logging.debug('Adding Covariate : {0}'.format(key))
DC.add_regressor(name=key, rtype='Parametric', datainfo=key, preproc='z')
for key in reg_unitmax.keys():
logging.debug('Adding Confound : {0}'.format(key))
DC.add_regressor(name=key, rtype='Parametric', datainfo=key, preproc='unitmax')
if contrasts is not None:
for con in contrasts:
DC.add_contrast(**con)
DC.add_simple_contrasts()
des = DC.design_from_datainfo(data.info)
model = glm.fit.OLSModel(des, data)
return model, des, data
[docs]
@set_verbose
def glm_periodogram(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
contrasts=None, fit_method='pinv', fit_intercept=True,
# STFT window args
nperseg=None, noverlap=None, window_type='hann', detrend='constant',
# FFT core args
nfft=None, axis=-1, return_onesided=True, mode='psd',
scaling='density', fs=1.0, fmin=None, fmax=None,
# misc
verbose=None):
"""Compute a Power Spectrum with a General Linear Model.
Parameters
----------
X : array_like
Time series of measurement values
reg_categorical : dict or None
Dictionary of covariate time series to be added as binary regessors. (Default value = None)
reg_ztrans : dict or None
Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None)
reg_unitmax : dict or None
Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None)
contrasts : dict or None
Dictionary of contrasts to be computed in the model.
(Default value = None, will add a simple contrast for each regressor)
fit_method : {'pinv', 'lstsq', 'glmtools', sklearn estimator instance}
Specifies how the GLM parameters will be estimated.
* `pinv` uses the design matrix psuedo-inverse method
* `lstsq` uses np.linalg.lstsq.
* `glmtools` uses the OLSModel from the glmtools package.
* A parametrised instance of a sklearn estimator is used if specified here. (Default value = 'pinv')
fit_intercept : bool
Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True)
%(stft_window_user)s'
%(fft_user)s'
%(verbose)s'
Returns
-------
GLMSpectrumResult : object
Object containing the fitted GLM Periodogram
"""
# Option housekeeping
if axis == -1:
axis = X.ndim - 1
if X.ndim != 1 and fit_method in ['pinv', 'lstsq']:
msg = "Data input should be vector for 'pinv' and 'lstsq' fits - data shape {0} was passed in"
logging.error(msg.format(X.shape))
logging.error("Use fit_method='glmtools' for multdimensional data")
raise ValueError("Fit methods 'pinv' and 'lstsq' not implemented for multidimensional data")
# Set configuration
logging.info('Setting config options')
config = GLMPeriodogramConfig(X.shape[axis], reg_ztrans=reg_ztrans,
reg_unitmax=reg_unitmax,
fit_method=fit_method, contrasts=contrasts,
fit_intercept=fit_intercept,
input_complex=np.iscomplexobj(X), fs=fs,
fmin=fmin, fmax=fmax,
window_type=window_type, nperseg=nperseg,
noverlap=noverlap,
nfft=nfft, detrend=detrend,
return_onesided=return_onesided,
scaling=scaling, axis=axis, mode=mode,
output_axis='time_first')
# Transform inputs into predicable, sanity checked dictionaries
logging.info('Processing Conditions, Covariates and Confounds')
reg_categorical = _process_input_covariate(reg_categorical, config.input_len)
reg_ztrans = _process_input_covariate(reg_ztrans, config.input_len)
reg_unitmax = _process_input_covariate(reg_unitmax, config.input_len)
# Compute STFT
logging.info('Computing sliding window periodogram')
f, t, p = compute_stft(X, **config.stft_args)
# Compute model - each method MUST assign copes, varcopes and extras
model, des, data = _glm_fit_glmtools(p, reg_categorical, reg_ztrans,
reg_unitmax, config,
contrasts=contrasts,
fit_intercept=fit_intercept)
return GLMSpectrumResult(f, model, des, data, config=config)
@set_verbose
def glm_multitaper(X, reg_categorical=None, reg_ztrans=None, reg_unitmax=None,
contrasts=None, fit_method='pinv', fit_intercept=True,
# Multitaper kwargs
num_tapers='auto', freq_resolution=1, time_bandwidth=5, apply_tapers='broadcast',
# STFT window args
nperseg=None, noverlap=None, window_type='hann', detrend='constant',
# FFT core args
nfft=None, axis=-1, return_onesided=True, mode='psd',
scaling='density', fs=1.0, fmin=None, fmax=None,
# misc
verbose=None):
"""Compute a Power Spectrum with a General Linear Model.
Parameters
----------
X : array_like
Time series of measurement values
reg_categorical : dict or None
Dictionary of covariate time series to be added as binary regessors. (Default value = None)
reg_ztrans : dict or None
Dictionary of covariate time series to be added as z-standardised regessors. (Default value = None)
reg_unitmax : dict or None
Dictionary of confound time series to be added as positive-valued unitmax regessors. (Default value = None)
contrasts : dict or None
Dictionary of contrasts to be computed in the model.
(Default value = None, will add a simple contrast for each regressor)
fit_method : {'pinv', 'lstsq', 'glmtools', sklearn estimator instance}
Specifies how the GLM parameters will be estimated.
* `pinv` uses the design matrix psuedo-inverse method
* `lstsq` uses np.linalg.lstsq.
* `glmtools` uses the OLSModel from the glmtools package.
* A parametrised instance of a sklearn estimator is used if specified here. (Default value = 'pinv')
fit_intercept : bool
Specifies whether a constant valued 'intercept' regressor is included in the model. (Default value = True)
%(multitaper_core)s'
%(stft_window_user)s'
%(fft_user)s'
%(verbose)s'
Returns
-------
freqs : ndarray
Array of sample frequencies.
t : ndarray
Array of times corresponding to each data segment
result : ndarray
Array of output data, contents dependent on *mode* kwarg.
"""
if axis == -1:
axis = X.ndim - 1
if X.ndim != 1 and fit_method in ['pinv', 'lstsq']:
msg = "Data input should be vector for 'pinv' and 'lstsq' fits - data shape {0} was passed in"
if axis == -1:
axis = X.ndim - 1
if X.ndim != 1 and fit_method in ['pinv', 'lstsq']:
msg = "Data input should be vector for 'pinv' and 'lstsq' fits - data shape {0} was passed in"
logging.error(msg.format(X.shape))
logging.error("Use fit_method='glmtools' for multdimensional data")
raise ValueError("Fit methods 'pinv' and 'lstsq' not implemented for multidimensional data")
# Set configuration
logging.info('Setting config options')
config = MultiTaperConfig(X.shape[axis],
input_complex=np.any(np.iscomplex(X)),
time_bandwidth=time_bandwidth,
num_tapers=num_tapers, apply_tapers=apply_tapers,
freq_resolution=freq_resolution, average=None,
fs=fs, window_type=None, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, mode=mode,
scaling=scaling, axis=axis, fmin=fmin, fmax=fmax,
output_axis='auto')
logging.info('Processing Covariates and Confounds')
reg_categorical = _process_input_covariate(reg_categorical, config.input_len)
reg_ztrans = _process_input_covariate(reg_ztrans, config.input_len)
reg_unitmax = _process_input_covariate(reg_unitmax, config.input_len)
# Compute STMT
logging.info('Computing sliding window multitaper')
f, t, p = compute_multitaper_stft(X, **config.multitaper_stft_args)
# Compute model - each method MUST assign copes, varcopes and extras
if fit_method in ['pinv', 'lstsq']:
logging.info('Running numpy GLM fit')
copes, varcopes = _glm_fit_simple(p, reg_ztrans, reg_unitmax, config,
fit_method=fit_method,
fit_intercept=fit_intercept)
extras = None
elif fit_method == 'glmtools':
logging.info('Running glmtools GLM fit')
copes, varcopes, extras = _glm_fit_glmtools(p, reg_ztrans, reg_unitmax, config,
fit_intercept=fit_intercept)
elif _is_sklearn_estimator(fit_method):
logging.info('Running sklearn GLM fit with {0}'.format(_glm_fit_sklearn_estimator))
copes, varcopes, extras = _glm_fit_sklearn_estimator(p, reg_ztrans, reg_unitmax, config,
fit_method=fit_method,
fit_intercept=fit_intercept)
else:
raise ValueError('fit_method not recognised')
return f, copes, varcopes, extras