Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flipped sign in BCE loss and add hinge-loss option. #7

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions hoplite/agile/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,28 @@ def bce_loss(
y_true = tf.cast(y_true, dtype=logits.dtype)
log_p = tf.math.log_sigmoid(logits)
log_not_p = tf.math.log_sigmoid(-logits)
raw_bce = -y_true * log_p + (1.0 - y_true) * log_not_p
# optax sigmoid_binary_cross_entropy:
# -labels * log_p - (1.0 - labels) * log_not_p
raw_bce = -y_true * log_p - (1.0 - y_true) * log_not_p
is_labeled_mask = tf.cast(is_labeled_mask, dtype=logits.dtype)
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
return tf.reduce_mean(raw_bce * weights)


def hinge_loss(
y_true: tf.Tensor,
logits: tf.Tensor,
is_labeled_mask: tf.Tensor,
weak_neg_weight: float,
) -> tf.Tensor:
"""Weighted SVM hinge loss."""
# Convert multihot to +/- 1 labels.
y_true = 2 * y_true - 1
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
raw_hinge_loss = tf.maximum(0, 1 - y_true * logits)
return tf.reduce_mean(raw_hinge_loss * weights)


def infer(params, embeddings: np.ndarray):
"""Apply the model to embeddings."""
return np.dot(embeddings, params['beta']) + params['beta_bias']
Expand Down Expand Up @@ -105,19 +121,26 @@ def train_linear_classifier(
learning_rate: float,
weak_neg_weight: float,
num_train_steps: int,
loss: str = 'bce',
):
"""Train a linear classifier."""
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.get_target_labels())
lin_model = get_linear_model(embedding_dim, num_classes)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
lin_model.compile(optimizer=optimizer, loss='binary_crossentropy')
if loss == 'hinge':
loss_fn = hinge_loss
elif loss == 'bce':
loss_fn = bce_loss
else:
raise ValueError(f'Unknown loss: {loss}')

@tf.function
def train_step(y_true, embeddings, is_labeled_mask):
with tf.GradientTape() as tape:
logits = lin_model(embeddings, training=True)
loss = bce_loss(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = loss_fn(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss, lin_model.trainable_variables)
optimizer.apply_gradients(zip(grads, lin_model.trainable_variables))
Expand Down
Loading