Post

Flow Matching

Flow matching explained and implemented

Flow Matching

Hello all, today I will cover flow matching. Flow matching is a (somewhat) novel technique to image generation, that seems to outperform all existing techniques. It is the backbone of many modern models like Stable Diffusion 3. So we will start off with the beloved Literature Review, and look through existing technique. Full code is available here

Image of some clothes generated by the model

Literature Review

So in the existing pools of techniques, we have these:

 VAEGANVAE-GANDiffusionFlow matching
TrainingDecent speed, albeit a bit wasteful - oftentimes the encoder is thrown awayPainful. Model collapse happens reasonably often.Model collapse quite unlikely, but again, you throw away encoder and discriminator usually.Pretty decent speeds, but there are many ways to screw up due to high amounts of mathVery good, trains very fast, and simple to implement
Inference speedGood - single passGoodGoodPoor - often needs many passes to generate decent qualityMid - Faster than diffusion, slower than single pass models
QualityTrash - images are often blurryGoodGood, but blurrier than pure GANVery goodVery good

Overall, flow matching is a pretty good balance between speed, quality, and inference time. However, due to the high amounts of attention diffusion and flow matching models tend to use in the U-net, the models may require unwieldly amounts of VRAM and train time if we try to make the images any bigger than 128x128 or so. This is why latent diffusion (which can be trivially extended to flow matching) was invented, but we will not cover that today as we will just be working with the 28x28 Fashion MNIST dataset.

Math

I shall now cover the math of the flow matching techniques. I will casually ignore all the fancy probability and integrals used in the official paper (because admittedly I can’t understand a good chunk of it too), and use the simplest interpretation instead - the linear interpolation interpretation.

Let us start off by defining some terms. Let $x$ be the image, and $z$ be random noise. We start off at complete random noise at $t = 0$. Then all we need to do is draw a line (ie linearly interpolate) between $z$ and $x$, from $t = 0$ to $t = 1$. We will represent the model as $f$.

\[\displaylines{ x_{t} = xt + (1-t)z\\ \frac{dx_{t}}{dt} = x - z\\ f(x_{t}, t) = \frac{dx_{t}}{dt} = x - z }\]

As such, all we have to do is train the model to find $x - z$, given $x_{t}$ and $t$.

Architecture

We will design the model in a classical U-net format, but after each of the convolution blocks, we introduce a self-attention. You can probably skip this whole section if you have implemented a diffusion model before.

Note this picture is just to give a rough idea of what a U-net looks like, we will implement something rather different Image of a generic unet model

Time embedding

First we need to settle the time embedding, this we inject into every layer. Let $t$ be the time we want to obtain, $d$ be half the embedding dimension, and $x$ be the embedding.

\[x_{i} = sin(\frac{t}{10000^{\frac{i}{d}}})\]

We take this equation but cos instead of sin, and concat it, to form the entire embedding. Here it is in code:

1
2
3
4
5
6
7
def get_time_embedding(time_steps, temb_dim):
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb

Downsample

I will skip a lot of the code for brevity, you can find the full thing here

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, down_sample = True, attend = True):
        super().__init__()
        ...
        self.resnet_conv_first = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, stride=1, padding=1),
        )
        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, out_channels)
        )
        self.resnet_conv_second = ...
        if attend:
            self.attention_norms = nn.GroupNorm(8, out_channels)
            
            self.attentions = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)

        self.residual_input_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else nn.Identity()
    

And the forward code,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def forward(self, x, t_emb):
    out = x
    resnet_input = out
    out = self.resnet_conv_first(out)
    out = out + self.t_emb_layers(t_emb)[:, :, None, None]
    out = self.resnet_conv_second(out)
    out = out + self.residual_input_conv(resnet_input)
    if self.attend:
        batch_size, channels, h, w = out.shape
        in_attn = out.reshape(batch_size, channels, h * w)
        in_attn = self.attention_norms(in_attn)
        in_attn = in_attn.transpose(1, 2)
        out_attn, _ = self.attentions(in_attn, in_attn, in_attn)
        out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
        out = out + out_attn
        
    out = self.down_sample_conv(out)
    return out

Note how we inject the time embedding in between the conv layers. We do this for every layer later as well. Another key thing to note is only downsample AFTER you do everything. And do the reverse for the upsampling portion later in the up block. Otherwise you’ll just get a total mess of an output. Not too sure how related, but do check out this article.

Middle block

