AI人工智能学习之自定义工具


AI人工智能学习之自定义工具


引言

人工智能中可以使用自定义工具,一种是按照工具函数的定义,一种是按照python类的定义。

正文

下面有两个示例。在开始之前先准备一下必要的库。

LLM 配置

一个是环境变量文件 env_utils.py:

import os

from dotenv import load_dotenv

load_dotenv(override=True)

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
DEEPSEEK_API_KEY = os.getenv('DEEPSEEK_API_KEY')
ZHIPU_API_KEY = os.getenv('ZHIPU_API_KEY')
MINIMAX_API_KEY = os.getenv('MINIMAX_API_KEY')
ALIBABA_API_KEY = os.getenv('ALIBABA_API_KEY')
K2_API_KEY = os.getenv('K2_API_KEY')

K2_BASE_URL = os.getenv('K2_BASE_URL')
ALIBABA_BASE_URL = os.getenv('ALIBABA_BASE_URL')
MINIMAX_BASE_URL = os.getenv('MINIMAX_BASE_URL')
OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL')
DEEPSEEK_BASE_URL = os.getenv('DEEPSEEK_BASE_URL')
ZHIPU_BASE_URL = os.getenv('ZHIPU_BASE_URL')

LOCAL_BASE_URL = os.getenv('LOCAL_BASE_URL')

一个是大模型集合文件 my_llm.py:

# 视频地址:https://www.bilibili.com/video/BV11rQzBEETd?vd_source=14e623b3280938e774caf714015caa22&spm_id_from=333.788.videopod.episodes&p=53

from langchain_openai import ChatOpenAI

from agent.llm.env_utils import OPENAI_API_KEY, OPENAI_BASE_URL, ALIBABA_API_KEY, ALIBABA_BASE_URL

# 调用阿里云百炼里面的的 DeepSeek 模型
llm = ChatOpenAI(  # 第一种
    model_name="deepseek-v3.2",
    temperature=1.1,
    api_key=ALIBABA_API_KEY,
    base_url=ALIBABA_BASE_URL,
)

# llm = ChatOpenAI(  # 第一种
#     model_name="deepseek-chat",
#     # model_name="deepseek-reasoner",
#     temperature=1.2,
#     api_key=DEEPSEEK_API_KEY,
#     base_url=DEEPSEEK_BASE_URL,
# )

# llm = ChatDeepSeek(  # 第一种
#     # model_name="deepseek-chat",
#     model_name="deepseek-reasoner",
#     temperature=1.3,
#     api_key=DEEPSEEK_API_KEY,
#     base_url=DEEPSEEK_BASE_URL,
# )
# resp = llm.invoke('用三句话简单介绍一下:机器学习的基本概念')
# print(type(resp))
# print(resp)
#
#
# llm = ChatDeepSeek(  # langchain-deepseek: 第二种
#     model_name="deepseek-r1-0528",
#     # model_name="deepseek-v3",
#     temperature=1.3,
#     api_key=ALIBABA_API_KEY,
#     api_base=ALIBABA_BASE_URL,
# )

# llm = BailianCustomChatModel(  # 自定义 的类来调用大模型: 第三种
#     model_name="deepseek-r1-0528",
#     api_key=ALIBABA_API_KEY,
#     base_url=ALIBABA_BASE_URL,
# )


# 在线的openai的大模型
# llm = ChatOpenAI(
#     model_name="gpt-4.1-mini",
#     temperature=0.5,
#     api_key=OPENAI_API_KEY,
#     base_url=OPENAI_BASE_URL,
# )

# llm = ChatOpenAI(
#     model='qwen3-max',
#     # model='qwen-plus',
#     # model='qwen3-8b',
#     temperature=0.6,
#     api_key=ALIBABA_API_KEY,
#     base_url=ALIBABA_BASE_URL,
# )


# 速率限制
# rate_limiter = InMemoryRateLimiter(
#     requests_per_second=0.1,  # 每10秒允许1个请求
#     check_every_n_seconds=0.1,  # 每100毫秒检查一次是否允许发出请求
#     max_bucket_size=10,  #  控制最大突发请求数量
# )

# llm = init_chat_model(  # V1.0后才有的写法
#     model="deepseek-r1-0528",
#     model_provider="openai",
#     api_key=ALIBABA_API_KEY,
#     base_url=ALIBABA_BASE_URL,
#     rate_limiter=rate_limiter
# )

# 智谱AI: https://docs.bigmodel.cn/cn/guide/tools/web-search
# pip install zhipuai
# zhipuai_client = ZhipuAI(api_key=ZHIPU_API_KEY)

日志类

再加一个日志类 logger_utils.py:

import sys, os
from loguru import logger

# 获得当前项目的绝对路径
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
log_dir = os.path.join(root_dir, "logs")  # 存放项目日志目录的绝对路径

if not os.path.exists(log_dir):  # 如果日志目录不存在,则创建
    os.mkdir(log_dir)

# LOG_FILE = "translation.log"  # 存储日志的文件

# Trace < Debug < Info < Success < Warning < Error < Critical

