diff --git a/crate/memory/src/cow.rs b/crate/memory/src/cow.rs new file mode 100644 index 0000000..a62d629 --- /dev/null +++ b/crate/memory/src/cow.rs @@ -0,0 +1,175 @@ +//! Shared memory & Copy-on-write extension for page table + +use super::paging::*; +use super::*; +use alloc::BTreeMap; +use core::ops::{Deref, DerefMut}; + +/// Wrapper for page table, supporting shared map & copy-on-write +struct CowExt<T: PageTable> { + page_table: T, + rc_map: FrameRcMap, +} + +impl<T: PageTable> CowExt<T> { + pub fn new(page_table: T) -> Self { + CowExt { + page_table, + rc_map: FrameRcMap::default(), + } + } + pub fn map_to_shared(&mut self, addr: VirtAddr, target: PhysAddr, writable: bool) { + let entry = self.page_table.map(addr, target); + entry.set_writable(false); + entry.set_shared(writable); + let frame = target / PAGE_SIZE; + match writable { + true => self.rc_map.write_increase(&frame), + false => self.rc_map.read_increase(&frame), + } + } + pub fn unmap_shared(&mut self, addr: VirtAddr) { + { + let entry = self.page_table.get_entry(addr); + let frame = entry.target() / PAGE_SIZE; + if entry.readonly_shared() { + self.rc_map.read_decrease(&frame); + } else if entry.writable_shared() { + self.rc_map.write_decrease(&frame); + } + } + self.page_table.unmap(addr); + } + /// This function must be called whenever PageFault happens. + /// Return whether copy-on-write happens. + pub fn page_fault_handler(&mut self, addr: VirtAddr, alloc_frame: impl FnOnce() -> PhysAddr) -> bool { + { + let entry = self.page_table.get_entry(addr); + if !entry.readonly_shared() && !entry.writable_shared() { + return false; + } + let frame = entry.target() / PAGE_SIZE; + if self.rc_map.read_count(&frame) == 0 && self.rc_map.write_count(&frame) == 1 { + entry.clear_shared(); + entry.set_writable(true); + self.rc_map.write_decrease(&frame); + return true; + } + } + use core::mem::uninitialized; + let mut temp_data: [u8; PAGE_SIZE] = unsafe { uninitialized() }; + self.read_page(addr, &mut temp_data[..]); + + self.unmap_shared(addr); + self.map(addr, alloc_frame()); + + self.write_page(addr, &temp_data[..]); + true + } +} + +impl<T: PageTable> Deref for CowExt<T> { + type Target = T; + + fn deref(&self) -> &<Self as Deref>::Target { + &self.page_table + } +} + +impl<T: PageTable> DerefMut for CowExt<T> { + fn deref_mut(&mut self) -> &mut <Self as Deref>::Target { + &mut self.page_table + } +} + +/// A map contains reference count for shared frame +#[derive(Default)] +struct FrameRcMap(BTreeMap<Frame, (u8, u8)>); + +type Frame = usize; + +impl FrameRcMap { + fn read_count(&mut self, frame: &Frame) -> u8 { + self.0.get(frame).unwrap_or(&(0, 0)).0 + } + fn write_count(&mut self, frame: &Frame) -> u8 { + self.0.get(frame).unwrap_or(&(0, 0)).1 + } + fn read_increase(&mut self, frame: &Frame) { + let (r, w) = self.0.get(&frame).unwrap_or(&(0, 0)).clone(); + self.0.insert(frame.clone(), (r + 1, w)); + } + fn read_decrease(&mut self, frame: &Frame) { + self.0.get_mut(frame).unwrap().0 -= 1; + } + fn write_increase(&mut self, frame: &Frame) { + let (r, w) = self.0.get(&frame).unwrap_or(&(0, 0)).clone(); + self.0.insert(frame.clone(), (r, w + 1)); + } + fn write_decrease(&mut self, frame: &Frame) { + self.0.get_mut(frame).unwrap().1 -= 1; + } +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::boxed::Box; + + #[test] + fn test() { + let mut pt = CowExt::new(MockPageTable::new()); + let pt0 = unsafe { &mut *(&mut pt as *mut CowExt<MockPageTable>) }; + + struct FrameAlloc(usize); + impl FrameAlloc { + fn alloc(&mut self) -> PhysAddr { + let pa = self.0 * PAGE_SIZE; + self.0 += 1; + pa + } + } + let mut alloc = FrameAlloc(4); + + pt.page_table.set_handler(Box::new(move |_, addr: VirtAddr| { + pt0.page_fault_handler(addr, || alloc.alloc()); + })); + let target = 0x0; + let frame = 0x0; + + pt.map(0x1000, target); + pt.write(0x1000, 1); + assert_eq!(pt.read(0x1000), 1); + pt.unmap(0x1000); + + pt.map_to_shared(0x1000, target, true); + pt.map_to_shared(0x2000, target, true); + pt.map_to_shared(0x3000, target, false); + assert_eq!(pt.rc_map.read_count(&frame), 1); + assert_eq!(pt.rc_map.write_count(&frame), 2); + assert_eq!(pt.read(0x1000), 1); + assert_eq!(pt.read(0x2000), 1); + assert_eq!(pt.read(0x3000), 1); + + pt.write(0x1000, 2); + assert_eq!(pt.rc_map.read_count(&frame), 1); + assert_eq!(pt.rc_map.write_count(&frame), 1); + assert_ne!(pt.get_entry(0x1000).target(), target); + assert_eq!(pt.read(0x1000), 2); + assert_eq!(pt.read(0x2000), 1); + assert_eq!(pt.read(0x3000), 1); + + pt.unmap_shared(0x3000); + assert_eq!(pt.rc_map.read_count(&frame), 0); + assert_eq!(pt.rc_map.write_count(&frame), 1); + assert!(!pt.get_entry(0x3000).present()); + + pt.write(0x2000, 3); + assert_eq!(pt.rc_map.read_count(&frame), 0); + assert_eq!(pt.rc_map.write_count(&frame), 0); + assert_eq!(pt.get_entry(0x2000).target(), target, + "The last write reference should not allocate new frame."); + assert_eq!(pt.read(0x1000), 2); + assert_eq!(pt.read(0x2000), 3); + } +} \ No newline at end of file diff --git a/crate/memory/src/lib.rs b/crate/memory/src/lib.rs index 40a15e0..e75ca48 100644 --- a/crate/memory/src/lib.rs +++ b/crate/memory/src/lib.rs @@ -4,6 +4,7 @@ extern crate alloc; pub mod paging; +pub mod cow; //pub mod swap; type VirtAddr = usize; diff --git a/crate/memory/src/paging/mock_page_table.rs b/crate/memory/src/paging/mock_page_table.rs index 5c65849..3f10135 100644 --- a/crate/memory/src/paging/mock_page_table.rs +++ b/crate/memory/src/paging/mock_page_table.rs @@ -17,6 +17,8 @@ pub struct MockEntry { writable: bool, accessed: bool, dirty: bool, + writable_shared: bool, + readonly_shared: bool, } impl Entry for MockEntry { @@ -29,6 +31,17 @@ impl Entry for MockEntry { fn set_writable(&mut self, value: bool) { self.writable = value; } fn set_present(&mut self, value: bool) { self.present = value; } fn target(&self) -> usize { self.target } + + fn writable_shared(&self) -> bool { self.writable_shared } + fn readonly_shared(&self) -> bool { self.readonly_shared } + fn set_shared(&mut self, writable: bool) { + self.writable_shared = writable; + self.readonly_shared = !writable; + } + fn clear_shared(&mut self) { + self.writable_shared = false; + self.readonly_shared = false; + } } type PageFaultHandler = Box<FnMut(&mut MockPageTable, VirtAddr)>; @@ -50,10 +63,19 @@ impl PageTable for MockPageTable { assert!(entry.present); entry.present = false; } - fn get_entry(&mut self, addr: VirtAddr) -> &mut <Self as PageTable>::Entry { &mut self.entries[addr / PAGE_SIZE] } + fn read_page(&mut self, addr: usize, data: &mut [u8]) { + self._read(addr); + let pa = self.translate(addr) & !(PAGE_SIZE - 1); + data.copy_from_slice(&self.data[pa..pa + PAGE_SIZE]); + } + fn write_page(&mut self, addr: usize, data: &[u8]) { + self._write(addr); + let pa = self.translate(addr) & !(PAGE_SIZE - 1); + self.data[pa..pa + PAGE_SIZE].copy_from_slice(data); + } } impl MockPageTable { @@ -78,29 +100,32 @@ impl MockPageTable { fn translate(&self, addr: VirtAddr) -> PhysAddr { let entry = &self.entries[addr / PAGE_SIZE]; assert!(entry.present); - (entry.target & !(PAGE_SIZE - 1)) | (addr & (PAGE_SIZE - 1)) - } - fn get_data_mut(&mut self, addr: VirtAddr) -> &mut u8 { - let pa = self.translate(addr); + let pa = (entry.target & !(PAGE_SIZE - 1)) | (addr & (PAGE_SIZE - 1)); assert!(pa < self.data.len(), "Physical memory access out of range"); - &mut self.data[pa] + pa } - /// Read memory, mark accessed, trigger page fault if not present - pub fn read(&mut self, addr: VirtAddr) -> u8 { + fn _read(&mut self, addr: VirtAddr) { while !self.entries[addr / PAGE_SIZE].present { self.trigger_page_fault(addr); } self.entries[addr / PAGE_SIZE].accessed = true; - *self.get_data_mut(addr) } - /// Write memory, mark accessed and dirty, trigger page fault if not present - pub fn write(&mut self, addr: VirtAddr, data: u8) { + fn _write(&mut self, addr: VirtAddr) { while !(self.entries[addr / PAGE_SIZE].present && self.entries[addr / PAGE_SIZE].writable) { self.trigger_page_fault(addr); } self.entries[addr / PAGE_SIZE].accessed = true; self.entries[addr / PAGE_SIZE].dirty = true; - *self.get_data_mut(addr) = data; + } + /// Read memory, mark accessed, trigger page fault if not present + pub fn read(&mut self, addr: VirtAddr) -> u8 { + self._read(addr); + self.data[self.translate(addr)] + } + /// Write memory, mark accessed and dirty, trigger page fault if not present + pub fn write(&mut self, addr: VirtAddr, data: u8) { + self._write(addr); + self.data[self.translate(addr)] = data; } } diff --git a/crate/memory/src/paging/mod.rs b/crate/memory/src/paging/mod.rs index f2099ec..03f5f65 100644 --- a/crate/memory/src/paging/mod.rs +++ b/crate/memory/src/paging/mod.rs @@ -8,6 +8,8 @@ pub trait PageTable { fn map(&mut self, addr: VirtAddr, target: PhysAddr) -> &mut Self::Entry; fn unmap(&mut self, addr: VirtAddr); fn get_entry(&mut self, addr: VirtAddr) -> &mut Self::Entry; + fn read_page(&mut self, addr: VirtAddr, data: &mut [u8]); + fn write_page(&mut self, addr: VirtAddr, data: &[u8]); } pub trait Entry { @@ -26,4 +28,10 @@ pub trait Entry { fn set_present(&mut self, value: bool); fn target(&self) -> PhysAddr; + + // For Copy-on-write extension + fn writable_shared(&self) -> bool; + fn readonly_shared(&self) -> bool; + fn set_shared(&mut self, writable: bool); + fn clear_shared(&mut self); }