# Discrete Diffusion for Language Modelling
Reproducing a research model

In this lab session we are going to reproduce the experiments given in https://arxiv.org/pdf/2102.05379

We use the code provided in https://github.com/ehoogeboom/multinomial_diffusion


In [11]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset

import os, zipfile, urllib, json, math
import numpy as np

In [12]:
batch_size = 128
num_workers = 1
pin_memory = True

In [13]:
class Vocab():

    def __init__(self, stoi={}):
        self.fill(stoi)

    def fill(self, stoi):
        self.stoi = stoi
        self.itos = {i:s for s,i in stoi.items()}

    def save_json(self, path):
        if not os.path.exists(path): os.makedirs(path)
        vocab_file = os.path.join(path, 'vocab.json')
        with open(vocab_file, 'w') as f:
            json.dump(self.stoi, f, indent=4)

    def load_json(self, path):
        vocab_file = os.path.join(path, 'vocab.json')
        with open(vocab_file, 'r') as f:
            stoi = json.load(f)
        self.fill(stoi)

    def string_to_idx(self, string):
        assert isinstance(string, str)
        return [self.stoi[s] for s in string]

    def idx_to_string(self, idx):
        assert isinstance(idx, list)
        count_err = np.sum([1 for i in idx if i not in self.itos])
        if count_err > 0:
            print(f'Warning, {count_err} decodings were not in vocab.')
            print(set([i for i in idx if i not in self.itos]))
        return ''.join([self.itos[i] if i in self.itos else '?' for i in idx])

    def encode(self, text, padding_value=0):
        assert isinstance(text, list)
        length = torch.tensor([len(string) for string in text])
        tensor_list = [torch.tensor(self.string_to_idx(string)) for string in text]
        tensor = nn.utils.rnn.pad_sequence(tensor_list, batch_first=True, padding_value=padding_value)
        return tensor, length

    def decode(self, tensor, length):
        assert torch.is_tensor(tensor)
        assert tensor.dim() == 2, 'Tensor should have shape (batch_size, seq_len)'
        text = [self.idx_to_string(tensor[b][:length[b]].tolist()) for b in range(tensor.shape[0])]
        return text

In [14]:
DATA_PATH = './datasets'


class Text8Dataset(Dataset):
    """
    The text8 dataset consisting of 100M characters (with vocab size 27).
    We here split the dataset into (90M, 5M, 5M) characters for
    (train, val, test) as in [1,2,3].

    The sets are then split into chunks of equal length as specified by `seq_len`.
    The default is 256, corresponding to what was used in [1]. Other choices
    include 180, as [2] reports using.

    [1] Discrete Flows: Invertible Generative Models of Discrete Data
        Tran et al., 2019, https://arxiv.org/abs/1905.10347
    [2] Architectural Complexity Measures of Recurrent Neural Networks
        Zhang et al., 2016, https://arxiv.org/abs/1602.08210
    [3] Subword Language Modeling with Neural Networks
        Mikolov et al., 2013, http://www.fit.vutbr.cz/~imikolov/rnnlm/char.pdf
    """

    def __init__(self, root=DATA_PATH, seq_len=256, split='train', download=False):
        assert split in {'train', 'valid', 'test'}
        self.root = os.path.join(root, 'text8')
        self.seq_len = seq_len
        self.split = split

        if not os.path.exists(self.raw_file):
            if download:
                self.download()
            else:
                raise RuntimeError('Dataset not found. You can use download=True to download it.')

        # Get vocabulary
        self.vocab = Vocab()
        vocab_file = os.path.join(self.root, 'vocab.json')
        if os.path.exists(vocab_file):
            self.vocab.load_json(self.root)
        else:
            stoi = self._create_stoi()
            self.vocab.fill(stoi)
            self.vocab.save_json(self.root)

        # Preprocess data
        if not os.path.exists(self.processed_file(split)):
            self._preprocess_data(split)

        # Load data
        self.data = torch.load(self.processed_file(split))

    def __getitem__(self, index):
        return self.data[index], self.seq_len

    def __len__(self):
        return len(self.data)

    def _create_stoi(self):
        rawdata = zipfile.ZipFile(self.raw_file).read('text8').decode('utf-8')
        s = sorted(list(set(rawdata)))
        stoi = {s[i]: i for i in range(len(s))}
        return stoi

    def _preprocess_data(self, split):
        # Read raw data
        rawdata = zipfile.ZipFile(self.raw_file).read('text8').decode('utf-8')

        # Extract subset
        if split == 'train':
            rawdata = rawdata[:90000000]
        elif split == 'valid':
            rawdata = rawdata[90000000:95000000]
        elif split == 'test':
            rawdata = rawdata[95000000:]

        # Encode characters
        data = torch.tensor([self.vocab.stoi[s] for s in rawdata])

        # Split into chunks
        data = data[:self.seq_len*(len(data)//self.seq_len)]
        data = data.reshape(-1, self.seq_len)

        # Save processed data
        torch.save(data, self.processed_file(split))

    @property
    def raw_file(self):
        return os.path.join(self.root, 'text8.zip')

    def processed_file(self, split):
        return os.path.join(self.root, 'processed_{}.pt'.format(split))

    def download(self):
        if not os.path.exists(self.root):
            os.makedirs(self.root)

        print('Downloading text8...')
        url = 'http://mattmahoney.net/dc/text8.zip'
        print('Downloading from {}...'.format(url))
        urllib.request.urlretrieve(url, self.raw_file)
        print('Saved to {}'.format(self.raw_file))

In [15]:
def get_data(batch_size, num_workers):
  train = Text8Dataset(seq_len=256, split='train', download=True)
  valid = Text8Dataset(seq_len=256, split='valid')
  test = Text8Dataset(seq_len=256, split='test')
  data_shape = (256,)
  num_classes = 27

  train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
  eval_loader = DataLoader(valid,  batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=pin_memory)

  return train_loader, eval_loader, data_shape, num_classes

In [16]:
train_loader, eval_loader, data_shape, num_classes = get_data(batch_size, num_workers)

In [21]:
import torch.nn.functional as F

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])

    alphas = np.clip(alphas, a_min=0.001, a_max=1.)

    # Use sqrt of this, so the alpha in our paper is the alpha_sqrt from the
    # Gaussian diffusion in Ho et al.
    alphas = np.sqrt(alphas)
    return alphas

