Skip to content

Commit 31a677f

Browse files
committed
eliminate usage of deserialize_and_register_object_ref
1 parent bd414af commit 31a677f

File tree

3 files changed

+68
-64
lines changed

3 files changed

+68
-64
lines changed

core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -67,42 +67,35 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
6767
queue: ObjectRefHolder.Queue,
6868
ownerName: String): RecordBatch = {
6969

70-
// NOTE: We intentionally do NOT pass an owner argument to Ray.put anymore.
71-
//
72-
// - When ownerName is empty, route the put via the JVM RayAppMaster actor.
73-
// - When ownerName is set to a Python actor name (e.g. RayDPSparkMaster),
74-
// invoke that Python actor's put_data(data) method via Ray cross-language
75-
// calls so that the Python actor becomes the owner of the created object.
76-
val objectRef: ObjectRef[_] =
77-
if (ownerName == "") {
78-
Ray.put(data)
79-
} else {
80-
// Ray.getActor(String) is a raw Java Optional in Ray's Java API.
81-
// If we don't cast it to an explicit reference type here, Scala may infer
82-
// Optional[Nothing] and insert an invalid cast at runtime.
83-
val opt = Ray.getActor(ownerName).asInstanceOf[Optional[AnyRef]]
84-
if (!opt.isPresent) {
85-
throw new RayDPException(s"Actor $ownerName not found when putting dataset block.")
86-
}
87-
val handleAny: AnyRef = opt.get()
88-
if (!handleAny.isInstanceOf[PyActorHandle]) {
89-
throw new RayDPException(
90-
s"Actor $ownerName is not a Python actor; cannot invoke put_data."
91-
)
92-
}
93-
val pyHandle = handleAny.asInstanceOf[PyActorHandle]
94-
val method = PyActorMethod.of("put_data", classOf[AnyRef])
95-
val refOfRef = pyHandle.task(method, data).remote()
96-
refOfRef
97-
}
70+
// Owner-transfer only implementation:
71+
// - ownerName must always be provided (non-empty) and refer to a Python actor.
72+
// - JVM never creates/handles Ray ObjectRefs for the dataset blocks.
73+
// - JVM returns only a per-batch key encoded in RecordBatch.objectId (bytes),
74+
// and Python will fetch the real ObjectRefs from the owner actor by key.
75+
76+
if (ownerName == null || ownerName.isEmpty) {
77+
throw new RayDPException("ownerName must be set for Spark->Ray conversion.")
78+
}
79+
80+
val opt = Ray.getActor(ownerName).asInstanceOf[Optional[AnyRef]]
81+
if (!opt.isPresent) {
82+
throw new RayDPException(s"Actor $ownerName not found when putting dataset block.")
83+
}
84+
val handleAny: AnyRef = opt.get()
85+
if (!handleAny.isInstanceOf[PyActorHandle]) {
86+
throw new RayDPException(s"Actor $ownerName is not a Python actor; cannot invoke put_data.")
87+
}
88+
val pyHandle = handleAny.asInstanceOf[PyActorHandle]
89+
val batchKey = UUID.randomUUID().toString
90+
91+
// put_data(batchKey, arrowBytes) -> boolean ack
92+
val method = PyActorMethod.of("put_data", classOf[java.lang.Boolean])
93+
val args: Array[AnyRef] = Array(batchKey, data.asInstanceOf[AnyRef])
94+
new PyActorTaskCaller(pyHandle, method, args).remote().get()
9895

99-
// add the objectRef to the objectRefHolder to avoid reference GC
100-
queue.add(objectRef)
101-
val objectRefImpl = RayDPUtils.convert(objectRef)
102-
val objectId = objectRefImpl.getId
103-
val runtime = Ray.internal.asInstanceOf[AbstractRayRuntime]
104-
val addressInfo = runtime.getObjectStore.getOwnershipInfo(objectId)
105-
RecordBatch(addressInfo, objectId.getBytes, numRecords)
96+
// ownerAddress/objectId here are not Ray's object metadata; objectId encodes the key.
97+
// Python side will treat objectId as UTF-8 key bytes.
98+
RecordBatch(Array.emptyByteArray, batchKey.getBytes("UTF-8"), numRecords)
10699
}
107100

