Giter Site home page Giter Site logo

riccorl / super-slomo-tf2 Goto Github PK

View Code? Open in Web Editor NEW
60.0 6.0 9.0 43.95 MB

Tensorflow 2 implementation of Super SloMo paper

Python 100.00%
computer-vision neural-network tensorflow cnn interpolation flow optical-flow opencv slomo slow-motion generative-model frame frame-interpolation

super-slomo-tf2's Introduction

Super Slo Mo TF2

tensorflow Code style: black

Tensorflow 2 implementation of "Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation" by Jiang H., Sun D., Jampani V., Yang M., Learned-Miller E. and Kautz J.

Setup

The code is based on Tensorflow 2.1. To install all the needed dependency, run

Conda
conda env create -f environment.yml
source activate super-slomo
Pip
python3 -m venv super-slomo
source super-slomo/bin/activate
pip install -r requirements.txt

Inference

You can download the pre-trained model here. This model is trained for 259 epochs on the adobe240fps dataset. It uses the single frame prediction mode.

To generate a slomo video run:

python super-slomo/inference.py path/to/source/video path/to/slomo/video --model path/to/checkpoint --n_frames 20 --fps 480

Train

Data Extraction

Before the training phase, the frames must be extracted from the original video sources. This code uses the adobe240fps dataset to train the model. To extract frames, run the following command:

python super-slomo/frame_extraction.py path/to/dataset path/to/destination 

It will use ffmepg to extract the frames and put them in the destination folder, grouped in folders of 12 consecutive frames. If ffmpeg is not available, it falls back to slower opencv.

For info run:

python super-slomo/frame_extraction.py -h

Train the model

You can start to train the model by running:

python super-slomo/train.py path/to/frames --model path/to/checkpoints --epochs 100 --batch-size 32

If the model directory contains a checkpoint, the model will continue to train from that epoch until the total number of epochs provided is reached

You can also visualize the training with tensorboard, using the following command

tensorboard --logdir log --port 6006

and go to https://localhost:6006.

For info run:

python super-slomo/train.py -h
Multi-frame model

The model above predicts only one frame at time, due to hardware limitations. If you can access to powerful GPUs, you can predict more frame with a single sample (like in the original paper). To start, clone the multi-frame branch

git clone --branch multi-frame https://github.com/Riccorl/Super-SloMo-tf2.git 

then, follow the instructions above to setup and extract the frames. The training command has one additional parameter --frames to control the number of frames to predict:

python super-slomo/train.py path/to/frames --model path/to/checkpoints --epochs 100 --batch-size 32 --frames 9

Useful links

Dataset links

Random notes

References

super-slomo-tf2's People

Contributors

deepsourcebot avatar dependabot[bot] avatar riccorl avatar silsever avatar zeusm9 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

super-slomo-tf2's Issues

The error message "Cannot convert a symbolic Tensor to a numpy array" ,TensorFlow and NumPy versions

Any suggestions?
TensorFlow Version: 2.1.0
TensorFlow Addons Version: 0.8.2
NumPy Version: 1.21.5

The NotImplementedError error you encountered indicates a problem with the use of tf.meshgrid in the dense_image_warp function of tensorflow_addons. The error message "Cannot convert a symbolic Tensor to a numpy array" suggests that a symbolic tensor from TensorFlow cannot be converted into a NumPy array. This is a known issue in TensorFlow 2.x, often related to the compatibility between TensorFlow and NumPy versions.

To address this issue, try the following steps:

Update TensorFlow and TensorFlow Addons: Ensure you are using the latest versions of TensorFlow and TensorFlow Addons. Sometimes, such issues are resolved in updated versions.

Check TensorFlow and NumPy Versions: Some TensorFlow versions may not be fully compatible with specific NumPy versions. You might need to experiment with different version combinations.


