diff --git a/experimental/README.md b/experimental/README.md new file mode 100644 index 00000000..0a1ccaf0 --- /dev/null +++ b/experimental/README.md @@ -0,0 +1,152 @@ +# Efficient Feature Map of Neural Tangent Kernels via Sketching and Random Features + +Implementations developed in [[1]](#1-scaling-neural-tangent-kernels-via-sketching-and-random-features). The library is written for users familar with [JAX](https://github.com/google/jax) and [Neural Tangents](https://github.com/google/neural-tangents) library. The codes are compatible with NT v0.5.0. + +[PyTorch](https://pytorch.org/) Implementations can be found in [here](https://github.com/insuhan/ntk-sketch-rf). + + +## Examples + +### Fully-connected NTK approximation via Random Features: + +```python +from jax import random +from experimental.features import DenseFeatures, ReluFeatures, serial + +relufeat_arg = { + 'method': 'RANDFEAT', + 'feature_dim0': 64, + 'feature_dim1': 128, + 'sketch_dim': 256, +} + +init_fn, feature_fn = serial( + DenseFeatures(512), ReluFeatures(**relufeat_arg), + DenseFeatures(512), ReluFeatures(**relufeat_arg), + DenseFeatures(1) +) + +key1, key2 = random.split(random.PRNGKey(1)) +x = random.normal(key1, (5, 4)) + +_, feat_fn_inputs = init_fn(key2, x.shape) +feats = feature_fn(x, feat_fn_inputs) +# feats.nngp_feat is a feature map of NNGP kernel +# feats.ntk_feat is a feature map of NTK +assert feats.nngp_feat.shape == (5, relufeat_arg['feature_dim1']) +assert feats.ntk_feat.shape == (5, relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim']) +``` + +For more details of fully connected NTK features, please check `test_fc_ntk.py`. + +### Convolutional NTK approximation via Random Features: + +```python +from experimental.features import ConvFeatures, AvgPoolFeatures, FlattenFeatures + +init_fn, feature_fn = serial( + ConvFeatures(512, filter_shape=(3, 3)), ReluFeatures(**relufeat_arg), + AvgPoolFeatures((2, 2), strides=(2, 2)), FlattenFeatures(), + DenseFeatures(512) +) + +n, H, W, C = 5, 8, 8, 3 +key1, key2 = random.split(random.PRNGKey(1)) +x = random.normal(key1, shape=(n, H, W, C)) + +_, feat_fn_inputs = init_fn(key2, x.shape) +feats = feature_fn(x, feat_fn_inputs) +# feats.nngp_feat is a feature map of NNGP kernel +# feats.ntk_feat is a feature map of NTK +assert feats.nngp_feat.shape == (5, (H/2)*(W/2)*relufeat_arg['feature_dim1']) +assert feats.ntk_feat.shape == (5, (H/2)*(W/2)*(relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim'])) +``` +For more complex CNTK features, please check `test_myrtle_networks.py`. + +# Modules + +All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. We do not return `apply_fn` functions. + +- `init_fn` takes (1) random seed and (2) input shape. It returns (1) a pair of shapes of both NNGP and NTK features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach). +- `feature_fn` takes (1) feature structure `features.Feature` and (2) parameters used for feature approximation (initialized by `init_fn`). It returns `features.Feature` including approximate features of the corresponding module. + + +## [`features.DenseFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L88) +`features.DenseFeatures` provides features for fully-connected dense layer and corresponds to `stax.Dense` module in [Neural Tangents](https://github.com/google/neural-tangents). We assume that the input is a tabular dataset (i.e., a n-by-d matrix). Its `feature_fn` updates the NTK features by concatenating NNGP features and NTK features. This is because `stax.Dense` updates a new NTK kernel matrix `(N x D)` by adding the previous NNGP and NTK kernel matrices. The features of dense layer are exact and no approximations are applied. + +```python +from jax import numpy as np +from neural_tangents import stax +from experimental.features import DenseFeatures, serial + +width = 1 +x = random.normal(key1, shape=(3, 2)) +_, _, kernel_fn = stax.Dense(width) +nt_kernel = kernel_fn(x) + +_, feat_fn = serial(DenseFeatures(width)) +feat = feat_fn(x, ()) + +assert np.linalg.norm(nt_kernel.nngp - feat.nngp_feat @ feat.nngp_feat.T) <= 1e-12 +assert feat.ntk_feat == np.zeros(()) +``` + +## [`features.ReluFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L119) +`features.ReluFeatures` is a key module of the NTK approximation. We implement feature approximations based on (1) Random Features of arc-cosine kernels [[2]](#2) and (2) Polynomial Sketch [[3]](#3). Parameters used for feature approximation are intialized in `init_fn`. We support tabular and image datasets. For tabular dataset, the input features are of form `N x D` matrix and the approximations are applied to the d-dimensional vectors. + +For image dataset, the inputs are 4-D tensors with shape `N x H x W x D` where N is batch size, H is image height, W is image width and D is the feature dimension. We reshape the image features into 2-D tensor with shape `NHW x D` and apply proper feature approximations. Then, the resulting features reshape to 4-D tensor with shape `N x H x W x D'` where `D'` is the output dimension of the feature approximation. + +To use the Random Features approach, set the parameter `method` to `rf` (default `rf`), e.g., + +```python +from experimental.features import DenseFeatures, ReluFeatures, serial + +x = random.normal(key1, shape=(3, 32)) + +init_fn, feat_fn = serial( + DenseFeatures(1), + ReluFeatures(method='RANDFEAT', feature_dim0=10, feature_dim1=20, sketch_dim=30) +) + +_, params = init_fn(key1, x.shape) + +out_feat = feat_fn(x, params) + +assert out_feat.nngp_feat.shape == (3, 20) +assert out_feat.ntk_feat.shape == (3, 30) +``` + +To use the exact feature map (based on Cholesky decomposition), set the parameter `method` to `exact`, e.g., + +```python +init_fn, feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact')) +_, params = init_fn(key1, x.shape) +out_feat = feat_fn(x, params) + +assert out_feat.nngp_feat.shape == (3, 3) +assert out_feat.ntk_feat.shape == (3, 3) +``` + +(This is for debugging. The dimension of the exact feature map is equal to the number of inputs, i.e., `N` for tabular dataset, `NHW` for image dataset). + + +## [`features.ConvFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L236) + +`features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * filter_size**2)` where `filter_size` is the size of convolution filter and `d` is the input feature dimension. + + +## [`features.AvgPoolFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L269) + +`features.AvgPoolFeatures` operates the average pooling on features of both NNGP and NTK. It calls [`_pool_kernel`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L3143) function in [Neural Tangents](https://github.com/google/neural-tangents) as a subroutine. + +## [`features.FlattenFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L304) + +`features.FlattenFeatures` makes the features 2-D tensors. Similar to [`Flatten`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L1641) module in [Neural Tangents](https://github.com/google/neural-tangents), the flattened features recale by the square-root of the number of elements. For example, if `nngp_feat` has the shape `N x H x W x C`, it returns a `N x HWC` matrix where all entries are divided by `(H*W*C)**0.5`. + + +## References +#### [1] [Scaling Neural Tangent Kernels via Sketching and Random Features](https://arxiv.org/pdf/2106.07880.pdf) +#### [2] [Kernel methods for deep learning](https://cseweb.ucsd.edu/~saul/papers/nips09_kernel.pdf) +#### [3] [Oblivious Sketching of High-Degree Polynomial Kernels](https://arxiv.org/pdf/1909.01410.pdf) +#### [4] [On Exact Computation with an Infinitely Wide Neural Net](https://arxiv.org/pdf/1904.11955.pdf) + diff --git a/experimental/__init__.py b/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/features.py b/experimental/features.py new file mode 100644 index 00000000..96555e28 --- /dev/null +++ b/experimental/features.py @@ -0,0 +1,858 @@ +import enum +from typing import Optional, Callable, Sequence, Tuple +import frozendict +import string +import functools +import operator as op + +from jax import lax +from jax import random +from jax._src.util import prod +from jax import numpy as np +import jax.example_libraries.stax as ostax +from jax import eval_shape, ShapedArray + +from neural_tangents._src.utils import dataclasses +from neural_tangents._src.utils.typing import Axes +from neural_tangents._src.stax.requirements import _set_req, get_req, _fuse_requirements, _DEFAULT_INPUT_REQ +from neural_tangents._src.stax.combinators import _get_input_req_attr +from neural_tangents._src.stax.linear import _pool_kernel, Padding, _get_dimension_numbers, AggregateImplementation +from neural_tangents._src.stax.linear import _Pooling as Pooling + +from experimental.sketching import TensorSRHT, PolyTensorSketch +from experimental.poly_fitting import kappa0_coeffs, kappa1_coeffs, kappa0, kappa1, relu_ntk_coeffs +""" Implementation for NTK Sketching and Random Features """ + + +@dataclasses.dataclass +class Features: + nngp_feat: np.ndarray + ntk_feat: np.ndarray + + batch_axis: int = 0 + channel_axis: int = -1 + + replace = ... # type: Callable[..., 'Features'] + + +class ReluFeaturesImplementation(enum.Enum): + """Method for ReLU NNGP/NTK features approximation.""" + RANDFEAT = 'RANDFEAT' + POLYSKETCH = 'POLYSKETCH' + PSRF = 'PSRF' + POLY = 'POLY' + EXACT = 'EXACT' + + +def requires(**static_reqs): + + def req(feature_fn): + _set_req(feature_fn, frozendict.frozendict(static_reqs)) + return feature_fn + + return req + + +def layer(layer_fn): + + def new_layer_fns(*args, **kwargs): + init_fn, feature_fn = layer_fn(*args, **kwargs) + init_fn = _preprocess_init_fn(init_fn) + feature_fn = _preprocess_feature_fn(feature_fn) + return init_fn, feature_fn + + return new_layer_fns + + +def _preprocess_init_fn(init_fn): + + def init_fn_any(rng, input_shape_any, **kwargs): + if _is_single_shape(input_shape_any): + # Add a dummy shape for ntk_feat + dummy_shape = (-1,) + (0,) * (len(input_shape_any) - 1) + input_shape = (input_shape_any, dummy_shape, 0) + return init_fn(rng, input_shape, **kwargs) + else: + return init_fn(rng, input_shape_any, **kwargs) + + return init_fn_any + + +def _is_single_shape(input_shape): + if all(isinstance(n, int) for n in input_shape): + return True + elif len(input_shape) == 3 and all( + _is_single_shape(s) for s in input_shape[:2]): + return False + raise ValueError(input_shape) + + +# For flexible `feature_fn` with both input `np.ndarray` and with `Feature`. +# Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/_src/stax/requirements.py +def _preprocess_feature_fn(feature_fn): + + def feature_fn_feature(feature, input, **kwargs): + return feature_fn(feature, input, **kwargs) + + def feature_fn_x(x, input, **kwargs): + feature_fn_reqs = get_req(feature_fn) + reqs = _fuse_requirements(feature_fn_reqs, _DEFAULT_INPUT_REQ, **kwargs) + feature = _inputs_to_features(x, **reqs) + return feature_fn(feature, input, **kwargs) + + def feature_fn_any(x_or_feature, input, **kwargs): + if isinstance(x_or_feature, Features): + return feature_fn_feature(x_or_feature, input, **kwargs) + return feature_fn_x(x_or_feature, input, **kwargs) + + _set_req(feature_fn_any, get_req(feature_fn)) + return feature_fn_any + + +def _inputs_to_features(x: np.ndarray, + batch_axis: int = 0, + channel_axis: int = -1, + **kwargs) -> Features: + """Transforms (batches of) inputs to a `Features`.""" + # Followed the same initialization of Neural Tangents library. + if channel_axis is None: + x = np.moveaxis(x, batch_axis, 0).reshape((x.shape[batch_axis], -1)) + batch_axis, channel_axis = 0, 1 + else: + channel_axis %= x.ndim + + nngp_feat = x / x.shape[channel_axis]**0.5 + ntk_feat = np.zeros(x.shape[:channel_axis] + (0,) + + x.shape[channel_axis + 1:], + dtype=x.dtype) + return Features(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + batch_axis=batch_axis, + channel_axis=channel_axis) # pytype:disable=wrong-keyword-args + + +# Modified the serial process of feature map blocks. +# Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py +@layer +def serial(*layers): + + init_fns, feature_fns = zip(*layers) + init_fn, _ = ostax.serial(*zip(init_fns, init_fns)) + + @requires(**_get_input_req_attr(feature_fns, fold=op.rshift)) + def feature_fn(features: Features, inputs, **kwargs) -> Features: + if not (len(init_fns) == len(feature_fns) == len(inputs)): + raise ValueError('Length of inputs should be same as that of layers.') + for feature_fn_, input_ in zip(feature_fns, inputs): + features = feature_fn_(features, input_, **kwargs) + return features + + return init_fn, feature_fn + + +@layer +def DenseFeatures(out_dim: int, + W_std: float = 1., + b_std: Optional[float] = None, + batch_axis: int = 0, + channel_axis: int = -1, + parameterization: str = 'ntk'): + + parameterization = parameterization.lower() + + if parameterization != 'ntk': + raise NotImplementedError(f'Parameterization ({parameterization}) is ' + ' not implemented yet.') + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + _channel_axis = channel_axis % len(nngp_feat_shape) + + nngp_feat_dim = nngp_feat_shape[_channel_axis] + (1 if b_std is not None + else 0) + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + nngp_feat_dim,) + nngp_feat_shape[_channel_axis + 1:] + + if prod(ntk_feat_shape) == 0: + new_ntk_feat_shape = new_nngp_feat_shape + else: + ntk_feat_dim = nngp_feat_dim + ntk_feat_shape[_channel_axis] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + ntk_feat_dim,) + ntk_feat_shape[_channel_axis + 1:] + + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () + + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat + + _channel_axis = channel_axis % nngp_feat.ndim + + if b_std is not None: # concatenate bias vector in nngp_feat + biases = b_std * np.ones(nngp_feat.shape[:_channel_axis] + + (1,) + nngp_feat.shape[_channel_axis + 1:], + dtype=nngp_feat.dtype) + nngp_feat = np.concatenate((W_std * nngp_feat, biases), + axis=_channel_axis) + ntk_feat = W_std * ntk_feat + else: + nngp_feat *= W_std + ntk_feat *= W_std + + ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=_channel_axis) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, feature_fn + + +@layer +def ReluFeatures(method: str = 'RANDFEAT', + feature_dim0: int = 1, + feature_dim1: int = 1, + sketch_dim: int = 1, + poly_degree: int = 8, + poly_sketch_dim: int = 1, + generate_rand_mtx: bool = True, + batch_axis: int = 0, + channel_axis: int = -1): + + method = ReluFeaturesImplementation(method.upper()) + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + + relu_layers_count = input_shape[2] + new_relu_layers_count = relu_layers_count + 1 + + ndim = len(nngp_feat_shape) + _channel_axis = channel_axis % ndim + + if method == ReluFeaturesImplementation.RANDFEAT: + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + feature_dim1,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + sketch_dim,) + ntk_feat_shape[_channel_axis + 1:] + + rng1, rng2, rng3 = random.split(rng, 3) + if generate_rand_mtx: + # Random vectors for random features of arc-cosine kernel of order 0. + W0 = random.normal(rng1, (nngp_feat_shape[_channel_axis], feature_dim0)) + # Random vectors for random features of arc-cosine kernel of order 1. + W1 = random.normal(rng2, (nngp_feat_shape[_channel_axis], feature_dim1)) + else: + # if `generate_rand_mtx` is False, return random seeds and shapes instead of np.ndarray. + W0 = (rng1, (nngp_feat_shape[_channel_axis], feature_dim0)) + W1 = (rng2, (nngp_feat_shape[_channel_axis], feature_dim1)) + + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT(rng=rng3, + input_dim1=ntk_feat_shape[_channel_axis], + input_dim2=feature_dim0, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_relu_layers_count), (W0, W1, tensorsrht) + + elif method == ReluFeaturesImplementation.POLYSKETCH: + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + poly_sketch_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + sketch_dim,) + ntk_feat_shape[_channel_axis + 1:] + + rng1, rng2 = random.split(rng, 2) + + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + poly_sketch_dim,) + nngp_feat_shape[_channel_axis + 1:] + + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) + + # PolySketch expansion for nngp features. + if relu_layers_count == 0: + pts_input_dim = nngp_feat_shape[_channel_axis] + else: + pts_input_dim = int(nngp_feat_shape[_channel_axis] / 2 + 0.5) + polysketch = PolyTensorSketch(rng=rng1, + input_dim=pts_input_dim, + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + # TensorSRHT of degree 2 for approximating tensor product. + if relu_layers_count == 0: + ts_input_dim = ntk_feat_shape[_channel_axis] + else: + ts_input_dim = int(ntk_feat_shape[_channel_axis] / 2 + 0.5) + tensorsrht = TensorSRHT(rng=rng2, + input_dim1=ts_input_dim, + input_dim2=poly_degree * + (polysketch.sketch_dim // 4 - 1) + 1, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_relu_layers_count), (polysketch, tensorsrht, kappa0_coeff, + kappa1_coeff) + + elif method == ReluFeaturesImplementation.PSRF: + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + poly_sketch_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + sketch_dim,) + ntk_feat_shape[_channel_axis + 1:] + + rng1, rng2, rng3 = random.split(rng, 3) + + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + + # PolySketch expansion for nngp features. + if relu_layers_count == 0: + pts_input_dim = nngp_feat_shape[_channel_axis] + else: + pts_input_dim = int(nngp_feat_shape[_channel_axis] / 2 + 0.5) + polysketch = PolyTensorSketch(rng=rng1, + input_dim=pts_input_dim, + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + # TensorSRHT of degree 2 for approximating tensor product. + if relu_layers_count == 0: + ts_input_dim = ntk_feat_shape[_channel_axis] + else: + ts_input_dim = int(ntk_feat_shape[_channel_axis] / 2 + 0.5) + tensorsrht = TensorSRHT(rng=rng2, + input_dim1=ts_input_dim, + input_dim2=feature_dim0, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + # Random vectors for random features of arc-cosine kernel of order 0. + if relu_layers_count == 0: + W0 = random.normal(rng3, + (nngp_feat_shape[_channel_axis], feature_dim0 // 2)) + else: + W0 = random.normal( + rng3, + (int(nngp_feat_shape[_channel_axis] / 2 + 0.5), feature_dim0 // 2)) + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_relu_layers_count), (W0, polysketch, tensorsrht, kappa1_coeff) + + elif method == ReluFeaturesImplementation.POLY: + # This only uses the polynomial approximation without sketching. + feat_dim = prod( + tuple(nngp_feat_shape[i] + for i in range(ndim) + if i not in [_channel_axis])) + + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + feat_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + feat_dim,) + ntk_feat_shape[_channel_axis + 1:] + + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_relu_layers_count), (kappa0_coeff, kappa1_coeff) + + elif method == ReluFeaturesImplementation.EXACT: + # The exact feature map computation is for debug. + feat_dim = prod( + tuple(nngp_feat_shape[i] + for i in range(ndim) + if i not in [_channel_axis])) + + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + feat_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + feat_dim,) + ntk_feat_shape[_channel_axis + 1:] + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_relu_layers_count), () + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs) -> Features: + ndim = len(f.nngp_feat.shape) + _channel_axis = channel_axis % ndim + spatial_axes = tuple( + f.nngp_feat.shape[i] for i in range(ndim) if i != _channel_axis) + + def _convert_to_original(x): + return np.moveaxis(x.reshape(spatial_axes + (-1,)), -1, _channel_axis) + + def _convert_to_2d(x): + feat_dim = x.shape[_channel_axis] + return np.moveaxis(x, _channel_axis, -1).reshape(-1, feat_dim) + + nngp_feat_2d = _convert_to_2d(f.nngp_feat) + if prod(f.ntk_feat.shape) != 0: + ntk_feat_2d = _convert_to_2d(f.ntk_feat) + + if method == ReluFeaturesImplementation.RANDFEAT: # Random Features approach. + if generate_rand_mtx: + W0: np.ndarray = input[0] + W1: np.ndarray = input[1] + else: + W0 = random.normal(input[0][0], shape=input[0][1]) + W1 = random.normal(input[1][0], shape=input[1][1]) + tensorsrht: TensorSRHT = input[2] + + kappa0_feat = (nngp_feat_2d @ W0 > 0) / W0.shape[-1]**0.5 + del W0 + nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / W1.shape[-1]**0.5) + del W1 + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat, real_output=True) + + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) + + elif method == ReluFeaturesImplementation.POLYSKETCH: + polysketch: PolyTensorSketch = input[0] + tensorsrht: TensorSRHT = input[1] + kappa0_coeff: np.ndarray = input[2] + kappa1_coeff: np.ndarray = input[3] + + norms = np.linalg.norm(nngp_feat_2d, axis=-1, keepdims=True) + norms = np.maximum(norms, 1e-12) + + nngp_feat_2d /= norms + ntk_feat_2d /= norms + + # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) + del polysketch_feats + + # Apply SRHT to kappa1_feat so that dimension of nngp_feat is poly_sketch_dim//2. + nngp_feat = polysketch.standardsrht(kappa1_feat) + + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat) + + nngp_feat *= norms + ntk_feat *= norms + + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) + + elif method == ReluFeaturesImplementation.PSRF: # Combination of PolySketch and Random Features. + W0: np.ndarray = input[0] + polysketch: PolyTensorSketch = input[1] + tensorsrht: TensorSRHT = input[2] + kappa1_coeff: np.ndarray = input[3] + + norms = np.linalg.norm(nngp_feat_2d, axis=-1, keepdims=True) + norms = np.maximum(norms, 1e-12) + + nngp_feat_2d /= norms + ntk_feat_2d /= norms + + # Apply PolySketch to approximate feature maps of kappa1 kernels. + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + del polysketch_feats + + nngp_feat = polysketch.standardsrht(kappa1_feat) + + nngp_proj = nngp_feat_2d @ W0 + kappa0_feat = np.concatenate( + ((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / W0.shape[-1]**0.5 + del W0 + + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat) + + nngp_feat *= norms + ntk_feat *= norms + + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) + + elif method == ReluFeaturesImplementation.POLY: # Polynomial approximation without sketching. + kappa0_coeff: np.ndarray = input[0] + kappa1_coeff: np.ndarray = input[1] + + norms = np.linalg.norm(nngp_feat_2d, axis=-1, keepdims=True) + norms = np.maximum(norms, 1e-12) + + nngp_feat_2d /= norms + ntk_feat_2d /= norms + + gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) + nngp_feat = _cholesky(np.polyval(kappa1_coeff[::-1], gram_nngp)) + + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) + ntk_feat = _cholesky(ntk * kappa0_mat) + + nngp_feat *= norms + ntk_feat *= norms + + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) + + elif method == ReluFeaturesImplementation.EXACT: # Exact feature map computations via Cholesky decomposition. + nngp_feat = _convert_to_original( + _cholesky(kappa1(nngp_feat_2d, is_x_matrix=True))) + + if prod(f.ntk_feat.shape) != 0: + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = kappa0(nngp_feat_2d, is_x_matrix=True) + ntk_feat = _convert_to_original(_cholesky(ntk * kappa0_mat)) + else: + ntk_feat = f.ntk_feat + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + if method != ReluFeaturesImplementation.RANDFEAT: + ntk_feat /= 2.0**0.5 + nngp_feat /= 2.0**0.5 + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, feature_fn + + +def _cholesky(mat): + return np.linalg.cholesky(mat + 1e-8 * np.eye(mat.shape[0])) + + +@layer +def ReluNTKFeatures(num_layers: int, + poly_degree: int = 16, + poly_sketch_dim: int = 1024, + batch_axis: int = 0, + channel_axis: int = -1): + + if batch_axis != 0 or channel_axis != -1: + raise NotImplementedError(f'Not supported axes.') + + def init_fn(rng, input_shape): + input_dim = input_shape[0][-1] + + # PolySketch expansion for nngp/ntk features. + polysketch = PolyTensorSketch(rng=rng, + input_dim=input_dim, + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + nngp_coeffs, ntk_coeffs = relu_ntk_coeffs(poly_degree, num_layers) + + return (), (polysketch, nngp_coeffs, ntk_coeffs) + + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input=None, **kwargs): + input_shape = f.nngp_feat.shape[:-1] + + polysketch: PolyTensorSketch = input[0] + nngp_coeffs: np.ndarray = input[1] + ntk_coeffs: np.ndarray = input[2] + + norms = np.linalg.norm(f.nngp_feat, axis=channel_axis, keepdims=True) + nngp_feat = f.nngp_feat / norms + + polysketch_feats = polysketch.sketch(nngp_feat) + nngp_feat = polysketch.expand_feats(polysketch_feats, nngp_coeffs) + ntk_feat = polysketch.expand_feats(polysketch_feats, ntk_coeffs) + + # Apply SRHT to features so that dimensions are poly_sketch_dim//2. + nngp_feat = polysketch.standardsrht(nngp_feat).reshape(input_shape + (-1,)) + ntk_feat = polysketch.standardsrht(ntk_feat).reshape(input_shape + (-1,)) + + # Convert complex features to real ones. + nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1) + + nngp_feat *= norms / 2**(num_layers / 2.) + ntk_feat *= norms / 2**(num_layers / 2.) + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, feature_fn + + +@layer +def ConvFeatures(out_chan: int, + filter_shape: Sequence[int], + strides: Optional[Sequence[int]] = None, + padding: str = 'SAME', + W_std: float = 1.0, + b_std: Optional[float] = None, + dimension_numbers: Optional[Tuple[str, str, str]] = None, + parameterization: str = 'ntk'): + + parameterization = parameterization.lower() + + if dimension_numbers is None: + dimension_numbers = _get_dimension_numbers(len(filter_shape), False) + + lhs_spec, rhs_spec, out_spec = dimension_numbers + + channel_axis = lhs_spec.index('C') + + patch_size = prod(filter_shape) + + if parameterization != 'ntk': + raise NotImplementedError(f'Parameterization ({parameterization}) is ' + ' not implemented yet.') + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + + nngp_feat_dim = nngp_feat_shape[channel_axis] * patch_size + if b_std is not None: + nngp_feat_dim += 1 + + nngp_feat_dim = nngp_feat_shape[channel_axis] * patch_size + ( + 1 if b_std is not None else 0) + ntk_feat_dim = nngp_feat_dim + ntk_feat_shape[channel_axis] * patch_size + + new_nngp_feat_shape = nngp_feat_shape[:channel_axis] + ( + nngp_feat_dim,) + nngp_feat_shape[channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:channel_axis] + ( + ntk_feat_dim,) + ntk_feat_shape[channel_axis + 1:] + + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () + + @requires(batch_axis=lhs_spec.index('N'), channel_axis=lhs_spec.index('C')) + def feature_fn(f: Features, input, **kwargs): + + nngp_feat = f.nngp_feat + + _channel_axis = channel_axis % nngp_feat.ndim + + nngp_feat = _concat_shifted_features_2d( + nngp_feat, filter_shape, dimension_numbers) * W_std / patch_size**0.5 + + if b_std is not None: + biases = b_std * np.ones(nngp_feat.shape[:_channel_axis] + + (1,) + nngp_feat.shape[_channel_axis + 1:], + dtype=nngp_feat.dtype) + nngp_feat = np.concatenate((nngp_feat, biases), axis=_channel_axis) + + if prod(f.ntk_feat.shape) == 0: # if ntk_feat is empty skip feature concat + ntk_feat = nngp_feat + else: + ntk_feat = _concat_shifted_features_2d( + f.ntk_feat, filter_shape, dimension_numbers) * W_std / patch_size**0.5 + ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=_channel_axis) + + return f.replace(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + batch_axis=out_spec.index('N'), + channel_axis=out_spec.index('C')) + + return init_fn, feature_fn + + +def _concat_shifted_features_2d(X: np.ndarray, + filter_shape: Sequence[int], + dimension_numbers: Optional[Tuple[str, str, + str]] = None): + return lax.conv_general_dilated_patches(X, + filter_shape=filter_shape, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=dimension_numbers) + + +@layer +def AvgPoolFeatures(window_shape: Sequence[int], + strides: Optional[Sequence[int]] = None, + padding: str = 'VALID', + normalize_edges: bool = False, + batch_axis: int = 0, + channel_axis: int = -1): + + if window_shape[0] != strides[0] or window_shape[1] != strides[1]: + raise NotImplementedError('window_shape should be equal to strides.') + + channel_axis %= 4 + spec = ''.join( + c for c in string.ascii_uppercase if c not in ('N', 'C'))[:len(strides)] + for a in sorted((batch_axis, channel_axis % (2 + len(strides)))): + if a == batch_axis: + spec = spec[:a] + 'N' + spec[a:] + else: + spec = spec[:a] + 'C' + spec[a:] + + _kernel_window_shape = lambda x_: tuple( + [x_[0] if s == 'A' else x_[0] if s == 'B' else 1 for s in spec]) + window_shape_kernel = _kernel_window_shape(window_shape) + strides_kernel = _kernel_window_shape(strides) + + pooling = lambda x: _pool_kernel(x, Pooling.AVG, + window_shape_kernel, strides_kernel, + Padding(padding), normalize_edges, 0) + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + + new_nngp_feat_shape = eval_shape(pooling, + ShapedArray(nngp_feat_shape, + np.float32)).shape + new_ntk_feat_shape = eval_shape(pooling, + ShapedArray(ntk_feat_shape, + np.float32)).shape + + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () + + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat + + nngp_feat = pooling(nngp_feat) + + if prod(f.ntk_feat.shape) == 0: # check if ntk_feat is empty + ntk_feat = nngp_feat + else: + ntk_feat = pooling(ntk_feat) + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, feature_fn + + +@layer +def GlobalAvgPoolFeatures(batch_axis: int = 0, channel_axis: int = -1): + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + ndim = len(nngp_feat_shape) + non_spatial_axes = (batch_axis % ndim, channel_axis % ndim) + _get_output_shape = lambda _shape: tuple(_shape[i] + for i in range(ndim) + if i in non_spatial_axes) + new_nngp_feat_shape = _get_output_shape(nngp_feat_shape) + new_ntk_feat_shape = _get_output_shape(ntk_feat_shape) + + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () + + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat + + ndim = len(nngp_feat.shape) + non_spatial_axes = (batch_axis % ndim, channel_axis % ndim) + spatial_axes = tuple(set(range(ndim)) - set(non_spatial_axes)) + + nngp_feat = np.mean(nngp_feat, axis=spatial_axes) + ntk_feat = np.mean(ntk_feat, axis=spatial_axes) + + batch_first = batch_axis % ndim < channel_axis % ndim + return f.replace(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + batch_axis=0 if batch_first else 1, + channel_axis=1 if batch_first else 0) + + return init_fn, feature_fn + + +@layer +def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): + + if batch_axis_out in (0, -2): + batch_axis_out = 0 + channel_axis_out = 1 + elif batch_axis_out in (1, -1): + batch_axis_out = 1 + channel_axis_out = 0 + else: + raise ValueError(f'`batch_axis_out` must be 0 or 1, got {batch_axis_out}.') + + def get_output_shape(input_shape): + batch_size = input_shape[batch_axis] + channel_size = functools.reduce( + op.mul, input_shape[:batch_axis] + + input_shape[(batch_axis + 1) or len(input_shape):], 1) + if batch_axis_out == 0: + return batch_size, channel_size + return channel_size, batch_size + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = get_output_shape(nngp_feat_shape) + new_ntk_feat_shape = get_output_shape(ntk_feat_shape) + + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () + + @requires(batch_axis=batch_axis, channel_axis=None) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat + + batch_size = nngp_feat.shape[batch_axis] + nngp_feat_dim = prod( + nngp_feat.shape) / batch_size / f.nngp_feat.shape[f.channel_axis] + nngp_feat = nngp_feat.reshape(batch_size, -1) / nngp_feat_dim**0.5 + + if prod(f.ntk_feat.shape) != 0: # check if ntk_feat is not empty + ntk_feat_dim = prod( + f.ntk_feat.shape) / batch_size / f.ntk_feat.shape[f.channel_axis] + ntk_feat = f.ntk_feat.reshape(batch_size, -1) / ntk_feat_dim**0.5 + else: + ntk_feat = f.ntk_feat.reshape(batch_size, -1) + + return f.replace(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + batch_axis=batch_axis_out, + channel_axis=channel_axis_out) + + return init_fn, feature_fn + + +@layer +def LayerNormFeatures(axis: Axes = -1, + eps: float = 1e-12, + batch_axis: int = 0, + channel_axis: int = -1): + + def init_fn(rng, input_shape): + return input_shape, () + + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + norms = np.linalg.norm(f.nngp_feat, keepdims=True, axis=channel_axis) + norms = np.maximum(norms, eps) + + nngp_feat = f.nngp_feat / norms + ntk_feat = f.ntk_feat / norms if prod(f.ntk_feat.shape) != 0 else f.ntk_feat + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, feature_fn + + +@layer +def AggregateFeatures( + aggregate_axis: Optional[Axes] = None, + batch_axis: int = 0, + channel_axis: int = -1, + to_dense: Optional[Callable[[np.ndarray], np.ndarray]] = lambda p: p, + implementation: str = AggregateImplementation.DENSE.value): + + def init_fn(rng, input_shape): + return input_shape, () + + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input=None, pattern=None, **kwargs): + if pattern is None: + raise NotImplementedError('`pattern=None` is not implemented.') + + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat + + pattern_T = np.swapaxes(pattern, 1, 2) + nngp_feat = np.einsum("bnm,bmc->bnc", pattern_T, nngp_feat) + + if prod(f.ntk_feat.shape) != 0: # check if ntk_feat is not empty + ntk_feat = np.einsum("bnm,bmc->bnc", pattern_T, ntk_feat) + else: + ntk_feat = nngp_feat + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, feature_fn diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py new file mode 100644 index 00000000..4a23cfe5 --- /dev/null +++ b/experimental/poly_fitting.py @@ -0,0 +1,148 @@ +from jax import numpy as np +from jaxopt import OSQP + + +def kappa0(x, is_x_matrix): + if is_x_matrix: + xxt = x @ x.T + xnormsq = np.sum(x**2, axis=-1) + prod = np.outer(xnormsq, xnormsq) + return (1 - _arccos(xxt / _sqrt(prod)) / np.pi) + else: # vector input + return (1 - _arccos(x) / np.pi) + + +def kappa1(x, is_x_matrix): + if is_x_matrix: + xxt = x @ x.T + xnormsq = np.sum(x**2, axis=-1) + prod = np.outer(xnormsq, xnormsq) + return (_sqrt(prod - xxt**2) + + (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi + else: # vector input + return (_sqrt(1 - x**2) + (np.pi - _arccos(x)) * x) / np.pi + + +def _arccos(x): + return np.arccos(np.clip(x, -1, 1)) + + +def _sqrt(x): + return np.maximum(x, 1e-20)**0.5 + + +def poly_fitting_qp(xvals: np.ndarray, + fvals: np.ndarray, + weights: np.ndarray, + degree: int, + eq_last_point: bool = False): + """ Computes polynomial coefficients that fitting input observations. + For a dot-product kernel (e.g., kappa0 or kappa1), coefficients of its + Taylor series expansion are always nonnegative. Moreover, the kernel + function is a monotone increasing function. This can be solved by + Quadratic Programming (QP) under inequality constraints. + """ + nx = len(xvals) + x_powers = np.ones((nx, degree + 1), dtype=xvals.dtype) + for i in range(degree): + x_powers = xvals.reshape(nx, + 1)**np.arange(degree + 1).reshape(1, degree + 1) + + y_weighted = fvals * weights + x_powers_weighted = x_powers.T * weights[None, :] + + dx_powers = x_powers[:-1, :] - x_powers[1:, :] + + # OSQP algorithm for solving min_x x'*Q*x + c'*x such that A*x=b, G*x<= h + P = x_powers_weighted @ x_powers_weighted.T + Q = .5 * (P.T + P + 1e-5 * np.eye(P.shape[0], dtype=xvals.dtype)) # make sure Q is symmetric + c = -x_powers_weighted @ y_weighted + G = np.concatenate((dx_powers, -np.eye(degree + 1)), axis=0) + h = np.zeros(nx + degree, dtype=xvals.dtype) + + if eq_last_point: + A = x_powers[-1, :][None, :] + b = fvals[-1:] + return OSQP().run(params_obj=(Q, c), params_eq=(A, b), + params_ineq=(G, h)).params.primal + else: + return OSQP().run(params_obj=(Q, c), params_ineq=(G, h)).params.primal + + +def kappa0_coeffs(degree: int, num_layers: int): + + # A lower bound of kappa0^{(num_layers)} reduces to alpha_ from -1 + alpha_ = -1 + for i in range(num_layers): + alpha_ = (alpha_ + kappa1(alpha_, is_x_matrix=False)) / 2. + + # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] + # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used + # where more points are around 1. + num_points = 20 * num_layers + 8 * degree + x_eq = np.linspace(alpha_, 1., num=201) + x_noneq = np.cos((2 * np.arange(num_points) + 1) * np.pi / (4 * num_points)) + xvals = np.sort(np.concatenate((x_eq, x_noneq))) + fvals = kappa0(xvals, is_x_matrix=False) + + # For kappa0, we set all weights to be one. + weights = np.ones(len(fvals), dtype=xvals.dtype) + + # Coefficients can be obtained by solving QP with OSQP jaxopt. kappa0 has a + # sharp slope at x=1, hence we add an equailty condition of p_n(1)=f(x). + coeffs = poly_fitting_qp(xvals, fvals, weights, degree, eq_last_point=True) + return np.where(coeffs < 1e-5, 0.0, coeffs) + + +def kappa1_coeffs(degree: int, num_layers: int): + + # A lower bound of kappa1^{(num_layers)} reduces to alpha_ from -1 + alpha_ = -1 + for i in range(num_layers): + alpha_ = (2. * alpha_ + kappa1(alpha_, is_x_matrix=False)) / 3. + + # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] + # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used + # where more points are around 1. + num_points = 15 * num_layers + 5 * degree + x_eq = np.linspace(alpha_, 1., num=201) + x_noneq = np.cos( + (2. * np.arange(num_points) + 1.) * np.pi / (4. * num_points)) + xvals = np.sort(np.concatenate((x_eq, x_noneq))) + fvals = kappa1(xvals, is_x_matrix=False) + + # For kappa1, we set all weights to be one. + weights = np.ones(len(fvals), dtype=xvals.dtype) + + # For kappa1, we consider an equality condition for the last point + # (close to 1) because the slope around 1 is much sharper. + coeffs = poly_fitting_qp(xvals, fvals, weights, degree, eq_last_point=True) + return np.where(coeffs < 1e-5, 0.0, coeffs) + + +def relu_ntk_coeffs(degree: int, num_layers: int): + + num_points = 20 * num_layers + 8 * degree + x_eq = np.linspace(-1, 1., num=201) + x_noneq = np.cos((2 * np.arange(num_points) + 1) * np.pi / (4 * num_points)) + x = np.sort(np.concatenate((x_eq, x_noneq))) + + kappa1s = {} + kappa1s[0] = x + for i in range(num_layers): + kappa1s[i + 1] = kappa1(kappa1s[i], is_x_matrix=False) + + weights = np.linspace(0.0, 1.0, num=len(x)) + 2 / num_layers + nngp_coeffs = poly_fitting_qp(x, kappa1s[num_layers], weights, degree) + nngp_coeffs = np.where(nngp_coeffs < 1e-5, 0.0, nngp_coeffs) + + ntk = np.zeros(len(x), dtype=x.dtype) + for i in range(num_layers + 1): + z = kappa1s[i] + for j in range(i, num_layers): + z *= kappa0(kappa1s[j], is_x_matrix=False) + ntk += z + ntk_coeffs = poly_fitting_qp(x, ntk, weights, degree) + ntk_coeffs = np.where(ntk_coeffs < 1e-5, 0.0, ntk_coeffs) + + return nngp_coeffs, ntk_coeffs diff --git a/experimental/sketching.py b/experimental/sketching.py new file mode 100644 index 00000000..9e34eabc --- /dev/null +++ b/experimental/sketching.py @@ -0,0 +1,165 @@ +from jax import random +from jax import numpy as np +from jax.numpy.fft import fftn +from neural_tangents._src.utils import dataclasses +from typing import Optional, Callable + + +def _random_signs_indices(rngs, input_dim, output_dim, shape=()): + rand_signs = random.bernoulli(rngs[0], shape=shape + (input_dim,)) * 2 - 1. + rand_inds = random.choice(rngs[1], input_dim, shape=shape + (output_dim,)) + return rand_signs, rand_inds + + +# TensorSRHT of degree 2. This version allows different input vectors. +@dataclasses.dataclass +class TensorSRHT: + + input_dim1: int + input_dim2: int + sketch_dim: int + + rng: random.KeyArray + shape: Optional[np.ndarray] = None + + rand_signs1: Optional[np.ndarray] = None + rand_signs2: Optional[np.ndarray] = None + rand_inds1: Optional[np.ndarray] = None + rand_inds2: Optional[np.ndarray] = None + + replace = ... # type: Callable[..., 'TensorSRHT'] + + def init_sketches(self) -> 'TensorSRHT': + rng1, rng2, rng3, rng4 = random.split(self.rng, 4) + rand_signs1, rand_inds1 = _random_signs_indices( + (rng1, rng3), self.input_dim1, self.sketch_dim // 2) + rand_signs2, rand_inds2 = _random_signs_indices( + (rng2, rng4), self.input_dim2, self.sketch_dim // 2) + shape = (self.input_dim1, self.input_dim2, self.sketch_dim) + return self.replace(shape=shape, + rand_signs1=rand_signs1, + rand_signs2=rand_signs2, + rand_inds1=rand_inds1, + rand_inds2=rand_inds2) + + def sketch(self, x1, x2, real_output=False): + x1fft = fftn(x1 * self.rand_signs1[None, :], axes=(-1,))[:, self.rand_inds1] + x2fft = fftn(x2 * self.rand_signs2[None, :], axes=(-1,))[:, self.rand_inds2] + out = (x1fft * x2fft) / self.rand_inds1.shape[-1]**0.5 + return np.concatenate((out.real, out.imag), 1) if real_output else out + + +@dataclasses.dataclass +class PolyTensorSketch: + + input_dim: int + sketch_dim: int + degree: int + + rng: random.KeyArray + + tree_rand_signs: Optional[list] = None + tree_rand_inds: Optional[list] = None + rand_signs: Optional[np.ndarray] = None + rand_inds: Optional[np.ndarray] = None + + replace = ... # type: Callable[..., 'PolyTensorSketch'] + + def init_sketches(self) -> 'PolyTensorSketch': + height = (self.degree - 1).bit_length() + tree_rand_signs = [0] * height + tree_rand_inds = [0] * height + rng1, rng3 = random.split(self.rng, 2) + + internal_sketch_dim = self.sketch_dim // 4 - 1 + degree = self.degree // 2 + + for lvl in range(height): + rng1, rng2 = random.split(rng1) + + input_dim = self.input_dim if lvl == 0 else internal_sketch_dim + tree_rand_signs[lvl], tree_rand_inds[lvl] = _random_signs_indices( + (rng1, rng2), input_dim, internal_sketch_dim, (degree, 2)) + + degree = degree // 2 + + rng1, rng2 = random.split(rng3, 2) + rand_signs, rand_inds = _random_signs_indices( + (rng1, rng2), 1 + self.degree * internal_sketch_dim, + self.sketch_dim // 2) + + return self.replace(tree_rand_signs=tree_rand_signs, + tree_rand_inds=tree_rand_inds, + rand_signs=rand_signs, + rand_inds=rand_inds) + + # TensorSRHT of degree 2 + def tensorsrht(self, x1, x2, rand_inds, rand_signs): + x1fft = fftn(x1 * rand_signs[:1, :], axes=(-1,))[:, rand_inds[0, :]] + x2fft = fftn(x2 * rand_signs[1:, :], axes=(-1,))[:, rand_inds[1, :]] + return rand_inds.shape[1]**(-0.5) * (x1fft * x2fft) + + # Standard SRHT + def standardsrht(self, x, rand_inds=None, rand_signs=None): + rand_inds = self.rand_inds if rand_inds is None else rand_inds + rand_signs = self.rand_signs if rand_signs is None else rand_signs + xfft = fftn(x * rand_signs[None, :], axes=(-1,))[:, rand_inds] + return rand_inds.shape[0]**(-0.5) * xfft + + def sketch(self, x): + n = x.shape[0] + dtype = np.complex64 if x.real.dtype == np.float32 else np.complex128 + + height = len(self.tree_rand_signs) + V = [np.zeros(())] * height + + for lvl in range(height): + deg = self.tree_rand_signs[lvl].shape[0] + output_dim = self.tree_rand_inds[lvl].shape[2] + V[lvl] = np.zeros((deg, n, output_dim), dtype=dtype) + for j in range(deg): + if lvl == 0: + x1, x2 = x, x + else: + x1, x2 = V[lvl - 1][2 * j, :, :], V[lvl - 1][2 * j + 1, :, :] + + V[lvl] = V[lvl].at[j, :, :].set( + self.tensorsrht(x1, x2, self.tree_rand_inds[lvl][j, :, :], + self.tree_rand_signs[lvl][j, :, :])) + + U = [np.zeros(())] * 2**height + U[0] = V[-1][0, :, :] + + SetE1 = set() + + for j in range(1, 2**height): + p = (j - 1) // 2 + for lvl in range(height): + if j % (2**(lvl + 1)) == 0: + SetE1.add((lvl, p)) + else: + if lvl == 0: + V[lvl] = V[lvl].at[p, :, :].set( + self.standardsrht(x, self.tree_rand_inds[lvl][p, 0, :], + self.tree_rand_signs[lvl][p, 0, :])) + else: + if (lvl - 1, 2 * p) in SetE1: + V[lvl] = V[lvl].at[p, :, :].set(V[lvl - 1][2 * p + 1, :, :]) + else: + V[lvl] = V[lvl].at[p, :, :].set( + self.tensorsrht(V[lvl - 1][2 * p, :, :], + V[lvl - 1][2 * p + 1, :, :], + self.tree_rand_inds[lvl][p, :, :], + self.tree_rand_signs[lvl][p, :, :])) + p = p // 2 + U[j] = V[-1][0, :, :] + + return U + + def expand_feats(self, sketches, coeffs): + n = sketches[0].shape[0] + degree = len(sketches) + return np.concatenate( + [coeffs[0]**0.5 * np.ones((n, 1))] + + [coeffs[i + 1]**0.5 * sketches[-i - 1] for i in range(degree)], + axis=-1) \ No newline at end of file diff --git a/experimental/tests/__init__.py b/experimental/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py new file mode 100644 index 00000000..5b4d23d6 --- /dev/null +++ b/experimental/tests/features_test.py @@ -0,0 +1,570 @@ +from absl.testing import absltest +from absl.testing import parameterized +import functools +from jax import jit +from jax.config import config +import jax.numpy as np +import jax.random as random +from neural_tangents._src.utils import utils +from neural_tangents import stax +from tests import test_utils + +import experimental.features as ft + +config.parse_flags_with_absl() +config.update('jax_numpy_rank_promotion', 'raise') + +test_utils.update_test_tolerance() + +NUM_DIMS = [128, 256, 512] +WEIGHT_VARIANCES = [0.5, 1.] +BIAS_VARIANCES = [None, 0.1] + + +def _convert_features_to_matrices(f_, channel_axis=-1): + if isinstance(f_, ft.Features): + nngp = _convert_features_to_matrices(f_.nngp_feat, f_.channel_axis) + ntk = _convert_features_to_matrices(f_.ntk_feat, f_.channel_axis) + return nngp, ntk + elif isinstance(f_, np.ndarray): + channel_dim = f_.shape[channel_axis] + feat = np.moveaxis(f_, channel_axis, -1).reshape(-1, channel_dim) + k_mat = feat @ feat.T + if f_.ndim > 2: + k_mat = utils.zip_axes( + k_mat.reshape( + tuple(f_.shape[i] + for i in range(len(f_.shape)) + if i != channel_axis) * 2)) + return k_mat + else: + raise ValueError + + +def _convert_image_feature_to_kernel(feat): + return utils.zip_axes(np.einsum("ijkc,xyzc->ijkxyz", feat, feat)) + + +def _get_init_data(rng, shape, normalized_output=False): + x = random.normal(rng, shape) + if normalized_output: + return x / np.linalg.norm(x, axis=-1, keepdims=True) + else: + return x + + +class FeaturesTest(test_utils.NeuralTangentsTestCase): + + @parameterized.product(W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + n_layers=[1, 2, 3, 4], + do_jit=[True, False]) + def test_dense_features(self, W_std, b_std, n_layers, do_jit): + n, d = 4, 256 + rng = random.PRNGKey(1) + x = _get_init_data(rng, (n, d)) + + dense_args = {'out_dim': 1, 'W_std': W_std, 'b_std': b_std} + + kernel_fn = stax.serial(*[stax.Dense(**dense_args)] * n_layers)[2] + feature_fn = ft.serial(*[ft.DenseFeatures(**dense_args)] * n_layers)[1] + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x, None) + f = feature_fn(x, [()] * n_layers) + + self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) + self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) + + @parameterized.product( + W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + n_layers=[1, 2, 3, 4], + relu_method=['RANDFEAT', 'POLYSKETCH', 'PSRF', 'POLY', 'EXACT'], + do_jit=[True, False]) + def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): + rng = random.PRNGKey(1) + n, d = 4, 256 + x = _get_init_data(rng, (n, d)) + + dense_args = {"out_dim": 1, "W_std": W_std, "b_std": b_std} + relu_args = {'method': relu_method} + if relu_method == 'RANDFEAT': + relu_args['feature_dim0'] = 4096 + relu_args['feature_dim1'] = 4096 + relu_args['sketch_dim'] = 4096 + elif relu_method == 'POLYSKETCH': + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 4096 + relu_args['sketch_dim'] = 4096 + elif relu_method == 'PSRF': + relu_args['feature_dim0'] = 4096 + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 4096 + relu_args['sketch_dim'] = 4096 + elif relu_method == 'POLY': + relu_args['poly_degree'] = 16 + elif relu_method == 'EXACT': + pass + else: + raise ValueError(relu_method) + + kernel_fn = stax.serial( + *[stax.Dense(**dense_args), stax.Relu()] * n_layers + + [stax.Dense(**dense_args)])[2] + init_fn, feature_fn = ft.serial( + *[ft.DenseFeatures(**dense_args), + ft.ReluFeatures(**relu_args)] * n_layers + + [ft.DenseFeatures(**dense_args)]) + + rng2 = random.PRNGKey(2) + _, feat_fn_inputs = init_fn(rng2, x.shape) + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x, None) + f = feature_fn(x, feat_fn_inputs) + + if np.iscomplexobj(f.nngp_feat) or np.iscomplexobj(f.ntk_feat): + nngp_feat = np.concatenate((f.nngp_feat.real, f.nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((f.ntk_feat.real, f.ntk_feat.imag), axis=-1) + f = f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + k_nngp_approx = f.nngp_feat @ f.nngp_feat.T + k_ntk_approx = f.ntk_feat @ f.ntk_feat.T + + if relu_method == 'EXACT': + self.assertAllClose(k.nngp, k_nngp_approx) + self.assertAllClose(k.ntk, k_ntk_approx) + else: + test_utils.assert_close_matrices(self, k.nngp, k_nngp_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k.ntk, k_ntk_approx, 0.2, 1.) + + @parameterized.product(W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + n_layers=[1, 2, 3, 4], + do_jit=[True, False]) + def test_conv_features(self, W_std, b_std, n_layers, do_jit): + n, h, w, c = 3, 4, 5, 2 + rng = random.PRNGKey(1) + x = _get_init_data(rng, (n, h, w, c)) + + conv_args = { + 'out_chan': 1, + 'filter_shape': (3, 3), + 'padding': 'SAME', + 'W_std': W_std, + 'b_std': b_std + } + + kernel_fn = stax.serial(*[stax.Conv(**conv_args)] * n_layers)[2] + feature_fn = ft.serial(*[ft.ConvFeatures(**conv_args)] * n_layers)[1] + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + f = feature_fn(x, [()] * n_layers) + + if k.is_reversed: + nngp_feat = np.moveaxis(f.nngp_feat, 1, 2) + ntk_feat = np.moveaxis(f.ntk_feat, 1, 2) + f = f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + k_nngp_approx = _convert_image_feature_to_kernel(f.nngp_feat) + k_ntk_approx = _convert_image_feature_to_kernel(f.ntk_feat) + + self.assertAllClose(k.nngp, k_nngp_approx) + self.assertAllClose(k.ntk, k_ntk_approx) + + @parameterized.product(n_layers=[1, 2, 3, 4], do_jit=[True, False]) + def test_avgpool_features(self, n_layers, do_jit): + n, h, w, c = 3, 32, 28, 2 + rng = random.PRNGKey(1) + x = _get_init_data(rng, (n, h, w, c)) + + avgpool_args = { + 'window_shape': (2, 2), + 'strides': (2, 2), + 'padding': 'SAME' + } + + kernel_fn = stax.serial(*[stax.AvgPool(**avgpool_args)] * n_layers + + [stax.Flatten()])[2] + feature_fn = ft.serial(*[ft.AvgPoolFeatures(**avgpool_args)] * n_layers + + [ft.FlattenFeatures()])[1] + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + f = feature_fn(x, [()] * (n_layers + 1)) + + k_nngp_approx, k_ntk_approx = _convert_features_to_matrices(f) + + self.assertAllClose(k.nngp, k_nngp_approx) + if k.ntk.ndim > 0: + self.assertAllClose(k.ntk, k_ntk_approx) + + @parameterized.parameters([{ + 'ndim': nd, + 'do_jit': do_jit + } for nd in [2, 3, 4] for do_jit in [True, False]]) + def test_flatten_features(self, ndim, do_jit): + key = random.PRNGKey(1) + n, h, w, c = 4, 8, 6, 5 + width = 1 + W_std = 1.7 + b_std = 0.1 + if ndim == 2: + input_shape = (n, h * w * c) + elif ndim == 3: + input_shape = (n, h * w, c) + elif ndim == 4: + input_shape = (n, h, w, c) + else: + raise absltest.SkipTest() + + x = random.normal(key, input_shape) + + dense_kernel = stax.Dense(width, W_std=W_std, b_std=b_std) + dense_feature = ft.DenseFeatures(width, W_std=W_std, b_std=b_std) + + relu_kernel = stax.Relu() + relu_feature = ft.ReluFeatures(method='EXACT') + + kernel_fc = stax.serial(dense_kernel, relu_kernel, dense_kernel)[2] + kernel_top = stax.serial(dense_kernel, relu_kernel, dense_kernel, + stax.Flatten())[2] + kernel_mid = stax.serial(dense_kernel, relu_kernel, stax.Flatten(), + dense_kernel)[2] + kernel_bot = stax.serial(stax.Flatten(), dense_kernel, relu_kernel, + dense_kernel)[2] + + feature_fc = ft.serial(dense_feature, relu_feature, dense_feature)[1] + feature_top = ft.serial(dense_feature, relu_feature, dense_feature, + ft.FlattenFeatures())[1] + feature_mid = ft.serial(dense_feature, relu_feature, ft.FlattenFeatures(), + dense_feature)[1] + feature_bot = ft.serial(ft.FlattenFeatures(), dense_feature, relu_feature, + dense_feature)[1] + + if do_jit: + kernel_fc = jit(kernel_fc) + kernel_top = jit(kernel_top) + kernel_mid = jit(kernel_mid) + kernel_bot = jit(kernel_bot) + + feature_fc = jit(feature_fc) + feature_top = jit(feature_top) + feature_mid = jit(feature_mid) + feature_bot = jit(feature_bot) + + k_fc = kernel_fc(x) + f_fc = feature_fc(x, [()] * 3) + nngp_fc, ntk_fc = _convert_features_to_matrices(f_fc) + self.assertAllClose(k_fc.nngp, nngp_fc) + self.assertAllClose(k_fc.ntk, ntk_fc) + + k_top = kernel_top(x) + f_top = feature_top(x, [()] * 4) + nngp_top, ntk_top = _convert_features_to_matrices(f_top) + self.assertAllClose(k_top.nngp, nngp_top) + self.assertAllClose(k_top.ntk, ntk_top) + + k_mid = kernel_mid(x) + f_mid = feature_mid(x, [()] * 4) + nngp_mid, ntk_mid = _convert_features_to_matrices(f_mid) + self.assertAllClose(k_mid.nngp, nngp_mid) + self.assertAllClose(k_mid.ntk, ntk_mid) + + k_bot = kernel_bot(x) + f_bot = feature_bot(x, [()] * 4) + nngp_bot, ntk_bot = _convert_features_to_matrices(f_bot) + self.assertAllClose(k_bot.nngp, nngp_bot) + self.assertAllClose(k_bot.ntk, ntk_bot) + + @parameterized.product(ndim=[2, 3, 4], + channel_axis=[1, 2, 3], + n_layers=[1, 2, 3, 4], + use_conv=[True, False], + use_layernorm=[True, False], + do_pool=[True, False], + do_jit=[True, False]) + def test_channel_axis(self, ndim, channel_axis, use_conv, n_layers, + use_layernorm, do_pool, do_jit): + n, h, w, c = 4, 8, 6, 5 + W_std = 1.7 + b_std = 0.1 + key = random.PRNGKey(1) + + if ndim == 2: + if channel_axis != 1: + raise absltest.SkipTest() + input_shape = (n, h * w * c) + elif ndim == 3: + if channel_axis == 1: + input_shape = (n, c, h * w) + elif channel_axis == 2: + input_shape = (n, h * w, c) + else: + raise absltest.SkipTest() + elif ndim == 4: + if channel_axis == 1: + input_shape = (n, c, h, w) + dn = ('NCAB', 'ABIO', 'NCAB') + elif channel_axis == 3: + input_shape = (n, h, w, c) + dn = ('NABC', 'ABIO', 'NABC') + else: + raise absltest.SkipTest() + + x = random.normal(key, input_shape) + + if use_conv: + if ndim != 4: + raise absltest.SkipTest() + else: + linear = stax.Conv(1, (3, 3), (1, 1), + 'SAME', + W_std=W_std, + b_std=b_std, + dimension_numbers=dn) + linear_feat = ft.ConvFeatures(1, (3, 3), (1, 1), + W_std=W_std, + b_std=b_std, + dimension_numbers=dn) + else: + linear = stax.Dense(1, + W_std=W_std, + b_std=b_std, + channel_axis=channel_axis) + linear_feat = ft.DenseFeatures(1, + W_std=W_std, + b_std=b_std, + channel_axis=channel_axis) + + layers = [linear, stax.Relu()] * n_layers + layers += [linear] + layers += [stax.LayerNorm(channel_axis, channel_axis=channel_axis) + ] if use_layernorm else [] + layers += [stax.GlobalAvgPool( + channel_axis=channel_axis)] if do_pool else [stax.Flatten()] + kernel_fn = stax.serial(*layers)[2] + + layers = [ + linear_feat, + ft.ReluFeatures(method='EXACT', channel_axis=channel_axis) + ] * n_layers + layers += [linear_feat] + layers += [ft.LayerNormFeatures(channel_axis, channel_axis=channel_axis) + ] if use_layernorm else [] + layers += [ft.GlobalAvgPoolFeatures( + channel_axis=channel_axis)] if do_pool else [ft.FlattenFeatures()] + feature_fn = ft.serial(*layers)[1] + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + f = feature_fn(x, [()] * len(layers)) + nngp, ntk = _convert_features_to_matrices(f) + self.assertAllClose(k.nngp, nngp) + self.assertAllClose(k.ntk, ntk) + + @parameterized.product( + channel_axis=[1, 3], + W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + relu_method=['RANDFEAT', 'POLYSKETCH', 'PSRF', 'POLY', 'EXACT'], + depth=[5], + do_jit=[True, False]) + def test_myrtle_network(self, channel_axis, W_std, b_std, relu_method, depth, + do_jit): + n, h, w, c = 2, 32, 32, 3 + rng = random.PRNGKey(1) + if channel_axis == 1: + x = _get_init_data(rng, (n, c, h, w)) + dn = ('NCAB', 'ABIO', 'NCAB') + elif channel_axis == 3: + x = _get_init_data(rng, (n, h, w, c)) + dn = ('NABC', 'ABIO', 'NABC') + + layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} + + def _get_myrtle_kernel_fn(): + conv = functools.partial(stax.Conv, + W_std=W_std, + b_std=b_std, + padding='SAME', + dimension_numbers=dn) + layers = [] + layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][0] + layers += [ + stax.AvgPool((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] + layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][1] + layers += [ + stax.AvgPool((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] + layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][2] + layers += [ + stax.AvgPool((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] * 3 + layers += [stax.Flatten(), stax.Dense(1, W_std=W_std, b_std=b_std)] + + return stax.serial(*layers) + + def _get_myrtle_feature_fn(**relu_args): + conv = functools.partial(ft.ConvFeatures, + W_std=W_std, + b_std=b_std, + padding='SAME', + dimension_numbers=dn) + layers = [] + layers += [ + conv(1, (3, 3)), + ft.ReluFeatures(channel_axis=channel_axis, **relu_args) + ] * layer_factor[depth][0] + layers += [ + ft.AvgPoolFeatures((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] + layers += [ + conv(1, (3, 3)), + ft.ReluFeatures(channel_axis=channel_axis, **relu_args) + ] * layer_factor[depth][1] + layers += [ + ft.AvgPoolFeatures((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] + layers += [ + conv(1, (3, 3)), + ft.ReluFeatures(channel_axis=channel_axis, **relu_args) + ] * layer_factor[depth][2] + layers += [ + ft.AvgPoolFeatures((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] * 3 + layers += [ + ft.FlattenFeatures(), + ft.DenseFeatures(1, W_std=W_std, b_std=b_std) + ] + + return ft.serial(*layers) + + kernel_fn = _get_myrtle_kernel_fn()[2] + + relu_args = {'method': relu_method} + if relu_method == 'RANDFEAT': + relu_args['feature_dim0'] = 2048 + relu_args['feature_dim1'] = 2048 + relu_args['sketch_dim'] = 2048 + elif relu_method == 'POLYSKETCH': + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 2048 + relu_args['sketch_dim'] = 2048 + elif relu_method == 'PSRF': + relu_args['feature_dim0'] = 2048 + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 2048 + relu_args['sketch_dim'] = 2048 + elif relu_method == 'POLY': + relu_args['poly_degree'] = 16 + elif relu_method == 'EXACT': + pass + else: + raise ValueError(relu_method) + + init_fn, feature_fn = _get_myrtle_feature_fn(**relu_args) + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + k_nngp = k.nngp + k_ntk = k.ntk + + _, feat_fn_inputs = init_fn(rng, x.shape) + f = feature_fn(x, feat_fn_inputs) + if np.iscomplexobj(f.nngp_feat) or np.iscomplexobj(f.ntk_feat): + nngp_feat = np.concatenate((f.nngp_feat.real, f.nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((f.ntk_feat.real, f.ntk_feat.imag), axis=-1) + f = f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + k_nngp_approx = f.nngp_feat @ f.nngp_feat.T + k_ntk_approx = f.ntk_feat @ f.ntk_feat.T + + if relu_method == 'EXACT': + self.assertAllClose(k_nngp, k_nngp_approx) + self.assertAllClose(k_ntk, k_ntk_approx) + else: + test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.2, 1.) + + def test_aggregate_features(self): + rng = random.PRNGKey(1) + rng1, rng2 = random.split(rng, 2) + + batch_size = 4 + num_channels = 3 + shape = (5,) + width = 1 + + x = random.normal(rng1, (batch_size,) + shape + (num_channels,)) + pattern = random.uniform(rng2, (batch_size,) + shape * 2) + + kernel_fn = stax.serial(stax.Dense(width, W_std=2**0.5), stax.Relu(), + stax.Aggregate(), stax.GlobalAvgPool(), + stax.Dense(width))[2] + + k = jit(kernel_fn)(x, None, pattern=(pattern, pattern)) + + feature_fn = ft.serial(ft.DenseFeatures(width, W_std=2**0.5), + ft.ReluFeatures(method='EXACT'), + ft.AggregateFeatures(), ft.GlobalAvgPoolFeatures(), + ft.DenseFeatures(width))[1] + + f = feature_fn(x, [()] * 5, **{'pattern': pattern}) + self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) + self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) + + @parameterized.product(n_layers=[1, 2, 3, 4, 5], do_jit=[True, False]) + def test_onepass_fc_relu_nngp_ntk(self, n_layers, do_jit): + rng = random.PRNGKey(1) + n, d = 4, 256 + x = _get_init_data(rng, (n, d)) + + kernel_fn = stax.serial(*[stax.Dense(1), stax.Relu()] * n_layers + + [stax.Dense(1)])[2] + + poly_degree = 8 + poly_sketch_dim = 4096 + + init_fn, feature_fn = ft.ReluNTKFeatures(n_layers, poly_degree, + poly_sketch_dim) + + rng2 = random.PRNGKey(2) + _, feat_fn_inputs = init_fn(rng2, x.shape) + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + f = feature_fn(x, feat_fn_inputs) + + k_nngp_approx = f.nngp_feat @ f.nngp_feat.T + k_ntk_approx = f.ntk_feat @ f.ntk_feat.T + + test_utils.assert_close_matrices(self, k.nngp, k_nngp_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k.ntk, k_ntk_approx, 0.2, 1.) + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/tests/sketching_test.py b/experimental/tests/sketching_test.py new file mode 100644 index 00000000..16322403 --- /dev/null +++ b/experimental/tests/sketching_test.py @@ -0,0 +1,55 @@ +from absl.testing import absltest +from absl.testing import parameterized + +import jax.numpy as np +from math import factorial +import jax.random as random +from experimental.sketching import PolyTensorSketch +from tests import test_utils + +NUM_POINTS = [10, 100, 1000] +NUM_DIMS = [64, 256, 1024] + + +class SketchingTest(test_utils.NeuralTangentsTestCase): + + @classmethod + def _get_init_data(cls, rng, shape, normalized_output=True): + x = random.normal(rng, shape) + if normalized_output: + return x / np.linalg.norm(x, axis=-1, keepdims=True) + else: + return x + + @parameterized.parameters({ + 'n': n, + 'd': d, + 'sketch_dim': 1024, + 'degree': 16 + } for n in NUM_POINTS for d in NUM_DIMS) + def test_exponential_kernel(self, n, d, sketch_dim, degree): + rng = random.PRNGKey(1) + x = self._get_init_data(rng, (n, d), True) + + coeffs = np.asarray([1 / factorial(i) for i in range(degree)]) + + rng2 = random.PRNGKey(2) + pts = PolyTensorSketch(rng=rng2, + input_dim=d, + sketch_dim=sketch_dim, + degree=degree).init_sketches() # pytype:disable=wrong-keyword-args + + x_sketches = pts.sketch(x) + + z = pts.expand_feats(x_sketches, coeffs) + z = pts.standardsrht(z) + z = np.concatenate((z.real, z.imag), axis=-1) + + k_exact = np.polyval(coeffs[::-1], x @ x.T) + k_approx = z @ z.T + + test_utils.assert_close_matrices(self, k_exact, k_approx, 0.15, 1.) + + +if __name__ == "__main__": + absltest.main() diff --git a/setup.py b/setup.py index 6861b15e..a376cd63 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ 'jax>=0.3.13', 'frozendict>=2.3', 'typing_extensions>=4.0.1', + 'jaxopt>=0.3.1', 'tf2jax>=0.3.0', ]