Add is_listening to TcpSocketState, support sys_poll for listen

master
Jiajie Chen 6 years ago
parent 6ed66d03d8
commit 6697861860

@ -5,7 +5,7 @@ use log::*;
use rcore_fs::vfs::INode; use rcore_fs::vfs::INode;
use spin::Mutex; use spin::Mutex;
use xmas_elf::{ElfFile, header, program::{Flags, Type}}; use xmas_elf::{ElfFile, header, program::{Flags, Type}};
use smoltcp::socket::{SocketSet, SocketHandle}; use smoltcp::socket::SocketHandle;
use smoltcp::wire::IpEndpoint; use smoltcp::wire::IpEndpoint;
use rcore_memory::PAGE_SIZE; use rcore_memory::PAGE_SIZE;
@ -27,10 +27,16 @@ pub struct Thread {
pub proc: Arc<Mutex<Process>>, pub proc: Arc<Mutex<Process>>,
} }
#[derive(Clone, Debug)]
pub struct TcpSocketState {
pub local_endpoint: Option<IpEndpoint>, // save local endpoint for bind()
pub is_listening: bool,
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum SocketType { pub enum SocketType {
Raw, Raw,
Tcp(Option<IpEndpoint>), // save local endpoint for bind() Tcp(TcpSocketState),
Udp, Udp,
Icmp Icmp
} }
@ -51,7 +57,14 @@ impl fmt::Debug for FileLike {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
FileLike::File(_) => write!(f, "File"), FileLike::File(_) => write!(f, "File"),
FileLike::Socket(_) => write!(f, "Socket"), FileLike::Socket(wrapper) => {
match wrapper.socket_type {
SocketType::Raw => write!(f, "RawSocket"),
SocketType::Tcp(_) => write!(f, "TcpSocket"),
SocketType::Udp => write!(f, "UdpSocket"),
SocketType::Icmp => write!(f, "IcmpSocket"),
}
},
} }
} }
} }

