✨TL;DR Summary
- Serverless Diffusion Bottleneck: The baseline cold start for loading the 18GB FLUX.2-Kontext-Dev (9B) model was ~50 seconds, making real-time serverless deployment impossible.
- Memory Snapshot Acceleration: Using Modal's GPU memory snapshots to capture the fully initialized VRAM and system memory state reduced cold starts by 4.2× (11.4s).
- Async VRAM Transfers: Rewriting HuggingFace's synchronous blocking device transfers to use custom PyTorch non-blocking CUDA streams further reduced cold starts to 7.2s.
- Batching Strategy: Abandoned automatic batching (which introduced a 500ms latency tax and VRAM OOMs) in favor of single-request concurrency for stable, progressive generation.
Deploying massive diffusion models like FLUX.2-Kontext-Dev (9B parameters) on serverless GPUs for real-time applications requires a different set of optimization strategies than LLMs. Rather than token-by-token streaming, we are dealing with high-bandwidth image tensors, large weight shifts, and multi-second diffusion step blocks.
Performance Benchmarks
| Stage | Architecture Detail | Latency (Cold) | Latency (Warm) | Throughput Speedup |
|---|---|---|---|---|
| Baseline | Unoptimized HuggingFace Pipeline | 48.3s | 11.6s | 1.0× (Baseline) |
| Memory Snapshot | CRIU GPU state restoration | 11.4s | 6.2s | 4.2× cold reduction |
| Full Pipeline | GPU Snapshot + Async VRAM transfers | 7.2s | 4.1s | 6.7× cold reduction |
Overcoming the 48-Second Cold Start
On startup, loading the ~18 GB of weights for the FLUX.2 pipeline from disk into system memory and then transferring them to the GPU (L40S) takes nearly 50 seconds. Since users expect images under 5 seconds, scale-to-zero serverless was unusable.
GPU Memory Snapshots (CRIU)
By using Modal's GPU memory snapshots, we capture a physical image of the GPU VRAM and system memory after the model is loaded and fully initialized. When a new container is spawned, it bypasses the entire loading process, restoring directly from the memory snapshot in 3.1 seconds.
# Pre-loading hook in Modal class
@app.cls(gpu="L40S", enable_gpu_snapshot=True)
class FluxModel:
@modal.enter()
def enter(self):
# This code runs ONCE when compiling the snapshot
from diffusers import FluxPipeline
import torch
# Pre-initialize pipeline into system memory
self.pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
)
# Warmup compiler and CUDA memory paths
_ = self.pipe("warmup prompt", num_inference_steps=4)Async GPU Transfers
To squeeze another 4.5 seconds out of the warm path, we rewrote the default HuggingFace device transfer code. Instead of synchronous blocking transfers during inference, we implemented a custom non-blocking queue transfer mechanism using PyTorch streams:
# Custom async CUDA transfer stream
cuda_stream = torch.cuda.Stream()
with torch.cuda.stream(cuda_stream):
# Non-blocking transfer allows CPU to process tokens while weights move to VRAM
self.transformer.to(device="cuda", non_blocking=True)
self.text_encoder.to(device="cuda", non_blocking=True)Evaluating Batching Strategies
To handle production load from REimagineHome.AI, we evaluated three batching strategies for multi-concurrency:
1. Automatic Batching (@modal.batched) — REVERTED
Modal intercepts requests and waits up to 500ms to group them. This introduced a mandatory **500ms latency tax** on all requests. If 20 concurrent requests arrived, it spun up 8 containers (leaving 5 idle due to batch over-provisioning) and crashed on batch sizes larger than 12 due to VRAM OOM.
2. Single Request Concurrency (@modal.concurrent) — ADOPTED
No wait-time tax. Each container handles a single image generation request. Autoscaling is predictable and stable, and clients receive a progressive update for each image immediately.
3. Explicit Batch API (generate_batch) — HIGH THROUGHPUT
For background processing pipelines where clients can group requests, we exposed a batch generation endpoint. Running batch generation on the GPU is highly efficient, cutting processing time to **3.22 seconds per image** at batch=10 compared to 4.40 seconds for single images.
Quantization & Compilation Anti-Patterns
torch.compileon distilled models: FLUX.2-Schnell/Klein is step-distilled (4 inference steps). Runningtorch.compile()requires up to 45 seconds of JIT time during inference. Because there are only 4 steps, the compiled kernel execution path has no amortization window, making it slower and unpredictable.- Multiresolution Warmup: Warming up the pipeline with multiple resolutions (e.g., 512px, 768px, 1024px) before snapshotting added 50 seconds to compilation, but did not speed up inference because CUDA kernels compile dynamically at runtime for variable-resolution inputs anyway.
Key Learnings
- Use shared storage for model weights: Mounting model cache folders on a shared
modal.Volumeensures that when snapshots are updated, containers do not download multi-gigabyte models from HuggingFace Hub again. - Avoid text encoders on CPU: Moving the T5-XXL text encoder to the CPU to save VRAM causes a severe execution bottleneck. The L40S GPU has more than enough memory (48GB) to house both the text encoder and the transformer. Keep them both on-device.
Source: §1 (FLUX.2-klein-9B).