Skip to content

Commit 38de0f9

Browse files
goutamvenkat-anyscalemarcostephan
authored andcommitted
[Data] - Optimize memory usage for One Hot Encoder (ray-project#56565)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Previously, the vector that was holding the values from OneHotEncoder was of type `int64`. We can reduce this to `uint8`, which should result in 8x lower memory usage ## Related issue number <!-- For example: "Closes ray-project#1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Goutam V. <goutam@anyscale.com> Signed-off-by: Marco Stephan <marco@magic.dev>
1 parent 44cac84 commit 38de0f9

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

python/ray/data/preprocessors/encoder.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,18 @@ def safe_get(v: Any, stats: Dict[str, int]):
285285

286286
stats = self.stats_[f"unique_values({column})"]
287287
num_categories = len(stats)
288-
one_hot = np.zeros((len(df), num_categories), dtype=int)
288+
one_hot = np.zeros((len(df), num_categories), dtype=np.uint8)
289+
# Integer indices for each category in the column
289290
codes = df[column].apply(lambda v: safe_get(v, stats)).to_numpy()
290-
valid_rows = codes != -1
291-
one_hot[np.nonzero(valid_rows)[0], codes[valid_rows].astype(int)] = 1
291+
# Filter to only the rows that have a valid category
292+
valid_category_mask = codes != -1
293+
# Dimension should be (num_rows, ) - 1D boolean array
294+
non_zero_indices = np.nonzero(valid_category_mask)[0]
295+
# Mark the corresponding categories as 1
296+
one_hot[
297+
non_zero_indices,
298+
codes[valid_category_mask],
299+
] = 1
292300
df[output_column] = one_hot.tolist()
293301

294302
return df

0 commit comments

Comments
 (0)