Optimize vsock code structure

This commit is contained in:
Anmin Liu
2024-05-14 10:36:14 +00:00
committed by Tate, Hongliang Tian
parent 60dd17fdd3
commit 646406115e
16 changed files with 189 additions and 94 deletions

View File

@ -53,7 +53,7 @@ pub enum VsockEventType {
/// The peer requests to establish a connection with us.
ConnectionRequest,
/// The connection was successfully established.
Connected,
ConnectionResponse,
/// The connection was closed.
Disconnected {
/// The reason for the disconnection.
@ -107,7 +107,7 @@ impl VsockEvent {
}
VirtioVsockOp::Response => {
header.check_data_is_empty()?;
VsockEventType::Connected
VsockEventType::ConnectionResponse
}
VirtioVsockOp::CreditUpdate => {
header.check_data_is_empty()?;

View File

@ -9,7 +9,7 @@ use log::debug;
use pod::Pod;
use super::{
buffer::{RxBuffer, RX_BUFFER_LEN},
buffer::RxBuffer,
config::{VirtioVsockConfig, VsockFeatures},
connect::{ConnectionInfo, VsockEvent},
error::SocketError,
@ -99,6 +99,7 @@ impl SocketDevice {
fn handle_vsock_event(_: &TrapFrame) {
handle_recv_irq(super::DEVICE_NAME);
}
// FIXME: handle event virtqueue notification in live migration
device
.transport
@ -184,6 +185,7 @@ impl SocketDevice {
header: &VirtioVsockHdr,
buffer: &[u8],
) -> Result<(), SocketError> {
debug!("Sent packet {:?}. Op {:?}", header, header.op());
debug!("buffer in send_packet_to_tx_queue: {:?}", buffer);
let tx_buffer = TxBuffer::new(header, buffer);
@ -254,9 +256,8 @@ impl SocketDevice {
/// Receive bytes from peer, returns the header
pub fn receive(
&mut self,
buffer: &mut [u8],
// connection_info: &mut ConnectionInfo,
) -> Result<VirtioVsockHdr, SocketError> {
) -> Result<RxBuffer, SocketError> {
let (token, len) = self.recv_queue.pop_used()?;
debug!(
"receive packet in rx_queue: token = {}, len = {}",
@ -268,21 +269,10 @@ impl SocketDevice {
.ok_or(QueueError::WrongToken)?;
rx_buffer.set_packet_len(len as usize);
let mut buf_reader = rx_buffer.buf();
let mut temp_buffer = vec![0u8; buf_reader.remain()];
buf_reader.read(&mut VmWriter::from(&mut temp_buffer as &mut [u8]));
let new_rx_buffer = RxBuffer::new(size_of::<VirtioVsockHdr>());
self.add_rx_buffer(new_rx_buffer, token)?;
let (header, payload) = read_header_and_body(&temp_buffer)?;
// The length written should be equal to len(header)+len(packet)
assert_eq!(len, header.len() + VIRTIO_VSOCK_HDR_LEN as u32);
debug!("Received packet {:?}. Op {:?}", header, header.op());
debug!("body is {:?}", payload);
assert!(buffer.len() >= payload.len());
buffer[..payload.len()].copy_from_slice(payload);
self.add_rx_buffer(rx_buffer, token)?;
Ok(header)
Ok(rx_buffer)
}
/// Polls the RX virtqueue for the next event, and calls the given handler function to handle it.
@ -294,10 +284,17 @@ impl SocketDevice {
if !self.recv_queue.can_pop() {
return Ok(None);
}
let mut body = vec![0u8; RX_BUFFER_LEN];
let header = self.receive(&mut body)?;
let rx_buffer = self.receive()?;
VsockEvent::from_header(&header).and_then(|event| handler(event, &body))
let mut buf_reader = rx_buffer.buf();
let mut temp_buffer = vec![0u8; buf_reader.remain()];
buf_reader.read(&mut VmWriter::from(&mut temp_buffer as &mut [u8]));
let (header, payload) = read_header_and_body(&temp_buffer)?;
// The length written should be equal to len(header)+len(packet)
debug!("Received packet {:?}. Op {:?}", header, header.op());
debug!("body is {:?}", payload);
VsockEvent::from_header(&header).and_then(|event| handler(event, payload))
}
/// Add a used rx buffer to recv queue,@index is only to check the correctness

View File

@ -21,18 +21,18 @@ pub fn register_device(name: String, device: Arc<SpinLock<SocketDevice>>) {
VSOCK_DEVICE_TABLE
.get()
.unwrap()
.lock()
.lock_irq_disabled()
.insert(name, (Arc::new(SpinLock::new(Vec::new())), device));
}
pub fn get_device(str: &str) -> Option<Arc<SpinLock<SocketDevice>>> {
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock_irq_disabled();
let (_, device) = lock.get(str)?;
Some(device.clone())
}
pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
let vsock_devs = VSOCK_DEVICE_TABLE.get().unwrap().lock();
let vsock_devs = VSOCK_DEVICE_TABLE.get().unwrap().lock_irq_disabled();
vsock_devs
.iter()
.map(|(name, (_, device))| (name.clone(), device.clone()))
@ -40,20 +40,19 @@ pub fn all_devices() -> Vec<(String, Arc<SpinLock<SocketDevice>>)> {
}
pub fn register_recv_callback(name: &str, callback: impl VsockDeviceIrqHandler) {
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock_irq_disabled();
let Some((callbacks, _)) = lock.get(name) else {
return;
};
callbacks.lock().push(Arc::new(callback));
callbacks.lock_irq_disabled().push(Arc::new(callback));
}
pub fn handle_recv_irq(name: &str) {
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock();
let lock = VSOCK_DEVICE_TABLE.get().unwrap().lock_irq_disabled();
let Some((callbacks, _)) = lock.get(name) else {
return;
};
let callbacks = callbacks.clone();
let lock = callbacks.lock();
let lock = callbacks.lock_irq_disabled();
for callback in lock.iter() {
callback.call(())
}