Tutorial 8 - Using A Custom Model FitΒΆ

SAILS provides implementations of several algorithms for fitting autoregressive models but it is straightforward to create a custom class which implements a new model fit or uses one from another package.

This tutorial will outline how to create a custom model fit class using the Vector Autoregression class from statsmodels.tsa. We start by importing SAILS and creating a simulated time series to model.

import sails

# Create a simulated signal
sample_rate = 100
siggen = sails.Baccala2001_fig2()
X = siggen.generate_signal(sample_rate=sample_rate,num_samples=1000)

We then fit a model using Ordinary Least Squared as implemented in SAILS.

# Fit an autoregressive model with order 3
sails_model = sails.OLSLinearModel.fit_model(X,np.arange(4))

Now we will create a new model fit class based on sails.modelfit.AbstractLinearModel. This is a base class which contains a number of methods and properties to store and compute information on a model. The AbstractLinearModel is not usable on its own as the fit_model method is not implemented. When classes such as OLSLinearModel are defined in SAILS, they inherit from AbstractLinearModel to define the helper functions before a specific fit_model method is defined. We can do the same to define a custom model fit class using an external package. We will first create a new class which inherits from AbstractLinearModel and then define a fit_model method which computes the model fit and stores the outputs in a standard form.

Here is our custom model fit class, each section is described in the comments in the code.

# Define a new class inheriting from the SAILS base model
class TSALinearModel( sails.AbstractLinearModel ):

    # Define a fit_model method using the python @classmethod decorator
    def fit_model( cls, data, delay_vect):

        # Some sanity checking of the input matrix We make sure that the input
        # data is in 2d format or 3d format with a single epoch.
        # statsmodels.tsa doesn't currently support fitting multitrial data
        if data.ndim == 3:
            if data.shape[2] > 1:
                raise ValueError('This class is only implemented for single-trial data')
            # Take first trial if we have 3d data
            data = data[:,:,0]

        # Create object - classmethods act as object constructors. cls points
        # to TSALinearModel and ret is a specific instance of TSALinearModel
        # though it is currently empty.
        ret = cls()

        # The rest of this method will populate ret with information based on
        # our model fit

        # Import the model fit function
        from statsmodels.tsa.api import VAR

        # Initialise and fit model - we use a simple VAR with default options.
        # Note that we return the model and results to ret.tsa_model.  This
        # means that the statsmodels.tsa.api.VAR object will be stored and
        # returned with ret. Later we can use this to access the statsmodels
        # model and results though the SAILS object.
        ret.tsa_model = VAR(data.T) # SAILS assumes channels first, TSA samples first
        model_order = len(delay_vect) - 1 # delay_vect includes a leading zero
        ret.tsa_results = ret.tsa_model.fit(model_order)

        # The method must assign the following values for SAILS metrics to work properly
        ret.maxorder = model_order
        ret.delay_vect = np.arange(model_order)
        ret.parameters = np.concatenate((-np.eye(data.shape[0])[:,:,None],
                                          ret.tsa_results.coefs.transpose((1,2,0))), axis=2)
        ret.data_cov = sails.find_cov(data.T,data.T)
        ret.resid_cov = sails.find_cov(ret.tsa_results.resid.T,ret.tsa_results.resid.T)

        # Return fitted model within an instance of a TSALinearModel
        return ret

It is crucial that the fit_model class returns an instance of our the overall class. This instance must contain the following information. Other functions in SAILS assume that these are stored in a fitted model class with specific formats and names.

  • maxorder: the model order of the fit
  • delay_vect: the vector of delays used in the model fit
  • parameters: the fitted autoregressive parameters of shape [num_channels x num_channels x model_order] with a leading identity
  • data_cov: the covariance matrix of the fitted data
  • resid_cov: the covariance matrix of the residuls of the fit

Other data can be added in as well (we store tsa_model and tsa_results in the example here) but these five must be defined within the returned class.

We can now fit a model using our new class

tsa_model = TSALinearModel.fit_model(X,np.arange(4))

Finally, we compute connectivity metrics from each model fit and plot a comparison

freq_vect = np.linspace(0,sample_rate/2)

sails_metrics = sails.FourierMvarMetrics.initialise(sails_model,sample_rate,freq_vect)
tsa_metrics = sails.FourierMvarMetrics.initialise(tsa_model,sample_rate,freq_vect)

PDC = np.concatenate( (sails_metrics.partial_directed_coherence,tsa_metrics.partial_directed_coherence),axis=3)

sails.plotting.plot_vector(PDC,freq_vect,line_labels=['SAILS','TSA'],diag=True,x_label='Frequency (Hz'))

We see that the partial directed coherence from the two models is nearly identical.