Skip to content

Commit 6cad5ca

Browse files
author
Shallow Copy Bot
committed
Use pandas testing utils for exact cases.
Original PR #1722 by ueshin Original: databricks/koalas#1722
1 parent 9980a3f commit 6cad5ca

19 files changed

+329
-536
lines changed

databricks/koalas/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,10 @@ def __sub__(self, other):
212212
)
213213
if isinstance(other, IndexOpsMixin) and isinstance(other.spark.data_type, DateType):
214214
warnings.warn(msg, UserWarning)
215-
return column_op(F.datediff)(self, other)
215+
return column_op(F.datediff)(self, other).astype("bigint")
216216
elif isinstance(other, datetime.date) and not isinstance(other, datetime.datetime):
217217
warnings.warn(msg, UserWarning)
218-
return column_op(F.datediff)(self, F.lit(other))
218+
return column_op(F.datediff)(self, F.lit(other)).astype("bigint")
219219
else:
220220
raise TypeError("date subtraction can only be applied to date series.")
221221
return column_op(Column.__sub__)(self, other)
@@ -286,7 +286,7 @@ def __rsub__(self, other):
286286
)
287287
if isinstance(other, datetime.date) and not isinstance(other, datetime.datetime):
288288
warnings.warn(msg, UserWarning)
289-
return -column_op(F.datediff)(self, F.lit(other))
289+
return -column_op(F.datediff)(self, F.lit(other)).astype("bigint")
290290
else:
291291
raise TypeError("date subtraction can only be applied to date series.")
292292
return column_op(Column.__rsub__)(self, other)

databricks/koalas/frame.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,10 +2093,18 @@ def transpose(self):
20932093

20942094
internal = self._internal.copy(
20952095
spark_frame=transposed_df,
2096-
index_map=OrderedDict((col, None) for col in internal_index_columns),
2096+
index_map=OrderedDict(
2097+
(col, name if name is None or isinstance(name, tuple) else (name,))
2098+
for col, name in zip(
2099+
internal_index_columns,
2100+
self._internal.column_label_names
2101+
if self._internal.column_label_names is not None
2102+
else ([None] * len(internal_index_columns)),
2103+
)
2104+
),
20972105
column_labels=[tuple(json.loads(col)["a"]) for col in new_data_columns],
20982106
data_spark_columns=[scol_for(transposed_df, col) for col in new_data_columns],
2099-
column_label_names=None,
2107+
column_label_names=self._internal.index_names,
21002108
)
21012109

21022110
return DataFrame(internal)

databricks/koalas/testing/utils.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,46 @@ def tearDownClass(cls):
133133

134134
def assertPandasEqual(self, left, right):
135135
if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
136-
msg = (
137-
"DataFrames are not equal: "
138-
+ "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
139-
+ "\n\nRight:\n%s\n%s" % (right, right.dtypes)
140-
)
141-
self.assertTrue(left.equals(right), msg=msg)
136+
try:
137+
pd.util.testing.assert_frame_equal(
138+
left,
139+
right,
140+
check_index_type=("equiv" if len(left.index) > 0 else False),
141+
check_column_type=("equiv" if len(left.columns) > 0 else False),
142+
check_exact=True,
143+
)
144+
except AssertionError as e:
145+
msg = (
146+
str(e)
147+
+ "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
148+
+ "\n\nRight:\n%s\n%s" % (right, right.dtypes)
149+
)
150+
raise AssertionError(msg) from e
142151
elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
143-
msg = (
144-
"Series are not equal: "
145-
+ "\n\nLeft:\n%s\n%s" % (left, left.dtype)
146-
+ "\n\nRight:\n%s\n%s" % (right, right.dtype)
147-
)
148-
self.assertEqual(str(left.name), str(right.name), msg=msg)
149-
self.assertTrue((left == right).all(), msg=msg)
152+
try:
153+
pd.util.testing.assert_series_equal(
154+
left,
155+
right,
156+
check_index_type=("equiv" if len(left.index) > 0 else False),
157+
check_exact=True,
158+
)
159+
except AssertionError as e:
160+
msg = (
161+
str(e)
162+
+ "\n\nLeft:\n%s\n%s" % (left, left.dtype)
163+
+ "\n\nRight:\n%s\n%s" % (right, right.dtype)
164+
)
165+
raise AssertionError(msg) from e
150166
elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
151-
msg = (
152-
"Indices are not equal: "
153-
+ "\n\nLeft:\n%s\n%s" % (left, left.dtype)
154-
+ "\n\nRight:\n%s\n%s" % (right, right.dtype)
155-
)
156-
self.assertTrue((left == right).all(), msg=msg)
167+
try:
168+
pd.util.testing.assert_index_equal(left, right, check_exact=True)
169+
except AssertionError as e:
170+
msg = (
171+
str(e)
172+
+ "\n\nLeft:\n%s\n%s" % (left, left.dtype)
173+
+ "\n\nRight:\n%s\n%s" % (right, right.dtype)
174+
)
175+
raise AssertionError(msg) from e
157176
else:
158177
raise ValueError("Unexpected values: (%s, %s)" % (left, right))
159178

