Skip to content

Commit 6c75e52

Browse files
committed
Fix catalog identifier matching to exact match
1 parent acc46d7 commit 6c75e52

File tree

2 files changed

+99
-37
lines changed

2 files changed

+99
-37
lines changed

crates/persistence/src/backend/catalog.rs

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,14 +1251,9 @@ impl ParquetDataCatalog {
12511251
let start_u64 = start.map(|s| s.as_u64());
12521252
let end_u64 = end.map(|e| e.as_u64());
12531253

1254-
let safe_ids = instrument_ids.as_ref().map(|ids| {
1255-
ids.iter()
1256-
.map(|id| urisafe_instrument_id(id))
1257-
.collect::<Vec<String>>()
1258-
});
1259-
12601254
let base_dir = self.make_path(data_cls, None)?;
12611255

1256+
// Use recursive listing to match Python's glob behavior
12621257
let list_result = self.execute_async(async {
12631258
let prefix = ObjectPath::from(format!("{base_dir}/"));
12641259
let mut stream = self.object_store.list(Some(&prefix));
@@ -1269,23 +1264,68 @@ impl ParquetDataCatalog {
12691264
Ok::<Vec<_>, anyhow::Error>(objects)
12701265
})?;
12711266

1272-
for object in list_result {
1273-
let path_str = object.location.to_string();
1274-
if path_str.ends_with(".parquet") {
1275-
if let Some(ids) = &safe_ids {
1276-
let matches_any_id = ids.iter().any(|safe_id| path_str.contains(safe_id));
1277-
if !matches_any_id {
1278-
continue;
1279-
}
1267+
let mut file_paths: Vec<String> = list_result
1268+
.into_iter()
1269+
.filter_map(|object| {
1270+
let path_str = object.location.to_string();
1271+
if path_str.ends_with(".parquet") {
1272+
Some(path_str)
1273+
} else {
1274+
None
12801275
}
1276+
})
1277+
.collect();
12811278

1282-
if query_intersects_filename(&path_str, start_u64, end_u64) {
1283-
let full_uri = self.reconstruct_full_uri(&path_str);
1284-
files.push(full_uri);
1285-
}
1279+
// Apply identifier filtering if provided
1280+
if let Some(identifiers) = instrument_ids {
1281+
let safe_identifiers: Vec<String> = identifiers
1282+
.iter()
1283+
.map(|id| urisafe_instrument_id(id))
1284+
.collect();
1285+
1286+
// Exact match by default for instrument_ids or bar_types
1287+
let exact_match_file_paths: Vec<String> = file_paths
1288+
.iter()
1289+
.filter(|file_path| {
1290+
// Extract the directory name (second to last path component)
1291+
let path_parts: Vec<&str> = file_path.split('/').collect();
1292+
if path_parts.len() >= 2 {
1293+
let dir_name = path_parts[path_parts.len() - 2];
1294+
safe_identifiers.iter().any(|safe_id| safe_id == dir_name)
1295+
} else {
1296+
false
1297+
}
1298+
})
1299+
.cloned()
1300+
.collect();
1301+
1302+
if exact_match_file_paths.is_empty() && data_cls == "bars" {
1303+
// Partial match of instrument_ids in bar_types for bars
1304+
file_paths.retain(|file_path| {
1305+
let path_parts: Vec<&str> = file_path.split('/').collect();
1306+
if path_parts.len() >= 2 {
1307+
let dir_name = path_parts[path_parts.len() - 2];
1308+
safe_identifiers
1309+
.iter()
1310+
.any(|safe_id| dir_name.starts_with(&format!("{}-", safe_id)))
1311+
} else {
1312+
false
1313+
}
1314+
});
1315+
} else {
1316+
file_paths = exact_match_file_paths;
12861317
}
12871318
}
12881319

1320+
// Apply timestamp filtering
1321+
file_paths.retain(|file_path| query_intersects_filename(file_path, start_u64, end_u64));
1322+
1323+
// Convert to full URIs
1324+
for file_path in file_paths {
1325+
let full_uri = self.reconstruct_full_uri(&file_path);
1326+
files.push(full_uri);
1327+
}
1328+
12891329
Ok(files)
12901330
}
12911331

nautilus_trader/persistence/catalog/parquet.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,7 +1195,7 @@ def query(
11951195
An additional SQL WHERE clause to filter the data (used in Rust queries).
11961196
files : list[str], optional
11971197
A specific list of files to query from. If provided, these files are used
1198-
instead of discovering files through the normal process. Forces PyArrow backend.
1198+
instead of discovering files through the normal process.
11991199
**kwargs : Any
12001200
Additional keyword arguments passed to the underlying query implementation.
12011201
@@ -1233,6 +1233,7 @@ def query(
12331233
start=start,
12341234
end=end,
12351235
where=where,
1236+
files=files,
12361237
**kwargs,
12371238
)
12381239
else:
@@ -1270,6 +1271,7 @@ def _query_rust(
12701271
start: TimestampLike | None = None,
12711272
end: TimestampLike | None = None,
12721273
where: str | None = None,
1274+
files: list[str] | None = None,
12731275
**kwargs: Any,
12741276
) -> list[Data]:
12751277
query_data_cls = OrderBookDelta if data_cls == OrderBookDeltas else data_cls
@@ -1279,6 +1281,7 @@ def _query_rust(
12791281
start=start,
12801282
end=end,
12811283
where=where,
1284+
file=files,
12821285
**kwargs,
12831286
)
12841287
result = session.to_query_result()
@@ -1304,6 +1307,7 @@ def backend_session(
13041307
end: TimestampLike | None = None,
13051308
where: str | None = None,
13061309
session: DataBackendSession | None = None,
1310+
files: list[str] | None = None,
13071311
**kwargs: Any,
13081312
) -> DataBackendSession:
13091313
"""
@@ -1327,6 +1331,9 @@ def backend_session(
13271331
An additional SQL WHERE clause to filter the data.
13281332
session : DataBackendSession, optional
13291333
An existing session to update. If None, a new session is created.
1334+
files : list[str], optional
1335+
A specific list of files to query from. If provided, these files are used
1336+
instead of discovering files through the normal process.
13301337
**kwargs : Any
13311338
Additional keyword arguments.
13321339
@@ -1351,7 +1358,7 @@ def backend_session(
13511358
13521359
"""
13531360
data_type: NautilusDataType = ParquetDataCatalog._nautilus_data_cls_to_data_type(data_cls)
1354-
files = self._query_files(data_cls, identifiers, start, end)
1361+
file_list = files if files else self._query_files(data_cls, identifiers, start, end)
13551362
file_prefix = class_to_filename(data_cls)
13561363

13571364
if session is None:
@@ -1361,7 +1368,7 @@ def backend_session(
13611368
if self.fs_protocol != "file":
13621369
self._register_object_store_with_session(session)
13631370

1364-
for idx, file in enumerate(files):
1371+
for idx, file in enumerate(file_list):
13651372
table = f"{file_prefix}_{idx}"
13661373
query = self._build_query(
13671374
table,
@@ -1492,10 +1499,7 @@ def _query_pyarrow(
14921499
**kwargs: Any,
14931500
) -> list[Data]:
14941501
# Load dataset - use provided files or query for them
1495-
if files is not None:
1496-
file_list = files
1497-
else:
1498-
file_list = self._query_files(data_cls, identifiers, start, end)
1502+
file_list = files if files else self._query_files(data_cls, identifiers, start, end)
14991503

15001504
if not file_list:
15011505
return []
@@ -1536,32 +1540,50 @@ def _query_files(
15361540
file_prefix = class_to_filename(data_cls)
15371541
base_path = self.path.rstrip("/")
15381542
glob_path = f"{base_path}/data/{file_prefix}/**/*.parquet"
1539-
file_names: list[str] = self.fs.glob(glob_path)
1543+
file_paths: list[str] = self.fs.glob(glob_path)
15401544

15411545
if identifiers:
15421546
if not isinstance(identifiers, list):
15431547
identifiers = [identifiers]
15441548

15451549
safe_identifiers = [urisafe_identifier(identifier) for identifier in identifiers]
1546-
file_names = [
1547-
file_name
1548-
for file_name in file_names
1549-
if any(safe_identifier in file_name for safe_identifier in safe_identifiers)
1550+
1551+
# Exact match by default for instrument_ids or bar_types
1552+
exact_match_file_paths = [
1553+
file_path
1554+
for file_path in file_paths
1555+
if any(
1556+
safe_identifier == file_path.split("/")[-2]
1557+
for safe_identifier in safe_identifiers
1558+
)
15501559
]
15511560

1561+
if not exact_match_file_paths and data_cls in [Bar, *Bar.__subclasses__()]:
1562+
# Partial match of instrument_ids in bar_types for bars
1563+
file_paths = [
1564+
file_path
1565+
for file_path in file_paths
1566+
if any(
1567+
file_path.split("/")[-2].startswith(f"{safe_identifier}-")
1568+
for safe_identifier in safe_identifiers
1569+
)
1570+
]
1571+
else:
1572+
file_paths = exact_match_file_paths
1573+
15521574
used_start: pd.Timestamp | None = time_object_to_dt(start)
15531575
used_end: pd.Timestamp | None = time_object_to_dt(end)
1554-
file_names = [
1555-
file_name
1556-
for file_name in file_names
1557-
if _query_intersects_filename(file_name, used_start, used_end)
1576+
file_paths = [
1577+
file_path
1578+
for file_path in file_paths
1579+
if _query_intersects_filename(file_path, used_start, used_end)
15581580
]
15591581

15601582
if self.show_query_paths:
1561-
for file_name in file_names:
1562-
print(file_name)
1583+
for file_path in file_paths:
1584+
print(file_path)
15631585

1564-
return file_names
1586+
return file_paths
15651587

15661588
@staticmethod
15671589
def _handle_table_nautilus(

0 commit comments

Comments
 (0)