refactor driver: make (Net)Driver Sync. may help avoid deadlock?

master
WangRunji 6 years ago
parent 1f2625e565
commit 9e6483f488

@ -91,8 +91,7 @@ pub extern fn rust_trap(tf: &mut TrapFrame) {
COM2 => com2(), COM2 => com2(),
IDE => ide(), IDE => ide(),
_ => { _ => {
let mut drivers = DRIVERS.lock(); for driver in DRIVERS.read().iter() {
for driver in drivers.iter_mut() {
if driver.try_handle_interrupt() == true { if driver.try_handle_interrupt() == true {
debug!("driver processed interrupt"); debug!("driver processed interrupt");
return; return;

@ -28,8 +28,7 @@ pub struct VirtIOBlk {
capacity: usize capacity: usize
} }
#[derive(Clone)] pub struct VirtIOBlkDriver(Mutex<VirtIOBlk>);
pub struct VirtIOBlkDriver(Arc<Mutex<VirtIOBlk>>);
#[repr(C)] #[repr(C)]
@ -92,7 +91,7 @@ bitflags! {
} }
impl Driver for VirtIOBlkDriver { impl Driver for VirtIOBlkDriver {
fn try_handle_interrupt(&mut self) -> bool { fn try_handle_interrupt(&self) -> bool {
let mut driver = self.0.lock(); let mut driver = self.0.lock();
// ensure header page is mapped // ensure header page is mapped
@ -167,15 +166,15 @@ pub fn virtio_blk_init(node: &Node) {
// configure two virtqueues: ingress and egress // configure two virtqueues: ingress and egress
header.guest_page_size.write(PAGE_SIZE as u32); // one page header.guest_page_size.write(PAGE_SIZE as u32); // one page
let mut driver = VirtIOBlkDriver(Arc::new(Mutex::new(VirtIOBlk { let mut driver = VirtIOBlkDriver(Mutex::new(VirtIOBlk {
interrupt: node.prop_u32("interrupts").unwrap(), interrupt: node.prop_u32("interrupts").unwrap(),
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(), interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize, header: from as usize,
queue: VirtIOVirtqueue::new(header, 0, 16), queue: VirtIOVirtqueue::new(header, 0, 16),
capacity: config.capacity.read() as usize, capacity: config.capacity.read() as usize,
}))); }));
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits()); header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
DRIVERS.lock().push(Box::new(driver)); DRIVERS.write().push(Arc::new(driver));
} }

@ -1,5 +1,6 @@
use alloc::alloc::{GlobalAlloc, Layout}; use alloc::alloc::{GlobalAlloc, Layout};
use alloc::prelude::*; use alloc::prelude::*;
use alloc::sync::Arc;
use core::slice; use core::slice;
use bitflags::*; use bitflags::*;
@ -14,6 +15,7 @@ use crate::arch::cpu;
use crate::HEAP_ALLOCATOR; use crate::HEAP_ALLOCATOR;
use crate::memory::active_table; use crate::memory::active_table;
use crate::arch::consts::{KERNEL_OFFSET, MEMORY_OFFSET}; use crate::arch::consts::{KERNEL_OFFSET, MEMORY_OFFSET};
use crate::sync::SpinNoIrqLock as Mutex;
use super::super::{DeviceType, Driver, DRIVERS}; use super::super::{DeviceType, Driver, DRIVERS};
use super::super::bus::virtio_mmio::*; use super::super::bus::virtio_mmio::*;
@ -24,7 +26,7 @@ const VIRTIO_GPU_EVENT_DISPLAY : u32 = 1 << 0;
struct VirtIOGpu { struct VirtIOGpu {
interrupt_parent: u32, interrupt_parent: u32,
interrupt: u32, interrupt: u32,
header: usize, header: &'static mut VirtIOHeader,
queue_buffer: [usize; 2], queue_buffer: [usize; 2],
frame_buffer: usize, frame_buffer: usize,
rect: VirtIOGpuRect, rect: VirtIOGpuRect,
@ -185,20 +187,25 @@ const VIRTIO_BUFFER_RECEIVE: usize = 1;
const VIRTIO_GPU_RESOURCE_ID: u32 = 0xbabe; const VIRTIO_GPU_RESOURCE_ID: u32 = 0xbabe;
impl Driver for VirtIOGpu { pub struct VirtIOGpuDriver(Mutex<VirtIOGpu>);
fn try_handle_interrupt(&mut self) -> bool {
impl Driver for VirtIOGpuDriver {
fn try_handle_interrupt(&self) -> bool {
// for simplicity // for simplicity
if cpu::id() > 0 { if cpu::id() > 0 {
return false return false
} }
let mut driver = self.0.lock();
// ensure header page is mapped // ensure header page is mapped
active_table().map_if_not_exists(self.header as usize, self.header as usize); // TODO: this should be mapped in all page table by default
let header_addr = &mut driver.header as *mut _ as usize;
active_table().map_if_not_exists(header_addr, header_addr);
let header = unsafe { &mut *(self.header as *mut VirtIOHeader) }; let interrupt = driver.header.interrupt_status.read();
let interrupt = header.interrupt_status.read();
if interrupt != 0 { if interrupt != 0 {
header.interrupt_ack.write(interrupt); driver.header.interrupt_ack.write(interrupt);
debug!("Got interrupt {:?}", interrupt); debug!("Got interrupt {:?}", interrupt);
return true; return true;
} }
@ -331,15 +338,18 @@ pub fn virtio_gpu_init(node: &Node) {
header.guest_page_size.write(PAGE_SIZE as u32); // one page header.guest_page_size.write(PAGE_SIZE as u32); // one page
let queue_num = 2; let queue_num = 2;
let queues = [
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_CURSOR, queue_num)
];
let mut driver = VirtIOGpu { let mut driver = VirtIOGpu {
interrupt: node.prop_u32("interrupts").unwrap(), interrupt: node.prop_u32("interrupts").unwrap(),
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(), interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize, header,
queue_buffer: [0, 0], queue_buffer: [0, 0],
frame_buffer: 0, frame_buffer: 0,
rect: VirtIOGpuRect::default(), rect: VirtIOGpuRect::default(),
queues: [VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num), queues,
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_CURSOR, queue_num)]
}; };
for buffer in 0..2 { for buffer in 0..2 {
@ -351,9 +361,10 @@ pub fn virtio_gpu_init(node: &Node) {
debug!("buffer {} using page address {:#X}", buffer, page as usize); debug!("buffer {} using page address {:#X}", buffer, page as usize);
} }
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits()); driver.header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
setup_framebuffer(&mut driver); setup_framebuffer(&mut driver);
DRIVERS.lock().push(Box::new(driver)); let driver = Arc::new(VirtIOGpuDriver(Mutex::new(driver)));
DRIVERS.write().push(driver);
} }

@ -1,5 +1,6 @@
use alloc::prelude::*; use alloc::prelude::*;
use alloc::vec; use alloc::vec;
use alloc::sync::Arc;
use core::fmt; use core::fmt;
use core::mem::size_of; use core::mem::size_of;
use core::mem::transmute_copy; use core::mem::transmute_copy;
@ -15,6 +16,7 @@ use volatile::Volatile;
use crate::arch::cpu; use crate::arch::cpu;
use crate::memory::active_table; use crate::memory::active_table;
use crate::sync::SpinNoIrqLock as Mutex;
use super::super::{DeviceType, Driver, DRIVERS}; use super::super::{DeviceType, Driver, DRIVERS};
use super::super::bus::virtio_mmio::*; use super::super::bus::virtio_mmio::*;
@ -22,7 +24,7 @@ use super::super::bus::virtio_mmio::*;
struct VirtIOInput { struct VirtIOInput {
interrupt_parent: u32, interrupt_parent: u32,
interrupt: u32, interrupt: u32,
header: usize, header: &'static mut VirtIOHeader,
// 0 for event, 1 for status // 0 for event, 1 for status
queues: [VirtIOVirtqueue; 2], queues: [VirtIOVirtqueue; 2],
x: isize, x: isize,
@ -121,7 +123,9 @@ bitflags! {
const VIRTIO_QUEUE_EVENT: usize = 0; const VIRTIO_QUEUE_EVENT: usize = 0;
const VIRTIO_QUEUE_STATUS: usize = 1; const VIRTIO_QUEUE_STATUS: usize = 1;
impl Driver for VirtIOInput { pub struct VirtIOInputDriver(Mutex<VirtIOInput>);
impl VirtIOInput {
fn try_handle_interrupt(&mut self) -> bool { fn try_handle_interrupt(&mut self) -> bool {
// for simplicity // for simplicity
if cpu::id() > 0 { if cpu::id() > 0 {
@ -129,14 +133,16 @@ impl Driver for VirtIOInput {
} }
// ensure header page is mapped // ensure header page is mapped
active_table().map_if_not_exists(self.header as usize, self.header as usize); // TODO: this should be mapped in all page table by default
let header = unsafe { &mut *(self.header as *mut VirtIOHeader) }; let header_addr = self.header as *mut _ as usize;
let interrupt = header.interrupt_status.read(); active_table().map_if_not_exists(header_addr, header_addr);
let interrupt = self.header.interrupt_status.read();
if interrupt != 0 { if interrupt != 0 {
header.interrupt_ack.write(interrupt); self.header.interrupt_ack.write(interrupt);
debug!("Got interrupt {:?}", interrupt); debug!("Got interrupt {:?}", interrupt);
loop { loop {
if let Some((input, output, _, _)) = self.queues[VIRTIO_QUEUE_EVENT].get() { if let Some((input, output, _, _)) = self.queues[VIRTIO_QUEUE_EVENT].get() {
let event: VirtIOInputEvent = unsafe { transmute_copy(&input[0][0]) }; let event: VirtIOInputEvent = unsafe { transmute_copy(&input[0][0]) };
if event.event_type == 2 && event.code == 0 { if event.event_type == 2 && event.code == 0 {
// X // X
@ -156,6 +162,12 @@ impl Driver for VirtIOInput {
} }
return false; return false;
} }
}
impl Driver for VirtIOInputDriver {
fn try_handle_interrupt(&self) -> bool {
self.0.lock().try_handle_interrupt()
}
fn device_type(&self) -> DeviceType { fn device_type(&self) -> DeviceType {
DeviceType::Input DeviceType::Input
@ -187,12 +199,15 @@ pub fn virtio_input_init(node: &Node) {
header.guest_page_size.write(PAGE_SIZE as u32); // one page header.guest_page_size.write(PAGE_SIZE as u32); // one page
let queue_num = 32; let queue_num = 32;
let queues = [
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_EVENT, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_STATUS, queue_num),
];
let mut driver = VirtIOInput { let mut driver = VirtIOInput {
interrupt: node.prop_u32("interrupts").unwrap(), interrupt: node.prop_u32("interrupts").unwrap(),
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(), interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize, header,
queues: [VirtIOVirtqueue::new(header, VIRTIO_QUEUE_EVENT, queue_num), queues,
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_STATUS, queue_num)],
x: 0, x: 0,
y: 0 y: 0
}; };
@ -204,7 +219,8 @@ pub fn virtio_input_init(node: &Node) {
driver.queues[VIRTIO_QUEUE_EVENT].add(&[buffer], &[], 0); driver.queues[VIRTIO_QUEUE_EVENT].add(&[buffer], &[], 0);
} }
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits()); driver.header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
DRIVERS.lock().push(Box::new(driver)); let driver = Arc::new(VirtIOInputDriver(Mutex::new(driver)));
DRIVERS.write().push(driver);
} }

@ -1,11 +1,12 @@
use alloc::prelude::*; use alloc::prelude::*;
use core::any::Any; use alloc::sync::Arc;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use smoltcp::wire::{EthernetAddress, Ipv4Address}; use smoltcp::wire::{EthernetAddress, Ipv4Address};
use smoltcp::socket::SocketSet; use smoltcp::socket::SocketSet;
use spin::RwLock;
use crate::sync::{ThreadLock, SpinLock, Condvar, MutexGuard, SpinNoIrq}; use crate::sync::{Condvar, MutexGuard, SpinNoIrq};
mod device_tree; mod device_tree;
pub mod bus; pub mod bus;
@ -21,16 +22,16 @@ pub enum DeviceType {
Block Block
} }
pub trait Driver : Send { pub trait Driver : Send + Sync {
// if interrupt belongs to this driver, handle it and return true // if interrupt belongs to this driver, handle it and return true
// return false otherwise // return false otherwise
fn try_handle_interrupt(&mut self) -> bool; fn try_handle_interrupt(&self) -> bool;
// return the correspondent device type, see DeviceType // return the correspondent device type, see DeviceType
fn device_type(&self) -> DeviceType; fn device_type(&self) -> DeviceType;
} }
pub trait NetDriver : Send { pub trait NetDriver : Driver {
// get mac address for this device // get mac address for this device
fn get_mac(&self) -> EthernetAddress; fn get_mac(&self) -> EthernetAddress;
@ -41,19 +42,21 @@ pub trait NetDriver : Send {
fn ipv4_address(&self) -> Option<Ipv4Address>; fn ipv4_address(&self) -> Option<Ipv4Address>;
// get sockets // get sockets
fn sockets(&mut self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq>; fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq>;
// manually trigger a poll, use it after sending packets // manually trigger a poll, use it after sending packets
fn poll(&mut self); fn poll(&self);
} }
lazy_static! { lazy_static! {
pub static ref DRIVERS: SpinLock<Vec<Box<Driver>>> = SpinLock::new(Vec::new()); // NOTE: RwLock only write when initializing drivers
pub static ref DRIVERS: RwLock<Vec<Arc<Driver>>> = RwLock::new(Vec::new());
} }
lazy_static! { lazy_static! {
pub static ref NET_DRIVERS: ThreadLock<Vec<Box<NetDriver>>> = ThreadLock::new(Vec::new()); // NOTE: RwLock only write when initializing drivers
pub static ref NET_DRIVERS: RwLock<Vec<Arc<NetDriver>>> = RwLock::new(Vec::new());
} }
lazy_static!{ lazy_static!{

@ -62,15 +62,14 @@ const E1000_MTA: usize = 0x5200 / 4;
const E1000_RAL: usize = 0x5400 / 4; const E1000_RAL: usize = 0x5400 / 4;
const E1000_RAH: usize = 0x5404 / 4; const E1000_RAH: usize = 0x5404 / 4;
#[derive(Clone)]
pub struct E1000Interface { pub struct E1000Interface {
iface: Arc<Mutex<EthernetInterface<'static, 'static, 'static, E1000Driver>>>, iface: Mutex<EthernetInterface<'static, 'static, 'static, E1000Driver>>,
driver: E1000Driver, driver: E1000Driver,
sockets: Arc<Mutex<SocketSet<'static, 'static, 'static>>>, sockets: Mutex<SocketSet<'static, 'static, 'static>>,
} }
impl Driver for E1000Interface { impl Driver for E1000Interface {
fn try_handle_interrupt(&mut self) -> bool { fn try_handle_interrupt(&self) -> bool {
let irq = { let irq = {
let driver = self.driver.0.lock(); let driver = self.driver.0.lock();
let mut current_addr = driver.header; let mut current_addr = driver.header;
@ -172,11 +171,11 @@ impl NetDriver for E1000Interface {
self.iface.lock().ipv4_address() self.iface.lock().ipv4_address()
} }
fn sockets(&mut self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> { fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
self.sockets.lock() self.sockets.lock()
} }
fn poll(&mut self) { fn poll(&self) {
let timestamp = Instant::from_millis(unsafe { let timestamp = Instant::from_millis(unsafe {
(crate::trap::TICK / crate::consts::USEC_PER_TICK / 1000) as i64 (crate::trap::TICK / crate::consts::USEC_PER_TICK / 1000) as i64
}); });
@ -485,11 +484,12 @@ pub fn e1000_init(header: usize, size: usize) {
.finalize(); .finalize();
let e1000_iface = E1000Interface { let e1000_iface = E1000Interface {
iface: Arc::new(Mutex::new(iface)), iface: Mutex::new(iface),
sockets: Arc::new(Mutex::new(SocketSet::new(vec![]))), sockets: Mutex::new(SocketSet::new(vec![])),
driver: net_driver.clone(), driver: net_driver.clone(),
}; };
DRIVERS.lock().push(Box::new(e1000_iface.clone())); let driver = Arc::new(e1000_iface);
NET_DRIVERS.lock().push(Box::new(e1000_iface)); DRIVERS.write().push(driver.clone());
NET_DRIVERS.write().push(driver);
} }

@ -42,7 +42,7 @@ const VIRTIO_QUEUE_RECEIVE: usize = 0;
const VIRTIO_QUEUE_TRANSMIT: usize = 1; const VIRTIO_QUEUE_TRANSMIT: usize = 1;
impl Driver for VirtIONetDriver { impl Driver for VirtIONetDriver {
fn try_handle_interrupt(&mut self) -> bool { fn try_handle_interrupt(&self) -> bool {
let driver = self.0.lock(); let driver = self.0.lock();
// ensure header page is mapped // ensure header page is mapped
@ -90,11 +90,11 @@ impl NetDriver for VirtIONetDriver {
unimplemented!() unimplemented!()
} }
fn sockets(&mut self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> { fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
unimplemented!() unimplemented!()
} }
fn poll(&mut self) { fn poll(&self) {
unimplemented!() unimplemented!()
} }
} }
@ -296,8 +296,8 @@ pub fn virtio_net_init(node: &Node) {
header.status.write(VirtIODeviceStatus::DRIVER_OK.bits()); header.status.write(VirtIODeviceStatus::DRIVER_OK.bits());
let net_driver = VirtIONetDriver(Arc::new(Mutex::new(driver))); let net_driver = Arc::new(VirtIONetDriver(Arc::new(Mutex::new(driver))));
DRIVERS.lock().push(Box::new(net_driver.clone())); DRIVERS.write().push(net_driver.clone());
NET_DRIVERS.lock().push(Box::new(net_driver)); NET_DRIVERS.write().push(net_driver);
} }

@ -6,7 +6,7 @@ use alloc::vec;
use core::fmt::Write; use core::fmt::Write;
pub extern fn server(_arg: usize) -> ! { pub extern fn server(_arg: usize) -> ! {
if NET_DRIVERS.lock().len() < 1 { if NET_DRIVERS.read().len() < 1 {
loop { loop {
thread::yield_now(); thread::yield_now();
} }
@ -24,7 +24,7 @@ pub extern fn server(_arg: usize) -> ! {
let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]); let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]);
let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer); let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer);
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let udp_handle = sockets.add(udp_socket); let udp_handle = sockets.add(udp_socket);
let tcp_handle = sockets.add(tcp_socket); let tcp_handle = sockets.add(tcp_socket);
@ -34,7 +34,7 @@ pub extern fn server(_arg: usize) -> ! {
loop { loop {
{ {
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
// udp server // udp server

@ -187,7 +187,7 @@ impl Thread {
info!("temporary copy data!"); info!("temporary copy data!");
let kstack = KernelStack::new(); let kstack = KernelStack::new();
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
for (_fd, file) in files.iter() { for (_fd, file) in files.iter() {
if let FileLike::Socket(wrapper) = file { if let FileLike::Socket(wrapper) = file {

@ -73,7 +73,7 @@ pub fn sys_poll(ufds: *mut PollFd, nfds: usize, timeout_msecs: usize) -> SysResu
}, },
Some(FileLike::Socket(wrapper)) => { Some(FileLike::Socket(wrapper)) => {
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);

@ -62,7 +62,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
domain, socket_type, protocol domain, socket_type, protocol
); );
let mut proc = process(); let mut proc = process();
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
match domain { match domain {
AF_INET => match socket_type { AF_INET => match socket_type {
SOCK_STREAM => { SOCK_STREAM => {
@ -196,8 +196,7 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let mut drivers = NET_DRIVERS.lock(); let iface = &*(NET_DRIVERS.read()[0]);
let iface = &mut *(drivers[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
@ -208,13 +207,10 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
// avoid deadlock // avoid deadlock
drop(socket); drop(socket);
drop(sockets); drop(sockets);
drop(iface);
drop(drivers);
// wait for connection result // wait for connection result
loop { loop {
let mut drivers = NET_DRIVERS.lock(); let iface = &*(NET_DRIVERS.read()[0]);
let iface = &mut *(drivers[0]);
iface.poll(); iface.poll();
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
@ -223,8 +219,6 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
// still connecting // still connecting
drop(socket); drop(socket);
drop(sockets); drop(sockets);
drop(iface);
drop(drivers);
debug!("poll for connection wait"); debug!("poll for connection wait");
SOCKET_ACTIVITY._wait(); SOCKET_ACTIVITY._wait();
} else if socket.state() == TcpState::Established { } else if socket.state() == TcpState::Established {
@ -245,7 +239,7 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult {
} }
pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult {
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
@ -277,7 +271,7 @@ pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usi
} }
pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult { pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult {
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
loop { loop {
@ -359,7 +353,7 @@ pub fn sys_sendto(
proc.memory_set.check_ptr(addr)?; proc.memory_set.check_ptr(addr)?;
proc.memory_set.check_array(buffer, len)?; proc.memory_set.check_array(buffer, len)?;
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
if let SocketType::Raw = wrapper.socket_type { if let SocketType::Raw = wrapper.socket_type {
@ -471,7 +465,7 @@ pub fn sys_recvfrom(
proc.memory_set.check_mut_array(addr, max_addr_len)?; proc.memory_set.check_mut_array(addr, max_addr_len)?;
} }
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
// TODO: move some part of these into one generic function // TODO: move some part of these into one generic function
@ -524,7 +518,7 @@ pub fn sys_recvfrom(
} }
pub fn sys_close_socket(proc: &mut Process, fd: usize, handle: SocketHandle) -> SysResult { pub fn sys_close_socket(proc: &mut Process, fd: usize, handle: SocketHandle) -> SysResult {
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
sockets.release(handle); sockets.release(handle);
sockets.prune(); sockets.prune();
@ -553,7 +547,7 @@ pub fn sys_bind(fd: usize, addr: *const u8, len: usize) -> SysResult {
return Err(SysError::EINVAL); return Err(SysError::EINVAL);
} }
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = &mut proc.get_socket_mut(fd)?; let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
wrapper.socket_type = SocketType::Tcp(Some(IpEndpoint::new(host.unwrap(), port))); wrapper.socket_type = SocketType::Tcp(Some(IpEndpoint::new(host.unwrap(), port)));
@ -572,7 +566,7 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
// open multiple sockets for each connection // open multiple sockets for each connection
let mut proc = process(); let mut proc = process();
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
@ -611,7 +605,7 @@ pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult {
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
loop { loop {
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
@ -652,7 +646,6 @@ pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult {
// avoid deadlock // avoid deadlock
drop(socket); drop(socket);
drop(sockets); drop(sockets);
drop(iface);
SOCKET_ACTIVITY._wait() SOCKET_ACTIVITY._wait()
} }
} else { } else {
@ -684,7 +677,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResul
proc.memory_set.check_mut_array(addr, max_addr_len)?; proc.memory_set.check_mut_array(addr, max_addr_len)?;
let iface = &mut *(NET_DRIVERS.lock()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) };

Loading…
Cancel
Save