Add is_listening to TcpSocketState, support sys_poll for listen

master
Jiajie Chen 6 years ago
parent 6ed66d03d8
commit 6697861860

@ -5,7 +5,7 @@ use log::*;
use rcore_fs::vfs::INode;
use spin::Mutex;
use xmas_elf::{ElfFile, header, program::{Flags, Type}};
use smoltcp::socket::{SocketSet, SocketHandle};
use smoltcp::socket::SocketHandle;
use smoltcp::wire::IpEndpoint;
use rcore_memory::PAGE_SIZE;
@ -27,10 +27,16 @@ pub struct Thread {
pub proc: Arc<Mutex<Process>>,
}
#[derive(Clone, Debug)]
pub struct TcpSocketState {
pub local_endpoint: Option<IpEndpoint>, // save local endpoint for bind()
pub is_listening: bool,
}
#[derive(Clone, Debug)]
pub enum SocketType {
Raw,
Tcp(Option<IpEndpoint>), // save local endpoint for bind()
Tcp(TcpSocketState),
Udp,
Icmp
}
@ -51,7 +57,14 @@ impl fmt::Debug for FileLike {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FileLike::File(_) => write!(f, "File"),
FileLike::Socket(_) => write!(f, "Socket"),
FileLike::Socket(wrapper) => {
match wrapper.socket_type {
SocketType::Raw => write!(f, "RawSocket"),
SocketType::Tcp(_) => write!(f, "TcpSocket"),
SocketType::Udp => write!(f, "UdpSocket"),
SocketType::Icmp => write!(f, "IcmpSocket"),
}
},
}
}
}

@ -1,9 +1,6 @@
use super::*;
use core::mem::size_of;
use core::sync::atomic::{AtomicI32, Ordering};
use alloc::collections::btree_map::BTreeMap;
use crate::sync::Condvar;
use crate::sync::SpinNoIrqLock as Mutex;
pub fn sys_arch_prctl(code: i32, addr: usize, tf: &mut TrapFrame) -> SysResult {
const ARCH_SET_FS: i32 = 0x1002;

@ -6,10 +6,8 @@ use core::{slice, str, fmt};
use bitflags::bitflags;
use rcore_memory::VMError;
use rcore_fs::vfs::{FileType, FsError, INode, Metadata};
use spin::{Mutex, MutexGuard};
use crate::arch::interrupt::TrapFrame;
use crate::fs::FileHandle;
use crate::process::*;
use crate::thread;
use crate::util;

@ -2,6 +2,7 @@
use super::*;
use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY};
use crate::process::structs::TcpSocketState;
use core::mem::size_of;
use smoltcp::socket::*;
use smoltcp::wire::*;
@ -51,7 +52,10 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
fd,
FileLike::Socket(SocketWrapper {
handle: tcp_handle,
socket_type: SocketType::Tcp(None),
socket_type: SocketType::Tcp(TcpSocketState {
local_endpoint: None,
is_listening: false,
}),
}),
);
@ -410,6 +414,7 @@ pub fn sys_recvfrom(
}
let iface = &*(NET_DRIVERS.read()[0]);
debug!("sockets {:#?}", proc.files);
let wrapper = proc.get_socket(fd)?;
// TODO: move some part of these into one generic function
@ -527,7 +532,10 @@ pub fn sys_bind(fd: usize, addr: *const SockaddrIn, len: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
wrapper.socket_type = SocketType::Tcp(Some(endpoint));
wrapper.socket_type = SocketType::Tcp(TcpSocketState {
local_endpoint: Some(endpoint),
is_listening: false
});
Ok(0)
} else {
Err(SysError::EINVAL)
@ -542,14 +550,21 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type {
if tcp_state.is_listening {
// it is ok to listen twice
Ok(0)
} else if let Some(local_endpoint) = tcp_state.local_endpoint {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
info!("socket {} listening on {:?}", fd, endpoint);
info!("socket {} listening on {:?}", fd, local_endpoint);
if !socket.is_listening() {
match socket.listen(endpoint) {
Ok(()) => Ok(0),
match socket.listen(local_endpoint) {
Ok(()) => {
tcp_state.is_listening = true;
Ok(0)
},
Err(err) => {
Err(SysError::EINVAL)
},
@ -560,6 +575,9 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
} else {
Err(SysError::EINVAL)
}
} else {
Err(SysError::EINVAL)
}
}
pub fn sys_shutdown(fd: usize, how: usize) -> SysResult {
@ -568,7 +586,7 @@ pub fn sys_shutdown(fd: usize, how: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
socket.close();
@ -600,7 +618,8 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
}
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() {
if let Some(endpoint) = tcp_state.local_endpoint {
loop {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
@ -620,17 +639,28 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
tcp_socket.listen(endpoint).unwrap();
let tcp_handle = sockets.add(tcp_socket);
let orig_handle = proc
let mut orig_socket = proc
.files
.insert(
fd,
FileLike::Socket(SocketWrapper {
handle: tcp_handle,
socket_type: SocketType::Tcp(Some(endpoint)),
socket_type: SocketType::Tcp(tcp_state),
}),
)
.unwrap();
proc.files.insert(new_fd, orig_handle);
if let FileLike::Socket(wrapper) = orig_socket {
proc.files.insert(new_fd, FileLike::Socket(SocketWrapper {
handle: wrapper.handle,
socket_type: SocketType::Tcp(TcpSocketState {
local_endpoint: Some(endpoint),
is_listening: false,
})
}));
} else {
panic!("impossible");
}
if !addr.is_null() {
let sockaddr_in = SockaddrIn::from(remote_endpoint);
@ -644,6 +674,9 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else {
Err(SysError::EINVAL)
}
} else {
debug!("bad socket type {:?}", wrapper);
Err(SysError::EINVAL)
@ -675,10 +708,14 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) ->
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
if let SocketType::Tcp(state) = &wrapper.socket_type {
if let Some(endpoint) = state.local_endpoint {
let sockaddr_in = SockaddrIn::from(endpoint);
unsafe { sockaddr_in.write_to(addr, addr_len); }
return Ok(0);
Ok(0)
} else {
Err(SysError::EINVAL)
}
} else {
Err(SysError::EINVAL)
}
@ -709,7 +746,7 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) ->
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type {
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
let socket = sockets.get::<TcpSocket>(wrapper.handle);
@ -732,12 +769,15 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) {
let mut input = false;
let mut output = false;
let mut err = false;
if let SocketType::Tcp(_) = wrapper.socket_type {
if let SocketType::Tcp(state) = wrapper.socket_type.clone() {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
if !socket.is_open() {
if state.is_listening && socket.is_active() {
// a new connection
input = true;
} else if !socket.is_open() {
err = true;
} else {
if socket.can_recv() {

@ -162,11 +162,13 @@ pub fn sys_kill(pid: usize) -> SysResult {
/// Get the current process id
pub fn sys_getpid() -> SysResult {
info!("getpid");
Ok(thread::current().id())
}
/// Get the current thread id
pub fn sys_gettid() -> SysResult {
info!("gettid");
// use pid as tid for now
Ok(thread::current().id())
}

Loading…
Cancel
Save