(super-slomo) PS E:\vedio-interpolation\Super-SolMo\Super-SloMo-tf2-master> python super-slomo/inference.py resources/UAV/1.mp4 resources/UAV/1output.mp4 --model chckpnt259/ckpt-259 --n_frames 20 --fps 480
following is error information_______________
2023-12-31 19:29:04.119197: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
WARNING:tensorflow:AutoGraph could not transform <function load_dataset.. at 0x000002316D37A1F8> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output.
Cause: No module named 'tensorflow_core.estimator'
WARNING:tensorflow:AutoGraph could not transform <function load_frames at 0x000002316D0DE708> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output.
Cause: No module named 'tensorflow_core.estimator'
WARNING:tensorflow:AutoGraph could not transform <function dense_image_warp at 0x000002316D227E58> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output.
Cause: No module named 'tensorflow_core.estimator'
Traceback (most recent call last):
File "super-slomo/inference.py", line 164, in
main()
File "super-slomo/inference.py", line 159, in main
predict(video_path, model_path, output_path, args.n_frames, args.fps)
File "super-slomo/inference.py", line 120, in predict
predictions, _ = model(frames + ([f],), training=False)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 822, in call
outputs = self.call(cast_inputs, *args, **kwargs)
File "E:\jupyterDir\vedio-interpolation\Super-SolMo\Super-SloMo-tf2-master\super-slomo\models\slomo_model.py", line 31, in call
f_t0, v_t0, f_t1, v_t1, g_i0_ft0, g_i1_ft1 = self.optical_flow(optical_input)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 822, in call
outputs = self.call(cast_inputs, *args, **kwargs)
File "E:\jupyterDir\vedio-interpolation\Super-SolMo\Super-SloMo-tf2-master\super-slomo\models\layers.py", line 166, in call
g_i0_ft0 = self.backwarp_layer_t0([frames_0, f_t0_t])
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 822, in call
outputs = self.call(cast_inputs, *args, **kwargs)
File "E:\jupyterDir\vedio-interpolation\Super-SolMo\Super-SloMo-tf2-master\super-slomo\models\layers.py", line 142, in call
img_backwarp = self.backwarp(image, flow)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 568, in call
result = self._call(*args, **kwds)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 615, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 497, in _initialize
*args, **kwds))
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\eager\function.py", line 2389, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\eager\function.py", line 2703, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\eager\function.py", line 2593, in _create_graph_function
capture_by_value=self._capture_by_value),
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 978, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 439, in wrapped_fn
return weak_wrapped_fn().wrapped(*args, **kwds)
File "E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
NotImplementedError: in converted code:

E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_addons\image\dense_image_warp.py:235 dense_image_warp
    grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\ops\array_ops.py:3065 meshgrid
    mult_fact = ones(shapes, output_dtype)
E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\ops\array_ops.py:2659 ones
    output = _constant_if_small(one, shape, dtype, name)
E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\ops\array_ops.py:2391 _constant_if_small
    if np.prod(shape) < 1000:
<__array_function__ internals>:6 prod

E:\Miniconda3\envs\super-slomo\lib\site-packages\numpy\core\fromnumeric.py:3052 prod
    keepdims=keepdims, initial=initial, where=where)
E:\Miniconda3\envs\super-slomo\lib\site-packages\numpy\core\fromnumeric.py:86 _wrapreduction
    return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
E:\Miniconda3\envs\super-slomo\lib\site-packages\tensorflow_core\python\framework\ops.py:728 __array__
    " array.".format(self.name))

NotImplementedError: Cannot convert a symbolic Tensor (dense_image_warp/meshgrid/Size_1:0) to a numpy array.

cv2 frame extraction currently broken

This does not work as expected. If os.system(cmd) does not find ffmpeg no exception is thrown.
See relevant SO

        output_filename = output_dir / video_file.name
        Path(output_filename).mkdir(parents=True, exist_ok=True)
        try:
            cmd = "ffmpeg -i** {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg".format(
                video_file, width, height, output_filename
            )
            os.system(cmd)
        except:
            print("ffmpeg not found, using opencv")
            vidcap = cv2.VideoCapture(str(video_file))
            success, image = vidcap.read()

