Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Robust Video Matting (RVM)

Teaser

English | 中文

Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves 4K 76FPS and HD 104FPS on an Nvidia GTX 1080 Ti GPU. The project was developed at ByteDance Inc.


News

  • [Aug 25 2021] Source code and pretrained models are published.
  • [Jul 27 2021] Paper is accepted by WACV 2022.

Showreel

Watch the showreel video (YouTube, Bilibili) to see the model's performance.

All footage in the video are available in Google Drive and Baidu Pan (code: tb3w).


Demo

  • Webcam Demo: Run the model live in your browser. Visualize recurrent states.
  • Colab Demo: Test our model on your own videos with free GPU.

Download

We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See inference documentation for more instructions.

Framework Download Notes
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
Official weights for PyTorch. Doc
TorchHub Nothing to Download. Easiest way to use our model in your PyTorch project. Doc
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
If inference on mobile, consider export int8 quantized models yourself. Doc
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter.
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel. Doc
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip
Run the model on the web. Demo, Starter Code
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio. Doc, Exporter

All models are available in Google Drive and Baidu Pan (code: gym7).


PyTorch Example

  1. Install dependencies:
pip install -r requirements_inference.txt
  1. Load the model:
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  1. To convert videos, we provide a simple conversion API:
from inference import convert_video

convert_video(
    model,                           # The model, can be on any device (cpu or cuda).
    input_source='input.mp4',        # A video file or an image sequence directory.
    output_type='video',             # Choose "video" or "png_sequence"
    output_composition='output.mp4', # File path if video; directory path if png sequence.
    output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
    downsample_ratio=None,           # A hyperparameter to adjust or use None for auto.
    seq_chunk=12,                    # Process n frames at once for better parallelism.
)
  1. Or write your own inference code:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.
downsample_ratio = 0.25                                # Adjust based on your video.

with torch.no_grad():
    for src in DataLoader(reader):                     # RGB tensor normalized to 0 ~ 1.
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # Cycle the recurrent states.
        com = fgr * pha + bgr * (1 - pha)              # Composite to green background. 
        writer.write(com)                              # Write frame.
  1. The models and converter API are also available through TorchHub.
# Load the model.
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"

# Converter API.
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

Please see inference documentation for details on downsample_ratio hyperparameter, more converter arguments, and more advanced usage.


Training and Evaluation

Please refer to the training documentation to train and evaluate your own model.


Speed

Speed is measured with inference_speed_test.py for reference.

GPU dType HD (1920x1080) 4K (3840x2160)
RTX 3090 FP16 172 FPS 154 FPS
RTX 2060 Super FP16 134 FPS 108 FPS
GTX 1080 Ti FP32 104 FPS 74 FPS
  • Note 1: HD uses downsample_ratio=0.25, 4K uses downsample_ratio=0.125. All tests use batch size 1 and frame chunk 1.
  • Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32.
  • Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to PyNvCodec.

Project Members

Comments
  • [Questions] - Training Procedure

    [Questions] - Training Procedure

    Hi,

    I have some questions about the training procedure:

    1. In the paper, you've mentioned training Stage 1, for 15 epochs, while in the code you've set the instructions to 20 epochs. Is there a reason for such change? Will the results be similar?
    2. I could not get access to Distinctions-646, I had no reply from the authors/maintainers of the dataset. Based on your indicated file structure, I've built a similar dataset, which adds uncertainty to the quality of my training, but it is a risk I am willing to take. To have a comparison parameter (stages 1-3) do not depend on this dataset, would you mind sharing your partial training weights on pytorch (stage1/epoch19.pth, stage2/epoch21.pth, and stage3/epoch22.pth)?
    3. What is the min resolution you've used for the background images while training?

    For the 3rd time, thank you very much for your contribution to the field. It was a brilliant work. Looking forward to your future work.

  • hardsigmoid replacement

    hardsigmoid replacement

    I've been trying to export an onnx model replacing the hardsigmoid operator.

    I have modified the site-packages/torch/onnx/symbolic_opset9.py file this way:

    @parse_args("v") def hardswish(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    @parse_args("v") def hardsigmoid(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    But I am not sure at all if this is the way to replace them with primitive ops

    When I export the onnx with this change I still get and error "OnnxImportException: Unknown type HardSigmoid encountered while parsing layer 396" with the inference engine I am trying to use.

  • VideoMatte240K-HD

    VideoMatte240K-HD

    if I'm going to train stage3 and stage4, the VideoMatte-HD data will be used. And is it right to modify the following path?VideoMatte240K_JPEG_SD to VideoMatte240K_JPEG_HD

    'videomatte': { 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', },

  •  Some questions about training

    Some questions about training

    1.How to eliminate or reduce edge flickering problem,can i set --seq-length-lr Is it possible to increase the sequence length improvement,Does it work? 2.Only the composite image has no foreground image,Is it possible to remove foreground training,and foreground loss?or is there a better way? 3.How important is foreground prediction for matting

    Looking forward to your reply

  • Add Unity example to README?

    Add Unity example to README?

    Hey there, I just ported RVM to Unity using NatML, an open-source machine learning runtime. I have questions:

    1. Can I make a PR to add a link into the README to a Unity example project demonstrating using RVM?
    2. I published the model under my account on NatML Hub. Would you be interested in signing up on Hub, so that I can transfer the model to you?

    Here's the model on NatML Hub:

    @natsuite/robust-video-matting

  • Not Issue 👉 Few questions

    Not Issue 👉 Few questions

    First of all thank you for working on this project! it looks much stronger than the BMV2 !

    1. Will it work on Anaconda and Windows 10 just like BMV2 works? (not more complicated?)

    2. Will it support same hardware, or need a much more powerful CPU / GPU compare to BMV2 ?

    3. Can you please tell when will you release it again, I missed it first so I can't test it because it's still offline. It will be very nice to have it this week if possible of course.

    Thanks ahead for the answers, please keep up the good work! ❤

  • Synchronization issues between inferred mask and original video

    Synchronization issues between inferred mask and original video

    Hello!. Thanks for the code. I have had some timing issues between the inferred output in the mask compared to the original video. I made this comparison by transforming my original video and the output video from masks to frames. I have obtained the same amount of frames in both processes, so the difference can be caused by a bad configuration of mine. My original video is 30fps and 1080x1920. If you have a suggestion I would appreciate it.

  • Weird results when use Segmentation Pass for inference

    Weird results when use Segmentation Pass for inference

    https://github.com/PeterL1n/RobustVideoMatting/blob/f8a26e27198a93a94bfd06e96b8d5a34d0660f81/inference.py#L127

    I changed this line to use Segmentation Pass. (use the pretrained weights rvm_mobilenetv3.pth)

    pha, *rec = model(src, *rec, segmentation_pass=True)
    fgr = src * pha
    

    But I got weird mask results, something like this, why?

    seg_pass_alpha

  • 新手问题的关于模型结果

    新手问题的关于模型结果

    大神辛苦,两个问题请教....... 1.除了更改downsample_ratio的参数值来修正抠图的精度,还可以更改那些参数来更改实现效果? 2.此项目对显卡的要求是否更高?显卡的型号会影响最后结果么? 目前,有执行model的项目,但是效果并不是很理想,再次感谢!

    Hello, two questions to consult.......

    1. In addition to changing the parameter value of downsample_ratio to correct the accuracy of matting, which other parameters can be changed to change the implementation effect?

    2. Does this project have higher requirements for graphics cards? Does the type of graphics card affect the final results?

    I have my own project to implement model, but the effect is not very ideal, thank you!(Translation from Youdao Translation)

  • FP16 is slower than FP32

    FP16 is slower than FP32

    I use pre-trained ONNX model parameters for inference tests (in Python not C++), only onnxruntime, cv2 and numpy libraries, nothing extra. Parameters downloaded from https://github.com/PeterL1n/RobustVideoMatting/releases/: rvm_mobilenetv3_fp32.onnx and rvm_mobilenetv3_fp16.onnx

    Inference on 1080x1920 video,downsample_ratio=0.25. As a result, the speed of FP32 is about 170ms (1 frame), and the speed of FP16 is about 240ms. Why is FP16 so slow?

    I have adjusted the input correctly, for src, r1i, r2i, r3i, r4i it is np.array([[[[]]]], dtype=np.float32 or 16) and for downsample_ratio it is always np.array([0.25], dtype= float32)

    I use CPU (Intel i5) for inference, Is it so slow because the CPU does not support FP16 operations?

  • foreground prediction details

    foreground prediction details

    你好,请教一下,关于前景预测,从官方提供的web demo中,我看到模型预测的前景图片中除了前景(人像)外,还存在输入图片的背景细节(非人像像素),但是我自己训练得到的模型(我的模型没有修改官方的任何细节,唯一的不同仅仅是采用我采集而来的背景图片),预测的前景图片只含有人像而不会存在输入图片的背景细节,一开始我怀疑可能是前景loss包含了所有像素(alpha可以是任何值而不仅仅是像论文中所说的大于0)的loss, 但是我查看代码后没有任何问题,和论文一致,请问这是什么原因造成? 谢谢。

  • Problem with exporting alpha-mask on the Replicate/COG version

    Problem with exporting alpha-mask on the Replicate/COG version

    I tried both the replicate page and local COG variants. When predicting with alpha-mask, this error is a constant:

    FileNotFoundError: [Errno 2] No such file or directory: 'alpha-mask.mp4'

    Exporting with green-screen and foreground-mask works however. Maybe this is a issue with mp4 not supporting alpha transparent video, so it fails?

  • A question about src_sm in the model.py

    A question about src_sm in the model.py

    I see "x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])" in mobilenetv3.py and ''' f1, f2, f3, f4 = self.backbone(src_sm) ... hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) ''' in model.py. This means the input src_sm of the decoder has not been normalized. Is that your intention?

  • Performance using grayscale images

    Performance using grayscale images

    Hey there,

    thanks for this amazing tool! Does anybody know how the performance for grayscale images is? I want to use it in the dark and I have an infrared camera.

    If retraining is required how much GPU hours do you think are necessary?

    :)

  • Slow inference and low GPU use.

    Slow inference and low GPU use.

    The inference.py and its running at ~4.2it/s. It barely loads my RTX2060 (0-13% use) The inference_speed_test script gives me ~33.2it/s on the same model and video settings. Changing the --workers on the convert_video() function did nothing. Am I missing something? How can I run inferences faster using the full hardware potential?

    Thanks.

  • How do I display results locally using a local camera

    How do I display results locally using a local camera

    I really want to display the result locally using a local camera. I notice that inference_utils.py use pims.PyAVVideoReader to get the imformation about the video. So I really want to know how to call the local camera.

MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Nov 24, 2022
YOLOv5 in PyTorch > ONNX > CoreML > TFLite
YOLOv5 in PyTorch > ONNX > CoreML > TFLite

This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and evolution on anonymized client datasets. All code and models are under active development, and are subject to modification or deletion without notice.

Nov 30, 2022
YOLOv3 in PyTorch > ONNX > CoreML > TFLite
YOLOv3 in PyTorch > ONNX > CoreML > TFLite

This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices

Nov 28, 2022
WHENet - ONNX, OpenVINO, TFLite, TensorRT, EdgeTPU, CoreML, TFJS, YOLOv4/YOLOv4-tiny-3L
WHENet - ONNX, OpenVINO, TFLite, TensorRT, EdgeTPU, CoreML, TFJS, YOLOv4/YOLOv4-tiny-3L

HeadPoseEstimation-WHENet-yolov4-onnx-openvino ONNX, OpenVINO, TFLite, TensorRT, EdgeTPU, CoreML, TFJS, YOLOv4/YOLOv4-tiny-3L 1. Usage $ git clone htt

Sep 21, 2022
U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.
U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

Nov 25, 2022
ONNX Runtime Web demo is an interactive demo portal showing real use cases running ONNX Runtime Web in VueJS.

ONNX Runtime Web demo is an interactive demo portal showing real use cases running ONNX Runtime Web in VueJS. It currently supports four examples for you to quickly experience the power of ONNX Runtime Web.

Dec 4, 2022
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Mar 15, 2022
ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX
ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

Nov 6, 2022
ONNX-PackNet-SfM: Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX
ONNX-PackNet-SfM: Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Oct 5, 2022
A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or simply to separate onnx files to any size you want.
A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or simply to separate onnx files to any size you want.

sne4onnx A very simple tool for situations where optimization with onnx-simplifier would exceed the Protocol Buffers upper file size limit of 2GB, or

Aug 30, 2022
Simple ONNX operation generator. Simple Operation Generator for ONNX.
Simple ONNX operation generator. Simple Operation Generator for ONNX.

sog4onnx Simple ONNX operation generator. Simple Operation Generator for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools Key concept V

May 15, 2022
A very simple tool to rewrite parameters such as attributes and constants for OPs in ONNX models. Simple Attribute and Constant Modifier for ONNX.
A very simple tool to rewrite parameters such as attributes and constants for OPs in ONNX models. Simple Attribute and Constant Modifier for ONNX.

sam4onnx A very simple tool to rewrite parameters such as attributes and constants for OPs in ONNX models. Simple Attribute and Constant Modifier for

May 15, 2022
Simple tool to combine(merge) onnx models. Simple Network Combine Tool for ONNX.
Simple tool to combine(merge) onnx models.  Simple Network Combine Tool for ONNX.

snc4onnx Simple tool to combine(merge) onnx models. Simple Network Combine Tool for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools 1.

Oct 13, 2022
Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel order of RGB and BGR. Simple Channel Converter for ONNX.
Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel order of RGB and BGR. Simple Channel Converter for ONNX.

scc4onnx Very simple NCHW and NHWC conversion tool for ONNX. Change to the specified input order for each and every input OP. Also, change the channel

Oct 16, 2022
Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite and .pb from .tflite.
Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite and .pb from .tflite.

tflite2tensorflow Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite and .pb from .tflite. 1. Supported Layers No. TFLite Layer TF

Nov 18, 2022
A few stylization coreML models that I've trained with CreateML
A few stylization coreML models that I've trained with CreateML

CoreML-StyleTransfer A few stylization coreML models that I've trained with CreateML You can open and use the .mlmodel files in the "models" folder in

Aug 18, 2022
Github project for Attention-guided Temporal Coherent Video Object Matting.

Attention-guided Temporal Coherent Video Object Matting This is the Github project for our paper Attention-guided Temporal Coherent Video Object Matti

Sep 27, 2022
Video Matting Refinement For Python

Video-matting refinement Library (use pip to install) scikit-image numpy av matplotlib Run Static background python path_to_video.mp4 Moving backgroun

Jan 11, 2022