PixArt-α con menos de 8GB de VRAM

Realiza el proceso de inferencia de este modelo generativo de imágenes con tan solo 6.4GB de VRAM.

Introducción

En un artículo anterior, vimos cómo implementar cada componente de Stable Diffusion por separado para entender cómo interactúan entre sí. Gracias a este conocimiento podemos optimizar cualquier modelo de difusión para distintas situaciones, por ejemplo para disminuir el consumo de memoria o reducir el tiempo de inferencia.

Ya que hay mundo más allá de Stable Diffusion, en este artículo vamos a optimizar el uso de memoria del modelo de difusión fotorealístico llamado PixArt-α. Este modelo tan creativo ha sido entrenado para generar imágenes de 1024x1024 y cuenta con una serie de ventajas respecto a Stable Diffusion, siendo la mayor de ellas la reducción en un 90% del tiempo y costes del proceso de entrenamiento.

La arquitectura es similar a Stable Diffusion al tratarse de un modelo de difusión. Se utiliza un tokenizador (T5Tokenizer), un transformer del tipo text encoder (T5EncoderModel), un variational autoencoder (VAE), un scheduler (DPM-Solver) y para limpiar el ruido se utiliza otro modelo de tipo transformer (Transformer2DModel) en vez de un modelo U-Net.

Este modelo necesita 23GB de memoria pero gracias a diffusers los requerimientos bajan a 11GB. Nosotros lo haremos en 6.4GB.

Al instanciar los componentes del modelo por separado podemos elegir en qué momento cargamos cada componente en memoria, para así no sobrepasar nunca los 8GB y poder ejecutar este modelo en tarjetas gráficas de gama media. Por supuesto, la velocidad se resiente, pero podremos generar imágenes en apenas 20 segundos utilizando una NVIDIA RTX 2070.

FP8 vs FP16

La documentación oficial posee un artículo explicando como ejecutar este modelo utilizando menos de 8GB, pero utiliza bitsandbytes para bajar la precisión del text encoder a 8 bits, algo que reduce la calidad del resultado. En este artículo utilizaremos un enfoque distinto manteniendo 16 bits de precisión.

Instalación de librerías

La puesta a punto es la misma de siempre: utilizar Python 3.10, instalar CUDA, crear/activar un entorno virtual e instalar las librerías necesarias:

Crear el entorno virtual
python -m venv .venv
Activar el entorno virtual
# Unix
source .venv/bin/activate

# Windows
.venv\Scripts\activate
Instalar librerías necesarias
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install transformers diffusers accelerate sentencepiece beautifulsoup4 ftfy

Proceso de inferencia

Creamos el fichero inference.py y empezamos a escribir nuestra pequeña aplicación.

Repositorio del blog

Si quieres copiar y pegar el código entero, recuerda que lo tienes disponible en articles/pixart-a-with-less-than-8gb-vram/inference.py.

En el repositorio del blog en GitHub encontrarás todo el contenido asociado con este y otros artículos.

Importar lo necesario

Vamos a importar unas cuantas librerías:

  • Python
import torch
from diffusers import PixArtAlphaPipeline
from transformers import T5EncoderModel
import gc

Haremos uso del garbage collector para sacar de la memoria los modelos tras haberlos utilizado.

Inicializar parámetros

No podemos generar varias imágenes a la vez (en paralelo / batch_size), ya que esto aumentaría bastante el uso de memoria. Tampoco es muy buena idea generar solo una imagen, ya que cargar los modelos en memoria también tarda algo de tiempo. La solución ideal para mantener un uso mínimo de memoria y exprimir la velocidad de generación al máximo, es utilizar una cola de generación (queue) para generar una imagen tras otra (en serie).

  • Python
queue = []

# Generar imagen con este prompt. Mantener el resto de parámetros por defecto
queue.extend([{ 'prompt': 'Oppenheimer sits on the beach on a chair, watching a nuclear explosion with a huge mushroom cloud, 1200mm' }])

# Generar imagen utilizando valores específicos para todos los parámetros
queue.extend([{
  'prompt': 'pirate ship trapped in a cosmic malestrom nebula',
  'width': 1024,
  'height': 1024,
  'seed': 1152753,
  'cfg': 5,
  'steps': 30,
}])

# Generar 4 imágenes con este prompt. No utilizar semilla aquí o saldrán todas iguales
queue.extend([{ 'prompt': 'supercar', 'cfg': 4 } for _ in range(3)])

Embeddings y Transformer

