Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 18 additions & 30 deletions src/Npgsql/Internal/NpgsqlConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -939,10 +939,19 @@ void Connect(NpgsqlTimeout timeout)

async Task ConnectAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken)
{
// Note that there aren't any timeout-able or cancellable DNS methods
Task<IPAddress[]> GetHostAddressesAsync(CancellationToken ct) =>
#if NET6_0_OR_GREATER
Dns.GetHostAddressesAsync(Host, ct);
#else
Dns.GetHostAddressesAsync(Host);
#endif

// Whether the framework and/or the OS platform support Dns.GetHostAddressesAsync cancellation API or they do not,
// we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting
// and raises the exception, while the actual task may be left running.
var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath)
? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) }
: (await GetHostAddressesAsync(timeout, cancellationToken))
: (await TaskTimeoutAndCancellation.ExecuteAsync(GetHostAddressesAsync, timeout, cancellationToken))
.Select(a => new IPEndPoint(a, Port)).ToArray();

// Give each IP an equal share of the remaining time
Expand Down Expand Up @@ -995,39 +1004,18 @@ async Task ConnectAsync(NpgsqlTimeout timeout, CancellationToken cancellationTok
}
}

Task<IPAddress[]> GetHostAddressesAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken)
{
// .NET 6.0 added cancellation support to GetHostAddressesAsync, which allows us to implement real
// cancellation and timeout. On older TFMs, we fake-cancel the operation, i.e. stop waiting
// and raise the exception, but the actual connection task is left running.

#if NET6_0_OR_GREATER
var task = TaskExtensions.ExecuteWithTimeout(
ct => Dns.GetHostAddressesAsync(Host, ct),
timeout, cancellationToken);
#else
var task = Dns.GetHostAddressesAsync(Host);
#endif

// As the cancellation support of GetHostAddressesAsync is not guaranteed on all platforms
// we apply the fake-cancel mechanism in all cases.
return task.WithCancellationAndTimeout(timeout, cancellationToken);
}

static Task OpenSocketConnectionAsync(Socket socket, EndPoint endpoint, NpgsqlTimeout perIpTimeout, CancellationToken cancellationToken)
{
// .NET 5.0 added cancellation support to ConnectAsync, which allows us to implement real
// cancellation and timeout. On older TFMs, we fake-cancel the operation, i.e. stop waiting
// and raise the exception, but the actual connection task is left running.

// Whether the framework and/or the OS platform support Socket.ConnectAsync cancellation API or they do not,
// we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting
// and raises the exception, while the actual task may be left running.
Task ConnectAsync(CancellationToken ct) =>
#if NET5_0_OR_GREATER
return TaskExtensions.ExecuteWithTimeout(
ct => socket.ConnectAsync(endpoint, ct).AsTask(),
perIpTimeout, cancellationToken);
socket.ConnectAsync(endpoint, ct).AsTask();
#else
return socket.ConnectAsync(endpoint)
.WithCancellationAndTimeout(perIpTimeout, cancellationToken);
socket.ConnectAsync(endpoint);
#endif
return TaskTimeoutAndCancellation.ExecuteAsync(ConnectAsync, perIpTimeout, cancellationToken);
}
}

Expand Down
65 changes: 65 additions & 0 deletions src/Npgsql/Shims/TaskExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#if !NET6_0_OR_GREATER
using System.Collections.Generic;

namespace System.Threading.Tasks;

static class TaskExtensions
{
/// <summary>
/// Gets a <see cref="Task"/> that will complete when this <see cref="Task"/> completes, when the specified timeout expires, or when the specified <see cref="CancellationToken"/> has cancellation requested.
/// </summary>
/// <param name="task">The <see cref="Task"/> representing the asynchronous wait.</param>
/// <param name="timeout">The timeout after which the <see cref="Task"/> should be faulted with a <see cref="TimeoutException"/> if it hasn't otherwise completed.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for a cancellation request.</param>
/// <returns>The <see cref="Task"/> representing the asynchronous wait.</returns>
/// <remarks>This method reproduces new to the .NET 6.0 API <see cref="Task"/>.WaitAsync.</remarks>
public static async Task WaitAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken)
{
var tasks = new List<Task>(3);

Task? cancellationTask = default;
CancellationTokenRegistration registration = default;
if (cancellationToken.CanBeCanceled)
{
var tcs = new TaskCompletionSource<bool>();
registration = cancellationToken.Register(s => ((TaskCompletionSource<bool>)s!).TrySetResult(true), tcs);
cancellationTask = tcs.Task;
tasks.Add(cancellationTask);
}

Task? delayTask = default;
CancellationTokenSource? delayCts = default;
if (timeout != Timeout.InfiniteTimeSpan)
{
var timeLeft = timeout;
delayCts = new CancellationTokenSource();
delayTask = Task.Delay(timeLeft, delayCts.Token);
tasks.Add(delayTask);
}

try
{
if (tasks.Count != 0)
{
tasks.Add(task);
var result = await Task.WhenAny(tasks);
if (result == cancellationTask)
{
task = Task.FromCanceled(cancellationToken);
}
else if (result == delayTask)
{
task = Task.FromException(new TimeoutException());
}
}
await task;
}
finally
{
delayCts?.Cancel();
delayCts?.Dispose();
registration.Dispose();
}
}
}
#endif
128 changes: 0 additions & 128 deletions src/Npgsql/TaskExtensions.cs

