Skip to content

Commit 807bf97

Browse files
pradeepfnfacebook-github-bot
authored andcommitted
Adding couple more APIs to KVTensorWrapper to bring partiy with torch::Tensor (#3645)
Summary: X-link: facebookresearch/FBGEMM#721 Differential Revision: D68934783
1 parent 98d54f7 commit 807bf97

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
4747
const int64_t length,
4848
const at::Tensor& weights);
4949

50-
c10::IntArrayRef size();
50+
c10::IntArrayRef sizes();
51+
52+
c10::IntArrayRef strides();
5153

5254
c10::ScalarType dtype();
5355

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,16 @@ void KVTensorWrapper::set_range(
5757
FBEXCEPTION("Not implemented");
5858
}
5959

60-
c10::IntArrayRef KVTensorWrapper::size() {
60+
c10::IntArrayRef KVTensorWrapper::sizes() {
6161
FBEXCEPTION("Not implemented");
6262
return shape_;
6363
}
6464

65+
c10::IntArrayRef KVTensorWrapper::strides() {
66+
FBEXCEPTION("Not implemented");
67+
return shape_; // make linter happy.
68+
}
69+
6570
c10::ScalarType KVTensorWrapper::dtype() {
6671
FBEXCEPTION("Not implemented");
6772
return options_.dtype().toScalarType();

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,20 @@ void KVTensorWrapper::set_range(
347347
}
348348
}
349349

350-
c10::IntArrayRef KVTensorWrapper::size() {
350+
c10::IntArrayRef KVTensorWrapper::sizes() {
351351
return shape_;
352352
}
353353

354+
c10::IntArrayRef KVTensorWrapper::strides() {
355+
// Assume contiguous tensor.
356+
std::vector<int64_t> strides(shape_.size(), 1);
357+
for (int i = shape_.size() - 2; i > -1; i--) {
358+
int prev = i + 1;
359+
strides[i] = strides[prev] * std::max<int64_t>(shape_[prev], 1);
360+
}
361+
return strides;
362+
}
363+
354364
c10::ScalarType KVTensorWrapper::dtype() {
355365
return options_.dtype().toScalarType();
356366
}
@@ -500,9 +510,10 @@ static auto kv_tensor_wrapper =
500510
.def_property("layout_str", &KVTensorWrapper::layout_str)
501511
.def_property(
502512
"shape",
503-
&KVTensorWrapper::size,
513+
&KVTensorWrapper::sizes,
504514
std::string(
505-
"Returns the shape of the original tensor. Only the narrowed part is materialized."));
515+
"Returns the shape of the original tensor. Only the narrowed part is materialized."))
516+
.def_property("strides", &KVTensorWrapper::strides);
506517

507518
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
508519
m.def(

0 commit comments

Comments
 (0)