General question (hopefully useful for people coming from google): What to do when the gradient explodes? When working with transformers and deep NNs (with PyTorch), do you have a mental checklist of things to investigate and try when your gradient explodes (and when your loss becomes NaN)?
More context (my specific situation): I am training a model to take chemical formulas as input, represent the elements as embedding vectors and feed sequences of embedding vectors into several transformers until finally a MLP predicts a numerical property of the compound with that chemical formula.
Structure:
class BandgapPredictionModel(nn.Module):
def __init__(self, num_elements, embedding_dim, num_heads, num_layers, num_queries):
super(BandgapPredictionModel, self).__init__()
self.element_embedding = ElementEmbedding(num_elements, embedding_dim)
self.attention_block = SelfAttentionBlock(embedding_dim, num_heads, num_layers)
self.motif_discovery = MotifDiscovery(embedding_dim, num_queries)
self.aggregation = HierarchicalAggregation(embedding_dim)
self.prediction = PredictionMLP(embedding_dim)
def forward(self, element_ids):
embeddings = self.element_embedding(element_ids) # Step 1
mask = (element_ids == 0)
attended_elements = self.attention_block(embeddings, src_key_padding_mask=mask) # Step 2
motifs = self.motif_discovery(attended_elements) # Attention block (5 queries)
#aggreagtion and attention of attended elemental embeddings and motif embeddings
global_representation = self.aggregation(motifs)
bandgap = self.prediction(global_representation).squeeze(-1) # Step 5
return bandgap
and when called with lr=0.0001
BandgapPredictionModel(num_elements=118, embedding_dim=32,
num_heads=4, num_layers=3, num_queries=5)
The model performed reasonably well on 10000 entries (sampled from bigger dataset with a fixed random seed) but on more data it would begin to given NaN loss. The logs show:
2025-02-06 17:59:08,309 - INFO - Predictions: no nan in predictions
2025-02-06 17:59:08,312 - INFO - NaN detected in gradient for parameter: element_embedding.embedding.weight
2025-02-06 17:59:08,313 - INFO - NaN detected in gradient for parameter: attention_block.transformer.layers.0.self_attn.in_proj_weight
2025-02-06 17:59:08,313 - INFO - NaN detected in gradient for parameter: attention_block.transformer.layers.0.self_attn.in_proj_bias
Summary:
Complex transformer model gives NaN gradient after a while. What steps to take?
Steps Taken
Reduced Learning rate to 0.0001 did not solve the problem.
Weight Initialization
Used Xavier init method for the weights.
Self Attention Norm self.norm = nn.LayerNorm(embedding_dim) in my Self Attention block
Gradient Clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1000000.0) This helped when training with less samples but with more data, the clipping resulted in the model predicting NaN. How to choose the norm?
I hope this question respects the rules (I am new)! I am ready to edit if needed.
Final Note: This model actually worked quite well when I got it to run on 10 000 samples so I don't think the architecture is completely stupid. Predicted VS Actual