Files
yaze/src/cli/service/ollama_ai_service.cc

264 lines
8.1 KiB
C++

#include "cli/service/ollama_ai_service.h"
#include <cstdlib>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
// Check if we have httplib available (from vcpkg or bundled)
#if __has_include("httplib.h")
#define YAZE_HAS_HTTPLIB 1
#include "httplib.h"
#elif __has_include("incl/httplib.h")
#define YAZE_HAS_HTTPLIB 1
#include "incl/httplib.h"
#else
#define YAZE_HAS_HTTPLIB 0
#endif
// Check if we have JSON library available
#if __has_include("third_party/json/src/json.hpp")
#define YAZE_HAS_JSON 1
#include "third_party/json/src/json.hpp"
#elif __has_include("json.hpp")
#define YAZE_HAS_JSON 1
#include "json.hpp"
#else
#define YAZE_HAS_JSON 0
#endif
namespace yaze {
namespace cli {
OllamaAIService::OllamaAIService(const OllamaConfig& config) : config_(config) {
// Load command documentation into prompt builder
prompt_builder_.LoadResourceCatalogue(""); // TODO: Pass actual yaml path when available
if (config_.system_prompt.empty()) {
// Use enhanced prompting by default
if (config_.use_enhanced_prompting) {
config_.system_prompt = prompt_builder_.BuildSystemInstructionWithExamples();
} else {
config_.system_prompt = BuildSystemPrompt();
}
}
}
std::string OllamaAIService::BuildSystemPrompt() {
// Fallback prompt if enhanced prompting is disabled
// Use PromptBuilder's basic system instruction
return prompt_builder_.BuildSystemInstruction();
}
absl::Status OllamaAIService::CheckAvailability() {
#if !YAZE_HAS_HTTPLIB || !YAZE_HAS_JSON
return absl::UnimplementedError(
"Ollama service requires httplib and JSON support. "
"Install vcpkg dependencies or use bundled libraries.");
#else
try {
httplib::Client cli(config_.base_url);
cli.set_connection_timeout(5); // 5 second timeout
auto res = cli.Get("/api/tags");
if (!res) {
return absl::UnavailableError(absl::StrFormat(
"Cannot connect to Ollama server at %s.\n"
"Make sure Ollama is installed and running:\n"
" 1. Install: brew install ollama (macOS) or https://ollama.com/download\n"
" 2. Start: ollama serve\n"
" 3. Verify: curl http://localhost:11434/api/tags",
config_.base_url));
}
if (res->status != 200) {
return absl::InternalError(absl::StrFormat(
"Ollama server error: HTTP %d\nResponse: %s",
res->status, res->body));
}
// Check if requested model is available
nlohmann::json models_json = nlohmann::json::parse(res->body);
bool model_found = false;
if (models_json.contains("models") && models_json["models"].is_array()) {
for (const auto& model : models_json["models"]) {
if (model.contains("name")) {
std::string model_name = model["name"].get<std::string>();
if (model_name.find(config_.model) != std::string::npos) {
model_found = true;
break;
}
}
}
}
if (!model_found) {
return absl::NotFoundError(absl::StrFormat(
"Model '%s' not found on Ollama server.\n"
"Pull it with: ollama pull %s\n"
"Available models: ollama list",
config_.model, config_.model));
}
return absl::OkStatus();
} catch (const std::exception& e) {
return absl::InternalError(absl::StrCat(
"Ollama health check failed: ", e.what()));
}
#endif
}
absl::StatusOr<std::vector<std::string>> OllamaAIService::ListAvailableModels() {
#if !YAZE_HAS_HTTPLIB || !YAZE_HAS_JSON
return absl::UnimplementedError("Requires httplib and JSON support");
#else
try {
httplib::Client cli(config_.base_url);
cli.set_connection_timeout(5);
auto res = cli.Get("/api/tags");
if (!res || res->status != 200) {
return absl::UnavailableError(
"Cannot list Ollama models. Is the server running?");
}
nlohmann::json models_json = nlohmann::json::parse(res->body);
std::vector<std::string> models;
if (models_json.contains("models") && models_json["models"].is_array()) {
for (const auto& model : models_json["models"]) {
if (model.contains("name")) {
models.push_back(model["name"].get<std::string>());
}
}
}
return models;
} catch (const std::exception& e) {
return absl::InternalError(absl::StrCat(
"Failed to list models: ", e.what()));
}
#endif
}
absl::StatusOr<std::string> OllamaAIService::ParseOllamaResponse(
const std::string& json_response) {
#if !YAZE_HAS_JSON
return absl::UnimplementedError("Requires JSON support");
#else
try {
nlohmann::json response_json = nlohmann::json::parse(json_response);
if (!response_json.contains("response")) {
return absl::InvalidArgumentError(
"Ollama response missing 'response' field");
}
return response_json["response"].get<std::string>();
} catch (const nlohmann::json::exception& e) {
return absl::InternalError(absl::StrCat(
"Failed to parse Ollama response: ", e.what()));
}
#endif
}
absl::StatusOr<std::vector<std::string>> OllamaAIService::GetCommands(
const std::string& prompt) {
#if !YAZE_HAS_HTTPLIB || !YAZE_HAS_JSON
return absl::UnimplementedError(
"Ollama service requires httplib and JSON support. "
"Install vcpkg dependencies or use bundled libraries.");
#else
// Build request payload
nlohmann::json request_body = {
{"model", config_.model},
{"prompt", config_.system_prompt + "\n\nUSER REQUEST: " + prompt},
{"stream", false},
{"options", {
{"temperature", config_.temperature},
{"num_predict", config_.max_tokens}
}},
{"format", "json"} // Force JSON output
};
try {
httplib::Client cli(config_.base_url);
cli.set_read_timeout(60); // Longer timeout for inference
auto res = cli.Post("/api/generate", request_body.dump(), "application/json");
if (!res) {
return absl::UnavailableError(
"Failed to connect to Ollama. Is 'ollama serve' running?\n"
"Start with: ollama serve");
}
if (res->status != 200) {
return absl::InternalError(absl::StrFormat(
"Ollama API error: HTTP %d\nResponse: %s",
res->status, res->body));
}
// Parse response to extract generated text
auto generated_text_or = ParseOllamaResponse(res->body);
if (!generated_text_or.ok()) {
return generated_text_or.status();
}
std::string generated_text = generated_text_or.value();
// Parse the command array from generated text
nlohmann::json commands_json;
try {
commands_json = nlohmann::json::parse(generated_text);
} catch (const nlohmann::json::exception& e) {
// Sometimes the LLM includes extra text - try to extract JSON array
size_t start = generated_text.find('[');
size_t end = generated_text.rfind(']');
if (start != std::string::npos && end != std::string::npos && end > start) {
std::string json_only = generated_text.substr(start, end - start + 1);
try {
commands_json = nlohmann::json::parse(json_only);
} catch (const nlohmann::json::exception&) {
return absl::InvalidArgumentError(
"LLM did not return valid JSON. Response:\n" + generated_text);
}
} else {
return absl::InvalidArgumentError(
"LLM did not return a JSON array. Response:\n" + generated_text);
}
}
if (!commands_json.is_array()) {
return absl::InvalidArgumentError(
"LLM did not return a JSON array. Response:\n" + generated_text);
}
std::vector<std::string> commands;
for (const auto& cmd : commands_json) {
if (cmd.is_string()) {
commands.push_back(cmd.get<std::string>());
}
}
if (commands.empty()) {
return absl::InvalidArgumentError(
"LLM returned empty command list. Prompt may be unclear.\n"
"Try rephrasing your request to be more specific.");
}
return commands;
} catch (const std::exception& e) {
return absl::InternalError(absl::StrCat(
"Ollama request failed: ", e.what()));
}
#endif
}
} // namespace cli
} // namespace yaze