Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
/photos/output.mp4
photos/output.mp4
outputs/

.env/
venv/
__pycache__/

*.DS_Store

checkpoints/*
!checkpoints/.gitkeep
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,29 @@ Argument list:
# Testing exported model
The following script creates an MP4 video of interpolated frames between 2 input images:
```
python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--gpu] [--fp16] [--frames FRAMES] [--fps FPS]
python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--device DEVICE_NAME] [--fp16] [--frames FRAMES] [--fps FPS]
```
* ```model_path``` Path to the exported TorchScript checkpoint
* ```img1``` Path to the first image
* ```img2``` Path to the second image
* ```--save_path SAVE_PATH``` Path to save the interpolated frames as a video, if absent it will be saved in the same directory as ```img1``` is located and named ```output.mp4```
* ```--gpu``` Whether to attempt to use GPU for predictions
* ```--device``` Choose between `gpu`, `cpu` and `auto`. The latter will be selected as the default setting.
* ```--fp16``` Whether to use fp16 for calculations, speeds inference up on GPUs with tensor cores
* ```--frames FRAMES``` Number of frames to interpolate between the input images
* ```--fps FPS``` FPS of the output video

Or try the **gradio application**. First you need to install gradio on your local environment:

```sh
pip install gradio
```

Then, you can run the application:

```sh
python gradio_app.py gradio_app.py [--port PORT] [--debug] [--share]
```

### Results on the 2 example photos from original repository:
<p float="left">
<img src="photos/one.png" width="384px" />
Expand Down
Empty file added checkpoints/.gitkeep
Empty file.
125 changes: 125 additions & 0 deletions gradio_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import gradio as gr

from inference import inference, select_device, load_model

this_dir = os.path.dirname(__file__)
output_folder_path = os.path.join(this_dir, "outputs")

output_count = 1
continue_counting = True

device = select_device("auto")
model = None

with gr.Blocks() as app:
with gr.Row():
with gr.Column():
model_path = gr.FileExplorer("**/*.pt", root_dir=os.path.join(this_dir, "checkpoints"), file_count="single", label="Model path")
device_dropdown = gr.Dropdown(["auto", "gpu", "cpu"], value="auto", interactive=True, label="Device")
with gr.Row():
fp16 = gr.Checkbox(value=False, label="fp16")
frames = gr.Number(18, precision=0, minimum=1, label="Number of frames to interpolate")
fps = gr.Number(10, precision=0, minimum=1, label="FPS")
with gr.Row():
frame1 = gr.Image(
value=os.path.join(this_dir, "photos/one.png"),
type="filepath",
label="Frame 1",
)
frame2 = gr.Image(
value=os.path.join(this_dir, "photos/two.png"),
type="filepath",
label="Frame 2",
)
save_path = gr.Text(os.path.join(output_folder_path, f"{output_count}.mp4"), interactive=True, label="Save path")
with gr.Column():
video = gr.Video(label="Output video")
with gr.Row():
generate_button = gr.Button(value="Generate video", variant="primary")

def model_loading_trigger(checkpoint_path: str | None, is_fp16: bool):
if checkpoint_path is not None:
global model, device
model = load_model(checkpoint_path, device, is_fp16)
print("Model loaded/re-loaded.")
else:
model = None
print("Model unloaded.")

model_path.change(model_loading_trigger, inputs=[model_path, fp16])
fp16.change(model_loading_trigger, inputs=[model_path, fp16])


def device_loading_trigger(device_name: str, checkpoint_path: str, is_fp16: bool):
global device
force_settings = {}
try:
device = select_device(device_name)
except RuntimeError:
device = select_device("auto")
print(f"{device_name} cannot be loaded. Switch to device {device}")
force_settings["value"] = "auto"
model_loading_trigger(checkpoint_path, is_fp16)
return gr.update(**force_settings)

device_dropdown.change(device_loading_trigger, inputs=[device_dropdown, model_path, fp16], outputs=device_dropdown)


