Snippet: PyTorch Best Practices¶
Domain Context¶
Framework-specific rules for PyTorch training, model development, and debugging.
Model Definition¶
- Inherit from
nn.Module; implementforward()only — never override__call__ - Initialize weights explicitly when not using pretrained — document initialization scheme
- Use
nn.Sequentialornn.ModuleListfor dynamic layers — never plain Python lists (breaks.parameters()) - Register buffers with
self.register_buffer()for non-parameter tensors that need device movement - Type hints on forward:
def forward(self, x: torch.Tensor) -> torch.Tensor:
Training Loop¶
- Always call:
model.train()before training,model.eval()before validation - Use
torch.no_grad()ortorch.inference_mode()during evaluation — saves memory and speeds up - Gradient clipping:
torch.nn.utils.clip_grad_norm_()— set max_norm and log when clipping occurs - Zero gradients with
optimizer.zero_grad(set_to_none=True)— more memory efficient - Mixed precision: use
torch.amp.autocast+GradScalerfor fp16; bf16 doesn't need scaling
Data Loading¶
num_workers: start with 4 per GPU; tune based on CPU/IO bottleneck profilingpin_memory=Truewhen using GPU — speeds up host-to-device transferpersistent_workers=Truefor PyTorch ≥1.7 to avoid worker restart overhead- Custom datasets: implement
__len__and__getitem__; useIterableDatasetfor streaming data - Reproducible data loading: set
worker_init_fnwith seed offset per worker
Debugging & Profiling¶
torch.autograd.set_detect_anomaly(True)during debugging — disable in production (slow)- Use
torch.profilerwith TensorBoard for bottleneck analysis - Memory profiling:
torch.cuda.memory_summary()to diagnose OOM issues - Gradient checking:
torch.autograd.gradcheck()for custom autograd functions - NaN detection: check for NaN in loss before backward pass — fail fast with informative error
Checkpointing¶
- Save:
{'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'config': config} - Load with
strict=Falsewhen architecture has changed — log mismatched/missing keys - For distributed: save only on rank 0; load with
map_locationto handle device mapping - Use
safetensorsformat for security and speed when sharing model weights
Performance Optimization¶
torch.compile()(PyTorch 2.0+): use for production models — significant speedup with minimal code change- Channels-last memory format:
model.to(memory_format=torch.channels_last)for CNN speedup - Prefer
torch.einsumortorch.matmulover manual loops - Avoid Python-level loops over tensor dimensions — use vectorized operations
- Profile before optimizing: the bottleneck is rarely where you think it is
Common Pitfalls¶
- Forgetting
model.eval()→ BatchNorm and Dropout behave incorrectly during inference .item()in training loop → forces GPU sync, kills throughput — only use for logging- Tensor on wrong device → runtime error; always use
.to(device)consistently - In-place operations (e.g.,
x += 1) can break autograd — avoid in differentiable paths - DataLoader with
shuffle=True+DistributedSampler→ double shuffling; let sampler handle it