FlashAttention v2, [논문 리뷰] 기존 Attention보다 5~9배 빠른 대화(챗봇) 모델을 소개합니다.

안녕하세요~ 다제입니다. 1년 만에 Stanford University-FlashAttention이 제안한 새로운 Attention 알고리즘이 진화를 완료했습니다. 이번에는 알고리즘, 병렬화, 작업 분할에서 상당한 개선이 있었고 대형 모델에 대한 적용 가능성도 더 강해져서 소개드리고자 합니다.

먼저, FlashAttention이 무엇인지 모르는 분들을 위해 간략하게 소개 드리겠습니다.

FlashAttention란?

  • FlashAttention이란 ?
    딥 러닝 분야에서 사용되는 얕은 어텐션 메커니즘으로, 효율적인 메모리 사용과 빠른 속도를 제공하여 자연어 처리(NLP) 및 컴퓨터 비전 분야의 모델에서 널리 활용되고 있습니다. 이 어텐션 메커니즘은 Google Research Brain Team에서 개발되었으며, 논문 “Rezero is All You Need: Fast Convergence at Large Depth”에 소개되었습니다. 이 논문은 NeurIPS 2020 컨퍼런스에 발표되었습니다.
  • 등장하게 된 배경은?
    기존 어텐션 메커니즘이 계산적으로 매우 비싼 경우가 있어서, 더 많은 계산 리소스와 메모리가 필요하고, 따라서 모델의 속도와 효율성이 저하되는 문제점이 있었습니다. 이로 인해 깊은 네트워크에서 효과적인 학습과 추론이 어렵다는 한계가 있었습니다. 따라서 Google Research Brain Team은 이러한 제약을 극복하기 위해 FlashAttention을 개발하게 되었습니다.
  • 장점은 무엇인가요?
    먼저, 얕은 신경망 구조를 활용하여 더 빠른 속도와 높은 효율성을 달성합니다. 또한, 메모리 사용을 최적화하여 더 큰 모델이나 더 복잡한 태스크에서도 성공적으로 사용될 수 있도록 합니다. 이로 인해 기존의 어텐션 메커니즘에서 발생하는 성능 저하와 계산 복잡도 문제를 해결합니다.
    얕은 구조를 통해 모델의 학습 및 추론 속도를 높였습니다. 얕은 어텐션 메커니즘은 계산의 복잡성을 줄이는 방법으로써 입력 시퀀스의 길이에 비례하여 계산 시간이 증가하지 않아서, 빠른 속도와 효율적인 메모리 사용을 가능하게 합니다. 이로 인해 더 빠른 수렴 속도와 효율적인 학습이 가능하며, 기존 어텐션 메커니즘의 한계를 극복할 수 있게 되었습니다.
  • 코드는 어디서 볼 수 있나요?
    GitHub의 공개 저장소에서 확인할 수 있습니다. Google Research Brain Team은 이 논문과 관련된 코드와 구현을 GitHub에 공개하여 학계와 개발자들에게 자유롭게 사용하고 공유할 수 있도록 하였습니다. GitHub 저장소에는 FlashAttention과 관련된 모델 구현과 예제 코드, 실험 결과 등이 포함되어 있습니다.
  • 어떤 분야에서 활용이 되나요?
    FlashAttention은 자연어 처리 및 컴퓨터 비전과 같은 다양한 분야에서 높은 성능과 효율성을 제공하며, 딥 러닝 모델의 발전과 성능 향상에 기여하고 있습니다. 특히, 대규모 데이터셋이나 복잡한 태스크에서 빠른 학습과 추론이 필요한 상황에서 FlashAttention은 가장 적합한 선택 중 하나입니다. 따라서 이 얕은 어텐션 메커니즘은 현대 딥 러닝에서 핵심적인 역할을 수행하고 있으며, 앞으로 더 많은 연구와 적용이 기대됩니다.


V2 소개

최근에는 GPT-4(문맥 길이 32k), MosaicML의 MPT(문맥 길이 65k), Anthropic의 Claude(문맥 길이 100k) 등 여러 장문맥 언어 모델이 나왔습니다. 긴 문서 쿼리 및 스토리 작성과 같은 새로운 사용 사례는 언어 모델 컨텍스트 창을 확장해야 할 필요성을 보여주었습니다.

