[RAG] – Bổ sung dữ liệu cho LLM

Tình huống

Tôi đã triển khai LLAMA3.2 (Instructor + Vision) trên máy chủ của tôi, chạy dưới dạng locally nhằm tiết kiệm chi phí so với việc sử dụng OpenAI API hay Gemini API. Tôi muốn bổ sung thêm thông tin (dữ liệu mới) cho LLM để thuận tiện truy vấn trong chatbot (TELEGRAM), do đó tôi sẽ sử dụng RAG để cho chatbot có thể lấy thêm thông tin, dữ liệu từ bên ngoài.

Cách làm

Bước 1. Triển khai ollama

Bước 2. Triển khai LLAMA3.2

Bước 3: Triển khai QWEEN2.5

Bổ sung dữ liệu chứng khoán cho Chatbot

Phần này tôi chỉ tóm tắt nhanh, mã nguồn chính được lưu trữ ở đây:
https://github.com/taipm/crewai_stock_market.git

Các bước tiến hành

Bước 1. Định nghĩa các tools

# Danh sách các công cụ với phiên bản cache
tools = [
    Tool(
        name="Calculator",
        func=calculator,
        description="Thực hiện các phép tính toán học cơ bản."
    ),
    Tool(
        name="GetWeather",
        func=cached_get_weather,
        description="Cung cấp thông tin thời tiết cho một địa điểm cụ thể."
    ),
    Tool(
        name="GetStockPrice",
        func=cached_get_stock_price,
        description="Lấy giá hiện tại của một mã cổ phiếu cụ thể sử dụng yfinance."
    ),
    Tool(
        name="ReadNewsFromURL",
        func=cached_read_news_from_url,
        description="Đọc và trích xuất nội dung tin tức từ một URL cụ thể."
    ),
    Tool(
        name="PlotStockChart",
        func=plot_stock_chart,
        description="Vẽ đồ thị giá cổ phiếu hàng ngày của một mã cổ phiếu cụ thể trong một khoảng thời gian."
    ),
]

Giả sử, ta cần bổ sung một hàm tóm tắt thông tin thị trường: MarketSummary:

Tool(
        name="GetMarketSummary",
        func=cached_get_market_summary,
        description="Tóm tắt thị trường chứng khoán trong ngày."
    ),

Tiếp đó bổ sung:
@tool_cache()
def cached_get_market_summary(*args, **kwargs):
    return GetMarketSummary(*args, **kwargs)

from app.tools.vn_stock import (
    get_stock_price,
    GetMarketSummary,
    plot_stock_chart
)
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from functools import lru_cache
import json
import time

from langchain_community.llms import Ollama
from langchain.agents import AgentExecutor, Tool, initialize_agent, AgentType
from langchain.memory import ConversationBufferMemory
from langchain.cache import InMemoryCache, SQLiteCache
from langchain.globals import set_llm_cache
from app.config import OLLAMA_BASE_URL, LLAMA_MODEL

# Import các công cụ và cấu hình
from app.tools.common import (
    calculator, 
    get_weather, 
    read_news_from_url
)
from app.tools.vn_stock import (
    get_stock_price,
    plot_stock_chart
)

# Thiết lập cache 
CACHE_DIRECTORY = os.path.join(os.path.dirname(__file__), 'cache')
os.makedirs(CACHE_DIRECTORY, exist_ok=True)

# Lựa chọn cache - SQLite cho persistent caching
set_llm_cache(SQLiteCache(database_path=os.path.join(CACHE_DIRECTORY, 'langchain.db')))

# Cache decorator cho các hàm công cụ
def tool_cache(maxsize=128, typed=False):
    def decorator(func):
        @lru_cache(maxsize=maxsize, typed=typed)
        def wrapper(*args, **kwargs):
            # Chuyển đổi kwargs thành một dạng có thể cache được
            kwargs_key = tuple(sorted(kwargs.items()))
            return func(*args, *kwargs_key)
        return wrapper
    return decorator

# Áp dụng cache cho các công cụ
@tool_cache()
def cached_get_weather(*args, **kwargs):
    return get_weather(*args, **kwargs)

@tool_cache()
def cached_get_stock_price(*args, **kwargs):
    return get_stock_price(*args, **kwargs)

@tool_cache()
def cached_read_news_from_url(*args, **kwargs):
    return read_news_from_url(*args, **kwargs)