def trigger_inference(frame1, frame2, save_path, frames, fps, is_fp16):
if model is None:
print("The model must be loaded first to generate the video.")
return [gr.update(), gr.update()]
inference(model, device, frame1, frame2, save_path, frames, fps, is_fp16)
if continue_counting:
global output_count
output_count += 1
next_save_file_name = os.path.join(output_folder_path, f"{output_count}.mp4")
video_update = gr.update(value=save_path)
return [gr.update(value=next_save_file_name), video_update]
return [gr.update(), video_update]

generate_button.click(
lambda: gr.update(interactive=False),
outputs=generate_button,
).then(
trigger_inference,
inputs=[frame1, frame2, save_path, frames, fps, fp16],
outputs=[save_path, video],
).then(
lambda: gr.update(interactive=True),
outputs=generate_button,
)


def stop_counting_trigger():
continue_counting = False

save_path.change(stop_counting_trigger)


if __name__ == "__main__":
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--port", "-p", type=int)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--share", action="store_true")
parser.add_argument(
"--max_text_chunk_multiplier",
type=int,
default=10,
help="for preventing the issue: 'ValueError: Decompressed Data Too Large'",
)

args = parser.parse_args()

from PIL import PngImagePlugin
# issue: https://github.com/python-pillow/Pillow/issues/5610
PngImagePlugin.MAX_TEXT_CHUNK *= args.max_text_chunk_multiplier

app.launch(
debug=args.debug,
server_port=args.port,
share=args.share,
)
56 changes: 35 additions & 21 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,39 @@
import numpy as np
import cv2

from accelerate import Accelerator

from util import load_image


def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
def select_device(device: str) -> torch.device:
if device == "auto":
accelerator = Accelerator()
if accelerator.device.__str__().startswith("mps"):
print("WARNING: MPS device not (yet?) supported by the project.")
return torch.device("cpu")
else:
return accelerator.device
return torch.device(device)



def load_model(model_path: str, device: torch.device, half: bool) -> torch.nn.Module:
model = torch.jit.load(model_path, map_location='cpu')
if half:
model = model.half()
model = model.to(device)
return model


def inference(model: torch.nn.Module, device: torch.device, img1: str, img2: str, save_path: str, inter_frames: int, fps: int, half: bool):
model.eval()
img_batch_1, crop_region_1 = load_image(img1)
img_batch_2, crop_region_2 = load_image(img2)

img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2)
img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2)

if not half:
model.float()

if gpu and torch.cuda.is_available():
if half:
model = model.half()
else:
model.float()
model = model.cuda()


if save_path == 'img1 folder':
save_path = os.path.join(os.path.split(img1)[0], 'output.mp4')

Expand All @@ -51,12 +62,11 @@ def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
x0 = results[start_i]
x1 = results[end_i]

if gpu and torch.cuda.is_available():
if half:
x0 = x0.half()
x1 = x1.half()
x0 = x0.cuda()
x1 = x1.cuda()
if half:
x0 = x0.half()
x1 = x1.half()
x0 = x0.to(device)
x1 = x1.to(device)

dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])

Expand Down Expand Up @@ -95,11 +105,15 @@ def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
parser.add_argument('img2', type=str, help='Path to the second image')

parser.add_argument('--save_path', type=str, default='img1 folder', help='Path to save the interpolated frames')
parser.add_argument('--gpu', action='store_true', help='Use GPU')
parser.add_argument('--device', type=str, choices=["auto", "gpu", "cpu"], default="auto", help='Device to run the inference (default: choose the optimal device)')
parser.add_argument('--fp16', action='store_true', help='Use FP16')
parser.add_argument('--frames', type=int, default=18, help='Number of frames to interpolate')
parser.add_argument('--fps', type=int, default=10, help='FPS of the output video')

args = parser.parse_args()

inference(args.model_path, args.img1, args.img2, args.save_path, args.gpu, args.frames, args.fps, args.fp16)

device = select_device(args.device)
model = load_model(args.model_path, device, args.fp16)
print(f"Loaded model sent to {device}.")

inference(model, device, args.img1, args.img2, args.save_path, args.frames, args.fps, args.fp16)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
opencv-python
torch
accelerate
tqdm