diff --git a/kernel/src/syscall/fs.rs b/kernel/src/syscall/fs.rs index dd20e26..580cd11 100644 --- a/kernel/src/syscall/fs.rs +++ b/kernel/src/syscall/fs.rs @@ -12,9 +12,11 @@ pub fn sys_read(fd: usize, base: *mut u8, len: usize) -> SysResult { info!("read: fd: {}, base: {:?}, len: {:#x}", fd, base, len); let mut proc = process(); proc.memory_set.check_mut_array(base, len)?; - let slice = unsafe { slice::from_raw_parts_mut(base, len) }; - let len = proc.get_file(fd)?.read(slice)?; - Ok(len as isize) + match proc.files.get(&fd) { + Some(FileLike::File(_)) => sys_read_file(&mut proc, fd, base, len), + Some(FileLike::Socket(_)) => sys_read_socket(&mut proc, fd, base, len), + None => Err(SysError::EINVAL) + } } pub fn sys_write(fd: usize, base: *const u8, len: usize) -> SysResult { @@ -28,12 +30,37 @@ pub fn sys_write(fd: usize, base: *const u8, len: usize) -> SysResult { } } +pub fn sys_read_file(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult { + let slice = unsafe { slice::from_raw_parts_mut(base, len) }; + let len = proc.get_file(fd)?.read(slice)?; + Ok(len as isize) +} + pub fn sys_write_file(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { let slice = unsafe { slice::from_raw_parts(base, len) }; let len = proc.get_file(fd)?.write(slice)?; Ok(len as isize) } +#[repr(C)] +pub struct PollFd { + fd: u32, + events: u16, + revents: u16 +} + +pub fn sys_poll(ufds: *mut PollFd, nfds: usize, timeout_msecs: usize) -> SysResult { + info!("poll: ufds: {:?}, nfds: {}, timeout_msecs: {:#x}", ufds, nfds, timeout_msecs); + let mut proc = process(); + proc.memory_set.check_mut_array(ufds, nfds)?; + let slice = unsafe { slice::from_raw_parts_mut(ufds, nfds) }; + + // emulate it for now + use core::time::Duration; + thread::sleep(Duration::from_millis(timeout_msecs as u64)); + Ok(nfds as isize) +} + pub fn sys_readv(fd: usize, iov_ptr: *const IoVec, iov_count: usize) -> SysResult { info!("readv: fd: {}, iov: {:?}, count: {}", fd, iov_ptr, iov_count); let mut proc = process(); diff --git a/kernel/src/syscall/mod.rs b/kernel/src/syscall/mod.rs index d4caf88..4ebff9b 100644 --- a/kernel/src/syscall/mod.rs +++ b/kernel/src/syscall/mod.rs @@ -39,7 +39,7 @@ pub fn syscall(id: usize, args: [usize; 6], tf: &mut TrapFrame) -> isize { 004 => sys_stat(args[0] as *const u8, args[1] as *mut Stat), 005 => sys_fstat(args[0], args[1] as *mut Stat), 006 => sys_lstat(args[0] as *const u8, args[1] as *mut Stat), -// 007 => sys_poll(), + 007 => sys_poll(args[0] as *mut PollFd, args[1], args[2]), 008 => sys_lseek(args[0], args[1] as i64, args[2] as u8), 009 => sys_mmap(args[0], args[1], args[2], args[3], args[4] as i32, args[5]), 011 => sys_munmap(args[0], args[1]), @@ -55,7 +55,7 @@ pub fn syscall(id: usize, args: [usize; 6], tf: &mut TrapFrame) -> isize { 042 => sys_connect(args[0], args[1] as *const u8, args[2]), // 043 => sys_accept(), 044 => sys_sendto(args[0], args[1] as *const u8, args[2], args[3], args[4] as *const u8, args[5]), - 045 => sys_recvfrom(args[0], args[1] as *mut u8, args[2], args[3], args[4] as *mut u8, args[5] as *mut usize), + 045 => sys_recvfrom(args[0], args[1] as *mut u8, args[2], args[3], args[4] as *mut u8, args[5] as *mut u32), // 046 => sys_sendmsg(), // 047 => sys_recvmsg(), // 048 => sys_shutdown(), diff --git a/kernel/src/syscall/net.rs b/kernel/src/syscall/net.rs index 586e8f2..9c843d9 100644 --- a/kernel/src/syscall/net.rs +++ b/kernel/src/syscall/net.rs @@ -15,6 +15,19 @@ const SOCK_RAW: usize = 3; const IPPROTO_IP: usize = 0; const IPPROTO_ICMP: usize = 1; +fn get_ephemeral_port() -> u16 { + // TODO selects non-conflict high port + static mut EPHEMERAL_PORT: u16 = 49152; + unsafe { + if EPHEMERAL_PORT == 65535 { + EPHEMERAL_PORT = 49152; + } else { + EPHEMERAL_PORT = EPHEMERAL_PORT + 1; + } + EPHEMERAL_PORT + } +} + fn parse_addr(sockaddr_in: &SockaddrIn, dest: &mut Option, port: &mut u16) { if sockaddr_in.sin_family == AF_INET as u16 { *port = u16::from_be(sockaddr_in.sin_port); @@ -70,6 +83,26 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu Ok(fd as isize) } + SOCK_DGRAM => { + let fd = proc.get_free_inode(); + + let udp_rx_buffer = + UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 2048]); + let udp_tx_buffer = + UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 2048]); + let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + + let udp_handle = iface.sockets().add(udp_socket); + proc.files.insert( + fd, + FileLike::Socket(SocketWrapper { + handle: udp_handle, + socket_type: SocketType::Udp, + }), + ); + + Ok(fd as isize) + } SOCK_RAW => { let fd = proc.get_free_inode(); @@ -159,16 +192,7 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { let mut sockets = iface.sockets(); let mut socket = sockets.get::(wrapper.handle); - // TODO selects non-conflict high port - static mut EPHEMERAL_PORT: u16 = 49152; - let temp_port = unsafe { - if EPHEMERAL_PORT == 65535 { - EPHEMERAL_PORT = 49152; - } else { - EPHEMERAL_PORT = EPHEMERAL_PORT + 1; - } - EPHEMERAL_PORT - }; + let temp_port = get_ephemeral_port(); match socket.connect((dest.unwrap(), port), temp_port) { Ok(()) => { @@ -194,6 +218,9 @@ pub fn sys_connect(fd: usize, addr: *const u8, addrlen: usize) -> SysResult { } Err(_) => Err(SysError::ENOBUFS), } + } else if let SocketType::Udp = wrapper.socket_type { + // do nothing when only sendto() is used + Ok(0) } else { unimplemented!("socket type") } @@ -231,6 +258,38 @@ pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usi } } +pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult { + let iface = &mut *(NET_DRIVERS.lock()[0]); + let wrapper = proc.get_socket(fd)?; + if let SocketType::Udp = wrapper.socket_type { + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); + + let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; + if socket.is_open() { + if socket.can_recv() { + match socket.recv_slice(&mut slice) { + Ok((size, _)) => { + // avoid deadlock + drop(socket); + drop(sockets); + + iface.poll(); + Ok(size as isize) + } + Err(err) => Err(SysError::ENOBUFS), + } + } else { + Err(SysError::ENOBUFS) + } + } else { + Err(SysError::ENOTCONN) + } + } else { + unimplemented!("socket type") + } +} + pub fn sys_select( fd: usize, inp: *const u8, @@ -299,6 +358,43 @@ pub fn sys_sendto( drop(sockets); iface.poll(); + Ok(len as isize) + } else { + unimplemented!("ip type") + } + } else if let SocketType::Udp = wrapper.socket_type { + let v4_src = iface.ipv4_address().unwrap(); + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); + + let mut dest = None; + let mut port = 0; + + // FIXME: check size as per sin_family + let sockaddr_in = unsafe { &*(addr as *const SockaddrIn) }; + parse_addr(&sockaddr_in, &mut dest, &mut port); + + if dest == None { + return Err(SysError::EINVAL); + } else if let Some(dst_addr) = dest { + if !socket.endpoint().is_specified() { + let temp_port = get_ephemeral_port(); + socket + .bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port)) + .unwrap(); + } + + let slice = unsafe { slice::from_raw_parts(buffer, len) }; + + socket + .send_slice(&slice, IpEndpoint::new(dst_addr, port)) + .unwrap(); + + // avoid deadlock + drop(socket); + drop(sockets); + iface.poll(); + Ok(len as isize) } else { unimplemented!("ip type") @@ -314,17 +410,31 @@ pub fn sys_recvfrom( len: usize, flags: usize, addr: *mut u8, - addr_len: *mut usize, + addr_len: *mut u32, ) -> SysResult { info!( "sys_recvfrom: fd: {} buffer: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}", fd, buffer, len, flags, addr, addr_len ); + let mut proc = process(); + proc.memory_set.check_mut_array(buffer, len)?; + + if addr as usize != 0 { + proc.memory_set.check_mut_ptr(addr_len)?; + + let max_addr_len = unsafe { *addr_len } as usize; + if max_addr_len < size_of::() { + return Err(SysError::EINVAL); + } + + proc.memory_set.check_mut_array(addr, max_addr_len)?; + } let iface = &mut *(NET_DRIVERS.lock()[0]); let wrapper = proc.get_socket(fd)?; + // TODO: move some part of these into one generic function if let SocketType::Raw = wrapper.socket_type { loop { let mut sockets = iface.sockets(); @@ -334,10 +444,32 @@ pub fn sys_recvfrom( if let Ok(size) = socket.recv_slice(&mut slice) { let mut packet = Ipv4Packet::new_unchecked(&slice); - // FIXME: check size as per sin_family - let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; - fill_addr(&mut sockaddr_in, IpAddress::Ipv4(packet.src_addr()), 0); - unsafe { *addr_len = size_of::() }; + if addr as usize != 0 { + // FIXME: check size as per sin_family + let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; + fill_addr(&mut sockaddr_in, IpAddress::Ipv4(packet.src_addr()), 0); + unsafe { *addr_len = size_of::() as u32 }; + } + + return Ok(size as isize); + } + + // avoid deadlock + drop(socket); + SOCKET_ACTIVITY._wait() + } + } else if let SocketType::Udp = wrapper.socket_type { + loop { + let mut sockets = iface.sockets(); + let mut socket = sockets.get::(wrapper.handle); + + let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; + if let Ok((size, endpoint)) = socket.recv_slice(&mut slice) { + if addr as usize != 0 { + let mut sockaddr_in = unsafe { &mut *(addr as *mut SockaddrIn) }; + fill_addr(&mut sockaddr_in, endpoint.addr, endpoint.port); + unsafe { *addr_len = size_of::() as u32 }; + } return Ok(size as isize); }