feat: Add AI GUI controller and vision action refiner for enhanced automation

- Introduced `AIGUIController` class to manage AI-driven GUI automation with vision feedback, enabling natural language command execution and iterative action refinement.
- Implemented `VisionActionRefiner` class to analyze screenshots and refine actions based on visual feedback, improving action success rates.
- Added header and implementation files for both classes, along with necessary methods for screenshot analysis, action verification, and UI element location.
- Updated CMake configuration to include new source files for the AI GUI controller and vision action refiner functionalities.
This commit is contained in:
scawful
2025-10-04 23:09:59 -04:00
parent 39edadb7b6
commit ec88f087a2
5 changed files with 1034 additions and 0 deletions

View File

@@ -73,6 +73,8 @@ set(YAZE_AGENT_SOURCES
cli/service/agent/learned_knowledge_service.cc
cli/service/ai/ai_service.cc
cli/service/ai/ai_action_parser.cc
cli/service/ai/vision_action_refiner.cc
cli/service/ai/ai_gui_controller.cc
cli/service/ai/ollama_ai_service.cc
cli/service/ai/prompt_builder.cc
cli/service/ai/service_factory.cc

View File

@@ -0,0 +1,351 @@
#include "cli/service/ai/ai_gui_controller.h"
#include <chrono>
#include <thread>
#include "absl/strings/str_format.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "cli/service/ai/gemini_ai_service.h"
#ifdef YAZE_WITH_GRPC
#include "cli/service/gui/gui_automation_client.h"
#include "app/core/service/screenshot_utils.h"
#endif
namespace yaze {
namespace cli {
namespace ai {
AIGUIController::AIGUIController(GeminiAIService* gemini_service,
gui::GuiAutomationClient* gui_client)
: gemini_service_(gemini_service),
gui_client_(gui_client),
vision_refiner_(std::make_unique<VisionActionRefiner>(gemini_service)) {
if (!gemini_service_) {
throw std::invalid_argument("Gemini service cannot be null");
}
if (!gui_client_) {
throw std::invalid_argument("GUI client cannot be null");
}
}
absl::Status AIGUIController::Initialize(const ControlLoopConfig& config) {
config_ = config;
screenshots_dir_ = config.screenshots_dir;
EnsureScreenshotsDirectory();
return absl::OkStatus();
}
absl::StatusOr<ControlResult> AIGUIController::ExecuteCommand(
const std::string& command) {
// Parse natural language command into actions
auto actions_result = AIActionParser::ParseCommand(command);
if (!actions_result.ok()) {
return actions_result.status();
}
return ExecuteActions(*actions_result);
}
absl::StatusOr<ControlResult> AIGUIController::ExecuteActions(
const std::vector<AIAction>& actions) {
ControlResult result;
result.success = false;
for (const auto& action : actions) {
int retry_count = 0;
bool action_succeeded = false;
AIAction current_action = action;
while (retry_count < config_.max_retries_per_action && !action_succeeded) {
result.iterations_performed++;
if (result.iterations_performed > config_.max_iterations) {
result.error_message = "Max iterations reached";
return result;
}
// Execute the action with vision verification
auto execute_result = ExecuteSingleAction(
current_action,
config_.enable_vision_verification
);
if (!execute_result.ok()) {
result.error_message = std::string(execute_result.status().message());
return result;
}
result.vision_analyses.push_back(*execute_result);
result.actions_executed.push_back(current_action);
if (execute_result->action_successful) {
action_succeeded = true;
}
else if (config_.enable_iterative_refinement) {
// Refine action and retry
auto refinement = vision_refiner_->RefineAction(
current_action,
*execute_result
);
if (!refinement.ok()) {
result.error_message =
absl::StrCat("Failed to refine action: ",
refinement.status().message());
return result;
}
if (refinement->needs_different_approach) {
result.error_message =
absl::StrCat("Action requires different approach: ",
refinement->reasoning);
return result;
}
if (refinement->needs_retry) {
// Update action parameters
for (const auto& [key, value] : refinement->adjusted_parameters) {
current_action.parameters[key] = value;
}
}
retry_count++;
}
else {
// No refinement, just fail
result.error_message = execute_result->error_message;
return result;
}
}
if (!action_succeeded) {
result.error_message =
absl::StrFormat("Action failed after %d retries", retry_count);
return result;
}
}
result.success = true;
// Capture final state
auto final_screenshot = CaptureCurrentState("final_state");
if (final_screenshot.ok()) {
result.screenshots_taken.push_back(*final_screenshot);
// Analyze final state
auto final_analysis = vision_refiner_->AnalyzeScreenshot(
*final_screenshot,
"Verify all actions completed successfully"
);
if (final_analysis.ok()) {
result.final_state_description = final_analysis->description;
}
}
return result;
}
absl::StatusOr<VisionAnalysisResult> AIGUIController::ExecuteSingleAction(
const AIAction& action,
bool verify_with_vision) {
VisionAnalysisResult result;
// Capture before screenshot
std::filesystem::path before_screenshot;
if (verify_with_vision) {
auto before_result = CaptureCurrentState("before_action");
if (!before_result.ok()) {
return before_result.status();
}
before_screenshot = *before_result;
}
// Wait for UI to settle
if (config_.screenshot_delay_ms > 0) {
std::this_thread::sleep_for(
std::chrono::milliseconds(config_.screenshot_delay_ms));
}
// Execute the action via gRPC
auto execute_status = ExecuteGRPCAction(action);
if (!execute_status.ok()) {
result.action_successful = false;
result.error_message = std::string(execute_status.message());
return result;
}
// Wait for action to complete
std::this_thread::sleep_for(
std::chrono::milliseconds(config_.screenshot_delay_ms));
if (verify_with_vision) {
// Capture after screenshot
auto after_result = CaptureCurrentState("after_action");
if (!after_result.ok()) {
return after_result.status();
}
// Verify with vision
return VerifyActionSuccess(action, before_screenshot, *after_result);
}
else {
// Assume success without verification
result.action_successful = true;
result.description = "Action executed (no vision verification)";
return result;
}
}
absl::StatusOr<VisionAnalysisResult> AIGUIController::AnalyzeCurrentGUIState(
const std::string& context) {
auto screenshot = CaptureCurrentState("analysis");
if (!screenshot.ok()) {
return screenshot.status();
}
return vision_refiner_->AnalyzeScreenshot(*screenshot, context);
}
// Private helper methods
absl::StatusOr<std::filesystem::path> AIGUIController::CaptureCurrentState(
const std::string& description) {
#ifdef YAZE_WITH_GRPC
std::filesystem::path path = GenerateScreenshotPath(description);
auto result = yaze::test::CaptureHarnessScreenshot(path.string());
if (!result.ok()) {
return result.status();
}
return std::filesystem::path(result->file_path);
#else
return absl::UnimplementedError("Screenshot capture requires gRPC support");
#endif
}
absl::Status AIGUIController::ExecuteGRPCAction(const AIAction& action) {
// Convert AI action to gRPC test commands
auto grpc_commands = action_generator_.GenerateGRPCCommands({action});
if (grpc_commands.empty()) {
return absl::InternalError("No gRPC commands generated for action");
}
// Execute each command
for (const auto& command_json : grpc_commands) {
// Parse JSON and execute via GUI client
// This is a placeholder - actual implementation would parse JSON
// and call appropriate GUI client methods
if (action.type == AIActionType::kClickButton) {
auto button_it = action.parameters.find("button");
if (button_it != action.parameters.end()) {
auto status = gui_client_->ClickButton(button_it->second);
if (!status.ok()) {
return status;
}
}
}
else if (action.type == AIActionType::kPlaceTile) {
// Extract parameters
auto x_it = action.parameters.find("x");
auto y_it = action.parameters.find("y");
auto tile_it = action.parameters.find("tile_id");
if (x_it != action.parameters.end() &&
y_it != action.parameters.end() &&
tile_it != action.parameters.end()) {
int x = std::stoi(x_it->second);
int y = std::stoi(y_it->second);
int tile_id = std::stoi(tile_it->second);
// Use GUI client to place tile
// (This would need actual implementation in GuiAutomationClient)
auto status = gui_client_->ExecuteTestScript(
absl::StrFormat("PlaceTile(%d, %d, %d)", x, y, tile_id));
if (!status.ok()) {
return status;
}
}
}
else if (action.type == AIActionType::kWait) {
int wait_ms = config_.screenshot_delay_ms;
auto wait_it = action.parameters.find("duration_ms");
if (wait_it != action.parameters.end()) {
wait_ms = std::stoi(wait_it->second);
}
std::this_thread::sleep_for(std::chrono::milliseconds(wait_ms));
}
}
return absl::OkStatus();
}
absl::StatusOr<VisionAnalysisResult> AIGUIController::VerifyActionSuccess(
const AIAction& action,
const std::filesystem::path& before_screenshot,
const std::filesystem::path& after_screenshot) {
return vision_refiner_->VerifyAction(action, before_screenshot, after_screenshot);
}
absl::StatusOr<AIAction> AIGUIController::RefineActionWithVision(
const AIAction& original_action,
const VisionAnalysisResult& analysis) {
auto refinement = vision_refiner_->RefineAction(original_action, analysis);
if (!refinement.ok()) {
return refinement.status();
}
AIAction refined_action = original_action;
// Apply adjusted parameters
for (const auto& [key, value] : refinement->adjusted_parameters) {
refined_action.parameters[key] = value;
}
return refined_action;
}
void AIGUIController::EnsureScreenshotsDirectory() {
std::error_code ec;
std::filesystem::create_directories(screenshots_dir_, ec);
if (ec) {
std::cerr << "Warning: Failed to create screenshots directory: "
<< ec.message() << std::endl;
}
}
std::filesystem::path AIGUIController::GenerateScreenshotPath(
const std::string& suffix) {
int64_t timestamp = absl::ToUnixMillis(absl::Now());
std::string filename = absl::StrFormat(
"ai_gui_%s_%lld.png",
suffix,
static_cast<long long>(timestamp)
);
return screenshots_dir_ / filename;
}
} // namespace ai
} // namespace cli
} // namespace yaze

View File

@@ -0,0 +1,173 @@
#ifndef YAZE_CLI_SERVICE_AI_AI_GUI_CONTROLLER_H_
#define YAZE_CLI_SERVICE_AI_AI_GUI_CONTROLLER_H_
#include <filesystem>
#include <memory>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "cli/service/ai/ai_action_parser.h"
#include "cli/service/ai/vision_action_refiner.h"
#include "cli/service/gui/gui_action_generator.h"
namespace yaze {
namespace cli {
// Forward declares
class GeminiAIService;
namespace gui {
class GuiAutomationClient;
}
namespace ai {
/**
* @struct ControlLoopConfig
* @brief Configuration for the AI GUI control loop
*/
struct ControlLoopConfig {
int max_iterations = 10; // Max attempts before giving up
int screenshot_delay_ms = 500; // Delay before taking screenshots
bool enable_vision_verification = true; // Use vision to verify actions
bool enable_iterative_refinement = true; // Retry with refined actions
int max_retries_per_action = 3; // Max retries for a single action
std::string screenshots_dir = "/tmp/yaze/ai_gui_control";
};
/**
* @struct ControlResult
* @brief Result of AI-controlled GUI automation
*/
struct ControlResult {
bool success = false;
int iterations_performed = 0;
std::vector<ai::AIAction> actions_executed;
std::vector<VisionAnalysisResult> vision_analyses;
std::vector<std::filesystem::path> screenshots_taken;
std::string error_message;
std::string final_state_description;
};
/**
* @class AIGUIController
* @brief High-level controller for AI-driven GUI automation with vision feedback
*
* This class implements the complete vision-guided control loop:
*
* 1. **Parse Command** → Natural language → AIActions
* 2. **Take Screenshot** → Capture current GUI state
* 3. **Analyze Vision** → Gemini analyzes screenshot
* 4. **Execute Action** → Send gRPC command to GUI
* 5. **Verify Success** → Compare before/after screenshots
* 6. **Refine & Retry** → Adjust parameters if action failed
* 7. **Repeat** → Until goal achieved or max iterations reached
*
* Example usage:
* ```cpp
* AIGUIController controller(gemini_service, gui_client);
* controller.Initialize(config);
*
* auto result = controller.ExecuteCommand(
* "Place tile 0x42 at overworld position (5, 7)"
* );
*
* if (result->success) {
* std::cout << "Success! Took " << result->iterations_performed
* << " iterations\n";
* }
* ```
*/
class AIGUIController {
public:
/**
* @brief Construct controller with required services
* @param gemini_service Gemini AI service for vision analysis
* @param gui_client gRPC client for GUI automation
*/
AIGUIController(GeminiAIService* gemini_service,
gui::GuiAutomationClient* gui_client);
~AIGUIController() = default;
/**
* @brief Initialize the controller with configuration
*/
absl::Status Initialize(const ControlLoopConfig& config);
/**
* @brief Execute a natural language command with AI vision guidance
* @param command Natural language command (e.g., "Place tile 0x42 at (5, 7)")
* @return Result including success status and execution details
*/
absl::StatusOr<ControlResult> ExecuteCommand(const std::string& command);
/**
* @brief Execute a sequence of pre-parsed actions
* @param actions Vector of AI actions to execute
* @return Result including success status
*/
absl::StatusOr<ControlResult> ExecuteActions(
const std::vector<ai::AIAction>& actions);
/**
* @brief Execute a single action with optional vision verification
* @param action The action to execute
* @param verify_with_vision Whether to use vision to verify success
* @return Success status and vision analysis
*/
absl::StatusOr<VisionAnalysisResult> ExecuteSingleAction(
const AIAction& action,
bool verify_with_vision = true);
/**
* @brief Analyze the current GUI state without executing actions
* @param context What to look for in the GUI
* @return Vision analysis of current state
*/
absl::StatusOr<VisionAnalysisResult> AnalyzeCurrentGUIState(
const std::string& context = "");
/**
* @brief Get the current configuration
*/
const ControlLoopConfig& config() const { return config_; }
/**
* @brief Update configuration
*/
void SetConfig(const ControlLoopConfig& config) { config_ = config; }
private:
GeminiAIService* gemini_service_; // Not owned
gui::GuiAutomationClient* gui_client_; // Not owned
std::unique_ptr<VisionActionRefiner> vision_refiner_;
gui::GuiActionGenerator action_generator_;
ControlLoopConfig config_;
std::filesystem::path screenshots_dir_;
// Helper methods
absl::StatusOr<std::filesystem::path> CaptureCurrentState(
const std::string& description);
absl::Status ExecuteGRPCAction(const AIAction& action);
absl::StatusOr<VisionAnalysisResult> VerifyActionSuccess(
const AIAction& action,
const std::filesystem::path& before_screenshot,
const std::filesystem::path& after_screenshot);
absl::StatusOr<AIAction> RefineActionWithVision(
const AIAction& original_action,
const VisionAnalysisResult& analysis);
void EnsureScreenshotsDirectory();
std::filesystem::path GenerateScreenshotPath(const std::string& suffix);
};
} // namespace ai
} // namespace cli
} // namespace yaze
#endif // YAZE_CLI_SERVICE_AI_AI_GUI_CONTROLLER_H_

View File

@@ -0,0 +1,353 @@
#include "cli/service/ai/vision_action_refiner.h"
#include <algorithm>
#include <sstream>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "cli/service/ai/gemini_ai_service.h"
namespace yaze {
namespace cli {
namespace ai {
VisionActionRefiner::VisionActionRefiner(GeminiAIService* gemini_service)
: gemini_service_(gemini_service) {
if (!gemini_service_) {
throw std::invalid_argument("Gemini service cannot be null");
}
}
absl::StatusOr<VisionAnalysisResult> VisionActionRefiner::AnalyzeScreenshot(
const std::filesystem::path& screenshot_path,
const std::string& context) {
if (!std::filesystem::exists(screenshot_path)) {
return absl::NotFoundError(
absl::StrCat("Screenshot not found: ", screenshot_path.string()));
}
std::string prompt = BuildAnalysisPrompt(context);
auto response = gemini_service_->GenerateMultimodalResponse(
screenshot_path.string(),
prompt
);
if (!response.ok()) {
return response.status();
}
return ParseAnalysisResponse(response->text_response);
}
absl::StatusOr<VisionAnalysisResult> VisionActionRefiner::VerifyAction(
const AIAction& action,
const std::filesystem::path& before_screenshot,
const std::filesystem::path& after_screenshot) {
if (!std::filesystem::exists(before_screenshot)) {
return absl::NotFoundError("Before screenshot not found");
}
if (!std::filesystem::exists(after_screenshot)) {
return absl::NotFoundError("After screenshot not found");
}
// First, analyze the after screenshot
std::string verification_prompt = BuildVerificationPrompt(action);
auto after_response = gemini_service_->GenerateMultimodalResponse(
after_screenshot.string(),
verification_prompt
);
if (!after_response.ok()) {
return after_response.status();
}
return ParseVerificationResponse(after_response->text_response, action);
}
absl::StatusOr<ActionRefinement> VisionActionRefiner::RefineAction(
const AIAction& original_action,
const VisionAnalysisResult& analysis) {
ActionRefinement refinement;
// If action was successful, no refinement needed
if (analysis.action_successful) {
return refinement;
}
// Determine refinement strategy based on error
std::string error_lower = analysis.error_message;
std::transform(error_lower.begin(), error_lower.end(),
error_lower.begin(), ::tolower);
if (error_lower.find("not found") != std::string::npos ||
error_lower.find("missing") != std::string::npos) {
refinement.needs_different_approach = true;
refinement.reasoning = "UI element not found, may need to open different editor";
}
else if (error_lower.find("wrong") != std::string::npos ||
error_lower.find("incorrect") != std::string::npos) {
refinement.needs_retry = true;
refinement.reasoning = "Action executed on wrong element, adjusting parameters";
// Try to extract corrected parameters from suggestions
for (const auto& suggestion : analysis.suggestions) {
// Parse suggestions for parameter corrections
// e.g., "Try position (6, 8) instead"
if (suggestion.find("position") != std::string::npos) {
// Extract coordinates
size_t pos = suggestion.find('(');
if (pos != std::string::npos) {
size_t end = suggestion.find(')', pos);
if (end != std::string::npos) {
std::string coords = suggestion.substr(pos + 1, end - pos - 1);
std::vector<std::string> parts = absl::StrSplit(coords, ',');
if (parts.size() == 2) {
refinement.adjusted_parameters["x"] =
absl::StripAsciiWhitespace(parts[0]);
refinement.adjusted_parameters["y"] =
absl::StripAsciiWhitespace(parts[1]);
}
}
}
}
}
}
else {
refinement.needs_retry = true;
refinement.reasoning = "Generic failure, will retry with same parameters";
}
return refinement;
}
absl::StatusOr<std::map<std::string, std::string>>
VisionActionRefiner::LocateUIElement(
const std::filesystem::path& screenshot_path,
const std::string& element_name) {
std::string prompt = BuildElementLocationPrompt(element_name);
auto response = gemini_service_->GenerateMultimodalResponse(
screenshot_path.string(),
prompt
);
if (!response.ok()) {
return response.status();
}
std::map<std::string, std::string> location;
// Parse location from response
// Expected format: "The element is located at position (X, Y)"
// or "The element is in the top-right corner"
std::string text = response->text_response;
std::transform(text.begin(), text.end(), text.begin(), ::tolower);
if (text.find("not found") != std::string::npos ||
text.find("not visible") != std::string::npos) {
location["found"] = "false";
location["description"] = response->text_response;
} else {
location["found"] = "true";
location["description"] = response->text_response;
// Try to extract coordinates
size_t pos = text.find('(');
if (pos != std::string::npos) {
size_t end = text.find(')', pos);
if (end != std::string::npos) {
std::string coords = text.substr(pos + 1, end - pos - 1);
std::vector<std::string> parts = absl::StrSplit(coords, ',');
if (parts.size() == 2) {
location["x"] = absl::StripAsciiWhitespace(parts[0]);
location["y"] = absl::StripAsciiWhitespace(parts[1]);
}
}
}
}
return location;
}
absl::StatusOr<std::vector<std::string>>
VisionActionRefiner::ExtractVisibleWidgets(
const std::filesystem::path& screenshot_path) {
std::string prompt = BuildWidgetExtractionPrompt();
auto response = gemini_service_->GenerateMultimodalResponse(
screenshot_path.string(),
prompt
);
if (!response.ok()) {
return response.status();
}
// Parse widget list from response
std::vector<std::string> widgets;
std::stringstream ss(response->text_response);
std::string line;
while (std::getline(ss, line)) {
// Skip empty lines
if (line.empty() || line.find_first_not_of(" \t\n\r") == std::string::npos) {
continue;
}
// Remove list markers (-, *, 1., etc.)
size_t start = 0;
if (line[0] == '-' || line[0] == '*') {
start = 1;
} else if (std::isdigit(line[0])) {
start = line.find('.');
if (start != std::string::npos) {
start++;
} else {
start = 0;
}
}
absl::string_view widget_view = absl::StripAsciiWhitespace(
absl::string_view(line).substr(start));
if (!widget_view.empty()) {
widgets.push_back(std::string(widget_view));
}
}
return widgets;
}
// Private helper methods
std::string VisionActionRefiner::BuildAnalysisPrompt(const std::string& context) {
std::string base_prompt =
"Analyze this screenshot of the YAZE ROM editor GUI. "
"Identify all visible UI elements, windows, and widgets. "
"List them in order of importance.";
if (!context.empty()) {
return absl::StrCat(base_prompt, "\n\nContext: ", context);
}
return base_prompt;
}
std::string VisionActionRefiner::BuildVerificationPrompt(const AIAction& action) {
std::string action_desc = AIActionParser::ActionToString(action);
return absl::StrCat(
"This screenshot was taken after attempting to perform the following action: ",
action_desc,
"\n\nDid the action succeed? Look for visual evidence that the action completed. "
"Respond with:\n"
"SUCCESS: <description of what changed>\n"
"or\n"
"FAILURE: <description of what went wrong>"
);
}
std::string VisionActionRefiner::BuildElementLocationPrompt(
const std::string& element_name) {
return absl::StrCat(
"Locate the '", element_name, "' UI element in this screenshot. "
"If found, describe its position (coordinates if possible, or relative position). "
"If not found, state 'NOT FOUND'."
);
}
std::string VisionActionRefiner::BuildWidgetExtractionPrompt() {
return
"List all visible UI widgets, buttons, windows, and interactive elements "
"in this screenshot. Format as a bulleted list, one element per line.";
}
VisionAnalysisResult VisionActionRefiner::ParseAnalysisResponse(
const std::string& response) {
VisionAnalysisResult result;
result.description = response;
// Extract widgets from description
// Look for common patterns like "- Button", "1. Window", etc.
std::stringstream ss(response);
std::string line;
while (std::getline(ss, line)) {
// Check if line contains a widget mention
std::string lower = line;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
if (lower.find("button") != std::string::npos ||
lower.find("window") != std::string::npos ||
lower.find("panel") != std::string::npos ||
lower.find("selector") != std::string::npos ||
lower.find("editor") != std::string::npos) {
result.widgets.push_back(std::string(absl::StripAsciiWhitespace(line)));
}
// Extract suggestions
if (lower.find("suggest") != std::string::npos ||
lower.find("try") != std::string::npos ||
lower.find("could") != std::string::npos) {
result.suggestions.push_back(std::string(absl::StripAsciiWhitespace(line)));
}
}
return result;
}
VisionAnalysisResult VisionActionRefiner::ParseVerificationResponse(
const std::string& response,
const AIAction& action) {
VisionAnalysisResult result;
result.description = response;
std::string response_upper = response;
std::transform(response_upper.begin(), response_upper.end(),
response_upper.begin(), ::toupper);
if (response_upper.find("SUCCESS") != std::string::npos) {
result.action_successful = true;
// Extract success description
size_t pos = response_upper.find("SUCCESS:");
if (pos != std::string::npos) {
std::string desc = response.substr(pos + 8);
result.description = absl::StripAsciiWhitespace(desc);
}
}
else if (response_upper.find("FAILURE") != std::string::npos) {
result.action_successful = false;
// Extract failure description
size_t pos = response_upper.find("FAILURE:");
if (pos != std::string::npos) {
std::string desc = response.substr(pos + 8);
result.error_message = absl::StripAsciiWhitespace(desc);
} else {
result.error_message = "Action failed (details in description)";
}
}
else {
// Ambiguous response, assume failure
result.action_successful = false;
result.error_message = "Could not determine action success from vision analysis";
}
return result;
}
} // namespace ai
} // namespace cli
} // namespace yaze

View File

@@ -0,0 +1,155 @@
#ifndef YAZE_CLI_SERVICE_AI_VISION_ACTION_REFINER_H_
#define YAZE_CLI_SERVICE_AI_VISION_ACTION_REFINER_H_
#include <filesystem>
#include <string>
#include <vector>
#include "absl/status/statusor.h"
#include "cli/service/ai/ai_action_parser.h"
namespace yaze {
namespace cli {
// Forward declare
class GeminiAIService;
namespace ai {
/**
* @struct VisionAnalysisResult
* @brief Result of analyzing a screenshot with Gemini Vision
*/
struct VisionAnalysisResult {
std::string description; // What Gemini sees in the image
std::vector<std::string> widgets; // Detected UI widgets
std::vector<std::string> suggestions; // Action suggestions
bool action_successful = false; // Whether the last action succeeded
std::string error_message; // Error description if action failed
};
/**
* @struct ActionRefinement
* @brief Refined action parameters based on vision analysis
*/
struct ActionRefinement {
bool needs_retry = false;
bool needs_different_approach = false;
std::map<std::string, std::string> adjusted_parameters;
std::string reasoning;
};
/**
* @class VisionActionRefiner
* @brief Uses Gemini Vision to analyze GUI screenshots and refine AI actions
*
* This class implements the vision-guided action loop:
* 1. Take screenshot of current GUI state
* 2. Send to Gemini Vision with contextual prompt
* 3. Analyze response to determine next action
* 4. Verify action success by comparing screenshots
*
* Example usage:
* ```cpp
* VisionActionRefiner refiner(gemini_service);
*
* // Analyze current state
* auto analysis = refiner.AnalyzeCurrentState(
* "overworld_editor",
* "Looking for tile selector"
* );
*
* // Verify action was successful
* auto verification = refiner.VerifyAction(
* AIAction(kPlaceTile, {{"x", "5"}, {"y", "7"}}),
* before_screenshot,
* after_screenshot
* );
*
* // Refine failed action
* if (!verification->action_successful) {
* auto refinement = refiner.RefineAction(
* original_action,
* *verification
* );
* }
* ```
*/
class VisionActionRefiner {
public:
/**
* @brief Construct refiner with Gemini service
* @param gemini_service Pointer to Gemini AI service (not owned)
*/
explicit VisionActionRefiner(GeminiAIService* gemini_service);
/**
* @brief Analyze the current GUI state from a screenshot
* @param screenshot_path Path to screenshot file
* @param context Additional context about what we're looking for
* @return Vision analysis result
*/
absl::StatusOr<VisionAnalysisResult> AnalyzeScreenshot(
const std::filesystem::path& screenshot_path,
const std::string& context = "");
/**
* @brief Verify an action was successful by comparing before/after screenshots
* @param action The action that was performed
* @param before_screenshot Screenshot before action
* @param after_screenshot Screenshot after action
* @return Analysis indicating whether action succeeded
*/
absl::StatusOr<VisionAnalysisResult> VerifyAction(
const AIAction& action,
const std::filesystem::path& before_screenshot,
const std::filesystem::path& after_screenshot);
/**
* @brief Refine an action based on vision analysis feedback
* @param original_action The action that failed or needs adjustment
* @param analysis Vision analysis showing why action failed
* @return Refined action with adjusted parameters
*/
absl::StatusOr<ActionRefinement> RefineAction(
const AIAction& original_action,
const VisionAnalysisResult& analysis);
/**
* @brief Find a specific UI element in a screenshot
* @param screenshot_path Path to screenshot
* @param element_name Name/description of element to find
* @return Coordinates or description of where element is located
*/
absl::StatusOr<std::map<std::string, std::string>> LocateUIElement(
const std::filesystem::path& screenshot_path,
const std::string& element_name);
/**
* @brief Extract all visible widgets from a screenshot
* @param screenshot_path Path to screenshot
* @return List of detected widgets with their properties
*/
absl::StatusOr<std::vector<std::string>> ExtractVisibleWidgets(
const std::filesystem::path& screenshot_path);
private:
GeminiAIService* gemini_service_; // Not owned
// Build prompts for different vision analysis tasks
std::string BuildAnalysisPrompt(const std::string& context);
std::string BuildVerificationPrompt(const AIAction& action);
std::string BuildElementLocationPrompt(const std::string& element_name);
std::string BuildWidgetExtractionPrompt();
// Parse Gemini vision responses
VisionAnalysisResult ParseAnalysisResponse(const std::string& response);
VisionAnalysisResult ParseVerificationResponse(
const std::string& response, const AIAction& action);
};
} // namespace ai
} // namespace cli
} // namespace yaze
#endif // YAZE_CLI_SERVICE_AI_VISION_ACTION_REFINER_H_