forked from SciSharp/BotSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPlatformBuilderBase.cs
More file actions
245 lines (200 loc) · 8.4 KB
/
PlatformBuilderBase.cs
File metadata and controls
245 lines (200 loc) · 8.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
using BotSharp.Core.Engines;
using BotSharp.Platform.Abstractions;
using BotSharp.Platform.Models;
using BotSharp.Platform.Models.AiRequest;
using BotSharp.Platform.Models.AiResponse;
using BotSharp.Platform.Models.Contexts;
using BotSharp.Platform.Models.Entities;
using BotSharp.Platform.Models.Intents;
using BotSharp.Platform.Models.MachineLearning;
using DotNetToolkit;
using Microsoft.Extensions.Configuration;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
namespace BotSharp.Core
{
public abstract class PlatformBuilderBase<TAgent> where TAgent : AgentBase
{
public TAgent Agent { get; set; }
public IAgentStorage<TAgent> Storage { get; set; }
protected readonly IAgentStorageFactory<TAgent> agentStorageFactory;
protected readonly IContextStorageFactory<AIContext> contextStorageFactory;
protected readonly IPlatformSettings settings;
public PlatformBuilderBase(IAgentStorageFactory<TAgent> agentStorageFactory, IContextStorageFactory<AIContext> contextStorageFactory, IPlatformSettings settings)
{
this.agentStorageFactory = agentStorageFactory;
this.contextStorageFactory = contextStorageFactory;
this.settings = settings;
GetAgentStorage();
}
public async Task<List<TAgent>> GetAllAgents()
{
return await Storage.Query();
}
public async Task<TAgent> LoadAgentFromFile<TImporter>(string dataDir) where TImporter : IAgentImporter<TAgent>, new()
{
Console.WriteLine($"Loading agent from folder {dataDir}");
var meta = LoadMeta(dataDir);
var importer = new TImporter
{
AgentDir = dataDir
};
// Load agent summary
var agent = await importer.LoadAgent(meta);
// Load user custom entities
await importer.LoadCustomEntities(agent);
// Load agent intents
await importer.LoadIntents(agent);
// Load system buildin entities
await importer.LoadBuildinEntities(agent);
Console.WriteLine($"Loaded agent: {agent.Name} {agent.Id}");
Agent = agent;
return agent;
}
private AgentImportHeader LoadMeta(string dataDir)
{
// load meta
string metaJson = File.ReadAllText(Path.Combine(dataDir, "meta.json"));
return JsonConvert.DeserializeObject<AgentImportHeader>(metaJson);
}
public async Task<TAgent> GetAgentById(string agentId)
{
return await Storage.FetchById(agentId);
}
public async Task<TAgent> GetAgentByName(string agentName)
{
return await Storage.FetchByName(agentName);
}
public virtual async Task<ModelMetaData> Train(TAgent agent, TrainingCorpus corpus, BotTrainOptions options)
{
if (String.IsNullOrEmpty(options.AgentDir))
{
options.AgentDir = Path.Combine(AppDomain.CurrentDomain.GetData("DataPath").ToString(), "Projects", agent.Id);
}
if (String.IsNullOrEmpty(options.Model))
{
options.Model = "model_" + DateTime.UtcNow.ToString("yyyyMMdd");
}
ModelMetaData meta = null;
// train by contexts
corpus.UserSays.GroupBy(x => x.ContextHash).Select(g => new
{
Context = g.Key,
Corpus = new TrainingCorpus
{
Entities = corpus.Entities,
UserSays = corpus.UserSays.Where(x => x.ContextHash == g.Key).ToList()
}
})
.ToList()
.ForEach(async c =>
{
var trainer = new BotTrainer(settings);
agent.Corpus = c.Corpus;
meta = await trainer.Train(agent, new BotTrainOptions
{
AgentDir = options.AgentDir,
Model = options.Model + $"{Path.DirectorySeparatorChar}{c.Context}"
});
});
meta.Pipeline.Clear();
meta.Model = options.Model;
return meta;
}
public virtual async Task<TResult> TextRequest<TResult>(AiRequest request)
{
// merge last contexts
string contextHash = await GetContextsHash(request);
Console.WriteLine($"TextRequest: {request.Text}, {request.AgentId}, {string.Join(",", request.Contexts)}, {request.SessionId}");
// Load agent
var projectPath = Path.Combine(AppDomain.CurrentDomain.GetData("DataPath").ToString(), "Projects", request.AgentId);
var model = Directory.GetDirectories(projectPath).Where(x => x.Contains("model_")).Last().Split(Path.DirectorySeparatorChar).Last();
var modelPath = Path.Combine(projectPath, model);
request.AgentDir = projectPath;
request.Model = model + $"{Path.DirectorySeparatorChar}{contextHash}";
Agent = await GetAgentById(request.AgentId);
var preditor = new BotPredictor();
var doc = await preditor.Predict(Agent, request);
var predictedIntent = doc.Sentences[0].Intent;
if (predictedIntent.Confidence < Agent.MlConfig.MinConfidence)
{
predictedIntent = await FallbackResponse(request);
predictedIntent.Confidence = Agent.MlConfig.MinConfidence;
predictedIntent.Label = "fallback";
Agent.Intents.Add(new Intent
{
Name = predictedIntent.Label,
Responses = new List<IntentResponse>
{
new IntentResponse
{
IntentName = predictedIntent.Label,
Messages = new List<IntentResponseMessage>
{
new IntentResponseMessage
{
Speech = "\"" + predictedIntent.Text + "\"",
Type = AIResponseMessageType.Text
}
}
}
}
});
}
var aiResponse = new AiResponse
{
ResolvedQuery = request.Text,
Score = predictedIntent.Confidence,
Source = predictedIntent.Classifier,
Intent = predictedIntent.Label,
Entities = doc.Sentences[0].Entities
};
Console.WriteLine($"TextResponse: {aiResponse.Intent}, {request.SessionId}");
return await AssembleResult<TResult>(request, aiResponse);
}
private async Task<string> GetContextsHash(AiRequest request)
{
var ctxStore = contextStorageFactory.Get();
var contexts = await ctxStore.Fetch(request.SessionId);
for(int i = 0; i < contexts.Length; i++)
{
var ctx = contexts[i];
if (ctx.Lifespan > 0 && !request.Contexts.Exists(x => x == ctx.Name))
{
request.Contexts.Add(ctx.Name);
}
}
request.Contexts = request.Contexts.OrderBy(x => x).ToList();
return String.Join("_", request.Contexts).GetMd5Hash();
}
public virtual async Task<TextClassificationResult> FallbackResponse(AiRequest request)
{
throw new NotImplementedException("FallbackResponse");
}
public virtual async Task<TResult> AssembleResult<TResult>(AiRequest request, AiResponse response)
{
throw new NotImplementedException();
}
public virtual async Task<bool> SaveAgent(TAgent agent)
{
// default save agent in FileStorage
await Storage.Persist(agent);
return true;
}
protected IAgentStorage<TAgent> GetAgentStorage()
{
if (Storage == null)
{
Storage = agentStorageFactory.Get();
}
return Storage;
}
}
}