El siguiente paso es convertir el prompt en tokens, después en un embedding y por último pasarlo por el transformer que aplica los mecanismos de atención. En PixArt-α, al igual que en Stable Diffusion, el text encoder ya se encarga de producir un embedding transformado, por lo que las dos últimas partes ya están unidas.

Además, para no hacer todo desde cero, vamos a abstraer el código un poco más utilizando la pipeline proporcionada por Hugging Face y así nos ahorramos tokenizar nosotros mismos.

Cargamos el text encoder con una precisión de 16 bits (fp16):

  • Python
text_encoder = T5EncoderModel.from_pretrained(
  'PixArt-alpha/PixArt-XL-2-1024-MS',
  subfolder='text_encoder',
  torch_dtype=torch.float16,
  device_map='auto',
)

El parámetro device_map='auto' carga los modelos (en este caso el text encoder) donde sea posible. Primero utiliza la memoria de la tarjeta gráfica, después empieza a utilizar la memoria RAM y, por último, el disco duro (aunque por el bien de tu salud mental, espero que no tenga que recurrir a la RAM y menos aún al disco).

Asignamos este text encoder a la pipeline y también le indicamos que no queremos utilizar ningún transformer, así evitamos cargarlo en memoria de momento (transformer=None se refiere al transformer que se encarga de limpiar el ruido).

  • Python
pipe = PixArtAlphaPipeline.from_pretrained(
  'PixArt-alpha/PixArt-XL-2-1024-MS',
  torch_dtype=torch.float16,
  text_encoder=text_encoder,
  transformer=None,
  device_map='auto',
)
Bye VAE

Tampoco necesitamos cargar el VAE en memoria así que podríamos haber añadido vae=None. El problema es que recibiremos un error porque la implementación del pipeline parece no estar preparada para ser usada así. No he querido indagar mucho ya que el objetivo del artículo es nunca sobrepasar los 8GB de memoria y lo cumplimos con creces. Pero que sepas que podríamos ahorrar otros 300MB extra de esta manera e incluso poder llegar a utilizar menos de 6GB.

Para no tener que volver a cargar este modelo, procesaremos todos los prompts y guardaremos el resultado para su posterior uso:

  • Python
with torch.no_grad():
  for generation in queue:
    generation['embeddings'] = pipe.encode_prompt(generation['prompt'])

La línea with torch.no_grad() desactiva el cálculo automático del gradiente. Sin entrar en detalle, es algo que no necesitamos para el proceso de inferencia y evitamos utilizar memoria innecesariamente. Entraremos en este contexto cada vez que hagamos uso de un modelo parametrizado.

Como ya no necesitamos el text encoder, eliminamos las referencias al modelo para que el garbage collector haga limpieza y también borramos la caché de CUDA:

  • Python
del text_encoder
del pipe
gc.collect()
torch.cuda.empty_cache()
max_memory_allocated

Puedes ver el máximo de memoria que se ha consumido en cada punto de la aplicación mediante:

  • Python
print(f'Máxima memoria utilizada: {torch.cuda.max_memory_allocated(device="cuda")}')

Generaración de imágenes

Los modelos de difusión trabajan sobre un tensor lleno de ruido que van limpiando a lo largo de una serie de pasos gestionados por el scheduler. Cuando finaliza este bucle se pasa el tensor por el VAE para obtener la imagen final.

En el caso del pipeline que estamos utilizando ya se abstrae en una única función todo este proceso. Así que, vamos a instanciar de nuevo otra pipeline solo que en este caso no cargaremos el text encoder. Es decir, esta pipeline contendrá el resto de componentes: el modelo transformer que se encarga de limpiar el ruido (lo que sería la U-Net de Stable Diffusion), el scheduler y el VAE, que se encarga de convertir el tensor limpio de ruido en una imagen.

  • Python
pipe = PixArtAlphaPipeline.from_pretrained(
  'PixArt-alpha/PixArt-XL-2-1024-MS',
  torch_dtype=torch.float16,
  text_encoder=None,
).to('cuda')

En este caso es necesario utilizar to('cuda') en vez de device_map='auto'.

Ahora utilizamos un bucle para procesar todos los tensores de embedding, uno tras otro. Recuerda que esta función ya se encarga de generar un tensor lleno de ruido y de limpiarlo durante varios pasos.

  • Python