If you need to bypass add raise Exception

        try:
            raise Exception("ffmpeg now disabled")
            cmd = "ffmpeg -i** {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg".format(

multi-GPU support still WIP

image

According to this post, is needed:

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():

It is still not enough.

File "/home/rac/slomo/multifix/super-slomo/train.py", line 207, in train_step  *
    loss_values = loss_obj.compute_losses(
File "/home/rac/slomo/multifix/super-slomo/models/losses.py", line 122, in compute_losses  *
    p_rec_loss += self.reconstruction_loss(true, pred)
File "/home/rac/slomo/multifix/super-slomo/models/losses.py", line 26, in reconstruction_loss  *
    return self.mae(y_true, y_pred)
File "/home/rac/slomo/tf-cuda.env/lib/python3.10/site-packages/keras/losses.py", line 166, in __call__  **
    reduction = self._get_reduction()
File "/home/rac/slomo/tf-cuda.env/lib/python3.10/site-packages/keras/losses.py", line 217, in _get_reduction
    raise ValueError(

ValueError: Please use `tf.keras.losses.Reduction.SUM` or `tf.keras.losses.Reduction.NONE` for loss reduction when losses are used with `tf.distribute.Strategy` outside of the built-in training loops. You can implement `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch size like:
```
with strategy.scope():
    loss_obj = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
....
    loss = tf.reduce_sum(loss_obj(labels, predictions)) * (1. / global_batch_size)
```

frame_extraction frames created per second

Inference.py uses opencv for frame extraction, where VideoCapture::read might actually process all frames, as far as compressed videos have individual frames.

Default frame_extraction.py uses ffmpeg (if found), and ffmpeg default settings are (atleast on mine) 2 frames/sec. These are in 640x360 dimensions unless specified otherwise.

Are these good values for training? Are there benefits on targeting full framerate? What's the speed loss?

Pre-trained model for the multi-frame prediction mode

Hi,
There is a pre-trained model for using the single frame prediction mode. Is there a pre-trained model available for the multi-frame prediction mode in order to reproduce the results of the paper?
Best Regards,
Fabien

My full conda list which I got it to work (CPU)

Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
_tflow_select             2.3.0                       mkl  
absl-py                   1.3.0            py38h06a4308_0  
aiohttp                   3.8.3            py38h5eee18b_0  
aiosignal                 1.2.0              pyhd3eb1b0_0  
appdirs                   1.4.4              pyhd3eb1b0_0  
astor                     0.8.1            py38h06a4308_0  
astunparse                1.6.3                      py_0  
async-timeout             4.0.2            py38h06a4308_0  
attrs                     22.1.0           py38h06a4308_0  
black                     19.10b0                    py_0  
blas                      1.0                         mkl  
blinker                   1.4              py38h06a4308_0  
brotlipy                  0.7.0           py38h27cfd23_1003  
c-ares                    1.19.0               h5eee18b_0  
ca-certificates           2023.01.10           h06a4308_0  
cachetools                4.2.2              pyhd3eb1b0_0  
certifi                   2022.12.7        py38h06a4308_0  
cffi                      1.15.1           py38h5eee18b_3  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.0.4            py38h06a4308_0  
cryptography              39.0.1           py38h9ce1e76_0  
flit-core                 3.6.0              pyhd3eb1b0_0  
frozenlist                1.3.3            py38h5eee18b_0  
gast                      0.4.0              pyhd3eb1b0_0  
google-auth               2.6.0              pyhd3eb1b0_0  
google-auth-oauthlib      0.4.4              pyhd3eb1b0_0  
google-pasta              0.2.0              pyhd3eb1b0_0  
grpcio                    1.42.0           py38hce63b2e_0  
h5py                      2.10.0           py38hd6299e0_1  
hdf5                      1.10.6               hb1b8bf9_0  
idna                      3.4              py38h06a4308_0  
importlib-metadata        6.0.0            py38h06a4308_0  
intel-openmp              2023.0.0         h9e868ea_25371  
keras                     2.12.0                   pypi_0    pypi
keras-preprocessing       1.1.2              pyhd3eb1b0_0  
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.2                h6a678d5_6  
libgcc-ng                 11.2.0               h1234567_1  
libgfortran-ng            7.5.0               ha8ba4b0_17  
libgfortran4              7.5.0               ha8ba4b0_17  
libgomp                   11.2.0               h1234567_1  
libprotobuf               3.20.3               he621ea3_0  
libstdcxx-ng              11.2.0               h1234567_1  
markdown                  3.4.1            py38h06a4308_0  
markupsafe                2.1.1            py38h7f8727e_0  
mkl                       2020.2                      256  
mkl-service               2.3.0            py38he904b0f_0  
mkl_fft                   1.3.0            py38h54f3939_0  
mkl_random                1.1.1            py38h0573a6f_0  
multidict                 6.0.2            py38h5eee18b_0  
mypy_extensions           0.4.3            py38h06a4308_1  
ncurses                   6.4                  h6a678d5_0  
numpy                     1.18.5           py38ha1c710e_0  
numpy-base                1.18.5           py38hde5b4d6_0  
oauthlib                  3.2.1            py38h06a4308_0  
opencv-python             4.7.0.72                 pypi_0    pypi
openssl                   1.1.1t               h7f8727e_0  
opt_einsum                3.3.0              pyhd3eb1b0_1  
packaging                 23.0                     pypi_0    pypi
pathspec                  0.10.3           py38h06a4308_0  
pip                       23.0.1           py38h06a4308_0  
protobuf                  3.20.3           py38h6a678d5_0  
pyasn1                    0.4.8              pyhd3eb1b0_0  
pyasn1-modules            0.2.8                      py_0  
pycparser                 2.21               pyhd3eb1b0_0  
pyjwt                     2.4.0            py38h06a4308_0  
pyopenssl                 23.0.0           py38h06a4308_0  
pysocks                   1.7.1            py38h06a4308_0  
python                    3.8.16               h7a1cb2a_3  
python-flatbuffers        2.0                pyhd3eb1b0_0  
readline                  8.2                  h5eee18b_0  
regex                     2022.7.9         py38h5eee18b_0  
requests                  2.28.1           py38h06a4308_1  
requests-oauthlib         1.3.0                      py_0  
rsa                       4.7.2              pyhd3eb1b0_1  
scipy                     1.6.2            py38h91f5cce_0  
setuptools                65.6.3           py38h06a4308_0  
six                       1.16.0             pyhd3eb1b0_1  
sqlite                    3.41.1               h5eee18b_0  
tensorboard               2.11.0           py38h06a4308_0  
tensorboard-data-server   0.6.1            py38h52d8a92_0  
tensorboard-plugin-wit    1.8.1            py38h06a4308_0  
tensorflow                2.4.1           mkl_py38hb2083e0_0  
tensorflow-addons         0.10.0                   pypi_0    pypi
tensorflow-base           2.4.1           mkl_py38h43e0292_0  
tensorflow-estimator      2.4.1              pyheb71bc4_0  
termcolor                 2.1.0            py38h06a4308_0  
tk                        8.6.12               h1ccaba5_0  
toml                      0.10.2             pyhd3eb1b0_0  
tqdm                      4.64.1           py38h06a4308_0  
typed-ast                 1.4.3            py38h7f8727e_1  
typeguard                 3.0.1                    pypi_0    pypi
typing_extensions         4.4.0            py38h06a4308_0  
urllib3                   1.26.14          py38h06a4308_0  
werkzeug                  2.2.2            py38h06a4308_0  
wheel                     0.38.4           py38h06a4308_0  
wrapt                     1.14.1           py38h5eee18b_0  
xz                        5.2.10               h5eee18b_1  
yarl                      1.8.1            py38h5eee18b_0  
zipp                      3.11.0           py38h06a4308_0  
zlib                      1.2.13               h5eee18b_0  

Numpy has upgrade issues, I always use 1.18.4/5.

Opencv-python is difficult.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.