Skip to content
Open
14 changes: 14 additions & 0 deletions mooncake-integration/store/store_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,19 @@ class MooncakeHostMemAllocatorPyWrapper {
};

PYBIND11_MODULE(store, m) {
// Object data type classification
py::enum_<ObjectDataType>(m, "ObjectDataType")
.value("UNKNOWN", ObjectDataType::UNKNOWN)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add a general type?

.value("KVCACHE", ObjectDataType::KVCACHE)
.value("TENSOR", ObjectDataType::TENSOR)
.value("WEIGHT", ObjectDataType::WEIGHT)
.value("SAMPLE", ObjectDataType::SAMPLE)
.value("ACTIVATION", ObjectDataType::ACTIVATION)
.value("GRADIENT", ObjectDataType::GRADIENT)
.value("OPTIMIZER_STATE", ObjectDataType::OPTIMIZER_STATE)
.value("METADATA", ObjectDataType::METADATA)
.export_values();

// Define the ReplicateConfig class
py::class_<ReplicateConfig>(m, "ReplicateConfig")
.def(py::init<>())
Expand All @@ -1102,6 +1115,7 @@ PYBIND11_MODULE(store, m) {
.def_readwrite("preferred_segment", &ReplicateConfig::preferred_segment)
.def_readwrite("prefer_alloc_in_same_node",
&ReplicateConfig::prefer_alloc_in_same_node)
.def_readwrite("data_type", &ReplicateConfig::data_type)
.def("__str__", [](const ReplicateConfig &config) {
std::ostringstream oss;
oss << config;
Expand Down
11 changes: 8 additions & 3 deletions mooncake-store/include/master_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,12 @@ class MasterService {
const UUID& client_id_,
const std::chrono::system_clock::time_point put_start_time_,
size_t value_length, std::vector<Replica>&& reps,
bool enable_soft_pin)
bool enable_soft_pin,
ObjectDataType data_type_ = ObjectDataType::UNKNOWN)
: client_id(client_id_),
put_start_time(put_start_time_),
size(value_length),
data_type(data_type_),
lease_timeout(),
soft_pin_timeout(std::nullopt),
replicas_(std::move(reps)) {
Expand All @@ -507,6 +509,7 @@ class MasterService {
const UUID client_id;
const std::chrono::system_clock::time_point put_start_time;
const size_t size;
const ObjectDataType data_type{ObjectDataType::UNKNOWN};

mutable SpinLock lock;
// Default constructor, creates a time_point representing
Expand Down Expand Up @@ -897,15 +900,17 @@ class MasterService {
}

void Create(const UUID& client_id, uint64_t total_length,
std::vector<Replica> replicas, bool enable_soft_pin) {
std::vector<Replica> replicas, bool enable_soft_pin,
ObjectDataType data_type = ObjectDataType::UNKNOWN) {
if (Exists()) {
throw std::logic_error("Already exists");
}
const auto now = std::chrono::system_clock::now();
auto result = shard_guard_->metadata.emplace(
std::piecewise_construct, std::forward_as_tuple(key_),
std::forward_as_tuple(client_id, now, total_length,
std::move(replicas), enable_soft_pin));
std::move(replicas), enable_soft_pin,
data_type));
it_ = result.first;
}

Expand Down
4 changes: 3 additions & 1 deletion mooncake-store/include/replica.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct ReplicateConfig {
std::string preferred_segment{}; // Deprecated: Single preferred segment
// for backward compatibility
bool prefer_alloc_in_same_node{false};
ObjectDataType data_type{ObjectDataType::UNKNOWN};

friend std::ostream& operator<<(std::ostream& os,
const ReplicateConfig& config) noexcept {
Expand All @@ -105,7 +106,8 @@ struct ReplicateConfig {
<< config.preferred_segment;
}
os << ", prefer_alloc_in_same_node: "
<< config.prefer_alloc_in_same_node << " }";
<< config.prefer_alloc_in_same_node
<< ", data_type: " << config.data_type << " }";
return os;
}
};
Expand Down
38 changes: 38 additions & 0 deletions mooncake-store/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,44 @@ static constexpr uint64_t DEFAULT_PROCESSING_TASK_TIMEOUT_SEC =
300; // 0 to be no timeout
static constexpr uint32_t DEFAULT_MAX_RETRY_ATTEMPTS = 10;

/**
* @brief Data type classification for objects stored in Mooncake Store.
*
* This allows the store to track what kind of data each object holds,
* enabling future type-aware policies (eviction priority, replication
* strategies, etc.). Defaults to UNKNOWN for backward compatibility.
*/
enum class ObjectDataType : uint8_t {
UNKNOWN = 0,
KVCACHE = 1,
TENSOR = 2,
WEIGHT = 3,
SAMPLE = 4,
ACTIVATION = 5,
GRADIENT = 6,
OPTIMIZER_STATE = 7,
METADATA = 8,
// 9-255 reserved for future types
};

inline std::ostream& operator<<(std::ostream& os,
const ObjectDataType& type) noexcept {
static const std::unordered_map<ObjectDataType, std::string_view>
type_strings{{ObjectDataType::UNKNOWN, "UNKNOWN"},
{ObjectDataType::KVCACHE, "KVCACHE"},
{ObjectDataType::TENSOR, "TENSOR"},
{ObjectDataType::WEIGHT, "WEIGHT"},
{ObjectDataType::SAMPLE, "SAMPLE"},
{ObjectDataType::ACTIVATION, "ACTIVATION"},
{ObjectDataType::GRADIENT, "GRADIENT"},
{ObjectDataType::OPTIMIZER_STATE, "OPTIMIZER_STATE"},
{ObjectDataType::METADATA, "METADATA"}};

auto it = type_strings.find(type);
os << (it != type_strings.end() ? it->second : "UNKNOWN");
return os;
}
Comment on lines +148 to +165
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types.h now uses std::string_view in the ObjectDataType stream operator, but the header does not include <string_view> (and also relies on indirect includes for std::ostream). This makes the header non-self-contained and can cause build breaks for translation units that include types.h directly. Add the missing standard includes in types.h.

Copilot uses AI. Check for mistakes.

// Forward declarations
class BufferAllocatorBase;
class CachelibBufferAllocator;
Expand Down
31 changes: 21 additions & 10 deletions mooncake-store/src/master_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ auto MasterService::PutStart(const UUID& client_id, const std::string& key,
shard->metadata.emplace(
std::piecewise_construct, std::forward_as_tuple(key),
std::forward_as_tuple(client_id, now, total_length, std::move(replicas),
config.with_soft_pin));
config.with_soft_pin, config.data_type));
// Also insert the metadata into processing set for monitoring.
shard->processing_keys.insert(key);

Expand Down Expand Up @@ -3485,7 +3485,8 @@ MasterService::MetadataSerializer::DeserializeShard(const msgpack::object& obj,
std::forward_as_tuple(
metadata_ptr->client_id, metadata_ptr->put_start_time,
metadata_ptr->size, metadata_ptr->PopReplicas(),
metadata_ptr->soft_pin_timeout.has_value()));
metadata_ptr->soft_pin_timeout.has_value(),
metadata_ptr->data_type));

