How do populations of neurons encode movement?#
This notebook will help you regenerate some of the figures in Churchland & Cunningham et al., 2012, as well as understand how changing parameters in the analysis change these visualizations (and how we interpret them).
Setup#
First, we’ll clone and install a super useful jPCA package, written by Benjamin Antin at Columbia University. If you’ve already installed this package, the cell below will simply import it.
try:
import jPCA
print('jPCA already installed.')
except ImportError:
!git clone https://github.com/bantin/jPCA.git
%cd jPCA
!pip install .
%cd ..
import jPCA
print('jPCA installed successfully.')
jPCA already installed.
After installing and/or importing the jPCA package, we also need to import numpy and matplotlib.pyplot.
import numpy as np
import matplotlib.pyplot as plt
print('Packages imported.')
Packages imported.
Load data and inspect it#
Below, we’ll use the helper function load_churchland_data to load the Churchland & Cunningham et al. (2012) data that is hosted on the Churchland lab website. We’ll also first take a look at one of the objects, data.
from dandi.dandiapi import DandiAPIClient
dandiset_id = '000070' # ephys dataset from Shenoy lab
filepath = 'sub-Jenkins/sub-Jenkins_ses-20090916_behavior+ecephys.nwb' # one file from one monkeyb
with DandiAPIClient() as client:
asset = client.get_dandiset(dandiset_id, 'draft').get_asset_by_path(filepath)
s3_url = asset.get_content_url(follow_redirects=1, strip_query=True)
print(s3_url)
https://dandiarchive.s3.amazonaws.com/blobs/2dd/66a/2dd66ad9-e7d9-4b96-9de1-4a90ca416048
import os
import jPCA
from jPCA.util import load_churchland_data, plot_projections
# Check if data file exists locally, otherwise download
data_file = 'exampleData.mat'
if not os.path.exists(data_file):
import urllib.request
data_url = 'https://github.com/nwb4edu/development/blob/0f181c8092d79278fcb0320d9f53bc33fbb0df85/exampleData.mat?raw=true'
path, headers = urllib.request.urlretrieve(data_url, data_file)
else:
path = data_file
data, times = load_churchland_data(path)
print(len(data)) # Show the length of data
data[:10] # Look at the first 10 entries in data
108
[array([[1.60276178, 3.62279645, 1.51476444, ..., 5.60466456, 8.12611812,
4.51181022],
[1.55470078, 3.61949889, 1.52247338, ..., 5.8348096 , 8.39029125,
4.49137067],
[1.50889271, 3.6006078 , 1.51828573, ..., 6.10419168, 8.69046855,
4.46007445],
...,
[4.12529142, 0.82163592, 0. , ..., 0.93349805, 4.08401884,
2.88854864],
[4.45614688, 0.75939308, 0. , ..., 1.02900484, 3.93844586,
2.50915528],
[4.81611363, 0.71241948, 0. , ..., 1.11870277, 3.78184694,
2.23669761]], shape=(61, 218)),
array([[1.98988481e+00, 9.18115783e+00, 4.84613933e-01, ...,
8.29134273e+00, 1.28676193e+01, 5.34669479e+00],
[2.03938737e+00, 8.95338260e+00, 4.71113281e-01, ...,
8.66122138e+00, 1.29588883e+01, 5.27734065e+00],
[2.11567852e+00, 8.73625972e+00, 4.36420133e-01, ...,
9.01331021e+00, 1.30878842e+01, 5.12815676e+00],
...,
[0.00000000e+00, 8.89741886e-01, 1.20710063e-01, ...,
1.77465655e+00, 5.54310886e+00, 4.64197132e+00],
[4.34632238e-01, 8.58363468e-01, 5.94563715e-02, ...,
1.77712679e+00, 5.20989508e+00, 4.28237550e+00],
[1.23498659e+00, 8.28295046e-01, 1.18584586e-02, ...,
1.75649494e+00, 4.79042602e+00, 3.98363834e+00]], shape=(61, 218)),
array([[ 1.60373245, 10.15764322, 0.29813424, ..., 6.96855588,
8.28129753, 3.90105099],
[ 1.569838 , 10.021134 , 0.31419016, ..., 7.28878381,
8.51076048, 3.937435 ],
[ 1.55323424, 9.75734348, 0.32486749, ..., 7.62972653,
8.8555404 , 3.99531886],
...,
[ 0.77569014, 1.61741704, 0.26314808, ..., 2.36269075,
5.84946381, 6.43682085],
[ 1.31849522, 1.61624415, 0.23164867, ..., 2.23056838,
5.62593022, 6.4016482 ],
[ 2.02338781, 1.62394422, 0.21108832, ..., 2.10164688,
5.39176039, 6.42803701]], shape=(61, 218)),
array([[2.72760539e+00, 2.88838501e+01, 0.00000000e+00, ...,
8.51932370e+00, 1.11985675e+01, 4.52686021e+00],
[2.67390640e+00, 2.84223262e+01, 2.53716630e-03, ...,
8.87067363e+00, 1.15797916e+01, 4.22883051e+00],
[2.60115725e+00, 2.77449432e+01, 9.19432870e-03, ...,
9.22141101e+00, 1.20100251e+01, 3.87817751e+00],
...,
[1.32059867e+01, 0.00000000e+00, 2.17196058e-02, ...,
2.18950438e+00, 8.22869141e+00, 1.62000847e+01],
[1.31175366e+01, 0.00000000e+00, 1.81035530e-02, ...,
2.27029301e+00, 7.68207040e+00, 1.45853136e+01],
[1.29664334e+01, 0.00000000e+00, 1.43236629e-02, ...,
2.30341582e+00, 7.07254659e+00, 1.30476296e+01]], shape=(61, 218)),
array([[ 3.41652625, 15.30420543, 0.31560335, ..., 5.55989529,
7.89834632, 2.55235159],
[ 3.44291429, 14.94753565, 0.29931779, ..., 5.78522773,
8.05902305, 2.51652769],
[ 3.4348078 , 14.41107245, 0.27380366, ..., 6.09202126,
8.30579099, 2.58677391],
...,
[20.97243685, 0. , 0. , ..., 1.0827491 ,
5.51890839, 18.01749614],
[21.19895812, 0. , 0. , ..., 1.29116673,
5.3266248 , 17.86940981],
[21.22406973, 0. , 0. , ..., 1.46901007,
5.12274536, 17.59725915]], shape=(61, 218)),
array([[4.03923418e+00, 1.36332909e+01, 0.00000000e+00, ...,
5.44131703e+00, 8.48576516e+00, 3.06348630e+00],
[4.07191047e+00, 1.32253702e+01, 2.53716630e-03, ...,
5.64141161e+00, 8.58871439e+00, 3.04405261e+00],
[4.09220598e+00, 1.27503908e+01, 9.19432870e-03, ...,
5.90183520e+00, 8.72992838e+00, 3.11563348e+00],
...,
[1.82841848e+01, 0.00000000e+00, 2.17196058e-02, ...,
6.30578050e-01, 3.96554230e+00, 1.98244539e+01],
[1.83640149e+01, 0.00000000e+00, 1.81035530e-02, ...,
9.15449621e-01, 3.88906459e+00, 1.98647166e+01],
[1.83092114e+01, 0.00000000e+00, 1.43236629e-02, ...,
1.16036934e+00, 3.77210293e+00, 1.96804517e+01]], shape=(61, 218)),
array([[1.6649231 , 1.54846916, 0.23029342, ..., 5.06335567, 9.19511398,
5.04997387],
[1.68467823, 1.52632859, 0.19169466, ..., 5.28183642, 9.49420333,
5.1348521 ],
[1.70369659, 1.4793452 , 0.18148143, ..., 5.55075934, 9.8505357 ,
5.24377644],
...,
[2.31733141, 1.80694326, 0. , ..., 1.35858661, 5.58251028,
5.43404842],
[2.16833113, 1.774965 , 0. , ..., 1.35887535, 5.24024774,
4.80467814],
[2.00670126, 1.74779926, 0. , ..., 1.36631926, 4.90706029,
4.22998486]], shape=(61, 218)),
array([[3.79552599, 3.39280084, 0. , ..., 6.32034651, 5.88440318,
3.85917332],
[3.86538517, 3.28978386, 0. , ..., 6.53654853, 5.98880166,
3.93190762],
[3.97796234, 3.1487481 , 0. , ..., 6.81210714, 6.20658187,
4.04437122],
...,
[3.90025433, 3.0066539 , 0.33077066, ..., 6.35071173, 3.46221117,
4.33747868],
[4.0121167 , 2.94413028, 0.37893628, ..., 5.54029745, 3.62462199,
4.6777363 ],
[4.15470646, 2.8824668 , 0.43859265, ..., 4.82267132, 3.74731782,
5.00949809]], shape=(61, 218)),
array([[4.25736601, 4.40688495, 0. , ..., 5.07081411, 6.62688377,
3.71489383],
[4.32244144, 4.29753924, 0. , ..., 5.28244275, 6.8641511 ,
3.68663197],
[4.4330013 , 4.07015713, 0. , ..., 5.58151686, 7.18845814,
3.6953928 ],
...,
[4.01073048, 3.1942071 , 0.29335473, ..., 6.44955845, 4.349918 ,
4.21620661],
[3.73210996, 3.14390731, 0.32311401, ..., 5.57861995, 4.37063976,
4.09425406],
[3.49415452, 3.10281821, 0.36125002, ..., 4.80957995, 4.34536525,
4.05340816]], shape=(61, 218)),
array([[2.28999702, 7.74126035, 0.07296037, ..., 6.23387879, 9.14562329,
3.11584665],
[2.3060687 , 7.62697086, 0.07916832, ..., 6.56379207, 9.30814884,
3.00823497],
[2.34111484, 7.50059092, 0.08822277, ..., 6.88993767, 9.50221832,
2.89594494],
...,
[2.56678043, 0.66563231, 0. , ..., 1.76669617, 4.02700293,
2.6701241 ],
[2.75769703, 0.64591699, 0. , ..., 1.80855134, 3.81500126,
2.87936864],
[2.96793386, 0.6293154 , 0. , ..., 1.80527546, 3.61018331,
3.18619669]], shape=(61, 218))]
Hmm, why is data a 108-length list of arrays? If we read carefully in the paper, it says:
“In the ‘maze task’ monkeys J and N made both straight reaches and reaches that curved around one or more intervening barriers. This task was beneficial because of the large variety of different reaches that could be evoked. Typically we used 27 conditions: each providing a particular arrangement of target and barriers. Monkey J performed the task for four different sets of 27 conditions, resulting in four datasets (J1 through J4). For the monkey J-array and N-array datasets, 108 conditions were presented in the same recording session.”
There are four different sets of 27 conditions (giving 108 total), each a different arrangement of a target and barrier. Each of the arrays within contains a time x neuron matrix for each condition (hereafter, for clarity, we’ll call these trials. We can inspect the first array to see how many neurons there are.
data[0].shape
(61, 218)
print(len(times)) # Show the length of time
times[:10] # Look at the first 10 entries in time
61
[np.int16(-50),
np.int16(-40),
np.int16(-30),
np.int16(-20),
np.int16(-10),
np.int16(0),
np.int16(10),
np.int16(20),
np.int16(30),
np.int16(40)]
The meaning of the times length is a bit trickier to derive from the paper, but the documentation for exampleMat.mat tells us that these times are in milliseconds (ms), and start 50 ms before the reach at time 0. A firing rate was sampled every 10 ms, so this is why these timestamps are 10 apart.
Visually explore the data#
Now that we understand the structure of the data, we can take a look at some of these trials. Below, we’ll create a firing rate plot for one trial at a time. In this plot, each neuron will be a different color.
data` below (keeping in mind how many conditions we have in total as well as 0 indexing in Python!) to look through a few different trials for patterns.trial = data[25] # You can change the index here if you'd like to look through different trials.
def plotFiringRates(times,trial):
plt.plot(times,trial)
plt.xlabel('Time (ms)')
plt.ylabel('Firing Rate (Hz)')
plt.vlines(0,np.min(trial),np.max(trial),color='gray',linestyle='--')
plt.show()
plotFiringRates(times,trial)
As you can appreciate from this plot, it is hard to make much sense of what this population of neurons are doing. Some neurons respond to the plotted condition, but not others.
We could investigate what happens if we average across all of the neurons, and then plot to compare trials against one another. In this plot, each line is a different condition.
all_trials = np.zeros((61,108))
for idx,this_trial in enumerate(data):
trial_average = np.mean(this_trial,axis=1)
all_trials[:,idx] = trial_average
plotFiringRates(times,all_trials)
This is also not particularly informative! Clearly, neural activity increases overall when a monkey initiates a reach, but what does this tell us about neural coding?
Let’s do one final check, where we look at a single cell across multiple conditions. This is similar to Figure 2 of the paper, except we won’t color these by preparatory activity.
neuron below (keeping in mind how many neurons we have in total as well as 0 indexing in Python!) to look through a few neurons for patterns.neuron = 200
trials_for_neuron = np.zeros((61,108))
for idx,this_trial in enumerate(data):
trials_for_neuron[:,idx] = this_trial[:,neuron]
plotFiringRates(times,trials_for_neuron)
In summary, we can learn a few things away from the data exploration above:
Neural activity increases overall when the monkey reaches.
Individual neurons show different patterns of responses to reaches. Some neurons increase their firing rate to any reach, others are more selective. The timing of their firing rate variation also varies after the reach.
jPCA#
Okay, enough of the old school, simple analyses! Below, we’ll use jPCA to retain information from each individual neuron and each condition, looking at how each neuron moves through state space over time. To create our jPCA plot, we’ll use the jPCA package.
# Create a jPCA object
jpca = jPCA.JPCA(num_jpcs=6)
# Fit the jPCA object to data
projected, full_data_var, pca_var_capt, jpca_var_capt = jpca.fit(data, times=times, tstart=-50, tend=150)
# Plot the projected data
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
plot_projections(projected, axis=axes[0], x_idx=0, y_idx=1) # Plot the first jPCA plane
plot_projections(projected, axis=axes[1], x_idx=2, y_idx=3) # Plot the second jPCA plane
axes[0].set_title('jPCA Plane 1')
axes[1].set_title('jPCA Plane 2')
plt.tight_layout()
plt.show()
Above, we’ve generated figures really similar to Figure 3. However, the utility of working with this data live is that we can change features of the analysis and observe how this changes the plots (and our interpretation of the data).
By default, the jPCA is fit to -50 ms to 150 ms of the data. Change
tendto include more of the data in the analysis, and observe how this changes the PCA plots.Bonus: What dicates the coloring of the lines here? Dig into
utilsto find out.