Skip to content

[Bug Report] Backward hooks and "modifying gradients" #1160

@nkondapa

Description

@nkondapa

Describe the bug
When adding a backward hook, if you return grad, the returned gradient should substitute into the backwards pass. Suppose you return grad with no changes, this should not raise any errors, but it does. The error is RuntimeError: hook 'hook' has changed the size of value. It implies the shape of the gradient is changing before it reaches the hook. This is not an issue if your return None, since the original gradient is passed through.

Code example

import torch
from transformer_lens import HookedTransformer

# Minimal reproduction of backward hook bug
model = HookedTransformer.from_pretrained("gpt2-small")
model.eval()

prompt = "Hello world"
tokens = model.to_tokens(prompt)

# Simple backward hook that just returns the gradient unchanged
def simple_hook(grad, hook):
    print(f"{hook.name}: shape={grad.shape}")
    return grad

# Add backward hook
model.add_hook("blocks.0.hook_resid_post", simple_hook, dir='bwd')

# Forward pass
logits = model(tokens)

# Backward pass - this should work but raises:
# RuntimeError: hook 'hook' has changed the size of value
loss = logits[0, -1, 0]  # Single scalar
loss.backward()

model.reset_hooks()

Output

  File "/.../lib/python3.10/site-packages/torch/autograd/graph.py", line 841, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: hook 'hook' has changed the size of value
blocks.0.hook_resid_post: shape=torch.Size([1, 3, 768])

System Info
Describe the characteristic of your environment:
conda + pip
Ubuntu 22.04
Python 3.10.19

Checklist

  • I have checked that there is no similar issue in the repo (required)
    Searched for the error and didn't find it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions