Skip to content

Conversation

cyy536
Copy link
Contributor

@cyy536 cyy536 commented Aug 25, 2025

PR Category

User Experience

PR Types

New features

Description

add compat.softmax to compat with torch.softmax

Copy link

paddle-bot bot commented Aug 25, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link

codecov-commenter commented Aug 25, 2025

Codecov Report

❌ Patch coverage is 88.67925% with 6 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@2fd8a7e). Learn more about missing BASE report.

Files with missing lines Patch % Lines
python/paddle/utils/decorator_utils.py 70.00% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #74874   +/-   ##
==========================================
  Coverage           ?   88.67%           
==========================================
  Files              ?        7           
  Lines              ?       53           
  Branches           ?        0           
==========================================
  Hits               ?       47           
  Misses             ?        6           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥还分了这么多softmax。

只有两个softmax,compat.softmax将dim进行设置直接调用原来的softmax。

@zhwesky2010 zhwesky2010 changed the title add compat.softmax to compat with torch.softmax [API Compatiblity] add compat.softmax to compat with torch.softmax Aug 26, 2025
@cyy536
Copy link
Contributor Author

cyy536 commented Aug 26, 2025

一共写了4个,它们类型不一样
paddle.Tensor.softmax==torch.Tensor.softmax的类型是 (int dim, torch.dtype dtype = None)
paddle.softmax==torch.softmax的类型是 (Tensor input, int dim, torch.dtype dtype = None, *, Tensor out = None)
paddle.compat==torch.nn.functional.softmax的类型是 (input, dim=None, _stacklevel=3, dtype=None)
paddle.nn.functional.softmax的类型是 (input, dim=None, _stacklevel=3, dtype=None),但默认值不一样



@softmax_param_ignore_alias
def compat_softmax(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个实现放到对应位置里去吧,tensor/compat.py 里。paddle.compat.softmax 可以直接调 paddle.softmax

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已调整

@@ -159,7 +164,156 @@ def softmax(
[0.03205860, 0.08714432, 0.23688282, 0.64391426],
[0.03205860, 0.08714432, 0.23688282, 0.64391426]]])
"""
return _softmax_impl(x, axis, dtype, name, out=out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以直接展开

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已展开


@softmax_param_ignore_alias
def compat_softmax(
x: Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compat下的签名风格就按 input/dim 来吧,与min/max/sort 这些保持一致。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已调整

@@ -19,5 +19,6 @@
sort,
split,
)
from .tensor.softmax import compat_softmax as softmax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去 .tensor.compat 下实现。

return _softmax_impl(x, axis, dtype, name, out=out)


@softmax_param_ignore_alias
Copy link
Contributor

@zhwesky2010 zhwesky2010 Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能复用当前的一些基建吗,保持风格一致的warning和提示比较好。参考 paddle.compat.min 风格。

infoflow 2025-08-26 18-17-31

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已复用

@@ -494,6 +494,8 @@
)
from .to_string import set_printoptions # noqa: F401

__all__ = ['softmax']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不能放__all__里,这个要放到下面的tensor_method_func,才会bind到paddle.Tensor上

if in_dynamic_or_pir_mode():
outs_cast = input if dtype is None else _C_ops.cast(input, dtype)
return paddle.assign(_C_ops.softmax(outs_cast, dim), out)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的这个老静态图的的分支不需要实现

Size2,
)


from paddle import nn
from paddle.base.data_feeder import check_dtype, check_variable_and_dtype
from paddle.base.framework import convert_np_dtype_to_dtype_
from paddle.base.layer_helper import LayerHelper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

老静态图分支不需要实现,移除无需导入的东西

def softmax(
input: Tensor,
dim: int | None = None,
_stacklevel: int = 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能对齐到torch.softmax吗,这个_stacklevel参数

你是不是继承下ForbidKeywordsDecorator,隐藏处理下_stacklevel,然后外面的签名还是按input, dim, dtype, *, out

dtype = convert_np_dtype_to_dtype_(dtype)
if in_dynamic_or_pir_mode():
outs_cast = input if dtype is None else _C_ops.cast(input, dtype)
return paddle.assign(_C_ops.softmax(outs_cast, dim), out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个直接 _C_ops.softmax(outs_cast, dim, out=out)

@@ -121,6 +121,7 @@ def softmax(
:math:`axis + D` . Default is -1.
dtype (str, optional): The data type of the output tensor, can be bfloat16, float16, float32, float64.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
out (Tensor, optional): The output Tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能用 _C_ops.softmax(outs_cast, axis, out=out) 来实现吗

@@ -31,6 +31,7 @@
real,
shape,
)
from .compat_softmax import softmax as softmax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个得加到tensor_method_func,才会bind到paddle.Tensor上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_method_func一直都有加

@@ -403,6 +396,70 @@ def process(
return args, kwargs


class ForbidKeywordsIgnoreOneParamDecorator(DecoratorBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以继承ForbidKeywordsDecorator,只抽出它的ignore逻辑,其他采用 super().init(*args, **kwargs) ,这样可以最大化复用代码。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@cyy536
Copy link
Contributor Author

cyy536 commented Aug 28, 2025

/re-run all-failed

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyy536
Copy link
Contributor Author

cyy536 commented Aug 28, 2025

/re-run coverage build

@zhwesky2010 zhwesky2010 merged commit cd64b23 into PaddlePaddle:develop Aug 28, 2025
153 of 162 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants