Move socket to global, move and merge functions to net/structs.rs

toolchain_update
Jiajie Chen 6 years ago
parent f6352b2688
commit 5d601c3ea4

@ -10,10 +10,15 @@ use crate::sync::{Condvar, MutexGuard, SpinNoIrq};
use self::block::virtio_blk::VirtIOBlkDriver;
mod device_tree;
#[allow(dead_code)]
pub mod bus;
#[allow(dead_code)]
pub mod net;
#[allow(dead_code)]
pub mod block;
#[allow(dead_code)]
mod gpu;
#[allow(dead_code)]
mod input;
#[derive(Debug, Eq, PartialEq)]
@ -45,9 +50,6 @@ pub trait NetDriver : Driver {
// get ipv4 address
fn ipv4_address(&self) -> Option<Ipv4Address>;
// get sockets
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq>;
// manually trigger a poll, use it after sending packets
fn poll(&self);
}

@ -23,6 +23,7 @@ use smoltcp::Result;
use volatile::Volatile;
use crate::memory::active_table;
use crate::net::SOCKETS;
use crate::sync::SpinNoIrqLock as Mutex;
use crate::sync::{MutexGuard, SpinNoIrq};
use crate::HEAP_ALLOCATOR;
@ -72,7 +73,6 @@ const E1000_RAH: usize = 0x5404 / 4;
pub struct E1000Interface {
iface: Mutex<EthernetInterface<'static, 'static, 'static, E1000Driver>>,
driver: E1000Driver,
sockets: Mutex<SocketSet<'static, 'static, 'static>>,
}
impl Driver for E1000Interface {
@ -104,7 +104,7 @@ impl Driver for E1000Interface {
if irq {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock();
let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => {
SOCKET_ACTIVITY.notify_all();
@ -136,13 +136,9 @@ impl NetDriver for E1000Interface {
self.iface.lock().ipv4_address()
}
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
self.sockets.lock()
}
fn poll(&self) {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock();
let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => {
SOCKET_ACTIVITY.notify_all();
@ -195,8 +191,9 @@ impl<'a> phy::Device<'a> for E1000Driver {
}
}
let e1000 =
unsafe { slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4) };
let e1000 = unsafe {
slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4)
};
let send_queue_size = PAGE_SIZE / size_of::<E1000SendDesc>();
let send_queue = unsafe {
@ -214,8 +211,8 @@ impl<'a> phy::Device<'a> for E1000Driver {
let index = (rdt as usize + 1) % recv_queue_size;
let recv_desc = &mut recv_queue[index];
let transmit_avail = driver.first_trans || (*send_desc).status & 1 != 0;
let receive_avail = (*recv_desc).status & 1 != 0;
let transmit_avail = driver.first_trans || (*send_desc).status & 1 != 0;
let receive_avail = (*recv_desc).status & 1 != 0;
if transmit_avail && receive_avail {
let buffer = unsafe {
@ -247,8 +244,9 @@ impl<'a> phy::Device<'a> for E1000Driver {
}
}
let e1000 =
unsafe { slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4) };
let e1000 = unsafe {
slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4)
};
let send_queue_size = PAGE_SIZE / size_of::<E1000SendDesc>();
let send_queue = unsafe {
@ -428,7 +426,6 @@ pub fn e1000_init(header: usize, size: usize) {
// IPGT=0xa | IPGR1=0x8 | IPGR2=0xc
e1000[E1000_TIPG].write(0xa | (0x8 << 10) | (0xc << 20)); // TIPG
// 4.6.5 Receive Initialization
let mut ral: u32 = 0;
let mut rah: u32 = 0;
@ -502,7 +499,6 @@ pub fn e1000_init(header: usize, size: usize) {
let e1000_iface = E1000Interface {
iface: Mutex::new(iface),
sockets: Mutex::new(SocketSet::new(vec![])),
driver: net_driver.clone(),
};

@ -1,7 +1,6 @@
//! Intel 10Gb Network Adapter 82599 i.e. ixgbe network driver
use alloc::alloc::{GlobalAlloc, Layout};
use alloc::format;
use alloc::prelude::*;
use alloc::sync::Arc;
use core::mem::size_of;
@ -14,7 +13,7 @@ use log::*;
use rcore_memory::paging::PageTable;
use rcore_memory::PAGE_SIZE;
use smoltcp::iface::*;
use smoltcp::phy::{self, DeviceCapabilities, Checksum};
use smoltcp::phy::{self, Checksum, DeviceCapabilities};
use smoltcp::socket::*;
use smoltcp::time::Instant;
use smoltcp::wire::EthernetAddress;
@ -23,6 +22,7 @@ use smoltcp::Result;
use volatile::Volatile;
use crate::memory::active_table;
use crate::net::SOCKETS;
use crate::sync::SpinNoIrqLock as Mutex;
use crate::sync::{MutexGuard, SpinNoIrq};
use crate::HEAP_ALLOCATOR;
@ -139,7 +139,6 @@ const IXGBE_EEC: usize = 0x10010 / 4;
pub struct IXGBEInterface {
iface: Mutex<EthernetInterface<'static, 'static, 'static, IXGBEDriver>>,
driver: IXGBEDriver,
sockets: Mutex<SocketSet<'static, 'static, 'static>>,
name: String,
irq: Option<u32>,
}
@ -194,7 +193,7 @@ impl Driver for IXGBEInterface {
if rx {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock();
let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => {
SOCKET_ACTIVITY.notify_all();
@ -226,13 +225,9 @@ impl NetDriver for IXGBEInterface {
self.iface.lock().ipv4_address()
}
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
self.sockets.lock()
}
fn poll(&self) {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock();
let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => {
SOCKET_ACTIVITY.notify_all();
@ -576,7 +571,6 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
// CRCStrip | RSCACKC | FCOE_WRFIX
ixgbe[IXGBE_RDRXCTL].write(ixgbe[IXGBE_RDRXCTL].read() | (1 << 0) | (1 << 25) | (1 << 26));
/* Not completed part
// Program RXPBSIZE, MRQC, PFQDE, RTRUP2TC, MFLCN.RPFCE, and MFLCN.RFCE according to the DCB and virtualization modes (see Section 4.6.11.3).
// 4.6.11.3.4 DCB-Off, VT-Off
@ -717,7 +711,8 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
// Program the HLREG0 register according to the required MAC behavior.
// TXCRCEN | RXCRCSTRP | TXPADEN | RXLNGTHERREN
// ixgbe[IXGBE_HLREG0].write(ixgbe[IXGBE_HLREG0].read() & !(1 << 0) & !(1 << 1));
ixgbe[IXGBE_HLREG0].write(ixgbe[IXGBE_HLREG0].read() | (1 << 0) | (1 << 1) | (1 << 10) | (1 << 27));
ixgbe[IXGBE_HLREG0]
.write(ixgbe[IXGBE_HLREG0].read() | (1 << 0) | (1 << 1) | (1 << 10) | (1 << 27));
// The following steps should be done once per transmit queue:
// 1. Allocate a region of memory for the transmit descriptor list.
@ -746,7 +741,6 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
ixgbe[IXGBE_TXDCTL].write(ixgbe[IXGBE_TXDCTL].read() | 1 << 25);
while ixgbe[IXGBE_TXDCTL].read() & (1 << 25) == 0 {}
// 4.6.6 Interrupt Initialization
// The software driver associates between Tx and Rx interrupt causes and the EICR register by setting the IVAR[n] registers.
// map Rx0 to interrupt 0
@ -758,7 +752,7 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
// CNT_WDIS | ITR Interval=100us
// if sys_read() spin more times, the interval here should be larger
// Linux use dynamic ETIR based on statistics
ixgbe[IXGBE_EITR].write(((100/2) << 3) | (1 << 31));
ixgbe[IXGBE_EITR].write(((100 / 2) << 3) | (1 << 31));
// Disable general purpose interrupt
// We don't need them
ixgbe[IXGBE_GPIE].write(0);
@ -789,7 +783,6 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
let ixgbe_iface = IXGBEInterface {
iface: Mutex::new(iface),
sockets: Mutex::new(SocketSet::new(vec![])),
driver: net_driver.clone(),
name,
irq,
@ -798,5 +791,4 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
let driver = Arc::new(ixgbe_iface);
DRIVERS.write().push(driver.clone());
NET_DRIVERS.write().push(driver);
}

@ -1,3 +1,3 @@
pub mod virtio_net;
pub mod e1000;
pub mod ixgbe;
pub mod ixgbe;
pub mod virtio_net;

@ -6,25 +6,25 @@ use core::mem::size_of;
use core::slice;
use bitflags::*;
use device_tree::Node;
use device_tree::util::SliceRead;
use device_tree::Node;
use log::*;
use rcore_memory::PAGE_SIZE;
use rcore_memory::paging::PageTable;
use rcore_memory::PAGE_SIZE;
use smoltcp::phy::{self, DeviceCapabilities};
use smoltcp::Result;
use smoltcp::socket::SocketSet;
use smoltcp::time::Instant;
use smoltcp::wire::{EthernetAddress, Ipv4Address};
use smoltcp::socket::SocketSet;
use smoltcp::Result;
use volatile::{ReadOnly, Volatile};
use crate::HEAP_ALLOCATOR;
use crate::memory::active_table;
use crate::sync::SpinNoIrqLock as Mutex;
use crate::sync::{MutexGuard, SpinNoIrq};
use crate::HEAP_ALLOCATOR;
use super::super::{DeviceType, Driver, DRIVERS, NET_DRIVERS, NetDriver};
use super::super::bus::virtio_mmio::*;
use super::super::{DeviceType, Driver, NetDriver, DRIVERS, NET_DRIVERS};
pub struct VirtIONet {
interrupt_parent: u32,
@ -71,7 +71,6 @@ impl VirtIONet {
self.queues[VIRTIO_QUEUE_TRANSMIT].can_add(1, 0)
}
fn receive_available(&self) -> bool {
self.queues[VIRTIO_QUEUE_RECEIVE].can_get()
}
@ -90,10 +89,6 @@ impl NetDriver for VirtIONetDriver {
unimplemented!()
}
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
unimplemented!()
}
fn poll(&self) {
unimplemented!()
}
@ -110,8 +105,10 @@ impl<'a> phy::Device<'a> for VirtIONetDriver {
let driver = self.0.lock();
if driver.transmit_available() && driver.receive_available() {
// potential racing
Some((VirtIONetRxToken(self.clone()),
VirtIONetTxToken(self.clone())))
Some((
VirtIONetRxToken(self.clone()),
VirtIONetTxToken(self.clone()),
))
} else {
None
}
@ -134,9 +131,10 @@ impl<'a> phy::Device<'a> for VirtIONetDriver {
}
}
impl phy::RxToken for VirtIONetRxToken {
impl phy::RxToken for VirtIONetRxToken {
fn consume<R, F>(self, _timestamp: Instant, f: F) -> Result<R>
where F: FnOnce(&[u8]) -> Result<R>
where
F: FnOnce(&[u8]) -> Result<R>,
{
let (input, output, _, user_data) = {
let mut driver = (self.0).0.lock();
@ -156,7 +154,8 @@ impl phy::RxToken for VirtIONetRxToken {
impl phy::TxToken for VirtIONetTxToken {
fn consume<R, F>(self, _timestamp: Instant, len: usize, f: F) -> Result<R>
where F: FnOnce(&mut [u8]) -> Result<R>,
where
F: FnOnce(&mut [u8]) -> Result<R>,
{
let output = {
let mut driver = (self.0).0.lock();
@ -165,16 +164,18 @@ impl phy::TxToken for VirtIONetTxToken {
active_table().map_if_not_exists(driver.header as usize, driver.header as usize);
if let Some((_, output, _, _)) = driver.queues[VIRTIO_QUEUE_TRANSMIT].get() {
unsafe { slice::from_raw_parts_mut(output[0].as_ptr() as *mut u8, output[0].len())}
unsafe { slice::from_raw_parts_mut(output[0].as_ptr() as *mut u8, output[0].len()) }
} else {
// allocate a page for buffer
let page = unsafe {
HEAP_ALLOCATOR.alloc_zeroed(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap())
HEAP_ALLOCATOR
.alloc_zeroed(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap())
} as usize;
unsafe { slice::from_raw_parts_mut(page as *mut u8, PAGE_SIZE) }
}
};
let output_buffer = &mut output[size_of::<VirtIONetHeader>()..(size_of::<VirtIONetHeader>() + len)];
let output_buffer =
&mut output[size_of::<VirtIONetHeader>()..(size_of::<VirtIONetHeader>() + len)];
let result = f(output_buffer);
let mut driver = (self.0).0.lock();
@ -183,7 +184,6 @@ impl phy::TxToken for VirtIONetTxToken {
}
}
bitflags! {
struct VirtIONetFeature : u64 {
const CSUM = 1 << 0;
@ -234,7 +234,7 @@ bitflags! {
#[derive(Debug)]
struct VirtIONetworkConfig {
mac: [u8; 6],
status: ReadOnly<u16>
status: ReadOnly<u16>,
}
// virtio 5.1.6 Device Operation
@ -250,7 +250,6 @@ struct VirtIONetHeader {
// payload starts from here
}
pub fn virtio_net_init(node: &Node) {
let reg = node.prop_raw("reg").unwrap();
let from = reg.as_slice().read_be_u64(0).unwrap();
@ -283,8 +282,10 @@ pub fn virtio_net_init(node: &Node) {
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize,
mac: EthernetAddress(mac),
queues: [VirtIOVirtqueue::new(header, VIRTIO_QUEUE_RECEIVE, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num)],
queues: [
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_RECEIVE, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num),
],
};
// allocate a page for buffer
@ -300,4 +301,4 @@ pub fn virtio_net_init(node: &Node) {
DRIVERS.write().push(net_driver.clone());
NET_DRIVERS.write().push(net_driver);
}
}

@ -1,2 +1,5 @@
mod structs;
mod test;
pub use self::test::server;
pub use self::structs::*;
pub use self::test::server;

@ -0,0 +1,249 @@
use alloc::sync::Arc;
use core::fmt;
use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY};
use crate::process::structs::Process;
use crate::sync::SpinNoIrqLock as Mutex;
use crate::syscall::*;
use smoltcp::socket::*;
use smoltcp::wire::*;
lazy_static! {
pub static ref SOCKETS: Arc<Mutex<SocketSet<'static, 'static, 'static>>> =
Arc::new(Mutex::new(SocketSet::new(vec![])));
}
#[derive(Clone, Debug)]
pub struct TcpSocketState {
pub local_endpoint: Option<IpEndpoint>, // save local endpoint for bind()
pub is_listening: bool,
}
#[derive(Clone, Debug)]
pub struct UdpSocketState {
pub remote_endpoint: Option<IpEndpoint>, // remember remote endpoint for connect()
}
#[derive(Clone, Debug)]
pub enum SocketType {
Raw,
Tcp(TcpSocketState),
Udp(UdpSocketState),
Icmp,
}
#[derive(Debug)]
pub struct SocketWrapper {
pub handle: SocketHandle,
pub socket_type: SocketType,
}
pub fn get_ephemeral_port() -> u16 {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 49152;
unsafe {
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
}
EPHEMERAL_PORT
}
}
/// Safety: call this without SOCKETS locked
pub fn poll_ifaces() {
for iface in NET_DRIVERS.read().iter() {
iface.poll();
}
}
impl Drop for SocketWrapper {
fn drop(&mut self) {
let mut sockets = SOCKETS.lock();
sockets.release(self.handle);
sockets.prune();
// send FIN immediately when applicable
drop(sockets);
poll_ifaces();
}
}
impl SocketWrapper {
pub fn write(&self, data: &[u8], sendto_endpoint: Option<IpEndpoint>) -> SysResult {
if let SocketType::Raw = self.socket_type {
if let Some(endpoint) = sendto_endpoint {
// temporary solution
let iface = &*(NET_DRIVERS.read()[0]);
let v4_src = iface.ipv4_address().unwrap();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<RawSocket>(self.handle);
if let IpAddress::Ipv4(v4_dst) = endpoint.addr {
let len = data.len();
// using 20-byte IPv4 header
let mut buffer = vec![0u8; len + 20];
let mut packet = Ipv4Packet::new_unchecked(&mut buffer);
packet.set_version(4);
packet.set_header_len(20);
packet.set_total_len((20 + len) as u16);
packet.set_protocol(socket.ip_protocol().into());
packet.set_src_addr(v4_src);
packet.set_dst_addr(v4_dst);
let payload = packet.payload_mut();
payload.copy_from_slice(data);
packet.fill_checksum();
socket.send_slice(&buffer).unwrap();
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
} else {
unimplemented!("ip type")
}
} else {
Err(SysError::ENOTCONN)
}
} else if let SocketType::Tcp(_) = self.socket_type {
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(self.handle);
if socket.is_open() {
if socket.can_send() {
match socket.send_slice(&data) {
Ok(size) => {
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
Ok(size)
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
Err(SysError::ENOTCONN)
}
} else if let SocketType::Udp(ref state) = self.socket_type {
let remote_endpoint = {
if let Some(ref endpoint) = sendto_endpoint {
endpoint
} else if let Some(ref endpoint) = state.remote_endpoint {
endpoint
} else {
return Err(SysError::ENOTCONN);
}
};
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<UdpSocket>(self.handle);
if socket.endpoint().port == 0 {
let temp_port = get_ephemeral_port();
socket
.bind(IpEndpoint::new(IpAddress::Unspecified, temp_port))
.unwrap();
}
if socket.can_send() {
match socket.send_slice(&data, *remote_endpoint) {
Ok(()) => {
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
Ok(data.len())
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
unimplemented!("socket type")
}
}
pub fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) {
if let SocketType::Raw = self.socket_type {
loop {
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<RawSocket>(self.handle);
if let Ok(size) = socket.recv_slice(data) {
let packet = Ipv4Packet::new_unchecked(data);
return (
Ok(size),
IpEndpoint {
addr: IpAddress::Ipv4(packet.src_addr()),
port: 0,
},
);
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else if let SocketType::Tcp(_) = self.socket_type {
spin_and_wait(&[&SOCKET_ACTIVITY], move || {
poll_ifaces();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(self.handle);
if socket.is_open() {
if let Ok(size) = socket.recv_slice(data) {
if size > 0 {
let endpoint = socket.remote_endpoint();
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
return Some((Ok(size), endpoint));
}
}
} else {
return Some((Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED));
}
None
})
} else if let SocketType::Udp(ref state) = self.socket_type {
loop {
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<UdpSocket>(self.handle);
if socket.is_open() {
if let Ok((size, remote_endpoint)) = socket.recv_slice(data) {
let endpoint = remote_endpoint;
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
return (Ok(size), endpoint);
}
} else {
return (Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED);
}
// avoid deadlock
drop(socket);
SOCKET_ACTIVITY._wait()
}
} else {
unimplemented!("socket type")
}
}
}

@ -1,11 +1,12 @@
use crate::thread;
use crate::drivers::NET_DRIVERS;
use smoltcp::socket::*;
use crate::drivers::NetDriver;
use crate::drivers::NET_DRIVERS;
use crate::net::SOCKETS;
use crate::thread;
use alloc::vec;
use core::fmt::Write;
use smoltcp::socket::*;
pub extern fn server(_arg: usize) -> ! {
pub extern "C" fn server(_arg: usize) -> ! {
if NET_DRIVERS.read().len() < 1 {
loop {
thread::yield_now();
@ -24,18 +25,15 @@ pub extern fn server(_arg: usize) -> ! {
let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]);
let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer);
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let udp_handle = sockets.add(udp_socket);
let tcp_handle = sockets.add(tcp_socket);
let tcp2_handle = sockets.add(tcp2_socket);
drop(sockets);
drop(iface);
loop {
{
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
// udp server
{
@ -45,10 +43,8 @@ pub extern fn server(_arg: usize) -> ! {
}
let client = match socket.recv() {
Ok((_, endpoint)) => {
Some(endpoint)
}
Err(_) => None
Ok((_, endpoint)) => Some(endpoint),
Err(_) => None,
};
if let Some(endpoint) = client {
let hello = b"hello\n";
@ -85,5 +81,4 @@ pub extern fn server(_arg: usize) -> ! {
thread::yield_now();
}
}

@ -4,8 +4,6 @@ use core::fmt;
use log::*;
use spin::{Mutex, RwLock};
use xmas_elf::{ElfFile, header, program::{Flags, Type}};
use smoltcp::socket::SocketHandle;
use smoltcp::wire::IpEndpoint;
use rcore_memory::PAGE_SIZE;
use rcore_thread::Tid;
@ -14,6 +12,7 @@ use crate::memory::{ByFrame, GlobalFrameAlloc, KernelStack, MemoryAttr, MemorySe
use crate::fs::{FileHandle, OpenOptions};
use crate::sync::Condvar;
use crate::drivers::NET_DRIVERS;
use crate::net::{SocketWrapper, SOCKETS};
use super::abi::{self, ProcInitInfo};
@ -27,30 +26,6 @@ pub struct Thread {
pub proc: Arc<Mutex<Process>>,
}
#[derive(Clone, Debug)]
pub struct TcpSocketState {
pub local_endpoint: Option<IpEndpoint>, // save local endpoint for bind()
pub is_listening: bool,
}
#[derive(Clone, Debug)]
pub struct UdpSocketState {
pub remote_endpoint: Option<IpEndpoint>, // remember remote endpoint for connect()
}
#[derive(Clone, Debug)]
pub enum SocketType {
Raw,
Tcp(TcpSocketState),
Udp(UdpSocketState),
Icmp
}
#[derive(Debug)]
pub struct SocketWrapper {
pub handle: SocketHandle,
pub socket_type: SocketType,
}
#[derive(Clone)]
pub enum FileLike {
@ -63,12 +38,7 @@ impl fmt::Debug for FileLike {
match self {
FileLike::File(_) => write!(f, "File"),
FileLike::Socket(wrapper) => {
match wrapper.socket_type {
SocketType::Raw => write!(f, "RawSocket"),
SocketType::Tcp(_) => write!(f, "TcpSocket"),
SocketType::Udp(_) => write!(f, "UdpSocket"),
SocketType::Icmp => write!(f, "IcmpSocket"),
}
write!(f, "{:?}", wrapper)
},
}
}
@ -324,12 +294,10 @@ impl Thread {
debug!("fork: temporary copy data!");
let kstack = KernelStack::new();
for iface in NET_DRIVERS.read().iter() {
let mut sockets = iface.sockets();
for (_fd, file) in files.iter() {
if let FileLike::Socket(wrapper) = file {
sockets.retain(wrapper.handle);
}
let mut sockets = SOCKETS.lock();
for (_fd, file) in files.iter() {
if let FileLike::Socket(wrapper) = file {
sockets.retain(wrapper.handle);
}
}

@ -2,7 +2,10 @@
use super::*;
use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY};
use crate::process::structs::TcpSocketState;
use crate::net::{
get_ephemeral_port, poll_ifaces, SocketType, SocketWrapper, TcpSocketState, UdpSocketState,
SOCKETS,
};
use core::cmp::min;
use core::mem::size_of;
use smoltcp::socket::*;
@ -23,26 +26,12 @@ const IPPROTO_TCP: usize = 6;
const TCP_SENDBUF: usize = 512 * 1024; // 512K
const TCP_RECVBUF: usize = 512 * 1024; // 512K
fn get_ephemeral_port() -> u16 {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 49152;
unsafe {
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
}
EPHEMERAL_PORT
}
}
pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult {
info!(
"socket: domain: {}, socket_type: {}, protocol: {}",
domain, socket_type, protocol
);
let mut proc = process();
let iface = &*(NET_DRIVERS.read()[0]);
match domain {
AF_INET | AF_UNIX => match socket_type & SOCK_TYPE_MASK {
SOCK_STREAM => {
@ -52,7 +41,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]);
let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
let tcp_handle = iface.sockets().add(tcp_socket);
let tcp_handle = SOCKETS.lock().add(tcp_socket);
proc.files.insert(
fd,
FileLike::Socket(SocketWrapper {
@ -75,7 +64,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 2048]);
let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
let udp_handle = iface.sockets().add(udp_socket);
let udp_handle = SOCKETS.lock().add(udp_socket);
proc.files.insert(
fd,
FileLike::Socket(SocketWrapper {
@ -102,7 +91,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
raw_tx_buffer,
);
let raw_handle = iface.sockets().add(raw_socket);
let raw_handle = SOCKETS.lock().add(raw_socket);
proc.files.insert(
fd,
FileLike::Socket(SocketWrapper {
@ -211,8 +200,7 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
let temp_port = get_ephemeral_port();
@ -225,10 +213,9 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
// wait for connection result
loop {
let iface = &*(NET_DRIVERS.read()[0]);
iface.poll();
poll_ifaces();
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.state() == TcpState::SynSent {
// still connecting
@ -256,307 +243,73 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
}
pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
let slice = unsafe { slice::from_raw_parts(base, len) };
if socket.is_open() {
if socket.can_send() {
match socket.send_slice(&slice) {
Ok(size) => {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(size)
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
Err(SysError::ENOTCONN)
}
} else if let SocketType::Udp(ref state) = wrapper.socket_type {
if let Some(ref remote_endpoint) = state.remote_endpoint {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.endpoint().port == 0 {
let v4_src = iface.ipv4_address().unwrap();
let temp_port = get_ephemeral_port();
socket
.bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port))
.unwrap();
}
let slice = unsafe { slice::from_raw_parts(base, len) };
if socket.is_open() {
if socket.can_send() {
match socket.send_slice(&slice, *remote_endpoint) {
Ok(()) => {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
Err(SysError::ENOTCONN)
}
} else {
Err(SysError::ENOTCONN)
}
} else {
unimplemented!("socket type")
}
let slice = unsafe { slice::from_raw_parts(base, len) };
wrapper.write(&slice, None)
}
pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
spin_and_wait(&[&SOCKET_ACTIVITY], move || {
iface.poll();
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.is_open() {
let mut slice = unsafe { slice::from_raw_parts_mut(base, len) };
if let Ok(size) = socket.recv_slice(&mut slice) {
if size > 0 {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
return Some(Ok(size));
}
}
} else {
return Some(Err(SysError::ENOTCONN));
}
None
})
} else if let SocketType::Udp(_) = wrapper.socket_type {
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.is_open() {
let mut slice = unsafe { slice::from_raw_parts_mut(base, len) };
if let Ok((size, _)) = socket.recv_slice(&mut slice) {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
return Ok(size);
}
} else {
return Err(SysError::ENOTCONN);
}
// avoid deadlock
drop(socket);
SOCKET_ACTIVITY._wait()
}
} else {
unimplemented!("socket type")
}
let mut slice = unsafe { slice::from_raw_parts_mut(base, len) };
let (result, _) = wrapper.read(&mut slice);
result
}
pub fn sys_sendto(
fd: usize,
buffer: *const u8,
base: *const u8,
len: usize,
flags: usize,
addr: *const SockAddr,
addr_len: usize,
) -> SysResult {
info!(
"sys_sendto: fd: {} buffer: {:?} len: {} addr: {:?} addr_len: {}",
fd, buffer, len, addr, addr_len
"sys_sendto: fd: {} base: {:?} len: {} addr: {:?} addr_len: {}",
fd, base, len, addr, addr_len
);
let mut proc = process();
proc.memory_set.check_array(buffer, len)?;
let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?;
let iface = &*(NET_DRIVERS.read()[0]);
proc.memory_set.check_array(base, len)?;
let wrapper = proc.get_socket(fd)?;
if let SocketType::Raw = wrapper.socket_type {
let v4_src = iface.ipv4_address().unwrap();
let mut sockets = iface.sockets();
let mut socket = sockets.get::<RawSocket>(wrapper.handle);
if let IpAddress::Ipv4(v4_dst) = endpoint.addr {
let slice = unsafe { slice::from_raw_parts(buffer, len) };
// using 20-byte IPv4 header
let mut buffer = vec![0u8; len + 20];
let mut packet = Ipv4Packet::new_unchecked(&mut buffer);
packet.set_version(4);
packet.set_header_len(20);
packet.set_total_len((20 + len) as u16);
packet.set_protocol(socket.ip_protocol().into());
packet.set_src_addr(v4_src);
packet.set_dst_addr(v4_dst);
let payload = packet.payload_mut();
payload.copy_from_slice(slice);
packet.fill_checksum();
socket.send_slice(&buffer).unwrap();
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
} else {
unimplemented!("ip type")
}
} else if let SocketType::Udp(_) = wrapper.socket_type {
let v4_src = iface.ipv4_address().unwrap();
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.endpoint().port == 0 {
let temp_port = get_ephemeral_port();
socket
.bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port))
.unwrap();
}
let slice = unsafe { slice::from_raw_parts(buffer, len) };
socket.send_slice(&slice, endpoint).unwrap();
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
} else {
unimplemented!("socket type")
}
let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?;
let slice = unsafe { slice::from_raw_parts(base, len) };
wrapper.write(&slice, Some(endpoint))
}
pub fn sys_recvfrom(
fd: usize,
buffer: *mut u8,
base: *mut u8,
len: usize,
flags: usize,
addr: *mut SockAddr,
addr_len: *mut u32,
) -> SysResult {
info!(
"sys_recvfrom: fd: {} buffer: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}",
fd, buffer, len, flags, addr, addr_len
"sys_recvfrom: fd: {} base: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}",
fd, base, len, flags, addr, addr_len
);
let mut proc = process();
proc.memory_set.check_mut_array(buffer, len)?;
let iface = &*(NET_DRIVERS.read()[0]);
proc.memory_set.check_mut_array(base, len)?;
let wrapper = proc.get_socket(fd)?;
// TODO: move some part of these into one generic function
if let SocketType::Raw = wrapper.socket_type {
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<RawSocket>(wrapper.handle);
let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) };
if let Ok(size) = socket.recv_slice(&mut slice) {
let packet = Ipv4Packet::new_unchecked(&slice);
if !addr.is_null() {
// FIXME: check size as per sin_family
let sockaddr_in = SockAddr::from(IpEndpoint {
addr: IpAddress::Ipv4(packet.src_addr()),
port: 0,
});
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
}
let mut slice = unsafe { slice::from_raw_parts_mut(base, len) };
let (result, endpoint) = wrapper.read(&mut slice);
return Ok(size);
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else if let SocketType::Udp(_) = wrapper.socket_type {
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) };
if let Ok((size, endpoint)) = socket.recv_slice(&mut slice) {
if !addr.is_null() {
let sockaddr_in = SockAddr::from(endpoint);
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
}
return Ok(size);
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else if let SocketType::Tcp(_) = wrapper.socket_type {
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) };
if let Ok(size) = socket.recv_slice(&mut slice) {
if !addr.is_null() {
let sockaddr_in = SockAddr::from(socket.remote_endpoint());
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
}
return Ok(size);
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
if result.is_ok() && !addr.is_null() {
let sockaddr_in = SockAddr::from(endpoint);
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
} else {
unimplemented!("socket type")
}
result
}
impl Clone for SocketWrapper {
fn clone(&self) -> Self {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
sockets.retain(self.handle);
SocketWrapper {
@ -566,19 +319,6 @@ impl Clone for SocketWrapper {
}
}
impl Drop for SocketWrapper {
fn drop(&mut self) {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
sockets.release(self.handle);
sockets.prune();
// send FIN immediately when applicable
drop(sockets);
iface.poll();
}
}
pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult {
info!("sys_bind: fd: {} addr: {:?} len: {}", fd, addr, addr_len);
let mut proc = process();
@ -589,7 +329,6 @@ pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult
}
info!("sys_bind: fd: {} bind to {}", fd, endpoint);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
wrapper.socket_type = SocketType::Tcp(TcpSocketState {
@ -598,7 +337,7 @@ pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult
});
Ok(0)
} else if let SocketType::Udp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
match socket.bind(endpoint) {
Ok(()) => Ok(0),
@ -615,14 +354,13 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
// open multiple sockets for each connection
let mut proc = process();
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type {
if tcp_state.is_listening {
// it is ok to listen twice
Ok(0)
} else if let Some(local_endpoint) = tcp_state.local_endpoint {
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
info!("socket {} listening on {:?}", fd, local_endpoint);
@ -649,10 +387,9 @@ pub fn sys_shutdown(fd: usize, how: usize) -> SysResult {
info!("sys_shutdown: fd: {} how: {}", fd, how);
let mut proc = process();
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
socket.close();
Ok(0)
@ -686,8 +423,7 @@ pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResu
if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() {
if let Some(endpoint) = tcp_state.local_endpoint {
loop {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.is_active() {
@ -736,14 +472,13 @@ pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResu
drop(sockets);
drop(proc);
iface.poll();
poll_ifaces();
return Ok(new_fd);
}
// avoid deadlock
drop(socket);
drop(sockets);
drop(iface);
SOCKET_ACTIVITY._wait()
}
} else {
@ -767,7 +502,6 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
return Err(SysError::EINVAL);
}
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(state) = &wrapper.socket_type {
if let Some(endpoint) = state.local_endpoint {
@ -777,7 +511,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
}
Ok(0)
} else {
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle);
let endpoint = socket.local_endpoint();
if endpoint.port != 0 {
@ -791,7 +525,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
}
}
} else if let SocketType::Udp(_) = &wrapper.socket_type {
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let socket = sockets.get::<UdpSocket>(wrapper.handle);
let endpoint = socket.endpoint();
if endpoint.port != 0 {
@ -822,10 +556,9 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
return Err(SysError::EINVAL);
}
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.is_open() {
@ -860,8 +593,7 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) {
let mut output = false;
let mut err = false;
if let SocketType::Tcp(state) = wrapper.socket_type.clone() {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle);
if state.is_listening && socket.is_active() {
@ -879,8 +611,7 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) {
}
}
} else if let SocketType::Udp(_) = wrapper.socket_type {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
let mut sockets = SOCKETS.lock();
let socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.can_recv() {

Loading…
Cancel
Save