From 64a758016902aed39368d1470b40eff88adecdaf Mon Sep 17 00:00:00 2001 From: chenjiale Date: Tue, 16 Dec 2025 21:18:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(alert):=20=E6=96=B0=E5=A2=9E=20ANN=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E4=B8=8E=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 ANN 模型训练接口 /alert/model/train/ann - 添加 ANN 模型测试接口 /alert/model/test/ann - 扩展 ModelInfo 和 ModelInitVO 支持 ANN 层级和迭代参数 - 更新模型服务接口和实现以支持 ANN 特定逻辑 - 重构训练参数结构以区分 PCA 与 ANN 配置 - 修复模型创建时算法相关字段设置问题 - 增强模型更新时对不同算法参数的处理逻辑 - 优化点数据结构增加 type 字段支持 --- .../admin/model/ModelController.java | 18 +++ .../admin/model/model/ModelInfo.java | 8 +- .../controller/admin/model/model/Point.java | 2 + .../admin/model/vo/ModelInitVO.java | 10 ++ .../yudao/module/alert/param/TrainParam.java | 58 +++++++++- .../alert/service/model/ModelService.java | 6 + .../service/model/impl/ModelServiceImpl.java | 109 +++++++++++++++++- 7 files changed, 204 insertions(+), 7 deletions(-) diff --git a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java index b1fe66d..503d459 100644 --- a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java +++ b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/ModelController.java @@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.alert.controller.admin.model; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*; +import cn.iocoder.yudao.module.alert.param.AnnTestParam; +import cn.iocoder.yudao.module.alert.param.AnnTrainParam; import cn.iocoder.yudao.module.alert.param.ModelTestParam; import cn.iocoder.yudao.module.alert.param.TrainParam; import cn.iocoder.yudao.module.alert.service.model.ModelService; @@ -72,12 +74,28 @@ public class ModelController { } } + @PostMapping("/train/ann") + public CommonResult getAnnTrainData(@RequestBody AnnTrainParam param) { + try { + String trainInfo = modelService.trainAnn(param); + return CommonResult.success(trainInfo); + } catch (Exception e) { + return CommonResult.error(INTERNAL_SERVER_ERROR.getCode(), e.getMessage()); + } + } + @PostMapping("/test") public CommonResult getTestData(@RequestBody ModelTestParam param) { ModelTestData modelTestData = modelService.getModelTestData(param); return CommonResult.success(modelTestData); } + @PostMapping("/test/ann") + public CommonResult getAnnTestData(@RequestBody AnnTestParam param) { + ModelTestData modelTestData = modelService.getAnnModelTestData(param); + return CommonResult.success(modelTestData); + } + /** * 创建草稿版本(v-test) */ diff --git a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java index 0b56f2b..0dd78ba 100644 --- a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java +++ b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/ModelInfo.java @@ -69,6 +69,12 @@ public class ModelInfo { @JsonProperty("rate") private String rate; + @JsonProperty("layer") + private String layer; + + @JsonProperty("iter") + private Integer iter; + @JsonProperty("outPointInfo") private List outPointInfo; @@ -102,7 +108,7 @@ public class ModelInfo { private AlarmModelSet alarmModelSet; @JsonProperty("para") - private TrainInfo para; + private Object para; @JsonProperty("principal") private Integer principal; diff --git a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java index 039645b..a71bfd8 100644 --- a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java +++ b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/model/Point.java @@ -57,4 +57,6 @@ public class Point { private BigDecimal C95; @JsonProperty("C99") private BigDecimal C99; + + private boolean type; } diff --git a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java index a068366..1f76bfa 100644 --- a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java +++ b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/controller/admin/model/vo/ModelInitVO.java @@ -57,4 +57,14 @@ public class ModelInitVO { private Integer sampling; private String rate; + + /** + * ANN 模型层级配置 + */ + private String layer; + + /** + * ANN 迭代次数 + */ + private Integer iter; } diff --git a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java index 6878137..60f6c3d 100644 --- a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java +++ b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/param/TrainParam.java @@ -16,11 +16,36 @@ public class TrainParam { @JsonProperty("Train_Data") private TrainData trainData; + /** + * 算法类型:PCA / ANN 等 + */ + private String type; + + /** + * 条件(保留原拼写以兼容旧入参) + */ + private String conditon; + + /** + * 正确拼写的条件字段,兼容新入参 + */ + @JsonAlias("condition") + private String condition; + + /** + * PCA 专用配置 + */ + private PcaParam pca; + + /** + * ANN 专用配置(如仍走旧接口可使用) + */ + private AnnParam ann; + + // === 旧字段兼容 === @JsonAlias("Hyper_para") @JsonProperty("Hyper_para") private HyperPara hyperPara; - private String type; - private String conditon; @JsonAlias("smote_config") @JsonProperty("smote_config") @@ -30,6 +55,10 @@ public class TrainParam { private String targetPoint; + private String iter; + + private String hide; + @Data public static class Smote { @@ -57,4 +86,29 @@ public class TrainParam { private String uplow; private Integer interval; } + + @Data + @AllArgsConstructor + @NoArgsConstructor + public static class PcaParam { + @JsonAlias("Hyper_para") + @JsonProperty("Hyper_para") + private HyperPara hyperPara; + + @JsonAlias("smote_config") + @JsonProperty("smote_config") + private List smoteConfig; + + private Boolean smote; + + private String targetPoint; + } + + @Data + @AllArgsConstructor + @NoArgsConstructor + public static class AnnParam { + private String iter; + private String hide; + } } diff --git a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java index b4002d2..332e21c 100644 --- a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java +++ b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/ModelService.java @@ -2,7 +2,9 @@ package cn.iocoder.yudao.module.alert.service.model; import cn.iocoder.yudao.module.alert.controller.admin.model.vo.*; +import cn.iocoder.yudao.module.alert.param.AnnTestParam; import cn.iocoder.yudao.module.alert.param.ModelTestParam; +import cn.iocoder.yudao.module.alert.param.AnnTrainParam; import cn.iocoder.yudao.module.alert.param.TrainParam; import java.util.List; @@ -55,8 +57,12 @@ public interface ModelService { TrainInfo trainModel(TrainParam param); + String trainAnn(AnnTrainParam param); + ModelTestData getModelTestData(ModelTestParam param); + ModelTestData getAnnModelTestData(AnnTestParam param); + /** * 模型下装:校验训练/评估后,将模型状态置为已下装并落库版本 * diff --git a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java index a84b2b7..8c362e0 100644 --- a/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java +++ b/yudao-module-alert/yudao-module-alert-biz/src/main/java/cn/iocoder/yudao/module/alert/service/model/impl/ModelServiceImpl.java @@ -20,6 +20,8 @@ import cn.iocoder.yudao.module.alert.dao.mapper.AssessReportCfgMapper; import cn.iocoder.yudao.module.alert.dao.service.ModelCfgService; import cn.iocoder.yudao.module.alert.dao.service.ModelVersionService; import cn.iocoder.yudao.module.alert.dao.service.SystemCfgService; +import cn.iocoder.yudao.module.alert.param.AnnTestParam; +import cn.iocoder.yudao.module.alert.param.AnnTrainParam; import cn.iocoder.yudao.module.alert.param.ModelTestParam; import cn.iocoder.yudao.module.alert.param.TrainParam; import cn.iocoder.yudao.module.alert.service.model.ModelService; @@ -124,11 +126,12 @@ public class ModelServiceImpl implements ModelService { @Override public Integer createModel(ModelInitVO model) { try { + Algorithm algorithm = model.getAlgorithm(); ModelInfo info = new ModelInfo(); BeanUtil.copyProperties(model, info); ModelCfg modelCfg = ModelCfg.builder() .systemId(model.getSystemId()) - .algorithmId(model.getAlgorithm().code) + .algorithmId(algorithm.code) .modelName(model.getModelName()) .createTime(new Date()) .trash(ModelTrash.NORMAL.code) @@ -141,13 +144,25 @@ public class ModelServiceImpl implements ModelService { info.setId(String.valueOf(modelId)); info.setSystemId(model.getSystemId()); info.setUnit(model.getUnit()); - info.setAlgorithm(model.getAlgorithm()); + info.setAlgorithm(algorithm); info.setFounder(modelCfg.getCreator()); info.setCreateTime(modelCfg.getCreateTime()); info.setAlarmModelSet(ModelInfo.AlarmModelSet.defaultInit()); info.setName(model.getModelName()); info.setDescription(model.getDescription()); - info.setRate(model.getRate()); + if (Algorithm.PCA.equals(algorithm)) { + info.setRate(model.getRate()); + info.setLayer(null); + info.setIter(null); + } else if (Algorithm.ANN.equals(algorithm)) { + info.setLayer(model.getLayer()); + info.setRate(null); + info.setIter(model.getIter()); + } else { + info.setRate(model.getRate()); + info.setLayer(model.getLayer()); + info.setIter(model.getIter()); + } info.setTrainTime(new ArrayList<>()); modelCfg.setModelInfo(JsonUtils.toJsonString(info)); modelCfgService.updateById(modelCfg); @@ -161,6 +176,19 @@ public class ModelServiceImpl implements ModelService { @Override public Boolean updateModelInfo(ModelInfoVO modelInfo) { + ModelCfg existCfg = modelCfgService.getById(Integer.parseInt(modelInfo.getId())); + if (existCfg != null) { + Algorithm algorithm = Algorithm.of(existCfg.getAlgorithmId()); + modelInfo.setAlgorithm(algorithm); + if (Algorithm.PCA.equals(algorithm)) { + modelInfo.setLayer(null); + modelInfo.setIter(null); + } else if (Algorithm.ANN.equals(algorithm)) { + modelInfo.setRate(null); + } else { + // 其他算法默认保留传入字段 + } + } modelInfo.setModifier(SecurityFrameworkUtils.getLoginUserNickname()); modelInfo.setModifiedTime(new Date()); ModelCfg modelCfg = ModelCfg.builder() @@ -196,7 +224,29 @@ public class ModelServiceImpl implements ModelService { @Override public TrainInfo trainModel(TrainParam param) { - String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(param)); + if ("ANN".equalsIgnoreCase(param.getType())) { + throw new RuntimeException("请使用 /alert/model/train/ann 接口进行 ANN 训练"); + } + return trainPca(param); + } + + private TrainInfo trainPca(TrainParam param) { + Map payload = new HashMap<>(); + payload.put("Train_Data", param.getTrainData()); + payload.put("type", param.getType()); + payload.put("condition", param.getCondition()); + payload.put("conditon", param.getConditon()); + + TrainParam.PcaParam pca = param.getPca(); + TrainParam.HyperPara hyperPara = pca != null && pca.getHyperPara() != null ? pca.getHyperPara() : param.getHyperPara(); + if (hyperPara != null) { + payload.put("Hyper_para", hyperPara); + } + payload.put("smote_config", pca != null ? pca.getSmoteConfig() : param.getSmoteConfig()); + payload.put("smote", pca != null && pca.getSmote() != null ? pca.getSmote() : param.getSmote()); + payload.put("targetPoint", pca != null ? pca.getTargetPoint() : param.getTargetPoint()); + + String trainBody = HttpUtils.post(algorithmHost + "/api/test/ClearTrain", null, JsonUtils.toJsonString(payload)); if (trainBody.contains("error_msg")) { throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody, new TypeReference>() { @@ -206,6 +256,42 @@ public class ModelServiceImpl implements ModelService { }); } + @Override + public String trainAnn(AnnTrainParam param) { + Map payload = new HashMap<>(); + TrainParam.TrainData trainData = param.getTrainData(); + String time = trainData != null ? trainData.getTime() : null; + payload.put("time", time == null ? null : time.replaceAll(";+$", "")); + payload.put("point", trainData != null ? trainData.getPoints() : null); + Integer interval = trainData != null ? trainData.getInterval() : null; + payload.put("interval", interval == null ? null : interval * 1000); + + String iter = param.getIter(); + String hide = param.getHide(); + + payload.put("iter", iter); + payload.put("layer", StringUtils.hasText(hide) ? hide.split("-") : new String[]{}); + payload.put("type", param.getType()); + payload.put("dead", trainData != null ? trainData.getDead() : null); + payload.put("limit", trainData != null ? trainData.getLimit() : null); + payload.put("uplow", trainData != null ? trainData.getUplow() : null); + payload.put("condition", resolveCondition(param)); + String trainBody = HttpUtils.post(algorithmHost + "/api/test/ANN_Train", null, JsonUtils.toJsonString(payload)); + if (trainBody.contains("error_msg")) { + throw new RuntimeException("模型训练异常:" + JsonUtils.parseObject(trainBody, + new TypeReference>() { + }).get("error_msg")); + } + return trainBody; + } + + private String resolveCondition(AnnTrainParam param) { + if (StringUtils.hasText(param.getCondition())) { + return param.getCondition(); + } + return param.getConditon(); + } + @Override public ModelTestData getModelTestData(ModelTestParam param) { @@ -214,6 +300,21 @@ public class ModelServiceImpl implements ModelService { return JsonUtils.parseObject(result, ModelTestData.class); } + @Override + public ModelTestData getAnnModelTestData(AnnTestParam param) { + Map payload = new HashMap<>(); + payload.put("point", param.getPoints()); + payload.put("model", param.getModel()); + payload.put("time", param.getTime()); + Integer interval = param.getInterval(); + payload.put("interval", interval == null ? null : interval); + payload.put("type", param.getType()); + + String result = HttpUtils.post(algorithmHost + "/api/test/ANN_Test", null, JsonUtils.toJsonString(payload)) + .replace("NaN", "-1").replace("Infinity", "1"); + return JsonUtils.parseObject(result, ModelTestData.class); + } + @Override @Transactional(rollbackFor = Exception.class) public ModelInfoVO bottomModel(Integer id, Long reportId) {