diff --git a/kernel/src/drivers/mod.rs b/kernel/src/drivers/mod.rs index 2c47328..b1eb9ec 100644 --- a/kernel/src/drivers/mod.rs +++ b/kernel/src/drivers/mod.rs @@ -10,10 +10,15 @@ use crate::sync::{Condvar, MutexGuard, SpinNoIrq}; use self::block::virtio_blk::VirtIOBlkDriver; mod device_tree; +#[allow(dead_code)] pub mod bus; +#[allow(dead_code)] pub mod net; +#[allow(dead_code)] pub mod block; +#[allow(dead_code)] mod gpu; +#[allow(dead_code)] mod input; #[derive(Debug, Eq, PartialEq)] @@ -45,9 +50,6 @@ pub trait NetDriver : Driver { // get ipv4 address fn ipv4_address(&self) -> Option; - // get sockets - fn sockets(&self) -> MutexGuard, SpinNoIrq>; - // manually trigger a poll, use it after sending packets fn poll(&self); } diff --git a/kernel/src/drivers/net/e1000.rs b/kernel/src/drivers/net/e1000.rs index 6f68d34..de8a7ae 100644 --- a/kernel/src/drivers/net/e1000.rs +++ b/kernel/src/drivers/net/e1000.rs @@ -23,6 +23,7 @@ use smoltcp::Result; use volatile::Volatile; use crate::memory::active_table; +use crate::net::SOCKETS; use crate::sync::SpinNoIrqLock as Mutex; use crate::sync::{MutexGuard, SpinNoIrq}; use crate::HEAP_ALLOCATOR; @@ -72,7 +73,6 @@ const E1000_RAH: usize = 0x5404 / 4; pub struct E1000Interface { iface: Mutex>, driver: E1000Driver, - sockets: Mutex>, } impl Driver for E1000Interface { @@ -104,7 +104,7 @@ impl Driver for E1000Interface { if irq { let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); - let mut sockets = self.sockets.lock(); + let mut sockets = SOCKETS.lock(); match self.iface.lock().poll(&mut sockets, timestamp) { Ok(_) => { SOCKET_ACTIVITY.notify_all(); @@ -136,13 +136,9 @@ impl NetDriver for E1000Interface { self.iface.lock().ipv4_address() } - fn sockets(&self) -> MutexGuard, SpinNoIrq> { - self.sockets.lock() - } - fn poll(&self) { let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); - let mut sockets = self.sockets.lock(); + let mut sockets = SOCKETS.lock(); match self.iface.lock().poll(&mut sockets, timestamp) { Ok(_) => { SOCKET_ACTIVITY.notify_all(); @@ -195,8 +191,9 @@ impl<'a> phy::Device<'a> for E1000Driver { } } - let e1000 = - unsafe { slice::from_raw_parts_mut(driver.header as *mut Volatile, driver.size / 4) }; + let e1000 = unsafe { + slice::from_raw_parts_mut(driver.header as *mut Volatile, driver.size / 4) + }; let send_queue_size = PAGE_SIZE / size_of::(); let send_queue = unsafe { @@ -214,8 +211,8 @@ impl<'a> phy::Device<'a> for E1000Driver { let index = (rdt as usize + 1) % recv_queue_size; let recv_desc = &mut recv_queue[index]; - let transmit_avail = driver.first_trans || (*send_desc).status & 1 != 0; - let receive_avail = (*recv_desc).status & 1 != 0; + let transmit_avail = driver.first_trans || (*send_desc).status & 1 != 0; + let receive_avail = (*recv_desc).status & 1 != 0; if transmit_avail && receive_avail { let buffer = unsafe { @@ -247,8 +244,9 @@ impl<'a> phy::Device<'a> for E1000Driver { } } - let e1000 = - unsafe { slice::from_raw_parts_mut(driver.header as *mut Volatile, driver.size / 4) }; + let e1000 = unsafe { + slice::from_raw_parts_mut(driver.header as *mut Volatile, driver.size / 4) + }; let send_queue_size = PAGE_SIZE / size_of::(); let send_queue = unsafe { @@ -428,7 +426,6 @@ pub fn e1000_init(header: usize, size: usize) { // IPGT=0xa | IPGR1=0x8 | IPGR2=0xc e1000[E1000_TIPG].write(0xa | (0x8 << 10) | (0xc << 20)); // TIPG - // 4.6.5 Receive Initialization let mut ral: u32 = 0; let mut rah: u32 = 0; @@ -502,7 +499,6 @@ pub fn e1000_init(header: usize, size: usize) { let e1000_iface = E1000Interface { iface: Mutex::new(iface), - sockets: Mutex::new(SocketSet::new(vec![])), driver: net_driver.clone(), }; diff --git a/kernel/src/drivers/net/ixgbe.rs b/kernel/src/drivers/net/ixgbe.rs index c6471ba..bb5fefe 100644 --- a/kernel/src/drivers/net/ixgbe.rs +++ b/kernel/src/drivers/net/ixgbe.rs @@ -1,7 +1,6 @@ //! Intel 10Gb Network Adapter 82599 i.e. ixgbe network driver use alloc::alloc::{GlobalAlloc, Layout}; -use alloc::format; use alloc::prelude::*; use alloc::sync::Arc; use core::mem::size_of; @@ -14,7 +13,7 @@ use log::*; use rcore_memory::paging::PageTable; use rcore_memory::PAGE_SIZE; use smoltcp::iface::*; -use smoltcp::phy::{self, DeviceCapabilities, Checksum}; +use smoltcp::phy::{self, Checksum, DeviceCapabilities}; use smoltcp::socket::*; use smoltcp::time::Instant; use smoltcp::wire::EthernetAddress; @@ -23,6 +22,7 @@ use smoltcp::Result; use volatile::Volatile; use crate::memory::active_table; +use crate::net::SOCKETS; use crate::sync::SpinNoIrqLock as Mutex; use crate::sync::{MutexGuard, SpinNoIrq}; use crate::HEAP_ALLOCATOR; @@ -139,7 +139,6 @@ const IXGBE_EEC: usize = 0x10010 / 4; pub struct IXGBEInterface { iface: Mutex>, driver: IXGBEDriver, - sockets: Mutex>, name: String, irq: Option, } @@ -194,7 +193,7 @@ impl Driver for IXGBEInterface { if rx { let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); - let mut sockets = self.sockets.lock(); + let mut sockets = SOCKETS.lock(); match self.iface.lock().poll(&mut sockets, timestamp) { Ok(_) => { SOCKET_ACTIVITY.notify_all(); @@ -226,13 +225,9 @@ impl NetDriver for IXGBEInterface { self.iface.lock().ipv4_address() } - fn sockets(&self) -> MutexGuard, SpinNoIrq> { - self.sockets.lock() - } - fn poll(&self) { let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); - let mut sockets = self.sockets.lock(); + let mut sockets = SOCKETS.lock(); match self.iface.lock().poll(&mut sockets, timestamp) { Ok(_) => { SOCKET_ACTIVITY.notify_all(); @@ -576,7 +571,6 @@ pub fn ixgbe_init(name: String, irq: Option, header: usize, size: usize) { // CRCStrip | RSCACKC | FCOE_WRFIX ixgbe[IXGBE_RDRXCTL].write(ixgbe[IXGBE_RDRXCTL].read() | (1 << 0) | (1 << 25) | (1 << 26)); - /* Not completed part // Program RXPBSIZE, MRQC, PFQDE, RTRUP2TC, MFLCN.RPFCE, and MFLCN.RFCE according to the DCB and virtualization modes (see Section 4.6.11.3). // 4.6.11.3.4 DCB-Off, VT-Off @@ -717,7 +711,8 @@ pub fn ixgbe_init(name: String, irq: Option, header: usize, size: usize) { // Program the HLREG0 register according to the required MAC behavior. // TXCRCEN | RXCRCSTRP | TXPADEN | RXLNGTHERREN // ixgbe[IXGBE_HLREG0].write(ixgbe[IXGBE_HLREG0].read() & !(1 << 0) & !(1 << 1)); - ixgbe[IXGBE_HLREG0].write(ixgbe[IXGBE_HLREG0].read() | (1 << 0) | (1 << 1) | (1 << 10) | (1 << 27)); + ixgbe[IXGBE_HLREG0] + .write(ixgbe[IXGBE_HLREG0].read() | (1 << 0) | (1 << 1) | (1 << 10) | (1 << 27)); // The following steps should be done once per transmit queue: // 1. Allocate a region of memory for the transmit descriptor list. @@ -746,7 +741,6 @@ pub fn ixgbe_init(name: String, irq: Option, header: usize, size: usize) { ixgbe[IXGBE_TXDCTL].write(ixgbe[IXGBE_TXDCTL].read() | 1 << 25); while ixgbe[IXGBE_TXDCTL].read() & (1 << 25) == 0 {} - // 4.6.6 Interrupt Initialization // The software driver associates between Tx and Rx interrupt causes and the EICR register by setting the IVAR[n] registers. // map Rx0 to interrupt 0 @@ -758,7 +752,7 @@ pub fn ixgbe_init(name: String, irq: Option, header: usize, size: usize) { // CNT_WDIS | ITR Interval=100us // if sys_read() spin more times, the interval here should be larger // Linux use dynamic ETIR based on statistics - ixgbe[IXGBE_EITR].write(((100/2) << 3) | (1 << 31)); + ixgbe[IXGBE_EITR].write(((100 / 2) << 3) | (1 << 31)); // Disable general purpose interrupt // We don't need them ixgbe[IXGBE_GPIE].write(0); @@ -789,7 +783,6 @@ pub fn ixgbe_init(name: String, irq: Option, header: usize, size: usize) { let ixgbe_iface = IXGBEInterface { iface: Mutex::new(iface), - sockets: Mutex::new(SocketSet::new(vec![])), driver: net_driver.clone(), name, irq, @@ -798,5 +791,4 @@ pub fn ixgbe_init(name: String, irq: Option, header: usize, size: usize) { let driver = Arc::new(ixgbe_iface); DRIVERS.write().push(driver.clone()); NET_DRIVERS.write().push(driver); - } diff --git a/kernel/src/drivers/net/mod.rs b/kernel/src/drivers/net/mod.rs index eafbb9c..0019043 100644 --- a/kernel/src/drivers/net/mod.rs +++ b/kernel/src/drivers/net/mod.rs @@ -1,3 +1,3 @@ -pub mod virtio_net; pub mod e1000; -pub mod ixgbe; \ No newline at end of file +pub mod ixgbe; +pub mod virtio_net; diff --git a/kernel/src/drivers/net/virtio_net.rs b/kernel/src/drivers/net/virtio_net.rs index 91f1cf8..f2492fe 100644 --- a/kernel/src/drivers/net/virtio_net.rs +++ b/kernel/src/drivers/net/virtio_net.rs @@ -6,25 +6,25 @@ use core::mem::size_of; use core::slice; use bitflags::*; -use device_tree::Node; use device_tree::util::SliceRead; +use device_tree::Node; use log::*; -use rcore_memory::PAGE_SIZE; use rcore_memory::paging::PageTable; +use rcore_memory::PAGE_SIZE; use smoltcp::phy::{self, DeviceCapabilities}; -use smoltcp::Result; +use smoltcp::socket::SocketSet; use smoltcp::time::Instant; use smoltcp::wire::{EthernetAddress, Ipv4Address}; -use smoltcp::socket::SocketSet; +use smoltcp::Result; use volatile::{ReadOnly, Volatile}; -use crate::HEAP_ALLOCATOR; use crate::memory::active_table; use crate::sync::SpinNoIrqLock as Mutex; use crate::sync::{MutexGuard, SpinNoIrq}; +use crate::HEAP_ALLOCATOR; -use super::super::{DeviceType, Driver, DRIVERS, NET_DRIVERS, NetDriver}; use super::super::bus::virtio_mmio::*; +use super::super::{DeviceType, Driver, NetDriver, DRIVERS, NET_DRIVERS}; pub struct VirtIONet { interrupt_parent: u32, @@ -71,7 +71,6 @@ impl VirtIONet { self.queues[VIRTIO_QUEUE_TRANSMIT].can_add(1, 0) } - fn receive_available(&self) -> bool { self.queues[VIRTIO_QUEUE_RECEIVE].can_get() } @@ -90,10 +89,6 @@ impl NetDriver for VirtIONetDriver { unimplemented!() } - fn sockets(&self) -> MutexGuard, SpinNoIrq> { - unimplemented!() - } - fn poll(&self) { unimplemented!() } @@ -110,8 +105,10 @@ impl<'a> phy::Device<'a> for VirtIONetDriver { let driver = self.0.lock(); if driver.transmit_available() && driver.receive_available() { // potential racing - Some((VirtIONetRxToken(self.clone()), - VirtIONetTxToken(self.clone()))) + Some(( + VirtIONetRxToken(self.clone()), + VirtIONetTxToken(self.clone()), + )) } else { None } @@ -134,9 +131,10 @@ impl<'a> phy::Device<'a> for VirtIONetDriver { } } -impl phy::RxToken for VirtIONetRxToken { +impl phy::RxToken for VirtIONetRxToken { fn consume(self, _timestamp: Instant, f: F) -> Result - where F: FnOnce(&[u8]) -> Result + where + F: FnOnce(&[u8]) -> Result, { let (input, output, _, user_data) = { let mut driver = (self.0).0.lock(); @@ -156,7 +154,8 @@ impl phy::RxToken for VirtIONetRxToken { impl phy::TxToken for VirtIONetTxToken { fn consume(self, _timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result, + where + F: FnOnce(&mut [u8]) -> Result, { let output = { let mut driver = (self.0).0.lock(); @@ -165,16 +164,18 @@ impl phy::TxToken for VirtIONetTxToken { active_table().map_if_not_exists(driver.header as usize, driver.header as usize); if let Some((_, output, _, _)) = driver.queues[VIRTIO_QUEUE_TRANSMIT].get() { - unsafe { slice::from_raw_parts_mut(output[0].as_ptr() as *mut u8, output[0].len())} + unsafe { slice::from_raw_parts_mut(output[0].as_ptr() as *mut u8, output[0].len()) } } else { // allocate a page for buffer let page = unsafe { - HEAP_ALLOCATOR.alloc_zeroed(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap()) + HEAP_ALLOCATOR + .alloc_zeroed(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap()) } as usize; unsafe { slice::from_raw_parts_mut(page as *mut u8, PAGE_SIZE) } } }; - let output_buffer = &mut output[size_of::()..(size_of::() + len)]; + let output_buffer = + &mut output[size_of::()..(size_of::() + len)]; let result = f(output_buffer); let mut driver = (self.0).0.lock(); @@ -183,7 +184,6 @@ impl phy::TxToken for VirtIONetTxToken { } } - bitflags! { struct VirtIONetFeature : u64 { const CSUM = 1 << 0; @@ -234,7 +234,7 @@ bitflags! { #[derive(Debug)] struct VirtIONetworkConfig { mac: [u8; 6], - status: ReadOnly + status: ReadOnly, } // virtio 5.1.6 Device Operation @@ -250,7 +250,6 @@ struct VirtIONetHeader { // payload starts from here } - pub fn virtio_net_init(node: &Node) { let reg = node.prop_raw("reg").unwrap(); let from = reg.as_slice().read_be_u64(0).unwrap(); @@ -283,8 +282,10 @@ pub fn virtio_net_init(node: &Node) { interrupt_parent: node.prop_u32("interrupt-parent").unwrap(), header: from as usize, mac: EthernetAddress(mac), - queues: [VirtIOVirtqueue::new(header, VIRTIO_QUEUE_RECEIVE, queue_num), - VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num)], + queues: [ + VirtIOVirtqueue::new(header, VIRTIO_QUEUE_RECEIVE, queue_num), + VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num), + ], }; // allocate a page for buffer @@ -300,4 +301,4 @@ pub fn virtio_net_init(node: &Node) { DRIVERS.write().push(net_driver.clone()); NET_DRIVERS.write().push(net_driver); -} \ No newline at end of file +} diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index aff647a..3773914 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -1,2 +1,5 @@ +mod structs; mod test; -pub use self::test::server; \ No newline at end of file + +pub use self::structs::*; +pub use self::test::server; diff --git a/kernel/src/net/structs.rs b/kernel/src/net/structs.rs new file mode 100644 index 0000000..5779600 --- /dev/null +++ b/kernel/src/net/structs.rs @@ -0,0 +1,249 @@ +use alloc::sync::Arc; +use core::fmt; + +use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY}; +use crate::process::structs::Process; +use crate::sync::SpinNoIrqLock as Mutex; +use crate::syscall::*; + +use smoltcp::socket::*; +use smoltcp::wire::*; + +lazy_static! { + pub static ref SOCKETS: Arc>> = + Arc::new(Mutex::new(SocketSet::new(vec![]))); +} + +#[derive(Clone, Debug)] +pub struct TcpSocketState { + pub local_endpoint: Option, // save local endpoint for bind() + pub is_listening: bool, +} + +#[derive(Clone, Debug)] +pub struct UdpSocketState { + pub remote_endpoint: Option, // remember remote endpoint for connect() +} + +#[derive(Clone, Debug)] +pub enum SocketType { + Raw, + Tcp(TcpSocketState), + Udp(UdpSocketState), + Icmp, +} + +#[derive(Debug)] +pub struct SocketWrapper { + pub handle: SocketHandle, + pub socket_type: SocketType, +} + +pub fn get_ephemeral_port() -> u16 { + // TODO selects non-conflict high port + static mut EPHEMERAL_PORT: u16 = 49152; + unsafe { + if EPHEMERAL_PORT == 65535 { + EPHEMERAL_PORT = 49152; + } else { + EPHEMERAL_PORT = EPHEMERAL_PORT + 1; + } + EPHEMERAL_PORT + } +} + +/// Safety: call this without SOCKETS locked +pub fn poll_ifaces() { + for iface in NET_DRIVERS.read().iter() { + iface.poll(); + } +} + +impl Drop for SocketWrapper { + fn drop(&mut self) { + let mut sockets = SOCKETS.lock(); + sockets.release(self.handle); + sockets.prune(); + + // send FIN immediately when applicable + drop(sockets); + poll_ifaces(); + } +} + +impl SocketWrapper { + pub fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + if let SocketType::Raw = self.socket_type { + if let Some(endpoint) = sendto_endpoint { + // temporary solution + let iface = &*(NET_DRIVERS.read()[0]); + let v4_src = iface.ipv4_address().unwrap(); + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle); + + if let IpAddress::Ipv4(v4_dst) = endpoint.addr { + let len = data.len(); + // using 20-byte IPv4 header + let mut buffer = vec![0u8; len + 20]; + let mut packet = Ipv4Packet::new_unchecked(&mut buffer); + packet.set_version(4); + packet.set_header_len(20); + packet.set_total_len((20 + len) as u16); + packet.set_protocol(socket.ip_protocol().into()); + packet.set_src_addr(v4_src); + packet.set_dst_addr(v4_dst); + let payload = packet.payload_mut(); + payload.copy_from_slice(data); + packet.fill_checksum(); + + socket.send_slice(&buffer).unwrap(); + + // avoid deadlock + drop(socket); + drop(sockets); + iface.poll(); + + Ok(len) + } else { + unimplemented!("ip type") + } + } else { + Err(SysError::ENOTCONN) + } + } else if let SocketType::Tcp(_) = self.socket_type { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle); + + if socket.is_open() { + if socket.can_send() { + match socket.send_slice(&data) { + Ok(size) => { + // avoid deadlock + drop(socket); + drop(sockets); + + poll_ifaces(); + Ok(size) + } + Err(err) => Err(SysError::ENOBUFS), + } + } else { + Err(SysError::ENOBUFS) + } + } else { + Err(SysError::ENOTCONN) + } + } else if let SocketType::Udp(ref state) = self.socket_type { + let remote_endpoint = { + if let Some(ref endpoint) = sendto_endpoint { + endpoint + } else if let Some(ref endpoint) = state.remote_endpoint { + endpoint + } else { + return Err(SysError::ENOTCONN); + } + }; + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle); + + if socket.endpoint().port == 0 { + let temp_port = get_ephemeral_port(); + socket + .bind(IpEndpoint::new(IpAddress::Unspecified, temp_port)) + .unwrap(); + } + + if socket.can_send() { + match socket.send_slice(&data, *remote_endpoint) { + Ok(()) => { + // avoid deadlock + drop(socket); + drop(sockets); + + poll_ifaces(); + Ok(data.len()) + } + Err(err) => Err(SysError::ENOBUFS), + } + } else { + Err(SysError::ENOBUFS) + } + } else { + unimplemented!("socket type") + } + } + + pub fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { + if let SocketType::Raw = self.socket_type { + loop { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle); + + if let Ok(size) = socket.recv_slice(data) { + let packet = Ipv4Packet::new_unchecked(data); + + return ( + Ok(size), + IpEndpoint { + addr: IpAddress::Ipv4(packet.src_addr()), + port: 0, + }, + ); + } + + // avoid deadlock + drop(socket); + drop(sockets); + SOCKET_ACTIVITY._wait() + } + } else if let SocketType::Tcp(_) = self.socket_type { + spin_and_wait(&[&SOCKET_ACTIVITY], move || { + poll_ifaces(); + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle); + + if socket.is_open() { + if let Ok(size) = socket.recv_slice(data) { + if size > 0 { + let endpoint = socket.remote_endpoint(); + // avoid deadlock + drop(socket); + drop(sockets); + + poll_ifaces(); + return Some((Ok(size), endpoint)); + } + } + } else { + return Some((Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED)); + } + None + }) + } else if let SocketType::Udp(ref state) = self.socket_type { + loop { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle); + + if socket.is_open() { + if let Ok((size, remote_endpoint)) = socket.recv_slice(data) { + let endpoint = remote_endpoint; + // avoid deadlock + drop(socket); + drop(sockets); + + poll_ifaces(); + return (Ok(size), endpoint); + } + } else { + return (Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED); + } + + // avoid deadlock + drop(socket); + SOCKET_ACTIVITY._wait() + } + } else { + unimplemented!("socket type") + } + } +} diff --git a/kernel/src/net/test.rs b/kernel/src/net/test.rs index 3d72d78..ebfb007 100644 --- a/kernel/src/net/test.rs +++ b/kernel/src/net/test.rs @@ -1,11 +1,12 @@ -use crate::thread; -use crate::drivers::NET_DRIVERS; -use smoltcp::socket::*; use crate::drivers::NetDriver; +use crate::drivers::NET_DRIVERS; +use crate::net::SOCKETS; +use crate::thread; use alloc::vec; use core::fmt::Write; +use smoltcp::socket::*; -pub extern fn server(_arg: usize) -> ! { +pub extern "C" fn server(_arg: usize) -> ! { if NET_DRIVERS.read().len() < 1 { loop { thread::yield_now(); @@ -24,18 +25,15 @@ pub extern fn server(_arg: usize) -> ! { let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]); let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer); - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let udp_handle = sockets.add(udp_socket); let tcp_handle = sockets.add(tcp_socket); let tcp2_handle = sockets.add(tcp2_socket); drop(sockets); - drop(iface); loop { { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); // udp server { @@ -45,10 +43,8 @@ pub extern fn server(_arg: usize) -> ! { } let client = match socket.recv() { - Ok((_, endpoint)) => { - Some(endpoint) - } - Err(_) => None + Ok((_, endpoint)) => Some(endpoint), + Err(_) => None, }; if let Some(endpoint) = client { let hello = b"hello\n"; @@ -85,5 +81,4 @@ pub extern fn server(_arg: usize) -> ! { thread::yield_now(); } - } diff --git a/kernel/src/process/structs.rs b/kernel/src/process/structs.rs index e62d741..48921c0 100644 --- a/kernel/src/process/structs.rs +++ b/kernel/src/process/structs.rs @@ -4,8 +4,6 @@ use core::fmt; use log::*; use spin::{Mutex, RwLock}; use xmas_elf::{ElfFile, header, program::{Flags, Type}}; -use smoltcp::socket::SocketHandle; -use smoltcp::wire::IpEndpoint; use rcore_memory::PAGE_SIZE; use rcore_thread::Tid; @@ -14,6 +12,7 @@ use crate::memory::{ByFrame, GlobalFrameAlloc, KernelStack, MemoryAttr, MemorySe use crate::fs::{FileHandle, OpenOptions}; use crate::sync::Condvar; use crate::drivers::NET_DRIVERS; +use crate::net::{SocketWrapper, SOCKETS}; use super::abi::{self, ProcInitInfo}; @@ -27,30 +26,6 @@ pub struct Thread { pub proc: Arc>, } -#[derive(Clone, Debug)] -pub struct TcpSocketState { - pub local_endpoint: Option, // save local endpoint for bind() - pub is_listening: bool, -} - -#[derive(Clone, Debug)] -pub struct UdpSocketState { - pub remote_endpoint: Option, // remember remote endpoint for connect() -} - -#[derive(Clone, Debug)] -pub enum SocketType { - Raw, - Tcp(TcpSocketState), - Udp(UdpSocketState), - Icmp -} - -#[derive(Debug)] -pub struct SocketWrapper { - pub handle: SocketHandle, - pub socket_type: SocketType, -} #[derive(Clone)] pub enum FileLike { @@ -63,12 +38,7 @@ impl fmt::Debug for FileLike { match self { FileLike::File(_) => write!(f, "File"), FileLike::Socket(wrapper) => { - match wrapper.socket_type { - SocketType::Raw => write!(f, "RawSocket"), - SocketType::Tcp(_) => write!(f, "TcpSocket"), - SocketType::Udp(_) => write!(f, "UdpSocket"), - SocketType::Icmp => write!(f, "IcmpSocket"), - } + write!(f, "{:?}", wrapper) }, } } @@ -324,12 +294,10 @@ impl Thread { debug!("fork: temporary copy data!"); let kstack = KernelStack::new(); - for iface in NET_DRIVERS.read().iter() { - let mut sockets = iface.sockets(); - for (_fd, file) in files.iter() { - if let FileLike::Socket(wrapper) = file { - sockets.retain(wrapper.handle); - } + let mut sockets = SOCKETS.lock(); + for (_fd, file) in files.iter() { + if let FileLike::Socket(wrapper) = file { + sockets.retain(wrapper.handle); } } diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index 6fdd673..d2b7000 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -2,7 +2,10 @@ use super::*; use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY}; -use crate::process::structs::TcpSocketState; +use crate::net::{ + get_ephemeral_port, poll_ifaces, SocketType, SocketWrapper, TcpSocketState, UdpSocketState, + SOCKETS, +}; use core::cmp::min; use core::mem::size_of; use smoltcp::socket::*; @@ -23,26 +26,12 @@ const IPPROTO_TCP: usize = 6; const TCP_SENDBUF: usize = 512 * 1024; // 512K const TCP_RECVBUF: usize = 512 * 1024; // 512K -fn get_ephemeral_port() -> u16 { - // TODO selects non-conflict high port - static mut EPHEMERAL_PORT: u16 = 49152; - unsafe { - if EPHEMERAL_PORT == 65535 { - EPHEMERAL_PORT = 49152; - } else { - EPHEMERAL_PORT = EPHEMERAL_PORT + 1; - } - EPHEMERAL_PORT - } -} - pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult { info!( "socket: domain: {}, socket_type: {}, protocol: {}", domain, socket_type, protocol ); let mut proc = process(); - let iface = &*(NET_DRIVERS.read()[0]); match domain { AF_INET | AF_UNIX => match socket_type & SOCK_TYPE_MASK { SOCK_STREAM => { @@ -52,7 +41,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]); let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - let tcp_handle = iface.sockets().add(tcp_socket); + let tcp_handle = SOCKETS.lock().add(tcp_socket); proc.files.insert( fd, FileLike::Socket(SocketWrapper { @@ -75,7 +64,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 2048]); let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - let udp_handle = iface.sockets().add(udp_socket); + let udp_handle = SOCKETS.lock().add(udp_socket); proc.files.insert( fd, FileLike::Socket(SocketWrapper { @@ -102,7 +91,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu raw_tx_buffer, ); - let raw_handle = iface.sockets().add(raw_socket); + let raw_handle = SOCKETS.lock().add(raw_socket); proc.files.insert( fd, FileLike::Socket(SocketWrapper { @@ -211,8 +200,7 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu let wrapper = &mut proc.get_socket_mut(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(wrapper.handle); let temp_port = get_ephemeral_port(); @@ -225,10 +213,9 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu // wait for connection result loop { - let iface = &*(NET_DRIVERS.read()[0]); - iface.poll(); + poll_ifaces(); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let socket = sockets.get::(wrapper.handle); if socket.state() == TcpState::SynSent { // still connecting @@ -256,307 +243,73 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu } pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { - let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket(fd)?; - if let SocketType::Tcp(_) = wrapper.socket_type { - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - let slice = unsafe { slice::from_raw_parts(base, len) }; - if socket.is_open() { - if socket.can_send() { - match socket.send_slice(&slice) { - Ok(size) => { - // avoid deadlock - drop(socket); - drop(sockets); - - iface.poll(); - Ok(size) - } - Err(err) => Err(SysError::ENOBUFS), - } - } else { - Err(SysError::ENOBUFS) - } - } else { - Err(SysError::ENOTCONN) - } - } else if let SocketType::Udp(ref state) = wrapper.socket_type { - if let Some(ref remote_endpoint) = state.remote_endpoint { - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - if socket.endpoint().port == 0 { - let v4_src = iface.ipv4_address().unwrap(); - let temp_port = get_ephemeral_port(); - socket - .bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port)) - .unwrap(); - } - - let slice = unsafe { slice::from_raw_parts(base, len) }; - if socket.is_open() { - if socket.can_send() { - match socket.send_slice(&slice, *remote_endpoint) { - Ok(()) => { - // avoid deadlock - drop(socket); - drop(sockets); - - iface.poll(); - Ok(len) - } - Err(err) => Err(SysError::ENOBUFS), - } - } else { - Err(SysError::ENOBUFS) - } - } else { - Err(SysError::ENOTCONN) - } - } else { - Err(SysError::ENOTCONN) - } - } else { - unimplemented!("socket type") - } + let slice = unsafe { slice::from_raw_parts(base, len) }; + wrapper.write(&slice, None) } pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult { - let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket(fd)?; - if let SocketType::Tcp(_) = wrapper.socket_type { - spin_and_wait(&[&SOCKET_ACTIVITY], move || { - iface.poll(); - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - if socket.is_open() { - let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; - if let Ok(size) = socket.recv_slice(&mut slice) { - if size > 0 { - // avoid deadlock - drop(socket); - drop(sockets); - - iface.poll(); - return Some(Ok(size)); - } - } - } else { - return Some(Err(SysError::ENOTCONN)); - } - None - }) - } else if let SocketType::Udp(_) = wrapper.socket_type { - loop { - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - if socket.is_open() { - let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; - if let Ok((size, _)) = socket.recv_slice(&mut slice) { - // avoid deadlock - drop(socket); - drop(sockets); - - iface.poll(); - return Ok(size); - } - } else { - return Err(SysError::ENOTCONN); - } - - // avoid deadlock - drop(socket); - SOCKET_ACTIVITY._wait() - } - } else { - unimplemented!("socket type") - } + let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; + let (result, _) = wrapper.read(&mut slice); + result } pub fn sys_sendto( fd: usize, - buffer: *const u8, + base: *const u8, len: usize, flags: usize, addr: *const SockAddr, addr_len: usize, ) -> SysResult { info!( - "sys_sendto: fd: {} buffer: {:?} len: {} addr: {:?} addr_len: {}", - fd, buffer, len, addr, addr_len + "sys_sendto: fd: {} base: {:?} len: {} addr: {:?} addr_len: {}", + fd, base, len, addr, addr_len ); let mut proc = process(); - proc.memory_set.check_array(buffer, len)?; - - let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; - - let iface = &*(NET_DRIVERS.read()[0]); + proc.memory_set.check_array(base, len)?; let wrapper = proc.get_socket(fd)?; - if let SocketType::Raw = wrapper.socket_type { - let v4_src = iface.ipv4_address().unwrap(); - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - if let IpAddress::Ipv4(v4_dst) = endpoint.addr { - let slice = unsafe { slice::from_raw_parts(buffer, len) }; - // using 20-byte IPv4 header - let mut buffer = vec![0u8; len + 20]; - let mut packet = Ipv4Packet::new_unchecked(&mut buffer); - packet.set_version(4); - packet.set_header_len(20); - packet.set_total_len((20 + len) as u16); - packet.set_protocol(socket.ip_protocol().into()); - packet.set_src_addr(v4_src); - packet.set_dst_addr(v4_dst); - let payload = packet.payload_mut(); - payload.copy_from_slice(slice); - packet.fill_checksum(); - - socket.send_slice(&buffer).unwrap(); - - // avoid deadlock - drop(socket); - drop(sockets); - iface.poll(); - - Ok(len) - } else { - unimplemented!("ip type") - } - } else if let SocketType::Udp(_) = wrapper.socket_type { - let v4_src = iface.ipv4_address().unwrap(); - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - if socket.endpoint().port == 0 { - let temp_port = get_ephemeral_port(); - socket - .bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port)) - .unwrap(); - } - - let slice = unsafe { slice::from_raw_parts(buffer, len) }; - - socket.send_slice(&slice, endpoint).unwrap(); - - // avoid deadlock - drop(socket); - drop(sockets); - iface.poll(); - - Ok(len) - } else { - unimplemented!("socket type") - } + let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; + let slice = unsafe { slice::from_raw_parts(base, len) }; + wrapper.write(&slice, Some(endpoint)) } pub fn sys_recvfrom( fd: usize, - buffer: *mut u8, + base: *mut u8, len: usize, flags: usize, addr: *mut SockAddr, addr_len: *mut u32, ) -> SysResult { info!( - "sys_recvfrom: fd: {} buffer: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}", - fd, buffer, len, flags, addr, addr_len + "sys_recvfrom: fd: {} base: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}", + fd, base, len, flags, addr, addr_len ); let mut proc = process(); - proc.memory_set.check_mut_array(buffer, len)?; - - let iface = &*(NET_DRIVERS.read()[0]); + proc.memory_set.check_mut_array(base, len)?; let wrapper = proc.get_socket(fd)?; - // TODO: move some part of these into one generic function - if let SocketType::Raw = wrapper.socket_type { - loop { - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; - if let Ok(size) = socket.recv_slice(&mut slice) { - let packet = Ipv4Packet::new_unchecked(&slice); - - if !addr.is_null() { - // FIXME: check size as per sin_family - let sockaddr_in = SockAddr::from(IpEndpoint { - addr: IpAddress::Ipv4(packet.src_addr()), - port: 0, - }); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - } + let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; + let (result, endpoint) = wrapper.read(&mut slice); - return Ok(size); - } - - // avoid deadlock - drop(socket); - drop(sockets); - SOCKET_ACTIVITY._wait() - } - } else if let SocketType::Udp(_) = wrapper.socket_type { - loop { - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; - if let Ok((size, endpoint)) = socket.recv_slice(&mut slice) { - if !addr.is_null() { - let sockaddr_in = SockAddr::from(endpoint); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - } - - return Ok(size); - } - - // avoid deadlock - drop(socket); - drop(sockets); - SOCKET_ACTIVITY._wait() - } - } else if let SocketType::Tcp(_) = wrapper.socket_type { - loop { - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; - if let Ok(size) = socket.recv_slice(&mut slice) { - if !addr.is_null() { - let sockaddr_in = SockAddr::from(socket.remote_endpoint()); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - } - - return Ok(size); - } - - // avoid deadlock - drop(socket); - drop(sockets); - SOCKET_ACTIVITY._wait() + if result.is_ok() && !addr.is_null() { + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to(&mut proc, addr, addr_len)?; } - } else { - unimplemented!("socket type") } + + result } impl Clone for SocketWrapper { fn clone(&self) -> Self { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); sockets.retain(self.handle); SocketWrapper { @@ -566,19 +319,6 @@ impl Clone for SocketWrapper { } } -impl Drop for SocketWrapper { - fn drop(&mut self) { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); - sockets.release(self.handle); - sockets.prune(); - - // send FIN immediately when applicable - drop(sockets); - iface.poll(); - } -} - pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult { info!("sys_bind: fd: {} addr: {:?} len: {}", fd, addr, addr_len); let mut proc = process(); @@ -589,7 +329,6 @@ pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult } info!("sys_bind: fd: {} bind to {}", fd, endpoint); - let iface = &*(NET_DRIVERS.read()[0]); let wrapper = &mut proc.get_socket_mut(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { wrapper.socket_type = SocketType::Tcp(TcpSocketState { @@ -598,7 +337,7 @@ pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult }); Ok(0) } else if let SocketType::Udp(_) = wrapper.socket_type { - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(wrapper.handle); match socket.bind(endpoint) { Ok(()) => Ok(0), @@ -615,14 +354,13 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult { // open multiple sockets for each connection let mut proc = process(); - let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type { if tcp_state.is_listening { // it is ok to listen twice Ok(0) } else if let Some(local_endpoint) = tcp_state.local_endpoint { - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(wrapper.handle); info!("socket {} listening on {:?}", fd, local_endpoint); @@ -649,10 +387,9 @@ pub fn sys_shutdown(fd: usize, how: usize) -> SysResult { info!("sys_shutdown: fd: {} how: {}", fd, how); let mut proc = process(); - let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(wrapper.handle); socket.close(); Ok(0) @@ -686,8 +423,7 @@ pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResu if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() { if let Some(endpoint) = tcp_state.local_endpoint { loop { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let socket = sockets.get::(wrapper.handle); if socket.is_active() { @@ -736,14 +472,13 @@ pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResu drop(sockets); drop(proc); - iface.poll(); + poll_ifaces(); return Ok(new_fd); } // avoid deadlock drop(socket); drop(sockets); - drop(iface); SOCKET_ACTIVITY._wait() } } else { @@ -767,7 +502,6 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy return Err(SysError::EINVAL); } - let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; if let SocketType::Tcp(state) = &wrapper.socket_type { if let Some(endpoint) = state.local_endpoint { @@ -777,7 +511,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy } Ok(0) } else { - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let socket = sockets.get::(wrapper.handle); let endpoint = socket.local_endpoint(); if endpoint.port != 0 { @@ -791,7 +525,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy } } } else if let SocketType::Udp(_) = &wrapper.socket_type { - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let socket = sockets.get::(wrapper.handle); let endpoint = socket.endpoint(); if endpoint.port != 0 { @@ -822,10 +556,9 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy return Err(SysError::EINVAL); } - let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let socket = sockets.get::(wrapper.handle); if socket.is_open() { @@ -860,8 +593,7 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) { let mut output = false; let mut err = false; if let SocketType::Tcp(state) = wrapper.socket_type.clone() { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let socket = sockets.get::(wrapper.handle); if state.is_listening && socket.is_active() { @@ -879,8 +611,7 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) { } } } else if let SocketType::Udp(_) = wrapper.socket_type { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); + let mut sockets = SOCKETS.lock(); let socket = sockets.get::(wrapper.handle); if socket.can_recv() {