Skip to content

Commit c255f15

Browse files
Merge pull request #26051 from ClickHouse/fix_21184
Fix sequence_id in MySQL protocol
2 parents 06570d2 + dd0ad58 commit c255f15

File tree

12 files changed

+133
-54
lines changed

12 files changed

+133
-54
lines changed

src/Core/MySQL/MySQLClient.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ namespace ErrorCodes
2626
MySQLClient::MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_)
2727
: host(host_), port(port_), user(user_), password(std::move(password_))
2828
{
29-
client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
29+
mysql_context.client_capabilities = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
3030
}
3131

3232
MySQLClient::MySQLClient(MySQLClient && other)
3333
: host(std::move(other.host)), port(other.port), user(std::move(other.user)), password(std::move(other.password))
34-
, client_capability_flags(other.client_capability_flags)
34+
, mysql_context(other.mysql_context)
3535
{
36+
mysql_context.sequence_id = 0;
3637
}
3738

3839
void MySQLClient::connect()
@@ -56,7 +57,7 @@ void MySQLClient::connect()
5657

5758
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
5859
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
59-
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, seq);
60+
packet_endpoint = mysql_context.makeEndpoint(*in, *out);
6061
handshake();
6162
}
6263

@@ -68,7 +69,7 @@ void MySQLClient::disconnect()
6869
socket->close();
6970
socket = nullptr;
7071
connected = false;
71-
seq = 0;
72+
mysql_context.sequence_id = 0;
7273
}
7374

7475
/// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
@@ -87,10 +88,10 @@ void MySQLClient::handshake()
8788
String auth_plugin_data = native41.getAuthPluginData();
8889

8990
HandshakeResponse handshake_response(
90-
client_capability_flags, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
91+
mysql_context.client_capabilities, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
9192
packet_endpoint->sendPacket<HandshakeResponse>(handshake_response, true);
9293

93-
ResponsePacket packet_response(client_capability_flags, true);
94+
ResponsePacket packet_response(mysql_context.client_capabilities, true);
9495
packet_endpoint->receivePacket(packet_response);
9596
packet_endpoint->resetSequenceId();
9697

@@ -105,7 +106,7 @@ void MySQLClient::writeCommand(char command, String query)
105106
WriteCommand write_command(command, query);
106107
packet_endpoint->sendPacket<WriteCommand>(write_command, true);
107108

108-
ResponsePacket packet_response(client_capability_flags);
109+
ResponsePacket packet_response(mysql_context.client_capabilities);
109110
packet_endpoint->receivePacket(packet_response);
110111
switch (packet_response.getType())
111112
{
@@ -124,7 +125,7 @@ void MySQLClient::registerSlaveOnMaster(UInt32 slave_id)
124125
RegisterSlave register_slave(slave_id);
125126
packet_endpoint->sendPacket<RegisterSlave>(register_slave, true);
126127

127-
ResponsePacket packet_response(client_capability_flags);
128+
ResponsePacket packet_response(mysql_context.client_capabilities);
128129
packet_endpoint->receivePacket(packet_response);
129130
packet_endpoint->resetSequenceId();
130131
if (packet_response.getType() == PACKET_ERR)

src/Core/MySQL/MySQLClient.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ class MySQLClient
4545
String password;
4646

4747
bool connected = false;
48-
UInt32 client_capability_flags = 0;
49-
50-
uint8_t seq = 0;
48+
MySQLWireContext mysql_context;
5149
const UInt8 charset_utf8 = 33;
5250
const String mysql_native_password = "mysql_native_password";
5351

src/Core/MySQL/PacketEndpoint.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,15 @@ String PacketEndpoint::packetToText(const String & payload)
6868

6969
}
7070

71+
72+
MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(WriteBuffer & out)
73+
{
74+
return MySQLProtocol::PacketEndpoint::create(out, sequence_id);
75+
}
76+
77+
MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(ReadBuffer & in, WriteBuffer & out)
78+
{
79+
return MySQLProtocol::PacketEndpoint::create(in, out, sequence_id);
80+
}
81+
7182
}

