diff --git a/kernel/src/drivers/mod.rs b/kernel/src/drivers/mod.rs index 4e48c6f..b2dd301 100644 --- a/kernel/src/drivers/mod.rs +++ b/kernel/src/drivers/mod.rs @@ -2,7 +2,7 @@ use alloc::prelude::*; use core::any::Any; use lazy_static::lazy_static; -use smoltcp::wire::EthernetAddress; +use smoltcp::wire::{EthernetAddress, Ipv4Address}; use smoltcp::socket::SocketSet; use crate::sync::SpinNoIrqLock; @@ -37,6 +37,10 @@ pub trait NetDriver : Send { // get interface name for this device fn get_ifname(&self) -> String; + // get ipv4 address + fn ipv4_address(&self) -> Option; + + // poll for sockets fn poll(&mut self, socket: &mut SocketSet) -> Option; } diff --git a/kernel/src/drivers/net/e1000.rs b/kernel/src/drivers/net/e1000.rs index d7cb399..cf61b64 100644 --- a/kernel/src/drivers/net/e1000.rs +++ b/kernel/src/drivers/net/e1000.rs @@ -143,6 +143,10 @@ impl NetDriver for E1000Interface { } } } + + fn ipv4_address(&self) -> Option { + self.iface.ipv4_address() + } } #[repr(C)] diff --git a/kernel/src/drivers/net/virtio_net.rs b/kernel/src/drivers/net/virtio_net.rs index 8ec7104..850fd43 100644 --- a/kernel/src/drivers/net/virtio_net.rs +++ b/kernel/src/drivers/net/virtio_net.rs @@ -14,7 +14,7 @@ use rcore_memory::paging::PageTable; use smoltcp::phy::{self, DeviceCapabilities}; use smoltcp::Result; use smoltcp::time::Instant; -use smoltcp::wire::EthernetAddress; +use smoltcp::wire::{EthernetAddress, Ipv4Address}; use smoltcp::socket::SocketSet; use volatile::{ReadOnly, Volatile}; @@ -88,6 +88,10 @@ impl NetDriver for VirtIONetDriver { fn poll(&mut self, sockets: &mut SocketSet) -> Option { unimplemented!() } + + fn ipv4_address(&self) -> Option { + unimplemented!() + } } pub struct VirtIONetRxToken(VirtIONetDriver); diff --git a/kernel/src/process/structs.rs b/kernel/src/process/structs.rs index 230469e..6218338 100644 --- a/kernel/src/process/structs.rs +++ b/kernel/src/process/structs.rs @@ -19,10 +19,24 @@ pub struct Thread { pub proc: Arc>, } +#[derive(Clone)] +pub enum SocketType { + Raw, + Tcp, + Udp, + Icmp +} + +#[derive(Clone)] +pub struct SocketWrapper { + pub handle: SocketHandle, + pub socket_type: SocketType, +} + #[derive(Clone)] pub enum FileLike { File(FileHandle), - Socket(SocketHandle) + Socket(SocketWrapper) } pub struct Process { diff --git a/kernel/src/syscall/fs.rs b/kernel/src/syscall/fs.rs index 27d0552..d08c3b9 100644 --- a/kernel/src/syscall/fs.rs +++ b/kernel/src/syscall/fs.rs @@ -110,7 +110,7 @@ pub fn sys_close(fd: usize) -> SysResult { let mut proc = process(); match proc.files.remove(&fd) { Some(FileLike::File(_)) => Ok(0), - Some(FileLike::Socket(handle)) => sys_close_socket(&mut proc, fd, handle), + Some(FileLike::Socket(wrapper)) => sys_close_socket(&mut proc, fd, wrapper.handle), None => Err(SysError::EINVAL), } } diff --git a/kernel/src/syscall/mod.rs b/kernel/src/syscall/mod.rs index 5843366..f5a9440 100644 --- a/kernel/src/syscall/mod.rs +++ b/kernel/src/syscall/mod.rs @@ -55,7 +55,7 @@ pub fn syscall(id: usize, args: [usize; 6], tf: &mut TrapFrame) -> isize { 042 => sys_connect(args[0], args[1] as *const u8, args[2]), // 043 => sys_accept(), 044 => sys_sendto(args[0], args[1] as *const u8, args[2], args[3], args[4] as *const u8, args[5]), - 045 => sys_recvfrom(args[0], args[1] as *mut u8, args[2], args[3], args[4] as *const u8, args[5]), + 045 => sys_recvfrom(args[0], args[1] as *mut u8, args[2], args[3], args[4] as *mut u8, args[5] as *mut usize), // 046 => sys_sendmsg(), // 047 => sys_recvmsg(), // 048 => sys_shutdown(), diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index b568fcf..1529482 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -15,6 +15,34 @@ const SOCK_RAW: usize = 3; const IPPROTO_IP: usize = 0; const IPPROTO_ICMP: usize = 1; +fn parse_addr(sockaddr_in: &SockaddrIn, dest: &mut Option, port: &mut u16) { + if sockaddr_in.sin_family == AF_INET as u16 { + *port = u16::from_be(sockaddr_in.sin_port); + let addr = u32::from_be(sockaddr_in.sin_addr); + *dest = Some(IpAddress::v4( + (addr >> 24) as u8, + ((addr >> 16) & 0xFF) as u8, + ((addr >> 8) & 0xFF) as u8, + (addr & 0xFF) as u8, + )); + } +} + +fn fill_addr(sockaddr_in: &mut SockaddrIn, dest: IpAddress, port: u16) { + if let IpAddress::Ipv4(ipv4) = dest { + sockaddr_in.sin_family = AF_INET as u16; + sockaddr_in.sin_port = u16::to_be(port); + sockaddr_in.sin_addr = u32::to_be( + ((ipv4.0[0] as u32) << 24) + | ((ipv4.0[1] as u32) << 16) + | ((ipv4.0[2] as u32) << 8) + | ipv4.0[3] as u32, + ); + } else { + unimplemented!("ipv6"); + } +} + pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult { info!( "socket: domain: {}, socket_type: {}, protocol: {}", @@ -31,7 +59,13 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); let tcp_handle = proc.sockets.add(tcp_socket); - proc.files.insert(fd, FileLike::Socket(tcp_handle)); + proc.files.insert( + fd, + FileLike::Socket(SocketWrapper { + handle: tcp_handle, + socket_type: SocketType::Tcp, + }), + ); Ok(fd as isize) } @@ -50,7 +84,13 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu ); let raw_handle = proc.sockets.add(raw_socket); - proc.files.insert(fd, FileLike::Socket(raw_handle)); + proc.files.insert( + fd, + FileLike::Socket(SocketWrapper { + handle: raw_handle, + socket_type: SocketType::Raw, + }), + ); Ok(fd as isize) } _ => Err(SysError::EINVAL), @@ -83,10 +123,10 @@ struct SockaddrIn { } impl Process { - fn get_handle(&mut self, fd: usize) -> Result { + fn get_socket(&mut self, fd: usize) -> Result { let file = self.files.get_mut(&fd).ok_or(SysError::EBADF)?; match file { - FileLike::Socket(handle) => Ok(handle.clone()), + FileLike::Socket(wrapper) => Ok(wrapper.clone()), _ => Err(SysError::ENOTSOCK), } } @@ -98,18 +138,17 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { fd, addr, addrlen ); + let mut proc = process(); + if !proc.memory_set.check_ptr(addr) { + return Err(SysError::EFAULT); + } + let mut dest = None; let mut port = 0; - if addrlen == size_of::() { - let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; - port = ((sockaddr_in.sin_port & 0xFF) << 8) | (sockaddr_in.sin_port >> 8); - dest = Some(IpAddress::v4( - (sockaddr_in.sin_addr & 0xFF) as u8, - ((sockaddr_in.sin_addr >> 8) & 0xFF) as u8, - ((sockaddr_in.sin_addr >> 16) & 0xFF) as u8, - (sockaddr_in.sin_addr >> 24) as u8, - )); - } + + // FIXME: check size as per sin_family + let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; + parse_addr(&sockaddr_in, &mut dest, &mut port); if dest == None { return Err(SysError::EINVAL); @@ -120,52 +159,53 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { let iface = &mut *NET_DRIVERS.lock()[0]; iface.poll(&mut proc.sockets); - // TODO: check its type - let tcp_handle = proc.get_handle(fd)?; - let mut socket = proc.sockets.get::(tcp_handle); + let wrapper = proc.get_socket(fd)?; + if let SocketType::Tcp = wrapper.socket_type { + let mut socket = proc.sockets.get::(wrapper.handle); - // TODO selects non-conflict high port - static mut EPHEMERAL_PORT: u16 = 49152; - let temp_port = unsafe { - if EPHEMERAL_PORT == 65535 { - EPHEMERAL_PORT = 49152; - } else { - EPHEMERAL_PORT = EPHEMERAL_PORT + 1; - } - EPHEMERAL_PORT - }; + // TODO selects non-conflict high port + static mut EPHEMERAL_PORT: u16 = 49152; + let temp_port = unsafe { + if EPHEMERAL_PORT == 65535 { + EPHEMERAL_PORT = 49152; + } else { + EPHEMERAL_PORT = EPHEMERAL_PORT + 1; + } + EPHEMERAL_PORT + }; - match socket.connect((dest.unwrap(), port), temp_port) { - Ok(()) => Ok(0), - Err(_) => Err(SysError::EISCONN), + match socket.connect((dest.unwrap(), port), temp_port) { + Ok(()) => Ok(0), + Err(_) => Err(SysError::EISCONN), + } + } else { + unimplemented!("socket type") } } -pub fn sys_write_socket( - proc: &mut Process, - fd: usize, - base: *const u8, - len: usize, -) -> SysResult { +pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { // little hack: kick it forward let iface = &mut *NET_DRIVERS.lock()[0]; iface.poll(&mut proc.sockets); - // TODO: check its type - let tcp_handle = proc.get_handle(fd)?; - let mut socket = proc.sockets.get::(tcp_handle); - let slice = unsafe { slice::from_raw_parts(base, len) }; - if socket.is_open() { - if socket.can_send() { - match socket.send_slice(&slice) { - Ok(size) => Ok(size as isize), - Err(err) => Err(SysError::ENOBUFS) + let wrapper = proc.get_socket(fd)?; + if let SocketType::Tcp = wrapper.socket_type { + let mut socket = proc.sockets.get::(wrapper.handle); + let slice = unsafe { slice::from_raw_parts(base, len) }; + if socket.is_open() { + if socket.can_send() { + match socket.send_slice(&slice) { + Ok(size) => Ok(size as isize), + Err(err) => Err(SysError::ENOBUFS), + } + } else { + Err(SysError::ENOBUFS) } } else { - Err(SysError::ENOBUFS) + Err(SysError::ECONNREFUSED) } } else { - Err(SysError::ECONNREFUSED) + unimplemented!("socket type") } } @@ -189,9 +229,60 @@ pub fn sys_sendto( addr: *const u8, addr_len: usize, ) -> SysResult { - info!("sys_sendto: fd: {} buffer: {:?} len: {}", fd, buffer, len); - warn!("sys_sendto is unimplemented"); - Err(SysError::EINVAL) + info!( + "sys_sendto: fd: {} buffer: {:?} len: {} addr: {:?} addr_len: {}", + fd, buffer, len, addr, addr_len + ); + let mut proc = process(); + if !proc.memory_set.check_ptr(addr) { + return Err(SysError::EFAULT); + } + + if !proc.memory_set.check_array(buffer, len) { + return Err(SysError::EINVAL); + } + + // little hack: kick it forward + let iface = &mut *NET_DRIVERS.lock()[0]; + iface.poll(&mut proc.sockets); + + let wrapper = proc.get_socket(fd)?; + if let SocketType::Raw = wrapper.socket_type { + let mut socket = proc.sockets.get::(wrapper.handle); + + let mut dest = None; + let mut port = 0; + + // FIXME: check size as per sin_family + let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; + parse_addr(&sockaddr_in, &mut dest, &mut port); + + if dest == None { + return Err(SysError::EINVAL); + } else if let Some(IpAddress::Ipv4(v4_dest)) = dest { + let slice = unsafe { slice::from_raw_parts(buffer, len) }; + // using 20-byte IPv4 header + let mut buffer = vec![0u8; len + 20]; + let mut packet = Ipv4Packet::new_unchecked(&mut buffer); + packet.set_version(4); + packet.set_header_len(20); + packet.set_total_len((20 + len) as u16); + packet.set_protocol(socket.ip_protocol().into()); + packet.set_src_addr(iface.ipv4_address().unwrap()); + packet.set_dst_addr(v4_dest); + let payload = packet.payload_mut(); + payload.copy_from_slice(slice); + packet.fill_checksum(); + + socket.send_slice(&buffer).unwrap(); + + Ok(len as isize) + } else { + unimplemented!("ip type") + } + } else { + unimplemented!("socket type") + } } pub fn sys_recvfrom( @@ -199,12 +290,44 @@ pub fn sys_recvfrom( buffer: *mut u8, len: usize, flags: usize, - addr: *const u8, - addr_len: usize, + addr: *mut u8, + addr_len: *mut usize, ) -> SysResult { - info!("sys_recvfrom: fd: {} buffer: {:?} len: {}", fd, buffer, len); - warn!("sys_recvfrom is unimplemented"); - Err(SysError::EINVAL) + info!( + "sys_recvfrom: fd: {} buffer: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}", + fd, buffer, len, flags, addr, addr_len + ); + let mut proc = process(); + + // little hack: kick it forward + let iface = &mut *NET_DRIVERS.lock()[0]; + iface.poll(&mut proc.sockets); + + let wrapper = proc.get_socket(fd)?; + if let SocketType::Raw = wrapper.socket_type { + let mut socket = proc.sockets.get::(wrapper.handle); + + + let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; + match socket.recv_slice(&mut slice) { + Ok(size) => { + let mut packet = Ipv4Packet::new_unchecked(&slice); + + // FIXME: check size as per sin_family + let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; + fill_addr(&mut sockaddr_in, IpAddress::Ipv4(packet.src_addr()), 0); + unsafe { *addr_len = size_of::() }; + + Ok(size as isize) + } + Err(err) => { + warn!("err {:?}", err); + Err(SysError::ENOBUFS) + } + } + } else { + unimplemented!("socket type") + } } pub fn sys_close_socket(proc: &mut Process, fd: usize, handle: SocketHandle) -> SysResult { @@ -217,4 +340,4 @@ pub fn sys_close_socket(proc: &mut Process, fd: usize, handle: SocketHandle) -> } Ok(0) -} \ No newline at end of file +}