Skip to content

Commit

Permalink
add centroid metric, fix calculating distance from hyperplane for mul…
Browse files Browse the repository at this point in the history
…ticlass classification
  • Loading branch information
balins committed Aug 29, 2024
1 parent 9d2c12e commit 0b9ac85
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
26 changes: 21 additions & 5 deletions fsvm/_fuzzy_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class FuzzySVC(ClassifierMixin, BaseEstimator):
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
"""

Expand All @@ -238,6 +238,7 @@ class FuzzySVC(ClassifierMixin, BaseEstimator):
StrOptions({"centroid", "hyperplane"}),
callable,
],
"centroid_metric": [StrOptions({"euclidean", "manhattan"})],
"membership_decay": [StrOptions({"exponential", "linear"}), callable],
"beta": [Interval(Real, 0.0, 1.0, closed="both")],
"balanced": ["boolean"],
Expand All @@ -251,8 +252,9 @@ def __init__(
self,
*,
distance_metric="centroid",
centroid_metric="euclidean",
membership_decay="exponential",
beta=0.5,
beta=0.1,
balanced=True,
C=1.0,
kernel="rbf",
Expand All @@ -270,6 +272,7 @@ def __init__(
random_state=None,
):
self.distance_metric = distance_metric
self.centroid_metric = centroid_metric
self.membership_decay = membership_decay
self.beta = beta
self.balanced = balanced
Expand Down Expand Up @@ -349,8 +352,21 @@ def fit(self, X, y):
centroids = _NearestCentroid().fit(X, y).centroids_
self.distance_ = np.linalg.norm(X - centroids[y_], axis=1)
elif self.distance_metric == "hyperplane":
svc = _SVC(**svc_args).fit(X, y)
self.distance_ = np.abs(svc.decision_function(X))
hyperplane_svc_args = {**svc_args, "decision_function_shape": "ovr"}
svc = _SVC(**hyperplane_svc_args).fit(X, y)
decision_function_output = svc.decision_function(X)

# For multiclass, extract the distances corresponding
# to the true class labels
if decision_function_output.ndim > 1:
y_indices = np.array(
[svc.classes_.tolist().index(class_label) for class_label in y]
)
decision_function_output = decision_function_output[
np.arange(len(y)), y_indices
]

self.distance_ = np.abs(decision_function_output)
elif callable(self.distance_metric):
self.distance_ = self.distance_metric(X)

Expand Down Expand Up @@ -416,7 +432,7 @@ def __calculate_membership_degree(self):
membership = 2 / (1 + np.exp(self.beta * self.distance_))
elif self.membership_decay == "linear":
max_distance = np.amax(self.distance_)
delta = 1e-6
delta = 1e-9
membership = 1 - (self.distance_ / (max_distance + delta))
elif callable(self.membership_decay):
membership = self.membership_decay(self.distance_)
Expand Down
2 changes: 1 addition & 1 deletion fsvm/tests/test_fuzzy_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_fuzzy_svc(data):
clf = FuzzySVC()
assert clf.distance_metric == "centroid"
assert clf.membership_decay == "exponential"
assert clf.beta == 0.5
assert clf.beta == 0.1
assert clf.balanced is True

clf.fit(X, y)
Expand Down

0 comments on commit 0b9ac85

Please sign in to comment.