diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 11b139b620175..a7588cc741352 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -196,15 +196,23 @@ def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): if_exists='append' else: if_exists='fail' + + if if_exists not in ('fail', 'replace', 'append'): + raise ValueError, "'%s' is not valid for if_exists" % if_exists + exists = table_exists(name, con, flavor) - if if_exists == 'fail' and exists: - raise ValueError, "Table '%s' already exists." % name - #create or drop-recreate if necessary + # creation/replacement dependent on the table existing and if_exist criteria create = None - if exists and if_exists == 'replace': - create = "DROP TABLE %s" % name - elif not exists: + if exists: + if if_exists == 'fail': + raise ValueError, "Table '%s' already exists." % name + elif if_exists == 'replace': + cur = con.cursor() + cur.execute("DROP TABLE %s;" % name) + cur.close() + create = get_schema(frame, name, flavor) + else: create = get_schema(frame, name, flavor) if create is not None: diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 5b23bf173ec4e..0d4cd9b52023d 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -240,6 +240,65 @@ def test_onecolumn_of_integer(self): result = sql.read_frame("select * from mono_df",con_x) tm.assert_frame_equal(result,mono_df) + def test_if_exists(self): + df_if_exists_1 = DataFrame({'col1': [1, 2], 'col2': ['A', 'B']}) + df_if_exists_2 = DataFrame({'col1': [3, 4, 5], 'col2': ['C', 'D', 'E']}) + table_name = 'table_if_exists' + sql_select = "SELECT * FROM %s" % table_name + + def clean_up(test_table_to_drop): + """ + Drops tables created from individual tests + so no dependencies arise from sequential tests + """ + if sql.table_exists(test_table_to_drop, self.db, flavor='sqlite'): + cur = self.db.cursor() + cur.execute("DROP TABLE %s" % test_table_to_drop) + cur.close() + + # test if invalid value for if_exists raises appropriate error + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='sqlite', + if_exists='notvalidvalue') + clean_up(table_name) + + # test if_exists='fail' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='sqlite', if_exists='fail') + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='sqlite', + if_exists='fail') + + # test if_exists='replace' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='sqlite', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='sqlite', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + + # test if_exists='append' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='sqlite', if_exists='fail') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='sqlite', if_exists='append') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B'), (3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + class TestMySQL(unittest.TestCase): @@ -483,6 +542,66 @@ def test_keyword_as_column_names(self): sql.write_frame(df, con = self.db, name = 'testkeywords', if_exists='replace', flavor='mysql') + def test_if_exists(self): + _skip_if_no_MySQLdb() + df_if_exists_1 = DataFrame({'col1': [1, 2], 'col2': ['A', 'B']}) + df_if_exists_2 = DataFrame({'col1': [3, 4, 5], 'col2': ['C', 'D', 'E']}) + table_name = 'table_if_exists' + sql_select = "SELECT * FROM %s" % table_name + + def clean_up(test_table_to_drop): + """ + Drops tables created from individual tests + so no dependencies arise from sequential tests + """ + if sql.table_exists(test_table_to_drop, self.db, flavor='mysql'): + cur = self.db.cursor() + cur.execute("DROP TABLE %s" % test_table_to_drop) + cur.close() + + # test if invalid value for if_exists raises appropriate error + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='mysql', + if_exists='notvalidvalue') + clean_up(table_name) + + # test if_exists='fail' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='mysql', if_exists='fail') + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='mysql', + if_exists='fail') + + # test if_exists='replace' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='mysql', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='mysql', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + + # test if_exists='append' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='mysql', if_exists='fail') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='mysql', if_exists='append') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B'), (3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + if __name__ == '__main__': # unittest.main()