Using torch.index_select, torch.gather and torch.take
In some situations, you’ll need to do some advanced indexing / selection with Pytorch, e.g. answer the question: “how can I select elements from Tensor A following the indices specified in Tensor B?”
In this post we’ll present the three most common methods for such tasks, namely torch.index_select, torch.gather and torch.take. We’ll explain all of them in detail and contrast them with one another.
Admittedly, one motivation for this post was me forgetting how and when to use which function, ending up googling, browsing Stack Overflow and the, in my opinion, relatively brief and not too helpful official documentation. Thus, as mentioned, we here do a deep dive into these functions: we motivate when to use which, give examples in 2- and 3D, and show the resulting selection graphically.
I hope this post will bring clarity about said functions and remove the need for further exploration — thanks for reading!
And now, without further ado, let’s dive into the functions one by one. For all, we first start with a 2D example and visualize the resulting selection, and then move to somewhat more complex example in 3D. Further, we re-implement the executed operation in simple Python — s.t. you can look at pseudocode as another source of information what these functions do. In the end, we summarize the functions and their differences in a table.
torch.index_select selects elements along one dimension, while keeping the other ones unchanged. That is: keep all elements from all other dimensions, but pick elements in the target dimensions following the index tensor. Let’s demonstrate this with a 2D example, in which we select along dimension 1:
num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)
The resulting tensor has shape [len_dim_0, num_picks]
: for every element along dimension 0, we have picked the same element from dimension 1. Let’s visualize this: