Skip to content

Commit 256b76e

Browse files
committed
Rename content_type column to req_type for tls_events. Remove version col in favor of later use of req_body/resp_body. Clean up duplicated body field on tls::Frame type
Signed-off-by: Dom Del Nano <[email protected]>
1 parent 3f80935 commit 256b76e

File tree

7 files changed

+43
-64
lines changed

7 files changed

+43
-64
lines changed

src/stirling/source_connectors/socket_tracer/protocols/tls/parse.cc

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/parse.h"
1919

2020
#include <map>
21+
#include <memory>
2122
#include <string>
2223
#include <utility>
2324
#include <vector>
@@ -41,7 +42,7 @@ constexpr size_t kSNIExtensionMinimumLength = 3;
4142
// In TLS 1.2 and earlier, gmt_unix_time is 4 bytes and Random is 28 bytes.
4243
constexpr size_t kRandomStructLength = 32;
4344

44-
StatusOr<ParseState> ExtractSNIExtension(ReqExtensions* exts, BinaryDecoder* decoder) {
45+
StatusOr<ParseState> ExtractSNIExtension(SharedExtensions* exts, BinaryDecoder* decoder) {
4546
PX_ASSIGN_OR(auto server_name_list_length, decoder->ExtractBEInt<uint16_t>(),
4647
return ParseState::kInvalid);
4748
while (server_name_list_length > 0) {
@@ -75,7 +76,7 @@ StatusOr<ParseState> ExtractSNIExtension(ReqExtensions* exts, BinaryDecoder* dec
7576
* diagram: https://en.wikipedia.org/wiki/Transport_Layer_Security#TLS_record
7677
*/
7778

78-
ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {
79+
ParseState ParseFullFrame(SharedExtensions* extensions, BinaryDecoder* decoder, Frame* frame) {
7980
PX_ASSIGN_OR(auto raw_content_type, decoder->ExtractBEInt<uint8_t>(),
8081
return ParseState::kInvalid);
8182
auto content_type = magic_enum::enum_cast<tls::ContentType>(raw_content_type);
@@ -161,8 +162,6 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {
161162
return ParseState::kSuccess;
162163
}
163164

164-
ReqExtensions req_extensions;
165-
RespExtensions resp_extensions;
166165
while (extensions_length > 0) {
167166
PX_ASSIGN_OR(auto extension_type, decoder->ExtractBEInt<uint16_t>(),
168167
return ParseState::kInvalid);
@@ -171,7 +170,7 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {
171170

172171
if (extension_length > 0) {
173172
if (extension_type == 0x00) {
174-
if (!ExtractSNIExtension(&req_extensions, decoder).ok()) {
173+
if (!ExtractSNIExtension(extensions, decoder).ok()) {
175174
return ParseState::kInvalid;
176175
}
177176
} else {
@@ -183,21 +182,17 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {
183182

184183
extensions_length -= kExtensionMinimumLength + extension_length;
185184
}
186-
JSONObjectBuilder req_body_builder;
187-
req_body_builder.WriteKVRecursive("extensions", req_extensions);
188-
frame->req_body = req_body_builder.GetString();
189-
190-
JSONObjectBuilder resp_body_builder;
191-
resp_body_builder.WriteKVRecursive("extensions", resp_extensions);
192-
frame->resp_body = resp_body_builder.GetString();
185+
JSONObjectBuilder body_builder;
186+
body_builder.WriteKVRecursive("extensions", *extensions);
187+
frame->body = body_builder.GetString();
193188

194189
return ParseState::kSuccess;
195190
}
196191

197192
} // namespace tls
198193

199194
template <>
200-
ParseState ParseFrame(message_type_t, std::string_view* buf, tls::Frame* frame, NoState*) {
195+
ParseState ParseFrame(message_type_t type, std::string_view* buf, tls::Frame* frame, NoState*) {
201196
// TLS record header is 5 bytes. The size of the record is in bytes 4 and 5.
202197
if (buf->length() < tls::kTLSRecordHeaderLength) {
203198
return ParseState::kNeedsMoreData;
@@ -208,7 +203,13 @@ ParseState ParseFrame(message_type_t, std::string_view* buf, tls::Frame* frame,
208203
}
209204

210205
BinaryDecoder decoder(*buf);
211-
auto parse_result = tls::ParseFullFrame(&decoder, frame);
206+
std::unique_ptr<tls::SharedExtensions> extensions;
207+
if (type == kRequest) {
208+
extensions = std::make_unique<tls::ReqExtensions>();
209+
} else {
210+
extensions = std::make_unique<tls::RespExtensions>();
211+
}
212+
auto parse_result = tls::ParseFullFrame(extensions.get(), &decoder, frame);
212213
if (parse_result == ParseState::kSuccess) {
213214
buf->remove_prefix(length + tls::kTLSRecordHeaderLength);
214215
}

src/stirling/source_connectors/socket_tracer/protocols/tls/parse.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace stirling {
2828
namespace protocols {
2929
namespace tls {
3030

31-
ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame);
31+
ParseState ParseFullFrame(SharedExtensions* extensions, BinaryDecoder* decoder, Frame* frame);
3232

3333
}
3434

src/stirling/source_connectors/socket_tracer/protocols/tls/types.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,20 +187,23 @@ enum class ExtensionType : uint16_t {
187187
// Extensions that are common to both the client and server side
188188
// of a TLS handshake
189189
struct SharedExtensions {
190-
void ToJSON(::px::utils::JSONObjectBuilder* /*builder*/) const {}
190+
std::vector<std::string> server_names;
191+
192+
virtual void ToJSON(::px::utils::JSONObjectBuilder* /*builder*/) const {}
193+
virtual ~SharedExtensions() = default;
191194
};
192195

193196
struct ReqExtensions : public SharedExtensions {
194-
std::vector<std::string> server_names;
195-
196-
void ToJSON(::px::utils::JSONObjectBuilder* builder) const {
197+
void ToJSON(::px::utils::JSONObjectBuilder* builder) const override {
197198
SharedExtensions::ToJSON(builder);
198199
builder->WriteKV("server_name", server_names);
199200
}
200201
};
201202

202203
struct RespExtensions : public SharedExtensions {
203-
void ToJSON(::px::utils::JSONObjectBuilder* builder) const { SharedExtensions::ToJSON(builder); }
204+
void ToJSON(::px::utils::JSONObjectBuilder* builder) const override {
205+
SharedExtensions::ToJSON(builder);
206+
}
204207
};
205208

206209
struct Frame : public FrameBase {
@@ -217,8 +220,7 @@ struct Frame : public FrameBase {
217220
LegacyVersion handshake_version;
218221

219222
std::string session_id;
220-
std::string req_body;
221-
std::string resp_body;
223+
std::string body;
222224

223225
bool consumed = false;
224226

@@ -227,9 +229,8 @@ struct Frame : public FrameBase {
227229
std::string ToString() const override {
228230
return absl::Substitute(
229231
"TLS Frame [len=$0 content_type=$1 legacy_version=$2 handshake_version=$3 "
230-
"handshake_type=$4 req_body=$5 resp_body=$6]",
231-
length, content_type, legacy_version, handshake_version, handshake_type, req_body,
232-
resp_body);
232+
"handshake_type=$4 body=$5]",
233+
length, content_type, legacy_version, handshake_version, handshake_type, body);
233234
}
234235
};
235236

src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,10 +1721,9 @@ void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracke
17211721
r.Append<r.ColIndex("local_addr")>(conn_tracker.local_endpoint().AddrStr());
17221722
r.Append<r.ColIndex("local_port")>(conn_tracker.local_endpoint().port());
17231723
r.Append<r.ColIndex("trace_role")>(conn_tracker.role());
1724-
r.Append<r.ColIndex("content_type")>(static_cast<uint64_t>(req_message.content_type));
1725-
r.Append<r.ColIndex("version")>(static_cast<uint64_t>(req_message.legacy_version));
1726-
r.Append<r.ColIndex("req_body")>(req_message.req_body, kMaxTLSBodyBytes);
1727-
r.Append<r.ColIndex("resp_body")>(resp_message.resp_body, kMaxTLSBodyBytes);
1724+
r.Append<r.ColIndex("req_type")>(static_cast<uint64_t>(req_message.content_type));
1725+
r.Append<r.ColIndex("req_body")>(req_message.body, kMaxTLSBodyBytes);
1726+
r.Append<r.ColIndex("resp_body")>(resp_message.body, kMaxTLSBodyBytes);
17281727
r.Append<r.ColIndex("latency")>(
17291728
CalculateLatency(req_message.timestamp_ns, resp_message.timestamp_ns));
17301729
#ifndef NDEBUG

src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ std::vector<tls::Record> ToRecordVector(const types::ColumnWrapperRecordBatch& r
113113
std::vector<tls::Record> result;
114114

115115
for (const auto& idx : indices) {
116-
auto version = rb[kTLSVersionIdx]->Get<types::Int64Value>(idx);
117116
tls::Record r;
118-
r.req.legacy_version = static_cast<tls::LegacyVersion>(version.val);
117+
r.req.body = rb[kTLSReqBodyIdx]->Get<types::StringValue>(idx);
118+
r.resp.body = rb[kTLSRespBodyIdx]->Get<types::StringValue>(idx);
119119
result.push_back(r);
120120
}
121121
return result;

src/stirling/source_connectors/socket_tracer/tls_table.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,15 @@ static constexpr DataElement kTLSElements[] = {
3737
canonical_data_elements::kLocalAddr,
3838
canonical_data_elements::kLocalPort,
3939
canonical_data_elements::kTraceRole,
40-
{"content_type", "The content type of the TLS record (e.g. handshake, alert, heartbeat, etc)",
40+
{"req_type", "The content type of the TLS record (e.g. handshake, alert, heartbeat, etc)",
4141
types::DataType::INT64,
4242
types::SemanticType::ST_NONE,
4343
types::PatternType::GENERAL_ENUM},
44-
{"version", "Version of TLS record",
45-
types::DataType::INT64,
46-
types::SemanticType::ST_NONE,
47-
types::PatternType::GENERAL_ENUM},
48-
{"req_body", "Request body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions)",
44+
{"req_body", "Request body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions, version negotiated, etc.)",
4945
types::DataType::STRING,
5046
types::SemanticType::ST_NONE,
5147
types::PatternType::STRUCTURED},
52-
{"resp_body", "Response body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions)",
48+
{"resp_body", "Response body in JSON format. Structure depends on content type (e.g. handshakes contain TLS extensions, version negotiated, etc.)",
5349
types::DataType::STRING,
5450
types::SemanticType::ST_NONE,
5551
types::PatternType::STRUCTURED},
@@ -65,9 +61,9 @@ static constexpr auto kTLSTable =
6561
DEFINE_PRINT_TABLE(TLS)
6662

6763
constexpr int kTLSUPIDIdx = kTLSTable.ColIndex("upid");
68-
constexpr int kTLSCmdIdx = kTLSTable.ColIndex("content_type");
69-
constexpr int kTLSVersionIdx = kTLSTable.ColIndex("version");
64+
constexpr int kTLSCmdIdx = kTLSTable.ColIndex("req_type");
7065
constexpr int kTLSReqBodyIdx = kTLSTable.ColIndex("req_body");
66+
constexpr int kTLSRespBodyIdx = kTLSTable.ColIndex("resp_body");
7167

7268
} // namespace stirling
7369
} // namespace px

src/stirling/source_connectors/socket_tracer/tls_trace_bpf_test.cc

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,6 @@ using ::testing::SizeIs;
4848
using ::testing::StrEq;
4949
using ::testing::UnorderedElementsAre;
5050

51-
struct TraceRecords {
52-
std::vector<tls::Record> tls_records;
53-
std::vector<std::string> req_body;
54-
};
55-
5651
class NginxOpenSSL_3_0_8_ContainerWrapper
5752
: public ::px::stirling::testing::NginxOpenSSL_3_0_8_Container {
5853
public:
@@ -80,15 +75,6 @@ tls::Record GetExpectedTLSRecord() {
8075
return expected_record;
8176
}
8277

83-
inline std::vector<std::string> GetRequestBody(const types::ColumnWrapperRecordBatch& rb,
84-
const std::vector<size_t>& indices) {
85-
std::vector<std::string> exts;
86-
for (size_t idx : indices) {
87-
exts.push_back(rb[kTLSReqBodyIdx]->Get<types::StringValue>(idx));
88-
}
89-
return exts;
90-
}
91-
9278
class TLSVersionParameterizedTest
9379
: public SocketTraceBPFTestFixture</* TClientSideTracing */ false>,
9480
public ::testing::WithParamInterface<std::string> {
@@ -125,15 +111,15 @@ class TLSVersionParameterizedTest
125111
client.Wait();
126112
this->StopTransferDataThread();
127113

128-
TraceRecords records = this->GetTraceRecords(this->server_.PID());
129-
EXPECT_THAT(records.tls_records, SizeIs(1));
130-
EXPECT_THAT(records.req_body, SizeIs(1));
114+
auto records = this->GetTraceRecords(this->server_.PID());
115+
EXPECT_THAT(records, SizeIs(1));
116+
EXPECT_GT(records[0].req.body.size(), 0);
131117
auto sni_str = R"({"extensions":{"server_name":["test-host"]}})";
132-
EXPECT_THAT(records.req_body[0], StrEq(sni_str));
118+
EXPECT_THAT(records[0].req.body, StrEq(sni_str));
133119
}
134120

135121
// Returns the trace records of the process specified by the input pid.
136-
TraceRecords GetTraceRecords(int pid) {
122+
std::vector<tls::Record> GetTraceRecords(int pid) {
137123
std::vector<TaggedRecordBatch> tablets =
138124
this->ConsumeRecords(SocketTraceConnector::kTLSTableNum);
139125
if (tablets.empty()) {
@@ -142,11 +128,7 @@ class TLSVersionParameterizedTest
142128
types::ColumnWrapperRecordBatch record_batch = tablets[0].records;
143129
std::vector<size_t> server_record_indices =
144130
FindRecordIdxMatchesPID(record_batch, kTLSUPIDIdx, pid);
145-
std::vector<tls::Record> tls_records =
146-
ToRecordVector<tls::Record>(record_batch, server_record_indices);
147-
std::vector<std::string> extensions = GetRequestBody(record_batch, server_record_indices);
148-
149-
return {std::move(tls_records), std::move(extensions)};
131+
return ToRecordVector<tls::Record>(record_batch, server_record_indices);
150132
}
151133

152134
NginxOpenSSL_3_0_8_ContainerWrapper server_;

0 commit comments

Comments
 (0)