2020-02-24-what-does-roll-do-in-numpy-and-pytorch

What does roll do in numpy and pytorch?

Rolling in numpy or pytorch moves the data in a circle on a particular axis. Let's look at some examples.

In [1]:
import numpy as np
import torch

Roll in Numpy

Suppose we have a 6x8 matrix in numpy. We can initialize it for clarity like so.

In [2]:
state = np.zeros((6, 8))
for i in range(6):
    for j in range(8):
        state[i][j] = j
state
Out[2]:
array([[0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.]])

We can then roll it leftward (indicated by -1) along the horizontal axis (axis 1). The zeroes which would have been shifted off the beginning of each array are now at the end.

In [3]:
state = np.roll(state, -1, axis=1); state
Out[3]:
array([[1., 2., 3., 4., 5., 6., 7., 0.],
       [1., 2., 3., 4., 5., 6., 7., 0.],
       [1., 2., 3., 4., 5., 6., 7., 0.],
       [1., 2., 3., 4., 5., 6., 7., 0.],
       [1., 2., 3., 4., 5., 6., 7., 0.],
       [1., 2., 3., 4., 5., 6., 7., 0.]])

Why might we want to do this? Suppose this is some sort of time series data and we are getting rid of the oldest elements, which are kept on the left. We could then replace the newest fields with the latest data like so:

In [4]:
for i in range(6):
    state[i][-1] = 8
state
Out[4]:
array([[1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.]])

The zeroes were the oldest, so we got rid of them, and now we have the newest "8" data on the right.

I've seen this done sometimes in reinforcement learning settings, where someone wants to put together state data with time frames.

You can see an example of it using numpy here.

Roll in Pytorch

We can do something similarly in pytorch. Let's reininitialize the state as a pytorch tensor and see an example.

In [5]:
state = torch.zeros((6, 8))
for i in range(6):
    for j in range(8):
        state[i, j] = j
state
Out[5]:
tensor([[0., 1., 2., 3., 4., 5., 6., 7.],
        [0., 1., 2., 3., 4., 5., 6., 7.],
        [0., 1., 2., 3., 4., 5., 6., 7.],
        [0., 1., 2., 3., 4., 5., 6., 7.],
        [0., 1., 2., 3., 4., 5., 6., 7.],
        [0., 1., 2., 3., 4., 5., 6., 7.]])

The syntax is slightly different, but the effect is the same.

In [6]:
state = torch.roll(state, -1, dims=-1)
state
Out[6]:
tensor([[1., 2., 3., 4., 5., 6., 7., 0.],
        [1., 2., 3., 4., 5., 6., 7., 0.],
        [1., 2., 3., 4., 5., 6., 7., 0.],
        [1., 2., 3., 4., 5., 6., 7., 0.],
        [1., 2., 3., 4., 5., 6., 7., 0.],
        [1., 2., 3., 4., 5., 6., 7., 0.]])

Similarly, if we were working with time series state data, we could replace the last index with something new.

In [7]:
for i in range(6):
    for j in range(8):
        state[i][-1] = 8
state
Out[7]:
tensor([[1., 2., 3., 4., 5., 6., 7., 8.],
        [1., 2., 3., 4., 5., 6., 7., 8.],
        [1., 2., 3., 4., 5., 6., 7., 8.],
        [1., 2., 3., 4., 5., 6., 7., 8.],
        [1., 2., 3., 4., 5., 6., 7., 8.],
        [1., 2., 3., 4., 5., 6., 7., 8.]])

Here is a pytorch implementation example of the previous code, doing the same thing to the state:

Rolling Manually in numpy

I've also seen folks use a similar strategy, but using manual indexing. Here is an example of that in numpy.

In [8]:
state = np.zeros((6, 8))
for i in range(6):
    for j in range(8):
        state[i][j] = j
state
Out[8]:
array([[0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.],
       [0., 1., 2., 3., 4., 5., 6., 7.]])

Instead of rolling, we just replace part of the dataset with a different part of the dataset. Notice how without the roll function, the zeroes don't end up on the end.

In [9]:
state[:,:7] = state[:,1:]
state
Out[9]:
array([[1., 2., 3., 4., 5., 6., 7., 7.],
       [1., 2., 3., 4., 5., 6., 7., 7.],
       [1., 2., 3., 4., 5., 6., 7., 7.],
       [1., 2., 3., 4., 5., 6., 7., 7.],
       [1., 2., 3., 4., 5., 6., 7., 7.],
       [1., 2., 3., 4., 5., 6., 7., 7.]])

If we are just going to update those fields anyway, however, perhaps we don't care about that.

In [10]:
for i in range(6):
    for j in range(8):
        state[i][-1] = 8
state
Out[10]:
array([[1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.],
       [1., 2., 3., 4., 5., 6., 7., 8.]])

Here's a reinforcement example of formatting the state data using the manual strategy.


Previous Post
« How do I use numpy's stack, vstack, and hstack?
Next Post
Why can’t weight preferences be part of the state in multi objective reinforcement learning? »

Archive

chinese tang-dynasty-poetry 李白 王维 python rl pytorch emacs 杜牧 spinningup numpy networking deep-learning 贺知章 白居易 王昌龄 杜甫 李商隐 tips reinforcement-learning macports jekyll 骆宾王 贾岛 孟浩然 time-series regression rails pandas math macosx lesson-plan helicopters flying fastai conceptual-learning command-line bro 黄巢 韦应物 陈子昂 王翰 王之涣 柳宗元 杜秋娘 李绅 张继 孟郊 刘禹锡 元稹 youtube visdom system sungho stylelint softmax siri sgd scikit-learn scikit research qtran qoe qmix pyhton poetry pedagogy papers paper-review optimization openssl openmpi nyc neural-net multiprocessing mpi morl ml mdp marl mandarin machine-learning latex language-learning khan-academy jupyter-notebooks intuition homebrew hacking google-cloud github flashcards faker dme deepmind dec-pomdp data-wrangling craftsman congestion-control coding books book-review atari anki analogy 3brown1blue 2fa