Fix sys_accept deadlock

master
Jiajie Chen 6 years ago
parent 6697861860
commit d041884cc2

@ -371,9 +371,7 @@ pub fn sys_sendto(
let slice = unsafe { slice::from_raw_parts(buffer, len) }; let slice = unsafe { slice::from_raw_parts(buffer, len) };
socket socket.send_slice(&slice, endpoint).unwrap();
.send_slice(&slice, endpoint)
.unwrap();
// avoid deadlock // avoid deadlock
drop(socket); drop(socket);
@ -433,7 +431,9 @@ pub fn sys_recvfrom(
addr: IpAddress::Ipv4(packet.src_addr()), addr: IpAddress::Ipv4(packet.src_addr()),
port: 0, port: 0,
}); });
unsafe { sockaddr_in.write_to(addr, addr_len); } unsafe {
sockaddr_in.write_to(addr, addr_len);
}
} }
return Ok(size); return Ok(size);
@ -453,7 +453,9 @@ pub fn sys_recvfrom(
if let Ok((size, endpoint)) = socket.recv_slice(&mut slice) { if let Ok((size, endpoint)) = socket.recv_slice(&mut slice) {
if !addr.is_null() { if !addr.is_null() {
let sockaddr_in = SockaddrIn::from(endpoint); let sockaddr_in = SockaddrIn::from(endpoint);
unsafe { sockaddr_in.write_to(addr, addr_len); } unsafe {
sockaddr_in.write_to(addr, addr_len);
}
} }
return Ok(size); return Ok(size);
@ -473,7 +475,9 @@ pub fn sys_recvfrom(
if let Ok(size) = socket.recv_slice(&mut slice) { if let Ok(size) = socket.recv_slice(&mut slice) {
if !addr.is_null() { if !addr.is_null() {
let sockaddr_in = SockaddrIn::from(socket.remote_endpoint()); let sockaddr_in = SockaddrIn::from(socket.remote_endpoint());
unsafe { sockaddr_in.write_to(addr, addr_len); } unsafe {
sockaddr_in.write_to(addr, addr_len);
}
} }
return Ok(size); return Ok(size);
@ -497,7 +501,7 @@ impl Clone for SocketWrapper {
SocketWrapper { SocketWrapper {
handle: self.handle.clone(), handle: self.handle.clone(),
socket_type: self.socket_type.clone() socket_type: self.socket_type.clone(),
} }
} }
} }
@ -534,7 +538,7 @@ pub fn sys_bind(fd: usize, addr: *const SockaddrIn, len: usize) -> SysResult {
if let SocketType::Tcp(_) = wrapper.socket_type { if let SocketType::Tcp(_) = wrapper.socket_type {
wrapper.socket_type = SocketType::Tcp(TcpSocketState { wrapper.socket_type = SocketType::Tcp(TcpSocketState {
local_endpoint: Some(endpoint), local_endpoint: Some(endpoint),
is_listening: false is_listening: false,
}); });
Ok(0) Ok(0)
} else { } else {
@ -564,10 +568,8 @@ pub fn sys_listen(fd: usize, backlog: usize) -> SysResult {
Ok(()) => { Ok(()) => {
tcp_state.is_listening = true; tcp_state.is_listening = true;
Ok(0) Ok(0)
}, }
Err(err) => { Err(err) => Err(SysError::EINVAL),
Err(SysError::EINVAL)
},
} }
} else { } else {
Ok(0) Ok(0)
@ -639,6 +641,7 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
tcp_socket.listen(endpoint).unwrap(); tcp_socket.listen(endpoint).unwrap();
let tcp_handle = sockets.add(tcp_socket); let tcp_handle = sockets.add(tcp_socket);
let mut orig_socket = proc let mut orig_socket = proc
.files .files
.insert( .insert(
@ -650,21 +653,22 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
) )
.unwrap(); .unwrap();
if let FileLike::Socket(wrapper) = orig_socket { if let FileLike::Socket(ref mut wrapper) = orig_socket {
proc.files.insert(new_fd, FileLike::Socket(SocketWrapper { if let SocketType::Tcp(ref mut state) = wrapper.socket_type {
handle: wrapper.handle, state.is_listening = false;
socket_type: SocketType::Tcp(TcpSocketState {
local_endpoint: Some(endpoint),
is_listening: false,
})
}));
} else { } else {
panic!("impossible"); panic!("impossible");
} }
} else {
panic!("impossible");
}
proc.files.insert(new_fd, orig_socket);
if !addr.is_null() { if !addr.is_null() {
let sockaddr_in = SockaddrIn::from(remote_endpoint); let sockaddr_in = SockaddrIn::from(remote_endpoint);
unsafe { sockaddr_in.write_to(addr, addr_len); } unsafe {
sockaddr_in.write_to(addr, addr_len);
}
} }
return Ok(new_fd); return Ok(new_fd);
} }
@ -672,6 +676,7 @@ pub fn sys_accept(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) -> SysRe
// avoid deadlock // avoid deadlock
drop(socket); drop(socket);
drop(sockets); drop(sockets);
drop(iface);
SOCKET_ACTIVITY._wait() SOCKET_ACTIVITY._wait()
} }
} else { } else {
@ -711,7 +716,9 @@ pub fn sys_getsockname(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) ->
if let SocketType::Tcp(state) = &wrapper.socket_type { if let SocketType::Tcp(state) = &wrapper.socket_type {
if let Some(endpoint) = state.local_endpoint { if let Some(endpoint) = state.local_endpoint {
let sockaddr_in = SockaddrIn::from(endpoint); let sockaddr_in = SockaddrIn::from(endpoint);
unsafe { sockaddr_in.write_to(addr, addr_len); } unsafe {
sockaddr_in.write_to(addr, addr_len);
}
Ok(0) Ok(0)
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
@ -753,7 +760,9 @@ pub fn sys_getpeername(fd: usize, addr: *mut SockaddrIn, addr_len: *mut u32) ->
if socket.is_open() { if socket.is_open() {
let remote_endpoint = socket.remote_endpoint(); let remote_endpoint = socket.remote_endpoint();
let sockaddr_in = SockaddrIn::from(remote_endpoint); let sockaddr_in = SockaddrIn::from(remote_endpoint);
unsafe { sockaddr_in.write_to(addr, addr_len); } unsafe {
sockaddr_in.write_to(addr, addr_len);
}
Ok(0) Ok(0)
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)
@ -799,10 +808,7 @@ pub fn sys_dup2_socket(proc: &mut Process, wrapper: SocketWrapper, fd: usize) ->
let iface = &*(NET_DRIVERS.read()[0]); let iface = &*(NET_DRIVERS.read()[0]);
let mut sockets = iface.sockets(); let mut sockets = iface.sockets();
sockets.retain(wrapper.handle); sockets.retain(wrapper.handle);
proc.files.insert( proc.files.insert(fd, FileLike::Socket(wrapper));
fd,
FileLike::Socket(wrapper),
);
Ok(fd) Ok(fd)
} }
@ -817,15 +823,13 @@ pub struct SockaddrIn {
impl From<IpEndpoint> for SockaddrIn { impl From<IpEndpoint> for SockaddrIn {
fn from(endpoint: IpEndpoint) -> Self { fn from(endpoint: IpEndpoint) -> Self {
match endpoint.addr { match endpoint.addr {
IpAddress::Ipv4(ipv4) => { IpAddress::Ipv4(ipv4) => SockaddrIn {
SockaddrIn {
sin_family: AF_INET as u16, sin_family: AF_INET as u16,
sin_port: u16::to_be(endpoint.port), sin_port: u16::to_be(endpoint.port),
sin_addr: u32::to_be(u32::from_be_bytes(ipv4.0)), sin_addr: u32::to_be(u32::from_be_bytes(ipv4.0)),
sin_zero: [0; 8], sin_zero: [0; 8],
} },
} _ => unimplemented!("ipv6"),
_ => unimplemented!("ipv6")
} }
} }
} }
@ -836,11 +840,13 @@ impl SockaddrIn {
if self.sin_family == AF_INET as u16 { if self.sin_family == AF_INET as u16 {
let port = u16::from_be(self.sin_port); let port = u16::from_be(self.sin_port);
let addr = IpAddress::from(Ipv4Address::from_bytes( let addr = IpAddress::from(Ipv4Address::from_bytes(
&u32::from_be(self.sin_addr).to_be_bytes()[..] &u32::from_be(self.sin_addr).to_be_bytes()[..],
)); ));
Ok((addr, port).into()) Ok((addr, port).into())
} else if self.sin_family == AF_UNIX as u16 { } else if self.sin_family == AF_UNIX as u16 {
debug!("unix socket path {}", unsafe {util::from_cstr((self as *const SockaddrIn as *const u8).add(2)) }); debug!("unix socket path {}", unsafe {
util::from_cstr((self as *const SockaddrIn as *const u8).add(2))
});
Err(SysError::EINVAL) Err(SysError::EINVAL)
} else { } else {
Err(SysError::EINVAL) Err(SysError::EINVAL)

@ -3,8 +3,8 @@
use super::*; use super::*;
use crate::arch::consts::USEC_PER_TICK; use crate::arch::consts::USEC_PER_TICK;
use crate::arch::driver::rtc_cmos; use crate::arch::driver::rtc_cmos;
use lazy_static::lazy_static;
use core::time::Duration; use core::time::Duration;
use lazy_static::lazy_static;
lazy_static! { lazy_static! {
pub static ref EPOCH_BASE: u64 = unsafe { rtc_cmos::read_epoch() }; pub static ref EPOCH_BASE: u64 = unsafe { rtc_cmos::read_epoch() };

Loading…
Cancel
Save