src/Core/MySQL/PacketEndpoint.h

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "IMySQLReadPacket.h"
66
#include "IMySQLWritePacket.h"
77
#include "IO/MySQLPacketPayloadReadBuffer.h"
8+
#include <common/shared_ptr_helper.h>
89

910
namespace DB
1011
{
@@ -15,19 +16,13 @@ namespace MySQLProtocol
1516
/* Writes and reads packets, keeping sequence-id.
1617
* Throws ProtocolError, if packet with incorrect sequence-id was received.
1718
*/
18-
class PacketEndpoint
19+
class PacketEndpoint : public shared_ptr_helper<PacketEndpoint>
1920
{
2021
public:
2122
uint8_t & sequence_id;
2223
ReadBuffer * in;
2324
WriteBuffer * out;
2425

25-
/// For writing.
26-
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
27-
28-
/// For reading and writing.
29-
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
30-
3126
MySQLPacketPayloadReadBuffer getPayload();
3227

3328
void receivePacket(IMySQLReadPacket & packet);
@@ -48,8 +43,29 @@ class PacketEndpoint
4843

4944
/// Converts packet to text. Is used for debug output.
5045
static String packetToText(const String & payload);
46+
47+
protected:
48+
/// For writing.
49+
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
50+
51+
/// For reading and writing.
52+
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
53+
54+
friend struct shared_ptr_helper<PacketEndpoint>;
5155
};
5256

57+
using PacketEndpointPtr = std::shared_ptr<PacketEndpoint>;
58+
5359
}
5460

61+
struct MySQLWireContext
62+
{
63+
uint8_t sequence_id = 0;
64+
uint32_t client_capabilities = 0;
65+
size_t max_packet_size = 0;
66+
67+
MySQLProtocol::PacketEndpointPtr makeEndpoint(WriteBuffer & out);
68+
MySQLProtocol::PacketEndpointPtr makeEndpoint(ReadBuffer & in, WriteBuffer & out);
69+
};
70+
5571
}

src/Formats/FormatFactory.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace ErrorCodes
3333
extern const int LOGICAL_ERROR;
3434
extern const int FORMAT_IS_NOT_SUITABLE_FOR_INPUT;
3535
extern const int FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT;
36+
extern const int UNSUPPORTED_METHOD;
3637
}
3738

3839
const FormatFactory::Creators & FormatFactory::getCreators(const String & name) const
@@ -207,6 +208,9 @@ BlockOutputStreamPtr FormatFactory::getOutputStreamParallelIfPossible(
207208
WriteCallback callback,
208209
const std::optional<FormatSettings> & _format_settings) const
209210
{
211+
if (context->getMySQLProtocolContext() && name != "MySQLWire")
212+
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats");
213+
210214
const auto & output_getter = getCreators(name).output_processor_creator;
211215

212216
const Settings & settings = context->getSettingsRef();
@@ -309,7 +313,10 @@ OutputFormatPtr FormatFactory::getOutputFormatParallelIfPossible(
309313
{
310314
const auto & output_getter = getCreators(name).output_processor_creator;
311315
if (!output_getter)
312-
throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT);
316+
throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name);
317+
318+
if (context->getMySQLProtocolContext() && name != "MySQLWire")
319+
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats");
313320

314321
auto format_settings = _format_settings ? *_format_settings : getFormatSettings(context);
315322

@@ -344,7 +351,7 @@ OutputFormatPtr FormatFactory::getOutputFormat(
344351
{
345352
const auto & output_getter = getCreators(name).output_processor_creator;
346353
if (!output_getter)
347-
throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT);
354+
throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name);
348355

349356
if (context->hasQueryContext() && context->getSettingsRef().log_queries)
350357
context->getQueryContext()->addQueryFactoriesInfo(Context::QueryLogFactories::Format, name);

src/Interpreters/Context.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,11 +2355,6 @@ OutputFormatPtr Context::getOutputFormatParallelIfPossible(const String & name,
23552355
return FormatFactory::instance().getOutputFormatParallelIfPossible(name, buf, sample, shared_from_this());
23562356
}
23572357

2358-
OutputFormatPtr Context::getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const
2359-
{
2360-
return FormatFactory::instance().getOutputFormat(name, buf, sample, shared_from_this());
2361-
}
2362-
23632358

23642359
time_t Context::getUptimeSeconds() const
23652360
{
@@ -2732,4 +2727,18 @@ PartUUIDsPtr Context::getIgnoredPartUUIDs() const
27322727
return ignored_part_uuids;
27332728
}
27342729

2730+
void Context::setMySQLProtocolContext(MySQLWireContext * mysql_context)
2731+
{
2732+
assert(session_context.lock().get() == this);
2733+
assert(!mysql_protocol_context);
2734+
assert(mysql_context);
2735+
mysql_protocol_context = mysql_context;
2736+
}
2737+
2738+
MySQLWireContext * Context::getMySQLProtocolContext() const
2739+
{
2740+
assert(!mysql_protocol_context || session_context.lock().get());
2741+
return mysql_protocol_context;
2742+
}
2743+
27352744
}

src/Interpreters/Context.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ using ThrottlerPtr = std::shared_ptr<Throttler>;
119119
class ZooKeeperMetadataTransaction;
120120
using ZooKeeperMetadataTransactionPtr = std::shared_ptr<ZooKeeperMetadataTransaction>;
121121

122+
struct MySQLWireContext;
123+
122124
/// Callback for external tables initializer
123125
using ExternalTablesInitializer = std::function<void(ContextPtr)>;
124126

@@ -298,6 +300,8 @@ class Context: public std::enable_shared_from_this<Context>
298300
/// thousands of signatures.
299301
/// And I hope it will be replaced with more common Transaction sometime.
300302

303+
MySQLWireContext * mysql_protocol_context = nullptr;
304+
301305
Context();
302306
Context(const Context &);
303307
Context & operator=(const Context &);
@@ -538,7 +542,6 @@ class Context: public std::enable_shared_from_this<Context>
538542
BlockOutputStreamPtr getOutputStream(const String & name, WriteBuffer & buf, const Block & sample) const;
539543

540544
OutputFormatPtr getOutputFormatParallelIfPossible(const String & name, WriteBuffer & buf, const Block & sample) const;
541-
OutputFormatPtr getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const;
542545

543546
InterserverIOHandler & getInterserverIOHandler();
544547

@@ -794,14 +797,10 @@ class Context: public std::enable_shared_from_this<Context>
794797
/// Returns context of current distributed DDL query or nullptr.
795798
ZooKeeperMetadataTransactionPtr getZooKeeperMetadataTransaction() const;
796799

797-
struct MySQLWireContext
798-
{
799-
uint8_t sequence_id = 0;
800-
uint32_t client_capabilities = 0;
801-
size_t max_packet_size = 0;
802-
};
803-
804-
MySQLWireContext mysql;
800+
/// Caller is responsible for lifetime of mysql_context.
801+
/// Used in MySQLHandler for session context.
802+
void setMySQLProtocolContext(MySQLWireContext * mysql_context);
803+
MySQLWireContext * getMySQLProtocolContext() const;
805804

806805
PartUUIDsPtr getPartUUIDs() const;
807806
PartUUIDsPtr getIgnoredPartUUIDs() const;

src/Processors/Formats/Impl/MySQLOutputFormat.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@ MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_,
1717
{
1818
}
1919

20+
void MySQLOutputFormat::setContext(ContextPtr context_)
21+
{
22+
context = context_;
23+
/// MySQlWire is a special format that is usually used as output format for MySQL protocol connections.
24+
/// In this case we have to use the corresponding session context to set correct sequence_id.
25+
mysql_context = getContext()->getMySQLProtocolContext();
26+
if (!mysql_context)
27+
{
28+
/// But it's also possible to specify MySQLWire as output format for clickhouse-client or clickhouse-local.
29+
/// There is no MySQL protocol context in this case, so we create dummy one.
30+
own_mysql_context.emplace();
31+
mysql_context = &own_mysql_context.value();
32+
}
33+
packet_endpoint = mysql_context->makeEndpoint(out);
34+
}
35+
2036
void MySQLOutputFormat::initialize()
2137
{
2238
if (initialized)
@@ -40,7 +56,7 @@ void MySQLOutputFormat::initialize()
4056
packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
4157
}
4258

