Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code base refactor - Discussion #468

Open
lhjohn opened this issue Jul 2, 2024 · 2 comments
Open

Code base refactor - Discussion #468

lhjohn opened this issue Jul 2, 2024 · 2 comments

Comments

@lhjohn
Copy link
Contributor

lhjohn commented Jul 2, 2024

A place to discuss the refactor of PLP and get an overview of the current and options for a prospective code base. Currently the project is file-organized and function-based. Below is a "class" diagram of all files and functions in the R folder.

classDiagram
class AdditionalCovariates {
createCohortCovariateSettings
getCohortCovariateData
}
class AndromedaHelperFunctions {
batchRestrict
calculatePrevs
limitCovariatesToPopulation
}
class CalibrationSummary {
getCalibrationSummary
getCalibrationSummary_binary
getCalibrationSummary_survival
}
class CovariateSummary {
aggregateCovariateSummaries
covariateSummary
covariateSummarySubset
createCovariateSubsets
getCovariatesForGroup
}
class CyclopsModels {
createCyclopsModel
filterCovariateIds
fitCyclopsModel
getCV
getVariableImportance
modelTypeToCyclopsModelType
predictCyclops
predictCyclopsType
reparamTransferCoefs
}
class CyclopsSettings {
setCoxModel
setIterativeHardThresholding
setLassoLogisticRegression
}
class DatabaseMigration {
getDataMigrator
migrateDataModel
}
class DataSplitting {
checkInputsSplit
createDefaultSplitSetting
dataSummary
randomSplitter
splitData
subjectSplitter
timeSplitter
}
class DemographicSummary {
getDemographicSummary
getDemographicSummary_binary
getDemographicSummary_survival
}
class DiagnosePlp {
cos_sim
diagnoseMultiplePlp
diagnosePlp
getDiagnostic
getMaxEndDaysFromCovariates
getOutcomeSummary
probastDesign
probastOutcome
probastParticipants
probastPredictors
}
class EvaluatePlp {
evaluatePlp
modelBasedConcordance
}
class EvaluationSummary {
aucWithCi
aucWithoutCi
averagePrecision
brierScore
calculateEStatisticsBinary
calibrationInLarge
calibrationInLargeIntercept
calibrationLine
calibrationWeak
computeAuc
getEvaluationStatistics
getEvaluationStatistics_binary
getEvaluationStatistics_survival
ici
}
class ExternalValidatePlp {
createValidationDesign
createValidationSettings
externalValidateDbPlp
externalValidatePlp
validateExternal
validateModel
}
class ExtractData {
createDatabaseDetails
createRestrictPlpDataSettings
getPlpData
print.plpData
print.summary.plpData
summary.plpData
}
class FeatureEngineering {
calculateStratifiedMeans
createFeatureEngineeringSettings
createRandomForestFeatureSelection
createSplineSettings
createStratifiedImputationSettings
createUnivariateFeatureSelection
featureEngineer
imputeMissingMeans
randomForestFeatureSelection
splineCovariates
splineMap
stratifiedImputeCovariates
univariateFeatureSelection
}
class FeatureImportance {
permute
permutePerf
pfi
}
class Fit {
fitPlp
}
class Formatting {
checkRam
MapIds
toSparseM
}
class GradientBoostingMachine {
fitXgboost
predictXgboost
setGradientBoostingMachine
varImpXgboost
}
class HelperFunctions {
configurePython
createTempModelLoc
cut2
ensure_installed
getOs
is_installed
listAppend
nrow
nrow.default
nrow.tbl
removeInvalidString
setPythonEnvironment
}
class ImportFromCsv {
extractCohortDefinitionsCSV
extractDatabaseListCSV
extractDiagnosticFromCsv
extractObjectFromCsv
getModelDesignCsv
getModelDesignSettingTable
getPerformanceEvaluationCsv
getTableNamesPlp
insertCsvToDatabase
}
class KNN {
fitKNN
predictKnn
setKNN
}
class LearningCurve {
createLearningCurve
getTrainFractions
lcWrapper
learningCurveHelper
plotLearningCurve
}
class LightGBM {
fitLightGBM
predictLightGBM
setLightGBM
varImpLightGBM
}
class Logging {
checkFileExists
closeLog
createLog
createLogSettings
}
class ParamChecks {
checkBoolean
checkHigher
checkHigherEqual
checkInStringVector
checkIsClass
checkLower
checkLowerEqual
checkNotNull
}
class PatientLevelPrediction

