File tree Expand file tree Collapse file tree 3 files changed +23
-5
lines changed
fbgemm_gpu/src/ssd_split_embeddings_cache Expand file tree Collapse file tree 3 files changed +23
-5
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,9 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
47
47
const int64_t length,
48
48
const at::Tensor& weights);
49
49
50
- c10::IntArrayRef size ();
50
+ c10::IntArrayRef sizes ();
51
+
52
+ c10::IntArrayRef strides ();
51
53
52
54
c10::ScalarType dtype ();
53
55
Original file line number Diff line number Diff line change @@ -57,11 +57,16 @@ void KVTensorWrapper::set_range(
57
57
FBEXCEPTION (" Not implemented" );
58
58
}
59
59
60
- c10::IntArrayRef KVTensorWrapper::size () {
60
+ c10::IntArrayRef KVTensorWrapper::sizes () {
61
61
FBEXCEPTION (" Not implemented" );
62
62
return shape_;
63
63
}
64
64
65
+ c10::IntArrayRef KVTensorWrapper::strides () {
66
+ FBEXCEPTION (" Not implemented" );
67
+ return shape_; // make linter happy.
68
+ }
69
+
65
70
c10::ScalarType KVTensorWrapper::dtype () {
66
71
FBEXCEPTION (" Not implemented" );
67
72
return options_.dtype ().toScalarType ();
Original file line number Diff line number Diff line change @@ -347,10 +347,20 @@ void KVTensorWrapper::set_range(
347
347
}
348
348
}
349
349
350
- c10::IntArrayRef KVTensorWrapper::size () {
350
+ c10::IntArrayRef KVTensorWrapper::sizes () {
351
351
return shape_;
352
352
}
353
353
354
+ c10::IntArrayRef KVTensorWrapper::strides () {
355
+ // Assume contiguous tensor.
356
+ std::vector<int64_t > strides (shape_.size (), 1 );
357
+ for (int i = shape_.size () - 2 ; i > -1 ; i--) {
358
+ int prev = i + 1 ;
359
+ strides[i] = strides[prev] * std::max<int64_t >(shape_[prev], 1 );
360
+ }
361
+ return strides;
362
+ }
363
+
354
364
c10::ScalarType KVTensorWrapper::dtype () {
355
365
return options_.dtype ().toScalarType ();
356
366
}
@@ -500,9 +510,10 @@ static auto kv_tensor_wrapper =
500
510
.def_property(" layout_str" , &KVTensorWrapper::layout_str)
501
511
.def_property(
502
512
" shape" ,
503
- &KVTensorWrapper::size ,
513
+ &KVTensorWrapper::sizes ,
504
514
std::string (
505
- " Returns the shape of the original tensor. Only the narrowed part is materialized." ));
515
+ " Returns the shape of the original tensor. Only the narrowed part is materialized." ))
516
+ .def_property(" strides" , &KVTensorWrapper::strides);
506
517
507
518
TORCH_LIBRARY_FRAGMENT (fbgemm, m) {
508
519
m.def (
You can’t perform that action at this time.
0 commit comments