diff --git a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py index e06c8c7268..ffbc5519db 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py @@ -7,6 +7,8 @@ # pyre-strict +import torch + from fbgemm_gpu.sll.cpu_sll import ( # noqa F401 cpu_array_jagged_bmm_jagged_out, cpu_dense_jagged_cat_jagged_out, @@ -21,14 +23,14 @@ cpu_jagged_jagged_bmm_jagged_out, cpu_jagged_self_substraction_jagged_out, cpu_jagged_softmax, - meta_jagged_dense_elementwise_mul_jagged_out, - meta_jagged_self_substraction_jagged_out, ) from fbgemm_gpu.sll.meta_sll import ( # noqa F401 meta_array_jagged_bmm_jagged_out, meta_jagged2_softmax, + meta_jagged_dense_elementwise_mul_jagged_out, meta_jagged_jagged_bmm_jagged_out, + meta_jagged_self_substraction_jagged_out, ) from fbgemm_gpu.sll.triton_sll import ( # noqa F401 @@ -208,144 +210,131 @@ """ ) -# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function -# however, this is not ideal because in the inference case, we don't need the autograd forward -# to save the context because we don't need to do backward. -lib.register( - "sll_jagged_dense_bmm", - { - "CUDA": jagged_dense_bmm, - "AutogradCUDA": jagged_dense_bmm, +# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same +# function however, this is not ideal because in the inference case, we don't +# need the autograd forward to save the context because we don't need to do +# backward. + +# pyre-ignore[5] +sll_cpu_registrations = { + "sll_jagged_dense_bmm": { "CPU": cpu_jagged_dense_bmm, "AutogradCPU": cpu_jagged_dense_bmm, }, -) - -lib.register( - "sll_jagged_jagged_bmm", - { - "CUDA": jagged_jagged_bmm, - "AutogradCUDA": jagged_jagged_bmm, + "sll_jagged_jagged_bmm": { "CPU": cpu_jagged_jagged_bmm, "AutogradCPU": cpu_jagged_jagged_bmm, }, -) - -lib.register( - "sll_dense_jagged_cat_jagged_out", - { - "CUDA": dense_jagged_cat_jagged_out, + "sll_dense_jagged_cat_jagged_out": { "CPU": cpu_dense_jagged_cat_jagged_out, }, -) - -lib.register( - "sll_jagged_self_substraction_jagged_out", - { - "CUDA": triton_jagged_self_substraction_jagged_out, + "sll_jagged_self_substraction_jagged_out": { "CPU": cpu_jagged_self_substraction_jagged_out, "Meta": meta_jagged_self_substraction_jagged_out, }, -) - -lib.register( - "sll_jagged2_to_padded_dense", - { - "CUDA": jagged2_to_padded_dense, - "AutogradCUDA": jagged2_to_padded_dense, + "sll_jagged2_to_padded_dense": { "CPU": cpu_jagged2_to_padded_dense, "AutogradCPU": cpu_jagged2_to_padded_dense, }, -) - -lib.register( - "sll_jagged_dense_elementwise_mul_jagged_out", - { - "CUDA": jagged_dense_elementwise_mul_jagged_out, - "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out, + "sll_jagged_dense_elementwise_mul_jagged_out": { "CPU": cpu_jagged_dense_elementwise_mul_jagged_out, "AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out, "Meta": meta_jagged_dense_elementwise_mul_jagged_out, }, -) - -lib.register( - "sll_jagged_softmax", - { - "CUDA": jagged_softmax, - "AutogradCUDA": jagged_softmax, + "sll_jagged_softmax": { "CPU": cpu_jagged_softmax, "AutogradCPU": cpu_jagged_softmax, }, -) - -lib.register( - "sll_jagged2_softmax", - { - "CUDA": jagged2_softmax, - "AutogradCUDA": jagged2_softmax, + "sll_jagged2_softmax": { "CPU": cpu_jagged2_softmax, "AutogradCPU": cpu_jagged2_softmax, "AutogradMeta": meta_jagged2_softmax, }, -) - -lib.register( - "sll_array_jagged_bmm_jagged_out", - { - "CUDA": array_jagged_bmm_jagged_out, - "AutogradCUDA": array_jagged_bmm_jagged_out, + "sll_array_jagged_bmm_jagged_out": { "CPU": cpu_array_jagged_bmm_jagged_out, "AutogradCPU": cpu_array_jagged_bmm_jagged_out, "AutogradMeta": meta_array_jagged_bmm_jagged_out, }, -) - -lib.register( - "sll_jagged_jagged_bmm_jagged_out", - { - "CUDA": jagged_jagged_bmm_jagged_out, - "AutogradCUDA": jagged_jagged_bmm_jagged_out, + "sll_jagged_jagged_bmm_jagged_out": { "CPU": cpu_jagged_jagged_bmm_jagged_out, "AutogradCPU": cpu_jagged_jagged_bmm_jagged_out, "AutogradMeta": meta_jagged_jagged_bmm_jagged_out, }, -) - -lib.register( - "sll_jagged_flash_attention_basic", - { - "CUDA": jagged_flash_attention_basic, - "AutogradCUDA": jagged_flash_attention_basic, + "sll_jagged_flash_attention_basic": { "CPU": cpu_jagged_flash_attention_basic, "AutogradCPU": cpu_jagged_flash_attention_basic, }, -) - -lib.register( - "sll_jagged_dense_elementwise_add", - { - "CUDA": jagged_dense_elementwise_add, - "AutogradCUDA": jagged_dense_elementwise_add, + "sll_jagged_dense_elementwise_add": { "CPU": cpu_jagged_dense_elementwise_add, "AutogradCPU": cpu_jagged_dense_elementwise_add, }, -) - -lib.register( - "sll_jagged_dense_flash_attention", - { - "CUDA": jagged_dense_flash_attention, - "AutogradCUDA": jagged_dense_flash_attention, + "sll_jagged_dense_flash_attention": { "CPU": cpu_jagged_dense_flash_attention, "AutogradCPU": cpu_jagged_dense_flash_attention, }, -) +} -lib.register( - "sll_multi_head_jagged_flash_attention", - { +# pyre-ignore[5] +sll_gpu_registrations = { + "sll_jagged_dense_bmm": { + "CUDA": jagged_dense_bmm, + "AutogradCUDA": jagged_dense_bmm, + }, + "sll_jagged_jagged_bmm": { + "CUDA": jagged_jagged_bmm, + "AutogradCUDA": jagged_jagged_bmm, + }, + "sll_dense_jagged_cat_jagged_out": { + "CUDA": dense_jagged_cat_jagged_out, + }, + "sll_jagged_self_substraction_jagged_out": { + "CUDA": triton_jagged_self_substraction_jagged_out, + }, + "sll_jagged2_to_padded_dense": { + "CUDA": jagged2_to_padded_dense, + "AutogradCUDA": jagged2_to_padded_dense, + }, + "sll_jagged_dense_elementwise_mul_jagged_out": { + "CUDA": jagged_dense_elementwise_mul_jagged_out, + "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out, + }, + "sll_jagged_softmax": { + "CUDA": jagged_softmax, + "AutogradCUDA": jagged_softmax, + }, + "sll_jagged2_softmax": { + "CUDA": jagged2_softmax, + "AutogradCUDA": jagged2_softmax, + }, + "sll_array_jagged_bmm_jagged_out": { + "CUDA": array_jagged_bmm_jagged_out, + "AutogradCUDA": array_jagged_bmm_jagged_out, + }, + "sll_jagged_jagged_bmm_jagged_out": { + "CUDA": jagged_jagged_bmm_jagged_out, + "AutogradCUDA": jagged_jagged_bmm_jagged_out, + }, + "sll_jagged_flash_attention_basic": { + "CUDA": jagged_flash_attention_basic, + "AutogradCUDA": jagged_flash_attention_basic, + }, + "sll_jagged_dense_elementwise_add": { + "CUDA": jagged_dense_elementwise_add, + "AutogradCUDA": jagged_dense_elementwise_add, + }, + "sll_jagged_dense_flash_attention": { + "CUDA": jagged_dense_flash_attention, + "AutogradCUDA": jagged_dense_flash_attention, + }, + "sll_multi_head_jagged_flash_attention": { "CUDA": multi_head_jagged_flash_attention, "AutogradCUDA": multi_head_jagged_flash_attention, }, -) +} + +for op_name, dispatches in sll_cpu_registrations.items(): + lib.register(op_name, dispatches) + +if torch.cuda.is_available(): + for op_name, dispatches in sll_gpu_registrations.items(): + lib.register(op_name, dispatches) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py index bf2349429f..f50260ae0e 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py @@ -213,19 +213,6 @@ def cpu_jagged_self_substraction_jagged_out( return jagged_B -def meta_jagged_self_substraction_jagged_out( - jagged_A: torch.Tensor, - offsets_a: torch.Tensor, - offsets_b: torch.Tensor, - max_seq_len: int, -) -> torch.Tensor: - return torch.empty( - [torch.library.get_ctx().new_dynamic_size()], - dtype=jagged_A.dtype, - device=jagged_A.device, - ) - - def cpu_jagged2_to_padded_dense( values: torch.Tensor, offsets: torch.Tensor, @@ -352,65 +339,6 @@ def cpu_jagged_dense_elementwise_mul_jagged_out( ) -class MetaJaggedDenseElementwiseMul(torch.autograd.Function): - @staticmethod - # pyre-fixme - def forward( - ctx, # pyre-ignore [2] - x: torch.Tensor, - y: torch.Tensor, - x_seq_lengths: torch.Tensor, - x_offsets: torch.Tensor, - max_seq_len: int, - ) -> torch.Tensor: - ctx.max_seq_len = max_seq_len - - ctx.save_for_backward( - x, - y, - x_seq_lengths, - x_offsets, - ) - - total_L = x.size(0) - jagged_C = torch.zeros((total_L), device=x.device, dtype=x.dtype) - - return jagged_C - - @staticmethod - # pyre-fixme - def backward(ctx, grad_output: torch.Tensor): - ( - x, - y, - x_seq_lengths, - x_offsets, - ) = ctx.saved_tensors - - total_L = grad_output.size(0) - jagged_C = torch.zeros( - (total_L), device=grad_output.device, dtype=grad_output.dtype - ) - - return jagged_C, None, None, None, None - - -def meta_jagged_dense_elementwise_mul_jagged_out( - x: torch.Tensor, - y: torch.Tensor, - x_seq_lengths: torch.Tensor, - x_offsets: torch.Tensor, - max_seq_len: int, -) -> torch.Tensor: - return MetaJaggedDenseElementwiseMul.apply( - x, - y, - x_seq_lengths, - x_offsets, - max_seq_len, - ) - - class JaggedSoftmaxCPU(torch.autograd.Function): @staticmethod # pyre-fixme diff --git a/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py index 924f13b260..c74d900739 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py @@ -8,6 +8,78 @@ import torch +def meta_jagged_self_substraction_jagged_out( + jagged_A: torch.Tensor, + offsets_a: torch.Tensor, + offsets_b: torch.Tensor, + max_seq_len: int, +) -> torch.Tensor: + return torch.empty( + [torch.library.get_ctx().new_dynamic_size()], + dtype=jagged_A.dtype, + device=jagged_A.device, + ) + + +class MetaJaggedDenseElementwiseMul(torch.autograd.Function): + @staticmethod + # pyre-fixme + def forward( + ctx, # pyre-ignore [2] + x: torch.Tensor, + y: torch.Tensor, + x_seq_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + ) -> torch.Tensor: + ctx.max_seq_len = max_seq_len + + ctx.save_for_backward( + x, + y, + x_seq_lengths, + x_offsets, + ) + + total_L = x.size(0) + jagged_C = torch.zeros((total_L), device=x.device, dtype=x.dtype) + + return jagged_C + + @staticmethod + # pyre-fixme + def backward(ctx, grad_output: torch.Tensor): + ( + x, + y, + x_seq_lengths, + x_offsets, + ) = ctx.saved_tensors + + total_L = grad_output.size(0) + jagged_C = torch.zeros( + (total_L), device=grad_output.device, dtype=grad_output.dtype + ) + + return jagged_C, None, None, None, None + + +def meta_jagged_dense_elementwise_mul_jagged_out( + x: torch.Tensor, + y: torch.Tensor, + x_seq_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, +) -> torch.Tensor: + return MetaJaggedDenseElementwiseMul.apply( + x, + y, + x_seq_lengths, + x_offsets, + max_seq_len, + ) + + class Jagged2SoftmaxMeta(torch.autograd.Function): @staticmethod # pyre-fixme