class Plotting {
outcomeSurvivalPlot
plotDemographicSummary
plotF1Measure
plotGeneralizability
plotPlp
plotPrecisionRecall
plotPredictedPDF
plotPredictionDistribution
plotPreferencePDF
plotSmoothCalibration
plotSmoothCalibrationLoess
plotSmoothCalibrationRcs
plotSparseCalibration
plotSparseCalibration2
plotSparseRoc
plotVariableScatterplot
}
class PopulationSettings {
createStudyPopulation
createStudyPopulationSettings
getCounts
getCounts2
}
class Predict {
applyFeatureengineering
applyTidyCovariateData
predictPlp
}
class PredictionDistribution {
getPredictionDistribution
getPredictionDistribution_binary
getPredictionDistribution_survival
}
class PreprocessingData {
createPreprocessSettings
preprocessData
}
class RClassifier {
applyCrossValidationInR
fitRclassifier
}
class Recalibration {
inverseLog
logFunct
recalibratePlp
recalibratePlpRefit
recalibrationInTheLarge
weakRecalibration
}
class RunMultiplePlp {
convertToJson
createModelDesign
loadPlpAnalysesJson
runMultiplePlp
savePlpAnalysesJson
validateMultiplePlp
}
class RunPlp {
runPlp
}
class RunPlpHelpers {
checkInputs
createDefaultExecuteSettings
createExecuteSettings
printHeader
}
class Sampling {
createSampleSettings
overSampleData
sameData
sampleData
underSampleData
}
class SaveLoadPlp {
applyMinCellCount
extractDatabaseToCsv
getPlpSensitiveColumns
loadPlpData
loadPlpModel
loadPlpResult
loadPlpShareable
loadPrediction
removeCellCount
removeList
saveModelPart
savePlpData
savePlpModel
savePlpResult
savePlpShareable
savePrediction
}
class Simulation {
simulatePlpData
}
class SklearnClassifier {
checkPySettings
computeGridPerformance
fitPythonModel
fitSklearn
gridCvPython
predictPythonSklearn
predictValues
}
class SklearnClassifierHelpers {
listCartesian
}
class SklearnClassifierSettings {
AdaBoostClassifierInputs
DecisionTreeClassifierInputs
GaussianNBInputs
MLPClassifierInputs
RandomForestClassifierInputs
setAdaBoost
setDecisionTree
setMLP
setNaiveBayes
setRandomForest
setSVM
SVCInputs
}
class SklearnToJson {
deSerializeAdaboost
deSerializeCsrMatrix
deSerializeDecisionTree
deSerializeMlp
deSerializeNaiveBayes
deSerializeRandomForest
deSerializeSVM
deSerializeTree
serializeAdaboost
serializeCsrMatrix
serializeDecisionTree
serializeMLP
serializeNaiveBayes
serializeRandomForest
serializeSVM
serializeTree
sklearnFromJson
sklearnToJson
}
class ThresholdSummary {
accuracy
checkToByTwoTableInputs
diagnosticOddsRatio
f1Score
falseDiscoveryRate
falseNegativeRate
falseOmissionRate
falsePositiveRate
getThresholdSummary
getThresholdSummary_binary
getThresholdSummary_survival
negativeLikelihoodRatio
negativePredictiveValue
positiveLikelihoodRatio
positivePredictiveValue
sensitivity
specificity
stdca
}
class uploadToDatabase {
addCohort
addDatabase
addModel
addMultipleRunPlpToDatabase
addRunPlpToDatabase
checkJson
checkTable
cleanNum
createDatabaseList
createDatabaseSchemaSettings
createPlpResultTables
deleteTables
enc
getCohortDef
getPlpResultTables
getResultLocations
insertModelInDatabase
insertResultsToSqlite
insertRunPlpToSqlite
}
class uploadToDatabaseDiagnostics {
addDiagnosePlpToDatabase
addDiagnostic
addMultipleDiagnosePlpToDatabase
addResultTable
insertDiagnosisToDatabase
}
class uploadToDatabaseModelDesign {
addCovariateSetting
addFESetting
addModelDesign
addModelSetting
addPlpDataSetting
addPopulationSetting
addSampleSetting
addSplitSettings
addTar
addTidySetting
insertModelDesignInDatabase
insertModelDesignSettings
orderJson
}
class uploadToDatabasePerformance {
addAttrition
addCalibrationSummary
addCovariateSummary
addDemographicSummary
addEvaluation
addEvaluationStatistics
addPerformance
addPredictionDistribution
addThresholdSummary
checkResultExists
getColumnNames
insertPerformanceInDatabase
}
class ViewShinyPlp {
viewDatabaseResultPlp
viewMultiplePlp
viewPlp
viewPlps
}
Loading

