Implement hdrincl for raw socket

master
Jiajie Chen 6 years ago
parent 025007c8bf
commit d3a462e8a0

@ -31,6 +31,10 @@ pub trait Socket: Send + Sync {
fn remote_endpoint(&self) -> Option<IpEndpoint> { fn remote_endpoint(&self) -> Option<IpEndpoint> {
None None
} }
fn setsockopt(&mut self, level: usize, opt: usize, data: &[u8]) -> SysResult {
warn!("setsockopt is unimplemented");
Ok(0)
}
fn box_clone(&self) -> Box<dyn Socket>; fn box_clone(&self) -> Box<dyn Socket>;
} }
@ -65,6 +69,7 @@ pub struct UdpSocketState {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RawSocketState { pub struct RawSocketState {
handle: GlobalSocketHandle, handle: GlobalSocketHandle,
header_included: bool,
} }
/// A wrapper for `SocketHandle`. /// A wrapper for `SocketHandle`.
@ -469,7 +474,10 @@ impl RawSocketState {
); );
let handle = GlobalSocketHandle(SOCKETS.lock().add(socket)); let handle = GlobalSocketHandle(SOCKETS.lock().add(socket));
RawSocketState { handle } RawSocketState {
handle,
header_included: false,
}
} }
} }
@ -499,41 +507,51 @@ impl Socket for RawSocketState {
} }
fn write(&self, data: &[u8], sendto_endpoint: Option<IpEndpoint>) -> SysResult { fn write(&self, data: &[u8], sendto_endpoint: Option<IpEndpoint>) -> SysResult {
if let Some(endpoint) = sendto_endpoint { if self.header_included {
// temporary solution
let iface = &*(NET_DRIVERS.read()[0]);
let v4_src = iface.ipv4_address().unwrap();
let mut sockets = SOCKETS.lock(); let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<RawSocket>(self.handle.0); let mut socket = sockets.get::<RawSocket>(self.handle.0);
if let IpAddress::Ipv4(v4_dst) = endpoint.addr { match socket.send_slice(&data) {
let len = data.len(); Ok(()) => Ok(data.len()),
// using 20-byte IPv4 header Err(_) => Err(SysError::ENOBUFS),
let mut buffer = vec![0u8; len + 20]; }
let mut packet = Ipv4Packet::new_unchecked(&mut buffer); } else {
packet.set_version(4); if let Some(endpoint) = sendto_endpoint {
packet.set_header_len(20); // temporary solution
packet.set_total_len((20 + len) as u16); let iface = &*(NET_DRIVERS.read()[0]);
packet.set_protocol(socket.ip_protocol().into()); let v4_src = iface.ipv4_address().unwrap();
packet.set_src_addr(v4_src); let mut sockets = SOCKETS.lock();
packet.set_dst_addr(v4_dst); let mut socket = sockets.get::<RawSocket>(self.handle.0);
let payload = packet.payload_mut();
payload.copy_from_slice(data); if let IpAddress::Ipv4(v4_dst) = endpoint.addr {
packet.fill_checksum(); let len = data.len();
// using 20-byte IPv4 header
socket.send_slice(&buffer).unwrap(); let mut buffer = vec![0u8; len + 20];
let mut packet = Ipv4Packet::new_unchecked(&mut buffer);
packet.set_version(4);
packet.set_header_len(20);
packet.set_total_len((20 + len) as u16);
packet.set_protocol(socket.ip_protocol().into());
packet.set_src_addr(v4_src);
packet.set_dst_addr(v4_dst);
let payload = packet.payload_mut();
payload.copy_from_slice(data);
packet.fill_checksum();
socket.send_slice(&buffer).unwrap();
// avoid deadlock // avoid deadlock
drop(socket); drop(socket);
drop(sockets); drop(sockets);
iface.poll(); iface.poll();
Ok(len) Ok(len)
} else {
unimplemented!("ip type")
}
} else { } else {
unimplemented!("ip type") Err(SysError::ENOTCONN)
} }
} else {
Err(SysError::ENOTCONN)
} }
} }
@ -548,6 +566,19 @@ impl Socket for RawSocketState {
fn box_clone(&self) -> Box<dyn Socket> { fn box_clone(&self) -> Box<dyn Socket> {
Box::new(self.clone()) Box::new(self.clone())
} }
fn setsockopt(&mut self, level: usize, opt: usize, data: &[u8]) -> SysResult {
match (level, opt) {
(IPPROTO_IP, IP_HDRINCL) => {
if let Some(arg) = data.first() {
self.header_included = *arg > 0;
debug!("hdrincl set to {}", self.header_included);
}
}
_ => {}
}
Ok(0)
}
} }
fn get_ephemeral_port() -> u16 { fn get_ephemeral_port() -> u16 {

@ -34,15 +34,18 @@ pub fn sys_setsockopt(
fd: usize, fd: usize,
level: usize, level: usize,
optname: usize, optname: usize,
_optval: *const u8, optval: *const u8,
_optlen: usize, optlen: usize,
) -> SysResult { ) -> SysResult {
info!( info!(
"setsockopt: fd: {}, level: {}, optname: {}", "setsockopt: fd: {}, level: {}, optname: {}",
fd, level, optname fd, level, optname
); );
warn!("sys_setsockopt is unimplemented"); let mut proc = process();
Ok(0) proc.vm.check_read_array(optval, optlen)?;
let data = unsafe { slice::from_raw_parts(optval, optlen) };
let socket = proc.get_socket(fd)?;
socket.setsockopt(level, optname, data)
} }
pub fn sys_getsockopt( pub fn sys_getsockopt(
@ -387,3 +390,5 @@ const SO_RCVBUF: usize = 8;
const SO_LINGER: usize = 13; const SO_LINGER: usize = 13;
const TCP_CONGESTION: usize = 13; const TCP_CONGESTION: usize = 13;
const IP_HDRINCL: usize = 3;

Loading…
Cancel
Save