def log_1_min_a(a):
    return torch.log(1 - a.exp() + 1e-40)


def log_add_exp(a, b):
    maximum = torch.max(a, b)
    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def index_to_log_onehot(x, num_classes):
    assert x.max().item() < num_classes, \
        f'Error: {x.max().item()} >= {num_classes}'
    x_onehot = F.one_hot(x, num_classes)

    permute_order = (0, -1) + tuple(range(1, len(x.size())))

    x_onehot = x_onehot.permute(permute_order)

    log_x = torch.log(x_onehot.float().clamp(min=1e-30))

    return log_x


def log_onehot_to_index(log_x):
    return log_x.argmax(1)

    """
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""

eps = 1e-8


def sum_except_batch(x, num_dims=1):
    '''
    Sums all dimensions except the first.

    Args:
        x: Tensor, shape (batch_size, ...)
        num_dims: int, number of batch dims (default=1)

    Returns:
        x_sum: Tensor, shape (batch_size,)
    '''
    return x.reshape(*x.shape[:num_dims], -1).sum(-1)

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def log_categorical(log_x_start, log_prob):
    return (log_x_start.exp() * log_prob).sum(dim=1)





## Mulitnomial Diffusion

Comment functions to explain what they do

In [22]:
class MultinomialDiffusion(torch.nn.Module):
    def forward(self, x):
      return x
    def __init__(self, num_classes, shape, denoise_fn, timesteps=1000,
                 parametrization='x0'):
        super(MultinomialDiffusion, self).__init__()

        assert parametrization in ('x0', 'direct')

        self.num_classes = num_classes
        self._denoise_fn = denoise_fn
        self.shape = shape
        self.num_timesteps = timesteps
        self.parametrization = parametrization

        alphas = cosine_beta_schedule(timesteps)

        alphas = torch.tensor(alphas.astype('float64'))
        log_alpha = np.log(alphas)
        log_cumprod_alpha = np.cumsum(log_alpha)

        log_1_min_alpha = log_1_min_a(log_alpha)
        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)

        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5

        # Convert to float32 and register buffers.
        self.register_buffer('log_alpha', log_alpha.float())
        self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
        self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
        self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())

        self.register_buffer('Lt_history', torch.zeros(timesteps))
        self.register_buffer('Lt_count', torch.zeros(timesteps))



    def multinomial_kl(self, log_prob1, log_prob2):
        kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
        return kl

    def q_pred_one_timestep(self, log_x_t, t):
        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)

        # alpha_t * E[xt] + (1 - alpha_t) 1 / K
        log_probs = log_add_exp(
            log_x_t + log_alpha_t,
            log_1_min_alpha_t - np.log(self.num_classes)
        )

        return log_probs

    def q_pred(self, log_x_start, t):
        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)

        log_probs = log_add_exp(
            log_x_start + log_cumprod_alpha_t,
            log_1_min_cumprod_alpha - np.log(self.num_classes)
        )

        return log_probs

    def predict_start(self, log_x_t, t):
        x_t = log_onehot_to_index(log_x_t)

        out = self._denoise_fn(t, x_t)
        out = out.permute(0,2,1)

        #print(out.size())
        assert out.size(0) == x_t.size(0)
        assert out.size(1) == self.num_classes
        assert out.size()[2:] == x_t.size()[1:]
        log_pred = F.log_softmax(out, dim=1)
        return log_pred

    def q_posterior(self, log_x_start, log_x_t, t):
        # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
        # where q(xt | xt-1, x0) = q(xt | xt-1).

        t_minus_1 = t - 1
        # Remove negative values, will not be used anyway for final decoder
        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
        log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)

        num_axes = (1,) * (len(log_x_start.size()) - 1)
        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
        log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)


        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
        # Not very easy to see why this is true. But it is :)
        # print(log_EV_qxtmin_x0.size(), self.q_pred_one_timestep(log_x_t, t).size())
        unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)

        log_EV_xtmin_given_xt_given_xstart = \
            unnormed_logprobs \
            - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)

        return log_EV_xtmin_given_xt_given_xstart

    def p_pred(self, log_x, t):
        if self.parametrization == 'x0':
            log_x_recon = self.predict_start(log_x, t=t)
            log_model_pred = self.q_posterior(
                log_x_start=log_x_recon, log_x_t=log_x, t=t)
        elif self.parametrization == 'direct':
            log_model_pred = self.predict_start(log_x, t=t)
        else:
            raise ValueError
        return log_model_pred

    @torch.no_grad()
    def p_sample(self, log_x, t):
        model_log_prob = self.p_pred(log_x=log_x, t=t)
        out = self.log_sample_categorical(model_log_prob)
        return out

    @torch.no_grad()
    def p_sample_loop(self, shape):
        device = self.log_alpha.device

        b = shape[0]
        # start with random normal image.
        img = torch.randn(shape, device=device)

        for i in reversed(range(1, self.num_timesteps)):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
        return img

    @torch.no_grad()
    def _sample(self, image_size, batch_size = 16):
        return self.p_sample_loop((batch_size, 3, image_size, image_size))

    @torch.no_grad()
    def interpolate(self, x1, x2, t = None, lam = 0.5):
        b, *_, device = *x1.shape, x1.device
        t = default(t, self.num_timesteps - 1)

        assert x1.shape == x2.shape

        t_batched = torch.stack([torch.tensor(t, device=device)] * b)
        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))

        img = (1 - lam) * xt1 + lam * xt2
        for i in reversed(range(0, t)):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))

        return img

    def log_sample_categorical(self, logits):
        uniform = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
        sample = (gumbel_noise + logits).argmax(dim=1)
        log_sample = index_to_log_onehot(sample, self.num_classes)
        return log_sample

    def q_sample(self, log_x_start, t):
        log_EV_qxt_x0 = self.q_pred(log_x_start, t)

        log_sample = self.log_sample_categorical(log_EV_qxt_x0)

        return log_sample

    def nll(self, log_x_start):
        b = log_x_start.size(0)
        device = log_x_start.device
        loss = 0
        for t in range(0, self.num_timesteps):
            t_array = (torch.ones(b, device=device) * t).long()

            kl = self.compute_Lt(
                log_x_start=log_x_start,
                log_x_t=self.q_sample(log_x_start=log_x_start, t=t_array),
                t=t_array)

            loss += kl

        loss += self.kl_prior(log_x_start)

        return loss

    def kl_prior(self, log_x_start):
        b = log_x_start.size(0)
        device = log_x_start.device
        ones = torch.ones(b, device=device).long()

        log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
        log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))

        kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
        return sum_except_batch(kl_prior)

    def compute_Lt(self, log_x_start, log_x_t, t, detach_mean=False):
        log_true_prob = self.q_posterior(
            log_x_start=log_x_start, log_x_t=log_x_t, t=t)

        log_model_prob = self.p_pred(log_x=log_x_t, t=t)

        if detach_mean:
            log_model_prob = log_model_prob.detach()

        kl = self.multinomial_kl(log_true_prob, log_model_prob)
        kl = sum_except_batch(kl)

        decoder_nll = -log_categorical(log_x_start, log_model_prob)
        decoder_nll = sum_except_batch(decoder_nll)

        mask = (t == torch.zeros_like(t)).float()
        loss = mask * decoder_nll + (1. - mask) * kl

        return loss

    def sample_time(self, b, device, method='uniform'):
        if method == 'importance':
            if not (self.Lt_count > 10).all():
                return self.sample_time(b, device, method='uniform')

            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.
            pt_all = Lt_sqrt / Lt_sqrt.sum()

            t = torch.multinomial(pt_all, num_samples=b, replacement=True)

            pt = pt_all.gather(dim=0, index=t)

            return t, pt

        elif method == 'uniform':
            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

            pt = torch.ones_like(t).float() / self.num_timesteps
            return t, pt
        else:
            raise ValueError

    def _train_loss(self, x):
        b, device = x.size(0), x.device

        x_start = x

        t, pt = self.sample_time(b, device, 'importance')

        log_x_start = index_to_log_onehot(x_start, self.num_classes)

        kl = self.compute_Lt(log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t)

        Lt2 = kl.pow(2)
        Lt2_prev = self.Lt_history.gather(dim=0, index=t)
        new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
        self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
        self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))

        kl_prior = self.kl_prior(log_x_start)

        # Upweigh loss term of the kl
        vb_loss = kl / pt + kl_prior

        return -vb_loss


    def log_prob(self, x):
        b, device = x.size(0), x.device
        if self.training:
            return self._train_loss(x)

        else:
            log_x_start = index_to_log_onehot(x, self.num_classes)

            t, pt = self.sample_time(b, device, 'importance')

            kl = self.compute_Lt(
                log_x_start, self.q_sample(log_x_start=log_x_start, t=t), t)

            kl_prior = self.kl_prior(log_x_start)

            # Upweigh loss term of the kl
            loss = kl / pt + kl_prior

            return -loss

    def sample(self, num_samples):
        b = num_samples
        device = self.log_alpha.device
        uniform_logits = torch.zeros((b, self.num_classes) + self.shape, device=device)
        log_z = self.log_sample_categorical(uniform_logits)
        for i in reversed(range(0, self.num_timesteps)):
            print(f'Sample timestep {i:4d}', end='\r')

            t = torch.full((b,), i, device=device, dtype=torch.long)
            log_z = self.p_sample(log_z, t)
        print()
        return log_onehot_to_index(log_z)

    def sample_chain(self, num_samples):
        b = num_samples
        device = self.log_alpha.device
        uniform_logits = torch.zeros(
            (b, self.num_classes) + self.shape, device=device)

        zs = torch.zeros((self.num_timesteps, b) + self.shape).long()

        log_z = self.log_sample_categorical(uniform_logits)
        for i in reversed(range(0, self.num_timesteps)):
            print(f'Chain timestep {i:4d}', end='\r')
            t = torch.full((b,), i, device=device, dtype=torch.long)
            log_z = self.p_sample(log_z, t)

            zs[i] = log_onehot_to_index(log_z)
        print()
        return zs

In [24]:
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim.lr_scheduler import _LRScheduler


class LinearWarmupScheduler(_LRScheduler):
    """ Linearly warm-up (increasing) learning rate, starting from zero.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        total_epoch: target learning rate is reached at total_epoch.
    """

    def __init__(self, optimizer, total_epoch, last_epoch=-1):
        self.total_epoch = total_epoch
        super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * min(1, (self.last_epoch / self.total_epoch)) for base_lr in self.base_lrs]


In [26]:
import os
import pickle
import torch
from prettytable import PrettyTable


def get_args_table(args_dict):
    table = PrettyTable(['Arg', 'Value'])
    for arg, val in args_dict.items():
        table.add_row([arg, val])
    return table


def get_metric_table(metric_dict, epochs):
    table = PrettyTable()
    table.add_column('Epoch', epochs)
    if len(metric_dict)>0:
        for metric_name, metric_values in metric_dict.items():
            table.add_column(metric_name, metric_values)
    return table

class BaseExperiment(object):

    def __init__(self, model, optimizer, scheduler_iter, scheduler_epoch,
                 log_path, eval_every, check_every):

        # Objects
        self.model = model
        self.optimizer = optimizer
        self.scheduler_iter = scheduler_iter
        self.scheduler_epoch = scheduler_epoch

        # Paths
        self.log_path = log_path
        self.check_path = os.path.join(log_path, 'check')

        # Intervals
        self.eval_every = eval_every
        self.check_every = check_every

        # Initialize
        self.current_epoch = 0
        self.train_metrics = {}
        self.eval_metrics = {}
        self.eval_epochs = []

    def train_fn(self, epoch):
        raise NotImplementedError()

    def eval_fn(self, epoch):
        raise NotImplementedError()

    def log_fn(self, epoch, train_dict, eval_dict):
        raise NotImplementedError()

    def log_train_metrics(self, train_dict):
        if len(self.train_metrics)==0:
            for metric_name, metric_value in train_dict.items():
                self.train_metrics[metric_name] = [metric_value]
        else:
            for metric_name, metric_value in train_dict.items():
                self.train_metrics[metric_name].append(metric_value)

    def log_eval_metrics(self, eval_dict):
        if len(self.eval_metrics)==0:
            for metric_name, metric_value in eval_dict.items():
                self.eval_metrics[metric_name] = [metric_value]
        else:
            for metric_name, metric_value in eval_dict.items():
                self.eval_metrics[metric_name].append(metric_value)

    def create_folders(self):

        # Create log folder
        os.makedirs(self.log_path)
        print("Storing logs in:", self.log_path)

        # Create check folder
        if self.check_every is not None:
            os.makedirs(self.check_path)
            print("Storing checkpoints in:", self.check_path)

    def save_args(self, args):

        # Save args
        with open(os.path.join(self.log_path, 'args.pickle'), "wb") as f:
            pickle.dump(args, f)

        # Save args table
        args_table = get_args_table(vars(args))
        with open(os.path.join(self.log_path,'args_table.txt'), "w") as f:
            f.write(str(args_table))

    def save_metrics(self):

        # Save metrics
        with open(os.path.join(self.log_path,'metrics_train.pickle'), 'wb') as f:
            pickle.dump(self.train_metrics, f)
        with open(os.path.join(self.log_path,'metrics_eval.pickle'), 'wb') as f:
            pickle.dump(self.eval_metrics, f)

        # Save metrics table
        metric_table = get_metric_table(self.train_metrics, epochs=list(range(1, self.current_epoch+2)))
        with open(os.path.join(self.log_path,'metrics_train.txt'), "w") as f:
            f.write(str(metric_table))
        metric_table = get_metric_table(self.eval_metrics, epochs=[e+1 for e in self.eval_epochs])
        with open(os.path.join(self.log_path,'metrics_eval.txt'), "w") as f:
            f.write(str(metric_table))

    def checkpoint_save(self, name='checkpoint.pt'):
        checkpoint = {'current_epoch': self.current_epoch,
                      'train_metrics': self.train_metrics,
                      'eval_metrics': self.eval_metrics,
                      'eval_epochs': self.eval_epochs,
                      'model': self.model.state_dict(),
                      'optimizer': self.optimizer.state_dict(),
                      'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None,
                      'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None}
        torch.save(checkpoint, os.path.join(self.check_path, name))

    def checkpoint_load(self, check_path, name='checkpoint.pt'):
        checkpoint = torch.load(os.path.join(check_path, name))
        self.current_epoch = checkpoint['current_epoch']
        self.train_metrics = checkpoint['train_metrics']
        self.eval_metrics = checkpoint['eval_metrics']
        self.eval_epochs = checkpoint['eval_epochs']
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter'])
        if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch'])

    def run(self, epochs):

        for epoch in range(self.current_epoch, epochs):

            # Train
            train_dict = self.train_fn(epoch, epochs)
            self.log_train_metrics(train_dict)

            # Eval
            eval_dict = self.eval_fn(epoch, epochs)
            self.log_eval_metrics(eval_dict)
            self.eval_epochs.append(epoch)

            # Log
            self.save_metrics()
            self.log_fn(epoch, train_dict, eval_dict)

            # Checkpoint
            self.current_epoch += 1
            if (epoch+1) % self.check_every == 0:
                self.checkpoint_save()



In [27]:
# Path
import os
import time
import pathlib
HOME = str(pathlib.Path.home())

import copy

def clean_dict(d, keys):
    d2 = copy.deepcopy(d)
    for key in keys:
        if key in d2:
            del d2[key]
    return d2

#  Logging frameworks
from torch.utils.tensorboard import SummaryWriter

class DiffusionExperiment(BaseExperiment):
    no_log_keys = ['project', 'name',
                   'log_tb', 'log_wandb',
                   'check_every', 'eval_every',
                   'device', 'parallel'
                   'pin_memory', 'num_workers']

    def __init__(self,
                 data_id, model_id, optim_id,
                 train_loader, eval_loader,
                 model, optimizer, scheduler_iter, scheduler_epoch):

        self.log_base = os.path.join(HOME, 'log', 'flow')

        # Edit args
        eval_every = 20
        check_every = 10
        name = time.strftime("%Y-%m-%d_%H-%M-%S")
        project = '_'.join([data_id, model_id])

        # Move model
        model = model.to('cuda:0')

        # Init parent
        super(DiffusionExperiment, self).__init__(model=model,
                                                  optimizer=optimizer,
                                                  scheduler_iter=scheduler_iter,
                                                  scheduler_epoch=scheduler_epoch,
                                                  log_path=os.path.join(self.log_base, data_id, model_id, optim_id, name),
                                                  eval_every=eval_every,
                                                  check_every=check_every)

        # Store args
        self.create_folders()

        # Store IDs
        self.data_id = data_id
        self.model_id = model_id
        self.optim_id = optim_id

        # Store data loaders
        self.train_loader = train_loader
        self.eval_loader = eval_loader

        # Init logging
        self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))

    def log_fn(self, epoch, train_dict, eval_dict):

        # Tensorboard
        for metric_name, metric_value in train_dict.items():
          self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1)
        if eval_dict:
          for metric_name, metric_value in eval_dict.items():
            self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1)


    # def resume(self):
    #     resume_path = os.path.join(self.log_base, self.data_id, self.model_id, self.optim_id, self.args.resume, 'check')
    #     self.checkpoint_load(resume_path)
    #     for epoch in range(self.current_epoch):
    #         train_dict = {}
    #         for metric_name, metric_values in self.train_metrics.items():
    #             train_dict[metric_name] = metric_values[epoch]
    #         if epoch in self.eval_epochs:
    #             eval_dict = {}
    #             for metric_name, metric_values in self.eval_metrics.items():
    #                 eval_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)]
    #         else: eval_dict = None
    #         self.log_fn(epoch, train_dict=train_dict, eval_dict=eval_dict)

    def run(self, epochs):
      super(DiffusionExperiment, self).run(epochs=epochs)

In [28]:
class Experiment(DiffusionExperiment):

    def train_fn(self, epoch, epochs):
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        loss_moving = None
        for iteration, (x, length) in enumerate(self.train_loader):
            #print(x.size())
            #print(length.size())
            x, length = x.to('cuda:0'), length.to('cuda:0')
            num_elem = length.sum()
            loss = - self.model.log_prob(x).sum() / (math.log(2) * num_elem)
            loss.backward()
            if (iteration + 1) % 2 == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                if self.scheduler_iter: self.scheduler_iter.step()
            loss_sum += loss.detach().cpu().item() * len(x)
            loss_count += len(x)

            if loss_moving is None:
                loss_moving = loss.detach().cpu().item()
            else:
                loss_moving = .99 * loss_moving + .01 * loss.detach().cpu().item()

            print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/char: {:.3f}'.format(epoch+1, epochs, loss_count, len(self.train_loader.dataset), loss_moving), end='\r')
        print('')
        if self.scheduler_epoch: self.scheduler_epoch.step()
        return {'bpc': loss_sum/loss_count}

    def eval_fn(self, epoch, epochs):
        self.model.eval()

        print('sqrt |Lt_history|^2')
        sqrt_Lt = torch.sqrt(self.model.Lt_history)
        print(' '.join(f'{item.item():.2f}' for item in sqrt_Lt))
        print()
        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x, length in self.eval_loader:
                x, length = x.to('cuda:0'), length.to('cuda:0')
                num_elem = length.sum()
                loss = - self.model.log_prob(x).sum() / (math.log(2) * num_elem)
                loss_sum += loss.detach().cpu().item() * len(x)
                loss_count += len(x)
                print('Evaluating train. Epoch: {}/{}, Datapoint: {}/{}, Bits/char: {:.3f}'.format(epoch+1, epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r')
            print('')

        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for x, length in self.eval_loader:
                x, length = x.to('cuda:0'), length.to('cuda:0')
                num_elem = length.sum()
                loss = - self.model.log_prob(x).sum() / (math.log(2) * num_elem)
                loss_sum += loss.detach().cpu().item() * len(x)
                loss_count += len(x)
                print('Evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/char: {:.3f}'.format(epoch+1, epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r')
            print('')
        return {'bpc': loss_sum/loss_count}

## Implementation of the Denoiser



In [None]:
char_dim = 16
num_heads = 8
num_layers = 4
time_dim = 32

### Position Embeddings

This implementation is taken from *https://nlp.seas.harvard.edu/annotated-transformer*

In [18]:
class TimeStepEmb(nn.Module):
  def __init__(self, dim):
    super().__init__()

    self.spe = SinusoidalPosEmb(dim, 1000, 40000)
    self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.Softplus(),
            nn.Linear(dim * 4, dim)
        )

  def forward(self, x):
    return self.mlp(self.spe(x))

### Vecteurs de pas de temps

In [17]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, num_steps, rescale_steps=4000):
        super().__init__()
        self.dim = dim
        self.num_steps = float(num_steps)
        self.rescale_steps = float(rescale_steps)

    def forward(self, x):
        x = x / self.num_steps * self.rescale_steps
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


In [19]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return x


Implement class *TransDenoiser* based on Transformer Encoders which implements the diffusion denoiser. From a noisy sequence of characters it produces the clean (denoised) text.

The forward takes as input a batch of timesteps (one per sequence in batch) and a batch of sequences  of (noisy) characters.
1. it concatenates the timestep of each sequence with each character embedings of the said sequence
2. it passes the sequences in a transformer encode consisting of several layers
3. it maps each output to a vector of class scores (via a linear mapping or a MLP)


In [20]:
class TransDenoiser(nn.Module):
  def __init__(self, char_dim, time_dim, num_heads, num_layers, num_classes):
    super().__init__()
    pass

  def forward(self, t, x):
    pass


denoiser = TransDenoiser(char_dim, time_dim, num_heads, num_layers, num_classes)


In [None]:
C, L = 1, data_shape[0]

current_shape = (L,)

model = MultinomialDiffusion(
        num_classes, current_shape, denoiser,
        timesteps=1000,
        parametrization='x0')

In [25]:
lr = 3e-4
warmup=10
gamma=0.9

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=warmup)

scheduler_epoch = ExponentialLR(optimizer, gamma=gamma)



### Running the experiments

Now you have all the pieces to train the diffusion model!
The training loss is not given explcitly, but the number of bits per word gives you an indication about the perplexity of the model.


In [None]:
exp = Experiment('text8',
                 model_id='diff',
                 optim_id='adam',
                 train_loader=train_loader,
                 eval_loader=eval_loader,
                 model=model,
                 optimizer=optimizer,
                 scheduler_iter=scheduler_iter,
                 scheduler_epoch=scheduler_epoch)
epochs = 100
exp.run(epochs)

### Generate Examples

Modify the *Experiment* class so that after each eval a batch of sentences is generated and printed.

Bonus: print the batch of sentences after 0,200,400,800,1000 denoising steps




### Improvement: The Diffusion Transformer architecture

The DDiT architecture, see
https://arxiv.org/pdf/2212.09748,
is a modification of Transformer Encoders specialized for denoising diffusion processes.
While it was proposed for images, and as such was presented in terms of continuous diffusion, it has since been adapted to discrete (multinomial) diffusion processes, see
https://aclanthology.org/2023.findings-emnlp.860.pdf or
https://arxiv.org/pdf/2310.16834
for instance.

Replace the *TransDenoiser* architecture with DDiT.
You can taks as inspiration the implementation provided in https://github.com/hzy312/DiffusionSL