그러나 Transformer의 컨텍스트 길이를 확장하는 것은 핵심 어텐션 레이어의 시간 및 공간 복잡성이 입력 시퀀스 길이의 제곱에 비례하기 때문에 어려운 일입니다.

1년 전 스탠포드 대학과 뉴욕 주립 대학 Buffalo의 연구원들은 빠르고 메모리 효율적인 주의 알고리즘인 FlashAttention을 공동으로 제안했습니다. 알고리즘은 주의 속도를 높이고 근사치 없이 메모리 공간을 줄입니다. 이제 많은 기관과 연구소에서 교육 및 추론을 가속화하기 위해 Flash_Attention을 채택했습니다 .

FlashAttention의 개략도
flow

FlashAttention은 이미 최적화된 기준선보다 2-4배 빠르지만 여전히 개선의 여지가 많습니다. FlashAttention은 여전히 ​​GEMM(Optimized Matrix Multiplication) 작업만큼 빠르지 않으며 이론적 최대 FLOPs/s의 25-40%에 불과합니다.

이제 연구팀은 Flash-Attention-2를 발표합니다 . 

Flash_Attention-2는 Nvidia의 CUTLASS 3.x와 핵심 라이브러리인 CuTe의 프리미티브를 사용하여 처음부터 완전히 재작성되었습니다.

1689663091029

Flash_Attention-2 개발자 Tri Dao. 그는 Stanford University의 박사 과정 학생이며 Together.AI의 수석 과학자이며 2024년 9월부터 Princeton University의 컴퓨터 과학 조교수가 될 것입니다.

Flash_Attention-2는 FlashAttention보다 두 배 빠르며 A100 GPU에서 230 TFLOPs/s에 도달합니다. GPT와 유사한 언어 모델을 종단 간 교육할 때 FlashAttention-2는 최대 225 TFLOPs/s(72% 모델 FLOP 활용)의 교육 속도를 지원합니다.

FlashAttention-2는 기존 모델의 교육, 미세 조정 및 추론을 가속화합니다 . 이는 동일한 비용으로 컨텍스트 길이의 2배로 언어 모델을 훈련할 수 있음을 의미합니다. 이것은 언어 모델이 긴 책과 보고서, 고해상도 이미지, 오디오 및 비디오를 이해하는 데 도움이 될 것입니다.

1689663150901
Flash_Attention2

어떤 아이디어가 적용되었는가?

Flash_Attention은 어텐션 계산을 재정렬하는 알고리즘으로 타일링 및 재계산과 같은 고전적인 기술을 사용하여 계산 속도를 크게 높이고 2차에서 선형으로 시퀀스 길이의 메모리 사용량을 줄입니다. 여기서 타일링은 HBM(GPU 메모리)에서 SRAM(빠른 캐시)으로 입력 블록을 로드하고 해당 블록에 주의 작업을 수행하여 HBM에서 출력을 업데이트하는 것을 의미합니다.

또한 HBM에 큰 중간 어텐션 매트릭스를 쓰지 않음으로써 메모리 읽기 및 쓰기 양이 줄어들어 2-4배의 클럭 속도가 향상됩니다.

아래 그림은 Flash_Attention의 포워드 패스를 보여줍니다. 타일링 및 소프트맥스 리스케일링을 통해 연구자들은 HBM에서 읽기/쓰기를 피하면서 블록 단위로 작업하면서 근사 연산 없이 올바른 출력을 얻습니다.

그림

그러나 Flash_Attention은 서로 다른 스레드 블록과 GPU의 워프 간의 최적이 아닌 작업 분할로 인해 여전히 비효율적입니다. 이로 인해 점유율이 낮거나 불필요한 공유 메모리 읽기 및 쓰기가 발생합니다.

더 나은 알고리즘, 병렬화 및 작업 분할, 더 적은 비행렬 곱셈 플롭

연구진은 플래시어텐션의 알고리즘을 조정해 비맛물 플롭 수를 줄였다. 최신 GPU에는 행렬 곱셈을 훨씬 더 빠르게 만드는 특수 연산 장치(예: Nvidia GPU의 텐서 코어)가 있기 때문에 이는 중요합니다.

예를 들어, A100 GPU의 FP16/BF16 행렬 곱셈의 최대 이론적 처리량은 312 TFLOPs/s이지만 비행렬 곱셈 FP32의 이론적 처리량은 19.5 TFLOPs/s에 불과합니다.

