Skip to content

[NRBF] Fix bugs discovered by the fuzzer #107368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 6, 2024
5 changes: 4 additions & 1 deletion src/libraries/System.Formats.Nrbf/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
<value>{0} Record Type is not supported by design.</value>
</data>
<data name="Serialization_InvalidReference" xml:space="preserve">
<value>Member reference was pointing to a record of unexpected type.</value>
<value>Invalid member reference.</value>
</data>
<data name="Serialization_InvalidTypeName" xml:space="preserve">
<value>Invalid type name: `{0}`.</value>
Expand Down Expand Up @@ -162,4 +162,7 @@
<data name="Serialization_InvalidAssemblyName" xml:space="preserve">
<value>Invalid assembly name: `{0}`.</value>
</data>
<data name="Serialization_InvalidFormat" xml:space="preserve">
<value>Invalid format.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal static ClassTypeInfo Decode(BinaryReader reader, PayloadOptions options
string rawName = reader.ReadString();
SerializationRecordId libraryId = SerializationRecordId.Decode(reader);

BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId];
BinaryLibraryRecord library = recordMap.GetRecord<BinaryLibraryRecord>(libraryId);

return new ClassTypeInfo(rawName.ParseNonSystemClassRecordTypeName(library, options));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ internal static ClassWithIdRecord Decode(
SerializationRecordId id = SerializationRecordId.Decode(reader);
SerializationRecordId metadataId = SerializationRecordId.Decode(reader);

if (recordMap[metadataId] is not ClassRecord referencedRecord)
{
throw new SerializationException(SR.Serialization_InvalidReference);
}
ClassRecord referencedRecord = recordMap.GetRecord<ClassRecord>(metadataId);

return new ClassWithIdRecord(id, referencedRecord);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal static ClassWithMembersAndTypesRecord Decode(BinaryReader reader, Recor
MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, classInfo.MemberNames.Count, options, recordMap);
SerializationRecordId libraryId = SerializationRecordId.Decode(reader);

BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId];
BinaryLibraryRecord library = recordMap.GetRecord<BinaryLibraryRecord>(libraryId);
classInfo.LoadTypeName(library, options);

return new ClassWithMembersAndTypesRecord(classInfo, memberTypeInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordM
internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap)
=> new(SerializationRecordId.Decode(reader), recordMap);

internal SerializationRecord GetReferencedRecord() => RecordMap[Reference];
internal SerializationRecord GetReferencedRecord() => RecordMap.GetRecord(Reference);
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ public static SerializationRecord Decode(Stream payload, out IReadOnlyDictionary
#endif

using BinaryReader reader = new(payload, ThrowOnInvalidUtf8Encoding, leaveOpen: leaveOpen);
return Decode(reader, options ?? new(), out recordMap);
try
{
return Decode(reader, options ?? new(), out recordMap);
}
catch (FormatException) // can be thrown by various BinaryReader methods
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass to inner exception?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't do that on purpose, to avoid leaking any information from the inner exception that could be attacker controlled (like some weird string that could somehow affect the log file).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{
throw new SerializationException(SR.Serialization_InvalidFormat);
}
}

/// <summary>
Expand Down Expand Up @@ -213,12 +220,7 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op
private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap recordMap,
AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType)
{
byte nextByte = reader.ReadByte();
if (((uint)allowed & (1u << nextByte)) == 0)
{
ThrowHelper.ThrowForUnexpectedRecordType(nextByte);
}
recordType = (SerializationRecordType)nextByte;
recordType = reader.ReadSerializationRecordType(allowed);

SerializationRecord record = recordType switch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ internal void Add(SerializationRecord record)

internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header)
{
SerializationRecord rootRecord = _map[header.RootId];
SerializationRecord rootRecord = GetRecord(header.RootId);

if (rootRecord is SystemClassWithMembersAndTypesRecord systemClass)
{
// update the record map, so it's visible also to those who access it via Id
Expand All @@ -72,4 +73,14 @@ internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header)

return rootRecord;
}

internal SerializationRecord GetRecord(SerializationRecordId recordId)
=> _map.TryGetValue(recordId, out SerializationRecord? record)
? record
: throw new SerializationException(SR.Serialization_InvalidReference);

internal T GetRecord<T>(SerializationRecordId recordId) where T : SerializationRecord
=> _map.TryGetValue(recordId, out SerializationRecord? record) && record is T casted
? casted
: throw new SerializationException(SR.Serialization_InvalidReference);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ internal static class BinaryReaderExtensions
{
private static object? s_baseAmbiguousDstDateTime;

internal static SerializationRecordType ReadSerializationRecordType(this BinaryReader reader, AllowedRecordTypes allowed)
{
byte nextByte = reader.ReadByte();
if (nextByte > (byte)SerializationRecordType.MethodReturn // MethodReturn is the last defined value.
|| (nextByte > (byte)SerializationRecordType.ArraySingleString && nextByte < (byte)SerializationRecordType.MethodCall) // not part of the spec
|| ((uint)allowed & (1u << nextByte)) == 0) // valid, but not allowed
{
ThrowHelper.ThrowForUnexpectedRecordType(nextByte);
}

return (SerializationRecordType)nextByte;
}

internal static BinaryArrayType ReadArrayType(this BinaryReader reader)
{
byte arrayType = reader.ReadByte();
Expand Down
59 changes: 59 additions & 0 deletions src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,63 @@ public void ThrowsOnInvalidArrayType()
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}

[Theory]
[InlineData(18, typeof(NotSupportedException))] // not part of the spec, but still less than max allowed value (22)
[InlineData(19, typeof(NotSupportedException))] // same as above
[InlineData(20, typeof(NotSupportedException))] // same as above
[InlineData(23, typeof(SerializationException))] // not part of the spec and more than max allowed value (22)
[InlineData(64, typeof(SerializationException))] // same as above but also matches AllowedRecordTypes.SerializedStreamHeader
public void InvalidSerializationRecordType(byte recordType, Type expectedException)
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer);
writer.Write(recordType); // SerializationRecordType
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws(expectedException, () => NrbfDecoder.Decode(stream));
}

[Fact]
public void MissingRootRecord()
{
const int RootRecordId = 1;
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer, rootId: RootRecordId);
writer.Write((byte)SerializationRecordType.BinaryObjectString);
writer.Write(RootRecordId + 1); // a different ID
writer.Write("theString");
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}

[Fact]
public void Invalid7BitEncodedStringLength()
{
// The highest bit of the last byte is set (so it's invalid).
byte[] invalidLength = [byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue];

using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);

WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.BinaryObjectString);
writer.Write(1); // root record Id
writer.Write(invalidLength); // the length prefix
writer.Write(Encoding.UTF8.GetBytes("theString"));
writer.Write((byte)SerializationRecordType.MessageEnd);

stream.Position = 0;

Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
}
4 changes: 2 additions & 2 deletions src/libraries/System.Formats.Nrbf/tests/ReadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ protected static BinaryFormatter CreateBinaryFormatter()
};
#pragma warning restore SYSLIB0011 // Type or member is obsolete

protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0)
protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0, int rootId = 1)
{
writer.Write((byte)SerializationRecordType.SerializedStreamHeader);
writer.Write(1); // root ID
writer.Write(rootId); // root ID
writer.Write(1); // header ID
writer.Write(major); // major version
writer.Write(minor); // minor version
Expand Down
Loading