From 1a1e39c9608ffbf5619f9523a305a6426c212c32 Mon Sep 17 00:00:00 2001 From: Jiajie Chen Date: Mon, 4 Mar 2019 15:34:02 +0800 Subject: [PATCH] Move socket set to iface, redesign NetDriver trait and implement blocking net syscalls --- kernel/Makefile | 2 +- kernel/src/arch/x86_64/interrupt/handler.rs | 12 ++- kernel/src/drivers/bus/pci.rs | 65 +++++++++--- kernel/src/drivers/mod.rs | 13 ++- kernel/src/drivers/net/e1000.rs | 112 ++++++++++++++------ kernel/src/drivers/net/virtio_net.rs | 9 +- kernel/src/net/test.rs | 18 ++-- kernel/src/process/structs.rs | 9 +- kernel/src/shell.rs | 4 +- kernel/src/syscall/net.rs | 100 ++++++++++------- 10 files changed, 227 insertions(+), 117 deletions(-) diff --git a/kernel/Makefile b/kernel/Makefile index 685f362..5c26af8 100644 --- a/kernel/Makefile +++ b/kernel/Makefile @@ -71,7 +71,7 @@ qemu_opts += \ -serial mon:stdio \ -device isa-debug-exit qemu_net_opts += \ - -device e1000,netdev=net0 + -device e1000e,netdev=net0 else ifeq ($(arch), riscv32) qemu_opts += \ diff --git a/kernel/src/arch/x86_64/interrupt/handler.rs b/kernel/src/arch/x86_64/interrupt/handler.rs index 22406e4..ca554b3 100644 --- a/kernel/src/arch/x86_64/interrupt/handler.rs +++ b/kernel/src/arch/x86_64/interrupt/handler.rs @@ -67,6 +67,7 @@ use super::consts::*; use super::TrapFrame; use log::*; +use crate::drivers::DRIVERS; global_asm!(include_str!("trap.asm")); global_asm!(include_str!("vector.asm")); @@ -89,7 +90,16 @@ pub extern fn rust_trap(tf: &mut TrapFrame) { COM1 => com1(), COM2 => com2(), IDE => ide(), - _ => panic!("Invalid IRQ number: {}", irq), + _ => { + let mut drivers = DRIVERS.lock(); + for driver in drivers.iter_mut() { + if driver.try_handle_interrupt() == true { + debug!("driver processed interrupt"); + return; + } + } + warn!("unhandled external IRQ number: {}", irq); + }, } } SwitchToKernel => to_kernel(tf), diff --git a/kernel/src/drivers/bus/pci.rs b/kernel/src/drivers/bus/pci.rs index f6fd456..fa13db6 100644 --- a/kernel/src/drivers/bus/pci.rs +++ b/kernel/src/drivers/bus/pci.rs @@ -1,14 +1,25 @@ use crate::drivers::net::e1000; use x86_64::instructions::port::Port; -const VENDOR: u32 = 0x00; -const DEVICE: u32 = 0x02; -const COMMAND: u32 = 0x04; -const STATUS: u32 = 0x06; -const SUBCLASS: u32 = 0x0a; -const CLASS: u32 = 0x0b; -const HEADER: u32 = 0x0e; -const BAR0: u32 = 0x10; +const PCI_VENDOR: u32 = 0x00; +const PCI_DEVICE: u32 = 0x02; +const PCI_COMMAND: u32 = 0x04; +const PCI_STATUS: u32 = 0x06; +const PCI_SUBCLASS: u32 = 0x0a; +const PCI_CLASS: u32 = 0x0b; +const PCI_HEADER: u32 = 0x0e; +const PCI_BAR0: u32 = 0x10; // first +const PCI_BAR5: u32 = 0x24; // last +const PCI_CAP_PTR: u32 = 0x34; +const PCI_INTERRUPT_LINE: u32 = 0x3c; +const PCI_INTERRUPT_PIN: u32 = 0x3d; + +const PCI_MSI_CTRL_CAP: u32 = 0x00; +const PCI_MSI_ADDR: u32 = 0x04; +const PCI_MSI_UPPER_ADDR: u32 = 0x08; +const PCI_MSI_DATA: u32 = 0x0C; + +const PCI_CAP_ID_MSI: u32 = 0x05; const PCI_ADDR_PORT: u16 = 0xcf8; const PCI_DATA_PORT: u16 = 0xcfc; @@ -93,7 +104,7 @@ impl PciTag { // return (addr, len) pub unsafe fn get_bar_mem(&self, bar_number: u32) -> Option<(usize, usize)> { assert!(bar_number <= 4); - let bar = BAR0 + 4 * bar_number; + let bar = PCI_BAR0 + 4 * bar_number; let mut base = self.read(bar, 4); self.write(bar, 0xffffffff); let mut max_base = self.read(bar, 4); @@ -129,14 +140,14 @@ impl PciTag { // returns a tuple of (vid, did, next) pub fn probe(&self) -> Option<(u32, u32, bool)> { unsafe { - let v = self.read(VENDOR, 2); + let v = self.read(PCI_VENDOR, 2); if v == 0xffff { return None; } - let d = self.read(DEVICE, 2); - let mf = self.read(HEADER, 1); - let cl = self.read(CLASS, 1); - let scl = self.read(SUBCLASS, 1); + let d = self.read(PCI_DEVICE, 2); + let mf = self.read(PCI_HEADER, 1); + let cl = self.read(PCI_CLASS, 1); + let scl = self.read(PCI_SUBCLASS, 1); info!( "{}: {}: {}: {:#X} {:#X} ({} {})", self.bus(), @@ -153,9 +164,29 @@ impl PciTag { } pub unsafe fn enable(&self) { - let orig = self.read(COMMAND, 2); - // IO_ENABLE | MEM_ENABLE | MASTER_ENABLE - self.write(COMMAND, orig | 0xf); + let orig = self.read(PCI_COMMAND, 2); + // IO Space | MEM Space | Bus Mastering | Special Cycles | PCI Interrupt Disable + self.write(PCI_COMMAND, orig | 0x40f); + + // find MSI cap + let mut cap_ptr = self.read(PCI_CAP_PTR, 1); + while cap_ptr > 0 { + let cap_id = self.read(cap_ptr, 1); + if cap_id == PCI_CAP_ID_MSI { + self.write(cap_ptr + PCI_MSI_ADDR, 0xfee << 20); + // irq 23 temporarily + self.write(cap_ptr + PCI_MSI_DATA, 55 | 0 << 12); + + let orig_ctrl = self.read(cap_ptr + PCI_MSI_CTRL_CAP, 4); + debug!("orig ctrl {:b}", orig_ctrl); + self.write(cap_ptr + PCI_MSI_CTRL_CAP, orig_ctrl | 0x10000); + debug!("new ctrl {:b}", self.read(cap_ptr + PCI_MSI_CTRL_CAP, 2)); + break; + } + info!("cap id {} at {:#X}", self.read(cap_ptr, 1), cap_ptr); + cap_ptr = self.read(cap_ptr + 1, 1); + } + } } diff --git a/kernel/src/drivers/mod.rs b/kernel/src/drivers/mod.rs index b2dd301..a30f0ca 100644 --- a/kernel/src/drivers/mod.rs +++ b/kernel/src/drivers/mod.rs @@ -5,7 +5,7 @@ use lazy_static::lazy_static; use smoltcp::wire::{EthernetAddress, Ipv4Address}; use smoltcp::socket::SocketSet; -use crate::sync::SpinNoIrqLock; +use crate::sync::{SpinNoIrqLock, Condvar, MutexGuard, SpinNoIrq}; mod device_tree; pub mod bus; @@ -40,8 +40,11 @@ pub trait NetDriver : Send { // get ipv4 address fn ipv4_address(&self) -> Option; - // poll for sockets - fn poll(&mut self, socket: &mut SocketSet) -> Option; + // get sockets + fn sockets(&mut self) -> MutexGuard, SpinNoIrq>; + + // manually trigger a poll, use it after sending packets + fn poll(&mut self); } @@ -53,6 +56,10 @@ lazy_static! { pub static ref NET_DRIVERS: SpinNoIrqLock>> = SpinNoIrqLock::new(Vec::new()); } +lazy_static! { + pub static ref SOCKET_ACTIVITY: Condvar = Condvar::new(); +} + #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] pub fn init(dtb: usize) { device_tree::init(dtb); diff --git a/kernel/src/drivers/net/e1000.rs b/kernel/src/drivers/net/e1000.rs index cf61b64..236df30 100644 --- a/kernel/src/drivers/net/e1000.rs +++ b/kernel/src/drivers/net/e1000.rs @@ -2,7 +2,7 @@ use alloc::alloc::{GlobalAlloc, Layout}; use alloc::format; use alloc::prelude::*; use alloc::sync::Arc; -use core::mem::size_of; +use core::mem::{size_of, transmute}; use core::slice; use core::sync::atomic::{fence, Ordering}; @@ -22,9 +22,10 @@ use volatile::{Volatile}; use crate::memory::active_table; use crate::sync::SpinNoIrqLock as Mutex; +use crate::sync::{MutexGuard, SpinNoIrq}; use crate::HEAP_ALLOCATOR; -use super::super::{DeviceType, Driver, NetDriver, NET_DRIVERS}; +use super::super::{DeviceType, Driver, NetDriver, DRIVERS, NET_DRIVERS, SOCKET_ACTIVITY}; pub struct E1000 { header: usize, @@ -37,8 +38,13 @@ pub struct E1000 { first_trans: bool, } +#[derive(Clone)] +pub struct E1000Driver(Arc>); + const E1000_STATUS: usize = 0x0008 / 4; +const E1000_ICR: usize = 0x00C0 / 4; const E1000_IMS: usize = 0x00D0 / 4; +const E1000_IMC: usize = 0x00D8 / 4; const E1000_RCTL: usize = 0x0100 / 4; const E1000_TCTL: usize = 0x0400 / 4; const E1000_TIPG: usize = 0x0410 / 4; @@ -56,21 +62,50 @@ const E1000_MTA: usize = 0x5200 / 4; const E1000_RAL: usize = 0x5400 / 4; const E1000_RAH: usize = 0x5404 / 4; +#[derive(Clone)] pub struct E1000Interface { - iface: EthernetInterface<'static, 'static, 'static, E1000Driver> + iface: Arc>>, + driver: E1000Driver, + sockets: Arc>>, } -#[derive(Clone)] -pub struct E1000Driver(Arc>); - -impl Driver for E1000Driver { +impl Driver for E1000Interface { fn try_handle_interrupt(&mut self) -> bool { - let driver = self.0.lock(); + let irq = { + let driver = self.driver.0.lock(); + let mut current_addr = driver.header; + while current_addr < driver.header + driver.size { + active_table().map_if_not_exists(current_addr, current_addr); + current_addr = current_addr + PAGE_SIZE; + } + + let e1000 = + unsafe { slice::from_raw_parts_mut(driver.header as *mut Volatile, driver.size / 4) }; + + let icr = e1000[E1000_ICR].read(); + if icr != 0 { + // clear it + e1000[E1000_ICR].write(icr); + true + } else { + false + } + }; - // ensure header page is mapped - active_table().map_if_not_exists(driver.header, driver.size); + if irq { + let timestamp = Instant::from_millis(unsafe { crate::trap::TICK as i64 }); + let mut sockets = self.sockets.lock(); + match self.iface.lock().poll(&mut sockets, timestamp) { + Ok(_) => { + SOCKET_ACTIVITY.notify_all(); + } + Err(err) => { + debug!("poll got err {}", err); + } + } + } - return false; + return irq; } fn device_type(&self) -> DeviceType { @@ -90,14 +125,13 @@ impl E1000 { let e1000 = unsafe { slice::from_raw_parts_mut(self.header as *mut Volatile, self.size / 4) }; let send_queue_size = PAGE_SIZE / size_of::(); - let mut send_queue = unsafe { + let send_queue = unsafe { slice::from_raw_parts_mut(self.send_page as *mut E1000RecvDesc, send_queue_size) }; - let mut tdt = e1000[E1000_TDT].read(); - let index = (tdt as usize + 1) % send_queue_size; + let tdt = e1000[E1000_TDT].read(); + let index = (tdt as usize) % send_queue_size; let send_desc = &mut send_queue[index]; - // TODO: fix it return self.first_trans || (*send_desc).status & 1 != 0; } @@ -124,29 +158,33 @@ impl E1000 { impl NetDriver for E1000Interface { fn get_mac(&self) -> EthernetAddress { - self.iface.ethernet_addr() + self.iface.lock().ethernet_addr() } fn get_ifname(&self) -> String { format!("e1000") } - fn poll(&mut self, sockets: &mut SocketSet) -> Option { + fn ipv4_address(&self) -> Option { + self.iface.lock().ipv4_address() + } + + fn sockets(&mut self) -> MutexGuard, SpinNoIrq> { + self.sockets.lock() + } + + fn poll(&mut self) { let timestamp = Instant::from_millis(unsafe { crate::trap::TICK as i64 }); - match self.iface.poll(sockets, timestamp) { - Ok(update) => { - Some(update) + let mut sockets = self.sockets.lock(); + match self.iface.lock().poll(&mut sockets, timestamp) { + Ok(_) => { + SOCKET_ACTIVITY.notify_all(); } Err(err) => { debug!("poll got err {}", err); - None } } } - - fn ipv4_address(&self) -> Option { - self.iface.ipv4_address() - } } #[repr(C)] @@ -201,7 +239,7 @@ impl<'a> phy::Device<'a> for E1000Driver { fn capabilities(&self) -> DeviceCapabilities { let mut caps = DeviceCapabilities::default(); caps.max_transmission_unit = 1536; - caps.max_burst_size = Some(1); + caps.max_burst_size = Some(32); caps } } @@ -263,12 +301,10 @@ impl phy::TxToken for E1000TxToken { }; let mut tdt = e1000[E1000_TDT].read(); - let index_next = (tdt as usize + 1) % send_queue_size; - let send_desc = &mut send_queue[index_next]; - assert!(driver.first_trans || send_desc.status & 1 != 0); - let index = (tdt as usize) % send_queue_size; let send_desc = &mut send_queue[index]; + assert!(driver.first_trans || send_desc.status & 1 != 0); + let target = unsafe { slice::from_raw_parts_mut(driver.send_buffers[index] as *mut u8, len) }; target.copy_from_slice(&buffer[..len]); @@ -400,7 +436,13 @@ pub fn e1000_init(header: usize, size: usize) { for i in E1000_MTA..E1000_RAL { e1000[i].write(0); } - e1000[E1000_IMS].write(0); // IMS + + // enable interrupt + // RXT0 + e1000[E1000_IMS].write(1 << 7); // IMS + + // clear interrupt + e1000[E1000_ICR].write(e1000[E1000_ICR].read()); e1000[E1000_RDBAL].write(recv_page_pa as u32); // RDBAL e1000[E1000_RDBAH].write((recv_page_pa >> 32) as u32); // RDBAH @@ -431,16 +473,18 @@ pub fn e1000_init(header: usize, size: usize) { let ethernet_addr = EthernetAddress::from_bytes(&mac); let ip_addrs = [IpCidr::new(IpAddress::v4(10,0,0,2), 24)]; let neighbor_cache = NeighborCache::new(BTreeMap::new()); - let iface = EthernetInterfaceBuilder::new(net_driver) + let iface = EthernetInterfaceBuilder::new(net_driver.clone()) .ethernet_addr(ethernet_addr) .ip_addrs(ip_addrs) .neighbor_cache(neighbor_cache) .finalize(); let e1000_iface = E1000Interface { - iface, + iface: Arc::new(Mutex::new(iface)), + sockets: Arc::new(Mutex::new(SocketSet::new(vec![]))), + driver: net_driver.clone(), }; - //DRIVERS.lock().push(Box::new(net_driver.clone())); + DRIVERS.lock().push(Box::new(e1000_iface.clone())); NET_DRIVERS.lock().push(Box::new(e1000_iface)); } diff --git a/kernel/src/drivers/net/virtio_net.rs b/kernel/src/drivers/net/virtio_net.rs index 850fd43..09ba8ef 100644 --- a/kernel/src/drivers/net/virtio_net.rs +++ b/kernel/src/drivers/net/virtio_net.rs @@ -21,6 +21,7 @@ use volatile::{ReadOnly, Volatile}; use crate::HEAP_ALLOCATOR; use crate::memory::active_table; use crate::sync::SpinNoIrqLock as Mutex; +use crate::sync::{MutexGuard, SpinNoIrq}; use super::super::{DeviceType, Driver, DRIVERS, NET_DRIVERS, NetDriver}; use super::super::bus::virtio_mmio::*; @@ -85,11 +86,15 @@ impl NetDriver for VirtIONetDriver { format!("virtio{}", self.0.lock().interrupt) } - fn poll(&mut self, sockets: &mut SocketSet) -> Option { + fn ipv4_address(&self) -> Option { unimplemented!() } - fn ipv4_address(&self) -> Option { + fn sockets(&mut self) -> MutexGuard, SpinNoIrq> { + unimplemented!() + } + + fn poll(&mut self) { unimplemented!() } } diff --git a/kernel/src/net/test.rs b/kernel/src/net/test.rs index ec7842f..854985c 100644 --- a/kernel/src/net/test.rs +++ b/kernel/src/net/test.rs @@ -24,24 +24,18 @@ pub extern fn server(_arg: usize) -> ! { let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]); let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer); - let mut sockets = SocketSet::new(vec![]); + let iface = &mut *(NET_DRIVERS.lock()[0]); + let mut sockets = iface.sockets(); let udp_handle = sockets.add(udp_socket); let tcp_handle = sockets.add(tcp_socket); let tcp2_handle = sockets.add(tcp2_socket); + drop(sockets); + drop(iface); loop { { - let iface = &mut *NET_DRIVERS.lock()[0]; - match iface.poll(&mut sockets) { - Some(event) => { - if !event { - continue; - } - }, - None => { - continue - } - } + let iface = &mut *(NET_DRIVERS.lock()[0]); + let mut sockets = iface.sockets(); // udp server { diff --git a/kernel/src/process/structs.rs b/kernel/src/process/structs.rs index 6218338..508bc76 100644 --- a/kernel/src/process/structs.rs +++ b/kernel/src/process/structs.rs @@ -9,6 +9,8 @@ use smoltcp::socket::{SocketSet, SocketHandle}; use crate::arch::interrupt::{Context, TrapFrame}; use crate::memory::{ByFrame, GlobalFrameAlloc, KernelStack, MemoryAttr, MemorySet}; use crate::fs::{FileHandle, OpenOptions}; +use crate::sync::Condvar; +use crate::drivers::NET_DRIVERS; use super::abi::{self, ProcInitInfo}; @@ -43,8 +45,6 @@ pub struct Process { pub memory_set: MemorySet, pub files: BTreeMap, pub cwd: String, - // TODO: discuss: move it to interface or leave it here - pub sockets: SocketSet<'static, 'static, 'static>, } /// Let `rcore_thread` can switch between our `Thread` @@ -66,7 +66,6 @@ impl Thread { memory_set: MemorySet::new(), files: BTreeMap::default(), cwd: String::from("/"), - sockets: SocketSet::new(vec![]) })), }) } @@ -82,7 +81,6 @@ impl Thread { memory_set, files: BTreeMap::default(), cwd: String::from("/"), - sockets: SocketSet::new(vec![]) })), }) } @@ -161,7 +159,6 @@ impl Thread { memory_set, files, cwd: String::from("/"), - sockets: SocketSet::new(vec![]) })), }) } @@ -193,8 +190,6 @@ impl Thread { memory_set, files: self.proc.lock().files.clone(), cwd: self.proc.lock().cwd.clone(), - // TODO: duplicate sockets for child process - sockets: SocketSet::new(vec![]) })), }) } diff --git a/kernel/src/shell.rs b/kernel/src/shell.rs index 2031fc0..4355d2c 100644 --- a/kernel/src/shell.rs +++ b/kernel/src/shell.rs @@ -7,8 +7,8 @@ use crate::process::*; use crate::thread; pub fn run_user_shell() { - use crate::net::server; - processor().manager().add(Thread::new_kernel(server, 0), 0); + //use crate::net::server; + //processor().manager().add(Thread::new_kernel(server, 0), 0); if let Ok(inode) = ROOT_INODE.lookup("sh") { println!("Going to user mode shell."); println!("Use 'ls' to list available programs."); diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index c28db71..5ff813f 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -1,7 +1,7 @@ //! Syscalls for networking use super::*; -use crate::drivers::NET_DRIVERS; +use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY}; use core::mem::size_of; use smoltcp::socket::*; use smoltcp::wire::*; @@ -49,6 +49,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu domain, socket_type, protocol ); let mut proc = process(); + let iface = &mut *(NET_DRIVERS.lock()[0]); match domain { AF_INET => match socket_type { SOCK_STREAM => { @@ -58,7 +59,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 2048]); let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - let tcp_handle = proc.sockets.add(tcp_socket); + let tcp_handle = iface.sockets().add(tcp_socket); proc.files.insert( fd, FileLike::Socket(SocketWrapper { @@ -83,7 +84,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu raw_tx_buffer, ); - let raw_handle = proc.sockets.add(raw_socket); + let raw_handle = iface.sockets().add(raw_socket); proc.files.insert( fd, FileLike::Socket(SocketWrapper { @@ -140,6 +141,7 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { let mut proc = process(); proc.memory_set.check_ptr(addr)?; + let iface = &mut *(NET_DRIVERS.lock()[0]); let mut dest = None; let mut port = 0; @@ -152,14 +154,10 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { return Err(SysError::EINVAL); } - let mut proc = process(); - // little hack: kick it forward - let iface = &mut *NET_DRIVERS.lock()[0]; - iface.poll(&mut proc.sockets); - let wrapper = proc.get_socket(fd)?; if let SocketType::Tcp = wrapper.socket_type { - let mut socket = proc.sockets.get::(wrapper.handle); + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); // TODO selects non-conflict high port static mut EPHEMERAL_PORT: u16 = 49152; @@ -173,8 +171,27 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { }; match socket.connect((dest.unwrap(), port), temp_port) { - Ok(()) => Ok(0), - Err(_) => Err(SysError::EISCONN), + Ok(()) => { + drop(socket); + drop(sockets); + + // wait for connection result + loop { + iface.poll(); + + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); + if socket.state() == TcpState::SynSent { + // still connecting + SOCKET_ACTIVITY._wait() + } else if socket.state() == TcpState::Established { + break Ok(0) + } else if socket.state() == TcpState::Closed { + break Err(SysError::ECONNREFUSED) + } + } + }, + Err(_) => Err(SysError::ENOBUFS), } } else { unimplemented!("socket type") @@ -182,18 +199,22 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { } pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { - // little hack: kick it forward - let iface = &mut *NET_DRIVERS.lock()[0]; - iface.poll(&mut proc.sockets); - + let iface = &mut *(NET_DRIVERS.lock()[0]); let wrapper = proc.get_socket(fd)?; if let SocketType::Tcp = wrapper.socket_type { - let mut socket = proc.sockets.get::(wrapper.handle); + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); + let slice = unsafe { slice::from_raw_parts(base, len) }; if socket.is_open() { if socket.can_send() { match socket.send_slice(&slice) { - Ok(size) => Ok(size as isize), + Ok(size) => { + drop(socket); + drop(sockets); + iface.poll(); + Ok(size as isize) + }, Err(err) => Err(SysError::ENOBUFS), } } else { @@ -231,17 +252,18 @@ pub fn sys_sendto( "sys_sendto: fd: {} buffer: {:?} len: {} addr: {:?} addr_len: {}", fd, buffer, len, addr, addr_len ); + let mut proc = process(); proc.memory_set.check_ptr(addr)?; proc.memory_set.check_array(buffer, len)?; - // little hack: kick it forward - let iface = &mut *NET_DRIVERS.lock()[0]; - iface.poll(&mut proc.sockets); + let iface = &mut *(NET_DRIVERS.lock()[0]); let wrapper = proc.get_socket(fd)?; if let SocketType::Raw = wrapper.socket_type { - let mut socket = proc.sockets.get::(wrapper.handle); + let v4_src = iface.ipv4_address().unwrap(); + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); let mut dest = None; let mut port = 0; @@ -252,7 +274,7 @@ pub fn sys_sendto( if dest == None { return Err(SysError::EINVAL); - } else if let Some(IpAddress::Ipv4(v4_dest)) = dest { + } else if let Some(IpAddress::Ipv4(v4_dst)) = dest { let slice = unsafe { slice::from_raw_parts(buffer, len) }; // using 20-byte IPv4 header let mut buffer = vec![0u8; len + 20]; @@ -261,14 +283,18 @@ pub fn sys_sendto( packet.set_header_len(20); packet.set_total_len((20 + len) as u16); packet.set_protocol(socket.ip_protocol().into()); - packet.set_src_addr(iface.ipv4_address().unwrap()); - packet.set_dst_addr(v4_dest); + packet.set_src_addr(v4_src); + packet.set_dst_addr(v4_dst); let payload = packet.payload_mut(); payload.copy_from_slice(slice); packet.fill_checksum(); socket.send_slice(&buffer).unwrap(); + drop(socket); + drop(sockets); + iface.poll(); + Ok(len as isize) } else { unimplemented!("ip type") @@ -292,18 +318,16 @@ pub fn sys_recvfrom( ); let mut proc = process(); - // little hack: kick it forward - let iface = &mut *NET_DRIVERS.lock()[0]; - iface.poll(&mut proc.sockets); + let iface = &mut *(NET_DRIVERS.lock()[0]); let wrapper = proc.get_socket(fd)?; if let SocketType::Raw = wrapper.socket_type { - let mut socket = proc.sockets.get::(wrapper.handle); - + loop { + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); - let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; - match socket.recv_slice(&mut slice) { - Ok(size) => { + let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; + if let Ok(size) = socket.recv_slice(&mut slice) { let mut packet = Ipv4Packet::new_unchecked(&slice); // FIXME: check size as per sin_family @@ -311,12 +335,11 @@ pub fn sys_recvfrom( fill_addr(&mut sockaddr_in, IpAddress::Ipv4(packet.src_addr()), 0); unsafe { *addr_len = size_of::() }; - Ok(size as isize) - } - Err(err) => { - warn!("err {:?}", err); - Err(SysError::ENOBUFS) + return Ok(size as isize); } + + drop(socket); + SOCKET_ACTIVITY._wait() } } else { unimplemented!("socket type") @@ -324,7 +347,8 @@ pub fn sys_recvfrom( } pub fn sys_close_socket(proc: &mut Process, fd: usize, handle: SocketHandle) -> SysResult { - let mut socket = proc.sockets.remove(handle); + let iface = &mut *(NET_DRIVERS.lock()[0]); + let mut socket = iface.sockets().remove(handle); match socket { Socket::Tcp(ref mut tcp_socket) => { tcp_socket.close();