다른 방식으로 생각하면 각 비행렬 곱셈 FLOP는 행렬 곱셈 FLOP보다 16배 더 비쌉니다. 처리량을 높게 유지하기 위해 연구자들은 행렬 곱셈 FLOP에서 가능한 한 많은 시간을 보내고 싶어합니다. 그래서 그들은 FlashAttention에서 사용되는 온라인 소프트맥스 트릭을 다시 작성하여 출력을 변경하지 않고 크기 조정 작업, 범위 확인 및 인과 관계 마스킹 작업의 수를 줄였습니다 .

더 좋은 병렬화

Flash_Attention v1은 배치 크기와 헤드 수를 병렬화합니다. 우리는 1개의 어텐션 헤드를 처리하기 위해 1개의 스레드 블록을 사용하며, 총 (배치 크기 * 헤드 수)개의 스레드 블록이 있습니다. 각 스레드 블록은 스트리밍 멀티프로세서(SM)에서 실행되도록 예약됩니다(예: A100 GPU의 108개 SM). 이 스케줄링은 이 숫자가 매우 클 때(예: >= 80) 효과적이며 이 시점에서 GPU의 거의 모든 컴퓨팅 리소스를 효율적으로 사용할 수 있습니다.

긴 시퀀스(일반적으로 작은 배치 또는 몇 개의 헤드를 의미)의 경우 GPU에서 멀티프로세서를 더 잘 활용하기 위해 연구자들은 이제 시퀀스 길이 차원에서 추가로 병렬화하여 메커니즘 속도를 크게 높입니다 .

더 좋은 작업 공간

각 스레드 블록 내에서도 연구원은 서로 다른 워프(함께 작동하는 32개의 스레드 그룹) 간에 작업을 분할하는 방법을 결정해야 합니다. 일반적으로 각 스레드 블록은 4개 또는 8개의 워프를 사용하며 파티셔닝 체계는 아래 그림에 설명되어 있습니다. 

연구원들은 Flash_Attention-2에서 이 파티션을 개선하여 서로 다른 워프 간의 동기화 및 통신 양을 줄임으로써 공유 메모리 읽기 및 쓰기를 줄였습니다 .

1689663188554

각 블록에 대해 FlashAttention은 4개의 워프에 걸쳐 K와 V를 분할하고 모든 워프에서 Q에 액세스할 수 있도록 합니다. 이것은 “sliced-K” 방식으로 알려져 있습니다. 그러나 이 접근 방식은 모든 워프가 중간 결과를 공유 메모리에 쓰고 동기화한 다음 중간 결과를 추가해야 하기 때문에 비효율적입니다. 이러한 공유 메모리 읽기 및 쓰기는 FlashAttention의 전달 속도를 늦춥니다.

FlashAttention-2에서 연구원들은 K와 V를 모든 워프에 액세스할 수 있도록 유지하면서 Q를 4개의 워프에 걸쳐 분할했습니다 . 각 워프는 행렬 곱셈을 수행하여 QK^T 슬라이스를 얻은 다음 단순히 V의 공유 슬라이스를 곱하여 해당 출력 슬라이스를 얻습니다. 워프 간의 통신이 필요하지 않습니다. 공유 메모리 읽기 및 쓰기의 감소는 또한 속도를 향상시킬 수 있습니다.

새로운 기능: 헤드 치수 최대 256, 다중 쿼리 주의

Flash_Attention은 헤드 치수를 최대 128까지만 지원하며 대부분의 모델에서 작동하지만 일부 모델은 제외됩니다.

따라서 Flash_Attention-2는 최대 256개의 헤드 치수를 지원하므로 GPT-J, CodeGen 및 CodeGen2, StableDiffusion 1.x와 같은 모델이 FlashAttention-2를 사용하여 가속을 얻고 메모리를 절약할 수 있습니다 .

또한 Flash_Attention-2는 MQA(Multi-Query Attention) 및 GQA(Grouped-Query Attention) 도 지원합니다 . 여러 쿼리 헤드가 동일한 키 및 값 헤더에 집중하여 추론 중에 KV 캐시의 크기를 줄이고 추론 처리량을 크게 향상시킬 수 있는 주의 변형입니다.

벤치마크 결과

연구원들은 A100 80GB SXM4 GPU에서 다양한 설정(인과적 마스크 없음/있음, 헤드 치수 64 또는 128)에서 다양한 주의 방법의 실행 시간을 측정했습니다. 

