Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add uninit buffer ancillary APIs #1108

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions src/net/send_recv/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::net::UCred;

use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::{align_of, size_of, size_of_val, take};
use core::mem::{align_of, size_of, size_of_val, take, MaybeUninit};
#[cfg(linux_kernel)]
use core::ptr::addr_of;
use core::{ptr, slice};
Expand All @@ -24,25 +24,31 @@ use super::{RecvFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketAddrV6};
///
/// Allocate a buffer for a single file descriptor:
/// ```
/// # use core::mem::MaybeUninit;
/// # use rustix::cmsg_space;
/// let mut space = [0; rustix::cmsg_space!(ScmRights(1))];
/// let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(1))];
/// # let _: &[MaybeUninit<u8>] = space.as_slice();
/// ```
///
/// Allocate a buffer for credentials:
/// ```
/// # #[cfg(linux_kernel)]
/// # {
/// # use core::mem::MaybeUninit;
/// # use rustix::cmsg_space;
/// let mut space = [0; rustix::cmsg_space!(ScmCredentials(1))];
/// let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
/// # let _: &[MaybeUninit<u8>] = space.as_slice();
/// # }
/// ```
///
/// Allocate a buffer for two file descriptors and credentials:
/// ```
/// # #[cfg(linux_kernel)]
/// # {
/// # use core::mem::MaybeUninit;
/// # use rustix::cmsg_space;
/// let mut space = [0; rustix::cmsg_space!(ScmRights(2), ScmCredentials(1))];
/// let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(2), ScmCredentials(1))];
/// # let _: &[MaybeUninit<u8>] = space.as_slice();
/// # }
/// ```
#[macro_export]
Expand Down Expand Up @@ -164,7 +170,7 @@ pub enum RecvAncillaryMessage<'a> {
/// [`push`]: SendAncillaryBuffer::push
pub struct SendAncillaryBuffer<'buf, 'slice, 'fd> {
/// Raw byte buffer for messages.
buffer: &'buf mut [u8],
buffer: &'buf mut [MaybeUninit<u8>],

/// The amount of the buffer that is used.
length: usize,
Expand All @@ -179,6 +185,12 @@ impl<'buf> From<&'buf mut [u8]> for SendAncillaryBuffer<'buf, '_, '_> {
}
}

impl<'buf> From<&'buf mut [MaybeUninit<u8>]> for SendAncillaryBuffer<'buf, '_, '_> {
fn from(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self::new_(buffer)
}
}

impl Default for SendAncillaryBuffer<'_, '_, '_> {
fn default() -> Self {
Self {
Expand Down Expand Up @@ -231,6 +243,11 @@ impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
/// [`send`]: crate::net::send
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
// SAFETY: T -> MaybeUninit<T> is always safe and we never uninitialize any bytes.
Self::new_(unsafe { core::mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buffer) })
}

fn new_(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self {
buffer: align_for_cmsghdr(buffer),
length: 0,
Expand All @@ -248,7 +265,7 @@ impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
return core::ptr::null_mut();
}

self.buffer.as_mut_ptr()
self.buffer.as_mut_ptr().cast()
}

/// Returns the length of the message data.
Expand Down Expand Up @@ -301,7 +318,7 @@ impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
let buffer = leap!(self.buffer.get_mut(..new_length));

// Fill the new part of the buffer with zeroes.
buffer[self.length..new_length].fill(0);
buffer[self.length..new_length].fill(MaybeUninit::new(0));
self.length = new_length;

// Get the last header in the buffer.
Expand Down Expand Up @@ -339,7 +356,7 @@ impl<'slice, 'fd> Extend<SendAncillaryMessage<'slice, 'fd>>
#[derive(Default)]
pub struct RecvAncillaryBuffer<'buf> {
/// Raw byte buffer for messages.
buffer: &'buf mut [u8],
buffer: &'buf mut [MaybeUninit<u8>],

/// The portion of the buffer we've read from already.
read: usize,
Expand All @@ -354,6 +371,12 @@ impl<'buf> From<&'buf mut [u8]> for RecvAncillaryBuffer<'buf> {
}
}

