-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add mean IOU op. #10519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mean IOU op. #10519
Changes from 5 commits
7732d89
3bb246c
f161d69
5855926
c7eb88b
13e7b5a
70ae3d2
ea505a4
e706336
8c8d004
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/fluid/operators/mean_iou_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class MeanIoUOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| PADDLE_ENFORCE(ctx->HasInput("predictions"), | ||
| "Input (predictions) of MeanIoU op should not be null."); | ||
| PADDLE_ENFORCE(ctx->HasInput("labels"), | ||
| "Input (labels) of MeanIoU op should not be null."); | ||
| PADDLE_ENFORCE(ctx->HasOutput("out_mean_iou"), | ||
| "Output (out_mean_iou) of MeanIoU op should not be null."); | ||
| PADDLE_ENFORCE(ctx->HasOutput("out_wrong"), | ||
| "Output (out_wrong) of MeanIoU op should not be null."); | ||
| PADDLE_ENFORCE(ctx->HasOutput("out_correct"), | ||
| "Output (out_wrong) of MeanIoU op should not be null."); | ||
|
|
||
| int64_t num_classes = | ||
| static_cast<int64_t>(ctx->Attrs().Get<int>("num_classes")); | ||
|
|
||
| ctx->SetOutputDim("out_mean_iou", {1}); | ||
| ctx->SetOutputDim("out_wrong", {num_classes}); | ||
| ctx->SetOutputDim("out_correct", {num_classes}); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext& ctx) const override { | ||
| return framework::OpKernelType( | ||
| framework::ToDataType(ctx.Input<Tensor>("predictions")->type()), | ||
| ctx.GetPlace()); | ||
| } | ||
| }; | ||
|
|
||
| class MeanIoUOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| MeanIoUOpMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput("predictions", | ||
| "A Tensor of prediction results for semantic labels" | ||
| " with type int32 or int64."); | ||
| AddInput("labels", | ||
| "A Tensor of ground truth labels with type int32 or int64." | ||
| "Its shape should be the same as Input(predictions)."); | ||
| AddInput("in_wrongs", | ||
| "A list of Tensor with shape " | ||
| "[num_classes]. They are used to collect wrong number among " | ||
| "batches. Empty list is also valid here.") | ||
| .AsDuplicable() | ||
| .AsDispensable(); | ||
| AddInput( | ||
| "in_corrects", | ||
| "A list of Tensor with shape " | ||
| "[num_classes]. They are used to collect correct number among batches. " | ||
| "Empty list is also valid here.") | ||
| .AsDuplicable() | ||
| .AsDispensable(); | ||
| AddInput("in_mean_iou", | ||
| "A list of Tensor that Output(mean_iou) should " | ||
| "be added to. Empty list is also valid here.") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in_wrongs, in_corrects, in_mean_iou是干啥的?和out_wrong/correct/mean_iou有啥区别?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in_wrongs, in_corrects, in_mean_iou之前当前batch之前累计的数据,加上当前batch的统计结果,就得到:out_wrong/correct/mean_iou |
||
| .AsDuplicable() | ||
| .AsDispensable(); | ||
| AddOutput("out_mean_iou", | ||
| "A Tensor representing the" | ||
| " mean intersection-over-union."); | ||
|
||
| AddOutput("out_wrong", "A Tensor with shape [num_classes]. "); | ||
| AddOutput("out_correct", "A Tensor with shape [num_classes]. "); | ||
|
||
| AddAttr<int>("num_classes", "The possible number of labels."); | ||
|
|
||
| AddComment(R"DOC( | ||
| mean-IOU Operator. | ||
| Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, which first computes the IOU for each semantic class and then computes the average over classes. IOU is defined as follows: IOU = true_positive / (true_positive + false_positive + false_negative). The predictions are accumulated in a confusion matrix and mean-IOU is then calculated from it. | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have iou_similarity_op: The doc here better to give more details for the difference.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx. FIxed. |
||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OPERATOR(mean_iou, ops::MeanIoUOp, ops::MeanIoUOpMaker, | ||
| paddle::framework::EmptyGradOpMaker); | ||
| REGISTER_OP_CPU_KERNEL(mean_iou, ops::MeanIoUKernel<int>, | ||
| ops::MeanIoUKernel<int64_t>); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 上面文档描述里是支持int32和int64,这里没有注册int32。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx. FIxed. |
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,164 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||
| you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||
| You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||
| See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||
| limitations under the License. */ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #include "paddle/fluid/operators/math/math_function.h" | ||||||||||||||||||||||||||||||||||||||||||||||
| #include "paddle/fluid/operators/mean_iou_op.h" | ||||||||||||||||||||||||||||||||||||||||||||||
| #include "paddle/fluid/platform/cuda_primitives.h" | ||||||||||||||||||||||||||||||||||||||||||||||
| #include "paddle/fluid/platform/gpu_info.h" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| namespace paddle { | ||||||||||||||||||||||||||||||||||||||||||||||
| namespace operators { | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| using platform::PADDLE_CUDA_NUM_THREADS; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #define CUDA_1D_KERNEL_LOOP(i, n) \ | ||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ | ||||||||||||||||||||||||||||||||||||||||||||||
| i += blockDim.x * gridDim.x) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||||||
| __global__ void CountCUDAKernel(const int num_classes, const int count, | ||||||||||||||||||||||||||||||||||||||||||||||
| const T* predictions, const T* labels, | ||||||||||||||||||||||||||||||||||||||||||||||
| int* wrong, int* correct) { | ||||||||||||||||||||||||||||||||||||||||||||||
| extern __shared__ int blcok_cache[]; | ||||||||||||||||||||||||||||||||||||||||||||||
| int* wrong_c = blcok_cache; | ||||||||||||||||||||||||||||||||||||||||||||||
| int* correct_c = blcok_cache + num_classes; | ||||||||||||||||||||||||||||||||||||||||||||||
| // init cache | ||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = threadIdx.x; i < num_classes * 2; i += blockDim.x) { | ||||||||||||||||||||||||||||||||||||||||||||||
| blcok_cache[i] = 0; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| T pred; | ||||||||||||||||||||||||||||||||||||||||||||||
| T label; | ||||||||||||||||||||||||||||||||||||||||||||||
| CUDA_1D_KERNEL_LOOP(i, count) { | ||||||||||||||||||||||||||||||||||||||||||||||
| pred = predictions[i]; | ||||||||||||||||||||||||||||||||||||||||||||||
| label = labels[i]; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (pred == label) { | ||||||||||||||||||||||||||||||||||||||||||||||
| atomicAdd(correct_c + pred, 1); | ||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||
| atomicAdd(wrong_c + pred, 1); | ||||||||||||||||||||||||||||||||||||||||||||||
| atomicAdd(wrong_c + label, 1); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = threadIdx.x; i < num_classes; i += blockDim.x) { | ||||||||||||||||||||||||||||||||||||||||||||||
| atomicAdd(wrong + i, wrong_c[i]); | ||||||||||||||||||||||||||||||||||||||||||||||
| atomicAdd(correct + i, correct_c[i]); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果num_classes较小, predictions的shape较大,会导致这个kernel的性能非常低效,其实感觉类似这样的kernel,先CPU即可,后续最好评估下时间。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| __global__ void ComputeIoUCUDAKernel(const int num_classes, int* wrong, | ||||||||||||||||||||||||||||||||||||||||||||||
| int* correct, float* ious, | ||||||||||||||||||||||||||||||||||||||||||||||
| int* valid_count) { | ||||||||||||||||||||||||||||||||||||||||||||||
| __shared__ int valid_count_c; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (threadIdx.x == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||
| valid_count_c = 0; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||||||||||
| CUDA_1D_KERNEL_LOOP(i, num_classes) { | ||||||||||||||||||||||||||||||||||||||||||||||
| int wrong_n = wrong[i]; | ||||||||||||||||||||||||||||||||||||||||||||||
| int correct_n = correct[i]; | ||||||||||||||||||||||||||||||||||||||||||||||
| int denominator = wrong_n + correct_n; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (denominator > 0) { | ||||||||||||||||||||||||||||||||||||||||||||||
| atomicAdd(&valid_count_c, 1); | ||||||||||||||||||||||||||||||||||||||||||||||
| ious[i] = static_cast<float>(correct_n) / denominator; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (threadIdx.x == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||
| atomicAdd(valid_count, valid_count_c); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||||||
| class MeanIoUCUDAOpKernel : public framework::OpKernel<T> { | ||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||
| void Compute(const framework::ExecutionContext& ctx) const override { | ||||||||||||||||||||||||||||||||||||||||||||||
| auto& place = *ctx.template device_context<platform::CUDADeviceContext>() | ||||||||||||||||||||||||||||||||||||||||||||||
| .eigen_device(); | ||||||||||||||||||||||||||||||||||||||||||||||
| // get input and output tensor | ||||||||||||||||||||||||||||||||||||||||||||||
| auto* predictions = ctx.Input<Tensor>("predictions"); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto* labels = ctx.Input<Tensor>("labels"); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto* out_mean_iou = ctx.Output<Tensor>("out_mean_iou"); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto* out_wrong = ctx.Output<Tensor>("out_wrong"); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto* out_correct = ctx.Output<Tensor>("out_correct"); | ||||||||||||||||||||||||||||||||||||||||||||||
| int num_classes = static_cast<int>(ctx.Attr<int>("num_classes")); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // get data ptr | ||||||||||||||||||||||||||||||||||||||||||||||
| const T* predictions_data = predictions->data<T>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| const T* labels_data = labels->data<T>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace()); | ||||||||||||||||||||||||||||||||||||||||||||||
| int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace()); | ||||||||||||||||||||||||||||||||||||||||||||||
| float* out_mean_iou_data = | ||||||||||||||||||||||||||||||||||||||||||||||
| out_mean_iou->mutable_data<float>(ctx.GetPlace()); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // get eigen tensor | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto out_correct_t = EigenTensor<int, 1>::From(*out_correct); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // Tmp tensor | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| Tensor valid_count; | ||||||||||||||||||||||||||||||||||||||||||||||
| int* valid_count_data = valid_count.mutable_data<int>({1}, ctx.GetPlace()); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto valid_count_t = EigenTensor<int, 1>::From(valid_count); | ||||||||||||||||||||||||||||||||||||||||||||||
| Tensor ious; | ||||||||||||||||||||||||||||||||||||||||||||||
| float* ious_data = ious.mutable_data<float>( | ||||||||||||||||||||||||||||||||||||||||||||||
| {static_cast<int64_t>(num_classes)}, ctx.GetPlace()); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto ious_t = EigenTensor<float, 1>::From(ious); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // init out_wrong, out_correct and out_mean_iou | ||||||||||||||||||||||||||||||||||||||||||||||
| out_wrong_t.device(place) = out_wrong_t.constant(0); | ||||||||||||||||||||||||||||||||||||||||||||||
| out_correct_t.device(place) = out_correct_t.constant(0); | ||||||||||||||||||||||||||||||||||||||||||||||
| valid_count_t.device(place) = valid_count_t.constant(0); | ||||||||||||||||||||||||||||||||||||||||||||||
| out_mean_iou_t.device(place) = out_mean_iou_t.constant(0.0f); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // collect pre wrong, correct and mean_iou | ||||||||||||||||||||||||||||||||||||||||||||||
| auto in_mean_ious = ctx.MultiInput<Tensor>("in_mean_iou"); | ||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < in_mean_ious.size(); ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||
| out_mean_iou_t.device(place) += | ||||||||||||||||||||||||||||||||||||||||||||||
| EigenTensor<float, 1>::From(*in_mean_ious[i]); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| auto in_wrongs = ctx.MultiInput<Tensor>("in_wrongs"); | ||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < in_wrongs.size(); ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||
| out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| auto in_corrects = ctx.MultiInput<Tensor>("in_corrects"); | ||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < in_corrects.size(); ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||
| out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| // compute | ||||||||||||||||||||||||||||||||||||||||||||||
| auto stream = ctx.cuda_device_context().stream(); | ||||||||||||||||||||||||||||||||||||||||||||||
| int block = PADDLE_CUDA_NUM_THREADS; | ||||||||||||||||||||||||||||||||||||||||||||||
| int grid = (predictions->numel() + block - 1) / block; | ||||||||||||||||||||||||||||||||||||||||||||||
| int cache_size = (num_classes * 2 + 1) * sizeof(int); | ||||||||||||||||||||||||||||||||||||||||||||||
| CountCUDAKernel<T><<<grid, block, cache_size, stream>>>( | ||||||||||||||||||||||||||||||||||||||||||||||
| num_classes, predictions->numel(), predictions_data, labels_data, | ||||||||||||||||||||||||||||||||||||||||||||||
| out_wrong_data, out_correct_data); | ||||||||||||||||||||||||||||||||||||||||||||||
| ctx.device_context().Wait(); | ||||||||||||||||||||||||||||||||||||||||||||||
| grid = (num_classes + block - 1) / block; | ||||||||||||||||||||||||||||||||||||||||||||||
| ComputeIoUCUDAKernel<<<grid, block, 0, stream>>>( | ||||||||||||||||||||||||||||||||||||||||||||||
| num_classes, out_wrong_data, out_correct_data, ious_data, | ||||||||||||||||||||||||||||||||||||||||||||||
| valid_count_data); | ||||||||||||||||||||||||||||||||||||||||||||||
| ctx.device_context().Wait(); | ||||||||||||||||||||||||||||||||||||||||||||||
| out_mean_iou_t.device(place) += ious_t.sum() / valid_count_t.cast<float>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace operators | ||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace paddle | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| namespace ops = paddle::operators; | ||||||||||||||||||||||||||||||||||||||||||||||
| REGISTER_OP_CUDA_KERNEL(mean_iou, ops::MeanIoUCUDAOpKernel<int>, | ||||||||||||||||||||||||||||||||||||||||||||||
| ops::MeanIoUKernel<int64_t>); | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <algorithm> | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
| using Tensor = framework::Tensor; | ||
|
|
||
| template <typename T, int D, int MajorType = Eigen::RowMajor, | ||
| typename IndexType = Eigen::DenseIndex> | ||
| using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; | ||
|
|
||
| template <typename T> | ||
| class MeanIoUKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const override { | ||
| auto& place = *ctx.template device_context<platform::CPUDeviceContext>() | ||
| .eigen_device(); | ||
| // get input and output tensor | ||
| auto* predictions = ctx.Input<Tensor>("predictions"); | ||
| auto* labels = ctx.Input<Tensor>("labels"); | ||
| auto* out_mean_iou = ctx.Output<Tensor>("out_mean_iou"); | ||
| auto* out_wrong = ctx.Output<Tensor>("out_wrong"); | ||
| auto* out_correct = ctx.Output<Tensor>("out_correct"); | ||
| int num_classes = static_cast<int>(ctx.Attr<int>("num_classes")); | ||
|
|
||
| // get data ptr | ||
| const T* predictions_data = predictions->data<T>(); | ||
| const T* labels_data = labels->data<T>(); | ||
| float* out_mean_iou_data = | ||
| out_mean_iou->mutable_data<float>(ctx.GetPlace()); | ||
| int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace()); | ||
| int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace()); | ||
|
|
||
| // get eigen tensor | ||
| auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou); | ||
| auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong); | ||
| auto out_correct_t = EigenTensor<int, 1>::From(*out_correct); | ||
|
|
||
| // Tmp tensor | ||
| Tensor denominator; | ||
| Tensor valid_count; | ||
| Tensor iou_sum; | ||
|
|
||
| // get data ptr of tmp tensor | ||
| int* denominator_data = denominator.mutable_data<int>( | ||
| {static_cast<int64_t>(num_classes)}, ctx.GetPlace()); | ||
| int* valid_count_data = valid_count.mutable_data<int>({1}, ctx.GetPlace()); | ||
| float* iou_sum_data = iou_sum.mutable_data<float>({1}, ctx.GetPlace()); | ||
|
|
||
| // get eigen tensor of tmp tensor | ||
| auto denominator_t = EigenTensor<int, 1>::From(denominator); | ||
| auto valid_count_t = EigenTensor<int, 1>::From(valid_count); | ||
| auto iou_sum_t = EigenTensor<float, 1>::From(iou_sum); | ||
|
|
||
| // init out_wrong, out_correct and out_mean_iou | ||
| out_wrong_t = out_wrong_t.constant(0); | ||
| out_correct_t = out_correct_t.constant(0); | ||
| out_mean_iou_t = out_mean_iou_t.constant(0); | ||
|
|
||
| // collect pre wrong, correct and mean_iou | ||
| auto in_mean_ious = ctx.MultiInput<Tensor>("in_mean_iou"); | ||
| for (int i = 0; i < in_mean_ious.size(); ++i) { | ||
| out_mean_iou_t.device(place) += | ||
| EigenTensor<float, 1>::From(*in_mean_ious[i]); | ||
| } | ||
| auto in_wrongs = ctx.MultiInput<Tensor>("in_wrongs"); | ||
| for (int i = 0; i < in_wrongs.size(); ++i) { | ||
| out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]); | ||
| } | ||
| auto in_corrects = ctx.MultiInput<Tensor>("in_corrects"); | ||
| for (int i = 0; i < in_corrects.size(); ++i) { | ||
| out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]); | ||
| } | ||
|
|
||
| // compute | ||
| for (int i = 0; i < predictions->numel(); ++i) { | ||
| if (predictions_data[i] == labels_data[i]) { | ||
| out_correct_data[predictions_data[i]] += 1; | ||
| } else { | ||
| out_wrong_data[labels_data[i]] += 1; | ||
| out_wrong_data[predictions_data[i]] += 1; | ||
| } | ||
| } | ||
|
|
||
| denominator_t = out_wrong_t + out_correct_t; | ||
| valid_count_t = | ||
| (denominator_t > denominator_t.constant(0.0f)).cast<int>().sum(); | ||
|
|
||
| for (int i = 0; i < num_classes; ++i) { | ||
| if (denominator_data[i] == 0) { | ||
| denominator_data[i] = 1; | ||
| } | ||
| } | ||
|
|
||
| iou_sum_t = | ||
| (out_correct_t.cast<float>() / denominator_t.cast<float>()).sum(); | ||
| out_mean_iou_data[0] += (iou_sum_data[0] / valid_count_data[0]); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to give the shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. FIxed.