Related resources:
Draft PR for new model API: #462

@lhjohn
Copy link
Contributor Author

lhjohn commented Jul 2, 2024

To design PLP system inspired by mlr3, we can organize PLP components into mlr3 building blocks:

  1. Learner: Corresponds to model type in PLP.
  2. Task: Could correspond to study population and other cohort parameters.
  3. Resample: Could correspond to data splitting in PLP.
  4. Measure: Corresponds to evaluation functions in PLP.
  5. Prediction: Could correspond to the model object in PLP, used for internal and external validation.
  6. Data: Although not considered a building block in mlr3, useful to represent the data object. Exists as DataBackend class in mlr3.

Forcing our existing PLP functions into mlr3 building blocks could look something like this:

PatientLevelPrediction:

classDiagram
class PatientLevelPrediction

PatientLevelPrediction --> HelperFunctions
PatientLevelPrediction --> Fit
PatientLevelPrediction --> Logging
PatientLevelPrediction --> ParamChecks
PatientLevelPrediction --> DatabaseMigration
PatientLevelPrediction --> RunMultiplePlp
PatientLevelPrediction --> RunPlp
PatientLevelPrediction --> RunPlpHelpers
PatientLevelPrediction --> SaveLoadPlp
PatientLevelPrediction --> LearningCurve

class HelperFunctions {
  +configurePython()
  +createTempModelLoc()
  +cut2()
  +ensure_installed()
  +getOs()
  +is_installed()
  +listAppend()
  +nrow()
  +nrow.default()
  +nrow.tbl()
  +removeInvalidString()
  +setPythonEnvironment()
}
class Fit {
  +fitPlp()
}
class Logging {
  +checkFileExists()
  +closeLog()
  +createLog()
  +createLogSettings()
}
class ParamChecks {
  +checkBoolean()
  +checkHigher()
  +checkHigherEqual()
  +checkInStringVector()
  +checkIsClass()
  +checkLower()
  +checkLowerEqual()
  +checkNotNull()
}
class DatabaseMigration {
  +getDataMigrator()
  +migrateDataModel()
}
class RunMultiplePlp {
  +convertToJson()
  +createModelDesign()
  +loadPlpAnalysesJson()
  +runMultiplePlp()
  +savePlpAnalysesJson()
  +validateMultiplePlp()
}
class RunPlp {
  +runPlp()
}
class RunPlpHelpers {
  +checkInputs()
  +createDefaultExecuteSettings()
  +createExecuteSettings()
  +printHeader()
}
class SaveLoadPlp {
  +applyMinCellCount()
  +extractDatabaseToCsv()
  +getPlpSensitiveColumns()
  +loadPlpData()
  +loadPlpModel()
  +loadPlpResult()
  +loadPlpShareable()
  +loadPrediction()
  +removeCellCount()
  +removeList()
  +saveModelPart()
  +savePlpData()
  +savePlpModel()
  +savePlpResult()
  +savePlpShareable()
  +savePrediction()
}
class LearningCurve {
  +createLearningCurve()
  +getTrainFractions()
  +lcWrapper()
  +learningCurveHelper()
  +plotLearningCurve()
}
Loading

