Skip to content

Commit cfd10d1

Browse files
committed
Add test and fix bug
1 parent b61db42 commit cfd10d1

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

markupsafe/_speedups.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ escape_unicode(PyObject *in)
8888
next_escp = inp;
8989
while (next_escp < inp_end) {
9090
if (*next_escp < ESCAPED_CHARS_TABLE_SIZE &&
91-
(delta_len = escaped_chars_delta_len[*next_escp])) {
91+
(delta_len = escaped_chars_delta_len[*next_escp])) {
9292
++delta_len;
9393
break;
9494
}
@@ -133,7 +133,7 @@ escape_unicode(PyObject *in)
133133
{ \
134134
Py_ssize_t ncopy = 0; \
135135
while (inp < inp_end) { \
136-
switch (*inp++) { \
136+
switch (*inp) { \
137137
case '"': \
138138
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
139139
outp += ncopy; ncopy = 0; \
@@ -180,6 +180,7 @@ escape_unicode(PyObject *in)
180180
default: \
181181
ncopy++; \
182182
} \
183+
inp++; \
183184
} \
184185
memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
185186
}
@@ -297,7 +298,7 @@ escape(PyObject *self, PyObject *text)
297298

298299
if (!id__html__) {
299300
#if PY_MAJOR_VERSION < 3
300-
id__html__ = PyStr_FromString("__html__");
301+
id__html__ = PyString_FromString("__html__");
301302
#else
302303
id__html__ = PyUnicode_FromString("__html__");
303304
#endif

tests.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
import unittest
66
from markupsafe import Markup, escape, escape_silent
77
from markupsafe._compat import text_type, PY2
8+
from markupsafe import _native
9+
try:
10+
from markupsafe import _speedups
11+
have_speedups = True
12+
except ImportError:
13+
have_speedups = False
814

915

1016
class MarkupTestCase(unittest.TestCase):
@@ -191,6 +197,51 @@ def test_markup_leaks(self):
191197
'leak objects, got: ' + str(len(counts))
192198

193199

200+
class NativeEscapeTestCase(unittest.TestCase):
201+
202+
escape = staticmethod(_native.escape)
203+
204+
def test_empty(self):
205+
self.assertEqual(Markup(u''), self.escape(u''))
206+
207+
def test_ascii(self):
208+
self.assertEqual(
209+
Markup(u'abcd&amp;&gt;&lt;&#39;&#34;efgh'),
210+
self.escape(u'abcd&><\'"efgh'))
211+
self.assertEqual(
212+
Markup(u'&amp;&gt;&lt;&#39;&#34;efgh'),
213+
self.escape(u'&><\'"efgh'))
214+
self.assertEqual(
215+
Markup(u'abcd&amp;&gt;&lt;&#39;&#34;'),
216+
self.escape(u'abcd&><\'"'))
217+
218+
def test_2byte(self):
219+
self.assertEqual(
220+
Markup(u'こんにちは&amp;&gt;&lt;&#39;&#34;こんばんは'),
221+
self.escape(u'こんにちは&><\'"こんばんは'))
222+
self.assertEqual(
223+
Markup(u'&amp;&gt;&lt;&#39;&#34;こんばんは'),
224+
self.escape(u'&><\'"こんばんは'))
225+
self.assertEqual(
226+
Markup(u'こんにちは&amp;&gt;&lt;&#39;&#34;'),
227+
self.escape(u'こんにちは&><\'"'))
228+
229+
def test_4byte(self):
230+
self.assertEqual(
231+
Markup(u'\U0001F363\U0001F362&amp;&gt;&lt;&#39;&#34;\U0001F37A xyz'),
232+
self.escape(u'\U0001F363\U0001F362&><\'"\U0001F37A xyz'))
233+
self.assertEqual(
234+
Markup(u'&amp;&gt;&lt;&#39;&#34;\U0001F37A xyz'),
235+
self.escape(u'&><\'"\U0001F37A xyz'))
236+
self.assertEqual(
237+
Markup(u'\U0001F363\U0001F362&amp;&gt;&lt;&#39;&#34;'),
238+
self.escape(u'\U0001F363\U0001F362&><\'"'))
239+
240+
if have_speedups:
241+
class SpeedupEscapeTestCase(NativeEscapeTestCase):
242+
escape = _speedups.escape
243+
244+
194245
def suite():
195246
suite = unittest.TestSuite()
196247
suite.addTest(unittest.makeSuite(MarkupTestCase))
@@ -199,6 +250,10 @@ def suite():
199250
if not hasattr(escape, 'func_code'):
200251
suite.addTest(unittest.makeSuite(MarkupLeakTestCase))
201252

253+
suite.addTest(unittest.makeSuite(NativeEscapeTestCase))
254+
if have_speedups:
255+
suite.addTest(unittest.makeSuite(SpeedupEscapeTestCase))
256+
202257
return suite
203258

204259

0 commit comments

Comments
 (0)