This repository was archived by the owner on May 6, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 233
Created experimental folder and added all NTK sketching codes #142
Open
insuhan
wants to merge
49
commits into
google:main
Choose a base branch
from
insuhan:NT
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,949
−0
Open
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
c64f307
add NTK Random Features and Sketching codes
insuhan d1a9266
Add NTK Random Features and Sketching codes
insuhan 9dc3536
Delete cache files
insuhan aa12545
Resolve pytype tests
insuhan 12b828e
ntk sketch with polynomial approximation to the end kernel function
4a5cddd
Fix simple issues from code reviews (v1)
insuhan 932987d
Fix simple issues from code reviews (v2)
insuhan 6a156ed
Fix simple issues from code reviews (v2)
insuhan 9f2a3e5
Automatically preprocess init_fn/feature_fn
insuhan c2bed91
Update for raw inputs
insuhan d874cde
Update FlattenFeatures for raw inputs
insuhan 78038f9
changes to the poly sketching alg
09ea575
fc ntk sketch
71a0946
poly fitting using jaxopt
6404c1b
poly fitting minor edit
5492d68
Make poly_fitting jittable
insuhan 049336d
Fix typo
insuhan a14f8aa
Edit format of sketching.py
insuhan 5f5ac18
Fix typo in alpha_ computitation
insuhan 7e0f580
Delete unnecessaries
insuhan 391a1b8
Update FC NTK features and check pytype
insuhan d52ae47
Make jit-able test_fc_ntk.py
insuhan ec43fc7
Add ReluNTKFeatures (one-pass sketching)
insuhan e0f5bfe
Reflect comments in PR conversation
insuhan cb90151
Merge remote-tracking branch 'upstream/main' into NT
insuhan 61bc32d
Add JAXopt package
insuhan 1cfbe5a
Reflect comments in PR conversation (v2)
insuhan 5485e17
Compare ReluNTKFeatures to neural_tangents.empirical_ntk_fn
insuhan 5756f92
Merge remote-tracking branch 'upstream/main' into NT
insuhan 3acf87a
test
amirzandieh e99f7d0
Merge branch 'NT' of https://github.com/insuhan/neural-tangents into NT
insuhan 07f563d
Extend to ConvFeatures with retangular filter o shape
insuhan a65d729
Fix complex dtype warning
insuhan 671c27d
Update Cholesky decomposition safely
insuhan c84ed65
Fix nans issue -- complex data type
insuhan 50212c2
recover previous commit
insuhan 52aacb5
Add bias term for DenseFeatures
insuhan 0995169
Fix typo
insuhan 70f53fe
Add bias in ConvFeatures
insuhan 5564dac
Merge remote-tracking branch 'upstream/main' into NT
insuhan 6d709ff
Add aggregate features for graph neural nets
insuhan a6724cb
Fix features_test
insuhan 420f836
Fix features_test
insuhan 9a14e60
Update dynamic axis
insuhan e35e551
Update neural-tangents v=0.6.0
insuhan 7f0b5f7
Add setup.py
insuhan 2036f86
Add jaxopt in setup.py
insuhan af524d2
Change the third argument of init_fn
insuhan a16098c
Add ReluNTKFeatures test
insuhan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Efficient Feature Map of Neural Tangent Kernels via Sketching and Random Features | ||
|
||
Implementations developed in [[1]](#1-scaling-neural-tangent-kernels-via-sketching-and-random-features). The library is written for users familar with [JAX](https://github.com/google/jax) and [Neural Tangents](https://github.com/google/neural-tangents) library. The codes are compatible with NT v0.5.0. | ||
|
||
[PyTorch](https://pytorch.org/) Implementations can be found in [here](https://github.com/insuhan/ntk-sketch-rf). | ||
|
||
|
||
## Examples | ||
|
||
### Fully-connected NTK approximation via Random Features: | ||
|
||
```python | ||
from jax import random | ||
from experimental.features import DenseFeatures, ReluFeatures, serial | ||
|
||
relufeat_arg = { | ||
'method': 'RANDFEAT', | ||
'feature_dim0': 64, | ||
'feature_dim1': 128, | ||
'sketch_dim': 256, | ||
} | ||
|
||
init_fn, feature_fn = serial( | ||
DenseFeatures(512), ReluFeatures(**relufeat_arg), | ||
DenseFeatures(512), ReluFeatures(**relufeat_arg), | ||
DenseFeatures(1) | ||
) | ||
|
||
key1, key2 = random.split(random.PRNGKey(1)) | ||
x = random.normal(key1, (5, 4)) | ||
|
||
_, feat_fn_inputs = init_fn(key2, x.shape) | ||
feats = feature_fn(x, feat_fn_inputs) | ||
# feats.nngp_feat is a feature map of NNGP kernel | ||
# feats.ntk_feat is a feature map of NTK | ||
assert feats.nngp_feat.shape == (5, relufeat_arg['feature_dim1']) | ||
assert feats.ntk_feat.shape == (5, relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim']) | ||
``` | ||
|
||
For more details of fully connected NTK features, please check `test_fc_ntk.py`. | ||
|
||
### Convolutional NTK approximation via Random Features: | ||
|
||
```python | ||
from experimental.features import ConvFeatures, AvgPoolFeatures, FlattenFeatures | ||
|
||
init_fn, feature_fn = serial( | ||
ConvFeatures(512, filter_shape=(3, 3)), ReluFeatures(**relufeat_arg), | ||
AvgPoolFeatures((2, 2), strides=(2, 2)), FlattenFeatures(), | ||
DenseFeatures(512) | ||
) | ||
|
||
n, H, W, C = 5, 8, 8, 3 | ||
key1, key2 = random.split(random.PRNGKey(1)) | ||
x = random.normal(key1, shape=(n, H, W, C)) | ||
|
||
_, feat_fn_inputs = init_fn(key2, x.shape) | ||
feats = feature_fn(x, feat_fn_inputs) | ||
# feats.nngp_feat is a feature map of NNGP kernel | ||
# feats.ntk_feat is a feature map of NTK | ||
assert feats.nngp_feat.shape == (5, (H/2)*(W/2)*relufeat_arg['feature_dim1']) | ||
assert feats.ntk_feat.shape == (5, (H/2)*(W/2)*(relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim'])) | ||
``` | ||
For more complex CNTK features, please check `test_myrtle_networks.py`. | ||
|
||
# Modules | ||
|
||
All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. We do not return `apply_fn` functions. | ||
|
||
- `init_fn` takes (1) random seed and (2) input shape. It returns (1) a pair of shapes of both NNGP and NTK features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach). | ||
- `feature_fn` takes (1) feature structure `features.Feature` and (2) parameters used for feature approximation (initialized by `init_fn`). It returns `features.Feature` including approximate features of the corresponding module. | ||
|
||
|
||
## [`features.DenseFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L88) | ||
`features.DenseFeatures` provides features for fully-connected dense layer and corresponds to `stax.Dense` module in [Neural Tangents](https://github.com/google/neural-tangents). We assume that the input is a tabular dataset (i.e., a n-by-d matrix). Its `feature_fn` updates the NTK features by concatenating NNGP features and NTK features. This is because `stax.Dense` updates a new NTK kernel matrix `(N x D)` by adding the previous NNGP and NTK kernel matrices. The features of dense layer are exact and no approximations are applied. | ||
|
||
```python | ||
insuhan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from jax import numpy as np | ||
from neural_tangents import stax | ||
from experimental.features import DenseFeatures, serial | ||
|
||
width = 1 | ||
x = random.normal(key1, shape=(3, 2)) | ||
_, _, kernel_fn = stax.Dense(width) | ||
nt_kernel = kernel_fn(x) | ||
|
||
_, feat_fn = serial(DenseFeatures(width)) | ||
feat = feat_fn(x, ()) | ||
|
||
assert np.linalg.norm(nt_kernel.nngp - feat.nngp_feat @ feat.nngp_feat.T) <= 1e-12 | ||
assert feat.ntk_feat == np.zeros(()) | ||
``` | ||
|
||
## [`features.ReluFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L119) | ||
`features.ReluFeatures` is a key module of the NTK approximation. We implement feature approximations based on (1) Random Features of arc-cosine kernels [[2]](#2) and (2) Polynomial Sketch [[3]](#3). Parameters used for feature approximation are intialized in `init_fn`. We support tabular and image datasets. For tabular dataset, the input features are of form `N x D` matrix and the approximations are applied to the d-dimensional vectors. | ||
|
||
For image dataset, the inputs are 4-D tensors with shape `N x H x W x D` where N is batch size, H is image height, W is image width and D is the feature dimension. We reshape the image features into 2-D tensor with shape `NHW x D` and apply proper feature approximations. Then, the resulting features reshape to 4-D tensor with shape `N x H x W x D'` where `D'` is the output dimension of the feature approximation. | ||
|
||
To use the Random Features approach, set the parameter `method` to `rf` (default `rf`), e.g., | ||
|
||
```python | ||
from experimental.features import DenseFeatures, ReluFeatures, serial | ||
|
||
x = random.normal(key1, shape=(3, 32)) | ||
|
||
init_fn, feat_fn = serial( | ||
DenseFeatures(1), | ||
ReluFeatures(method='RANDFEAT', feature_dim0=10, feature_dim1=20, sketch_dim=30) | ||
) | ||
|
||
_, params = init_fn(key1, x.shape) | ||
|
||
out_feat = feat_fn(x, params) | ||
|
||
assert out_feat.nngp_feat.shape == (3, 20) | ||
assert out_feat.ntk_feat.shape == (3, 30) | ||
``` | ||
|
||
To use the exact feature map (based on Cholesky decomposition), set the parameter `method` to `exact`, e.g., | ||
|
||
```python | ||
init_fn, feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact')) | ||
_, params = init_fn(key1, x.shape) | ||
out_feat = feat_fn(x, params) | ||
|
||
assert out_feat.nngp_feat.shape == (3, 3) | ||
assert out_feat.ntk_feat.shape == (3, 3) | ||
``` | ||
|
||
(This is for debugging. The dimension of the exact feature map is equal to the number of inputs, i.e., `N` for tabular dataset, `NHW` for image dataset). | ||
|
||
|
||
## [`features.ConvFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L236) | ||
|
||
`features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * filter_size**2)` where `filter_size` is the size of convolution filter and `d` is the input feature dimension. | ||
|
||
|
||
## [`features.AvgPoolFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L269) | ||
|
||
`features.AvgPoolFeatures` operates the average pooling on features of both NNGP and NTK. It calls [`_pool_kernel`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L3143) function in [Neural Tangents](https://github.com/google/neural-tangents) as a subroutine. | ||
|
||
## [`features.FlattenFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L304) | ||
|
||
`features.FlattenFeatures` makes the features 2-D tensors. Similar to [`Flatten`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L1641) module in [Neural Tangents](https://github.com/google/neural-tangents), the flattened features recale by the square-root of the number of elements. For example, if `nngp_feat` has the shape `N x H x W x C`, it returns a `N x HWC` matrix where all entries are divided by `(H*W*C)**0.5`. | ||
|
||
|
||
## References | ||
#### [1] [Scaling Neural Tangent Kernels via Sketching and Random Features](https://arxiv.org/pdf/2106.07880.pdf) | ||
#### [2] [Kernel methods for deep learning](https://cseweb.ucsd.edu/~saul/papers/nips09_kernel.pdf) | ||
#### [3] [Oblivious Sketching of High-Degree Polynomial Kernels](https://arxiv.org/pdf/1909.01410.pdf) | ||
#### [4] [On Exact Computation with an Infinitely Wide Neural Net](https://arxiv.org/pdf/1904.11955.pdf) | ||
|
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.