Skip to content

Commit 9c43488

Browse files
committed
Implement equivalet transform fuc for arrow1 (with tests)
1 parent 7a0f1df commit 9c43488

File tree

7 files changed

+118
-32
lines changed

7 files changed

+118
-32
lines changed

src/arrow1/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ pub mod writer;
1111
#[cfg(feature = "writer")]
1212
pub mod writer_properties;
1313

14+
#[cfg(all(feature = "writer", feature = "async"))]
15+
pub mod writer_async;
16+
1417
pub mod error;
1518

1619
#[cfg(all(feature = "reader", feature = "async"))]

src/arrow1/wasm.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,24 @@ pub async fn read_parquet_stream(
8282
});
8383
Ok(wasm_streams::ReadableStream::from_stream(stream).into_raw())
8484
}
85+
86+
#[wasm_bindgen(js_name = "transformParquetStream")]
87+
#[cfg(all(feature = "writer", feature = "async"))]
88+
pub fn transform_parquet_stream(
89+
stream: wasm_streams::readable::sys::ReadableStream,
90+
writer_properties: Option<crate::arrow1::writer_properties::WriterProperties>,
91+
) -> WasmResult<wasm_streams::readable::sys::ReadableStream> {
92+
use futures::StreamExt;
93+
let batches = wasm_streams::ReadableStream::from_raw(stream)
94+
.into_stream()
95+
.map(|maybe_chunk| {
96+
let chunk = maybe_chunk.unwrap();
97+
let transformed: arrow_wasm::arrow1::RecordBatch = chunk.try_into().unwrap();
98+
transformed
99+
});
100+
let output_stream = super::writer_async::transform_parquet_stream(
101+
batches,
102+
writer_properties.unwrap_or_default(),
103+
);
104+
Ok(output_stream.unwrap())
105+
}

src/arrow1/writer_async.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use crate::arrow1::error::Result;
2+
use crate::common::stream::WrappedWritableStream;
3+
use async_compat::CompatExt;
4+
use futures::StreamExt;
5+
use parquet::arrow::async_writer::AsyncArrowWriter;
6+
use wasm_bindgen_futures::spawn_local;
7+
8+
pub fn transform_parquet_stream(
9+
batches: impl futures::Stream<Item = arrow_wasm::arrow1::RecordBatch> + 'static,
10+
writer_properties: crate::arrow1::writer_properties::WriterProperties,
11+
) -> Result<wasm_streams::readable::sys::ReadableStream> {
12+
let options = Some(writer_properties.into());
13+
// let encoding = writer_properties.get_encoding();
14+
15+
let (writable_stream, output_stream) = {
16+
let raw_stream = wasm_streams::transform::sys::TransformStream::new();
17+
let raw_writable = raw_stream.writable();
18+
let inner_writer = wasm_streams::WritableStream::from_raw(raw_writable).into_async_write();
19+
let writable_stream = WrappedWritableStream {
20+
stream: inner_writer,
21+
};
22+
(writable_stream, raw_stream.readable())
23+
};
24+
spawn_local::<_>(async move {
25+
let mut adapted_stream = batches.peekable();
26+
let mut pinned_stream = std::pin::pin!(adapted_stream);
27+
let first_batch = pinned_stream.as_mut().peek().await.unwrap();
28+
let schema = first_batch.schema().into_inner();
29+
// Need to create an encoding for each column
30+
let mut writer =
31+
AsyncArrowWriter::try_new(writable_stream.compat(), schema, 1024, options).unwrap();
32+
while let Some(batch) = pinned_stream.next().await {
33+
let _ = writer.write(&batch.into()).await;
34+
}
35+
let _ = writer.close().await;
36+
});
37+
Ok(output_stream)
38+
}

src/arrow2/writer_async.rs

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,9 @@
11
use crate::arrow2::error::Result;
2+
use crate::common::stream::WrappedWritableStream;
23
use arrow2::io::parquet::write::FileSink;
3-
use futures::{AsyncWrite, SinkExt, StreamExt};
4+
use futures::{SinkExt, StreamExt};
45
use wasm_bindgen_futures::spawn_local;
56

