Skip to content

Commit e3f53a4

Browse files
authored
Fix DataFrame.LoadCsv can not load CSV with duplicate column names (#6772)
1 parent 39235a7 commit e3f53a4

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

src/Microsoft.Data.Analysis/DataFrame.IO.cs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Data.Common;
99
using System.Globalization;
1010
using System.IO;
11+
using System.Linq;
1112
using System.Text;
1213
using System.Threading.Tasks;
1314

@@ -349,8 +350,8 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
349350
private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader,
350351
char separator = ',', bool header = true,
351352
string[] columnNames = null, Type[] dataTypes = null,
352-
long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false
353-
)
353+
long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false,
354+
bool renameDuplicatedColumns = false)
354355
{
355356
if (dataTypes == null && guessRows <= 0)
356357
{
@@ -376,6 +377,25 @@ private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringRe
376377
// First pass: schema and number of rows.
377378
while ((fields = parser.ReadFields()) != null)
378379
{
380+
if (renameDuplicatedColumns)
381+
{
382+
var names = new Dictionary<string, int>();
383+
384+
for (int i = 0; i < fields.Length; i++)
385+
{
386+
if (names.TryGetValue(fields[i], out int index))
387+
{
388+
var newName = String.Format("{0}.{1}", fields[i], index);
389+
names[fields[i]] = ++index;
390+
fields[i] = newName;
391+
}
392+
else
393+
{
394+
names.Add(fields[i], 1);
395+
}
396+
}
397+
}
398+
379399
if ((numberOfRowsToRead == -1) || rowline < numberOfRowsToRead)
380400
{
381401
if (linesForGuessType.Count < guessRows || (header && rowline == 0))
@@ -524,12 +544,13 @@ public static DataFrame LoadCsvFromString(string csvString,
524544
/// <param name="guessRows">number of rows used to guess types</param>
525545
/// <param name="addIndexColumn">add one column with the row index</param>
526546
/// <param name="encoding">The character encoding. Defaults to UTF8 if not specified</param>
547+
/// <param name="renameDuplicatedColumns">If set to true, columns with repeated names are auto-renamed.</param>
527548
/// <returns><see cref="DataFrame"/></returns>
528549
public static DataFrame LoadCsv(Stream csvStream,
529550
char separator = ',', bool header = true,
530551
string[] columnNames = null, Type[] dataTypes = null,
531552
long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false,
532-
Encoding encoding = null)
553+
Encoding encoding = null, bool renameDuplicatedColumns = false)
533554
{
534555
if (!csvStream.CanSeek)
535556
{
@@ -542,7 +563,7 @@ public static DataFrame LoadCsv(Stream csvStream,
542563
}
543564

544565
WrappedStreamReaderOrStringReader wrappedStreamReaderOrStringReader = new WrappedStreamReaderOrStringReader(csvStream, encoding ?? Encoding.UTF8);
545-
return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn);
566+
return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn, renameDuplicatedColumns);
546567
}
547568

548569
/// <summary>

test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,46 @@ void ReducedRowsTest(DataFrame reducedRows)
154154
ReducedRowsTest(csvDf);
155155
}
156156

157+
[Fact]
158+
public void TestReadCsvWithHeaderAndDuplicatedColumns_WithoutRenaming()
159+
{
160+
161+
string data = @$"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,payment_type,fare_amount
162+
CMT,1,1,1271,3.8,CRD,CRD,17.5
163+
CMT,1,1,474,1.5,CRD,CRD,8
164+
CMT,1,1,637,1.4,CRD,CRD,8.5
165+
CMT,1,1,181,0.6,CSH,CSH,4.5";
166+
167+
Assert.Throws<System.ArgumentException>(() => DataFrame.LoadCsv(GetStream(data)));
168+
}
169+
170+
[Fact]
171+
public void TestReadCsvWithHeaderAndDuplicatedColumns_WithDuplicateColumnRenaming()
172+
{
173+
174+
string data = @$"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,payment_type,payment_type,fare_amount
175+
CMT,1,1,1271,3.8,CRD,CRD_1,Test,17.5
176+
CMT,1,1,474,1.5,CRD,CRD,Test,8
177+
CMT,1,1,637,1.4,CRD,CRD,Test,8.5
178+
CMT,1,1,181,0.6,CSH,CSH,Test,4.5";
179+
180+
DataFrame df = DataFrame.LoadCsv(GetStream(data), renameDuplicatedColumns: true);
181+
182+
Assert.Equal(4, df.Rows.Count);
183+
Assert.Equal(9, df.Columns.Count);
184+
Assert.Equal("CMT", df.Columns["vendor_id"][3]);
185+
186+
Assert.Equal("payment_type", df.Columns[5].Name);
187+
Assert.Equal("payment_type.1", df.Columns[6].Name);
188+
Assert.Equal("payment_type.2", df.Columns[7].Name);
189+
190+
Assert.Equal("CRD", df.Columns["payment_type"][0]);
191+
Assert.Equal("CRD_1", df.Columns["payment_type.1"][0]);
192+
Assert.Equal("Test", df.Columns["payment_type.2"][0]);
193+
194+
VerifyColumnTypes(df);
195+
}
196+
157197
[Fact]
158198
public void TestReadCsvSplitAcrossMultipleLines()
159199
{

0 commit comments

Comments
 (0)