Data:

classDiagram
class Data

Data --> ExtractData
Data --> Simulation
Data --> PreprocessingData
Data --> FeatureEngineering
Data --> FeatureImportance
Data --> Formatting
Data --> AdditionalCovariates
Data --> AndromedaHelperFunctions

class ExtractData {
  +createDatabaseDetails()
  +createRestrictPlpDataSettings()
  +getPlpData()
  +print.plpData()
  +print.summary.plpData()
  +summary.plpData()
}
class Simulation {
  +simulatePlpData()
}
class PreprocessingData {
  +createPreprocessSettings()
  +preprocessData()
}
class FeatureEngineering {
  +calculateStratifiedMeans()
  +createFeatureEngineeringSettings()
  +createRandomForestFeatureSelection()
  +createSplineSettings()
  +createStratifiedImputationSettings()
  +createUnivariateFeatureSelection()
  +featureEngineer()
  +imputeMissingMeans()
  +randomForestFeatureSelection()
  +splineCovariates()
  +splineMap()
  +stratifiedImputeCovariates()
  +univariateFeatureSelection()
}
class FeatureImportance {
  +permute()
  +permutePerf()
  +pfi()
}
class Formatting {
  +checkRam()
  +MapIds()
  +toSparseM()
}
class AdditionalCovariates {
  +createCohortCovariateSettings()
  +getCohortCovariateData()
}
class AndromedaHelperFunctions {
  +batchRestrict()
  +calculatePrevs()
  +limitCovariatesToPopulation()
}
Loading

Resample:

classDiagram
class Resample

class Sampling {
  +createSampleSettings()
  +overSampleData()
  +sameData()
  +sampleData()
  +underSampleData()
}

class DataSplitting {
  +checkInputsSplit()
  +createDefaultSplitSetting()
  +dataSummary()
  +randomSplitter()
  +splitData()
  +subjectSplitter()
  +timeSplitter()
}

Resample --> Sampling
Resample --> DataSplitting
Loading

Task:

classDiagram
class Task

Task --> PopulationSettings
Task --> DiagnosePlp

class PopulationSettings {
  +createStudyPopulation()
  +createStudyPopulationSettings()
  +getCounts()
  +getCounts2()
}
class DiagnosePlp {
  +cos_sim()
  +diagnoseMultiplePlp()
  +diagnosePlp()
  +getDiagnostic()
  +getMaxEndDaysFromCovariates()
  +getOutcomeSummary()
  +probastDesign()
  +probastOutcome()
  +probastParticipants()
  +probastPredictors()
}
Loading

Learner:

classDiagram
class Learner

Learner --> SklearnToJson
Learner --> SklearnClassifierSettings
Learner --> SklearnClassifierHelpers
Learner --> SklearnClassifier
Learner --> RClassifier
Learner --> KNN
Learner --> LightGBM
Learner --> GradientBoostingMachine
Learner --> CyclopsModels
Learner --> CyclopsSettings

