Skip to content

Commit 93a4039

Browse files
Zac-HDfmaussion
authored andcommitted
Use correct dtype for RGB image alpha channel (#1893)
* Fix alpha channel logic for integer RGB images * Use named argument for concat axis
1 parent ee38ff0 commit 93a4039

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

doc/gallery/plot_rasterio_rgb.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
# Read the data
2727
da = xr.open_rasterio('RGB.byte.tif')
2828

29-
# Normalize the image
30-
da = da / 255
31-
3229
# The data is in UTM projection. We have to set it manually until
3330
# https://github.com/SciTools/cartopy/issues/813 is implemented
3431
crs = ccrs.UTM('18N')

xarray/plot/plot.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,12 @@ def imshow(x, y, z, ax, **kwargs):
710710
# missing data transparent. We therefore add an alpha channel if
711711
# there isn't one, and set it to transparent where data is masked.
712712
if z.shape[-1] == 3:
713-
z = np.ma.concatenate((z, np.ma.ones(z.shape[:2] + (1,))), 2)
714-
z = z.copy()
713+
alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype)
714+
if np.issubdtype(z.dtype, np.integer):
715+
alpha *= 255
716+
z = np.ma.concatenate((z, alpha), axis=2)
717+
else:
718+
z = z.copy()
715719
z[np.any(z.mask, axis=-1), -1] = 0
716720

717721
primitive = ax.imshow(z, **defaults)

xarray/tests/test_plot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,13 @@ def test_normalize_rgb_one_arg_error(self):
11461146
for kwds in [dict(vmax=-1, vmin=-1.2), dict(vmin=2, vmax=2.1)]:
11471147
da.plot.imshow(**kwds)
11481148

1149+
def test_imshow_rgb_values_in_valid_range(self):
1150+
da = DataArray(np.arange(75, dtype='uint8').reshape((5, 5, 3)))
1151+
_, ax = plt.subplots()
1152+
out = da.plot.imshow(ax=ax).get_array()
1153+
assert out.dtype == np.uint8
1154+
assert (out[..., :3] == da.values).all() # Compare without added alpha
1155+
11491156

11501157
class TestFacetGrid(PlotTestCase):
11511158
def setUp(self):

0 commit comments

Comments
 (0)