PixArt-α with less than 8GB VRAM

Perform the inference process of this generative image model with just 6.4GB of VRAM.

Introduction

In a previous article, we saw how to implement each component of Stable Diffusion separately to understand how they interact with each other. Thanks to this knowledge we can optimize any diffusion model for different situations, for example to decrease memory consumption or reduce inference time.

Since there is a world beyond Stable Diffusion, in this article we are going to optimize the memory usage of the photorealistic diffusion model called PixArt-α. This very creative model has been trained to generate 1024x1024 images and has a series of advantages over Stable Diffusion, being the most significant one the 90% reduction in time and costs of the training process.

The architecture is similar to Stable Diffusion as it is a diffusion model. It employs a tokenizer (T5Tokenizer), a text encoder transformer (T5EncoderModel), a variational autoencoder (VAE), a scheduler (DPM-Solver), and for noise removal, it utilizes another transformer model (Transformer2DModel) instead of a U-Net model.

This model needs 23GB of memory but thanks to diffusers the requirements drop to 11GB. We will do it in 6.4GB.

By instantiating the model components separately we can choose when to load each component into memory, thereby never exceeding the 8GB limit and enabling the execution of this model on mid-range graphics cards. Of course, speed is compromised, but we can still generate images in just 20 seconds using an NVIDIA RTX 2070.

FP8 vs FP16

The official documentation has an article explaining how to run this model using less than 8GB, but it uses bitsandbytes to lower the precision of the text encoder to 8 bits, something that reduces the quality of the result. In this article we will use a different approach maintaining 16 bits of precision.

Installation of libraries

The setup is same as usual: use Python 3.10, install CUDA, create/activate a virtual environment and install the necessary libraries:

Create the virtual environment
python -m venv .venv
Enable virtual environment
# Unix
source .venv/bin/activate

# Windows
.venv\Scripts\activate
Install required libraries
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install transformers diffusers accelerate sentencepiece beautifulsoup4 ftfy

Inference process

We create the file inference.py and start writing our small application.

Blog repository

If you want to copy and paste the entire code, remember that it is available at articles/pixart-a-with-less-than-8gb-vram/inference.py.

In the blog repository on GitHub you will find all the content associated with this and other articles.

Import what we need

We are going to import a few libraries:

  • Python
import torch
from diffusers import PixArtAlphaPipeline
from transformers import T5EncoderModel
import gc

We will make use of the garbage collector to remove the models from memory after they have been used.

Initialize parameters

We cannot generate multiple images at once (in parallel / batch_size), as this would increase memory usage quite a lot. It is also not a very good idea to generate just one image, since loading the models into memory also takes some time. The ideal solution to keep memory usage to a minimum and to maximize generation speed, is to use a generation queue to generate one image after another (serial).

  • Python
queue = []

# Generate image with this prompt. Keep the rest of the parameters as default
queue.extend([{ 'prompt': 'Oppenheimer sits on the beach on a chair, watching a nuclear explosion with a huge mushroom cloud, 1200mm' }])

# Generate image using specific values for all parameters
queue.extend([{
  'prompt': 'pirate ship trapped in a cosmic malestrom nebula',
  'width': 1024,
  'height': 1024,
  'seed': 1152753,
  'cfg': 5,
  'steps': 30,
}])

# Generate 4 images with this prompt. Do not use seed here or they will all come out the same
queue.extend([{ 'prompt': 'supercar', 'cfg': 4 } for _ in range(3)])

Embeddings and Transformer

The next step is to convert the prompt into tokens, then into an embedding and finally pass it through the transformer that applies the attention mechanisms. In PixArt-α, as in Stable Diffusion, the text encoder already takes care of producing a transformed embedding, so the last two parts are already joined.

Also, to avoid doing everything from scratch, let's abstract the code a bit more using the pipeline provided by Hugging Face and thus we save tokenizing ourselves.

We load the text encoder with 16-bit precision (fp16):

  • Python
text_encoder = T5EncoderModel.from_pretrained(
  'PixArt-alpha/PixArt-XL-2-1024-MS',
  subfolder='text_encoder',
  torch_dtype=torch.float16,
  device_map='auto',
)

