]> Repositorios git - scryer-prolog.git/commitdiff
further ffi cleanup
authorBennet Bleßmann <[email protected]>
Mon, 17 Feb 2025 20:23:08 +0000 (21:23 +0100)
committerBennet Bleßmann <[email protected]>
Fri, 1 Aug 2025 17:10:38 +0000 (19:10 +0200)
src/ffi.rs

index 1629a7b45eb74586772134a5fdec72091736b806..e3421220ae402d5b1354dfd570c04826a4621285 100644 (file)
@@ -29,13 +29,13 @@ use libffi::middle::{Arg, Cif, CodePtr, Type};
 use libloading::{Library, Symbol};
 use ordered_float::OrderedFloat;
 use std::alloc::{self, Layout};
-use std::any::Any;
 use std::collections::HashMap;
 use std::error::Error;
 use std::ffi::{c_void, CString};
 use std::fmt::Debug;
 use std::marker::PhantomData;
 use std::ops::Deref;
+use std::ptr::NonNull;
 
 pub struct FunctionDefinition {
     pub name: String,
@@ -66,7 +66,7 @@ struct StructImpl {
 
 struct PointerArgs<'a, 'val> {
     memory: Vec<Arg>,
-    phantom: PhantomData<&'a mut ArgValues<'val>>,
+    phantom: PhantomData<&'a mut ArgValue<'val>>,
 }
 
 impl Deref for PointerArgs<'_, '_> {
@@ -77,7 +77,7 @@ impl Deref for PointerArgs<'_, '_> {
     }
 }
 
-enum ArgValues<'a> {
+enum ArgValue<'a> {
     U8(u8),
     I8(i8),
     U16(u16),
@@ -89,14 +89,14 @@ enum ArgValues<'a> {
     F32(f32),
     F64(f64),
     Ptr(*mut c_void, PhantomData<&'a CString>),
-    Struct(Box<dyn Any>),
+    Struct(FfiStruct),
 }
 
-impl<'val> ArgValues<'val> {
+impl<'val> ArgValue<'val> {
     fn new(
         val: &'val mut Value,
         arg_type: &Type,
-        structs_table: &mut HashMap<String, StructImpl>,
+        structs_table: &HashMap<String, StructImpl>,
     ) -> Result<Self, FFIError> {
         match (unsafe { *arg_type.as_raw_ptr() }).type_ as u32 {
             libffi::raw::FFI_TYPE_UINT8 => Ok(Self::U8(val.as_int()?)),
@@ -110,12 +110,49 @@ impl<'val> ArgValues<'val> {
             libffi::raw::FFI_TYPE_FLOAT => Ok(Self::F32(val.as_float()? as f32)),
             libffi::raw::FFI_TYPE_DOUBLE => Ok(Self::F64(val.as_float()?)),
             libffi::raw::FFI_TYPE_POINTER => Ok(Self::Ptr(val.as_ptr()?, PhantomData)),
-            libffi::raw::FFI_TYPE_STRUCT => Ok(Self::Struct(
-                ForeignFunctionTable::build_struct(val, structs_table)?.0,
-            )),
+            libffi::raw::FFI_TYPE_STRUCT => Ok(Self::Struct(ForeignFunctionTable::build_struct(
+                val,
+                structs_table,
+            )?)),
             _ => Err(FFIError::InvalidFFIType),
         }
     }
+
+    fn build_args(
+        args: &'val mut [Value],
+        types: &[Type],
+        structs_table: &HashMap<String, StructImpl>,
+    ) -> Result<Vec<Self>, FFIError> {
+        if types.len() != args.len() {
+            return Err(FFIError::ArgCountMismatch);
+        }
+
+        args.iter_mut()
+            .zip(types)
+            .map(|(arg, arg_type)| ArgValue::new(arg, arg_type, structs_table))
+            .collect::<Result<Vec<_>, _>>()
+    }
+}
+
+struct FfiStruct {
+    ptr: NonNull<c_void>,
+    layout: Layout,
+}
+
+impl FfiStruct {
+    fn new(layout: Layout) -> Result<Self, FFIError> {
+        if let Some(ptr) = NonNull::new(unsafe { alloc::alloc(layout) as *mut c_void }) {
+            Ok(FfiStruct { ptr, layout })
+        } else {
+            Err(FFIError::AllocationFailed)
+        }
+    }
+}
+
+impl Drop for FfiStruct {
+    fn drop(&mut self) {
+        unsafe { alloc::dealloc(self.ptr.as_ptr().cast(), self.layout) };
+    }
 }
 
 impl ForeignFunctionTable {
@@ -187,8 +224,7 @@ impl ForeignFunctionTable {
             let library = Library::new(library_name)?;
             for function in functions {
                 let symbol_name: CString = CString::new(function.name.clone())?;
-                let code_ptr: Symbol<*mut c_void> =
-                    library.get(&symbol_name.into_bytes_with_nul())?;
+                let code_ptr: Symbol<*mut c_void> = library.get(symbol_name.as_bytes_with_nul())?;
                 let args: Vec<_> = function
                     .args
                     .iter()
@@ -196,7 +232,7 @@ impl ForeignFunctionTable {
                     .collect::<Result<_, _>>()?;
                 let result = self.map_type_ffi(&function.return_value)?;
 
-                let cif = libffi::middle::Cif::new(args.clone(), result.clone());
+                let cif = libffi::middle::Cif::new(args.iter().cloned(), result.clone());
 
                 let return_struct_name =
                     if (*result.as_raw_ptr()).type_ as u32 == libffi::raw::FFI_TYPE_STRUCT {
@@ -221,25 +257,23 @@ impl ForeignFunctionTable {
         Ok(())
     }
 
-    fn build_pointer_args<'args, 'val>(args: &[ArgValues<'val>]) -> PointerArgs<'args, 'val> {
+    fn build_pointer_args<'args, 'val>(args: &[ArgValue<'val>]) -> PointerArgs<'args, 'val> {
         let args = args
             .iter()
             .map(|arg| match arg {
-                ArgValues::U8(a) => libffi::middle::arg(a),
-                ArgValues::I8(a) => libffi::middle::arg(a),
-                ArgValues::U16(a) => libffi::middle::arg(a),
-                ArgValues::I16(a) => libffi::middle::arg(a),
-                ArgValues::U32(a) => libffi::middle::arg(a),
-                ArgValues::I32(a) => libffi::middle::arg(a),
-                ArgValues::U64(a) => libffi::middle::arg(a),
-                ArgValues::I64(a) => libffi::middle::arg(a),
-                ArgValues::F32(a) => libffi::middle::arg(a),
-                ArgValues::F64(a) => libffi::middle::arg(a),
-                ArgValues::Ptr(ptr, _) => unsafe { std::mem::transmute::<*mut c_void, Arg>(*ptr) },
-                ArgValues::Struct(s) => unsafe {
-                    std::mem::transmute::<*const c_void, Arg>(
-                        s.as_ref() as *const _ as *const c_void
-                    )
+                ArgValue::U8(a) => libffi::middle::arg(a),
+                ArgValue::I8(a) => libffi::middle::arg(a),
+                ArgValue::U16(a) => libffi::middle::arg(a),
+                ArgValue::I16(a) => libffi::middle::arg(a),
+                ArgValue::U32(a) => libffi::middle::arg(a),
+                ArgValue::I32(a) => libffi::middle::arg(a),
+                ArgValue::U64(a) => libffi::middle::arg(a),
+                ArgValue::I64(a) => libffi::middle::arg(a),
+                ArgValue::F32(a) => libffi::middle::arg(a),
+                ArgValue::F64(a) => libffi::middle::arg(a),
+                ArgValue::Ptr(ptr, _) => unsafe { std::mem::transmute::<*mut c_void, Arg>(*ptr) },
+                ArgValue::Struct(s) => unsafe {
+                    std::mem::transmute::<*mut c_void, Arg>(s.ptr.as_ptr())
                 },
             })
             .collect();
@@ -252,94 +286,78 @@ impl ForeignFunctionTable {
 
     fn build_struct(
         arg: &mut Value,
-        structs_table: &mut HashMap<String, StructImpl>,
-    ) -> Result<(Box<dyn Any>, usize, usize), FFIError> {
-        match arg {
-            Value::Struct(ref name, ref mut struct_args) => {
-                if let Some(ref mut struct_type) = structs_table.clone().get_mut(name) {
-                    let ffi_type = unsafe { *struct_type.ffi_type.as_raw_ptr() };
-                    let layout =
-                        Layout::from_size_align(ffi_type.size, ffi_type.alignment.into()).unwrap();
-                    let align = ffi_type.alignment as usize;
-                    let size = ffi_type.size;
-                    let ptr = unsafe { alloc::alloc(layout) as *mut c_void };
-
-                    if ptr.is_null() {
-                        panic!("allocation failed")
-                    }
+        structs_table: &HashMap<String, StructImpl>,
+    ) -> Result<FfiStruct, FFIError> {
+        let Value::Struct(ref name, ref mut struct_args) = arg else {
+            return Err(FFIError::ValueCast);
+        };
 
-                    let mut field_ptr = ptr;
-
-                    #[allow(clippy::needless_range_loop)]
-                    for i in 0..(struct_type.fields.len() - 1) {
-                        macro_rules! try_write_int {
-                            ($type:ty) => {{
-                                field_ptr = field_ptr
-                                    .add(field_ptr.align_offset(std::mem::align_of::<$type>()));
-                                let n: $type = struct_args[i].as_int()?;
-                                std::ptr::write(field_ptr as *mut $type, n);
-                                field_ptr = field_ptr.add(std::mem::size_of::<$type>());
-                            }};
-                        }
-
-                        macro_rules! write {
-                            ($type:ty, $value:expr) => {{
-                                let data: $type = $value;
-                                std::ptr::write(field_ptr as *mut $type, data);
-                                field_ptr = field_ptr.add(align);
-                            }};
-                        }
-
-                        let field = &struct_type.fields[i];
-                        unsafe {
-                            #[allow(clippy::wildcard_in_or_patterns)]
-                            match (*field.as_raw_ptr()).type_ as u32 {
-                                libffi::raw::FFI_TYPE_UINT8 => try_write_int!(u8),
-                                libffi::raw::FFI_TYPE_SINT8 => try_write_int!(i8),
-                                libffi::raw::FFI_TYPE_UINT16 => try_write_int!(u16),
-                                libffi::raw::FFI_TYPE_SINT16 => try_write_int!(i16),
-                                libffi::raw::FFI_TYPE_UINT32 => try_write_int!(u32),
-                                libffi::raw::FFI_TYPE_SINT32 => try_write_int!(i32),
-                                libffi::raw::FFI_TYPE_UINT64 => try_write_int!(u64),
-                                libffi::raw::FFI_TYPE_SINT64 => try_write_int!(i64),
-                                libffi::raw::FFI_TYPE_POINTER => {
-                                    write!(*mut c_void, struct_args[i].as_ptr()?)
-                                }
-                                libffi::raw::FFI_TYPE_FLOAT => {
-                                    write!(f32, struct_args[i].as_float()? as f32)
-                                }
-                                libffi::raw::FFI_TYPE_DOUBLE => {
-                                    write!(f64, struct_args[i].as_float()?)
-                                }
-                                libffi::raw::FFI_TYPE_STRUCT => {
-                                    let (struct_ptr, struct_size, struct_align) =
-                                        Self::build_struct(&mut struct_args[i], structs_table)?;
-                                    field_ptr = field_ptr.add(field_ptr.align_offset(struct_align));
-
-                                    std::ptr::copy(
-                                        &*struct_ptr as *const _ as *const c_void,
-                                        field_ptr,
-                                        struct_size,
-                                    );
-                                    field_ptr = field_ptr.add(struct_size);
-                                }
-                                libffi::raw::FFI_TYPE_VOID
-                                | libffi::raw::FFI_TYPE_INT
-                                | libffi::raw::FFI_TYPE_LONGDOUBLE
-                                | libffi::raw::FFI_TYPE_COMPLEX
-                                | _ => return Err(FFIError::InvalidFFIType),
-                            }
-                        }
-                    }
+        let Some(struct_type) = structs_table.get(name) else {
+            return Err(FFIError::InvalidStructName);
+        };
+
+        let args = ArgValue::build_args(struct_args, &struct_type.fields, structs_table)?;
+
+        let ffi_type = unsafe { *struct_type.ffi_type.as_raw_ptr() };
+
+        let alloc = FfiStruct::new(
+            Layout::from_size_align(ffi_type.size, ffi_type.alignment.into()).unwrap(),
+        )?;
+
+        let Ok(mut current_layout) = Layout::from_size_align(0, 1) else {
+            return Err(FFIError::AllocationFailed);
+        };
+
+        unsafe fn write_primitive<T>(
+            ptr: NonNull<c_void>,
+            layout: &mut Layout,
+            val: T,
+        ) -> Result<(), FFIError> {
+            let (new_layout, offset) = layout
+                .extend(Layout::new::<T>())
+                .map_err(|_| FFIError::AllocationFailed)?;
+            *layout = new_layout;
+            ptr.byte_offset(offset as isize).cast::<T>().write(val);
+            Ok(())
+        }
 
-                    #[allow(clippy::from_raw_with_void_ptr)]
-                    Ok((unsafe { Box::from_raw(ptr) }, size, align))
-                } else {
-                    Err(FFIError::InvalidStructName)
+        for arg in args {
+            unsafe {
+                match arg {
+                    ArgValue::U8(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::I8(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::U16(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::I16(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::U32(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::I32(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::U64(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::I64(i) => write_primitive(alloc.ptr, &mut current_layout, i)?,
+                    ArgValue::F32(f) => write_primitive(alloc.ptr, &mut current_layout, f)?,
+                    ArgValue::F64(f) => write_primitive(alloc.ptr, &mut current_layout, f)?,
+                    ArgValue::Ptr(p, _) => write_primitive(alloc.ptr, &mut current_layout, p)?,
+                    ArgValue::Struct(arg) => {
+                        let Ok((new_layout, offset)) = current_layout.extend(arg.layout) else {
+                            return Err(FFIError::AllocationFailed);
+                        };
+
+                        current_layout = new_layout;
+
+                        std::ptr::copy(
+                            arg.ptr.as_ptr(),
+                            alloc.ptr.byte_offset(offset as isize).as_ptr(),
+                            arg.layout.size(),
+                        );
+                    }
                 }
             }
-            _ => Err(FFIError::ValueCast),
         }
+
+        if alloc.layout != current_layout.pad_to_align() {
+            // sanity check
+            return Err(FFIError::AllocationFailed);
+        }
+
+        Ok(alloc)
     }
 
     pub fn exec(
@@ -348,16 +366,9 @@ impl ForeignFunctionTable {
         mut args: Vec<Value>,
         arena: &mut Arena,
     ) -> Result<Value, FFIError> {
-        let function_impl = self.table.get_mut(name).ok_or(FFIError::FunctionNotFound)?;
-        if function_impl.args.len() != args.len() {
-            return Err(FFIError::ArgCountMismatch);
-        }
+        let function_impl = self.table.get(name).ok_or(FFIError::FunctionNotFound)?;
 
-        let args = args
-            .iter_mut()
-            .zip(function_impl.args.iter())
-            .map(|(arg, arg_type)| ArgValues::new(arg, arg_type, &mut self.structs))
-            .collect::<Result<Vec<_>, _>>()?;
+        let args = ArgValue::build_args(&mut args, &function_impl.args, &self.structs)?;
 
         let args = Self::build_pointer_args(&args);
 
@@ -409,19 +420,17 @@ impl ForeignFunctionTable {
             libffi::raw::FFI_TYPE_FLOAT => unsafe { call_and_return_float!(f32) },
             libffi::raw::FFI_TYPE_DOUBLE => unsafe { call_and_return_float!(f64) },
             libffi::raw::FFI_TYPE_STRUCT => {
-                let name = &function_impl
+                let name = function_impl
                     .return_struct_name
-                    .clone()
+                    .as_ref()
                     .ok_or(FFIError::StructNotFound)?;
                 let struct_type = self.structs.get(name).ok_or(FFIError::StructNotFound)?;
                 let ffi_type = unsafe { *struct_type.ffi_type.as_raw_ptr() };
+
                 let layout =
                     Layout::from_size_align(ffi_type.size, ffi_type.alignment.into()).unwrap();
-                let ptr = unsafe { alloc::alloc(layout) };
 
-                if ptr.is_null() {
-                    return Err(FFIError::AllocationFailed);
-                }
+                let alloc = FfiStruct::new(layout)?;
 
                 let ptr_args: &[Arg] = &args;
 
@@ -429,13 +438,14 @@ impl ForeignFunctionTable {
                     libffi::raw::ffi_call(
                         function_impl.cif.as_raw_ptr(),
                         Some(*function_impl.code_ptr.as_safe_fun()),
-                        ptr as *mut c_void,
+                        alloc.ptr.as_ptr(),
                         ptr_args.as_ptr() as *mut *mut c_void,
                     )
                 };
-                let struct_val = self.read_struct(ptr as *mut c_void, name, struct_type, arena);
+                let struct_val = self.read_struct(alloc.ptr.as_ptr(), name, struct_type, arena);
+
+                drop(alloc);
 
-                unsafe { alloc::dealloc(ptr, layout) };
                 struct_val
             }
             _ => unreachable!(),