Source code for ayx_python_sdk.providers.amp_provider.repositories.output_record_packet_repository

# Copyright (C) 2022 Alteryx, Inc. All rights reserved.
#
# Licensed under the ALTERYX SDK AND API LICENSE AGREEMENT;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.alteryx.com/alteryx-sdk-and-api-license-agreement
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class that saves/retrieves output record packets."""
import logging
from typing import Dict, List, TYPE_CHECKING

from ayx_python_sdk.providers.amp_provider.builders import RecordPacketBuilder
from ayx_python_sdk.providers.amp_provider.repositories.grpc_repository import (
    GrpcRepository,
)
from ayx_python_sdk.providers.amp_provider.repositories.singleton import Singleton
from ayx_python_sdk.providers.amp_provider.resources.generated.outgoing_record_packet_push_pb2 import (
    OutgoingRecordPacketPush,
)
from ayx_python_sdk.providers.amp_provider.resources.generated.record_packet_pb2 import (
    RecordPacket as ProtobufRecordPacket,
)

if TYPE_CHECKING:
    from ayx_python_sdk.core.record_packet_base import RecordPacketBase

logger = logging.getLogger(__name__)


[docs]class OutputRecordPacketRepository(metaclass=Singleton): """Repository that stores output record packets.""" _record_packet_builder = RecordPacketBuilder() def __init__(self) -> None: """Initialize the output record packet repository.""" self._record_packet_map: Dict[str, "RecordPacketBase"] = {} self._record_packet_sequence: Dict[str, int] = {} self._anchor_progress: Dict[str, float] = {}
[docs] def save_record_packet( self, anchor_name: str, record_packet: "RecordPacketBase" ) -> None: """ Save a record packet. Parameters ---------- anchor_name The name of the anchor that the record packet is associated with. record_packet The record packet to save to the repository. """ logger.debug( "Saving record packet for %s in OutputRecordPacketRepository", anchor_name ) self._record_packet_map[anchor_name] = record_packet logger.debug( "Current OutputRecordPacketRepository State: %r", self._record_packet_map ) try: client = GrpcRepository().get_sdk_engine_client() except ValueError: pass else: logger.debug("Pushing output record packet for anchor %s", anchor_name) for record_packet in self.get_grpc_record_packets(anchor_name): client.PushOutgoingRecordPacket( OutgoingRecordPacketPush( anchor_name=anchor_name, record_packet=record_packet, ) )
[docs] def get_record_packet(self, anchor_name: str) -> "RecordPacketBase": """ Get a record packet. Parameters ---------- anchor_name The name of the output anchor that the record packet is associated with. Returns ------- RecordPacketBase The record packet associated with the anchor name. """ if anchor_name not in self._record_packet_map: raise ValueError(f"Anchor {anchor_name} not found in repository.") return self._record_packet_map[anchor_name]
[docs] def get_grpc_record_packets(self, anchor_name: str) -> List[ProtobufRecordPacket]: """ Get a record packet in protobuf format. Parameters ---------- anchor_name The name of the anchor to delete. Returns ------- List[ProtobufRecordPacket] The list of protobuf record packets that are associated with the passed in anchor name. """ record_packet = self.get_record_packet(anchor_name) if anchor_name not in self._record_packet_sequence: self._record_packet_sequence[anchor_name] = 0 pb_packets = self._record_packet_builder.to_protobuf( record_packet, self._record_packet_sequence[anchor_name], self.get_anchor_progress(anchor_name), ) self._record_packet_sequence[anchor_name] += len(pb_packets) return pb_packets
[docs] def save_anchor_progress(self, anchor_name: str, progress: float) -> None: """ Save the anchor progress. Parameters ---------- anchor_name The name of the anchor to delete. progress The progress percentage of the anchor. """ if not 0.0 <= progress <= 1.0: raise ValueError("Progress must be between 0 and 1.") self._anchor_progress[anchor_name] = progress
[docs] def get_anchor_progress(self, anchor_name: str) -> float: """ Get the anchor progress. Parameters ---------- anchor_name The name of the anchor to delete. Returns ------- progress The progress percentage of the anchor. """ return self._anchor_progress.get(anchor_name, 0.0)
[docs] def clear_repository(self) -> None: """Delete all data in the repository.""" logger.debug("Clearing OutputRecordPacketRepository") self._record_packet_map = {} self._record_packet_sequence = {} self._anchor_progress = {} logger.debug( "Current OutputRecordPacketRepository State: %r", self._record_packet_map )