The device_map='auto' parameter loads the models (in this case the text encoder) wherever possible. It first utilizes the graphics card memory, then starts using RAM, and finally, the hard drive (although for the sake of your mental health, I hope it doesn't have to resort to RAM, let alone the disk).

We assign this text encoder to the pipeline and also tell it that we don't want to use any transformer, thus avoiding loading it into memory for now (transformer=None refers to the transformer responsible for cleaning the noise).

  • Python
pipe = PixArtAlphaPipeline.from_pretrained(
  'PixArt-alpha/PixArt-XL-2-1024-MS',
  torch_dtype=torch.float16,
  text_encoder=text_encoder,
  transformer=None,
  device_map='auto',
)
Bye VAE

We also don't need to load the VAE into memory so we could have added vae=None. The issue is that we'll get an error because the pipeline implementation doesn't seem to be ready to be used this way. I didn't want to dig too much since the article's goal is to never exceed 8GB of memory and we exceed it by far. But you should know that we could save another 300MB this way and even manage to use less than 6GB.

Ir order to avoid reloading this model, we will process all prompts and save the result for later use:

  • Python
with torch.no_grad():
  for generation in queue:
    generation['embeddings'] = pipe.encode_prompt(generation['prompt'])

The line with torch.no_grad() disables the automatic gradient calculation. Without going into detail, it is something we don't need for the inference process and we avoid using memory unnecessarily. We will enter this context every time we make use of a parameterized model.

As we no longer need the text encoder, we remove references to the model so that the garbage collector can clean up and we also clear the CUDA cache:

  • Python
del text_encoder
del pipe
gc.collect()
torch.cuda.empty_cache()
max_memory_allocated

You can see the maximum amount of memory that has been consumed at each point in the application using:

  • Python
print(f'Maximum memory used: {torch.cuda.max_memory_allocated(device="cuda")}')

Image generation

Diffusion models work on a tensor filled with noise, gradually cleaning it through a series of steps managed by the scheduler. When this loop ends, the tensor is passed through the VAE to obtain the final image.

In the case of the pipeline that we are using, all this process is already abstracted into a single function. So, let's instantiate another pipeline again, but this time we won't load the text encoder. In other words, this pipeline will contain the remaining components: the transformer model that is in charge of cleaning the noise (akin to the U-Net in Stable Diffusion), the scheduler and the VAE, which responsible for converting the noise-cleaned tensor into an image.

  • Python
pipe = PixArtAlphaPipeline.from_pretrained(
  'PixArt-alpha/PixArt-XL-2-1024-MS',
  torch_dtype=torch.float16,
  text_encoder=None,
).to('cuda')

In this case it is necessary to use to('cuda') instead of device_map='auto'.

Now we use a loop to process all the embedding tensors, one after another. Remember that this function already takes care of generating a tensor filled with noise and cleaning it over several steps.

  • Python
for i, generation in enumerate(queue, start=1):
  generator = torch.Generator(device='cuda')

  if 'seed' in generation:
    generator.manual_seed(generation['seed'])
  else:
    generator.seed()

  image = pipe(
    negative_prompt=None,
    width=generation['width'] if 'width' in generation else 1024,
    height=generation['height'] if 'height' in generation else 1024,
    guidance_scale=generation['cfg'] if 'cfg' in generation else 7,
    num_inference_steps=generation['steps'] if 'steps' in generation else 20,
    generator=generator,
    prompt_embeds=generation['embeddings'][0],
    prompt_attention_mask=generation['embeddings'][1],
    negative_prompt_embeds=generation['embeddings'][2],
    negative_prompt_attention_mask=generation['embeddings'][3],
    num_images_per_prompt=1,
  ).images[0]

  image.save(f'image_{i}.png')

In this loop we have instantiated a generator to which we assign the seed we have defined (generator.manual_seed(generation['seed'])) or a random seed otherwise (generator.seed()).

After that, all we have to do is pass all the necessary arguments to the pipeline:

  • negative_prompt: PixArt-α does not accept a negative prompt (as far as I know).
  • width and height: The size of the image we have specified or its default value (1024).
  • guidance_scale: The value of CFG that we have specified or its default value (7).
  • num_inference_steps: The steps value we have specified or its default value (20).
  • generator: The generator containing the seed.
  • prompt_embeds, prompt_attention_mask, negative_prompt_embeds and negative_prompt_attention_mask: These values are the ones returned by the text encoder and we had saved to use them here (they are inside the list in this order).
  • num_images_per_prompt: The number of images that are generated at once (batch size). If we optimize for memory usage it does not make sense to change this value.

The pipeline returns a dictionary (images) containing the images already decoded. Since num_images_per_prompt is always going to be 1, we can directly access the single image by images[0]. We save the images to disk.

The results are these:

Using the VAE manually

As I mentioned above, we cannot use the pipeline with the vae=None parameter as it results in an error. If it were possible, we could run the VAE separately as well (or if we don't want to use the abstraction provided by the pipeline).

In this hypothetical scenario, if we use the output_type='latent' parameter in the pipeline, it will return a tensor in the latent space within the images property (don't use [0] here).

  • Python
for generation in queue:
  # generator = ...

  generation['latents'] = pipe(
    # Rest of parameters
    output_type='latent',
  ).images

We cannot clean the tensor and pass it through the VAE in the same loop, since both models would be loaded into memory and we would exceed 8GB of memory. It must be done in two loops.

After finishing the first loop we can remove the transformer since we won't need it anymore (although it seems to be optional).

  • Python
del pipe.transformer
gc.collect()
torch.cuda.empty_cache()

And now, we start the second one that decodes the tensor from the latent space to generate a tensor in the image space:

  • Python
with torch.no_grad():
  for i, generation in enumerate(queue, start=1):
    image = pipe.vae.decode(
      generation['latents'] / pipe.vae.config.scaling_factor,
      return_dict=False,
    )[0]

    image = pipe.image_processor.postprocess(image, output_type='pil')[0]
    image.save(f'image_{i}.png')

Remember that as we saw in the previous article, the VAE scale factor must be taken into account.

The postprocess method is responsible for converting the tensor into a Pillow image (output_type='pil'). This is the same process we did in the other article using the ToPILImage method from the torchvision library.

Finally we save the image in the same way.

Conclusion

In this article we have stepped out of the Stable Diffusion orbit to explore the architecture of PixArt-α, a quite creative diffusion model that uses different components. And since it's a somewhat demanding model, we have taken advantage of prior knowledge to optimize memory usage by loading the components only when necessary.

You can support me so that I can dedicate even more time to writing articles and have resources to create new projects. Thank you!