Skip to content

Fix DataFrame.LoadCsv can not load CSV with duplicate column names #6772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/Microsoft.Data.Analysis/DataFrame.IO.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Data.Common;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

Expand Down Expand Up @@ -349,8 +350,8 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader,
char separator = ',', bool header = true,
string[] columnNames = null, Type[] dataTypes = null,
long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false
)
long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false,
bool renameDuplicatedColumns = false)
{
if (dataTypes == null && guessRows <= 0)
{
Expand All @@ -376,6 +377,25 @@ private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringRe
// First pass: schema and number of rows.
while ((fields = parser.ReadFields()) != null)
{
if (renameDuplicatedColumns)
{
var names = new Dictionary<string, int>();

for (int i = 0; i < fields.Length; i++)
{
if (names.TryGetValue(fields[i], out int index))
{
var newName = String.Format("{0}.{1}", fields[i], index);
names[fields[i]] = ++index;
fields[i] = newName;
}
else
{
names.Add(fields[i], 1);
}
}
}

if ((numberOfRowsToRead == -1) || rowline < numberOfRowsToRead)
{
if (linesForGuessType.Count < guessRows || (header && rowline == 0))
Expand Down Expand Up @@ -524,12 +544,13 @@ public static DataFrame LoadCsvFromString(string csvString,
/// <param name="guessRows">number of rows used to guess types</param>
/// <param name="addIndexColumn">add one column with the row index</param>
/// <param name="encoding">The character encoding. Defaults to UTF8 if not specified</param>
/// <param name="renameDuplicatedColumns">If set to true, columns with repeated names are auto-renamed.</param>
/// <returns><see cref="DataFrame"/></returns>
public static DataFrame LoadCsv(Stream csvStream,
char separator = ',', bool header = true,
string[] columnNames = null, Type[] dataTypes = null,
long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false,
Encoding encoding = null)
Encoding encoding = null, bool renameDuplicatedColumns = false)
{
if (!csvStream.CanSeek)
{
Expand All @@ -542,7 +563,7 @@ public static DataFrame LoadCsv(Stream csvStream,
}

WrappedStreamReaderOrStringReader wrappedStreamReaderOrStringReader = new WrappedStreamReaderOrStringReader(csvStream, encoding ?? Encoding.UTF8);
return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn);
return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn, renameDuplicatedColumns);
}

/// <summary>
Expand Down
40 changes: 40 additions & 0 deletions test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,46 @@ void ReducedRowsTest(DataFrame reducedRows)
ReducedRowsTest(csvDf);
}

[Fact]
public void TestReadCsvWithHeaderAndDuplicatedColumns_WithoutRenaming()
{

string data = @$"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD,CRD,17.5
CMT,1,1,474,1.5,CRD,CRD,8
CMT,1,1,637,1.4,CRD,CRD,8.5
CMT,1,1,181,0.6,CSH,CSH,4.5";

Assert.Throws<System.ArgumentException>(() => DataFrame.LoadCsv(GetStream(data)));
}

[Fact]
public void TestReadCsvWithHeaderAndDuplicatedColumns_WithDuplicateColumnRenaming()
{

string data = @$"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,payment_type,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD,CRD_1,Test,17.5
CMT,1,1,474,1.5,CRD,CRD,Test,8
CMT,1,1,637,1.4,CRD,CRD,Test,8.5
CMT,1,1,181,0.6,CSH,CSH,Test,4.5";

DataFrame df = DataFrame.LoadCsv(GetStream(data), renameDuplicatedColumns: true);

Assert.Equal(4, df.Rows.Count);
Assert.Equal(9, df.Columns.Count);
Assert.Equal("CMT", df.Columns["vendor_id"][3]);

Assert.Equal("payment_type", df.Columns[5].Name);
Assert.Equal("payment_type.1", df.Columns[6].Name);
Assert.Equal("payment_type.2", df.Columns[7].Name);

Assert.Equal("CRD", df.Columns["payment_type"][0]);
Assert.Equal("CRD_1", df.Columns["payment_type.1"][0]);
Assert.Equal("Test", df.Columns["payment_type.2"][0]);

VerifyColumnTypes(df);
}

[Fact]
public void TestReadCsvSplitAcrossMultipleLines()
{
Expand Down