Skip to content

Commit 2800743

Browse files
committed
adjust the max based on the current plot
1 parent 2ebf9f9 commit 2800743

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

conjugate/plot.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,6 @@ def _reshape_x_values(self, x: np.ndarray) -> np.ndarray:
111111

112112
return x
113113

114-
def _settle_axis(self, ax: Axes | None = None) -> Axes:
115-
return ax if ax is not None else plt.gca()
116-
117114

118115
class ContinuousPlotDistMixin(PlotDistMixin):
119116
"""Functionality for plot_pdf method of continuous distributions."""
@@ -122,7 +119,7 @@ def _plot(self, ax: Axes | None = None, cdf: bool = False, **kwargs) -> Axes:
122119
x = self._create_x_values()
123120
x = self._reshape_x_values(x)
124121

125-
ax = self._settle_axis(ax=ax)
122+
ax = ax if ax is not None else plt.gca()
126123

127124
return self._create_plot_on_axis(x=x, cdf=cdf, ax=ax, **kwargs)
128125

@@ -183,7 +180,14 @@ def _create_plot_on_axis(self, x, cdf: bool, ax: Axes, **kwargs) -> Axes:
183180

184181
ax.plot(x, yy, label=label, **kwargs)
185182
self._setup_labels(ax=ax, cdf=cdf)
186-
ax.set_ylim(0, None)
183+
184+
y_max = np.max(yy)
185+
_, current_y_max = ax.get_ylim()
186+
new_max_value = max(y_max, current_y_max) * 1.02
187+
if not np.isfinite(new_max_value):
188+
new_max_value = None
189+
190+
ax.set_ylim(None, new_max_value)
187191
return ax
188192

189193

@@ -211,7 +215,7 @@ def plot_pdf(
211215
"""
212216
distribution_samples = self.dist.rvs(size=samples, random_state=random_state)
213217

214-
ax = self._settle_axis(ax=ax)
218+
ax = ax if ax is not None else plt.gca()
215219
xx = self._create_x_values()
216220

217221
labels = label_to_iterable(
@@ -242,7 +246,7 @@ def _plot(
242246
x = self._create_x_values()
243247
x = self._reshape_x_values(x)
244248

245-
ax = self._settle_axis(ax=ax)
249+
ax = ax if ax is not None else plt.gca()
246250
return self._create_plot_on_axis(
247251
x,
248252
ax=ax,
@@ -353,5 +357,12 @@ def _create_plot_on_axis(
353357

354358
ax.set_xlabel("Domain")
355359
ax.set_ylabel(ylabel)
356-
ax.set_ylim(0, None)
360+
361+
y_max = np.max(yy)
362+
_, current_y_max = ax.get_ylim()
363+
new_max_value = max(y_max, current_y_max) * 1.02
364+
if not np.isfinite(new_max_value):
365+
new_max_value = None
366+
367+
ax.set_ylim(None, new_max_value)
357368
return ax

0 commit comments

Comments
 (0)