# Danh sách các công cụ với phiên bản cache
tools = [
    Tool(
        name="Calculator",
        func=calculator,
        description="Thực hiện các phép tính toán học cơ bản."
    ),
    Tool(
        name="GetWeather",
        func=cached_get_weather,
        description="Cung cấp thông tin thời tiết cho một địa điểm cụ thể."
    ),
    Tool(
        name="GetStockPrice",
        func=cached_get_stock_price,
        description="Lấy giá hiện tại của một mã cổ phiếu cụ thể sử dụng yfinance."
    ),
    Tool(
        name="ReadNewsFromURL",
        func=cached_read_news_from_url,
        description="Đọc và trích xuất nội dung tin tức từ một URL cụ thể."
    ),
    Tool(
        name="PlotStockChart",
        func=plot_stock_chart,
        description="Vẽ đồ thị giá cổ phiếu hàng ngày của một mã cổ phiếu cụ thể trong một khoảng thời gian."
    ),
]

# Khởi tạo LLM với các tối ưu
llm = Ollama(
    model=LLAMA_MODEL,
    base_url=OLLAMA_BASE_URL,
    verbose=False,  # Giảm verbose để tăng tốc độ
    temperature=0.7,  # Điều chỉnh sáng tạo
    num_predict=4096,  # Tăng độ dài token
    timeout=120,  # Tăng thời gian chờ
)

# Quản lý bộ nhớ hội thoại
memory = ConversationBufferMemory(
    memory_key="chat_history", 
    return_messages=True,
    max_token_limit=4096  # Giới hạn kích thước bộ nhớ
)

agent_executor = initialize_agent(
    tools,
    llm,
    agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
    verbose=True,  # Để kiểm tra chi tiết lỗi nếu có
    handle_parsing_errors=True,
    memory=memory,
    max_iterations=10,  # Tăng giới hạn vòng lặp
    early_stopping_method="force",
    agent_kwargs={
        "system_message": (
            "Bạn là một trợ lý AI thông minh có thể sử dụng nhiều công cụ để trả lời câu hỏi một cách nhanh chóng và hiệu quả. "
            "Khi trả lời, hãy bao gồm cả phần 'Observation' chi tiết và 'Summary', ví dụ:"
            "Thông tin cổ phiếu DGC (HOSE):"
            "- Ngày: 2024-11-29 00:00:00"
            "- Giá mở cửa: 108000.00"
            "- Giá cao nhất: 110500.00"
            "- Giá thấp nhất: 107900.00"
            "- Giá đóng cửa: 109500.00"
            "- Tỷ lệ tăng giá: 1.67%"
            "- Khối lượng giao dịch: 2099600.0"
            "- Tỷ lệ tăng khối lượng (so với trung bình 20 phiên): 45.54%"
            "- Tín hiệu: Giữ"
        ),
    }
)

# Hàm thực thi với timeout và xử lý ngoại lệ
async def execute_agent_query(query, max_timeout=30):
    try:
        with ThreadPoolExecutor() as pool:
            # Sử dụng asyncio để thiết lập timeout
            future = pool.submit(agent_executor.run, query)
            response = await asyncio.wrap_future(future)
            return response
    except TimeoutError:
        return "Xin lỗi, yêu cầu mất quá nhiều thời gian để xử lý."
    except Exception as e:
        return f"Đã xảy ra lỗi: {str(e)}"

# Hàm quản lý cache
def manage_cache():
    # Tùy chọn xóa cache cũ
    cache_dir = os.path.join(os.path.dirname(__file__), 'cache')
    for filename in os.listdir(cache_dir):
        file_path = os.path.join(cache_dir, filename)
        try:
            if os.path.isfile(file_path):
                # Xóa file cache cũ hơn 7 ngày
                if os.path.getmtime(file_path) < (time.time() - 7 * 86400):
                    os.unlink(file_path)
        except Exception as e:
            print(f"Lỗi khi dọn cache: {e}")

# Đăng ký hàm quản lý cache (nếu cần)
if __name__ == "__main__":
    manage_cache()

Bổ sung mã nguồn cho vn_stock.py

from pymongo import MongoClient
from datetime import datetime
import numpy as np
import pandas as pd
from app.config import MONGO_URI
import mplfinance as mpf
from io import BytesIO
import matplotlib.pyplot as plt


# Kết nối MongoDB
client = MongoClient(MONGO_URI)
print(f"Connected to MongoDB with URI: {MONGO_URI}")
db = client["DailyDb"]