108101
/**

python/raydp/spark/dataset.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -110,39 +110,26 @@ def raydp_master_set_reference_as_state(
110110
raydp_master_set_reference_as_state)
111111

112112

113-
@client_mode_wrap
114-
def _register_objects(records):
115-
worker = ray.worker.global_worker
116-
blocks: List[ray.ObjectRef] = []
117-
block_sizes: List[int] = []
118-
for obj_id, owner, num_record in records:
119-
object_ref = ray.ObjectRef(obj_id)
120-
# Register the ownership of the ObjectRef
121-
worker.core_worker.deserialize_and_register_object_ref(
122-
object_ref.binary(), ray.ObjectRef.nil(), owner, "")
123-
blocks.append(object_ref)
124-
block_sizes.append(num_record)
125-
return blocks, block_sizes
126-
127113
def _save_spark_df_to_object_store(df: sql.DataFrame, use_batch: bool = True,
128114
owner: Union[PartitionObjectsOwner, None] = None):
129115
# call java function from python
130116
jvm = df.sql_ctx.sparkSession.sparkContext._jvm
131117
jdf = df._jdf
132118
object_store_writer = jvm.org.apache.spark.sql.raydp.ObjectStoreWriter(jdf)
133-
actor_owner_name = ""
134-
if owner is not None:
135-
actor_owner_name = owner.actor_name
119+
if owner is None:
120+
# Default to RayDPSparkMaster as the owner if not specified.
121+
owner = get_raydp_master_owner(df.sql_ctx.sparkSession)
122+
actor_owner_name = owner.actor_name
136123
records = object_store_writer.save(use_batch, actor_owner_name)
137124

138-
record_tuples = [(record.objectId(), record.ownerAddress(), record.numRecords())
139-
for record in records]
140-
blocks, block_sizes = _register_objects(record_tuples)
141-
142-
if owner is not None:
143-
actor_owner = ray.get_actor(actor_owner_name)
144-
ray.get(owner.set_reference_as_state(actor_owner, blocks))
145-
125+
# Owner-transfer-only path:
126+
# JVM returns List[RecordBatch] where record.objectId() contains UTF-8 bytes of batch_key.
127+
# Fetch actual ObjectRefs from the owner actor by key.
128+
data_owner_actor = ray.get_actor(actor_owner_name)
129+
batch_keys = [bytes(record.objectId()).decode("utf-8") for record in records]
130+
block_sizes = [record.numRecords() for record in records]
131+
blocks = ray.get(data_owner_actor.get_block_refs.remote(batch_keys))
132+
ray.get(owner.set_reference_as_state(data_owner_actor, blocks))
146133
return blocks, block_sizes
147134

148135
def spark_dataframe_to_ray_dataset(df: sql.DataFrame,

python/raydp/spark/ray_cluster_master.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,36 @@ class RayDPObjectOwnerMixin:
5757
objects, without using Ray's experimental `ray.put(_owner=...)` API.
5858
"""
5959

60-
def put_data(self, data) -> "pa.Table":
61-
"""Put one serialized Arrow batch into the Ray object store."""
62-
# data is Arrow IPC stream bytes written by ArrowStreamWriter
60+
def _get_raydp_blocks_by_key(self):
61+
blocks = getattr(self, "_raydp_blocks_by_key", None)
62+
if blocks is None:
63+
blocks = {}
64+
setattr(self, "_raydp_blocks_by_key", blocks)
65+
return blocks
66+
67+
def put_data(self, batch_key: str, data: bytes) -> bool:
68+
"""Create one Ray Dataset block owned by this actor.
69+
70+
Args:
71+
batch_key: A per-batch application-level key generated by the JVM.
72+
data: Arrow IPC stream bytes written by ArrowStreamWriter on Spark executors.
73+
74+
Returns:
75+
True when the block has been created and stored.
76+
"""
6377
reader = pa.ipc.open_stream(pa.BufferReader(data))
6478
table = reader.read_all()
65-
return table
79+
ref = ray.put(table)
80+
self._get_raydp_blocks_by_key()[batch_key] = ref
81+
return True
82+
83+
def get_block_refs(self, batch_keys):
84+
"""Fetch (and remove) stored block refs for the given keys."""
85+
blocks = self._get_raydp_blocks_by_key()
86+
refs = []
87+
for k in batch_keys:
88+
refs.append(blocks.pop(k))
89+
return refs
6690

6791

6892

0 commit comments

Comments
 (0)