最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

How does Pytorch handle in-place operations without losing information necessary for backpropagation? - Stack Overflow

programmeradmin3浏览0评论

I've seen several questions about in-place operations regarding their efficiency, but I'm actually more confused about the inner-workings of pytorch.

Let's take a simple example like

a = torch.randn(10)
b = torch.randn(10)
c = torch.randn(10)

x1 = a * b
x2 = x1 * c 

In this case, things are easy. Backpropagations happens like this:

x2.grad <- 1
c.grad <- x2.grad * x1 = x1 = a * b
x1.grad <- x2.grad * c = c
b.grad <- x1.grad * a = c * a
a.grad <- x1.grad * b = c * b

Everything works correctly. However, in this scenario we have are allocating two buffers: x1 and x2. Now, what happen when we do something like this:

x = a * b
x = x * c 

It seems to me that the overall expression is the same. However, if we try to compute the gradients the same way we did before, we will run into the following problem:

x.grad <- 1
c.grad <- x.grad * x = x = a * b * c

Uh oh. we already got a mistake. Since we performed the multiplication with c in place, we lost the buffer containing a * b which was needed in order to calculate the gradient of c.

I can imagine two possible solutions to this:

  1. The previous code actually gets 'compiled' into something like 'x = a * b * c'. But I feel like this kind of optimization might fail when we try more complicated expressions.

  2. The previous code actually uses intermediate buffers (like x1) instead.

However, in this case, what happens if we try to compute something like.

x = a * b
x *= c
x *= d
x *= e
x *= f 

Does pytorch actually create 3 temporary buffers (x1, x2 and x3)?

How is this kind of problem solved in modern frameworks?

I've seen several questions about in-place operations regarding their efficiency, but I'm actually more confused about the inner-workings of pytorch.

Let's take a simple example like

a = torch.randn(10)
b = torch.randn(10)
c = torch.randn(10)

x1 = a * b
x2 = x1 * c 

In this case, things are easy. Backpropagations happens like this:

x2.grad <- 1
c.grad <- x2.grad * x1 = x1 = a * b
x1.grad <- x2.grad * c = c
b.grad <- x1.grad * a = c * a
a.grad <- x1.grad * b = c * b

Everything works correctly. However, in this scenario we have are allocating two buffers: x1 and x2. Now, what happen when we do something like this:

x = a * b
x = x * c 

It seems to me that the overall expression is the same. However, if we try to compute the gradients the same way we did before, we will run into the following problem:

x.grad <- 1
c.grad <- x.grad * x = x = a * b * c

Uh oh. we already got a mistake. Since we performed the multiplication with c in place, we lost the buffer containing a * b which was needed in order to calculate the gradient of c.

I can imagine two possible solutions to this:

  1. The previous code actually gets 'compiled' into something like 'x = a * b * c'. But I feel like this kind of optimization might fail when we try more complicated expressions.

  2. The previous code actually uses intermediate buffers (like x1) instead.

However, in this case, what happens if we try to compute something like.

x = a * b
x *= c
x *= d
x *= e
x *= f 

Does pytorch actually create 3 temporary buffers (x1, x2 and x3)?

How is this kind of problem solved in modern frameworks?

Share Improve this question edited Feb 15 at 15:01 Victor Chavauty asked Feb 15 at 13:38 Victor ChavautyVictor Chavauty 1963 silver badges12 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 2

You cannot backprop through an in-place operation. However, there are cases where it can appear possible because the in-place operation happened in a way that did not impact the gradient calculation.

First lets distinguish between what is happening in python vs pytorch. Pytorch operations return new tensor objects. Your example of using the x variable twice is not a problem as x in this case is pointing to two different tensors. You can check this with the data_ptr attribute:

a = torch.randn(10)
b = torch.randn(10)
c = torch.randn(10)

x = a * b
print(x.data_ptr())
> 605338432 # `x` points to 605338432
x = x * c 
print(x.data_ptr())
> 605338624 # `x` points to 605338624, a different object

Now to the pytorch level. When you compute values, pytorch saves only the necessary values for backprop. If one of these values is modified by an in-place operation, you get an error. You can however do in-place operations to values not required for gradient calculation.

Consider y = exp(x). In this case, dy/dx = exp(x) = y. For backprop, pytorch stores the value of y to use when the gradient of x is computed. This means that if you modify y with an in-place operation and try to backprop, you'll get an error:

x = torch.randn(10, requires_grad=True)
y = x.exp() # computing `y=exp(x)` saves `y` for backprop
y.add_(10) # in-place addition
y.backward(torch.ones_like(y)) # backprop throws error
> RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Now consider y = x ** 2. In this case, dy/dx = 2 * x. For backprop, pytorch stores the value of x. In this case, y is not involved with the gradient computation. This means you can modify y with an in-place operation.

x = torch.randn(10, requires_grad=True)
y = x.pow(2) # computing `y = x**2` saves `x` for backprop
y.add_(10) # in-place addition
y.backward(torch.ones_like(y)) # backprop runs without error

In general you should avoid using in-place operations for anything you want to backprop through.

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论