@ -1,9 +1,6 @@
use super::*; use super::*;
use core::mem::size_of; use core::mem::size_of;
use core::sync::atomic::{AtomicI32, Ordering}; use core::sync::atomic::{AtomicI32, Ordering};
use alloc::collections::btree_map::BTreeMap;
use crate::sync::Condvar;
use crate::sync::SpinNoIrqLock as Mutex;
pub fn sys_arch_prctl(code: i32, addr: usize, tf: &mut TrapFrame) -> SysResult { pub fn sys_arch_prctl(code: i32, addr: usize, tf: &mut TrapFrame) -> SysResult {
const ARCH_SET_FS: i32 = 0x1002; const ARCH_SET_FS: i32 = 0x1002;

@ -6,10 +6,8 @@ use core::{slice, str, fmt};
use bitflags::bitflags; use bitflags::bitflags;
use rcore_memory::VMError; use rcore_memory::VMError;
use rcore_fs::vfs::{FileType, FsError, INode, Metadata}; use rcore_fs::vfs::{FileType, FsError, INode, Metadata};
use spin::{Mutex, MutexGuard};
use crate::arch::interrupt::TrapFrame; use crate::arch::interrupt::TrapFrame;
use crate::fs::FileHandle;
use crate::process::*; use crate::process::*;
use crate::thread; use crate::thread;
use crate::util; use crate::util;

@ -2,6 +2,7 @@
use super::*; use super::*;
use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY}; use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY};
use crate::process::structs::TcpSocketState;
use core::mem::size_of; use core::mem::size_of;
use smoltcp::socket::*; use smoltcp::socket::*;
use smoltcp::wire::*; use smoltcp::wire::*;
@ -51,7 +52,10 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
fd, fd,
FileLike::Socket(SocketWrapper { FileLike::Socket(SocketWrapper {
handle: tcp_handle, handle: tcp_handle,
socket_type: SocketType::Tcp(None), socket_type: SocketType::Tcp(TcpSocketState {
local_endpoint: None,
is_listening: false,
}),
}), }),
); );
@ -410,6 +414,7 @@ pub fn sys_recvfrom(
} }
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
debug!("sockets {:#?}", proc.files);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
// TODO: move some part of these into one generic function // TODO: move some part of these into one generic function
@ -527,7 +532,10 @@ pub fn sys_bind(fd: usize, addr: *const SockaddrIn, len: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = &mut proc.get_socket_mut(fd)?; let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
wrapper.socket_type = SocketType::Tcp(Some(endpoint)); wrapper.socket_type = SocketType::Tcp(TcpSocketState {
local_endpoint: Some(endpoint),
is_listening: false
});
Ok(0) Ok(0)
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
@ -542,14 +550,21 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type {
if tcp_state.is_listening {
// it is ok to listen twice
Ok(0)
} else if let Some(local_endpoint) = tcp_state.local_endpoint {
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
info!("socket {} listening on {:?}", fd, endpoint); info!("socket {} listening on {:?}", fd, local_endpoint);
if !socket.is_listening() { if !socket.is_listening() {
match socket.listen(endpoint) { match socket.listen(local_endpoint) {
Ok(()) => Ok(0), Ok(()) => {
tcp_state.is_listening = true;
Ok(0)
},
Err(err) => { Err(err) => {
Err(SysError::EINVAL) Err(SysError::EINVAL)
}, },
@ -560,6 +575,9 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
} }
} else {
Err(SysError::EINVAL)
}
} }
pub fn sys_shutdown(fd: usize, how: usize) -> SysResult { pub fn sys_shutdown(fd: usize, how: usize) -> SysResult {
@ -568,7 +586,7 @@ pub fn sys_shutdown(fd: usize, how: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
socket.close(); socket.close();
@ -600,7 +618,8 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
} }
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() {
if let Some(endpoint) = tcp_state.local_endpoint {
loop { loop {
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
@ -620,17 +639,28 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
tcp_socket.listen(endpoint).unwrap(); tcp_socket.listen(endpoint).unwrap();
let tcp_handle = sockets.add(tcp_socket); let tcp_handle = sockets.add(tcp_socket);
let orig_handle = proc let mut orig_socket = proc
.files .files
.insert( .insert(
fd, fd,
FileLike::Socket(SocketWrapper { FileLike::Socket(SocketWrapper {
handle: tcp_handle, handle: tcp_handle,
socket_type: SocketType::Tcp(Some(endpoint)), socket_type: SocketType::Tcp(tcp_state),
}), }),
) )
.unwrap(); .unwrap();
proc.files.insert(new_fd, orig_handle);
if let FileLike::Socket(wrapper) = orig_socket {
proc.files.insert(new_fd, FileLike::Socket(SocketWrapper {
handle: wrapper.handle,
socket_type: SocketType::Tcp(TcpSocketState {
local_endpoint: Some(endpoint),
is_listening: false,
})
}));
} else {
panic!("impossible");
}
if !addr.is_null() { if !addr.is_null() {
let sockaddr_in = SockaddrIn::from(remote_endpoint); let sockaddr_in = SockaddrIn::from(remote_endpoint);
@ -644,6 +674,9 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
drop(sockets); drop(sockets);
SOCKET_ACTIVITY._wait() SOCKET_ACTIVITY._wait()
} }
} else {
Err(SysError::EINVAL)
}
} else { } else {
debug!("bad socket type {:?}", wrapper); debug!("bad socket type {:?}", wrapper);
Err(SysError::EINVAL) Err(SysError::EINVAL)
@ -675,10 +708,14 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) ->
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(state) = &wrapper.socket_type {
if let Some(endpoint) = state.local_endpoint {
let sockaddr_in = SockaddrIn::from(endpoint); let sockaddr_in = SockaddrIn::from(endpoint);
unsafe { sockaddr_in.write_to(addr, addr_len); } unsafe { sockaddr_in.write_to(addr, addr_len); }
return Ok(0); Ok(0)
} else {
Err(SysError::EINVAL)
}
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
} }
@ -709,7 +746,7 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) ->
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let socket = sockets.get::<TcpSocket>(wrapper.handle); let socket = sockets.get::<TcpSocket>(wrapper.handle);
@ -732,12 +769,15 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) {
let mut input = false; let mut input = false;
let mut output = false; let mut output = false;
let mut err = false; let mut err = false;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(state) = wrapper.socket_type.clone() {
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
if !socket.is_open() { if state.is_listening && socket.is_active() {
// a new connection
input = true;
} else if !socket.is_open() {
err = true; err = true;
} else { } else {
if socket.can_recv() { if socket.can_recv() {

@ -162,11 +162,13 @@ pub fn sys_kill(pid: usize) -> SysResult {
/// Get the current process id /// Get the current process id
pub fn sys_getpid() -> SysResult { pub fn sys_getpid() -> SysResult {
info!("getpid");
Ok(thread::current().id()) Ok(thread::current().id())
} }
/// Get the current thread id /// Get the current thread id
pub fn sys_gettid() -> SysResult { pub fn sys_gettid() -> SysResult {
info!("gettid");
// use pid as tid for now // use pid as tid for now
Ok(thread::current().id()) Ok(thread::current().id())
} }

Loading…
Cancel
Save