Skip to content

Commit 46fedcd

Browse files
authored
Merge pull request #1394 from wjmaddox/kplt_general_roots
Make KroneckerProductLazyTensor call each component's root decomposition
2 parents 1b8e880 + 3d8e81d commit 46fedcd

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

gpytorch/lazy/kronecker_product_lazy_tensor.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,18 +199,26 @@ def _matmul(self, rhs):
199199
def root_decomposition(self, method: Optional[str] = None):
200200
from gpytorch.lazy import RootLazyTensor
201201

202-
if method == "symeig" or method is None:
203-
evals, evecs = self._symeig(eigenvectors=True, return_evals_as_lazy=True)
204-
# TODO: only use non-zero evals (req. dealing w/ batches...)
205-
f_list = [
206-
evec * eval.diag().clamp(0.0).sqrt().unsqueeze(-2)
207-
for eval, evec in zip(evals.lazy_tensors, evecs.lazy_tensors)
208-
]
209-
F = KroneckerProductLazyTensor(*f_list)
210-
return RootLazyTensor(F)
211-
else:
202+
# return a dense root decomposition if the matrix is small
203+
if self.shape[-1] <= settings.max_cholesky_size.value():
212204
return super().root_decomposition(method=method)
213205

206+
root_list = [lt.root_decomposition(method=method).root for lt in self.lazy_tensors]
207+
kronecker_root = KroneckerProductLazyTensor(*root_list)
208+
return RootLazyTensor(kronecker_root)
209+
210+
@cached(name="root_inv_decomposition")
211+
def root_inv_decomposition(self, initial_vectors=None, test_vectors=None):
212+
from gpytorch.lazy import RootLazyTensor
213+
214+
# return a dense root decomposition if the matrix is small
215+
if self.shape[-1] <= settings.max_cholesky_size.value():
216+
return super().root_inv_decomposition()
217+
218+
root_list = [lt.root_inv_decomposition().root for lt in self.lazy_tensors]
219+
kronecker_root = KroneckerProductLazyTensor(*root_list)
220+
return RootLazyTensor(kronecker_root)
221+
214222
@cached(name="size")
215223
def _size(self):
216224
left_size = _prod(lazy_tensor.size(-2) for lazy_tensor in self.lazy_tensors)

test/lazy/test_kronecker_product_lazy_tensor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def kron(a, b):
2020

2121
class TestKroneckerProductLazyTensor(LazyTensorTestCase, unittest.TestCase):
2222
seed = 0
23-
should_call_lanczos = False
23+
should_call_lanczos = True
2424

2525
def create_lazy_tensor(self):
2626
a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float)
@@ -40,12 +40,14 @@ def evaluate_lazy_tensor(self, lazy_tensor):
4040

4141
class TestKroneckerProductLazyTensorBatch(TestKroneckerProductLazyTensor):
4242
seed = 0
43-
should_call_lanczos = False
43+
should_call_lanczos = True
4444

4545
def create_lazy_tensor(self):
4646
a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float).repeat(3, 1, 1)
4747
b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float).repeat(3, 1, 1)
48-
c = torch.tensor([[4, 0, 1, 0], [0, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float).repeat(3, 1, 1)
48+
c = torch.tensor([[4, 0.1, 1, 0], [0.1, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float).repeat(
49+
3, 1, 1
50+
)
4951
a.requires_grad_(True)
5052
b.requires_grad_(True)
5153
c.requires_grad_(True)

0 commit comments

Comments
 (0)