class MyLogger:
    def __init__(self):
        # log_file_path = os.path.join(log_dir, LOG_FILE)
        self.logger = logger  # 写日志的对象
        # 清空所有设置
        self.logger.remove()
        # 添加控制台输出的格式,sys.stdout为输出到屏幕;关于这些配置还需要自定义请移步官网查看相关参数说明
        self.logger.add(sys.stdout, level='DEBUG',
                        format="<green>{time:YYYYMMDD HH:mm:ss}</green> | "  # 颜色>时间
                               "{process.name} | "  # 进程名
                               "{thread.name} | "  # 线程名
                               "<cyan>{module}</cyan>.<cyan>{function}</cyan>"  # 模块名.方法名
                               ":<cyan>{line}</cyan> | "  # 行号
                               "<level>{level}</level>: "  # 等级
                               "<level>{message}</level>",  # 日志内容
                        )
        # 输出到文件的格式,注释下面的add',则关闭日志写入
        # self.logger.add(log_file_path, level='DEBUG', encoding='UTF-8',
        #                 format='{time:YYYYMMDD HH:mm:ss} - '  # 时间
        #                        "{process.name} | "  # 进程名
        #                        "{thread.name} | "  # 进程名
        #                        '{module}.{function}:{line} - {level} -{message}',  # 模块名.方法名:行号
        #                 rotation="10 MB",  # 日志文件生成的规则  rotation="1 week"  rotation="1 days"
        #                 retention=20  # 保留日志文件的规则
        #                 )

    def get_logger(self):
        return self.logger


log = MyLogger().get_logger()

if __name__ == '__main__':
    # log.debug("This is a debug message.")
    # log.info("This is an info message.")
    # log.warning('这是一个警告')
    # log.trace('xxxx')
    print('str.pdf'['str.pdf'.rindex('.'):])
    # @log.catch  # 整个函数自动加上try, catch。自动捕获异常,并且通过日志打印
    def test():
        try:
            print(3/0)
        except ZeroDivisionError as e:
            # log.error(e)
            log.exception(e)  # 以后常用

方式一

视频地址:https://www.bilibili.com/video/BV11rQzBEETd/?p=60

# ------------------第一种方式--------------
@tool('my_web_search', parse_docstring=True)
def web_search(query: str) -> str:
    """互联网搜索的工具,可以搜索所有公开的信息。

    Args:
        query: 需要进行互联网查询的的信息。

    Returns:
        返回搜索的结果信息,该信息是一个文本字符串。
    """
    pass


# ------------------第二种方式--------------

class SearchArgs(BaseModel):
    query: str = Field(..., description='需要进行互联网查询的查询信息')
    second: int = Field(..., description='第二个参数')

@tool('my_web_search2', args_schema=SearchArgs, description='互联网搜索的工具,可以搜索所有公开的信息')
def web_search2(query: str, second: int) -> str:
    pass


# ------------------第一种方式,示例--------------
# pip install zhipuai
from agent.my_llm import zhipuai_client

@tool('my_web_search3', parse_docstring=True)
def web_search3(query: str) -> str:
    """互联网搜索的工具,可以搜索所有公开的信息。

    Args:
        query: 需要进行互联网查询的的信息。

    Returns:
        返回搜索的结果信息,该信息是一个文本字符串。
    """
    try:
        resp = zhipuai_client.web_search.web_search(
            search_engine='search_pro',
            search_query=query,
        )
        if resp.search_result:
            return "\n\n".join([d.content for d in resp.search_result])
        return"没有搜索到任何结果"
    except Exception as e:
        print(e)
        return f"Error: {e}"

# ------------------第一种方式,异步--------------
@tool('my_web_search4', parse_docstring=True)
async def web_search4(query: str) -> str:
    """互联网搜索的工具,可以搜索所有公开的信息。

    Args:
        query: 需要进行互联网查询的的信息。

    Returns:
        返回搜索的结果信息,该信息是一个文本字符串。
    """
    pass

# ------------------测试--------------

if __name__ =='__main__':
    print(web_search3.name)#工具的名字
    print(web_search3.description)#工具的描述
    print(web_search3.args)#工具的参数
    print(web_search3.args_schema.model_json_schema())#工具的参数的json schema(描述son字符串)

    result=web_search3.invoke({'query':'如何使用langchain?'})
    print(result)

方式二

视频地址:https://www.bilibili.com/video/BV11rQzBEETd?p=63

from agent.my_llm import zhipuai_client

from typing import Type
from langchain_core.tools import BaseTool
from pydantic import create_model, Field, BaseModel

class SearchArgs(BaseModel): #  类:数据模型类
    query: str = Field(..., description='需要进行互联网查询的查询信息')
    #second: int=Field(...,description='第二个参数')

class MyWebSearchTool(BaseTool):
    name: str="web_search2"#定义工具的名称

    description:str="使用这个工具可以进行网络搜索。"
    #第一种写法
    #args_schema:Type[BaseModel]=SearchArgs #工具的参数
    
    #第二种写法
    # def __init__(self, **kwargs): # 动态参数模型类
    #     super().__init__(kwargs)
    #     self.args_schema = create_model(model_name:"SearchInput",query=(str,Field(...,description="需要进行互联网查询的查询信息")))
    def __init__(self): # 不传参
        super().__init__()
        self.args_schema = create_model(model_name="SearchInput",query=(str,Field(...,description="需要进行互联网查询的查询信息")))


    def _run(self, query: str) -> str:
        try:
            # print("执行我的Python中的工具,输入的参数为:", query)
            response = zhipuai_client.web_search.web_search(
                search_engine="search_pro",
                search_query=query
            )
            # print(response)
            if response.search_result:
                return "\n\n".join([d.content for d in response.search_result])
            return '没有搜索到任何内容!'
        except Exception as e:
            print(e)
            return '没有搜索到任何内容!'

    async def _run(self, query: str) -> str:
        return self._run(query)







参考资料


返回