diff --git a/kernel/src/process/structs.rs b/kernel/src/process/structs.rs index 8cca5a1..dbf9bca 100644 --- a/kernel/src/process/structs.rs +++ b/kernel/src/process/structs.rs @@ -1,4 +1,4 @@ -use alloc::{boxed::Box, collections::BTreeMap, collections::btree_map::Entry, string::String, sync::Arc, vec::Vec, sync::Weak}; +use alloc::{boxed::Box, collections::BTreeMap, string::String, sync::Arc, vec::Vec, sync::Weak}; use core::fmt; use log::*; @@ -33,11 +33,16 @@ pub struct TcpSocketState { pub is_listening: bool, } +#[derive(Clone, Debug)] +pub struct UdpSocketState { + pub remote_endpoint: Option, // remember remote endpoint for connect(0) +} + #[derive(Clone, Debug)] pub enum SocketType { Raw, Tcp(TcpSocketState), - Udp, + Udp(UdpSocketState), Icmp } @@ -61,7 +66,7 @@ impl fmt::Debug for FileLike { match wrapper.socket_type { SocketType::Raw => write!(f, "RawSocket"), SocketType::Tcp(_) => write!(f, "TcpSocket"), - SocketType::Udp => write!(f, "UdpSocket"), + SocketType::Udp(_) => write!(f, "UdpSocket"), SocketType::Icmp => write!(f, "IcmpSocket"), } }, diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index b3fd600..6a0645e 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -80,7 +80,9 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu fd, FileLike::Socket(SocketWrapper { handle: udp_handle, - socket_type: SocketType::Udp, + socket_type: SocketType::Udp(UdpSocketState { + remote_endpoint: None + }), }), ); @@ -213,7 +215,7 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; - let wrapper = proc.get_socket(fd)?; + let wrapper = &mut proc.get_socket_mut(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = iface.sockets(); @@ -249,8 +251,10 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu } Err(_) => Err(SysError::ENOBUFS), } - } else if let SocketType::Udp = wrapper.socket_type { - // do nothing when only sendto() is used + } else if let SocketType::Udp(_) = wrapper.socket_type { + wrapper.socket_type = SocketType::Udp(UdpSocketState { + remote_endpoint: Some(endpoint), + }); Ok(0) } else { unimplemented!("socket type") @@ -284,6 +288,42 @@ pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usi } else { Err(SysError::ENOTCONN) } + } else if let SocketType::Udp(ref state) = wrapper.socket_type { + if let Some(ref remote_endpoint) = state.remote_endpoint { + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); + + if !socket.endpoint().is_specified() { + let v4_src = iface.ipv4_address().unwrap(); + let temp_port = get_ephemeral_port(); + socket + .bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port)) + .unwrap(); + } + + let slice = unsafe { slice::from_raw_parts(base, len) }; + if socket.is_open() { + if socket.can_send() { + match socket.send_slice(&slice, *remote_endpoint) { + Ok(()) => { + // avoid deadlock + drop(socket); + drop(sockets); + + iface.poll(); + Ok(len) + } + Err(err) => Err(SysError::ENOBUFS), + } + } else { + Err(SysError::ENOBUFS) + } + } else { + Err(SysError::ENOTCONN) + } + } else { + Err(SysError::ENOTCONN) + } } else { unimplemented!("socket type") } @@ -316,7 +356,7 @@ pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) drop(sockets); SOCKET_ACTIVITY._wait() } - } else if let SocketType::Udp = wrapper.socket_type { + } else if let SocketType::Udp(_) = wrapper.socket_type { loop { let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); @@ -396,7 +436,7 @@ pub fn sys_sendto( } else { unimplemented!("ip type") } - } else if let SocketType::Udp = wrapper.socket_type { + } else if let SocketType::Udp(_) = wrapper.socket_type { let v4_src = iface.ipv4_address().unwrap(); let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); @@ -483,7 +523,7 @@ pub fn sys_recvfrom( drop(sockets); SOCKET_ACTIVITY._wait() } - } else if let SocketType::Udp = wrapper.socket_type { + } else if let SocketType::Udp(_) = wrapper.socket_type { loop { let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); @@ -757,6 +797,19 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy Err(SysError::EINVAL) } } + } else if let SocketType::Udp(_) = &wrapper.socket_type { + let mut sockets = iface.sockets(); + let socket = sockets.get::(wrapper.handle); + let endpoint = socket.endpoint(); + if endpoint.is_specified() { + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to(&mut proc, addr, addr_len)?; + } + Ok(0) + } else { + Err(SysError::EINVAL) + } } else { Err(SysError::EINVAL) } @@ -792,6 +845,16 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy } else { Err(SysError::EINVAL) } + } else if let SocketType::Udp(state) = &wrapper.socket_type { + if let Some(endpoint) = state.remote_endpoint { + let sockaddr_in = SockAddr::from(endpoint); + unsafe { + sockaddr_in.write_to(&mut proc, addr, addr_len)?; + } + Ok(0) + } else { + Err(SysError::EINVAL) + } } else { Err(SysError::EINVAL) } diff --git a/kernel/src/syscall/proc.rs b/kernel/src/syscall/proc.rs index c5a1f9c..b2eab2d 100644 --- a/kernel/src/syscall/proc.rs +++ b/kernel/src/syscall/proc.rs @@ -1,7 +1,6 @@ //! Syscalls for process use super::*; -use crate::sync::Condvar; /// Fork the current process. Return the child's PID. pub fn sys_fork(tf: &TrapFrame) -> SysResult { @@ -152,7 +151,6 @@ pub fn sys_kill(pid: usize, sig: usize) -> SysResult { if current_pid == pid { // killing myself sys_exit_group(sig); - Ok(0) } else { if let Some(proc_arc) = PROCESSES.read().get(&pid).and_then(|weak| weak.upgrade()) { let proc = proc_arc.lock();