From a67076d7cb1bb3cd7eb23abf0c3d95c3906a0565 Mon Sep 17 00:00:00 2001 From: Ian Jackson Date: Thu, 14 Nov 2024 20:55:26 +0000 Subject: [PATCH] wip --- src/lib.rs | 85 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 444317a..96abfe0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![allow(unused)] +use std::collections::HashSet; use std::fmt::{self, Debug}; use std::hint::unreachable_unchecked; use std::marker::PhantomData; @@ -45,6 +46,16 @@ pub unsafe trait IsMutToken<'a>: IsRefToken<'a> + Sized { l: (), } } + + fn multi_runtime<'r>(mut self) -> MultiRuntime<'r> + where 'a: 'r, + { + MultiRuntime { + tok: self.mut_token(), + ref_given: HashSet::new(), + mut_given: HashSet::new(), + } + } } unsafe impl<'a> IsRefToken<'a> for &'a NoAliasSingleton {} unsafe impl<'a> IsRefToken<'a> for &'a mut NoAliasSingleton {} @@ -233,39 +244,38 @@ pub struct MultiStatic<'a, L> { l: L, } -fn forbid_alias(this: &T, new: *const ()) -> Result<(), BorrowConflict> { - let this = this as *const T as *const (); - if this == new { return Err(BorrowConflict) } +fn forbid_alias(this: &T, new: NonNull<()>) -> Result<(), BorrowConflict> { + if NonNull::from(this) == new.cast() { return Err(BorrowConflict) } Ok(()) } pub unsafe trait MultiStaticList { - fn alias_check_ref(&self, p: *const ()) -> Result<(), BorrowConflict>; - fn alias_check_mut(&self, p: *const ()) -> Result<(), BorrowConflict>; + fn alias_check_ref(&self, p: NonNull<()>) -> Result<(), BorrowConflict>; + fn alias_check_mut(&self, p: NonNull<()>) -> Result<(), BorrowConflict>; } unsafe impl MultiStaticList for () { - fn alias_check_ref(&self, p: *const ()) -> Result<(), BorrowConflict> { + fn alias_check_ref(&self, p: NonNull<()>) -> Result<(), BorrowConflict> { Ok(()) } - fn alias_check_mut(&self, p: *const ()) -> Result<(), BorrowConflict> { + fn alias_check_mut(&self, p: NonNull<()>) -> Result<(), BorrowConflict> { Ok(()) } } unsafe impl MultiStaticList for (L, &T) { - fn alias_check_ref(&self, p: *const ()) -> Result<(), BorrowConflict> { + fn alias_check_ref(&self, p: NonNull<()>) -> Result<(), BorrowConflict> { self.0.alias_check_ref(p) } - fn alias_check_mut(&self, p: *const ()) -> Result<(), BorrowConflict> { + fn alias_check_mut(&self, p: NonNull<()>) -> Result<(), BorrowConflict> { forbid_alias(self.1, p)?; self.0.alias_check_mut(p) } } unsafe impl MultiStaticList for (L, &mut T) { - fn alias_check_ref(&self, p: *const ()) -> Result<(), BorrowConflict> { + fn alias_check_ref(&self, p: NonNull<()>) -> Result<(), BorrowConflict> { forbid_alias(&*self.1, p)?; self.0.alias_check_ref(p) } - fn alias_check_mut(&self, p: *const ()) -> Result<(), BorrowConflict> { + fn alias_check_mut(&self, p: NonNull<()>) -> Result<(), BorrowConflict> { forbid_alias(&*self.1, p)?; self.0.alias_check_mut(p) } @@ -278,7 +288,7 @@ impl<'a, L: MultiStaticList> MultiStatic<'a, L> { > where 'a: 'r { - match self.l.alias_check_ref(p.ptr.as_ptr() as *const ()) { + match self.l.alias_check_ref(p.ptr.cast()) { Err(e) => Err(self), Ok(()) => Ok(MultiStatic { tok: self.tok, @@ -293,7 +303,7 @@ impl<'a, L: MultiStaticList> MultiStatic<'a, L> { > where 'a: 'r { - match self.l.alias_check_mut(p.ptr.as_ptr() as *const ()) { + match self.l.alias_check_mut(p.ptr.cast()) { Err(e) => Err(self), Ok(()) => Ok(MultiStatic { tok: self.tok, @@ -307,6 +317,54 @@ impl<'a, L: MultiStaticList> MultiStatic<'a, L> { } } +pub struct MultiRuntime<'a> { + tok: MutToken<'a>, + ref_given: HashSet>, + mut_given: HashSet>, +} + +impl<'a> MultiRuntime<'a> { + #[inline] + pub fn borrow<'r, T>(&mut self, p: Ptr) + -> Result<&'r T, BorrowConflict> + where 'a: 'r + { + self.borrow_inner_check(p.ptr.cast())?; + Ok(unsafe { p.ptr.as_ref() }) + } + + fn borrow_inner_check(&mut self, p: NonNull<()>) + -> Result<(), BorrowConflict> + { + if self.mut_given.contains(&p) { + return Err(BorrowConflict) + } + self.ref_given.insert(p); + Ok(()) + } + + #[inline] + pub fn borrow_mut<'r, T>(&mut self, mut p: Ptr) + -> Result<&'r mut T, BorrowConflict> + where 'a: 'r + { + self.borrow_mut_inner_check(p.ptr.cast())?; + Ok(unsafe { p.ptr.as_mut() }) + } + + fn borrow_mut_inner_check(&mut self, p: NonNull<()>) + -> Result<(), BorrowConflict> +{ + if self.ref_given.contains(&p) { + return Err(BorrowConflict) + } + if !self.mut_given.insert(p) { + return Err(BorrowConflict) + } + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -478,3 +536,4 @@ mod tests { } + -- 2.30.2