- Why JAX + TPU v5e for SDXL?
- How to write an image generation pipeline in JAX
- Benchmark
- How does the demo work?
Generative AI models, such as Stable Diffusion XL (SDXL), enable the creation of high-quality, realistic content with wide-ranging applications. However, harnessing the power of such models presents significant challenges and computational costs. SDXL is a large image generation model whose UNet component is about three times as large as the one in the previous version of the model. Deploying a model like this in production is challenging due to the increased memory requirements, as well as increased inference times. Hugging Face Diffusers now offers support for serving SDXL with JAX on Cloud TPUs, providing high-performance, cost-efficient inference.
Google Cloud TPUs are custom-designed AI accelerators, which are optimized for training and inference of large AI models, including state-of-the-art LLMs and generative AI models such as SDXL. The new Cloud TPU v5e is purpose-built to bring the cost-efficiency and performance required for large-scale AI training and inference. At less than half the cost of TPU v4, TPU v5e makes it possible for more organizations to train and deploy AI models.
𧨠Diffusers JAX integration offers a convenient way to run SDXL on TPU via XLA, and a demo was built to showcase it. You can try it out in this Space or in the playground embedded below:
Under the hood, this demo runs on several TPU v5e-4 instances (each instance has 4 TPU chips) and takes advantage of parallelization to serve four large 1024Γ1024 images in about 4 seconds. This time includes format conversions, communications time, and frontend processing; the actual generation time is about 2.3s, as will be shown below!
In this blog post,
- A description is provided for why JAX + TPU + Diffusers is a powerful framework to run SDXL
- Guidance is offered on writing a simple image generation pipeline with Diffusers and JAX
- Benchmarks are presented comparing different TPU settings
Why JAX + TPU v5e for SDXL?
Serving SDXL with JAX on Cloud TPU v5e with high performance and cost-efficiency is possible thanks to the combination of purpose-built TPU hardware and a software stack optimized for performance. Two key factors are highlighted below: JAX just-in-time (jit) compilation and XLA compiler-driven parallelism with JAX pmap.
JIT compilation
A notable feature of JAX is its just-in-time (jit) compilation. The JIT compiler traces code during the first run and generates highly optimized TPU binaries that are re-used in subsequent calls. The catch of this process is that it requires all input, intermediate, and output shapes to be static, meaning that they must be known in advance. Every time the shapes change, a new and costly compilation process will be triggered again. JIT compilation is ideal for services that can be designed around static shapes: compilation runs once, and then super-fast inference times can be achieved.
Image generation is well-suited for JIT compilation. If the same number of images are always generated and they have the same size, then the output shapes are constant and known in advance. The text inputs are also constant: by design, Stable Diffusion and SDXL use fixed-shape embedding vectors (with padding) to represent the prompts typed by the user. Therefore, JAX code can be written that relies on fixed shapes, and that can be greatly optimized!
High-performance throughput for high batch sizes
Workloads can be scaled across multiple devices using JAX’s pmap, which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute it in parallel on various XLA devices. For text-to-image generation workloads this means that increasing the number of images rendered simultaneously is straightforward to implement and doesn’t compromise performance. For example, running SDXL on a TPU with 8 chips will generate 8 images in the same time it takes for 1 chip to create a single image.
TPU v5e instances come in multiple shapes, including 1, 4 and 8-chip shapes, all the way up to 256 chips (a full TPU v5e pod), with ultra-fast ICI links between chips. This allows selection of the TPU shape that best suits a use case and easily takes advantage of the parallelism that JAX and TPUs provide.
How to write an image generation pipeline in JAX
This section will detail the steps over the code required to run inference super-fast using JAX! First, dependencies are imported.
# Show best practices for SDXL JAX
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
import time
The base SDXL model and other inference components are now loaded. The diffusers pipeline handles downloading and caching all necessary elements. Adhering to JAX’s functional approach, the model’s parameters are returned separately and will have to be passed to the pipeline during inference:
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True
)
Model parameters are downloaded in 32-bit precision by default. To save memory and run computation faster, conversion to bfloat16 is performed, an efficient 16-bit representation. However, there’s a caveat: for best results, the scheduler state must be kept in float32, otherwise precision errors accumulate and result in low-quality or even black images.
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state
The prompt and other pipeline inputs are now ready to be set up.
default_prompt = "high-quality photo of a baby dolphin ββplaying in a pool and wearing a party hat"
default_neg_prompt = "illustration, low-quality"
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25
The prompts have to be supplied as tensors to the pipeline, and they always have to have the same dimensions across invocations. This allows the inference call to be compiled. The pipeline prepare_inputs method performs all necessary steps, so a helper function will be created to prepare both the prompt and negative prompt as tensors. This will be used later from the generate function:
def tokenize_prompt(prompt, neg_prompt):
prompt_ids = pipeline.prepare_inputs(prompt)
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
return prompt_ids, neg_prompt_ids
To take advantage of parallelization, inputs will be replicated across devices. A Cloud TPU v5e-4 has 4 chips, so by replicating the inputs each chip generates a different image, in parallel. Care must be taken to supply a different random seed to each chip so the 4 images are different:
NUM_DEVICES = jax.device_count()
# Model parameters don't change during inference,
# so we only need to replicate them once.
p_params = replicate(params)
def replicate_all(prompt_ids, neg_prompt_ids, seed):
p_prompt_ids = replicate(prompt_ids)
p_neg_prompt_ids = replicate(neg_prompt_ids)
rng = jax.random.PRNGKey(seed)
rng = jax.random.split(rng, NUM_DEVICES)
return p_prompt_ids, p_neg_prompt_ids, rng
All components are now ready to be assembled into a generate function:
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=neg_prompt_ids,
guidance_scale=guidance_scale,
jit=True,
).images
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
jit=True indicates that the pipeline call is intended for compilation. This will happen the first time generate is called, and it will be very slow β JAX needs to trace the operations, optimize them, and convert them to low-level primitives. A first generation run will be performed to complete this process and warm things up:
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")
This took about three minutes the first execution. But once the code has been compiled, inference will be super fast. A subsequent attempt can be made!
start = time.time()
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt)
print(f"Inference in {time.time() - start}")
It now took about 2s to generate the 4 images!
Benchmark
The following measures were obtained running SDXL 1.0 base for 20 steps, with the default Euler Discrete scheduler. Cloud TPU v5e is compared with TPUv4 for the same batch sizes. Do note that, due to parallelism, a TPU v5e-4 like those used in the demo will generate 4 images when using a batch size of 1 (or 8 images with a batch size of 2). Similarly, a TPU v5e-8 will generate 8 images when using a batch size of 1.
The Cloud TPU tests were run using Python 3.10 and jax version 0.4.16. These are the same specs used in the demo Space.
Benchmark Results:
- TPU v5e-4 (JAX)
- Batch Size: 4, Latency: 2.33s, Perf/$: 21.46
- Batch Size: 8, Latency: 4.99s, Perf/$: 20.04
- TPU v4-8 (JAX)
- Batch Size: 4, Latency: 2.16s, Perf/$: 9.05
- Batch Size: 8, Latency: 4.17s, Perf/$: 8.98
TPU v5e achieves up to 2.4x greater perf/$ on SDXL compared to TPU v4, demonstrating the cost-efficiency of the latest TPU generation.
To measure inference performance, the industry-standard metric of throughput is used. First, latency per image is measured when the model has been compiled and loaded. Then, throughput is then calculated by dividing batch size over latency per chip. As a result, throughput measures how the model is performing in production environments regardless of how many chips are used. Throughput is then divided by the list price to get performance per dollar.
How does the demo work?
The demo previously presented was built using a script that essentially follows the code presented in this blog post. It runs on a few Cloud TPU v5e devices with 4 chips each, and there’s a simple load-balancing server that routes user requests to backend servers randomly. When a prompt is entered in the demo, the request will be assigned to one of the backend servers, and the generated images will be received.
This is a simple solution based on several pre-allocated TPU instances. In a future post, a future post will cover how to create dynamic solutions that adapt to load using GKE.
All the code for the demo is open-source and available in Hugging Face Diffusers today. The potential for creations with Diffusers + JAX + Cloud TPUs is exciting!