class SklearnToJson {
  +deSerializeAdaboost()
  +deSerializeCsrMatrix()
  +deSerializeDecisionTree()
  +deSerializeMlp()
  +deSerializeNaiveBayes()
  +deSerializeRandomForest()
  +deSerializeSVM()
  +deSerializeTree()
  +serializeAdaboost()
  +serializeCsrMatrix()
  +serializeDecisionTree()
  +serializeMLP()
  +serializeNaiveBayes()
  +serializeRandomForest()
  +serializeSVM()
  +serializeTree()
  +sklearnFromJson()
  +sklearnToJson()
}
class SklearnClassifierSettings {
  +AdaBoostClassifierInputs()
  +DecisionTreeClassifierInputs()
  +GaussianNBInputs()
  +MLPClassifierInputs()
  +RandomForestClassifierInputs()
  +setAdaBoost()
  +setDecisionTree()
  +setMLP()
  +setNaiveBayes()
  +setRandomForest()
  +setSVM()
  +SVCInputs()
}
class SklearnClassifierHelpers {
  +listCartesian()
}
class SklearnClassifier {
  +checkPySettings()
  +computeGridPerformance()
  +fitPythonModel()
  +fitSklearn()
  +gridCvPython()
  +predictPythonSklearn()
  +predictValues()
}
class RClassifier {
  +applyCrossValidationInR()
  +fitRclassifier()
}
class KNN {
  +fitKNN()
  +predictKnn()
  +setKNN()
}
class LightGBM {
  +fitLightGBM()
  +predictLightGBM()
  +setLightGBM()
  +varImpLightGBM()
}
class GradientBoostingMachine {
  +fitXgboost()
  +predictXgboost()
  +setGradientBoostingMachine()
  +varImpXgboost()
}
class CyclopsModels {
  +createCyclopsModel()
  +filterCovariateIds()
  +fitCyclopsModel()
  +getCV()
  +getVariableImportance()
  +modelTypeToCyclopsModelType()
  +predictCyclops()
  +predictCyclopsType()
  +reparamTransferCoefs()
}
class CyclopsSettings {
  +setCoxModel()
  +setIterativeHardThresholding()
  +setLassoLogisticRegression()
}
Loading

Measure:

classDiagram
class Measure

Measure --> ViewShinyPlp
Measure --> uploadToDatabasePerformance
Measure --> uploadToDatabase
Measure --> uploadToDatabaseDiagnostics
Measure --> uploadToDatabaseModelDesign
Measure --> ThresholdSummary
Measure --> PredictionDistribution
Measure --> Plotting
Measure --> CovariateSummary
Measure --> EvaluatePlp
Measure --> EvaluationSummary
Measure --> DemographicSummary
Measure --> CalibrationSummary
Measure --> ImportFromCsv