for i, generation in enumerate(queue, start=1):
  generator = torch.Generator(device='cuda')

  if 'seed' in generation:
    generator.manual_seed(generation['seed'])
  else:
    generator.seed()

  image = pipe(
    negative_prompt=None,
    width=generation['width'] if 'width' in generation else 1024,
    height=generation['height'] if 'height' in generation else 1024,
    guidance_scale=generation['cfg'] if 'cfg' in generation else 7,
    num_inference_steps=generation['steps'] if 'steps' in generation else 20,
    generator=generator,
    prompt_embeds=generation['embeddings'][0],
    prompt_attention_mask=generation['embeddings'][1],
    negative_prompt_embeds=generation['embeddings'][2],
    negative_prompt_attention_mask=generation['embeddings'][3],
    num_images_per_prompt=1,
  ).images[0]

  image.save(f'image_{i}.png')

En este bucle hemos instanciado un generador al que le asignamos la semilla que hemos definido (generator.manual_seed(generation['seed'])) o una semilla aleatoria en caso contrario (generator.seed()).

Después, lo único que tenemos que hacer es pasarle todos los argumentos necesarios al pipeline:

  • negative_prompt: PixArt-α no acepta un prompt negativo (que yo sepa).
  • width y height: El tamaño de la imagen que hemos especificado o su valor por defecto (1024).
  • guidance_scale: El valor de CFG que hemos especificado o su valor por defecto (7).
  • num_inference_steps: El valor de steps que hemos especificado o su valor por defecto (20).
  • generator: El generador que contiene la semilla.
  • prompt_embeds, prompt_attention_mask, negative_prompt_embeds y negative_prompt_attention_mask: Estos valores son los que devolvió el text encoder y habíamos guardado para utilizarlos aquí (están dentro de la lista en este orden).
  • num_images_per_prompt: La cantidad de imágenes que se generan a la vez (batch size). Si optimizamos por uso de memoria no tiene sentido cambiar este valor.

El pipeline nos devuelve un diccionario (images) en el que se encuentran las imágenes ya decodificadas. Como num_images_per_prompt siempre va a ser 1, podemos acceder directamente a la única imagen mediante images[0]. Guardamos las imágenes en el disco.

Los resultados son estos:

Utilizando el VAE manualmente

Como dije arriba, no podemos utilizar el pipeline con el parámetro vae=None porque da error. Si se pudiera, podríamos ejecutar el VAE también por separado (o si no queremos utilizar la abstracción que ofrece el pipeline).

En este hipotético caso, si utilizamos el parámetro output_type='latent' en el pipeline, éste nos devolverá un tensor en el espacio latente dentro de la propiedad images (no utilices [0] aquí).

  • Python
for generation in queue:
  # generator = ...

  generation['latents'] = pipe(
    # Resto de parámetros
    output_type='latent',
  ).images

No podemos limpiar el tensor y pasarlo por el VAE en el mismo bucle, ya que se cargarían en memoria ambos modelos y sobrepasaríamos los 8GB de memoria. Hay que hacerlo en dos bucles.

Tras finalizar el primer bucle podemos borrar el transformer ya que no lo vamos a necesitar más (aunque aparentemente no es obligatorio).

  • Python
del pipe.transformer
gc.collect()
torch.cuda.empty_cache()

Y ahora sí, iniciamos el segundo que decodificará el tensor del espacio latente para generar un tensor en el espacio de imagen:

  • Python
with torch.no_grad():
  for i, generation in enumerate(queue, start=1):
    image = pipe.vae.decode(
      generation['latents'] / pipe.vae.config.scaling_factor,
      return_dict=False,
    )[0]

    image = pipe.image_processor.postprocess(image, output_type='pil')[0]
    image.save(f'image_{i}.png')

Recuerda que tal como vimos en el artículo anterior, hay que tener en cuenta el factor de escala del VAE.

El método postprocess se encarga de convertir el tensor en una imagen de Pillow (output_type='pil'). Este es el mismo proceso que hicimos en el otro artículo mediante el método ToPILImage de la librería torchvision.

Por último guardamos la imagen de la misma manera.

Conclusión

En este artículo hemos salido de la órbita de Stable Diffusion para explorar la arquitectura de PixArt-α, un modelo de difusión bastante creativo que utiliza diferentes componentes. Y ya que es un modelo algo exigente, hemos aprovechado el conocimiento previo para optimizar el uso de memoria cargando los componentes solo cuando es necesario.

Puedes apoyarme para que pueda dedicar aún más tiempo a escribir artículos y tener recursos para crear nuevos proyectos. ¡Gracias!