์ ์: ๋ง์ดํด ๊ทธ์ฌ๋น๋ ๋ฒ์ญ: ์ด์งํ
์ด ํํ ๋ฆฌ์ผ์์๋ PyTorch 1.12 ๋ฒ์ ์ ์ผ๋ถ๋ก Better Transformer (BT)๋ฅผ ์๊ฐํฉ๋๋ค. ์ฌ๊ธฐ์๋ torchtext๋ฅผ ์ฌ์ฉํด ์์ฉํ๋ ์ ํ ์์ค์ ์ถ๋ก ์์ Better Transformer๋ฅผ ์ ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. Better Transformer๋ ์์ฉ ์ ํ ์์ค์ผ๋ก ๋ฐ๋ก ์ ์ฉ๊ฐ๋ฅํ fastpath์ ๋๋ค. ์ด๋, CPU์ GPU์์ ๊ณ ์ฑ๋ฅ์ผ๋ก ๋ ๋น ๋ฅด๊ฒ Transformer ๋ชจ๋ธ์ ๋ฐฐํฌํ ์ ์๊ฒ๋ ํด์ค๋๋ค. ์ด fastpath ๊ธฐ๋ฅ์ PyTorch ์ฝ์ด nn.module์ ์ง์ ๊ธฐ๋ฐ์ผ๋ก ํ๊ฑฐ๋ torchtext๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ๋ํด ์ดํดํ๊ธฐ ์ฝ๊ณ ๋ช ํํ๊ฒ ์๋ํฉ๋๋ค.
Better Transformer fastpath๋ก ๊ฐ์ํ๋ ์ ์๋ ๋ชจ๋ธ์ PyTorch ์ฝ์ด torch.nn.module ํด๋์ค์ธ TransformerEncoder, TransformerEncoderLayer, ๊ทธ๋ฆฌ๊ณ MultiHeadAttention์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ๋๋ค. ๋ํ, torchtext๋ fastpath ๊ฐ์ํ์ ์ด์ ์ ์ป๊ธฐ ์ํด ์ฝ์ด ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ชจ๋๋ค์ ์ฌ์ฉํ๋๋ก ์ ๋ฐ์ดํธ๋์์ต๋๋ค. (์ถํ ๋ ๋ง์ ๋ชจ๋์ด fastpath ์คํ์ ์ง์ํ ์ ์์ต๋๋ค.)
Better Transformer๋ ๋ ๊ฐ์ง ์ ํ์ ๊ฐ์ํ๋ฅผ ์ ๊ณตํฉ๋๋ค:
- CPU์ GPU์ ๋ํ Native multihead attention(MHA) ๊ตฌํ์ผ๋ก ์ ๋ฐ์ ์ธ ์คํ ํจ์จ์ฑ์ ํฅ์์ํต๋๋ค.
- NLP ์ถ๋ก ์์์ sparsity๋ฅผ ํ์ฉํฉ๋๋ค. ๊ฐ๋ณ ๊ธธ์ด ์ ๋ ฅ(variable input lengths)์ผ๋ก ์ธํด ์ ๋ ฅ ํ ํฐ์ ๋ง์ ์์ ํจ๋ฉ ํ ํฐ์ด ํฌํจ๋ ์ ์๋๋ฐ, ์ด๋ฌํ ํ ํฐ๋ค์ ์ฒ๋ฆฌ๋ฅผ ๊ฑด๋๋ฐ์ด ์๋นํ ์๋ ํฅ์์ ์ ๊ณตํฉ๋๋ค.
Fastpath ์คํ์ ๋ช ๊ฐ์ง ๊ธฐ์ค์ ์ถฉ์กฑํด์ผ ํฉ๋๋ค. ๊ฐ์ฅ ์ค์ํ ๊ฑด, ๋ชจ๋ธ์ด ์ถ๋ก ๋ชจ๋์์ ์คํ๋์ด์ผ ํ๋ฉฐ gradient tape ์ ๋ณด๋ฅผ ์์งํ์ง ์๋ ์ ๋ ฅ ํ ์์ ๋ํด ์๋ํด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค(์: torch.no_grad๋ฅผ ์ฌ์ฉํ์ฌ ์คํ).
์ด ์์ ๋ฅผ Google Colab์์ ๋ฐ๋ผํ๋ ค๋ฉด, ์ฌ๊ธฐ๋ฅผ ํด๋ฆญ.
- ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ๋ก๋ (Better Transformer ์์ด PyTorch ๋ฒ์ 1.12 ์ด์ ์ ์์ฑ๋ ๋ชจ๋ธ)
- CPU์์ BT fastpath๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ์ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ์ ์ถ๋ก ์ ์คํ ๋ฐ ๋ฒค์น๋งํฌ (๋ค์ดํฐ๋ธ MHA๋ง ํด๋น)
- (๊ตฌ์ฑ ๊ฐ๋ฅํ)๋๋ฐ์ด์ค์์ BT fastpath๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ์ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ์ ์ถ๋ก ์ ์คํ ๋ฐ ๋ฒค์น๋งํฌ (๋ค์ดํฐ๋ธ MHA๋ง ํด๋น)
- sparsity ์ง์ ํ์ฑํ
- (๊ตฌ์ฑ ๊ฐ๋ฅํ) ๋๋ฐ์ด์ค์์ BT fastpath๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ์ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ์ ์ถ๋ก ์ ์คํ ๋ฐ ๋ฒค์น๋งํฌ (๋ค์ดํฐ๋ธ MHA + ํฌ์์ฑ)
๋ ๋์ ํธ๋์คํฌ๋จธ์ ๋ํ ์ถ๊ฐ ์ ๋ณด๋ PyTorch.Org ๋ธ๋ก๊ทธ์์ ํ์ธํ ์ ์์ต๋๋ค. ๊ณ ์ ํธ๋์คํฌ๋จธ ์ถ๋ก ์ ์ํ Better Transformer.
- ์ค์
1.1 ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
torchtext.models ์ ์ง์นจ์ ๋ฐ๋ผ ๋ฏธ๋ฆฌ ์ ์๋ torchtext ๋ชจ๋ธ์์ XLM-R ๋ชจ๋ธ์ ๋ค์ด๋ก๋ํฉ๋๋ค. ๋ํ ๊ฐ์๊ธฐ ์์์์ ํ ์คํธ๋ฅผ ์คํํ๊ธฐ ์ํด DEVICE๋ฅผ ์ค์ ํฉ๋๋ค. (ํ์์ ๋ฐ๋ผ ์ฌ์ฉ ํ๊ฒฝ์ ๋ง๊ฒ GPU ์คํ์ ํ์ฑํ๋ฉด ๋ฉ๋๋ค.)
import torch
import torch.nn as nn
print(f"torch version: {torch.__version__}")
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"torch cuda available: {torch.cuda.is_available()}")
import torch, torchtext
from torchtext.models import RobertaClassificationHead
from torchtext.functional import to_tensor
xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = 1024)
model = xlmr_large.get_model(head=classifier_head)
transform = xlmr_large.transform()1.2 ๋ฐ์ดํฐ์ ์ค์
๋ ๊ฐ์ง ์ ํ์ ์ ๋ ฅ์ ์ค์ ํ๊ฒ ์ต๋๋ค. ์์ ์ ๋ ฅ ๋ฐฐ์น์ sparsity๋ฅผ ๊ฐ์ง ํฐ ์ ๋ ฅ ๋ฐฐ์น์ ๋๋ค.
small_input_batch = [
"Hello world",
"How are you!"
]
big_input_batch = [
"Hello world",
"How are you!",
"""`Well, Prince, so Genoa and Lucca are now just family estates of the
Buonapartes. But I warn you, if you don't tell me that this means war,
if you still try to defend the infamies and horrors perpetrated by
that Antichrist- I really believe he is Antichrist- I will have
nothing more to do with you and you are no longer my friend, no longer
my 'faithful slave,' as you call yourself! But how do you do? I see
I have frightened you- sit down and tell me all the news.`
It was in July, 1805, and the speaker was the well-known Anna
Pavlovna Scherer, maid of honor and favorite of the Empress Marya
Fedorovna. With these words she greeted Prince Vasili Kuragin, a man
of high rank and importance, who was the first to arrive at her
reception. Anna Pavlovna had had a cough for some days. She was, as
she said, suffering from la grippe; grippe being then a new word in
St. Petersburg, used only by the elite."""
]๋ค์์ผ๋ก, ์์ ์ ๋ ฅ ๋ฐฐ์น ๋๋ ํฐ ์ ๋ ฅ ๋ฐฐ์น ์ค ํ๋๋ฅผ ์ ํํ๊ณ , ์ ๋ ฅ์ ์ ์ฒ๋ฆฌํ ํ ๋ชจ๋ธ์ ํ ์คํธํฉ๋๋ค.
input_batch=big_input_batch
model_input = to_tensor(transform(input_batch), padding_value=1)
output = model(model_input)
output.shape๋ง์ง๋ง์ผ๋ก, ๋ฒค์น๋งํฌ ๋ฐ๋ณต ํ์๋ฅผ ์ค์ ํฉ๋๋ค.
ITERATIONS=10- ์คํ
2.1 CPU์์ BT fastpath๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ์ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ์ ์ถ๋ก ์ ์คํ ๋ฐ ๋ฒค์น๋งํฌ (๋ค์ดํฐ๋ธ MHA๋ง ํด๋น)
CPU์์ ๋ชจ๋ธ์ ์คํํ๊ณ ํ๋กํ์ผ ์ ๋ณด๋ฅผ ์์งํฉ๋๋ค:
- ์ฒซ ๋ฒ์งธ ์คํ์ ์ ํต์ ์ธ ์คํ('slow path')์ ์ฌ์ฉํฉ๋๋ค.
- ๋ ๋ฒ์งธ ์คํ์ model.eval()์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ถ๋ก ๋ชจ๋๋ก ์ค์ ํ๊ณ torch.no_grad()๋ก ๋ณํ๋(gradient) ์์ง์ ๋นํ์ฑํํ์ฌ BT fastpath ์คํ์ ํ์ฑํํฉ๋๋ค.
CPU์์ ๋ชจ๋ธ์ ์คํํ ๋ ์ฑ๋ฅ์ด ํฅ์๋ ๊ฒ์ ๋ณผ ์ ์์ ๊ฒ๋๋ค.(ํฅ์ ์ ๋๋ CPU ๋ชจ๋ธ์ ๋ฐ๋ผ ๋ค๋ฆ ๋๋ค) fastpath ํ๋กํ์ผ์์ ๋๋ถ๋ถ์ ์คํ ์๊ฐ์ด ๋ค์ดํฐ๋ธ `TransformerEncoderLayer`์ ์ ์์ค ์ฐ์ฐ์ ๊ตฌํํ `aten::_transformer_encoder_layer_fwd`์ ์์๋๋ ๊ฒ์ ์ฃผ๋ชฉํ์ธ์:
print("slow path:")
print("==========")
with torch.autograd.profiler.profile(use_cuda=False) as prof:
for i in range(ITERATIONS):
output = model(model_input)
print(prof)
model.eval()
print("fast path:")
print("==========")
with torch.autograd.profiler.profile(use_cuda=False) as prof:
with torch.no_grad():
for i in range(ITERATIONS):
output = model(model_input)
print(prof)2.2 (๊ตฌ์ฑ ๊ฐ๋ฅํ)๋๋ฐ์ด์ค์์ BT fastpath๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ์ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ์ ์ถ๋ก ์ ์คํ ๋ฐ ๋ฒค์น๋งํฌ (๋ค์ดํฐ๋ธ MHA๋ง ํด๋น)
BT sparsity ์ค์ ์ ํ์ธํด๋ณด๊ฒ ์ต๋๋ค.
model.encoder.transformer.layers.enable_nested_tensor์ด๋ฒ์ BT sparsity์ ๋นํ์ฑํํฉ๋๋ค.
model.encoder.transformer.layers.enable_nested_tensor=FalseDEVICE์์ ๋ชจ๋ธ์ ์คํํ๊ณ , DEVICE์์์ ๋ค์ดํฐ๋ธ MHA ์คํ์ ๋ํ ํ๋กํ์ผ ์ ๋ณด๋ฅผ ์์งํฉ๋๋ค:
- ์ฒซ ๋ฒ์งธ ์คํ์ ์ ํต์ ์ธ ('slow path') ์คํ์ ์ฌ์ฉํฉ๋๋ค.
- ๋ ๋ฒ์งธ ์คํ์ model.eval()์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ถ๋ก ๋ชจ๋๋ก ์ค์ ํ๊ณ torch.no_grad()๋ก ๋ณํ๋(gradient) ์์ง์ ๋นํ์ฑํํ์ฌ BT fastpath ์คํ์ ํ์ฑํํฉ๋๋ค.
GPU์์ ์คํํ ๋, ํนํ ์์ ์ ๋ ฅ ๋ฐฐ์น๋ก ์ค์ ํ ๊ฒฝ์ฐ ์๋๊ฐ ํฌ๊ฒ ํฅ์๋๋ ๊ฒ์ ๋ณผ ์ ์์ ๊ฒ๋๋ค.
model.to(DEVICE)
model_input = model_input.to(DEVICE)
print("slow path:")
print("==========")
with torch.autograd.profiler.profile(use_cuda=True) as prof:
for i in range(ITERATIONS):
output = model(model_input)
print(prof)
model.eval()
print("fast path:")
print("==========")
with torch.autograd.profiler.profile(use_cuda=True) as prof:
with torch.no_grad():
for i in range(ITERATIONS):
output = model(model_input)
print(prof)2.3 (๊ตฌ์ฑ ๊ฐ๋ฅํ) ๋๋ฐ์ด์ค์์ BT fastpath๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ์ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ์ ์ถ๋ก ์ ์คํ ๋ฐ ๋ฒค์น๋งํฌ (๋ค์ดํฐ๋ธ MHA + ํฌ์์ฑ)
sparsity ์ง์์ ํ์ฑํํฉ๋๋ค.
model.encoder.transformer.layers.enable_nested_tensor = TrueDEVICE์์ ๋ชจ๋ธ์ ์คํํ๊ณ , DEVICE์์์ ๋ค์ดํฐ๋ธ MHA์ sparsity ์ง์ ์คํ์ ๋ํ ํ๋กํ์ผ ์ ๋ณด๋ฅผ ์์งํฉ๋๋ค:
- ์ฒซ ๋ฒ์งธ ์คํ์ ์ ํต์ ์ธ ('slow path') ์คํ์ ์ฌ์ฉํฉ๋๋ค.
- ๋ ๋ฒ์งธ ์คํ์ model.eval()์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ถ๋ก ๋ชจ๋๋ก ์ค์ ํ๊ณ torch.no_grad()๋ก ๋ณํ๋(gradient) ์์ง์ ๋นํ์ฑํํ์ฌ BT fastpath ์คํ์ ํ์ฑํํฉ๋๋ค.
GPU์์ ์คํํ ๋, ํนํ sparsity๋ฅผ ํฌํจํ๋ ํฐ ์ ๋ ฅ ๋ฐฐ์น ์ค์ ์์ ์๋นํ ์๋ ํฅ์์ ๋ณผ ์ ์์ ๊ฒ๋๋ค.
model.to(DEVICE)
model_input = model_input.to(DEVICE)
print("slow path:")
print("==========")
with torch.autograd.profiler.profile(use_cuda=True) as prof:
for i in range(ITERATIONS):
output = model(model_input)
print(prof)
model.eval()
print("fast path:")
print("==========")
with torch.autograd.profiler.profile(use_cuda=True) as prof:
with torch.no_grad():
for i in range(ITERATIONS):
output = model(model_input)
print(prof)์ด ํํ ๋ฆฌ์ผ์์๋ torchtext์์ PyTorch ์ฝ์ด์ ํธ๋์คํฌ๋จธ ์ธ์ฝ๋ ๋ชจ๋ธ์ ์ํ Better Transformer ์ง์์ ํ์ฉํ์ฌ, Better Transformer๋ฅผ ์ด์ฉํ ๊ณ ์ ํธ๋์คํฌ๋จธ ์ถ๋ก ์ ์๊ฐํ์ต๋๋ค. BT fastpath ์คํ์ด ๊ฐ๋ฅํด์ง๊ธฐ ์ด์ ์ ํ๋ จ๋ ๋ชจ๋ธ์์ Better Transformer์ ์ฌ์ฉ์ ์์ฐํ์ต๋๋ค. ๋ํ BT fastpath ์คํ์ ๋ ๊ฐ์ง ๋ชจ๋์ธ ๋ค์ดํฐ๋ธ MHA ์คํ๊ณผ BT sparsity ๊ฐ์ํ์ ์ฌ์ฉ์ ์์ฐ ๋ฐ ๋ฒค์น๋งํฌ๋ฅผ ํด๋ณด์์ต๋๋ค.