From d3a462e8a0d8b5a129b7e039207220a9bef2e6d1 Mon Sep 17 00:00:00 2001 From: Jiajie Chen Date: Thu, 4 Apr 2019 14:32:58 +0800 Subject: [PATCH] Implement hdrincl for raw socket --- kernel/src/net/structs.rs | 89 ++++++++++++++++++++++++++------------- kernel/src/syscall/net.rs | 13 ++++-- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/kernel/src/net/structs.rs b/kernel/src/net/structs.rs index bc076fb..719cc48 100644 --- a/kernel/src/net/structs.rs +++ b/kernel/src/net/structs.rs @@ -31,6 +31,10 @@ pub trait Socket: Send + Sync { fn remote_endpoint(&self) -> Option { None } + fn setsockopt(&mut self, level: usize, opt: usize, data: &[u8]) -> SysResult { + warn!("setsockopt is unimplemented"); + Ok(0) + } fn box_clone(&self) -> Box; } @@ -65,6 +69,7 @@ pub struct UdpSocketState { #[derive(Debug, Clone)] pub struct RawSocketState { handle: GlobalSocketHandle, + header_included: bool, } /// A wrapper for `SocketHandle`. @@ -469,7 +474,10 @@ impl RawSocketState { ); let handle = GlobalSocketHandle(SOCKETS.lock().add(socket)); - RawSocketState { handle } + RawSocketState { + handle, + header_included: false, + } } } @@ -499,41 +507,51 @@ impl Socket for RawSocketState { } fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { - if let Some(endpoint) = sendto_endpoint { - // temporary solution - let iface = &*(NET_DRIVERS.read()[0]); - let v4_src = iface.ipv4_address().unwrap(); + if self.header_included { let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(self.handle.0); - if let IpAddress::Ipv4(v4_dst) = endpoint.addr { - let len = data.len(); - // using 20-byte IPv4 header - let mut buffer = vec![0u8; len + 20]; - let mut packet = Ipv4Packet::new_unchecked(&mut buffer); - packet.set_version(4); - packet.set_header_len(20); - packet.set_total_len((20 + len) as u16); - packet.set_protocol(socket.ip_protocol().into()); - packet.set_src_addr(v4_src); - packet.set_dst_addr(v4_dst); - let payload = packet.payload_mut(); - payload.copy_from_slice(data); - packet.fill_checksum(); - - socket.send_slice(&buffer).unwrap(); + match socket.send_slice(&data) { + Ok(()) => Ok(data.len()), + Err(_) => Err(SysError::ENOBUFS), + } + } else { + if let Some(endpoint) = sendto_endpoint { + // temporary solution + let iface = &*(NET_DRIVERS.read()[0]); + let v4_src = iface.ipv4_address().unwrap(); + let mut sockets = SOCKETS.lock(); + let mut socket = sockets.get::(self.handle.0); + + if let IpAddress::Ipv4(v4_dst) = endpoint.addr { + let len = data.len(); + // using 20-byte IPv4 header + let mut buffer = vec![0u8; len + 20]; + let mut packet = Ipv4Packet::new_unchecked(&mut buffer); + packet.set_version(4); + packet.set_header_len(20); + packet.set_total_len((20 + len) as u16); + packet.set_protocol(socket.ip_protocol().into()); + packet.set_src_addr(v4_src); + packet.set_dst_addr(v4_dst); + let payload = packet.payload_mut(); + payload.copy_from_slice(data); + packet.fill_checksum(); + + socket.send_slice(&buffer).unwrap(); - // avoid deadlock - drop(socket); - drop(sockets); - iface.poll(); + // avoid deadlock + drop(socket); + drop(sockets); + iface.poll(); - Ok(len) + Ok(len) + } else { + unimplemented!("ip type") + } } else { - unimplemented!("ip type") + Err(SysError::ENOTCONN) } - } else { - Err(SysError::ENOTCONN) } } @@ -548,6 +566,19 @@ impl Socket for RawSocketState { fn box_clone(&self) -> Box { Box::new(self.clone()) } + + fn setsockopt(&mut self, level: usize, opt: usize, data: &[u8]) -> SysResult { + match (level, opt) { + (IPPROTO_IP, IP_HDRINCL) => { + if let Some(arg) = data.first() { + self.header_included = *arg > 0; + debug!("hdrincl set to {}", self.header_included); + } + } + _ => {} + } + Ok(0) + } } fn get_ephemeral_port() -> u16 { diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index 2811f12..fd303f0 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -34,15 +34,18 @@ pub fn sys_setsockopt( fd: usize, level: usize, optname: usize, - _optval: *const u8, - _optlen: usize, + optval: *const u8, + optlen: usize, ) -> SysResult { info!( "setsockopt: fd: {}, level: {}, optname: {}", fd, level, optname ); - warn!("sys_setsockopt is unimplemented"); - Ok(0) + let mut proc = process(); + proc.vm.check_read_array(optval, optlen)?; + let data = unsafe { slice::from_raw_parts(optval, optlen) }; + let socket = proc.get_socket(fd)?; + socket.setsockopt(level, optname, data) } pub fn sys_getsockopt( @@ -387,3 +390,5 @@ const SO_RCVBUF: usize = 8; const SO_LINGER: usize = 13; const TCP_CONGESTION: usize = 13; + +const IP_HDRINCL: usize = 3;