Skip to content

Commit a7dc3b1

Browse files
ueshinHyukjinKwon
authored andcommitted
Fix is_unique to respect the current Spark column (#981)
1 parent f228961 commit a7dc3b1

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

databricks/koalas/series.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -911,18 +911,17 @@ def is_unique(self):
911911
>>> ks.Series([1, 2, 3, None]).is_unique
912912
True
913913
"""
914-
sdf = self._kdf._sdf.select(self._scol)
915-
col = self._scol
914+
scol = self._scol
916915

917916
# Here we check:
918917
# 1. the distinct count without nulls and count without nulls for non-null values
919918
# 2. count null values and see if null is a distinct value.
920919
#
921920
# This workaround is in order to calculate the distinct count including nulls in
922921
# single pass. Note that COUNT(DISTINCT expr) in Spark is designed to ignore nulls.
923-
return sdf.select(
924-
(F.count(col) == F.countDistinct(col)) &
925-
(F.count(F.when(col.isNull(), 1).otherwise(None)) <= 1)
922+
return self._kdf._sdf.select(
923+
(F.count(scol) == F.countDistinct(scol)) &
924+
(F.count(F.when(scol.isNull(), 1).otherwise(None)) <= 1)
926925
).collect()[0][0]
927926

928927
def reset_index(self, level=None, drop=False, name=None, inplace=False):

databricks/koalas/tests/test_series.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,18 +434,22 @@ def test_is_unique(self):
434434
pser = pd.Series([1, 2, 2, None, None])
435435
kser = ks.from_pandas(pser)
436436
self.assertEqual(False, kser.is_unique)
437+
self.assertEqual(False, (kser + 1).is_unique)
437438

438439
pser = pd.Series([1, None, None])
439440
kser = ks.from_pandas(pser)
440441
self.assertEqual(False, kser.is_unique)
442+
self.assertEqual(False, (kser + 1).is_unique)
441443

442444
pser = pd.Series([1])
443445
kser = ks.from_pandas(pser)
444446
self.assertEqual(pser.is_unique, kser.is_unique)
447+
self.assertEqual((pser + 1).is_unique, (kser + 1).is_unique)
445448

446449
pser = pd.Series([1, 1, 1])
447450
kser = ks.from_pandas(pser)
448451
self.assertEqual(pser.is_unique, kser.is_unique)
452+
self.assertEqual((pser + 1).is_unique, (kser + 1).is_unique)
449453

450454
def test_to_list(self):
451455
if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"):

0 commit comments

Comments
 (0)