Skip to content

Commit 8670c12

Browse files
[release/9.0] Fix BigInteger.Rotate{Left,Right} for backport (#112991)
* Add BigInteger.Rotate* tests * Fix BigInteger.Rotate* * avoid stackalloc * Add comment * Fix the unsigned right shift operator of BigInteger (#112879) * Add tests for the shift operator of BigInteger * Fix the unsigned right shift operator of BigInteger * avoid stackalloc * external sign element --------- Co-authored-by: kzrnm <[email protected]>
1 parent b64f47a commit 8670c12

File tree

7 files changed

+1123
-121
lines changed

7 files changed

+1123
-121
lines changed

src/libraries/System.Runtime.Numerics/src/System/Numerics/BigInteger.cs

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,7 +1694,7 @@ private static BigInteger Add(ReadOnlySpan<uint> leftBits, int leftSign, ReadOnl
16941694
}
16951695

16961696
if (bitsFromPool != null)
1697-
ArrayPool<uint>.Shared.Return(bitsFromPool);
1697+
ArrayPool<uint>.Shared.Return(bitsFromPool);
16981698

16991699
return result;
17001700
}
@@ -2629,7 +2629,7 @@ public static implicit operator BigInteger(nuint value)
26292629

26302630
if (zdFromPool != null)
26312631
ArrayPool<uint>.Shared.Return(zdFromPool);
2632-
exit:
2632+
exit:
26332633
if (xdFromPool != null)
26342634
ArrayPool<uint>.Shared.Return(xdFromPool);
26352635

@@ -3232,7 +3232,27 @@ public static BigInteger PopCount(BigInteger value)
32323232
public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
32333233
{
32343234
value.AssertValid();
3235-
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);
3235+
3236+
bool negx = value._sign < 0;
3237+
uint smallBits = NumericsHelpers.Abs(value._sign);
3238+
scoped ReadOnlySpan<uint> bits = value._bits;
3239+
if (bits.IsEmpty)
3240+
{
3241+
bits = new ReadOnlySpan<uint>(in smallBits);
3242+
}
3243+
3244+
int xl = bits.Length;
3245+
if (negx && (bits[^1] >= kuMaskHighBit) && ((bits[^1] != kuMaskHighBit) || bits.IndexOfAnyExcept(0u) != (bits.Length - 1)))
3246+
{
3247+
// We check for a special case where its sign bit could be outside the uint array after 2's complement conversion.
3248+
// For example given [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF], its 2's complement is [0x01, 0x00, 0x00]
3249+
// After a 32 bit right shift, it becomes [0x00, 0x00] which is [0x00, 0x00] when converted back.
3250+
// The expected result is [0x00, 0x00, 0xFFFFFFFF] (2's complement) or [0x00, 0x00, 0x01] when converted back
3251+
// If the 2's component's last element is a 0, we will track the sign externally
3252+
++xl;
3253+
}
3254+
3255+
int byteCount = xl * 4;
32363256

32373257
// Normalize the rotate amount to drop full rotations
32383258
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
@@ -3249,14 +3269,13 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
32493269
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);
32503270

32513271
uint[]? xdFromPool = null;
3252-
int xl = value._bits?.Length ?? 1;
3253-
32543272
Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
32553273
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
32563274
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
32573275
xd = xd.Slice(0, xl);
3276+
xd[^1] = 0;
32583277

3259-
bool negx = value.GetPartsForBitManipulation(xd);
3278+
bits.CopyTo(xd);
32603279

