Move socket to global, move and merge functions to net/structs.rs

toolchain_update
Jiajie Chen 6 years ago
parent f6352b2688
commit 5d601c3ea4

@ -10,10 +10,15 @@ use crate::sync::{Condvar, MutexGuard, SpinNoIrq};
use self::block::virtio_blk::VirtIOBlkDriver; use self::block::virtio_blk::VirtIOBlkDriver;
mod device_tree; mod device_tree;
#[allow(dead_code)]
pub mod bus; pub mod bus;
#[allow(dead_code)]
pub mod net; pub mod net;
#[allow(dead_code)]
pub mod block; pub mod block;
#[allow(dead_code)]
mod gpu; mod gpu;
#[allow(dead_code)]
mod input; mod input;
#[derive(Debug, Eq, PartialEq)] #[derive(Debug, Eq, PartialEq)]
@ -45,9 +50,6 @@ pub trait NetDriver : Driver {
// get ipv4 address // get ipv4 address
fn ipv4_address(&self) -> Option<Ipv4Address>; fn ipv4_address(&self) -> Option<Ipv4Address>;
// get sockets
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq>;
// manually trigger a poll, use it after sending packets // manually trigger a poll, use it after sending packets
fn poll(&self); fn poll(&self);
} }

@ -23,6 +23,7 @@ use smoltcp::Result;
use volatile::Volatile; use volatile::Volatile;
use crate::memory::active_table; use crate::memory::active_table;
use crate::net::SOCKETS;
use crate::sync::SpinNoIrqLock as Mutex; use crate::sync::SpinNoIrqLock as Mutex;
use crate::sync::{MutexGuard, SpinNoIrq}; use crate::sync::{MutexGuard, SpinNoIrq};
use crate::HEAP_ALLOCATOR; use crate::HEAP_ALLOCATOR;
@ -72,7 +73,6 @@ const E1000_RAH: usize = 0x5404 / 4;
pub struct E1000Interface { pub struct E1000Interface {
iface: Mutex<EthernetInterface<'static, 'static, 'static, E1000Driver>>, iface: Mutex<EthernetInterface<'static, 'static, 'static, E1000Driver>>,
driver: E1000Driver, driver: E1000Driver,
sockets: Mutex<SocketSet<'static, 'static, 'static>>,
} }
impl Driver for E1000Interface { impl Driver for E1000Interface {
@ -104,7 +104,7 @@ impl Driver for E1000Interface {
if irq { if irq {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock(); let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) { match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => { Ok(_) => {
SOCKET_ACTIVITY.notify_all(); SOCKET_ACTIVITY.notify_all();
@ -136,13 +136,9 @@ impl NetDriver for E1000Interface {
self.iface.lock().ipv4_address() self.iface.lock().ipv4_address()
} }
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
self.sockets.lock()
}
fn poll(&self) { fn poll(&self) {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock(); let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) { match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => { Ok(_) => {
SOCKET_ACTIVITY.notify_all(); SOCKET_ACTIVITY.notify_all();
@ -195,8 +191,9 @@ impl<'a> phy::Device<'a> for E1000Driver {
} }
} }
let e1000 = let e1000 = unsafe {
unsafe { slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4) }; slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4)
};
let send_queue_size = PAGE_SIZE / size_of::<E1000SendDesc>(); let send_queue_size = PAGE_SIZE / size_of::<E1000SendDesc>();
let send_queue = unsafe { let send_queue = unsafe {
@ -247,8 +244,9 @@ impl<'a> phy::Device<'a> for E1000Driver {
} }
} }
let e1000 = let e1000 = unsafe {
unsafe { slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4) }; slice::from_raw_parts_mut(driver.header as *mut Volatile<u32>, driver.size / 4)
};
let send_queue_size = PAGE_SIZE / size_of::<E1000SendDesc>(); let send_queue_size = PAGE_SIZE / size_of::<E1000SendDesc>();
let send_queue = unsafe { let send_queue = unsafe {
@ -428,7 +426,6 @@ pub fn e1000_init(header: usize, size: usize) {
// IPGT=0xa | IPGR1=0x8 | IPGR2=0xc // IPGT=0xa | IPGR1=0x8 | IPGR2=0xc
e1000[E1000_TIPG].write(0xa | (0x8 << 10) | (0xc << 20)); // TIPG e1000[E1000_TIPG].write(0xa | (0x8 << 10) | (0xc << 20)); // TIPG
// 4.6.5 Receive Initialization // 4.6.5 Receive Initialization
let mut ral: u32 = 0; let mut ral: u32 = 0;
let mut rah: u32 = 0; let mut rah: u32 = 0;
@ -502,7 +499,6 @@ pub fn e1000_init(header: usize, size: usize) {
let e1000_iface = E1000Interface { let e1000_iface = E1000Interface {
iface: Mutex::new(iface), iface: Mutex::new(iface),
sockets: Mutex::new(SocketSet::new(vec![])),
driver: net_driver.clone(), driver: net_driver.clone(),
}; };

@ -1,7 +1,6 @@
//! Intel 10Gb Network Adapter 82599 i.e. ixgbe network driver //! Intel 10Gb Network Adapter 82599 i.e. ixgbe network driver
use alloc::alloc::{GlobalAlloc, Layout}; use alloc::alloc::{GlobalAlloc, Layout};
use alloc::format;
use alloc::prelude::*; use alloc::prelude::*;
use alloc::sync::Arc; use alloc::sync::Arc;
use core::mem::size_of; use core::mem::size_of;
@ -14,7 +13,7 @@ use log::*;
use rcore_memory::paging::PageTable; use rcore_memory::paging::PageTable;
use rcore_memory::PAGE_SIZE; use rcore_memory::PAGE_SIZE;
use smoltcp::iface::*; use smoltcp::iface::*;
use smoltcp::phy::{self, DeviceCapabilities, Checksum}; use smoltcp::phy::{self, Checksum, DeviceCapabilities};
use smoltcp::socket::*; use smoltcp::socket::*;
use smoltcp::time::Instant; use smoltcp::time::Instant;
use smoltcp::wire::EthernetAddress; use smoltcp::wire::EthernetAddress;
@ -23,6 +22,7 @@ use smoltcp::Result;
use volatile::Volatile; use volatile::Volatile;
use crate::memory::active_table; use crate::memory::active_table;
use crate::net::SOCKETS;
use crate::sync::SpinNoIrqLock as Mutex; use crate::sync::SpinNoIrqLock as Mutex;
use crate::sync::{MutexGuard, SpinNoIrq}; use crate::sync::{MutexGuard, SpinNoIrq};
use crate::HEAP_ALLOCATOR; use crate::HEAP_ALLOCATOR;
@ -139,7 +139,6 @@ const IXGBE_EEC: usize = 0x10010 / 4;
pub struct IXGBEInterface { pub struct IXGBEInterface {
iface: Mutex<EthernetInterface<'static, 'static, 'static, IXGBEDriver>>, iface: Mutex<EthernetInterface<'static, 'static, 'static, IXGBEDriver>>,
driver: IXGBEDriver, driver: IXGBEDriver,
sockets: Mutex<SocketSet<'static, 'static, 'static>>,
name: String, name: String,
irq: Option<u32>, irq: Option<u32>,
} }
@ -194,7 +193,7 @@ impl Driver for IXGBEInterface {
if rx { if rx {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock(); let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) { match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => { Ok(_) => {
SOCKET_ACTIVITY.notify_all(); SOCKET_ACTIVITY.notify_all();
@ -226,13 +225,9 @@ impl NetDriver for IXGBEInterface {
self.iface.lock().ipv4_address() self.iface.lock().ipv4_address()
} }
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
self.sockets.lock()
}
fn poll(&self) { fn poll(&self) {
let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64); let timestamp = Instant::from_millis(crate::trap::uptime_msec() as i64);
let mut sockets = self.sockets.lock(); let mut sockets = SOCKETS.lock();
match self.iface.lock().poll(&mut sockets, timestamp) { match self.iface.lock().poll(&mut sockets, timestamp) {
Ok(_) => { Ok(_) => {
SOCKET_ACTIVITY.notify_all(); SOCKET_ACTIVITY.notify_all();
@ -576,7 +571,6 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
// CRCStrip | RSCACKC | FCOE_WRFIX // CRCStrip | RSCACKC | FCOE_WRFIX
ixgbe[IXGBE_RDRXCTL].write(ixgbe[IXGBE_RDRXCTL].read() | (1 << 0) | (1 << 25) | (1 << 26)); ixgbe[IXGBE_RDRXCTL].write(ixgbe[IXGBE_RDRXCTL].read() | (1 << 0) | (1 << 25) | (1 << 26));
/* Not completed part /* Not completed part
// Program RXPBSIZE, MRQC, PFQDE, RTRUP2TC, MFLCN.RPFCE, and MFLCN.RFCE according to the DCB and virtualization modes (see Section 4.6.11.3). // Program RXPBSIZE, MRQC, PFQDE, RTRUP2TC, MFLCN.RPFCE, and MFLCN.RFCE according to the DCB and virtualization modes (see Section 4.6.11.3).
// 4.6.11.3.4 DCB-Off, VT-Off // 4.6.11.3.4 DCB-Off, VT-Off
@ -717,7 +711,8 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
// Program the HLREG0 register according to the required MAC behavior. // Program the HLREG0 register according to the required MAC behavior.
// TXCRCEN | RXCRCSTRP | TXPADEN | RXLNGTHERREN // TXCRCEN | RXCRCSTRP | TXPADEN | RXLNGTHERREN
// ixgbe[IXGBE_HLREG0].write(ixgbe[IXGBE_HLREG0].read() & !(1 << 0) & !(1 << 1)); // ixgbe[IXGBE_HLREG0].write(ixgbe[IXGBE_HLREG0].read() & !(1 << 0) & !(1 << 1));
ixgbe[IXGBE_HLREG0].write(ixgbe[IXGBE_HLREG0].read() | (1 << 0) | (1 << 1) | (1 << 10) | (1 << 27)); ixgbe[IXGBE_HLREG0]
.write(ixgbe[IXGBE_HLREG0].read() | (1 << 0) | (1 << 1) | (1 << 10) | (1 << 27));
// The following steps should be done once per transmit queue: // The following steps should be done once per transmit queue:
// 1. Allocate a region of memory for the transmit descriptor list. // 1. Allocate a region of memory for the transmit descriptor list.
@ -746,7 +741,6 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
ixgbe[IXGBE_TXDCTL].write(ixgbe[IXGBE_TXDCTL].read() | 1 << 25); ixgbe[IXGBE_TXDCTL].write(ixgbe[IXGBE_TXDCTL].read() | 1 << 25);
while ixgbe[IXGBE_TXDCTL].read() & (1 << 25) == 0 {} while ixgbe[IXGBE_TXDCTL].read() & (1 << 25) == 0 {}
// 4.6.6 Interrupt Initialization // 4.6.6 Interrupt Initialization
// The software driver associates between Tx and Rx interrupt causes and the EICR register by setting the IVAR[n] registers. // The software driver associates between Tx and Rx interrupt causes and the EICR register by setting the IVAR[n] registers.
// map Rx0 to interrupt 0 // map Rx0 to interrupt 0
@ -789,7 +783,6 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
let ixgbe_iface = IXGBEInterface { let ixgbe_iface = IXGBEInterface {
iface: Mutex::new(iface), iface: Mutex::new(iface),
sockets: Mutex::new(SocketSet::new(vec![])),
driver: net_driver.clone(), driver: net_driver.clone(),
name, name,
irq, irq,
@ -798,5 +791,4 @@ pub fn ixgbe_init(name: String, irq: Option<u32>, header: usize, size: usize) {
let driver = Arc::new(ixgbe_iface); let driver = Arc::new(ixgbe_iface);
DRIVERS.write().push(driver.clone()); DRIVERS.write().push(driver.clone());
NET_DRIVERS.write().push(driver); NET_DRIVERS.write().push(driver);
} }

@ -1,3 +1,3 @@
pub mod virtio_net;
pub mod e1000; pub mod e1000;
pub mod ixgbe; pub mod ixgbe;
pub mod virtio_net;

@ -6,25 +6,25 @@ use core::mem::size_of;
use core::slice; use core::slice;
use bitflags::*; use bitflags::*;
use device_tree::Node;
use device_tree::util::SliceRead; use device_tree::util::SliceRead;
use device_tree::Node;
use log::*; use log::*;
use rcore_memory::PAGE_SIZE;
use rcore_memory::paging::PageTable; use rcore_memory::paging::PageTable;
use rcore_memory::PAGE_SIZE;
use smoltcp::phy::{self, DeviceCapabilities}; use smoltcp::phy::{self, DeviceCapabilities};
use smoltcp::Result; use smoltcp::socket::SocketSet;
use smoltcp::time::Instant; use smoltcp::time::Instant;
use smoltcp::wire::{EthernetAddress, Ipv4Address}; use smoltcp::wire::{EthernetAddress, Ipv4Address};
use smoltcp::socket::SocketSet; use smoltcp::Result;
use volatile::{ReadOnly, Volatile}; use volatile::{ReadOnly, Volatile};
use crate::HEAP_ALLOCATOR;
use crate::memory::active_table; use crate::memory::active_table;
use crate::sync::SpinNoIrqLock as Mutex; use crate::sync::SpinNoIrqLock as Mutex;
use crate::sync::{MutexGuard, SpinNoIrq}; use crate::sync::{MutexGuard, SpinNoIrq};
use crate::HEAP_ALLOCATOR;
use super::super::{DeviceType, Driver, DRIVERS, NET_DRIVERS, NetDriver};
use super::super::bus::virtio_mmio::*; use super::super::bus::virtio_mmio::*;
use super::super::{DeviceType, Driver, NetDriver, DRIVERS, NET_DRIVERS};
pub struct VirtIONet { pub struct VirtIONet {
interrupt_parent: u32, interrupt_parent: u32,
@ -71,7 +71,6 @@ impl VirtIONet {
self.queues[VIRTIO_QUEUE_TRANSMIT].can_add(1, 0) self.queues[VIRTIO_QUEUE_TRANSMIT].can_add(1, 0)
} }
fn receive_available(&self) -> bool { fn receive_available(&self) -> bool {
self.queues[VIRTIO_QUEUE_RECEIVE].can_get() self.queues[VIRTIO_QUEUE_RECEIVE].can_get()
} }
@ -90,10 +89,6 @@ impl NetDriver for VirtIONetDriver {
unimplemented!() unimplemented!()
} }
fn sockets(&self) -> MutexGuard<SocketSet<'static, 'static, 'static>, SpinNoIrq> {
unimplemented!()
}
fn poll(&self) { fn poll(&self) {
unimplemented!() unimplemented!()
} }
@ -110,8 +105,10 @@ impl<'a> phy::Device<'a> for VirtIONetDriver {
let driver = self.0.lock(); let driver = self.0.lock();
if driver.transmit_available() && driver.receive_available() { if driver.transmit_available() && driver.receive_available() {
// potential racing // potential racing
Some((VirtIONetRxToken(self.clone()), Some((
VirtIONetTxToken(self.clone()))) VirtIONetRxToken(self.clone()),
VirtIONetTxToken(self.clone()),
))
} else { } else {
None None
} }
@ -136,7 +133,8 @@ impl<'a> phy::Device<'a> for VirtIONetDriver {
impl phy::RxToken for VirtIONetRxToken { impl phy::RxToken for VirtIONetRxToken {
fn consume<R, F>(self, _timestamp: Instant, f: F) -> Result<R> fn consume<R, F>(self, _timestamp: Instant, f: F) -> Result<R>
where F: FnOnce(&[u8]) -> Result<R> where
F: FnOnce(&[u8]) -> Result<R>,
{ {
let (input, output, _, user_data) = { let (input, output, _, user_data) = {
let mut driver = (self.0).0.lock(); let mut driver = (self.0).0.lock();
@ -156,7 +154,8 @@ impl phy::RxToken for VirtIONetRxToken {
impl phy::TxToken for VirtIONetTxToken { impl phy::TxToken for VirtIONetTxToken {
fn consume<R, F>(self, _timestamp: Instant, len: usize, f: F) -> Result<R> fn consume<R, F>(self, _timestamp: Instant, len: usize, f: F) -> Result<R>
where F: FnOnce(&mut [u8]) -> Result<R>, where
F: FnOnce(&mut [u8]) -> Result<R>,
{ {
let output = { let output = {
let mut driver = (self.0).0.lock(); let mut driver = (self.0).0.lock();
@ -169,12 +168,14 @@ impl phy::TxToken for VirtIONetTxToken {
} else { } else {
// allocate a page for buffer // allocate a page for buffer
let page = unsafe { let page = unsafe {
HEAP_ALLOCATOR.alloc_zeroed(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap()) HEAP_ALLOCATOR
.alloc_zeroed(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap())
} as usize; } as usize;
unsafe { slice::from_raw_parts_mut(page as *mut u8, PAGE_SIZE) } unsafe { slice::from_raw_parts_mut(page as *mut u8, PAGE_SIZE) }
} }
}; };
let output_buffer = &mut output[size_of::<VirtIONetHeader>()..(size_of::<VirtIONetHeader>() + len)]; let output_buffer =
&mut output[size_of::<VirtIONetHeader>()..(size_of::<VirtIONetHeader>() + len)];
let result = f(output_buffer); let result = f(output_buffer);
let mut driver = (self.0).0.lock(); let mut driver = (self.0).0.lock();
@ -183,7 +184,6 @@ impl phy::TxToken for VirtIONetTxToken {
} }
} }
bitflags! { bitflags! {
struct VirtIONetFeature : u64 { struct VirtIONetFeature : u64 {
const CSUM = 1 << 0; const CSUM = 1 << 0;
@ -234,7 +234,7 @@ bitflags! {
#[derive(Debug)] #[derive(Debug)]
struct VirtIONetworkConfig { struct VirtIONetworkConfig {
mac: [u8; 6], mac: [u8; 6],
status: ReadOnly<u16> status: ReadOnly<u16>,
} }
// virtio 5.1.6 Device Operation // virtio 5.1.6 Device Operation
@ -250,7 +250,6 @@ struct VirtIONetHeader {
// payload starts from here // payload starts from here
} }
pub fn virtio_net_init(node: &Node) { pub fn virtio_net_init(node: &Node) {
let reg = node.prop_raw("reg").unwrap(); let reg = node.prop_raw("reg").unwrap();
let from = reg.as_slice().read_be_u64(0).unwrap(); let from = reg.as_slice().read_be_u64(0).unwrap();
@ -283,8 +282,10 @@ pub fn virtio_net_init(node: &Node) {
interrupt_parent: node.prop_u32("interrupt-parent").unwrap(), interrupt_parent: node.prop_u32("interrupt-parent").unwrap(),
header: from as usize, header: from as usize,
mac: EthernetAddress(mac), mac: EthernetAddress(mac),
queues: [VirtIOVirtqueue::new(header, VIRTIO_QUEUE_RECEIVE, queue_num), queues: [
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num)], VirtIOVirtqueue::new(header, VIRTIO_QUEUE_RECEIVE, queue_num),
VirtIOVirtqueue::new(header, VIRTIO_QUEUE_TRANSMIT, queue_num),
],
}; };
// allocate a page for buffer // allocate a page for buffer

@ -1,2 +1,5 @@
mod structs;
mod test; mod test;
pub use self::structs::*;
pub use self::test::server; pub use self::test::server;

@ -0,0 +1,249 @@
use alloc::sync::Arc;
use core::fmt;
use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY};
use crate::process::structs::Process;
use crate::sync::SpinNoIrqLock as Mutex;
use crate::syscall::*;
use smoltcp::socket::*;
use smoltcp::wire::*;
lazy_static! {
pub static ref SOCKETS: Arc<Mutex<SocketSet<'static, 'static, 'static>>> =
Arc::new(Mutex::new(SocketSet::new(vec![])));
}
#[derive(Clone, Debug)]
pub struct TcpSocketState {
pub local_endpoint: Option<IpEndpoint>, // save local endpoint for bind()
pub is_listening: bool,
}
#[derive(Clone, Debug)]
pub struct UdpSocketState {
pub remote_endpoint: Option<IpEndpoint>, // remember remote endpoint for connect()
}
#[derive(Clone, Debug)]
pub enum SocketType {
Raw,
Tcp(TcpSocketState),
Udp(UdpSocketState),
Icmp,
}
#[derive(Debug)]
pub struct SocketWrapper {
pub handle: SocketHandle,
pub socket_type: SocketType,
}
pub fn get_ephemeral_port() -> u16 {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 49152;
unsafe {
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
}
EPHEMERAL_PORT
}
}
/// Safety: call this without SOCKETS locked
pub fn poll_ifaces() {
for iface in NET_DRIVERS.read().iter() {
iface.poll();
}
}
impl Drop for SocketWrapper {
fn drop(&mut self) {
let mut sockets = SOCKETS.lock();
sockets.release(self.handle);
sockets.prune();
// send FIN immediately when applicable
drop(sockets);
poll_ifaces();
}
}
impl SocketWrapper {
pub fn write(&self, data: &[u8], sendto_endpoint: Option<IpEndpoint>) -> SysResult {
if let SocketType::Raw = self.socket_type {
if let Some(endpoint) = sendto_endpoint {
// temporary solution
let iface = &*(NET_DRIVERS.read()[0]);
let v4_src = iface.ipv4_address().unwrap();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<RawSocket>(self.handle);
if let IpAddress::Ipv4(v4_dst) = endpoint.addr {
let len = data.len();
// using 20-byte IPv4 header
let mut buffer = vec![0u8; len + 20];
let mut packet = Ipv4Packet::new_unchecked(&mut buffer);
packet.set_version(4);
packet.set_header_len(20);
packet.set_total_len((20 + len) as u16);
packet.set_protocol(socket.ip_protocol().into());
packet.set_src_addr(v4_src);
packet.set_dst_addr(v4_dst);
let payload = packet.payload_mut();
payload.copy_from_slice(data);
packet.fill_checksum();
socket.send_slice(&buffer).unwrap();
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
} else {
unimplemented!("ip type")
}
} else {
Err(SysError::ENOTCONN)
}
} else if let SocketType::Tcp(_) = self.socket_type {
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(self.handle);
if socket.is_open() {
if socket.can_send() {
match socket.send_slice(&data) {
Ok(size) => {
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
Ok(size)
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
Err(SysError::ENOTCONN)
}
} else if let SocketType::Udp(ref state) = self.socket_type {
let remote_endpoint = {
if let Some(ref endpoint) = sendto_endpoint {
endpoint
} else if let Some(ref endpoint) = state.remote_endpoint {
endpoint
} else {
return Err(SysError::ENOTCONN);
}
};
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<UdpSocket>(self.handle);
if socket.endpoint().port == 0 {
let temp_port = get_ephemeral_port();
socket
.bind(IpEndpoint::new(IpAddress::Unspecified, temp_port))
.unwrap();
}
if socket.can_send() {
match socket.send_slice(&data, *remote_endpoint) {
Ok(()) => {
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
Ok(data.len())
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
unimplemented!("socket type")
}
}
pub fn read(&self, data: &mut [u8]) -> (SysResult, IpEndpoint) {
if let SocketType::Raw = self.socket_type {
loop {
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<RawSocket>(self.handle);
if let Ok(size) = socket.recv_slice(data) {
let packet = Ipv4Packet::new_unchecked(data);
return (
Ok(size),
IpEndpoint {
addr: IpAddress::Ipv4(packet.src_addr()),
port: 0,
},
);
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else if let SocketType::Tcp(_) = self.socket_type {
spin_and_wait(&[&SOCKET_ACTIVITY], move || {
poll_ifaces();
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(self.handle);
if socket.is_open() {
if let Ok(size) = socket.recv_slice(data) {
if size > 0 {
let endpoint = socket.remote_endpoint();
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
return Some((Ok(size), endpoint));
}
}
} else {
return Some((Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED));
}
None
})
} else if let SocketType::Udp(ref state) = self.socket_type {
loop {
let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<UdpSocket>(self.handle);
if socket.is_open() {
if let Ok((size, remote_endpoint)) = socket.recv_slice(data) {
let endpoint = remote_endpoint;
// avoid deadlock
drop(socket);
drop(sockets);
poll_ifaces();
return (Ok(size), endpoint);
}
} else {
return (Err(SysError::ENOTCONN), IpEndpoint::UNSPECIFIED);
}
// avoid deadlock
drop(socket);
SOCKET_ACTIVITY._wait()
}
} else {
unimplemented!("socket type")
}
}
}

@ -1,11 +1,12 @@
use crate::thread;
use crate::drivers::NET_DRIVERS;
use smoltcp::socket::*;
use crate::drivers::NetDriver; use crate::drivers::NetDriver;
use crate::drivers::NET_DRIVERS;
use crate::net::SOCKETS;
use crate::thread;
use alloc::vec; use alloc::vec;
use core::fmt::Write; use core::fmt::Write;
use smoltcp::socket::*;
pub extern fn server(_arg: usize) -> ! { pub extern "C" fn server(_arg: usize) -> ! {
if NET_DRIVERS.read().len() < 1 { if NET_DRIVERS.read().len() < 1 {
loop { loop {
thread::yield_now(); thread::yield_now();
@ -24,18 +25,15 @@ pub extern fn server(_arg: usize) -> ! {
let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]); let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]);
let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer); let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer);
let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
let udp_handle = sockets.add(udp_socket); let udp_handle = sockets.add(udp_socket);
let tcp_handle = sockets.add(tcp_socket); let tcp_handle = sockets.add(tcp_socket);
let tcp2_handle = sockets.add(tcp2_socket); let tcp2_handle = sockets.add(tcp2_socket);
drop(sockets); drop(sockets);
drop(iface);
loop { loop {
{ {
let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
// udp server // udp server
{ {
@ -45,10 +43,8 @@ pub extern fn server(_arg: usize) -> ! {
} }
let client = match socket.recv() { let client = match socket.recv() {
Ok((_, endpoint)) => { Ok((_, endpoint)) => Some(endpoint),
Some(endpoint) Err(_) => None,
}
Err(_) => None
}; };
if let Some(endpoint) = client { if let Some(endpoint) = client {
let hello = b"hello\n"; let hello = b"hello\n";
@ -85,5 +81,4 @@ pub extern fn server(_arg: usize) -> ! {
thread::yield_now(); thread::yield_now();
} }
} }

@ -4,8 +4,6 @@ use core::fmt;
use log::*; use log::*;
use spin::{Mutex, RwLock}; use spin::{Mutex, RwLock};
use xmas_elf::{ElfFile, header, program::{Flags, Type}}; use xmas_elf::{ElfFile, header, program::{Flags, Type}};
use smoltcp::socket::SocketHandle;
use smoltcp::wire::IpEndpoint;
use rcore_memory::PAGE_SIZE; use rcore_memory::PAGE_SIZE;
use rcore_thread::Tid; use rcore_thread::Tid;
@ -14,6 +12,7 @@ use crate::memory::{ByFrame, GlobalFrameAlloc, KernelStack, MemoryAttr, MemorySe
use crate::fs::{FileHandle, OpenOptions}; use crate::fs::{FileHandle, OpenOptions};
use crate::sync::Condvar; use crate::sync::Condvar;
use crate::drivers::NET_DRIVERS; use crate::drivers::NET_DRIVERS;
use crate::net::{SocketWrapper, SOCKETS};
use super::abi::{self, ProcInitInfo}; use super::abi::{self, ProcInitInfo};
@ -27,30 +26,6 @@ pub struct Thread {
pub proc: Arc<Mutex<Process>>, pub proc: Arc<Mutex<Process>>,
} }
#[derive(Clone, Debug)]
pub struct TcpSocketState {
pub local_endpoint: Option<IpEndpoint>, // save local endpoint for bind()
pub is_listening: bool,
}
#[derive(Clone, Debug)]
pub struct UdpSocketState {
pub remote_endpoint: Option<IpEndpoint>, // remember remote endpoint for connect()
}
#[derive(Clone, Debug)]
pub enum SocketType {
Raw,
Tcp(TcpSocketState),
Udp(UdpSocketState),
Icmp
}
#[derive(Debug)]
pub struct SocketWrapper {
pub handle: SocketHandle,
pub socket_type: SocketType,
}
#[derive(Clone)] #[derive(Clone)]
pub enum FileLike { pub enum FileLike {
@ -63,12 +38,7 @@ impl fmt::Debug for FileLike {
match self { match self {
FileLike::File(_) => write!(f, "File"), FileLike::File(_) => write!(f, "File"),
FileLike::Socket(wrapper) => { FileLike::Socket(wrapper) => {
match wrapper.socket_type { write!(f, "{:?}", wrapper)
SocketType::Raw => write!(f, "RawSocket"),
SocketType::Tcp(_) => write!(f, "TcpSocket"),
SocketType::Udp(_) => write!(f, "UdpSocket"),
SocketType::Icmp => write!(f, "IcmpSocket"),
}
}, },
} }
} }
@ -324,14 +294,12 @@ impl Thread {
debug!("fork: temporary copy data!"); debug!("fork: temporary copy data!");
let kstack = KernelStack::new(); let kstack = KernelStack::new();
for iface in NET_DRIVERS.read().iter() { let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
for (_fd, file) in files.iter() { for (_fd, file) in files.iter() {
if let FileLike::Socket(wrapper) = file { if let FileLike::Socket(wrapper) = file {
sockets.retain(wrapper.handle); sockets.retain(wrapper.handle);
} }
} }
}
Box::new(Thread { Box::new(Thread {

@ -2,7 +2,10 @@
use super::*; use super::*;
use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY}; use crate::drivers::{NET_DRIVERS, SOCKET_ACTIVITY};
use crate::process::structs::TcpSocketState; use crate::net::{
get_ephemeral_port, poll_ifaces, SocketType, SocketWrapper, TcpSocketState, UdpSocketState,
SOCKETS,
};
use core::cmp::min; use core::cmp::min;
use core::mem::size_of; use core::mem::size_of;
use smoltcp::socket::*; use smoltcp::socket::*;
@ -23,26 +26,12 @@ const IPPROTO_TCP: usize = 6;
const TCP_SENDBUF: usize = 512 * 1024; // 512K const TCP_SENDBUF: usize = 512 * 1024; // 512K
const TCP_RECVBUF: usize = 512 * 1024; // 512K const TCP_RECVBUF: usize = 512 * 1024; // 512K
fn get_ephemeral_port() -> u16 {
// TODO selects non-conflict high port
static mut EPHEMERAL_PORT: u16 = 49152;
unsafe {
if EPHEMERAL_PORT == 65535 {
EPHEMERAL_PORT = 49152;
} else {
EPHEMERAL_PORT = EPHEMERAL_PORT + 1;
}
EPHEMERAL_PORT
}
}
pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult { pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResult {
info!( info!(
"socket: domain: {}, socket_type: {}, protocol: {}", "socket: domain: {}, socket_type: {}, protocol: {}",
domain, socket_type, protocol domain, socket_type, protocol
); );
let mut proc = process(); let mut proc = process();
let iface = &*(NET_DRIVERS.read()[0]);
match domain { match domain {
AF_INET | AF_UNIX => match socket_type & SOCK_TYPE_MASK { AF_INET | AF_UNIX => match socket_type & SOCK_TYPE_MASK {
SOCK_STREAM => { SOCK_STREAM => {
@ -52,7 +41,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]); let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; TCP_SENDBUF]);
let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer);
let tcp_handle = iface.sockets().add(tcp_socket); let tcp_handle = SOCKETS.lock().add(tcp_socket);
proc.files.insert( proc.files.insert(
fd, fd,
FileLike::Socket(SocketWrapper { FileLike::Socket(SocketWrapper {
@ -75,7 +64,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 2048]); UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 2048]);
let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer);
let udp_handle = iface.sockets().add(udp_socket); let udp_handle = SOCKETS.lock().add(udp_socket);
proc.files.insert( proc.files.insert(
fd, fd,
FileLike::Socket(SocketWrapper { FileLike::Socket(SocketWrapper {
@ -102,7 +91,7 @@ pub fn sys_socket(domain: usize, socket_type: usize, protocol: usize) -> SysResu
raw_tx_buffer, raw_tx_buffer,
); );
let raw_handle = iface.sockets().add(raw_socket); let raw_handle = SOCKETS.lock().add(raw_socket);
proc.files.insert( proc.files.insert(
fd, fd,
FileLike::Socket(SocketWrapper { FileLike::Socket(SocketWrapper {
@ -211,8 +200,7 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
let wrapper = &mut proc.get_socket_mut(fd)?; let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
let temp_port = get_ephemeral_port(); let temp_port = get_ephemeral_port();
@ -225,10 +213,9 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
// wait for connection result // wait for connection result
loop { loop {
let iface = &*(NET_DRIVERS.read()[0]); poll_ifaces();
iface.poll();
let mut sockets = iface.sockets(); let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle); let socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.state() == TcpState::SynSent { if socket.state() == TcpState::SynSent {
// still connecting // still connecting
@ -256,307 +243,73 @@ pub fn sys_connect(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResu
} }
pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult { pub fn sys_write_socket(proc: &mut Process, fd: usize, base: *const u8, len: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
let slice = unsafe { slice::from_raw_parts(base, len) }; let slice = unsafe { slice::from_raw_parts(base, len) };
if socket.is_open() { wrapper.write(&slice, None)
if socket.can_send() {
match socket.send_slice(&slice) {
Ok(size) => {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(size)
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
Err(SysError::ENOTCONN)
}
} else if let SocketType::Udp(ref state) = wrapper.socket_type {
if let Some(ref remote_endpoint) = state.remote_endpoint {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.endpoint().port == 0 {
let v4_src = iface.ipv4_address().unwrap();
let temp_port = get_ephemeral_port();
socket
.bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port))
.unwrap();
}
let slice = unsafe { slice::from_raw_parts(base, len) };
if socket.is_open() {
if socket.can_send() {
match socket.send_slice(&slice, *remote_endpoint) {
Ok(()) => {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
}
Err(err) => Err(SysError::ENOBUFS),
}
} else {
Err(SysError::ENOBUFS)
}
} else {
Err(SysError::ENOTCONN)
}
} else {
Err(SysError::ENOTCONN)
}
} else {
unimplemented!("socket type")
}
} }
pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult { pub fn sys_read_socket(proc: &mut Process, fd: usize, base: *mut u8, len: usize) -> SysResult {
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type {
spin_and_wait(&[&SOCKET_ACTIVITY], move || {
iface.poll();
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.is_open() {
let mut slice = unsafe { slice::from_raw_parts_mut(base, len) };
if let Ok(size) = socket.recv_slice(&mut slice) {
if size > 0 {
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
return Some(Ok(size));
}
}
} else {
return Some(Err(SysError::ENOTCONN));
}
None
})
} else if let SocketType::Udp(_) = wrapper.socket_type {
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.is_open() {
let mut slice = unsafe { slice::from_raw_parts_mut(base, len) }; let mut slice = unsafe { slice::from_raw_parts_mut(base, len) };
if let Ok((size, _)) = socket.recv_slice(&mut slice) { let (result, _) = wrapper.read(&mut slice);
// avoid deadlock result
drop(socket);
drop(sockets);
iface.poll();
return Ok(size);
}
} else {
return Err(SysError::ENOTCONN);
}
// avoid deadlock
drop(socket);
SOCKET_ACTIVITY._wait()
}
} else {
unimplemented!("socket type")
}
} }
pub fn sys_sendto( pub fn sys_sendto(
fd: usize, fd: usize,
buffer: *const u8, base: *const u8,
len: usize, len: usize,
flags: usize, flags: usize,
addr: *const SockAddr, addr: *const SockAddr,
addr_len: usize, addr_len: usize,
) -> SysResult { ) -> SysResult {
info!( info!(
"sys_sendto: fd: {} buffer: {:?} len: {} addr: {:?} addr_len: {}", "sys_sendto: fd: {} base: {:?} len: {} addr: {:?} addr_len: {}",
fd, buffer, len, addr, addr_len fd, base, len, addr, addr_len
); );
let mut proc = process(); let mut proc = process();
proc.memory_set.check_array(buffer, len)?; proc.memory_set.check_array(base, len)?;
let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?;
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
if let SocketType::Raw = wrapper.socket_type { let endpoint = sockaddr_to_endpoint(&mut proc, addr, addr_len)?;
let v4_src = iface.ipv4_address().unwrap(); let slice = unsafe { slice::from_raw_parts(base, len) };
let mut sockets = iface.sockets(); wrapper.write(&slice, Some(endpoint))
let mut socket = sockets.get::<RawSocket>(wrapper.handle);
if let IpAddress::Ipv4(v4_dst) = endpoint.addr {
let slice = unsafe { slice::from_raw_parts(buffer, len) };
// using 20-byte IPv4 header
let mut buffer = vec![0u8; len + 20];
let mut packet = Ipv4Packet::new_unchecked(&mut buffer);
packet.set_version(4);
packet.set_header_len(20);
packet.set_total_len((20 + len) as u16);
packet.set_protocol(socket.ip_protocol().into());
packet.set_src_addr(v4_src);
packet.set_dst_addr(v4_dst);
let payload = packet.payload_mut();
payload.copy_from_slice(slice);
packet.fill_checksum();
socket.send_slice(&buffer).unwrap();
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
} else {
unimplemented!("ip type")
}
} else if let SocketType::Udp(_) = wrapper.socket_type {
let v4_src = iface.ipv4_address().unwrap();
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.endpoint().port == 0 {
let temp_port = get_ephemeral_port();
socket
.bind(IpEndpoint::new(IpAddress::Ipv4(v4_src), temp_port))
.unwrap();
}
let slice = unsafe { slice::from_raw_parts(buffer, len) };
socket.send_slice(&slice, endpoint).unwrap();
// avoid deadlock
drop(socket);
drop(sockets);
iface.poll();
Ok(len)
} else {
unimplemented!("socket type")
}
} }
pub fn sys_recvfrom( pub fn sys_recvfrom(
fd: usize, fd: usize,
buffer: *mut u8, base: *mut u8,
len: usize, len: usize,
flags: usize, flags: usize,
addr: *mut SockAddr, addr: *mut SockAddr,
addr_len: *mut u32, addr_len: *mut u32,
) -> SysResult { ) -> SysResult {
info!( info!(
"sys_recvfrom: fd: {} buffer: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}", "sys_recvfrom: fd: {} base: {:?} len: {} flags: {} addr: {:?} addr_len: {:?}",
fd, buffer, len, flags, addr, addr_len fd, base, len, flags, addr, addr_len
); );
let mut proc = process(); let mut proc = process();
proc.memory_set.check_mut_array(buffer, len)?; proc.memory_set.check_mut_array(base, len)?;
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket(fd)?; let wrapper = proc.get_socket(fd)?;
// TODO: move some part of these into one generic function let mut slice = unsafe { slice::from_raw_parts_mut(base, len) };
if let SocketType::Raw = wrapper.socket_type { let (result, endpoint) = wrapper.read(&mut slice);
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<RawSocket>(wrapper.handle);
let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) };
if let Ok(size) = socket.recv_slice(&mut slice) {
let packet = Ipv4Packet::new_unchecked(&slice);
if !addr.is_null() {
// FIXME: check size as per sin_family
let sockaddr_in = SockAddr::from(IpEndpoint {
addr: IpAddress::Ipv4(packet.src_addr()),
port: 0,
});
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
}
return Ok(size);
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else if let SocketType::Udp(_) = wrapper.socket_type {
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) }; if result.is_ok() && !addr.is_null() {
if let Ok((size, endpoint)) = socket.recv_slice(&mut slice) {
if !addr.is_null() {
let sockaddr_in = SockAddr::from(endpoint); let sockaddr_in = SockAddr::from(endpoint);
unsafe { unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?; sockaddr_in.write_to(&mut proc, addr, addr_len)?;
} }
} }
return Ok(size); result
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else if let SocketType::Tcp(_) = wrapper.socket_type {
loop {
let mut sockets = iface.sockets();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
let mut slice = unsafe { slice::from_raw_parts_mut(buffer, len) };
if let Ok(size) = socket.recv_slice(&mut slice) {
if !addr.is_null() {
let sockaddr_in = SockAddr::from(socket.remote_endpoint());
unsafe {
sockaddr_in.write_to(&mut proc, addr, addr_len)?;
}
}
return Ok(size);
}
// avoid deadlock
drop(socket);
drop(sockets);
SOCKET_ACTIVITY._wait()
}
} else {
unimplemented!("socket type")
}
} }
impl Clone for SocketWrapper { impl Clone for SocketWrapper {
fn clone(&self) -> Self { fn clone(&self) -> Self {
let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
sockets.retain(self.handle); sockets.retain(self.handle);
SocketWrapper { SocketWrapper {
@ -566,19 +319,6 @@ impl Clone for SocketWrapper {
} }
} }
impl Drop for SocketWrapper {
fn drop(&mut self) {
let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets();
sockets.release(self.handle);
sockets.prune();
// send FIN immediately when applicable
drop(sockets);
iface.poll();
}
}
pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult { pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult {
info!("sys_bind: fd: {} addr: {:?} len: {}", fd, addr, addr_len); info!("sys_bind: fd: {} addr: {:?} len: {}", fd, addr, addr_len);
let mut proc = process(); let mut proc = process();
@ -589,7 +329,6 @@ pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult
} }
info!("sys_bind: fd: {} bind to {}", fd, endpoint); info!("sys_bind: fd: {} bind to {}", fd, endpoint);
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = &mut proc.get_socket_mut(fd)?; let wrapper = &mut proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
wrapper.socket_type = SocketType::Tcp(TcpSocketState { wrapper.socket_type = SocketType::Tcp(TcpSocketState {
@ -598,7 +337,7 @@ pub fn sys_bind(fd: usize, addr: *const SockAddr, addr_len: usize) -> SysResult
}); });
Ok(0) Ok(0)
} else if let SocketType::Udp(_) = wrapper.socket_type { } else if let SocketType::Udp(_) = wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<UdpSocket>(wrapper.handle); let mut socket = sockets.get::<UdpSocket>(wrapper.handle);
match socket.bind(endpoint) { match socket.bind(endpoint) {
Ok(()) => Ok(0), Ok(()) => Ok(0),
@ -615,14 +354,13 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
// open multiple sockets for each connection // open multiple sockets for each connection
let mut proc = process(); let mut proc = process();
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type { if let SocketType::Tcp(ref mut tcp_state) = wrapper.socket_type {
if tcp_state.is_listening { if tcp_state.is_listening {
// it is ok to listen twice // it is ok to listen twice
Ok(0) Ok(0)
} else if let Some(local_endpoint) = tcp_state.local_endpoint { } else if let Some(local_endpoint) = tcp_state.local_endpoint {
let mut sockets = iface.sockets(); let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
info!("socket {} listening on {:?}", fd, local_endpoint); info!("socket {} listening on {:?}", fd, local_endpoint);
@ -649,10 +387,9 @@ pub fn sys_shutdown(fd: usize, how: usize) -> SysResult {
info!("sys_shutdown: fd: {} how: {}", fd, how); info!("sys_shutdown: fd: {} how: {}", fd, how);
let mut proc = process(); let mut proc = process();
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = SOCKETS.lock();
let mut socket = sockets.get::<TcpSocket>(wrapper.handle); let mut socket = sockets.get::<TcpSocket>(wrapper.handle);
socket.close(); socket.close();
Ok(0) Ok(0)
@ -686,8 +423,7 @@ pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResu
if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() { if let SocketType::Tcp(tcp_state) = wrapper.socket_type.clone() {
if let Some(endpoint) = tcp_state.local_endpoint { if let Some(endpoint) = tcp_state.local_endpoint {
loop { loop {
let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
let socket = sockets.get::<TcpSocket>(wrapper.handle); let socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.is_active() { if socket.is_active() {
@ -736,14 +472,13 @@ pub fn sys_accept(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> SysResu
drop(sockets); drop(sockets);
drop(proc); drop(proc);
iface.poll(); poll_ifaces();
return Ok(new_fd); return Ok(new_fd);
} }
// avoid deadlock // avoid deadlock
drop(socket); drop(socket);
drop(sockets); drop(sockets);
drop(iface);
SOCKET_ACTIVITY._wait() SOCKET_ACTIVITY._wait()
} }
} else { } else {
@ -767,7 +502,6 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
return Err(SysError::EINVAL); return Err(SysError::EINVAL);
} }
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(state) = &wrapper.socket_type { if let SocketType::Tcp(state) = &wrapper.socket_type {
if let Some(endpoint) = state.local_endpoint { if let Some(endpoint) = state.local_endpoint {
@ -777,7 +511,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
} }
Ok(0) Ok(0)
} else { } else {
let mut sockets = iface.sockets(); let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle); let socket = sockets.get::<TcpSocket>(wrapper.handle);
let endpoint = socket.local_endpoint(); let endpoint = socket.local_endpoint();
if endpoint.port != 0 { if endpoint.port != 0 {
@ -791,7 +525,7 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
} }
} }
} else if let SocketType::Udp(_) = &wrapper.socket_type { } else if let SocketType::Udp(_) = &wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = SOCKETS.lock();
let socket = sockets.get::<UdpSocket>(wrapper.handle); let socket = sockets.get::<UdpSocket>(wrapper.handle);
let endpoint = socket.endpoint(); let endpoint = socket.endpoint();
if endpoint.port != 0 { if endpoint.port != 0 {
@ -822,10 +556,9 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockAddr, addr_len: *mut u32) -> Sy
return Err(SysError::EINVAL); return Err(SysError::EINVAL);
} }
let iface = &*(NET_DRIVERS.read()[0]);
let wrapper = proc.get_socket_mut(fd)?; let wrapper = proc.get_socket_mut(fd)?;
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
let mut sockets = iface.sockets(); let mut sockets = SOCKETS.lock();
let socket = sockets.get::<TcpSocket>(wrapper.handle); let socket = sockets.get::<TcpSocket>(wrapper.handle);
if socket.is_open() { if socket.is_open() {
@ -860,8 +593,7 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) {
let mut output = false; let mut output = false;
let mut err = false; let mut err = false;
if let SocketType::Tcp(state) = wrapper.socket_type.clone() { if let SocketType::Tcp(state) = wrapper.socket_type.clone() {
let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
let socket = sockets.get::<TcpSocket>(wrapper.handle); let socket = sockets.get::<TcpSocket>(wrapper.handle);
if state.is_listening && socket.is_active() { if state.is_listening && socket.is_active() {
@ -879,8 +611,7 @@ pub fn poll_socket(wrapper: &SocketWrapper) -> (bool, bool, bool) {
} }
} }
} else if let SocketType::Udp(_) = wrapper.socket_type { } else if let SocketType::Udp(_) = wrapper.socket_type {
let iface = &*(NET_DRIVERS.read()[0]); let mut sockets = SOCKETS.lock();
let mut sockets = iface.sockets();
let socket = sockets.get::<UdpSocket>(wrapper.handle); let socket = sockets.get::<UdpSocket>(wrapper.handle);
if socket.can_recv() { if socket.can_recv() {

Loading…
Cancel
Save