Skip to content

Commit 2e78816

Browse files
authored
Merge pull request #393 from enthought/fix/consistent-floats
Clean up Float and BaseFloat validation
2 parents 1ab8ed1 + 0294a89 commit 2e78816

File tree

3 files changed

+201
-13
lines changed

3 files changed

+201
-13
lines changed

traits/ctraits.c

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3318,13 +3318,52 @@ validate_trait_integer ( trait_object * trait, has_traits_object * obj,
33183318
return raise_trait_error( trait, obj, name, value );
33193319
}
33203320

3321+
3322+
/*-----------------------------------------------------------------------------
3323+
| Verifies that a Python value is convertible to float
3324+
|
3325+
| Will convert anything whose type has a __float__ method to a Python
3326+
| float. Returns a Python object of exact type "float". Raises TraitError
3327+
| with a suitable message if the given value isn't convertible to float.
3328+
|
3329+
| Any exception other than TypeError raised by the value's __float__ method
3330+
| will be propagated. A TypeError will be caught and turned into TraitError.
3331+
|
3332+
+----------------------------------------------------------------------------*/
3333+
3334+
static PyObject *
3335+
validate_trait_float(trait_object * trait, has_traits_object * obj,
3336+
PyObject * name, PyObject * value) {
3337+
/* Fast path for the most common case. */
3338+
if (PyFloat_CheckExact(value)) {
3339+
Py_INCREF(value);
3340+
return value;
3341+
}
3342+
else {
3343+
double value_as_double = PyFloat_AsDouble(value);
3344+
/* Translate a TypeError to a TraitError, but propagate
3345+
other exceptions. */
3346+
if (value_as_double == -1.0 && PyErr_Occurred()) {
3347+
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
3348+
PyErr_Clear();
3349+
goto error;
3350+
}
3351+
return NULL;
3352+
}
3353+
return PyFloat_FromDouble(value_as_double);
3354+
}
3355+
3356+
error:
3357+
return raise_trait_error(trait, obj, name, value);
3358+
}
3359+
33213360
/*-----------------------------------------------------------------------------
33223361
| Verifies a Python value is a float within a specified range:
33233362
+----------------------------------------------------------------------------*/
33243363

33253364
static PyObject *
3326-
validate_trait_float ( trait_object * trait, has_traits_object * obj,
3327-
PyObject * name, PyObject * value ) {
3365+
validate_trait_float_range ( trait_object * trait, has_traits_object * obj,
3366+
PyObject * name, PyObject * value ) {
33283367

33293368
register PyObject * low;
33303369
register PyObject * high;
@@ -4015,7 +4054,7 @@ validate_trait_complex ( trait_object * trait, has_traits_object * obj,
40154054
goto done;
40164055
break;
40174056

4018-
case 20: /* Integer check: */
4057+
case 20: /* Integer check: */
40194058

40204059
/* Fast paths for the most common cases. */
40214060
#if PY_MAJOR_VERSION < 3
@@ -4063,6 +4102,29 @@ validate_trait_complex ( trait_object * trait, has_traits_object * obj,
40634102
Py_DECREF(int_value);
40644103
return result;
40654104

4105+
case 21: /* Float check */
4106+
/* Fast path for most common case. */
4107+
if (PyFloat_CheckExact(value)) {
4108+
Py_INCREF(value);
4109+
return value;
4110+
}
4111+
else {
4112+
double value_as_double = PyFloat_AsDouble(value);
4113+
if (value_as_double == -1.0 && PyErr_Occurred()) {
4114+
/* TypeError indicates that we don't have a match;
4115+
clear the error and continue with the next item
4116+
in the complex sequence. */
4117+
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
4118+
PyErr_Clear();
4119+
break;
4120+
}
4121+
/* Any other exception is unexpected and likely
4122+
a code bug; propagate it. */
4123+
return NULL;
4124+
}
4125+
return PyFloat_FromDouble(value_as_double);
4126+
}
4127+
40664128
default: /* Should never happen...indicates an internal error: */
40674129
goto error;
40684130
}
@@ -4086,7 +4148,7 @@ static trait_validate validate_handlers[] = {
40864148
#else
40874149
validate_trait_self_type, NULL,
40884150
#endif // #if PY_MAJOR_VERSION < 3
4089-
validate_trait_float, validate_trait_enum,
4151+
validate_trait_float_range, validate_trait_enum,
40904152
validate_trait_map, validate_trait_complex,
40914153
NULL, validate_trait_tuple,
40924154
validate_trait_prefix_map, validate_trait_coerce_type,
@@ -4097,6 +4159,7 @@ static trait_validate validate_handlers[] = {
40974159
setattr_validate2, setattr_validate3,
40984160
/* ...End of __getstate__ method entries */
40994161
validate_trait_adapt, validate_trait_integer,
4162+
validate_trait_float,
41004163
};
41014164

41024165
static PyObject *
@@ -4246,6 +4309,12 @@ _trait_set_validate ( trait_object * trait, PyObject * args ) {
42464309
if ( n == 1 )
42474310
goto done;
42484311
break;
4312+
4313+
case 21: /* Float check: */
4314+
if ( n == 1 )
4315+
goto done;
4316+
break;
4317+
42494318
}
42504319
}
42514320
}

traits/tests/test_float.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,48 @@
1717
"""
1818
import sys
1919

20+
try:
21+
import numpy
22+
except ImportError:
23+
numpy_available = False
24+
else:
25+
numpy_available = True
26+
2027
from traits.testing.unittest_tools import unittest
2128

22-
from ..api import BaseFloat, Float, HasTraits
29+
from ..api import BaseFloat, Either, Float, HasTraits, TraitError, Unicode
30+
31+
32+
class MyFloat(object):
33+
def __init__(self, value):
34+
self._value = value
35+
36+
def __float__(self):
37+
return self._value
38+
39+
40+
class BadFloat(object):
41+
def __float__(self):
42+
raise ZeroDivisionError
2343

2444

2545
class FloatModel(HasTraits):
2646
value = Float
2747

48+
# Assignment to the `Either` trait exercises a different C code path (see
49+
# validate_trait_complex in ctraits.c).
50+
value_or_none = Either(None, Float)
51+
52+
float_or_text = Either(Float, Unicode)
53+
2854

2955
class BaseFloatModel(HasTraits):
3056
value = BaseFloat
3157

58+
value_or_none = Either(None, BaseFloat)
59+
60+
float_or_text = Either(Float, Unicode)
61+
3262

3363
class CommonFloatTests(object):
3464
""" Common tests for Float and BaseFloat """
@@ -38,36 +68,110 @@ def test_default(self):
3868

3969
def test_accepts_float(self):
4070
a = self.test_class()
71+
4172
a.value = 5.6
4273
self.assertIs(type(a.value), float)
4374
self.assertEqual(a.value, 5.6)
4475

76+
a.value_or_none = 5.6
77+
self.assertIs(type(a.value_or_none), float)
78+
self.assertEqual(a.value_or_none, 5.6)
79+
4580
def test_accepts_int(self):
4681
a = self.test_class()
82+
4783
a.value = 2
4884
self.assertIs(type(a.value), float)
4985
self.assertEqual(a.value, 2.0)
5086

87+
a.value_or_none = 2
88+
self.assertIs(type(a.value_or_none), float)
89+
self.assertEqual(a.value_or_none, 2.0)
90+
91+
def test_accepts_float_like(self):
92+
a = self.test_class()
93+
94+
a.value = MyFloat(1729.0)
95+
self.assertIs(type(a.value), float)
96+
self.assertEqual(a.value, 1729.0)
97+
98+
a.value = MyFloat(594.0)
99+
self.assertIs(type(a.value), float)
100+
self.assertEqual(a.value, 594.0)
101+
102+
def test_rejects_string(self):
103+
a = self.test_class()
104+
with self.assertRaises(TraitError):
105+
a.value = "2.3"
106+
with self.assertRaises(TraitError):
107+
a.value_or_none = "2.3"
108+
109+
def test_bad_float_exceptions_propagated(self):
110+
a = self.test_class()
111+
with self.assertRaises(ZeroDivisionError):
112+
a.value = BadFloat()
113+
114+
def test_compound_trait_float_conversion_fail(self):
115+
# Check that a failure to convert to float doesn't terminate
116+
# an assignment to a compound trait.
117+
a = self.test_class()
118+
a.float_or_text = u"not a float"
119+
self.assertEqual(a.float_or_text, u"not a float")
120+
51121
@unittest.skipUnless(sys.version_info < (3,), "Not applicable to Python 3")
52122
def test_accepts_small_long(self):
53123
a = self.test_class()
124+
54125
a.value = long(2)
55126
self.assertIs(type(a.value), float)
56127
self.assertEqual(a.value, 2.0)
57128

129+
a.value_or_none = long(2)
130+
self.assertIs(type(a.value_or_none), float)
131+
self.assertEqual(a.value_or_none, 2.0)
132+
58133
@unittest.skipUnless(sys.version_info < (3,), "Not applicable to Python 3")
59134
def test_accepts_large_long(self):
60135
a = self.test_class()
136+
61137
# Value large enough to be a long on Python 2.
62138
a.value = 2**64
63139
self.assertIs(type(a.value), float)
64140
self.assertEqual(a.value, 2**64)
65141

142+
a.value_or_none = 2**64
143+
self.assertIs(type(a.value_or_none), float)
144+
self.assertEqual(a.value_or_none, 2**64)
145+
146+
@unittest.skipUnless(numpy_available, "Test requires NumPy")
147+
def test_accepts_numpy_floats(self):
148+
test_values = [
149+
numpy.float64(2.3),
150+
numpy.float32(3.7),
151+
numpy.float16(1.28),
152+
]
153+
a = self.test_class()
154+
for test_value in test_values:
155+
a.value = test_value
156+
self.assertIs(type(a.value), float)
157+
self.assertEqual(a.value, test_value)
158+
159+
a.value_or_none = test_value
160+
self.assertIs(type(a.value_or_none), float)
161+
self.assertEqual(a.value_or_none, test_value)
162+
66163

67164
class TestFloat(unittest.TestCase, CommonFloatTests):
68165
def setUp(self):
69166
self.test_class = FloatModel
70167

168+
def test_exceptions_propagate_in_compound_trait(self):
169+
# This test doesn't currently pass for BaseFloat, which is why it's not
170+
# in the common tests. That's probably a bug.
171+
a = self.test_class()
172+
with self.assertRaises(ZeroDivisionError):
173+
a.value_or_none = BadFloat()
174+
71175

72176
class TestBaseFloat(unittest.TestCase, CommonFloatTests):
73177
def setUp(self):

traits/trait_types.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ def default_text_editor ( trait, type = None ):
123123
enter_set = enter_set,
124124
evaluate = type )
125125

126+
127+
# Generic validators
128+
129+
def _validate_float(value):
130+
"""
131+
Convert an arbitrary Python object to a float, or raise TypeError.
132+
"""
133+
if type(value) == float: # fast path for common case
134+
return value
135+
try:
136+
nb_float = type(value).__float__
137+
except AttributeError:
138+
raise TypeError(
139+
"Object of type {!r} not convertible to float".format(type(value)))
140+
return nb_float(value)
141+
142+
126143
#-------------------------------------------------------------------------------
127144
# 'Any' trait:
128145
#-------------------------------------------------------------------------------
@@ -241,6 +258,7 @@ class Long ( BaseLong ):
241258
#: The C-level fast validator to use:
242259
fast_validate = long_fast_validate
243260

261+
244262
#-------------------------------------------------------------------------------
245263
# 'BaseFloat' and 'Float' traits:
246264
#-------------------------------------------------------------------------------
@@ -262,13 +280,10 @@ def validate ( self, object, name, value ):
262280
263281
Note: The 'fast validator' version performs this check in C.
264282
"""
265-
if isinstance( value, float ):
266-
return value
267-
268-
if isinstance( value, ( int, long ) ):
269-
return float( value )
270-
271-
self.error( object, name, value )
283+
try:
284+
return _validate_float(value)
285+
except TypeError:
286+
self.error(object, name, value)
272287

273288
def create_editor ( self ):
274289
""" Returns the default traits UI editor for this type of trait.
@@ -282,7 +297,7 @@ class Float ( BaseFloat ):
282297
"""
283298

284299
#: The C-level fast validator to use:
285-
fast_validate = float_fast_validate
300+
fast_validate = ( 21, )
286301

287302
#-------------------------------------------------------------------------------
288303
# 'BaseComplex' and 'Complex' traits:

0 commit comments

Comments
 (0)