jax.lax.scan is a function which allows jit-able loops
jax.lax.scan
Jax is a neural network library used mostly by Google. Jax converts all your implementation into a graph which is executed on your CPU, GPU or TPU. There are two main advantages of using Jax for your implementation:
Jax allows you to jit
your functions. jit
stands for Just-In-Time compilation. This makes your function significantly fast
since it is compiled into something native to the GPUjax.lax.scan
allows you to get around this limitation by
allowing you to define a loop with pre-specified length. But how does it work?
According to the jax documentation, the following code is essentially a translation of the function in pythonic form:
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
The code may be a little convoluted to understand. A simpler way to understand it is to look at some simple examples. The scan
function takes three parameters and scans over the third argument. The first arguments is a function to execute over each scan iteration.
The second argument is some pytree
structure which we initially start from. Lets look at a simple example.
scan(lambda x, y: (x+y, y+2,), 0, [1, 2, 3])
The above code essentially scans through the list [1, 2, 3]
and for each element, returns a tuple consisting of the previous element and the current element. The output of the above code will be:
(6, [3, 4, 5])
Here x
is the carry argument which is initialized to 0
. y
is each element of the array [1, 2, 3]
passed sequentially to the function. In the first pass x
is 0
and y
is 1. The return value is (x+y, y+2,) = (0+1, 1+2,) = (1, 3,)
. For the next iteration, x
is 1
, since 1
is the value of the carry returned in the last iteration and y
is 2
, since 2
is the next value in the input array. Thus, (x+y, y+2,) = (1+2, 2+2,) = (3, 4,)
. In the next iteration, the final value of the carry is 6
, while the final y+2
is 5
. Thus, the scan
function returns 6
(the carry) and [3, 4, 5]
(all the y+2
concatenated)
Here we used a simple example, but the scan function can be used over more complicated arguments, such as the training loop of a neural network. This makes it possible to jit compile the entire training process, resulting in a large gain in training speed.
In jax, it is common to use NamedTuple
s to store various things such as model parameters. The third argument of scan can be used to iterate over the batch dimension of elements in the NamedTuple
. Here is an example. Consider a transformer decoder with 32 blocks (Llama). Parameter values may be saved in a single NamedTuple
as follows:
DecoderBlock(
input_norm=(32, 4096),
attention=Attention(
q_proj=(32, 4096, 1, 32, 128),
k_proj=(32, 4096, 32, 128),
v_proj=(32, 4096, 32, 128),
out_proj=(32, 1, 32, 128, 4096)
),
post_attn_norm=(32, 4096), gate_proj=(32, 4096, 11008),
up_proj=(32, 4096, 11008), down_proj=(32, 11008, 4096)
)
The following function can be used to make in inference for each 32 subsections of the DecoderBlock
:
def inner(state, input_):
key, seq = state
key, subkey = split_key_nullable(key)
seq = decoder_block(input_, seq, attn_mask, key=subkey, model_config=model_config)
return (key, seq), None
(key, seq), _ = jax.lax.scan(inner, (key, seq), params)
Inside the inner
function, input_
is represented in the following format:
DecoderBlock(
input_norm=(4096),
attention=Attention(
q_proj=(4096, 1, 32, 128),
k_proj=(4096, 32, 128),
v_proj=(4096, 32, 128),
out_proj=(1, 32, 128, 4096)
),
post_attn_norm=(4096), gate_proj=(4096, 11008),
up_proj=(4096, 11008), down_proj=(11008, 4096)
)