preserve dtype and shapes

When writing PyTorch code, preserve tensor dtypes and devices and be explicit about vector/matrix shapes to avoid precision loss, incorrect broadcasting, and unnecessary CPU allocations.

copy reviewer prompt

Prompt

Reviewer Prompt

When writing PyTorch code, preserve tensor dtypes and devices and be explicit about vector/matrix shapes to avoid precision loss, incorrect broadcasting, and unnecessary CPU allocations.

Why: mixed-precision and quantized weights are common in model deployment. Unnecessary up/downcasts, implicit device transfers, and incorrect assumptions about 1D tensor shapes (torch treats a 1D tensor as shape (d,) and matmul/outer have specific shape requirements) lead to subtle correctness and performance bugs.

Rules (actionable):

  • Do not cast dtypes unless required. If a vector is already float32, avoid .to(torch.float32) again; prefer only changing device: v = v.to(module.weight.device)
  • When projecting against possibly low-precision weights, do projection math in a safe compute dtype, but avoid redundant round-trips (downcast then upcast). Example: if refusal_directions is float32 and W is bfloat16/4-bit, move the vector to W’s device but keep its dtype:

    good: preserve dtype, only change device

    v = layer_refusal_direction.to(matrix.device) r_transpose_W = torch.matmul(v, matrix) matrix.sub_(weight * torch.outer(v, r_transpose_W))

  • Be explicit about shapes when using matmul vs outer:
    • Use torch.matmul for vector-matrix products where one operand is 1D (d,) and the other is (d, k). torch.matmul handles (d,) as (1, d) prepended internally.
    • Use torch.outer(a, b) when both a and b are 1D column vectors and you want the full outer product (d, k). Do not pass shaped tensors like (d,1) and (1,k) to torch.outer; prefer 1D tensors.
  • Avoid creating tensors on CPU unless necessary. Most torch operations preserve the source device; respect the device of source tensors to prevent implicit transfers and extra memory usage.

  • API/type clarity: annotate functions that return modules vs tensors correctly (e.g., dict[str, list[torch.nn.Module]]), and explicitly handle LoRA/quantization lifecycle (attach adapters without unnecessary weight changes; when merging on quantized models, reload base model on CPU and copy adapter weights).

Example (combined pattern):

# r is (d,) float32 already r_device = r.to(matrix.device) # r_device: (d,), matrix: (d, k) -> torch.matmul yields (k,) r_transpose_W = torch.matmul(r_device, matrix) # outer(r, r_transpose_W) -> (d, k) matrix.sub_(weight * torch.outer(r_device, r_transpose_W))

Apply these guidelines consistently to avoid unnecessary precision loss, incorrect broadcasting, and extraneous device transfers when working with models, LoRA adapters, and quantized weights.

Source discussions