Skip to content

Commit 48f46fb

Browse files
authored
refactor(gridutil): improve arg handling in get_disu_kwargs (#2480)
1 parent 932bdcc commit 48f46fb

File tree

2 files changed

+60
-28
lines changed

2 files changed

+60
-28
lines changed

autotest/test_gridutil.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ def test_get_lni_infers_layer_count_when_int_ncpl(ncpl, nodes, expected):
7272
61,
7373
np.array(61 * [50]),
7474
np.array(61 * [50]),
75-
np.array([-10]),
76-
np.array([-30.0, -50.0]),
75+
-10,
76+
np.array([-30.0]),
7777
),
7878
(
7979
2,
8080
61,
8181
61,
8282
np.array(61 * [50]),
8383
np.array(61 * [50]),
84-
np.array([-10]),
84+
-10,
8585
np.array([-30.0, -50.0]),
8686
),
8787
(
@@ -90,8 +90,8 @@ def test_get_lni_infers_layer_count_when_int_ncpl(ncpl, nodes, expected):
9090
4, # ncol
9191
np.array(4 * [4.0]), # delr
9292
np.array(3 * [3.0]), # delc
93-
np.array([-10]), # top
94-
np.array([-30.0]), # botm
93+
-10, # top
94+
-30.0, # botm
9595
),
9696
],
9797
)
@@ -107,10 +107,6 @@ def test_get_disu_kwargs(nlay, nrow, ncol, delr, delc, tp, botm):
107107
return_vertices=True,
108108
)
109109

110-
from pprint import pprint
111-
112-
pprint(kwargs["area"])
113-
114110
assert kwargs["nodes"] == nlay * nrow * ncol
115111
assert kwargs["nvert"] == (nrow + 1) * (ncol + 1)
116112

flopy/utils/gridutil.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,45 @@ def get_disu_kwargs(
9898
def get_nn(k, i, j):
9999
return k * nrow * ncol + i * ncol + j
100100

101-
if not isinstance(delr, np.ndarray):
102-
delr = np.array(delr)
103-
if not isinstance(delc, np.ndarray):
104-
delc = np.array(delc)
105-
assert delr.shape == (ncol,)
106-
assert delc.shape == (nrow,)
101+
# delr check
102+
if np.isscalar(delr):
103+
delr = delr * np.ones(ncol, dtype=float)
104+
else:
105+
assert np.asanyarray(delr).shape == (ncol,), (
106+
"delr must be array with shape (ncol,), got {}".format(delr.shape)
107+
)
108+
109+
# delc check
110+
if np.isscalar(delc):
111+
delc = delc * np.ones(nrow, dtype=float)
112+
else:
113+
assert np.asanyarray(delc).shape == (nrow,), (
114+
"delc must be array with shape (nrow,), got {}".format(delc.shape)
115+
)
116+
117+
# tp check
118+
if np.isscalar(tp):
119+
tp = tp * np.ones((nrow, ncol), dtype=float)
120+
else:
121+
assert np.asanyarray(tp).shape == (
122+
nrow,
123+
ncol,
124+
), "tp must be scalar or array with shape (nrow, ncol), got {}".format(tp.shape)
125+
126+
# botm check
127+
if np.isscalar(botm):
128+
botm = botm * np.ones((nlay, nrow, ncol), dtype=float)
129+
elif np.asanyarray(botm).shape == (nlay,):
130+
b = np.empty((nlay, nrow, ncol), dtype=float)
131+
for k in range(nlay):
132+
b[k] = botm[k]
133+
botm = b
134+
else:
135+
assert np.asanyarray(botm).shape == (
136+
nlay,
137+
nrow,
138+
ncol,
139+
), "botm must be array with shape (nlay, nrow, ncol), got {}".format(botm.shape)
107140

108141
nodes = nlay * nrow * ncol
109142
iac = np.zeros((nodes), dtype=int)
@@ -126,16 +159,16 @@ def get_nn(k, i, j):
126159
cl12.append(n + 1)
127160
hwva.append(n + 1)
128161
if k == 0:
129-
top[n] = tp.item() if isinstance(tp, np.ndarray) else tp
162+
top[n] = tp[i, j]
130163
else:
131-
top[n] = botm[k - 1]
132-
bot[n] = botm[k]
164+
top[n] = botm[k - 1, i, j]
165+
bot[n] = botm[k, i, j]
133166
# up
134167
if k > 0:
135168
ja.append(get_nn(k - 1, i, j))
136169
iac[n] += 1
137170
ihc.append(0)
138-
dz = botm[k - 1] - botm[k]
171+
dz = botm[k - 1, i, j] - botm[k, i, j]
139172
cl12.append(0.5 * dz)
140173
hwva.append(delr[j] * delc[i])
141174
# back
@@ -172,9 +205,9 @@ def get_nn(k, i, j):
172205
iac[n] += 1
173206
ihc.append(0)
174207
if k == 0:
175-
dz = tp - botm[k]
208+
dz = tp[i, j] - botm[k, i, j]
176209
else:
177-
dz = botm[k - 1] - botm[k]
210+
dz = botm[k - 1, i, j] - botm[k, i, j]
178211
cl12.append(0.5 * dz)
179212
hwva.append(delr[j] * delc[i])
180213
ja = np.array(ja, dtype=int)
@@ -272,28 +305,31 @@ def get_disv_kwargs(
272305
if np.isscalar(delr):
273306
delr = delr * np.ones(ncol, dtype=float)
274307
else:
275-
assert delr.shape == (ncol,), "delr must be array with shape (ncol,)"
308+
assert np.asanyarray(delr).shape == (ncol,), (
309+
"delr must be array with shape (ncol,), got {}".format(delr.shape)
310+
)
276311

277312
# delc check
278313
if np.isscalar(delc):
279314
delc = delc * np.ones(nrow, dtype=float)
280315
else:
281-
assert delc.shape == (nrow,), "delc must be array with shape (nrow,)"
316+
assert np.asanyarray(delc).shape == (nrow,), (
317+
"delc must be array with shape (nrow,), got {}".format(delc.shape)
318+
)
282319

283320
# tp check
284321
if np.isscalar(tp):
285322
tp = tp * np.ones((nrow, ncol), dtype=float)
286323
else:
287-
assert tp.shape == (
324+
assert np.asanyarray(tp).shape == (
288325
nrow,
289326
ncol,
290-
), "tp must be scalar or array with shape (nrow, ncol)"
327+
), "tp must be scalar or array with shape (nrow, ncol), got {}".format(tp.shape)
291328

292329
# botm check
293330
if np.isscalar(botm):
294331
botm = botm * np.ones((nlay, nrow, ncol), dtype=float)
295-
elif isinstance(botm, list):
296-
assert len(botm) == nlay, "if botm provided as a list it must have length nlay"
332+
elif np.asanyarray(botm).shape == (nlay,):
297333
b = np.empty((nlay, nrow, ncol), dtype=float)
298334
for k in range(nlay):
299335
b[k] = botm[k]
@@ -303,7 +339,7 @@ def get_disv_kwargs(
303339
nlay,
304340
nrow,
305341
ncol,
306-
), "botm must be array with shape (nlay, nrow, ncol)"
342+
), "botm must be array with shape (nlay, nrow, ncol), got {}".format(botm.shape)
307343

308344
# build vertices
309345
xv = np.cumsum(delr)

0 commit comments

Comments
 (0)