Skip to content

BUG: When creating table, db indexes should be created from DataFrame indexes #8083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 11, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,17 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True,
raise ValueError("Table '%s' already exists." % name)
elif if_exists == 'replace':
self.pd_sql.drop_table(self.name, self.schema)
self.table = self._create_table_statement()
self.table = self._create_table_setup()
self.create()
elif if_exists == 'append':
self.table = self.pd_sql.get_table(self.name, self.schema)
if self.table is None:
self.table = self._create_table_statement()
self.table = self._create_table_setup()
else:
raise ValueError(
"'{0}' is not valid for if_exists".format(if_exists))
else:
self.table = self._create_table_statement()
self.table = self._create_table_setup()
self.create()
else:
# no data provided, read-only mode
Expand Down Expand Up @@ -703,23 +703,25 @@ def _get_column_names_and_types(self, dtype_mapper):
for i, idx_label in enumerate(self.index):
idx_type = dtype_mapper(
self.frame.index.get_level_values(i))
column_names_and_types.append((idx_label, idx_type))
column_names_and_types.append((idx_label, idx_type, True))

column_names_and_types += [
(str(self.frame.columns[i]),
dtype_mapper(self.frame.iloc[:,i]))
dtype_mapper(self.frame.iloc[:,i]),
False)
for i in range(len(self.frame.columns))
]

return column_names_and_types

def _create_table_statement(self):
def _create_table_setup(self):
from sqlalchemy import Table, Column

column_names_and_types = \
self._get_column_names_and_types(self._sqlalchemy_type)

columns = [Column(name, typ)
for name, typ in column_names_and_types]
columns = [Column(name, typ, index=is_index)
for name, typ, is_index in column_names_and_types]

return Table(self.name, self.pd_sql.meta, *columns, schema=self.schema)

Expand Down Expand Up @@ -979,10 +981,12 @@ class PandasSQLTableLegacy(PandasSQLTable):
Instead of a table variable just use the Create Table
statement"""
def sql_schema(self):
return str(self.table)
return str(";\n".join(self.table))

def create(self):
self.pd_sql.execute(self.table)
with self.pd_sql.con:
for stmt in self.table:
self.pd_sql.execute(stmt)

def insert_statement(self):
names = list(map(str, self.frame.columns))
Expand Down Expand Up @@ -1026,14 +1030,17 @@ def insert(self, chunksize=None):
cur.executemany(ins, data_list)
cur.close()

def _create_table_statement(self):
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
def _create_table_setup(self):
"""Return a list of SQL statement that create a table reflecting the
structure of a DataFrame. The first entry will be a CREATE TABLE
statement while the rest will be CREATE INDEX statements
"""

column_names_and_types = \
self._get_column_names_and_types(self._sql_type_name)

pat = re.compile('\s+')
column_names = [col_name for col_name, _ in column_names_and_types]
column_names = [col_name for col_name, _, _ in column_names_and_types]
if any(map(pat.search, column_names)):
warnings.warn(_SAFE_NAMES_WARNING)

Expand All @@ -1044,13 +1051,21 @@ def _create_table_statement(self):

col_template = br_l + '%s' + br_r + ' %s'

columns = ',\n '.join(col_template %
x for x in column_names_and_types)
columns = ',\n '.join(col_template % (cname, ctype)
for cname, ctype, _ in column_names_and_types)
template = """CREATE TABLE %(name)s (
%(columns)s
)"""
create_statement = template % {'name': self.name, 'columns': columns}
return create_statement
create_stmts = [template % {'name': self.name, 'columns': columns}, ]

ix_tpl = "CREATE INDEX ix_{tbl}_{col} ON {tbl} ({br_l}{col}{br_r})"
for cname, _, is_index in column_names_and_types:
if not is_index:
continue
create_stmts.append(ix_tpl.format(tbl=self.name, col=cname,
br_l=br_l, br_r=br_r))

return create_stmts

def _sql_type_name(self, col):
pytype = col.dtype.type
Expand Down
53 changes: 52 additions & 1 deletion pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _load_test2_data(self):
E=['1990-11-22', '1991-10-26', '1993-11-26', '1995-12-12']))
df['E'] = to_datetime(df['E'])

self.test_frame3 = df
self.test_frame2 = df

def _load_test3_data(self):
columns = ['index', 'A', 'B']
Expand Down Expand Up @@ -324,6 +324,13 @@ def _execute_sql(self):
row = iris_results.fetchone()
tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa'])

def _to_sql_save_index(self):
df = DataFrame.from_records([(1,2.1,'line1'), (2,1.5,'line2')],
columns=['A','B','C'], index=['A'])
self.pandasSQL.to_sql(df, 'test_to_sql_saves_index')
ix_cols = self._get_index_columns('test_to_sql_saves_index')
self.assertEqual(ix_cols, [['A',],])


#------------------------------------------------------------------------------
#--- Testing the public API
Expand Down Expand Up @@ -694,6 +701,13 @@ def test_warning_case_insensitive_table_name(self):
# Verify some things
self.assertEqual(len(w), 0, "Warning triggered for writing a table")

def _get_index_columns(self, tbl_name):
from sqlalchemy.engine import reflection
insp = reflection.Inspector.from_engine(self.conn)
ixs = insp.get_indexes('test_index_saved')
ixs = [i['column_names'] for i in ixs]
return ixs


class TestSQLLegacyApi(_TestSQLApi):
"""
Expand Down Expand Up @@ -1074,6 +1088,16 @@ def test_nan_string(self):
result = sql.read_sql_query('SELECT * FROM test_nan', self.conn)
tm.assert_frame_equal(result, df)

def _get_index_columns(self, tbl_name):
from sqlalchemy.engine import reflection
insp = reflection.Inspector.from_engine(self.conn)
ixs = insp.get_indexes(tbl_name)
ixs = [i['column_names'] for i in ixs]
return ixs

def test_to_sql_save_index(self):
self._to_sql_save_index()


class TestSQLiteAlchemy(_TestSQLAlchemy):
"""
Expand Down Expand Up @@ -1368,6 +1392,20 @@ def test_datetime_time(self):
# test support for datetime.time
raise nose.SkipTest("datetime.time not supported for sqlite fallback")

def _get_index_columns(self, tbl_name):
ixs = sql.read_sql_query(
"SELECT * FROM sqlite_master WHERE type = 'index' " +
"AND tbl_name = '%s'" % tbl_name, self.conn)
ix_cols = []
for ix_name in ixs.name:
ix_info = sql.read_sql_query(
"PRAGMA index_info(%s)" % ix_name, self.conn)
ix_cols.append(ix_info.name.tolist())
return ix_cols

def test_to_sql_save_index(self):
self._to_sql_save_index()


class TestMySQLLegacy(TestSQLiteLegacy):
"""
Expand Down Expand Up @@ -1424,6 +1462,19 @@ def test_a_deprecation(self):
sql.has_table('test_frame1', self.conn, flavor='mysql'),
'Table not written to DB')

def _get_index_columns(self, tbl_name):
ixs = sql.read_sql_query(
"SHOW INDEX IN %s" % tbl_name, self.conn)
ix_cols = {}
for ix_name, ix_col in zip(ixs.Key_name, ixs.Column_name):
if ix_name not in ix_cols:
ix_cols[ix_name] = []
ix_cols[ix_name].append(ix_col)
return list(ix_cols.values())

def test_to_sql_save_index(self):
self._to_sql_save_index()


#------------------------------------------------------------------------------
#--- Old tests from 0.13.1 (before refactor using sqlalchemy)
Expand Down