Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/bilinear_tensor_product_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
namespace operators {
Expand Down
16 changes: 7 additions & 9 deletions paddle/fluid/operators/conv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/vol2col.h"

namespace paddle {
Expand Down Expand Up @@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;

auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
Expand All @@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, false, col_matrix,
false, T(1.0), &out_slice, T(0.0));
blas.MatMul(filter_slice, col_matrix, &out_slice);
}
}
}
Expand Down Expand Up @@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {

math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);

if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
Expand Down Expand Up @@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape);
}
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, true,
out_grad_slice, false, T(1.0),
&col_matrix, T(0.0));
blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix);

if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides,
Expand Down Expand Up @@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<DeviceContext, T>(dev_ctx, out_grad_slice, false,
col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
blas.MatMul(out_grad_slice, false, col_matrix, true,
&filter_grad_slice);
}
}
}
Expand Down
16 changes: 6 additions & 10 deletions paddle/fluid/operators/conv_transpose_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/vol2col.h"

namespace paddle {
Expand Down Expand Up @@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
output->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
set_zero(dev_ctx, output, static_cast<T>(0));

math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
Expand All @@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {

// col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
math::matmul<DeviceContext, T>(dev_ctx, filter, true, input_batch, false,
static_cast<T>(1.0), &col_matrix,
static_cast<T>(0.0));
blas.MatMul(filter, true, input_batch, false, &col_matrix);

if (data_dim == 2U) {
// col2im: col_matrix -> dy
Expand Down Expand Up @@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
auto& dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (input_grad || filter_grad) {
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
Expand Down Expand Up @@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w)
math::matmul<DeviceContext, T>(
dev_ctx, filter, false, col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
blas.MatMul(filter, false, col_matrix, false, &input_grad_batch);
}
if (filter_grad) {
// input batch
Expand All @@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w)
math::matmul<DeviceContext, T>(dev_ctx, in_batch, false, col_matrix,
true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_);
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/operators/gru_unit_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/math_function.h"

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
namespace operators {
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/operators/layer_norm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
Expand Down Expand Up @@ -46,9 +46,9 @@ class RowwiseMean2D<platform::CUDADeviceContext, T> {
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
math::gemv<platform::CUDADeviceContext, T>(
context, false, left_, right_, 1., input.data<T>(), divisor_.data<T>(),
0., out->data<T>());
math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
false, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
out->data<T>());
}

private:
Expand Down Expand Up @@ -93,9 +93,9 @@ class ColwiseSum2D<platform::CUDADeviceContext, T> {

void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
math::gemv<platform::CUDADeviceContext, T>(
context, true, left_, right_, 1., input.data<T>(), divisor_.data<T>(),
0., out->data<T>());
math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
true, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
out->data<T>());
}

private:
Expand Down
34 changes: 15 additions & 19 deletions paddle/fluid/operators/lstm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"

namespace paddle {
Expand Down Expand Up @@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel<T> {
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));

auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Expand All @@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_hidden_t, false, *weight,
false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
blas.MatMul(pre_hidden_t, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
} else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
Expand All @@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel<T> {
Tensor ordered_h0;
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, *weight,
false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
}

lstm_value.gate_value = gate_t.data<T>();
Expand Down Expand Up @@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {

auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Expand Down Expand Up @@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &pre_hidden_g,
static_cast<T>(1.0));
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&pre_hidden_g, static_cast<T>(1.0));
if (weight_g) {
/* backward weight */
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_hidden, true, gate_g,
false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
blas.MatMul(pre_hidden, true, gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
}
} else {
if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g,
false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
}
if (h0 && h0_g) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
true, static_cast<T>(1.0),
&ordered_h0_g, static_cast<T>(0.0));
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&ordered_h0_g, static_cast<T>(0.0));
}
}
}
Expand Down
Loading