refactor driver: make (Net)Driver Sync. may help avoid deadlock?

master
WangRunji 6 years ago
parent 1f2625e565
commit 9e6483f488

@ -91,8 +91,7 @@ pub extern fn rust_trap(tf: &mut TrapFrame) {
COM2 => com2(),
IDE => ide(),
_ => {
let mut drivers = DRIVERS.lock();
for driver in drivers.iter_mut() {
for driver in DRIVERS.read().iter() {
if driver.try_handle_interrupt() == true {
debug!("driver processed interrupt");
return;

@ -28,8 +28,7 @@ pub struct VirtIOBlk {
capacity: usize
}
#[derive(Clone)]
pub struct VirtIOBlkDriver(Arc<Mutex<VirtIOBlk>>);
pub struct VirtIOBlkDriver(Mutex<VirtIOBlk>);
#[repr(C)]
@ -92,7 +91,7 @@ bitflags! {
}
impl Driver for VirtIOBlkDriver {
fn try_handle_interrupt(&mut self) -> bool {
fn try_handle_interrupt(&self) -> bool {
let mut driver = self.0.lock();
// ensure header page is mapped
@ -167,15 +166,15 @@ pub fn virtio_blk_init(node: &Node) {
// configure two virtqueues: ingress and egress
header.guest_page_size.write(PAGE_SIZE as u32); // one page
let mut driver = VirtIOBlkDriver(Arc::new(Mutex::new(VirtIOBlk {
let mut driver = VirtIOBlkDriver(Mutex::new(VirtIOBlk {
interrupt: node.prop_u32("interrupts").unwrap(),
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize,
queue: VirtIOVirtqueue::new(header, 0, 16),
capacity: config.capacity.read() as usize,
})));
}));
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
DRIVERS.lock().push(Box::new(driver));
DRIVERS.write().push(Arc::new(driver));
}

@ -1,5 +1,6 @@
use alloc::alloc::{GlobalAlloc, Layout};
use alloc::prelude::*;
use alloc::sync::Arc;
use core::slice;
use bitflags::*;
@ -14,6 +15,7 @@ use crate::arch::cpu;
use crate::HEAP_ALLOCATOR;
use crate::memory::active_table;
use crate::arch::consts::{KERNEL_OFFSET, MEMORY_OFFSET};
use crate::sync::SpinNoIrqLock as Mutex;
use super::super::{DeviceType, Driver, DRIVERS};
use super::super::bus::virtio_mmio::*;
@ -24,7 +26,7 @@ const VIRTIO_GPU_EVENT_DISPLAY : u32 = 1 << 0;
struct VirtIOGpu {
interrupt_parent: u32,
interrupt: u32,
header: usize,
header: &'static mut VirtIOHeader,
queue_buffer: [usize; 2],
frame_buffer: usize,
rect: VirtIOGpuRect,
@ -185,20 +187,25 @@ const VIRTIO_BUFFER_RECEIVE: usize = 1;
const VIRTIO_GPU_RESOURCE_ID: u32 = 0xbabe;
impl Driver for VirtIOGpu {
fn try_handle_interrupt(&mut self) -> bool {
pub struct VirtIOGpuDriver(Mutex<VirtIOGpu>);
impl Driver for VirtIOGpuDriver {
fn try_handle_interrupt(&self) -> bool {
// for simplicity
if cpu::id() > 0 {
return false
}
let mut driver = self.0.lock();
// ensure header page is mapped
active_table().map_if_not_exists(self.header as usize, self.header as usize);
// TODO: this should be mapped in all page table by default
let header_addr = &mut driver.header as *mut _ as usize;
active_table().map_if_not_exists(header_addr, header_addr);
let header = unsafe { &mut *(self.header as *mut VirtIOHeader) };
let interrupt = header.interrupt_status.read();
let interrupt = driver.header.interrupt_status.read();
if interrupt != 0 {
header.interrupt_ack.write(interrupt);
driver.header.interrupt_ack.write(interrupt);
debug!("Got interrupt {:?}", interrupt);
return true;
}
@ -331,15 +338,18 @@ pub fn virtio_gpu_init(node: &Node) {
header.guest_page_size.write(PAGE_SIZE as u32); // one page
let queue_num = 2;
let queues = [
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_CURSOR, queue_num)
];
let mut driver = VirtIOGpu {
interrupt: node.prop_u32("interrupts").unwrap(),
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize,
header,
queue_buffer: [0, 0],
frame_buffer: 0,
rect: VirtIOGpuRect::default(),
queues: [VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_CURSOR, queue_num)]
queues,
};
for buffer in 0..2 {
@ -351,9 +361,10 @@ pub fn virtio_gpu_init(node: &Node) {
debug!("buffer {} using page address {:#X}", buffer, page as usize);
}
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
driver.header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
setup_framebuffer(&mut driver);
DRIVERS.lock().push(Box::new(driver));
let driver = Arc::new(VirtIOGpuDriver(Mutex::new(driver)));
DRIVERS.write().push(driver);
}

@ -1,5 +1,6 @@
use alloc::prelude::*;
use alloc::vec;
use alloc::sync::Arc;
use core::fmt;
use core::mem::size_of;
use core::mem::transmute_copy;
@ -15,6 +16,7 @@ use volatile::Volatile;
use crate::arch::cpu;
use crate::memory::active_table;
use crate::sync::SpinNoIrqLock as Mutex;
use super::super::{DeviceType, Driver, DRIVERS};
use super::super::bus::virtio_mmio::*;
@ -22,7 +24,7 @@ use super::super::bus::virtio_mmio::*;
struct VirtIOInput {
interrupt_parent: u32,
interrupt: u32,
header: usize,
header: &'static mut VirtIOHeader,
// 0 for event, 1 for status
queues: [VirtIOVirtqueue; 2],
x: isize,
@ -121,7 +123,9 @@ bitflags! {
const VIRTIO_QUEUE_EVENT: usize = 0;
const VIRTIO_QUEUE_STATUS: usize = 1;
impl Driver for VirtIOInput {
pub struct VirtIOInputDriver(Mutex<VirtIOInput>);
impl VirtIOInput {
fn try_handle_interrupt(&mut self) -> bool {
// for simplicity
if cpu::id() > 0 {
@ -129,14 +133,16 @@ impl Driver for VirtIOInput {
}
// ensure header page is mapped
active_table().map_if_not_exists(self.header as usize, self.header as usize);
let header = unsafe { &mut *(self.header as *mut VirtIOHeader) };
let interrupt = header.interrupt_status.read();
// TODO: this should be mapped in all page table by default
let header_addr = self.header as *mut _ as usize;
active_table().map_if_not_exists(header_addr, header_addr);
let interrupt = self.header.interrupt_status.read();
if interrupt != 0 {
header.interrupt_ack.write(interrupt);
self.header.interrupt_ack.write(interrupt);
debug!("Got interrupt {:?}", interrupt);
loop {
if let Some((input, output, _, _)) = self.queues[VIRTIO_QUEUE_EVENT].get() {
if let Some((input, output, _, _)) = self.queues[VIRTIO_QUEUE_EVENT].get() {
let event: VirtIOInputEvent = unsafe { transmute_copy(&input[0][0]) };
if event.event_type == 2 && event.code == 0 {
// X
@ -156,6 +162,12 @@ impl Driver for VirtIOInput {
}
return false;
}
}
impl Driver for VirtIOInputDriver {
fn try_handle_interrupt(&self) -> bool {
self.0.lock().try_handle_interrupt()
}
fn device_type(&self) -> DeviceType {
DeviceType::Input
@ -187,12 +199,15 @@ pub fn virtio_input_init(node: &Node) {
header.guest_page_size.write(PAGE_SIZE as u32); // one page
let queue_num = 32;
let queues = [
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_EVENT, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_STATUS, queue_num),
];
let mut driver = VirtIOInput {
interrupt: node.prop_u32("interrupts").unwrap(),
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize,
queues: [VirtIOVirtqueue::new(header, VIRTIO_QUEUE_EVENT, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_STATUS, queue_num)],
header,
queues,
x: 0,
y: 0
};
@ -204,7 +219,8 @@ pub fn virtio_input_init(node: &Node) {
driver.queues[VIRTIO_QUEUE_EVENT].add(&[buffer], &[], 0);
}
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
driver.header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
DRIVERS.lock().push(Box::new(driver));
let driver = Arc::new(VirtIOInputDriver(Mutex::new(driver)));
DRIVERS.write().push(driver);
}

@ -1,11 +1,12 @@
use alloc::prelude::*;
use core::any::Any;
use alloc::sync::Arc;
use lazy_static::lazy_static;
use smoltcp::wire::{EthernetAddress, Ipv4Address};
use smoltcp::socket::SocketSet;
use spin::RwLock;
use crate::sync::{ThreadLock, SpinLock, Condvar, MutexGuard, SpinNoIrq};
use crate::sync::{Condvar, MutexGuard, SpinNoIrq};
mod device_tree;
pub mod bus;
@ -21,16 +22,16 @@ pub enum DeviceType {
Block
}
pub trait Driver : Send {
pub trait Driver : Send + Sync {
// if interrupt belongs to this driver, handle it and return true
// return false otherwise
fn try_handle_interrupt(&mut self) -> bool;
fn try_handle_interrupt(&self) -> bool;
// return the correspondent device type, see DeviceType
fn device_type(&self) -> DeviceType;
}
pub trait NetDriver : Send {
pub trait NetDriver : Driver {
// get mac address for this device
fn get_mac(&self) -> EthernetAddress;
@ -41,19 +42,21 @@ pub trait NetDriver : Send {
fn ipv4_address(&self) -> Option<Ipv4Address>;
// get sockets
fn sockets(&mut self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq>;
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq>;
// manually trigger a poll, use it after sending packets
fn poll(&mut self);
fn poll(&self);
}
lazy_static! {
pub static ref DRIVERS: SpinLock<Vec<Box<Driver>>> = SpinLock::new(Vec::new());
// NOTE: RwLock only write when initializing drivers
pub static ref DRIVERS: RwLock<Vec<Arc<Driver>>> = RwLock::new(Vec::new());
}
lazy_static! {
pub static ref NET_DRIVERS: ThreadLock<Vec<Box<NetDriver>>> = ThreadLock::new(Vec::new());
// NOTE: RwLock only write when initializing drivers
pub static ref NET_DRIVERS: RwLock<Vec<Arc<NetDriver>>> = RwLock::new(Vec::new());
}
lazy_static!{

@ -62,15 +62,14 @@ const E1000_MTA: usize = 0x5200 / 4;
const E1000_RAL: usize = 0x5400 / 4;
const E1000_RAH: usize = 0x5404 / 4;
#[derive(Clone)]
pub struct E1000Interface {
iface: Arc<Mutex<EthernetInterface<'static, 'static, 'static, E1000Driver>>>,
iface: Mutex<EthernetInterface<'static, 'static, 'static, E1000Driver>>,
driver: E1000Driver,
sockets: Arc<Mutex<SocketSet<'static, 'static, 'static>>>,
sockets: Mutex<SocketSet<'static, 'static, 'static>>,
}
impl Driver for E1000Interface {
fn try_handle_interrupt(&mut self) -> bool {
fn try_handle_interrupt(&self) -> bool {
let irq = {
let driver = self.driver.0.lock();
let mut current_addr = driver.header;
@ -172,11 +171,11 @@ impl NetDriver for E1000Interface {
self.iface.lock().ipv4_address()
}
fn sockets(&mut self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
self.sockets.lock()
}
fn poll(&mut self) {
fn poll(&self) {
let timestamp = Instant::from_millis(unsafe {
(crate::trap::TICK / crate::consts::USEC_PER_TICK / 1000) as i64
});
@ -485,11 +484,12 @@ pub fn e1000_init(header: usize, size: usize) {
.finalize();
let e1000_iface = E1000Interface {
iface: Arc::new(Mutex::new(iface)),
sockets: Arc::new(Mutex::new(SocketSet::new(vec![]))),
iface: Mutex::new(iface),
sockets: Mutex::new(SocketSet::new(vec![])),
driver: net_driver.clone(),
};
DRIVERS.lock().push(Box::new(e1000_iface.clone()));
NET_DRIVERS.lock().push(Box::new(e1000_iface));
let driver = Arc::new(e1000_iface);
DRIVERS.write().push(driver.clone());
NET_DRIVERS.write().push(driver);
}

@ -42,7 +42,7 @@ const VIRTIO_QUEUE_RECEIVE: usize = 0;
const VIRTIO_QUEUE_TRANSMIT: usize = 1;
impl Driver for VirtIONetDriver {
fn try_handle_interrupt(&mut self) -> bool {
fn try_handle_interrupt(&self) -> bool {
let driver = self.0.lock();
// ensure header page is mapped
@ -90,11 +90,11 @@ impl NetDriver for VirtIONetDriver {
unimplemented!()
}
fn sockets(&mut self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
unimplemented!()
}
fn poll(&mut self) {
fn poll(&self) {
unimplemented!()
}
}
@ -296,8 +296,8 @@ pub fn virtio_net_init(node: &Node) {
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
let net_driver = VirtIONetDriver(Arc::new(Mutex::new(driver)));
let net_driver = Arc::new(VirtIONetDriver(Arc::new(Mutex::new(driver))));
DRIVERS.lock().push(Box::new(net_driver.clone()));
NET_DRIVERS.lock().push(Box::new(net_driver));
DRIVERS.write().push(net_driver.clone());
NET_DRIVERS.write().push(net_driver);
}

@ -6,7 +6,7 @@ use alloc::vec;
use core::fmt::Write;
pub extern fn server(_arg: usize) -> ! {
if NET_DRIVERS.lock().len() < 1 {
if NET_DRIVERS.read().len() < 1 {
loop {
thread::yield_now();
}
@ -24,7 +24,7 @@ pub extern fn server(_arg: usize) -> ! {
let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]);
let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer);
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let udp_handle = sockets.add(udp_socket);
let tcp_handle = sockets.add(tcp_socket);
@ -34,7 +34,7 @@ pub extern fn server(_arg: usize) -> ! {
loop {
{
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
// udp server

@ -187,7 +187,7 @@ impl Thread {
info!("temporary copy data!");
let kstack = KernelStack::new();
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
for (_fd, file) in files.iter() {
if let FileLike::Socket(wrapper) = file {

@ -73,7 +73,7 @@ pub fn sys_poll(ufds: *mut PollFd, nfds: usize, timeout_msecs: usize) -> SysResu
},
Some(FileLike::Socket(wrapper)) => {
if let SocketType::Tcp(_) = wrapper.socket_type {
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);

@ -62,7 +62,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
domain, socket_type, protocol
);
let mut proc = process();
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
match domain {
AF_INET => match socket_type {
SOCK_STREAM => {
@ -196,8 +196,7 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut drivers = NET_DRIVERS.lock();
let iface = &mut *(drivers[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
@ -208,13 +207,10 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
// avoid deadlock
drop(socket);
drop(sockets);
drop(iface);
drop(drivers);
// wait for connection result
loop {
let mut drivers = NET_DRIVERS.lock();
let iface = &mut *(drivers[0]);
let iface = &*(NET_DRIVERS.read()[0]);
iface.poll();
let mut sockets = iface.sockets();
@ -223,8 +219,6 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
// still connecting
drop(socket);
drop(sockets);
drop(iface);
drop(drivers);
debug!("poll for connection wait");
SOCKET_ACTIVITY._wait();
} else if socket.state() == TcpState::Established {
@ -245,7 +239,7 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
}
pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult {
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
@ -277,7 +271,7 @@ pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usi
}
pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult {
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
loop {
@ -359,7 +353,7 @@ pub fn sys_sendto(
proc.memory_set.check_ptr(addr)?;
proc.memory_set.check_array(buffer, len)?;
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?;
if let SocketType::Raw = wrapper.socket_type {
@ -471,7 +465,7 @@ pub fn sys_recvfrom(
proc.memory_set.check_mut_array(addr, max_addr_len)?;
}
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?;
// TODO: move some part of these into one generic function
@ -524,7 +518,7 @@ pub fn sys_recvfrom(
}
pub fn sys_close_socket(proc: &mut Process, fd: usize, handle: SocketHandle) -> SysResult {
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
sockets.release(handle);
sockets.prune();
@ -553,7 +547,7 @@ pub fn sys_bind(fd: usize, addr: *const u8, len: usize) -> SysResult {
return Err(SysError::EINVAL);
}
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
wrapper.socket_type = SocketType::Tcp(Some(IpEndpoint::new(host.unwrap(), port)));
@ -572,7 +566,7 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
// open multiple sockets for each connection
let mut proc = process();
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
let mut sockets = iface.sockets();
@ -611,7 +605,7 @@ pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult {
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
loop {
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
@ -652,7 +646,6 @@ pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult {
// avoid deadlock
drop(socket);
drop(sockets);
drop(iface);
SOCKET_ACTIVITY._wait()
}
} else {
@ -684,7 +677,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResul
proc.memory_set.check_mut_array(addr, max_addr_len)?;
let iface = &mut *(NET_DRIVERS.lock()[0]);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) };

Loading…
Cancel
Save