Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cartpole/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from cartpole.common import Config, Error, State

3 changes: 1 addition & 2 deletions cartpole/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from cartpole.common.interface import Config, Error, State, CartPoleBase
from cartpole.common.view import generate_pyplot_animation
from cartpole.common.interface import Config, Error, State
114 changes: 35 additions & 79 deletions cartpole/common/interface.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,67 @@
import enum
import dataclasses as dc

import numpy as np
import numpy
import torch

from pydantic import BaseModel
from typing import Any

class Error(enum.IntEnum):
NO_ERROR = 0
NEED_RESET = 1
X_OVERFLOW = 2
V_OVERFLOW = 3
A_OVERFLOW = 4
MOTOR_STALLED = 5
ENDSTOP_HIT = 6
CART_POSITION_OVERFLOW = 2
CART_VELOCITY_OVERFLOW = 3
CART_ACCELERATION_OVERFLOW = 4
HARDWARE = 5

def __bool__(self) -> bool:
return self != Error.NO_ERROR


@dc.dataclass
class Config:
class Config(BaseModel):
# software cart limits
max_position: float = 0.25 # m
max_velocity: float = 2.0 # m/s
max_acceleration: float = 3.5 # m/s^2
# hardware limits
hard_max_position: float = 0.27 # m
hard_max_velocity: float = 2.5 # m/s
hard_max_acceleration: float = 5.0 # m/s^2
# hardware flags
clamp_position: bool = False
clamp_velocity: bool = False
clamp_acceleration: bool = False
# physical params
pole_length: float = 0.3 # m
pole_mass: float = 0.118 # kg
gravity: float = 9.8 # m/s^2


@dc.dataclass
class State:
cart_position: float = 0
cart_velocity: float = 0
pole_angle: float = 0
pole_angular_velocity: float = 0
error: Error = Error.NO_ERROR
cart_acceleration: float = 0

@staticmethod
def from_array(a):
'''
q = (x, a, v, w)
'''
return State(a[0], a[2], a[1], a[3])

@staticmethod
def home():
return State(.0, .0, .0, .0)

def as_tuple(self):
return (
self.cart_position,
self.pole_angle,
self.cart_velocity,
self.pole_angular_velocity,
)
class State(BaseModel):
# cart state
cart_position: float = 0.0
cart_velocity: float = 0.0

def as_array(self):
return np.array(self.as_tuple())
# pole state
pole_angle: float = 0.0
pole_angular_velocity: float = 0.0

def as_array_4x1(self):
return self.as_array().reshape(4, 1)
# control state
cart_acceleration: float = 0.0
error: Error = Error.NO_ERROR

def __repr__(self):
return '(x={x:+.2f}, v={v:+.2f}, a={a:+.2f}, w={w:+.2f}, err={err})'.format(
x = self.cart_position,
v = self.cart_velocity,
a = self.pole_angle,
w = self.pole_angular_velocity,
err=self.error,
)
@staticmethod
def home() -> 'State':
return State()


class CartPoleBase:
'''
Description:
Сlass implements a physical simulation of the cart-pole device.
The class specifies a interface of the cart-pole (device or simulation).
A pole is attached by an joint to a cart, which moves along guide axis.
The pendulum is initially at rest state. The goal is to maintain it in
upright pose by increasing and reducing the cart's velocity.
upright pose by increasing and reducing cart's velocity.
Source:
This environment is some variation of the cart-pole problem
described by Barto, Sutton, and Anderson
Initial state:
A pole is at starting position 0 with no velocity and acceleration.
'''

def reset(self, config: Config) -> None:
def __init__(self, config: Config):
self.config = config

def reset(self, state: State = State.home()) -> None:
'''
Resets the device to the initial state.
The pole is at rest position and cart is centered.
Resets the device to the state.
It must be called at the beginning of any session.
For real device only position may be set.
'''
raise NotImplementedError

