Skip to content

Commit

Permalink
Add SOCKS authentication fast path (#134)
Browse files Browse the repository at this point in the history
Add a fast path to SOCKS auto-authenticate when the server selects noneRequired as the SOCKS server authentication method.
  • Loading branch information
Davidde94 authored Jun 17, 2021
1 parent 548e0d4 commit b03d835
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 52 deletions.
31 changes: 14 additions & 17 deletions Sources/NIOSOCKS/Channel Handlers/SOCKSServerHandshakeHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,46 +71,43 @@ public final class SOCKSServerHandshakeHandler: ChannelDuplexHandler, RemovableC
public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
do {
let message = self.unwrapOutboundIn(data)
let outboundBuffer: ByteBuffer
switch message {
case .selectedAuthenticationMethod(let method):
try self.handleWriteSelectedAuthenticationMethod(method, context: context, promise: promise)
outboundBuffer = try self.handleWriteSelectedAuthenticationMethod(method, context: context)
case .response(let response):
try self.handleWriteResponse(response, context: context, promise: promise)
outboundBuffer = try self.handleWriteResponse(response, context: context)
case .authenticationData(let data, let complete):
try self.handleWriteAuthenticationData(data, complete: complete, context: context, promise: promise)
outboundBuffer = try self.handleWriteAuthenticationData(data, complete: complete, context: context)
}
context.write(self.wrapOutboundOut(outboundBuffer), promise: promise)

} catch {
context.fireErrorCaught(error)
promise?.fail(error)
}
}

private func handleWriteSelectedAuthenticationMethod(
_ method: SelectedAuthenticationMethod, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws {
_ method: SelectedAuthenticationMethod, context: ChannelHandlerContext) throws -> ByteBuffer {
try stateMachine.sendAuthenticationMethod(method)
var buffer = context.channel.allocator.buffer(capacity: 16)
buffer.writeMethodSelection(method)
context.write(self.wrapOutboundOut(buffer), promise: promise)
return buffer
}

private func handleWriteResponse(
_ response: SOCKSResponse, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws {
_ response: SOCKSResponse, context: ChannelHandlerContext) throws -> ByteBuffer {
try stateMachine.sendServerResponse(response)
var buffer = context.channel.allocator.buffer(capacity: 16)
buffer.writeServerResponse(response)
context.write(self.wrapOutboundOut(buffer), promise: promise)
return buffer
}

private func handleWriteAuthenticationData(_ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) throws {
do {
try self.stateMachine.sendData()
if complete {
try self.stateMachine.authenticationComplete()
}
context.write(self.wrapOutboundOut(data), promise: promise)
} catch {
promise?.fail(error)
}
private func handleWriteAuthenticationData(
_ data: ByteBuffer, complete: Bool, context: ChannelHandlerContext) throws -> ByteBuffer {
try self.stateMachine.sendAuthenticationData(data, complete: complete)
return data
}

}
39 changes: 18 additions & 21 deletions Sources/NIOSOCKS/State/ServerStateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum ServerState: Hashable {
struct ServerStateMachine: Hashable {

private var state: ServerState
private var authenticationMethod: AuthenticationMethod?

var proxyEstablished: Bool {
switch self.state {
Expand Down Expand Up @@ -118,7 +119,7 @@ extension ServerStateMachine {
self.state = .waitingForClientGreeting
}

mutating func sendAuthenticationMethod(_ method: SelectedAuthenticationMethod) throws {
mutating func sendAuthenticationMethod(_ selected: SelectedAuthenticationMethod) throws {
switch self.state {
case .waitingToSendAuthenticationMethod:
()
Expand All @@ -131,7 +132,13 @@ extension ServerStateMachine {
.error:
throw SOCKSError.InvalidServerState()
}
self.state = .authenticating

self.authenticationMethod = selected.method
if selected.method == .noneRequired {
self.state = .waitingForClientRequest
} else {
self.state = .authenticating
}
}

mutating func sendServerResponse(_ response: SOCKSResponse) throws {
Expand All @@ -155,35 +162,25 @@ extension ServerStateMachine {
}
}

mutating func sendData() throws {
mutating func sendAuthenticationData(_ data: ByteBuffer, complete: Bool) throws {
switch self.state {
case .authenticating:
()
case .inactive,
.waitingForClientGreeting,
.waitingToSendAuthenticationMethod,
.waitingForClientRequest,
.waitingToSendResponse,
.active,
.error:
throw SOCKSError.InvalidServerState()
}
}

mutating func authenticationComplete() throws {
switch self.state {
case .authenticating:
()
break
case .waitingForClientRequest:
guard self.authenticationMethod == .noneRequired, complete, data.readableBytes == 0 else {
throw SOCKSError.InvalidServerState()
}
case .inactive,
.waitingForClientGreeting,
.waitingToSendAuthenticationMethod,
.waitingForClientRequest,
.waitingToSendResponse,
.active,
.error:
throw SOCKSError.InvalidServerState()
}

self.state = .waitingForClientRequest
if complete {
self.state = .waitingForClientRequest
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ extension SOCKSServerHandlerTests {
("testOutboundErrorsAreHandled", testOutboundErrorsAreHandled),
("testFlushOnHandlerRemoved", testFlushOnHandlerRemoved),
("testForceHandlerRemovalAfterAuth", testForceHandlerRemovalAfterAuth),
("testAutoAuthenticationComplete", testAutoAuthenticationComplete),
("testAutoAuthenticationCompleteWithManualCompletion", testAutoAuthenticationCompleteWithManualCompletion),
("testEagerClientRequestBeforeAuthenticationComplete", testEagerClientRequestBeforeAuthenticationComplete),
("testManualAuthenticationFailureExtraBytes", testManualAuthenticationFailureExtraBytes),
("testManualAuthenticationFailureInvalidCompletion", testManualAuthenticationFailureInvalidCompletion),
]
}
}
Expand Down
104 changes: 94 additions & 10 deletions Tests/NIOSOCKSTests/SOCKSServerHandshakeHandler+Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class SOCKSServerHandlerTests: XCTestCase {

// tests dripfeeding to ensure we buffer data correctly
func testTypicalWorkflowDripfeed() {
let expectedGreeting = ClientGreeting(methods: [.noneRequired])
let expectedGreeting = ClientGreeting(methods: [.gssapi])
let expectedRequest = SOCKSRequest(command: .connect, addressType: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))
let expectedData = ByteBuffer(string: "1234")
let testHandler = PromiseTestHandler(
Expand All @@ -168,16 +168,15 @@ class SOCKSServerHandlerTests: XCTestCase {
self.assertOutputBuffer([])
self.writeInbound([0x01])
self.assertOutputBuffer([])
self.writeInbound([0x00])
self.writeInbound([0x01])
self.assertOutputBuffer([])
XCTAssertTrue(testHandler.hadGreeting)

// write the auth selection
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired))))
self.assertOutputBuffer([0x05, 0x00])
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.selectedAuthenticationMethod(.init(method: .gssapi))))
self.assertOutputBuffer([0x05, 0x01])

// finish authentication - nothing should be written
// as this is informing the state machine only
// finish authentication with some bytes
XCTAssertNoThrow(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0xFF, 0xFF]), complete: true)))
self.assertOutputBuffer([0xFF, 0xFF])

Expand Down Expand Up @@ -217,10 +216,11 @@ class SOCKSServerHandlerTests: XCTestCase {
func testForceHandlerRemovalAfterAuth() {

// go through auth
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])
XCTAssertNoThrow(try self.handler.stateMachine.authenticationComplete())
self.writeInbound([0x05, 0x01, 0x01])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi)))
self.assertOutputBuffer([0x05, 0x01])
self.writeOutbound(.authenticationData(ByteBuffer(), complete: true))
self.assertOutputBuffer([])
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
Expand All @@ -229,4 +229,88 @@ class SOCKSServerHandlerTests: XCTestCase {
// removing the handler, it should fail
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(string: "hello, world!"), complete: false)))
}

