I am training a PRO gan network based on this github. For those of you not familiar don't worry, the network architecture will not play a serious role.
I have this input convolutional layer, that after a bit of training has nan weights. I set the seed to 0 for reproducibility and it happens at 780 epochs. So i trained for 779, saved the "pre nan" weights and now I am experimenting to see what is wrong with it. In this step, regardless of the input, I still get nan gradients (so nan weights after one training step) but i really cant find why.
The convolution is defined as such
class WSConv2d(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,gain=2):
super().__init__()
self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
self.scale = (gain/(in_channels*kernel_size*kernel_size))**0.5
self.bias =self.conv.bias
self.conv.bias = None
nn.init.normal_(self.conv.weight)
nn.init.zeros_(self.bias)
def forward(self,x):
return self.conv(x*self.scale)+self.bias.view(1,self.bias.shape[0],1,1)
The shape of the input is torch.Size([16, 8, 4, 4])
The shape of the convolutions weights istorch.Size([512, 8, 1, 1])
the shape bias is torch.Size([512])
Scale is 0.5
There are no nan values in any of them
Here is the code that turns all of the weights and biases to zero
critic.load_state_dict(torch.load('test.pth')) # load the weights before nan
cur_step =6 # the layers are in descending order, so 6 is the input layer
x = critic.rgb_layers[cur_step](input) # this is just the convolution as defined above
loss = torch.mean(x)
opt_critic.zero_grad()
loss.backward()
opt_critic.step()
loss is around 0.1322 depending on the input.