Expand All @@ -109,33 +71,27 @@ def get_state(self) -> State:
'''
raise NotImplementedError

def get_info(self) -> dict:
def get_info(self) -> Any:
'''
Returns usefull debug information.
'''
raise NotImplementedError

def get_target(self) -> float:
'''
Returns current target acceleration.
'''
raise NotImplementedError

def set_target(self, target: float) -> None:
'''
Set desired target acceleration.
'''
raise NotImplementedError

def advance(self, delta: float = None) -> None:
def advance(self, delta: float) -> None:
'''
Advance the dynamic system by delta seconds.
Advance system by delta seconds (has means only for simulation).
'''
pass

def timestamp(self):
def timestamp(self) -> float:
'''
Current time.
Current time in seconds (float).
'''
raise NotImplementedError

Expand Down
196 changes: 196 additions & 0 deletions cartpole/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import asyncio
import json
import logging
import time

from threading import Event, Thread
from foxglove_websocket.server import FoxgloveServer
from mcap.writer import Writer
from pydantic import BaseModel
from typing import Any

from cartpole import State


def to_ns(t: float) -> int:
'''
Convert time in seconds to nanoseconds
'''
return int(t * 1e9)

class Registration(BaseModel):
cls: Any
channel_id: int

class MCAPLogger:
def __init__(self, log_path: str):
'''
Args:
log_path: path to mcap log file
'''

self._writer = Writer(open(log_path, "wb"))
self._topic_to_registration: Dict[str, Registration] = {}

def _register(self, topic_name: str, cls: Any) -> None:
assert issubclass(cls, BaseModel), 'Required pydantic model, but got {cls.__name__}'

if topic_name in self._topic_to_registration:
cached = self._topic_to_registration[topic_name]
assert cached.cls == cls, f'Topic {topic} already registered with {cached.cls.__name__}'
return

schema_id = self._writer.register_schema(
name=cls.__name__,
encoding="jsonschema",
data=cls.schema_json().encode())

channel_id = self._writer.register_channel(
schema_id=schema_id,
topic=topic_name,
message_encoding="json")

self._topic_to_registration[topic_name] = Registration(cls=cls, channel_id=channel_id)

def publish(self, topic: str, obj: BaseModel, stamp: float) -> None:
'''
Args:
* topic: topic name
* obj: object to dump (pydantic model)
* stamp: timestamp in nanoseconds (float)
'''

self._register(topic, type(obj))
self._writer.add_message(
channel_id=self._topic_to_registration[topic].channel_id,
log_time=to_ns(stamp),
data=obj.json().encode(),
publish_time=to_ns(stamp))

def foxglove_logger() -> logging.Logger:
logger = logging.getLogger("LogServer")
logger.setLevel(logging.ERROR)

handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s: [%(levelname)s] %(message)s"))
logger.addHandler(handler)

return logger

async def _foxglove_async_entrypoint(queue: asyncio.Queue, stop: Event) -> None:
async with FoxgloveServer("0.0.0.0", 8765, "CartPole", logger=foxglove_logger()) as server:
topic_to_registration = {}

async def register(topic_name, cls):
assert issubclass(cls, BaseModel), f'Required pydantic model, but got {cls.__name__}'

if topic_name in topic_to_registration:
cached = topic_to_registration[topic_name]
assert cached.cls == cls, f'Topic {topic} already registered with {cached.cls.__name__}'
return

spec = {
"topic": topic_name,
"encoding": "json",
"schemaName": cls.__name__,
"schema": cls.schema_json().encode(),
}

channel_id = await server.add_channel(spec)
topic_to_registration[topic_name] = Registration(cls=cls, channel_id=channel_id)

while not stop.is_set():
topic_name, stamp, obj = await queue.get()
await register(topic_name, type(obj))
channel_id = topic_to_registration[topic_name].channel_id
await server.send_message(channel_id, to_ns(stamp), obj.json().encode())

def foxglove_main(loop: asyncio.AbstractEventLoop, queue: asyncio.Queue, stop: Event) -> None:
asyncio.set_event_loop(loop)
loop.run_until_complete(_foxglove_async_entrypoint(queue, stop))

class FoxgloveWebsocketLogger:
def __init__(self):
self._loop = asyncio.new_event_loop()
self._queue = asyncio.Queue()
self._stop = Event()
self._writer = None

self._foxlgove_thread = Thread(
target=foxglove_main,
name='foxglove_main_loop',
daemon=True,
args=(self._loop, self._queue, self._stop))

self._foxlgove_thread.start()

def publish(self, topic_name: str, obj: BaseModel, stamp: float) -> None:
'''
Args:
* topic_name: topic name
* obj: object to dump (pydantic model)
* stamp: timestamp in nanoseconds (float)
'''

if not (self._loop.is_running() and self._foxlgove_thread.is_alive()):
raise RuntimeError('Foxglove logger is not running')

item = (topic_name, stamp, obj)
asyncio.run_coroutine_threadsafe(self._queue.put(item), self._loop)

def __del__(self):
self._stop.set()

class Logger:
def __init__(self, log_path: str = ''):
'''
Args:
* log_path: path to mcap log file, if not provided, no mcap log will be created
'''

self._foxglove_log = FoxgloveWebsocketLogger()
self._mcap_log = None
if log_path:
self.mcap_log = MCAPLogger(log_path)

def publish(self, topic_name: str, obj: BaseModel, stamp: float) -> None:
'''
Args:
* topic_name: topic name
* obj: pydantic model
* stamp: timestamp in nanoseconds (float)
'''

if self._mcap_log:
self._mcap_log.publish(topic_name, obj, stamp)

self._foxglove_log.publish(topic_name, obj, stamp)


__logger = None

def setup(log_path: str = '') -> None:
'''
Args:
* log_path: path to mcap log file, if not provided, no mcap log will be created
'''

global __logger
__logger = Logger(log_path)


def publish(topic_name: str, obj: BaseModel, stamp: float|None = None) -> None:
'''
Args:
* topic_name: topic name
* obj: pydantic model
* stamp: timestamp in nanoseconds (float), if not provided, current time used
'''

if not __logger:
setup()

if stamp is None:
stamp = time.time()

__logger.publish(topic_name, obj, stamp)
Loading