Skip to content

Commit 31d16bf

Browse files
committed
infer types
1 parent 576e90d commit 31d16bf

File tree

5 files changed

+25
-17
lines changed

5 files changed

+25
-17
lines changed

pandas/core/arrays/sparse.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -1907,18 +1907,17 @@ def make_sparse(arr, kind='block', fill_value=None, dtype=None, copy=False):
19071907
index = _make_index(length, indices, kind)
19081908
sparsified_values = arr[mask]
19091909

1910-
# careful about casting here
1911-
# as we could easily specify a type that cannot hold the resulting values
1912-
# e.g. integer when we have floats
1910+
# careful about casting here as we could easily specify a type that
1911+
# cannot hold the resulting values, e.g. integer when we have floats
1912+
# if we don't have an object specified then use this as the cast
19131913
if dtype is not None:
1914-
try:
1915-
sparsified_values = astype_nansafe(
1916-
sparsified_values, dtype=dtype, casting='same_kind')
1917-
except TypeError:
1918-
dtype = 'float64'
1919-
sparsified_values = astype_nansafe(
1920-
sparsified_values, dtype=dtype, casting='unsafe')
19211914

1915+
ok_to_cast = all(not (is_object_dtype(t) or is_bool_dtype(t))
1916+
for t in (dtype, sparsified_values.dtype))
1917+
if ok_to_cast:
1918+
dtype = find_common_type([dtype, sparsified_values.dtype])
1919+
sparsified_values = astype_nansafe(
1920+
sparsified_values, dtype=dtype)
19221921

19231922
# TODO: copy
19241923
return sparsified_values, index, fill_value

pandas/core/internals/construction.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,10 @@ def sanitize_array(data, index, dtype=None, copy=False,
667667
data = np.array(data, dtype=dtype, copy=False)
668668
subarr = np.array(data, dtype=object, copy=copy)
669669

670-
if is_object_dtype(subarr.dtype) and dtype != 'object':
670+
if (not (is_extension_array_dtype(subarr.dtype) or
671+
is_extension_array_dtype(dtype)) and
672+
is_object_dtype(subarr.dtype) and
673+
not is_object_dtype(dtype)):
671674
inferred = lib.infer_dtype(subarr, skipna=False)
672675
if inferred == 'period':
673676
try:

pandas/core/sparse/frame.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -284,20 +284,26 @@ def _unpickle_sparse_frame_compat(self, state):
284284
def to_dense(self):
285285
return SparseFrameAccessor(self).to_dense()
286286

287-
def _apply_columns(self, func):
287+
def _apply_columns(self, func, *args, **kwargs):
288288
"""
289289
Get new SparseDataFrame applying func to each columns
290290
"""
291291

292-
new_data = {col: func(series)
292+
new_data = {col: func(series, *args, **kwargs)
293293
for col, series in self.items()}
294294

295295
return self._constructor(
296296
data=new_data, index=self.index, columns=self.columns,
297297
default_fill_value=self.default_fill_value).__finalize__(self)
298298

299-
def astype(self, dtype):
300-
return self._apply_columns(lambda x: x.astype(dtype))
299+
def astype(self, dtype, **kwargs):
300+
301+
def f(x, dtype, **kwargs):
302+
if isinstance(dtype, (dict, Series)):
303+
dtype = dtype[x.name]
304+
return x.astype(dtype, **kwargs)
305+
306+
return self._apply_columns(f, dtype=dtype, **kwargs)
301307

302308
def copy(self, deep=True):
303309
"""

pandas/tests/sparse/frame/test_analytics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ def test_ufunc(data, dtype, func):
5555
result = func(df)
5656
expected = DataFrame(
5757
{'A': Series(func(data),
58-
dtype=dtype)})
58+
dtype=SparseDtype('float64', dtype.fill_value))})
5959
tm.assert_frame_equal(result, expected)

pandas/tests/sparse/series/test_analytics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ def test_ufunc(data, dtype, func):
1616
s = Series(data, dtype=dtype)
1717
result = func(s)
1818
expected = Series(func(data),
19-
dtype=dtype)
19+
dtype=SparseDtype('float64', dtype.fill_value))
2020
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)