I will skip the initialization code, it’s just the same as that of the downsampler. For this block, instead of resnet -> attention, we do resnet -> attention -> resnet.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def forward(self, x, t_emb):
    out = x
    resnet_input = out
    out = self.resnet_convs[0](out)
    out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
    out = self.resnet_convs[1](out)
    out = out + self.residual_input_conv[0](resnet_input)

    batch_size, channels, h, w = out.shape
    in_attn = out.reshape(batch_size, channels, h * w)
    in_attn = self.attention_norms(in_attn)
    in_attn = in_attn.transpose(1, 2)
    out_attn, _ = self.attentions(in_attn, in_attn, in_attn)
    out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
    out = out + out_attn
    
    resnet_input = out
    out = self.resnet_convs[2](out)
    out = out + self.t_emb_layers[1](t_emb)[:, :, None, None]
    out = self.resnet_convs[3](out)
    out = out + self.residual_input_conv[1](resnet_input)
    
    return out

Upsample

Now for the upsample,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, attend = True):
        super().__init__()
        ...
        self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
                                                 4, 2, 1) \
            if self.up_sample else nn.Identity()
    
    def forward(self, x, out_down, t_emb):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim=1)
        out = x
        resnet_input = out
        out = self.resnet_conv_first(out)
        out = out + self.t_emb_layers(t_emb)[:, :, None, None]
        out = self.resnet_conv_second(out)
        out = out + self.residual_input_conv(resnet_input)
        if self.attend:
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms(in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions(in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
        
        return out

As you will note, we first upsample with ConvTranspose2d, then we concat with the respective downsample block, for the U-net residuals. We do this in the channel dimensions.

U-net

Now we piece them all together.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Unet(nn.Module):
    def __init__(self, im_channels, down_channels, mid_channels, t_emb_dim, down_sample, attend):
        super().__init__()
        self.down_channels = down_channels
        self.mid_channels = mid_channels
        self.t_emb_dim = t_emb_dim
        self.down_sample = down_sample

        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )

        self.up_sample = list(reversed(self.down_sample))
        ...
        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
    
    def forward(self, x, t):
        out = self.conv_in(x)
        t_emb = get_time_embedding(t, self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        
        down_outs = []
        
        for idx, down in enumerate(self.downs):
            down_outs.append(out)
            out = down(out, t_emb)
        for mid in self.mids:
            out = mid(out, t_emb)
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)
        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        return out

And just like that, we have implemented our model! Finally, let us do some trainings

Training

We prepare the dataset, we will use Fashion MNIST,

1
2
3
4
5
6
7
8
9
10
11
dataset_dir = "datasets/"

transforms_fmnist = transforms.Compose([transforms.ToTensor()])
train_dataset = FashionMNIST(root=dataset_dir,
                             train=True,
                             download=True,
                             transform=transforms_fmnist)
BATCH_SIZE = 256
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

Prepare the model,

1
2
3
4
5
6
7
8
9
10
11
12
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Unet(im_channels = 1,
             down_channels = [32, 64, 128], 
             mid_channels = [128, 64], 
             t_emb_dim = 128, 
             down_sample = [True, True],
             attend = [False, True]).to(device)
model.train()
num_epochs = 10
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

Finally, train it

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
model.train()
for epoch_idx in range(num_epochs):
    losses = []
    for im, _ in tqdm(train_dataloader):
        optimizer.zero_grad()
        im = im.float().to(device)
        im = (2 * im) - 1
        
        t = torch.rand((im.shape[0])).to(device)
        noise = torch.randn_like(im).to(device)
        
        noisy_im = im * t[:, None, None, None] + (1 - t)[:, None, None, None] * noise
        v_pred = model(noisy_im, t)

        loss = criterion(v_pred, im - noise)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
    print('Finished epoch:{} | Loss : {:.4f}'.format(
        epoch_idx + 1,
        np.mean(losses),
    ))

Note how as mentioned, we are training the model to find $x - z$, given $x_{t}$ and $t$.

As for inference, we will just use simple Euler methods to solve the ODE represented with $\frac{dx_{t}}{dt}$, to find $x$ given completely random noise.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
n_step = 100
model.eval()
xt = torch.randn((128, 1, 28, 28)).to(device)
with torch.no_grad():
    for i in tqdm(range(n_step)):
        velo = model(xt, torch.as_tensor(i / n_step).unsqueeze(0).to(device))
        xt = xt + velo / n_step

ims = torch.clamp(xt, -1., 1.).detach().cpu()
ims = (ims + 1) / 2
grid = make_grid(ims, nrow=16)
img = torchvision.transforms.ToPILImage()(grid)
img.save('out.png')
img.close()
model.train()

And that’s it! We now have implemented and trained a flow matching model on Fashion MNIST! Once again, full code is available here

Here is some sample output of the model, trained for ~4 mins on a P100 GPU, 5 epochs:

Image of some clothes generated by the model

References

This post is licensed under CC BY 4.0 by the author.