// required data message = 3;
uint32 last_type_id = 0;
+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
+ State state = State::kNoTag;
// If we see message data before the type_id, we'll append it to this so
// we can parse it later.
case WireFormatLite::kMessageSetTypeIdTag: {
uint32 type_id;
if (!input->ReadVarint32(&type_id)) return false;
- last_type_id = type_id;
- if (!message_data.empty()) {
+ if (state == State::kNoTag) {
+ last_type_id = type_id;
+ state = State::kHasType;
+ } else if (state == State::kHasPayload) {
+ last_type_id = type_id;
// We saw some message data before the type_id. Have to parse it
// now.
io::CodedInputStream sub_input(
return false;
}
message_data.clear();
+ state = State::kDone;
}
break;
}
case WireFormatLite::kMessageSetMessageTag: {
- if (last_type_id == 0) {
+ if (state != State::kHasType) {
// We haven't seen a type_id yet. Append this data to message_data.
string temp;
uint32 length;
if (!input->ReadVarint32(&length)) return false;
if (!input->ReadString(&temp, length)) return false;
- io::StringOutputStream output_stream(&message_data);
- io::CodedOutputStream coded_output(&output_stream);
- coded_output.WriteVarint32(length);
- coded_output.WriteString(temp);
+ if (state == State::kNoTag) {
+ io::StringOutputStream output_stream(&message_data);
+ io::CodedOutputStream coded_output(&output_stream);
+ coded_output.WriteVarint32(length);
+ coded_output.WriteString(temp);
+ state = State::kHasPayload;
+ }
} else {
// Already saw type_id, so we can parse this directly.
if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
extension_finder, field_skipper)) {
return false;
}
+ state = State::kDone;
}
break;
// required data message = 3;
uint32 last_type_id = 0;
-
- // Once we see a type_id, we'll look up the FieldDescriptor for the
- // extension.
- const FieldDescriptor* field = NULL;
+ enum class State { kNoTag, kHasType, kHasPayload, kDone };
+ State state = State::kNoTag;
// If we see message data before the type_id, we'll append it to this so
// we can parse it later.
case WireFormatLite::kMessageSetTypeIdTag: {
uint32 type_id;
if (!input->ReadVarint32(&type_id)) return false;
- last_type_id = type_id;
- field = message_reflection->FindKnownExtensionByNumber(type_id);
- if (!message_data.empty()) {
+ if (state == State::kNoTag) {
+ last_type_id = type_id;
+ state = State::kHasType;
+ } else if (state == State::kHasPayload) {
// We saw some message data before the type_id. Have to parse it
// now.
+ last_type_id = type_id;
+ const FieldDescriptor* field = message_reflection->FindKnownExtensionByNumber(type_id);
io::ArrayInputStream raw_input(message_data.data(),
message_data.size());
io::CodedInputStream sub_input(&raw_input);
return false;
}
message_data.clear();
+ state = State::kDone;
}
break;
}
case WireFormatLite::kMessageSetMessageTag: {
- if (last_type_id == 0) {
+ if (state == State::kNoTag) {
// We haven't seen a type_id yet. Append this data to message_data.
string temp;
uint32 length;
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteVarint32(length);
coded_output.WriteString(temp);
- } else {
+ state = State::kHasPayload;
+ } else if (state == State::kHasType) {
// Already saw type_id, so we can parse this directly.
+ const FieldDescriptor* field = message_reflection->FindKnownExtensionByNumber(last_type_id);
if (!ParseAndMergeMessageSetField(last_type_id, field, message,
input)) {
return false;
}
+ state = State::kDone;
+ } else {
+ if (!SkipField(input, tag, NULL)) return false;
}
break;
#include <google/protobuf/wire_format.h>
#include <google/protobuf/wire_format_lite_inl.h>
#include <google/protobuf/descriptor.h>
+#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/unittest.pb.h>
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
}
-TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
+namespace {
+std::string BuildMessageSetItemStart() {
string data;
{
- unittest::TestMessageSetExtension1 message;
- message.set_i(123);
- // Build a MessageSet manually with its message content put before its
- // type_id.
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
+ }
+ return data;
+}
+std::string BuildMessageSetItemEnd() {
+ std::string data;
+ {
+ io::StringOutputStream output_stream(&data);
+ io::CodedOutputStream coded_output(&output_stream);
+ coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+ }
+ return data;
+}
+std::string BuildMessageSetTestExtension1(int value = 123) {
+ std::string data;
+ {
+ unittest::TestMessageSetExtension1 message;
+ message.set_i(value);
+ io::StringOutputStream output_stream(&data);
+ io::CodedOutputStream coded_output(&output_stream);
// Write the message content first.
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&coded_output);
coded_output.WriteVarint32(message.ByteSize());
message.SerializeWithCachedSizes(&coded_output);
- // Write the type id.
- uint32 type_id = message.GetDescriptor()->extension(0)->number();
+ }
+ return data;
+}
+std::string BuildMessageSetItemTypeId(int extension_number) {
+ std::string data;
+ {
+ io::StringOutputStream output_stream(&data);
+ io::CodedOutputStream coded_output(&output_stream);
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
- type_id, &coded_output);
- coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
+ extension_number, &coded_output);
}
+ return data;
+}
+void ValidateTestMessageSet(const std::string& test_case,
+ const std::string& data) {
+ SCOPED_TRACE(test_case);
{
proto2_wireformat_unittest::TestMessageSet message_set;
ASSERT_TRUE(message_set.ParseFromString(data));
EXPECT_EQ(123, message_set.GetExtension(
unittest::TestMessageSetExtension1::message_set_extension).i());
+
+ // Make sure it does not contain anything else.
+ message_set.ClearExtension(
+ unittest::TestMessageSetExtension1::message_set_extension);
+ EXPECT_EQ(message_set.SerializeAsString(), "");
}
{
// Test parse the message via Reflection.
EXPECT_EQ(123, message_set.GetExtension(
unittest::TestMessageSetExtension1::message_set_extension).i());
}
+ {
+ // Test parse the message via DynamicMessage.
+ DynamicMessageFactory factory;
+ std::unique_ptr<Message> msg(
+ factory
+ .GetPrototype(
+ proto2_wireformat_unittest::TestMessageSet::descriptor())
+ ->New());
+ msg->ParseFromString(data);
+ auto* reflection = msg->GetReflection();
+ std::vector<const FieldDescriptor*> fields;
+ reflection->ListFields(*msg, &fields);
+ ASSERT_EQ(fields.size(), 1);
+ const auto& sub = reflection->GetMessage(*msg, fields[0]);
+ reflection = sub.GetReflection();
+ EXPECT_EQ(123, reflection->GetInt32(
+ sub, sub.GetDescriptor()->FindFieldByName("i")));
+ }
+}
+} // namespace
+
+TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
+ std::string start = BuildMessageSetItemStart();
+ std::string end = BuildMessageSetItemEnd();
+ std::string id = BuildMessageSetItemTypeId(
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+ std::string message = BuildMessageSetTestExtension1();
+
+ ValidateTestMessageSet("id + message", start + id + message + end);
+ ValidateTestMessageSet("message + id", start + message + id + end);
+}
+
+TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
+ std::string start = BuildMessageSetItemStart();
+ std::string end = BuildMessageSetItemEnd();
+ std::string id = BuildMessageSetItemTypeId(
+ unittest::TestMessageSetExtension1::descriptor()->extension(0)->number());
+ std::string other_id = BuildMessageSetItemTypeId(123456);
+ std::string message = BuildMessageSetTestExtension1();
+ std::string other_message = BuildMessageSetTestExtension1(321);
+
+ // Double id
+ ValidateTestMessageSet("id + other_id + message",
+ start + id + other_id + message + end);
+ ValidateTestMessageSet("id + message + other_id",
+ start + id + message + other_id + end);
+ ValidateTestMessageSet("message + id + other_id",
+ start + message + id + other_id + end);
+ // Double message
+ ValidateTestMessageSet("id + message + other_message",
+ start + id + message + other_message + end);
+ ValidateTestMessageSet("message + id + other_message",
+ start + message + id + other_message + end);
+ ValidateTestMessageSet("message + other_message + id",
+ start + message + other_message + id + end);
}
TEST(WireFormatTest, ParseBrokenMessageSet) {