diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 09c59710e9b0c..c960a73bb0f88 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -566,17 +566,17 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True, raise ValueError("Table '%s' already exists." % name) elif if_exists == 'replace': self.pd_sql.drop_table(self.name, self.schema) - self.table = self._create_table_statement() + self.table = self._create_table_setup() self.create() elif if_exists == 'append': self.table = self.pd_sql.get_table(self.name, self.schema) if self.table is None: - self.table = self._create_table_statement() + self.table = self._create_table_setup() else: raise ValueError( "'{0}' is not valid for if_exists".format(if_exists)) else: - self.table = self._create_table_statement() + self.table = self._create_table_setup() self.create() else: # no data provided, read-only mode @@ -703,23 +703,25 @@ def _get_column_names_and_types(self, dtype_mapper): for i, idx_label in enumerate(self.index): idx_type = dtype_mapper( self.frame.index.get_level_values(i)) - column_names_and_types.append((idx_label, idx_type)) + column_names_and_types.append((idx_label, idx_type, True)) column_names_and_types += [ (str(self.frame.columns[i]), - dtype_mapper(self.frame.iloc[:,i])) + dtype_mapper(self.frame.iloc[:,i]), + False) for i in range(len(self.frame.columns)) ] + return column_names_and_types - def _create_table_statement(self): + def _create_table_setup(self): from sqlalchemy import Table, Column column_names_and_types = \ self._get_column_names_and_types(self._sqlalchemy_type) - columns = [Column(name, typ) - for name, typ in column_names_and_types] + columns = [Column(name, typ, index=is_index) + for name, typ, is_index in column_names_and_types] return Table(self.name, self.pd_sql.meta, *columns, schema=self.schema) @@ -979,10 +981,12 @@ class PandasSQLTableLegacy(PandasSQLTable): Instead of a table variable just use the Create Table statement""" def sql_schema(self): - return str(self.table) + return str(";\n".join(self.table)) def create(self): - self.pd_sql.execute(self.table) + with self.pd_sql.con: + for stmt in self.table: + self.pd_sql.execute(stmt) def insert_statement(self): names = list(map(str, self.frame.columns)) @@ -1026,14 +1030,17 @@ def insert(self, chunksize=None): cur.executemany(ins, data_list) cur.close() - def _create_table_statement(self): - "Return a CREATE TABLE statement to suit the contents of a DataFrame." + def _create_table_setup(self): + """Return a list of SQL statement that create a table reflecting the + structure of a DataFrame. The first entry will be a CREATE TABLE + statement while the rest will be CREATE INDEX statements + """ column_names_and_types = \ self._get_column_names_and_types(self._sql_type_name) pat = re.compile('\s+') - column_names = [col_name for col_name, _ in column_names_and_types] + column_names = [col_name for col_name, _, _ in column_names_and_types] if any(map(pat.search, column_names)): warnings.warn(_SAFE_NAMES_WARNING) @@ -1044,13 +1051,21 @@ def _create_table_statement(self): col_template = br_l + '%s' + br_r + ' %s' - columns = ',\n '.join(col_template % - x for x in column_names_and_types) + columns = ',\n '.join(col_template % (cname, ctype) + for cname, ctype, _ in column_names_and_types) template = """CREATE TABLE %(name)s ( %(columns)s )""" - create_statement = template % {'name': self.name, 'columns': columns} - return create_statement + create_stmts = [template % {'name': self.name, 'columns': columns}, ] + + ix_tpl = "CREATE INDEX ix_{tbl}_{col} ON {tbl} ({br_l}{col}{br_r})" + for cname, _, is_index in column_names_and_types: + if not is_index: + continue + create_stmts.append(ix_tpl.format(tbl=self.name, col=cname, + br_l=br_l, br_r=br_r)) + + return create_stmts def _sql_type_name(self, col): pytype = col.dtype.type diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 0108335c94249..3ad9669abb883 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -199,7 +199,7 @@ def _load_test2_data(self): E=['1990-11-22', '1991-10-26', '1993-11-26', '1995-12-12'])) df['E'] = to_datetime(df['E']) - self.test_frame3 = df + self.test_frame2 = df def _load_test3_data(self): columns = ['index', 'A', 'B'] @@ -324,6 +324,13 @@ def _execute_sql(self): row = iris_results.fetchone() tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) + def _to_sql_save_index(self): + df = DataFrame.from_records([(1,2.1,'line1'), (2,1.5,'line2')], + columns=['A','B','C'], index=['A']) + self.pandasSQL.to_sql(df, 'test_to_sql_saves_index') + ix_cols = self._get_index_columns('test_to_sql_saves_index') + self.assertEqual(ix_cols, [['A',],]) + #------------------------------------------------------------------------------ #--- Testing the public API @@ -694,6 +701,13 @@ def test_warning_case_insensitive_table_name(self): # Verify some things self.assertEqual(len(w), 0, "Warning triggered for writing a table") + def _get_index_columns(self, tbl_name): + from sqlalchemy.engine import reflection + insp = reflection.Inspector.from_engine(self.conn) + ixs = insp.get_indexes('test_index_saved') + ixs = [i['column_names'] for i in ixs] + return ixs + class TestSQLLegacyApi(_TestSQLApi): """ @@ -1074,6 +1088,16 @@ def test_nan_string(self): result = sql.read_sql_query('SELECT * FROM test_nan', self.conn) tm.assert_frame_equal(result, df) + def _get_index_columns(self, tbl_name): + from sqlalchemy.engine import reflection + insp = reflection.Inspector.from_engine(self.conn) + ixs = insp.get_indexes(tbl_name) + ixs = [i['column_names'] for i in ixs] + return ixs + + def test_to_sql_save_index(self): + self._to_sql_save_index() + class TestSQLiteAlchemy(_TestSQLAlchemy): """ @@ -1368,6 +1392,20 @@ def test_datetime_time(self): # test support for datetime.time raise nose.SkipTest("datetime.time not supported for sqlite fallback") + def _get_index_columns(self, tbl_name): + ixs = sql.read_sql_query( + "SELECT * FROM sqlite_master WHERE type = 'index' " + + "AND tbl_name = '%s'" % tbl_name, self.conn) + ix_cols = [] + for ix_name in ixs.name: + ix_info = sql.read_sql_query( + "PRAGMA index_info(%s)" % ix_name, self.conn) + ix_cols.append(ix_info.name.tolist()) + return ix_cols + + def test_to_sql_save_index(self): + self._to_sql_save_index() + class TestMySQLLegacy(TestSQLiteLegacy): """ @@ -1424,6 +1462,19 @@ def test_a_deprecation(self): sql.has_table('test_frame1', self.conn, flavor='mysql'), 'Table not written to DB') + def _get_index_columns(self, tbl_name): + ixs = sql.read_sql_query( + "SHOW INDEX IN %s" % tbl_name, self.conn) + ix_cols = {} + for ix_name, ix_col in zip(ixs.Key_name, ixs.Column_name): + if ix_name not in ix_cols: + ix_cols[ix_name] = [] + ix_cols[ix_name].append(ix_col) + return list(ix_cols.values()) + + def test_to_sql_save_index(self): + self._to_sql_save_index() + #------------------------------------------------------------------------------ #--- Old tests from 0.13.1 (before refactor using sqlalchemy)