Pytorch 가 TPU 를 공식적으로 지원한다는 내용이 공개 되었습니다. 그래서 이게 정말 쓸 수 있는 건지, 이것 저곳 코드를 많이 바꿔야 되는 건 아닌지 궁금해서 Pytorch/XLA API Docs, GCP Pytorch TPU 관련 문서, Pytorch TPU 발표 영상을 보고 내용을 정리해 봤습니다.
Pytorch Official FB Post : Google Cloud TPU support in PyTorch is now broadly available. Hear how engineers from Facebook, Google, and Salesforce worked together to enable and pilot this integration. Learn how to get started here: http://bit.ly/2NU755z
PyTorch on Google Cloud TPUs - Google, Salesforce, Facebook
Training PyTorch models on Cloud TPU Pods | Cloud TPU | Google Cloud
import torch_xla.core.xla_model as xm
import torch_xla.distributed.data_parallel as dp
devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(MNIST, device_ids=devices)
def train_loop_fn(model, loader, device, context):
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
model.train()
for data, target in loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
for epoch in range(1, num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
(vm)$ docker run -it --shm-size 16G gcr.io/tpu-pytorch/xla:r0.5
(pytorch) root@CONTAINERID:/$ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
(pytorch) root@CONTAINERID:/$ python pytorch/xla/test/test_train_mnist.py