Task-based Socket Extension Methods (C# TPL)
Introduction
The Task-based Asynchronous Pattern (TAP) is the recommended way to write asynchronous code for .NET applications. As I explained in my last post, TAP is intended to replace the Asynchronous Programming Model (APM) and the Event-based Asynchronous Pattern (EAP), however many classes in the .NET framework still use these older patterns. Fortunately, these can be turned into TAP-style "awaitable" methods with relative ease. By doing so, you reap the benefits that come from working with Task
(and Task<T>
) objects. In this post, I will convert a set of APM-style methods from the System.Net.Sockets
namespace to TAP methods and provide an end-to-end example of how to use them in a generic TCP socket server.
All of the code examples in this post can be downloaded as individual .cs files at my gist link below:
I'll go through the various Socket
methods in the order they would occur for a simple scenario: Connect to a remote host and send a short text message that is received and read by the server over TCP. This code also uses a Result
object for all return values. The code for the Result
class will also be covered.
Connect
We begin with the assumption that a remote end point exists, defined by an IP address and port number, that is bound and listening for incoming connetions with a TCP socket. We will refer to this side of our example as the server. The other side which attempts to connect to this end point by calling ConnectWithTimeoutAsync
will be refered to as the client. ConnectWithTimeoutAsync
is an extension method that wraps the BeginConnect
/EndConnect
methods of the Socket
class and returns a Task<Result>
object:
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result> ConnectWithTimeoutAsync(
this Socket socket,
string remoteIpAddress,
int port,
int timeoutMs)
{
try
{
var connectTask = Task.Factory.FromAsync(
socket.BeginConnect,
socket.EndConnect,
remoteIpAddress,
port,
null);
if (connectTask == await Task.WhenAny(connectTask, Task.Delay(timeoutMs)).ConfigureAwait(false))
{
await connectTask.ConfigureAwait(false);
}
else
{
throw new TimeoutException();
}
}
catch (SocketException ex)
{
return Result.Fail($"{ex.Message} ({ex.GetType()})");
}
catch (TimeoutException ex)
{
return Result.Fail($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok();
}
}
}
Note that this method requires a timeoutMs
value. The timeout behavior is achieved through use of Task.WhenAny
. Since the same timeout behavior has been applied to all of the socket methods (where appropriate) let's take a closer look at how this is done:
if (connectTask == await Task.WhenAny(connectTask, Task.Delay(timeoutMs)).ConfigureAwait(false))
{
await connectTask.ConfigureAwait(false);
}
else
{
throw new TimeoutException();
}
We pass 2 tasks into Task.WhenAny
: connectTask
and Task.Delay(timeoutMs)
. Whichever one of these tasks completes first will be returned. Let's say timeoutMs = 5000
, which is a timeout value of five seconds. If the task returned from Task.WhenAny
is connectTask, the connection was successful and we await the result. However if connectTask
is not returned, that means our socket was unable to connect to the endpoint after trying for five seconds and a TimeoutException
is thrown. The exception is handled by calling Result.Fail(string error)
, a static method which returns an instance of Result
with Success = false
and a string containing the Message
property of the exception.
Result Class, Explained
The Result<T>
class encapsulates all information pertaining to the outcome of a task, conveniently providing an error message in case the task failed and an object instance in case the task succeeded.
I was introduced to the Result
class by this blog post by Vladimir Khorikov, part of a series that applies Functional programming techniques to C#. He has made the code available on github, but I am still using the basic version below:
namespace AaronLuna.Common.Result
{
public class Result
{
protected Result(bool success, string error)
{
Success = success;
Error = error;
}
public bool Success { get; }
public string Error { get; }
public bool Failure => !Success;
public static Result Fail(string message)
{
return new Result(false, message);
}
public static Result<T> Fail<T>(string message)
{
return new Result<T>(default(T), false, message);
}
public static Result Ok()
{
return new Result(true, string.Empty);
}
public static Result<T> Ok<T>(T value)
{
return new Result<T>(value, true, string.Empty);
}
public static Result Combine(params Result[] results)
{
foreach (Result result in results)
{
if (result.Failure)
{
return result;
}
}
return Ok();
}
}
public class Result<T> : Result
{
protected internal Result(T value, bool success, string error)
: base(success, error)
{
Value = value;
}
public T Value { get; }
}
}
Accept
Ok, back to our socket example. On the other side of our connectTask
, the server is waiting for incoming connections inside the AcceptAsync
method, which wraps the BeginAccept
/EndAcceppt
method pair:
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result<Socket>> AcceptAsync(this Socket socket)
{
Socket transferSocket;
try
{
var acceptTask = Task<Socket>.Factory.FromAsync(socket.BeginAccept, socket.EndAccept, null);
transferSocket = await acceptTask.ConfigureAwait(false);
}
catch (SocketException ex)
{
return Result.Fail<Socket>($"{ex.Message} ({ex.GetType()})");
}
catch (InvalidOperationException ex)
{
return Result.Fail<Socket>($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok(transferSocket);
}
}
}
Receive
After ConnectWithTimeoutAsync
and AcceptAsync
have successfully ran to completion, the server will have a new Socket
instance, which is named transferSocket
in the code example above. The server will use this socket to receive data from the client. I created two diferent ReceiveAsync
methods, one with the same timeout behavior as ConnectWithTimeoutAsync
and one without:
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result<int>> ReceiveWithTimeoutAsync(
this Socket socket,
byte[] buffer,
int offset,
int size,
SocketFlags socketFlags,
int timeoutMs)
{
int bytesReceived;
try
{
var asyncResult = socket.BeginReceive(buffer, offset, size, socketFlags, null, null);
var receiveTask = Task<int>.Factory.FromAsync(asyncResult, _ => socket.EndReceive(asyncResult));
if (receiveTask == await Task.WhenAny(receiveTask, Task.Delay(timeoutMs)).ConfigureAwait(false))
{
bytesReceived = await receiveTask.ConfigureAwait(false);
}
else
{
throw new TimeoutException();
}
}
catch (SocketException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
catch (TimeoutException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok(bytesReceived);
}
public static async Task<Result<int>> ReceiveAsync(
this Socket socket,
byte[] buffer,
int offset,
int size,
SocketFlags socketFlags)
{
int bytesReceived;
try
{
var asyncResult = socket.BeginReceive(buffer, offset, size, socketFlags, null, null);
bytesReceived = await Task<int>.Factory.FromAsync(asyncResult, _ => socket.EndReceive(asyncResult));
}
catch (SocketException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok(bytesReceived);
}
}
}
Send
On the client side, I created an equivalent SendWithTimeoutAsync
method, but did not create the version without the timeout. My reasoning for this is that a Send
operation is an active process while a Receive
operation is a passive one. There are situations where a socket may be waiting for a Send
operation on the other side to begin, without any way to predict how long it's going to take until receiving data. In those cases a timeout would not be helpful. This is obviously not the case with a Send
operation:
namespace AaronLuna.TplSockets
{
using System;
using System.Net.Sockets;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public static partial class TplSocketExtensions
{
public static async Task<Result<int>> SendWithTimeoutAsync(
this Socket socket,
byte[] buffer,
int offset,
int size,
SocketFlags socketFlags,
int timeoutMs)
{
int bytesSent;
try
{
var asyncResult = socket.BeginSend(buffer, offset, size, socketFlags, null, null);
var sendBytesTask = Task<int>.Factory.FromAsync(asyncResult, _ => socket.EndSend(asyncResult));
if (sendBytesTask != await Task.WhenAny(sendBytesTask, Task.Delay(timeoutMs)).ConfigureAwait(false))
{
throw new TimeoutException();
}
bytesSent = await sendBytesTask;
}
catch (SocketException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
catch (TimeoutException ex)
{
return Result.Fail<int>($"{ex.Message} ({ex.GetType()})");
}
return Result.Ok(bytesSent);
}
}
}
Client/Server Example
It's time to put these pieces together and demonstrate how to use these methods and the Result
objects:
namespace AaronLuna.TplSockets
{
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
using AaronLuna.Common.Result;
public class TplSocketExample
{
const int BufferSize = 8 * 1024;
const int ConnectTimeoutMs = 3000;
const int ReceiveTimeoutMs = 3000;
const int SendTimeoutMs = 3000;
Socket _listenSocket;
Socket _clientSocket;
Socket _transferSocket;
public async Task<Result> SendAndReceiveTextMesageAsync()
{
_listenSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_transferSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
var serverPort = 7003;
var ipHostInfo = Dns.GetHostEntry(Dns.GetHostName());
var ipAddress =
ipHostInfo.AddressList.Select(ip => ip)
.FirstOrDefault(ip => ip.AddressFamily == AddressFamily.InterNetwork);
var ipEndPoint = new IPEndPoint(ipAddress, serverPort);
// Step 1: Bind a socket to a local TCP port and Listen for incoming connections
_listenSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true);
_listenSocket.Bind(ipEndPoint);
_listenSocket.Listen(5);
// Step 2: Create a Task and accept the next incoming connection (ServerAcceptTask)
// NOTE: The call to AcceptConnectionTask is not awaited, therefore this method
// continues executing
var acceptTask = Task.Run(AcceptConnectionTask);
// Step 3: With another socket, connect to the bound socket and await the result (ClientConnectTask)
var connectResult =
await _clientSocket.ConnectWithTimeoutAsync(
ipAddress.ToString(),
serverPort,
ConnectTimeoutMs).ConfigureAwait(false);
// Step 4: Await the result of the ServerAcceptTask
var acceptResult = await acceptTask.ConfigureAwait(false);
// If either ServerAcceptTask or ClientConnectTask did not complete successfully,stop execution and report the error
if (Result.Combine(acceptResult, connectResult).Failure)
{
return Result.Fail("There was an error connecting to the server/accepting connection from the client");
}
// Step 5: Store the transfer socket if ServerAcceptTask was successful
_transferSocket = acceptResult.Value;
// Step 6: Create a Task and recieve data from the transfer socket (ServerReceiveBytesTask)
// NOTE: The call to ReceiveMessageAsync is not awaited, therefore this method
// continues executing
var receiveTask = Task.Run(ReceiveMessageAsync);
// Step 7: Encode a string message before sending it to the server
var messageToSend = "this is a text message from a socket";
var messageData = Encoding.ASCII.GetBytes(messageToSend);
// Step 8: Send the message data to the server and await the result (ClientSendBytesTask)
var sendResult =
await _clientSocket.SendWithTimeoutAsync(
messageData,
0,
messageData.Length,
0,
SendTimeoutMs).ConfigureAwait(false);
// Step 9: Await the result of ServerReceiveBytesTask
var receiveResult = await receiveTask.ConfigureAwait(false);
// Step 10: If either ServerReceiveBytesTask or ClientSendBytesTask did not complete successfully,stop execution and report the error
if (Result.Combine(sendResult, receiveResult).Failure)
{
return Result.Fail("There was an error sending/receiving data from the client");
}
// Step 11: Compare the string that was received to what was sent, report an error if not matching
var messageReceived = receiveResult.Value;
if (messageToSend != messageReceived)
{
return Result.Fail("Error: Message received from client did not match what was sent");
}
// Step 12: Report the entire task was successful since all subtasks were successful
return Result.Ok();
}
async Task<Result<Socket>> AcceptConnectionTask()
{
return await _listenSocket.AcceptAsync().ConfigureAwait(false);
}
async Task<Result<string>> ReceiveMessageAsync()
{
var message = string.Empty;
var buffer = new byte[BufferSize];
var receiveResult =
await _transferSocket.ReceiveWithTimeoutAsync(
buffer,
0,
BufferSize,
0,
ReceiveTimeoutMs).ConfigureAwait(false);
var bytesReceived = receiveResult.Value;
if (bytesReceived == 0)
{
return Result.Fail<string>("Error reading message from client, no data was received");
}
message = Encoding.ASCII.GetString(buffer, 0, bytesReceived);
return Result.Ok(message);
}
}
}
Summary
As stated at the beginning of this post, all of the source code for these extension methods and the example can be found at the gist embedded below: