-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Wrong output of GroupBy transform with string input (e.g., transform('rank')) #22509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Here is the output I get after running the code I submitted:
|
xref #19354 Adding 'rank' to pandas/pandas/core/groupby/generic.py Line 524 in cf25c5c
and pandas/pandas/core/groupby/generic.py Line 919 in cf25c5c
should be used discriminate between aggregating functions (which _transform_fast assumes) and non-aggregating functions (like rank), whether they are cythonized is not the point. The result of What the fast path does when passed a non-aggregating function, is generate a rank series The fast_path was introduces in 2014 with fe55b89, back when SeriesGroupby didn't actually have any non-aggregation function implemented, so at the time it was good enough. |
To illustrate: In [34]: df=pd.DataFrame(dict(price=[10,10,20,20,30,30],cost=(100,200,300,400,500,600)))
...: df
Out[34]:
price cost
0 10 100
1 10 200
2 20 300
3 20 400
4 30 500
5 30 600
# correct
In [36]: df.groupby('price').cost.rank()
Out[36]:
0 1.0
1 2.0
2 1.0
3 2.0
4 1.0
5 2.0
Name: cost, dtype: float64
# wrong
In [35]: df.groupby('price').cost.transform('rank')
Out[35]:
0 1.0
1 1.0
2 2.0
3 2.0
4 1.0
5 1.0
Name: cost, dtype: float64
# the above generated via
In [47]: ids, _, _ = df.groupby('price').grouper.group_info
...: ids
Out[47]: array([0, 0, 1, 1, 2, 2])
In [48]: df.groupby('price').rank().loc[ids]
Out[48]:
cost
0 1.0
0 1.0
1 2.0
1 2.0
2 1.0
2 1.0 |
same issue with 'ffill', 'bfill' (xref #14274) |
transforming on a non aggregating function doesn’t make any sense as how would u broadcast the results? |
if its non-agg AND shape-preserving, as many are, you don't need to broadcast. |
but if you're suddenly a purist, see #26743 |
@pilkibun you are missing the point transform works by using groupby with an aggregator then broadcasting not sure why you have 2 issues here - just needs to raise on on aggregated outputs |
I'm sure one of us is. |
This is fixed on master but could use tests |
This seems to be the same issue as #19354, closing |
Code Sample, a copy-pastable example if possible
Problem description
For simplicity, I will explain the issue for
SeriesGroupBy
, though the bug is present inDataFrameGroupBy
as well.When calling
transform
on aSeriesGroupBy
object with a string input'string_input'
(e.g.,.transform('mean')
), the relevant code insidepandas/pandas/core/groupby/generic.py
will end up callingSeriesGroupBy._transform_fast
on the function(unless
'string_input'
is insidebase.cython_transforms
, currently consisting of['cumprod', 'cumsum', 'shift', 'cummin', 'cummax']
). Inside_transform_fast
, the result of applyingfunc
to theSeriesGroupBy
object is then broadcast to the entire index of the original object. This works as expected iffunc
returns a single value per group in theGroupBy
object (e.g., for functions like'mean'
,'std'
, etc.). However, for functions likerank
,cumcount
, etc., that return several values per group, the result of broadcasting is nonsensical to my best knowledge.Note: Index broadcasting works correctly if
transform
is called with a (non-cython) function, for examplelambda x: x.rank()
. In that case,_fast_transform
is never called and the result is a simple concatenation of the results for each group.Expected Output
The results of
df.groupby('a')['b'].transform('rank')
anddf.groupby('a')['b'].rank()
should be identical. Same for 'cumcount' and maybe otherGroupBy
functions.Output of
pd.show_versions()
[paste the output of
pd.show_versions()
here below this line]INSTALLED VERSIONS
commit: None
python: 3.7.0.final.0
python-bits: 64
OS: Linux
OS-release: 4.18.3-arch1-1-ARCH
machine: x86_64
processor:
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: en_US.UTF-8
pandas: 0.23.4
pytest: None
pip: 10.0.1
setuptools: 40.0.0
Cython: None
numpy: 1.15.0
scipy: 1.1.0
pyarrow: None
xarray: None
IPython: 6.5.0
sphinx: None
patsy: 0.5.0
dateutil: 2.7.3
pytz: 2018.5
blosc: None
bottleneck: None
tables: 3.4.4
numexpr: 2.6.7
feather: None
matplotlib: 2.2.3
openpyxl: None
xlrd: None
xlwt: None
xlsxwriter: None
lxml: None
bs4: None
html5lib: 1.0.1
sqlalchemy: None
pymysql: None
psycopg2: None
jinja2: 2.10
s3fs: None
fastparquet: None
pandas_gbq: None
pandas_datareader: None
The text was updated successfully, but these errors were encountered: