🔥llm.c update: Our single file of 2,000 ~clean lines of C/CUDA code now trains GPT-2 (124M) on GPU at speeds ~matching PyTorch (fp32, no flash attention)
On my A100 I'm seeing 78ms/iter for llm.c and 80ms/iter for PyTorch. Keeping in mind this is fp32, with no flash attention yet, and slightly stale PyTorch (2.1.0).
- It is a direct implementation of the training loop and backpropagation in C/CUDA.
- It compiles and runs instantly. No more "hit run then wait for tens of seconds for unknown reasons", for mountains of inscrutable abstractions to build a Universe.
- It deletes the need for the Python interpreter and a deep learning library.
- It allocates all the memory a single time at the start.
- It's pretty cool.
How:
Getting this to work required us to write a lot of custom CUDA kernels, and doing this manually (instead of using Tensor ops of aten/PyTorch and torch.compile etc.) is a bit like programming in assembly. And you spend quality time looking at more assembly (CUDA PTX/SASS). But this also means we get to hyperoptimize the code and possibly explore optimizations that torch.compile might find difficult to, which is awesome. Examples of optimizations that went in over the last few days:
- we're being clever with our memory consumption in the backward pass, only using a few buffers we need to propagate the gradients, saving memory capacity.
- one fused classifier kernel does the last layer forward pass, the loss, and kicks off the backward pass.
- many improvements to all the kernels involved, including e.g. gains from carefully constraining execution within the autoregressive mask in attention
- cuBLAS(Lt) calls for all heavy lifting matmuls, and fused bias accumulation
Big credits to two CUDA experts who appeared from somewhere on the internet to help this open source project, ngc92 and ademeure. We're hanging out of Github and Discords of CUDAMODE and my NN Zero to Hero.
Next steps:
- more optimizing of our (fp32) kernels, and especially switch to flash attention.
- mixed precision training (fp16 to start).
- multi-gpu training (DDP to start).
- data & evals to set up a proper GPT-2 training runs
- 🚀 repro GPT-2 (1.6B) training run.
- more modern architectures etc. (Llama 3?)
- writing, videos, exercises on building all of this from scratch.
Figure 1: eye candy: timing profile of the kernels (one layer). NVIDIA cutlass kernels with solid compute throughput taking up a lot of the running time => nice.
显示更多