|
29 | 29 | "import numpy as np\n",
|
30 | 30 | "import pandas as pd\n",
|
31 | 31 | "import xarray as xr\n",
|
32 |
| - "import matplotlib.pyplot as plt " |
| 32 | + "import matplotlib.pyplot as plt" |
33 | 33 | ]
|
34 | 34 | },
|
35 | 35 | {
|
|
50 | 50 | },
|
51 | 51 | "outputs": [],
|
52 | 52 | "source": [
|
53 |
| - "ds = xr.tutorial.open_dataset('rasm').load()\n", |
| 53 | + "ds = xr.tutorial.open_dataset(\"rasm\").load()\n", |
54 | 54 | "ds"
|
55 | 55 | ]
|
56 | 56 | },
|
|
88 | 88 | "outputs": [],
|
89 | 89 | "source": [
|
90 | 90 | "# Calculate the weights by grouping by 'time.season'.\n",
|
91 |
| - "weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()\n", |
| 91 | + "weights = (\n", |
| 92 | + " month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n", |
| 93 | + ")\n", |
92 | 94 | "\n",
|
93 | 95 | "# Test that the sum of the weights for each season is 1.0\n",
|
94 |
| - "np.testing.assert_allclose(weights.groupby('time.season').sum().values, np.ones(4))\n", |
| 96 | + "np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n", |
95 | 97 | "\n",
|
96 | 98 | "# Calculate the weighted average\n",
|
97 |
| - "ds_weighted = (ds * weights).groupby('time.season').sum(dim='time')" |
| 99 | + "ds_weighted = (ds * weights).groupby(\"time.season\").sum(dim=\"time\")" |
98 | 100 | ]
|
99 | 101 | },
|
100 | 102 | {
|
|
123 | 125 | "outputs": [],
|
124 | 126 | "source": [
|
125 | 127 | "# only used for comparisons\n",
|
126 |
| - "ds_unweighted = ds.groupby('time.season').mean('time')\n", |
| 128 | + "ds_unweighted = ds.groupby(\"time.season\").mean(\"time\")\n", |
127 | 129 | "ds_diff = ds_weighted - ds_unweighted"
|
128 | 130 | ]
|
129 | 131 | },
|
|
139 | 141 | "outputs": [],
|
140 | 142 | "source": [
|
141 | 143 | "# Quick plot to show the results\n",
|
142 |
| - "notnull = pd.notnull(ds_unweighted['Tair'][0])\n", |
143 |
| - "\n", |
144 |
| - "fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14,12))\n", |
145 |
| - "for i, season in enumerate(('DJF', 'MAM', 'JJA', 'SON')):\n", |
146 |
| - " ds_weighted['Tair'].sel(season=season).where(notnull).plot.pcolormesh(\n", |
147 |
| - " ax=axes[i, 0], vmin=-30, vmax=30, cmap='Spectral_r', \n", |
148 |
| - " add_colorbar=True, extend='both')\n", |
149 |
| - " \n", |
150 |
| - " ds_unweighted['Tair'].sel(season=season).where(notnull).plot.pcolormesh(\n", |
151 |
| - " ax=axes[i, 1], vmin=-30, vmax=30, cmap='Spectral_r', \n", |
152 |
| - " add_colorbar=True, extend='both')\n", |
153 |
| - "\n", |
154 |
| - " ds_diff['Tair'].sel(season=season).where(notnull).plot.pcolormesh(\n", |
155 |
| - " ax=axes[i, 2], vmin=-0.1, vmax=.1, cmap='RdBu_r',\n", |
156 |
| - " add_colorbar=True, extend='both')\n", |
| 144 | + "notnull = pd.notnull(ds_unweighted[\"Tair\"][0])\n", |
| 145 | + "\n", |
| 146 | + "fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14, 12))\n", |
| 147 | + "for i, season in enumerate((\"DJF\", \"MAM\", \"JJA\", \"SON\")):\n", |
| 148 | + " ds_weighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", |
| 149 | + " ax=axes[i, 0],\n", |
| 150 | + " vmin=-30,\n", |
| 151 | + " vmax=30,\n", |
| 152 | + " cmap=\"Spectral_r\",\n", |
| 153 | + " add_colorbar=True,\n", |
| 154 | + " extend=\"both\",\n", |
| 155 | + " )\n", |
| 156 | + "\n", |
| 157 | + " ds_unweighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", |
| 158 | + " ax=axes[i, 1],\n", |
| 159 | + " vmin=-30,\n", |
| 160 | + " vmax=30,\n", |
| 161 | + " cmap=\"Spectral_r\",\n", |
| 162 | + " add_colorbar=True,\n", |
| 163 | + " extend=\"both\",\n", |
| 164 | + " )\n", |
| 165 | + "\n", |
| 166 | + " ds_diff[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", |
| 167 | + " ax=axes[i, 2],\n", |
| 168 | + " vmin=-0.1,\n", |
| 169 | + " vmax=0.1,\n", |
| 170 | + " cmap=\"RdBu_r\",\n", |
| 171 | + " add_colorbar=True,\n", |
| 172 | + " extend=\"both\",\n", |
| 173 | + " )\n", |
157 | 174 | "\n",
|
158 | 175 | " axes[i, 0].set_ylabel(season)\n",
|
159 |
| - " axes[i, 1].set_ylabel('')\n", |
160 |
| - " axes[i, 2].set_ylabel('')\n", |
| 176 | + " axes[i, 1].set_ylabel(\"\")\n", |
| 177 | + " axes[i, 2].set_ylabel(\"\")\n", |
161 | 178 | "\n",
|
162 | 179 | "for ax in axes.flat:\n",
|
163 | 180 | " ax.axes.get_xaxis().set_ticklabels([])\n",
|
164 | 181 | " ax.axes.get_yaxis().set_ticklabels([])\n",
|
165 |
| - " ax.axes.axis('tight')\n", |
166 |
| - " ax.set_xlabel('')\n", |
167 |
| - " \n", |
168 |
| - "axes[0, 0].set_title('Weighted by DPM')\n", |
169 |
| - "axes[0, 1].set_title('Equal Weighting')\n", |
170 |
| - "axes[0, 2].set_title('Difference')\n", |
171 |
| - " \n", |
| 182 | + " ax.axes.axis(\"tight\")\n", |
| 183 | + " ax.set_xlabel(\"\")\n", |
| 184 | + "\n", |
| 185 | + "axes[0, 0].set_title(\"Weighted by DPM\")\n", |
| 186 | + "axes[0, 1].set_title(\"Equal Weighting\")\n", |
| 187 | + "axes[0, 2].set_title(\"Difference\")\n", |
| 188 | + "\n", |
172 | 189 | "plt.tight_layout()\n",
|
173 | 190 | "\n",
|
174 |
| - "fig.suptitle('Seasonal Surface Air Temperature', fontsize=16, y=1.02)" |
| 191 | + "fig.suptitle(\"Seasonal Surface Air Temperature\", fontsize=16, y=1.02)" |
175 | 192 | ]
|
176 | 193 | },
|
177 | 194 | {
|
|
186 | 203 | "outputs": [],
|
187 | 204 | "source": [
|
188 | 205 | "# Wrap it into a simple function\n",
|
189 |
| - "def season_mean(ds, calendar='standard'):\n", |
| 206 | + "def season_mean(ds, calendar=\"standard\"):\n", |
190 | 207 | " # Make a DataArray with the number of days in each month, size = len(time)\n",
|
191 | 208 | " month_length = ds.time.dt.days_in_month\n",
|
192 | 209 | "\n",
|
193 | 210 | " # Calculate the weights by grouping by 'time.season'\n",
|
194 |
| - " weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()\n", |
| 211 | + " weights = (\n", |
| 212 | + " month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n", |
| 213 | + " )\n", |
195 | 214 | "\n",
|
196 | 215 | " # Test that the sum of the weights for each season is 1.0\n",
|
197 |
| - " np.testing.assert_allclose(weights.groupby('time.season').sum().values, np.ones(4))\n", |
| 216 | + " np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n", |
198 | 217 | "\n",
|
199 | 218 | " # Calculate the weighted average\n",
|
200 |
| - " return (ds * weights).groupby('time.season').sum(dim='time')" |
| 219 | + " return (ds * weights).groupby(\"time.season\").sum(dim=\"time\")" |
201 | 220 | ]
|
202 | 221 | }
|
203 | 222 | ],
|
|
0 commit comments