6-
struct WrappedWritableStream<'writer> {
7-
stream: wasm_streams::writable::IntoAsyncWrite<'writer>,
8-
}
9-
10-
impl<'writer> AsyncWrite for WrappedWritableStream<'writer> {
11-
fn poll_write(
12-
self: std::pin::Pin<&mut Self>,
13-
cx: &mut std::task::Context<'_>,
14-
buf: &[u8],
15-
) -> std::task::Poll<std::io::Result<usize>> {
16-
AsyncWrite::poll_write(std::pin::Pin::new(&mut self.get_mut().stream), cx, buf)
17-
}
18-
19-
fn poll_flush(
20-
self: std::pin::Pin<&mut Self>,
21-
cx: &mut std::task::Context<'_>,
22-
) -> std::task::Poll<std::io::Result<()>> {
23-
AsyncWrite::poll_flush(std::pin::Pin::new(&mut self.get_mut().stream), cx)
24-
}
25-
26-
fn poll_close(
27-
self: std::pin::Pin<&mut Self>,
28-
cx: &mut std::task::Context<'_>,
29-
) -> std::task::Poll<std::io::Result<()>> {
30-
AsyncWrite::poll_close(std::pin::Pin::new(&mut self.get_mut().stream), cx)
31-
}
32-
}
33-
34-
unsafe impl<'writer> Send for WrappedWritableStream<'writer> {}
35-
367
pub fn transform_parquet_stream(
378
batches: impl futures::Stream<Item = arrow_wasm::arrow2::RecordBatch> + 'static,
389
writer_properties: crate::arrow2::writer_properties::WriterProperties,

src/common/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ pub mod writer_properties;
33

44
#[cfg(feature = "async")]
55
pub mod fetch;
6+
7+
#[cfg(feature = "async")]
8+
pub mod stream;

src/common/stream.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use futures::AsyncWrite;
2+
3+
pub struct WrappedWritableStream<'writer> {
4+
pub stream: wasm_streams::writable::IntoAsyncWrite<'writer>,
5+
}
6+
7+
impl<'writer> AsyncWrite for WrappedWritableStream<'writer> {
8+
fn poll_write(
9+
self: std::pin::Pin<&mut Self>,
10+
cx: &mut std::task::Context<'_>,
11+
buf: &[u8],
12+
) -> std::task::Poll<std::io::Result<usize>> {
13+
AsyncWrite::poll_write(std::pin::Pin::new(&mut self.get_mut().stream), cx, buf)
14+
}
15+
16+
fn poll_flush(
17+
self: std::pin::Pin<&mut Self>,
18+
cx: &mut std::task::Context<'_>,
19+
) -> std::task::Poll<std::io::Result<()>> {
20+
AsyncWrite::poll_flush(std::pin::Pin::new(&mut self.get_mut().stream), cx)
21+
}
22+
23+
fn poll_close(
24+
self: std::pin::Pin<&mut Self>,
25+
cx: &mut std::task::Context<'_>,
26+
) -> std::task::Poll<std::io::Result<()>> {
27+
AsyncWrite::poll_close(std::pin::Pin::new(&mut self.get_mut().stream), cx)
28+
}
29+
}
30+
31+
unsafe impl<'writer> Send for WrappedWritableStream<'writer> {}

tests/js/arrow1.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import * as test from "tape";
22
import * as wasm from "../../pkg/node/arrow1";
33
import { readFileSync } from "fs";
44
import { tableFromIPC, tableToIPC } from "apache-arrow";
5-
import { testArrowTablesEqual, readExpectedArrowData } from "./utils";
5+
import { testArrowTablesEqual, readExpectedArrowData, temporaryServer } from "./utils";
66

77
// Path from repo root
88
const dataDir = "tests/data";
@@ -83,3 +83,22 @@ test("error produced trying to read file with arrayBuffer", (t) => {
8383

8484
t.end();
8585
});
86+
87+
test("read stream-write stream-read stream round trip (no writer properties provided)", async (t) => {
88+
const server = await temporaryServer();
89+
const listeningPort = server.addresses()[0].port;
90+
const rootUrl = `http://localhost:${listeningPort}`;
91+
92+
const expectedTable = readExpectedArrowData();
93+
94+
const url = `${rootUrl}/1-partition-brotli.parquet`;
95+
const originalStream = await wasm.readParquetStream(url);
96+
97+
const stream = await wasm.transformParquetStream(originalStream);
98+
const accumulatedBuffer = new Uint8Array(await new Response(stream).arrayBuffer());
99+
const roundtripTable = tableFromIPC(wasm.readParquet(accumulatedBuffer).intoIPC());
100+
101+
testArrowTablesEqual(t, expectedTable, roundtripTable);
102+
await server.close();
103+
t.end();
104+
})

0 commit comments

Comments
 (0)