HPC Project
Optimizing GPT-2 on Multi-Node, Multi-GPU Environments
The goal of this project is to optimize GPT-2 inference on a cluster consisting of four nodes, each equipped with two GPUs.
I utilized the MPI library to evenly distribute the given workload across the nodes and fully utilize the GPUs for optimal performance. All computation kernels were implemented from scratch using CUDA and C.
The main workflow of the project is contained in the following four files: main.cpp, tensor.cu, model.cu, and layer.cu.
In main.cpp, all the messy setup tasks, including initialization, MPI, and CUDA configurations, are handled. Then, the main generation function, generate_tokens(), is called.
generate_tokens() iterates over the transformer blocks and invokes the corresponding groups of GPU kernels for inference.
The transformer_block_batch_multi_gpu() function consists of many subfunctions with finer-grained computations.
Among them, mha_batch_multi_gpu() and ffn_batch_multi_gpu() incur most of the latency.
These functions are also composed of more fine-grained operations. Ultimately, we can identify the main operation of the entire system — the GEMM (mha -> attn -> matmul_batch_cuda).
The CUDA kernel for the GEMM operation is implemented as follows.
First, I tile the matrix into blocks so that each thread block can execute a separate region of the GEMM.
Then, I make each thread compute multiple elements within a block to reduce memory I/O. You can check the core part of this kernel design in the figure below.
Valuable Reference: https://siboehm.com/articles/22/CUDA-MMM
Now let me introduce the most important optimization I made throughout this project — KV caching, which boosts the throughput to another level. KV caching is a technique that saves the Key and Value states generated from the previous tokens and simply loads them during the current token generation. Since this technique removes the need to recompute these heavy states, it significantly reduces the overall computation required during the decoding phase.
All the optimizations I applied to build the inference system for GPT-2 resulted in a throughput of approximately 65,000 tokens/sec. The full codebase of this project is available on GitHub.