Skip to content

Commit 8af7f77

Browse files
committed
fix more tests
1 parent 9ad56ee commit 8af7f77

3 files changed

Lines changed: 10 additions & 7 deletions

File tree

onedal/svm/tests/test_nusvc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _test_libsvm_parameters(queue, array_constr, dtype):
3939
clf.dual_coef_, [[-0.04761905, -0.0952381, 0.0952381, 0.04761905]]
4040
)
4141
assert_array_equal(clf.support_, [0, 1, 3, 4])
42-
assert_array_equal(clf.support_vectors_, X[clf.support_])
42+
assert_array_equal(clf.support_vectors_, X[clf.support_.astype(int)])
4343
assert_array_equal(clf.intercept_, [0.0])
4444
assert_array_equal(clf.predict(X, queue=queue), y)
4545

@@ -70,7 +70,7 @@ def test_sample_weight(queue):
7070
y = np.array([1, 1, 1, 2, 2, 2])
7171

7272
clf = NuSVC(kernel="linear")
73-
clf.fit(X, y, sample_weight=[1] * 6, queue=queue)
73+
clf.fit(X, y, sample_weight=np.array([1] * 6), queue=queue)
7474
assert_array_almost_equal(clf.intercept_, [0.0])
7575

7676

onedal/svm/tests/test_svc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_sample_weight(queue):
8989
y = np.array([1, 1, 1, 2, 2, 2])
9090

9191
clf = SVC(kernel="linear")
92-
clf.fit(X, y, sample_weight=[1] * 6, queue=queue)
92+
clf.fit(X, y, sample_weight=np.array([1] * 6), queue=queue)
9393
assert_array_almost_equal(clf.intercept_, [0.0])
9494

9595

onedal/svm/tests/test_svr.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
2020
from sklearn import datasets
21+
from sklearn.metrics import r2_score
2122
from sklearn.metrics.pairwise import rbf_kernel
2223
from sklearn.svm import SVR as SklearnSVR
2324

@@ -176,10 +177,12 @@ def test_synth_linear_compare_with_sklearn(queue, C):
176177

177178
def _test_synth_poly_compare_with_sklearn(queue, params):
178179
x, y = datasets.make_regression(**synth_params)
179-
clf = SVR(kernel="poly", **params)
180+
gamma = 1.0 / (x.shape[1] * x.var())
181+
clf = SVR(kernel="poly", gamma=gamma, **params)
180182
clf.fit(x, y, queue=queue)
181-
result = clf.score(x, y, queue=queue)
183+
result = r2_score(y, clf.predict(x, queue=queue))
182184

185+
# gamma='scale' by default in sklearn
183186
clf = SklearnSVR(kernel="poly", **params)
184187
clf.fit(x, y)
185188
expected = clf.score(x, y)
@@ -193,8 +196,8 @@ def _test_synth_poly_compare_with_sklearn(queue, params):
193196
@pytest.mark.parametrize(
194197
"params",
195198
[
196-
{"degree": 2, "coef0": 0.1, "gamma": "scale", "C": 100},
197-
{"degree": 3, "coef0": 0.0, "gamma": "scale", "C": 1000},
199+
{"degree": 2, "coef0": 0.1, "C": 100},
200+
{"degree": 3, "coef0": 0.0, "C": 1000},
198201
],
199202
)
200203
def test_synth_poly_compare_with_sklearn(queue, params):

0 commit comments

Comments
 (0)