diff --git a/Cargo.toml b/Cargo.toml index f4d4372..698dcc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fast-socks5" -version = "0.4.3" +version = "0.5.0" authors = ["Jonathan Dizdarevic "] edition = "2018" license = "MIT" @@ -8,16 +8,17 @@ description = "Fast SOCKS5 client/server implementation written in Rust async/.a repository = "https://github.com/dizda/fast-socks5" [dependencies] -futures = "0.3.8" log = "0.4" -async-std = { version = "1.9.0", features = ["std"] } +tokio = { version = "1.12.0", features = ["io-util", "net", "time"] } anyhow = "1.0" thiserror = "1.0" +tokio-stream = "0.1.7" # Dependencies for examples/ [dev-dependencies] env_logger = "0.7" structopt = "0.3" +tokio = { version = "1.12.0", features = ["io-util", "net", "time", "rt-multi-thread", "macros"] } [[example]] name = "server" diff --git a/LICENSE b/LICENSE index 27839de..943eda1 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2020 Jonathan Dizdarevic +Copyright (c) 2021 Jonathan Dizdarevic Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 8ec6d85..dc4602d 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ - An `async`/`.await` [SOCKS5](https://tools.ietf.org/html/rfc1928) implementation. - No **unsafe** code -- Built on-top of `async-std` library +- Built on-top of `tokio` library - Ultra lightweight and scalable - No system dependencies - Cross-platform @@ -22,11 +22,11 @@ - `AsyncRead + AsyncWrite` traits are implemented on Socks5Stream & Socks5Socket - `IPv4`, `IPv6`, and `Domains` types are supported - Config helper for Socks5Server -- Helpers to run a Socks5Server à la *"async-std's TcpStream"* via `incoming.next().await` +- Helpers to run a Socks5Server à la *"std's TcpStream"* via `incoming.next().await` - Examples come with real cases commands scenarios - Can disable `DNS resolving` -- Can skip the authentication/handshake process, which will directly handle command's request (useful to save useless round-trips in an already authenticated environment) -- Can disable command execution (useful if you just want to forward the request to an another server) +- Can skip the authentication/handshake process, which will directly handle command's request (useful to save useless round-trips in a current authenticated environment) +- Can disable command execution (useful if you just want to forward the request to a different server) ## Install diff --git a/examples/client.rs b/examples/client.rs index 423957a..c736b1c 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -3,11 +3,10 @@ extern crate log; use anyhow::Context; -use async_std::task; use fast_socks5::client::Config; use fast_socks5::{client::Socks5Stream, Result}; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use structopt::StructOpt; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; /// # How to use it: /// @@ -46,10 +45,11 @@ struct Opt { pub skip_auth: bool, } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { env_logger::init(); - task::block_on(spawn_socks_client()) + spawn_socks_client().await } async fn spawn_socks_client() -> Result<()> { @@ -112,7 +112,10 @@ async fn http_request( .context("Can't read HTTP Response")?; info!("Response: {}", String::from_utf8_lossy(&result)); - assert!(result.starts_with(b"HTTP/1.1")); + + if result.starts_with(b"HTTP/1.1") { + info!("HTTP/1.1 Response detected!"); + } //assert!(result.ends_with(b"\r\n") || result.ends_with(b"")); Ok(()) diff --git a/examples/server.rs b/examples/server.rs index 0e163a1..6ac292b 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -2,13 +2,15 @@ #[macro_use] extern crate log; -use async_std::{future::Future, stream::StreamExt, task}; use fast_socks5::{ server::{Config, SimpleUserPassword, Socks5Server, Socks5Socket}, Result, SocksError, }; -use futures::{AsyncRead, AsyncWrite}; +use std::future::Future; use structopt::StructOpt; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::task; +use tokio_stream::StreamExt; /// # How to use it: /// @@ -61,10 +63,11 @@ enum AuthMode { /// /// TODO: Write functional tests: https://github.com/ark0f/async-socks5/blob/master/src/lib.rs#L762 /// TODO: Write functional tests with cURL? -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { env_logger::init(); - task::block_on(spawn_socks_server()) + spawn_socks_server().await } async fn spawn_socks_server() -> Result<()> { @@ -79,7 +82,7 @@ async fn spawn_socks_server() -> Result<()> { if opt.skip_auth { return Err(SocksError::ArgumentInputError( "Can't use skip-auth flag and authentication altogether.", - ))?; + )); } config.set_authentication(SimpleUserPassword { username, password }); diff --git a/examples/simple_tcp_server.rs b/examples/simple_tcp_server.rs index 627b241..ac8ae1e 100644 --- a/examples/simple_tcp_server.rs +++ b/examples/simple_tcp_server.rs @@ -2,15 +2,18 @@ #[macro_use] extern crate log; -use async_std::net::TcpListener; -use async_std::sync::Arc; -use async_std::{future::Future, stream::StreamExt, task}; use fast_socks5::{ - server::{Config, SimpleUserPassword, Socks5Server, Socks5Socket}, + server::{Config, SimpleUserPassword, Socks5Socket}, Result, }; -use futures::{AsyncRead, AsyncWrite}; +use std::future::Future; +use std::sync::Arc; use structopt::StructOpt; +use tokio::task; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpListener, +}; /// # How to use it: /// @@ -61,10 +64,11 @@ enum AuthMode { /// TODO: Write functional tests: https://github.com/ark0f/async-socks5/blob/master/src/lib.rs#L762 /// TODO: Write functional tests with cURL? /// TODO: Move this to as a standalone library -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { env_logger::init(); - task::block_on(spawn_socks_server()) + spawn_socks_server().await } async fn spawn_socks_server() -> Result<()> { @@ -82,30 +86,21 @@ async fn spawn_socks_server() -> Result<()> { let config = Arc::new(config); - let mut listener = TcpListener::bind(&opt.listen_addr).await?; + let listener = TcpListener::bind(&opt.listen_addr).await?; // listener.set_config(config); - let mut incoming = listener.incoming(); - info!("Listen for socks connections @ {}", &opt.listen_addr); // Standard TCP loop - while let Some(socket_res) = incoming.next().await { - match socket_res { - Ok(socket) => { + loop { + match listener.accept().await { + Ok((socket, _addr)) => { info!("Connection from {}", socket.peer_addr()?); let socket = Socks5Socket::new(socket, config.clone()); - // socket.upgrade_to_socks5().await; spawn_and_log_error(socket.upgrade_to_socks5()); - // match socket.upgrade_to_socks5().await { - // Ok(_) => {} - // Err(e) => error!("{:#}", &e), - // } - } - Err(err) => { - error!("accept error = {:?}", err); } + Err(err) => error!("accept error = {:?}", err), } } diff --git a/src/client.rs b/src/client.rs index 55be547..7508c69 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,10 +3,12 @@ use crate::read_exact; use crate::util::target_addr::{read_address, TargetAddr, ToTargetAddr}; use crate::{consts, AuthenticationMethod, ReplyError, Result, SocksError}; use anyhow::Context; -use async_std::net::{SocketAddr, TcpStream, ToSocketAddrs}; -use futures::{task::Poll, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use std::io; +use std::net::SocketAddr; use std::pin::Pin; +use std::task::Poll; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpStream, ToSocketAddrs}; const MAX_ADDR_LEN: usize = 260; @@ -65,7 +67,7 @@ where } // Handshake Lifecycle - if stream.config.skip_auth == false { + if !stream.config.skip_auth { let methods = stream.send_version_and_methods(methods).await?; stream.which_method_accepted(methods).await?; } else { @@ -171,7 +173,7 @@ where .await .context("Can't write that the methods are unsupported.")?; - return Err(SocksError::AuthMethodUnacceptable(vec![method]))?; + return Err(SocksError::AuthMethodUnacceptable(vec![method])); } } @@ -287,7 +289,7 @@ where TargetAddr::Domain(ref domain, port) => { debug!("TargetAddr::Domain"); if domain.len() > u8::max_value() as usize { - return Err(SocksError::ExceededMaxDomainLen(domain.len()))?; + return Err(SocksError::ExceededMaxDomainLen(domain.len())); } padding = 5 + domain.len() + 2; @@ -337,7 +339,7 @@ where } if reply != consts::SOCKS5_REPLY_SUCCEEDED { - return Err(ReplyError::from_u8(reply))?; // Convert reply received into correct error + return Err(ReplyError::from_u8(reply).into()); // Convert reply received into correct error } let address = read_address(&mut self.socket, address_type).await?; @@ -427,8 +429,8 @@ where fn poll_read( mut self: Pin<&mut Self>, context: &mut std::task::Context, - buf: &mut [u8], - ) -> Poll> { + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.socket).poll_read(context, buf) } } @@ -453,10 +455,10 @@ where Pin::new(&mut self.socket).poll_flush(context) } - fn poll_close( + fn poll_shutdown( mut self: Pin<&mut Self>, context: &mut std::task::Context, ) -> Poll> { - Pin::new(&mut self.socket).poll_close(context) + Pin::new(&mut self.socket).poll_shutdown(context) } } diff --git a/src/lib.rs b/src/lib.rs index f202800..a7a7db6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,8 +91,6 @@ impl fmt::Display for AuthenticationMethod { pub enum SocksError { #[error("i/o error: {0}")] Io(#[from] io::Error), - #[error("request timeout: {0}")] - FutureTimeout(#[from] async_std::future::TimeoutError), #[error("the data for key `{0}` is not available")] Redaction(String), #[error("invalid header (expected {expected:?}, found {found:?})")] diff --git a/src/server.rs b/src/server.rs index adb0c01..5913f93 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,22 +1,19 @@ use crate::read_exact; +use crate::ready; use crate::util::target_addr::{read_address, TargetAddr}; use crate::{consts, AuthenticationMethod, ReplyError, Result, SocksError}; use anyhow::Context; -use async_std::{ - future, - net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs}, - sync::Arc, - task::ready, - task::{Context as AsyncContext, Poll}, -}; -use futures::{ - future::{Either, Future}, - stream::Stream, - AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, -}; +use std::future::Future; use std::io; -use std::net::ToSocketAddrs as StdToSocketAddrs; +use std::net::{SocketAddr, ToSocketAddrs as StdToSocketAddrs}; use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context as AsyncContext, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs}; +use tokio::time::timeout; +use tokio_stream::{Stream, StreamExt}; #[derive(Clone)] pub struct Config { @@ -187,7 +184,7 @@ impl Socks5Socket { trace!("upgrading to socks5..."); // Handshake - if self.config.skip_auth == false { + if !self.config.skip_auth { let methods = self.get_methods().await?; self.can_accept_method(methods).await?; @@ -208,7 +205,7 @@ impl Socks5Socket { Err(SocksError::ReplyError(e)) => { // If a reply error has been returned, we send it to the client self.reply(&e).await?; - Err(e)? // propagate the error to end this connection's task + return Err(e.into()); // propagate the error to end this connection's task } // if any other errors has been detected, we simply end connection's task Err(d) => return Err(d), @@ -445,7 +442,7 @@ impl Socks5Socket { } if cmd != consts::SOCKS5_CMD_TCP_CONNECT { - return Err(ReplyError::CommandNotSupported)?; + return Err(ReplyError::CommandNotSupported.into()); } // Guess address type @@ -493,26 +490,32 @@ impl Socks5Socket { .next() .context("unreachable")?; + let fut = TcpStream::connect(addr); + let limit = Duration::from_secs(self.config.request_timeout); + // TCP connect with timeout, to avoid memory leak for connection that takes forever - let outbound = match future::timeout( - std::time::Duration::from_secs(self.config.request_timeout), - TcpStream::connect(addr), - ) - .await - { + let outbound = match timeout(limit, fut).await { Ok(e) => match e { Ok(o) => o, Err(e) => match e.kind() { // Match other TCP errors with ReplyError - io::ErrorKind::ConnectionRefused => Err(ReplyError::ConnectionRefused)?, - io::ErrorKind::ConnectionAborted => Err(ReplyError::ConnectionNotAllowed)?, - io::ErrorKind::ConnectionReset => Err(ReplyError::ConnectionNotAllowed)?, - io::ErrorKind::NotConnected => Err(ReplyError::NetworkUnreachable)?, - _ => Err(e)?, // #[error("General failure")] ? + io::ErrorKind::ConnectionRefused => { + return Err(ReplyError::ConnectionRefused.into()) + } + io::ErrorKind::ConnectionAborted => { + return Err(ReplyError::ConnectionNotAllowed.into()) + } + io::ErrorKind::ConnectionReset => { + return Err(ReplyError::ConnectionNotAllowed.into()) + } + io::ErrorKind::NotConnected => { + return Err(ReplyError::NetworkUnreachable.into()) + } + _ => return Err(e.into()), // #[error("General failure")] ? }, }, // Wrap timeout error in a proper ReplyError - Err(_) => Err(ReplyError::TtlExpired)?, + Err(_) => return Err(ReplyError::TtlExpired.into()), }; debug!("Connected to remote destination"); @@ -552,38 +555,14 @@ impl Socks5Socket { /// Copy data between two peers /// Using 2 different generators, because they could be different structs with same traits. -async fn transfer(mut inbound: I, outbound: O) -> Result<()> +async fn transfer(mut inbound: I, mut outbound: O) -> Result<()> where I: AsyncRead + AsyncWrite + Unpin, O: AsyncRead + AsyncWrite + Unpin, { - //TODO: use TcpStream.clone() https://github.com/async-rs/async-std/pull/689/files#diff-633608b66cafdfb86435918f3a48bea5R17 - - // let (mut ri, mut wi) = (&inbound, &inbound); - let (mut ri, mut wi) = futures::io::AsyncReadExt::split(&mut inbound); - // let (mut ro, mut wo) = (&outbound, &outbound); - let (mut ro, mut wo) = futures::io::AsyncReadExt::split(outbound); - - // Exchange data - // For some reasons, futures::future::select does not work with async_std::io::copy() 🤔 - let inbound_to_outbound = futures::io::copy(&mut ri, &mut wo); - let outbound_to_inbound = futures::io::copy(&mut ro, &mut wi); - - // I've chosen `select` over `join` because the inbound (client) is more likely to leave the connection open for a while, - // while it's not necessarily as the other part (outbound, aka remote server) has closed the communication. - match futures::future::select(inbound_to_outbound, outbound_to_inbound).await { - Either::Left((Ok(data), _)) => { - info!("local closed -> remote target ({} bytes consumed)", data) - } - Either::Left((Err(err), _)) => { - error!("local closed -> remote target with error {:?}", err,) - } - Either::Right((Ok(data), _)) => { - info!("local <- remote target closed ({} bytes consumed)", data) - } - Either::Right((Err(err), _)) => { - error!("local <- remote target closed with error {:?}", err,) - } + match tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await { + Ok(res) => info!("transfer closed ({}, {})", res.0, res.1), + Err(err) => error!("transfer error: {:?}", err), }; Ok(()) @@ -597,8 +576,8 @@ where fn poll_read( mut self: Pin<&mut Self>, context: &mut std::task::Context, - buf: &mut [u8], - ) -> Poll> { + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.inner).poll_read(context, buf) } } @@ -623,11 +602,11 @@ where Pin::new(&mut self.inner).poll_flush(context) } - fn poll_close( + fn poll_shutdown( mut self: Pin<&mut Self>, context: &mut std::task::Context, ) -> Poll> { - Pin::new(&mut self.inner).poll_close(context) + Pin::new(&mut self.inner).poll_shutdown(context) } } @@ -639,7 +618,7 @@ mod test { fn test_bind() { //dza async { - let server = Socks5Server::bind("127.0.0.1:1080").await.unwrap(); + let _server = Socks5Server::bind("127.0.0.1:1080").await.unwrap(); }; } } diff --git a/src/util/stream.rs b/src/util/stream.rs index 2280792..e5146ea 100644 --- a/src/util/stream.rs +++ b/src/util/stream.rs @@ -28,3 +28,13 @@ macro_rules! read_exact { $stream.read_exact(&mut x).await.map(|_| x) }}; } + +#[macro_export] +macro_rules! ready { + ($e:expr $(,)?) => { + match $e { + std::task::Poll::Ready(t) => t, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} diff --git a/src/util/target_addr.rs b/src/util/target_addr.rs index 5a1805b..67afd93 100644 --- a/src/util/target_addr.rs +++ b/src/util/target_addr.rs @@ -1,12 +1,13 @@ use crate::consts; use crate::read_exact; use anyhow::Context; -use async_std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; -use futures::{AsyncRead, AsyncReadExt}; use std::fmt; use std::io; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::vec::IntoIter; use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::net::lookup_host; /// SOCKS5 reply code #[derive(Error, Debug)] @@ -49,8 +50,8 @@ impl TargetAddr { TargetAddr::Ip(ip) => Ok(TargetAddr::Ip(ip)), TargetAddr::Domain(domain, port) => { debug!("Attempt to DNS resolve the domain {}...", &domain); - let socket_addr = (&domain[..], port) - .to_socket_addrs() + + let socket_addr = lookup_host((&domain[..], port)) .await .context(AddrError::DNSResolutionFailed)? .next() @@ -181,7 +182,7 @@ pub async fn read_address( Addr::Domain(domain) } - _ => return Err(anyhow::anyhow!(AddrError::IncorrectAddressType))?, + _ => return Err(anyhow::anyhow!(AddrError::IncorrectAddressType)), }; // Find port number