-
Notifications
You must be signed in to change notification settings - Fork 723
[Store] Add ObjectDataType enum for type-aware metadata #1719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
55c7893
bfbdef2
e3eb2f1
96bad43
e68abff
515b36e
c632b5c
011f3ce
0a85afb
dd64ff2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,6 +123,44 @@ static constexpr uint64_t DEFAULT_PROCESSING_TASK_TIMEOUT_SEC = | |
| 300; // 0 to be no timeout | ||
| static constexpr uint32_t DEFAULT_MAX_RETRY_ATTEMPTS = 10; | ||
|
|
||
| /** | ||
| * @brief Data type classification for objects stored in Mooncake Store. | ||
| * | ||
| * This allows the store to track what kind of data each object holds, | ||
| * enabling future type-aware policies (eviction priority, replication | ||
| * strategies, etc.). Defaults to UNKNOWN for backward compatibility. | ||
| */ | ||
| enum class ObjectDataType : uint8_t { | ||
| UNKNOWN = 0, | ||
| KVCACHE = 1, | ||
| TENSOR = 2, | ||
| WEIGHT = 3, | ||
| SAMPLE = 4, | ||
| ACTIVATION = 5, | ||
| GRADIENT = 6, | ||
| OPTIMIZER_STATE = 7, | ||
| METADATA = 8, | ||
| // 9-255 reserved for future types | ||
| }; | ||
|
|
||
| inline std::ostream& operator<<(std::ostream& os, | ||
| const ObjectDataType& type) noexcept { | ||
| static const std::unordered_map<ObjectDataType, std::string_view> | ||
| type_strings{{ObjectDataType::UNKNOWN, "UNKNOWN"}, | ||
| {ObjectDataType::KVCACHE, "KVCACHE"}, | ||
| {ObjectDataType::TENSOR, "TENSOR"}, | ||
| {ObjectDataType::WEIGHT, "WEIGHT"}, | ||
| {ObjectDataType::SAMPLE, "SAMPLE"}, | ||
| {ObjectDataType::ACTIVATION, "ACTIVATION"}, | ||
| {ObjectDataType::GRADIENT, "GRADIENT"}, | ||
| {ObjectDataType::OPTIMIZER_STATE, "OPTIMIZER_STATE"}, | ||
| {ObjectDataType::METADATA, "METADATA"}}; | ||
|
|
||
| auto it = type_strings.find(type); | ||
| os << (it != type_strings.end() ? it->second : "UNKNOWN"); | ||
| return os; | ||
| } | ||
|
Comment on lines
+148
to
+165
|
||
|
|
||
| // Forward declarations | ||
| class BufferAllocatorBase; | ||
| class CachelibBufferAllocator; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -766,7 +766,7 @@ auto MasterService::PutStart(const UUID& client_id, const std::string& key, | |
| shard->metadata.emplace( | ||
| std::piecewise_construct, std::forward_as_tuple(key), | ||
| std::forward_as_tuple(client_id, now, total_length, std::move(replicas), | ||
| config.with_soft_pin)); | ||
| config.with_soft_pin, config.data_type)); | ||
| // Also insert the metadata into processing set for monitoring. | ||
| shard->processing_keys.insert(key); | ||
|
|
||
|
|
@@ -3485,7 +3485,8 @@ MasterService::MetadataSerializer::DeserializeShard(const msgpack::object& obj, | |
| std::forward_as_tuple( | ||
| metadata_ptr->client_id, metadata_ptr->put_start_time, | ||
| metadata_ptr->size, metadata_ptr->PopReplicas(), | ||
| metadata_ptr->soft_pin_timeout.has_value())); | ||
| metadata_ptr->soft_pin_timeout.has_value(), | ||
| metadata_ptr->data_type)); | ||
|
|
||
| it->second.lease_timeout = metadata_ptr->lease_timeout; | ||
| it->second.soft_pin_timeout = metadata_ptr->soft_pin_timeout; | ||
|
|
@@ -3500,10 +3501,12 @@ MasterService::MetadataSerializer::SerializeMetadata( | |
| MsgpackPacker& packer) const { | ||
| // Pack ObjectMetadata using array structure for efficiency | ||
| // Format: [client_id, put_start_time, size, lease_timeout, | ||
| // has_soft_pin_timeout, soft_pin_timeout, replicas_count, replicas...] | ||
| // has_soft_pin_timeout, soft_pin_timeout, replicas_count, data_type, | ||
| // replicas...] | ||
|
|
||
| size_t array_size = 7; // size, lease_timeout, has_soft_pin_timeout, | ||
| // soft_pin_timeout, replicas_count | ||
| size_t array_size = 8; // client_id, put_start_time, size, lease_timeout, | ||
| // has_soft_pin_timeout, soft_pin_timeout, | ||
| // replicas_count, data_type | ||
| array_size += metadata.CountReplicas(); // One element per replica | ||
| packer.pack_array(array_size); | ||
|
|
||
|
|
@@ -3543,6 +3546,9 @@ MasterService::MetadataSerializer::SerializeMetadata( | |
| // Serialize replicas count | ||
| packer.pack(static_cast<uint32_t>(metadata.CountReplicas())); | ||
|
|
||
| // Serialize data_type | ||
| packer.pack(static_cast<uint8_t>(metadata.data_type)); | ||
|
|
||
|
Comment on lines
4274
to
+4323
|
||
| // Serialize replicas | ||
| for (const auto& replica : metadata.GetAllReplicas()) { | ||
| auto result = Serializer<Replica>::serialize( | ||
|
|
@@ -3565,8 +3571,7 @@ MasterService::MetadataSerializer::DeserializeMetadata( | |
| "deserialize ObjectMetadata state is not an array")); | ||
| } | ||
|
|
||
| // Need at least 7 elements: client_id, put_start_time, size, lease_timeout, | ||
| // has_soft_pin_timeout, soft_pin_timeout, replicas_count | ||
| // Need at least 7 elements for old format (without data_type), 8 for new | ||
| if (obj.via.array.size < 7) { | ||
| return tl::unexpected(SerializationError( | ||
| ErrorCode::DESERIALIZE_FAIL, | ||
|
|
@@ -3599,8 +3604,14 @@ MasterService::MetadataSerializer::DeserializeMetadata( | |
| // Deserialize replicas count | ||
| uint32_t replicas_count = array[index++].as<uint32_t>(); | ||
|
|
||
| // Check if array size matches replicas_count | ||
| if (obj.via.array.size != 7 + replicas_count) { | ||
| // Detect old vs new format based on total array size. | ||
| // Old format: 7 + replicas_count elements | ||
| // New format: 8 + replicas_count elements (with data_type after | ||
| // replicas_count) | ||
| ObjectDataType data_type = ObjectDataType::UNKNOWN; | ||
| if (obj.via.array.size == 8 + replicas_count) { | ||
| data_type = static_cast<ObjectDataType>(array[index++].as<uint8_t>()); | ||
| } else if (obj.via.array.size != 7 + replicas_count) { | ||
| return tl::unexpected(SerializationError( | ||
| ErrorCode::DESERIALIZE_FAIL, | ||
| "deserialize ObjectMetadata array size mismatch")); | ||
|
|
@@ -3625,7 +3636,7 @@ MasterService::MetadataSerializer::DeserializeMetadata( | |
| client_id, | ||
| std::chrono::system_clock::time_point( | ||
| std::chrono::milliseconds(put_start_time_timestamp)), | ||
| size, std::move(replicas), enable_soft_pin); | ||
| size, std::move(replicas), enable_soft_pin, data_type); | ||
| metadata->lease_timeout = std::chrono::system_clock::time_point( | ||
| std::chrono::milliseconds(lease_timestamp)); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| #include "types.h" | ||
| #include "replica.h" | ||
| #include "master_service.h" | ||
|
|
||
| #include <glog/logging.h> | ||
| #include <gtest/gtest.h> | ||
|
|
||
| #include <sstream> | ||
| #include <vector> | ||
|
|
||
| namespace mooncake::test { | ||
|
|
||
| class ObjectDataTypeTest : public ::testing::Test { | ||
| protected: | ||
| void SetUp() override { | ||
| google::InitGoogleLogging("ObjectDataTypeTest"); | ||
| FLAGS_logtostderr = true; | ||
| } | ||
|
|
||
| void TearDown() override { google::ShutdownGoogleLogging(); } | ||
|
|
||
| static constexpr size_t kDefaultSegmentBase = 0x300000000; | ||
| static constexpr size_t kDefaultSegmentSize = 1024 * 1024 * 16; | ||
|
|
||
| Segment MakeSegment(std::string name = "test_segment", | ||
| size_t base = kDefaultSegmentBase, | ||
| size_t size = kDefaultSegmentSize) const { | ||
| Segment segment; | ||
| segment.id = generate_uuid(); | ||
| segment.name = std::move(name); | ||
| segment.base = base; | ||
| segment.size = size; | ||
| segment.te_endpoint = segment.name; | ||
| return segment; | ||
| } | ||
| }; | ||
|
|
||
| // Verify enum values match the RFC spec | ||
| TEST_F(ObjectDataTypeTest, EnumValues) { | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::UNKNOWN), 0); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::KVCACHE), 1); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::TENSOR), 2); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::WEIGHT), 3); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::SAMPLE), 4); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::ACTIVATION), 5); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::GRADIENT), 6); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::OPTIMIZER_STATE), 7); | ||
| EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::METADATA), 8); | ||
| } | ||
|
|
||
| // Verify stream operator produces readable output | ||
| TEST_F(ObjectDataTypeTest, StreamOperator) { | ||
| std::ostringstream oss; | ||
| oss << ObjectDataType::KVCACHE; | ||
| EXPECT_EQ(oss.str(), "KVCACHE"); | ||
|
|
||
| oss.str(""); | ||
| oss << ObjectDataType::UNKNOWN; | ||
| EXPECT_EQ(oss.str(), "UNKNOWN"); | ||
|
|
||
| oss.str(""); | ||
| oss << ObjectDataType::OPTIMIZER_STATE; | ||
| EXPECT_EQ(oss.str(), "OPTIMIZER_STATE"); | ||
|
|
||
| // Out-of-range value should print "UNKNOWN" | ||
| oss.str(""); | ||
| oss << static_cast<ObjectDataType>(200); | ||
| EXPECT_EQ(oss.str(), "UNKNOWN"); | ||
| } | ||
|
|
||
| // ReplicateConfig defaults to UNKNOWN | ||
| TEST_F(ObjectDataTypeTest, ReplicateConfigDefaultDataType) { | ||
| ReplicateConfig config; | ||
| EXPECT_EQ(config.data_type, ObjectDataType::UNKNOWN); | ||
| } | ||
|
|
||
| // ReplicateConfig can be set to other types | ||
| TEST_F(ObjectDataTypeTest, ReplicateConfigSetDataType) { | ||
| ReplicateConfig config; | ||
| config.data_type = ObjectDataType::WEIGHT; | ||
| EXPECT_EQ(config.data_type, ObjectDataType::WEIGHT); | ||
| } | ||
|
|
||
| // ReplicateConfig stream output includes data_type | ||
| TEST_F(ObjectDataTypeTest, ReplicateConfigStreamIncludesDataType) { | ||
| ReplicateConfig config; | ||
| config.data_type = ObjectDataType::TENSOR; | ||
| std::ostringstream oss; | ||
| oss << config; | ||
| EXPECT_NE(oss.str().find("data_type: TENSOR"), std::string::npos); | ||
| } | ||
|
|
||
| // PutStart with data_type propagates to ObjectMetadata | ||
| TEST_F(ObjectDataTypeTest, PutStartWithDataType) { | ||
| std::unique_ptr<MasterService> service(new MasterService()); | ||
| Segment segment = MakeSegment(); | ||
| UUID client_id = generate_uuid(); | ||
| auto mount_result = service->MountSegment(segment, client_id); | ||
| ASSERT_TRUE(mount_result.has_value()); | ||
|
|
||
| UUID put_client = generate_uuid(); | ||
|
|
||
| // Put with WEIGHT type | ||
| ReplicateConfig config; | ||
| config.replica_num = 1; | ||
| config.data_type = ObjectDataType::WEIGHT; | ||
|
|
||
| auto result = service->PutStart(put_client, "key_weight", 1024, config); | ||
| ASSERT_TRUE(result.has_value()); | ||
| EXPECT_FALSE(result.value().empty()); | ||
|
|
||
| auto end_result = | ||
| service->PutEnd(put_client, "key_weight", ReplicaType::MEMORY); | ||
| EXPECT_TRUE(end_result.has_value()); | ||
| } | ||
|
Comment on lines
+98
to
+120
|
||
|
|
||
| // PutStart with default UNKNOWN data_type still works (backward compat) | ||
| TEST_F(ObjectDataTypeTest, PutStartDefaultDataType) { | ||
| std::unique_ptr<MasterService> service(new MasterService()); | ||
| Segment segment = MakeSegment(); | ||
| UUID client_id = generate_uuid(); | ||
| auto mount_result = service->MountSegment(segment, client_id); | ||
| ASSERT_TRUE(mount_result.has_value()); | ||
|
|
||
| UUID put_client = generate_uuid(); | ||
| ReplicateConfig config; | ||
| config.replica_num = 1; | ||
| // data_type left as default (UNKNOWN) | ||
|
|
||
| auto result = service->PutStart(put_client, "key_default", 1024, config); | ||
| ASSERT_TRUE(result.has_value()); | ||
| EXPECT_FALSE(result.value().empty()); | ||
| } | ||
|
|
||
| // Verify all enum values can roundtrip through uint8_t cast | ||
| TEST_F(ObjectDataTypeTest, EnumRoundtrip) { | ||
| std::vector<ObjectDataType> all_types = { | ||
| ObjectDataType::UNKNOWN, ObjectDataType::KVCACHE, | ||
| ObjectDataType::TENSOR, ObjectDataType::WEIGHT, | ||
| ObjectDataType::SAMPLE, ObjectDataType::ACTIVATION, | ||
| ObjectDataType::GRADIENT, ObjectDataType::OPTIMIZER_STATE, | ||
| ObjectDataType::METADATA, | ||
| }; | ||
|
|
||
| for (auto type : all_types) { | ||
| uint8_t raw = static_cast<uint8_t>(type); | ||
| auto recovered = static_cast<ObjectDataType>(raw); | ||
| EXPECT_EQ(type, recovered); | ||
| } | ||
| } | ||
|
|
||
| } // namespace mooncake::test | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add a
generaltype?