|
29 | 29 | #include "arrow/filesystem/localfs.h"
|
30 | 30 | #include "arrow/filesystem/path_util.h"
|
31 | 31 | #include "arrow/filesystem/util_internal.h"
|
| 32 | +#include "arrow/util/checked_cast.h" |
| 33 | +#include "arrow/util/make_unique.h" |
| 34 | +#include "arrow/util/uri.h" |
32 | 35 |
|
33 | 36 | namespace arrow {
|
| 37 | + |
| 38 | +using ::arrow::internal::UriFromAbsolutePath; |
| 39 | +using internal::checked_cast; |
| 40 | +using internal::make_unique; |
| 41 | + |
34 | 42 | namespace engine {
|
35 | 43 |
|
36 | 44 | template <typename RelMessage>
|
@@ -162,36 +170,45 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
|
162 | 170 | }
|
163 | 171 |
|
164 | 172 | path = path.substr(7);
|
165 |
| - if (item.path_type_case() == |
166 |
| - substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath) { |
167 |
| - ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); |
168 |
| - if (file.type() == fs::FileType::File) { |
169 |
| - files.push_back(std::move(file)); |
170 |
| - } else if (file.type() == fs::FileType::Directory) { |
| 173 | + switch (item.path_type_case()) { |
| 174 | + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath: { |
| 175 | + ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); |
| 176 | + if (file.type() == fs::FileType::File) { |
| 177 | + files.push_back(std::move(file)); |
| 178 | + } else if (file.type() == fs::FileType::Directory) { |
| 179 | + fs::FileSelector selector; |
| 180 | + selector.base_dir = path; |
| 181 | + selector.recursive = true; |
| 182 | + ARROW_ASSIGN_OR_RAISE(auto discovered_files, |
| 183 | + filesystem->GetFileInfo(selector)); |
| 184 | + std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); |
| 185 | + } |
| 186 | + break; |
| 187 | + } |
| 188 | + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile: { |
| 189 | + files.emplace_back(path, fs::FileType::File); |
| 190 | + break; |
| 191 | + } |
| 192 | + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder: { |
171 | 193 | fs::FileSelector selector;
|
172 | 194 | selector.base_dir = path;
|
173 | 195 | selector.recursive = true;
|
174 | 196 | ARROW_ASSIGN_OR_RAISE(auto discovered_files,
|
175 | 197 | filesystem->GetFileInfo(selector));
|
176 |
| - std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); |
| 198 | + std::move(discovered_files.begin(), discovered_files.end(), |
| 199 | + std::back_inserter(files)); |
| 200 | + break; |
| 201 | + } |
| 202 | + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPathGlob: { |
| 203 | + ARROW_ASSIGN_OR_RAISE(auto discovered_files, |
| 204 | + fs::internal::GlobFiles(filesystem, path)); |
| 205 | + std::move(discovered_files.begin(), discovered_files.end(), |
| 206 | + std::back_inserter(files)); |
| 207 | + break; |
| 208 | + } |
| 209 | + default: { |
| 210 | + return Status::Invalid("Unrecognized file type in LocalFiles"); |
177 | 211 | }
|
178 |
| - } |
179 |
| - if (item.path_type_case() == |
180 |
| - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile) { |
181 |
| - files.emplace_back(path, fs::FileType::File); |
182 |
| - } else if (item.path_type_case() == |
183 |
| - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder) { |
184 |
| - fs::FileSelector selector; |
185 |
| - selector.base_dir = path; |
186 |
| - selector.recursive = true; |
187 |
| - ARROW_ASSIGN_OR_RAISE(auto discovered_files, filesystem->GetFileInfo(selector)); |
188 |
| - std::move(discovered_files.begin(), discovered_files.end(), |
189 |
| - std::back_inserter(files)); |
190 |
| - } else { |
191 |
| - ARROW_ASSIGN_OR_RAISE(auto discovered_files, |
192 |
| - fs::internal::GlobFiles(filesystem, path)); |
193 |
| - std::move(discovered_files.begin(), discovered_files.end(), |
194 |
| - std::back_inserter(files)); |
195 | 212 | }
|
196 | 213 | }
|
197 | 214 |
|
@@ -421,5 +438,141 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
|
421 | 438 | rel.DebugString());
|
422 | 439 | }
|
423 | 440 |
|
| 441 | +namespace { |
| 442 | + |
| 443 | +Result<std::shared_ptr<Schema>> ExtractSchemaToBind(const compute::Declaration& declr) { |
| 444 | + std::shared_ptr<Schema> bind_schema; |
| 445 | + if (declr.factory_name == "scan") { |
| 446 | + const auto& opts = checked_cast<const dataset::ScanNodeOptions&>(*(declr.options)); |
| 447 | + bind_schema = opts.dataset->schema(); |
| 448 | + } else if (declr.factory_name == "filter") { |
| 449 | + auto input_declr = util::get<compute::Declaration>(declr.inputs[0]); |
| 450 | + ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); |
| 451 | + } else if (declr.factory_name == "sink") { |
| 452 | + // Note that the sink has no output_schema |
| 453 | + return bind_schema; |
| 454 | + } else { |
| 455 | + return Status::Invalid("Schema extraction failed, unsupported factory ", |
| 456 | + declr.factory_name); |
| 457 | + } |
| 458 | + return bind_schema; |
| 459 | +} |
| 460 | + |
| 461 | +Result<std::unique_ptr<substrait::ReadRel>> ScanRelationConverter( |
| 462 | + const std::shared_ptr<Schema>& schema, const compute::Declaration& declaration, |
| 463 | + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { |
| 464 | + auto read_rel = make_unique<substrait::ReadRel>(); |
| 465 | + const auto& scan_node_options = |
| 466 | + checked_cast<const dataset::ScanNodeOptions&>(*declaration.options); |
| 467 | + auto dataset = |
| 468 | + dynamic_cast<dataset::FileSystemDataset*>(scan_node_options.dataset.get()); |
| 469 | + if (dataset == nullptr) { |
| 470 | + return Status::Invalid( |
| 471 | + "Can only convert scan node with FileSystemDataset to a Substrait plan."); |
| 472 | + } |
| 473 | + // set schema |
| 474 | + ARROW_ASSIGN_OR_RAISE(auto named_struct, |
| 475 | + ToProto(*dataset->schema(), ext_set, conversion_options)); |
| 476 | + read_rel->set_allocated_base_schema(named_struct.release()); |
| 477 | + |
| 478 | + // set local files |
| 479 | + auto read_rel_lfs = make_unique<substrait::ReadRel_LocalFiles>(); |
| 480 | + for (const auto& file : dataset->files()) { |
| 481 | + auto read_rel_lfs_ffs = make_unique<substrait::ReadRel_LocalFiles_FileOrFiles>(); |
| 482 | + read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); |
| 483 | + // set file format |
| 484 | + auto format_type_name = dataset->format()->type_name(); |
| 485 | + if (format_type_name == "parquet") { |
| 486 | + read_rel_lfs_ffs->set_allocated_parquet( |
| 487 | + new substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); |
| 488 | + } else if (format_type_name == "ipc") { |
| 489 | + read_rel_lfs_ffs->set_allocated_arrow( |
| 490 | + new substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); |
| 491 | + } else if (format_type_name == "orc") { |
| 492 | + read_rel_lfs_ffs->set_allocated_orc( |
| 493 | + new substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); |
| 494 | + } else { |
| 495 | + return Status::NotImplemented("Unsupported file type: ", format_type_name); |
| 496 | + } |
| 497 | + read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); |
| 498 | + } |
| 499 | + read_rel->set_allocated_local_files(read_rel_lfs.release()); |
| 500 | + return std::move(read_rel); |
| 501 | +} |
| 502 | + |
| 503 | +Result<std::unique_ptr<substrait::FilterRel>> FilterRelationConverter( |
| 504 | + const std::shared_ptr<Schema>& schema, const compute::Declaration& declaration, |
| 505 | + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { |
| 506 | + auto filter_rel = make_unique<substrait::FilterRel>(); |
| 507 | + const auto& filter_node_options = |
| 508 | + checked_cast<const compute::FilterNodeOptions&>(*(declaration.options)); |
| 509 | + |
| 510 | + auto filter_expr = filter_node_options.filter_expression; |
| 511 | + compute::Expression bound_expression; |
| 512 | + if (!filter_expr.IsBound()) { |
| 513 | + ARROW_ASSIGN_OR_RAISE(bound_expression, filter_expr.Bind(*schema)); |
| 514 | + } |
| 515 | + |
| 516 | + if (declaration.inputs.size() == 0) { |
| 517 | + return Status::Invalid("Filter node doesn't have an input."); |
| 518 | + } |
| 519 | + |
| 520 | + // handling input |
| 521 | + auto declr_input = declaration.inputs[0]; |
| 522 | + ARROW_ASSIGN_OR_RAISE( |
| 523 | + auto input_rel, |
| 524 | + ToProto(util::get<compute::Declaration>(declr_input), ext_set, conversion_options)); |
| 525 | + filter_rel->set_allocated_input(input_rel.release()); |
| 526 | + |
| 527 | + ARROW_ASSIGN_OR_RAISE(auto subs_expr, |
| 528 | + ToProto(bound_expression, ext_set, conversion_options)); |
| 529 | + filter_rel->set_allocated_condition(subs_expr.release()); |
| 530 | + return std::move(filter_rel); |
| 531 | +} |
| 532 | + |
| 533 | +} // namespace |
| 534 | + |
| 535 | +Status SerializeAndCombineRelations(const compute::Declaration& declaration, |
| 536 | + ExtensionSet* ext_set, |
| 537 | + std::unique_ptr<substrait::Rel>* rel, |
| 538 | + const ConversionOptions& conversion_options) { |
| 539 | + const auto& factory_name = declaration.factory_name; |
| 540 | + ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); |
| 541 | + // Note that the sink declaration factory doesn't exist for serialization as |
| 542 | + // Substrait doesn't deal with a sink node definition |
| 543 | + |
| 544 | + if (factory_name == "scan") { |
| 545 | + ARROW_ASSIGN_OR_RAISE( |
| 546 | + auto read_rel, |
| 547 | + ScanRelationConverter(schema, declaration, ext_set, conversion_options)); |
| 548 | + (*rel)->set_allocated_read(read_rel.release()); |
| 549 | + } else if (factory_name == "filter") { |
| 550 | + ARROW_ASSIGN_OR_RAISE( |
| 551 | + auto filter_rel, |
| 552 | + FilterRelationConverter(schema, declaration, ext_set, conversion_options)); |
| 553 | + (*rel)->set_allocated_filter(filter_rel.release()); |
| 554 | + } else if (factory_name == "sink") { |
| 555 | + // Generally when a plan is deserialized the declaration will be a sink declaration. |
| 556 | + // Since there is no Sink relation in substrait, this function would be recursively |
| 557 | + // called on the input of the Sink declaration. |
| 558 | + auto sink_input_decl = util::get<compute::Declaration>(declaration.inputs[0]); |
| 559 | + RETURN_NOT_OK( |
| 560 | + SerializeAndCombineRelations(sink_input_decl, ext_set, rel, conversion_options)); |
| 561 | + } else { |
| 562 | + return Status::NotImplemented("Factory ", factory_name, |
| 563 | + " not implemented for roundtripping."); |
| 564 | + } |
| 565 | + |
| 566 | + return Status::OK(); |
| 567 | +} |
| 568 | + |
| 569 | +Result<std::unique_ptr<substrait::Rel>> ToProto( |
| 570 | + const compute::Declaration& declr, ExtensionSet* ext_set, |
| 571 | + const ConversionOptions& conversion_options) { |
| 572 | + auto rel = make_unique<substrait::Rel>(); |
| 573 | + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); |
| 574 | + return std::move(rel); |
| 575 | +} |
| 576 | + |
424 | 577 | } // namespace engine
|
425 | 578 | } // namespace arrow
|
0 commit comments