@@ -190,6 +209,15 @@ def assertPandasAlmostEqual(self, left, right):
190209
self.assertEqual(lnull, rnull, msg=msg)
191210
for lval, rval in zip(left.dropna(), right.dropna()):
192211
self.assertAlmostEqual(lval, rval, msg=msg)
212+
elif isinstance(left, pd.MultiIndex) and isinstance(left, pd.MultiIndex):
213+
msg = (
214+
"MultiIndices are not almost equal: "
215+
+ "\n\nLeft:\n%s\n%s" % (left, left.dtype)
216+
+ "\n\nRight:\n%s\n%s" % (right, right.dtype)
217+
)
218+
self.assertEqual(len(left), len(right), msg=msg)
219+
for lval, rval in zip(left, right):
220+
self.assertAlmostEqual(lval, rval, msg=msg)
193221
elif isinstance(left, pd.Index) and isinstance(left, pd.Index):
194222
msg = (
195223
"Indices are not almost equal: "

databricks/koalas/tests/test_dataframe.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def mul10(x) -> int:
567567

568568
def test_dot_in_column_name(self):
569569
self.assert_eq(
570-
ks.DataFrame(ks.range(1)._internal.spark_frame.selectExpr("1 as `a.b`"))["a.b"],
570+
ks.DataFrame(ks.range(1)._internal.spark_frame.selectExpr("1L as `a.b`"))["a.b"],
571571
ks.Series([1], name="a.b"),
572572
)
573573

@@ -665,7 +665,7 @@ def _test_dropna(self, pdf, axis):
665665
pdf2.dropna(inplace=True)
666666
kdf2.dropna(inplace=True)
667667
self.assert_eq(kdf2, pdf2)
668-
self.assert_eq(kser, pser, almost=True)
668+
self.assert_eq(kser, pser)
669669

670670
# multi-index
671671
columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")])
@@ -805,7 +805,7 @@ def test_fillna(self):
805805
pdf.fillna({"x": -1, "y": -2, "z": -5}, inplace=True)
806806
kdf.fillna({"x": -1, "y": -2, "z": -5}, inplace=True)
807807
self.assert_eq(kdf, pdf)
808-
self.assert_eq(kser, pser, almost=True)
808+
self.assert_eq(kser, pser)
809809

810810
s_nan = pd.Series([-1, -2, -5], index=["x", "y", "z"], dtype=int)
811811
self.assert_eq(kdf.fillna(s_nan), pdf.fillna(s_nan))
@@ -942,7 +942,7 @@ def test_sort_values(self):
942942
kserA = kdf.a
943943
self.assert_eq(kdf.sort_values("b", inplace=True), pdf.sort_values("b", inplace=True))
944944
self.assert_eq(kdf, pdf)
945-
self.assert_eq(kserA, pserA, almost=True)
945+
self.assert_eq(kserA, pserA)
946946

947947
columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")])
948948
kdf.columns = columns
@@ -975,7 +975,7 @@ def test_sort_index(self):
975975
kserA = kdf.A
976976
self.assertEqual(kdf.sort_index(inplace=True), pdf.sort_index(inplace=True))
977977
self.assert_eq(kdf, pdf)
978-
self.assert_eq(kserA, pserA, almost=True)
978+
self.assert_eq(kserA, pserA)
979979

980980
# Assert multi-indices
981981
pdf = pd.DataFrame(
@@ -1759,7 +1759,7 @@ def get_data(left_columns=None, right_columns=None):
17591759
left_pdf.update(right_pdf)
17601760
left_kdf.update(right_kdf)
17611761
self.assert_eq(left_pdf.sort_values(by=["A", "B"]), left_kdf.sort_values(by=["A", "B"]))
1762-
self.assert_eq(kser.sort_index(), pser.sort_index(), almost=True)
1762+
self.assert_eq(kser.sort_index(), pser.sort_index())
17631763

17641764
left_kdf, left_pdf, right_kdf, right_pdf = get_data()
17651765
left_pdf.update(right_pdf, overwrite=False)
@@ -2063,7 +2063,7 @@ def test_stack(self):
20632063
)
20642064
kdf = ks.from_pandas(pdf)
20652065

2066-
self.assert_eq(kdf.stack().sort_index(), pdf.stack().sort_index(), almost=True)
2066+
self.assert_eq(kdf.stack().sort_index(), pdf.stack().sort_index())
20672067
self.assert_eq(kdf[[]].stack().sort_index(), pdf[[]].stack().sort_index(), almost=True)
20682068

20692069
def test_unstack(self):
@@ -3362,10 +3362,10 @@ def test_query(self):
33623362
kdf.query("('A', 'Z') > ('B', 'X')")
33633363

33643364
def test_take(self):
3365-
kdf = ks.DataFrame(
3365+
pdf = pd.DataFrame(
33663366
{"A": range(0, 50000), "B": range(100000, 0, -2), "C": range(100000, 50000, -1)}
33673367
)
3368-
pdf = kdf.to_pandas()
3368+
kdf = ks.from_pandas(pdf)
33693369

33703370
# axis=0 (default)
33713371
self.assert_eq(kdf.take([1, 2]).sort_index(), pdf.take([1, 2]).sort_index())
@@ -3438,6 +3438,7 @@ def test_take(self):
34383438
self.assert_eq(
34393439
kdf.take(range(-1, -3), axis=1).sort_index(),
34403440
pdf.take(range(-1, -3), axis=1).sort_index(),
3441+
almost=True,
34413442
)
34423443
self.assert_eq(
34433444
kdf.take([2, 1], axis=1).sort_index(), pdf.take([2, 1], axis=1).sort_index(),
@@ -3555,7 +3556,7 @@ def test_squeeze(self):
35553556
axises = [None, 0, 1, "rows", "index", "columns"]
35563557

35573558
# Multiple columns
3558-
pdf = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"])
3559+
pdf = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"], index=["x", "y"])
35593560
kdf = ks.from_pandas(pdf)
35603561
for axis in axises:
35613562
self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis))
@@ -3567,7 +3568,7 @@ def test_squeeze(self):
35673568
self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis))
35683569

