¡Reproduzcamos NanoGPT con JAX! (Parte 1) | de Luis Wang | julio de 2024

¡Reproduzcamos NanoGPT con JAX!  (Parte 1) |  de Luis Wang |  julio de 2024

Inspirado en el reciente vídeo de YouTube de Andrej Kapathy en Reproduzcamos GPT-2 (124M)Me gustaría reconstruirlo con la mayoría de las optimizaciones de entrenamiento en Jax. Jax está diseñado para una velocidad de computación muy eficiente y es bastante interesante comparar Pytorch con su reciente optimización de entrenamiento y Jax con sus bibliotecas asociadas como Flax (API de capas para entrenar redes neuronales para Jax) y Optax (una biblioteca de optimización y procesamiento de gradientes). para JAX). Aprenderemos rápidamente qué es Jax y reconstruiremos el GPT con Jax. ¡Al final, compararemos el token/seg con el entrenamiento multiGPU entre Pytorch y Jax!

GPT generado por IA

¿Qué es Jax?

Basado en su lee el documentoJAX es una biblioteca de Python para computación de matrices orientada a aceleradores y transformación de programas, diseñada para computación numérica de alto rendimiento y aprendizaje automático a gran escala. Me gustaría presentarles a JAX con su nombre. Mientras alguien lo llama simplemente otro XLA (Álgebra lineal acelerada), prefiero llamarlo J(it) A(utograd) X(LA) para demostrar su capacidad de alta eficiencia.

J: compilación justo a tiempo (JIT). Cuando ejecuta su función Python, Jax la convierte en un conjunto primitivo de operaciones llamado Jaxpr. Luego, la expresión Jaxpr se convertirá en una entrada para XLA, que compila los scripts de nivel inferior para producir un ejecutable optimizado para el dispositivo de destino (CPU, GPU o TPU).

A — Autogrado. Calcular gradientes es una parte esencial de los métodos modernos de aprendizaje automático y simplemente puede llamar jax.grad() para obtener gradientes que permitan optimizar los modelos.

X-XLA. Es un compilador de aprendizaje automático de código abierto para aceleradores de CPU, GPU y ML. Normalmente, XLA realiza múltiples pases integrados de optimización y análisis en el HLO estable gráfico, luego envía el cálculo de HLO a un backend para optimizaciones adicionales en el nivel de HLO. Luego, el backend realiza la generación de código específico del objetivo.

Estas son solo algunas de las características clave de JAX, pero también tiene muchas API similares a numpy fáciles de usar. jax.numpy y vectorización automática con jax.vmap y paralelice sus códigos en múltiples dispositivos a través de jax.pmap Cubriremos más conceptos y aplicaciones de Jax en blogs futuros, ¡pero ahora repliquemos NanoGPT con Jax!

De la atención al transformador

GPT es un modelo de transformador solo decodificador y el componente clave es el módulo de Atención. Primero podemos definir una clase de datos de configuración del modelo para guardar los hiperparámetros del modelo, de modo que el módulo del modelo pueda consumirlos de manera eficiente para inicializar la arquitectura del modelo. De manera similar al modelo GPT 124M, aquí inicializamos un decodificador transformador de 12 capas con 12 cabezales y un tamaño de vocabulario de 50257 tokens, cada uno con 768 dimensiones de incrustación. El tamaño del bloque para el cálculo de la atención es 1024.

from dataclasses import dataclass

@dataclass
class ModelConfig:
vocab_size: int = 50257
n_head: int = 12
n_embd: int = 768
block_size: int = 1024
n_layer: int = 12
dropout_rate: float = 0.1

Pasemos ahora al elemento clave del modelo transformador: la atención. La idea es procesar las entradas en tres matrices de peso: clave, consulta y valor. Aquí confiamos en flax una capa Jax y una biblioteca API de entrenamiento para inicializar la matriz de 3 pesos, simplemente llamando al flax.linen.Dense . Como se mencionó, Jax tiene muchas API similares a numpy, por lo que remodelamos las salidas después de la matriz de peso con jax.numpy.reshape desde [batch_size, sequence_length, embedding_dim] tiene [batch_size, sequence_length, num_head, embedding_dim / num_head]. Dado que necesitamos realizar la multiplicación de matrices en las matrices de clave y valor, jax también tiene jax.numpy.matmul API y jax.numpy.transpose (transponer la matriz clave para la multiplicación).

Atención multidireccional