impl<'buf> From<&'buf mut [MaybeUninit<u8>]> for RecvAncillaryBuffer<'buf> {
fn from(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self::new_(buffer)
}
}

impl<'buf> RecvAncillaryBuffer<'buf> {
/// Create a new, empty `RecvAncillaryBuffer` from a raw byte buffer.
///
Expand Down Expand Up @@ -396,6 +419,11 @@ impl<'buf> RecvAncillaryBuffer<'buf> {
/// [`recv`]: crate::net::recv
#[inline]
pub fn new(buffer: &'buf mut [u8]) -> Self {
// SAFETY: T -> MaybeUninit<T> is always safe and we never uninitialize any bytes.
Self::new_(unsafe { core::mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buffer) })
}

fn new_(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self {
buffer: align_for_cmsghdr(buffer),
read: 0,
Expand All @@ -413,7 +441,7 @@ impl<'buf> RecvAncillaryBuffer<'buf> {
return core::ptr::null_mut();
}

self.buffer.as_mut_ptr()
self.buffer.as_mut_ptr().cast()
}

/// Returns the length of the message data.
Expand Down Expand Up @@ -454,7 +482,7 @@ impl Drop for RecvAncillaryBuffer<'_> {
/// Return a slice of `buffer` starting at the first `cmsghdr` alignment
/// boundary.
#[inline]
fn align_for_cmsghdr(buffer: &mut [u8]) -> &mut [u8] {
fn align_for_cmsghdr(buffer: &mut [MaybeUninit<u8>]) -> &mut [MaybeUninit<u8>] {
// If the buffer is empty, we won't be writing anything into it, so it
// doesn't need to be aligned.
if buffer.is_empty() {
Expand Down Expand Up @@ -920,6 +948,7 @@ mod messages {
use crate::backend::net::msghdr;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::ptr::NonNull;

/// An iterator over the messages in an ancillary buffer.
Expand All @@ -933,12 +962,16 @@ mod messages {
header: Option<NonNull<c::cmsghdr>>,

/// Capture the original lifetime of the buffer.
_buffer: PhantomData<&'buf mut [u8]>,
_buffer: PhantomData<&'buf mut [MaybeUninit<u8>]>,
}

pub(super) trait AllowedMsgBufType {}
impl AllowedMsgBufType for u8 {}
impl AllowedMsgBufType for MaybeUninit<u8> {}

impl<'buf> Messages<'buf> {
/// Create a new iterator over messages from a byte buffer.
pub(super) fn new(buf: &'buf mut [u8]) -> Self {
pub(super) fn new(buf: &'buf mut [impl AllowedMsgBufType]) -> Self {
let msghdr = {
let mut h = msghdr::zero_msghdr();
h.msg_control = buf.as_mut_ptr().cast();
Expand Down
30 changes: 16 additions & 14 deletions tests/net/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rustix::net::{
accept, bind_unix, connect_unix, listen, socket, AddressFamily, SocketAddrUnix, SocketType,
};
use rustix::path::DecInt;
use std::mem::MaybeUninit;
use std::path::Path;
use std::str::FromStr;
use std::sync::{Arc, Condvar, Mutex};
Expand Down Expand Up @@ -447,13 +448,13 @@ fn test_unix_msg_with_scm_rights() {
let mut pipe_end = None;

let mut buffer = [0; BUFFER_SIZE];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(1))];

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -555,8 +556,8 @@ fn test_unix_msg_with_scm_rights() {
// Format the CMSG.
let we = [write_end.as_fd()];
let msg = SendAncillaryMessage::ScmRights(&we);
let mut space = [0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::from(space.as_mut_slice());
assert!(cmsg_buffer.push(msg));

connect_unix(&data_socket, &addr).unwrap();
Expand Down Expand Up @@ -615,8 +616,8 @@ fn test_unix_peercred_explicit() {

let ucred = sockopt::get_socket_peercred(&send_sock).unwrap();
let msg = SendAncillaryMessage::ScmCredentials(ucred);
let mut space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = SendAncillaryBuffer::from(space.as_mut_slice());
assert!(cmsg_buffer.push(msg));

sendmsg(
Expand All @@ -627,8 +628,8 @@ fn test_unix_peercred_explicit() {
)
.unwrap();

let mut cmsg_space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());

let mut buffer = [0; BUFFER_SIZE];
recvmsg(
Expand Down Expand Up @@ -685,8 +686,8 @@ fn test_unix_peercred_implicit() {
)
.unwrap();

let mut cmsg_space = [0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());

let mut buffer = [0; BUFFER_SIZE];
recvmsg(
Expand Down Expand Up @@ -737,13 +738,14 @@ fn test_unix_msg_with_combo() {
let mut yet_another_pipe_end = None;

let mut buffer = [0; BUFFER_SIZE];
let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(2), ScmRights(1))];
let mut cmsg_space =
[MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(2), ScmRights(1))];

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.as_mut_slice());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -859,8 +861,8 @@ fn test_unix_msg_with_combo() {

let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap();

let mut space = [0; rustix::cmsg_space!(ScmRights(2), ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = [MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(2), ScmRights(1))];
let mut cmsg_buffer = SendAncillaryBuffer::from(space.as_mut_slice());

// Format a CMSG.
let we = [write_end.as_fd(), another_write_end.as_fd()];
Expand Down
27 changes: 15 additions & 12 deletions tests/net/unix_alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,14 @@ fn test_unix_msg_with_scm_rights() {
let mut pipe_end = None;

let mut buffer = vec![0; BUFFER_SIZE];
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1))];
let mut cmsg_space = Vec::with_capacity(rustix::cmsg_space!(ScmRights(1)));

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer =
RecvAncillaryBuffer::from(cmsg_space.spare_capacity_mut());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -553,8 +554,8 @@ fn test_unix_msg_with_scm_rights() {
// Format the CMSG.
let we = [write_end.as_fd()];
let msg = SendAncillaryMessage::ScmRights(&we);
let mut space = vec![0; msg.size()];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = Vec::with_capacity(msg.size());
let mut cmsg_buffer = SendAncillaryBuffer::from(space.spare_capacity_mut());
assert!(cmsg_buffer.push(msg));

connect_unix(&data_socket, &addr).unwrap();
Expand Down Expand Up @@ -618,8 +619,8 @@ fn test_unix_peercred() {
assert_eq!(ucred.gid, getgid());

let msg = SendAncillaryMessage::ScmCredentials(ucred);
let mut space = vec![0; msg.size()];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = Vec::with_capacity(msg.size());
let mut cmsg_buffer = SendAncillaryBuffer::from(space.spare_capacity_mut());
assert!(cmsg_buffer.push(msg));

sendmsg(
Expand All @@ -630,8 +631,8 @@ fn test_unix_peercred() {
)
.unwrap();

let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmCredentials(1))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_space = Vec::with_capacity(rustix::cmsg_space!(ScmCredentials(1)));
let mut cmsg_buffer = RecvAncillaryBuffer::from(cmsg_space.spare_capacity_mut());

let mut buffer = vec![0; BUFFER_SIZE];
recvmsg(
Expand Down Expand Up @@ -682,13 +683,15 @@ fn test_unix_msg_with_combo() {
let mut yet_another_pipe_end = None;

let mut buffer = vec![0; BUFFER_SIZE];
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(1), ScmRights(2))];
let mut cmsg_space =
Vec::with_capacity(rustix::cmsg_space!(ScmRights(1), ScmRights(2)));

'exit: loop {
let data_socket = accept(&connection_socket).unwrap();
let mut sum = 0;
loop {
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut cmsg_buffer =
RecvAncillaryBuffer::from(cmsg_space.spare_capacity_mut());
let nread = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
Expand Down Expand Up @@ -804,8 +807,8 @@ fn test_unix_msg_with_combo() {

let data_socket = socket(AddressFamily::UNIX, SocketType::SEQPACKET, None).unwrap();

let mut space = vec![0; rustix::cmsg_space!(ScmRights(1), ScmRights(2))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut space);
let mut space = Vec::with_capacity(rustix::cmsg_space!(ScmRights(1), ScmRights(2)));
let mut cmsg_buffer = SendAncillaryBuffer::from(space.spare_capacity_mut());

// Format a CMSG.
let we = [write_end.as_fd(), another_write_end.as_fd()];
Expand Down
Loading