32613280
int zl = xl;
32623281
uint[]? zdFromPool = null;
@@ -3367,7 +3386,28 @@ public static BigInteger RotateLeft(BigInteger value, int rotateAmount)
33673386
public static BigInteger RotateRight(BigInteger value, int rotateAmount)
33683387
{
33693388
value.AssertValid();
3370-
int byteCount = (value._bits is null) ? sizeof(int) : (value._bits.Length * 4);
3389+
3390+
3391+
bool negx = value._sign < 0;
3392+
uint smallBits = NumericsHelpers.Abs(value._sign);
3393+
scoped ReadOnlySpan<uint> bits = value._bits;
3394+
if (bits.IsEmpty)
3395+
{
3396+
bits = new ReadOnlySpan<uint>(in smallBits);
3397+
}
3398+
3399+
int xl = bits.Length;
3400+
if (negx && (bits[^1] >= kuMaskHighBit) && ((bits[^1] != kuMaskHighBit) || bits.IndexOfAnyExcept(0u) != (bits.Length - 1)))
3401+
{
3402+
// We check for a special case where its sign bit could be outside the uint array after 2's complement conversion.
3403+
// For example given [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF], its 2's complement is [0x01, 0x00, 0x00]
3404+
// After a 32 bit right shift, it becomes [0x00, 0x00] which is [0x00, 0x00] when converted back.
3405+
// The expected result is [0x00, 0x00, 0xFFFFFFFF] (2's complement) or [0x00, 0x00, 0x01] when converted back
3406+
// If the 2's component's last element is a 0, we will track the sign externally
3407+
++xl;
3408+
}
3409+
3410+
int byteCount = xl * 4;
33713411

33723412
// Normalize the rotate amount to drop full rotations
33733413
rotateAmount = (int)(rotateAmount % (byteCount * 8L));
@@ -3384,14 +3424,13 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
33843424
(int digitShift, int smallShift) = Math.DivRem(rotateAmount, kcbitUint);
33853425

33863426
uint[]? xdFromPool = null;
3387-
int xl = value._bits?.Length ?? 1;
3388-
33893427
Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold)
33903428
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
33913429
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl);
33923430
xd = xd.Slice(0, xl);
3431+
xd[^1] = 0;
33933432

3394-
bool negx = value.GetPartsForBitManipulation(xd);
3433+
bits.CopyTo(xd);
33953434

33963435
int zl = xl;
33973436
uint[]? zdFromPool = null;
@@ -3438,19 +3477,12 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
34383477
{
34393478
int carryShift = kcbitUint - smallShift;
34403479

3441-
int dstIndex = 0;
3442-
int srcIndex = digitShift;
3480+
int dstIndex = xd.Length - 1;
3481+
int srcIndex = digitShift == 0
3482+
? xd.Length - 1
3483+
: digitShift - 1;
34433484

3444-
uint carry = 0;
3445-
3446-
if (digitShift == 0)
3447-
{
3448-
carry = xd[^1] << carryShift;
3449-
}
3450-
else
3451-
{
3452-
carry = xd[srcIndex - 1] << carryShift;
3453-
}
3485+
uint carry = xd[digitShift] << carryShift;
34543486

34553487
do
34563488
{
@@ -3459,22 +3491,22 @@ public static BigInteger RotateRight(BigInteger value, int rotateAmount)
34593491
zd[dstIndex] = (part >> smallShift) | carry;
34603492
carry = part << carryShift;
34613493

3462-
dstIndex++;
3463-
srcIndex++;
3494+
dstIndex--;
3495+
srcIndex--;
34643496
}
3465-
while (srcIndex < xd.Length);
3497+
while ((uint)srcIndex < (uint)xd.Length); // is equivalent to (srcIndex >= 0 && srcIndex < xd.Length)
34663498

3467-
srcIndex = 0;
3499+
srcIndex = xd.Length - 1;
34683500

3469-
while (dstIndex < zd.Length)
3501+
while ((uint)dstIndex < (uint)zd.Length) // is equivalent to (dstIndex >= 0 && dstIndex < zd.Length)
34703502
{
34713503
uint part = xd[srcIndex];
34723504

34733505
zd[dstIndex] = (part >> smallShift) | carry;
34743506
carry = part << carryShift;
34753507

3476-
dstIndex++;
3477-
srcIndex++;
3508+
dstIndex--;
3509+
srcIndex--;
34783510
}
34793511
}
34803512

