@@ -1840,13 +1840,16 @@ vector<Stmt> LowererImpl::codeToInitializeDenseAcceleratorArrays(Where where) {
1840
1840
// Returns true if the following conditions are met:
1841
1841
// 1) The temporary is a dense vector
1842
1842
// 2) There is only one value on the right hand side of the consumer
1843
- // -- We would need to handle sparse acceleration in the merge lattices for multiple operands on the RHS
1844
- // 3) There are no reduced accesses
1845
- // 4) The left hand side of the where consumer is sparse TODO: update this
1846
- // 5) CPU Code is being generated (TEMPORARY - This should be removed)
1847
- // -- The sorting calls and calloc call in lower where are CPU specific. We could map calloc to a cudaMalloc
1848
- // and use a library like CUB to emit the sort. CUB support is built into CUDA 11 but not prior versions
1849
- // of CUDA so in that case, we'd probably need to include the CUB headers in the generated code.
1843
+ // -- We would need to handle sparse acceleration in the merge lattices for
1844
+ // multiple operands on the RHS
1845
+ // 3) The left hand side of the where consumer is sparse, if the consumer is an
1846
+ // assignment
1847
+ // 4) CPU Code is being generated (TEMPORARY - This should be removed)
1848
+ // -- The sorting calls and calloc call in lower where are CPU specific. We
1849
+ // could map calloc to a cudaMalloc and use a library like CUB to emit
1850
+ // the sort. CUB support is built into CUDA 11 but not prior versions of
1851
+ // CUDA so in that case, we'd probably need to include the CUB headers in
1852
+ // the generated code.
1850
1853
std::pair<bool ,bool > LowererImpl::canAccelerateDenseTemp (Where where) {
1851
1854
// TODO: TEMPORARY -- Needs to be removed
1852
1855
if (should_use_CUDA_codegen ()) {
@@ -1868,20 +1871,13 @@ std::pair<bool,bool> LowererImpl::canAccelerateDenseTemp(Where where) {
1868
1871
return std::make_pair (false , false );
1869
1872
}
1870
1873
1871
- std::tie (resultAccesses, reducedAccesses) = getResultAccesses (where.getConsumer ());
1872
- // (3) Contains reduced accesses
1873
- if (!reducedAccesses.empty ()) {
1874
- return std::make_pair (false , false );
1875
- }
1876
-
1877
1874
// no or multiple results?
1878
1875
if (resultAccesses.size () > 1 || resultAccesses.empty ()) {
1879
1876
return std::make_pair (false , false );
1880
1877
}
1881
1878
1882
- // (4) Level of result is sparse
1883
- // No check for size of tempVar since we enforced the temporary is a vector and if there is only one RHS value,
1884
- // it must (should?) be the temporary
1879
+ // No check for size of tempVar since we enforced the temporary is a vector
1880
+ // and if there is only one RHS value, it must (should?) be the temporary
1885
1881
std::vector<IndexVar> tempVar = inputAccesses[0 ].getIndexVars ();
1886
1882
1887
1883
// Get index vars in result.
@@ -1900,9 +1896,7 @@ std::pair<bool,bool> LowererImpl::canAccelerateDenseTemp(Where where) {
1900
1896
TensorVar resultTensor = resultAccesses[0 ].getTensorVar ();
1901
1897
int modeIndex = resultTensor.getFormat ().getModeOrdering ()[index];
1902
1898
ModeFormat varFmt = resultTensor.getFormat ().getModeFormats ()[modeIndex];
1903
-
1904
- // Actual check for condition (4). If the current mode is full, no
1905
- // optimizations necessary
1899
+ // (3) Level of result is sparse
1906
1900
if (varFmt.isFull ()) {
1907
1901
return std::make_pair (false , false );
1908
1902
}
0 commit comments