Skip to content

Commit 66cec62

Browse files
committed
refs #21199. to_sql() support passing a callable to method parameter.
1 parent 8528936 commit 66cec62

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

pandas/io/sql.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datetime import datetime, date, time
99
import csv
1010
from io import StringIO
11+
from functools import partial
1112

1213
import warnings
1314
import re
@@ -660,8 +661,9 @@ def insert(self, chunksize=None, method=None):
660661
exec_insert = self._execute_insert_multi
661662
elif method == 'copy':
662663
exec_insert = self._execute_insert_copy
664+
elif callable(method):
665+
exec_insert = partial(method, self)
663666
else:
664-
# TODO: support callables?
665667
raise ValueError('Invalid parameter `method`: {}'.format(method))
666668

667669
keys, data_list = self.insert_data()

pandas/tests/io/test_sql.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,25 @@ def _to_sql_append(self):
435435
assert num_rows == num_entries
436436
self.drop_table('test_frame1')
437437

438+
def _to_sql_method_callable(self):
439+
check = [] # used to double check function below is really being used
440+
441+
def sample(pd_table, conn, keys, data_iter):
442+
check.append(1)
443+
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
444+
conn.execute(pd_table.table.insert(), data)
445+
self.drop_table('test_frame1')
446+
447+
self.pandasSQL.to_sql(self.test_frame1, 'test_frame1', method=sample)
448+
assert self.pandasSQL.has_table('test_frame1')
449+
450+
assert check == [1]
451+
num_entries = len(self.test_frame1)
452+
num_rows = self._count_rows('test_frame1')
453+
assert num_rows == num_entries
454+
# Nuke table
455+
self.drop_table('test_frame1')
456+
438457
def _roundtrip(self):
439458
self.drop_table('test_frame_roundtrip')
440459
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
@@ -1211,6 +1230,9 @@ def test_to_sql_append(self):
12111230
def test_to_sql_method_multi(self):
12121231
self._to_sql(method='multi')
12131232

1233+
def test_to_sql_method_callable(self):
1234+
self._to_sql_method_callable()
1235+
12141236
def test_create_table(self):
12151237
temp_conn = self.connect()
12161238
temp_frame = DataFrame(
@@ -1686,7 +1708,6 @@ class _TestSQLiteAlchemy(object):
16861708
16871709
"""
16881710
flavor = 'sqlite'
1689-
supports_multivalues_insert = True
16901711

16911712
@classmethod
16921713
def connect(cls):
@@ -1735,7 +1756,6 @@ class _TestMySQLAlchemy(object):
17351756
17361757
"""
17371758
flavor = 'mysql'
1738-
supports_multivalues_insert = True
17391759

17401760
@classmethod
17411761
def connect(cls):
@@ -1805,7 +1825,6 @@ class _TestPostgreSQLAlchemy(object):
18051825
18061826
"""
18071827
flavor = 'postgresql'
1808-
supports_multivalues_insert = True
18091828

18101829
@classmethod
18111830
def connect(cls):

0 commit comments

Comments
 (0)