Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
Merge branch 'reinvent.3.1'
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoJeff committed Nov 15, 2021
2 parents 7ec8d30 + c8355fd commit e2463b8
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def save_checkpoint(self, step, scaffold_filter, agent):
actual_step = step + 1
if self._log_config.logging_frequency > 0 and actual_step % self._log_config.logging_frequency == 0:
self.save_diversity_memory(scaffold_filter)
agent.save_to_file(os.path.join(self._log_config.result_folder, f'Agent.{actual_step}.ckpt'))
agent.save(os.path.join(self._log_config.result_folder, f'Agent.{actual_step}.ckpt'))

@abstractmethod
def save_final_state(self, agent, scaffold_filter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def __init__(self, strategy_configuration: LibInventScoringStrategyConfiguration
self.reaction_filter = ReactionFilter(strategy_configuration.reaction_filter)

def evaluate(self, sampled_sequences: List[SampledSequencesDTO], step) -> FinalSummary:
score_summary = self._apply_scoring_function(sampled_sequences)
score_summary = self._apply_scoring_function(sampled_sequences, step)

score_summary.total_score = self.diversity_filter.update_score(score_summary, sampled_sequences, step)
return score_summary

def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO]) -> FinalSummary:
def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO], step:int) -> FinalSummary:
molecules = self._join_scaffolds_and_decorations(sampled_sequences)
smiles = [self._conversion.mol_to_smiles(molecule) if molecule else "INVALID" for molecule in molecules]
final_score: FinalSummary = self.scoring_function.get_final_score(smiles)
final_score: FinalSummary = self.scoring_function.get_final_score_for_step(smiles, step)
final_score = self._apply_reaction_filters(molecules, final_score)
return final_score

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def __init__(self, strategy_configuration: ScoringStrategyConfiguration, diversi
super().__init__(strategy_configuration, diversity_filter, logger)

def evaluate(self, sampled_sequences: List[SampledSequencesDTO], step) -> FinalSummary:
score_summary = self._apply_scoring_function(sampled_sequences)
score_summary = self._apply_scoring_function(sampled_sequences, step)
score_summary = self._clean_scored_smiles(score_summary)
score_summary.total_score = self.diversity_filter.update_score(score_summary, sampled_sequences, step)
return score_summary

def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO]) -> FinalSummary:
def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO], step) -> FinalSummary:
molecules = self._join_linker_and_warheads(sampled_sequences, keep_labels=True)
smiles = []
for idx, molecule in enumerate(molecules):
Expand All @@ -33,7 +33,7 @@ def _apply_scoring_function(self, sampled_sequences: List[SampledSequencesDTO])
f'\n\toutput: {sampled_sequences[idx].output}\n')
finally:
smiles.append(smiles_str)
final_score: FinalSummary = self.scoring_function.get_final_score(smiles)
final_score: FinalSummary = self.scoring_function.get_final_score_for_step(smiles, step)
return final_score

def _join_linker_and_warheads(self, sampled_sequences: List[SampledSequencesDTO], keep_labels=False):
Expand Down
3 changes: 0 additions & 3 deletions running_modes/sampling/logging/local_sampling_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def log_message(self, message: str):
def timestep_report(self, smiles: [], likelihoods: np.array):
self._log_timestep(smiles, likelihoods)

def __del__(self):
self._summary_writer.close()

def _log_timestep(self, smiles: np.array, likelihoods: np.array):
valid_smiles_fraction = fraction_valid_smiles(smiles)
fraction_unique_entries = self._get_unique_entires_fraction(likelihoods)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ def __init__(self, configuration: GeneralConfigurationEnvelope):
super().__init__(configuration)
self._summary_writer = SummaryWriter(log_dir=self._log_config.logging_path)

def __del__(self):
self._summary_writer.close()

def log_out_input_configuration(self):
file = os.path.join(self._log_config.logging_path, "input.json")
jsonstr = json.dumps(self._configuration, default=lambda x: x.__dict__, sort_keys=True, indent=4,
Expand Down
12 changes: 11 additions & 1 deletion running_modes/validation/logging/remote_validation_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from running_modes.configurations.general_configuration_envelope import GeneralConfigurationEnvelope
from running_modes.validation.logging.base_validation_logger import BaseValidationLogger
import running_modes.utils.configuration as utils_log

from running_modes.configurations.logging import get_remote_logging_auth_token


class RemoteValidationLogger(BaseValidationLogger):
def __init__(self, configuration: GeneralConfigurationEnvelope):
super().__init__(configuration)
self._is_dev = utils_log._is_development_environment()

def log_message(self, message: str):
data = {"valid": self.model_is_valid, "message": message}
Expand All @@ -15,13 +19,19 @@ def log_message(self, message: str):
def _notify_server(self, data, to_address):
"""This is called every time we are posting data to server"""
try:
headers = {
'Accept': 'application/json', 'Content-Type': 'application/json',
'Authorization': get_remote_logging_auth_token()
}
self._common_logger.warning(f"posting to {to_address}")
response = requests.post(to_address, data=data)
response = requests.post(to_address, json=data, headers=headers)

if response.status_code == requests.codes.ok:
self._common_logger.info(f"SUCCESS: {response.status_code}")
self._common_logger.info(response.content)
else:
self._common_logger.info(f"PROBLEM: {response.status_code}")
self._common_logger.exception(data, exc_info=False)
except Exception as e:
self._common_logger.exception("Exception occurred", exc_info=True)
self._common_logger.exception(f"Attempted posting the following data:")
Expand Down

0 comments on commit e2463b8

Please sign in to comment.