Tenga en cuenta que debemos poner una máscara en la matriz de atención para evitar la fuga de información (evitar que los tokens anteriores tengan acceso a los tokens posteriores). jax.numpy.tril ayuda a construir una red de triángulos inferiores, y jax.numpy.where podemos llenar el número infinito para que obtengamos 0 después de softmax jax.nn.softmax Los códigos completos para la atención multidireccional se pueden encontrar a continuación.

from flax import linen as nn
import jax.numpy as jnp

class CausalSelfAttention(nn.Module):

config: ModelConfig

@nn.compact
def __call__(self, x, deterministic=True):

assert len(x.shape) == 3

b, l, d = x.shape

q = nn.Dense(self.config.n_embd)(x)
k = nn.Dense(self.config.n_embd)(x)
v = nn.Dense(self.config.n_embd)(x)
# q*k / sqrt(dim) -> softmax -> @v
q = jnp.reshape(q, (b, l, d//self.config.n_head , self.config.n_head))
k = jnp.reshape(k, (b, l, d//self.config.n_head , self.config.n_head))
v = jnp.reshape(v, (b, l, d//self.config.n_head , self.config.n_head))
norm = jnp.sqrt(list(jnp.shape(k))[-1])
attn = jnp.matmul(q,jnp.transpose(k, (0,1,3,2))) / norm
mask = jnp.tril(attn)
attn = jnp.where(mask[:,:,:l,:l], attn, float("-inf"))
probs = jax.nn.softmax(attn, axis=-1)
y = jnp.matmul(probs, v)
y = jnp.reshape(y, (b,l,d))
y = nn.Dense(self.config.n_embd)(y)
return y

Puedes notar que no hay __init__ O forward métodos como puedes ver en pytorch. Esta es la característica especial de jax, donde puedes definir explícitamente capas con setup métodos, o definirlos implícitamente en el pase directo agregando nn.compact sobre __call__ método. [ref]

A continuación, construyamos la capa MLP y Bloque, que incluye la capa Densa, la función de activación Gelu, LayerNorm y Dropout. Nuevamente, flax.linen tiene las API de capa para ayudarnos a construir el módulo. Tenga en cuenta que vamos a pasar un deterministic Variable booleana para controlar diferentes comportamientos durante el entrenamiento o evaluación para ciertas capas como Dropout.

class MLP(nn.Module):

config: ModelConfig

@nn.compact
def __call__(self, x, deterministic=True):
x = nn.Dense(self.config.n_embd*4)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=deterministic)
x = nn.Dense(self.config.n_embd)(x)
x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=deterministic)
return x

class Block(nn.Module):

config: ModelConfig

@nn.compact
def __call__(self, x):
x = nn.LayerNorm()(x)
x = x + CausalSelfAttention(self.config)(x)
x = nn.LayerNorm()(x)
x = x + MLP(self.config)(x)
return x

Ahora usemos los bloques anteriores para construir el NanoGPT:

Dadas las entradas de un token de secuencia de identificadores, utilizamos el flax.linen.Embed capa para obtener incrustaciones de posición e incrustaciones de tokens. Luego los pasamos al módulo Bloque N veces, donde N es el número de capas definidas en la configuración del modelo. Al final, asignamos las salidas del último bloque a las probabilidades de cada token en el vocabulario para predecir el siguiente token. además del delantero __call__ método, creemos también un init métodos para obtener las entradas ficticias para obtener los parámetros del modelo.

class GPT(nn.Module):

config: ModelConfig

@nn.compact
def __call__(self, x, deterministic=False):

B, T = x.shape
assert T <= self.config.block_size

pos = jnp.arange(0, T)[None]
pos_emb = nn.Embed(self.config.block_size, self.config.n_embd)(pos)
wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
tok_emb = wte(x)
x = tok_emb + pos_emb

for _ in range(self.config.n_layer):
x = Block(self.config)(x)
x = nn.LayerNorm()(x)
logits = nn.Dense(config.n_embd, config.vocab_size)
# logits = wte.attend(x) # parameter sharing
return logits

def init(self, rng):
tokens = jnp.zeros((1, self.config.block_size), dtype=jnp.uint16)
params = jax.jit(super().init, static_argnums=(2,))(rng, tokens, True)
return params

Ahora verifiquemos la cantidad de parámetros: primero inicializamos la clase de datos de configuración del modelo y la clave aleatoria, luego creamos entradas ficticias y las introducimos en el modelo GPT. Luego usamos el jax.util.treemap API para crear una función de recuento de parámetros. hemos adquirido 124439808 (124M) parámetros, la misma cantidad que GPT2 de Huggingface, BOOM!

Resultado de colab: número de parámetros
Verifique el número de parámetros en huggingface GPT2

DataLoader y bucle de entrenamiento

Ahora sobreajustaremos un pequeño conjunto de datos. Para hacerlo comparable en el video de Andrej sobre Pytorch NanoGPT, usemos el juguete base de datos que compartió en su video. Usamos el tokenizador GPT2 de tiktoken biblioteca para tokenizar todos los textos en el archivo de entrada y convertir los tokens en jax.numpy.array para entrenar el modelo de Jax.

class DataLoader:
def __init__(self, B, T):
self.current_position = 0
self.B = B
self.T = T

with open("input.txt","r") as f:
text = f.read()
enc = tiktoken.get_encoding("gpt2")
self.tokens = jnp.array(enc.encode(text))
print(f"loaded {len(self.tokens)} tokens in the datasets" )
print(f" 1 epoch = {len(self.tokens)//(B*T)} batches")

def next_batch(self):
B,T = self.B, self.T
buf = self.tokens[self.current_position:self.current_position+B*T+1]
x,y = jnp.reshape(buf[:-1],(B,T)), jnp.reshape(buf[1:],(B,T))
self.current_position += B*T
if self.current_position + B*T+1 > len(self.tokens):
self.current_position = 0
return x,y

Resultado de colaboración: cargador de datos simple con 4 tamaños de lote y 128 longitudes de secuencia

A continuación, olvidémonos primero de la optimización y el entrenamiento distribuido y simplemente creemos un ciclo de entrenamiento ingenuo para verificar la coherencia. Lo primero que debe hacer después de inicializar el modelo es crear un Estado del trenun estado del modelo en el que podemos actualizar parámetros y gradientes. TrainState toma tres entradas importantes: apply_fn (función de transferencia del modelo), params (parámetros del modelo del método init) y tx (una transformación de gradiente Optax).

Luego usamos la función train_step para actualizar el estado del modelo (gradientes y parámetros) para continuar con el entrenamiento del modelo. Optax proporcionar la entropía cruzada softmax como función de pérdida para la siguiente tarea de predicción de token, y jax.value_and_grad Calcula los gradientes y el valor de pérdida para la función de pérdida. Finalmente, actualizamos el estado del modelo con los nuevos parámetros usando la función apply_gradients API. [ref] ¡No olvides ejecutar la función train_step para reducir la carga computacional!

def init_train_state(key, config) -> TrainState:
model = GPT(config)
params = model.init(key)
optimizer = optax.adamw(3e-4, b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-1)
train_state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer)
return train_state

@jax.jit
def train_step(state: TrainState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, TrainState]:

def loss_fn(params: FrozenDict) -> jnp.ndarray:

logits = state.apply_fn(params, x, False)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss

loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state.params)
new_state = state.apply_gradients(grads=grads)
return loss, new_state

Ahora todo está listo para el circuito de entrenamiento del pobre. Comprobemos el valor de la pérdida. La predicción del modelo debería ser mejor que la suposición aleatoria, por lo que la pérdida debería ser menor que -ln(1/50257)≈10,825. Lo que esperamos del sobreajuste de un solo lote es que: al principio la pérdida es cercana a 10,825, luego baja a cerca de 0. Tomemos un lote de (x, y) y ejecutemos el ciclo de entrenamiento 50 veces. También agrego un logaritmo similar para calcular la velocidad de entrenamiento.

Como podemos ver, el valor de pérdida es exactamente lo que esperamos y el rendimiento del entrenamiento es de alrededor de 400-500 mil tokens/seg. Que ya es 40 veces más rápido que la versión inicial de Pytorch sin ninguna optimización en el vídeo de Andrej. Tenga en cuenta que estamos ejecutando los scripts Jax en 1 GPU A100, lo que debería eliminar la diferencia de hardware para la comparación de velocidades. No hay .to(device) elementos para mover su modelo o datos desde la CPU host a la GPU del dispositivo, que es uno de los beneficios de Jax!

Eso es todo, está hecho y lo hemos logrado. Haremos que el entrenamiento sea 10 veces más rápido en la parte 2 con más optimizaciones…

Parte 2:¡El viaje de optimización del entrenamiento a 1350.000 tokens/seg en una sola GPU!

«A menos que se indique lo contrario, todas las imágenes son del autor»