ddef get_market_summary(date_str: str = '') -> str:
    # Nếu không cung cấp ngày, tự động lấy ngày cuối cùng của dữ liệu
    if not date_str:
        all_symbols = db.list_collection_names()
        max_date = None
        for symbol in all_symbols:
            collection = db[symbol]
            latest_date_doc = collection.find_one({}, sort=[("date", -1)])  # Tìm ngày lớn nhất
            if latest_date_doc:
                if not max_date or latest_date_doc["date"] > max_date:
                    max_date = latest_date_doc["date"]
        
        if not max_date:
            return "Không tìm thấy dữ liệu trong cơ sở dữ liệu."
        
        # Chuyển ngày lớn nhất thành chuỗi để sử dụng
        date_str = max_date.strftime("%Y-%m-%d")
    
    # Tạo bộ lọc theo ngày
    try:
        filter_query = {"date": datetime.strptime(date_str, "%Y-%m-%d")}
    except ValueError:
        return "Ngày không hợp lệ! Vui lòng nhập theo định dạng YYYY-MM-DD."

    # Lấy danh sách các collection (mỗi collection là một mã cổ phiếu)
    all_symbols = db.list_collection_names()

    summary_data = []
    for symbol in all_symbols:
        collection = db[symbol]
        # Truy vấn dữ liệu theo ngày
        stock_data = list(collection.find(filter_query))
        if stock_data:
            for record in stock_data:
                summary_data.append({
                    "symbol": record["symbol"],
                    "market": record["market"],
                    "volume": record["volume"],
                    "price_change": (record["close"] - record["open"]) / record["open"] * 100 if record["open"] else 0,
                    "money": record["money"]
                })

    if not summary_data:
        return f"Không có dữ liệu giao dịch cho ngày {date_str}."

    # Tạo DataFrame để xử lý dữ liệu
    df = pd.DataFrame(summary_data)

    # Lấy top 5 cổ phiếu tăng giá mạnh nhất
    top_gainers = df.sort_values(by="price_change", ascending=False).head(5)
    # Lấy top 5 cổ phiếu có khối lượng giao dịch lớn nhất
    top_volume = df.sort_values(by="volume", ascending=False).head(5)

    # Tạo kết quả trả về
    result = f"Ngày giao dịch: {date_str}\n\n"
    result += "Top cổ phiếu tăng giá mạnh nhất:\n"
    result += "\n".join([f"{row['symbol']} ({row['market']}): {row['price_change']:.2f}%" for _, row in top_gainers.iterrows()])
    result += "\n\nTop cổ phiếu có khối lượng giao dịch lớn nhất:\n"
    result += "\n".join([f"{row['symbol']} ({row['market']}): {row['volume']}" for _, row in top_volume.iterrows()])

    return result

def get_stock_price(ticker: str) -> str:
    print(f"Fetching stock price for {ticker} from MongoDB")
    try:
        # Truy cập collection tương ứng với mã cổ phiếu
        collection = db[ticker]

        # Lấy dữ liệu 21 phiên gần nhất (để tính trung bình 20 phiên)
        recent_data = list(collection.find().sort("date", -1).limit(21))
        print(recent_data)
        
        if not recent_data or len(recent_data) < 2:
            return f"Không đủ dữ liệu để tính toán cho mã cổ phiếu: {ticker}."
        
        # Lấy thông tin ngày gần nhất và ngày liền trước
        latest_entry = recent_data[0]
        previous_entry = recent_data[1]
        
        # Giá đóng cửa gần nhất
        latest_price = latest_entry.get("close")
        previous_price = previous_entry.get("close")
        
        # Tính tỷ lệ tăng giá (so với ngày trước đó)
        price_change = ((latest_price - previous_price) / previous_price) * 100 if previous_price else 0
        
        # Khối lượng giao dịch
        latest_volume = latest_entry.get("volume")
        volumes = [entry.get("volume", 0) for entry in recent_data[1:21]]
        avg_volume_20 = np.mean(volumes)
        volume_change = ((latest_volume - avg_volume_20) / avg_volume_20) * 100 if avg_volume_20 else 0
        
        # Thông tin bổ sung
        date = latest_entry.get("date")
        market = latest_entry.get("market")
        high = latest_entry.get("high")
        low = latest_entry.get("low")
        open_price = latest_entry.get("open")
        
        # Tín hiệu mua/bán cơ bản
        signal = "Mua" if price_change > 2 and volume_change > 50 else "Giữ" if price_change > 0 else "Bán"

        # Kết quả cuối cùng
        return (
            f"Thông tin cổ phiếu {ticker} ({market}):\n"
            f"- Ngày: {date}\n"
            f"- Giá mở cửa: {open_price:.2f}\n"
            f"- Giá cao nhất: {high:.2f}\n"
            f"- Giá thấp nhất: {low:.2f}\n"
            f"- Giá đóng cửa: {latest_price:.2f}\n"
            f"- Tỷ lệ tăng giá: {price_change:.2f}%\n"
            f"- Khối lượng giao dịch: {latest_volume}\n"
            f"- Tỷ lệ tăng khối lượng (so với trung bình 20 phiên): {volume_change:.2f}%\n"
            f"- Tín hiệu: {signal}\n"
        )
    except Exception as e:
        return f"Lỗi khi lấy dữ liệu cổ phiếu từ MongoDB: {str(e)}"


# def plot_stock_chart(ticker: str, period: str = "6mo"):
#     ticker = ticker.upper()
#     print(f'Đang gọi hàm vẽ đồ thị của cổ phiếu {ticker}')
#     try:
#         # Truy cập collection tương ứng với mã cổ phiếu
#         collection = db[ticker]
        
#         # Lấy dữ liệu từ MongoDB
#         data = list(collection.find().sort("date", 1))  # Sắp xếp theo ngày tăng dần
#         print(data)
#         if not data:
#             return None, f"Không tìm thấy dữ liệu cho mã cổ phiếu: {ticker}."
        
#         # Trích xuất dữ liệu cần thiết để vẽ
#         dates = [datetime.strptime(d["date"], "%Y-%m-%dT%H:%M:%S.%fZ") for d in data]
#         closes = [d["close"] for d in data]
        
#         # Vẽ đồ thị
#         plt.figure(figsize=(10, 5))
#         plt.plot(dates, closes, label='Giá Đóng Cửa')
#         plt.title(f"Biểu Đồ Giá Đóng Cửa Hàng Ngày của {ticker}")
#         plt.xlabel("Ngày")
#         plt.ylabel("Giá")
#         plt.legend()
#         plt.grid(True)

#         # Lưu hình ảnh vào bộ nhớ
#         img_buffer = BytesIO()
#         plt.savefig(img_buffer, format='png')
#         plt.close()
#         img_buffer.seek(0)

#         return img_buffer, None  # Trả về dữ liệu hình ảnh và None cho message
#     except Exception as e:
#         return None, f"Lỗi khi vẽ đồ thị: {str(e)}"


def plot_stock_chart(ticker: str, period: str = "6mo"):
    ticker = ticker.upper()
    print(f'Vẽ đồ thị candlestick cho cổ phiếu {ticker}')
    try:
        # Truy cập collection tương ứng với mã cổ phiếu
        collection = db[ticker]

        # Lấy dữ liệu từ MongoDB
        data = list(collection.find().sort("date", 1))  # Sắp xếp theo ngày tăng dần
        print(data)
        if not data:
            return None, f"Không tìm thấy dữ liệu cho mã cổ phiếu: {ticker}."

        # Tạo DataFrame từ dữ liệu
        df = pd.DataFrame(data)
        df['date'] = pd.to_datetime(df['date'])  # Chuyển đổi định dạng ngày
        df.set_index('date', inplace=True)

        # Đảm bảo các cột cần thiết đều là số, thay thế giá trị không hợp lệ
        numeric_columns = ['open', 'high', 'low', 'close', 'volume']
        for col in numeric_columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')  # Chuyển thành số, lỗi -> NaN

        # Loại bỏ các hàng có giá trị NaN trong các cột quan trọng
        df.dropna(subset=['open', 'high', 'low', 'close'], inplace=True)

        # Kiểm tra nếu không còn dữ liệu hợp lệ sau khi làm sạch
        if df.empty:
            return None, f"Dữ liệu không đủ để vẽ đồ thị candlestick cho mã cổ phiếu {ticker}."

        # Vẽ đồ thị candlestick
        fig, ax = mpf.plot(
            df,
            type='candle',
            style='yahoo',
            title=f"Biểu Đồ Candlestick của {ticker}",
            ylabel='Giá',
            volume=True,
            returnfig=True
        )

        # Lưu hình ảnh vào bộ nhớ
        img_buffer = BytesIO()
        fig.savefig(img_buffer, format='png')
        plt.close(fig)
        img_buffer.seek(0)

        return img_buffer, None  # Trả về hình ảnh
    except Exception as e:
        return None, f"Lỗi khi vẽ đồ thị candlestick: {str(e)}"

Docker

Thực thi lệnh [sh setup.sh] để cập nhật chương trình trên docker.

Kiểm thử chatbot

Thực thi một số truy vấn, dưới đây là scripts mẫu:

Tóm tắt thị trường hôm nay

Lấy giá cổ phiếu FPT

Comments

No comments yet. Why don’t you start the discussion?

Leave a Reply

Your email address will not be published. Required fields are marked *