From cf51338a770d68661f3939e9cdb3161ec769e298 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Bennet=20Ble=C3=9Fmann?= Date: Mon, 17 Feb 2025 21:23:08 +0100 Subject: [PATCH] further ffi cleanup --- src/ffi.rs | 268 +++++++++++++++++++++++++++-------------------------- 1 file changed, 139 insertions(+), 129 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index 1629a7b4..e3421220 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -29,13 +29,13 @@ use libffi::middle::{Arg, Cif, CodePtr, Type}; use libloading::{Library, Symbol}; use ordered_float::OrderedFloat; use std::alloc::{self, Layout}; -use std::any::Any; use std::collections::HashMap; use std::error::Error; use std::ffi::{c_void, CString}; use std::fmt::Debug; use std::marker::PhantomData; use std::ops::Deref; +use std::ptr::NonNull; pub struct FunctionDefinition { pub name: String, @@ -66,7 +66,7 @@ struct StructImpl { struct PointerArgs<'a, 'val> { memory: Vec, - phantom: PhantomData<&'a mut ArgValues<'val>>, + phantom: PhantomData<&'a mut ArgValue<'val>>, } impl Deref for PointerArgs<'_, '_> { @@ -77,7 +77,7 @@ impl Deref for PointerArgs<'_, '_> { } } -enum ArgValues<'a> { +enum ArgValue<'a> { U8(u8), I8(i8), U16(u16), @@ -89,14 +89,14 @@ enum ArgValues<'a> { F32(f32), F64(f64), Ptr(*mut c_void, PhantomData<&'a CString>), - Struct(Box), + Struct(FfiStruct), } -impl<'val> ArgValues<'val> { +impl<'val> ArgValue<'val> { fn new( val: &'val mut Value, arg_type: &Type, - structs_table: &mut HashMap, + structs_table: &HashMap, ) -> Result { match (unsafe { *arg_type.as_raw_ptr() }).type_ as u32 { libffi::raw::FFI_TYPE_UINT8 => Ok(Self::U8(val.as_int()?)), @@ -110,12 +110,49 @@ impl<'val> ArgValues<'val> { libffi::raw::FFI_TYPE_FLOAT => Ok(Self::F32(val.as_float()? as f32)), libffi::raw::FFI_TYPE_DOUBLE => Ok(Self::F64(val.as_float()?)), libffi::raw::FFI_TYPE_POINTER => Ok(Self::Ptr(val.as_ptr()?, PhantomData)), - libffi::raw::FFI_TYPE_STRUCT => Ok(Self::Struct( - ForeignFunctionTable::build_struct(val, structs_table)?.0, - )), + libffi::raw::FFI_TYPE_STRUCT => Ok(Self::Struct(ForeignFunctionTable::build_struct( + val, + structs_table, + )?)), _ => Err(FFIError::InvalidFFIType), } } + + fn build_args( + args: &'val mut [Value], + types: &[Type], + structs_table: &HashMap, + ) -> Result, FFIError> { + if types.len() != args.len() { + return Err(FFIError::ArgCountMismatch); + } + + args.iter_mut() + .zip(types) + .map(|(arg, arg_type)| ArgValue::new(arg, arg_type, structs_table)) + .collect::, _>>() + } +} + +struct FfiStruct { + ptr: NonNull, + layout: Layout, +} + +impl FfiStruct { + fn new(layout: Layout) -> Result { + if let Some(ptr) = NonNull::new(unsafe { alloc::alloc(layout) as *mut c_void }) { + Ok(FfiStruct { ptr, layout }) + } else { + Err(FFIError::AllocationFailed) + } + } +} + +impl Drop for FfiStruct { + fn drop(&mut self) { + unsafe { alloc::dealloc(self.ptr.as_ptr().cast(), self.layout) }; + } } impl ForeignFunctionTable { @@ -187,8 +224,7 @@ impl ForeignFunctionTable { let library = Library::new(library_name)?; for function in functions { let symbol_name: CString = CString::new(function.name.clone())?; - let code_ptr: Symbol<*mut c_void> = - library.get(&symbol_name.into_bytes_with_nul())?; + let code_ptr: Symbol<*mut c_void> = library.get(symbol_name.as_bytes_with_nul())?; let args: Vec<_> = function .args .iter() @@ -196,7 +232,7 @@ impl ForeignFunctionTable { .collect::>()?; let result = self.map_type_ffi(&function.return_value)?; - let cif = libffi::middle::Cif::new(args.clone(), result.clone()); + let cif = libffi::middle::Cif::new(args.iter().cloned(), result.clone()); let return_struct_name = if (*result.as_raw_ptr()).type_ as u32 == libffi::raw::FFI_TYPE_STRUCT { @@ -221,25 +257,23 @@ impl ForeignFunctionTable { Ok(()) } - fn build_pointer_args<'args, 'val>(args: &[ArgValues<'val>]) -> PointerArgs<'args, 'val> { + fn build_pointer_args<'args, 'val>(args: &[ArgValue<'val>]) -> PointerArgs<'args, 'val> { let args = args .iter() .map(|arg| match arg { - ArgValues::U8(a) => libffi::middle::arg(a), - ArgValues::I8(a) => libffi::middle::arg(a), - ArgValues::U16(a) => libffi::middle::arg(a), - ArgValues::I16(a) => libffi::middle::arg(a), - ArgValues::U32(a) => libffi::middle::arg(a), - ArgValues::I32(a) => libffi::middle::arg(a), - ArgValues::U64(a) => libffi::middle::arg(a), - ArgValues::I64(a) => libffi::middle::arg(a), - ArgValues::F32(a) => libffi::middle::arg(a), - ArgValues::F64(a) => libffi::middle::arg(a), - ArgValues::Ptr(ptr, _) => unsafe { std::mem::transmute::<*mut c_void, Arg>(*ptr) }, - ArgValues::Struct(s) => unsafe { - std::mem::transmute::<*const c_void, Arg>( - s.as_ref() as *const _ as *const c_void - ) + ArgValue::U8(a) => libffi::middle::arg(a), + ArgValue::I8(a) => libffi::middle::arg(a), + ArgValue::U16(a) => libffi::middle::arg(a), + ArgValue::I16(a) => libffi::middle::arg(a), + ArgValue::U32(a) => libffi::middle::arg(a), + ArgValue::I32(a) => libffi::middle::arg(a), + ArgValue::U64(a) => libffi::middle::arg(a), + ArgValue::I64(a) => libffi::middle::arg(a), + ArgValue::F32(a) => libffi::middle::arg(a), + ArgValue::F64(a) => libffi::middle::arg(a), + ArgValue::Ptr(ptr, _) => unsafe { std::mem::transmute::<*mut c_void, Arg>(*ptr) }, + ArgValue::Struct(s) => unsafe { + std::mem::transmute::<*mut c_void, Arg>(s.ptr.as_ptr()) }, }) .collect(); @@ -252,94 +286,78 @@ impl ForeignFunctionTable { fn build_struct( arg: &mut Value, - structs_table: &mut HashMap, - ) -> Result<(Box, usize, usize), FFIError> { - match arg { - Value::Struct(ref name, ref mut struct_args) => { - if let Some(ref mut struct_type) = structs_table.clone().get_mut(name) { - let ffi_type = unsafe { *struct_type.ffi_type.as_raw_ptr() }; - let layout = - Layout::from_size_align(ffi_type.size, ffi_type.alignment.into()).unwrap(); - let align = ffi_type.alignment as usize; - let size = ffi_type.size; - let ptr = unsafe { alloc::alloc(layout) as *mut c_void }; - - if ptr.is_null() { - panic!("allocation failed") - } + structs_table: &HashMap, + ) -> Result { + let Value::Struct(ref name, ref mut struct_args) = arg else { + return Err(FFIError::ValueCast); + }; - let mut field_ptr = ptr; - - #[allow(clippy::needless_range_loop)] - for i in 0..(struct_type.fields.len() - 1) { - macro_rules! try_write_int { - ($type:ty) => {{ - field_ptr = field_ptr - .add(field_ptr.align_offset(std::mem::align_of::<$type>())); - let n: $type = struct_args[i].as_int()?; - std::ptr::write(field_ptr as *mut $type, n); - field_ptr = field_ptr.add(std::mem::size_of::<$type>()); - }}; - } - - macro_rules! write { - ($type:ty, $value:expr) => {{ - let data: $type = $value; - std::ptr::write(field_ptr as *mut $type, data); - field_ptr = field_ptr.add(align); - }}; - } - - let field = &struct_type.fields[i]; - unsafe { - #[allow(clippy::wildcard_in_or_patterns)] - match (*field.as_raw_ptr()).type_ as u32 { - libffi::raw::FFI_TYPE_UINT8 => try_write_int!(u8), - libffi::raw::FFI_TYPE_SINT8 => try_write_int!(i8), - libffi::raw::FFI_TYPE_UINT16 => try_write_int!(u16), - libffi::raw::FFI_TYPE_SINT16 => try_write_int!(i16), - libffi::raw::FFI_TYPE_UINT32 => try_write_int!(u32), - libffi::raw::FFI_TYPE_SINT32 => try_write_int!(i32), - libffi::raw::FFI_TYPE_UINT64 => try_write_int!(u64), - libffi::raw::FFI_TYPE_SINT64 => try_write_int!(i64), - libffi::raw::FFI_TYPE_POINTER => { - write!(*mut c_void, struct_args[i].as_ptr()?) - } - libffi::raw::FFI_TYPE_FLOAT => { - write!(f32, struct_args[i].as_float()? as f32) - } - libffi::raw::FFI_TYPE_DOUBLE => { - write!(f64, struct_args[i].as_float()?) - } - libffi::raw::FFI_TYPE_STRUCT => { - let (struct_ptr, struct_size, struct_align) = - Self::build_struct(&mut struct_args[i], structs_table)?; - field_ptr = field_ptr.add(field_ptr.align_offset(struct_align)); - - std::ptr::copy( - &*struct_ptr as *const _ as *const c_void, - field_ptr, - struct_size, - ); - field_ptr = field_ptr.add(struct_size); - } - libffi::raw::FFI_TYPE_VOID - | libffi::raw::FFI_TYPE_INT - | libffi::raw::FFI_TYPE_LONGDOUBLE - | libffi::raw::FFI_TYPE_COMPLEX - | _ => return Err(FFIError::InvalidFFIType), - } - } - } + let Some(struct_type) = structs_table.get(name) else { + return Err(FFIError::InvalidStructName); + }; + + let args = ArgValue::build_args(struct_args, &struct_type.fields, structs_table)?; + + let ffi_type = unsafe { *struct_type.ffi_type.as_raw_ptr() }; + + let alloc = FfiStruct::new( + Layout::from_size_align(ffi_type.size, ffi_type.alignment.into()).unwrap(), + )?; + + let Ok(mut current_layout) = Layout::from_size_align(0, 1) else { + return Err(FFIError::AllocationFailed); + }; + + unsafe fn write_primitive( + ptr: NonNull, + layout: &mut Layout, + val: T, + ) -> Result<(), FFIError> { + let (new_layout, offset) = layout + .extend(Layout::new::()) + .map_err(|_| FFIError::AllocationFailed)?; + *layout = new_layout; + ptr.byte_offset(offset as isize).cast::().write(val); + Ok(()) + } - #[allow(clippy::from_raw_with_void_ptr)] - Ok((unsafe { Box::from_raw(ptr) }, size, align)) - } else { - Err(FFIError::InvalidStructName) + for arg in args { + unsafe { + match arg { + ArgValue::U8(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::I8(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::U16(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::I16(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::U32(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::I32(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::U64(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::I64(i) => write_primitive(alloc.ptr, &mut current_layout, i)?, + ArgValue::F32(f) => write_primitive(alloc.ptr, &mut current_layout, f)?, + ArgValue::F64(f) => write_primitive(alloc.ptr, &mut current_layout, f)?, + ArgValue::Ptr(p, _) => write_primitive(alloc.ptr, &mut current_layout, p)?, + ArgValue::Struct(arg) => { + let Ok((new_layout, offset)) = current_layout.extend(arg.layout) else { + return Err(FFIError::AllocationFailed); + }; + + current_layout = new_layout; + + std::ptr::copy( + arg.ptr.as_ptr(), + alloc.ptr.byte_offset(offset as isize).as_ptr(), + arg.layout.size(), + ); + } } } - _ => Err(FFIError::ValueCast), } + + if alloc.layout != current_layout.pad_to_align() { + // sanity check + return Err(FFIError::AllocationFailed); + } + + Ok(alloc) } pub fn exec( @@ -348,16 +366,9 @@ impl ForeignFunctionTable { mut args: Vec, arena: &mut Arena, ) -> Result { - let function_impl = self.table.get_mut(name).ok_or(FFIError::FunctionNotFound)?; - if function_impl.args.len() != args.len() { - return Err(FFIError::ArgCountMismatch); - } + let function_impl = self.table.get(name).ok_or(FFIError::FunctionNotFound)?; - let args = args - .iter_mut() - .zip(function_impl.args.iter()) - .map(|(arg, arg_type)| ArgValues::new(arg, arg_type, &mut self.structs)) - .collect::, _>>()?; + let args = ArgValue::build_args(&mut args, &function_impl.args, &self.structs)?; let args = Self::build_pointer_args(&args); @@ -409,19 +420,17 @@ impl ForeignFunctionTable { libffi::raw::FFI_TYPE_FLOAT => unsafe { call_and_return_float!(f32) }, libffi::raw::FFI_TYPE_DOUBLE => unsafe { call_and_return_float!(f64) }, libffi::raw::FFI_TYPE_STRUCT => { - let name = &function_impl + let name = function_impl .return_struct_name - .clone() + .as_ref() .ok_or(FFIError::StructNotFound)?; let struct_type = self.structs.get(name).ok_or(FFIError::StructNotFound)?; let ffi_type = unsafe { *struct_type.ffi_type.as_raw_ptr() }; + let layout = Layout::from_size_align(ffi_type.size, ffi_type.alignment.into()).unwrap(); - let ptr = unsafe { alloc::alloc(layout) }; - if ptr.is_null() { - return Err(FFIError::AllocationFailed); - } + let alloc = FfiStruct::new(layout)?; let ptr_args: &[Arg] = &args; @@ -429,13 +438,14 @@ impl ForeignFunctionTable { libffi::raw::ffi_call( function_impl.cif.as_raw_ptr(), Some(*function_impl.code_ptr.as_safe_fun()), - ptr as *mut c_void, + alloc.ptr.as_ptr(), ptr_args.as_ptr() as *mut *mut c_void, ) }; - let struct_val = self.read_struct(ptr as *mut c_void, name, struct_type, arena); + let struct_val = self.read_struct(alloc.ptr.as_ptr(), name, struct_type, arena); + + drop(alloc); - unsafe { alloc::dealloc(ptr, layout) }; struct_val } _ => unreachable!(), -- 2.54.0