func testAutoAuthenticationComplete() {

// server selects none-required, this should mean we can continue without
// having to manually inform the state machine
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])

// if we try and write the request then the data would be read
// as authentication data, and so the server wouldn't reply
// with a response
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
}

func testAutoAuthenticationCompleteWithManualCompletion() {

// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])

// complete authentication, but nothing should be written
// to the network
self.writeOutbound(.authenticationData(ByteBuffer(), complete: true))
self.assertOutputBuffer([])

// if we try and write the request then the data would be read
// as authentication data, and so the server wouldn't reply
// with a response
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.writeOutbound(.response(.init(reply: .succeeded, boundAddress: .address(try! .init(ipAddress: "127.0.0.1", port: 80)))))
self.assertOutputBuffer([0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
}

func testEagerClientRequestBeforeAuthenticationComplete() {

// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x01])
self.assertInbound(.greeting(.init(methods: [.gssapi])))
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .gssapi)))
self.assertOutputBuffer([0x05, 0x01])

// at this point authentication isn't complete
// so if the client sends a request then the
// server will read those as authentication bytes
self.writeInbound([0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
self.assertInbound(.authenticationData(ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])))
}

func testManualAuthenticationFailureExtraBytes() {
// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])

// invalid authentication completion
// we've selected `noneRequired`, so no
// bytes should be written
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: [0x00]), complete: true)))
}

func testManualAuthenticationFailureInvalidCompletion() {
// server selects none-required, this should mean we can continue without
// having to manually inform the state machine. However, informing the state
// machine manually shouldn't break anything.
self.writeInbound([0x05, 0x01, 0x00])
self.writeOutbound(.selectedAuthenticationMethod(.init(method: .noneRequired)))
self.assertOutputBuffer([0x05, 0x00])

// invalid authentication completion
// authentication should have already completed
// as we selected `noneRequired`, so sending
// `complete = false` should be an error
XCTAssertThrowsError(try self.channel.writeOutbound(ServerMessage.authenticationData(ByteBuffer(bytes: []), complete: false)))
}
}
4 changes: 0 additions & 4 deletions Tests/NIOSOCKSTests/ServerStateMachine+Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ public class ServerStateMachineTests: XCTestCase {
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
XCTAssertFalse(stateMachine.proxyEstablished)

// authentication is now finished, as we didn't send any
XCTAssertNoThrow(try stateMachine.authenticationComplete())

// send the client request
var request = ByteBuffer(bytes: [0x05, 0x01, 0x00, 0x01, 127, 0, 0, 1, 0, 80])
XCTAssertNoThrow(try stateMachine.receiveBuffer(&request))
Expand All @@ -61,7 +58,6 @@ public class ServerStateMachineTests: XCTestCase {
XCTAssertNoThrow(try stateMachine.connectionEstablished())
XCTAssertNoThrow(try stateMachine.receiveBuffer(&greeting))
XCTAssertNoThrow(try stateMachine.sendAuthenticationMethod(.init(method: .noneRequired)))
XCTAssertNoThrow(try stateMachine.authenticationComplete())

// write some invalid bytes from the client
// the state machine should throw
Expand Down

0 comments on commit b03d835

Please sign in to comment.