When writing PyTorch code, preserve tensor dtypes and devices and be explicit about vector/matrix shapes to avoid precision loss, incorrect broadcasting, and unnecessary CPU allocations.
Why: mixed-precision and quantized weights are common in model deployment. Unnecessary up/downcasts, implicit device transfers, and incorrect assumptions about 1D tensor shapes (torch treats a 1D tensor as shape (d,) and matmul/outer have specific shape requirements) lead to subtle correctness and performance bugs.
Rules (actionable):
When projecting against possibly low-precision weights, do projection math in a safe compute dtype, but avoid redundant round-trips (downcast then upcast). Example: if refusal_directions is float32 and W is bfloat16/4-bit, move the vector to W’s device but keep its dtype:
v = layer_refusal_direction.to(matrix.device) r_transpose_W = torch.matmul(v, matrix) matrix.sub_(weight * torch.outer(v, r_transpose_W))
Avoid creating tensors on CPU unless necessary. Most torch operations preserve the source device; respect the device of source tensors to prevent implicit transfers and extra memory usage.
Example (combined pattern):
# r is (d,) float32 already r_device = r.to(matrix.device) # r_device: (d,), matrix: (d, k) -> torch.matmul yields (k,) r_transpose_W = torch.matmul(r_device, matrix) # outer(r, r_transpose_W) -> (d, k) matrix.sub_(weight * torch.outer(r_device, r_transpose_W))
Apply these guidelines consistently to avoid unnecessary precision loss, incorrect broadcasting, and extraneous device transfers when working with models, LoRA adapters, and quantized weights.
Enter the URL of a public GitHub repository