class ViewShinyPlp {
  +viewDatabaseResultPlp()
  +viewMultiplePlp()
  +viewPlp()
  +viewPlps()
}
class uploadToDatabasePerformance {
  +addAttrition()
  +addCalibrationSummary()
  +addCovariateSummary()
  +addDemographicSummary()
  +addEvaluation()
  +addEvaluationStatistics()
  +addPerformance()
  +addPredictionDistribution()
  +addThresholdSummary()
  +checkResultExists()
  +getColumnNames()
  +insertPerformanceInDatabase()
}
class uploadToDatabase {
  +addCohort()
  +addDatabase()
  +addModel()
  +addMultipleRunPlpToDatabase()
  +addRunPlpToDatabase()
  +checkJson()
  +checkTable()
  +cleanNum()
  +createDatabaseList()
  +createDatabaseSchemaSettings()
  +createPlpResultTables()
  +deleteTables()
  +enc()
  +getCohortDef()
  +getPlpResultTables()
  +getResultLocations()
  +insertModelInDatabase()
  +insertResultsToSqlite()
  +insertRunPlpToSqlite()
}
class uploadToDatabaseDiagnostics {
  +addDiagnosePlpToDatabase()
  +addDiagnostic()
  +addMultipleDiagnosePlpToDatabase()
  +addResultTable()
  +insertDiagnosisToDatabase()
}
class uploadToDatabaseModelDesign {
  +addCovariateSetting()
  +addFESetting()
  +addModelDesign()
  +addModelSetting()
  +addPlpDataSetting()
  +addPopulationSetting()
  +addSampleSetting()
  +addSplitSettings()
  +addTar()
  +addTidySetting()
  +insertModelDesignInDatabase()
  +insertModelDesignSettings()
  +orderJson()
}
class ThresholdSummary {
  +accuracy()
  +checkToByTwoTableInputs()
  +diagnosticOddsRatio()
  +f1Score()
  +falseDiscoveryRate()
  +falseNegativeRate()
  +falseOmissionRate()
  +falsePositiveRate()
  +getThresholdSummary()
  +getThresholdSummary_binary()
  +getThresholdSummary_survival()
  +negativeLikelihoodRatio()
  +negativePredictiveValue()
  +positiveLikelihoodRatio()
  +positivePredictiveValue()
  +sensitivity()
  +specificity()
  +stdca()
}
class PredictionDistribution {
  +getPredictionDistribution()
  +getPredictionDistribution_binary()
  +getPredictionDistribution_survival()
}
class Plotting {
  +outcomeSurvivalPlot()
  +plotDemographicSummary()
  +plotF1Measure()
  +plotGeneralizability()
  +plotPlp()
  +plotPrecisionRecall()
  +plotPredictedPDF()
  +plotPredictionDistribution()
  +plotPreferencePDF()
  +plotSmoothCalibration()
  +plotSmoothCalibrationLoess()
  +plotSmoothCalibrationRcs()
  +plotSparseCalibration()
  +plotSparseCalibration2()
  +plotSparseRoc()
  +plotVariableScatterplot()
}
class CovariateSummary {
  +aggregateCovariateSummaries()
  +covariateSummary()
  +covariateSummarySubset()
  +createCovariateSubsets()
  +getCovariatesForGroup()
}
class EvaluatePlp {
  +evaluatePlp()
  +modelBasedConcordance()
}
class EvaluationSummary {
  +aucWithCi()
  +aucWithoutCi()
  +averagePrecision()
  +brierScore()
  +calculateEStatisticsBinary()
  +calibrationInLarge()
  +calibrationInLargeIntercept()
  +calibrationLine()
  +calibrationWeak()
  +computeAuc()
  +getEvaluationStatistics()
  +getEvaluationStatistics_binary()
  +getEvaluationStatistics_survival()
  +ici()
}
class DemographicSummary {
  +getDemographicSummary()
  +getDemographicSummary_binary()
  +getDemographicSummary_survival()
}
class CalibrationSummary {
  +getCalibrationSummary()
  +getCalibrationSummary_binary()
  +getCalibrationSummary_survival()
}
class ImportFromCsv {
  +extractCohortDefinitionsCSV()
  +extractDatabaseListCSV()
  +extractDiagnosticFromCsv()
  +extractObjectFromCsv()
  +getModelDesignCsv()
  +getModelDesignSettingTable()
  +getPerformanceEvaluationCsv()
  +getTableNamesPlp()
  +insertCsvToDatabase()
}
Loading

Prediction:

classDiagram
class Prediction

Prediction --> ExternalValidatePlp
Prediction --> Recalibration
Prediction --> Predict

class ExternalValidatePlp {
  +createValidationDesign()
  +createValidationSettings()
  +externalValidateDbPlp()
  +externalValidatePlp()
  +validateExternal()
  +validateModel()
}
class Recalibration {
  +inverseLog()
  +logFunct()
  +recalibratePlp()
  +recalibratePlpRefit()
  +recalibrationInTheLarge()
  +weakRecalibration()
}
class Predict {
  +applyFeatureengineering()
  +applyTidyCovariateData()
  +predictPlp()
}
Loading

@lhjohn lhjohn changed the title Discussion on code base refactor Code base refactor - Discussion Jul 3, 2024
@egillax
Copy link
Collaborator

egillax commented Jul 31, 2024

For information. TidyModels uses Parsnip to provide model interfaces. They describe their design here:

https://github.com/tidymodels/parsnip/tree/main/R#readme

They seem to be using function calls although it is a bit complicated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants