Attention 다시 보기¶
이 챕터에서 배우는 것
- scaled dot-product attention — 식 한 줄, 코드 5 줄
- causal mask 가 왜 필요한가 (생성 모델의 핵심 제약)
- PyTorch 의
F.scaled_dot_product_attention한 줄 — 직접 짠 것과 같은지 확인 - multi-head 는 그저 reshape — 새 알고리즘 아님
전제
행렬곱 감, softmax 정의, PyTorch 텐서 연산 기본. Attention 을 이미 한 번 들어봤다면 더 좋다 — 이 챕터는 처음 배우기보단 손으로 다시 짜기.
1. 개념 — 식 한 줄¶
세 입력 모두 같은 형태 — (seq, d_k). 출력도 (seq, d_k).
직관: 각 토큰이 나머지 토큰들에 점수를 매기고, 그 점수로 value 를 가중 평균 한다. "지금 만들 토큰에 어디를 얼마나 참고할까" 의 미분가능한 버전.
5 단계로 풀면:
- Q · Kᵀ — 토큰 i 가 토큰 j 에 얼마나 관심?
(seq, seq)점수. - ÷ √d_k — 점수 분산 안정. d_k 가 크면 dot product 가 너무 커져 softmax 가 한 곳으로 쏠림.
- causal mask — 미래 위치는 -∞ 로 (생성 모델 한정).
- softmax — 확률 분포로.
- × V — value 가중 합.
2. 왜 필요한가 — RNN/CNN 대비¶
| 방식 | "어디를 보나" | 거리 의존 | 병렬화 |
|---|---|---|---|
| RNN | 직전 hidden state 만 (간접적으로 멀리) | O(n) hop | 어려움 (순차) |
| CNN | window 안만 (예: 3~7) | window 한정 | 좋음 |
| Attention | 모든 위치 직접 참조 | O(1) | 좋음 (matmul) |
직접 참조 + 병렬화. 이 두 속성이 트랜스포머가 RNN/CNN 을 갈아치운 이유.
대가: 메모리 O(n²) — Q · Kᵀ 가 (seq, seq). seq=4K 면 그 한 행렬만 64MB (fp32). FlashAttention(§7) 이 이 문제를 다룸.
3. 어디에 쓰이나¶
- 모든 트랜스포머 층 — encoder · decoder · cross-attention 모두 같은 식.
- GPT 계열 (decoder-only) — causal mask 적용. 우리가 만들 모델.
- BERT 계열 (encoder-only) — mask 없음 (양방향).
- T5 (encoder-decoder) — encoder 는 mask 없음, decoder 는 causal + cross-attention.
이 책 본문은 causal self-attention 만 다룬다 (Part 7 Ch 25 에서 encoder, Ch 28 에서 encoder-decoder 한 번씩).
4. 최소 예제 — 손으로 짜기 30줄¶
Q @ K.T— 각 위치 쌍 (i, j) 의 dot product. shape(seq, seq).√d_k로 나누기 — 식의 핵심. 안 나누면 softmax 가 평평해지거나 한 곳에 쏠림 (d 가 커질수록 심함).triu(diagonal=1)— 주대각선 위쪽이 True. 그 자리에 -∞ 채우면 softmax 후 그 자리는 0. 아래쪽 (j ≤ i) 만 본다.- attention 가중치로 V 의 가중 평균. 출력은 입력과 같은 shape.
전형적 출력:
attention weights row 0: tensor([1., 0., 0., 0.]) # 위치 0
attention weights row 1: tensor([0.31, 0.69, 0., 0.]) # 위치 1
attention weights row 2: tensor([0.20, 0.45, 0.35, 0.]) # 위치 2
위치 i 가 항상 자기를 포함한 0..i 만 본다 — causal 의 정의.
5. 실전 튜토리얼 — F.scaled_dot_product_attention 한 줄과 비교¶
PyTorch 2.x 부터 F.scaled_dot_product_attention 한 줄에 같은 연산. 내부적으로 FlashAttention (Dao et al., 2022) 또는 효율 구현 자동 선택.
is_causal=True면 mask 자동 적용. shape 추론도 자동 — 우리가 짠 5 줄이 한 줄로.- 같은 결과여야 한다. 1e-6 미만 차이는 부동소수점 오차.
왜 한 줄을 쓰는가:
- 속도: GPU 에서 FlashAttention 이 메모리 효율적 (
O(n²)아닌O(n)메모리). seq=2K 부터 체감. - 메모리: 큰 attention matrix 를 메모리에 올리지 않음.
- 유지보수: 미래 PyTorch 업데이트가 알아서 더 빨라짐.
언제 직접 짜는가: 디버깅, attention 가중치 시각화 (Ch 18), 새 변형 (RoPE 의 일부) 구현.
6. multi-head — 그저 reshape¶
d_model=64, n_head=8 이면 head 마다 head_dim = 8. 각 head 가 독립적으로 attention 한 다음 concat.
view + transpose두 줄. 새 알고리즘이 아니라 차원 쪼개기.- SDPA 가 head 차원을 자동 broadcast. 각 head 가 독립 attention.
왜 head 를 쪼개나: 여러 "보는 관점" 을 학습. head 1 은 "직전 토큰" 에, head 2 는 "마지막 명사" 에 가중을 두는 식 (학습 후 시각화로 확인 — Ch 18).
7. 자주 깨지는 포인트¶
1. √d_k 를 잊는다 — 학습 초반 손실이 안 떨어진다. d_k=64 면 dot product 평균이 √64=8 만큼 더 커져 softmax 가 한 곳으로 쏠림.
2. mask shape 실수 — causal mask 는 (T, T). attention scores 가 (B, H, T, T) 면 broadcast 가 자동이지만 변형 시 (e.g. padding mask 추가) 손으로 shape 맞춰야 함.
3. mask 자리에 0 을 채운다 — softmax 전에는 -inf 가 맞다. softmax 가 -inf 를 0 확률로 만든다.
4. dtype 불일치 — Q, K, V 가 fp16 인데 mask 가 fp32 면 cast 비용. .to(Q.dtype) 한 번.
5. is_causal=True 와 직접 mask 동시 사용 — SDPA 가 헷갈려서 두 번 적용될 수 있음. 둘 중 하나만.
6. transpose(-2, -1) vs .T — 다차원 텐서에 .T 는 모든 차원 뒤집기. 항상 transpose(-2, -1) 가 안전.
8. 운영 시 체크할 점¶
-
F.scaled_dot_product_attention사용 (직접 구현은 디버깅 때만) - PyTorch ≥ 2.0 — FlashAttention 자동 선택
- seq_len 큰 모델이면
is_causal=True로 mask 비메모리화 - head_dim 은 16의 배수 (Tensor Core 효율) — 보통 32, 64, 128
- 추론 시 KV cache 별도 (Ch 11 메모리 산수 + Part 6)
- attention 가중치 시각화는 학습 후 별도 hook 으로 (forward 안에서 저장하지 말 것 — 메모리 폭발)
9. 연습문제¶
- §4 의 5 줄 attention 을 batch B=2, seq T=8, hidden D=16 으로 돌려보고
attnshape 와 합 (attn.sum(-1)) 이 모두 1 인지 확인하라. - §5 의 SDPA 한 줄 결과와 수동 결과의 차이를 다양한 dtype (fp32, fp16, bf16) 으로 비교. 어느 dtype 이 가장 차이가 큰가?
- causal mask 를 반대로 (
triu(diagonal=0)— 자기 포함 미래만 봄) 적용하면 모델이 어떻게 학습될까? 한 epoch 돌려 손실 곡선을 정상 mask 와 비교. - multi-head 8개를 모두 같은 weight 로 초기화하면 무엇이 일어날까? PyTorch 기본 초기화가 head 마다 다른 결과를 자동 보장하는 이유는?
- (생각해볼 것) seq=10K 인 모델에서 attention
O(n²)가 메모리 100MB 를 잡는다. seq=100K 면 단순 계산으로 10GB. FlashAttention 은 어떻게 이 문제를 해결하는가? (한 문단으로 핵심만)
원전¶
- Vaswani et al. (2017). Attention Is All You Need. arXiv:1706.03762
- Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arXiv:2205.14135
- PyTorch docs —
torch.nn.functional.scaled_dot_product_attention - Karpathy. Let's build GPT (YouTube, 2023) — 같은 5줄을 영상으로