it->second.lease_timeout = metadata_ptr->lease_timeout;
it->second.soft_pin_timeout = metadata_ptr->soft_pin_timeout;
Expand All @@ -3500,10 +3501,12 @@ MasterService::MetadataSerializer::SerializeMetadata(
MsgpackPacker& packer) const {
// Pack ObjectMetadata using array structure for efficiency
// Format: [client_id, put_start_time, size, lease_timeout,
// has_soft_pin_timeout, soft_pin_timeout, replicas_count, replicas...]
// has_soft_pin_timeout, soft_pin_timeout, replicas_count, data_type,
// replicas...]

size_t array_size = 7; // size, lease_timeout, has_soft_pin_timeout,
// soft_pin_timeout, replicas_count
size_t array_size = 8; // client_id, put_start_time, size, lease_timeout,
// has_soft_pin_timeout, soft_pin_timeout,
// replicas_count, data_type
array_size += metadata.CountReplicas(); // One element per replica
packer.pack_array(array_size);

Expand Down Expand Up @@ -3543,6 +3546,9 @@ MasterService::MetadataSerializer::SerializeMetadata(
// Serialize replicas count
packer.pack(static_cast<uint32_t>(metadata.CountReplicas()));

// Serialize data_type
packer.pack(static_cast<uint8_t>(metadata.data_type));

Comment on lines 4274 to +4323
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Snapshot serialization now always writes the new array shape (8 + replicas_count) by inserting data_type. While the new deserializer can read both formats, this is only backward-compatible in the "new code reads old snapshots" direction. If any mixed-version deployment expects older binaries to load snapshots produced by newer binaries, they will likely fail on the extra field. Consider adding an explicit snapshot format/version marker or a compatibility mode to write the old format during rolling upgrades, and/or clarify the upgrade guarantees.

Copilot uses AI. Check for mistakes.
// Serialize replicas
for (const auto& replica : metadata.GetAllReplicas()) {
auto result = Serializer<Replica>::serialize(
Expand All @@ -3565,8 +3571,7 @@ MasterService::MetadataSerializer::DeserializeMetadata(
"deserialize ObjectMetadata state is not an array"));
}

// Need at least 7 elements: client_id, put_start_time, size, lease_timeout,
// has_soft_pin_timeout, soft_pin_timeout, replicas_count
// Need at least 7 elements for old format (without data_type), 8 for new
if (obj.via.array.size < 7) {
return tl::unexpected(SerializationError(
ErrorCode::DESERIALIZE_FAIL,
Expand Down Expand Up @@ -3599,8 +3604,14 @@ MasterService::MetadataSerializer::DeserializeMetadata(
// Deserialize replicas count
uint32_t replicas_count = array[index++].as<uint32_t>();

// Check if array size matches replicas_count
if (obj.via.array.size != 7 + replicas_count) {
// Detect old vs new format based on total array size.
// Old format: 7 + replicas_count elements
// New format: 8 + replicas_count elements (with data_type after
// replicas_count)
ObjectDataType data_type = ObjectDataType::UNKNOWN;
if (obj.via.array.size == 8 + replicas_count) {
data_type = static_cast<ObjectDataType>(array[index++].as<uint8_t>());
} else if (obj.via.array.size != 7 + replicas_count) {
return tl::unexpected(SerializationError(
ErrorCode::DESERIALIZE_FAIL,
"deserialize ObjectMetadata array size mismatch"));
Expand All @@ -3625,7 +3636,7 @@ MasterService::MetadataSerializer::DeserializeMetadata(
client_id,
std::chrono::system_clock::time_point(
std::chrono::milliseconds(put_start_time_timestamp)),
size, std::move(replicas), enable_soft_pin);
size, std::move(replicas), enable_soft_pin, data_type);
metadata->lease_timeout = std::chrono::system_clock::time_point(
std::chrono::milliseconds(lease_timestamp));

Expand Down
1 change: 1 addition & 0 deletions mooncake-store/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ add_store_test(task_executor_test task_executor_test.cpp)
add_store_test(task_integration_test task_integration_test.cpp)
add_store_test(dummy_client_get_buffer_test dummy_client_get_buffer_test.cpp)
add_store_test(health_check_test health_check_test.cpp)
add_store_test(object_data_type_test object_data_type_test.cpp)
add_subdirectory(e2e)

add_executable(high_availability_test high_availability_test.cpp)
Expand Down
152 changes: 152 additions & 0 deletions mooncake-store/tests/object_data_type_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include "types.h"
#include "replica.h"
#include "master_service.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <sstream>
#include <vector>

namespace mooncake::test {

class ObjectDataTypeTest : public ::testing::Test {
protected:
void SetUp() override {
google::InitGoogleLogging("ObjectDataTypeTest");
FLAGS_logtostderr = true;
}

void TearDown() override { google::ShutdownGoogleLogging(); }

static constexpr size_t kDefaultSegmentBase = 0x300000000;
static constexpr size_t kDefaultSegmentSize = 1024 * 1024 * 16;

Segment MakeSegment(std::string name = "test_segment",
size_t base = kDefaultSegmentBase,
size_t size = kDefaultSegmentSize) const {
Segment segment;
segment.id = generate_uuid();
segment.name = std::move(name);
segment.base = base;
segment.size = size;
segment.te_endpoint = segment.name;
return segment;
}
};

// Verify enum values match the RFC spec
TEST_F(ObjectDataTypeTest, EnumValues) {
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::UNKNOWN), 0);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::KVCACHE), 1);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::TENSOR), 2);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::WEIGHT), 3);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::SAMPLE), 4);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::ACTIVATION), 5);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::GRADIENT), 6);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::OPTIMIZER_STATE), 7);
EXPECT_EQ(static_cast<uint8_t>(ObjectDataType::METADATA), 8);
}

