]> Repositorios git - scryer-prolog.git/commitdiff
Fix integer overflow in >>/2 and <</2
authorEmilie Burgun <[email protected]>
Sun, 26 Jan 2025 22:56:52 +0000 (23:56 +0100)
committerEmilie Burgun <[email protected]>
Sun, 26 Jan 2025 23:10:47 +0000 (00:10 +0100)
src/machine/arithmetic_ops.rs

index 5328901247ed8c3bb6067b663d0af81beb3eb744..eab7e7080102e62a184ac4b4669668b8fb0d23da 100644 (file)
@@ -629,103 +629,110 @@ pub(crate) fn int_floor_div(
     idiv(n1, n2, arena)
 }
 
-pub(crate) fn shr(n1: Number, n2: Number, arena: &mut Arena) -> Result<Number, MachineStubGen> {
+pub(crate) fn shr(lhs: Number, rhs: Number, arena: &mut Arena) -> Result<Number, MachineStubGen> {
     let stub_gen = || {
         let shr_atom = atom!(">>");
         functor_stub(shr_atom, 2)
     };
 
-    if n2.is_integer() && n2.is_negative() {
-        return shl(n1, neg(n2, arena), arena);
+    if rhs.is_integer() && rhs.is_negative() {
+        return shl(lhs, neg(rhs, arena), arena);
     }
 
-    match (n1, n2) {
-        (Number::Fixnum(n1), Number::Fixnum(n2)) => {
-            let n1_i = n1.get_num();
-            let n2_i = n2.get_num();
-
-            // FIXME(arithmetic_overflow)
-            // what should this do for too large n2,
-            // - logical right shift should probably turn to 0
-            // - arithmetic right shift should maybe differ for negative numbers
-            //
-            // note: negaitve n2 is already handled above
-            #[allow(arithmetic_overflow)]
-            if let Ok(n2) = usize::try_from(n2_i) {
-                Ok(Number::arena_from(n1_i >> n2, arena))
-            } else {
-                Ok(Number::arena_from(n1_i >> usize::MAX, arena))
-            }
-        }
-        (Number::Fixnum(n1), Number::Integer(n2)) => {
-            let n1 = Integer::from(n1.get_num());
-
-            let result: Result<usize, _> = (&*n2).try_into();
+    match lhs {
+        Number::Fixnum(lhs) => {
+            let rhs = match rhs {
+                Number::Fixnum(fix) => fix.get_num().try_into().unwrap_or(u32::MAX),
+                Number::Integer(int) => (&*int).try_into().unwrap_or(u32::MAX),
+                other => {
+                    return Err(numerical_type_error(ValidType::Integer, other, stub_gen));
+                }
+            };
 
-            match result {
-                Ok(n2) => Ok(Number::arena_from(n1 >> n2, arena)),
-                Err(_) => Ok(Number::arena_from(n1 >> usize::MAX, arena)),
-            }
-        }
-        (Number::Integer(n1), Number::Fixnum(n2)) => match usize::try_from(n2.get_num()) {
-            Ok(n2) => Ok(Number::arena_from(Integer::from(&*n1 >> n2), arena)),
-            _ => Ok(Number::arena_from(Integer::from(&*n1 >> usize::MAX), arena)),
-        },
-        (Number::Integer(n1), Number::Integer(n2)) => {
-            let result: Result<usize, _> = (&*n2).try_into();
+            let res = lhs.get_num().checked_shr(rhs).unwrap_or(0);
+            Ok(Number::arena_from(res, arena))
+        }
+        Number::Integer(lhs) => {
+            // Note: bigints require `log(n)` bits of space. If `rhs > usize::MAX`,
+            // then this clamping only becomes an issue when `lhs < 2 ^ (usize::MAX)`:
+            // - on 32-bit systems, `lhs` would need to be 512MiB big (1/8th of the addressable memory)
+            // - on 64-bit systems, `lhs` would need to be 2EiB big (!!!)
+            let rhs = match rhs {
+                Number::Fixnum(fix) => fix.get_num().try_into().unwrap_or(usize::MAX),
+                Number::Integer(int) => (&*int).try_into().unwrap_or(usize::MAX),
+                other => {
+                    return Err(numerical_type_error(ValidType::Integer, other, stub_gen));
+                }
+            };
 
-            match result {
-                Ok(n2) => Ok(Number::arena_from(Integer::from(&*n1 >> n2), arena)),
-                Err(_) => Ok(Number::arena_from(Integer::from(&*n1 >> usize::MAX), arena)),
-            }
+            Ok(Number::arena_from(Integer::from(&*lhs >> rhs), arena))
         }
-        (Number::Integer(_), n2) => Err(numerical_type_error(ValidType::Integer, n2, stub_gen)),
-        (Number::Fixnum(_), n2) => Err(numerical_type_error(ValidType::Integer, n2, stub_gen)),
-        (n1, _) => Err(numerical_type_error(ValidType::Integer, n1, stub_gen)),
+        other => Err(numerical_type_error(ValidType::Integer, other, stub_gen)),
     }
 }
 
-pub(crate) fn shl(n1: Number, n2: Number, arena: &mut Arena) -> Result<Number, MachineStubGen> {
+pub(crate) fn shl(lhs: Number, rhs: Number, arena: &mut Arena) -> Result<Number, MachineStubGen> {
     let stub_gen = || {
         let shl_atom = atom!("<<");
         functor_stub(shl_atom, 2)
     };
 
-    if n2.is_integer() && n2.is_negative() {
-        return shr(n1, neg(n2, arena), arena);
+    if rhs.is_integer() && rhs.is_negative() {
+        return shr(lhs, neg(rhs, arena), arena);
     }
 
-    match (n1, n2) {
-        (Number::Fixnum(n1), Number::Fixnum(n2)) => {
-            let n1_i = n1.get_num();
-            let n2_i = n2.get_num();
+    let rhs = match rhs {
+        Number::Fixnum(fix) => fix.get_num().try_into().unwrap_or(usize::MAX),
+        Number::Integer(int) => (&*int).try_into().unwrap_or(usize::MAX),
+        other => {
+            return Err(numerical_type_error(ValidType::Integer, other, stub_gen));
+        }
+    };
+
+    match lhs {
+        Number::Fixnum(lhs) => {
+            let lhs = lhs.get_num();
 
-            if let Ok(n2) = usize::try_from(n2_i) {
-                Ok(Number::arena_from(n1_i << n2, arena))
+            if let Some(res) = checked_signed_shl(lhs, rhs) {
+                Ok(Number::arena_from(res, arena))
             } else {
-                let n1 = Integer::from(n1_i);
-                Ok(Number::arena_from(n1 << usize::MAX, arena))
+                let lhs = Integer::from(lhs);
+                Ok(Number::arena_from(
+                    Integer::from(lhs << (rhs as usize)),
+                    arena,
+                ))
             }
         }
-        (Number::Fixnum(n1), Number::Integer(n2)) => {
-            let n1 = Integer::from(n1.get_num());
+        Number::Integer(lhs) => Ok(Number::arena_from(
+            Integer::from(&*lhs << (rhs as usize)),
+            arena,
+        )),
+        other => Err(numerical_type_error(ValidType::Integer, other, stub_gen)),
+    }
+}
 
-            match (&*n2).try_into() as Result<usize, _> {
-                Ok(n2) => Ok(Number::arena_from(n1 << n2, arena)),
-                _ => Ok(Number::arena_from(n1 << usize::MAX, arena)),
-            }
+/// Returns `x << shift`, checking for overflow and for values of `shift` that are too big.
+#[inline]
+fn checked_signed_shl(x: i64, shift: usize) -> Option<i64> {
+    if shift == 0 {
+        return Some(x);
+    }
+
+    if x >= 0 {
+        // Note: for unsigned integers, the condition would usually be spelled
+        // `shift <= x.leading_zeros()`, but since the MSB for signed integers
+        // controls the sign, we need to make sure that `shift` is at most
+        // `x.leading_zeros() - 1`.
+        if shift < x.leading_zeros() as usize {
+            Some(x << shift)
+        } else {
+            None
         }
-        (Number::Integer(n1), Number::Fixnum(n2)) => match usize::try_from(n2.get_num()) {
-            Ok(n2) => Ok(Number::arena_from(Integer::from(&*n1 << n2), arena)),
-            _ => Ok(Number::arena_from(Integer::from(&*n1 << usize::MAX), arena)),
-        },
-        (Number::Integer(n1), Number::Integer(n2)) => match (&*n2).try_into() as Result<usize, _> {
-            Ok(n2) => Ok(Number::arena_from(Integer::from(&*n1 << n2), arena)),
-            _ => Ok(Number::arena_from(Integer::from(&*n1 << usize::MAX), arena)),
-        },
-        (Number::Integer(_), n2) => Err(numerical_type_error(ValidType::Integer, n2, stub_gen)),
-        (Number::Fixnum(_), n2) => Err(numerical_type_error(ValidType::Integer, n2, stub_gen)),
-        (n1, _) => Err(numerical_type_error(ValidType::Integer, n1, stub_gen)),
+    } else {
+        let y = x.checked_neg()?;
+        // FIXME: incorrectly rejects `-2 ^ 62 << 1`. This is currently a non-issue,
+        // since the bitshift is then done as a `Number::Integer`
+        checked_signed_shl(y, shift).and_then(|res| res.checked_neg())
     }
 }