Add udp remote endpoint state

master
Jiajie Chen 6 years ago
parent 33ce72703b
commit af63d937d6

@ -1,4 +1,4 @@
use alloc::{boxed::Box, collections::BTreeMap, collections::btree_map::Entry, string::String, sync::Arc, vec::Vec, sync::Weak}; use alloc::{boxed::Box, collections::BTreeMap, string::String, sync::Arc, vec::Vec, sync::Weak};
use core::fmt; use core::fmt;
use log::*; use log::*;
@ -33,11 +33,16 @@ pub struct TcpSocketState {
pub is_listening: bool, pub is_listening: bool,
} }
#[derive(Clone, Debug)]
pub struct UdpSocketState {
pub remote_endpoint: Option<IpEndpoint>, // remember remote endpoint for connect(0)
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum SocketType { pub enum SocketType {
Raw, Raw,
Tcp(TcpSocketState), Tcp(TcpSocketState),
Udp, Udp(UdpSocketState),
Icmp Icmp
} }
@ -61,7 +66,7 @@ impl fmt::Debug for FileLike {
match wrapper.socket_type { match wrapper.socket_type {
SocketType::Raw => write!(f, "RawSocket"), SocketType::Raw => write!(f, "RawSocket"),
SocketType::Tcp(_) => write!(f, "TcpSocket"), SocketType::Tcp(_) => write!(f, "TcpSocket"),
SocketType::Udp => write!(f, "UdpSocket"), SocketType::Udp(_) => write!(f, "UdpSocket"),
SocketType::Icmp => write!(f, "IcmpSocket"), SocketType::Icmp => write!(f, "IcmpSocket"),
} }
}, },

@ -80,7 +80,9 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
fd, fd,
FileLike::Socket(SocketWrapper { FileLike::Socket(SocketWrapper {
handle: udp_handle, handle: udp_handle,
socket_type: SocketType::Udp, socket_type: SocketType::Udp(UdpSocketState {
remote_endpoint: None
}),
}), }),
); );
@ -213,7 +215,7 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?;
let wrapper = proc.get_socket(fd)?; let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
@ -249,8 +251,10 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
} }
Err(_) => Err(SysError::ENOBUFS), Err(_) => Err(SysError::ENOBUFS),
} }
} else if let SocketType::Udp = wrapper.socket_type { } else if let SocketType::Udp(_) = wrapper.socket_type {
// do nothing when only sendto() is used wrapper.socket_type = SocketType::Udp(UdpSocketState {
remote_endpoint: Some(endpoint),
});
Ok(0) Ok(0)
} else { } else {
unimplemented!("socket type") unimplemented!("socket type")
@ -284,6 +288,42 @@ pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usi
} else { } else {
Err(SysError::ENOTCONN) Err(SysError::ENOTCONN)
} }
} else if let SocketType::Udp(ref state) = wrapper.socket_type {
if let Some(ref remote_endpoint) = state.remote_endpoint {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
if !socket.endpoint().is_specified() {
let v4_src = iface.ipv4_address().unwrap();
let temp_port = get_ephemeral_port();
socket
.bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port))
.unwrap();
}
let slice = unsafe { slice::from_raw_parts(base, len) };
if socket.is_open() {
if socket.can_send() {
match socket.send_slice(&slice, *remote_endpoint) {
Ok(()) => {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
Err(SysError::ENOTCONN)
}
} else {
Err(SysError::ENOTCONN)
}
} else { } else {
unimplemented!("socket type") unimplemented!("socket type")
} }
@ -316,7 +356,7 @@ pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize)
drop(sockets); drop(sockets);
SOCKET_ACTIVITY._wait() SOCKET_ACTIVITY._wait()
} }
} else if let SocketType::Udp = wrapper.socket_type { } else if let SocketType::Udp(_) = wrapper.socket_type {
loop { loop {
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle); let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
@ -396,7 +436,7 @@ pub fn sys_sendto(
} else { } else {
unimplemented!("ip type") unimplemented!("ip type")
} }
} else if let SocketType::Udp = wrapper.socket_type { } else if let SocketType::Udp(_) = wrapper.socket_type {
let v4_src = iface.ipv4_address().unwrap(); let v4_src = iface.ipv4_address().unwrap();
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle); let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
@ -483,7 +523,7 @@ pub fn sys_recvfrom(
drop(sockets); drop(sockets);
SOCKET_ACTIVITY._wait() SOCKET_ACTIVITY._wait()
} }
} else if let SocketType::Udp = wrapper.socket_type { } else if let SocketType::Udp(_) = wrapper.socket_type {
loop { loop {
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle); let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
@ -757,6 +797,19 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
Err(SysError::EINVAL) Err(SysError::EINVAL)
} }
} }
} else if let SocketType::Udp(_) = &wrapper.socket_type {
let mut sockets = iface.sockets();
let socket = sockets.get::<UdpSocket>(wrapper.handle);
let endpoint = socket.endpoint();
if endpoint.is_specified() {
let sockaddr_in = SockAddr::from(endpoint);
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
Ok(0)
} else {
Err(SysError::EINVAL)
}
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
} }
@ -792,6 +845,16 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
} }
} else if let SocketType::Udp(state) = &wrapper.socket_type {
if let Some(endpoint) = state.remote_endpoint {
let sockaddr_in = SockAddr::from(endpoint);
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
Ok(0)
} else {
Err(SysError::EINVAL)
}
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
} }

@ -1,7 +1,6 @@
//! Syscalls for process //! Syscalls for process
use super::*; use super::*;
use crate::sync::Condvar;
/// Fork the current process. Return the child's PID. /// Fork the current process. Return the child's PID.
pub fn sys_fork(tf: &TrapFrame) -> SysResult { pub fn sys_fork(tf: &TrapFrame) -> SysResult {
@ -152,7 +151,6 @@ pub fn sys_kill(pid: usize, sig: usize) -> SysResult {
if current_pid == pid { if current_pid == pid {
// killing myself // killing myself
sys_exit_group(sig); sys_exit_group(sig);
Ok(0)
} else { } else {
if let Some(proc_arc) = PROCESSES.read().get(&pid).and_then(|weak| weak.upgrade()) { if let Some(proc_arc) = PROCESSES.read().get(&pid).and_then(|weak| weak.upgrade()) {
let proc = proc_arc.lock(); let proc = proc_arc.lock();

Loading…
Cancel
Save