Hacker Timesnew | past | comments | ask | show | jobs | submit | jaykmody's commentslogin


Author here. It's a design choice, but there's two reasons I chose to use imports like this:

1) For demonstrative purposes. The title of the post is `A GPT in 60 Lines of NumPy`, I kinda wanted to show "hey it's just numpy, nothing to be scared about!". Also if an import is ONLY used in a single function, I find it visually helps show that "hey, this import is only used in this function" vs when it's at the top of the file you're not really sure when/where and how many times an import is used.

2) Scoping. `load_encoder_hparams_and_params` imports tensorflow, which is really slow to import. When I was testing, I used randomly initialized weights instead of loading the checkpoint which is slower, so I was only making use of the `gpt2` function. If I kept the import at the top level, it would've slowed things down unnecessarily.


Hey ya'll author here!

Thank you for all the nice and constructive comments!

For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...

The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.

Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:

    def lm_loss(params, inputs, n_head) -> float:
        x, y = inputs[:-1], inputs[1:]
        output = gpt(x, **params, n_head=n_head)
        loss = np.mean(-np.log(output[y]))
        return loss
  
    grads = jax.grad(lm_loss)(params, inputs, n_head)
You can even support batching with `jax.vmap` (https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.h...):

    gpt2_batched = jax.vmap(gpt2, in_axes=0)
    gpt2_batched(batched_inputs) # [batch, seq_len] -> [batch, seq_len, vocab]
Of course, with JAX comes in-built GPU and even TPU support!

As far as training code and KV Cache for inference efficiency, I leave that as an exercise for the reader lol


"hackable" and "simple yet complete technical introduction"

Music to my ears, well done and don't worry too much about the negative comments! They'll come out for anything you do I think.

I saw a tweet from someone the other day talking about how they massively increased their training speed by changing part of their architecture to have dimensions that were a factor of 64 rather than a prime-like kind of number.

One of the comments below it? ~"Seems very architecture specific."

lol.

So don't sweat it! <3 Great work and thanks for putting yourself out there, super job! :D :D :D :D :)))))) <3 :D :D :fireworks:


We do GPU-specific training and inference speedups, at CentML.


Grata, well deserved.


This is beautiful. Having worked with everything from nanoGPT to Megatron, sitting down and reading through picoGPT.py was clear and refreshing with just the essential details. Nothing left to add, nothing left to take away: perfection.


This looks like something Peter Norvig would write, and that’s about the highest compliment I can give.


> GPU support

If you haven't tried cuNumeric [1], you really ought to. It's a drop-in NumPy wrapper for distributed GPU acceleration. Would be interesting to see if it works for this.

[1]: https://github.com/nv-legate/cunumeric


The problem with drop-in replacements between CPU and GPU code is that performance GPU code requires rethinking the dataflow often -- so even if the code itself is a drop-in, the "make it good" part still requires some rewriting.

I'd be curious how that library compares to other numeric python GPU libraries


> For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...

Neat, but please add one-line comments/docstrings where these missing bits would go.


Hi there, thank you for putting this together !

I want to commend you for one of the best written introductions in this space that I've seen, especially the excellent use of hyperlinking that points to really good resources exactly at the right time !


Hope it move to like open go ai version. Alpha go comes and goes. We need one and open sources we have one. Hope this is the same.


Tteam5049@gmail.com


Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: