From 6697861860d7878789029a7767a57b362317b467 Mon Sep 17 00:00:00 2001 From: Jiajie Chen Date: Sun, 10 Mar 2019 00:58:10 +0800 Subject: [PATCH] Add is_listening to TcpSocketState, support sys_poll for listen --- kernel/src/process/structs.rs | 19 +++- kernel/src/syscall/misc.rs | 3 - kernel/src/syscall/mod.rs | 2 - kernel/src/syscall/net.rs | 160 +++++++++++++++++++++------------- kernel/src/syscall/proc.rs | 2 + 5 files changed, 118 insertions(+), 68 deletions(-) diff --git a/kernel/src/process/structs.rs b/kernel/src/process/structs.rs index a1b1dfd..96c9ba1 100644 --- a/kernel/src/process/structs.rs +++ b/kernel/src/process/structs.rs @@ -5,7 +5,7 @@ use log::*; use rcore_fs::vfs::INode; use spin::Mutex; use xmas_elf::{ElfFile, header, program::{Flags, Type}}; -use smoltcp::socket::{SocketSet, SocketHandle}; +use smoltcp::socket::SocketHandle; use smoltcp::wire::IpEndpoint; use rcore_memory::PAGE_SIZE; @@ -27,10 +27,16 @@ pub struct Thread { pub proc: Arc>, } +#[derive(Clone, Debug)] +pub struct TcpSocketState { + pub local_endpoint: Option, // save local endpoint for bind() + pub is_listening: bool, +} + #[derive(Clone, Debug)] pub enum SocketType { Raw, - Tcp(Option), // save local endpoint for bind() + Tcp(TcpSocketState), Udp, Icmp } @@ -51,7 +57,14 @@ impl fmt::Debug for FileLike { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { FileLike::File(_) => write!(f, "File"), - FileLike::Socket(_) => write!(f, "Socket"), + FileLike::Socket(wrapper) => { + match wrapper.socket_type { + SocketType::Raw => write!(f, "RawSocket"), + SocketType::Tcp(_) => write!(f, "TcpSocket"), + SocketType::Udp => write!(f, "UdpSocket"), + SocketType::Icmp => write!(f, "IcmpSocket"), + } + }, } } } diff --git a/kernel/src/syscall/misc.rs b/kernel/src/syscall/misc.rs index 5258379..5d557ca 100644 --- a/kernel/src/syscall/misc.rs +++ b/kernel/src/syscall/misc.rs @@ -1,9 +1,6 @@ use super::*; use core::mem::size_of; use core::sync::atomic::{AtomicI32, Ordering}; -use alloc::collections::btree_map::BTreeMap; -use crate::sync::Condvar; -use crate::sync::SpinNoIrqLock as Mutex; pub fn sys_arch_prctl(code: i32, addr: usize, tf: &mut TrapFrame) -> SysResult { const ARCH_SET_FS: i32 = 0x1002; diff --git a/kernel/src/syscall/mod.rs b/kernel/src/syscall/mod.rs index e2ee608..4e79efa 100644 --- a/kernel/src/syscall/mod.rs +++ b/kernel/src/syscall/mod.rs @@ -6,10 +6,8 @@ use core::{slice, str, fmt}; use bitflags::bitflags; use rcore_memory::VMError; use rcore_fs::vfs::{FileType, FsError, INode, Metadata}; -use spin::{Mutex, MutexGuard}; use crate::arch::interrupt::TrapFrame; -use crate::fs::FileHandle; use crate::process::*; use crate::thread; use crate::util; diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index fd0a1fa..92b67fa 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -2,6 +2,7 @@ use super::*; use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY}; +use crate::process::structs::TcpSocketState; use core::mem::size_of; use smoltcp::socket::*; use smoltcp::wire::*; @@ -51,7 +52,10 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu fd, FileLike::Socket(SocketWrapper { handle: tcp_handle, - socket_type: SocketType::Tcp(None), + socket_type: SocketType::Tcp(TcpSocketState { + local_endpoint: None, + is_listening: false, + }), }), ); @@ -410,6 +414,7 @@ pub fn sys_recvfrom( } let iface = &*(NET_DRIVERS.read()[0]); + debug!("sockets {:#?}", proc.files); let wrapper = proc.get_socket(fd)?; // TODO: move some part of these into one generic function @@ -527,7 +532,10 @@ pub fn sys_bind(fd: usize, addr: *const SockaddrIn, len: usize) -> SysResult { let iface = &*(NET_DRIVERS.read()[0]); let wrapper = &mut proc.get_socket_mut(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { - wrapper.socket_type = SocketType::Tcp(Some(endpoint)); + wrapper.socket_type = SocketType::Tcp(TcpSocketState { + local_endpoint: Some(endpoint), + is_listening: false + }); Ok(0) } else { Err(SysError::EINVAL) @@ -542,20 +550,30 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult { let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); + if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type { + if tcp_state.is_listening { + // it is ok to listen twice + Ok(0) + } else if let Some(local_endpoint) = tcp_state.local_endpoint { + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); - info!("socket {} listening on {:?}", fd, endpoint); - if !socket.is_listening() { - match socket.listen(endpoint) { - Ok(()) => Ok(0), - Err(err) => { - Err(SysError::EINVAL) - }, + info!("socket {} listening on {:?}", fd, local_endpoint); + if !socket.is_listening() { + match socket.listen(local_endpoint) { + Ok(()) => { + tcp_state.is_listening = true; + Ok(0) + }, + Err(err) => { + Err(SysError::EINVAL) + }, + } + } else { + Ok(0) } } else { - Ok(0) + Err(SysError::EINVAL) } } else { Err(SysError::EINVAL) @@ -568,7 +586,7 @@ pub fn sys_shutdown(fd: usize, how: usize) -> SysResult { let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { + if let SocketType::Tcp(_) = wrapper.socket_type { let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); socket.close(); @@ -600,49 +618,64 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe } let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { - loop { - let iface = &*(NET_DRIVERS.read()[0]); - let mut sockets = iface.sockets(); - let mut socket = sockets.get::(wrapper.handle); - - if socket.is_active() { - let remote_endpoint = socket.remote_endpoint(); - drop(socket); - - // move the current one to new_fd - // create a new one in fd - let new_fd = proc.get_free_fd(); + if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() { + if let Some(endpoint) = tcp_state.local_endpoint { + loop { + let iface = &*(NET_DRIVERS.read()[0]); + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); + + if socket.is_active() { + let remote_endpoint = socket.remote_endpoint(); + drop(socket); - let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; 2048]); - let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 2048]); - let mut tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); - tcp_socket.listen(endpoint).unwrap(); - - let tcp_handle = sockets.add(tcp_socket); - let orig_handle = proc - .files - .insert( - fd, - FileLike::Socket(SocketWrapper { - handle: tcp_handle, - socket_type: SocketType::Tcp(Some(endpoint)), - }), - ) - .unwrap(); - proc.files.insert(new_fd, orig_handle); + // move the current one to new_fd + // create a new one in fd + let new_fd = proc.get_free_fd(); + + let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; 2048]); + let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 2048]); + let mut tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + tcp_socket.listen(endpoint).unwrap(); + + let tcp_handle = sockets.add(tcp_socket); + let mut orig_socket = proc + .files + .insert( + fd, + FileLike::Socket(SocketWrapper { + handle: tcp_handle, + socket_type: SocketType::Tcp(tcp_state), + }), + ) + .unwrap(); + + if let FileLike::Socket(wrapper) = orig_socket { + proc.files.insert(new_fd, FileLike::Socket(SocketWrapper { + handle: wrapper.handle, + socket_type: SocketType::Tcp(TcpSocketState { + local_endpoint: Some(endpoint), + is_listening: false, + }) + })); + } else { + panic!("impossible"); + } - if !addr.is_null() { - let sockaddr_in = SockaddrIn::from(remote_endpoint); - unsafe { sockaddr_in.write_to(addr, addr_len); } + if !addr.is_null() { + let sockaddr_in = SockaddrIn::from(remote_endpoint); + unsafe { sockaddr_in.write_to(addr, addr_len); } + } + return Ok(new_fd); } - return Ok(new_fd); - } - // avoid deadlock - drop(socket); - drop(sockets); - SOCKET_ACTIVITY._wait() + // avoid deadlock + drop(socket); + drop(sockets); + SOCKET_ACTIVITY._wait() + } + } else { + Err(SysError::EINVAL) } } else { debug!("bad socket type {:?}", wrapper); @@ -675,10 +708,14 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { - let sockaddr_in = SockaddrIn::from(endpoint); - unsafe { sockaddr_in.write_to(addr, addr_len); } - return Ok(0); + if let SocketType::Tcp(state) = &wrapper.socket_type { + if let Some(endpoint) = state.local_endpoint { + let sockaddr_in = SockaddrIn::from(endpoint); + unsafe { sockaddr_in.write_to(addr, addr_len); } + Ok(0) + } else { + Err(SysError::EINVAL) + } } else { Err(SysError::EINVAL) } @@ -709,7 +746,7 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; - if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { + if let SocketType::Tcp(_) = wrapper.socket_type { let mut sockets = iface.sockets(); let socket = sockets.get::(wrapper.handle); @@ -732,12 +769,15 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) { let mut input = false; let mut output = false; let mut err = false; - if let SocketType::Tcp(_) = wrapper.socket_type { + if let SocketType::Tcp(state) = wrapper.socket_type.clone() { let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); - if !socket.is_open() { + if state.is_listening && socket.is_active() { + // a new connection + input = true; + } else if !socket.is_open() { err = true; } else { if socket.can_recv() { diff --git a/kernel/src/syscall/proc.rs b/kernel/src/syscall/proc.rs index dd69775..a79c72e 100644 --- a/kernel/src/syscall/proc.rs +++ b/kernel/src/syscall/proc.rs @@ -162,11 +162,13 @@ pub fn sys_kill(pid: usize) -> SysResult { /// Get the current process id pub fn sys_getpid() -> SysResult { + info!("getpid"); Ok(thread::current().id()) } /// Get the current thread id pub fn sys_gettid() -> SysResult { + info!("gettid"); // use pid as tid for now Ok(thread::current().id()) }