|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
| 5 | +using System; |
| 6 | +using System.Collections.Generic; |
| 7 | +using System.Linq; |
5 | 8 | using Microsoft.Data.DataView;
|
6 | 9 | using Microsoft.ML.Data;
|
7 |
| -using Microsoft.ML.SamplesUtils; |
8 |
| -using Microsoft.ML.Trainers.HalLearners; |
| 10 | +using Microsoft.ML.Functional.Tests.Datasets; |
9 | 11 | using Xunit;
|
10 | 12 |
|
11 | 13 | namespace Microsoft.ML.Functional.Tests
|
12 | 14 | {
|
13 | 15 | internal static class Common
|
14 | 16 | {
|
| 17 | + /// <summary> |
| 18 | + /// Asssert that an <see cref="IDataView"/> rows are of <see cref="TypeTestData"/>. |
| 19 | + /// </summary> |
| 20 | + /// <param name="testTypeDataset">An <see cref="IDataView"/>.</param> |
| 21 | + public static void AssertTypeTestDataset(IDataView testTypeDataset) |
| 22 | + { |
| 23 | + var toyClassProperties = typeof(TypeTestData).GetProperties(); |
| 24 | + |
| 25 | + // Check that the schema is of the right size. |
| 26 | + Assert.Equal(toyClassProperties.Length, testTypeDataset.Schema.Count); |
| 27 | + |
| 28 | + // Create a lookup table for the types and counts of all properties. |
| 29 | + var types = new Dictionary<string, Type>(); |
| 30 | + var counts = new Dictionary<string, int>(); |
| 31 | + foreach (var property in toyClassProperties) |
| 32 | + { |
| 33 | + if (!property.PropertyType.IsArray) |
| 34 | + types[property.Name] = property.PropertyType; |
| 35 | + else |
| 36 | + { |
| 37 | + // Construct a VBuffer type for the array. |
| 38 | + var vBufferType = typeof(VBuffer<>); |
| 39 | + Type[] typeArgs = { property.PropertyType.GetElementType() }; |
| 40 | + Activator.CreateInstance(property.PropertyType.GetElementType()); |
| 41 | + types[property.Name] = vBufferType.MakeGenericType(typeArgs); |
| 42 | + } |
| 43 | + |
| 44 | + counts[property.Name] = 0; |
| 45 | + } |
| 46 | + |
| 47 | + foreach (var column in testTypeDataset.Schema) |
| 48 | + { |
| 49 | + Assert.True(types.ContainsKey(column.Name)); |
| 50 | + Assert.Equal(1, ++counts[column.Name]); |
| 51 | + Assert.Equal(types[column.Name], column.Type.RawType); |
| 52 | + } |
| 53 | + |
| 54 | + // Make sure we didn't miss any columns. |
| 55 | + foreach (var value in counts.Values) |
| 56 | + Assert.Equal(1, value); |
| 57 | + } |
| 58 | + |
| 59 | + /// <summary> |
| 60 | + /// Assert than two <see cref="TypeTestData"/> datasets are equal. |
| 61 | + /// </summary> |
| 62 | + /// <param name="mlContext">The ML Context.</param> |
| 63 | + /// <param name="data1">A <see cref="IDataView"/> of <see cref="TypeTestData"/></param> |
| 64 | + /// <param name="data2">A <see cref="IDataView"/> of <see cref="TypeTestData"/></param> |
| 65 | + public static void AssertTestTypeDatasetsAreEqual(MLContext mlContext, IDataView data1, IDataView data2) |
| 66 | + { |
| 67 | + // Confirm that they are both of the propery row type. |
| 68 | + AssertTypeTestDataset(data1); |
| 69 | + AssertTypeTestDataset(data2); |
| 70 | + |
| 71 | + // Validate that the two Schemas are the same. |
| 72 | + Common.AssertEqual(data1.Schema, data2.Schema); |
| 73 | + |
| 74 | + // Define how to serialize the IDataView to objects. |
| 75 | + var enumerable1 = mlContext.CreateEnumerable<TypeTestData>(data1, true); |
| 76 | + var enumerable2 = mlContext.CreateEnumerable<TypeTestData>(data2, true); |
| 77 | + |
| 78 | + AssertEqual(enumerable1, enumerable2); |
| 79 | + } |
| 80 | + |
| 81 | + /// <summary> |
| 82 | + /// Assert that two float arrays are equal. |
| 83 | + /// </summary> |
| 84 | + /// <param name="array1">An array of floats.</param> |
| 85 | + /// <param name="array2">An array of floats.</param> |
| 86 | + public static void AssertEqual(float[] array1, float[] array2) |
| 87 | + { |
| 88 | + Assert.NotNull(array1); |
| 89 | + Assert.NotNull(array2); |
| 90 | + Assert.Equal(array1.Length, array2.Length); |
| 91 | + |
| 92 | + for (int i = 0; i < array1.Length; i++) |
| 93 | + Assert.Equal(array1[i], array2[i]); |
| 94 | + } |
| 95 | + |
| 96 | + /// <summary> |
| 97 | + /// Assert that two <see cref="Schema"/> objects are equal. |
| 98 | + /// </summary> |
| 99 | + /// <param name="schema1">A <see cref="Schema"/> object.</param> |
| 100 | + /// <param name="schema2">A <see cref="Schema"/> object.</param> |
| 101 | + public static void AssertEqual(Schema schema1, Schema schema2) |
| 102 | + { |
| 103 | + Assert.NotNull(schema1); |
| 104 | + Assert.NotNull(schema2); |
| 105 | + |
| 106 | + Assert.Equal(schema1.Count(), schema2.Count()); |
| 107 | + |
| 108 | + foreach (var schemaPair in schema1.Zip(schema2, Tuple.Create)) |
| 109 | + { |
| 110 | + Assert.Equal(schemaPair.Item1.Name, schemaPair.Item2.Name); |
| 111 | + Assert.Equal(schemaPair.Item1.Index, schemaPair.Item2.Index); |
| 112 | + Assert.Equal(schemaPair.Item1.IsHidden, schemaPair.Item2.IsHidden); |
| 113 | + // Can probably do a better comparison of Metadata. |
| 114 | + AssertEqual(schemaPair.Item1.Metadata.Schema, schemaPair.Item1.Metadata.Schema); |
| 115 | + Assert.True((schemaPair.Item1.Type == schemaPair.Item2.Type) || |
| 116 | + (schemaPair.Item1.Type.RawType == schemaPair.Item2.Type.RawType)); |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + /// <summary> |
| 121 | + /// Assert than two <see cref="TypeTestData"/> enumerables are equal. |
| 122 | + /// </summary> |
| 123 | + /// <param name="data1">An enumerable of <see cref="TypeTestData"/></param> |
| 124 | + /// <param name="data2">An enumerable of <see cref="TypeTestData"/></param> |
| 125 | + public static void AssertEqual(IEnumerable<TypeTestData> data1, IEnumerable<TypeTestData> data2) |
| 126 | + { |
| 127 | + Assert.NotNull(data1); |
| 128 | + Assert.NotNull(data2); |
| 129 | + Assert.Equal(data1.Count(), data2.Count()); |
| 130 | + |
| 131 | + foreach (var rowPair in data1.Zip(data2, Tuple.Create)) |
| 132 | + { |
| 133 | + AssertEqual(rowPair.Item1, rowPair.Item2); |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + /// <summary> |
| 138 | + /// Assert that two TypeTest datasets are equal. |
| 139 | + /// </summary> |
| 140 | + /// <param name="testType1">An <see cref="TypeTestData"/>.</param> |
| 141 | + /// <param name="testType2">An <see cref="TypeTestData"/>.</param> |
| 142 | + public static void AssertEqual(TypeTestData testType1, TypeTestData testType2) |
| 143 | + { |
| 144 | + Assert.Equal(testType1.Label, testType2.Label); |
| 145 | + Common.AssertEqual(testType1.Features, testType2.Features); |
| 146 | + Assert.Equal(testType1.I1, testType2.I1); |
| 147 | + Assert.Equal(testType1.U1, testType2.U1); |
| 148 | + Assert.Equal(testType1.I2, testType2.I2); |
| 149 | + Assert.Equal(testType1.U2, testType2.U2); |
| 150 | + Assert.Equal(testType1.I4, testType2.I4); |
| 151 | + Assert.Equal(testType1.U4, testType2.U4); |
| 152 | + Assert.Equal(testType1.I8, testType2.I8); |
| 153 | + Assert.Equal(testType1.U8, testType2.U8); |
| 154 | + Assert.Equal(testType1.R4, testType2.R4); |
| 155 | + Assert.Equal(testType1.R8, testType2.R8); |
| 156 | + Assert.Equal(testType1.Tx.ToString(), testType2.Tx.ToString()); |
| 157 | + Assert.True(testType1.Ts.Equals(testType2.Ts)); |
| 158 | + Assert.True(testType1.Dt.Equals(testType2.Dt)); |
| 159 | + Assert.True(testType1.Dz.Equals(testType2.Dz)); |
| 160 | + Assert.True(testType1.Ug.Equals(testType2.Ug)); |
| 161 | + } |
| 162 | + |
| 163 | + /// <summary> |
| 164 | + /// Check that a <see cref="RegressionMetrics"/> object is valid. |
| 165 | + /// </summary> |
| 166 | + /// <param name="metrics">The metrics object.</param> |
15 | 167 | public static void CheckMetrics(RegressionMetrics metrics)
|
16 | 168 | {
|
17 |
| - // Perform sanity checks on the metrics |
| 169 | + // Perform sanity checks on the metrics. |
18 | 170 | Assert.True(metrics.Rms >= 0);
|
19 | 171 | Assert.True(metrics.L1 >= 0);
|
20 | 172 | Assert.True(metrics.L2 >= 0);
|
|
0 commit comments