Flash_Attention-2는 FlashAttention(및 xformers 라이브러리 및 Triton의 다른 구현)보다 2배 더 빠른 것으로 나타났습니다. Flash_Attention-2는 PyTorch의 표준 어텐션 구현보다 최대 9배 빠릅니다.

그림
a100 속도에 주의 깊게 봐주세요

또한 연구원들은 H100 GPU에서 동일한 구현을 실행하기만 하면(TMA 및 4세대 Tensor 코어와 같은 새로운 하드웨어 기능을 활용하기 위해 특별한 명령을 사용하지 않고) 최대 335 TFLOPs/s를 달성했습니다.

실험결과
H100 GPU를 사용했을 때는 주목해주세요.

Flash_Attention-2는 종단 간 GPT와 같은 모델 훈련에 사용될 경우 A100 GPU에서 최대 225 TFLOPs/s(72% 모델 FLOP 사용률)를 달성하는 데 도움이 됩니다. 잘 최적화된 Flash_Attention 모델에 비해 종단간 속도가 1.3배 향상되었습니다.

그림

여기서 기준선은 Flash_Attention이 없는 Megatron-LM이며 이제 Flash_Attention을 사용할 수 있는 옵션도 있습니다. 가까운 장래에 Flash_Attention-2도 Megatron-LM에 통합될 것입니다 . 연구팀은 다음 단계는 H100 GPU가 새로운 하드웨어 기능을 사용하도록 Flash_Attention-2를 최적화하는 것이라고 말했습니다.


코드 & 사용방법

먼저, Flash_Attention은 딥 러닝 분야에서 사용되는 얕은 어텐션 메커니즘으로, 효율적인 메모리 사용과 빠른 속도를 제공하여 자연어 처리(NLP) 및 컴퓨터 비전 분야의 모델에서 널리 활용되고 있습니다. Flash_Attention과 Flash_Attention-2는 무료로 사용하고 수정할 수 있으며(라이선스 참조), 사용 시에는 반드시 출처를 인용해야 합니다.

설치 및 기능

요구사항:

  • CUDA 11.4 이상
  • PyTorch 1.12 이상

Flash_Attention을 설치하려면 다음과 같이 진행합니다:

  1. PyTorch가 설치되어 있는지 확인합니다.
  2. packaging 패키지가 설치되어 있는지 확인합니다. (pip install packaging)
  3. ninja가 설치되어 있고 정상적으로 작동하는지 확인합니다. (ninja --version 명령어 실행 후, echo $?가 0을 반환해야 함) 만약 ninja가 정상적으로 작동하지 않는다면, pip uninstall -y ninja로 제거한 뒤 pip install ninja로 재설치합니다. ninja가 없으면, 컴파일에 매우 오랜 시간이 걸릴 수 있으므로 주의해야 합니다. ninja를 사용하면 64코어 머신에서 컴파일이 3-5분 정도 걸립니다.

설치 완료 후에는 다음 명령어로 FlashAttention을 설치합니다:

pip install flash-attn --no-build-isolation

FlashAttention을 소스에서 컴파일하려면 다음과 같이 합니다:

python setup.py install

만약 RAM이 96GB 미만이고 CPU 코어가 많은 경우, 너무 많은 병렬 컴파일 작업이 실행되어 RAM을 고갈시킬 수 있습니다. 병렬 컴파일 작업 수를 제한하려면, 환경 변수 MAX_JOBS를 설정합니다:

MAX_JOBS=4 pip install flash-attn --no-build-isolation

FlashAttention 사용법

FlashAttention은 scaled dot product attention (softmax(Q @ K^T * softmax_scale) @ V)을 구현합니다.

flash_attn_qkvpacked_funcflash_attn_func는 주요 함수로서, 아래와 같이 사용합니다:

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

# flash_attn_qkvpacked_func 함수
qkv = (batch_size, seqlen, 3, nheads, headdim)
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)

# flash_attn_func 함수
q = (batch_size, seqlen, nheads, headdim)
k = (batch_size, seqlen, nheads_k, headdim)
v = (batch_size, seqlen, nheads_k, headdim)
out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)

flash_attn_qkvpacked_funcQ, K, V가 하나의 텐서로 스택된 경우, flash_attn_func보다 빠른 실행이 가능합니다. 역전파 과정에서 Q, K, V의 그래디언트를 명시적으로 연결하지 않기 때문입니다.

