Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions examples/human.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
import sys
import time
import argparse

from scienceworld import ScienceWorldEnv


prompt_toolkit_available = False
try:
# For command line history and autocompletion.
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory
prompt_toolkit_available = sys.stdout.isatty()
except ImportError:
pass

try:
# For command line history when prompt_toolkit is not available.
import readline # noqa: F401
except ImportError:
pass


def userConsole(args):
""" Example user input console, to play through a game. """
history = None
if prompt_toolkit_available:
history = InMemoryHistory()

exitCommands = ["quit", "exit"]

taskIdx = args['task_num']
Expand Down Expand Up @@ -99,12 +121,22 @@ def userConsole(args):
print("isCompleted: " + str(isCompleted))
#print("info: " + str(info))

print("'help' lists valid action templates, 'objects' lists valid objects, 'valid' lists valid action-object combinations (long!). ")
print("'help' lists valid action templates, 'objects' lists valid objects, use <tab> to list valid actions. ")
print("'goals' lists progress on subgoals.")
print("type 'exit' to quit.")

# Select a random action
validActions = env.getValidActionObjectCombinations()

# Get user input
userInputStr = input('> ')
if prompt_toolkit_available:
actions_completer = WordCompleter(validActions, ignore_case=True, sentence=True)
userInputStr = prompt('> ', completer=actions_completer,
history=history, enable_history_search=True)
else:
print("Valid Actions: " + str(validActions))
userInputStr = input('> ')

# Sanitize input
userInputStr = userInputStr.lower().strip()

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
gymnasium @ git+https://github.com/MarcCote/Gymnasium.git@enh_vector_reset_options
py4j
7 changes: 7 additions & 0 deletions scienceworld/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from .version import __version__
from .scienceworld import ScienceWorldEnv
from .scienceworld import BufferedHistorySaver

from gymnasium.envs.registration import register

register(
id="ScienceWorld-v0",
entry_point="scienceworld.scienceworld_gym:ScienceWorldEnv",
)
Binary file not shown.
100 changes: 49 additions & 51 deletions scienceworld/scienceworld.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# scienceworld.py
#
# conda create --name scienceworld python=3.8
# conda activate scienceworld
# pip install py4j (for scala-python interface)
# pip install -U pywebio (for web server)

from py4j.java_gateway import JavaGateway, GatewayParameters, launch_gateway, CallbackServerParameters

import os
import json
import numpy as np

import logging
import scienceworld


BASEPATH = os.path.dirname(os.path.abspath(__file__))
JAR_FILE = 'scienceworld-{version}.jar'.format(version=scienceworld.__version__)
JAR_PATH = os.path.join(BASEPATH, JAR_FILE)
Expand All @@ -22,8 +19,7 @@ class ScienceWorldEnv:
#
# Constructor
#
def __init__(self, taskName, serverPath=None, envStepLimit=100):
self.taskName = taskName
def __init__(self, taskName=None, serverPath=None, envStepLimit=100):
serverPath = serverPath or JAR_PATH # Use the builtin jar.

# Launch the server and connect to the JVM.
Expand All @@ -49,12 +45,15 @@ def __init__(self, taskName, serverPath=None, envStepLimit=100):
python_port)

self.server = self._gateway.jvm.scienceworld.runtime.pythonapi.PythonInterface()
logger.info("ScienceWorld server running on port", port)

# Keep track of the last step score, to calculate reward from score
self.lastStepScore = 0

# Load the script
self.load(self.taskName, 0, "")
self.taskName = taskName
if self.taskName:
self.load(taskName, 0, "")

# Set the environment step limit
self.envStepLimit = envStepLimit
Expand All @@ -65,33 +64,48 @@ def __init__(self, taskName, serverPath=None, envStepLimit=100):
# By default, set that the gold path was not generated unless the user asked for it
self.goldPathGenerated = False

# Ask the simulator to load an environment from a script
def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=False):
""" Load a given task and its variation. """

# Check loading arguments.
if isinstance(taskName, int):
# Retrieve task from its id.
taskName = self.getTaskNames()[taskName]

#
# Methods
#
# Validate task name.
if taskName not in self.getTaskNames():
msg = "Unknown taskName: '{}'. ".format(taskName)
msg += "Supported tasks are: {}".format(self.getTaskNames())
raise ValueError(msg)

# Ask the simulator to load an environment from a script
def load(self, taskName, variationIdx, simplificationStr, generateGoldPath=False):
self.scriptFilename = taskName
self.taskName = taskName

logger.info("Load: " + self.scriptFilename + " (variation: " + str(variationIdx) + ")" + " (simplifications: " + simplificationStr + ")")
# Validate simplification string.
possible_simplifications = ["easy"] + self.getPossibleSimplifications()
for simplification in simplificationStr.split(","):
if simplification and simplification not in possible_simplifications:
msg = "Unknown simplification: '{}'. ".format(simplification)
msg += "Supported simplifications are: {}".format(possible_simplifications)
raise ValueError(msg)

is_electrical_task = "power-component" in taskName or "conductivity" in taskName
if is_electrical_task and "noElectricalAction" in simplificationStr:
msg = "Invalid simplification. Task '{}' requires electrical actions but '--no-electrical' was provided."
raise ValueError(msg.format(taskName))

errMsg = self.server.load(self.scriptFilename, variationIdx, simplificationStr, generateGoldPath)
if errMsg and taskName: # Do not raise error if intentionally loading empty task
raise RuntimeError(errMsg)
self.simplificationStr = simplificationStr
self.variationIdx = variationIdx

logger.info(f"Loading: {self.taskName} (variation: {self.variationIdx}) (simplifications: {self.simplificationStr})")
self.server.load(self.taskName, self.variationIdx, self.simplificationStr, generateGoldPath)