// Verify stream operator produces readable output
TEST_F(ObjectDataTypeTest, StreamOperator) {
std::ostringstream oss;
oss << ObjectDataType::KVCACHE;
EXPECT_EQ(oss.str(), "KVCACHE");

oss.str("");
oss << ObjectDataType::UNKNOWN;
EXPECT_EQ(oss.str(), "UNKNOWN");

oss.str("");
oss << ObjectDataType::OPTIMIZER_STATE;
EXPECT_EQ(oss.str(), "OPTIMIZER_STATE");

// Out-of-range value should print "UNKNOWN"
oss.str("");
oss << static_cast<ObjectDataType>(200);
EXPECT_EQ(oss.str(), "UNKNOWN");
}

// ReplicateConfig defaults to UNKNOWN
TEST_F(ObjectDataTypeTest, ReplicateConfigDefaultDataType) {
ReplicateConfig config;
EXPECT_EQ(config.data_type, ObjectDataType::UNKNOWN);
}

// ReplicateConfig can be set to other types
TEST_F(ObjectDataTypeTest, ReplicateConfigSetDataType) {
ReplicateConfig config;
config.data_type = ObjectDataType::WEIGHT;
EXPECT_EQ(config.data_type, ObjectDataType::WEIGHT);
}

// ReplicateConfig stream output includes data_type
TEST_F(ObjectDataTypeTest, ReplicateConfigStreamIncludesDataType) {
ReplicateConfig config;
config.data_type = ObjectDataType::TENSOR;
std::ostringstream oss;
oss << config;
EXPECT_NE(oss.str().find("data_type: TENSOR"), std::string::npos);
}

