From c64f307f7545d47e076a272b20e93e7dff6c425b Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 24 Feb 2022 13:59:11 +0900 Subject: [PATCH 01/44] add NTK Random Features and Sketching codes --- experimental/README.md | 140 ++++++++ .../__pycache__/features.cpython-37.pyc | Bin 0 -> 9881 bytes .../__pycache__/sketching.cpython-37.pyc | Bin 0 -> 4416 bytes experimental/features.py | 327 ++++++++++++++++++ experimental/sketching.py | 128 +++++++ experimental/test_fc_ntk.py | 106 ++++++ experimental/test_myrtle_network.py | 117 +++++++ 7 files changed, 818 insertions(+) create mode 100644 experimental/README.md create mode 100644 experimental/__pycache__/features.cpython-37.pyc create mode 100644 experimental/__pycache__/sketching.cpython-37.pyc create mode 100644 experimental/features.py create mode 100644 experimental/sketching.py create mode 100644 experimental/test_fc_ntk.py create mode 100644 experimental/test_myrtle_network.py diff --git a/experimental/README.md b/experimental/README.md new file mode 100644 index 00000000..9eb1064a --- /dev/null +++ b/experimental/README.md @@ -0,0 +1,140 @@ +# Efficient Feature Map of Neural Tangent Kernels via Sketching and Random Features + +Implementations developed in [[1]](#1-scaling-neural-tangent-kernels-via-sketching-and-random-features). The library is written for users familar with [JAX](https://github.com/google/jax) and [Neural Tangents](https://github.com/google/neural-tangents) library. The codes are compatible with NT v0.5.0. + +[PyTorch](https://pytorch.org/) Implementations can be found in [here](https://github.com/insuhan/ntk-sketch-rf). + + +## Examples + +### Fully-connected NTK approximation via Random Features: + +```python +from jax import random +from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial + +relufeat_arg = { + 'feature_dim0': 128, + 'feature_dim1': 128, + 'sketch_dim': 256, + 'method': 'rf', +} + +init_fn, _, feature_fn = serial( + DenseFeatures(512), ReluFeatures(**relufeat_arg), + DenseFeatures(512), ReluFeatures(**relufeat_arg), + DenseFeatures(1) +) + +key1, key2 = random.split(random.PRNGKey(1)) +x = random.normal(key1, (5, 4)) + +initial_nngp_feat_shape = x.shape +initial_ntk_feat_shape = (-1,0) +initial_feat_shape = (initial_nngp_feat_shape, initial_ntk_feat_shape) + +_, feat_fn_inputs = init_fn(key2, initial_feat_shape) +feats = feature_fn(_inputs_to_features(x), feat_fn_inputs) +# feats.nngp_feat is a feature map of NNGP kernel +# feats.ntk_feat is a feature map of NTK +``` +For more details of fully connected NTK features, please check `test_fc_ntk.py`. + +### Convolutional NTK approximation via Random Features: +```python +init_fn, _, feature_fn = serial( + ConvFeatures(512, filter_size=3), ReluFeatures(**relufeat_arg), + AvgPoolFeatures(2, 2), FlattenFeatures() +) + +n, H, W, C = 5, 8, 8, 3 +x = random.normal(key1, shape=(n, H, W, C)) + +_, feat_fn_inputs = init_fn(key2, (x.shape, (-1, 0)) +feats = feature_fn(_inputs_to_features(x), feat_fn_inputs) +# feats.nngp_feat is a feature map of NNGP kernel +# feats.ntk_feat is a feature map of NTK +``` +For more complex CNTK features, please check `test_myrtle_networks.py`. + +# Modules + +All modules return a triple functions `(init_fn, apply_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. + +- `init_fn` takes (1) random seed and (2) a pair of shapes of input features for both NNGP and NTK. It returns (1) a pair of shapes of output features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach). +- `apply_fn` does nothing (dummy functions). +- `feature_fn` takes (1) feature structure `features.Feature` and (2) parameters used for feature approximation (initialized by `init_fn`). It returns `features.Feature` including approximate features of the corresponding module. + + +## [`features.DenseFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L88) +`features.DenseFeatures` provides features for fully-connected dense layer and corresponds to `stax.Dense` module in [Neural Tangents](https://github.com/google/neural-tangents). We assume that the input is a tabular dataset (i.e., a n-by-d matrix). Its `feature_fn` updates the NTK features by concatenating NNGP features and NTK features. This is because `stax.Dense` updates a new NTK kernel matrix `(N x D)` by adding the previous NNGP and NTK kernel matrices. The features of dense layer are exact and no approximations are applied. +```python +import numpy as np + +width = 1 +x = random.normal(key1, shape=(3, 2)) +_, _, kernel_fn = stax.Dense(width) +nt_kernel = kernel_fn(x) + +_, _, feat_fn = DenseFeatures(width) +feat = feat_fn(_inputs_to_features(x), ()) + +assert np.linalg.norm(nt_kernel.ntk - feat.ntk_feat @ feat.ntk_feat.T) <= 1e-12 +assert np.linalg.norm(nt_kernel.nngp - feat.nngp_feat @ feat.nngp_feat.T) <= 1e-12 +``` + +## [`features.ReluFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L119) +`features.ReluFeatures` is a key module of the NTK approximation. We implement feature approximations based on (1) Random Features of arc-cosine kernels [[2]](#2) and (2) Polynomial Sketch [[3]](#3). Parameters used for feature approximation are intialized in `init_fn`. We support tabular and image datasets. For tabular dataset, the input features are of form `N x D` matrix and the approximations are applied to the d-dimensional vectors. + +For image dataset, the inputs are 4-D tensors with shape `N x H x W x D` where N is batch size, H is image height, W is image width and D is the feature dimension. We reshape the image features into 2-D tensor with shape `NHW x D` and apply proper feature approximations. Then, the resulting features reshape to 4-D tensor with shape `N x H x W x D'` where `D'` is the output dimension of the feature approximation. + +To use the Random Features approach, set the parameter `method` to `rf` (default `rf`), e.g., +```python +x = random.normal(key1, shape=(3, 32)) + +init_fn, _ , feat_fn = serial( + DenseFeatures(1), + ReluFeatures(method='rf', feature_dim0=10, feature_dim1=20, sketch_dim=30) +) + +_, params = init_fn(key1, (x.shape,(-1, 0))) + +out_feat = feat_fn(_inputs_to_features(x), params) + +assert out_feat.nngp_feat.shape == (3, 20) +assert out_feat.ntk_feat.shape == (3, 30) +``` + +To use the exact feature map (based on Cholesky decomposition), set the parameter `method` to `exact`, e.g., +```python +init_fn, _ , feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact')) +_, params = init_fn(key1, (x.shape,(-1, 0))) +out_feat = feat_fn(_inputs_to_features(x), params) + +assert out_feat.nngp_feat.shape == (3, 3) +assert out_feat.ntk_feat.shape == (3, 3) +``` +(This is for debugging. The dimension of the exact feature map is equal to the number of inputs, i.e., `N` for tabular dataset, `NHW` for image dataset). + + +## [`features.ConvFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L236) + +`features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * s**2)` where `s` is the stride size and `d` is the input feature dimension. + +To be updated. + + +## [`features.AvgPoolFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L269) + +To be updated. + +## [`features.FlattenFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L304) + +To be updated. + +## References +#### [1] [Scaling Neural Tangent Kernels via Sketching and Random Features](https://arxiv.org/pdf/2106.07880.pdf) +#### [2] [Kernel methods for deep learning](https://cseweb.ucsd.edu/~saul/papers/nips09_kernel.pdf) +#### [3] [Oblivious Sketching of High-Degree Polynomial Kernels](https://arxiv.org/pdf/1909.01410.pdf) +#### [4] [On Exact Computation with an Infinitely Wide Neural Net](https://arxiv.org/pdf/1904.11955.pdf) + diff --git a/experimental/__pycache__/features.cpython-37.pyc b/experimental/__pycache__/features.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d56a05bf8ef0223c89e62ee86d2c15dee52959c GIT binary patch literal 9881 zcmcIqS#MlNcD{A@Ubu*)Hd3M~OSaUu%uS>5*j8-Ia=auC<5;rg>G8~UtLrw^WN)f_ zDUnQf1{lg=JAsUu36Li}ka-J`rvUi_c}Vgb?n4kD4-+8|$xB`W1j%`_xeHOyzX9cl9LoLZ58q zO!aL3Y;`(6jpq?L`q;?N$T2yN-*fVWoWSpSIVq>`J1bAhQ}}%@9ABGz-;k%}na74a zvt#av`~~@XXstaTUc~o;oPKJ{8F}uplfNX-%UQhhf_wulFXDTd`p64UUHQDch#ptu z4StK>)ZfeU@>5@4kvX*H+_2_GjTTJawSSK zvJe-Fl|mGS5o)tRqh7BB>!GTJmBjnFAZ59>N@E8fqk+nnN?`?xm_@@~A*sDwdcH&k6FVpENyg|q*Df8l<;8ZMM;(Z)g)-7iH8kLqeYY7~m$ zLM@ILYT8&)EvRC#9wm0MQf}mg znm|*37itoZi5KWjM4ZS(W(ai_`OxexDtLApVQ}_+ljk;9RxiHu>+Js1e|!VSMcV;8 zJWz4sRSTQt>IS&x$jP06dOeT+-Ko!^+?$$CuD9_!cxd9gW{l&+VzFn&V$E8!H4&be z(wa05*X)}uQ`)i96tUZcM9vsHcz4a4G0qs$nlV5)XGNXF*u=tEZ1jwXpsDC5;*Q>p zQ{RX~mAGKYLS;3vYjss+GAUYcVm8XyHxVbXHaBB+f`;g^izc+YUT8E5i!Y)nB6=D) z>8nD$fYQ+$UPH_OZ^PIwSepoPMsL@RJrDr3ngD0O>{G^RB9}YUC4=21gF26i74bvP zSI<$I4phxic7ZZFFW$GhNG0Ny&hFJEsyKf4OIOj}-@U1b8bzeuhmQPhq-X_MXyk=7 z^5&-oK-3N$DQ)2NuXQwqg~AL9H&sBz>@%(j*{8dQ(xst~ACD-wa}Lo7(? z_mQGyWSZ5Q=2?Ksq%mW(Y_NwZ&8EF(N$ZK(a$>t_Hz7$)NYoS34bJw$9cYE91Zjxo zt}wO2Xs%w$&6R77jX0W%>vM&wE2F5xdgOhpLS_C& zrCuylqBrOJT_NOe&mgXDVzi56o#%;K_$U~z>kp|({yrvLrPXWW1N<*W z#rEvGu&lza(Dpf@bC5l;pyg8f35`HBm6|uZM@#H@aK;Z%SEc?e24efUW9I7e3JeDl z&_gt5sUaOf1_6=)y&-l@t8Ki7Qq$&N02Uxjtea}SiPqXgOa=pbmCQCy37`h^VT}cp zAP%PhEl69!w4?=dW*=p~1joH%@1ER_8>G2%z) z!Kw~POp-s`HapzF;yw&C)qAkfHsRmPfJNU zS`SH=lF()wFxCSZu0qkG-@wySGR{x^&Wnd?mJ$~te!nh@%(R$3gG6N(>CaGTn|sE- z@wM5evv;GZSp0$ode_{wcI{nf*WL9>zJx>qo_%Rc2ZEoHso0Dm@=wiO6X5Mh|M5sG z-SnHOW|~Dmy#fTxV|>_qQR0M~g<{-eOgb{c7$RTv1$gM&z@RNU0AO|Dq3v;vgC*Nw z*S{ho8_Hb(~L z@Oh5UFg%9dbs6QHUuD7#15&nOEh z`;xK(WxNYodmqT3;E63cLrl=u^pU3UwS^0lCX1&uQWkhZiF%|_nh>*M`}`+iW#jup z{J*88br65{=&5>uZ!&tS>v*LnT;y4iy~L>hE(sS+7JyFPbNP>tXf_hK`DsZ-;<|=zCM0!{HE~h;5Q8c@W;Wj>vnAKIZa=Wxephz zl|hY*_7t`6o7|r2x9?eCVfVDLFIw4Vde4L4WcD$0uXdysV|ECi`nS&Pqn+8uy0gdr zFGOtxy|nl&$Eg@A+7UZqV4N&Q_SQ0Rh{xkBwYE;gspfd|#1#C#fqo-3o5sWbN16~$ znPOcBAtJ|@A#K3+V9!tJrk$gHP1nHpM%hm2)$NB3P!#XVipa zHJ=$k2?m3sP=XiUSNODW4S3%jW7r(X4zBm-v~_lvGCnpw{eo|udQruw?a3<{*`$$g zQP` zbWK7FW%l>n2vX9dFMiKM(bmbfgPo5DY@T9-Dh!-_xU=6I`)zM?yt>_wAk@klUU# zz}p>94;*%cA2`&jYvwrY54wj;9O8~XwcwACopZndQTM5i3=ZH{pS-kiw*3DK1ueW# z{N(MWfmGY)Z-M6DpZfnFnv^)mfOC2DR9*GoV>miRB01@ih)F)==*57KHUz$H&cGT* z3?E59(uk?r3WD6-5`$kCTQ0cdw8BhZdtjesHe; zz?!tn=VpE18Xr-+v~*&@_Yi=|JSyWd=#hq5rgMum88mFN-c4TRpV<4)WV3iuI6@kMcAz3LzVee0$Ma z6x(w*AFk4^*!L&C&(oY8@x5^DY=q6flsYZNss~jT=-pVDW2!=z$t=09*!r?Di&qot=4bEowIfC+4+Ng&pW{dU88wb-rjNQ;my=zIpn!F! zvVV4of1^^U-jjtluMjZ_*nDW)?d!*qvoGTP09<#BA!FPcKqf{Z_kj7F_vnfhH>v%O zfTP|{6mc0wB==E$Lm4ZJn<%|NYb8>2JW7i$8n!Qfhc&}~)yS0M;nmkPF40ZT+k=8* z;Awbqyr5+4z_hsP?*^r^8Q+@7s9$1yY=j0+M*CMc0W6^*wL%5~+cB3==!AcP0$?|& z5yXAS6d-nE3-w^om$;gA(`}?EkL=rE2IdiNP&7acMLHN%&LTis#heo>K$aTP52KQLH7?D`H{x@YJea{eM1#ts*sQ;nV|b8mLSMa zA9!Ovh*WW&|I(s}SNSS8KX+s=S`V!|#4T`&`lcz%P)6sRKY8#i9yUL9u$sonr*ue# z;CzHe=mjMEe@Ec|6mVoEN%j71qwj_n>6L4gy-pdSAnuKzo7TtDdpPA4w2sdsMHh-+T4y<9*hA zZ)PT7cz!vz^m%@rvA@&DBXhT3l108008OqM)Tcc)9n)7c=&M-*a|Est$ydaULmQ^uLe;77wsr>H zR-$}jsu7?A(Vj|%Y1~ShfwuD`ZEG9Hy4!y}b&{;e*OUEWqPl(b<8-}JcQ@;-4G;88 zH|rKrg-Y@j2+wZ(2$8oc*TKzF1)+Ys^C!=MsTEap%YWm6Tifw|k}HHErf8$hh8EdSTkUqCZGfygd!t>WXHqe=wW|uF z0bOqmvQ}Iq85(q!#HqHs{kYR~)e^Ca{SCCu(AnRBIJ`DVv^Cf(jOr?IgE-yl)Za%x zCqrTUHhwqx;r#UJR!s0^KZ>%rpG1)kqNqQRdnxrBQM9`kr(=rM%?dpiMRAr5ikQMz z9!2UFWNR;zu~P9tWyhqQE3Isi2wx*W>zZ0J*`1?0^S2vpX{9eE?a zw4}Sr5HcL;;T`P%6_TB4xljHus+)1{34bhTR9Tn)Q=$G*c)eQbO0NviMB9=z1UZk4 zdzq1$m-Q`XctABgwTcHcmuT7ja}vhXS4gN`C4q3X>fJ9MVJ7U&pRDmzvQyR;D1@F{ zTNI4fWvDm%&;)M-gjKFJ89A(-=HO<-gdUT6!rEz;H#tXr zOoP8gfZSD4gjcHop^wNyIX{BeC7tN1O;gON&X{(elPHEH(BQ7Pi9j_k$oCidETGMA zne?zICcOpT5VHpB^24Q9=&nV&rH9fcC0CO!{)NLWc@I8Whrvy|W; z&+8c}3D2QCmi2oMU@VxEa`+- zHZ;+~3$gpjSm%~neeDdSb1Fb9onKnRsm?D^&UBs~zhQ=Hod$2vIniII^GoaaZ&o5ohZ`m(XbU*adzwGge0VV{ zQjz!&E#C%UJiu=r+29Ux!8|Sw%g7D$z}-L|cu(BmKjN#PEuz1SEHZB{84LXDs2-;| zZe!FPe)#s4`q{=+>;ufD?hyEpz;_9}vRvH&m9lzH7dv10?<&GGI= parse_version('0.5.0'): + from neural_tangents._src.utils import utils, dataclasses + from neural_tangents._src.stax.linear import _pool_kernel, Padding + from neural_tangents._src.stax.linear import _Pooling as Pooling +else: + from neural_tangents.utils import utils, dataclasses + from neural_tangents.stax import _pool_kernel, Padding, Pooling + +from sketching import TensorSRHT2, PolyTensorSRHT +"""Implementation for NTK Sketching and Random Features""" + + +def _prod(tuple_): + prod = 1 + for x in tuple_: + prod = prod * x + return prod + + +# Arc-cosine kernel functions is for debugging. +def _arccos(x): + return np.arccos(np.clip(x, -1, 1)) + + +def _sqrt(x): + return np.sqrt(np.maximum(x, 1e-20)) + + +def kappa0(x): + xxt = x @ x.T + prod = np.outer(np.linalg.norm(x, axis=-1)**2, np.linalg.norm(x, axis=-1)**2) + return (1 - _arccos(xxt / _sqrt(prod)) / np.pi) / 2 + + +def kappa1(x): + xxt = x @ x.T + prod = np.outer(np.linalg.norm(x, axis=-1)**2, np.linalg.norm(x, axis=-1)**2) + return (_sqrt(prod - xxt**2) + + (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi / 2 + + +@dataclasses.dataclass +class Features: + nngp_feat: np.ndarray + ntk_feat: np.ndarray + + batch_axis: int = dataclasses.field(pytree_node=False) + channel_axis: int = dataclasses.field(pytree_node=False) + + replace = ... # type: Callable[..., 'Features'] + + +def _inputs_to_features(x: np.ndarray, + batch_axis: int = 0, + channel_axis: int = -1, + **kwargs) -> Features: + """Transforms (batches of) inputs to a `Features`.""" + + # Followed the same initialization of Neural Tangents library. + nngp_feat = x / x.shape[channel_axis]**0.5 + ntk_feat = np.empty((), dtype=nngp_feat.dtype) + + return Features(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + batch_axis=batch_axis, + channel_axis=channel_axis) + + +# Modified the serial process of feature map blocks. +# Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py +def serial(*layers): + init_fns, apply_fns, feature_fns = zip(*layers) + init_fn, apply_fn = ostax.serial(*zip(init_fns, apply_fns)) + + # import time + def feature_fn(k, inputs, **kwargs): + for f, input_ in zip(feature_fns, inputs): + # print(f) + # tic = time.time() + k = f(k, input_, **kwargs) + # print(f"toc: {time.time() - tic:.2f} sec") + return k + + return init_fn, apply_fn, feature_fn + + +def DenseFeatures(out_dim: int, + W_std: float = 1., + b_std: float = None, + parameterization: str = 'ntk', + batch_axis: int = 0, + channel_axis: int = -1): + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_ntk_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] + + ntk_feat_shape[-1],) + return (nngp_feat_shape, new_ntk_feat_shape), () + + def apply_fn(**kwargs): + return None + + def kernel_fn(f: Features, input, **kwargs): + nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat + nngp_feat *= W_std + ntk_feat *= W_std + + if ntk_feat.ndim == 0: # check if ntk_feat is empty + ntk_feat = nngp_feat + else: + ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, apply_fn, kernel_fn + + +def ReluFeatures( + feature_dim0: int = 1, + feature_dim1: int = 1, + sketch_dim: int = 1, + poly_degree0: int = 4, + poly_degree1: int = 4, + poly_sketch_dim0: int = 1, + poly_sketch_dim1: int = 1, + method: str = 'rf', +): + + method = method.lower() + assert method in ['rf', 'ps', 'exact'] + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) + + if method == 'rf': + rng1, rng2, rng3 = random.split(rng, 3) + # Random vectors for random features of arc-cosine kernel of order 0. + W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) + # Random vectors for random features of arc-cosine kernel of order 1. + W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) + # TensorSRHT of degree 2 for approximating tensor product. + ts2 = TensorSRHT2(rng=rng3, + input_dim1=ntk_feat_shape[-1], + input_dim2=feature_dim0, + sketch_dim=sketch_dim).init_sketches() + return (new_nngp_feat_shape, new_ntk_feat_shape), (W0, W1, ts2) + + elif method == 'ps': + rng1, rng2, rng3 = random.split(rng, 3) + # PolySketch algorithm for arc-cosine kernel of order 0. + ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, + poly_degree0) + # PolySketch algorithm for arc-cosine kernel of order 1. + ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, + poly_degree1) + # TensorSRHT of degree 2 for approximating tensor product. + ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) + return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) + + elif method == 'exact': + # The exact feature map computation is for debug. + new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( + nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) + return (new_nngp_feat_shape, new_ntk_feat_shape), () + + def apply_fn(**kwargs): + return None + + def feature_fn(f: Features, input=None, **kwargs) -> Features: + + input_shape = f.nngp_feat.shape[:-1] + nngp_feat_dim = f.nngp_feat.shape[-1] + ntk_feat_dim = f.ntk_feat.shape[-1] + + nngp_feat_2d = f.nngp_feat.reshape(-1, nngp_feat_dim) + ntk_feat_2d = f.ntk_feat.reshape(-1, ntk_feat_dim) + + if method == 'rf': # Random Features approach. + W0: np.ndarray = input[0] + W1: np.ndarray = input[1] + ts2: TensorSRHT2 = input[2] + + kappa0_feat = (nngp_feat_2d @ W0 > 0) / np.sqrt(W0.shape[-1]) + nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / + np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) + ntk_feat = ts2.sketch(ntk_feat_2d, + kappa0_feat).reshape(input_shape + (-1,)) + + elif method == 'ps': + ps0: PolyTensorSRHT = input[0] + ps1: PolyTensorSRHT = input[1] + ts2: TensorSRHT2 = input[2] + raise NotImplementedError + + elif method == 'exact': # Exact feature extraction via Cholesky decomposition. + nngp_feat = cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) + + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = kappa0(nngp_feat_2d) + ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + + else: + raise NotImplementedError + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, apply_fn, feature_fn + + +def conv_feat(X, filter_size): + N, H, W, C = X.shape + out = np.zeros((N, H, W, C * filter_size)) + out = out.at[:, :, :, :C].set(X) + j = 1 + for i in range(1, min((filter_size + 1) // 2, W)): + out = out.at[:, :, :-i, j * C:(j + 1) * C].set(X[:, :, i:]) + j += 1 + out = out.at[:, :, i:, j * C:(j + 1) * C].set(X[:, :, :-i]) + j += 1 + return out + + +def conv2d_feat(X, filter_size): + return conv_feat(np.moveaxis(conv_feat(X, filter_size), 1, 2), filter_size) + + +def ConvFeatures(out_dim: int, + filter_size: int, + W_std: float = 1.0, + b_std: float = 0., + channel_axis: int = -1): + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] * + filter_size**2,) + new_ntk_feat_shape = nngp_feat_shape[:-1] + ( + (nngp_feat_shape[-1] + ntk_feat_shape[-1]) * filter_size**2,) + return (new_nngp_feat_shape, new_ntk_feat_shape), () + + def apply_fn(**kwargs): + return None + + def feature_fn(f, input, **kwargs): + nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat + + nngp_feat = conv2d_feat(nngp_feat, filter_size) / filter_size * W_std + + if ntk_feat.ndim == 0: # check if ntk_feat is empty + ntk_feat = nngp_feat + else: + ntk_feat = conv2d_feat(ntk_feat, filter_size) / filter_size * W_std + ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, apply_fn, feature_fn + + +def AvgPoolFeatures(window_size: int, + stride_size: int = 2, + padding: str = stax.Padding.VALID.name, + normalize_edges: bool = False, + batch_axis: int = 0, + channel_axis: int = -1): + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + + new_nngp_feat_shape = nngp_feat_shape[:1] + ( + nngp_feat_shape[1] // window_size, + nngp_feat_shape[2] // window_size) + nngp_feat_shape[-1:] + new_ntk_feat_shape = ntk_feat_shape[:1] + ( + ntk_feat_shape[1] // window_size, + ntk_feat_shape[2] // window_size) + ntk_feat_shape[-1:] + return (new_nngp_feat_shape, new_ntk_feat_shape), () + + def apply_fn(**kwargs): + return None + + def feature_fn(f, input=None, **kwargs): + window_shape_kernel = (1,) + (window_size,) * 2 + (1,) + strides_kernel = (1,) + (window_size,) * 2 + (1,) + pooling = lambda x: _pool_kernel(x, Pooling.AVG, + window_shape_kernel, strides_kernel, + Padding(padding), normalize_edges, 0) + nngp_feat = pooling(f.nngp_feat) + ntk_feat = pooling(f.ntk_feat) + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, apply_fn, feature_fn + + +def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = nngp_feat_shape[:1] + (_prod(nngp_feat_shape[1:]),) + new_ntk_feat_shape = ntk_feat_shape[:1] + (_prod(ntk_feat_shape[1:]),) + return (new_nngp_feat_shape, new_ntk_feat_shape), () + + def apply_fn(**kwargs): + return None + + def feature_fn(f, input=None, **kwargs): + batch_size = f.nngp_feat.shape[0] + nngp_feat = f.nngp_feat.reshape(batch_size, -1) / np.sqrt( + _prod(f.nngp_feat.shape[1:-1])) + ntk_feat = f.ntk_feat.reshape(batch_size, -1) / np.sqrt( + _prod(f.ntk_feat.shape[1:-1])) + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, apply_fn, feature_fn diff --git a/experimental/sketching.py b/experimental/sketching.py new file mode 100644 index 00000000..db48f845 --- /dev/null +++ b/experimental/sketching.py @@ -0,0 +1,128 @@ +from jax import random +from jax import numpy as np +from neural_tangents._src.utils import utils, dataclasses +from neural_tangents._src.utils.typing import Optional + + +# TensorSRHT of degree 2. This version allows different input vectors. +@dataclasses.dataclass +class TensorSRHT2: + + input_dim1: int + input_dim2: int + sketch_dim: int + + rng: np.ndarray + shape: Optional[np.ndarray] = None + + rand_signs1: Optional[np.ndarray] = None + rand_signs2: Optional[np.ndarray] = None + rand_inds1: Optional[np.ndarray] = None + rand_inds2: Optional[np.ndarray] = None + + replace = ... + + def init_sketches(self): + rng1, rng2, rng3, rng4 = random.split(self.rng, 4) + rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 + rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 + rand_inds1 = random.choice(rng3, + self.input_dim1, + shape=(self.sketch_dim // 2,)) + rand_inds2 = random.choice(rng4, + self.input_dim2, + shape=(self.sketch_dim // 2,)) + shape = (self.input_dim1, self.input_dim2, self.sketch_dim) + return self.replace(shape=shape, + rand_signs1=rand_signs1, + rand_signs2=rand_signs2, + rand_inds1=rand_inds1, + rand_inds2=rand_inds2) + + def sketch(self, x1, x2): + x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] + x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] + out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) + return np.concatenate((out.real, out.imag), 1) + + +# TensorSRHT of degree p. This operates the same input vectors. +class PolyTensorSRHT: + + def __init__(self, rng, input_dim, sketch_dim, coeffs): + self.coeffs = coeffs + degree = len(coeffs) - 1 + self.degree = degree + + self.tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] + self.tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] + rng1, rng2, rng3 = random.split(rng, 3) + + ske_dim_ = sketch_dim // 4 + deg_ = degree // 2 + for i in range((degree - 1).bit_length()): + rng1, rng2 = random.split(rng1) + if i == 0: + self.tree_rand_signs[i] = random.choice( + rng1, 2, shape=(deg_, 2, input_dim)) * 2 - 1 + self.tree_rand_inds[i] = random.choice(rng2, + input_dim, + shape=(deg_, 2, ske_dim_)) + else: + self.tree_rand_signs[i] = random.choice( + rng1, 2, shape=(deg_, 2, ske_dim_)) * 2 - 1 + self.tree_rand_inds[i] = random.choice(rng2, + ske_dim_, + shape=(deg_, 2, ske_dim_)) + deg_ = deg_ // 2 + + rng1, rng2 = random.split(rng3) + self.rand_signs = random.choice(rng1, 2, shape=(degree * ske_dim_,)) * 2 - 1 + self.rand_inds = random.choice(rng2, + degree * ske_dim_, + shape=(sketch_dim // 2,)) + + def sketch(self, x): + n = x.shape[0] + log_degree = len(self.tree_rand_signs) + V = [0 for i in range(log_degree)] + E1 = np.concatenate((np.ones( + (n, 1), dtype=x.dtype), np.zeros((n, x.shape[-1] - 1), dtype=x.dtype)), + 1) + for i in range(log_degree): + deg = self.tree_rand_signs[i].shape[0] + V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), + dtype=np.complex64) + for j in range(deg): + if i == 0: + V[i] = V[i].at[j, :, :].set( + tensorsrht(x, x, self.tree_rand_inds[i][j, :, :], + self.tree_rand_signs[i][j, :, :])) + else: + V[i] = V[i].at[j, :, :].set( + tensorsrht(V[i - 1][2 * j, :, :], V[i - 1][2 * j + 1, :, :], + self.tree_rand_inds[i][j, :, :], + self.tree_rand_signs[i][j, :, :])) + U = [0 for i in range(2**log_degree)] + U[0] = V[log_degree - 1][0, :, :].clone() + + for j in range(1, len(U)): + p = (j - 1) // 2 + for i in range(log_degree): + if j % (2**(i + 1)) == 0: + V[i] = V[i].at[p, :, :].set( + np.concatenate((np.ones((n, 1)), np.zeros( + (n, V[i].shape[-1] - 1))), 1)) + else: + if i == 0: + V[i] = V[i].at[p, :, :].set( + tensorsrht(x, E1, self.tree_rand_inds[i][p, :, :], + self.tree_rand_signs[i][p, :, :])) + else: + V[i] = V[i].at[p, :, :].set( + tensorsrht(V[i - 1][2 * p, :, :], V[i - 1][2 * p + 1, :, :], + self.tree_rand_inds[i][p, :, :], + self.tree_rand_signs[i][p, :, :])) + p = p // 2 + U[j] = V[log_degree - 1][0, :, :].clone() + return U diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py new file mode 100644 index 00000000..38586be4 --- /dev/null +++ b/experimental/test_fc_ntk.py @@ -0,0 +1,106 @@ +from numpy.linalg import norm +from jax import random +from jax.config import config +from jax import jit + +config.update("jax_enable_x64", True) +from neural_tangents import stax + +from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial + +seed = 1 +n = 6 +d = 4 + +key1, key2 = random.split(random.PRNGKey(seed)) +x1 = random.normal(key1, (n, d)) + +width = 512 # this does not matter the output + +print("================= Result of Neural Tangent Library =================") + +init_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(width), stax.Relu(), + stax.Dense(width), stax.Relu(), + stax.Dense(1)) + +nt_kernel = kernel_fn(x1, None) + +print("K_nngp :") +print(nt_kernel.nngp) +print() + +print("K_ntk :") +print(nt_kernel.ntk) +print() + +print("================= Result of NTK Random Features =================") + +kappa0_feat_dim = 10000 +kappa1_feat_dim = 10000 +sketch_dim = 20000 + +f0 = _inputs_to_features(x1) + +relufeat_arg = { + 'method': 'rf', + 'feature_dim0': kappa0_feat_dim, + 'feature_dim1': kappa1_feat_dim, + 'sketch_dim': sketch_dim, +} + +init_fn, _, features_fn = serial(DenseFeatures(width), + ReluFeatures(**relufeat_arg), + DenseFeatures(width), + ReluFeatures(**relufeat_arg), DenseFeatures(1)) + +# Initialize random vectors and sketching algorithms +init_nngp_feat_shape = x1.shape +init_ntk_feat_shape = (-1, 0) +init_feat_shape = (init_nngp_feat_shape, init_ntk_feat_shape) +_, feat_fn_inputs = init_fn(key2, init_feat_shape) + +# Transform input vectors to NNGP/NTK feature map +f0 = _inputs_to_features(x1) +feats = jit(features_fn)(f0, feat_fn_inputs) + +print("K_nngp :") +print(feats.nngp_feat @ feats.nngp_feat.T) +print() + +print("K_ntk :") +print(feats.ntk_feat @ feats.ntk_feat.T) +print() + +print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" +) +print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" +) +print() + +print("================= (Debug) Exact NTK Feature Maps =================") + +relufeat_arg = {'method': 'exact'} + +init_fn, _, features_fn = serial(DenseFeatures(width), + ReluFeatures(**relufeat_arg), + DenseFeatures(width), + ReluFeatures(**relufeat_arg), DenseFeatures(1)) +f0 = _inputs_to_features(x1) +feats = jit(features_fn)(f0, feat_fn_inputs) + +print("K_nngp :") +print(feats.nngp_feat @ feats.nngp_feat.T) +print() + +print("K_ntk :") +print(feats.ntk_feat @ feats.ntk_feat.T) +print() + +print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" +) +print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" +) diff --git a/experimental/test_myrtle_network.py b/experimental/test_myrtle_network.py new file mode 100644 index 00000000..8f6616b2 --- /dev/null +++ b/experimental/test_myrtle_network.py @@ -0,0 +1,117 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '' +import functools +from numpy.linalg import norm +from jax.config import config +from jax import jit +# Enable float64 for JAX +config.update("jax_enable_x64", True) + +import jax.numpy as np +from jax import random + +from neural_tangents import stax +from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features + +layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} +width = 1 + + +def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.): + activation_fn = stax.Relu() + conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME') + + layers = [] + layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0] + layers += [stax.AvgPool((2, 2), strides=(2, 2))] + layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1] + layers += [stax.AvgPool((2, 2), strides=(2, 2))] + layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2] + layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3 + + layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)] + + return stax.serial(*layers) + + +def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args): + + conv_fn = functools.partial(ConvFeatures, W_std=W_std, b_std=b_std) + + layers = [] + layers += [conv_fn(width, filter_size=3), + ReluFeatures(**relu_args)] * layer_factor[depth][0] + layers += [AvgPoolFeatures(2, 2)] + layers += [ + ConvFeatures(width, filter_size=3, W_std=W_std), + ReluFeatures(**relu_args) + ] * layer_factor[depth][1] + layers += [AvgPoolFeatures(2, 2)] + layers += [ + ConvFeatures(width, filter_size=3, W_std=W_std), + ReluFeatures(**relu_args) + ] * layer_factor[depth][2] + layers += [AvgPoolFeatures(2, 2)] * 3 + layers += [FlattenFeatures(), DenseFeatures(1, W_std, b_std)] + + return serial(*layers) + + +key = random.PRNGKey(0) + +N, H, W, C = 4, 32, 32, 3 +key1, key2 = random.split(key) +x = random.normal(key1, shape=(N, H, W, C)) + +_, _, kernel_fn = MyrtleNetwork(5) +kernel_fn = jit(kernel_fn) + +print("================= Result of Neural Tangent Library =================") + +nt_kernel = kernel_fn(x) +print("K_nngp (exact):") +print(nt_kernel.nngp) +print() + +print("K_ntk (exact):") +print(nt_kernel.ntk) +print() + +print("================= CNTK Random Features =================") +kappa0_feat_dim = 1000 +kappa1_feat_dim = 1000 +sketch_dim = 1000 + +relufeat_arg = { + 'method': 'rf', + 'feature_dim0': kappa0_feat_dim, + 'feature_dim1': kappa1_feat_dim, + 'sketch_dim': sketch_dim, +} + +init_fn, _, feature_fn = MyrtleNetworkFeatures(5, **relufeat_arg) +feature_fn = jit(feature_fn) + +init_nngp_feat_shape = x.shape +init_ntk_feat_shape = (-1, 0) +init_feat_shape = (init_nngp_feat_shape, init_ntk_feat_shape) +inputs_shape, feat_fn_inputs = init_fn(key2, init_feat_shape) + +f0 = _inputs_to_features(x) +feats = feature_fn(f0, feat_fn_inputs) + +print("K_nngp (approx):") +print(feats.nngp_feat @ feats.nngp_feat.T) +print() + +print("K_ntk (approx):") +print(feats.ntk_feat @ feats.ntk_feat.T) +print() + +print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" +) +print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" +) From d1a9266fcdb9b9e690db500e48c5d75dd4a092f8 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 24 Feb 2022 14:24:52 +0900 Subject: [PATCH 02/44] Add NTK Random Features and Sketching codes --- experimental/README.md | 7 +++---- experimental/features.py | 2 +- experimental/sketching.py | 7 +++++++ experimental/test_fc_ntk.py | 3 +-- experimental/test_myrtle_network.py | 1 + 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/experimental/README.md b/experimental/README.md index 9eb1064a..f5ee2d75 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -121,16 +121,15 @@ assert out_feat.ntk_feat.shape == (3, 3) `features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * s**2)` where `s` is the stride size and `d` is the input feature dimension. -To be updated. - ## [`features.AvgPoolFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L269) -To be updated. +`features.AvgPoolFeatures` operates the average pooling on features of both NNGP and NTK. It calls [`_pool_kernel`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L3143) function in [Neural Tangents](https://github.com/google/neural-tangents) as a subroutine. ## [`features.FlattenFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L304) -To be updated. +`features.FlattenFeatures` makes the features 2-D tensors. Similar to [`Flatten`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L1641) module in [Neural Tangents](https://github.com/google/neural-tangents), the flattened features recale by the square-root of the number of elements. For example, if `nngp_feat` has the shape `N x H x W x C`, it returns a `N x HWC` matrix where all entries are divided by `(H*W*C)**0.5`. + ## References #### [1] [Scaling Neural Tangent Kernels via Sketching and Random Features](https://arxiv.org/pdf/2106.07880.pdf) diff --git a/experimental/features.py b/experimental/features.py index bdd852cc..196236da 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -16,7 +16,7 @@ from neural_tangents.stax import _pool_kernel, Padding, Pooling from sketching import TensorSRHT2, PolyTensorSRHT -"""Implementation for NTK Sketching and Random Features""" +""" Implementation for NTK Sketching and Random Features """ def _prod(tuple_): diff --git a/experimental/sketching.py b/experimental/sketching.py index db48f845..29f60cf4 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -46,6 +46,13 @@ def sketch(self, x1, x2): return np.concatenate((out.real, out.imag), 1) +# Function implementation of TensorSRHT of degree 2 (duplicated) +def tensorsrht(x1, x2, rand_inds, rand_signs): + x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]] + x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]] + return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) + + # TensorSRHT of degree p. This operates the same input vectors. class PolyTensorSRHT: diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index 38586be4..23e5d575 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -9,8 +9,7 @@ from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial seed = 1 -n = 6 -d = 4 +n, d = 6, 4 key1, key2 = random.split(random.PRNGKey(seed)) x1 = random.normal(key1, (n, d)) diff --git a/experimental/test_myrtle_network.py b/experimental/test_myrtle_network.py index 8f6616b2..29d34c73 100644 --- a/experimental/test_myrtle_network.py +++ b/experimental/test_myrtle_network.py @@ -96,6 +96,7 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args): init_nngp_feat_shape = x.shape init_ntk_feat_shape = (-1, 0) init_feat_shape = (init_nngp_feat_shape, init_ntk_feat_shape) + inputs_shape, feat_fn_inputs = init_fn(key2, init_feat_shape) f0 = _inputs_to_features(x) From 9dc3536dcce1d5c982fbdc773bde5d655bbf1aa0 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 24 Feb 2022 14:25:54 +0900 Subject: [PATCH 03/44] Delete cache files --- experimental/__pycache__/features.cpython-37.pyc | Bin 9881 -> 0 bytes .../__pycache__/sketching.cpython-37.pyc | Bin 4416 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 experimental/__pycache__/features.cpython-37.pyc delete mode 100644 experimental/__pycache__/sketching.cpython-37.pyc diff --git a/experimental/__pycache__/features.cpython-37.pyc b/experimental/__pycache__/features.cpython-37.pyc deleted file mode 100644 index 6d56a05bf8ef0223c89e62ee86d2c15dee52959c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9881 zcmcIqS#MlNcD{A@Ubu*)Hd3M~OSaUu%uS>5*j8-Ia=auC<5;rg>G8~UtLrw^WN)f_ zDUnQf1{lg=JAsUu36Li}ka-J`rvUi_c}Vgb?n4kD4-+8|$xB`W1j%`_xeHOyzX9cl9LoLZ58q zO!aL3Y;`(6jpq?L`q;?N$T2yN-*fVWoWSpSIVq>`J1bAhQ}}%@9ABGz-;k%}na74a zvt#av`~~@XXstaTUc~o;oPKJ{8F}uplfNX-%UQhhf_wulFXDTd`p64UUHQDch#ptu z4StK>)ZfeU@>5@4kvX*H+_2_GjTTJawSSK zvJe-Fl|mGS5o)tRqh7BB>!GTJmBjnFAZ59>N@E8fqk+nnN?`?xm_@@~A*sDwdcH&k6FVpENyg|q*Df8l<;8ZMM;(Z)g)-7iH8kLqeYY7~m$ zLM@ILYT8&)EvRC#9wm0MQf}mg znm|*37itoZi5KWjM4ZS(W(ai_`OxexDtLApVQ}_+ljk;9RxiHu>+Js1e|!VSMcV;8 zJWz4sRSTQt>IS&x$jP06dOeT+-Ko!^+?$$CuD9_!cxd9gW{l&+VzFn&V$E8!H4&be z(wa05*X)}uQ`)i96tUZcM9vsHcz4a4G0qs$nlV5)XGNXF*u=tEZ1jwXpsDC5;*Q>p zQ{RX~mAGKYLS;3vYjss+GAUYcVm8XyHxVbXHaBB+f`;g^izc+YUT8E5i!Y)nB6=D) z>8nD$fYQ+$UPH_OZ^PIwSepoPMsL@RJrDr3ngD0O>{G^RB9}YUC4=21gF26i74bvP zSI<$I4phxic7ZZFFW$GhNG0Ny&hFJEsyKf4OIOj}-@U1b8bzeuhmQPhq-X_MXyk=7 z^5&-oK-3N$DQ)2NuXQwqg~AL9H&sBz>@%(j*{8dQ(xst~ACD-wa}Lo7(? z_mQGyWSZ5Q=2?Ksq%mW(Y_NwZ&8EF(N$ZK(a$>t_Hz7$)NYoS34bJw$9cYE91Zjxo zt}wO2Xs%w$&6R77jX0W%>vM&wE2F5xdgOhpLS_C& zrCuylqBrOJT_NOe&mgXDVzi56o#%;K_$U~z>kp|({yrvLrPXWW1N<*W z#rEvGu&lza(Dpf@bC5l;pyg8f35`HBm6|uZM@#H@aK;Z%SEc?e24efUW9I7e3JeDl z&_gt5sUaOf1_6=)y&-l@t8Ki7Qq$&N02Uxjtea}SiPqXgOa=pbmCQCy37`h^VT}cp zAP%PhEl69!w4?=dW*=p~1joH%@1ER_8>G2%z) z!Kw~POp-s`HapzF;yw&C)qAkfHsRmPfJNU zS`SH=lF()wFxCSZu0qkG-@wySGR{x^&Wnd?mJ$~te!nh@%(R$3gG6N(>CaGTn|sE- z@wM5evv;GZSp0$ode_{wcI{nf*WL9>zJx>qo_%Rc2ZEoHso0Dm@=wiO6X5Mh|M5sG z-SnHOW|~Dmy#fTxV|>_qQR0M~g<{-eOgb{c7$RTv1$gM&z@RNU0AO|Dq3v;vgC*Nw z*S{ho8_Hb(~L z@Oh5UFg%9dbs6QHUuD7#15&nOEh z`;xK(WxNYodmqT3;E63cLrl=u^pU3UwS^0lCX1&uQWkhZiF%|_nh>*M`}`+iW#jup z{J*88br65{=&5>uZ!&tS>v*LnT;y4iy~L>hE(sS+7JyFPbNP>tXf_hK`DsZ-;<|=zCM0!{HE~h;5Q8c@W;Wj>vnAKIZa=Wxephz zl|hY*_7t`6o7|r2x9?eCVfVDLFIw4Vde4L4WcD$0uXdysV|ECi`nS&Pqn+8uy0gdr zFGOtxy|nl&$Eg@A+7UZqV4N&Q_SQ0Rh{xkBwYE;gspfd|#1#C#fqo-3o5sWbN16~$ znPOcBAtJ|@A#K3+V9!tJrk$gHP1nHpM%hm2)$NB3P!#XVipa zHJ=$k2?m3sP=XiUSNODW4S3%jW7r(X4zBm-v~_lvGCnpw{eo|udQruw?a3<{*`$$g zQP` zbWK7FW%l>n2vX9dFMiKM(bmbfgPo5DY@T9-Dh!-_xU=6I`)zM?yt>_wAk@klUU# zz}p>94;*%cA2`&jYvwrY54wj;9O8~XwcwACopZndQTM5i3=ZH{pS-kiw*3DK1ueW# z{N(MWfmGY)Z-M6DpZfnFnv^)mfOC2DR9*GoV>miRB01@ih)F)==*57KHUz$H&cGT* z3?E59(uk?r3WD6-5`$kCTQ0cdw8BhZdtjesHe; zz?!tn=VpE18Xr-+v~*&@_Yi=|JSyWd=#hq5rgMum88mFN-c4TRpV<4)WV3iuI6@kMcAz3LzVee0$Ma z6x(w*AFk4^*!L&C&(oY8@x5^DY=q6flsYZNss~jT=-pVDW2!=z$t=09*!r?Di&qot=4bEowIfC+4+Ng&pW{dU88wb-rjNQ;my=zIpn!F! zvVV4of1^^U-jjtluMjZ_*nDW)?d!*qvoGTP09<#BA!FPcKqf{Z_kj7F_vnfhH>v%O zfTP|{6mc0wB==E$Lm4ZJn<%|NYb8>2JW7i$8n!Qfhc&}~)yS0M;nmkPF40ZT+k=8* z;Awbqyr5+4z_hsP?*^r^8Q+@7s9$1yY=j0+M*CMc0W6^*wL%5~+cB3==!AcP0$?|& z5yXAS6d-nE3-w^om$;gA(`}?EkL=rE2IdiNP&7acMLHN%&LTis#heo>K$aTP52KQLH7?D`H{x@YJea{eM1#ts*sQ;nV|b8mLSMa zA9!Ovh*WW&|I(s}SNSS8KX+s=S`V!|#4T`&`lcz%P)6sRKY8#i9yUL9u$sonr*ue# z;CzHe=mjMEe@Ec|6mVoEN%j71qwj_n>6L4gy-pdSAnuKzo7TtDdpPA4w2sdsMHh-+T4y<9*hA zZ)PT7cz!vz^m%@rvA@&DBXhT3l108008OqM)Tcc)9n)7c=&M-*a|Est$ydaULmQ^uLe;77wsr>H zR-$}jsu7?A(Vj|%Y1~ShfwuD`ZEG9Hy4!y}b&{;e*OUEWqPl(b<8-}JcQ@;-4G;88 zH|rKrg-Y@j2+wZ(2$8oc*TKzF1)+Ys^C!=MsTEap%YWm6Tifw|k}HHErf8$hh8EdSTkUqCZGfygd!t>WXHqe=wW|uF z0bOqmvQ}Iq85(q!#HqHs{kYR~)e^Ca{SCCu(AnRBIJ`DVv^Cf(jOr?IgE-yl)Za%x zCqrTUHhwqx;r#UJR!s0^KZ>%rpG1)kqNqQRdnxrBQM9`kr(=rM%?dpiMRAr5ikQMz z9!2UFWNR;zu~P9tWyhqQE3Isi2wx*W>zZ0J*`1?0^S2vpX{9eE?a zw4}Sr5HcL;;T`P%6_TB4xljHus+)1{34bhTR9Tn)Q=$G*c)eQbO0NviMB9=z1UZk4 zdzq1$m-Q`XctABgwTcHcmuT7ja}vhXS4gN`C4q3X>fJ9MVJ7U&pRDmzvQyR;D1@F{ zTNI4fWvDm%&;)M-gjKFJ89A(-=HO<-gdUT6!rEz;H#tXr zOoP8gfZSD4gjcHop^wNyIX{BeC7tN1O;gON&X{(elPHEH(BQ7Pi9j_k$oCidETGMA zne?zICcOpT5VHpB^24Q9=&nV&rH9fcC0CO!{)NLWc@I8Whrvy|W; z&+8c}3D2QCmi2oMU@VxEa`+- zHZ;+~3$gpjSm%~neeDdSb1Fb9onKnRsm?D^&UBs~zhQ=Hod$2vIniII^GoaaZ&o5ohZ`m(XbU*adzwGge0VV{ zQjz!&E#C%UJiu=r+29Ux!8|Sw%g7D$z}-L|cu(BmKjN#PEuz1SEHZB{84LXDs2-;| zZe!FPe)#s4`q{=+>;ufD?hyEpz;_9}vRvH&m9lzH7dv10?<&GGI Date: Thu, 10 Mar 2022 00:22:42 +0900 Subject: [PATCH 04/44] Resolve pytype tests --- experimental/__init__.py | 0 experimental/features.py | 63 +++++++++++++---------------- experimental/sketching.py | 12 +++--- experimental/test_fc_ntk.py | 4 +- experimental/test_myrtle_network.py | 5 ++- 5 files changed, 42 insertions(+), 42 deletions(-) create mode 100644 experimental/__init__.py diff --git a/experimental/__init__.py b/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/features.py b/experimental/features.py index 196236da..faeb2c2f 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -1,21 +1,15 @@ +from typing import Optional, Callable from jax import random from jax import numpy as np from jax.numpy.linalg import cholesky - import jax.example_libraries.stax as ostax -import neural_tangents -from neural_tangents import stax -from pkg_resources import parse_version -if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'): - from neural_tangents._src.utils import utils, dataclasses - from neural_tangents._src.stax.linear import _pool_kernel, Padding - from neural_tangents._src.stax.linear import _Pooling as Pooling -else: - from neural_tangents.utils import utils, dataclasses - from neural_tangents.stax import _pool_kernel, Padding, Pooling +from neural_tangents import stax +from neural_tangents._src.utils import dataclasses +from neural_tangents._src.stax.linear import _pool_kernel, Padding +from neural_tangents._src.stax.linear import _Pooling as Pooling -from sketching import TensorSRHT2, PolyTensorSRHT +from experimental.sketching import TensorSRHT2 """ Implementation for NTK Sketching and Random Features """ @@ -50,11 +44,11 @@ def kappa1(x): @dataclasses.dataclass class Features: - nngp_feat: np.ndarray - ntk_feat: np.ndarray + nngp_feat: Optional[np.ndarray] = None + ntk_feat: Optional[np.ndarray] = None - batch_axis: int = dataclasses.field(pytree_node=False) - channel_axis: int = dataclasses.field(pytree_node=False) + batch_axis: int = 0 + channel_axis: int = -1 replace = ... # type: Callable[..., 'Features'] @@ -72,7 +66,7 @@ def _inputs_to_features(x: np.ndarray, return Features(nngp_feat=nngp_feat, ntk_feat=ntk_feat, batch_axis=batch_axis, - channel_axis=channel_axis) + channel_axis=channel_axis) # pytype:disable=wrong-keyword-args # Modified the serial process of feature map blocks. @@ -95,7 +89,7 @@ def feature_fn(k, inputs, **kwargs): def DenseFeatures(out_dim: int, W_std: float = 1., - b_std: float = None, + b_std: float = 1., parameterization: str = 'ntk', batch_axis: int = 0, channel_axis: int = -1): @@ -114,7 +108,7 @@ def kernel_fn(f: Features, input, **kwargs): nngp_feat *= W_std ntk_feat *= W_std - if ntk_feat.ndim == 0: # check if ntk_feat is empty + if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) @@ -153,20 +147,21 @@ def init_fn(rng, input_shape): ts2 = TensorSRHT2(rng=rng3, input_dim1=ntk_feat_shape[-1], input_dim2=feature_dim0, - sketch_dim=sketch_dim).init_sketches() + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape), (W0, W1, ts2) elif method == 'ps': - rng1, rng2, rng3 = random.split(rng, 3) - # PolySketch algorithm for arc-cosine kernel of order 0. - ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, - poly_degree0) - # PolySketch algorithm for arc-cosine kernel of order 1. - ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, - poly_degree1) - # TensorSRHT of degree 2 for approximating tensor product. - ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) - return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) + # rng1, rng2, rng3 = random.split(rng, 3) + # # PolySketch algorithm for arc-cosine kernel of order 0. + # ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, + # poly_degree0) + # # PolySketch algorithm for arc-cosine kernel of order 1. + # ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, + # poly_degree1) + # # TensorSRHT of degree 2 for approximating tensor product. + # ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) + # return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) + raise NotImplementedError elif method == 'exact': # The exact feature map computation is for debug. @@ -199,9 +194,9 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: kappa0_feat).reshape(input_shape + (-1,)) elif method == 'ps': - ps0: PolyTensorSRHT = input[0] - ps1: PolyTensorSRHT = input[1] - ts2: TensorSRHT2 = input[2] + # ps0: PolyTensorSRHT = input[0] + # ps1: PolyTensorSRHT = input[1] + # ts2: TensorSRHT2 = input[2] raise NotImplementedError elif method == 'exact': # Exact feature extraction via Cholesky decomposition. @@ -258,7 +253,7 @@ def feature_fn(f, input, **kwargs): nngp_feat = conv2d_feat(nngp_feat, filter_size) / filter_size * W_std - if ntk_feat.ndim == 0: # check if ntk_feat is empty + if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = conv2d_feat(ntk_feat, filter_size) / filter_size * W_std diff --git a/experimental/sketching.py b/experimental/sketching.py index 29f60cf4..48b54abf 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -1,7 +1,7 @@ from jax import random from jax import numpy as np -from neural_tangents._src.utils import utils, dataclasses -from neural_tangents._src.utils.typing import Optional +from neural_tangents._src.utils import dataclasses +from typing import Optional, Callable # TensorSRHT of degree 2. This version allows different input vectors. @@ -20,9 +20,9 @@ class TensorSRHT2: rand_inds1: Optional[np.ndarray] = None rand_inds2: Optional[np.ndarray] = None - replace = ... + replace = ... # type: Callable[..., 'TensorSRHT2'] - def init_sketches(self): + def init_sketches(self) -> 'TensorSRHT2': rng1, rng2, rng3, rng4 = random.split(self.rng, 4) rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 @@ -53,7 +53,8 @@ def tensorsrht(x1, x2, rand_inds, rand_signs): return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) -# TensorSRHT of degree p. This operates the same input vectors. +# pytype: disable=attribute-error +# TODO: Improve faster TensorSRHT. class PolyTensorSRHT: def __init__(self, rng, input_dim, sketch_dim, coeffs): @@ -133,3 +134,4 @@ def sketch(self, x): p = p // 2 U[j] = V[log_degree - 1][0, :, :].clone() return U +# pytype: enable=attribute-error \ No newline at end of file diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index 23e5d575..c340deac 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -2,11 +2,13 @@ from jax import random from jax.config import config from jax import jit +import sys +sys.path.append("./") config.update("jax_enable_x64", True) from neural_tangents import stax -from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial +from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial seed = 1 n, d = 6, 4 diff --git a/experimental/test_myrtle_network.py b/experimental/test_myrtle_network.py index 29d34c73..c89c8046 100644 --- a/experimental/test_myrtle_network.py +++ b/experimental/test_myrtle_network.py @@ -1,6 +1,7 @@ import os - os.environ['CUDA_VISIBLE_DEVICES'] = '' +import sys +sys.path.append("./") import functools from numpy.linalg import norm from jax.config import config @@ -12,7 +13,7 @@ from jax import random from neural_tangents import stax -from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features +from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} width = 1 From 12b828e985b68ea97f8885f4cd1b5acd41c1eb12 Mon Sep 17 00:00:00 2001 From: Amir Zandieh Date: Wed, 16 Mar 2022 16:54:40 +0100 Subject: [PATCH 05/44] ntk sketch with polynomial approximation to the end kernel function --- experimental/ntk_sketch.py | 84 +++++++++++++++ experimental/sketching.py | 205 +++++++++++++++++++++---------------- 2 files changed, 201 insertions(+), 88 deletions(-) create mode 100644 experimental/ntk_sketch.py diff --git a/experimental/ntk_sketch.py b/experimental/ntk_sketch.py new file mode 100644 index 00000000..cede91d6 --- /dev/null +++ b/experimental/ntk_sketch.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Wed Mar 16 15:41:18 2022 + +@author: amir +""" + +import numpy as onp +import quadprog +from matplotlib import pyplot as plt +from jax import numpy as np +from jax.numpy import linalg as LA +from sketching import standardsrht + +def quadprog_solve_qp(P, q, G=None, h=None, A=None, b=None): + qp_G = .5 * (P + P.T +1e-5*onp.eye(P.shape[0])) # make sure P is symmetric + qp_a = -q + if A is not None: + qp_C = -onp.vstack([A, G]).T + qp_b = -onp.hstack([b, h]) + meq = A.shape[0] + else: # no equality constraint + qp_C = -G.T + qp_b = -h + meq = 0 + return quadprog.solve_qp(qp_G, qp_a, qp_C, qp_b, meq)[0] + +def ntk_poly_coeffs(L,degree): + n=15*L+5*degree + Y = onp.zeros((201+n,L+1)) + Y[:,0] = onp.sort(onp.concatenate((onp.linspace(-1.0, 1.0, num=201), onp.cos((2*onp.arange(n)+1)*onp.pi / (4*n))), axis=0)) + + grid_len = Y.shape[0] + + for i in range(L): + Y[:,i+1] = (onp.sqrt(1-Y[:,i]**2) + Y[:,i]*(onp.pi - onp.arccos(Y[:,i])))/onp.pi + + y = onp.zeros(grid_len) + for i in range(L+1): + z = Y[:,i] + for j in range(i,L): + z = z*(onp.pi - onp.arccos(Y[:,j]))/onp.pi + y = y + z + + Z = onp.zeros((grid_len,degree+1)) + Z[:,0] = onp.ones(grid_len) + for i in range(degree): + Z[:,i+1] = Z[:,i] * Y[:,0] + + + weight_ = onp.linspace(0.0, 1.0, num=grid_len) + 2/L + w = y * weight_ + U = Z.T * weight_ + + coeffs = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len)) + coeffs[coeffs < 1e-5] = 0 + + return coeffs + +def poly_ntk_sketch(depth, polysketch, X): + degree = polysketch.degree + n = X.shape[0] + + ntk_coeff = ntk_poly_coeffs(depth, degree) + + norm_x = LA.norm(X, axis=1) + normalizer = np.where(norm_x>0, norm_x, 1.0) + x_normlzd = ((X.T / normalizer).T) + + polysketch_feats = polysketch.sketch(x_normlzd) + + sktch_dim = polysketch_feats[0].shape[1] + + Z = np.zeros((len(polysketch.rand_signs),n), dtype=np.complex64) + for i in range(degree): + Z = Z.at[sktch_dim*i:sktch_dim*(i+1),:].set(np.sqrt( ntk_coeff[i+1] ) * + polysketch_feats[degree-i-1].T) + + Z = standardsrht(Z.T, polysketch.rand_inds, polysketch.rand_signs) + Z = (Z.T * normalizer).T + + return np.concatenate(( np.sqrt(ntk_coeff[0]) * normalizer.reshape((n,1)), np.concatenate((Z.real, Z.imag), 1)), 1) + diff --git a/experimental/sketching.py b/experimental/sketching.py index 48b54abf..80a76139 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -45,93 +45,122 @@ def sketch(self, x1, x2): out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) return np.concatenate((out.real, out.imag), 1) +# Standard SRHT as a function +def standardsrht(x, rand_inds, rand_signs): + xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] + return np.sqrt(1 / rand_inds.shape[0]) * xfft -# Function implementation of TensorSRHT of degree 2 (duplicated) -def tensorsrht(x1, x2, rand_inds, rand_signs): - x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]] - x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]] - return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) - - -# pytype: disable=attribute-error -# TODO: Improve faster TensorSRHT. -class PolyTensorSRHT: - - def __init__(self, rng, input_dim, sketch_dim, coeffs): - self.coeffs = coeffs - degree = len(coeffs) - 1 - self.degree = degree - - self.tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] - self.tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] - rng1, rng2, rng3 = random.split(rng, 3) - - ske_dim_ = sketch_dim // 4 - deg_ = degree // 2 - for i in range((degree - 1).bit_length()): - rng1, rng2 = random.split(rng1) - if i == 0: - self.tree_rand_signs[i] = random.choice( - rng1, 2, shape=(deg_, 2, input_dim)) * 2 - 1 - self.tree_rand_inds[i] = random.choice(rng2, - input_dim, - shape=(deg_, 2, ske_dim_)) - else: - self.tree_rand_signs[i] = random.choice( - rng1, 2, shape=(deg_, 2, ske_dim_)) * 2 - 1 - self.tree_rand_inds[i] = random.choice(rng2, - ske_dim_, - shape=(deg_, 2, ske_dim_)) - deg_ = deg_ // 2 - - rng1, rng2 = random.split(rng3) - self.rand_signs = random.choice(rng1, 2, shape=(degree * ske_dim_,)) * 2 - 1 - self.rand_inds = random.choice(rng2, - degree * ske_dim_, - shape=(sketch_dim // 2,)) - - def sketch(self, x): - n = x.shape[0] - log_degree = len(self.tree_rand_signs) - V = [0 for i in range(log_degree)] - E1 = np.concatenate((np.ones( - (n, 1), dtype=x.dtype), np.zeros((n, x.shape[-1] - 1), dtype=x.dtype)), - 1) - for i in range(log_degree): - deg = self.tree_rand_signs[i].shape[0] - V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), - dtype=np.complex64) - for j in range(deg): - if i == 0: - V[i] = V[i].at[j, :, :].set( - tensorsrht(x, x, self.tree_rand_inds[i][j, :, :], - self.tree_rand_signs[i][j, :, :])) - else: - V[i] = V[i].at[j, :, :].set( - tensorsrht(V[i - 1][2 * j, :, :], V[i - 1][2 * j + 1, :, :], - self.tree_rand_inds[i][j, :, :], - self.tree_rand_signs[i][j, :, :])) - U = [0 for i in range(2**log_degree)] - U[0] = V[log_degree - 1][0, :, :].clone() - - for j in range(1, len(U)): - p = (j - 1) // 2 - for i in range(log_degree): - if j % (2**(i + 1)) == 0: - V[i] = V[i].at[p, :, :].set( - np.concatenate((np.ones((n, 1)), np.zeros( - (n, V[i].shape[-1] - 1))), 1)) - else: - if i == 0: - V[i] = V[i].at[p, :, :].set( - tensorsrht(x, E1, self.tree_rand_inds[i][p, :, :], - self.tree_rand_signs[i][p, :, :])) - else: - V[i] = V[i].at[p, :, :].set( - tensorsrht(V[i - 1][2 * p, :, :], V[i - 1][2 * p + 1, :, :], - self.tree_rand_inds[i][p, :, :], - self.tree_rand_signs[i][p, :, :])) - p = p // 2 - U[j] = V[log_degree - 1][0, :, :].clone() - return U +@dataclasses.dataclass +class PolyTensorSketch: + + rng: np.ndarray + + input_dim: int + sketch_dim: int + degree: int + + tree_rand_signs: Optional[list] = None + tree_rand_inds: Optional[list] = None + rand_signs: Optional[np.ndarray] = None + rand_inds: Optional[np.ndarray] = None + + replace = ... + + def init_sketch(self) -> 'PolyTensorSketch': + + tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] + tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] + rng1, rng3 = random.split(self.rng, 2) + + ske_dim_ = self.sketch_dim // 4 + deg_ = self.degree // 2 + + for i in range((self.degree - 1).bit_length()): + rng1, rng2 = random.split(rng1) + + if i == 0: + tree_rand_signs[i] = random.choice( + rng1, 2, shape=(deg_, 2, self.input_dim)) * 2 - 1 + tree_rand_inds[i] = random.choice(rng2, + self.input_dim, + shape=(deg_, 2, ske_dim_)) + else: + tree_rand_signs[i] = random.choice( + rng1, 2, shape=(deg_, 2, ske_dim_)) * 2 - 1 + tree_rand_inds[i] = random.choice(rng2, + ske_dim_,shape=(deg_, 2, ske_dim_)) + deg_ = deg_ // 2 + + rng1, rng2 = random.split(rng3,2) + rand_signs = random.choice(rng1, 2, shape=(self.degree * ske_dim_,)) * 2 - 1 + rand_inds = random.choice(rng2, + self.degree * ske_dim_,shape=(self.sketch_dim // 2,)) + + return self.replace(tree_rand_signs=tree_rand_signs, + tree_rand_inds=tree_rand_inds, + rand_signs=rand_signs, + rand_inds=rand_inds) + + + # TensorSRHT of degree 2 + def tensorsrht(self, x1, x2, rand_inds, rand_signs): + x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]] + x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]] + return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) + + # Standard SRHT + def standardsrht(self, x, rand_inds, rand_signs): + xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] + return np.sqrt(1 / rand_inds.shape[0]) * xfft + + def sketch(self, x): + n = x.shape[0] + log_degree = len(self.tree_rand_signs) + V = [0 for i in range(log_degree)] + #E1 = np.concatenate((np.ones((n, 1), dtype=x.dtype), np.zeros((n, x.shape[-1] - 1), dtype=x.dtype)),1) + + for i in range(log_degree): + deg = self.tree_rand_signs[i].shape[0] + V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), + dtype=np.complex64) + for j in range(deg): + if i == 0: + V[i] = V[i].at[j, :, :].set( + self.tensorsrht(x, x, self.tree_rand_inds[i][j, :, :], + self.tree_rand_signs[i][j, :, :])) + + else: + V[i] = V[i].at[j, :, :].set( + self.tensorsrht(V[i - 1][2 * j, :, :], V[i - 1][2 * j + 1, :, :], + self.tree_rand_inds[i][j, :, :], + self.tree_rand_signs[i][j, :, :])) + + U = [0 for i in range(2**log_degree)] + U[0] = V[log_degree - 1][0, :, :].clone() + + SetE1 = set() + + for j in range(1, len(U)): + p = (j - 1) // 2 + for i in range(log_degree): + if j % (2**(i + 1)) == 0: + SetE1.add((i,p)) + #V[i] = V[i].at[p, :, :].set(np.concatenate((np.ones((n, 1)), np.zeros((n, V[i].shape[-1] - 1))), 1)) + else: + if i == 0: + V[i] = V[i].at[p, :, :].set( + self.standardsrht(x, self.tree_rand_inds[i][p, 0, :], + self.tree_rand_signs[i][p, 0, :])) + else: + if (i-1,2*p) in SetE1: + V[i] = V[i].at[p, :, :].set(V[i-1][2*p+1,:,:].clone()) + else: + V[i] = V[i].at[p, :, :].set( + self.tensorsrht(V[i - 1][2 * p, :, :], V[i - 1][2 * p + 1, :, :], + self.tree_rand_inds[i][p, :, :], + self.tree_rand_signs[i][p, :, :])) + p = p // 2 + U[j] = V[log_degree - 1][0, :, :].clone() + + return U # pytype: enable=attribute-error \ No newline at end of file From 4a5cddd98f7fbbc049e6bb180fc6d0cb2dd54b0e Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 17 Mar 2022 06:06:04 +0900 Subject: [PATCH 06/44] Fix simple issues from code reviews (v1) --- experimental/README.md | 3 +-- experimental/features.py | 56 ++++++++++++++++------------------------ 2 files changed, 23 insertions(+), 36 deletions(-) diff --git a/experimental/README.md b/experimental/README.md index f5ee2d75..d3b6ff85 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -59,10 +59,9 @@ For more complex CNTK features, please check `test_myrtle_networks.py`. # Modules -All modules return a triple functions `(init_fn, apply_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. +All modules return a tuple functions `(init_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. - `init_fn` takes (1) random seed and (2) a pair of shapes of input features for both NNGP and NTK. It returns (1) a pair of shapes of output features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach). -- `apply_fn` does nothing (dummy functions). - `feature_fn` takes (1) feature structure `features.Feature` and (2) parameters used for feature approximation (initialized by `init_fn`). It returns `features.Feature` including approximate features of the corresponding module. diff --git a/experimental/features.py b/experimental/features.py index faeb2c2f..ee7b8db4 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -73,26 +73,25 @@ def _inputs_to_features(x: np.ndarray, # Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py def serial(*layers): init_fns, apply_fns, feature_fns = zip(*layers) - init_fn, apply_fn = ostax.serial(*zip(init_fns, apply_fns)) + init_fn, _ = ostax.serial(*zip(init_fns, apply_fns)) - # import time def feature_fn(k, inputs, **kwargs): for f, input_ in zip(feature_fns, inputs): - # print(f) - # tic = time.time() k = f(k, input_, **kwargs) - # print(f"toc: {time.time() - tic:.2f} sec") return k - return init_fn, apply_fn, feature_fn + return init_fn, feature_fn def DenseFeatures(out_dim: int, W_std: float = 1., - b_std: float = 1., - parameterization: str = 'ntk', + b_std: float = 0., batch_axis: int = 0, channel_axis: int = -1): + + if b_std != 0.0: + raise NotImplementedError('Non-zero b_std is not implemented yet .' + ' Please set b_std to be `0`.') def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] @@ -100,10 +99,7 @@ def init_fn(rng, input_shape): ntk_feat_shape[-1],) return (nngp_feat_shape, new_ntk_feat_shape), () - def apply_fn(**kwargs): - return None - - def kernel_fn(f: Features, input, **kwargs): + def feature_fn(f: Features, input, **kwargs): nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat nngp_feat *= W_std ntk_feat *= W_std @@ -115,7 +111,7 @@ def kernel_fn(f: Features, input, **kwargs): return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - return init_fn, apply_fn, kernel_fn + return init_fn, feature_fn def ReluFeatures( @@ -170,9 +166,6 @@ def init_fn(rng, input_shape): new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) return (new_nngp_feat_shape, new_ntk_feat_shape), () - def apply_fn(**kwargs): - return None - def feature_fn(f: Features, input=None, **kwargs) -> Features: input_shape = f.nngp_feat.shape[:-1] @@ -211,10 +204,10 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - return init_fn, apply_fn, feature_fn + return init_fn, feature_fn -def conv_feat(X, filter_size): +def _conv_feat(X, filter_size): N, H, W, C = X.shape out = np.zeros((N, H, W, C * filter_size)) out = out.at[:, :, :, :C].set(X) @@ -227,8 +220,8 @@ def conv_feat(X, filter_size): return out -def conv2d_feat(X, filter_size): - return conv_feat(np.moveaxis(conv_feat(X, filter_size), 1, 2), filter_size) +def _conv2d_feat(X, filter_size): + return _conv_feat(np.moveaxis(_conv_feat(X, filter_size), 1, 2), filter_size) def ConvFeatures(out_dim: int, @@ -237,6 +230,10 @@ def ConvFeatures(out_dim: int, b_std: float = 0., channel_axis: int = -1): + if b_std != 0.0: + raise NotImplementedError('Non-zero b_std is not implemented yet .' + ' Please set b_std to be `0`.') + def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] * @@ -245,23 +242,20 @@ def init_fn(rng, input_shape): (nngp_feat_shape[-1] + ntk_feat_shape[-1]) * filter_size**2,) return (new_nngp_feat_shape, new_ntk_feat_shape), () - def apply_fn(**kwargs): - return None - def feature_fn(f, input, **kwargs): nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat - nngp_feat = conv2d_feat(nngp_feat, filter_size) / filter_size * W_std + nngp_feat = _conv2d_feat(nngp_feat, filter_size) / filter_size * W_std if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: - ntk_feat = conv2d_feat(ntk_feat, filter_size) / filter_size * W_std + ntk_feat = _conv2d_feat(ntk_feat, filter_size) / filter_size * W_std ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - return init_fn, apply_fn, feature_fn + return init_fn, feature_fn def AvgPoolFeatures(window_size: int, @@ -282,9 +276,6 @@ def init_fn(rng, input_shape): ntk_feat_shape[2] // window_size) + ntk_feat_shape[-1:] return (new_nngp_feat_shape, new_ntk_feat_shape), () - def apply_fn(**kwargs): - return None - def feature_fn(f, input=None, **kwargs): window_shape_kernel = (1,) + (window_size,) * 2 + (1,) strides_kernel = (1,) + (window_size,) * 2 + (1,) @@ -296,7 +287,7 @@ def feature_fn(f, input=None, **kwargs): return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - return init_fn, apply_fn, feature_fn + return init_fn, feature_fn def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): @@ -307,9 +298,6 @@ def init_fn(rng, input_shape): new_ntk_feat_shape = ntk_feat_shape[:1] + (_prod(ntk_feat_shape[1:]),) return (new_nngp_feat_shape, new_ntk_feat_shape), () - def apply_fn(**kwargs): - return None - def feature_fn(f, input=None, **kwargs): batch_size = f.nngp_feat.shape[0] nngp_feat = f.nngp_feat.reshape(batch_size, -1) / np.sqrt( @@ -319,4 +307,4 @@ def feature_fn(f, input=None, **kwargs): return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - return init_fn, apply_fn, feature_fn + return init_fn, feature_fn From 932987dbf0ee772807a972ee91caa91a5a1c6554 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 17 Mar 2022 06:15:58 +0900 Subject: [PATCH 07/44] Fix simple issues from code reviews (v2) --- experimental/features.py | 6 +++--- experimental/test_fc_ntk.py | 22 +++++++++++----------- experimental/test_myrtle_network.py | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index ee7b8db4..9df6948a 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -72,8 +72,8 @@ def _inputs_to_features(x: np.ndarray, # Modified the serial process of feature map blocks. # Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py def serial(*layers): - init_fns, apply_fns, feature_fns = zip(*layers) - init_fn, _ = ostax.serial(*zip(init_fns, apply_fns)) + init_fns, feature_fns = zip(*layers) + init_fn, _ = ostax.serial(*zip(init_fns, init_fns)) def feature_fn(k, inputs, **kwargs): for f, input_ in zip(feature_fns, inputs): @@ -88,7 +88,7 @@ def DenseFeatures(out_dim: int, b_std: float = 0., batch_axis: int = 0, channel_axis: int = -1): - + if b_std != 0.0: raise NotImplementedError('Non-zero b_std is not implemented yet .' ' Please set b_std to be `0`.') diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index c340deac..a623c37f 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -20,9 +20,9 @@ print("================= Result of Neural Tangent Library =================") -init_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(width), stax.Relu(), - stax.Dense(width), stax.Relu(), - stax.Dense(1)) +init_fn, _, kernel_fn = stax.serial(stax.Dense(width), stax.Relu(), + stax.Dense(width), stax.Relu(), + stax.Dense(1)) nt_kernel = kernel_fn(x1, None) @@ -49,10 +49,10 @@ 'sketch_dim': sketch_dim, } -init_fn, _, features_fn = serial(DenseFeatures(width), - ReluFeatures(**relufeat_arg), - DenseFeatures(width), - ReluFeatures(**relufeat_arg), DenseFeatures(1)) +init_fn, features_fn = serial(DenseFeatures(width), + ReluFeatures(**relufeat_arg), + DenseFeatures(width), + ReluFeatures(**relufeat_arg), DenseFeatures(1)) # Initialize random vectors and sketching algorithms init_nngp_feat_shape = x1.shape @@ -84,10 +84,10 @@ relufeat_arg = {'method': 'exact'} -init_fn, _, features_fn = serial(DenseFeatures(width), - ReluFeatures(**relufeat_arg), - DenseFeatures(width), - ReluFeatures(**relufeat_arg), DenseFeatures(1)) +init_fn, features_fn = serial(DenseFeatures(width), + ReluFeatures(**relufeat_arg), + DenseFeatures(width), + ReluFeatures(**relufeat_arg), DenseFeatures(1)) f0 = _inputs_to_features(x1) feats = jit(features_fn)(f0, feat_fn_inputs) diff --git a/experimental/test_myrtle_network.py b/experimental/test_myrtle_network.py index c89c8046..52134b1e 100644 --- a/experimental/test_myrtle_network.py +++ b/experimental/test_myrtle_network.py @@ -91,7 +91,7 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args): 'sketch_dim': sketch_dim, } -init_fn, _, feature_fn = MyrtleNetworkFeatures(5, **relufeat_arg) +init_fn, feature_fn = MyrtleNetworkFeatures(5, **relufeat_arg) feature_fn = jit(feature_fn) init_nngp_feat_shape = x.shape From 6a156ed556ab937f929c83a18ecef5f0a70eac91 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 17 Mar 2022 06:20:14 +0900 Subject: [PATCH 08/44] Fix simple issues from code reviews (v2) --- experimental/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/experimental/README.md b/experimental/README.md index d3b6ff85..c0b8c4d3 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -20,7 +20,7 @@ relufeat_arg = { 'method': 'rf', } -init_fn, _, feature_fn = serial( +init_fn, feature_fn = serial( DenseFeatures(512), ReluFeatures(**relufeat_arg), DenseFeatures(512), ReluFeatures(**relufeat_arg), DenseFeatures(1) @@ -42,7 +42,7 @@ For more details of fully connected NTK features, please check `test_fc_ntk.py`. ### Convolutional NTK approximation via Random Features: ```python -init_fn, _, feature_fn = serial( +init_fn, feature_fn = serial( ConvFeatures(512, filter_size=3), ReluFeatures(**relufeat_arg), AvgPoolFeatures(2, 2), FlattenFeatures() ) @@ -59,7 +59,7 @@ For more complex CNTK features, please check `test_myrtle_networks.py`. # Modules -All modules return a tuple functions `(init_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. +All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. We do not return `apply_fn` functions. - `init_fn` takes (1) random seed and (2) a pair of shapes of input features for both NNGP and NTK. It returns (1) a pair of shapes of output features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach). - `feature_fn` takes (1) feature structure `features.Feature` and (2) parameters used for feature approximation (initialized by `init_fn`). It returns `features.Feature` including approximate features of the corresponding module. @@ -75,7 +75,7 @@ x = random.normal(key1, shape=(3, 2)) _, _, kernel_fn = stax.Dense(width) nt_kernel = kernel_fn(x) -_, _, feat_fn = DenseFeatures(width) +_, feat_fn = DenseFeatures(width) feat = feat_fn(_inputs_to_features(x), ()) assert np.linalg.norm(nt_kernel.ntk - feat.ntk_feat @ feat.ntk_feat.T) <= 1e-12 @@ -91,7 +91,7 @@ To use the Random Features approach, set the parameter `method` to `rf` (default ```python x = random.normal(key1, shape=(3, 32)) -init_fn, _ , feat_fn = serial( +init_fn, feat_fn = serial( DenseFeatures(1), ReluFeatures(method='rf', feature_dim0=10, feature_dim1=20, sketch_dim=30) ) @@ -106,7 +106,7 @@ assert out_feat.ntk_feat.shape == (3, 30) To use the exact feature map (based on Cholesky decomposition), set the parameter `method` to `exact`, e.g., ```python -init_fn, _ , feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact')) +init_fn, feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact')) _, params = init_fn(key1, (x.shape,(-1, 0))) out_feat = feat_fn(_inputs_to_features(x), params) From 9f2a3e5edf9434cdd7437670c2d72d5f29674890 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 17 Mar 2022 15:11:47 +0900 Subject: [PATCH 09/44] Automatically preprocess init_fn/feature_fn --- experimental/README.md | 30 ++++++++-------- experimental/features.py | 53 +++++++++++++++++++++++++++++ experimental/test_fc_ntk.py | 26 ++++++-------- experimental/test_myrtle_network.py | 13 +++---- 4 files changed, 84 insertions(+), 38 deletions(-) diff --git a/experimental/README.md b/experimental/README.md index c0b8c4d3..4acef12e 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -11,7 +11,7 @@ Implementations developed in [[1]](#1-scaling-neural-tangent-kernels-via-sketchi ```python from jax import random -from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial +from experimental.features import DenseFeatures, ReluFeatures, serial relufeat_arg = { 'feature_dim0': 128, @@ -29,12 +29,8 @@ init_fn, feature_fn = serial( key1, key2 = random.split(random.PRNGKey(1)) x = random.normal(key1, (5, 4)) -initial_nngp_feat_shape = x.shape -initial_ntk_feat_shape = (-1,0) -initial_feat_shape = (initial_nngp_feat_shape, initial_ntk_feat_shape) - -_, feat_fn_inputs = init_fn(key2, initial_feat_shape) -feats = feature_fn(_inputs_to_features(x), feat_fn_inputs) +_, feat_fn_inputs = init_fn(key2, x.shape) +feats = feature_fn(x, feat_fn_inputs) # feats.nngp_feat is a feature map of NNGP kernel # feats.ntk_feat is a feature map of NTK ``` @@ -42,16 +38,19 @@ For more details of fully connected NTK features, please check `test_fc_ntk.py`. ### Convolutional NTK approximation via Random Features: ```python +from experimental.features import ConvFeatures, AvgPoolFeatures, FlattenFeatures + init_fn, feature_fn = serial( ConvFeatures(512, filter_size=3), ReluFeatures(**relufeat_arg), AvgPoolFeatures(2, 2), FlattenFeatures() ) n, H, W, C = 5, 8, 8, 3 +key1, key2 = random.split(random.PRNGKey(1)) x = random.normal(key1, shape=(n, H, W, C)) -_, feat_fn_inputs = init_fn(key2, (x.shape, (-1, 0)) -feats = feature_fn(_inputs_to_features(x), feat_fn_inputs) +_, feat_fn_inputs = init_fn(key2, x.shape) +feats = feature_fn(x, feat_fn_inputs) # feats.nngp_feat is a feature map of NNGP kernel # feats.ntk_feat is a feature map of NTK ``` @@ -61,7 +60,7 @@ For more complex CNTK features, please check `test_myrtle_networks.py`. All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. We do not return `apply_fn` functions. -- `init_fn` takes (1) random seed and (2) a pair of shapes of input features for both NNGP and NTK. It returns (1) a pair of shapes of output features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach). +- `init_fn` takes (1) random seed and (2) input shape. It returns (1) a pair of shapes of both NNGP and NTK features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach). - `feature_fn` takes (1) feature structure `features.Feature` and (2) parameters used for feature approximation (initialized by `init_fn`). It returns `features.Feature` including approximate features of the corresponding module. @@ -69,6 +68,7 @@ All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kerne `features.DenseFeatures` provides features for fully-connected dense layer and corresponds to `stax.Dense` module in [Neural Tangents](https://github.com/google/neural-tangents). We assume that the input is a tabular dataset (i.e., a n-by-d matrix). Its `feature_fn` updates the NTK features by concatenating NNGP features and NTK features. This is because `stax.Dense` updates a new NTK kernel matrix `(N x D)` by adding the previous NNGP and NTK kernel matrices. The features of dense layer are exact and no approximations are applied. ```python import numpy as np +from neural_tangents import stax width = 1 x = random.normal(key1, shape=(3, 2)) @@ -76,7 +76,7 @@ _, _, kernel_fn = stax.Dense(width) nt_kernel = kernel_fn(x) _, feat_fn = DenseFeatures(width) -feat = feat_fn(_inputs_to_features(x), ()) +feat = feat_fn(x, ()) assert np.linalg.norm(nt_kernel.ntk - feat.ntk_feat @ feat.ntk_feat.T) <= 1e-12 assert np.linalg.norm(nt_kernel.nngp - feat.nngp_feat @ feat.nngp_feat.T) <= 1e-12 @@ -96,9 +96,9 @@ init_fn, feat_fn = serial( ReluFeatures(method='rf', feature_dim0=10, feature_dim1=20, sketch_dim=30) ) -_, params = init_fn(key1, (x.shape,(-1, 0))) +_, params = init_fn(key1, x.shape) -out_feat = feat_fn(_inputs_to_features(x), params) +out_feat = feat_fn(x, params) assert out_feat.nngp_feat.shape == (3, 20) assert out_feat.ntk_feat.shape == (3, 30) @@ -107,8 +107,8 @@ assert out_feat.ntk_feat.shape == (3, 30) To use the exact feature map (based on Cholesky decomposition), set the parameter `method` to `exact`, e.g., ```python init_fn, feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact')) -_, params = init_fn(key1, (x.shape,(-1, 0))) -out_feat = feat_fn(_inputs_to_features(x), params) +_, params = init_fn(key1, x.shape) +out_feat = feat_fn(x, params) assert out_feat.nngp_feat.shape == (3, 3) assert out_feat.ntk_feat.shape == (3, 3) diff --git a/experimental/features.py b/experimental/features.py index 9df6948a..4c80c30a 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -69,8 +69,56 @@ def _inputs_to_features(x: np.ndarray, channel_axis=channel_axis) # pytype:disable=wrong-keyword-args +# For flexible `feature_fn` with both input `np.ndarray` and with `Feature`. +# Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/_src/stax/requirements.py +def _preprocess_feature_fn(feature_fn): + + def feature_fn_feature(feature, input, **kwargs): + return feature_fn(feature, input, **kwargs) + + def feature_fn_x(x, input, **kwargs): + feature = _inputs_to_features(x, **kwargs) + return feature_fn(feature, input, **kwargs) + + def feature_fn_any(x_or_feature, input=None, **kwargs): + if isinstance(x_or_feature, Features): + return feature_fn_feature(x_or_feature, input, **kwargs) + return feature_fn_x(x_or_feature, input, **kwargs) + + return feature_fn_any + + +def _is_sinlge_shape(input_shape): + if all(isinstance(n, int) for n in input_shape): + return True + return False + +def _preprocess_init_fn(init_fn): + + def init_fn_any(rng, input_shape_any, **kwargs): + if _is_sinlge_shape(input_shape_any): + input_shape = (input_shape_any, (-1,0)) + return init_fn(rng, input_shape, **kwargs) + else: + return init_fn(rng, input_shape_any, **kwargs) + + return init_fn_any + + +def layer(layer_fn): + + def new_layer_fns(*args, **kwargs): + init_fn, feature_fn = layer_fn(*args, **kwargs) + feature_fn = _preprocess_feature_fn(feature_fn) + init_fn = _preprocess_init_fn(init_fn) + return init_fn, feature_fn + + return new_layer_fns + + # Modified the serial process of feature map blocks. # Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py +@layer def serial(*layers): init_fns, feature_fns = zip(*layers) init_fn, _ = ostax.serial(*zip(init_fns, init_fns)) @@ -83,6 +131,7 @@ def feature_fn(k, inputs, **kwargs): return init_fn, feature_fn +@layer def DenseFeatures(out_dim: int, W_std: float = 1., b_std: float = 0., @@ -114,6 +163,7 @@ def feature_fn(f: Features, input, **kwargs): return init_fn, feature_fn +@layer def ReluFeatures( feature_dim0: int = 1, feature_dim1: int = 1, @@ -224,6 +274,7 @@ def _conv2d_feat(X, filter_size): return _conv_feat(np.moveaxis(_conv_feat(X, filter_size), 1, 2), filter_size) +@layer def ConvFeatures(out_dim: int, filter_size: int, W_std: float = 1.0, @@ -258,6 +309,7 @@ def feature_fn(f, input, **kwargs): return init_fn, feature_fn +@layer def AvgPoolFeatures(window_size: int, stride_size: int = 2, padding: str = stax.Padding.VALID.name, @@ -290,6 +342,7 @@ def feature_fn(f, input=None, **kwargs): return init_fn, feature_fn +@layer def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): def init_fn(rng, input_shape): diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index a623c37f..be91063a 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -8,13 +8,13 @@ config.update("jax_enable_x64", True) from neural_tangents import stax -from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial +from experimental.features import DenseFeatures, ReluFeatures, serial seed = 1 n, d = 6, 4 key1, key2 = random.split(random.PRNGKey(seed)) -x1 = random.normal(key1, (n, d)) +x = random.normal(key1, (n, d)) width = 512 # this does not matter the output @@ -24,7 +24,7 @@ stax.Dense(width), stax.Relu(), stax.Dense(1)) -nt_kernel = kernel_fn(x1, None) +nt_kernel = kernel_fn(x, None) print("K_nngp :") print(nt_kernel.nngp) @@ -40,8 +40,6 @@ kappa1_feat_dim = 10000 sketch_dim = 20000 -f0 = _inputs_to_features(x1) - relufeat_arg = { 'method': 'rf', 'feature_dim0': kappa0_feat_dim, @@ -55,20 +53,18 @@ ReluFeatures(**relufeat_arg), DenseFeatures(1)) # Initialize random vectors and sketching algorithms -init_nngp_feat_shape = x1.shape -init_ntk_feat_shape = (-1, 0) -init_feat_shape = (init_nngp_feat_shape, init_ntk_feat_shape) -_, feat_fn_inputs = init_fn(key2, init_feat_shape) +feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map -f0 = _inputs_to_features(x1) -feats = jit(features_fn)(f0, feat_fn_inputs) +feats = jit(features_fn)(x, feat_fn_inputs) -print("K_nngp :") +print(f"f_nngp shape: {feat_shape[0]}") +print("K_nngp (approx):") print(feats.nngp_feat @ feats.nngp_feat.T) print() -print("K_ntk :") +print(f"f_ntk shape: {feat_shape[1]}") +print("K_ntk (approx):") print(feats.ntk_feat @ feats.ntk_feat.T) print() @@ -88,8 +84,8 @@ ReluFeatures(**relufeat_arg), DenseFeatures(width), ReluFeatures(**relufeat_arg), DenseFeatures(1)) -f0 = _inputs_to_features(x1) -feats = jit(features_fn)(f0, feat_fn_inputs) + +feats = jit(features_fn)(x, feat_fn_inputs) print("K_nngp :") print(feats.nngp_feat @ feats.nngp_feat.T) diff --git a/experimental/test_myrtle_network.py b/experimental/test_myrtle_network.py index 52134b1e..e50f799b 100644 --- a/experimental/test_myrtle_network.py +++ b/experimental/test_myrtle_network.py @@ -13,7 +13,7 @@ from jax import random from neural_tangents import stax -from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features +from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} width = 1 @@ -94,19 +94,16 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args): init_fn, feature_fn = MyrtleNetworkFeatures(5, **relufeat_arg) feature_fn = jit(feature_fn) -init_nngp_feat_shape = x.shape -init_ntk_feat_shape = (-1, 0) -init_feat_shape = (init_nngp_feat_shape, init_ntk_feat_shape) +feat_shape, feat_fn_inputs = init_fn(key2, x.shape) -inputs_shape, feat_fn_inputs = init_fn(key2, init_feat_shape) - -f0 = _inputs_to_features(x) -feats = feature_fn(f0, feat_fn_inputs) +feats = feature_fn(x, feat_fn_inputs) +print(f"f_nngp shape: {feat_shape[0]}") print("K_nngp (approx):") print(feats.nngp_feat @ feats.nngp_feat.T) print() +print(f"f_ntk shape: {feat_shape[1]}") print("K_ntk (approx):") print(feats.ntk_feat @ feats.ntk_feat.T) print() From c2bed911dc685673abc30fc3e669e17177edaeaa Mon Sep 17 00:00:00 2001 From: insuhan Date: Fri, 18 Mar 2022 03:27:55 +0900 Subject: [PATCH 10/44] Update for raw inputs --- experimental/features.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index 4c80c30a..c78f38f6 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -61,7 +61,7 @@ def _inputs_to_features(x: np.ndarray, # Followed the same initialization of Neural Tangents library. nngp_feat = x / x.shape[channel_axis]**0.5 - ntk_feat = np.empty((), dtype=nngp_feat.dtype) + ntk_feat = np.array([0.0], dtype=nngp_feat.dtype) return Features(nngp_feat=nngp_feat, ntk_feat=ntk_feat, @@ -93,11 +93,16 @@ def _is_sinlge_shape(input_shape): return True return False + +def _is_defaut_feature(feat): + return feat.ndim == 1 + + def _preprocess_init_fn(init_fn): - + def init_fn_any(rng, input_shape_any, **kwargs): if _is_sinlge_shape(input_shape_any): - input_shape = (input_shape_any, (-1,0)) + input_shape = (input_shape_any, (-1, 0)) return init_fn(rng, input_shape, **kwargs) else: return init_fn(rng, input_shape_any, **kwargs) @@ -153,7 +158,7 @@ def feature_fn(f: Features, input, **kwargs): nngp_feat *= W_std ntk_feat *= W_std - if ntk_feat.ndim == 0: # check if ntk_feat is empty + if _is_defaut_feature(ntk_feat): # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) @@ -298,7 +303,7 @@ def feature_fn(f, input, **kwargs): nngp_feat = _conv2d_feat(nngp_feat, filter_size) / filter_size * W_std - if ntk_feat.ndim == 0: # check if ntk_feat is empty + if _is_defaut_feature(ntk_feat): # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = _conv2d_feat(ntk_feat, filter_size) / filter_size * W_std @@ -355,8 +360,12 @@ def feature_fn(f, input=None, **kwargs): batch_size = f.nngp_feat.shape[0] nngp_feat = f.nngp_feat.reshape(batch_size, -1) / np.sqrt( _prod(f.nngp_feat.shape[1:-1])) - ntk_feat = f.ntk_feat.reshape(batch_size, -1) / np.sqrt( - _prod(f.ntk_feat.shape[1:-1])) + + if _is_defaut_feature(f.ntk_feat): # check if ntk_feat is empty + ntk_feat = f.ntk_feat + else: + ntk_feat = f.ntk_feat.reshape(batch_size, -1) / np.sqrt( + _prod(f.ntk_feat.shape[1:-1])) return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) From d874cde6b844e9ebf082f8aa054f575e29be0778 Mon Sep 17 00:00:00 2001 From: insuhan Date: Fri, 18 Mar 2022 03:29:37 +0900 Subject: [PATCH 11/44] Update FlattenFeatures for raw inputs --- experimental/features.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/features.py b/experimental/features.py index c78f38f6..9b6d86aa 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -360,7 +360,6 @@ def feature_fn(f, input=None, **kwargs): batch_size = f.nngp_feat.shape[0] nngp_feat = f.nngp_feat.reshape(batch_size, -1) / np.sqrt( _prod(f.nngp_feat.shape[1:-1])) - if _is_defaut_feature(f.ntk_feat): # check if ntk_feat is empty ntk_feat = f.ntk_feat else: From 78038f97a9ce8ba7720b13085bcb8b0493aef90b Mon Sep 17 00:00:00 2001 From: Amir Zandieh Date: Thu, 24 Mar 2022 08:44:48 +0100 Subject: [PATCH 12/44] changes to the poly sketching alg --- experimental/features.py | 195 ++++++++++++++++++++--------------- experimental/poly_fitting.py | 80 ++++++++++++++ experimental/sketching.py | 95 +++++++++++++++-- experimental/test_fc_ntk.py | 30 ++++-- 4 files changed, 304 insertions(+), 96 deletions(-) create mode 100644 experimental/poly_fitting.py diff --git a/experimental/features.py b/experimental/features.py index 9b6d86aa..ee537231 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -3,13 +3,15 @@ from jax import numpy as np from jax.numpy.linalg import cholesky import jax.example_libraries.stax as ostax +from jax.numpy import linalg as LA from neural_tangents import stax from neural_tangents._src.utils import dataclasses from neural_tangents._src.stax.linear import _pool_kernel, Padding from neural_tangents._src.stax.linear import _Pooling as Pooling -from experimental.sketching import TensorSRHT2 +from sketching import TensorSRHT2 +from poly_fitting import kappa0_coeffs, kappa1_coeffs """ Implementation for NTK Sketching and Random Features """ @@ -32,20 +34,21 @@ def _sqrt(x): def kappa0(x): xxt = x @ x.T prod = np.outer(np.linalg.norm(x, axis=-1)**2, np.linalg.norm(x, axis=-1)**2) - return (1 - _arccos(xxt / _sqrt(prod)) / np.pi) / 2 + return (1 - _arccos(xxt / _sqrt(prod)) / np.pi) def kappa1(x): xxt = x @ x.T prod = np.outer(np.linalg.norm(x, axis=-1)**2, np.linalg.norm(x, axis=-1)**2) return (_sqrt(prod - xxt**2) + - (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi / 2 + (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi @dataclasses.dataclass class Features: nngp_feat: Optional[np.ndarray] = None ntk_feat: Optional[np.ndarray] = None + norms: Optional[np.ndarray] = None batch_axis: int = 0 channel_axis: int = -1 @@ -61,10 +64,15 @@ def _inputs_to_features(x: np.ndarray, # Followed the same initialization of Neural Tangents library. nngp_feat = x / x.shape[channel_axis]**0.5 + norms = LA.norm(nngp_feat, axis=channel_axis) + norms = np.where(norms>0, norms, 1.0) + nngp_feat = (nngp_feat.T / norms).T + ntk_feat = np.array([0.0], dtype=nngp_feat.dtype) return Features(nngp_feat=nngp_feat, ntk_feat=ntk_feat, + norms=norms, batch_axis=batch_axis, channel_axis=channel_axis) # pytype:disable=wrong-keyword-args @@ -151,7 +159,11 @@ def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_ntk_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] + ntk_feat_shape[-1],) - return (nngp_feat_shape, new_ntk_feat_shape), () + + if len(input_shape) > 2: + return (nngp_feat_shape, new_ntk_feat_shape, input_shape[2]+'D'), () + else: + return (nngp_feat_shape, new_ntk_feat_shape, 'D'), () def feature_fn(f: Features, input, **kwargs): nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat @@ -178,88 +190,103 @@ def ReluFeatures( poly_sketch_dim0: int = 1, poly_sketch_dim1: int = 1, method: str = 'rf', + top_layer: bool = False ): - method = method.lower() - assert method in ['rf', 'ps', 'exact'] - - def init_fn(rng, input_shape): - nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) - - if method == 'rf': - rng1, rng2, rng3 = random.split(rng, 3) - # Random vectors for random features of arc-cosine kernel of order 0. - W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) - # Random vectors for random features of arc-cosine kernel of order 1. - W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) - # TensorSRHT of degree 2 for approximating tensor product. - ts2 = TensorSRHT2(rng=rng3, - input_dim1=ntk_feat_shape[-1], - input_dim2=feature_dim0, - sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args - return (new_nngp_feat_shape, new_ntk_feat_shape), (W0, W1, ts2) - - elif method == 'ps': - # rng1, rng2, rng3 = random.split(rng, 3) - # # PolySketch algorithm for arc-cosine kernel of order 0. - # ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, - # poly_degree0) - # # PolySketch algorithm for arc-cosine kernel of order 1. - # ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, - # poly_degree1) - # # TensorSRHT of degree 2 for approximating tensor product. - # ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) - # return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) - raise NotImplementedError - - elif method == 'exact': - # The exact feature map computation is for debug. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( - nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) - return (new_nngp_feat_shape, new_ntk_feat_shape), () - - def feature_fn(f: Features, input=None, **kwargs) -> Features: - - input_shape = f.nngp_feat.shape[:-1] - nngp_feat_dim = f.nngp_feat.shape[-1] - ntk_feat_dim = f.ntk_feat.shape[-1] - - nngp_feat_2d = f.nngp_feat.reshape(-1, nngp_feat_dim) - ntk_feat_2d = f.ntk_feat.reshape(-1, ntk_feat_dim) - - if method == 'rf': # Random Features approach. - W0: np.ndarray = input[0] - W1: np.ndarray = input[1] - ts2: TensorSRHT2 = input[2] - - kappa0_feat = (nngp_feat_2d @ W0 > 0) / np.sqrt(W0.shape[-1]) - nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / - np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) - ntk_feat = ts2.sketch(ntk_feat_2d, - kappa0_feat).reshape(input_shape + (-1,)) - - elif method == 'ps': - # ps0: PolyTensorSRHT = input[0] - # ps1: PolyTensorSRHT = input[1] - # ts2: TensorSRHT2 = input[2] - raise NotImplementedError - - elif method == 'exact': # Exact feature extraction via Cholesky decomposition. - nngp_feat = cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) - - ntk = ntk_feat_2d @ ntk_feat_2d.T - kappa0_mat = kappa0(nngp_feat_2d) - ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) - - else: - raise NotImplementedError - - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - - return init_fn, feature_fn + method = method.lower() + assert method in ['rf', 'ps', 'exact'] + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) + net_shape = input_shape[2] + layer_count = len(net_shape)//2+1 + + if method == 'rf': + rng1, rng2, rng3 = random.split(rng, 3) + # Random vectors for random features of arc-cosine kernel of order 0. + W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) + # Random vectors for random features of arc-cosine kernel of order 1. + W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) + # TensorSRHT of degree 2 for approximating tensor product. + ts2 = TensorSRHT2(rng=rng3, + input_dim1=ntk_feat_shape[-1], + input_dim2=feature_dim0, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (W0, W1, ts2) + + elif method == 'ps': + # rng1, rng2, rng3 = random.split(rng, 3) + # # PolySketch algorithm for arc-cosine kernel of order 0. + # ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, + # poly_degree0) + # # PolySketch algorithm for arc-cosine kernel of order 1. + # ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, + # poly_degree1) + # # TensorSRHT of degree 2 for approximating tensor product. + # ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) + # return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) + raise NotImplementedError + + elif method == 'exact': + # The exact feature map computation is for debug. + new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( + nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) + + return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (layer_count) + + def feature_fn(f: Features, input=None, **kwargs) -> Features: + + input_shape = f.nngp_feat.shape[:-1] + nngp_feat_dim = f.nngp_feat.shape[-1] + ntk_feat_dim = f.ntk_feat.shape[-1] + + nngp_feat_2d = f.nngp_feat.reshape(-1, nngp_feat_dim) + ntk_feat_2d = f.ntk_feat.reshape(-1, ntk_feat_dim) + + if method == 'rf': # Random Features approach. + W0: np.ndarray = input[0] + W1: np.ndarray = input[1] + ts2: TensorSRHT2 = input[2] + + kappa0_feat = (nngp_feat_2d @ W0 > 0) / np.sqrt(W0.shape[-1]) + nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / + np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) + ntk_feat = ts2.sketch(ntk_feat_2d, + kappa0_feat).reshape(input_shape + (-1,)) + + elif method == 'ps': + # ps0: PolyTensorSRHT = input[0] + # ps1: PolyTensorSRHT = input[1] + # ts2: TensorSRHT2 = input[2] + raise NotImplementedError + + elif method == 'exact': # Exact feature extraction via Cholesky decomposition. + layer_count = input + + nngp_feat = cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) + + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = kappa0(nngp_feat_2d) + ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + + if top_layer: + ntk_feat = (1 / 2**(layer_count/2))*ntk_feat + nngp_feat = (1 / 2**(layer_count/2))*nngp_feat + + else: + raise NotImplementedError + + if top_layer: + ntk_feat = (ntk_feat.T * f.norms).T + nngp_feat = (nngp_feat.T * f.norms).T + + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + return init_fn, feature_fn def _conv_feat(X, filter_size): diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py new file mode 100644 index 00000000..bcc05307 --- /dev/null +++ b/experimental/poly_fitting.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 22 16:56:26 2022 + +@author: amir +""" + +import numpy as onp +import quadprog + + +from matplotlib import pyplot as plt + + + +def quadprog_solve_qp(P, q, G=None, h=None, A=None, b=None): + qp_G = .5 * (P + P.T +1e-5*onp.eye(P.shape[0])) # make sure P is symmetric + qp_a = -q + if A is not None: + qp_C = -onp.vstack([A, G]).T + qp_b = -onp.hstack([b, h]) + meq = A.shape[0] + else: # no equality constraint + qp_C = -G.T + qp_b = -h + meq = 0 + return quadprog.solve_qp(qp_G, qp_a, qp_C, qp_b, meq)[0] + + + +def kappa1_coeffs(degree,h): + alpha_ = -1.0 + for i in range(h): + alpha_ = (2.0*alpha_ + (onp.sqrt(1-alpha_**2) + alpha_*(onp.pi - onp.arccos(alpha_)))/onp.pi)/3.0 + + n=15*h+5*degree + x = onp.sort(onp.concatenate((onp.linspace(alpha_, 1.0, num=201), onp.cos((2*onp.arange(n)+1)*onp.pi / (4*n))), axis=0)) + y = (onp.sqrt(1-x**2) + x*(onp.pi - onp.arccos(x)))/onp.pi + grid_len = len(x) + + Z = onp.zeros((grid_len,degree+1)) + Z[:,0] = onp.ones(grid_len) + for i in range(degree): + Z[:,i+1] = Z[:,i] * x + + w = y + U = Z.T + + beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len), Z[grid_len-1,:][onp.newaxis,:],y[grid_len-1]) + beta_[beta_ < 1e-5] = 0 + + return beta_ + + + +def kappa0_coeffs(degree,h): + alpha_ = -1.0 + for i in range(h): + alpha_ = (1.0*alpha_ + (onp.sqrt(1-alpha_**2) + alpha_*(onp.pi - onp.arccos(alpha_)))/onp.pi)/2.0 + + n=15*h+5*degree + x = onp.sort(onp.concatenate((onp.linspace(alpha_, 1.0, num=201), onp.cos((2*onp.arange(n)+1)*onp.pi / (4*n))), axis=0)) + y = (onp.pi - onp.arccos(x))/onp.pi + grid_len = len(x) + + + Z = onp.zeros((grid_len,degree+1)) + Z[:,0] = onp.ones(grid_len) + for i in range(degree): + Z[:,i+1] = Z[:,i] * x + + weight_ = onp.linspace(0.0, 1.0, num=grid_len) + 1/2 + w = y * weight_ + U = Z.T * weight_ + + beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len))#, Z[200,:][np.newaxis,:],y[200]) + beta_[beta_ < 1e-5] = 0 + + return beta_ \ No newline at end of file diff --git a/experimental/sketching.py b/experimental/sketching.py index 80a76139..5b629eb9 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -45,6 +45,76 @@ def sketch(self, x1, x2): out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) return np.concatenate((out.real, out.imag), 1) +# Standard SRHT with real valued output +@dataclasses.dataclass +class SRHT: + + input_dim: int + sketch_dim: int + + rng: np.ndarray + shape: Optional[np.ndarray] = None + + rand_signs: Optional[np.ndarray] = None + rand_inds: Optional[np.ndarray] = None + + replace = ... # type: Callable[..., 'TensorSRHT2'] + + def init_sketches(self) -> 'SRHT': + rng1, rng2 = random.split(self.rng, 2) + rand_signs = random.choice(rng1, 2, shape=(self.input_dim,)) * 2 - 1 + rand_inds = random.choice(rng2, self.input_dim, + shape=(self.sketch_dim // 2,)) + shape = (self.input_dim, self.sketch_dim) + return self.replace(shape=shape, + rand_signs=rand_signs, + rand_inds=rand_inds) + + def sketch(self, x): + xfft = np.fft.fftn(x * self.rand_signs, axes=(-1,))[:, self.rand_inds] + out = np.sqrt(1 / self.rand_inds.shape[-1]) * xfft + return np.concatenate((out.real, out.imag), 1) + +# TensorSRHT of degree 2 with complex valued output. This version allows different input vectors. +@dataclasses.dataclass +class CmplxTensorSRHT: + + input_dim1: int + input_dim2: int + sketch_dim: int + + rng: np.ndarray + shape: Optional[np.ndarray] = None + + rand_signs1: Optional[np.ndarray] = None + rand_signs2: Optional[np.ndarray] = None + rand_inds1: Optional[np.ndarray] = None + rand_inds2: Optional[np.ndarray] = None + + replace = ... # type: Callable[..., 'TensorSRHT2'] + + def init_sketches(self) -> 'CmplxTensorSRHT': + rng1, rng2, rng3, rng4 = random.split(self.rng, 4) + rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 + rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 + rand_inds1 = random.choice(rng3,self.input_dim1, + shape=(self.sketch_dim // 2,)) + rand_inds2 = random.choice(rng4,self.input_dim2, + shape=(self.sketch_dim // 2,)) + shape = (self.input_dim1, self.input_dim2, self.sketch_dim) + return self.replace(shape=shape, + rand_signs1=rand_signs1, + rand_signs2=rand_signs2, + rand_inds1=rand_inds1, + rand_inds2=rand_inds2) + + def sketch(self, x1, x2): + x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] + x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] + out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) + return out + + # Standard SRHT as a function def standardsrht(x, rand_inds, rand_signs): xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] @@ -72,7 +142,7 @@ def init_sketch(self) -> 'PolyTensorSketch': tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] rng1, rng3 = random.split(self.rng, 2) - ske_dim_ = self.sketch_dim // 4 + ske_dim_ = self.sketch_dim // 4 -1 deg_ = self.degree // 2 for i in range((self.degree - 1).bit_length()): @@ -92,9 +162,9 @@ def init_sketch(self) -> 'PolyTensorSketch': deg_ = deg_ // 2 rng1, rng2 = random.split(rng3,2) - rand_signs = random.choice(rng1, 2, shape=(self.degree * ske_dim_,)) * 2 - 1 + rand_signs = random.choice(rng1, 2, shape=(1+self.degree * ske_dim_,)) * 2 - 1 rand_inds = random.choice(rng2, - self.degree * ske_dim_,shape=(self.sketch_dim // 2,)) + 1+self.degree * ske_dim_,shape=(self.sketch_dim // 2,)) return self.replace(tree_rand_signs=tree_rand_signs, tree_rand_inds=tree_rand_inds, @@ -136,7 +206,7 @@ def sketch(self, x): self.tree_rand_signs[i][j, :, :])) U = [0 for i in range(2**log_degree)] - U[0] = V[log_degree - 1][0, :, :].clone() + U[0] = V[log_degree - 1][0, :, :] SetE1 = set() @@ -153,14 +223,27 @@ def sketch(self, x): self.tree_rand_signs[i][p, 0, :])) else: if (i-1,2*p) in SetE1: - V[i] = V[i].at[p, :, :].set(V[i-1][2*p+1,:,:].clone()) + V[i] = V[i].at[p, :, :].set(V[i-1][2*p+1,:,:]) else: V[i] = V[i].at[p, :, :].set( self.tensorsrht(V[i - 1][2 * p, :, :], V[i - 1][2 * p + 1, :, :], self.tree_rand_inds[i][p, :, :], self.tree_rand_signs[i][p, :, :])) p = p // 2 - U[j] = V[log_degree - 1][0, :, :].clone() + U[j] = V[log_degree - 1][0, :, :] return U + + def expand_feats(self, polysketch_feats, coeffs): + n, sktch_dim = polysketch_feats[0].shape + Z = np.zeros((len(self.rand_signs),n), dtype=np.complex64) + Z = Z.at[1,:].set(np.sqrt(coeffs[0]) * np.ones(n)) + degree = len(polysketch_feats) + for i in range(degree): + Z = Z.at[sktch_dim*i+1:sktch_dim*(i+1)+1,:].set(np.sqrt( coeffs[i+1] ) * + polysketch_feats[degree-i-1].T) + + return Z.T + + # pytype: enable=attribute-error \ No newline at end of file diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index be91063a..b6811e7c 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -8,7 +8,7 @@ config.update("jax_enable_x64", True) from neural_tangents import stax -from experimental.features import DenseFeatures, ReluFeatures, serial +from features import DenseFeatures, ReluFeatures, serial seed = 1 n, d = 6, 4 @@ -21,6 +21,7 @@ print("================= Result of Neural Tangent Library =================") init_fn, _, kernel_fn = stax.serial(stax.Dense(width), stax.Relu(), + stax.Dense(width), stax.Relu(), stax.Dense(width), stax.Relu(), stax.Dense(1)) @@ -36,9 +37,9 @@ print("================= Result of NTK Random Features =================") -kappa0_feat_dim = 10000 -kappa1_feat_dim = 10000 -sketch_dim = 20000 +kappa0_feat_dim = 2048 +kappa1_feat_dim = 2048 +sketch_dim = 2048 relufeat_arg = { 'method': 'rf', @@ -47,10 +48,20 @@ 'sketch_dim': sketch_dim, } +relufeat_arg_top = { + 'method': 'rf', + 'feature_dim0': kappa0_feat_dim, + 'feature_dim1': kappa1_feat_dim, + 'sketch_dim': sketch_dim, + 'top_layer': True +} + init_fn, features_fn = serial(DenseFeatures(width), ReluFeatures(**relufeat_arg), DenseFeatures(width), - ReluFeatures(**relufeat_arg), DenseFeatures(1)) + ReluFeatures(**relufeat_arg), + DenseFeatures(width), + ReluFeatures(**relufeat_arg_top), DenseFeatures(1)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) @@ -79,11 +90,18 @@ print("================= (Debug) Exact NTK Feature Maps =================") relufeat_arg = {'method': 'exact'} +relufeat_arg_top = {'method': 'exact', 'top_layer': True} + init_fn, features_fn = serial(DenseFeatures(width), ReluFeatures(**relufeat_arg), DenseFeatures(width), - ReluFeatures(**relufeat_arg), DenseFeatures(1)) + ReluFeatures(**relufeat_arg), + DenseFeatures(width), + ReluFeatures(**relufeat_arg_top), DenseFeatures(1)) + +# Initialize random vectors and sketching algorithms +feat_shape, feat_fn_inputs = init_fn(key2, x.shape) feats = jit(features_fn)(x, feat_fn_inputs) From 09ea575497c3d9c13d6044cbf181441dc41fa739 Mon Sep 17 00:00:00 2001 From: Amir Zandieh Date: Thu, 24 Mar 2022 17:25:48 +0100 Subject: [PATCH 13/44] fc ntk sketch --- experimental/features.py | 136 ++++++++++++++++++++++++++++------- experimental/poly_fitting.py | 7 +- experimental/sketching.py | 2 +- experimental/test_fc_ntk.py | 16 +++-- 4 files changed, 124 insertions(+), 37 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index ee537231..a806a4c6 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -10,7 +10,7 @@ from neural_tangents._src.stax.linear import _pool_kernel, Padding from neural_tangents._src.stax.linear import _Pooling as Pooling -from sketching import TensorSRHT2 +from sketching import TensorSRHT2, PolyTensorSketch, CmplxTensorSRHT from poly_fitting import kappa0_coeffs, kappa1_coeffs """ Implementation for NTK Sketching and Random Features """ @@ -43,6 +43,13 @@ def kappa1(x): return (_sqrt(prod - xxt**2) + (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi +def poly_expansion(x, coeffs): + y = np.ones_like(x) + results = np.zeros_like(x) + for c in coeffs: + results += c*y + y = y*x + return results @dataclasses.dataclass class Features: @@ -185,16 +192,14 @@ def ReluFeatures( feature_dim0: int = 1, feature_dim1: int = 1, sketch_dim: int = 1, - poly_degree0: int = 4, - poly_degree1: int = 4, - poly_sketch_dim0: int = 1, - poly_sketch_dim1: int = 1, + poly_degree: int = 4, + poly_sketch_dim: int = 1, method: str = 'rf', top_layer: bool = False ): method = method.lower() - assert method in ['rf', 'ps', 'exact'] + assert method in ['rf', 'ps', 'exact', 'psrf'] def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] @@ -215,19 +220,47 @@ def init_fn(rng, input_shape): input_dim2=feature_dim0, sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (W0, W1, ts2) - + + elif method == 'psrf': + new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + rng1, rng2, rng3 = random.split(rng, 3) + + kappa1_coeff = kappa1_coeffs(poly_degree,layer_count-1) + + # PolySketch expansion for nngp features. + polysketch = PolyTensorSketch(rng1, nngp_feat_shape[-1], + poly_sketch_dim, poly_degree).init_sketch() + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT2(input_dim1=ntk_feat_shape[-1], + input_dim2=feature_dim0, + sketch_dim=sketch_dim , rng=rng2).init_sketches() + + # Random vectors for random features of arc-cosine kernel of order 0. + # W0 = random.choice(rng3, 2, shape=(nngp_feat_shape[-1], feature_dim0//2)) * 2 - 1 + if layer_count ==1: + W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0//2)) + else: + W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0//2),dtype='float32') + + return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (W0, polysketch, tensorsrht, (kappa1_coeff, layer_count)) + elif method == 'ps': - # rng1, rng2, rng3 = random.split(rng, 3) - # # PolySketch algorithm for arc-cosine kernel of order 0. - # ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, - # poly_degree0) - # # PolySketch algorithm for arc-cosine kernel of order 1. - # ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, - # poly_degree1) - # # TensorSRHT of degree 2 for approximating tensor product. - # ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) - # return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) - raise NotImplementedError + new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + rng1, rng2, rng3 = random.split(rng, 3) + + kappa1_coeff = kappa1_coeffs(poly_degree,layer_count-1) + kappa0_coeff = kappa0_coeffs(poly_degree,layer_count-1) + + # PolySketch expansion for nngp features. + polysketch = PolyTensorSketch(rng1, nngp_feat_shape[-1]//(1+(layer_count>1)), + poly_sketch_dim, poly_degree).init_sketch() + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = CmplxTensorSRHT(input_dim1=ntk_feat_shape[-1]//(1+(layer_count>1)), + input_dim2=poly_degree*(polysketch.sketch_dim//4-1)+1, + sketch_dim=sketch_dim , rng=rng2).init_sketches() + + return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (polysketch, tensorsrht,(kappa0_coeff,kappa1_coeff, layer_count)) + raise NotImplementedError elif method == 'exact': # The exact feature map computation is for debug. @@ -235,7 +268,10 @@ def init_fn(rng, input_shape): nngp_feat_shape[:-1]),) new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) - return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (layer_count) + kappa1_coeff = kappa1_coeffs(poly_degree,layer_count-1) + kappa0_coeff = kappa0_coeffs(poly_degree,layer_count-1) + + return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (kappa0_coeff, kappa1_coeff, layer_count) def feature_fn(f: Features, input=None, **kwargs) -> Features: @@ -256,20 +292,66 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) ntk_feat = ts2.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) + + elif method == 'psrf': # Combination of Poly Sketch and Random Features. + W0: np.ndarray = input[0] + polysketch: PolyTensorSketch = input[1] + ts2: TensorSRHT2 = input[2] + kappa1_coeff: np.ndarray = input[3][0] + layer_count = input[3][1] + + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + del polysketch_feats + nngp_feat = polysketch.standardsrht(kappa1_feat, polysketch.rand_inds, + polysketch.rand_signs).reshape(input_shape + (-1,)) + nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) + + nngp_proj = np.dot(nngp_feat_2d , W0) + + + kappa0_feat = np.concatenate(((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / np.sqrt(W0.shape[-1]) + del W0 + ntk_feat = ts2.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) + if top_layer: + ntk_feat = (1 / 2**(layer_count/2))*ntk_feat + nngp_feat = (1 / 2**(layer_count/2))*nngp_feat + elif method == 'ps': - # ps0: PolyTensorSRHT = input[0] - # ps1: PolyTensorSRHT = input[1] - # ts2: TensorSRHT2 = input[2] - raise NotImplementedError + polysketch: PolyTensorSketch = input[0] + tensorsrht: CmplxTensorSRHT = input[1] + kappa0_coeff: np.ndarray = input[2][0] + kappa1_coeff: np.ndarray = input[2][1] + layer_count = input[2][2] + + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + nngp_feat = polysketch.standardsrht(kappa1_feat, polysketch.rand_inds, + polysketch.rand_signs).reshape(input_shape + (-1,)) + + kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) + del polysketch_feats + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) + + if top_layer: + ntk_feat = (1 / 2**(layer_count/2))*np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) + nngp_feat = (1 / 2**(layer_count/2))*np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) + - elif method == 'exact': # Exact feature extraction via Cholesky decomposition. - layer_count = input + elif method == 'exact': + + kappa0_coeff: np.ndarray = input[0] + kappa1_coeff: np.ndarray = input[1] + layer_count = input[2] + + gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) + nngp_feat = cholesky(poly_expansion(gram_nngp, kappa1_coeff)).reshape(input_shape + (-1,)) - nngp_feat = cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) ntk = ntk_feat_2d @ ntk_feat_2d.T - kappa0_mat = kappa0(nngp_feat_2d) + kappa0_mat = poly_expansion(gram_nngp, kappa0_coeff) + # kappa0_mat = kappa0(nngp_feat_2d) ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) if top_layer: diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index bcc05307..04bc8a2f 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -59,7 +59,7 @@ def kappa0_coeffs(degree,h): for i in range(h): alpha_ = (1.0*alpha_ + (onp.sqrt(1-alpha_**2) + alpha_*(onp.pi - onp.arccos(alpha_)))/onp.pi)/2.0 - n=15*h+5*degree + n=20*h+8*degree x = onp.sort(onp.concatenate((onp.linspace(alpha_, 1.0, num=201), onp.cos((2*onp.arange(n)+1)*onp.pi / (4*n))), axis=0)) y = (onp.pi - onp.arccos(x))/onp.pi grid_len = len(x) @@ -70,9 +70,8 @@ def kappa0_coeffs(degree,h): for i in range(degree): Z[:,i+1] = Z[:,i] * x - weight_ = onp.linspace(0.0, 1.0, num=grid_len) + 1/2 - w = y * weight_ - U = Z.T * weight_ + w = y + U = Z.T beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len))#, Z[200,:][np.newaxis,:],y[200]) beta_[beta_ < 1e-5] = 0 diff --git a/experimental/sketching.py b/experimental/sketching.py index 5b629eb9..a2d7b969 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -237,7 +237,7 @@ def sketch(self, x): def expand_feats(self, polysketch_feats, coeffs): n, sktch_dim = polysketch_feats[0].shape Z = np.zeros((len(self.rand_signs),n), dtype=np.complex64) - Z = Z.at[1,:].set(np.sqrt(coeffs[0]) * np.ones(n)) + Z = Z.at[0,:].set(np.sqrt(coeffs[0]) * np.ones(n)) degree = len(polysketch_feats) for i in range(degree): Z = Z.at[sktch_dim*i+1:sktch_dim*(i+1)+1,:].set(np.sqrt( coeffs[i+1] ) * diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index b6811e7c..d359db8e 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -11,7 +11,7 @@ from features import DenseFeatures, ReluFeatures, serial seed = 1 -n, d = 6, 4 +n, d = 6, 5 key1, key2 = random.split(random.PRNGKey(seed)) x = random.normal(key1, (n, d)) @@ -40,19 +40,25 @@ kappa0_feat_dim = 2048 kappa1_feat_dim = 2048 sketch_dim = 2048 +poly_degree = 4 +poly_sketch_dim = 2048 relufeat_arg = { - 'method': 'rf', + 'method': 'psrf', 'feature_dim0': kappa0_feat_dim, 'feature_dim1': kappa1_feat_dim, 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim } relufeat_arg_top = { - 'method': 'rf', + 'method': 'psrf', 'feature_dim0': kappa0_feat_dim, 'feature_dim1': kappa1_feat_dim, 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim, 'top_layer': True } @@ -89,8 +95,8 @@ print("================= (Debug) Exact NTK Feature Maps =================") -relufeat_arg = {'method': 'exact'} -relufeat_arg_top = {'method': 'exact', 'top_layer': True} +relufeat_arg = {'poly_degree': poly_degree, 'method': 'exact'} +relufeat_arg_top = {'poly_degree': poly_degree, 'method': 'exact', 'top_layer': True} init_fn, features_fn = serial(DenseFeatures(width), From 71a0946e5d21e034456e710009d7d3b40575f1a6 Mon Sep 17 00:00:00 2001 From: Amir Zandieh Date: Thu, 24 Mar 2022 20:08:19 +0100 Subject: [PATCH 14/44] poly fitting using jaxopt --- experimental/features.py | 21 +++++++++++---------- experimental/poly_fitting.py | 28 ++++++++++++++++------------ experimental/test_fc_ntk.py | 4 ++-- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index a806a4c6..395d7e8a 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -228,17 +228,17 @@ def init_fn(rng, input_shape): kappa1_coeff = kappa1_coeffs(poly_degree,layer_count-1) # PolySketch expansion for nngp features. - polysketch = PolyTensorSketch(rng1, nngp_feat_shape[-1], + polysketch = PolyTensorSketch(rng1, nngp_feat_shape[-1]//(1+(layer_count>1)), poly_sketch_dim, poly_degree).init_sketch() # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = TensorSRHT2(input_dim1=ntk_feat_shape[-1], + tensorsrht = CmplxTensorSRHT(input_dim1=ntk_feat_shape[-1]//(1+(layer_count>1)), input_dim2=feature_dim0, sketch_dim=sketch_dim , rng=rng2).init_sketches() # Random vectors for random features of arc-cosine kernel of order 0. # W0 = random.choice(rng3, 2, shape=(nngp_feat_shape[-1], feature_dim0//2)) * 2 - 1 if layer_count ==1: - W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0//2)) + W0 = random.normal(rng3, (2*nngp_feat_shape[-1], feature_dim0//2)) else: W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0//2),dtype='float32') @@ -288,15 +288,17 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: ts2: TensorSRHT2 = input[2] kappa0_feat = (nngp_feat_2d @ W0 > 0) / np.sqrt(W0.shape[-1]) + del W0 nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) + del W1 ntk_feat = ts2.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) elif method == 'psrf': # Combination of Poly Sketch and Random Features. W0: np.ndarray = input[0] polysketch: PolyTensorSketch = input[1] - ts2: TensorSRHT2 = input[2] + tensorsrht: CmplxTensorSRHT = input[2] kappa1_coeff: np.ndarray = input[3][0] layer_count = input[3][1] @@ -305,17 +307,16 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: del polysketch_feats nngp_feat = polysketch.standardsrht(kappa1_feat, polysketch.rand_inds, polysketch.rand_signs).reshape(input_shape + (-1,)) - nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) + # nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) - nngp_proj = np.dot(nngp_feat_2d , W0) - + nngp_proj = np.dot(np.concatenate((nngp_feat_2d.real, nngp_feat_2d.imag), axis=1) , W0) kappa0_feat = np.concatenate(((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / np.sqrt(W0.shape[-1]) del W0 - ntk_feat = ts2.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) if top_layer: - ntk_feat = (1 / 2**(layer_count/2))*ntk_feat - nngp_feat = (1 / 2**(layer_count/2))*nngp_feat + ntk_feat = (1 / 2**(layer_count/2))*np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) + nngp_feat = (1 / 2**(layer_count/2))*np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) elif method == 'ps': diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index 04bc8a2f..a279739e 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -8,6 +8,8 @@ import numpy as onp import quadprog +from jax import numpy as jnp +from jaxopt import OSQP from matplotlib import pyplot as plt @@ -15,17 +17,14 @@ def quadprog_solve_qp(P, q, G=None, h=None, A=None, b=None): - qp_G = .5 * (P + P.T +1e-5*onp.eye(P.shape[0])) # make sure P is symmetric - qp_a = -q + qp_Q = .5 * (P + P.T +1e-5*jnp.eye(P.shape[0])) # make sure P is symmetric + + qp = OSQP() if A is not None: - qp_C = -onp.vstack([A, G]).T - qp_b = -onp.hstack([b, h]) - meq = A.shape[0] - else: # no equality constraint - qp_C = -G.T - qp_b = -h - meq = 0 - return quadprog.solve_qp(qp_G, qp_a, qp_C, qp_b, meq)[0] + sol = qp.run(params_obj=(qp_Q, q), params_eq=(A, b), params_ineq=(G, h)).params + else: + sol = qp.run(params_obj=(qp_Q, q), params_eq=None, params_ineq=(G, h)).params + return onp.array(sol.primal) @@ -47,7 +46,10 @@ def kappa1_coeffs(degree,h): w = y U = Z.T - beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len), Z[grid_len-1,:][onp.newaxis,:],y[grid_len-1]) + beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , + onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], + -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len), + Z[grid_len-1,:][onp.newaxis,:],onp.array([y[grid_len-1]])) beta_[beta_ < 1e-5] = 0 return beta_ @@ -73,7 +75,9 @@ def kappa0_coeffs(degree,h): w = y U = Z.T - beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len))#, Z[200,:][np.newaxis,:],y[200]) + beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , + onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], + -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len)) beta_[beta_ < 1e-5] = 0 return beta_ \ No newline at end of file diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index d359db8e..1a678c6f 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -39,9 +39,9 @@ kappa0_feat_dim = 2048 kappa1_feat_dim = 2048 -sketch_dim = 2048 +sketch_dim = 4096 poly_degree = 4 -poly_sketch_dim = 2048 +poly_sketch_dim = 4096 relufeat_arg = { 'method': 'psrf', From 6404c1b1f1c998519a0f6470183ab7bc00b2b9e8 Mon Sep 17 00:00:00 2001 From: Amir Zandieh Date: Thu, 24 Mar 2022 20:45:02 +0100 Subject: [PATCH 15/44] poly fitting minor edit --- experimental/poly_fitting.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index a279739e..86c6c2fc 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -47,8 +47,7 @@ def kappa1_coeffs(degree,h): U = Z.T beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , - onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], - -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len), + -onp.eye(degree+1), onp.zeros(degree+1), Z[grid_len-1,:][onp.newaxis,:],onp.array([y[grid_len-1]])) beta_[beta_ < 1e-5] = 0 @@ -76,8 +75,7 @@ def kappa0_coeffs(degree,h): U = Z.T beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , - onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], - -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len)) + -onp.eye(degree+1), onp.zeros(degree+1)) beta_[beta_ < 1e-5] = 0 return beta_ \ No newline at end of file From 5492d68082b24a440d59ef635f9bc2874cf5b256 Mon Sep 17 00:00:00 2001 From: insuhan Date: Fri, 25 Mar 2022 10:05:28 +0900 Subject: [PATCH 16/44] Make poly_fitting jittable --- experimental/features.py | 34 ++---- experimental/poly_fitting.py | 196 +++++++++++++++++++++-------------- 2 files changed, 125 insertions(+), 105 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index 395d7e8a..eb98c0c0 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -22,34 +22,14 @@ def _prod(tuple_): return prod -# Arc-cosine kernel functions is for debugging. -def _arccos(x): - return np.arccos(np.clip(x, -1, 1)) - - -def _sqrt(x): - return np.sqrt(np.maximum(x, 1e-20)) - - -def kappa0(x): - xxt = x @ x.T - prod = np.outer(np.linalg.norm(x, axis=-1)**2, np.linalg.norm(x, axis=-1)**2) - return (1 - _arccos(xxt / _sqrt(prod)) / np.pi) - - -def kappa1(x): - xxt = x @ x.T - prod = np.outer(np.linalg.norm(x, axis=-1)**2, np.linalg.norm(x, axis=-1)**2) - return (_sqrt(prod - xxt**2) + - (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi - def poly_expansion(x, coeffs): - y = np.ones_like(x) - results = np.zeros_like(x) - for c in coeffs: - results += c*y - y = y*x - return results + y = np.ones_like(x) + results = np.zeros_like(x) + for c in coeffs: + results += c*y + y = y*x + return results + @dataclasses.dataclass class Features: diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index 86c6c2fc..17240581 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -1,81 +1,121 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Tue Mar 22 16:56:26 2022 - -@author: amir -""" - -import numpy as onp -import quadprog -from jax import numpy as jnp +from jax import numpy as np +from jax import lax from jaxopt import OSQP -from matplotlib import pyplot as plt - - - -def quadprog_solve_qp(P, q, G=None, h=None, A=None, b=None): - qp_Q = .5 * (P + P.T +1e-5*jnp.eye(P.shape[0])) # make sure P is symmetric - - qp = OSQP() - if A is not None: - sol = qp.run(params_obj=(qp_Q, q), params_eq=(A, b), params_ineq=(G, h)).params - else: - sol = qp.run(params_obj=(qp_Q, q), params_eq=None, params_ineq=(G, h)).params - return onp.array(sol.primal) - - - -def kappa1_coeffs(degree,h): - alpha_ = -1.0 - for i in range(h): - alpha_ = (2.0*alpha_ + (onp.sqrt(1-alpha_**2) + alpha_*(onp.pi - onp.arccos(alpha_)))/onp.pi)/3.0 - - n=15*h+5*degree - x = onp.sort(onp.concatenate((onp.linspace(alpha_, 1.0, num=201), onp.cos((2*onp.arange(n)+1)*onp.pi / (4*n))), axis=0)) - y = (onp.sqrt(1-x**2) + x*(onp.pi - onp.arccos(x)))/onp.pi - grid_len = len(x) - - Z = onp.zeros((grid_len,degree+1)) - Z[:,0] = onp.ones(grid_len) - for i in range(degree): - Z[:,i+1] = Z[:,i] * x - - w = y - U = Z.T - - beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , - -onp.eye(degree+1), onp.zeros(degree+1), - Z[grid_len-1,:][onp.newaxis,:],onp.array([y[grid_len-1]])) - beta_[beta_ < 1e-5] = 0 - - return beta_ - - - -def kappa0_coeffs(degree,h): - alpha_ = -1.0 - for i in range(h): - alpha_ = (1.0*alpha_ + (onp.sqrt(1-alpha_**2) + alpha_*(onp.pi - onp.arccos(alpha_)))/onp.pi)/2.0 - - n=20*h+8*degree - x = onp.sort(onp.concatenate((onp.linspace(alpha_, 1.0, num=201), onp.cos((2*onp.arange(n)+1)*onp.pi / (4*n))), axis=0)) - y = (onp.pi - onp.arccos(x))/onp.pi - grid_len = len(x) - - - Z = onp.zeros((grid_len,degree+1)) - Z[:,0] = onp.ones(grid_len) - for i in range(degree): - Z[:,i+1] = Z[:,i] * x - - w = y - U = Z.T - - beta_ = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , - -onp.eye(degree+1), onp.zeros(degree+1)) - beta_[beta_ < 1e-5] = 0 - - return beta_ \ No newline at end of file +def _arccos(x): + return np.arccos(np.clip(x, -1, 1)) + + +def _sqrt(x): + return np.sqrt(np.maximum(x, 1e-20)) + + +def kappa0(x, is_x_matrix=True): + if is_x_matrix: + xxt = x @ x.T + xnormsq = np.linalg.norm(x, axis=-1)**2 + prod = np.outer(xnormsq, xnormsq) + return (1 - _arccos(xxt / _sqrt(prod)) / np.pi) + else: # vector input + return (1 - _arccos(x) / np.pi) + + +def kappa1(x, is_x_matrix=True): + if is_x_matrix: + xxt = x @ x.T + xnormsq = np.linalg.norm(x, axis=-1)**2 + prod = np.outer(xnormsq, xnormsq) + return (_sqrt(prod - xxt**2) + + (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi + else: # vector input + return (_sqrt(1 - x**2) + (np.pi - _arccos(x)) * x) / np.pi + + +def poly_fitting_qp(xvals: np.ndarray, + fvals: np.ndarray, + weights: np.ndarray, + degree: int, + eq_last_point: bool = False): + """ Computes polynomial coefficients that fitting input observations. + For a dot-product kernel (e.g., kappa0 or kappa1), coefficients of its + Taylor series expansion are always nonnegative. Moreover, the kernel + function is a monotone increasing function. This can be solved by + Quadratic Programming (QP) under inequality constraints. + """ + nx = len(xvals) + x_powers = np.ones((nx, degree + 1), dtype=xvals.dtype) + for i in range(degree): + x_powers = x_powers.at[:, i + 1].set(x_powers[:, i] * xvals) + + y_weighted = fvals * weights + x_powers_weighted = x_powers.T * weights + + dx_powers = x_powers[:-1, :] - x_powers[1:, :] + + # OSQP algorithm for solving min_x x'*Q*x + c'*x such that A*x=b, G*x<= h + P = x_powers_weighted @ x_powers_weighted.T + Q = .5 * (P.T + P + 1e-5 * np.eye(P.shape[0], dtype=xvals.dtype)) # make sure Q is symmetric + c = -x_powers_weighted @ y_weighted + G = np.concatenate((dx_powers, -np.eye(degree + 1)), axis=0) + h = np.zeros(nx + degree, dtype=xvals.dtype) + + if eq_last_point: + A = x_powers[-1, :][None, :] + b = fvals[-1:] + return OSQP().run(params_obj=(Q, c), params_eq=(A, b), + params_ineq=(G, h)).params.primal + else: + return OSQP().run(params_obj=(Q, c), params_ineq=(G, h)).params.primal + + +def kappa0_coeffsF(degree: int, num_layers: int): + + # A lower bound of kappa0^{(num_layers)} reduces to alpha_ from -1 + init_alpha_ = -1. + alpha_ = lax.fori_loop( + 0, 4, lambda i, x_: (x_ + kappa1(x_, is_x_matrix=False)) / 2., + init_alpha_) + + # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] + # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used + # where more points are around 1. + num_points = 20 * num_layers + 8 * degree + x_eq = np.linspace(alpha_, 1., num=201) + x_noneq = np.cos((2 * np.arange(num_points) + 1) * np.pi / (4 * num_points)) + xvals = np.sort(np.concatenate((x_eq, x_noneq))) + fvals = kappa0(xvals, is_x_matrix=False) + + # For kappa0, we set all weights to be one. + weights = np.ones(len(fvals), dtype=xvals.dtype) + + # Coefficients can be obtained by solving QP with OSQP jaxopt. + coeffs = poly_fitting_qp(xvals, fvals, weights, degree) + return np.where(coeffs < 1e-5, 0.0, coeffs) + + +def kappa1_coeffs(degree: int, num_layers: int): + + # A lower bound of kappa1^{(num_layers)} reduces to alpha_ from -1 + init_alpha_ = -1. + alpha_ = lax.fori_loop( + 0, 4, lambda i, x_: (2. * x_ + kappa1(x_, is_x_matrix=False)) / 3., + init_alpha_) + + # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] + # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used + # where more points are around 1. + num_points = 15 * num_layers + 5 * degree + x_eq = np.linspace(alpha_, 1., num=201) + x_noneq = np.cos( + (2. * np.arange(num_points) + 1.) * np.pi / (4. * num_points)) + xvals = np.sort(np.concatenate((x_eq, x_noneq))) + fvals = kappa1(xvals, is_x_matrix=False) + + # For kappa1, we set all weights to be one. + weights = np.ones(len(fvals), dtype=xvals.dtype) + + # For kappa1, we consider an equality condition for the last point + # (close to 1) because the slope around 1 is much sharper. + coeffs = poly_fitting_qp(xvals, fvals, weights, degree, eq_last_point=True) + return np.where(coeffs < 1e-5, 0.0, coeffs) From 049336ddabafc755164beb8b30141128199aff85 Mon Sep 17 00:00:00 2001 From: insuhan Date: Fri, 25 Mar 2022 10:12:22 +0900 Subject: [PATCH 17/44] Fix typo --- experimental/poly_fitting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index 17240581..76bb3447 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -69,7 +69,7 @@ def poly_fitting_qp(xvals: np.ndarray, return OSQP().run(params_obj=(Q, c), params_ineq=(G, h)).params.primal -def kappa0_coeffsF(degree: int, num_layers: int): +def kappa0_coeffs(degree: int, num_layers: int): # A lower bound of kappa0^{(num_layers)} reduces to alpha_ from -1 init_alpha_ = -1. From a14f8aa94da9f4a5039a25dc6267804b4537daa7 Mon Sep 17 00:00:00 2001 From: insuhan Date: Fri, 25 Mar 2022 10:18:09 +0900 Subject: [PATCH 18/44] Edit format of sketching.py --- experimental/sketching.py | 333 +++++++++++++++++++------------------- 1 file changed, 168 insertions(+), 165 deletions(-) diff --git a/experimental/sketching.py b/experimental/sketching.py index a2d7b969..296ef2c9 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -45,6 +45,7 @@ def sketch(self, x1, x2): out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) return np.concatenate((out.real, out.imag), 1) + # Standard SRHT with real valued output @dataclasses.dataclass class SRHT: @@ -58,192 +59,194 @@ class SRHT: rand_signs: Optional[np.ndarray] = None rand_inds: Optional[np.ndarray] = None - replace = ... # type: Callable[..., 'TensorSRHT2'] + replace = ... # type: Callable[..., 'SRHT'] def init_sketches(self) -> 'SRHT': rng1, rng2 = random.split(self.rng, 2) rand_signs = random.choice(rng1, 2, shape=(self.input_dim,)) * 2 - 1 - rand_inds = random.choice(rng2, self.input_dim, - shape=(self.sketch_dim // 2,)) + rand_inds = random.choice(rng2, + self.input_dim, + shape=(self.sketch_dim // 2,)) shape = (self.input_dim, self.sketch_dim) - return self.replace(shape=shape, - rand_signs=rand_signs, - rand_inds=rand_inds) + return self.replace(shape=shape, rand_signs=rand_signs, rand_inds=rand_inds) def sketch(self, x): xfft = np.fft.fftn(x * self.rand_signs, axes=(-1,))[:, self.rand_inds] out = np.sqrt(1 / self.rand_inds.shape[-1]) * xfft return np.concatenate((out.real, out.imag), 1) + # TensorSRHT of degree 2 with complex valued output. This version allows different input vectors. @dataclasses.dataclass class CmplxTensorSRHT: - input_dim1: int - input_dim2: int - sketch_dim: int - - rng: np.ndarray - shape: Optional[np.ndarray] = None - - rand_signs1: Optional[np.ndarray] = None - rand_signs2: Optional[np.ndarray] = None - rand_inds1: Optional[np.ndarray] = None - rand_inds2: Optional[np.ndarray] = None - - replace = ... # type: Callable[..., 'TensorSRHT2'] - - def init_sketches(self) -> 'CmplxTensorSRHT': - rng1, rng2, rng3, rng4 = random.split(self.rng, 4) - rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 - rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 - rand_inds1 = random.choice(rng3,self.input_dim1, - shape=(self.sketch_dim // 2,)) - rand_inds2 = random.choice(rng4,self.input_dim2, - shape=(self.sketch_dim // 2,)) - shape = (self.input_dim1, self.input_dim2, self.sketch_dim) - return self.replace(shape=shape, - rand_signs1=rand_signs1, - rand_signs2=rand_signs2, - rand_inds1=rand_inds1, - rand_inds2=rand_inds2) - - def sketch(self, x1, x2): - x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] - x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] - out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) - return out + input_dim1: int + input_dim2: int + sketch_dim: int + + rng: np.ndarray + shape: Optional[np.ndarray] = None + + rand_signs1: Optional[np.ndarray] = None + rand_signs2: Optional[np.ndarray] = None + rand_inds1: Optional[np.ndarray] = None + rand_inds2: Optional[np.ndarray] = None + + replace = ... # type: Callable[..., 'CmplxTensorSRHT'] + + def init_sketches(self) -> 'CmplxTensorSRHT': + rng1, rng2, rng3, rng4 = random.split(self.rng, 4) + rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 + rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 + rand_inds1 = random.choice(rng3, + self.input_dim1, + shape=(self.sketch_dim // 2,)) + rand_inds2 = random.choice(rng4, + self.input_dim2, + shape=(self.sketch_dim // 2,)) + shape = (self.input_dim1, self.input_dim2, self.sketch_dim) + return self.replace(shape=shape, + rand_signs1=rand_signs1, + rand_signs2=rand_signs2, + rand_inds1=rand_inds1, + rand_inds2=rand_inds2) + + def sketch(self, x1, x2): + x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] + x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] + out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) + return out # Standard SRHT as a function def standardsrht(x, rand_inds, rand_signs): - xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] - return np.sqrt(1 / rand_inds.shape[0]) * xfft + xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] + return np.sqrt(1 / rand_inds.shape[0]) * xfft + @dataclasses.dataclass class PolyTensorSketch: - - rng: np.ndarray - - input_dim: int - sketch_dim: int - degree: int - - tree_rand_signs: Optional[list] = None - tree_rand_inds: Optional[list] = None - rand_signs: Optional[np.ndarray] = None - rand_inds: Optional[np.ndarray] = None - - replace = ... - - def init_sketch(self) -> 'PolyTensorSketch': - - tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] - tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] - rng1, rng3 = random.split(self.rng, 2) - - ske_dim_ = self.sketch_dim // 4 -1 - deg_ = self.degree // 2 - - for i in range((self.degree - 1).bit_length()): - rng1, rng2 = random.split(rng1) - - if i == 0: - tree_rand_signs[i] = random.choice( - rng1, 2, shape=(deg_, 2, self.input_dim)) * 2 - 1 - tree_rand_inds[i] = random.choice(rng2, - self.input_dim, - shape=(deg_, 2, ske_dim_)) - else: - tree_rand_signs[i] = random.choice( - rng1, 2, shape=(deg_, 2, ske_dim_)) * 2 - 1 - tree_rand_inds[i] = random.choice(rng2, - ske_dim_,shape=(deg_, 2, ske_dim_)) - deg_ = deg_ // 2 - - rng1, rng2 = random.split(rng3,2) - rand_signs = random.choice(rng1, 2, shape=(1+self.degree * ske_dim_,)) * 2 - 1 - rand_inds = random.choice(rng2, - 1+self.degree * ske_dim_,shape=(self.sketch_dim // 2,)) - - return self.replace(tree_rand_signs=tree_rand_signs, - tree_rand_inds=tree_rand_inds, - rand_signs=rand_signs, - rand_inds=rand_inds) - - - # TensorSRHT of degree 2 - def tensorsrht(self, x1, x2, rand_inds, rand_signs): - x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]] - x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]] - return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) - - # Standard SRHT - def standardsrht(self, x, rand_inds, rand_signs): - xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] - return np.sqrt(1 / rand_inds.shape[0]) * xfft - - def sketch(self, x): - n = x.shape[0] - log_degree = len(self.tree_rand_signs) - V = [0 for i in range(log_degree)] - #E1 = np.concatenate((np.ones((n, 1), dtype=x.dtype), np.zeros((n, x.shape[-1] - 1), dtype=x.dtype)),1) - - for i in range(log_degree): - deg = self.tree_rand_signs[i].shape[0] - V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), - dtype=np.complex64) - for j in range(deg): - if i == 0: - V[i] = V[i].at[j, :, :].set( - self.tensorsrht(x, x, self.tree_rand_inds[i][j, :, :], - self.tree_rand_signs[i][j, :, :])) - - else: - V[i] = V[i].at[j, :, :].set( - self.tensorsrht(V[i - 1][2 * j, :, :], V[i - 1][2 * j + 1, :, :], - self.tree_rand_inds[i][j, :, :], - self.tree_rand_signs[i][j, :, :])) - - U = [0 for i in range(2**log_degree)] - U[0] = V[log_degree - 1][0, :, :] - - SetE1 = set() - - for j in range(1, len(U)): - p = (j - 1) // 2 - for i in range(log_degree): - if j % (2**(i + 1)) == 0: - SetE1.add((i,p)) - #V[i] = V[i].at[p, :, :].set(np.concatenate((np.ones((n, 1)), np.zeros((n, V[i].shape[-1] - 1))), 1)) + + rng: np.ndarray + + input_dim: int + sketch_dim: int + degree: int + + tree_rand_signs: Optional[list] = None + tree_rand_inds: Optional[list] = None + rand_signs: Optional[np.ndarray] = None + rand_inds: Optional[np.ndarray] = None + + replace = ... # type: Callable[..., 'PolyTensorSketch'] + + def init_sketch(self) -> 'PolyTensorSketch': + + tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] + tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] + rng1, rng3 = random.split(self.rng, 2) + + ske_dim_ = self.sketch_dim // 4 - 1 + deg_ = self.degree // 2 + + for i in range((self.degree - 1).bit_length()): + rng1, rng2 = random.split(rng1) + + if i == 0: + tree_rand_signs[i] = random.choice( + rng1, 2, shape=(deg_, 2, self.input_dim)) * 2 - 1 + tree_rand_inds[i] = random.choice(rng2, + self.input_dim, + shape=(deg_, 2, ske_dim_)) + else: + tree_rand_signs[i] = random.choice(rng1, 2, + shape=(deg_, 2, ske_dim_)) * 2 - 1 + tree_rand_inds[i] = random.choice(rng2, + ske_dim_, + shape=(deg_, 2, ske_dim_)) + deg_ = deg_ // 2 + + rng1, rng2 = random.split(rng3, 2) + rand_signs = random.choice(rng1, 2, + shape=(1 + self.degree * ske_dim_,)) * 2 - 1 + rand_inds = random.choice(rng2, + 1 + self.degree * ske_dim_, + shape=(self.sketch_dim // 2,)) + + return self.replace(tree_rand_signs=tree_rand_signs, + tree_rand_inds=tree_rand_inds, + rand_signs=rand_signs, + rand_inds=rand_inds) + + # TensorSRHT of degree 2 + def tensorsrht(self, x1, x2, rand_inds, rand_signs): + x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]] + x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]] + return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) + + # Standard SRHT + def standardsrht(self, x, rand_inds, rand_signs): + xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] + return np.sqrt(1 / rand_inds.shape[0]) * xfft + + def sketch(self, x): + n = x.shape[0] + log_degree = len(self.tree_rand_signs) + V = [0 for i in range(log_degree)] + + for i in range(log_degree): + deg = self.tree_rand_signs[i].shape[0] + V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), + dtype=np.complex64) + for j in range(deg): + if i == 0: + V[i] = V[i].at[j, :, :].set( + self.tensorsrht(x, x, self.tree_rand_inds[i][j, :, :], + self.tree_rand_signs[i][j, :, :])) + + else: + V[i] = V[i].at[j, :, :].set( + self.tensorsrht(V[i - 1][2 * j, :, :], V[i - 1][2 * j + 1, :, :], + self.tree_rand_inds[i][j, :, :], + self.tree_rand_signs[i][j, :, :])) + + U = [0 for i in range(2**log_degree)] + U[0] = V[log_degree - 1][0, :, :] + + SetE1 = set() + + for j in range(1, len(U)): + p = (j - 1) // 2 + for i in range(log_degree): + if j % (2**(i + 1)) == 0: + SetE1.add((i, p)) + else: + if i == 0: + V[i] = V[i].at[p, :, :].set( + self.standardsrht(x, self.tree_rand_inds[i][p, 0, :], + self.tree_rand_signs[i][p, 0, :])) + else: + if (i - 1, 2 * p) in SetE1: + V[i] = V[i].at[p, :, :].set(V[i - 1][2 * p + 1, :, :]) else: - if i == 0: - V[i] = V[i].at[p, :, :].set( - self.standardsrht(x, self.tree_rand_inds[i][p, 0, :], - self.tree_rand_signs[i][p, 0, :])) - else: - if (i-1,2*p) in SetE1: - V[i] = V[i].at[p, :, :].set(V[i-1][2*p+1,:,:]) - else: - V[i] = V[i].at[p, :, :].set( - self.tensorsrht(V[i - 1][2 * p, :, :], V[i - 1][2 * p + 1, :, :], - self.tree_rand_inds[i][p, :, :], - self.tree_rand_signs[i][p, :, :])) - p = p // 2 - U[j] = V[log_degree - 1][0, :, :] - - return U - - def expand_feats(self, polysketch_feats, coeffs): - n, sktch_dim = polysketch_feats[0].shape - Z = np.zeros((len(self.rand_signs),n), dtype=np.complex64) - Z = Z.at[0,:].set(np.sqrt(coeffs[0]) * np.ones(n)) - degree = len(polysketch_feats) - for i in range(degree): - Z = Z.at[sktch_dim*i+1:sktch_dim*(i+1)+1,:].set(np.sqrt( coeffs[i+1] ) * - polysketch_feats[degree-i-1].T) - - return Z.T - - + V[i] = V[i].at[p, :, :].set( + self.tensorsrht(V[i - 1][2 * p, :, :], + V[i - 1][2 * p + 1, :, :], + self.tree_rand_inds[i][p, :, :], + self.tree_rand_signs[i][p, :, :])) + p = p // 2 + U[j] = V[log_degree - 1][0, :, :] + + return U + + def expand_feats(self, polysketch_feats, coeffs): + n, sktch_dim = polysketch_feats[0].shape + Z = np.zeros((len(self.rand_signs), n), dtype=np.complex64) + Z = Z.at[0, :].set(np.sqrt(coeffs[0]) * np.ones(n)) + degree = len(polysketch_feats) + for i in range(degree): + Z = Z.at[sktch_dim * i + 1:sktch_dim * (i + 1) + 1, :].set( + np.sqrt(coeffs[i + 1]) * polysketch_feats[degree - i - 1].T) + + return Z.T # pytype: enable=attribute-error \ No newline at end of file From 5f5ac186a05256631ed3a297d3915cf25724935f Mon Sep 17 00:00:00 2001 From: insuhan Date: Sat, 26 Mar 2022 05:10:35 +0900 Subject: [PATCH 19/44] Fix typo in alpha_ computitation --- experimental/poly_fitting.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index 76bb3447..9d0e6563 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -1,5 +1,5 @@ from jax import numpy as np -from jax import lax +from jax.lax import fori_loop from jaxopt import OSQP @@ -55,7 +55,8 @@ def poly_fitting_qp(xvals: np.ndarray, # OSQP algorithm for solving min_x x'*Q*x + c'*x such that A*x=b, G*x<= h P = x_powers_weighted @ x_powers_weighted.T - Q = .5 * (P.T + P + 1e-5 * np.eye(P.shape[0], dtype=xvals.dtype)) # make sure Q is symmetric + Q = .5 * (P.T + P + 1e-5 * np.eye(P.shape[0], dtype=xvals.dtype) + ) # make sure Q is symmetric c = -x_powers_weighted @ y_weighted G = np.concatenate((dx_powers, -np.eye(degree + 1)), axis=0) h = np.zeros(nx + degree, dtype=xvals.dtype) @@ -72,10 +73,8 @@ def poly_fitting_qp(xvals: np.ndarray, def kappa0_coeffs(degree: int, num_layers: int): # A lower bound of kappa0^{(num_layers)} reduces to alpha_ from -1 - init_alpha_ = -1. - alpha_ = lax.fori_loop( - 0, 4, lambda i, x_: (x_ + kappa1(x_, is_x_matrix=False)) / 2., - init_alpha_) + alpha_ = fori_loop(0, num_layers, lambda i, x_: + (x_ + kappa1(x_, is_x_matrix=False)) / 2., -1.) # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used @@ -97,10 +96,9 @@ def kappa0_coeffs(degree: int, num_layers: int): def kappa1_coeffs(degree: int, num_layers: int): # A lower bound of kappa1^{(num_layers)} reduces to alpha_ from -1 - init_alpha_ = -1. - alpha_ = lax.fori_loop( - 0, 4, lambda i, x_: (2. * x_ + kappa1(x_, is_x_matrix=False)) / 3., - init_alpha_) + alpha_ = fori_loop( + 0, num_layers, lambda i, x_: + (2. * x_ + kappa1(x_, is_x_matrix=False)) / 3., -1.) # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used From 7e0f580502b3149c79e917843ebf341ae7817ddc Mon Sep 17 00:00:00 2001 From: insuhan Date: Sat, 26 Mar 2022 05:38:12 +0900 Subject: [PATCH 20/44] Delete unnecessaries --- experimental/ntk_sketch.py | 84 -------------------------------------- 1 file changed, 84 deletions(-) delete mode 100644 experimental/ntk_sketch.py diff --git a/experimental/ntk_sketch.py b/experimental/ntk_sketch.py deleted file mode 100644 index cede91d6..00000000 --- a/experimental/ntk_sketch.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Wed Mar 16 15:41:18 2022 - -@author: amir -""" - -import numpy as onp -import quadprog -from matplotlib import pyplot as plt -from jax import numpy as np -from jax.numpy import linalg as LA -from sketching import standardsrht - -def quadprog_solve_qp(P, q, G=None, h=None, A=None, b=None): - qp_G = .5 * (P + P.T +1e-5*onp.eye(P.shape[0])) # make sure P is symmetric - qp_a = -q - if A is not None: - qp_C = -onp.vstack([A, G]).T - qp_b = -onp.hstack([b, h]) - meq = A.shape[0] - else: # no equality constraint - qp_C = -G.T - qp_b = -h - meq = 0 - return quadprog.solve_qp(qp_G, qp_a, qp_C, qp_b, meq)[0] - -def ntk_poly_coeffs(L,degree): - n=15*L+5*degree - Y = onp.zeros((201+n,L+1)) - Y[:,0] = onp.sort(onp.concatenate((onp.linspace(-1.0, 1.0, num=201), onp.cos((2*onp.arange(n)+1)*onp.pi / (4*n))), axis=0)) - - grid_len = Y.shape[0] - - for i in range(L): - Y[:,i+1] = (onp.sqrt(1-Y[:,i]**2) + Y[:,i]*(onp.pi - onp.arccos(Y[:,i])))/onp.pi - - y = onp.zeros(grid_len) - for i in range(L+1): - z = Y[:,i] - for j in range(i,L): - z = z*(onp.pi - onp.arccos(Y[:,j]))/onp.pi - y = y + z - - Z = onp.zeros((grid_len,degree+1)) - Z[:,0] = onp.ones(grid_len) - for i in range(degree): - Z[:,i+1] = Z[:,i] * Y[:,0] - - - weight_ = onp.linspace(0.0, 1.0, num=grid_len) + 2/L - w = y * weight_ - U = Z.T * weight_ - - coeffs = quadprog_solve_qp(onp.dot(U, U.T), -onp.dot(U,w) , onp.concatenate((Z[0:grid_len-1,:]-Z[1:grid_len,:], -onp.eye(degree+1)),axis=0), onp.zeros(degree+grid_len)) - coeffs[coeffs < 1e-5] = 0 - - return coeffs - -def poly_ntk_sketch(depth, polysketch, X): - degree = polysketch.degree - n = X.shape[0] - - ntk_coeff = ntk_poly_coeffs(depth, degree) - - norm_x = LA.norm(X, axis=1) - normalizer = np.where(norm_x>0, norm_x, 1.0) - x_normlzd = ((X.T / normalizer).T) - - polysketch_feats = polysketch.sketch(x_normlzd) - - sktch_dim = polysketch_feats[0].shape[1] - - Z = np.zeros((len(polysketch.rand_signs),n), dtype=np.complex64) - for i in range(degree): - Z = Z.at[sktch_dim*i:sktch_dim*(i+1),:].set(np.sqrt( ntk_coeff[i+1] ) * - polysketch_feats[degree-i-1].T) - - Z = standardsrht(Z.T, polysketch.rand_inds, polysketch.rand_signs) - Z = (Z.T * normalizer).T - - return np.concatenate(( np.sqrt(ntk_coeff[0]) * normalizer.reshape((n,1)), np.concatenate((Z.real, Z.imag), 1)), 1) - From 391a1b82600be5e3a8a5d28a9d4e0b60c6170b4f Mon Sep 17 00:00:00 2001 From: insuhan Date: Sun, 27 Mar 2022 12:06:16 +0900 Subject: [PATCH 21/44] Update FC NTK features and check pytype --- experimental/features.py | 448 ++++++++++++++++++++---------------- experimental/sketching.py | 99 +------- experimental/test_fc_ntk.py | 209 ++++++++++++++--- 3 files changed, 434 insertions(+), 322 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index eb98c0c0..d79282ac 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -3,15 +3,14 @@ from jax import numpy as np from jax.numpy.linalg import cholesky import jax.example_libraries.stax as ostax -from jax.numpy import linalg as LA from neural_tangents import stax from neural_tangents._src.utils import dataclasses from neural_tangents._src.stax.linear import _pool_kernel, Padding from neural_tangents._src.stax.linear import _Pooling as Pooling -from sketching import TensorSRHT2, PolyTensorSketch, CmplxTensorSRHT -from poly_fitting import kappa0_coeffs, kappa1_coeffs +from experimental.sketching import TensorSRHT, PolyTensorSketch +from experimental.poly_fitting import kappa0_coeffs, kappa1_coeffs, kappa0, kappa1 """ Implementation for NTK Sketching and Random Features """ @@ -22,13 +21,8 @@ def _prod(tuple_): return prod -def poly_expansion(x, coeffs): - y = np.ones_like(x) - results = np.zeros_like(x) - for c in coeffs: - results += c*y - y = y*x - return results +def _poly_expansion(x, coeffs): + return np.polyval(coeffs[::-1], x) @dataclasses.dataclass @@ -51,9 +45,9 @@ def _inputs_to_features(x: np.ndarray, # Followed the same initialization of Neural Tangents library. nngp_feat = x / x.shape[channel_axis]**0.5 - norms = LA.norm(nngp_feat, axis=channel_axis) - norms = np.where(norms>0, norms, 1.0) - nngp_feat = (nngp_feat.T / norms).T + norms = np.linalg.norm(nngp_feat, axis=channel_axis) + norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) + nngp_feat = nngp_feat / norms ntk_feat = np.array([0.0], dtype=nngp_feat.dtype) @@ -84,9 +78,7 @@ def feature_fn_any(x_or_feature, input=None, **kwargs): def _is_sinlge_shape(input_shape): - if all(isinstance(n, int) for n in input_shape): - return True - return False + return all(isinstance(n, int) for n in input_shape) def _is_defaut_feature(feat): @@ -105,6 +97,12 @@ def init_fn_any(rng, input_shape_any, **kwargs): return init_fn_any +def _renormalize_feature(f: Features, **kwargs): + nngp_feat = f.nngp_feat * f.norms + ntk_feat = f.ntk_feat * f.norms + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + def layer(layer_fn): def new_layer_fns(*args, **kwargs): @@ -116,16 +114,28 @@ def new_layer_fns(*args, **kwargs): return new_layer_fns +def _check_modules_contain_dense_relu(module_names: tuple) -> bool: + + def _check_string_tuple_has_one_entry(str_tuple, entry): + return len(set(str_tuple)) == 1 and str_tuple[0] == entry + + return len(module_names) % 2 == 1 and _check_string_tuple_has_one_entry( + module_names[::2], 'DenseFeatures') and _check_string_tuple_has_one_entry( + module_names[1::2], 'ReluFeatures') + + # Modified the serial process of feature map blocks. # Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py @layer def serial(*layers): + init_fns, feature_fns = zip(*layers) init_fn, _ = ostax.serial(*zip(init_fns, init_fns)) def feature_fn(k, inputs, **kwargs): for f, input_ in zip(feature_fns, inputs): k = f(k, input_, **kwargs) + k = _renormalize_feature(k) return k return init_fn, feature_fn @@ -146,210 +156,244 @@ def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_ntk_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] + ntk_feat_shape[-1],) - + if len(input_shape) > 2: - return (nngp_feat_shape, new_ntk_feat_shape, input_shape[2]+'D'), () + return (nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () else: - return (nngp_feat_shape, new_ntk_feat_shape, 'D'), () + return (nngp_feat_shape, new_ntk_feat_shape, 'D'), () def feature_fn(f: Features, input, **kwargs): - nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat - nngp_feat *= W_std - ntk_feat *= W_std + nngp_feat, ntk_feat, norms = f.nngp_feat, f.ntk_feat, f.norms + norms *= W_std if _is_defaut_feature(ntk_feat): # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) return init_fn, feature_fn @layer -def ReluFeatures( - feature_dim0: int = 1, - feature_dim1: int = 1, - sketch_dim: int = 1, - poly_degree: int = 4, - poly_sketch_dim: int = 1, - method: str = 'rf', - top_layer: bool = False -): - - method = method.lower() - assert method in ['rf', 'ps', 'exact', 'psrf'] - - def init_fn(rng, input_shape): - nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) - net_shape = input_shape[2] - layer_count = len(net_shape)//2+1 - - if method == 'rf': - rng1, rng2, rng3 = random.split(rng, 3) - # Random vectors for random features of arc-cosine kernel of order 0. - W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) - # Random vectors for random features of arc-cosine kernel of order 1. - W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) - # TensorSRHT of degree 2 for approximating tensor product. - ts2 = TensorSRHT2(rng=rng3, +def ReluFeatures(feature_dim0: int = 1, + feature_dim1: int = 1, + sketch_dim: int = 1, + poly_degree: int = 8, + poly_sketch_dim: int = 1, + method: str = 'rf', + top_layer: bool = False): + + method = method.lower() + assert method in ['rf', 'ps', 'exact', 'psrf', 'poly'] + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) + net_shape = input_shape[2] + layer_count = len(net_shape) // 2 + 1 + + if method == 'rf': + rng1, rng2, rng3 = random.split(rng, 3) + # Random vectors for random features of arc-cosine kernel of order 0. + W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) + # Random vectors for random features of arc-cosine kernel of order 1. + W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT(rng=rng3, input_dim1=ntk_feat_shape[-1], input_dim2=feature_dim0, sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args - return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (W0, W1, ts2) - - elif method == 'psrf': - new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) - rng1, rng2, rng3 = random.split(rng, 3) - - kappa1_coeff = kappa1_coeffs(poly_degree,layer_count-1) - - # PolySketch expansion for nngp features. - polysketch = PolyTensorSketch(rng1, nngp_feat_shape[-1]//(1+(layer_count>1)), - poly_sketch_dim, poly_degree).init_sketch() - # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = CmplxTensorSRHT(input_dim1=ntk_feat_shape[-1]//(1+(layer_count>1)), - input_dim2=feature_dim0, - sketch_dim=sketch_dim , rng=rng2).init_sketches() - - # Random vectors for random features of arc-cosine kernel of order 0. - # W0 = random.choice(rng3, 2, shape=(nngp_feat_shape[-1], feature_dim0//2)) * 2 - 1 - if layer_count ==1: - W0 = random.normal(rng3, (2*nngp_feat_shape[-1], feature_dim0//2)) - else: - W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0//2),dtype='float32') - - return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (W0, polysketch, tensorsrht, (kappa1_coeff, layer_count)) - - elif method == 'ps': - new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) - rng1, rng2, rng3 = random.split(rng, 3) - - kappa1_coeff = kappa1_coeffs(poly_degree,layer_count-1) - kappa0_coeff = kappa0_coeffs(poly_degree,layer_count-1) - - # PolySketch expansion for nngp features. - polysketch = PolyTensorSketch(rng1, nngp_feat_shape[-1]//(1+(layer_count>1)), - poly_sketch_dim, poly_degree).init_sketch() - # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = CmplxTensorSRHT(input_dim1=ntk_feat_shape[-1]//(1+(layer_count>1)), - input_dim2=poly_degree*(polysketch.sketch_dim//4-1)+1, - sketch_dim=sketch_dim , rng=rng2).init_sketches() - - return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (polysketch, tensorsrht,(kappa0_coeff,kappa1_coeff, layer_count)) - raise NotImplementedError - - elif method == 'exact': - # The exact feature map computation is for debug. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( - nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) - - kappa1_coeff = kappa1_coeffs(poly_degree,layer_count-1) - kappa0_coeff = kappa0_coeffs(poly_degree,layer_count-1) - - return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape+'R'), (kappa0_coeff, kappa1_coeff, layer_count) - - def feature_fn(f: Features, input=None, **kwargs) -> Features: - - input_shape = f.nngp_feat.shape[:-1] - nngp_feat_dim = f.nngp_feat.shape[-1] - ntk_feat_dim = f.ntk_feat.shape[-1] - - nngp_feat_2d = f.nngp_feat.reshape(-1, nngp_feat_dim) - ntk_feat_2d = f.ntk_feat.reshape(-1, ntk_feat_dim) - - if method == 'rf': # Random Features approach. - W0: np.ndarray = input[0] - W1: np.ndarray = input[1] - ts2: TensorSRHT2 = input[2] - - kappa0_feat = (nngp_feat_2d @ W0 > 0) / np.sqrt(W0.shape[-1]) - del W0 - nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / - np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) - del W1 - ntk_feat = ts2.sketch(ntk_feat_2d, - kappa0_feat).reshape(input_shape + (-1,)) - - elif method == 'psrf': # Combination of Poly Sketch and Random Features. - W0: np.ndarray = input[0] - polysketch: PolyTensorSketch = input[1] - tensorsrht: CmplxTensorSRHT = input[2] - kappa1_coeff: np.ndarray = input[3][0] - layer_count = input[3][1] - - polysketch_feats = polysketch.sketch(nngp_feat_2d) - kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) - del polysketch_feats - nngp_feat = polysketch.standardsrht(kappa1_feat, polysketch.rand_inds, - polysketch.rand_signs).reshape(input_shape + (-1,)) - # nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) - - nngp_proj = np.dot(np.concatenate((nngp_feat_2d.real, nngp_feat_2d.imag), axis=1) , W0) - - kappa0_feat = np.concatenate(((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / np.sqrt(W0.shape[-1]) - del W0 - ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) - if top_layer: - ntk_feat = (1 / 2**(layer_count/2))*np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) - nngp_feat = (1 / 2**(layer_count/2))*np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) - - - elif method == 'ps': - polysketch: PolyTensorSketch = input[0] - tensorsrht: CmplxTensorSRHT = input[1] - kappa0_coeff: np.ndarray = input[2][0] - kappa1_coeff: np.ndarray = input[2][1] - layer_count = input[2][2] - - polysketch_feats = polysketch.sketch(nngp_feat_2d) - kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) - nngp_feat = polysketch.standardsrht(kappa1_feat, polysketch.rand_inds, - polysketch.rand_signs).reshape(input_shape + (-1,)) - - kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) - del polysketch_feats - ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) - - if top_layer: - ntk_feat = (1 / 2**(layer_count/2))*np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) - nngp_feat = (1 / 2**(layer_count/2))*np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) - - - elif method == 'exact': - - kappa0_coeff: np.ndarray = input[0] - kappa1_coeff: np.ndarray = input[1] - layer_count = input[2] - - gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) - nngp_feat = cholesky(poly_expansion(gram_nngp, kappa1_coeff)).reshape(input_shape + (-1,)) - - - ntk = ntk_feat_2d @ ntk_feat_2d.T - kappa0_mat = poly_expansion(gram_nngp, kappa0_coeff) - # kappa0_mat = kappa0(nngp_feat_2d) - ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) - - if top_layer: - ntk_feat = (1 / 2**(layer_count/2))*ntk_feat - nngp_feat = (1 / 2**(layer_count/2))*nngp_feat - - else: - raise NotImplementedError - - if top_layer: - ntk_feat = (ntk_feat.T * f.norms).T - nngp_feat = (nngp_feat.T * f.norms).T - - - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - - return init_fn, feature_fn + + return (new_nngp_feat_shape, new_ntk_feat_shape, + net_shape + 'R'), (W0, W1, tensorsrht) + + elif method == 'ps': + new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + rng1, rng2, rng3 = random.split(rng, 3) + + kappa1_coeff = kappa1_coeffs(poly_degree, layer_count - 1) + kappa0_coeff = kappa0_coeffs(poly_degree, layer_count - 1) + + # PolySketch expansion for nngp features. + polysketch = PolyTensorSketch(rng=rng1, + input_dim=nngp_feat_shape[-1] // + (1 + (layer_count > 1)), + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT( + input_dim1=ntk_feat_shape[-1] // (1 + (layer_count > 1)), + input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, + sketch_dim=sketch_dim, + rng=rng2).init_sketches() # pytype:disable=wrong-keyword-args + + return (new_nngp_feat_shape, new_ntk_feat_shape, + net_shape + 'R'), (polysketch, tensorsrht, + (kappa0_coeff, kappa1_coeff, layer_count)) + + elif method == 'psrf': + new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + rng1, rng2, rng3 = random.split(rng, 3) + + kappa1_coeff = kappa1_coeffs(poly_degree, layer_count - 1) + + # PolySketch expansion for nngp features. + polysketch = PolyTensorSketch(rng=rng1, + input_dim=nngp_feat_shape[-1] // + (1 + (layer_count > 1)), + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT(rng=rng2, + input_dim1=ntk_feat_shape[-1] // + (1 + (layer_count > 1)), + input_dim2=feature_dim0, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + # Random vectors for random features of arc-cosine kernel of order 0. + if layer_count == 1: + W0 = random.normal(rng3, (2 * nngp_feat_shape[-1], feature_dim0 // 2)) + else: + W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0 // 2)) + + return (new_nngp_feat_shape, new_ntk_feat_shape, + net_shape + 'R'), (W0, polysketch, tensorsrht, (kappa1_coeff, + layer_count)) + + elif method == 'poly': + # This only uses the polynomial approximation without sketching. + new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( + nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) + + kappa1_coeff = kappa1_coeffs(poly_degree, layer_count - 1) + kappa0_coeff = kappa0_coeffs(poly_degree, layer_count - 1) + + return (new_nngp_feat_shape, new_ntk_feat_shape, + net_shape + 'R'), (kappa0_coeff, kappa1_coeff, layer_count) + + elif method == 'exact': + # The exact feature map computation is for debug. + new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( + nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) + + return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape + 'R'), () + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + def feature_fn(f: Features, input=None, **kwargs) -> Features: + + input_shape = f.nngp_feat.shape[:-1] + nngp_feat_dim = f.nngp_feat.shape[-1] + ntk_feat_dim = f.ntk_feat.shape[-1] + + nngp_feat_2d = f.nngp_feat.reshape(-1, nngp_feat_dim) + ntk_feat_2d = f.ntk_feat.reshape(-1, ntk_feat_dim) + norms = f.norms + + if method == 'rf': # Random Features approach. + W0: np.ndarray = input[0] + W1: np.ndarray = input[1] + tensorsrht: TensorSRHT = input[2] + + kappa0_feat = (nngp_feat_2d @ W0 > 0) / np.sqrt(W0.shape[-1]) + del W0 + nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / + np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) + del W1 + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat, + real_output=True).reshape(input_shape + + (-1,)) + + elif method == 'ps': + polysketch: PolyTensorSketch = input[0] + tensorsrht: TensorSRHT = input[1] + kappa0_coeff: np.ndarray = input[2][0] + kappa1_coeff: np.ndarray = input[2][1] + layer_count: int = input[2][2] + + # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) + del polysketch_feats + + # Apply SRHT to kappa1_feat so that dimension of nngp_feat is poly_sketch_dim//2. + nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + + (-1,)) + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. + ntk_feat = tensorsrht.sketch(ntk_feat_2d, + kappa0_feat).reshape(input_shape + (-1,)) + + # At the top ReluFeatures, convert complex features to real ones. + if top_layer: + ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) + nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) + + elif method == 'psrf': # Combination of PolySketch and Random Features. + W0: np.ndarray = input[0] + polysketch: PolyTensorSketch = input[1] + tensorsrht: TensorSRHT = input[2] + kappa1_coeff: np.ndarray = input[3][0] + + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + del polysketch_feats + + nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + + (-1,)) + + nngp_proj = np.concatenate( + (nngp_feat_2d.real, nngp_feat_2d.imag), axis=1) @ W0 + kappa0_feat = np.concatenate( + ((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / np.sqrt(W0.shape[-1]) + del W0 + + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. + ntk_feat = tensorsrht.sketch(ntk_feat_2d, + kappa0_feat).reshape(input_shape + (-1,)) + + # At the top ReluFeatures, convert complex features to real ones. + if top_layer: + ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) + nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) + + elif method == 'poly': # Polynomial approximation without sketching. + kappa0_coeff: np.ndarray = input[0] + kappa1_coeff: np.ndarray = input[1] + layer_count = input[2] + + gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) + nngp_feat = cholesky(_poly_expansion( + gram_nngp, kappa1_coeff)).reshape(input_shape + (-1,)) + + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = _poly_expansion(gram_nngp, kappa0_coeff) + ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + + elif method == 'exact': # Exact feature map computations via Cholesky decomposition. + nngp_feat = cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) + + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = kappa0(nngp_feat_2d) + ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + if method != 'rf': + norms /= np.sqrt(2.) + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + + return init_fn, feature_fn def _conv_feat(X, filter_size): diff --git a/experimental/sketching.py b/experimental/sketching.py index 296ef2c9..aaf0b50c 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -6,7 +6,7 @@ # TensorSRHT of degree 2. This version allows different input vectors. @dataclasses.dataclass -class TensorSRHT2: +class TensorSRHT: input_dim1: int input_dim2: int @@ -20,9 +20,9 @@ class TensorSRHT2: rand_inds1: Optional[np.ndarray] = None rand_inds2: Optional[np.ndarray] = None - replace = ... # type: Callable[..., 'TensorSRHT2'] + replace = ... # type: Callable[..., 'TensorSRHT'] - def init_sketches(self) -> 'TensorSRHT2': + def init_sketches(self) -> 'TensorSRHT': rng1, rng2, rng3, rng4 = random.split(self.rng, 4) rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 @@ -39,100 +39,23 @@ def init_sketches(self) -> 'TensorSRHT2': rand_inds1=rand_inds1, rand_inds2=rand_inds2) - def sketch(self, x1, x2): + def sketch(self, x1, x2, real_output=False): x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) - return np.concatenate((out.real, out.imag), 1) - - -# Standard SRHT with real valued output -@dataclasses.dataclass -class SRHT: - - input_dim: int - sketch_dim: int - - rng: np.ndarray - shape: Optional[np.ndarray] = None - - rand_signs: Optional[np.ndarray] = None - rand_inds: Optional[np.ndarray] = None - - replace = ... # type: Callable[..., 'SRHT'] - - def init_sketches(self) -> 'SRHT': - rng1, rng2 = random.split(self.rng, 2) - rand_signs = random.choice(rng1, 2, shape=(self.input_dim,)) * 2 - 1 - rand_inds = random.choice(rng2, - self.input_dim, - shape=(self.sketch_dim // 2,)) - shape = (self.input_dim, self.sketch_dim) - return self.replace(shape=shape, rand_signs=rand_signs, rand_inds=rand_inds) - - def sketch(self, x): - xfft = np.fft.fftn(x * self.rand_signs, axes=(-1,))[:, self.rand_inds] - out = np.sqrt(1 / self.rand_inds.shape[-1]) * xfft - return np.concatenate((out.real, out.imag), 1) - - -# TensorSRHT of degree 2 with complex valued output. This version allows different input vectors. -@dataclasses.dataclass -class CmplxTensorSRHT: - - input_dim1: int - input_dim2: int - sketch_dim: int - - rng: np.ndarray - shape: Optional[np.ndarray] = None - - rand_signs1: Optional[np.ndarray] = None - rand_signs2: Optional[np.ndarray] = None - rand_inds1: Optional[np.ndarray] = None - rand_inds2: Optional[np.ndarray] = None - - replace = ... # type: Callable[..., 'CmplxTensorSRHT'] - - def init_sketches(self) -> 'CmplxTensorSRHT': - rng1, rng2, rng3, rng4 = random.split(self.rng, 4) - rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 - rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 - rand_inds1 = random.choice(rng3, - self.input_dim1, - shape=(self.sketch_dim // 2,)) - rand_inds2 = random.choice(rng4, - self.input_dim2, - shape=(self.sketch_dim // 2,)) - shape = (self.input_dim1, self.input_dim2, self.sketch_dim) - return self.replace(shape=shape, - rand_signs1=rand_signs1, - rand_signs2=rand_signs2, - rand_inds1=rand_inds1, - rand_inds2=rand_inds2) - - def sketch(self, x1, x2): - x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] - x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] - out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) - return out - - -# Standard SRHT as a function -def standardsrht(x, rand_inds, rand_signs): - xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] - return np.sqrt(1 / rand_inds.shape[0]) * xfft + return np.concatenate((out.real, out.imag), 1) if real_output else out +# pytype: disable=attribute-error @dataclasses.dataclass class PolyTensorSketch: - rng: np.ndarray - input_dim: int sketch_dim: int degree: int + rng: np.ndarray + tree_rand_signs: Optional[list] = None tree_rand_inds: Optional[list] = None rand_signs: Optional[np.ndarray] = None @@ -140,7 +63,7 @@ class PolyTensorSketch: replace = ... # type: Callable[..., 'PolyTensorSketch'] - def init_sketch(self) -> 'PolyTensorSketch': + def init_sketches(self) -> 'PolyTensorSketch': tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] @@ -185,7 +108,9 @@ def tensorsrht(self, x1, x2, rand_inds, rand_signs): return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) # Standard SRHT - def standardsrht(self, x, rand_inds, rand_signs): + def standardsrht(self, x, rand_inds=None, rand_signs=None): + rand_inds = self.rand_inds if rand_inds is None else rand_inds + rand_signs = self.rand_signs if rand_signs is None else rand_signs xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] return np.sqrt(1 / rand_inds.shape[0]) * xfft diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index 1a678c6f..f8cbd6d9 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -3,12 +3,13 @@ from jax.config import config from jax import jit import sys + sys.path.append("./") config.update("jax_enable_x64", True) from neural_tangents import stax -from features import DenseFeatures, ReluFeatures, serial +from experimental.features import DenseFeatures, ReluFeatures, serial seed = 1 n, d = 6, 5 @@ -17,13 +18,15 @@ x = random.normal(key1, (n, d)) width = 512 # this does not matter the output +W_std = 1.234 # std of Gaussian random weights -print("================= Result of Neural Tangent Library =================") +print("================== Result of Neural Tangent Library ==================") -init_fn, _, kernel_fn = stax.serial(stax.Dense(width), stax.Relu(), - stax.Dense(width), stax.Relu(), - stax.Dense(width), stax.Relu(), - stax.Dense(1)) +init_fn, _, kernel_fn = stax.serial( + stax.Dense(width, W_std=W_std), stax.Relu(), + stax.Dense(width, W_std=W_std), stax.Relu(), + stax.Dense(width, W_std=W_std), stax.Relu(), + stax.Dense(1, W_std=W_std)) nt_kernel = kernel_fn(x, None) @@ -35,52 +38,148 @@ print(nt_kernel.ntk) print() -print("================= Result of NTK Random Features =================") -kappa0_feat_dim = 2048 -kappa1_feat_dim = 2048 +print("==================== Result of NTK Random Features ====================") + +kappa0_feat_dim = 4096 +kappa1_feat_dim = 4096 sketch_dim = 4096 -poly_degree = 4 -poly_sketch_dim = 4096 relufeat_arg = { - 'method': 'psrf', + 'method': 'rf', 'feature_dim0': kappa0_feat_dim, 'feature_dim1': kappa1_feat_dim, 'sketch_dim': sketch_dim, +} + +print(f"ReluFeatures params:") +for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") +print() + +init_fn, features_fn = serial( + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std)) + +# Initialize random vectors and sketching algorithms +feat_shape, feat_fn_inputs = init_fn(key2, x.shape) + +# Transform input vectors to NNGP/NTK feature map +feats = jit(features_fn)(x, feat_fn_inputs) + +print(f"f_nngp shape: {feat_shape[0]}") +print(f"f_ntk shape: {feat_shape[1]}") + +print("K_nngp (approx):") +print(feats.nngp_feat @ feats.nngp_feat.T) +print() + +print("K_ntk (approx):") +print(feats.ntk_feat @ feats.ntk_feat.T) +print() + +print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" +) +print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" +) +print() + + +print("==================== Result of NTK wih PolySketch ====================") + +poly_degree = 4 +poly_sketch_dim = 4096 +sketch_dim = 4096 + +relufeat_arg = { + 'method': 'ps', + 'sketch_dim': sketch_dim, 'poly_degree': poly_degree, 'poly_sketch_dim': poly_sketch_dim } -relufeat_arg_top = { +print(f"ReluFeatures params:") +for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") +print() + +init_fn, features_fn = serial( + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), + DenseFeatures(1, W_std=W_std)) + +# Initialize random vectors and sketching algorithms +feat_shape, feat_fn_inputs = init_fn(key2, x.shape) + +# Transform input vectors to NNGP/NTK feature map +feats = features_fn(x, feat_fn_inputs) + +print(f"f_nngp shape: {feat_shape[0]}") +print(f"f_ntk shape: {feat_shape[1]}") + +print("K_nngp (approx):") +print(feats.nngp_feat @ feats.nngp_feat.T) +print() + +print("K_ntk (approx):") +print(feats.ntk_feat @ feats.ntk_feat.T) +print() + +print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" +) +print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" +) +print() + + +print("=============== Result of PolySketch + Random Features ===============") + +kappa0_feat_dim = 2048 +kappa1_feat_dim = 2048 +sketch_dim = 4096 +poly_degree = 4 +poly_sketch_dim = 4096 + +relufeat_arg = { 'method': 'psrf', 'feature_dim0': kappa0_feat_dim, 'feature_dim1': kappa1_feat_dim, 'sketch_dim': sketch_dim, 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim, - 'top_layer': True + 'poly_sketch_dim': poly_sketch_dim } -init_fn, features_fn = serial(DenseFeatures(width), - ReluFeatures(**relufeat_arg), - DenseFeatures(width), - ReluFeatures(**relufeat_arg), - DenseFeatures(width), - ReluFeatures(**relufeat_arg_top), DenseFeatures(1)) +print(f"ReluFeatures params:") +for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") +print() + +init_fn, features_fn = serial( + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), + DenseFeatures(1, W_std=W_std)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map -feats = jit(features_fn)(x, feat_fn_inputs) +feats = features_fn(x, feat_fn_inputs) print(f"f_nngp shape: {feat_shape[0]}") +print(f"f_ntk shape: {feat_shape[1]}") + print("K_nngp (approx):") print(feats.nngp_feat @ feats.nngp_feat.T) print() -print(f"f_ntk shape: {feat_shape[1]}") print("K_ntk (approx):") print(feats.ntk_feat @ feats.ntk_feat.T) print() @@ -93,23 +192,67 @@ ) print() -print("================= (Debug) Exact NTK Feature Maps =================") -relufeat_arg = {'poly_degree': poly_degree, 'method': 'exact'} -relufeat_arg_top = {'poly_degree': poly_degree, 'method': 'exact', 'top_layer': True} +print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") +print("\t(*No Sketching algorithm is applied.)") + +relufeat_arg = {'method': 'poly', 'poly_degree':64} + +print(f"ReluFeatures params:") +for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") +print() + +init_fn, feature_fn = serial( + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std)) + +# Initialize random vectors and sketching algorithms +feat_shape, feat_fn_inputs = init_fn(key2, x.shape) + +feats = jit(feature_fn)(x, feat_fn_inputs) + +print(f"f_nngp shape: {feat_shape[0]}") +print(f"f_ntk shape: {feat_shape[1]}") + +print("K_nngp (approx):") +print(feats.nngp_feat @ feats.nngp_feat.T) +print() + +print("K_ntk :") +print(feats.ntk_feat @ feats.ntk_feat.T) +print() + +print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" +) +print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" +) +print() + + +print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") +relufeat_arg = {'method': 'exact'} -init_fn, features_fn = serial(DenseFeatures(width), - ReluFeatures(**relufeat_arg), - DenseFeatures(width), - ReluFeatures(**relufeat_arg), - DenseFeatures(width), - ReluFeatures(**relufeat_arg_top), DenseFeatures(1)) +print(f"ReluFeatures params:") +for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") +print() + +init_fn, feature_fn = serial( + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) -feats = jit(features_fn)(x, feat_fn_inputs) +feats = jit(feature_fn)(x, feat_fn_inputs) print("K_nngp :") print(feats.nngp_feat @ feats.nngp_feat.T) From d52ae476d3dda489dbb6f78d1ef4c6117e6097b1 Mon Sep 17 00:00:00 2001 From: insuhan Date: Sun, 27 Mar 2022 14:46:49 +0900 Subject: [PATCH 22/44] Make jit-able test_fc_ntk.py --- experimental/test_fc_ntk.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index f8cbd6d9..eab0ca63 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -57,7 +57,7 @@ print(f"{name_:<12} : {value_}") print() -init_fn, features_fn = serial( +init_fn, feature_fn = serial( DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), @@ -67,7 +67,7 @@ feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map -feats = jit(features_fn)(x, feat_fn_inputs) +feats = jit(feature_fn)(x, feat_fn_inputs) print(f"f_nngp shape: {feat_shape[0]}") print(f"f_ntk shape: {feat_shape[1]}") @@ -107,7 +107,7 @@ print(f"{name_:<12} : {value_}") print() -init_fn, features_fn = serial( +init_fn, feature_fn = serial( DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), @@ -117,7 +117,7 @@ feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map -feats = features_fn(x, feat_fn_inputs) +feats = jit(feature_fn)(x, feat_fn_inputs) print(f"f_nngp shape: {feat_shape[0]}") print(f"f_ntk shape: {feat_shape[1]}") @@ -161,7 +161,7 @@ print(f"{name_:<12} : {value_}") print() -init_fn, features_fn = serial( +init_fn, feature_fn = serial( DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), @@ -171,7 +171,7 @@ feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map -feats = features_fn(x, feat_fn_inputs) +feats = jit(feature_fn)(x, feat_fn_inputs) print(f"f_nngp shape: {feat_shape[0]}") print(f"f_ntk shape: {feat_shape[1]}") From ec43fc756c710fe2260060048122ec7ea0de65aa Mon Sep 17 00:00:00 2001 From: insuhan Date: Mon, 28 Mar 2022 11:52:36 +0900 Subject: [PATCH 23/44] Add ReluNTKFeatures (one-pass sketching) --- experimental/features.py | 61 ++++++++-- experimental/poly_fitting.py | 28 +++++ experimental/test_fc_ntk.py | 223 ++++++++++++++--------------------- 3 files changed, 165 insertions(+), 147 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index d79282ac..47b2cf6b 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -10,7 +10,7 @@ from neural_tangents._src.stax.linear import _Pooling as Pooling from experimental.sketching import TensorSRHT, PolyTensorSketch -from experimental.poly_fitting import kappa0_coeffs, kappa1_coeffs, kappa0, kappa1 +from experimental.poly_fitting import kappa0_coeffs, kappa1_coeffs, kappa0, kappa1, relu_ntk_coeffs """ Implementation for NTK Sketching and Random Features """ @@ -114,16 +114,6 @@ def new_layer_fns(*args, **kwargs): return new_layer_fns -def _check_modules_contain_dense_relu(module_names: tuple) -> bool: - - def _check_string_tuple_has_one_entry(str_tuple, entry): - return len(set(str_tuple)) == 1 and str_tuple[0] == entry - - return len(module_names) % 2 == 1 and _check_string_tuple_has_one_entry( - module_names[::2], 'DenseFeatures') and _check_string_tuple_has_one_entry( - module_names[1::2], 'ReluFeatures') - - # Modified the serial process of feature map blocks. # Followed https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py @layer @@ -396,6 +386,55 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: return init_fn, feature_fn +@layer +def ReluNTKFeatures( + num_layers: int, + poly_degree: int, + poly_sketch_dim: int, + W_std: float = 1., +): + + def init_fn(rng, input_shape): + input_dim = input_shape[0][-1] + + # PolySketch expansion for nngp/ntk features. + polysketch = PolyTensorSketch(rng=rng, + input_dim=input_dim, + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + nngp_coeffs, ntk_coeffs = relu_ntk_coeffs(poly_degree, num_layers) + + return (), (polysketch, nngp_coeffs, ntk_coeffs) + + def feature_fn(f, input=None, **kwargs): + input_shape = f.nngp_feat.shape[:-1] + + polysketch: PolyTensorSketch = input[0] + nngp_coeffs: np.ndarray = input[1] + ntk_coeffs: np.ndarray = input[2] + + polysketch_feats = polysketch.sketch( + f.nngp_feat) # f.ntk_feat should be equal to f.nngp_feat. + nngp_feat = polysketch.expand_feats(polysketch_feats, nngp_coeffs) + ntk_feat = polysketch.expand_feats(polysketch_feats, ntk_coeffs) + + # Apply SRHT to features so that dimensions are poly_sketch_dim//2. + nngp_feat = polysketch.standardsrht(nngp_feat).reshape(input_shape + (-1,)) + ntk_feat = polysketch.standardsrht(ntk_feat).reshape(input_shape + (-1,)) + + # Convert complex features to real ones. + ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) + nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) + + norms = f.norms / 2.**(num_layers / 2) * (W_std**(num_layers + 1)) + + return _renormalize_feature( + f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms)) + + return init_fn, feature_fn + + def _conv_feat(X, filter_size): N, H, W, C = X.shape out = np.zeros((N, H, W, C * filter_size)) diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index 9d0e6563..bc6e085b 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -117,3 +117,31 @@ def kappa1_coeffs(degree: int, num_layers: int): # (close to 1) because the slope around 1 is much sharper. coeffs = poly_fitting_qp(xvals, fvals, weights, degree, eq_last_point=True) return np.where(coeffs < 1e-5, 0.0, coeffs) + + +def relu_ntk_coeffs(degree: int, num_layers: int): + + num_points = 20 * num_layers + 8 * degree + x_eq = np.linspace(-1, 1., num=201) + x_noneq = np.cos((2 * np.arange(num_points) + 1) * np.pi / (4 * num_points)) + x = np.sort(np.concatenate((x_eq, x_noneq))) + + kappa1s = {} + kappa1s[0] = x + for i in range(num_layers): + kappa1s[i + 1] = kappa1(kappa1s[i], is_x_matrix=False) + + weights = np.linspace(0.0, 1.0, num=len(x)) + 2 / num_layers + nngp_coeffs = poly_fitting_qp(x, kappa1s[num_layers], weights, degree) + nngp_coeffs = np.where(nngp_coeffs < 1e-5, 0.0, nngp_coeffs) + + ntk = np.zeros(len(x), dtype=x.dtype) + for i in range(num_layers + 1): + z = kappa1s[i] + for j in range(i, num_layers): + z *= kappa0(kappa1s[j], is_x_matrix=False) + ntk += z + ntk_coeffs = poly_fitting_qp(x, ntk, weights, degree) + ntk_coeffs = np.where(ntk_coeffs < 1e-5, 0.0, ntk_coeffs) + + return nngp_coeffs, ntk_coeffs diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index eab0ca63..9c42a501 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -9,7 +9,7 @@ config.update("jax_enable_x64", True) from neural_tangents import stax -from experimental.features import DenseFeatures, ReluFeatures, serial +from experimental.features import DenseFeatures, ReluFeatures, serial, ReluNTKFeatures seed = 1 n, d = 6, 5 @@ -18,15 +18,14 @@ x = random.normal(key1, (n, d)) width = 512 # this does not matter the output -W_std = 1.234 # std of Gaussian random weights +W_std = 1.234 # std of Gaussian random weights print("================== Result of Neural Tangent Library ==================") -init_fn, _, kernel_fn = stax.serial( - stax.Dense(width, W_std=W_std), stax.Relu(), - stax.Dense(width, W_std=W_std), stax.Relu(), - stax.Dense(width, W_std=W_std), stax.Relu(), - stax.Dense(1, W_std=W_std)) +init_fn, _, kernel_fn = stax.serial(stax.Dense(width, W_std=W_std), stax.Relu(), + stax.Dense(width, W_std=W_std), stax.Relu(), + stax.Dense(width, W_std=W_std), stax.Relu(), + stax.Dense(1, W_std=W_std)) nt_kernel = kernel_fn(x, None) @@ -39,6 +38,27 @@ print() +def eval_features(f_): + print(f"f_nngp shape: {f_.nngp_feat.shape}") + print(f"f_ntk shape: {f_.ntk_feat.shape}") + + print("K_nngp:") + print(f_.nngp_feat @ f_.nngp_feat.T) + print() + + print("K_ntk:") + print(f_.ntk_feat @ f_.ntk_feat.T) + print() + + print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - f_.nngp_feat @ f_.nngp_feat.T)}" + ) + print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - f_.ntk_feat @ f_.ntk_feat.T)}" + ) + print() + + print("==================== Result of NTK Random Features ====================") kappa0_feat_dim = 4096 @@ -46,48 +66,29 @@ sketch_dim = 4096 relufeat_arg = { - 'method': 'rf', - 'feature_dim0': kappa0_feat_dim, - 'feature_dim1': kappa1_feat_dim, - 'sketch_dim': sketch_dim, + 'method': 'rf', + 'feature_dim0': kappa0_feat_dim, + 'feature_dim1': kappa1_feat_dim, + 'sketch_dim': sketch_dim, } print(f"ReluFeatures params:") for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") + print(f"{name_:<12} : {value_}") print() init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std)) + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map feats = jit(feature_fn)(x, feat_fn_inputs) - -print(f"f_nngp shape: {feat_shape[0]}") -print(f"f_ntk shape: {feat_shape[1]}") - -print("K_nngp (approx):") -print(feats.nngp_feat @ feats.nngp_feat.T) -print() - -print("K_ntk (approx):") -print(feats.ntk_feat @ feats.ntk_feat.T) -print() - -print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" -) -print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" -) -print() - +eval_features(feats) print("==================== Result of NTK wih PolySketch ====================") @@ -96,142 +97,106 @@ sketch_dim = 4096 relufeat_arg = { - 'method': 'ps', - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim + 'method': 'ps', + 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim } print(f"ReluFeatures params:") for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") + print(f"{name_:<12} : {value_}") print() init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), - DenseFeatures(1, W_std=W_std)) + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), + DenseFeatures(1, W_std=W_std)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map feats = jit(feature_fn)(x, feat_fn_inputs) - -print(f"f_nngp shape: {feat_shape[0]}") -print(f"f_ntk shape: {feat_shape[1]}") - -print("K_nngp (approx):") -print(feats.nngp_feat @ feats.nngp_feat.T) -print() - -print("K_ntk (approx):") -print(feats.ntk_feat @ feats.ntk_feat.T) -print() - -print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" -) -print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" -) -print() - +eval_features(feats) print("=============== Result of PolySketch + Random Features ===============") kappa0_feat_dim = 2048 -kappa1_feat_dim = 2048 sketch_dim = 4096 poly_degree = 4 poly_sketch_dim = 4096 relufeat_arg = { - 'method': 'psrf', - 'feature_dim0': kappa0_feat_dim, - 'feature_dim1': kappa1_feat_dim, - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim + 'method': 'psrf', + 'feature_dim0': kappa0_feat_dim, + 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim } print(f"ReluFeatures params:") for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") + print(f"{name_:<12} : {value_}") print() init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), - DenseFeatures(1, W_std=W_std)) + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), + DenseFeatures(1, W_std=W_std)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map feats = jit(feature_fn)(x, feat_fn_inputs) +eval_features(feats) -print(f"f_nngp shape: {feat_shape[0]}") -print(f"f_ntk shape: {feat_shape[1]}") +print("=========== Result of ReLU-NTK Sketch (one-pass sketching) ===========") -print("K_nngp (approx):") -print(feats.nngp_feat @ feats.nngp_feat.T) -print() +relufeat_arg = { + 'num_layers': 3, + 'poly_degree': 32, + 'poly_sketch_dim': 4096, + 'W_std': W_std, +} -print("K_ntk (approx):") -print(feats.ntk_feat @ feats.ntk_feat.T) +print(f"ReluFeatures params:") +for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") print() -print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" -) -print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" -) -print() +init_fn, feature_fn = ReluNTKFeatures(**relufeat_arg) +_, feat_fn_inputs = init_fn(key2, x.shape) + +# Transform input vectors to NNGP/NTK feature map +feats = jit(feature_fn)(x, feat_fn_inputs) +eval_features(feats) print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") print("\t(*No Sketching algorithm is applied.)") -relufeat_arg = {'method': 'poly', 'poly_degree':64} +relufeat_arg = {'method': 'poly', 'poly_degree': 64} print(f"ReluFeatures params:") for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") + print(f"{name_:<12} : {value_}") print() init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std)) + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) feats = jit(feature_fn)(x, feat_fn_inputs) - -print(f"f_nngp shape: {feat_shape[0]}") -print(f"f_ntk shape: {feat_shape[1]}") - -print("K_nngp (approx):") -print(feats.nngp_feat @ feats.nngp_feat.T) -print() - -print("K_ntk :") -print(feats.ntk_feat @ feats.ntk_feat.T) -print() - -print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" -) -print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" -) -print() +eval_features(feats) print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") @@ -240,31 +205,17 @@ print(f"ReluFeatures params:") for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") + print(f"{name_:<12} : {value_}") print() init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std)) + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std)) # Initialize random vectors and sketching algorithms feat_shape, feat_fn_inputs = init_fn(key2, x.shape) feats = jit(feature_fn)(x, feat_fn_inputs) - -print("K_nngp :") -print(feats.nngp_feat @ feats.nngp_feat.T) -print() - -print("K_ntk :") -print(feats.ntk_feat @ feats.ntk_feat.T) -print() - -print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" -) -print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" -) +eval_features(feats) \ No newline at end of file From e0f5bfe229e9fed440d0d971d6d3aa9abd39f7a4 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 30 Mar 2022 05:05:25 +0900 Subject: [PATCH 24/44] Reflect comments in PR conversation --- experimental/README.md | 2 +- experimental/features.py | 296 ++++++++++-------- experimental/poly_fitting.py | 16 +- experimental/sketching.py | 4 +- experimental/tests/__init__.py | 0 .../{test_fc_ntk.py => tests/fc_ntk_test.py} | 3 +- .../myrtle_network_test.py} | 0 experimental/tests/sketching_test.py | 41 +++ 8 files changed, 225 insertions(+), 137 deletions(-) create mode 100644 experimental/tests/__init__.py rename experimental/{test_fc_ntk.py => tests/fc_ntk_test.py} (99%) rename experimental/{test_myrtle_network.py => tests/myrtle_network_test.py} (100%) create mode 100644 experimental/tests/sketching_test.py diff --git a/experimental/README.md b/experimental/README.md index 4acef12e..d75172dd 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -118,7 +118,7 @@ assert out_feat.ntk_feat.shape == (3, 3) ## [`features.ConvFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L236) -`features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * s**2)` where `s` is the stride size and `d` is the input feature dimension. +`features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * filter_size**2)` where `filter_size` is the size of convolution filter and `d` is the input feature dimension. ## [`features.AvgPoolFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L269) diff --git a/experimental/features.py b/experimental/features.py index 47b2cf6b..a4e8d480 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable +from typing import Optional, Callable, Sequence, Tuple from jax import random from jax import numpy as np from jax.numpy.linalg import cholesky @@ -6,7 +6,7 @@ from neural_tangents import stax from neural_tangents._src.utils import dataclasses -from neural_tangents._src.stax.linear import _pool_kernel, Padding +from neural_tangents._src.stax.linear import _pool_kernel, Padding, _get_dimension_numbers from neural_tangents._src.stax.linear import _Pooling as Pooling from experimental.sketching import TensorSRHT, PolyTensorSketch @@ -14,22 +14,11 @@ """ Implementation for NTK Sketching and Random Features """ -def _prod(tuple_): - prod = 1 - for x in tuple_: - prod = prod * x - return prod - - -def _poly_expansion(x, coeffs): - return np.polyval(coeffs[::-1], x) - - @dataclasses.dataclass class Features: - nngp_feat: Optional[np.ndarray] = None - ntk_feat: Optional[np.ndarray] = None - norms: Optional[np.ndarray] = None + nngp_feat: np.ndarray + ntk_feat: np.ndarray + norms: np.ndarray batch_axis: int = 0 channel_axis: int = -1 @@ -37,25 +26,38 @@ class Features: replace = ... # type: Callable[..., 'Features'] -def _inputs_to_features(x: np.ndarray, - batch_axis: int = 0, - channel_axis: int = -1, - **kwargs) -> Features: - """Transforms (batches of) inputs to a `Features`.""" +def layer(layer_fn): - # Followed the same initialization of Neural Tangents library. - nngp_feat = x / x.shape[channel_axis]**0.5 - norms = np.linalg.norm(nngp_feat, axis=channel_axis) - norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) - nngp_feat = nngp_feat / norms + def new_layer_fns(*args, **kwargs): + init_fn, feature_fn = layer_fn(*args, **kwargs) + init_fn = _preprocess_init_fn(init_fn) + feature_fn = _preprocess_feature_fn(feature_fn) + return init_fn, feature_fn - ntk_feat = np.array([0.0], dtype=nngp_feat.dtype) + return new_layer_fns - return Features(nngp_feat=nngp_feat, - ntk_feat=ntk_feat, - norms=norms, - batch_axis=batch_axis, - channel_axis=channel_axis) # pytype:disable=wrong-keyword-args + +def _preprocess_init_fn(init_fn): + + def init_fn_any(rng, input_shape_any, **kwargs): + if _is_sinlge_shape(input_shape_any): + input_shape = (input_shape_any, (-1, 0)) + return init_fn(rng, input_shape, **kwargs) + else: + return init_fn(rng, input_shape_any, **kwargs) + + return init_fn_any + + +def _is_sinlge_shape(input_shape): + if len(input_shape) == 2: + if all(isinstance(n, int) for n in input_shape): + return True + elif all(_is_sinlge_shape(s) for s in input_shape): + return False + elif len(input_shape) == 3: + return _is_sinlge_shape(input_shape[:2]) + raise ValueError(input_shape) # For flexible `feature_fn` with both input `np.ndarray` and with `Feature`. @@ -77,41 +79,24 @@ def feature_fn_any(x_or_feature, input=None, **kwargs): return feature_fn_any -def _is_sinlge_shape(input_shape): - return all(isinstance(n, int) for n in input_shape) - - -def _is_defaut_feature(feat): - return feat.ndim == 1 - - -def _preprocess_init_fn(init_fn): - - def init_fn_any(rng, input_shape_any, **kwargs): - if _is_sinlge_shape(input_shape_any): - input_shape = (input_shape_any, (-1, 0)) - return init_fn(rng, input_shape, **kwargs) - else: - return init_fn(rng, input_shape_any, **kwargs) - - return init_fn_any - - -def _renormalize_feature(f: Features, **kwargs): - nngp_feat = f.nngp_feat * f.norms - ntk_feat = f.ntk_feat * f.norms - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - - -def layer(layer_fn): +def _inputs_to_features(x: np.ndarray, + batch_axis: int = 0, + channel_axis: int = -1, + **kwargs) -> Features: + """Transforms (batches of) inputs to a `Features`.""" - def new_layer_fns(*args, **kwargs): - init_fn, feature_fn = layer_fn(*args, **kwargs) - feature_fn = _preprocess_feature_fn(feature_fn) - init_fn = _preprocess_init_fn(init_fn) - return init_fn, feature_fn + # Followed the same initialization of Neural Tangents library. + nngp_feat = x / x.shape[channel_axis]**0.5 + norms = np.linalg.norm(nngp_feat, axis=channel_axis) + norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) + nngp_feat = nngp_feat / norms + ntk_feat = np.zeros((), dtype=nngp_feat.dtype) - return new_layer_fns + return Features(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + norms=norms, + batch_axis=batch_axis, + channel_axis=channel_axis) # pytype:disable=wrong-keyword-args # Modified the serial process of feature map blocks. @@ -131,17 +116,28 @@ def feature_fn(k, inputs, **kwargs): return init_fn, feature_fn +def _renormalize_feature(f: Features, **kwargs): + nngp_feat = f.nngp_feat * f.norms + ntk_feat = f.ntk_feat * f.norms + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + @layer def DenseFeatures(out_dim: int, W_std: float = 1., - b_std: float = 0., + b_std: Optional[float] = None, + parameterization: str = 'ntk', batch_axis: int = 0, channel_axis: int = -1): - if b_std != 0.0: - raise NotImplementedError('Non-zero b_std is not implemented yet .' + if b_std is not None: + raise NotImplementedError('Bias variable b_std is not implemented yet .' ' Please set b_std to be `0`.') + if parameterization != 'ntk': + raise NotImplementedError(f'Parameterization ({parameterization}) is ' + ' not implemented yet.') + def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_ntk_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] + @@ -153,10 +149,13 @@ def init_fn(rng, input_shape): return (nngp_feat_shape, new_ntk_feat_shape, 'D'), () def feature_fn(f: Features, input, **kwargs): - nngp_feat, ntk_feat, norms = f.nngp_feat, f.ntk_feat, f.norms + nngp_feat : np.ndarray = f.nngp_feat + ntk_feat : np.ndarray = f.ntk_feat + norms : np.ndarray = f.norms + norms *= W_std - if _is_defaut_feature(ntk_feat): # check if ntk_feat is empty + if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) @@ -280,23 +279,23 @@ def init_fn(rng, input_shape): def feature_fn(f: Features, input=None, **kwargs) -> Features: - input_shape = f.nngp_feat.shape[:-1] - nngp_feat_dim = f.nngp_feat.shape[-1] - ntk_feat_dim = f.ntk_feat.shape[-1] + input_shape: tuple = f.nngp_feat.shape[:-1] + nngp_feat_dim: tuple = f.nngp_feat.shape[-1] + ntk_feat_dim: tuple = f.ntk_feat.shape[-1] - nngp_feat_2d = f.nngp_feat.reshape(-1, nngp_feat_dim) - ntk_feat_2d = f.ntk_feat.reshape(-1, ntk_feat_dim) - norms = f.norms + nngp_feat_2d: np.ndarray = f.nngp_feat.reshape(-1, nngp_feat_dim) + ntk_feat_2d: np.ndarray = f.ntk_feat.reshape(-1, ntk_feat_dim) + norms: np.ndarray = f.norms if method == 'rf': # Random Features approach. W0: np.ndarray = input[0] W1: np.ndarray = input[1] tensorsrht: TensorSRHT = input[2] - kappa0_feat = (nngp_feat_2d @ W0 > 0) / np.sqrt(W0.shape[-1]) + kappa0_feat = (nngp_feat_2d @ W0 > 0) / W0.shape[-1]**0.5 del W0 nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / - np.sqrt(W1.shape[-1])).reshape(input_shape + (-1,)) + W1.shape[-1]**0.5).reshape(input_shape + (-1,)) del W1 ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat, real_output=True).reshape(input_shape + @@ -343,7 +342,7 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: nngp_proj = np.concatenate( (nngp_feat_2d.real, nngp_feat_2d.imag), axis=1) @ W0 kappa0_feat = np.concatenate( - ((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / np.sqrt(W0.shape[-1]) + ((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / W0.shape[-1]**0.5 del W0 # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. @@ -361,11 +360,11 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: layer_count = input[2] gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) - nngp_feat = cholesky(_poly_expansion( - gram_nngp, kappa1_coeff)).reshape(input_shape + (-1,)) + nngp_feat = cholesky(np.polyval(kappa1_coeff[::-1], + gram_nngp)).reshape(input_shape + (-1,)) ntk = ntk_feat_2d @ ntk_feat_2d.T - kappa0_mat = _poly_expansion(gram_nngp, kappa0_coeff) + kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) elif method == 'exact': # Exact feature map computations via Cholesky decomposition. @@ -379,18 +378,25 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: raise NotImplementedError(f'Invalid method name: {method}') if method != 'rf': - norms /= np.sqrt(2.) + norms /= 2.0**0.5 return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) return init_fn, feature_fn +def _prod(tuple_): + prod = 1 + for x in tuple_: + prod = prod * x + return prod + + @layer def ReluNTKFeatures( num_layers: int, - poly_degree: int, - poly_sketch_dim: int, + poly_degree: int = 16, + poly_sketch_dim: int = 1024, W_std: float = 1., ): @@ -435,33 +441,36 @@ def feature_fn(f, input=None, **kwargs): return init_fn, feature_fn -def _conv_feat(X, filter_size): - N, H, W, C = X.shape - out = np.zeros((N, H, W, C * filter_size)) - out = out.at[:, :, :, :C].set(X) - j = 1 - for i in range(1, min((filter_size + 1) // 2, W)): - out = out.at[:, :, :-i, j * C:(j + 1) * C].set(X[:, :, i:]) - j += 1 - out = out.at[:, :, i:, j * C:(j + 1) * C].set(X[:, :, :-i]) - j += 1 - return out +@layer +def ConvFeatures(out_chan: int, + filter_shape: Sequence[int], + strides: Optional[Sequence[int]], + padding: str, + W_std: float = 1.0, + b_std: Optional[float] = None, + dimension_numbers: Optional[Tuple[str, str, str]] = None, + parameterization: str = 'ntk'): + if b_std is not None: + raise NotImplementedError('Bias variable b_std is not implemented yet .' + ' Please set b_std to be `0`.') -def _conv2d_feat(X, filter_size): - return _conv_feat(np.moveaxis(_conv_feat(X, filter_size), 1, 2), filter_size) + parameterization = parameterization.lower() + if dimension_numbers is None: + dimension_numbers = _get_dimension_numbers(len(filter_shape), False) -@layer -def ConvFeatures(out_dim: int, - filter_size: int, - W_std: float = 1.0, - b_std: float = 0., - channel_axis: int = -1): + lhs_spec, rhs_spec, out_spec = dimension_numbers - if b_std != 0.0: - raise NotImplementedError('Non-zero b_std is not implemented yet .' - ' Please set b_std to be `0`.') + channel_axis = lhs_spec.index('C') + + if parameterization != 'ntk': + raise NotImplementedError(f'Parameterization ({parameterization}) is ' + ' not implemented yet.') + + if filter_shape[0] != filter_shape[1]: + raise NotImplementedError('filter_shape should be square.') + filter_size = filter_shape[0] def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] @@ -469,14 +478,18 @@ def init_fn(rng, input_shape): filter_size**2,) new_ntk_feat_shape = nngp_feat_shape[:-1] + ( (nngp_feat_shape[-1] + ntk_feat_shape[-1]) * filter_size**2,) - return (new_nngp_feat_shape, new_ntk_feat_shape), () + + if len(input_shape) > 2: + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () + else: + return (new_nngp_feat_shape, new_ntk_feat_shape, 'D'), () def feature_fn(f, input, **kwargs): nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat nngp_feat = _conv2d_feat(nngp_feat, filter_size) / filter_size * W_std - if _is_defaut_feature(ntk_feat): # check if ntk_feat is empty + if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = _conv2d_feat(ntk_feat, filter_size) / filter_size * W_std @@ -487,9 +500,30 @@ def feature_fn(f, input, **kwargs): return init_fn, feature_fn +def _conv2d_feat(X, filter_size): + return _conv_feat(np.moveaxis(_conv_feat(X, filter_size), 1, 2), filter_size) + + +def _conv_feat(X, filter_size): + """ + Direct sums of image features. If input shape is (N, H, W, C), the output has + the shape (N, H, W, C * filter_size**2). + """ + N, H, W, C = X.shape + out = np.zeros((N, H, W, C * filter_size)) + out = out.at[:, :, :, :C].set(X) + j = 1 + for i in range(1, min((filter_size + 1) // 2, W)): + out = out.at[:, :, :-i, j * C:(j + 1) * C].set(X[:, :, i:]) + j += 1 + out = out.at[:, :, i:, j * C:(j + 1) * C].set(X[:, :, :-i]) + j += 1 + return out + + @layer -def AvgPoolFeatures(window_size: int, - stride_size: int = 2, +def AvgPoolFeatures(window_shape: Sequence[int], + strides: Optional[Sequence[int]] = None, padding: str = stax.Padding.VALID.name, normalize_edges: bool = False, batch_axis: int = 0, @@ -499,21 +533,25 @@ def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_nngp_feat_shape = nngp_feat_shape[:1] + ( - nngp_feat_shape[1] // window_size, - nngp_feat_shape[2] // window_size) + nngp_feat_shape[-1:] + nngp_feat_shape[1] // window_shape[0], + nngp_feat_shape[2] // window_shape[1]) + nngp_feat_shape[-1:] new_ntk_feat_shape = ntk_feat_shape[:1] + ( - ntk_feat_shape[1] // window_size, - ntk_feat_shape[2] // window_size) + ntk_feat_shape[-1:] + ntk_feat_shape[1] // window_shape[0], + ntk_feat_shape[2] // window_shape[1]) + ntk_feat_shape[-1:] return (new_nngp_feat_shape, new_ntk_feat_shape), () def feature_fn(f, input=None, **kwargs): - window_shape_kernel = (1,) + (window_size,) * 2 + (1,) - strides_kernel = (1,) + (window_size,) * 2 + (1,) + window_shape_kernel = (1,) + tuple(window_shape) + (1,) + strides_kernel = (1,) + tuple(strides) + (1,) pooling = lambda x: _pool_kernel(x, Pooling.AVG, window_shape_kernel, strides_kernel, Padding(padding), normalize_edges, 0) nngp_feat = pooling(f.nngp_feat) - ntk_feat = pooling(f.ntk_feat) + + if f.ntk_feat.ndim == 0: # check if ntk_feat is empty + ntk_feat = nngp_feat + else: + ntk_feat = pooling(f.ntk_feat) return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) @@ -523,6 +561,15 @@ def feature_fn(f, input=None, **kwargs): @layer def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): + if batch_axis_out in (0, -2): + batch_axis_out = 0 + channel_axis_out = 1 + elif batch_axis_out in (1, -1): + batch_axis_out = 1 + channel_axis_out = 0 + else: + raise ValueError(f'`batch_axis_out` must be 0 or 1, got {batch_axis_out}.') + def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_nngp_feat_shape = nngp_feat_shape[:1] + (_prod(nngp_feat_shape[1:]),) @@ -530,14 +577,15 @@ def init_fn(rng, input_shape): return (new_nngp_feat_shape, new_ntk_feat_shape), () def feature_fn(f, input=None, **kwargs): - batch_size = f.nngp_feat.shape[0] - nngp_feat = f.nngp_feat.reshape(batch_size, -1) / np.sqrt( - _prod(f.nngp_feat.shape[1:-1])) - if _is_defaut_feature(f.ntk_feat): # check if ntk_feat is empty + batch_size = f.nngp_feat.shape[batch_axis] + nngp_feat = f.nngp_feat.reshape(batch_size, -1) / _prod( + f.nngp_feat.shape[1:-1])**0.5 + + if f.ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = f.ntk_feat else: - ntk_feat = f.ntk_feat.reshape(batch_size, -1) / np.sqrt( - _prod(f.ntk_feat.shape[1:-1])) + ntk_feat = f.ntk_feat.reshape(batch_size, -1) / _prod( + f.ntk_feat.shape[1:-1])**0.5 return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index bc6e085b..a27a52fd 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -3,14 +3,6 @@ from jaxopt import OSQP -def _arccos(x): - return np.arccos(np.clip(x, -1, 1)) - - -def _sqrt(x): - return np.sqrt(np.maximum(x, 1e-20)) - - def kappa0(x, is_x_matrix=True): if is_x_matrix: xxt = x @ x.T @@ -32,6 +24,14 @@ def kappa1(x, is_x_matrix=True): return (_sqrt(1 - x**2) + (np.pi - _arccos(x)) * x) / np.pi +def _arccos(x): + return np.arccos(np.clip(x, -1, 1)) + + +def _sqrt(x): + return np.maximum(x, 1e-20)**0.5 + + def poly_fitting_qp(xvals: np.ndarray, fvals: np.ndarray, weights: np.ndarray, diff --git a/experimental/sketching.py b/experimental/sketching.py index aaf0b50c..0b8ab858 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -12,7 +12,7 @@ class TensorSRHT: input_dim2: int sketch_dim: int - rng: np.ndarray + rng: random.KeyArray shape: Optional[np.ndarray] = None rand_signs1: Optional[np.ndarray] = None @@ -54,7 +54,7 @@ class PolyTensorSketch: sketch_dim: int degree: int - rng: np.ndarray + rng: random.KeyArray tree_rand_signs: Optional[list] = None tree_rand_inds: Optional[list] = None diff --git a/experimental/tests/__init__.py b/experimental/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/test_fc_ntk.py b/experimental/tests/fc_ntk_test.py similarity index 99% rename from experimental/test_fc_ntk.py rename to experimental/tests/fc_ntk_test.py index 9c42a501..b3c6418d 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/tests/fc_ntk_test.py @@ -3,7 +3,6 @@ from jax.config import config from jax import jit import sys - sys.path.append("./") config.update("jax_enable_x64", True) @@ -87,7 +86,7 @@ def eval_features(f_): feat_shape, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map -feats = jit(feature_fn)(x, feat_fn_inputs) +feats = feature_fn(x, feat_fn_inputs) eval_features(feats) print("==================== Result of NTK wih PolySketch ====================") diff --git a/experimental/test_myrtle_network.py b/experimental/tests/myrtle_network_test.py similarity index 100% rename from experimental/test_myrtle_network.py rename to experimental/tests/myrtle_network_test.py diff --git a/experimental/tests/sketching_test.py b/experimental/tests/sketching_test.py new file mode 100644 index 00000000..c7bf7a45 --- /dev/null +++ b/experimental/tests/sketching_test.py @@ -0,0 +1,41 @@ +import sys + +sys.path.append("./") +import scipy +from jax import random, jit +from jax import numpy as jnp +from experimental.sketching import PolyTensorSketch + +# Coefficients of Taylor series of exp(x) +degree = 8 +coeffs = jnp.asarray([1 / scipy.special.factorial(i) for i in range(degree)]) + +n = 4 +d = 32 +sketch_dim = 256 + +rng = random.PRNGKey(1) +x = random.normal(rng, shape=(n, d)) +norm_x = jnp.linalg.norm(x, axis=-1) +x_normalized = x / norm_x[:, None] + +rng2 = random.PRNGKey(2) +pts = PolyTensorSketch(rng=rng2, + input_dim=d, + sketch_dim=sketch_dim, + degree=degree).init_sketches() # pytype:disable=wrong-keyword-args +x_sketches = pts.sketch(x_normalized) + +z = pts.expand_feats(x_sketches, coeffs) # z.shape[1] is not the desired. +z = pts.standardsrht(z) # z is complex ndarray. +z = jnp.concatenate((z.real, z.imag), axis=-1) + +K = jnp.polyval(coeffs[::-1], x_normalized @ x_normalized.T) +K_approx = z @ z.T + +print("Exact kernel matrix:") +print(K) +print() + +print(f"Approximate kernel matrix (sketch_dim: {z.shape[1]}):") +print(K_approx) From 61bc32de36cc73bbca46207d37b823e79e17316f Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 30 Mar 2022 06:14:18 +0900 Subject: [PATCH 25/44] Add JAXopt package --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3b03020b..0cd29007 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,8 @@ INSTALL_REQUIRES = [ 'jax>=0.3', 'frozendict>=2.3', - 'typing_extensions>=4.0.1' + 'typing_extensions>=4.0.1', + 'jaxopt>=0.3.1', ] From 1cfbe5a935d2cd4b912d55785752b00f7dc73623 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 30 Mar 2022 16:24:23 +0900 Subject: [PATCH 26/44] Reflect comments in PR conversation (v2) --- experimental/README.md | 24 ++- experimental/features.py | 243 +++++++++++++--------- experimental/tests/fc_ntk_test.py | 170 +++++---------- experimental/tests/myrtle_network_test.py | 136 ++++++++---- 4 files changed, 301 insertions(+), 272 deletions(-) diff --git a/experimental/README.md b/experimental/README.md index d75172dd..9a71c006 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -14,10 +14,10 @@ from jax import random from experimental.features import DenseFeatures, ReluFeatures, serial relufeat_arg = { - 'feature_dim0': 128, + 'method': 'RANDFEAT', + 'feature_dim0': 64, 'feature_dim1': 128, 'sketch_dim': 256, - 'method': 'rf', } init_fn, feature_fn = serial( @@ -33,6 +33,8 @@ _, feat_fn_inputs = init_fn(key2, x.shape) feats = feature_fn(x, feat_fn_inputs) # feats.nngp_feat is a feature map of NNGP kernel # feats.ntk_feat is a feature map of NTK +assert feats.nngp_feat.shape == (5, relufeat_arg['feature_dim1']) +assert feats.ntk_feat.shape == (5, relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim']) ``` For more details of fully connected NTK features, please check `test_fc_ntk.py`. @@ -41,8 +43,9 @@ For more details of fully connected NTK features, please check `test_fc_ntk.py`. from experimental.features import ConvFeatures, AvgPoolFeatures, FlattenFeatures init_fn, feature_fn = serial( - ConvFeatures(512, filter_size=3), ReluFeatures(**relufeat_arg), - AvgPoolFeatures(2, 2), FlattenFeatures() + ConvFeatures(512, filter_shape=(3, 3)), ReluFeatures(**relufeat_arg), + AvgPoolFeatures((2, 2), strides=(2, 2)), FlattenFeatures(), + DenseFeatures(512) ) n, H, W, C = 5, 8, 8, 3 @@ -53,6 +56,8 @@ _, feat_fn_inputs = init_fn(key2, x.shape) feats = feature_fn(x, feat_fn_inputs) # feats.nngp_feat is a feature map of NNGP kernel # feats.ntk_feat is a feature map of NTK +assert feats.nngp_feat.shape == (5, (H/2)*(W/2)*relufeat_arg['feature_dim1']) +assert feats.ntk_feat.shape == (5, (H/2)*(W/2)*(relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim'])) ``` For more complex CNTK features, please check `test_myrtle_networks.py`. @@ -67,19 +72,20 @@ All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kerne ## [`features.DenseFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L88) `features.DenseFeatures` provides features for fully-connected dense layer and corresponds to `stax.Dense` module in [Neural Tangents](https://github.com/google/neural-tangents). We assume that the input is a tabular dataset (i.e., a n-by-d matrix). Its `feature_fn` updates the NTK features by concatenating NNGP features and NTK features. This is because `stax.Dense` updates a new NTK kernel matrix `(N x D)` by adding the previous NNGP and NTK kernel matrices. The features of dense layer are exact and no approximations are applied. ```python -import numpy as np +from jax import numpy as np from neural_tangents import stax +from experimental.features import DenseFeatures, serial width = 1 x = random.normal(key1, shape=(3, 2)) _, _, kernel_fn = stax.Dense(width) nt_kernel = kernel_fn(x) -_, feat_fn = DenseFeatures(width) +_, feat_fn = serial(DenseFeatures(width)) feat = feat_fn(x, ()) -assert np.linalg.norm(nt_kernel.ntk - feat.ntk_feat @ feat.ntk_feat.T) <= 1e-12 assert np.linalg.norm(nt_kernel.nngp - feat.nngp_feat @ feat.nngp_feat.T) <= 1e-12 +assert feat.ntk_feat == np.zeros(()) ``` ## [`features.ReluFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L119) @@ -89,11 +95,13 @@ For image dataset, the inputs are 4-D tensors with shape `N x H x W x D` where N To use the Random Features approach, set the parameter `method` to `rf` (default `rf`), e.g., ```python +from experimental.features import DenseFeatures, ReluFeatures, serial + x = random.normal(key1, shape=(3, 32)) init_fn, feat_fn = serial( DenseFeatures(1), - ReluFeatures(method='rf', feature_dim0=10, feature_dim1=20, sketch_dim=30) + ReluFeatures(method='RANDFEAT', feature_dim0=10, feature_dim1=20, sketch_dim=30) ) _, params = init_fn(key1, x.shape) diff --git a/experimental/features.py b/experimental/features.py index a4e8d480..fe4f13a3 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -1,8 +1,10 @@ +import enum from typing import Optional, Callable, Sequence, Tuple from jax import random from jax import numpy as np from jax.numpy.linalg import cholesky import jax.example_libraries.stax as ostax +from jax import eval_shape, ShapedArray from neural_tangents import stax from neural_tangents._src.utils import dataclasses @@ -26,6 +28,15 @@ class Features: replace = ... # type: Callable[..., 'Features'] +class ReluFeaturesMethod(enum.Enum): + """Method for ReLU NNGP/NTK features approximation.""" + RANDFEAT = 'RANDFEAT' + POLYSKETCH = 'POLYSKETCH' + PSRF = 'PSRF' + POLY = 'POLY' + EXACT = 'EXACT' + + def layer(layer_fn): def new_layer_fns(*args, **kwargs): @@ -41,7 +52,7 @@ def _preprocess_init_fn(init_fn): def init_fn_any(rng, input_shape_any, **kwargs): if _is_sinlge_shape(input_shape_any): - input_shape = (input_shape_any, (-1, 0)) + input_shape = (input_shape_any, (-1, 0)) # Add a dummy shape for ntk_feat return init_fn(rng, input_shape, **kwargs) else: return init_fn(rng, input_shape_any, **kwargs) @@ -50,13 +61,11 @@ def init_fn_any(rng, input_shape_any, **kwargs): def _is_sinlge_shape(input_shape): - if len(input_shape) == 2: - if all(isinstance(n, int) for n in input_shape): - return True - elif all(_is_sinlge_shape(s) for s in input_shape): - return False - elif len(input_shape) == 3: - return _is_sinlge_shape(input_shape[:2]) + if all(isinstance(n, int) for n in input_shape): + return True + elif (len(input_shape) == 2 or len(input_shape) == 3) and all( + _is_sinlge_shape(s) for s in input_shape[:2]): + return False raise ValueError(input_shape) @@ -110,16 +119,17 @@ def serial(*layers): def feature_fn(k, inputs, **kwargs): for f, input_ in zip(feature_fns, inputs): k = f(k, input_, **kwargs) - k = _renormalize_feature(k) + k = _unnormalize_features(k) return k return init_fn, feature_fn -def _renormalize_feature(f: Features, **kwargs): +def _unnormalize_features(f: Features) -> Features: nngp_feat = f.nngp_feat * f.norms - ntk_feat = f.ntk_feat * f.norms - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + ntk_feat = f.ntk_feat * f.norms if f.ntk_feat.ndim != 0 else f.ntk_feat + norms = np.zeros((), dtype=nngp_feat.dtype) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) @layer @@ -132,7 +142,7 @@ def DenseFeatures(out_dim: int, if b_std is not None: raise NotImplementedError('Bias variable b_std is not implemented yet .' - ' Please set b_std to be `0`.') + ' Please set b_std to be None.') if parameterization != 'ntk': raise NotImplementedError(f'Parameterization ({parameterization}) is ' @@ -149,9 +159,9 @@ def init_fn(rng, input_shape): return (nngp_feat_shape, new_ntk_feat_shape, 'D'), () def feature_fn(f: Features, input, **kwargs): - nngp_feat : np.ndarray = f.nngp_feat - ntk_feat : np.ndarray = f.ntk_feat - norms : np.ndarray = f.norms + nngp_feat: np.ndarray = f.nngp_feat + ntk_feat: np.ndarray = f.ntk_feat + norms: np.ndarray = f.norms norms *= W_std @@ -166,25 +176,24 @@ def feature_fn(f: Features, input, **kwargs): @layer -def ReluFeatures(feature_dim0: int = 1, +def ReluFeatures(method: str = 'RANDFEAT', + feature_dim0: int = 1, feature_dim1: int = 1, sketch_dim: int = 1, poly_degree: int = 8, - poly_sketch_dim: int = 1, - method: str = 'rf', - top_layer: bool = False): + poly_sketch_dim: int = 1): - method = method.lower() - assert method in ['rf', 'ps', 'exact', 'psrf', 'poly'] + method = ReluFeaturesMethod(method.upper()) def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) net_shape = input_shape[2] - layer_count = len(net_shape) // 2 + 1 + relu_layers_count = net_shape.count('R') + new_net_shape = net_shape + 'R' - if method == 'rf': + if method == ReluFeaturesMethod.RANDFEAT: rng1, rng2, rng3 = random.split(rng, 3) # Random vectors for random features of arc-cosine kernel of order 0. W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) @@ -197,82 +206,80 @@ def init_fn(rng, input_shape): sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape, - net_shape + 'R'), (W0, W1, tensorsrht) + new_net_shape), (W0, W1, tensorsrht) - elif method == 'ps': + elif method == ReluFeaturesMethod.POLYSKETCH: new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) rng1, rng2, rng3 = random.split(rng, 3) - kappa1_coeff = kappa1_coeffs(poly_degree, layer_count - 1) - kappa0_coeff = kappa0_coeffs(poly_degree, layer_count - 1) + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) # PolySketch expansion for nngp features. polysketch = PolyTensorSketch(rng=rng1, input_dim=nngp_feat_shape[-1] // - (1 + (layer_count > 1)), + (1 + (relu_layers_count > 0)), sketch_dim=poly_sketch_dim, degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args # TensorSRHT of degree 2 for approximating tensor product. tensorsrht = TensorSRHT( - input_dim1=ntk_feat_shape[-1] // (1 + (layer_count > 1)), + input_dim1=ntk_feat_shape[-1] // (1 + (relu_layers_count > 0)), input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, sketch_dim=sketch_dim, rng=rng2).init_sketches() # pytype:disable=wrong-keyword-args - return (new_nngp_feat_shape, new_ntk_feat_shape, - net_shape + 'R'), (polysketch, tensorsrht, - (kappa0_coeff, kappa1_coeff, layer_count)) + new_net_shape), (polysketch, tensorsrht, kappa0_coeff, + kappa1_coeff) - elif method == 'psrf': + elif method == ReluFeaturesMethod.PSRF: new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) rng1, rng2, rng3 = random.split(rng, 3) - kappa1_coeff = kappa1_coeffs(poly_degree, layer_count - 1) + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) # PolySketch expansion for nngp features. polysketch = PolyTensorSketch(rng=rng1, input_dim=nngp_feat_shape[-1] // - (1 + (layer_count > 1)), + (1 + (relu_layers_count > 0)), sketch_dim=poly_sketch_dim, degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args # TensorSRHT of degree 2 for approximating tensor product. tensorsrht = TensorSRHT(rng=rng2, input_dim1=ntk_feat_shape[-1] // - (1 + (layer_count > 1)), + (1 + (relu_layers_count > 0)), input_dim2=feature_dim0, sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args # Random vectors for random features of arc-cosine kernel of order 0. - if layer_count == 1: + if relu_layers_count == 0: W0 = random.normal(rng3, (2 * nngp_feat_shape[-1], feature_dim0 // 2)) else: W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0 // 2)) return (new_nngp_feat_shape, new_ntk_feat_shape, - net_shape + 'R'), (W0, polysketch, tensorsrht, (kappa1_coeff, - layer_count)) + new_net_shape), (W0, polysketch, tensorsrht, kappa1_coeff) - elif method == 'poly': + elif method == ReluFeaturesMethod.POLY: # This only uses the polynomial approximation without sketching. new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( nngp_feat_shape[:-1]),) new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) - kappa1_coeff = kappa1_coeffs(poly_degree, layer_count - 1) - kappa0_coeff = kappa0_coeffs(poly_degree, layer_count - 1) + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) return (new_nngp_feat_shape, new_ntk_feat_shape, - net_shape + 'R'), (kappa0_coeff, kappa1_coeff, layer_count) + new_net_shape), (kappa0_coeff, kappa1_coeff) - elif method == 'exact': + elif method == ReluFeaturesMethod.EXACT: # The exact feature map computation is for debug. new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( nngp_feat_shape[:-1]),) new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) - return (new_nngp_feat_shape, new_ntk_feat_shape, net_shape + 'R'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), () else: raise NotImplementedError(f'Invalid method name: {method}') @@ -287,7 +294,7 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: ntk_feat_2d: np.ndarray = f.ntk_feat.reshape(-1, ntk_feat_dim) norms: np.ndarray = f.norms - if method == 'rf': # Random Features approach. + if method == ReluFeaturesMethod.RANDFEAT: # Random Features approach. W0: np.ndarray = input[0] W1: np.ndarray = input[1] tensorsrht: TensorSRHT = input[2] @@ -301,12 +308,11 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: real_output=True).reshape(input_shape + (-1,)) - elif method == 'ps': + elif method == ReluFeaturesMethod.POLYSKETCH: polysketch: PolyTensorSketch = input[0] tensorsrht: TensorSRHT = input[1] - kappa0_coeff: np.ndarray = input[2][0] - kappa1_coeff: np.ndarray = input[2][1] - layer_count: int = input[2][2] + kappa0_coeff: np.ndarray = input[2] + kappa1_coeff: np.ndarray = input[3] # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. polysketch_feats = polysketch.sketch(nngp_feat_2d) @@ -321,16 +327,11 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) - # At the top ReluFeatures, convert complex features to real ones. - if top_layer: - ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) - nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) - - elif method == 'psrf': # Combination of PolySketch and Random Features. + elif method == ReluFeaturesMethod.PSRF: # Combination of PolySketch and Random Features. W0: np.ndarray = input[0] polysketch: PolyTensorSketch = input[1] tensorsrht: TensorSRHT = input[2] - kappa1_coeff: np.ndarray = input[3][0] + kappa1_coeff: np.ndarray = input[3] polysketch_feats = polysketch.sketch(nngp_feat_2d) kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) @@ -349,15 +350,9 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) - # At the top ReluFeatures, convert complex features to real ones. - if top_layer: - ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) - nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) - - elif method == 'poly': # Polynomial approximation without sketching. + elif method == ReluFeaturesMethod.POLY: # Polynomial approximation without sketching. kappa0_coeff: np.ndarray = input[0] kappa1_coeff: np.ndarray = input[1] - layer_count = input[2] gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) nngp_feat = cholesky(np.polyval(kappa1_coeff[::-1], @@ -367,7 +362,7 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) - elif method == 'exact': # Exact feature map computations via Cholesky decomposition. + elif method == ReluFeaturesMethod.EXACT: # Exact feature map computations via Cholesky decomposition. nngp_feat = cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) ntk = ntk_feat_2d @ ntk_feat_2d.T @@ -377,7 +372,7 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: else: raise NotImplementedError(f'Invalid method name: {method}') - if method != 'rf': + if method != ReluFeaturesMethod.RANDFEAT: norms /= 2.0**0.5 return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) @@ -430,12 +425,12 @@ def feature_fn(f, input=None, **kwargs): ntk_feat = polysketch.standardsrht(ntk_feat).reshape(input_shape + (-1,)) # Convert complex features to real ones. - ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=1) - nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=1) + ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1) + nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=-1) norms = f.norms / 2.**(num_layers / 2) * (W_std**(num_layers + 1)) - return _renormalize_feature( + return _unnormalize_features( f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms)) return init_fn, feature_fn @@ -444,8 +439,8 @@ def feature_fn(f, input=None, **kwargs): @layer def ConvFeatures(out_chan: int, filter_shape: Sequence[int], - strides: Optional[Sequence[int]], - padding: str, + strides: Optional[Sequence[int]] = None, + padding: str = 'SAME', W_std: float = 1.0, b_std: Optional[float] = None, dimension_numbers: Optional[Tuple[str, str, str]] = None, @@ -453,7 +448,7 @@ def ConvFeatures(out_chan: int, if b_std is not None: raise NotImplementedError('Bias variable b_std is not implemented yet .' - ' Please set b_std to be `0`.') + ' Please set b_std to be None.') parameterization = parameterization.lower() @@ -480,22 +475,34 @@ def init_fn(rng, input_shape): (nngp_feat_shape[-1] + ntk_feat_shape[-1]) * filter_size**2,) if len(input_shape) > 2: - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'C'), () else: - return (new_nngp_feat_shape, new_ntk_feat_shape, 'D'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, 'C'), () def feature_fn(f, input, **kwargs): - nngp_feat, ntk_feat = f.nngp_feat, f.ntk_feat - + """ + Operations under ConvFeatures is concatenation of shifted features. Since + they are not linear operations, we first unnormalize features (i.e., + multiplying them by `norms`) and then re-normalize the output features. + """ + f_renormalized: Features = _unnormalize_features(f) + nngp_feat: np.ndarray = f_renormalized.nngp_feat + ntk_feat: np.ndarray = f_renormalized.ntk_feat nngp_feat = _conv2d_feat(nngp_feat, filter_size) / filter_size * W_std - if ntk_feat.ndim == 0: # check if ntk_feat is empty + if f.ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = _conv2d_feat(ntk_feat, filter_size) / filter_size * W_std ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + # Re-normalize the features. + norms = norms = np.linalg.norm(nngp_feat, axis=channel_axis) + norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) + nngp_feat = nngp_feat / norms + ntk_feat = ntk_feat / norms + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) return init_fn, feature_fn @@ -529,31 +536,50 @@ def AvgPoolFeatures(window_shape: Sequence[int], batch_axis: int = 0, channel_axis: int = -1): + if window_shape[0] != strides[0] or window_shape[1] != strides[1]: + raise NotImplementedError('window_shape should be equal to strides.') + + window_shape_kernel = (1,) + tuple(window_shape) + (1,) + strides_kernel = (1,) + tuple(strides) + (1,) + pooling = lambda x: _pool_kernel(x, Pooling.AVG, + window_shape_kernel, strides_kernel, + Padding(padding), normalize_edges, 0) + def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:1] + ( - nngp_feat_shape[1] // window_shape[0], - nngp_feat_shape[2] // window_shape[1]) + nngp_feat_shape[-1:] - new_ntk_feat_shape = ntk_feat_shape[:1] + ( - ntk_feat_shape[1] // window_shape[0], - ntk_feat_shape[2] // window_shape[1]) + ntk_feat_shape[-1:] - return (new_nngp_feat_shape, new_ntk_feat_shape), () + new_nngp_feat_shape = eval_shape(pooling, + ShapedArray(nngp_feat_shape, + np.float32)).shape + new_ntk_feat_shape = eval_shape(pooling, + ShapedArray(ntk_feat_shape, + np.float32)).shape + + if len(input_shape) > 2: + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'A'), () + else: + return (new_nngp_feat_shape, new_ntk_feat_shape, 'A'), () def feature_fn(f, input=None, **kwargs): - window_shape_kernel = (1,) + tuple(window_shape) + (1,) - strides_kernel = (1,) + tuple(strides) + (1,) - pooling = lambda x: _pool_kernel(x, Pooling.AVG, - window_shape_kernel, strides_kernel, - Padding(padding), normalize_edges, 0) - nngp_feat = pooling(f.nngp_feat) + # Unnormalize the input features. + f_renomalized: Features = _unnormalize_features(f) + nngp_feat: np.ndarray = f_renomalized.nngp_feat + ntk_feat: np.ndarray = f_renomalized.ntk_feat + + nngp_feat = pooling(nngp_feat) if f.ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: - ntk_feat = pooling(f.ntk_feat) + ntk_feat = pooling(ntk_feat) - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + # Re-normalize the features. + norms = norms = np.linalg.norm(nngp_feat, axis=channel_axis) + norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) + nngp_feat = nngp_feat / norms + ntk_feat = ntk_feat / norms + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) return init_fn, feature_fn @@ -574,19 +600,30 @@ def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_nngp_feat_shape = nngp_feat_shape[:1] + (_prod(nngp_feat_shape[1:]),) new_ntk_feat_shape = ntk_feat_shape[:1] + (_prod(ntk_feat_shape[1:]),) - return (new_nngp_feat_shape, new_ntk_feat_shape), () + if len(input_shape) > 2: + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'F'), () + else: + return (new_nngp_feat_shape, new_ntk_feat_shape, 'F'), () def feature_fn(f, input=None, **kwargs): + f_renomalized: Features = _unnormalize_features(f) + nngp_feat: np.ndarray = f_renomalized.nngp_feat + ntk_feat: np.ndarray = f_renomalized.ntk_feat + batch_size = f.nngp_feat.shape[batch_axis] - nngp_feat = f.nngp_feat.reshape(batch_size, -1) / _prod( - f.nngp_feat.shape[1:-1])**0.5 + nngp_feat = nngp_feat.reshape(batch_size, -1) / _prod( + nngp_feat.shape[1:-1])**0.5 - if f.ntk_feat.ndim == 0: # check if ntk_feat is empty - ntk_feat = f.ntk_feat - else: - ntk_feat = f.ntk_feat.reshape(batch_size, -1) / _prod( - f.ntk_feat.shape[1:-1])**0.5 + if f.ntk_feat.ndim != 0: # check if ntk_feat is not empty + ntk_feat = ntk_feat.reshape(batch_size, -1) / _prod( + ntk_feat.shape[1:-1])**0.5 + + # Re-normalize the features. + norms = norms = np.linalg.norm(nngp_feat, axis=-1) + norms = np.expand_dims(np.where(norms > 0, norms, 1.0), -1) + nngp_feat = nngp_feat / norms + ntk_feat = ntk_feat / norms - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) return init_fn, feature_fn diff --git a/experimental/tests/fc_ntk_test.py b/experimental/tests/fc_ntk_test.py index b3c6418d..aac17cbe 100644 --- a/experimental/tests/fc_ntk_test.py +++ b/experimental/tests/fc_ntk_test.py @@ -1,4 +1,4 @@ -from numpy.linalg import norm +from jax import numpy as np from jax import random from jax.config import config from jax import jit @@ -10,8 +10,11 @@ from experimental.features import DenseFeatures, ReluFeatures, serial, ReluNTKFeatures + + seed = 1 n, d = 6, 5 +no_jitting = False key1, key2 = random.split(random.PRNGKey(seed)) x = random.normal(key1, (n, d)) @@ -26,6 +29,7 @@ stax.Dense(width, W_std=W_std), stax.Relu(), stax.Dense(1, W_std=W_std)) +kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) nt_kernel = kernel_fn(x, None) print("K_nngp :") @@ -37,57 +41,66 @@ print() -def eval_features(f_): - print(f"f_nngp shape: {f_.nngp_feat.shape}") - print(f"f_ntk shape: {f_.ntk_feat.shape}") +def test_fc_relu_ntk_approx(relufeat_arg, init_fn=None, feature_fn=None): + + print(f"ReluFeatures params:") + for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") + print() + + if init_fn is None or feature_fn is None: + init_fn, feature_fn = serial( + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std)) + + # Initialize random vectors and sketching algorithms + _, feat_fn_inputs = init_fn(key2, x.shape) + + # Transform input vectors to NNGP/NTK feature map + feature_fn = feature_fn if no_jitting else jit(feature_fn) + feats = feature_fn(x, feat_fn_inputs) + + # PolySketch returns complex features. Convert complex features to real ones. + if np.iscomplexobj(feats.nngp_feat) or np.iscomplexobj(feats.ntk_feat): + nngp_feat = np.concatenate((feats.nngp_feat.real, feats.nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) + feats = feats.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + print(f"f_nngp shape: {feats.nngp_feat.shape}") + print(f"f_ntk shape: {feats.ntk_feat.shape}") print("K_nngp:") - print(f_.nngp_feat @ f_.nngp_feat.T) + print(feats.nngp_feat @ feats.nngp_feat.T) print() print("K_ntk:") - print(f_.ntk_feat @ f_.ntk_feat.T) + print(feats.ntk_feat @ feats.ntk_feat.T) print() print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - f_.nngp_feat @ f_.nngp_feat.T)}" + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {np.linalg.norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" ) print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - f_.ntk_feat @ f_.ntk_feat.T)}" + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {np.linalg.norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" ) print() + print("==================== Result of NTK Random Features ====================") kappa0_feat_dim = 4096 kappa1_feat_dim = 4096 sketch_dim = 4096 -relufeat_arg = { - 'method': 'rf', +test_fc_relu_ntk_approx({ + 'method': 'RANDFEAT', 'feature_dim0': kappa0_feat_dim, 'feature_dim1': kappa1_feat_dim, 'sketch_dim': sketch_dim, -} - -print(f"ReluFeatures params:") -for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") -print() - -init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std)) - -# Initialize random vectors and sketching algorithms -feat_shape, feat_fn_inputs = init_fn(key2, x.shape) - -# Transform input vectors to NNGP/NTK feature map -feats = feature_fn(x, feat_fn_inputs) -eval_features(feats) +}) print("==================== Result of NTK wih PolySketch ====================") @@ -95,30 +108,12 @@ def eval_features(f_): poly_sketch_dim = 4096 sketch_dim = 4096 -relufeat_arg = { - 'method': 'ps', +test_fc_relu_ntk_approx({ + 'method': 'POLYSKETCH', 'sketch_dim': sketch_dim, 'poly_degree': poly_degree, 'poly_sketch_dim': poly_sketch_dim -} - -print(f"ReluFeatures params:") -for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") -print() - -init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), - DenseFeatures(1, W_std=W_std)) - -# Initialize random vectors and sketching algorithms -feat_shape, feat_fn_inputs = init_fn(key2, x.shape) - -# Transform input vectors to NNGP/NTK feature map -feats = jit(feature_fn)(x, feat_fn_inputs) -eval_features(feats) +}) print("=============== Result of PolySketch + Random Features ===============") @@ -127,31 +122,13 @@ def eval_features(f_): poly_degree = 4 poly_sketch_dim = 4096 -relufeat_arg = { - 'method': 'psrf', +test_fc_relu_ntk_approx({ + 'method': 'PSRF', 'feature_dim0': kappa0_feat_dim, 'sketch_dim': sketch_dim, 'poly_degree': poly_degree, 'poly_sketch_dim': poly_sketch_dim -} - -print(f"ReluFeatures params:") -for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") -print() - -init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg, top_layer=True), - DenseFeatures(1, W_std=W_std)) - -# Initialize random vectors and sketching algorithms -feat_shape, feat_fn_inputs = init_fn(key2, x.shape) - -# Transform input vectors to NNGP/NTK feature map -feats = jit(feature_fn)(x, feat_fn_inputs) -eval_features(feats) +}) print("=========== Result of ReLU-NTK Sketch (one-pass sketching) ===========") @@ -162,59 +139,14 @@ def eval_features(f_): 'W_std': W_std, } -print(f"ReluFeatures params:") -for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") -print() - init_fn, feature_fn = ReluNTKFeatures(**relufeat_arg) -_, feat_fn_inputs = init_fn(key2, x.shape) - -# Transform input vectors to NNGP/NTK feature map -feats = jit(feature_fn)(x, feat_fn_inputs) -eval_features(feats) - +test_fc_relu_ntk_approx(relufeat_arg, init_fn, feature_fn) print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") print("\t(*No Sketching algorithm is applied.)") -relufeat_arg = {'method': 'poly', 'poly_degree': 64} - -print(f"ReluFeatures params:") -for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") -print() - -init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std)) - -# Initialize random vectors and sketching algorithms -feat_shape, feat_fn_inputs = init_fn(key2, x.shape) - -feats = jit(feature_fn)(x, feat_fn_inputs) -eval_features(feats) - +test_fc_relu_ntk_approx({'method': 'POLY', 'poly_degree': 16}) print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") -relufeat_arg = {'method': 'exact'} - -print(f"ReluFeatures params:") -for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") -print() - -init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std)) - -# Initialize random vectors and sketching algorithms -feat_shape, feat_fn_inputs = init_fn(key2, x.shape) - -feats = jit(feature_fn)(x, feat_fn_inputs) -eval_features(feats) \ No newline at end of file +test_fc_relu_ntk_approx({'method': 'EXACT'}) \ No newline at end of file diff --git a/experimental/tests/myrtle_network_test.py b/experimental/tests/myrtle_network_test.py index e50f799b..774d49f2 100644 --- a/experimental/tests/myrtle_network_test.py +++ b/experimental/tests/myrtle_network_test.py @@ -1,5 +1,3 @@ -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '' import sys sys.path.append("./") import functools @@ -36,25 +34,25 @@ def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.): return stax.serial(*layers) -def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args): +def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), **relu_args): - conv_fn = functools.partial(ConvFeatures, W_std=W_std, b_std=b_std) + conv_fn = functools.partial(ConvFeatures, W_std=W_std) layers = [] - layers += [conv_fn(width, filter_size=3), + layers += [conv_fn(width, filter_shape=(3, 3)), ReluFeatures(**relu_args)] * layer_factor[depth][0] - layers += [AvgPoolFeatures(2, 2)] + layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] layers += [ - ConvFeatures(width, filter_size=3, W_std=W_std), + ConvFeatures(width, filter_shape=(3, 3), W_std=W_std), ReluFeatures(**relu_args) ] * layer_factor[depth][1] - layers += [AvgPoolFeatures(2, 2)] + layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] layers += [ - ConvFeatures(width, filter_size=3, W_std=W_std), + ConvFeatures(width, filter_shape=(3, 3), W_std=W_std), ReluFeatures(**relu_args) ] * layer_factor[depth][2] - layers += [AvgPoolFeatures(2, 2)] * 3 - layers += [FlattenFeatures(), DenseFeatures(1, W_std, b_std)] + layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] * 3 + layers += [FlattenFeatures(), DenseFeatures(1, W_std)] return serial(*layers) @@ -62,14 +60,17 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args): key = random.PRNGKey(0) N, H, W, C = 4, 32, 32, 3 +depth = 5 +no_jitting = False + key1, key2 = random.split(key) x = random.normal(key1, shape=(N, H, W, C)) -_, _, kernel_fn = MyrtleNetwork(5) -kernel_fn = jit(kernel_fn) - print("================= Result of Neural Tangent Library =================") +_, _, kernel_fn = MyrtleNetwork(depth) +kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) + nt_kernel = kernel_fn(x) print("K_nngp (exact):") print(nt_kernel.nngp) @@ -79,38 +80,89 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), b_std=0., **relu_args): print(nt_kernel.ntk) print() -print("================= CNTK Random Features =================") -kappa0_feat_dim = 1000 -kappa1_feat_dim = 1000 -sketch_dim = 1000 -relufeat_arg = { - 'method': 'rf', - 'feature_dim0': kappa0_feat_dim, - 'feature_dim1': kappa1_feat_dim, - 'sketch_dim': sketch_dim, -} +def test_myrtle_network_approx(relufeat_arg, init_fn=None, feature_fn=None): -init_fn, feature_fn = MyrtleNetworkFeatures(5, **relufeat_arg) -feature_fn = jit(feature_fn) + print(f"ReluFeatures params:") + for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") + print() -feat_shape, feat_fn_inputs = init_fn(key2, x.shape) + if init_fn is None or feature_fn is None: + init_fn, feature_fn = MyrtleNetworkFeatures(depth, **relufeat_arg) -feats = feature_fn(x, feat_fn_inputs) + # Initialize random vectors and sketching algorithms + _, feat_fn_inputs = init_fn(key2, x.shape) -print(f"f_nngp shape: {feat_shape[0]}") -print("K_nngp (approx):") -print(feats.nngp_feat @ feats.nngp_feat.T) -print() + # Transform input vectors to NNGP/NTK feature map + feature_fn = feature_fn if no_jitting else jit(feature_fn) + feats = feature_fn(x, feat_fn_inputs) -print(f"f_ntk shape: {feat_shape[1]}") -print("K_ntk (approx):") -print(feats.ntk_feat @ feats.ntk_feat.T) -print() + # PolySketch returns complex features. Convert complex features to real ones. + if np.iscomplexobj(feats.nngp_feat) or np.iscomplexobj(feats.ntk_feat): + nngp_feat = np.concatenate((feats.nngp_feat.real, feats.nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) + feats = feats.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + print(f"f_nngp shape: {feats.nngp_feat.shape}") + print(f"f_ntk shape: {feats.ntk_feat.shape}") -print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" -) -print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" -) + print("K_nngp:") + print(feats.nngp_feat @ feats.nngp_feat.T) + print() + + print("K_ntk:") + print(feats.ntk_feat @ feats.ntk_feat.T) + print() + + print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {np.linalg.norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" + ) + print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {np.linalg.norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" + ) + print() + + + +print("=================== Result of CNTK Random Features ===================") +kappa0_feat_dim = 1024 +kappa1_feat_dim = 1024 +sketch_dim = 1024 + +test_myrtle_network_approx({ + 'method': 'RANDFEAT', + 'feature_dim0': kappa0_feat_dim, + 'feature_dim1': kappa1_feat_dim, + 'sketch_dim': sketch_dim, +}) + +print("==================== Result of CNTK wih PolySketch ====================") +poly_degree = 16 +poly_sketch_dim = 4096 +sketch_dim = 4096 + +test_myrtle_network_approx({ + 'method': 'POLYSKETCH', + 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim +}) + +print("=============== Result of PolySketch + Random Features ===============") +kappa0_feat_dim = 2048 +sketch_dim = 4096 +poly_degree = 4 +poly_sketch_dim = 4096 + +test_myrtle_network_approx({ + 'method': 'PSRF', + 'feature_dim0': kappa0_feat_dim, + 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim +}) + +print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") + +test_myrtle_network_approx({'method': 'EXACT'}) \ No newline at end of file From 5485e17ec5dabadf985eb9dbcb802223c3450fdd Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 31 Mar 2022 00:05:34 +0900 Subject: [PATCH 27/44] Compare ReluNTKFeatures to neural_tangents.empirical_ntk_fn --- experimental/tests/kernel_approx_test.py | 141 +++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 experimental/tests/kernel_approx_test.py diff --git a/experimental/tests/kernel_approx_test.py b/experimental/tests/kernel_approx_test.py new file mode 100644 index 00000000..9d1344b8 --- /dev/null +++ b/experimental/tests/kernel_approx_test.py @@ -0,0 +1,141 @@ +import jax +from jax import numpy as np +from jax import random +from jax.config import config +from jax import jit +import sys + +sys.path.append("./") + +config.update("jax_enable_x64", True) +import neural_tangents as nt +from neural_tangents._src import empirical +from neural_tangents import stax + +from experimental.features import DenseFeatures, ReluFeatures, serial, ReluNTKFeatures + + +def _generate_fc_relu_ntk(width, depth, W_std): + layers = [] + layers += [stax.Dense(width, W_std=W_std), stax.Relu()] * depth + layers += [stax.Dense(output_dim, W_std=W_std)] + init_fn, apply_f, kernel_fn = stax.serial(*layers) + return init_fn, apply_f, kernel_fn + + +# This is re-implementation of neural_tangents.empirical_ntk_fn. +# The result is same with "nt.empirical_ntk_fn(apply_fn)(x, None, params)" +def _get_grad(x, output_dim, params, apply_fn): + + f_output = empirical._get_f_params(apply_fn, x, None, None, {}) + jac_f_output = jax.jacobian(f_output) + jacobian = jac_f_output(params) + + grad_all = [] + for jac_ in jacobian: + if len(jac_) > 0: + for j_ in jac_: + if j_ is None or np.linalg.norm(j_) < 1e-10: + continue + grad_all.append(j_.reshape(n, -1)) + + grad_all = np.hstack(grad_all) + return grad_all / np.sqrt(output_dim) + + +def _get_grad_feat_dim(input_dim, width, output_dim, depth): + dim_1 = input_dim * width + dim_2 = np.asarray([width**2 for _ in range(depth - 1)]).sum() + dim_3 = width * output_dim + return (dim_1 + dim_2 + dim_3) * output_dim - dim_3 + + +def fc_relu_ntk_sketching(relufeat_arg, + rng, + init_fn=None, + feature_fn=None, + W_std=1., + depth=-1, + no_jitting=False): + + if init_fn is None or feature_fn is None: + layers = [] + layers += [ + DenseFeatures(1, W_std=W_std), + ReluFeatures(**relufeat_arg), + ] * depth + layers += [DenseFeatures(1, W_std=W_std)] + init_fn, feature_fn = serial(*layers) + + # Initialize random vectors and sketching algorithms + _, feat_fn_inputs = init_fn(rng, x.shape) + + # Transform input vectors to NNGP/NTK feature map + feature_fn = feature_fn if no_jitting else jit(feature_fn) + feats = feature_fn(x, feat_fn_inputs) + + # PolySketch returns complex features. Convert complex features to real ones. + if np.iscomplexobj(feats.ntk_feat): + return np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) + return feats.ntk_feat + + +seed = 1 +n, d = 1000, 28 * 28 +no_jitting = False + +key1, key2, key3 = random.split(random.PRNGKey(seed), 3) +x = random.normal(key1, (n, d)) + +width = 4 +depth = 3 +W_std = 1.234 +output_dim = 2 + +init_fn, apply_fn, kernel_fn = _generate_fc_relu_ntk(width, depth, W_std) + +kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) +nt_kernel = kernel_fn(x, None) + +# Sanity check of grad_feat. +_, params = init_fn(key2, x.shape) +grad_feat = _get_grad(x, output_dim, params, apply_fn) +assert np.linalg.norm( + nt.empirical_ntk_fn(apply_fn)(x, None, params) - + grad_feat @ grad_feat.T) <= 1e-12 + +# Store Frobenius-norm of the exact NTK for estimating relative errors. +ntk_norm = np.linalg.norm(nt_kernel.ntk) + +width_all = np.arange(2, 16) +grad_feat_dims_all = [] + +print("empirical_ntk_fn results:") +for width in width_all: + init_fn, apply_fn, _ = _generate_fc_relu_ntk(width, depth, W_std) + _, params = init_fn(key2, x.shape) + grad_feat = _get_grad(x, output_dim, params, apply_fn) + rel_err = np.linalg.norm(grad_feat @ grad_feat.T - nt_kernel.ntk) / ntk_norm + grad_feat_dims_all.append(grad_feat.shape[1]) + print( + f"feat_dim : {grad_feat.shape[1]} (width : {width}), relative err : {rel_err}" + ) + +print() +print("ReluNTKFeatures results:") +relufeat_arg = { + 'num_layers': depth, + 'poly_degree': 16, + 'W_std': W_std, +} + +for feat_dim in grad_feat_dims_all: + relufeat_arg['poly_sketch_dim'] = feat_dim + init_fn, feature_fn = ReluNTKFeatures(**relufeat_arg) + ntk_feat = fc_relu_ntk_sketching(relufeat_arg, + key3, + init_fn=init_fn, + feature_fn=feature_fn) + + rel_err = np.linalg.norm(ntk_feat @ ntk_feat.T - nt_kernel.ntk) / ntk_norm + print(f"feat_dim : {ntk_feat.shape[1]}, err : {rel_err}") \ No newline at end of file From 3acf87abb828067c42308df3333c40df3701b991 Mon Sep 17 00:00:00 2001 From: Amir Zandieh Date: Tue, 5 Apr 2022 16:45:32 +0200 Subject: [PATCH 28/44] test --- experimental/tests/fc_ntk_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/tests/fc_ntk_test.py b/experimental/tests/fc_ntk_test.py index aac17cbe..dc967640 100644 --- a/experimental/tests/fc_ntk_test.py +++ b/experimental/tests/fc_ntk_test.py @@ -22,7 +22,7 @@ width = 512 # this does not matter the output W_std = 1.234 # std of Gaussian random weights -print("================== Result of Neural Tangent Library ==================") +print("================== Result of Neural Tangent Library ===================") init_fn, _, kernel_fn = stax.serial(stax.Dense(width, W_std=W_std), stax.Relu(), stax.Dense(width, W_std=W_std), stax.Relu(), @@ -149,4 +149,4 @@ def test_fc_relu_ntk_approx(relufeat_arg, init_fn=None, feature_fn=None): print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") -test_fc_relu_ntk_approx({'method': 'EXACT'}) \ No newline at end of file +test_fc_relu_ntk_approx({'method': 'EXACT'}) From 07f563def0675a220be7480a0034a156087edf00 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 6 Apr 2022 06:14:28 +0900 Subject: [PATCH 29/44] Extend to ConvFeatures with retangular filter o shape --- experimental/README.md | 6 ++ experimental/features.py | 116 ++++++++++++++-------- experimental/poly_fitting.py | 5 +- experimental/tests/myrtle_network_test.py | 14 +-- 4 files changed, 89 insertions(+), 52 deletions(-) diff --git a/experimental/README.md b/experimental/README.md index 9a71c006..0a1ccaf0 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -36,9 +36,11 @@ feats = feature_fn(x, feat_fn_inputs) assert feats.nngp_feat.shape == (5, relufeat_arg['feature_dim1']) assert feats.ntk_feat.shape == (5, relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim']) ``` + For more details of fully connected NTK features, please check `test_fc_ntk.py`. ### Convolutional NTK approximation via Random Features: + ```python from experimental.features import ConvFeatures, AvgPoolFeatures, FlattenFeatures @@ -71,6 +73,7 @@ All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kerne ## [`features.DenseFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L88) `features.DenseFeatures` provides features for fully-connected dense layer and corresponds to `stax.Dense` module in [Neural Tangents](https://github.com/google/neural-tangents). We assume that the input is a tabular dataset (i.e., a n-by-d matrix). Its `feature_fn` updates the NTK features by concatenating NNGP features and NTK features. This is because `stax.Dense` updates a new NTK kernel matrix `(N x D)` by adding the previous NNGP and NTK kernel matrices. The features of dense layer are exact and no approximations are applied. + ```python from jax import numpy as np from neural_tangents import stax @@ -94,6 +97,7 @@ assert feat.ntk_feat == np.zeros(()) For image dataset, the inputs are 4-D tensors with shape `N x H x W x D` where N is batch size, H is image height, W is image width and D is the feature dimension. We reshape the image features into 2-D tensor with shape `NHW x D` and apply proper feature approximations. Then, the resulting features reshape to 4-D tensor with shape `N x H x W x D'` where `D'` is the output dimension of the feature approximation. To use the Random Features approach, set the parameter `method` to `rf` (default `rf`), e.g., + ```python from experimental.features import DenseFeatures, ReluFeatures, serial @@ -113,6 +117,7 @@ assert out_feat.ntk_feat.shape == (3, 30) ``` To use the exact feature map (based on Cholesky decomposition), set the parameter `method` to `exact`, e.g., + ```python init_fn, feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact')) _, params = init_fn(key1, x.shape) @@ -121,6 +126,7 @@ out_feat = feat_fn(x, params) assert out_feat.nngp_feat.shape == (3, 3) assert out_feat.ntk_feat.shape == (3, 3) ``` + (This is for debugging. The dimension of the exact feature map is equal to the number of inputs, i.e., `N` for tabular dataset, `NHW` for image dataset). diff --git a/experimental/features.py b/experimental/features.py index fe4f13a3..f4c749bf 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -1,6 +1,7 @@ import enum from typing import Optional, Callable, Sequence, Tuple from jax import random +from jax._src.util import prod from jax import numpy as np from jax.numpy.linalg import cholesky import jax.example_libraries.stax as ostax @@ -22,6 +23,8 @@ class Features: ntk_feat: np.ndarray norms: np.ndarray + is_reversed: bool = dataclasses.field(pytree_node=False) + batch_axis: int = 0 channel_axis: int = -1 @@ -101,9 +104,12 @@ def _inputs_to_features(x: np.ndarray, nngp_feat = nngp_feat / norms ntk_feat = np.zeros((), dtype=nngp_feat.dtype) + is_reversed = False + return Features(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms, + is_reversed=is_reversed, batch_axis=batch_axis, channel_axis=channel_axis) # pytype:disable=wrong-keyword-args @@ -181,7 +187,8 @@ def ReluFeatures(method: str = 'RANDFEAT', feature_dim1: int = 1, sketch_dim: int = 1, poly_degree: int = 8, - poly_sketch_dim: int = 1): + poly_sketch_dim: int = 1, + generate_rand_mtx: bool = True): method = ReluFeaturesMethod(method.upper()) @@ -195,10 +202,16 @@ def init_fn(rng, input_shape): if method == ReluFeaturesMethod.RANDFEAT: rng1, rng2, rng3 = random.split(rng, 3) - # Random vectors for random features of arc-cosine kernel of order 0. - W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) - # Random vectors for random features of arc-cosine kernel of order 1. - W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) + if generate_rand_mtx: + # Random vectors for random features of arc-cosine kernel of order 0. + W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) + # Random vectors for random features of arc-cosine kernel of order 1. + W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) + else: + # if `generate_rand_mtx` is False, return random seeds and shapes instead of np.ndarray. + W0 = (rng1, (nngp_feat_shape[-1], feature_dim0)) + W1 = (rng2, (nngp_feat_shape[-1], feature_dim1)) + # TensorSRHT of degree 2 for approximating tensor product. tensorsrht = TensorSRHT(rng=rng3, input_dim1=ntk_feat_shape[-1], @@ -210,7 +223,7 @@ def init_fn(rng, input_shape): elif method == ReluFeaturesMethod.POLYSKETCH: new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) - rng1, rng2, rng3 = random.split(rng, 3) + rng1, rng2 = random.split(rng, 2) kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) @@ -224,10 +237,11 @@ def init_fn(rng, input_shape): # TensorSRHT of degree 2 for approximating tensor product. tensorsrht = TensorSRHT( + rng=rng2, input_dim1=ntk_feat_shape[-1] // (1 + (relu_layers_count > 0)), input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, - sketch_dim=sketch_dim, - rng=rng2).init_sketches() # pytype:disable=wrong-keyword-args + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), (polysketch, tensorsrht, kappa0_coeff, kappa1_coeff) @@ -263,9 +277,8 @@ def init_fn(rng, input_shape): elif method == ReluFeaturesMethod.POLY: # This only uses the polynomial approximation without sketching. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( - nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) + new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) @@ -275,9 +288,8 @@ def init_fn(rng, input_shape): elif method == ReluFeaturesMethod.EXACT: # The exact feature map computation is for debug. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (_prod( - nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (_prod(ntk_feat_shape[:-1]),) + new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), () @@ -295,8 +307,12 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: norms: np.ndarray = f.norms if method == ReluFeaturesMethod.RANDFEAT: # Random Features approach. - W0: np.ndarray = input[0] - W1: np.ndarray = input[1] + if generate_rand_mtx: + W0: np.ndarray = input[0] + W1: np.ndarray = input[1] + else: + W0 = random.normal(input[0][0], shape=input[0][1]) + W1 = random.normal(input[1][0], shape=input[1][1]) tensorsrht: TensorSRHT = input[2] kappa0_feat = (nngp_feat_2d @ W0 > 0) / W0.shape[-1]**0.5 @@ -380,13 +396,6 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: return init_fn, feature_fn -def _prod(tuple_): - prod = 1 - for x in tuple_: - prod = prod * x - return prod - - @layer def ReluNTKFeatures( num_layers: int, @@ -459,20 +468,18 @@ def ConvFeatures(out_chan: int, channel_axis = lhs_spec.index('C') + patch_size = prod(filter_shape) + if parameterization != 'ntk': raise NotImplementedError(f'Parameterization ({parameterization}) is ' ' not implemented yet.') - if filter_shape[0] != filter_shape[1]: - raise NotImplementedError('filter_shape should be square.') - filter_size = filter_shape[0] - def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] * - filter_size**2,) + patch_size,) new_ntk_feat_shape = nngp_feat_shape[:-1] + ( - (nngp_feat_shape[-1] + ntk_feat_shape[-1]) * filter_size**2,) + (nngp_feat_shape[-1] + ntk_feat_shape[-1]) * patch_size,) if len(input_shape) > 2: return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'C'), () @@ -485,15 +492,27 @@ def feature_fn(f, input, **kwargs): they are not linear operations, we first unnormalize features (i.e., multiplying them by `norms`) and then re-normalize the output features. """ + is_reversed = f.is_reversed + f_renormalized: Features = _unnormalize_features(f) nngp_feat: np.ndarray = f_renormalized.nngp_feat ntk_feat: np.ndarray = f_renormalized.ntk_feat - nngp_feat = _conv2d_feat(nngp_feat, filter_size) / filter_size * W_std + + if is_reversed: + filter_shape_ = filter_shape[::-1] + else: + filter_shape_ = filter_shape + + is_reversed = not f.is_reversed + + nngp_feat = _concat_shifted_features_2d( + nngp_feat, filter_shape_) * W_std / patch_size**0.5 if f.ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: - ntk_feat = _conv2d_feat(ntk_feat, filter_size) / filter_size * W_std + ntk_feat = _concat_shifted_features_2d( + ntk_feat, filter_shape_) * W_std / patch_size**0.5 ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) # Re-normalize the features. @@ -502,19 +521,24 @@ def feature_fn(f, input, **kwargs): nngp_feat = nngp_feat / norms ntk_feat = ntk_feat / norms - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + return f.replace(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + norms=norms, + is_reversed=is_reversed) return init_fn, feature_fn -def _conv2d_feat(X, filter_size): - return _conv_feat(np.moveaxis(_conv_feat(X, filter_size), 1, 2), filter_size) +def _concat_shifted_features_2d(X: np.ndarray, filter_shape: Sequence[int]): + return _concat_shifted_features( + np.moveaxis(_concat_shifted_features(X, filter_shape[1]), 1, 2), + filter_shape[0]) -def _conv_feat(X, filter_size): +def _concat_shifted_features(X, filter_size): """ - Direct sums of image features. If input shape is (N, H, W, C), the output has - the shape (N, H, W, C * filter_size**2). + Concatenations of shifted image features. If input shape is (N, H, W, C), + the output has the shape (N, H, W, C * filter_size). """ N, H, W, C = X.shape out = np.zeros((N, H, W, C * filter_size)) @@ -579,7 +603,10 @@ def feature_fn(f, input=None, **kwargs): nngp_feat = nngp_feat / norms ntk_feat = ntk_feat / norms - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + return f.replace(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + norms=norms, + is_reversed=False) return init_fn, feature_fn @@ -598,8 +625,8 @@ def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:1] + (_prod(nngp_feat_shape[1:]),) - new_ntk_feat_shape = ntk_feat_shape[:1] + (_prod(ntk_feat_shape[1:]),) + new_nngp_feat_shape = nngp_feat_shape[:1] + (prod(nngp_feat_shape[1:]),) + new_ntk_feat_shape = ntk_feat_shape[:1] + (prod(ntk_feat_shape[1:]),) if len(input_shape) > 2: return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'F'), () else: @@ -611,11 +638,11 @@ def feature_fn(f, input=None, **kwargs): ntk_feat: np.ndarray = f_renomalized.ntk_feat batch_size = f.nngp_feat.shape[batch_axis] - nngp_feat = nngp_feat.reshape(batch_size, -1) / _prod( + nngp_feat = nngp_feat.reshape(batch_size, -1) / prod( nngp_feat.shape[1:-1])**0.5 if f.ntk_feat.ndim != 0: # check if ntk_feat is not empty - ntk_feat = ntk_feat.reshape(batch_size, -1) / _prod( + ntk_feat = ntk_feat.reshape(batch_size, -1) / prod( ntk_feat.shape[1:-1])**0.5 # Re-normalize the features. @@ -624,6 +651,9 @@ def feature_fn(f, input=None, **kwargs): nngp_feat = nngp_feat / norms ntk_feat = ntk_feat / norms - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + return f.replace(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + norms=norms, + is_reversed=False) return init_fn, feature_fn diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index a27a52fd..d985b2a3 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -88,8 +88,9 @@ def kappa0_coeffs(degree: int, num_layers: int): # For kappa0, we set all weights to be one. weights = np.ones(len(fvals), dtype=xvals.dtype) - # Coefficients can be obtained by solving QP with OSQP jaxopt. - coeffs = poly_fitting_qp(xvals, fvals, weights, degree) + # Coefficients can be obtained by solving QP with OSQP jaxopt. kappa0 has a + # sharp slope at x=1, hence we add an equailty condition of p_n(1)=f(x). + coeffs = poly_fitting_qp(xvals, fvals, weights, degree, eq_last_point=True) return np.where(coeffs < 1e-5, 0.0, coeffs) diff --git a/experimental/tests/myrtle_network_test.py b/experimental/tests/myrtle_network_test.py index 774d49f2..0b2e108e 100644 --- a/experimental/tests/myrtle_network_test.py +++ b/experimental/tests/myrtle_network_test.py @@ -138,9 +138,9 @@ def test_myrtle_network_approx(relufeat_arg, init_fn=None, feature_fn=None): }) print("==================== Result of CNTK wih PolySketch ====================") -poly_degree = 16 -poly_sketch_dim = 4096 -sketch_dim = 4096 +poly_degree = 8 +poly_sketch_dim = 1024 +sketch_dim = 1024 test_myrtle_network_approx({ 'method': 'POLYSKETCH', @@ -150,10 +150,10 @@ def test_myrtle_network_approx(relufeat_arg, init_fn=None, feature_fn=None): }) print("=============== Result of PolySketch + Random Features ===============") -kappa0_feat_dim = 2048 -sketch_dim = 4096 -poly_degree = 4 -poly_sketch_dim = 4096 +kappa0_feat_dim = 512 +sketch_dim = 1024 +poly_degree = 8 +poly_sketch_dim = 1024 test_myrtle_network_approx({ 'method': 'PSRF', From a65d72905ca201318f1a840dc244e3a5bfbce012 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 6 Apr 2022 08:52:34 +0900 Subject: [PATCH 30/44] Fix complex dtype warning --- experimental/features.py | 2 +- experimental/tests/myrtle_network_test.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index f4c749bf..68deb9e3 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -541,7 +541,7 @@ def _concat_shifted_features(X, filter_size): the output has the shape (N, H, W, C * filter_size). """ N, H, W, C = X.shape - out = np.zeros((N, H, W, C * filter_size)) + out = np.zeros((N, H, W, C * filter_size), dtype=X.dtype) out = out.at[:, :, :, :C].set(X) j = 1 for i in range(1, min((filter_size + 1) // 2, W)): diff --git a/experimental/tests/myrtle_network_test.py b/experimental/tests/myrtle_network_test.py index 0b2e108e..6b3a66e5 100644 --- a/experimental/tests/myrtle_network_test.py +++ b/experimental/tests/myrtle_network_test.py @@ -124,7 +124,6 @@ def test_myrtle_network_approx(relufeat_arg, init_fn=None, feature_fn=None): print() - print("=================== Result of CNTK Random Features ===================") kappa0_feat_dim = 1024 kappa1_feat_dim = 1024 @@ -163,6 +162,11 @@ def test_myrtle_network_approx(relufeat_arg, init_fn=None, feature_fn=None): 'poly_sketch_dim': poly_sketch_dim }) +print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") +print("\t(*No Sketching algorithm is applied.)") + +test_myrtle_network_approx({'method': 'POLY', 'poly_degree': poly_degree}) + print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") test_myrtle_network_approx({'method': 'EXACT'}) \ No newline at end of file From 671c27de858c0bdc8b746e7e2b0cdf607a4fe7ba Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 6 Apr 2022 15:24:02 +0900 Subject: [PATCH 31/44] Update Cholesky decomposition safely --- experimental/features.py | 13 +- experimental/tests/myrtle_network_test.py | 220 ++++++++++++---------- 2 files changed, 131 insertions(+), 102 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index 68deb9e3..c33347ef 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -3,7 +3,6 @@ from jax import random from jax._src.util import prod from jax import numpy as np -from jax.numpy.linalg import cholesky import jax.example_libraries.stax as ostax from jax import eval_shape, ShapedArray @@ -371,19 +370,19 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: kappa1_coeff: np.ndarray = input[1] gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) - nngp_feat = cholesky(np.polyval(kappa1_coeff[::-1], + nngp_feat = _cholesky(np.polyval(kappa1_coeff[::-1], gram_nngp)).reshape(input_shape + (-1,)) ntk = ntk_feat_2d @ ntk_feat_2d.T kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) - ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) elif method == ReluFeaturesMethod.EXACT: # Exact feature map computations via Cholesky decomposition. - nngp_feat = cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) + nngp_feat = _cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) ntk = ntk_feat_2d @ ntk_feat_2d.T kappa0_mat = kappa0(nngp_feat_2d) - ntk_feat = cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) else: raise NotImplementedError(f'Invalid method name: {method}') @@ -396,6 +395,10 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: return init_fn, feature_fn +def _cholesky(mat): + return np.linalg.cholesky(mat + 1e-8 * np.eye(mat.shape[0])) + + @layer def ReluNTKFeatures( num_layers: int, diff --git a/experimental/tests/myrtle_network_test.py b/experimental/tests/myrtle_network_test.py index 6b3a66e5..fd920936 100644 --- a/experimental/tests/myrtle_network_test.py +++ b/experimental/tests/myrtle_network_test.py @@ -1,3 +1,4 @@ +import time import sys sys.path.append("./") import functools @@ -14,10 +15,8 @@ from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} -width = 1 - -def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.): +def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0., width=1): activation_fn = stax.Relu() conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME') @@ -34,7 +33,7 @@ def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.): return stax.serial(*layers) -def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), **relu_args): +def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), width=1, **relu_args): conv_fn = functools.partial(ConvFeatures, W_std=W_std) @@ -57,116 +56,143 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), **relu_args): return serial(*layers) -key = random.PRNGKey(0) +def test_small_dataset(num_data=4, dataset='synthetic', depth=5, no_jitting=False): -N, H, W, C = 4, 32, 32, 3 -depth = 5 -no_jitting = False + print(f"dataset : {dataset}") -key1, key2 = random.split(key) -x = random.normal(key1, shape=(N, H, W, C)) + key = random.PRNGKey(0) -print("================= Result of Neural Tangent Library =================") + if dataset == 'synthetic': + H, W, C = 32, 32, 3 + x = random.normal(key, shape=(num_data, H, W, C)) -_, _, kernel_fn = MyrtleNetwork(depth) -kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) + elif dataset in ['cifar10', 'cifar100']: + from examples import datasets + x = datasets.get_dataset('cifar10', do_flatten_and_normalize=False)[0] + mean_ = np.mean(x) + std_ = np.std(x) + x = (x - mean_) / std_ -nt_kernel = kernel_fn(x) -print("K_nngp (exact):") -print(nt_kernel.nngp) -print() + x = x[random.permutation(key, len(x))[:num_data]] -print("K_ntk (exact):") -print(nt_kernel.ntk) -print() + else: + raise NotImplementedError(f"Invalid dataset : {dataset}") + key1, key2 = random.split(key) + print("================= Result of Neural Tangent Library =================") -def test_myrtle_network_approx(relufeat_arg, init_fn=None, feature_fn=None): + _, _, kernel_fn = MyrtleNetwork(depth) + kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) - print(f"ReluFeatures params:") - for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") - print() + tic = time.time() + nt_kernel = kernel_fn(x) + toc = time.time() - tic + print(f"nt kernel time: {toc:.4f} sec") - if init_fn is None or feature_fn is None: - init_fn, feature_fn = MyrtleNetworkFeatures(depth, **relufeat_arg) + if num_data <= 8: + print("K_nngp (exact):") + print(nt_kernel.nngp) + print() - # Initialize random vectors and sketching algorithms - _, feat_fn_inputs = init_fn(key2, x.shape) + print("K_ntk (exact):") + print(nt_kernel.ntk) + print() - # Transform input vectors to NNGP/NTK feature map - feature_fn = feature_fn if no_jitting else jit(feature_fn) - feats = feature_fn(x, feat_fn_inputs) + def test_myrtle_network_approx(relufeat_arg): - # PolySketch returns complex features. Convert complex features to real ones. - if np.iscomplexobj(feats.nngp_feat) or np.iscomplexobj(feats.ntk_feat): - nngp_feat = np.concatenate((feats.nngp_feat.real, feats.nngp_feat.imag), axis=-1) - ntk_feat = np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) - feats = feats.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + print(f"ReluFeatures params:") + for name_, value_ in relufeat_arg.items(): + print(f"{name_:<12} : {value_}") + print() - print(f"f_nngp shape: {feats.nngp_feat.shape}") - print(f"f_ntk shape: {feats.ntk_feat.shape}") + init_fn, feature_fn = MyrtleNetworkFeatures(depth, **relufeat_arg) - print("K_nngp:") - print(feats.nngp_feat @ feats.nngp_feat.T) - print() + # Initialize random vectors and sketching algorithms + _, feat_fn_inputs = init_fn(key2, x.shape) + + # Transform input vectors to NNGP/NTK feature map + feature_fn = feature_fn if no_jitting else jit(feature_fn) + + tic = time.time() + feats = feature_fn(x, feat_fn_inputs) + toc = time.time() - tic + print(f"{relufeat_arg['method']} feature time: {toc:.4f} sec") + + # PolySketch returns complex features. Convert complex features to real ones. + if np.iscomplexobj(feats.nngp_feat) or np.iscomplexobj(feats.ntk_feat): + nngp_feat = np.concatenate((feats.nngp_feat.real, feats.nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) + feats = feats.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + print(f"f_nngp shape: {feats.nngp_feat.shape}") + print(f"f_ntk shape: {feats.ntk_feat.shape}") + + if num_data <= 8: + print("K_nngp:") + print(feats.nngp_feat @ feats.nngp_feat.T) + print() + + print("K_ntk:") + print(feats.ntk_feat @ feats.ntk_feat.T) + print() + + print( + f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {np.linalg.norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" + ) + print( + f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {np.linalg.norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" + ) + print() + + + print("================= Result of CNTK Random Features =================") + kappa0_feat_dim = 1024 + kappa1_feat_dim = 1024 + sketch_dim = 1024 + + test_myrtle_network_approx({ + 'method': 'RANDFEAT', + 'feature_dim0': kappa0_feat_dim, + 'feature_dim1': kappa1_feat_dim, + 'sketch_dim': sketch_dim, + }) + + print("================== Result of CNTK wih PolySketch ==================") + poly_degree = 8 + poly_sketch_dim = 1024 + sketch_dim = 1024 + + test_myrtle_network_approx({ + 'method': 'POLYSKETCH', + 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim + }) + + print("============== Result of PolySketch + Random Features ==============") + kappa0_feat_dim = 512 + sketch_dim = 1024 + poly_degree = 8 + poly_sketch_dim = 1024 + + test_myrtle_network_approx({ + 'method': 'PSRF', + 'feature_dim0': kappa0_feat_dim, + 'sketch_dim': sketch_dim, + 'poly_degree': poly_degree, + 'poly_sketch_dim': poly_sketch_dim + }) - print("K_ntk:") - print(feats.ntk_feat @ feats.ntk_feat.T) - print() + print("===== (Debug) NTK Feature Maps with Polynomial Approximation =====") + print("\t(*No Sketching algorithm is applied.)") - print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {np.linalg.norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" - ) - print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {np.linalg.norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" - ) - print() + test_myrtle_network_approx({'method': 'POLY', 'poly_degree': poly_degree}) + print("==== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ====") -print("=================== Result of CNTK Random Features ===================") -kappa0_feat_dim = 1024 -kappa1_feat_dim = 1024 -sketch_dim = 1024 + test_myrtle_network_approx({'method': 'EXACT'}) -test_myrtle_network_approx({ - 'method': 'RANDFEAT', - 'feature_dim0': kappa0_feat_dim, - 'feature_dim1': kappa1_feat_dim, - 'sketch_dim': sketch_dim, -}) - -print("==================== Result of CNTK wih PolySketch ====================") -poly_degree = 8 -poly_sketch_dim = 1024 -sketch_dim = 1024 - -test_myrtle_network_approx({ - 'method': 'POLYSKETCH', - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim -}) - -print("=============== Result of PolySketch + Random Features ===============") -kappa0_feat_dim = 512 -sketch_dim = 1024 -poly_degree = 8 -poly_sketch_dim = 1024 - -test_myrtle_network_approx({ - 'method': 'PSRF', - 'feature_dim0': kappa0_feat_dim, - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim -}) - -print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") -print("\t(*No Sketching algorithm is applied.)") - -test_myrtle_network_approx({'method': 'POLY', 'poly_degree': poly_degree}) - -print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") - -test_myrtle_network_approx({'method': 'EXACT'}) \ No newline at end of file +if __name__ == "__main__": + test_small_dataset(num_data=6, dataset='synthetic', depth=5, no_jitting=False) + test_small_dataset(num_data=6, dataset='cifar10', depth=5, no_jitting=False) + test_small_dataset(num_data=6, dataset='cifar100', depth=5, no_jitting=False) \ No newline at end of file From c84ed65fea9698eff1a1d7a3ae5d3ad238a1a9aa Mon Sep 17 00:00:00 2001 From: insuhan Date: Tue, 10 May 2022 07:33:15 +0900 Subject: [PATCH 32/44] Fix nans issue -- complex data type --- experimental/features.py | 349 ++++++++++++++++++++++++++++++++++++++ experimental/sketching.py | 8 +- 2 files changed, 354 insertions(+), 3 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index c33347ef..e5d8c601 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -395,6 +395,355 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: return init_fn, feature_fn +@layer +def ReluFeatures2(method: str = 'RANDFEAT', + feature_dim0: int = 1, + feature_dim1: int = 1, + sketch_dim: int = 1, + poly_degree: int = 8, + poly_sketch_dim: int = 1, + generate_rand_mtx: bool = True): + + method = ReluFeaturesMethod(method.upper()) + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) + net_shape = input_shape[2] + relu_layers_count = net_shape.count('R') + new_net_shape = net_shape + 'R' + + if method == ReluFeaturesMethod.POLYSKETCH: + new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + rng1, rng2 = random.split(rng, 2) + + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) + + # PolySketch expansion for nngp features. + polysketch = PolyTensorSketch(rng=rng1, + input_dim=nngp_feat_shape[-1] // + (1 + (relu_layers_count > 0)), + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT( + rng=rng2, + input_dim1=ntk_feat_shape[-1] // (1 + (relu_layers_count > 0)), + input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_net_shape), (polysketch, tensorsrht, kappa0_coeff, + kappa1_coeff) + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + def feature_fn(f: Features, input=None, **kwargs) -> Features: + + input_shape: tuple = f.nngp_feat.shape[:-1] + nngp_feat_dim: tuple = f.nngp_feat.shape[-1] + ntk_feat_dim: tuple = f.ntk_feat.shape[-1] + + nngp_feat_2d: np.ndarray = f.nngp_feat.reshape(-1, nngp_feat_dim) + ntk_feat_2d: np.ndarray = f.ntk_feat.reshape(-1, ntk_feat_dim) + norms: np.ndarray = f.norms + + if method == ReluFeaturesMethod.POLYSKETCH: + polysketch: PolyTensorSketch = input[0] + tensorsrht: TensorSRHT = input[1] + kappa0_coeff: np.ndarray = input[2] + kappa1_coeff: np.ndarray = input[3] + + # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) + + # Apply SRHT to kappa1_feat so that dimension of nngp_feat is poly_sketch_dim//2. + nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + + (-1,)) + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. + ntk_feat = tensorsrht.sketch(ntk_feat_2d, + kappa0_feat).reshape(input_shape + (-1,)) + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + if method != ReluFeaturesMethod.RANDFEAT: + norms /= 2.0**0.5 + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + + return init_fn, feature_fn + + +import jax +import numpy as onp +@layer +def NormGaussFeatures(method: str = 'RANDFEAT', + feature_dim0: int = 1, + feature_dim1: int = 1, + sketch_dim: int = 1, + poly_degree: int = 8, + poly_sketch_dim: int = 1, + generate_rand_mtx: bool = True): + + method = ReluFeaturesMethod(method.upper()) + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) + net_shape = input_shape[2] + relu_layers_count = net_shape.count('R') + new_net_shape = net_shape + 'R' + + # print(input_shape) + # print(relu_layers_count) + + if method == ReluFeaturesMethod.RANDFEAT: + rng1, rng2, rng3 = random.split(rng, 3) + if generate_rand_mtx: + # Random vectors for random features of arc-cosine kernel of order 0. + W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) + # Random vectors for random features of arc-cosine kernel of order 1. + W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) + else: + # if `generate_rand_mtx` is False, return random seeds and shapes instead of np.ndarray. + W0 = (rng1, (nngp_feat_shape[-1], feature_dim0)) + W1 = (rng2, (nngp_feat_shape[-1], feature_dim1)) + + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT(rng=rng3, + input_dim1=ntk_feat_shape[-1], + input_dim2=feature_dim0, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_net_shape), (W0, W1, tensorsrht) + + elif method == ReluFeaturesMethod.POLYSKETCH: + new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + rng1, rng2 = random.split(rng, 2) + + kappa1_coeff = np.exp(-1) * np.array([1/np.exp(jax.scipy.special.gammaln(i+1)) for i in range(poly_degree+1)]) + kappa0_coeff = np.exp(-1) * np.array([1/np.exp(jax.scipy.special.gammaln(i+1)) for i in range(poly_degree+1)]) + + # PolySketch expansion for nngp features. + polysketch = PolyTensorSketch(rng=rng1, + input_dim=nngp_feat_shape[-1] // + (1 + (relu_layers_count > 0)), + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT( + rng=rng2, + input_dim1=ntk_feat_shape[-1] // (1 + (relu_layers_count > 0)), + input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + # print(f"new_nngp_feat_shape : {new_nngp_feat_shape}") + # print(f"new_ntk_feat_shape : {new_ntk_feat_shape}") + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_net_shape), (polysketch, tensorsrht, kappa0_coeff, + kappa1_coeff) + + elif method == ReluFeaturesMethod.PSRF: + new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + rng1, rng2, rng3 = random.split(rng, 3) + + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + + # PolySketch expansion for nngp features. + polysketch = PolyTensorSketch(rng=rng1, + input_dim=nngp_feat_shape[-1] // + (1 + (relu_layers_count > 0)), + sketch_dim=poly_sketch_dim, + degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args + + # TensorSRHT of degree 2 for approximating tensor product. + tensorsrht = TensorSRHT(rng=rng2, + input_dim1=ntk_feat_shape[-1] // + (1 + (relu_layers_count > 0)), + input_dim2=feature_dim0, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + + # Random vectors for random features of arc-cosine kernel of order 0. + if relu_layers_count == 0: + W0 = random.normal(rng3, (2 * nngp_feat_shape[-1], feature_dim0 // 2)) + else: + W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0 // 2)) + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_net_shape), (W0, polysketch, tensorsrht, kappa1_coeff) + + elif method == ReluFeaturesMethod.POLY: + # This only uses the polynomial approximation without sketching. + new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) + + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) + kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) + + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_net_shape), (kappa0_coeff, kappa1_coeff) + + elif method == ReluFeaturesMethod.EXACT: + # The exact feature map computation is for debug. + new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) + new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) + + return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), () + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + def feature_fn(f: Features, input=None, **kwargs) -> Features: + + input_shape: tuple = f.nngp_feat.shape[:-1] + nngp_feat_dim: tuple = f.nngp_feat.shape[-1] + ntk_feat_dim: tuple = f.ntk_feat.shape[-1] + + nngp_feat_2d: np.ndarray = f.nngp_feat.reshape(-1, nngp_feat_dim) + ntk_feat_2d: np.ndarray = f.ntk_feat.reshape(-1, ntk_feat_dim) + norms: np.ndarray = f.norms + + # if np.any(np.isnan(nngp_feat_2d)) or np.any(np.isnan(ntk_feat_2d)): + # print("1. nan is observed") + # import pdb; pdb.set_trace(); + + if method == ReluFeaturesMethod.RANDFEAT: # Random Features approach. + if generate_rand_mtx: + W0: np.ndarray = input[0] + W1: np.ndarray = input[1] + else: + W0 = random.normal(input[0][0], shape=input[0][1]) + W1 = random.normal(input[1][0], shape=input[1][1]) + tensorsrht: TensorSRHT = input[2] + + kappa0_feat = (nngp_feat_2d @ W0 > 0) / W0.shape[-1]**0.5 + del W0 + nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / + W1.shape[-1]**0.5).reshape(input_shape + (-1,)) + del W1 + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat, + real_output=True).reshape(input_shape + + (-1,)) + + elif method == ReluFeaturesMethod.POLYSKETCH: + polysketch: PolyTensorSketch = input[0] + tensorsrht: TensorSRHT = input[1] + kappa0_coeff: np.ndarray = input[2] + kappa1_coeff: np.ndarray = input[3] + + # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) + # del polysketch_feats + + # Apply SRHT to kappa1_feat so that dimension of nngp_feat is poly_sketch_dim//2. + nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + + (-1,)) + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) + # print(f"ntk_feat_2d avg norm : {np.sum(ntk_feat_2d.real**2 + ntk_feat_2d.imag**2, axis=-1).mean()}") + # print(f"kappa0_feat avg norm : {np.sum(kappa0_feat.real**2 + kappa0_feat.imag**2, axis=-1).mean()}") + # print(f"ntk_feat avg norm : {np.sum(ntk_feat.real**2 + ntk_feat.imag**2, axis=-1).mean()}") + + # aa = np.fft.fftn(ntk_feat_2d * tensorsrht.rand_signs1, axes=(-1,))[:, tensorsrht.rand_inds1] + # bb = np.fft.fftn( (kappa0_feat * tensorsrht.rand_signs2), axes=(-1,))[:, tensorsrht.rand_inds2] + # bb = np.fft.fftn(kappa0_feat * tensorsrht.rand_signs2, axes=(-1,))[:, tensorsrht.rand_inds2] + # ntk_feat = np.sqrt(1 / tensorsrht.rand_inds1.shape[-1]) * (aa * bb).reshape(input_shape + (-1,)) + # if np.any(np.isnan(ntk_feat)) or np.any(np.isnan(nngp_feat)): + # kk = kappa0_feat * tensorsrht.rand_signs2 + # np.sum(kk.real **2 + kk.imag**2, axis=-1) + # bk = np.fft.fftn(kk, axes=(-1,)) + # np.sum(bk.real **2 + bk.imag**2, axis=-1) + # qq = np.fft.fftn(kappa0_feat * tensorsrht.rand_signs2, axes=(-1,)) + # np.sum(qq.real **2 + qq.imag**2, axis=-1) + # bb = np.fft.fftn( (kappa0_feat * tensorsrht.rand_signs2), axes=(-1,)) + # np.sum(bb.real**2 + bb.imag**2, axis=-1) + # np.sum(ntk_feat_2d.real ** 2 + ntk_feat_2d.imag**2, axis=-1) + # np.sum(kappa0_feat.real ** 2 + kappa0_feat.imag**2, axis=-1) + # np.sum(kk.real ** 2 + kk.imag**2, axis=-1) + + # np.sum(kappa0_feat.real ** 2 + kappa0_feat.imag**2, axis=-1) + + # cc = kappa0_feat * tensorsrht.rand_signs2 + # ee = np.fft.fftn(cc, axes=(-1,)) + # enorm = np.sum(ee.real**2 + ee.imag**2, axis=-1) + + # print(cc.sum()) + # print("1.5. nan is observed") + # import pdb; pdb.set_trace(); + # else: + # print("sdfdsfs") + # print(cc.sum()) + # import pdb; pdb.set_trace(); + + elif method == ReluFeaturesMethod.PSRF: # Combination of PolySketch and Random Features. + W0: np.ndarray = input[0] + polysketch: PolyTensorSketch = input[1] + tensorsrht: TensorSRHT = input[2] + kappa1_coeff: np.ndarray = input[3] + + polysketch_feats = polysketch.sketch(nngp_feat_2d) + kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) + del polysketch_feats + + nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + + (-1,)) + + nngp_proj = np.concatenate( + (nngp_feat_2d.real, nngp_feat_2d.imag), axis=1) @ W0 + kappa0_feat = np.concatenate( + ((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / W0.shape[-1]**0.5 + del W0 + + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. + ntk_feat = tensorsrht.sketch(ntk_feat_2d, + kappa0_feat).reshape(input_shape + (-1,)) + + elif method == ReluFeaturesMethod.POLY: # Polynomial approximation without sketching. + kappa0_coeff: np.ndarray = input[0] + kappa1_coeff: np.ndarray = input[1] + + gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) + nngp_feat = _cholesky(np.polyval(kappa1_coeff[::-1], + gram_nngp)).reshape(input_shape + (-1,)) + + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) + ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + + elif method == ReluFeaturesMethod.EXACT: # Exact feature map computations via Cholesky decomposition. + nngp_feat = _cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) + + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = kappa0(nngp_feat_2d) + ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + + else: + raise NotImplementedError(f'Invalid method name: {method}') + + if method != ReluFeaturesMethod.RANDFEAT: + norms /= 2.0**0.5 + + # if np.any(np.isnan(nngp_feat)) or np.any(np.isnan(ntk_feat)): + # print("2. nan is observed") + # import pdb; pdb.set_trace(); + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + + return init_fn, feature_fn + + def _cholesky(mat): return np.linalg.cholesky(mat + 1e-8 * np.eye(mat.shape[0])) diff --git a/experimental/sketching.py b/experimental/sketching.py index 0b8ab858..25a072f3 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -118,11 +118,11 @@ def sketch(self, x): n = x.shape[0] log_degree = len(self.tree_rand_signs) V = [0 for i in range(log_degree)] + dtype = np.complex64 if x.real.dtype == np.float32 else np.complex128 for i in range(log_degree): deg = self.tree_rand_signs[i].shape[0] - V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), - dtype=np.complex64) + V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), dtype=dtype) for j in range(deg): if i == 0: V[i] = V[i].at[j, :, :].set( @@ -166,7 +166,9 @@ def sketch(self, x): def expand_feats(self, polysketch_feats, coeffs): n, sktch_dim = polysketch_feats[0].shape - Z = np.zeros((len(self.rand_signs), n), dtype=np.complex64) + dtype = np.complex64 if polysketch_feats[ + 0].real.dtype == np.float32 else np.complex128 + Z = np.zeros((len(self.rand_signs), n), dtype=dtype) Z = Z.at[0, :].set(np.sqrt(coeffs[0]) * np.ones(n)) degree = len(polysketch_feats) for i in range(degree): From 50212c281174d5c2cdf31717679c378ad907a3b3 Mon Sep 17 00:00:00 2001 From: insuhan Date: Tue, 10 May 2022 08:57:36 +0900 Subject: [PATCH 33/44] recover previous commit --- experimental/features.py | 349 --------------------------------------- 1 file changed, 349 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index e5d8c601..c33347ef 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -395,355 +395,6 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: return init_fn, feature_fn -@layer -def ReluFeatures2(method: str = 'RANDFEAT', - feature_dim0: int = 1, - feature_dim1: int = 1, - sketch_dim: int = 1, - poly_degree: int = 8, - poly_sketch_dim: int = 1, - generate_rand_mtx: bool = True): - - method = ReluFeaturesMethod(method.upper()) - - def init_fn(rng, input_shape): - nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) - net_shape = input_shape[2] - relu_layers_count = net_shape.count('R') - new_net_shape = net_shape + 'R' - - if method == ReluFeaturesMethod.POLYSKETCH: - new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) - rng1, rng2 = random.split(rng, 2) - - kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) - kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) - - # PolySketch expansion for nngp features. - polysketch = PolyTensorSketch(rng=rng1, - input_dim=nngp_feat_shape[-1] // - (1 + (relu_layers_count > 0)), - sketch_dim=poly_sketch_dim, - degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args - - # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = TensorSRHT( - rng=rng2, - input_dim1=ntk_feat_shape[-1] // (1 + (relu_layers_count > 0)), - input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, - sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args - - return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (polysketch, tensorsrht, kappa0_coeff, - kappa1_coeff) - - else: - raise NotImplementedError(f'Invalid method name: {method}') - - def feature_fn(f: Features, input=None, **kwargs) -> Features: - - input_shape: tuple = f.nngp_feat.shape[:-1] - nngp_feat_dim: tuple = f.nngp_feat.shape[-1] - ntk_feat_dim: tuple = f.ntk_feat.shape[-1] - - nngp_feat_2d: np.ndarray = f.nngp_feat.reshape(-1, nngp_feat_dim) - ntk_feat_2d: np.ndarray = f.ntk_feat.reshape(-1, ntk_feat_dim) - norms: np.ndarray = f.norms - - if method == ReluFeaturesMethod.POLYSKETCH: - polysketch: PolyTensorSketch = input[0] - tensorsrht: TensorSRHT = input[1] - kappa0_coeff: np.ndarray = input[2] - kappa1_coeff: np.ndarray = input[3] - - # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. - polysketch_feats = polysketch.sketch(nngp_feat_2d) - kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) - kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) - - # Apply SRHT to kappa1_feat so that dimension of nngp_feat is poly_sketch_dim//2. - nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + - (-1,)) - # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. - ntk_feat = tensorsrht.sketch(ntk_feat_2d, - kappa0_feat).reshape(input_shape + (-1,)) - - else: - raise NotImplementedError(f'Invalid method name: {method}') - - if method != ReluFeaturesMethod.RANDFEAT: - norms /= 2.0**0.5 - - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) - - return init_fn, feature_fn - - -import jax -import numpy as onp -@layer -def NormGaussFeatures(method: str = 'RANDFEAT', - feature_dim0: int = 1, - feature_dim1: int = 1, - sketch_dim: int = 1, - poly_degree: int = 8, - poly_sketch_dim: int = 1, - generate_rand_mtx: bool = True): - - method = ReluFeaturesMethod(method.upper()) - - def init_fn(rng, input_shape): - nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) - net_shape = input_shape[2] - relu_layers_count = net_shape.count('R') - new_net_shape = net_shape + 'R' - - # print(input_shape) - # print(relu_layers_count) - - if method == ReluFeaturesMethod.RANDFEAT: - rng1, rng2, rng3 = random.split(rng, 3) - if generate_rand_mtx: - # Random vectors for random features of arc-cosine kernel of order 0. - W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) - # Random vectors for random features of arc-cosine kernel of order 1. - W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) - else: - # if `generate_rand_mtx` is False, return random seeds and shapes instead of np.ndarray. - W0 = (rng1, (nngp_feat_shape[-1], feature_dim0)) - W1 = (rng2, (nngp_feat_shape[-1], feature_dim1)) - - # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = TensorSRHT(rng=rng3, - input_dim1=ntk_feat_shape[-1], - input_dim2=feature_dim0, - sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args - - return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (W0, W1, tensorsrht) - - elif method == ReluFeaturesMethod.POLYSKETCH: - new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) - rng1, rng2 = random.split(rng, 2) - - kappa1_coeff = np.exp(-1) * np.array([1/np.exp(jax.scipy.special.gammaln(i+1)) for i in range(poly_degree+1)]) - kappa0_coeff = np.exp(-1) * np.array([1/np.exp(jax.scipy.special.gammaln(i+1)) for i in range(poly_degree+1)]) - - # PolySketch expansion for nngp features. - polysketch = PolyTensorSketch(rng=rng1, - input_dim=nngp_feat_shape[-1] // - (1 + (relu_layers_count > 0)), - sketch_dim=poly_sketch_dim, - degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args - - # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = TensorSRHT( - rng=rng2, - input_dim1=ntk_feat_shape[-1] // (1 + (relu_layers_count > 0)), - input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, - sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args - - # print(f"new_nngp_feat_shape : {new_nngp_feat_shape}") - # print(f"new_ntk_feat_shape : {new_ntk_feat_shape}") - return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (polysketch, tensorsrht, kappa0_coeff, - kappa1_coeff) - - elif method == ReluFeaturesMethod.PSRF: - new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) - rng1, rng2, rng3 = random.split(rng, 3) - - kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) - - # PolySketch expansion for nngp features. - polysketch = PolyTensorSketch(rng=rng1, - input_dim=nngp_feat_shape[-1] // - (1 + (relu_layers_count > 0)), - sketch_dim=poly_sketch_dim, - degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args - - # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = TensorSRHT(rng=rng2, - input_dim1=ntk_feat_shape[-1] // - (1 + (relu_layers_count > 0)), - input_dim2=feature_dim0, - sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args - - # Random vectors for random features of arc-cosine kernel of order 0. - if relu_layers_count == 0: - W0 = random.normal(rng3, (2 * nngp_feat_shape[-1], feature_dim0 // 2)) - else: - W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0 // 2)) - - return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (W0, polysketch, tensorsrht, kappa1_coeff) - - elif method == ReluFeaturesMethod.POLY: - # This only uses the polynomial approximation without sketching. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) - - kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) - kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) - - return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (kappa0_coeff, kappa1_coeff) - - elif method == ReluFeaturesMethod.EXACT: - # The exact feature map computation is for debug. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) - - return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), () - - else: - raise NotImplementedError(f'Invalid method name: {method}') - - def feature_fn(f: Features, input=None, **kwargs) -> Features: - - input_shape: tuple = f.nngp_feat.shape[:-1] - nngp_feat_dim: tuple = f.nngp_feat.shape[-1] - ntk_feat_dim: tuple = f.ntk_feat.shape[-1] - - nngp_feat_2d: np.ndarray = f.nngp_feat.reshape(-1, nngp_feat_dim) - ntk_feat_2d: np.ndarray = f.ntk_feat.reshape(-1, ntk_feat_dim) - norms: np.ndarray = f.norms - - # if np.any(np.isnan(nngp_feat_2d)) or np.any(np.isnan(ntk_feat_2d)): - # print("1. nan is observed") - # import pdb; pdb.set_trace(); - - if method == ReluFeaturesMethod.RANDFEAT: # Random Features approach. - if generate_rand_mtx: - W0: np.ndarray = input[0] - W1: np.ndarray = input[1] - else: - W0 = random.normal(input[0][0], shape=input[0][1]) - W1 = random.normal(input[1][0], shape=input[1][1]) - tensorsrht: TensorSRHT = input[2] - - kappa0_feat = (nngp_feat_2d @ W0 > 0) / W0.shape[-1]**0.5 - del W0 - nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / - W1.shape[-1]**0.5).reshape(input_shape + (-1,)) - del W1 - ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat, - real_output=True).reshape(input_shape + - (-1,)) - - elif method == ReluFeaturesMethod.POLYSKETCH: - polysketch: PolyTensorSketch = input[0] - tensorsrht: TensorSRHT = input[1] - kappa0_coeff: np.ndarray = input[2] - kappa1_coeff: np.ndarray = input[3] - - # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. - polysketch_feats = polysketch.sketch(nngp_feat_2d) - kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) - kappa0_feat = polysketch.expand_feats(polysketch_feats, kappa0_coeff) - # del polysketch_feats - - # Apply SRHT to kappa1_feat so that dimension of nngp_feat is poly_sketch_dim//2. - nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + - (-1,)) - # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. - ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat).reshape(input_shape + (-1,)) - # print(f"ntk_feat_2d avg norm : {np.sum(ntk_feat_2d.real**2 + ntk_feat_2d.imag**2, axis=-1).mean()}") - # print(f"kappa0_feat avg norm : {np.sum(kappa0_feat.real**2 + kappa0_feat.imag**2, axis=-1).mean()}") - # print(f"ntk_feat avg norm : {np.sum(ntk_feat.real**2 + ntk_feat.imag**2, axis=-1).mean()}") - - # aa = np.fft.fftn(ntk_feat_2d * tensorsrht.rand_signs1, axes=(-1,))[:, tensorsrht.rand_inds1] - # bb = np.fft.fftn( (kappa0_feat * tensorsrht.rand_signs2), axes=(-1,))[:, tensorsrht.rand_inds2] - # bb = np.fft.fftn(kappa0_feat * tensorsrht.rand_signs2, axes=(-1,))[:, tensorsrht.rand_inds2] - # ntk_feat = np.sqrt(1 / tensorsrht.rand_inds1.shape[-1]) * (aa * bb).reshape(input_shape + (-1,)) - # if np.any(np.isnan(ntk_feat)) or np.any(np.isnan(nngp_feat)): - # kk = kappa0_feat * tensorsrht.rand_signs2 - # np.sum(kk.real **2 + kk.imag**2, axis=-1) - # bk = np.fft.fftn(kk, axes=(-1,)) - # np.sum(bk.real **2 + bk.imag**2, axis=-1) - # qq = np.fft.fftn(kappa0_feat * tensorsrht.rand_signs2, axes=(-1,)) - # np.sum(qq.real **2 + qq.imag**2, axis=-1) - # bb = np.fft.fftn( (kappa0_feat * tensorsrht.rand_signs2), axes=(-1,)) - # np.sum(bb.real**2 + bb.imag**2, axis=-1) - # np.sum(ntk_feat_2d.real ** 2 + ntk_feat_2d.imag**2, axis=-1) - # np.sum(kappa0_feat.real ** 2 + kappa0_feat.imag**2, axis=-1) - # np.sum(kk.real ** 2 + kk.imag**2, axis=-1) - - # np.sum(kappa0_feat.real ** 2 + kappa0_feat.imag**2, axis=-1) - - # cc = kappa0_feat * tensorsrht.rand_signs2 - # ee = np.fft.fftn(cc, axes=(-1,)) - # enorm = np.sum(ee.real**2 + ee.imag**2, axis=-1) - - # print(cc.sum()) - # print("1.5. nan is observed") - # import pdb; pdb.set_trace(); - # else: - # print("sdfdsfs") - # print(cc.sum()) - # import pdb; pdb.set_trace(); - - elif method == ReluFeaturesMethod.PSRF: # Combination of PolySketch and Random Features. - W0: np.ndarray = input[0] - polysketch: PolyTensorSketch = input[1] - tensorsrht: TensorSRHT = input[2] - kappa1_coeff: np.ndarray = input[3] - - polysketch_feats = polysketch.sketch(nngp_feat_2d) - kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) - del polysketch_feats - - nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + - (-1,)) - - nngp_proj = np.concatenate( - (nngp_feat_2d.real, nngp_feat_2d.imag), axis=1) @ W0 - kappa0_feat = np.concatenate( - ((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / W0.shape[-1]**0.5 - del W0 - - # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. - ntk_feat = tensorsrht.sketch(ntk_feat_2d, - kappa0_feat).reshape(input_shape + (-1,)) - - elif method == ReluFeaturesMethod.POLY: # Polynomial approximation without sketching. - kappa0_coeff: np.ndarray = input[0] - kappa1_coeff: np.ndarray = input[1] - - gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) - nngp_feat = _cholesky(np.polyval(kappa1_coeff[::-1], - gram_nngp)).reshape(input_shape + (-1,)) - - ntk = ntk_feat_2d @ ntk_feat_2d.T - kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) - ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) - - elif method == ReluFeaturesMethod.EXACT: # Exact feature map computations via Cholesky decomposition. - nngp_feat = _cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) - - ntk = ntk_feat_2d @ ntk_feat_2d.T - kappa0_mat = kappa0(nngp_feat_2d) - ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) - - else: - raise NotImplementedError(f'Invalid method name: {method}') - - if method != ReluFeaturesMethod.RANDFEAT: - norms /= 2.0**0.5 - - # if np.any(np.isnan(nngp_feat)) or np.any(np.isnan(ntk_feat)): - # print("2. nan is observed") - # import pdb; pdb.set_trace(); - - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) - - return init_fn, feature_fn - - def _cholesky(mat): return np.linalg.cholesky(mat + 1e-8 * np.eye(mat.shape[0])) From 52aacb54d25c388aa7fe2a17d74a36564cc2a3c4 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 2 Jun 2022 00:40:56 +0900 Subject: [PATCH 34/44] Add bias term for DenseFeatures --- experimental/features.py | 46 +++++---- experimental/sketching.py | 26 +++--- experimental/tests/fc_ntk_test.py | 108 +++++++++++----------- experimental/tests/kernel_approx_test.py | 8 +- experimental/tests/myrtle_network_test.py | 31 ++++--- experimental/tests/sketching_test.py | 4 +- 6 files changed, 123 insertions(+), 100 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index c33347ef..44452870 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -66,7 +66,7 @@ def _is_sinlge_shape(input_shape): if all(isinstance(n, int) for n in input_shape): return True elif (len(input_shape) == 2 or len(input_shape) == 3) and all( - _is_sinlge_shape(s) for s in input_shape[:2]): + _is_sinlge_shape(s) for s in input_shape[:2]): return False raise ValueError(input_shape) @@ -82,7 +82,7 @@ def feature_fn_x(x, input, **kwargs): feature = _inputs_to_features(x, **kwargs) return feature_fn(feature, input, **kwargs) - def feature_fn_any(x_or_feature, input=None, **kwargs): + def feature_fn_any(x_or_feature, input, **kwargs): if isinstance(x_or_feature, Features): return feature_fn_feature(x_or_feature, input, **kwargs) return feature_fn_x(x_or_feature, input, **kwargs) @@ -98,8 +98,8 @@ def _inputs_to_features(x: np.ndarray, # Followed the same initialization of Neural Tangents library. nngp_feat = x / x.shape[channel_axis]**0.5 - norms = np.linalg.norm(nngp_feat, axis=channel_axis) - norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) + norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) nngp_feat = nngp_feat / norms ntk_feat = np.zeros((), dtype=nngp_feat.dtype) @@ -145,10 +145,6 @@ def DenseFeatures(out_dim: int, batch_axis: int = 0, channel_axis: int = -1): - if b_std is not None: - raise NotImplementedError('Bias variable b_std is not implemented yet .' - ' Please set b_std to be None.') - if parameterization != 'ntk': raise NotImplementedError(f'Parameterization ({parameterization}) is ' ' not implemented yet.') @@ -168,7 +164,23 @@ def feature_fn(f: Features, input, **kwargs): ntk_feat: np.ndarray = f.ntk_feat norms: np.ndarray = f.norms - norms *= W_std + if b_std is not None: + f_renomalized: Features = _unnormalize_features(f) + nngp_feat: np.ndarray = f_renomalized.nngp_feat + ntk_feat: np.ndarray = f_renomalized.ntk_feat + + biases = b_std * np.ones((nngp_feat.shape[0], 1), dtype=nngp_feat.dtype) + nngp_feat = np.concatenate((W_std * nngp_feat, biases), axis=-1) + ntk_feat = W_std * ntk_feat + + norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) + + nngp_feat = nngp_feat / norms + ntk_feat = ntk_feat / norms + + else: + norms *= W_std if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat @@ -371,7 +383,7 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) nngp_feat = _cholesky(np.polyval(kappa1_coeff[::-1], - gram_nngp)).reshape(input_shape + (-1,)) + gram_nngp)).reshape(input_shape + (-1,)) ntk = ntk_feat_2d @ ntk_feat_2d.T kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) @@ -519,8 +531,8 @@ def feature_fn(f, input, **kwargs): ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) # Re-normalize the features. - norms = norms = np.linalg.norm(nngp_feat, axis=channel_axis) - norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) + norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) nngp_feat = nngp_feat / norms ntk_feat = ntk_feat / norms @@ -601,8 +613,8 @@ def feature_fn(f, input=None, **kwargs): ntk_feat = pooling(ntk_feat) # Re-normalize the features. - norms = norms = np.linalg.norm(nngp_feat, axis=channel_axis) - norms = np.expand_dims(np.where(norms > 0, norms, 1.0), channel_axis) + norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) nngp_feat = nngp_feat / norms ntk_feat = ntk_feat / norms @@ -613,7 +625,7 @@ def feature_fn(f, input=None, **kwargs): return init_fn, feature_fn - +# TODO(insu): fix reshaping features for general batch/channel axes. @layer def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): @@ -649,8 +661,8 @@ def feature_fn(f, input=None, **kwargs): ntk_feat.shape[1:-1])**0.5 # Re-normalize the features. - norms = norms = np.linalg.norm(nngp_feat, axis=-1) - norms = np.expand_dims(np.where(norms > 0, norms, 1.0), -1) + norms = np.linalg.norm(nngp_feat, axis=-1, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) nngp_feat = nngp_feat / norms ntk_feat = ntk_feat / norms diff --git a/experimental/sketching.py b/experimental/sketching.py index 25a072f3..124ce316 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -24,8 +24,8 @@ class TensorSRHT: def init_sketches(self) -> 'TensorSRHT': rng1, rng2, rng3, rng4 = random.split(self.rng, 4) - rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 - rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 + rand_signs1 = random.bernoulli(rng1, shape=(self.input_dim1,)) * 2 - 1 + rand_signs2 = random.bernoulli(rng2, shape=(self.input_dim2,)) * 2 - 1 rand_inds1 = random.choice(rng3, self.input_dim1, shape=(self.sketch_dim // 2,)) @@ -42,7 +42,7 @@ def init_sketches(self) -> 'TensorSRHT': def sketch(self, x1, x2, real_output=False): x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] - out = np.sqrt(1 / self.rand_inds1.shape[-1]) * (x1fft * x2fft) + out = self.rand_inds1.shape[-1]**(-0.5) * (x1fft * x2fft) return np.concatenate((out.real, out.imag), 1) if real_output else out @@ -76,22 +76,22 @@ def init_sketches(self) -> 'PolyTensorSketch': rng1, rng2 = random.split(rng1) if i == 0: - tree_rand_signs[i] = random.choice( - rng1, 2, shape=(deg_, 2, self.input_dim)) * 2 - 1 + tree_rand_signs[i] = random.bernoulli( + rng1, shape=(deg_, 2, self.input_dim)) * 2 - 1 tree_rand_inds[i] = random.choice(rng2, self.input_dim, shape=(deg_, 2, ske_dim_)) else: - tree_rand_signs[i] = random.choice(rng1, 2, - shape=(deg_, 2, ske_dim_)) * 2 - 1 + tree_rand_signs[i] = random.bernoulli(rng1, + shape=(deg_, 2, ske_dim_)) * 2 - 1 tree_rand_inds[i] = random.choice(rng2, ske_dim_, shape=(deg_, 2, ske_dim_)) deg_ = deg_ // 2 rng1, rng2 = random.split(rng3, 2) - rand_signs = random.choice(rng1, 2, - shape=(1 + self.degree * ske_dim_,)) * 2 - 1 + rand_signs = random.bernoulli(rng1, + shape=(1 + self.degree * ske_dim_,)) * 2 - 1 rand_inds = random.choice(rng2, 1 + self.degree * ske_dim_, shape=(self.sketch_dim // 2,)) @@ -105,14 +105,14 @@ def init_sketches(self) -> 'PolyTensorSketch': def tensorsrht(self, x1, x2, rand_inds, rand_signs): x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]] x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]] - return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) + return rand_inds.shape[1]**(-0.5) * (x1fft * x2fft) # Standard SRHT def standardsrht(self, x, rand_inds=None, rand_signs=None): rand_inds = self.rand_inds if rand_inds is None else rand_inds rand_signs = self.rand_signs if rand_signs is None else rand_signs xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] - return np.sqrt(1 / rand_inds.shape[0]) * xfft + return rand_inds.shape[0]**(-0.5) * xfft def sketch(self, x): n = x.shape[0] @@ -169,11 +169,11 @@ def expand_feats(self, polysketch_feats, coeffs): dtype = np.complex64 if polysketch_feats[ 0].real.dtype == np.float32 else np.complex128 Z = np.zeros((len(self.rand_signs), n), dtype=dtype) - Z = Z.at[0, :].set(np.sqrt(coeffs[0]) * np.ones(n)) + Z = Z.at[0, :].set(coeffs[0]**0.5 * np.ones(n)) degree = len(polysketch_feats) for i in range(degree): Z = Z.at[sktch_dim * i + 1:sktch_dim * (i + 1) + 1, :].set( - np.sqrt(coeffs[i + 1]) * polysketch_feats[degree - i - 1].T) + coeffs[i + 1]**0.5 * polysketch_feats[degree - i - 1].T) return Z.T # pytype: enable=attribute-error \ No newline at end of file diff --git a/experimental/tests/fc_ntk_test.py b/experimental/tests/fc_ntk_test.py index dc967640..a6e3372a 100644 --- a/experimental/tests/fc_ntk_test.py +++ b/experimental/tests/fc_ntk_test.py @@ -21,13 +21,15 @@ width = 512 # this does not matter the output W_std = 1.234 # std of Gaussian random weights +b_std = 0.567 # std of the biases +dense_kwargs = {"out_dim": width, "W_std": W_std, "b_std": b_std} print("================== Result of Neural Tangent Library ===================") -init_fn, _, kernel_fn = stax.serial(stax.Dense(width, W_std=W_std), stax.Relu(), - stax.Dense(width, W_std=W_std), stax.Relu(), - stax.Dense(width, W_std=W_std), stax.Relu(), - stax.Dense(1, W_std=W_std)) +init_fn, _, kernel_fn = stax.serial(stax.Dense(**dense_kwargs), stax.Relu(), + stax.Dense(**dense_kwargs), stax.Relu(), + stax.Dense(**dense_kwargs), stax.Relu(), + stax.Dense(1, W_std=W_std, b_std=b_std)) kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) nt_kernel = kernel_fn(x, None) @@ -50,10 +52,10 @@ def test_fc_relu_ntk_approx(relufeat_arg, init_fn=None, feature_fn=None): if init_fn is None or feature_fn is None: init_fn, feature_fn = serial( - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(width, W_std=W_std), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std)) + DenseFeatures(**dense_kwargs), ReluFeatures(**relufeat_arg), + DenseFeatures(**dense_kwargs), ReluFeatures(**relufeat_arg), + DenseFeatures(**dense_kwargs), ReluFeatures(**relufeat_arg), + DenseFeatures(1, W_std=W_std, b_std=b_std)) # Initialize random vectors and sketching algorithms _, feat_fn_inputs = init_fn(key2, x.shape) @@ -89,64 +91,64 @@ def test_fc_relu_ntk_approx(relufeat_arg, init_fn=None, feature_fn=None): -print("==================== Result of NTK Random Features ====================") +# print("==================== Result of NTK Random Features ====================") -kappa0_feat_dim = 4096 -kappa1_feat_dim = 4096 -sketch_dim = 4096 +# kappa0_feat_dim = 4096 +# kappa1_feat_dim = 4096 +# sketch_dim = 4096 -test_fc_relu_ntk_approx({ - 'method': 'RANDFEAT', - 'feature_dim0': kappa0_feat_dim, - 'feature_dim1': kappa1_feat_dim, - 'sketch_dim': sketch_dim, -}) +# test_fc_relu_ntk_approx({ +# 'method': 'RANDFEAT', +# 'feature_dim0': kappa0_feat_dim, +# 'feature_dim1': kappa1_feat_dim, +# 'sketch_dim': sketch_dim, +# }) -print("==================== Result of NTK wih PolySketch ====================") +# print("==================== Result of NTK wih PolySketch ====================") -poly_degree = 4 -poly_sketch_dim = 4096 -sketch_dim = 4096 +# poly_degree = 4 +# poly_sketch_dim = 4096 +# sketch_dim = 4096 -test_fc_relu_ntk_approx({ - 'method': 'POLYSKETCH', - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim -}) +# test_fc_relu_ntk_approx({ +# 'method': 'POLYSKETCH', +# 'sketch_dim': sketch_dim, +# 'poly_degree': poly_degree, +# 'poly_sketch_dim': poly_sketch_dim +# }) -print("=============== Result of PolySketch + Random Features ===============") +# print("=============== Result of PolySketch + Random Features ===============") -kappa0_feat_dim = 2048 -sketch_dim = 4096 -poly_degree = 4 -poly_sketch_dim = 4096 +# kappa0_feat_dim = 2048 +# sketch_dim = 4096 +# poly_degree = 4 +# poly_sketch_dim = 4096 -test_fc_relu_ntk_approx({ - 'method': 'PSRF', - 'feature_dim0': kappa0_feat_dim, - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim -}) +# test_fc_relu_ntk_approx({ +# 'method': 'PSRF', +# 'feature_dim0': kappa0_feat_dim, +# 'sketch_dim': sketch_dim, +# 'poly_degree': poly_degree, +# 'poly_sketch_dim': poly_sketch_dim +# }) -print("=========== Result of ReLU-NTK Sketch (one-pass sketching) ===========") +# print("=========== Result of ReLU-NTK Sketch (one-pass sketching) ===========") -relufeat_arg = { - 'num_layers': 3, - 'poly_degree': 32, - 'poly_sketch_dim': 4096, - 'W_std': W_std, -} +# relufeat_arg = { +# 'num_layers': 3, +# 'poly_degree': 32, +# 'poly_sketch_dim': 4096, +# 'W_std': W_std, +# } -init_fn, feature_fn = ReluNTKFeatures(**relufeat_arg) -test_fc_relu_ntk_approx(relufeat_arg, init_fn, feature_fn) +# init_fn, feature_fn = ReluNTKFeatures(**relufeat_arg) +# test_fc_relu_ntk_approx(relufeat_arg, init_fn, feature_fn) -print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") -print("\t(*No Sketching algorithm is applied.)") +# print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") +# print("\t(*No Sketching algorithm is applied.)") -test_fc_relu_ntk_approx({'method': 'POLY', 'poly_degree': 16}) +# test_fc_relu_ntk_approx({'method': 'POLY', 'poly_degree': 16}) -print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") +# print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") test_fc_relu_ntk_approx({'method': 'EXACT'}) diff --git a/experimental/tests/kernel_approx_test.py b/experimental/tests/kernel_approx_test.py index 9d1344b8..90c83c3e 100644 --- a/experimental/tests/kernel_approx_test.py +++ b/experimental/tests/kernel_approx_test.py @@ -81,7 +81,7 @@ def fc_relu_ntk_sketching(relufeat_arg, seed = 1 -n, d = 1000, 28 * 28 +n, d = 10, 32 no_jitting = False key1, key2, key3 = random.split(random.PRNGKey(seed), 3) @@ -124,9 +124,9 @@ def fc_relu_ntk_sketching(relufeat_arg, print() print("ReluNTKFeatures results:") relufeat_arg = { - 'num_layers': depth, - 'poly_degree': 16, - 'W_std': W_std, + 'num_layers': depth, + 'poly_degree': 16, + 'W_std': W_std, } for feat_dim in grad_feat_dims_all: diff --git a/experimental/tests/myrtle_network_test.py b/experimental/tests/myrtle_network_test.py index fd920936..e455e4f7 100644 --- a/experimental/tests/myrtle_network_test.py +++ b/experimental/tests/myrtle_network_test.py @@ -1,8 +1,9 @@ import time import sys -sys.path.append("./") +sys.path.insert(0, "./") import functools from numpy.linalg import norm +import jax from jax.config import config from jax import jit # Enable float64 for JAX @@ -16,6 +17,9 @@ layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} +H, W, C = 4, 4, 3; num_final_avgpools = 0 +# H, W, C = 32, 32, 3; num_final_avgpools = 3 + def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0., width=1): activation_fn = stax.Relu() conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME') @@ -26,7 +30,7 @@ def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0., width=1): layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1] layers += [stax.AvgPool((2, 2), strides=(2, 2))] layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2] - layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3 + layers += [stax.AvgPool((2, 2), strides=(2, 2))] * num_final_avgpools layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)] @@ -50,25 +54,26 @@ def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), width=1, **relu_args): ConvFeatures(width, filter_shape=(3, 3), W_std=W_std), ReluFeatures(**relu_args) ] * layer_factor[depth][2] - layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] * 3 + layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] * num_final_avgpools layers += [FlattenFeatures(), DenseFeatures(1, W_std)] return serial(*layers) -def test_small_dataset(num_data=4, dataset='synthetic', depth=5, no_jitting=False): +def test_small_dataset(num_data=4, dataset='synthetic', depth=5, do_jit=True): print(f"dataset : {dataset}") key = random.PRNGKey(0) if dataset == 'synthetic': - H, W, C = 32, 32, 3 x = random.normal(key, shape=(num_data, H, W, C)) elif dataset in ['cifar10', 'cifar100']: from examples import datasets x = datasets.get_dataset('cifar10', do_flatten_and_normalize=False)[0] + if (H, W) != (32, 32): + x = jax.image.resize(x, (x.shape[0], H, W, 3), method='linear') mean_ = np.mean(x) std_ = np.std(x) x = (x - mean_) / std_ @@ -82,7 +87,7 @@ def test_small_dataset(num_data=4, dataset='synthetic', depth=5, no_jitting=Fals print("================= Result of Neural Tangent Library =================") _, _, kernel_fn = MyrtleNetwork(depth) - kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) + kernel_fn = jit(kernel_fn) if do_jit else kernel_fn tic = time.time() nt_kernel = kernel_fn(x) @@ -108,10 +113,14 @@ def test_myrtle_network_approx(relufeat_arg): init_fn, feature_fn = MyrtleNetworkFeatures(depth, **relufeat_arg) # Initialize random vectors and sketching algorithms - _, feat_fn_inputs = init_fn(key2, x.shape) + # _, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map - feature_fn = feature_fn if no_jitting else jit(feature_fn) + # feature_fn = jit(feature_fn) if do_jit else feature_fn + @jit + def feature_fn_joint(x, key): + _, feat_fn_inputs = init_fn(key, x.shape) + return feature_fn(x, feat_fn_inputs) tic = time.time() feats = feature_fn(x, feat_fn_inputs) @@ -193,6 +202,6 @@ def test_myrtle_network_approx(relufeat_arg): test_myrtle_network_approx({'method': 'EXACT'}) if __name__ == "__main__": - test_small_dataset(num_data=6, dataset='synthetic', depth=5, no_jitting=False) - test_small_dataset(num_data=6, dataset='cifar10', depth=5, no_jitting=False) - test_small_dataset(num_data=6, dataset='cifar100', depth=5, no_jitting=False) \ No newline at end of file + test_small_dataset(num_data=4, dataset='synthetic', depth=5) + # test_small_dataset(num_data=4, dataset='cifar10', depth=5) + # test_small_dataset(num_data=4, dataset='cifar100', depth=5) \ No newline at end of file diff --git a/experimental/tests/sketching_test.py b/experimental/tests/sketching_test.py index c7bf7a45..247019f4 100644 --- a/experimental/tests/sketching_test.py +++ b/experimental/tests/sketching_test.py @@ -16,8 +16,8 @@ rng = random.PRNGKey(1) x = random.normal(rng, shape=(n, d)) -norm_x = jnp.linalg.norm(x, axis=-1) -x_normalized = x / norm_x[:, None] +norm_x = jnp.linalg.norm(x, axis=-1, keepdims=True) +x_normalized = x / norm_x rng2 = random.PRNGKey(2) pts = PolyTensorSketch(rng=rng2, From 0995169cde1eaf84173bf10e1cd7056317ffa499 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 2 Jun 2022 01:12:41 +0900 Subject: [PATCH 35/44] Fix typo --- experimental/tests/myrtle_network_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/experimental/tests/myrtle_network_test.py b/experimental/tests/myrtle_network_test.py index e455e4f7..b3e69c49 100644 --- a/experimental/tests/myrtle_network_test.py +++ b/experimental/tests/myrtle_network_test.py @@ -113,14 +113,10 @@ def test_myrtle_network_approx(relufeat_arg): init_fn, feature_fn = MyrtleNetworkFeatures(depth, **relufeat_arg) # Initialize random vectors and sketching algorithms - # _, feat_fn_inputs = init_fn(key2, x.shape) + _, feat_fn_inputs = init_fn(key2, x.shape) # Transform input vectors to NNGP/NTK feature map - # feature_fn = jit(feature_fn) if do_jit else feature_fn - @jit - def feature_fn_joint(x, key): - _, feat_fn_inputs = init_fn(key, x.shape) - return feature_fn(x, feat_fn_inputs) + feature_fn = jit(feature_fn) if do_jit else feature_fn tic = time.time() feats = feature_fn(x, feat_fn_inputs) From 70f53fee4097771270607375ff96332bf16cfed3 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 8 Jun 2022 01:46:49 +0900 Subject: [PATCH 36/44] Add bias in ConvFeatures --- experimental/features.py | 105 +++++-- experimental/poly_fitting.py | 2 +- experimental/sketching.py | 154 +++++---- experimental/tests/fc_ntk_test.py | 154 --------- experimental/tests/features_test.py | 361 ++++++++++++++++++++++ experimental/tests/kernel_approx_test.py | 141 --------- experimental/tests/myrtle_network_test.py | 203 ------------ experimental/tests/sketching_test.py | 79 +++-- 8 files changed, 554 insertions(+), 645 deletions(-) delete mode 100644 experimental/tests/fc_ntk_test.py create mode 100644 experimental/tests/features_test.py delete mode 100644 experimental/tests/kernel_approx_test.py delete mode 100644 experimental/tests/myrtle_network_test.py diff --git a/experimental/features.py b/experimental/features.py index 44452870..e4b00763 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -8,6 +8,7 @@ from neural_tangents import stax from neural_tangents._src.utils import dataclasses +from neural_tangents._src.utils.typing import Axes from neural_tangents._src.stax.linear import _pool_kernel, Padding, _get_dimension_numbers from neural_tangents._src.stax.linear import _Pooling as Pooling @@ -54,7 +55,9 @@ def _preprocess_init_fn(init_fn): def init_fn_any(rng, input_shape_any, **kwargs): if _is_sinlge_shape(input_shape_any): - input_shape = (input_shape_any, (-1, 0)) # Add a dummy shape for ntk_feat + # Add a dummy shape for ntk_feat + dummy_shape = (-1,) + (0,) * (len(input_shape_any) - 1) + input_shape = (input_shape_any, dummy_shape) return init_fn(rng, input_shape, **kwargs) else: return init_fn(rng, input_shape_any, **kwargs) @@ -66,7 +69,7 @@ def _is_sinlge_shape(input_shape): if all(isinstance(n, int) for n in input_shape): return True elif (len(input_shape) == 2 or len(input_shape) == 3) and all( - _is_sinlge_shape(s) for s in input_shape[:2]): + _is_sinlge_shape(s) for s in input_shape[:2]): return False raise ValueError(input_shape) @@ -151,15 +154,21 @@ def DenseFeatures(out_dim: int, def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_ntk_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] + - ntk_feat_shape[-1],) + + nngp_feat_dim = nngp_feat_shape[channel_axis] + if b_std is not None: + nngp_feat_dim += 1 + + new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_dim,) + new_ntk_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_dim + + ntk_feat_shape[channel_axis],) if len(input_shape) > 2: - return (nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () else: - return (nngp_feat_shape, new_ntk_feat_shape, 'D'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, 'D'), () - def feature_fn(f: Features, input, **kwargs): + def feature_fn(f: Features, input=None, **kwargs): nngp_feat: np.ndarray = f.nngp_feat ntk_feat: np.ndarray = f.ntk_feat norms: np.ndarray = f.norms @@ -170,14 +179,14 @@ def feature_fn(f: Features, input, **kwargs): ntk_feat: np.ndarray = f_renomalized.ntk_feat biases = b_std * np.ones((nngp_feat.shape[0], 1), dtype=nngp_feat.dtype) - nngp_feat = np.concatenate((W_std * nngp_feat, biases), axis=-1) + nngp_feat = np.concatenate((W_std * nngp_feat, biases), axis=channel_axis) ntk_feat = W_std * ntk_feat norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) norms = np.where(norms > 0, norms, 1.0) nngp_feat = nngp_feat / norms - ntk_feat = ntk_feat / norms + ntk_feat = ntk_feat / norms if ntk_feat.ndim != 0 else ntk_feat else: norms *= W_std @@ -240,18 +249,25 @@ def init_fn(rng, input_shape): kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) # PolySketch expansion for nngp features. + if relu_layers_count == 0: + pts_input_dim = nngp_feat_shape[-1] + else: + pts_input_dim = int(nngp_feat_shape[-1] / 2 + 0.5) polysketch = PolyTensorSketch(rng=rng1, - input_dim=nngp_feat_shape[-1] // - (1 + (relu_layers_count > 0)), + input_dim=pts_input_dim, sketch_dim=poly_sketch_dim, degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args # TensorSRHT of degree 2 for approximating tensor product. - tensorsrht = TensorSRHT( - rng=rng2, - input_dim1=ntk_feat_shape[-1] // (1 + (relu_layers_count > 0)), - input_dim2=poly_degree * (polysketch.sketch_dim // 4 - 1) + 1, - sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args + if relu_layers_count == 0: + ts_input_dim = ntk_feat_shape[-1] + else: + ts_input_dim = int(ntk_feat_shape[-1] / 2 + 0.5) + tensorsrht = TensorSRHT(rng=rng2, + input_dim1=ts_input_dim, + input_dim2=poly_degree * + (polysketch.sketch_dim // 4 - 1) + 1, + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), (polysketch, tensorsrht, kappa0_coeff, @@ -264,24 +280,31 @@ def init_fn(rng, input_shape): kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) # PolySketch expansion for nngp features. + if relu_layers_count == 0: + pts_input_dim = nngp_feat_shape[-1] + else: + pts_input_dim = int(nngp_feat_shape[-1] / 2 + 0.5) polysketch = PolyTensorSketch(rng=rng1, - input_dim=nngp_feat_shape[-1] // - (1 + (relu_layers_count > 0)), + input_dim=pts_input_dim, sketch_dim=poly_sketch_dim, degree=poly_degree).init_sketches() # pytype:disable=wrong-keyword-args # TensorSRHT of degree 2 for approximating tensor product. + if relu_layers_count == 0: + ts_input_dim = ntk_feat_shape[-1] + else: + ts_input_dim = int(ntk_feat_shape[-1] / 2 + 0.5) tensorsrht = TensorSRHT(rng=rng2, - input_dim1=ntk_feat_shape[-1] // - (1 + (relu_layers_count > 0)), + input_dim1=ts_input_dim, input_dim2=feature_dim0, sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args # Random vectors for random features of arc-cosine kernel of order 0. if relu_layers_count == 0: - W0 = random.normal(rng3, (2 * nngp_feat_shape[-1], feature_dim0 // 2)) - else: W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0 // 2)) + else: + W0 = random.normal( + rng3, (int(nngp_feat_shape[-1] / 2 + 0.5), feature_dim0 // 2)) return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), (W0, polysketch, tensorsrht, kappa1_coeff) @@ -367,8 +390,7 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + (-1,)) - nngp_proj = np.concatenate( - (nngp_feat_2d.real, nngp_feat_2d.imag), axis=1) @ W0 + nngp_proj = nngp_feat_2d @ W0 kappa0_feat = np.concatenate( ((nngp_proj > 0), (nngp_proj <= 0)), axis=1) / W0.shape[-1]**0.5 del W0 @@ -470,10 +492,6 @@ def ConvFeatures(out_chan: int, dimension_numbers: Optional[Tuple[str, str, str]] = None, parameterization: str = 'ntk'): - if b_std is not None: - raise NotImplementedError('Bias variable b_std is not implemented yet .' - ' Please set b_std to be None.') - parameterization = parameterization.lower() if dimension_numbers is None: @@ -491,10 +509,14 @@ def ConvFeatures(out_chan: int, def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_shape[-1] * - patch_size,) + + nngp_feat_dim = nngp_feat_shape[channel_axis] * patch_size + if b_std is not None: + nngp_feat_dim += 1 + + new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_dim,) new_ntk_feat_shape = nngp_feat_shape[:-1] + ( - (nngp_feat_shape[-1] + ntk_feat_shape[-1]) * patch_size,) + nngp_feat_dim + ntk_feat_shape[channel_axis] * patch_size,) if len(input_shape) > 2: return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'C'), () @@ -523,6 +545,11 @@ def feature_fn(f, input, **kwargs): nngp_feat = _concat_shifted_features_2d( nngp_feat, filter_shape_) * W_std / patch_size**0.5 + if b_std is not None: + biases = b_std * np.ones(nngp_feat.shape[:-1] + (1,), + dtype=nngp_feat.dtype) + nngp_feat = np.concatenate((nngp_feat, biases), axis=channel_axis) + if f.ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: @@ -625,6 +652,7 @@ def feature_fn(f, input=None, **kwargs): return init_fn, feature_fn + # TODO(insu): fix reshaping features for general batch/channel axes. @layer def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): @@ -672,3 +700,18 @@ def feature_fn(f, input=None, **kwargs): is_reversed=False) return init_fn, feature_fn + + +@layer +def LayerNormFeatures(axis: Axes = -1, + eps: float = 1e-12, + batch_axis: int = 0, + channel_axis: int = -1): + + def init_fn(rng, input_shape): + return input_shape, () + + def feature_fn(f, input=None, **kwargs): + return f.replace(norms=np.ones_like(f.norms)) + + return init_fn, feature_fn \ No newline at end of file diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index d985b2a3..06dc8c50 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -49,7 +49,7 @@ def poly_fitting_qp(xvals: np.ndarray, x_powers = x_powers.at[:, i + 1].set(x_powers[:, i] * xvals) y_weighted = fvals * weights - x_powers_weighted = x_powers.T * weights + x_powers_weighted = x_powers.T * weights[None,:] dx_powers = x_powers[:-1, :] - x_powers[1:, :] diff --git a/experimental/sketching.py b/experimental/sketching.py index 124ce316..9e34eabc 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -1,9 +1,16 @@ from jax import random from jax import numpy as np +from jax.numpy.fft import fftn from neural_tangents._src.utils import dataclasses from typing import Optional, Callable +def _random_signs_indices(rngs, input_dim, output_dim, shape=()): + rand_signs = random.bernoulli(rngs[0], shape=shape + (input_dim,)) * 2 - 1. + rand_inds = random.choice(rngs[1], input_dim, shape=shape + (output_dim,)) + return rand_signs, rand_inds + + # TensorSRHT of degree 2. This version allows different input vectors. @dataclasses.dataclass class TensorSRHT: @@ -24,14 +31,10 @@ class TensorSRHT: def init_sketches(self) -> 'TensorSRHT': rng1, rng2, rng3, rng4 = random.split(self.rng, 4) - rand_signs1 = random.bernoulli(rng1, shape=(self.input_dim1,)) * 2 - 1 - rand_signs2 = random.bernoulli(rng2, shape=(self.input_dim2,)) * 2 - 1 - rand_inds1 = random.choice(rng3, - self.input_dim1, - shape=(self.sketch_dim // 2,)) - rand_inds2 = random.choice(rng4, - self.input_dim2, - shape=(self.sketch_dim // 2,)) + rand_signs1, rand_inds1 = _random_signs_indices( + (rng1, rng3), self.input_dim1, self.sketch_dim // 2) + rand_signs2, rand_inds2 = _random_signs_indices( + (rng2, rng4), self.input_dim2, self.sketch_dim // 2) shape = (self.input_dim1, self.input_dim2, self.sketch_dim) return self.replace(shape=shape, rand_signs1=rand_signs1, @@ -40,13 +43,12 @@ def init_sketches(self) -> 'TensorSRHT': rand_inds2=rand_inds2) def sketch(self, x1, x2, real_output=False): - x1fft = np.fft.fftn(x1 * self.rand_signs1, axes=(-1,))[:, self.rand_inds1] - x2fft = np.fft.fftn(x2 * self.rand_signs2, axes=(-1,))[:, self.rand_inds2] - out = self.rand_inds1.shape[-1]**(-0.5) * (x1fft * x2fft) + x1fft = fftn(x1 * self.rand_signs1[None, :], axes=(-1,))[:, self.rand_inds1] + x2fft = fftn(x2 * self.rand_signs2[None, :], axes=(-1,))[:, self.rand_inds2] + out = (x1fft * x2fft) / self.rand_inds1.shape[-1]**0.5 return np.concatenate((out.real, out.imag), 1) if real_output else out -# pytype: disable=attribute-error @dataclasses.dataclass class PolyTensorSketch: @@ -64,37 +66,27 @@ class PolyTensorSketch: replace = ... # type: Callable[..., 'PolyTensorSketch'] def init_sketches(self) -> 'PolyTensorSketch': - - tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] - tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] + height = (self.degree - 1).bit_length() + tree_rand_signs = [0] * height + tree_rand_inds = [0] * height rng1, rng3 = random.split(self.rng, 2) - ske_dim_ = self.sketch_dim // 4 - 1 - deg_ = self.degree // 2 + internal_sketch_dim = self.sketch_dim // 4 - 1 + degree = self.degree // 2 - for i in range((self.degree - 1).bit_length()): + for lvl in range(height): rng1, rng2 = random.split(rng1) - if i == 0: - tree_rand_signs[i] = random.bernoulli( - rng1, shape=(deg_, 2, self.input_dim)) * 2 - 1 - tree_rand_inds[i] = random.choice(rng2, - self.input_dim, - shape=(deg_, 2, ske_dim_)) - else: - tree_rand_signs[i] = random.bernoulli(rng1, - shape=(deg_, 2, ske_dim_)) * 2 - 1 - tree_rand_inds[i] = random.choice(rng2, - ske_dim_, - shape=(deg_, 2, ske_dim_)) - deg_ = deg_ // 2 + input_dim = self.input_dim if lvl == 0 else internal_sketch_dim + tree_rand_signs[lvl], tree_rand_inds[lvl] = _random_signs_indices( + (rng1, rng2), input_dim, internal_sketch_dim, (degree, 2)) + + degree = degree // 2 rng1, rng2 = random.split(rng3, 2) - rand_signs = random.bernoulli(rng1, - shape=(1 + self.degree * ske_dim_,)) * 2 - 1 - rand_inds = random.choice(rng2, - 1 + self.degree * ske_dim_, - shape=(self.sketch_dim // 2,)) + rand_signs, rand_inds = _random_signs_indices( + (rng1, rng2), 1 + self.degree * internal_sketch_dim, + self.sketch_dim // 2) return self.replace(tree_rand_signs=tree_rand_signs, tree_rand_inds=tree_rand_inds, @@ -103,77 +95,71 @@ def init_sketches(self) -> 'PolyTensorSketch': # TensorSRHT of degree 2 def tensorsrht(self, x1, x2, rand_inds, rand_signs): - x1fft = np.fft.fftn(x1 * rand_signs[0, :], axes=(-1,))[:, rand_inds[0, :]] - x2fft = np.fft.fftn(x2 * rand_signs[1, :], axes=(-1,))[:, rand_inds[1, :]] + x1fft = fftn(x1 * rand_signs[:1, :], axes=(-1,))[:, rand_inds[0, :]] + x2fft = fftn(x2 * rand_signs[1:, :], axes=(-1,))[:, rand_inds[1, :]] return rand_inds.shape[1]**(-0.5) * (x1fft * x2fft) # Standard SRHT def standardsrht(self, x, rand_inds=None, rand_signs=None): rand_inds = self.rand_inds if rand_inds is None else rand_inds rand_signs = self.rand_signs if rand_signs is None else rand_signs - xfft = np.fft.fftn(x * rand_signs, axes=(-1,))[:, rand_inds] + xfft = fftn(x * rand_signs[None, :], axes=(-1,))[:, rand_inds] return rand_inds.shape[0]**(-0.5) * xfft def sketch(self, x): n = x.shape[0] - log_degree = len(self.tree_rand_signs) - V = [0 for i in range(log_degree)] dtype = np.complex64 if x.real.dtype == np.float32 else np.complex128 - for i in range(log_degree): - deg = self.tree_rand_signs[i].shape[0] - V[i] = np.zeros((deg, n, self.tree_rand_inds[i].shape[2]), dtype=dtype) - for j in range(deg): - if i == 0: - V[i] = V[i].at[j, :, :].set( - self.tensorsrht(x, x, self.tree_rand_inds[i][j, :, :], - self.tree_rand_signs[i][j, :, :])) + height = len(self.tree_rand_signs) + V = [np.zeros(())] * height + for lvl in range(height): + deg = self.tree_rand_signs[lvl].shape[0] + output_dim = self.tree_rand_inds[lvl].shape[2] + V[lvl] = np.zeros((deg, n, output_dim), dtype=dtype) + for j in range(deg): + if lvl == 0: + x1, x2 = x, x else: - V[i] = V[i].at[j, :, :].set( - self.tensorsrht(V[i - 1][2 * j, :, :], V[i - 1][2 * j + 1, :, :], - self.tree_rand_inds[i][j, :, :], - self.tree_rand_signs[i][j, :, :])) + x1, x2 = V[lvl - 1][2 * j, :, :], V[lvl - 1][2 * j + 1, :, :] + + V[lvl] = V[lvl].at[j, :, :].set( + self.tensorsrht(x1, x2, self.tree_rand_inds[lvl][j, :, :], + self.tree_rand_signs[lvl][j, :, :])) - U = [0 for i in range(2**log_degree)] - U[0] = V[log_degree - 1][0, :, :] + U = [np.zeros(())] * 2**height + U[0] = V[-1][0, :, :] SetE1 = set() - for j in range(1, len(U)): + for j in range(1, 2**height): p = (j - 1) // 2 - for i in range(log_degree): - if j % (2**(i + 1)) == 0: - SetE1.add((i, p)) + for lvl in range(height): + if j % (2**(lvl + 1)) == 0: + SetE1.add((lvl, p)) else: - if i == 0: - V[i] = V[i].at[p, :, :].set( - self.standardsrht(x, self.tree_rand_inds[i][p, 0, :], - self.tree_rand_signs[i][p, 0, :])) + if lvl == 0: + V[lvl] = V[lvl].at[p, :, :].set( + self.standardsrht(x, self.tree_rand_inds[lvl][p, 0, :], + self.tree_rand_signs[lvl][p, 0, :])) else: - if (i - 1, 2 * p) in SetE1: - V[i] = V[i].at[p, :, :].set(V[i - 1][2 * p + 1, :, :]) + if (lvl - 1, 2 * p) in SetE1: + V[lvl] = V[lvl].at[p, :, :].set(V[lvl - 1][2 * p + 1, :, :]) else: - V[i] = V[i].at[p, :, :].set( - self.tensorsrht(V[i - 1][2 * p, :, :], - V[i - 1][2 * p + 1, :, :], - self.tree_rand_inds[i][p, :, :], - self.tree_rand_signs[i][p, :, :])) + V[lvl] = V[lvl].at[p, :, :].set( + self.tensorsrht(V[lvl - 1][2 * p, :, :], + V[lvl - 1][2 * p + 1, :, :], + self.tree_rand_inds[lvl][p, :, :], + self.tree_rand_signs[lvl][p, :, :])) p = p // 2 - U[j] = V[log_degree - 1][0, :, :] + U[j] = V[-1][0, :, :] return U - def expand_feats(self, polysketch_feats, coeffs): - n, sktch_dim = polysketch_feats[0].shape - dtype = np.complex64 if polysketch_feats[ - 0].real.dtype == np.float32 else np.complex128 - Z = np.zeros((len(self.rand_signs), n), dtype=dtype) - Z = Z.at[0, :].set(coeffs[0]**0.5 * np.ones(n)) - degree = len(polysketch_feats) - for i in range(degree): - Z = Z.at[sktch_dim * i + 1:sktch_dim * (i + 1) + 1, :].set( - coeffs[i + 1]**0.5 * polysketch_feats[degree - i - 1].T) - - return Z.T -# pytype: enable=attribute-error \ No newline at end of file + def expand_feats(self, sketches, coeffs): + n = sketches[0].shape[0] + degree = len(sketches) + return np.concatenate( + [coeffs[0]**0.5 * np.ones((n, 1))] + + [coeffs[i + 1]**0.5 * sketches[-i - 1] for i in range(degree)], + axis=-1) \ No newline at end of file diff --git a/experimental/tests/fc_ntk_test.py b/experimental/tests/fc_ntk_test.py deleted file mode 100644 index a6e3372a..00000000 --- a/experimental/tests/fc_ntk_test.py +++ /dev/null @@ -1,154 +0,0 @@ -from jax import numpy as np -from jax import random -from jax.config import config -from jax import jit -import sys -sys.path.append("./") - -config.update("jax_enable_x64", True) -from neural_tangents import stax - -from experimental.features import DenseFeatures, ReluFeatures, serial, ReluNTKFeatures - - - -seed = 1 -n, d = 6, 5 -no_jitting = False - -key1, key2 = random.split(random.PRNGKey(seed)) -x = random.normal(key1, (n, d)) - -width = 512 # this does not matter the output -W_std = 1.234 # std of Gaussian random weights -b_std = 0.567 # std of the biases -dense_kwargs = {"out_dim": width, "W_std": W_std, "b_std": b_std} - -print("================== Result of Neural Tangent Library ===================") - -init_fn, _, kernel_fn = stax.serial(stax.Dense(**dense_kwargs), stax.Relu(), - stax.Dense(**dense_kwargs), stax.Relu(), - stax.Dense(**dense_kwargs), stax.Relu(), - stax.Dense(1, W_std=W_std, b_std=b_std)) - -kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) -nt_kernel = kernel_fn(x, None) - -print("K_nngp :") -print(nt_kernel.nngp) -print() - -print("K_ntk :") -print(nt_kernel.ntk) -print() - - -def test_fc_relu_ntk_approx(relufeat_arg, init_fn=None, feature_fn=None): - - print(f"ReluFeatures params:") - for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") - print() - - if init_fn is None or feature_fn is None: - init_fn, feature_fn = serial( - DenseFeatures(**dense_kwargs), ReluFeatures(**relufeat_arg), - DenseFeatures(**dense_kwargs), ReluFeatures(**relufeat_arg), - DenseFeatures(**dense_kwargs), ReluFeatures(**relufeat_arg), - DenseFeatures(1, W_std=W_std, b_std=b_std)) - - # Initialize random vectors and sketching algorithms - _, feat_fn_inputs = init_fn(key2, x.shape) - - # Transform input vectors to NNGP/NTK feature map - feature_fn = feature_fn if no_jitting else jit(feature_fn) - feats = feature_fn(x, feat_fn_inputs) - - # PolySketch returns complex features. Convert complex features to real ones. - if np.iscomplexobj(feats.nngp_feat) or np.iscomplexobj(feats.ntk_feat): - nngp_feat = np.concatenate((feats.nngp_feat.real, feats.nngp_feat.imag), axis=-1) - ntk_feat = np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) - feats = feats.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - - print(f"f_nngp shape: {feats.nngp_feat.shape}") - print(f"f_ntk shape: {feats.ntk_feat.shape}") - - print("K_nngp:") - print(feats.nngp_feat @ feats.nngp_feat.T) - print() - - print("K_ntk:") - print(feats.ntk_feat @ feats.ntk_feat.T) - print() - - print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {np.linalg.norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" - ) - print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {np.linalg.norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" - ) - print() - - - -# print("==================== Result of NTK Random Features ====================") - -# kappa0_feat_dim = 4096 -# kappa1_feat_dim = 4096 -# sketch_dim = 4096 - -# test_fc_relu_ntk_approx({ -# 'method': 'RANDFEAT', -# 'feature_dim0': kappa0_feat_dim, -# 'feature_dim1': kappa1_feat_dim, -# 'sketch_dim': sketch_dim, -# }) - -# print("==================== Result of NTK wih PolySketch ====================") - -# poly_degree = 4 -# poly_sketch_dim = 4096 -# sketch_dim = 4096 - -# test_fc_relu_ntk_approx({ -# 'method': 'POLYSKETCH', -# 'sketch_dim': sketch_dim, -# 'poly_degree': poly_degree, -# 'poly_sketch_dim': poly_sketch_dim -# }) - -# print("=============== Result of PolySketch + Random Features ===============") - -# kappa0_feat_dim = 2048 -# sketch_dim = 4096 -# poly_degree = 4 -# poly_sketch_dim = 4096 - -# test_fc_relu_ntk_approx({ -# 'method': 'PSRF', -# 'feature_dim0': kappa0_feat_dim, -# 'sketch_dim': sketch_dim, -# 'poly_degree': poly_degree, -# 'poly_sketch_dim': poly_sketch_dim -# }) - -# print("=========== Result of ReLU-NTK Sketch (one-pass sketching) ===========") - -# relufeat_arg = { -# 'num_layers': 3, -# 'poly_degree': 32, -# 'poly_sketch_dim': 4096, -# 'W_std': W_std, -# } - -# init_fn, feature_fn = ReluNTKFeatures(**relufeat_arg) -# test_fc_relu_ntk_approx(relufeat_arg, init_fn, feature_fn) - -# print("======= (Debug) NTK Feature Maps with Polynomial Approximation =======") -# print("\t(*No Sketching algorithm is applied.)") - -# test_fc_relu_ntk_approx({'method': 'POLY', 'poly_degree': 16}) - -# print("====== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ======") - -test_fc_relu_ntk_approx({'method': 'EXACT'}) diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py new file mode 100644 index 00000000..4bf81299 --- /dev/null +++ b/experimental/tests/features_test.py @@ -0,0 +1,361 @@ +from absl.testing import absltest +from absl.testing import parameterized +import functools + +from jax import jit +from jax.config import config +import jax.numpy as np +import jax.random as random +from jax import test_util as jtu +from neural_tangents._src.utils import utils as ntutils +from neural_tangents import stax +from tests import test_utils + +from experimental.features import DenseFeatures, ReluFeatures, ConvFeatures, AvgPoolFeatures, FlattenFeatures, serial + +config.update("jax_enable_x64", True) +config.parse_flags_with_absl() +config.update('jax_numpy_rank_promotion', 'raise') + +NUM_DIMS = [64, 128, 256, 512] +WEIGHT_VARIANCES = [0.001, 0.01, 0.1, 1.] +BIAS_VARIANCES = [None, 0.001, 0.01, 0.1] +test_utils.update_test_tolerance() + + +class FeaturesTest(jtu.JaxTestCase): + + @classmethod + def _get_init_data(cls, rng, shape, normalized_output=False): + x = random.normal(rng, shape) + if normalized_output: + return x / np.linalg.norm(x, axis=-1, keepdims=True) + else: + return x + + @classmethod + def _convert_image_feature_to_kernel(cls, f_): + return ntutils.zip_axes(np.einsum("ijkc,xyzc->ijkxyz", f_, f_)) + + @parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + ' [Wstd{}_bstd{}_{}layers_{}] '.format(W_std, b_std, n_layers, + 'jit' if do_jit else ''), + 'W_std': + W_std, + 'b_std': + b_std, + 'n_layers': + n_layers, + 'do_jit': + do_jit, + } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES + for n_layers in [1, 2, 3, 4] + for do_jit in [True, False])) + def testDenseFeatures(self, W_std, b_std, n_layers, do_jit): + n, d = 4, 256 + rng = random.PRNGKey(1) + x = self._get_init_data(rng, (n, d)) + + dense_args = {'out_dim': 1, 'W_std': W_std, 'b_std': b_std} + + kernel_fn = stax.serial(*[stax.Dense(**dense_args)] * n_layers)[2] + feature_fn = serial(*[DenseFeatures(**dense_args)] * n_layers)[1] + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x, None) + f = feature_fn(x, [()] * n_layers) + + self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) + self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) + + @parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + ' [Wstd{}_bstd{}_numlayers{}_{}_{}] '.format( + W_std, b_std, n_layers, relu_method, 'jit' if do_jit else ''), + 'W_std': + W_std, + 'b_std': + b_std, + 'n_layers': + n_layers, + 'relu_method': + relu_method, + 'do_jit': + do_jit, + } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES + for relu_method in + ['RANDFEAT', 'POLYSKETCH', 'PSRF', 'POLY', 'EXACT'] + for n_layers in [1, 2, 3, 4] + for do_jit in [True, False])) + def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): + rng = random.PRNGKey(1) + n, d = 4, 256 + x = self._get_init_data(rng, (n, d)) + + dense_args = {"out_dim": 1, "W_std": W_std, "b_std": b_std} + relu_args = {'method': relu_method} + if relu_method == 'RANDFEAT': + relu_args['feature_dim0'] = 4096 + relu_args['feature_dim1'] = 4096 + relu_args['sketch_dim'] = 4096 + elif relu_method == 'POLYSKETCH': + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 4096 + relu_args['sketch_dim'] = 4096 + elif relu_method == 'PSRF': + relu_args['feature_dim0'] = 4096 + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 4096 + relu_args['sketch_dim'] = 4096 + elif relu_method in ['EXACT', 'POLY']: + pass + else: + raise ValueError(relu_method) + + _, _, kernel_fn = stax.serial( + *[stax.Dense(**dense_args), stax.Relu()] * n_layers + + [stax.Dense(**dense_args)]) + init_fn, feature_fn = serial( + *[DenseFeatures(**dense_args), + ReluFeatures(**relu_args)] * n_layers + [DenseFeatures(**dense_args)]) + + rng2 = random.PRNGKey(2) + _, feat_fn_inputs = init_fn(rng2, x.shape) + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x, None) + k_nngp = k.nngp + k_ntk = k.ntk + + f = feature_fn(x, feat_fn_inputs) + if np.iscomplexobj(f.nngp_feat) or np.iscomplexobj(f.ntk_feat): + nngp_feat = np.concatenate((f.nngp_feat.real, f.nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((f.ntk_feat.real, f.ntk_feat.imag), axis=-1) + f = f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + k_nngp_approx = f.nngp_feat @ f.nngp_feat.T + k_ntk_approx = f.ntk_feat @ f.ntk_feat.T + + if relu_method == 'EXACT': + self.assertAllClose(k_nngp, k_nngp_approx) + self.assertAllClose(k_ntk, k_ntk_approx) + else: + test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.1, 1.) + test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.1, 1.) + + @parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + ' [Wstd{}_bstd{}_numlayers{}_{}] '.format( + W_std, b_std, n_layers, 'jit' if do_jit else ''), + 'W_std': + W_std, + 'b_std': + b_std, + 'n_layers': + n_layers, + 'do_jit': + do_jit, + } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES + for n_layers in [1, 2, 3, 4] + for do_jit in [True, False])) + def test_conv_features(self, W_std, b_std, n_layers, do_jit): + n, h, w, c = 3, 4, 5, 2 + rng = random.PRNGKey(1) + x = self._get_init_data(rng, (n, h, w, c)) + + conv_args = { + 'out_chan': 1, + 'filter_shape': (3, 3), + 'padding': 'SAME', + 'W_std': W_std, + 'b_std': b_std + } + + kernel_fn = stax.serial(*[stax.Conv(**conv_args)] * n_layers)[2] + feature_fn = serial(*[ConvFeatures(**conv_args)] * n_layers)[1] + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + f = feature_fn(x, [()] * n_layers) + + k_nngp_approx = self._convert_image_feature_to_kernel(f.nngp_feat) + k_ntk_approx = self._convert_image_feature_to_kernel(f.ntk_feat) + + self.assertAllClose(k.nngp, k_nngp_approx) + self.assertAllClose(k.ntk, k_ntk_approx) + + @parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + ' [nlayers{}_{}] '.format(n_layers, 'jit' if do_jit else ''), + 'n_layers': + n_layers, + 'do_jit': + do_jit + } for n_layers in [1, 2, 3, 4] for do_jit in [True, False])) + def test_avgpool_features(self, n_layers, do_jit): + n, h, w, c = 3, 32, 28, 2 + rng = random.PRNGKey(1) + x = self._get_init_data(rng, (n, h, w, c)) + + avgpool_args = { + 'window_shape': (2, 2), + 'strides': (2, 2), + 'padding': 'SAME' + } + + kernel_fn = stax.serial(*[stax.AvgPool(**avgpool_args)] * n_layers)[2] + feature_fn = serial(*[AvgPoolFeatures(**avgpool_args)] * n_layers)[1] + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + f = feature_fn(x, [()] * n_layers) + + k_nngp_approx = self._convert_image_feature_to_kernel(f.nngp_feat) + + self.assertAllClose(k.nngp, k_nngp_approx) + + def test_flatten_features(self): + n, h, w, c = 3, 32, 28, 2 + n_layers = 1 + rng = random.PRNGKey(1) + x = self._get_init_data(rng, (n, h, w, c)) + + kernel_fn = stax.serial(*[stax.Flatten()] * n_layers)[2] + + k = kernel_fn(x) + + feature_fn = serial(*[FlattenFeatures()] * n_layers)[1] + + f = feature_fn(x, [()] * n_layers) + + self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) + + @parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + ' [Wstd{}_bstd{}_depth{}_{}_{}] '.format( + W_std, b_std, depth, relu_method, 'jit' if do_jit else ''), + 'W_std': + W_std, + 'b_std': + b_std, + 'depth': + depth, + 'relu_method': + relu_method, + 'do_jit': + do_jit, + } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES + for relu_method in ['PSRF'] for depth in [5] + for do_jit in [False])) + def test_myrtle_network(self, W_std, b_std, relu_method, depth, do_jit): + if relu_method in ['RANDFEAT', 'POLYSKETCH', 'PSRF']: + import os + os.environ['CUDA_VISIBLE_DEVICES'] = '' + + n, h, w, c = 2, 32, 32, 3 + rng = random.PRNGKey(1) + x = self._get_init_data(rng, (n, h, w, c)) + + layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} + + def _get_myrtle_kernel_fn(): + conv = functools.partial(stax.Conv, + W_std=W_std, + b_std=b_std, + padding='SAME') + + layers = [] + layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][0] + layers += [stax.AvgPool((2, 2), strides=(2, 2))] + layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][1] + layers += [stax.AvgPool((2, 2), strides=(2, 2))] + layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][2] + layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3 + layers += [stax.Flatten(), stax.Dense(1, W_std=W_std, b_std=b_std)] + + return stax.serial(*layers) + + def _get_myrtle_feature_fn(**relu_args): + conv = functools.partial(ConvFeatures, W_std=W_std, b_std=b_std) + layers = [] + layers += [conv(1, (3, 3)), ReluFeatures(**relu_args) + ] * layer_factor[depth][0] + layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] + layers += [conv(1, (3, 3)), ReluFeatures(**relu_args) + ] * layer_factor[depth][1] + layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] + layers += [conv(1, (3, 3)), ReluFeatures(**relu_args) + ] * layer_factor[depth][2] + layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] * 3 + layers += [FlattenFeatures(), DenseFeatures(1, W_std=W_std, b_std=b_std)] + + return serial(*layers) + + _, _, kernel_fn = _get_myrtle_kernel_fn() + + relu_args = {'method': relu_method} + if relu_method == 'RANDFEAT': + relu_args['feature_dim0'] = 2048 + relu_args['feature_dim1'] = 2048 + relu_args['sketch_dim'] = 2048 + elif relu_method == 'POLYSKETCH': + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 2048 + relu_args['sketch_dim'] = 2048 + elif relu_method == 'PSRF': + relu_args['feature_dim0'] = 2048 + relu_args['poly_degree'] = 4 + relu_args['poly_sketch_dim'] = 2048 + relu_args['sketch_dim'] = 2048 + elif relu_method in ['EXACT', 'POLY']: + pass + else: + raise ValueError(relu_method) + + init_fn, feature_fn = _get_myrtle_feature_fn(**relu_args) + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + k_nngp = k.nngp + k_ntk = k.ntk + + _, feat_fn_inputs = init_fn(rng, x.shape) + f = feature_fn(x, feat_fn_inputs) + if np.iscomplexobj(f.nngp_feat) or np.iscomplexobj(f.ntk_feat): + nngp_feat = np.concatenate((f.nngp_feat.real, f.nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((f.ntk_feat.real, f.ntk_feat.imag), axis=-1) + f = f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + k_nngp_approx = f.nngp_feat @ f.nngp_feat.T + k_ntk_approx = f.ntk_feat @ f.ntk_feat.T + + if relu_method == 'EXACT': + self.assertAllClose(k_nngp, k_nngp_approx) + self.assertAllClose(k_ntk, k_ntk_approx) + else: + test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.15, 1.) + test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.15, 1.) + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/tests/kernel_approx_test.py b/experimental/tests/kernel_approx_test.py deleted file mode 100644 index 90c83c3e..00000000 --- a/experimental/tests/kernel_approx_test.py +++ /dev/null @@ -1,141 +0,0 @@ -import jax -from jax import numpy as np -from jax import random -from jax.config import config -from jax import jit -import sys - -sys.path.append("./") - -config.update("jax_enable_x64", True) -import neural_tangents as nt -from neural_tangents._src import empirical -from neural_tangents import stax - -from experimental.features import DenseFeatures, ReluFeatures, serial, ReluNTKFeatures - - -def _generate_fc_relu_ntk(width, depth, W_std): - layers = [] - layers += [stax.Dense(width, W_std=W_std), stax.Relu()] * depth - layers += [stax.Dense(output_dim, W_std=W_std)] - init_fn, apply_f, kernel_fn = stax.serial(*layers) - return init_fn, apply_f, kernel_fn - - -# This is re-implementation of neural_tangents.empirical_ntk_fn. -# The result is same with "nt.empirical_ntk_fn(apply_fn)(x, None, params)" -def _get_grad(x, output_dim, params, apply_fn): - - f_output = empirical._get_f_params(apply_fn, x, None, None, {}) - jac_f_output = jax.jacobian(f_output) - jacobian = jac_f_output(params) - - grad_all = [] - for jac_ in jacobian: - if len(jac_) > 0: - for j_ in jac_: - if j_ is None or np.linalg.norm(j_) < 1e-10: - continue - grad_all.append(j_.reshape(n, -1)) - - grad_all = np.hstack(grad_all) - return grad_all / np.sqrt(output_dim) - - -def _get_grad_feat_dim(input_dim, width, output_dim, depth): - dim_1 = input_dim * width - dim_2 = np.asarray([width**2 for _ in range(depth - 1)]).sum() - dim_3 = width * output_dim - return (dim_1 + dim_2 + dim_3) * output_dim - dim_3 - - -def fc_relu_ntk_sketching(relufeat_arg, - rng, - init_fn=None, - feature_fn=None, - W_std=1., - depth=-1, - no_jitting=False): - - if init_fn is None or feature_fn is None: - layers = [] - layers += [ - DenseFeatures(1, W_std=W_std), - ReluFeatures(**relufeat_arg), - ] * depth - layers += [DenseFeatures(1, W_std=W_std)] - init_fn, feature_fn = serial(*layers) - - # Initialize random vectors and sketching algorithms - _, feat_fn_inputs = init_fn(rng, x.shape) - - # Transform input vectors to NNGP/NTK feature map - feature_fn = feature_fn if no_jitting else jit(feature_fn) - feats = feature_fn(x, feat_fn_inputs) - - # PolySketch returns complex features. Convert complex features to real ones. - if np.iscomplexobj(feats.ntk_feat): - return np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) - return feats.ntk_feat - - -seed = 1 -n, d = 10, 32 -no_jitting = False - -key1, key2, key3 = random.split(random.PRNGKey(seed), 3) -x = random.normal(key1, (n, d)) - -width = 4 -depth = 3 -W_std = 1.234 -output_dim = 2 - -init_fn, apply_fn, kernel_fn = _generate_fc_relu_ntk(width, depth, W_std) - -kernel_fn = kernel_fn if no_jitting else jit(kernel_fn) -nt_kernel = kernel_fn(x, None) - -# Sanity check of grad_feat. -_, params = init_fn(key2, x.shape) -grad_feat = _get_grad(x, output_dim, params, apply_fn) -assert np.linalg.norm( - nt.empirical_ntk_fn(apply_fn)(x, None, params) - - grad_feat @ grad_feat.T) <= 1e-12 - -# Store Frobenius-norm of the exact NTK for estimating relative errors. -ntk_norm = np.linalg.norm(nt_kernel.ntk) - -width_all = np.arange(2, 16) -grad_feat_dims_all = [] - -print("empirical_ntk_fn results:") -for width in width_all: - init_fn, apply_fn, _ = _generate_fc_relu_ntk(width, depth, W_std) - _, params = init_fn(key2, x.shape) - grad_feat = _get_grad(x, output_dim, params, apply_fn) - rel_err = np.linalg.norm(grad_feat @ grad_feat.T - nt_kernel.ntk) / ntk_norm - grad_feat_dims_all.append(grad_feat.shape[1]) - print( - f"feat_dim : {grad_feat.shape[1]} (width : {width}), relative err : {rel_err}" - ) - -print() -print("ReluNTKFeatures results:") -relufeat_arg = { - 'num_layers': depth, - 'poly_degree': 16, - 'W_std': W_std, -} - -for feat_dim in grad_feat_dims_all: - relufeat_arg['poly_sketch_dim'] = feat_dim - init_fn, feature_fn = ReluNTKFeatures(**relufeat_arg) - ntk_feat = fc_relu_ntk_sketching(relufeat_arg, - key3, - init_fn=init_fn, - feature_fn=feature_fn) - - rel_err = np.linalg.norm(ntk_feat @ ntk_feat.T - nt_kernel.ntk) / ntk_norm - print(f"feat_dim : {ntk_feat.shape[1]}, err : {rel_err}") \ No newline at end of file diff --git a/experimental/tests/myrtle_network_test.py b/experimental/tests/myrtle_network_test.py deleted file mode 100644 index b3e69c49..00000000 --- a/experimental/tests/myrtle_network_test.py +++ /dev/null @@ -1,203 +0,0 @@ -import time -import sys -sys.path.insert(0, "./") -import functools -from numpy.linalg import norm -import jax -from jax.config import config -from jax import jit -# Enable float64 for JAX -config.update("jax_enable_x64", True) - -import jax.numpy as np -from jax import random - -from neural_tangents import stax -from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures - -layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} - -H, W, C = 4, 4, 3; num_final_avgpools = 0 -# H, W, C = 32, 32, 3; num_final_avgpools = 3 - -def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0., width=1): - activation_fn = stax.Relu() - conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME') - - layers = [] - layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0] - layers += [stax.AvgPool((2, 2), strides=(2, 2))] - layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1] - layers += [stax.AvgPool((2, 2), strides=(2, 2))] - layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2] - layers += [stax.AvgPool((2, 2), strides=(2, 2))] * num_final_avgpools - - layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)] - - return stax.serial(*layers) - - -def MyrtleNetworkFeatures(depth, W_std=np.sqrt(2.0), width=1, **relu_args): - - conv_fn = functools.partial(ConvFeatures, W_std=W_std) - - layers = [] - layers += [conv_fn(width, filter_shape=(3, 3)), - ReluFeatures(**relu_args)] * layer_factor[depth][0] - layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] - layers += [ - ConvFeatures(width, filter_shape=(3, 3), W_std=W_std), - ReluFeatures(**relu_args) - ] * layer_factor[depth][1] - layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] - layers += [ - ConvFeatures(width, filter_shape=(3, 3), W_std=W_std), - ReluFeatures(**relu_args) - ] * layer_factor[depth][2] - layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] * num_final_avgpools - layers += [FlattenFeatures(), DenseFeatures(1, W_std)] - - return serial(*layers) - - -def test_small_dataset(num_data=4, dataset='synthetic', depth=5, do_jit=True): - - print(f"dataset : {dataset}") - - key = random.PRNGKey(0) - - if dataset == 'synthetic': - x = random.normal(key, shape=(num_data, H, W, C)) - - elif dataset in ['cifar10', 'cifar100']: - from examples import datasets - x = datasets.get_dataset('cifar10', do_flatten_and_normalize=False)[0] - if (H, W) != (32, 32): - x = jax.image.resize(x, (x.shape[0], H, W, 3), method='linear') - mean_ = np.mean(x) - std_ = np.std(x) - x = (x - mean_) / std_ - - x = x[random.permutation(key, len(x))[:num_data]] - - else: - raise NotImplementedError(f"Invalid dataset : {dataset}") - - key1, key2 = random.split(key) - print("================= Result of Neural Tangent Library =================") - - _, _, kernel_fn = MyrtleNetwork(depth) - kernel_fn = jit(kernel_fn) if do_jit else kernel_fn - - tic = time.time() - nt_kernel = kernel_fn(x) - toc = time.time() - tic - print(f"nt kernel time: {toc:.4f} sec") - - if num_data <= 8: - print("K_nngp (exact):") - print(nt_kernel.nngp) - print() - - print("K_ntk (exact):") - print(nt_kernel.ntk) - print() - - def test_myrtle_network_approx(relufeat_arg): - - print(f"ReluFeatures params:") - for name_, value_ in relufeat_arg.items(): - print(f"{name_:<12} : {value_}") - print() - - init_fn, feature_fn = MyrtleNetworkFeatures(depth, **relufeat_arg) - - # Initialize random vectors and sketching algorithms - _, feat_fn_inputs = init_fn(key2, x.shape) - - # Transform input vectors to NNGP/NTK feature map - feature_fn = jit(feature_fn) if do_jit else feature_fn - - tic = time.time() - feats = feature_fn(x, feat_fn_inputs) - toc = time.time() - tic - print(f"{relufeat_arg['method']} feature time: {toc:.4f} sec") - - # PolySketch returns complex features. Convert complex features to real ones. - if np.iscomplexobj(feats.nngp_feat) or np.iscomplexobj(feats.ntk_feat): - nngp_feat = np.concatenate((feats.nngp_feat.real, feats.nngp_feat.imag), axis=-1) - ntk_feat = np.concatenate((feats.ntk_feat.real, feats.ntk_feat.imag), axis=-1) - feats = feats.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) - - print(f"f_nngp shape: {feats.nngp_feat.shape}") - print(f"f_ntk shape: {feats.ntk_feat.shape}") - - if num_data <= 8: - print("K_nngp:") - print(feats.nngp_feat @ feats.nngp_feat.T) - print() - - print("K_ntk:") - print(feats.ntk_feat @ feats.ntk_feat.T) - print() - - print( - f"|| K_nngp - f_nngp @ f_nngp.T ||_fro = {np.linalg.norm(nt_kernel.nngp - feats.nngp_feat @ feats.nngp_feat.T)}" - ) - print( - f"|| K_ntk - f_ntk @ f_ntk.T ||_fro = {np.linalg.norm(nt_kernel.ntk - feats.ntk_feat @ feats.ntk_feat.T)}" - ) - print() - - - print("================= Result of CNTK Random Features =================") - kappa0_feat_dim = 1024 - kappa1_feat_dim = 1024 - sketch_dim = 1024 - - test_myrtle_network_approx({ - 'method': 'RANDFEAT', - 'feature_dim0': kappa0_feat_dim, - 'feature_dim1': kappa1_feat_dim, - 'sketch_dim': sketch_dim, - }) - - print("================== Result of CNTK wih PolySketch ==================") - poly_degree = 8 - poly_sketch_dim = 1024 - sketch_dim = 1024 - - test_myrtle_network_approx({ - 'method': 'POLYSKETCH', - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim - }) - - print("============== Result of PolySketch + Random Features ==============") - kappa0_feat_dim = 512 - sketch_dim = 1024 - poly_degree = 8 - poly_sketch_dim = 1024 - - test_myrtle_network_approx({ - 'method': 'PSRF', - 'feature_dim0': kappa0_feat_dim, - 'sketch_dim': sketch_dim, - 'poly_degree': poly_degree, - 'poly_sketch_dim': poly_sketch_dim - }) - - print("===== (Debug) NTK Feature Maps with Polynomial Approximation =====") - print("\t(*No Sketching algorithm is applied.)") - - test_myrtle_network_approx({'method': 'POLY', 'poly_degree': poly_degree}) - - print("==== (Debug) Exact NTK Feature Maps via Cholesky Decomposition ====") - - test_myrtle_network_approx({'method': 'EXACT'}) - -if __name__ == "__main__": - test_small_dataset(num_data=4, dataset='synthetic', depth=5) - # test_small_dataset(num_data=4, dataset='cifar10', depth=5) - # test_small_dataset(num_data=4, dataset='cifar100', depth=5) \ No newline at end of file diff --git a/experimental/tests/sketching_test.py b/experimental/tests/sketching_test.py index 247019f4..b5334c3a 100644 --- a/experimental/tests/sketching_test.py +++ b/experimental/tests/sketching_test.py @@ -1,41 +1,58 @@ -import sys +from absl.testing import absltest +from absl.testing import parameterized -sys.path.append("./") -import scipy -from jax import random, jit -from jax import numpy as jnp +import jax.numpy as np +from math import factorial +import jax.random as random +from jax import test_util as jtu from experimental.sketching import PolyTensorSketch +from tests import test_utils -# Coefficients of Taylor series of exp(x) -degree = 8 -coeffs = jnp.asarray([1 / scipy.special.factorial(i) for i in range(degree)]) +NUM_POINTS = [10, 100, 1000] +NUM_DIMS = [64, 256, 1024] -n = 4 -d = 32 -sketch_dim = 256 -rng = random.PRNGKey(1) -x = random.normal(rng, shape=(n, d)) -norm_x = jnp.linalg.norm(x, axis=-1, keepdims=True) -x_normalized = x / norm_x +class SketchingTest(jtu.JaxTestCase): -rng2 = random.PRNGKey(2) -pts = PolyTensorSketch(rng=rng2, - input_dim=d, - sketch_dim=sketch_dim, - degree=degree).init_sketches() # pytype:disable=wrong-keyword-args -x_sketches = pts.sketch(x_normalized) + @classmethod + def _get_init_data(cls, rng, shape, normalized_output=True): + x = random.normal(rng, shape) + if normalized_output: + return x / np.linalg.norm(x, axis=-1, keepdims=True) + else: + return x -z = pts.expand_feats(x_sketches, coeffs) # z.shape[1] is not the desired. -z = pts.standardsrht(z) # z is complex ndarray. -z = jnp.concatenate((z.real, z.imag), axis=-1) + @parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': f' [n{n}_d{d}]', + 'n': 4, + 'd': 32, + 'sketch_dim': 1024, + 'degree': 16 + } for n in NUM_POINTS for d in NUM_DIMS)) + def test_exponential_kernel(self, n, d, sketch_dim, degree): + rng = random.PRNGKey(1) + x = self._get_init_data(rng, (n, d), True) -K = jnp.polyval(coeffs[::-1], x_normalized @ x_normalized.T) -K_approx = z @ z.T + coeffs = np.asarray([1 / factorial(i) for i in range(degree)]) -print("Exact kernel matrix:") -print(K) -print() + rng2 = random.PRNGKey(2) + pts = PolyTensorSketch(rng=rng2, + input_dim=d, + sketch_dim=sketch_dim, + degree=degree).init_sketches() # pytype:disable=wrong-keyword-args -print(f"Approximate kernel matrix (sketch_dim: {z.shape[1]}):") -print(K_approx) + x_sketches = pts.sketch(x) + + z = pts.expand_feats(x_sketches, coeffs) + z = pts.standardsrht(z) + z = np.concatenate((z.real, z.imag), axis=-1) + + k_exact = np.polyval(coeffs[::-1], x @ x.T) + k_approx = z @ z.T + + test_utils.assert_close_matrices(self, k_exact, k_approx, 0.15, 1.) + + +if __name__ == "__main__": + absltest.main() From 6d709ffa2ff2cff96c360f1b24df0ddd926b048e Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 8 Jun 2022 13:52:35 +0900 Subject: [PATCH 37/44] Add aggregate features for graph neural nets --- experimental/features.py | 85 ++++++++++++++++++++++++++++- experimental/tests/features_test.py | 83 ++++++++++++++++++++++++---- 2 files changed, 154 insertions(+), 14 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index e4b00763..373c9149 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -9,7 +9,7 @@ from neural_tangents import stax from neural_tangents._src.utils import dataclasses from neural_tangents._src.utils.typing import Axes -from neural_tangents._src.stax.linear import _pool_kernel, Padding, _get_dimension_numbers +from neural_tangents._src.stax.linear import _pool_kernel, Padding, _get_dimension_numbers, AggregateImplementation from neural_tangents._src.stax.linear import _Pooling as Pooling from experimental.sketching import TensorSRHT, PolyTensorSketch @@ -653,6 +653,51 @@ def feature_fn(f, input=None, **kwargs): return init_fn, feature_fn +@layer +def GlobalAvgPoolFeatures(batch_axis: int = 0, channel_axis: int = -1): + + def init_fn(rng, input_shape): + nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + ndim = len(nngp_feat_shape) + non_spatial_axes = (batch_axis % ndim, channel_axis % ndim) + _get_output_shape = lambda _shape: tuple(_shape[i] + for i in range(ndim) + if i in non_spatial_axes) + new_nngp_feat_shape = _get_output_shape(nngp_feat_shape) + new_ntk_feat_shape = _get_output_shape(ntk_feat_shape) + + if len(input_shape) > 2: + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () + else: + return (new_nngp_feat_shape, new_ntk_feat_shape, ''), () + + def feature_fn(f, input=None, **kwargs): + # Unnormalize the input features. + f_renomalized: Features = _unnormalize_features(f) + nngp_feat: np.ndarray = f_renomalized.nngp_feat + ntk_feat: np.ndarray = f_renomalized.ntk_feat + + ndim = len(nngp_feat.shape) + non_spatial_axes = (batch_axis % ndim, channel_axis % ndim) + spatial_axes = tuple(set(range(ndim)) - set(non_spatial_axes)) + + nngp_feat = np.mean(nngp_feat, axis=spatial_axes) + ntk_feat = np.mean(ntk_feat, axis=spatial_axes) + + # Re-normalize the features. + norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) + nngp_feat = nngp_feat / norms + ntk_feat = ntk_feat / norms + + return f.replace(nngp_feat=nngp_feat, + ntk_feat=ntk_feat, + norms=norms, + is_reversed=False) + + return init_fn, feature_fn + + # TODO(insu): fix reshaping features for general batch/channel axes. @layer def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): @@ -714,4 +759,40 @@ def init_fn(rng, input_shape): def feature_fn(f, input=None, **kwargs): return f.replace(norms=np.ones_like(f.norms)) - return init_fn, feature_fn \ No newline at end of file + return init_fn, feature_fn + + +@layer +def AggregateFeatures( + aggregate_axis: Optional[Axes] = None, + batch_axis: int = 0, + channel_axis: int = -1, + to_dense: Optional[Callable[[np.ndarray], np.ndarray]] = lambda p: p, + implementation: str = AggregateImplementation.DENSE.value): + + init_fn = lambda rng, input_shape: (input_shape, ()) + + def feature_fn(f, input=None, pattern= None, **kwargs): + if pattern is None: + raise NotImplementedError('`pattern=None` is not implemented.') + + f_renomalized: Features = _unnormalize_features(f) + nngp_feat: np.ndarray = f_renomalized.nngp_feat + ntk_feat: np.ndarray = f_renomalized.ntk_feat + + pattern_T = np.swapaxes(pattern, 1, 2) + nngp_feat = np.einsum("bnm,bmc->bnc", pattern_T, nngp_feat) + if f.ntk_feat.ndim == 0: + ntk_feat = nngp_feat + else: + ntk_feat = np.einsum("bnm,bmc->bnc", pattern_T, ntk_feat) + + # Re-normalize the features. + norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) + norms = np.where(norms > 0, norms, 1.0) + nngp_feat = nngp_feat / norms + ntk_feat = ntk_feat / norms + + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + + return init_fn, feature_fn diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py index 4bf81299..7e251498 100644 --- a/experimental/tests/features_test.py +++ b/experimental/tests/features_test.py @@ -1,29 +1,31 @@ from absl.testing import absltest from absl.testing import parameterized import functools - from jax import jit from jax.config import config import jax.numpy as np import jax.random as random -from jax import test_util as jtu from neural_tangents._src.utils import utils as ntutils from neural_tangents import stax from tests import test_utils -from experimental.features import DenseFeatures, ReluFeatures, ConvFeatures, AvgPoolFeatures, FlattenFeatures, serial +from experimental.features import DenseFeatures, ReluFeatures, ConvFeatures, AvgPoolFeatures, FlattenFeatures, serial, GlobalAvgPoolFeatures, AggregateFeatures + -config.update("jax_enable_x64", True) config.parse_flags_with_absl() config.update('jax_numpy_rank_promotion', 'raise') +test_utils.update_test_tolerance() + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '' + NUM_DIMS = [64, 128, 256, 512] WEIGHT_VARIANCES = [0.001, 0.01, 0.1, 1.] BIAS_VARIANCES = [None, 0.001, 0.01, 0.1] -test_utils.update_test_tolerance() -class FeaturesTest(jtu.JaxTestCase): +class FeaturesTest(test_utils.NeuralTangentsTestCase): @classmethod def _get_init_data(cls, rng, shape, normalized_output=False): @@ -38,7 +40,7 @@ def _convert_image_feature_to_kernel(cls, f_): return ntutils.zip_axes(np.einsum("ijkc,xyzc->ijkxyz", f_, f_)) @parameterized.named_parameters( - jtu.cases_from_list({ + test_utils.cases_from_list({ 'testcase_name': ' [Wstd{}_bstd{}_{}layers_{}] '.format(W_std, b_std, n_layers, 'jit' if do_jit else ''), @@ -53,7 +55,7 @@ def _convert_image_feature_to_kernel(cls, f_): } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES for n_layers in [1, 2, 3, 4] for do_jit in [True, False])) - def testDenseFeatures(self, W_std, b_std, n_layers, do_jit): + def test_dense_features(self, W_std, b_std, n_layers, do_jit): n, d = 4, 256 rng = random.PRNGKey(1) x = self._get_init_data(rng, (n, d)) @@ -74,7 +76,7 @@ def testDenseFeatures(self, W_std, b_std, n_layers, do_jit): self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) @parameterized.named_parameters( - jtu.cases_from_list({ + test_utils.cases_from_list({ 'testcase_name': ' [Wstd{}_bstd{}_numlayers{}_{}_{}] '.format( W_std, b_std, n_layers, relu_method, 'jit' if do_jit else ''), @@ -152,7 +154,7 @@ def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.1, 1.) @parameterized.named_parameters( - jtu.cases_from_list({ + test_utils.cases_from_list({ 'testcase_name': ' [Wstd{}_bstd{}_numlayers{}_{}] '.format( W_std, b_std, n_layers, 'jit' if do_jit else ''), @@ -197,7 +199,7 @@ def test_conv_features(self, W_std, b_std, n_layers, do_jit): self.assertAllClose(k.ntk, k_ntk_approx) @parameterized.named_parameters( - jtu.cases_from_list({ + test_utils.cases_from_list({ 'testcase_name': ' [nlayers{}_{}] '.format(n_layers, 'jit' if do_jit else ''), 'n_layers': @@ -247,7 +249,7 @@ def test_flatten_features(self): self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) @parameterized.named_parameters( - jtu.cases_from_list({ + test_utils.cases_from_list({ 'testcase_name': ' [Wstd{}_bstd{}_depth{}_{}_{}] '.format( W_std, b_std, depth, relu_method, 'jit' if do_jit else ''), @@ -356,6 +358,63 @@ def _get_myrtle_feature_fn(**relu_args): test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.15, 1.) test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.15, 1.) + def test_global_average_pooling_features(self): + rng = random.PRNGKey(1) + input_shape = (4, 5, 6, 7) + x = random.normal(rng, input_shape) + + _, _, kernel_fn = stax.serial( + stax.Conv(1, (3, 3), padding='SAME'), + stax.Relu(), + stax.GlobalAvgPool() + ) + + _, feature_fn = serial( + ConvFeatures(1, (3, 3)), + ReluFeatures(method='EXACT'), + GlobalAvgPoolFeatures() + ) + + k = jit(kernel_fn)(x) + f = jit(feature_fn)(x, [()] * 3) + + self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) + self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) + + def test_aggregate_features(self): + rng = random.PRNGKey(1) + rng1, rng2 = random.split(rng, 2) + + batch_size = 4 + num_channels = 3 + shape = (5, ) + width = 1 + + x = random.normal(rng1, (batch_size,) + shape + (num_channels,)) + pattern = random.uniform(rng2, (batch_size,) + shape * 2) + + _, _, kernel_fn = stax.serial( + stax.Dense(width, W_std=2**0.5), + stax.Relu(), + stax.Aggregate(), + stax.GlobalAvgPool(), + stax.Dense(width) + ) + + k = jit(kernel_fn)(x, None, pattern=(pattern, pattern)) + + _, feature_fn = serial( + DenseFeatures(width, W_std=2**0.5), + ReluFeatures(method='EXACT'), + AggregateFeatures(), + GlobalAvgPoolFeatures(), + DenseFeatures(width) + ) + + f = feature_fn(x, [()] * 5, **{'pattern': pattern}) + self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) + self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) + if __name__ == "__main__": absltest.main() From a6724cbf6163f0848d74a67d3b6393fd2e27c2ea Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 8 Jun 2022 14:06:43 +0900 Subject: [PATCH 38/44] Fix features_test --- experimental/tests/features_test.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py index 7e251498..48efc3a4 100644 --- a/experimental/tests/features_test.py +++ b/experimental/tests/features_test.py @@ -12,6 +12,7 @@ from experimental.features import DenseFeatures, ReluFeatures, ConvFeatures, AvgPoolFeatures, FlattenFeatures, serial, GlobalAvgPoolFeatures, AggregateFeatures +config.update("jax_enable_x64", True) config.parse_flags_with_absl() config.update('jax_numpy_rank_promotion', 'raise') @@ -20,9 +21,9 @@ import os os.environ['CUDA_VISIBLE_DEVICES'] = '' -NUM_DIMS = [64, 128, 256, 512] -WEIGHT_VARIANCES = [0.001, 0.01, 0.1, 1.] -BIAS_VARIANCES = [None, 0.001, 0.01, 0.1] +NUM_DIMS = [128, 256, 512] +WEIGHT_VARIANCES = [0.01, 0.1, 1.] +BIAS_VARIANCES = [None, 0.01, 0.1] class FeaturesTest(test_utils.NeuralTangentsTestCase): @@ -355,8 +356,8 @@ def _get_myrtle_feature_fn(**relu_args): self.assertAllClose(k_nngp, k_nngp_approx) self.assertAllClose(k_ntk, k_ntk_approx) else: - test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.15, 1.) - test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.15, 1.) + test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.2, 1.) def test_global_average_pooling_features(self): rng = random.PRNGKey(1) From 420f83682ca1b95a4e3a870a21b97c27069b03d2 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 8 Jun 2022 14:16:22 +0900 Subject: [PATCH 39/44] Fix features_test --- experimental/tests/features_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py index 48efc3a4..5dd882aa 100644 --- a/experimental/tests/features_test.py +++ b/experimental/tests/features_test.py @@ -151,8 +151,8 @@ def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): self.assertAllClose(k_nngp, k_nngp_approx) self.assertAllClose(k_ntk, k_ntk_approx) else: - test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.1, 1.) - test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.1, 1.) + test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.2, 1.) @parameterized.named_parameters( test_utils.cases_from_list({ From 9a14e603d0afbd91ccd6bd9c28e8fa57395e4d71 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 13 Jul 2022 12:25:47 +0900 Subject: [PATCH 40/44] Update dynamic axis --- experimental/features.py | 613 +++++++++++++++------------ experimental/poly_fitting.py | 30 +- experimental/tests/features_test.py | 517 +++++++++++++--------- experimental/tests/sketching_test.py | 17 +- 4 files changed, 672 insertions(+), 505 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index 373c9149..f69bca0f 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -1,14 +1,21 @@ import enum from typing import Optional, Callable, Sequence, Tuple +import frozendict +import string +import functools +import operator as op + +from jax import lax from jax import random from jax._src.util import prod from jax import numpy as np import jax.example_libraries.stax as ostax from jax import eval_shape, ShapedArray -from neural_tangents import stax from neural_tangents._src.utils import dataclasses from neural_tangents._src.utils.typing import Axes +from neural_tangents._src.stax.requirements import _set_req, get_req, _fuse_requirements, _DEFAULT_INPUT_REQ +from neural_tangents._src.stax.combinators import _get_input_req_attr from neural_tangents._src.stax.linear import _pool_kernel, Padding, _get_dimension_numbers, AggregateImplementation from neural_tangents._src.stax.linear import _Pooling as Pooling @@ -21,9 +28,6 @@ class Features: nngp_feat: np.ndarray ntk_feat: np.ndarray - norms: np.ndarray - - is_reversed: bool = dataclasses.field(pytree_node=False) batch_axis: int = 0 channel_axis: int = -1 @@ -31,7 +35,7 @@ class Features: replace = ... # type: Callable[..., 'Features'] -class ReluFeaturesMethod(enum.Enum): +class ReluFeaturesImplementation(enum.Enum): """Method for ReLU NNGP/NTK features approximation.""" RANDFEAT = 'RANDFEAT' POLYSKETCH = 'POLYSKETCH' @@ -40,6 +44,15 @@ class ReluFeaturesMethod(enum.Enum): EXACT = 'EXACT' +def requires(**static_reqs): + + def req(feature_fn): + _set_req(feature_fn, frozendict.frozendict(static_reqs)) + return feature_fn + + return req + + def layer(layer_fn): def new_layer_fns(*args, **kwargs): @@ -54,10 +67,10 @@ def new_layer_fns(*args, **kwargs): def _preprocess_init_fn(init_fn): def init_fn_any(rng, input_shape_any, **kwargs): - if _is_sinlge_shape(input_shape_any): + if _is_single_shape(input_shape_any): # Add a dummy shape for ntk_feat dummy_shape = (-1,) + (0,) * (len(input_shape_any) - 1) - input_shape = (input_shape_any, dummy_shape) + input_shape = (input_shape_any, dummy_shape, '') return init_fn(rng, input_shape, **kwargs) else: return init_fn(rng, input_shape_any, **kwargs) @@ -65,11 +78,11 @@ def init_fn_any(rng, input_shape_any, **kwargs): return init_fn_any -def _is_sinlge_shape(input_shape): +def _is_single_shape(input_shape): if all(isinstance(n, int) for n in input_shape): return True - elif (len(input_shape) == 2 or len(input_shape) == 3) and all( - _is_sinlge_shape(s) for s in input_shape[:2]): + elif len(input_shape) == 3 and all( + _is_single_shape(s) for s in input_shape[:2]): return False raise ValueError(input_shape) @@ -82,7 +95,9 @@ def feature_fn_feature(feature, input, **kwargs): return feature_fn(feature, input, **kwargs) def feature_fn_x(x, input, **kwargs): - feature = _inputs_to_features(x, **kwargs) + feature_fn_reqs = get_req(feature_fn) + reqs = _fuse_requirements(feature_fn_reqs, _DEFAULT_INPUT_REQ, **kwargs) + feature = _inputs_to_features(x, **reqs) return feature_fn(feature, input, **kwargs) def feature_fn_any(x_or_feature, input, **kwargs): @@ -90,6 +105,7 @@ def feature_fn_any(x_or_feature, input, **kwargs): return feature_fn_feature(x_or_feature, input, **kwargs) return feature_fn_x(x_or_feature, input, **kwargs) + _set_req(feature_fn_any, get_req(feature_fn)) return feature_fn_any @@ -98,20 +114,19 @@ def _inputs_to_features(x: np.ndarray, channel_axis: int = -1, **kwargs) -> Features: """Transforms (batches of) inputs to a `Features`.""" - # Followed the same initialization of Neural Tangents library. - nngp_feat = x / x.shape[channel_axis]**0.5 - norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) - norms = np.where(norms > 0, norms, 1.0) - nngp_feat = nngp_feat / norms - ntk_feat = np.zeros((), dtype=nngp_feat.dtype) - - is_reversed = False + if channel_axis is None: + x = np.moveaxis(x, batch_axis, 0).reshape((x.shape[batch_axis], -1)) + batch_axis, channel_axis = 0, 1 + else: + channel_axis %= x.ndim + nngp_feat = x / x.shape[channel_axis]**0.5 + ntk_feat = np.zeros(x.shape[:channel_axis] + (0,) + + x.shape[channel_axis + 1:], + dtype=x.dtype) return Features(nngp_feat=nngp_feat, ntk_feat=ntk_feat, - norms=norms, - is_reversed=is_reversed, batch_axis=batch_axis, channel_axis=channel_axis) # pytype:disable=wrong-keyword-args @@ -124,29 +139,26 @@ def serial(*layers): init_fns, feature_fns = zip(*layers) init_fn, _ = ostax.serial(*zip(init_fns, init_fns)) - def feature_fn(k, inputs, **kwargs): - for f, input_ in zip(feature_fns, inputs): - k = f(k, input_, **kwargs) - k = _unnormalize_features(k) - return k + @requires(**_get_input_req_attr(feature_fns, fold=op.rshift)) + def feature_fn(features: Features, inputs, **kwargs) -> Features: + if not (len(init_fns) == len(feature_fns) == len(inputs)): + raise ValueError('Length of inputs should be same as that of layers.') + for feature_fn_, input_ in zip(feature_fns, inputs): + features = feature_fn_(features, input_, **kwargs) + return features return init_fn, feature_fn -def _unnormalize_features(f: Features) -> Features: - nngp_feat = f.nngp_feat * f.norms - ntk_feat = f.ntk_feat * f.norms if f.ntk_feat.ndim != 0 else f.ntk_feat - norms = np.zeros((), dtype=nngp_feat.dtype) - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) - - @layer def DenseFeatures(out_dim: int, W_std: float = 1., b_std: Optional[float] = None, - parameterization: str = 'ntk', batch_axis: int = 0, - channel_axis: int = -1): + channel_axis: int = -1, + parameterization: str = 'ntk'): + + parameterization = parameterization.lower() if parameterization != 'ntk': raise NotImplementedError(f'Parameterization ({parameterization}) is ' @@ -154,49 +166,42 @@ def DenseFeatures(out_dim: int, def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] + _channel_axis = channel_axis % len(nngp_feat_shape) - nngp_feat_dim = nngp_feat_shape[channel_axis] - if b_std is not None: - nngp_feat_dim += 1 + nngp_feat_dim = nngp_feat_shape[_channel_axis] + (1 if b_std is not None + else 0) + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + nngp_feat_dim,) + nngp_feat_shape[_channel_axis + 1:] - new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_dim,) - new_ntk_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_dim + - ntk_feat_shape[channel_axis],) - - if len(input_shape) > 2: - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () + if prod(ntk_feat_shape) == 0: + new_ntk_feat_shape = new_nngp_feat_shape else: - return (new_nngp_feat_shape, new_ntk_feat_shape, 'D'), () - - def feature_fn(f: Features, input=None, **kwargs): - nngp_feat: np.ndarray = f.nngp_feat - ntk_feat: np.ndarray = f.ntk_feat - norms: np.ndarray = f.norms - - if b_std is not None: - f_renomalized: Features = _unnormalize_features(f) - nngp_feat: np.ndarray = f_renomalized.nngp_feat - ntk_feat: np.ndarray = f_renomalized.ntk_feat - - biases = b_std * np.ones((nngp_feat.shape[0], 1), dtype=nngp_feat.dtype) - nngp_feat = np.concatenate((W_std * nngp_feat, biases), axis=channel_axis) - ntk_feat = W_std * ntk_feat + ntk_feat_dim = nngp_feat_dim + ntk_feat_shape[_channel_axis] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + ntk_feat_dim,) + ntk_feat_shape[_channel_axis + 1:] - norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) - norms = np.where(norms > 0, norms, 1.0) + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () - nngp_feat = nngp_feat / norms - ntk_feat = ntk_feat / norms if ntk_feat.ndim != 0 else ntk_feat + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat - else: - norms *= W_std + _channel_axis = channel_axis % nngp_feat.ndim - if ntk_feat.ndim == 0: # check if ntk_feat is empty - ntk_feat = nngp_feat + if b_std is not None: # concatenate bias vector in nngp_feat + biases = b_std * np.ones(nngp_feat.shape[:_channel_axis] + + (1,) + nngp_feat.shape[_channel_axis + 1:], + dtype=nngp_feat.dtype) + nngp_feat = np.concatenate((W_std * nngp_feat, biases), + axis=_channel_axis) + ntk_feat = W_std * ntk_feat else: - ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) + nngp_feat *= W_std + ntk_feat *= W_std - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=_channel_axis) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) return init_fn, feature_fn @@ -208,51 +213,67 @@ def ReluFeatures(method: str = 'RANDFEAT', sketch_dim: int = 1, poly_degree: int = 8, poly_sketch_dim: int = 1, - generate_rand_mtx: bool = True): + generate_rand_mtx: bool = True, + batch_axis: int = 0, + channel_axis: int = -1): - method = ReluFeaturesMethod(method.upper()) + method = ReluFeaturesImplementation(method.upper()) def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:-1] + (feature_dim1,) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (sketch_dim,) + net_shape = input_shape[2] relu_layers_count = net_shape.count('R') new_net_shape = net_shape + 'R' - if method == ReluFeaturesMethod.RANDFEAT: + ndim = len(nngp_feat_shape) + _channel_axis = channel_axis % ndim + + if method == ReluFeaturesImplementation.RANDFEAT: + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + feature_dim1,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + sketch_dim,) + ntk_feat_shape[_channel_axis + 1:] + rng1, rng2, rng3 = random.split(rng, 3) if generate_rand_mtx: # Random vectors for random features of arc-cosine kernel of order 0. - W0 = random.normal(rng1, (nngp_feat_shape[-1], feature_dim0)) + W0 = random.normal(rng1, (nngp_feat_shape[_channel_axis], feature_dim0)) # Random vectors for random features of arc-cosine kernel of order 1. - W1 = random.normal(rng2, (nngp_feat_shape[-1], feature_dim1)) + W1 = random.normal(rng2, (nngp_feat_shape[_channel_axis], feature_dim1)) else: # if `generate_rand_mtx` is False, return random seeds and shapes instead of np.ndarray. - W0 = (rng1, (nngp_feat_shape[-1], feature_dim0)) - W1 = (rng2, (nngp_feat_shape[-1], feature_dim1)) + W0 = (rng1, (nngp_feat_shape[_channel_axis], feature_dim0)) + W1 = (rng2, (nngp_feat_shape[_channel_axis], feature_dim1)) # TensorSRHT of degree 2 for approximating tensor product. tensorsrht = TensorSRHT(rng=rng3, - input_dim1=ntk_feat_shape[-1], + input_dim1=ntk_feat_shape[_channel_axis], input_dim2=feature_dim0, sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), (W0, W1, tensorsrht) - elif method == ReluFeaturesMethod.POLYSKETCH: - new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + elif method == ReluFeaturesImplementation.POLYSKETCH: + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + poly_sketch_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + sketch_dim,) + ntk_feat_shape[_channel_axis + 1:] + rng1, rng2 = random.split(rng, 2) + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + poly_sketch_dim,) + nngp_feat_shape[_channel_axis + 1:] + kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) # PolySketch expansion for nngp features. if relu_layers_count == 0: - pts_input_dim = nngp_feat_shape[-1] + pts_input_dim = nngp_feat_shape[_channel_axis] else: - pts_input_dim = int(nngp_feat_shape[-1] / 2 + 0.5) + pts_input_dim = int(nngp_feat_shape[_channel_axis] / 2 + 0.5) polysketch = PolyTensorSketch(rng=rng1, input_dim=pts_input_dim, sketch_dim=poly_sketch_dim, @@ -260,9 +281,9 @@ def init_fn(rng, input_shape): # TensorSRHT of degree 2 for approximating tensor product. if relu_layers_count == 0: - ts_input_dim = ntk_feat_shape[-1] + ts_input_dim = ntk_feat_shape[_channel_axis] else: - ts_input_dim = int(ntk_feat_shape[-1] / 2 + 0.5) + ts_input_dim = int(ntk_feat_shape[_channel_axis] / 2 + 0.5) tensorsrht = TensorSRHT(rng=rng2, input_dim1=ts_input_dim, input_dim2=poly_degree * @@ -273,17 +294,21 @@ def init_fn(rng, input_shape): new_net_shape), (polysketch, tensorsrht, kappa0_coeff, kappa1_coeff) - elif method == ReluFeaturesMethod.PSRF: - new_nngp_feat_shape = nngp_feat_shape[:-1] + (poly_sketch_dim,) + elif method == ReluFeaturesImplementation.PSRF: + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + poly_sketch_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + sketch_dim,) + ntk_feat_shape[_channel_axis + 1:] + rng1, rng2, rng3 = random.split(rng, 3) kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) # PolySketch expansion for nngp features. if relu_layers_count == 0: - pts_input_dim = nngp_feat_shape[-1] + pts_input_dim = nngp_feat_shape[_channel_axis] else: - pts_input_dim = int(nngp_feat_shape[-1] / 2 + 0.5) + pts_input_dim = int(nngp_feat_shape[_channel_axis] / 2 + 0.5) polysketch = PolyTensorSketch(rng=rng1, input_dim=pts_input_dim, sketch_dim=poly_sketch_dim, @@ -291,9 +316,9 @@ def init_fn(rng, input_shape): # TensorSRHT of degree 2 for approximating tensor product. if relu_layers_count == 0: - ts_input_dim = ntk_feat_shape[-1] + ts_input_dim = ntk_feat_shape[_channel_axis] else: - ts_input_dim = int(ntk_feat_shape[-1] / 2 + 0.5) + ts_input_dim = int(ntk_feat_shape[_channel_axis] / 2 + 0.5) tensorsrht = TensorSRHT(rng=rng2, input_dim1=ts_input_dim, input_dim2=feature_dim0, @@ -301,18 +326,27 @@ def init_fn(rng, input_shape): # Random vectors for random features of arc-cosine kernel of order 0. if relu_layers_count == 0: - W0 = random.normal(rng3, (nngp_feat_shape[-1], feature_dim0 // 2)) + W0 = random.normal(rng3, + (nngp_feat_shape[_channel_axis], feature_dim0 // 2)) else: W0 = random.normal( - rng3, (int(nngp_feat_shape[-1] / 2 + 0.5), feature_dim0 // 2)) + rng3, + (int(nngp_feat_shape[_channel_axis] / 2 + 0.5), feature_dim0 // 2)) return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), (W0, polysketch, tensorsrht, kappa1_coeff) - elif method == ReluFeaturesMethod.POLY: + elif method == ReluFeaturesImplementation.POLY: # This only uses the polynomial approximation without sketching. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) + feat_dim = prod( + tuple(nngp_feat_shape[i] + for i in range(ndim) + if i not in [_channel_axis])) + + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + feat_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + feat_dim,) + ntk_feat_shape[_channel_axis + 1:] kappa1_coeff = kappa1_coeffs(poly_degree, relu_layers_count) kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) @@ -320,27 +354,42 @@ def init_fn(rng, input_shape): return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), (kappa0_coeff, kappa1_coeff) - elif method == ReluFeaturesMethod.EXACT: + elif method == ReluFeaturesImplementation.EXACT: # The exact feature map computation is for debug. - new_nngp_feat_shape = nngp_feat_shape[:-1] + (prod(nngp_feat_shape[:-1]),) - new_ntk_feat_shape = ntk_feat_shape[:-1] + (prod(ntk_feat_shape[:-1]),) + feat_dim = prod( + tuple(nngp_feat_shape[i] + for i in range(ndim) + if i not in [_channel_axis])) + + new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( + feat_dim,) + nngp_feat_shape[_channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( + feat_dim,) + ntk_feat_shape[_channel_axis + 1:] return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), () else: raise NotImplementedError(f'Invalid method name: {method}') - def feature_fn(f: Features, input=None, **kwargs) -> Features: + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs) -> Features: + ndim = len(f.nngp_feat.shape) + _channel_axis = channel_axis % ndim + spatial_axes = tuple( + f.nngp_feat.shape[i] for i in range(ndim) if i != _channel_axis) - input_shape: tuple = f.nngp_feat.shape[:-1] - nngp_feat_dim: tuple = f.nngp_feat.shape[-1] - ntk_feat_dim: tuple = f.ntk_feat.shape[-1] + def _convert_to_original(x): + return np.moveaxis(x.reshape(spatial_axes + (-1,)), -1, _channel_axis) - nngp_feat_2d: np.ndarray = f.nngp_feat.reshape(-1, nngp_feat_dim) - ntk_feat_2d: np.ndarray = f.ntk_feat.reshape(-1, ntk_feat_dim) - norms: np.ndarray = f.norms + def _convert_to_2d(x): + feat_dim = x.shape[_channel_axis] + return np.moveaxis(x, _channel_axis, -1).reshape(-1, feat_dim) - if method == ReluFeaturesMethod.RANDFEAT: # Random Features approach. + nngp_feat_2d = _convert_to_2d(f.nngp_feat) + if prod(f.ntk_feat.shape) != 0: + ntk_feat_2d = _convert_to_2d(f.ntk_feat) + + if method == ReluFeaturesImplementation.RANDFEAT: # Random Features approach. if generate_rand_mtx: W0: np.ndarray = input[0] W1: np.ndarray = input[1] @@ -351,19 +400,25 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: kappa0_feat = (nngp_feat_2d @ W0 > 0) / W0.shape[-1]**0.5 del W0 - nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / - W1.shape[-1]**0.5).reshape(input_shape + (-1,)) + nngp_feat = (np.maximum(nngp_feat_2d @ W1, 0) / W1.shape[-1]**0.5) del W1 - ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat, - real_output=True).reshape(input_shape + - (-1,)) + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat, real_output=True) + + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) - elif method == ReluFeaturesMethod.POLYSKETCH: + elif method == ReluFeaturesImplementation.POLYSKETCH: polysketch: PolyTensorSketch = input[0] tensorsrht: TensorSRHT = input[1] kappa0_coeff: np.ndarray = input[2] kappa1_coeff: np.ndarray = input[3] + norms = np.linalg.norm(nngp_feat_2d, axis=-1, keepdims=True) + norms = np.maximum(norms, 1e-12) + + nngp_feat_2d /= norms + ntk_feat_2d /= norms + # Apply PolySketch to approximate feature maps of kappa0 & kappa1 kernels. polysketch_feats = polysketch.sketch(nngp_feat_2d) kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) @@ -371,24 +426,35 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: del polysketch_feats # Apply SRHT to kappa1_feat so that dimension of nngp_feat is poly_sketch_dim//2. - nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + - (-1,)) + nngp_feat = polysketch.standardsrht(kappa1_feat) + # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. - ntk_feat = tensorsrht.sketch(ntk_feat_2d, - kappa0_feat).reshape(input_shape + (-1,)) + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat) + + nngp_feat *= norms + ntk_feat *= norms - elif method == ReluFeaturesMethod.PSRF: # Combination of PolySketch and Random Features. + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) + + elif method == ReluFeaturesImplementation.PSRF: # Combination of PolySketch and Random Features. W0: np.ndarray = input[0] polysketch: PolyTensorSketch = input[1] tensorsrht: TensorSRHT = input[2] kappa1_coeff: np.ndarray = input[3] + norms = np.linalg.norm(nngp_feat_2d, axis=-1, keepdims=True) + norms = np.maximum(norms, 1e-12) + + nngp_feat_2d /= norms + ntk_feat_2d /= norms + + # Apply PolySketch to approximate feature maps of kappa1 kernels. polysketch_feats = polysketch.sketch(nngp_feat_2d) kappa1_feat = polysketch.expand_feats(polysketch_feats, kappa1_coeff) del polysketch_feats - nngp_feat = polysketch.standardsrht(kappa1_feat).reshape(input_shape + - (-1,)) + nngp_feat = polysketch.standardsrht(kappa1_feat) nngp_proj = nngp_feat_2d @ W0 kappa0_feat = np.concatenate( @@ -396,35 +462,56 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: del W0 # Apply TensorSRHT to ntk_feat_2d and kappa0_feat to approximate their tensor product. - ntk_feat = tensorsrht.sketch(ntk_feat_2d, - kappa0_feat).reshape(input_shape + (-1,)) + ntk_feat = tensorsrht.sketch(ntk_feat_2d, kappa0_feat) + + nngp_feat *= norms + ntk_feat *= norms + + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) - elif method == ReluFeaturesMethod.POLY: # Polynomial approximation without sketching. + elif method == ReluFeaturesImplementation.POLY: # Polynomial approximation without sketching. kappa0_coeff: np.ndarray = input[0] kappa1_coeff: np.ndarray = input[1] + norms = np.linalg.norm(nngp_feat_2d, axis=-1, keepdims=True) + norms = np.maximum(norms, 1e-12) + + nngp_feat_2d /= norms + ntk_feat_2d /= norms + gram_nngp = np.dot(nngp_feat_2d, nngp_feat_2d.T) - nngp_feat = _cholesky(np.polyval(kappa1_coeff[::-1], - gram_nngp)).reshape(input_shape + (-1,)) + nngp_feat = _cholesky(np.polyval(kappa1_coeff[::-1], gram_nngp)) ntk = ntk_feat_2d @ ntk_feat_2d.T kappa0_mat = np.polyval(kappa0_coeff[::-1], gram_nngp) - ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + ntk_feat = _cholesky(ntk * kappa0_mat) - elif method == ReluFeaturesMethod.EXACT: # Exact feature map computations via Cholesky decomposition. - nngp_feat = _cholesky(kappa1(nngp_feat_2d)).reshape(input_shape + (-1,)) + nngp_feat *= norms + ntk_feat *= norms - ntk = ntk_feat_2d @ ntk_feat_2d.T - kappa0_mat = kappa0(nngp_feat_2d) - ntk_feat = _cholesky(ntk * kappa0_mat).reshape(input_shape + (-1,)) + nngp_feat = _convert_to_original(nngp_feat) + ntk_feat = _convert_to_original(ntk_feat) + + elif method == ReluFeaturesImplementation.EXACT: # Exact feature map computations via Cholesky decomposition. + nngp_feat = _convert_to_original( + _cholesky(kappa1(nngp_feat_2d, is_x_matrix=True))) + + if prod(f.ntk_feat.shape) != 0: + ntk = ntk_feat_2d @ ntk_feat_2d.T + kappa0_mat = kappa0(nngp_feat_2d, is_x_matrix=True) + ntk_feat = _convert_to_original(_cholesky(ntk * kappa0_mat)) + else: + ntk_feat = f.ntk_feat else: raise NotImplementedError(f'Invalid method name: {method}') - if method != ReluFeaturesMethod.RANDFEAT: - norms /= 2.0**0.5 + if method != ReluFeaturesImplementation.RANDFEAT: + ntk_feat /= 2.0**0.5 + nngp_feat /= 2.0**0.5 - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) return init_fn, feature_fn @@ -461,8 +548,7 @@ def feature_fn(f, input=None, **kwargs): nngp_coeffs: np.ndarray = input[1] ntk_coeffs: np.ndarray = input[2] - polysketch_feats = polysketch.sketch( - f.nngp_feat) # f.ntk_feat should be equal to f.nngp_feat. + polysketch_feats = polysketch.sketch(f.nngp_feat) nngp_feat = polysketch.expand_feats(polysketch_feats, nngp_coeffs) ntk_feat = polysketch.expand_feats(polysketch_feats, ntk_coeffs) @@ -474,10 +560,7 @@ def feature_fn(f, input=None, **kwargs): ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1) nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=-1) - norms = f.norms / 2.**(num_layers / 2) * (W_std**(num_layers + 1)) - - return _unnormalize_features( - f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms)) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) return init_fn, feature_fn @@ -514,90 +597,63 @@ def init_fn(rng, input_shape): if b_std is not None: nngp_feat_dim += 1 - new_nngp_feat_shape = nngp_feat_shape[:-1] + (nngp_feat_dim,) - new_ntk_feat_shape = nngp_feat_shape[:-1] + ( - nngp_feat_dim + ntk_feat_shape[channel_axis] * patch_size,) + nngp_feat_dim = nngp_feat_shape[channel_axis] * patch_size + ( + 1 if b_std is not None else 0) + ntk_feat_dim = nngp_feat_dim + ntk_feat_shape[channel_axis] * patch_size - if len(input_shape) > 2: - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'C'), () - else: - return (new_nngp_feat_shape, new_ntk_feat_shape, 'C'), () - - def feature_fn(f, input, **kwargs): - """ - Operations under ConvFeatures is concatenation of shifted features. Since - they are not linear operations, we first unnormalize features (i.e., - multiplying them by `norms`) and then re-normalize the output features. - """ - is_reversed = f.is_reversed - - f_renormalized: Features = _unnormalize_features(f) - nngp_feat: np.ndarray = f_renormalized.nngp_feat - ntk_feat: np.ndarray = f_renormalized.ntk_feat - - if is_reversed: - filter_shape_ = filter_shape[::-1] - else: - filter_shape_ = filter_shape + new_nngp_feat_shape = nngp_feat_shape[:channel_axis] + ( + nngp_feat_dim,) + nngp_feat_shape[channel_axis + 1:] + new_ntk_feat_shape = ntk_feat_shape[:channel_axis] + ( + ntk_feat_dim,) + ntk_feat_shape[channel_axis + 1:] + + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'C'), () + + @requires(batch_axis=lhs_spec.index('N'), channel_axis=lhs_spec.index('C')) + def feature_fn(f: Features, input, **kwargs): - is_reversed = not f.is_reversed + nngp_feat = f.nngp_feat + + _channel_axis = channel_axis % nngp_feat.ndim nngp_feat = _concat_shifted_features_2d( - nngp_feat, filter_shape_) * W_std / patch_size**0.5 + nngp_feat, filter_shape, dimension_numbers) * W_std / patch_size**0.5 if b_std is not None: - biases = b_std * np.ones(nngp_feat.shape[:-1] + (1,), + biases = b_std * np.ones(nngp_feat.shape[:_channel_axis] + + (1,) + nngp_feat.shape[_channel_axis + 1:], dtype=nngp_feat.dtype) - nngp_feat = np.concatenate((nngp_feat, biases), axis=channel_axis) + nngp_feat = np.concatenate((nngp_feat, biases), axis=_channel_axis) - if f.ntk_feat.ndim == 0: # check if ntk_feat is empty + if prod(f.ntk_feat.shape) == 0: # if ntk_feat is empty skip feature concat ntk_feat = nngp_feat else: ntk_feat = _concat_shifted_features_2d( - ntk_feat, filter_shape_) * W_std / patch_size**0.5 - ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) - - # Re-normalize the features. - norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) - norms = np.where(norms > 0, norms, 1.0) - nngp_feat = nngp_feat / norms - ntk_feat = ntk_feat / norms + f.ntk_feat, filter_shape, dimension_numbers) * W_std / patch_size**0.5 + ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=_channel_axis) return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, - norms=norms, - is_reversed=is_reversed) + batch_axis=out_spec.index('N'), + channel_axis=out_spec.index('C')) return init_fn, feature_fn -def _concat_shifted_features_2d(X: np.ndarray, filter_shape: Sequence[int]): - return _concat_shifted_features( - np.moveaxis(_concat_shifted_features(X, filter_shape[1]), 1, 2), - filter_shape[0]) - - -def _concat_shifted_features(X, filter_size): - """ - Concatenations of shifted image features. If input shape is (N, H, W, C), - the output has the shape (N, H, W, C * filter_size). - """ - N, H, W, C = X.shape - out = np.zeros((N, H, W, C * filter_size), dtype=X.dtype) - out = out.at[:, :, :, :C].set(X) - j = 1 - for i in range(1, min((filter_size + 1) // 2, W)): - out = out.at[:, :, :-i, j * C:(j + 1) * C].set(X[:, :, i:]) - j += 1 - out = out.at[:, :, i:, j * C:(j + 1) * C].set(X[:, :, :-i]) - j += 1 - return out +def _concat_shifted_features_2d(X: np.ndarray, + filter_shape: Sequence[int], + dimension_numbers: Optional[Tuple[str, str, + str]] = None): + return lax.conv_general_dilated_patches(X, + filter_shape=filter_shape, + window_strides=(1, 1), + padding='SAME', + dimension_numbers=dimension_numbers) @layer def AvgPoolFeatures(window_shape: Sequence[int], strides: Optional[Sequence[int]] = None, - padding: str = stax.Padding.VALID.name, + padding: str = 'VALID', normalize_edges: bool = False, batch_axis: int = 0, channel_axis: int = -1): @@ -605,8 +661,20 @@ def AvgPoolFeatures(window_shape: Sequence[int], if window_shape[0] != strides[0] or window_shape[1] != strides[1]: raise NotImplementedError('window_shape should be equal to strides.') - window_shape_kernel = (1,) + tuple(window_shape) + (1,) - strides_kernel = (1,) + tuple(strides) + (1,) + channel_axis %= 4 + spec = ''.join( + c for c in string.ascii_uppercase if c not in ('N', 'C'))[:len(strides)] + for a in sorted((batch_axis, channel_axis % (2 + len(strides)))): + if a == batch_axis: + spec = spec[:a] + 'N' + spec[a:] + else: + spec = spec[:a] + 'C' + spec[a:] + + _kernel_window_shape = lambda x_: tuple( + [x_[0] if s == 'A' else x_[0] if s == 'B' else 1 for s in spec]) + window_shape_kernel = _kernel_window_shape(window_shape) + strides_kernel = _kernel_window_shape(strides) + pooling = lambda x: _pool_kernel(x, Pooling.AVG, window_shape_kernel, strides_kernel, Padding(padding), normalize_edges, 0) @@ -621,34 +689,21 @@ def init_fn(rng, input_shape): ShapedArray(ntk_feat_shape, np.float32)).shape - if len(input_shape) > 2: - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'A'), () - else: - return (new_nngp_feat_shape, new_ntk_feat_shape, 'A'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'A'), () - def feature_fn(f, input=None, **kwargs): - # Unnormalize the input features. - f_renomalized: Features = _unnormalize_features(f) - nngp_feat: np.ndarray = f_renomalized.nngp_feat - ntk_feat: np.ndarray = f_renomalized.ntk_feat + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat nngp_feat = pooling(nngp_feat) - if f.ntk_feat.ndim == 0: # check if ntk_feat is empty + if prod(f.ntk_feat.shape) == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = pooling(ntk_feat) - # Re-normalize the features. - norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) - norms = np.where(norms > 0, norms, 1.0) - nngp_feat = nngp_feat / norms - ntk_feat = ntk_feat / norms - - return f.replace(nngp_feat=nngp_feat, - ntk_feat=ntk_feat, - norms=norms, - is_reversed=False) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) return init_fn, feature_fn @@ -666,16 +721,12 @@ def init_fn(rng, input_shape): new_nngp_feat_shape = _get_output_shape(nngp_feat_shape) new_ntk_feat_shape = _get_output_shape(ntk_feat_shape) - if len(input_shape) > 2: - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () - else: - return (new_nngp_feat_shape, new_ntk_feat_shape, ''), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () - def feature_fn(f, input=None, **kwargs): - # Unnormalize the input features. - f_renomalized: Features = _unnormalize_features(f) - nngp_feat: np.ndarray = f_renomalized.nngp_feat - ntk_feat: np.ndarray = f_renomalized.ntk_feat + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat ndim = len(nngp_feat.shape) non_spatial_axes = (batch_axis % ndim, channel_axis % ndim) @@ -684,21 +735,15 @@ def feature_fn(f, input=None, **kwargs): nngp_feat = np.mean(nngp_feat, axis=spatial_axes) ntk_feat = np.mean(ntk_feat, axis=spatial_axes) - # Re-normalize the features. - norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) - norms = np.where(norms > 0, norms, 1.0) - nngp_feat = nngp_feat / norms - ntk_feat = ntk_feat / norms - + batch_first = batch_axis % ndim < channel_axis % ndim return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, - norms=norms, - is_reversed=False) + batch_axis=0 if batch_first else 1, + channel_axis=1 if batch_first else 0) return init_fn, feature_fn -# TODO(insu): fix reshaping features for general batch/channel axes. @layer def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): @@ -711,38 +756,42 @@ def FlattenFeatures(batch_axis: int = 0, batch_axis_out: int = 0): else: raise ValueError(f'`batch_axis_out` must be 0 or 1, got {batch_axis_out}.') + def get_output_shape(input_shape): + batch_size = input_shape[batch_axis] + channel_size = functools.reduce( + op.mul, input_shape[:batch_axis] + + input_shape[(batch_axis + 1) or len(input_shape):], 1) + if batch_axis_out == 0: + return batch_size, channel_size + return channel_size, batch_size + def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - new_nngp_feat_shape = nngp_feat_shape[:1] + (prod(nngp_feat_shape[1:]),) - new_ntk_feat_shape = ntk_feat_shape[:1] + (prod(ntk_feat_shape[1:]),) - if len(input_shape) > 2: - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'F'), () - else: - return (new_nngp_feat_shape, new_ntk_feat_shape, 'F'), () + new_nngp_feat_shape = get_output_shape(nngp_feat_shape) + new_ntk_feat_shape = get_output_shape(ntk_feat_shape) - def feature_fn(f, input=None, **kwargs): - f_renomalized: Features = _unnormalize_features(f) - nngp_feat: np.ndarray = f_renomalized.nngp_feat - ntk_feat: np.ndarray = f_renomalized.ntk_feat + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'F'), () - batch_size = f.nngp_feat.shape[batch_axis] - nngp_feat = nngp_feat.reshape(batch_size, -1) / prod( - nngp_feat.shape[1:-1])**0.5 + @requires(batch_axis=batch_axis, channel_axis=None) + def feature_fn(f: Features, input, **kwargs): + nngp_feat = f.nngp_feat - if f.ntk_feat.ndim != 0: # check if ntk_feat is not empty - ntk_feat = ntk_feat.reshape(batch_size, -1) / prod( - ntk_feat.shape[1:-1])**0.5 + batch_size = nngp_feat.shape[batch_axis] + nngp_feat_dim = prod( + nngp_feat.shape) / batch_size / f.nngp_feat.shape[f.channel_axis] + nngp_feat = nngp_feat.reshape(batch_size, -1) / nngp_feat_dim**0.5 - # Re-normalize the features. - norms = np.linalg.norm(nngp_feat, axis=-1, keepdims=True) - norms = np.where(norms > 0, norms, 1.0) - nngp_feat = nngp_feat / norms - ntk_feat = ntk_feat / norms + if prod(f.ntk_feat.shape) != 0: # check if ntk_feat is not empty + ntk_feat_dim = prod( + f.ntk_feat.shape) / batch_size / f.ntk_feat.shape[f.channel_axis] + ntk_feat = f.ntk_feat.reshape(batch_size, -1) / ntk_feat_dim**0.5 + else: + ntk_feat = f.ntk_feat.reshape(batch_size, -1) return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, - norms=norms, - is_reversed=False) + batch_axis=batch_axis_out, + channel_axis=channel_axis_out) return init_fn, feature_fn @@ -756,8 +805,14 @@ def LayerNormFeatures(axis: Axes = -1, def init_fn(rng, input_shape): return input_shape, () - def feature_fn(f, input=None, **kwargs): - return f.replace(norms=np.ones_like(f.norms)) + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input, **kwargs): + norms = np.linalg.norm(f.nngp_feat, keepdims=True, axis=channel_axis) + norms = np.maximum(norms, eps) + + nngp_feat = f.nngp_feat / norms + ntk_feat = f.ntk_feat / norms if prod(f.ntk_feat.shape) != 0 else f.ntk_feat + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) return init_fn, feature_fn @@ -770,29 +825,25 @@ def AggregateFeatures( to_dense: Optional[Callable[[np.ndarray], np.ndarray]] = lambda p: p, implementation: str = AggregateImplementation.DENSE.value): - init_fn = lambda rng, input_shape: (input_shape, ()) + def init_fn(rng, input_shape): + return input_shape, () - def feature_fn(f, input=None, pattern= None, **kwargs): + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input=None, pattern=None, **kwargs): if pattern is None: raise NotImplementedError('`pattern=None` is not implemented.') - f_renomalized: Features = _unnormalize_features(f) - nngp_feat: np.ndarray = f_renomalized.nngp_feat - ntk_feat: np.ndarray = f_renomalized.ntk_feat + nngp_feat = f.nngp_feat + ntk_feat = f.ntk_feat pattern_T = np.swapaxes(pattern, 1, 2) nngp_feat = np.einsum("bnm,bmc->bnc", pattern_T, nngp_feat) - if f.ntk_feat.ndim == 0: - ntk_feat = nngp_feat - else: - ntk_feat = np.einsum("bnm,bmc->bnc", pattern_T, ntk_feat) - # Re-normalize the features. - norms = np.linalg.norm(nngp_feat, axis=channel_axis, keepdims=True) - norms = np.where(norms > 0, norms, 1.0) - nngp_feat = nngp_feat / norms - ntk_feat = ntk_feat / norms + if prod(f.ntk_feat.shape) != 0: # check if ntk_feat is not empty + ntk_feat = np.einsum("bnm,bmc->bnc", pattern_T, ntk_feat) + else: + ntk_feat = nngp_feat - return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, norms=norms) + return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) return init_fn, feature_fn diff --git a/experimental/poly_fitting.py b/experimental/poly_fitting.py index 06dc8c50..4a23cfe5 100644 --- a/experimental/poly_fitting.py +++ b/experimental/poly_fitting.py @@ -1,22 +1,21 @@ from jax import numpy as np -from jax.lax import fori_loop from jaxopt import OSQP -def kappa0(x, is_x_matrix=True): +def kappa0(x, is_x_matrix): if is_x_matrix: xxt = x @ x.T - xnormsq = np.linalg.norm(x, axis=-1)**2 + xnormsq = np.sum(x**2, axis=-1) prod = np.outer(xnormsq, xnormsq) return (1 - _arccos(xxt / _sqrt(prod)) / np.pi) else: # vector input return (1 - _arccos(x) / np.pi) -def kappa1(x, is_x_matrix=True): +def kappa1(x, is_x_matrix): if is_x_matrix: xxt = x @ x.T - xnormsq = np.linalg.norm(x, axis=-1)**2 + xnormsq = np.sum(x**2, axis=-1) prod = np.outer(xnormsq, xnormsq) return (_sqrt(prod - xxt**2) + (np.pi - _arccos(xxt / _sqrt(prod))) * xxt) / np.pi @@ -46,17 +45,17 @@ def poly_fitting_qp(xvals: np.ndarray, nx = len(xvals) x_powers = np.ones((nx, degree + 1), dtype=xvals.dtype) for i in range(degree): - x_powers = x_powers.at[:, i + 1].set(x_powers[:, i] * xvals) + x_powers = xvals.reshape(nx, + 1)**np.arange(degree + 1).reshape(1, degree + 1) y_weighted = fvals * weights - x_powers_weighted = x_powers.T * weights[None,:] + x_powers_weighted = x_powers.T * weights[None, :] dx_powers = x_powers[:-1, :] - x_powers[1:, :] # OSQP algorithm for solving min_x x'*Q*x + c'*x such that A*x=b, G*x<= h P = x_powers_weighted @ x_powers_weighted.T - Q = .5 * (P.T + P + 1e-5 * np.eye(P.shape[0], dtype=xvals.dtype) - ) # make sure Q is symmetric + Q = .5 * (P.T + P + 1e-5 * np.eye(P.shape[0], dtype=xvals.dtype)) # make sure Q is symmetric c = -x_powers_weighted @ y_weighted G = np.concatenate((dx_powers, -np.eye(degree + 1)), axis=0) h = np.zeros(nx + degree, dtype=xvals.dtype) @@ -73,8 +72,9 @@ def poly_fitting_qp(xvals: np.ndarray, def kappa0_coeffs(degree: int, num_layers: int): # A lower bound of kappa0^{(num_layers)} reduces to alpha_ from -1 - alpha_ = fori_loop(0, num_layers, lambda i, x_: - (x_ + kappa1(x_, is_x_matrix=False)) / 2., -1.) + alpha_ = -1 + for i in range(num_layers): + alpha_ = (alpha_ + kappa1(alpha_, is_x_matrix=False)) / 2. # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used @@ -88,7 +88,7 @@ def kappa0_coeffs(degree: int, num_layers: int): # For kappa0, we set all weights to be one. weights = np.ones(len(fvals), dtype=xvals.dtype) - # Coefficients can be obtained by solving QP with OSQP jaxopt. kappa0 has a + # Coefficients can be obtained by solving QP with OSQP jaxopt. kappa0 has a # sharp slope at x=1, hence we add an equailty condition of p_n(1)=f(x). coeffs = poly_fitting_qp(xvals, fvals, weights, degree, eq_last_point=True) return np.where(coeffs < 1e-5, 0.0, coeffs) @@ -97,9 +97,9 @@ def kappa0_coeffs(degree: int, num_layers: int): def kappa1_coeffs(degree: int, num_layers: int): # A lower bound of kappa1^{(num_layers)} reduces to alpha_ from -1 - alpha_ = fori_loop( - 0, num_layers, lambda i, x_: - (2. * x_ + kappa1(x_, is_x_matrix=False)) / 3., -1.) + alpha_ = -1 + for i in range(num_layers): + alpha_ = (2. * alpha_ + kappa1(alpha_, is_x_matrix=False)) / 3. # Points for polynomial fitting contain (1) equi-spaced ones from [alpha_,1] # and (2) non-equi-spaced ones from [0,1]. For (2), cosine function is used diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py index 5dd882aa..b9ec2bb2 100644 --- a/experimental/tests/features_test.py +++ b/experimental/tests/features_test.py @@ -5,66 +5,69 @@ from jax.config import config import jax.numpy as np import jax.random as random -from neural_tangents._src.utils import utils as ntutils +from neural_tangents._src.utils import utils from neural_tangents import stax from tests import test_utils -from experimental.features import DenseFeatures, ReluFeatures, ConvFeatures, AvgPoolFeatures, FlattenFeatures, serial, GlobalAvgPoolFeatures, AggregateFeatures +import experimental.features as ft - -config.update("jax_enable_x64", True) config.parse_flags_with_absl() config.update('jax_numpy_rank_promotion', 'raise') test_utils.update_test_tolerance() -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '' - NUM_DIMS = [128, 256, 512] -WEIGHT_VARIANCES = [0.01, 0.1, 1.] -BIAS_VARIANCES = [None, 0.01, 0.1] +WEIGHT_VARIANCES = [0.5, 1.] +BIAS_VARIANCES = [None, 0.1] + + +def _convert_features_to_matrices(f_, channel_axis=-1): + if isinstance(f_, ft.Features): + nngp = _convert_features_to_matrices(f_.nngp_feat, f_.channel_axis) + ntk = _convert_features_to_matrices(f_.ntk_feat, f_.channel_axis) + return nngp, ntk + elif isinstance(f_, np.ndarray): + channel_dim = f_.shape[channel_axis] + feat = np.moveaxis(f_, channel_axis, -1).reshape(-1, channel_dim) + k_mat = feat @ feat.T + if f_.ndim > 2: + k_mat = utils.zip_axes( + k_mat.reshape( + tuple(f_.shape[i] + for i in range(len(f_.shape)) + if i != channel_axis) * 2)) + return k_mat + else: + raise ValueError + + +def _convert_image_feature_to_kernel(feat): + return utils.zip_axes(np.einsum("ijkc,xyzc->ijkxyz", feat, feat)) + + +def _get_init_data(rng, shape, normalized_output=False): + x = random.normal(rng, shape) + if normalized_output: + return x / np.linalg.norm(x, axis=-1, keepdims=True) + else: + return x class FeaturesTest(test_utils.NeuralTangentsTestCase): - @classmethod - def _get_init_data(cls, rng, shape, normalized_output=False): - x = random.normal(rng, shape) - if normalized_output: - return x / np.linalg.norm(x, axis=-1, keepdims=True) - else: - return x - - @classmethod - def _convert_image_feature_to_kernel(cls, f_): - return ntutils.zip_axes(np.einsum("ijkc,xyzc->ijkxyz", f_, f_)) - - @parameterized.named_parameters( - test_utils.cases_from_list({ - 'testcase_name': - ' [Wstd{}_bstd{}_{}layers_{}] '.format(W_std, b_std, n_layers, - 'jit' if do_jit else ''), - 'W_std': - W_std, - 'b_std': - b_std, - 'n_layers': - n_layers, - 'do_jit': - do_jit, - } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES - for n_layers in [1, 2, 3, 4] - for do_jit in [True, False])) + @parameterized.product(W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + n_layers=[1, 2, 3, 4], + do_jit=[True, False]) def test_dense_features(self, W_std, b_std, n_layers, do_jit): n, d = 4, 256 rng = random.PRNGKey(1) - x = self._get_init_data(rng, (n, d)) + x = _get_init_data(rng, (n, d)) dense_args = {'out_dim': 1, 'W_std': W_std, 'b_std': b_std} kernel_fn = stax.serial(*[stax.Dense(**dense_args)] * n_layers)[2] - feature_fn = serial(*[DenseFeatures(**dense_args)] * n_layers)[1] + feature_fn = ft.serial(*[ft.DenseFeatures(**dense_args)] * n_layers)[1] if do_jit: kernel_fn = jit(kernel_fn) @@ -76,30 +79,16 @@ def test_dense_features(self, W_std, b_std, n_layers, do_jit): self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) - @parameterized.named_parameters( - test_utils.cases_from_list({ - 'testcase_name': - ' [Wstd{}_bstd{}_numlayers{}_{}_{}] '.format( - W_std, b_std, n_layers, relu_method, 'jit' if do_jit else ''), - 'W_std': - W_std, - 'b_std': - b_std, - 'n_layers': - n_layers, - 'relu_method': - relu_method, - 'do_jit': - do_jit, - } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES - for relu_method in - ['RANDFEAT', 'POLYSKETCH', 'PSRF', 'POLY', 'EXACT'] - for n_layers in [1, 2, 3, 4] - for do_jit in [True, False])) + @parameterized.product( + W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + n_layers=[1, 2, 3, 4], + relu_method=['RANDFEAT', 'POLYSKETCH', 'PSRF', 'POLY', 'EXACT'], + do_jit=[True, False]) def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): rng = random.PRNGKey(1) n, d = 4, 256 - x = self._get_init_data(rng, (n, d)) + x = _get_init_data(rng, (n, d)) dense_args = {"out_dim": 1, "W_std": W_std, "b_std": b_std} relu_args = {'method': relu_method} @@ -116,17 +105,20 @@ def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): relu_args['poly_degree'] = 4 relu_args['poly_sketch_dim'] = 4096 relu_args['sketch_dim'] = 4096 - elif relu_method in ['EXACT', 'POLY']: + elif relu_method == 'POLY': + relu_args['poly_degree'] = 16 + elif relu_method == 'EXACT': pass else: raise ValueError(relu_method) - _, _, kernel_fn = stax.serial( + kernel_fn = stax.serial( *[stax.Dense(**dense_args), stax.Relu()] * n_layers + - [stax.Dense(**dense_args)]) - init_fn, feature_fn = serial( - *[DenseFeatures(**dense_args), - ReluFeatures(**relu_args)] * n_layers + [DenseFeatures(**dense_args)]) + [stax.Dense(**dense_args)])[2] + init_fn, feature_fn = ft.serial( + *[ft.DenseFeatures(**dense_args), + ft.ReluFeatures(**relu_args)] * n_layers + + [ft.DenseFeatures(**dense_args)]) rng2 = random.PRNGKey(2) _, feat_fn_inputs = init_fn(rng2, x.shape) @@ -136,10 +128,8 @@ def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): feature_fn = jit(feature_fn) k = kernel_fn(x, None) - k_nngp = k.nngp - k_ntk = k.ntk - f = feature_fn(x, feat_fn_inputs) + if np.iscomplexobj(f.nngp_feat) or np.iscomplexobj(f.ntk_feat): nngp_feat = np.concatenate((f.nngp_feat.real, f.nngp_feat.imag), axis=-1) ntk_feat = np.concatenate((f.ntk_feat.real, f.ntk_feat.imag), axis=-1) @@ -148,32 +138,20 @@ def test_fc_relu_nngp_ntk(self, W_std, b_std, n_layers, relu_method, do_jit): k_ntk_approx = f.ntk_feat @ f.ntk_feat.T if relu_method == 'EXACT': - self.assertAllClose(k_nngp, k_nngp_approx) - self.assertAllClose(k_ntk, k_ntk_approx) + self.assertAllClose(k.nngp, k_nngp_approx) + self.assertAllClose(k.ntk, k_ntk_approx) else: - test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.2, 1.) - test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k.nngp, k_nngp_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k.ntk, k_ntk_approx, 0.2, 1.) - @parameterized.named_parameters( - test_utils.cases_from_list({ - 'testcase_name': - ' [Wstd{}_bstd{}_numlayers{}_{}] '.format( - W_std, b_std, n_layers, 'jit' if do_jit else ''), - 'W_std': - W_std, - 'b_std': - b_std, - 'n_layers': - n_layers, - 'do_jit': - do_jit, - } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES - for n_layers in [1, 2, 3, 4] - for do_jit in [True, False])) + @parameterized.product(W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + n_layers=[1, 2, 3, 4], + do_jit=[True, False]) def test_conv_features(self, W_std, b_std, n_layers, do_jit): n, h, w, c = 3, 4, 5, 2 rng = random.PRNGKey(1) - x = self._get_init_data(rng, (n, h, w, c)) + x = _get_init_data(rng, (n, h, w, c)) conv_args = { 'out_chan': 1, @@ -184,7 +162,7 @@ def test_conv_features(self, W_std, b_std, n_layers, do_jit): } kernel_fn = stax.serial(*[stax.Conv(**conv_args)] * n_layers)[2] - feature_fn = serial(*[ConvFeatures(**conv_args)] * n_layers)[1] + feature_fn = ft.serial(*[ft.ConvFeatures(**conv_args)] * n_layers)[1] if do_jit: kernel_fn = jit(kernel_fn) @@ -193,25 +171,22 @@ def test_conv_features(self, W_std, b_std, n_layers, do_jit): k = kernel_fn(x) f = feature_fn(x, [()] * n_layers) - k_nngp_approx = self._convert_image_feature_to_kernel(f.nngp_feat) - k_ntk_approx = self._convert_image_feature_to_kernel(f.ntk_feat) + if k.is_reversed: + nngp_feat = np.moveaxis(f.nngp_feat, 1, 2) + ntk_feat = np.moveaxis(f.ntk_feat, 1, 2) + f = f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) + + k_nngp_approx = _convert_image_feature_to_kernel(f.nngp_feat) + k_ntk_approx = _convert_image_feature_to_kernel(f.ntk_feat) self.assertAllClose(k.nngp, k_nngp_approx) self.assertAllClose(k.ntk, k_ntk_approx) - @parameterized.named_parameters( - test_utils.cases_from_list({ - 'testcase_name': - ' [nlayers{}_{}] '.format(n_layers, 'jit' if do_jit else ''), - 'n_layers': - n_layers, - 'do_jit': - do_jit - } for n_layers in [1, 2, 3, 4] for do_jit in [True, False])) + @parameterized.product(n_layers=[1, 2, 3, 4], do_jit=[True, False]) def test_avgpool_features(self, n_layers, do_jit): n, h, w, c = 3, 32, 28, 2 rng = random.PRNGKey(1) - x = self._get_init_data(rng, (n, h, w, c)) + x = _get_init_data(rng, (n, h, w, c)) avgpool_args = { 'window_shape': (2, 2), @@ -219,62 +194,209 @@ def test_avgpool_features(self, n_layers, do_jit): 'padding': 'SAME' } - kernel_fn = stax.serial(*[stax.AvgPool(**avgpool_args)] * n_layers)[2] - feature_fn = serial(*[AvgPoolFeatures(**avgpool_args)] * n_layers)[1] + kernel_fn = stax.serial(*[stax.AvgPool(**avgpool_args)] * n_layers + + [stax.Flatten()])[2] + feature_fn = ft.serial(*[ft.AvgPoolFeatures(**avgpool_args)] * n_layers + + [ft.FlattenFeatures()])[1] if do_jit: kernel_fn = jit(kernel_fn) feature_fn = jit(feature_fn) k = kernel_fn(x) - f = feature_fn(x, [()] * n_layers) + f = feature_fn(x, [()] * (n_layers + 1)) - k_nngp_approx = self._convert_image_feature_to_kernel(f.nngp_feat) + k_nngp_approx, k_ntk_approx = _convert_features_to_matrices(f) self.assertAllClose(k.nngp, k_nngp_approx) + if k.ntk.ndim > 0: + self.assertAllClose(k.ntk, k_ntk_approx) + + @parameterized.parameters([{ + 'ndim': nd, + 'do_jit': do_jit + } for nd in [2, 3, 4] for do_jit in [True, False]]) + def test_flatten_features(self, ndim, do_jit): + key = random.PRNGKey(1) + n, h, w, c = 4, 8, 6, 5 + width = 1 + W_std = 1.7 + b_std = 0.1 + if ndim == 2: + input_shape = (n, h * w * c) + elif ndim == 3: + input_shape = (n, h * w, c) + elif ndim == 4: + input_shape = (n, h, w, c) + else: + raise absltest.SkipTest() - def test_flatten_features(self): - n, h, w, c = 3, 32, 28, 2 - n_layers = 1 - rng = random.PRNGKey(1) - x = self._get_init_data(rng, (n, h, w, c)) + x = random.normal(key, input_shape) - kernel_fn = stax.serial(*[stax.Flatten()] * n_layers)[2] + dense_kernel = stax.Dense(width, W_std=W_std, b_std=b_std) + dense_feature = ft.DenseFeatures(width, W_std=W_std, b_std=b_std) - k = kernel_fn(x) + relu_kernel = stax.Relu() + relu_feature = ft.ReluFeatures(method='EXACT') - feature_fn = serial(*[FlattenFeatures()] * n_layers)[1] + kernel_fc = stax.serial(dense_kernel, relu_kernel, dense_kernel)[2] + kernel_top = stax.serial(dense_kernel, relu_kernel, dense_kernel, + stax.Flatten())[2] + kernel_mid = stax.serial(dense_kernel, relu_kernel, stax.Flatten(), + dense_kernel)[2] + kernel_bot = stax.serial(stax.Flatten(), dense_kernel, relu_kernel, + dense_kernel)[2] - f = feature_fn(x, [()] * n_layers) + feature_fc = ft.serial(dense_feature, relu_feature, dense_feature)[1] + feature_top = ft.serial(dense_feature, relu_feature, dense_feature, + ft.FlattenFeatures())[1] + feature_mid = ft.serial(dense_feature, relu_feature, ft.FlattenFeatures(), + dense_feature)[1] + feature_bot = ft.serial(ft.FlattenFeatures(), dense_feature, relu_feature, + dense_feature)[1] - self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) + if do_jit: + kernel_fc = jit(kernel_fc) + kernel_top = jit(kernel_top) + kernel_mid = jit(kernel_mid) + kernel_bot = jit(kernel_bot) + + feature_fc = jit(feature_fc) + feature_top = jit(feature_top) + feature_mid = jit(feature_mid) + feature_bot = jit(feature_bot) + + k_fc = kernel_fc(x) + f_fc = feature_fc(x, [()] * 3) + nngp_fc, ntk_fc = _convert_features_to_matrices(f_fc) + self.assertAllClose(k_fc.nngp, nngp_fc) + self.assertAllClose(k_fc.ntk, ntk_fc) + + k_top = kernel_top(x) + f_top = feature_top(x, [()] * 4) + nngp_top, ntk_top = _convert_features_to_matrices(f_top) + self.assertAllClose(k_top.nngp, nngp_top) + self.assertAllClose(k_top.ntk, ntk_top) + + k_mid = kernel_mid(x) + f_mid = feature_mid(x, [()] * 4) + nngp_mid, ntk_mid = _convert_features_to_matrices(f_mid) + self.assertAllClose(k_mid.nngp, nngp_mid) + self.assertAllClose(k_mid.ntk, ntk_mid) + + k_bot = kernel_bot(x) + f_bot = feature_bot(x, [()] * 4) + nngp_bot, ntk_bot = _convert_features_to_matrices(f_bot) + self.assertAllClose(k_bot.nngp, nngp_bot) + self.assertAllClose(k_bot.ntk, ntk_bot) + + @parameterized.product(ndim=[2, 3, 4], + channel_axis=[1, 2, 3], + n_layers=[1, 2, 3, 4], + use_conv=[True, False], + use_layernorm=[True, False], + do_pool=[True, False], + do_jit=[True, False]) + def test_channel_axis(self, ndim, channel_axis, use_conv, n_layers, + use_layernorm, do_pool, do_jit): + n, h, w, c = 4, 8, 6, 5 + W_std = 1.7 + b_std = 0.1 + key = random.PRNGKey(1) + channel_axis %= ndim + + if ndim == 2: + if channel_axis != 1: + raise absltest.SkipTest() + input_shape = (n, h * w * c) + elif ndim == 3: + if channel_axis == 1: + input_shape = (n, c, h * w) + elif channel_axis == 2: + input_shape = (n, h * w, c) + else: + raise absltest.SkipTest() + elif ndim == 4: + if channel_axis == 1: + input_shape = (n, c, h, w) + dn = ('NCAB', 'ABIO', 'NCAB') + elif channel_axis == 3: + input_shape = (n, h, w, c) + dn = ('NABC', 'ABIO', 'NABC') + else: + raise absltest.SkipTest() + + x = random.normal(key, input_shape) + + if use_conv: + if ndim != 4: + raise absltest.SkipTest() + else: + linear = stax.Conv(1, (3, 3), (1, 1), + 'SAME', + W_std=W_std, + b_std=b_std, + dimension_numbers=dn) + linear_feat = ft.ConvFeatures(1, (3, 3), (1, 1), + W_std=W_std, + b_std=b_std, + dimension_numbers=dn) + else: + linear = stax.Dense(1, + W_std=W_std, + b_std=b_std, + channel_axis=channel_axis) + linear_feat = ft.DenseFeatures(1, + W_std=W_std, + b_std=b_std, + channel_axis=channel_axis) + + layers = [linear, stax.Relu()] * n_layers + layers += [linear] + layers += [stax.LayerNorm(channel_axis, channel_axis=channel_axis) + ] if use_layernorm else [] + layers += [stax.GlobalAvgPool( + channel_axis=channel_axis)] if do_pool else [stax.Flatten()] + kernel_fn = stax.serial(*layers)[2] + + layers = [ + linear_feat, + ft.ReluFeatures(method='EXACT', channel_axis=channel_axis) + ] * n_layers + layers += [linear_feat] + layers += [ft.LayerNormFeatures(channel_axis, channel_axis=channel_axis) + ] if use_layernorm else [] + layers += [ft.GlobalAvgPoolFeatures( + channel_axis=channel_axis)] if do_pool else [ft.FlattenFeatures()] + feature_fn = ft.serial(*layers)[1] - @parameterized.named_parameters( - test_utils.cases_from_list({ - 'testcase_name': - ' [Wstd{}_bstd{}_depth{}_{}_{}] '.format( - W_std, b_std, depth, relu_method, 'jit' if do_jit else ''), - 'W_std': - W_std, - 'b_std': - b_std, - 'depth': - depth, - 'relu_method': - relu_method, - 'do_jit': - do_jit, - } for W_std in WEIGHT_VARIANCES for b_std in BIAS_VARIANCES - for relu_method in ['PSRF'] for depth in [5] - for do_jit in [False])) - def test_myrtle_network(self, W_std, b_std, relu_method, depth, do_jit): - if relu_method in ['RANDFEAT', 'POLYSKETCH', 'PSRF']: - import os - os.environ['CUDA_VISIBLE_DEVICES'] = '' + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + k = kernel_fn(x) + f = feature_fn(x, [()] * len(layers)) + nngp, ntk = _convert_features_to_matrices(f) + self.assertAllClose(k.nngp, nngp) + self.assertAllClose(k.ntk, ntk) + + @parameterized.product( + channel_axis=[1, 3], + W_std=WEIGHT_VARIANCES, + b_std=BIAS_VARIANCES, + relu_method=['RANDFEAT', 'POLYSKETCH', 'PSRF', 'POLY', 'EXACT'], + depth=[5], + do_jit=[True, False]) + def test_myrtle_network(self, channel_axis, W_std, b_std, relu_method, depth, + do_jit): n, h, w, c = 2, 32, 32, 3 rng = random.PRNGKey(1) - x = self._get_init_data(rng, (n, h, w, c)) + if channel_axis == 1: + x = _get_init_data(rng, (n, c, h, w)) + dn = ('NCAB', 'ABIO', 'NCAB') + elif channel_axis == 3: + x = _get_init_data(rng, (n, h, w, c)) + dn = ('NABC', 'ABIO', 'NABC') layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} @@ -282,36 +404,61 @@ def _get_myrtle_kernel_fn(): conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, - padding='SAME') - + padding='SAME', + dimension_numbers=dn) layers = [] layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][0] - layers += [stax.AvgPool((2, 2), strides=(2, 2))] + layers += [ + stax.AvgPool((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][1] - layers += [stax.AvgPool((2, 2), strides=(2, 2))] + layers += [ + stax.AvgPool((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] layers += [conv(1, (3, 3)), stax.Relu()] * layer_factor[depth][2] - layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3 + layers += [ + stax.AvgPool((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] * 3 layers += [stax.Flatten(), stax.Dense(1, W_std=W_std, b_std=b_std)] return stax.serial(*layers) def _get_myrtle_feature_fn(**relu_args): - conv = functools.partial(ConvFeatures, W_std=W_std, b_std=b_std) + conv = functools.partial(ft.ConvFeatures, + W_std=W_std, + b_std=b_std, + padding='SAME', + dimension_numbers=dn) layers = [] - layers += [conv(1, (3, 3)), ReluFeatures(**relu_args) - ] * layer_factor[depth][0] - layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] - layers += [conv(1, (3, 3)), ReluFeatures(**relu_args) - ] * layer_factor[depth][1] - layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] - layers += [conv(1, (3, 3)), ReluFeatures(**relu_args) - ] * layer_factor[depth][2] - layers += [AvgPoolFeatures((2, 2), strides=(2, 2))] * 3 - layers += [FlattenFeatures(), DenseFeatures(1, W_std=W_std, b_std=b_std)] - - return serial(*layers) - - _, _, kernel_fn = _get_myrtle_kernel_fn() + layers += [ + conv(1, (3, 3)), + ft.ReluFeatures(channel_axis=channel_axis, **relu_args) + ] * layer_factor[depth][0] + layers += [ + ft.AvgPoolFeatures((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] + layers += [ + conv(1, (3, 3)), + ft.ReluFeatures(channel_axis=channel_axis, **relu_args) + ] * layer_factor[depth][1] + layers += [ + ft.AvgPoolFeatures((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] + layers += [ + conv(1, (3, 3)), + ft.ReluFeatures(channel_axis=channel_axis, **relu_args) + ] * layer_factor[depth][2] + layers += [ + ft.AvgPoolFeatures((2, 2), strides=(2, 2), channel_axis=channel_axis) + ] * 3 + layers += [ + ft.FlattenFeatures(), + ft.DenseFeatures(1, W_std=W_std, b_std=b_std) + ] + + return ft.serial(*layers) + + kernel_fn = _get_myrtle_kernel_fn()[2] relu_args = {'method': relu_method} if relu_method == 'RANDFEAT': @@ -327,7 +474,9 @@ def _get_myrtle_feature_fn(**relu_args): relu_args['poly_degree'] = 4 relu_args['poly_sketch_dim'] = 2048 relu_args['sketch_dim'] = 2048 - elif relu_method in ['EXACT', 'POLY']: + elif relu_method == 'POLY': + relu_args['poly_degree'] = 16 + elif relu_method == 'EXACT': pass else: raise ValueError(relu_method) @@ -359,58 +508,28 @@ def _get_myrtle_feature_fn(**relu_args): test_utils.assert_close_matrices(self, k_nngp, k_nngp_approx, 0.2, 1.) test_utils.assert_close_matrices(self, k_ntk, k_ntk_approx, 0.2, 1.) - def test_global_average_pooling_features(self): - rng = random.PRNGKey(1) - input_shape = (4, 5, 6, 7) - x = random.normal(rng, input_shape) - - _, _, kernel_fn = stax.serial( - stax.Conv(1, (3, 3), padding='SAME'), - stax.Relu(), - stax.GlobalAvgPool() - ) - - _, feature_fn = serial( - ConvFeatures(1, (3, 3)), - ReluFeatures(method='EXACT'), - GlobalAvgPoolFeatures() - ) - - k = jit(kernel_fn)(x) - f = jit(feature_fn)(x, [()] * 3) - - self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) - self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) - def test_aggregate_features(self): rng = random.PRNGKey(1) rng1, rng2 = random.split(rng, 2) batch_size = 4 num_channels = 3 - shape = (5, ) + shape = (5,) width = 1 x = random.normal(rng1, (batch_size,) + shape + (num_channels,)) pattern = random.uniform(rng2, (batch_size,) + shape * 2) - _, _, kernel_fn = stax.serial( - stax.Dense(width, W_std=2**0.5), - stax.Relu(), - stax.Aggregate(), - stax.GlobalAvgPool(), - stax.Dense(width) - ) + kernel_fn = stax.serial(stax.Dense(width, W_std=2**0.5), stax.Relu(), + stax.Aggregate(), stax.GlobalAvgPool(), + stax.Dense(width))[2] k = jit(kernel_fn)(x, None, pattern=(pattern, pattern)) - _, feature_fn = serial( - DenseFeatures(width, W_std=2**0.5), - ReluFeatures(method='EXACT'), - AggregateFeatures(), - GlobalAvgPoolFeatures(), - DenseFeatures(width) - ) + feature_fn = ft.serial(ft.DenseFeatures(width, W_std=2**0.5), + ft.ReluFeatures(method='EXACT'), + ft.AggregateFeatures(), ft.GlobalAvgPoolFeatures(), + ft.DenseFeatures(width))[1] f = feature_fn(x, [()] * 5, **{'pattern': pattern}) self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) diff --git a/experimental/tests/sketching_test.py b/experimental/tests/sketching_test.py index b5334c3a..16322403 100644 --- a/experimental/tests/sketching_test.py +++ b/experimental/tests/sketching_test.py @@ -4,7 +4,6 @@ import jax.numpy as np from math import factorial import jax.random as random -from jax import test_util as jtu from experimental.sketching import PolyTensorSketch from tests import test_utils @@ -12,7 +11,7 @@ NUM_DIMS = [64, 256, 1024] -class SketchingTest(jtu.JaxTestCase): +class SketchingTest(test_utils.NeuralTangentsTestCase): @classmethod def _get_init_data(cls, rng, shape, normalized_output=True): @@ -22,14 +21,12 @@ def _get_init_data(cls, rng, shape, normalized_output=True): else: return x - @parameterized.named_parameters( - jtu.cases_from_list({ - 'testcase_name': f' [n{n}_d{d}]', - 'n': 4, - 'd': 32, - 'sketch_dim': 1024, - 'degree': 16 - } for n in NUM_POINTS for d in NUM_DIMS)) + @parameterized.parameters({ + 'n': n, + 'd': d, + 'sketch_dim': 1024, + 'degree': 16 + } for n in NUM_POINTS for d in NUM_DIMS) def test_exponential_kernel(self, n, d, sketch_dim, degree): rng = random.PRNGKey(1) x = self._get_init_data(rng, (n, d), True) From 7f0b5f78349219de9f236558aef6ca2fc707863e Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 13 Jul 2022 13:47:50 +0900 Subject: [PATCH 41/44] Add setup.py --- setup.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/setup.py b/setup.py index d4183204..6861b15e 100644 --- a/setup.py +++ b/setup.py @@ -29,11 +29,7 @@ 'jax>=0.3.13', 'frozendict>=2.3', 'typing_extensions>=4.0.1', -<<<<<<< HEAD - 'jaxopt>=0.3.1', -======= 'tf2jax>=0.3.0', ->>>>>>> upstream/main ] From 2036f86a18bfcab820c534e971b21b6210c374e9 Mon Sep 17 00:00:00 2001 From: insuhan Date: Wed, 13 Jul 2022 13:58:16 +0900 Subject: [PATCH 42/44] Add jaxopt in setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 6861b15e..a376cd63 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ 'jax>=0.3.13', 'frozendict>=2.3', 'typing_extensions>=4.0.1', + 'jaxopt>=0.3.1', 'tf2jax>=0.3.0', ] From af524d2fe56e812cb969c91a486383c4dbf3b104 Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 14 Jul 2022 00:40:06 +0900 Subject: [PATCH 43/44] Change the third argument of init_fn --- experimental/features.py | 28 ++++++++++++++-------------- experimental/tests/features_test.py | 1 - 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index f69bca0f..a9230aca 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -70,7 +70,7 @@ def init_fn_any(rng, input_shape_any, **kwargs): if _is_single_shape(input_shape_any): # Add a dummy shape for ntk_feat dummy_shape = (-1,) + (0,) * (len(input_shape_any) - 1) - input_shape = (input_shape_any, dummy_shape, '') + input_shape = (input_shape_any, dummy_shape, 0) return init_fn(rng, input_shape, **kwargs) else: return init_fn(rng, input_shape_any, **kwargs) @@ -180,7 +180,7 @@ def init_fn(rng, input_shape): new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( ntk_feat_dim,) + ntk_feat_shape[_channel_axis + 1:] - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'D'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () @requires(batch_axis=batch_axis, channel_axis=channel_axis) def feature_fn(f: Features, input, **kwargs): @@ -222,9 +222,8 @@ def ReluFeatures(method: str = 'RANDFEAT', def init_fn(rng, input_shape): nngp_feat_shape, ntk_feat_shape = input_shape[0], input_shape[1] - net_shape = input_shape[2] - relu_layers_count = net_shape.count('R') - new_net_shape = net_shape + 'R' + relu_layers_count = input_shape[2] + new_relu_layers_count = relu_layers_count + 1 ndim = len(nngp_feat_shape) _channel_axis = channel_axis % ndim @@ -253,7 +252,7 @@ def init_fn(rng, input_shape): sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (W0, W1, tensorsrht) + new_relu_layers_count), (W0, W1, tensorsrht) elif method == ReluFeaturesImplementation.POLYSKETCH: new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( @@ -291,8 +290,8 @@ def init_fn(rng, input_shape): sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (polysketch, tensorsrht, kappa0_coeff, - kappa1_coeff) + new_relu_layers_count), (polysketch, tensorsrht, kappa0_coeff, + kappa1_coeff) elif method == ReluFeaturesImplementation.PSRF: new_nngp_feat_shape = nngp_feat_shape[:_channel_axis] + ( @@ -334,7 +333,7 @@ def init_fn(rng, input_shape): (int(nngp_feat_shape[_channel_axis] / 2 + 0.5), feature_dim0 // 2)) return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (W0, polysketch, tensorsrht, kappa1_coeff) + new_relu_layers_count), (W0, polysketch, tensorsrht, kappa1_coeff) elif method == ReluFeaturesImplementation.POLY: # This only uses the polynomial approximation without sketching. @@ -352,7 +351,7 @@ def init_fn(rng, input_shape): kappa0_coeff = kappa0_coeffs(poly_degree, relu_layers_count) return (new_nngp_feat_shape, new_ntk_feat_shape, - new_net_shape), (kappa0_coeff, kappa1_coeff) + new_relu_layers_count), (kappa0_coeff, kappa1_coeff) elif method == ReluFeaturesImplementation.EXACT: # The exact feature map computation is for debug. @@ -366,7 +365,8 @@ def init_fn(rng, input_shape): new_ntk_feat_shape = ntk_feat_shape[:_channel_axis] + ( feat_dim,) + ntk_feat_shape[_channel_axis + 1:] - return (new_nngp_feat_shape, new_ntk_feat_shape, new_net_shape), () + return (new_nngp_feat_shape, new_ntk_feat_shape, + new_relu_layers_count), () else: raise NotImplementedError(f'Invalid method name: {method}') @@ -606,7 +606,7 @@ def init_fn(rng, input_shape): new_ntk_feat_shape = ntk_feat_shape[:channel_axis] + ( ntk_feat_dim,) + ntk_feat_shape[channel_axis + 1:] - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'C'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () @requires(batch_axis=lhs_spec.index('N'), channel_axis=lhs_spec.index('C')) def feature_fn(f: Features, input, **kwargs): @@ -689,7 +689,7 @@ def init_fn(rng, input_shape): ShapedArray(ntk_feat_shape, np.float32)).shape - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'A'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () @requires(batch_axis=batch_axis, channel_axis=channel_axis) def feature_fn(f: Features, input, **kwargs): @@ -770,7 +770,7 @@ def init_fn(rng, input_shape): new_nngp_feat_shape = get_output_shape(nngp_feat_shape) new_ntk_feat_shape = get_output_shape(ntk_feat_shape) - return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2] + 'F'), () + return (new_nngp_feat_shape, new_ntk_feat_shape, input_shape[2]), () @requires(batch_axis=batch_axis, channel_axis=None) def feature_fn(f: Features, input, **kwargs): diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py index b9ec2bb2..22175bc2 100644 --- a/experimental/tests/features_test.py +++ b/experimental/tests/features_test.py @@ -303,7 +303,6 @@ def test_channel_axis(self, ndim, channel_axis, use_conv, n_layers, W_std = 1.7 b_std = 0.1 key = random.PRNGKey(1) - channel_axis %= ndim if ndim == 2: if channel_axis != 1: From a16098c00c973929e4c2c004c4e67cedc293efdc Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 14 Jul 2022 01:42:55 +0900 Subject: [PATCH 44/44] Add ReluNTKFeatures test --- experimental/features.py | 27 ++++++++++++++++--------- experimental/tests/features_test.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/experimental/features.py b/experimental/features.py index a9230aca..96555e28 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -521,12 +521,14 @@ def _cholesky(mat): @layer -def ReluNTKFeatures( - num_layers: int, - poly_degree: int = 16, - poly_sketch_dim: int = 1024, - W_std: float = 1., -): +def ReluNTKFeatures(num_layers: int, + poly_degree: int = 16, + poly_sketch_dim: int = 1024, + batch_axis: int = 0, + channel_axis: int = -1): + + if batch_axis != 0 or channel_axis != -1: + raise NotImplementedError(f'Not supported axes.') def init_fn(rng, input_shape): input_dim = input_shape[0][-1] @@ -541,14 +543,18 @@ def init_fn(rng, input_shape): return (), (polysketch, nngp_coeffs, ntk_coeffs) - def feature_fn(f, input=None, **kwargs): + @requires(batch_axis=batch_axis, channel_axis=channel_axis) + def feature_fn(f: Features, input=None, **kwargs): input_shape = f.nngp_feat.shape[:-1] polysketch: PolyTensorSketch = input[0] nngp_coeffs: np.ndarray = input[1] ntk_coeffs: np.ndarray = input[2] - polysketch_feats = polysketch.sketch(f.nngp_feat) + norms = np.linalg.norm(f.nngp_feat, axis=channel_axis, keepdims=True) + nngp_feat = f.nngp_feat / norms + + polysketch_feats = polysketch.sketch(nngp_feat) nngp_feat = polysketch.expand_feats(polysketch_feats, nngp_coeffs) ntk_feat = polysketch.expand_feats(polysketch_feats, ntk_coeffs) @@ -557,8 +563,11 @@ def feature_fn(f, input=None, **kwargs): ntk_feat = polysketch.standardsrht(ntk_feat).reshape(input_shape + (-1,)) # Convert complex features to real ones. - ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1) nngp_feat = np.concatenate((nngp_feat.real, nngp_feat.imag), axis=-1) + ntk_feat = np.concatenate((ntk_feat.real, ntk_feat.imag), axis=-1) + + nngp_feat *= norms / 2**(num_layers / 2.) + ntk_feat *= norms / 2**(num_layers / 2.) return f.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat) diff --git a/experimental/tests/features_test.py b/experimental/tests/features_test.py index 22175bc2..5b4d23d6 100644 --- a/experimental/tests/features_test.py +++ b/experimental/tests/features_test.py @@ -534,6 +534,37 @@ def test_aggregate_features(self): self.assertAllClose(k.nngp, f.nngp_feat @ f.nngp_feat.T) self.assertAllClose(k.ntk, f.ntk_feat @ f.ntk_feat.T) + @parameterized.product(n_layers=[1, 2, 3, 4, 5], do_jit=[True, False]) + def test_onepass_fc_relu_nngp_ntk(self, n_layers, do_jit): + rng = random.PRNGKey(1) + n, d = 4, 256 + x = _get_init_data(rng, (n, d)) + + kernel_fn = stax.serial(*[stax.Dense(1), stax.Relu()] * n_layers + + [stax.Dense(1)])[2] + + poly_degree = 8 + poly_sketch_dim = 4096 + + init_fn, feature_fn = ft.ReluNTKFeatures(n_layers, poly_degree, + poly_sketch_dim) + + rng2 = random.PRNGKey(2) + _, feat_fn_inputs = init_fn(rng2, x.shape) + + if do_jit: + kernel_fn = jit(kernel_fn) + feature_fn = jit(feature_fn) + + k = kernel_fn(x) + f = feature_fn(x, feat_fn_inputs) + + k_nngp_approx = f.nngp_feat @ f.nngp_feat.T + k_ntk_approx = f.ntk_feat @ f.ntk_feat.T + + test_utils.assert_close_matrices(self, k.nngp, k_nngp_approx, 0.2, 1.) + test_utils.assert_close_matrices(self, k.ntk, k_ntk_approx, 0.2, 1.) + if __name__ == "__main__": absltest.main()