Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Npgsql/Internal/NpgsqlConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public sealed partial class NpgsqlConnector : IDisposable
ProvideClientCertificatesCallback? ProvideClientCertificatesCallback { get; }
RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; }
ProvidePasswordCallback? ProvidePasswordCallback { get; }
PhysicalOpenCallback? PhysicalOpenCallback { get; set; }
PhysicalOpenAsyncCallback? PhysicalOpenAsyncCallback { get; set; }

internal Encoding TextEncoding { get; private set; } = default!;

Expand Down Expand Up @@ -304,6 +306,8 @@ internal NpgsqlConnector(ConnectorSource connectorSource, NpgsqlConnection conn)
ProvideClientCertificatesCallback = conn.ProvideClientCertificatesCallback;
UserCertificateValidationCallback = conn.UserCertificateValidationCallback;
ProvidePasswordCallback = conn.ProvidePasswordCallback;
PhysicalOpenCallback = conn.PhysicalOpenCallback;
PhysicalOpenAsyncCallback = conn.PhysicalOpenAsyncCallback;
}

NpgsqlConnector(NpgsqlConnector connector)
Expand Down Expand Up @@ -491,6 +495,11 @@ internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken ca
OpenTimestamp = DateTime.UtcNow;
Log.Trace($"Opened connection to {Host}:{Port}");

if (async && PhysicalOpenAsyncCallback is not null)
await PhysicalOpenAsyncCallback(this);
else if (!async && PhysicalOpenCallback is not null)
PhysicalOpenCallback(this);

if (Settings.Multiplexing)
{
// Start an infinite async loop, which processes incoming multiplexing traffic.
Expand Down
28 changes: 25 additions & 3 deletions src/Npgsql/NpgsqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,16 @@ public override string ConnectionString
/// </remarks>
public ProvidePasswordCallback? ProvidePasswordCallback { get; set; }

/// <summary>
/// Gets or sets the delegate used to setup a connection whenever a physical connection is opened synchronously.
/// </summary>
public PhysicalOpenCallback? PhysicalOpenCallback { get; set; }

/// <summary>
/// Gets or sets the delegate used to setup a connection whenever a physical connection is opened asynchronously.
/// </summary>
public PhysicalOpenAsyncCallback? PhysicalOpenAsyncCallback { get; set; }

#endregion Connection string management

#region Configuration settings
Expand Down Expand Up @@ -2036,16 +2046,16 @@ enum ConnectorBindingScope
public delegate void NotificationEventHandler(object sender, NpgsqlNotificationEventArgs e);

/// <summary>
/// Represents the method that allows the application to provide a certificate collection to be used for SSL client authentication
/// Represents a method that allows the application to provide a certificate collection to be used for SSL client authentication
/// </summary>
/// <param name="certificates">
/// A <see cref="System.Security.Cryptography.X509Certificates.X509CertificateCollection"/> to be filled with one or more client
/// A <see cref="X509CertificateCollection"/> to be filled with one or more client
/// certificates.
/// </param>
public delegate void ProvideClientCertificatesCallback(X509CertificateCollection certificates);

/// <summary>
/// Represents the method that allows the application to provide a password at connection time in code rather than configuration
/// Represents a method that allows the application to provide a password at connection time in code rather than configuration
/// </summary>
/// <param name="host">Hostname</param>
/// <param name="port">Port</param>
Expand All @@ -2054,5 +2064,17 @@ enum ConnectorBindingScope
/// <returns>A valid password for connecting to the database</returns>
public delegate string ProvidePasswordCallback(string host, int port, string database, string username);

/// <summary>
/// Represents a method that allows the application to setup a connection with custom commands.
/// </summary>
/// <param name="connection">Physical connection to the database</param>
public delegate void PhysicalOpenCallback(NpgsqlConnector connection);

/// <summary>
/// Represents an asynchronous method that allows the application to setup a connection with custom commands.
/// </summary>
/// <param name="connection">Physical connection to the database</param>
public delegate Task PhysicalOpenAsyncCallback(NpgsqlConnector connection);

#endregion
}
6 changes: 6 additions & 0 deletions src/Npgsql/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ Npgsql.NpgsqlConnection.BeginBinaryImportAsync(string! copyFromCommand, System.T
Npgsql.NpgsqlConnection.BeginRawBinaryCopyAsync(string! copyCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<Npgsql.NpgsqlRawCopyStream!>!
Npgsql.NpgsqlConnection.BeginTextExportAsync(string! copyToCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<System.IO.TextReader!>!
Npgsql.NpgsqlConnection.BeginTextImportAsync(string! copyFromCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<System.IO.TextWriter!>!
Npgsql.NpgsqlConnection.PhysicalOpenAsyncCallback.get -> Npgsql.PhysicalOpenAsyncCallback?
Npgsql.NpgsqlConnection.PhysicalOpenAsyncCallback.set -> void
Npgsql.NpgsqlConnection.PhysicalOpenCallback.get -> Npgsql.PhysicalOpenCallback?
Npgsql.NpgsqlConnection.PhysicalOpenCallback.set -> void
Npgsql.NpgsqlConnection.Settings.get -> Npgsql.NpgsqlConnectionStringBuilder!
Npgsql.NpgsqlConnectionStringBuilder.HostRecheckSeconds.get -> int
Npgsql.NpgsqlConnectionStringBuilder.HostRecheckSeconds.set -> void
Expand All @@ -27,6 +31,8 @@ Npgsql.NpgsqlConnectionStringBuilder.TargetSessionAttributes.set -> void
*REMOVED*Npgsql.NpgsqlDatabaseInfo.NpgsqlDatabaseInfo(string! host, int port, string! databaseName, System.Version! version) -> void
*REMOVED*Npgsql.NpgsqlDatabaseInfo.Port.get -> int
*REMOVED*Npgsql.NpgsqlDatabaseInfo.Version.get -> System.Version!
Npgsql.PhysicalOpenAsyncCallback
Npgsql.PhysicalOpenCallback
NpgsqlTypes.NpgsqlTsQuery.Write(System.Text.StringBuilder! stringBuilder) -> void
override NpgsqlTypes.NpgsqlTsQuery.Equals(object? obj) -> bool
override NpgsqlTypes.NpgsqlTsQuery.GetHashCode() -> int
Expand Down
112 changes: 112 additions & 0 deletions test/Npgsql.Tests/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,118 @@ public async Task TimeoutDuringAuthentication()
.With.InnerException.TypeOf<TimeoutException>());
}

[Test]
public async Task Physical_open_callback_sync()
{
await using var defaultConn = await OpenConnectionAsync();
await using var _ = await CreateTempTable(defaultConn, "ID INTEGER", out var table);

using var __ = CreateTempPool(ConnectionString, out var connectionString);
using var conn = new NpgsqlConnection(connectionString);
conn.PhysicalOpenCallback = connector =>
{
using var cmd = connector.CreateCommand($"INSERT INTO \"{table}\" VALUES(1)");
cmd.ExecuteNonQuery();
};
conn.PhysicalOpenAsyncCallback = _ => throw new NotImplementedException();

Assert.DoesNotThrow(conn.Open);

var rowsCount = (long)(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM \"{table}\""))!;
Assert.AreEqual(1, rowsCount);
}

[Test]
public async Task Physical_open_async_callback()
{
await using var defaultConn = await OpenConnectionAsync();
await using var _ = await CreateTempTable(defaultConn, "ID INTEGER", out var table);

using var __ = CreateTempPool(ConnectionString, out var connectionString);
await using var conn = new NpgsqlConnection(connectionString);
conn.PhysicalOpenAsyncCallback = async connector =>
{
using var cmd = connector.CreateCommand($"INSERT INTO \"{table}\" VALUES(1)");
await cmd.ExecuteNonQueryAsync();
};
conn.PhysicalOpenCallback = _ => throw new NotImplementedException();

Assert.DoesNotThrowAsync(conn.OpenAsync);

var rowsCount = (long)(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM \"{table}\""))!;
Assert.AreEqual(1, rowsCount);
}

