-
Notifications
You must be signed in to change notification settings - Fork 519
Open
Description
LPIPS is a useful tool to evaluate the quality of generated images.
However, I met some questions using a cached vgg ckpt to save downloading time across different environments:
model = lpips.LPIPS(net="vgg", model_path="./assets/vgg16-397923af.pth").to(get_device())The cached ckpt is exactly the same as newly downloaded.
root@xx:path# md5sum /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
463aeb51ba5e122501bd03f4ad6d5374 /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
root@xx:path# md5sum ./assets/vgg16-397923af.pth
463aeb51ba5e122501bd03f4ad6d5374 ./assets/vgg16-397923af.pth
A test snippet is recorded below:
import lpips
import torch
torch.set_default_device('cuda' if torch.cuda.is_available() else 'cpu')
def test_fixed_lpips():
"""Test the fixed LPIPS implementation"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("=== Testing Fixed LPIPS ===")
# Create test data
torch.manual_seed(42)
img1 = torch.rand(4, 3, 256, 256).to(device)
img2 = torch.rand(4, 3, 256, 256).to(device)
# Original LPIPS (reference standard)
original_lpips = lpips.LPIPS(net='vgg').to(device)
# Fixed LPIPS
fixed_lpips = lpips.LPIPS(net="vgg", model_path="./assets/vgg16-397923af.pth").to(device)
print("1. Same image test:")
orig_same = original_lpips(img1, img1).mean()
fixed_same = fixed_lpips(img1, img1).mean()
print(f" Original LPIPS: {orig_same.item():.8f}")
print(f" Fixed LPIPS: {fixed_same.item():.8f}")
print("\n2. Different image test:")
orig_diff = original_lpips(img1, img2).mean()
fixed_diff = fixed_lpips(img1, img2, normalize=True).mean()
print(f" Original LPIPS: {orig_diff.item():.8f}")
print(f" Fixed LPIPS: {fixed_diff.item():.8f}")
print("\n3. Similar images test:")
similar_img = img1 + 0.001 * torch.randn_like(img1)
orig_similar = original_lpips(img1, similar_img).mean()
fixed_similar = fixed_lpips(img1, similar_img).mean()
print(f" Original LPIPS: {orig_similar.item():.8f}")
print(f" Fixed LPIPS: {fixed_similar.item():.8f}")
print("\n4. Simulated training scenario test:")
# Simulate the situation in your training code
orig_mv_frames = torch.rand(4, 3, 256, 256).to(device)
recon_mv_frames = torch.rand(4, 3, 256, 256).to(device)
# Training call (without input_range)
lpips_train = fixed_lpips(orig_mv_frames, recon_mv_frames).mean()
print(f" Training call (without input_range): {lpips_train.item():.8f}")
print("\n5. Stability test over multiple runs:")
for i in range(5):
test_loss = fixed_lpips(img1, img1).mean()
print(f" Run {i+1} (same image): {test_loss.item():.8f}")
if __name__ == '__main__':
test_fixed_lpips()Metadata
Metadata
Assignees
Labels
No labels