Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onedal/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def predict_proba(self, X, queue=None):
_y = self._y.reshape((-1, 1))
classes_ = [self.classes_]

n_queries = _num_samples(X)
n_queries = _num_samples(X if X is not None else self._fit_X)

weights = self._get_weights(neigh_dist, self.weights)
if weights is None:
Expand Down
13 changes: 13 additions & 0 deletions sklearnex/neighbors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,19 @@
patching_status = PatchingConditionsChain(
f"sklearn.neighbors.{class_name}.{method_name}"
)
# TODO: with verbosity enabled, here it would emit a log saying that it fell

Check notice on line 154 in sklearnex/neighbors/common.py

View check run for this annotation

codefactor.io / CodeFactor

sklearnex/neighbors/common.py#L154

Unresolved comment '# TODO: with verbosity enabled, here it would emit a log saying that it fell'. (C100)
# back to sklearn, but internally, sklearn will end up calling 'kneighbors'
# which is overridden in the sklearnex classes, thus it will end up calling
# oneDAL in the end, but the log will say otherwise. Find a way to make the
# log consistent with what happens in practice.
patching_status.and_conditions(
[
(
not (data[0] is None and method_name in ["predict", "score"]),
"Predictions on 'None' data are handled by internal sklearn methods.",
)
]
)
if not patching_status.and_condition(
"radius" not in method_name, "RadiusNeighbors not implemented in sklearnex"
):
Expand Down
9 changes: 6 additions & 3 deletions sklearnex/neighbors/knn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def fit(self, X, y):
@wrap_output_data
def predict(self, X):
check_is_fitted(self)
check_feature_names(self, X, reset=False)
if X is not None:
check_feature_names(self, X, reset=False)
return dispatch(
self,
"predict",
Expand All @@ -93,7 +94,8 @@ def predict(self, X):
@wrap_output_data
def predict_proba(self, X):
check_is_fitted(self)
check_feature_names(self, X, reset=False)
if X is not None:
check_feature_names(self, X, reset=False)
return dispatch(
self,
"predict_proba",
Expand All @@ -107,7 +109,8 @@ def predict_proba(self, X):
@wrap_output_data
def score(self, X, y, sample_weight=None):
check_is_fitted(self)
check_feature_names(self, X, reset=False)
if X is not None:
check_feature_names(self, X, reset=False)
return dispatch(
self,
"score",
Expand Down
Loading