diff --git a/.gitignore b/.gitignore index 598912f..9c85796 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,11 @@ -/photos/output.mp4 +photos/output.mp4 +outputs/ + +.env/ +venv/ +__pycache__/ + +*.DS_Store + +checkpoints/* +!checkpoints/.gitkeep \ No newline at end of file diff --git a/README.md b/README.md index 832c1e8..0e7b53e 100644 --- a/README.md +++ b/README.md @@ -70,17 +70,29 @@ Argument list: # Testing exported model The following script creates an MP4 video of interpolated frames between 2 input images: ``` -python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--gpu] [--fp16] [--frames FRAMES] [--fps FPS] +python inference.py "model_path" "img1" "img2" [--save_path SAVE_PATH] [--device DEVICE_NAME] [--fp16] [--frames FRAMES] [--fps FPS] ``` * ```model_path``` Path to the exported TorchScript checkpoint * ```img1``` Path to the first image * ```img2``` Path to the second image * ```--save_path SAVE_PATH``` Path to save the interpolated frames as a video, if absent it will be saved in the same directory as ```img1``` is located and named ```output.mp4``` -* ```--gpu``` Whether to attempt to use GPU for predictions +* ```--device``` Choose between `gpu`, `cpu` and `auto`. The latter will be selected as the default setting. * ```--fp16``` Whether to use fp16 for calculations, speeds inference up on GPUs with tensor cores * ```--frames FRAMES``` Number of frames to interpolate between the input images * ```--fps FPS``` FPS of the output video +Or try the **gradio application**. First you need to install gradio on your local environment: + +```sh +pip install gradio +``` + +Then, you can run the application: + +```sh +python gradio_app.py gradio_app.py [--port PORT] [--debug] [--share] +``` + ### Results on the 2 example photos from original repository:

diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000..f7de945 --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,125 @@ +import os +import gradio as gr + +from inference import inference, select_device, load_model + +this_dir = os.path.dirname(__file__) +output_folder_path = os.path.join(this_dir, "outputs") + +output_count = 1 +continue_counting = True + +device = select_device("auto") +model = None + +with gr.Blocks() as app: + with gr.Row(): + with gr.Column(): + model_path = gr.FileExplorer("**/*.pt", root_dir=os.path.join(this_dir, "checkpoints"), file_count="single", label="Model path") + device_dropdown = gr.Dropdown(["auto", "gpu", "cpu"], value="auto", interactive=True, label="Device") + with gr.Row(): + fp16 = gr.Checkbox(value=False, label="fp16") + frames = gr.Number(18, precision=0, minimum=1, label="Number of frames to interpolate") + fps = gr.Number(10, precision=0, minimum=1, label="FPS") + with gr.Row(): + frame1 = gr.Image( + value=os.path.join(this_dir, "photos/one.png"), + type="filepath", + label="Frame 1", + ) + frame2 = gr.Image( + value=os.path.join(this_dir, "photos/two.png"), + type="filepath", + label="Frame 2", + ) + save_path = gr.Text(os.path.join(output_folder_path, f"{output_count}.mp4"), interactive=True, label="Save path") + with gr.Column(): + video = gr.Video(label="Output video") + with gr.Row(): + generate_button = gr.Button(value="Generate video", variant="primary") + + def model_loading_trigger(checkpoint_path: str | None, is_fp16: bool): + if checkpoint_path is not None: + global model, device + model = load_model(checkpoint_path, device, is_fp16) + print("Model loaded/re-loaded.") + else: + model = None + print("Model unloaded.") + + model_path.change(model_loading_trigger, inputs=[model_path, fp16]) + fp16.change(model_loading_trigger, inputs=[model_path, fp16]) + + + def device_loading_trigger(device_name: str, checkpoint_path: str, is_fp16: bool): + global device + force_settings = {} + try: + device = select_device(device_name) + except RuntimeError: + device = select_device("auto") + print(f"{device_name} cannot be loaded. Switch to device {device}") + force_settings["value"] = "auto" + model_loading_trigger(checkpoint_path, is_fp16) + return gr.update(**force_settings) + + device_dropdown.change(device_loading_trigger, inputs=[device_dropdown, model_path, fp16], outputs=device_dropdown) + + + def trigger_inference(frame1, frame2, save_path, frames, fps, is_fp16): + if model is None: + print("The model must be loaded first to generate the video.") + return [gr.update(), gr.update()] + inference(model, device, frame1, frame2, save_path, frames, fps, is_fp16) + if continue_counting: + global output_count + output_count += 1 + next_save_file_name = os.path.join(output_folder_path, f"{output_count}.mp4") + video_update = gr.update(value=save_path) + return [gr.update(value=next_save_file_name), video_update] + return [gr.update(), video_update] + + generate_button.click( + lambda: gr.update(interactive=False), + outputs=generate_button, + ).then( + trigger_inference, + inputs=[frame1, frame2, save_path, frames, fps, fp16], + outputs=[save_path, video], + ).then( + lambda: gr.update(interactive=True), + outputs=generate_button, + ) + + + def stop_counting_trigger(): + continue_counting = False + + save_path.change(stop_counting_trigger) + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--port", "-p", type=int) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--share", action="store_true") + parser.add_argument( + "--max_text_chunk_multiplier", + type=int, + default=10, + help="for preventing the issue: 'ValueError: Decompressed Data Too Large'", + ) + + args = parser.parse_args() + + from PIL import PngImagePlugin + # issue: https://github.com/python-pillow/Pillow/issues/5610 + PngImagePlugin.MAX_TEXT_CHUNK *= args.max_text_chunk_multiplier + + app.launch( + debug=args.debug, + server_port=args.port, + share=args.share, + ) \ No newline at end of file diff --git a/inference.py b/inference.py index 344d8bb..e25a4e8 100644 --- a/inference.py +++ b/inference.py @@ -5,28 +5,39 @@ import numpy as np import cv2 +from accelerate import Accelerator + from util import load_image -def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half): +def select_device(device: str) -> torch.device: + if device == "auto": + accelerator = Accelerator() + if accelerator.device.__str__().startswith("mps"): + print("WARNING: MPS device not (yet?) supported by the project.") + return torch.device("cpu") + else: + return accelerator.device + return torch.device(device) + + + +def load_model(model_path: str, device: torch.device, half: bool) -> torch.nn.Module: model = torch.jit.load(model_path, map_location='cpu') + if half: + model = model.half() + model = model.to(device) + return model + + +def inference(model: torch.nn.Module, device: torch.device, img1: str, img2: str, save_path: str, inter_frames: int, fps: int, half: bool): model.eval() img_batch_1, crop_region_1 = load_image(img1) img_batch_2, crop_region_2 = load_image(img2) img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2) img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2) - - if not half: - model.float() - - if gpu and torch.cuda.is_available(): - if half: - model = model.half() - else: - model.float() - model = model.cuda() - + if save_path == 'img1 folder': save_path = os.path.join(os.path.split(img1)[0], 'output.mp4') @@ -51,12 +62,11 @@ def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half): x0 = results[start_i] x1 = results[end_i] - if gpu and torch.cuda.is_available(): - if half: - x0 = x0.half() - x1 = x1.half() - x0 = x0.cuda() - x1 = x1.cuda() + if half: + x0 = x0.half() + x1 = x1.half() + x0 = x0.to(device) + x1 = x1.to(device) dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) @@ -95,11 +105,15 @@ def inference(model_path, img1, img2, save_path, gpu, inter_frames, fps, half): parser.add_argument('img2', type=str, help='Path to the second image') parser.add_argument('--save_path', type=str, default='img1 folder', help='Path to save the interpolated frames') - parser.add_argument('--gpu', action='store_true', help='Use GPU') + parser.add_argument('--device', type=str, choices=["auto", "gpu", "cpu"], default="auto", help='Device to run the inference (default: choose the optimal device)') parser.add_argument('--fp16', action='store_true', help='Use FP16') parser.add_argument('--frames', type=int, default=18, help='Number of frames to interpolate') parser.add_argument('--fps', type=int, default=10, help='FPS of the output video') args = parser.parse_args() - - inference(args.model_path, args.img1, args.img2, args.save_path, args.gpu, args.frames, args.fps, args.fp16) + + device = select_device(args.device) + model = load_model(args.model_path, device, args.fp16) + print(f"Loaded model sent to {device}.") + + inference(model, device, args.img1, args.img2, args.save_path, args.frames, args.fps, args.fp16) diff --git a/requirements.txt b/requirements.txt index 8cf85fc..15be0f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ opencv-python torch +accelerate tqdm \ No newline at end of file