Skip to content

Commit 3592f80

Browse files
committed
added SizeNe
1 parent b357d87 commit 3592f80

File tree

4 files changed

+41
-0
lines changed

4 files changed

+41
-0
lines changed

torch_xla/csrc/ops/dynamic_ir.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,25 @@ int64_t SizeEq::getDynamicValue() const {
112112

113113
std::string SizeEq::ToString() const { return "SizeEq"; }
114114

115+
SizeNe::SizeNe(torch::lazy::Value a, torch::lazy::Value b)
116+
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::ne")},
117+
{a, b}, xla::ShapeUtil::MakeShape(xla::S64, {}), 1) {
118+
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
119+
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
120+
XLA_CHECK(dim_node_0);
121+
XLA_CHECK(dim_node_1);
122+
};
123+
124+
int64_t SizeNe::getDynamicValue() const {
125+
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
126+
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
127+
XLA_CHECK(dim_node_0);
128+
XLA_CHECK(dim_node_1);
129+
return dim_node_0->getDynamicValue() != dim_node_1->getDynamicValue() ? 1 : 0;
130+
}
131+
132+
std::string SizeNe::ToString() const { return "SizeNe"; }
133+
115134
SizeConstant::SizeConstant(int64_t val) : Scalar(c10::Scalar{val}, xla::S64){};
116135

117136
SizeMul::SizeMul(torch::lazy::Value a, torch::lazy::Value b)

torch_xla/csrc/ops/dynamic_ir.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ class SizeEq : public XlaNode, public torch::lazy::DimensionNode {
6868
}
6969
};
7070

71+
class SizeNe : public XlaNode, public torch::lazy::DimensionNode {
72+
public:
73+
SizeNe(torch::lazy::Value a, torch::lazy::Value b);
74+
int64_t getDynamicValue() const override;
75+
int64_t getStaticValue() const override {
76+
TORCH_CHECK(false, "Comparison operators should be using getDynamicValue");
77+
}
78+
bool isSymbolic() const override { return true; }
79+
std::string ToString() const override;
80+
virtual XlaOpVector Lower(LoweringContext* loctx) const override {
81+
// TODO: not sure we will ever need it?
82+
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
83+
}
84+
};
85+
7186
class SizeAdd : public XlaNode, public torch::lazy::DimensionNode {
7287
public:
7388
SizeAdd(torch::lazy::Value a, torch::lazy::Value b);

torch_xla/csrc/tensor.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,12 @@ c10::SymNode XLASymNodeImpl::eq(const c10::SymNode& other) {
615615
return c10::make_intrusive<XLASymNodeImpl>(neq);
616616
}
617617

618+
c10::SymNode XLASymNodeImpl::ne(const c10::SymNode& other) {
619+
auto pother = dynamic_cast<XLASymNodeImpl*>(other.get());
620+
auto nne = torch::lazy::MakeNode<SizeNe>(node(), pother->node());
621+
return c10::make_intrusive<XLASymNodeImpl>(nne);
622+
}
623+
618624
c10::SymNode XLASymNodeImpl::add(const c10::SymNode& other) {
619625
auto pother = dynamic_cast<XLASymNodeImpl*>(other.get());
620626
auto nadd = torch::lazy::MakeNode<SizeAdd>(node(), pother->node());

torch_xla/csrc/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class TORCH_API XLASymNodeImpl : public c10::SymNodeImpl {
3333
bool is_int() override;
3434
bool is_float() override;
3535
c10::SymNode eq(const c10::SymNode& other) override;
36+
c10::SymNode ne(const c10::SymNode& other) override;
3637
c10::SymNode add(const c10::SymNode& other) override;
3738
c10::SymNode mul(const c10::SymNode& other) override;
3839
c10::SymNode floordiv(const c10::SymNode& other) override;

0 commit comments

Comments
 (0)