Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ void TensorDistAttr::set_default_dynamic_dims(
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}

void TensorDistAttr::set_default_dynamic_dims(int64_t tensor_shape_size) {
dynamic_dims_ = std::vector<bool>(tensor_shape_size, false);
}

void TensorDistAttr::mark_annotated(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class TEST_API TensorDistAttr {

void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);

void set_default_dynamic_dims(int64_t tensor_shape_size);

const std::map<std::string, bool>& annotated() const { return annotated_; }

void set_annotated(const std::map<std::string, bool>& annotated);
Expand Down
256 changes: 255 additions & 1 deletion paddle/phi/infermeta/spmd_rules/dim_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
// map between from idx in shape to new_shape
std::vector<int64_t> idx_map(shape.size(), -1);
for (int i = 0, n = static_cast<int>(shape.size()); i < n; ++i) {
if (shape[id] != 1) {
if (shape[i] != 1) {
idx_map[i] = static_cast<int64_t>(new_shape.size());
new_shape.emplace_back(shape[i]);
}
Expand Down Expand Up @@ -272,6 +272,139 @@ std::vector<std::shared_ptr<DimTrans>> GetDimTrans(
return ret_dim_trans;
}

std::vector<std::shared_ptr<DimTrans>> GetDimTransCoShard(
const std::shared_ptr<DimTrans> dim_trans,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& mesh_shape,
const std::vector<std::vector<int64_t>>& input_dims_mapping,
const std::set<int64_t>& sharded_input_dims,
std::vector<std::vector<bool>>* shardable,
std::set<int64_t>* seen_dims) {
DimTrans::Type type = dim_trans->type();
std::vector<std::shared_ptr<DimTrans>> ret_dim_trans;

if (type == DimTrans::Type::INPUTDIM) {
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim_trans);
int64_t dim = inputdim->input_dim();
seen_dims->insert(dim);

if (sharded_input_dims.count(dim) > 0) {
ret_dim_trans.push_back(dim_trans);
}
} else if (type == DimTrans::Type::FLATTEN) {
std::shared_ptr<Flatten> flatten =
std::dynamic_pointer_cast<Flatten>(dim_trans);
const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs();

int64_t nmesh = (*shardable)[0].size(); // NOLINT
int64_t mesh_shape_prod = 1;

int last_shard_idx = -1;
int64_t first_shard_idx = -1;
int64_t first_sharded_shape = -1;

for (int i = 0, n = static_cast<int>(inputs.size()); i < n; ++i) {
std::shared_ptr<DimTrans> input = inputs[i];
if (input->type() != DimTrans::Type::INPUTDIM) {
break;
}
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(input);
if (sharded_input_dims.count(inputdim->input_dim()) > 0) {
if (first_shard_idx == -1) {
first_shard_idx = i;
first_sharded_shape = input_shape[inputdim->input_dim()];
}
for (const auto& dim : input_dims_mapping[inputdim->input_dim()]) {
mesh_shape_prod *= mesh_shape[dim];
}
if (first_sharded_shape % mesh_shape_prod == 0) {
ret_dim_trans.push_back(inputdim);
} else {
break;
}
} else {
break;
}
last_shard_idx = i;
}

for (int i = last_shard_idx + 1, n = static_cast<int>(inputs.size()); i < n;
i++) {
std::shared_ptr<DimTrans> input = inputs[i];
if (input->type() == DimTrans::Type::INPUTDIM) {
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(input);
(*shardable)[inputdim->input_dim()].assign(nmesh, false);
}

GetDimTransCoShard(input,
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
shardable,
seen_dims);
}
} else if (type == DimTrans::Type::SPLIT) {
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
std::vector<std::shared_ptr<DimTrans>> dims =
GetDimTransCoShard(split->input(),
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
shardable,
seen_dims);
int64_t ret_size = split->local_split_shape_value();

if (split->split_id() == 0) {
int64_t mesh_shape_prod = 1;
int64_t first_shard_idx = -1;
int64_t first_sharded_shape = -1;
for (const auto& dim : dims) {
PADDLE_ENFORCE_EQ(dim->type(),
DimTrans::Type::INPUTDIM,
common::errors::InvalidArgument(
"The returned dim_trans must be INPUTDIM."));
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
int64_t input_axis = inputdim->input_dim();

// Check whether the sharded dim can be sharded on
// each mesh dimension. The dimension should be
// divisible by the mesh size that it is sharded on
for (int64_t imesh = 0; imesh < nmesh; imesh++) {
(*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0);
}

if (first_shard_idx == -1) {
first_shard_idx = input_axis;
first_sharded_shape = input_shape[input_axis];
}

if (sharded_input_dims.count(input_axis) > 0) {
for (const auto& dim : input_dims_mapping[input_axis]) {
mesh_shape_prod *= mesh_shape[dim];
}
if ((ret_size % mesh_shape_prod == 0) &&
(first_sharded_shape % mesh_shape_prod == 0)) {
ret_dim_trans.push_back(dim);
} else {
break;
}
} else {
break;
}
}
}
} else if (type == DimTrans::Type::SINGLETON) {
}
return ret_dim_trans;
}

void GetUsedInputDim(const std::shared_ptr<DimTrans> dim_trans,
std::set<int64_t>* seen_dims) {
if (dim_trans->type() == DimTrans::Type::INPUTDIM) {
Expand Down Expand Up @@ -311,6 +444,27 @@ InferFromDimTrans(const DistMetaTensor& input_spec,
return InferFromDimTrans(input_spec, input_shape, dim_trans);
}

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input_spec,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
auto input_shape = phi::vectorize(input_spec.dims());
// deal with reshape xshape in dynamic
if (input_shape[0] == 0 &&
input_shape.size() !=
input_spec.dist_attr().multi_dims_mapping().size()) {
input_shape.erase(input_shape.begin());
}
PADDLE_ENFORCE_EQ(input_shape.size(),
input_spec.dist_attr().multi_dims_mapping().size(),
common::errors::InvalidArgument(
"The Tensor X's rank [%d] and X's "
"dims_mapping size [%d] are not matched.",
input_shape.size(),
input_spec.dist_attr().multi_dims_mapping().size()));
return InferFromDimTransCoShard(input_spec, input_shape, dim_trans);
}

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTrans(const DistMetaTensor& input,
const std::vector<int64_t>& input_shape,
Expand Down Expand Up @@ -400,4 +554,104 @@ InferFromDimTrans(const DistMetaTensor& input,
return {new_input_dims_mapping, out_dims_mapping};
}

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input,
const std::vector<int64_t>& input_shape,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
const std::vector<std::vector<int64_t>>& input_dims_mapping =
input.dist_attr().multi_dims_mapping();
const ProcessMesh& mesh = input.dist_attr().process_mesh();
const std::vector<int64_t>& mesh_shape = mesh.shape();

std::set<int64_t> sharded_input_dims;
for (int64_t i = 0, n = static_cast<int64_t>(input_dims_mapping.size());
i < n;
++i) {
if (std::any_of(input_dims_mapping[i].begin(),
input_dims_mapping[i].end(),
[](int64_t dim) { return dim > -1; })) {
sharded_input_dims.insert(i);
}
}
int64_t ndim = static_cast<int64_t>(input_shape.size());
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
std::vector<std::vector<bool>> shardable(ndim,
std::vector<bool>(nmesh, true));

std::set<int64_t> seen_input_dims;
for (const std::shared_ptr<DimTrans>& trans : dim_trans) {
GetUsedInputDim(trans, &seen_input_dims);
}

for (int64_t idim = 0; idim < ndim; idim++) {
bool seen = seen_input_dims.count(idim);
if (!seen) {
shardable[idim].assign(nmesh, seen);
}
}

// get the map from sharded input dimensions to output dimensions.
// key is src dim, value is dst dim.
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
std::unordered_map<int, std::vector<int>> dim_map_dst2src;
for (int64_t i = 0, n = static_cast<int64_t>(dim_trans.size()); i < n; i++) {
std::vector<std::shared_ptr<DimTrans>> dims =
GetDimTransCoShard(dim_trans[i],
input_shape,
mesh_shape,
input_dims_mapping,
sharded_input_dims,
&shardable,
&seen_input_dims);
for (auto dim : dims) {
if (dim->type() == DimTrans::Type::INPUTDIM) {
std::shared_ptr<InputDim> inputdim =
std::dynamic_pointer_cast<InputDim>(dim);
dim_map_src2tgt[inputdim->input_dim()] = i;
dim_map_dst2src[i].push_back(inputdim->input_dim());
}
}
}

std::vector<std::vector<int64_t>> out_dims_mapping(dim_trans.size());
std::vector<std::vector<int64_t>> new_input_dims_mapping(
input_dims_mapping.size());

// set output dims mapping with corresponding input dimensions.
// if one input dimension is sharded on a unshardable mesh after
// splitting, we need to make it replicated.
for (int64_t i = 0; i < ndim; i++) {
const auto& mesh_dims = input_dims_mapping[i];
if (!std::all_of(mesh_dims.begin(),
mesh_dims.end(),
[](int64_t dim) { return dim >= 0; }) ||
dim_map_src2tgt[i] == -1) {
continue;
}

bool is_unshardable = false;
for (const auto& mesh_dim : mesh_dims) {
if (mesh_dim >= 0 && !shardable[i][mesh_dim]) {
is_unshardable = true;
break;
}
}
if (!is_unshardable) {
int dst_dim = dim_map_src2tgt[i];
const auto& src_dims = dim_map_dst2src[dst_dim];
auto min_dim_it = std::min_element(src_dims.begin(), src_dims.end());
int64_t min_dim = *min_dim_it;
out_dims_mapping[dst_dim].insert(
out_dims_mapping[dst_dim].end(), mesh_dims.begin(), mesh_dims.end());
new_input_dims_mapping[min_dim].insert(
new_input_dims_mapping[min_dim].end(),
mesh_dims.begin(),
mesh_dims.end());
}
}

return {new_input_dims_mapping, out_dims_mapping};
}

} // namespace phi::distributed
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/spmd_rules/dim_trans.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,21 @@ std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTrans(const DistMetaTensor& input_spec,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input_spec,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTrans(const DistMetaTensor& input_spec,
const std::vector<int64_t>& input_shape,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
InferFromDimTransCoShard(
const DistMetaTensor& input_spec,
const std::vector<int64_t>& input_shape,
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);

} // namespace distributed
} // namespace phi
Loading