Skip to content

Different behaviors using model_path #142

@yuanze1024

Description

@yuanze1024

LPIPS is a useful tool to evaluate the quality of generated images.

However, I met some questions using a cached vgg ckpt to save downloading time across different environments:

model = lpips.LPIPS(net="vgg", model_path="./assets/vgg16-397923af.pth").to(get_device())

The cached ckpt is exactly the same as newly downloaded.

root@xx:path# md5sum /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
463aeb51ba5e122501bd03f4ad6d5374  /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
root@xx:path# md5sum ./assets/vgg16-397923af.pth
463aeb51ba5e122501bd03f4ad6d5374  ./assets/vgg16-397923af.pth

A test snippet is recorded below:

import lpips
import torch

torch.set_default_device('cuda' if torch.cuda.is_available() else 'cpu')

def test_fixed_lpips():
    """Test the fixed LPIPS implementation"""
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print("=== Testing Fixed LPIPS ===")
    
    # Create test data
    torch.manual_seed(42)
    img1 = torch.rand(4, 3, 256, 256).to(device)
    img2 = torch.rand(4, 3, 256, 256).to(device)
    
    # Original LPIPS (reference standard)
    original_lpips = lpips.LPIPS(net='vgg').to(device)
    
    # Fixed LPIPS
    fixed_lpips = lpips.LPIPS(net="vgg", model_path="./assets/vgg16-397923af.pth").to(device)
    
    print("1. Same image test:")
    orig_same = original_lpips(img1, img1).mean()
    fixed_same = fixed_lpips(img1, img1).mean()
    print(f"   Original LPIPS: {orig_same.item():.8f}")
    print(f"   Fixed LPIPS: {fixed_same.item():.8f}")
    
    print("\n2. Different image test:")
    orig_diff = original_lpips(img1, img2).mean()
    fixed_diff = fixed_lpips(img1, img2, normalize=True).mean()
    print(f"   Original LPIPS: {orig_diff.item():.8f}")
    print(f"   Fixed LPIPS: {fixed_diff.item():.8f}")
    
    print("\n3. Similar images test:")
    similar_img = img1 + 0.001 * torch.randn_like(img1)
    orig_similar = original_lpips(img1, similar_img).mean()
    fixed_similar = fixed_lpips(img1, similar_img).mean()
    print(f"   Original LPIPS: {orig_similar.item():.8f}")
    print(f"   Fixed LPIPS: {fixed_similar.item():.8f}")
    
    print("\n4. Simulated training scenario test:")
    # Simulate the situation in your training code
    orig_mv_frames = torch.rand(4, 3, 256, 256).to(device)
    recon_mv_frames = torch.rand(4, 3, 256, 256).to(device)
    
    # Training call (without input_range)
    lpips_train = fixed_lpips(orig_mv_frames, recon_mv_frames).mean()
    print(f"   Training call (without input_range): {lpips_train.item():.8f}")

    print("\n5. Stability test over multiple runs:")
    for i in range(5):
        test_loss = fixed_lpips(img1, img1).mean()
        print(f"   Run {i+1} (same image): {test_loss.item():.8f}")
    
if __name__ == '__main__':
    test_fixed_lpips()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions