Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 185 additions & 179 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.utils import (
BlendMode,
PytorchPadMode,
Range,
convert_data_type,
convert_to_dst_type,
ensure_tuple,
Expand Down Expand Up @@ -135,189 +136,194 @@ def sliding_window_inference(
- input must be channel-first and have a batch dim, supports N-D sliding window.

"""
buffered = buffer_steps is not None and buffer_steps > 0
num_spatial_dims = len(inputs.shape) - 2
if buffered:
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
if buffer_dim < 0:
buffer_dim += num_spatial_dims
overlap = ensure_tuple_rep(overlap, num_spatial_dims)
for o in overlap:
if o < 0 or o >= 1:
raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.")
compute_dtype = inputs.dtype

# determine image spatial size and batch size
# Note: all input images must have the same image size and batch size
batch_size, _, *image_size_ = inputs.shape
device = device or inputs.device
sw_device = sw_device or inputs.device

temp_meta = None
if isinstance(inputs, MetaTensor):
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0]
roi_size = fall_back_tuple(roi_size, image_size_)

# in case that image size is smaller than roi size
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
pad_size = []
for k in range(len(inputs.shape) - 1, 1, -1):
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
if any(pad_size):
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)

# Store all slices
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered)

num_win = len(slices) # number of windows per image
total_slices = num_win * batch_size # total number of windows
windows_range: Iterable
if not buffered:
non_blocking = False
windows_range = range(0, total_slices, sw_batch_size)
else:
slices, n_per_batch, b_slices, windows_range = _create_buffered_slices(
slices, batch_size, sw_batch_size, buffer_dim, buffer_steps
)
non_blocking, _ss = torch.cuda.is_available(), -1
for x in b_slices[:n_per_batch]:
if x[1] < _ss: # detect overlapping slices
non_blocking = False
break
_ss = x[2]

# Create window-level importance map
valid_patch_size = get_valid_patch_size(image_size, roi_size)
if valid_patch_size == roi_size and (roi_weight_map is not None):
importance_map_ = roi_weight_map
else:
try:
valid_p_size = ensure_tuple(valid_patch_size)
importance_map_ = compute_importance_map(
valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype
)
if len(importance_map_.shape) == num_spatial_dims and not process_fn:
importance_map_ = importance_map_[None, None] # adds batch, channel dimensions
except Exception as e:
raise RuntimeError(
f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n"
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
) from e
importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0]

# stores output and count map
output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore
# for each patch
for slice_g in tqdm(windows_range) if progress else windows_range:
slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices))
unravel_slice = [
[slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win])
for idx in slice_range
]
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
if with_coord:
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
else:
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch

# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
if process_fn:
seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_)
else:
w_t = importance_map_
if len(w_t.shape) == num_spatial_dims:
w_t = w_t[None, None]
w_t = w_t.to(dtype=compute_dtype, device=sw_device)
with Range("SW_preprocessing"):
buffered = buffer_steps is not None and buffer_steps > 0
num_spatial_dims = len(inputs.shape) - 2
if buffered:
c_start, c_end = b_slices[b_s][1:]
if not sw_device_buffer:
k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored
sp_size = list(image_size)
sp_size[buffer_dim] = c_end - c_start
sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)]
for p, s in zip(seg_tuple[0], unravel_slice):
offset = s[buffer_dim + 2].start - c_start
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
s[0] = slice(0, 1)
sw_device_buffer[0][s] += p * w_t
b_i += len(unravel_slice)
if b_i < b_slices[b_s][0]:
continue
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
if buffer_dim < 0:
buffer_dim += num_spatial_dims
overlap = ensure_tuple_rep(overlap, num_spatial_dims)
for o in overlap:
if o < 0 or o >= 1:
raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.")
compute_dtype = inputs.dtype

# determine image spatial size and batch size
# Note: all input images must have the same image size and batch size
batch_size, _, *image_size_ = inputs.shape
device = device or inputs.device
sw_device = sw_device or inputs.device

temp_meta = None
if isinstance(inputs, MetaTensor):
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0]
roi_size = fall_back_tuple(roi_size, image_size_)

# in case that image size is smaller than roi size
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
pad_size = []
for k in range(len(inputs.shape) - 1, 1, -1):
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
if any(pad_size):
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)

# Store all slices
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered)

num_win = len(slices) # number of windows per image
total_slices = num_win * batch_size # total number of windows
windows_range: Iterable
if not buffered:
non_blocking = False
windows_range = range(0, total_slices, sw_batch_size)
else:
sw_device_buffer = list(seg_tuple)

for ss in range(len(sw_device_buffer)):
b_shape = sw_device_buffer[ss].shape
seg_chns, seg_shape = b_shape[1], b_shape[2:]
z_scale = None
if not buffered and seg_shape != roi_size:
z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)]
w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode)
if len(output_image_list) <= ss:
output_shape = [batch_size, seg_chns]
output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size)
# allocate memory to store the full output and the count for overlapping parts
new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore
output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device))
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
w_t_ = w_t.to(device)
for __s in slices:
if z_scale is not None:
__s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale))
count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_
if buffered:
o_slice = [slice(None)] * len(inputs.shape)
o_slice[buffer_dim + 2] = slice(c_start, c_end)
img_b = b_s // n_per_batch # image batch index
o_slice[0] = slice(img_b, img_b + 1)
if non_blocking:
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
else:
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
else:
sw_device_buffer[ss] *= w_t
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
_compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss])
sw_device_buffer = []
if buffered:
b_s += 1

if non_blocking:
torch.cuda.current_stream().synchronize()

# account for any overlapping sections
for ss in range(len(output_image_list)):
output_image_list[ss] /= count_map_list.pop(0)

# remove padding if image_size smaller than roi_size
if any(pad_size):
kwargs.update({"pad_size": pad_size})
for ss, output_i in enumerate(output_image_list):
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
final_slicing: list[slice] = []
for sp in range(num_spatial_dims):
si = num_spatial_dims - sp - 1
slice_dim = slice(
int(round(pad_size[sp * 2] * zoom_scale[si])),
int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])),
slices, n_per_batch, b_slices, windows_range = _create_buffered_slices(
slices, batch_size, sw_batch_size, buffer_dim, buffer_steps
)
non_blocking, _ss = torch.cuda.is_available(), -1
for x in b_slices[:n_per_batch]:
if x[1] < _ss: # detect overlapping slices
non_blocking = False
break
_ss = x[2]

# Create window-level importance map
valid_patch_size = get_valid_patch_size(image_size, roi_size)
if valid_patch_size == roi_size and (roi_weight_map is not None):
importance_map_ = roi_weight_map
else:
try:
valid_p_size = ensure_tuple(valid_patch_size)
importance_map_ = compute_importance_map(
valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype
)
final_slicing.insert(0, slice_dim)
output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)]
if len(importance_map_.shape) == num_spatial_dims and not process_fn:
importance_map_ = importance_map_[None, None] # adds batch, channel dimensions
except Exception as e:
raise RuntimeError(
f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n"
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
) from e
importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0]

with Range("SW_PatchForward_Loop"):
# stores output and count map
output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore
# for each patch
for slice_g in tqdm(windows_range) if progress else windows_range:
with Range("SW_Model_Computation"):
slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices))
unravel_slice = [
[slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win])
for idx in slice_range
]
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
if with_coord:
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
else:
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch

final_output = _pack_struct(output_image_list, dict_keys)
if temp_meta is not None:
final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0]
else:
final_output = convert_to_dst_type(final_output, inputs, device=device)[0]
with Range("SW_PatchProb_Processing"):
# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
if process_fn:
seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_)
else:
w_t = importance_map_
if len(w_t.shape) == num_spatial_dims:
w_t = w_t[None, None]
w_t = w_t.to(dtype=compute_dtype, device=sw_device)
if buffered:
c_start, c_end = b_slices[b_s][1:]
if not sw_device_buffer:
k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored
sp_size = list(image_size)
sp_size[buffer_dim] = c_end - c_start
sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)]
for p, s in zip(seg_tuple[0], unravel_slice):
offset = s[buffer_dim + 2].start - c_start
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
s[0] = slice(0, 1)
sw_device_buffer[0][s] += p * w_t
b_i += len(unravel_slice)
if b_i < b_slices[b_s][0]:
continue
else:
sw_device_buffer = list(seg_tuple)

for ss in range(len(sw_device_buffer)):
b_shape = sw_device_buffer[ss].shape
seg_chns, seg_shape = b_shape[1], b_shape[2:]
z_scale = None
if not buffered and seg_shape != roi_size:
z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)]
w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode)
if len(output_image_list) <= ss:
output_shape = [batch_size, seg_chns]
output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size)
# allocate memory to store the full output and the count for overlapping parts
new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore
output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device))
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
w_t_ = w_t.to(device)
for __s in slices:
if z_scale is not None:
__s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale))
count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_
if buffered:
o_slice = [slice(None)] * len(inputs.shape)
o_slice[buffer_dim + 2] = slice(c_start, c_end)
img_b = b_s // n_per_batch # image batch index
o_slice[0] = slice(img_b, img_b + 1)
if non_blocking:
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
else:
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
else:
sw_device_buffer[ss] *= w_t
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
_compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss])
sw_device_buffer = []
if buffered:
b_s += 1

with Range("SW_Postprocessing"):
if non_blocking:
torch.cuda.current_stream().synchronize()

# account for any overlapping sections
for ss in range(len(output_image_list)):
output_image_list[ss] /= count_map_list.pop(0)

# remove padding if image_size smaller than roi_size
if any(pad_size):
kwargs.update({"pad_size": pad_size})
for ss, output_i in enumerate(output_image_list):
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
final_slicing: list[slice] = []
for sp in range(num_spatial_dims):
si = num_spatial_dims - sp - 1
slice_dim = slice(
int(round(pad_size[sp * 2] * zoom_scale[si])),
int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])),
)
final_slicing.insert(0, slice_dim)
output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)]

final_output = _pack_struct(output_image_list, dict_keys)
if temp_meta is not None:
final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0]
else:
final_output = convert_to_dst_type(final_output, inputs, device=device)[0]

return final_output # type: ignore

Expand Down
Loading