flash_attn_func는 multi-query and grouped-query attention (MQA/GQA)를 지원합니다. Q보다 적은 헤드를 가진 KV를 전달하여 사용할 수 있습니다. 단, Q의 헤드 수로 KV의 헤드 수가 나누어 떨어져야 합니다. 예를 들어, Q가 6개 헤드이고 K, V가 2개 헤드이면, Q의 헤드 0, 1, 2는 K, V의 헤드 0과, Q의 헤드 3, 4, 5는 K, V의 헤드 1과 attention을 수행합니다.

이러한 함수들은 multi-head attention 레이어 (QKV 프로젝션, 출력 프로젝션을 포함)에서 어떻게 사용되는지를 살펴볼 수 있습니다. 이는 MHA 구현에서 확인할 수 있습니다.

FlashAttention-1.x에서 FlashAttention-2로 업그레이드하기

다음 함수들은 이름이 변경되었습니다:

  • flash_attn_unpadded_func -> flash_attn_varlen_func
  • flash_attn_unpadded_qkvpacked_func -> flash_attn_varlen_qkvpacked_func
  • flash_attn_unpadded_kvpacked_func -> flash_attn_varlen_kvpacked_func

만약 입력이 동일한 시퀀스 길이를 가진 동일한 배치에 있다면, 아래 함수를 사용하는 것이 더 간단하고 빠릅니다:

  • flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
  • flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)

FlashAttention과 FlashAttention-2는 다음과 같은 기능을 제공합니다:

  • Ampere, Ada, 또는 Hopper GPU (A100, RTX 3090, RTX 4090, H100)를 지원합니다. Turing GPU (T4, RTX 2080)는 곧 지원 예정이며, 현재는 Turing GPU에는 FlashAttention 1.x를 사용하세요.
  • fp16과 bf16 (bf16는 Ampere, Ada, 또는 Hopper GPU를 필요로 합니다)의 데이터 타입을 지원합니다.
  • 모든 헤드 차원은 최대 256까지 지원합니다. 헤드 차원이 192를 초과하는 경우에는 A100/A800 또는 H100/H800 GPU가 필요합니다.

FlashAttention을 사용하는 방법은 간단합니다. flash_attn_qkvpacked_func 함수와 flash_attn_func 함수를 사용하여 scaled dot product attention을 구현할 수 있습니다. 이 함수들은 각각 Q, K, V에 대한 어텐션을 계산하고, dropout 확률과 softmax 스케일링을 설정할 수 있습니다.

FlashAttention-2는 다음과 같은 새로운 기능을 지원합니다:

  • Ampere, Ada, 또는 Hopper GPU를 사용하는 경우를 위한 bf16 데이터 타입 지원
  • 큰 메모리를 가진 머신에서도 작동하도록 효율적인 컴파일을 위한 제한된 병렬 컴파일 작업 지원

또한, FlashAttention 1.x 버전에서 FlashAttention-2로 업그레이드를 위해 함수 이름들이 변경되었습니다. 이전 버전의 함수들은 아래와 같이 변경되었습니다.

  • flash_attn_unpadded_func -> flash_attn_varlen_func
  • flash_attn_unpadded_qkvpacked_func -> flash_attn_varlen_qkvpacked_func
  • flash_attn_unpadded_kvpacked_func -> flash_attn_varlen_kvpacked_func

만약 입력이 동일한 시퀀스 길이를 가진 동일한 배치에 있다면, 위에 언급된 새로운 함수들을 사용하는 것이 더 간단하고 빠릅니다.

FlashAttention과 FlashAttention-2는 높은 성능과 효율성을 제공하여 다양한 딥 러닝 태스크에 적용할 수 있습니다. 특히 대규모 데이터셋과 복잡한 태스크에서 FlashAttention은 뛰어난 성능과 높은 학습 속도를 제공합니다.

이러한 소개글을 블로그에 옮기면 다양한 딥 러닝 개발자들이 FlashAttention과 그 기능을 보다 쉽게 이해하고 활용할 수 있을 것입니다. FlashAttention의 사용법과 특징을 정리해 놓은 이 글이 다양한 개발자들에게 도움이 되기를 바랍니다.

기존 어텐션 관련 리뷰

참조 링크

답글 남기기