Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions paddle/fluid/framework/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,25 @@ class ReaderHolder {

ReaderBase* Get() const { return reader_.get(); }

void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
void ReInit() { reader_->ReInit(); }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
void ReInit() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit();
}

DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
DDim shape(size_t idx) const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shape(idx);
}
std::vector<DDim> shapes() const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shapes();
}
void set_shapes(const std::vector<DDim>& shapes) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->set_shapes(shapes);
}

Expand Down
38 changes: 21 additions & 17 deletions paddle/fluid/operators/reader/reader_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,25 @@ FileReaderMakerBase::FileReaderMakerBase(
}

void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(
!ctx->IsRuntime(),
"'FileReaderInferShape' should only be invoked during compile time.");

PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);

if (ctx->IsRuntime()) {
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}

void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
Expand All @@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,

void DecoratedReaderInferShape::operator()(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"'DecoratedReaderInferShape' should only be invoked during "
"compile time.");

PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));

if (ctx->IsRuntime()) {
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
Expand Down