# Reset last step score (used to calculate reward from current-previous score)
self.lastStepScore = 0

# Keep track of whether the gold path was generated, to generate verbose error messages
self.goldPathGenerated = generateGoldPath


# Ask the simulator to reset an environment back to it's initial state
def reset(self):
self.server.reset()
Expand All @@ -102,29 +116,15 @@ def reset(self):
# Make first move
observation, score, isCompleted, info = self.step("look around")

# Return a tuple that looks like the Jericho signature for reset
# Return a tuple that looks like the Jericho signiture for reset
return observation, info

# Ask the simulator to reset an environment back to it's initial state
def resetWithVariation(self, variationIdx, simplificationStr):
self.load(self.scriptFilename, variationIdx, simplificationStr)

# Reset last step score (used to calculate reward from current-previous score)
self.lastStepScore = 0

# Make first move
observation, score, isCompleted, info = self.step("look around")

# Return a tuple that looks like the Jericho signature for reset
return observation, info


# Simplifications
def getSimplificationsUsed(self):
return self.server.getSimplificationsUsed()

def getPossibleSimplifications(self):
return self.server.getPossibleSimplifications()
return self.server.getPossibleSimplifications().split(", ")


# Get a list of valid tasks/environments
Expand Down Expand Up @@ -224,7 +224,6 @@ def getTaskDescription(self):
#
def getRunHistory(self):
historyStr = self.server.getRunHistoryJSON()
#logger.info("historyStr: " + str(historyStr))
jsonOut = json.loads(historyStr)
return jsonOut

Expand Down Expand Up @@ -255,7 +254,6 @@ def saveRunHistories(self, filenameOutPrefix):
logger.info("* Saving run history (" + str(filenameOut) + ")...")

with open(filenameOut, 'w') as outfile:
#logger.info(type(self.runHistories))
json.dump(self.runHistories, outfile, sort_keys=True, indent=4)

def getRunHistorySize(self):
Expand Down Expand Up @@ -319,18 +317,19 @@ def step(self, inputStr:str):
if (score < 0):
isCompleted = True

#logger.info("> " + str(inputStr))
#logger.info("score: " + str(score))
#logger.info("moves: " + str(numMoves))

# Mirror of Jericho API
infos = {'moves': numMoves,
'score': score,
'reward': reward,
'look': self.look(),
'inv': self.inventory(),
'taskDesc': self.taskdescription(),
'valid': self.getValidActionObjectCombinations() }
infos = {
'moves': numMoves,
'score': score,
'reward': reward,
'look': self.look(),
'inv': self.inventory(),
'taskDesc': self.taskdescription(),
'valid': self.getValidActionObjectCombinations(),
'variationIdx': self.variationIdx,
'taskName': self.taskName,
'simplificationStr': self.simplificationStr,
}

return observation, reward, isCompleted, infos

Expand Down Expand Up @@ -395,7 +394,6 @@ def saveRunHistories(self):
logger.info("* Saving run history ( " + str(filenameOut) + ")...")

with open(filenameOut, 'w') as outfile:
#logger.info(type(self.runHistories))
json.dump(self.runHistories, outfile, sort_keys=True, indent=4)

def getRunHistorySize(self):
Expand Down
93 changes: 93 additions & 0 deletions scienceworld/scienceworld_gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy as np

import gymnasium as gym
from gymnasium import spaces

import scienceworld


class String(spaces.Space):
def __init__(self, ):
super().__init__(dtype=str)

def sample(self):
return ''

def contains(self, obj):
return isinstance(obj, str)

def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
return (
isinstance(other, String)
)


class ScienceWorldEnv(gym.Env):
metadata = {"render_modes": ["human", "ansi"]}

def __init__(self, render_mode=None):
self.env = scienceworld.ScienceWorldEnv()

assert render_mode is None or render_mode in self.metadata["render_modes"]

self.observation_space = String()
self.action_space = String()
self.options = {}
self.variation_id = 0

def reset(self, seed=None, options=None):
# We need the following line to seed self.np_random
super().reset(seed=seed)

if options is not None:
self.options = dict(options)

task = self.options.get("task")
self.env.load(task)

variations = self.options.get("variation")
if variations == "train" or variations is None:
variations = self.env.getVariationsTrain()
elif variations == "dev":
variations = self.env.getVariationsDev()
elif variations == "test":
variations = self.env.getVariationsTest()
elif isinstance(variations, int):
variations = [variations]

self.options["variation"] = list(variations)
self.np_random.shuffle(self.options["variation"])
self.variation_id = 0

task = self.options["task"]
variation = self.options["variation"][self.variation_id]
simplification = self.options.get("simplification", "")
generate_gold_path = self.options.get("generate_gold_path", False)
# TODO: is task is not provided, choose one at random.

# Advance variation counter.
self.variation_id = (self.variation_id + 1) % len(self.options["variation"])

self.env.load(task, variation, simplification, generate_gold_path)
self.observation, info = self.env.reset()
self.last_command = "look around"

if self.render_mode == "human":
print(">", self.last_command, "\n", self.observation)

return self.observation, info

def step(self, command):
self.last_command = command
self.observation, reward, terminated, info = self.env.step(command)

if self.render_mode == "human":
print(">", self.last_command, "\n", self.observation)

return self.observation, reward, terminated, False, info # Truncated is always False


def render(self):
if self.render_mode == "ansi":
return ">" + self.last_command + "\n" + self.observation
2 changes: 1 addition & 1 deletion scienceworld/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.3rc2'
__version__ = '1.1.0rc1'
2 changes: 1 addition & 1 deletion simulator/build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name := "scienceworld-scala"

version := "1.0.3rc2"
version := "1.1.0rc1"

scalaVersion := "2.12.9"

Expand Down
Loading