diff --git a/.gitignore b/.gitignore index 598912f..9c85796 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,11 @@ -/photos/output.mp4 +photos/output.mp4 +outputs/ + +.env/ +venv/ +__pycache__/ + +*.DS_Store + +checkpoints/* +!checkpoints/.gitkeep \ No newline at end of file diff --git a/README.md b/README.md index 832c1e8..0e7b53e 100644 --- a/README.md +++ b/README.md @@ -70,17 +70,29 @@ Argument list: # Testing exported model The following script creates an MP4 video of interpolated frames between 2 input images: ``` -python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--gpu] [--fp16] [--frames FRAMES] [--fps FPS] +python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--device DEVICE_NAME] [--fp16] [--frames FRAMES] [--fps FPS] ``` * ```model_path``` Path to the exported TorchScript checkpoint * ```img1``` Path to the first image * ```img2``` Path to the second image * ```--save_path SAVE_PATH``` Path to save the interpolated frames as a video, if absent it will be saved in the same directory as ```img1``` is located and named ```output.mp4``` -* ```--gpu``` Whether to attempt to use GPU for predictions +* ```--device``` Choose between `gpu`, `cpu` and `auto`. The latter will be selected as the default setting. * ```--fp16``` Whether to use fp16 for calculations, speeds inference up on GPUs with tensor cores * ```--frames FRAMES``` Number of frames to interpolate between the input images * ```--fps FPS``` FPS of the output video +Or try the **gradio application**. First you need to install gradio on your local environment: + +```sh +pip install gradio +``` + +Then, you can run the application: + +```sh +python gradio_app.py gradio_app.py [--port PORT] [--debug] [--share] +``` + ### Results on the 2 example photos from original repository:
diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/gradio_app.py b/gradio_app.py
new file mode 100644
index 0000000..f7de945
--- /dev/null
+++ b/gradio_app.py
@@ -0,0 +1,125 @@
+import os
+import gradio as gr
+
+from inference import inference, select_device, load_model
+
+this_dir = os.path.dirname(__file__)
+output_folder_path = os.path.join(this_dir, "outputs")
+
+output_count = 1
+continue_counting = True
+
+device = select_device("auto")
+model = None
+
+with gr.Blocks() as app:
+ with gr.Row():
+ with gr.Column():
+ model_path = gr.FileExplorer("**/*.pt", root_dir=os.path.join(this_dir, "checkpoints"), file_count="single", label="Model path")
+ device_dropdown = gr.Dropdown(["auto", "gpu", "cpu"], value="auto", interactive=True, label="Device")
+ with gr.Row():
+ fp16 = gr.Checkbox(value=False, label="fp16")
+ frames = gr.Number(18, precision=0, minimum=1, label="Number of frames to interpolate")
+ fps = gr.Number(10, precision=0, minimum=1, label="FPS")
+ with gr.Row():
+ frame1 = gr.Image(
+ value=os.path.join(this_dir, "photos/one.png"),
+ type="filepath",
+ label="Frame 1",
+ )
+ frame2 = gr.Image(
+ value=os.path.join(this_dir, "photos/two.png"),
+ type="filepath",
+ label="Frame 2",
+ )
+ save_path = gr.Text(os.path.join(output_folder_path, f"{output_count}.mp4"), interactive=True, label="Save path")
+ with gr.Column():
+ video = gr.Video(label="Output video")
+ with gr.Row():
+ generate_button = gr.Button(value="Generate video", variant="primary")
+
+ def model_loading_trigger(checkpoint_path: str | None, is_fp16: bool):
+ if checkpoint_path is not None:
+ global model, device
+ model = load_model(checkpoint_path, device, is_fp16)
+ print("Model loaded/re-loaded.")
+ else:
+ model = None
+ print("Model unloaded.")
+
+ model_path.change(model_loading_trigger, inputs=[model_path, fp16])
+ fp16.change(model_loading_trigger, inputs=[model_path, fp16])
+
+
+ def device_loading_trigger(device_name: str, checkpoint_path: str, is_fp16: bool):
+ global device
+ force_settings = {}
+ try:
+ device = select_device(device_name)
+ except RuntimeError:
+ device = select_device("auto")
+ print(f"{device_name} cannot be loaded. Switch to device {device}")
+ force_settings["value"] = "auto"
+ model_loading_trigger(checkpoint_path, is_fp16)
+ return gr.update(**force_settings)
+
+ device_dropdown.change(device_loading_trigger, inputs=[device_dropdown, model_path, fp16], outputs=device_dropdown)
+
+
+ def trigger_inference(frame1, frame2, save_path, frames, fps, is_fp16):
+ if model is None:
+ print("The model must be loaded first to generate the video.")
+ return [gr.update(), gr.update()]
+ inference(model, device, frame1, frame2, save_path, frames, fps, is_fp16)
+ if continue_counting:
+ global output_count
+ output_count += 1
+ next_save_file_name = os.path.join(output_folder_path, f"{output_count}.mp4")
+ video_update = gr.update(value=save_path)
+ return [gr.update(value=next_save_file_name), video_update]
+ return [gr.update(), video_update]
+
+ generate_button.click(
+ lambda: gr.update(interactive=False),
+ outputs=generate_button,
+ ).then(
+ trigger_inference,
+ inputs=[frame1, frame2, save_path, frames, fps, fp16],
+ outputs=[save_path, video],
+ ).then(
+ lambda: gr.update(interactive=True),
+ outputs=generate_button,
+ )
+
+
+ def stop_counting_trigger():
+ continue_counting = False
+
+ save_path.change(stop_counting_trigger)
+
+
+if __name__ == "__main__":
+ from argparse import ArgumentParser
+
+ parser = ArgumentParser()
+ parser.add_argument("--port", "-p", type=int)
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument(
+ "--max_text_chunk_multiplier",
+ type=int,
+ default=10,
+ help="for preventing the issue: 'ValueError: Decompressed Data Too Large'",
+ )
+
+ args = parser.parse_args()
+
+ from PIL import PngImagePlugin
+ # issue: https://github.com/python-pillow/Pillow/issues/5610
+ PngImagePlugin.MAX_TEXT_CHUNK *= args.max_text_chunk_multiplier
+
+ app.launch(
+ debug=args.debug,
+ server_port=args.port,
+ share=args.share,
+ )
\ No newline at end of file
diff --git a/inference.py b/inference.py
index 344d8bb..e25a4e8 100644
--- a/inference.py
+++ b/inference.py
@@ -5,28 +5,39 @@
import numpy as np
import cv2
+from accelerate import Accelerator
+
from util import load_image
-def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
+def select_device(device: str) -> torch.device:
+ if device == "auto":
+ accelerator = Accelerator()
+ if accelerator.device.__str__().startswith("mps"):
+ print("WARNING: MPS device not (yet?) supported by the project.")
+ return torch.device("cpu")
+ else:
+ return accelerator.device
+ return torch.device(device)
+
+
+
+def load_model(model_path: str, device: torch.device, half: bool) -> torch.nn.Module:
model = torch.jit.load(model_path, map_location='cpu')
+ if half:
+ model = model.half()
+ model = model.to(device)
+ return model
+
+
+def inference(model: torch.nn.Module, device: torch.device, img1: str, img2: str, save_path: str, inter_frames: int, fps: int, half: bool):
model.eval()
img_batch_1, crop_region_1 = load_image(img1)
img_batch_2, crop_region_2 = load_image(img2)
img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2)
img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2)
-
- if not half:
- model.float()
-
- if gpu and torch.cuda.is_available():
- if half:
- model = model.half()
- else:
- model.float()
- model = model.cuda()
-
+
if save_path == 'img1 folder':
save_path = os.path.join(os.path.split(img1)[0], 'output.mp4')
@@ -51,12 +62,11 @@ def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
x0 = results[start_i]
x1 = results[end_i]
- if gpu and torch.cuda.is_available():
- if half:
- x0 = x0.half()
- x1 = x1.half()
- x0 = x0.cuda()
- x1 = x1.cuda()
+ if half:
+ x0 = x0.half()
+ x1 = x1.half()
+ x0 = x0.to(device)
+ x1 = x1.to(device)
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
@@ -95,11 +105,15 @@ def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half):
parser.add_argument('img2', type=str, help='Path to the second image')
parser.add_argument('--save_path', type=str, default='img1 folder', help='Path to save the interpolated frames')
- parser.add_argument('--gpu', action='store_true', help='Use GPU')
+ parser.add_argument('--device', type=str, choices=["auto", "gpu", "cpu"], default="auto", help='Device to run the inference (default: choose the optimal device)')
parser.add_argument('--fp16', action='store_true', help='Use FP16')
parser.add_argument('--frames', type=int, default=18, help='Number of frames to interpolate')
parser.add_argument('--fps', type=int, default=10, help='FPS of the output video')
args = parser.parse_args()
-
- inference(args.model_path, args.img1, args.img2, args.save_path, args.gpu, args.frames, args.fps, args.fp16)
+
+ device = select_device(args.device)
+ model = load_model(args.model_path, device, args.fp16)
+ print(f"Loaded model sent to {device}.")
+
+ inference(model, device, args.img1, args.img2, args.save_path, args.frames, args.fps, args.fp16)
diff --git a/requirements.txt b/requirements.txt
index 8cf85fc..15be0f3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
opencv-python
torch
+accelerate
tqdm
\ No newline at end of file