{ "cells": [ { "cell_type": "markdown", "id": "4b3a0584-b52c-4873-abb8-8382e13ff5c0", "metadata": {}, "source": [ "# Validation with Pydantic\n", "\n", "Here, we'll see how to use pydantic to specify the schema and validate the results.\n", "\n", "**ATTENTION** Validation does *NOT* imply that extraction was correct. Validation only implies that the\n", "data was returned in the correct shape and meets all validation criteria. This doesn't mean\n", "that the LLM didn't make some up information!" ] }, { "cell_type": "code", "execution_count": 1, "id": "0b4597b2-2a43-4491-8830-bf9f79428074", "metadata": { "nbsphinx": "hidden", "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import sys\n", "\n", "sys.path.insert(0, \"../../\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "c719e4fc-3ccf-4633-a787-b2fe0d1eac65", "metadata": { "tags": [] }, "outputs": [], "source": [ "import enum\n", "from langchain.chat_models import ChatOpenAI\n", "from kor import create_extraction_chain, Object, Text, Number\n", "import pydantic\n", "from typing import List\n", "from kor import from_pydantic\n", "from pydantic import BaseModel, Field\n", "from typing import Optional" ] }, { "cell_type": "code", "execution_count": 3, "id": "f1313c02-d415-4ce6-bff0-3df537cc06c2", "metadata": { "tags": [] }, "outputs": [], "source": [ "llm = ChatOpenAI(\n", " model_name=\"gpt-3.5-turbo\",\n", " temperature=0,\n", ")" ] }, { "cell_type": "markdown", "id": "c6ce4726-db5b-49ee-abf7-780fd707be5f", "metadata": {}, "source": [ "Let's returning to our hypothetical music player API:" ] }, { "cell_type": "code", "execution_count": 4, "id": "fface268-cda5-430e-a0dc-c354ee4cfe2f", "metadata": { "tags": [] }, "outputs": [], "source": [ "class Action(enum.Enum):\n", " play = \"play\"\n", " stop = \"stop\"\n", " previous = \"previous\"\n", " next_ = \"next\"\n", "\n", "\n", "class MusicRequest(BaseModel):\n", " song: Optional[List[str]] = Field(\n", " default=None, description=\"The song(s) that the user would like to be played.\"\n", " )\n", " album: Optional[List[str]] = Field(\n", " default=None, description=\"The album(s) that the user would like to be played.\"\n", " )\n", " artist: Optional[List[str]] = Field(\n", " default=None,\n", " description=\"The artist(s) whose music the user would like to hear.\",\n", " examples=[(\"Songs by paul simon\", \"paul simon\")],\n", " )\n", " action: Optional[Action] = Field(\n", " default=None,\n", " description=\"The action that should be taken; one of `play`, `stop`, `next`, `previous`\",\n", " examples=[\n", " (\"Please stop the music\", \"stop\"),\n", " (\"play something\", \"play\"),\n", " (\"play a song\", \"play\"),\n", " (\"next song\", \"next\"),\n", " ],\n", " )" ] }, { "cell_type": "code", "execution_count": 5, "id": "4fe1ec70-1428-4433-acac-c190674a666e", "metadata": { "tags": [] }, "outputs": [], "source": [ "schema, validator = from_pydantic(MusicRequest)" ] }, { "cell_type": "markdown", "id": "49029ef7-c084-46f2-9791-6fb731252a9f", "metadata": {}, "source": [ "**ATTENTION** Use the JSON encoder here rather than the default CSV encoder as it supports nested lists" ] }, { "cell_type": "code", "execution_count": 6, "id": "ff9ad27f-7a81-4123-8d0b-1e14802df67e", "metadata": { "tags": [] }, "outputs": [], "source": [ "chain = create_extraction_chain(\n", " llm, schema, encoder_or_encoder_class=\"json\", validator=validator\n", ")" ] }, { "cell_type": "markdown", "id": "3cc3f196-5c63-44ae-bbf3-471c2af47bf8", "metadata": {}, "source": [ "Let's test it out" ] }, { "cell_type": "code", "execution_count": 7, "id": "248fa47a-18b5-414f-a656-31c88f7a8dc4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. Do not add any attributes that do not appear in the schema shown below.\n", "\n", "```TypeScript\n", "\n", "musicrequest: { // \n", " song: Array // The song(s) that the user would like to be played.\n", " album: Array // The album(s) that the user would like to be played.\n", " artist: Array // The artist(s) whose music the user would like to hear.\n", " action: \"play\" | \"stop\" | \"previous\" | \"next\" // The action that should be taken; one of `play`, `stop`, `next`, `previous`\n", "}\n", "```\n", "\n", "\n", "Please output the extracted information in JSON format. Do not output anything except for the extracted information. Do not add any clarifying information. Do not add any fields that are not in the schema. If the text contains attributes that do not appear in the schema, please ignore them. All output must be in JSON format and follow the schema specified above. Wrap the JSON in tags.\n", "\n", "\n", "\n", "Input: Songs by paul simon\n", "Output: {\"musicrequest\": {\"artist\": [\"paul simon\"]}}\n", "Input: Please stop the music\n", "Output: {\"musicrequest\": {\"action\": \"stop\"}}\n", "Input: play something\n", "Output: {\"musicrequest\": {\"action\": \"play\"}}\n", "Input: play a song\n", "Output: {\"musicrequest\": {\"action\": \"play\"}}\n", "Input: next song\n", "Output: {\"musicrequest\": {\"action\": \"next\"}}\n", "Input: [user input]\n", "Output:\n" ] } ], "source": [ "print(chain.prompt.format_prompt(text=\"[user input]\").to_string())" ] }, { "cell_type": "code", "execution_count": 8, "id": "760baa5f-9368-4b5a-abc0-6ac65c34b7a7", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=None, album=None, artist=None, action=)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"stop the music now\")[\"validated_data\"]" ] }, { "cell_type": "code", "execution_count": 9, "id": "593b106a-72e7-4c86-82a8-f9630a3bfabb", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=['yellow submarine'], album=None, artist=['the beatles'], action=None)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"i want to hear yellow submarine by the beatles\")[\"validated_data\"]" ] }, { "cell_type": "code", "execution_count": 10, "id": "8d728835-2e3e-40d3-8bba-2cd4bd4ec789", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=['goliath'], album=None, artist=['smith&thell'], action=None)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"play goliath by smith&thell\")[\"validated_data\"]" ] }, { "cell_type": "code", "execution_count": 11, "id": "02c7f1e5-1c8d-4e9f-82e6-c37a41d6de14", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=None, album=['the lion king soundtrack'], artist=None, action=None)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"can you play the lion king soundtrack\")[\"validated_data\"]" ] }, { "cell_type": "code", "execution_count": 12, "id": "7a6d918c-53fe-426b-b37e-eec2abb8a704", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=None, album=None, artist=['paul simon', 'led zeppelin', 'the doors'], action=None)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"play songs by paul simon and led zeppelin and the doors\")[\"validated_data\"]" ] }, { "cell_type": "code", "execution_count": 13, "id": "b18acf0a-d99e-48de-ace5-fb01bded5a41", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=None, album=None, artist=None, action=)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"could you play the previous song again?\")[\"validated_data\"]" ] }, { "cell_type": "code", "execution_count": 14, "id": "8c5c06b5-7f3a-416d-a809-19004cb14750", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=None, album=None, artist=None, action=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"previous\")[\"validated_data\"]" ] }, { "cell_type": "code", "execution_count": 15, "id": "03f33a42-e1c9-4bc2-a6b4-15d3f3c5ce32", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MusicRequest(song=None, album=None, artist=None, action=)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"play the song before\")[\"validated_data\"]" ] }, { "cell_type": "markdown", "id": "95f22605-a2b0-401e-ac7d-48d97913d535", "metadata": {}, "source": [ "## Validation in Action" ] }, { "cell_type": "code", "execution_count": 16, "id": "f779d4a2-ddb3-49c4-86cf-d9537b6ca6d4", "metadata": { "tags": [] }, "outputs": [], "source": [ "class Player(BaseModel):\n", " song: List[str] = Field(\n", " description=\"The song(s) that the user would like to be played.\"\n", " ) # <-- Note this is NOT Optional\n", " album: Optional[List[str]] = Field(\n", " default=None, description=\"The album(s) that the user would like to be played.\"\n", " )\n", " artist: Optional[List[str]] = Field(\n", " default=None,\n", " description=\"The artist(s) whose music the user would like to hear.\",\n", " examples=[(\"Songs by paul simon\", \"paul simon\")],\n", " )\n", " action: Optional[Action] = Field(\n", " default=None,\n", " description=\"The action that should be taken; one of `play`, `stop`, `next`, `previous`\",\n", " examples=[\n", " (\"Please stop the music\", \"stop\"),\n", " (\"play something\", \"play\"),\n", " (\"play a song\", \"play\"),\n", " (\"next song\", \"next\"),\n", " ],\n", " )" ] }, { "cell_type": "code", "execution_count": 17, "id": "38979da3-13aa-4e34-b11c-3656d7e3b4f6", "metadata": { "tags": [] }, "outputs": [], "source": [ "schema, validator = from_pydantic(Player)\n", "chain = create_extraction_chain(\n", " llm, schema, encoder_or_encoder_class=\"json\", validator=validator\n", ")" ] }, { "cell_type": "markdown", "id": "ceaec432-480f-44f0-aad9-9eb8146f0f02", "metadata": {}, "source": [ "Now the schema expects that a list of songs parsed out in the query." ] }, { "cell_type": "markdown", "id": "229771c7-eb16-4a54-85ac-13f4f0d8c527", "metadata": {}, "source": [ "### No valid data!\n", "\n", "We made *SONG* a required attribute in the pydantic schema above! Let's see what happens now!" ] }, { "cell_type": "code", "execution_count": 19, "id": "d594eb6e-02a0-4774-9dca-421db192372d", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Error in LangChainTracer.on_chain_end callback: No constructor defined\n" ] }, { "data": { "text/plain": [ "{'data': {'player': {'action': 'stop'}},\n", " 'raw': '{\"player\": {\"action\": \"stop\"}}',\n", " 'errors': [1 validation error for Player\n", " song\n", " Field required [type=missing, input_value={'action': 'stop'}, input_type=dict]\n", " For further information visit https://errors.pydantic.dev/2.3/v/missing],\n", " 'validated_data': None}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"stop the music now\")" ] }, { "cell_type": "code", "execution_count": 21, "id": "672f9908-c6b7-47cf-8c82-03b0e5b8fa84", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Player(song=['yellow submarine'], album=None, artist=['the beatles'], action=None)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"i want to hear yellow submarine by the beatles\")[\"validated_data\"]" ] }, { "cell_type": "markdown", "id": "02559660-20a4-4fce-83a4-2e3b68794f19", "metadata": {}, "source": [ "## Validating Collections\n", "\n", "Currently, there are a few gotchyas when modeling collections that depend on the encoder." ] }, { "cell_type": "markdown", "id": "3da8c0fe-3765-4a6e-8925-85806e9500cc", "metadata": {}, "source": [ "### CSV Encoder\n", "\n", "A CSV encoder is expected to work best when encoding a list of records.\n", "\n", "At the moment, the CSV encoder doesn't handle embedded lists or objects. \n", "\n", "(This works with either JSON or CSV.)" ] }, { "cell_type": "code", "execution_count": 22, "id": "cd9dffbe-82bf-4d1f-9b0a-3ec2c58b63d6", "metadata": { "tags": [] }, "outputs": [], "source": [ "class Person(BaseModel):\n", " name: str = Field(description=\"The person's name\")\n", " age: int = Field(description=\"The age of the person\")" ] }, { "cell_type": "code", "execution_count": 23, "id": "f859b6e1-c2d8-48e0-af17-2ffb286bffe9", "metadata": { "tags": [] }, "outputs": [], "source": [ "schema, validator = from_pydantic(\n", " Person,\n", " description=\"Personal information\",\n", " many=True,\n", " examples=[(\"Joe is 10 years old\", {\"name\": \"Joe\", \"age\": \"10\"})],\n", ")\n", "chain = create_extraction_chain(\n", " llm, schema, encoder_or_encoder_class=\"csv\", validator=validator\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "af6c6339-81db-482b-9507-31f41d2fa489", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. Do not add any attributes that do not appear in the schema shown below.\n", "\n", "```TypeScript\n", "\n", "person: Array<{ // Personal information\n", " name: string // The person's name\n", " age: number // The age of the person\n", "}>\n", "```\n", "\n", "\n", "Please output the extracted information in CSV format in Excel dialect. Please use a | as the delimiter. \n", " Do NOT add any clarifying information. Output MUST follow the schema above. Do NOT add any additional columns that do not appear in the schema.\n", "\n", "\n", "\n", "Input: Joe is 10 years old\n", "Output: name|age\n", "Joe|10\n", "\n", "Input: [user input]\n", "Output:\n" ] } ], "source": [ "print(chain.prompt.format_prompt(text=\"[user input]\").to_string())" ] }, { "cell_type": "code", "execution_count": 25, "id": "837a08c2-8de0-448a-9984-0cf19a73d4a3", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "[Person(name='john', age=13), Person(name='maria', age=24)]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\"john is 13 years old. maria is 24 years old\")[\"validated_data\"]" ] }, { "cell_type": "markdown", "id": "7622135e-3a2e-41b2-a99f-980e5dd9f804", "metadata": {}, "source": [ "## Complex Structure\n", "\n", "To serialize more complex structures, use the JSON encoder.\n", "\n", "So for the example, above the following alternative works:" ] }, { "cell_type": "code", "execution_count": 26, "id": "ee92ae5e-52e9-4405-9718-be71d25ce412", "metadata": { "tags": [] }, "outputs": [], "source": [ "class Person(BaseModel):\n", " name: str = Field(description=\"The person's name\")\n", " age: int = Field(description=\"The age of the person\")\n", "\n", "\n", "class Root(BaseModel):\n", " people: List[Person] = Field(\n", " description=\"Personal information\",\n", " examples=[(\"John was 23 years old\", {\"name\": \"John\", \"age\": 23})],\n", " )" ] }, { "cell_type": "markdown", "id": "b1c618c8-adab-48bd-bc5d-45dc5adb1dbb", "metadata": {}, "source": [ "** NOTE ** Using a Root container and `many` = `False`" ] }, { "cell_type": "code", "execution_count": 27, "id": "7464aab0-45fb-4e22-bed4-b695c7f60629", "metadata": { "tags": [] }, "outputs": [], "source": [ "schema, validator = from_pydantic(Root, description=\"Personal information\", many=False)\n", "chain = create_extraction_chain(\n", " llm, schema, encoder_or_encoder_class=\"json\", validator=validator\n", ")" ] }, { "cell_type": "code", "execution_count": 28, "id": "8becd13d-bd23-4d37-8fd3-7548a7fe51c1", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'data': {'root': {'people': [{'name': 'tom', 'age': 23},\n", " {'name': 'Jessica', 'age': 75}]}},\n", " 'raw': '{\"root\": {\"people\": [{\"name\": \"tom\", \"age\": 23}, {\"name\": \"Jessica\", \"age\": 75}]}}',\n", " 'errors': [],\n", " 'validated_data': Root(people=[Person(name='tom', age=23), Person(name='Jessica', age=75)])}" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\n", " \"My name is tom and i am 23 years old. Her name is Jessica and she is 75 years old.\"\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "7ea06fe0-104b-487a-ab78-cd23de66ec88", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. Do not add any attributes that do not appear in the schema shown below.\n", "\n", "```TypeScript\n", "\n", "root: { // Personal information\n", " people: Array<{ // Personal information\n", " name: string // The person's name\n", " age: number // The age of the person\n", " }>\n", "}\n", "```\n", "\n", "\n", "Please output the extracted information in JSON format. Do not output anything except for the extracted information. Do not add any clarifying information. Do not add any fields that are not in the schema. If the text contains attributes that do not appear in the schema, please ignore them. All output must be in JSON format and follow the schema specified above. Wrap the JSON in tags.\n", "\n", "\n", "\n", "Input: John was 23 years old\n", "Output: {\"root\": {\"people\": [{\"name\": \"John\", \"age\": 23}]}}\n", "Input: [user input]\n", "Output:\n" ] } ], "source": [ "print(chain.prompt.format_prompt(text=\"[user input]\").to_string())" ] }, { "cell_type": "code", "execution_count": 41, "id": "38245cd6-e188-40c9-a940-184da8755983", "metadata": { "tags": [] }, "outputs": [], "source": [ "class Pet(BaseModel):\n", " name: str = Field(description=\"the name of the pet\")\n", " species: Optional[str] = Field(\n", " default=None, description=\"The species of the pet; e.g., dog or cat\"\n", " )\n", " age: Optional[int] = Field(default=None, description=\"The number of the age; e.g.,\")\n", " age_unit: Optional[str] = Field(\n", " default=None, description=\"The unit of the age; e.g., days or weeks\"\n", " )\n", "\n", "\n", "class Person(BaseModel):\n", " name: str = Field(description=\"The person's name\")\n", " age: Optional[int] = Field(default=None, description=\"The age of the person\")\n", " pets: List[Pet] = Field(\n", " description=\"The pets owned by the person\",\n", " examples=[\n", " (\n", " \"he had a dog by the name of charles that was 5 days old\",\n", " {\"name\": \"dog\", \"species\": \"dog\", \"age\": \"5\", \"age_unit\": \"days\"},\n", " )\n", " ],\n", " )\n", "\n", "\n", "class Root(BaseModel):\n", " people: List[Person] = Field(\n", " description=\"Personal information\",\n", " examples=[(\"John was 23 years old\", {\"name\": \"John\", \"age\": 23})],\n", " )" ] }, { "cell_type": "code", "execution_count": 42, "id": "236a9510-6f69-4d63-8854-62e2c57380a6", "metadata": { "tags": [] }, "outputs": [], "source": [ "schema, validator = from_pydantic(\n", " Root, description=\"Personal information for multiple people\", many=False\n", ")\n", "chain = create_extraction_chain(\n", " llm, schema, encoder_or_encoder_class=\"json\", validator=validator\n", ")" ] }, { "cell_type": "code", "execution_count": 43, "id": "843e992a-32c5-4382-95ab-33eb3cd7810b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Your goal is to extract structured information from the user's input that matches the form described below. When extracting information please make sure it matches the type information exactly. Do not add any attributes that do not appear in the schema shown below.\n", "\n", "```TypeScript\n", "\n", "root: { // Personal information for multiple people\n", " people: Array<{ // Personal information\n", " name: string // The person's name\n", " age: number // The age of the person\n", " pets: Array<{ // The pets owned by the person\n", " name: string // the name of the pet\n", " species: string // The species of the pet; e.g., dog or cat\n", " age: number // The number of the age; e.g.,\n", " age_unit: string // The unit of the age; e.g., days or weeks\n", " }>\n", " }>\n", "}\n", "```\n", "\n", "\n", "Please output the extracted information in JSON format. Do not output anything except for the extracted information. Do not add any clarifying information. Do not add any fields that are not in the schema. If the text contains attributes that do not appear in the schema, please ignore them. All output must be in JSON format and follow the schema specified above. Wrap the JSON in tags.\n", "\n", "\n", "\n", "Input: John was 23 years old\n", "Output: {\"root\": {\"people\": [{\"name\": \"John\", \"age\": 23}]}}\n", "Input: he had a dog by the name of charles that was 5 days old\n", "Output: {\"root\": {\"people\": [{\"pets\": [{\"name\": \"dog\", \"species\": \"dog\", \"age\": \"5\", \"age_unit\": \"days\"}]}]}}\n", "Input: [user input]\n", "Output:\n" ] } ], "source": [ "print(chain.prompt.format_prompt(text=\"[user input]\").to_string())" ] }, { "cell_type": "code", "execution_count": 45, "id": "e34b6194-9a2b-43a1-95c6-2c9930d036ed", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Root(people=[Person(name='Neo', age=None, pets=[Pet(name='Tom', species='dog', age=None, age_unit=None), Pet(name='Weeby', species='cat', age=23, age_unit='days')]), Person(name='Julia', age=None, pets=[Pet(name='Wind', species='horse', age=None, age_unit=None)])])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.run(\n", " text=\"Neo had a dog by the name of Tom and a cat by the name of Weeby. Weeby was 23 days old. Julia owned a horse. The horses name was Wind\"\n", ")[\"validated_data\"]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }