When an operation can fail in a known way (edge-case inputs, risky math, I/O/parsing, optional capabilities), handle it explicitly:

Example (defensive guard + clean recovery):

num_tokens = torch.tensor(0, device=device)
steps = [next(train_loader) for _ in range(grad_accum_steps)]
num_tokens += sum((targets >= 0).sum() for _, targets in steps)

# (optional) sync across ranks
if ddp:
    dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM)

# prevent NaN loss
if num_tokens.item() == 0:
    model.zero_grad(set_to_none=True)
    # skip this accumulation window
    return

for train_inputs, train_targets in steps:
    loss = model(train_inputs, train_targets, loss_reduction='sum')
    loss = loss / num_tokens
    loss.backward()

Example (don’t swallow intended errors):

# WRONG: throws, then immediately swallows everything
try:
    raise RuntimeError("something important")
except Exception:
    pass

# RIGHT: either remove the throw/silencing, or narrow the exception
try:
    import torch._dynamo
    if torch._dynamo.is_compiling():
        raise RuntimeError("RoPE cache too small during torch.compile")
except ImportError:
    # only ignore the missing module case
    pass

Example (cleanup on integrity failure):

Apply this consistently across training loops, dataset download/verification, and state load paths.