Skip to content

Commit 14b9cc6

Browse files
author
briochh
committed
More to help catch mis-alignment at apply pars time
1 parent d4ec97d commit 14b9cc6

File tree

4 files changed

+48
-37
lines changed

4 files changed

+48
-37
lines changed

autotest/pst_from_tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,12 @@ def freyberg_prior_build_test(tmp_path):
555555
par_name_base="welflux_grid",
556556
zone_array=m.bas6.ibound.array,
557557
geostruct=geostruct, lower_bound=0.25, upper_bound=1.75)
558+
pf.add_parameters(filenames=well_mfiles,
559+
par_type="grid", index_cols=[0, 1, 2], use_cols=3,
560+
par_name_base="welflux_grid",
561+
zone_array=m.bas6.ibound.array,
562+
use_rows=(1, 3, 4),
563+
geostruct=geostruct, lower_bound=0.25, upper_bound=1.75)
558564
# global constant across all files
559565
pf.add_parameters(filenames=well_mfiles,
560566
par_type="constant",

pyemu/utils/geostats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,7 @@ def _calc_factors_org(
13831383
if self.interp_data is None:
13841384
self.interp_data = df
13851385
else:
1386-
self.interp_data = self.interp_data.append(df)
1386+
self.interp_data = pd.concat([self.interp_data, df])
13871387
# correct for negative kriging factors, if requested
13881388
if remove_negative_factors == True:
13891389
self._remove_neg_factors()

pyemu/utils/helpers.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,11 +2047,12 @@ def _process_chunk_list_files(chunk, i, df):
20472047
print("process", i, " processed ", len(chunk), "process_list_file calls")
20482048

20492049

2050-
def _list_index_caster(x,add1):
2050+
def _list_index_caster(x, add1):
20512051
vals = []
20522052
for xx in x:
20532053
if xx:
2054-
if xx.strip().isdigit() or (xx.strip()[0] == '-' and xx.strip()[1:].isdigit()):
2054+
if (xx.strip().isdigit() or
2055+
(xx.strip()[0] == '-' and xx.strip()[1:].isdigit())):
20552056
vals.append(add1 + int(xx))
20562057
else:
20572058
try:
@@ -2061,11 +2062,12 @@ def _list_index_caster(x,add1):
20612062

20622063
return tuple(vals)
20632064

2064-
def _list_index_splitter_and_caster(x,add1):
2065-
return _list_index_caster(x.strip("()").replace('\'','').split(","),add1)
20662065

2067-
def _process_list_file(model_file, df):
2066+
def _list_index_splitter_and_caster(x, add1):
2067+
return _list_index_caster(x.strip("()").replace('\'', '').split(","), add1)
2068+
20682069

2070+
def _process_list_file(model_file, df):
20692071
# print("processing model file:", model_file)
20702072
df_mf = df.loc[df.model_file == model_file, :].copy()
20712073
# read data stored in org (mults act on this)
@@ -2110,7 +2112,7 @@ def _process_list_file(model_file, df):
21102112
# index_cols can be from header str
21112113
header = 0
21122114
hheader = True
2113-
elif isinstance(index_col_eg, int):
2115+
elif isinstance(index_col_eg, (int, np.integer)):
21142116
# index_cols are column numbers in input file
21152117
header = None
21162118
hheader = None
@@ -2169,14 +2171,13 @@ def _process_list_file(model_file, df):
21692171
common_idx = (
21702172
new_df.index.intersection(mlts.index).sort_values().drop_duplicates()
21712173
)
2172-
if common_idx.shape[0] == 0:
2173-
raise Exception("error: common_idx is empty")
21742174
mlt_cols = [str(col) for col in mlt.use_cols]
2175-
assert len(common_idx) * len(mlt_cols) == mlt.chkpar, (
2176-
"probable miss-alignment in tpl indices and original file:\n"
2177-
f"mult idx[:10] : {mlts.index.values.tolist()[:10]}\n"
2178-
f"org file idx[:10]: {new_df.index.value[:10]}\n"
2179-
f"n common: {len(common_idx)}, n cols: {len(mlt_cols)}"
2175+
assert len(common_idx) == mlt.chkpar, (
2176+
"Probable miss-alignment in tpl indices and original file:\n"
2177+
f"mult idx[:10] : {mlts.index.sort_values().tolist()[:10]}\n"
2178+
f"org file idx[:10]: {new_df.index.sort_values().to_list()[:10]}\n"
2179+
f"n common: {len(common_idx)}, n cols: {len(mlt_cols)}, "
2180+
f"expected: {mlt.chkpar}."
21802181
)
21812182
operator = mlt.operator
21822183
if operator == "*" or operator.lower()[0] == "m":

pyemu/utils/pst_from.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def parfile_relations(self):
219219
if x is not None
220220
else lb_max["lbound"]
221221
)
222-
pr["zero_based"] = self.zero_based
222+
pr["zero_based"] = self.zero_based # todo -- chase this out if going to file specific zero based def
223223
return pr
224224

225225
def _generic_get_xy(self, args, **kwargs):
@@ -968,7 +968,7 @@ def _par_prep(
968968
sep = " "
969969
if rel_filepath.suffix.lower() == ".csv":
970970
sep = ","
971-
if df.columns.is_integer():
971+
if pd.api.types.is_integer_dtype(df.columns): # df.columns.is_integer(): # really!???
972972
hheader = False
973973
else:
974974
hheader = df.columns
@@ -1912,7 +1912,9 @@ def add_parameters(
19121912
par_style = par_style[0]
19131913
if par_style not in ["m", "d", "a"]:
19141914
self.logger.lraise(
1915-
"add_parameters(): unrecognized 'style': {0}, should be either 'm'/'mult'/'multiplier', 'a'/'add'/'addend' or 'd'/'direct'".format(
1915+
"add_parameters(): unrecognized 'style': {0}, "
1916+
"should be either 'm'/'mult'/'multiplier', "
1917+
"'a'/'add'/'addend' or 'd'/'direct'".format(
19161918
par_style
19171919
)
19181920
)
@@ -2138,7 +2140,7 @@ def add_parameters(
21382140

21392141
pp_filename = None # setup placeholder variables
21402142
fac_filename = None
2141-
2143+
nxs = None
21422144
# Process model parameter files to produce appropriate pest pars
21432145
if index_cols is not None: # Assume list/tabular type input files
21442146
# ensure inputs are provided for all required cols
@@ -2167,7 +2169,7 @@ def add_parameters(
21672169
par_type.startswith("grid") or par_type.startswith("p")
21682170
) and geostruct is not None:
21692171
get_xy = self.get_xy
2170-
df = write_list_tpl(
2172+
df, nxs = write_list_tpl(
21712173
filenames,
21722174
dfs,
21732175
par_name_base,
@@ -2189,6 +2191,7 @@ def add_parameters(
21892191
fill_value=initial_value,
21902192
logger=self.logger,
21912193
)
2194+
nxs = {fname: nx for fname, nx in zip(filenames, nxs)}
21922195
assert (
21932196
np.mod(len(df), len(use_cols)) == 0.0
21942197
), "Parameter dataframe wrong shape for number of cols {0}" "".format(
@@ -2273,14 +2276,13 @@ def add_parameters(
22732276
structured = True
22742277
for mod_file, ar in file_dict.items():
22752278
orgdata = ar.shape
2276-
if spatial_reference_type=='vertex':
2279+
if spatial_reference_type == 'vertex':
22772280
assert orgdata[0] == spatial_reference.ncpl, (
22782281
"Spatial reference ncpl not equal to original data ncpl for\n"
22792282
+ os.path.join(
22802283
*os.path.split(self.original_file_d)[1:], mod_file
22812284
)
22822285
)
2283-
22842286
else:
22852287
assert orgdata[0] == spatial_reference.nrow, (
22862288
"Spatial reference nrow not equal to original data nrow for\n"
@@ -2643,7 +2645,7 @@ def add_parameters(
26432645
zone_filename = zone_filename.name
26442646

26452647
relate_parfiles = []
2646-
for mod_file in file_dict.keys():
2648+
for mod_file, pdf in file_dict.items():
26472649
mult_dict = {
26482650
"org_file": Path(self.original_file_d.name, mod_file.name),
26492651
"model_file": mod_file,
@@ -2655,8 +2657,9 @@ def add_parameters(
26552657
"upper_bound": ult_ubound,
26562658
"lower_bound": ult_lbound,
26572659
"operator": par_style,
2658-
"chkpar": len(df)
26592660
}
2661+
if nxs:
2662+
mult_dict["chkpar"] = nxs[mod_file]
26602663
if par_style in ["m", "a"]:
26612664
mult_dict["mlt_file"] = Path(self.mult_file_d.name, mlt_filename)
26622665

@@ -3094,7 +3097,7 @@ def write_list_tpl(
30943097
# get dataframe with autogenerated parnames based on `name`, `index_cols`,
30953098
# `use_cols`, `suffix` and `par_type`
30963099
if par_style == "d":
3097-
df_tpl = _write_direct_df_tpl(
3100+
df_tpl, nxs = _write_direct_df_tpl(
30983101
filenames[0],
30993102
tpl_filename,
31003103
dfs[0],
@@ -3130,8 +3133,8 @@ def write_list_tpl(
31303133
par_fill_value=fill_value,
31313134
par_style=par_style,
31323135
)
3133-
idxs = [df.loc[:, index_cols].values.tolist() for df in dfs]
3134-
use_rows = _get_use_rows(
3136+
idxs = [[tuple(s) for s in df.loc[:, index_cols].values] for df in dfs]
3137+
use_rows, nxs = _get_use_rows(
31353138
df_tpl, idxs, use_rows, zero_based, tpl_filename, logger=logger
31363139
)
31373140
df_tpl = df_tpl.loc[use_rows, :] # direct pars done in direct function
@@ -3227,7 +3230,7 @@ def write_list_tpl(
32273230
df_par.loc[:, "tpl_filename"] = tpl_filename
32283231
df_par.loc[:, "input_filename"] = input_filename
32293232
df_par.loc[:, "parval1"] = parval
3230-
return df_par
3233+
return df_par, nxs
32313234

32323235

32333236
def _write_direct_df_tpl(
@@ -3311,8 +3314,8 @@ def _write_direct_df_tpl(
33113314
init_df=df,
33123315
init_fname=in_filename,
33133316
)
3314-
idxs = df.loc[:, index_cols].values.tolist()
3315-
use_rows = _get_use_rows(
3317+
idxs = [tuple(s) for s in df.loc[:, index_cols].values]
3318+
use_rows, nxs = _get_use_rows(
33163319
df_ti, [idxs], use_rows, zero_based, tpl_filename, logger=logger
33173320
)
33183321
df_ti = df_ti.loc[use_rows]
@@ -3325,7 +3328,7 @@ def _write_direct_df_tpl(
33253328
pyemu.helpers._write_df_tpl(
33263329
tpl_filename, direct_tpl_df, index=False, header=header, headerlines=headerlines
33273330
)
3328-
return df_ti
3331+
return df_ti, nxs
33293332

33303333

33313334
def _get_use_rows(tpldf, idxcolvals, use_rows, zero_based, fnme, logger=None):
@@ -3345,19 +3348,23 @@ def _get_use_rows(tpldf, idxcolvals, use_rows, zero_based, fnme, logger=None):
33453348
"""
33463349
if use_rows is None:
33473350
use_rows = tpldf.index
3348-
return use_rows
3351+
nxs = [len(set(idx)) for idx in idxcolvals]
3352+
return use_rows, nxs
33493353
if np.ndim(use_rows) == 0:
33503354
use_rows = [use_rows]
33513355
if np.ndim(use_rows) == 1: # assume we have collection of int that describe iloc
33523356
use_rows = [idx[i] for i in use_rows for idx in idxcolvals]
3357+
else:
3358+
use_rows = [tuple(r) for r in use_rows]
3359+
nxs = [len(set(use_rows).intersection(idx)) for idx in idxcolvals]
3360+
orig_use_rows = use_rows.copy()
33533361
if not zero_based: # assume passed indicies are 1 based
33543362
use_rows = [
3355-
tuple([i - 1 if isinstance(i, int) else i for i in r])
3363+
tuple([i - 1 if isinstance(i, (int, np.integer)) else i for i in r])
33563364
if not isinstance(r, str)
33573365
else r
33583366
for r in use_rows
33593367
]
3360-
orig_use_rows = use_rows
33613368
use_rows = set(use_rows)
33623369
sel = tpldf.sidx.isin(use_rows) | tpldf.idx_strs.isin(use_rows)
33633370
if not sel.any(): # use_rows must be ints
@@ -3387,7 +3394,7 @@ def _get_use_rows(tpldf, idxcolvals, use_rows, zero_based, fnme, logger=None):
33873394
else:
33883395
warnings.warn(msg, PyemuWarning)
33893396
use_rows = tpldf.index
3390-
return use_rows
3397+
return use_rows, nxs
33913398

33923399

33933400
def _get_index_strfmt(index_cols):
@@ -3590,7 +3597,6 @@ def _get_tpl_or_ins_df(
35903597
Private method to auto-generate parameter or obs names from tabular
35913598
model files (input or output) read into pandas dataframes
35923599
Args:
3593-
filenames (`str` or `list` of str`): filenames
35943600
dfs (`pandas.DataFrame` or `list`): DataFrames (can be list of DataFrames)
35953601
to set up parameters or observations
35963602
name (`str`): Parameter name or Observation name prefix
@@ -3633,8 +3639,6 @@ def _get_tpl_or_ins_df(
36333639

36343640
# work out the union of indices across all dfs
36353641
if typ != "obs":
3636-
3637-
36383642
sidx = set()
36393643
for df in dfs:
36403644
# looses ordering

0 commit comments

Comments
 (0)