|
18 | 18 | from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator |
19 | 19 | from urllib.parse import urlparse |
20 | 20 |
|
| 21 | +import numpy as np |
21 | 22 | import requests |
22 | 23 | import torch |
23 | 24 | from torch.utils.model_zoo import tqdm |
@@ -483,3 +484,39 @@ def verify_str_arg( |
483 | 484 | raise ValueError(msg) |
484 | 485 |
|
485 | 486 | return value |
| 487 | + |
| 488 | + |
| 489 | +def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: |
| 490 | + """Read file in .pfm format. Might contain either 1 or 3 channels of data. |
| 491 | +
|
| 492 | + Args: |
| 493 | + file_name (str): Path to the file. |
| 494 | + slice_channels (int): Number of channels to slice out of the file. |
| 495 | + Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. |
| 496 | + """ |
| 497 | + |
| 498 | + with open(file_name, "rb") as f: |
| 499 | + header = f.readline().rstrip() |
| 500 | + if header not in [b"PF", b"Pf"]: |
| 501 | + raise ValueError("Invalid PFM file") |
| 502 | + |
| 503 | + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) |
| 504 | + if not dim_match: |
| 505 | + raise Exception("Malformed PFM header.") |
| 506 | + w, h = (int(dim) for dim in dim_match.groups()) |
| 507 | + |
| 508 | + scale = float(f.readline().rstrip()) |
| 509 | + if scale < 0: # little-endian |
| 510 | + endian = "<" |
| 511 | + scale = -scale |
| 512 | + else: |
| 513 | + endian = ">" # big-endian |
| 514 | + |
| 515 | + data = np.fromfile(f, dtype=endian + "f") |
| 516 | + |
| 517 | + pfm_channels = 3 if header == b"PF" else 1 |
| 518 | + |
| 519 | + data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) |
| 520 | + data = np.flip(data, axis=1) # flip on h dimension |
| 521 | + data = data[:slice_channels, :, :] |
| 522 | + return data.astype(np.float32) |
0 commit comments