diff --git a/pmml-rexp-lightgbm/src/main/java/org/jpmml/rexp/lightgbm/LightGBMConverter.java b/pmml-rexp-lightgbm/src/main/java/org/jpmml/rexp/lightgbm/LightGBMConverter.java index a52411b..1d6884f 100644 --- a/pmml-rexp-lightgbm/src/main/java/org/jpmml/rexp/lightgbm/LightGBMConverter.java +++ b/pmml-rexp-lightgbm/src/main/java/org/jpmml/rexp/lightgbm/LightGBMConverter.java @@ -22,23 +22,67 @@ import java.io.IOException; import java.io.InputStream; import java.util.Collections; +import java.util.List; +import java.util.Map; -import org.dmg.pmml.PMML; +import org.dmg.pmml.mining.MiningModel; +import org.jpmml.converter.Feature; +import org.jpmml.converter.Label; +import org.jpmml.converter.Schema; import org.jpmml.lightgbm.GBDT; import org.jpmml.lightgbm.LightGBMUtil; -import org.jpmml.rexp.Converter; +import org.jpmml.rexp.ModelConverter; import org.jpmml.rexp.REnvironment; import org.jpmml.rexp.RExpEncoder; import org.jpmml.rexp.RRaw; -public class LightGBMConverter extends Converter { +public class LightGBMConverter extends ModelConverter { + + private GBDT gbdt = null; + public LightGBMConverter(REnvironment environment){ super(environment); } @Override - public PMML encodePMML(RExpEncoder encoder){ + public void encodeSchema(RExpEncoder encoder){ + GBDT gbdt = ensureGBDT(); + + Schema schema = gbdt.encodeSchema(null, null, encoder); + + Label label = schema.getLabel(); + List features = schema.getFeatures(); + + encoder.setLabel(label); + + for(Feature feature : features){ + encoder.addFeature(feature); + } + } + + @Override + public MiningModel encodeModel(Schema schema){ + GBDT gbdt = ensureGBDT(); + + // XXX + Map options = Collections.emptyMap(); + + Schema lgbmSchema = gbdt.toLightGBMSchema(schema); + + return gbdt.encodeModel(options, lgbmSchema); + } + + private GBDT ensureGBDT(){ + + if(this.gbdt == null){ + this.gbdt = loadGBDT(); + } + + return this.gbdt; + } + + private GBDT loadGBDT(){ REnvironment environment = getObject(); RRaw raw = (RRaw)environment.findVariable("raw"); @@ -46,14 +90,10 @@ public PMML encodePMML(RExpEncoder encoder){ throw new IllegalArgumentException(); } - GBDT gbdt; - try(InputStream is = new ByteArrayInputStream(raw.getValue())){ - gbdt = LightGBMUtil.loadGBDT(is); + return LightGBMUtil.loadGBDT(is); } catch(IOException ioe){ throw new IllegalArgumentException(ioe); } - - return gbdt.encodePMML(Collections.emptyMap(), null, null); } } \ No newline at end of file