This file was deleted.

66 changes: 66 additions & 0 deletions src/Npgsql/TaskTimeoutAndCancellation.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Npgsql.Util;

namespace Npgsql;

/// <summary>
/// Utility class to execute a potentially non-cancellable <see cref="Task"/> while allowing to timeout and/or cancel awaiting for it and at the same time prevent <see cref="TaskScheduler.UnobservedTaskException"/> event if the original <see cref="Task"/> fails later.
/// </summary>
static class TaskTimeoutAndCancellation
{
/// <summary>
/// Executes a potentially non-cancellable <see cref="Task{TResult}"/> while allowing to timeout and/or cancel awaiting for it.
/// If the given task does not complete within <paramref name="timeout"/>, a <see cref="TimeoutException"/> is thrown.
/// The executed <see cref="Task{TResult}"/> may be left in an incomplete state after the <see cref="Task{TResult}"/> that this method returns completes dues to timeout and/or cancellation request.
/// The method guarantees that the abandoned, incomplete <see cref="Task{TResult}"/> is not going to produce <see cref="TaskScheduler.UnobservedTaskException"/> event if it fails later.
/// </summary>
/// <param name="getTaskFunc">Gets the <see cref="Task{TResult}"/> for execution with a combined <see cref="CancellationToken"/> that attempts to cancel the <see cref="Task{TResult}"/> in an event of the timeout or external cancellation request.</param>
/// <param name="timeout">The timeout after which the <see cref="Task{TResult}"/> should be faulted with a <see cref="TimeoutException"/> if it hasn't otherwise completed.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for a cancellation request.</param>
/// <typeparam name="TResult">The result <see cref="Type"/>.</typeparam>
/// <returns>The <see cref="Task{TResult}"/> representing the asynchronous wait.</returns>
internal static async Task<TResult> ExecuteAsync<TResult>(Func<CancellationToken, Task<TResult>> getTaskFunc, NpgsqlTimeout timeout, CancellationToken cancellationToken)
{
Task<TResult>? task = default;
await ExecuteAsync(ct => (Task)(task = getTaskFunc(ct)), timeout, cancellationToken);
return await task!;
}

/// <summary>
/// Executes a potentially non-cancellable <see cref="Task"/> while allowing to timeout and/or cancel awaiting for it.
/// If the given task does not complete within <paramref name="timeout"/>, a <see cref="TimeoutException"/> is thrown.
/// The executed <see cref="Task"/> may be left in an incomplete state after the <see cref="Task"/> that this method returns completes dues to timeout and/or cancellation request.
/// The method guarantees that the abandoned, incomplete <see cref="Task"/> is not going to produce <see cref="TaskScheduler.UnobservedTaskException"/> event if it fails later.
/// </summary>
/// <param name="getTaskFunc">Gets the <see cref="Task"/> for execution with a combined <see cref="CancellationToken"/> that attempts to cancel the <see cref="Task"/> in an event of the timeout or external cancellation request.</param>
/// <param name="timeout">The timeout after which the <see cref="Task"/> should be faulted with a <see cref="TimeoutException"/> if it hasn't otherwise completed.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for a cancellation request.</param>
/// <returns>The <see cref="Task"/> representing the asynchronous wait.</returns>
internal static async Task ExecuteAsync(Func<CancellationToken, Task> getTaskFunc, NpgsqlTimeout timeout, CancellationToken cancellationToken)
{
using var combinedCts = timeout.IsSet ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) : null;
var task = getTaskFunc(combinedCts?.Token ?? cancellationToken);
try
{
try
{
await task.WaitAsync(timeout.CheckAndGetTimeLeft(), cancellationToken);
}
catch (TimeoutException) when (!task!.IsCompleted)
{
// Attempt to stop the Task in progress.
combinedCts?.Cancel();
throw;
}
}
catch
{
// Prevent unobserved Task notifications by observing the failed Task exception.
// To test: comment the next line out and re-run TaskExtensionsTest.DelayedFaultedTaskCancellation.
_ = task.ContinueWith(t => _ = t.Exception, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Current);
throw;
}
}
}
Loading