diff --git a/kernel/src/syscall/mod.rs b/kernel/src/syscall/mod.rs index 167c2b7..ea1bb79 100644 --- a/kernel/src/syscall/mod.rs +++ b/kernel/src/syscall/mod.rs @@ -57,16 +57,16 @@ pub fn syscall(id: usize, args: [usize; 6], tf: &mut TrapFrame) -> isize { 035 => sys_sleep(args[0]), // TODO: nanosleep 039 => sys_getpid(), 041 => sys_socket(args[0], args[1], args[2]), - 042 => sys_connect(args[0], args[1] as *const u8, args[2]), - 043 => sys_accept(args[0], args[1] as *mut u8, args[2] as *mut u32), - 044 => sys_sendto(args[0], args[1] as *const u8, args[2], args[3], args[4] as *const u8, args[5]), - 045 => sys_recvfrom(args[0], args[1] as *mut u8, args[2], args[3], args[4] as *mut u8, args[5] as *mut u32), + 042 => sys_connect(args[0], args[1] as *const SockaddrIn, args[2]), + 043 => sys_accept(args[0], args[1] as *mut SockaddrIn, args[2] as *mut u32), + 044 => sys_sendto(args[0], args[1] as *const u8, args[2], args[3], args[4] as *const SockaddrIn, args[5]), + 045 => sys_recvfrom(args[0], args[1] as *mut u8, args[2], args[3], args[4] as *mut SockaddrIn, args[5] as *mut u32), // 046 => sys_sendmsg(), // 047 => sys_recvmsg(), // 048 => sys_shutdown(), - 049 => sys_bind(args[0], args[1] as *const u8, args[2]), + 049 => sys_bind(args[0], args[1] as *const SockaddrIn, args[2]), 050 => sys_listen(args[0], args[1]), - 051 => sys_getsockname(args[0], args[1] as *mut u8, args[2] as *mut u32), + 051 => sys_getsockname(args[0], args[1] as *mut SockaddrIn, args[2] as *mut u32), 054 => sys_setsockopt(args[0], args[1], args[2], args[3] as *const u8, args[4]), 055 => sys_getsockopt(args[0], args[1], args[2], args[3] as *mut u8, args[4] as *mut u32), // 056 => sys_clone(), diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index dc46e39..a8ca6b4 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -28,34 +28,6 @@ fn get_ephemeral_port() -> u16 { } } -fn parse_addr(sockaddr_in: &SockaddrIn, dest: &mut Option, port: &mut u16) { - if sockaddr_in.sin_family == AF_INET as u16 { - *port = u16::from_be(sockaddr_in.sin_port); - let addr = u32::from_be(sockaddr_in.sin_addr); - *dest = Some(IpAddress::v4( - (addr >> 24) as u8, - ((addr >> 16) & 0xFF) as u8, - ((addr >> 8) & 0xFF) as u8, - (addr & 0xFF) as u8, - )); - } -} - -fn fill_addr(sockaddr_in: &mut SockaddrIn, dest: IpAddress, port: u16) { - if let IpAddress::Ipv4(ipv4) = dest { - sockaddr_in.sin_family = AF_INET as u16; - sockaddr_in.sin_port = u16::to_be(port); - sockaddr_in.sin_addr = u32::to_be( - ((ipv4.0[0] as u32) << 24) - | ((ipv4.0[1] as u32) << 16) - | ((ipv4.0[2] as u32) << 8) - | ipv4.0[3] as u32, - ); - } else { - unimplemented!("ipv6"); - } -} - pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult { info!( "socket: domain: {}, socket_type: {}, protocol: {}", @@ -163,14 +135,6 @@ pub fn sys_getsockopt( Err(SysError::ENOPROTOOPT) } -#[repr(C)] -struct SockaddrIn { - sin_family: u16, - sin_port: u16, - sin_addr: u32, - sin_zero: [u8; 8], -} - impl Process { fn get_socket(&mut self, fd: usize) -> Result { let file = self.files.get_mut(&fd).ok_or(SysError::EBADF)?; @@ -189,7 +153,7 @@ impl Process { } } -pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { +pub fn sys_connect(fd: usize, addr: *const SockaddrIn, addrlen: usize) -> SysResult { info!( "sys_connect: fd: {}, addr: {:?}, addrlen: {}", fd, addr, addrlen @@ -198,16 +162,9 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { let mut proc = process(); proc.memory_set.check_ptr(addr)?; - let mut dest = None; - let mut port = 0; - // FIXME: check size as per sin_family - let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; - parse_addr(&sockaddr_in, &mut dest, &mut port); - - if dest == None { - return Err(SysError::EINVAL); - } + let sockaddr_in = unsafe { &*(addr) }; + let endpoint = sockaddr_in.to_endpoint()?; let wrapper = proc.get_socket(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { @@ -217,7 +174,7 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { let temp_port = get_ephemeral_port(); - match socket.connect((dest.unwrap(), port), temp_port) { + match socket.connect(endpoint, temp_port) { Ok(()) => { // avoid deadlock drop(socket); @@ -345,7 +302,7 @@ pub fn sys_sendto( buffer: *const u8, len: usize, flags: usize, - addr: *const u8, + addr: *const SockaddrIn, addr_len: usize, ) -> SysResult { info!( @@ -357,6 +314,9 @@ pub fn sys_sendto( proc.memory_set.check_ptr(addr)?; proc.memory_set.check_array(buffer, len)?; + let sockaddr_in = unsafe { &*(addr) }; + let endpoint = sockaddr_in.to_endpoint()?; + let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket(fd)?; @@ -365,16 +325,7 @@ pub fn sys_sendto( let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); - let mut dest = None; - let mut port = 0; - - // FIXME: check size as per sin_family - let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; - parse_addr(&sockaddr_in, &mut dest, &mut port); - - if dest == None { - return Err(SysError::EINVAL); - } else if let Some(IpAddress::Ipv4(v4_dst)) = dest { + if let IpAddress::Ipv4(v4_dst) = endpoint.addr { let slice = unsafe { slice::from_raw_parts(buffer, len) }; // using 20-byte IPv4 header let mut buffer = vec![0u8; len + 20]; @@ -405,38 +356,25 @@ pub fn sys_sendto( let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); - let mut dest = None; - let mut port = 0; - - // FIXME: check size as per sin_family - let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; - parse_addr(&sockaddr_in, &mut dest, &mut port); - - if dest == None { - return Err(SysError::EINVAL); - } else if let Some(dst_addr) = dest { - if !socket.endpoint().is_specified() { - let temp_port = get_ephemeral_port(); - socket - .bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port)) - .unwrap(); - } - - let slice = unsafe { slice::from_raw_parts(buffer, len) }; - + if !socket.endpoint().is_specified() { + let temp_port = get_ephemeral_port(); socket - .send_slice(&slice, IpEndpoint::new(dst_addr, port)) + .bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port)) .unwrap(); + } - // avoid deadlock - drop(socket); - drop(sockets); - iface.poll(); + let slice = unsafe { slice::from_raw_parts(buffer, len) }; - Ok(len as isize) - } else { - unimplemented!("ip type") - } + socket + .send_slice(&slice, endpoint) + .unwrap(); + + // avoid deadlock + drop(socket); + drop(sockets); + iface.poll(); + + Ok(len as isize) } else { unimplemented!("socket type") } @@ -447,7 +385,7 @@ pub fn sys_recvfrom( buffer: *mut u8, len: usize, flags: usize, - addr: *mut u8, + addr: *mut SockaddrIn, addr_len: *mut u32, ) -> SysResult { info!( @@ -458,7 +396,7 @@ pub fn sys_recvfrom( let mut proc = process(); proc.memory_set.check_mut_array(buffer, len)?; - if addr as usize != 0 { + if !addr.is_null() { proc.memory_set.check_mut_ptr(addr_len)?; let max_addr_len = unsafe { *addr_len } as usize; @@ -482,11 +420,13 @@ pub fn sys_recvfrom( if let Ok(size) = socket.recv_slice(&mut slice) { let mut packet = Ipv4Packet::new_unchecked(&slice); - if addr as usize != 0 { + if !addr.is_null() { // FIXME: check size as per sin_family - let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; - fill_addr(&mut sockaddr_in, IpAddress::Ipv4(packet.src_addr()), 0); - unsafe { *addr_len = size_of::() as u32 }; + let sockaddr_in = SockaddrIn::from(IpEndpoint { + addr: IpAddress::Ipv4(packet.src_addr()), + port: 0, + }); + unsafe { sockaddr_in.write_to(addr, addr_len); } } return Ok(size as isize); @@ -503,10 +443,9 @@ pub fn sys_recvfrom( let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; if let Ok((size, endpoint)) = socket.recv_slice(&mut slice) { - if addr as usize != 0 { - let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; - fill_addr(&mut sockaddr_in, endpoint.addr, endpoint.port); - unsafe { *addr_len = size_of::() as u32 }; + if !addr.is_null() { + let sockaddr_in = SockaddrIn::from(endpoint); + unsafe { sockaddr_in.write_to(addr, addr_len); } } return Ok(size as isize); @@ -526,14 +465,14 @@ pub fn sys_close_socket(proc: &mut Process, fd: usize, handle: SocketHandle) -> let mut sockets = iface.sockets(); sockets.release(handle); sockets.prune(); - + // send FIN immediately when applicable drop(sockets); iface.poll(); Ok(0) } -pub fn sys_bind(fd: usize, addr: *const u8, len: usize) -> SysResult { +pub fn sys_bind(fd: usize, addr: *const SockaddrIn, len: usize) -> SysResult { info!("sys_bind: fd: {} addr: {:?} len: {}", fd, addr, len); let mut proc = process(); proc.memory_set.check_array(addr, len)?; @@ -542,20 +481,13 @@ pub fn sys_bind(fd: usize, addr: *const u8, len: usize) -> SysResult { return Err(SysError::EINVAL); } - let mut host = None; - let mut port = 0; - - let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; - parse_addr(&sockaddr_in, &mut host, &mut port); - - if host == None { - return Err(SysError::EINVAL); - } + let sockaddr_in = unsafe { &*(addr) }; + let endpoint = sockaddr_in.to_endpoint()?; let iface = &*(NET_DRIVERS.read()[0]); let wrapper = &mut proc.get_socket_mut(fd)?; if let SocketType::Tcp(_) = wrapper.socket_type { - wrapper.socket_type = SocketType::Tcp(Some(IpEndpoint::new(host.unwrap(), port))); + wrapper.socket_type = SocketType::Tcp(Some(endpoint)); Ok(0) } else { Err(SysError::EINVAL) @@ -583,7 +515,7 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult { } } -pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult { +pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysResult { info!( "sys_accept: fd: {} addr: {:?} addr_len: {:?}", fd, addr, addr_len @@ -592,7 +524,7 @@ pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult { // open multiple sockets for each connection let mut proc = process(); - if addr as usize != 0 { + if !addr.is_null() { proc.memory_set.check_mut_ptr(addr_len)?; let max_addr_len = unsafe { *addr_len } as usize; @@ -637,10 +569,9 @@ pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult { .unwrap(); proc.files.insert(new_fd, orig_handle); - if addr as usize != 0 { - let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; - fill_addr(&mut sockaddr_in, remote_endpoint.addr, remote_endpoint.port); - unsafe { *addr_len = size_of::() as u32 }; + if !addr.is_null() { + let sockaddr_in = SockaddrIn::from(remote_endpoint); + unsafe { sockaddr_in.write_to(addr, addr_len); } } return Ok(new_fd as isize); } @@ -656,7 +587,7 @@ pub fn sys_accept(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult { } } -pub fn sys_getsockname(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResult { +pub fn sys_getsockname(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysResult { info!( "sys_getsockname: fd: {} addr: {:?} addr_len: {:?}", fd, addr, addr_len @@ -666,7 +597,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResul // open multiple sockets for each connection let mut proc = process(); - if addr as usize == 0 { + if addr.is_null() { return Err(SysError::EINVAL); } @@ -682,9 +613,8 @@ pub fn sys_getsockname(fd: usize, addr: *mut u8, addr_len: *mut u32) -> SysResul let iface = &*(NET_DRIVERS.read()[0]); let wrapper = proc.get_socket_mut(fd)?; if let SocketType::Tcp(Some(endpoint)) = wrapper.socket_type { - let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; - fill_addr(&mut sockaddr_in, endpoint.addr, endpoint.port); - unsafe { *addr_len = size_of::() as u32 }; + let sockaddr_in = SockaddrIn::from(endpoint); + unsafe { sockaddr_in.write_to(addr, addr_len); } return Ok(0); } else { Err(SysError::EINVAL) @@ -729,4 +659,47 @@ pub fn sys_dup2_socket(proc: &mut Process, wrapper: SocketWrapper, fd: usize) -> FileLike::Socket(wrapper), ); Ok(fd as isize) -} \ No newline at end of file +} + +#[repr(C)] +pub struct SockaddrIn { + sin_family: u16, + sin_port: u16, + sin_addr: u32, + sin_zero: [u8; 8], +} + +impl From for SockaddrIn { + fn from(endpoint: IpEndpoint) -> Self { + match endpoint.addr { + IpAddress::Ipv4(ipv4) => { + SockaddrIn { + sin_family: AF_INET as u16, + sin_port: u16::to_be(endpoint.port), + sin_addr: u32::to_be(u32::from_be_bytes(ipv4.0)), + sin_zero: [0; 8], + } + } + _ => unimplemented!("ipv6") + } + } +} + +impl SockaddrIn { + fn to_endpoint(&self) -> Result { + // FIXME: check size as per sin_family + if self.sin_family == AF_INET as u16 { + let port = u16::from_be(self.sin_port); + let addr = IpAddress::from(Ipv4Address::from_bytes( + &u32::from_be(self.sin_addr).to_be_bytes()[..] + )); + Ok((addr, port).into()) + } else { + Err(SysError::EINVAL) + } + } + unsafe fn write_to(self, addr: *mut SockaddrIn, addr_len: *mut u32) { + addr.write(self); + addr_len.write(size_of::() as u32); + } +}