43-
if (!(getContext()->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
59+
if (!(mysql_context->client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
4460
{
4561
packet_endpoint->sendPacket(EOFPacket(0, 0));
4662
}
@@ -79,10 +95,10 @@ void MySQLOutputFormat::finalize()
7995
const auto & header = getPort(PortKind::Main).getHeader();
8096
if (header.columns() == 0)
8197
packet_endpoint->sendPacket(
82-
OKPacket(0x0, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
83-
else if (getContext()->mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
98+
OKPacket(0x0, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
99+
else if (mysql_context->client_capabilities & CLIENT_DEPRECATE_EOF)
84100
packet_endpoint->sendPacket(
85-
OKPacket(0xfe, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
101+
OKPacket(0xfe, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
86102
else
87103
packet_endpoint->sendPacket(EOFPacket(0, 0), true);
88104
}

src/Processors/Formats/Impl/MySQLOutputFormat.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@ class MySQLOutputFormat final : public IOutputFormat, WithContext
2525

2626
String getName() const override { return "MySQLOutputFormat"; }
2727

28-
void setContext(ContextPtr context_)
29-
{
30-
context = context_;
31-
packet_endpoint = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(getContext()->mysql.sequence_id)); /// TODO: fix it
32-
}
28+
void setContext(ContextPtr context_);
3329

3430
void consume(Chunk) override;
3531
void finalize() override;
@@ -41,7 +37,9 @@ class MySQLOutputFormat final : public IOutputFormat, WithContext
4137
private:
4238
bool initialized = false;
4339

44-
std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
40+
std::optional<MySQLWireContext> own_mysql_context;
41+
MySQLWireContext * mysql_context = nullptr;
42+
MySQLProtocol::PacketEndpointPtr packet_endpoint;
4543
FormatSettings format_settings;
4644
DataTypes data_types;
4745
Serializations serializations;

src/Server/MySQLHandler.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,11 @@ void MySQLHandler::run()
9595
connection_context->getClientInfo().interface = ClientInfo::Interface::MYSQL;
9696
connection_context->setDefaultFormat("MySQLWire");
9797
connection_context->getClientInfo().connection_id = connection_id;
98+
connection_context->setMySQLProtocolContext(&connection_context_mysql);
9899

99100
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
100101
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
101-
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context->mysql.sequence_id);
102+
packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out);
102103

103104
try
104105
{
@@ -110,11 +111,11 @@ void MySQLHandler::run()
110111

111112
HandshakeResponse handshake_response;
112113
finishHandshake(handshake_response);
113-
connection_context->mysql.client_capabilities = handshake_response.capability_flags;
114+
connection_context_mysql.client_capabilities = handshake_response.capability_flags;
114115
if (handshake_response.max_packet_size)
115-
connection_context->mysql.max_packet_size = handshake_response.max_packet_size;
116-
if (!connection_context->mysql.max_packet_size)
117-
connection_context->mysql.max_packet_size = MAX_PACKET_LENGTH;
116+
connection_context_mysql.max_packet_size = handshake_response.max_packet_size;
117+
if (!connection_context_mysql.max_packet_size)
118+
connection_context_mysql.max_packet_size = MAX_PACKET_LENGTH;
118119

119120
LOG_TRACE(log,
120121
"Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}",
@@ -395,14 +396,14 @@ void MySQLHandlerSSL::finishHandshakeSSL(
395396
ReadBufferFromMemory payload(buf, pos);
396397
payload.ignore(PACKET_HEADER_SIZE);
397398
ssl_request.readPayloadWithUnpacked(payload);
398-
connection_context->mysql.client_capabilities = ssl_request.capability_flags;
399-
connection_context->mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
399+
connection_context_mysql.client_capabilities = ssl_request.capability_flags;
400+
connection_context_mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
400401
secure_connection = true;
401402
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
402403
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
403404
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
404-
connection_context->mysql.sequence_id = 2;
405-
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context->mysql.sequence_id);
405+
connection_context_mysql.sequence_id = 2;
406+
packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out);
406407
packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
407408
}
408409

0 commit comments

Comments
 (0)