Author here. It's a design choice, but there's two reasons I chose to use imports like this:
1) For demonstrative purposes. The title of the post is `A GPT in 60 Lines of NumPy`, I kinda wanted to show "hey it's just numpy, nothing to be scared about!". Also if an import is ONLY used in a single function, I find it visually helps show that "hey, this import is only used in this function" vs when it's at the top of the file you're not really sure when/where and how many times an import is used.
2) Scoping. `load_encoder_hparams_and_params` imports tensorflow, which is really slow to import. When I was testing, I used randomly initialized weights instead of loading the checkpoint which is slower, so I was only making use of the `gpt2` function. If I kept the import at the top level, it would've slowed things down unnecessarily.
Thank you for all the nice and constructive comments!
For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...
The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.
Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:
def lm_loss(params, inputs, n_head) -> float:
x, y = inputs[:-1], inputs[1:]
output = gpt(x, **params, n_head=n_head)
loss = np.mean(-np.log(output[y]))
return loss
grads = jax.grad(lm_loss)(params, inputs, n_head)
"hackable" and "simple yet complete technical introduction"
Music to my ears, well done and don't worry too much about the negative comments! They'll come out for anything you do I think.
I saw a tweet from someone the other day talking about how they massively increased their training speed by changing part of their architecture to have dimensions that were a factor of 64 rather than a prime-like kind of number.
One of the comments below it? ~"Seems very architecture specific."
lol.
So don't sweat it! <3 Great work and thanks for putting yourself out there, super job! :D :D :D :D :)))))) <3 :D :D :fireworks:
This is beautiful. Having worked with everything from nanoGPT to Megatron, sitting down and reading through picoGPT.py was clear and refreshing with just the essential details. Nothing left to add, nothing left to take away: perfection.
If you haven't tried cuNumeric [1], you really ought to. It's a drop-in NumPy wrapper for distributed GPU acceleration. Would be interesting to see if it works for this.
The problem with drop-in replacements between CPU and GPU code is that performance GPU code requires rethinking the dataflow often -- so even if the code itself is a drop-in, the "make it good" part still requires some rewriting.
I'd be curious how that library compares to other numeric python GPU libraries
I want to commend you for one of the best written introductions in this space that I've seen, especially the excellent use of hyperlinking that points to really good resources exactly at the right time !