"""Definitions of input elements."""
from __future__ import annotations
import abc
import copy
from typing import (
Any,
Generic,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel
from ._pydantic import PYDANTIC_MAJOR_VERSION
# Name of field to store the type discriminator
TYPE_DISCRIMINATOR_FIELD = "$type"
T = TypeVar("T")
# Visitor is defined here for now, to avoid circular imports.
[docs]class AbstractVisitor(Generic[T], abc.ABC):
"""An abstract visitor."""
[docs] def visit_text(self, node: Text, **kwargs: Any) -> T:
"""Visit text node."""
return self.visit_default(node, **kwargs)
[docs] def visit_number(self, node: Number, **kwargs: Any) -> T:
"""Visit text node."""
return self.visit_default(node, **kwargs)
[docs] def visit_object(self, node: Object, **kwargs: Any) -> T:
"""Visit object node."""
return self.visit_default(node, **kwargs)
[docs] def visit_selection(self, node: Selection, **kwargs: Any) -> T:
"""Visit selection node."""
return self.visit_default(node, **kwargs)
[docs] def visit_option(self, node: Option, **kwargs: Any) -> T:
"""Visit option node."""
return self.visit_default(node, **kwargs)
[docs] def visit_default(self, node: AbstractSchemaNode, **kwargs: Any) -> T:
"""Default node implementation."""
raise NotImplementedError()
[docs] def visit_bool(self, node: Bool, **kwargs: Any) -> T:
"""Visit bool node."""
return self.visit_default(node, **kwargs)
[docs]class AbstractSchemaNode(BaseModel):
"""Abstract schema node.
Each node is expected to have a unique ID, and should
only use alphanumeric characters.
The ID should be unique across all inputs that belong
to a given form.
The description should describe what the node represents.
It is used during prompt generation.
"""
id: str
description: str = ""
many: bool = False
[docs] @abc.abstractmethod
def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
raise NotImplementedError()
# Update return type to `Self` when bumping python version.
[docs] def replace(
self,
id: Optional[str] = None, # pylint: disable=redefined-builtin
description: Optional[str] = None,
) -> "AbstractSchemaNode":
"""Wrapper around data-classes replace."""
new_object = copy.copy(self)
if id:
new_object.id = id
if description:
new_object.description = description
return new_object
[docs]class Number(ExtractionSchemaNode):
"""Built-in number input."""
examples: Sequence[
Tuple[str, Union[int, float, Sequence[Union[float, int]]]]
] = tuple()
[docs] def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_number(self, **kwargs)
[docs]class Text(ExtractionSchemaNode):
"""Built-in text input."""
examples: Sequence[Tuple[str, Union[Sequence[str], str]]] = tuple()
[docs] def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_text(self, **kwargs)
[docs]class Bool(ExtractionSchemaNode):
"""Built-in bool input."""
examples: Sequence[Tuple[str, Union[Sequence[bool], bool]]] = tuple()
[docs] def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_bool(self, **kwargs)
[docs]class Option(AbstractSchemaNode):
"""Built-in option input must be part of a selection input."""
examples: Sequence[str] = tuple()
[docs] def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_option(self, **kwargs)
[docs]class Selection(AbstractSchemaNode):
"""Built-in selection node (aka Enum).
A selection input is composed of one or more options.
A selectio node supports both examples and null_examples.
Null examples are segments of text for which nothing should be extracted.
Examples:
.. code-block:: python
selection = Selection(
id="species",
description="What is your favorite animal species?",
options=[
Option(id="dog", description="Dog"),
Option(id="cat", description="Cat"),
Option(id="bird", description="Bird"),
],
examples=[
("I like dogs", "dog"),
("I like cats", "cat"),
("I like birds", "bird"),
],
null_examples=[
"I like flowers",
],
many=False
)
"""
options: Sequence[Option]
examples: Sequence[Tuple[str, Union[str, Sequence[str]]]] = tuple()
null_examples: Sequence[str] = tuple()
[docs] def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_selection(self, **kwargs)
[docs]class Object(AbstractSchemaNode):
"""Built-in representation for an object.
Use an object node to represent an entire object that should be extracted.
An extraction input can be associated with 2 different types of examples:
Example:
.. code-block:: python
object = Object(
id="cookie",
description="Information about a cookie including price and name.",
attributes=[
Text(id="name", description="The name of the cookie"),
Number(id="price", description="The price of the cookie"),
],
examples=[
("I bought this Big Cookie for $10",
{"name": "Big Cookie", "price": "$10"}),
("Eggs cost twelve dollars", {}), # Not a cookie
],
)
"""
attributes: Sequence[Union[ExtractionSchemaNode, Selection, Object]]
examples: Sequence[
Tuple[
str,
Union[
Sequence[Mapping[str, Any]],
Mapping[str, Any],
],
]
] = tuple()
[docs] def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_object(self, **kwargs)
[docs] @classmethod
def parse_raw(cls, *args: Any, **kwargs: Any) -> Object:
"""Parse raw data."""
if PYDANTIC_MAJOR_VERSION != 1:
raise NotImplementedError(
f"parse_raw is not supported for pydantic {PYDANTIC_MAJOR_VERSION}"
)
return super().parse_raw(*args, **kwargs)
[docs] @classmethod
def parse_obj(cls, *args: Any, **kwargs: Any) -> Object:
"""Parse an object."""
if PYDANTIC_MAJOR_VERSION != 1:
raise NotImplementedError(
f"parse_obj is not supported for pydantic {PYDANTIC_MAJOR_VERSION}"
)
return super().parse_obj(*args, **kwargs)