Edit model card

Make-A-Video SD JAX Model Card

A latent diffusion model for text-to-video synthesis.

Try it with an interactive demo on Model Database spaces.

Training code, PyTorch and FLAX implementation are available here: https://github.com/lopho/makeavid-sd-tpu

This model extends an inpainting latent-diffusion image generation model (Stable Diffusion v1.5 Inpaint) with temporal convolution and temporal self-attention ported from Make-A-Video PyTorch

It has then been fine tuned for ~150k steps on a dataset of 10,000 videos themed around dance. Then for an additional ~50k steps with extra data of generic videos mixed into the original set.

This model used weights pretrained by lxj616 on 286 timelapse video clips for initialization.

Table of Contents

Model Details

Uses

  • Understanding limitations and biases of generative video models
  • Development of educational or creative tools
  • Artistic usage
  • What ever you want

Limitations

  • Limited knowledge of temporal concepts not seen during training (see linked datasets)
  • Emerging flashing lights, most likely due to training on dance videos, which include many scenes with bright, neon and flashing lights
  • The model has only been trained with English captions and will not perform as well in other languages

Training

Training Data

  • S(mall)dance: 10,000 video-caption pairs of dancing videos (as encoded image latents, text embeddings and metadata).
  • small: 7,000 video-caption pairs of general videos (as encoded image latents, text embeddings and metadata).

Training Procedure

  • From each video sample a random range of 24 frames is selected
  • Each video latent is encoded into latent representations of the shape 4 x 24 x H/8 x W/8
  • The latent of the first frame from each video is repeated along the frame dimension as additional guidance (referred to as hint image)
  • Hint latent and video latent are stacked to produce a shape of 8 x 24 x H/8 x W/8
  • The last input channel is preserved for masking purposes (not used during training, set to zero)
  • Text prompts are encoded by the CLIP text encoder
  • Video latents with added noise and clip encoded text prompts are fed into the UNet to predict the added noise
  • Loss is the reconstruction objective between the added noise and the predicted noise via mean squared error (mse/l2)

Hyperparameters

  • Batch size: 1 x 4
  • Image size: 512 x 512
  • Frame count: 24
  • Optimizer: AdamW (beta_1 = 0.9, beta_2 = 0.999, weight decay = 0.02)
  • Schedule:
    • 2 x 10 epochs: LR warmup for 1 epochs then held constant at 5e-5 (10,000 samples per ep)
    • 2 x 20 epochs: LR warmup for 1 epochs then held constant at 5e-5 (10,000 samples per ep)
    • 1 x 9 epochs: LR warmup for 1 epoch to 5e-5 then cosine annealing to 1e-8
    • Additional data mixed in, see Trainig Data
    • 1 x 5 epochs: LR warmup for 0.5 epochs to 2.5e-5 then constant (17,000 samples per ep)
    • 1 x 5 epochs: LR warmup for 0.5 epochs to 5e-6 then cosine annealing to 2.5e-6 (17,000 samples per ep)
    • some restarts were required due to NaNs appearing in the gradient (see training logs)
  • Total update steps: ~200,000
  • Hardware: TPUv4-8 (provided by Google Cloud for the Model Database JAX/Diffusers Sprint Event)

Trainig statistics are available at Weights and Biases.

Acknowledgements

Citation

@misc{TempoFunk2023,
      author = {Lopho, Carlos Chavez},
      title = {TempoFunk: Extending latent diffusion image models to Video},
      url = {https://github.com/lopho/makeavid-sd-tpu},
      month = {5},
      year = {2023}
}

This model card was written by: Lopho, Chavinlo, Julian Herrera and is based on the DALL-E Mini model card.

Downloads last month
0
Hosted inference API

Inference API has been turned off for this model.

Datasets used to train TempoFunk/makeavid-sd-jax

Spaces using TempoFunk/makeavid-sd-jax 2