diff --git a/kernel/src/fs/file_like.rs b/kernel/src/fs/file_like.rs new file mode 100644 index 0000000..d765f2c --- /dev/null +++ b/kernel/src/fs/file_like.rs @@ -0,0 +1,40 @@ +use core::fmt; + +use super::FileHandle; +use crate::net::Socket; +use crate::syscall::SysResult; +use alloc::boxed::Box; + +// TODO: merge FileLike to FileHandle ? +// TODO: fix dup and remove Clone +#[derive(Clone)] +pub enum FileLike { + File(FileHandle), + Socket(Box), +} + +impl FileLike { + pub fn read(&mut self, buf: &mut [u8]) -> SysResult { + let len = match self { + FileLike::File(file) => file.read(buf)?, + FileLike::Socket(socket) => socket.read(buf).0?, + }; + Ok(len) + } + pub fn write(&mut self, buf: &[u8]) -> SysResult { + let len = match self { + FileLike::File(file) => file.write(buf)?, + FileLike::Socket(socket) => socket.write(buf, None)?, + }; + Ok(len) + } +} + +impl fmt::Debug for FileLike { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FileLike::File(_) => write!(f, "File"), + FileLike::Socket(_) => write!(f, "Socket"), + } + } +} diff --git a/kernel/src/fs/mod.rs b/kernel/src/fs/mod.rs index cfcafaf..488dcb2 100644 --- a/kernel/src/fs/mod.rs +++ b/kernel/src/fs/mod.rs @@ -7,11 +7,13 @@ use rcore_fs_sfs::SimpleFileSystem; use crate::arch::driver::ide; pub use self::file::*; +pub use self::file_like::*; pub use self::pipe::Pipe; pub use self::stdio::{STDIN, STDOUT}; mod device; mod file; +mod file_like; mod pipe; mod stdio; diff --git a/kernel/src/net/structs.rs b/kernel/src/net/structs.rs index 0295e54..bc076fb 100644 --- a/kernel/src/net/structs.rs +++ b/kernel/src/net/structs.rs @@ -2,68 +2,87 @@ use crate::arch::rand; use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY}; use crate::sync::SpinNoIrqLock as Mutex; use crate::syscall::*; -use alloc::sync::Arc; +use alloc::boxed::Box; use smoltcp::socket::*; use smoltcp::wire::*; +/// +pub trait Socket: Send + Sync { + fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint); + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult; + fn poll(&self) -> (bool, bool, bool); // (in, out, err) + fn connect(&mut self, endpoint: IpEndpoint) -> SysResult; + fn bind(&mut self, endpoint: IpEndpoint) -> SysResult { + Err(SysError::EINVAL) + } + fn listen(&mut self) -> SysResult { + Err(SysError::EINVAL) + } + fn shutdown(&self) -> SysResult { + Err(SysError::EINVAL) + } + fn accept(&mut self) -> Result<(Box, IpEndpoint), SysError> { + Err(SysError::EINVAL) + } + fn endpoint(&self) -> Option { + None + } + fn remote_endpoint(&self) -> Option { + None + } + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.box_clone() + } +} + lazy_static! { - pub static ref SOCKETS: Arc>> = - Arc::new(Mutex::new(SocketSet::new(vec![]))); + /// Global SocketSet in smoltcp. + /// + /// Because smoltcp is a single thread network stack, + /// every socket operation needs to lock this. + pub static ref SOCKETS: Mutex> = + Mutex::new(SocketSet::new(vec![])); } -#[derive(Clone, Debug)] +#[derive(Debug, Clone)] pub struct TcpSocketState { - pub local_endpoint: Option, // save local endpoint for bind() - pub is_listening: bool, + handle: GlobalSocketHandle, + local_endpoint: Option, // save local endpoint for bind() + is_listening: bool, } -#[derive(Clone, Debug)] +#[derive(Debug, Clone)] pub struct UdpSocketState { - pub remote_endpoint: Option, // remember remote endpoint for connect() + handle: GlobalSocketHandle, + remote_endpoint: Option, // remember remote endpoint for connect() } -#[derive(Clone, Debug)] -pub enum SocketType { - Raw, - Tcp(TcpSocketState), - Udp(UdpSocketState), - Icmp, +#[derive(Debug, Clone)] +pub struct RawSocketState { + handle: GlobalSocketHandle, } +/// A wrapper for `SocketHandle`. +/// Auto increase and decrease reference count on Clone and Drop. #[derive(Debug)] -pub struct SocketWrapper { - pub handle: SocketHandle, - pub socket_type: SocketType, -} - -pub fn get_ephemeral_port() -> u16 { - // TODO selects non-conflict high port - static mut EPHEMERAL_PORT: u16 = 0; - unsafe { - if EPHEMERAL_PORT == 0 { - EPHEMERAL_PORT = (49152 + rand::rand() % (65536 - 49152)) as u16; - } - if EPHEMERAL_PORT == 65535 { - EPHEMERAL_PORT = 49152; - } else { - EPHEMERAL_PORT = EPHEMERAL_PORT + 1; - } - EPHEMERAL_PORT - } -} +struct GlobalSocketHandle(SocketHandle); -/// Safety: call this without SOCKETS locked -pub fn poll_ifaces() { - for iface in NET_DRIVERS.read().iter() { - iface.poll(); +impl Clone for GlobalSocketHandle { + fn clone(&self) -> Self { + SOCKETS.lock().retain(self.0); + Self(self.0) } } -impl Drop for SocketWrapper { +impl Drop for GlobalSocketHandle { fn drop(&mut self) { let mut sockets = SOCKETS.lock(); - sockets.release(self.handle); + sockets.release(self.0); sockets.prune(); // send FIN immediately when applicable @@ -72,97 +91,61 @@ impl Drop for SocketWrapper { } } -impl SocketWrapper { - pub fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { - if let SocketType::Raw = self.socket_type { - if let Some(endpoint) = sendto_endpoint { - // temporary solution - let iface = &*(NET_DRIVERS.read()[0]); - let v4_src = iface.ipv4_address().unwrap(); - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(self.handle); - - if let IpAddress::Ipv4(v4_dst) = endpoint.addr { - let len = data.len(); - // using 20-byte IPv4 header - let mut buffer = vec![0u8; len + 20]; - let mut packet = Ipv4Packet::new_unchecked(&mut buffer); - packet.set_version(4); - packet.set_header_len(20); - packet.set_total_len((20 + len) as u16); - packet.set_protocol(socket.ip_protocol().into()); - packet.set_src_addr(v4_src); - packet.set_dst_addr(v4_dst); - let payload = packet.payload_mut(); - payload.copy_from_slice(data); - packet.fill_checksum(); - - socket.send_slice(&buffer).unwrap(); +impl TcpSocketState { + pub fn new() -> Self { + let rx_buffer = TcpSocketBuffer::new(vec![0; TCP_RECVBUF]); + let tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]); + let socket = TcpSocket::new(rx_buffer, tx_buffer); + let handle = GlobalSocketHandle(SOCKETS.lock().add(socket)); - // avoid deadlock - drop(socket); - drop(sockets); - iface.poll(); + TcpSocketState { + handle, + local_endpoint: None, + is_listening: false, + } + } +} - Ok(len) - } else { - unimplemented!("ip type") - } - } else { - Err(SysError::ENOTCONN) - } - } else if let SocketType::Tcp(_) = self.socket_type { +impl Socket for TcpSocketState { + fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { + spin_and_wait(&[&SOCKET_ACTIVITY], move || { + poll_ifaces(); let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(self.handle); + let mut socket = sockets.get::(self.handle.0); if socket.is_open() { - if socket.can_send() { - match socket.send_slice(&data) { - Ok(size) => { - // avoid deadlock - drop(socket); - drop(sockets); + if let Ok(size) = socket.recv_slice(data) { + if size > 0 { + let endpoint = socket.remote_endpoint(); + // avoid deadlock + drop(socket); + drop(sockets); - poll_ifaces(); - Ok(size) - } - Err(err) => Err(SysError::ENOBUFS), + poll_ifaces(); + return Some((Ok(size), endpoint)); } - } else { - Err(SysError::ENOBUFS) } } else { - Err(SysError::ENOTCONN) + return Some((Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED)); } - } else if let SocketType::Udp(ref state) = self.socket_type { - let remote_endpoint = { - if let Some(ref endpoint) = sendto_endpoint { - endpoint - } else if let Some(ref endpoint) = state.remote_endpoint { - endpoint - } else { - return Err(SysError::ENOTCONN); - } - }; - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(self.handle); + None + }) + } - if socket.endpoint().port == 0 { - let temp_port = get_ephemeral_port(); - socket - .bind(IpEndpoint::new(IpAddress::Unspecified, temp_port)) - .unwrap(); - } + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + if socket.is_open() { if socket.can_send() { - match socket.send_slice(&data, *remote_endpoint) { - Ok(()) => { + match socket.send_slice(&data) { + Ok(size) => { // avoid deadlock drop(socket); drop(sockets); poll_ifaces(); - Ok(data.len()) + Ok(size) } Err(err) => Err(SysError::ENOBUFS), } @@ -170,81 +153,433 @@ impl SocketWrapper { Err(SysError::ENOBUFS) } } else { - unimplemented!("socket type") + Err(SysError::ENOTCONN) } } - pub fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { - if let SocketType::Raw = self.socket_type { - loop { - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(self.handle); + fn poll(&self) -> (bool, bool, bool) { + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); - if let Ok(size) = socket.recv_slice(data) { - let packet = Ipv4Packet::new_unchecked(data); - - return ( - Ok(size), - IpEndpoint { - addr: IpAddress::Ipv4(packet.src_addr()), - port: 0, - }, - ); - } + let (mut input, mut output, mut err) = (false, false, false); + if self.is_listening && socket.is_active() { + // a new connection + input = true; + } else if !socket.is_open() { + err = true; + } else { + if socket.can_recv() { + input = true; + } + if socket.can_send() { + output = true; + } + } + (input, output, err) + } + + fn connect(&mut self, endpoint: IpEndpoint) -> SysResult { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + let temp_port = get_ephemeral_port(); + + match socket.connect(endpoint, temp_port) { + Ok(()) => { // avoid deadlock drop(socket); drop(sockets); - SOCKET_ACTIVITY._wait() - } - } else if let SocketType::Tcp(_) = self.socket_type { - spin_and_wait(&[&SOCKET_ACTIVITY], move || { - poll_ifaces(); - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(self.handle); - - if socket.is_open() { - if let Ok(size) = socket.recv_slice(data) { - if size > 0 { - let endpoint = socket.remote_endpoint(); - // avoid deadlock + + // wait for connection result + loop { + poll_ifaces(); + + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + match socket.state() { + TcpState::SynSent => { + // still connecting drop(socket); drop(sockets); - - poll_ifaces(); - return Some((Ok(size), endpoint)); + debug!("poll for connection wait"); + SOCKET_ACTIVITY._wait(); + } + TcpState::Established => { + break Ok(0); + } + _ => { + break Err(SysError::ECONNREFUSED); } } - } else { - return Some((Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED)); } + } + Err(_) => Err(SysError::ENOBUFS), + } + } + + fn bind(&mut self, mut endpoint: IpEndpoint) -> SysResult { + if endpoint.port == 0 { + endpoint.port = get_ephemeral_port(); + } + self.local_endpoint = Some(endpoint); + self.is_listening = false; + Ok(0) + } + + fn listen(&mut self) -> SysResult { + if self.is_listening { + // it is ok to listen twice + return Ok(0); + } + let local_endpoint = self.local_endpoint.ok_or(SysError::EINVAL)?; + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + + info!("socket listening on {:?}", local_endpoint); + if socket.is_listening() { + return Ok(0); + } + match socket.listen(local_endpoint) { + Ok(()) => { + self.is_listening = true; + Ok(0) + } + Err(_) => Err(SysError::EINVAL), + } + } + + fn shutdown(&self) -> SysResult { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + socket.close(); + Ok(0) + } + + fn accept(&mut self) -> Result<(Box, IpEndpoint), SysError> { + let endpoint = self.local_endpoint.ok_or(SysError::EINVAL)?; + loop { + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + + if socket.is_active() { + let remote_endpoint = socket.remote_endpoint(); + drop(socket); + + let new_socket = { + let rx_buffer = TcpSocketBuffer::new(vec![0; TCP_RECVBUF]); + let tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]); + let mut socket = TcpSocket::new(rx_buffer, tx_buffer); + socket.listen(endpoint).unwrap(); + let new_handle = GlobalSocketHandle(sockets.add(socket)); + let old_handle = ::core::mem::replace(&mut self.handle, new_handle); + + Box::new(TcpSocketState { + handle: old_handle, + local_endpoint: self.local_endpoint, + is_listening: false, + }) + }; + + drop(sockets); + poll_ifaces(); + return Ok((new_socket, remote_endpoint)); + } + + // avoid deadlock + drop(socket); + drop(sockets); + SOCKET_ACTIVITY._wait(); + } + } + + fn endpoint(&self) -> Option { + self.local_endpoint.clone().or_else(|| { + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + let endpoint = socket.local_endpoint(); + if endpoint.port != 0 { + Some(endpoint) + } else { None - }) - } else if let SocketType::Udp(ref state) = self.socket_type { - loop { - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(self.handle); - - if socket.is_open() { - if let Ok((size, remote_endpoint)) = socket.recv_slice(data) { - let endpoint = remote_endpoint; - // avoid deadlock - drop(socket); - drop(sockets); + } + }) + } - poll_ifaces(); - return (Ok(size), endpoint); - } - } else { - return (Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED); + fn remote_endpoint(&self) -> Option { + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + if socket.is_open() { + Some(socket.remote_endpoint()) + } else { + None + } + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + +impl UdpSocketState { + pub fn new() -> Self { + let rx_buffer = UdpSocketBuffer::new( + vec![UdpPacketMetadata::EMPTY; UDP_METADATA_BUF], + vec![0; UDP_RECVBUF], + ); + let tx_buffer = UdpSocketBuffer::new( + vec![UdpPacketMetadata::EMPTY; UDP_METADATA_BUF], + vec![0; UDP_SENDBUF], + ); + let socket = UdpSocket::new(rx_buffer, tx_buffer); + let handle = GlobalSocketHandle(SOCKETS.lock().add(socket)); + + UdpSocketState { + handle, + remote_endpoint: None, + } + } +} + +impl Socket for UdpSocketState { + fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { + loop { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + + if socket.is_open() { + if let Ok((size, remote_endpoint)) = socket.recv_slice(data) { + let endpoint = remote_endpoint; + // avoid deadlock + drop(socket); + drop(sockets); + + poll_ifaces(); + return (Ok(size), endpoint); + } + } else { + return (Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED); + } + + // avoid deadlock + drop(socket); + SOCKET_ACTIVITY._wait() + } + } + + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + let remote_endpoint = { + if let Some(ref endpoint) = sendto_endpoint { + endpoint + } else if let Some(ref endpoint) = self.remote_endpoint { + endpoint + } else { + return Err(SysError::ENOTCONN); + } + }; + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + + if socket.endpoint().port == 0 { + let temp_port = get_ephemeral_port(); + socket + .bind(IpEndpoint::new(IpAddress::Unspecified, temp_port)) + .unwrap(); + } + + if socket.can_send() { + match socket.send_slice(&data, *remote_endpoint) { + Ok(()) => { + // avoid deadlock + drop(socket); + drop(sockets); + + poll_ifaces(); + Ok(data.len()) } + Err(err) => Err(SysError::ENOBUFS), + } + } else { + Err(SysError::ENOBUFS) + } + } + + fn poll(&self) -> (bool, bool, bool) { + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + + let (mut input, mut output, err) = (false, false, false); + if socket.can_recv() { + input = true; + } + if socket.can_send() { + output = true; + } + (input, output, err) + } + + fn connect(&mut self, endpoint: IpEndpoint) -> SysResult { + self.remote_endpoint = Some(endpoint); + Ok(0) + } + + fn bind(&mut self, endpoint: IpEndpoint) -> SysResult { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + match socket.bind(endpoint) { + Ok(()) => Ok(0), + Err(_) => Err(SysError::EINVAL), + } + } + + fn endpoint(&self) -> Option { + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + let endpoint = socket.endpoint(); + if endpoint.port != 0 { + Some(endpoint) + } else { + None + } + } + + fn remote_endpoint(&self) -> Option { + self.remote_endpoint.clone() + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + +impl RawSocketState { + pub fn new(protocol: u8) -> Self { + let rx_buffer = RawSocketBuffer::new( + vec![RawPacketMetadata::EMPTY; RAW_METADATA_BUF], + vec![0; RAW_RECVBUF], + ); + let tx_buffer = RawSocketBuffer::new( + vec![RawPacketMetadata::EMPTY; RAW_METADATA_BUF], + vec![0; RAW_SENDBUF], + ); + let socket = RawSocket::new( + IpVersion::Ipv4, + IpProtocol::from(protocol), + rx_buffer, + tx_buffer, + ); + let handle = GlobalSocketHandle(SOCKETS.lock().add(socket)); + + RawSocketState { handle } + } +} + +impl Socket for RawSocketState { + fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { + loop { + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + + if let Ok(size) = socket.recv_slice(data) { + let packet = Ipv4Packet::new_unchecked(data); + + return ( + Ok(size), + IpEndpoint { + addr: IpAddress::Ipv4(packet.src_addr()), + port: 0, + }, + ); + } + + // avoid deadlock + drop(socket); + drop(sockets); + SOCKET_ACTIVITY._wait() + } + } + + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + if let Some(endpoint) = sendto_endpoint { + // temporary solution + let iface = &*(NET_DRIVERS.read()[0]); + let v4_src = iface.ipv4_address().unwrap(); + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + + if let IpAddress::Ipv4(v4_dst) = endpoint.addr { + let len = data.len(); + // using 20-byte IPv4 header + let mut buffer = vec![0u8; len + 20]; + let mut packet = Ipv4Packet::new_unchecked(&mut buffer); + packet.set_version(4); + packet.set_header_len(20); + packet.set_total_len((20 + len) as u16); + packet.set_protocol(socket.ip_protocol().into()); + packet.set_src_addr(v4_src); + packet.set_dst_addr(v4_dst); + let payload = packet.payload_mut(); + payload.copy_from_slice(data); + packet.fill_checksum(); + + socket.send_slice(&buffer).unwrap(); // avoid deadlock drop(socket); - SOCKET_ACTIVITY._wait() + drop(sockets); + iface.poll(); + + Ok(len) + } else { + unimplemented!("ip type") } } else { - unimplemented!("socket type") + Err(SysError::ENOTCONN) } } + + fn poll(&self) -> (bool, bool, bool) { + unimplemented!() + } + + fn connect(&mut self, _endpoint: IpEndpoint) -> SysResult { + unimplemented!() + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } } + +fn get_ephemeral_port() -> u16 { + // TODO selects non-conflict high port + static mut EPHEMERAL_PORT: u16 = 0; + unsafe { + if EPHEMERAL_PORT == 0 { + EPHEMERAL_PORT = (49152 + rand::rand() % (65536 - 49152)) as u16; + } + if EPHEMERAL_PORT == 65535 { + EPHEMERAL_PORT = 49152; + } else { + EPHEMERAL_PORT = EPHEMERAL_PORT + 1; + } + EPHEMERAL_PORT + } +} + +/// Safety: call this without SOCKETS locked +fn poll_ifaces() { + for iface in NET_DRIVERS.read().iter() { + iface.poll(); + } +} + +pub const TCP_SENDBUF: usize = 512 * 1024; // 512K +pub const TCP_RECVBUF: usize = 512 * 1024; // 512K + +const UDP_METADATA_BUF: usize = 1024; +const UDP_SENDBUF: usize = 64 * 1024; // 64K +const UDP_RECVBUF: usize = 64 * 1024; // 64K + +const RAW_METADATA_BUF: usize = 2; +const RAW_SENDBUF: usize = 2 * 1024; // 2K +const RAW_RECVBUF: usize = 2 * 1024; // 2K diff --git a/kernel/src/process/mod.rs b/kernel/src/process/mod.rs index e476fd4..ff69345 100644 --- a/kernel/src/process/mod.rs +++ b/kernel/src/process/mod.rs @@ -1,10 +1,10 @@ pub use self::structs::*; use crate::arch::cpu; use crate::consts::{MAX_CPU_NUM, MAX_PROCESS_NUM}; +use crate::sync::{MutexGuard, SpinNoIrq}; use alloc::{boxed::Box, sync::Arc}; use log::*; pub use rcore_thread::*; -use spin::MutexGuard; mod abi; pub mod structs; @@ -37,13 +37,13 @@ static PROCESSORS: [Processor; MAX_CPU_NUM] = [ ]; /// Get current process -pub fn process() -> MutexGuard<'static, Process> { +pub fn process() -> MutexGuard<'static, Process, SpinNoIrq> { current_thread().proc.lock() } /// Get current process, ignoring its lock /// Only use this when necessary -pub unsafe fn process_unsafe() -> MutexGuard<'static, Process> { +pub unsafe fn process_unsafe() -> MutexGuard<'static, Process, SpinNoIrq> { let thread = current_thread(); thread.proc.force_unlock(); thread.proc.lock() diff --git a/kernel/src/process/structs.rs b/kernel/src/process/structs.rs index 038eee0..93820af 100644 --- a/kernel/src/process/structs.rs +++ b/kernel/src/process/structs.rs @@ -5,7 +5,7 @@ use core::str; use log::*; use rcore_memory::PAGE_SIZE; use rcore_thread::Tid; -use spin::{Mutex, RwLock}; +use spin::RwLock; use xmas_elf::{ header, program::{Flags, SegmentData, Type}, @@ -13,10 +13,10 @@ use xmas_elf::{ }; use crate::arch::interrupt::{Context, TrapFrame}; -use crate::fs::{FileHandle, INodeExt, OpenOptions, FOLLOW_MAX_DEPTH}; +use crate::fs::{FileHandle, FileLike, INodeExt, OpenOptions, FOLLOW_MAX_DEPTH}; use crate::memory::{ByFrame, GlobalFrameAlloc, KernelStack, MemoryAttr, MemorySet}; -use crate::net::{SocketWrapper, SOCKETS}; -use crate::sync::Condvar; +use crate::net::{Socket, SOCKETS}; +use crate::sync::{Condvar, SpinNoIrqLock as Mutex}; use super::abi::{self, ProcInitInfo}; @@ -30,21 +30,6 @@ pub struct Thread { pub proc: Arc>, } -#[derive(Clone)] -pub enum FileLike { - File(FileHandle), - Socket(SocketWrapper), -} - -impl fmt::Debug for FileLike { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - FileLike::File(_) => write!(f, "File"), - FileLike::Socket(wrapper) => write!(f, "{:?}", wrapper), - } - } -} - /// Pid type /// For strong type separation #[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -343,13 +328,6 @@ impl Thread { debug!("fork: temporary copy data!"); let kstack = KernelStack::new(); - let mut sockets = SOCKETS.lock(); - for (_fd, file) in files.iter() { - if let FileLike::Socket(wrapper) = file { - sockets.retain(wrapper.handle); - } - } - Box::new(Thread { context: unsafe { Context::new_fork(tf, kstack.top(), vm.token()) }, kstack, diff --git a/kernel/src/syscall/fs.rs b/kernel/src/syscall/fs.rs index 6f9ca29..b16fe73 100644 --- a/kernel/src/syscall/fs.rs +++ b/kernel/src/syscall/fs.rs @@ -19,11 +19,10 @@ pub fn sys_read(fd: usize, base: *mut u8, len: usize) -> SysResult { info!("read: fd: {}, base: {:?}, len: {:#x}", fd, base, len); } proc.vm.check_write_array(base, len)?; - match proc.files.get(&fd) { - Some(FileLike::File(_)) => sys_read_file(&mut proc, fd, base, len), - Some(FileLike::Socket(_)) => sys_read_socket(&mut proc, fd, base, len), - None => Err(SysError::EINVAL), - } + let slice = unsafe { slice::from_raw_parts_mut(base, len) }; + let file_like = proc.get_file_like(fd)?; + let len = file_like.read(slice)?; + Ok(len) } pub fn sys_write(fd: usize, base: *const u8, len: usize) -> SysResult { @@ -33,12 +32,10 @@ pub fn sys_write(fd: usize, base: *const u8, len: usize) -> SysResult { info!("write: fd: {}, base: {:?}, len: {:#x}", fd, base, len); } proc.vm.check_read_array(base, len)?; - - match proc.files.get(&fd) { - Some(FileLike::File(_)) => sys_write_file(&mut proc, fd, base, len), - Some(FileLike::Socket(_)) => sys_write_socket(&mut proc, fd, base, len), - None => Err(SysError::EINVAL), - } + let slice = unsafe { slice::from_raw_parts(base, len) }; + let file_like = proc.get_file_like(fd)?; + let len = file_like.write(slice)?; + Ok(len) } pub fn sys_pread(fd: usize, base: *mut u8, len: usize, offset: usize) -> SysResult { @@ -67,18 +64,6 @@ pub fn sys_pwrite(fd: usize, base: *const u8, len: usize, offset: usize) -> SysR Ok(len) } -pub fn sys_read_file(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult { - let slice = unsafe { slice::from_raw_parts_mut(base, len) }; - let len = proc.get_file(fd)?.read(slice)?; - Ok(len) -} - -pub fn sys_write_file(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { - let slice = unsafe { slice::from_raw_parts(base, len) }; - let len = proc.get_file(fd)?.write(slice)?; - Ok(len) -} - pub fn sys_poll(ufds: *mut PollFd, nfds: usize, timeout_msecs: usize) -> SysResult { info!( "poll: ufds: {:?}, nfds: {}, timeout_msecs: {:#x}", @@ -110,8 +95,8 @@ pub fn sys_poll(ufds: *mut PollFd, nfds: usize, timeout_msecs: usize) -> SysResu events = events + 1; } } - Some(FileLike::Socket(wrapper)) => { - let (input, output, err) = poll_socket(&wrapper); + Some(FileLike::Socket(socket)) => { + let (input, output, err) = socket.poll(); if err { poll.revents = poll.revents | PE::HUP; events = events + 1; @@ -146,61 +131,6 @@ pub fn sys_poll(ufds: *mut PollFd, nfds: usize, timeout_msecs: usize) -> SysResu } } -const FD_PER_ITEM: usize = 8 * size_of::(); -const MAX_FDSET_SIZE: usize = 1024 / FD_PER_ITEM; - -struct FdSet { - addr: *mut u32, - nfds: usize, - saved: [u32; MAX_FDSET_SIZE], -} - -impl FdSet { - /// Initialize a `FdSet` from pointer and number of fds - /// Check if the array is large enough - fn new(vm: &MemorySet, addr: *mut u32, nfds: usize) -> Result { - let mut saved = [0u32; MAX_FDSET_SIZE]; - if addr as usize != 0 { - let len = (nfds + FD_PER_ITEM - 1) / FD_PER_ITEM; - vm.check_write_array(addr, len)?; - if len > MAX_FDSET_SIZE { - return Err(SysError::EINVAL); - } - let slice = unsafe { slice::from_raw_parts_mut(addr, len) }; - - // save the fdset, and clear it - for i in 0..len { - saved[i] = slice[i]; - slice[i] = 0; - } - } - - Ok(FdSet { addr, nfds, saved }) - } - - /// Try to set fd in `FdSet` - /// Return true when `FdSet` is valid, and false when `FdSet` is bad (i.e. null pointer) - /// Fd should be less than nfds - fn set(&mut self, fd: usize) -> bool { - if self.addr as usize != 0 { - assert!(fd < self.nfds); - unsafe { - *self.addr.add(fd / 8 / size_of::()) |= 1 << (fd % (8 * size_of::())); - } - true - } else { - false - } - } - - /// Check to see fd is see in original `FdSet` - /// Fd should be less than nfds - fn is_set(&mut self, fd: usize) -> bool { - assert!(fd < self.nfds); - self.saved[fd / 8 / size_of::()] & (1 << (fd % (8 * size_of::()))) != 0 - } -} - pub fn sys_select( nfds: usize, read: *mut u32, @@ -230,9 +160,9 @@ pub fn sys_select( loop { let proc = process(); let mut events = 0; - for (fd, file) in proc.files.iter() { + for (fd, file_like) in proc.files.iter() { if *fd < nfds { - match file { + match file_like { FileLike::File(_) => { // FIXME: assume it is stdin for now if STDIN.can_read() { @@ -242,8 +172,8 @@ pub fn sys_select( } } } - FileLike::Socket(wrapper) => { - let (input, output, err) = poll_socket(&wrapper); + FileLike::Socket(socket) => { + let (input, output, err) = socket.poll(); if err && err_fds.is_set(*fd) { err_fds.set(*fd); events = events + 1; @@ -290,9 +220,9 @@ pub fn sys_readv(fd: usize, iov_ptr: *const IoVec, iov_count: usize) -> SysResul let mut iovs = IoVecs::check_and_new(iov_ptr, iov_count, &proc.vm, true)?; // read all data to a buf - let mut file = proc.get_file(fd)?; + let file_like = proc.get_file_like(fd)?; let mut buf = iovs.new_buf(true); - let len = file.read(buf.as_mut_slice())?; + let len = file_like.read(buf.as_mut_slice())?; // copy data to user iovs.write_all_from_slice(&buf[..len]); Ok(len) @@ -309,15 +239,11 @@ pub fn sys_writev(fd: usize, iov_ptr: *const IoVec, iov_count: usize) -> SysResu let buf = iovs.read_all_to_vec(); let len = buf.len(); - match proc.files.get(&fd) { - Some(FileLike::File(_)) => sys_write_file(&mut proc, fd, buf.as_ptr(), len), - Some(FileLike::Socket(_)) => sys_write_socket(&mut proc, fd, buf.as_ptr(), len), - None => Err(SysError::EINVAL), - } + let file_like = proc.get_file_like(fd)?; + let len = file_like.write(buf.as_slice())?; + Ok(len) } -const AT_FDCWD: usize = -100isize as usize; - pub fn sys_open(path: *const u8, flags: usize, mode: usize) -> SysResult { sys_openat(AT_FDCWD, path, flags, mode) } @@ -539,18 +465,9 @@ pub fn sys_dup2(fd1: usize, fd2: usize) -> SysResult { // close fd2 first if it is opened proc.files.remove(&fd2); - match proc.files.get(&fd1) { - Some(FileLike::File(file)) => { - let new_file = FileLike::File(file.clone()); - proc.files.insert(fd2, new_file); - Ok(fd2) - } - Some(FileLike::Socket(wrapper)) => { - let new_wrapper = wrapper.clone(); - sys_dup2_socket(&mut proc, new_wrapper, fd2) - } - None => Err(SysError::EINVAL), - } + let file_like = proc.get_file_like(fd1)?.clone(); + proc.files.insert(fd2, file_like); + Ok(fd2) } pub fn sys_chdir(path: *const u8) -> SysResult { @@ -785,14 +702,14 @@ pub fn sys_sendfile(out_fd: usize, in_fd: usize, offset: *mut usize, count: usiz } impl Process { + pub fn get_file_like(&mut self, fd: usize) -> Result<&mut FileLike, SysError> { + self.files.get_mut(&fd).ok_or(SysError::EBADF) + } pub fn get_file(&mut self, fd: usize) -> Result<&mut FileHandle, SysError> { - self.files - .get_mut(&fd) - .ok_or(SysError::EBADF) - .and_then(|f| match f { - FileLike::File(file) => Ok(file), - _ => Err(SysError::EBADF), - }) + match self.get_file_like(fd)? { + FileLike::File(file) => Ok(file), + _ => Err(SysError::EBADF), + } } pub fn lookup_inode(&self, path: &str) -> Result, SysError> { debug!("lookup_inode: cwd {} path {}", self.cwd, path); @@ -1250,3 +1167,60 @@ bitflags! { const INVAL = 0x0020; } } + +const FD_PER_ITEM: usize = 8 * size_of::(); +const MAX_FDSET_SIZE: usize = 1024 / FD_PER_ITEM; + +struct FdSet { + addr: *mut u32, + nfds: usize, + saved: [u32; MAX_FDSET_SIZE], +} + +impl FdSet { + /// Initialize a `FdSet` from pointer and number of fds + /// Check if the array is large enough + fn new(vm: &MemorySet, addr: *mut u32, nfds: usize) -> Result { + let mut saved = [0u32; MAX_FDSET_SIZE]; + if addr as usize != 0 { + let len = (nfds + FD_PER_ITEM - 1) / FD_PER_ITEM; + vm.check_write_array(addr, len)?; + if len > MAX_FDSET_SIZE { + return Err(SysError::EINVAL); + } + let slice = unsafe { slice::from_raw_parts_mut(addr, len) }; + + // save the fdset, and clear it + for i in 0..len { + saved[i] = slice[i]; + slice[i] = 0; + } + } + + Ok(FdSet { addr, nfds, saved }) + } + + /// Try to set fd in `FdSet` + /// Return true when `FdSet` is valid, and false when `FdSet` is bad (i.e. null pointer) + /// Fd should be less than nfds + fn set(&mut self, fd: usize) -> bool { + if self.addr as usize != 0 { + assert!(fd < self.nfds); + unsafe { + *self.addr.add(fd / 8 / size_of::()) |= 1 << (fd % (8 * size_of::())); + } + true + } else { + false + } + } + + /// Check to see fd is see in original `FdSet` + /// Fd should be less than nfds + fn is_set(&mut self, fd: usize) -> bool { + assert!(fd < self.nfds); + self.saved[fd / 8 / size_of::()] & (1 << (fd % (8 * size_of::()))) != 0 + } +} + +const AT_FDCWD: usize = -100isize as usize; diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index a955b9e..2811f12 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -2,116 +2,32 @@ use super::*; use crate::drivers::SOCKET_ACTIVITY; -use crate::net::{ - get_ephemeral_port, poll_ifaces, SocketType, SocketWrapper, TcpSocketState, UdpSocketState, - SOCKETS, -}; +use crate::fs::FileLike; +use crate::net::{RawSocketState, Socket, TcpSocketState, UdpSocketState, SOCKETS}; +use crate::sync::{MutexGuard, SpinNoIrq, SpinNoIrqLock as Mutex}; +use alloc::boxed::Box; use core::cmp::min; use core::mem::size_of; -use smoltcp::socket::*; use smoltcp::wire::*; -const AF_UNIX: usize = 1; -const AF_INET: usize = 2; - -const SOCK_STREAM: usize = 1; -const SOCK_DGRAM: usize = 2; -const SOCK_RAW: usize = 3; -const SOCK_TYPE_MASK: usize = 0xf; - -const IPPROTO_IP: usize = 0; -const IPPROTO_ICMP: usize = 1; -const IPPROTO_TCP: usize = 6; - -const TCP_SENDBUF: usize = 512 * 1024; // 512K -const TCP_RECVBUF: usize = 512 * 1024; // 512K - -const UDP_SENDBUF: usize = 64 * 1024; // 64K -const UDP_RECVBUF: usize = 64 * 1024; // 64K - pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult { info!( "socket: domain: {}, socket_type: {}, protocol: {}", domain, socket_type, protocol ); let mut proc = process(); - match domain { + let socket: Box = match domain { AF_INET | AF_UNIX => match socket_type & SOCK_TYPE_MASK { - SOCK_STREAM => { - let fd = proc.get_free_fd(); - - let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; TCP_RECVBUF]); - let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]); - let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - - let tcp_handle = SOCKETS.lock().add(tcp_socket); - proc.files.insert( - fd, - FileLike::Socket(SocketWrapper { - handle: tcp_handle, - socket_type: SocketType::Tcp(TcpSocketState { - local_endpoint: None, - is_listening: false, - }), - }), - ); - - Ok(fd) - } - SOCK_DGRAM => { - let fd = proc.get_free_fd(); - - let udp_rx_buffer = UdpSocketBuffer::new( - vec![UdpPacketMetadata::EMPTY; 1024], - vec![0; UDP_RECVBUF], - ); - let udp_tx_buffer = UdpSocketBuffer::new( - vec![UdpPacketMetadata::EMPTY; 1024], - vec![0; UDP_SENDBUF], - ); - let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - - let udp_handle = SOCKETS.lock().add(udp_socket); - proc.files.insert( - fd, - FileLike::Socket(SocketWrapper { - handle: udp_handle, - socket_type: SocketType::Udp(UdpSocketState { - remote_endpoint: None, - }), - }), - ); - - Ok(fd) - } - SOCK_RAW => { - let fd = proc.get_free_fd(); - - let raw_rx_buffer = - RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; 2], vec![0; 2048]); - let raw_tx_buffer = - RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; 2], vec![0; 2048]); - let raw_socket = RawSocket::new( - IpVersion::Ipv4, - IpProtocol::from(protocol as u8), - raw_rx_buffer, - raw_tx_buffer, - ); - - let raw_handle = SOCKETS.lock().add(raw_socket); - proc.files.insert( - fd, - FileLike::Socket(SocketWrapper { - handle: raw_handle, - socket_type: SocketType::Raw, - }), - ); - Ok(fd) - } - _ => Err(SysError::EINVAL), + SOCK_STREAM => Box::new(TcpSocketState::new()), + SOCK_DGRAM => Box::new(UdpSocketState::new()), + SOCK_RAW => Box::new(RawSocketState::new(protocol as u8)), + _ => return Err(SysError::EINVAL), }, - _ => Err(SysError::EAFNOSUPPORT), - } + _ => return Err(SysError::EAFNOSUPPORT), + }; + let fd = proc.get_free_fd(); + proc.files.insert(fd, FileLike::Socket(socket)); + Ok(fd) } pub fn sys_setsockopt( @@ -129,13 +45,6 @@ pub fn sys_setsockopt( Ok(0) } -const SOL_SOCKET: usize = 1; -const SO_SNDBUF: usize = 7; -const SO_RCVBUF: usize = 8; -const SO_LINGER: usize = 13; - -const TCP_CONGESTION: usize = 13; - pub fn sys_getsockopt( fd: usize, level: usize, @@ -154,7 +63,7 @@ pub fn sys_getsockopt( SO_SNDBUF => { proc.vm.check_write_array(optval, 4)?; unsafe { - *(optval as *mut u32) = TCP_SENDBUF as u32; + *(optval as *mut u32) = crate::net::TCP_SENDBUF as u32; *optlen = 4; } Ok(0) @@ -162,7 +71,7 @@ pub fn sys_getsockopt( SO_RCVBUF => { proc.vm.check_write_array(optval, 4)?; unsafe { - *(optval as *mut u32) = TCP_RECVBUF as u32; + *(optval as *mut u32) = crate::net::TCP_RECVBUF as u32; *optlen = 4; } Ok(0) @@ -177,24 +86,6 @@ pub fn sys_getsockopt( } } -impl Process { - fn get_socket(&mut self, fd: usize) -> Result { - let file = self.files.get_mut(&fd).ok_or(SysError::EBADF)?; - match file { - FileLike::Socket(wrapper) => Ok(wrapper.clone()), - _ => Err(SysError::ENOTSOCK), - } - } - - fn get_socket_mut(&mut self, fd: usize) -> Result<&mut SocketWrapper, SysError> { - let file = self.files.get_mut(&fd).ok_or(SysError::EBADF)?; - match file { - FileLike::Socket(ref mut wrapper) => Ok(wrapper), - _ => Err(SysError::ENOTSOCK), - } - } -} - pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult { info!( "sys_connect: fd: {}, addr: {:?}, addr_len: {}", @@ -202,64 +93,10 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu ); let mut proc = process(); - let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; - - let wrapper = &mut proc.get_socket_mut(fd)?; - if let SocketType::Tcp(_) = wrapper.socket_type { - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(wrapper.handle); - - let temp_port = get_ephemeral_port(); - - match socket.connect(endpoint, temp_port) { - Ok(()) => { - // avoid deadlock - drop(socket); - drop(sockets); - - // wait for connection result - loop { - poll_ifaces(); - - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(wrapper.handle); - if socket.state() == TcpState::SynSent { - // still connecting - drop(socket); - drop(sockets); - debug!("poll for connection wait"); - SOCKET_ACTIVITY._wait(); - } else if socket.state() == TcpState::Established { - break Ok(0); - } else { - break Err(SysError::ECONNREFUSED); - } - } - } - Err(_) => Err(SysError::ENOBUFS), - } - } else if let SocketType::Udp(_) = wrapper.socket_type { - wrapper.socket_type = SocketType::Udp(UdpSocketState { - remote_endpoint: Some(endpoint), - }); - Ok(0) - } else { - unimplemented!("socket type") - } -} - -pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { - let wrapper = proc.get_socket(fd)?; - let slice = unsafe { slice::from_raw_parts(base, len) }; - wrapper.write(&slice, None) -} - -pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult { - let wrapper = proc.get_socket(fd)?; - let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; - let (result, _) = wrapper.read(&mut slice); - result + let socket = proc.get_socket(fd)?; + socket.connect(endpoint)?; + Ok(0) } pub fn sys_sendto( @@ -278,15 +115,16 @@ pub fn sys_sendto( let mut proc = process(); proc.vm.check_read_array(base, len)?; - let wrapper = proc.get_socket(fd)?; let slice = unsafe { slice::from_raw_parts(base, len) }; - if addr.is_null() { - wrapper.write(&slice, None) + let endpoint = if addr.is_null() { + None } else { let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; info!("sys_sendto: sending to endpoint {:?}", endpoint); - wrapper.write(&slice, Some(endpoint)) - } + Some(endpoint) + }; + let socket = proc.get_socket(fd)?; + socket.write(&slice, endpoint) } pub fn sys_recvfrom( @@ -305,9 +143,9 @@ pub fn sys_recvfrom( let mut proc = process(); proc.vm.check_write_array(base, len)?; - let wrapper = proc.get_socket(fd)?; + let socket = proc.get_socket(fd)?; let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; - let (result, endpoint) = wrapper.read(&mut slice); + let (result, endpoint) = socket.read(&mut slice); if result.is_ok() && !addr.is_null() { let sockaddr_in = SockAddr::from(endpoint); @@ -319,45 +157,15 @@ pub fn sys_recvfrom( result } -impl Clone for SocketWrapper { - fn clone(&self) -> Self { - let mut sockets = SOCKETS.lock(); - sockets.retain(self.handle); - - SocketWrapper { - handle: self.handle.clone(), - socket_type: self.socket_type.clone(), - } - } -} - pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult { info!("sys_bind: fd: {} addr: {:?} len: {}", fd, addr, addr_len); let mut proc = process(); let mut endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; - if endpoint.port == 0 { - endpoint.port = get_ephemeral_port(); - } info!("sys_bind: fd: {} bind to {}", fd, endpoint); - let wrapper = &mut proc.get_socket_mut(fd)?; - if let SocketType::Tcp(_) = wrapper.socket_type { - wrapper.socket_type = SocketType::Tcp(TcpSocketState { - local_endpoint: Some(endpoint), - is_listening: false, - }); - Ok(0) - } else if let SocketType::Udp(_) = wrapper.socket_type { - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(wrapper.handle); - match socket.bind(endpoint) { - Ok(()) => Ok(0), - Err(_) => Err(SysError::EINVAL), - } - } else { - Err(SysError::EINVAL) - } + let socket = proc.get_socket(fd)?; + socket.bind(endpoint) } pub fn sys_listen(fd: usize, backlog: usize) -> SysResult { @@ -366,48 +174,16 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult { // open multiple sockets for each connection let mut proc = process(); - let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type { - if tcp_state.is_listening { - // it is ok to listen twice - Ok(0) - } else if let Some(local_endpoint) = tcp_state.local_endpoint { - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(wrapper.handle); - - info!("socket {} listening on {:?}", fd, local_endpoint); - if !socket.is_listening() { - match socket.listen(local_endpoint) { - Ok(()) => { - tcp_state.is_listening = true; - Ok(0) - } - Err(_err) => Err(SysError::EINVAL), - } - } else { - Ok(0) - } - } else { - Err(SysError::EINVAL) - } - } else { - Err(SysError::EINVAL) - } + let socket = proc.get_socket(fd)?; + socket.listen() } pub fn sys_shutdown(fd: usize, how: usize) -> SysResult { info!("sys_shutdown: fd: {} how: {}", fd, how); let mut proc = process(); - let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(_) = wrapper.socket_type { - let mut sockets = SOCKETS.lock(); - let mut socket = sockets.get::(wrapper.handle); - socket.close(); - Ok(0) - } else { - Err(SysError::EINVAL) - } + let socket = proc.get_socket(fd)?; + socket.shutdown() } pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResult { @@ -419,75 +195,19 @@ pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResu // open multiple sockets for each connection let mut proc = process(); - let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() { - if let Some(endpoint) = tcp_state.local_endpoint { - loop { - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(wrapper.handle); - - if socket.is_active() { - let remote_endpoint = socket.remote_endpoint(); - drop(socket); - - // move the current one to new_fd - // create a new one in fd - let new_fd = proc.get_free_fd(); - - let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; TCP_RECVBUF]); - let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]); - let mut tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - tcp_socket.listen(endpoint).unwrap(); - - let tcp_handle = sockets.add(tcp_socket); - - let mut orig_socket = proc - .files - .insert( - fd, - FileLike::Socket(SocketWrapper { - handle: tcp_handle, - socket_type: SocketType::Tcp(tcp_state), - }), - ) - .unwrap(); - - if let FileLike::Socket(ref mut wrapper) = orig_socket { - if let SocketType::Tcp(ref mut state) = wrapper.socket_type { - state.is_listening = false; - } else { - panic!("impossible"); - } - } else { - panic!("impossible"); - } - proc.files.insert(new_fd, orig_socket); - - if !addr.is_null() { - let sockaddr_in = SockAddr::from(remote_endpoint); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - } - - drop(sockets); - drop(proc); - poll_ifaces(); - return Ok(new_fd); - } + let socket = proc.get_socket(fd)?; + let (new_socket, remote_endpoint) = socket.accept()?; - // avoid deadlock - drop(socket); - drop(sockets); - SOCKET_ACTIVITY._wait() - } - } else { - Err(SysError::EINVAL) + let new_fd = proc.get_free_fd(); + proc.files.insert(new_fd, FileLike::Socket(new_socket)); + + if !addr.is_null() { + let sockaddr_in = SockAddr::from(remote_endpoint); + unsafe { + sockaddr_in.write_to(&mut proc, addr, addr_len)?; } - } else { - debug!("bad socket type {:?}", wrapper); - Err(SysError::EINVAL) } + Ok(new_fd) } pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResult { @@ -502,44 +222,13 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy return Err(SysError::EINVAL); } - let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(state) = &wrapper.socket_type { - if let Some(endpoint) = state.local_endpoint { - let sockaddr_in = SockAddr::from(endpoint); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - Ok(0) - } else { - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(wrapper.handle); - let endpoint = socket.local_endpoint(); - if endpoint.port != 0 { - let sockaddr_in = SockAddr::from(socket.local_endpoint()); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - Ok(0) - } else { - Err(SysError::EINVAL) - } - } - } else if let SocketType::Udp(_) = &wrapper.socket_type { - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(wrapper.handle); - let endpoint = socket.endpoint(); - if endpoint.port != 0 { - let sockaddr_in = SockAddr::from(endpoint); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - Ok(0) - } else { - Err(SysError::EINVAL) - } - } else { - Err(SysError::EINVAL) + let socket = proc.get_socket(fd)?; + let endpoint = socket.endpoint().ok_or(SysError::EINVAL)?; + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to(&mut proc, addr, addr_len)?; } + Ok(0) } pub fn sys_getpeername(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResult { @@ -556,81 +245,22 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy return Err(SysError::EINVAL); } - let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(_) = wrapper.socket_type { - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(wrapper.handle); - - if socket.is_open() { - let remote_endpoint = socket.remote_endpoint(); - let sockaddr_in = SockAddr::from(remote_endpoint); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - Ok(0) - } else { - Err(SysError::EINVAL) - } - } else if let SocketType::Udp(state) = &wrapper.socket_type { - if let Some(endpoint) = state.remote_endpoint { - let sockaddr_in = SockAddr::from(endpoint); - unsafe { - sockaddr_in.write_to(&mut proc, addr, addr_len)?; - } - Ok(0) - } else { - Err(SysError::EINVAL) - } - } else { - Err(SysError::EINVAL) + let socket = proc.get_socket(fd)?; + let remote_endpoint = socket.remote_endpoint().ok_or(SysError::EINVAL)?; + let sockaddr_in = SockAddr::from(remote_endpoint); + unsafe { + sockaddr_in.write_to(&mut proc, addr, addr_len)?; } + Ok(0) } -/// Check socket state -/// return (in, out, err) -pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) { - let mut input = false; - let mut output = false; - let mut err = false; - if let SocketType::Tcp(state) = wrapper.socket_type.clone() { - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(wrapper.handle); - - if state.is_listening && socket.is_active() { - // a new connection - input = true; - } else if !socket.is_open() { - err = true; - } else { - if socket.can_recv() { - input = true; - } - - if socket.can_send() { - output = true; - } - } - } else if let SocketType::Udp(_) = wrapper.socket_type { - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(wrapper.handle); - - if socket.can_recv() { - input = true; - } - - if socket.can_send() { - output = true; +impl Process { + fn get_socket(&mut self, fd: usize) -> Result<&mut Box, SysError> { + match self.get_file_like(fd)? { + FileLike::Socket(socket) => Ok(socket), + _ => Err(SysError::EBADF), } - } else { - unimplemented!() } - - (input, output, err) -} - -pub fn sys_dup2_socket(proc: &mut Process, wrapper: SocketWrapper, fd: usize) -> SysResult { - proc.files.insert(fd, FileLike::Socket(wrapper)); - Ok(fd) } // cancel alignment @@ -738,3 +368,22 @@ impl SockAddr { return Ok(0); } } + +const AF_UNIX: usize = 1; +const AF_INET: usize = 2; + +const SOCK_STREAM: usize = 1; +const SOCK_DGRAM: usize = 2; +const SOCK_RAW: usize = 3; +const SOCK_TYPE_MASK: usize = 0xf; + +const IPPROTO_IP: usize = 0; +const IPPROTO_ICMP: usize = 1; +const IPPROTO_TCP: usize = 6; + +const SOL_SOCKET: usize = 1; +const SO_SNDBUF: usize = 7; +const SO_RCVBUF: usize = 8; +const SO_LINGER: usize = 13; + +const TCP_CONGESTION: usize = 13; diff --git a/tools/.gitignore b/tools/.gitignore deleted file mode 100644 index 816158a..0000000 --- a/tools/.gitignore +++ /dev/null @@ -1 +0,0 @@ -llc