https://arxiv.org/pdf/2001.04451.pdf
Abstract
- 큰 transformer model은 성능이 좋지만 특히 긴 문장에서 학습하기에는 비용이 많이든다
- 2가지 방법으로 이 문제를 개선했다
- LSH를 이용하여 $O(L^2)$ 에서 $O(LlogL)$ 로 complexity를 개선했다
- reversible residual layer 방법을 사용하여 N번 activation을 저장하는게 아니라 한번만 저장하도록 개선했다
Introduction
- transformer 기반의 모델들은 성능을 높이기위해 layer당 parameter를 0.5B, layer의 수는 64개까지 늘려왔다
- 11K 정도의 long sequence 에서도 사용이되며 music, image같은 다른 영역에서도 long sequence는 필요하다
- 하지만 큰 규모의 transformer를 학습하기위해서는 대규모의 GPU 인프라가 필요하며 1개의 GPU로는 finetunning조차 힘들다
- 이런 transformer에서 연산량은 아래와 같다
- layer당 0.5B가 필요한 모델의 경우 layer당 2GB의 GPU memory 필요
- 64K token, embedding size 1024, batch size 8일때는 64K1K8 = 0.5B floats가 저장되어야 하며 2GB GPU memory가 필요하다
- 학습할때 layer당으로만 memory를 사용하면 이런 큰 모델도 1개의 GPU로 충분히 fine-tune 가능하다(기본 transformer의 경우 layer별 activation값들을 다 저장해서 back prop해야한다)
- 기존 Transformer의 문제
- back propagation을 하기 위해 N layer에서는 single layer보다 N배 memory가 필요하다
- feed-forward layer의 depth는 attention activation보다 커서 memory 사용에 큰 부분을 차지한다
- dot product attention은 $O(L^2)$ 의 연산, 메모리 사용량을 가져서 비효율적이다
- 아래 방법을 사용하여 위 문제를 해결
- reversible layer —> N개 block의 activation을 저장하는게 아니고 1개의 activation copy만으로 skip connection 방식을 학습할 수 있는 방법
- feed-forward layer의 activation을 chunk들로 쪼개서 메모리를 절약하는 방법
- LSH를 이용해서 $O(L^2)$ 에서 $O(LlogL)$ 로 효율화
- 이러한 기법들은 base transformer에 비해 구현적으로도 크게 변화가 없이 할 수 있고, 성능 감소도 없이 진행하였다
- LSH는 2가지 기법보다는 변화량이 컸는데 concurrent hash의 수의 따라 결과가 다르게 나왔다 후에 실험에서 자세히 설명
LOCALITY-SENSITIVE HASHING ATTENTION
Dot-product attention
Memory-efficient attention