Skip to content

Commit

Permalink
fix loss_computer.get_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
Shirley Wu committed Apr 14, 2024
1 parent b3d1ab7 commit d817d3b
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions disc/run_epoch_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ def run_epoch(
optimizer.step()

if is_training and (batch_idx+1) % log_every==0:
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args))
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args, is_training))
csv_logger.flush()
loss_computer.log_stats(logger, is_training)
loss_computer.reset_stats()

if (not is_training) or loss_computer.batch_count > 0:
model = model.cuda()
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args))
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args, is_training))
csv_logger.flush()
loss_computer.log_stats(logger, is_training)
if is_training:
Expand Down
4 changes: 2 additions & 2 deletions disc/run_epoch_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def run_epoch_disc(
optimizer.step()

if is_training and (batch_idx + 1) % log_every==0:
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args))
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args, is_training))
csv_logger.flush()
loss_computer.log_stats(logger, is_training)
loss_computer.reset_stats()

if (not is_training) or loss_computer.batch_count > 0:
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args))
csv_logger.log(epoch, batch_idx, loss_computer.get_stats(model, args, is_training))
csv_logger.flush()
loss_computer.log_stats(logger, is_training)
if is_training:
Expand Down
2 changes: 1 addition & 1 deletion disc/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_model_stats(self, model, args, stats_dict):
stats_dict['reg_loss'] = args.weight_decay / 2 * model_norm_sq.item()
return stats_dict

def get_stats(self, model=None, args=None):
def get_stats(self, model=None, args=None, is_training=True):
stats_dict = {}
accs = []

Expand Down
1 change: 1 addition & 0 deletions run_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
seed=args.seed
)
if args.lisa_mix_up:
train_loader = {}
for i in range(train_data.n_groups):
idxes = np.where(train_data.get_group_array() == i)[0]
if len(idxes) == 0:
Expand Down
4 changes: 2 additions & 2 deletions scripts/waterbirds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ python run_expt.py -s confounder -d CUB -t waterbird_complete95 -c forest2water2
# ERM
SEED=0
ROOT=./DISC
python run_expt.py -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.001 --batch_size 32 --weight_decay 0.0001 --model resnet50 --n_epochs 300 --root_dir $ROOT/data/cub --log_dir $ROOT/output/ --save_best --save_last --seed 1
python run_expt.py -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.001 --batch_size 32 --weight_decay 0.0001 --model resnet50 --n_epochs 300 --root_dir $ROOT/data/cub --log_dir $ROOT/output/ --save_best --save_last --seed $SEED


# ERM + aug
SEED=0
ROOT=./DISC
python run_expt.py -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.001 --batch_size 32 --weight_decay 0.0001 --model resnet50 --n_epochs 300 --root_dir $ROOT/data/cub --log_dir $ROOT/output/waterbirds/ERM_aug --augment_data --save_best --save_last
python run_expt.py -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.001 --batch_size 32 --weight_decay 0.0001 --model resnet50 --n_epochs 300 --root_dir $ROOT/data/cub --log_dir $ROOT/output/ --augment_data --save_best --save_last --seed $SEED

0 comments on commit d817d3b

Please sign in to comment.