19
19
from pandas .core .base import PandasObject
20
20
from pandas .tseries .tools import to_datetime
21
21
22
+ from contextlib import contextmanager
22
23
23
24
class SQLAlchemyRequired (ImportError ):
24
25
pass
@@ -645,13 +646,9 @@ def insert_data(self):
645
646
646
647
return column_names , data_list
647
648
648
- def get_session (self ):
649
- con = self .pd_sql .engine .connect ()
650
- return con .begin ()
651
-
652
- def _execute_insert (self , trans , keys , data_iter ):
649
+ def _execute_insert (self , conn , keys , data_iter ):
653
650
data = [dict ( (k , v ) for k , v in zip (keys , row ) ) for row in data_iter ]
654
- trans . connection .execute (self .insert_statement (), data )
651
+ conn .execute (self .insert_statement (), data )
655
652
656
653
def insert (self , chunksize = None ):
657
654
keys , data_list = self .insert_data ()
@@ -661,15 +658,15 @@ def insert(self, chunksize=None):
661
658
chunksize = nrows
662
659
chunks = int (nrows / chunksize ) + 1
663
660
664
- with self .get_session () as trans :
661
+ with self .pd_sql . run_transaction () as conn :
665
662
for i in range (chunks ):
666
663
start_i = i * chunksize
667
664
end_i = min ((i + 1 ) * chunksize , nrows )
668
665
if start_i >= end_i :
669
666
break
670
667
671
668
chunk_iter = zip (* [arr [start_i :end_i ] for arr in data_list ])
672
- self ._execute_insert (trans , keys , chunk_iter )
669
+ self ._execute_insert (conn , keys , chunk_iter )
673
670
674
671
def read (self , coerce_float = True , parse_dates = None , columns = None ):
675
672
@@ -892,6 +889,9 @@ def __init__(self, engine, schema=None, meta=None):
892
889
893
890
self .meta = meta
894
891
892
+ def run_transaction (self ):
893
+ return self .engine .begin ()
894
+
895
895
def execute (self , * args , ** kwargs ):
896
896
"""Simple passthrough to SQLAlchemy engine"""
897
897
return self .engine .execute (* args , ** kwargs )
@@ -1025,9 +1025,9 @@ def sql_schema(self):
1025
1025
return str (";\n " .join (self .table ))
1026
1026
1027
1027
def _execute_create (self ):
1028
- with self .get_session () :
1028
+ with self .pd_sql . run_transaction () as conn :
1029
1029
for stmt in self .table :
1030
- self . pd_sql .execute (stmt )
1030
+ conn .execute (stmt )
1031
1031
1032
1032
def insert_statement (self ):
1033
1033
names = list (map (str , self .frame .columns ))
@@ -1046,12 +1046,9 @@ def insert_statement(self):
1046
1046
self .name , col_names , wildcards )
1047
1047
return insert_statement
1048
1048
1049
- def get_session (self ):
1050
- return self .pd_sql .con
1051
-
1052
- def _execute_insert (self , trans , keys , data_iter ):
1049
+ def _execute_insert (self , conn , keys , data_iter ):
1053
1050
data_list = list (data_iter )
1054
- trans .executemany (self .insert_statement (), data_list )
1051
+ conn .executemany (self .insert_statement (), data_list )
1055
1052
1056
1053
def _create_table_setup (self ):
1057
1054
"""Return a list of SQL statement that create a table reflecting the
@@ -1133,6 +1130,17 @@ def __init__(self, con, flavor, is_cursor=False):
1133
1130
else :
1134
1131
self .flavor = flavor
1135
1132
1133
+ @contextmanager
1134
+ def run_transaction (self ):
1135
+ cur = self .con .cursor ()
1136
+ try :
1137
+ yield cur
1138
+ self .con .commit ()
1139
+ except :
1140
+ self .con .rollback ()
1141
+ finally :
1142
+ cur .close ()
1143
+
1136
1144
def execute (self , * args , ** kwargs ):
1137
1145
if self .is_cursor :
1138
1146
cur = self .con
0 commit comments