Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ venv
*$py.class

# Protobuf auto-generated bindings
firmware/src/protocol.pb.[ch]
firmware/src/nanopb.pb.[ch]
cartpole/device/protocol_pb2.py
cartpole/device/nanopb_pb2.py
firmware/src/proto
cartpole/device/proto
60 changes: 42 additions & 18 deletions cartpole/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,46 @@

def parse_args():
common = argparse.ArgumentParser(
prog='cartpole',
description='cartpole control experiments'
prog='cartpole', description='cartpole control experiments'
)

subparsers = common.add_subparsers(title='commands', dest='command', required=True, help='command help')
subparsers = common.add_subparsers(
title='commands', dest='command', required=True, help='command help'
)

# common arguments

common.add_argument('-S', '--simulation', action='store_true', help='simulation mode')
common.add_argument(
'-S', '--simulation', action='store_true', help='simulation mode'
)
common.add_argument('-c', '--config', type=str, help='cartpole yaml config file')
common.add_argument('-m', '--mcap', type=str, default='', help='mcap log file')
common.add_argument('-a', '--advance', type=float, default=0.01, help='advance simulation time (seconds)')
common.add_argument(
'-a',
'--advance',
type=float,
default=0.01,
help='advance simulation time (seconds)',
)

# eval arguments
eval = subparsers.add_parser('eval', help='system identification')

eval.add_argument('-d', '--duration', type=float, default=10.0, help='experiment duration (seconds)')
eval.add_argument(
'-d',
'--duration',
type=float,
default=10.0,
help='experiment duration (seconds)',
)
eval.add_argument('-O', '--output', type=str, help='output yaml config file')

return common.parse_args()


def evaluate(device: CartPoleBase, config: Config, args: argparse.Namespace) -> None:
log.info('parameters evaluation')
random.seed(0)

position_margin = 0.01
position_tolerance = 0.005
Expand All @@ -48,8 +64,10 @@ def evaluate(device: CartPoleBase, config: Config, args: argparse.Namespace) ->

log.info(f'run calibration session for {duration:.2f} seconds')
device.reset()
time.sleep(5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Это чтоб палочка успокоилась?


start = device.get_state()
print(start)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Это потом подчистите?

state = start

target = Target(position=0, velocity=0, acceleration=0)
Expand All @@ -59,25 +77,29 @@ def evaluate(device: CartPoleBase, config: Config, args: argparse.Namespace) ->
state = device.get_state()

if abs(target.position - state.cart_position) < position_tolerance:
position = random.uniform(position_max/2, position_max)
position = random.uniform(position_max / 2, position_max)
target.position = position if target.position < 0 else -position
target.velocity = random.uniform(velocity_max/2, velocity_max)
target.acceleration = random.uniform(acceleration_max/2, acceleration_max)
target.velocity = random.uniform(velocity_max / 2, velocity_max)
target.acceleration = random.uniform(acceleration_max / 2, acceleration_max)

log.info(f'target {target}')
device.set_target(target)
state = device.set_target(target)

log.publish('/cartpole/state', state, state.stamp)
log.publish('/cartpole/target', target, state.stamp)
log.publish('/cartpole/info', device.get_info(), state.stamp)
Comment on lines -70 to -72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

А почему решили убрать stamp?

log.publish('/cartpole/state', state)
log.publish('/cartpole/target', target)
# log.publish('/cartpole/info', device.get_info(), state.stamp)

states.append(state)
device.advance(advance)

if args.simulation:
time.sleep(advance) # simulate real time
time.sleep(advance) # simulate real time

if state.error:
print('ERR', state.error)

log.info(f'find parameters')
print(len(states))
parameters = find_parameters(states, config.parameters.gravity)

log.info(f'parameters: {parameters}')
Expand Down Expand Up @@ -106,24 +128,26 @@ def main():

if args.simulation:
log.info('simulation mode')
device = Simulator(integration_step=min(args.advance/20, 0.001))
device = Simulator(integration_step=min(args.advance / 20, 0.001))
else:
raise NotImplementedError()

from cartpole.device import CartPoleDevice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

МБ сверху лучше?


device = CartPoleDevice(hard_reset=True)

if args.config:
log.debug(f'config file: {args.config}')
config = Config.from_yaml_file(args.config)
else:
log.warning('no config file specified, using defaults')
config = Config()


device.set_config(config)

if args.command == 'eval':
evaluate(device, config, args)
else:
raise NotImplementedError()


if __name__ == "__main__":
main()
29 changes: 7 additions & 22 deletions cartpole/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch

from pydantic import BaseModel, Field
from typing import Any
from typing import Any, Optional

import json
import yaml
import time


class Limits(BaseModel):
Expand Down Expand Up @@ -152,7 +153,7 @@ class State(BaseModel):
pole_angle - absolute accumulated pole angle (rad)
pole_angular_velocity - pole angular velocity (rad/s)

stamp - system time stamp (s)
stamp - system time stamp in seconds
error - system error code
'''

Expand All @@ -163,25 +164,9 @@ class State(BaseModel):
pole_angle: float = 0.0
pole_angular_velocity: float = 0.0

stamp: float = 0.0
stamp: float = Field(default_factory=time.perf_counter)
error: Error = Error.NO_ERROR

class Config:
@staticmethod
def json_schema_extra(schema: Any, model: Any) -> None:
# make schema lightweight
schema.pop('definitions', None)

properties = schema['properties']
for name in properties:
properties[name].pop('title', None)

# simplify schema for foxglove
properties['error'] = {
'type': 'integer',
'enum': [e.value for e in Error]
}

def validate(self, config: Config) -> None:
'''
Validates state against limits.
Expand Down Expand Up @@ -240,9 +225,9 @@ class Target(BaseModel):
If velocity/accleration is not specified (absolute value needed), use control limit as a default.
'''

position: float | None = None
velocity: float | None = None
acceleration: float | None = None
position: Optional[float] = None
velocity: Optional[float] = None
acceleration: Optional[float] = None

def acceleration_or(self, default: float) -> float:
'''
Expand Down
2 changes: 1 addition & 1 deletion cartpole/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from cartpole.device._device import *
from cartpole.device.device import CartPoleDevice
79 changes: 0 additions & 79 deletions cartpole/device/_device.py

This file was deleted.

Loading