From 36c289af8fd5b89e2a7bcc1029caad2ae4250c66 Mon Sep 17 00:00:00 2001 From: chenjiale Date: Tue, 16 Dec 2025 21:18:52 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E6=B7=BB=E5=8A=A0=20ANN=20?= =?UTF-8?q?=E7=AE=97=E6=B3=95=E6=94=AF=E6=8C=81=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 ANN 算法相关 API 接口和路由配置 - 实现 PCA 与 ANN 算法的参数分离管理机制 - 增加测点类型的输入/输出标识功能 - 优化模型训练和测试的数据处理逻辑 - 完善界面展示,区分不同算法的配置项显示 - 强化表单校验和数据标准化处理函数 - 更新模型评估报告中的算法适配逻辑 --- src/api/alert/model/models.ts | 6 + src/views/model/AssessReport.vue | 14 +- src/views/model/list/step/Step2.vue | 22 ++- src/views/model/list/step/Step3.vue | 12 +- src/views/model/list/step/data.tsx | 10 ++ src/views/model/train/data.tsx | 6 + src/views/model/train/index.vue | 238 ++++++++++++++++++++++------ 7 files changed, 254 insertions(+), 54 deletions(-) diff --git a/src/api/alert/model/models.ts b/src/api/alert/model/models.ts index 0de2561..ce101ed 100644 --- a/src/api/alert/model/models.ts +++ b/src/api/alert/model/models.ts @@ -10,7 +10,9 @@ enum Api { CALCULATE_BACK = '/alert/model/data/calculate/', OPTIMISTIC = '/alert/optimistic', TRAIN_MODEL = '/alert/model/train', + TRAIN_MODEL_ANN = '/alert/model/train/ann', TEST_MODEL = '/alert/model/test', + TEST_MODEL_ANN = '/alert/model/test/ann', BOTTOM_MODEL = '/alert/model/bottom/', VERSION_LIST = '/alert/model/version/', VERSION_NEW = '/alert/model/version/new/', @@ -42,8 +44,12 @@ export const getOptimisticApi = (params: any) => defHttp.get({ url: Api.OPT export const trainModelApi = (params: any) => defHttp.post({ url: Api.TRAIN_MODEL, data: params }) +export const trainModelAnnApi = (params: any) => defHttp.post({ url: Api.TRAIN_MODEL_ANN, data: params }) + export const testModelApi = (params: any) => defHttp.post({ url: Api.TEST_MODEL, data: params }) +export const testModelAnnApi = (params: any) => defHttp.post({ url: Api.TEST_MODEL_ANN, data: params }) + export function bottomModelApi(id: any, reportId?: number | string) { return defHttp.post({ url: Api.BOTTOM_MODEL + id, diff --git a/src/views/model/AssessReport.vue b/src/views/model/AssessReport.vue index 2aca232..895c025 100644 --- a/src/views/model/AssessReport.vue +++ b/src/views/model/AssessReport.vue @@ -258,10 +258,20 @@ export default defineComponent({ }, 0) } + const normalizeTypeFlag = (val: any) => { + if (val === true || val === '是' || val === 1 || val === '1' || val === '输出' || val === 'output') + return true + if (val === false || val === '否' || val === 0 || val === '0' || val === '输入' || val === 'input') + return false + return !!val + } + const normalizePointRow = (item: any, index: number): AssessPointRow => { const tMax = item.TMax ?? item.tMax const tMin = item.TMin ?? item.tMin const locked = item.lock === true + const typeFlag = normalizeTypeFlag(item.type) + const isAnn = algorithm.value.toUpperCase() === 'ANN' return { ...item, description: item.description ?? item.Description, @@ -275,8 +285,8 @@ export default defineComponent({ Upper: item.Upper ?? item.upper, Unit: item.Unit ?? item.unit, lock: locked, - // 锁定则不参与,否则默认参与预警 - alarm: !locked, + // 锁定则不参与;ANN 算法仅 type 为 true 参与预警 + alarm: !locked && (!isAnn || typeFlag === true), } } diff --git a/src/views/model/list/step/Step2.vue b/src/views/model/list/step/Step2.vue index 8a0f35b..d00e410 100644 --- a/src/views/model/list/step/Step2.vue +++ b/src/views/model/list/step/Step2.vue @@ -48,10 +48,24 @@ export default defineComponent({ try { const values = await validate() - const modelInfo = toRaw(props.beforeData) - modelInfo.algorithm = values.algorithm - modelInfo.sampling = values.sampling - modelInfo.rate = values.rate + const modelInfo = { ...(toRaw(props.beforeData) || {}) } + const { algorithm, sampling, rate, iteration, hiddenLayer, iter } = values + + modelInfo.algorithm = algorithm + modelInfo.sampling = sampling + + if (algorithm === 'PCA') { + modelInfo.rate = rate + delete modelInfo.iteration + delete modelInfo.layer + delete modelInfo.iter + } + else if (algorithm === 'ANN') { + modelInfo.iteration = iteration + modelInfo.layer = hiddenLayer + modelInfo.iter = iter + delete modelInfo.rate + } emit('next', modelInfo) } catch (error) { diff --git a/src/views/model/list/step/Step3.vue b/src/views/model/list/step/Step3.vue index d505620..74d9aad 100644 --- a/src/views/model/list/step/Step3.vue +++ b/src/views/model/list/step/Step3.vue @@ -78,6 +78,7 @@ export default defineComponent({ unit: point.unit, Lower: point.Lower, Upper: point.Upper, + type: props.beforeData?.algorithm === 'ANN', dead: true, limit: false, } @@ -131,9 +132,18 @@ export default defineComponent({ {{ beforeData.sampling }} - + {{ beforeData.rate }} + + {{ beforeData.iteration }} + + + {{ beforeData.layer }} + + + {{ beforeData.iter }} + diff --git a/src/views/model/list/step/data.tsx b/src/views/model/list/step/data.tsx index 73d0ced..6baf93b 100644 --- a/src/views/model/list/step/data.tsx +++ b/src/views/model/list/step/data.tsx @@ -79,6 +79,16 @@ export const step2Schemas: FormSchema[] = [ colProps: { span: 24 }, labelWidth: 120, // 增大label宽度 }, + { + field: 'iter', + component: 'InputNumber', + label: '最大迭代次数', + defaultValue: 300, + required: ({ values }) => values.algorithm === 'ANN', + ifShow: ({ values }) => values.algorithm === 'ANN', + colProps: { span: 24 }, + labelWidth: 120, + }, ] export const step3Schemas: FormSchema[] = [ diff --git a/src/views/model/train/data.tsx b/src/views/model/train/data.tsx index acee5c6..6ae7b19 100644 --- a/src/views/model/train/data.tsx +++ b/src/views/model/train/data.tsx @@ -155,6 +155,12 @@ export const pointTableSchema: BasicColumn[] = [ width: 120, dataIndex: 'limit', }, + { + title: '测点类型', + width: 120, + dataIndex: 'type', + slots: { customRender: 'pointType' }, + }, { title: '编辑', width: 100, diff --git a/src/views/model/train/index.vue b/src/views/model/train/index.vue index 2f6d9bb..4991498 100644 --- a/src/views/model/train/index.vue +++ b/src/views/model/train/index.vue @@ -17,15 +17,16 @@ import { Input, InputNumber, Modal, + Radio, RangePicker, Row, Select, Space, Spin, Steps, + Switch, Table, Tabs, - Radio, } from 'ant-design-vue' import VueECharts from 'vue-echarts' import PointTransfer from '../components/PointTransfer.vue' @@ -37,7 +38,9 @@ import { bottomModelApi, createDraftVersionApi, modelInfoApi, + testModelAnnApi, testModelApi, + trainModelAnnApi, trainModelApi, updateModelInfo, versionListApi, @@ -79,6 +82,7 @@ export default defineComponent({ ASpace: Space, ARadio: Radio, ARadioGroup: Radio.Group, + ASwitch: Switch, Icon, PointTransfer, }, @@ -134,12 +138,20 @@ export default defineComponent({ const pointData = computed(() => { const list = model.value?.pointInfo || [] + const normalizeBool = (val: any) => { + if (val === true || val === '是' || val === 1 || val === '1' || val === '输出' || val === 'output') + return true + if (val === false || val === '否' || val === 0 || val === '0' || val === '输入' || val === 'input') + return false + return !!val + } return list.map((p: any) => ({ description: p.description ?? p.Description, PointId: p.PointId ?? p.pointId, unit: p.unit ?? p.Unit, Upper: p.Upper ?? p.upper, Lower: p.Lower ?? p.lower, + type: normalizeBool(p.type), dead: p.dead === true ? '是' : '否', limit: p.limit === true ? '是' : '否', upperBound: p.upperBound ?? p.upperbound ?? p.upperBound, @@ -174,6 +186,26 @@ export default defineComponent({ return sum + (Number.isFinite(modeVal) ? modeVal : 0) }, 500) }) + function handlePointTypeChange(pointId: any, value: boolean) { + if (!isANNAlgorithm.value) + return + if (!showTrainActions.value) { + createMessage.warning('非草稿版本不可修改测点类型') + return + } + const infoList = model.value?.pointInfo + if (!Array.isArray(infoList)) + return + const idx = infoList.findIndex(item => (item?.PointId ?? item?.pointId) === pointId) + if (idx === -1) + return + infoList[idx] = { + ...infoList[idx], + type: value, + } + model.value.pointInfo = [...infoList] + updateModelInfoDebounced() + } const activeKey = ref('1') type RangeValue = [Dayjs, Dayjs] @@ -230,18 +262,43 @@ export default defineComponent({ createMessage.error('未获取到模型信息,无法获取测试数据') return } - const params = { - Model_id: id, - version: model.value?.version ? model.value?.version : 'v-test', - Test_Data: { - time: timeRange - .map(t => dayjs(t).format('YYYY-MM-DD HH:mm:ss')) - .join(','), - points: model.value.pointInfo.map(t => t.PointId).join(','), - interval: model.value.sampling * 1000, - }, + const algorithm = (model.value?.algorithm || 'PCA').toUpperCase() + const timeStr = timeRange + .map(t => dayjs(t).format('YYYY-MM-DD HH:mm:ss')) + .join(',') + const pointStr = model.value.pointInfo.map(t => t.PointId).join(',') + const intervalMs = model.value.sampling * 1000 + let result + if (algorithm === 'ANN') { + const normalizeTypeFlag = (val: any) => { + if (val === true || val === '是' || val === 1 || val === '1' || val === '输出' || val === 'output') + return true + if (val === false || val === '否' || val === 0 || val === '0' || val === '输入' || val === 'input') + return false + return !!val + } + const typeStr = model.value.pointInfo.map(t => (normalizeTypeFlag(t?.type) ? '1' : '0')).join(',') + const params = { + time: timeStr, + points: pointStr, + interval: intervalMs, + model: JSON.stringify(model.value?.para ?? ''), + type: typeStr, + } + result = await testModelAnnApi(params) + } + else { + const params = { + Model_id: id, + version: model.value?.version ? model.value?.version : 'v-test', + Test_Data: { + time: timeStr, + points: pointStr, + interval: intervalMs, + }, + } + result = await testModelApi(params) } - const result = await testModelApi(params) const sampleData = result?.sampleData ?? result?.SampleData ?? [] const reconData = result?.reconData ?? result?.ReconData ?? [] if (!Array.isArray(sampleData) || !Array.isArray(reconData) || sampleData.length === 0 || reconData.length === 0) { @@ -561,35 +618,59 @@ export default defineComponent({ console.error('模型参数点信息为空,无法训练') return } - const params = { - conditon: modelInfo.alarmmodelset?.alarmcondition || '1=1', - Hyper_para: { - percent: modelInfo.rate, - }, - Train_Data: { - points: pointInfo.map(item => item.PointId).join(','), - dead: pointInfo.map(item => (item.dead ? '1' : '0')).join(','), - limit: pointInfo.map(item => (item.limit ? '1' : '0')).join(','), - uplow: pointInfo - .map( - item => - `${item.Upper ? item.Upper : null},${ - item.Lower ? item.Lower : null - }`, - ) - .join(';'), - interval: modelInfo.sampling * 1000, - time: modelInfo.trainTime - .map(item => `${item.st},${item.et}`) - .join(';'), - }, - type: 'PCA', - smote_config: [], - smote: true, + const algorithm = (modelInfo.algorithm || 'PCA').toUpperCase() + const condition = modelInfo.alarmmodelset?.alarmcondition || '1=1' + const trainData = { + points: pointInfo.map(item => item.PointId).join(','), + dead: pointInfo.map(item => (item.dead ? '1' : '0')).join(','), + limit: pointInfo.map(item => (item.limit ? '1' : '0')).join(','), + uplow: pointInfo + .map( + item => + `${item.Upper ?? ''},${item.Lower ?? ''}`, + ) + .join(';'), + interval: modelInfo.sampling * 1000, + time: modelInfo.trainTime.map(item => `${item.st},${item.et}`).join(';'), } spinning.value = true try { - const response = await trainModelApi(params) + let response + if (algorithm === 'ANN') { + const normalizeTypeFlag = (val: any) => { + if (val === true || val === '是' || val === 1 || val === '1' || val === '输出' || val === 'output') + return true + if (val === false || val === '否' || val === 0 || val === '0' || val === '输入' || val === 'input') + return false + return !!val + } + const typeStr = pointInfo.map(item => (normalizeTypeFlag(item?.type) ? '1' : '0')).join(',') + const params = { + type: typeStr, + iter: modelInfo.iter, + hide: modelInfo.layer, + conditon: condition, + condition, + Train_Data: trainData, + } + response = await trainModelAnnApi(params) + response = JSON.parse(response) + response.filename = JSON.parse(response.filename) + } + else { + const hyperPara: Record = {} + if (algorithm === 'PCA') + hyperPara.percent = modelInfo.rate + const params = { + conditon: condition, + Hyper_para: hyperPara, + Train_Data: trainData, + type: algorithm, + smote_config: [], + smote: true, + } + response = await trainModelApi(params) + } model.value.para = response const modelInfoDetail = response.Model_info || response.modelInfo @@ -770,8 +851,13 @@ export default defineComponent({ const editModelForm = ref({ sampling: 0, rate: 0, + iter: 0, + layer: '', selectedKeys: [], }) + const algorithmValue = computed(() => (model.value?.algorithm || 'PCA').toUpperCase()) + const isPCAAlgorithm = computed(() => algorithmValue.value === 'PCA') + const isANNAlgorithm = computed(() => algorithmValue.value === 'ANN') function openEditModel() { if (!canEditModel.value) { @@ -780,6 +866,8 @@ export default defineComponent({ } editModelForm.value.sampling = model.value?.sampling || 0 editModelForm.value.rate = model.value?.rate || 0 + editModelForm.value.iter = model.value?.iter || 0 + editModelForm.value.layer = model.value?.layer || '' editModelForm.value.selectedKeys = (model.value?.pointInfo || []).map(item => buildPointKeyFromInfo({ description: item.description ?? item.Description, PointId: item.PointId ?? item.pointId, @@ -795,7 +883,22 @@ export default defineComponent({ return } model.value.sampling = editModelForm.value.sampling - model.value.rate = editModelForm.value.rate + if (isPCAAlgorithm.value) { + model.value.rate = editModelForm.value.rate + delete model.value.layer + delete model.value.iter + } + else if (isANNAlgorithm.value) { + model.value.layer = editModelForm.value.layer + model.value.iter = editModelForm.value.iter + if (!Array.isArray(model.value.pointInfo)) + model.value.pointInfo = [] + model.value.pointInfo = model.value.pointInfo.map((item: any) => ({ + ...item, + type: item?.type ?? true, + })) + delete model.value.rate + } model.value.pointInfo = editModelForm.value.selectedKeys.map((key) => { const { description, PointId, unit, Lower, Upper } = parsePointKey(key) return { @@ -804,6 +907,7 @@ export default defineComponent({ unit, Lower, Upper, + type: isANNAlgorithm.value ? true : undefined, dead: true, limit: false, } @@ -899,9 +1003,9 @@ export default defineComponent({ else { queryParts.push('sampleType=train') } - if (assessSampleCount.value !== null && assessSampleCount.value !== undefined && assessSampleCount.value !== '') { + if (assessSampleCount.value !== null && assessSampleCount.value !== undefined && assessSampleCount.value !== '') queryParts.push(`sampleCount=${assessSampleCount.value}`) - } + assessConfigVisible.value = false go(`/model/assess-report/${id}?${queryParts.join('&')}`) } @@ -1135,6 +1239,9 @@ export default defineComponent({ createDraftVersion, goAssessReport, showTrainActions, + isPCAAlgorithm, + isANNAlgorithm, + handlePointTypeChange, canEditModel, effectiveSampleCount, assessConfigVisible, @@ -1219,12 +1326,18 @@ export default defineComponent({ {{ model?.pointInfo.length || "暂无" }} - + {{ model?.rate }} - + {{ model?.principal }} + + {{ model?.layer || '暂无' }} + + + {{ model?.iter }} + {{ model?.precision }} @@ -1272,6 +1385,18 @@ export default defineComponent({ +