// PutStart with data_type propagates to ObjectMetadata
TEST_F(ObjectDataTypeTest, PutStartWithDataType) {
std::unique_ptr<MasterService> service(new MasterService());
Segment segment = MakeSegment();
UUID client_id = generate_uuid();
auto mount_result = service->MountSegment(segment, client_id);
ASSERT_TRUE(mount_result.has_value());

UUID put_client = generate_uuid();

// Put with WEIGHT type
ReplicateConfig config;
config.replica_num = 1;
config.data_type = ObjectDataType::WEIGHT;

auto result = service->PutStart(put_client, "key_weight", 1024, config);
ASSERT_TRUE(result.has_value());
EXPECT_FALSE(result.value().empty());

auto end_result =
service->PutEnd(put_client, "key_weight", ReplicaType::MEMORY);
EXPECT_TRUE(end_result.has_value());
}
Comment on lines +98 to +120
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PutStartWithDataType is described as verifying propagation to ObjectMetadata, but it never asserts that data_type was actually stored/serialized—only that PutStart/PutEnd succeed. Either add an assertion that reads back the stored metadata (e.g., via an existing test-only hook/snapshot roundtrip) or rename/re-scope the test so it matches what it verifies.

Copilot uses AI. Check for mistakes.

// PutStart with default UNKNOWN data_type still works (backward compat)
TEST_F(ObjectDataTypeTest, PutStartDefaultDataType) {
std::unique_ptr<MasterService> service(new MasterService());
Segment segment = MakeSegment();
UUID client_id = generate_uuid();
auto mount_result = service->MountSegment(segment, client_id);
ASSERT_TRUE(mount_result.has_value());

UUID put_client = generate_uuid();
ReplicateConfig config;
config.replica_num = 1;
// data_type left as default (UNKNOWN)

auto result = service->PutStart(put_client, "key_default", 1024, config);
ASSERT_TRUE(result.has_value());
EXPECT_FALSE(result.value().empty());
}

// Verify all enum values can roundtrip through uint8_t cast
TEST_F(ObjectDataTypeTest, EnumRoundtrip) {
std::vector<ObjectDataType> all_types = {
ObjectDataType::UNKNOWN, ObjectDataType::KVCACHE,
ObjectDataType::TENSOR, ObjectDataType::WEIGHT,
ObjectDataType::SAMPLE, ObjectDataType::ACTIVATION,
ObjectDataType::GRADIENT, ObjectDataType::OPTIMIZER_STATE,
ObjectDataType::METADATA,
};

for (auto type : all_types) {
uint8_t raw = static_cast<uint8_t>(type);
auto recovered = static_cast<ObjectDataType>(raw);
EXPECT_EQ(type, recovered);
}
}

} // namespace mooncake::test
Loading