Skip to content

Commit 55e3459

Browse files
authored
FIX-#1976: indices matching at reduction functions fixed (#2270)
Signed-off-by: Dmitry Chigarev <[email protected]>
1 parent 6697e05 commit 55e3459

File tree

11 files changed

+373
-478
lines changed

11 files changed

+373
-478
lines changed

modin/backends/pandas/query_compiler.py

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222
is_scalar,
2323
)
2424
from pandas.core.base import DataError
25+
from typing import Type, Callable
2526
import warnings
2627

28+
2729
from modin.backends.base.query_compiler import BaseQueryCompiler
2830
from modin.error_message import ErrorMessage
2931
from modin.utils import try_cast_to_pandas, wrap_udf_function
3032
from modin.data_management.functions import (
33+
Function,
3134
FoldFunction,
3235
MapFunction,
3336
MapReduceFunction,
@@ -150,6 +153,34 @@ def caller(df, *args, **kwargs):
150153
return caller
151154

152155

156+
def _numeric_only_reduce_fn(applier: Type[Function], *funcs) -> Callable:
157+
"""
158+
Build reduce function for statistic operations with `numeric_only` parameter.
159+
160+
Parameters
161+
----------
162+
applier: Callable
163+
Function object to register `funcs`
164+
*funcs: list
165+
List of functions to register in `applier`
166+
167+
Returns
168+
-------
169+
callable
170+
A callable function to be applied in the partitions
171+
"""
172+
173+
def caller(self, *args, **kwargs):
174+
# If `numeric_only` is None then we don't know what columns/indices will
175+
# be dropped at the result of reduction function, and so can't preserve labels
176+
preserve_index = kwargs.get("numeric_only", None) is not None
177+
return applier.register(*funcs, preserve_index=preserve_index)(
178+
self, *args, **kwargs
179+
)
180+
181+
return caller
182+
183+
153184
class PandasQueryCompiler(BaseQueryCompiler):
154185
"""This class implements the logic necessary for operating on partitions
155186
with a Pandas backend. This logic is specific to Pandas."""
@@ -625,29 +656,54 @@ def is_monotonic_decreasing(self):
625656
is_monotonic = _is_monotonic
626657

627658
count = MapReduceFunction.register(pandas.DataFrame.count, pandas.DataFrame.sum)
628-
max = MapReduceFunction.register(pandas.DataFrame.max, pandas.DataFrame.max)
629-
min = MapReduceFunction.register(pandas.DataFrame.min, pandas.DataFrame.min)
630-
sum = MapReduceFunction.register(pandas.DataFrame.sum, pandas.DataFrame.sum)
631-
prod = MapReduceFunction.register(pandas.DataFrame.prod, pandas.DataFrame.prod)
659+
max = _numeric_only_reduce_fn(MapReduceFunction, pandas.DataFrame.max)
660+
min = _numeric_only_reduce_fn(MapReduceFunction, pandas.DataFrame.min)
661+
sum = _numeric_only_reduce_fn(MapReduceFunction, pandas.DataFrame.sum)
662+
prod = _numeric_only_reduce_fn(MapReduceFunction, pandas.DataFrame.prod)
632663
any = MapReduceFunction.register(pandas.DataFrame.any, pandas.DataFrame.any)
633664
all = MapReduceFunction.register(pandas.DataFrame.all, pandas.DataFrame.all)
634665
memory_usage = MapReduceFunction.register(
635666
pandas.DataFrame.memory_usage,
636667
lambda x, *args, **kwargs: pandas.DataFrame.sum(x),
637668
axis=0,
638669
)
639-
mean = MapReduceFunction.register(
640-
lambda df, **kwargs: df.apply(
641-
lambda x: (x.sum(skipna=kwargs.get("skipna", True)), x.count()),
642-
axis=kwargs.get("axis", 0),
643-
result_type="reduce",
644-
).set_axis(df.axes[kwargs.get("axis", 0) ^ 1], axis=0),
645-
lambda df, **kwargs: df.apply(
646-
lambda x: x.apply(lambda d: d[0]).sum(skipna=kwargs.get("skipna", True))
647-
/ x.apply(lambda d: d[1]).sum(skipna=kwargs.get("skipna", True)),
648-
axis=kwargs.get("axis", 0),
649-
).set_axis(df.axes[kwargs.get("axis", 0) ^ 1], axis=0),
650-
)
670+
671+
def mean(self, axis, **kwargs):
672+
if kwargs.get("level") is not None:
673+
return self.default_to_pandas(pandas.DataFrame.mean, axis=axis, **kwargs)
674+
675+
skipna = kwargs.get("skipna", True)
676+
677+
def map_apply_fn(ser, **kwargs):
678+
try:
679+
sum_result = ser.sum(skipna=skipna)
680+
count_result = ser.count()
681+
except TypeError:
682+
return None
683+
else:
684+
return (sum_result, count_result)
685+
686+
def reduce_apply_fn(ser, **kwargs):
687+
sum_result = ser.apply(lambda x: x[0]).sum(skipna=skipna)
688+
count_result = ser.apply(lambda x: x[1]).sum(skipna=skipna)
689+
return sum_result / count_result
690+
691+
def reduce_fn(df, **kwargs):
692+
df.dropna(axis=1, inplace=True, how="any")
693+
return build_applyier(reduce_apply_fn, axis=axis)(df)
694+
695+
def build_applyier(func, **applyier_kwargs):
696+
def applyier(df, **kwargs):
697+
result = df.apply(func, **applyier_kwargs)
698+
return result.set_axis(df.axes[axis ^ 1], axis=0)
699+
700+
return applyier
701+
702+
return MapReduceFunction.register(
703+
build_applyier(map_apply_fn, axis=axis, result_type="reduce"),
704+
reduce_fn,
705+
preserve_index=(kwargs.get("numeric_only") is not None),
706+
)(self, axis=axis, **kwargs)
651707

652708
def value_counts(self, **kwargs):
653709
"""
@@ -664,7 +720,7 @@ def value_counts(self, **kwargs):
664720
return self.__constructor__(new_modin_frame)
665721

666722
def map_func(df, *args, **kwargs):
667-
return df.squeeze(axis=1).value_counts(**kwargs)
723+
return df.squeeze(axis=1).value_counts(**kwargs).to_frame()
668724

669725
def reduce_func(df, *args, **kwargs):
670726
normalize = kwargs.get("normalize", False)
@@ -735,28 +791,30 @@ def sort_index_for_equal_values(result, ascending):
735791
else:
736792
new_index[j] = result.index[j]
737793
i += 1
738-
return pandas.DataFrame(result, index=new_index)
794+
return pandas.DataFrame(
795+
result, index=new_index, columns=["__reduced__"]
796+
)
739797

740798
return sort_index_for_equal_values(result, ascending)
741799

742-
return MapReduceFunction.register(map_func, reduce_func, preserve_index=False)(
743-
self, **kwargs
744-
)
800+
return MapReduceFunction.register(
801+
map_func, reduce_func, axis=0, preserve_index=False
802+
)(self, **kwargs)
745803

746804
# END MapReduce operations
747805

748806
# Reduction operations
749807
idxmax = ReductionFunction.register(pandas.DataFrame.idxmax)
750808
idxmin = ReductionFunction.register(pandas.DataFrame.idxmin)
751-
median = ReductionFunction.register(pandas.DataFrame.median)
809+
median = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.median)
752810
nunique = ReductionFunction.register(pandas.DataFrame.nunique)
753-
skew = ReductionFunction.register(pandas.DataFrame.skew)
754-
kurt = ReductionFunction.register(pandas.DataFrame.kurt)
755-
sem = ReductionFunction.register(pandas.DataFrame.sem)
756-
std = ReductionFunction.register(pandas.DataFrame.std)
757-
var = ReductionFunction.register(pandas.DataFrame.var)
758-
sum_min_count = ReductionFunction.register(pandas.DataFrame.sum)
759-
prod_min_count = ReductionFunction.register(pandas.DataFrame.prod)
811+
skew = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.skew)
812+
kurt = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.kurt)
813+
sem = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.sem)
814+
std = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.std)
815+
var = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.var)
816+
sum_min_count = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.sum)
817+
prod_min_count = _numeric_only_reduce_fn(ReductionFunction, pandas.DataFrame.prod)
760818
quantile_for_single_value = ReductionFunction.register(pandas.DataFrame.quantile)
761819
mad = ReductionFunction.register(pandas.DataFrame.mad)
762820
to_datetime = ReductionFunction.register(

modin/data_management/functions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific language
1212
# governing permissions and limitations under the License.
1313

14+
from .function import Function
1415
from .mapfunction import MapFunction
1516
from .mapreducefunction import MapReduceFunction
1617
from .reductionfunction import ReductionFunction
@@ -19,6 +20,7 @@
1920
from .groupby_function import GroupbyReduceFunction
2021

2122
__all__ = [
23+
"Function",
2224
"MapFunction",
2325
"MapReduceFunction",
2426
"ReductionFunction",

modin/data_management/functions/foldfunction.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@ class FoldFunction(Function):
1818
@classmethod
1919
def call(cls, fold_function, **call_kwds):
2020
def caller(query_compiler, *args, **kwargs):
21+
axis = call_kwds.get("axis", kwargs.get("axis"))
2122
return query_compiler.__constructor__(
2223
query_compiler._modin_frame._fold(
23-
call_kwds.get("axis")
24-
if "axis" in call_kwds
25-
else kwargs.get("axis"),
24+
cls.validate_axis(axis),
2625
lambda x: fold_function(x, *args, **kwargs),
2726
)
2827
)

modin/data_management/functions/function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# ANY KIND, either express or implied. See the License for the specific language
1212
# governing permissions and limitations under the License.
1313

14+
from typing import Optional
15+
1416

1517
class Function(object):
1618
def __init__(self):
@@ -27,3 +29,7 @@ def call(cls, func, **call_kwds):
2729
@classmethod
2830
def register(cls, func, **kwargs):
2931
return cls.call(func, **kwargs)
32+
33+
@classmethod
34+
def validate_axis(cls, axis: Optional[int]) -> int:
35+
return 0 if axis is None else axis

modin/data_management/functions/mapreducefunction.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ class MapReduceFunction(Function):
1919
def call(cls, map_function, reduce_function, **call_kwds):
2020
def caller(query_compiler, *args, **kwargs):
2121
preserve_index = call_kwds.pop("preserve_index", True)
22+
axis = call_kwds.get("axis", kwargs.get("axis"))
2223
return query_compiler.__constructor__(
2324
query_compiler._modin_frame._map_reduce(
24-
call_kwds.get("axis")
25-
if "axis" in call_kwds
26-
else kwargs.get("axis"),
25+
cls.validate_axis(axis),
2726
lambda x: map_function(x, *args, **kwargs),
2827
lambda y: reduce_function(y, *args, **kwargs),
2928
preserve_index=preserve_index,
@@ -33,5 +32,7 @@ def caller(query_compiler, *args, **kwargs):
3332
return caller
3433

3534
@classmethod
36-
def register(cls, map_function, reduce_function, **kwargs):
35+
def register(cls, map_function, reduce_function=None, **kwargs):
36+
if reduce_function is None:
37+
reduce_function = map_function
3738
return cls.call(map_function, reduce_function, **kwargs)

modin/data_management/functions/reductionfunction.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ class ReductionFunction(Function):
1818
@classmethod
1919
def call(cls, reduction_function, **call_kwds):
2020
def caller(query_compiler, *args, **kwargs):
21+
preserve_index = call_kwds.pop("preserve_index", True)
22+
axis = call_kwds.get("axis", kwargs.get("axis"))
2123
return query_compiler.__constructor__(
2224
query_compiler._modin_frame._fold_reduce(
23-
call_kwds.get("axis")
24-
if "axis" in call_kwds
25-
else kwargs.get("axis"),
25+
cls.validate_axis(axis),
2626
lambda x: reduction_function(x, *args, **kwargs),
27+
preserve_index=preserve_index,
2728
)
2829
)
2930

0 commit comments

Comments
 (0)