diff --git a/asyncpg/protocol/buffer.pxd b/asyncpg/protocol/buffer.pxd index 6687e8e5..1093c688 100644 --- a/asyncpg/protocol/buffer.pxd +++ b/asyncpg/protocol/buffer.pxd @@ -48,7 +48,7 @@ cdef class WriteBuffer: cdef write_bytes(self, bytes data) cdef write_bytestring(self, bytes string) cdef write_str(self, str string, str encoding) - cdef write_cstr(self, char *data, ssize_t len) + cdef write_cstr(self, const char *data, ssize_t len) cdef write_int16(self, int16_t i) cdef write_int32(self, int32_t i) cdef write_int64(self, int64_t i) diff --git a/asyncpg/protocol/buffer.pyx b/asyncpg/protocol/buffer.pyx index e999b7ad..6f803340 100644 --- a/asyncpg/protocol/buffer.pyx +++ b/asyncpg/protocol/buffer.pyx @@ -169,7 +169,7 @@ cdef class WriteBuffer: cdef write_str(self, str string, str encoding): self.write_bytestring(string.encode(encoding)) - cdef write_cstr(self, char *data, ssize_t len): + cdef write_cstr(self, const char *data, ssize_t len): self._check_readonly() self._ensure_alloced(len) diff --git a/asyncpg/protocol/codecs/array.pyx b/asyncpg/protocol/codecs/array.pyx index d2dd9cd4..92c76a32 100644 --- a/asyncpg/protocol/codecs/array.pyx +++ b/asyncpg/protocol/codecs/array.pyx @@ -140,6 +140,127 @@ cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf, buf.write_buffer(elem_data) +cdef _write_textarray_data(ConnectionSettings settings, object obj, + int32_t ndims, int32_t dim, WriteBuffer array_data, + encode_func_ex encoder, const void *encoder_arg, + Py_UCS4 typdelim): + cdef: + ssize_t i = 0 + int8_t delim = typdelim + WriteBuffer elem_data + Py_buffer pybuf + const char *elem_str + char ch + ssize_t elem_len + ssize_t quoted_elem_len + bint need_quoting + + array_data.write_byte(b'{') + + if dim < ndims - 1: + for item in obj: + if i > 0: + array_data.write_byte(delim) + array_data.write_byte(b' ') + _write_textarray_data(settings, item, ndims, dim + 1, array_data, + encoder, encoder_arg, typdelim) + i += 1 + else: + for item in obj: + elem_data = WriteBuffer.new() + + if i > 0: + array_data.write_byte(delim) + array_data.write_byte(b' ') + + if item is None: + array_data.write_bytes(b'NULL') + i += 1 + continue + else: + try: + encoder(settings, elem_data, item, encoder_arg) + except TypeError as e: + raise ValueError( + 'invalid array element: {}'.format( + e.args[0])) from None + + # element string length (first four bytes are the encoded length.) + elem_len = elem_data.len() - 4 + + if elem_len == 0: + # Empty string + array_data.write_bytes(b'""') + else: + cpython.PyObject_GetBuffer( + elem_data, &pybuf, cpython.PyBUF_SIMPLE) + + elem_str = (pybuf.buf) + 4 + + try: + if not apg_strcasecmp_char(elem_str, b'NULL'): + array_data.write_bytes(b'"NULL"') + else: + quoted_elem_len = elem_len + need_quoting = False + + for i in range(elem_len): + ch = elem_str[i] + if ch == b'"' or ch == b'\\': + # Quotes and backslashes need escaping. + quoted_elem_len += 1 + need_quoting = True + elif (ch == b'{' or ch == b'}' or ch == delim or + apg_ascii_isspace(ch)): + need_quoting = True + + if need_quoting: + array_data.write_byte(b'"') + + if quoted_elem_len == elem_len: + array_data.write_cstr(elem_str, elem_len) + else: + # Escaping required. + for i in range(elem_len): + ch = elem_str[i] + if ch == b'"' or ch == b'\\': + array_data.write_byte(b'\\') + array_data.write_byte(ch) + + array_data.write_byte(b'"') + else: + array_data.write_cstr(elem_str, elem_len) + finally: + cpython.PyBuffer_Release(&pybuf) + + i += 1 + + array_data.write_byte(b'}') + + +cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf, + object obj, encode_func_ex encoder, + const void *encoder_arg, Py_UCS4 typdelim): + cdef: + WriteBuffer array_data + int32_t dims[ARRAY_MAXDIM] + int32_t ndims = 1 + int32_t i + + if not _is_container(obj): + raise TypeError( + 'a non-trivial iterable expected (got type {!r})'.format( + type(obj).__name__)) + + _get_array_shape(obj, dims, &ndims) + + array_data = WriteBuffer.new() + _write_textarray_data(settings, obj, ndims, 0, array_data, + encoder, encoder_arg, typdelim) + buf.write_int32(array_data.len()) + buf.write_buffer(array_data) + + cdef inline array_decode(ConnectionSettings settings, FastReadBuffer buf, decode_func_ex decoder, const void *decoder_arg): cdef: diff --git a/asyncpg/protocol/codecs/base.pxd b/asyncpg/protocol/codecs/base.pxd index fe5d7b01..19b9bb76 100644 --- a/asyncpg/protocol/codecs/base.pxd +++ b/asyncpg/protocol/codecs/base.pxd @@ -81,6 +81,9 @@ cdef class Codec: cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf, object obj) + cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, + object obj) + cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, object obj) @@ -137,6 +140,7 @@ cdef class Codec: cdef Codec new_composite_codec(uint32_t oid, str name, str schema, + CodecFormat format, list element_codecs, tuple element_type_oids, object element_names) diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx index 70aa650e..b51a114d 100644 --- a/asyncpg/protocol/codecs/base.pyx +++ b/asyncpg/protocol/codecs/base.pyx @@ -50,15 +50,24 @@ cdef class Codec: self.encoder = &self.encode_scalar self.decoder = &self.decode_scalar elif type == CODEC_ARRAY: - self.encoder = &self.encode_array if format == PG_FORMAT_BINARY: + self.encoder = &self.encode_array self.decoder = &self.decode_array else: + self.encoder = &self.encode_array_text self.decoder = &self.decode_array_text elif type == CODEC_RANGE: + if format != PG_FORMAT_BINARY: + raise RuntimeError( + 'cannot encode type "{}"."{}": text encoding of ' + 'range types is not supported'.format(schema, name)) self.encoder = &self.encode_range self.decoder = &self.decode_range elif type == CODEC_COMPOSITE: + if format != PG_FORMAT_BINARY: + raise RuntimeError( + 'cannot encode type "{}"."{}": text encoding of ' + 'composite types is not supported'.format(schema, name)) self.encoder = &self.encode_composite self.decoder = &self.decode_composite elif type == CODEC_PY: @@ -91,6 +100,13 @@ cdef class Codec: codec_encode_func_ex, (self.element_codec)) + cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, + object obj): + return textarray_encode(settings, buf, obj, + codec_encode_func_ex, + (self.element_codec), + self.element_delimiter) + cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, object obj): range_encode(settings, buf, obj, self.element_codec.oid, @@ -269,22 +285,22 @@ cdef class Codec: Codec element_codec): cdef Codec codec codec = Codec(oid) - codec.init(name, schema, 'range', CODEC_RANGE, PG_FORMAT_BINARY, - NULL, NULL, None, None, element_codec, None, None, None, - 0) + codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format, + NULL, NULL, None, None, element_codec, None, None, None, 0) return codec @staticmethod cdef Codec new_composite_codec(uint32_t oid, str name, str schema, + CodecFormat format, list element_codecs, tuple element_type_oids, object element_names): cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'composite', CODEC_COMPOSITE, - PG_FORMAT_BINARY, NULL, NULL, None, None, None, + format, NULL, NULL, None, None, None, element_type_oids, element_names, element_codecs, 0) return codec @@ -368,11 +384,12 @@ cdef class DataCodecConfig: elem_format = PG_FORMAT_BINARY else: elem_format = PG_FORMAT_TEXT + elem_codec = self.get_codec(array_element_oid, elem_format) if elem_codec is None: - raise RuntimeError( - 'no codec for array element type {}'.format( - array_element_oid)) + elem_format = PG_FORMAT_TEXT + elem_codec = self.declare_fallback_codec( + array_element_oid, name, schema) elem_delim = ti['elemdelim'][0] @@ -410,9 +427,8 @@ cdef class DataCodecConfig: self._type_codecs_cache[oid, format] = \ Codec.new_composite_codec( - oid, name, schema, comp_elem_codecs, - comp_type_attrs, - element_names) + oid, name, schema, format, comp_elem_codecs, + comp_type_attrs, element_names) elif ti['kind'] == b'd': # Domain type @@ -424,8 +440,9 @@ cdef class DataCodecConfig: elem_codec = self.get_codec(base_type, format) if elem_codec is None: - raise RuntimeError( - 'no codec for domain base type {}'.format(base_type)) + format = PG_FORMAT_TEXT + elem_codec = self.declare_fallback_codec( + base_type, name, schema) self._type_codecs_cache[oid, format] = elem_codec @@ -441,34 +458,18 @@ cdef class DataCodecConfig: elem_format = PG_FORMAT_BINARY else: elem_format = PG_FORMAT_TEXT + elem_codec = self.get_codec(range_subtype_oid, elem_format) if elem_codec is None: - raise RuntimeError( - 'no codec for range element type {}'.format( - range_subtype_oid)) + elem_format = PG_FORMAT_TEXT + elem_codec = self.declare_fallback_codec( + range_subtype_oid, name, schema) self._type_codecs_cache[oid, elem_format] = \ Codec.new_range_codec(oid, name, schema, elem_codec) else: - if oid <= MAXBUILTINOID: - # This is a non-BKI type, for which ayncpg has no - # defined codec. This should only happen for newly - # added builtin types, for which this version of - # asyncpg is lacking support. - # - raise NotImplementedError( - 'unhandled standard data type {!r} (OID {})'.format( - name, oid)) - else: - # This is a non-BKI type, and as such, has no - # stable OID, so no possibility of a builtin codec. - # In this case, fallback to text format. Applications - # can avoid this by specifying a codec for this type - # using Connection.set_type_codec(). - # - self.set_builtin_type_codec(oid, name, schema, 'scalar', - UNKNOWNOID) + self.declare_fallback_codec(oid, name, schema) def add_python_codec(self, typeoid, typename, typeschema, typekind, encoder, decoder, binary): @@ -478,13 +479,20 @@ cdef class DataCodecConfig: Codec.new_python_codec(typeoid, typename, typeschema, typekind, encoder, decoder, format) + self.clear_type_cache() + def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, - alias_to): + alias_to, format=PG_FORMAT_ANY): cdef: Codec codec Codec target_codec - for format in (PG_FORMAT_BINARY, PG_FORMAT_TEXT): + if format == PG_FORMAT_ANY: + formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT) + else: + formats = (format,) + + for format in formats: if self.get_codec(typeoid, format) is not None: raise ValueError('cannot override codec for type {}'.format( typeoid)) @@ -509,9 +517,41 @@ cdef class DataCodecConfig: (typeoid, PG_FORMAT_TEXT) not in self._local_type_codecs): raise ValueError('unknown alias target: {}'.format(alias_to)) + self.clear_type_cache() + def clear_type_cache(self): self._type_codecs_cache.clear() + def declare_fallback_codec(self, uint32_t oid, str name, str schema): + cdef Codec codec + + codec = self.get_codec(oid, PG_FORMAT_TEXT) + if codec is not None: + return codec + + if oid <= MAXBUILTINOID: + # This is a BKI type, for which ayncpg has no + # defined codec. This should only happen for newly + # added builtin types, for which this version of + # asyncpg is lacking support. + # + raise NotImplementedError( + 'unhandled standard data type {!r} (OID {})'.format( + name, oid)) + else: + # This is a non-BKI type, and as such, has no + # stable OID, so no possibility of a builtin codec. + # In this case, fallback to text format. Applications + # can avoid this by specifying a codec for this type + # using Connection.set_type_codec(). + # + self.set_builtin_type_codec(oid, name, schema, 'scalar', + TEXTOID, PG_FORMAT_TEXT) + + codec = self.get_codec(oid, PG_FORMAT_TEXT) + + return codec + cdef inline Codec get_codec(self, uint32_t oid, CodecFormat format): cdef Codec codec diff --git a/asyncpg/protocol/codecs/text.pyx b/asyncpg/protocol/codecs/text.pyx index b8397d95..b745dcac 100644 --- a/asyncpg/protocol/codecs/text.pyx +++ b/asyncpg/protocol/codecs/text.pyx @@ -63,5 +63,9 @@ cdef init_text_codecs(): &text_decode, PG_FORMAT_BINARY) + register_core_codec(oid, + &text_encode, + &text_decode, + PG_FORMAT_TEXT) init_text_codecs() diff --git a/asyncpg/protocol/codecs/textutils.pyx b/asyncpg/protocol/codecs/textutils.pyx index 1a09c179..d9716106 100644 --- a/asyncpg/protocol/codecs/textutils.pyx +++ b/asyncpg/protocol/codecs/textutils.pyx @@ -5,6 +5,12 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +cdef inline uint32_t _apg_tolower(uint32_t c): + if c >= 'A' and c <= 'Z': + return c + 'a' - 'A' + else: + return c + cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2): cdef: @@ -17,11 +23,8 @@ cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2): c2 = s2[i] if c1 != c2: - if c1 >= 'A' and c1 <= 'Z': - c1 += 'a' - 'A' - if c2 >= 'A' and c2 <= 'Z': - c2 += 'a' - 'A' - + c1 = _apg_tolower(c1) + c2 = _apg_tolower(c2) if c1 != c2: return c1 - c2 @@ -33,6 +36,30 @@ cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2): return 0 +cdef int apg_strcasecmp_char(const char *s1, const char *s2): + cdef: + uint8_t c1 + uint8_t c2 + int i = 0 + + while True: + c1 = s1[i] + c2 = s2[i] + + if c1 != c2: + c1 = _apg_tolower(c1) + c2 = _apg_tolower(c2) + if c1 != c2: + return c1 - c2 + + if c1 == 0 or c2 == 0: + break + + i += 1 + + return 0 + + cdef inline bint apg_ascii_isspace(Py_UCS4 ch): return ( ch == ' ' or diff --git a/tests/test_codecs.py b/tests/test_codecs.py index a810bd00..14402ab4 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -1077,3 +1077,87 @@ async def test_enum(self): DROP TABLE tab; DROP TYPE enum_t; ''') + + async def test_unknown_type_text_fallback(self): + await self.con.execute(r'CREATE EXTENSION citext') + await self.con.execute(r''' + CREATE DOMAIN citext_dom AS citext + ''') + await self.con.execute(r''' + CREATE TYPE citext_range AS RANGE (SUBTYPE = citext) + ''') + await self.con.execute(r''' + CREATE TYPE citext_comp AS (t citext) + ''') + + try: + # Check that plain fallback works. + result = await self.con.fetchval(''' + SELECT $1::citext + ''', 'citext') + + self.assertEqual(result, 'citext') + + # Check that domain fallback works. + result = await self.con.fetchval(''' + SELECT $1::citext_dom + ''', 'citext') + + self.assertEqual(result, 'citext') + + # Check that array fallback works. + cases = [ + ['a', 'b'], + [None, 'b'], + [], + [' a', ' b'], + ['"a', r'\""'], + [['"a', r'\""'], [',', '",']], + ] + + for case in cases: + result = await self.con.fetchval(''' + SELECT + $1::citext[] + ''', case) + + self.assertEqual(result, case) + + # Text encoding of ranges and composite types + # is not supported yet. + with self.assertRaisesRegex( + RuntimeError, + 'text encoding of range types is not supported'): + + await self.con.fetchval(''' + SELECT + $1::citext_range + ''', ['a', 'z']) + + with self.assertRaisesRegex( + RuntimeError, + 'text encoding of composite types is not supported'): + + await self.con.fetchval(''' + SELECT + $1::citext_comp + ''', ('a',)) + + # Check that setting a custom codec clears the codec + # cache properly and that subsequent queries work + # as expected. + await self.con.set_type_codec( + 'citext', encoder=lambda d: d, decoder=lambda d: 'CI: ' + d) + + result = await self.con.fetchval(''' + SELECT + $1::citext[] + ''', ['a', 'b']) + + self.assertEqual(result, ['CI: a', 'CI: b']) + + finally: + await self.con.execute(r'DROP TYPE citext_comp') + await self.con.execute(r'DROP TYPE citext_range') + await self.con.execute(r'DROP TYPE citext_dom') + await self.con.execute(r'DROP EXTENSION citext')