From 77f8afa30cba671f3fde3afac0d5a86da5418ae5 Mon Sep 17 00:00:00 2001 From: Jiajie Chen Date: Thu, 4 Apr 2019 20:01:59 +0800 Subject: [PATCH] Refactor network endpoint, and add basic support for sockaddr_ll --- kernel/src/lib.rs | 3 +- kernel/src/net/structs.rs | 234 +++++++++++++++++++++++++------------- kernel/src/syscall/net.rs | 116 +++++++++++++------ kernel/src/util/mod.rs | 55 +++++++++ 4 files changed, 293 insertions(+), 115 deletions(-) diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index 20ff71a..53a9cc2 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -22,6 +22,8 @@ use rcore_thread::std_thread as thread; #[macro_use] // print! mod logging; +#[macro_use] +mod util; mod backtrace; mod consts; mod drivers; @@ -34,7 +36,6 @@ mod shell; mod sync; mod syscall; mod trap; -mod util; #[allow(dead_code)] #[cfg(target_arch = "x86_64")] diff --git a/kernel/src/net/structs.rs b/kernel/src/net/structs.rs index 3238661..6a9e569 100644 --- a/kernel/src/net/structs.rs +++ b/kernel/src/net/structs.rs @@ -7,13 +7,22 @@ use alloc::boxed::Box; use smoltcp::socket::*; use smoltcp::wire::*; -/// +#[derive(Clone, Debug)] +pub struct LinkLevelEndpoint {} + +#[derive(Clone, Debug)] +pub enum Endpoint { + Ip(IpEndpoint), + LinkLevel(LinkLevelEndpoint), +} + +/// Common methods that a socket must have pub trait Socket: Send + Sync { - fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint); - fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult; + fn read(&self, data: &mut [u8]) -> (SysResult, Endpoint); + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult; fn poll(&self) -> (bool, bool, bool); // (in, out, err) - fn connect(&mut self, endpoint: IpEndpoint) -> SysResult; - fn bind(&mut self, endpoint: IpEndpoint) -> SysResult { + fn connect(&mut self, endpoint: Endpoint) -> SysResult; + fn bind(&mut self, endpoint: Endpoint) -> SysResult { Err(SysError::EINVAL) } fn listen(&mut self) -> SysResult { @@ -22,13 +31,13 @@ pub trait Socket: Send + Sync { fn shutdown(&self) -> SysResult { Err(SysError::EINVAL) } - fn accept(&mut self) -> Result<(Box, IpEndpoint), SysError> { + fn accept(&mut self) -> Result<(Box, Endpoint), SysError> { Err(SysError::EINVAL) } - fn endpoint(&self) -> Option { + fn endpoint(&self) -> Option { None } - fn remote_endpoint(&self) -> Option { + fn remote_endpoint(&self) -> Option { None } fn setsockopt(&mut self, level: usize, opt: usize, data: &[u8]) -> SysResult { @@ -72,6 +81,12 @@ pub struct RawSocketState { header_included: bool, } +#[derive(Debug, Clone)] +pub struct PacketSocketState { + // no state +// only ethernet egress +} + /// A wrapper for `SocketHandle`. /// Auto increase and decrease reference count on Clone and Drop. #[derive(Debug)] @@ -112,7 +127,7 @@ impl TcpSocketState { } impl Socket for TcpSocketState { - fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { + fn read(&self, data: &mut [u8]) -> (SysResult, Endpoint) { spin_and_wait(&[&SOCKET_ACTIVITY], move || { poll_ifaces(); let mut sockets = SOCKETS.lock(); @@ -127,17 +142,20 @@ impl Socket for TcpSocketState { drop(sockets); poll_ifaces(); - return Some((Ok(size), endpoint)); + return Some((Ok(size), Endpoint::Ip(endpoint))); } } } else { - return Some((Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED)); + return Some(( + Err(SysError::ENOTCONN), + Endpoint::Ip(IpEndpoint::UNSPECIFIED), + )); } None }) } - fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(self.handle.0); @@ -183,52 +201,60 @@ impl Socket for TcpSocketState { (input, output, err) } - fn connect(&mut self, endpoint: IpEndpoint) -> SysResult { + fn connect(&mut self, endpoint: Endpoint) -> SysResult { let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(self.handle.0); - let temp_port = get_ephemeral_port(); + if let Endpoint::Ip(ip) = endpoint { + let temp_port = get_ephemeral_port(); - match socket.connect(endpoint, temp_port) { - Ok(()) => { - // avoid deadlock - drop(socket); - drop(sockets); + match socket.connect(ip, temp_port) { + Ok(()) => { + // avoid deadlock + drop(socket); + drop(sockets); - // wait for connection result - loop { - poll_ifaces(); + // wait for connection result + loop { + poll_ifaces(); - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(self.handle.0); - match socket.state() { - TcpState::SynSent => { - // still connecting - drop(socket); - drop(sockets); - debug!("poll for connection wait"); - SOCKET_ACTIVITY._wait(); - } - TcpState::Established => { - break Ok(0); - } - _ => { - break Err(SysError::ECONNREFUSED); + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + match socket.state() { + TcpState::SynSent => { + // still connecting + drop(socket); + drop(sockets); + debug!("poll for connection wait"); + SOCKET_ACTIVITY._wait(); + } + TcpState::Established => { + break Ok(0); + } + _ => { + break Err(SysError::ECONNREFUSED); + } } } } + Err(_) => Err(SysError::ENOBUFS), } - Err(_) => Err(SysError::ENOBUFS), + } else { + Err(SysError::EINVAL) } } - fn bind(&mut self, mut endpoint: IpEndpoint) -> SysResult { - if endpoint.port == 0 { - endpoint.port = get_ephemeral_port(); + fn bind(&mut self, mut endpoint: Endpoint) -> SysResult { + if let Endpoint::Ip(mut ip) = endpoint { + if ip.port == 0 { + ip.port = get_ephemeral_port(); + } + self.local_endpoint = Some(ip); + self.is_listening = false; + Ok(0) + } else { + Err(SysError::EINVAL) } - self.local_endpoint = Some(endpoint); - self.is_listening = false; - Ok(0) } fn listen(&mut self) -> SysResult { @@ -260,7 +286,7 @@ impl Socket for TcpSocketState { Ok(0) } - fn accept(&mut self) -> Result<(Box, IpEndpoint), SysError> { + fn accept(&mut self) -> Result<(Box, Endpoint), SysError> { let endpoint = self.local_endpoint.ok_or(SysError::EINVAL)?; loop { let mut sockets = SOCKETS.lock(); @@ -287,7 +313,7 @@ impl Socket for TcpSocketState { drop(sockets); poll_ifaces(); - return Ok((new_socket, remote_endpoint)); + return Ok((new_socket, Endpoint::Ip(remote_endpoint))); } // avoid deadlock @@ -297,24 +323,27 @@ impl Socket for TcpSocketState { } } - fn endpoint(&self) -> Option { - self.local_endpoint.clone().or_else(|| { - let mut sockets = SOCKETS.lock(); - let socket = sockets.get::(self.handle.0); - let endpoint = socket.local_endpoint(); - if endpoint.port != 0 { - Some(endpoint) - } else { - None - } - }) + fn endpoint(&self) -> Option { + self.local_endpoint + .clone() + .map(|e| Endpoint::Ip(e)) + .or_else(|| { + let mut sockets = SOCKETS.lock(); + let socket = sockets.get::(self.handle.0); + let endpoint = socket.local_endpoint(); + if endpoint.port != 0 { + Some(Endpoint::Ip(endpoint)) + } else { + None + } + }) } - fn remote_endpoint(&self) -> Option { + fn remote_endpoint(&self) -> Option { let mut sockets = SOCKETS.lock(); let socket = sockets.get::(self.handle.0); if socket.is_open() { - Some(socket.remote_endpoint()) + Some(Endpoint::Ip(socket.remote_endpoint())) } else { None } @@ -346,7 +375,7 @@ impl UdpSocketState { } impl Socket for UdpSocketState { - fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { + fn read(&self, data: &mut [u8]) -> (SysResult, Endpoint) { loop { let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(self.handle.0); @@ -359,10 +388,13 @@ impl Socket for UdpSocketState { drop(sockets); poll_ifaces(); - return (Ok(size), endpoint); + return (Ok(size), Endpoint::Ip(endpoint)); } } else { - return (Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED); + return ( + Err(SysError::ENOTCONN), + Endpoint::Ip(IpEndpoint::UNSPECIFIED), + ); } // avoid deadlock @@ -371,9 +403,9 @@ impl Socket for UdpSocketState { } } - fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { let remote_endpoint = { - if let Some(ref endpoint) = sendto_endpoint { + if let Some(Endpoint::Ip(ref endpoint)) = sendto_endpoint { endpoint } else if let Some(ref endpoint) = self.remote_endpoint { endpoint @@ -422,33 +454,41 @@ impl Socket for UdpSocketState { (input, output, err) } - fn connect(&mut self, endpoint: IpEndpoint) -> SysResult { - self.remote_endpoint = Some(endpoint); - Ok(0) + fn connect(&mut self, endpoint: Endpoint) -> SysResult { + if let Endpoint::Ip(ip) = endpoint { + self.remote_endpoint = Some(ip); + Ok(0) + } else { + Err(SysError::EINVAL) + } } - fn bind(&mut self, endpoint: IpEndpoint) -> SysResult { + fn bind(&mut self, endpoint: Endpoint) -> SysResult { let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(self.handle.0); - match socket.bind(endpoint) { - Ok(()) => Ok(0), - Err(_) => Err(SysError::EINVAL), + if let Endpoint::Ip(ip) = endpoint { + match socket.bind(ip) { + Ok(()) => Ok(0), + Err(_) => Err(SysError::EINVAL), + } + } else { + Err(SysError::EINVAL) } } - fn endpoint(&self) -> Option { + fn endpoint(&self) -> Option { let mut sockets = SOCKETS.lock(); let socket = sockets.get::(self.handle.0); let endpoint = socket.endpoint(); if endpoint.port != 0 { - Some(endpoint) + Some(Endpoint::Ip(endpoint)) } else { None } } - fn remote_endpoint(&self) -> Option { - self.remote_endpoint.clone() + fn remote_endpoint(&self) -> Option { + self.remote_endpoint.clone().map(|e| Endpoint::Ip(e)) } fn box_clone(&self) -> Box { @@ -482,7 +522,7 @@ impl RawSocketState { } impl Socket for RawSocketState { - fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) { + fn read(&self, data: &mut [u8]) -> (SysResult, Endpoint) { loop { let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(self.handle.0); @@ -492,10 +532,10 @@ impl Socket for RawSocketState { return ( Ok(size), - IpEndpoint { + Endpoint::Ip(IpEndpoint { addr: IpAddress::Ipv4(packet.src_addr()), port: 0, - }, + }), ); } @@ -506,7 +546,7 @@ impl Socket for RawSocketState { } } - fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { if self.header_included { let mut sockets = SOCKETS.lock(); let mut socket = sockets.get::(self.handle.0); @@ -516,7 +556,7 @@ impl Socket for RawSocketState { Err(_) => Err(SysError::ENOBUFS), } } else { - if let Some(endpoint) = sendto_endpoint { + if let Some(Endpoint::Ip(endpoint)) = sendto_endpoint { // temporary solution let iface = &*(NET_DRIVERS.read()[0]); let v4_src = iface.ipv4_address().unwrap(); @@ -559,7 +599,7 @@ impl Socket for RawSocketState { unimplemented!() } - fn connect(&mut self, _endpoint: IpEndpoint) -> SysResult { + fn connect(&mut self, _endpoint: Endpoint) -> SysResult { unimplemented!() } @@ -581,6 +621,38 @@ impl Socket for RawSocketState { } } +impl PacketSocketState { + pub fn new() -> Self { + PacketSocketState {} + } +} + +impl Socket for PacketSocketState { + fn read(&self, data: &mut [u8]) -> (SysResult, Endpoint) { + unimplemented!() + } + + fn write(&self, data: &[u8], sendto_endpoint: Option) -> SysResult { + if let Some(endpoint) = sendto_endpoint { + unimplemented!() + } else { + Err(SysError::ENOTCONN) + } + } + + fn poll(&self) -> (bool, bool, bool) { + unimplemented!() + } + + fn connect(&mut self, _endpoint: Endpoint) -> SysResult { + unimplemented!() + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + fn get_ephemeral_port() -> u16 { // TODO selects non-conflict high port static mut EPHEMERAL_PORT: u16 = 0; diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index fd303f0..71aca45 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -3,7 +3,9 @@ use super::*; use crate::drivers::SOCKET_ACTIVITY; use crate::fs::FileLike; -use crate::net::{RawSocketState, Socket, TcpSocketState, UdpSocketState, SOCKETS}; +use crate::net::{ + Endpoint, PacketSocketState, RawSocketState, Socket, TcpSocketState, UdpSocketState, SOCKETS, +}; use crate::sync::{MutexGuard, SpinNoIrq, SpinNoIrqLock as Mutex}; use alloc::boxed::Box; use core::cmp::min; @@ -11,16 +13,23 @@ use core::mem::size_of; use smoltcp::wire::*; pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult { + let domain = AddressFamily::from(domain as u16); info!( - "socket: domain: {}, socket_type: {}, protocol: {}", + "socket: domain: {:?}, socket_type: {}, protocol: {}", domain, socket_type, protocol ); let mut proc = process(); let socket: Box = match domain { - AF_INET | AF_UNIX => match socket_type & SOCK_TYPE_MASK { - SOCK_STREAM => Box::new(TcpSocketState::new()), - SOCK_DGRAM => Box::new(UdpSocketState::new()), - SOCK_RAW => Box::new(RawSocketState::new(protocol as u8)), + AddressFamily::Internet | AddressFamily::Unix => { + match SocketType::from(socket_type as u8 & SOCK_TYPE_MASK) { + SocketType::Stream => Box::new(TcpSocketState::new()), + SocketType::Datagram => Box::new(UdpSocketState::new()), + SocketType::Raw => Box::new(RawSocketState::new(protocol as u8)), + _ => return Err(SysError::EINVAL), + } + } + AddressFamily::Packet => match SocketType::from(socket_type as u8 & SOCK_TYPE_MASK) { + SocketType::Raw => Box::new(PacketSocketState::new()), _ => return Err(SysError::EINVAL), }, _ => return Err(SysError::EAFNOSUPPORT), @@ -165,7 +174,7 @@ pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult let mut proc = process(); let mut endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?; - info!("sys_bind: fd: {} bind to {}", fd, endpoint); + info!("sys_bind: fd: {} bind to {:?}", fd, endpoint); let socket = proc.get_socket(fd)?; socket.bind(endpoint) @@ -279,10 +288,21 @@ pub struct SockAddrUn { sun_path: [u8; 108], } +#[repr(C)] +pub struct SockAddrLl { + sll_protocol: u16, + sll_ifindex: u32, + sll_hatype: u16, + sll_pkttype: u8, + sll_halen: u8, + sll_addr: u8, +} + #[repr(C)] pub union SockAddrPayload { addr_in: SockAddrIn, addr_un: SockAddrUn, + addr_ll: SockAddrLl, } #[repr(C)] @@ -291,20 +311,24 @@ pub struct SockAddr { payload: SockAddrPayload, } -impl From for SockAddr { - fn from(endpoint: IpEndpoint) -> Self { - match endpoint.addr { - IpAddress::Ipv4(ipv4) => SockAddr { - family: AF_INET as u16, - payload: SockAddrPayload { - addr_in: SockAddrIn { - sin_port: u16::to_be(endpoint.port), - sin_addr: u32::to_be(u32::from_be_bytes(ipv4.0)), - sin_zero: [0; 8], +impl From for SockAddr { + fn from(endpoint: Endpoint) -> Self { + if let Endpoint::Ip(ip) = endpoint { + match ip.addr { + IpAddress::Ipv4(ipv4) => SockAddr { + family: AddressFamily::Internet.into(), + payload: SockAddrPayload { + addr_in: SockAddrIn { + sin_port: u16::to_be(ip.port), + sin_addr: u32::to_be(u32::from_be_bytes(ipv4.0)), + sin_zero: [0; 8], + }, }, }, - }, - _ => unimplemented!("ipv6"), + _ => unimplemented!("only ipv4"), + } + } else { + unimplemented!("only ip"); } } } @@ -315,14 +339,14 @@ fn sockaddr_to_endpoint( proc: &mut Process, addr: *const SockAddr, len: usize, -) -> Result { +) -> Result { if len < size_of::() { return Err(SysError::EINVAL); } proc.vm.check_read_array(addr as *const u8, len)?; unsafe { - match (*addr).family as usize { - AF_INET => { + match AddressFamily::from((*addr).family) { + AddressFamily::Internet => { if len < size_of::() + size_of::() { return Err(SysError::EINVAL); } @@ -330,9 +354,15 @@ fn sockaddr_to_endpoint( let addr = IpAddress::from(Ipv4Address::from_bytes( &u32::from_be((*addr).payload.addr_in.sin_addr).to_be_bytes()[..], )); - Ok((addr, port).into()) + Ok(Endpoint::Ip((addr, port).into())) + } + AddressFamily::Unix => Err(SysError::EINVAL), + AddressFamily::Packet => { + if len < size_of::() + size_of::() { + return Err(SysError::EINVAL); + } + unimplemented!() } - AF_UNIX => Err(SysError::EINVAL), _ => Err(SysError::EINVAL), } } @@ -354,9 +384,9 @@ impl SockAddr { proc.vm.check_write_ptr(addr_len)?; let max_addr_len = *addr_len as usize; - let full_len = match self.family as usize { - AF_INET => size_of::() + size_of::(), - AF_UNIX => return Err(SysError::EINVAL), + let full_len = match AddressFamily::from(self.family) { + AddressFamily::Internet => size_of::() + size_of::(), + AddressFamily::Unix => return Err(SysError::EINVAL), _ => return Err(SysError::EINVAL), }; @@ -372,13 +402,33 @@ impl SockAddr { } } -const AF_UNIX: usize = 1; -const AF_INET: usize = 2; +enum_with_unknown! { + /// Address families + pub doc enum AddressFamily(u16) { + /// Unspecified + Unspecified = 0, + /// Unix domain sockets + Unix = 1, + /// Internet IP Protocol + Internet = 2, + /// Packet family + Packet = 17, + } +} -const SOCK_STREAM: usize = 1; -const SOCK_DGRAM: usize = 2; -const SOCK_RAW: usize = 3; -const SOCK_TYPE_MASK: usize = 0xf; +const SOCK_TYPE_MASK: u8 = 0xf; + +enum_with_unknown! { + /// Socket types + pub doc enum SocketType(u8) { + /// Stream + Stream = 1, + /// Datagram + Datagram = 2, + /// Raw + Raw = 3, + } +} const IPPROTO_IP: usize = 0; const IPPROTO_ICMP: usize = 1; diff --git a/kernel/src/util/mod.rs b/kernel/src/util/mod.rs index 0c6df60..582f101 100644 --- a/kernel/src/util/mod.rs +++ b/kernel/src/util/mod.rs @@ -13,3 +13,58 @@ pub unsafe fn write_cstr(ptr: *mut u8, s: &str) { ptr.copy_from(s.as_ptr(), s.len()); ptr.add(s.len()).write(0); } + +// Taken from m-labs/smoltcp src/macros.rs, thanks for their contribution +// https://github.com/m-labs/smoltcp/blob/master/src/macros.rs +macro_rules! enum_with_unknown { + ( + $( #[$enum_attr:meta] )* + pub enum $name:ident($ty:ty) { + $( $variant:ident = $value:expr ),+ $(,)* + } + ) => { + enum_with_unknown! { + $( #[$enum_attr] )* + pub doc enum $name($ty) { + $( #[doc(shown)] $variant = $value ),+ + } + } + }; + ( + $( #[$enum_attr:meta] )* + pub doc enum $name:ident($ty:ty) { + $( + $( #[$variant_attr:meta] )+ + $variant:ident = $value:expr $(,)* + ),+ + } + ) => { + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] + $( #[$enum_attr] )* + pub enum $name { + $( + $( #[$variant_attr] )* + $variant + ),*, + Unknown($ty) + } + + impl ::core::convert::From<$ty> for $name { + fn from(value: $ty) -> Self { + match value { + $( $value => $name::$variant ),*, + other => $name::Unknown(other) + } + } + } + + impl ::core::convert::From<$name> for $ty { + fn from(value: $name) -> Self { + match value { + $( $name::$variant => $value ),*, + $name::Unknown(other) => other + } + } + } + } +}