3D U-Net in Video Diffusion Models(VDMs)
Latent Diffusion Models = Autoencoder + Forward process + Denoising process
The Denoising part is the core because it’s the generative process that reconstructs the image in latent space.
Here’s a simple way to think about it: In the forward process, we start with a clear picture and gradually add noise to it. In the reverse process, we start with a noisy picture and try to clean it up.
U-Net is one of the popular neural networks for the denoising process. It is trained to iteratively estimate the noise at each step, refining the noisy latent toward the clean target.
What’s the hack in it?
2D U-Net
It’s got a U-shaped structure that resembles the Encoder/Decoder design architecture. On the left, we got the contracting path, which contracts and compresses features, downsampling the feature map, and upsampling the channels.
On the right side, the expansive path is to recover the feature map while downsampling the channels.
And here’s the special mention: At the start of each expansive stage, we grab the corresponding feature map from the contraction and concatenate it onto itself (along the channel axis). This is called the Skip/Residual connection, one of the most essential hacks in deep neural networks. You can check this reddit post for an interesting interpretation.
From 2D to 3D U-Net
When it comes to video, you’re dealing with a stack of images. So, while a 2D U-Net is processing images with input dimensions like:
# 2D U-net
input_x = (batch, channel, height, width)
We just added a whole new dimension to the 3D U-Net:
# 3D U-net
input_x = (batch, frames, channel, height, width)
The key factor for 3D U-Net is being able to upsample and downsample with visual and temporal awareness. In the VideoCrafter model, they achieve this by (1) ResNet fused with timestep and fps information and (2) the spatio-temporal attention mechanism.
3D U-Net in VideoCrafter
VideoCrafter2 is the text-to-video(T2V) and image-to-video(I2V) generative model; this architecture is also used in StyleCrafter 1, DynamiCrafter 3, etc. The authors(Chen et al.) show a simplified figure of its architecture:
As always, there is more than meets the eye for all deep learning papers! I looked into VideoCrafter’s codebase and drew a more sophisticated figures as follows:
Google Drive: 3D U-Net in VideoCrafter
Google Drive: Detailed 3D U-Net in VideoCrafter, the network starts from the bottom to top.
The size changes are shown in the following table, where f
frames, c
channels, and h
and w
are the feature map’s height and width.
- The denoising 3D U-Net part consists of 12 encoder layers and also 12 layers of the decoder.
- In the contracting path, the feature map is compressed down to 1/4.
- In the contracting path, the channel is expanded 4 times.
- The temporal dimension (16 frames) is neither compressed nor expanded, but is used in the networks.
Building Blocks in 3D U-Net
The core building block of the 3D U-Net in video diffusion models utilizes a combination of residual networks (ResNet) and spatio-temporal attention mechanisms. Additional information includes: texts embedding(prompts) for T2V, image embedding for I2V, and the temporal data such as denoising timesteps and the FPS as the motion speed control.
Temporal data embedding
Diffusion steps(timesteps) and motion speed control(FPS). They are fused as follows (code):
ResNet
ResNet modules are in charge of both upsampling and downsampling. These modules contain convolutional layers responsible for the upsampling and downsampling (code). The temporal embedding from the previous step is integrated into the ResNet module by adding it to the latent along the channel dimension. Additionally, ResNet modules contain temporal convolutions that further refine the temporal representation within the latent space (code).
Text and Image Embedding
In the denoising process, text and image data serve as conditional information. Both are encoded using the CLIP model. For text-to-video (T2V) diffusion, the CLIP-encoded text becomes the context embedding, which is later fed into cross-attention to generate keys and values. In VideoCrafter’s image-to-video (I2V) model, a separate projection network is trained to align the image embedding with the CLIP-encoded text embedding (code). These are then concatenated together to form the context embedding (code).
Spatio-Temporal Attention Mechanisms
The Spatio-Temporal Attention mechanism (ST-attention) is a sequential block of a spatial transformer followed by a temporal transformer. Each transformer — both spatial and temporal — includes two attention modules: self-attention and cross-attention.
The modules within ST-Attention operate as follows:
- Spatial self-attention: Each “pixel” in the feature map attends to every other pixel, creating a pixel-to-pixel attention map for a single frame in latent space.
- Spatial cross-attention: The output from spatial self-attention is projected as queries, while the context embedding is projected into keys and values.
- Temporal self-attention: Each frame attends to all other frames, building a frame-to-frame attention map across a batch of latent frames.
- Temporal cross-attention: The output from temporal self-attention is projected as queries, and the context embedding is again used for keys and values.
ST-Attention for Text-to-Video (T2V)
ST-Attention for Image-to-Video (I2V)
The I2V model in VideoCrafter adds one more cross-attention block.
The modules within ST-Attention operate as follows:
- Spatial self-attention
- Spatial cross-attention: After spatial self-attention, the output is projected as queries, while the context embedding is split into text and image embeddings. These are then separately projected into keys and values, resulting in two attention outputs (\(\text{Out}_{\text{text}}\) and \(\text{Out}_{\text{img}}\)). The final output is obtained by adding these two outputs together, referred to as “dual cross-attention” in the paper.
- Temporal self-attention
- Temporal cross-attention: Similar dual cross-attention is applied here as well.
Comments