Key Highlights:
Summarize the following article into 3-5 concise bullet points in HTML without further information from your side. format:
I have been wanting to find a project to try out the data-parallel language Futhark. They have a very good blog that I’ve been
following for years, but I’ve never actually written anything in it.
Andrej Karpathy’s microgpta
self-contained implementation of a GPT-2-like neural network in 200 lines of
Python, finally provided the excuse. I like microgpt, but it does not
scale at all. Obviously the point of this implementation is not
efficiency, but it’s not just that it’s slow: you also can’t scale up to
even slightly larger networks, because you quickly hit Python recursion depth
errors.
So, I was curious whether I could port it as 1-to-1 as possible and get much
better scaling without losing too much concision. The answer, as it turns out,
is sort-of: the port scales much better but is not as concise. Parts of it
translate quite nicely though.
This post, Part I, will start with just the forward pass. I’ll alternate code
from Karpathy’s original Python version with my Futhark translation, attempting
to keep things as similar as possible, while still taking advantage of
Futhark’s parallel
primitives.
LLM parameters
First, the data structures holding the LLM parameters (weights). We will assume
these are pretrained (the training code will come in Part II).
Python:
n_layer = 1
n_embd = 16
block_size = 16
n_head = 4
head_dim = n_embd // n_head
matrix = lambda nout, nin, std=0.08: ((Value(random.gauss(0, std)) for _ in range(nin)) for _ in range(nout))
state_dict = {‘wte’: matrix(vocab_size, n_embd), ‘wpe’: matrix(block_size, n_embd), ‘lm_head’: matrix(vocab_size, n_embd)}
for i in range(n_layer):
state_dict(f’layer{i}.attn_wq’) = matrix(n_embd, n_embd)
state_dict(f’layer{i}.attn_wk’) = matrix(n_embd, n_embd)
state_dict(f’layer{i}.attn_wv’) = matrix(n_embd, n_embd)
state_dict(f’layer{i}.attn_wo’) = matrix(n_embd, n_embd)
state_dict(f’layer{i}.mlp_fc1′) = matrix(4 * n_embd, n_embd)
state_dict(f’layer{i}.mlp_fc2′) = matrix(n_embd, 4 * n_embd)
params = (p for mat in state_dict.values() for row in mat for p in row)
Futhark:
def n_layer : i64 = 1
def n_embd : i64 = 16
def block_size : i64 = 16
def n_head : i64 = 4
def head_dim : i64 = n_embd / n_head
type params (v) = {
wte: (v)(n_embd)f32, — token embeddings
wpe: (block_size)(n_embd)f32, — position embeddings
lm_head: (v)(n_embd)f32, — output projection
attn_wq: (n_layer)(n_embd)(n_embd)f32, — query weights
attn_wk: (n_layer)(n_embd)(n_embd)f32, — key weights
attn_wv: (n_layer)(n_embd)(n_embd)f32, — value weights
attn_wo: (n_layer)(n_embd)(n_embd)f32, — output weights
mlp_fc1: (n_layer)(4 * n_embd)(n_embd)f32, — MLP up-projection
mlp_fc2: (n_layer)(n_embd)(4 * n_embd)f32 — MLP down-projection
}
Next, the model architecture.
Python:
def linear(x, w):
return (sum(wi * xi for wi, xi in zip(wo, x)) for wo in w)
def softmax(logits):
max_val = max(val.data for val in logits)
exps = ((val – max_val).exp() for val in logits)
total = sum(exps)
return (e / total for e in exps)
def rmsnorm(x):
ms = sum(xi * xi for xi in x) / len(x)
scale = (ms + 1e-5) ** -0.5
return (xi * scale for xi in x)
Futhark:
def linear (n)(m) (x: (n)f32) (w: (m)(n)f32) : (m)f32 =
map (\w_row -> reduce (+) 0f32 (map2 (*) w_row x)) w
def softmax (n) (logits: (n)f32) : (n)f32 =
let max_val = reduce f32.max f32.lowest logits
let exps = map (\v -> f32.exp (v – max_val)) logits
let total = reduce (+) 0f32 exps
in map (/ total) exps
def rmsnorm (n) (x: (n)f32) : (n)f32 =
let ms = reduce (+) 0f32 (map (\xi -> xi * xi) x) / f32.i64 n
let scale = 1f32 / f32.sqrt (ms + 1e-5)
in map (* scale) x
I’m pleased at how nicely these three functions translate. The explicit typing
does add a little syntax noise. Arguably reduce (+) 0f32 is also not as
nice a way to spell sum. But it’s generally readable, especially if
you are already familiar with these kinds of functional combinators. The number
of lines of code stayed exactly the same.
And finally, the GPT forward pass, complete with a KV cache.
Karpathy’s Python original:
def gpt(token_id, pos_id, keys, values):
tok_emb = state_dict(‘wte’)(token_id) # token embedding
pos_emb = state_dict(‘wpe’)(pos_id) # position embedding
x = (t + p for t, p in zip(tok_emb, pos_emb)) # joint token and position embedding
x = rmsnorm(x) # note: not redundant due to backward pass via the residual connection
for li in range(n_layer):
# 1) Multi-head Attention block
x_residual = x
x = rmsnorm(x)
q = linear(x, state_dict(f’layer{li}.attn_wq’))
k = linear(x, state_dict(f’layer{li}.attn_wk’))
v = linear(x, state_dict(f’layer{li}.attn_wv’))
keys(li).append(k)
values(li).append(v)
x_attn = ()
for h in range(n_head):
hs = h * head_dim
q_h = q(hs:hs+head_dim)
k_h = (ki(hs:hs+head_dim) for ki in keys(li))
v_h = (vi(hs:hs+head_dim) for vi in values(li))
attn_logits = (sum(q_h(j) * k_h(t)(j) for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h)))
attn_weights = softmax(attn_logits)
head_out = (sum(attn_weights(t) * v_h(t)(j) for t in range(len(v_h))) for j in range(head_dim))
x_attn.extend(head_out)
x = linear(x_attn, state_dict(f’layer{li}.attn_wo’))
x = (a + b for a, b in zip(x, x_residual))
# 2) MLP block
x_residual = x
x = rmsnorm(x)
x = linear(x, state_dict(f’layer{li}.mlp_fc1′))
x = (xi.relu() for xi in x)
x = linear(x, state_dict(f’layer{li}.mlp_fc2′))
x = (a + b for a, b in zip(x, x_residual))
logits = linear(x, state_dict(‘lm_head’))
return logits
My Futhark port:
def gpt (v)
(p: params (v))
(token_id: i64) (pos_id: i64)
(keys: *(n_layer)(block_size)(n_embd)f32)
(values: *(n_layer)(block_size)(n_embd)f32)
: ((v)f32,
*(n_layer)(block_size)(n_embd)f32,
*(n_layer)(block_size)(n_embd)f32) =
let tok_emb = p.wte(token_id) — token embedding
let pos_emb = p.wpe(pos_id) — position embedding
let x = map2 (+) tok_emb pos_emb — joint token and position embedding
let x = rmsnorm x
let (x, keys, values) =
loop (x, keys, values) for li < n_layer do
— 1) Multi-head Attention block
let x_residual = x
let x_norm = rmsnorm x
let q = linear x_norm p.attn_wq(li)
let k = linear x_norm p.attn_wk(li)
let v_vec = linear x_norm p.attn_wv(li)
let keys = keys with (li, pos_id) = k
let values = values with (li, pos_id) = v_vec
let x_attn = flatten (
tabulate n_head (\h ->
let hs = h * head_dim
let q_h = tabulate head_dim (\j -> q(hs + j))
let scale = 1f32 / f32.sqrt (f32.i64 head_dim)
let attn_logits = tabulate block_size (\t ->
let dot = reduce (+) 0f32 (
tabulate head_dim (\j -> q_h(j) * keys(li, t, hs + j))
)
in if t <= pos_id then dot * scale else -1e30f32
)
let attn_weights = softmax attn_logits
in tabulate head_dim (\j ->
reduce (+) 0f32 (
tabulate block_size (\t -> attn_weights(t) * values(li, t, hs + j))
)
)
)
) :> (n_embd)f32
let x_out = linear x_attn p.attn_wo(li)
let x = map2 (+) x_out x_residual
— 2) MLP block
let x_residual = x
let x_norm = rmsnorm x
let x_mlp = linear x_norm p.mlp_fc1(li)
let x_mlp = map (f32.max 0) x_mlp
let x_mlp = linear x_mlp p.mlp_fc2(li)
let x = map2 (+) x_mlp x_residual
in (x, keys, values)
let logits = linear x p.lm_head
in (logits, keys, values)
The for loops mostly translate to tabulatewhich you
can think of as essentially a parallel for loop. The MLP block at the end
translated directly too. The attention block was a bit hairier to wrangle into
Futhark’s constraints, due to using some imperative Python features and
destructively updated data structures. But it was not too bad. The main change
was to preallocate the KV cache in a fixed-size array (size
(n_layer)(block_size)(n_embd)), and then, in the attn_logits
calculation, mask out “future” tokens to keep the model causal (not needed in
the Python version because the list was constructed in causal order).
The total lines of code for this function (excluding comments and blank lines)
crept up from 33 to 51, but partly because I broke statements across lines more
liberally than the original did.
I will admit that, even as someone who likes this style of functional
programming, the end result is arguably less readable. The deep nesting in
particular is a bit hard to follow: once you’re inside a lambda inside a
tabulate inside a reduce inside a tabulate inside
another tabulateinside a flattenit can be easy to lose
track of what’s going on. This could probably be refactored to be more
readable, but for now I stuck to as close a translation as possible, so since
Karpathy had a bunch of list comprehensions inside of nested loops, I kept the
same structure.
There is also one minor annoyance to please the size-typing system: Futhark
infers that the result of flatten is a size n_head * head_dim
1d array (because we flattened an array of size (n_head)(head_dim)),
but it isn’t able to further infer that this is same as a size n_embd
1d array. So we need to use the :> size coercion operator. On the
other hand, there is a minor readability improvement: map2which
maps a 2-parameter function elementwise across two lists, is more intuitive imo
than the zip version.
* * *
Missing here, of course, is the star of the show: training the model!
That will come in Part II, along with some benchmarks.
License is not valid, please check your API Key!