35693570
# Single column with single value
3570-
pdf = pd.DataFrame([[1]], columns=["a"])
3571+
pdf = pd.DataFrame([[1]], columns=["a"], index=["x"])
35713572
kdf = ks.from_pandas(pdf)
35723573
for axis in axises:
35733574
self.assert_eq(pdf.squeeze(axis), kdf.squeeze(axis))
@@ -3880,15 +3881,15 @@ def test_iteritems(self):
38803881

38813882
def test_tail(self):
38823883
if LooseVersion(pyspark.__version__) >= LooseVersion("3.0"):
3883-
pdf = pd.DataFrame(range(1000))
3884+
pdf = pd.DataFrame({"x": range(1000)})
38843885
kdf = ks.from_pandas(pdf)
38853886

3886-
self.assert_eq(pdf.tail(), kdf.tail(), almost=True)
3887-
self.assert_eq(pdf.tail(10), kdf.tail(10), almost=True)
3888-
self.assert_eq(pdf.tail(-990), kdf.tail(-990), almost=True)
3889-
self.assert_eq(pdf.tail(0), kdf.tail(0), almost=True)
3890-
self.assert_eq(pdf.tail(-1001), kdf.tail(-1001), almost=True)
3891-
self.assert_eq(pdf.tail(1001), kdf.tail(1001), almost=True)
3887+
self.assert_eq(pdf.tail(), kdf.tail())
3888+
self.assert_eq(pdf.tail(10), kdf.tail(10))
3889+
self.assert_eq(pdf.tail(-990), kdf.tail(-990))
3890+
self.assert_eq(pdf.tail(0), kdf.tail(0))
3891+
self.assert_eq(pdf.tail(-1001), kdf.tail(-1001))
3892+
self.assert_eq(pdf.tail(1001), kdf.tail(1001))
38923893
with self.assertRaisesRegex(TypeError, "bad operand type for unary -: 'str'"):
38933894
kdf.tail("10")
38943895

0 commit comments

Comments
 (0)