@@ -34,12 +34,13 @@ namespace gtsam {
3434 /* ******************************************************************************** */
3535 DecisionTreeFactor::DecisionTreeFactor (const DiscreteKeys& keys,
3636 const ADT& potentials) :
37- DiscreteFactor (keys.indices()), Potentials(keys, potentials) {
37+ DiscreteFactor (keys.indices()), ADT(potentials),
38+ cardinalities_ (keys.cardinalities()) {
3839 }
3940
4041 /* *************************************************************************/
4142 DecisionTreeFactor::DecisionTreeFactor (const DiscreteConditional& c) :
42- DiscreteFactor (c.keys()), Potentials(c ) {
43+ DiscreteFactor (c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_ ) {
4344 }
4445
4546 /* ************************************************************************* */
@@ -48,16 +49,24 @@ namespace gtsam {
4849 return false ;
4950 }
5051 else {
51- const DecisionTreeFactor & f (static_cast <const DecisionTreeFactor&>(other));
52- return Potentials ::equals (f, tol);
52+ const auto & f (static_cast <const DecisionTreeFactor&>(other));
53+ return ADT ::equals (f, tol);
5354 }
5455 }
5556
57+ /* ************************************************************************* */
58+ double DecisionTreeFactor::safe_div (const double &a, const double &b) {
59+ // The use for safe_div is when we divide the product factor by the sum
60+ // factor. If the product or sum is zero, we accord zero probability to the
61+ // event.
62+ return (a == 0 || b == 0 ) ? 0 : (a / b);
63+ }
64+
5665 /* ************************************************************************* */
5766 void DecisionTreeFactor::print (const string& s,
5867 const KeyFormatter& formatter) const {
5968 cout << s;
60- Potentials ::print (" Potentials:" ,formatter);
69+ ADT ::print (" Potentials:" ,formatter);
6170 }
6271
6372 /* ************************************************************************* */
@@ -162,20 +171,20 @@ namespace gtsam {
162171 void DecisionTreeFactor::dot (std::ostream& os,
163172 const KeyFormatter& keyFormatter,
164173 bool showZero) const {
165- Potentials ::dot (os, keyFormatter, valueFormatter, showZero);
174+ ADT ::dot (os, keyFormatter, valueFormatter, showZero);
166175 }
167176
168177 /* * output to graphviz format, open a file */
169178 void DecisionTreeFactor::dot (const std::string& name,
170179 const KeyFormatter& keyFormatter,
171180 bool showZero) const {
172- Potentials ::dot (name, keyFormatter, valueFormatter, showZero);
181+ ADT ::dot (name, keyFormatter, valueFormatter, showZero);
173182 }
174183
175184 /* * output to graphviz format string */
176185 std::string DecisionTreeFactor::dot (const KeyFormatter& keyFormatter,
177186 bool showZero) const {
178- return Potentials ::dot (keyFormatter, valueFormatter, showZero);
187+ return ADT ::dot (keyFormatter, valueFormatter, showZero);
179188 }
180189
181190 /* ************************************************************************* */
@@ -209,5 +218,15 @@ namespace gtsam {
209218 return ss.str ();
210219 }
211220
221+ DecisionTreeFactor::DecisionTreeFactor (const DiscreteKeys &keys, const vector<double > &table) :
222+ DiscreteFactor (keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
223+ cardinalities_ (keys.cardinalities()) {
224+ }
225+
226+ DecisionTreeFactor::DecisionTreeFactor (const DiscreteKeys &keys, const string &table) :
227+ DiscreteFactor (keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
228+ cardinalities_ (keys.cardinalities()) {
229+ }
230+
212231 /* ************************************************************************* */
213232} // namespace gtsam
0 commit comments