Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made fixed dataset shuffle a non-default again #17

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions cpc/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(self,
phoneLabelsDict,
nSpeakers,
nProcessLoader=50,
MAX_SIZE_LOADED=4000000000):
MAX_SIZE_LOADED=4000000000,
keepSameSeedForDSshuffle=False,
newTorchaudio=False):
"""
Args:
- path (string): path to the training dataset
Expand All @@ -45,6 +47,8 @@ def __init__(self,
- MAX_SIZE_LOADED (int): target maximal size of the floating array
containing all loaded data.
"""
self.keepSameSeedForDSshuffle = keepSameSeedForDSshuffle
self.newTorchaudio = newTorchaudio
self.MAX_SIZE_LOADED = MAX_SIZE_LOADED
self.nProcessLoader = nProcessLoader
self.dbPath = Path(path)
Expand Down Expand Up @@ -91,15 +95,24 @@ def clear(self):
del self.seqLabel

def prepare(self):
randomstate = random.getstate()
random.seed(767543) # set seed only for batching so that it is random but always same for same dataset
# so that capturing captures data for same audio across runs if same dataset provided
if self.keepSameSeedForDSshuffle:
print("--> setting same seed for DS seqNames shuffling")
randomstate = random.getstate()
random.seed(767543) # set seed only for batching so that it is random but always same for same dataset
# so that capturing captures data for same audio across runs if same dataset provided
else:
print("--> using random seed for DS seqNames shuffling")
random.shuffle(self.seqNames)
random.setstate(randomstate) # restore random state so that other stuff changes with seed in args
if self.keepSameSeedForDSshuffle:
random.setstate(randomstate) # restore random state so that other stuff changes with seed in args
start_time = time.time()

print("Checking length...")
allLength = self.reload_pool.map(extractLength, self.seqNames)
if self.newTorchaudio:
mapFun = extractLengthNewTorchaudio
else:
mapFun = extractLength
allLength = self.reload_pool.map(mapFun, self.seqNames)

self.packageIndex, self.totSize = [], 0
start, packageSize = 0, 0
Expand Down Expand Up @@ -423,11 +436,17 @@ def __iter__(self):
return iter(self.batches)


def extractLength(couple):
def extractLength(couple): # for old torchaudio
speaker, locPath = couple
info = torchaudio.info(str(locPath))[0]
return info.length

def extractLengthNewTorchaudio(couple): # linux machines, new torchaudio 0.8.1+ for CUDA around >= 11
speaker, locPath = couple
# https://pytorch.org/audio/stable/backend.html#torchaudio.backend.common.AudioMetaData
info = torchaudio.info(str(locPath))
return info.num_frames * info.num_channels # (default 'sox' backend)


def findAllSeqs(dirName,
extension='.flac',
Expand Down
36 changes: 26 additions & 10 deletions cpc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def run(trainDataset,
for epoch in range(startEpoch, nEpoch):

print(f"Starting epoch {epoch}")
sys.stdout.flush()
utils.cpu_stats()

trainLoader = trainDataset.getDataLoader(batchSize, samplingMode,
Expand Down Expand Up @@ -505,7 +506,9 @@ def main(args):
phoneLabels,
len(speakers),
nProcessLoader=args.n_process_loader,
MAX_SIZE_LOADED=args.max_size_loaded)
MAX_SIZE_LOADED=args.max_size_loaded,
keepSameSeedForDSshuffle=args.fixedDSshuffleSeed,
newTorchaudio=args.newTorchaudio)
print("Training dataset loaded")
print("")

Expand All @@ -515,7 +518,9 @@ def main(args):
seqVal,
phoneLabels,
len(speakers),
nProcessLoader=args.n_process_loader)
nProcessLoader=args.n_process_loader,
keepSameSeedForDSshuffle=args.fixedDSshuffleSeed,
newTorchaudio=args.newTorchaudio)
print("Validation dataset loaded")
print("")
else:
Expand All @@ -538,7 +543,9 @@ def main(args):
seqCapture,
phoneLabelsForCapture,
len(speakers),
nProcessLoader=args.n_process_loader)
nProcessLoader=args.n_process_loader,
keepSameSeedForDSshuffle=True,
newTorchaudio=args.newTorchaudio)
print("Capture dataset loaded")
print("")

Expand Down Expand Up @@ -713,16 +720,20 @@ def constructSpeakerCriterionAndOptimizer():

return speaker_criterion, speaker_optimizer

linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain,
phoneLabelsData, len(speakers))
linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal,
phoneLabelsData, len(speakers))
# loading this second time kills RAM
# linsep_db_train = AudioBatchData(args.pathDB, args.sizeWindow, seqTrain,
# phoneLabelsData, len(speakers), keepSameSeedForDSshuffle=args.fixedDSshuffleSeed,
# newTorchaudio=args.newTorchaudio)
# linsep_db_val = AudioBatchData(args.pathDB, args.sizeWindow, seqVal,
# phoneLabelsData, len(speakers), keepSameSeedForDSshuffle=args.fixedDSshuffleSeed,
# newTorchaudio=args.newTorchaudio)

linsep_train_loader = linsep_db_train.getDataLoader(linsep_batch_size, "uniform", True,
linsep_train_loader = trainDataset.getDataLoader(linsep_batch_size, "uniform", True,
numWorkers=0)

linsep_val_loader = linsep_db_val.getDataLoader(linsep_batch_size, 'sequential', False,
print("linsep_train_loader ready")
linsep_val_loader = valDataset.getDataLoader(linsep_batch_size, 'sequential', False,
numWorkers=0)
print("linsep_val_loader ready")

def runLinsepClassificationTraining(numOfEpoch, cpcMdl, cpcStateEpoch):
log_path_for_epoch = os.path.join(args.linsep_logs_dir, str(numOfEpoch))
Expand Down Expand Up @@ -841,6 +852,11 @@ def parseArgs(argv):
group_db.add_argument('--pathVal', type=str, default=None,
help='Path to a .txt file containing the list of the '
'validation sequences.')
group_db.add_argument('--fixedDSshuffleSeed', action='store_true',
help="if set, will always shuffle train & val DS same way (with same seed); "
"if not set, will use randomized seed used for other stuff also for this")
group_db.add_argument('--newTorchaudio', action='store_true',
help="if set, use newer audio data loading API compatible with newer torchaudio (0.8.1+)")
# stuff below for capturing data
group_db.add_argument('--onlyCapture', action='store_true',
help='Only capture data from learned model for one epoch, ignore training; '
Expand Down