[Test]
public async Task Physical_open_callback_throws()
{
using var _ = CreateTempPool(ConnectionString, out var connectionString);
await using var conn = new NpgsqlConnection(connectionString);
conn.PhysicalOpenCallback = _ => throw new NotImplementedException();

Assert.Throws<NotImplementedException>(conn.Open);
}

[Test]
public async Task Physical_open_async_callback_throws()
{
PhysicalOpenAsyncCallback callback = _ => throw new NotImplementedException();

using var _ = CreateTempPool(ConnectionString, out var connectionString);
await using var conn = new NpgsqlConnection(connectionString);
conn.PhysicalOpenAsyncCallback = callback;

Assert.ThrowsAsync<NotImplementedException>(conn.OpenAsync);

if (IsMultiplexing)
{
// With multiplexing a physical connection might open on NpgsqlConnection.OpenAsync (if there was no completed bootstrap beforehand)
// or on NpgsqlCommand.ExecuteReaderAsync.
// We've already tested the first case above, testing the second one below.
conn.PhysicalOpenAsyncCallback = null;
// Allow the bootstrap to complete
Assert.DoesNotThrowAsync(conn.OpenAsync);

NpgsqlConnection.ClearPool(conn);

conn.PhysicalOpenAsyncCallback = callback;
Assert.ThrowsAsync<NotImplementedException>(() => conn.ExecuteNonQueryAsync("SELECT 1"));
}
}

[Test]
public async Task Physical_open_callback_idle_connection()
{
if (IsMultiplexing)
return;

using var _ = CreateTempPool(ConnectionString, out var connectionString);
await using var conn = new NpgsqlConnection(connectionString);

Assert.DoesNotThrow(conn.Open);
conn.Close();

conn.PhysicalOpenCallback = _ => throw new NotImplementedException();

Assert.DoesNotThrow(conn.Open);
Assert.DoesNotThrow(() => conn.ExecuteNonQuery("SELECT 1"));
}

[Test]
public async Task Physical_open_async_callback_idle_connection()
{
using var _ = CreateTempPool(ConnectionString, out var connectionString);
await using var conn = new NpgsqlConnection(connectionString);

Assert.DoesNotThrowAsync(conn.OpenAsync);
await conn.CloseAsync();

conn.PhysicalOpenAsyncCallback = _ => throw new NotImplementedException();

Assert.DoesNotThrowAsync(conn.OpenAsync);
Assert.DoesNotThrowAsync(() => conn.ExecuteNonQueryAsync("SELECT 1"));
}

public ConnectionTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {}
}
}