@@ -5232,13 +5264,32 @@ static bool INumberBase<BigInteger>.TryConvertToTruncating<TOther>(BigInteger va
52325264

52335265
BigInteger result;
52345266

5267+
bool negx = value._sign < 0;
5268+
uint smallBits = NumericsHelpers.Abs(value._sign);
5269+
scoped ReadOnlySpan<uint> bits = value._bits;
5270+
if (bits.IsEmpty)
5271+
{
5272+
bits = new ReadOnlySpan<uint>(in smallBits);
5273+
}
5274+
5275+
int xl = bits.Length;
5276+
if (negx && (bits[^1] >= kuMaskHighBit) && ((bits[^1] != kuMaskHighBit) || bits.IndexOfAnyExcept(0u) != (bits.Length - 1)))
5277+
{
5278+
// For a shift of N x 32 bit,
5279+
// We check for a special case where its sign bit could be outside the uint array after 2's complement conversion.
5280+
// For example given [0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF], its 2's complement is [0x01, 0x00, 0x00]
5281+
// After a 32 bit right shift, it becomes [0x00, 0x00] which is [0x00, 0x00] when converted back.
5282+
// The expected result is [0x00, 0x00, 0xFFFFFFFF] (2's complement) or [0x00, 0x00, 0x01] when converted back
5283+
// If the 2's component's last element is a 0, we will track the sign externally
5284+
++xl;
5285+
}
5286+
52355287
uint[]? xdFromPool = null;
5236-
int xl = value._bits?.Length ?? 1;
52375288
Span<uint> xd = (xl <= BigIntegerCalculator.StackAllocThreshold
52385289
? stackalloc uint[BigIntegerCalculator.StackAllocThreshold]
52395290
: xdFromPool = ArrayPool<uint>.Shared.Rent(xl)).Slice(0, xl);
5240-
5241-
bool negx = value.GetPartsForBitManipulation(xd);
5291+
xd[^1] = 0;
5292+
bits.CopyTo(xd);
52425293

52435294
if (negx)
52445295
{

src/libraries/System.Runtime.Numerics/tests/BigInteger/MyBigInt.cs

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,14 @@ public static BigInteger DoBinaryOperatorMine(BigInteger num1, BigInteger num2,
108108
return new BigInteger(Max(bytes1, bytes2).ToArray());
109109
case "b>>":
110110
return new BigInteger(ShiftLeft(bytes1, Negate(bytes2)).ToArray());
111+
case "b>>>":
112+
return new BigInteger(ShiftRightUnsigned(bytes1, bytes2).ToArray());
111113
case "b<<":
112114
return new BigInteger(ShiftLeft(bytes1, bytes2).ToArray());
115+
case "bRotateLeft":
116+
return new BigInteger(RotateLeft(bytes1, bytes2).ToArray());
117+
case "bRotateRight":
118+
return new BigInteger(RotateLeft(bytes1, Negate(bytes2)).ToArray());
113119
case "b^":
114120
return new BigInteger(Xor(bytes1, bytes2).ToArray());
115121
case "b|":
@@ -637,11 +643,68 @@ public static List<byte> Not(List<byte> bytes)
637643
return bnew;
638644
}
639645

646+
public static List<byte> ShiftRightUnsigned(List<byte> bytes1, List<byte> bytes2)
647+
{
648+
int byteShift = (int)new BigInteger(Divide(Copy(bytes2), new List<byte>(new byte[] { 8 })).ToArray());
649+
sbyte bitShift = (sbyte)new BigInteger(Remainder(Copy(bytes2), new List<byte>(new byte[] { 8 })).ToArray());
650+
651+
if (byteShift == 0 && bitShift == 0)
652+
return bytes1;
653+
654+
if (byteShift < 0 || bitShift < 0)
655+
return ShiftLeft(bytes1, Negate(bytes2));
656+
657+
Trim(bytes1);
658+
659+
byte fill = (bytes1[bytes1.Count - 1] & 0x80) != 0 ? byte.MaxValue : (byte)0;
660+
661+
if (fill == byte.MaxValue)
662+
{
663+
while (bytes1.Count % 4 != 0)
664+
{
665+
bytes1.Add(fill);
666+
}
667+
}
668+
669+
if (byteShift >= bytes1.Count)
670+
{
671+
return [fill];
672+
}
673+
674+
if (fill == byte.MaxValue)
675+
{
676+
bytes1.Add(0);
677+
}
678+
679+
for (int i = 0; i < bitShift; i++)
680+
{
681+
bytes1 = ShiftRight(bytes1);
682+
}
683+
684+
List<byte> temp = new List<byte>();
685+
for (int i = byteShift; i < bytes1.Count; i++)
686+
{
687+
temp.Add(bytes1[i]);
688+
}
689+
bytes1 = temp;
690+
691+
if (fill == byte.MaxValue && bytes1.Count % 4 == 1)
692+
{
693+
bytes1.RemoveAt(bytes1.Count - 1);
694+
}
695+
696+
Trim(bytes1);
697+
698+
return bytes1;
699+
}
700+
640701
public static List<byte> ShiftLeft(List<byte> bytes1, List<byte> bytes2)
641702
{
642703
int byteShift = (int)new BigInteger(Divide(Copy(bytes2), new List<byte>(new byte[] { 8 })).ToArray());
643704
sbyte bitShift = (sbyte)new BigInteger(Remainder(bytes2, new List<byte>(new byte[] { 8 })).ToArray());
644705

706+
Trim(bytes1);
707+
645708
for (int i = 0; i < Math.Abs(bitShift); i++)
646709
{
647710
if (bitShift < 0)
@@ -774,6 +837,105 @@ public static List<byte> ShiftRight(List<byte> bytes)
774837
return bresult;
775838
}
776839

840+
public static List<byte> RotateRight(List<byte> bytes)
841+
{
842+
List<byte> bresult = new List<byte>();
843+
844+
byte bottom = (byte)(bytes[0] & 0x01);
845+
846+
for (int i = 0; i < bytes.Count; i++)
847+
{
848+
byte newbyte = bytes[i];
849+
850+
newbyte = (byte)(newbyte / 2);
851+
if ((i != (bytes.Count - 1)) && ((bytes[i + 1] & 0x01) == 1))
852+
{
853+
newbyte += 128;
854+
}
855+
if ((i == (bytes.Count - 1)) && (bottom != 0))
856+
{
857+
newbyte += 128;
858+
}
859+
bresult.Add(newbyte);
860+
}
861+
862+
return bresult;
863+
}
864+
865+
public static List<byte> RotateLeft(List<byte> bytes)
866+
{
867+
List<byte> bresult = new List<byte>();
868+
869+
bool prevHead = (bytes[bytes.Count - 1] & 0x80) != 0;
870+
871+
for (int i = 0; i < bytes.Count; i++)
872+
{
873+
byte newbyte = bytes[i];
874+
875+
newbyte = (byte)(newbyte * 2);
876+
if (prevHead)
877+
{
878+
newbyte += 1;
879+
}
880+
881+
bresult.Add(newbyte);
882+
883+
prevHead = (bytes[i] & 0x80) != 0;
884+
}
885+
886+
return bresult;
887+
}
888+
889+
890+
public static List<byte> RotateLeft(List<byte> bytes1, List<byte> bytes2)
891+
{
892+
List<byte> bytes1Copy = Copy(bytes1);
893+
int byteShift = (int)new BigInteger(Divide(Copy(bytes2), new List<byte>(new byte[] { 8 })).ToArray());
894+
sbyte bitShift = (sbyte)new BigInteger(Remainder(bytes2, new List<byte>(new byte[] { 8 })).ToArray());
895+
896+
Trim(bytes1);
897+
898+
byte fill = (bytes1[bytes1.Count - 1] & 0x80) != 0 ? byte.MaxValue : (byte)0;
899+
900+
if (fill == 0 && bytes1.Count > 1 && bytes1[bytes1.Count - 1] == 0)
901+
bytes1.RemoveAt(bytes1.Count - 1);
902+
903+
while (bytes1.Count % 4 != 0)
904+
{
905+
bytes1.Add(fill);
906+
}
907+
908+
byteShift %= bytes1.Count;
909+
if (byteShift == 0 && bitShift == 0)
910+
return bytes1Copy;
911+
912+
for (int i = 0; i < Math.Abs(bitShift); i++)
913+
{
914+
if (bitShift < 0)
915+
{
916+
bytes1 = RotateRight(bytes1);
917+
}
918+
else
919+
{
920+
bytes1 = RotateLeft(bytes1);
921+
}
922+
}
923+
924+
List<byte> temp = new List<byte>();
925+
for (int i = 0; i < bytes1.Count; i++)
926+
{
927+
temp.Add(bytes1[(i - byteShift + bytes1.Count) % bytes1.Count]);
928+
}
929+
bytes1 = temp;
930+
931+
if (fill == 0)
932+
bytes1.Add(0);
933+
934+
Trim(bytes1);
935+
936+
return bytes1;
937+
}
938+
777939
public static List<byte> SetLength(List<byte> bytes, int size)
778940
{
779941
List<byte> bresult = new List<byte>();

0 commit comments

Comments
 (0)